diff --git a/src/neighbors/knn_classifier.rs b/src/neighbors/knn_classifier.rs index e22c897..553011d 100644 --- a/src/neighbors/knn_classifier.rs +++ b/src/neighbors/knn_classifier.rs @@ -1,9 +1,9 @@ use serde::{Deserialize, Serialize}; -use crate::neighbors::{KNNAlgorithmName, KNNAlgorithm}; use crate::linalg::{row_iter, Matrix}; use crate::math::distance::Distance; use crate::math::num::FloatExt; +use crate::neighbors::{KNNAlgorithm, KNNAlgorithmName}; #[derive(Serialize, Deserialize, Debug)] pub struct KNNClassifierParameters { diff --git a/src/neighbors/knn_regressor.rs b/src/neighbors/knn_regressor.rs index 2d183f4..61df138 100644 --- a/src/neighbors/knn_regressor.rs +++ b/src/neighbors/knn_regressor.rs @@ -1,10 +1,9 @@ use serde::{Deserialize, Serialize}; -use crate::neighbors::{KNNAlgorithmName, KNNAlgorithm}; use crate::linalg::{row_iter, BaseVector, Matrix}; use crate::math::distance::Distance; use crate::math::num::FloatExt; - +use crate::neighbors::{KNNAlgorithm, KNNAlgorithmName}; #[derive(Serialize, Deserialize, Debug)] pub struct KNNRegressorParameters { @@ -13,7 +12,7 @@ pub struct KNNRegressorParameters { } #[derive(Serialize, Deserialize, Debug)] -pub struct KNNRegressor, T>> { +pub struct KNNRegressor, T>> { y: Vec, knn_algorithm: KNNAlgorithm, k: usize, @@ -30,7 +29,7 @@ impl Default for KNNRegressorParameters { impl, T>> PartialEq for KNNRegressor { fn eq(&self, other: &Self) -> bool { - if self.k != other.k || self.y.len() != other.y.len(){ + if self.k != other.k || self.y.len() != other.y.len() { return false; } else { for i in 0..self.y.len() { @@ -56,7 +55,7 @@ impl, T>> KNNRegressor { let (x_n, _) = x.shape(); let data = row_iter(x).collect(); - + assert!( x_n == y_n, format!( @@ -68,9 +67,9 @@ impl, T>> KNNRegressor { assert!( parameters.k > 1, format!("k should be > 1, k=[{}]", parameters.k) - ); + ); - KNNRegressor { + KNNRegressor { y: y.to_vec(), k: parameters.k, knn_algorithm: parameters.algorithm.fit(data, distance), @@ -88,10 +87,10 @@ impl, T>> KNNRegressor { } fn predict_for_row(&self, x: Vec) -> T { - let idxs = self.knn_algorithm.find(&x, self.k); + let idxs = self.knn_algorithm.find(&x, self.k); let mut result = T::zero(); for i in idxs { - result = result + self.y[i]; + result = result + self.y[i]; } result / T::from_usize(self.k).unwrap() diff --git a/src/neighbors/mod.rs b/src/neighbors/mod.rs index b68b204..f727726 100644 --- a/src/neighbors/mod.rs +++ b/src/neighbors/mod.rs @@ -1,12 +1,12 @@ //! # Nearest Neighbors -use serde::{Deserialize, Serialize}; use crate::algorithm::neighbour::cover_tree::CoverTree; use crate::algorithm::neighbour::linear_search::LinearKNNSearch; use crate::math::distance::Distance; use crate::math::num::FloatExt; +use serde::{Deserialize, Serialize}; -/// +/// pub mod knn_classifier; pub mod knn_regressor;