fix: refactors decision_tree_classifier

This commit is contained in:
Volodymyr Orlov
2020-03-24 11:07:05 -07:00
parent 18243e658b
commit 84ffd331cd
+2 -3
View File
@@ -270,7 +270,7 @@ impl DecisionTreeClassifier {
let n = visitor.samples.iter().sum();
if n <= self.parameters.min_samples_leaf {
if n <= self.parameters.min_samples_split {
return false;
}
@@ -302,7 +302,6 @@ impl DecisionTreeClassifier {
let mut true_count = vec![0; self.num_classes];
let mut prevx = std::f64::NAN;
let mut prevy = 0;
let node_size = 1;
for i in visitor.order[j].iter() {
if visitor.samples[*i] > 0 {
@@ -316,7 +315,7 @@ impl DecisionTreeClassifier {
let tc = true_count.iter().sum();
let fc = n - tc;
if tc < node_size || fc < node_size {
if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf {
prevx = visitor.x.get(*i, j);
prevy = visitor.y[*i];
true_count[visitor.y[*i]] += visitor.samples[*i];