fix test
This commit is contained in:
@@ -833,8 +833,9 @@ mod tests {
|
||||
)]
|
||||
#[test]
|
||||
fn test_random_forest_predict_proba() {
|
||||
use num_traits::FromPrimitive;
|
||||
// Iris-like dataset (subset)
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
@@ -881,21 +882,22 @@ mod tests {
|
||||
// These values are approximate and based on typical random forest behavior
|
||||
for i in 0..5 {
|
||||
assert!(
|
||||
*probas.get((i, 0)) > 0.6,
|
||||
*probas.get((i, 0)) > f64::from_f32(0.6).unwrap(),
|
||||
"Class 0 samples should have high probability for class 0"
|
||||
);
|
||||
assert!(
|
||||
*probas.get((i, 1)) < 0.4,
|
||||
*probas.get((i, 1)) < f64::from_f32(0.4).unwrap(),
|
||||
"Class 0 samples should have low probability for class 1"
|
||||
);
|
||||
}
|
||||
|
||||
for i in 5..10 {
|
||||
assert!(
|
||||
*probas.get((i, 1)) > 0.6,
|
||||
*probas.get((i, 1)) > f64::from_f32(0.6).unwrap(),
|
||||
"Class 1 samples should have high probability for class 1"
|
||||
);
|
||||
assert!(
|
||||
*probas.get((i, 0)) < 0.4,
|
||||
*probas.get((i, 0)) < f64::from_f32(0.4).unwrap(),
|
||||
"Class 1 samples should have low probability for class 0"
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user