add proper error handling
This commit is contained in:
@@ -903,14 +903,14 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
///
|
///
|
||||||
/// # Errors
|
/// # Errors
|
||||||
///
|
///
|
||||||
/// Returns an error if the prediction process fails.
|
/// Returns an error if at least one row prediction process fails.
|
||||||
pub fn predict_proba(&self, x: &X) -> Result<DenseMatrix<f64>, Failed> {
|
pub fn predict_proba(&self, x: &X) -> Result<DenseMatrix<f64>, Failed> {
|
||||||
let (n_samples, _) = x.shape();
|
let (n_samples, _) = x.shape();
|
||||||
let n_classes = self.classes().len();
|
let n_classes = self.classes().len();
|
||||||
let mut result = DenseMatrix::<f64>::zeros(n_samples, n_classes);
|
let mut result = DenseMatrix::<f64>::zeros(n_samples, n_classes);
|
||||||
|
|
||||||
for i in 0..n_samples {
|
for i in 0..n_samples {
|
||||||
let probs = self.predict_proba_for_row(x, i);
|
let probs = self.predict_proba_for_row(x, i)?;
|
||||||
for (j, &prob) in probs.iter().enumerate() {
|
for (j, &prob) in probs.iter().enumerate() {
|
||||||
result.set((i, j), prob);
|
result.set((i, j), prob);
|
||||||
}
|
}
|
||||||
@@ -930,7 +930,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
///
|
///
|
||||||
/// A vector of probabilities, one for each class, representing the probability
|
/// A vector of probabilities, one for each class, representing the probability
|
||||||
/// of the input sample belonging to each class.
|
/// of the input sample belonging to each class.
|
||||||
fn predict_proba_for_row(&self, x: &X, row: usize) -> Vec<f64> {
|
fn predict_proba_for_row(&self, x: &X, row: usize) -> Result<Vec<f64>, Failed> {
|
||||||
let mut node = 0;
|
let mut node = 0;
|
||||||
|
|
||||||
while let Some(current_node) = self.nodes().get(node) {
|
while let Some(current_node) = self.nodes().get(node) {
|
||||||
@@ -938,7 +938,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
// Leaf node reached
|
// Leaf node reached
|
||||||
let mut probs = vec![0.0; self.classes().len()];
|
let mut probs = vec![0.0; self.classes().len()];
|
||||||
probs[current_node.output] = 1.0;
|
probs[current_node.output] = 1.0;
|
||||||
return probs;
|
return Ok(probs);
|
||||||
}
|
}
|
||||||
|
|
||||||
let split_feature = current_node.split_feature;
|
let split_feature = current_node.split_feature;
|
||||||
@@ -952,7 +952,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
}
|
}
|
||||||
|
|
||||||
// This should never happen if the tree is properly constructed
|
// This should never happen if the tree is properly constructed
|
||||||
vec![0.0; self.classes().len()]
|
Err(Failed::predict("Nodes iteration did not reach leaf"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user