From d23931496758792938a1eebc989be760fc66c9b8 Mon Sep 17 00:00:00 2001 From: Malte Londschien Date: Thu, 14 Oct 2021 09:59:26 +0200 Subject: [PATCH] Same for regressor. --- src/ensemble/random_forest_regressor.rs | 108 +++++++++++++++++++++++- 1 file changed, 106 insertions(+), 2 deletions(-) diff --git a/src/ensemble/random_forest_regressor.rs b/src/ensemble/random_forest_regressor.rs index 0351fc4..f1caa8e 100644 --- a/src/ensemble/random_forest_regressor.rs +++ b/src/ensemble/random_forest_regressor.rs @@ -51,7 +51,7 @@ use rand::Rng; use serde::{Deserialize, Serialize}; use crate::api::{Predictor, SupervisedEstimator}; -use crate::error::Failed; +use crate::error::{Failed, FailedError}; use crate::linalg::Matrix; use crate::math::num::RealNumber; use crate::tree::decision_tree_regressor::{ @@ -73,6 +73,8 @@ pub struct RandomForestRegressorParameters { pub n_trees: usize, /// Number of random sample of predictors to use as split candidates. pub m: Option, + /// Whether to keep samples used for tree generation. This is required for OOB prediction. + pub keep_samples: bool, } /// Random Forest Regressor @@ -81,6 +83,7 @@ pub struct RandomForestRegressorParameters { pub struct RandomForestRegressor { parameters: RandomForestRegressorParameters, trees: Vec>, + samples: Option>>, } impl RandomForestRegressorParameters { @@ -109,6 +112,12 @@ impl RandomForestRegressorParameters { self.m = Some(m); self } + + /// Whether to keep samples used for tree generation. This is required for OOB prediction. + pub fn with_keep_samples(mut self, keep_samples: bool) -> Self { + self.keep_samples = keep_samples; + self + } } impl Default for RandomForestRegressorParameters { @@ -119,6 +128,7 @@ impl Default for RandomForestRegressorParameters { min_samples_split: 2, n_trees: 10, m: Option::None, + keep_samples: false, } } } @@ -174,8 +184,16 @@ impl RandomForestRegressor { let mut trees: Vec> = Vec::new(); + let mut maybe_all_samples: Option>> = Option::None; + if parameters.keep_samples { + maybe_all_samples = Some(Vec::new()); + } + for _ in 0..parameters.n_trees { let samples = RandomForestRegressor::::sample_with_replacement(n_rows); + if let Some(ref mut all_samples) = maybe_all_samples { + all_samples.push(samples.iter().map(|x| *x != 0).collect()) + } let params = DecisionTreeRegressorParameters { max_depth: parameters.max_depth, min_samples_leaf: parameters.min_samples_leaf, @@ -185,7 +203,7 @@ impl RandomForestRegressor { trees.push(tree); } - Ok(RandomForestRegressor { parameters, trees }) + Ok(RandomForestRegressor { parameters, trees, samples: maybe_all_samples }) } /// Predict class for `x` @@ -214,6 +232,46 @@ impl RandomForestRegressor { result / T::from(n_trees).unwrap() } + + /// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training. + pub fn predict_oob>(&self, x: &M) -> Result { + let (n, _) = x.shape(); + if self.samples.is_none() { + Err(Failed::because( + FailedError::PredictFailed, + "Need samples=true for OOB predictions.", + )) + } else if self.samples.as_ref().unwrap()[0].len() != n { + Err(Failed::because( + FailedError::PredictFailed, + "Prediction matrix must match matrix used in training for OOB predictions.", + )) + } else { + let mut result = M::zeros(1, n); + + for i in 0..n { + result.set(0, i, self.predict_for_row_oob(x, i)); + } + + Ok(result.to_row_vector()) + } + } + + fn predict_for_row_oob>(&self, x: &M, row: usize) -> T { + let mut n_trees = 0; + let mut result = T::zero(); + + for (tree, samples) in self.trees.iter().zip(self.samples.as_ref().unwrap()) { + if !samples[row] { + result += tree.predict_for_row(x, row); + n_trees += 1; + } + } + + // TODO: What to do if there are no oob trees? + result / T::from(n_trees).unwrap() + } + fn sample_with_replacement(nrows: usize) -> Vec { let mut rng = rand::thread_rng(); let mut samples = vec![0; nrows]; @@ -266,6 +324,7 @@ mod tests { min_samples_split: 2, n_trees: 1000, m: Option::None, + keep_samples: false, }, ) .and_then(|rf| rf.predict(&x)) @@ -274,6 +333,51 @@ mod tests { assert!(mean_absolute_error(&y, &y_hat) < 1.0); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + fn fit_predict_longley_oob() { + let x = DenseMatrix::from_2d_array(&[ + &[234.289, 235.6, 159., 107.608, 1947., 60.323], + &[259.426, 232.5, 145.6, 108.632, 1948., 61.122], + &[258.054, 368.2, 161.6, 109.773, 1949., 60.171], + &[284.599, 335.1, 165., 110.929, 1950., 61.187], + &[328.975, 209.9, 309.9, 112.075, 1951., 63.221], + &[346.999, 193.2, 359.4, 113.27, 1952., 63.639], + &[365.385, 187., 354.7, 115.094, 1953., 64.989], + &[363.112, 357.8, 335., 116.219, 1954., 63.761], + &[397.469, 290.4, 304.8, 117.388, 1955., 66.019], + &[419.18, 282.2, 285.7, 118.734, 1956., 67.857], + &[442.769, 293.6, 279.8, 120.445, 1957., 68.169], + &[444.546, 468.1, 263.7, 121.95, 1958., 66.513], + &[482.704, 381.3, 255.2, 123.366, 1959., 68.655], + &[502.601, 393.1, 251.4, 125.368, 1960., 69.564], + &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], + &[554.894, 400.7, 282.7, 130.081, 1962., 70.551], + ]); + let y = vec![ + 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, + 114.2, 115.7, 116.9, + ]; + + let regressor = RandomForestRegressor::fit( + &x, + &y, + RandomForestRegressorParameters { + max_depth: None, + min_samples_leaf: 1, + min_samples_split: 2, + n_trees: 1000, + m: Option::None, + keep_samples: true, + }, + ).unwrap(); + + let y_hat = regressor.predict(&x).unwrap(); + let y_hat_oob = regressor.predict_oob(&x).unwrap(); + + assert!(mean_absolute_error(&y, &y_hat) < mean_absolute_error(&y, &y_hat_oob)); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] #[cfg(feature = "serde")]