diff --git a/src/classification/decision_tree.rs b/src/classification/decision_tree.rs index a9cba83..70240a1 100644 --- a/src/classification/decision_tree.rs +++ b/src/classification/decision_tree.rs @@ -5,9 +5,9 @@ use crate::algorithm::sort::quick_sort::QuickArgSort; #[derive(Debug)] pub struct DecisionTreeParameters { - criterion: SplitCriterion, - max_depth: Option, - min_samples_leaf: u16 + pub criterion: SplitCriterion, + pub max_depth: Option, + pub min_samples_leaf: u16 } #[derive(Debug)] @@ -19,7 +19,7 @@ pub struct DecisionTree { depth: u16 } -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum SplitCriterion { Gini, Entropy, @@ -125,7 +125,7 @@ impl<'a, M: Matrix> NodeVisitor<'a, M> { } -fn which_max(x: &Vec) -> usize { +pub(in crate) fn which_max(x: &Vec) -> usize { let mut m = x[0]; let mut which = 0; @@ -142,9 +142,15 @@ fn which_max(x: &Vec) -> usize { impl DecisionTree { pub fn fit(x: &M, y: &M::RowVector, parameters: DecisionTreeParameters) -> DecisionTree { + let (x_nrows, num_attributes) = x.shape(); + let samples = vec![1; x_nrows]; + DecisionTree::fit_weak_learner(x, y, samples, num_attributes, parameters) + } + + pub fn fit_weak_learner(x: &M, y: &M::RowVector, samples: Vec, mtry: usize, parameters: DecisionTreeParameters) -> DecisionTree { let y_m = M::from_row_vector(y.clone()); let (_, y_ncols) = y_m.shape(); - let (x_nrows, num_attributes) = x.shape(); + let (_, num_attributes) = x.shape(); let classes = y_m.unique(); let k = classes.len(); if k < 2 { @@ -158,8 +164,7 @@ impl DecisionTree { yi[i] = classes.iter().position(|c| yc == *c).unwrap(); } - let mut nodes: Vec = Vec::new(); - let samples = vec![1; x_nrows]; + let mut nodes: Vec = Vec::new(); let mut count = vec![0; k]; for i in 0..y_ncols { @@ -186,13 +191,13 @@ impl DecisionTree { let mut visitor_queue: LinkedList> = LinkedList::new(); - if tree.find_best_cutoff(&mut visitor) { + if tree.find_best_cutoff(&mut visitor, mtry) { visitor_queue.push_back(visitor); } while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) { match visitor_queue.pop_front() { - Some(node) => tree.split(node, &mut visitor_queue), + Some(node) => tree.split(node, mtry, &mut visitor_queue,), None => break }; } @@ -212,7 +217,7 @@ impl DecisionTree { result.to_row_vector() } - fn predict_for_row(&self, x: &M, row: usize) -> usize { + pub(in crate) fn predict_for_row(&self, x: &M, row: usize) -> usize { let mut result = 0; let mut queue: LinkedList = LinkedList::new(); @@ -240,7 +245,7 @@ impl DecisionTree { } - fn find_best_cutoff(&mut self, visitor: &mut NodeVisitor) -> bool { + fn find_best_cutoff(&mut self, visitor: &mut NodeVisitor, mtry: usize) -> bool { let (n_rows, n_attr) = visitor.x.shape(); @@ -282,7 +287,7 @@ impl DecisionTree { variables[i] = i; } - for j in 0..n_attr { + for j in 0..mtry { self.find_best_split(visitor, n, &count, &mut false_count, parent_impurity, variables[j]); } @@ -340,7 +345,7 @@ impl DecisionTree { } - fn split<'a, M: Matrix>(&mut self, mut visitor: NodeVisitor<'a, M>, visitor_queue: &mut LinkedList>) -> bool { + fn split<'a, M: Matrix>(&mut self, mut visitor: NodeVisitor<'a, M>, mtry: usize, visitor_queue: &mut LinkedList>) -> bool { let (n, _) = visitor.x.shape(); let mut tc = 0; let mut fc = 0; @@ -377,13 +382,13 @@ impl DecisionTree { let mut true_visitor = NodeVisitor::::new(true_child_idx, true_samples, visitor.order, visitor.x, visitor.y, visitor.level + 1); - if tc > self.parameters.min_samples_leaf as u32 && self.find_best_cutoff(&mut true_visitor) { + if tc > self.parameters.min_samples_leaf as u32 && self.find_best_cutoff(&mut true_visitor, mtry) { visitor_queue.push_back(true_visitor); } let mut false_visitor = NodeVisitor::::new(false_child_idx, visitor.samples, visitor.order, visitor.x, visitor.y, visitor.level + 1); - if fc > self.parameters.min_samples_leaf as u32 && self.find_best_cutoff(&mut false_visitor) { + if fc > self.parameters.min_samples_leaf as u32 && self.find_best_cutoff(&mut false_visitor, mtry) { visitor_queue.push_back(false_visitor); } diff --git a/src/classification/mod.rs b/src/classification/mod.rs index c181db8..696cd87 100644 --- a/src/classification/mod.rs +++ b/src/classification/mod.rs @@ -3,6 +3,7 @@ use crate::common::Nominal; pub mod knn; pub mod logistic_regression; pub mod decision_tree; +pub mod random_forest; pub trait Classifier where diff --git a/src/classification/random_forest.rs b/src/classification/random_forest.rs new file mode 100644 index 0000000..5a46e59 --- /dev/null +++ b/src/classification/random_forest.rs @@ -0,0 +1,160 @@ +extern crate rand; + +use rand::Rng; +use std::default::Default; +use crate::linalg::Matrix; +use crate::classification::decision_tree::{DecisionTree, DecisionTreeParameters, SplitCriterion, which_max}; + +#[derive(Debug, Clone)] +pub struct RandomForestParameters { + pub criterion: SplitCriterion, + pub max_depth: Option, + pub min_samples_leaf: u16, + pub n_trees: u16, + pub mtry: Option +} + +#[derive(Debug)] +pub struct RandomForest { + parameters: RandomForestParameters, + trees: Vec, + classes: Vec +} + +impl Default for RandomForestParameters { + fn default() -> Self { + RandomForestParameters { + criterion: SplitCriterion::Gini, + max_depth: None, + min_samples_leaf: 1, + n_trees: 100, + mtry: Option::None + } + } +} + +impl RandomForest { + + pub fn fit(x: &M, y: &M::RowVector, parameters: RandomForestParameters) -> RandomForest { + let (_, num_attributes) = x.shape(); + let y_m = M::from_row_vector(y.clone()); + let (_, y_ncols) = y_m.shape(); + let mut yi: Vec = vec![0; y_ncols]; + let classes = y_m.unique(); + + for i in 0..y_ncols { + let yc = y_m.get(0, i); + yi[i] = classes.iter().position(|c| yc == *c).unwrap(); + } + + let mtry = parameters.mtry.unwrap_or((num_attributes as f64).sqrt().floor() as usize); + + let classes = y_m.unique(); + let k = classes.len(); + let mut trees: Vec = Vec::new(); + + for _ in 0..parameters.n_trees { + let samples = RandomForest::sample_with_replacement(&yi, k); + let params = DecisionTreeParameters{ + criterion: parameters.criterion.clone(), + max_depth: parameters.max_depth, + min_samples_leaf: parameters.min_samples_leaf + }; + let tree = DecisionTree::fit_weak_learner(x, y, samples, mtry, params); + trees.push(tree); + } + + RandomForest { + parameters: parameters, + trees: trees, + classes + } + } + + pub fn predict(&self, x: &M) -> M::RowVector { + let mut result = M::zeros(1, x.shape().0); + + let (n, _) = x.shape(); + + for i in 0..n { + result.set(0, i, self.classes[self.predict_for_row(x, i)]); + } + + result.to_row_vector() + } + + fn predict_for_row(&self, x: &M, row: usize) -> usize { + let mut result = vec![0; self.classes.len()]; + + for tree in self.trees.iter() { + result[tree.predict_for_row(x, row)] += 1; + } + + return which_max(&result) + + } + + fn sample_with_replacement(y: &Vec, num_classes: usize) -> Vec{ + let mut rng = rand::thread_rng(); + let class_weight = vec![1.; num_classes]; + let nrows = y.len(); + let mut samples = vec![0; nrows]; + for l in 0..num_classes { + let mut nj = 0; + let mut cj: Vec = Vec::new(); + for i in 0..nrows { + if y[i] == l { + cj.push(i); + nj += 1; + } + } + + let size = ((nj as f64) / class_weight[l]) as usize; + for _ in 0..size { + let xi: usize = rng.gen_range(0, nj); + samples[cj[xi]] += 1; + } + } + samples + } + +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::linalg::naive::dense_matrix::DenseMatrix; + + #[test] + fn fit_predict_iris() { + + let x = DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + &[4.9, 3.1, 1.5, 0.1], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + &[5.7, 2.8, 4.5, 1.3], + &[6.3, 3.3, 4.7, 1.6], + &[4.9, 2.4, 3.3, 1.0], + &[6.6, 2.9, 4.6, 1.3], + &[5.2, 2.7, 3.9, 1.4]]); + let y = vec![0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]; + + RandomForest::fit(&x, &y, Default::default()); + + assert_eq!(y, RandomForest::fit(&x, &y, Default::default()).predict(&x)); + + } + +} \ No newline at end of file