From bb356e6a289209ad586eaf7af20217c12125a2ae Mon Sep 17 00:00:00 2001 From: Lorenzo Mec-iS Date: Mon, 20 Jan 2025 17:29:29 +0000 Subject: [PATCH] fix test --- src/ensemble/random_forest_classifier.rs | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index 19d75f3..f398d13 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -833,8 +833,9 @@ mod tests { )] #[test] fn test_random_forest_predict_proba() { + use num_traits::FromPrimitive; // Iris-like dataset (subset) - let x = DenseMatrix::from_2d_array(&[ + let x: DenseMatrix = DenseMatrix::from_2d_array(&[ &[5.1, 3.5, 1.4, 0.2], &[4.9, 3.0, 1.4, 0.2], &[4.7, 3.2, 1.3, 0.2], @@ -881,21 +882,22 @@ mod tests { // These values are approximate and based on typical random forest behavior for i in 0..5 { assert!( - *probas.get((i, 0)) > 0.6, + *probas.get((i, 0)) > f64::from_f32(0.6).unwrap(), "Class 0 samples should have high probability for class 0" ); assert!( - *probas.get((i, 1)) < 0.4, + *probas.get((i, 1)) < f64::from_f32(0.4).unwrap(), "Class 0 samples should have low probability for class 1" ); } + for i in 5..10 { assert!( - *probas.get((i, 1)) > 0.6, + *probas.get((i, 1)) > f64::from_f32(0.6).unwrap(), "Class 1 samples should have high probability for class 1" ); assert!( - *probas.get((i, 0)) < 0.4, + *probas.get((i, 0)) < f64::from_f32(0.4).unwrap(), "Class 1 samples should have low probability for class 0" ); }