Allow setting seed for RandomForestClassifier and Regressor (#120)

* Seed for the classifier.

* Seed for the regressor.

* Forgot one.

* typo.
This commit is contained in:
Malte Londschien
2021-11-11 01:51:24 +01:00
committed by GitHub
parent 521dab49ef
commit 12c102d02b
5 changed files with 72 additions and 23 deletions
+17 -6
View File
@@ -68,6 +68,7 @@ use std::fmt::Debug;
use std::marker::PhantomData;
use rand::seq::SliceRandom;
use rand::Rng;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
@@ -328,7 +329,14 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
) -> Result<DecisionTreeClassifier<T>, Failed> {
let (x_nrows, num_attributes) = x.shape();
let samples = vec![1; x_nrows];
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
DecisionTreeClassifier::fit_weak_learner(
x,
y,
samples,
num_attributes,
parameters,
&mut rand::thread_rng(),
)
}
pub(crate) fn fit_weak_learner<M: Matrix<T>>(
@@ -337,6 +345,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
samples: Vec<usize>,
mtry: usize,
parameters: DecisionTreeClassifierParameters,
rng: &mut impl Rng,
) -> Result<DecisionTreeClassifier<T>, Failed> {
let y_m = M::from_row_vector(y.clone());
let (_, y_ncols) = y_m.shape();
@@ -384,13 +393,13 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
if tree.find_best_cutoff(&mut visitor, mtry) {
if tree.find_best_cutoff(&mut visitor, mtry, rng) {
visitor_queue.push_back(visitor);
}
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
match visitor_queue.pop_front() {
Some(node) => tree.split(node, mtry, &mut visitor_queue),
Some(node) => tree.split(node, mtry, &mut visitor_queue, rng),
None => break,
};
}
@@ -443,6 +452,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
&mut self,
visitor: &mut NodeVisitor<'_, T, M>,
mtry: usize,
rng: &mut impl Rng,
) -> bool {
let (n_rows, n_attr) = visitor.x.shape();
@@ -482,7 +492,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
let mut variables = (0..n_attr).collect::<Vec<_>>();
if mtry < n_attr {
variables.shuffle(&mut rand::thread_rng());
variables.shuffle(rng);
}
for variable in variables.iter().take(mtry) {
@@ -566,6 +576,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
mut visitor: NodeVisitor<'a, T, M>,
mtry: usize,
visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>,
rng: &mut impl Rng,
) -> bool {
let (n, _) = visitor.x.shape();
let mut tc = 0;
@@ -614,7 +625,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
visitor.level + 1,
);
if self.find_best_cutoff(&mut true_visitor, mtry) {
if self.find_best_cutoff(&mut true_visitor, mtry, rng) {
visitor_queue.push_back(true_visitor);
}
@@ -627,7 +638,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
visitor.level + 1,
);
if self.find_best_cutoff(&mut false_visitor, mtry) {
if self.find_best_cutoff(&mut false_visitor, mtry, rng) {
visitor_queue.push_back(false_visitor);
}
+17 -6
View File
@@ -63,6 +63,7 @@ use std::default::Default;
use std::fmt::Debug;
use rand::seq::SliceRandom;
use rand::Rng;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
@@ -242,7 +243,14 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
) -> Result<DecisionTreeRegressor<T>, Failed> {
let (x_nrows, num_attributes) = x.shape();
let samples = vec![1; x_nrows];
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
DecisionTreeRegressor::fit_weak_learner(
x,
y,
samples,
num_attributes,
parameters,
&mut rand::thread_rng(),
)
}
pub(crate) fn fit_weak_learner<M: Matrix<T>>(
@@ -251,6 +259,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
samples: Vec<usize>,
mtry: usize,
parameters: DecisionTreeRegressorParameters,
rng: &mut impl Rng,
) -> Result<DecisionTreeRegressor<T>, Failed> {
let y_m = M::from_row_vector(y.clone());
@@ -284,13 +293,13 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
if tree.find_best_cutoff(&mut visitor, mtry) {
if tree.find_best_cutoff(&mut visitor, mtry, rng) {
visitor_queue.push_back(visitor);
}
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
match visitor_queue.pop_front() {
Some(node) => tree.split(node, mtry, &mut visitor_queue),
Some(node) => tree.split(node, mtry, &mut visitor_queue, rng),
None => break,
};
}
@@ -343,6 +352,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
&mut self,
visitor: &mut NodeVisitor<'_, T, M>,
mtry: usize,
rng: &mut impl Rng,
) -> bool {
let (_, n_attr) = visitor.x.shape();
@@ -357,7 +367,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
let mut variables = (0..n_attr).collect::<Vec<_>>();
if mtry < n_attr {
variables.shuffle(&mut rand::thread_rng());
variables.shuffle(rng);
}
let parent_gain =
@@ -432,6 +442,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
mut visitor: NodeVisitor<'a, T, M>,
mtry: usize,
visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>,
rng: &mut impl Rng,
) -> bool {
let (n, _) = visitor.x.shape();
let mut tc = 0;
@@ -480,7 +491,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
visitor.level + 1,
);
if self.find_best_cutoff(&mut true_visitor, mtry) {
if self.find_best_cutoff(&mut true_visitor, mtry, rng) {
visitor_queue.push_back(true_visitor);
}
@@ -493,7 +504,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
visitor.level + 1,
);
if self.find_best_cutoff(&mut false_visitor, mtry) {
if self.find_best_cutoff(&mut false_visitor, mtry, rng) {
visitor_queue.push_back(false_visitor);
}