Allow setting seed for RandomForestClassifier and Regressor (#120)

* Seed for the classifier.

* Seed for the regressor.

* Forgot one.

* typo.
This commit is contained in:
Malte Londschien
2021-11-11 01:51:24 +01:00
committed by GitHub
parent 521dab49ef
commit 12c102d02b
5 changed files with 72 additions and 23 deletions
+19 -5
View File
@@ -45,10 +45,11 @@
//! //!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::default::Default; use std::default::Default;
use std::fmt::Debug; use std::fmt::Debug;
use rand::Rng;
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -79,6 +80,8 @@ pub struct RandomForestClassifierParameters {
pub m: Option<usize>, pub m: Option<usize>,
/// Whether to keep samples used for tree generation. This is required for OOB prediction. /// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub keep_samples: bool, pub keep_samples: bool,
/// Seed used for bootstrap sampling and feature selection for each tree.
pub seed: u64,
} }
/// Random Forest Classifier /// Random Forest Classifier
@@ -128,6 +131,12 @@ impl RandomForestClassifierParameters {
self.keep_samples = keep_samples; self.keep_samples = keep_samples;
self self
} }
/// Seed used for bootstrap sampling and feature selection for each tree.
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
} }
impl<T: RealNumber> PartialEq for RandomForestClassifier<T> { impl<T: RealNumber> PartialEq for RandomForestClassifier<T> {
@@ -160,6 +169,7 @@ impl Default for RandomForestClassifierParameters {
n_trees: 100, n_trees: 100,
m: Option::None, m: Option::None,
keep_samples: false, keep_samples: false,
seed: 0,
} }
} }
} }
@@ -211,6 +221,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
.unwrap() .unwrap()
}); });
let mut rng = StdRng::seed_from_u64(parameters.seed);
let classes = y_m.unique(); let classes = y_m.unique();
let k = classes.len(); let k = classes.len();
let mut trees: Vec<DecisionTreeClassifier<T>> = Vec::new(); let mut trees: Vec<DecisionTreeClassifier<T>> = Vec::new();
@@ -221,7 +232,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
} }
for _ in 0..parameters.n_trees { for _ in 0..parameters.n_trees {
let samples = RandomForestClassifier::<T>::sample_with_replacement(&yi, k); let samples = RandomForestClassifier::<T>::sample_with_replacement(&yi, k, &mut rng);
if let Some(ref mut all_samples) = maybe_all_samples { if let Some(ref mut all_samples) = maybe_all_samples {
all_samples.push(samples.iter().map(|x| *x != 0).collect()) all_samples.push(samples.iter().map(|x| *x != 0).collect())
} }
@@ -232,7 +243,8 @@ impl<T: RealNumber> RandomForestClassifier<T> {
min_samples_leaf: parameters.min_samples_leaf, min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split, min_samples_split: parameters.min_samples_split,
}; };
let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params)?; let tree =
DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
trees.push(tree); trees.push(tree);
} }
@@ -304,8 +316,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
which_max(&result) which_max(&result)
} }
fn sample_with_replacement(y: &[usize], num_classes: usize) -> Vec<usize> { fn sample_with_replacement(y: &[usize], num_classes: usize, rng: &mut impl Rng) -> Vec<usize> {
let mut rng = rand::thread_rng();
let class_weight = vec![1.; num_classes]; let class_weight = vec![1.; num_classes];
let nrows = y.len(); let nrows = y.len();
let mut samples = vec![0; nrows]; let mut samples = vec![0; nrows];
@@ -375,6 +386,7 @@ mod tests {
n_trees: 100, n_trees: 100,
m: Option::None, m: Option::None,
keep_samples: false, keep_samples: false,
seed: 87,
}, },
) )
.unwrap(); .unwrap();
@@ -422,9 +434,11 @@ mod tests {
n_trees: 100, n_trees: 100,
m: Option::None, m: Option::None,
keep_samples: true, keep_samples: true,
seed: 87,
}, },
) )
.unwrap(); .unwrap();
assert!( assert!(
accuracy(&y, &classifier.predict_oob(&x).unwrap()) accuracy(&y, &classifier.predict_oob(&x).unwrap())
< accuracy(&y, &classifier.predict(&x).unwrap()) < accuracy(&y, &classifier.predict(&x).unwrap())
+18 -6
View File
@@ -43,10 +43,11 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::default::Default; use std::default::Default;
use std::fmt::Debug; use std::fmt::Debug;
use rand::Rng;
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -75,6 +76,8 @@ pub struct RandomForestRegressorParameters {
pub m: Option<usize>, pub m: Option<usize>,
/// Whether to keep samples used for tree generation. This is required for OOB prediction. /// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub keep_samples: bool, pub keep_samples: bool,
/// Seed used for bootstrap sampling and feature selection for each tree.
pub seed: u64,
} }
/// Random Forest Regressor /// Random Forest Regressor
@@ -118,8 +121,13 @@ impl RandomForestRegressorParameters {
self.keep_samples = keep_samples; self.keep_samples = keep_samples;
self self
} }
}
/// Seed used for bootstrap sampling and feature selection for each tree.
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
}
impl Default for RandomForestRegressorParameters { impl Default for RandomForestRegressorParameters {
fn default() -> Self { fn default() -> Self {
RandomForestRegressorParameters { RandomForestRegressorParameters {
@@ -129,6 +137,7 @@ impl Default for RandomForestRegressorParameters {
n_trees: 10, n_trees: 10,
m: Option::None, m: Option::None,
keep_samples: false, keep_samples: false,
seed: 0,
} }
} }
} }
@@ -182,6 +191,7 @@ impl<T: RealNumber> RandomForestRegressor<T> {
.m .m
.unwrap_or((num_attributes as f64).sqrt().floor() as usize); .unwrap_or((num_attributes as f64).sqrt().floor() as usize);
let mut rng = StdRng::seed_from_u64(parameters.seed);
let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new(); let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new();
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None; let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
@@ -190,7 +200,7 @@ impl<T: RealNumber> RandomForestRegressor<T> {
} }
for _ in 0..parameters.n_trees { for _ in 0..parameters.n_trees {
let samples = RandomForestRegressor::<T>::sample_with_replacement(n_rows); let samples = RandomForestRegressor::<T>::sample_with_replacement(n_rows, &mut rng);
if let Some(ref mut all_samples) = maybe_all_samples { if let Some(ref mut all_samples) = maybe_all_samples {
all_samples.push(samples.iter().map(|x| *x != 0).collect()) all_samples.push(samples.iter().map(|x| *x != 0).collect())
} }
@@ -199,7 +209,8 @@ impl<T: RealNumber> RandomForestRegressor<T> {
min_samples_leaf: parameters.min_samples_leaf, min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split, min_samples_split: parameters.min_samples_split,
}; };
let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params)?; let tree =
DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
trees.push(tree); trees.push(tree);
} }
@@ -275,8 +286,7 @@ impl<T: RealNumber> RandomForestRegressor<T> {
result / T::from(n_trees).unwrap() result / T::from(n_trees).unwrap()
} }
fn sample_with_replacement(nrows: usize) -> Vec<usize> { fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec<usize> {
let mut rng = rand::thread_rng();
let mut samples = vec![0; nrows]; let mut samples = vec![0; nrows];
for _ in 0..nrows { for _ in 0..nrows {
let xi = rng.gen_range(0..nrows); let xi = rng.gen_range(0..nrows);
@@ -328,6 +338,7 @@ mod tests {
n_trees: 1000, n_trees: 1000,
m: Option::None, m: Option::None,
keep_samples: false, keep_samples: false,
seed: 87,
}, },
) )
.and_then(|rf| rf.predict(&x)) .and_then(|rf| rf.predict(&x))
@@ -372,6 +383,7 @@ mod tests {
n_trees: 1000, n_trees: 1000,
m: Option::None, m: Option::None,
keep_samples: true, keep_samples: true,
seed: 87,
}, },
) )
.unwrap(); .unwrap();
+1
View File
@@ -1008,6 +1008,7 @@ mod tests {
n_trees: 1000, n_trees: 1000,
m: Option::None, m: Option::None,
keep_samples: false, keep_samples: false,
seed: 0,
}, },
) )
.unwrap() .unwrap()
+17 -6
View File
@@ -68,6 +68,7 @@ use std::fmt::Debug;
use std::marker::PhantomData; use std::marker::PhantomData;
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use rand::Rng;
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -328,7 +329,14 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
) -> Result<DecisionTreeClassifier<T>, Failed> { ) -> Result<DecisionTreeClassifier<T>, Failed> {
let (x_nrows, num_attributes) = x.shape(); let (x_nrows, num_attributes) = x.shape();
let samples = vec![1; x_nrows]; let samples = vec![1; x_nrows];
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters) DecisionTreeClassifier::fit_weak_learner(
x,
y,
samples,
num_attributes,
parameters,
&mut rand::thread_rng(),
)
} }
pub(crate) fn fit_weak_learner<M: Matrix<T>>( pub(crate) fn fit_weak_learner<M: Matrix<T>>(
@@ -337,6 +345,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
samples: Vec<usize>, samples: Vec<usize>,
mtry: usize, mtry: usize,
parameters: DecisionTreeClassifierParameters, parameters: DecisionTreeClassifierParameters,
rng: &mut impl Rng,
) -> Result<DecisionTreeClassifier<T>, Failed> { ) -> Result<DecisionTreeClassifier<T>, Failed> {
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
let (_, y_ncols) = y_m.shape(); let (_, y_ncols) = y_m.shape();
@@ -384,13 +393,13 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new(); let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
if tree.find_best_cutoff(&mut visitor, mtry) { if tree.find_best_cutoff(&mut visitor, mtry, rng) {
visitor_queue.push_back(visitor); visitor_queue.push_back(visitor);
} }
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) { while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
match visitor_queue.pop_front() { match visitor_queue.pop_front() {
Some(node) => tree.split(node, mtry, &mut visitor_queue), Some(node) => tree.split(node, mtry, &mut visitor_queue, rng),
None => break, None => break,
}; };
} }
@@ -443,6 +452,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
&mut self, &mut self,
visitor: &mut NodeVisitor<'_, T, M>, visitor: &mut NodeVisitor<'_, T, M>,
mtry: usize, mtry: usize,
rng: &mut impl Rng,
) -> bool { ) -> bool {
let (n_rows, n_attr) = visitor.x.shape(); let (n_rows, n_attr) = visitor.x.shape();
@@ -482,7 +492,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
let mut variables = (0..n_attr).collect::<Vec<_>>(); let mut variables = (0..n_attr).collect::<Vec<_>>();
if mtry < n_attr { if mtry < n_attr {
variables.shuffle(&mut rand::thread_rng()); variables.shuffle(rng);
} }
for variable in variables.iter().take(mtry) { for variable in variables.iter().take(mtry) {
@@ -566,6 +576,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
mut visitor: NodeVisitor<'a, T, M>, mut visitor: NodeVisitor<'a, T, M>,
mtry: usize, mtry: usize,
visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>, visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>,
rng: &mut impl Rng,
) -> bool { ) -> bool {
let (n, _) = visitor.x.shape(); let (n, _) = visitor.x.shape();
let mut tc = 0; let mut tc = 0;
@@ -614,7 +625,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
visitor.level + 1, visitor.level + 1,
); );
if self.find_best_cutoff(&mut true_visitor, mtry) { if self.find_best_cutoff(&mut true_visitor, mtry, rng) {
visitor_queue.push_back(true_visitor); visitor_queue.push_back(true_visitor);
} }
@@ -627,7 +638,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
visitor.level + 1, visitor.level + 1,
); );
if self.find_best_cutoff(&mut false_visitor, mtry) { if self.find_best_cutoff(&mut false_visitor, mtry, rng) {
visitor_queue.push_back(false_visitor); visitor_queue.push_back(false_visitor);
} }
+17 -6
View File
@@ -63,6 +63,7 @@ use std::default::Default;
use std::fmt::Debug; use std::fmt::Debug;
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use rand::Rng;
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -242,7 +243,14 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
) -> Result<DecisionTreeRegressor<T>, Failed> { ) -> Result<DecisionTreeRegressor<T>, Failed> {
let (x_nrows, num_attributes) = x.shape(); let (x_nrows, num_attributes) = x.shape();
let samples = vec![1; x_nrows]; let samples = vec![1; x_nrows];
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters) DecisionTreeRegressor::fit_weak_learner(
x,
y,
samples,
num_attributes,
parameters,
&mut rand::thread_rng(),
)
} }
pub(crate) fn fit_weak_learner<M: Matrix<T>>( pub(crate) fn fit_weak_learner<M: Matrix<T>>(
@@ -251,6 +259,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
samples: Vec<usize>, samples: Vec<usize>,
mtry: usize, mtry: usize,
parameters: DecisionTreeRegressorParameters, parameters: DecisionTreeRegressorParameters,
rng: &mut impl Rng,
) -> Result<DecisionTreeRegressor<T>, Failed> { ) -> Result<DecisionTreeRegressor<T>, Failed> {
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
@@ -284,13 +293,13 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new(); let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
if tree.find_best_cutoff(&mut visitor, mtry) { if tree.find_best_cutoff(&mut visitor, mtry, rng) {
visitor_queue.push_back(visitor); visitor_queue.push_back(visitor);
} }
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) { while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
match visitor_queue.pop_front() { match visitor_queue.pop_front() {
Some(node) => tree.split(node, mtry, &mut visitor_queue), Some(node) => tree.split(node, mtry, &mut visitor_queue, rng),
None => break, None => break,
}; };
} }
@@ -343,6 +352,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
&mut self, &mut self,
visitor: &mut NodeVisitor<'_, T, M>, visitor: &mut NodeVisitor<'_, T, M>,
mtry: usize, mtry: usize,
rng: &mut impl Rng,
) -> bool { ) -> bool {
let (_, n_attr) = visitor.x.shape(); let (_, n_attr) = visitor.x.shape();
@@ -357,7 +367,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
let mut variables = (0..n_attr).collect::<Vec<_>>(); let mut variables = (0..n_attr).collect::<Vec<_>>();
if mtry < n_attr { if mtry < n_attr {
variables.shuffle(&mut rand::thread_rng()); variables.shuffle(rng);
} }
let parent_gain = let parent_gain =
@@ -432,6 +442,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
mut visitor: NodeVisitor<'a, T, M>, mut visitor: NodeVisitor<'a, T, M>,
mtry: usize, mtry: usize,
visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>, visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>,
rng: &mut impl Rng,
) -> bool { ) -> bool {
let (n, _) = visitor.x.shape(); let (n, _) = visitor.x.shape();
let mut tc = 0; let mut tc = 0;
@@ -480,7 +491,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
visitor.level + 1, visitor.level + 1,
); );
if self.find_best_cutoff(&mut true_visitor, mtry) { if self.find_best_cutoff(&mut true_visitor, mtry, rng) {
visitor_queue.push_back(true_visitor); visitor_queue.push_back(true_visitor);
} }
@@ -493,7 +504,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
visitor.level + 1, visitor.level + 1,
); );
if self.find_best_cutoff(&mut false_visitor, mtry) { if self.find_best_cutoff(&mut false_visitor, mtry, rng) {
visitor_queue.push_back(false_visitor); visitor_queue.push_back(false_visitor);
} }