diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index 3d189e6..996210e 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -270,7 +270,7 @@ impl DecisionTreeClassifier { let n = visitor.samples.iter().sum(); - if n <= self.parameters.min_samples_leaf { + if n <= self.parameters.min_samples_split { return false; } @@ -301,8 +301,7 @@ impl DecisionTreeClassifier { let mut true_count = vec![0; self.num_classes]; let mut prevx = std::f64::NAN; - let mut prevy = 0; - let node_size = 1; + let mut prevy = 0; for i in visitor.order[j].iter() { if visitor.samples[*i] > 0 { @@ -316,7 +315,7 @@ impl DecisionTreeClassifier { let tc = true_count.iter().sum(); let fc = n - tc; - if tc < node_size || fc < node_size { + if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf { prevx = visitor.x.get(*i, j); prevy = visitor.y[*i]; true_count[visitor.y[*i]] += visitor.samples[*i];