More KNN experiments

This commit is contained in:
Volodymyr Orlov
2019-09-05 18:34:50 -07:00
parent a9ec6dfcd0
commit 9c5f6eb307
5 changed files with 118 additions and 63 deletions
+91 -46
View File
@@ -1,31 +1,84 @@
use super::Classifier;
use crate::math::distance::Distance;
use crate::math::distance::euclidian::EuclidianDistance;
use crate::algorithm::sort::heap_select::HeapSelect;
use crate::common::AnyNumber;
use ndarray::prelude::*;
use num_traits::Signed;
use num_traits::{Float, Num};
use std::marker::PhantomData;
use ndarray::{ArrayBase, Data, Ix1, Ix2};
use num_traits::{Float};
use std::cmp::{Ordering, PartialOrd};
use std::fmt::Debug;
use ndarray::arr1;
pub struct KNNClassifier<E> {
y: Option<Array1<E>>
pub struct KNNClassifier<X, Y, F>
where
X: AnyNumber,
Y: AnyNumber,
F: Fn(&Array1<X>, &Array1<X>) -> f64
{
y: Vec<Y>,
distance: F,
k: usize,
knn_algorithm: Box<KNNAlgorithm<Array1<X>, F>>
}
pub trait KNNAlgorithm<T: Clone + Debug>{
fn find(&self, from: &T, k: usize) -> Vec<&T>;
}
pub struct SimpleKNNAlgorithm<T, D: Distance<T>>
impl<X, Y, F> KNNClassifier<X, Y, F>
where
X: AnyNumber,
Y: AnyNumber,
F: Fn(&Array1<X>, &Array1<X>) -> f64
{
data: Vec<T>,
distance: D
pub fn fit<SX: Data<Elem = X>, SY: Data<Elem = Y>>(x: &ArrayBase<SX, Ix2>, y: &ArrayBase<SY, Ix1>, k: usize, distance: F) -> KNNClassifier<X, Y, F> {
assert!(ArrayBase::shape(x)[0] == ArrayBase::shape(y)[0], format!("Size of x should equal size of y; |x|=[{}], |y|=[{}]", ArrayBase::shape(x)[0], ArrayBase::shape(y)[0]));
assert!(k > 1, format!("k should be > 1, k=[{}]", k));
let v: Vec<Array1<X>> = x.outer_iter().map(|x| x.to_owned()).collect();
let knn = Box::new(SimpleKNNAlgorithm{
data: v
});
KNNClassifier{y: y.to_owned().to_vec(), k: k, distance: distance, knn_algorithm: knn}
}
}
impl<T: Clone + Debug, D: Distance<T>> KNNAlgorithm<T> for SimpleKNNAlgorithm<T, D>
impl<X, Y, SX, F> Classifier<X, Y, SX> for KNNClassifier<X, Y, F>
where
X: AnyNumber,
Y: AnyNumber,
SX: Data<Elem = X>,
F: Fn(&Array1<X>, &Array1<X>) -> f64
{
fn predict(&self, x: &ArrayBase<SX, Ix2>) -> Array1<Y> {
let mut result = Vec::new();
for x in x.outer_iter() {
let idxs = self.knn_algorithm.find(&x.to_owned(), self.k, &self.distance);
let mut sum: Y = Y::zero();
let mut count = 0;
for i in idxs {
sum = sum + self.y[i].to_owned();
count += 1;
}
result.push(sum / Y::from_u64(count).unwrap());
}
arr1(&result)
}
}
pub trait KNNAlgorithm<T: Clone, F: Fn(&T, &T) -> f64>{
fn find(&self, from: &T, k: usize, d: &F) -> Vec<usize>;
}
pub struct SimpleKNNAlgorithm<T>
{
fn find(&self, from: &T, k: usize) -> Vec<&T> {
data: Vec<T>
}
impl<T: Clone, F: Fn(&T, &T) -> f64> KNNAlgorithm<T, F> for SimpleKNNAlgorithm<T>
{
fn find(&self, from: &T, k: usize, d: &F) -> Vec<usize> {
if k < 1 || k > self.data.len() {
panic!("k should be >= 1 and <= length(data)");
}
@@ -41,7 +94,7 @@ impl<T: Clone + Debug, D: Distance<T>> KNNAlgorithm<T> for SimpleKNNAlgorithm<T,
for i in 0..self.data.len() {
let d = D::distance(&from, &self.data[i]);
let d = d(&from, &self.data[i]);
let datum = heap.peek_mut();
if d < datum.distance {
datum.distance = d;
@@ -52,7 +105,7 @@ impl<T: Clone + Debug, D: Distance<T>> KNNAlgorithm<T> for SimpleKNNAlgorithm<T,
heap.sort();
heap.get().into_iter().flat_map(|x| x.index).map(|i| &self.data[i]).collect()
heap.get().into_iter().flat_map(|x| x.index).collect()
}
}
@@ -76,24 +129,11 @@ impl PartialEq for KNNPoint {
impl Eq for KNNPoint {}
impl<A1, A2> Classifier<A1, A2> for KNNClassifier<A2>
where
A2: Signed + Clone,
{
fn fit(&mut self, x: &Array2<A1>, y: &Array1<A2>){
self.y = Some(Array1::<A2>::zeros(ArrayBase::len(y)));
}
fn predict(&self, x: &Array2<A1>) -> Array1<A2>{
let array = Array1::<A2>::zeros(ArrayBase::len(self.y.as_ref().unwrap()));
array
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::*;
use crate::math::distance::Distance;
use crate::math::distance::euclidian::EuclidianDistance;
struct SimpleDistance{}
@@ -104,23 +144,28 @@ mod tests {
}
#[test]
fn knn_fit_predict() {
let mut knn = KNNClassifier{y: None};
let x = arr2(&[[1, 2, 3],[4, 5, 6]]);
let y = arr1(&[1, 2]);
knn.fit(&x, &y);
fn knn_fit_predict() {
let x = arr2(&[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]);
let y = arr1(&[1, 2, 3, 4, 5]);
let knn = KNNClassifier::fit(&x, &y, 3, EuclidianDistance::distance);
let r = knn.predict(&x);
assert_eq!(2, ArrayBase::len(&r));
assert_eq!(5, ArrayBase::len(&r));
assert_eq!(arr1(&[2, 2, 3, 4, 4]), r);
}
#[test]
fn knn_find() {
let sKnn = SimpleKNNAlgorithm{
data: vec!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
distance: SimpleDistance{}
};
let simple_knn = SimpleKNNAlgorithm{
data: vec!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
};
assert_eq!(vec!(&2, &3, &1), sKnn.find(&2, 3));
assert_eq!(vec!(1, 2, 0), simple_knn.find(&2, 3, &SimpleDistance::distance));
let knn2 = SimpleKNNAlgorithm{
data: vec!(arr1(&[1, 1]), arr1(&[2, 2]), arr1(&[3, 3]), arr1(&[4, 4]), arr1(&[5, 5]))
};
assert_eq!(vec!(2, 3, 1), knn2.find(&arr1(&[3, 3]), 3, &EuclidianDistance::distance));
}
#[test]
+9 -8
View File
@@ -1,14 +1,15 @@
use ndarray::prelude::*;
use ndarray::{arr1, arr2};
use ndarray::FixedInitializer;
use crate::common::AnyNumber;
use ndarray::{Array1, ArrayBase, Data, Ix2};
pub mod knn;
pub trait Classifier<E1, E2>
{
pub trait Classifier<X, Y, SX>
where
X: AnyNumber,
Y: AnyNumber,
SX: Data<Elem = X>
{
fn fit(&mut self, x: &Array2<E1>, y: &Array1<E2>);
fn predict(&self, x: &Array2<E1>) -> Array1<E2>;
fn predict(&self, x: &ArrayBase<SX, Ix2>) -> Array1<Y>;
}
+7
View File
@@ -0,0 +1,7 @@
use num_traits::{Num, ToPrimitive, FromPrimitive};
use ndarray::{ScalarOperand};
pub trait AnyNumber: Num + ScalarOperand + ToPrimitive + FromPrimitive{}
impl<T> AnyNumber for T where T: Num + ScalarOperand + ToPrimitive + FromPrimitive {}
+2 -1
View File
@@ -1,4 +1,5 @@
pub mod classification;
pub mod math;
pub mod error;
pub mod algorithm;
pub mod algorithm;
pub mod common;
+9 -8
View File
@@ -1,13 +1,12 @@
use super::Distance;
use ndarray::{ArrayBase, Data, Dimension};
use num_traits::{Num, ToPrimitive};
use ndarray::{ScalarOperand};
use crate::common::AnyNumber;
pub struct EuclidianDistance{}
impl<A, S, D> Distance<ArrayBase<S, D>> for EuclidianDistance
where
A: Num + ScalarOperand + ToPrimitive,
A: AnyNumber,
S: Data<Elem = A>,
D: Dimension
{
@@ -25,17 +24,19 @@ where
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{arr1, Array};
use ndarray::arr1;
#[test]
fn measure_simple_euclidian_distance() {
let a = arr1(&[1, 2, 3]);
let b = arr1(&[4, 5, 6]);
let b = arr1(&[4, 5, 6]);
let d = EuclidianDistance::distance(&a, &b);
let d_arr = EuclidianDistance::distance(&a, &b);
let d_view = EuclidianDistance::distance(&a.view(), &b.view());
assert!((d - 5.19615242).abs() < 1e-8);
}
assert!((d_arr - 5.19615242).abs() < 1e-8);
assert!((d_view - 5.19615242).abs() < 1e-8);
}
#[test]
fn measure_simple_euclidian_distance_static() {