try to fix test error

This commit is contained in:
Lorenzo Mec-iS
2025-01-20 20:12:41 +00:00
parent 0262dae872
commit d427c91cef
2 changed files with 18 additions and 4 deletions
+6 -4
View File
@@ -856,8 +856,10 @@ mod tests {
// Test shape
assert_eq!(probas.shape(), (10, 2));
let (pro_n_rows, _) = probas.shape();
// Test probability sum
for i in 0..10 {
for i in 0..pro_n_rows {
let row_sum: f64 = probas.get_row(i).sum();
assert!(
(row_sum - 1.0).abs() < 1e-6,
@@ -866,7 +868,7 @@ mod tests {
}
// Test class prediction
let predictions: Vec<u32> = (0..10)
let predictions: Vec<u32> = (0..pro_n_rows)
.map(|i| {
if probas.get((i, 0)) > probas.get((i, 1)) {
0
@@ -880,7 +882,7 @@ mod tests {
// Test probability values
// These values are approximate and based on typical random forest behavior
for i in 0..5 {
for i in 0..(pro_n_rows / 2) {
assert!(
*probas.get((i, 0)) > f64::from_f32(0.6).unwrap(),
"Class 0 samples should have high probability for class 0"
@@ -891,7 +893,7 @@ mod tests {
);
}
for i in 5..10 {
for i in (pro_n_rows / 2)..pro_n_rows {
assert!(
*probas.get((i, 1)) > f64::from_f32(0.6).unwrap(),
"Class 1 samples should have high probability for class 1"