Adds RandomForest
This commit is contained in:
@@ -5,9 +5,9 @@ use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DecisionTreeParameters {
|
||||
criterion: SplitCriterion,
|
||||
max_depth: Option<u16>,
|
||||
min_samples_leaf: u16
|
||||
pub criterion: SplitCriterion,
|
||||
pub max_depth: Option<u16>,
|
||||
pub min_samples_leaf: u16
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -19,7 +19,7 @@ pub struct DecisionTree {
|
||||
depth: u16
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum SplitCriterion {
|
||||
Gini,
|
||||
Entropy,
|
||||
@@ -125,7 +125,7 @@ impl<'a, M: Matrix> NodeVisitor<'a, M> {
|
||||
|
||||
}
|
||||
|
||||
fn which_max(x: &Vec<u32>) -> usize {
|
||||
pub(in crate) fn which_max(x: &Vec<u32>) -> usize {
|
||||
let mut m = x[0];
|
||||
let mut which = 0;
|
||||
|
||||
@@ -142,9 +142,15 @@ fn which_max(x: &Vec<u32>) -> usize {
|
||||
impl DecisionTree {
|
||||
|
||||
pub fn fit<M: Matrix>(x: &M, y: &M::RowVector, parameters: DecisionTreeParameters) -> DecisionTree {
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let samples = vec![1; x_nrows];
|
||||
DecisionTree::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||
}
|
||||
|
||||
pub fn fit_weak_learner<M: Matrix>(x: &M, y: &M::RowVector, samples: Vec<u32>, mtry: usize, parameters: DecisionTreeParameters) -> DecisionTree {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
let (_, y_ncols) = y_m.shape();
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let (_, num_attributes) = x.shape();
|
||||
let classes = y_m.unique();
|
||||
let k = classes.len();
|
||||
if k < 2 {
|
||||
@@ -159,7 +165,6 @@ impl DecisionTree {
|
||||
}
|
||||
|
||||
let mut nodes: Vec<Node> = Vec::new();
|
||||
let samples = vec![1; x_nrows];
|
||||
|
||||
let mut count = vec![0; k];
|
||||
for i in 0..y_ncols {
|
||||
@@ -186,13 +191,13 @@ impl DecisionTree {
|
||||
|
||||
let mut visitor_queue: LinkedList<NodeVisitor<M>> = LinkedList::new();
|
||||
|
||||
if tree.find_best_cutoff(&mut visitor) {
|
||||
if tree.find_best_cutoff(&mut visitor, mtry) {
|
||||
visitor_queue.push_back(visitor);
|
||||
}
|
||||
|
||||
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
||||
match visitor_queue.pop_front() {
|
||||
Some(node) => tree.split(node, &mut visitor_queue),
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue,),
|
||||
None => break
|
||||
};
|
||||
}
|
||||
@@ -212,7 +217,7 @@ impl DecisionTree {
|
||||
result.to_row_vector()
|
||||
}
|
||||
|
||||
fn predict_for_row<M: Matrix>(&self, x: &M, row: usize) -> usize {
|
||||
pub(in crate) fn predict_for_row<M: Matrix>(&self, x: &M, row: usize) -> usize {
|
||||
let mut result = 0;
|
||||
let mut queue: LinkedList<usize> = LinkedList::new();
|
||||
|
||||
@@ -240,7 +245,7 @@ impl DecisionTree {
|
||||
|
||||
}
|
||||
|
||||
fn find_best_cutoff<M: Matrix>(&mut self, visitor: &mut NodeVisitor<M>) -> bool {
|
||||
fn find_best_cutoff<M: Matrix>(&mut self, visitor: &mut NodeVisitor<M>, mtry: usize) -> bool {
|
||||
|
||||
let (n_rows, n_attr) = visitor.x.shape();
|
||||
|
||||
@@ -282,7 +287,7 @@ impl DecisionTree {
|
||||
variables[i] = i;
|
||||
}
|
||||
|
||||
for j in 0..n_attr {
|
||||
for j in 0..mtry {
|
||||
self.find_best_split(visitor, n, &count, &mut false_count, parent_impurity, variables[j]);
|
||||
}
|
||||
|
||||
@@ -340,7 +345,7 @@ impl DecisionTree {
|
||||
|
||||
}
|
||||
|
||||
fn split<'a, M: Matrix>(&mut self, mut visitor: NodeVisitor<'a, M>, visitor_queue: &mut LinkedList<NodeVisitor<'a, M>>) -> bool {
|
||||
fn split<'a, M: Matrix>(&mut self, mut visitor: NodeVisitor<'a, M>, mtry: usize, visitor_queue: &mut LinkedList<NodeVisitor<'a, M>>) -> bool {
|
||||
let (n, _) = visitor.x.shape();
|
||||
let mut tc = 0;
|
||||
let mut fc = 0;
|
||||
@@ -377,13 +382,13 @@ impl DecisionTree {
|
||||
|
||||
let mut true_visitor = NodeVisitor::<M>::new(true_child_idx, true_samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
|
||||
|
||||
if tc > self.parameters.min_samples_leaf as u32 && self.find_best_cutoff(&mut true_visitor) {
|
||||
if tc > self.parameters.min_samples_leaf as u32 && self.find_best_cutoff(&mut true_visitor, mtry) {
|
||||
visitor_queue.push_back(true_visitor);
|
||||
}
|
||||
|
||||
let mut false_visitor = NodeVisitor::<M>::new(false_child_idx, visitor.samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
|
||||
|
||||
if fc > self.parameters.min_samples_leaf as u32 && self.find_best_cutoff(&mut false_visitor) {
|
||||
if fc > self.parameters.min_samples_leaf as u32 && self.find_best_cutoff(&mut false_visitor, mtry) {
|
||||
visitor_queue.push_back(false_visitor);
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ use crate::common::Nominal;
|
||||
pub mod knn;
|
||||
pub mod logistic_regression;
|
||||
pub mod decision_tree;
|
||||
pub mod random_forest;
|
||||
|
||||
pub trait Classifier<X, Y>
|
||||
where
|
||||
|
||||
@@ -0,0 +1,160 @@
|
||||
extern crate rand;
|
||||
|
||||
use rand::Rng;
|
||||
use std::default::Default;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::classification::decision_tree::{DecisionTree, DecisionTreeParameters, SplitCriterion, which_max};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RandomForestParameters {
|
||||
pub criterion: SplitCriterion,
|
||||
pub max_depth: Option<u16>,
|
||||
pub min_samples_leaf: u16,
|
||||
pub n_trees: u16,
|
||||
pub mtry: Option<usize>
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RandomForest {
|
||||
parameters: RandomForestParameters,
|
||||
trees: Vec<DecisionTree>,
|
||||
classes: Vec<f64>
|
||||
}
|
||||
|
||||
impl Default for RandomForestParameters {
|
||||
fn default() -> Self {
|
||||
RandomForestParameters {
|
||||
criterion: SplitCriterion::Gini,
|
||||
max_depth: None,
|
||||
min_samples_leaf: 1,
|
||||
n_trees: 100,
|
||||
mtry: Option::None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RandomForest {
|
||||
|
||||
pub fn fit<M: Matrix>(x: &M, y: &M::RowVector, parameters: RandomForestParameters) -> RandomForest {
|
||||
let (_, num_attributes) = x.shape();
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
let (_, y_ncols) = y_m.shape();
|
||||
let mut yi: Vec<usize> = vec![0; y_ncols];
|
||||
let classes = y_m.unique();
|
||||
|
||||
for i in 0..y_ncols {
|
||||
let yc = y_m.get(0, i);
|
||||
yi[i] = classes.iter().position(|c| yc == *c).unwrap();
|
||||
}
|
||||
|
||||
let mtry = parameters.mtry.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
|
||||
|
||||
let classes = y_m.unique();
|
||||
let k = classes.len();
|
||||
let mut trees: Vec<DecisionTree> = Vec::new();
|
||||
|
||||
for _ in 0..parameters.n_trees {
|
||||
let samples = RandomForest::sample_with_replacement(&yi, k);
|
||||
let params = DecisionTreeParameters{
|
||||
criterion: parameters.criterion.clone(),
|
||||
max_depth: parameters.max_depth,
|
||||
min_samples_leaf: parameters.min_samples_leaf
|
||||
};
|
||||
let tree = DecisionTree::fit_weak_learner(x, y, samples, mtry, params);
|
||||
trees.push(tree);
|
||||
}
|
||||
|
||||
RandomForest {
|
||||
parameters: parameters,
|
||||
trees: trees,
|
||||
classes
|
||||
}
|
||||
}
|
||||
|
||||
pub fn predict<M: Matrix>(&self, x: &M) -> M::RowVector {
|
||||
let mut result = M::zeros(1, x.shape().0);
|
||||
|
||||
let (n, _) = x.shape();
|
||||
|
||||
for i in 0..n {
|
||||
result.set(0, i, self.classes[self.predict_for_row(x, i)]);
|
||||
}
|
||||
|
||||
result.to_row_vector()
|
||||
}
|
||||
|
||||
fn predict_for_row<M: Matrix>(&self, x: &M, row: usize) -> usize {
|
||||
let mut result = vec![0; self.classes.len()];
|
||||
|
||||
for tree in self.trees.iter() {
|
||||
result[tree.predict_for_row(x, row)] += 1;
|
||||
}
|
||||
|
||||
return which_max(&result)
|
||||
|
||||
}
|
||||
|
||||
fn sample_with_replacement(y: &Vec<usize>, num_classes: usize) -> Vec<u32>{
|
||||
let mut rng = rand::thread_rng();
|
||||
let class_weight = vec![1.; num_classes];
|
||||
let nrows = y.len();
|
||||
let mut samples = vec![0; nrows];
|
||||
for l in 0..num_classes {
|
||||
let mut nj = 0;
|
||||
let mut cj: Vec<usize> = Vec::new();
|
||||
for i in 0..nrows {
|
||||
if y[i] == l {
|
||||
cj.push(i);
|
||||
nj += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let size = ((nj as f64) / class_weight[l]) as usize;
|
||||
for _ in 0..size {
|
||||
let xi: usize = rng.gen_range(0, nj);
|
||||
samples[cj[xi]] += 1;
|
||||
}
|
||||
}
|
||||
samples
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[test]
|
||||
fn fit_predict_iris() {
|
||||
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[5.7, 2.8, 4.5, 1.3],
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4]]);
|
||||
let y = vec![0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
|
||||
|
||||
RandomForest::fit(&x, &y, Default::default());
|
||||
|
||||
assert_eq!(y, RandomForest::fit(&x, &y, Default::default()).predict(&x));
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user