Add serde for StandardScaler (#148)

* Derive `serde::Serialize` and `serde::Deserialize` for
  `StandardScaler`.
* Add relevant unit test.

Signed-off-by: Christos Katsakioris <ckatsak@gmail.com>

Signed-off-by: Christos Katsakioris <ckatsak@gmail.com>
This commit is contained in:
Christos Katsakioris
2022-09-06 20:37:54 +03:00
committed by GitHub
parent d305406dfd
commit 4d5f64c758
+43
View File
@@ -32,7 +32,11 @@ use crate::error::{Failed, FailedError};
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
/// Configure Behaviour of `StandardScaler`.
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone, Debug, Copy, Eq, PartialEq)]
pub struct StandardScalerParameters {
/// Optionaly adjust mean to be zero.
@@ -54,6 +58,7 @@ impl Default for StandardScalerParameters {
/// deviation of one. This can improve model training for
/// scaling sensitive models like neural network or nearest
/// neighbors based models.
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct StandardScaler<T: RealNumber> {
means: Vec<T>,
@@ -400,5 +405,43 @@ mod tests {
Ok(DenseMatrix::from_2d_array(&[&[0.0, 3.0], &[2.0, 4.0]]))
)
}
/// Same as `fit_for_random_values` test, but using a `StandardScaler` that has been
/// serialized and deserialized.
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
#[cfg(feature = "serde")]
fn serde_fit_for_random_values() {
let fitted_scaler = StandardScaler::fit(
&DenseMatrix::from_2d_array(&[
&[0.1004222429, 0.2194113576, 0.9310663354, 0.3313593793],
&[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264],
&[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046],
&[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442],
]),
StandardScalerParameters::default(),
)
.unwrap();
let deserialized_scaler: StandardScaler<f64> =
serde_json::from_str(&serde_json::to_string(&fitted_scaler).unwrap()).unwrap();
assert_eq!(
deserialized_scaler.means,
vec![0.42864544605, 0.2869813741, 0.737752073825, 0.431011663625],
);
assert!(
&DenseMatrix::from_2d_vec(&vec![deserialized_scaler.stds]).approximate_eq(
&DenseMatrix::from_2d_array(&[&[
0.29426447500954,
0.16758497615485,
0.20820945786863,
0.23329718831165
],]),
0.00000000000001
)
)
}
}
}