fix: refactors decision_tree_classifier
This commit is contained in:
@@ -270,7 +270,7 @@ impl DecisionTreeClassifier {
|
||||
|
||||
let n = visitor.samples.iter().sum();
|
||||
|
||||
if n <= self.parameters.min_samples_leaf {
|
||||
if n <= self.parameters.min_samples_split {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -302,7 +302,6 @@ impl DecisionTreeClassifier {
|
||||
let mut true_count = vec![0; self.num_classes];
|
||||
let mut prevx = std::f64::NAN;
|
||||
let mut prevy = 0;
|
||||
let node_size = 1;
|
||||
|
||||
for i in visitor.order[j].iter() {
|
||||
if visitor.samples[*i] > 0 {
|
||||
@@ -316,7 +315,7 @@ impl DecisionTreeClassifier {
|
||||
let tc = true_count.iter().sum();
|
||||
let fc = n - tc;
|
||||
|
||||
if tc < node_size || fc < node_size {
|
||||
if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf {
|
||||
prevx = visitor.x.get(*i, j);
|
||||
prevy = visitor.y[*i];
|
||||
true_count[visitor.y[*i]] += visitor.samples[*i];
|
||||
|
||||
Reference in New Issue
Block a user