Adds CoverTree implementation
This commit is contained in:
@@ -1 +1,2 @@
|
||||
pub mod sort;
|
||||
pub mod neighbour;
|
||||
@@ -0,0 +1,161 @@
|
||||
use crate::math;
|
||||
use crate::algorithm::neighbour::KNNAlgorithm;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Debug;
|
||||
|
||||
pub struct CoverTree<'a, T>
|
||||
where T: Debug
|
||||
{
|
||||
|
||||
base: f64,
|
||||
max_level: i8,
|
||||
min_level: i8,
|
||||
distance: &'a Fn(&T, &T) -> f64,
|
||||
nodes: Vec<Node<T>>
|
||||
}
|
||||
|
||||
impl<'a, T> CoverTree<'a, T>
|
||||
where T: Debug
|
||||
{
|
||||
|
||||
pub fn new(data: Vec<T>, distance: &'a Fn(&T, &T) -> f64) -> CoverTree<T> {
|
||||
let mut tree = CoverTree {
|
||||
base: 2f64,
|
||||
max_level: 10,
|
||||
min_level: 10,
|
||||
distance: distance,
|
||||
nodes: Vec::new()
|
||||
};
|
||||
|
||||
for p in data {
|
||||
tree.insert(p);
|
||||
}
|
||||
|
||||
tree
|
||||
|
||||
}
|
||||
|
||||
pub fn new_node(&mut self, data: T) -> NodeId {
|
||||
let next_index = self.nodes.len();
|
||||
let node_id = NodeId { index: next_index };
|
||||
self.nodes.push(
|
||||
Node {
|
||||
index: node_id,
|
||||
data: data,
|
||||
parent: None,
|
||||
children: HashMap::new()
|
||||
});
|
||||
node_id
|
||||
}
|
||||
|
||||
fn insert(&mut self, p: T) {
|
||||
if self.nodes.is_empty(){
|
||||
self.new_node(p);
|
||||
} else {
|
||||
let mut parent: Option<NodeId> = Option::None;
|
||||
let mut p_i = 0;
|
||||
let mut qi_p_ds = vec!((self.root(), (self.distance)(&p, &self.root().data)));
|
||||
let mut i = self.max_level;
|
||||
loop {
|
||||
let i_d = self.base.powf(i as f64);
|
||||
let q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
|
||||
let d_p_Q = self.min_ds(&q_p_ds);
|
||||
if d_p_Q < math::small_e {
|
||||
return
|
||||
} else if d_p_Q > i_d {
|
||||
break;
|
||||
}
|
||||
if self.min_ds(&qi_p_ds) <= self.base.powf(i as f64){
|
||||
parent = q_p_ds.iter().find(|(_, d)| d <= &i_d).map(|(n, d)| n.index);
|
||||
p_i = i;
|
||||
}
|
||||
|
||||
qi_p_ds = q_p_ds.into_iter().filter(|(n, d)| d <= &i_d).collect();
|
||||
i -= 1;
|
||||
}
|
||||
|
||||
let new_node = self.new_node(p);
|
||||
self.nodes.get_mut(parent.unwrap().index).unwrap().children.insert(p_i, new_node);
|
||||
self.min_level = i8::min(self.min_level, p_i-1);
|
||||
}
|
||||
}
|
||||
|
||||
fn root(&self) -> &Node<T> {
|
||||
self.nodes.first().unwrap()
|
||||
}
|
||||
|
||||
fn get_children_dist<'b>(&'b self, p: &T, qi_p_ds: &Vec<(&'b Node<T>, f64)>, i: i8) -> Vec<(&'b Node<T>, f64)> {
|
||||
|
||||
let mut children = Vec::<(&'b Node<T>, f64)>::new();
|
||||
|
||||
children.extend(qi_p_ds.iter().cloned());
|
||||
|
||||
let q: Vec<&Node<T>> = qi_p_ds.iter().flat_map(|(n, _)| self.get_child(n, i)).collect();
|
||||
|
||||
children.extend(q.into_iter().map(|n| (n, (self.distance)(&n.data, &p))));
|
||||
|
||||
children
|
||||
|
||||
}
|
||||
|
||||
fn min_ds(&self, q_p_ds: &Vec<(&Node<T>, f64)>) -> f64 {
|
||||
q_p_ds.into_iter().min_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap()).unwrap().1
|
||||
}
|
||||
|
||||
fn min_p_ds(&self, q_p_ds: &mut Vec<(&Node<T>, f64)>, k: usize) -> f64 {
|
||||
q_p_ds.sort_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap());
|
||||
q_p_ds[..usize::min(q_p_ds.len(), k)].last().unwrap().1
|
||||
}
|
||||
|
||||
fn get_child(&self, node: &Node<T>, i: i8) -> Option<&Node<T>> {
|
||||
node.children.get(&i).and_then(|n_id| self.nodes.get(n_id.index))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
impl<'a, T> KNNAlgorithm<T> for CoverTree<'a, T>
|
||||
where T: Debug
|
||||
{
|
||||
fn find(&self, p: &T, k: usize) -> Vec<usize>{
|
||||
let mut qi_p_ds = vec!((self.root(), (self.distance)(&p, &self.root().data)));
|
||||
for i in (self.min_level..self.max_level+1).rev() {
|
||||
let i_d = self.base.powf(i as f64);
|
||||
let mut q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
|
||||
let d_p_q = self.min_p_ds(&mut q_p_ds, k);
|
||||
qi_p_ds = q_p_ds.into_iter().filter(|(n, d)| d <= &(d_p_q + i_d)).collect();
|
||||
}
|
||||
qi_p_ds.sort_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap());
|
||||
qi_p_ds[..usize::min(qi_p_ds.len(), k)].iter().map(|(n, _)| n.index.index).collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct NodeId {
|
||||
index: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Node<T> {
|
||||
index: NodeId,
|
||||
data: T,
|
||||
children: HashMap<i8, NodeId>,
|
||||
parent: Option<NodeId>
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn cover_tree_test() {
|
||||
let data = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9);
|
||||
let distance = |a: &i32, b: &i32| -> f64 {
|
||||
(a - b).abs() as f64
|
||||
};
|
||||
let tree = CoverTree::<i32>::new(data, &distance);
|
||||
let nearest_3 = tree.find(&5, 3);
|
||||
assert_eq!(vec!(4, 5, 3), nearest_3);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
use crate::algorithm::neighbour::KNNAlgorithm;
|
||||
use crate::algorithm::sort::heap_select::HeapSelect;
|
||||
use std::cmp::{Ordering, PartialOrd};
|
||||
use num_traits::Float;
|
||||
|
||||
pub struct LinearKNNSearch<'a, T> {
|
||||
distance: Box<Fn(&T, &T) -> f64 + 'a>,
|
||||
data: Vec<T>
|
||||
}
|
||||
|
||||
impl<'a, T> KNNAlgorithm<T> for LinearKNNSearch<'a, T>
|
||||
{
|
||||
fn find(&self, from: &T, k: usize) -> Vec<usize> {
|
||||
if k < 1 || k > self.data.len() {
|
||||
panic!("k should be >= 1 and <= length(data)");
|
||||
}
|
||||
|
||||
let mut heap = HeapSelect::<KNNPoint>::with_capacity(k);
|
||||
|
||||
for _ in 0..k {
|
||||
heap.add(KNNPoint{
|
||||
distance: Float::infinity(),
|
||||
index: None
|
||||
});
|
||||
}
|
||||
|
||||
for i in 0..self.data.len() {
|
||||
|
||||
let d = (self.distance)(&from, &self.data[i]);
|
||||
let datum = heap.peek_mut();
|
||||
if d < datum.distance {
|
||||
datum.distance = d;
|
||||
datum.index = Some(i);
|
||||
heap.heapify();
|
||||
}
|
||||
}
|
||||
|
||||
heap.sort();
|
||||
|
||||
heap.get().into_iter().flat_map(|x| x.index).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> LinearKNNSearch<'a, T> {
|
||||
pub fn new(data: Vec<T>, distance: &'a Fn(&T, &T) -> f64) -> LinearKNNSearch<T>{
|
||||
LinearKNNSearch{
|
||||
data: data,
|
||||
distance: Box::new(distance)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct KNNPoint {
|
||||
distance: f64,
|
||||
index: Option<usize>
|
||||
}
|
||||
|
||||
impl PartialOrd for KNNPoint {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
self.distance.partial_cmp(&other.distance)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for KNNPoint {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.distance == other.distance
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for KNNPoint {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::math::distance::Distance;
|
||||
use ndarray::{arr1, Array1};
|
||||
|
||||
struct SimpleDistance{}
|
||||
|
||||
impl SimpleDistance {
|
||||
fn distance(a: &i32, b: &i32) -> f64 {
|
||||
(a - b).abs() as f64
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn knn_find() {
|
||||
let data1 = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
|
||||
|
||||
let algorithm1 = LinearKNNSearch::new(data1, &SimpleDistance::distance);
|
||||
|
||||
assert_eq!(vec!(1, 2, 0), algorithm1.find(&2, 3));
|
||||
|
||||
let data2 = vec!(arr1(&[1, 1]), arr1(&[2, 2]), arr1(&[3, 3]), arr1(&[4, 4]), arr1(&[5, 5]));
|
||||
|
||||
let algorithm2 = LinearKNNSearch::new(data2, &Array1::distance);
|
||||
|
||||
assert_eq!(vec!(2, 3, 1), algorithm2.find(&arr1(&[3, 3]), 3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn knn_point_eq() {
|
||||
let point1 = KNNPoint{
|
||||
distance: 10.,
|
||||
index: Some(0)
|
||||
};
|
||||
|
||||
let point2 = KNNPoint{
|
||||
distance: 100.,
|
||||
index: Some(1)
|
||||
};
|
||||
|
||||
let point3 = KNNPoint{
|
||||
distance: 10.,
|
||||
index: Some(2)
|
||||
};
|
||||
|
||||
let point_inf = KNNPoint{
|
||||
distance: Float::infinity(),
|
||||
index: Some(3)
|
||||
};
|
||||
|
||||
assert!(point2 > point1);
|
||||
assert_eq!(point3, point1);
|
||||
assert_ne!(point3, point2);
|
||||
assert!(point_inf > point3 && point_inf > point2 && point_inf > point1);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
pub mod cover_tree;
|
||||
pub mod linear_search;
|
||||
|
||||
pub enum KNNAlgorithmName {
|
||||
CoverTree,
|
||||
LinearSearch,
|
||||
}
|
||||
|
||||
pub trait KNNAlgorithm<T>{
|
||||
fn find(&self, from: &T, k: usize) -> Vec<usize>;
|
||||
}
|
||||
+24
-119
@@ -1,31 +1,33 @@
|
||||
use super::Classifier;
|
||||
use std::collections::HashSet;
|
||||
use crate::algorithm::sort::heap_select::HeapSelect;
|
||||
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
|
||||
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
|
||||
use crate::algorithm::neighbour::cover_tree::CoverTree;
|
||||
use crate::common::Nominal;
|
||||
use ndarray::{ArrayBase, Data, Ix1, Ix2};
|
||||
use num_traits::{Float};
|
||||
use std::cmp::{Ordering, PartialOrd};
|
||||
use std::fmt::Debug;
|
||||
|
||||
|
||||
type F<X> = Fn(&X, &X) -> f64;
|
||||
|
||||
pub struct KNNClassifier<X, Y>
|
||||
pub struct KNNClassifier<'a, X, Y>
|
||||
where
|
||||
Y: Nominal
|
||||
Y: Nominal,
|
||||
X: Debug
|
||||
{
|
||||
classes: Vec<Y>,
|
||||
y: Vec<usize>,
|
||||
data: Vec<X>,
|
||||
distance: Box<F<X>>,
|
||||
knn_algorithm: Box<KNNAlgorithm<X> + 'a>,
|
||||
k: usize,
|
||||
}
|
||||
|
||||
impl<X, Y> KNNClassifier<X, Y>
|
||||
impl<'a, X, Y> KNNClassifier<'a, X, Y>
|
||||
where
|
||||
Y: Nominal
|
||||
Y: Nominal,
|
||||
X: Debug
|
||||
{
|
||||
|
||||
pub fn fit(x: Vec<X>, y: Vec<Y>, k: usize, distance: &'static F<X>) -> KNNClassifier<X, Y> {
|
||||
pub fn fit(x: Vec<X>, y: Vec<Y>, k: usize, distance: &'a F<X>, algorithm: KNNAlgorithmName) -> KNNClassifier<X, Y> {
|
||||
|
||||
assert!(Vec::len(&x) == Vec::len(&y), format!("Size of x should equal size of y; |x|=[{}], |y|=[{}]", Vec::len(&x), Vec::len(&y)));
|
||||
|
||||
@@ -35,18 +37,25 @@ where
|
||||
let classes: Vec<Y> = c_hash.into_iter().collect();
|
||||
let y_i:Vec<usize> = y.into_iter().map(|y| classes.iter().position(|yy| yy == &y).unwrap()).collect();
|
||||
|
||||
KNNClassifier{classes:classes, y: y_i, data: x, k: k, distance: Box::new(distance)}
|
||||
let knn_algorithm: Box<KNNAlgorithm<X> + 'a> = match algorithm {
|
||||
KNNAlgorithmName::CoverTree => Box::new(CoverTree::<X>::new(x, distance)),
|
||||
KNNAlgorithmName::LinearSearch => Box::new(LinearKNNSearch::<X>::new(x, distance))
|
||||
};
|
||||
|
||||
KNNClassifier{classes:classes, y: y_i, k: k, knn_algorithm: knn_algorithm}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
impl<X, Y> Classifier<X, Y> for KNNClassifier<X, Y>
|
||||
impl<'a, X, Y> Classifier<X, Y> for KNNClassifier<'a, X, Y>
|
||||
where
|
||||
Y: Nominal
|
||||
Y: Nominal,
|
||||
X: Debug
|
||||
{
|
||||
|
||||
fn predict(&self, x: &X) -> Y {
|
||||
let idxs = self.data.find(x, self.k, &self.distance);
|
||||
let idxs = self.knn_algorithm.find(x, self.k);
|
||||
let mut c = vec![0; self.classes.len()];
|
||||
let mut max_c = 0;
|
||||
let mut max_i = 0;
|
||||
@@ -79,123 +88,19 @@ impl NDArrayUtils {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait KNNAlgorithm<T>{
|
||||
fn find(&self, from: &T, k: usize, d: &Fn(&T, &T) -> f64) -> Vec<usize>;
|
||||
}
|
||||
|
||||
impl<T> KNNAlgorithm<T> for Vec<T>
|
||||
{
|
||||
fn find(&self, from: &T, k: usize, d: &Fn(&T, &T) -> f64) -> Vec<usize> {
|
||||
if k < 1 || k > self.len() {
|
||||
panic!("k should be >= 1 and <= length(data)");
|
||||
}
|
||||
|
||||
let mut heap = HeapSelect::<KNNPoint>::with_capacity(k);
|
||||
|
||||
for _ in 0..k {
|
||||
heap.add(KNNPoint{
|
||||
distance: Float::infinity(),
|
||||
index: None
|
||||
});
|
||||
}
|
||||
|
||||
for i in 0..self.len() {
|
||||
|
||||
let d = d(&from, &self[i]);
|
||||
let datum = heap.peek_mut();
|
||||
if d < datum.distance {
|
||||
datum.distance = d;
|
||||
datum.index = Some(i);
|
||||
heap.heapify();
|
||||
}
|
||||
}
|
||||
|
||||
heap.sort();
|
||||
|
||||
heap.get().into_iter().flat_map(|x| x.index).collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct KNNPoint {
|
||||
distance: f64,
|
||||
index: Option<usize>
|
||||
}
|
||||
|
||||
impl PartialOrd for KNNPoint {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
self.distance.partial_cmp(&other.distance)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for KNNPoint {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.distance == other.distance
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for KNNPoint {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::math::distance::Distance;
|
||||
use ndarray::{arr1, arr2, Array1};
|
||||
|
||||
struct SimpleDistance{}
|
||||
|
||||
impl SimpleDistance {
|
||||
fn distance(a: &i32, b: &i32) -> f64 {
|
||||
(a - b).abs() as f64
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn knn_fit_predict() {
|
||||
let x = arr2(&[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]);
|
||||
let y = arr1(&[2, 2, 2, 3, 3]);
|
||||
let knn = KNNClassifier::fit(NDArrayUtils::array2_to_vec(&x), y.to_vec(), 3, &Array1::distance);
|
||||
let knn = KNNClassifier::fit(NDArrayUtils::array2_to_vec(&x), y.to_vec(), 3, &Array1::distance, KNNAlgorithmName::LinearSearch);
|
||||
let r = knn.predict_vec(&NDArrayUtils::array2_to_vec(&x));
|
||||
assert_eq!(5, Vec::len(&r));
|
||||
assert_eq!(y.to_vec(), r);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn knn_find() {
|
||||
let data1 = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
|
||||
|
||||
assert_eq!(vec!(1, 2, 0), data1.find(&2, 3, &SimpleDistance::distance));
|
||||
|
||||
let data2 = vec!(arr1(&[1, 1]), arr1(&[2, 2]), arr1(&[3, 3]), arr1(&[4, 4]), arr1(&[5, 5]));
|
||||
|
||||
assert_eq!(vec!(2, 3, 1), data2.find(&arr1(&[3, 3]), 3, &Array1::distance));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn knn_point_eq() {
|
||||
let point1 = KNNPoint{
|
||||
distance: 10.,
|
||||
index: Some(0)
|
||||
};
|
||||
|
||||
let point2 = KNNPoint{
|
||||
distance: 100.,
|
||||
index: Some(1)
|
||||
};
|
||||
|
||||
let point3 = KNNPoint{
|
||||
distance: 10.,
|
||||
index: Some(2)
|
||||
};
|
||||
|
||||
let point_inf = KNNPoint{
|
||||
distance: Float::infinity(),
|
||||
index: Some(3)
|
||||
};
|
||||
|
||||
assert!(point2 > point1);
|
||||
assert_eq!(point3, point1);
|
||||
assert_ne!(point3, point2);
|
||||
assert!(point_inf > point3 && point_inf > point2 && point_inf > point1);
|
||||
}
|
||||
}
|
||||
@@ -1 +1,3 @@
|
||||
pub mod distance;
|
||||
|
||||
pub static small_e:f64 = 0.000000001f64;
|
||||
Reference in New Issue
Block a user