feat: new distance function parameter in KNN, extends KNN documentation

This commit is contained in:
Volodymyr Orlov
2020-08-28 15:30:52 -07:00
parent dcf636a5f1
commit 367ea62608
6 changed files with 172 additions and 33 deletions
+8 -6
View File
@@ -78,7 +78,7 @@ impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D> {
node_id
}
pub fn find(&self, p: &T, k: usize) -> Vec<usize> {
pub fn find(&self, p: &T, k: usize) -> Vec<(usize, F)> {
let mut qi_p_ds = vec![(self.root(), self.distance.distance(&p, &self.root().data))];
for i in (self.min_level..self.max_level + 1).rev() {
let i_d = self.base.powf(F::from(i).unwrap());
@@ -92,7 +92,7 @@ impl<T: Debug, F: FloatExt, D: Distance<T, F>> CoverTree<T, F, D> {
qi_p_ds.sort_by(|(_, d1), (_, d2)| d1.partial_cmp(d2).unwrap());
qi_p_ds[..usize::min(qi_p_ds.len(), k)]
.iter()
.map(|(n, _)| n.index.index)
.map(|(n, d)| (n.index.index, *d))
.collect()
}
@@ -353,12 +353,14 @@ mod tests {
}
let mut nearest_3_to_5 = tree.find(&5, 3);
nearest_3_to_5.sort();
assert_eq!(vec!(3, 4, 5), nearest_3_to_5);
nearest_3_to_5.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let nearest_3_to_5_indexes: Vec<usize> = nearest_3_to_5.iter().map(|v| v.0).collect();
assert_eq!(vec!(4, 5, 3), nearest_3_to_5_indexes);
let mut nearest_3_to_15 = tree.find(&15, 3);
nearest_3_to_15.sort();
assert_eq!(vec!(13, 14, 15), nearest_3_to_15);
nearest_3_to_15.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let nearest_3_to_15_indexes: Vec<usize> = nearest_3_to_15.iter().map(|v| v.0).collect();
assert_eq!(vec!(14, 13, 15), nearest_3_to_15_indexes);
assert_eq!(-1, tree.min_level);
assert_eq!(100, tree.max_level);
+15 -4
View File
@@ -22,7 +22,7 @@ impl<T, F: FloatExt, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
}
}
pub fn find(&self, from: &T, k: usize) -> Vec<usize> {
pub fn find(&self, from: &T, k: usize) -> Vec<(usize, F)> {
if k < 1 || k > self.data.len() {
panic!("k should be >= 1 and <= length(data)");
}
@@ -48,7 +48,10 @@ impl<T, F: FloatExt, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
heap.sort();
heap.get().into_iter().flat_map(|x| x.index).collect()
heap.get()
.into_iter()
.flat_map(|x| x.index.map(|i| (i, x.distance)))
.collect()
}
}
@@ -91,7 +94,9 @@ mod tests {
let algorithm1 = LinearKNNSearch::new(data1, SimpleDistance {});
assert_eq!(vec!(1, 2, 0), algorithm1.find(&2, 3));
let found_idxs1: Vec<usize> = algorithm1.find(&2, 3).iter().map(|v| v.0).collect();
assert_eq!(vec!(1, 2, 0), found_idxs1);
let data2 = vec![
vec![1., 1.],
@@ -103,7 +108,13 @@ mod tests {
let algorithm2 = LinearKNNSearch::new(data2, Distances::euclidian());
assert_eq!(vec!(2, 3, 1), algorithm2.find(&vec![3., 3.], 3));
let found_idxs2: Vec<usize> = algorithm2
.find(&vec![3., 3.], 3)
.iter()
.map(|v| v.0)
.collect();
assert_eq!(vec!(2, 3, 1), found_idxs2);
}
#[test]