From 663db0334d8b41895df0e955bb8e429b8044a327 Mon Sep 17 00:00:00 2001 From: Alan Race Date: Mon, 11 Jul 2022 16:08:03 +0200 Subject: [PATCH] Added per-class probability prediction for random forests --- src/ensemble/random_forest_classifier.rs | 33 +++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index 247b502..87062f2 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -55,7 +55,8 @@ use serde::{Deserialize, Serialize}; use crate::api::{Predictor, SupervisedEstimator}; use crate::error::{Failed, FailedError}; -use crate::linalg::Matrix; +use crate::linalg::naive::dense_matrix::DenseMatrix; +use crate::linalg::{BaseMatrix, Matrix}; use crate::math::num::RealNumber; use crate::tree::decision_tree_classifier::{ which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion, @@ -316,6 +317,36 @@ impl RandomForestClassifier { which_max(&result) } + /// Predict the per-class probabilties for each observation. The probability is calculated as the fraction of trees that predicted a given class + pub fn predict_probs>(&self, x: &M) -> Result, Failed> { + let mut result = DenseMatrix::::zeros(x.shape().0, self.classes.len()); + + let (n, _) = x.shape(); + + for i in 0..n { + let row_probs = self.predict_probs_for_row(x, i); + + for j in 0..row_probs.len() { + result.set(i, j, row_probs[j]); + } + } + + Ok(result) + } + + fn predict_probs_for_row>(&self, x: &M, row: usize) -> Vec { + let mut result = vec![0; self.classes.len()]; + + for tree in self.trees.iter() { + result[tree.predict_for_row(x, row)] += 1; + } + + result + .iter() + .map(|n| *n as f64 / self.trees.len() as f64) + .collect() + } + fn sample_with_replacement(y: &[usize], num_classes: usize, rng: &mut impl Rng) -> Vec { let class_weight = vec![1.; num_classes]; let nrows = y.len();