refactored random forest regressor into reusable compoennts (#318)

This commit is contained in:
Daniel Lacina
2025-07-12 10:56:49 -04:00
committed by GitHub
parent c5816b0e1b
commit 9fef05ecc6
5 changed files with 239 additions and 157 deletions
+214
View File
@@ -0,0 +1,214 @@
use rand::Rng;
use std::fmt::Debug;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
use crate::rand_custom::get_rng_impl;
use crate::tree::base_tree_regressor::{BaseTreeRegressor, BaseTreeRegressorParameters, Splitter};
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// Parameters of the Forest Regressor
/// Some parameters here are passed directly into base estimator.
pub struct BaseForestRegressorParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub max_depth: Option<u16>,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub min_samples_leaf: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub min_samples_split: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The number of trees in the forest.
pub n_trees: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// Number of random sample of predictors to use as split candidates.
pub m: Option<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub keep_samples: bool,
#[cfg_attr(feature = "serde", serde(default))]
/// Seed used for bootstrap sampling and feature selection for each tree.
pub seed: u64,
#[cfg_attr(feature = "serde", serde(default))]
pub bootstrap: bool,
#[cfg_attr(feature = "serde", serde(default))]
pub splitter: Splitter,
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
for BaseForestRegressor<TX, TY, X, Y>
{
fn eq(&self, other: &Self) -> bool {
if self.trees.as_ref().unwrap().len() != other.trees.as_ref().unwrap().len() {
false
} else {
self.trees
.iter()
.zip(other.trees.iter())
.all(|(a, b)| a == b)
}
}
}
/// Forest Regressor
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct BaseForestRegressor<
TX: Number + FloatNumber + PartialOrd,
TY: Number,
X: Array2<TX>,
Y: Array1<TY>,
> {
trees: Option<Vec<BaseTreeRegressor<TX, TY, X, Y>>>,
samples: Option<Vec<Vec<bool>>>,
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
BaseForestRegressor<TX, TY, X, Y>
{
/// Build a forest of trees from the training set.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - the target class values
pub fn fit(
x: &X,
y: &Y,
parameters: BaseForestRegressorParameters,
) -> Result<BaseForestRegressor<TX, TY, X, Y>, Failed> {
let (n_rows, num_attributes) = x.shape();
if n_rows != y.shape() {
return Err(Failed::fit("Number of rows in X should = len(y)"));
}
let mtry = parameters
.m
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
let mut rng = get_rng_impl(Some(parameters.seed));
let mut trees: Vec<BaseTreeRegressor<TX, TY, X, Y>> = Vec::new();
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
if parameters.keep_samples {
// TODO: use with_capacity here
maybe_all_samples = Some(Vec::new());
}
let mut samples: Vec<usize> = (0..n_rows).map(|_| 1).collect();
for _ in 0..parameters.n_trees {
if parameters.bootstrap {
samples =
BaseForestRegressor::<TX, TY, X, Y>::sample_with_replacement(n_rows, &mut rng);
}
// keep samples is flag is on
if let Some(ref mut all_samples) = maybe_all_samples {
all_samples.push(samples.iter().map(|x| *x != 0).collect())
}
let params = BaseTreeRegressorParameters {
max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split,
seed: Some(parameters.seed),
splitter: parameters.splitter.clone(),
};
let tree = BaseTreeRegressor::fit_weak_learner(x, y, samples.clone(), mtry, params)?;
trees.push(tree);
}
Ok(BaseForestRegressor {
trees: Some(trees),
samples: maybe_all_samples,
})
}
/// Predict class for `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let mut result = Y::zeros(x.shape().0);
let (n, _) = x.shape();
for i in 0..n {
result.set(i, self.predict_for_row(x, i));
}
Ok(result)
}
fn predict_for_row(&self, x: &X, row: usize) -> TY {
let n_trees = self.trees.as_ref().unwrap().len();
let mut result = TY::zero();
for tree in self.trees.as_ref().unwrap().iter() {
result += tree.predict_for_row(x, row);
}
result / TY::from_usize(n_trees).unwrap()
}
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
let (n, _) = x.shape();
if self.samples.is_none() {
Err(Failed::because(
FailedError::PredictFailed,
"Need samples=true for OOB predictions.",
))
} else if self.samples.as_ref().unwrap()[0].len() != n {
Err(Failed::because(
FailedError::PredictFailed,
"Prediction matrix must match matrix used in training for OOB predictions.",
))
} else {
let mut result = Y::zeros(n);
for i in 0..n {
result.set(i, self.predict_for_row_oob(x, i));
}
Ok(result)
}
}
fn predict_for_row_oob(&self, x: &X, row: usize) -> TY {
let mut n_trees = 0;
let mut result = TY::zero();
for (tree, samples) in self
.trees
.as_ref()
.unwrap()
.iter()
.zip(self.samples.as_ref().unwrap())
{
if !samples[row] {
result += tree.predict_for_row(x, row);
n_trees += 1;
}
}
// TODO: What to do if there are no oob trees?
result / TY::from(n_trees).unwrap()
}
fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec<usize> {
let mut samples = vec![0; nrows];
for _ in 0..nrows {
let xi = rng.gen_range(0..nrows);
samples[xi] += 1;
}
samples
}
}
+1
View File
@@ -16,6 +16,7 @@
//! //!
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 8.2 Bagging, Random Forests, Boosting](http://faculty.marshall.usc.edu/gareth-james/ISL/) //! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 8.2 Bagging, Random Forests, Boosting](http://faculty.marshall.usc.edu/gareth-james/ISL/)
mod base_forest_regressor;
/// Random forest classifier /// Random forest classifier
pub mod random_forest_classifier; pub mod random_forest_classifier;
/// Random forest regressor /// Random forest regressor
+19 -125
View File
@@ -43,7 +43,6 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use rand::Rng;
use std::default::Default; use std::default::Default;
use std::fmt::Debug; use std::fmt::Debug;
@@ -51,15 +50,12 @@ use std::fmt::Debug;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator}; use crate::api::{Predictor, SupervisedEstimator};
use crate::error::{Failed, FailedError}; use crate::ensemble::base_forest_regressor::{BaseForestRegressor, BaseForestRegressorParameters};
use crate::error::Failed;
use crate::linalg::basic::arrays::{Array1, Array2}; use crate::linalg::basic::arrays::{Array1, Array2};
use crate::numbers::basenum::Number; use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber; use crate::numbers::floatnum::FloatNumber;
use crate::tree::base_tree_regressor::Splitter;
use crate::rand_custom::get_rng_impl;
use crate::tree::decision_tree_regressor::{
DecisionTreeRegressor, DecisionTreeRegressorParameters,
};
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -98,8 +94,7 @@ pub struct RandomForestRegressor<
X: Array2<TX>, X: Array2<TX>,
Y: Array1<TY>, Y: Array1<TY>,
> { > {
trees: Option<Vec<DecisionTreeRegressor<TX, TY, X, Y>>>, forest_regressor: Option<BaseForestRegressor<TX, TY, X, Y>>,
samples: Option<Vec<Vec<bool>>>,
} }
impl RandomForestRegressorParameters { impl RandomForestRegressorParameters {
@@ -159,14 +154,7 @@ impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1
for RandomForestRegressor<TX, TY, X, Y> for RandomForestRegressor<TX, TY, X, Y>
{ {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
if self.trees.as_ref().unwrap().len() != other.trees.as_ref().unwrap().len() { self.forest_regressor == other.forest_regressor
false
} else {
self.trees
.iter()
.zip(other.trees.iter())
.all(|(a, b)| a == b)
}
} }
} }
@@ -176,8 +164,7 @@ impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1
{ {
fn new() -> Self { fn new() -> Self {
Self { Self {
trees: Option::None, forest_regressor: Option::None,
samples: Option::None,
} }
} }
@@ -397,128 +384,35 @@ impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1
y: &Y, y: &Y,
parameters: RandomForestRegressorParameters, parameters: RandomForestRegressorParameters,
) -> Result<RandomForestRegressor<TX, TY, X, Y>, Failed> { ) -> Result<RandomForestRegressor<TX, TY, X, Y>, Failed> {
let (n_rows, num_attributes) = x.shape(); let regressor_params = BaseForestRegressorParameters {
if n_rows != y.shape() {
return Err(Failed::fit("Number of rows in X should = len(y)"));
}
let mtry = parameters
.m
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
let mut rng = get_rng_impl(Some(parameters.seed));
let mut trees: Vec<DecisionTreeRegressor<TX, TY, X, Y>> = Vec::new();
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
if parameters.keep_samples {
// TODO: use with_capacity here
maybe_all_samples = Some(Vec::new());
}
for _ in 0..parameters.n_trees {
let samples: Vec<usize> =
RandomForestRegressor::<TX, TY, X, Y>::sample_with_replacement(n_rows, &mut rng);
// keep samples is flag is on
if let Some(ref mut all_samples) = maybe_all_samples {
all_samples.push(samples.iter().map(|x| *x != 0).collect())
}
let params = DecisionTreeRegressorParameters {
max_depth: parameters.max_depth, max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf, min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split, min_samples_split: parameters.min_samples_split,
seed: Some(parameters.seed), n_trees: parameters.n_trees,
m: parameters.m,
keep_samples: parameters.keep_samples,
seed: parameters.seed,
bootstrap: true,
splitter: Splitter::Best,
}; };
let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params)?; let forest_regressor = BaseForestRegressor::fit(x, y, regressor_params)?;
trees.push(tree);
}
Ok(RandomForestRegressor { Ok(RandomForestRegressor {
trees: Some(trees), forest_regressor: Some(forest_regressor),
samples: maybe_all_samples,
}) })
} }
/// Predict class for `x` /// Predict class for `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &X) -> Result<Y, Failed> { pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let mut result = Y::zeros(x.shape().0); let forest_regressor = self.forest_regressor.as_ref().unwrap();
forest_regressor.predict(x)
let (n, _) = x.shape();
for i in 0..n {
result.set(i, self.predict_for_row(x, i));
}
Ok(result)
}
fn predict_for_row(&self, x: &X, row: usize) -> TY {
let n_trees = self.trees.as_ref().unwrap().len();
let mut result = TY::zero();
for tree in self.trees.as_ref().unwrap().iter() {
result += tree.predict_for_row(x, row);
}
result / TY::from_usize(n_trees).unwrap()
} }
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training. /// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> { pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
let (n, _) = x.shape(); let forest_regressor = self.forest_regressor.as_ref().unwrap();
if self.samples.is_none() { forest_regressor.predict_oob(x)
Err(Failed::because(
FailedError::PredictFailed,
"Need samples=true for OOB predictions.",
))
} else if self.samples.as_ref().unwrap()[0].len() != n {
Err(Failed::because(
FailedError::PredictFailed,
"Prediction matrix must match matrix used in training for OOB predictions.",
))
} else {
let mut result = Y::zeros(n);
for i in 0..n {
result.set(i, self.predict_for_row_oob(x, i));
}
Ok(result)
}
}
fn predict_for_row_oob(&self, x: &X, row: usize) -> TY {
let mut n_trees = 0;
let mut result = TY::zero();
for (tree, samples) in self
.trees
.as_ref()
.unwrap()
.iter()
.zip(self.samples.as_ref().unwrap())
{
if !samples[row] {
result += tree.predict_for_row(x, row);
n_trees += 1;
}
}
// TODO: What to do if there are no oob trees?
result / TY::from(n_trees).unwrap()
}
fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec<usize> {
let mut samples = vec![0; nrows];
for _ in 0..nrows {
let xi = rng.gen_range(0..nrows);
samples[xi] += 1;
}
samples
} }
} }
-27
View File
@@ -312,38 +312,11 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
}) })
} }
pub(crate) fn fit_weak_learner(
x: &X,
y: &Y,
samples: Vec<usize>,
mtry: usize,
parameters: DecisionTreeRegressorParameters,
) -> Result<DecisionTreeRegressor<TX, TY, X, Y>, Failed> {
let tree_parameters = BaseTreeRegressorParameters {
max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split,
seed: parameters.seed,
splitter: Splitter::Best,
};
let tree = BaseTreeRegressor::fit_weak_learner(x, y, samples, mtry, tree_parameters)?;
Ok(Self {
tree_regressor: Some(tree),
})
}
/// Predict regression value for `x`. /// Predict regression value for `x`.
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &X) -> Result<Y, Failed> { pub fn predict(&self, x: &X) -> Result<Y, Failed> {
self.tree_regressor.as_ref().unwrap().predict(x) self.tree_regressor.as_ref().unwrap().predict(x)
} }
pub(crate) fn predict_for_row(&self, x: &X, row: usize) -> TY {
self.tree_regressor
.as_ref()
.unwrap()
.predict_for_row(x, row)
}
} }
#[cfg(test)] #[cfg(test)]
+1 -1
View File
@@ -19,7 +19,7 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
mod base_tree_regressor; pub(crate) mod base_tree_regressor;
/// Classification tree for dependent variables that take a finite number of unordered values. /// Classification tree for dependent variables that take a finite number of unordered values.
pub mod decision_tree_classifier; pub mod decision_tree_classifier;
/// Regression tree for for dependent variables that take continuous or ordered discrete values. /// Regression tree for for dependent variables that take continuous or ordered discrete values.