diff --git a/src/ensemble/random_forest.rs b/src/ensemble/random_forest.rs index f26acab..b318597 100644 --- a/src/ensemble/random_forest.rs +++ b/src/ensemble/random_forest.rs @@ -3,7 +3,7 @@ extern crate rand; use rand::Rng; use std::default::Default; use crate::linalg::Matrix; -use crate::tree::decision_tree::{DecisionTree, DecisionTreeParameters, SplitCriterion, which_max}; +use crate::tree::decision_tree_classifier::{DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion, which_max}; #[derive(Debug, Clone)] pub struct RandomForestParameters { @@ -17,7 +17,7 @@ pub struct RandomForestParameters { #[derive(Debug)] pub struct RandomForest { parameters: RandomForestParameters, - trees: Vec, + trees: Vec, classes: Vec } @@ -51,16 +51,16 @@ impl RandomForest { let classes = y_m.unique(); let k = classes.len(); - let mut trees: Vec = Vec::new(); + let mut trees: Vec = Vec::new(); for _ in 0..parameters.n_trees { let samples = RandomForest::sample_with_replacement(&yi, k); - let params = DecisionTreeParameters{ + let params = DecisionTreeClassifierParameters{ criterion: parameters.criterion.clone(), max_depth: parameters.max_depth, min_samples_leaf: parameters.min_samples_leaf }; - let tree = DecisionTree::fit_weak_learner(x, y, samples, mtry, params); + let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params); trees.push(tree); } diff --git a/src/tree/decision_tree.rs b/src/tree/decision_tree_classifier.rs similarity index 93% rename from src/tree/decision_tree.rs rename to src/tree/decision_tree_classifier.rs index 80f9a4a..5c1e483 100644 --- a/src/tree/decision_tree.rs +++ b/src/tree/decision_tree_classifier.rs @@ -4,16 +4,16 @@ use crate::linalg::Matrix; use crate::algorithm::sort::quick_sort::QuickArgSort; #[derive(Debug)] -pub struct DecisionTreeParameters { +pub struct DecisionTreeClassifierParameters { pub criterion: SplitCriterion, pub max_depth: Option, pub min_samples_leaf: u16 } #[derive(Debug)] -pub struct DecisionTree { +pub struct DecisionTreeClassifier { nodes: Vec, - parameters: DecisionTreeParameters, + parameters: DecisionTreeClassifierParameters, num_classes: usize, classes: Vec, depth: u16 @@ -38,9 +38,9 @@ pub struct Node { } -impl Default for DecisionTreeParameters { +impl Default for DecisionTreeClassifierParameters { fn default() -> Self { - DecisionTreeParameters { + DecisionTreeClassifierParameters { criterion: SplitCriterion::Gini, max_depth: None, min_samples_leaf: 1 @@ -139,15 +139,15 @@ pub(in crate) fn which_max(x: &Vec) -> usize { return which; } -impl DecisionTree { +impl DecisionTreeClassifier { - pub fn fit(x: &M, y: &M::RowVector, parameters: DecisionTreeParameters) -> DecisionTree { + pub fn fit(x: &M, y: &M::RowVector, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier { let (x_nrows, num_attributes) = x.shape(); let samples = vec![1; x_nrows]; - DecisionTree::fit_weak_learner(x, y, samples, num_attributes, parameters) + DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters) } - pub fn fit_weak_learner(x: &M, y: &M::RowVector, samples: Vec, mtry: usize, parameters: DecisionTreeParameters) -> DecisionTree { + pub fn fit_weak_learner(x: &M, y: &M::RowVector, samples: Vec, mtry: usize, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier { let y_m = M::from_row_vector(y.clone()); let (_, y_ncols) = y_m.shape(); let (_, num_attributes) = x.shape(); @@ -179,7 +179,7 @@ impl DecisionTree { order.push(x.get_col_as_vec(i).quick_argsort()); } - let mut tree = DecisionTree{ + let mut tree = DecisionTreeClassifier{ nodes: nodes, parameters: parameters, num_classes: k, @@ -435,9 +435,9 @@ mod tests { &[5.2, 2.7, 3.9, 1.4]]); let y = vec![0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]; - assert_eq!(y, DecisionTree::fit(&x, &y, Default::default()).predict(&x)); + assert_eq!(y, DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x)); - assert_eq!(3, DecisionTree::fit(&x, &y, DecisionTreeParameters{criterion: SplitCriterion::Entropy, max_depth: Some(3), min_samples_leaf: 1}).depth); + assert_eq!(3, DecisionTreeClassifier::fit(&x, &y, DecisionTreeClassifierParameters{criterion: SplitCriterion::Entropy, max_depth: Some(3), min_samples_leaf: 1}).depth); } @@ -467,7 +467,7 @@ mod tests { &[0.,0.,0.,1.]]); let y = vec![1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.]; - assert_eq!(y, DecisionTree::fit(&x, &y, Default::default()).predict(&x)); + assert_eq!(y, DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x)); } } \ No newline at end of file diff --git a/src/tree/mod.rs b/src/tree/mod.rs index bea1d51..624c16a 100644 --- a/src/tree/mod.rs +++ b/src/tree/mod.rs @@ -1,2 +1,2 @@ pub mod decision_tree_regressor; -pub mod decision_tree; \ No newline at end of file +pub mod decision_tree_classifier; \ No newline at end of file