add proper error handling
This commit is contained in:
@@ -903,14 +903,14 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the prediction process fails.
|
||||
/// Returns an error if at least one row prediction process fails.
|
||||
pub fn predict_proba(&self, x: &X) -> Result<DenseMatrix<f64>, Failed> {
|
||||
let (n_samples, _) = x.shape();
|
||||
let n_classes = self.classes().len();
|
||||
let mut result = DenseMatrix::<f64>::zeros(n_samples, n_classes);
|
||||
|
||||
for i in 0..n_samples {
|
||||
let probs = self.predict_proba_for_row(x, i);
|
||||
let probs = self.predict_proba_for_row(x, i)?;
|
||||
for (j, &prob) in probs.iter().enumerate() {
|
||||
result.set((i, j), prob);
|
||||
}
|
||||
@@ -930,7 +930,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
///
|
||||
/// A vector of probabilities, one for each class, representing the probability
|
||||
/// of the input sample belonging to each class.
|
||||
fn predict_proba_for_row(&self, x: &X, row: usize) -> Vec<f64> {
|
||||
fn predict_proba_for_row(&self, x: &X, row: usize) -> Result<Vec<f64>, Failed> {
|
||||
let mut node = 0;
|
||||
|
||||
while let Some(current_node) = self.nodes().get(node) {
|
||||
@@ -938,7 +938,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
// Leaf node reached
|
||||
let mut probs = vec![0.0; self.classes().len()];
|
||||
probs[current_node.output] = 1.0;
|
||||
return probs;
|
||||
return Ok(probs);
|
||||
}
|
||||
|
||||
let split_feature = current_node.split_feature;
|
||||
@@ -952,7 +952,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
}
|
||||
|
||||
// This should never happen if the tree is properly constructed
|
||||
vec![0.0; self.classes().len()]
|
||||
Err(Failed::predict("Nodes iteration did not reach leaf"))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user