Add fit_with_centroids
This commit is contained in:
+177
-1
@@ -62,7 +62,7 @@ use serde::{Deserialize, Serialize};
|
||||
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
||||
use crate::api::{Predictor, UnsupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::basic::arrays::{Array1, Array2};
|
||||
use crate::linalg::basic::arrays::{Array1, Array2, Array};
|
||||
use crate::metrics::distance::euclidian::*;
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::rand_custom::get_rng_impl;
|
||||
@@ -322,6 +322,109 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y>
|
||||
})
|
||||
}
|
||||
|
||||
/// Fit algorithm to _NxM_ matrix where _N_ is number of samples and _M_ is number of features.
|
||||
/// * `data` - training instances to cluster
|
||||
/// * `parameters` - cluster parameters
|
||||
/// * `centroids` - starting centroids
|
||||
pub fn fit_with_centroids(
|
||||
data: &X,
|
||||
parameters: KMeansParameters,
|
||||
centroids: Vec<Vec<f64>>,
|
||||
) -> Result<KMeans<TX, TY, X, Y>, Failed> {
|
||||
|
||||
// TODO: reuse existing methods in `crate::metrics`
|
||||
fn euclidean_distance(point1: &Vec<f64>, point2: &Vec<f64>) -> f64 {
|
||||
let mut dist = 0.0;
|
||||
for i in 0..point1.len() {
|
||||
dist += (point1[i] - point2[i]).powi(2);
|
||||
}
|
||||
dist.sqrt()
|
||||
}
|
||||
|
||||
fn closest_centroid(point: &Vec<f64>, centroids: &Vec<Vec<f64>>) -> usize {
|
||||
let mut closest_idx = 0;
|
||||
let mut closest_dist = std::f64::MAX;
|
||||
for (i, centroid) in centroids.iter().enumerate() {
|
||||
let dist = euclidean_distance(point, centroid);
|
||||
if dist < closest_dist {
|
||||
closest_dist = dist;
|
||||
closest_idx = i;
|
||||
}
|
||||
}
|
||||
closest_idx
|
||||
}
|
||||
|
||||
let bbd = BBDTree::new(data);
|
||||
|
||||
if centroids.len() != parameters.k {
|
||||
return Err(Failed::fit(&format!(
|
||||
"number of centroids ({}) must be equal to k ({})",
|
||||
centroids.len(),
|
||||
parameters.k
|
||||
)));
|
||||
}
|
||||
|
||||
let mut y = vec![0; data.shape().0];
|
||||
for i in 0..data.shape().0 {
|
||||
y[i] = closest_centroid(
|
||||
&Vec::from_iterator(data.get_row(i).iterator(0).map(|e| e.to_f64().unwrap()),
|
||||
data.shape().1), ¢roids
|
||||
);
|
||||
}
|
||||
|
||||
let mut size = vec![0; parameters.k];
|
||||
let mut new_centroids = vec![vec![0f64; data.shape().1]; parameters.k];
|
||||
|
||||
for i in 0..data.shape().0 {
|
||||
size[y[i]] += 1;
|
||||
}
|
||||
|
||||
for i in 0..data.shape().0 {
|
||||
for j in 0..data.shape().1 {
|
||||
new_centroids[y[i]][j] += data.get((i, j)).to_f64().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
for i in 0..parameters.k {
|
||||
for j in 0..data.shape().1 {
|
||||
new_centroids[i][j] /= size[i] as f64;
|
||||
}
|
||||
}
|
||||
|
||||
let mut sums = vec![vec![0f64; data.shape().1]; parameters.k];
|
||||
let mut distortion = std::f64::MAX;
|
||||
|
||||
for _ in 1..=parameters.max_iter {
|
||||
let dist = bbd.clustering(&new_centroids, &mut sums, &mut size, &mut y);
|
||||
for i in 0..parameters.k {
|
||||
if size[i] > 0 {
|
||||
for j in 0..data.shape().1 {
|
||||
new_centroids[i][j] = sums[i][j] / size[i] as f64;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if distortion <= dist {
|
||||
break;
|
||||
} else {
|
||||
distortion = dist;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(KMeans {
|
||||
k: parameters.k,
|
||||
_y: y,
|
||||
size,
|
||||
_distortion: distortion,
|
||||
centroids: new_centroids,
|
||||
_phantom_tx: PhantomData,
|
||||
_phantom_ty: PhantomData,
|
||||
_phantom_x: PhantomData,
|
||||
_phantom_y: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
/// Predict clusters for `x`
|
||||
/// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features.
|
||||
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
@@ -417,6 +520,7 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y>
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use crate::algorithm::neighbour::fastpair;
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
@@ -503,6 +607,78 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn fit_with_centroids_predict() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[5.7, 2.8, 4.5, 1.3],
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
|
||||
let parameters = KMeansParameters {
|
||||
k: 3,
|
||||
max_iter: 50,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// compute pairs
|
||||
let fastpair = fastpair::FastPair::new(&x).unwrap();
|
||||
|
||||
// compute centroids for N closest pairs
|
||||
let mut n: isize = 2;
|
||||
let mut centroids = vec![vec![0f64; x.shape().1]; n as usize + 1];
|
||||
for p in fastpair.ordered_pairs() {
|
||||
if n == -1 {
|
||||
break
|
||||
}
|
||||
|
||||
centroids[n as usize] = {
|
||||
let mut result: Vec<f64> = Vec::with_capacity(x.shape().1);
|
||||
for val1 in x.get_row(p.node).iterator(0) {
|
||||
for val2 in x.get_row(p.neighbour.unwrap()).iterator(0) {
|
||||
let sum = val1 + val2;
|
||||
let avg = sum * 0.5f64;
|
||||
result.push(avg);
|
||||
}
|
||||
}
|
||||
result
|
||||
};
|
||||
|
||||
n -= 1;
|
||||
}
|
||||
|
||||
|
||||
let kmeans = KMeans::fit_with_centroids(
|
||||
&x, parameters, centroids).unwrap();
|
||||
|
||||
let y: Vec<usize> = kmeans.predict(&x).unwrap();
|
||||
|
||||
for (i, _y_i) in y.iter().enumerate() {
|
||||
assert_eq!({ y[i] }, kmeans._y[i]);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
|
||||
Reference in New Issue
Block a user