From e8cba343ca83ea2af06b0d0af97a3c538b17084f Mon Sep 17 00:00:00 2001 From: Malte Londschien Date: Thu, 14 Oct 2021 09:33:55 +0200 Subject: [PATCH 1/6] Initial implementation of predict_oob. --- src/ensemble/random_forest_classifier.rs | 59 +++++++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index 1d7884b..b3c810a 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -53,7 +53,7 @@ use rand::Rng; use serde::{Deserialize, Serialize}; use crate::api::{Predictor, SupervisedEstimator}; -use crate::error::Failed; +use crate::error::{Failed, FailedError}; use crate::linalg::Matrix; use crate::math::num::RealNumber; use crate::tree::decision_tree_classifier::{ @@ -77,6 +77,8 @@ pub struct RandomForestClassifierParameters { pub n_trees: u16, /// Number of random sample of predictors to use as split candidates. pub m: Option, + /// Whether to keep samples used for tree generation. This is required for OOB prediction. + pub keep_samples: bool, } /// Random Forest Classifier @@ -86,6 +88,7 @@ pub struct RandomForestClassifier { parameters: RandomForestClassifierParameters, trees: Vec>, classes: Vec, + samples: Option>>, } impl RandomForestClassifierParameters { @@ -119,6 +122,12 @@ impl RandomForestClassifierParameters { self.m = Some(m); self } + + /// Whether to keep samples used for tree generation. This is required for OOB prediction. + pub fn with_keep_samples(mut self, keep_samples: bool) -> Self { + self.keep_samples = keep_samples; + self + } } impl PartialEq for RandomForestClassifier { @@ -150,6 +159,7 @@ impl Default for RandomForestClassifierParameters { min_samples_split: 2, n_trees: 100, m: Option::None, + keep_samples: false, } } } @@ -205,8 +215,17 @@ impl RandomForestClassifier { let k = classes.len(); let mut trees: Vec> = Vec::new(); + let mut maybe_all_samples: Option>> = Option::None; + if parameters.keep_samples { + maybe_all_samples = Some(Vec::new()); + } + for _ in 0..parameters.n_trees { let samples = RandomForestClassifier::::sample_with_replacement(&yi, k); + if let Some(ref mut all_samples) = maybe_all_samples { + all_samples.push(samples.iter().map(|x| *x != 0).collect()) + } + let params = DecisionTreeClassifierParameters { criterion: parameters.criterion.clone(), max_depth: parameters.max_depth, @@ -221,6 +240,7 @@ impl RandomForestClassifier { parameters, trees, classes, + samples: maybe_all_samples, }) } @@ -248,6 +268,42 @@ impl RandomForestClassifier { which_max(&result) } + /// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training. + pub fn predict_oob>(&self, x: &M) -> Result { + let (n, _) = x.shape(); + if self.samples.is_none() { + Err(Failed::because( + FailedError::PredictFailed, + "Need samples=true for OOB predictions.", + )) + } else if self.samples.as_ref().unwrap()[0].len() != n { + Err(Failed::because( + FailedError::PredictFailed, + "Prediction matrix must match matrix used in training for OOB predictions.", + )) + } else { + let mut result = M::zeros(self.classes.len(), 1); + + for i in 0..n { + result.set(0, i, self.classes[self.predict_for_row_oob(x, i)]); + } + + Ok(result.to_row_vector()) + } + } + + fn predict_for_row_oob>(&self, x: &M, row: usize) -> usize { + let mut result = vec![0; self.classes.len()]; + + for (tree, samples) in self.trees.iter().zip(self.samples.as_ref().unwrap()) { + if !samples[row] { + result[tree.predict_for_row(x, row)] += 1; + } + } + + which_max(&result) + } + fn sample_with_replacement(y: &[usize], num_classes: usize) -> Vec { let mut rng = rand::thread_rng(); let class_weight = vec![1.; num_classes]; @@ -318,6 +374,7 @@ mod tests { min_samples_split: 2, n_trees: 100, m: Option::None, + keep_samples: false, }, ) .unwrap(); From 4bae62ab2f7776f56363c472a4257ee0c069fee7 Mon Sep 17 00:00:00 2001 From: Malte Londschien Date: Thu, 14 Oct 2021 09:47:00 +0200 Subject: [PATCH 2/6] Test. --- src/ensemble/random_forest_classifier.rs | 51 +++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index b3c810a..f70604c 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -282,7 +282,7 @@ impl RandomForestClassifier { "Prediction matrix must match matrix used in training for OOB predictions.", )) } else { - let mut result = M::zeros(self.classes.len(), 1); + let mut result = M::zeros(1, n); for i in 0..n { result.set(0, i, self.classes[self.predict_for_row_oob(x, i)]); @@ -382,6 +382,55 @@ mod tests { assert!(accuracy(&y, &classifier.predict(&x).unwrap()) >= 0.95); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + fn fit_predict_iris_oob() { + let x = DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + &[4.9, 3.1, 1.5, 0.1], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + &[5.7, 2.8, 4.5, 1.3], + &[6.3, 3.3, 4.7, 1.6], + &[4.9, 2.4, 3.3, 1.0], + &[6.6, 2.9, 4.6, 1.3], + &[5.2, 2.7, 3.9, 1.4], + ]); + let y = vec![ + 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., + ]; + + let classifier = RandomForestClassifier::fit( + &x, + &y, + RandomForestClassifierParameters { + criterion: SplitCriterion::Gini, + max_depth: None, + min_samples_leaf: 1, + min_samples_split: 2, + n_trees: 100, + m: Option::None, + keep_samples: true, + }, + ) + .unwrap(); + assert!( + accuracy(&y, &classifier.predict_oob(&x).unwrap()) + < accuracy(&y, &classifier.predict(&x).unwrap()) + ); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] #[cfg(feature = "serde")] From d23931496758792938a1eebc989be760fc66c9b8 Mon Sep 17 00:00:00 2001 From: Malte Londschien Date: Thu, 14 Oct 2021 09:59:26 +0200 Subject: [PATCH 3/6] Same for regressor. --- src/ensemble/random_forest_regressor.rs | 108 +++++++++++++++++++++++- 1 file changed, 106 insertions(+), 2 deletions(-) diff --git a/src/ensemble/random_forest_regressor.rs b/src/ensemble/random_forest_regressor.rs index 0351fc4..f1caa8e 100644 --- a/src/ensemble/random_forest_regressor.rs +++ b/src/ensemble/random_forest_regressor.rs @@ -51,7 +51,7 @@ use rand::Rng; use serde::{Deserialize, Serialize}; use crate::api::{Predictor, SupervisedEstimator}; -use crate::error::Failed; +use crate::error::{Failed, FailedError}; use crate::linalg::Matrix; use crate::math::num::RealNumber; use crate::tree::decision_tree_regressor::{ @@ -73,6 +73,8 @@ pub struct RandomForestRegressorParameters { pub n_trees: usize, /// Number of random sample of predictors to use as split candidates. pub m: Option, + /// Whether to keep samples used for tree generation. This is required for OOB prediction. + pub keep_samples: bool, } /// Random Forest Regressor @@ -81,6 +83,7 @@ pub struct RandomForestRegressorParameters { pub struct RandomForestRegressor { parameters: RandomForestRegressorParameters, trees: Vec>, + samples: Option>>, } impl RandomForestRegressorParameters { @@ -109,6 +112,12 @@ impl RandomForestRegressorParameters { self.m = Some(m); self } + + /// Whether to keep samples used for tree generation. This is required for OOB prediction. + pub fn with_keep_samples(mut self, keep_samples: bool) -> Self { + self.keep_samples = keep_samples; + self + } } impl Default for RandomForestRegressorParameters { @@ -119,6 +128,7 @@ impl Default for RandomForestRegressorParameters { min_samples_split: 2, n_trees: 10, m: Option::None, + keep_samples: false, } } } @@ -174,8 +184,16 @@ impl RandomForestRegressor { let mut trees: Vec> = Vec::new(); + let mut maybe_all_samples: Option>> = Option::None; + if parameters.keep_samples { + maybe_all_samples = Some(Vec::new()); + } + for _ in 0..parameters.n_trees { let samples = RandomForestRegressor::::sample_with_replacement(n_rows); + if let Some(ref mut all_samples) = maybe_all_samples { + all_samples.push(samples.iter().map(|x| *x != 0).collect()) + } let params = DecisionTreeRegressorParameters { max_depth: parameters.max_depth, min_samples_leaf: parameters.min_samples_leaf, @@ -185,7 +203,7 @@ impl RandomForestRegressor { trees.push(tree); } - Ok(RandomForestRegressor { parameters, trees }) + Ok(RandomForestRegressor { parameters, trees, samples: maybe_all_samples }) } /// Predict class for `x` @@ -214,6 +232,46 @@ impl RandomForestRegressor { result / T::from(n_trees).unwrap() } + + /// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training. + pub fn predict_oob>(&self, x: &M) -> Result { + let (n, _) = x.shape(); + if self.samples.is_none() { + Err(Failed::because( + FailedError::PredictFailed, + "Need samples=true for OOB predictions.", + )) + } else if self.samples.as_ref().unwrap()[0].len() != n { + Err(Failed::because( + FailedError::PredictFailed, + "Prediction matrix must match matrix used in training for OOB predictions.", + )) + } else { + let mut result = M::zeros(1, n); + + for i in 0..n { + result.set(0, i, self.predict_for_row_oob(x, i)); + } + + Ok(result.to_row_vector()) + } + } + + fn predict_for_row_oob>(&self, x: &M, row: usize) -> T { + let mut n_trees = 0; + let mut result = T::zero(); + + for (tree, samples) in self.trees.iter().zip(self.samples.as_ref().unwrap()) { + if !samples[row] { + result += tree.predict_for_row(x, row); + n_trees += 1; + } + } + + // TODO: What to do if there are no oob trees? + result / T::from(n_trees).unwrap() + } + fn sample_with_replacement(nrows: usize) -> Vec { let mut rng = rand::thread_rng(); let mut samples = vec![0; nrows]; @@ -266,6 +324,7 @@ mod tests { min_samples_split: 2, n_trees: 1000, m: Option::None, + keep_samples: false, }, ) .and_then(|rf| rf.predict(&x)) @@ -274,6 +333,51 @@ mod tests { assert!(mean_absolute_error(&y, &y_hat) < 1.0); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + fn fit_predict_longley_oob() { + let x = DenseMatrix::from_2d_array(&[ + &[234.289, 235.6, 159., 107.608, 1947., 60.323], + &[259.426, 232.5, 145.6, 108.632, 1948., 61.122], + &[258.054, 368.2, 161.6, 109.773, 1949., 60.171], + &[284.599, 335.1, 165., 110.929, 1950., 61.187], + &[328.975, 209.9, 309.9, 112.075, 1951., 63.221], + &[346.999, 193.2, 359.4, 113.27, 1952., 63.639], + &[365.385, 187., 354.7, 115.094, 1953., 64.989], + &[363.112, 357.8, 335., 116.219, 1954., 63.761], + &[397.469, 290.4, 304.8, 117.388, 1955., 66.019], + &[419.18, 282.2, 285.7, 118.734, 1956., 67.857], + &[442.769, 293.6, 279.8, 120.445, 1957., 68.169], + &[444.546, 468.1, 263.7, 121.95, 1958., 66.513], + &[482.704, 381.3, 255.2, 123.366, 1959., 68.655], + &[502.601, 393.1, 251.4, 125.368, 1960., 69.564], + &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], + &[554.894, 400.7, 282.7, 130.081, 1962., 70.551], + ]); + let y = vec![ + 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, + 114.2, 115.7, 116.9, + ]; + + let regressor = RandomForestRegressor::fit( + &x, + &y, + RandomForestRegressorParameters { + max_depth: None, + min_samples_leaf: 1, + min_samples_split: 2, + n_trees: 1000, + m: Option::None, + keep_samples: true, + }, + ).unwrap(); + + let y_hat = regressor.predict(&x).unwrap(); + let y_hat_oob = regressor.predict_oob(&x).unwrap(); + + assert!(mean_absolute_error(&y, &y_hat) < mean_absolute_error(&y, &y_hat_oob)); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] #[cfg(feature = "serde")] From 85b9fde9a78c6bc3c9b8e2899240fc77a7d5adf4 Mon Sep 17 00:00:00 2001 From: Malte Londschien Date: Wed, 20 Oct 2021 17:04:24 +0200 Subject: [PATCH 4/6] Another format. --- src/ensemble/random_forest_regressor.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/ensemble/random_forest_regressor.rs b/src/ensemble/random_forest_regressor.rs index f1caa8e..90ac479 100644 --- a/src/ensemble/random_forest_regressor.rs +++ b/src/ensemble/random_forest_regressor.rs @@ -203,7 +203,11 @@ impl RandomForestRegressor { trees.push(tree); } - Ok(RandomForestRegressor { parameters, trees, samples: maybe_all_samples }) + Ok(RandomForestRegressor { + parameters, + trees, + samples: maybe_all_samples, + }) } /// Predict class for `x` @@ -232,7 +236,6 @@ impl RandomForestRegressor { result / T::from(n_trees).unwrap() } - /// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training. pub fn predict_oob>(&self, x: &M) -> Result { let (n, _) = x.shape(); @@ -370,7 +373,8 @@ mod tests { m: Option::None, keep_samples: true, }, - ).unwrap(); + ) + .unwrap(); let y_hat = regressor.predict(&x).unwrap(); let y_hat_oob = regressor.predict_oob(&x).unwrap(); From d0a4ccbe202263ff2103891cbf163a691befe594 Mon Sep 17 00:00:00 2001 From: Malte Londschien Date: Wed, 20 Oct 2021 17:08:52 +0200 Subject: [PATCH 5/6] Set keep_samples attribute. --- src/linalg/ndarray_bindings.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/linalg/ndarray_bindings.rs b/src/linalg/ndarray_bindings.rs index 0aa97aa..2ec8e3a 100644 --- a/src/linalg/ndarray_bindings.rs +++ b/src/linalg/ndarray_bindings.rs @@ -1007,6 +1007,7 @@ mod tests { min_samples_split: 2, n_trees: 1000, m: Option::None, + keep_samples: false, }, ) .unwrap() From 14245e15ad2452232052f56b22959a6aaa89af2a Mon Sep 17 00:00:00 2001 From: Malte Londschien Date: Wed, 20 Oct 2021 17:13:00 +0200 Subject: [PATCH 6/6] type error. --- src/linalg/ndarray_bindings.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/linalg/ndarray_bindings.rs b/src/linalg/ndarray_bindings.rs index 2ec8e3a..f5b1c69 100644 --- a/src/linalg/ndarray_bindings.rs +++ b/src/linalg/ndarray_bindings.rs @@ -966,7 +966,7 @@ mod tests { let error: f64 = y .into_iter() .zip(y_hat.into_iter()) - .map(|(&a, &b)| (a - b).abs()) + .map(|(a, b)| (a - b).abs()) .sum(); assert!(error <= 1.0);