fix: cargo fmt
This commit is contained in:
+54
-57
@@ -1,40 +1,41 @@
|
||||
extern crate rand;
|
||||
|
||||
use rand::Rng;
|
||||
use std::iter::Sum;
|
||||
use std::fmt::Debug;
|
||||
use std::iter::Sum;
|
||||
|
||||
use serde::{Serialize, Deserialize};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::distance::euclidian::*;
|
||||
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
||||
use crate::math::num::FloatExt;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct KMeans<T: FloatExt> {
|
||||
pub struct KMeans<T: FloatExt> {
|
||||
k: usize,
|
||||
y: Vec<usize>,
|
||||
size: Vec<usize>,
|
||||
distortion: T,
|
||||
centroids: Vec<Vec<T>>
|
||||
centroids: Vec<Vec<T>>,
|
||||
}
|
||||
|
||||
impl<T: FloatExt> PartialEq for KMeans<T> {
|
||||
impl<T: FloatExt> PartialEq for KMeans<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.k != other.k ||
|
||||
self.size != other.size ||
|
||||
self.centroids.len() != other.centroids.len() {
|
||||
if self.k != other.k
|
||||
|| self.size != other.size
|
||||
|| self.centroids.len() != other.centroids.len()
|
||||
{
|
||||
false
|
||||
} else {
|
||||
let n_centroids = self.centroids.len();
|
||||
for i in 0..n_centroids{
|
||||
if self.centroids[i].len() != other.centroids[i].len(){
|
||||
return false
|
||||
for i in 0..n_centroids {
|
||||
if self.centroids[i].len() != other.centroids[i].len() {
|
||||
return false;
|
||||
}
|
||||
for j in 0..self.centroids[i].len() {
|
||||
if (self.centroids[i][j] - other.centroids[i][j]).abs() > T::epsilon() {
|
||||
return false
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -44,21 +45,18 @@ impl<T: FloatExt> PartialEq for KMeans<T> {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KMeansParameters {
|
||||
pub max_iter: usize
|
||||
pub struct KMeansParameters {
|
||||
pub max_iter: usize,
|
||||
}
|
||||
|
||||
impl Default for KMeansParameters {
|
||||
fn default() -> Self {
|
||||
KMeansParameters {
|
||||
max_iter: 100
|
||||
}
|
||||
}
|
||||
fn default() -> Self {
|
||||
KMeansParameters { max_iter: 100 }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt + Sum> KMeans<T>{
|
||||
impl<T: FloatExt + Sum> KMeans<T> {
|
||||
pub fn new<M: Matrix<T>>(data: &M, k: usize, parameters: KMeansParameters) -> KMeans<T> {
|
||||
|
||||
let bbd = BBDTree::new(data);
|
||||
|
||||
if k < 2 {
|
||||
@@ -66,11 +64,14 @@ impl<T: FloatExt + Sum> KMeans<T>{
|
||||
}
|
||||
|
||||
if parameters.max_iter <= 0 {
|
||||
panic!("Invalid maximum number of iterations: {}", parameters.max_iter);
|
||||
panic!(
|
||||
"Invalid maximum number of iterations: {}",
|
||||
parameters.max_iter
|
||||
);
|
||||
}
|
||||
|
||||
let (n, d) = data.shape();
|
||||
|
||||
|
||||
let mut distortion = T::max_value();
|
||||
let mut y = KMeans::kmeans_plus_plus(data, k);
|
||||
let mut size = vec![0; k];
|
||||
@@ -90,10 +91,10 @@ impl<T: FloatExt + Sum> KMeans<T>{
|
||||
for j in 0..d {
|
||||
centroids[i][j] = centroids[i][j] / T::from(size[i]).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut sums = vec![vec![T::zero(); d]; k];
|
||||
for _ in 1..= parameters.max_iter {
|
||||
for _ in 1..=parameters.max_iter {
|
||||
let dist = bbd.clustering(¢roids, &mut sums, &mut size, &mut y);
|
||||
for i in 0..k {
|
||||
if size[i] > 0 {
|
||||
@@ -108,48 +109,46 @@ impl<T: FloatExt + Sum> KMeans<T>{
|
||||
} else {
|
||||
distortion = dist;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
KMeans{
|
||||
KMeans {
|
||||
k: k,
|
||||
y: y,
|
||||
size: size,
|
||||
distortion: distortion,
|
||||
centroids: centroids
|
||||
centroids: centroids,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
|
||||
let (n, _) = x.shape();
|
||||
let mut result = M::zeros(1, n);
|
||||
let (n, _) = x.shape();
|
||||
let mut result = M::zeros(1, n);
|
||||
|
||||
for i in 0..n {
|
||||
|
||||
let mut min_dist = T::max_value();
|
||||
let mut best_cluster = 0;
|
||||
|
||||
for j in 0..self.k {
|
||||
let dist = Euclidian::squared_distance(&x.get_row_as_vec(i), &self.centroids[j]);
|
||||
let dist = Euclidian::squared_distance(&x.get_row_as_vec(i), &self.centroids[j]);
|
||||
if dist < min_dist {
|
||||
min_dist = dist;
|
||||
best_cluster = j;
|
||||
}
|
||||
}
|
||||
}
|
||||
result.set(0, i, T::from(best_cluster).unwrap());
|
||||
}
|
||||
|
||||
result.to_row_vector()
|
||||
}
|
||||
|
||||
fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize) -> Vec<usize>{
|
||||
let mut rng = rand::thread_rng();
|
||||
fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize) -> Vec<usize> {
|
||||
let mut rng = rand::thread_rng();
|
||||
let (n, _) = data.shape();
|
||||
let mut y = vec![0; n];
|
||||
let mut centroid = data.get_row_as_vec(rng.gen_range(0, n));
|
||||
|
||||
let mut d = vec![T::max_value(); n];
|
||||
|
||||
|
||||
// pick the next center
|
||||
for j in 1..k {
|
||||
// Loop over the samples and compare them to the most recent center. Store
|
||||
@@ -157,7 +156,7 @@ impl<T: FloatExt + Sum> KMeans<T>{
|
||||
for i in 0..n {
|
||||
// compute the distance between this sample and the current center
|
||||
let dist = Euclidian::squared_distance(&data.get_row_as_vec(i), ¢roid);
|
||||
|
||||
|
||||
if dist < d[i] {
|
||||
d[i] = dist;
|
||||
y[i] = j - 1;
|
||||
@@ -165,7 +164,7 @@ impl<T: FloatExt + Sum> KMeans<T>{
|
||||
}
|
||||
|
||||
let mut sum: T = T::zero();
|
||||
for i in d.iter(){
|
||||
for i in d.iter() {
|
||||
sum = sum + *i;
|
||||
}
|
||||
let cutoff = T::from(rng.gen::<f64>()).unwrap() * sum;
|
||||
@@ -183,8 +182,8 @@ impl<T: FloatExt + Sum> KMeans<T>{
|
||||
|
||||
for i in 0..n {
|
||||
// compute the distance between this sample and the current center
|
||||
let dist = Euclidian::squared_distance(&data.get_row_as_vec(i), ¢roid);
|
||||
|
||||
let dist = Euclidian::squared_distance(&data.get_row_as_vec(i), ¢roid);
|
||||
|
||||
if dist < d[i] {
|
||||
d[i] = dist;
|
||||
y[i] = k - 1;
|
||||
@@ -193,17 +192,15 @@ impl<T: FloatExt + Sum> KMeans<T>{
|
||||
|
||||
y
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[test]
|
||||
fn fit_predict_iris() {
|
||||
fn fit_predict_iris() {
|
||||
let x = DenseMatrix::from_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
@@ -224,7 +221,8 @@ mod tests {
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4]]);
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
|
||||
let kmeans = KMeans::new(&x, 2, Default::default());
|
||||
|
||||
@@ -232,12 +230,11 @@ mod tests {
|
||||
|
||||
for i in 0..y.len() {
|
||||
assert_eq!(y[i] as usize, kmeans.y[i]);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde() {
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
@@ -258,14 +255,14 @@ mod tests {
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4]]);
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
|
||||
let kmeans = KMeans::new(&x, 2, Default::default());
|
||||
|
||||
let deserialized_kmeans: KMeans<f64> = serde_json::from_str(&serde_json::to_string(&kmeans).unwrap()).unwrap();
|
||||
let deserialized_kmeans: KMeans<f64> =
|
||||
serde_json::from_str(&serde_json::to_string(&kmeans).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(kmeans, deserialized_kmeans);
|
||||
|
||||
assert_eq!(kmeans, deserialized_kmeans);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user