feat: adds serialization/deserialization methods
This commit is contained in:
+88
-18
@@ -1,26 +1,83 @@
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::math::distance::Distance;
|
||||
use crate::linalg::{Matrix, row_iter};
|
||||
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
|
||||
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
|
||||
use crate::algorithm::neighbour::cover_tree::CoverTree;
|
||||
|
||||
pub struct KNNClassifier<'a, T: FloatExt> {
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct KNNClassifier<T: FloatExt, D: Distance<Vec<T>, T>> {
|
||||
classes: Vec<T>,
|
||||
y: Vec<usize>,
|
||||
knn_algorithm: Box<dyn KNNAlgorithm<Vec<T>> + 'a>,
|
||||
k: usize,
|
||||
knn_algorithm: KNNAlgorithmV<T, D>,
|
||||
k: usize
|
||||
}
|
||||
|
||||
impl<'a, T: FloatExt> KNNClassifier<'a, T> {
|
||||
pub enum KNNAlgorithmName {
|
||||
LinearSearch,
|
||||
CoverTree
|
||||
}
|
||||
|
||||
pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, k: usize, distance: &'a dyn Fn(&Vec<T>, &Vec<T>) -> T, algorithm: KNNAlgorithmName) -> KNNClassifier<'a, T> {
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub enum KNNAlgorithmV<T: FloatExt, D: Distance<Vec<T>, T>> {
|
||||
LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
|
||||
CoverTree(CoverTree<Vec<T>, T, D>)
|
||||
}
|
||||
|
||||
impl KNNAlgorithmName {
|
||||
|
||||
fn fit<T: FloatExt, D: Distance<Vec<T>, T>>(&self, data: Vec<Vec<T>>, distance: D) -> KNNAlgorithmV<T, D> {
|
||||
match *self {
|
||||
KNNAlgorithmName::LinearSearch => KNNAlgorithmV::LinearSearch(LinearKNNSearch::new(data, distance)),
|
||||
KNNAlgorithmName::CoverTree => KNNAlgorithmV::CoverTree(CoverTree::new(data, distance)),
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNAlgorithmV<T, D> {
|
||||
fn find(&self, from: &Vec<T>, k: usize) -> Vec<usize>{
|
||||
match *self {
|
||||
KNNAlgorithmV::LinearSearch(ref linear) => linear.find(from, k),
|
||||
KNNAlgorithmV::CoverTree(ref cover) => cover.find(from, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl<T: FloatExt, D: Distance<Vec<T>, T>> PartialEq for KNNClassifier<T, D> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.classes.len() != other.classes.len() ||
|
||||
self.k != other.k ||
|
||||
self.y.len() != other.y.len() {
|
||||
return false
|
||||
} else {
|
||||
for i in 0..self.classes.len() {
|
||||
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for i in 0..self.y.len() {
|
||||
if self.y[i] != other.y[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||
|
||||
pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, k: usize, distance: D, algorithm: KNNAlgorithmName) -> KNNClassifier<T, D> {
|
||||
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
|
||||
let (_, y_n) = y_m.shape();
|
||||
let (x_n, _) = x.shape();
|
||||
|
||||
let data = row_iter(x).collect();
|
||||
let data = row_iter(x).collect();
|
||||
|
||||
let mut yi: Vec<usize> = vec![0; y_n];
|
||||
let classes = y_m.unique();
|
||||
@@ -32,14 +89,9 @@ impl<'a, T: FloatExt> KNNClassifier<'a, T> {
|
||||
|
||||
assert!(x_n == y_n, format!("Size of x should equal size of y; |x|=[{}], |y|=[{}]", x_n, y_n));
|
||||
|
||||
assert!(k > 1, format!("k should be > 1, k=[{}]", k));
|
||||
assert!(k > 1, format!("k should be > 1, k=[{}]", k));
|
||||
|
||||
let knn_algorithm: Box<dyn KNNAlgorithm<Vec<T>> + 'a> = match algorithm {
|
||||
KNNAlgorithmName::CoverTree => Box::new(CoverTree::<Vec<T>, T>::new(data, distance)),
|
||||
KNNAlgorithmName::LinearSearch => Box::new(LinearKNNSearch::<Vec<T>, T>::new(data, distance))
|
||||
};
|
||||
|
||||
KNNClassifier{classes:classes, y: yi, k: k, knn_algorithm: knn_algorithm}
|
||||
KNNClassifier{classes:classes, y: yi, k: k, knn_algorithm: algorithm.fit(data, distance)}
|
||||
|
||||
}
|
||||
|
||||
@@ -74,8 +126,8 @@ impl<'a, T: FloatExt> KNNClassifier<'a, T> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::math::distance::euclidian;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::math::distance::Distances;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[test]
|
||||
fn knn_fit_predict() {
|
||||
@@ -85,10 +137,28 @@ mod tests {
|
||||
&[5., 6.],
|
||||
&[7., 8.],
|
||||
&[9., 10.]]);
|
||||
let y = vec![2., 2., 2., 3., 3.];
|
||||
let knn = KNNClassifier::fit(&x, &y, 3, &euclidian::distance, KNNAlgorithmName::LinearSearch);
|
||||
let y = vec![2., 2., 2., 3., 3.];
|
||||
let knn = KNNClassifier::fit(&x, &y, 3, Distances::euclidian(), KNNAlgorithmName::LinearSearch);
|
||||
let r = knn.predict(&x);
|
||||
assert_eq!(5, Vec::len(&r));
|
||||
assert_eq!(y.to_vec(), r);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_array(&[
|
||||
&[1., 2.],
|
||||
&[3., 4.],
|
||||
&[5., 6.],
|
||||
&[7., 8.],
|
||||
&[9., 10.]]);
|
||||
let y = vec![2., 2., 2., 3., 3.];
|
||||
|
||||
let knn = KNNClassifier::fit(&x, &y, 3, Distances::euclidian(), KNNAlgorithmName::CoverTree);
|
||||
|
||||
let deserialized_knn = bincode::deserialize(&bincode::serialize(&knn).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(knn, deserialized_knn);
|
||||
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user