From 68fd27f8f4b9933986ce90eec789e1f487e9c7ca Mon Sep 17 00:00:00 2001 From: Lorenzo Mec-iS Date: Mon, 20 Jan 2025 14:59:50 +0000 Subject: [PATCH] Implement predict_proba for DecisionTreeClassifier --- src/tree/decision_tree_classifier.rs | 110 +++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index c659651..712cd87 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -78,6 +78,8 @@ use serde::{Deserialize, Serialize}; use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1}; +use crate::linalg::basic::matrix::DenseMatrix; +use crate::linalg::basic::arrays::MutArray; use crate::numbers::basenum::Number; use crate::rand_custom::get_rng_impl; @@ -887,12 +889,79 @@ impl, Y: Array1> } importances } + + + /// Predict class probabilities for the input samples. + /// + /// # Arguments + /// + /// * `x` - The input samples as a matrix where each row is a sample and each column is a feature. + /// + /// # Returns + /// + /// A `Result` containing a `DenseMatrix` where each row corresponds to a sample and each column + /// corresponds to a class. The values represent the probability of the sample belonging to each class. + /// + /// # Errors + /// + /// Returns an error if the prediction process fails. + pub fn predict_proba(&self, x: &X) -> Result, Failed> { + let (n_samples, _) = x.shape(); + let n_classes = self.classes().len(); + let mut result = DenseMatrix::::zeros(n_samples, n_classes); + + for i in 0..n_samples { + let probs = self.predict_proba_for_row(x, i); + for (j, &prob) in probs.iter().enumerate() { + result.set((i, j), prob); + } + } + + Ok(result) + } + + /// Predict class probabilities for a single input sample. + /// + /// # Arguments + /// + /// * `x` - The input matrix containing all samples. + /// * `row` - The index of the row in `x` for which to predict probabilities. + /// + /// # Returns + /// + /// A vector of probabilities, one for each class, representing the probability + /// of the input sample belonging to each class. + fn predict_proba_for_row(&self, x: &X, row: usize) -> Vec { + let mut node = 0; + + while let Some(current_node) = self.nodes().get(node) { + if current_node.true_child.is_none() && current_node.false_child.is_none() { + // Leaf node reached + let mut probs = vec![0.0; self.classes().len()]; + probs[current_node.output] = 1.0; + return probs; + } + + let split_feature = current_node.split_feature; + let split_value = current_node.split_value.unwrap_or(f64::NAN); + + if x.get((row, split_feature)).to_f64().unwrap() <= split_value { + node = current_node.true_child.unwrap(); + } else { + node = current_node.false_child.unwrap(); + } + } + + // This should never happen if the tree is properly constructed + vec![0.0; self.classes().len()] + } } #[cfg(test)] mod tests { use super::*; use crate::linalg::basic::matrix::DenseMatrix; + use crate::linalg::basic::arrays::Array; #[test] fn search_parameters() { @@ -934,6 +1003,47 @@ mod tests { ); } + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn test_predict_proba() { + let x: DenseMatrix = DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + ]).unwrap(); + let y: Vec = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1]; + + let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap(); + let probabilities = tree.predict_proba(&x).unwrap(); + + assert_eq!(probabilities.shape(), (10, 2)); + + for row in 0..10 { + let row_sum: f64 = probabilities.get_row(row).sum(); + assert!((row_sum - 1.0).abs() < 1e-6, "Row probabilities should sum to 1"); + } + + // Check if the first 5 samples have higher probability for class 0 + for i in 0..5 { + assert!(probabilities.get((i, 0)) > probabilities.get((i, 1))); + } + + // Check if the last 5 samples have higher probability for class 1 + for i in 5..10 { + assert!(probabilities.get((i, 1)) > probabilities.get((i, 0))); + } + } + #[cfg_attr( all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test