diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index f70604c..5cebced 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -45,10 +45,11 @@ //! //! //! +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; use std::default::Default; use std::fmt::Debug; -use rand::Rng; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -79,6 +80,8 @@ pub struct RandomForestClassifierParameters { pub m: Option, /// Whether to keep samples used for tree generation. This is required for OOB prediction. pub keep_samples: bool, + /// Seed used for bootstrap sampling and feature selection for each tree. + pub seed: u64, } /// Random Forest Classifier @@ -128,6 +131,12 @@ impl RandomForestClassifierParameters { self.keep_samples = keep_samples; self } + + /// Seed used for bootstrap sampling and feature selection for each tree. + pub fn with_seed(mut self, seed: u64) -> Self { + self.seed = seed; + self + } } impl PartialEq for RandomForestClassifier { @@ -160,6 +169,7 @@ impl Default for RandomForestClassifierParameters { n_trees: 100, m: Option::None, keep_samples: false, + seed: 0, } } } @@ -211,6 +221,7 @@ impl RandomForestClassifier { .unwrap() }); + let mut rng = StdRng::seed_from_u64(parameters.seed); let classes = y_m.unique(); let k = classes.len(); let mut trees: Vec> = Vec::new(); @@ -221,7 +232,7 @@ impl RandomForestClassifier { } for _ in 0..parameters.n_trees { - let samples = RandomForestClassifier::::sample_with_replacement(&yi, k); + let samples = RandomForestClassifier::::sample_with_replacement(&yi, k, &mut rng); if let Some(ref mut all_samples) = maybe_all_samples { all_samples.push(samples.iter().map(|x| *x != 0).collect()) } @@ -232,7 +243,8 @@ impl RandomForestClassifier { min_samples_leaf: parameters.min_samples_leaf, min_samples_split: parameters.min_samples_split, }; - let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params)?; + let tree = + DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?; trees.push(tree); } @@ -304,8 +316,7 @@ impl RandomForestClassifier { which_max(&result) } - fn sample_with_replacement(y: &[usize], num_classes: usize) -> Vec { - let mut rng = rand::thread_rng(); + fn sample_with_replacement(y: &[usize], num_classes: usize, rng: &mut impl Rng) -> Vec { let class_weight = vec![1.; num_classes]; let nrows = y.len(); let mut samples = vec![0; nrows]; @@ -375,6 +386,7 @@ mod tests { n_trees: 100, m: Option::None, keep_samples: false, + seed: 87, }, ) .unwrap(); @@ -422,9 +434,11 @@ mod tests { n_trees: 100, m: Option::None, keep_samples: true, + seed: 87, }, ) .unwrap(); + assert!( accuracy(&y, &classifier.predict_oob(&x).unwrap()) < accuracy(&y, &classifier.predict(&x).unwrap()) diff --git a/src/ensemble/random_forest_regressor.rs b/src/ensemble/random_forest_regressor.rs index 90ac479..c923cd8 100644 --- a/src/ensemble/random_forest_regressor.rs +++ b/src/ensemble/random_forest_regressor.rs @@ -43,10 +43,11 @@ //! //! +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; use std::default::Default; use std::fmt::Debug; -use rand::Rng; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -75,6 +76,8 @@ pub struct RandomForestRegressorParameters { pub m: Option, /// Whether to keep samples used for tree generation. This is required for OOB prediction. pub keep_samples: bool, + /// Seed used for bootstrap sampling and feature selection for each tree. + pub seed: u64, } /// Random Forest Regressor @@ -118,8 +121,13 @@ impl RandomForestRegressorParameters { self.keep_samples = keep_samples; self } -} + /// Seed used for bootstrap sampling and feature selection for each tree. + pub fn with_seed(mut self, seed: u64) -> Self { + self.seed = seed; + self + } +} impl Default for RandomForestRegressorParameters { fn default() -> Self { RandomForestRegressorParameters { @@ -129,6 +137,7 @@ impl Default for RandomForestRegressorParameters { n_trees: 10, m: Option::None, keep_samples: false, + seed: 0, } } } @@ -182,6 +191,7 @@ impl RandomForestRegressor { .m .unwrap_or((num_attributes as f64).sqrt().floor() as usize); + let mut rng = StdRng::seed_from_u64(parameters.seed); let mut trees: Vec> = Vec::new(); let mut maybe_all_samples: Option>> = Option::None; @@ -190,7 +200,7 @@ impl RandomForestRegressor { } for _ in 0..parameters.n_trees { - let samples = RandomForestRegressor::::sample_with_replacement(n_rows); + let samples = RandomForestRegressor::::sample_with_replacement(n_rows, &mut rng); if let Some(ref mut all_samples) = maybe_all_samples { all_samples.push(samples.iter().map(|x| *x != 0).collect()) } @@ -199,7 +209,8 @@ impl RandomForestRegressor { min_samples_leaf: parameters.min_samples_leaf, min_samples_split: parameters.min_samples_split, }; - let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params)?; + let tree = + DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?; trees.push(tree); } @@ -275,8 +286,7 @@ impl RandomForestRegressor { result / T::from(n_trees).unwrap() } - fn sample_with_replacement(nrows: usize) -> Vec { - let mut rng = rand::thread_rng(); + fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec { let mut samples = vec![0; nrows]; for _ in 0..nrows { let xi = rng.gen_range(0..nrows); @@ -328,6 +338,7 @@ mod tests { n_trees: 1000, m: Option::None, keep_samples: false, + seed: 87, }, ) .and_then(|rf| rf.predict(&x)) @@ -372,6 +383,7 @@ mod tests { n_trees: 1000, m: Option::None, keep_samples: true, + seed: 87, }, ) .unwrap(); diff --git a/src/linalg/ndarray_bindings.rs b/src/linalg/ndarray_bindings.rs index 091aaaf..99e0918 100644 --- a/src/linalg/ndarray_bindings.rs +++ b/src/linalg/ndarray_bindings.rs @@ -1008,6 +1008,7 @@ mod tests { n_trees: 1000, m: Option::None, keep_samples: false, + seed: 0, }, ) .unwrap() diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index 200fee5..751d5d1 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -68,6 +68,7 @@ use std::fmt::Debug; use std::marker::PhantomData; use rand::seq::SliceRandom; +use rand::Rng; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -328,7 +329,14 @@ impl DecisionTreeClassifier { ) -> Result, Failed> { let (x_nrows, num_attributes) = x.shape(); let samples = vec![1; x_nrows]; - DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters) + DecisionTreeClassifier::fit_weak_learner( + x, + y, + samples, + num_attributes, + parameters, + &mut rand::thread_rng(), + ) } pub(crate) fn fit_weak_learner>( @@ -337,6 +345,7 @@ impl DecisionTreeClassifier { samples: Vec, mtry: usize, parameters: DecisionTreeClassifierParameters, + rng: &mut impl Rng, ) -> Result, Failed> { let y_m = M::from_row_vector(y.clone()); let (_, y_ncols) = y_m.shape(); @@ -384,13 +393,13 @@ impl DecisionTreeClassifier { let mut visitor_queue: LinkedList> = LinkedList::new(); - if tree.find_best_cutoff(&mut visitor, mtry) { + if tree.find_best_cutoff(&mut visitor, mtry, rng) { visitor_queue.push_back(visitor); } while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) { match visitor_queue.pop_front() { - Some(node) => tree.split(node, mtry, &mut visitor_queue), + Some(node) => tree.split(node, mtry, &mut visitor_queue, rng), None => break, }; } @@ -443,6 +452,7 @@ impl DecisionTreeClassifier { &mut self, visitor: &mut NodeVisitor<'_, T, M>, mtry: usize, + rng: &mut impl Rng, ) -> bool { let (n_rows, n_attr) = visitor.x.shape(); @@ -482,7 +492,7 @@ impl DecisionTreeClassifier { let mut variables = (0..n_attr).collect::>(); if mtry < n_attr { - variables.shuffle(&mut rand::thread_rng()); + variables.shuffle(rng); } for variable in variables.iter().take(mtry) { @@ -566,6 +576,7 @@ impl DecisionTreeClassifier { mut visitor: NodeVisitor<'a, T, M>, mtry: usize, visitor_queue: &mut LinkedList>, + rng: &mut impl Rng, ) -> bool { let (n, _) = visitor.x.shape(); let mut tc = 0; @@ -614,7 +625,7 @@ impl DecisionTreeClassifier { visitor.level + 1, ); - if self.find_best_cutoff(&mut true_visitor, mtry) { + if self.find_best_cutoff(&mut true_visitor, mtry, rng) { visitor_queue.push_back(true_visitor); } @@ -627,7 +638,7 @@ impl DecisionTreeClassifier { visitor.level + 1, ); - if self.find_best_cutoff(&mut false_visitor, mtry) { + if self.find_best_cutoff(&mut false_visitor, mtry, rng) { visitor_queue.push_back(false_visitor); } diff --git a/src/tree/decision_tree_regressor.rs b/src/tree/decision_tree_regressor.rs index 6a0705f..34f58a9 100644 --- a/src/tree/decision_tree_regressor.rs +++ b/src/tree/decision_tree_regressor.rs @@ -63,6 +63,7 @@ use std::default::Default; use std::fmt::Debug; use rand::seq::SliceRandom; +use rand::Rng; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -242,7 +243,14 @@ impl DecisionTreeRegressor { ) -> Result, Failed> { let (x_nrows, num_attributes) = x.shape(); let samples = vec![1; x_nrows]; - DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters) + DecisionTreeRegressor::fit_weak_learner( + x, + y, + samples, + num_attributes, + parameters, + &mut rand::thread_rng(), + ) } pub(crate) fn fit_weak_learner>( @@ -251,6 +259,7 @@ impl DecisionTreeRegressor { samples: Vec, mtry: usize, parameters: DecisionTreeRegressorParameters, + rng: &mut impl Rng, ) -> Result, Failed> { let y_m = M::from_row_vector(y.clone()); @@ -284,13 +293,13 @@ impl DecisionTreeRegressor { let mut visitor_queue: LinkedList> = LinkedList::new(); - if tree.find_best_cutoff(&mut visitor, mtry) { + if tree.find_best_cutoff(&mut visitor, mtry, rng) { visitor_queue.push_back(visitor); } while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) { match visitor_queue.pop_front() { - Some(node) => tree.split(node, mtry, &mut visitor_queue), + Some(node) => tree.split(node, mtry, &mut visitor_queue, rng), None => break, }; } @@ -343,6 +352,7 @@ impl DecisionTreeRegressor { &mut self, visitor: &mut NodeVisitor<'_, T, M>, mtry: usize, + rng: &mut impl Rng, ) -> bool { let (_, n_attr) = visitor.x.shape(); @@ -357,7 +367,7 @@ impl DecisionTreeRegressor { let mut variables = (0..n_attr).collect::>(); if mtry < n_attr { - variables.shuffle(&mut rand::thread_rng()); + variables.shuffle(rng); } let parent_gain = @@ -432,6 +442,7 @@ impl DecisionTreeRegressor { mut visitor: NodeVisitor<'a, T, M>, mtry: usize, visitor_queue: &mut LinkedList>, + rng: &mut impl Rng, ) -> bool { let (n, _) = visitor.x.shape(); let mut tc = 0; @@ -480,7 +491,7 @@ impl DecisionTreeRegressor { visitor.level + 1, ); - if self.find_best_cutoff(&mut true_visitor, mtry) { + if self.find_best_cutoff(&mut true_visitor, mtry, rng) { visitor_queue.push_back(true_visitor); } @@ -493,7 +504,7 @@ impl DecisionTreeRegressor { visitor.level + 1, ); - if self.find_best_cutoff(&mut false_visitor, mtry) { + if self.find_best_cutoff(&mut false_visitor, mtry, rng) { visitor_queue.push_back(false_visitor); }