fix: refactors knn and distance functions
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
use std::collections::LinkedList;
|
||||
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::distance::euclidian;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct BBDTree {
|
||||
@@ -77,10 +76,10 @@ impl BBDTree {
|
||||
let d = centroids[0].len();
|
||||
|
||||
// Determine which mean the node mean is closest to
|
||||
let mut min_dist = BBDTree::squared_distance(&self.nodes[node].center, ¢roids[candidates[0]]);
|
||||
let mut min_dist = euclidian::squared_distance(&self.nodes[node].center, ¢roids[candidates[0]]);
|
||||
let mut closest = candidates[0];
|
||||
for i in 1..k {
|
||||
let dist = BBDTree::squared_distance(&self.nodes[node].center, ¢roids[candidates[i]]);
|
||||
let dist = euclidian::squared_distance(&self.nodes[node].center, ¢roids[candidates[i]]);
|
||||
if dist < min_dist {
|
||||
min_dist = dist;
|
||||
closest = candidates[i];
|
||||
@@ -146,20 +145,7 @@ impl BBDTree {
|
||||
}
|
||||
|
||||
return lhs >= 2f64 * rhs;
|
||||
}
|
||||
|
||||
fn squared_distance(x: &Vec<f64>,y: &Vec<f64>) -> f64 {
|
||||
if x.len() != y.len() {
|
||||
panic!("Input vector sizes are different.");
|
||||
}
|
||||
|
||||
let mut sum = 0f64;
|
||||
for i in 0..x.len() {
|
||||
sum += (x[i] - y[i]).powf(2.);
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
}
|
||||
|
||||
fn build_node<M: Matrix>(&mut self, data: &M, begin: usize, end: usize) -> usize {
|
||||
let (_, d) = data.shape();
|
||||
|
||||
@@ -72,9 +72,8 @@ impl Eq for KNNPoint {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::math::distance::Distance;
|
||||
use ndarray::{arr1, Array1};
|
||||
use super::*;
|
||||
use crate::math::distance::euclidian;
|
||||
|
||||
struct SimpleDistance{}
|
||||
|
||||
@@ -92,11 +91,11 @@ mod tests {
|
||||
|
||||
assert_eq!(vec!(1, 2, 0), algorithm1.find(&2, 3));
|
||||
|
||||
let data2 = vec!(arr1(&[1, 1]), arr1(&[2, 2]), arr1(&[3, 3]), arr1(&[4, 4]), arr1(&[5, 5]));
|
||||
let data2 = vec!(vec![1., 1.], vec![2., 2.], vec![3., 3.], vec![4., 4.], vec![5., 5.]);
|
||||
|
||||
let algorithm2 = LinearKNNSearch::new(data2, &Array1::distance);
|
||||
let algorithm2 = LinearKNNSearch::new(data2, &euclidian::distance);
|
||||
|
||||
assert_eq!(vec!(2, 3, 1), algorithm2.find(&arr1(&[3, 3]), 3));
|
||||
assert_eq!(vec!(2, 3, 1), algorithm2.find(&vec![3., 3.], 3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
Reference in New Issue
Block a user