From 18243e658b5e923ce93ad22701f1e109fd34f54e Mon Sep 17 00:00:00 2001 From: Volodymyr Orlov Date: Mon, 23 Mar 2020 18:50:58 -0700 Subject: [PATCH] fix: minor refactoring --- src/ensemble/random_forest_classifier.rs | 4 ++-- src/tree/decision_tree_classifier.rs | 24 ++++++++++++------------ 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index 9175b28..d338bcd 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -9,7 +9,7 @@ use crate::tree::decision_tree_classifier::{DecisionTreeClassifier, DecisionTree pub struct RandomForestClassifierParameters { pub criterion: SplitCriterion, pub max_depth: Option, - pub min_samples_leaf: u16, + pub min_samples_leaf: usize, pub min_samples_split: usize, pub n_trees: u16, pub mtry: Option @@ -97,7 +97,7 @@ impl RandomForestClassifier { } - fn sample_with_replacement(y: &Vec, num_classes: usize) -> Vec{ + fn sample_with_replacement(y: &Vec, num_classes: usize) -> Vec{ let mut rng = rand::thread_rng(); let class_weight = vec![1.; num_classes]; let nrows = y.len(); diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index 4ac70c2..3d189e6 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -7,7 +7,7 @@ use crate::algorithm::sort::quick_sort::QuickArgSort; pub struct DecisionTreeClassifierParameters { pub criterion: SplitCriterion, pub max_depth: Option, - pub min_samples_leaf: u16, + pub min_samples_leaf: usize, pub min_samples_split: usize } @@ -68,14 +68,14 @@ struct NodeVisitor<'a, M: Matrix> { x: &'a M, y: &'a Vec, node: usize, - samples: Vec, + samples: Vec, order: &'a Vec>, true_child_output: usize, false_child_output: usize, level: u16 } -fn impurity(criterion: &SplitCriterion, count: &Vec, n: u32) -> f64 { +fn impurity(criterion: &SplitCriterion, count: &Vec, n: usize) -> f64 { let mut impurity = 0.; match criterion { @@ -112,7 +112,7 @@ fn impurity(criterion: &SplitCriterion, count: &Vec, n: u32) -> f64 { impl<'a, M: Matrix> NodeVisitor<'a, M> { - fn new(node_id: usize, samples: Vec, order: &'a Vec>, x: &'a M, y: &'a Vec, level: u16) -> Self { + fn new(node_id: usize, samples: Vec, order: &'a Vec>, x: &'a M, y: &'a Vec, level: u16) -> Self { NodeVisitor { x: x, y: y, @@ -127,7 +127,7 @@ impl<'a, M: Matrix> NodeVisitor<'a, M> { } -pub(in crate) fn which_max(x: &Vec) -> usize { +pub(in crate) fn which_max(x: &Vec) -> usize { let mut m = x[0]; let mut which = 0; @@ -149,7 +149,7 @@ impl DecisionTreeClassifier { DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters) } - pub fn fit_weak_learner(x: &M, y: &M::RowVector, samples: Vec, mtry: usize, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier { + pub fn fit_weak_learner(x: &M, y: &M::RowVector, samples: Vec, mtry: usize, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier { let y_m = M::from_row_vector(y.clone()); let (_, y_ncols) = y_m.shape(); let (_, num_attributes) = x.shape(); @@ -270,7 +270,7 @@ impl DecisionTreeClassifier { let n = visitor.samples.iter().sum(); - if n <= self.parameters.min_samples_leaf as u32 { + if n <= self.parameters.min_samples_leaf { return false; } @@ -297,7 +297,7 @@ impl DecisionTreeClassifier { } - fn find_best_split(&mut self, visitor: &mut NodeVisitor, n: u32, count: &Vec, false_count: &mut Vec, parent_impurity: f64, j: usize){ + fn find_best_split(&mut self, visitor: &mut NodeVisitor, n: usize, count: &Vec, false_count: &mut Vec, parent_impurity: f64, j: usize){ let mut true_count = vec![0; self.num_classes]; let mut prevx = std::f64::NAN; @@ -351,7 +351,7 @@ impl DecisionTreeClassifier { let (n, _) = visitor.x.shape(); let mut tc = 0; let mut fc = 0; - let mut true_samples: Vec = vec![0; n]; + let mut true_samples: Vec = vec![0; n]; for i in 0..n { if visitor.samples[i] > 0 { @@ -365,7 +365,7 @@ impl DecisionTreeClassifier { } } - if tc < self.parameters.min_samples_leaf as u32 || fc < self.parameters.min_samples_leaf as u32 { + if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf { self.nodes[visitor.node].split_feature = 0; self.nodes[visitor.node].split_value = std::f64::NAN; self.nodes[visitor.node].split_score = std::f64::NAN; @@ -384,13 +384,13 @@ impl DecisionTreeClassifier { let mut true_visitor = NodeVisitor::::new(true_child_idx, true_samples, visitor.order, visitor.x, visitor.y, visitor.level + 1); - if tc > self.parameters.min_samples_leaf as u32 && self.find_best_cutoff(&mut true_visitor, mtry) { + if self.find_best_cutoff(&mut true_visitor, mtry) { visitor_queue.push_back(true_visitor); } let mut false_visitor = NodeVisitor::::new(false_child_idx, visitor.samples, visitor.order, visitor.x, visitor.y, visitor.level + 1); - if fc > self.parameters.min_samples_leaf as u32 && self.find_best_cutoff(&mut false_visitor, mtry) { + if self.find_best_cutoff(&mut false_visitor, mtry) { visitor_queue.push_back(false_visitor); }