diff --git a/src/classification/knn.rs b/src/classification/knn.rs index 073de53..1dad6d5 100644 --- a/src/classification/knn.rs +++ b/src/classification/knn.rs @@ -1,31 +1,84 @@ use super::Classifier; -use crate::math::distance::Distance; -use crate::math::distance::euclidian::EuclidianDistance; use crate::algorithm::sort::heap_select::HeapSelect; +use crate::common::AnyNumber; use ndarray::prelude::*; -use num_traits::Signed; -use num_traits::{Float, Num}; -use std::marker::PhantomData; +use ndarray::{ArrayBase, Data, Ix1, Ix2}; +use num_traits::{Float}; use std::cmp::{Ordering, PartialOrd}; -use std::fmt::Debug; +use ndarray::arr1; -pub struct KNNClassifier { - y: Option> +pub struct KNNClassifier +where + X: AnyNumber, + Y: AnyNumber, + F: Fn(&Array1, &Array1) -> f64 +{ + y: Vec, + distance: F, + k: usize, + knn_algorithm: Box, F>> } -pub trait KNNAlgorithm{ - fn find(&self, from: &T, k: usize) -> Vec<&T>; -} - -pub struct SimpleKNNAlgorithm> +impl KNNClassifier +where + X: AnyNumber, + Y: AnyNumber, + F: Fn(&Array1, &Array1) -> f64 { - data: Vec, - distance: D + + pub fn fit, SY: Data>(x: &ArrayBase, y: &ArrayBase, k: usize, distance: F) -> KNNClassifier { + + assert!(ArrayBase::shape(x)[0] == ArrayBase::shape(y)[0], format!("Size of x should equal size of y; |x|=[{}], |y|=[{}]", ArrayBase::shape(x)[0], ArrayBase::shape(y)[0])); + + assert!(k > 1, format!("k should be > 1, k=[{}]", k)); + + let v: Vec> = x.outer_iter().map(|x| x.to_owned()).collect(); + + let knn = Box::new(SimpleKNNAlgorithm{ + data: v + }); + + KNNClassifier{y: y.to_owned().to_vec(), k: k, distance: distance, knn_algorithm: knn} + } } -impl> KNNAlgorithm for SimpleKNNAlgorithm +impl Classifier for KNNClassifier +where + X: AnyNumber, + Y: AnyNumber, + SX: Data, + F: Fn(&Array1, &Array1) -> f64 + { + + fn predict(&self, x: &ArrayBase) -> Array1 { + let mut result = Vec::new(); + for x in x.outer_iter() { + let idxs = self.knn_algorithm.find(&x.to_owned(), self.k, &self.distance); + let mut sum: Y = Y::zero(); + let mut count = 0; + for i in idxs { + sum = sum + self.y[i].to_owned(); + count += 1; + } + result.push(sum / Y::from_u64(count).unwrap()); + } + arr1(&result) + } + +} + +pub trait KNNAlgorithm f64>{ + fn find(&self, from: &T, k: usize, d: &F) -> Vec; +} + +pub struct SimpleKNNAlgorithm { - fn find(&self, from: &T, k: usize) -> Vec<&T> { + data: Vec +} + +impl f64> KNNAlgorithm for SimpleKNNAlgorithm +{ + fn find(&self, from: &T, k: usize, d: &F) -> Vec { if k < 1 || k > self.data.len() { panic!("k should be >= 1 and <= length(data)"); } @@ -41,7 +94,7 @@ impl> KNNAlgorithm for SimpleKNNAlgorithm> KNNAlgorithm for SimpleKNNAlgorithm Classifier for KNNClassifier -where - A2: Signed + Clone, - { - fn fit(&mut self, x: &Array2, y: &Array1){ - self.y = Some(Array1::::zeros(ArrayBase::len(y))); - } - - fn predict(&self, x: &Array2) -> Array1{ - let array = Array1::::zeros(ArrayBase::len(self.y.as_ref().unwrap())); - array - } - -} - #[cfg(test)] mod tests { - use super::*; + use super::*; + use crate::math::distance::Distance; + use crate::math::distance::euclidian::EuclidianDistance; struct SimpleDistance{} @@ -104,23 +144,28 @@ mod tests { } #[test] - fn knn_fit_predict() { - let mut knn = KNNClassifier{y: None}; - let x = arr2(&[[1, 2, 3],[4, 5, 6]]); - let y = arr1(&[1, 2]); - knn.fit(&x, &y); + fn knn_fit_predict() { + let x = arr2(&[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]); + let y = arr1(&[1, 2, 3, 4, 5]); + let knn = KNNClassifier::fit(&x, &y, 3, EuclidianDistance::distance); let r = knn.predict(&x); - assert_eq!(2, ArrayBase::len(&r)); + assert_eq!(5, ArrayBase::len(&r)); + assert_eq!(arr1(&[2, 2, 3, 4, 4]), r); } #[test] fn knn_find() { - let sKnn = SimpleKNNAlgorithm{ - data: vec!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), - distance: SimpleDistance{} - }; + let simple_knn = SimpleKNNAlgorithm{ + data: vec!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + }; - assert_eq!(vec!(&2, &3, &1), sKnn.find(&2, 3)); + assert_eq!(vec!(1, 2, 0), simple_knn.find(&2, 3, &SimpleDistance::distance)); + + let knn2 = SimpleKNNAlgorithm{ + data: vec!(arr1(&[1, 1]), arr1(&[2, 2]), arr1(&[3, 3]), arr1(&[4, 4]), arr1(&[5, 5])) + }; + + assert_eq!(vec!(2, 3, 1), knn2.find(&arr1(&[3, 3]), 3, &EuclidianDistance::distance)); } #[test] diff --git a/src/classification/mod.rs b/src/classification/mod.rs index 1f94636..b3747d2 100644 --- a/src/classification/mod.rs +++ b/src/classification/mod.rs @@ -1,14 +1,15 @@ -use ndarray::prelude::*; -use ndarray::{arr1, arr2}; -use ndarray::FixedInitializer; +use crate::common::AnyNumber; +use ndarray::{Array1, ArrayBase, Data, Ix2}; pub mod knn; -pub trait Classifier -{ +pub trait Classifier +where + X: AnyNumber, + Y: AnyNumber, + SX: Data +{ - fn fit(&mut self, x: &Array2, y: &Array1); - - fn predict(&self, x: &Array2) -> Array1; + fn predict(&self, x: &ArrayBase) -> Array1; } \ No newline at end of file diff --git a/src/common/mod.rs b/src/common/mod.rs new file mode 100644 index 0000000..773c328 --- /dev/null +++ b/src/common/mod.rs @@ -0,0 +1,7 @@ +use num_traits::{Num, ToPrimitive, FromPrimitive}; +use ndarray::{ScalarOperand}; + +pub trait AnyNumber: Num + ScalarOperand + ToPrimitive + FromPrimitive{} + + +impl AnyNumber for T where T: Num + ScalarOperand + ToPrimitive + FromPrimitive {} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index cc6a7c1..3203152 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod classification; pub mod math; pub mod error; -pub mod algorithm; \ No newline at end of file +pub mod algorithm; +pub mod common; \ No newline at end of file diff --git a/src/math/distance/euclidian.rs b/src/math/distance/euclidian.rs index 965413f..31bf785 100644 --- a/src/math/distance/euclidian.rs +++ b/src/math/distance/euclidian.rs @@ -1,13 +1,12 @@ use super::Distance; use ndarray::{ArrayBase, Data, Dimension}; -use num_traits::{Num, ToPrimitive}; -use ndarray::{ScalarOperand}; +use crate::common::AnyNumber; pub struct EuclidianDistance{} impl Distance> for EuclidianDistance where - A: Num + ScalarOperand + ToPrimitive, + A: AnyNumber, S: Data, D: Dimension { @@ -25,17 +24,19 @@ where #[cfg(test)] mod tests { use super::*; - use ndarray::{arr1, Array}; + use ndarray::arr1; #[test] fn measure_simple_euclidian_distance() { let a = arr1(&[1, 2, 3]); - let b = arr1(&[4, 5, 6]); + let b = arr1(&[4, 5, 6]); - let d = EuclidianDistance::distance(&a, &b); + let d_arr = EuclidianDistance::distance(&a, &b); + let d_view = EuclidianDistance::distance(&a.view(), &b.view()); - assert!((d - 5.19615242).abs() < 1e-8); - } + assert!((d_arr - 5.19615242).abs() < 1e-8); + assert!((d_view - 5.19615242).abs() < 1e-8); + } #[test] fn measure_simple_euclidian_distance_static() {