From 4841791b7e8d74b7182657610c54629bcba781bd Mon Sep 17 00:00:00 2001 From: Daniel Lacina <127566471+DanielLacina@users.noreply.github.com> Date: Sat, 12 Jul 2025 13:37:11 -0400 Subject: [PATCH] implemented extra trees (#320) * implemented extra trees * implemented extra trees --- src/ensemble/extra_trees_regressor.rs | 318 ++++++++++++++++++++++++++ src/ensemble/mod.rs | 1 + 2 files changed, 319 insertions(+) create mode 100644 src/ensemble/extra_trees_regressor.rs diff --git a/src/ensemble/extra_trees_regressor.rs b/src/ensemble/extra_trees_regressor.rs new file mode 100644 index 0000000..818ac6c --- /dev/null +++ b/src/ensemble/extra_trees_regressor.rs @@ -0,0 +1,318 @@ +//! # Extra Trees Regressor +//! An Extra-Trees (Extremely Randomized Trees) regressor is an ensemble learning method that fits multiple randomized +//! decision trees on the dataset and averages their predictions to improve accuracy and control over-fitting. +//! +//! It is similar to a standard Random Forest, but introduces more randomness in the way splits are chosen, which can +//! reduce the variance of the model and often make the training process faster. +//! +//! The two key differences from a standard Random Forest are: +//! 1. It uses the whole original dataset to build each tree instead of bootstrap samples. +//! 2. When splitting a node, it chooses a random split point for each feature, rather than the most optimal one. +//! +//! See [ensemble models](../index.html) for more details. +//! +//! Bigger number of estimators in general improves performance of the algorithm with an increased cost of training time. +//! The random sample of _m_ predictors is typically set to be \\(\sqrt{p}\\) from the full set of _p_ predictors. +//! +//! Example: +//! +//! ``` +//! use smartcore::linalg::basic::matrix::DenseMatrix; +//! use smartcore::ensemble::extra_trees_regressor::*; +//! +//! // Longley dataset ([https://www.statsmodels.org/stable/datasets/generated/longley.html](https://www.statsmodels.org/stable/datasets/generated/longley.html)) +//! let x = DenseMatrix::from_2d_array(&[ +//! &[234.289, 235.6, 159., 107.608, 1947., 60.323], +//! &[259.426, 232.5, 145.6, 108.632, 1948., 61.122], +//! &[258.054, 368.2, 161.6, 109.773, 1949., 60.171], +//! &[284.599, 335.1, 165., 110.929, 1950., 61.187], +//! &[328.975, 209.9, 309.9, 112.075, 1951., 63.221], +//! &[346.999, 193.2, 359.4, 113.27, 1952., 63.639], +//! &[365.385, 187., 354.7, 115.094, 1953., 64.989], +//! &[363.112, 357.8, 335., 116.219, 1954., 63.761], +//! &[397.469, 290.4, 304.8, 117.388, 1955., 66.019], +//! &[419.18, 282.2, 285.7, 118.734, 1956., 67.857], +//! &[442.769, 293.6, 279.8, 120.445, 1957., 68.169], +//! &[444.546, 468.1, 263.7, 121.95, 1958., 66.513], +//! &[482.704, 381.3, 255.2, 123.366, 1959., 68.655], +//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564], +//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], +//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551], +//! ]).unwrap(); +//! let y = vec![ +//! 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, +//! 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9 +//! ]; +//! +//! let regressor = ExtraTreesRegressor::fit(&x, &y, Default::default()).unwrap(); +//! +//! let y_hat = regressor.predict(&x).unwrap(); // use the same data for prediction +//! ``` +//! +//! +//! + +use std::default::Default; +use std::fmt::Debug; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use crate::api::{Predictor, SupervisedEstimator}; +use crate::ensemble::base_forest_regressor::{BaseForestRegressor, BaseForestRegressorParameters}; +use crate::error::Failed; +use crate::linalg::basic::arrays::{Array1, Array2}; +use crate::numbers::basenum::Number; +use crate::numbers::floatnum::FloatNumber; +use crate::tree::base_tree_regressor::Splitter; + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +/// Parameters of the Extra Trees Regressor +/// Some parameters here are passed directly into base estimator. +pub struct ExtraTreesRegressorParameters { + #[cfg_attr(feature = "serde", serde(default))] + /// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) + pub max_depth: Option, + #[cfg_attr(feature = "serde", serde(default))] + /// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) + pub min_samples_leaf: usize, + #[cfg_attr(feature = "serde", serde(default))] + /// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) + pub min_samples_split: usize, + #[cfg_attr(feature = "serde", serde(default))] + /// The number of trees in the forest. + pub n_trees: usize, + #[cfg_attr(feature = "serde", serde(default))] + /// Number of random sample of predictors to use as split candidates. + pub m: Option, + #[cfg_attr(feature = "serde", serde(default))] + /// Whether to keep samples used for tree generation. This is required for OOB prediction. + pub keep_samples: bool, + #[cfg_attr(feature = "serde", serde(default))] + /// Seed used for bootstrap sampling and feature selection for each tree. + pub seed: u64, +} + +/// Extra Trees Regressor +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] +pub struct ExtraTreesRegressor< + TX: Number + FloatNumber + PartialOrd, + TY: Number, + X: Array2, + Y: Array1, +> { + forest_regressor: Option>, +} + +impl ExtraTreesRegressorParameters { + /// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub fn with_max_depth(mut self, max_depth: u16) -> Self { + self.max_depth = Some(max_depth); + self + } + /// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self { + self.min_samples_leaf = min_samples_leaf; + self + } + /// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self { + self.min_samples_split = min_samples_split; + self + } + /// The number of trees in the forest. + pub fn with_n_trees(mut self, n_trees: usize) -> Self { + self.n_trees = n_trees; + self + } + /// Number of random sample of predictors to use as split candidates. + pub fn with_m(mut self, m: usize) -> Self { + self.m = Some(m); + self + } + + /// Whether to keep samples used for tree generation. This is required for OOB prediction. + pub fn with_keep_samples(mut self, keep_samples: bool) -> Self { + self.keep_samples = keep_samples; + self + } + + /// Seed used for bootstrap sampling and feature selection for each tree. + pub fn with_seed(mut self, seed: u64) -> Self { + self.seed = seed; + self + } +} +impl Default for ExtraTreesRegressorParameters { + fn default() -> Self { + ExtraTreesRegressorParameters { + max_depth: Option::None, + min_samples_leaf: 1, + min_samples_split: 2, + n_trees: 10, + m: Option::None, + keep_samples: false, + seed: 0, + } + } +} + +impl, Y: Array1> + SupervisedEstimator for ExtraTreesRegressor +{ + fn new() -> Self { + Self { + forest_regressor: Option::None, + } + } + + fn fit(x: &X, y: &Y, parameters: ExtraTreesRegressorParameters) -> Result { + ExtraTreesRegressor::fit(x, y, parameters) + } +} + +impl, Y: Array1> + Predictor for ExtraTreesRegressor +{ + fn predict(&self, x: &X) -> Result { + self.predict(x) + } +} + +impl, Y: Array1> + ExtraTreesRegressor +{ + /// Build a forest of trees from the training set. + /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation. + /// * `y` - the target class values + pub fn fit( + x: &X, + y: &Y, + parameters: ExtraTreesRegressorParameters, + ) -> Result, Failed> { + let regressor_params = BaseForestRegressorParameters { + max_depth: parameters.max_depth, + min_samples_leaf: parameters.min_samples_leaf, + min_samples_split: parameters.min_samples_split, + n_trees: parameters.n_trees, + m: parameters.m, + keep_samples: parameters.keep_samples, + seed: parameters.seed, + bootstrap: false, + splitter: Splitter::Random, + }; + let forest_regressor = BaseForestRegressor::fit(x, y, regressor_params)?; + + Ok(ExtraTreesRegressor { + forest_regressor: Some(forest_regressor), + }) + } + + /// Predict class for `x` + /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. + pub fn predict(&self, x: &X) -> Result { + let forest_regressor = self.forest_regressor.as_ref().unwrap(); + forest_regressor.predict(x) + } + + /// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training. + pub fn predict_oob(&self, x: &X) -> Result { + let forest_regressor = self.forest_regressor.as_ref().unwrap(); + forest_regressor.predict_oob(x) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::linalg::basic::matrix::DenseMatrix; + use crate::metrics::mean_squared_error; + + #[test] + fn test_extra_trees_regressor_fit_predict() { + // Use a simpler, more predictable dataset for unit testing. + let x = DenseMatrix::from_2d_array(&[ + &[1., 2.], + &[3., 4.], + &[5., 6.], + &[7., 8.], + &[9., 10.], + &[11., 12.], + &[13., 14.], + &[15., 16.], + ]) + .unwrap(); + let y = vec![1., 2., 3., 4., 5., 6., 7., 8.]; + + let parameters = ExtraTreesRegressorParameters::default() + .with_n_trees(100) + .with_seed(42); + + let regressor = ExtraTreesRegressor::fit(&x, &y, parameters).unwrap(); + let y_hat = regressor.predict(&x).unwrap(); + + assert_eq!(y_hat.len(), y.len()); + // A basic check to ensure the model is learning something. + // The error should be significantly less than the variance of y. + let mse = mean_squared_error(&y, &y_hat); + // With this simple dataset, the error should be very low. + assert!(mse < 1.0); + } + + #[test] + fn test_fit_predict_higher_dims() { + // Dataset with 10 features, but y is only dependent on the 3rd feature (index 2). + let x = DenseMatrix::from_2d_array(&[ + // The 3rd column is the important one. The rest are noise. + &[0., 0., 10., 5., 8., 1., 4., 9., 2., 7.], + &[0., 0., 20., 1., 2., 3., 4., 5., 6., 7.], + &[0., 0., 30., 7., 6., 5., 4., 3., 2., 1.], + &[0., 0., 40., 9., 2., 4., 6., 8., 1., 3.], + &[0., 0., 55., 3., 1., 8., 6., 4., 2., 9.], + &[0., 0., 65., 2., 4., 7., 5., 3., 1., 8.], + ]) + .unwrap(); + let y = vec![10., 20., 30., 40., 55., 65.]; + + let parameters = ExtraTreesRegressorParameters::default() + .with_n_trees(100) + .with_seed(42); + + let regressor = ExtraTreesRegressor::fit(&x, &y, parameters).unwrap(); + let y_hat = regressor.predict(&x).unwrap(); + + assert_eq!(y_hat.len(), y.len()); + + let mse = mean_squared_error(&y, &y_hat); + + // The model should be able to learn this simple relationship perfectly, + // ignoring the noise features. The MSE should be very low. + assert!(mse < 1.0); + } + + #[test] + fn test_reproducibility() { + let x = DenseMatrix::from_2d_array(&[ + &[1., 2.], + &[3., 4.], + &[5., 6.], + &[7., 8.], + &[9., 10.], + &[11., 12.], + ]) + .unwrap(); + let y = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + + let params = ExtraTreesRegressorParameters::default().with_seed(42); + + let regressor1 = ExtraTreesRegressor::fit(&x, &y, params.clone()).unwrap(); + let y_hat1 = regressor1.predict(&x).unwrap(); + + let regressor2 = ExtraTreesRegressor::fit(&x, &y, params.clone()).unwrap(); + let y_hat2 = regressor2.predict(&x).unwrap(); + + assert_eq!(y_hat1, y_hat2); + } +} diff --git a/src/ensemble/mod.rs b/src/ensemble/mod.rs index dc03096..4f5eefc 100644 --- a/src/ensemble/mod.rs +++ b/src/ensemble/mod.rs @@ -17,6 +17,7 @@ //! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 8.2 Bagging, Random Forests, Boosting](http://faculty.marshall.usc.edu/gareth-james/ISL/) mod base_forest_regressor; +pub mod extra_trees_regressor; /// Random forest classifier pub mod random_forest_classifier; /// Random forest regressor