Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cfc953b25c | ||
|
|
5bf7102fc2 | ||
|
|
97604a2d83 | ||
|
|
dae556776c | ||
|
|
24d80a0c9a | ||
|
|
c56370dfca | ||
|
|
78e53a28e7 | ||
|
|
a9f89a2e15 | ||
|
|
e9ed9e85ae | ||
|
|
28c81eb358 | ||
|
|
7f7b2edca0 | ||
|
|
d46b830bcd | ||
|
|
b6fb8191eb | ||
|
|
61db4ebd90 | ||
|
|
2603a1f42b | ||
|
|
663db0334d |
@@ -580,6 +580,37 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
|
||||
which_max(&result)
|
||||
}
|
||||
|
||||
/// Predict the per-class probabilties for each observation.
|
||||
/// The probability is calculated as the fraction of trees that predicted a given class
|
||||
pub fn predict_proba<R: Array2<f64>>(&self, x: &X) -> Result<R, Failed> {
|
||||
let mut result: R = R::zeros(x.shape().0, self.classes.as_ref().unwrap().len());
|
||||
|
||||
let (n, _) = x.shape();
|
||||
|
||||
for i in 0..n {
|
||||
let row_probs = self.predict_proba_for_row(x, i);
|
||||
|
||||
for (j, item) in row_probs.iter().enumerate() {
|
||||
result.set((i, j), *item);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn predict_proba_for_row(&self, x: &X, row: usize) -> Vec<f64> {
|
||||
let mut result = vec![0; self.classes.as_ref().unwrap().len()];
|
||||
|
||||
for tree in self.trees.as_ref().unwrap().iter() {
|
||||
result[tree.predict_for_row(x, row)] += 1;
|
||||
}
|
||||
|
||||
result
|
||||
.iter()
|
||||
.map(|n| *n as f64 / self.trees.as_ref().unwrap().len() as f64)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn sample_with_replacement(y: &[usize], num_classes: usize, rng: &mut impl Rng) -> Vec<usize> {
|
||||
let class_weight = vec![1.; num_classes];
|
||||
let nrows = y.len();
|
||||
@@ -607,6 +638,7 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::basic::arrays::Array;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use crate::metrics::*;
|
||||
|
||||
@@ -799,4 +831,69 @@ mod tests {
|
||||
|
||||
assert_eq!(forest, deserialized_forest);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn fit_predict_probabilities() {
|
||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[5.7, 2.8, 4.5, 1.3],
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
|
||||
|
||||
let classifier = RandomForestClassifier::fit(
|
||||
&x,
|
||||
&y,
|
||||
RandomForestClassifierParameters {
|
||||
criterion: SplitCriterion::Gini,
|
||||
max_depth: None,
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2,
|
||||
n_trees: 100, // this is n_estimators in sklearn
|
||||
m: Option::None,
|
||||
keep_samples: false,
|
||||
seed: 0,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
println!("{:?}", classifier.classes);
|
||||
|
||||
let results: DenseMatrix<f64> = classifier.predict_proba(&x).unwrap();
|
||||
println!("{:?}", x.shape());
|
||||
println!("{:?}", results);
|
||||
println!("{:?}", results.shape());
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
DenseMatrix::<f64>::new(
|
||||
20,
|
||||
2,
|
||||
vec![
|
||||
1.0, 0.0, 0.78, 0.22, 0.95, 0.05, 0.82, 0.18, 1.0, 0.0, 0.92, 0.08, 0.99, 0.01,
|
||||
0.96, 0.04, 0.36, 0.64, 0.33, 0.67, 0.02, 0.98, 0.02, 0.98, 0.0, 1.0, 0.0, 1.0,
|
||||
0.0, 1.0, 0.0, 1.0, 0.03, 0.97, 0.05, 0.95, 0.0, 1.0, 0.02, 0.98
|
||||
],
|
||||
true
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user