feat: serialization/deserialization with Serde

This commit is contained in:
Volodymyr Orlov
2020-03-31 18:19:20 -07:00
parent 1257d2c19b
commit 8bb6013430
8 changed files with 281 additions and 28 deletions
+115
View File
@@ -2,6 +2,11 @@ extern crate num;
use std::ops::Range;
use std::fmt;
use std::fmt::Debug;
use std::marker::PhantomData;
use serde::{Serialize, Deserialize};
use serde::ser::{Serializer, SerializeStruct};
use serde::de::{Deserializer, Visitor, SeqAccess, MapAccess};
use crate::linalg::Matrix;
pub use crate::linalg::BaseMatrix;
@@ -87,6 +92,96 @@ impl<T: FloatExt + Debug> DenseMatrix<T> {
}
impl<'de, T: FloatExt + fmt::Debug + Deserialize<'de>> Deserialize<'de> for DenseMatrix<T> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(field_identifier, rename_all = "lowercase")]
enum Field { NRows, NCols, Values }
struct DenseMatrixVisitor<T: FloatExt + fmt::Debug>{
t: PhantomData<T>
}
impl<'a, T: FloatExt + fmt::Debug + Deserialize<'a>> Visitor<'a> for DenseMatrixVisitor<T> {
type Value = DenseMatrix<T>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("struct DenseMatrix")
}
fn visit_seq<V>(self, mut seq: V) -> Result<DenseMatrix<T>, V::Error>
where
V: SeqAccess<'a>,
{
let nrows = seq.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
let ncols = seq.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(1, &self))?;
let values = seq.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(2, &self))?;
Ok(DenseMatrix::new(nrows, ncols, values))
}
fn visit_map<V>(self, mut map: V) -> Result<DenseMatrix<T>, V::Error>
where
V: MapAccess<'a>,
{
let mut nrows = None;
let mut ncols = None;
let mut values = None;
while let Some(key) = map.next_key()? {
match key {
Field::NRows => {
if nrows.is_some() {
return Err(serde::de::Error::duplicate_field("nrows"));
}
nrows = Some(map.next_value()?);
}
Field::NCols => {
if ncols.is_some() {
return Err(serde::de::Error::duplicate_field("ncols"));
}
ncols = Some(map.next_value()?);
}
Field::Values => {
if values.is_some() {
return Err(serde::de::Error::duplicate_field("values"));
}
values = Some(map.next_value()?);
}
}
}
let nrows = nrows.ok_or_else(|| serde::de::Error::missing_field("nrows"))?;
let ncols = ncols.ok_or_else(|| serde::de::Error::missing_field("ncols"))?;
let values = values.ok_or_else(|| serde::de::Error::missing_field("values"))?;
Ok(DenseMatrix::new(nrows, ncols, values))
}
}
const FIELDS: &'static [&'static str] = &["nrows", "ncols", "values"];
deserializer.deserialize_struct("DenseMatrix", FIELDS, DenseMatrixVisitor {
t: PhantomData
})
}
}
impl<T: FloatExt + fmt::Debug + Serialize> Serialize for DenseMatrix<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where
S: Serializer {
let (nrows, ncols) = self.shape();
let mut state = serializer.serialize_struct("DenseMatrix", 3)?;
state.serialize_field("nrows", &nrows)?;
state.serialize_field("ncols", &ncols)?;
state.serialize_field("values", &self.values)?;
state.end()
}
}
impl<T: FloatExt + Debug> SVDDecomposableMatrix<T> for DenseMatrix<T> {}
impl<T: FloatExt + Debug> EVDDecomposableMatrix<T> for DenseMatrix<T> {}
@@ -772,4 +867,24 @@ mod tests {
assert_eq!(res, a);
}
#[test]
fn to_from_json() {
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
let deserialized_a: DenseMatrix<f64> = serde_json::from_str(&serde_json::to_string(&a).unwrap()).unwrap();
assert_eq!(a, deserialized_a);
}
#[test]
fn to_from_bincode() {
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
let deserialized_a: DenseMatrix<f64> = bincode::deserialize(&bincode::serialize(&a).unwrap()).unwrap();
assert_eq!(a, deserialized_a);
}
#[test]
fn to_string() {
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
assert_eq!(format!("{}", a), "[[0.9, 0.4, 0.7], [0.4, 0.5, 0.3], [0.7, 0.3, 0.8]]");
}
}