feat: adds serialization/deserialization methods

This commit is contained in:
Volodymyr Orlov
2020-04-03 11:12:15 -07:00
parent 5766364311
commit eb0c36223f
16 changed files with 555 additions and 159 deletions
+3 -3
View File
@@ -2,7 +2,7 @@ use std::fmt::Debug;
use crate::math::num::FloatExt;
use crate::linalg::Matrix;
use crate::math::distance::euclidian;
use crate::math::distance::euclidian::*;
#[derive(Debug)]
pub struct BBDTree<T: FloatExt> {
@@ -79,10 +79,10 @@ impl<T: FloatExt> BBDTree<T> {
let d = centroids[0].len();
// Determine which mean the node mean is closest to
let mut min_dist = euclidian::squared_distance(&self.nodes[node].center, &centroids[candidates[0]]);
let mut min_dist = Euclidian::squared_distance(&self.nodes[node].center, &centroids[candidates[0]]);
let mut closest = candidates[0];
for i in 1..k {
let dist = euclidian::squared_distance(&self.nodes[node].center, &centroids[candidates[i]]);
let dist = Euclidian::squared_distance(&self.nodes[node].center, &centroids[candidates[i]]);
if dist < min_dist {
min_dist = dist;
closest = candidates[i];
+43 -42
View File
@@ -3,25 +3,26 @@ use std::iter::FromIterator;
use std::fmt::Debug;
use core::hash::{Hash, Hasher};
use serde::{Serialize, Deserialize};
use crate::math::num::FloatExt;
use crate::algorithm::neighbour::KNNAlgorithm;
use crate::math::distance::Distance;
use crate::algorithm::sort::heap_select::HeapSelect;
pub struct CoverTree<'a, T, F: FloatExt>
where T: Debug
#[derive(Serialize, Deserialize, Debug)]
pub struct CoverTree<T, F: FloatExt, D: Distance<T, F>>
{
base: F,
max_level: i8,
min_level: i8,
distance: &'a dyn Fn(&T, &T) -> F,
distance: D,
nodes: Vec<Node<T>>
}
impl<'a, T, F: FloatExt> CoverTree<'a, T, F>
where T: Debug
impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D>
{
pub fn new(mut data: Vec<T>, distance: &'a dyn Fn(&T, &T) -> F) -> CoverTree<T, F> {
pub fn new(mut data: Vec<T>, distance: D) -> CoverTree<T, F, D> {
let mut tree = CoverTree {
base: F::two(),
max_level: 100,
@@ -43,7 +44,7 @@ where T: Debug
} else {
let mut parent: Option<NodeId> = Option::None;
let mut p_i = 0;
let mut qi_p_ds = vec!((self.root(), (self.distance)(&p, &self.root().data)));
let mut qi_p_ds = vec!((self.root(), D::distance(&p, &self.root().data)));
let mut i = self.max_level;
loop {
let i_d = self.base.powf(F::from(i).unwrap());
@@ -82,6 +83,18 @@ where T: Debug
node_id
}
pub fn find(&self, p: &T, k: usize) -> Vec<usize>{
let mut qi_p_ds = vec!((self.root(), D::distance(&p, &self.root().data)));
for i in (self.min_level..self.max_level+1).rev() {
let i_d = self.base.powf(F::from(i).unwrap());
let mut q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
let d_p_q = self.min_k_by_distance(&mut q_p_ds, k);
qi_p_ds = q_p_ds.into_iter().filter(|(_, d)| d <= &(d_p_q + i_d)).collect();
}
qi_p_ds.sort_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap());
qi_p_ds[..usize::min(qi_p_ds.len(), k)].iter().map(|(n, _)| n.index.index).collect()
}
fn split(&self, p_id: NodeId, r: F, s1: &mut Vec<T>, s2: Option<&mut Vec<T>>) -> (Vec<T>, Vec<T>){
let mut my_near = (Vec::new(), Vec::new());
@@ -102,7 +115,7 @@ where T: Debug
let p = &self.nodes.get(p_id.index).unwrap().data;
let mut i = 0;
while i != s.len() {
let d = (self.distance)(p, &s[i]);
let d = D::distance(p, &s[i]);
if d <= r {
my_near.0.push(s.remove(i));
} else if d > r && d <= F::two() * r{
@@ -156,7 +169,7 @@ where T: Debug
let q: Vec<&Node<T>> = qi_p_ds.iter().flat_map(|(n, _)| self.get_child(n, i)).collect();
children.extend(q.into_iter().map(|n| (n, (self.distance)(&n.data, &p))));
children.extend(q.into_iter().map(|n| (n, D::distance(&n.data, &p))));
children
@@ -180,7 +193,7 @@ where T: Debug
}
#[allow(dead_code)]
fn check_invariant(&self, invariant: fn(&CoverTree<T, F>, &Vec<&Node<T>>, &Vec<&Node<T>>, i8) -> ()) {
fn check_invariant(&self, invariant: fn(&CoverTree<T, F, D>, &Vec<&Node<T>>, &Vec<&Node<T>>, i8) -> ()) {
let mut current_nodes: Vec<&Node<T>> = Vec::new();
current_nodes.push(self.root());
for i in (self.min_level..self.max_level+1).rev() {
@@ -193,7 +206,7 @@ where T: Debug
}
#[allow(dead_code)]
fn nesting_invariant(_: &CoverTree<T, F>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, _: i8) {
fn nesting_invariant(_: &CoverTree<T, F, D>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, _: i8) {
let nodes_set: HashSet<&Node<T>> = HashSet::from_iter(nodes.into_iter().map(|n| *n));
let next_nodes_set: HashSet<&Node<T>> = HashSet::from_iter(next_nodes.into_iter().map(|n| *n));
for n in nodes_set.iter() {
@@ -202,11 +215,11 @@ where T: Debug
}
#[allow(dead_code)]
fn covering_tree(tree: &CoverTree<T, F>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, i: i8) {
fn covering_tree(tree: &CoverTree<T, F, D>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, i: i8) {
let mut p_selected: Vec<&Node<T>> = Vec::new();
for p in next_nodes {
for q in nodes {
if (tree.distance)(&p.data, &q.data) <= tree.base.powf(F::from(i).unwrap()) {
if D::distance(&p.data, &q.data) <= tree.base.powf(F::from(i).unwrap()) {
p_selected.push(*p);
}
}
@@ -216,11 +229,11 @@ where T: Debug
}
#[allow(dead_code)]
fn separation(tree: &CoverTree<T, F>, nodes: &Vec<&Node<T>>, _: &Vec<&Node<T>>, i: i8) {
fn separation(tree: &CoverTree<T, F, D>, nodes: &Vec<&Node<T>>, _: &Vec<&Node<T>>, i: i8) {
for p in nodes {
for q in nodes {
if p != q {
assert!((tree.distance)(&p.data, &q.data) > tree.base.powf(F::from(i).unwrap()));
assert!(D::distance(&p.data, &q.data) > tree.base.powf(F::from(i).unwrap()));
}
}
}
@@ -228,28 +241,12 @@ where T: Debug
}
impl<'a, T, F: FloatExt> KNNAlgorithm<T> for CoverTree<'a, T, F>
where T: Debug
{
fn find(&self, p: &T, k: usize) -> Vec<usize>{
let mut qi_p_ds = vec!((self.root(), (self.distance)(&p, &self.root().data)));
for i in (self.min_level..self.max_level+1).rev() {
let i_d = self.base.powf(F::from(i).unwrap());
let mut q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
let d_p_q = self.min_k_by_distance(&mut q_p_ds, k);
qi_p_ds = q_p_ds.into_iter().filter(|(_, d)| d <= &(d_p_q + i_d)).collect();
}
qi_p_ds.sort_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap());
qi_p_ds[..usize::min(qi_p_ds.len(), k)].iter().map(|(n, _)| n.index.index).collect()
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct NodeId {
index: usize,
}
#[derive(Debug)]
#[derive(Debug, Serialize, Deserialize)]
struct Node<T> {
index: NodeId,
data: T,
@@ -280,13 +277,19 @@ mod tests {
use super::*;
struct SimpleDistance{}
impl Distance<i32, f64> for SimpleDistance {
fn distance(a: &i32, b: &i32) -> f64 {
(a - b).abs() as f64
}
}
#[test]
fn cover_tree_test() {
let data = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9);
let distance = |a: &i32, b: &i32| -> f64 {
(a - b).abs() as f64
};
let mut tree = CoverTree::<i32, f64>::new(data, &distance);
let mut tree = CoverTree::new(data, SimpleDistance{});
for d in vec!(10, 11, 12, 13, 14, 15, 16, 17, 18, 19) {
tree.insert(d);
}
@@ -306,10 +309,8 @@ mod tests {
#[test]
fn test_invariants(){
let data = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9);
let distance = |a: &i32, b: &i32| -> f64 {
(a - b).abs() as f64
};
let tree = CoverTree::<i32, f64>::new(data, &distance);
let tree = CoverTree::new(data, SimpleDistance{});
tree.check_invariant(CoverTree::nesting_invariant);
tree.check_invariant(CoverTree::covering_tree);
tree.check_invariant(CoverTree::separation);
+32 -29
View File
@@ -1,16 +1,28 @@
use crate::algorithm::neighbour::KNNAlgorithm;
use crate::algorithm::sort::heap_select::HeapSelect;
use std::cmp::{Ordering, PartialOrd};
use num_traits::Float;
use std::marker::PhantomData;
use serde::{Serialize, Deserialize};
pub struct LinearKNNSearch<'a, T, F: Float> {
distance: Box<dyn Fn(&T, &T) -> F + 'a>,
data: Vec<T>
use crate::math::num::FloatExt;
use crate::math::distance::Distance;
use crate::algorithm::sort::heap_select::HeapSelect;
#[derive(Serialize, Deserialize, Debug)]
pub struct LinearKNNSearch<T, F: FloatExt, D: Distance<T, F>> {
distance: D,
data: Vec<T>,
f: PhantomData<F>
}
impl<'a, T, F: Float> KNNAlgorithm<T> for LinearKNNSearch<'a, T, F>
{
fn find(&self, from: &T, k: usize) -> Vec<usize> {
impl<T, F: FloatExt, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
pub fn new(data: Vec<T>, distance: D) -> LinearKNNSearch<T, F, D>{
LinearKNNSearch{
data: data,
distance: distance,
f: PhantomData
}
}
pub fn find(&self, from: &T, k: usize) -> Vec<usize> {
if k < 1 || k > self.data.len() {
panic!("k should be >= 1 and <= length(data)");
}
@@ -19,14 +31,14 @@ impl<'a, T, F: Float> KNNAlgorithm<T> for LinearKNNSearch<'a, T, F>
for _ in 0..k {
heap.add(KNNPoint{
distance: Float::infinity(),
distance: F::infinity(),
index: None
});
}
for i in 0..self.data.len() {
let d = (self.distance)(&from, &self.data[i]);
let d = D::distance(&from, &self.data[i]);
let datum = heap.peek_mut();
if d < datum.distance {
datum.distance = d;
@@ -41,43 +53,34 @@ impl<'a, T, F: Float> KNNAlgorithm<T> for LinearKNNSearch<'a, T, F>
}
}
impl<'a, T, F: Float> LinearKNNSearch<'a, T, F> {
pub fn new(data: Vec<T>, distance: &'a dyn Fn(&T, &T) -> F) -> LinearKNNSearch<T, F>{
LinearKNNSearch{
data: data,
distance: Box::new(distance)
}
}
}
#[derive(Debug)]
struct KNNPoint<F: Float> {
struct KNNPoint<F: FloatExt> {
distance: F,
index: Option<usize>
}
impl<F: Float> PartialOrd for KNNPoint<F> {
impl<F: FloatExt> PartialOrd for KNNPoint<F> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.distance.partial_cmp(&other.distance)
}
}
impl<F: Float> PartialEq for KNNPoint<F> {
impl<F: FloatExt> PartialEq for KNNPoint<F> {
fn eq(&self, other: &Self) -> bool {
self.distance == other.distance
}
}
impl<F: Float> Eq for KNNPoint<F> {}
impl<F: FloatExt> Eq for KNNPoint<F> {}
#[cfg(test)]
mod tests {
use super::*;
use crate::math::distance::euclidian;
use crate::math::distance::Distances;
struct SimpleDistance{}
impl SimpleDistance {
impl Distance<i32, f64> for SimpleDistance {
fn distance(a: &i32, b: &i32) -> f64 {
(a - b).abs() as f64
}
@@ -87,13 +90,13 @@ mod tests {
fn knn_find() {
let data1 = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
let algorithm1 = LinearKNNSearch::new(data1, &SimpleDistance::distance);
let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance{});
assert_eq!(vec!(1, 2, 0), algorithm1.find(&2, 3));
let data2 = vec!(vec![1., 1.], vec![2., 2.], vec![3., 3.], vec![4., 4.], vec![5., 5.]);
let algorithm2 = LinearKNNSearch::new(data2, &euclidian::distance);
let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian());
assert_eq!(vec!(2, 3, 1), algorithm2.find(&vec![3., 3.], 3));
}
@@ -116,7 +119,7 @@ mod tests {
};
let point_inf = KNNPoint{
distance: Float::infinity(),
distance: std::f64::INFINITY,
index: Some(3)
};
+1 -10
View File
@@ -1,12 +1,3 @@
pub mod cover_tree;
pub mod linear_search;
pub mod bbd_tree;
pub enum KNNAlgorithmName {
CoverTree,
LinearSearch,
}
pub trait KNNAlgorithm<T>{
fn find(&self, from: &T, k: usize) -> Vec<usize>;
}
pub mod bbd_tree;