diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 793e79d..09b53b6 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -39,6 +39,6 @@ jobs: command: tarpaulin args: --out Lcov --all-features -- --test-threads 1 - name: Upload to codecov.io - uses: codecov/codecov-action@v1 + uses: codecov/codecov-action@v2 with: fail_ci_if_error: true diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index 3e32d6b..8f2e013 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -54,7 +54,7 @@ use std::fmt::Debug; use serde::{Deserialize, Serialize}; use crate::api::{Predictor, SupervisedEstimator}; -use crate::error::Failed; +use crate::error::{Failed, FailedError}; use crate::linalg::basic::arrays::{Array1, Array2}; use crate::numbers::basenum::Number; use crate::numbers::floatnum::FloatNumber; @@ -104,9 +104,10 @@ pub struct RandomForestClassifier< X: Array2, Y: Array1, > { - parameters: RandomForestClassifierParameters, - trees: Vec>, - classes: Vec, + parameters: Option, + trees: Option>>, + classes: Option>, + samples: Option>>, } impl RandomForestClassifierParameters { @@ -154,11 +155,13 @@ impl RandomForestClassifierParameters { } } -impl, Y: Array1> PartialEq - for RandomForestClassifier +impl, Y: Array1> + PartialEq for RandomForestClassifier { fn eq(&self, other: &Self) -> bool { - if self.classes.len() != other.classes.len() || self.trees.len() != other.trees.len() { + if self.classes.as_ref().unwrap().len() != other.classes.as_ref().unwrap().len() + || self.trees.as_ref().unwrap().len() != other.trees.as_ref().unwrap().len() + { false } else { self.classes @@ -189,17 +192,25 @@ impl Default for RandomForestClassifierParameters { } } -impl, Y: Array1> +impl, Y: Array1> SupervisedEstimator for RandomForestClassifier { + fn new() -> Self { + Self { + parameters: Option::None, + trees: Option::None, + classes: Option::None, + samples: Option::None, + } + } fn fit(x: &X, y: &Y, parameters: RandomForestClassifierParameters) -> Result { RandomForestClassifier::fit(x, y, parameters) } } -impl, Y: Array1> Predictor - for RandomForestClassifier +impl, Y: Array1> + Predictor for RandomForestClassifier { fn predict(&self, x: &X) -> Result { self.predict(x) @@ -462,10 +473,22 @@ impl, Y: Array1> = Vec::new(); + let mut maybe_all_samples: Option>> = Option::None; + if parameters.keep_samples { + // TODO: use with_capacity here + maybe_all_samples = Some(Vec::new()); + } + for _ in 0..parameters.n_trees { - let samples = RandomForestClassifier::::sample_with_replacement(&yi, k, &mut rng); + let samples: Vec = + RandomForestClassifier::::sample_with_replacement(&yi, k, &mut rng); + if let Some(ref mut all_samples) = maybe_all_samples { + all_samples.push(samples.iter().map(|x| *x != 0).collect()) + } + let params = DecisionTreeClassifierParameters { criterion: parameters.criterion.clone(), max_depth: parameters.max_depth, @@ -478,9 +501,10 @@ impl, Y: Array1, Y: Array1 usize { - let mut result = vec![0; self.classes.len()]; + let mut result = vec![0; self.classes.as_ref().unwrap().len()]; - for tree in self.trees.iter() { + for tree in self.trees.as_ref().unwrap().iter() { result[tree.predict_for_row(x, row)] += 1; } @@ -511,38 +538,43 @@ impl, Y: Array1 Result { let (n, _) = x.shape(); - /* TODO: fix this: - if self.samples.is_none() { - Err(Failed::because( - FailedError::PredictFailed, - "Need samples=true for OOB predictions.", - )) - } else if self.samples.as_ref().unwrap()[0].len() != n { - Err(Failed::because( - FailedError::PredictFailed, - "Prediction matrix must match matrix used in training for OOB predictions.", - )) - } else { - */ - let mut result = Y::zeros(n); + if self.samples.is_none() { + Err(Failed::because( + FailedError::PredictFailed, + "Need samples=true for OOB predictions.", + )) + } else if self.samples.as_ref().unwrap()[0].len() != n { + Err(Failed::because( + FailedError::PredictFailed, + "Prediction matrix must match matrix used in training for OOB predictions.", + )) + } else { + let mut result = Y::zeros(n); - for i in 0..n { - result.set(i, self.classes[self.predict_for_row_oob(x, i)]); + for i in 0..n { + result.set( + i, + self.classes.as_ref().unwrap()[self.predict_for_row_oob(x, i)], + ); + } + Ok(result) } - - Ok(result) - //} } fn predict_for_row_oob(&self, x: &X, row: usize) -> usize { - let mut result = vec![0; self.classes.len()]; + let mut result = vec![0; self.classes.as_ref().unwrap().len()]; - // TODO: FIX THIS - //for (tree, samples) in self.trees.iter().zip(self.samples.as_ref().unwrap()) { - // if !samples[row] { - // result[tree.predict_for_row(x, row)] += 1; - // } - // } + for (tree, samples) in self + .trees + .as_ref() + .unwrap() + .iter() + .zip(self.samples.as_ref().unwrap()) + { + if !samples[row] { + result[tree.predict_for_row(x, row)] += 1; + } + } which_max(&result) } @@ -671,9 +703,7 @@ mod tests { &[6.6, 2.9, 4.6, 1.3], &[5.2, 2.7, 3.9, 1.4], ]); - let y = vec![ - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - ]; + let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; let classifier = RandomForestClassifier::fit( &x, @@ -697,39 +727,39 @@ mod tests { ); } - // #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] - // #[test] - // #[cfg(feature = "serde")] - // fn serde() { - // let x = DenseMatrix::from_2d_array(&[ - // &[5.1, 3.5, 1.4, 0.2], - // &[4.9, 3.0, 1.4, 0.2], - // &[4.7, 3.2, 1.3, 0.2], - // &[4.6, 3.1, 1.5, 0.2], - // &[5.0, 3.6, 1.4, 0.2], - // &[5.4, 3.9, 1.7, 0.4], - // &[4.6, 3.4, 1.4, 0.3], - // &[5.0, 3.4, 1.5, 0.2], - // &[4.4, 2.9, 1.4, 0.2], - // &[4.9, 3.1, 1.5, 0.1], - // &[7.0, 3.2, 4.7, 1.4], - // &[6.4, 3.2, 4.5, 1.5], - // &[6.9, 3.1, 4.9, 1.5], - // &[5.5, 2.3, 4.0, 1.3], - // &[6.5, 2.8, 4.6, 1.5], - // &[5.7, 2.8, 4.5, 1.3], - // &[6.3, 3.3, 4.7, 1.6], - // &[4.9, 2.4, 3.3, 1.0], - // &[6.6, 2.9, 4.6, 1.3], - // &[5.2, 2.7, 3.9, 1.4], - // ]); - // let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + #[cfg(feature = "serde")] + fn serde() { + let x = DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + &[4.9, 3.1, 1.5, 0.1], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + &[5.7, 2.8, 4.5, 1.3], + &[6.3, 3.3, 4.7, 1.6], + &[4.9, 2.4, 3.3, 1.0], + &[6.6, 2.9, 4.6, 1.3], + &[5.2, 2.7, 3.9, 1.4], + ]); + let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; - // let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap(); + let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap(); - // let deserialized_forest: RandomForestClassifier, Vec> = - // bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap(); + let deserialized_forest: RandomForestClassifier, Vec> = + bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap(); - // assert_eq!(forest, deserialized_forest); - // } + assert_eq!(forest, deserialized_forest); + } } diff --git a/src/ensemble/random_forest_regressor.rs b/src/ensemble/random_forest_regressor.rs index b323877..a54ac3a 100644 --- a/src/ensemble/random_forest_regressor.rs +++ b/src/ensemble/random_forest_regressor.rs @@ -51,7 +51,7 @@ use std::fmt::Debug; use serde::{Deserialize, Serialize}; use crate::api::{Predictor, SupervisedEstimator}; -use crate::error::Failed; +use crate::error::{Failed, FailedError}; use crate::linalg::basic::arrays::{Array1, Array2}; use crate::numbers::basenum::Number; use crate::numbers::floatnum::FloatNumber; @@ -92,11 +92,15 @@ pub struct RandomForestRegressorParameters { /// Random Forest Regressor #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug)] -pub struct RandomForestRegressor, Y: Array1> -{ - parameters: RandomForestRegressorParameters, - trees: Vec>, - samples: Option>> +pub struct RandomForestRegressor< + TX: Number + FloatNumber + PartialOrd, + TY: Number, + X: Array2, + Y: Array1, +> { + parameters: Option, + trees: Option>>, + samples: Option>>, } impl RandomForestRegressorParameters { @@ -156,7 +160,7 @@ impl, Y: Array1 for RandomForestRegressor { fn eq(&self, other: &Self) -> bool { - if self.trees.len() != other.trees.len() { + if self.trees.as_ref().unwrap().len() != other.trees.as_ref().unwrap().len() { false } else { self.trees @@ -171,13 +175,21 @@ impl, Y: Array1 SupervisedEstimator for RandomForestRegressor { + fn new() -> Self { + Self { + parameters: Option::None, + trees: Option::None, + samples: Option::None, + } + } + fn fit(x: &X, y: &Y, parameters: RandomForestRegressorParameters) -> Result { RandomForestRegressor::fit(x, y, parameters) } } -impl, Y: Array1> Predictor - for RandomForestRegressor +impl, Y: Array1> + Predictor for RandomForestRegressor { fn predict(&self, x: &X) -> Result { self.predict(x) @@ -396,17 +408,19 @@ impl, Y: Array1 let mut rng = get_rng_impl(Some(parameters.seed)); let mut trees: Vec> = Vec::new(); - let mut maybe_all_samples: Vec> = Vec::new(); + let mut maybe_all_samples: Option>> = Option::None; + if parameters.keep_samples { + // TODO: use with_capacity here + maybe_all_samples = Some(Vec::new()); + } for _ in 0..parameters.n_trees { - let samples = RandomForestRegressor::::sample_with_replacement( - n_rows, - &mut rng, - ); + let samples: Vec = + RandomForestRegressor::::sample_with_replacement(n_rows, &mut rng); // keep samples is flag is on - if parameters.keep_samples { - maybe_all_samples.push(samples); + if let Some(ref mut all_samples) = maybe_all_samples { + all_samples.push(samples.iter().map(|x| *x != 0).collect()) } let params = DecisionTreeRegressorParameters { @@ -419,17 +433,10 @@ impl, Y: Array1 trees.push(tree); } - let samples; - if maybe_all_samples.len() == 0 { - samples = Option::None; - } else { - samples = Some(maybe_all_samples) - } - Ok(RandomForestRegressor { - parameters: parameters, - trees, - samples + parameters: Some(parameters), + trees: Some(trees), + samples: maybe_all_samples, }) } @@ -448,11 +455,11 @@ impl, Y: Array1 } fn predict_for_row(&self, x: &X, row: usize) -> TY { - let n_trees = self.trees.len(); + let n_trees = self.trees.as_ref().unwrap().len(); let mut result = TY::zero(); - for tree in self.trees.iter() { + for tree in self.trees.as_ref().unwrap().iter() { result += tree.predict_for_row(x, row); } @@ -462,7 +469,6 @@ impl, Y: Array1 /// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training. pub fn predict_oob(&self, x: &X) -> Result { let (n, _) = x.shape(); - /* TODO: FIX THIS if self.samples.is_none() { Err(Failed::because( FailedError::PredictFailed, @@ -473,29 +479,32 @@ impl, Y: Array1 FailedError::PredictFailed, "Prediction matrix must match matrix used in training for OOB predictions.", )) - } else { - let mut result = Y::zeros(n); + } else { + let mut result = Y::zeros(n); - for i in 0..n { - result.set(i, self.predict_for_row_oob(x, i)); + for i in 0..n { + result.set(i, self.predict_for_row_oob(x, i)); + } + + Ok(result) } - - Ok(result) - }*/ - let result = Y::zeros(n); - Ok(result) } - //TODo: fix this fn predict_for_row_oob(&self, x: &X, row: usize) -> TY { let mut n_trees = 0; let mut result = TY::zero(); - for (tree, samples) in self.trees.iter().zip(self.samples.as_ref().unwrap()) { - if !samples[row] { - result += tree.predict_for_row(x, row); - n_trees += 1; - } + for (tree, samples) in self + .trees + .as_ref() + .unwrap() + .iter() + .zip(self.samples.as_ref().unwrap()) + { + if !samples[row] { + result += tree.predict_for_row(x, row); + n_trees += 1; + } } // TODO: What to do if there are no oob trees? @@ -636,39 +645,38 @@ mod tests { assert!(mean_absolute_error(&y, &y_hat) < mean_absolute_error(&y, &y_hat_oob)); } - // TODO: missing deserialization for DenseMatrix - // #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] - // #[test] - // #[cfg(feature = "serde")] - // fn serde() { - // let x = DenseMatrix::from_2d_array(&[ - // &[234.289, 235.6, 159., 107.608, 1947., 60.323], - // &[259.426, 232.5, 145.6, 108.632, 1948., 61.122], - // &[258.054, 368.2, 161.6, 109.773, 1949., 60.171], - // &[284.599, 335.1, 165., 110.929, 1950., 61.187], - // &[328.975, 209.9, 309.9, 112.075, 1951., 63.221], - // &[346.999, 193.2, 359.4, 113.27, 1952., 63.639], - // &[365.385, 187., 354.7, 115.094, 1953., 64.989], - // &[363.112, 357.8, 335., 116.219, 1954., 63.761], - // &[397.469, 290.4, 304.8, 117.388, 1955., 66.019], - // &[419.18, 282.2, 285.7, 118.734, 1956., 67.857], - // &[442.769, 293.6, 279.8, 120.445, 1957., 68.169], - // &[444.546, 468.1, 263.7, 121.95, 1958., 66.513], - // &[482.704, 381.3, 255.2, 123.366, 1959., 68.655], - // &[502.601, 393.1, 251.4, 125.368, 1960., 69.564], - // &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], - // &[554.894, 400.7, 282.7, 130.081, 1962., 70.551], - // ]); - // let y = vec![ - // 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, - // 114.2, 115.7, 116.9, - // ]; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + #[cfg(feature = "serde")] + fn serde() { + let x = DenseMatrix::from_2d_array(&[ + &[234.289, 235.6, 159., 107.608, 1947., 60.323], + &[259.426, 232.5, 145.6, 108.632, 1948., 61.122], + &[258.054, 368.2, 161.6, 109.773, 1949., 60.171], + &[284.599, 335.1, 165., 110.929, 1950., 61.187], + &[328.975, 209.9, 309.9, 112.075, 1951., 63.221], + &[346.999, 193.2, 359.4, 113.27, 1952., 63.639], + &[365.385, 187., 354.7, 115.094, 1953., 64.989], + &[363.112, 357.8, 335., 116.219, 1954., 63.761], + &[397.469, 290.4, 304.8, 117.388, 1955., 66.019], + &[419.18, 282.2, 285.7, 118.734, 1956., 67.857], + &[442.769, 293.6, 279.8, 120.445, 1957., 68.169], + &[444.546, 468.1, 263.7, 121.95, 1958., 66.513], + &[482.704, 381.3, 255.2, 123.366, 1959., 68.655], + &[502.601, 393.1, 251.4, 125.368, 1960., 69.564], + &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], + &[554.894, 400.7, 282.7, 130.081, 1962., 70.551], + ]); + let y = vec![ + 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, + 114.2, 115.7, 116.9, + ]; - // let forest = RandomForestRegressor::fit(&x, &y, Default::default()).unwrap(); + let forest = RandomForestRegressor::fit(&x, &y, Default::default()).unwrap(); - // let deserialized_forest: RandomForestRegressor, Vec> = - // bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap(); + let deserialized_forest: RandomForestRegressor, Vec> = + bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap(); - // assert_eq!(forest, deserialized_forest); - // } + assert_eq!(forest, deserialized_forest); + } } diff --git a/src/lib.rs b/src/lib.rs index c74c573..d665838 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -80,7 +80,7 @@ pub mod dataset; /// Matrix decomposition algorithms pub mod decomposition; /// Ensemble methods, including Random Forest classifier and regressor -// pub mod ensemble; +pub mod ensemble; pub mod error; /// Diverse collection of linear algebra abstractions and methods that power SmartCore algorithms pub mod linalg; diff --git a/src/linalg/basic/matrix.rs b/src/linalg/basic/matrix.rs index 7fdbfc1..149c1fc 100644 --- a/src/linalg/basic/matrix.rs +++ b/src/linalg/basic/matrix.rs @@ -4,7 +4,7 @@ use std::ops::Range; use std::slice::Iter; use approx::{AbsDiffEq, RelativeEq}; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use crate::linalg::basic::arrays::{ Array, Array2, ArrayView1, ArrayView2, MutArray, MutArrayView2, @@ -19,7 +19,7 @@ use crate::numbers::basenum::Number; use crate::numbers::realnum::RealNumber; /// Dense matrix -#[derive(Debug, Clone, Serialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct DenseMatrix { ncols: usize, nrows: usize, diff --git a/src/naive_bayes/bernoulli.rs b/src/naive_bayes/bernoulli.rs index 4f17d9a..1ded589 100644 --- a/src/naive_bayes/bernoulli.rs +++ b/src/naive_bayes/bernoulli.rs @@ -33,6 +33,8 @@ //! ## References: //! //! * ["Introduction to Information Retrieval", Manning C. D., Raghavan P., Schutze H., 2009, Chapter 13 ](https://nlp.stanford.edu/IR-book/information-retrieval-book.html) +use std::fmt; + use num_traits::Unsigned; use crate::api::{Predictor, SupervisedEstimator}; @@ -62,6 +64,18 @@ struct BernoulliNBDistribution { n_features: usize, } +impl fmt::Display for BernoulliNBDistribution { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!( + f, + "BernoulliNBDistribution: n_features: {:?}", + self.n_features + )?; + writeln!(f, "class_labels: {:?}", self.class_labels)?; + Ok(()) + } +} + impl PartialEq for BernoulliNBDistribution { fn eq(&self, other: &Self) -> bool { if self.class_labels == other.class_labels @@ -598,23 +612,22 @@ mod tests { assert_eq!(y_hat, vec!(2, 2, 0, 0, 0, 2, 1, 1, 0, 0, 0, 0, 0, 0, 0)); } - // TODO: implement serialization - // #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] - // #[test] - // #[cfg(feature = "serde")] - // fn serde() { - // let x = DenseMatrix::from_2d_array(&[ - // &[1, 1, 0, 0, 0, 0], - // &[0, 1, 0, 0, 1, 0], - // &[0, 1, 0, 1, 0, 0], - // &[0, 1, 1, 0, 0, 1], - // ]); - // let y: Vec = vec![0, 0, 0, 1]; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + #[cfg(feature = "serde")] + fn serde() { + let x = DenseMatrix::from_2d_array(&[ + &[1, 1, 0, 0, 0, 0], + &[0, 1, 0, 0, 1, 0], + &[0, 1, 0, 1, 0, 0], + &[0, 1, 1, 0, 0, 1], + ]); + let y: Vec = vec![0, 0, 0, 1]; - // let bnb = BernoulliNB::fit(&x, &y, Default::default()).unwrap(); - // let deserialized_bnb: BernoulliNB, Vec> = - // serde_json::from_str(&serde_json::to_string(&bnb).unwrap()).unwrap(); + let bnb = BernoulliNB::fit(&x, &y, Default::default()).unwrap(); + let deserialized_bnb: BernoulliNB, Vec> = + serde_json::from_str(&serde_json::to_string(&bnb).unwrap()).unwrap(); - // assert_eq!(bnb, deserialized_bnb); - // } + assert_eq!(bnb, deserialized_bnb); + } } diff --git a/src/naive_bayes/categorical.rs b/src/naive_bayes/categorical.rs index 77645f5..3196b3b 100644 --- a/src/naive_bayes/categorical.rs +++ b/src/naive_bayes/categorical.rs @@ -30,6 +30,8 @@ //! let nb = CategoricalNB::fit(&x, &y, Default::default()).unwrap(); //! let y_hat = nb.predict(&x).unwrap(); //! ``` +use std::fmt; + use num_traits::Unsigned; use crate::api::{Predictor, SupervisedEstimator}; @@ -61,6 +63,18 @@ struct CategoricalNBDistribution { category_count: Vec>>, } +impl fmt::Display for CategoricalNBDistribution { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!( + f, + "CategoricalNBDistribution: n_features: {:?}", + self.n_features + )?; + writeln!(f, "class_labels: {:?}", self.class_labels)?; + Ok(()) + } +} + impl PartialEq for CategoricalNBDistribution { fn eq(&self, other: &Self) -> bool { if self.class_labels == other.class_labels @@ -521,34 +535,33 @@ mod tests { assert_eq!(y_hat, vec![0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1]); } - // TODO: implement serialization - // #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] - // #[test] - // #[cfg(feature = "serde")] - // fn serde() { - // let x = DenseMatrix::from_2d_array(&[ - // &[3, 4, 0, 1], - // &[3, 0, 0, 1], - // &[4, 4, 1, 2], - // &[4, 2, 4, 3], - // &[4, 2, 4, 2], - // &[4, 1, 1, 0], - // &[1, 1, 1, 1], - // &[0, 4, 1, 0], - // &[0, 3, 2, 1], - // &[0, 3, 1, 1], - // &[3, 4, 0, 1], - // &[3, 4, 2, 4], - // &[0, 3, 1, 2], - // &[0, 4, 1, 2], - // ]); + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + #[cfg(feature = "serde")] + fn serde() { + let x = DenseMatrix::from_2d_array(&[ + &[3, 4, 0, 1], + &[3, 0, 0, 1], + &[4, 4, 1, 2], + &[4, 2, 4, 3], + &[4, 2, 4, 2], + &[4, 1, 1, 0], + &[1, 1, 1, 1], + &[0, 4, 1, 0], + &[0, 3, 2, 1], + &[0, 3, 1, 1], + &[3, 4, 0, 1], + &[3, 4, 2, 4], + &[0, 3, 1, 2], + &[0, 4, 1, 2], + ]); - // let y: Vec = vec![0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0]; - // let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap(); + let y: Vec = vec![0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0]; + let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap(); - // let deserialized_cnb: CategoricalNB, Vec> = - // serde_json::from_str(&serde_json::to_string(&cnb).unwrap()).unwrap(); + let deserialized_cnb: CategoricalNB, Vec> = + serde_json::from_str(&serde_json::to_string(&cnb).unwrap()).unwrap(); - // assert_eq!(cnb, deserialized_cnb); - // } + assert_eq!(cnb, deserialized_cnb); + } } diff --git a/src/naive_bayes/gaussian.rs b/src/naive_bayes/gaussian.rs index aecef39..c8223fd 100644 --- a/src/naive_bayes/gaussian.rs +++ b/src/naive_bayes/gaussian.rs @@ -22,6 +22,8 @@ //! let nb = GaussianNB::fit(&x, &y, Default::default()).unwrap(); //! let y_hat = nb.predict(&x).unwrap(); //! ``` +use std::fmt; + use num_traits::Unsigned; use crate::api::{Predictor, SupervisedEstimator}; @@ -49,6 +51,18 @@ struct GaussianNBDistribution { theta: Vec>, } +impl fmt::Display for GaussianNBDistribution { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!( + f, + "GaussianNBDistribution: class_count: {:?}", + self.class_count + )?; + writeln!(f, "class_labels: {:?}", self.class_labels)?; + Ok(()) + } +} + impl NBDistribution for GaussianNBDistribution { @@ -415,25 +429,24 @@ mod tests { assert_eq!(gnb.class_priors(), &priors); } - // TODO: implement serialization - // #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] - // #[test] - // #[cfg(feature = "serde")] - // fn serde() { - // let x = DenseMatrix::::from_2d_array(&[ - // &[-1., -1.], - // &[-2., -1.], - // &[-3., -2.], - // &[1., 1.], - // &[2., 1.], - // &[3., 2.], - // ]); - // let y: Vec = vec![1, 1, 1, 2, 2, 2]; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + #[cfg(feature = "serde")] + fn serde() { + let x = DenseMatrix::::from_2d_array(&[ + &[-1., -1.], + &[-2., -1.], + &[-3., -2.], + &[1., 1.], + &[2., 1.], + &[3., 2.], + ]); + let y: Vec = vec![1, 1, 1, 2, 2, 2]; - // let gnb = GaussianNB::fit(&x, &y, Default::default()).unwrap(); - // let deserialized_gnb: GaussianNB, Vec> = - // serde_json::from_str(&serde_json::to_string(&gnb).unwrap()).unwrap(); + let gnb = GaussianNB::fit(&x, &y, Default::default()).unwrap(); + let deserialized_gnb: GaussianNB, Vec> = + serde_json::from_str(&serde_json::to_string(&gnb).unwrap()).unwrap(); - // assert_eq!(gnb, deserialized_gnb); - // } + assert_eq!(gnb, deserialized_gnb); + } } diff --git a/src/naive_bayes/multinomial.rs b/src/naive_bayes/multinomial.rs index bb13e7d..f82d4fc 100644 --- a/src/naive_bayes/multinomial.rs +++ b/src/naive_bayes/multinomial.rs @@ -33,6 +33,8 @@ //! ## References: //! //! * ["Introduction to Information Retrieval", Manning C. D., Raghavan P., Schutze H., 2009, Chapter 13 ](https://nlp.stanford.edu/IR-book/information-retrieval-book.html) +use std::fmt; + use num_traits::Unsigned; use crate::api::{Predictor, SupervisedEstimator}; @@ -62,6 +64,18 @@ struct MultinomialNBDistribution { n_features: usize, } +impl fmt::Display for MultinomialNBDistribution { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!( + f, + "MultinomialNBDistribution: n_features: {:?}", + self.n_features + )?; + writeln!(f, "class_labels: {:?}", self.class_labels)?; + Ok(()) + } +} + impl NBDistribution for MultinomialNBDistribution { @@ -510,23 +524,22 @@ mod tests { assert_eq!(y_hat, vec!(2, 2, 0, 0, 0, 2, 2, 1, 0, 1, 0, 2, 0, 0, 2)); } - // TODO: implement serialization - // #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] - // #[test] - // #[cfg(feature = "serde")] - // fn serde() { - // let x = DenseMatrix::from_2d_array(&[ - // &[1, 1, 0, 0, 0, 0], - // &[0, 1, 0, 0, 1, 0], - // &[0, 1, 0, 1, 0, 0], - // &[0, 1, 1, 0, 0, 1], - // ]); - // let y = vec![0, 0, 0, 1]; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + #[cfg(feature = "serde")] + fn serde() { + let x = DenseMatrix::from_2d_array(&[ + &[1, 1, 0, 0, 0, 0], + &[0, 1, 0, 0, 1, 0], + &[0, 1, 0, 1, 0, 0], + &[0, 1, 1, 0, 0, 1], + ]); + let y = vec![0, 0, 0, 1]; - // let mnb = MultinomialNB::fit(&x, &y, Default::default()).unwrap(); - // let deserialized_mnb: MultinomialNB, Vec> = - // serde_json::from_str(&serde_json::to_string(&mnb).unwrap()).unwrap(); + let mnb = MultinomialNB::fit(&x, &y, Default::default()).unwrap(); + let deserialized_mnb: MultinomialNB, Vec> = + serde_json::from_str(&serde_json::to_string(&mnb).unwrap()).unwrap(); - // assert_eq!(mnb, deserialized_mnb); - // } + assert_eq!(mnb, deserialized_mnb); + } } diff --git a/src/svm/svc.rs b/src/svm/svc.rs index aa4e5cc..256c3c3 100644 --- a/src/svm/svc.rs +++ b/src/svm/svc.rs @@ -1119,10 +1119,8 @@ mod tests { let svc = SVC::fit(&x, &y, ¶ms).unwrap(); // serialization - let _serialized_svc = &serde_json::to_string(&svc).unwrap(); + let serialized_svc = &serde_json::to_string(&svc).unwrap(); - // println!("{:?}", serialized_svc); - - // TODO: for deserialization, deserialization is needed for `linalg::basic::matrix::DenseMatrix` + println!("{:?}", serialized_svc); } }