fix: formatting
This commit is contained in:
@@ -1,9 +1,9 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::neighbors::{KNNAlgorithmName, KNNAlgorithm};
|
|
||||||
use crate::linalg::{row_iter, Matrix};
|
use crate::linalg::{row_iter, Matrix};
|
||||||
use crate::math::distance::Distance;
|
use crate::math::distance::Distance;
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
|
use crate::neighbors::{KNNAlgorithm, KNNAlgorithmName};
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct KNNClassifierParameters {
|
pub struct KNNClassifierParameters {
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::neighbors::{KNNAlgorithmName, KNNAlgorithm};
|
|
||||||
use crate::linalg::{row_iter, BaseVector, Matrix};
|
use crate::linalg::{row_iter, BaseVector, Matrix};
|
||||||
use crate::math::distance::Distance;
|
use crate::math::distance::Distance;
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
|
use crate::neighbors::{KNNAlgorithm, KNNAlgorithmName};
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct KNNRegressorParameters {
|
pub struct KNNRegressorParameters {
|
||||||
@@ -13,7 +12,7 @@ pub struct KNNRegressorParameters {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct KNNRegressor<T: FloatExt, D: Distance<Vec<T>, T>> {
|
pub struct KNNRegressor<T: FloatExt, D: Distance<Vec<T>, T>> {
|
||||||
y: Vec<T>,
|
y: Vec<T>,
|
||||||
knn_algorithm: KNNAlgorithm<T, D>,
|
knn_algorithm: KNNAlgorithm<T, D>,
|
||||||
k: usize,
|
k: usize,
|
||||||
@@ -30,7 +29,7 @@ impl Default for KNNRegressorParameters {
|
|||||||
|
|
||||||
impl<T: FloatExt, D: Distance<Vec<T>, T>> PartialEq for KNNRegressor<T, D> {
|
impl<T: FloatExt, D: Distance<Vec<T>, T>> PartialEq for KNNRegressor<T, D> {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
if self.k != other.k || self.y.len() != other.y.len(){
|
if self.k != other.k || self.y.len() != other.y.len() {
|
||||||
return false;
|
return false;
|
||||||
} else {
|
} else {
|
||||||
for i in 0..self.y.len() {
|
for i in 0..self.y.len() {
|
||||||
@@ -56,7 +55,7 @@ impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
|
|||||||
let (x_n, _) = x.shape();
|
let (x_n, _) = x.shape();
|
||||||
|
|
||||||
let data = row_iter(x).collect();
|
let data = row_iter(x).collect();
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
x_n == y_n,
|
x_n == y_n,
|
||||||
format!(
|
format!(
|
||||||
@@ -68,9 +67,9 @@ impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
|
|||||||
assert!(
|
assert!(
|
||||||
parameters.k > 1,
|
parameters.k > 1,
|
||||||
format!("k should be > 1, k=[{}]", parameters.k)
|
format!("k should be > 1, k=[{}]", parameters.k)
|
||||||
);
|
);
|
||||||
|
|
||||||
KNNRegressor {
|
KNNRegressor {
|
||||||
y: y.to_vec(),
|
y: y.to_vec(),
|
||||||
k: parameters.k,
|
k: parameters.k,
|
||||||
knn_algorithm: parameters.algorithm.fit(data, distance),
|
knn_algorithm: parameters.algorithm.fit(data, distance),
|
||||||
@@ -88,10 +87,10 @@ impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn predict_for_row(&self, x: Vec<T>) -> T {
|
fn predict_for_row(&self, x: Vec<T>) -> T {
|
||||||
let idxs = self.knn_algorithm.find(&x, self.k);
|
let idxs = self.knn_algorithm.find(&x, self.k);
|
||||||
let mut result = T::zero();
|
let mut result = T::zero();
|
||||||
for i in idxs {
|
for i in idxs {
|
||||||
result = result + self.y[i];
|
result = result + self.y[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
result / T::from_usize(self.k).unwrap()
|
result / T::from_usize(self.k).unwrap()
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
//! # Nearest Neighbors
|
//! # Nearest Neighbors
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use crate::algorithm::neighbour::cover_tree::CoverTree;
|
use crate::algorithm::neighbour::cover_tree::CoverTree;
|
||||||
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
|
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
|
||||||
use crate::math::distance::Distance;
|
use crate::math::distance::Distance;
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
///
|
///
|
||||||
pub mod knn_classifier;
|
pub mod knn_classifier;
|
||||||
pub mod knn_regressor;
|
pub mod knn_regressor;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user