From 01f753f86d7089613e1f3e3a815c60ce2c6cb1ca Mon Sep 17 00:00:00 2001 From: Christos Katsakioris Date: Tue, 6 Sep 2022 20:37:54 +0300 Subject: [PATCH] Add serde for StandardScaler (#148) * Derive `serde::Serialize` and `serde::Deserialize` for `StandardScaler`. * Add relevant unit test. Signed-off-by: Christos Katsakioris Signed-off-by: Christos Katsakioris --- src/preprocessing/numerical.rs | 43 ++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/src/preprocessing/numerical.rs b/src/preprocessing/numerical.rs index cc90b28..e2205c3 100644 --- a/src/preprocessing/numerical.rs +++ b/src/preprocessing/numerical.rs @@ -32,7 +32,11 @@ use crate::error::{Failed, FailedError}; use crate::linalg::Matrix; use crate::math::num::RealNumber; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + /// Configure Behaviour of `StandardScaler`. +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Clone, Debug, Copy, Eq, PartialEq)] pub struct StandardScalerParameters { /// Optionaly adjust mean to be zero. @@ -54,6 +58,7 @@ impl Default for StandardScalerParameters { /// deviation of one. This can improve model training for /// scaling sensitive models like neural network or nearest /// neighbors based models. +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Clone, Debug, Default, Eq, PartialEq)] pub struct StandardScaler { means: Vec, @@ -400,5 +405,43 @@ mod tests { Ok(DenseMatrix::from_2d_array(&[&[0.0, 3.0], &[2.0, 4.0]])) ) } + + /// Same as `fit_for_random_values` test, but using a `StandardScaler` that has been + /// serialized and deserialized. + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + #[cfg(feature = "serde")] + fn serde_fit_for_random_values() { + let fitted_scaler = StandardScaler::fit( + &DenseMatrix::from_2d_array(&[ + &[0.1004222429, 0.2194113576, 0.9310663354, 0.3313593793], + &[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264], + &[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046], + &[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442], + ]), + StandardScalerParameters::default(), + ) + .unwrap(); + + let deserialized_scaler: StandardScaler = + serde_json::from_str(&serde_json::to_string(&fitted_scaler).unwrap()).unwrap(); + + assert_eq!( + deserialized_scaler.means, + vec![0.42864544605, 0.2869813741, 0.737752073825, 0.431011663625], + ); + + assert!( + &DenseMatrix::from_2d_vec(&vec![deserialized_scaler.stds]).approximate_eq( + &DenseMatrix::from_2d_array(&[&[ + 0.29426447500954, + 0.16758497615485, + 0.20820945786863, + 0.23329718831165 + ],]), + 0.00000000000001 + ) + ) + } } }