From 8765bd21738770da3bd64815afced0d1ad23ab8d Mon Sep 17 00:00:00 2001 From: "Lorenzo (Mec-iS)" Date: Tue, 21 Mar 2023 17:37:58 +0900 Subject: [PATCH] Add fit_with_centroids --- src/cluster/kmeans.rs | 178 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 177 insertions(+), 1 deletion(-) diff --git a/src/cluster/kmeans.rs b/src/cluster/kmeans.rs index c2470ab..5f1c91b 100644 --- a/src/cluster/kmeans.rs +++ b/src/cluster/kmeans.rs @@ -62,7 +62,7 @@ use serde::{Deserialize, Serialize}; use crate::algorithm::neighbour::bbd_tree::BBDTree; use crate::api::{Predictor, UnsupervisedEstimator}; use crate::error::Failed; -use crate::linalg::basic::arrays::{Array1, Array2}; +use crate::linalg::basic::arrays::{Array1, Array2, Array}; use crate::metrics::distance::euclidian::*; use crate::numbers::basenum::Number; use crate::rand_custom::get_rng_impl; @@ -322,6 +322,109 @@ impl, Y: Array1> KMeans }) } + /// Fit algorithm to _NxM_ matrix where _N_ is number of samples and _M_ is number of features. + /// * `data` - training instances to cluster + /// * `parameters` - cluster parameters + /// * `centroids` - starting centroids + pub fn fit_with_centroids( + data: &X, + parameters: KMeansParameters, + centroids: Vec>, + ) -> Result, Failed> { + + // TODO: reuse existing methods in `crate::metrics` + fn euclidean_distance(point1: &Vec, point2: &Vec) -> f64 { + let mut dist = 0.0; + for i in 0..point1.len() { + dist += (point1[i] - point2[i]).powi(2); + } + dist.sqrt() + } + + fn closest_centroid(point: &Vec, centroids: &Vec>) -> usize { + let mut closest_idx = 0; + let mut closest_dist = std::f64::MAX; + for (i, centroid) in centroids.iter().enumerate() { + let dist = euclidean_distance(point, centroid); + if dist < closest_dist { + closest_dist = dist; + closest_idx = i; + } + } + closest_idx + } + + let bbd = BBDTree::new(data); + + if centroids.len() != parameters.k { + return Err(Failed::fit(&format!( + "number of centroids ({}) must be equal to k ({})", + centroids.len(), + parameters.k + ))); + } + + let mut y = vec![0; data.shape().0]; + for i in 0..data.shape().0 { + y[i] = closest_centroid( + &Vec::from_iterator(data.get_row(i).iterator(0).map(|e| e.to_f64().unwrap()), + data.shape().1), ¢roids + ); + } + + let mut size = vec![0; parameters.k]; + let mut new_centroids = vec![vec![0f64; data.shape().1]; parameters.k]; + + for i in 0..data.shape().0 { + size[y[i]] += 1; + } + + for i in 0..data.shape().0 { + for j in 0..data.shape().1 { + new_centroids[y[i]][j] += data.get((i, j)).to_f64().unwrap(); + } + } + + for i in 0..parameters.k { + for j in 0..data.shape().1 { + new_centroids[i][j] /= size[i] as f64; + } + } + + let mut sums = vec![vec![0f64; data.shape().1]; parameters.k]; + let mut distortion = std::f64::MAX; + + for _ in 1..=parameters.max_iter { + let dist = bbd.clustering(&new_centroids, &mut sums, &mut size, &mut y); + for i in 0..parameters.k { + if size[i] > 0 { + for j in 0..data.shape().1 { + new_centroids[i][j] = sums[i][j] / size[i] as f64; + } + } + } + + if distortion <= dist { + break; + } else { + distortion = dist; + } + } + + Ok(KMeans { + k: parameters.k, + _y: y, + size, + _distortion: distortion, + centroids: new_centroids, + _phantom_tx: PhantomData, + _phantom_ty: PhantomData, + _phantom_x: PhantomData, + _phantom_y: PhantomData, + }) + } + + /// Predict clusters for `x` /// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features. pub fn predict(&self, x: &X) -> Result { @@ -417,6 +520,7 @@ impl, Y: Array1> KMeans mod tests { use super::*; use crate::linalg::basic::matrix::DenseMatrix; + use crate::algorithm::neighbour::fastpair; #[cfg_attr( all(target_arch = "wasm32", not(target_os = "wasi")), @@ -503,6 +607,78 @@ mod tests { } } + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn fit_with_centroids_predict() { + let x = DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + &[4.9, 3.1, 1.5, 0.1], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + &[5.7, 2.8, 4.5, 1.3], + &[6.3, 3.3, 4.7, 1.6], + &[4.9, 2.4, 3.3, 1.0], + &[6.6, 2.9, 4.6, 1.3], + &[5.2, 2.7, 3.9, 1.4], + ]); + + let parameters = KMeansParameters { + k: 3, + max_iter: 50, + ..Default::default() + }; + + // compute pairs + let fastpair = fastpair::FastPair::new(&x).unwrap(); + + // compute centroids for N closest pairs + let mut n: isize = 2; + let mut centroids = vec![vec![0f64; x.shape().1]; n as usize + 1]; + for p in fastpair.ordered_pairs() { + if n == -1 { + break + } + + centroids[n as usize] = { + let mut result: Vec = Vec::with_capacity(x.shape().1); + for val1 in x.get_row(p.node).iterator(0) { + for val2 in x.get_row(p.neighbour.unwrap()).iterator(0) { + let sum = val1 + val2; + let avg = sum * 0.5f64; + result.push(avg); + } + } + result + }; + + n -= 1; + } + + + let kmeans = KMeans::fit_with_centroids( + &x, parameters, centroids).unwrap(); + + let y: Vec = kmeans.predict(&x).unwrap(); + + for (i, _y_i) in y.iter().enumerate() { + assert_eq!({ y[i] }, kmeans._y[i]); + } + } + #[cfg_attr( all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test