From 4bae62ab2f7776f56363c472a4257ee0c069fee7 Mon Sep 17 00:00:00 2001 From: Malte Londschien Date: Thu, 14 Oct 2021 09:47:00 +0200 Subject: [PATCH] Test. --- src/ensemble/random_forest_classifier.rs | 51 +++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index b3c810a..f70604c 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -282,7 +282,7 @@ impl RandomForestClassifier { "Prediction matrix must match matrix used in training for OOB predictions.", )) } else { - let mut result = M::zeros(self.classes.len(), 1); + let mut result = M::zeros(1, n); for i in 0..n { result.set(0, i, self.classes[self.predict_for_row_oob(x, i)]); @@ -382,6 +382,55 @@ mod tests { assert!(accuracy(&y, &classifier.predict(&x).unwrap()) >= 0.95); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + fn fit_predict_iris_oob() { + let x = DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + &[4.9, 3.1, 1.5, 0.1], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + &[5.7, 2.8, 4.5, 1.3], + &[6.3, 3.3, 4.7, 1.6], + &[4.9, 2.4, 3.3, 1.0], + &[6.6, 2.9, 4.6, 1.3], + &[5.2, 2.7, 3.9, 1.4], + ]); + let y = vec![ + 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., + ]; + + let classifier = RandomForestClassifier::fit( + &x, + &y, + RandomForestClassifierParameters { + criterion: SplitCriterion::Gini, + max_depth: None, + min_samples_leaf: 1, + min_samples_split: 2, + n_trees: 100, + m: Option::None, + keep_samples: true, + }, + ) + .unwrap(); + assert!( + accuracy(&y, &classifier.predict_oob(&x).unwrap()) + < accuracy(&y, &classifier.predict(&x).unwrap()) + ); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] #[cfg(feature = "serde")]