From 9db993939e657d50c3cdd17f7d80bd7cd9003496 Mon Sep 17 00:00:00 2001 From: morenol Date: Thu, 19 Nov 2020 16:07:10 -0400 Subject: [PATCH] Add serde to CategoricalNB (#30) * Add serde to CategoricalNB * Implement PartialEq for CategoricalNBDistribution --- src/naive_bayes/categorical.rs | 60 ++++++++++++++++++++++++++++++++-- 1 file changed, 58 insertions(+), 2 deletions(-) diff --git a/src/naive_bayes/categorical.rs b/src/naive_bayes/categorical.rs index ae6eb0c..d32c34d 100644 --- a/src/naive_bayes/categorical.rs +++ b/src/naive_bayes/categorical.rs @@ -6,13 +6,41 @@ use crate::naive_bayes::{BaseNaiveBayes, NBDistribution}; use serde::{Deserialize, Serialize}; /// Naive Bayes classifier for categorical features -#[derive(Debug)] +#[derive(Serialize, Deserialize, Debug)] struct CategoricalNBDistribution { class_labels: Vec, class_priors: Vec, coefficients: Vec>>, } +impl PartialEq for CategoricalNBDistribution { + fn eq(&self, other: &Self) -> bool { + if self.class_labels == other.class_labels && self.class_priors == other.class_priors { + if self.coefficients.len() != other.coefficients.len() { + return false; + } + for (a, b) in self.coefficients.iter().zip(other.coefficients.iter()) { + if a.len() != b.len() { + return false; + } + for (a_i, b_i) in a.iter().zip(b.iter()) { + if a_i.len() != b_i.len() { + return false; + } + for (a_i_j, b_i_j) in a_i.iter().zip(b_i.iter()) { + if (*a_i_j - *b_i_j).abs() > T::epsilon() { + return false; + } + } + } + } + true + } else { + false + } + } +} + impl> NBDistribution for CategoricalNBDistribution { fn prior(&self, class_index: usize) -> T { if class_index >= self.class_labels.len() { @@ -181,7 +209,7 @@ impl Default for CategoricalNBParameters { } /// CategoricalNB implements the categorical naive Bayes algorithm for categorically distributed data. -#[derive(Debug)] +#[derive(Serialize, Deserialize, Debug, PartialEq)] pub struct CategoricalNB> { inner: BaseNaiveBayes>, } @@ -269,4 +297,32 @@ mod tests { vec![0., 0., 1., 1., 1., 0., 1., 0., 1., 1., 0., 1., 1., 1.] ); } + + #[test] + fn serde() { + let x = DenseMatrix::::from_2d_array(&[ + &[3., 4., 0., 1.], + &[3., 0., 0., 1.], + &[4., 4., 1., 2.], + &[4., 2., 4., 3.], + &[4., 2., 4., 2.], + &[4., 1., 1., 0.], + &[1., 1., 1., 1.], + &[0., 4., 1., 0.], + &[0., 3., 2., 1.], + &[0., 3., 1., 1.], + &[3., 4., 0., 1.], + &[3., 4., 2., 4.], + &[0., 3., 1., 2.], + &[0., 4., 1., 2.], + ]); + + let y = vec![0., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 0.]; + let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap(); + + let deserialized_cnb: CategoricalNB> = + serde_json::from_str(&serde_json::to_string(&cnb).unwrap()).unwrap(); + + assert_eq!(cnb, deserialized_cnb); + } }