feat: serialization/deserialization with Serde

This commit is contained in:
Volodymyr Orlov
2020-03-31 18:19:20 -07:00
parent 1257d2c19b
commit 8bb6013430
8 changed files with 281 additions and 28 deletions
+42 -4
View File
@@ -1,22 +1,31 @@
use std::fmt::Debug;
use serde::{Serialize, Deserialize};
use crate::math::num::FloatExt;
use crate::linalg::Matrix;
#[derive(Debug)]
#[derive(Serialize, Deserialize, Debug)]
pub enum LinearRegressionSolver {
QR,
SVD
}
#[derive(Debug)]
pub struct LinearRegression<T: FloatExt + Debug, M: Matrix<T>> {
#[derive(Serialize, Deserialize, Debug)]
pub struct LinearRegression<T: FloatExt, M: Matrix<T>> {
coefficients: M,
intercept: T,
solver: LinearRegressionSolver
}
impl<T: FloatExt + Debug, M: Matrix<T>> LinearRegression<T, M> {
impl<T: FloatExt, M: Matrix<T>> PartialEq for LinearRegression<T, M> {
fn eq(&self, other: &Self) -> bool {
self.coefficients == other.coefficients &&
self.intercept == other.intercept
}
}
impl<T: FloatExt, M: Matrix<T>> LinearRegression<T, M> {
pub fn fit(x: &M, y: &M, solver: LinearRegressionSolver) -> LinearRegression<T, M>{
@@ -90,4 +99,33 @@ mod tests {
}
#[test]
fn serde(){
let x = DenseMatrix::from_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
&[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
&[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
&[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
&[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
&[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
&[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
let y = DenseMatrix::from_array(&[&[83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]]);
let lr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR);
let deserialized_lr: LinearRegression<f64, DenseMatrix<f64>> = serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
assert_eq!(lr, deserialized_lr);
}
}
+49 -9
View File
@@ -1,6 +1,8 @@
use std::fmt::Debug;
use std::marker::PhantomData;
use serde::{Serialize, Deserialize};
use crate::math::num::FloatExt;
use crate::linalg::Matrix;
use crate::optimization::FunctionOrder;
@@ -8,15 +10,15 @@ use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
use crate::optimization::line_search::Backtracking;
use crate::optimization::first_order::lbfgs::LBFGS;
#[derive(Debug)]
pub struct LogisticRegression<T: FloatExt + Debug, M: Matrix<T>> {
#[derive(Serialize, Deserialize, Debug)]
pub struct LogisticRegression<T: FloatExt, M: Matrix<T>> {
weights: M,
classes: Vec<T>,
num_attributes: usize,
num_classes: usize
}
trait ObjectiveFunction<T: FloatExt + Debug, M: Matrix<T>> {
trait ObjectiveFunction<T: FloatExt, M: Matrix<T>> {
fn f(&self, w_bias: &M) -> T;
fn df(&self, g: &mut M, w_bias: &M);
@@ -31,13 +33,24 @@ trait ObjectiveFunction<T: FloatExt + Debug, M: Matrix<T>> {
}
}
struct BinaryObjectiveFunction<'a, T: FloatExt + Debug, M: Matrix<T>> {
struct BinaryObjectiveFunction<'a, T: FloatExt, M: Matrix<T>> {
x: &'a M,
y: Vec<usize>,
phantom: PhantomData<&'a T>
}
impl<'a, T: FloatExt + Debug, M: Matrix<T>> ObjectiveFunction<T, M> for BinaryObjectiveFunction<'a, T, M> {
impl<T: FloatExt, M: Matrix<T>> PartialEq for LogisticRegression<T, M> {
fn eq(&self, other: &Self) -> bool {
self.num_classes == other.num_classes &&
self.classes == other.classes &&
self.num_attributes == other.num_attributes &&
self.weights == other.weights
}
}
impl<'a, T: FloatExt, M: Matrix<T>> ObjectiveFunction<T, M> for BinaryObjectiveFunction<'a, T, M> {
fn f(&self, w_bias: &M) -> T {
let mut f = T::zero();
@@ -72,14 +85,14 @@ impl<'a, T: FloatExt + Debug, M: Matrix<T>> ObjectiveFunction<T, M> for BinaryOb
}
struct MultiClassObjectiveFunction<'a, T: FloatExt + Debug, M: Matrix<T>> {
struct MultiClassObjectiveFunction<'a, T: FloatExt, M: Matrix<T>> {
x: &'a M,
y: Vec<usize>,
k: usize,
phantom: PhantomData<&'a T>
}
impl<'a, T: FloatExt + Debug, M: Matrix<T>> ObjectiveFunction<T, M> for MultiClassObjectiveFunction<'a, T, M> {
impl<'a, T: FloatExt, M: Matrix<T>> ObjectiveFunction<T, M> for MultiClassObjectiveFunction<'a, T, M> {
fn f(&self, w_bias: &M) -> T {
let mut f = T::zero();
@@ -125,7 +138,7 @@ impl<'a, T: FloatExt + Debug, M: Matrix<T>> ObjectiveFunction<T, M> for MultiCla
}
impl<T: FloatExt + Debug, M: Matrix<T>> LogisticRegression<T, M> {
impl<T: FloatExt, M: Matrix<T>> LogisticRegression<T, M> {
pub fn fit(x: &M, y: &M::RowVector) -> LogisticRegression<T, M>{
@@ -371,6 +384,33 @@ mod tests {
}
#[test]
fn serde(){
let x = DenseMatrix::from_array(&[
&[1., -5.],
&[ 2., 5.],
&[ 3., -2.],
&[ 1., 2.],
&[ 2., 0.],
&[ 6., -5.],
&[ 7., 5.],
&[ 6., -2.],
&[ 7., 2.],
&[ 6., 0.],
&[ 8., -5.],
&[ 9., 5.],
&[10., -2.],
&[ 8., 2.],
&[ 9., 0.]]);
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
let lr = LogisticRegression::fit(&x, &y);
let deserialized_lr: LogisticRegression<f64, DenseMatrix<f64>> = serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
assert_eq!(lr, deserialized_lr);
}
#[test]
fn lr_fit_predict_iris() {
let x = arr2(&[
@@ -396,7 +436,7 @@ mod tests {
[5.2, 2.7, 3.9, 1.4]]);
let y = arr1(&[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]);
let lr = LogisticRegression::fit(&x, &y);
let lr = LogisticRegression::fit(&x, &y);
let y_hat = lr.predict(&x);