feat: serialization/deserialization with Serde
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::optimization::FunctionOrder;
|
||||
@@ -8,15 +10,15 @@ use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
|
||||
use crate::optimization::line_search::Backtracking;
|
||||
use crate::optimization::first_order::lbfgs::LBFGS;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LogisticRegression<T: FloatExt + Debug, M: Matrix<T>> {
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct LogisticRegression<T: FloatExt, M: Matrix<T>> {
|
||||
weights: M,
|
||||
classes: Vec<T>,
|
||||
num_attributes: usize,
|
||||
num_classes: usize
|
||||
}
|
||||
|
||||
trait ObjectiveFunction<T: FloatExt + Debug, M: Matrix<T>> {
|
||||
trait ObjectiveFunction<T: FloatExt, M: Matrix<T>> {
|
||||
fn f(&self, w_bias: &M) -> T;
|
||||
fn df(&self, g: &mut M, w_bias: &M);
|
||||
|
||||
@@ -31,13 +33,24 @@ trait ObjectiveFunction<T: FloatExt + Debug, M: Matrix<T>> {
|
||||
}
|
||||
}
|
||||
|
||||
struct BinaryObjectiveFunction<'a, T: FloatExt + Debug, M: Matrix<T>> {
|
||||
struct BinaryObjectiveFunction<'a, T: FloatExt, M: Matrix<T>> {
|
||||
x: &'a M,
|
||||
y: Vec<usize>,
|
||||
phantom: PhantomData<&'a T>
|
||||
}
|
||||
|
||||
impl<'a, T: FloatExt + Debug, M: Matrix<T>> ObjectiveFunction<T, M> for BinaryObjectiveFunction<'a, T, M> {
|
||||
impl<T: FloatExt, M: Matrix<T>> PartialEq for LogisticRegression<T, M> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
|
||||
self.num_classes == other.num_classes &&
|
||||
self.classes == other.classes &&
|
||||
self.num_attributes == other.num_attributes &&
|
||||
self.weights == other.weights
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: FloatExt, M: Matrix<T>> ObjectiveFunction<T, M> for BinaryObjectiveFunction<'a, T, M> {
|
||||
|
||||
fn f(&self, w_bias: &M) -> T {
|
||||
let mut f = T::zero();
|
||||
@@ -72,14 +85,14 @@ impl<'a, T: FloatExt + Debug, M: Matrix<T>> ObjectiveFunction<T, M> for BinaryOb
|
||||
|
||||
}
|
||||
|
||||
struct MultiClassObjectiveFunction<'a, T: FloatExt + Debug, M: Matrix<T>> {
|
||||
struct MultiClassObjectiveFunction<'a, T: FloatExt, M: Matrix<T>> {
|
||||
x: &'a M,
|
||||
y: Vec<usize>,
|
||||
k: usize,
|
||||
phantom: PhantomData<&'a T>
|
||||
}
|
||||
|
||||
impl<'a, T: FloatExt + Debug, M: Matrix<T>> ObjectiveFunction<T, M> for MultiClassObjectiveFunction<'a, T, M> {
|
||||
impl<'a, T: FloatExt, M: Matrix<T>> ObjectiveFunction<T, M> for MultiClassObjectiveFunction<'a, T, M> {
|
||||
|
||||
fn f(&self, w_bias: &M) -> T {
|
||||
let mut f = T::zero();
|
||||
@@ -125,7 +138,7 @@ impl<'a, T: FloatExt + Debug, M: Matrix<T>> ObjectiveFunction<T, M> for MultiCla
|
||||
|
||||
}
|
||||
|
||||
impl<T: FloatExt + Debug, M: Matrix<T>> LogisticRegression<T, M> {
|
||||
impl<T: FloatExt, M: Matrix<T>> LogisticRegression<T, M> {
|
||||
|
||||
pub fn fit(x: &M, y: &M::RowVector) -> LogisticRegression<T, M>{
|
||||
|
||||
@@ -371,6 +384,33 @@ mod tests {
|
||||
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde(){
|
||||
let x = DenseMatrix::from_array(&[
|
||||
&[1., -5.],
|
||||
&[ 2., 5.],
|
||||
&[ 3., -2.],
|
||||
&[ 1., 2.],
|
||||
&[ 2., 0.],
|
||||
&[ 6., -5.],
|
||||
&[ 7., 5.],
|
||||
&[ 6., -2.],
|
||||
&[ 7., 2.],
|
||||
&[ 6., 0.],
|
||||
&[ 8., -5.],
|
||||
&[ 9., 5.],
|
||||
&[10., -2.],
|
||||
&[ 8., 2.],
|
||||
&[ 9., 0.]]);
|
||||
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y);
|
||||
|
||||
let deserialized_lr: LogisticRegression<f64, DenseMatrix<f64>> = serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(lr, deserialized_lr);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lr_fit_predict_iris() {
|
||||
let x = arr2(&[
|
||||
@@ -396,7 +436,7 @@ mod tests {
|
||||
[5.2, 2.7, 3.9, 1.4]]);
|
||||
let y = arr1(&[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]);
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y);
|
||||
let lr = LogisticRegression::fit(&x, &y);
|
||||
|
||||
let y_hat = lr.predict(&x);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user