Lmm/add seeds in more algorithms (#164)
* Provide better output in flaky tests * feat: add seed parameter to multiple algorithms * Update changelog Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
This commit is contained in:
@@ -77,6 +77,7 @@ use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::rand::get_rng_impl;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -90,6 +91,8 @@ pub struct DecisionTreeClassifierParameters {
|
||||
pub min_samples_leaf: usize,
|
||||
/// The minimum number of samples required to split an internal node.
|
||||
pub min_samples_split: usize,
|
||||
/// Controls the randomness of the estimator
|
||||
pub seed: Option<u64>,
|
||||
}
|
||||
|
||||
/// Decision Tree
|
||||
@@ -197,6 +200,7 @@ impl Default for DecisionTreeClassifierParameters {
|
||||
max_depth: None,
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2,
|
||||
seed: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -467,14 +471,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
) -> Result<DecisionTreeClassifier<T>, Failed> {
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let samples = vec![1; x_nrows];
|
||||
DecisionTreeClassifier::fit_weak_learner(
|
||||
x,
|
||||
y,
|
||||
samples,
|
||||
num_attributes,
|
||||
parameters,
|
||||
&mut rand::thread_rng(),
|
||||
)
|
||||
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||
}
|
||||
|
||||
pub(crate) fn fit_weak_learner<M: Matrix<T>>(
|
||||
@@ -483,7 +480,6 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
samples: Vec<usize>,
|
||||
mtry: usize,
|
||||
parameters: DecisionTreeClassifierParameters,
|
||||
rng: &mut impl Rng,
|
||||
) -> Result<DecisionTreeClassifier<T>, Failed> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
let (_, y_ncols) = y_m.shape();
|
||||
@@ -497,6 +493,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
)));
|
||||
}
|
||||
|
||||
let mut rng = get_rng_impl(parameters.seed);
|
||||
let mut yi: Vec<usize> = vec![0; y_ncols];
|
||||
|
||||
for (i, yi_i) in yi.iter_mut().enumerate().take(y_ncols) {
|
||||
@@ -531,13 +528,13 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
|
||||
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
|
||||
|
||||
if tree.find_best_cutoff(&mut visitor, mtry, rng) {
|
||||
if tree.find_best_cutoff(&mut visitor, mtry, &mut rng) {
|
||||
visitor_queue.push_back(visitor);
|
||||
}
|
||||
|
||||
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
||||
match visitor_queue.pop_front() {
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue, rng),
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue, &mut rng),
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
@@ -874,7 +871,8 @@ mod tests {
|
||||
criterion: SplitCriterion::Entropy,
|
||||
max_depth: Some(3),
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2
|
||||
min_samples_split: 2,
|
||||
seed: None
|
||||
}
|
||||
)
|
||||
.unwrap()
|
||||
|
||||
@@ -72,6 +72,7 @@ use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::rand::get_rng_impl;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -83,6 +84,8 @@ pub struct DecisionTreeRegressorParameters {
|
||||
pub min_samples_leaf: usize,
|
||||
/// The minimum number of samples required to split an internal node.
|
||||
pub min_samples_split: usize,
|
||||
/// Controls the randomness of the estimator
|
||||
pub seed: Option<u64>,
|
||||
}
|
||||
|
||||
/// Regression Tree
|
||||
@@ -130,6 +133,7 @@ impl Default for DecisionTreeRegressorParameters {
|
||||
max_depth: None,
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2,
|
||||
seed: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -357,14 +361,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
) -> Result<DecisionTreeRegressor<T>, Failed> {
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let samples = vec![1; x_nrows];
|
||||
DecisionTreeRegressor::fit_weak_learner(
|
||||
x,
|
||||
y,
|
||||
samples,
|
||||
num_attributes,
|
||||
parameters,
|
||||
&mut rand::thread_rng(),
|
||||
)
|
||||
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||
}
|
||||
|
||||
pub(crate) fn fit_weak_learner<M: Matrix<T>>(
|
||||
@@ -373,7 +370,6 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
samples: Vec<usize>,
|
||||
mtry: usize,
|
||||
parameters: DecisionTreeRegressorParameters,
|
||||
rng: &mut impl Rng,
|
||||
) -> Result<DecisionTreeRegressor<T>, Failed> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
|
||||
@@ -381,6 +377,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
let (_, num_attributes) = x.shape();
|
||||
|
||||
let mut nodes: Vec<Node<T>> = Vec::new();
|
||||
let mut rng = get_rng_impl(parameters.seed);
|
||||
|
||||
let mut n = 0;
|
||||
let mut sum = T::zero();
|
||||
@@ -407,13 +404,13 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
|
||||
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
|
||||
|
||||
if tree.find_best_cutoff(&mut visitor, mtry, rng) {
|
||||
if tree.find_best_cutoff(&mut visitor, mtry, &mut rng) {
|
||||
visitor_queue.push_back(visitor);
|
||||
}
|
||||
|
||||
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
||||
match visitor_queue.pop_front() {
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue, rng),
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue, &mut rng),
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
@@ -699,6 +696,7 @@ mod tests {
|
||||
max_depth: Option::None,
|
||||
min_samples_leaf: 2,
|
||||
min_samples_split: 6,
|
||||
seed: None,
|
||||
},
|
||||
)
|
||||
.and_then(|t| t.predict(&x))
|
||||
@@ -719,6 +717,7 @@ mod tests {
|
||||
max_depth: Option::None,
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 3,
|
||||
seed: None,
|
||||
},
|
||||
)
|
||||
.and_then(|t| t.predict(&x))
|
||||
|
||||
Reference in New Issue
Block a user