feat: consolidates API
This commit is contained in:
+43
-22
@@ -43,7 +43,7 @@
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//!
|
||||
//! let kmeans = KMeans::fit(&x, 2, Default::default()).unwrap(); // Fit to data, 2 clusters
|
||||
//! let kmeans = KMeans::fit(&x, KMeansParameters::default().with_k(2)).unwrap(); // Fit to data, 2 clusters
|
||||
//! let y_hat = kmeans.predict(&x).unwrap(); // use the same points for prediction
|
||||
//! ```
|
||||
//!
|
||||
@@ -59,6 +59,7 @@ use std::iter::Sum;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
||||
use crate::api::{Predictor, UnsupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::distance::euclidian::*;
|
||||
@@ -101,11 +102,18 @@ impl<T: RealNumber> PartialEq for KMeans<T> {
|
||||
#[derive(Debug, Clone)]
|
||||
/// K-Means clustering algorithm parameters
|
||||
pub struct KMeansParameters {
|
||||
/// Number of clusters.
|
||||
pub k: usize,
|
||||
/// Maximum number of iterations of the k-means algorithm for a single run.
|
||||
pub max_iter: usize,
|
||||
}
|
||||
|
||||
impl KMeansParameters {
|
||||
/// Number of clusters.
|
||||
pub fn with_k(mut self, k: usize) -> Self {
|
||||
self.k = k;
|
||||
self
|
||||
}
|
||||
/// Maximum number of iterations of the k-means algorithm for a single run.
|
||||
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
|
||||
self.max_iter = max_iter;
|
||||
@@ -115,24 +123,37 @@ impl KMeansParameters {
|
||||
|
||||
impl Default for KMeansParameters {
|
||||
fn default() -> Self {
|
||||
KMeansParameters { max_iter: 100 }
|
||||
KMeansParameters {
|
||||
k: 2,
|
||||
max_iter: 100,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber + Sum, M: Matrix<T>> UnsupervisedEstimator<M, KMeansParameters> for KMeans<T> {
|
||||
fn fit(x: &M, parameters: KMeansParameters) -> Result<Self, Failed> {
|
||||
KMeans::fit(x, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for KMeans<T> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber + Sum> KMeans<T> {
|
||||
/// Fit algorithm to _NxM_ matrix where _N_ is number of samples and _M_ is number of features.
|
||||
/// * `data` - training instances to cluster
|
||||
/// * `k` - number of clusters
|
||||
/// * `data` - training instances to cluster
|
||||
/// * `parameters` - cluster parameters
|
||||
pub fn fit<M: Matrix<T>>(
|
||||
data: &M,
|
||||
k: usize,
|
||||
parameters: KMeansParameters,
|
||||
) -> Result<KMeans<T>, Failed> {
|
||||
pub fn fit<M: Matrix<T>>(data: &M, parameters: KMeansParameters) -> Result<KMeans<T>, Failed> {
|
||||
let bbd = BBDTree::new(data);
|
||||
|
||||
if k < 2 {
|
||||
return Err(Failed::fit(&format!("invalid number of clusters: {}", k)));
|
||||
if parameters.k < 2 {
|
||||
return Err(Failed::fit(&format!(
|
||||
"invalid number of clusters: {}",
|
||||
parameters.k
|
||||
)));
|
||||
}
|
||||
|
||||
if parameters.max_iter == 0 {
|
||||
@@ -145,9 +166,9 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
||||
let (n, d) = data.shape();
|
||||
|
||||
let mut distortion = T::max_value();
|
||||
let mut y = KMeans::kmeans_plus_plus(data, k);
|
||||
let mut size = vec![0; k];
|
||||
let mut centroids = vec![vec![T::zero(); d]; k];
|
||||
let mut y = KMeans::kmeans_plus_plus(data, parameters.k);
|
||||
let mut size = vec![0; parameters.k];
|
||||
let mut centroids = vec![vec![T::zero(); d]; parameters.k];
|
||||
|
||||
for i in 0..n {
|
||||
size[y[i]] += 1;
|
||||
@@ -159,16 +180,16 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
||||
}
|
||||
}
|
||||
|
||||
for i in 0..k {
|
||||
for i in 0..parameters.k {
|
||||
for j in 0..d {
|
||||
centroids[i][j] /= T::from(size[i]).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
let mut sums = vec![vec![T::zero(); d]; k];
|
||||
let mut sums = vec![vec![T::zero(); d]; parameters.k];
|
||||
for _ in 1..=parameters.max_iter {
|
||||
let dist = bbd.clustering(¢roids, &mut sums, &mut size, &mut y);
|
||||
for i in 0..k {
|
||||
for i in 0..parameters.k {
|
||||
if size[i] > 0 {
|
||||
for j in 0..d {
|
||||
centroids[i][j] = T::from(sums[i][j]).unwrap() / T::from(size[i]).unwrap();
|
||||
@@ -184,7 +205,7 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
||||
}
|
||||
|
||||
Ok(KMeans {
|
||||
k,
|
||||
k: parameters.k,
|
||||
y,
|
||||
size,
|
||||
distortion,
|
||||
@@ -280,10 +301,10 @@ mod tests {
|
||||
fn invalid_k() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
|
||||
|
||||
assert!(KMeans::fit(&x, 0, Default::default()).is_err());
|
||||
assert!(KMeans::fit(&x, KMeansParameters::default().with_k(0)).is_err());
|
||||
assert_eq!(
|
||||
"Fit failed: invalid number of clusters: 1",
|
||||
KMeans::fit(&x, 1, Default::default())
|
||||
KMeans::fit(&x, KMeansParameters::default().with_k(1))
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
);
|
||||
@@ -314,7 +335,7 @@ mod tests {
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
|
||||
let kmeans = KMeans::fit(&x, 2, Default::default()).unwrap();
|
||||
let kmeans = KMeans::fit(&x, Default::default()).unwrap();
|
||||
|
||||
let y = kmeans.predict(&x).unwrap();
|
||||
|
||||
@@ -348,7 +369,7 @@ mod tests {
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
|
||||
let kmeans = KMeans::fit(&x, 2, Default::default()).unwrap();
|
||||
let kmeans = KMeans::fit(&x, Default::default()).unwrap();
|
||||
|
||||
let deserialized_kmeans: KMeans<f64> =
|
||||
serde_json::from_str(&serde_json::to_string(&kmeans).unwrap()).unwrap();
|
||||
|
||||
Reference in New Issue
Block a user