Initial implementation of predict_oob.

This commit is contained in:
Malte Londschien
2021-10-14 09:33:55 +02:00
parent 1208051fb5
commit e8cba343ca
+58 -1
View File
@@ -53,7 +53,7 @@ use rand::Rng;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator}; use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed; use crate::error::{Failed, FailedError};
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
use crate::tree::decision_tree_classifier::{ use crate::tree::decision_tree_classifier::{
@@ -77,6 +77,8 @@ pub struct RandomForestClassifierParameters {
pub n_trees: u16, pub n_trees: u16,
/// Number of random sample of predictors to use as split candidates. /// Number of random sample of predictors to use as split candidates.
pub m: Option<usize>, pub m: Option<usize>,
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub keep_samples: bool,
} }
/// Random Forest Classifier /// Random Forest Classifier
@@ -86,6 +88,7 @@ pub struct RandomForestClassifier<T: RealNumber> {
parameters: RandomForestClassifierParameters, parameters: RandomForestClassifierParameters,
trees: Vec<DecisionTreeClassifier<T>>, trees: Vec<DecisionTreeClassifier<T>>,
classes: Vec<T>, classes: Vec<T>,
samples: Option<Vec<Vec<bool>>>,
} }
impl RandomForestClassifierParameters { impl RandomForestClassifierParameters {
@@ -119,6 +122,12 @@ impl RandomForestClassifierParameters {
self.m = Some(m); self.m = Some(m);
self self
} }
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub fn with_keep_samples(mut self, keep_samples: bool) -> Self {
self.keep_samples = keep_samples;
self
}
} }
impl<T: RealNumber> PartialEq for RandomForestClassifier<T> { impl<T: RealNumber> PartialEq for RandomForestClassifier<T> {
@@ -150,6 +159,7 @@ impl Default for RandomForestClassifierParameters {
min_samples_split: 2, min_samples_split: 2,
n_trees: 100, n_trees: 100,
m: Option::None, m: Option::None,
keep_samples: false,
} }
} }
} }
@@ -205,8 +215,17 @@ impl<T: RealNumber> RandomForestClassifier<T> {
let k = classes.len(); let k = classes.len();
let mut trees: Vec<DecisionTreeClassifier<T>> = Vec::new(); let mut trees: Vec<DecisionTreeClassifier<T>> = Vec::new();
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
if parameters.keep_samples {
maybe_all_samples = Some(Vec::new());
}
for _ in 0..parameters.n_trees { for _ in 0..parameters.n_trees {
let samples = RandomForestClassifier::<T>::sample_with_replacement(&yi, k); let samples = RandomForestClassifier::<T>::sample_with_replacement(&yi, k);
if let Some(ref mut all_samples) = maybe_all_samples {
all_samples.push(samples.iter().map(|x| *x != 0).collect())
}
let params = DecisionTreeClassifierParameters { let params = DecisionTreeClassifierParameters {
criterion: parameters.criterion.clone(), criterion: parameters.criterion.clone(),
max_depth: parameters.max_depth, max_depth: parameters.max_depth,
@@ -221,6 +240,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
parameters, parameters,
trees, trees,
classes, classes,
samples: maybe_all_samples,
}) })
} }
@@ -248,6 +268,42 @@ impl<T: RealNumber> RandomForestClassifier<T> {
which_max(&result) which_max(&result)
} }
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
pub fn predict_oob<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
let (n, _) = x.shape();
if self.samples.is_none() {
Err(Failed::because(
FailedError::PredictFailed,
"Need samples=true for OOB predictions.",
))
} else if self.samples.as_ref().unwrap()[0].len() != n {
Err(Failed::because(
FailedError::PredictFailed,
"Prediction matrix must match matrix used in training for OOB predictions.",
))
} else {
let mut result = M::zeros(self.classes.len(), 1);
for i in 0..n {
result.set(0, i, self.classes[self.predict_for_row_oob(x, i)]);
}
Ok(result.to_row_vector())
}
}
fn predict_for_row_oob<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
let mut result = vec![0; self.classes.len()];
for (tree, samples) in self.trees.iter().zip(self.samples.as_ref().unwrap()) {
if !samples[row] {
result[tree.predict_for_row(x, row)] += 1;
}
}
which_max(&result)
}
fn sample_with_replacement(y: &[usize], num_classes: usize) -> Vec<usize> { fn sample_with_replacement(y: &[usize], num_classes: usize) -> Vec<usize> {
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let class_weight = vec![1.; num_classes]; let class_weight = vec![1.; num_classes];
@@ -318,6 +374,7 @@ mod tests {
min_samples_split: 2, min_samples_split: 2,
n_trees: 100, n_trees: 100,
m: Option::None, m: Option::None,
keep_samples: false,
}, },
) )
.unwrap(); .unwrap();