Add test
This commit is contained in:
@@ -317,7 +317,8 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
which_max(&result)
|
||||
}
|
||||
|
||||
/// Predict the per-class probabilties for each observation. The probability is calculated as the fraction of trees that predicted a given class
|
||||
/// Predict the per-class probabilties for each observation.
|
||||
/// The probability is calculated as the fraction of trees that predicted a given class
|
||||
pub fn predict_probs<M: Matrix<T>>(&self, x: &M) -> Result<DenseMatrix<f64>, Failed> {
|
||||
let mut result = DenseMatrix::<f64>::zeros(x.shape().0, self.classes.len());
|
||||
|
||||
@@ -326,8 +327,8 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
for i in 0..n {
|
||||
let row_probs = self.predict_probs_for_row(x, i);
|
||||
|
||||
for j in 0..row_probs.len() {
|
||||
result.set(i, j, row_probs[j]);
|
||||
for (j, item) in row_probs.iter().enumerate() {
|
||||
result.set(i, j, *item);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -559,9 +560,25 @@ mod tests_prob {
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let results = classifier.predict_probs(&x).unwrap();
|
||||
println!("{:?}", classifier.classes);
|
||||
|
||||
let results = classifier.predict_probs(&x).unwrap();
|
||||
println!("{:?}", x.shape());
|
||||
println!("{:?}", results);
|
||||
println!("{:?}", results.shape());
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
DenseMatrix::<f64>::from_array(
|
||||
20,
|
||||
2,
|
||||
&[
|
||||
1.0, 0.78, 0.95, 0.82, 1.0, 0.92, 0.99, 0.96, 0.36, 0.33, 0.02, 0.02, 0.0, 0.0,
|
||||
0.0, 0.0, 0.03, 0.05, 0.0, 0.02, 0.0, 0.22, 0.05, 0.18, 0.0, 0.08, 0.01, 0.04,
|
||||
0.64, 0.67, 0.98, 0.98, 1.0, 1.0, 1.0, 1.0, 0.97, 0.95, 1.0, 0.98
|
||||
]
|
||||
)
|
||||
);
|
||||
assert!(false);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user