This commit is contained in:
Lorenzo Mec-iS
2025-01-20 17:29:29 +00:00
parent 52b797d520
commit bb356e6a28
+7 -5
View File
@@ -833,8 +833,9 @@ mod tests {
)]
#[test]
fn test_random_forest_predict_proba() {
use num_traits::FromPrimitive;
// Iris-like dataset (subset)
let x = DenseMatrix::from_2d_array(&[
let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
@@ -881,21 +882,22 @@ mod tests {
// These values are approximate and based on typical random forest behavior
for i in 0..5 {
assert!(
*probas.get((i, 0)) > 0.6,
*probas.get((i, 0)) > f64::from_f32(0.6).unwrap(),
"Class 0 samples should have high probability for class 0"
);
assert!(
*probas.get((i, 1)) < 0.4,
*probas.get((i, 1)) < f64::from_f32(0.4).unwrap(),
"Class 0 samples should have low probability for class 1"
);
}
for i in 5..10 {
assert!(
*probas.get((i, 1)) > 0.6,
*probas.get((i, 1)) > f64::from_f32(0.6).unwrap(),
"Class 1 samples should have high probability for class 1"
);
assert!(
*probas.get((i, 0)) < 0.4,
*probas.get((i, 0)) < f64::from_f32(0.4).unwrap(),
"Class 1 samples should have low probability for class 0"
);
}