Merge branch 'development' into predict-probability

This commit is contained in:
Alan Race
2022-08-29 16:24:39 +02:00
+72 -4
View File
@@ -317,7 +317,8 @@ impl<T: RealNumber> RandomForestClassifier<T> {
which_max(&result) which_max(&result)
} }
/// Predict the per-class probabilties for each observation. The probability is calculated as the fraction of trees that predicted a given class /// Predict the per-class probabilties for each observation.
/// The probability is calculated as the fraction of trees that predicted a given class
pub fn predict_probs<M: Matrix<T>>(&self, x: &M) -> Result<DenseMatrix<f64>, Failed> { pub fn predict_probs<M: Matrix<T>>(&self, x: &M) -> Result<DenseMatrix<f64>, Failed> {
let mut result = DenseMatrix::<f64>::zeros(x.shape().0, self.classes.len()); let mut result = DenseMatrix::<f64>::zeros(x.shape().0, self.classes.len());
@@ -326,8 +327,8 @@ impl<T: RealNumber> RandomForestClassifier<T> {
for i in 0..n { for i in 0..n {
let row_probs = self.predict_probs_for_row(x, i); let row_probs = self.predict_probs_for_row(x, i);
for j in 0..row_probs.len() { for (j, item) in row_probs.iter().enumerate() {
result.set(i, j, row_probs[j]); result.set(i, j, *item);
} }
} }
@@ -372,7 +373,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests_prob {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::metrics::*; use crate::metrics::*;
@@ -513,4 +514,71 @@ mod tests {
assert_eq!(forest, deserialized_forest); assert_eq!(forest, deserialized_forest);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn fit_predict_probabilities() {
let x = DenseMatrix::<f64>::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[5.7, 2.8, 4.5, 1.3],
&[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
let y = vec![
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
let classifier = RandomForestClassifier::fit(
&x,
&y,
RandomForestClassifierParameters {
criterion: SplitCriterion::Gini,
max_depth: None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 100,
m: Option::None,
keep_samples: false,
seed: 87,
},
)
.unwrap();
println!("{:?}", classifier.classes);
let results = classifier.predict_probs(&x).unwrap();
println!("{:?}", x.shape());
println!("{:?}", results);
println!("{:?}", results.shape());
assert_eq!(
results,
DenseMatrix::<f64>::from_array(
20,
2,
&[
1.0, 0.78, 0.95, 0.82, 1.0, 0.92, 0.99, 0.96, 0.36, 0.33, 0.02, 0.02, 0.0, 0.0,
0.0, 0.0, 0.03, 0.05, 0.0, 0.02, 0.0, 0.22, 0.05, 0.18, 0.0, 0.08, 0.01, 0.04,
0.64, 0.67, 0.98, 0.98, 1.0, 1.0, 1.0, 1.0, 0.97, 0.95, 1.0, 0.98
]
)
);
assert!(false);
}
} }