fix: minor bug in decision_tree_regressor
This commit is contained in:
@@ -9,7 +9,8 @@ use crate::tree::decision_tree_classifier::{DecisionTreeClassifier, DecisionTree
|
|||||||
pub struct RandomForestParameters {
|
pub struct RandomForestParameters {
|
||||||
pub criterion: SplitCriterion,
|
pub criterion: SplitCriterion,
|
||||||
pub max_depth: Option<u16>,
|
pub max_depth: Option<u16>,
|
||||||
pub min_samples_leaf: u16,
|
pub min_samples_leaf: u16,
|
||||||
|
pub min_samples_split: usize,
|
||||||
pub n_trees: u16,
|
pub n_trees: u16,
|
||||||
pub mtry: Option<usize>
|
pub mtry: Option<usize>
|
||||||
}
|
}
|
||||||
@@ -27,6 +28,7 @@ impl Default for RandomForestParameters {
|
|||||||
criterion: SplitCriterion::Gini,
|
criterion: SplitCriterion::Gini,
|
||||||
max_depth: None,
|
max_depth: None,
|
||||||
min_samples_leaf: 1,
|
min_samples_leaf: 1,
|
||||||
|
min_samples_split: 2,
|
||||||
n_trees: 100,
|
n_trees: 100,
|
||||||
mtry: Option::None
|
mtry: Option::None
|
||||||
}
|
}
|
||||||
@@ -58,7 +60,8 @@ impl RandomForest {
|
|||||||
let params = DecisionTreeClassifierParameters{
|
let params = DecisionTreeClassifierParameters{
|
||||||
criterion: parameters.criterion.clone(),
|
criterion: parameters.criterion.clone(),
|
||||||
max_depth: parameters.max_depth,
|
max_depth: parameters.max_depth,
|
||||||
min_samples_leaf: parameters.min_samples_leaf
|
min_samples_leaf: parameters.min_samples_leaf,
|
||||||
|
min_samples_split: parameters.min_samples_split
|
||||||
};
|
};
|
||||||
let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params);
|
let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params);
|
||||||
trees.push(tree);
|
trees.push(tree);
|
||||||
|
|||||||
@@ -7,7 +7,8 @@ use crate::algorithm::sort::quick_sort::QuickArgSort;
|
|||||||
pub struct DecisionTreeClassifierParameters {
|
pub struct DecisionTreeClassifierParameters {
|
||||||
pub criterion: SplitCriterion,
|
pub criterion: SplitCriterion,
|
||||||
pub max_depth: Option<u16>,
|
pub max_depth: Option<u16>,
|
||||||
pub min_samples_leaf: u16
|
pub min_samples_leaf: u16,
|
||||||
|
pub min_samples_split: usize
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@@ -43,7 +44,8 @@ impl Default for DecisionTreeClassifierParameters {
|
|||||||
DecisionTreeClassifierParameters {
|
DecisionTreeClassifierParameters {
|
||||||
criterion: SplitCriterion::Gini,
|
criterion: SplitCriterion::Gini,
|
||||||
max_depth: None,
|
max_depth: None,
|
||||||
min_samples_leaf: 1
|
min_samples_leaf: 1,
|
||||||
|
min_samples_split: 2
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -437,7 +439,7 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(y, DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x));
|
assert_eq!(y, DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x));
|
||||||
|
|
||||||
assert_eq!(3, DecisionTreeClassifier::fit(&x, &y, DecisionTreeClassifierParameters{criterion: SplitCriterion::Entropy, max_depth: Some(3), min_samples_leaf: 1}).depth);
|
assert_eq!(3, DecisionTreeClassifier::fit(&x, &y, DecisionTreeClassifierParameters{criterion: SplitCriterion::Entropy, max_depth: Some(3), min_samples_leaf: 1, min_samples_split: 2}).depth);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ impl Default for DecisionTreeRegressorParameters {
|
|||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
DecisionTreeRegressorParameters {
|
DecisionTreeRegressorParameters {
|
||||||
max_depth: None,
|
max_depth: None,
|
||||||
min_samples_leaf: 2,
|
min_samples_leaf: 1,
|
||||||
min_samples_split: 2
|
min_samples_split: 2
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -132,7 +132,7 @@ impl DecisionTreeRegressor {
|
|||||||
|
|
||||||
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
||||||
match visitor_queue.pop_front() {
|
match visitor_queue.pop_front() {
|
||||||
Some(node) => tree.split(node, mtry, &mut visitor_queue,),
|
Some(node) => tree.split(node, mtry, &mut visitor_queue),
|
||||||
None => break
|
None => break
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@@ -186,7 +186,7 @@ impl DecisionTreeRegressor {
|
|||||||
|
|
||||||
let n: usize = visitor.samples.iter().sum();
|
let n: usize = visitor.samples.iter().sum();
|
||||||
|
|
||||||
if n <= self.parameters.min_samples_split {
|
if n < self.parameters.min_samples_split {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -289,13 +289,13 @@ impl DecisionTreeRegressor {
|
|||||||
|
|
||||||
let mut true_visitor = NodeVisitor::<M>::new(true_child_idx, true_samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
|
let mut true_visitor = NodeVisitor::<M>::new(true_child_idx, true_samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
|
||||||
|
|
||||||
if tc > self.parameters.min_samples_split && self.find_best_cutoff(&mut true_visitor, mtry) {
|
if self.find_best_cutoff(&mut true_visitor, mtry) {
|
||||||
visitor_queue.push_back(true_visitor);
|
visitor_queue.push_back(true_visitor);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut false_visitor = NodeVisitor::<M>::new(false_child_idx, visitor.samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
|
let mut false_visitor = NodeVisitor::<M>::new(false_child_idx, visitor.samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
|
||||||
|
|
||||||
if fc > self.parameters.min_samples_split && self.find_best_cutoff(&mut false_visitor, mtry) {
|
if self.find_best_cutoff(&mut false_visitor, mtry) {
|
||||||
visitor_queue.push_back(false_visitor);
|
visitor_queue.push_back(false_visitor);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -331,11 +331,10 @@ mod tests {
|
|||||||
&[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
|
&[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
|
||||||
let y = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
let y = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
||||||
|
|
||||||
let expected_y = vec![85.6, 89.0, 85.6, 89.0, 97.15, 97.15, 100.0, 100.0, 100.0, 107.9, 107.9, 107.9, 113.4, 113.4, 116.3, 116.3];
|
let y_hat = DecisionTreeRegressor::fit(&x, &y, Default::default()).predict(&x);
|
||||||
let y_hat = DecisionTreeRegressor::fit(&x, &y, Default::default()).predict(&x);
|
|
||||||
|
|
||||||
for i in 0..y_hat.len() {
|
for i in 0..y_hat.len() {
|
||||||
assert!((y_hat[i] - expected_y[i]).abs() < 0.1);
|
assert!((y_hat[i] - y[i]).abs() < 0.1);
|
||||||
}
|
}
|
||||||
|
|
||||||
let expected_y = vec![87.3, 87.3, 87.3, 87.3, 98.9, 98.9, 98.9, 98.9, 98.9, 107.9, 107.9, 107.9, 114.85, 114.85, 114.85, 114.85];
|
let expected_y = vec![87.3, 87.3, 87.3, 87.3, 98.9, 98.9, 98.9, 98.9, 98.9, 107.9, 107.9, 107.9, 114.85, 114.85, 114.85, 114.85];
|
||||||
@@ -343,7 +342,14 @@ mod tests {
|
|||||||
|
|
||||||
for i in 0..y_hat.len() {
|
for i in 0..y_hat.len() {
|
||||||
assert!((y_hat[i] - expected_y[i]).abs() < 0.1);
|
assert!((y_hat[i] - expected_y[i]).abs() < 0.1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let expected_y = vec![83.0, 88.35, 88.35, 89.5, 97.15, 97.15, 99.5, 99.5, 101.2, 104.6, 109.6, 109.6, 113.4, 113.4, 116.30, 116.30];
|
||||||
|
let y_hat = DecisionTreeRegressor::fit(&x, &y, DecisionTreeRegressorParameters{max_depth: Option::None, min_samples_leaf: 1, min_samples_split: 3}).predict(&x);
|
||||||
|
|
||||||
|
for i in 0..y_hat.len() {
|
||||||
|
assert!((y_hat[i] - expected_y[i]).abs() < 0.1);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user