Merge pull request #1 from smartcorelib/alanrace-predict-probs
Add test to predict probabilities
This commit is contained in:
@@ -55,7 +55,8 @@ use serde::{Deserialize, Serialize};
|
|||||||
|
|
||||||
use crate::api::{Predictor, SupervisedEstimator};
|
use crate::api::{Predictor, SupervisedEstimator};
|
||||||
use crate::error::{Failed, FailedError};
|
use crate::error::{Failed, FailedError};
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
use crate::linalg::{BaseMatrix, Matrix};
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
use crate::tree::decision_tree_classifier::{
|
use crate::tree::decision_tree_classifier::{
|
||||||
which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion,
|
which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion,
|
||||||
@@ -316,6 +317,37 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
|||||||
which_max(&result)
|
which_max(&result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Predict the per-class probabilties for each observation.
|
||||||
|
/// The probability is calculated as the fraction of trees that predicted a given class
|
||||||
|
pub fn predict_probs<M: Matrix<T>>(&self, x: &M) -> Result<DenseMatrix<f64>, Failed> {
|
||||||
|
let mut result = DenseMatrix::<f64>::zeros(x.shape().0, self.classes.len());
|
||||||
|
|
||||||
|
let (n, _) = x.shape();
|
||||||
|
|
||||||
|
for i in 0..n {
|
||||||
|
let row_probs = self.predict_probs_for_row(x, i);
|
||||||
|
|
||||||
|
for (j, item) in row_probs.iter().enumerate() {
|
||||||
|
result.set(i, j, *item);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn predict_probs_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> Vec<f64> {
|
||||||
|
let mut result = vec![0; self.classes.len()];
|
||||||
|
|
||||||
|
for tree in self.trees.iter() {
|
||||||
|
result[tree.predict_for_row(x, row)] += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
.iter()
|
||||||
|
.map(|n| *n as f64 / self.trees.len() as f64)
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
fn sample_with_replacement(y: &[usize], num_classes: usize, rng: &mut impl Rng) -> Vec<usize> {
|
fn sample_with_replacement(y: &[usize], num_classes: usize, rng: &mut impl Rng) -> Vec<usize> {
|
||||||
let class_weight = vec![1.; num_classes];
|
let class_weight = vec![1.; num_classes];
|
||||||
let nrows = y.len();
|
let nrows = y.len();
|
||||||
@@ -341,7 +373,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests_prob {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
use crate::metrics::*;
|
use crate::metrics::*;
|
||||||
@@ -482,4 +514,71 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(forest, deserialized_forest);
|
assert_eq!(forest, deserialized_forest);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
|
#[test]
|
||||||
|
fn fit_predict_probabilities() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
|
&[4.7, 3.2, 1.3, 0.2],
|
||||||
|
&[4.6, 3.1, 1.5, 0.2],
|
||||||
|
&[5.0, 3.6, 1.4, 0.2],
|
||||||
|
&[5.4, 3.9, 1.7, 0.4],
|
||||||
|
&[4.6, 3.4, 1.4, 0.3],
|
||||||
|
&[5.0, 3.4, 1.5, 0.2],
|
||||||
|
&[4.4, 2.9, 1.4, 0.2],
|
||||||
|
&[4.9, 3.1, 1.5, 0.1],
|
||||||
|
&[7.0, 3.2, 4.7, 1.4],
|
||||||
|
&[6.4, 3.2, 4.5, 1.5],
|
||||||
|
&[6.9, 3.1, 4.9, 1.5],
|
||||||
|
&[5.5, 2.3, 4.0, 1.3],
|
||||||
|
&[6.5, 2.8, 4.6, 1.5],
|
||||||
|
&[5.7, 2.8, 4.5, 1.3],
|
||||||
|
&[6.3, 3.3, 4.7, 1.6],
|
||||||
|
&[4.9, 2.4, 3.3, 1.0],
|
||||||
|
&[6.6, 2.9, 4.6, 1.3],
|
||||||
|
&[5.2, 2.7, 3.9, 1.4],
|
||||||
|
]);
|
||||||
|
let y = vec![
|
||||||
|
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||||
|
];
|
||||||
|
|
||||||
|
let classifier = RandomForestClassifier::fit(
|
||||||
|
&x,
|
||||||
|
&y,
|
||||||
|
RandomForestClassifierParameters {
|
||||||
|
criterion: SplitCriterion::Gini,
|
||||||
|
max_depth: None,
|
||||||
|
min_samples_leaf: 1,
|
||||||
|
min_samples_split: 2,
|
||||||
|
n_trees: 100,
|
||||||
|
m: Option::None,
|
||||||
|
keep_samples: false,
|
||||||
|
seed: 87,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
println!("{:?}", classifier.classes);
|
||||||
|
|
||||||
|
let results = classifier.predict_probs(&x).unwrap();
|
||||||
|
println!("{:?}", x.shape());
|
||||||
|
println!("{:?}", results);
|
||||||
|
println!("{:?}", results.shape());
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
results,
|
||||||
|
DenseMatrix::<f64>::from_array(
|
||||||
|
20,
|
||||||
|
2,
|
||||||
|
&[
|
||||||
|
1.0, 0.78, 0.95, 0.82, 1.0, 0.92, 0.99, 0.96, 0.36, 0.33, 0.02, 0.02, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.03, 0.05, 0.0, 0.02, 0.0, 0.22, 0.05, 0.18, 0.0, 0.08, 0.01, 0.04,
|
||||||
|
0.64, 0.67, 0.98, 0.98, 1.0, 1.0, 1.0, 1.0, 0.97, 0.95, 1.0, 0.98
|
||||||
|
]
|
||||||
|
)
|
||||||
|
);
|
||||||
|
assert!(false);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user