fix: minor refactoring

This commit is contained in:
Volodymyr Orlov
2020-03-23 18:50:58 -07:00
parent 17200fe633
commit 18243e658b
2 changed files with 14 additions and 14 deletions
+2 -2
View File
@@ -9,7 +9,7 @@ use crate::tree::decision_tree_classifier::{DecisionTreeClassifier, DecisionTree
pub struct RandomForestClassifierParameters { pub struct RandomForestClassifierParameters {
pub criterion: SplitCriterion, pub criterion: SplitCriterion,
pub max_depth: Option<u16>, pub max_depth: Option<u16>,
pub min_samples_leaf: u16, pub min_samples_leaf: usize,
pub min_samples_split: usize, pub min_samples_split: usize,
pub n_trees: u16, pub n_trees: u16,
pub mtry: Option<usize> pub mtry: Option<usize>
@@ -97,7 +97,7 @@ impl RandomForestClassifier {
} }
fn sample_with_replacement(y: &Vec<usize>, num_classes: usize) -> Vec<u32>{ fn sample_with_replacement(y: &Vec<usize>, num_classes: usize) -> Vec<usize>{
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let class_weight = vec![1.; num_classes]; let class_weight = vec![1.; num_classes];
let nrows = y.len(); let nrows = y.len();
+12 -12
View File
@@ -7,7 +7,7 @@ use crate::algorithm::sort::quick_sort::QuickArgSort;
pub struct DecisionTreeClassifierParameters { pub struct DecisionTreeClassifierParameters {
pub criterion: SplitCriterion, pub criterion: SplitCriterion,
pub max_depth: Option<u16>, pub max_depth: Option<u16>,
pub min_samples_leaf: u16, pub min_samples_leaf: usize,
pub min_samples_split: usize pub min_samples_split: usize
} }
@@ -68,14 +68,14 @@ struct NodeVisitor<'a, M: Matrix> {
x: &'a M, x: &'a M,
y: &'a Vec<usize>, y: &'a Vec<usize>,
node: usize, node: usize,
samples: Vec<u32>, samples: Vec<usize>,
order: &'a Vec<Vec<usize>>, order: &'a Vec<Vec<usize>>,
true_child_output: usize, true_child_output: usize,
false_child_output: usize, false_child_output: usize,
level: u16 level: u16
} }
fn impurity(criterion: &SplitCriterion, count: &Vec<u32>, n: u32) -> f64 { fn impurity(criterion: &SplitCriterion, count: &Vec<usize>, n: usize) -> f64 {
let mut impurity = 0.; let mut impurity = 0.;
match criterion { match criterion {
@@ -112,7 +112,7 @@ fn impurity(criterion: &SplitCriterion, count: &Vec<u32>, n: u32) -> f64 {
impl<'a, M: Matrix> NodeVisitor<'a, M> { impl<'a, M: Matrix> NodeVisitor<'a, M> {
fn new(node_id: usize, samples: Vec<u32>, order: &'a Vec<Vec<usize>>, x: &'a M, y: &'a Vec<usize>, level: u16) -> Self { fn new(node_id: usize, samples: Vec<usize>, order: &'a Vec<Vec<usize>>, x: &'a M, y: &'a Vec<usize>, level: u16) -> Self {
NodeVisitor { NodeVisitor {
x: x, x: x,
y: y, y: y,
@@ -127,7 +127,7 @@ impl<'a, M: Matrix> NodeVisitor<'a, M> {
} }
pub(in crate) fn which_max(x: &Vec<u32>) -> usize { pub(in crate) fn which_max(x: &Vec<usize>) -> usize {
let mut m = x[0]; let mut m = x[0];
let mut which = 0; let mut which = 0;
@@ -149,7 +149,7 @@ impl DecisionTreeClassifier {
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters) DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
} }
pub fn fit_weak_learner<M: Matrix>(x: &M, y: &M::RowVector, samples: Vec<u32>, mtry: usize, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier { pub fn fit_weak_learner<M: Matrix>(x: &M, y: &M::RowVector, samples: Vec<usize>, mtry: usize, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier {
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
let (_, y_ncols) = y_m.shape(); let (_, y_ncols) = y_m.shape();
let (_, num_attributes) = x.shape(); let (_, num_attributes) = x.shape();
@@ -270,7 +270,7 @@ impl DecisionTreeClassifier {
let n = visitor.samples.iter().sum(); let n = visitor.samples.iter().sum();
if n <= self.parameters.min_samples_leaf as u32 { if n <= self.parameters.min_samples_leaf {
return false; return false;
} }
@@ -297,7 +297,7 @@ impl DecisionTreeClassifier {
} }
fn find_best_split<M: Matrix>(&mut self, visitor: &mut NodeVisitor<M>, n: u32, count: &Vec<u32>, false_count: &mut Vec<u32>, parent_impurity: f64, j: usize){ fn find_best_split<M: Matrix>(&mut self, visitor: &mut NodeVisitor<M>, n: usize, count: &Vec<usize>, false_count: &mut Vec<usize>, parent_impurity: f64, j: usize){
let mut true_count = vec![0; self.num_classes]; let mut true_count = vec![0; self.num_classes];
let mut prevx = std::f64::NAN; let mut prevx = std::f64::NAN;
@@ -351,7 +351,7 @@ impl DecisionTreeClassifier {
let (n, _) = visitor.x.shape(); let (n, _) = visitor.x.shape();
let mut tc = 0; let mut tc = 0;
let mut fc = 0; let mut fc = 0;
let mut true_samples: Vec<u32> = vec![0; n]; let mut true_samples: Vec<usize> = vec![0; n];
for i in 0..n { for i in 0..n {
if visitor.samples[i] > 0 { if visitor.samples[i] > 0 {
@@ -365,7 +365,7 @@ impl DecisionTreeClassifier {
} }
} }
if tc < self.parameters.min_samples_leaf as u32 || fc < self.parameters.min_samples_leaf as u32 { if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf {
self.nodes[visitor.node].split_feature = 0; self.nodes[visitor.node].split_feature = 0;
self.nodes[visitor.node].split_value = std::f64::NAN; self.nodes[visitor.node].split_value = std::f64::NAN;
self.nodes[visitor.node].split_score = std::f64::NAN; self.nodes[visitor.node].split_score = std::f64::NAN;
@@ -384,13 +384,13 @@ impl DecisionTreeClassifier {
let mut true_visitor = NodeVisitor::<M>::new(true_child_idx, true_samples, visitor.order, visitor.x, visitor.y, visitor.level + 1); let mut true_visitor = NodeVisitor::<M>::new(true_child_idx, true_samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
if tc > self.parameters.min_samples_leaf as u32 && self.find_best_cutoff(&mut true_visitor, mtry) { if self.find_best_cutoff(&mut true_visitor, mtry) {
visitor_queue.push_back(true_visitor); visitor_queue.push_back(true_visitor);
} }
let mut false_visitor = NodeVisitor::<M>::new(false_child_idx, visitor.samples, visitor.order, visitor.x, visitor.y, visitor.level + 1); let mut false_visitor = NodeVisitor::<M>::new(false_child_idx, visitor.samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
if fc > self.parameters.min_samples_leaf as u32 && self.find_best_cutoff(&mut false_visitor, mtry) { if self.find_best_cutoff(&mut false_visitor, mtry) {
visitor_queue.push_back(false_visitor); visitor_queue.push_back(false_visitor);
} }