fix: minor refactoring

This commit is contained in:
Volodymyr Orlov
2020-03-23 18:50:58 -07:00
parent 17200fe633
commit 18243e658b
2 changed files with 14 additions and 14 deletions
+2 -2
View File
@@ -9,7 +9,7 @@ use crate::tree::decision_tree_classifier::{DecisionTreeClassifier, DecisionTree
pub struct RandomForestClassifierParameters {
pub criterion: SplitCriterion,
pub max_depth: Option<u16>,
pub min_samples_leaf: u16,
pub min_samples_leaf: usize,
pub min_samples_split: usize,
pub n_trees: u16,
pub mtry: Option<usize>
@@ -97,7 +97,7 @@ impl RandomForestClassifier {
}
fn sample_with_replacement(y: &Vec<usize>, num_classes: usize) -> Vec<u32>{
fn sample_with_replacement(y: &Vec<usize>, num_classes: usize) -> Vec<usize>{
let mut rng = rand::thread_rng();
let class_weight = vec![1.; num_classes];
let nrows = y.len();
+12 -12
View File
@@ -7,7 +7,7 @@ use crate::algorithm::sort::quick_sort::QuickArgSort;
pub struct DecisionTreeClassifierParameters {
pub criterion: SplitCriterion,
pub max_depth: Option<u16>,
pub min_samples_leaf: u16,
pub min_samples_leaf: usize,
pub min_samples_split: usize
}
@@ -68,14 +68,14 @@ struct NodeVisitor<'a, M: Matrix> {
x: &'a M,
y: &'a Vec<usize>,
node: usize,
samples: Vec<u32>,
samples: Vec<usize>,
order: &'a Vec<Vec<usize>>,
true_child_output: usize,
false_child_output: usize,
level: u16
}
fn impurity(criterion: &SplitCriterion, count: &Vec<u32>, n: u32) -> f64 {
fn impurity(criterion: &SplitCriterion, count: &Vec<usize>, n: usize) -> f64 {
let mut impurity = 0.;
match criterion {
@@ -112,7 +112,7 @@ fn impurity(criterion: &SplitCriterion, count: &Vec<u32>, n: u32) -> f64 {
impl<'a, M: Matrix> NodeVisitor<'a, M> {
fn new(node_id: usize, samples: Vec<u32>, order: &'a Vec<Vec<usize>>, x: &'a M, y: &'a Vec<usize>, level: u16) -> Self {
fn new(node_id: usize, samples: Vec<usize>, order: &'a Vec<Vec<usize>>, x: &'a M, y: &'a Vec<usize>, level: u16) -> Self {
NodeVisitor {
x: x,
y: y,
@@ -127,7 +127,7 @@ impl<'a, M: Matrix> NodeVisitor<'a, M> {
}
pub(in crate) fn which_max(x: &Vec<u32>) -> usize {
pub(in crate) fn which_max(x: &Vec<usize>) -> usize {
let mut m = x[0];
let mut which = 0;
@@ -149,7 +149,7 @@ impl DecisionTreeClassifier {
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
}
pub fn fit_weak_learner<M: Matrix>(x: &M, y: &M::RowVector, samples: Vec<u32>, mtry: usize, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier {
pub fn fit_weak_learner<M: Matrix>(x: &M, y: &M::RowVector, samples: Vec<usize>, mtry: usize, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier {
let y_m = M::from_row_vector(y.clone());
let (_, y_ncols) = y_m.shape();
let (_, num_attributes) = x.shape();
@@ -270,7 +270,7 @@ impl DecisionTreeClassifier {
let n = visitor.samples.iter().sum();
if n <= self.parameters.min_samples_leaf as u32 {
if n <= self.parameters.min_samples_leaf {
return false;
}
@@ -297,7 +297,7 @@ impl DecisionTreeClassifier {
}
fn find_best_split<M: Matrix>(&mut self, visitor: &mut NodeVisitor<M>, n: u32, count: &Vec<u32>, false_count: &mut Vec<u32>, parent_impurity: f64, j: usize){
fn find_best_split<M: Matrix>(&mut self, visitor: &mut NodeVisitor<M>, n: usize, count: &Vec<usize>, false_count: &mut Vec<usize>, parent_impurity: f64, j: usize){
let mut true_count = vec![0; self.num_classes];
let mut prevx = std::f64::NAN;
@@ -351,7 +351,7 @@ impl DecisionTreeClassifier {
let (n, _) = visitor.x.shape();
let mut tc = 0;
let mut fc = 0;
let mut true_samples: Vec<u32> = vec![0; n];
let mut true_samples: Vec<usize> = vec![0; n];
for i in 0..n {
if visitor.samples[i] > 0 {
@@ -365,7 +365,7 @@ impl DecisionTreeClassifier {
}
}
if tc < self.parameters.min_samples_leaf as u32 || fc < self.parameters.min_samples_leaf as u32 {
if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf {
self.nodes[visitor.node].split_feature = 0;
self.nodes[visitor.node].split_value = std::f64::NAN;
self.nodes[visitor.node].split_score = std::f64::NAN;
@@ -384,13 +384,13 @@ impl DecisionTreeClassifier {
let mut true_visitor = NodeVisitor::<M>::new(true_child_idx, true_samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
if tc > self.parameters.min_samples_leaf as u32 && self.find_best_cutoff(&mut true_visitor, mtry) {
if self.find_best_cutoff(&mut true_visitor, mtry) {
visitor_queue.push_back(true_visitor);
}
let mut false_visitor = NodeVisitor::<M>::new(false_child_idx, visitor.samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
if fc > self.parameters.min_samples_leaf as u32 && self.find_best_cutoff(&mut false_visitor, mtry) {
if self.find_best_cutoff(&mut false_visitor, mtry) {
visitor_queue.push_back(false_visitor);
}