This commit is contained in:
Lorenzo Mec-iS
2025-01-20 17:18:09 +00:00
parent 63fa00334b
commit 52b797d520
+43 -15
View File
@@ -55,11 +55,11 @@ use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator}; use crate::api::{Predictor, SupervisedEstimator};
use crate::error::{Failed, FailedError}; use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::MutArray;
use crate::linalg::basic::arrays::{Array1, Array2}; use crate::linalg::basic::arrays::{Array1, Array2};
use crate::linalg::basic::matrix::DenseMatrix;
use crate::numbers::basenum::Number; use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber; use crate::numbers::floatnum::FloatNumber;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::linalg::basic::arrays::MutArray;
use crate::rand_custom::get_rng_impl; use crate::rand_custom::get_rng_impl;
use crate::tree::decision_tree_classifier::{ use crate::tree::decision_tree_classifier::{
@@ -667,16 +667,15 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
Ok(probas) Ok(probas)
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::metrics::*;
use crate::ensemble::random_forest_classifier::RandomForestClassifier; use crate::ensemble::random_forest_classifier::RandomForestClassifier;
use crate::linalg::basic::arrays::Array; use crate::linalg::basic::arrays::Array;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::metrics::*;
#[test] #[test]
fn search_parameters() { fn search_parameters() {
@@ -846,7 +845,8 @@ mod tests {
&[6.9, 3.1, 4.9, 1.5], &[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3], &[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5], &[6.5, 2.8, 4.6, 1.5],
]).unwrap(); ])
.unwrap();
let y: Vec<u32> = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1]; let y: Vec<u32> = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap(); let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap();
@@ -858,12 +858,21 @@ mod tests {
// Test probability sum // Test probability sum
for i in 0..10 { for i in 0..10 {
let row_sum: f64 = probas.get_row(i).sum(); let row_sum: f64 = probas.get_row(i).sum();
assert!((row_sum - 1.0).abs() < 1e-6, "Row probabilities should sum to 1"); assert!(
(row_sum - 1.0).abs() < 1e-6,
"Row probabilities should sum to 1"
);
} }
// Test class prediction // Test class prediction
let predictions: Vec<u32> = (0..10) let predictions: Vec<u32> = (0..10)
.map(|i| if probas.get((i, 0)) > probas.get((i, 1)) { 0 } else { 1 }) .map(|i| {
if probas.get((i, 0)) > probas.get((i, 1)) {
0
} else {
1
}
})
.collect(); .collect();
let acc = accuracy(&y, &predictions); let acc = accuracy(&y, &predictions);
assert!(acc > 0.8, "Accuracy should be high for the training set"); assert!(acc > 0.8, "Accuracy should be high for the training set");
@@ -871,23 +880,42 @@ mod tests {
// Test probability values // Test probability values
// These values are approximate and based on typical random forest behavior // These values are approximate and based on typical random forest behavior
for i in 0..5 { for i in 0..5 {
assert!(*probas.get((i, 0)) > 0.6, "Class 0 samples should have high probability for class 0"); assert!(
assert!(*probas.get((i, 1)) < 0.4, "Class 0 samples should have low probability for class 1"); *probas.get((i, 0)) > 0.6,
"Class 0 samples should have high probability for class 0"
);
assert!(
*probas.get((i, 1)) < 0.4,
"Class 0 samples should have low probability for class 1"
);
} }
for i in 5..10 { for i in 5..10 {
assert!(*probas.get((i, 1)) > 0.6, "Class 1 samples should have high probability for class 1"); assert!(
assert!(*probas.get((i, 0)) < 0.4, "Class 1 samples should have low probability for class 0"); *probas.get((i, 1)) > 0.6,
"Class 1 samples should have high probability for class 1"
);
assert!(
*probas.get((i, 0)) < 0.4,
"Class 1 samples should have low probability for class 0"
);
} }
// Test with new data // Test with new data
let x_new = DenseMatrix::from_2d_array(&[ let x_new = DenseMatrix::from_2d_array(&[
&[5.0, 3.4, 1.5, 0.2], // Should be close to class 0 &[5.0, 3.4, 1.5, 0.2], // Should be close to class 0
&[6.3, 3.3, 4.7, 1.6], // Should be close to class 1 &[6.3, 3.3, 4.7, 1.6], // Should be close to class 1
]).unwrap(); ])
.unwrap();
let probas_new = forest.predict_proba(&x_new).unwrap(); let probas_new = forest.predict_proba(&x_new).unwrap();
assert_eq!(probas_new.shape(), (2, 2)); assert_eq!(probas_new.shape(), (2, 2));
assert!(probas_new.get((0, 0)) > probas_new.get((0, 1)), "First sample should be predicted as class 0"); assert!(
assert!(probas_new.get((1, 1)) > probas_new.get((1, 0)), "Second sample should be predicted as class 1"); probas_new.get((0, 0)) > probas_new.get((0, 1)),
"First sample should be predicted as class 0"
);
assert!(
probas_new.get((1, 1)) > probas_new.get((1, 0)),
"Second sample should be predicted as class 1"
);
} }
#[cfg_attr( #[cfg_attr(