fix: refactors decision_tree_classifier
This commit is contained in:
@@ -270,7 +270,7 @@ impl DecisionTreeClassifier {
|
|||||||
|
|
||||||
let n = visitor.samples.iter().sum();
|
let n = visitor.samples.iter().sum();
|
||||||
|
|
||||||
if n <= self.parameters.min_samples_leaf {
|
if n <= self.parameters.min_samples_split {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -302,7 +302,6 @@ impl DecisionTreeClassifier {
|
|||||||
let mut true_count = vec![0; self.num_classes];
|
let mut true_count = vec![0; self.num_classes];
|
||||||
let mut prevx = std::f64::NAN;
|
let mut prevx = std::f64::NAN;
|
||||||
let mut prevy = 0;
|
let mut prevy = 0;
|
||||||
let node_size = 1;
|
|
||||||
|
|
||||||
for i in visitor.order[j].iter() {
|
for i in visitor.order[j].iter() {
|
||||||
if visitor.samples[*i] > 0 {
|
if visitor.samples[*i] > 0 {
|
||||||
@@ -316,7 +315,7 @@ impl DecisionTreeClassifier {
|
|||||||
let tc = true_count.iter().sum();
|
let tc = true_count.iter().sum();
|
||||||
let fc = n - tc;
|
let fc = n - tc;
|
||||||
|
|
||||||
if tc < node_size || fc < node_size {
|
if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf {
|
||||||
prevx = visitor.x.get(*i, j);
|
prevx = visitor.x.get(*i, j);
|
||||||
prevy = visitor.y[*i];
|
prevy = visitor.y[*i];
|
||||||
true_count[visitor.y[*i]] += visitor.samples[*i];
|
true_count[visitor.y[*i]] += visitor.samples[*i];
|
||||||
|
|||||||
Reference in New Issue
Block a user