Allow setting seed for RandomForestClassifier and Regressor (#120)
* Seed for the classifier. * Seed for the regressor. * Forgot one. * typo.
This commit is contained in:
@@ -45,10 +45,11 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rand::Rng;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -79,6 +80,8 @@ pub struct RandomForestClassifierParameters {
|
||||
pub m: Option<usize>,
|
||||
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||
pub keep_samples: bool,
|
||||
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||
pub seed: u64,
|
||||
}
|
||||
|
||||
/// Random Forest Classifier
|
||||
@@ -128,6 +131,12 @@ impl RandomForestClassifierParameters {
|
||||
self.keep_samples = keep_samples;
|
||||
self
|
||||
}
|
||||
|
||||
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||
pub fn with_seed(mut self, seed: u64) -> Self {
|
||||
self.seed = seed;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> PartialEq for RandomForestClassifier<T> {
|
||||
@@ -160,6 +169,7 @@ impl Default for RandomForestClassifierParameters {
|
||||
n_trees: 100,
|
||||
m: Option::None,
|
||||
keep_samples: false,
|
||||
seed: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -211,6 +221,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let mut rng = StdRng::seed_from_u64(parameters.seed);
|
||||
let classes = y_m.unique();
|
||||
let k = classes.len();
|
||||
let mut trees: Vec<DecisionTreeClassifier<T>> = Vec::new();
|
||||
@@ -221,7 +232,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
}
|
||||
|
||||
for _ in 0..parameters.n_trees {
|
||||
let samples = RandomForestClassifier::<T>::sample_with_replacement(&yi, k);
|
||||
let samples = RandomForestClassifier::<T>::sample_with_replacement(&yi, k, &mut rng);
|
||||
if let Some(ref mut all_samples) = maybe_all_samples {
|
||||
all_samples.push(samples.iter().map(|x| *x != 0).collect())
|
||||
}
|
||||
@@ -232,7 +243,8 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
min_samples_leaf: parameters.min_samples_leaf,
|
||||
min_samples_split: parameters.min_samples_split,
|
||||
};
|
||||
let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params)?;
|
||||
let tree =
|
||||
DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
|
||||
trees.push(tree);
|
||||
}
|
||||
|
||||
@@ -304,8 +316,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
which_max(&result)
|
||||
}
|
||||
|
||||
fn sample_with_replacement(y: &[usize], num_classes: usize) -> Vec<usize> {
|
||||
let mut rng = rand::thread_rng();
|
||||
fn sample_with_replacement(y: &[usize], num_classes: usize, rng: &mut impl Rng) -> Vec<usize> {
|
||||
let class_weight = vec![1.; num_classes];
|
||||
let nrows = y.len();
|
||||
let mut samples = vec![0; nrows];
|
||||
@@ -375,6 +386,7 @@ mod tests {
|
||||
n_trees: 100,
|
||||
m: Option::None,
|
||||
keep_samples: false,
|
||||
seed: 87,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
@@ -422,9 +434,11 @@ mod tests {
|
||||
n_trees: 100,
|
||||
m: Option::None,
|
||||
keep_samples: true,
|
||||
seed: 87,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
accuracy(&y, &classifier.predict_oob(&x).unwrap())
|
||||
< accuracy(&y, &classifier.predict(&x).unwrap())
|
||||
|
||||
@@ -43,10 +43,11 @@
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rand::Rng;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -75,6 +76,8 @@ pub struct RandomForestRegressorParameters {
|
||||
pub m: Option<usize>,
|
||||
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||
pub keep_samples: bool,
|
||||
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||
pub seed: u64,
|
||||
}
|
||||
|
||||
/// Random Forest Regressor
|
||||
@@ -118,8 +121,13 @@ impl RandomForestRegressorParameters {
|
||||
self.keep_samples = keep_samples;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||
pub fn with_seed(mut self, seed: u64) -> Self {
|
||||
self.seed = seed;
|
||||
self
|
||||
}
|
||||
}
|
||||
impl Default for RandomForestRegressorParameters {
|
||||
fn default() -> Self {
|
||||
RandomForestRegressorParameters {
|
||||
@@ -129,6 +137,7 @@ impl Default for RandomForestRegressorParameters {
|
||||
n_trees: 10,
|
||||
m: Option::None,
|
||||
keep_samples: false,
|
||||
seed: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -182,6 +191,7 @@ impl<T: RealNumber> RandomForestRegressor<T> {
|
||||
.m
|
||||
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
|
||||
|
||||
let mut rng = StdRng::seed_from_u64(parameters.seed);
|
||||
let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new();
|
||||
|
||||
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
|
||||
@@ -190,7 +200,7 @@ impl<T: RealNumber> RandomForestRegressor<T> {
|
||||
}
|
||||
|
||||
for _ in 0..parameters.n_trees {
|
||||
let samples = RandomForestRegressor::<T>::sample_with_replacement(n_rows);
|
||||
let samples = RandomForestRegressor::<T>::sample_with_replacement(n_rows, &mut rng);
|
||||
if let Some(ref mut all_samples) = maybe_all_samples {
|
||||
all_samples.push(samples.iter().map(|x| *x != 0).collect())
|
||||
}
|
||||
@@ -199,7 +209,8 @@ impl<T: RealNumber> RandomForestRegressor<T> {
|
||||
min_samples_leaf: parameters.min_samples_leaf,
|
||||
min_samples_split: parameters.min_samples_split,
|
||||
};
|
||||
let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params)?;
|
||||
let tree =
|
||||
DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
|
||||
trees.push(tree);
|
||||
}
|
||||
|
||||
@@ -275,8 +286,7 @@ impl<T: RealNumber> RandomForestRegressor<T> {
|
||||
result / T::from(n_trees).unwrap()
|
||||
}
|
||||
|
||||
fn sample_with_replacement(nrows: usize) -> Vec<usize> {
|
||||
let mut rng = rand::thread_rng();
|
||||
fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec<usize> {
|
||||
let mut samples = vec![0; nrows];
|
||||
for _ in 0..nrows {
|
||||
let xi = rng.gen_range(0..nrows);
|
||||
@@ -328,6 +338,7 @@ mod tests {
|
||||
n_trees: 1000,
|
||||
m: Option::None,
|
||||
keep_samples: false,
|
||||
seed: 87,
|
||||
},
|
||||
)
|
||||
.and_then(|rf| rf.predict(&x))
|
||||
@@ -372,6 +383,7 @@ mod tests {
|
||||
n_trees: 1000,
|
||||
m: Option::None,
|
||||
keep_samples: true,
|
||||
seed: 87,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -1008,6 +1008,7 @@ mod tests {
|
||||
n_trees: 1000,
|
||||
m: Option::None,
|
||||
keep_samples: false,
|
||||
seed: 0,
|
||||
},
|
||||
)
|
||||
.unwrap()
|
||||
|
||||
@@ -68,6 +68,7 @@ use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::Rng;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -328,7 +329,14 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
) -> Result<DecisionTreeClassifier<T>, Failed> {
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let samples = vec![1; x_nrows];
|
||||
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||
DecisionTreeClassifier::fit_weak_learner(
|
||||
x,
|
||||
y,
|
||||
samples,
|
||||
num_attributes,
|
||||
parameters,
|
||||
&mut rand::thread_rng(),
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn fit_weak_learner<M: Matrix<T>>(
|
||||
@@ -337,6 +345,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
samples: Vec<usize>,
|
||||
mtry: usize,
|
||||
parameters: DecisionTreeClassifierParameters,
|
||||
rng: &mut impl Rng,
|
||||
) -> Result<DecisionTreeClassifier<T>, Failed> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
let (_, y_ncols) = y_m.shape();
|
||||
@@ -384,13 +393,13 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
|
||||
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
|
||||
|
||||
if tree.find_best_cutoff(&mut visitor, mtry) {
|
||||
if tree.find_best_cutoff(&mut visitor, mtry, rng) {
|
||||
visitor_queue.push_back(visitor);
|
||||
}
|
||||
|
||||
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
||||
match visitor_queue.pop_front() {
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue),
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue, rng),
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
@@ -443,6 +452,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
&mut self,
|
||||
visitor: &mut NodeVisitor<'_, T, M>,
|
||||
mtry: usize,
|
||||
rng: &mut impl Rng,
|
||||
) -> bool {
|
||||
let (n_rows, n_attr) = visitor.x.shape();
|
||||
|
||||
@@ -482,7 +492,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
let mut variables = (0..n_attr).collect::<Vec<_>>();
|
||||
|
||||
if mtry < n_attr {
|
||||
variables.shuffle(&mut rand::thread_rng());
|
||||
variables.shuffle(rng);
|
||||
}
|
||||
|
||||
for variable in variables.iter().take(mtry) {
|
||||
@@ -566,6 +576,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
mut visitor: NodeVisitor<'a, T, M>,
|
||||
mtry: usize,
|
||||
visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>,
|
||||
rng: &mut impl Rng,
|
||||
) -> bool {
|
||||
let (n, _) = visitor.x.shape();
|
||||
let mut tc = 0;
|
||||
@@ -614,7 +625,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
visitor.level + 1,
|
||||
);
|
||||
|
||||
if self.find_best_cutoff(&mut true_visitor, mtry) {
|
||||
if self.find_best_cutoff(&mut true_visitor, mtry, rng) {
|
||||
visitor_queue.push_back(true_visitor);
|
||||
}
|
||||
|
||||
@@ -627,7 +638,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
visitor.level + 1,
|
||||
);
|
||||
|
||||
if self.find_best_cutoff(&mut false_visitor, mtry) {
|
||||
if self.find_best_cutoff(&mut false_visitor, mtry, rng) {
|
||||
visitor_queue.push_back(false_visitor);
|
||||
}
|
||||
|
||||
|
||||
@@ -63,6 +63,7 @@ use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::Rng;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -242,7 +243,14 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
) -> Result<DecisionTreeRegressor<T>, Failed> {
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let samples = vec![1; x_nrows];
|
||||
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||
DecisionTreeRegressor::fit_weak_learner(
|
||||
x,
|
||||
y,
|
||||
samples,
|
||||
num_attributes,
|
||||
parameters,
|
||||
&mut rand::thread_rng(),
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn fit_weak_learner<M: Matrix<T>>(
|
||||
@@ -251,6 +259,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
samples: Vec<usize>,
|
||||
mtry: usize,
|
||||
parameters: DecisionTreeRegressorParameters,
|
||||
rng: &mut impl Rng,
|
||||
) -> Result<DecisionTreeRegressor<T>, Failed> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
|
||||
@@ -284,13 +293,13 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
|
||||
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
|
||||
|
||||
if tree.find_best_cutoff(&mut visitor, mtry) {
|
||||
if tree.find_best_cutoff(&mut visitor, mtry, rng) {
|
||||
visitor_queue.push_back(visitor);
|
||||
}
|
||||
|
||||
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
||||
match visitor_queue.pop_front() {
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue),
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue, rng),
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
@@ -343,6 +352,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
&mut self,
|
||||
visitor: &mut NodeVisitor<'_, T, M>,
|
||||
mtry: usize,
|
||||
rng: &mut impl Rng,
|
||||
) -> bool {
|
||||
let (_, n_attr) = visitor.x.shape();
|
||||
|
||||
@@ -357,7 +367,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
let mut variables = (0..n_attr).collect::<Vec<_>>();
|
||||
|
||||
if mtry < n_attr {
|
||||
variables.shuffle(&mut rand::thread_rng());
|
||||
variables.shuffle(rng);
|
||||
}
|
||||
|
||||
let parent_gain =
|
||||
@@ -432,6 +442,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
mut visitor: NodeVisitor<'a, T, M>,
|
||||
mtry: usize,
|
||||
visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>,
|
||||
rng: &mut impl Rng,
|
||||
) -> bool {
|
||||
let (n, _) = visitor.x.shape();
|
||||
let mut tc = 0;
|
||||
@@ -480,7 +491,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
visitor.level + 1,
|
||||
);
|
||||
|
||||
if self.find_best_cutoff(&mut true_visitor, mtry) {
|
||||
if self.find_best_cutoff(&mut true_visitor, mtry, rng) {
|
||||
visitor_queue.push_back(true_visitor);
|
||||
}
|
||||
|
||||
@@ -493,7 +504,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
visitor.level + 1,
|
||||
);
|
||||
|
||||
if self.find_best_cutoff(&mut false_visitor, mtry) {
|
||||
if self.find_best_cutoff(&mut false_visitor, mtry, rng) {
|
||||
visitor_queue.push_back(false_visitor);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user