feat: serialization/deserialization with Serde

This commit is contained in:
Volodymyr Orlov
2020-03-31 18:19:20 -07:00
parent 1257d2c19b
commit 8bb6013430
8 changed files with 281 additions and 28 deletions
+6 -6
View File
@@ -5,7 +5,7 @@ pub mod evd;
pub mod ndarray_bindings;
use std::ops::Range;
use std::fmt::Debug;
use std::fmt::{Debug, Display};
use std::marker::PhantomData;
use crate::math::num::FloatExt;
@@ -13,7 +13,7 @@ use svd::SVDDecomposableMatrix;
use evd::EVDDecomposableMatrix;
use qr::QRDecomposableMatrix;
pub trait BaseMatrix<T: FloatExt + Debug>: Clone + Debug {
pub trait BaseMatrix<T: FloatExt>: Clone + Debug {
type RowVector: Clone + Debug;
@@ -175,9 +175,9 @@ pub trait BaseMatrix<T: FloatExt + Debug>: Clone + Debug {
}
pub trait Matrix<T: FloatExt + Debug>: BaseMatrix<T> + SVDDecomposableMatrix<T> + EVDDecomposableMatrix<T> + QRDecomposableMatrix<T> {}
pub trait Matrix<T: FloatExt>: BaseMatrix<T> + SVDDecomposableMatrix<T> + EVDDecomposableMatrix<T> + QRDecomposableMatrix<T> + PartialEq + Display {}
pub fn row_iter<F: FloatExt + Debug, M: Matrix<F>>(m: &M) -> RowIter<F, M> {
pub fn row_iter<F: FloatExt, M: BaseMatrix<F>>(m: &M) -> RowIter<F, M> {
RowIter{
m: m,
pos: 0,
@@ -186,14 +186,14 @@ pub fn row_iter<F: FloatExt + Debug, M: Matrix<F>>(m: &M) -> RowIter<F, M> {
}
}
pub struct RowIter<'a, T: FloatExt + Debug, M: Matrix<T>> {
pub struct RowIter<'a, T: FloatExt, M: BaseMatrix<T>> {
m: &'a M,
pos: usize,
max_pos: usize,
phantom: PhantomData<&'a T>
}
impl<'a, T: FloatExt + Debug, M: Matrix<T>> Iterator for RowIter<'a, T, M> {
impl<'a, T: FloatExt, M: BaseMatrix<T>> Iterator for RowIter<'a, T, M> {
type Item = Vec<T>;
+115
View File
@@ -2,6 +2,11 @@ extern crate num;
use std::ops::Range;
use std::fmt;
use std::fmt::Debug;
use std::marker::PhantomData;
use serde::{Serialize, Deserialize};
use serde::ser::{Serializer, SerializeStruct};
use serde::de::{Deserializer, Visitor, SeqAccess, MapAccess};
use crate::linalg::Matrix;
pub use crate::linalg::BaseMatrix;
@@ -87,6 +92,96 @@ impl<T: FloatExt + Debug> DenseMatrix<T> {
}
impl<'de, T: FloatExt + fmt::Debug + Deserialize<'de>> Deserialize<'de> for DenseMatrix<T> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(field_identifier, rename_all = "lowercase")]
enum Field { NRows, NCols, Values }
struct DenseMatrixVisitor<T: FloatExt + fmt::Debug>{
t: PhantomData<T>
}
impl<'a, T: FloatExt + fmt::Debug + Deserialize<'a>> Visitor<'a> for DenseMatrixVisitor<T> {
type Value = DenseMatrix<T>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("struct DenseMatrix")
}
fn visit_seq<V>(self, mut seq: V) -> Result<DenseMatrix<T>, V::Error>
where
V: SeqAccess<'a>,
{
let nrows = seq.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
let ncols = seq.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(1, &self))?;
let values = seq.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(2, &self))?;
Ok(DenseMatrix::new(nrows, ncols, values))
}
fn visit_map<V>(self, mut map: V) -> Result<DenseMatrix<T>, V::Error>
where
V: MapAccess<'a>,
{
let mut nrows = None;
let mut ncols = None;
let mut values = None;
while let Some(key) = map.next_key()? {
match key {
Field::NRows => {
if nrows.is_some() {
return Err(serde::de::Error::duplicate_field("nrows"));
}
nrows = Some(map.next_value()?);
}
Field::NCols => {
if ncols.is_some() {
return Err(serde::de::Error::duplicate_field("ncols"));
}
ncols = Some(map.next_value()?);
}
Field::Values => {
if values.is_some() {
return Err(serde::de::Error::duplicate_field("values"));
}
values = Some(map.next_value()?);
}
}
}
let nrows = nrows.ok_or_else(|| serde::de::Error::missing_field("nrows"))?;
let ncols = ncols.ok_or_else(|| serde::de::Error::missing_field("ncols"))?;
let values = values.ok_or_else(|| serde::de::Error::missing_field("values"))?;
Ok(DenseMatrix::new(nrows, ncols, values))
}
}
const FIELDS: &'static [&'static str] = &["nrows", "ncols", "values"];
deserializer.deserialize_struct("DenseMatrix", FIELDS, DenseMatrixVisitor {
t: PhantomData
})
}
}
impl<T: FloatExt + fmt::Debug + Serialize> Serialize for DenseMatrix<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where
S: Serializer {
let (nrows, ncols) = self.shape();
let mut state = serializer.serialize_struct("DenseMatrix", 3)?;
state.serialize_field("nrows", &nrows)?;
state.serialize_field("ncols", &ncols)?;
state.serialize_field("values", &self.values)?;
state.end()
}
}
impl<T: FloatExt + Debug> SVDDecomposableMatrix<T> for DenseMatrix<T> {}
impl<T: FloatExt + Debug> EVDDecomposableMatrix<T> for DenseMatrix<T> {}
@@ -772,4 +867,24 @@ mod tests {
assert_eq!(res, a);
}
#[test]
fn to_from_json() {
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
let deserialized_a: DenseMatrix<f64> = serde_json::from_str(&serde_json::to_string(&a).unwrap()).unwrap();
assert_eq!(a, deserialized_a);
}
#[test]
fn to_from_bincode() {
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
let deserialized_a: DenseMatrix<f64> = bincode::deserialize(&bincode::serialize(&a).unwrap()).unwrap();
assert_eq!(a, deserialized_a);
}
#[test]
fn to_string() {
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
assert_eq!(format!("{}", a), "[[0.9, 0.4, 0.7], [0.4, 0.5, 0.3], [0.7, 0.3, 0.8]]");
}
}
+5 -6
View File
@@ -1,5 +1,4 @@
use std::ops::Range;
use std::fmt::Debug;
use std::iter::Sum;
use std::ops::AddAssign;
use std::ops::SubAssign;
@@ -17,7 +16,7 @@ use crate::linalg::evd::EVDDecomposableMatrix;
use crate::linalg::qr::QRDecomposableMatrix;
impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> BaseMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2>
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> BaseMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2>
{
type RowVector = ArrayBase<OwnedRepr<T>, Ix1>;
@@ -286,13 +285,13 @@ impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + D
}
impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> SVDDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> SVDDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> EVDDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> EVDDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> QRDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> QRDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> Matrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> Matrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
#[cfg(test)]
mod tests {