diff --git a/src/algorithm/sort/heap_select.rs b/src/algorithm/sort/heap_select.rs index 0fb7677..60c25ae 100644 --- a/src/algorithm/sort/heap_select.rs +++ b/src/algorithm/sort/heap_select.rs @@ -1,44 +1,102 @@ use std::cmp::Ordering; +use std::mem; +use std::fmt::Display; -pub struct HeapSelect { +#[derive(Debug)] +pub struct HeapSelect { k: usize, n: usize, sorted: bool, heap: Vec - - } -impl HeapSelect { +impl<'a, T: PartialOrd> HeapSelect { - pub fn from_vec(vec: Vec) -> HeapSelect { + pub fn with_capacity(k: usize) -> HeapSelect { HeapSelect{ - k: vec.len(), + k: k, n: 0, sorted: false, - heap: vec + heap: Vec::::new() } } - pub fn add(&mut self, element: T) { + pub fn add(&mut self, element: T) { self.sorted = false; if self.n < self.k { - self.heap[self.n] = element; + self.heap.push(element); self.n += 1; if self.n == self.k { self.heapify(); } } else { self.n += 1; - if element.cmp(&self.heap[0]) == Ordering::Less { - self.heap[0] = element; + if element.partial_cmp(&self.heap[0]) == Some(Ordering::Less) { + self.heap[0] = element; } } } - pub fn heapify(&mut self){ + pub fn heapify(&mut self) { + let n = self.heap.len(); + for i in (0..=(n / 2 - 1)).rev() { + self.sift_down(i, n-1); + } + } + pub fn peek(&self) -> &T { + return &self.heap[0]; + } + + pub fn peek_mut(&mut self) -> &mut T { + return &mut self.heap[0]; + } + + pub fn sift_down(&mut self, from: usize, n: usize) { + let mut k = from; + while 2 * k <= n { + let mut j = 2 * k; + if j < n && self.heap[j] < self.heap[j + 1] { + j += 1; + } + if self.heap[k] >= self.heap[j] { + break; + } + self.heap.swap(k, j); + k = j; + } + + } + + pub fn get(self) -> Vec { + return self.heap; + } + + pub fn sort(&mut self) { + HeapSelect::shuffle_sort(&mut self.heap, std::cmp::min(self.k,self.n)); + } + + pub fn shuffle_sort(vec: &mut Vec, n: usize) { + let mut inc = 1; + while inc <= n { + inc *= 3; + inc += 1 + } + + let len = n; + while inc >= 1 { + let mut i = inc; + while i < len { + let mut j = i; + while j >= inc && vec[j - inc] > vec[j] { + vec.swap(j - inc, j); + j -= inc; + } + i += 1; + } + inc /= 3 + } } } @@ -48,17 +106,52 @@ mod tests { use super::*; #[test] - fn test_from_vec() { - let heap = HeapSelect::from_vec(vec!(1, 2, 3)); + fn with_capacity() { + let heap = HeapSelect::::with_capacity(3); assert_eq!(3, heap.k); } #[test] - fn test_add() { - let mut heap = HeapSelect::from_vec(Vec::::new()); - heap.add(1); + fn test_add() { + let mut heap = HeapSelect::with_capacity(3); + heap.add(333); heap.add(2); - heap.add(3); - assert_eq!(3, heap.n); + heap.add(13); + heap.add(10); + heap.add(40); + heap.add(30); + assert_eq!(6, heap.n); + assert_eq!(&10, heap.peek()); + assert_eq!(&10, heap.peek_mut()); } + + #[test] + fn test_add_ordered() { + let mut heap = HeapSelect::with_capacity(3); + heap.add(1.); + heap.add(2.); + heap.add(3.); + heap.add(4.); + heap.add(5.); + heap.add(6.); + let result = heap.get(); + assert_eq!(vec![2., 3., 1.], result); + } + + #[test] + fn test_shuffle_sort() { + let mut v1 = vec![10, 33, 22, 105, 12]; + let n = v1.len(); + HeapSelect::shuffle_sort(&mut v1, n); + assert_eq!(vec![10, 12, 22, 33, 105], v1); + + let mut v2 = vec![10, 33, 22, 105, 12]; + HeapSelect::shuffle_sort(&mut v2, 3); + assert_eq!(vec![10, 22, 33, 105, 12], v2); + + let mut v3 = vec![4, 5, 3, 2, 1]; + HeapSelect::shuffle_sort(&mut v3, 3); + assert_eq!(vec![3, 4, 5, 2, 1], v3); + } + } \ No newline at end of file diff --git a/src/classification/knn.rs b/src/classification/knn.rs index 0a00d9d..073de53 100644 --- a/src/classification/knn.rs +++ b/src/classification/knn.rs @@ -1,39 +1,81 @@ use super::Classifier; -use super::super::math::distance::Distance; -use super::super::math::distance::euclidian::EuclidianDistance; +use crate::math::distance::Distance; +use crate::math::distance::euclidian::EuclidianDistance; +use crate::algorithm::sort::heap_select::HeapSelect; use ndarray::prelude::*; use num_traits::Signed; -use num_traits::Float; +use num_traits::{Float, Num}; use std::marker::PhantomData; +use std::cmp::{Ordering, PartialOrd}; +use std::fmt::Debug; pub struct KNNClassifier { y: Option> } -pub trait KNNAlgorithm{ - fn find(&self, from: &T, k: i32) -> &Vec; +pub trait KNNAlgorithm{ + fn find(&self, from: &T, k: usize) -> Vec<&T>; } -pub struct SimpleKNNAlgorithm -where - A: Float, - D: Distance +pub struct SimpleKNNAlgorithm> { data: Vec, - distance: D, - __phantom: PhantomData + distance: D } -impl KNNAlgorithm for SimpleKNNAlgorithm -where - A: Float, - D: Distance +impl> KNNAlgorithm for SimpleKNNAlgorithm { - fn find(&self, from: &T, k: i32) -> &Vec { - &self.data + fn find(&self, from: &T, k: usize) -> Vec<&T> { + if k < 1 || k > self.data.len() { + panic!("k should be >= 1 and <= length(data)"); + } + + let mut heap = HeapSelect::::with_capacity(k); + + for _ in 0..k { + heap.add(KNNPoint{ + distance: Float::infinity(), + index: None + }); + } + + for i in 0..self.data.len() { + + let d = D::distance(&from, &self.data[i]); + let datum = heap.peek_mut(); + if d < datum.distance { + datum.distance = d; + datum.index = Some(i); + heap.heapify(); + } + } + + heap.sort(); + + heap.get().into_iter().flat_map(|x| x.index).map(|i| &self.data[i]).collect() } } +#[derive(Debug)] +struct KNNPoint { + distance: f64, + index: Option +} + +impl PartialOrd for KNNPoint { + fn partial_cmp(&self, other: &Self) -> Option { + self.distance.partial_cmp(&other.distance) + } +} + +impl PartialEq for KNNPoint { + fn eq(&self, other: &Self) -> bool { + self.distance == other.distance + } +} + +impl Eq for KNNPoint {} + impl Classifier for KNNClassifier where A2: Signed + Clone, @@ -51,7 +93,15 @@ where #[cfg(test)] mod tests { - use super::*; + use super::*; + + struct SimpleDistance{} + + impl Distance for SimpleDistance { + fn distance(a: &i32, b: &i32) -> f64 { + (a - b).abs() as f64 + } + } #[test] fn knn_fit_predict() { @@ -64,13 +114,40 @@ mod tests { } #[test] - fn knn_find() { + fn knn_find() { let sKnn = SimpleKNNAlgorithm{ - data: vec!(arr1(&[1., 2.]), arr1(&[1., 2.]), arr1(&[1., 2.])), - distance: EuclidianDistance{}, - __phantom: PhantomData + data: vec!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), + distance: SimpleDistance{} + }; + + assert_eq!(vec!(&2, &3, &1), sKnn.find(&2, 3)); + } + + #[test] + fn knn_point_eq() { + let point1 = KNNPoint{ + distance: 10., + index: Some(0) }; - assert_eq!(&vec!(arr1(&[1., 2.]), arr1(&[1., 2.]), arr1(&[1., 2.])), sKnn.find(&arr1(&[1., 2.]), 3)); + let point2 = KNNPoint{ + distance: 100., + index: Some(1) + }; + + let point3 = KNNPoint{ + distance: 10., + index: Some(2) + }; + + let point_inf = KNNPoint{ + distance: Float::infinity(), + index: Some(3) + }; + + assert!(point2 > point1); + assert_eq!(point3, point1); + assert_ne!(point3, point2); + assert!(point_inf > point3 && point_inf > point2 && point_inf > point1); } } \ No newline at end of file diff --git a/src/math/distance/euclidian.rs b/src/math/distance/euclidian.rs index 84e5707..965413f 100644 --- a/src/math/distance/euclidian.rs +++ b/src/math/distance/euclidian.rs @@ -1,21 +1,22 @@ use super::Distance; use ndarray::{ArrayBase, Data, Dimension}; -use num_traits::Float; +use num_traits::{Num, ToPrimitive}; +use ndarray::{ScalarOperand}; pub struct EuclidianDistance{} -impl Distance, A> for EuclidianDistance +impl Distance> for EuclidianDistance where - A: Float, + A: Num + ScalarOperand + ToPrimitive, S: Data, D: Dimension { - fn distance(a: &ArrayBase, b: &ArrayBase) -> A { + fn distance(a: &ArrayBase, b: &ArrayBase) -> f64 { if a.len() != b.len() { panic!("vectors a and b have different length"); } else { - ((a - b)*(a - b)).sum().sqrt() + ((a - b)*(a - b)).sum().to_f64().unwrap().sqrt() } } } @@ -28,8 +29,8 @@ mod tests { #[test] fn measure_simple_euclidian_distance() { - let a = Array::from_vec(vec![1., 2., 3.]); - let b = Array::from_vec(vec![4., 5., 6.]); + let a = arr1(&[1, 2, 3]); + let b = arr1(&[4, 5, 6]); let d = EuclidianDistance::distance(&a, &b); diff --git a/src/math/distance/mod.rs b/src/math/distance/mod.rs index 3967638..b58dfea 100644 --- a/src/math/distance/mod.rs +++ b/src/math/distance/mod.rs @@ -2,9 +2,7 @@ pub mod euclidian; use num_traits::Float; -pub trait Distance -where - A: Float +pub trait Distance { - fn distance(a: &T, b: &T) -> A; + fn distance(a: &T, b: &T) -> f64; } \ No newline at end of file