Extends basic KNN search algorithm
This commit is contained in:
@@ -1,44 +1,102 @@
|
|||||||
use std::cmp::Ordering;
|
use std::cmp::Ordering;
|
||||||
|
use std::mem;
|
||||||
|
use std::fmt::Display;
|
||||||
|
|
||||||
pub struct HeapSelect<T: std::cmp::Ord> {
|
#[derive(Debug)]
|
||||||
|
pub struct HeapSelect<T: PartialOrd> {
|
||||||
|
|
||||||
k: usize,
|
k: usize,
|
||||||
n: usize,
|
n: usize,
|
||||||
sorted: bool,
|
sorted: bool,
|
||||||
heap: Vec<T>
|
heap: Vec<T>
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: std::cmp::Ord> HeapSelect<T> {
|
impl<'a, T: PartialOrd> HeapSelect<T> {
|
||||||
|
|
||||||
pub fn from_vec(vec: Vec<T>) -> HeapSelect<T> {
|
pub fn with_capacity(k: usize) -> HeapSelect<T> {
|
||||||
HeapSelect{
|
HeapSelect{
|
||||||
k: vec.len(),
|
k: k,
|
||||||
n: 0,
|
n: 0,
|
||||||
sorted: false,
|
sorted: false,
|
||||||
heap: vec
|
heap: Vec::<T>::new()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add(&mut self, element: T) {
|
pub fn add(&mut self, element: T) {
|
||||||
self.sorted = false;
|
self.sorted = false;
|
||||||
if self.n < self.k {
|
if self.n < self.k {
|
||||||
self.heap[self.n] = element;
|
self.heap.push(element);
|
||||||
self.n += 1;
|
self.n += 1;
|
||||||
if self.n == self.k {
|
if self.n == self.k {
|
||||||
self.heapify();
|
self.heapify();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
self.n += 1;
|
self.n += 1;
|
||||||
if element.cmp(&self.heap[0]) == Ordering::Less {
|
if element.partial_cmp(&self.heap[0]) == Some(Ordering::Less) {
|
||||||
self.heap[0] = element;
|
self.heap[0] = element;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn heapify(&mut self){
|
pub fn heapify(&mut self) {
|
||||||
|
let n = self.heap.len();
|
||||||
|
for i in (0..=(n / 2 - 1)).rev() {
|
||||||
|
self.sift_down(i, n-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn peek(&self) -> &T {
|
||||||
|
return &self.heap[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn peek_mut(&mut self) -> &mut T {
|
||||||
|
return &mut self.heap[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sift_down(&mut self, from: usize, n: usize) {
|
||||||
|
let mut k = from;
|
||||||
|
while 2 * k <= n {
|
||||||
|
let mut j = 2 * k;
|
||||||
|
if j < n && self.heap[j] < self.heap[j + 1] {
|
||||||
|
j += 1;
|
||||||
|
}
|
||||||
|
if self.heap[k] >= self.heap[j] {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
self.heap.swap(k, j);
|
||||||
|
k = j;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(self) -> Vec<T> {
|
||||||
|
return self.heap;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sort(&mut self) {
|
||||||
|
HeapSelect::shuffle_sort(&mut self.heap, std::cmp::min(self.k,self.n));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn shuffle_sort(vec: &mut Vec<T>, n: usize) {
|
||||||
|
let mut inc = 1;
|
||||||
|
while inc <= n {
|
||||||
|
inc *= 3;
|
||||||
|
inc += 1
|
||||||
|
}
|
||||||
|
|
||||||
|
let len = n;
|
||||||
|
while inc >= 1 {
|
||||||
|
let mut i = inc;
|
||||||
|
while i < len {
|
||||||
|
let mut j = i;
|
||||||
|
while j >= inc && vec[j - inc] > vec[j] {
|
||||||
|
vec.swap(j - inc, j);
|
||||||
|
j -= inc;
|
||||||
|
}
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
inc /= 3
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -48,17 +106,52 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_from_vec() {
|
fn with_capacity() {
|
||||||
let heap = HeapSelect::from_vec(vec!(1, 2, 3));
|
let heap = HeapSelect::<i32>::with_capacity(3);
|
||||||
assert_eq!(3, heap.k);
|
assert_eq!(3, heap.k);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_add() {
|
fn test_add() {
|
||||||
let mut heap = HeapSelect::from_vec(Vec::<i32>::new());
|
let mut heap = HeapSelect::with_capacity(3);
|
||||||
heap.add(1);
|
heap.add(333);
|
||||||
heap.add(2);
|
heap.add(2);
|
||||||
heap.add(3);
|
heap.add(13);
|
||||||
assert_eq!(3, heap.n);
|
heap.add(10);
|
||||||
|
heap.add(40);
|
||||||
|
heap.add(30);
|
||||||
|
assert_eq!(6, heap.n);
|
||||||
|
assert_eq!(&10, heap.peek());
|
||||||
|
assert_eq!(&10, heap.peek_mut());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_add_ordered() {
|
||||||
|
let mut heap = HeapSelect::with_capacity(3);
|
||||||
|
heap.add(1.);
|
||||||
|
heap.add(2.);
|
||||||
|
heap.add(3.);
|
||||||
|
heap.add(4.);
|
||||||
|
heap.add(5.);
|
||||||
|
heap.add(6.);
|
||||||
|
let result = heap.get();
|
||||||
|
assert_eq!(vec![2., 3., 1.], result);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_shuffle_sort() {
|
||||||
|
let mut v1 = vec![10, 33, 22, 105, 12];
|
||||||
|
let n = v1.len();
|
||||||
|
HeapSelect::shuffle_sort(&mut v1, n);
|
||||||
|
assert_eq!(vec![10, 12, 22, 33, 105], v1);
|
||||||
|
|
||||||
|
let mut v2 = vec![10, 33, 22, 105, 12];
|
||||||
|
HeapSelect::shuffle_sort(&mut v2, 3);
|
||||||
|
assert_eq!(vec![10, 22, 33, 105, 12], v2);
|
||||||
|
|
||||||
|
let mut v3 = vec![4, 5, 3, 2, 1];
|
||||||
|
HeapSelect::shuffle_sort(&mut v3, 3);
|
||||||
|
assert_eq!(vec![3, 4, 5, 2, 1], v3);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
+100
-23
@@ -1,39 +1,81 @@
|
|||||||
use super::Classifier;
|
use super::Classifier;
|
||||||
use super::super::math::distance::Distance;
|
use crate::math::distance::Distance;
|
||||||
use super::super::math::distance::euclidian::EuclidianDistance;
|
use crate::math::distance::euclidian::EuclidianDistance;
|
||||||
|
use crate::algorithm::sort::heap_select::HeapSelect;
|
||||||
use ndarray::prelude::*;
|
use ndarray::prelude::*;
|
||||||
use num_traits::Signed;
|
use num_traits::Signed;
|
||||||
use num_traits::Float;
|
use num_traits::{Float, Num};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
use std::cmp::{Ordering, PartialOrd};
|
||||||
|
use std::fmt::Debug;
|
||||||
|
|
||||||
pub struct KNNClassifier<E> {
|
pub struct KNNClassifier<E> {
|
||||||
y: Option<Array1<E>>
|
y: Option<Array1<E>>
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait KNNAlgorithm<T>{
|
pub trait KNNAlgorithm<T: Clone + Debug>{
|
||||||
fn find(&self, from: &T, k: i32) -> &Vec<T>;
|
fn find(&self, from: &T, k: usize) -> Vec<&T>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct SimpleKNNAlgorithm<T, A, D>
|
pub struct SimpleKNNAlgorithm<T, D: Distance<T>>
|
||||||
where
|
|
||||||
A: Float,
|
|
||||||
D: Distance<T, A>
|
|
||||||
{
|
{
|
||||||
data: Vec<T>,
|
data: Vec<T>,
|
||||||
distance: D,
|
distance: D
|
||||||
__phantom: PhantomData<A>
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T, A, D> KNNAlgorithm<T> for SimpleKNNAlgorithm<T, A, D>
|
impl<T: Clone + Debug, D: Distance<T>> KNNAlgorithm<T> for SimpleKNNAlgorithm<T, D>
|
||||||
where
|
|
||||||
A: Float,
|
|
||||||
D: Distance<T, A>
|
|
||||||
{
|
{
|
||||||
fn find(&self, from: &T, k: i32) -> &Vec<T> {
|
fn find(&self, from: &T, k: usize) -> Vec<&T> {
|
||||||
&self.data
|
if k < 1 || k > self.data.len() {
|
||||||
|
panic!("k should be >= 1 and <= length(data)");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut heap = HeapSelect::<KNNPoint>::with_capacity(k);
|
||||||
|
|
||||||
|
for _ in 0..k {
|
||||||
|
heap.add(KNNPoint{
|
||||||
|
distance: Float::infinity(),
|
||||||
|
index: None
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in 0..self.data.len() {
|
||||||
|
|
||||||
|
let d = D::distance(&from, &self.data[i]);
|
||||||
|
let datum = heap.peek_mut();
|
||||||
|
if d < datum.distance {
|
||||||
|
datum.distance = d;
|
||||||
|
datum.index = Some(i);
|
||||||
|
heap.heapify();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
heap.sort();
|
||||||
|
|
||||||
|
heap.get().into_iter().flat_map(|x| x.index).map(|i| &self.data[i]).collect()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct KNNPoint {
|
||||||
|
distance: f64,
|
||||||
|
index: Option<usize>
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialOrd for KNNPoint {
|
||||||
|
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||||
|
self.distance.partial_cmp(&other.distance)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialEq for KNNPoint {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
self.distance == other.distance
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Eq for KNNPoint {}
|
||||||
|
|
||||||
impl<A1, A2> Classifier<A1, A2> for KNNClassifier<A2>
|
impl<A1, A2> Classifier<A1, A2> for KNNClassifier<A2>
|
||||||
where
|
where
|
||||||
A2: Signed + Clone,
|
A2: Signed + Clone,
|
||||||
@@ -51,7 +93,15 @@ where
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
struct SimpleDistance{}
|
||||||
|
|
||||||
|
impl Distance<i32> for SimpleDistance {
|
||||||
|
fn distance(a: &i32, b: &i32) -> f64 {
|
||||||
|
(a - b).abs() as f64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn knn_fit_predict() {
|
fn knn_fit_predict() {
|
||||||
@@ -64,13 +114,40 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn knn_find() {
|
fn knn_find() {
|
||||||
let sKnn = SimpleKNNAlgorithm{
|
let sKnn = SimpleKNNAlgorithm{
|
||||||
data: vec!(arr1(&[1., 2.]), arr1(&[1., 2.]), arr1(&[1., 2.])),
|
data: vec!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
|
||||||
distance: EuclidianDistance{},
|
distance: SimpleDistance{}
|
||||||
__phantom: PhantomData
|
};
|
||||||
|
|
||||||
|
assert_eq!(vec!(&2, &3, &1), sKnn.find(&2, 3));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn knn_point_eq() {
|
||||||
|
let point1 = KNNPoint{
|
||||||
|
distance: 10.,
|
||||||
|
index: Some(0)
|
||||||
};
|
};
|
||||||
|
|
||||||
assert_eq!(&vec!(arr1(&[1., 2.]), arr1(&[1., 2.]), arr1(&[1., 2.])), sKnn.find(&arr1(&[1., 2.]), 3));
|
let point2 = KNNPoint{
|
||||||
|
distance: 100.,
|
||||||
|
index: Some(1)
|
||||||
|
};
|
||||||
|
|
||||||
|
let point3 = KNNPoint{
|
||||||
|
distance: 10.,
|
||||||
|
index: Some(2)
|
||||||
|
};
|
||||||
|
|
||||||
|
let point_inf = KNNPoint{
|
||||||
|
distance: Float::infinity(),
|
||||||
|
index: Some(3)
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!(point2 > point1);
|
||||||
|
assert_eq!(point3, point1);
|
||||||
|
assert_ne!(point3, point2);
|
||||||
|
assert!(point_inf > point3 && point_inf > point2 && point_inf > point1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,21 +1,22 @@
|
|||||||
use super::Distance;
|
use super::Distance;
|
||||||
use ndarray::{ArrayBase, Data, Dimension};
|
use ndarray::{ArrayBase, Data, Dimension};
|
||||||
use num_traits::Float;
|
use num_traits::{Num, ToPrimitive};
|
||||||
|
use ndarray::{ScalarOperand};
|
||||||
|
|
||||||
pub struct EuclidianDistance{}
|
pub struct EuclidianDistance{}
|
||||||
|
|
||||||
impl<A, S, D> Distance<ArrayBase<S, D>, A> for EuclidianDistance
|
impl<A, S, D> Distance<ArrayBase<S, D>> for EuclidianDistance
|
||||||
where
|
where
|
||||||
A: Float,
|
A: Num + ScalarOperand + ToPrimitive,
|
||||||
S: Data<Elem = A>,
|
S: Data<Elem = A>,
|
||||||
D: Dimension
|
D: Dimension
|
||||||
{
|
{
|
||||||
|
|
||||||
fn distance(a: &ArrayBase<S, D>, b: &ArrayBase<S, D>) -> A {
|
fn distance(a: &ArrayBase<S, D>, b: &ArrayBase<S, D>) -> f64 {
|
||||||
if a.len() != b.len() {
|
if a.len() != b.len() {
|
||||||
panic!("vectors a and b have different length");
|
panic!("vectors a and b have different length");
|
||||||
} else {
|
} else {
|
||||||
((a - b)*(a - b)).sum().sqrt()
|
((a - b)*(a - b)).sum().to_f64().unwrap().sqrt()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -28,8 +29,8 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn measure_simple_euclidian_distance() {
|
fn measure_simple_euclidian_distance() {
|
||||||
let a = Array::from_vec(vec![1., 2., 3.]);
|
let a = arr1(&[1, 2, 3]);
|
||||||
let b = Array::from_vec(vec![4., 5., 6.]);
|
let b = arr1(&[4, 5, 6]);
|
||||||
|
|
||||||
let d = EuclidianDistance::distance(&a, &b);
|
let d = EuclidianDistance::distance(&a, &b);
|
||||||
|
|
||||||
|
|||||||
@@ -2,9 +2,7 @@ pub mod euclidian;
|
|||||||
|
|
||||||
use num_traits::Float;
|
use num_traits::Float;
|
||||||
|
|
||||||
pub trait Distance<T, A>
|
pub trait Distance<T>
|
||||||
where
|
|
||||||
A: Float
|
|
||||||
{
|
{
|
||||||
fn distance(a: &T, b: &T) -> A;
|
fn distance(a: &T, b: &T) -> f64;
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user