From 61db4ebd90be072fbe9a81ef998a99d6a47d8519 Mon Sep 17 00:00:00 2001 From: "Lorenzo (Mec-iS)" Date: Wed, 24 Aug 2022 12:34:56 +0100 Subject: [PATCH] Add test --- src/ensemble/random_forest_classifier.rs | 25 ++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index dcfe41a..baf6901 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -317,7 +317,8 @@ impl RandomForestClassifier { which_max(&result) } - /// Predict the per-class probabilties for each observation. The probability is calculated as the fraction of trees that predicted a given class + /// Predict the per-class probabilties for each observation. + /// The probability is calculated as the fraction of trees that predicted a given class pub fn predict_probs>(&self, x: &M) -> Result, Failed> { let mut result = DenseMatrix::::zeros(x.shape().0, self.classes.len()); @@ -326,8 +327,8 @@ impl RandomForestClassifier { for i in 0..n { let row_probs = self.predict_probs_for_row(x, i); - for j in 0..row_probs.len() { - result.set(i, j, row_probs[j]); + for (j, item) in row_probs.iter().enumerate() { + result.set(i, j, *item); } } @@ -559,9 +560,25 @@ mod tests_prob { ) .unwrap(); - let results = classifier.predict_probs(&x).unwrap(); + println!("{:?}", classifier.classes); + let results = classifier.predict_probs(&x).unwrap(); + println!("{:?}", x.shape()); println!("{:?}", results); + println!("{:?}", results.shape()); + + assert_eq!( + results, + DenseMatrix::::from_array( + 20, + 2, + &[ + 1.0, 0.78, 0.95, 0.82, 1.0, 0.92, 0.99, 0.96, 0.36, 0.33, 0.02, 0.02, 0.0, 0.0, + 0.0, 0.0, 0.03, 0.05, 0.0, 0.02, 0.0, 0.22, 0.05, 0.18, 0.0, 0.08, 0.01, 0.04, + 0.64, 0.67, 0.98, 0.98, 1.0, 1.0, 1.0, 1.0, 0.97, 0.95, 1.0, 0.98 + ] + ) + ); assert!(false); } }