feat: extends interface of Matrix to support for broad range of types
This commit is contained in:
+26
-20
@@ -1,18 +1,21 @@
|
||||
extern crate rand;
|
||||
|
||||
use rand::Rng;
|
||||
use std::iter::Sum;
|
||||
use std::fmt::Debug;
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::distance::euclidian;
|
||||
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct KMeans {
|
||||
pub struct KMeans<T: FloatExt> {
|
||||
k: usize,
|
||||
y: Vec<usize>,
|
||||
size: Vec<usize>,
|
||||
distortion: f64,
|
||||
centroids: Vec<Vec<f64>>
|
||||
distortion: T,
|
||||
centroids: Vec<Vec<T>>
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -28,8 +31,8 @@ impl Default for KMeansParameters {
|
||||
}
|
||||
}
|
||||
|
||||
impl KMeans{
|
||||
pub fn new<M: Matrix>(data: &M, k: usize, parameters: KMeansParameters) -> KMeans {
|
||||
impl<T: FloatExt + Debug + Sum> KMeans<T>{
|
||||
pub fn new<M: Matrix<T>>(data: &M, k: usize, parameters: KMeansParameters) -> KMeans<T> {
|
||||
|
||||
let bbd = BBDTree::new(data);
|
||||
|
||||
@@ -43,10 +46,10 @@ impl KMeans{
|
||||
|
||||
let (n, d) = data.shape();
|
||||
|
||||
let mut distortion = std::f64::MAX;
|
||||
let mut distortion = T::max_value();
|
||||
let mut y = KMeans::kmeans_plus_plus(data, k);
|
||||
let mut size = vec![0; k];
|
||||
let mut centroids = vec![vec![0f64; d]; k];
|
||||
let mut centroids = vec![vec![T::zero(); d]; k];
|
||||
|
||||
for i in 0..n {
|
||||
size[y[i]] += 1;
|
||||
@@ -54,23 +57,23 @@ impl KMeans{
|
||||
|
||||
for i in 0..n {
|
||||
for j in 0..d {
|
||||
centroids[y[i]][j] += data.get(i, j);
|
||||
centroids[y[i]][j] = centroids[y[i]][j] + data.get(i, j);
|
||||
}
|
||||
}
|
||||
|
||||
for i in 0..k {
|
||||
for j in 0..d {
|
||||
centroids[i][j] /= size[i] as f64;
|
||||
centroids[i][j] = centroids[i][j] / T::from(size[i]).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
let mut sums = vec![vec![0f64; d]; k];
|
||||
let mut sums = vec![vec![T::zero(); d]; k];
|
||||
for _ in 1..= parameters.max_iter {
|
||||
let dist = bbd.clustering(¢roids, &mut sums, &mut size, &mut y);
|
||||
for i in 0..k {
|
||||
if size[i] > 0 {
|
||||
for j in 0..d {
|
||||
centroids[i][j] = sums[i][j] as f64 / size[i] as f64;
|
||||
centroids[i][j] = T::from(sums[i][j]).unwrap() / T::from(size[i]).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -92,13 +95,13 @@ impl KMeans{
|
||||
}
|
||||
}
|
||||
|
||||
pub fn predict<M: Matrix>(&self, x: &M) -> M::RowVector {
|
||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
|
||||
let (n, _) = x.shape();
|
||||
let mut result = M::zeros(1, n);
|
||||
|
||||
for i in 0..n {
|
||||
|
||||
let mut min_dist = std::f64::MAX;
|
||||
let mut min_dist = T::max_value();
|
||||
let mut best_cluster = 0;
|
||||
|
||||
for j in 0..self.k {
|
||||
@@ -108,19 +111,19 @@ impl KMeans{
|
||||
best_cluster = j;
|
||||
}
|
||||
}
|
||||
result.set(0, i, best_cluster as f64);
|
||||
result.set(0, i, T::from(best_cluster).unwrap());
|
||||
}
|
||||
|
||||
result.to_row_vector()
|
||||
}
|
||||
|
||||
fn kmeans_plus_plus<M: Matrix>(data: &M, k: usize) -> Vec<usize>{
|
||||
fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize) -> Vec<usize>{
|
||||
let mut rng = rand::thread_rng();
|
||||
let (n, _) = data.shape();
|
||||
let mut y = vec![0; n];
|
||||
let mut centroid = data.get_row_as_vec(rng.gen_range(0, n));
|
||||
|
||||
let mut d = vec![std::f64::MAX; n];
|
||||
let mut d = vec![T::max_value(); n];
|
||||
|
||||
// pick the next center
|
||||
for j in 1..k {
|
||||
@@ -136,12 +139,15 @@ impl KMeans{
|
||||
}
|
||||
}
|
||||
|
||||
let sum: f64 = d.iter().sum();
|
||||
let cutoff = rng.gen::<f64>() * sum;
|
||||
let mut cost = 0f64;
|
||||
let mut sum: T = T::zero();
|
||||
for i in d.iter(){
|
||||
sum = sum + *i;
|
||||
}
|
||||
let cutoff = T::from(rng.gen::<f64>()).unwrap() * sum;
|
||||
let mut cost = T::zero();
|
||||
let index = 0;
|
||||
for index in 0..n {
|
||||
cost += d[index];
|
||||
cost = cost + d[index];
|
||||
if cost >= cutoff {
|
||||
break;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user