diff --git a/src/ensemble/random_forest.rs b/src/ensemble/random_forest.rs index b318597..09b47f6 100644 --- a/src/ensemble/random_forest.rs +++ b/src/ensemble/random_forest.rs @@ -9,7 +9,8 @@ use crate::tree::decision_tree_classifier::{DecisionTreeClassifier, DecisionTree pub struct RandomForestParameters { pub criterion: SplitCriterion, pub max_depth: Option, - pub min_samples_leaf: u16, + pub min_samples_leaf: u16, + pub min_samples_split: usize, pub n_trees: u16, pub mtry: Option } @@ -27,6 +28,7 @@ impl Default for RandomForestParameters { criterion: SplitCriterion::Gini, max_depth: None, min_samples_leaf: 1, + min_samples_split: 2, n_trees: 100, mtry: Option::None } @@ -58,7 +60,8 @@ impl RandomForest { let params = DecisionTreeClassifierParameters{ criterion: parameters.criterion.clone(), max_depth: parameters.max_depth, - min_samples_leaf: parameters.min_samples_leaf + min_samples_leaf: parameters.min_samples_leaf, + min_samples_split: parameters.min_samples_split }; let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params); trees.push(tree); diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index 5c1e483..4ac70c2 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -7,7 +7,8 @@ use crate::algorithm::sort::quick_sort::QuickArgSort; pub struct DecisionTreeClassifierParameters { pub criterion: SplitCriterion, pub max_depth: Option, - pub min_samples_leaf: u16 + pub min_samples_leaf: u16, + pub min_samples_split: usize } #[derive(Debug)] @@ -43,7 +44,8 @@ impl Default for DecisionTreeClassifierParameters { DecisionTreeClassifierParameters { criterion: SplitCriterion::Gini, max_depth: None, - min_samples_leaf: 1 + min_samples_leaf: 1, + min_samples_split: 2 } } } @@ -437,7 +439,7 @@ mod tests { assert_eq!(y, DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x)); - assert_eq!(3, DecisionTreeClassifier::fit(&x, &y, DecisionTreeClassifierParameters{criterion: SplitCriterion::Entropy, max_depth: Some(3), min_samples_leaf: 1}).depth); + assert_eq!(3, DecisionTreeClassifier::fit(&x, &y, DecisionTreeClassifierParameters{criterion: SplitCriterion::Entropy, max_depth: Some(3), min_samples_leaf: 1, min_samples_split: 2}).depth); } diff --git a/src/tree/decision_tree_regressor.rs b/src/tree/decision_tree_regressor.rs index 4b5e8f2..752d201 100644 --- a/src/tree/decision_tree_regressor.rs +++ b/src/tree/decision_tree_regressor.rs @@ -33,7 +33,7 @@ impl Default for DecisionTreeRegressorParameters { fn default() -> Self { DecisionTreeRegressorParameters { max_depth: None, - min_samples_leaf: 2, + min_samples_leaf: 1, min_samples_split: 2 } } @@ -132,7 +132,7 @@ impl DecisionTreeRegressor { while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) { match visitor_queue.pop_front() { - Some(node) => tree.split(node, mtry, &mut visitor_queue,), + Some(node) => tree.split(node, mtry, &mut visitor_queue), None => break }; } @@ -186,7 +186,7 @@ impl DecisionTreeRegressor { let n: usize = visitor.samples.iter().sum(); - if n <= self.parameters.min_samples_split { + if n < self.parameters.min_samples_split { return false; } @@ -289,13 +289,13 @@ impl DecisionTreeRegressor { let mut true_visitor = NodeVisitor::::new(true_child_idx, true_samples, visitor.order, visitor.x, visitor.y, visitor.level + 1); - if tc > self.parameters.min_samples_split && self.find_best_cutoff(&mut true_visitor, mtry) { + if self.find_best_cutoff(&mut true_visitor, mtry) { visitor_queue.push_back(true_visitor); } let mut false_visitor = NodeVisitor::::new(false_child_idx, visitor.samples, visitor.order, visitor.x, visitor.y, visitor.level + 1); - if fc > self.parameters.min_samples_split && self.find_best_cutoff(&mut false_visitor, mtry) { + if self.find_best_cutoff(&mut false_visitor, mtry) { visitor_queue.push_back(false_visitor); } @@ -331,11 +331,10 @@ mod tests { &[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]); let y = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]; - let expected_y = vec![85.6, 89.0, 85.6, 89.0, 97.15, 97.15, 100.0, 100.0, 100.0, 107.9, 107.9, 107.9, 113.4, 113.4, 116.3, 116.3]; - let y_hat = DecisionTreeRegressor::fit(&x, &y, Default::default()).predict(&x); + let y_hat = DecisionTreeRegressor::fit(&x, &y, Default::default()).predict(&x); for i in 0..y_hat.len() { - assert!((y_hat[i] - expected_y[i]).abs() < 0.1); + assert!((y_hat[i] - y[i]).abs() < 0.1); } let expected_y = vec![87.3, 87.3, 87.3, 87.3, 98.9, 98.9, 98.9, 98.9, 98.9, 107.9, 107.9, 107.9, 114.85, 114.85, 114.85, 114.85]; @@ -343,7 +342,14 @@ mod tests { for i in 0..y_hat.len() { assert!((y_hat[i] - expected_y[i]).abs() < 0.1); - } + } + + let expected_y = vec![83.0, 88.35, 88.35, 89.5, 97.15, 97.15, 99.5, 99.5, 101.2, 104.6, 109.6, 109.6, 113.4, 113.4, 116.30, 116.30]; + let y_hat = DecisionTreeRegressor::fit(&x, &y, DecisionTreeRegressorParameters{max_depth: Option::None, min_samples_leaf: 1, min_samples_split: 3}).predict(&x); + + for i in 0..y_hat.len() { + assert!((y_hat[i] - expected_y[i]).abs() < 0.1); + } }