feat: adds serialization/deserialization methods
This commit is contained in:
@@ -2,7 +2,7 @@ use std::fmt::Debug;
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::distance::euclidian;
|
||||
use crate::math::distance::euclidian::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct BBDTree<T: FloatExt> {
|
||||
@@ -79,10 +79,10 @@ impl<T: FloatExt> BBDTree<T> {
|
||||
let d = centroids[0].len();
|
||||
|
||||
// Determine which mean the node mean is closest to
|
||||
let mut min_dist = euclidian::squared_distance(&self.nodes[node].center, ¢roids[candidates[0]]);
|
||||
let mut min_dist = Euclidian::squared_distance(&self.nodes[node].center, ¢roids[candidates[0]]);
|
||||
let mut closest = candidates[0];
|
||||
for i in 1..k {
|
||||
let dist = euclidian::squared_distance(&self.nodes[node].center, ¢roids[candidates[i]]);
|
||||
let dist = Euclidian::squared_distance(&self.nodes[node].center, ¢roids[candidates[i]]);
|
||||
if dist < min_dist {
|
||||
min_dist = dist;
|
||||
closest = candidates[i];
|
||||
|
||||
@@ -3,25 +3,26 @@ use std::iter::FromIterator;
|
||||
use std::fmt::Debug;
|
||||
use core::hash::{Hash, Hasher};
|
||||
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::algorithm::neighbour::KNNAlgorithm;
|
||||
use crate::math::distance::Distance;
|
||||
use crate::algorithm::sort::heap_select::HeapSelect;
|
||||
|
||||
pub struct CoverTree<'a, T, F: FloatExt>
|
||||
where T: Debug
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct CoverTree<T, F: FloatExt, D: Distance<T, F>>
|
||||
{
|
||||
base: F,
|
||||
max_level: i8,
|
||||
min_level: i8,
|
||||
distance: &'a dyn Fn(&T, &T) -> F,
|
||||
distance: D,
|
||||
nodes: Vec<Node<T>>
|
||||
}
|
||||
|
||||
impl<'a, T, F: FloatExt> CoverTree<'a, T, F>
|
||||
where T: Debug
|
||||
impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
{
|
||||
|
||||
pub fn new(mut data: Vec<T>, distance: &'a dyn Fn(&T, &T) -> F) -> CoverTree<T, F> {
|
||||
pub fn new(mut data: Vec<T>, distance: D) -> CoverTree<T, F, D> {
|
||||
let mut tree = CoverTree {
|
||||
base: F::two(),
|
||||
max_level: 100,
|
||||
@@ -43,7 +44,7 @@ where T: Debug
|
||||
} else {
|
||||
let mut parent: Option<NodeId> = Option::None;
|
||||
let mut p_i = 0;
|
||||
let mut qi_p_ds = vec!((self.root(), (self.distance)(&p, &self.root().data)));
|
||||
let mut qi_p_ds = vec!((self.root(), D::distance(&p, &self.root().data)));
|
||||
let mut i = self.max_level;
|
||||
loop {
|
||||
let i_d = self.base.powf(F::from(i).unwrap());
|
||||
@@ -82,6 +83,18 @@ where T: Debug
|
||||
node_id
|
||||
}
|
||||
|
||||
pub fn find(&self, p: &T, k: usize) -> Vec<usize>{
|
||||
let mut qi_p_ds = vec!((self.root(), D::distance(&p, &self.root().data)));
|
||||
for i in (self.min_level..self.max_level+1).rev() {
|
||||
let i_d = self.base.powf(F::from(i).unwrap());
|
||||
let mut q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
|
||||
let d_p_q = self.min_k_by_distance(&mut q_p_ds, k);
|
||||
qi_p_ds = q_p_ds.into_iter().filter(|(_, d)| d <= &(d_p_q + i_d)).collect();
|
||||
}
|
||||
qi_p_ds.sort_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap());
|
||||
qi_p_ds[..usize::min(qi_p_ds.len(), k)].iter().map(|(n, _)| n.index.index).collect()
|
||||
}
|
||||
|
||||
fn split(&self, p_id: NodeId, r: F, s1: &mut Vec<T>, s2: Option<&mut Vec<T>>) -> (Vec<T>, Vec<T>){
|
||||
|
||||
let mut my_near = (Vec::new(), Vec::new());
|
||||
@@ -102,7 +115,7 @@ where T: Debug
|
||||
let p = &self.nodes.get(p_id.index).unwrap().data;
|
||||
let mut i = 0;
|
||||
while i != s.len() {
|
||||
let d = (self.distance)(p, &s[i]);
|
||||
let d = D::distance(p, &s[i]);
|
||||
if d <= r {
|
||||
my_near.0.push(s.remove(i));
|
||||
} else if d > r && d <= F::two() * r{
|
||||
@@ -156,7 +169,7 @@ where T: Debug
|
||||
|
||||
let q: Vec<&Node<T>> = qi_p_ds.iter().flat_map(|(n, _)| self.get_child(n, i)).collect();
|
||||
|
||||
children.extend(q.into_iter().map(|n| (n, (self.distance)(&n.data, &p))));
|
||||
children.extend(q.into_iter().map(|n| (n, D::distance(&n.data, &p))));
|
||||
|
||||
children
|
||||
|
||||
@@ -180,7 +193,7 @@ where T: Debug
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn check_invariant(&self, invariant: fn(&CoverTree<T, F>, &Vec<&Node<T>>, &Vec<&Node<T>>, i8) -> ()) {
|
||||
fn check_invariant(&self, invariant: fn(&CoverTree<T, F, D>, &Vec<&Node<T>>, &Vec<&Node<T>>, i8) -> ()) {
|
||||
let mut current_nodes: Vec<&Node<T>> = Vec::new();
|
||||
current_nodes.push(self.root());
|
||||
for i in (self.min_level..self.max_level+1).rev() {
|
||||
@@ -193,7 +206,7 @@ where T: Debug
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn nesting_invariant(_: &CoverTree<T, F>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, _: i8) {
|
||||
fn nesting_invariant(_: &CoverTree<T, F, D>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, _: i8) {
|
||||
let nodes_set: HashSet<&Node<T>> = HashSet::from_iter(nodes.into_iter().map(|n| *n));
|
||||
let next_nodes_set: HashSet<&Node<T>> = HashSet::from_iter(next_nodes.into_iter().map(|n| *n));
|
||||
for n in nodes_set.iter() {
|
||||
@@ -202,11 +215,11 @@ where T: Debug
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn covering_tree(tree: &CoverTree<T, F>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, i: i8) {
|
||||
fn covering_tree(tree: &CoverTree<T, F, D>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, i: i8) {
|
||||
let mut p_selected: Vec<&Node<T>> = Vec::new();
|
||||
for p in next_nodes {
|
||||
for q in nodes {
|
||||
if (tree.distance)(&p.data, &q.data) <= tree.base.powf(F::from(i).unwrap()) {
|
||||
if D::distance(&p.data, &q.data) <= tree.base.powf(F::from(i).unwrap()) {
|
||||
p_selected.push(*p);
|
||||
}
|
||||
}
|
||||
@@ -216,11 +229,11 @@ where T: Debug
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn separation(tree: &CoverTree<T, F>, nodes: &Vec<&Node<T>>, _: &Vec<&Node<T>>, i: i8) {
|
||||
fn separation(tree: &CoverTree<T, F, D>, nodes: &Vec<&Node<T>>, _: &Vec<&Node<T>>, i: i8) {
|
||||
for p in nodes {
|
||||
for q in nodes {
|
||||
if p != q {
|
||||
assert!((tree.distance)(&p.data, &q.data) > tree.base.powf(F::from(i).unwrap()));
|
||||
assert!(D::distance(&p.data, &q.data) > tree.base.powf(F::from(i).unwrap()));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -228,28 +241,12 @@ where T: Debug
|
||||
|
||||
}
|
||||
|
||||
impl<'a, T, F: FloatExt> KNNAlgorithm<T> for CoverTree<'a, T, F>
|
||||
where T: Debug
|
||||
{
|
||||
fn find(&self, p: &T, k: usize) -> Vec<usize>{
|
||||
let mut qi_p_ds = vec!((self.root(), (self.distance)(&p, &self.root().data)));
|
||||
for i in (self.min_level..self.max_level+1).rev() {
|
||||
let i_d = self.base.powf(F::from(i).unwrap());
|
||||
let mut q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
|
||||
let d_p_q = self.min_k_by_distance(&mut q_p_ds, k);
|
||||
qi_p_ds = q_p_ds.into_iter().filter(|(_, d)| d <= &(d_p_q + i_d)).collect();
|
||||
}
|
||||
qi_p_ds.sort_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap());
|
||||
qi_p_ds[..usize::min(qi_p_ds.len(), k)].iter().map(|(n, _)| n.index.index).collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
|
||||
pub struct NodeId {
|
||||
index: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Node<T> {
|
||||
index: NodeId,
|
||||
data: T,
|
||||
@@ -280,13 +277,19 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
struct SimpleDistance{}
|
||||
|
||||
impl Distance<i32, f64> for SimpleDistance {
|
||||
fn distance(a: &i32, b: &i32) -> f64 {
|
||||
(a - b).abs() as f64
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cover_tree_test() {
|
||||
let data = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9);
|
||||
let distance = |a: &i32, b: &i32| -> f64 {
|
||||
(a - b).abs() as f64
|
||||
};
|
||||
let mut tree = CoverTree::<i32, f64>::new(data, &distance);
|
||||
|
||||
let mut tree = CoverTree::new(data, SimpleDistance{});
|
||||
for d in vec!(10, 11, 12, 13, 14, 15, 16, 17, 18, 19) {
|
||||
tree.insert(d);
|
||||
}
|
||||
@@ -306,10 +309,8 @@ mod tests {
|
||||
#[test]
|
||||
fn test_invariants(){
|
||||
let data = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9);
|
||||
let distance = |a: &i32, b: &i32| -> f64 {
|
||||
(a - b).abs() as f64
|
||||
};
|
||||
let tree = CoverTree::<i32, f64>::new(data, &distance);
|
||||
|
||||
let tree = CoverTree::new(data, SimpleDistance{});
|
||||
tree.check_invariant(CoverTree::nesting_invariant);
|
||||
tree.check_invariant(CoverTree::covering_tree);
|
||||
tree.check_invariant(CoverTree::separation);
|
||||
|
||||
@@ -1,16 +1,28 @@
|
||||
use crate::algorithm::neighbour::KNNAlgorithm;
|
||||
use crate::algorithm::sort::heap_select::HeapSelect;
|
||||
use std::cmp::{Ordering, PartialOrd};
|
||||
use num_traits::Float;
|
||||
use std::marker::PhantomData;
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
pub struct LinearKNNSearch<'a, T, F: Float> {
|
||||
distance: Box<dyn Fn(&T, &T) -> F + 'a>,
|
||||
data: Vec<T>
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::math::distance::Distance;
|
||||
use crate::algorithm::sort::heap_select::HeapSelect;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct LinearKNNSearch<T, F: FloatExt, D: Distance<T, F>> {
|
||||
distance: D,
|
||||
data: Vec<T>,
|
||||
f: PhantomData<F>
|
||||
}
|
||||
|
||||
impl<'a, T, F: Float> KNNAlgorithm<T> for LinearKNNSearch<'a, T, F>
|
||||
{
|
||||
fn find(&self, from: &T, k: usize) -> Vec<usize> {
|
||||
impl<T, F: FloatExt, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
||||
pub fn new(data: Vec<T>, distance: D) -> LinearKNNSearch<T, F, D>{
|
||||
LinearKNNSearch{
|
||||
data: data,
|
||||
distance: distance,
|
||||
f: PhantomData
|
||||
}
|
||||
}
|
||||
|
||||
pub fn find(&self, from: &T, k: usize) -> Vec<usize> {
|
||||
if k < 1 || k > self.data.len() {
|
||||
panic!("k should be >= 1 and <= length(data)");
|
||||
}
|
||||
@@ -19,14 +31,14 @@ impl<'a, T, F: Float> KNNAlgorithm<T> for LinearKNNSearch<'a, T, F>
|
||||
|
||||
for _ in 0..k {
|
||||
heap.add(KNNPoint{
|
||||
distance: Float::infinity(),
|
||||
distance: F::infinity(),
|
||||
index: None
|
||||
});
|
||||
}
|
||||
|
||||
for i in 0..self.data.len() {
|
||||
|
||||
let d = (self.distance)(&from, &self.data[i]);
|
||||
let d = D::distance(&from, &self.data[i]);
|
||||
let datum = heap.peek_mut();
|
||||
if d < datum.distance {
|
||||
datum.distance = d;
|
||||
@@ -41,43 +53,34 @@ impl<'a, T, F: Float> KNNAlgorithm<T> for LinearKNNSearch<'a, T, F>
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T, F: Float> LinearKNNSearch<'a, T, F> {
|
||||
pub fn new(data: Vec<T>, distance: &'a dyn Fn(&T, &T) -> F) -> LinearKNNSearch<T, F>{
|
||||
LinearKNNSearch{
|
||||
data: data,
|
||||
distance: Box::new(distance)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct KNNPoint<F: Float> {
|
||||
struct KNNPoint<F: FloatExt> {
|
||||
distance: F,
|
||||
index: Option<usize>
|
||||
}
|
||||
|
||||
impl<F: Float> PartialOrd for KNNPoint<F> {
|
||||
impl<F: FloatExt> PartialOrd for KNNPoint<F> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
self.distance.partial_cmp(&other.distance)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Float> PartialEq for KNNPoint<F> {
|
||||
impl<F: FloatExt> PartialEq for KNNPoint<F> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.distance == other.distance
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Float> Eq for KNNPoint<F> {}
|
||||
impl<F: FloatExt> Eq for KNNPoint<F> {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::math::distance::euclidian;
|
||||
use crate::math::distance::Distances;
|
||||
|
||||
struct SimpleDistance{}
|
||||
|
||||
impl SimpleDistance {
|
||||
impl Distance<i32, f64> for SimpleDistance {
|
||||
fn distance(a: &i32, b: &i32) -> f64 {
|
||||
(a - b).abs() as f64
|
||||
}
|
||||
@@ -87,13 +90,13 @@ mod tests {
|
||||
fn knn_find() {
|
||||
let data1 = vec!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
|
||||
|
||||
let algorithm1 = LinearKNNSearch::new(data1, &SimpleDistance::distance);
|
||||
let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance{});
|
||||
|
||||
assert_eq!(vec!(1, 2, 0), algorithm1.find(&2, 3));
|
||||
|
||||
let data2 = vec!(vec![1., 1.], vec![2., 2.], vec![3., 3.], vec![4., 4.], vec![5., 5.]);
|
||||
|
||||
let algorithm2 = LinearKNNSearch::new(data2, &euclidian::distance);
|
||||
let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian());
|
||||
|
||||
assert_eq!(vec!(2, 3, 1), algorithm2.find(&vec![3., 3.], 3));
|
||||
}
|
||||
@@ -116,7 +119,7 @@ mod tests {
|
||||
};
|
||||
|
||||
let point_inf = KNNPoint{
|
||||
distance: Float::infinity(),
|
||||
distance: std::f64::INFINITY,
|
||||
index: Some(3)
|
||||
};
|
||||
|
||||
|
||||
@@ -1,12 +1,3 @@
|
||||
pub mod cover_tree;
|
||||
pub mod linear_search;
|
||||
pub mod bbd_tree;
|
||||
|
||||
pub enum KNNAlgorithmName {
|
||||
CoverTree,
|
||||
LinearSearch,
|
||||
}
|
||||
|
||||
pub trait KNNAlgorithm<T>{
|
||||
fn find(&self, from: &T, k: usize) -> Vec<usize>;
|
||||
}
|
||||
pub mod bbd_tree;
|
||||
Reference in New Issue
Block a user