add proper error handling

This commit is contained in:
Lorenzo Mec-iS
2025-01-20 16:08:29 +00:00
parent fc7f2e61d9
commit 5711788fd8
+5 -5
View File
@@ -903,14 +903,14 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
/// ///
/// # Errors /// # Errors
/// ///
/// Returns an error if the prediction process fails. /// Returns an error if at least one row prediction process fails.
pub fn predict_proba(&self, x: &X) -> Result<DenseMatrix<f64>, Failed> { pub fn predict_proba(&self, x: &X) -> Result<DenseMatrix<f64>, Failed> {
let (n_samples, _) = x.shape(); let (n_samples, _) = x.shape();
let n_classes = self.classes().len(); let n_classes = self.classes().len();
let mut result = DenseMatrix::<f64>::zeros(n_samples, n_classes); let mut result = DenseMatrix::<f64>::zeros(n_samples, n_classes);
for i in 0..n_samples { for i in 0..n_samples {
let probs = self.predict_proba_for_row(x, i); let probs = self.predict_proba_for_row(x, i)?;
for (j, &prob) in probs.iter().enumerate() { for (j, &prob) in probs.iter().enumerate() {
result.set((i, j), prob); result.set((i, j), prob);
} }
@@ -930,7 +930,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
/// ///
/// A vector of probabilities, one for each class, representing the probability /// A vector of probabilities, one for each class, representing the probability
/// of the input sample belonging to each class. /// of the input sample belonging to each class.
fn predict_proba_for_row(&self, x: &X, row: usize) -> Vec<f64> { fn predict_proba_for_row(&self, x: &X, row: usize) -> Result<Vec<f64>, Failed> {
let mut node = 0; let mut node = 0;
while let Some(current_node) = self.nodes().get(node) { while let Some(current_node) = self.nodes().get(node) {
@@ -938,7 +938,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
// Leaf node reached // Leaf node reached
let mut probs = vec![0.0; self.classes().len()]; let mut probs = vec![0.0; self.classes().len()];
probs[current_node.output] = 1.0; probs[current_node.output] = 1.0;
return probs; return Ok(probs);
} }
let split_feature = current_node.split_feature; let split_feature = current_node.split_feature;
@@ -952,7 +952,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
} }
// This should never happen if the tree is properly constructed // This should never happen if the tree is properly constructed
vec![0.0; self.classes().len()] Err(Failed::predict("Nodes iteration did not reach leaf"))
} }
} }