fix: refactors knn and distance functions

This commit is contained in:
Volodymyr Orlov
2020-02-21 18:54:50 -08:00
parent 0e89113297
commit fe50509d3b
8 changed files with 101 additions and 154 deletions
+4 -16
View File
@@ -3,6 +3,7 @@ extern crate rand;
use rand::Rng;
use crate::linalg::Matrix;
use crate::math::distance::euclidian;
use crate::algorithm::neighbour::bbd_tree::BBDTree;
#[derive(Debug)]
@@ -101,7 +102,7 @@ impl KMeans{
let mut best_cluster = 0;
for j in 0..self.k {
let dist = KMeans::squared_distance(&x.get_row_as_vec(i), &self.centroids[j]);
let dist = euclidian::squared_distance(&x.get_row_as_vec(i), &self.centroids[j]);
if dist < min_dist {
min_dist = dist;
best_cluster = j;
@@ -127,7 +128,7 @@ impl KMeans{
// the distance from each sample to its closest center in scores.
for i in 0..n {
// compute the distance between this sample and the current center
let dist = KMeans::squared_distance(&data.get_row_as_vec(i), &centroid);
let dist = euclidian::squared_distance(&data.get_row_as_vec(i), &centroid);
if dist < d[i] {
d[i] = dist;
@@ -151,7 +152,7 @@ impl KMeans{
for i in 0..n {
// compute the distance between this sample and the current center
let dist = KMeans::squared_distance(&data.get_row_as_vec(i), &centroid);
let dist = euclidian::squared_distance(&data.get_row_as_vec(i), &centroid);
if dist < d[i] {
d[i] = dist;
@@ -161,19 +162,6 @@ impl KMeans{
y
}
fn squared_distance(x: &Vec<f64>,y: &Vec<f64>) -> f64 {
if x.len() != y.len() {
panic!("Input vector sizes are different.");
}
let mut sum = 0f64;
for i in 0..x.len() {
sum += (x[i] - y[i]).powf(2.);
}
return sum;
}
}