Add serde to CategoricalNB (#30)

* Add serde to CategoricalNB

* Implement PartialEq for CategoricalNBDistribution
This commit is contained in:
morenol
2020-11-19 16:07:10 -04:00
committed by GitHub
parent ad3ac49dde
commit 9db993939e
+58 -2
View File
@@ -6,13 +6,41 @@ use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
use serde::{Deserialize, Serialize};
/// Naive Bayes classifier for categorical features
#[derive(Debug)]
#[derive(Serialize, Deserialize, Debug)]
struct CategoricalNBDistribution<T: RealNumber> {
class_labels: Vec<T>,
class_priors: Vec<T>,
coefficients: Vec<Vec<Vec<T>>>,
}
impl<T: RealNumber> PartialEq for CategoricalNBDistribution<T> {
fn eq(&self, other: &Self) -> bool {
if self.class_labels == other.class_labels && self.class_priors == other.class_priors {
if self.coefficients.len() != other.coefficients.len() {
return false;
}
for (a, b) in self.coefficients.iter().zip(other.coefficients.iter()) {
if a.len() != b.len() {
return false;
}
for (a_i, b_i) in a.iter().zip(b.iter()) {
if a_i.len() != b_i.len() {
return false;
}
for (a_i_j, b_i_j) in a_i.iter().zip(b_i.iter()) {
if (*a_i_j - *b_i_j).abs() > T::epsilon() {
return false;
}
}
}
}
true
} else {
false
}
}
}
impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for CategoricalNBDistribution<T> {
fn prior(&self, class_index: usize) -> T {
if class_index >= self.class_labels.len() {
@@ -181,7 +209,7 @@ impl<T: RealNumber> Default for CategoricalNBParameters<T> {
}
/// CategoricalNB implements the categorical naive Bayes algorithm for categorically distributed data.
#[derive(Debug)]
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct CategoricalNB<T: RealNumber, M: Matrix<T>> {
inner: BaseNaiveBayes<T, M, CategoricalNBDistribution<T>>,
}
@@ -269,4 +297,32 @@ mod tests {
vec![0., 0., 1., 1., 1., 0., 1., 0., 1., 1., 0., 1., 1., 1.]
);
}
#[test]
fn serde() {
let x = DenseMatrix::<f64>::from_2d_array(&[
&[3., 4., 0., 1.],
&[3., 0., 0., 1.],
&[4., 4., 1., 2.],
&[4., 2., 4., 3.],
&[4., 2., 4., 2.],
&[4., 1., 1., 0.],
&[1., 1., 1., 1.],
&[0., 4., 1., 0.],
&[0., 3., 2., 1.],
&[0., 3., 1., 1.],
&[3., 4., 0., 1.],
&[3., 4., 2., 4.],
&[0., 3., 1., 2.],
&[0., 4., 1., 2.],
]);
let y = vec![0., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 0.];
let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap();
let deserialized_cnb: CategoricalNB<f64, DenseMatrix<f64>> =
serde_json::from_str(&serde_json::to_string(&cnb).unwrap()).unwrap();
assert_eq!(cnb, deserialized_cnb);
}
}