Extends basic KNN search algorithm
This commit is contained in:
+100
-23
@@ -1,39 +1,81 @@
|
||||
use super::Classifier;
|
||||
use super::super::math::distance::Distance;
|
||||
use super::super::math::distance::euclidian::EuclidianDistance;
|
||||
use crate::math::distance::Distance;
|
||||
use crate::math::distance::euclidian::EuclidianDistance;
|
||||
use crate::algorithm::sort::heap_select::HeapSelect;
|
||||
use ndarray::prelude::*;
|
||||
use num_traits::Signed;
|
||||
use num_traits::Float;
|
||||
use num_traits::{Float, Num};
|
||||
use std::marker::PhantomData;
|
||||
use std::cmp::{Ordering, PartialOrd};
|
||||
use std::fmt::Debug;
|
||||
|
||||
pub struct KNNClassifier<E> {
|
||||
y: Option<Array1<E>>
|
||||
}
|
||||
|
||||
pub trait KNNAlgorithm<T>{
|
||||
fn find(&self, from: &T, k: i32) -> &Vec<T>;
|
||||
pub trait KNNAlgorithm<T: Clone + Debug>{
|
||||
fn find(&self, from: &T, k: usize) -> Vec<&T>;
|
||||
}
|
||||
|
||||
pub struct SimpleKNNAlgorithm<T, A, D>
|
||||
where
|
||||
A: Float,
|
||||
D: Distance<T, A>
|
||||
pub struct SimpleKNNAlgorithm<T, D: Distance<T>>
|
||||
{
|
||||
data: Vec<T>,
|
||||
distance: D,
|
||||
__phantom: PhantomData<A>
|
||||
distance: D
|
||||
}
|
||||
|
||||
impl<T, A, D> KNNAlgorithm<T> for SimpleKNNAlgorithm<T, A, D>
|
||||
where
|
||||
A: Float,
|
||||
D: Distance<T, A>
|
||||
impl<T: Clone + Debug, D: Distance<T>> KNNAlgorithm<T> for SimpleKNNAlgorithm<T, D>
|
||||
{
|
||||
fn find(&self, from: &T, k: i32) -> &Vec<T> {
|
||||
&self.data
|
||||
fn find(&self, from: &T, k: usize) -> Vec<&T> {
|
||||
if k < 1 || k > self.data.len() {
|
||||
panic!("k should be >= 1 and <= length(data)");
|
||||
}
|
||||
|
||||
let mut heap = HeapSelect::<KNNPoint>::with_capacity(k);
|
||||
|
||||
for _ in 0..k {
|
||||
heap.add(KNNPoint{
|
||||
distance: Float::infinity(),
|
||||
index: None
|
||||
});
|
||||
}
|
||||
|
||||
for i in 0..self.data.len() {
|
||||
|
||||
let d = D::distance(&from, &self.data[i]);
|
||||
let datum = heap.peek_mut();
|
||||
if d < datum.distance {
|
||||
datum.distance = d;
|
||||
datum.index = Some(i);
|
||||
heap.heapify();
|
||||
}
|
||||
}
|
||||
|
||||
heap.sort();
|
||||
|
||||
heap.get().into_iter().flat_map(|x| x.index).map(|i| &self.data[i]).collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct KNNPoint {
|
||||
distance: f64,
|
||||
index: Option<usize>
|
||||
}
|
||||
|
||||
impl PartialOrd for KNNPoint {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
self.distance.partial_cmp(&other.distance)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for KNNPoint {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.distance == other.distance
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for KNNPoint {}
|
||||
|
||||
impl<A1, A2> Classifier<A1, A2> for KNNClassifier<A2>
|
||||
where
|
||||
A2: Signed + Clone,
|
||||
@@ -51,7 +93,15 @@ where
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use super::*;
|
||||
|
||||
struct SimpleDistance{}
|
||||
|
||||
impl Distance<i32> for SimpleDistance {
|
||||
fn distance(a: &i32, b: &i32) -> f64 {
|
||||
(a - b).abs() as f64
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn knn_fit_predict() {
|
||||
@@ -64,13 +114,40 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn knn_find() {
|
||||
fn knn_find() {
|
||||
let sKnn = SimpleKNNAlgorithm{
|
||||
data: vec!(arr1(&[1., 2.]), arr1(&[1., 2.]), arr1(&[1., 2.])),
|
||||
distance: EuclidianDistance{},
|
||||
__phantom: PhantomData
|
||||
data: vec!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
|
||||
distance: SimpleDistance{}
|
||||
};
|
||||
|
||||
assert_eq!(vec!(&2, &3, &1), sKnn.find(&2, 3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn knn_point_eq() {
|
||||
let point1 = KNNPoint{
|
||||
distance: 10.,
|
||||
index: Some(0)
|
||||
};
|
||||
|
||||
assert_eq!(&vec!(arr1(&[1., 2.]), arr1(&[1., 2.]), arr1(&[1., 2.])), sKnn.find(&arr1(&[1., 2.]), 3));
|
||||
let point2 = KNNPoint{
|
||||
distance: 100.,
|
||||
index: Some(1)
|
||||
};
|
||||
|
||||
let point3 = KNNPoint{
|
||||
distance: 10.,
|
||||
index: Some(2)
|
||||
};
|
||||
|
||||
let point_inf = KNNPoint{
|
||||
distance: Float::infinity(),
|
||||
index: Some(3)
|
||||
};
|
||||
|
||||
assert!(point2 > point1);
|
||||
assert_eq!(point3, point1);
|
||||
assert_ne!(point3, point2);
|
||||
assert!(point_inf > point3 && point_inf > point2 && point_inf > point1);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user