feat: adds dataset module, fixs problem in CoverTree implementation
This commit is contained in:
+2
-1
@@ -5,9 +5,10 @@ authors = ["SmartCore Developers"]
|
|||||||
edition = "2018"
|
edition = "2018"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = ["datasets"]
|
||||||
ndarray-bindings = ["ndarray"]
|
ndarray-bindings = ["ndarray"]
|
||||||
nalgebra-bindings = ["nalgebra"]
|
nalgebra-bindings = ["nalgebra"]
|
||||||
|
datasets = []
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
ndarray = { version = "0.13", optional = true }
|
ndarray = { version = "0.13", optional = true }
|
||||||
|
|||||||
@@ -21,14 +21,11 @@
|
|||||||
//! tree.find(&5, 3); // find 3 knn points from 5
|
//! tree.find(&5, 3); // find 3 knn points from 5
|
||||||
//!
|
//!
|
||||||
//! ```
|
//! ```
|
||||||
use core::hash::{Hash, Hasher};
|
|
||||||
use std::collections::{HashMap, HashSet};
|
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::iter::FromIterator;
|
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::algorithm::sort::heap_select::HeapSelect;
|
use crate::algorithm::sort::heap_select::HeapSelection;
|
||||||
use crate::math::distance::Distance;
|
use crate::math::distance::Distance;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
@@ -36,329 +33,358 @@ use crate::math::num::RealNumber;
|
|||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct CoverTree<T, F: RealNumber, D: Distance<T, F>> {
|
pub struct CoverTree<T, F: RealNumber, D: Distance<T, F>> {
|
||||||
base: F,
|
base: F,
|
||||||
max_level: i8,
|
inv_log_base: F,
|
||||||
min_level: i8,
|
|
||||||
distance: D,
|
distance: D,
|
||||||
nodes: Vec<Node<T>>,
|
root: Node<F>,
|
||||||
|
data: Vec<T>,
|
||||||
|
identical_excluded: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Debug, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D> {
|
impl<T, F: RealNumber, D: Distance<T, F>> PartialEq for CoverTree<T, F, D> {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
if self.data.len() != other.data.len() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for i in 0..self.data.len() {
|
||||||
|
if self.distance.distance(&self.data[i], &other.data[i]) != F::zero() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct Node<F: RealNumber> {
|
||||||
|
idx: usize,
|
||||||
|
max_dist: F,
|
||||||
|
parent_dist: F,
|
||||||
|
children: Vec<Node<F>>,
|
||||||
|
scale: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct DistanceSet<F: RealNumber> {
|
||||||
|
idx: usize,
|
||||||
|
dist: Vec<F>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D> {
|
||||||
/// Construct a cover tree.
|
/// Construct a cover tree.
|
||||||
/// * `data` - vector of data points to search for.
|
/// * `data` - vector of data points to search for.
|
||||||
/// * `distance` - distance metric to use for searching. This function should extend [`Distance`](../../../math/distance/index.html) interface.
|
/// * `distance` - distance metric to use for searching. This function should extend [`Distance`](../../../math/distance/index.html) interface.
|
||||||
pub fn new(mut data: Vec<T>, distance: D) -> CoverTree<T, F, D> {
|
pub fn new(data: Vec<T>, distance: D) -> CoverTree<T, F, D> {
|
||||||
|
let base = F::from_f64(1.3).unwrap();
|
||||||
|
let root = Node {
|
||||||
|
idx: 0,
|
||||||
|
max_dist: F::zero(),
|
||||||
|
parent_dist: F::zero(),
|
||||||
|
children: Vec::new(),
|
||||||
|
scale: 0,
|
||||||
|
};
|
||||||
let mut tree = CoverTree {
|
let mut tree = CoverTree {
|
||||||
base: F::two(),
|
base: base,
|
||||||
max_level: 100,
|
inv_log_base: F::one() / base.ln(),
|
||||||
min_level: 100,
|
|
||||||
distance: distance,
|
distance: distance,
|
||||||
nodes: Vec::new(),
|
root: root,
|
||||||
|
data: data,
|
||||||
|
identical_excluded: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let p = tree.new_node(None, data.remove(0));
|
tree.build_cover_tree();
|
||||||
tree.construct(p, data, Vec::new(), 10);
|
|
||||||
|
|
||||||
tree
|
tree
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Insert new data point into the cover tree.
|
|
||||||
/// * `p` - new data points.
|
|
||||||
pub fn insert(&mut self, p: T) {
|
|
||||||
if self.nodes.is_empty() {
|
|
||||||
self.new_node(None, p);
|
|
||||||
} else {
|
|
||||||
let mut parent: Option<NodeId> = Option::None;
|
|
||||||
let mut p_i = 0;
|
|
||||||
let mut qi_p_ds = vec![(self.root(), self.distance.distance(&p, &self.root().data))];
|
|
||||||
let mut i = self.max_level;
|
|
||||||
loop {
|
|
||||||
let i_d = self.base.powf(F::from(i).unwrap());
|
|
||||||
let q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
|
|
||||||
let d_p_q = self.min_by_distance(&q_p_ds);
|
|
||||||
if d_p_q < F::epsilon() {
|
|
||||||
return;
|
|
||||||
} else if d_p_q > i_d {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if self.min_by_distance(&qi_p_ds) <= self.base.powf(F::from(i).unwrap()) {
|
|
||||||
parent = q_p_ds.iter().find(|(_, d)| d <= &i_d).map(|(n, _)| n.index);
|
|
||||||
p_i = i;
|
|
||||||
}
|
|
||||||
|
|
||||||
qi_p_ds = q_p_ds.into_iter().filter(|(_, d)| d <= &i_d).collect();
|
|
||||||
i -= 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
let new_node = self.new_node(parent, p);
|
|
||||||
self.add_child(parent.unwrap(), new_node, p_i);
|
|
||||||
self.min_level = i8::min(self.min_level, p_i - 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn new_node(&mut self, parent: Option<NodeId>, data: T) -> NodeId {
|
|
||||||
let next_index = self.nodes.len();
|
|
||||||
let node_id = NodeId { index: next_index };
|
|
||||||
self.nodes.push(Node {
|
|
||||||
index: node_id,
|
|
||||||
data: data,
|
|
||||||
parent: parent,
|
|
||||||
children: HashMap::new(),
|
|
||||||
});
|
|
||||||
node_id
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Find k nearest neighbors of `p`
|
/// Find k nearest neighbors of `p`
|
||||||
/// * `p` - look for k nearest points to `p`
|
/// * `p` - look for k nearest points to `p`
|
||||||
/// * `k` - the number of nearest neighbors to return
|
/// * `k` - the number of nearest neighbors to return
|
||||||
pub fn find(&self, p: &T, k: usize) -> Vec<(usize, F)> {
|
pub fn find(&self, p: &T, k: usize) -> Vec<(usize, F)> {
|
||||||
let mut qi_p_ds = vec![(self.root(), self.distance.distance(&p, &self.root().data))];
|
if k <= 0 {
|
||||||
for i in (self.min_level..self.max_level + 1).rev() {
|
panic!("k should be > 0");
|
||||||
let i_d = self.base.powf(F::from(i).unwrap());
|
}
|
||||||
let mut q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
|
|
||||||
let d_p_q = self.min_k_by_distance(&mut q_p_ds, k);
|
if k > self.data.len() {
|
||||||
qi_p_ds = q_p_ds
|
panic!("k is > than the dataset size");
|
||||||
.into_iter()
|
}
|
||||||
.filter(|(_, d)| d <= &(d_p_q + i_d))
|
|
||||||
.collect();
|
let e = self.get_data_value(self.root.idx);
|
||||||
|
let mut d = self.distance.distance(&e, p);
|
||||||
|
|
||||||
|
let mut current_cover_set: Vec<(F, &Node<F>)> = Vec::new();
|
||||||
|
let mut zero_set: Vec<(F, &Node<F>)> = Vec::new();
|
||||||
|
|
||||||
|
current_cover_set.push((d, &self.root));
|
||||||
|
|
||||||
|
let mut heap = HeapSelection::with_capacity(k);
|
||||||
|
heap.add(F::max_value());
|
||||||
|
|
||||||
|
let mut empty_heap = true;
|
||||||
|
if !self.identical_excluded || self.get_data_value(self.root.idx) != p {
|
||||||
|
heap.add(d);
|
||||||
|
empty_heap = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
while !current_cover_set.is_empty() {
|
||||||
|
let mut next_cover_set: Vec<(F, &Node<F>)> = Vec::new();
|
||||||
|
for par in current_cover_set {
|
||||||
|
let parent = par.1;
|
||||||
|
for c in 0..parent.children.len() {
|
||||||
|
let child = &parent.children[c];
|
||||||
|
if c == 0 {
|
||||||
|
d = par.0;
|
||||||
|
} else {
|
||||||
|
d = self.distance.distance(self.get_data_value(child.idx), p);
|
||||||
|
}
|
||||||
|
|
||||||
|
let upper_bound = if empty_heap {
|
||||||
|
F::infinity()
|
||||||
|
} else {
|
||||||
|
*heap.peek()
|
||||||
|
};
|
||||||
|
if d <= (upper_bound + child.max_dist) {
|
||||||
|
if c > 0 && d < upper_bound {
|
||||||
|
if !self.identical_excluded || self.get_data_value(child.idx) != p {
|
||||||
|
heap.add(d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !child.children.is_empty() {
|
||||||
|
next_cover_set.push((d, child));
|
||||||
|
} else if d <= upper_bound {
|
||||||
|
zero_set.push((d, child));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
current_cover_set = next_cover_set;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut neighbors: Vec<(usize, F)> = Vec::new();
|
||||||
|
let upper_bound = *heap.peek();
|
||||||
|
for ds in zero_set {
|
||||||
|
if ds.0 <= upper_bound {
|
||||||
|
let v = self.get_data_value(ds.1.idx);
|
||||||
|
if !self.identical_excluded || v != p {
|
||||||
|
neighbors.push((ds.1.idx, ds.0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
neighbors.into_iter().take(k).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new_leaf(&self, idx: usize) -> Node<F> {
|
||||||
|
Node {
|
||||||
|
idx: idx,
|
||||||
|
max_dist: F::zero(),
|
||||||
|
parent_dist: F::zero(),
|
||||||
|
children: Vec::new(),
|
||||||
|
scale: 100,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_cover_tree(&mut self) {
|
||||||
|
let mut point_set: Vec<DistanceSet<F>> = Vec::new();
|
||||||
|
let mut consumed_set: Vec<DistanceSet<F>> = Vec::new();
|
||||||
|
|
||||||
|
let point = &self.data[0];
|
||||||
|
let idx = 0;
|
||||||
|
let mut max_dist = -F::one();
|
||||||
|
|
||||||
|
for i in 1..self.data.len() {
|
||||||
|
let dist = self.distance.distance(point, &self.data[i]);
|
||||||
|
let set = DistanceSet {
|
||||||
|
idx: i,
|
||||||
|
dist: vec![dist],
|
||||||
|
};
|
||||||
|
point_set.push(set);
|
||||||
|
if dist > max_dist {
|
||||||
|
max_dist = dist;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.root = self.batch_insert(
|
||||||
|
idx,
|
||||||
|
self.get_scale(max_dist),
|
||||||
|
self.get_scale(max_dist),
|
||||||
|
&mut point_set,
|
||||||
|
&mut consumed_set,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn batch_insert(
|
||||||
|
&self,
|
||||||
|
p: usize,
|
||||||
|
max_scale: i64,
|
||||||
|
top_scale: i64,
|
||||||
|
point_set: &mut Vec<DistanceSet<F>>,
|
||||||
|
consumed_set: &mut Vec<DistanceSet<F>>,
|
||||||
|
) -> Node<F> {
|
||||||
|
if point_set.is_empty() {
|
||||||
|
self.new_leaf(p)
|
||||||
|
} else {
|
||||||
|
let max_dist = self.max(&point_set);
|
||||||
|
let next_scale = (max_scale - 1).min(self.get_scale(max_dist));
|
||||||
|
if next_scale == std::i64::MIN {
|
||||||
|
let mut children: Vec<Node<F>> = Vec::new();
|
||||||
|
let mut leaf = self.new_leaf(p);
|
||||||
|
children.push(leaf);
|
||||||
|
while !point_set.is_empty() {
|
||||||
|
let set = point_set.remove(point_set.len() - 1);
|
||||||
|
leaf = self.new_leaf(set.idx);
|
||||||
|
children.push(leaf);
|
||||||
|
consumed_set.push(set);
|
||||||
|
}
|
||||||
|
Node {
|
||||||
|
idx: p,
|
||||||
|
max_dist: F::zero(),
|
||||||
|
parent_dist: F::zero(),
|
||||||
|
children: children,
|
||||||
|
scale: 100,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let mut far: Vec<DistanceSet<F>> = Vec::new();
|
||||||
|
self.split(point_set, &mut far, max_scale);
|
||||||
|
|
||||||
|
let child = self.batch_insert(p, next_scale, top_scale, point_set, consumed_set);
|
||||||
|
|
||||||
|
if point_set.is_empty() {
|
||||||
|
point_set.append(&mut far);
|
||||||
|
child
|
||||||
|
} else {
|
||||||
|
let mut children: Vec<Node<F>> = Vec::new();
|
||||||
|
children.push(child);
|
||||||
|
let mut new_point_set: Vec<DistanceSet<F>> = Vec::new();
|
||||||
|
let mut new_consumed_set: Vec<DistanceSet<F>> = Vec::new();
|
||||||
|
|
||||||
|
while !point_set.is_empty() {
|
||||||
|
let set: DistanceSet<F> = point_set.remove(point_set.len() - 1);
|
||||||
|
|
||||||
|
let new_dist: F = set.dist[set.dist.len() - 1];
|
||||||
|
|
||||||
|
self.dist_split(
|
||||||
|
point_set,
|
||||||
|
&mut new_point_set,
|
||||||
|
self.get_data_value(set.idx),
|
||||||
|
max_scale,
|
||||||
|
);
|
||||||
|
self.dist_split(
|
||||||
|
&mut far,
|
||||||
|
&mut new_point_set,
|
||||||
|
self.get_data_value(set.idx),
|
||||||
|
max_scale,
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut new_child = self.batch_insert(
|
||||||
|
set.idx,
|
||||||
|
next_scale,
|
||||||
|
top_scale,
|
||||||
|
&mut new_point_set,
|
||||||
|
&mut new_consumed_set,
|
||||||
|
);
|
||||||
|
new_child.parent_dist = new_dist;
|
||||||
|
|
||||||
|
consumed_set.push(set);
|
||||||
|
children.push(new_child);
|
||||||
|
|
||||||
|
let fmax = self.get_cover_radius(max_scale);
|
||||||
|
for mut set in new_point_set.drain(0..) {
|
||||||
|
set.dist.remove(set.dist.len() - 1);
|
||||||
|
if set.dist[set.dist.len() - 1] <= fmax {
|
||||||
|
point_set.push(set);
|
||||||
|
} else {
|
||||||
|
far.push(set);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for mut set in new_consumed_set.drain(0..) {
|
||||||
|
set.dist.remove(set.dist.len() - 1);
|
||||||
|
consumed_set.push(set);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
point_set.append(&mut far);
|
||||||
|
|
||||||
|
Node {
|
||||||
|
idx: p,
|
||||||
|
max_dist: self.max(consumed_set),
|
||||||
|
parent_dist: F::zero(),
|
||||||
|
children: children,
|
||||||
|
scale: (top_scale - max_scale),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
qi_p_ds.sort_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap());
|
|
||||||
qi_p_ds[..usize::min(qi_p_ds.len(), k)]
|
|
||||||
.iter()
|
|
||||||
.map(|(n, d)| (n.index.index, *d))
|
|
||||||
.collect()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn split(
|
fn split(
|
||||||
&self,
|
&self,
|
||||||
p_id: NodeId,
|
point_set: &mut Vec<DistanceSet<F>>,
|
||||||
r: F,
|
far_set: &mut Vec<DistanceSet<F>>,
|
||||||
s1: &mut Vec<T>,
|
max_scale: i64,
|
||||||
s2: Option<&mut Vec<T>>,
|
) {
|
||||||
) -> (Vec<T>, Vec<T>) {
|
let fmax = self.get_cover_radius(max_scale);
|
||||||
let mut my_near = (Vec::new(), Vec::new());
|
let mut new_set: Vec<DistanceSet<F>> = Vec::new();
|
||||||
|
for n in point_set.drain(0..) {
|
||||||
my_near = self.split_remove_s(p_id, r, s1, my_near);
|
if n.dist[n.dist.len() - 1] <= fmax {
|
||||||
|
new_set.push(n);
|
||||||
for s in s2 {
|
} else {
|
||||||
my_near = self.split_remove_s(p_id, r, s, my_near);
|
far_set.push(n);
|
||||||
}
|
|
||||||
|
|
||||||
return my_near;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn split_remove_s(
|
|
||||||
&self,
|
|
||||||
p_id: NodeId,
|
|
||||||
r: F,
|
|
||||||
s: &mut Vec<T>,
|
|
||||||
mut my_near: (Vec<T>, Vec<T>),
|
|
||||||
) -> (Vec<T>, Vec<T>) {
|
|
||||||
if s.len() > 0 {
|
|
||||||
let p = &self.nodes.get(p_id.index).unwrap().data;
|
|
||||||
let mut i = 0;
|
|
||||||
while i != s.len() {
|
|
||||||
let d = self.distance.distance(p, &s[i]);
|
|
||||||
if d <= r {
|
|
||||||
my_near.0.push(s.remove(i));
|
|
||||||
} else if d > r && d <= F::two() * r {
|
|
||||||
my_near.1.push(s.remove(i));
|
|
||||||
} else {
|
|
||||||
i += 1;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return my_near;
|
point_set.append(&mut new_set);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn construct<'b>(
|
fn dist_split(
|
||||||
&mut self,
|
&self,
|
||||||
p: NodeId,
|
point_set: &mut Vec<DistanceSet<F>>,
|
||||||
mut near: Vec<T>,
|
new_point_set: &mut Vec<DistanceSet<F>>,
|
||||||
mut far: Vec<T>,
|
new_point: &T,
|
||||||
i: i8,
|
max_scale: i64,
|
||||||
) -> (NodeId, Vec<T>) {
|
) {
|
||||||
if near.len() < 1 {
|
let fmax = self.get_cover_radius(max_scale);
|
||||||
self.min_level = std::cmp::min(self.min_level, i);
|
let mut new_set: Vec<DistanceSet<F>> = Vec::new();
|
||||||
return (p, far);
|
for mut n in point_set.drain(0..) {
|
||||||
|
let new_dist = self
|
||||||
|
.distance
|
||||||
|
.distance(new_point, self.get_data_value(n.idx));
|
||||||
|
if new_dist <= fmax {
|
||||||
|
n.dist.push(new_dist);
|
||||||
|
new_point_set.push(n);
|
||||||
|
} else {
|
||||||
|
new_set.push(n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
point_set.append(&mut new_set);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_cover_radius(&self, s: i64) -> F {
|
||||||
|
self.base.powf(F::from_i64(s).unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_data_value(&self, idx: usize) -> &T {
|
||||||
|
&self.data[idx]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_scale(&self, d: F) -> i64 {
|
||||||
|
if d == F::zero() {
|
||||||
|
std::i64::MIN
|
||||||
} else {
|
} else {
|
||||||
let (my, n) = self.split(p, self.base.powf(F::from(i - 1).unwrap()), &mut near, None);
|
(self.inv_log_base * d.ln()).ceil().to_i64().unwrap()
|
||||||
let (pi, mut near) = self.construct(p, my, n, i - 1);
|
|
||||||
while near.len() > 0 {
|
|
||||||
let q_data = near.remove(0);
|
|
||||||
let nn = self.new_node(Some(p), q_data);
|
|
||||||
let (my, n) = self.split(
|
|
||||||
nn,
|
|
||||||
self.base.powf(F::from(i - 1).unwrap()),
|
|
||||||
&mut near,
|
|
||||||
Some(&mut far),
|
|
||||||
);
|
|
||||||
let (child, mut unused) = self.construct(nn, my, n, i - 1);
|
|
||||||
self.add_child(pi, child, i);
|
|
||||||
let new_near_far =
|
|
||||||
self.split(p, self.base.powf(F::from(i).unwrap()), &mut unused, None);
|
|
||||||
near.extend(new_near_far.0);
|
|
||||||
far.extend(new_near_far.1);
|
|
||||||
}
|
|
||||||
self.min_level = std::cmp::min(self.min_level, i);
|
|
||||||
return (pi, far);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_child(&mut self, parent: NodeId, node: NodeId, i: i8) {
|
fn max(&self, distance_set: &Vec<DistanceSet<F>>) -> F {
|
||||||
self.nodes
|
let mut max = F::zero();
|
||||||
.get_mut(parent.index)
|
for n in distance_set {
|
||||||
.unwrap()
|
if max < n.dist[n.dist.len() - 1] {
|
||||||
.children
|
max = n.dist[n.dist.len() - 1];
|
||||||
.insert(i, node);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn root(&self) -> &Node<T> {
|
|
||||||
self.nodes.first().unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_children_dist<'b>(
|
|
||||||
&'b self,
|
|
||||||
p: &T,
|
|
||||||
qi_p_ds: &Vec<(&'b Node<T>, F)>,
|
|
||||||
i: i8,
|
|
||||||
) -> Vec<(&'b Node<T>, F)> {
|
|
||||||
let mut children = Vec::<(&'b Node<T>, F)>::new();
|
|
||||||
|
|
||||||
children.extend(qi_p_ds.iter().cloned());
|
|
||||||
|
|
||||||
let q: Vec<&Node<T>> = qi_p_ds
|
|
||||||
.iter()
|
|
||||||
.flat_map(|(n, _)| self.get_child(n, i))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
children.extend(
|
|
||||||
q.into_iter()
|
|
||||||
.map(|n| (n, self.distance.distance(&n.data, &p))),
|
|
||||||
);
|
|
||||||
|
|
||||||
children
|
|
||||||
}
|
|
||||||
|
|
||||||
fn min_k_by_distance(&self, q_p_ds: &mut Vec<(&Node<T>, F)>, k: usize) -> F {
|
|
||||||
let mut heap = HeapSelect::with_capacity(k);
|
|
||||||
for (_, d) in q_p_ds {
|
|
||||||
heap.add(d);
|
|
||||||
}
|
|
||||||
heap.sort();
|
|
||||||
*heap.get().pop().unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn min_by_distance(&self, q_p_ds: &Vec<(&Node<T>, F)>) -> F {
|
|
||||||
q_p_ds
|
|
||||||
.into_iter()
|
|
||||||
.min_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap())
|
|
||||||
.unwrap()
|
|
||||||
.1
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_child(&self, node: &Node<T>, i: i8) -> Option<&Node<T>> {
|
|
||||||
node.children
|
|
||||||
.get(&i)
|
|
||||||
.and_then(|n_id| self.nodes.get(n_id.index))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
fn check_invariant(
|
|
||||||
&self,
|
|
||||||
invariant: fn(&CoverTree<T, F, D>, &Vec<&Node<T>>, &Vec<&Node<T>>, i8) -> (),
|
|
||||||
) {
|
|
||||||
let mut current_nodes: Vec<&Node<T>> = Vec::new();
|
|
||||||
current_nodes.push(self.root());
|
|
||||||
for i in (self.min_level..self.max_level + 1).rev() {
|
|
||||||
let mut next_nodes: Vec<&Node<T>> = Vec::new();
|
|
||||||
next_nodes.extend(current_nodes.iter());
|
|
||||||
next_nodes.extend(current_nodes.iter().flat_map(|n| self.get_child(n, i)));
|
|
||||||
invariant(self, ¤t_nodes, &next_nodes, i);
|
|
||||||
current_nodes = next_nodes
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
fn nesting_invariant(
|
|
||||||
_: &CoverTree<T, F, D>,
|
|
||||||
nodes: &Vec<&Node<T>>,
|
|
||||||
next_nodes: &Vec<&Node<T>>,
|
|
||||||
_: i8,
|
|
||||||
) {
|
|
||||||
let nodes_set: HashSet<&Node<T>> = HashSet::from_iter(nodes.into_iter().map(|n| *n));
|
|
||||||
let next_nodes_set: HashSet<&Node<T>> =
|
|
||||||
HashSet::from_iter(next_nodes.into_iter().map(|n| *n));
|
|
||||||
for n in nodes_set.iter() {
|
|
||||||
assert!(next_nodes_set.contains(n), "Nesting invariant of the cover tree is not satisfied. Set of nodes [{:?}] is not a subset of [{:?}]", nodes_set, next_nodes_set);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
fn covering_tree(
|
|
||||||
tree: &CoverTree<T, F, D>,
|
|
||||||
nodes: &Vec<&Node<T>>,
|
|
||||||
next_nodes: &Vec<&Node<T>>,
|
|
||||||
i: i8,
|
|
||||||
) {
|
|
||||||
let mut p_selected: Vec<&Node<T>> = Vec::new();
|
|
||||||
for p in next_nodes {
|
|
||||||
for q in nodes {
|
|
||||||
if tree.distance.distance(&p.data, &q.data) <= tree.base.powf(F::from(i).unwrap()) {
|
|
||||||
p_selected.push(*p);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let c = p_selected
|
|
||||||
.iter()
|
|
||||||
.filter(|q| p.parent.map(|p| q.index == p).unwrap_or(false))
|
|
||||||
.count();
|
|
||||||
assert!(c <= 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
fn separation(tree: &CoverTree<T, F, D>, nodes: &Vec<&Node<T>>, _: &Vec<&Node<T>>, i: i8) {
|
|
||||||
for p in nodes {
|
|
||||||
for q in nodes {
|
|
||||||
if p != q {
|
|
||||||
assert!(
|
|
||||||
tree.distance.distance(&p.data, &q.data)
|
|
||||||
> tree.base.powf(F::from(i).unwrap())
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
return max;
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
|
|
||||||
struct NodeId {
|
|
||||||
index: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
struct Node<T> {
|
|
||||||
index: NodeId,
|
|
||||||
data: T,
|
|
||||||
children: HashMap<i8, NodeId>,
|
|
||||||
parent: Option<NodeId>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T> PartialEq for Node<T> {
|
|
||||||
fn eq(&self, other: &Self) -> bool {
|
|
||||||
self.index.index == other.index.index
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T> Eq for Node<T> {}
|
|
||||||
|
|
||||||
impl<T> Hash for Node<T> {
|
|
||||||
fn hash<H>(&self, state: &mut H)
|
|
||||||
where
|
|
||||||
H: Hasher,
|
|
||||||
{
|
|
||||||
state.write_usize(self.index.index);
|
|
||||||
state.finish();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -366,7 +392,9 @@ impl<T> Hash for Node<T> {
|
|||||||
mod tests {
|
mod tests {
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::math::distance::Distances;
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
struct SimpleDistance {}
|
struct SimpleDistance {}
|
||||||
|
|
||||||
impl Distance<i32, f64> for SimpleDistance {
|
impl Distance<i32, f64> for SimpleDistance {
|
||||||
@@ -379,32 +407,42 @@ mod tests {
|
|||||||
fn cover_tree_test() {
|
fn cover_tree_test() {
|
||||||
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
|
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
|
||||||
|
|
||||||
let mut tree = CoverTree::new(data, SimpleDistance {});
|
let tree = CoverTree::new(data, SimpleDistance {});
|
||||||
for d in vec![10, 11, 12, 13, 14, 15, 16, 17, 18, 19] {
|
|
||||||
tree.insert(d);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut nearest_3_to_5 = tree.find(&5, 3);
|
let mut knn = tree.find(&5, 3);
|
||||||
nearest_3_to_5.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
|
knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
|
||||||
let nearest_3_to_5_indexes: Vec<usize> = nearest_3_to_5.iter().map(|v| v.0).collect();
|
let knn: Vec<usize> = knn.iter().map(|v| v.0).collect();
|
||||||
assert_eq!(vec!(4, 5, 3), nearest_3_to_5_indexes);
|
assert_eq!(vec!(3, 4, 5), knn);
|
||||||
|
|
||||||
let mut nearest_3_to_15 = tree.find(&15, 3);
|
|
||||||
nearest_3_to_15.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
|
|
||||||
let nearest_3_to_15_indexes: Vec<usize> = nearest_3_to_15.iter().map(|v| v.0).collect();
|
|
||||||
assert_eq!(vec!(14, 13, 15), nearest_3_to_15_indexes);
|
|
||||||
|
|
||||||
assert_eq!(-1, tree.min_level);
|
|
||||||
assert_eq!(100, tree.max_level);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_invariants() {
|
fn cover_tree_test1() {
|
||||||
|
let data = vec![
|
||||||
|
vec![1., 2.],
|
||||||
|
vec![3., 4.],
|
||||||
|
vec![5., 6.],
|
||||||
|
vec![7., 8.],
|
||||||
|
vec![9., 10.],
|
||||||
|
];
|
||||||
|
|
||||||
|
let tree = CoverTree::new(data, Distances::euclidian());
|
||||||
|
|
||||||
|
let mut knn = tree.find(&vec![1., 2.], 3);
|
||||||
|
knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
|
||||||
|
let knn: Vec<usize> = knn.iter().map(|v| v.0).collect();
|
||||||
|
|
||||||
|
assert_eq!(vec!(0, 1, 2), knn);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn serde() {
|
||||||
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
|
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
|
||||||
|
|
||||||
let tree = CoverTree::new(data, SimpleDistance {});
|
let tree = CoverTree::new(data, SimpleDistance {});
|
||||||
tree.check_invariant(CoverTree::nesting_invariant);
|
|
||||||
tree.check_invariant(CoverTree::covering_tree);
|
let deserialized_tree: CoverTree<i32, f64, SimpleDistance> =
|
||||||
tree.check_invariant(CoverTree::separation);
|
serde_json::from_str(&serde_json::to_string(&tree).unwrap()).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(tree, deserialized_tree);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ use serde::{Deserialize, Serialize};
|
|||||||
use std::cmp::{Ordering, PartialOrd};
|
use std::cmp::{Ordering, PartialOrd};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
use crate::algorithm::sort::heap_select::HeapSelect;
|
use crate::algorithm::sort::heap_select::HeapSelection;
|
||||||
use crate::math::distance::Distance;
|
use crate::math::distance::Distance;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
@@ -57,7 +57,7 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
|||||||
panic!("k should be >= 1 and <= length(data)");
|
panic!("k should be >= 1 and <= length(data)");
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut heap = HeapSelect::<KNNPoint<F>>::with_capacity(k);
|
let mut heap = HeapSelection::<KNNPoint<F>>::with_capacity(k);
|
||||||
|
|
||||||
for _ in 0..k {
|
for _ in 0..k {
|
||||||
heap.add(KNNPoint {
|
heap.add(KNNPoint {
|
||||||
@@ -76,8 +76,6 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
heap.sort();
|
|
||||||
|
|
||||||
heap.get()
|
heap.get()
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.flat_map(|x| x.index.map(|i| (i, x.distance)))
|
.flat_map(|x| x.index.map(|i| (i, x.distance)))
|
||||||
@@ -124,9 +122,10 @@ mod tests {
|
|||||||
|
|
||||||
let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance {});
|
let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance {});
|
||||||
|
|
||||||
let found_idxs1: Vec<usize> = algorithm1.find(&2, 3).iter().map(|v| v.0).collect();
|
let mut found_idxs1: Vec<usize> = algorithm1.find(&2, 3).iter().map(|v| v.0).collect();
|
||||||
|
found_idxs1.sort();
|
||||||
|
|
||||||
assert_eq!(vec!(1, 2, 0), found_idxs1);
|
assert_eq!(vec!(0, 1, 2), found_idxs1);
|
||||||
|
|
||||||
let data2 = vec![
|
let data2 = vec![
|
||||||
vec![1., 1.],
|
vec![1., 1.],
|
||||||
@@ -138,13 +137,14 @@ mod tests {
|
|||||||
|
|
||||||
let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian());
|
let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian());
|
||||||
|
|
||||||
let found_idxs2: Vec<usize> = algorithm2
|
let mut found_idxs2: Vec<usize> = algorithm2
|
||||||
.find(&vec![3., 3.], 3)
|
.find(&vec![3., 3.], 3)
|
||||||
.iter()
|
.iter()
|
||||||
.map(|v| v.0)
|
.map(|v| v.0)
|
||||||
.collect();
|
.collect();
|
||||||
|
found_idxs2.sort();
|
||||||
|
|
||||||
assert_eq!(vec!(2, 3, 1), found_idxs2);
|
assert_eq!(vec!(1, 2, 3), found_idxs2);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -22,7 +22,7 @@
|
|||||||
//!
|
//!
|
||||||
//! ## References:
|
//! ## References:
|
||||||
//! * ["The Art of Computer Programming" Knuth, D, Vol. 3, 2nd ed, Sorting and Searching, 1998](https://www-cs-faculty.stanford.edu/~knuth/taocp.html)
|
//! * ["The Art of Computer Programming" Knuth, D, Vol. 3, 2nd ed, Sorting and Searching, 1998](https://www-cs-faculty.stanford.edu/~knuth/taocp.html)
|
||||||
//! * ["Cover Trees for Nearest Neighbor" Beygelzimer et al., Proceedings of the 23rd international conference on Machine learning, ICML'06 (2006)](https://homes.cs.washington.edu/~sham/papers/ml/cover_tree.pdf)
|
//! * ["Cover Trees for Nearest Neighbor" Beygelzimer et al., Proceedings of the 23rd international conference on Machine learning, ICML'06 (2006)](https://hunch.net/~jl/projects/cover_tree/cover_tree.html)
|
||||||
//! * ["Faster cover trees." Izbicki et al., Proceedings of the 32nd International Conference on Machine Learning, ICML'15 (2015)](http://www.cs.ucr.edu/~cshelton/papers/index.cgi%3FIzbShe15)
|
//! * ["Faster cover trees." Izbicki et al., Proceedings of the 32nd International Conference on Machine Learning, ICML'15 (2015)](http://www.cs.ucr.edu/~cshelton/papers/index.cgi%3FIzbShe15)
|
||||||
//! * ["The Elements of Statistical Learning: Data Mining, Inference, and Prediction" Trevor et al., 2nd edition, chapter 13](https://web.stanford.edu/~hastie/ElemStatLearn/)
|
//! * ["The Elements of Statistical Learning: Data Mining, Inference, and Prediction" Trevor et al., 2nd edition, chapter 13](https://web.stanford.edu/~hastie/ElemStatLearn/)
|
||||||
//!
|
//!
|
||||||
|
|||||||
@@ -1,20 +1,24 @@
|
|||||||
|
//! # Heap Selection Algorithm
|
||||||
|
//!
|
||||||
|
//! The goal is to find the k smallest elements in a list or array.
|
||||||
use std::cmp::Ordering;
|
use std::cmp::Ordering;
|
||||||
|
use std::fmt::Debug;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct HeapSelect<T: PartialOrd> {
|
pub struct HeapSelection<T: PartialOrd + Debug> {
|
||||||
k: usize,
|
k: usize,
|
||||||
n: usize,
|
n: usize,
|
||||||
sorted: bool,
|
sorted: bool,
|
||||||
heap: Vec<T>,
|
heap: Vec<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: PartialOrd> HeapSelect<T> {
|
impl<'a, T: PartialOrd + Debug> HeapSelection<T> {
|
||||||
pub fn with_capacity(k: usize) -> HeapSelect<T> {
|
pub fn with_capacity(k: usize) -> HeapSelection<T> {
|
||||||
HeapSelect {
|
HeapSelection {
|
||||||
k: k,
|
k: k,
|
||||||
n: 0,
|
n: 0,
|
||||||
sorted: false,
|
sorted: false,
|
||||||
heap: Vec::<T>::new(),
|
heap: Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -24,12 +28,13 @@ impl<'a, T: PartialOrd> HeapSelect<T> {
|
|||||||
self.heap.push(element);
|
self.heap.push(element);
|
||||||
self.n += 1;
|
self.n += 1;
|
||||||
if self.n == self.k {
|
if self.n == self.k {
|
||||||
self.heapify();
|
self.sort();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
self.n += 1;
|
self.n += 1;
|
||||||
if element.partial_cmp(&self.heap[0]) == Some(Ordering::Less) {
|
if element.partial_cmp(&self.heap[0]) == Some(Ordering::Less) {
|
||||||
self.heap[0] = element;
|
self.heap[0] = element;
|
||||||
|
self.sift_down(0, self.k - 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -41,58 +46,46 @@ impl<'a, T: PartialOrd> HeapSelect<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
pub fn peek(&self) -> &T {
|
pub fn peek(&self) -> &T {
|
||||||
return &self.heap[0];
|
if self.sorted {
|
||||||
|
return &self.heap[0];
|
||||||
|
} else {
|
||||||
|
&self
|
||||||
|
.heap
|
||||||
|
.iter()
|
||||||
|
.max_by(|a, b| a.partial_cmp(b).unwrap())
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn peek_mut(&mut self) -> &mut T {
|
pub fn peek_mut(&mut self) -> &mut T {
|
||||||
return &mut self.heap[0];
|
return &mut self.heap[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn sift_down(&mut self, from: usize, n: usize) {
|
|
||||||
let mut k = from;
|
|
||||||
while 2 * k <= n {
|
|
||||||
let mut j = 2 * k;
|
|
||||||
if j < n && self.heap[j] < self.heap[j + 1] {
|
|
||||||
j += 1;
|
|
||||||
}
|
|
||||||
if self.heap[k] >= self.heap[j] {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
self.heap.swap(k, j);
|
|
||||||
k = j;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get(self) -> Vec<T> {
|
pub fn get(self) -> Vec<T> {
|
||||||
return self.heap;
|
return self.heap;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn sort(&mut self) {
|
fn sift_down(&mut self, k: usize, n: usize) {
|
||||||
HeapSelect::shuffle_sort(&mut self.heap, std::cmp::min(self.k, self.n));
|
let mut kk = k;
|
||||||
|
while 2 * kk <= n {
|
||||||
|
let mut j = 2 * kk;
|
||||||
|
if j < n && self.heap[j].partial_cmp(&self.heap[j + 1]) == Some(Ordering::Less) {
|
||||||
|
j += 1;
|
||||||
|
}
|
||||||
|
if self.heap[kk].partial_cmp(&self.heap[j]) == Some(Ordering::Equal)
|
||||||
|
|| self.heap[kk].partial_cmp(&self.heap[j]) == Some(Ordering::Greater)
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
self.heap.swap(kk, j);
|
||||||
|
kk = j;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn shuffle_sort(vec: &mut Vec<T>, n: usize) {
|
fn sort(&mut self) {
|
||||||
let mut inc = 1;
|
self.sorted = true;
|
||||||
while inc <= n {
|
self.heap.sort_by(|a, b| b.partial_cmp(a).unwrap());
|
||||||
inc *= 3;
|
|
||||||
inc += 1
|
|
||||||
}
|
|
||||||
|
|
||||||
let len = n;
|
|
||||||
while inc >= 1 {
|
|
||||||
let mut i = inc;
|
|
||||||
while i < len {
|
|
||||||
let mut j = i;
|
|
||||||
while j >= inc && vec[j - inc] > vec[j] {
|
|
||||||
vec.swap(j - inc, j);
|
|
||||||
j -= inc;
|
|
||||||
}
|
|
||||||
i += 1;
|
|
||||||
}
|
|
||||||
inc /= 3
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -102,50 +95,62 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn with_capacity() {
|
fn with_capacity() {
|
||||||
let heap = HeapSelect::<i32>::with_capacity(3);
|
let heap = HeapSelection::<i32>::with_capacity(3);
|
||||||
assert_eq!(3, heap.k);
|
assert_eq!(3, heap.k);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_add() {
|
fn test_add() {
|
||||||
let mut heap = HeapSelect::with_capacity(3);
|
let mut heap = HeapSelection::with_capacity(3);
|
||||||
|
heap.add(-5);
|
||||||
|
assert_eq!(-5, *heap.peek());
|
||||||
heap.add(333);
|
heap.add(333);
|
||||||
heap.add(2);
|
assert_eq!(333, *heap.peek());
|
||||||
heap.add(13);
|
heap.add(13);
|
||||||
heap.add(10);
|
heap.add(10);
|
||||||
|
heap.add(2);
|
||||||
|
heap.add(0);
|
||||||
heap.add(40);
|
heap.add(40);
|
||||||
heap.add(30);
|
heap.add(30);
|
||||||
assert_eq!(6, heap.n);
|
assert_eq!(8, heap.n);
|
||||||
assert_eq!(&10, heap.peek());
|
assert_eq!(vec![2, 0, -5], heap.get());
|
||||||
assert_eq!(&10, heap.peek_mut());
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_add1() {
|
||||||
|
let mut heap = HeapSelection::with_capacity(3);
|
||||||
|
heap.add(std::f64::INFINITY);
|
||||||
|
heap.add(-5f64);
|
||||||
|
heap.add(4f64);
|
||||||
|
heap.add(-1f64);
|
||||||
|
heap.add(2f64);
|
||||||
|
heap.add(1f64);
|
||||||
|
heap.add(0f64);
|
||||||
|
assert_eq!(7, heap.n);
|
||||||
|
assert_eq!(vec![0f64, -1f64, -5f64], heap.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_add2() {
|
||||||
|
let mut heap = HeapSelection::with_capacity(3);
|
||||||
|
heap.add(std::f64::INFINITY);
|
||||||
|
heap.add(0.0);
|
||||||
|
heap.add(8.4852);
|
||||||
|
heap.add(5.6568);
|
||||||
|
heap.add(2.8284);
|
||||||
|
assert_eq!(5, heap.n);
|
||||||
|
assert_eq!(vec![5.6568, 2.8284, 0.0], heap.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_add_ordered() {
|
fn test_add_ordered() {
|
||||||
let mut heap = HeapSelect::with_capacity(3);
|
let mut heap = HeapSelection::with_capacity(3);
|
||||||
heap.add(1.);
|
heap.add(1.);
|
||||||
heap.add(2.);
|
heap.add(2.);
|
||||||
heap.add(3.);
|
heap.add(3.);
|
||||||
heap.add(4.);
|
heap.add(4.);
|
||||||
heap.add(5.);
|
heap.add(5.);
|
||||||
heap.add(6.);
|
heap.add(6.);
|
||||||
let result = heap.get();
|
assert_eq!(vec![3., 2., 1.], heap.get());
|
||||||
assert_eq!(vec![2., 3., 1.], result);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_shuffle_sort() {
|
|
||||||
let mut v1 = vec![10, 33, 22, 105, 12];
|
|
||||||
let n = v1.len();
|
|
||||||
HeapSelect::shuffle_sort(&mut v1, n);
|
|
||||||
assert_eq!(vec![10, 12, 22, 33, 105], v1);
|
|
||||||
|
|
||||||
let mut v2 = vec![10, 33, 22, 105, 12];
|
|
||||||
HeapSelect::shuffle_sort(&mut v2, 3);
|
|
||||||
assert_eq!(vec![10, 22, 33, 105, 12], v2);
|
|
||||||
|
|
||||||
let mut v3 = vec![4, 5, 3, 2, 1];
|
|
||||||
HeapSelect::shuffle_sort(&mut v3, 3);
|
|
||||||
assert_eq!(vec![3, 4, 5, 2, 1], v3);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,90 @@
|
|||||||
|
//! # The Iris Dataset flower
|
||||||
|
//!
|
||||||
|
//! [Fisher's Iris dataset](https://archive.ics.uci.edu/ml/datasets/iris) is a multivariate dataset that was published in 1936 by Ronald Fisher.
|
||||||
|
//! This multivariate dataset is frequently used to demonstrate various machine learning algorithms.
|
||||||
|
use crate::dataset::Dataset;
|
||||||
|
|
||||||
|
/// Get dataset
|
||||||
|
pub fn load_dataset() -> Dataset<f32, f32> {
|
||||||
|
let x = vec![
|
||||||
|
5.1, 3.5, 1.4, 0.2, 4.9, 3.0, 1.4, 0.2, 4.7, 3.2, 1.3, 0.2, 4.6, 3.1, 1.5, 0.2, 5.0, 3.6,
|
||||||
|
1.4, 0.2, 5.4, 3.9, 1.7, 0.4, 4.6, 3.4, 1.4, 0.3, 5.0, 3.4, 1.5, 0.2, 4.4, 2.9, 1.4, 0.2,
|
||||||
|
4.9, 3.1, 1.5, 0.1, 5.4, 3.7, 1.5, 0.2, 4.8, 3.4, 1.6, 0.2, 4.8, 3.0, 1.4, 0.1, 4.3, 3.0,
|
||||||
|
1.1, 0.1, 5.8, 4.0, 1.2, 0.2, 5.7, 4.4, 1.5, 0.4, 5.4, 3.9, 1.3, 0.4, 5.1, 3.5, 1.4, 0.3,
|
||||||
|
5.7, 3.8, 1.7, 0.3, 5.1, 3.8, 1.5, 0.3, 5.4, 3.4, 1.7, 0.2, 5.1, 3.7, 1.5, 0.4, 4.6, 3.6,
|
||||||
|
1.0, 0.2, 5.1, 3.3, 1.7, 0.5, 4.8, 3.4, 1.9, 0.2, 5.0, 3.0, 1.6, 0.2, 5.0, 3.4, 1.6, 0.4,
|
||||||
|
5.2, 3.5, 1.5, 0.2, 5.2, 3.4, 1.4, 0.2, 4.7, 3.2, 1.6, 0.2, 4.8, 3.1, 1.6, 0.2, 5.4, 3.4,
|
||||||
|
1.5, 0.4, 5.2, 4.1, 1.5, 0.1, 5.5, 4.2, 1.4, 0.2, 4.9, 3.1, 1.5, 0.1, 5.0, 3.2, 1.2, 0.2,
|
||||||
|
5.5, 3.5, 1.3, 0.2, 4.9, 3.1, 1.5, 0.1, 4.4, 3.0, 1.3, 0.2, 5.1, 3.4, 1.5, 0.2, 5.0, 3.5,
|
||||||
|
1.3, 0.3, 4.5, 2.3, 1.3, 0.3, 4.4, 3.2, 1.3, 0.2, 5.0, 3.5, 1.6, 0.6, 5.1, 3.8, 1.9, 0.4,
|
||||||
|
4.8, 3.0, 1.4, 0.3, 5.1, 3.8, 1.6, 0.2, 4.6, 3.2, 1.4, 0.2, 5.3, 3.7, 1.5, 0.2, 5.0, 3.3,
|
||||||
|
1.4, 0.2, 7.0, 3.2, 4.7, 1.4, 6.4, 3.2, 4.5, 1.5, 6.9, 3.1, 4.9, 1.5, 5.5, 2.3, 4.0, 1.3,
|
||||||
|
6.5, 2.8, 4.6, 1.5, 5.7, 2.8, 4.5, 1.3, 6.3, 3.3, 4.7, 1.6, 4.9, 2.4, 3.3, 1.0, 6.6, 2.9,
|
||||||
|
4.6, 1.3, 5.2, 2.7, 3.9, 1.4, 5.0, 2.0, 3.5, 1.0, 5.9, 3.0, 4.2, 1.5, 6.0, 2.2, 4.0, 1.0,
|
||||||
|
6.1, 2.9, 4.7, 1.4, 5.6, 2.9, 3.6, 1.3, 6.7, 3.1, 4.4, 1.4, 5.6, 3.0, 4.5, 1.5, 5.8, 2.7,
|
||||||
|
4.1, 1.0, 6.2, 2.2, 4.5, 1.5, 5.6, 2.5, 3.9, 1.1, 5.9, 3.2, 4.8, 1.8, 6.1, 2.8, 4.0, 1.3,
|
||||||
|
6.3, 2.5, 4.9, 1.5, 6.1, 2.8, 4.7, 1.2, 6.4, 2.9, 4.3, 1.3, 6.6, 3.0, 4.4, 1.4, 6.8, 2.8,
|
||||||
|
4.8, 1.4, 6.7, 3.0, 5.0, 1.7, 6.0, 2.9, 4.5, 1.5, 5.7, 2.6, 3.5, 1.0, 5.5, 2.4, 3.8, 1.1,
|
||||||
|
5.5, 2.4, 3.7, 1.0, 5.8, 2.7, 3.9, 1.2, 6.0, 2.7, 5.1, 1.6, 5.4, 3.0, 4.5, 1.5, 6.0, 3.4,
|
||||||
|
4.5, 1.6, 6.7, 3.1, 4.7, 1.5, 6.3, 2.3, 4.4, 1.3, 5.6, 3.0, 4.1, 1.3, 5.5, 2.5, 4.0, 1.3,
|
||||||
|
5.5, 2.6, 4.4, 1.2, 6.1, 3.0, 4.6, 1.4, 5.8, 2.6, 4.0, 1.2, 5.0, 2.3, 3.3, 1.0, 5.6, 2.7,
|
||||||
|
4.2, 1.3, 5.7, 3.0, 4.2, 1.2, 5.7, 2.9, 4.2, 1.3, 6.2, 2.9, 4.3, 1.3, 5.1, 2.5, 3.0, 1.1,
|
||||||
|
5.7, 2.8, 4.1, 1.3, 6.3, 3.3, 6.0, 2.5, 5.8, 2.7, 5.1, 1.9, 7.1, 3.0, 5.9, 2.1, 6.3, 2.9,
|
||||||
|
5.6, 1.8, 6.5, 3.0, 5.8, 2.2, 7.6, 3.0, 6.6, 2.1, 4.9, 2.5, 4.5, 1.7, 7.3, 2.9, 6.3, 1.8,
|
||||||
|
6.7, 2.5, 5.8, 1.8, 7.2, 3.6, 6.1, 2.5, 6.5, 3.2, 5.1, 2.0, 6.4, 2.7, 5.3, 1.9, 6.8, 3.0,
|
||||||
|
5.5, 2.1, 5.7, 2.5, 5.0, 2.0, 5.8, 2.8, 5.1, 2.4, 6.4, 3.2, 5.3, 2.3, 6.5, 3.0, 5.5, 1.8,
|
||||||
|
7.7, 3.8, 6.7, 2.2, 7.7, 2.6, 6.9, 2.3, 6.0, 2.2, 5.0, 1.5, 6.9, 3.2, 5.7, 2.3, 5.6, 2.8,
|
||||||
|
4.9, 2.0, 7.7, 2.8, 6.7, 2.0, 6.3, 2.7, 4.9, 1.8, 6.7, 3.3, 5.7, 2.1, 7.2, 3.2, 6.0, 1.8,
|
||||||
|
6.2, 2.8, 4.8, 1.8, 6.1, 3.0, 4.9, 1.8, 6.4, 2.8, 5.6, 2.1, 7.2, 3.0, 5.8, 1.6, 7.4, 2.8,
|
||||||
|
6.1, 1.9, 7.9, 3.8, 6.4, 2.0, 6.4, 2.8, 5.6, 2.2, 6.3, 2.8, 5.1, 1.5, 6.1, 2.6, 5.6, 1.4,
|
||||||
|
7.7, 3.0, 6.1, 2.3, 6.3, 3.4, 5.6, 2.4, 6.4, 3.1, 5.5, 1.8, 6.0, 3.0, 4.8, 1.8, 6.9, 3.1,
|
||||||
|
5.4, 2.1, 6.7, 3.1, 5.6, 2.4, 6.9, 3.1, 5.1, 2.3, 5.8, 2.7, 5.1, 1.9, 6.8, 3.2, 5.9, 2.3,
|
||||||
|
6.7, 3.3, 5.7, 2.5, 6.7, 3.0, 5.2, 2.3, 6.3, 2.5, 5.0, 1.9, 6.5, 3.0, 5.2, 2.0, 6.2, 3.4,
|
||||||
|
5.4, 2.3, 5.9, 3.0, 5.1, 1.8,
|
||||||
|
];
|
||||||
|
|
||||||
|
let setosa = std::iter::repeat(0f32).take(50);
|
||||||
|
let versicolor = std::iter::repeat(1f32).take(50);
|
||||||
|
let virginica = std::iter::repeat(2f32).take(50);
|
||||||
|
let y = setosa
|
||||||
|
.chain(versicolor)
|
||||||
|
.chain(virginica)
|
||||||
|
.collect::<Vec<f32>>();
|
||||||
|
let shape = (150, 4);
|
||||||
|
|
||||||
|
Dataset {
|
||||||
|
data: x,
|
||||||
|
target: y,
|
||||||
|
num_samples: shape.0,
|
||||||
|
num_features: shape.1,
|
||||||
|
feature_names: vec![
|
||||||
|
"sepal length (cm)",
|
||||||
|
"sepal width (cm)",
|
||||||
|
"petal length (cm)",
|
||||||
|
"petal width (cm)",
|
||||||
|
]
|
||||||
|
.iter()
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
.collect(),
|
||||||
|
target_names: vec!["setosa", "versicolor", "virginica"]
|
||||||
|
.iter()
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
.collect(),
|
||||||
|
description: "Iris dataset: https://archive.ics.uci.edu/ml/datasets/iris".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn iris_dataset() {
|
||||||
|
let dataset = load_dataset();
|
||||||
|
assert_eq!(dataset.data.len(), 50 * 3 * 4);
|
||||||
|
assert_eq!(dataset.target.len(), 50 * 3);
|
||||||
|
assert_eq!(dataset.num_features, 4);
|
||||||
|
assert_eq!(dataset.num_samples, 50 * 3);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,65 @@
|
|||||||
|
//! Datasets
|
||||||
|
//!
|
||||||
|
//! In this module you will find small datasets that are used in SmartCore for demonstration purpose mostly.
|
||||||
|
|
||||||
|
/// Iris flower data set
|
||||||
|
pub mod iris;
|
||||||
|
|
||||||
|
/// Dataset
|
||||||
|
pub struct Dataset<X, Y> {
|
||||||
|
/// data in one-dimensional array.
|
||||||
|
pub data: Vec<X>,
|
||||||
|
/// target values or class labels.
|
||||||
|
pub target: Vec<Y>,
|
||||||
|
/// number of samples (number of rows in matrix form).
|
||||||
|
pub num_samples: usize,
|
||||||
|
/// number of features (number of columns in matrix form).
|
||||||
|
pub num_features: usize,
|
||||||
|
/// names of dependent variables.
|
||||||
|
pub feature_names: Vec<String>,
|
||||||
|
/// names of target variables.
|
||||||
|
pub target_names: Vec<String>,
|
||||||
|
/// dataset description
|
||||||
|
pub description: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<X, Y> Dataset<X, Y> {
|
||||||
|
/// Reshape data into a two-dimensional matrix
|
||||||
|
pub fn as_2d_vector(&self) -> Vec<Vec<&X>> {
|
||||||
|
let mut result: Vec<Vec<&X>> = Vec::with_capacity(self.num_samples);
|
||||||
|
|
||||||
|
for r in 0..self.num_samples {
|
||||||
|
let mut row = Vec::with_capacity(self.num_features);
|
||||||
|
for c in 0..self.num_features {
|
||||||
|
row.push(&self.data[r * self.num_features + c]);
|
||||||
|
}
|
||||||
|
result.push(row);
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn as_2d_vector() {
|
||||||
|
let dataset = Dataset {
|
||||||
|
data: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
|
||||||
|
target: vec![1, 2, 3],
|
||||||
|
num_samples: 2,
|
||||||
|
num_features: 5,
|
||||||
|
feature_names: vec![],
|
||||||
|
target_names: vec![],
|
||||||
|
description: "".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let m = dataset.as_2d_vector();
|
||||||
|
|
||||||
|
assert_eq!(m.len(), 2);
|
||||||
|
assert_eq!(m[0].len(), 5);
|
||||||
|
assert_eq!(*m[1][3], 9);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -68,6 +68,9 @@
|
|||||||
pub mod algorithm;
|
pub mod algorithm;
|
||||||
/// Algorithms for clustering of unlabeled data
|
/// Algorithms for clustering of unlabeled data
|
||||||
pub mod cluster;
|
pub mod cluster;
|
||||||
|
/// Various datasets
|
||||||
|
#[cfg(feature = "datasets")]
|
||||||
|
pub mod dataset;
|
||||||
/// Matrix decomposition algorithms
|
/// Matrix decomposition algorithms
|
||||||
pub mod decomposition;
|
pub mod decomposition;
|
||||||
/// Ensemble methods, including Random Forest classifier and regressor
|
/// Ensemble methods, including Random Forest classifier and regressor
|
||||||
|
|||||||
@@ -143,6 +143,7 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
|
|||||||
|
|
||||||
fn predict_for_row(&self, x: Vec<T>) -> T {
|
fn predict_for_row(&self, x: Vec<T>) -> T {
|
||||||
let search_result = self.knn_algorithm.find(&x, self.k);
|
let search_result = self.knn_algorithm.find(&x, self.k);
|
||||||
|
println!("{:?}", search_result);
|
||||||
let mut result = T::zero();
|
let mut result = T::zero();
|
||||||
|
|
||||||
let weights = self
|
let weights = self
|
||||||
@@ -195,9 +196,10 @@ mod tests {
|
|||||||
let y_exp = vec![2., 2., 3., 4., 4.];
|
let y_exp = vec![2., 2., 3., 4., 4.];
|
||||||
let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default());
|
let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default());
|
||||||
let y_hat = knn.predict(&x);
|
let y_hat = knn.predict(&x);
|
||||||
|
println!("{:?}", y_hat);
|
||||||
assert_eq!(5, Vec::len(&y_hat));
|
assert_eq!(5, Vec::len(&y_hat));
|
||||||
for i in 0..y_hat.len() {
|
for i in 0..y_hat.len() {
|
||||||
assert!((y_hat[i] - y_exp[i]).abs() < std::f64::EPSILON);
|
assert!((y_hat[i] - y_exp[i]).abs() < 1e-7);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user