Lmm/add seeds in more algorithms (#164)
* Provide better output in flaky tests * feat: add seed parameter to multiple algorithms * Update changelog Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
This commit is contained in:
@@ -52,10 +52,10 @@
|
||||
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 10.3.1 K-Means Clustering](http://faculty.marshall.usc.edu/gareth-james/ISL/)
|
||||
//! * ["k-means++: The Advantages of Careful Seeding", Arthur D., Vassilvitskii S.](http://ilpubs.stanford.edu:8090/778/1/2006-13.pdf)
|
||||
|
||||
use rand::Rng;
|
||||
use std::fmt::Debug;
|
||||
use std::iter::Sum;
|
||||
|
||||
use ::rand::Rng;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -65,6 +65,7 @@ use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::distance::euclidian::*;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::rand::get_rng_impl;
|
||||
|
||||
/// K-Means clustering algorithm
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
@@ -108,6 +109,9 @@ pub struct KMeansParameters {
|
||||
pub k: usize,
|
||||
/// Maximum number of iterations of the k-means algorithm for a single run.
|
||||
pub max_iter: usize,
|
||||
/// Determines random number generation for centroid initialization.
|
||||
/// Use an int to make the randomness deterministic
|
||||
pub seed: Option<u64>,
|
||||
}
|
||||
|
||||
impl KMeansParameters {
|
||||
@@ -128,6 +132,7 @@ impl Default for KMeansParameters {
|
||||
KMeansParameters {
|
||||
k: 2,
|
||||
max_iter: 100,
|
||||
seed: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -238,7 +243,7 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
||||
let (n, d) = data.shape();
|
||||
|
||||
let mut distortion = T::max_value();
|
||||
let mut y = KMeans::kmeans_plus_plus(data, parameters.k);
|
||||
let mut y = KMeans::kmeans_plus_plus(data, parameters.k, parameters.seed);
|
||||
let mut size = vec![0; parameters.k];
|
||||
let mut centroids = vec![vec![T::zero(); d]; parameters.k];
|
||||
|
||||
@@ -311,8 +316,8 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
||||
Ok(result.to_row_vector())
|
||||
}
|
||||
|
||||
fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize) -> Vec<usize> {
|
||||
let mut rng = rand::thread_rng();
|
||||
fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize, seed: Option<u64>) -> Vec<usize> {
|
||||
let mut rng = get_rng_impl(seed);
|
||||
let (n, m) = data.shape();
|
||||
let mut y = vec![0; n];
|
||||
let mut centroid = data.get_row_as_vec(rng.gen_range(0..n));
|
||||
|
||||
Reference in New Issue
Block a user