In Naive Bayes, avoid using Option::unwrap and so avoid panicking from NaN values (#274)
This commit is contained in:
@@ -3,9 +3,9 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
api::{Predictor, SupervisedEstimator},
|
api::{Predictor, SupervisedEstimator},
|
||||||
error::{Failed, FailedError},
|
error::{Failed, FailedError},
|
||||||
linalg::basic::arrays::{Array2, Array1},
|
linalg::basic::arrays::{Array1, Array2},
|
||||||
numbers::realnum::RealNumber,
|
|
||||||
numbers::basenum::Number,
|
numbers::basenum::Number,
|
||||||
|
numbers::realnum::RealNumber,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult};
|
use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult};
|
||||||
|
|||||||
+84
-10
@@ -40,7 +40,7 @@ use crate::linalg::basic::arrays::{Array1, Array2, ArrayView1};
|
|||||||
use crate::numbers::basenum::Number;
|
use crate::numbers::basenum::Number;
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::marker::PhantomData;
|
use std::{cmp::Ordering, marker::PhantomData};
|
||||||
|
|
||||||
/// Distribution used in the Naive Bayes classifier.
|
/// Distribution used in the Naive Bayes classifier.
|
||||||
pub(crate) trait NBDistribution<X: Number, Y: Number>: Clone {
|
pub(crate) trait NBDistribution<X: Number, Y: Number>: Clone {
|
||||||
@@ -92,11 +92,10 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: NBDistribution<TX,
|
|||||||
/// Returns a vector of size N with class estimates.
|
/// Returns a vector of size N with class estimates.
|
||||||
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||||
let y_classes = self.distribution.classes();
|
let y_classes = self.distribution.classes();
|
||||||
let (rows, _) = x.shape();
|
let predictions = x
|
||||||
let predictions = (0..rows)
|
.row_iter()
|
||||||
.map(|row_index| {
|
.map(|row| {
|
||||||
let row = x.get_row(row_index);
|
y_classes
|
||||||
let (prediction, _probability) = y_classes
|
|
||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(class_index, class)| {
|
.map(|(class_index, class)| {
|
||||||
@@ -106,11 +105,26 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: NBDistribution<TX,
|
|||||||
+ self.distribution.prior(class_index).ln(),
|
+ self.distribution.prior(class_index).ln(),
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
.max_by(|(_, p1), (_, p2)| p1.partial_cmp(p2).unwrap())
|
// For some reason, the max_by method cannot use NaNs for finding the maximum value, it panics.
|
||||||
.unwrap();
|
// NaN must be considered as minimum values,
|
||||||
*prediction
|
// therefore it's like NaNs would not be considered for choosing the maximum value.
|
||||||
|
// So we need to handle this case for avoiding panicking by using `Option::unwrap`.
|
||||||
|
.max_by(|(_, p1), (_, p2)| match p1.partial_cmp(p2) {
|
||||||
|
Some(ordering) => ordering,
|
||||||
|
None => {
|
||||||
|
if p1.is_nan() {
|
||||||
|
Ordering::Less
|
||||||
|
} else if p2.is_nan() {
|
||||||
|
Ordering::Greater
|
||||||
|
} else {
|
||||||
|
Ordering::Equal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.map(|(prediction, _probability)| *prediction)
|
||||||
|
.ok_or_else(|| Failed::predict("Failed to predict, there is no result"))
|
||||||
})
|
})
|
||||||
.collect::<Vec<TY>>();
|
.collect::<Result<Vec<TY>, Failed>>()?;
|
||||||
let y_hat = Y::from_vec_slice(&predictions);
|
let y_hat = Y::from_vec_slice(&predictions);
|
||||||
Ok(y_hat)
|
Ok(y_hat)
|
||||||
}
|
}
|
||||||
@@ -119,3 +133,63 @@ pub mod bernoulli;
|
|||||||
pub mod categorical;
|
pub mod categorical;
|
||||||
pub mod gaussian;
|
pub mod gaussian;
|
||||||
pub mod multinomial;
|
pub mod multinomial;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::linalg::basic::arrays::Array;
|
||||||
|
use crate::linalg::basic::matrix::DenseMatrix;
|
||||||
|
use num_traits::float::Float;
|
||||||
|
|
||||||
|
type Model<'d> = BaseNaiveBayes<i32, i32, DenseMatrix<i32>, Vec<i32>, TestDistribution<'d>>;
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Clone)]
|
||||||
|
struct TestDistribution<'d>(&'d Vec<i32>);
|
||||||
|
|
||||||
|
impl<'d> NBDistribution<i32, i32> for TestDistribution<'d> {
|
||||||
|
fn prior(&self, _class_index: usize) -> f64 {
|
||||||
|
1.
|
||||||
|
}
|
||||||
|
|
||||||
|
fn log_likelihood<'a>(
|
||||||
|
&'a self,
|
||||||
|
class_index: usize,
|
||||||
|
_j: &'a Box<dyn ArrayView1<i32> + 'a>,
|
||||||
|
) -> f64 {
|
||||||
|
match self.0.get(class_index) {
|
||||||
|
&v @ 2 | &v @ 10 | &v @ 20 => v as f64,
|
||||||
|
_ => f64::nan(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn classes(&self) -> &Vec<i32> {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_predict() {
|
||||||
|
let matrix = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]]);
|
||||||
|
|
||||||
|
let val = vec![];
|
||||||
|
match Model::fit(TestDistribution(&val)).unwrap().predict(&matrix) {
|
||||||
|
Ok(_) => panic!("Should return error in case of empty classes"),
|
||||||
|
Err(err) => assert_eq!(
|
||||||
|
err.to_string(),
|
||||||
|
"Predict failed: Failed to predict, there is no result"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
let val = vec![1, 2, 3];
|
||||||
|
match Model::fit(TestDistribution(&val)).unwrap().predict(&matrix) {
|
||||||
|
Ok(r) => assert_eq!(r, vec![2, 2, 2]),
|
||||||
|
Err(_) => panic!("Should success in normal case with NaNs"),
|
||||||
|
}
|
||||||
|
|
||||||
|
let val = vec![20, 2, 10];
|
||||||
|
match Model::fit(TestDistribution(&val)).unwrap().predict(&matrix) {
|
||||||
|
Ok(r) => assert_eq!(r, vec![20, 20, 20]),
|
||||||
|
Err(_) => panic!("Should success in normal case without NaNs"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user