fix: minor refactoring
This commit is contained in:
@@ -9,7 +9,7 @@ use crate::tree::decision_tree_classifier::{DecisionTreeClassifier, DecisionTree
|
||||
pub struct RandomForestClassifierParameters {
|
||||
pub criterion: SplitCriterion,
|
||||
pub max_depth: Option<u16>,
|
||||
pub min_samples_leaf: u16,
|
||||
pub min_samples_leaf: usize,
|
||||
pub min_samples_split: usize,
|
||||
pub n_trees: u16,
|
||||
pub mtry: Option<usize>
|
||||
@@ -97,7 +97,7 @@ impl RandomForestClassifier {
|
||||
|
||||
}
|
||||
|
||||
fn sample_with_replacement(y: &Vec<usize>, num_classes: usize) -> Vec<u32>{
|
||||
fn sample_with_replacement(y: &Vec<usize>, num_classes: usize) -> Vec<usize>{
|
||||
let mut rng = rand::thread_rng();
|
||||
let class_weight = vec![1.; num_classes];
|
||||
let nrows = y.len();
|
||||
|
||||
@@ -7,7 +7,7 @@ use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||
pub struct DecisionTreeClassifierParameters {
|
||||
pub criterion: SplitCriterion,
|
||||
pub max_depth: Option<u16>,
|
||||
pub min_samples_leaf: u16,
|
||||
pub min_samples_leaf: usize,
|
||||
pub min_samples_split: usize
|
||||
}
|
||||
|
||||
@@ -68,14 +68,14 @@ struct NodeVisitor<'a, M: Matrix> {
|
||||
x: &'a M,
|
||||
y: &'a Vec<usize>,
|
||||
node: usize,
|
||||
samples: Vec<u32>,
|
||||
samples: Vec<usize>,
|
||||
order: &'a Vec<Vec<usize>>,
|
||||
true_child_output: usize,
|
||||
false_child_output: usize,
|
||||
level: u16
|
||||
}
|
||||
|
||||
fn impurity(criterion: &SplitCriterion, count: &Vec<u32>, n: u32) -> f64 {
|
||||
fn impurity(criterion: &SplitCriterion, count: &Vec<usize>, n: usize) -> f64 {
|
||||
let mut impurity = 0.;
|
||||
|
||||
match criterion {
|
||||
@@ -112,7 +112,7 @@ fn impurity(criterion: &SplitCriterion, count: &Vec<u32>, n: u32) -> f64 {
|
||||
|
||||
impl<'a, M: Matrix> NodeVisitor<'a, M> {
|
||||
|
||||
fn new(node_id: usize, samples: Vec<u32>, order: &'a Vec<Vec<usize>>, x: &'a M, y: &'a Vec<usize>, level: u16) -> Self {
|
||||
fn new(node_id: usize, samples: Vec<usize>, order: &'a Vec<Vec<usize>>, x: &'a M, y: &'a Vec<usize>, level: u16) -> Self {
|
||||
NodeVisitor {
|
||||
x: x,
|
||||
y: y,
|
||||
@@ -127,7 +127,7 @@ impl<'a, M: Matrix> NodeVisitor<'a, M> {
|
||||
|
||||
}
|
||||
|
||||
pub(in crate) fn which_max(x: &Vec<u32>) -> usize {
|
||||
pub(in crate) fn which_max(x: &Vec<usize>) -> usize {
|
||||
let mut m = x[0];
|
||||
let mut which = 0;
|
||||
|
||||
@@ -149,7 +149,7 @@ impl DecisionTreeClassifier {
|
||||
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||
}
|
||||
|
||||
pub fn fit_weak_learner<M: Matrix>(x: &M, y: &M::RowVector, samples: Vec<u32>, mtry: usize, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier {
|
||||
pub fn fit_weak_learner<M: Matrix>(x: &M, y: &M::RowVector, samples: Vec<usize>, mtry: usize, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
let (_, y_ncols) = y_m.shape();
|
||||
let (_, num_attributes) = x.shape();
|
||||
@@ -270,7 +270,7 @@ impl DecisionTreeClassifier {
|
||||
|
||||
let n = visitor.samples.iter().sum();
|
||||
|
||||
if n <= self.parameters.min_samples_leaf as u32 {
|
||||
if n <= self.parameters.min_samples_leaf {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -297,7 +297,7 @@ impl DecisionTreeClassifier {
|
||||
|
||||
}
|
||||
|
||||
fn find_best_split<M: Matrix>(&mut self, visitor: &mut NodeVisitor<M>, n: u32, count: &Vec<u32>, false_count: &mut Vec<u32>, parent_impurity: f64, j: usize){
|
||||
fn find_best_split<M: Matrix>(&mut self, visitor: &mut NodeVisitor<M>, n: usize, count: &Vec<usize>, false_count: &mut Vec<usize>, parent_impurity: f64, j: usize){
|
||||
|
||||
let mut true_count = vec![0; self.num_classes];
|
||||
let mut prevx = std::f64::NAN;
|
||||
@@ -351,7 +351,7 @@ impl DecisionTreeClassifier {
|
||||
let (n, _) = visitor.x.shape();
|
||||
let mut tc = 0;
|
||||
let mut fc = 0;
|
||||
let mut true_samples: Vec<u32> = vec![0; n];
|
||||
let mut true_samples: Vec<usize> = vec![0; n];
|
||||
|
||||
for i in 0..n {
|
||||
if visitor.samples[i] > 0 {
|
||||
@@ -365,7 +365,7 @@ impl DecisionTreeClassifier {
|
||||
}
|
||||
}
|
||||
|
||||
if tc < self.parameters.min_samples_leaf as u32 || fc < self.parameters.min_samples_leaf as u32 {
|
||||
if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf {
|
||||
self.nodes[visitor.node].split_feature = 0;
|
||||
self.nodes[visitor.node].split_value = std::f64::NAN;
|
||||
self.nodes[visitor.node].split_score = std::f64::NAN;
|
||||
@@ -384,13 +384,13 @@ impl DecisionTreeClassifier {
|
||||
|
||||
let mut true_visitor = NodeVisitor::<M>::new(true_child_idx, true_samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
|
||||
|
||||
if tc > self.parameters.min_samples_leaf as u32 && self.find_best_cutoff(&mut true_visitor, mtry) {
|
||||
if self.find_best_cutoff(&mut true_visitor, mtry) {
|
||||
visitor_queue.push_back(true_visitor);
|
||||
}
|
||||
|
||||
let mut false_visitor = NodeVisitor::<M>::new(false_child_idx, visitor.samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
|
||||
|
||||
if fc > self.parameters.min_samples_leaf as u32 && self.find_best_cutoff(&mut false_visitor, mtry) {
|
||||
if self.find_best_cutoff(&mut false_visitor, mtry) {
|
||||
visitor_queue.push_back(false_visitor);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user