fix: refactors knn and distance functions

This commit is contained in:
Volodymyr Orlov
2020-02-21 18:54:50 -08:00
parent 0e89113297
commit fe50509d3b
8 changed files with 101 additions and 154 deletions
+4 -18
View File
@@ -1,6 +1,5 @@
use std::collections::LinkedList;
use crate::linalg::Matrix;
use crate::math::distance::euclidian;
#[derive(Debug)]
pub struct BBDTree {
@@ -77,10 +76,10 @@ impl BBDTree {
let d = centroids[0].len();
// Determine which mean the node mean is closest to
let mut min_dist = BBDTree::squared_distance(&self.nodes[node].center, &centroids[candidates[0]]);
let mut min_dist = euclidian::squared_distance(&self.nodes[node].center, &centroids[candidates[0]]);
let mut closest = candidates[0];
for i in 1..k {
let dist = BBDTree::squared_distance(&self.nodes[node].center, &centroids[candidates[i]]);
let dist = euclidian::squared_distance(&self.nodes[node].center, &centroids[candidates[i]]);
if dist < min_dist {
min_dist = dist;
closest = candidates[i];
@@ -146,20 +145,7 @@ impl BBDTree {
}
return lhs >= 2f64 * rhs;
}
fn squared_distance(x: &Vec<f64>,y: &Vec<f64>) -> f64 {
if x.len() != y.len() {
panic!("Input vector sizes are different.");
}
let mut sum = 0f64;
for i in 0..x.len() {
sum += (x[i] - y[i]).powf(2.);
}
return sum;
}
}
fn build_node<M: Matrix>(&mut self, data: &M, begin: usize, end: usize) -> usize {
let (_, d) = data.shape();
+5 -6
View File
@@ -72,9 +72,8 @@ impl Eq for KNNPoint {}
#[cfg(test)]
mod tests {
use super::*;
use crate::math::distance::Distance;
use ndarray::{arr1, Array1};
use super::*;
use crate::math::distance::euclidian;
struct SimpleDistance{}
@@ -92,11 +91,11 @@ mod tests {
assert_eq!(vec!(1, 2, 0), algorithm1.find(&2, 3));
let data2 = vec!(arr1(&[1, 1]), arr1(&[2, 2]), arr1(&[3, 3]), arr1(&[4, 4]), arr1(&[5, 5]));
let data2 = vec!(vec![1., 1.], vec![2., 2.], vec![3., 3.], vec![4., 4.], vec![5., 5.]);
let algorithm2 = LinearKNNSearch::new(data2, &Array1::distance);
let algorithm2 = LinearKNNSearch::new(data2, &euclidian::distance);
assert_eq!(vec!(2, 3, 1), algorithm2.find(&arr1(&[3, 3]), 3));
assert_eq!(vec!(2, 3, 1), algorithm2.find(&vec![3., 3.], 3));
}
#[test]