feat: + cluster metrics
This commit is contained in:
+16
-8
@@ -189,15 +189,18 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
||||
/// Predict clusters for `x`
|
||||
/// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features.
|
||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let (n, _) = x.shape();
|
||||
let (n, m) = x.shape();
|
||||
let mut result = M::zeros(1, n);
|
||||
|
||||
let mut row = vec![T::zero(); m];
|
||||
|
||||
for i in 0..n {
|
||||
let mut min_dist = T::max_value();
|
||||
let mut best_cluster = 0;
|
||||
|
||||
for j in 0..self.k {
|
||||
let dist = Euclidian::squared_distance(&x.get_row_as_vec(i), &self.centroids[j]);
|
||||
x.copy_row_as_vec(i, &mut row);
|
||||
let dist = Euclidian::squared_distance(&row, &self.centroids[j]);
|
||||
if dist < min_dist {
|
||||
min_dist = dist;
|
||||
best_cluster = j;
|
||||
@@ -211,19 +214,22 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
||||
|
||||
fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize) -> Vec<usize> {
|
||||
let mut rng = rand::thread_rng();
|
||||
let (n, _) = data.shape();
|
||||
let (n, m) = data.shape();
|
||||
let mut y = vec![0; n];
|
||||
let mut centroid = data.get_row_as_vec(rng.gen_range(0, n));
|
||||
|
||||
let mut d = vec![T::max_value(); n];
|
||||
|
||||
let mut row = vec![T::zero(); m];
|
||||
|
||||
// pick the next center
|
||||
for j in 1..k {
|
||||
// Loop over the samples and compare them to the most recent center. Store
|
||||
// the distance from each sample to its closest center in scores.
|
||||
for i in 0..n {
|
||||
// compute the distance between this sample and the current center
|
||||
let dist = Euclidian::squared_distance(&data.get_row_as_vec(i), ¢roid);
|
||||
data.copy_row_as_vec(i, &mut row);
|
||||
let dist = Euclidian::squared_distance(&row, ¢roid);
|
||||
|
||||
if dist < d[i] {
|
||||
d[i] = dist;
|
||||
@@ -237,20 +243,22 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
||||
}
|
||||
let cutoff = T::from(rng.gen::<f64>()).unwrap() * sum;
|
||||
let mut cost = T::zero();
|
||||
let index = 0;
|
||||
for index in 0..n {
|
||||
let mut index = 0;
|
||||
while index < n {
|
||||
cost = cost + d[index];
|
||||
if cost >= cutoff {
|
||||
break;
|
||||
}
|
||||
index += 1;
|
||||
}
|
||||
|
||||
centroid = data.get_row_as_vec(index);
|
||||
data.copy_row_as_vec(index, &mut centroid);
|
||||
}
|
||||
|
||||
for i in 0..n {
|
||||
data.copy_row_as_vec(i, &mut row);
|
||||
// compute the distance between this sample and the current center
|
||||
let dist = Euclidian::squared_distance(&data.get_row_as_vec(i), ¢roid);
|
||||
let dist = Euclidian::squared_distance(&row, ¢roid);
|
||||
|
||||
if dist < d[i] {
|
||||
d[i] = dist;
|
||||
|
||||
Reference in New Issue
Block a user