feat: adds dataset module, fixs problem in CoverTree implementation

This commit is contained in:
Volodymyr Orlov
2020-09-10 12:21:59 -07:00
parent cc1f84e81f
commit b95e11cc98
9 changed files with 598 additions and 394 deletions
+2 -1
View File
@@ -5,9 +5,10 @@ authors = ["SmartCore Developers"]
edition = "2018"
[features]
default = []
default = ["datasets"]
ndarray-bindings = ["ndarray"]
nalgebra-bindings = ["nalgebra"]
datasets = []
[dependencies]
ndarray = { version = "0.13", optional = true }
+351 -313
View File
@@ -21,14 +21,11 @@
//! tree.find(&5, 3); // find 3 knn points from 5
//!
//! ```
use core::hash::{Hash, Hasher};
use std::collections::{HashMap, HashSet};
use std::fmt::Debug;
use std::iter::FromIterator;
use serde::{Deserialize, Serialize};
use crate::algorithm::sort::heap_select::HeapSelect;
use crate::algorithm::sort::heap_select::HeapSelection;
use crate::math::distance::Distance;
use crate::math::num::RealNumber;
@@ -36,329 +33,358 @@ use crate::math::num::RealNumber;
#[derive(Serialize, Deserialize, Debug)]
pub struct CoverTree<T, F: RealNumber, D: Distance<T, F>> {
base: F,
max_level: i8,
min_level: i8,
inv_log_base: F,
distance: D,
nodes: Vec<Node<T>>,
root: Node<F>,
data: Vec<T>,
identical_excluded: bool,
}
impl<T: Debug, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D> {
impl<T, F: RealNumber, D: Distance<T, F>> PartialEq for CoverTree<T, F, D> {
fn eq(&self, other: &Self) -> bool {
if self.data.len() != other.data.len() {
return false;
}
for i in 0..self.data.len() {
if self.distance.distance(&self.data[i], &other.data[i]) != F::zero() {
return false;
}
}
return true;
}
}
#[derive(Debug, Serialize, Deserialize)]
struct Node<F: RealNumber> {
idx: usize,
max_dist: F,
parent_dist: F,
children: Vec<Node<F>>,
scale: i64,
}
#[derive(Debug, Serialize, Deserialize)]
struct DistanceSet<F: RealNumber> {
idx: usize,
dist: Vec<F>,
}
impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D> {
/// Construct a cover tree.
/// * `data` - vector of data points to search for.
/// * `distance` - distance metric to use for searching. This function should extend [`Distance`](../../../math/distance/index.html) interface.
pub fn new(mut data: Vec<T>, distance: D) -> CoverTree<T, F, D> {
pub fn new(data: Vec<T>, distance: D) -> CoverTree<T, F, D> {
let base = F::from_f64(1.3).unwrap();
let root = Node {
idx: 0,
max_dist: F::zero(),
parent_dist: F::zero(),
children: Vec::new(),
scale: 0,
};
let mut tree = CoverTree {
base: F::two(),
max_level: 100,
min_level: 100,
base: base,
inv_log_base: F::one() / base.ln(),
distance: distance,
nodes: Vec::new(),
root: root,
data: data,
identical_excluded: false,
};
let p = tree.new_node(None, data.remove(0));
tree.construct(p, data, Vec::new(), 10);
tree.build_cover_tree();
tree
}
/// Insert new data point into the cover tree.
/// * `p` - new data points.
pub fn insert(&mut self, p: T) {
if self.nodes.is_empty() {
self.new_node(None, p);
} else {
let mut parent: Option<NodeId> = Option::None;
let mut p_i = 0;
let mut qi_p_ds = vec![(self.root(), self.distance.distance(&p, &self.root().data))];
let mut i = self.max_level;
loop {
let i_d = self.base.powf(F::from(i).unwrap());
let q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
let d_p_q = self.min_by_distance(&q_p_ds);
if d_p_q < F::epsilon() {
return;
} else if d_p_q > i_d {
break;
}
if self.min_by_distance(&qi_p_ds) <= self.base.powf(F::from(i).unwrap()) {
parent = q_p_ds.iter().find(|(_, d)| d <= &i_d).map(|(n, _)| n.index);
p_i = i;
}
qi_p_ds = q_p_ds.into_iter().filter(|(_, d)| d <= &i_d).collect();
i -= 1;
}
let new_node = self.new_node(parent, p);
self.add_child(parent.unwrap(), new_node, p_i);
self.min_level = i8::min(self.min_level, p_i - 1);
}
}
fn new_node(&mut self, parent: Option<NodeId>, data: T) -> NodeId {
let next_index = self.nodes.len();
let node_id = NodeId { index: next_index };
self.nodes.push(Node {
index: node_id,
data: data,
parent: parent,
children: HashMap::new(),
});
node_id
}
/// Find k nearest neighbors of `p`
/// * `p` - look for k nearest points to `p`
/// * `k` - the number of nearest neighbors to return
pub fn find(&self, p: &T, k: usize) -> Vec<(usize, F)> {
let mut qi_p_ds = vec![(self.root(), self.distance.distance(&p, &self.root().data))];
for i in (self.min_level..self.max_level + 1).rev() {
let i_d = self.base.powf(F::from(i).unwrap());
let mut q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
let d_p_q = self.min_k_by_distance(&mut q_p_ds, k);
qi_p_ds = q_p_ds
.into_iter()
.filter(|(_, d)| d <= &(d_p_q + i_d))
.collect();
if k <= 0 {
panic!("k should be > 0");
}
if k > self.data.len() {
panic!("k is > than the dataset size");
}
let e = self.get_data_value(self.root.idx);
let mut d = self.distance.distance(&e, p);
let mut current_cover_set: Vec<(F, &Node<F>)> = Vec::new();
let mut zero_set: Vec<(F, &Node<F>)> = Vec::new();
current_cover_set.push((d, &self.root));
let mut heap = HeapSelection::with_capacity(k);
heap.add(F::max_value());
let mut empty_heap = true;
if !self.identical_excluded || self.get_data_value(self.root.idx) != p {
heap.add(d);
empty_heap = false;
}
while !current_cover_set.is_empty() {
let mut next_cover_set: Vec<(F, &Node<F>)> = Vec::new();
for par in current_cover_set {
let parent = par.1;
for c in 0..parent.children.len() {
let child = &parent.children[c];
if c == 0 {
d = par.0;
} else {
d = self.distance.distance(self.get_data_value(child.idx), p);
}
let upper_bound = if empty_heap {
F::infinity()
} else {
*heap.peek()
};
if d <= (upper_bound + child.max_dist) {
if c > 0 && d < upper_bound {
if !self.identical_excluded || self.get_data_value(child.idx) != p {
heap.add(d);
}
}
if !child.children.is_empty() {
next_cover_set.push((d, child));
} else if d <= upper_bound {
zero_set.push((d, child));
}
}
}
}
current_cover_set = next_cover_set;
}
let mut neighbors: Vec<(usize, F)> = Vec::new();
let upper_bound = *heap.peek();
for ds in zero_set {
if ds.0 <= upper_bound {
let v = self.get_data_value(ds.1.idx);
if !self.identical_excluded || v != p {
neighbors.push((ds.1.idx, ds.0));
}
}
}
neighbors.into_iter().take(k).collect()
}
fn new_leaf(&self, idx: usize) -> Node<F> {
Node {
idx: idx,
max_dist: F::zero(),
parent_dist: F::zero(),
children: Vec::new(),
scale: 100,
}
}
fn build_cover_tree(&mut self) {
let mut point_set: Vec<DistanceSet<F>> = Vec::new();
let mut consumed_set: Vec<DistanceSet<F>> = Vec::new();
let point = &self.data[0];
let idx = 0;
let mut max_dist = -F::one();
for i in 1..self.data.len() {
let dist = self.distance.distance(point, &self.data[i]);
let set = DistanceSet {
idx: i,
dist: vec![dist],
};
point_set.push(set);
if dist > max_dist {
max_dist = dist;
}
}
self.root = self.batch_insert(
idx,
self.get_scale(max_dist),
self.get_scale(max_dist),
&mut point_set,
&mut consumed_set,
);
}
fn batch_insert(
&self,
p: usize,
max_scale: i64,
top_scale: i64,
point_set: &mut Vec<DistanceSet<F>>,
consumed_set: &mut Vec<DistanceSet<F>>,
) -> Node<F> {
if point_set.is_empty() {
self.new_leaf(p)
} else {
let max_dist = self.max(&point_set);
let next_scale = (max_scale - 1).min(self.get_scale(max_dist));
if next_scale == std::i64::MIN {
let mut children: Vec<Node<F>> = Vec::new();
let mut leaf = self.new_leaf(p);
children.push(leaf);
while !point_set.is_empty() {
let set = point_set.remove(point_set.len() - 1);
leaf = self.new_leaf(set.idx);
children.push(leaf);
consumed_set.push(set);
}
Node {
idx: p,
max_dist: F::zero(),
parent_dist: F::zero(),
children: children,
scale: 100,
}
} else {
let mut far: Vec<DistanceSet<F>> = Vec::new();
self.split(point_set, &mut far, max_scale);
let child = self.batch_insert(p, next_scale, top_scale, point_set, consumed_set);
if point_set.is_empty() {
point_set.append(&mut far);
child
} else {
let mut children: Vec<Node<F>> = Vec::new();
children.push(child);
let mut new_point_set: Vec<DistanceSet<F>> = Vec::new();
let mut new_consumed_set: Vec<DistanceSet<F>> = Vec::new();
while !point_set.is_empty() {
let set: DistanceSet<F> = point_set.remove(point_set.len() - 1);
let new_dist: F = set.dist[set.dist.len() - 1];
self.dist_split(
point_set,
&mut new_point_set,
self.get_data_value(set.idx),
max_scale,
);
self.dist_split(
&mut far,
&mut new_point_set,
self.get_data_value(set.idx),
max_scale,
);
let mut new_child = self.batch_insert(
set.idx,
next_scale,
top_scale,
&mut new_point_set,
&mut new_consumed_set,
);
new_child.parent_dist = new_dist;
consumed_set.push(set);
children.push(new_child);
let fmax = self.get_cover_radius(max_scale);
for mut set in new_point_set.drain(0..) {
set.dist.remove(set.dist.len() - 1);
if set.dist[set.dist.len() - 1] <= fmax {
point_set.push(set);
} else {
far.push(set);
}
}
for mut set in new_consumed_set.drain(0..) {
set.dist.remove(set.dist.len() - 1);
consumed_set.push(set);
}
}
point_set.append(&mut far);
Node {
idx: p,
max_dist: self.max(consumed_set),
parent_dist: F::zero(),
children: children,
scale: (top_scale - max_scale),
}
}
}
}
qi_p_ds.sort_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap());
qi_p_ds[..usize::min(qi_p_ds.len(), k)]
.iter()
.map(|(n, d)| (n.index.index, *d))
.collect()
}
fn split(
&self,
p_id: NodeId,
r: F,
s1: &mut Vec<T>,
s2: Option<&mut Vec<T>>,
) -> (Vec<T>, Vec<T>) {
let mut my_near = (Vec::new(), Vec::new());
my_near = self.split_remove_s(p_id, r, s1, my_near);
for s in s2 {
my_near = self.split_remove_s(p_id, r, s, my_near);
}
return my_near;
}
fn split_remove_s(
&self,
p_id: NodeId,
r: F,
s: &mut Vec<T>,
mut my_near: (Vec<T>, Vec<T>),
) -> (Vec<T>, Vec<T>) {
if s.len() > 0 {
let p = &self.nodes.get(p_id.index).unwrap().data;
let mut i = 0;
while i != s.len() {
let d = self.distance.distance(p, &s[i]);
if d <= r {
my_near.0.push(s.remove(i));
} else if d > r && d <= F::two() * r {
my_near.1.push(s.remove(i));
} else {
i += 1;
}
point_set: &mut Vec<DistanceSet<F>>,
far_set: &mut Vec<DistanceSet<F>>,
max_scale: i64,
) {
let fmax = self.get_cover_radius(max_scale);
let mut new_set: Vec<DistanceSet<F>> = Vec::new();
for n in point_set.drain(0..) {
if n.dist[n.dist.len() - 1] <= fmax {
new_set.push(n);
} else {
far_set.push(n);
}
}
return my_near;
point_set.append(&mut new_set);
}
fn construct<'b>(
&mut self,
p: NodeId,
mut near: Vec<T>,
mut far: Vec<T>,
i: i8,
) -> (NodeId, Vec<T>) {
if near.len() < 1 {
self.min_level = std::cmp::min(self.min_level, i);
return (p, far);
fn dist_split(
&self,
point_set: &mut Vec<DistanceSet<F>>,
new_point_set: &mut Vec<DistanceSet<F>>,
new_point: &T,
max_scale: i64,
) {
let fmax = self.get_cover_radius(max_scale);
let mut new_set: Vec<DistanceSet<F>> = Vec::new();
for mut n in point_set.drain(0..) {
let new_dist = self
.distance
.distance(new_point, self.get_data_value(n.idx));
if new_dist <= fmax {
n.dist.push(new_dist);
new_point_set.push(n);
} else {
new_set.push(n);
}
}
point_set.append(&mut new_set);
}
fn get_cover_radius(&self, s: i64) -> F {
self.base.powf(F::from_i64(s).unwrap())
}
fn get_data_value(&self, idx: usize) -> &T {
&self.data[idx]
}
fn get_scale(&self, d: F) -> i64 {
if d == F::zero() {
std::i64::MIN
} else {
let (my, n) = self.split(p, self.base.powf(F::from(i - 1).unwrap()), &mut near, None);
let (pi, mut near) = self.construct(p, my, n, i - 1);
while near.len() > 0 {
let q_data = near.remove(0);
let nn = self.new_node(Some(p), q_data);
let (my, n) = self.split(
nn,
self.base.powf(F::from(i - 1).unwrap()),
&mut near,
Some(&mut far),
);
let (child, mut unused) = self.construct(nn, my, n, i - 1);
self.add_child(pi, child, i);
let new_near_far =
self.split(p, self.base.powf(F::from(i).unwrap()), &mut unused, None);
near.extend(new_near_far.0);
far.extend(new_near_far.1);
}
self.min_level = std::cmp::min(self.min_level, i);
return (pi, far);
(self.inv_log_base * d.ln()).ceil().to_i64().unwrap()
}
}
fn add_child(&mut self, parent: NodeId, node: NodeId, i: i8) {
self.nodes
.get_mut(parent.index)
.unwrap()
.children
.insert(i, node);
}
fn root(&self) -> &Node<T> {
self.nodes.first().unwrap()
}
fn get_children_dist<'b>(
&'b self,
p: &T,
qi_p_ds: &Vec<(&'b Node<T>, F)>,
i: i8,
) -> Vec<(&'b Node<T>, F)> {
let mut children = Vec::<(&'b Node<T>, F)>::new();
children.extend(qi_p_ds.iter().cloned());
let q: Vec<&Node<T>> = qi_p_ds
.iter()
.flat_map(|(n, _)| self.get_child(n, i))
.collect();
children.extend(
q.into_iter()
.map(|n| (n, self.distance.distance(&n.data, &p))),
);
children
}
fn min_k_by_distance(&self, q_p_ds: &mut Vec<(&Node<T>, F)>, k: usize) -> F {
let mut heap = HeapSelect::with_capacity(k);
for (_, d) in q_p_ds {
heap.add(d);
}
heap.sort();
*heap.get().pop().unwrap()
}
fn min_by_distance(&self, q_p_ds: &Vec<(&Node<T>, F)>) -> F {
q_p_ds
.into_iter()
.min_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap())
.unwrap()
.1
}
fn get_child(&self, node: &Node<T>, i: i8) -> Option<&Node<T>> {
node.children
.get(&i)
.and_then(|n_id| self.nodes.get(n_id.index))
}
#[allow(dead_code)]
fn check_invariant(
&self,
invariant: fn(&CoverTree<T, F, D>, &Vec<&Node<T>>, &Vec<&Node<T>>, i8) -> (),
) {
let mut current_nodes: Vec<&Node<T>> = Vec::new();
current_nodes.push(self.root());
for i in (self.min_level..self.max_level + 1).rev() {
let mut next_nodes: Vec<&Node<T>> = Vec::new();
next_nodes.extend(current_nodes.iter());
next_nodes.extend(current_nodes.iter().flat_map(|n| self.get_child(n, i)));
invariant(self, &current_nodes, &next_nodes, i);
current_nodes = next_nodes
}
}
#[allow(dead_code)]
fn nesting_invariant(
_: &CoverTree<T, F, D>,
nodes: &Vec<&Node<T>>,
next_nodes: &Vec<&Node<T>>,
_: i8,
) {
let nodes_set: HashSet<&Node<T>> = HashSet::from_iter(nodes.into_iter().map(|n| *n));
let next_nodes_set: HashSet<&Node<T>> =
HashSet::from_iter(next_nodes.into_iter().map(|n| *n));
for n in nodes_set.iter() {
assert!(next_nodes_set.contains(n), "Nesting invariant of the cover tree is not satisfied. Set of nodes [{:?}] is not a subset of [{:?}]", nodes_set, next_nodes_set);
}
}
#[allow(dead_code)]
fn covering_tree(
tree: &CoverTree<T, F, D>,
nodes: &Vec<&Node<T>>,
next_nodes: &Vec<&Node<T>>,
i: i8,
) {
let mut p_selected: Vec<&Node<T>> = Vec::new();
for p in next_nodes {
for q in nodes {
if tree.distance.distance(&p.data, &q.data) <= tree.base.powf(F::from(i).unwrap()) {
p_selected.push(*p);
}
}
let c = p_selected
.iter()
.filter(|q| p.parent.map(|p| q.index == p).unwrap_or(false))
.count();
assert!(c <= 1);
}
}
#[allow(dead_code)]
fn separation(tree: &CoverTree<T, F, D>, nodes: &Vec<&Node<T>>, _: &Vec<&Node<T>>, i: i8) {
for p in nodes {
for q in nodes {
if p != q {
assert!(
tree.distance.distance(&p.data, &q.data)
> tree.base.powf(F::from(i).unwrap())
);
}
fn max(&self, distance_set: &Vec<DistanceSet<F>>) -> F {
let mut max = F::zero();
for n in distance_set {
if max < n.dist[n.dist.len() - 1] {
max = n.dist[n.dist.len() - 1];
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
struct NodeId {
index: usize,
}
#[derive(Debug, Serialize, Deserialize)]
struct Node<T> {
index: NodeId,
data: T,
children: HashMap<i8, NodeId>,
parent: Option<NodeId>,
}
impl<T> PartialEq for Node<T> {
fn eq(&self, other: &Self) -> bool {
self.index.index == other.index.index
}
}
impl<T> Eq for Node<T> {}
impl<T> Hash for Node<T> {
fn hash<H>(&self, state: &mut H)
where
H: Hasher,
{
state.write_usize(self.index.index);
state.finish();
return max;
}
}
@@ -366,7 +392,9 @@ impl<T> Hash for Node<T> {
mod tests {
use super::*;
use crate::math::distance::Distances;
#[derive(Debug, Serialize, Deserialize)]
struct SimpleDistance {}
impl Distance<i32, f64> for SimpleDistance {
@@ -379,32 +407,42 @@ mod tests {
fn cover_tree_test() {
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
let mut tree = CoverTree::new(data, SimpleDistance {});
for d in vec![10, 11, 12, 13, 14, 15, 16, 17, 18, 19] {
tree.insert(d);
}
let tree = CoverTree::new(data, SimpleDistance {});
let mut nearest_3_to_5 = tree.find(&5, 3);
nearest_3_to_5.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let nearest_3_to_5_indexes: Vec<usize> = nearest_3_to_5.iter().map(|v| v.0).collect();
assert_eq!(vec!(4, 5, 3), nearest_3_to_5_indexes);
let mut nearest_3_to_15 = tree.find(&15, 3);
nearest_3_to_15.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let nearest_3_to_15_indexes: Vec<usize> = nearest_3_to_15.iter().map(|v| v.0).collect();
assert_eq!(vec!(14, 13, 15), nearest_3_to_15_indexes);
assert_eq!(-1, tree.min_level);
assert_eq!(100, tree.max_level);
let mut knn = tree.find(&5, 3);
knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
let knn: Vec<usize> = knn.iter().map(|v| v.0).collect();
assert_eq!(vec!(3, 4, 5), knn);
}
#[test]
fn test_invariants() {
fn cover_tree_test1() {
let data = vec![
vec![1., 2.],
vec![3., 4.],
vec![5., 6.],
vec![7., 8.],
vec![9., 10.],
];
let tree = CoverTree::new(data, Distances::euclidian());
let mut knn = tree.find(&vec![1., 2.], 3);
knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
let knn: Vec<usize> = knn.iter().map(|v| v.0).collect();
assert_eq!(vec!(0, 1, 2), knn);
}
#[test]
fn serde() {
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
let tree = CoverTree::new(data, SimpleDistance {});
tree.check_invariant(CoverTree::nesting_invariant);
tree.check_invariant(CoverTree::covering_tree);
tree.check_invariant(CoverTree::separation);
let deserialized_tree: CoverTree<i32, f64, SimpleDistance> =
serde_json::from_str(&serde_json::to_string(&tree).unwrap()).unwrap();
assert_eq!(tree, deserialized_tree);
}
}
+8 -8
View File
@@ -25,7 +25,7 @@ use serde::{Deserialize, Serialize};
use std::cmp::{Ordering, PartialOrd};
use std::marker::PhantomData;
use crate::algorithm::sort::heap_select::HeapSelect;
use crate::algorithm::sort::heap_select::HeapSelection;
use crate::math::distance::Distance;
use crate::math::num::RealNumber;
@@ -57,7 +57,7 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
panic!("k should be >= 1 and <= length(data)");
}
let mut heap = HeapSelect::<KNNPoint<F>>::with_capacity(k);
let mut heap = HeapSelection::<KNNPoint<F>>::with_capacity(k);
for _ in 0..k {
heap.add(KNNPoint {
@@ -76,8 +76,6 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
}
}
heap.sort();
heap.get()
.into_iter()
.flat_map(|x| x.index.map(|i| (i, x.distance)))
@@ -124,9 +122,10 @@ mod tests {
let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance {});
let found_idxs1: Vec<usize> = algorithm1.find(&2, 3).iter().map(|v| v.0).collect();
let mut found_idxs1: Vec<usize> = algorithm1.find(&2, 3).iter().map(|v| v.0).collect();
found_idxs1.sort();
assert_eq!(vec!(1, 2, 0), found_idxs1);
assert_eq!(vec!(0, 1, 2), found_idxs1);
let data2 = vec![
vec![1., 1.],
@@ -138,13 +137,14 @@ mod tests {
let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian());
let found_idxs2: Vec<usize> = algorithm2
let mut found_idxs2: Vec<usize> = algorithm2
.find(&vec![3., 3.], 3)
.iter()
.map(|v| v.0)
.collect();
found_idxs2.sort();
assert_eq!(vec!(2, 3, 1), found_idxs2);
assert_eq!(vec!(1, 2, 3), found_idxs2);
}
#[test]
+1 -1
View File
@@ -22,7 +22,7 @@
//!
//! ## References:
//! * ["The Art of Computer Programming" Knuth, D, Vol. 3, 2nd ed, Sorting and Searching, 1998](https://www-cs-faculty.stanford.edu/~knuth/taocp.html)
//! * ["Cover Trees for Nearest Neighbor" Beygelzimer et al., Proceedings of the 23rd international conference on Machine learning, ICML'06 (2006)](https://homes.cs.washington.edu/~sham/papers/ml/cover_tree.pdf)
//! * ["Cover Trees for Nearest Neighbor" Beygelzimer et al., Proceedings of the 23rd international conference on Machine learning, ICML'06 (2006)](https://hunch.net/~jl/projects/cover_tree/cover_tree.html)
//! * ["Faster cover trees." Izbicki et al., Proceedings of the 32nd International Conference on Machine Learning, ICML'15 (2015)](http://www.cs.ucr.edu/~cshelton/papers/index.cgi%3FIzbShe15)
//! * ["The Elements of Statistical Learning: Data Mining, Inference, and Prediction" Trevor et al., 2nd edition, chapter 13](https://web.stanford.edu/~hastie/ElemStatLearn/)
//!
+75 -70
View File
@@ -1,20 +1,24 @@
//! # Heap Selection Algorithm
//!
//! The goal is to find the k smallest elements in a list or array.
use std::cmp::Ordering;
use std::fmt::Debug;
#[derive(Debug)]
pub struct HeapSelect<T: PartialOrd> {
pub struct HeapSelection<T: PartialOrd + Debug> {
k: usize,
n: usize,
sorted: bool,
heap: Vec<T>,
}
impl<'a, T: PartialOrd> HeapSelect<T> {
pub fn with_capacity(k: usize) -> HeapSelect<T> {
HeapSelect {
impl<'a, T: PartialOrd + Debug> HeapSelection<T> {
pub fn with_capacity(k: usize) -> HeapSelection<T> {
HeapSelection {
k: k,
n: 0,
sorted: false,
heap: Vec::<T>::new(),
heap: Vec::new(),
}
}
@@ -24,12 +28,13 @@ impl<'a, T: PartialOrd> HeapSelect<T> {
self.heap.push(element);
self.n += 1;
if self.n == self.k {
self.heapify();
self.sort();
}
} else {
self.n += 1;
if element.partial_cmp(&self.heap[0]) == Some(Ordering::Less) {
self.heap[0] = element;
self.sift_down(0, self.k - 1);
}
}
}
@@ -41,58 +46,46 @@ impl<'a, T: PartialOrd> HeapSelect<T> {
}
}
#[allow(dead_code)]
pub fn peek(&self) -> &T {
return &self.heap[0];
if self.sorted {
return &self.heap[0];
} else {
&self
.heap
.iter()
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap()
}
}
pub fn peek_mut(&mut self) -> &mut T {
return &mut self.heap[0];
}
pub fn sift_down(&mut self, from: usize, n: usize) {
let mut k = from;
while 2 * k <= n {
let mut j = 2 * k;
if j < n && self.heap[j] < self.heap[j + 1] {
j += 1;
}
if self.heap[k] >= self.heap[j] {
break;
}
self.heap.swap(k, j);
k = j;
}
}
pub fn get(self) -> Vec<T> {
return self.heap;
}
pub fn sort(&mut self) {
HeapSelect::shuffle_sort(&mut self.heap, std::cmp::min(self.k, self.n));
fn sift_down(&mut self, k: usize, n: usize) {
let mut kk = k;
while 2 * kk <= n {
let mut j = 2 * kk;
if j < n && self.heap[j].partial_cmp(&self.heap[j + 1]) == Some(Ordering::Less) {
j += 1;
}
if self.heap[kk].partial_cmp(&self.heap[j]) == Some(Ordering::Equal)
|| self.heap[kk].partial_cmp(&self.heap[j]) == Some(Ordering::Greater)
{
break;
}
self.heap.swap(kk, j);
kk = j;
}
}
pub fn shuffle_sort(vec: &mut Vec<T>, n: usize) {
let mut inc = 1;
while inc <= n {
inc *= 3;
inc += 1
}
let len = n;
while inc >= 1 {
let mut i = inc;
while i < len {
let mut j = i;
while j >= inc && vec[j - inc] > vec[j] {
vec.swap(j - inc, j);
j -= inc;
}
i += 1;
}
inc /= 3
}
fn sort(&mut self) {
self.sorted = true;
self.heap.sort_by(|a, b| b.partial_cmp(a).unwrap());
}
}
@@ -102,50 +95,62 @@ mod tests {
#[test]
fn with_capacity() {
let heap = HeapSelect::<i32>::with_capacity(3);
let heap = HeapSelection::<i32>::with_capacity(3);
assert_eq!(3, heap.k);
}
#[test]
fn test_add() {
let mut heap = HeapSelect::with_capacity(3);
let mut heap = HeapSelection::with_capacity(3);
heap.add(-5);
assert_eq!(-5, *heap.peek());
heap.add(333);
heap.add(2);
assert_eq!(333, *heap.peek());
heap.add(13);
heap.add(10);
heap.add(2);
heap.add(0);
heap.add(40);
heap.add(30);
assert_eq!(6, heap.n);
assert_eq!(&10, heap.peek());
assert_eq!(&10, heap.peek_mut());
assert_eq!(8, heap.n);
assert_eq!(vec![2, 0, -5], heap.get());
}
#[test]
fn test_add1() {
let mut heap = HeapSelection::with_capacity(3);
heap.add(std::f64::INFINITY);
heap.add(-5f64);
heap.add(4f64);
heap.add(-1f64);
heap.add(2f64);
heap.add(1f64);
heap.add(0f64);
assert_eq!(7, heap.n);
assert_eq!(vec![0f64, -1f64, -5f64], heap.get());
}
#[test]
fn test_add2() {
let mut heap = HeapSelection::with_capacity(3);
heap.add(std::f64::INFINITY);
heap.add(0.0);
heap.add(8.4852);
heap.add(5.6568);
heap.add(2.8284);
assert_eq!(5, heap.n);
assert_eq!(vec![5.6568, 2.8284, 0.0], heap.get());
}
#[test]
fn test_add_ordered() {
let mut heap = HeapSelect::with_capacity(3);
let mut heap = HeapSelection::with_capacity(3);
heap.add(1.);
heap.add(2.);
heap.add(3.);
heap.add(4.);
heap.add(5.);
heap.add(6.);
let result = heap.get();
assert_eq!(vec![2., 3., 1.], result);
}
#[test]
fn test_shuffle_sort() {
let mut v1 = vec![10, 33, 22, 105, 12];
let n = v1.len();
HeapSelect::shuffle_sort(&mut v1, n);
assert_eq!(vec![10, 12, 22, 33, 105], v1);
let mut v2 = vec![10, 33, 22, 105, 12];
HeapSelect::shuffle_sort(&mut v2, 3);
assert_eq!(vec![10, 22, 33, 105, 12], v2);
let mut v3 = vec![4, 5, 3, 2, 1];
HeapSelect::shuffle_sort(&mut v3, 3);
assert_eq!(vec![3, 4, 5, 2, 1], v3);
assert_eq!(vec![3., 2., 1.], heap.get());
}
}
+90
View File
@@ -0,0 +1,90 @@
//! # The Iris Dataset flower
//!
//! [Fisher's Iris dataset](https://archive.ics.uci.edu/ml/datasets/iris) is a multivariate dataset that was published in 1936 by Ronald Fisher.
//! This multivariate dataset is frequently used to demonstrate various machine learning algorithms.
use crate::dataset::Dataset;
/// Get dataset
pub fn load_dataset() -> Dataset<f32, f32> {
let x = vec![
5.1, 3.5, 1.4, 0.2, 4.9, 3.0, 1.4, 0.2, 4.7, 3.2, 1.3, 0.2, 4.6, 3.1, 1.5, 0.2, 5.0, 3.6,
1.4, 0.2, 5.4, 3.9, 1.7, 0.4, 4.6, 3.4, 1.4, 0.3, 5.0, 3.4, 1.5, 0.2, 4.4, 2.9, 1.4, 0.2,
4.9, 3.1, 1.5, 0.1, 5.4, 3.7, 1.5, 0.2, 4.8, 3.4, 1.6, 0.2, 4.8, 3.0, 1.4, 0.1, 4.3, 3.0,
1.1, 0.1, 5.8, 4.0, 1.2, 0.2, 5.7, 4.4, 1.5, 0.4, 5.4, 3.9, 1.3, 0.4, 5.1, 3.5, 1.4, 0.3,
5.7, 3.8, 1.7, 0.3, 5.1, 3.8, 1.5, 0.3, 5.4, 3.4, 1.7, 0.2, 5.1, 3.7, 1.5, 0.4, 4.6, 3.6,
1.0, 0.2, 5.1, 3.3, 1.7, 0.5, 4.8, 3.4, 1.9, 0.2, 5.0, 3.0, 1.6, 0.2, 5.0, 3.4, 1.6, 0.4,
5.2, 3.5, 1.5, 0.2, 5.2, 3.4, 1.4, 0.2, 4.7, 3.2, 1.6, 0.2, 4.8, 3.1, 1.6, 0.2, 5.4, 3.4,
1.5, 0.4, 5.2, 4.1, 1.5, 0.1, 5.5, 4.2, 1.4, 0.2, 4.9, 3.1, 1.5, 0.1, 5.0, 3.2, 1.2, 0.2,
5.5, 3.5, 1.3, 0.2, 4.9, 3.1, 1.5, 0.1, 4.4, 3.0, 1.3, 0.2, 5.1, 3.4, 1.5, 0.2, 5.0, 3.5,
1.3, 0.3, 4.5, 2.3, 1.3, 0.3, 4.4, 3.2, 1.3, 0.2, 5.0, 3.5, 1.6, 0.6, 5.1, 3.8, 1.9, 0.4,
4.8, 3.0, 1.4, 0.3, 5.1, 3.8, 1.6, 0.2, 4.6, 3.2, 1.4, 0.2, 5.3, 3.7, 1.5, 0.2, 5.0, 3.3,
1.4, 0.2, 7.0, 3.2, 4.7, 1.4, 6.4, 3.2, 4.5, 1.5, 6.9, 3.1, 4.9, 1.5, 5.5, 2.3, 4.0, 1.3,
6.5, 2.8, 4.6, 1.5, 5.7, 2.8, 4.5, 1.3, 6.3, 3.3, 4.7, 1.6, 4.9, 2.4, 3.3, 1.0, 6.6, 2.9,
4.6, 1.3, 5.2, 2.7, 3.9, 1.4, 5.0, 2.0, 3.5, 1.0, 5.9, 3.0, 4.2, 1.5, 6.0, 2.2, 4.0, 1.0,
6.1, 2.9, 4.7, 1.4, 5.6, 2.9, 3.6, 1.3, 6.7, 3.1, 4.4, 1.4, 5.6, 3.0, 4.5, 1.5, 5.8, 2.7,
4.1, 1.0, 6.2, 2.2, 4.5, 1.5, 5.6, 2.5, 3.9, 1.1, 5.9, 3.2, 4.8, 1.8, 6.1, 2.8, 4.0, 1.3,
6.3, 2.5, 4.9, 1.5, 6.1, 2.8, 4.7, 1.2, 6.4, 2.9, 4.3, 1.3, 6.6, 3.0, 4.4, 1.4, 6.8, 2.8,
4.8, 1.4, 6.7, 3.0, 5.0, 1.7, 6.0, 2.9, 4.5, 1.5, 5.7, 2.6, 3.5, 1.0, 5.5, 2.4, 3.8, 1.1,
5.5, 2.4, 3.7, 1.0, 5.8, 2.7, 3.9, 1.2, 6.0, 2.7, 5.1, 1.6, 5.4, 3.0, 4.5, 1.5, 6.0, 3.4,
4.5, 1.6, 6.7, 3.1, 4.7, 1.5, 6.3, 2.3, 4.4, 1.3, 5.6, 3.0, 4.1, 1.3, 5.5, 2.5, 4.0, 1.3,
5.5, 2.6, 4.4, 1.2, 6.1, 3.0, 4.6, 1.4, 5.8, 2.6, 4.0, 1.2, 5.0, 2.3, 3.3, 1.0, 5.6, 2.7,
4.2, 1.3, 5.7, 3.0, 4.2, 1.2, 5.7, 2.9, 4.2, 1.3, 6.2, 2.9, 4.3, 1.3, 5.1, 2.5, 3.0, 1.1,
5.7, 2.8, 4.1, 1.3, 6.3, 3.3, 6.0, 2.5, 5.8, 2.7, 5.1, 1.9, 7.1, 3.0, 5.9, 2.1, 6.3, 2.9,
5.6, 1.8, 6.5, 3.0, 5.8, 2.2, 7.6, 3.0, 6.6, 2.1, 4.9, 2.5, 4.5, 1.7, 7.3, 2.9, 6.3, 1.8,
6.7, 2.5, 5.8, 1.8, 7.2, 3.6, 6.1, 2.5, 6.5, 3.2, 5.1, 2.0, 6.4, 2.7, 5.3, 1.9, 6.8, 3.0,
5.5, 2.1, 5.7, 2.5, 5.0, 2.0, 5.8, 2.8, 5.1, 2.4, 6.4, 3.2, 5.3, 2.3, 6.5, 3.0, 5.5, 1.8,
7.7, 3.8, 6.7, 2.2, 7.7, 2.6, 6.9, 2.3, 6.0, 2.2, 5.0, 1.5, 6.9, 3.2, 5.7, 2.3, 5.6, 2.8,
4.9, 2.0, 7.7, 2.8, 6.7, 2.0, 6.3, 2.7, 4.9, 1.8, 6.7, 3.3, 5.7, 2.1, 7.2, 3.2, 6.0, 1.8,
6.2, 2.8, 4.8, 1.8, 6.1, 3.0, 4.9, 1.8, 6.4, 2.8, 5.6, 2.1, 7.2, 3.0, 5.8, 1.6, 7.4, 2.8,
6.1, 1.9, 7.9, 3.8, 6.4, 2.0, 6.4, 2.8, 5.6, 2.2, 6.3, 2.8, 5.1, 1.5, 6.1, 2.6, 5.6, 1.4,
7.7, 3.0, 6.1, 2.3, 6.3, 3.4, 5.6, 2.4, 6.4, 3.1, 5.5, 1.8, 6.0, 3.0, 4.8, 1.8, 6.9, 3.1,
5.4, 2.1, 6.7, 3.1, 5.6, 2.4, 6.9, 3.1, 5.1, 2.3, 5.8, 2.7, 5.1, 1.9, 6.8, 3.2, 5.9, 2.3,
6.7, 3.3, 5.7, 2.5, 6.7, 3.0, 5.2, 2.3, 6.3, 2.5, 5.0, 1.9, 6.5, 3.0, 5.2, 2.0, 6.2, 3.4,
5.4, 2.3, 5.9, 3.0, 5.1, 1.8,
];
let setosa = std::iter::repeat(0f32).take(50);
let versicolor = std::iter::repeat(1f32).take(50);
let virginica = std::iter::repeat(2f32).take(50);
let y = setosa
.chain(versicolor)
.chain(virginica)
.collect::<Vec<f32>>();
let shape = (150, 4);
Dataset {
data: x,
target: y,
num_samples: shape.0,
num_features: shape.1,
feature_names: vec![
"sepal length (cm)",
"sepal width (cm)",
"petal length (cm)",
"petal width (cm)",
]
.iter()
.map(|s| s.to_string())
.collect(),
target_names: vec!["setosa", "versicolor", "virginica"]
.iter()
.map(|s| s.to_string())
.collect(),
description: "Iris dataset: https://archive.ics.uci.edu/ml/datasets/iris".to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn iris_dataset() {
let dataset = load_dataset();
assert_eq!(dataset.data.len(), 50 * 3 * 4);
assert_eq!(dataset.target.len(), 50 * 3);
assert_eq!(dataset.num_features, 4);
assert_eq!(dataset.num_samples, 50 * 3);
}
}
+65
View File
@@ -0,0 +1,65 @@
//! Datasets
//!
//! In this module you will find small datasets that are used in SmartCore for demonstration purpose mostly.
/// Iris flower data set
pub mod iris;
/// Dataset
pub struct Dataset<X, Y> {
/// data in one-dimensional array.
pub data: Vec<X>,
/// target values or class labels.
pub target: Vec<Y>,
/// number of samples (number of rows in matrix form).
pub num_samples: usize,
/// number of features (number of columns in matrix form).
pub num_features: usize,
/// names of dependent variables.
pub feature_names: Vec<String>,
/// names of target variables.
pub target_names: Vec<String>,
/// dataset description
pub description: String,
}
impl<X, Y> Dataset<X, Y> {
/// Reshape data into a two-dimensional matrix
pub fn as_2d_vector(&self) -> Vec<Vec<&X>> {
let mut result: Vec<Vec<&X>> = Vec::with_capacity(self.num_samples);
for r in 0..self.num_samples {
let mut row = Vec::with_capacity(self.num_features);
for c in 0..self.num_features {
row.push(&self.data[r * self.num_features + c]);
}
result.push(row);
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn as_2d_vector() {
let dataset = Dataset {
data: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
target: vec![1, 2, 3],
num_samples: 2,
num_features: 5,
feature_names: vec![],
target_names: vec![],
description: "".to_string(),
};
let m = dataset.as_2d_vector();
assert_eq!(m.len(), 2);
assert_eq!(m[0].len(), 5);
assert_eq!(*m[1][3], 9);
}
}
+3
View File
@@ -68,6 +68,9 @@
pub mod algorithm;
/// Algorithms for clustering of unlabeled data
pub mod cluster;
/// Various datasets
#[cfg(feature = "datasets")]
pub mod dataset;
/// Matrix decomposition algorithms
pub mod decomposition;
/// Ensemble methods, including Random Forest classifier and regressor
+3 -1
View File
@@ -143,6 +143,7 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
fn predict_for_row(&self, x: Vec<T>) -> T {
let search_result = self.knn_algorithm.find(&x, self.k);
println!("{:?}", search_result);
let mut result = T::zero();
let weights = self
@@ -195,9 +196,10 @@ mod tests {
let y_exp = vec![2., 2., 3., 4., 4.];
let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default());
let y_hat = knn.predict(&x);
println!("{:?}", y_hat);
assert_eq!(5, Vec::len(&y_hat));
for i in 0..y_hat.len() {
assert!((y_hat[i] - y_exp[i]).abs() < std::f64::EPSILON);
assert!((y_hat[i] - y_exp[i]).abs() < 1e-7);
}
}