Adds RandomForest
This commit is contained in:
@@ -5,9 +5,9 @@ use crate::algorithm::sort::quick_sort::QuickArgSort;
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct DecisionTreeParameters {
|
pub struct DecisionTreeParameters {
|
||||||
criterion: SplitCriterion,
|
pub criterion: SplitCriterion,
|
||||||
max_depth: Option<u16>,
|
pub max_depth: Option<u16>,
|
||||||
min_samples_leaf: u16
|
pub min_samples_leaf: u16
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@@ -19,7 +19,7 @@ pub struct DecisionTree {
|
|||||||
depth: u16
|
depth: u16
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum SplitCriterion {
|
pub enum SplitCriterion {
|
||||||
Gini,
|
Gini,
|
||||||
Entropy,
|
Entropy,
|
||||||
@@ -125,7 +125,7 @@ impl<'a, M: Matrix> NodeVisitor<'a, M> {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn which_max(x: &Vec<u32>) -> usize {
|
pub(in crate) fn which_max(x: &Vec<u32>) -> usize {
|
||||||
let mut m = x[0];
|
let mut m = x[0];
|
||||||
let mut which = 0;
|
let mut which = 0;
|
||||||
|
|
||||||
@@ -142,9 +142,15 @@ fn which_max(x: &Vec<u32>) -> usize {
|
|||||||
impl DecisionTree {
|
impl DecisionTree {
|
||||||
|
|
||||||
pub fn fit<M: Matrix>(x: &M, y: &M::RowVector, parameters: DecisionTreeParameters) -> DecisionTree {
|
pub fn fit<M: Matrix>(x: &M, y: &M::RowVector, parameters: DecisionTreeParameters) -> DecisionTree {
|
||||||
|
let (x_nrows, num_attributes) = x.shape();
|
||||||
|
let samples = vec![1; x_nrows];
|
||||||
|
DecisionTree::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn fit_weak_learner<M: Matrix>(x: &M, y: &M::RowVector, samples: Vec<u32>, mtry: usize, parameters: DecisionTreeParameters) -> DecisionTree {
|
||||||
let y_m = M::from_row_vector(y.clone());
|
let y_m = M::from_row_vector(y.clone());
|
||||||
let (_, y_ncols) = y_m.shape();
|
let (_, y_ncols) = y_m.shape();
|
||||||
let (x_nrows, num_attributes) = x.shape();
|
let (_, num_attributes) = x.shape();
|
||||||
let classes = y_m.unique();
|
let classes = y_m.unique();
|
||||||
let k = classes.len();
|
let k = classes.len();
|
||||||
if k < 2 {
|
if k < 2 {
|
||||||
@@ -159,7 +165,6 @@ impl DecisionTree {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let mut nodes: Vec<Node> = Vec::new();
|
let mut nodes: Vec<Node> = Vec::new();
|
||||||
let samples = vec![1; x_nrows];
|
|
||||||
|
|
||||||
let mut count = vec![0; k];
|
let mut count = vec![0; k];
|
||||||
for i in 0..y_ncols {
|
for i in 0..y_ncols {
|
||||||
@@ -186,13 +191,13 @@ impl DecisionTree {
|
|||||||
|
|
||||||
let mut visitor_queue: LinkedList<NodeVisitor<M>> = LinkedList::new();
|
let mut visitor_queue: LinkedList<NodeVisitor<M>> = LinkedList::new();
|
||||||
|
|
||||||
if tree.find_best_cutoff(&mut visitor) {
|
if tree.find_best_cutoff(&mut visitor, mtry) {
|
||||||
visitor_queue.push_back(visitor);
|
visitor_queue.push_back(visitor);
|
||||||
}
|
}
|
||||||
|
|
||||||
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
||||||
match visitor_queue.pop_front() {
|
match visitor_queue.pop_front() {
|
||||||
Some(node) => tree.split(node, &mut visitor_queue),
|
Some(node) => tree.split(node, mtry, &mut visitor_queue,),
|
||||||
None => break
|
None => break
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@@ -212,7 +217,7 @@ impl DecisionTree {
|
|||||||
result.to_row_vector()
|
result.to_row_vector()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn predict_for_row<M: Matrix>(&self, x: &M, row: usize) -> usize {
|
pub(in crate) fn predict_for_row<M: Matrix>(&self, x: &M, row: usize) -> usize {
|
||||||
let mut result = 0;
|
let mut result = 0;
|
||||||
let mut queue: LinkedList<usize> = LinkedList::new();
|
let mut queue: LinkedList<usize> = LinkedList::new();
|
||||||
|
|
||||||
@@ -240,7 +245,7 @@ impl DecisionTree {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn find_best_cutoff<M: Matrix>(&mut self, visitor: &mut NodeVisitor<M>) -> bool {
|
fn find_best_cutoff<M: Matrix>(&mut self, visitor: &mut NodeVisitor<M>, mtry: usize) -> bool {
|
||||||
|
|
||||||
let (n_rows, n_attr) = visitor.x.shape();
|
let (n_rows, n_attr) = visitor.x.shape();
|
||||||
|
|
||||||
@@ -282,7 +287,7 @@ impl DecisionTree {
|
|||||||
variables[i] = i;
|
variables[i] = i;
|
||||||
}
|
}
|
||||||
|
|
||||||
for j in 0..n_attr {
|
for j in 0..mtry {
|
||||||
self.find_best_split(visitor, n, &count, &mut false_count, parent_impurity, variables[j]);
|
self.find_best_split(visitor, n, &count, &mut false_count, parent_impurity, variables[j]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -340,7 +345,7 @@ impl DecisionTree {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn split<'a, M: Matrix>(&mut self, mut visitor: NodeVisitor<'a, M>, visitor_queue: &mut LinkedList<NodeVisitor<'a, M>>) -> bool {
|
fn split<'a, M: Matrix>(&mut self, mut visitor: NodeVisitor<'a, M>, mtry: usize, visitor_queue: &mut LinkedList<NodeVisitor<'a, M>>) -> bool {
|
||||||
let (n, _) = visitor.x.shape();
|
let (n, _) = visitor.x.shape();
|
||||||
let mut tc = 0;
|
let mut tc = 0;
|
||||||
let mut fc = 0;
|
let mut fc = 0;
|
||||||
@@ -377,13 +382,13 @@ impl DecisionTree {
|
|||||||
|
|
||||||
let mut true_visitor = NodeVisitor::<M>::new(true_child_idx, true_samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
|
let mut true_visitor = NodeVisitor::<M>::new(true_child_idx, true_samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
|
||||||
|
|
||||||
if tc > self.parameters.min_samples_leaf as u32 && self.find_best_cutoff(&mut true_visitor) {
|
if tc > self.parameters.min_samples_leaf as u32 && self.find_best_cutoff(&mut true_visitor, mtry) {
|
||||||
visitor_queue.push_back(true_visitor);
|
visitor_queue.push_back(true_visitor);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut false_visitor = NodeVisitor::<M>::new(false_child_idx, visitor.samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
|
let mut false_visitor = NodeVisitor::<M>::new(false_child_idx, visitor.samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
|
||||||
|
|
||||||
if fc > self.parameters.min_samples_leaf as u32 && self.find_best_cutoff(&mut false_visitor) {
|
if fc > self.parameters.min_samples_leaf as u32 && self.find_best_cutoff(&mut false_visitor, mtry) {
|
||||||
visitor_queue.push_back(false_visitor);
|
visitor_queue.push_back(false_visitor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ use crate::common::Nominal;
|
|||||||
pub mod knn;
|
pub mod knn;
|
||||||
pub mod logistic_regression;
|
pub mod logistic_regression;
|
||||||
pub mod decision_tree;
|
pub mod decision_tree;
|
||||||
|
pub mod random_forest;
|
||||||
|
|
||||||
pub trait Classifier<X, Y>
|
pub trait Classifier<X, Y>
|
||||||
where
|
where
|
||||||
|
|||||||
@@ -0,0 +1,160 @@
|
|||||||
|
extern crate rand;
|
||||||
|
|
||||||
|
use rand::Rng;
|
||||||
|
use std::default::Default;
|
||||||
|
use crate::linalg::Matrix;
|
||||||
|
use crate::classification::decision_tree::{DecisionTree, DecisionTreeParameters, SplitCriterion, which_max};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct RandomForestParameters {
|
||||||
|
pub criterion: SplitCriterion,
|
||||||
|
pub max_depth: Option<u16>,
|
||||||
|
pub min_samples_leaf: u16,
|
||||||
|
pub n_trees: u16,
|
||||||
|
pub mtry: Option<usize>
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct RandomForest {
|
||||||
|
parameters: RandomForestParameters,
|
||||||
|
trees: Vec<DecisionTree>,
|
||||||
|
classes: Vec<f64>
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for RandomForestParameters {
|
||||||
|
fn default() -> Self {
|
||||||
|
RandomForestParameters {
|
||||||
|
criterion: SplitCriterion::Gini,
|
||||||
|
max_depth: None,
|
||||||
|
min_samples_leaf: 1,
|
||||||
|
n_trees: 100,
|
||||||
|
mtry: Option::None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RandomForest {
|
||||||
|
|
||||||
|
pub fn fit<M: Matrix>(x: &M, y: &M::RowVector, parameters: RandomForestParameters) -> RandomForest {
|
||||||
|
let (_, num_attributes) = x.shape();
|
||||||
|
let y_m = M::from_row_vector(y.clone());
|
||||||
|
let (_, y_ncols) = y_m.shape();
|
||||||
|
let mut yi: Vec<usize> = vec![0; y_ncols];
|
||||||
|
let classes = y_m.unique();
|
||||||
|
|
||||||
|
for i in 0..y_ncols {
|
||||||
|
let yc = y_m.get(0, i);
|
||||||
|
yi[i] = classes.iter().position(|c| yc == *c).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mtry = parameters.mtry.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
|
||||||
|
|
||||||
|
let classes = y_m.unique();
|
||||||
|
let k = classes.len();
|
||||||
|
let mut trees: Vec<DecisionTree> = Vec::new();
|
||||||
|
|
||||||
|
for _ in 0..parameters.n_trees {
|
||||||
|
let samples = RandomForest::sample_with_replacement(&yi, k);
|
||||||
|
let params = DecisionTreeParameters{
|
||||||
|
criterion: parameters.criterion.clone(),
|
||||||
|
max_depth: parameters.max_depth,
|
||||||
|
min_samples_leaf: parameters.min_samples_leaf
|
||||||
|
};
|
||||||
|
let tree = DecisionTree::fit_weak_learner(x, y, samples, mtry, params);
|
||||||
|
trees.push(tree);
|
||||||
|
}
|
||||||
|
|
||||||
|
RandomForest {
|
||||||
|
parameters: parameters,
|
||||||
|
trees: trees,
|
||||||
|
classes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn predict<M: Matrix>(&self, x: &M) -> M::RowVector {
|
||||||
|
let mut result = M::zeros(1, x.shape().0);
|
||||||
|
|
||||||
|
let (n, _) = x.shape();
|
||||||
|
|
||||||
|
for i in 0..n {
|
||||||
|
result.set(0, i, self.classes[self.predict_for_row(x, i)]);
|
||||||
|
}
|
||||||
|
|
||||||
|
result.to_row_vector()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn predict_for_row<M: Matrix>(&self, x: &M, row: usize) -> usize {
|
||||||
|
let mut result = vec![0; self.classes.len()];
|
||||||
|
|
||||||
|
for tree in self.trees.iter() {
|
||||||
|
result[tree.predict_for_row(x, row)] += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return which_max(&result)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sample_with_replacement(y: &Vec<usize>, num_classes: usize) -> Vec<u32>{
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
let class_weight = vec![1.; num_classes];
|
||||||
|
let nrows = y.len();
|
||||||
|
let mut samples = vec![0; nrows];
|
||||||
|
for l in 0..num_classes {
|
||||||
|
let mut nj = 0;
|
||||||
|
let mut cj: Vec<usize> = Vec::new();
|
||||||
|
for i in 0..nrows {
|
||||||
|
if y[i] == l {
|
||||||
|
cj.push(i);
|
||||||
|
nj += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let size = ((nj as f64) / class_weight[l]) as usize;
|
||||||
|
for _ in 0..size {
|
||||||
|
let xi: usize = rng.gen_range(0, nj);
|
||||||
|
samples[cj[xi]] += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
samples
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn fit_predict_iris() {
|
||||||
|
|
||||||
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
|
&[4.7, 3.2, 1.3, 0.2],
|
||||||
|
&[4.6, 3.1, 1.5, 0.2],
|
||||||
|
&[5.0, 3.6, 1.4, 0.2],
|
||||||
|
&[5.4, 3.9, 1.7, 0.4],
|
||||||
|
&[4.6, 3.4, 1.4, 0.3],
|
||||||
|
&[5.0, 3.4, 1.5, 0.2],
|
||||||
|
&[4.4, 2.9, 1.4, 0.2],
|
||||||
|
&[4.9, 3.1, 1.5, 0.1],
|
||||||
|
&[7.0, 3.2, 4.7, 1.4],
|
||||||
|
&[6.4, 3.2, 4.5, 1.5],
|
||||||
|
&[6.9, 3.1, 4.9, 1.5],
|
||||||
|
&[5.5, 2.3, 4.0, 1.3],
|
||||||
|
&[6.5, 2.8, 4.6, 1.5],
|
||||||
|
&[5.7, 2.8, 4.5, 1.3],
|
||||||
|
&[6.3, 3.3, 4.7, 1.6],
|
||||||
|
&[4.9, 2.4, 3.3, 1.0],
|
||||||
|
&[6.6, 2.9, 4.6, 1.3],
|
||||||
|
&[5.2, 2.7, 3.9, 1.4]]);
|
||||||
|
let y = vec![0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
|
||||||
|
|
||||||
|
RandomForest::fit(&x, &y, Default::default());
|
||||||
|
|
||||||
|
assert_eq!(y, RandomForest::fit(&x, &y, Default::default()).predict(&x));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user