diff --git a/Cargo.toml b/Cargo.toml index e98db56..bda1c06 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,9 +5,10 @@ authors = ["SmartCore Developers"] edition = "2018" [features] -default = [] +default = ["datasets"] ndarray-bindings = ["ndarray"] nalgebra-bindings = ["nalgebra"] +datasets = [] [dependencies] ndarray = { version = "0.13", optional = true } diff --git a/src/algorithm/neighbour/cover_tree.rs b/src/algorithm/neighbour/cover_tree.rs index 43e366e..8ffec19 100644 --- a/src/algorithm/neighbour/cover_tree.rs +++ b/src/algorithm/neighbour/cover_tree.rs @@ -21,14 +21,11 @@ //! tree.find(&5, 3); // find 3 knn points from 5 //! //! ``` -use core::hash::{Hash, Hasher}; -use std::collections::{HashMap, HashSet}; use std::fmt::Debug; -use std::iter::FromIterator; use serde::{Deserialize, Serialize}; -use crate::algorithm::sort::heap_select::HeapSelect; +use crate::algorithm::sort::heap_select::HeapSelection; use crate::math::distance::Distance; use crate::math::num::RealNumber; @@ -36,329 +33,358 @@ use crate::math::num::RealNumber; #[derive(Serialize, Deserialize, Debug)] pub struct CoverTree> { base: F, - max_level: i8, - min_level: i8, + inv_log_base: F, distance: D, - nodes: Vec>, + root: Node, + data: Vec, + identical_excluded: bool, } -impl> CoverTree { +impl> PartialEq for CoverTree { + fn eq(&self, other: &Self) -> bool { + if self.data.len() != other.data.len() { + return false; + } + for i in 0..self.data.len() { + if self.distance.distance(&self.data[i], &other.data[i]) != F::zero() { + return false; + } + } + return true; + } +} + +#[derive(Debug, Serialize, Deserialize)] +struct Node { + idx: usize, + max_dist: F, + parent_dist: F, + children: Vec>, + scale: i64, +} + +#[derive(Debug, Serialize, Deserialize)] +struct DistanceSet { + idx: usize, + dist: Vec, +} + +impl> CoverTree { /// Construct a cover tree. /// * `data` - vector of data points to search for. /// * `distance` - distance metric to use for searching. This function should extend [`Distance`](../../../math/distance/index.html) interface. - pub fn new(mut data: Vec, distance: D) -> CoverTree { + pub fn new(data: Vec, distance: D) -> CoverTree { + let base = F::from_f64(1.3).unwrap(); + let root = Node { + idx: 0, + max_dist: F::zero(), + parent_dist: F::zero(), + children: Vec::new(), + scale: 0, + }; let mut tree = CoverTree { - base: F::two(), - max_level: 100, - min_level: 100, + base: base, + inv_log_base: F::one() / base.ln(), distance: distance, - nodes: Vec::new(), + root: root, + data: data, + identical_excluded: false, }; - let p = tree.new_node(None, data.remove(0)); - tree.construct(p, data, Vec::new(), 10); + tree.build_cover_tree(); tree } - /// Insert new data point into the cover tree. - /// * `p` - new data points. - pub fn insert(&mut self, p: T) { - if self.nodes.is_empty() { - self.new_node(None, p); - } else { - let mut parent: Option = Option::None; - let mut p_i = 0; - let mut qi_p_ds = vec![(self.root(), self.distance.distance(&p, &self.root().data))]; - let mut i = self.max_level; - loop { - let i_d = self.base.powf(F::from(i).unwrap()); - let q_p_ds = self.get_children_dist(&p, &qi_p_ds, i); - let d_p_q = self.min_by_distance(&q_p_ds); - if d_p_q < F::epsilon() { - return; - } else if d_p_q > i_d { - break; - } - if self.min_by_distance(&qi_p_ds) <= self.base.powf(F::from(i).unwrap()) { - parent = q_p_ds.iter().find(|(_, d)| d <= &i_d).map(|(n, _)| n.index); - p_i = i; - } - - qi_p_ds = q_p_ds.into_iter().filter(|(_, d)| d <= &i_d).collect(); - i -= 1; - } - - let new_node = self.new_node(parent, p); - self.add_child(parent.unwrap(), new_node, p_i); - self.min_level = i8::min(self.min_level, p_i - 1); - } - } - - fn new_node(&mut self, parent: Option, data: T) -> NodeId { - let next_index = self.nodes.len(); - let node_id = NodeId { index: next_index }; - self.nodes.push(Node { - index: node_id, - data: data, - parent: parent, - children: HashMap::new(), - }); - node_id - } - /// Find k nearest neighbors of `p` /// * `p` - look for k nearest points to `p` /// * `k` - the number of nearest neighbors to return pub fn find(&self, p: &T, k: usize) -> Vec<(usize, F)> { - let mut qi_p_ds = vec![(self.root(), self.distance.distance(&p, &self.root().data))]; - for i in (self.min_level..self.max_level + 1).rev() { - let i_d = self.base.powf(F::from(i).unwrap()); - let mut q_p_ds = self.get_children_dist(&p, &qi_p_ds, i); - let d_p_q = self.min_k_by_distance(&mut q_p_ds, k); - qi_p_ds = q_p_ds - .into_iter() - .filter(|(_, d)| d <= &(d_p_q + i_d)) - .collect(); + if k <= 0 { + panic!("k should be > 0"); + } + + if k > self.data.len() { + panic!("k is > than the dataset size"); + } + + let e = self.get_data_value(self.root.idx); + let mut d = self.distance.distance(&e, p); + + let mut current_cover_set: Vec<(F, &Node)> = Vec::new(); + let mut zero_set: Vec<(F, &Node)> = Vec::new(); + + current_cover_set.push((d, &self.root)); + + let mut heap = HeapSelection::with_capacity(k); + heap.add(F::max_value()); + + let mut empty_heap = true; + if !self.identical_excluded || self.get_data_value(self.root.idx) != p { + heap.add(d); + empty_heap = false; + } + + while !current_cover_set.is_empty() { + let mut next_cover_set: Vec<(F, &Node)> = Vec::new(); + for par in current_cover_set { + let parent = par.1; + for c in 0..parent.children.len() { + let child = &parent.children[c]; + if c == 0 { + d = par.0; + } else { + d = self.distance.distance(self.get_data_value(child.idx), p); + } + + let upper_bound = if empty_heap { + F::infinity() + } else { + *heap.peek() + }; + if d <= (upper_bound + child.max_dist) { + if c > 0 && d < upper_bound { + if !self.identical_excluded || self.get_data_value(child.idx) != p { + heap.add(d); + } + } + + if !child.children.is_empty() { + next_cover_set.push((d, child)); + } else if d <= upper_bound { + zero_set.push((d, child)); + } + } + } + } + current_cover_set = next_cover_set; + } + + let mut neighbors: Vec<(usize, F)> = Vec::new(); + let upper_bound = *heap.peek(); + for ds in zero_set { + if ds.0 <= upper_bound { + let v = self.get_data_value(ds.1.idx); + if !self.identical_excluded || v != p { + neighbors.push((ds.1.idx, ds.0)); + } + } + } + + neighbors.into_iter().take(k).collect() + } + + fn new_leaf(&self, idx: usize) -> Node { + Node { + idx: idx, + max_dist: F::zero(), + parent_dist: F::zero(), + children: Vec::new(), + scale: 100, + } + } + + fn build_cover_tree(&mut self) { + let mut point_set: Vec> = Vec::new(); + let mut consumed_set: Vec> = Vec::new(); + + let point = &self.data[0]; + let idx = 0; + let mut max_dist = -F::one(); + + for i in 1..self.data.len() { + let dist = self.distance.distance(point, &self.data[i]); + let set = DistanceSet { + idx: i, + dist: vec![dist], + }; + point_set.push(set); + if dist > max_dist { + max_dist = dist; + } + } + + self.root = self.batch_insert( + idx, + self.get_scale(max_dist), + self.get_scale(max_dist), + &mut point_set, + &mut consumed_set, + ); + } + + fn batch_insert( + &self, + p: usize, + max_scale: i64, + top_scale: i64, + point_set: &mut Vec>, + consumed_set: &mut Vec>, + ) -> Node { + if point_set.is_empty() { + self.new_leaf(p) + } else { + let max_dist = self.max(&point_set); + let next_scale = (max_scale - 1).min(self.get_scale(max_dist)); + if next_scale == std::i64::MIN { + let mut children: Vec> = Vec::new(); + let mut leaf = self.new_leaf(p); + children.push(leaf); + while !point_set.is_empty() { + let set = point_set.remove(point_set.len() - 1); + leaf = self.new_leaf(set.idx); + children.push(leaf); + consumed_set.push(set); + } + Node { + idx: p, + max_dist: F::zero(), + parent_dist: F::zero(), + children: children, + scale: 100, + } + } else { + let mut far: Vec> = Vec::new(); + self.split(point_set, &mut far, max_scale); + + let child = self.batch_insert(p, next_scale, top_scale, point_set, consumed_set); + + if point_set.is_empty() { + point_set.append(&mut far); + child + } else { + let mut children: Vec> = Vec::new(); + children.push(child); + let mut new_point_set: Vec> = Vec::new(); + let mut new_consumed_set: Vec> = Vec::new(); + + while !point_set.is_empty() { + let set: DistanceSet = point_set.remove(point_set.len() - 1); + + let new_dist: F = set.dist[set.dist.len() - 1]; + + self.dist_split( + point_set, + &mut new_point_set, + self.get_data_value(set.idx), + max_scale, + ); + self.dist_split( + &mut far, + &mut new_point_set, + self.get_data_value(set.idx), + max_scale, + ); + + let mut new_child = self.batch_insert( + set.idx, + next_scale, + top_scale, + &mut new_point_set, + &mut new_consumed_set, + ); + new_child.parent_dist = new_dist; + + consumed_set.push(set); + children.push(new_child); + + let fmax = self.get_cover_radius(max_scale); + for mut set in new_point_set.drain(0..) { + set.dist.remove(set.dist.len() - 1); + if set.dist[set.dist.len() - 1] <= fmax { + point_set.push(set); + } else { + far.push(set); + } + } + + for mut set in new_consumed_set.drain(0..) { + set.dist.remove(set.dist.len() - 1); + consumed_set.push(set); + } + } + + point_set.append(&mut far); + + Node { + idx: p, + max_dist: self.max(consumed_set), + parent_dist: F::zero(), + children: children, + scale: (top_scale - max_scale), + } + } + } } - qi_p_ds.sort_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap()); - qi_p_ds[..usize::min(qi_p_ds.len(), k)] - .iter() - .map(|(n, d)| (n.index.index, *d)) - .collect() } fn split( &self, - p_id: NodeId, - r: F, - s1: &mut Vec, - s2: Option<&mut Vec>, - ) -> (Vec, Vec) { - let mut my_near = (Vec::new(), Vec::new()); - - my_near = self.split_remove_s(p_id, r, s1, my_near); - - for s in s2 { - my_near = self.split_remove_s(p_id, r, s, my_near); - } - - return my_near; - } - - fn split_remove_s( - &self, - p_id: NodeId, - r: F, - s: &mut Vec, - mut my_near: (Vec, Vec), - ) -> (Vec, Vec) { - if s.len() > 0 { - let p = &self.nodes.get(p_id.index).unwrap().data; - let mut i = 0; - while i != s.len() { - let d = self.distance.distance(p, &s[i]); - if d <= r { - my_near.0.push(s.remove(i)); - } else if d > r && d <= F::two() * r { - my_near.1.push(s.remove(i)); - } else { - i += 1; - } + point_set: &mut Vec>, + far_set: &mut Vec>, + max_scale: i64, + ) { + let fmax = self.get_cover_radius(max_scale); + let mut new_set: Vec> = Vec::new(); + for n in point_set.drain(0..) { + if n.dist[n.dist.len() - 1] <= fmax { + new_set.push(n); + } else { + far_set.push(n); } } - return my_near; + point_set.append(&mut new_set); } - fn construct<'b>( - &mut self, - p: NodeId, - mut near: Vec, - mut far: Vec, - i: i8, - ) -> (NodeId, Vec) { - if near.len() < 1 { - self.min_level = std::cmp::min(self.min_level, i); - return (p, far); + fn dist_split( + &self, + point_set: &mut Vec>, + new_point_set: &mut Vec>, + new_point: &T, + max_scale: i64, + ) { + let fmax = self.get_cover_radius(max_scale); + let mut new_set: Vec> = Vec::new(); + for mut n in point_set.drain(0..) { + let new_dist = self + .distance + .distance(new_point, self.get_data_value(n.idx)); + if new_dist <= fmax { + n.dist.push(new_dist); + new_point_set.push(n); + } else { + new_set.push(n); + } + } + + point_set.append(&mut new_set); + } + + fn get_cover_radius(&self, s: i64) -> F { + self.base.powf(F::from_i64(s).unwrap()) + } + + fn get_data_value(&self, idx: usize) -> &T { + &self.data[idx] + } + + fn get_scale(&self, d: F) -> i64 { + if d == F::zero() { + std::i64::MIN } else { - let (my, n) = self.split(p, self.base.powf(F::from(i - 1).unwrap()), &mut near, None); - let (pi, mut near) = self.construct(p, my, n, i - 1); - while near.len() > 0 { - let q_data = near.remove(0); - let nn = self.new_node(Some(p), q_data); - let (my, n) = self.split( - nn, - self.base.powf(F::from(i - 1).unwrap()), - &mut near, - Some(&mut far), - ); - let (child, mut unused) = self.construct(nn, my, n, i - 1); - self.add_child(pi, child, i); - let new_near_far = - self.split(p, self.base.powf(F::from(i).unwrap()), &mut unused, None); - near.extend(new_near_far.0); - far.extend(new_near_far.1); - } - self.min_level = std::cmp::min(self.min_level, i); - return (pi, far); + (self.inv_log_base * d.ln()).ceil().to_i64().unwrap() } } - fn add_child(&mut self, parent: NodeId, node: NodeId, i: i8) { - self.nodes - .get_mut(parent.index) - .unwrap() - .children - .insert(i, node); - } - - fn root(&self) -> &Node { - self.nodes.first().unwrap() - } - - fn get_children_dist<'b>( - &'b self, - p: &T, - qi_p_ds: &Vec<(&'b Node, F)>, - i: i8, - ) -> Vec<(&'b Node, F)> { - let mut children = Vec::<(&'b Node, F)>::new(); - - children.extend(qi_p_ds.iter().cloned()); - - let q: Vec<&Node> = qi_p_ds - .iter() - .flat_map(|(n, _)| self.get_child(n, i)) - .collect(); - - children.extend( - q.into_iter() - .map(|n| (n, self.distance.distance(&n.data, &p))), - ); - - children - } - - fn min_k_by_distance(&self, q_p_ds: &mut Vec<(&Node, F)>, k: usize) -> F { - let mut heap = HeapSelect::with_capacity(k); - for (_, d) in q_p_ds { - heap.add(d); - } - heap.sort(); - *heap.get().pop().unwrap() - } - - fn min_by_distance(&self, q_p_ds: &Vec<(&Node, F)>) -> F { - q_p_ds - .into_iter() - .min_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap()) - .unwrap() - .1 - } - - fn get_child(&self, node: &Node, i: i8) -> Option<&Node> { - node.children - .get(&i) - .and_then(|n_id| self.nodes.get(n_id.index)) - } - - #[allow(dead_code)] - fn check_invariant( - &self, - invariant: fn(&CoverTree, &Vec<&Node>, &Vec<&Node>, i8) -> (), - ) { - let mut current_nodes: Vec<&Node> = Vec::new(); - current_nodes.push(self.root()); - for i in (self.min_level..self.max_level + 1).rev() { - let mut next_nodes: Vec<&Node> = Vec::new(); - next_nodes.extend(current_nodes.iter()); - next_nodes.extend(current_nodes.iter().flat_map(|n| self.get_child(n, i))); - invariant(self, ¤t_nodes, &next_nodes, i); - current_nodes = next_nodes - } - } - - #[allow(dead_code)] - fn nesting_invariant( - _: &CoverTree, - nodes: &Vec<&Node>, - next_nodes: &Vec<&Node>, - _: i8, - ) { - let nodes_set: HashSet<&Node> = HashSet::from_iter(nodes.into_iter().map(|n| *n)); - let next_nodes_set: HashSet<&Node> = - HashSet::from_iter(next_nodes.into_iter().map(|n| *n)); - for n in nodes_set.iter() { - assert!(next_nodes_set.contains(n), "Nesting invariant of the cover tree is not satisfied. Set of nodes [{:?}] is not a subset of [{:?}]", nodes_set, next_nodes_set); - } - } - - #[allow(dead_code)] - fn covering_tree( - tree: &CoverTree, - nodes: &Vec<&Node>, - next_nodes: &Vec<&Node>, - i: i8, - ) { - let mut p_selected: Vec<&Node> = Vec::new(); - for p in next_nodes { - for q in nodes { - if tree.distance.distance(&p.data, &q.data) <= tree.base.powf(F::from(i).unwrap()) { - p_selected.push(*p); - } - } - let c = p_selected - .iter() - .filter(|q| p.parent.map(|p| q.index == p).unwrap_or(false)) - .count(); - assert!(c <= 1); - } - } - - #[allow(dead_code)] - fn separation(tree: &CoverTree, nodes: &Vec<&Node>, _: &Vec<&Node>, i: i8) { - for p in nodes { - for q in nodes { - if p != q { - assert!( - tree.distance.distance(&p.data, &q.data) - > tree.base.powf(F::from(i).unwrap()) - ); - } + fn max(&self, distance_set: &Vec>) -> F { + let mut max = F::zero(); + for n in distance_set { + if max < n.dist[n.dist.len() - 1] { + max = n.dist[n.dist.len() - 1]; } } - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] -struct NodeId { - index: usize, -} - -#[derive(Debug, Serialize, Deserialize)] -struct Node { - index: NodeId, - data: T, - children: HashMap, - parent: Option, -} - -impl PartialEq for Node { - fn eq(&self, other: &Self) -> bool { - self.index.index == other.index.index - } -} - -impl Eq for Node {} - -impl Hash for Node { - fn hash(&self, state: &mut H) - where - H: Hasher, - { - state.write_usize(self.index.index); - state.finish(); + return max; } } @@ -366,7 +392,9 @@ impl Hash for Node { mod tests { use super::*; + use crate::math::distance::Distances; + #[derive(Debug, Serialize, Deserialize)] struct SimpleDistance {} impl Distance for SimpleDistance { @@ -379,32 +407,42 @@ mod tests { fn cover_tree_test() { let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; - let mut tree = CoverTree::new(data, SimpleDistance {}); - for d in vec![10, 11, 12, 13, 14, 15, 16, 17, 18, 19] { - tree.insert(d); - } + let tree = CoverTree::new(data, SimpleDistance {}); - let mut nearest_3_to_5 = tree.find(&5, 3); - nearest_3_to_5.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); - let nearest_3_to_5_indexes: Vec = nearest_3_to_5.iter().map(|v| v.0).collect(); - assert_eq!(vec!(4, 5, 3), nearest_3_to_5_indexes); - - let mut nearest_3_to_15 = tree.find(&15, 3); - nearest_3_to_15.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); - let nearest_3_to_15_indexes: Vec = nearest_3_to_15.iter().map(|v| v.0).collect(); - assert_eq!(vec!(14, 13, 15), nearest_3_to_15_indexes); - - assert_eq!(-1, tree.min_level); - assert_eq!(100, tree.max_level); + let mut knn = tree.find(&5, 3); + knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); + let knn: Vec = knn.iter().map(|v| v.0).collect(); + assert_eq!(vec!(3, 4, 5), knn); } #[test] - fn test_invariants() { + fn cover_tree_test1() { + let data = vec![ + vec![1., 2.], + vec![3., 4.], + vec![5., 6.], + vec![7., 8.], + vec![9., 10.], + ]; + + let tree = CoverTree::new(data, Distances::euclidian()); + + let mut knn = tree.find(&vec![1., 2.], 3); + knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); + let knn: Vec = knn.iter().map(|v| v.0).collect(); + + assert_eq!(vec!(0, 1, 2), knn); + } + + #[test] + fn serde() { let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; let tree = CoverTree::new(data, SimpleDistance {}); - tree.check_invariant(CoverTree::nesting_invariant); - tree.check_invariant(CoverTree::covering_tree); - tree.check_invariant(CoverTree::separation); + + let deserialized_tree: CoverTree = + serde_json::from_str(&serde_json::to_string(&tree).unwrap()).unwrap(); + + assert_eq!(tree, deserialized_tree); } } diff --git a/src/algorithm/neighbour/linear_search.rs b/src/algorithm/neighbour/linear_search.rs index 00e09d9..e69ad6b 100644 --- a/src/algorithm/neighbour/linear_search.rs +++ b/src/algorithm/neighbour/linear_search.rs @@ -25,7 +25,7 @@ use serde::{Deserialize, Serialize}; use std::cmp::{Ordering, PartialOrd}; use std::marker::PhantomData; -use crate::algorithm::sort::heap_select::HeapSelect; +use crate::algorithm::sort::heap_select::HeapSelection; use crate::math::distance::Distance; use crate::math::num::RealNumber; @@ -57,7 +57,7 @@ impl> LinearKNNSearch { panic!("k should be >= 1 and <= length(data)"); } - let mut heap = HeapSelect::>::with_capacity(k); + let mut heap = HeapSelection::>::with_capacity(k); for _ in 0..k { heap.add(KNNPoint { @@ -76,8 +76,6 @@ impl> LinearKNNSearch { } } - heap.sort(); - heap.get() .into_iter() .flat_map(|x| x.index.map(|i| (i, x.distance))) @@ -124,9 +122,10 @@ mod tests { let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance {}); - let found_idxs1: Vec = algorithm1.find(&2, 3).iter().map(|v| v.0).collect(); + let mut found_idxs1: Vec = algorithm1.find(&2, 3).iter().map(|v| v.0).collect(); + found_idxs1.sort(); - assert_eq!(vec!(1, 2, 0), found_idxs1); + assert_eq!(vec!(0, 1, 2), found_idxs1); let data2 = vec![ vec![1., 1.], @@ -138,13 +137,14 @@ mod tests { let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian()); - let found_idxs2: Vec = algorithm2 + let mut found_idxs2: Vec = algorithm2 .find(&vec![3., 3.], 3) .iter() .map(|v| v.0) .collect(); + found_idxs2.sort(); - assert_eq!(vec!(2, 3, 1), found_idxs2); + assert_eq!(vec!(1, 2, 3), found_idxs2); } #[test] diff --git a/src/algorithm/neighbour/mod.rs b/src/algorithm/neighbour/mod.rs index f228aed..48c8835 100644 --- a/src/algorithm/neighbour/mod.rs +++ b/src/algorithm/neighbour/mod.rs @@ -22,7 +22,7 @@ //! //! ## References: //! * ["The Art of Computer Programming" Knuth, D, Vol. 3, 2nd ed, Sorting and Searching, 1998](https://www-cs-faculty.stanford.edu/~knuth/taocp.html) -//! * ["Cover Trees for Nearest Neighbor" Beygelzimer et al., Proceedings of the 23rd international conference on Machine learning, ICML'06 (2006)](https://homes.cs.washington.edu/~sham/papers/ml/cover_tree.pdf) +//! * ["Cover Trees for Nearest Neighbor" Beygelzimer et al., Proceedings of the 23rd international conference on Machine learning, ICML'06 (2006)](https://hunch.net/~jl/projects/cover_tree/cover_tree.html) //! * ["Faster cover trees." Izbicki et al., Proceedings of the 32nd International Conference on Machine Learning, ICML'15 (2015)](http://www.cs.ucr.edu/~cshelton/papers/index.cgi%3FIzbShe15) //! * ["The Elements of Statistical Learning: Data Mining, Inference, and Prediction" Trevor et al., 2nd edition, chapter 13](https://web.stanford.edu/~hastie/ElemStatLearn/) //! diff --git a/src/algorithm/sort/heap_select.rs b/src/algorithm/sort/heap_select.rs index 6c13b9b..063ffc6 100644 --- a/src/algorithm/sort/heap_select.rs +++ b/src/algorithm/sort/heap_select.rs @@ -1,20 +1,24 @@ +//! # Heap Selection Algorithm +//! +//! The goal is to find the k smallest elements in a list or array. use std::cmp::Ordering; +use std::fmt::Debug; #[derive(Debug)] -pub struct HeapSelect { +pub struct HeapSelection { k: usize, n: usize, sorted: bool, heap: Vec, } -impl<'a, T: PartialOrd> HeapSelect { - pub fn with_capacity(k: usize) -> HeapSelect { - HeapSelect { +impl<'a, T: PartialOrd + Debug> HeapSelection { + pub fn with_capacity(k: usize) -> HeapSelection { + HeapSelection { k: k, n: 0, sorted: false, - heap: Vec::::new(), + heap: Vec::new(), } } @@ -24,12 +28,13 @@ impl<'a, T: PartialOrd> HeapSelect { self.heap.push(element); self.n += 1; if self.n == self.k { - self.heapify(); + self.sort(); } } else { self.n += 1; if element.partial_cmp(&self.heap[0]) == Some(Ordering::Less) { self.heap[0] = element; + self.sift_down(0, self.k - 1); } } } @@ -41,58 +46,46 @@ impl<'a, T: PartialOrd> HeapSelect { } } - #[allow(dead_code)] pub fn peek(&self) -> &T { - return &self.heap[0]; + if self.sorted { + return &self.heap[0]; + } else { + &self + .heap + .iter() + .max_by(|a, b| a.partial_cmp(b).unwrap()) + .unwrap() + } } pub fn peek_mut(&mut self) -> &mut T { return &mut self.heap[0]; } - pub fn sift_down(&mut self, from: usize, n: usize) { - let mut k = from; - while 2 * k <= n { - let mut j = 2 * k; - if j < n && self.heap[j] < self.heap[j + 1] { - j += 1; - } - if self.heap[k] >= self.heap[j] { - break; - } - self.heap.swap(k, j); - k = j; - } - } - pub fn get(self) -> Vec { return self.heap; } - pub fn sort(&mut self) { - HeapSelect::shuffle_sort(&mut self.heap, std::cmp::min(self.k, self.n)); + fn sift_down(&mut self, k: usize, n: usize) { + let mut kk = k; + while 2 * kk <= n { + let mut j = 2 * kk; + if j < n && self.heap[j].partial_cmp(&self.heap[j + 1]) == Some(Ordering::Less) { + j += 1; + } + if self.heap[kk].partial_cmp(&self.heap[j]) == Some(Ordering::Equal) + || self.heap[kk].partial_cmp(&self.heap[j]) == Some(Ordering::Greater) + { + break; + } + self.heap.swap(kk, j); + kk = j; + } } - pub fn shuffle_sort(vec: &mut Vec, n: usize) { - let mut inc = 1; - while inc <= n { - inc *= 3; - inc += 1 - } - - let len = n; - while inc >= 1 { - let mut i = inc; - while i < len { - let mut j = i; - while j >= inc && vec[j - inc] > vec[j] { - vec.swap(j - inc, j); - j -= inc; - } - i += 1; - } - inc /= 3 - } + fn sort(&mut self) { + self.sorted = true; + self.heap.sort_by(|a, b| b.partial_cmp(a).unwrap()); } } @@ -102,50 +95,62 @@ mod tests { #[test] fn with_capacity() { - let heap = HeapSelect::::with_capacity(3); + let heap = HeapSelection::::with_capacity(3); assert_eq!(3, heap.k); } #[test] fn test_add() { - let mut heap = HeapSelect::with_capacity(3); + let mut heap = HeapSelection::with_capacity(3); + heap.add(-5); + assert_eq!(-5, *heap.peek()); heap.add(333); - heap.add(2); + assert_eq!(333, *heap.peek()); heap.add(13); heap.add(10); + heap.add(2); + heap.add(0); heap.add(40); heap.add(30); - assert_eq!(6, heap.n); - assert_eq!(&10, heap.peek()); - assert_eq!(&10, heap.peek_mut()); + assert_eq!(8, heap.n); + assert_eq!(vec![2, 0, -5], heap.get()); + } + + #[test] + fn test_add1() { + let mut heap = HeapSelection::with_capacity(3); + heap.add(std::f64::INFINITY); + heap.add(-5f64); + heap.add(4f64); + heap.add(-1f64); + heap.add(2f64); + heap.add(1f64); + heap.add(0f64); + assert_eq!(7, heap.n); + assert_eq!(vec![0f64, -1f64, -5f64], heap.get()); + } + + #[test] + fn test_add2() { + let mut heap = HeapSelection::with_capacity(3); + heap.add(std::f64::INFINITY); + heap.add(0.0); + heap.add(8.4852); + heap.add(5.6568); + heap.add(2.8284); + assert_eq!(5, heap.n); + assert_eq!(vec![5.6568, 2.8284, 0.0], heap.get()); } #[test] fn test_add_ordered() { - let mut heap = HeapSelect::with_capacity(3); + let mut heap = HeapSelection::with_capacity(3); heap.add(1.); heap.add(2.); heap.add(3.); heap.add(4.); heap.add(5.); heap.add(6.); - let result = heap.get(); - assert_eq!(vec![2., 3., 1.], result); - } - - #[test] - fn test_shuffle_sort() { - let mut v1 = vec![10, 33, 22, 105, 12]; - let n = v1.len(); - HeapSelect::shuffle_sort(&mut v1, n); - assert_eq!(vec![10, 12, 22, 33, 105], v1); - - let mut v2 = vec![10, 33, 22, 105, 12]; - HeapSelect::shuffle_sort(&mut v2, 3); - assert_eq!(vec![10, 22, 33, 105, 12], v2); - - let mut v3 = vec![4, 5, 3, 2, 1]; - HeapSelect::shuffle_sort(&mut v3, 3); - assert_eq!(vec![3, 4, 5, 2, 1], v3); + assert_eq!(vec![3., 2., 1.], heap.get()); } } diff --git a/src/dataset/iris.rs b/src/dataset/iris.rs new file mode 100644 index 0000000..6c20576 --- /dev/null +++ b/src/dataset/iris.rs @@ -0,0 +1,90 @@ +//! # The Iris Dataset flower +//! +//! [Fisher's Iris dataset](https://archive.ics.uci.edu/ml/datasets/iris) is a multivariate dataset that was published in 1936 by Ronald Fisher. +//! This multivariate dataset is frequently used to demonstrate various machine learning algorithms. +use crate::dataset::Dataset; + +/// Get dataset +pub fn load_dataset() -> Dataset { + let x = vec![ + 5.1, 3.5, 1.4, 0.2, 4.9, 3.0, 1.4, 0.2, 4.7, 3.2, 1.3, 0.2, 4.6, 3.1, 1.5, 0.2, 5.0, 3.6, + 1.4, 0.2, 5.4, 3.9, 1.7, 0.4, 4.6, 3.4, 1.4, 0.3, 5.0, 3.4, 1.5, 0.2, 4.4, 2.9, 1.4, 0.2, + 4.9, 3.1, 1.5, 0.1, 5.4, 3.7, 1.5, 0.2, 4.8, 3.4, 1.6, 0.2, 4.8, 3.0, 1.4, 0.1, 4.3, 3.0, + 1.1, 0.1, 5.8, 4.0, 1.2, 0.2, 5.7, 4.4, 1.5, 0.4, 5.4, 3.9, 1.3, 0.4, 5.1, 3.5, 1.4, 0.3, + 5.7, 3.8, 1.7, 0.3, 5.1, 3.8, 1.5, 0.3, 5.4, 3.4, 1.7, 0.2, 5.1, 3.7, 1.5, 0.4, 4.6, 3.6, + 1.0, 0.2, 5.1, 3.3, 1.7, 0.5, 4.8, 3.4, 1.9, 0.2, 5.0, 3.0, 1.6, 0.2, 5.0, 3.4, 1.6, 0.4, + 5.2, 3.5, 1.5, 0.2, 5.2, 3.4, 1.4, 0.2, 4.7, 3.2, 1.6, 0.2, 4.8, 3.1, 1.6, 0.2, 5.4, 3.4, + 1.5, 0.4, 5.2, 4.1, 1.5, 0.1, 5.5, 4.2, 1.4, 0.2, 4.9, 3.1, 1.5, 0.1, 5.0, 3.2, 1.2, 0.2, + 5.5, 3.5, 1.3, 0.2, 4.9, 3.1, 1.5, 0.1, 4.4, 3.0, 1.3, 0.2, 5.1, 3.4, 1.5, 0.2, 5.0, 3.5, + 1.3, 0.3, 4.5, 2.3, 1.3, 0.3, 4.4, 3.2, 1.3, 0.2, 5.0, 3.5, 1.6, 0.6, 5.1, 3.8, 1.9, 0.4, + 4.8, 3.0, 1.4, 0.3, 5.1, 3.8, 1.6, 0.2, 4.6, 3.2, 1.4, 0.2, 5.3, 3.7, 1.5, 0.2, 5.0, 3.3, + 1.4, 0.2, 7.0, 3.2, 4.7, 1.4, 6.4, 3.2, 4.5, 1.5, 6.9, 3.1, 4.9, 1.5, 5.5, 2.3, 4.0, 1.3, + 6.5, 2.8, 4.6, 1.5, 5.7, 2.8, 4.5, 1.3, 6.3, 3.3, 4.7, 1.6, 4.9, 2.4, 3.3, 1.0, 6.6, 2.9, + 4.6, 1.3, 5.2, 2.7, 3.9, 1.4, 5.0, 2.0, 3.5, 1.0, 5.9, 3.0, 4.2, 1.5, 6.0, 2.2, 4.0, 1.0, + 6.1, 2.9, 4.7, 1.4, 5.6, 2.9, 3.6, 1.3, 6.7, 3.1, 4.4, 1.4, 5.6, 3.0, 4.5, 1.5, 5.8, 2.7, + 4.1, 1.0, 6.2, 2.2, 4.5, 1.5, 5.6, 2.5, 3.9, 1.1, 5.9, 3.2, 4.8, 1.8, 6.1, 2.8, 4.0, 1.3, + 6.3, 2.5, 4.9, 1.5, 6.1, 2.8, 4.7, 1.2, 6.4, 2.9, 4.3, 1.3, 6.6, 3.0, 4.4, 1.4, 6.8, 2.8, + 4.8, 1.4, 6.7, 3.0, 5.0, 1.7, 6.0, 2.9, 4.5, 1.5, 5.7, 2.6, 3.5, 1.0, 5.5, 2.4, 3.8, 1.1, + 5.5, 2.4, 3.7, 1.0, 5.8, 2.7, 3.9, 1.2, 6.0, 2.7, 5.1, 1.6, 5.4, 3.0, 4.5, 1.5, 6.0, 3.4, + 4.5, 1.6, 6.7, 3.1, 4.7, 1.5, 6.3, 2.3, 4.4, 1.3, 5.6, 3.0, 4.1, 1.3, 5.5, 2.5, 4.0, 1.3, + 5.5, 2.6, 4.4, 1.2, 6.1, 3.0, 4.6, 1.4, 5.8, 2.6, 4.0, 1.2, 5.0, 2.3, 3.3, 1.0, 5.6, 2.7, + 4.2, 1.3, 5.7, 3.0, 4.2, 1.2, 5.7, 2.9, 4.2, 1.3, 6.2, 2.9, 4.3, 1.3, 5.1, 2.5, 3.0, 1.1, + 5.7, 2.8, 4.1, 1.3, 6.3, 3.3, 6.0, 2.5, 5.8, 2.7, 5.1, 1.9, 7.1, 3.0, 5.9, 2.1, 6.3, 2.9, + 5.6, 1.8, 6.5, 3.0, 5.8, 2.2, 7.6, 3.0, 6.6, 2.1, 4.9, 2.5, 4.5, 1.7, 7.3, 2.9, 6.3, 1.8, + 6.7, 2.5, 5.8, 1.8, 7.2, 3.6, 6.1, 2.5, 6.5, 3.2, 5.1, 2.0, 6.4, 2.7, 5.3, 1.9, 6.8, 3.0, + 5.5, 2.1, 5.7, 2.5, 5.0, 2.0, 5.8, 2.8, 5.1, 2.4, 6.4, 3.2, 5.3, 2.3, 6.5, 3.0, 5.5, 1.8, + 7.7, 3.8, 6.7, 2.2, 7.7, 2.6, 6.9, 2.3, 6.0, 2.2, 5.0, 1.5, 6.9, 3.2, 5.7, 2.3, 5.6, 2.8, + 4.9, 2.0, 7.7, 2.8, 6.7, 2.0, 6.3, 2.7, 4.9, 1.8, 6.7, 3.3, 5.7, 2.1, 7.2, 3.2, 6.0, 1.8, + 6.2, 2.8, 4.8, 1.8, 6.1, 3.0, 4.9, 1.8, 6.4, 2.8, 5.6, 2.1, 7.2, 3.0, 5.8, 1.6, 7.4, 2.8, + 6.1, 1.9, 7.9, 3.8, 6.4, 2.0, 6.4, 2.8, 5.6, 2.2, 6.3, 2.8, 5.1, 1.5, 6.1, 2.6, 5.6, 1.4, + 7.7, 3.0, 6.1, 2.3, 6.3, 3.4, 5.6, 2.4, 6.4, 3.1, 5.5, 1.8, 6.0, 3.0, 4.8, 1.8, 6.9, 3.1, + 5.4, 2.1, 6.7, 3.1, 5.6, 2.4, 6.9, 3.1, 5.1, 2.3, 5.8, 2.7, 5.1, 1.9, 6.8, 3.2, 5.9, 2.3, + 6.7, 3.3, 5.7, 2.5, 6.7, 3.0, 5.2, 2.3, 6.3, 2.5, 5.0, 1.9, 6.5, 3.0, 5.2, 2.0, 6.2, 3.4, + 5.4, 2.3, 5.9, 3.0, 5.1, 1.8, + ]; + + let setosa = std::iter::repeat(0f32).take(50); + let versicolor = std::iter::repeat(1f32).take(50); + let virginica = std::iter::repeat(2f32).take(50); + let y = setosa + .chain(versicolor) + .chain(virginica) + .collect::>(); + let shape = (150, 4); + + Dataset { + data: x, + target: y, + num_samples: shape.0, + num_features: shape.1, + feature_names: vec![ + "sepal length (cm)", + "sepal width (cm)", + "petal length (cm)", + "petal width (cm)", + ] + .iter() + .map(|s| s.to_string()) + .collect(), + target_names: vec!["setosa", "versicolor", "virginica"] + .iter() + .map(|s| s.to_string()) + .collect(), + description: "Iris dataset: https://archive.ics.uci.edu/ml/datasets/iris".to_string(), + } +} + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn iris_dataset() { + let dataset = load_dataset(); + assert_eq!(dataset.data.len(), 50 * 3 * 4); + assert_eq!(dataset.target.len(), 50 * 3); + assert_eq!(dataset.num_features, 4); + assert_eq!(dataset.num_samples, 50 * 3); + } +} diff --git a/src/dataset/mod.rs b/src/dataset/mod.rs new file mode 100644 index 0000000..406507a --- /dev/null +++ b/src/dataset/mod.rs @@ -0,0 +1,65 @@ +//! Datasets +//! +//! In this module you will find small datasets that are used in SmartCore for demonstration purpose mostly. + +/// Iris flower data set +pub mod iris; + +/// Dataset +pub struct Dataset { + /// data in one-dimensional array. + pub data: Vec, + /// target values or class labels. + pub target: Vec, + /// number of samples (number of rows in matrix form). + pub num_samples: usize, + /// number of features (number of columns in matrix form). + pub num_features: usize, + /// names of dependent variables. + pub feature_names: Vec, + /// names of target variables. + pub target_names: Vec, + /// dataset description + pub description: String, +} + +impl Dataset { + /// Reshape data into a two-dimensional matrix + pub fn as_2d_vector(&self) -> Vec> { + let mut result: Vec> = Vec::with_capacity(self.num_samples); + + for r in 0..self.num_samples { + let mut row = Vec::with_capacity(self.num_features); + for c in 0..self.num_features { + row.push(&self.data[r * self.num_features + c]); + } + result.push(row); + } + + result + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn as_2d_vector() { + let dataset = Dataset { + data: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + target: vec![1, 2, 3], + num_samples: 2, + num_features: 5, + feature_names: vec![], + target_names: vec![], + description: "".to_string(), + }; + + let m = dataset.as_2d_vector(); + + assert_eq!(m.len(), 2); + assert_eq!(m[0].len(), 5); + assert_eq!(*m[1][3], 9); + } +} diff --git a/src/lib.rs b/src/lib.rs index 7be40ca..b67d0f6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -68,6 +68,9 @@ pub mod algorithm; /// Algorithms for clustering of unlabeled data pub mod cluster; +/// Various datasets +#[cfg(feature = "datasets")] +pub mod dataset; /// Matrix decomposition algorithms pub mod decomposition; /// Ensemble methods, including Random Forest classifier and regressor diff --git a/src/neighbors/knn_regressor.rs b/src/neighbors/knn_regressor.rs index 99a22e1..d8ddcf3 100644 --- a/src/neighbors/knn_regressor.rs +++ b/src/neighbors/knn_regressor.rs @@ -143,6 +143,7 @@ impl, T>> KNNRegressor { fn predict_for_row(&self, x: Vec) -> T { let search_result = self.knn_algorithm.find(&x, self.k); + println!("{:?}", search_result); let mut result = T::zero(); let weights = self @@ -195,9 +196,10 @@ mod tests { let y_exp = vec![2., 2., 3., 4., 4.]; let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default()); let y_hat = knn.predict(&x); + println!("{:?}", y_hat); assert_eq!(5, Vec::len(&y_hat)); for i in 0..y_hat.len() { - assert!((y_hat[i] - y_exp[i]).abs() < std::f64::EPSILON); + assert!((y_hat[i] - y_exp[i]).abs() < 1e-7); } }