feat: + cross_validate, trait Predictor, refactoring
This commit is contained in:
@@ -25,31 +25,40 @@
|
||||
//! &[9., 10.]]);
|
||||
//! let y = vec![2., 2., 2., 3., 3.]; //your class labels
|
||||
//!
|
||||
//! let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||
//! let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||
//! let y_hat = knn.predict(&x).unwrap();
|
||||
//! ```
|
||||
//!
|
||||
//! variable `y_hat` will hold a vector with estimates of class labels
|
||||
//!
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
|
||||
use crate::base::Predictor;
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::{row_iter, Matrix};
|
||||
use crate::math::distance::Distance;
|
||||
use crate::math::distance::euclidian::Euclidian;
|
||||
use crate::math::distance::{Distance, Distances};
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::neighbors::KNNWeightFunction;
|
||||
|
||||
/// `KNNClassifier` parameters. Use `Default::default()` for default values.
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct KNNClassifierParameters {
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct KNNClassifierParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
/// a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
pub distance: D,
|
||||
/// backend search algorithm. See [`knn search algorithms`](../../algorithm/neighbour/index.html). `CoverTree` is default.
|
||||
pub algorithm: KNNAlgorithmName,
|
||||
/// weighting function that is used to calculate estimated class value. Default function is `KNNWeightFunction::Uniform`.
|
||||
pub weight: KNNWeightFunction,
|
||||
/// number of training samples to consider when estimating class for new point. Default value is 3.
|
||||
pub k: usize,
|
||||
/// this parameter is not used
|
||||
t: PhantomData<T>,
|
||||
}
|
||||
|
||||
/// K Nearest Neighbors Classifier
|
||||
@@ -62,12 +71,39 @@ pub struct KNNClassifier<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
k: usize,
|
||||
}
|
||||
|
||||
impl Default for KNNClassifierParameters {
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNClassifierParameters<T, D> {
|
||||
/// number of training samples to consider when estimating class for new point. Default value is 3.
|
||||
pub fn with_k(mut self, k: usize) -> Self {
|
||||
self.k = k;
|
||||
self
|
||||
}
|
||||
/// a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
pub fn with_distance(mut self, distance: D) -> Self {
|
||||
self.distance = distance;
|
||||
self
|
||||
}
|
||||
/// backend search algorithm. See [`knn search algorithms`](../../algorithm/neighbour/index.html). `CoverTree` is default.
|
||||
pub fn with_algorithm(mut self, algorithm: KNNAlgorithmName) -> Self {
|
||||
self.algorithm = algorithm;
|
||||
self
|
||||
}
|
||||
/// weighting function that is used to calculate estimated class value. Default function is `KNNWeightFunction::Uniform`.
|
||||
pub fn with_weight(mut self, weight: KNNWeightFunction) -> Self {
|
||||
self.weight = weight;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for KNNClassifierParameters<T, Euclidian> {
|
||||
fn default() -> Self {
|
||||
KNNClassifierParameters {
|
||||
distance: Distances::euclidian(),
|
||||
algorithm: KNNAlgorithmName::CoverTree,
|
||||
weight: KNNWeightFunction::Uniform,
|
||||
k: 3,
|
||||
t: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -95,19 +131,23 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> PartialEq for KNNClassifier<T, D> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, D: Distance<Vec<T>, T>> Predictor<M, M::RowVector>
|
||||
for KNNClassifier<T, D>
|
||||
{
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||
/// Fits KNN classifier to a NxM matrix where N is number of samples and M is number of features.
|
||||
/// * `x` - training data
|
||||
/// * `y` - vector with target values (classes) of length N
|
||||
/// * `distance` - a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
/// * `y` - vector with target values (classes) of length N
|
||||
/// * `parameters` - additional parameters like search algorithm and k
|
||||
pub fn fit<M: Matrix<T>>(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
distance: D,
|
||||
parameters: KNNClassifierParameters,
|
||||
parameters: KNNClassifierParameters<T, D>,
|
||||
) -> Result<KNNClassifier<T, D>, Failed> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
|
||||
@@ -142,7 +182,7 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||
classes,
|
||||
y: yi,
|
||||
k: parameters.k,
|
||||
knn_algorithm: parameters.algorithm.fit(data, distance)?,
|
||||
knn_algorithm: parameters.algorithm.fit(data, parameters.distance)?,
|
||||
weight: parameters.weight,
|
||||
})
|
||||
}
|
||||
@@ -187,14 +227,13 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::math::distance::Distances;
|
||||
|
||||
#[test]
|
||||
fn knn_fit_predict() {
|
||||
let x =
|
||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
let y = vec![2., 2., 2., 3., 3.];
|
||||
let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||
let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||
let y_hat = knn.predict(&x).unwrap();
|
||||
assert_eq!(5, Vec::len(&y_hat));
|
||||
assert_eq!(y.to_vec(), y_hat);
|
||||
@@ -207,12 +246,10 @@ mod tests {
|
||||
let knn = KNNClassifier::fit(
|
||||
&x,
|
||||
&y,
|
||||
Distances::euclidian(),
|
||||
KNNClassifierParameters {
|
||||
k: 5,
|
||||
algorithm: KNNAlgorithmName::LinearSearch,
|
||||
weight: KNNWeightFunction::Distance,
|
||||
},
|
||||
KNNClassifierParameters::default()
|
||||
.with_k(5)
|
||||
.with_algorithm(KNNAlgorithmName::LinearSearch)
|
||||
.with_weight(KNNWeightFunction::Distance),
|
||||
)
|
||||
.unwrap();
|
||||
let y_hat = knn.predict(&DenseMatrix::from_2d_array(&[&[4.1]])).unwrap();
|
||||
@@ -225,7 +262,7 @@ mod tests {
|
||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
let y = vec![2., 2., 2., 3., 3.];
|
||||
|
||||
let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||
let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let deserialized_knn = bincode::deserialize(&bincode::serialize(&knn).unwrap()).unwrap();
|
||||
|
||||
|
||||
@@ -27,31 +27,41 @@
|
||||
//! &[5., 5.]]);
|
||||
//! let y = vec![1., 2., 3., 4., 5.]; //your target values
|
||||
//!
|
||||
//! let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||
//! let knn = KNNRegressor::fit(&x, &y, Default::default()).unwrap();
|
||||
//! let y_hat = knn.predict(&x).unwrap();
|
||||
//! ```
|
||||
//!
|
||||
//! variable `y_hat` will hold predicted value
|
||||
//!
|
||||
//!
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
|
||||
use crate::base::Predictor;
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::{row_iter, BaseVector, Matrix};
|
||||
use crate::math::distance::Distance;
|
||||
use crate::math::distance::euclidian::Euclidian;
|
||||
use crate::math::distance::{Distance, Distances};
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::neighbors::KNNWeightFunction;
|
||||
|
||||
/// `KNNRegressor` parameters. Use `Default::default()` for default values.
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct KNNRegressorParameters {
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct KNNRegressorParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
/// a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
distance: D,
|
||||
/// backend search algorithm. See [`knn search algorithms`](../../algorithm/neighbour/index.html). `CoverTree` is default.
|
||||
pub algorithm: KNNAlgorithmName,
|
||||
/// weighting function that is used to calculate estimated class value. Default function is `KNNWeightFunction::Uniform`.
|
||||
pub weight: KNNWeightFunction,
|
||||
/// number of training samples to consider when estimating class for new point. Default value is 3.
|
||||
pub k: usize,
|
||||
/// this parameter is not used
|
||||
t: PhantomData<T>,
|
||||
}
|
||||
|
||||
/// K Nearest Neighbors Regressor
|
||||
@@ -63,12 +73,39 @@ pub struct KNNRegressor<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
k: usize,
|
||||
}
|
||||
|
||||
impl Default for KNNRegressorParameters {
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNRegressorParameters<T, D> {
|
||||
/// number of training samples to consider when estimating class for new point. Default value is 3.
|
||||
pub fn with_k(mut self, k: usize) -> Self {
|
||||
self.k = k;
|
||||
self
|
||||
}
|
||||
/// a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
pub fn with_distance(mut self, distance: D) -> Self {
|
||||
self.distance = distance;
|
||||
self
|
||||
}
|
||||
/// backend search algorithm. See [`knn search algorithms`](../../algorithm/neighbour/index.html). `CoverTree` is default.
|
||||
pub fn with_algorithm(mut self, algorithm: KNNAlgorithmName) -> Self {
|
||||
self.algorithm = algorithm;
|
||||
self
|
||||
}
|
||||
/// weighting function that is used to calculate estimated class value. Default function is `KNNWeightFunction::Uniform`.
|
||||
pub fn with_weight(mut self, weight: KNNWeightFunction) -> Self {
|
||||
self.weight = weight;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for KNNRegressorParameters<T, Euclidian> {
|
||||
fn default() -> Self {
|
||||
KNNRegressorParameters {
|
||||
distance: Distances::euclidian(),
|
||||
algorithm: KNNAlgorithmName::CoverTree,
|
||||
weight: KNNWeightFunction::Uniform,
|
||||
k: 3,
|
||||
t: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -88,19 +125,23 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> PartialEq for KNNRegressor<T, D> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, D: Distance<Vec<T>, T>> Predictor<M, M::RowVector>
|
||||
for KNNRegressor<T, D>
|
||||
{
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
|
||||
/// Fits KNN regressor to a NxM matrix where N is number of samples and M is number of features.
|
||||
/// * `x` - training data
|
||||
/// * `y` - vector with real values
|
||||
/// * `distance` - a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
/// * `y` - vector with real values
|
||||
/// * `parameters` - additional parameters like search algorithm and k
|
||||
pub fn fit<M: Matrix<T>>(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
distance: D,
|
||||
parameters: KNNRegressorParameters,
|
||||
parameters: KNNRegressorParameters<T, D>,
|
||||
) -> Result<KNNRegressor<T, D>, Failed> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
|
||||
@@ -126,7 +167,7 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
|
||||
Ok(KNNRegressor {
|
||||
y: y.to_vec(),
|
||||
k: parameters.k,
|
||||
knn_algorithm: parameters.algorithm.fit(data, distance)?,
|
||||
knn_algorithm: parameters.algorithm.fit(data, parameters.distance)?,
|
||||
weight: parameters.weight,
|
||||
})
|
||||
}
|
||||
@@ -176,12 +217,11 @@ mod tests {
|
||||
let knn = KNNRegressor::fit(
|
||||
&x,
|
||||
&y,
|
||||
Distances::euclidian(),
|
||||
KNNRegressorParameters {
|
||||
k: 3,
|
||||
algorithm: KNNAlgorithmName::LinearSearch,
|
||||
weight: KNNWeightFunction::Distance,
|
||||
},
|
||||
KNNRegressorParameters::default()
|
||||
.with_k(3)
|
||||
.with_distance(Distances::euclidian())
|
||||
.with_algorithm(KNNAlgorithmName::LinearSearch)
|
||||
.with_weight(KNNWeightFunction::Distance),
|
||||
)
|
||||
.unwrap();
|
||||
let y_hat = knn.predict(&x).unwrap();
|
||||
@@ -197,7 +237,7 @@ mod tests {
|
||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
let y: Vec<f64> = vec![1., 2., 3., 4., 5.];
|
||||
let y_exp = vec![2., 2., 3., 4., 4.];
|
||||
let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||
let knn = KNNRegressor::fit(&x, &y, Default::default()).unwrap();
|
||||
let y_hat = knn.predict(&x).unwrap();
|
||||
assert_eq!(5, Vec::len(&y_hat));
|
||||
for i in 0..y_hat.len() {
|
||||
@@ -211,7 +251,7 @@ mod tests {
|
||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
let y = vec![1., 2., 3., 4., 5.];
|
||||
|
||||
let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||
let knn = KNNRegressor::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let deserialized_knn = bincode::deserialize(&bincode::serialize(&knn).unwrap()).unwrap();
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ pub mod knn_regressor;
|
||||
pub type KNNAlgorithmName = crate::algorithm::neighbour::KNNAlgorithmName;
|
||||
|
||||
/// Weight function that is used to determine estimated value.
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub enum KNNWeightFunction {
|
||||
/// All k nearest points are weighted equally
|
||||
Uniform,
|
||||
|
||||
Reference in New Issue
Block a user