Add serde to CategoricalNB (#30)
* Add serde to CategoricalNB * Implement PartialEq for CategoricalNBDistribution
This commit is contained in:
@@ -6,13 +6,41 @@ use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// Naive Bayes classifier for categorical features
|
/// Naive Bayes classifier for categorical features
|
||||||
#[derive(Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
struct CategoricalNBDistribution<T: RealNumber> {
|
struct CategoricalNBDistribution<T: RealNumber> {
|
||||||
class_labels: Vec<T>,
|
class_labels: Vec<T>,
|
||||||
class_priors: Vec<T>,
|
class_priors: Vec<T>,
|
||||||
coefficients: Vec<Vec<Vec<T>>>,
|
coefficients: Vec<Vec<Vec<T>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber> PartialEq for CategoricalNBDistribution<T> {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
if self.class_labels == other.class_labels && self.class_priors == other.class_priors {
|
||||||
|
if self.coefficients.len() != other.coefficients.len() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (a, b) in self.coefficients.iter().zip(other.coefficients.iter()) {
|
||||||
|
if a.len() != b.len() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (a_i, b_i) in a.iter().zip(b.iter()) {
|
||||||
|
if a_i.len() != b_i.len() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (a_i_j, b_i_j) in a_i.iter().zip(b_i.iter()) {
|
||||||
|
if (*a_i_j - *b_i_j).abs() > T::epsilon() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for CategoricalNBDistribution<T> {
|
impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for CategoricalNBDistribution<T> {
|
||||||
fn prior(&self, class_index: usize) -> T {
|
fn prior(&self, class_index: usize) -> T {
|
||||||
if class_index >= self.class_labels.len() {
|
if class_index >= self.class_labels.len() {
|
||||||
@@ -181,7 +209,7 @@ impl<T: RealNumber> Default for CategoricalNBParameters<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// CategoricalNB implements the categorical naive Bayes algorithm for categorically distributed data.
|
/// CategoricalNB implements the categorical naive Bayes algorithm for categorically distributed data.
|
||||||
#[derive(Debug)]
|
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||||
pub struct CategoricalNB<T: RealNumber, M: Matrix<T>> {
|
pub struct CategoricalNB<T: RealNumber, M: Matrix<T>> {
|
||||||
inner: BaseNaiveBayes<T, M, CategoricalNBDistribution<T>>,
|
inner: BaseNaiveBayes<T, M, CategoricalNBDistribution<T>>,
|
||||||
}
|
}
|
||||||
@@ -269,4 +297,32 @@ mod tests {
|
|||||||
vec![0., 0., 1., 1., 1., 0., 1., 0., 1., 1., 0., 1., 1., 1.]
|
vec![0., 0., 1., 1., 1., 0., 1., 0., 1., 1., 0., 1., 1., 1.]
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn serde() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[3., 4., 0., 1.],
|
||||||
|
&[3., 0., 0., 1.],
|
||||||
|
&[4., 4., 1., 2.],
|
||||||
|
&[4., 2., 4., 3.],
|
||||||
|
&[4., 2., 4., 2.],
|
||||||
|
&[4., 1., 1., 0.],
|
||||||
|
&[1., 1., 1., 1.],
|
||||||
|
&[0., 4., 1., 0.],
|
||||||
|
&[0., 3., 2., 1.],
|
||||||
|
&[0., 3., 1., 1.],
|
||||||
|
&[3., 4., 0., 1.],
|
||||||
|
&[3., 4., 2., 4.],
|
||||||
|
&[0., 3., 1., 2.],
|
||||||
|
&[0., 4., 1., 2.],
|
||||||
|
]);
|
||||||
|
|
||||||
|
let y = vec![0., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 0.];
|
||||||
|
let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap();
|
||||||
|
|
||||||
|
let deserialized_cnb: CategoricalNB<f64, DenseMatrix<f64>> =
|
||||||
|
serde_json::from_str(&serde_json::to_string(&cnb).unwrap()).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(cnb, deserialized_cnb);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user