fix: formatting

This commit is contained in:
Volodymyr Orlov
2020-08-27 14:17:49 -07:00
parent e5b412451f
commit 762bc3d765
3 changed files with 11 additions and 12 deletions
+1 -1
View File
@@ -1,9 +1,9 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::neighbors::{KNNAlgorithmName, KNNAlgorithm};
use crate::linalg::{row_iter, Matrix}; use crate::linalg::{row_iter, Matrix};
use crate::math::distance::Distance; use crate::math::distance::Distance;
use crate::math::num::FloatExt; use crate::math::num::FloatExt;
use crate::neighbors::{KNNAlgorithm, KNNAlgorithmName};
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct KNNClassifierParameters { pub struct KNNClassifierParameters {
+8 -9
View File
@@ -1,10 +1,9 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::neighbors::{KNNAlgorithmName, KNNAlgorithm};
use crate::linalg::{row_iter, BaseVector, Matrix}; use crate::linalg::{row_iter, BaseVector, Matrix};
use crate::math::distance::Distance; use crate::math::distance::Distance;
use crate::math::num::FloatExt; use crate::math::num::FloatExt;
use crate::neighbors::{KNNAlgorithm, KNNAlgorithmName};
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct KNNRegressorParameters { pub struct KNNRegressorParameters {
@@ -13,7 +12,7 @@ pub struct KNNRegressorParameters {
} }
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct KNNRegressor<T: FloatExt, D: Distance<Vec<T>, T>> { pub struct KNNRegressor<T: FloatExt, D: Distance<Vec<T>, T>> {
y: Vec<T>, y: Vec<T>,
knn_algorithm: KNNAlgorithm<T, D>, knn_algorithm: KNNAlgorithm<T, D>,
k: usize, k: usize,
@@ -30,7 +29,7 @@ impl Default for KNNRegressorParameters {
impl<T: FloatExt, D: Distance<Vec<T>, T>> PartialEq for KNNRegressor<T, D> { impl<T: FloatExt, D: Distance<Vec<T>, T>> PartialEq for KNNRegressor<T, D> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
if self.k != other.k || self.y.len() != other.y.len(){ if self.k != other.k || self.y.len() != other.y.len() {
return false; return false;
} else { } else {
for i in 0..self.y.len() { for i in 0..self.y.len() {
@@ -56,7 +55,7 @@ impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
let (x_n, _) = x.shape(); let (x_n, _) = x.shape();
let data = row_iter(x).collect(); let data = row_iter(x).collect();
assert!( assert!(
x_n == y_n, x_n == y_n,
format!( format!(
@@ -68,9 +67,9 @@ impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
assert!( assert!(
parameters.k > 1, parameters.k > 1,
format!("k should be > 1, k=[{}]", parameters.k) format!("k should be > 1, k=[{}]", parameters.k)
); );
KNNRegressor { KNNRegressor {
y: y.to_vec(), y: y.to_vec(),
k: parameters.k, k: parameters.k,
knn_algorithm: parameters.algorithm.fit(data, distance), knn_algorithm: parameters.algorithm.fit(data, distance),
@@ -88,10 +87,10 @@ impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
} }
fn predict_for_row(&self, x: Vec<T>) -> T { fn predict_for_row(&self, x: Vec<T>) -> T {
let idxs = self.knn_algorithm.find(&x, self.k); let idxs = self.knn_algorithm.find(&x, self.k);
let mut result = T::zero(); let mut result = T::zero();
for i in idxs { for i in idxs {
result = result + self.y[i]; result = result + self.y[i];
} }
result / T::from_usize(self.k).unwrap() result / T::from_usize(self.k).unwrap()
+2 -2
View File
@@ -1,12 +1,12 @@
//! # Nearest Neighbors //! # Nearest Neighbors
use serde::{Deserialize, Serialize};
use crate::algorithm::neighbour::cover_tree::CoverTree; use crate::algorithm::neighbour::cover_tree::CoverTree;
use crate::algorithm::neighbour::linear_search::LinearKNNSearch; use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
use crate::math::distance::Distance; use crate::math::distance::Distance;
use crate::math::num::FloatExt; use crate::math::num::FloatExt;
use serde::{Deserialize, Serialize};
/// ///
pub mod knn_classifier; pub mod knn_classifier;
pub mod knn_regressor; pub mod knn_regressor;