feat: serialization/deserialization with Serde
This commit is contained in:
+6
-6
@@ -5,7 +5,7 @@ pub mod evd;
|
||||
pub mod ndarray_bindings;
|
||||
|
||||
use std::ops::Range;
|
||||
use std::fmt::Debug;
|
||||
use std::fmt::{Debug, Display};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
@@ -13,7 +13,7 @@ use svd::SVDDecomposableMatrix;
|
||||
use evd::EVDDecomposableMatrix;
|
||||
use qr::QRDecomposableMatrix;
|
||||
|
||||
pub trait BaseMatrix<T: FloatExt + Debug>: Clone + Debug {
|
||||
pub trait BaseMatrix<T: FloatExt>: Clone + Debug {
|
||||
|
||||
type RowVector: Clone + Debug;
|
||||
|
||||
@@ -175,9 +175,9 @@ pub trait BaseMatrix<T: FloatExt + Debug>: Clone + Debug {
|
||||
|
||||
}
|
||||
|
||||
pub trait Matrix<T: FloatExt + Debug>: BaseMatrix<T> + SVDDecomposableMatrix<T> + EVDDecomposableMatrix<T> + QRDecomposableMatrix<T> {}
|
||||
pub trait Matrix<T: FloatExt>: BaseMatrix<T> + SVDDecomposableMatrix<T> + EVDDecomposableMatrix<T> + QRDecomposableMatrix<T> + PartialEq + Display {}
|
||||
|
||||
pub fn row_iter<F: FloatExt + Debug, M: Matrix<F>>(m: &M) -> RowIter<F, M> {
|
||||
pub fn row_iter<F: FloatExt, M: BaseMatrix<F>>(m: &M) -> RowIter<F, M> {
|
||||
RowIter{
|
||||
m: m,
|
||||
pos: 0,
|
||||
@@ -186,14 +186,14 @@ pub fn row_iter<F: FloatExt + Debug, M: Matrix<F>>(m: &M) -> RowIter<F, M> {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RowIter<'a, T: FloatExt + Debug, M: Matrix<T>> {
|
||||
pub struct RowIter<'a, T: FloatExt, M: BaseMatrix<T>> {
|
||||
m: &'a M,
|
||||
pos: usize,
|
||||
max_pos: usize,
|
||||
phantom: PhantomData<&'a T>
|
||||
}
|
||||
|
||||
impl<'a, T: FloatExt + Debug, M: Matrix<T>> Iterator for RowIter<'a, T, M> {
|
||||
impl<'a, T: FloatExt, M: BaseMatrix<T>> Iterator for RowIter<'a, T, M> {
|
||||
|
||||
type Item = Vec<T>;
|
||||
|
||||
|
||||
@@ -2,6 +2,11 @@ extern crate num;
|
||||
use std::ops::Range;
|
||||
use std::fmt;
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use serde::{Serialize, Deserialize};
|
||||
use serde::ser::{Serializer, SerializeStruct};
|
||||
use serde::de::{Deserializer, Visitor, SeqAccess, MapAccess};
|
||||
|
||||
use crate::linalg::Matrix;
|
||||
pub use crate::linalg::BaseMatrix;
|
||||
@@ -87,6 +92,96 @@ impl<T: FloatExt + Debug> DenseMatrix<T> {
|
||||
|
||||
}
|
||||
|
||||
impl<'de, T: FloatExt + fmt::Debug + Deserialize<'de>> Deserialize<'de> for DenseMatrix<T> {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(field_identifier, rename_all = "lowercase")]
|
||||
enum Field { NRows, NCols, Values }
|
||||
|
||||
struct DenseMatrixVisitor<T: FloatExt + fmt::Debug>{
|
||||
t: PhantomData<T>
|
||||
}
|
||||
|
||||
impl<'a, T: FloatExt + fmt::Debug + Deserialize<'a>> Visitor<'a> for DenseMatrixVisitor<T> {
|
||||
type Value = DenseMatrix<T>;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("struct DenseMatrix")
|
||||
}
|
||||
|
||||
fn visit_seq<V>(self, mut seq: V) -> Result<DenseMatrix<T>, V::Error>
|
||||
where
|
||||
V: SeqAccess<'a>,
|
||||
{
|
||||
let nrows = seq.next_element()?
|
||||
.ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
|
||||
let ncols = seq.next_element()?
|
||||
.ok_or_else(|| serde::de::Error::invalid_length(1, &self))?;
|
||||
let values = seq.next_element()?
|
||||
.ok_or_else(|| serde::de::Error::invalid_length(2, &self))?;
|
||||
Ok(DenseMatrix::new(nrows, ncols, values))
|
||||
}
|
||||
|
||||
fn visit_map<V>(self, mut map: V) -> Result<DenseMatrix<T>, V::Error>
|
||||
where
|
||||
V: MapAccess<'a>,
|
||||
{
|
||||
let mut nrows = None;
|
||||
let mut ncols = None;
|
||||
let mut values = None;
|
||||
while let Some(key) = map.next_key()? {
|
||||
match key {
|
||||
Field::NRows => {
|
||||
if nrows.is_some() {
|
||||
return Err(serde::de::Error::duplicate_field("nrows"));
|
||||
}
|
||||
nrows = Some(map.next_value()?);
|
||||
}
|
||||
Field::NCols => {
|
||||
if ncols.is_some() {
|
||||
return Err(serde::de::Error::duplicate_field("ncols"));
|
||||
}
|
||||
ncols = Some(map.next_value()?);
|
||||
}
|
||||
Field::Values => {
|
||||
if values.is_some() {
|
||||
return Err(serde::de::Error::duplicate_field("values"));
|
||||
}
|
||||
values = Some(map.next_value()?);
|
||||
}
|
||||
}
|
||||
}
|
||||
let nrows = nrows.ok_or_else(|| serde::de::Error::missing_field("nrows"))?;
|
||||
let ncols = ncols.ok_or_else(|| serde::de::Error::missing_field("ncols"))?;
|
||||
let values = values.ok_or_else(|| serde::de::Error::missing_field("values"))?;
|
||||
Ok(DenseMatrix::new(nrows, ncols, values))
|
||||
}
|
||||
}
|
||||
|
||||
const FIELDS: &'static [&'static str] = &["nrows", "ncols", "values"];
|
||||
deserializer.deserialize_struct("DenseMatrix", FIELDS, DenseMatrixVisitor {
|
||||
t: PhantomData
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt + fmt::Debug + Serialize> Serialize for DenseMatrix<T> {
|
||||
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where
|
||||
S: Serializer {
|
||||
let (nrows, ncols) = self.shape();
|
||||
let mut state = serializer.serialize_struct("DenseMatrix", 3)?;
|
||||
state.serialize_field("nrows", &nrows)?;
|
||||
state.serialize_field("ncols", &ncols)?;
|
||||
state.serialize_field("values", &self.values)?;
|
||||
state.end()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt + Debug> SVDDecomposableMatrix<T> for DenseMatrix<T> {}
|
||||
|
||||
impl<T: FloatExt + Debug> EVDDecomposableMatrix<T> for DenseMatrix<T> {}
|
||||
@@ -772,4 +867,24 @@ mod tests {
|
||||
assert_eq!(res, a);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_from_json() {
|
||||
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||
let deserialized_a: DenseMatrix<f64> = serde_json::from_str(&serde_json::to_string(&a).unwrap()).unwrap();
|
||||
assert_eq!(a, deserialized_a);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_from_bincode() {
|
||||
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||
let deserialized_a: DenseMatrix<f64> = bincode::deserialize(&bincode::serialize(&a).unwrap()).unwrap();
|
||||
assert_eq!(a, deserialized_a);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_string() {
|
||||
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||
assert_eq!(format!("{}", a), "[[0.9, 0.4, 0.7], [0.4, 0.5, 0.3], [0.7, 0.3, 0.8]]");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use std::ops::Range;
|
||||
use std::fmt::Debug;
|
||||
use std::iter::Sum;
|
||||
use std::ops::AddAssign;
|
||||
use std::ops::SubAssign;
|
||||
@@ -17,7 +16,7 @@ use crate::linalg::evd::EVDDecomposableMatrix;
|
||||
use crate::linalg::qr::QRDecomposableMatrix;
|
||||
|
||||
|
||||
impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> BaseMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2>
|
||||
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> BaseMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2>
|
||||
{
|
||||
type RowVector = ArrayBase<OwnedRepr<T>, Ix1>;
|
||||
|
||||
@@ -286,13 +285,13 @@ impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + D
|
||||
|
||||
}
|
||||
|
||||
impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> SVDDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
||||
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> SVDDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
||||
|
||||
impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> EVDDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
||||
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> EVDDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
||||
|
||||
impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> QRDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
||||
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> QRDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
||||
|
||||
impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> Matrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
||||
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> Matrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
Reference in New Issue
Block a user