fix: refactors decision_tree_classifier

This commit is contained in:
Volodymyr Orlov
2020-03-24 11:07:05 -07:00
parent 18243e658b
commit 84ffd331cd
+2 -3
View File
@@ -270,7 +270,7 @@ impl DecisionTreeClassifier {
let n = visitor.samples.iter().sum(); let n = visitor.samples.iter().sum();
if n <= self.parameters.min_samples_leaf { if n <= self.parameters.min_samples_split {
return false; return false;
} }
@@ -302,7 +302,6 @@ impl DecisionTreeClassifier {
let mut true_count = vec![0; self.num_classes]; let mut true_count = vec![0; self.num_classes];
let mut prevx = std::f64::NAN; let mut prevx = std::f64::NAN;
let mut prevy = 0; let mut prevy = 0;
let node_size = 1;
for i in visitor.order[j].iter() { for i in visitor.order[j].iter() {
if visitor.samples[*i] > 0 { if visitor.samples[*i] > 0 {
@@ -316,7 +315,7 @@ impl DecisionTreeClassifier {
let tc = true_count.iter().sum(); let tc = true_count.iter().sum();
let fc = n - tc; let fc = n - tc;
if tc < node_size || fc < node_size { if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf {
prevx = visitor.x.get(*i, j); prevx = visitor.x.get(*i, j);
prevy = visitor.y[*i]; prevy = visitor.y[*i];
true_count[visitor.y[*i]] += visitor.samples[*i]; true_count[visitor.y[*i]] += visitor.samples[*i];