Added per-class probability prediction for random forests
This commit is contained in:
@@ -55,7 +55,8 @@ use serde::{Deserialize, Serialize};
|
|||||||
|
|
||||||
use crate::api::{Predictor, SupervisedEstimator};
|
use crate::api::{Predictor, SupervisedEstimator};
|
||||||
use crate::error::{Failed, FailedError};
|
use crate::error::{Failed, FailedError};
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
use crate::linalg::{BaseMatrix, Matrix};
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
use crate::tree::decision_tree_classifier::{
|
use crate::tree::decision_tree_classifier::{
|
||||||
which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion,
|
which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion,
|
||||||
@@ -316,6 +317,36 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
|||||||
which_max(&result)
|
which_max(&result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Predict the per-class probabilties for each observation. The probability is calculated as the fraction of trees that predicted a given class
|
||||||
|
pub fn predict_probs<M: Matrix<T>>(&self, x: &M) -> Result<DenseMatrix<f64>, Failed> {
|
||||||
|
let mut result = DenseMatrix::<f64>::zeros(x.shape().0, self.classes.len());
|
||||||
|
|
||||||
|
let (n, _) = x.shape();
|
||||||
|
|
||||||
|
for i in 0..n {
|
||||||
|
let row_probs = self.predict_probs_for_row(x, i);
|
||||||
|
|
||||||
|
for j in 0..row_probs.len() {
|
||||||
|
result.set(i, j, row_probs[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn predict_probs_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> Vec<f64> {
|
||||||
|
let mut result = vec![0; self.classes.len()];
|
||||||
|
|
||||||
|
for tree in self.trees.iter() {
|
||||||
|
result[tree.predict_for_row(x, row)] += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
.iter()
|
||||||
|
.map(|n| *n as f64 / self.trees.len() as f64)
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
fn sample_with_replacement(y: &[usize], num_classes: usize, rng: &mut impl Rng) -> Vec<usize> {
|
fn sample_with_replacement(y: &[usize], num_classes: usize, rng: &mut impl Rng) -> Vec<usize> {
|
||||||
let class_weight = vec![1.; num_classes];
|
let class_weight = vec![1.; num_classes];
|
||||||
let nrows = y.len();
|
let nrows = y.len();
|
||||||
|
|||||||
Reference in New Issue
Block a user