Allow setting seed for RandomForestClassifier and Regressor (#120)
* Seed for the classifier. * Seed for the regressor. * Forgot one. * typo.
This commit is contained in:
@@ -68,6 +68,7 @@ use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::Rng;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -328,7 +329,14 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
) -> Result<DecisionTreeClassifier<T>, Failed> {
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let samples = vec![1; x_nrows];
|
||||
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||
DecisionTreeClassifier::fit_weak_learner(
|
||||
x,
|
||||
y,
|
||||
samples,
|
||||
num_attributes,
|
||||
parameters,
|
||||
&mut rand::thread_rng(),
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn fit_weak_learner<M: Matrix<T>>(
|
||||
@@ -337,6 +345,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
samples: Vec<usize>,
|
||||
mtry: usize,
|
||||
parameters: DecisionTreeClassifierParameters,
|
||||
rng: &mut impl Rng,
|
||||
) -> Result<DecisionTreeClassifier<T>, Failed> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
let (_, y_ncols) = y_m.shape();
|
||||
@@ -384,13 +393,13 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
|
||||
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
|
||||
|
||||
if tree.find_best_cutoff(&mut visitor, mtry) {
|
||||
if tree.find_best_cutoff(&mut visitor, mtry, rng) {
|
||||
visitor_queue.push_back(visitor);
|
||||
}
|
||||
|
||||
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
||||
match visitor_queue.pop_front() {
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue),
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue, rng),
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
@@ -443,6 +452,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
&mut self,
|
||||
visitor: &mut NodeVisitor<'_, T, M>,
|
||||
mtry: usize,
|
||||
rng: &mut impl Rng,
|
||||
) -> bool {
|
||||
let (n_rows, n_attr) = visitor.x.shape();
|
||||
|
||||
@@ -482,7 +492,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
let mut variables = (0..n_attr).collect::<Vec<_>>();
|
||||
|
||||
if mtry < n_attr {
|
||||
variables.shuffle(&mut rand::thread_rng());
|
||||
variables.shuffle(rng);
|
||||
}
|
||||
|
||||
for variable in variables.iter().take(mtry) {
|
||||
@@ -566,6 +576,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
mut visitor: NodeVisitor<'a, T, M>,
|
||||
mtry: usize,
|
||||
visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>,
|
||||
rng: &mut impl Rng,
|
||||
) -> bool {
|
||||
let (n, _) = visitor.x.shape();
|
||||
let mut tc = 0;
|
||||
@@ -614,7 +625,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
visitor.level + 1,
|
||||
);
|
||||
|
||||
if self.find_best_cutoff(&mut true_visitor, mtry) {
|
||||
if self.find_best_cutoff(&mut true_visitor, mtry, rng) {
|
||||
visitor_queue.push_back(true_visitor);
|
||||
}
|
||||
|
||||
@@ -627,7 +638,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
visitor.level + 1,
|
||||
);
|
||||
|
||||
if self.find_best_cutoff(&mut false_visitor, mtry) {
|
||||
if self.find_best_cutoff(&mut false_visitor, mtry, rng) {
|
||||
visitor_queue.push_back(false_visitor);
|
||||
}
|
||||
|
||||
|
||||
@@ -63,6 +63,7 @@ use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::Rng;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -242,7 +243,14 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
) -> Result<DecisionTreeRegressor<T>, Failed> {
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let samples = vec![1; x_nrows];
|
||||
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||
DecisionTreeRegressor::fit_weak_learner(
|
||||
x,
|
||||
y,
|
||||
samples,
|
||||
num_attributes,
|
||||
parameters,
|
||||
&mut rand::thread_rng(),
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn fit_weak_learner<M: Matrix<T>>(
|
||||
@@ -251,6 +259,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
samples: Vec<usize>,
|
||||
mtry: usize,
|
||||
parameters: DecisionTreeRegressorParameters,
|
||||
rng: &mut impl Rng,
|
||||
) -> Result<DecisionTreeRegressor<T>, Failed> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
|
||||
@@ -284,13 +293,13 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
|
||||
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
|
||||
|
||||
if tree.find_best_cutoff(&mut visitor, mtry) {
|
||||
if tree.find_best_cutoff(&mut visitor, mtry, rng) {
|
||||
visitor_queue.push_back(visitor);
|
||||
}
|
||||
|
||||
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
||||
match visitor_queue.pop_front() {
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue),
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue, rng),
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
@@ -343,6 +352,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
&mut self,
|
||||
visitor: &mut NodeVisitor<'_, T, M>,
|
||||
mtry: usize,
|
||||
rng: &mut impl Rng,
|
||||
) -> bool {
|
||||
let (_, n_attr) = visitor.x.shape();
|
||||
|
||||
@@ -357,7 +367,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
let mut variables = (0..n_attr).collect::<Vec<_>>();
|
||||
|
||||
if mtry < n_attr {
|
||||
variables.shuffle(&mut rand::thread_rng());
|
||||
variables.shuffle(rng);
|
||||
}
|
||||
|
||||
let parent_gain =
|
||||
@@ -432,6 +442,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
mut visitor: NodeVisitor<'a, T, M>,
|
||||
mtry: usize,
|
||||
visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>,
|
||||
rng: &mut impl Rng,
|
||||
) -> bool {
|
||||
let (n, _) = visitor.x.shape();
|
||||
let mut tc = 0;
|
||||
@@ -480,7 +491,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
visitor.level + 1,
|
||||
);
|
||||
|
||||
if self.find_best_cutoff(&mut true_visitor, mtry) {
|
||||
if self.find_best_cutoff(&mut true_visitor, mtry, rng) {
|
||||
visitor_queue.push_back(true_visitor);
|
||||
}
|
||||
|
||||
@@ -493,7 +504,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
visitor.level + 1,
|
||||
);
|
||||
|
||||
if self.find_best_cutoff(&mut false_visitor, mtry) {
|
||||
if self.find_best_cutoff(&mut false_visitor, mtry, rng) {
|
||||
visitor_queue.push_back(false_visitor);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user