From 12c102d02b11413608bdeb231e2fbd145a3d7d6b Mon Sep 17 00:00:00 2001
From: Malte Londschien <61679398+mlondschien@users.noreply.github.com>
Date: Thu, 11 Nov 2021 01:51:24 +0100
Subject: [PATCH] Allow setting seed for `RandomForestClassifier` and
`Regressor` (#120)
* Seed for the classifier.
* Seed for the regressor.
* Forgot one.
* typo.
---
src/ensemble/random_forest_classifier.rs | 24 +++++++++++++++++++-----
src/ensemble/random_forest_regressor.rs | 24 ++++++++++++++++++------
src/linalg/ndarray_bindings.rs | 1 +
src/tree/decision_tree_classifier.rs | 23 +++++++++++++++++------
src/tree/decision_tree_regressor.rs | 23 +++++++++++++++++------
5 files changed, 72 insertions(+), 23 deletions(-)
diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs
index f70604c..5cebced 100644
--- a/src/ensemble/random_forest_classifier.rs
+++ b/src/ensemble/random_forest_classifier.rs
@@ -45,10 +45,11 @@
//!
//!
//!
+use rand::rngs::StdRng;
+use rand::{Rng, SeedableRng};
use std::default::Default;
use std::fmt::Debug;
-use rand::Rng;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
@@ -79,6 +80,8 @@ pub struct RandomForestClassifierParameters {
pub m: Option,
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub keep_samples: bool,
+ /// Seed used for bootstrap sampling and feature selection for each tree.
+ pub seed: u64,
}
/// Random Forest Classifier
@@ -128,6 +131,12 @@ impl RandomForestClassifierParameters {
self.keep_samples = keep_samples;
self
}
+
+ /// Seed used for bootstrap sampling and feature selection for each tree.
+ pub fn with_seed(mut self, seed: u64) -> Self {
+ self.seed = seed;
+ self
+ }
}
impl PartialEq for RandomForestClassifier {
@@ -160,6 +169,7 @@ impl Default for RandomForestClassifierParameters {
n_trees: 100,
m: Option::None,
keep_samples: false,
+ seed: 0,
}
}
}
@@ -211,6 +221,7 @@ impl RandomForestClassifier {
.unwrap()
});
+ let mut rng = StdRng::seed_from_u64(parameters.seed);
let classes = y_m.unique();
let k = classes.len();
let mut trees: Vec> = Vec::new();
@@ -221,7 +232,7 @@ impl RandomForestClassifier {
}
for _ in 0..parameters.n_trees {
- let samples = RandomForestClassifier::::sample_with_replacement(&yi, k);
+ let samples = RandomForestClassifier::::sample_with_replacement(&yi, k, &mut rng);
if let Some(ref mut all_samples) = maybe_all_samples {
all_samples.push(samples.iter().map(|x| *x != 0).collect())
}
@@ -232,7 +243,8 @@ impl RandomForestClassifier {
min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split,
};
- let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params)?;
+ let tree =
+ DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
trees.push(tree);
}
@@ -304,8 +316,7 @@ impl RandomForestClassifier {
which_max(&result)
}
- fn sample_with_replacement(y: &[usize], num_classes: usize) -> Vec {
- let mut rng = rand::thread_rng();
+ fn sample_with_replacement(y: &[usize], num_classes: usize, rng: &mut impl Rng) -> Vec {
let class_weight = vec![1.; num_classes];
let nrows = y.len();
let mut samples = vec![0; nrows];
@@ -375,6 +386,7 @@ mod tests {
n_trees: 100,
m: Option::None,
keep_samples: false,
+ seed: 87,
},
)
.unwrap();
@@ -422,9 +434,11 @@ mod tests {
n_trees: 100,
m: Option::None,
keep_samples: true,
+ seed: 87,
},
)
.unwrap();
+
assert!(
accuracy(&y, &classifier.predict_oob(&x).unwrap())
< accuracy(&y, &classifier.predict(&x).unwrap())
diff --git a/src/ensemble/random_forest_regressor.rs b/src/ensemble/random_forest_regressor.rs
index 90ac479..c923cd8 100644
--- a/src/ensemble/random_forest_regressor.rs
+++ b/src/ensemble/random_forest_regressor.rs
@@ -43,10 +43,11 @@
//!
//!
+use rand::rngs::StdRng;
+use rand::{Rng, SeedableRng};
use std::default::Default;
use std::fmt::Debug;
-use rand::Rng;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
@@ -75,6 +76,8 @@ pub struct RandomForestRegressorParameters {
pub m: Option,
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub keep_samples: bool,
+ /// Seed used for bootstrap sampling and feature selection for each tree.
+ pub seed: u64,
}
/// Random Forest Regressor
@@ -118,8 +121,13 @@ impl RandomForestRegressorParameters {
self.keep_samples = keep_samples;
self
}
-}
+ /// Seed used for bootstrap sampling and feature selection for each tree.
+ pub fn with_seed(mut self, seed: u64) -> Self {
+ self.seed = seed;
+ self
+ }
+}
impl Default for RandomForestRegressorParameters {
fn default() -> Self {
RandomForestRegressorParameters {
@@ -129,6 +137,7 @@ impl Default for RandomForestRegressorParameters {
n_trees: 10,
m: Option::None,
keep_samples: false,
+ seed: 0,
}
}
}
@@ -182,6 +191,7 @@ impl RandomForestRegressor {
.m
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
+ let mut rng = StdRng::seed_from_u64(parameters.seed);
let mut trees: Vec> = Vec::new();
let mut maybe_all_samples: Option>> = Option::None;
@@ -190,7 +200,7 @@ impl RandomForestRegressor {
}
for _ in 0..parameters.n_trees {
- let samples = RandomForestRegressor::::sample_with_replacement(n_rows);
+ let samples = RandomForestRegressor::::sample_with_replacement(n_rows, &mut rng);
if let Some(ref mut all_samples) = maybe_all_samples {
all_samples.push(samples.iter().map(|x| *x != 0).collect())
}
@@ -199,7 +209,8 @@ impl RandomForestRegressor {
min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split,
};
- let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params)?;
+ let tree =
+ DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
trees.push(tree);
}
@@ -275,8 +286,7 @@ impl RandomForestRegressor {
result / T::from(n_trees).unwrap()
}
- fn sample_with_replacement(nrows: usize) -> Vec {
- let mut rng = rand::thread_rng();
+ fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec {
let mut samples = vec![0; nrows];
for _ in 0..nrows {
let xi = rng.gen_range(0..nrows);
@@ -328,6 +338,7 @@ mod tests {
n_trees: 1000,
m: Option::None,
keep_samples: false,
+ seed: 87,
},
)
.and_then(|rf| rf.predict(&x))
@@ -372,6 +383,7 @@ mod tests {
n_trees: 1000,
m: Option::None,
keep_samples: true,
+ seed: 87,
},
)
.unwrap();
diff --git a/src/linalg/ndarray_bindings.rs b/src/linalg/ndarray_bindings.rs
index 091aaaf..99e0918 100644
--- a/src/linalg/ndarray_bindings.rs
+++ b/src/linalg/ndarray_bindings.rs
@@ -1008,6 +1008,7 @@ mod tests {
n_trees: 1000,
m: Option::None,
keep_samples: false,
+ seed: 0,
},
)
.unwrap()
diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs
index 200fee5..751d5d1 100644
--- a/src/tree/decision_tree_classifier.rs
+++ b/src/tree/decision_tree_classifier.rs
@@ -68,6 +68,7 @@ use std::fmt::Debug;
use std::marker::PhantomData;
use rand::seq::SliceRandom;
+use rand::Rng;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
@@ -328,7 +329,14 @@ impl DecisionTreeClassifier {
) -> Result, Failed> {
let (x_nrows, num_attributes) = x.shape();
let samples = vec![1; x_nrows];
- DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
+ DecisionTreeClassifier::fit_weak_learner(
+ x,
+ y,
+ samples,
+ num_attributes,
+ parameters,
+ &mut rand::thread_rng(),
+ )
}
pub(crate) fn fit_weak_learner>(
@@ -337,6 +345,7 @@ impl DecisionTreeClassifier {
samples: Vec,
mtry: usize,
parameters: DecisionTreeClassifierParameters,
+ rng: &mut impl Rng,
) -> Result, Failed> {
let y_m = M::from_row_vector(y.clone());
let (_, y_ncols) = y_m.shape();
@@ -384,13 +393,13 @@ impl DecisionTreeClassifier {
let mut visitor_queue: LinkedList> = LinkedList::new();
- if tree.find_best_cutoff(&mut visitor, mtry) {
+ if tree.find_best_cutoff(&mut visitor, mtry, rng) {
visitor_queue.push_back(visitor);
}
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
match visitor_queue.pop_front() {
- Some(node) => tree.split(node, mtry, &mut visitor_queue),
+ Some(node) => tree.split(node, mtry, &mut visitor_queue, rng),
None => break,
};
}
@@ -443,6 +452,7 @@ impl DecisionTreeClassifier {
&mut self,
visitor: &mut NodeVisitor<'_, T, M>,
mtry: usize,
+ rng: &mut impl Rng,
) -> bool {
let (n_rows, n_attr) = visitor.x.shape();
@@ -482,7 +492,7 @@ impl DecisionTreeClassifier {
let mut variables = (0..n_attr).collect::>();
if mtry < n_attr {
- variables.shuffle(&mut rand::thread_rng());
+ variables.shuffle(rng);
}
for variable in variables.iter().take(mtry) {
@@ -566,6 +576,7 @@ impl DecisionTreeClassifier {
mut visitor: NodeVisitor<'a, T, M>,
mtry: usize,
visitor_queue: &mut LinkedList>,
+ rng: &mut impl Rng,
) -> bool {
let (n, _) = visitor.x.shape();
let mut tc = 0;
@@ -614,7 +625,7 @@ impl DecisionTreeClassifier {
visitor.level + 1,
);
- if self.find_best_cutoff(&mut true_visitor, mtry) {
+ if self.find_best_cutoff(&mut true_visitor, mtry, rng) {
visitor_queue.push_back(true_visitor);
}
@@ -627,7 +638,7 @@ impl DecisionTreeClassifier {
visitor.level + 1,
);
- if self.find_best_cutoff(&mut false_visitor, mtry) {
+ if self.find_best_cutoff(&mut false_visitor, mtry, rng) {
visitor_queue.push_back(false_visitor);
}
diff --git a/src/tree/decision_tree_regressor.rs b/src/tree/decision_tree_regressor.rs
index 6a0705f..34f58a9 100644
--- a/src/tree/decision_tree_regressor.rs
+++ b/src/tree/decision_tree_regressor.rs
@@ -63,6 +63,7 @@ use std::default::Default;
use std::fmt::Debug;
use rand::seq::SliceRandom;
+use rand::Rng;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
@@ -242,7 +243,14 @@ impl DecisionTreeRegressor {
) -> Result, Failed> {
let (x_nrows, num_attributes) = x.shape();
let samples = vec![1; x_nrows];
- DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
+ DecisionTreeRegressor::fit_weak_learner(
+ x,
+ y,
+ samples,
+ num_attributes,
+ parameters,
+ &mut rand::thread_rng(),
+ )
}
pub(crate) fn fit_weak_learner>(
@@ -251,6 +259,7 @@ impl DecisionTreeRegressor {
samples: Vec,
mtry: usize,
parameters: DecisionTreeRegressorParameters,
+ rng: &mut impl Rng,
) -> Result, Failed> {
let y_m = M::from_row_vector(y.clone());
@@ -284,13 +293,13 @@ impl DecisionTreeRegressor {
let mut visitor_queue: LinkedList> = LinkedList::new();
- if tree.find_best_cutoff(&mut visitor, mtry) {
+ if tree.find_best_cutoff(&mut visitor, mtry, rng) {
visitor_queue.push_back(visitor);
}
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
match visitor_queue.pop_front() {
- Some(node) => tree.split(node, mtry, &mut visitor_queue),
+ Some(node) => tree.split(node, mtry, &mut visitor_queue, rng),
None => break,
};
}
@@ -343,6 +352,7 @@ impl DecisionTreeRegressor {
&mut self,
visitor: &mut NodeVisitor<'_, T, M>,
mtry: usize,
+ rng: &mut impl Rng,
) -> bool {
let (_, n_attr) = visitor.x.shape();
@@ -357,7 +367,7 @@ impl DecisionTreeRegressor {
let mut variables = (0..n_attr).collect::>();
if mtry < n_attr {
- variables.shuffle(&mut rand::thread_rng());
+ variables.shuffle(rng);
}
let parent_gain =
@@ -432,6 +442,7 @@ impl DecisionTreeRegressor {
mut visitor: NodeVisitor<'a, T, M>,
mtry: usize,
visitor_queue: &mut LinkedList>,
+ rng: &mut impl Rng,
) -> bool {
let (n, _) = visitor.x.shape();
let mut tc = 0;
@@ -480,7 +491,7 @@ impl DecisionTreeRegressor {
visitor.level + 1,
);
- if self.find_best_cutoff(&mut true_visitor, mtry) {
+ if self.find_best_cutoff(&mut true_visitor, mtry, rng) {
visitor_queue.push_back(true_visitor);
}
@@ -493,7 +504,7 @@ impl DecisionTreeRegressor {
visitor.level + 1,
);
- if self.find_best_cutoff(&mut false_visitor, mtry) {
+ if self.find_best_cutoff(&mut false_visitor, mtry, rng) {
visitor_queue.push_back(false_visitor);
}