fix: minor refactoring

This commit is contained in:
Volodymyr Orlov
2020-03-20 15:58:10 -07:00
parent a96f303dea
commit 6577e22111
3 changed files with 19 additions and 19 deletions
+5 -5
View File
@@ -3,7 +3,7 @@ extern crate rand;
use rand::Rng;
use std::default::Default;
use crate::linalg::Matrix;
use crate::tree::decision_tree::{DecisionTree, DecisionTreeParameters, SplitCriterion, which_max};
use crate::tree::decision_tree_classifier::{DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion, which_max};
#[derive(Debug, Clone)]
pub struct RandomForestParameters {
@@ -17,7 +17,7 @@ pub struct RandomForestParameters {
#[derive(Debug)]
pub struct RandomForest {
parameters: RandomForestParameters,
trees: Vec<DecisionTree>,
trees: Vec<DecisionTreeClassifier>,
classes: Vec<f64>
}
@@ -51,16 +51,16 @@ impl RandomForest {
let classes = y_m.unique();
let k = classes.len();
let mut trees: Vec<DecisionTree> = Vec::new();
let mut trees: Vec<DecisionTreeClassifier> = Vec::new();
for _ in 0..parameters.n_trees {
let samples = RandomForest::sample_with_replacement(&yi, k);
let params = DecisionTreeParameters{
let params = DecisionTreeClassifierParameters{
criterion: parameters.criterion.clone(),
max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf
};
let tree = DecisionTree::fit_weak_learner(x, y, samples, mtry, params);
let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params);
trees.push(tree);
}
@@ -4,16 +4,16 @@ use crate::linalg::Matrix;
use crate::algorithm::sort::quick_sort::QuickArgSort;
#[derive(Debug)]
pub struct DecisionTreeParameters {
pub struct DecisionTreeClassifierParameters {
pub criterion: SplitCriterion,
pub max_depth: Option<u16>,
pub min_samples_leaf: u16
}
#[derive(Debug)]
pub struct DecisionTree {
pub struct DecisionTreeClassifier {
nodes: Vec<Node>,
parameters: DecisionTreeParameters,
parameters: DecisionTreeClassifierParameters,
num_classes: usize,
classes: Vec<f64>,
depth: u16
@@ -38,9 +38,9 @@ pub struct Node {
}
impl Default for DecisionTreeParameters {
impl Default for DecisionTreeClassifierParameters {
fn default() -> Self {
DecisionTreeParameters {
DecisionTreeClassifierParameters {
criterion: SplitCriterion::Gini,
max_depth: None,
min_samples_leaf: 1
@@ -139,15 +139,15 @@ pub(in crate) fn which_max(x: &Vec<u32>) -> usize {
return which;
}
impl DecisionTree {
impl DecisionTreeClassifier {
pub fn fit<M: Matrix>(x: &M, y: &M::RowVector, parameters: DecisionTreeParameters) -> DecisionTree {
pub fn fit<M: Matrix>(x: &M, y: &M::RowVector, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier {
let (x_nrows, num_attributes) = x.shape();
let samples = vec![1; x_nrows];
DecisionTree::fit_weak_learner(x, y, samples, num_attributes, parameters)
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
}
pub fn fit_weak_learner<M: Matrix>(x: &M, y: &M::RowVector, samples: Vec<u32>, mtry: usize, parameters: DecisionTreeParameters) -> DecisionTree {
pub fn fit_weak_learner<M: Matrix>(x: &M, y: &M::RowVector, samples: Vec<u32>, mtry: usize, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier {
let y_m = M::from_row_vector(y.clone());
let (_, y_ncols) = y_m.shape();
let (_, num_attributes) = x.shape();
@@ -179,7 +179,7 @@ impl DecisionTree {
order.push(x.get_col_as_vec(i).quick_argsort());
}
let mut tree = DecisionTree{
let mut tree = DecisionTreeClassifier{
nodes: nodes,
parameters: parameters,
num_classes: k,
@@ -435,9 +435,9 @@ mod tests {
&[5.2, 2.7, 3.9, 1.4]]);
let y = vec![0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
assert_eq!(y, DecisionTree::fit(&x, &y, Default::default()).predict(&x));
assert_eq!(y, DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x));
assert_eq!(3, DecisionTree::fit(&x, &y, DecisionTreeParameters{criterion: SplitCriterion::Entropy, max_depth: Some(3), min_samples_leaf: 1}).depth);
assert_eq!(3, DecisionTreeClassifier::fit(&x, &y, DecisionTreeClassifierParameters{criterion: SplitCriterion::Entropy, max_depth: Some(3), min_samples_leaf: 1}).depth);
}
@@ -467,7 +467,7 @@ mod tests {
&[0.,0.,0.,1.]]);
let y = vec![1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.];
assert_eq!(y, DecisionTree::fit(&x, &y, Default::default()).predict(&x));
assert_eq!(y, DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x));
}
}
+1 -1
View File
@@ -1,2 +1,2 @@
pub mod decision_tree_regressor;
pub mod decision_tree;
pub mod decision_tree_classifier;