feat: adds KNN Regressor
This commit is contained in:
+2
-2
@@ -36,14 +36,14 @@
|
|||||||
//!
|
//!
|
||||||
//! Each category is assigned to a separate module.
|
//! Each category is assigned to a separate module.
|
||||||
//!
|
//!
|
||||||
//! For example, KNN classifier is defined in [smartcore::neighbors::knn](neighbors/knn/index.html). To train and run it using standard Rust vectors you will
|
//! For example, KNN classifier is defined in [smartcore::neighbors::knn_classifier](neighbors/knn_classifier/index.html). To train and run it using standard Rust vectors you will
|
||||||
//! run this code:
|
//! run this code:
|
||||||
//!
|
//!
|
||||||
//! ```
|
//! ```
|
||||||
//! // DenseMatrix defenition
|
//! // DenseMatrix defenition
|
||||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||||
//! // KNNClassifier
|
//! // KNNClassifier
|
||||||
//! use smartcore::neighbors::knn::*;
|
//! use smartcore::neighbors::knn_classifier::*;
|
||||||
//! // Various distance metrics
|
//! // Various distance metrics
|
||||||
//! use smartcore::math::distance::*;
|
//! use smartcore::math::distance::*;
|
||||||
//!
|
//!
|
||||||
|
|||||||
@@ -1,17 +1,10 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::algorithm::neighbour::cover_tree::CoverTree;
|
use crate::neighbors::{KNNAlgorithmName, KNNAlgorithm};
|
||||||
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
|
|
||||||
use crate::linalg::{row_iter, Matrix};
|
use crate::linalg::{row_iter, Matrix};
|
||||||
use crate::math::distance::Distance;
|
use crate::math::distance::Distance;
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
|
||||||
pub enum KNNAlgorithmName {
|
|
||||||
LinearSearch,
|
|
||||||
CoverTree,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct KNNClassifierParameters {
|
pub struct KNNClassifierParameters {
|
||||||
pub algorithm: KNNAlgorithmName,
|
pub algorithm: KNNAlgorithmName,
|
||||||
@@ -26,12 +19,6 @@ pub struct KNNClassifier<T: FloatExt, D: Distance<Vec<T>, T>> {
|
|||||||
k: usize,
|
k: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
|
||||||
enum KNNAlgorithm<T: FloatExt, D: Distance<Vec<T>, T>> {
|
|
||||||
LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
|
|
||||||
CoverTree(CoverTree<Vec<T>, T, D>),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for KNNClassifierParameters {
|
impl Default for KNNClassifierParameters {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
KNNClassifierParameters {
|
KNNClassifierParameters {
|
||||||
@@ -41,30 +28,6 @@ impl Default for KNNClassifierParameters {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl KNNAlgorithmName {
|
|
||||||
fn fit<T: FloatExt, D: Distance<Vec<T>, T>>(
|
|
||||||
&self,
|
|
||||||
data: Vec<Vec<T>>,
|
|
||||||
distance: D,
|
|
||||||
) -> KNNAlgorithm<T, D> {
|
|
||||||
match *self {
|
|
||||||
KNNAlgorithmName::LinearSearch => {
|
|
||||||
KNNAlgorithm::LinearSearch(LinearKNNSearch::new(data, distance))
|
|
||||||
}
|
|
||||||
KNNAlgorithmName::CoverTree => KNNAlgorithm::CoverTree(CoverTree::new(data, distance)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNAlgorithm<T, D> {
|
|
||||||
fn find(&self, from: &Vec<T>, k: usize) -> Vec<usize> {
|
|
||||||
match *self {
|
|
||||||
KNNAlgorithm::LinearSearch(ref linear) => linear.find(from, k),
|
|
||||||
KNNAlgorithm::CoverTree(ref cover) => cover.find(from, k),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: FloatExt, D: Distance<Vec<T>, T>> PartialEq for KNNClassifier<T, D> {
|
impl<T: FloatExt, D: Distance<Vec<T>, T>> PartialEq for KNNClassifier<T, D> {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
if self.classes.len() != other.classes.len()
|
if self.classes.len() != other.classes.len()
|
||||||
@@ -0,0 +1,139 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::neighbors::{KNNAlgorithmName, KNNAlgorithm};
|
||||||
|
use crate::linalg::{row_iter, BaseVector, Matrix};
|
||||||
|
use crate::math::distance::Distance;
|
||||||
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub struct KNNRegressorParameters {
|
||||||
|
pub algorithm: KNNAlgorithmName,
|
||||||
|
pub k: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub struct KNNRegressor<T: FloatExt, D: Distance<Vec<T>, T>> {
|
||||||
|
y: Vec<T>,
|
||||||
|
knn_algorithm: KNNAlgorithm<T, D>,
|
||||||
|
k: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for KNNRegressorParameters {
|
||||||
|
fn default() -> Self {
|
||||||
|
KNNRegressorParameters {
|
||||||
|
algorithm: KNNAlgorithmName::CoverTree,
|
||||||
|
k: 3,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: FloatExt, D: Distance<Vec<T>, T>> PartialEq for KNNRegressor<T, D> {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
if self.k != other.k || self.y.len() != other.y.len(){
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
for i in 0..self.y.len() {
|
||||||
|
if (self.y[i] - other.y[i]).abs() > T::epsilon() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
|
||||||
|
pub fn fit<M: Matrix<T>>(
|
||||||
|
x: &M,
|
||||||
|
y: &M::RowVector,
|
||||||
|
distance: D,
|
||||||
|
parameters: KNNRegressorParameters,
|
||||||
|
) -> KNNRegressor<T, D> {
|
||||||
|
let y_m = M::from_row_vector(y.clone());
|
||||||
|
|
||||||
|
let (_, y_n) = y_m.shape();
|
||||||
|
let (x_n, _) = x.shape();
|
||||||
|
|
||||||
|
let data = row_iter(x).collect();
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
x_n == y_n,
|
||||||
|
format!(
|
||||||
|
"Size of x should equal size of y; |x|=[{}], |y|=[{}]",
|
||||||
|
x_n, y_n
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
parameters.k > 1,
|
||||||
|
format!("k should be > 1, k=[{}]", parameters.k)
|
||||||
|
);
|
||||||
|
|
||||||
|
KNNRegressor {
|
||||||
|
y: y.to_vec(),
|
||||||
|
k: parameters.k,
|
||||||
|
knn_algorithm: parameters.algorithm.fit(data, distance),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
|
||||||
|
let mut result = M::zeros(1, x.shape().0);
|
||||||
|
|
||||||
|
row_iter(x)
|
||||||
|
.enumerate()
|
||||||
|
.for_each(|(i, x)| result.set(0, i, self.predict_for_row(x)));
|
||||||
|
|
||||||
|
result.to_row_vector()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn predict_for_row(&self, x: Vec<T>) -> T {
|
||||||
|
let idxs = self.knn_algorithm.find(&x, self.k);
|
||||||
|
let mut result = T::zero();
|
||||||
|
for i in idxs {
|
||||||
|
result = result + self.y[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
result / T::from_usize(self.k).unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
use crate::math::distance::Distances;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn knn_fit_predict() {
|
||||||
|
let x = DenseMatrix::from_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||||
|
let y: Vec<f64> = vec![1., 2., 3., 4., 5.];
|
||||||
|
let y_exp = vec![2., 2., 3., 4., 4.];
|
||||||
|
let knn = KNNRegressor::fit(
|
||||||
|
&x,
|
||||||
|
&y,
|
||||||
|
Distances::euclidian(),
|
||||||
|
KNNRegressorParameters {
|
||||||
|
k: 3,
|
||||||
|
algorithm: KNNAlgorithmName::LinearSearch,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
let y_hat = knn.predict(&x);
|
||||||
|
assert_eq!(5, Vec::len(&y_hat));
|
||||||
|
for i in 0..y_hat.len() {
|
||||||
|
assert!((y_hat[i] - y_exp[i]).abs() < std::f64::EPSILON);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn serde() {
|
||||||
|
let x = DenseMatrix::from_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||||
|
let y = vec![1., 2., 3., 4., 5.];
|
||||||
|
|
||||||
|
let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default());
|
||||||
|
|
||||||
|
let deserialized_knn = bincode::deserialize(&bincode::serialize(&knn).unwrap()).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(knn, deserialized_knn);
|
||||||
|
}
|
||||||
|
}
|
||||||
+47
-1
@@ -1 +1,47 @@
|
|||||||
pub mod knn;
|
//! # Nearest Neighbors
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use crate::algorithm::neighbour::cover_tree::CoverTree;
|
||||||
|
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
|
||||||
|
use crate::math::distance::Distance;
|
||||||
|
use crate::math::num::FloatExt;
|
||||||
|
|
||||||
|
///
|
||||||
|
pub mod knn_classifier;
|
||||||
|
pub mod knn_regressor;
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
pub enum KNNAlgorithmName {
|
||||||
|
LinearSearch,
|
||||||
|
CoverTree,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
|
enum KNNAlgorithm<T: FloatExt, D: Distance<Vec<T>, T>> {
|
||||||
|
LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
|
||||||
|
CoverTree(CoverTree<Vec<T>, T, D>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl KNNAlgorithmName {
|
||||||
|
fn fit<T: FloatExt, D: Distance<Vec<T>, T>>(
|
||||||
|
&self,
|
||||||
|
data: Vec<Vec<T>>,
|
||||||
|
distance: D,
|
||||||
|
) -> KNNAlgorithm<T, D> {
|
||||||
|
match *self {
|
||||||
|
KNNAlgorithmName::LinearSearch => {
|
||||||
|
KNNAlgorithm::LinearSearch(LinearKNNSearch::new(data, distance))
|
||||||
|
}
|
||||||
|
KNNAlgorithmName::CoverTree => KNNAlgorithm::CoverTree(CoverTree::new(data, distance)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNAlgorithm<T, D> {
|
||||||
|
fn find(&self, from: &Vec<T>, k: usize) -> Vec<usize> {
|
||||||
|
match *self {
|
||||||
|
KNNAlgorithm::LinearSearch(ref linear) => linear.find(from, k),
|
||||||
|
KNNAlgorithm::CoverTree(ref cover) => cover.find(from, k),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user