diff --git a/Cargo.toml b/Cargo.toml index 59857cd..47283cc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,10 +9,13 @@ ndarray = "0.13" num-traits = "0.2.11" num = "0.2.1" rand = "0.7.3" +serde = { version = "1.0.105", features = ["derive"] } +serde_derive = "1.0.105" [dev-dependencies] -ndarray = "0.13" criterion = "0.3" +serde_json = "1.0" +bincode = "1.2.1" [[bench]] name = "distance" diff --git a/src/cluster/kmeans.rs b/src/cluster/kmeans.rs index b264d4b..89ad5ef 100644 --- a/src/cluster/kmeans.rs +++ b/src/cluster/kmeans.rs @@ -4,12 +4,14 @@ use rand::Rng; use std::iter::Sum; use std::fmt::Debug; +use serde::{Serialize, Deserialize}; + use crate::math::num::FloatExt; use crate::linalg::Matrix; use crate::math::distance::euclidian; use crate::algorithm::neighbour::bbd_tree::BBDTree; -#[derive(Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct KMeans { k: usize, y: Vec, @@ -18,6 +20,29 @@ pub struct KMeans { centroids: Vec> } +impl PartialEq for KMeans { + fn eq(&self, other: &Self) -> bool { + if self.k != other.k || + self.size != other.size || + self.centroids.len() != other.centroids.len() { + false + } else { + let n_centroids = self.centroids.len(); + for i in 0..n_centroids{ + if self.centroids[i].len() != other.centroids[i].len(){ + return false + } + for j in 0..self.centroids[i].len() { + if (self.centroids[i][j] - other.centroids[i][j]).abs() > T::epsilon() { + return false + } + } + } + true + } + } +} + #[derive(Debug, Clone)] pub struct KMeansParameters { pub max_iter: usize @@ -210,5 +235,37 @@ mod tests { } } + + #[test] + fn serde() { + let x = DenseMatrix::from_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + &[4.9, 3.1, 1.5, 0.1], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + &[5.7, 2.8, 4.5, 1.3], + &[6.3, 3.3, 4.7, 1.6], + &[4.9, 2.4, 3.3, 1.0], + &[6.6, 2.9, 4.6, 1.3], + &[5.2, 2.7, 3.9, 1.4]]); + + let kmeans = KMeans::new(&x, 2, Default::default()); + + let deserialized_kmeans: KMeans = serde_json::from_str(&serde_json::to_string(&kmeans).unwrap()).unwrap(); + + assert_eq!(kmeans, deserialized_kmeans); + + } } \ No newline at end of file diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs index 0095333..5b28f6a 100644 --- a/src/linalg/mod.rs +++ b/src/linalg/mod.rs @@ -5,7 +5,7 @@ pub mod evd; pub mod ndarray_bindings; use std::ops::Range; -use std::fmt::Debug; +use std::fmt::{Debug, Display}; use std::marker::PhantomData; use crate::math::num::FloatExt; @@ -13,7 +13,7 @@ use svd::SVDDecomposableMatrix; use evd::EVDDecomposableMatrix; use qr::QRDecomposableMatrix; -pub trait BaseMatrix: Clone + Debug { +pub trait BaseMatrix: Clone + Debug { type RowVector: Clone + Debug; @@ -175,9 +175,9 @@ pub trait BaseMatrix: Clone + Debug { } -pub trait Matrix: BaseMatrix + SVDDecomposableMatrix + EVDDecomposableMatrix + QRDecomposableMatrix {} +pub trait Matrix: BaseMatrix + SVDDecomposableMatrix + EVDDecomposableMatrix + QRDecomposableMatrix + PartialEq + Display {} -pub fn row_iter>(m: &M) -> RowIter { +pub fn row_iter>(m: &M) -> RowIter { RowIter{ m: m, pos: 0, @@ -186,14 +186,14 @@ pub fn row_iter>(m: &M) -> RowIter { } } -pub struct RowIter<'a, T: FloatExt + Debug, M: Matrix> { +pub struct RowIter<'a, T: FloatExt, M: BaseMatrix> { m: &'a M, pos: usize, max_pos: usize, phantom: PhantomData<&'a T> } -impl<'a, T: FloatExt + Debug, M: Matrix> Iterator for RowIter<'a, T, M> { +impl<'a, T: FloatExt, M: BaseMatrix> Iterator for RowIter<'a, T, M> { type Item = Vec; diff --git a/src/linalg/naive/dense_matrix.rs b/src/linalg/naive/dense_matrix.rs index 400e5ea..3be158e 100644 --- a/src/linalg/naive/dense_matrix.rs +++ b/src/linalg/naive/dense_matrix.rs @@ -2,6 +2,11 @@ extern crate num; use std::ops::Range; use std::fmt; use std::fmt::Debug; +use std::marker::PhantomData; + +use serde::{Serialize, Deserialize}; +use serde::ser::{Serializer, SerializeStruct}; +use serde::de::{Deserializer, Visitor, SeqAccess, MapAccess}; use crate::linalg::Matrix; pub use crate::linalg::BaseMatrix; @@ -87,6 +92,96 @@ impl DenseMatrix { } +impl<'de, T: FloatExt + fmt::Debug + Deserialize<'de>> Deserialize<'de> for DenseMatrix { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + + #[derive(Deserialize)] + #[serde(field_identifier, rename_all = "lowercase")] + enum Field { NRows, NCols, Values } + + struct DenseMatrixVisitor{ + t: PhantomData + } + + impl<'a, T: FloatExt + fmt::Debug + Deserialize<'a>> Visitor<'a> for DenseMatrixVisitor { + type Value = DenseMatrix; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("struct DenseMatrix") + } + + fn visit_seq(self, mut seq: V) -> Result, V::Error> + where + V: SeqAccess<'a>, + { + let nrows = seq.next_element()? + .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?; + let ncols = seq.next_element()? + .ok_or_else(|| serde::de::Error::invalid_length(1, &self))?; + let values = seq.next_element()? + .ok_or_else(|| serde::de::Error::invalid_length(2, &self))?; + Ok(DenseMatrix::new(nrows, ncols, values)) + } + + fn visit_map(self, mut map: V) -> Result, V::Error> + where + V: MapAccess<'a>, + { + let mut nrows = None; + let mut ncols = None; + let mut values = None; + while let Some(key) = map.next_key()? { + match key { + Field::NRows => { + if nrows.is_some() { + return Err(serde::de::Error::duplicate_field("nrows")); + } + nrows = Some(map.next_value()?); + } + Field::NCols => { + if ncols.is_some() { + return Err(serde::de::Error::duplicate_field("ncols")); + } + ncols = Some(map.next_value()?); + } + Field::Values => { + if values.is_some() { + return Err(serde::de::Error::duplicate_field("values")); + } + values = Some(map.next_value()?); + } + } + } + let nrows = nrows.ok_or_else(|| serde::de::Error::missing_field("nrows"))?; + let ncols = ncols.ok_or_else(|| serde::de::Error::missing_field("ncols"))?; + let values = values.ok_or_else(|| serde::de::Error::missing_field("values"))?; + Ok(DenseMatrix::new(nrows, ncols, values)) + } + } + + const FIELDS: &'static [&'static str] = &["nrows", "ncols", "values"]; + deserializer.deserialize_struct("DenseMatrix", FIELDS, DenseMatrixVisitor { + t: PhantomData + }) + } +} + +impl Serialize for DenseMatrix { + + fn serialize(&self, serializer: S) -> Result where + S: Serializer { + let (nrows, ncols) = self.shape(); + let mut state = serializer.serialize_struct("DenseMatrix", 3)?; + state.serialize_field("nrows", &nrows)?; + state.serialize_field("ncols", &ncols)?; + state.serialize_field("values", &self.values)?; + state.end() + } +} + impl SVDDecomposableMatrix for DenseMatrix {} impl EVDDecomposableMatrix for DenseMatrix {} @@ -772,4 +867,24 @@ mod tests { assert_eq!(res, a); } + #[test] + fn to_from_json() { + let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); + let deserialized_a: DenseMatrix = serde_json::from_str(&serde_json::to_string(&a).unwrap()).unwrap(); + assert_eq!(a, deserialized_a); + } + + #[test] + fn to_from_bincode() { + let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); + let deserialized_a: DenseMatrix = bincode::deserialize(&bincode::serialize(&a).unwrap()).unwrap(); + assert_eq!(a, deserialized_a); + } + + #[test] + fn to_string() { + let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); + assert_eq!(format!("{}", a), "[[0.9, 0.4, 0.7], [0.4, 0.5, 0.3], [0.7, 0.3, 0.8]]"); + } + } diff --git a/src/linalg/ndarray_bindings.rs b/src/linalg/ndarray_bindings.rs index fb98c60..7f1c9b0 100644 --- a/src/linalg/ndarray_bindings.rs +++ b/src/linalg/ndarray_bindings.rs @@ -1,5 +1,4 @@ use std::ops::Range; -use std::fmt::Debug; use std::iter::Sum; use std::ops::AddAssign; use std::ops::SubAssign; @@ -17,7 +16,7 @@ use crate::linalg::evd::EVDDecomposableMatrix; use crate::linalg::qr::QRDecomposableMatrix; -impl BaseMatrix for ArrayBase, Ix2> +impl BaseMatrix for ArrayBase, Ix2> { type RowVector = ArrayBase, Ix1>; @@ -286,13 +285,13 @@ impl SVDDecomposableMatrix for ArrayBase, Ix2> {} +impl SVDDecomposableMatrix for ArrayBase, Ix2> {} -impl EVDDecomposableMatrix for ArrayBase, Ix2> {} +impl EVDDecomposableMatrix for ArrayBase, Ix2> {} -impl QRDecomposableMatrix for ArrayBase, Ix2> {} +impl QRDecomposableMatrix for ArrayBase, Ix2> {} -impl Matrix for ArrayBase, Ix2> {} +impl Matrix for ArrayBase, Ix2> {} #[cfg(test)] mod tests { diff --git a/src/linear/linear_regression.rs b/src/linear/linear_regression.rs index 34ef45a..2fd7506 100644 --- a/src/linear/linear_regression.rs +++ b/src/linear/linear_regression.rs @@ -1,22 +1,31 @@ use std::fmt::Debug; +use serde::{Serialize, Deserialize}; + use crate::math::num::FloatExt; use crate::linalg::Matrix; -#[derive(Debug)] +#[derive(Serialize, Deserialize, Debug)] pub enum LinearRegressionSolver { QR, SVD } -#[derive(Debug)] -pub struct LinearRegression> { +#[derive(Serialize, Deserialize, Debug)] +pub struct LinearRegression> { coefficients: M, intercept: T, solver: LinearRegressionSolver } -impl> LinearRegression { +impl> PartialEq for LinearRegression { + fn eq(&self, other: &Self) -> bool { + self.coefficients == other.coefficients && + self.intercept == other.intercept + } +} + +impl> LinearRegression { pub fn fit(x: &M, y: &M, solver: LinearRegressionSolver) -> LinearRegression{ @@ -90,4 +99,33 @@ mod tests { } + + #[test] + fn serde(){ + let x = DenseMatrix::from_array(&[ + &[234.289, 235.6, 159.0, 107.608, 1947., 60.323], + &[259.426, 232.5, 145.6, 108.632, 1948., 61.122], + &[258.054, 368.2, 161.6, 109.773, 1949., 60.171], + &[284.599, 335.1, 165.0, 110.929, 1950., 61.187], + &[328.975, 209.9, 309.9, 112.075, 1951., 63.221], + &[346.999, 193.2, 359.4, 113.270, 1952., 63.639], + &[365.385, 187.0, 354.7, 115.094, 1953., 64.989], + &[363.112, 357.8, 335.0, 116.219, 1954., 63.761], + &[397.469, 290.4, 304.8, 117.388, 1955., 66.019], + &[419.180, 282.2, 285.7, 118.734, 1956., 67.857], + &[442.769, 293.6, 279.8, 120.445, 1957., 68.169], + &[444.546, 468.1, 263.7, 121.950, 1958., 66.513], + &[482.704, 381.3, 255.2, 123.366, 1959., 68.655], + &[502.601, 393.1, 251.4, 125.368, 1960., 69.564], + &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], + &[554.894, 400.7, 282.7, 130.081, 1962., 70.551]]); + + let y = DenseMatrix::from_array(&[&[83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]]); + + let lr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR); + + let deserialized_lr: LinearRegression> = serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap(); + + assert_eq!(lr, deserialized_lr); + } } \ No newline at end of file diff --git a/src/linear/logistic_regression.rs b/src/linear/logistic_regression.rs index a1c4973..5116e2d 100644 --- a/src/linear/logistic_regression.rs +++ b/src/linear/logistic_regression.rs @@ -1,6 +1,8 @@ use std::fmt::Debug; use std::marker::PhantomData; +use serde::{Serialize, Deserialize}; + use crate::math::num::FloatExt; use crate::linalg::Matrix; use crate::optimization::FunctionOrder; @@ -8,15 +10,15 @@ use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult}; use crate::optimization::line_search::Backtracking; use crate::optimization::first_order::lbfgs::LBFGS; -#[derive(Debug)] -pub struct LogisticRegression> { +#[derive(Serialize, Deserialize, Debug)] +pub struct LogisticRegression> { weights: M, classes: Vec, num_attributes: usize, num_classes: usize } -trait ObjectiveFunction> { +trait ObjectiveFunction> { fn f(&self, w_bias: &M) -> T; fn df(&self, g: &mut M, w_bias: &M); @@ -31,13 +33,24 @@ trait ObjectiveFunction> { } } -struct BinaryObjectiveFunction<'a, T: FloatExt + Debug, M: Matrix> { +struct BinaryObjectiveFunction<'a, T: FloatExt, M: Matrix> { x: &'a M, y: Vec, phantom: PhantomData<&'a T> } -impl<'a, T: FloatExt + Debug, M: Matrix> ObjectiveFunction for BinaryObjectiveFunction<'a, T, M> { +impl> PartialEq for LogisticRegression { + fn eq(&self, other: &Self) -> bool { + + self.num_classes == other.num_classes && + self.classes == other.classes && + self.num_attributes == other.num_attributes && + self.weights == other.weights + + } +} + +impl<'a, T: FloatExt, M: Matrix> ObjectiveFunction for BinaryObjectiveFunction<'a, T, M> { fn f(&self, w_bias: &M) -> T { let mut f = T::zero(); @@ -72,14 +85,14 @@ impl<'a, T: FloatExt + Debug, M: Matrix> ObjectiveFunction for BinaryOb } -struct MultiClassObjectiveFunction<'a, T: FloatExt + Debug, M: Matrix> { +struct MultiClassObjectiveFunction<'a, T: FloatExt, M: Matrix> { x: &'a M, y: Vec, k: usize, phantom: PhantomData<&'a T> } -impl<'a, T: FloatExt + Debug, M: Matrix> ObjectiveFunction for MultiClassObjectiveFunction<'a, T, M> { +impl<'a, T: FloatExt, M: Matrix> ObjectiveFunction for MultiClassObjectiveFunction<'a, T, M> { fn f(&self, w_bias: &M) -> T { let mut f = T::zero(); @@ -125,7 +138,7 @@ impl<'a, T: FloatExt + Debug, M: Matrix> ObjectiveFunction for MultiCla } -impl> LogisticRegression { +impl> LogisticRegression { pub fn fit(x: &M, y: &M::RowVector) -> LogisticRegression{ @@ -371,6 +384,33 @@ mod tests { } + #[test] + fn serde(){ + let x = DenseMatrix::from_array(&[ + &[1., -5.], + &[ 2., 5.], + &[ 3., -2.], + &[ 1., 2.], + &[ 2., 0.], + &[ 6., -5.], + &[ 7., 5.], + &[ 6., -2.], + &[ 7., 2.], + &[ 6., 0.], + &[ 8., -5.], + &[ 9., 5.], + &[10., -2.], + &[ 8., 2.], + &[ 9., 0.]]); + let y: Vec = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.]; + + let lr = LogisticRegression::fit(&x, &y); + + let deserialized_lr: LogisticRegression> = serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap(); + + assert_eq!(lr, deserialized_lr); + } + #[test] fn lr_fit_predict_iris() { let x = arr2(&[ @@ -396,7 +436,7 @@ mod tests { [5.2, 2.7, 3.9, 1.4]]); let y = arr1(&[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]); - let lr = LogisticRegression::fit(&x, &y); + let lr = LogisticRegression::fit(&x, &y); let y_hat = lr.predict(&x); diff --git a/src/math/num.rs b/src/math/num.rs index eaa851c..7980853 100644 --- a/src/math/num.rs +++ b/src/math/num.rs @@ -1,7 +1,8 @@ +use std::fmt::{Debug, Display}; use num_traits::{Float, FromPrimitive}; use rand::prelude::*; -pub trait FloatExt: Float + FromPrimitive { +pub trait FloatExt: Float + FromPrimitive + Debug + Display { fn copysign(self, sign: Self) -> Self;