fix: minor bug in decision_tree_regressor

This commit is contained in:
Volodymyr Orlov
2020-03-23 15:32:06 -07:00
parent 6577e22111
commit 18dc6bdb40
3 changed files with 25 additions and 14 deletions
+5 -3
View File
@@ -7,7 +7,8 @@ use crate::algorithm::sort::quick_sort::QuickArgSort;
pub struct DecisionTreeClassifierParameters {
pub criterion: SplitCriterion,
pub max_depth: Option<u16>,
pub min_samples_leaf: u16
pub min_samples_leaf: u16,
pub min_samples_split: usize
}
#[derive(Debug)]
@@ -43,7 +44,8 @@ impl Default for DecisionTreeClassifierParameters {
DecisionTreeClassifierParameters {
criterion: SplitCriterion::Gini,
max_depth: None,
min_samples_leaf: 1
min_samples_leaf: 1,
min_samples_split: 2
}
}
}
@@ -437,7 +439,7 @@ mod tests {
assert_eq!(y, DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x));
assert_eq!(3, DecisionTreeClassifier::fit(&x, &y, DecisionTreeClassifierParameters{criterion: SplitCriterion::Entropy, max_depth: Some(3), min_samples_leaf: 1}).depth);
assert_eq!(3, DecisionTreeClassifier::fit(&x, &y, DecisionTreeClassifierParameters{criterion: SplitCriterion::Entropy, max_depth: Some(3), min_samples_leaf: 1, min_samples_split: 2}).depth);
}
+15 -9
View File
@@ -33,7 +33,7 @@ impl Default for DecisionTreeRegressorParameters {
fn default() -> Self {
DecisionTreeRegressorParameters {
max_depth: None,
min_samples_leaf: 2,
min_samples_leaf: 1,
min_samples_split: 2
}
}
@@ -132,7 +132,7 @@ impl DecisionTreeRegressor {
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
match visitor_queue.pop_front() {
Some(node) => tree.split(node, mtry, &mut visitor_queue,),
Some(node) => tree.split(node, mtry, &mut visitor_queue),
None => break
};
}
@@ -186,7 +186,7 @@ impl DecisionTreeRegressor {
let n: usize = visitor.samples.iter().sum();
if n <= self.parameters.min_samples_split {
if n < self.parameters.min_samples_split {
return false;
}
@@ -289,13 +289,13 @@ impl DecisionTreeRegressor {
let mut true_visitor = NodeVisitor::<M>::new(true_child_idx, true_samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
if tc > self.parameters.min_samples_split && self.find_best_cutoff(&mut true_visitor, mtry) {
if self.find_best_cutoff(&mut true_visitor, mtry) {
visitor_queue.push_back(true_visitor);
}
let mut false_visitor = NodeVisitor::<M>::new(false_child_idx, visitor.samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
if fc > self.parameters.min_samples_split && self.find_best_cutoff(&mut false_visitor, mtry) {
if self.find_best_cutoff(&mut false_visitor, mtry) {
visitor_queue.push_back(false_visitor);
}
@@ -331,11 +331,10 @@ mod tests {
&[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
let y = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
let expected_y = vec![85.6, 89.0, 85.6, 89.0, 97.15, 97.15, 100.0, 100.0, 100.0, 107.9, 107.9, 107.9, 113.4, 113.4, 116.3, 116.3];
let y_hat = DecisionTreeRegressor::fit(&x, &y, Default::default()).predict(&x);
let y_hat = DecisionTreeRegressor::fit(&x, &y, Default::default()).predict(&x);
for i in 0..y_hat.len() {
assert!((y_hat[i] - expected_y[i]).abs() < 0.1);
assert!((y_hat[i] - y[i]).abs() < 0.1);
}
let expected_y = vec![87.3, 87.3, 87.3, 87.3, 98.9, 98.9, 98.9, 98.9, 98.9, 107.9, 107.9, 107.9, 114.85, 114.85, 114.85, 114.85];
@@ -343,7 +342,14 @@ mod tests {
for i in 0..y_hat.len() {
assert!((y_hat[i] - expected_y[i]).abs() < 0.1);
}
}
let expected_y = vec![83.0, 88.35, 88.35, 89.5, 97.15, 97.15, 99.5, 99.5, 101.2, 104.6, 109.6, 109.6, 113.4, 113.4, 116.30, 116.30];
let y_hat = DecisionTreeRegressor::fit(&x, &y, DecisionTreeRegressorParameters{max_depth: Option::None, min_samples_leaf: 1, min_samples_split: 3}).predict(&x);
for i in 0..y_hat.len() {
assert!((y_hat[i] - expected_y[i]).abs() < 0.1);
}
}