fix: minor refactoring

This commit is contained in:
Volodymyr Orlov
2020-03-20 15:58:10 -07:00
parent a96f303dea
commit 6577e22111
3 changed files with 19 additions and 19 deletions
+5 -5
View File
@@ -3,7 +3,7 @@ extern crate rand;
use rand::Rng; use rand::Rng;
use std::default::Default; use std::default::Default;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::tree::decision_tree::{DecisionTree, DecisionTreeParameters, SplitCriterion, which_max}; use crate::tree::decision_tree_classifier::{DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion, which_max};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct RandomForestParameters { pub struct RandomForestParameters {
@@ -17,7 +17,7 @@ pub struct RandomForestParameters {
#[derive(Debug)] #[derive(Debug)]
pub struct RandomForest { pub struct RandomForest {
parameters: RandomForestParameters, parameters: RandomForestParameters,
trees: Vec<DecisionTree>, trees: Vec<DecisionTreeClassifier>,
classes: Vec<f64> classes: Vec<f64>
} }
@@ -51,16 +51,16 @@ impl RandomForest {
let classes = y_m.unique(); let classes = y_m.unique();
let k = classes.len(); let k = classes.len();
let mut trees: Vec<DecisionTree> = Vec::new(); let mut trees: Vec<DecisionTreeClassifier> = Vec::new();
for _ in 0..parameters.n_trees { for _ in 0..parameters.n_trees {
let samples = RandomForest::sample_with_replacement(&yi, k); let samples = RandomForest::sample_with_replacement(&yi, k);
let params = DecisionTreeParameters{ let params = DecisionTreeClassifierParameters{
criterion: parameters.criterion.clone(), criterion: parameters.criterion.clone(),
max_depth: parameters.max_depth, max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf min_samples_leaf: parameters.min_samples_leaf
}; };
let tree = DecisionTree::fit_weak_learner(x, y, samples, mtry, params); let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params);
trees.push(tree); trees.push(tree);
} }
@@ -4,16 +4,16 @@ use crate::linalg::Matrix;
use crate::algorithm::sort::quick_sort::QuickArgSort; use crate::algorithm::sort::quick_sort::QuickArgSort;
#[derive(Debug)] #[derive(Debug)]
pub struct DecisionTreeParameters { pub struct DecisionTreeClassifierParameters {
pub criterion: SplitCriterion, pub criterion: SplitCriterion,
pub max_depth: Option<u16>, pub max_depth: Option<u16>,
pub min_samples_leaf: u16 pub min_samples_leaf: u16
} }
#[derive(Debug)] #[derive(Debug)]
pub struct DecisionTree { pub struct DecisionTreeClassifier {
nodes: Vec<Node>, nodes: Vec<Node>,
parameters: DecisionTreeParameters, parameters: DecisionTreeClassifierParameters,
num_classes: usize, num_classes: usize,
classes: Vec<f64>, classes: Vec<f64>,
depth: u16 depth: u16
@@ -38,9 +38,9 @@ pub struct Node {
} }
impl Default for DecisionTreeParameters { impl Default for DecisionTreeClassifierParameters {
fn default() -> Self { fn default() -> Self {
DecisionTreeParameters { DecisionTreeClassifierParameters {
criterion: SplitCriterion::Gini, criterion: SplitCriterion::Gini,
max_depth: None, max_depth: None,
min_samples_leaf: 1 min_samples_leaf: 1
@@ -139,15 +139,15 @@ pub(in crate) fn which_max(x: &Vec<u32>) -> usize {
return which; return which;
} }
impl DecisionTree { impl DecisionTreeClassifier {
pub fn fit<M: Matrix>(x: &M, y: &M::RowVector, parameters: DecisionTreeParameters) -> DecisionTree { pub fn fit<M: Matrix>(x: &M, y: &M::RowVector, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier {
let (x_nrows, num_attributes) = x.shape(); let (x_nrows, num_attributes) = x.shape();
let samples = vec![1; x_nrows]; let samples = vec![1; x_nrows];
DecisionTree::fit_weak_learner(x, y, samples, num_attributes, parameters) DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
} }
pub fn fit_weak_learner<M: Matrix>(x: &M, y: &M::RowVector, samples: Vec<u32>, mtry: usize, parameters: DecisionTreeParameters) -> DecisionTree { pub fn fit_weak_learner<M: Matrix>(x: &M, y: &M::RowVector, samples: Vec<u32>, mtry: usize, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier {
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
let (_, y_ncols) = y_m.shape(); let (_, y_ncols) = y_m.shape();
let (_, num_attributes) = x.shape(); let (_, num_attributes) = x.shape();
@@ -179,7 +179,7 @@ impl DecisionTree {
order.push(x.get_col_as_vec(i).quick_argsort()); order.push(x.get_col_as_vec(i).quick_argsort());
} }
let mut tree = DecisionTree{ let mut tree = DecisionTreeClassifier{
nodes: nodes, nodes: nodes,
parameters: parameters, parameters: parameters,
num_classes: k, num_classes: k,
@@ -435,9 +435,9 @@ mod tests {
&[5.2, 2.7, 3.9, 1.4]]); &[5.2, 2.7, 3.9, 1.4]]);
let y = vec![0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]; let y = vec![0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
assert_eq!(y, DecisionTree::fit(&x, &y, Default::default()).predict(&x)); assert_eq!(y, DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x));
assert_eq!(3, DecisionTree::fit(&x, &y, DecisionTreeParameters{criterion: SplitCriterion::Entropy, max_depth: Some(3), min_samples_leaf: 1}).depth); assert_eq!(3, DecisionTreeClassifier::fit(&x, &y, DecisionTreeClassifierParameters{criterion: SplitCriterion::Entropy, max_depth: Some(3), min_samples_leaf: 1}).depth);
} }
@@ -467,7 +467,7 @@ mod tests {
&[0.,0.,0.,1.]]); &[0.,0.,0.,1.]]);
let y = vec![1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.]; let y = vec![1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.];
assert_eq!(y, DecisionTree::fit(&x, &y, Default::default()).predict(&x)); assert_eq!(y, DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x));
} }
} }
+1 -1
View File
@@ -1,2 +1,2 @@
pub mod decision_tree_regressor; pub mod decision_tree_regressor;
pub mod decision_tree; pub mod decision_tree_classifier;