15 Commits

Author SHA1 Message Date
Lorenzo
78780787db Update ci.yml 2025-01-22 12:12:07 +00:00
Lorenzo Mec-iS
4aee603ae4 fix test conditions 2025-01-22 12:08:11 +00:00
Lorenzo Mec-iS
4878042392 Merge branch 'issue-50-predict-proba-for-randomforest' of github.com:smartcorelib/smartcore into issue-50-predict-proba-for-randomforest 2025-01-20 20:13:04 +00:00
Lorenzo Mec-iS
d427c91cef try to fix test error 2025-01-20 20:12:41 +00:00
Lorenzo Mec-iS
0262dae872 Merge branch 'development' of github.com:smartcorelib/smartcore into issue-50-predict-proba-for-randomforest 2025-01-20 18:51:36 +00:00
Lorenzo
5d6ed49071 Merge branch 'development' into issue-50-predict-proba-for-randomforest 2025-01-20 18:51:06 +00:00
Lorenzo Mec-iS
bb356e6a28 fix test 2025-01-20 17:29:29 +00:00
Lorenzo Mec-iS
52b797d520 format 2025-01-20 17:18:09 +00:00
Lorenzo Mec-iS
63fa00334b Fix clippy error 2025-01-20 17:17:41 +00:00
Lorenzo Mec-iS
40ee35b04f Implement predict_proba for RandomForestClassifier 2025-01-20 17:15:52 +00:00
Lorenzo Mec-iS
5711788fd8 add proper error handling 2025-01-20 16:08:29 +00:00
Lorenzo Mec-iS
fc7f2e61d9 format 2025-01-20 15:27:39 +00:00
Lorenzo Mec-iS
609f8024bc more clippy fixes 2025-01-20 15:23:36 +00:00
Lorenzo Mec-iS
58ee0cb8d1 Some automated fixes suggested by cargo clippy --fix 2025-01-20 15:04:21 +00:00
Lorenzo Mec-iS
68fd27f8f4 Implement predict_proba for DecisionTreeClassifier 2025-01-20 14:59:50 +00:00
3 changed files with 175 additions and 0 deletions
+12
View File
@@ -70,3 +70,15 @@ $ rust-code-analysis-cli -p src/algorithm/neighbour/fastpair.rs --ls 22 --le 213
* **PRs on develop**: any change should be PRed first in `development`
* **testing**: everything should work and be tested as defined in the workflow. If any is failing for non-related reasons, annotate the test failure in the PR comment.
## Suggestions for debugging
1. Install `lldb` for your platform
2. Run `rust-lldb target/debug/libsmartcore.rlib` in your command-line
3. In lldb, set up some breakpoints using `b func_name` or `b src/path/to/file.rs:linenumber`
4. In lldb, run a single test with `r the_name_of_your_test`
Display variables in scope: `frame variable <name>`
Execute expression: `p <expr>`
+1
View File
@@ -23,6 +23,7 @@ jobs:
]
env:
TZ: "/usr/share/zoneinfo/your/location"
RUST_BACKTRACE: "1"
steps:
- uses: actions/checkout@v3
- name: Cache .cargo and target
+162
View File
@@ -55,7 +55,9 @@ use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::MutArray;
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::linalg::basic::matrix::DenseMatrix;
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
@@ -602,11 +604,76 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
}
samples
}
/// Predict class probabilities for X.
///
/// The predicted class probabilities of an input sample are computed as
/// the mean predicted class probabilities of the trees in the forest.
/// The class probability of a single tree is the fraction of samples of
/// the same class in a leaf.
///
/// # Arguments
///
/// * `x` - The input samples. A matrix of shape (n_samples, n_features).
///
/// # Returns
///
/// * `Result<DenseMatrix<f64>, Failed>` - The class probabilities of the input samples.
/// The order of the classes corresponds to that in the attribute `classes_`.
/// The matrix has shape (n_samples, n_classes).
///
/// # Errors
///
/// Returns a `Failed` error if:
/// * The model has not been fitted yet.
/// * The input `x` is not compatible with the model's expected input.
/// * Any of the tree predictions fail.
///
/// # Examples
///
/// ```
/// use smartcore::ensemble::random_forest_classifier::RandomForestClassifier;
/// use smartcore::linalg::basic::matrix::DenseMatrix;
/// use smartcore::linalg::basic::arrays::Array;
///
/// let x = DenseMatrix::from_2d_array(&[
/// &[5.1, 3.5, 1.4, 0.2],
/// &[4.9, 3.0, 1.4, 0.2],
/// &[7.0, 3.2, 4.7, 1.4],
/// ]).unwrap();
/// let y = vec![0, 0, 1];
///
/// let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap();
/// let probas = forest.predict_proba(&x).unwrap();
///
/// assert_eq!(probas.shape(), (3, 2));
/// ```
pub fn predict_proba(&self, x: &X) -> Result<DenseMatrix<f64>, Failed> {
let (n_samples, _) = x.shape();
let n_classes = self.classes.as_ref().unwrap().len();
let mut probas = DenseMatrix::<f64>::zeros(n_samples, n_classes);
for tree in self.trees.as_ref().unwrap().iter() {
let tree_predictions: Y = tree.predict(x).unwrap();
for (i, &class_idx) in tree_predictions.iterator(0).enumerate() {
let class_ = class_idx.to_usize().unwrap();
probas.add_element_mut((i, class_), 1.0);
}
}
let n_trees: f64 = self.trees.as_ref().unwrap().len() as f64;
probas.mul_scalar_mut(1.0 / n_trees);
Ok(probas)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ensemble::random_forest_classifier::RandomForestClassifier;
use crate::linalg::basic::arrays::Array;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::metrics::*;
@@ -760,6 +827,101 @@ mod tests {
);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn test_random_forest_predict_proba() {
use num_traits::FromPrimitive;
// Iris-like dataset (subset)
let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
])
.unwrap();
let y: Vec<u32> = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap();
let probas = forest.predict_proba(&x).unwrap();
// Test shape
assert_eq!(probas.shape(), (10, 2));
let (pro_n_rows, _) = probas.shape();
// Test probability sum
for i in 0..pro_n_rows {
let row_sum: f64 = probas.get_row(i).sum();
assert!(
(row_sum - 1.0).abs() < 1e-6,
"Row probabilities should sum to 1"
);
}
// Test class prediction
let predictions: Vec<u32> = (0..pro_n_rows)
.map(|i| {
if probas.get((i, 0)) > probas.get((i, 1)) {
0
} else {
1
}
})
.collect();
let acc = accuracy(&y, &predictions);
assert!(acc > 0.8, "Accuracy should be high for the training set");
// Test probability values
// These values are approximate and based on typical random forest behavior
for i in 0..(pro_n_rows / 2) {
assert!(
f64::from_f32(0.6).unwrap().lt(probas.get((i, 0))),
"Class 0 samples should have high probability for class 0"
);
assert!(
f64::from_f32(0.4).unwrap().gt(probas.get((i, 1))),
"Class 0 samples should have low probability for class 1"
);
}
for i in (pro_n_rows / 2)..pro_n_rows {
assert!(
f64::from_f32(0.6).unwrap().lt(probas.get((i, 1))),
"Class 1 samples should have high probability for class 1"
);
assert!(
f64::from_f32(0.4).unwrap().gt(probas.get((i, 0))),
"Class 1 samples should have low probability for class 0"
);
}
// Test with new data
let x_new = DenseMatrix::from_2d_array(&[
&[5.0, 3.4, 1.5, 0.2], // Should be close to class 0
&[6.3, 3.3, 4.7, 1.6], // Should be close to class 1
])
.unwrap();
let probas_new = forest.predict_proba(&x_new).unwrap();
assert_eq!(probas_new.shape(), (2, 2));
assert!(
probas_new.get((0, 0)) > probas_new.get((0, 1)),
"First sample should be predicted as class 0"
);
assert!(
probas_new.get((1, 1)) > probas_new.get((1, 0)),
"Second sample should be predicted as class 1"
);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test