diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index 1d7884b..b3c810a 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -53,7 +53,7 @@ use rand::Rng; use serde::{Deserialize, Serialize}; use crate::api::{Predictor, SupervisedEstimator}; -use crate::error::Failed; +use crate::error::{Failed, FailedError}; use crate::linalg::Matrix; use crate::math::num::RealNumber; use crate::tree::decision_tree_classifier::{ @@ -77,6 +77,8 @@ pub struct RandomForestClassifierParameters { pub n_trees: u16, /// Number of random sample of predictors to use as split candidates. pub m: Option, + /// Whether to keep samples used for tree generation. This is required for OOB prediction. + pub keep_samples: bool, } /// Random Forest Classifier @@ -86,6 +88,7 @@ pub struct RandomForestClassifier { parameters: RandomForestClassifierParameters, trees: Vec>, classes: Vec, + samples: Option>>, } impl RandomForestClassifierParameters { @@ -119,6 +122,12 @@ impl RandomForestClassifierParameters { self.m = Some(m); self } + + /// Whether to keep samples used for tree generation. This is required for OOB prediction. + pub fn with_keep_samples(mut self, keep_samples: bool) -> Self { + self.keep_samples = keep_samples; + self + } } impl PartialEq for RandomForestClassifier { @@ -150,6 +159,7 @@ impl Default for RandomForestClassifierParameters { min_samples_split: 2, n_trees: 100, m: Option::None, + keep_samples: false, } } } @@ -205,8 +215,17 @@ impl RandomForestClassifier { let k = classes.len(); let mut trees: Vec> = Vec::new(); + let mut maybe_all_samples: Option>> = Option::None; + if parameters.keep_samples { + maybe_all_samples = Some(Vec::new()); + } + for _ in 0..parameters.n_trees { let samples = RandomForestClassifier::::sample_with_replacement(&yi, k); + if let Some(ref mut all_samples) = maybe_all_samples { + all_samples.push(samples.iter().map(|x| *x != 0).collect()) + } + let params = DecisionTreeClassifierParameters { criterion: parameters.criterion.clone(), max_depth: parameters.max_depth, @@ -221,6 +240,7 @@ impl RandomForestClassifier { parameters, trees, classes, + samples: maybe_all_samples, }) } @@ -248,6 +268,42 @@ impl RandomForestClassifier { which_max(&result) } + /// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training. + pub fn predict_oob>(&self, x: &M) -> Result { + let (n, _) = x.shape(); + if self.samples.is_none() { + Err(Failed::because( + FailedError::PredictFailed, + "Need samples=true for OOB predictions.", + )) + } else if self.samples.as_ref().unwrap()[0].len() != n { + Err(Failed::because( + FailedError::PredictFailed, + "Prediction matrix must match matrix used in training for OOB predictions.", + )) + } else { + let mut result = M::zeros(self.classes.len(), 1); + + for i in 0..n { + result.set(0, i, self.classes[self.predict_for_row_oob(x, i)]); + } + + Ok(result.to_row_vector()) + } + } + + fn predict_for_row_oob>(&self, x: &M, row: usize) -> usize { + let mut result = vec![0; self.classes.len()]; + + for (tree, samples) in self.trees.iter().zip(self.samples.as_ref().unwrap()) { + if !samples[row] { + result[tree.predict_for_row(x, row)] += 1; + } + } + + which_max(&result) + } + fn sample_with_replacement(y: &[usize], num_classes: usize) -> Vec { let mut rng = rand::thread_rng(); let class_weight = vec![1.; num_classes]; @@ -318,6 +374,7 @@ mod tests { min_samples_split: 2, n_trees: 100, m: Option::None, + keep_samples: false, }, ) .unwrap();