8 Commits

Author SHA1 Message Date
Lorenzo (Mec-iS)
13bb222ca7 Merge branch 'development' into kmeans-with-fastpair 2023-05-04 17:19:01 +01:00
Lorenzo (Mec-iS)
bf65fe3753 Merge branch 'march-2023-improvements' into kmeans-with-fastpair 2023-03-24 12:09:55 +09:00
Lorenzo (Mec-iS)
074cfaf14f rustfmt 2023-03-24 12:06:54 +09:00
Lorenzo
393cf15534 Merge branch 'development' into march-2023-improvements 2023-03-24 12:05:06 +09:00
Lorenzo (Mec-iS)
80c406b37d Merge branch 'development' of github.com:smartcorelib/smartcore into march-2023-improvements 2023-03-21 17:38:35 +09:00
Lorenzo (Mec-iS)
50e040a7a2 Merge branch 'development' of github.com:smartcorelib/smartcore into kmeans-with-fastpair 2023-03-21 17:38:06 +09:00
Lorenzo (Mec-iS)
8765bd2173 Add fit_with_centroids 2023-03-21 17:37:58 +09:00
Lorenzo (Mec-iS)
0e1bf6ce7f Add ordered_pairs method to FastPair 2023-03-21 14:46:33 +09:00
2 changed files with 227 additions and 1 deletions
+50
View File
@@ -179,6 +179,21 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
}
}
///
/// Return order dissimilarities from closest to furthest
///
#[allow(dead_code)]
pub fn ordered_pairs(&self) -> std::vec::IntoIter<&PairwiseDistance<T>> {
// improvement: implement this to return `impl Iterator<Item = &PairwiseDistance<T>>`
// need to implement trait `Iterator` for `Vec<&PairwiseDistance<T>>`
let mut distances = self
.distances
.values()
.collect::<Vec<&PairwiseDistance<T>>>();
distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
distances.into_iter()
}
//
// Compute distances from input to all other points in data-structure.
// input is the row index of the sample matrix
@@ -590,4 +605,39 @@ mod tests_fastpair {
assert_eq!(closest, min_dissimilarity);
}
#[test]
fn fastpair_ordered_pairs() {
let x = DenseMatrix::<f64>::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
]);
let fastpair = FastPair::new(&x).unwrap();
let ordered = fastpair.ordered_pairs();
let mut previous: f64 = -1.0;
for p in ordered {
if previous == -1.0 {
previous = p.distance.unwrap();
} else {
let current = p.distance.unwrap();
assert!(current >= previous);
previous = current;
}
}
}
}
+177 -1
View File
@@ -62,7 +62,7 @@ use serde::{Deserialize, Serialize};
use crate::algorithm::neighbour::bbd_tree::BBDTree;
use crate::api::{Predictor, UnsupervisedEstimator};
use crate::error::Failed;
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::linalg::basic::arrays::{Array1, Array2, Array};
use crate::metrics::distance::euclidian::*;
use crate::numbers::basenum::Number;
use crate::rand_custom::get_rng_impl;
@@ -322,6 +322,109 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y>
})
}
/// Fit algorithm to _NxM_ matrix where _N_ is number of samples and _M_ is number of features.
/// * `data` - training instances to cluster
/// * `parameters` - cluster parameters
/// * `centroids` - starting centroids
pub fn fit_with_centroids(
data: &X,
parameters: KMeansParameters,
centroids: Vec<Vec<f64>>,
) -> Result<KMeans<TX, TY, X, Y>, Failed> {
// TODO: reuse existing methods in `crate::metrics`
fn euclidean_distance(point1: &Vec<f64>, point2: &Vec<f64>) -> f64 {
let mut dist = 0.0;
for i in 0..point1.len() {
dist += (point1[i] - point2[i]).powi(2);
}
dist.sqrt()
}
fn closest_centroid(point: &Vec<f64>, centroids: &Vec<Vec<f64>>) -> usize {
let mut closest_idx = 0;
let mut closest_dist = std::f64::MAX;
for (i, centroid) in centroids.iter().enumerate() {
let dist = euclidean_distance(point, centroid);
if dist < closest_dist {
closest_dist = dist;
closest_idx = i;
}
}
closest_idx
}
let bbd = BBDTree::new(data);
if centroids.len() != parameters.k {
return Err(Failed::fit(&format!(
"number of centroids ({}) must be equal to k ({})",
centroids.len(),
parameters.k
)));
}
let mut y = vec![0; data.shape().0];
for i in 0..data.shape().0 {
y[i] = closest_centroid(
&Vec::from_iterator(data.get_row(i).iterator(0).map(|e| e.to_f64().unwrap()),
data.shape().1), &centroids
);
}
let mut size = vec![0; parameters.k];
let mut new_centroids = vec![vec![0f64; data.shape().1]; parameters.k];
for i in 0..data.shape().0 {
size[y[i]] += 1;
}
for i in 0..data.shape().0 {
for j in 0..data.shape().1 {
new_centroids[y[i]][j] += data.get((i, j)).to_f64().unwrap();
}
}
for i in 0..parameters.k {
for j in 0..data.shape().1 {
new_centroids[i][j] /= size[i] as f64;
}
}
let mut sums = vec![vec![0f64; data.shape().1]; parameters.k];
let mut distortion = std::f64::MAX;
for _ in 1..=parameters.max_iter {
let dist = bbd.clustering(&new_centroids, &mut sums, &mut size, &mut y);
for i in 0..parameters.k {
if size[i] > 0 {
for j in 0..data.shape().1 {
new_centroids[i][j] = sums[i][j] / size[i] as f64;
}
}
}
if distortion <= dist {
break;
} else {
distortion = dist;
}
}
Ok(KMeans {
k: parameters.k,
_y: y,
size,
_distortion: distortion,
centroids: new_centroids,
_phantom_tx: PhantomData,
_phantom_ty: PhantomData,
_phantom_x: PhantomData,
_phantom_y: PhantomData,
})
}
/// Predict clusters for `x`
/// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features.
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
@@ -417,6 +520,7 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y>
mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::algorithm::neighbour::fastpair;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
@@ -503,6 +607,78 @@ mod tests {
}
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn fit_with_centroids_predict() {
let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[5.7, 2.8, 4.5, 1.3],
&[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
let parameters = KMeansParameters {
k: 3,
max_iter: 50,
..Default::default()
};
// compute pairs
let fastpair = fastpair::FastPair::new(&x).unwrap();
// compute centroids for N closest pairs
let mut n: isize = 2;
let mut centroids = vec![vec![0f64; x.shape().1]; n as usize + 1];
for p in fastpair.ordered_pairs() {
if n == -1 {
break
}
centroids[n as usize] = {
let mut result: Vec<f64> = Vec::with_capacity(x.shape().1);
for val1 in x.get_row(p.node).iterator(0) {
for val2 in x.get_row(p.neighbour.unwrap()).iterator(0) {
let sum = val1 + val2;
let avg = sum * 0.5f64;
result.push(avg);
}
}
result
};
n -= 1;
}
let kmeans = KMeans::fit_with_centroids(
&x, parameters, centroids).unwrap();
let y: Vec<usize> = kmeans.predict(&x).unwrap();
for (i, _y_i) in y.iter().enumerate() {
assert_eq!({ y[i] }, kmeans._y[i]);
}
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test