feat: extends interface of Matrix to support for broad range of types
This commit is contained in:
@@ -1,42 +1,45 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::distance::euclidian;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct BBDTree {
|
||||
nodes: Vec<BBDTreeNode>,
|
||||
pub struct BBDTree<T: FloatExt + Debug> {
|
||||
nodes: Vec<BBDTreeNode<T>>,
|
||||
index: Vec<usize>,
|
||||
root: usize
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct BBDTreeNode {
|
||||
struct BBDTreeNode<T: FloatExt + Debug> {
|
||||
count: usize,
|
||||
index: usize,
|
||||
center: Vec<f64>,
|
||||
radius: Vec<f64>,
|
||||
sum: Vec<f64>,
|
||||
cost: f64,
|
||||
center: Vec<T>,
|
||||
radius: Vec<T>,
|
||||
sum: Vec<T>,
|
||||
cost: T,
|
||||
lower: Option<usize>,
|
||||
upper: Option<usize>
|
||||
}
|
||||
|
||||
impl BBDTreeNode {
|
||||
fn new(d: usize) -> BBDTreeNode {
|
||||
impl<T: FloatExt + Debug> BBDTreeNode<T> {
|
||||
fn new(d: usize) -> BBDTreeNode<T> {
|
||||
BBDTreeNode {
|
||||
count: 0,
|
||||
index: 0,
|
||||
center: vec![0f64; d],
|
||||
radius: vec![0f64; d],
|
||||
sum: vec![0f64; d],
|
||||
cost: 0f64,
|
||||
center: vec![T::zero(); d],
|
||||
radius: vec![T::zero(); d],
|
||||
sum: vec![T::zero(); d],
|
||||
cost: T::zero(),
|
||||
lower: Option::None,
|
||||
upper: Option::None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BBDTree {
|
||||
pub fn new<M: Matrix>(data: &M) -> BBDTree {
|
||||
impl<T: FloatExt + Debug> BBDTree<T> {
|
||||
pub fn new<M: Matrix<T>>(data: &M) -> BBDTree<T> {
|
||||
let nodes = Vec::new();
|
||||
|
||||
let (n, _) = data.shape();
|
||||
@@ -59,20 +62,20 @@ impl BBDTree {
|
||||
tree
|
||||
}
|
||||
|
||||
pub(in crate) fn clustering(&self, centroids: &Vec<Vec<f64>>, sums: &mut Vec<Vec<f64>>, counts: &mut Vec<usize>, membership: &mut Vec<usize>) -> f64 {
|
||||
pub(in crate) fn clustering(&self, centroids: &Vec<Vec<T>>, sums: &mut Vec<Vec<T>>, counts: &mut Vec<usize>, membership: &mut Vec<usize>) -> T {
|
||||
let k = centroids.len();
|
||||
|
||||
counts.iter_mut().for_each(|x| *x = 0);
|
||||
let mut candidates = vec![0; k];
|
||||
for i in 0..k {
|
||||
candidates[i] = i;
|
||||
sums[i].iter_mut().for_each(|x| *x = 0f64);
|
||||
sums[i].iter_mut().for_each(|x| *x = T::zero());
|
||||
}
|
||||
|
||||
self.filter(self.root, centroids, &candidates, k, sums, counts, membership)
|
||||
}
|
||||
|
||||
fn filter(&self, node: usize, centroids: &Vec<Vec<f64>>, candidates: &Vec<usize>, k: usize, sums: &mut Vec<Vec<f64>>, counts: &mut Vec<usize>, membership: &mut Vec<usize>) -> f64{
|
||||
fn filter(&self, node: usize, centroids: &Vec<Vec<T>>, candidates: &Vec<usize>, k: usize, sums: &mut Vec<Vec<T>>, counts: &mut Vec<usize>, membership: &mut Vec<usize>) -> T{
|
||||
let d = centroids[0].len();
|
||||
|
||||
// Determine which mean the node mean is closest to
|
||||
@@ -109,7 +112,7 @@ impl BBDTree {
|
||||
|
||||
// Assigns all data within this node to a single mean
|
||||
for i in 0..d {
|
||||
sums[closest][i] += self.nodes[node].sum[i];
|
||||
sums[closest][i] = sums[closest][i] + self.nodes[node].sum[i];
|
||||
}
|
||||
|
||||
counts[closest] += self.nodes[node].count;
|
||||
@@ -123,7 +126,7 @@ impl BBDTree {
|
||||
|
||||
}
|
||||
|
||||
fn prune(center: &Vec<f64>, radius: &Vec<f64>, centroids: &Vec<Vec<f64>>, best_index: usize, test_index: usize) -> bool {
|
||||
fn prune(center: &Vec<T>, radius: &Vec<T>, centroids: &Vec<Vec<T>>, best_index: usize, test_index: usize) -> bool {
|
||||
if best_index == test_index {
|
||||
return false;
|
||||
}
|
||||
@@ -132,22 +135,22 @@ impl BBDTree {
|
||||
|
||||
let best = ¢roids[best_index];
|
||||
let test = ¢roids[test_index];
|
||||
let mut lhs = 0f64;
|
||||
let mut rhs = 0f64;
|
||||
let mut lhs = T::zero();
|
||||
let mut rhs = T::zero();
|
||||
for i in 0..d {
|
||||
let diff = test[i] - best[i];
|
||||
lhs += diff * diff;
|
||||
if diff > 0f64 {
|
||||
rhs += (center[i] + radius[i] - best[i]) * diff;
|
||||
lhs = lhs + diff * diff;
|
||||
if diff > T::zero() {
|
||||
rhs = rhs + (center[i] + radius[i] - best[i]) * diff;
|
||||
} else {
|
||||
rhs += (center[i] - radius[i] - best[i]) * diff;
|
||||
rhs = rhs + (center[i] - radius[i] - best[i]) * diff;
|
||||
}
|
||||
}
|
||||
|
||||
return lhs >= 2f64 * rhs;
|
||||
return lhs >= T::two() * rhs;
|
||||
}
|
||||
|
||||
fn build_node<M: Matrix>(&mut self, data: &M, begin: usize, end: usize) -> usize {
|
||||
fn build_node<M: Matrix<T>>(&mut self, data: &M, begin: usize, end: usize) -> usize {
|
||||
let (_, d) = data.shape();
|
||||
|
||||
// Allocate the node
|
||||
@@ -158,8 +161,8 @@ impl BBDTree {
|
||||
node.index = begin;
|
||||
|
||||
// Calculate the bounding box
|
||||
let mut lower_bound = vec![0f64; d];
|
||||
let mut upper_bound = vec![0f64; d];
|
||||
let mut lower_bound = vec![T::zero(); d];
|
||||
let mut upper_bound = vec![T::zero(); d];
|
||||
|
||||
for i in 0..d {
|
||||
lower_bound[i] = data.get(self.index[begin],i);
|
||||
@@ -179,11 +182,11 @@ impl BBDTree {
|
||||
}
|
||||
|
||||
// Calculate bounding box stats
|
||||
let mut max_radius = -1.;
|
||||
let mut max_radius = T::from(-1.).unwrap();
|
||||
let mut split_index = 0;
|
||||
for i in 0..d {
|
||||
node.center[i] = (lower_bound[i] + upper_bound[i]) / 2.;
|
||||
node.radius[i] = (upper_bound[i] - lower_bound[i]) / 2.;
|
||||
node.center[i] = (lower_bound[i] + upper_bound[i]) / T::two();
|
||||
node.radius[i] = (upper_bound[i] - lower_bound[i]) / T::two();
|
||||
if node.radius[i] > max_radius {
|
||||
max_radius = node.radius[i];
|
||||
split_index = i;
|
||||
@@ -191,7 +194,7 @@ impl BBDTree {
|
||||
}
|
||||
|
||||
// If the max spread is 0, make this a leaf node
|
||||
if max_radius < 1E-10 {
|
||||
if max_radius < T::from(1E-10).unwrap() {
|
||||
node.lower = Option::None;
|
||||
node.upper = Option::None;
|
||||
for i in 0..d {
|
||||
@@ -201,11 +204,11 @@ impl BBDTree {
|
||||
if end > begin + 1 {
|
||||
let len = end - begin;
|
||||
for i in 0..d {
|
||||
node.sum[i] *= len as f64;
|
||||
node.sum[i] = node.sum[i] * T::from(len).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
node.cost = 0f64;
|
||||
node.cost = T::zero();
|
||||
return self.add_node(node);
|
||||
}
|
||||
|
||||
@@ -247,9 +250,9 @@ impl BBDTree {
|
||||
node.sum[i] = self.nodes[node.lower.unwrap()].sum[i] + self.nodes[node.upper.unwrap()].sum[i];
|
||||
}
|
||||
|
||||
let mut mean = vec![0f64; d];
|
||||
let mut mean = vec![T::zero(); d];
|
||||
for i in 0..d {
|
||||
mean[i] = node.sum[i] / node.count as f64;
|
||||
mean[i] = node.sum[i] / T::from(node.count).unwrap();
|
||||
}
|
||||
|
||||
node.cost = BBDTree::node_cost(&self.nodes[node.lower.unwrap()], &mean) + BBDTree::node_cost(&self.nodes[node.upper.unwrap()], &mean);
|
||||
@@ -257,17 +260,17 @@ impl BBDTree {
|
||||
self.add_node(node)
|
||||
}
|
||||
|
||||
fn node_cost(node: &BBDTreeNode, center: &Vec<f64>) -> f64 {
|
||||
fn node_cost(node: &BBDTreeNode<T>, center: &Vec<T>) -> T {
|
||||
let d = center.len();
|
||||
let mut scatter = 0f64;
|
||||
let mut scatter = T::zero();
|
||||
for i in 0..d {
|
||||
let x = (node.sum[i] / node.count as f64) - center[i];
|
||||
scatter += x * x;
|
||||
let x = (node.sum[i] / T::from(node.count).unwrap()) - center[i];
|
||||
scatter = scatter + x * x;
|
||||
}
|
||||
node.cost + node.count as f64 * scatter
|
||||
node.cost + T::from(node.count).unwrap() * scatter
|
||||
}
|
||||
|
||||
fn add_node(&mut self, new_node: BBDTreeNode) -> usize{
|
||||
fn add_node(&mut self, new_node: BBDTreeNode<T>) -> usize{
|
||||
let idx = self.nodes.len();
|
||||
self.nodes.push(new_node);
|
||||
idx
|
||||
|
||||
@@ -1,29 +1,29 @@
|
||||
use crate::math;
|
||||
use crate::algorithm::neighbour::KNNAlgorithm;
|
||||
use crate::algorithm::sort::heap_select::HeapSelect;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::iter::FromIterator;
|
||||
use std::fmt::Debug;
|
||||
use std::cmp::{PartialOrd};
|
||||
use core::hash::{Hash, Hasher};
|
||||
|
||||
pub struct CoverTree<'a, T>
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::algorithm::neighbour::KNNAlgorithm;
|
||||
use crate::algorithm::sort::heap_select::HeapSelect;
|
||||
|
||||
pub struct CoverTree<'a, T, F: FloatExt>
|
||||
where T: Debug
|
||||
{
|
||||
base: f64,
|
||||
base: F,
|
||||
max_level: i8,
|
||||
min_level: i8,
|
||||
distance: &'a dyn Fn(&T, &T) -> f64,
|
||||
distance: &'a dyn Fn(&T, &T) -> F,
|
||||
nodes: Vec<Node<T>>
|
||||
}
|
||||
|
||||
impl<'a, T> CoverTree<'a, T>
|
||||
impl<'a, T, F: FloatExt> CoverTree<'a, T, F>
|
||||
where T: Debug
|
||||
{
|
||||
|
||||
pub fn new(mut data: Vec<T>, distance: &'a dyn Fn(&T, &T) -> f64) -> CoverTree<T> {
|
||||
pub fn new(mut data: Vec<T>, distance: &'a dyn Fn(&T, &T) -> F) -> CoverTree<T, F> {
|
||||
let mut tree = CoverTree {
|
||||
base: 2f64,
|
||||
base: F::two(),
|
||||
max_level: 100,
|
||||
min_level: 100,
|
||||
distance: distance,
|
||||
@@ -46,15 +46,15 @@ where T: Debug
|
||||
let mut qi_p_ds = vec!((self.root(), (self.distance)(&p, &self.root().data)));
|
||||
let mut i = self.max_level;
|
||||
loop {
|
||||
let i_d = self.base.powf(i as f64);
|
||||
let i_d = self.base.powf(F::from(i).unwrap());
|
||||
let q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
|
||||
let d_p_q = self.min_by_distance(&q_p_ds);
|
||||
if d_p_q < math::EPSILON {
|
||||
if d_p_q < F::epsilon() {
|
||||
return
|
||||
} else if d_p_q > i_d {
|
||||
break;
|
||||
}
|
||||
if self.min_by_distance(&qi_p_ds) <= self.base.powf(i as f64){
|
||||
if self.min_by_distance(&qi_p_ds) <= self.base.powf(F::from(i).unwrap()){
|
||||
parent = q_p_ds.iter().find(|(_, d)| d <= &i_d).map(|(n, _)| n.index);
|
||||
p_i = i;
|
||||
}
|
||||
@@ -82,7 +82,7 @@ where T: Debug
|
||||
node_id
|
||||
}
|
||||
|
||||
fn split(&self, p_id: NodeId, r: f64, s1: &mut Vec<T>, s2: Option<&mut Vec<T>>) -> (Vec<T>, Vec<T>){
|
||||
fn split(&self, p_id: NodeId, r: F, s1: &mut Vec<T>, s2: Option<&mut Vec<T>>) -> (Vec<T>, Vec<T>){
|
||||
|
||||
let mut my_near = (Vec::new(), Vec::new());
|
||||
|
||||
@@ -96,7 +96,7 @@ where T: Debug
|
||||
|
||||
}
|
||||
|
||||
fn split_remove_s(&self, p_id: NodeId, r: f64, s: &mut Vec<T>, mut my_near: (Vec<T>, Vec<T>)) -> (Vec<T>, Vec<T>){
|
||||
fn split_remove_s(&self, p_id: NodeId, r: F, s: &mut Vec<T>, mut my_near: (Vec<T>, Vec<T>)) -> (Vec<T>, Vec<T>){
|
||||
|
||||
if s.len() > 0 {
|
||||
let p = &self.nodes.get(p_id.index).unwrap().data;
|
||||
@@ -105,7 +105,7 @@ where T: Debug
|
||||
let d = (self.distance)(p, &s[i]);
|
||||
if d <= r {
|
||||
my_near.0.push(s.remove(i));
|
||||
} else if d > r && d <= 2f64 * r{
|
||||
} else if d > r && d <= F::two() * r{
|
||||
my_near.1.push(s.remove(i));
|
||||
} else {
|
||||
i += 1;
|
||||
@@ -122,15 +122,15 @@ where T: Debug
|
||||
self.min_level = std::cmp::min(self.min_level, i);
|
||||
return (p, far);
|
||||
} else {
|
||||
let (my, n) = self.split(p, self.base.powf((i-1) as f64), &mut near, None);
|
||||
let (my, n) = self.split(p, self.base.powf(F::from(i-1).unwrap()), &mut near, None);
|
||||
let (pi, mut near) = self.construct(p, my, n, i-1);
|
||||
while near.len() > 0 {
|
||||
let q_data = near.remove(0);
|
||||
let nn = self.new_node(Some(p), q_data);
|
||||
let (my, n) = self.split(nn, self.base.powf((i-1) as f64), &mut near, Some(&mut far));
|
||||
let (my, n) = self.split(nn, self.base.powf(F::from(i-1).unwrap()), &mut near, Some(&mut far));
|
||||
let (child, mut unused) = self.construct(nn, my, n, i-1);
|
||||
self.add_child(pi, child, i);
|
||||
let new_near_far = self.split(p, self.base.powf(i as f64), &mut unused, None);
|
||||
let new_near_far = self.split(p, self.base.powf(F::from(i).unwrap()), &mut unused, None);
|
||||
near.extend(new_near_far.0);
|
||||
far.extend(new_near_far.1);
|
||||
}
|
||||
@@ -148,9 +148,9 @@ where T: Debug
|
||||
self.nodes.first().unwrap()
|
||||
}
|
||||
|
||||
fn get_children_dist<'b>(&'b self, p: &T, qi_p_ds: &Vec<(&'b Node<T>, f64)>, i: i8) -> Vec<(&'b Node<T>, f64)> {
|
||||
fn get_children_dist<'b>(&'b self, p: &T, qi_p_ds: &Vec<(&'b Node<T>, F)>, i: i8) -> Vec<(&'b Node<T>, F)> {
|
||||
|
||||
let mut children = Vec::<(&'b Node<T>, f64)>::new();
|
||||
let mut children = Vec::<(&'b Node<T>, F)>::new();
|
||||
|
||||
children.extend(qi_p_ds.iter().cloned());
|
||||
|
||||
@@ -162,7 +162,7 @@ where T: Debug
|
||||
|
||||
}
|
||||
|
||||
fn min_k_by_distance(&self, q_p_ds: &mut Vec<(&Node<T>, f64)>, k: usize) -> f64 {
|
||||
fn min_k_by_distance(&self, q_p_ds: &mut Vec<(&Node<T>, F)>, k: usize) -> F {
|
||||
let mut heap = HeapSelect::with_capacity(k);
|
||||
for (_, d) in q_p_ds {
|
||||
heap.add(d);
|
||||
@@ -171,7 +171,7 @@ where T: Debug
|
||||
*heap.get().pop().unwrap()
|
||||
}
|
||||
|
||||
fn min_by_distance(&self, q_p_ds: &Vec<(&Node<T>, f64)>) -> f64 {
|
||||
fn min_by_distance(&self, q_p_ds: &Vec<(&Node<T>, F)>) -> F {
|
||||
q_p_ds.into_iter().min_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap()).unwrap().1
|
||||
}
|
||||
|
||||
@@ -180,7 +180,7 @@ where T: Debug
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn check_invariant(&self, invariant: fn(&CoverTree<T>, &Vec<&Node<T>>, &Vec<&Node<T>>, i8) -> ()) {
|
||||
fn check_invariant(&self, invariant: fn(&CoverTree<T, F>, &Vec<&Node<T>>, &Vec<&Node<T>>, i8) -> ()) {
|
||||
let mut current_nodes: Vec<&Node<T>> = Vec::new();
|
||||
current_nodes.push(self.root());
|
||||
for i in (self.min_level..self.max_level+1).rev() {
|
||||
@@ -193,7 +193,7 @@ where T: Debug
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn nesting_invariant(_: &CoverTree<T>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, _: i8) {
|
||||
fn nesting_invariant(_: &CoverTree<T, F>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, _: i8) {
|
||||
let nodes_set: HashSet<&Node<T>> = HashSet::from_iter(nodes.into_iter().map(|n| *n));
|
||||
let next_nodes_set: HashSet<&Node<T>> = HashSet::from_iter(next_nodes.into_iter().map(|n| *n));
|
||||
for n in nodes_set.iter() {
|
||||
@@ -202,11 +202,11 @@ where T: Debug
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn covering_tree(tree: &CoverTree<T>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, i: i8) {
|
||||
fn covering_tree(tree: &CoverTree<T, F>, nodes: &Vec<&Node<T>>, next_nodes: &Vec<&Node<T>>, i: i8) {
|
||||
let mut p_selected: Vec<&Node<T>> = Vec::new();
|
||||
for p in next_nodes {
|
||||
for q in nodes {
|
||||
if (tree.distance)(&p.data, &q.data) <= tree.base.powf(i as f64) {
|
||||
if (tree.distance)(&p.data, &q.data) <= tree.base.powf(F::from(i).unwrap()) {
|
||||
p_selected.push(*p);
|
||||
}
|
||||
}
|
||||
@@ -216,11 +216,11 @@ where T: Debug
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn separation(tree: &CoverTree<T>, nodes: &Vec<&Node<T>>, _: &Vec<&Node<T>>, i: i8) {
|
||||
fn separation(tree: &CoverTree<T, F>, nodes: &Vec<&Node<T>>, _: &Vec<&Node<T>>, i: i8) {
|
||||
for p in nodes {
|
||||
for q in nodes {
|
||||
if p != q {
|
||||
assert!((tree.distance)(&p.data, &q.data) > tree.base.powf(i as f64));
|
||||
assert!((tree.distance)(&p.data, &q.data) > tree.base.powf(F::from(i).unwrap()));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -228,13 +228,13 @@ where T: Debug
|
||||
|
||||
}
|
||||
|
||||
impl<'a, T> KNNAlgorithm<T> for CoverTree<'a, T>
|
||||
impl<'a, T, F: FloatExt> KNNAlgorithm<T> for CoverTree<'a, T, F>
|
||||
where T: Debug
|
||||
{
|
||||
fn find(&self, p: &T, k: usize) -> Vec<usize>{
|
||||
let mut qi_p_ds = vec!((self.root(), (self.distance)(&p, &self.root().data)));
|
||||
for i in (self.min_level..self.max_level+1).rev() {
|
||||
let i_d = self.base.powf(i as f64);
|
||||
let i_d = self.base.powf(F::from(i).unwrap());
|
||||
let mut q_p_ds = self.get_children_dist(&p, &qi_p_ds, i);
|
||||
let d_p_q = self.min_k_by_distance(&mut q_p_ds, k);
|
||||
qi_p_ds = q_p_ds.into_iter().filter(|(_, d)| d <= &(d_p_q + i_d)).collect();
|
||||
@@ -286,7 +286,7 @@ mod tests {
|
||||
let distance = |a: &i32, b: &i32| -> f64 {
|
||||
(a - b).abs() as f64
|
||||
};
|
||||
let mut tree = CoverTree::<i32>::new(data, &distance);
|
||||
let mut tree = CoverTree::<i32, f64>::new(data, &distance);
|
||||
for d in vec!(10, 11, 12, 13, 14, 15, 16, 17, 18, 19) {
|
||||
tree.insert(d);
|
||||
}
|
||||
@@ -309,7 +309,7 @@ mod tests {
|
||||
let distance = |a: &i32, b: &i32| -> f64 {
|
||||
(a - b).abs() as f64
|
||||
};
|
||||
let tree = CoverTree::<i32>::new(data, &distance);
|
||||
let tree = CoverTree::<i32, f64>::new(data, &distance);
|
||||
tree.check_invariant(CoverTree::nesting_invariant);
|
||||
tree.check_invariant(CoverTree::covering_tree);
|
||||
tree.check_invariant(CoverTree::separation);
|
||||
|
||||
@@ -3,19 +3,19 @@ use crate::algorithm::sort::heap_select::HeapSelect;
|
||||
use std::cmp::{Ordering, PartialOrd};
|
||||
use num_traits::Float;
|
||||
|
||||
pub struct LinearKNNSearch<'a, T> {
|
||||
distance: Box<dyn Fn(&T, &T) -> f64 + 'a>,
|
||||
pub struct LinearKNNSearch<'a, T, F: Float> {
|
||||
distance: Box<dyn Fn(&T, &T) -> F + 'a>,
|
||||
data: Vec<T>
|
||||
}
|
||||
|
||||
impl<'a, T> KNNAlgorithm<T> for LinearKNNSearch<'a, T>
|
||||
impl<'a, T, F: Float> KNNAlgorithm<T> for LinearKNNSearch<'a, T, F>
|
||||
{
|
||||
fn find(&self, from: &T, k: usize) -> Vec<usize> {
|
||||
if k < 1 || k > self.data.len() {
|
||||
panic!("k should be >= 1 and <= length(data)");
|
||||
}
|
||||
|
||||
let mut heap = HeapSelect::<KNNPoint>::with_capacity(k);
|
||||
let mut heap = HeapSelect::<KNNPoint<F>>::with_capacity(k);
|
||||
|
||||
for _ in 0..k {
|
||||
heap.add(KNNPoint{
|
||||
@@ -41,8 +41,8 @@ impl<'a, T> KNNAlgorithm<T> for LinearKNNSearch<'a, T>
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> LinearKNNSearch<'a, T> {
|
||||
pub fn new(data: Vec<T>, distance: &'a dyn Fn(&T, &T) -> f64) -> LinearKNNSearch<T>{
|
||||
impl<'a, T, F: Float> LinearKNNSearch<'a, T, F> {
|
||||
pub fn new(data: Vec<T>, distance: &'a dyn Fn(&T, &T) -> F) -> LinearKNNSearch<T, F>{
|
||||
LinearKNNSearch{
|
||||
data: data,
|
||||
distance: Box::new(distance)
|
||||
@@ -51,24 +51,24 @@ impl<'a, T> LinearKNNSearch<'a, T> {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct KNNPoint {
|
||||
distance: f64,
|
||||
struct KNNPoint<F: Float> {
|
||||
distance: F,
|
||||
index: Option<usize>
|
||||
}
|
||||
|
||||
impl PartialOrd for KNNPoint {
|
||||
impl<F: Float> PartialOrd for KNNPoint<F> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
self.distance.partial_cmp(&other.distance)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for KNNPoint {
|
||||
impl<F: Float> PartialEq for KNNPoint<F> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.distance == other.distance
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for KNNPoint {}
|
||||
impl<F: Float> Eq for KNNPoint<F> {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
use num_traits::Float;
|
||||
|
||||
pub trait QuickArgSort {
|
||||
fn quick_argsort(&mut self) -> Vec<usize>;
|
||||
}
|
||||
|
||||
impl QuickArgSort for Vec<f64> {
|
||||
impl<T: Float> QuickArgSort for Vec<T> {
|
||||
|
||||
fn quick_argsort(&mut self) -> Vec<usize> {
|
||||
let stack_size = 64;
|
||||
|
||||
Reference in New Issue
Block a user