diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index 19d75f3..f398d13 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -833,8 +833,9 @@ mod tests { )] #[test] fn test_random_forest_predict_proba() { + use num_traits::FromPrimitive; // Iris-like dataset (subset) - let x = DenseMatrix::from_2d_array(&[ + let x: DenseMatrix = DenseMatrix::from_2d_array(&[ &[5.1, 3.5, 1.4, 0.2], &[4.9, 3.0, 1.4, 0.2], &[4.7, 3.2, 1.3, 0.2], @@ -881,21 +882,22 @@ mod tests { // These values are approximate and based on typical random forest behavior for i in 0..5 { assert!( - *probas.get((i, 0)) > 0.6, + *probas.get((i, 0)) > f64::from_f32(0.6).unwrap(), "Class 0 samples should have high probability for class 0" ); assert!( - *probas.get((i, 1)) < 0.4, + *probas.get((i, 1)) < f64::from_f32(0.4).unwrap(), "Class 0 samples should have low probability for class 1" ); } + for i in 5..10 { assert!( - *probas.get((i, 1)) > 0.6, + *probas.get((i, 1)) > f64::from_f32(0.6).unwrap(), "Class 1 samples should have high probability for class 1" ); assert!( - *probas.get((i, 0)) < 0.4, + *probas.get((i, 0)) < f64::from_f32(0.4).unwrap(), "Class 1 samples should have low probability for class 0" ); }