Implement predict_proba for RandomForestClassifier
This commit is contained in:
@@ -58,6 +58,8 @@ use crate::error::{Failed, FailedError};
|
|||||||
use crate::linalg::basic::arrays::{Array1, Array2};
|
use crate::linalg::basic::arrays::{Array1, Array2};
|
||||||
use crate::numbers::basenum::Number;
|
use crate::numbers::basenum::Number;
|
||||||
use crate::numbers::floatnum::FloatNumber;
|
use crate::numbers::floatnum::FloatNumber;
|
||||||
|
use crate::linalg::basic::matrix::DenseMatrix;
|
||||||
|
use crate::linalg::basic::arrays::MutArray;
|
||||||
|
|
||||||
use crate::rand_custom::get_rng_impl;
|
use crate::rand_custom::get_rng_impl;
|
||||||
use crate::tree::decision_tree_classifier::{
|
use crate::tree::decision_tree_classifier::{
|
||||||
@@ -602,6 +604,72 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
|
|||||||
}
|
}
|
||||||
samples
|
samples
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Predict class probabilities for X.
|
||||||
|
///
|
||||||
|
/// The predicted class probabilities of an input sample are computed as
|
||||||
|
/// the mean predicted class probabilities of the trees in the forest.
|
||||||
|
/// The class probability of a single tree is the fraction of samples of
|
||||||
|
/// the same class in a leaf.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `x` - The input samples. A matrix of shape (n_samples, n_features).
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// * `Result<DenseMatrix<f64>, Failed>` - The class probabilities of the input samples.
|
||||||
|
/// The order of the classes corresponds to that in the attribute `classes_`.
|
||||||
|
/// The matrix has shape (n_samples, n_classes).
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
///
|
||||||
|
/// Returns a `Failed` error if:
|
||||||
|
/// * The model has not been fitted yet.
|
||||||
|
/// * The input `x` is not compatible with the model's expected input.
|
||||||
|
/// * Any of the tree predictions fail.
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// use smartcore::ensemble::random_forest_classifier::RandomForestClassifier;
|
||||||
|
/// use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||||
|
/// use smartcore::linalg::basic::arrays::Array;
|
||||||
|
///
|
||||||
|
/// let x = DenseMatrix::from_2d_array(&[
|
||||||
|
/// &[5.1, 3.5, 1.4, 0.2],
|
||||||
|
/// &[4.9, 3.0, 1.4, 0.2],
|
||||||
|
/// &[7.0, 3.2, 4.7, 1.4],
|
||||||
|
/// ]).unwrap();
|
||||||
|
/// let y = vec![0, 0, 1];
|
||||||
|
///
|
||||||
|
/// let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||||
|
/// let probas = forest.predict_proba(&x).unwrap();
|
||||||
|
///
|
||||||
|
/// assert_eq!(probas.shape(), (3, 2));
|
||||||
|
/// ```
|
||||||
|
pub fn predict_proba(&self, x: &X) -> Result<DenseMatrix<f64>, Failed> {
|
||||||
|
let (n_samples, _) = x.shape();
|
||||||
|
let n_classes = self.classes.as_ref().unwrap().len();
|
||||||
|
let mut probas = DenseMatrix::<f64>::zeros(n_samples, n_classes);
|
||||||
|
|
||||||
|
for tree in self.trees.as_ref().unwrap().iter() {
|
||||||
|
let tree_predictions: Y = tree.predict(x).unwrap();
|
||||||
|
|
||||||
|
let mut i = 0;
|
||||||
|
for &class_idx in tree_predictions.iterator(0) {
|
||||||
|
let class_ = class_idx.to_usize().unwrap();
|
||||||
|
probas.add_element_mut((i, class_), 1.0);
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let n_trees = self.trees.as_ref().unwrap().len() as f64;
|
||||||
|
probas.mul_scalar_mut(1.0 / n_trees);
|
||||||
|
|
||||||
|
Ok(probas)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -609,6 +677,8 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::basic::matrix::DenseMatrix;
|
use crate::linalg::basic::matrix::DenseMatrix;
|
||||||
use crate::metrics::*;
|
use crate::metrics::*;
|
||||||
|
use crate::ensemble::random_forest_classifier::RandomForestClassifier;
|
||||||
|
use crate::linalg::basic::arrays::Array;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn search_parameters() {
|
fn search_parameters() {
|
||||||
@@ -760,6 +830,68 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn test_random_forest_predict_proba() {
|
||||||
|
// Iris-like dataset (subset)
|
||||||
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
|
&[4.7, 3.2, 1.3, 0.2],
|
||||||
|
&[4.6, 3.1, 1.5, 0.2],
|
||||||
|
&[5.0, 3.6, 1.4, 0.2],
|
||||||
|
&[7.0, 3.2, 4.7, 1.4],
|
||||||
|
&[6.4, 3.2, 4.5, 1.5],
|
||||||
|
&[6.9, 3.1, 4.9, 1.5],
|
||||||
|
&[5.5, 2.3, 4.0, 1.3],
|
||||||
|
&[6.5, 2.8, 4.6, 1.5],
|
||||||
|
]).unwrap();
|
||||||
|
let y: Vec<u32> = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
|
||||||
|
|
||||||
|
let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||||
|
let probas = forest.predict_proba(&x).unwrap();
|
||||||
|
|
||||||
|
// Test shape
|
||||||
|
assert_eq!(probas.shape(), (10, 2));
|
||||||
|
|
||||||
|
// Test probability sum
|
||||||
|
for i in 0..10 {
|
||||||
|
let row_sum: f64 = probas.get_row(i).sum();
|
||||||
|
assert!((row_sum - 1.0).abs() < 1e-6, "Row probabilities should sum to 1");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test class prediction
|
||||||
|
let predictions: Vec<u32> = (0..10)
|
||||||
|
.map(|i| if probas.get((i, 0)) > probas.get((i, 1)) { 0 } else { 1 })
|
||||||
|
.collect();
|
||||||
|
let acc = accuracy(&y, &predictions);
|
||||||
|
assert!(acc > 0.8, "Accuracy should be high for the training set");
|
||||||
|
|
||||||
|
// Test probability values
|
||||||
|
// These values are approximate and based on typical random forest behavior
|
||||||
|
for i in 0..5 {
|
||||||
|
assert!(*probas.get((i, 0)) > 0.6, "Class 0 samples should have high probability for class 0");
|
||||||
|
assert!(*probas.get((i, 1)) < 0.4, "Class 0 samples should have low probability for class 1");
|
||||||
|
}
|
||||||
|
for i in 5..10 {
|
||||||
|
assert!(*probas.get((i, 1)) > 0.6, "Class 1 samples should have high probability for class 1");
|
||||||
|
assert!(*probas.get((i, 0)) < 0.4, "Class 1 samples should have low probability for class 0");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with new data
|
||||||
|
let x_new = DenseMatrix::from_2d_array(&[
|
||||||
|
&[5.0, 3.4, 1.5, 0.2], // Should be close to class 0
|
||||||
|
&[6.3, 3.3, 4.7, 1.6], // Should be close to class 1
|
||||||
|
]).unwrap();
|
||||||
|
let probas_new = forest.predict_proba(&x_new).unwrap();
|
||||||
|
assert_eq!(probas_new.shape(), (2, 2));
|
||||||
|
assert!(probas_new.get((0, 0)) > probas_new.get((0, 1)), "First sample should be predicted as class 0");
|
||||||
|
assert!(probas_new.get((1, 1)) > probas_new.get((1, 0)), "Second sample should be predicted as class 1");
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(
|
#[cfg_attr(
|
||||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
wasm_bindgen_test::wasm_bindgen_test
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
|||||||
Reference in New Issue
Block a user