From 40ee35b04fe67776193934d9853fdcba3fe46e7b Mon Sep 17 00:00:00 2001 From: Lorenzo Mec-iS Date: Mon, 20 Jan 2025 17:15:52 +0000 Subject: [PATCH] Implement predict_proba for RandomForestClassifier --- src/ensemble/random_forest_classifier.rs | 132 +++++++++++++++++++++++ 1 file changed, 132 insertions(+) diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index dabb248..7f15be0 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -58,6 +58,8 @@ use crate::error::{Failed, FailedError}; use crate::linalg::basic::arrays::{Array1, Array2}; use crate::numbers::basenum::Number; use crate::numbers::floatnum::FloatNumber; +use crate::linalg::basic::matrix::DenseMatrix; +use crate::linalg::basic::arrays::MutArray; use crate::rand_custom::get_rng_impl; use crate::tree::decision_tree_classifier::{ @@ -602,6 +604,72 @@ impl, Y: Array1, Failed>` - The class probabilities of the input samples. + /// The order of the classes corresponds to that in the attribute `classes_`. + /// The matrix has shape (n_samples, n_classes). + /// + /// # Errors + /// + /// Returns a `Failed` error if: + /// * The model has not been fitted yet. + /// * The input `x` is not compatible with the model's expected input. + /// * Any of the tree predictions fail. + /// + /// # Examples + /// + /// ``` + /// use smartcore::ensemble::random_forest_classifier::RandomForestClassifier; + /// use smartcore::linalg::basic::matrix::DenseMatrix; + /// use smartcore::linalg::basic::arrays::Array; + /// + /// let x = DenseMatrix::from_2d_array(&[ + /// &[5.1, 3.5, 1.4, 0.2], + /// &[4.9, 3.0, 1.4, 0.2], + /// &[7.0, 3.2, 4.7, 1.4], + /// ]).unwrap(); + /// let y = vec![0, 0, 1]; + /// + /// let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap(); + /// let probas = forest.predict_proba(&x).unwrap(); + /// + /// assert_eq!(probas.shape(), (3, 2)); + /// ``` + pub fn predict_proba(&self, x: &X) -> Result, Failed> { + let (n_samples, _) = x.shape(); + let n_classes = self.classes.as_ref().unwrap().len(); + let mut probas = DenseMatrix::::zeros(n_samples, n_classes); + + for tree in self.trees.as_ref().unwrap().iter() { + let tree_predictions: Y = tree.predict(x).unwrap(); + + let mut i = 0; + for &class_idx in tree_predictions.iterator(0) { + let class_ = class_idx.to_usize().unwrap(); + probas.add_element_mut((i, class_), 1.0); + i += 1; + } + } + + let n_trees = self.trees.as_ref().unwrap().len() as f64; + probas.mul_scalar_mut(1.0 / n_trees); + + Ok(probas) + } + } #[cfg(test)] @@ -609,6 +677,8 @@ mod tests { use super::*; use crate::linalg::basic::matrix::DenseMatrix; use crate::metrics::*; + use crate::ensemble::random_forest_classifier::RandomForestClassifier; + use crate::linalg::basic::arrays::Array; #[test] fn search_parameters() { @@ -760,6 +830,68 @@ mod tests { ); } + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn test_random_forest_predict_proba() { + // Iris-like dataset (subset) + let x = DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + ]).unwrap(); + let y: Vec = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1]; + + let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap(); + let probas = forest.predict_proba(&x).unwrap(); + + // Test shape + assert_eq!(probas.shape(), (10, 2)); + + // Test probability sum + for i in 0..10 { + let row_sum: f64 = probas.get_row(i).sum(); + assert!((row_sum - 1.0).abs() < 1e-6, "Row probabilities should sum to 1"); + } + + // Test class prediction + let predictions: Vec = (0..10) + .map(|i| if probas.get((i, 0)) > probas.get((i, 1)) { 0 } else { 1 }) + .collect(); + let acc = accuracy(&y, &predictions); + assert!(acc > 0.8, "Accuracy should be high for the training set"); + + // Test probability values + // These values are approximate and based on typical random forest behavior + for i in 0..5 { + assert!(*probas.get((i, 0)) > 0.6, "Class 0 samples should have high probability for class 0"); + assert!(*probas.get((i, 1)) < 0.4, "Class 0 samples should have low probability for class 1"); + } + for i in 5..10 { + assert!(*probas.get((i, 1)) > 0.6, "Class 1 samples should have high probability for class 1"); + assert!(*probas.get((i, 0)) < 0.4, "Class 1 samples should have low probability for class 0"); + } + + // Test with new data + let x_new = DenseMatrix::from_2d_array(&[ + &[5.0, 3.4, 1.5, 0.2], // Should be close to class 0 + &[6.3, 3.3, 4.7, 1.6], // Should be close to class 1 + ]).unwrap(); + let probas_new = forest.predict_proba(&x_new).unwrap(); + assert_eq!(probas_new.shape(), (2, 2)); + assert!(probas_new.get((0, 0)) > probas_new.get((0, 1)), "First sample should be predicted as class 0"); + assert!(probas_new.get((1, 1)) > probas_new.get((1, 0)), "Second sample should be predicted as class 1"); + } + #[cfg_attr( all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test