fix: code cleanup, documentation
This commit is contained in:
+37
-24
@@ -7,44 +7,60 @@ use crate::math::distance::Distance;
|
||||
use crate::math::num::FloatExt;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct KNNClassifier<T: FloatExt, D: Distance<Vec<T>, T>> {
|
||||
classes: Vec<T>,
|
||||
y: Vec<usize>,
|
||||
knn_algorithm: KNNAlgorithmV<T, D>,
|
||||
k: usize,
|
||||
}
|
||||
|
||||
pub enum KNNAlgorithmName {
|
||||
LinearSearch,
|
||||
CoverTree,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub enum KNNAlgorithmV<T: FloatExt, D: Distance<Vec<T>, T>> {
|
||||
pub struct KNNClassifierParameters {
|
||||
pub algorithm: KNNAlgorithmName,
|
||||
pub k: usize
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct KNNClassifier<T: FloatExt, D: Distance<Vec<T>, T>> {
|
||||
classes: Vec<T>,
|
||||
y: Vec<usize>,
|
||||
knn_algorithm: KNNAlgorithm<T, D>,
|
||||
k: usize,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
enum KNNAlgorithm<T: FloatExt, D: Distance<Vec<T>, T>> {
|
||||
LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
|
||||
CoverTree(CoverTree<Vec<T>, T, D>),
|
||||
}
|
||||
|
||||
impl Default for KNNClassifierParameters {
|
||||
fn default() -> Self {
|
||||
KNNClassifierParameters {
|
||||
algorithm: KNNAlgorithmName::CoverTree,
|
||||
k: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl KNNAlgorithmName {
|
||||
fn fit<T: FloatExt, D: Distance<Vec<T>, T>>(
|
||||
&self,
|
||||
data: Vec<Vec<T>>,
|
||||
distance: D,
|
||||
) -> KNNAlgorithmV<T, D> {
|
||||
) -> KNNAlgorithm<T, D> {
|
||||
match *self {
|
||||
KNNAlgorithmName::LinearSearch => {
|
||||
KNNAlgorithmV::LinearSearch(LinearKNNSearch::new(data, distance))
|
||||
KNNAlgorithm::LinearSearch(LinearKNNSearch::new(data, distance))
|
||||
}
|
||||
KNNAlgorithmName::CoverTree => KNNAlgorithmV::CoverTree(CoverTree::new(data, distance)),
|
||||
KNNAlgorithmName::CoverTree => KNNAlgorithm::CoverTree(CoverTree::new(data, distance)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNAlgorithmV<T, D> {
|
||||
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNAlgorithm<T, D> {
|
||||
fn find(&self, from: &Vec<T>, k: usize) -> Vec<usize> {
|
||||
match *self {
|
||||
KNNAlgorithmV::LinearSearch(ref linear) => linear.find(from, k),
|
||||
KNNAlgorithmV::CoverTree(ref cover) => cover.find(from, k),
|
||||
KNNAlgorithm::LinearSearch(ref linear) => linear.find(from, k),
|
||||
KNNAlgorithm::CoverTree(ref cover) => cover.find(from, k),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -76,9 +92,8 @@ impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||
pub fn fit<M: Matrix<T>>(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
k: usize,
|
||||
distance: D,
|
||||
algorithm: KNNAlgorithmName,
|
||||
parameters: KNNClassifierParameters
|
||||
) -> KNNClassifier<T, D> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
|
||||
@@ -103,13 +118,13 @@ impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||
)
|
||||
);
|
||||
|
||||
assert!(k > 1, format!("k should be > 1, k=[{}]", k));
|
||||
assert!(parameters.k > 1, format!("k should be > 1, k=[{}]", parameters.k));
|
||||
|
||||
KNNClassifier {
|
||||
classes: classes,
|
||||
y: yi,
|
||||
k: k,
|
||||
knn_algorithm: algorithm.fit(data, distance),
|
||||
k: parameters.k,
|
||||
knn_algorithm: parameters.algorithm.fit(data, distance),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -153,9 +168,8 @@ mod tests {
|
||||
let knn = KNNClassifier::fit(
|
||||
&x,
|
||||
&y,
|
||||
3,
|
||||
Distances::euclidian(),
|
||||
KNNAlgorithmName::LinearSearch,
|
||||
KNNClassifierParameters{k: 3, algorithm: KNNAlgorithmName::LinearSearch}
|
||||
);
|
||||
let r = knn.predict(&x);
|
||||
assert_eq!(5, Vec::len(&r));
|
||||
@@ -169,10 +183,9 @@ mod tests {
|
||||
|
||||
let knn = KNNClassifier::fit(
|
||||
&x,
|
||||
&y,
|
||||
3,
|
||||
&y,
|
||||
Distances::euclidian(),
|
||||
KNNAlgorithmName::CoverTree,
|
||||
Default::default()
|
||||
);
|
||||
|
||||
let deserialized_knn = bincode::deserialize(&bincode::serialize(&knn).unwrap()).unwrap();
|
||||
|
||||
Reference in New Issue
Block a user