feat: + cluster metrics

This commit is contained in:
Volodymyr Orlov
2020-09-22 20:23:51 -07:00
parent 0803532e79
commit 750015b861
15 changed files with 477 additions and 16 deletions
+16 -8
View File
@@ -189,15 +189,18 @@ impl<T: RealNumber + Sum> KMeans<T> {
/// Predict clusters for `x`
/// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features.
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
let (n, _) = x.shape();
let (n, m) = x.shape();
let mut result = M::zeros(1, n);
let mut row = vec![T::zero(); m];
for i in 0..n {
let mut min_dist = T::max_value();
let mut best_cluster = 0;
for j in 0..self.k {
let dist = Euclidian::squared_distance(&x.get_row_as_vec(i), &self.centroids[j]);
x.copy_row_as_vec(i, &mut row);
let dist = Euclidian::squared_distance(&row, &self.centroids[j]);
if dist < min_dist {
min_dist = dist;
best_cluster = j;
@@ -211,19 +214,22 @@ impl<T: RealNumber + Sum> KMeans<T> {
fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize) -> Vec<usize> {
let mut rng = rand::thread_rng();
let (n, _) = data.shape();
let (n, m) = data.shape();
let mut y = vec![0; n];
let mut centroid = data.get_row_as_vec(rng.gen_range(0, n));
let mut d = vec![T::max_value(); n];
let mut row = vec![T::zero(); m];
// pick the next center
for j in 1..k {
// Loop over the samples and compare them to the most recent center. Store
// the distance from each sample to its closest center in scores.
for i in 0..n {
// compute the distance between this sample and the current center
let dist = Euclidian::squared_distance(&data.get_row_as_vec(i), &centroid);
data.copy_row_as_vec(i, &mut row);
let dist = Euclidian::squared_distance(&row, &centroid);
if dist < d[i] {
d[i] = dist;
@@ -237,20 +243,22 @@ impl<T: RealNumber + Sum> KMeans<T> {
}
let cutoff = T::from(rng.gen::<f64>()).unwrap() * sum;
let mut cost = T::zero();
let index = 0;
for index in 0..n {
let mut index = 0;
while index < n {
cost = cost + d[index];
if cost >= cutoff {
break;
}
index += 1;
}
centroid = data.get_row_as_vec(index);
data.copy_row_as_vec(index, &mut centroid);
}
for i in 0..n {
data.copy_row_as_vec(i, &mut row);
// compute the distance between this sample and the current center
let dist = Euclidian::squared_distance(&data.get_row_as_vec(i), &centroid);
let dist = Euclidian::squared_distance(&row, &centroid);
if dist < d[i] {
d[i] = dist;