feat: new distance function parameter in KNN, extends KNN documentation
This commit is contained in:
@@ -78,7 +78,7 @@ impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D> {
|
|||||||
node_id
|
node_id
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn find(&self, p: &T, k: usize) -> Vec<usize> {
|
pub fn find(&self, p: &T, k: usize) -> Vec<(usize, F)> {
|
||||||
let mut qi_p_ds = vec![(self.root(), self.distance.distance(&p, &self.root().data))];
|
let mut qi_p_ds = vec![(self.root(), self.distance.distance(&p, &self.root().data))];
|
||||||
for i in (self.min_level..self.max_level + 1).rev() {
|
for i in (self.min_level..self.max_level + 1).rev() {
|
||||||
let i_d = self.base.powf(F::from(i).unwrap());
|
let i_d = self.base.powf(F::from(i).unwrap());
|
||||||
@@ -92,7 +92,7 @@ impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D> {
|
|||||||
qi_p_ds.sort_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap());
|
qi_p_ds.sort_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap());
|
||||||
qi_p_ds[..usize::min(qi_p_ds.len(), k)]
|
qi_p_ds[..usize::min(qi_p_ds.len(), k)]
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(n, _)| n.index.index)
|
.map(|(n, d)| (n.index.index, *d))
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -353,12 +353,14 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let mut nearest_3_to_5 = tree.find(&5, 3);
|
let mut nearest_3_to_5 = tree.find(&5, 3);
|
||||||
nearest_3_to_5.sort();
|
nearest_3_to_5.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
|
||||||
assert_eq!(vec!(3, 4, 5), nearest_3_to_5);
|
let nearest_3_to_5_indexes: Vec<usize> = nearest_3_to_5.iter().map(|v| v.0).collect();
|
||||||
|
assert_eq!(vec!(4, 5, 3), nearest_3_to_5_indexes);
|
||||||
|
|
||||||
let mut nearest_3_to_15 = tree.find(&15, 3);
|
let mut nearest_3_to_15 = tree.find(&15, 3);
|
||||||
nearest_3_to_15.sort();
|
nearest_3_to_15.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
|
||||||
assert_eq!(vec!(13, 14, 15), nearest_3_to_15);
|
let nearest_3_to_15_indexes: Vec<usize> = nearest_3_to_15.iter().map(|v| v.0).collect();
|
||||||
|
assert_eq!(vec!(14, 13, 15), nearest_3_to_15_indexes);
|
||||||
|
|
||||||
assert_eq!(-1, tree.min_level);
|
assert_eq!(-1, tree.min_level);
|
||||||
assert_eq!(100, tree.max_level);
|
assert_eq!(100, tree.max_level);
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ impl<T, F: FloatExt, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn find(&self, from: &T, k: usize) -> Vec<usize> {
|
pub fn find(&self, from: &T, k: usize) -> Vec<(usize, F)> {
|
||||||
if k < 1 || k > self.data.len() {
|
if k < 1 || k > self.data.len() {
|
||||||
panic!("k should be >= 1 and <= length(data)");
|
panic!("k should be >= 1 and <= length(data)");
|
||||||
}
|
}
|
||||||
@@ -48,7 +48,10 @@ impl<T, F: FloatExt, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
|||||||
|
|
||||||
heap.sort();
|
heap.sort();
|
||||||
|
|
||||||
heap.get().into_iter().flat_map(|x| x.index).collect()
|
heap.get()
|
||||||
|
.into_iter()
|
||||||
|
.flat_map(|x| x.index.map(|i| (i, x.distance)))
|
||||||
|
.collect()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -91,7 +94,9 @@ mod tests {
|
|||||||
|
|
||||||
let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance {});
|
let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance {});
|
||||||
|
|
||||||
assert_eq!(vec!(1, 2, 0), algorithm1.find(&2, 3));
|
let found_idxs1: Vec<usize> = algorithm1.find(&2, 3).iter().map(|v| v.0).collect();
|
||||||
|
|
||||||
|
assert_eq!(vec!(1, 2, 0), found_idxs1);
|
||||||
|
|
||||||
let data2 = vec![
|
let data2 = vec![
|
||||||
vec![1., 1.],
|
vec![1., 1.],
|
||||||
@@ -103,7 +108,13 @@ mod tests {
|
|||||||
|
|
||||||
let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian());
|
let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian());
|
||||||
|
|
||||||
assert_eq!(vec!(2, 3, 1), algorithm2.find(&vec![3., 3.], 3));
|
let found_idxs2: Vec<usize> = algorithm2
|
||||||
|
.find(&vec![3., 3.], 3)
|
||||||
|
.iter()
|
||||||
|
.map(|v| v.0)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
assert_eq!(vec!(2, 3, 1), found_idxs2);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
+2
-1
@@ -1,8 +1,9 @@
|
|||||||
use num_traits::{Float, FromPrimitive};
|
use num_traits::{Float, FromPrimitive};
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
use std::fmt::{Debug, Display};
|
use std::fmt::{Debug, Display};
|
||||||
|
use std::iter::{Product, Sum};
|
||||||
|
|
||||||
pub trait FloatExt: Float + FromPrimitive + Debug + Display + Copy {
|
pub trait FloatExt: Float + FromPrimitive + Debug + Display + Copy + Sum + Product {
|
||||||
fn copysign(self, sign: Self) -> Self;
|
fn copysign(self, sign: Self) -> Self;
|
||||||
|
|
||||||
fn ln_1pe(self) -> Self;
|
fn ln_1pe(self) -> Self;
|
||||||
|
|||||||
@@ -37,13 +37,15 @@ use serde::{Deserialize, Serialize};
|
|||||||
use crate::linalg::{row_iter, Matrix};
|
use crate::linalg::{row_iter, Matrix};
|
||||||
use crate::math::distance::Distance;
|
use crate::math::distance::Distance;
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
use crate::neighbors::{KNNAlgorithm, KNNAlgorithmName};
|
use crate::neighbors::{KNNAlgorithm, KNNAlgorithmName, KNNWeightFunction};
|
||||||
|
|
||||||
/// `KNNClassifier` parameters. Use `Default::default()` for default values.
|
/// `KNNClassifier` parameters. Use `Default::default()` for default values.
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct KNNClassifierParameters {
|
pub struct KNNClassifierParameters {
|
||||||
/// backend search algorithm. See [`knn search algorithms`](../../algorithm/neighbour/index.html). `CoverTree` is default.
|
/// backend search algorithm. See [`knn search algorithms`](../../algorithm/neighbour/index.html). `CoverTree` is default.
|
||||||
pub algorithm: KNNAlgorithmName,
|
pub algorithm: KNNAlgorithmName,
|
||||||
|
/// weighting function that is used to calculate estimated class value. Default function is `KNNWeightFunction::Uniform`.
|
||||||
|
pub weight: KNNWeightFunction,
|
||||||
/// number of training samples to consider when estimating class for new point. Default value is 3.
|
/// number of training samples to consider when estimating class for new point. Default value is 3.
|
||||||
pub k: usize,
|
pub k: usize,
|
||||||
}
|
}
|
||||||
@@ -54,6 +56,7 @@ pub struct KNNClassifier<T: FloatExt, D: Distance<Vec<T>, T>> {
|
|||||||
classes: Vec<T>,
|
classes: Vec<T>,
|
||||||
y: Vec<usize>,
|
y: Vec<usize>,
|
||||||
knn_algorithm: KNNAlgorithm<T, D>,
|
knn_algorithm: KNNAlgorithm<T, D>,
|
||||||
|
weight: KNNWeightFunction,
|
||||||
k: usize,
|
k: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -61,6 +64,7 @@ impl Default for KNNClassifierParameters {
|
|||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
KNNClassifierParameters {
|
KNNClassifierParameters {
|
||||||
algorithm: KNNAlgorithmName::CoverTree,
|
algorithm: KNNAlgorithmName::CoverTree,
|
||||||
|
weight: KNNWeightFunction::Uniform,
|
||||||
k: 3,
|
k: 3,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -90,7 +94,7 @@ impl<T: FloatExt, D: Distance<Vec<T>, T>> PartialEq for KNNClassifier<T, D> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||||
/// Fits KNN Classifier to a NxM matrix where N is number of samples and M is number of features.
|
/// Fits KNN classifier to a NxM matrix where N is number of samples and M is number of features.
|
||||||
/// * `x` - training data
|
/// * `x` - training data
|
||||||
/// * `y` - vector with target values (classes) of length N
|
/// * `y` - vector with target values (classes) of length N
|
||||||
/// * `distance` - a function that defines a distance between each pair of point in training data.
|
/// * `distance` - a function that defines a distance between each pair of point in training data.
|
||||||
@@ -136,6 +140,7 @@ impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
|||||||
y: yi,
|
y: yi,
|
||||||
k: parameters.k,
|
k: parameters.k,
|
||||||
knn_algorithm: parameters.algorithm.fit(data, distance),
|
knn_algorithm: parameters.algorithm.fit(data, distance),
|
||||||
|
weight: parameters.weight,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -153,15 +158,21 @@ impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn predict_for_row(&self, x: Vec<T>) -> usize {
|
fn predict_for_row(&self, x: Vec<T>) -> usize {
|
||||||
let idxs = self.knn_algorithm.find(&x, self.k);
|
let search_result = self.knn_algorithm.find(&x, self.k);
|
||||||
let mut c = vec![0; self.classes.len()];
|
|
||||||
let mut max_c = 0;
|
let weights = self
|
||||||
|
.weight
|
||||||
|
.calc_weights(search_result.iter().map(|v| v.1).collect());
|
||||||
|
let w_sum = weights.iter().map(|w| *w).sum();
|
||||||
|
|
||||||
|
let mut c = vec![T::zero(); self.classes.len()];
|
||||||
|
let mut max_c = T::zero();
|
||||||
let mut max_i = 0;
|
let mut max_i = 0;
|
||||||
for i in idxs {
|
for (r, w) in search_result.iter().zip(weights.iter()) {
|
||||||
c[self.y[i]] += 1;
|
c[self.y[r.0]] = c[self.y[r.0]] + (*w / w_sum);
|
||||||
if c[self.y[i]] > max_c {
|
if c[self.y[r.0]] > max_c {
|
||||||
max_c = c[self.y[i]];
|
max_c = c[self.y[r.0]];
|
||||||
max_i = self.y[i];
|
max_i = self.y[r.0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -179,18 +190,28 @@ mod tests {
|
|||||||
fn knn_fit_predict() {
|
fn knn_fit_predict() {
|
||||||
let x = DenseMatrix::from_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
let x = DenseMatrix::from_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||||
let y = vec![2., 2., 2., 3., 3.];
|
let y = vec![2., 2., 2., 3., 3.];
|
||||||
|
let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default());
|
||||||
|
let y_hat = knn.predict(&x);
|
||||||
|
assert_eq!(5, Vec::len(&y_hat));
|
||||||
|
assert_eq!(y.to_vec(), y_hat);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn knn_fit_predict_weighted() {
|
||||||
|
let x = DenseMatrix::from_array(&[&[1.], &[2.], &[3.], &[4.], &[5.]]);
|
||||||
|
let y = vec![2., 2., 2., 3., 3.];
|
||||||
let knn = KNNClassifier::fit(
|
let knn = KNNClassifier::fit(
|
||||||
&x,
|
&x,
|
||||||
&y,
|
&y,
|
||||||
Distances::euclidian(),
|
Distances::euclidian(),
|
||||||
KNNClassifierParameters {
|
KNNClassifierParameters {
|
||||||
k: 3,
|
k: 5,
|
||||||
algorithm: KNNAlgorithmName::LinearSearch,
|
algorithm: KNNAlgorithmName::LinearSearch,
|
||||||
|
weight: KNNWeightFunction::Distance,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
let r = knn.predict(&x);
|
let y_hat = knn.predict(&DenseMatrix::from_array(&[&[4.1]]));
|
||||||
assert_eq!(5, Vec::len(&r));
|
assert_eq!(vec![3.0], y_hat);
|
||||||
assert_eq!(y.to_vec(), r);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -1,20 +1,63 @@
|
|||||||
|
//! # K Nearest Neighbors Regressor
|
||||||
|
//!
|
||||||
|
//! Regressor that predicts estimated values as a function of k nearest neightbours.
|
||||||
|
//!
|
||||||
|
//! `KNNRegressor` relies on 2 backend algorithms to speedup KNN queries:
|
||||||
|
//! * [`LinearSearch`](../../algorithm/neighbour/linear_search/index.html)
|
||||||
|
//! * [`CoverTree`](../../algorithm/neighbour/cover_tree/index.html)
|
||||||
|
//!
|
||||||
|
//! The parameter `k` controls the stability of the KNN estimate: when `k` is small the algorithm is sensitive to the noise in data. When `k` increases the estimator becomes more stable.
|
||||||
|
//! In terms of the bias variance trade-off the variance decreases with `k` and the bias is likely to increase with `k`.
|
||||||
|
//!
|
||||||
|
//! When you don't know which search algorithm and `k` value to use go with default parameters defined by `Default::default()`
|
||||||
|
//!
|
||||||
|
//! To fit the model to a 4 x 2 matrix with 4 training samples, 2 features per sample:
|
||||||
|
//!
|
||||||
|
//! ```
|
||||||
|
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||||
|
//! use smartcore::neighbors::knn_regressor::*;
|
||||||
|
//! use smartcore::math::distance::*;
|
||||||
|
//!
|
||||||
|
//! //your explanatory variables. Each row is a training sample with 2 numerical features
|
||||||
|
//! let x = DenseMatrix::from_array(&[
|
||||||
|
//! &[1., 1.],
|
||||||
|
//! &[2., 2.],
|
||||||
|
//! &[3., 3.],
|
||||||
|
//! &[4., 4.],
|
||||||
|
//! &[5., 5.]]);
|
||||||
|
//! let y = vec![1., 2., 3., 4., 5.]; //your target values
|
||||||
|
//!
|
||||||
|
//! let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default());
|
||||||
|
//! let y_hat = knn.predict(&x);
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! variable `y_hat` will hold predicted value
|
||||||
|
//!
|
||||||
|
//!
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::linalg::{row_iter, BaseVector, Matrix};
|
use crate::linalg::{row_iter, BaseVector, Matrix};
|
||||||
use crate::math::distance::Distance;
|
use crate::math::distance::Distance;
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
use crate::neighbors::{KNNAlgorithm, KNNAlgorithmName};
|
use crate::neighbors::{KNNAlgorithm, KNNAlgorithmName, KNNWeightFunction};
|
||||||
|
|
||||||
|
/// `KNNRegressor` parameters. Use `Default::default()` for default values.
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct KNNRegressorParameters {
|
pub struct KNNRegressorParameters {
|
||||||
|
/// backend search algorithm. See [`knn search algorithms`](../../algorithm/neighbour/index.html). `CoverTree` is default.
|
||||||
pub algorithm: KNNAlgorithmName,
|
pub algorithm: KNNAlgorithmName,
|
||||||
|
/// weighting function that is used to calculate estimated class value. Default function is `KNNWeightFunction::Uniform`.
|
||||||
|
pub weight: KNNWeightFunction,
|
||||||
|
/// number of training samples to consider when estimating class for new point. Default value is 3.
|
||||||
pub k: usize,
|
pub k: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// K Nearest Neighbors Regressor
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct KNNRegressor<T: FloatExt, D: Distance<Vec<T>, T>> {
|
pub struct KNNRegressor<T: FloatExt, D: Distance<Vec<T>, T>> {
|
||||||
y: Vec<T>,
|
y: Vec<T>,
|
||||||
knn_algorithm: KNNAlgorithm<T, D>,
|
knn_algorithm: KNNAlgorithm<T, D>,
|
||||||
|
weight: KNNWeightFunction,
|
||||||
k: usize,
|
k: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -22,6 +65,7 @@ impl Default for KNNRegressorParameters {
|
|||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
KNNRegressorParameters {
|
KNNRegressorParameters {
|
||||||
algorithm: KNNAlgorithmName::CoverTree,
|
algorithm: KNNAlgorithmName::CoverTree,
|
||||||
|
weight: KNNWeightFunction::Uniform,
|
||||||
k: 3,
|
k: 3,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -43,6 +87,13 @@ impl<T: FloatExt, D: Distance<Vec<T>, T>> PartialEq for KNNRegressor<T, D> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
|
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
|
||||||
|
/// Fits KNN regressor to a NxM matrix where N is number of samples and M is number of features.
|
||||||
|
/// * `x` - training data
|
||||||
|
/// * `y` - vector with real values
|
||||||
|
/// * `distance` - a function that defines a distance between each pair of point in training data.
|
||||||
|
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||||
|
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||||
|
/// * `parameters` - additional parameters like search algorithm and k
|
||||||
pub fn fit<M: Matrix<T>>(
|
pub fn fit<M: Matrix<T>>(
|
||||||
x: &M,
|
x: &M,
|
||||||
y: &M::RowVector,
|
y: &M::RowVector,
|
||||||
@@ -73,9 +124,13 @@ impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
|
|||||||
y: y.to_vec(),
|
y: y.to_vec(),
|
||||||
k: parameters.k,
|
k: parameters.k,
|
||||||
knn_algorithm: parameters.algorithm.fit(data, distance),
|
knn_algorithm: parameters.algorithm.fit(data, distance),
|
||||||
|
weight: parameters.weight,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Predict the target for the provided data.
|
||||||
|
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
|
||||||
|
/// Returns a vector of size N with estimates.
|
||||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
|
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
|
||||||
let mut result = M::zeros(1, x.shape().0);
|
let mut result = M::zeros(1, x.shape().0);
|
||||||
|
|
||||||
@@ -87,13 +142,19 @@ impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn predict_for_row(&self, x: Vec<T>) -> T {
|
fn predict_for_row(&self, x: Vec<T>) -> T {
|
||||||
let idxs = self.knn_algorithm.find(&x, self.k);
|
let search_result = self.knn_algorithm.find(&x, self.k);
|
||||||
let mut result = T::zero();
|
let mut result = T::zero();
|
||||||
for i in idxs {
|
|
||||||
result = result + self.y[i];
|
let weights = self
|
||||||
|
.weight
|
||||||
|
.calc_weights(search_result.iter().map(|v| v.1).collect());
|
||||||
|
let w_sum = weights.iter().map(|w| *w).sum();
|
||||||
|
|
||||||
|
for (r, w) in search_result.iter().zip(weights.iter()) {
|
||||||
|
result = result + self.y[r.0] * (*w / w_sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
result / T::from_usize(self.k).unwrap()
|
result
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -104,10 +165,10 @@ mod tests {
|
|||||||
use crate::math::distance::Distances;
|
use crate::math::distance::Distances;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn knn_fit_predict() {
|
fn knn_fit_predict_weighted() {
|
||||||
let x = DenseMatrix::from_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
let x = DenseMatrix::from_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||||
let y: Vec<f64> = vec![1., 2., 3., 4., 5.];
|
let y: Vec<f64> = vec![1., 2., 3., 4., 5.];
|
||||||
let y_exp = vec![2., 2., 3., 4., 4.];
|
let y_exp = vec![1., 2., 3., 4., 5.];
|
||||||
let knn = KNNRegressor::fit(
|
let knn = KNNRegressor::fit(
|
||||||
&x,
|
&x,
|
||||||
&y,
|
&y,
|
||||||
@@ -115,6 +176,7 @@ mod tests {
|
|||||||
KNNRegressorParameters {
|
KNNRegressorParameters {
|
||||||
k: 3,
|
k: 3,
|
||||||
algorithm: KNNAlgorithmName::LinearSearch,
|
algorithm: KNNAlgorithmName::LinearSearch,
|
||||||
|
weight: KNNWeightFunction::Distance,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
let y_hat = knn.predict(&x);
|
let y_hat = knn.predict(&x);
|
||||||
@@ -124,6 +186,19 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn knn_fit_predict_uniform() {
|
||||||
|
let x = DenseMatrix::from_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||||
|
let y: Vec<f64> = vec![1., 2., 3., 4., 5.];
|
||||||
|
let y_exp = vec![2., 2., 3., 4., 4.];
|
||||||
|
let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default());
|
||||||
|
let y_hat = knn.predict(&x);
|
||||||
|
assert_eq!(5, Vec::len(&y_hat));
|
||||||
|
for i in 0..y_hat.len() {
|
||||||
|
assert!((y_hat[i] - y_exp[i]).abs() < std::f64::EPSILON);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::from_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
let x = DenseMatrix::from_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||||
|
|||||||
+30
-1
@@ -52,12 +52,41 @@ pub enum KNNAlgorithmName {
|
|||||||
CoverTree,
|
CoverTree,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Weight function that is used to determine estimated value.
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub enum KNNWeightFunction {
|
||||||
|
/// All k nearest points are weighted equally
|
||||||
|
Uniform,
|
||||||
|
/// k nearest points are weighted by the inverse of their distance. Closer neighbors will have a greater influence than neighbors which are further away.
|
||||||
|
Distance,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
enum KNNAlgorithm<T: FloatExt, D: Distance<Vec<T>, T>> {
|
enum KNNAlgorithm<T: FloatExt, D: Distance<Vec<T>, T>> {
|
||||||
LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
|
LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
|
||||||
CoverTree(CoverTree<Vec<T>, T, D>),
|
CoverTree(CoverTree<Vec<T>, T, D>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl KNNWeightFunction {
|
||||||
|
fn calc_weights<T: FloatExt>(&self, distances: Vec<T>) -> std::vec::Vec<T> {
|
||||||
|
match *self {
|
||||||
|
KNNWeightFunction::Distance => {
|
||||||
|
// if there are any points that has zero distance from one or more training points,
|
||||||
|
// those training points are weighted as 1.0 and the other points as 0.0
|
||||||
|
if distances.iter().any(|&e| e == T::zero()) {
|
||||||
|
distances
|
||||||
|
.iter()
|
||||||
|
.map(|e| if *e == T::zero() { T::one() } else { T::zero() })
|
||||||
|
.collect()
|
||||||
|
} else {
|
||||||
|
distances.iter().map(|e| T::one() / *e).collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
KNNWeightFunction::Uniform => vec![T::one(); distances.len()],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl KNNAlgorithmName {
|
impl KNNAlgorithmName {
|
||||||
fn fit<T: FloatExt, D: Distance<Vec<T>, T>>(
|
fn fit<T: FloatExt, D: Distance<Vec<T>, T>>(
|
||||||
&self,
|
&self,
|
||||||
@@ -74,7 +103,7 @@ impl KNNAlgorithmName {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNAlgorithm<T, D> {
|
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNAlgorithm<T, D> {
|
||||||
fn find(&self, from: &Vec<T>, k: usize) -> Vec<usize> {
|
fn find(&self, from: &Vec<T>, k: usize) -> Vec<(usize, T)> {
|
||||||
match *self {
|
match *self {
|
||||||
KNNAlgorithm::LinearSearch(ref linear) => linear.find(from, k),
|
KNNAlgorithm::LinearSearch(ref linear) => linear.find(from, k),
|
||||||
KNNAlgorithm::CoverTree(ref cover) => cover.find(from, k),
|
KNNAlgorithm::CoverTree(ref cover) => cover.find(from, k),
|
||||||
|
|||||||
Reference in New Issue
Block a user