Add serde to CategoricalNB (#30)
* Add serde to CategoricalNB * Implement PartialEq for CategoricalNBDistribution
This commit is contained in:
@@ -6,13 +6,41 @@ use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Naive Bayes classifier for categorical features
|
||||
#[derive(Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct CategoricalNBDistribution<T: RealNumber> {
|
||||
class_labels: Vec<T>,
|
||||
class_priors: Vec<T>,
|
||||
coefficients: Vec<Vec<Vec<T>>>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> PartialEq for CategoricalNBDistribution<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.class_labels == other.class_labels && self.class_priors == other.class_priors {
|
||||
if self.coefficients.len() != other.coefficients.len() {
|
||||
return false;
|
||||
}
|
||||
for (a, b) in self.coefficients.iter().zip(other.coefficients.iter()) {
|
||||
if a.len() != b.len() {
|
||||
return false;
|
||||
}
|
||||
for (a_i, b_i) in a.iter().zip(b.iter()) {
|
||||
if a_i.len() != b_i.len() {
|
||||
return false;
|
||||
}
|
||||
for (a_i_j, b_i_j) in a_i.iter().zip(b_i.iter()) {
|
||||
if (*a_i_j - *b_i_j).abs() > T::epsilon() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for CategoricalNBDistribution<T> {
|
||||
fn prior(&self, class_index: usize) -> T {
|
||||
if class_index >= self.class_labels.len() {
|
||||
@@ -181,7 +209,7 @@ impl<T: RealNumber> Default for CategoricalNBParameters<T> {
|
||||
}
|
||||
|
||||
/// CategoricalNB implements the categorical naive Bayes algorithm for categorically distributed data.
|
||||
#[derive(Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
pub struct CategoricalNB<T: RealNumber, M: Matrix<T>> {
|
||||
inner: BaseNaiveBayes<T, M, CategoricalNBDistribution<T>>,
|
||||
}
|
||||
@@ -269,4 +297,32 @@ mod tests {
|
||||
vec![0., 0., 1., 1., 1., 0., 1., 0., 1., 1., 0., 1., 1., 1.]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
&[3., 4., 0., 1.],
|
||||
&[3., 0., 0., 1.],
|
||||
&[4., 4., 1., 2.],
|
||||
&[4., 2., 4., 3.],
|
||||
&[4., 2., 4., 2.],
|
||||
&[4., 1., 1., 0.],
|
||||
&[1., 1., 1., 1.],
|
||||
&[0., 4., 1., 0.],
|
||||
&[0., 3., 2., 1.],
|
||||
&[0., 3., 1., 1.],
|
||||
&[3., 4., 0., 1.],
|
||||
&[3., 4., 2., 4.],
|
||||
&[0., 3., 1., 2.],
|
||||
&[0., 4., 1., 2.],
|
||||
]);
|
||||
|
||||
let y = vec![0., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 0.];
|
||||
let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let deserialized_cnb: CategoricalNB<f64, DenseMatrix<f64>> =
|
||||
serde_json::from_str(&serde_json::to_string(&cnb).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(cnb, deserialized_cnb);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user