fix test
This commit is contained in:
@@ -833,8 +833,9 @@ mod tests {
|
|||||||
)]
|
)]
|
||||||
#[test]
|
#[test]
|
||||||
fn test_random_forest_predict_proba() {
|
fn test_random_forest_predict_proba() {
|
||||||
|
use num_traits::FromPrimitive;
|
||||||
// Iris-like dataset (subset)
|
// Iris-like dataset (subset)
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
|
||||||
&[5.1, 3.5, 1.4, 0.2],
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
&[4.9, 3.0, 1.4, 0.2],
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
&[4.7, 3.2, 1.3, 0.2],
|
&[4.7, 3.2, 1.3, 0.2],
|
||||||
@@ -881,21 +882,22 @@ mod tests {
|
|||||||
// These values are approximate and based on typical random forest behavior
|
// These values are approximate and based on typical random forest behavior
|
||||||
for i in 0..5 {
|
for i in 0..5 {
|
||||||
assert!(
|
assert!(
|
||||||
*probas.get((i, 0)) > 0.6,
|
*probas.get((i, 0)) > f64::from_f32(0.6).unwrap(),
|
||||||
"Class 0 samples should have high probability for class 0"
|
"Class 0 samples should have high probability for class 0"
|
||||||
);
|
);
|
||||||
assert!(
|
assert!(
|
||||||
*probas.get((i, 1)) < 0.4,
|
*probas.get((i, 1)) < f64::from_f32(0.4).unwrap(),
|
||||||
"Class 0 samples should have low probability for class 1"
|
"Class 0 samples should have low probability for class 1"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
for i in 5..10 {
|
for i in 5..10 {
|
||||||
assert!(
|
assert!(
|
||||||
*probas.get((i, 1)) > 0.6,
|
*probas.get((i, 1)) > f64::from_f32(0.6).unwrap(),
|
||||||
"Class 1 samples should have high probability for class 1"
|
"Class 1 samples should have high probability for class 1"
|
||||||
);
|
);
|
||||||
assert!(
|
assert!(
|
||||||
*probas.get((i, 0)) < 0.4,
|
*probas.get((i, 0)) < f64::from_f32(0.4).unwrap(),
|
||||||
"Class 1 samples should have low probability for class 0"
|
"Class 1 samples should have low probability for class 0"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user