Lmm/add seeds in more algorithms (#164)

* Provide better output in flaky tests

* feat: add seed parameter to multiple algorithms

* Update changelog

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
This commit is contained in:
morenol
2022-09-21 15:35:22 -04:00
committed by GitHub
parent 48514d1b15
commit 3a44161406
14 changed files with 139 additions and 64 deletions
+4 -3
View File
@@ -2,16 +2,17 @@ name: CI
on: on:
push: push:
branches: [ main, development ] branches: [main, development]
pull_request: pull_request:
branches: [ development ] branches: [development]
jobs: jobs:
tests: tests:
runs-on: "${{ matrix.platform.os }}-latest" runs-on: "${{ matrix.platform.os }}-latest"
strategy: strategy:
matrix: matrix:
platform: [ platform:
[
{ os: "windows", target: "x86_64-pc-windows-msvc" }, { os: "windows", target: "x86_64-pc-windows-msvc" },
{ os: "windows", target: "i686-pc-windows-msvc" }, { os: "windows", target: "i686-pc-windows-msvc" },
{ os: "ubuntu", target: "x86_64-unknown-linux-gnu" }, { os: "ubuntu", target: "x86_64-unknown-linux-gnu" },
+9
View File
@@ -6,6 +6,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
## Added
- Seeds to multiple algorithims that depend on random number generation.
- Added feature `js` to use WASM in browser
## BREAKING CHANGE
- Added a new parameter to `train_test_split` to define the seed.
## [0.2.1] - 2022-05-10
## Added ## Added
- L2 regularization penalty to the Logistic Regression - L2 regularization penalty to the Logistic Regression
- Getters for the naive bayes structs - Getters for the naive bayes structs
+7 -3
View File
@@ -16,21 +16,25 @@ categories = ["science"]
default = ["datasets"] default = ["datasets"]
ndarray-bindings = ["ndarray"] ndarray-bindings = ["ndarray"]
nalgebra-bindings = ["nalgebra"] nalgebra-bindings = ["nalgebra"]
datasets = ["rand_distr"] datasets = ["rand_distr", "std"]
fp_bench = ["itertools"] fp_bench = ["itertools"]
std = ["rand/std", "rand/std_rng"]
# wasm32 only
js = ["getrandom/js"]
[dependencies] [dependencies]
ndarray = { version = "0.15", optional = true } ndarray = { version = "0.15", optional = true }
nalgebra = { version = "0.31", optional = true } nalgebra = { version = "0.31", optional = true }
num-traits = "0.2" num-traits = "0.2"
num = "0.4" num = "0.4"
rand = "0.8" rand = { version = "0.8", default-features = false, features = ["small_rng"] }
rand_distr = { version = "0.4", optional = true } rand_distr = { version = "0.4", optional = true }
serde = { version = "1", features = ["derive"], optional = true } serde = { version = "1", features = ["derive"], optional = true }
itertools = { version = "0.10.3", optional = true } itertools = { version = "0.10.3", optional = true }
cfg-if = "1.0.0"
[target.'cfg(target_arch = "wasm32")'.dependencies] [target.'cfg(target_arch = "wasm32")'.dependencies]
getrandom = { version = "0.2", features = ["js"] } getrandom = { version = "0.2", optional = true }
[dev-dependencies] [dev-dependencies]
smartcore = { path = ".", features = ["fp_bench"] } smartcore = { path = ".", features = ["fp_bench"] }
+9 -4
View File
@@ -52,10 +52,10 @@
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 10.3.1 K-Means Clustering](http://faculty.marshall.usc.edu/gareth-james/ISL/) //! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 10.3.1 K-Means Clustering](http://faculty.marshall.usc.edu/gareth-james/ISL/)
//! * ["k-means++: The Advantages of Careful Seeding", Arthur D., Vassilvitskii S.](http://ilpubs.stanford.edu:8090/778/1/2006-13.pdf) //! * ["k-means++: The Advantages of Careful Seeding", Arthur D., Vassilvitskii S.](http://ilpubs.stanford.edu:8090/778/1/2006-13.pdf)
use rand::Rng;
use std::fmt::Debug; use std::fmt::Debug;
use std::iter::Sum; use std::iter::Sum;
use ::rand::Rng;
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -65,6 +65,7 @@ use crate::error::Failed;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::distance::euclidian::*; use crate::math::distance::euclidian::*;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
use crate::rand::get_rng_impl;
/// K-Means clustering algorithm /// K-Means clustering algorithm
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@@ -108,6 +109,9 @@ pub struct KMeansParameters {
pub k: usize, pub k: usize,
/// Maximum number of iterations of the k-means algorithm for a single run. /// Maximum number of iterations of the k-means algorithm for a single run.
pub max_iter: usize, pub max_iter: usize,
/// Determines random number generation for centroid initialization.
/// Use an int to make the randomness deterministic
pub seed: Option<u64>,
} }
impl KMeansParameters { impl KMeansParameters {
@@ -128,6 +132,7 @@ impl Default for KMeansParameters {
KMeansParameters { KMeansParameters {
k: 2, k: 2,
max_iter: 100, max_iter: 100,
seed: None,
} }
} }
} }
@@ -238,7 +243,7 @@ impl<T: RealNumber + Sum> KMeans<T> {
let (n, d) = data.shape(); let (n, d) = data.shape();
let mut distortion = T::max_value(); let mut distortion = T::max_value();
let mut y = KMeans::kmeans_plus_plus(data, parameters.k); let mut y = KMeans::kmeans_plus_plus(data, parameters.k, parameters.seed);
let mut size = vec![0; parameters.k]; let mut size = vec![0; parameters.k];
let mut centroids = vec![vec![T::zero(); d]; parameters.k]; let mut centroids = vec![vec![T::zero(); d]; parameters.k];
@@ -311,8 +316,8 @@ impl<T: RealNumber + Sum> KMeans<T> {
Ok(result.to_row_vector()) Ok(result.to_row_vector())
} }
fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize) -> Vec<usize> { fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize, seed: Option<u64>) -> Vec<usize> {
let mut rng = rand::thread_rng(); let mut rng = get_rng_impl(seed);
let (n, m) = data.shape(); let (n, m) = data.shape();
let mut y = vec![0; n]; let mut y = vec![0; n];
let mut centroid = data.get_row_as_vec(rng.gen_range(0..n)); let mut centroid = data.get_row_as_vec(rng.gen_range(0..n));
+6 -5
View File
@@ -45,8 +45,8 @@
//! //!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use rand::rngs::StdRng; use rand::Rng;
use rand::{Rng, SeedableRng};
use std::default::Default; use std::default::Default;
use std::fmt::Debug; use std::fmt::Debug;
@@ -57,6 +57,7 @@ use crate::api::{Predictor, SupervisedEstimator};
use crate::error::{Failed, FailedError}; use crate::error::{Failed, FailedError};
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
use crate::rand::get_rng_impl;
use crate::tree::decision_tree_classifier::{ use crate::tree::decision_tree_classifier::{
which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion, which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion,
}; };
@@ -441,7 +442,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
.unwrap() .unwrap()
}); });
let mut rng = StdRng::seed_from_u64(parameters.seed); let mut rng = get_rng_impl(Some(parameters.seed));
let classes = y_m.unique(); let classes = y_m.unique();
let k = classes.len(); let k = classes.len();
let mut trees: Vec<DecisionTreeClassifier<T>> = Vec::new(); let mut trees: Vec<DecisionTreeClassifier<T>> = Vec::new();
@@ -462,9 +463,9 @@ impl<T: RealNumber> RandomForestClassifier<T> {
max_depth: parameters.max_depth, max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf, min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split, min_samples_split: parameters.min_samples_split,
seed: Some(parameters.seed),
}; };
let tree = let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params)?;
DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
trees.push(tree); trees.push(tree);
} }
+6 -5
View File
@@ -43,8 +43,8 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use rand::rngs::StdRng; use rand::Rng;
use rand::{Rng, SeedableRng};
use std::default::Default; use std::default::Default;
use std::fmt::Debug; use std::fmt::Debug;
@@ -55,6 +55,7 @@ use crate::api::{Predictor, SupervisedEstimator};
use crate::error::{Failed, FailedError}; use crate::error::{Failed, FailedError};
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
use crate::rand::get_rng_impl;
use crate::tree::decision_tree_regressor::{ use crate::tree::decision_tree_regressor::{
DecisionTreeRegressor, DecisionTreeRegressorParameters, DecisionTreeRegressor, DecisionTreeRegressorParameters,
}; };
@@ -376,7 +377,7 @@ impl<T: RealNumber> RandomForestRegressor<T> {
.m .m
.unwrap_or((num_attributes as f64).sqrt().floor() as usize); .unwrap_or((num_attributes as f64).sqrt().floor() as usize);
let mut rng = StdRng::seed_from_u64(parameters.seed); let mut rng = get_rng_impl(Some(parameters.seed));
let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new(); let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new();
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None; let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
@@ -393,9 +394,9 @@ impl<T: RealNumber> RandomForestRegressor<T> {
max_depth: parameters.max_depth, max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf, min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split, min_samples_split: parameters.min_samples_split,
seed: Some(parameters.seed),
}; };
let tree = let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params)?;
DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
trees.push(tree); trees.push(tree);
} }
+2
View File
@@ -101,3 +101,5 @@ pub mod readers;
pub mod svm; pub mod svm;
/// Supervised tree-based learning methods /// Supervised tree-based learning methods
pub mod tree; pub mod tree;
pub(crate) mod rand;
+4 -2
View File
@@ -9,6 +9,8 @@ use std::iter::{Product, Sum};
use std::ops::{AddAssign, DivAssign, MulAssign, SubAssign}; use std::ops::{AddAssign, DivAssign, MulAssign, SubAssign};
use std::str::FromStr; use std::str::FromStr;
use crate::rand::get_rng_impl;
/// Defines real number /// Defines real number
/// <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script> /// <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
pub trait RealNumber: pub trait RealNumber:
@@ -79,7 +81,7 @@ impl RealNumber for f64 {
} }
fn rand() -> f64 { fn rand() -> f64 {
let mut rng = rand::thread_rng(); let mut rng = get_rng_impl(None);
rng.gen() rng.gen()
} }
@@ -124,7 +126,7 @@ impl RealNumber for f32 {
} }
fn rand() -> f32 { fn rand() -> f32 {
let mut rng = rand::thread_rng(); let mut rng = get_rng_impl(None);
rng.gen() rng.gen()
} }
+19 -2
View File
@@ -5,8 +5,8 @@
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
use crate::model_selection::BaseKFold; use crate::model_selection::BaseKFold;
use crate::rand::get_rng_impl;
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use rand::thread_rng;
/// K-Folds cross-validator /// K-Folds cross-validator
pub struct KFold { pub struct KFold {
@@ -14,6 +14,9 @@ pub struct KFold {
pub n_splits: usize, // cannot exceed std::usize::MAX pub n_splits: usize, // cannot exceed std::usize::MAX
/// Whether to shuffle the data before splitting into batches /// Whether to shuffle the data before splitting into batches
pub shuffle: bool, pub shuffle: bool,
/// When shuffle is True, seed affects the ordering of the indices.
/// Which controls the randomness of each fold
pub seed: Option<u64>,
} }
impl KFold { impl KFold {
@@ -23,8 +26,10 @@ impl KFold {
// initialise indices // initialise indices
let mut indices: Vec<usize> = (0..n_samples).collect(); let mut indices: Vec<usize> = (0..n_samples).collect();
let mut rng = get_rng_impl(self.seed);
if self.shuffle { if self.shuffle {
indices.shuffle(&mut thread_rng()); indices.shuffle(&mut rng);
} }
// return a new array of given shape n_split, filled with each element of n_samples divided by n_splits. // return a new array of given shape n_split, filled with each element of n_samples divided by n_splits.
let mut fold_sizes = vec![n_samples / self.n_splits; self.n_splits]; let mut fold_sizes = vec![n_samples / self.n_splits; self.n_splits];
@@ -66,6 +71,7 @@ impl Default for KFold {
KFold { KFold {
n_splits: 3, n_splits: 3,
shuffle: true, shuffle: true,
seed: None,
} }
} }
} }
@@ -81,6 +87,12 @@ impl KFold {
self.shuffle = shuffle; self.shuffle = shuffle;
self self
} }
/// When shuffle is True, random_state affects the ordering of the indices.
pub fn with_seed(mut self, seed: Option<u64>) -> Self {
self.seed = seed;
self
}
} }
/// An iterator over indices that split data into training and test set. /// An iterator over indices that split data into training and test set.
@@ -150,6 +162,7 @@ mod tests {
let k = KFold { let k = KFold {
n_splits: 3, n_splits: 3,
shuffle: false, shuffle: false,
seed: None,
}; };
let x: DenseMatrix<f64> = DenseMatrix::rand(33, 100); let x: DenseMatrix<f64> = DenseMatrix::rand(33, 100);
let test_indices = k.test_indices(&x); let test_indices = k.test_indices(&x);
@@ -165,6 +178,7 @@ mod tests {
let k = KFold { let k = KFold {
n_splits: 3, n_splits: 3,
shuffle: false, shuffle: false,
seed: None,
}; };
let x: DenseMatrix<f64> = DenseMatrix::rand(34, 100); let x: DenseMatrix<f64> = DenseMatrix::rand(34, 100);
let test_indices = k.test_indices(&x); let test_indices = k.test_indices(&x);
@@ -180,6 +194,7 @@ mod tests {
let k = KFold { let k = KFold {
n_splits: 2, n_splits: 2,
shuffle: false, shuffle: false,
seed: None,
}; };
let x: DenseMatrix<f64> = DenseMatrix::rand(22, 100); let x: DenseMatrix<f64> = DenseMatrix::rand(22, 100);
let test_masks = k.test_masks(&x); let test_masks = k.test_masks(&x);
@@ -206,6 +221,7 @@ mod tests {
let k = KFold { let k = KFold {
n_splits: 2, n_splits: 2,
shuffle: false, shuffle: false,
seed: None,
}; };
let x: DenseMatrix<f64> = DenseMatrix::rand(22, 100); let x: DenseMatrix<f64> = DenseMatrix::rand(22, 100);
let train_test_splits: Vec<(Vec<usize>, Vec<usize>)> = k.split(&x).collect(); let train_test_splits: Vec<(Vec<usize>, Vec<usize>)> = k.split(&x).collect();
@@ -238,6 +254,7 @@ mod tests {
let k = KFold { let k = KFold {
n_splits: 3, n_splits: 3,
shuffle: false, shuffle: false,
seed: None,
}; };
let x: DenseMatrix<f64> = DenseMatrix::rand(10, 4); let x: DenseMatrix<f64> = DenseMatrix::rand(10, 4);
let expected: Vec<(Vec<usize>, Vec<usize>)> = vec![ let expected: Vec<(Vec<usize>, Vec<usize>)> = vec![
+7 -4
View File
@@ -41,7 +41,7 @@
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., //! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
//! ]; //! ];
//! //!
//! let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true); //! let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true, None);
//! //!
//! println!("X train: {:?}, y train: {}, X test: {:?}, y test: {}", //! println!("X train: {:?}, y train: {}, X test: {:?}, y test: {}",
//! x_train.shape(), y_train.len(), x_test.shape(), y_test.len()); //! x_train.shape(), y_train.len(), x_test.shape(), y_test.len());
@@ -107,8 +107,8 @@ use crate::error::Failed;
use crate::linalg::BaseVector; use crate::linalg::BaseVector;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
use crate::rand::get_rng_impl;
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use rand::thread_rng;
pub(crate) mod kfold; pub(crate) mod kfold;
@@ -130,11 +130,13 @@ pub trait BaseKFold {
/// * `y` - target values, should be of size _N_ /// * `y` - target values, should be of size _N_
/// * `test_size`, (0, 1] - the proportion of the dataset to include in the test split. /// * `test_size`, (0, 1] - the proportion of the dataset to include in the test split.
/// * `shuffle`, - whether or not to shuffle the data before splitting /// * `shuffle`, - whether or not to shuffle the data before splitting
/// * `seed` - Controls the shuffling applied to the data before applying the split. Pass an int for reproducible output across multiple function calls
pub fn train_test_split<T: RealNumber, M: Matrix<T>>( pub fn train_test_split<T: RealNumber, M: Matrix<T>>(
x: &M, x: &M,
y: &M::RowVector, y: &M::RowVector,
test_size: f32, test_size: f32,
shuffle: bool, shuffle: bool,
seed: Option<u64>,
) -> (M, M, M::RowVector, M::RowVector) { ) -> (M, M, M::RowVector, M::RowVector) {
if x.shape().0 != y.len() { if x.shape().0 != y.len() {
panic!( panic!(
@@ -143,6 +145,7 @@ pub fn train_test_split<T: RealNumber, M: Matrix<T>>(
y.len() y.len()
); );
} }
let mut rng = get_rng_impl(seed);
if test_size <= 0. || test_size > 1.0 { if test_size <= 0. || test_size > 1.0 {
panic!("test_size should be between 0 and 1"); panic!("test_size should be between 0 and 1");
@@ -159,7 +162,7 @@ pub fn train_test_split<T: RealNumber, M: Matrix<T>>(
let mut indices: Vec<usize> = (0..n).collect(); let mut indices: Vec<usize> = (0..n).collect();
if shuffle { if shuffle {
indices.shuffle(&mut thread_rng()); indices.shuffle(&mut rng);
} }
let x_train = x.take(&indices[n_test..n], 0); let x_train = x.take(&indices[n_test..n], 0);
@@ -292,7 +295,7 @@ mod tests {
let x: DenseMatrix<f64> = DenseMatrix::rand(n, 3); let x: DenseMatrix<f64> = DenseMatrix::rand(n, 3);
let y = vec![0f64; n]; let y = vec![0f64; n];
let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true); let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true, None);
assert!( assert!(
x_train.shape().0 > (n as f64 * 0.65) as usize x_train.shape().0 > (n as f64 * 0.65) as usize
+21
View File
@@ -0,0 +1,21 @@
use ::rand::SeedableRng;
#[cfg(not(feature = "std"))]
use rand::rngs::SmallRng as RngImpl;
#[cfg(feature = "std")]
use rand::rngs::StdRng as RngImpl;
pub(crate) fn get_rng_impl(seed: Option<u64>) -> RngImpl {
match seed {
Some(seed) => RngImpl::seed_from_u64(seed),
None => {
cfg_if::cfg_if! {
if #[cfg(feature = "std")] {
use rand::RngCore;
RngImpl::seed_from_u64(rand::thread_rng().next_u64())
} else {
panic!("seed number needed for non-std build");
}
}
}
}
}
+17 -5
View File
@@ -84,6 +84,7 @@ use crate::error::Failed;
use crate::linalg::BaseVector; use crate::linalg::BaseVector;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
use crate::rand::get_rng_impl;
use crate::svm::{Kernel, Kernels, LinearKernel}; use crate::svm::{Kernel, Kernels, LinearKernel};
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@@ -100,6 +101,8 @@ pub struct SVCParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>
pub kernel: K, pub kernel: K,
/// Unused parameter. /// Unused parameter.
m: PhantomData<M>, m: PhantomData<M>,
/// Controls the pseudo random number generation for shuffling the data for probability estimates
seed: Option<u64>,
} }
/// SVC grid search parameters /// SVC grid search parameters
@@ -279,8 +282,15 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVCParameters<T, M
tol: self.tol, tol: self.tol,
kernel, kernel,
m: PhantomData, m: PhantomData,
seed: self.seed,
} }
} }
/// Seed for the pseudo random number generator.
pub fn with_seed(mut self, seed: Option<u64>) -> Self {
self.seed = seed;
self
}
} }
impl<T: RealNumber, M: Matrix<T>> Default for SVCParameters<T, M, LinearKernel> { impl<T: RealNumber, M: Matrix<T>> Default for SVCParameters<T, M, LinearKernel> {
@@ -291,6 +301,7 @@ impl<T: RealNumber, M: Matrix<T>> Default for SVCParameters<T, M, LinearKernel>
tol: T::from_f64(1e-3).unwrap(), tol: T::from_f64(1e-3).unwrap(),
kernel: Kernels::linear(), kernel: Kernels::linear(),
m: PhantomData, m: PhantomData,
seed: None,
} }
} }
} }
@@ -511,7 +522,7 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
let good_enough = T::from_i32(1000).unwrap(); let good_enough = T::from_i32(1000).unwrap();
for _ in 0..self.parameters.epoch { for _ in 0..self.parameters.epoch {
for i in Self::permutate(n) { for i in self.permutate(n) {
self.process(i, self.x.get_row(i), self.y.get(i), &mut cache); self.process(i, self.x.get_row(i), self.y.get(i), &mut cache);
loop { loop {
self.reprocess(tol, &mut cache); self.reprocess(tol, &mut cache);
@@ -544,7 +555,7 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
let mut cp = 0; let mut cp = 0;
let mut cn = 0; let mut cn = 0;
for i in Self::permutate(n) { for i in self.permutate(n) {
if self.y.get(i) == T::one() && cp < few { if self.y.get(i) == T::one() && cp < few {
if self.process(i, self.x.get_row(i), self.y.get(i), cache) { if self.process(i, self.x.get_row(i), self.y.get(i), cache) {
cp += 1; cp += 1;
@@ -669,8 +680,8 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
self.recalculate_minmax_grad = true; self.recalculate_minmax_grad = true;
} }
fn permutate(n: usize) -> Vec<usize> { fn permutate(&self, n: usize) -> Vec<usize> {
let mut rng = rand::thread_rng(); let mut rng = get_rng_impl(self.parameters.seed);
let mut range: Vec<usize> = (0..n).collect(); let mut range: Vec<usize> = (0..n).collect();
range.shuffle(&mut rng); range.shuffle(&mut rng);
range range
@@ -893,7 +904,8 @@ mod tests {
&y, &y,
SVCParameters::default() SVCParameters::default()
.with_c(200.0) .with_c(200.0)
.with_kernel(Kernels::linear()), .with_kernel(Kernels::linear())
.with_seed(Some(100)),
) )
.and_then(|lr| lr.predict(&x)) .and_then(|lr| lr.predict(&x))
.unwrap(); .unwrap();
+10 -12
View File
@@ -77,6 +77,7 @@ use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed; use crate::error::Failed;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
use crate::rand::get_rng_impl;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -90,6 +91,8 @@ pub struct DecisionTreeClassifierParameters {
pub min_samples_leaf: usize, pub min_samples_leaf: usize,
/// The minimum number of samples required to split an internal node. /// The minimum number of samples required to split an internal node.
pub min_samples_split: usize, pub min_samples_split: usize,
/// Controls the randomness of the estimator
pub seed: Option<u64>,
} }
/// Decision Tree /// Decision Tree
@@ -197,6 +200,7 @@ impl Default for DecisionTreeClassifierParameters {
max_depth: None, max_depth: None,
min_samples_leaf: 1, min_samples_leaf: 1,
min_samples_split: 2, min_samples_split: 2,
seed: None,
} }
} }
} }
@@ -467,14 +471,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
) -> Result<DecisionTreeClassifier<T>, Failed> { ) -> Result<DecisionTreeClassifier<T>, Failed> {
let (x_nrows, num_attributes) = x.shape(); let (x_nrows, num_attributes) = x.shape();
let samples = vec![1; x_nrows]; let samples = vec![1; x_nrows];
DecisionTreeClassifier::fit_weak_learner( DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
x,
y,
samples,
num_attributes,
parameters,
&mut rand::thread_rng(),
)
} }
pub(crate) fn fit_weak_learner<M: Matrix<T>>( pub(crate) fn fit_weak_learner<M: Matrix<T>>(
@@ -483,7 +480,6 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
samples: Vec<usize>, samples: Vec<usize>,
mtry: usize, mtry: usize,
parameters: DecisionTreeClassifierParameters, parameters: DecisionTreeClassifierParameters,
rng: &mut impl Rng,
) -> Result<DecisionTreeClassifier<T>, Failed> { ) -> Result<DecisionTreeClassifier<T>, Failed> {
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
let (_, y_ncols) = y_m.shape(); let (_, y_ncols) = y_m.shape();
@@ -497,6 +493,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
))); )));
} }
let mut rng = get_rng_impl(parameters.seed);
let mut yi: Vec<usize> = vec![0; y_ncols]; let mut yi: Vec<usize> = vec![0; y_ncols];
for (i, yi_i) in yi.iter_mut().enumerate().take(y_ncols) { for (i, yi_i) in yi.iter_mut().enumerate().take(y_ncols) {
@@ -531,13 +528,13 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new(); let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
if tree.find_best_cutoff(&mut visitor, mtry, rng) { if tree.find_best_cutoff(&mut visitor, mtry, &mut rng) {
visitor_queue.push_back(visitor); visitor_queue.push_back(visitor);
} }
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) { while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
match visitor_queue.pop_front() { match visitor_queue.pop_front() {
Some(node) => tree.split(node, mtry, &mut visitor_queue, rng), Some(node) => tree.split(node, mtry, &mut visitor_queue, &mut rng),
None => break, None => break,
}; };
} }
@@ -874,7 +871,8 @@ mod tests {
criterion: SplitCriterion::Entropy, criterion: SplitCriterion::Entropy,
max_depth: Some(3), max_depth: Some(3),
min_samples_leaf: 1, min_samples_leaf: 1,
min_samples_split: 2 min_samples_split: 2,
seed: None
} }
) )
.unwrap() .unwrap()
+10 -11
View File
@@ -72,6 +72,7 @@ use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed; use crate::error::Failed;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
use crate::rand::get_rng_impl;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -83,6 +84,8 @@ pub struct DecisionTreeRegressorParameters {
pub min_samples_leaf: usize, pub min_samples_leaf: usize,
/// The minimum number of samples required to split an internal node. /// The minimum number of samples required to split an internal node.
pub min_samples_split: usize, pub min_samples_split: usize,
/// Controls the randomness of the estimator
pub seed: Option<u64>,
} }
/// Regression Tree /// Regression Tree
@@ -130,6 +133,7 @@ impl Default for DecisionTreeRegressorParameters {
max_depth: None, max_depth: None,
min_samples_leaf: 1, min_samples_leaf: 1,
min_samples_split: 2, min_samples_split: 2,
seed: None,
} }
} }
} }
@@ -357,14 +361,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
) -> Result<DecisionTreeRegressor<T>, Failed> { ) -> Result<DecisionTreeRegressor<T>, Failed> {
let (x_nrows, num_attributes) = x.shape(); let (x_nrows, num_attributes) = x.shape();
let samples = vec![1; x_nrows]; let samples = vec![1; x_nrows];
DecisionTreeRegressor::fit_weak_learner( DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
x,
y,
samples,
num_attributes,
parameters,
&mut rand::thread_rng(),
)
} }
pub(crate) fn fit_weak_learner<M: Matrix<T>>( pub(crate) fn fit_weak_learner<M: Matrix<T>>(
@@ -373,7 +370,6 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
samples: Vec<usize>, samples: Vec<usize>,
mtry: usize, mtry: usize,
parameters: DecisionTreeRegressorParameters, parameters: DecisionTreeRegressorParameters,
rng: &mut impl Rng,
) -> Result<DecisionTreeRegressor<T>, Failed> { ) -> Result<DecisionTreeRegressor<T>, Failed> {
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
@@ -381,6 +377,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
let (_, num_attributes) = x.shape(); let (_, num_attributes) = x.shape();
let mut nodes: Vec<Node<T>> = Vec::new(); let mut nodes: Vec<Node<T>> = Vec::new();
let mut rng = get_rng_impl(parameters.seed);
let mut n = 0; let mut n = 0;
let mut sum = T::zero(); let mut sum = T::zero();
@@ -407,13 +404,13 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new(); let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
if tree.find_best_cutoff(&mut visitor, mtry, rng) { if tree.find_best_cutoff(&mut visitor, mtry, &mut rng) {
visitor_queue.push_back(visitor); visitor_queue.push_back(visitor);
} }
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) { while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
match visitor_queue.pop_front() { match visitor_queue.pop_front() {
Some(node) => tree.split(node, mtry, &mut visitor_queue, rng), Some(node) => tree.split(node, mtry, &mut visitor_queue, &mut rng),
None => break, None => break,
}; };
} }
@@ -699,6 +696,7 @@ mod tests {
max_depth: Option::None, max_depth: Option::None,
min_samples_leaf: 2, min_samples_leaf: 2,
min_samples_split: 6, min_samples_split: 6,
seed: None,
}, },
) )
.and_then(|t| t.predict(&x)) .and_then(|t| t.predict(&x))
@@ -719,6 +717,7 @@ mod tests {
max_depth: Option::None, max_depth: Option::None,
min_samples_leaf: 1, min_samples_leaf: 1,
min_samples_split: 3, min_samples_split: 3,
seed: None,
}, },
) )
.and_then(|t| t.predict(&x)) .and_then(|t| t.predict(&x))