Add fit_with_centroids

This commit is contained in:
Lorenzo (Mec-iS)
2023-03-21 17:37:58 +09:00
parent 0e1bf6ce7f
commit 8765bd2173
+177 -1
View File
@@ -62,7 +62,7 @@ use serde::{Deserialize, Serialize};
use crate::algorithm::neighbour::bbd_tree::BBDTree;
use crate::api::{Predictor, UnsupervisedEstimator};
use crate::error::Failed;
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::linalg::basic::arrays::{Array1, Array2, Array};
use crate::metrics::distance::euclidian::*;
use crate::numbers::basenum::Number;
use crate::rand_custom::get_rng_impl;
@@ -322,6 +322,109 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y>
})
}
/// Fit algorithm to _NxM_ matrix where _N_ is number of samples and _M_ is number of features.
/// * `data` - training instances to cluster
/// * `parameters` - cluster parameters
/// * `centroids` - starting centroids
pub fn fit_with_centroids(
data: &X,
parameters: KMeansParameters,
centroids: Vec<Vec<f64>>,
) -> Result<KMeans<TX, TY, X, Y>, Failed> {
// TODO: reuse existing methods in `crate::metrics`
fn euclidean_distance(point1: &Vec<f64>, point2: &Vec<f64>) -> f64 {
let mut dist = 0.0;
for i in 0..point1.len() {
dist += (point1[i] - point2[i]).powi(2);
}
dist.sqrt()
}
fn closest_centroid(point: &Vec<f64>, centroids: &Vec<Vec<f64>>) -> usize {
let mut closest_idx = 0;
let mut closest_dist = std::f64::MAX;
for (i, centroid) in centroids.iter().enumerate() {
let dist = euclidean_distance(point, centroid);
if dist < closest_dist {
closest_dist = dist;
closest_idx = i;
}
}
closest_idx
}
let bbd = BBDTree::new(data);
if centroids.len() != parameters.k {
return Err(Failed::fit(&format!(
"number of centroids ({}) must be equal to k ({})",
centroids.len(),
parameters.k
)));
}
let mut y = vec![0; data.shape().0];
for i in 0..data.shape().0 {
y[i] = closest_centroid(
&Vec::from_iterator(data.get_row(i).iterator(0).map(|e| e.to_f64().unwrap()),
data.shape().1), &centroids
);
}
let mut size = vec![0; parameters.k];
let mut new_centroids = vec![vec![0f64; data.shape().1]; parameters.k];
for i in 0..data.shape().0 {
size[y[i]] += 1;
}
for i in 0..data.shape().0 {
for j in 0..data.shape().1 {
new_centroids[y[i]][j] += data.get((i, j)).to_f64().unwrap();
}
}
for i in 0..parameters.k {
for j in 0..data.shape().1 {
new_centroids[i][j] /= size[i] as f64;
}
}
let mut sums = vec![vec![0f64; data.shape().1]; parameters.k];
let mut distortion = std::f64::MAX;
for _ in 1..=parameters.max_iter {
let dist = bbd.clustering(&new_centroids, &mut sums, &mut size, &mut y);
for i in 0..parameters.k {
if size[i] > 0 {
for j in 0..data.shape().1 {
new_centroids[i][j] = sums[i][j] / size[i] as f64;
}
}
}
if distortion <= dist {
break;
} else {
distortion = dist;
}
}
Ok(KMeans {
k: parameters.k,
_y: y,
size,
_distortion: distortion,
centroids: new_centroids,
_phantom_tx: PhantomData,
_phantom_ty: PhantomData,
_phantom_x: PhantomData,
_phantom_y: PhantomData,
})
}
/// Predict clusters for `x`
/// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features.
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
@@ -417,6 +520,7 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y>
mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::algorithm::neighbour::fastpair;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
@@ -503,6 +607,78 @@ mod tests {
}
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn fit_with_centroids_predict() {
let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[5.7, 2.8, 4.5, 1.3],
&[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
let parameters = KMeansParameters {
k: 3,
max_iter: 50,
..Default::default()
};
// compute pairs
let fastpair = fastpair::FastPair::new(&x).unwrap();
// compute centroids for N closest pairs
let mut n: isize = 2;
let mut centroids = vec![vec![0f64; x.shape().1]; n as usize + 1];
for p in fastpair.ordered_pairs() {
if n == -1 {
break
}
centroids[n as usize] = {
let mut result: Vec<f64> = Vec::with_capacity(x.shape().1);
for val1 in x.get_row(p.node).iterator(0) {
for val2 in x.get_row(p.neighbour.unwrap()).iterator(0) {
let sum = val1 + val2;
let avg = sum * 0.5f64;
result.push(avg);
}
}
result
};
n -= 1;
}
let kmeans = KMeans::fit_with_centroids(
&x, parameters, centroids).unwrap();
let y: Vec<usize> = kmeans.predict(&x).unwrap();
for (i, _y_i) in y.iter().enumerate() {
assert_eq!({ y[i] }, kmeans._y[i]);
}
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test