Implement Display for NaiveBayes

This commit is contained in:
Lorenzo (Mec-iS)
2022-11-03 14:18:56 +00:00
parent d298709040
commit ba70bb941f
4 changed files with 62 additions and 0 deletions
+17
View File
@@ -364,6 +364,20 @@ pub struct BernoulliNB<
binarize: Option<TX>, binarize: Option<TX>,
} }
impl<TX: Number + PartialOrd, TY: Number + Ord + Unsigned, X: Array2<TX>, Y: Array1<TY>>
fmt::Display for BernoulliNB<TX, TY, X, Y>
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"BernoulliNB:\ninner: {:?}\nbinarize: {:?}",
self.inner.as_ref().unwrap(),
self.binarize.as_ref().unwrap()
)?;
Ok(())
}
}
impl<TX: Number + PartialOrd, TY: Number + Ord + Unsigned, X: Array2<TX>, Y: Array1<TY>> impl<TX: Number + PartialOrd, TY: Number + Ord + Unsigned, X: Array2<TX>, Y: Array1<TY>>
SupervisedEstimator<X, Y, BernoulliNBParameters<TX>> for BernoulliNB<TX, TY, X, Y> SupervisedEstimator<X, Y, BernoulliNBParameters<TX>> for BernoulliNB<TX, TY, X, Y>
{ {
@@ -594,6 +608,9 @@ mod tests {
] ]
); );
// test Display
println!("{}", &bnb);
let distribution = bnb.inner.clone().unwrap().distribution; let distribution = bnb.inner.clone().unwrap().distribution;
assert_eq!( assert_eq!(
+13
View File
@@ -139,6 +139,17 @@ impl<T: Number + Unsigned> NBDistribution<T, T> for CategoricalNBDistribution<T>
} }
} }
impl<T: Number + Unsigned, X: Array2<T>, Y: Array1<T>> fmt::Display for CategoricalNB<T, X, Y> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"CategoricalNB:\ninner: {:?}",
self.inner.as_ref().unwrap()
)?;
Ok(())
}
}
impl<T: Number + Unsigned> CategoricalNBDistribution<T> { impl<T: Number + Unsigned> CategoricalNBDistribution<T> {
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features. /// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
/// * `x` - training data. /// * `x` - training data.
@@ -539,6 +550,8 @@ mod tests {
let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap(); let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap();
let y_hat = cnb.predict(&x).unwrap(); let y_hat = cnb.predict(&x).unwrap();
assert_eq!(y_hat, vec![0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1]); assert_eq!(y_hat, vec![0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1]);
println!("{}", &cnb);
} }
#[cfg_attr( #[cfg_attr(
+16
View File
@@ -271,6 +271,19 @@ pub struct GaussianNB<
inner: Option<BaseNaiveBayes<TX, TY, X, Y, GaussianNBDistribution<TY>>>, inner: Option<BaseNaiveBayes<TX, TY, X, Y, GaussianNBDistribution<TY>>>,
} }
impl<
TX: Number + RealNumber + RealNumber,
TY: Number + Ord + Unsigned,
X: Array2<TX>,
Y: Array1<TY>,
> fmt::Display for GaussianNB<TX, TY, X, Y>
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "GaussianNB:\ninner: {:?}", self.inner.as_ref().unwrap())?;
Ok(())
}
}
impl< impl<
TX: Number + RealNumber + RealNumber, TX: Number + RealNumber + RealNumber,
TY: Number + Ord + Unsigned, TY: Number + Ord + Unsigned,
@@ -433,6 +446,9 @@ mod tests {
let gnb = GaussianNB::fit(&x, &y, parameters).unwrap(); let gnb = GaussianNB::fit(&x, &y, parameters).unwrap();
assert_eq!(gnb.class_priors(), &priors); assert_eq!(gnb.class_priors(), &priors);
// test display for GNB
println!("{}", &gnb);
} }
#[cfg_attr( #[cfg_attr(
+16
View File
@@ -309,6 +309,19 @@ pub struct MultinomialNB<
inner: Option<BaseNaiveBayes<TX, TY, X, Y, MultinomialNBDistribution<TY>>>, inner: Option<BaseNaiveBayes<TX, TY, X, Y, MultinomialNBDistribution<TY>>>,
} }
impl<TX: Number + Unsigned, TY: Number + Ord + Unsigned, X: Array2<TX>, Y: Array1<TY>> fmt::Display
for MultinomialNB<TX, TY, X, Y>
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"MultinomialNB:\ninner: {:?}",
self.inner.as_ref().unwrap()
)?;
Ok(())
}
}
impl<TX: Number + Unsigned, TY: Number + Ord + Unsigned, X: Array2<TX>, Y: Array1<TY>> impl<TX: Number + Unsigned, TY: Number + Ord + Unsigned, X: Array2<TX>, Y: Array1<TY>>
SupervisedEstimator<X, Y, MultinomialNBParameters> for MultinomialNB<TX, TY, X, Y> SupervisedEstimator<X, Y, MultinomialNBParameters> for MultinomialNB<TX, TY, X, Y>
{ {
@@ -500,6 +513,9 @@ mod tests {
] ]
); );
// test display
println!("{}", &nb);
let y_hat = nb.predict(&x).unwrap(); let y_hat = nb.predict(&x).unwrap();
let distribution = nb.inner.clone().unwrap().distribution; let distribution = nb.inner.clone().unwrap().distribution;