Add serde for StandardScaler (#148)
* Derive `serde::Serialize` and `serde::Deserialize` for `StandardScaler`. * Add relevant unit test. Signed-off-by: Christos Katsakioris <ckatsak@gmail.com> Signed-off-by: Christos Katsakioris <ckatsak@gmail.com>
This commit is contained in:
committed by
GitHub
parent
d305406dfd
commit
4d5f64c758
@@ -32,7 +32,11 @@ use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Configure Behaviour of `StandardScaler`.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Clone, Debug, Copy, Eq, PartialEq)]
|
||||
pub struct StandardScalerParameters {
|
||||
/// Optionaly adjust mean to be zero.
|
||||
@@ -54,6 +58,7 @@ impl Default for StandardScalerParameters {
|
||||
/// deviation of one. This can improve model training for
|
||||
/// scaling sensitive models like neural network or nearest
|
||||
/// neighbors based models.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Clone, Debug, Default, Eq, PartialEq)]
|
||||
pub struct StandardScaler<T: RealNumber> {
|
||||
means: Vec<T>,
|
||||
@@ -400,5 +405,43 @@ mod tests {
|
||||
Ok(DenseMatrix::from_2d_array(&[&[0.0, 3.0], &[2.0, 4.0]]))
|
||||
)
|
||||
}
|
||||
|
||||
/// Same as `fit_for_random_values` test, but using a `StandardScaler` that has been
|
||||
/// serialized and deserialized.
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde_fit_for_random_values() {
|
||||
let fitted_scaler = StandardScaler::fit(
|
||||
&DenseMatrix::from_2d_array(&[
|
||||
&[0.1004222429, 0.2194113576, 0.9310663354, 0.3313593793],
|
||||
&[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264],
|
||||
&[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046],
|
||||
&[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442],
|
||||
]),
|
||||
StandardScalerParameters::default(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let deserialized_scaler: StandardScaler<f64> =
|
||||
serde_json::from_str(&serde_json::to_string(&fitted_scaler).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
deserialized_scaler.means,
|
||||
vec![0.42864544605, 0.2869813741, 0.737752073825, 0.431011663625],
|
||||
);
|
||||
|
||||
assert!(
|
||||
&DenseMatrix::from_2d_vec(&vec![deserialized_scaler.stds]).approximate_eq(
|
||||
&DenseMatrix::from_2d_array(&[&[
|
||||
0.29426447500954,
|
||||
0.16758497615485,
|
||||
0.20820945786863,
|
||||
0.23329718831165
|
||||
],]),
|
||||
0.00000000000001
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user