In Naive Bayes, avoid using Option::unwrap and so avoid panicking from NaN values (#274)
This commit is contained in:
@@ -3,9 +3,9 @@
|
||||
use crate::{
|
||||
api::{Predictor, SupervisedEstimator},
|
||||
error::{Failed, FailedError},
|
||||
linalg::basic::arrays::{Array2, Array1},
|
||||
numbers::realnum::RealNumber,
|
||||
linalg::basic::arrays::{Array1, Array2},
|
||||
numbers::basenum::Number,
|
||||
numbers::realnum::RealNumber,
|
||||
};
|
||||
|
||||
use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult};
|
||||
|
||||
+84
-10
@@ -40,7 +40,7 @@ use crate::linalg::basic::arrays::{Array1, Array2, ArrayView1};
|
||||
use crate::numbers::basenum::Number;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::marker::PhantomData;
|
||||
use std::{cmp::Ordering, marker::PhantomData};
|
||||
|
||||
/// Distribution used in the Naive Bayes classifier.
|
||||
pub(crate) trait NBDistribution<X: Number, Y: Number>: Clone {
|
||||
@@ -92,11 +92,10 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: NBDistribution<TX,
|
||||
/// Returns a vector of size N with class estimates.
|
||||
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
let y_classes = self.distribution.classes();
|
||||
let (rows, _) = x.shape();
|
||||
let predictions = (0..rows)
|
||||
.map(|row_index| {
|
||||
let row = x.get_row(row_index);
|
||||
let (prediction, _probability) = y_classes
|
||||
let predictions = x
|
||||
.row_iter()
|
||||
.map(|row| {
|
||||
y_classes
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(class_index, class)| {
|
||||
@@ -106,11 +105,26 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: NBDistribution<TX,
|
||||
+ self.distribution.prior(class_index).ln(),
|
||||
)
|
||||
})
|
||||
.max_by(|(_, p1), (_, p2)| p1.partial_cmp(p2).unwrap())
|
||||
.unwrap();
|
||||
*prediction
|
||||
// For some reason, the max_by method cannot use NaNs for finding the maximum value, it panics.
|
||||
// NaN must be considered as minimum values,
|
||||
// therefore it's like NaNs would not be considered for choosing the maximum value.
|
||||
// So we need to handle this case for avoiding panicking by using `Option::unwrap`.
|
||||
.max_by(|(_, p1), (_, p2)| match p1.partial_cmp(p2) {
|
||||
Some(ordering) => ordering,
|
||||
None => {
|
||||
if p1.is_nan() {
|
||||
Ordering::Less
|
||||
} else if p2.is_nan() {
|
||||
Ordering::Greater
|
||||
} else {
|
||||
Ordering::Equal
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect::<Vec<TY>>();
|
||||
.map(|(prediction, _probability)| *prediction)
|
||||
.ok_or_else(|| Failed::predict("Failed to predict, there is no result"))
|
||||
})
|
||||
.collect::<Result<Vec<TY>, Failed>>()?;
|
||||
let y_hat = Y::from_vec_slice(&predictions);
|
||||
Ok(y_hat)
|
||||
}
|
||||
@@ -119,3 +133,63 @@ pub mod bernoulli;
|
||||
pub mod categorical;
|
||||
pub mod gaussian;
|
||||
pub mod multinomial;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::basic::arrays::Array;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use num_traits::float::Float;
|
||||
|
||||
type Model<'d> = BaseNaiveBayes<i32, i32, DenseMatrix<i32>, Vec<i32>, TestDistribution<'d>>;
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
struct TestDistribution<'d>(&'d Vec<i32>);
|
||||
|
||||
impl<'d> NBDistribution<i32, i32> for TestDistribution<'d> {
|
||||
fn prior(&self, _class_index: usize) -> f64 {
|
||||
1.
|
||||
}
|
||||
|
||||
fn log_likelihood<'a>(
|
||||
&'a self,
|
||||
class_index: usize,
|
||||
_j: &'a Box<dyn ArrayView1<i32> + 'a>,
|
||||
) -> f64 {
|
||||
match self.0.get(class_index) {
|
||||
&v @ 2 | &v @ 10 | &v @ 20 => v as f64,
|
||||
_ => f64::nan(),
|
||||
}
|
||||
}
|
||||
|
||||
fn classes(&self) -> &Vec<i32> {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_predict() {
|
||||
let matrix = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]]);
|
||||
|
||||
let val = vec![];
|
||||
match Model::fit(TestDistribution(&val)).unwrap().predict(&matrix) {
|
||||
Ok(_) => panic!("Should return error in case of empty classes"),
|
||||
Err(err) => assert_eq!(
|
||||
err.to_string(),
|
||||
"Predict failed: Failed to predict, there is no result"
|
||||
),
|
||||
}
|
||||
|
||||
let val = vec![1, 2, 3];
|
||||
match Model::fit(TestDistribution(&val)).unwrap().predict(&matrix) {
|
||||
Ok(r) => assert_eq!(r, vec![2, 2, 2]),
|
||||
Err(_) => panic!("Should success in normal case with NaNs"),
|
||||
}
|
||||
|
||||
let val = vec![20, 2, 10];
|
||||
match Model::fit(TestDistribution(&val)).unwrap().predict(&matrix) {
|
||||
Ok(r) => assert_eq!(r, vec![20, 20, 20]),
|
||||
Err(_) => panic!("Should success in normal case without NaNs"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user