feat: new distance function parameter in KNN, extends KNN documentation
This commit is contained in:
@@ -37,13 +37,15 @@ use serde::{Deserialize, Serialize};
|
||||
use crate::linalg::{row_iter, Matrix};
|
||||
use crate::math::distance::Distance;
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::neighbors::{KNNAlgorithm, KNNAlgorithmName};
|
||||
use crate::neighbors::{KNNAlgorithm, KNNAlgorithmName, KNNWeightFunction};
|
||||
|
||||
/// `KNNClassifier` parameters. Use `Default::default()` for default values.
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct KNNClassifierParameters {
|
||||
/// backend search algorithm. See [`knn search algorithms`](../../algorithm/neighbour/index.html). `CoverTree` is default.
|
||||
pub algorithm: KNNAlgorithmName,
|
||||
/// weighting function that is used to calculate estimated class value. Default function is `KNNWeightFunction::Uniform`.
|
||||
pub weight: KNNWeightFunction,
|
||||
/// number of training samples to consider when estimating class for new point. Default value is 3.
|
||||
pub k: usize,
|
||||
}
|
||||
@@ -54,6 +56,7 @@ pub struct KNNClassifier<T: FloatExt, D: Distance<Vec<T>, T>> {
|
||||
classes: Vec<T>,
|
||||
y: Vec<usize>,
|
||||
knn_algorithm: KNNAlgorithm<T, D>,
|
||||
weight: KNNWeightFunction,
|
||||
k: usize,
|
||||
}
|
||||
|
||||
@@ -61,6 +64,7 @@ impl Default for KNNClassifierParameters {
|
||||
fn default() -> Self {
|
||||
KNNClassifierParameters {
|
||||
algorithm: KNNAlgorithmName::CoverTree,
|
||||
weight: KNNWeightFunction::Uniform,
|
||||
k: 3,
|
||||
}
|
||||
}
|
||||
@@ -90,7 +94,7 @@ impl<T: FloatExt, D: Distance<Vec<T>, T>> PartialEq for KNNClassifier<T, D> {
|
||||
}
|
||||
|
||||
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||
/// Fits KNN Classifier to a NxM matrix where N is number of samples and M is number of features.
|
||||
/// Fits KNN classifier to a NxM matrix where N is number of samples and M is number of features.
|
||||
/// * `x` - training data
|
||||
/// * `y` - vector with target values (classes) of length N
|
||||
/// * `distance` - a function that defines a distance between each pair of point in training data.
|
||||
@@ -136,6 +140,7 @@ impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||
y: yi,
|
||||
k: parameters.k,
|
||||
knn_algorithm: parameters.algorithm.fit(data, distance),
|
||||
weight: parameters.weight,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -153,15 +158,21 @@ impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||
}
|
||||
|
||||
fn predict_for_row(&self, x: Vec<T>) -> usize {
|
||||
let idxs = self.knn_algorithm.find(&x, self.k);
|
||||
let mut c = vec![0; self.classes.len()];
|
||||
let mut max_c = 0;
|
||||
let search_result = self.knn_algorithm.find(&x, self.k);
|
||||
|
||||
let weights = self
|
||||
.weight
|
||||
.calc_weights(search_result.iter().map(|v| v.1).collect());
|
||||
let w_sum = weights.iter().map(|w| *w).sum();
|
||||
|
||||
let mut c = vec![T::zero(); self.classes.len()];
|
||||
let mut max_c = T::zero();
|
||||
let mut max_i = 0;
|
||||
for i in idxs {
|
||||
c[self.y[i]] += 1;
|
||||
if c[self.y[i]] > max_c {
|
||||
max_c = c[self.y[i]];
|
||||
max_i = self.y[i];
|
||||
for (r, w) in search_result.iter().zip(weights.iter()) {
|
||||
c[self.y[r.0]] = c[self.y[r.0]] + (*w / w_sum);
|
||||
if c[self.y[r.0]] > max_c {
|
||||
max_c = c[self.y[r.0]];
|
||||
max_i = self.y[r.0];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,18 +190,28 @@ mod tests {
|
||||
fn knn_fit_predict() {
|
||||
let x = DenseMatrix::from_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
let y = vec![2., 2., 2., 3., 3.];
|
||||
let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default());
|
||||
let y_hat = knn.predict(&x);
|
||||
assert_eq!(5, Vec::len(&y_hat));
|
||||
assert_eq!(y.to_vec(), y_hat);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn knn_fit_predict_weighted() {
|
||||
let x = DenseMatrix::from_array(&[&[1.], &[2.], &[3.], &[4.], &[5.]]);
|
||||
let y = vec![2., 2., 2., 3., 3.];
|
||||
let knn = KNNClassifier::fit(
|
||||
&x,
|
||||
&y,
|
||||
Distances::euclidian(),
|
||||
KNNClassifierParameters {
|
||||
k: 3,
|
||||
k: 5,
|
||||
algorithm: KNNAlgorithmName::LinearSearch,
|
||||
weight: KNNWeightFunction::Distance,
|
||||
},
|
||||
);
|
||||
let r = knn.predict(&x);
|
||||
assert_eq!(5, Vec::len(&r));
|
||||
assert_eq!(y.to_vec(), r);
|
||||
let y_hat = knn.predict(&DenseMatrix::from_array(&[&[4.1]]));
|
||||
assert_eq!(vec![3.0], y_hat);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
Reference in New Issue
Block a user