feat: serialization/deserialization with Serde
This commit is contained in:
@@ -2,6 +2,11 @@ extern crate num;
|
||||
use std::ops::Range;
|
||||
use std::fmt;
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use serde::{Serialize, Deserialize};
|
||||
use serde::ser::{Serializer, SerializeStruct};
|
||||
use serde::de::{Deserializer, Visitor, SeqAccess, MapAccess};
|
||||
|
||||
use crate::linalg::Matrix;
|
||||
pub use crate::linalg::BaseMatrix;
|
||||
@@ -87,6 +92,96 @@ impl<T: FloatExt + Debug> DenseMatrix<T> {
|
||||
|
||||
}
|
||||
|
||||
impl<'de, T: FloatExt + fmt::Debug + Deserialize<'de>> Deserialize<'de> for DenseMatrix<T> {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(field_identifier, rename_all = "lowercase")]
|
||||
enum Field { NRows, NCols, Values }
|
||||
|
||||
struct DenseMatrixVisitor<T: FloatExt + fmt::Debug>{
|
||||
t: PhantomData<T>
|
||||
}
|
||||
|
||||
impl<'a, T: FloatExt + fmt::Debug + Deserialize<'a>> Visitor<'a> for DenseMatrixVisitor<T> {
|
||||
type Value = DenseMatrix<T>;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("struct DenseMatrix")
|
||||
}
|
||||
|
||||
fn visit_seq<V>(self, mut seq: V) -> Result<DenseMatrix<T>, V::Error>
|
||||
where
|
||||
V: SeqAccess<'a>,
|
||||
{
|
||||
let nrows = seq.next_element()?
|
||||
.ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
|
||||
let ncols = seq.next_element()?
|
||||
.ok_or_else(|| serde::de::Error::invalid_length(1, &self))?;
|
||||
let values = seq.next_element()?
|
||||
.ok_or_else(|| serde::de::Error::invalid_length(2, &self))?;
|
||||
Ok(DenseMatrix::new(nrows, ncols, values))
|
||||
}
|
||||
|
||||
fn visit_map<V>(self, mut map: V) -> Result<DenseMatrix<T>, V::Error>
|
||||
where
|
||||
V: MapAccess<'a>,
|
||||
{
|
||||
let mut nrows = None;
|
||||
let mut ncols = None;
|
||||
let mut values = None;
|
||||
while let Some(key) = map.next_key()? {
|
||||
match key {
|
||||
Field::NRows => {
|
||||
if nrows.is_some() {
|
||||
return Err(serde::de::Error::duplicate_field("nrows"));
|
||||
}
|
||||
nrows = Some(map.next_value()?);
|
||||
}
|
||||
Field::NCols => {
|
||||
if ncols.is_some() {
|
||||
return Err(serde::de::Error::duplicate_field("ncols"));
|
||||
}
|
||||
ncols = Some(map.next_value()?);
|
||||
}
|
||||
Field::Values => {
|
||||
if values.is_some() {
|
||||
return Err(serde::de::Error::duplicate_field("values"));
|
||||
}
|
||||
values = Some(map.next_value()?);
|
||||
}
|
||||
}
|
||||
}
|
||||
let nrows = nrows.ok_or_else(|| serde::de::Error::missing_field("nrows"))?;
|
||||
let ncols = ncols.ok_or_else(|| serde::de::Error::missing_field("ncols"))?;
|
||||
let values = values.ok_or_else(|| serde::de::Error::missing_field("values"))?;
|
||||
Ok(DenseMatrix::new(nrows, ncols, values))
|
||||
}
|
||||
}
|
||||
|
||||
const FIELDS: &'static [&'static str] = &["nrows", "ncols", "values"];
|
||||
deserializer.deserialize_struct("DenseMatrix", FIELDS, DenseMatrixVisitor {
|
||||
t: PhantomData
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt + fmt::Debug + Serialize> Serialize for DenseMatrix<T> {
|
||||
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where
|
||||
S: Serializer {
|
||||
let (nrows, ncols) = self.shape();
|
||||
let mut state = serializer.serialize_struct("DenseMatrix", 3)?;
|
||||
state.serialize_field("nrows", &nrows)?;
|
||||
state.serialize_field("ncols", &ncols)?;
|
||||
state.serialize_field("values", &self.values)?;
|
||||
state.end()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt + Debug> SVDDecomposableMatrix<T> for DenseMatrix<T> {}
|
||||
|
||||
impl<T: FloatExt + Debug> EVDDecomposableMatrix<T> for DenseMatrix<T> {}
|
||||
@@ -772,4 +867,24 @@ mod tests {
|
||||
assert_eq!(res, a);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_from_json() {
|
||||
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||
let deserialized_a: DenseMatrix<f64> = serde_json::from_str(&serde_json::to_string(&a).unwrap()).unwrap();
|
||||
assert_eq!(a, deserialized_a);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_from_bincode() {
|
||||
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||
let deserialized_a: DenseMatrix<f64> = bincode::deserialize(&bincode::serialize(&a).unwrap()).unwrap();
|
||||
assert_eq!(a, deserialized_a);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_string() {
|
||||
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||
assert_eq!(format!("{}", a), "[[0.9, 0.4, 0.7], [0.4, 0.5, 0.3], [0.7, 0.3, 0.8]]");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user