From d427c91cef5bfdec9f0d629760e7da8b7703b810 Mon Sep 17 00:00:00 2001 From: Lorenzo Mec-iS Date: Mon, 20 Jan 2025 20:12:41 +0000 Subject: [PATCH] try to fix test error --- .github/CONTRIBUTING.md | 12 ++++++++++++ src/ensemble/random_forest_classifier.rs | 10 ++++++---- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 895db0f..06d3e86 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -70,3 +70,15 @@ $ rust-code-analysis-cli -p src/algorithm/neighbour/fastpair.rs --ls 22 --le 213 * **PRs on develop**: any change should be PRed first in `development` * **testing**: everything should work and be tested as defined in the workflow. If any is failing for non-related reasons, annotate the test failure in the PR comment. + + +## Suggestions for debugging +1. Install `lldb` for your platform +2. Run `rust-lldb target/debug/libsmartcore.rlib` in your command-line +3. In lldb, set up some breakpoints using `b func_name` or `b src/path/to/file.rs:linenumber` +4. In lldb, run a single test with `r the_name_of_your_test` + +Display variables in scope: `frame variable ` + +Execute expression: `p ` + diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index f398d13..6c0258e 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -856,8 +856,10 @@ mod tests { // Test shape assert_eq!(probas.shape(), (10, 2)); + let (pro_n_rows, _) = probas.shape(); + // Test probability sum - for i in 0..10 { + for i in 0..pro_n_rows { let row_sum: f64 = probas.get_row(i).sum(); assert!( (row_sum - 1.0).abs() < 1e-6, @@ -866,7 +868,7 @@ mod tests { } // Test class prediction - let predictions: Vec = (0..10) + let predictions: Vec = (0..pro_n_rows) .map(|i| { if probas.get((i, 0)) > probas.get((i, 1)) { 0 @@ -880,7 +882,7 @@ mod tests { // Test probability values // These values are approximate and based on typical random forest behavior - for i in 0..5 { + for i in 0..(pro_n_rows / 2) { assert!( *probas.get((i, 0)) > f64::from_f32(0.6).unwrap(), "Class 0 samples should have high probability for class 0" @@ -891,7 +893,7 @@ mod tests { ); } - for i in 5..10 { + for i in (pro_n_rows / 2)..pro_n_rows { assert!( *probas.get((i, 1)) > f64::from_f32(0.6).unwrap(), "Class 1 samples should have high probability for class 1"