Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
13bb222ca7 | ||
|
|
bf65fe3753 | ||
|
|
074cfaf14f | ||
|
|
393cf15534 | ||
|
|
80c406b37d | ||
|
|
50e040a7a2 | ||
|
|
8765bd2173 | ||
|
|
0e1bf6ce7f |
@@ -179,6 +179,21 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Return order dissimilarities from closest to furthest
|
||||||
|
///
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn ordered_pairs(&self) -> std::vec::IntoIter<&PairwiseDistance<T>> {
|
||||||
|
// improvement: implement this to return `impl Iterator<Item = &PairwiseDistance<T>>`
|
||||||
|
// need to implement trait `Iterator` for `Vec<&PairwiseDistance<T>>`
|
||||||
|
let mut distances = self
|
||||||
|
.distances
|
||||||
|
.values()
|
||||||
|
.collect::<Vec<&PairwiseDistance<T>>>();
|
||||||
|
distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||||
|
distances.into_iter()
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// Compute distances from input to all other points in data-structure.
|
// Compute distances from input to all other points in data-structure.
|
||||||
// input is the row index of the sample matrix
|
// input is the row index of the sample matrix
|
||||||
@@ -590,4 +605,39 @@ mod tests_fastpair {
|
|||||||
|
|
||||||
assert_eq!(closest, min_dissimilarity);
|
assert_eq!(closest, min_dissimilarity);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn fastpair_ordered_pairs() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
|
&[4.7, 3.2, 1.3, 0.2],
|
||||||
|
&[4.6, 3.1, 1.5, 0.2],
|
||||||
|
&[5.0, 3.6, 1.4, 0.2],
|
||||||
|
&[5.4, 3.9, 1.7, 0.4],
|
||||||
|
&[4.9, 3.1, 1.5, 0.1],
|
||||||
|
&[7.0, 3.2, 4.7, 1.4],
|
||||||
|
&[6.4, 3.2, 4.5, 1.5],
|
||||||
|
&[6.9, 3.1, 4.9, 1.5],
|
||||||
|
&[5.5, 2.3, 4.0, 1.3],
|
||||||
|
&[6.5, 2.8, 4.6, 1.5],
|
||||||
|
&[4.6, 3.4, 1.4, 0.3],
|
||||||
|
&[5.0, 3.4, 1.5, 0.2],
|
||||||
|
&[4.4, 2.9, 1.4, 0.2],
|
||||||
|
]);
|
||||||
|
let fastpair = FastPair::new(&x).unwrap();
|
||||||
|
|
||||||
|
let ordered = fastpair.ordered_pairs();
|
||||||
|
|
||||||
|
let mut previous: f64 = -1.0;
|
||||||
|
for p in ordered {
|
||||||
|
if previous == -1.0 {
|
||||||
|
previous = p.distance.unwrap();
|
||||||
|
} else {
|
||||||
|
let current = p.distance.unwrap();
|
||||||
|
assert!(current >= previous);
|
||||||
|
previous = current;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+177
-1
@@ -62,7 +62,7 @@ use serde::{Deserialize, Serialize};
|
|||||||
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
||||||
use crate::api::{Predictor, UnsupervisedEstimator};
|
use crate::api::{Predictor, UnsupervisedEstimator};
|
||||||
use crate::error::Failed;
|
use crate::error::Failed;
|
||||||
use crate::linalg::basic::arrays::{Array1, Array2};
|
use crate::linalg::basic::arrays::{Array1, Array2, Array};
|
||||||
use crate::metrics::distance::euclidian::*;
|
use crate::metrics::distance::euclidian::*;
|
||||||
use crate::numbers::basenum::Number;
|
use crate::numbers::basenum::Number;
|
||||||
use crate::rand_custom::get_rng_impl;
|
use crate::rand_custom::get_rng_impl;
|
||||||
@@ -322,6 +322,109 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y>
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Fit algorithm to _NxM_ matrix where _N_ is number of samples and _M_ is number of features.
|
||||||
|
/// * `data` - training instances to cluster
|
||||||
|
/// * `parameters` - cluster parameters
|
||||||
|
/// * `centroids` - starting centroids
|
||||||
|
pub fn fit_with_centroids(
|
||||||
|
data: &X,
|
||||||
|
parameters: KMeansParameters,
|
||||||
|
centroids: Vec<Vec<f64>>,
|
||||||
|
) -> Result<KMeans<TX, TY, X, Y>, Failed> {
|
||||||
|
|
||||||
|
// TODO: reuse existing methods in `crate::metrics`
|
||||||
|
fn euclidean_distance(point1: &Vec<f64>, point2: &Vec<f64>) -> f64 {
|
||||||
|
let mut dist = 0.0;
|
||||||
|
for i in 0..point1.len() {
|
||||||
|
dist += (point1[i] - point2[i]).powi(2);
|
||||||
|
}
|
||||||
|
dist.sqrt()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn closest_centroid(point: &Vec<f64>, centroids: &Vec<Vec<f64>>) -> usize {
|
||||||
|
let mut closest_idx = 0;
|
||||||
|
let mut closest_dist = std::f64::MAX;
|
||||||
|
for (i, centroid) in centroids.iter().enumerate() {
|
||||||
|
let dist = euclidean_distance(point, centroid);
|
||||||
|
if dist < closest_dist {
|
||||||
|
closest_dist = dist;
|
||||||
|
closest_idx = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
closest_idx
|
||||||
|
}
|
||||||
|
|
||||||
|
let bbd = BBDTree::new(data);
|
||||||
|
|
||||||
|
if centroids.len() != parameters.k {
|
||||||
|
return Err(Failed::fit(&format!(
|
||||||
|
"number of centroids ({}) must be equal to k ({})",
|
||||||
|
centroids.len(),
|
||||||
|
parameters.k
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut y = vec![0; data.shape().0];
|
||||||
|
for i in 0..data.shape().0 {
|
||||||
|
y[i] = closest_centroid(
|
||||||
|
&Vec::from_iterator(data.get_row(i).iterator(0).map(|e| e.to_f64().unwrap()),
|
||||||
|
data.shape().1), ¢roids
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut size = vec![0; parameters.k];
|
||||||
|
let mut new_centroids = vec![vec![0f64; data.shape().1]; parameters.k];
|
||||||
|
|
||||||
|
for i in 0..data.shape().0 {
|
||||||
|
size[y[i]] += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in 0..data.shape().0 {
|
||||||
|
for j in 0..data.shape().1 {
|
||||||
|
new_centroids[y[i]][j] += data.get((i, j)).to_f64().unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in 0..parameters.k {
|
||||||
|
for j in 0..data.shape().1 {
|
||||||
|
new_centroids[i][j] /= size[i] as f64;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut sums = vec![vec![0f64; data.shape().1]; parameters.k];
|
||||||
|
let mut distortion = std::f64::MAX;
|
||||||
|
|
||||||
|
for _ in 1..=parameters.max_iter {
|
||||||
|
let dist = bbd.clustering(&new_centroids, &mut sums, &mut size, &mut y);
|
||||||
|
for i in 0..parameters.k {
|
||||||
|
if size[i] > 0 {
|
||||||
|
for j in 0..data.shape().1 {
|
||||||
|
new_centroids[i][j] = sums[i][j] / size[i] as f64;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if distortion <= dist {
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
distortion = dist;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(KMeans {
|
||||||
|
k: parameters.k,
|
||||||
|
_y: y,
|
||||||
|
size,
|
||||||
|
_distortion: distortion,
|
||||||
|
centroids: new_centroids,
|
||||||
|
_phantom_tx: PhantomData,
|
||||||
|
_phantom_ty: PhantomData,
|
||||||
|
_phantom_x: PhantomData,
|
||||||
|
_phantom_y: PhantomData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Predict clusters for `x`
|
/// Predict clusters for `x`
|
||||||
/// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features.
|
/// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features.
|
||||||
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||||
@@ -417,6 +520,7 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y>
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::basic::matrix::DenseMatrix;
|
use crate::linalg::basic::matrix::DenseMatrix;
|
||||||
|
use crate::algorithm::neighbour::fastpair;
|
||||||
|
|
||||||
#[cfg_attr(
|
#[cfg_attr(
|
||||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
@@ -503,6 +607,78 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn fit_with_centroids_predict() {
|
||||||
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
|
&[4.7, 3.2, 1.3, 0.2],
|
||||||
|
&[4.6, 3.1, 1.5, 0.2],
|
||||||
|
&[5.0, 3.6, 1.4, 0.2],
|
||||||
|
&[5.4, 3.9, 1.7, 0.4],
|
||||||
|
&[4.6, 3.4, 1.4, 0.3],
|
||||||
|
&[5.0, 3.4, 1.5, 0.2],
|
||||||
|
&[4.4, 2.9, 1.4, 0.2],
|
||||||
|
&[4.9, 3.1, 1.5, 0.1],
|
||||||
|
&[7.0, 3.2, 4.7, 1.4],
|
||||||
|
&[6.4, 3.2, 4.5, 1.5],
|
||||||
|
&[6.9, 3.1, 4.9, 1.5],
|
||||||
|
&[5.5, 2.3, 4.0, 1.3],
|
||||||
|
&[6.5, 2.8, 4.6, 1.5],
|
||||||
|
&[5.7, 2.8, 4.5, 1.3],
|
||||||
|
&[6.3, 3.3, 4.7, 1.6],
|
||||||
|
&[4.9, 2.4, 3.3, 1.0],
|
||||||
|
&[6.6, 2.9, 4.6, 1.3],
|
||||||
|
&[5.2, 2.7, 3.9, 1.4],
|
||||||
|
]);
|
||||||
|
|
||||||
|
let parameters = KMeansParameters {
|
||||||
|
k: 3,
|
||||||
|
max_iter: 50,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
// compute pairs
|
||||||
|
let fastpair = fastpair::FastPair::new(&x).unwrap();
|
||||||
|
|
||||||
|
// compute centroids for N closest pairs
|
||||||
|
let mut n: isize = 2;
|
||||||
|
let mut centroids = vec![vec![0f64; x.shape().1]; n as usize + 1];
|
||||||
|
for p in fastpair.ordered_pairs() {
|
||||||
|
if n == -1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
centroids[n as usize] = {
|
||||||
|
let mut result: Vec<f64> = Vec::with_capacity(x.shape().1);
|
||||||
|
for val1 in x.get_row(p.node).iterator(0) {
|
||||||
|
for val2 in x.get_row(p.neighbour.unwrap()).iterator(0) {
|
||||||
|
let sum = val1 + val2;
|
||||||
|
let avg = sum * 0.5f64;
|
||||||
|
result.push(avg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result
|
||||||
|
};
|
||||||
|
|
||||||
|
n -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
let kmeans = KMeans::fit_with_centroids(
|
||||||
|
&x, parameters, centroids).unwrap();
|
||||||
|
|
||||||
|
let y: Vec<usize> = kmeans.predict(&x).unwrap();
|
||||||
|
|
||||||
|
for (i, _y_i) in y.iter().enumerate() {
|
||||||
|
assert_eq!({ y[i] }, kmeans._y[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(
|
#[cfg_attr(
|
||||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
wasm_bindgen_test::wasm_bindgen_test
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
|||||||
Reference in New Issue
Block a user