diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index 6c0258e..b302ef4 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -662,7 +662,7 @@ impl, Y: Array1 f64::from_f32(0.6).unwrap(), + f64::from_f32(0.6).unwrap().lt(probas.get((i, 0))), "Class 0 samples should have high probability for class 0" ); assert!( - *probas.get((i, 1)) < f64::from_f32(0.4).unwrap(), + f64::from_f32(0.4).unwrap().gt(probas.get((i, 1))), "Class 0 samples should have low probability for class 1" ); } for i in (pro_n_rows / 2)..pro_n_rows { assert!( - *probas.get((i, 1)) > f64::from_f32(0.6).unwrap(), + f64::from_f32(0.6).unwrap().lt(probas.get((i, 1))), "Class 1 samples should have high probability for class 1" ); assert!( - *probas.get((i, 0)) < f64::from_f32(0.4).unwrap(), + f64::from_f32(0.4).unwrap().gt(probas.get((i, 0))), "Class 1 samples should have low probability for class 0" ); }