Add serde for StandardScaler (#148)
* Derive `serde::Serialize` and `serde::Deserialize` for `StandardScaler`. * Add relevant unit test. Signed-off-by: Christos Katsakioris <ckatsak@gmail.com> Signed-off-by: Christos Katsakioris <ckatsak@gmail.com>
This commit is contained in:
committed by
GitHub
parent
d305406dfd
commit
4d5f64c758
@@ -32,7 +32,11 @@ use crate::error::{Failed, FailedError};
|
|||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// Configure Behaviour of `StandardScaler`.
|
/// Configure Behaviour of `StandardScaler`.
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Clone, Debug, Copy, Eq, PartialEq)]
|
#[derive(Clone, Debug, Copy, Eq, PartialEq)]
|
||||||
pub struct StandardScalerParameters {
|
pub struct StandardScalerParameters {
|
||||||
/// Optionaly adjust mean to be zero.
|
/// Optionaly adjust mean to be zero.
|
||||||
@@ -54,6 +58,7 @@ impl Default for StandardScalerParameters {
|
|||||||
/// deviation of one. This can improve model training for
|
/// deviation of one. This can improve model training for
|
||||||
/// scaling sensitive models like neural network or nearest
|
/// scaling sensitive models like neural network or nearest
|
||||||
/// neighbors based models.
|
/// neighbors based models.
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Clone, Debug, Default, Eq, PartialEq)]
|
#[derive(Clone, Debug, Default, Eq, PartialEq)]
|
||||||
pub struct StandardScaler<T: RealNumber> {
|
pub struct StandardScaler<T: RealNumber> {
|
||||||
means: Vec<T>,
|
means: Vec<T>,
|
||||||
@@ -400,5 +405,43 @@ mod tests {
|
|||||||
Ok(DenseMatrix::from_2d_array(&[&[0.0, 3.0], &[2.0, 4.0]]))
|
Ok(DenseMatrix::from_2d_array(&[&[0.0, 3.0], &[2.0, 4.0]]))
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Same as `fit_for_random_values` test, but using a `StandardScaler` that has been
|
||||||
|
/// serialized and deserialized.
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
|
fn serde_fit_for_random_values() {
|
||||||
|
let fitted_scaler = StandardScaler::fit(
|
||||||
|
&DenseMatrix::from_2d_array(&[
|
||||||
|
&[0.1004222429, 0.2194113576, 0.9310663354, 0.3313593793],
|
||||||
|
&[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264],
|
||||||
|
&[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046],
|
||||||
|
&[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442],
|
||||||
|
]),
|
||||||
|
StandardScalerParameters::default(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let deserialized_scaler: StandardScaler<f64> =
|
||||||
|
serde_json::from_str(&serde_json::to_string(&fitted_scaler).unwrap()).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
deserialized_scaler.means,
|
||||||
|
vec![0.42864544605, 0.2869813741, 0.737752073825, 0.431011663625],
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
&DenseMatrix::from_2d_vec(&vec![deserialized_scaler.stds]).approximate_eq(
|
||||||
|
&DenseMatrix::from_2d_array(&[&[
|
||||||
|
0.29426447500954,
|
||||||
|
0.16758497615485,
|
||||||
|
0.20820945786863,
|
||||||
|
0.23329718831165
|
||||||
|
],]),
|
||||||
|
0.00000000000001
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user