fix test conditions

This commit is contained in:
Lorenzo Mec-iS
2025-01-22 12:08:11 +00:00
parent 4878042392
commit 4aee603ae4
+5 -5
View File
@@ -662,7 +662,7 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
}
}
let n_trees = self.trees.as_ref().unwrap().len() as f64;
let n_trees: f64 = self.trees.as_ref().unwrap().len() as f64;
probas.mul_scalar_mut(1.0 / n_trees);
Ok(probas)
@@ -884,22 +884,22 @@ mod tests {
// These values are approximate and based on typical random forest behavior
for i in 0..(pro_n_rows / 2) {
assert!(
*probas.get((i, 0)) > f64::from_f32(0.6).unwrap(),
f64::from_f32(0.6).unwrap().lt(probas.get((i, 0))),
"Class 0 samples should have high probability for class 0"
);
assert!(
*probas.get((i, 1)) < f64::from_f32(0.4).unwrap(),
f64::from_f32(0.4).unwrap().gt(probas.get((i, 1))),
"Class 0 samples should have low probability for class 1"
);
}
for i in (pro_n_rows / 2)..pro_n_rows {
assert!(
*probas.get((i, 1)) > f64::from_f32(0.6).unwrap(),
f64::from_f32(0.6).unwrap().lt(probas.get((i, 1))),
"Class 1 samples should have high probability for class 1"
);
assert!(
*probas.get((i, 0)) < f64::from_f32(0.4).unwrap(),
f64::from_f32(0.4).unwrap().gt(probas.get((i, 0))),
"Class 1 samples should have low probability for class 0"
);
}