16 Commits

Author SHA1 Message Date
Lorenzo
cfc953b25c Merge branch 'development' into prdct-prb 2024-04-08 08:56:24 +01:00
Lorenzo
5bf7102fc2 Merge branch 'development' into prdct-prb 2023-03-21 14:03:04 +09:00
Lorenzo
97604a2d83 Merge branch 'development' into prdct-prb 2023-01-27 10:42:48 +00:00
morenol
dae556776c Merge branch 'development' into prdct-prb 2023-01-26 20:14:05 -04:00
Lorenzo
24d80a0c9a Merge branch 'development' into prdct-prb 2022-11-09 16:31:03 +00:00
Lorenzo
c56370dfca Merge branch 'development' into prdct-prb 2022-11-03 11:59:46 +00:00
Lorenzo (Mec-iS)
78e53a28e7 apply fmt 2022-10-31 19:28:24 +00:00
Lorenzo (Mec-iS)
a9f89a2e15 Fix conflicts 2022-10-31 19:22:06 +00:00
Luis Moreno
e9ed9e85ae Merge remote-tracking branch 'sm/development' into predict-probability 2022-09-22 12:20:56 -04:00
Alan Race
28c81eb358 Test case now passing without transpose 2022-08-30 11:08:35 +02:00
Alan Race
7f7b2edca0 Fixed test by transposing matrix 2022-08-29 16:25:21 +02:00
Alan Race
d46b830bcd Merge branch 'development' into predict-probability 2022-08-29 16:24:39 +02:00
Alan Race
b6fb8191eb Merge pull request #1 from smartcorelib/alanrace-predict-probs
Add test to predict probabilities
2022-08-29 15:57:24 +02:00
Lorenzo (Mec-iS)
61db4ebd90 Add test 2022-08-24 12:34:56 +01:00
Lorenzo (Mec-iS)
2603a1f42b Add test 2022-08-24 11:44:30 +01:00
Alan Race
663db0334d Added per-class probability prediction for random forests 2022-07-11 16:08:03 +02:00
+97
View File
@@ -580,6 +580,37 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
which_max(&result)
}
/// Predict the per-class probabilties for each observation.
/// The probability is calculated as the fraction of trees that predicted a given class
pub fn predict_proba<R: Array2<f64>>(&self, x: &X) -> Result<R, Failed> {
let mut result: R = R::zeros(x.shape().0, self.classes.as_ref().unwrap().len());
let (n, _) = x.shape();
for i in 0..n {
let row_probs = self.predict_proba_for_row(x, i);
for (j, item) in row_probs.iter().enumerate() {
result.set((i, j), *item);
}
}
Ok(result)
}
fn predict_proba_for_row(&self, x: &X, row: usize) -> Vec<f64> {
let mut result = vec![0; self.classes.as_ref().unwrap().len()];
for tree in self.trees.as_ref().unwrap().iter() {
result[tree.predict_for_row(x, row)] += 1;
}
result
.iter()
.map(|n| *n as f64 / self.trees.as_ref().unwrap().len() as f64)
.collect()
}
fn sample_with_replacement(y: &[usize], num_classes: usize, rng: &mut impl Rng) -> Vec<usize> {
let class_weight = vec![1.; num_classes];
let nrows = y.len();
@@ -607,6 +638,7 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::arrays::Array;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::metrics::*;
@@ -799,4 +831,69 @@ mod tests {
assert_eq!(forest, deserialized_forest);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn fit_predict_probabilities() {
let x = DenseMatrix::<f64>::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[5.7, 2.8, 4.5, 1.3],
&[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
let classifier = RandomForestClassifier::fit(
&x,
&y,
RandomForestClassifierParameters {
criterion: SplitCriterion::Gini,
max_depth: None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 100, // this is n_estimators in sklearn
m: Option::None,
keep_samples: false,
seed: 0,
},
)
.unwrap();
println!("{:?}", classifier.classes);
let results: DenseMatrix<f64> = classifier.predict_proba(&x).unwrap();
println!("{:?}", x.shape());
println!("{:?}", results);
println!("{:?}", results.shape());
assert_eq!(
results,
DenseMatrix::<f64>::new(
20,
2,
vec![
1.0, 0.0, 0.78, 0.22, 0.95, 0.05, 0.82, 0.18, 1.0, 0.0, 0.92, 0.08, 0.99, 0.01,
0.96, 0.04, 0.36, 0.64, 0.33, 0.67, 0.02, 0.98, 0.02, 0.98, 0.0, 1.0, 0.0, 1.0,
0.0, 1.0, 0.0, 1.0, 0.03, 0.97, 0.05, 0.95, 0.0, 1.0, 0.02, 0.98
],
true
)
);
}
}