Initial implementation of predict_oob.
This commit is contained in:
@@ -53,7 +53,7 @@ use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::tree::decision_tree_classifier::{
|
||||
@@ -77,6 +77,8 @@ pub struct RandomForestClassifierParameters {
|
||||
pub n_trees: u16,
|
||||
/// Number of random sample of predictors to use as split candidates.
|
||||
pub m: Option<usize>,
|
||||
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||
pub keep_samples: bool,
|
||||
}
|
||||
|
||||
/// Random Forest Classifier
|
||||
@@ -86,6 +88,7 @@ pub struct RandomForestClassifier<T: RealNumber> {
|
||||
parameters: RandomForestClassifierParameters,
|
||||
trees: Vec<DecisionTreeClassifier<T>>,
|
||||
classes: Vec<T>,
|
||||
samples: Option<Vec<Vec<bool>>>,
|
||||
}
|
||||
|
||||
impl RandomForestClassifierParameters {
|
||||
@@ -119,6 +122,12 @@ impl RandomForestClassifierParameters {
|
||||
self.m = Some(m);
|
||||
self
|
||||
}
|
||||
|
||||
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||
pub fn with_keep_samples(mut self, keep_samples: bool) -> Self {
|
||||
self.keep_samples = keep_samples;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> PartialEq for RandomForestClassifier<T> {
|
||||
@@ -150,6 +159,7 @@ impl Default for RandomForestClassifierParameters {
|
||||
min_samples_split: 2,
|
||||
n_trees: 100,
|
||||
m: Option::None,
|
||||
keep_samples: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -205,8 +215,17 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
let k = classes.len();
|
||||
let mut trees: Vec<DecisionTreeClassifier<T>> = Vec::new();
|
||||
|
||||
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
|
||||
if parameters.keep_samples {
|
||||
maybe_all_samples = Some(Vec::new());
|
||||
}
|
||||
|
||||
for _ in 0..parameters.n_trees {
|
||||
let samples = RandomForestClassifier::<T>::sample_with_replacement(&yi, k);
|
||||
if let Some(ref mut all_samples) = maybe_all_samples {
|
||||
all_samples.push(samples.iter().map(|x| *x != 0).collect())
|
||||
}
|
||||
|
||||
let params = DecisionTreeClassifierParameters {
|
||||
criterion: parameters.criterion.clone(),
|
||||
max_depth: parameters.max_depth,
|
||||
@@ -221,6 +240,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
parameters,
|
||||
trees,
|
||||
classes,
|
||||
samples: maybe_all_samples,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -248,6 +268,42 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
which_max(&result)
|
||||
}
|
||||
|
||||
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
|
||||
pub fn predict_oob<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let (n, _) = x.shape();
|
||||
if self.samples.is_none() {
|
||||
Err(Failed::because(
|
||||
FailedError::PredictFailed,
|
||||
"Need samples=true for OOB predictions.",
|
||||
))
|
||||
} else if self.samples.as_ref().unwrap()[0].len() != n {
|
||||
Err(Failed::because(
|
||||
FailedError::PredictFailed,
|
||||
"Prediction matrix must match matrix used in training for OOB predictions.",
|
||||
))
|
||||
} else {
|
||||
let mut result = M::zeros(self.classes.len(), 1);
|
||||
|
||||
for i in 0..n {
|
||||
result.set(0, i, self.classes[self.predict_for_row_oob(x, i)]);
|
||||
}
|
||||
|
||||
Ok(result.to_row_vector())
|
||||
}
|
||||
}
|
||||
|
||||
fn predict_for_row_oob<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
|
||||
let mut result = vec![0; self.classes.len()];
|
||||
|
||||
for (tree, samples) in self.trees.iter().zip(self.samples.as_ref().unwrap()) {
|
||||
if !samples[row] {
|
||||
result[tree.predict_for_row(x, row)] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
which_max(&result)
|
||||
}
|
||||
|
||||
fn sample_with_replacement(y: &[usize], num_classes: usize) -> Vec<usize> {
|
||||
let mut rng = rand::thread_rng();
|
||||
let class_weight = vec![1.; num_classes];
|
||||
@@ -318,6 +374,7 @@ mod tests {
|
||||
min_samples_split: 2,
|
||||
n_trees: 100,
|
||||
m: Option::None,
|
||||
keep_samples: false,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
Reference in New Issue
Block a user