Merge pull request #116 from mlondschien/issue-115
Add OOB predictions to random forests
This commit is contained in:
@@ -53,7 +53,7 @@ use rand::Rng;
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::api::{Predictor, SupervisedEstimator};
|
use crate::api::{Predictor, SupervisedEstimator};
|
||||||
use crate::error::Failed;
|
use crate::error::{Failed, FailedError};
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
use crate::tree::decision_tree_classifier::{
|
use crate::tree::decision_tree_classifier::{
|
||||||
@@ -77,6 +77,8 @@ pub struct RandomForestClassifierParameters {
|
|||||||
pub n_trees: u16,
|
pub n_trees: u16,
|
||||||
/// Number of random sample of predictors to use as split candidates.
|
/// Number of random sample of predictors to use as split candidates.
|
||||||
pub m: Option<usize>,
|
pub m: Option<usize>,
|
||||||
|
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||||
|
pub keep_samples: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Random Forest Classifier
|
/// Random Forest Classifier
|
||||||
@@ -86,6 +88,7 @@ pub struct RandomForestClassifier<T: RealNumber> {
|
|||||||
parameters: RandomForestClassifierParameters,
|
parameters: RandomForestClassifierParameters,
|
||||||
trees: Vec<DecisionTreeClassifier<T>>,
|
trees: Vec<DecisionTreeClassifier<T>>,
|
||||||
classes: Vec<T>,
|
classes: Vec<T>,
|
||||||
|
samples: Option<Vec<Vec<bool>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RandomForestClassifierParameters {
|
impl RandomForestClassifierParameters {
|
||||||
@@ -119,6 +122,12 @@ impl RandomForestClassifierParameters {
|
|||||||
self.m = Some(m);
|
self.m = Some(m);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||||
|
pub fn with_keep_samples(mut self, keep_samples: bool) -> Self {
|
||||||
|
self.keep_samples = keep_samples;
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: RealNumber> PartialEq for RandomForestClassifier<T> {
|
impl<T: RealNumber> PartialEq for RandomForestClassifier<T> {
|
||||||
@@ -150,6 +159,7 @@ impl Default for RandomForestClassifierParameters {
|
|||||||
min_samples_split: 2,
|
min_samples_split: 2,
|
||||||
n_trees: 100,
|
n_trees: 100,
|
||||||
m: Option::None,
|
m: Option::None,
|
||||||
|
keep_samples: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -205,8 +215,17 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
|||||||
let k = classes.len();
|
let k = classes.len();
|
||||||
let mut trees: Vec<DecisionTreeClassifier<T>> = Vec::new();
|
let mut trees: Vec<DecisionTreeClassifier<T>> = Vec::new();
|
||||||
|
|
||||||
|
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
|
||||||
|
if parameters.keep_samples {
|
||||||
|
maybe_all_samples = Some(Vec::new());
|
||||||
|
}
|
||||||
|
|
||||||
for _ in 0..parameters.n_trees {
|
for _ in 0..parameters.n_trees {
|
||||||
let samples = RandomForestClassifier::<T>::sample_with_replacement(&yi, k);
|
let samples = RandomForestClassifier::<T>::sample_with_replacement(&yi, k);
|
||||||
|
if let Some(ref mut all_samples) = maybe_all_samples {
|
||||||
|
all_samples.push(samples.iter().map(|x| *x != 0).collect())
|
||||||
|
}
|
||||||
|
|
||||||
let params = DecisionTreeClassifierParameters {
|
let params = DecisionTreeClassifierParameters {
|
||||||
criterion: parameters.criterion.clone(),
|
criterion: parameters.criterion.clone(),
|
||||||
max_depth: parameters.max_depth,
|
max_depth: parameters.max_depth,
|
||||||
@@ -221,6 +240,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
|||||||
parameters,
|
parameters,
|
||||||
trees,
|
trees,
|
||||||
classes,
|
classes,
|
||||||
|
samples: maybe_all_samples,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -248,6 +268,42 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
|||||||
which_max(&result)
|
which_max(&result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
|
||||||
|
pub fn predict_oob<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||||
|
let (n, _) = x.shape();
|
||||||
|
if self.samples.is_none() {
|
||||||
|
Err(Failed::because(
|
||||||
|
FailedError::PredictFailed,
|
||||||
|
"Need samples=true for OOB predictions.",
|
||||||
|
))
|
||||||
|
} else if self.samples.as_ref().unwrap()[0].len() != n {
|
||||||
|
Err(Failed::because(
|
||||||
|
FailedError::PredictFailed,
|
||||||
|
"Prediction matrix must match matrix used in training for OOB predictions.",
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
let mut result = M::zeros(1, n);
|
||||||
|
|
||||||
|
for i in 0..n {
|
||||||
|
result.set(0, i, self.classes[self.predict_for_row_oob(x, i)]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result.to_row_vector())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn predict_for_row_oob<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
|
||||||
|
let mut result = vec![0; self.classes.len()];
|
||||||
|
|
||||||
|
for (tree, samples) in self.trees.iter().zip(self.samples.as_ref().unwrap()) {
|
||||||
|
if !samples[row] {
|
||||||
|
result[tree.predict_for_row(x, row)] += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
which_max(&result)
|
||||||
|
}
|
||||||
|
|
||||||
fn sample_with_replacement(y: &[usize], num_classes: usize) -> Vec<usize> {
|
fn sample_with_replacement(y: &[usize], num_classes: usize) -> Vec<usize> {
|
||||||
let mut rng = rand::thread_rng();
|
let mut rng = rand::thread_rng();
|
||||||
let class_weight = vec![1.; num_classes];
|
let class_weight = vec![1.; num_classes];
|
||||||
@@ -318,6 +374,7 @@ mod tests {
|
|||||||
min_samples_split: 2,
|
min_samples_split: 2,
|
||||||
n_trees: 100,
|
n_trees: 100,
|
||||||
m: Option::None,
|
m: Option::None,
|
||||||
|
keep_samples: false,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -325,6 +382,55 @@ mod tests {
|
|||||||
assert!(accuracy(&y, &classifier.predict(&x).unwrap()) >= 0.95);
|
assert!(accuracy(&y, &classifier.predict(&x).unwrap()) >= 0.95);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
|
#[test]
|
||||||
|
fn fit_predict_iris_oob() {
|
||||||
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
|
&[4.7, 3.2, 1.3, 0.2],
|
||||||
|
&[4.6, 3.1, 1.5, 0.2],
|
||||||
|
&[5.0, 3.6, 1.4, 0.2],
|
||||||
|
&[5.4, 3.9, 1.7, 0.4],
|
||||||
|
&[4.6, 3.4, 1.4, 0.3],
|
||||||
|
&[5.0, 3.4, 1.5, 0.2],
|
||||||
|
&[4.4, 2.9, 1.4, 0.2],
|
||||||
|
&[4.9, 3.1, 1.5, 0.1],
|
||||||
|
&[7.0, 3.2, 4.7, 1.4],
|
||||||
|
&[6.4, 3.2, 4.5, 1.5],
|
||||||
|
&[6.9, 3.1, 4.9, 1.5],
|
||||||
|
&[5.5, 2.3, 4.0, 1.3],
|
||||||
|
&[6.5, 2.8, 4.6, 1.5],
|
||||||
|
&[5.7, 2.8, 4.5, 1.3],
|
||||||
|
&[6.3, 3.3, 4.7, 1.6],
|
||||||
|
&[4.9, 2.4, 3.3, 1.0],
|
||||||
|
&[6.6, 2.9, 4.6, 1.3],
|
||||||
|
&[5.2, 2.7, 3.9, 1.4],
|
||||||
|
]);
|
||||||
|
let y = vec![
|
||||||
|
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||||
|
];
|
||||||
|
|
||||||
|
let classifier = RandomForestClassifier::fit(
|
||||||
|
&x,
|
||||||
|
&y,
|
||||||
|
RandomForestClassifierParameters {
|
||||||
|
criterion: SplitCriterion::Gini,
|
||||||
|
max_depth: None,
|
||||||
|
min_samples_leaf: 1,
|
||||||
|
min_samples_split: 2,
|
||||||
|
n_trees: 100,
|
||||||
|
m: Option::None,
|
||||||
|
keep_samples: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
assert!(
|
||||||
|
accuracy(&y, &classifier.predict_oob(&x).unwrap())
|
||||||
|
< accuracy(&y, &classifier.predict(&x).unwrap())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ use rand::Rng;
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::api::{Predictor, SupervisedEstimator};
|
use crate::api::{Predictor, SupervisedEstimator};
|
||||||
use crate::error::Failed;
|
use crate::error::{Failed, FailedError};
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
use crate::tree::decision_tree_regressor::{
|
use crate::tree::decision_tree_regressor::{
|
||||||
@@ -73,6 +73,8 @@ pub struct RandomForestRegressorParameters {
|
|||||||
pub n_trees: usize,
|
pub n_trees: usize,
|
||||||
/// Number of random sample of predictors to use as split candidates.
|
/// Number of random sample of predictors to use as split candidates.
|
||||||
pub m: Option<usize>,
|
pub m: Option<usize>,
|
||||||
|
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||||
|
pub keep_samples: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Random Forest Regressor
|
/// Random Forest Regressor
|
||||||
@@ -81,6 +83,7 @@ pub struct RandomForestRegressorParameters {
|
|||||||
pub struct RandomForestRegressor<T: RealNumber> {
|
pub struct RandomForestRegressor<T: RealNumber> {
|
||||||
parameters: RandomForestRegressorParameters,
|
parameters: RandomForestRegressorParameters,
|
||||||
trees: Vec<DecisionTreeRegressor<T>>,
|
trees: Vec<DecisionTreeRegressor<T>>,
|
||||||
|
samples: Option<Vec<Vec<bool>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RandomForestRegressorParameters {
|
impl RandomForestRegressorParameters {
|
||||||
@@ -109,6 +112,12 @@ impl RandomForestRegressorParameters {
|
|||||||
self.m = Some(m);
|
self.m = Some(m);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||||
|
pub fn with_keep_samples(mut self, keep_samples: bool) -> Self {
|
||||||
|
self.keep_samples = keep_samples;
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for RandomForestRegressorParameters {
|
impl Default for RandomForestRegressorParameters {
|
||||||
@@ -119,6 +128,7 @@ impl Default for RandomForestRegressorParameters {
|
|||||||
min_samples_split: 2,
|
min_samples_split: 2,
|
||||||
n_trees: 10,
|
n_trees: 10,
|
||||||
m: Option::None,
|
m: Option::None,
|
||||||
|
keep_samples: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -174,8 +184,16 @@ impl<T: RealNumber> RandomForestRegressor<T> {
|
|||||||
|
|
||||||
let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new();
|
let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new();
|
||||||
|
|
||||||
|
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
|
||||||
|
if parameters.keep_samples {
|
||||||
|
maybe_all_samples = Some(Vec::new());
|
||||||
|
}
|
||||||
|
|
||||||
for _ in 0..parameters.n_trees {
|
for _ in 0..parameters.n_trees {
|
||||||
let samples = RandomForestRegressor::<T>::sample_with_replacement(n_rows);
|
let samples = RandomForestRegressor::<T>::sample_with_replacement(n_rows);
|
||||||
|
if let Some(ref mut all_samples) = maybe_all_samples {
|
||||||
|
all_samples.push(samples.iter().map(|x| *x != 0).collect())
|
||||||
|
}
|
||||||
let params = DecisionTreeRegressorParameters {
|
let params = DecisionTreeRegressorParameters {
|
||||||
max_depth: parameters.max_depth,
|
max_depth: parameters.max_depth,
|
||||||
min_samples_leaf: parameters.min_samples_leaf,
|
min_samples_leaf: parameters.min_samples_leaf,
|
||||||
@@ -185,7 +203,11 @@ impl<T: RealNumber> RandomForestRegressor<T> {
|
|||||||
trees.push(tree);
|
trees.push(tree);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(RandomForestRegressor { parameters, trees })
|
Ok(RandomForestRegressor {
|
||||||
|
parameters,
|
||||||
|
trees,
|
||||||
|
samples: maybe_all_samples,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Predict class for `x`
|
/// Predict class for `x`
|
||||||
@@ -214,6 +236,45 @@ impl<T: RealNumber> RandomForestRegressor<T> {
|
|||||||
result / T::from(n_trees).unwrap()
|
result / T::from(n_trees).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
|
||||||
|
pub fn predict_oob<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||||
|
let (n, _) = x.shape();
|
||||||
|
if self.samples.is_none() {
|
||||||
|
Err(Failed::because(
|
||||||
|
FailedError::PredictFailed,
|
||||||
|
"Need samples=true for OOB predictions.",
|
||||||
|
))
|
||||||
|
} else if self.samples.as_ref().unwrap()[0].len() != n {
|
||||||
|
Err(Failed::because(
|
||||||
|
FailedError::PredictFailed,
|
||||||
|
"Prediction matrix must match matrix used in training for OOB predictions.",
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
let mut result = M::zeros(1, n);
|
||||||
|
|
||||||
|
for i in 0..n {
|
||||||
|
result.set(0, i, self.predict_for_row_oob(x, i));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result.to_row_vector())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn predict_for_row_oob<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
|
||||||
|
let mut n_trees = 0;
|
||||||
|
let mut result = T::zero();
|
||||||
|
|
||||||
|
for (tree, samples) in self.trees.iter().zip(self.samples.as_ref().unwrap()) {
|
||||||
|
if !samples[row] {
|
||||||
|
result += tree.predict_for_row(x, row);
|
||||||
|
n_trees += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: What to do if there are no oob trees?
|
||||||
|
result / T::from(n_trees).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
fn sample_with_replacement(nrows: usize) -> Vec<usize> {
|
fn sample_with_replacement(nrows: usize) -> Vec<usize> {
|
||||||
let mut rng = rand::thread_rng();
|
let mut rng = rand::thread_rng();
|
||||||
let mut samples = vec![0; nrows];
|
let mut samples = vec![0; nrows];
|
||||||
@@ -266,6 +327,7 @@ mod tests {
|
|||||||
min_samples_split: 2,
|
min_samples_split: 2,
|
||||||
n_trees: 1000,
|
n_trees: 1000,
|
||||||
m: Option::None,
|
m: Option::None,
|
||||||
|
keep_samples: false,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.and_then(|rf| rf.predict(&x))
|
.and_then(|rf| rf.predict(&x))
|
||||||
@@ -274,6 +336,52 @@ mod tests {
|
|||||||
assert!(mean_absolute_error(&y, &y_hat) < 1.0);
|
assert!(mean_absolute_error(&y, &y_hat) < 1.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
|
#[test]
|
||||||
|
fn fit_predict_longley_oob() {
|
||||||
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
|
&[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||||
|
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||||
|
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||||
|
&[284.599, 335.1, 165., 110.929, 1950., 61.187],
|
||||||
|
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||||
|
&[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
|
||||||
|
&[365.385, 187., 354.7, 115.094, 1953., 64.989],
|
||||||
|
&[363.112, 357.8, 335., 116.219, 1954., 63.761],
|
||||||
|
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||||
|
&[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||||
|
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||||
|
&[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
|
||||||
|
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||||
|
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||||
|
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||||
|
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||||
|
]);
|
||||||
|
let y = vec![
|
||||||
|
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||||
|
114.2, 115.7, 116.9,
|
||||||
|
];
|
||||||
|
|
||||||
|
let regressor = RandomForestRegressor::fit(
|
||||||
|
&x,
|
||||||
|
&y,
|
||||||
|
RandomForestRegressorParameters {
|
||||||
|
max_depth: None,
|
||||||
|
min_samples_leaf: 1,
|
||||||
|
min_samples_split: 2,
|
||||||
|
n_trees: 1000,
|
||||||
|
m: Option::None,
|
||||||
|
keep_samples: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let y_hat = regressor.predict(&x).unwrap();
|
||||||
|
let y_hat_oob = regressor.predict_oob(&x).unwrap();
|
||||||
|
|
||||||
|
assert!(mean_absolute_error(&y, &y_hat) < mean_absolute_error(&y, &y_hat_oob));
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
|
|||||||
@@ -1007,6 +1007,7 @@ mod tests {
|
|||||||
min_samples_split: 2,
|
min_samples_split: 2,
|
||||||
n_trees: 1000,
|
n_trees: 1000,
|
||||||
m: Option::None,
|
m: Option::None,
|
||||||
|
keep_samples: false,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
|||||||
Reference in New Issue
Block a user