Files
smartcore/src/classification/knn.rs
Volodymyr Orlov f4d3a80490 Initial commit
2019-05-28 17:46:03 -07:00

76 lines
1.8 KiB
Rust

use super::Classifier;
use super::super::math::distance::Distance;
use super::super::math::distance::euclidian::EuclidianDistance;
use ndarray::prelude::*;
use num_traits::Signed;
use num_traits::Float;
use std::marker::PhantomData;
pub struct KNNClassifier<E> {
y: Option<Array1<E>>
}
pub trait KNNAlgorithm<T>{
fn find(&self, from: &T, k: i32) -> &Vec<T>;
}
pub struct SimpleKNNAlgorithm<T, A, D>
where
A: Float,
D: Distance<T, A>
{
data: Vec<T>,
distance: D,
__phantom: PhantomData<A>
}
impl<T, A, D> KNNAlgorithm<T> for SimpleKNNAlgorithm<T, A, D>
where
A: Float,
D: Distance<T, A>
{
fn find(&self, from: &T, k: i32) -> &Vec<T> {
&self.data
}
}
impl<A1, A2> Classifier<A1, A2> for KNNClassifier<A2>
where
A2: Signed + Clone,
{
fn fit(&mut self, x: &Array2<A1>, y: &Array1<A2>){
self.y = Some(Array1::<A2>::zeros(ArrayBase::len(y)));
}
fn predict(&self, x: &Array2<A1>) -> Array1<A2>{
let array = Array1::<A2>::zeros(ArrayBase::len(self.y.as_ref().unwrap()));
array
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn knn_fit_predict() {
let mut knn = KNNClassifier{y: None};
let x = arr2(&[[1, 2, 3],[4, 5, 6]]);
let y = arr1(&[1, 2]);
knn.fit(&x, &y);
let r = knn.predict(&x);
assert_eq!(2, ArrayBase::len(&r));
}
#[test]
fn knn_find() {
let sKnn = SimpleKNNAlgorithm{
data: vec!(arr1(&[1., 2.]), arr1(&[1., 2.]), arr1(&[1., 2.])),
distance: EuclidianDistance{},
__phantom: PhantomData
};
assert_eq!(&vec!(arr1(&[1., 2.]), arr1(&[1., 2.]), arr1(&[1., 2.])), sKnn.find(&arr1(&[1., 2.]), 3));
}
}