feat: new distance function parameter in KNN, extends KNN documentation
This commit is contained in:
@@ -78,7 +78,7 @@ impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D> {
|
||||
node_id
|
||||
}
|
||||
|
||||
pub fn find(&self, p: &T, k: usize) -> Vec<usize> {
|
||||
pub fn find(&self, p: &T, k: usize) -> Vec<(usize, F)> {
|
||||
let mut qi_p_ds = vec![(self.root(), self.distance.distance(&p, &self.root().data))];
|
||||
for i in (self.min_level..self.max_level + 1).rev() {
|
||||
let i_d = self.base.powf(F::from(i).unwrap());
|
||||
@@ -92,7 +92,7 @@ impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D> {
|
||||
qi_p_ds.sort_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap());
|
||||
qi_p_ds[..usize::min(qi_p_ds.len(), k)]
|
||||
.iter()
|
||||
.map(|(n, _)| n.index.index)
|
||||
.map(|(n, d)| (n.index.index, *d))
|
||||
.collect()
|
||||
}
|
||||
|
||||
@@ -353,12 +353,14 @@ mod tests {
|
||||
}
|
||||
|
||||
let mut nearest_3_to_5 = tree.find(&5, 3);
|
||||
nearest_3_to_5.sort();
|
||||
assert_eq!(vec!(3, 4, 5), nearest_3_to_5);
|
||||
nearest_3_to_5.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
|
||||
let nearest_3_to_5_indexes: Vec<usize> = nearest_3_to_5.iter().map(|v| v.0).collect();
|
||||
assert_eq!(vec!(4, 5, 3), nearest_3_to_5_indexes);
|
||||
|
||||
let mut nearest_3_to_15 = tree.find(&15, 3);
|
||||
nearest_3_to_15.sort();
|
||||
assert_eq!(vec!(13, 14, 15), nearest_3_to_15);
|
||||
nearest_3_to_15.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
|
||||
let nearest_3_to_15_indexes: Vec<usize> = nearest_3_to_15.iter().map(|v| v.0).collect();
|
||||
assert_eq!(vec!(14, 13, 15), nearest_3_to_15_indexes);
|
||||
|
||||
assert_eq!(-1, tree.min_level);
|
||||
assert_eq!(100, tree.max_level);
|
||||
|
||||
@@ -22,7 +22,7 @@ impl<T, F: FloatExt, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn find(&self, from: &T, k: usize) -> Vec<usize> {
|
||||
pub fn find(&self, from: &T, k: usize) -> Vec<(usize, F)> {
|
||||
if k < 1 || k > self.data.len() {
|
||||
panic!("k should be >= 1 and <= length(data)");
|
||||
}
|
||||
@@ -48,7 +48,10 @@ impl<T, F: FloatExt, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
||||
|
||||
heap.sort();
|
||||
|
||||
heap.get().into_iter().flat_map(|x| x.index).collect()
|
||||
heap.get()
|
||||
.into_iter()
|
||||
.flat_map(|x| x.index.map(|i| (i, x.distance)))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,7 +94,9 @@ mod tests {
|
||||
|
||||
let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance {});
|
||||
|
||||
assert_eq!(vec!(1, 2, 0), algorithm1.find(&2, 3));
|
||||
let found_idxs1: Vec<usize> = algorithm1.find(&2, 3).iter().map(|v| v.0).collect();
|
||||
|
||||
assert_eq!(vec!(1, 2, 0), found_idxs1);
|
||||
|
||||
let data2 = vec![
|
||||
vec![1., 1.],
|
||||
@@ -103,7 +108,13 @@ mod tests {
|
||||
|
||||
let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian());
|
||||
|
||||
assert_eq!(vec!(2, 3, 1), algorithm2.find(&vec![3., 3.], 3));
|
||||
let found_idxs2: Vec<usize> = algorithm2
|
||||
.find(&vec![3., 3.], 3)
|
||||
.iter()
|
||||
.map(|v| v.0)
|
||||
.collect();
|
||||
|
||||
assert_eq!(vec!(2, 3, 1), found_idxs2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
Reference in New Issue
Block a user