From 874d528f586bafe6352749aee1c3675df2f5686a Mon Sep 17 00:00:00 2001 From: Volodymyr Orlov Date: Mon, 23 Sep 2019 20:54:21 -0700 Subject: [PATCH] Adds CoverTree implementation --- src/algorithm/mod.rs | 3 +- src/algorithm/neighbour/cover_tree.rs | 161 +++++++++++++++++++++++ src/algorithm/neighbour/linear_search.rs | 129 ++++++++++++++++++ src/algorithm/neighbour/mod.rs | 11 ++ src/classification/knn.rs | 149 ++++----------------- src/math/mod.rs | 4 +- 6 files changed, 333 insertions(+), 124 deletions(-) create mode 100644 src/algorithm/neighbour/cover_tree.rs create mode 100644 src/algorithm/neighbour/linear_search.rs create mode 100644 src/algorithm/neighbour/mod.rs diff --git a/src/algorithm/mod.rs b/src/algorithm/mod.rs index 20ae7d2..2aec5b9 100644 --- a/src/algorithm/mod.rs +++ b/src/algorithm/mod.rs @@ -1 +1,2 @@ -pub mod sort; \ No newline at end of file +pub mod sort; +pub mod neighbour; \ No newline at end of file diff --git a/src/algorithm/neighbour/cover_tree.rs b/src/algorithm/neighbour/cover_tree.rs new file mode 100644 index 0000000..1e1fe34 --- /dev/null +++ b/src/algorithm/neighbour/cover_tree.rs @@ -0,0 +1,161 @@ +use crate::math; +use crate::algorithm::neighbour::KNNAlgorithm; +use std::collections::HashMap; +use std::fmt::Debug; + +pub struct CoverTree<'a, T> +where T: Debug +{ + + base: f64, + max_level: i8, + min_level: i8, + distance: &'a Fn(&T, &T) -> f64, + nodes: Vec> +} + +impl<'a, T> CoverTree<'a, T> +where T: Debug +{ + + pub fn new(data: Vec, distance: &'a Fn(&T, &T) -> f64) -> CoverTree { + let mut tree = CoverTree { + base: 2f64, + max_level: 10, + min_level: 10, + distance: distance, + nodes: Vec::new() + }; + + for p in data { + tree.insert(p); + } + + tree + + } + + pub fn new_node(&mut self, data: T) -> NodeId { + let next_index = self.nodes.len(); + let node_id = NodeId { index: next_index }; + self.nodes.push( + Node { + index: node_id, + data: data, + parent: None, + children: HashMap::new() + }); + node_id + } + + fn insert(&mut self, p: T) { + if self.nodes.is_empty(){ + self.new_node(p); + } else { + let mut parent: Option = Option::None; + let mut p_i = 0; + let mut qi_p_ds = vec!((self.root(), (self.distance)(&p, &self.root().data))); + let mut i = self.max_level; + loop { + let i_d = self.base.powf(i as f64); + let q_p_ds = self.get_children_dist(&p, &qi_p_ds, i); + let d_p_Q = self.min_ds(&q_p_ds); + if d_p_Q < math::small_e { + return + } else if d_p_Q > i_d { + break; + } + if self.min_ds(&qi_p_ds) <= self.base.powf(i as f64){ + parent = q_p_ds.iter().find(|(_, d)| d <= &i_d).map(|(n, d)| n.index); + p_i = i; + } + + qi_p_ds = q_p_ds.into_iter().filter(|(n, d)| d <= &i_d).collect(); + i -= 1; + } + + let new_node = self.new_node(p); + self.nodes.get_mut(parent.unwrap().index).unwrap().children.insert(p_i, new_node); + self.min_level = i8::min(self.min_level, p_i-1); + } + } + + fn root(&self) -> &Node { + self.nodes.first().unwrap() + } + + fn get_children_dist<'b>(&'b self, p: &T, qi_p_ds: &Vec<(&'b Node, f64)>, i: i8) -> Vec<(&'b Node, f64)> { + + let mut children = Vec::<(&'b Node, f64)>::new(); + + children.extend(qi_p_ds.iter().cloned()); + + let q: Vec<&Node> = qi_p_ds.iter().flat_map(|(n, _)| self.get_child(n, i)).collect(); + + children.extend(q.into_iter().map(|n| (n, (self.distance)(&n.data, &p)))); + + children + + } + + fn min_ds(&self, q_p_ds: &Vec<(&Node, f64)>) -> f64 { + q_p_ds.into_iter().min_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap()).unwrap().1 + } + + fn min_p_ds(&self, q_p_ds: &mut Vec<(&Node, f64)>, k: usize) -> f64 { + q_p_ds.sort_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap()); + q_p_ds[..usize::min(q_p_ds.len(), k)].last().unwrap().1 + } + + fn get_child(&self, node: &Node, i: i8) -> Option<&Node> { + node.children.get(&i).and_then(|n_id| self.nodes.get(n_id.index)) + } + +} + +impl<'a, T> KNNAlgorithm for CoverTree<'a, T> +where T: Debug +{ + fn find(&self, p: &T, k: usize) -> Vec{ + let mut qi_p_ds = vec!((self.root(), (self.distance)(&p, &self.root().data))); + for i in (self.min_level..self.max_level+1).rev() { + let i_d = self.base.powf(i as f64); + let mut q_p_ds = self.get_children_dist(&p, &qi_p_ds, i); + let d_p_q = self.min_p_ds(&mut q_p_ds, k); + qi_p_ds = q_p_ds.into_iter().filter(|(n, d)| d <= &(d_p_q + i_d)).collect(); + } + qi_p_ds.sort_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap()); + qi_p_ds[..usize::min(qi_p_ds.len(), k)].iter().map(|(n, _)| n.index.index).collect() + } +} + +#[derive(Debug, Clone, Copy)] +pub struct NodeId { + index: usize, +} + +#[derive(Debug)] +struct Node { + index: NodeId, + data: T, + children: HashMap, + parent: Option +} + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn cover_tree_test() { + let data = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9); + let distance = |a: &i32, b: &i32| -> f64 { + (a - b).abs() as f64 + }; + let tree = CoverTree::::new(data, &distance); + let nearest_3 = tree.find(&5, 3); + assert_eq!(vec!(4, 5, 3), nearest_3); + } + +} \ No newline at end of file diff --git a/src/algorithm/neighbour/linear_search.rs b/src/algorithm/neighbour/linear_search.rs new file mode 100644 index 0000000..785e211 --- /dev/null +++ b/src/algorithm/neighbour/linear_search.rs @@ -0,0 +1,129 @@ +use crate::algorithm::neighbour::KNNAlgorithm; +use crate::algorithm::sort::heap_select::HeapSelect; +use std::cmp::{Ordering, PartialOrd}; +use num_traits::Float; + +pub struct LinearKNNSearch<'a, T> { + distance: Box f64 + 'a>, + data: Vec +} + +impl<'a, T> KNNAlgorithm for LinearKNNSearch<'a, T> +{ + fn find(&self, from: &T, k: usize) -> Vec { + if k < 1 || k > self.data.len() { + panic!("k should be >= 1 and <= length(data)"); + } + + let mut heap = HeapSelect::::with_capacity(k); + + for _ in 0..k { + heap.add(KNNPoint{ + distance: Float::infinity(), + index: None + }); + } + + for i in 0..self.data.len() { + + let d = (self.distance)(&from, &self.data[i]); + let datum = heap.peek_mut(); + if d < datum.distance { + datum.distance = d; + datum.index = Some(i); + heap.heapify(); + } + } + + heap.sort(); + + heap.get().into_iter().flat_map(|x| x.index).collect() + } +} + +impl<'a, T> LinearKNNSearch<'a, T> { + pub fn new(data: Vec, distance: &'a Fn(&T, &T) -> f64) -> LinearKNNSearch{ + LinearKNNSearch{ + data: data, + distance: Box::new(distance) + } + } +} + +#[derive(Debug)] +struct KNNPoint { + distance: f64, + index: Option +} + +impl PartialOrd for KNNPoint { + fn partial_cmp(&self, other: &Self) -> Option { + self.distance.partial_cmp(&other.distance) + } +} + +impl PartialEq for KNNPoint { + fn eq(&self, other: &Self) -> bool { + self.distance == other.distance + } +} + +impl Eq for KNNPoint {} + +#[cfg(test)] +mod tests { + use super::*; + use crate::math::distance::Distance; + use ndarray::{arr1, Array1}; + + struct SimpleDistance{} + + impl SimpleDistance { + fn distance(a: &i32, b: &i32) -> f64 { + (a - b).abs() as f64 + } + } + + #[test] + fn knn_find() { + let data1 = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + + let algorithm1 = LinearKNNSearch::new(data1, &SimpleDistance::distance); + + assert_eq!(vec!(1, 2, 0), algorithm1.find(&2, 3)); + + let data2 = vec!(arr1(&[1, 1]), arr1(&[2, 2]), arr1(&[3, 3]), arr1(&[4, 4]), arr1(&[5, 5])); + + let algorithm2 = LinearKNNSearch::new(data2, &Array1::distance); + + assert_eq!(vec!(2, 3, 1), algorithm2.find(&arr1(&[3, 3]), 3)); + } + + #[test] + fn knn_point_eq() { + let point1 = KNNPoint{ + distance: 10., + index: Some(0) + }; + + let point2 = KNNPoint{ + distance: 100., + index: Some(1) + }; + + let point3 = KNNPoint{ + distance: 10., + index: Some(2) + }; + + let point_inf = KNNPoint{ + distance: Float::infinity(), + index: Some(3) + }; + + assert!(point2 > point1); + assert_eq!(point3, point1); + assert_ne!(point3, point2); + assert!(point_inf > point3 && point_inf > point2 && point_inf > point1); + } +} \ No newline at end of file diff --git a/src/algorithm/neighbour/mod.rs b/src/algorithm/neighbour/mod.rs new file mode 100644 index 0000000..ff6a070 --- /dev/null +++ b/src/algorithm/neighbour/mod.rs @@ -0,0 +1,11 @@ +pub mod cover_tree; +pub mod linear_search; + +pub enum KNNAlgorithmName { + CoverTree, + LinearSearch, +} + +pub trait KNNAlgorithm{ + fn find(&self, from: &T, k: usize) -> Vec; +} \ No newline at end of file diff --git a/src/classification/knn.rs b/src/classification/knn.rs index a8518fd..56daecd 100644 --- a/src/classification/knn.rs +++ b/src/classification/knn.rs @@ -1,31 +1,33 @@ use super::Classifier; use std::collections::HashSet; -use crate::algorithm::sort::heap_select::HeapSelect; +use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName}; +use crate::algorithm::neighbour::linear_search::LinearKNNSearch; +use crate::algorithm::neighbour::cover_tree::CoverTree; use crate::common::Nominal; use ndarray::{ArrayBase, Data, Ix1, Ix2}; -use num_traits::{Float}; -use std::cmp::{Ordering, PartialOrd}; +use std::fmt::Debug; type F = Fn(&X, &X) -> f64; -pub struct KNNClassifier +pub struct KNNClassifier<'a, X, Y> where - Y: Nominal + Y: Nominal, + X: Debug { classes: Vec, - y: Vec, - data: Vec, - distance: Box>, + y: Vec, + knn_algorithm: Box + 'a>, k: usize, } -impl KNNClassifier +impl<'a, X, Y> KNNClassifier<'a, X, Y> where - Y: Nominal + Y: Nominal, + X: Debug { - pub fn fit(x: Vec, y: Vec, k: usize, distance: &'static F) -> KNNClassifier { + pub fn fit(x: Vec, y: Vec, k: usize, distance: &'a F, algorithm: KNNAlgorithmName) -> KNNClassifier { assert!(Vec::len(&x) == Vec::len(&y), format!("Size of x should equal size of y; |x|=[{}], |y|=[{}]", Vec::len(&x), Vec::len(&y))); @@ -33,20 +35,27 @@ where let c_hash: HashSet = y.clone().into_iter().collect(); let classes: Vec = c_hash.into_iter().collect(); - let y_i:Vec = y.into_iter().map(|y| classes.iter().position(|yy| yy == &y).unwrap()).collect(); + let y_i:Vec = y.into_iter().map(|y| classes.iter().position(|yy| yy == &y).unwrap()).collect(); + + let knn_algorithm: Box + 'a> = match algorithm { + KNNAlgorithmName::CoverTree => Box::new(CoverTree::::new(x, distance)), + KNNAlgorithmName::LinearSearch => Box::new(LinearKNNSearch::::new(x, distance)) + }; + + KNNClassifier{classes:classes, y: y_i, k: k, knn_algorithm: knn_algorithm} - KNNClassifier{classes:classes, y: y_i, data: x, k: k, distance: Box::new(distance)} } } -impl Classifier for KNNClassifier +impl<'a, X, Y> Classifier for KNNClassifier<'a, X, Y> where - Y: Nominal + Y: Nominal, + X: Debug { fn predict(&self, x: &X) -> Y { - let idxs = self.data.find(x, self.k, &self.distance); + let idxs = self.knn_algorithm.find(x, self.k); let mut c = vec![0; self.classes.len()]; let mut max_c = 0; let mut max_i = 0; @@ -79,123 +88,19 @@ impl NDArrayUtils { } } -pub trait KNNAlgorithm{ - fn find(&self, from: &T, k: usize, d: &Fn(&T, &T) -> f64) -> Vec; -} - -impl KNNAlgorithm for Vec -{ - fn find(&self, from: &T, k: usize, d: &Fn(&T, &T) -> f64) -> Vec { - if k < 1 || k > self.len() { - panic!("k should be >= 1 and <= length(data)"); - } - - let mut heap = HeapSelect::::with_capacity(k); - - for _ in 0..k { - heap.add(KNNPoint{ - distance: Float::infinity(), - index: None - }); - } - - for i in 0..self.len() { - - let d = d(&from, &self[i]); - let datum = heap.peek_mut(); - if d < datum.distance { - datum.distance = d; - datum.index = Some(i); - heap.heapify(); - } - } - - heap.sort(); - - heap.get().into_iter().flat_map(|x| x.index).collect() - } -} - -#[derive(Debug)] -struct KNNPoint { - distance: f64, - index: Option -} - -impl PartialOrd for KNNPoint { - fn partial_cmp(&self, other: &Self) -> Option { - self.distance.partial_cmp(&other.distance) - } -} - -impl PartialEq for KNNPoint { - fn eq(&self, other: &Self) -> bool { - self.distance == other.distance - } -} - -impl Eq for KNNPoint {} - #[cfg(test)] mod tests { use super::*; use crate::math::distance::Distance; - use ndarray::{arr1, arr2, Array1}; - - struct SimpleDistance{} - - impl SimpleDistance { - fn distance(a: &i32, b: &i32) -> f64 { - (a - b).abs() as f64 - } - } + use ndarray::{arr1, arr2, Array1}; #[test] fn knn_fit_predict() { let x = arr2(&[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]); let y = arr1(&[2, 2, 2, 3, 3]); - let knn = KNNClassifier::fit(NDArrayUtils::array2_to_vec(&x), y.to_vec(), 3, &Array1::distance); + let knn = KNNClassifier::fit(NDArrayUtils::array2_to_vec(&x), y.to_vec(), 3, &Array1::distance, KNNAlgorithmName::LinearSearch); let r = knn.predict_vec(&NDArrayUtils::array2_to_vec(&x)); assert_eq!(5, Vec::len(&r)); assert_eq!(y.to_vec(), r); } - - #[test] - fn knn_find() { - let data1 = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); - - assert_eq!(vec!(1, 2, 0), data1.find(&2, 3, &SimpleDistance::distance)); - - let data2 = vec!(arr1(&[1, 1]), arr1(&[2, 2]), arr1(&[3, 3]), arr1(&[4, 4]), arr1(&[5, 5])); - - assert_eq!(vec!(2, 3, 1), data2.find(&arr1(&[3, 3]), 3, &Array1::distance)); - } - - #[test] - fn knn_point_eq() { - let point1 = KNNPoint{ - distance: 10., - index: Some(0) - }; - - let point2 = KNNPoint{ - distance: 100., - index: Some(1) - }; - - let point3 = KNNPoint{ - distance: 10., - index: Some(2) - }; - - let point_inf = KNNPoint{ - distance: Float::infinity(), - index: Some(3) - }; - - assert!(point2 > point1); - assert_eq!(point3, point1); - assert_ne!(point3, point2); - assert!(point_inf > point3 && point_inf > point2 && point_inf > point1); - } } \ No newline at end of file diff --git a/src/math/mod.rs b/src/math/mod.rs index 4567f3e..a4eeb67 100644 --- a/src/math/mod.rs +++ b/src/math/mod.rs @@ -1 +1,3 @@ -pub mod distance; \ No newline at end of file +pub mod distance; + +pub static small_e:f64 = 0.000000001f64; \ No newline at end of file