From 3a44161406b7f7b80226f4e576afd56c2c899ffe Mon Sep 17 00:00:00 2001 From: morenol <22335041+morenol@users.noreply.github.com> Date: Wed, 21 Sep 2022 15:35:22 -0400 Subject: [PATCH] Lmm/add seeds in more algorithms (#164) * Provide better output in flaky tests * feat: add seed parameter to multiple algorithms * Update changelog Co-authored-by: Luis Moreno --- .github/workflows/ci.yml | 23 ++++++++++++----------- CHANGELOG.md | 9 +++++++++ Cargo.toml | 10 +++++++--- src/cluster/kmeans.rs | 13 +++++++++---- src/ensemble/random_forest_classifier.rs | 11 ++++++----- src/ensemble/random_forest_regressor.rs | 11 ++++++----- src/lib.rs | 2 ++ src/math/num.rs | 6 ++++-- src/model_selection/kfold.rs | 21 +++++++++++++++++++-- src/model_selection/mod.rs | 11 +++++++---- src/rand.rs | 21 +++++++++++++++++++++ src/svm/svc.rs | 22 +++++++++++++++++----- src/tree/decision_tree_classifier.rs | 22 ++++++++++------------ src/tree/decision_tree_regressor.rs | 21 ++++++++++----------- 14 files changed, 139 insertions(+), 64 deletions(-) create mode 100644 src/rand.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5041117..82d0eab 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,23 +2,24 @@ name: CI on: push: - branches: [ main, development ] + branches: [main, development] pull_request: - branches: [ development ] + branches: [development] jobs: tests: runs-on: "${{ matrix.platform.os }}-latest" strategy: matrix: - platform: [ - { os: "windows", target: "x86_64-pc-windows-msvc" }, - { os: "windows", target: "i686-pc-windows-msvc" }, - { os: "ubuntu", target: "x86_64-unknown-linux-gnu" }, - { os: "ubuntu", target: "i686-unknown-linux-gnu" }, - { os: "ubuntu", target: "wasm32-unknown-unknown" }, - { os: "macos", target: "aarch64-apple-darwin" }, - ] + platform: + [ + { os: "windows", target: "x86_64-pc-windows-msvc" }, + { os: "windows", target: "i686-pc-windows-msvc" }, + { os: "ubuntu", target: "x86_64-unknown-linux-gnu" }, + { os: "ubuntu", target: "i686-unknown-linux-gnu" }, + { os: "ubuntu", target: "wasm32-unknown-unknown" }, + { os: "macos", target: "aarch64-apple-darwin" }, + ] env: TZ: "/usr/share/zoneinfo/your/location" steps: @@ -40,7 +41,7 @@ jobs: default: true - name: Install test runner for wasm if: matrix.platform.target == 'wasm32-unknown-unknown' - run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh + run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh - name: Stable Build uses: actions-rs/cargo@v1 with: diff --git a/CHANGELOG.md b/CHANGELOG.md index ade6825..79e77e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## Added +- Seeds to multiple algorithims that depend on random number generation. +- Added feature `js` to use WASM in browser + +## BREAKING CHANGE +- Added a new parameter to `train_test_split` to define the seed. + +## [0.2.1] - 2022-05-10 + ## Added - L2 regularization penalty to the Logistic Regression - Getters for the naive bayes structs diff --git a/Cargo.toml b/Cargo.toml index a0ad984..51b9887 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,21 +16,25 @@ categories = ["science"] default = ["datasets"] ndarray-bindings = ["ndarray"] nalgebra-bindings = ["nalgebra"] -datasets = ["rand_distr"] +datasets = ["rand_distr", "std"] fp_bench = ["itertools"] +std = ["rand/std", "rand/std_rng"] +# wasm32 only +js = ["getrandom/js"] [dependencies] ndarray = { version = "0.15", optional = true } nalgebra = { version = "0.31", optional = true } num-traits = "0.2" num = "0.4" -rand = "0.8" +rand = { version = "0.8", default-features = false, features = ["small_rng"] } rand_distr = { version = "0.4", optional = true } serde = { version = "1", features = ["derive"], optional = true } itertools = { version = "0.10.3", optional = true } +cfg-if = "1.0.0" [target.'cfg(target_arch = "wasm32")'.dependencies] -getrandom = { version = "0.2", features = ["js"] } +getrandom = { version = "0.2", optional = true } [dev-dependencies] smartcore = { path = ".", features = ["fp_bench"] } diff --git a/src/cluster/kmeans.rs b/src/cluster/kmeans.rs index 8ecbb2e..fee1425 100644 --- a/src/cluster/kmeans.rs +++ b/src/cluster/kmeans.rs @@ -52,10 +52,10 @@ //! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 10.3.1 K-Means Clustering](http://faculty.marshall.usc.edu/gareth-james/ISL/) //! * ["k-means++: The Advantages of Careful Seeding", Arthur D., Vassilvitskii S.](http://ilpubs.stanford.edu:8090/778/1/2006-13.pdf) -use rand::Rng; use std::fmt::Debug; use std::iter::Sum; +use ::rand::Rng; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -65,6 +65,7 @@ use crate::error::Failed; use crate::linalg::Matrix; use crate::math::distance::euclidian::*; use crate::math::num::RealNumber; +use crate::rand::get_rng_impl; /// K-Means clustering algorithm #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -108,6 +109,9 @@ pub struct KMeansParameters { pub k: usize, /// Maximum number of iterations of the k-means algorithm for a single run. pub max_iter: usize, + /// Determines random number generation for centroid initialization. + /// Use an int to make the randomness deterministic + pub seed: Option, } impl KMeansParameters { @@ -128,6 +132,7 @@ impl Default for KMeansParameters { KMeansParameters { k: 2, max_iter: 100, + seed: None, } } } @@ -238,7 +243,7 @@ impl KMeans { let (n, d) = data.shape(); let mut distortion = T::max_value(); - let mut y = KMeans::kmeans_plus_plus(data, parameters.k); + let mut y = KMeans::kmeans_plus_plus(data, parameters.k, parameters.seed); let mut size = vec![0; parameters.k]; let mut centroids = vec![vec![T::zero(); d]; parameters.k]; @@ -311,8 +316,8 @@ impl KMeans { Ok(result.to_row_vector()) } - fn kmeans_plus_plus>(data: &M, k: usize) -> Vec { - let mut rng = rand::thread_rng(); + fn kmeans_plus_plus>(data: &M, k: usize, seed: Option) -> Vec { + let mut rng = get_rng_impl(seed); let (n, m) = data.shape(); let mut y = vec![0; n]; let mut centroid = data.get_row_as_vec(rng.gen_range(0..n)); diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index a4d6e75..331dab7 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -45,8 +45,8 @@ //! //! //! -use rand::rngs::StdRng; -use rand::{Rng, SeedableRng}; +use rand::Rng; + use std::default::Default; use std::fmt::Debug; @@ -57,6 +57,7 @@ use crate::api::{Predictor, SupervisedEstimator}; use crate::error::{Failed, FailedError}; use crate::linalg::Matrix; use crate::math::num::RealNumber; +use crate::rand::get_rng_impl; use crate::tree::decision_tree_classifier::{ which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion, }; @@ -441,7 +442,7 @@ impl RandomForestClassifier { .unwrap() }); - let mut rng = StdRng::seed_from_u64(parameters.seed); + let mut rng = get_rng_impl(Some(parameters.seed)); let classes = y_m.unique(); let k = classes.len(); let mut trees: Vec> = Vec::new(); @@ -462,9 +463,9 @@ impl RandomForestClassifier { max_depth: parameters.max_depth, min_samples_leaf: parameters.min_samples_leaf, min_samples_split: parameters.min_samples_split, + seed: Some(parameters.seed), }; - let tree = - DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?; + let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params)?; trees.push(tree); } diff --git a/src/ensemble/random_forest_regressor.rs b/src/ensemble/random_forest_regressor.rs index ec78137..1270685 100644 --- a/src/ensemble/random_forest_regressor.rs +++ b/src/ensemble/random_forest_regressor.rs @@ -43,8 +43,8 @@ //! //! -use rand::rngs::StdRng; -use rand::{Rng, SeedableRng}; +use rand::Rng; + use std::default::Default; use std::fmt::Debug; @@ -55,6 +55,7 @@ use crate::api::{Predictor, SupervisedEstimator}; use crate::error::{Failed, FailedError}; use crate::linalg::Matrix; use crate::math::num::RealNumber; +use crate::rand::get_rng_impl; use crate::tree::decision_tree_regressor::{ DecisionTreeRegressor, DecisionTreeRegressorParameters, }; @@ -376,7 +377,7 @@ impl RandomForestRegressor { .m .unwrap_or((num_attributes as f64).sqrt().floor() as usize); - let mut rng = StdRng::seed_from_u64(parameters.seed); + let mut rng = get_rng_impl(Some(parameters.seed)); let mut trees: Vec> = Vec::new(); let mut maybe_all_samples: Option>> = Option::None; @@ -393,9 +394,9 @@ impl RandomForestRegressor { max_depth: parameters.max_depth, min_samples_leaf: parameters.min_samples_leaf, min_samples_split: parameters.min_samples_split, + seed: Some(parameters.seed), }; - let tree = - DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?; + let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params)?; trees.push(tree); } diff --git a/src/lib.rs b/src/lib.rs index e9e1c3d..b46ee10 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -101,3 +101,5 @@ pub mod readers; pub mod svm; /// Supervised tree-based learning methods pub mod tree; + +pub(crate) mod rand; diff --git a/src/math/num.rs b/src/math/num.rs index 433ad28..1ec20fb 100644 --- a/src/math/num.rs +++ b/src/math/num.rs @@ -9,6 +9,8 @@ use std::iter::{Product, Sum}; use std::ops::{AddAssign, DivAssign, MulAssign, SubAssign}; use std::str::FromStr; +use crate::rand::get_rng_impl; + /// Defines real number /// pub trait RealNumber: @@ -79,7 +81,7 @@ impl RealNumber for f64 { } fn rand() -> f64 { - let mut rng = rand::thread_rng(); + let mut rng = get_rng_impl(None); rng.gen() } @@ -124,7 +126,7 @@ impl RealNumber for f32 { } fn rand() -> f32 { - let mut rng = rand::thread_rng(); + let mut rng = get_rng_impl(None); rng.gen() } diff --git a/src/model_selection/kfold.rs b/src/model_selection/kfold.rs index 8706954..ef48b87 100644 --- a/src/model_selection/kfold.rs +++ b/src/model_selection/kfold.rs @@ -5,8 +5,8 @@ use crate::linalg::Matrix; use crate::math::num::RealNumber; use crate::model_selection::BaseKFold; +use crate::rand::get_rng_impl; use rand::seq::SliceRandom; -use rand::thread_rng; /// K-Folds cross-validator pub struct KFold { @@ -14,6 +14,9 @@ pub struct KFold { pub n_splits: usize, // cannot exceed std::usize::MAX /// Whether to shuffle the data before splitting into batches pub shuffle: bool, + /// When shuffle is True, seed affects the ordering of the indices. + /// Which controls the randomness of each fold + pub seed: Option, } impl KFold { @@ -23,8 +26,10 @@ impl KFold { // initialise indices let mut indices: Vec = (0..n_samples).collect(); + let mut rng = get_rng_impl(self.seed); + if self.shuffle { - indices.shuffle(&mut thread_rng()); + indices.shuffle(&mut rng); } // return a new array of given shape n_split, filled with each element of n_samples divided by n_splits. let mut fold_sizes = vec![n_samples / self.n_splits; self.n_splits]; @@ -66,6 +71,7 @@ impl Default for KFold { KFold { n_splits: 3, shuffle: true, + seed: None, } } } @@ -81,6 +87,12 @@ impl KFold { self.shuffle = shuffle; self } + + /// When shuffle is True, random_state affects the ordering of the indices. + pub fn with_seed(mut self, seed: Option) -> Self { + self.seed = seed; + self + } } /// An iterator over indices that split data into training and test set. @@ -150,6 +162,7 @@ mod tests { let k = KFold { n_splits: 3, shuffle: false, + seed: None, }; let x: DenseMatrix = DenseMatrix::rand(33, 100); let test_indices = k.test_indices(&x); @@ -165,6 +178,7 @@ mod tests { let k = KFold { n_splits: 3, shuffle: false, + seed: None, }; let x: DenseMatrix = DenseMatrix::rand(34, 100); let test_indices = k.test_indices(&x); @@ -180,6 +194,7 @@ mod tests { let k = KFold { n_splits: 2, shuffle: false, + seed: None, }; let x: DenseMatrix = DenseMatrix::rand(22, 100); let test_masks = k.test_masks(&x); @@ -206,6 +221,7 @@ mod tests { let k = KFold { n_splits: 2, shuffle: false, + seed: None, }; let x: DenseMatrix = DenseMatrix::rand(22, 100); let train_test_splits: Vec<(Vec, Vec)> = k.split(&x).collect(); @@ -238,6 +254,7 @@ mod tests { let k = KFold { n_splits: 3, shuffle: false, + seed: None, }; let x: DenseMatrix = DenseMatrix::rand(10, 4); let expected: Vec<(Vec, Vec)> = vec![ diff --git a/src/model_selection/mod.rs b/src/model_selection/mod.rs index 6f737d6..21cf7ed 100644 --- a/src/model_selection/mod.rs +++ b/src/model_selection/mod.rs @@ -41,7 +41,7 @@ //! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., //! ]; //! -//! let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true); +//! let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true, None); //! //! println!("X train: {:?}, y train: {}, X test: {:?}, y test: {}", //! x_train.shape(), y_train.len(), x_test.shape(), y_test.len()); @@ -107,8 +107,8 @@ use crate::error::Failed; use crate::linalg::BaseVector; use crate::linalg::Matrix; use crate::math::num::RealNumber; +use crate::rand::get_rng_impl; use rand::seq::SliceRandom; -use rand::thread_rng; pub(crate) mod kfold; @@ -130,11 +130,13 @@ pub trait BaseKFold { /// * `y` - target values, should be of size _N_ /// * `test_size`, (0, 1] - the proportion of the dataset to include in the test split. /// * `shuffle`, - whether or not to shuffle the data before splitting +/// * `seed` - Controls the shuffling applied to the data before applying the split. Pass an int for reproducible output across multiple function calls pub fn train_test_split>( x: &M, y: &M::RowVector, test_size: f32, shuffle: bool, + seed: Option, ) -> (M, M, M::RowVector, M::RowVector) { if x.shape().0 != y.len() { panic!( @@ -143,6 +145,7 @@ pub fn train_test_split>( y.len() ); } + let mut rng = get_rng_impl(seed); if test_size <= 0. || test_size > 1.0 { panic!("test_size should be between 0 and 1"); @@ -159,7 +162,7 @@ pub fn train_test_split>( let mut indices: Vec = (0..n).collect(); if shuffle { - indices.shuffle(&mut thread_rng()); + indices.shuffle(&mut rng); } let x_train = x.take(&indices[n_test..n], 0); @@ -292,7 +295,7 @@ mod tests { let x: DenseMatrix = DenseMatrix::rand(n, 3); let y = vec![0f64; n]; - let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true); + let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true, None); assert!( x_train.shape().0 > (n as f64 * 0.65) as usize diff --git a/src/rand.rs b/src/rand.rs new file mode 100644 index 0000000..d90e9c9 --- /dev/null +++ b/src/rand.rs @@ -0,0 +1,21 @@ +use ::rand::SeedableRng; +#[cfg(not(feature = "std"))] +use rand::rngs::SmallRng as RngImpl; +#[cfg(feature = "std")] +use rand::rngs::StdRng as RngImpl; + +pub(crate) fn get_rng_impl(seed: Option) -> RngImpl { + match seed { + Some(seed) => RngImpl::seed_from_u64(seed), + None => { + cfg_if::cfg_if! { + if #[cfg(feature = "std")] { + use rand::RngCore; + RngImpl::seed_from_u64(rand::thread_rng().next_u64()) + } else { + panic!("seed number needed for non-std build"); + } + } + } + } +} diff --git a/src/svm/svc.rs b/src/svm/svc.rs index 46b0b68..94c6d9e 100644 --- a/src/svm/svc.rs +++ b/src/svm/svc.rs @@ -84,6 +84,7 @@ use crate::error::Failed; use crate::linalg::BaseVector; use crate::linalg::Matrix; use crate::math::num::RealNumber; +use crate::rand::get_rng_impl; use crate::svm::{Kernel, Kernels, LinearKernel}; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -100,6 +101,8 @@ pub struct SVCParameters, K: Kernel pub kernel: K, /// Unused parameter. m: PhantomData, + /// Controls the pseudo random number generation for shuffling the data for probability estimates + seed: Option, } /// SVC grid search parameters @@ -279,8 +282,15 @@ impl, K: Kernel> SVCParameters) -> Self { + self.seed = seed; + self + } } impl> Default for SVCParameters { @@ -291,6 +301,7 @@ impl> Default for SVCParameters tol: T::from_f64(1e-3).unwrap(), kernel: Kernels::linear(), m: PhantomData, + seed: None, } } } @@ -511,7 +522,7 @@ impl<'a, T: RealNumber, M: Matrix, K: Kernel> Optimizer<'a, let good_enough = T::from_i32(1000).unwrap(); for _ in 0..self.parameters.epoch { - for i in Self::permutate(n) { + for i in self.permutate(n) { self.process(i, self.x.get_row(i), self.y.get(i), &mut cache); loop { self.reprocess(tol, &mut cache); @@ -544,7 +555,7 @@ impl<'a, T: RealNumber, M: Matrix, K: Kernel> Optimizer<'a, let mut cp = 0; let mut cn = 0; - for i in Self::permutate(n) { + for i in self.permutate(n) { if self.y.get(i) == T::one() && cp < few { if self.process(i, self.x.get_row(i), self.y.get(i), cache) { cp += 1; @@ -669,8 +680,8 @@ impl<'a, T: RealNumber, M: Matrix, K: Kernel> Optimizer<'a, self.recalculate_minmax_grad = true; } - fn permutate(n: usize) -> Vec { - let mut rng = rand::thread_rng(); + fn permutate(&self, n: usize) -> Vec { + let mut rng = get_rng_impl(self.parameters.seed); let mut range: Vec = (0..n).collect(); range.shuffle(&mut rng); range @@ -893,7 +904,8 @@ mod tests { &y, SVCParameters::default() .with_c(200.0) - .with_kernel(Kernels::linear()), + .with_kernel(Kernels::linear()) + .with_seed(Some(100)), ) .and_then(|lr| lr.predict(&x)) .unwrap(); diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index a1699af..a14c104 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -77,6 +77,7 @@ use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; use crate::linalg::Matrix; use crate::math::num::RealNumber; +use crate::rand::get_rng_impl; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] @@ -90,6 +91,8 @@ pub struct DecisionTreeClassifierParameters { pub min_samples_leaf: usize, /// The minimum number of samples required to split an internal node. pub min_samples_split: usize, + /// Controls the randomness of the estimator + pub seed: Option, } /// Decision Tree @@ -197,6 +200,7 @@ impl Default for DecisionTreeClassifierParameters { max_depth: None, min_samples_leaf: 1, min_samples_split: 2, + seed: None, } } } @@ -467,14 +471,7 @@ impl DecisionTreeClassifier { ) -> Result, Failed> { let (x_nrows, num_attributes) = x.shape(); let samples = vec![1; x_nrows]; - DecisionTreeClassifier::fit_weak_learner( - x, - y, - samples, - num_attributes, - parameters, - &mut rand::thread_rng(), - ) + DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters) } pub(crate) fn fit_weak_learner>( @@ -483,7 +480,6 @@ impl DecisionTreeClassifier { samples: Vec, mtry: usize, parameters: DecisionTreeClassifierParameters, - rng: &mut impl Rng, ) -> Result, Failed> { let y_m = M::from_row_vector(y.clone()); let (_, y_ncols) = y_m.shape(); @@ -497,6 +493,7 @@ impl DecisionTreeClassifier { ))); } + let mut rng = get_rng_impl(parameters.seed); let mut yi: Vec = vec![0; y_ncols]; for (i, yi_i) in yi.iter_mut().enumerate().take(y_ncols) { @@ -531,13 +528,13 @@ impl DecisionTreeClassifier { let mut visitor_queue: LinkedList> = LinkedList::new(); - if tree.find_best_cutoff(&mut visitor, mtry, rng) { + if tree.find_best_cutoff(&mut visitor, mtry, &mut rng) { visitor_queue.push_back(visitor); } while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) { match visitor_queue.pop_front() { - Some(node) => tree.split(node, mtry, &mut visitor_queue, rng), + Some(node) => tree.split(node, mtry, &mut visitor_queue, &mut rng), None => break, }; } @@ -874,7 +871,8 @@ mod tests { criterion: SplitCriterion::Entropy, max_depth: Some(3), min_samples_leaf: 1, - min_samples_split: 2 + min_samples_split: 2, + seed: None } ) .unwrap() diff --git a/src/tree/decision_tree_regressor.rs b/src/tree/decision_tree_regressor.rs index f48de33..7d88c40 100644 --- a/src/tree/decision_tree_regressor.rs +++ b/src/tree/decision_tree_regressor.rs @@ -72,6 +72,7 @@ use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; use crate::linalg::Matrix; use crate::math::num::RealNumber; +use crate::rand::get_rng_impl; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] @@ -83,6 +84,8 @@ pub struct DecisionTreeRegressorParameters { pub min_samples_leaf: usize, /// The minimum number of samples required to split an internal node. pub min_samples_split: usize, + /// Controls the randomness of the estimator + pub seed: Option, } /// Regression Tree @@ -130,6 +133,7 @@ impl Default for DecisionTreeRegressorParameters { max_depth: None, min_samples_leaf: 1, min_samples_split: 2, + seed: None, } } } @@ -357,14 +361,7 @@ impl DecisionTreeRegressor { ) -> Result, Failed> { let (x_nrows, num_attributes) = x.shape(); let samples = vec![1; x_nrows]; - DecisionTreeRegressor::fit_weak_learner( - x, - y, - samples, - num_attributes, - parameters, - &mut rand::thread_rng(), - ) + DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters) } pub(crate) fn fit_weak_learner>( @@ -373,7 +370,6 @@ impl DecisionTreeRegressor { samples: Vec, mtry: usize, parameters: DecisionTreeRegressorParameters, - rng: &mut impl Rng, ) -> Result, Failed> { let y_m = M::from_row_vector(y.clone()); @@ -381,6 +377,7 @@ impl DecisionTreeRegressor { let (_, num_attributes) = x.shape(); let mut nodes: Vec> = Vec::new(); + let mut rng = get_rng_impl(parameters.seed); let mut n = 0; let mut sum = T::zero(); @@ -407,13 +404,13 @@ impl DecisionTreeRegressor { let mut visitor_queue: LinkedList> = LinkedList::new(); - if tree.find_best_cutoff(&mut visitor, mtry, rng) { + if tree.find_best_cutoff(&mut visitor, mtry, &mut rng) { visitor_queue.push_back(visitor); } while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) { match visitor_queue.pop_front() { - Some(node) => tree.split(node, mtry, &mut visitor_queue, rng), + Some(node) => tree.split(node, mtry, &mut visitor_queue, &mut rng), None => break, }; } @@ -699,6 +696,7 @@ mod tests { max_depth: Option::None, min_samples_leaf: 2, min_samples_split: 6, + seed: None, }, ) .and_then(|t| t.predict(&x)) @@ -719,6 +717,7 @@ mod tests { max_depth: Option::None, min_samples_leaf: 1, min_samples_split: 3, + seed: None, }, ) .and_then(|t| t.predict(&x))