feat: serialization/deserialization with Serde
This commit is contained in:
+4
-1
@@ -9,10 +9,13 @@ ndarray = "0.13"
|
|||||||
num-traits = "0.2.11"
|
num-traits = "0.2.11"
|
||||||
num = "0.2.1"
|
num = "0.2.1"
|
||||||
rand = "0.7.3"
|
rand = "0.7.3"
|
||||||
|
serde = { version = "1.0.105", features = ["derive"] }
|
||||||
|
serde_derive = "1.0.105"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
ndarray = "0.13"
|
|
||||||
criterion = "0.3"
|
criterion = "0.3"
|
||||||
|
serde_json = "1.0"
|
||||||
|
bincode = "1.2.1"
|
||||||
|
|
||||||
[[bench]]
|
[[bench]]
|
||||||
name = "distance"
|
name = "distance"
|
||||||
|
|||||||
+58
-1
@@ -4,12 +4,14 @@ use rand::Rng;
|
|||||||
use std::iter::Sum;
|
use std::iter::Sum;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
|
use serde::{Serialize, Deserialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::distance::euclidian;
|
use crate::math::distance::euclidian;
|
||||||
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct KMeans<T: FloatExt> {
|
pub struct KMeans<T: FloatExt> {
|
||||||
k: usize,
|
k: usize,
|
||||||
y: Vec<usize>,
|
y: Vec<usize>,
|
||||||
@@ -18,6 +20,29 @@ pub struct KMeans<T: FloatExt> {
|
|||||||
centroids: Vec<Vec<T>>
|
centroids: Vec<Vec<T>>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T: FloatExt> PartialEq for KMeans<T> {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
if self.k != other.k ||
|
||||||
|
self.size != other.size ||
|
||||||
|
self.centroids.len() != other.centroids.len() {
|
||||||
|
false
|
||||||
|
} else {
|
||||||
|
let n_centroids = self.centroids.len();
|
||||||
|
for i in 0..n_centroids{
|
||||||
|
if self.centroids[i].len() != other.centroids[i].len(){
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for j in 0..self.centroids[i].len() {
|
||||||
|
if (self.centroids[i][j] - other.centroids[i][j]).abs() > T::epsilon() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct KMeansParameters {
|
pub struct KMeansParameters {
|
||||||
pub max_iter: usize
|
pub max_iter: usize
|
||||||
@@ -211,4 +236,36 @@ mod tests {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn serde() {
|
||||||
|
let x = DenseMatrix::from_array(&[
|
||||||
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
|
&[4.7, 3.2, 1.3, 0.2],
|
||||||
|
&[4.6, 3.1, 1.5, 0.2],
|
||||||
|
&[5.0, 3.6, 1.4, 0.2],
|
||||||
|
&[5.4, 3.9, 1.7, 0.4],
|
||||||
|
&[4.6, 3.4, 1.4, 0.3],
|
||||||
|
&[5.0, 3.4, 1.5, 0.2],
|
||||||
|
&[4.4, 2.9, 1.4, 0.2],
|
||||||
|
&[4.9, 3.1, 1.5, 0.1],
|
||||||
|
&[7.0, 3.2, 4.7, 1.4],
|
||||||
|
&[6.4, 3.2, 4.5, 1.5],
|
||||||
|
&[6.9, 3.1, 4.9, 1.5],
|
||||||
|
&[5.5, 2.3, 4.0, 1.3],
|
||||||
|
&[6.5, 2.8, 4.6, 1.5],
|
||||||
|
&[5.7, 2.8, 4.5, 1.3],
|
||||||
|
&[6.3, 3.3, 4.7, 1.6],
|
||||||
|
&[4.9, 2.4, 3.3, 1.0],
|
||||||
|
&[6.6, 2.9, 4.6, 1.3],
|
||||||
|
&[5.2, 2.7, 3.9, 1.4]]);
|
||||||
|
|
||||||
|
let kmeans = KMeans::new(&x, 2, Default::default());
|
||||||
|
|
||||||
|
let deserialized_kmeans: KMeans<f64> = serde_json::from_str(&serde_json::to_string(&kmeans).unwrap()).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(kmeans, deserialized_kmeans);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
+6
-6
@@ -5,7 +5,7 @@ pub mod evd;
|
|||||||
pub mod ndarray_bindings;
|
pub mod ndarray_bindings;
|
||||||
|
|
||||||
use std::ops::Range;
|
use std::ops::Range;
|
||||||
use std::fmt::Debug;
|
use std::fmt::{Debug, Display};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
@@ -13,7 +13,7 @@ use svd::SVDDecomposableMatrix;
|
|||||||
use evd::EVDDecomposableMatrix;
|
use evd::EVDDecomposableMatrix;
|
||||||
use qr::QRDecomposableMatrix;
|
use qr::QRDecomposableMatrix;
|
||||||
|
|
||||||
pub trait BaseMatrix<T: FloatExt + Debug>: Clone + Debug {
|
pub trait BaseMatrix<T: FloatExt>: Clone + Debug {
|
||||||
|
|
||||||
type RowVector: Clone + Debug;
|
type RowVector: Clone + Debug;
|
||||||
|
|
||||||
@@ -175,9 +175,9 @@ pub trait BaseMatrix<T: FloatExt + Debug>: Clone + Debug {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait Matrix<T: FloatExt + Debug>: BaseMatrix<T> + SVDDecomposableMatrix<T> + EVDDecomposableMatrix<T> + QRDecomposableMatrix<T> {}
|
pub trait Matrix<T: FloatExt>: BaseMatrix<T> + SVDDecomposableMatrix<T> + EVDDecomposableMatrix<T> + QRDecomposableMatrix<T> + PartialEq + Display {}
|
||||||
|
|
||||||
pub fn row_iter<F: FloatExt + Debug, M: Matrix<F>>(m: &M) -> RowIter<F, M> {
|
pub fn row_iter<F: FloatExt, M: BaseMatrix<F>>(m: &M) -> RowIter<F, M> {
|
||||||
RowIter{
|
RowIter{
|
||||||
m: m,
|
m: m,
|
||||||
pos: 0,
|
pos: 0,
|
||||||
@@ -186,14 +186,14 @@ pub fn row_iter<F: FloatExt + Debug, M: Matrix<F>>(m: &M) -> RowIter<F, M> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct RowIter<'a, T: FloatExt + Debug, M: Matrix<T>> {
|
pub struct RowIter<'a, T: FloatExt, M: BaseMatrix<T>> {
|
||||||
m: &'a M,
|
m: &'a M,
|
||||||
pos: usize,
|
pos: usize,
|
||||||
max_pos: usize,
|
max_pos: usize,
|
||||||
phantom: PhantomData<&'a T>
|
phantom: PhantomData<&'a T>
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: FloatExt + Debug, M: Matrix<T>> Iterator for RowIter<'a, T, M> {
|
impl<'a, T: FloatExt, M: BaseMatrix<T>> Iterator for RowIter<'a, T, M> {
|
||||||
|
|
||||||
type Item = Vec<T>;
|
type Item = Vec<T>;
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,11 @@ extern crate num;
|
|||||||
use std::ops::Range;
|
use std::ops::Range;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
use serde::{Serialize, Deserialize};
|
||||||
|
use serde::ser::{Serializer, SerializeStruct};
|
||||||
|
use serde::de::{Deserializer, Visitor, SeqAccess, MapAccess};
|
||||||
|
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
pub use crate::linalg::BaseMatrix;
|
pub use crate::linalg::BaseMatrix;
|
||||||
@@ -87,6 +92,96 @@ impl<T: FloatExt + Debug> DenseMatrix<T> {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'de, T: FloatExt + fmt::Debug + Deserialize<'de>> Deserialize<'de> for DenseMatrix<T> {
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
#[serde(field_identifier, rename_all = "lowercase")]
|
||||||
|
enum Field { NRows, NCols, Values }
|
||||||
|
|
||||||
|
struct DenseMatrixVisitor<T: FloatExt + fmt::Debug>{
|
||||||
|
t: PhantomData<T>
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T: FloatExt + fmt::Debug + Deserialize<'a>> Visitor<'a> for DenseMatrixVisitor<T> {
|
||||||
|
type Value = DenseMatrix<T>;
|
||||||
|
|
||||||
|
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
formatter.write_str("struct DenseMatrix")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_seq<V>(self, mut seq: V) -> Result<DenseMatrix<T>, V::Error>
|
||||||
|
where
|
||||||
|
V: SeqAccess<'a>,
|
||||||
|
{
|
||||||
|
let nrows = seq.next_element()?
|
||||||
|
.ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
|
||||||
|
let ncols = seq.next_element()?
|
||||||
|
.ok_or_else(|| serde::de::Error::invalid_length(1, &self))?;
|
||||||
|
let values = seq.next_element()?
|
||||||
|
.ok_or_else(|| serde::de::Error::invalid_length(2, &self))?;
|
||||||
|
Ok(DenseMatrix::new(nrows, ncols, values))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_map<V>(self, mut map: V) -> Result<DenseMatrix<T>, V::Error>
|
||||||
|
where
|
||||||
|
V: MapAccess<'a>,
|
||||||
|
{
|
||||||
|
let mut nrows = None;
|
||||||
|
let mut ncols = None;
|
||||||
|
let mut values = None;
|
||||||
|
while let Some(key) = map.next_key()? {
|
||||||
|
match key {
|
||||||
|
Field::NRows => {
|
||||||
|
if nrows.is_some() {
|
||||||
|
return Err(serde::de::Error::duplicate_field("nrows"));
|
||||||
|
}
|
||||||
|
nrows = Some(map.next_value()?);
|
||||||
|
}
|
||||||
|
Field::NCols => {
|
||||||
|
if ncols.is_some() {
|
||||||
|
return Err(serde::de::Error::duplicate_field("ncols"));
|
||||||
|
}
|
||||||
|
ncols = Some(map.next_value()?);
|
||||||
|
}
|
||||||
|
Field::Values => {
|
||||||
|
if values.is_some() {
|
||||||
|
return Err(serde::de::Error::duplicate_field("values"));
|
||||||
|
}
|
||||||
|
values = Some(map.next_value()?);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let nrows = nrows.ok_or_else(|| serde::de::Error::missing_field("nrows"))?;
|
||||||
|
let ncols = ncols.ok_or_else(|| serde::de::Error::missing_field("ncols"))?;
|
||||||
|
let values = values.ok_or_else(|| serde::de::Error::missing_field("values"))?;
|
||||||
|
Ok(DenseMatrix::new(nrows, ncols, values))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const FIELDS: &'static [&'static str] = &["nrows", "ncols", "values"];
|
||||||
|
deserializer.deserialize_struct("DenseMatrix", FIELDS, DenseMatrixVisitor {
|
||||||
|
t: PhantomData
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: FloatExt + fmt::Debug + Serialize> Serialize for DenseMatrix<T> {
|
||||||
|
|
||||||
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where
|
||||||
|
S: Serializer {
|
||||||
|
let (nrows, ncols) = self.shape();
|
||||||
|
let mut state = serializer.serialize_struct("DenseMatrix", 3)?;
|
||||||
|
state.serialize_field("nrows", &nrows)?;
|
||||||
|
state.serialize_field("ncols", &ncols)?;
|
||||||
|
state.serialize_field("values", &self.values)?;
|
||||||
|
state.end()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: FloatExt + Debug> SVDDecomposableMatrix<T> for DenseMatrix<T> {}
|
impl<T: FloatExt + Debug> SVDDecomposableMatrix<T> for DenseMatrix<T> {}
|
||||||
|
|
||||||
impl<T: FloatExt + Debug> EVDDecomposableMatrix<T> for DenseMatrix<T> {}
|
impl<T: FloatExt + Debug> EVDDecomposableMatrix<T> for DenseMatrix<T> {}
|
||||||
@@ -772,4 +867,24 @@ mod tests {
|
|||||||
assert_eq!(res, a);
|
assert_eq!(res, a);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn to_from_json() {
|
||||||
|
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||||
|
let deserialized_a: DenseMatrix<f64> = serde_json::from_str(&serde_json::to_string(&a).unwrap()).unwrap();
|
||||||
|
assert_eq!(a, deserialized_a);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn to_from_bincode() {
|
||||||
|
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||||
|
let deserialized_a: DenseMatrix<f64> = bincode::deserialize(&bincode::serialize(&a).unwrap()).unwrap();
|
||||||
|
assert_eq!(a, deserialized_a);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn to_string() {
|
||||||
|
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||||
|
assert_eq!(format!("{}", a), "[[0.9, 0.4, 0.7], [0.4, 0.5, 0.3], [0.7, 0.3, 0.8]]");
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
use std::ops::Range;
|
use std::ops::Range;
|
||||||
use std::fmt::Debug;
|
|
||||||
use std::iter::Sum;
|
use std::iter::Sum;
|
||||||
use std::ops::AddAssign;
|
use std::ops::AddAssign;
|
||||||
use std::ops::SubAssign;
|
use std::ops::SubAssign;
|
||||||
@@ -17,7 +16,7 @@ use crate::linalg::evd::EVDDecomposableMatrix;
|
|||||||
use crate::linalg::qr::QRDecomposableMatrix;
|
use crate::linalg::qr::QRDecomposableMatrix;
|
||||||
|
|
||||||
|
|
||||||
impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> BaseMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2>
|
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> BaseMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2>
|
||||||
{
|
{
|
||||||
type RowVector = ArrayBase<OwnedRepr<T>, Ix1>;
|
type RowVector = ArrayBase<OwnedRepr<T>, Ix1>;
|
||||||
|
|
||||||
@@ -286,13 +285,13 @@ impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + D
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> SVDDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> SVDDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
||||||
|
|
||||||
impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> EVDDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> EVDDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
||||||
|
|
||||||
impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> QRDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> QRDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
||||||
|
|
||||||
impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> Matrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
impl<T: FloatExt + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> Matrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
|||||||
@@ -1,22 +1,31 @@
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
|
use serde::{Serialize, Deserialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub enum LinearRegressionSolver {
|
pub enum LinearRegressionSolver {
|
||||||
QR,
|
QR,
|
||||||
SVD
|
SVD
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct LinearRegression<T: FloatExt + Debug, M: Matrix<T>> {
|
pub struct LinearRegression<T: FloatExt, M: Matrix<T>> {
|
||||||
coefficients: M,
|
coefficients: M,
|
||||||
intercept: T,
|
intercept: T,
|
||||||
solver: LinearRegressionSolver
|
solver: LinearRegressionSolver
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt + Debug, M: Matrix<T>> LinearRegression<T, M> {
|
impl<T: FloatExt, M: Matrix<T>> PartialEq for LinearRegression<T, M> {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
self.coefficients == other.coefficients &&
|
||||||
|
self.intercept == other.intercept
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: FloatExt, M: Matrix<T>> LinearRegression<T, M> {
|
||||||
|
|
||||||
pub fn fit(x: &M, y: &M, solver: LinearRegressionSolver) -> LinearRegression<T, M>{
|
pub fn fit(x: &M, y: &M, solver: LinearRegressionSolver) -> LinearRegression<T, M>{
|
||||||
|
|
||||||
@@ -90,4 +99,33 @@ mod tests {
|
|||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn serde(){
|
||||||
|
let x = DenseMatrix::from_array(&[
|
||||||
|
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||||
|
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||||
|
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||||
|
&[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
|
||||||
|
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||||
|
&[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
|
||||||
|
&[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
|
||||||
|
&[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
|
||||||
|
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||||
|
&[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||||
|
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||||
|
&[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
|
||||||
|
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||||
|
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||||
|
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||||
|
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
|
||||||
|
|
||||||
|
let y = DenseMatrix::from_array(&[&[83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]]);
|
||||||
|
|
||||||
|
let lr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR);
|
||||||
|
|
||||||
|
let deserialized_lr: LinearRegression<f64, DenseMatrix<f64>> = serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(lr, deserialized_lr);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
use serde::{Serialize, Deserialize};
|
||||||
|
|
||||||
use crate::math::num::FloatExt;
|
use crate::math::num::FloatExt;
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::optimization::FunctionOrder;
|
use crate::optimization::FunctionOrder;
|
||||||
@@ -8,15 +10,15 @@ use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
|
|||||||
use crate::optimization::line_search::Backtracking;
|
use crate::optimization::line_search::Backtracking;
|
||||||
use crate::optimization::first_order::lbfgs::LBFGS;
|
use crate::optimization::first_order::lbfgs::LBFGS;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct LogisticRegression<T: FloatExt + Debug, M: Matrix<T>> {
|
pub struct LogisticRegression<T: FloatExt, M: Matrix<T>> {
|
||||||
weights: M,
|
weights: M,
|
||||||
classes: Vec<T>,
|
classes: Vec<T>,
|
||||||
num_attributes: usize,
|
num_attributes: usize,
|
||||||
num_classes: usize
|
num_classes: usize
|
||||||
}
|
}
|
||||||
|
|
||||||
trait ObjectiveFunction<T: FloatExt + Debug, M: Matrix<T>> {
|
trait ObjectiveFunction<T: FloatExt, M: Matrix<T>> {
|
||||||
fn f(&self, w_bias: &M) -> T;
|
fn f(&self, w_bias: &M) -> T;
|
||||||
fn df(&self, g: &mut M, w_bias: &M);
|
fn df(&self, g: &mut M, w_bias: &M);
|
||||||
|
|
||||||
@@ -31,13 +33,24 @@ trait ObjectiveFunction<T: FloatExt + Debug, M: Matrix<T>> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct BinaryObjectiveFunction<'a, T: FloatExt + Debug, M: Matrix<T>> {
|
struct BinaryObjectiveFunction<'a, T: FloatExt, M: Matrix<T>> {
|
||||||
x: &'a M,
|
x: &'a M,
|
||||||
y: Vec<usize>,
|
y: Vec<usize>,
|
||||||
phantom: PhantomData<&'a T>
|
phantom: PhantomData<&'a T>
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: FloatExt + Debug, M: Matrix<T>> ObjectiveFunction<T, M> for BinaryObjectiveFunction<'a, T, M> {
|
impl<T: FloatExt, M: Matrix<T>> PartialEq for LogisticRegression<T, M> {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
|
||||||
|
self.num_classes == other.num_classes &&
|
||||||
|
self.classes == other.classes &&
|
||||||
|
self.num_attributes == other.num_attributes &&
|
||||||
|
self.weights == other.weights
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T: FloatExt, M: Matrix<T>> ObjectiveFunction<T, M> for BinaryObjectiveFunction<'a, T, M> {
|
||||||
|
|
||||||
fn f(&self, w_bias: &M) -> T {
|
fn f(&self, w_bias: &M) -> T {
|
||||||
let mut f = T::zero();
|
let mut f = T::zero();
|
||||||
@@ -72,14 +85,14 @@ impl<'a, T: FloatExt + Debug, M: Matrix<T>> ObjectiveFunction<T, M> for BinaryOb
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct MultiClassObjectiveFunction<'a, T: FloatExt + Debug, M: Matrix<T>> {
|
struct MultiClassObjectiveFunction<'a, T: FloatExt, M: Matrix<T>> {
|
||||||
x: &'a M,
|
x: &'a M,
|
||||||
y: Vec<usize>,
|
y: Vec<usize>,
|
||||||
k: usize,
|
k: usize,
|
||||||
phantom: PhantomData<&'a T>
|
phantom: PhantomData<&'a T>
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: FloatExt + Debug, M: Matrix<T>> ObjectiveFunction<T, M> for MultiClassObjectiveFunction<'a, T, M> {
|
impl<'a, T: FloatExt, M: Matrix<T>> ObjectiveFunction<T, M> for MultiClassObjectiveFunction<'a, T, M> {
|
||||||
|
|
||||||
fn f(&self, w_bias: &M) -> T {
|
fn f(&self, w_bias: &M) -> T {
|
||||||
let mut f = T::zero();
|
let mut f = T::zero();
|
||||||
@@ -125,7 +138,7 @@ impl<'a, T: FloatExt + Debug, M: Matrix<T>> ObjectiveFunction<T, M> for MultiCla
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: FloatExt + Debug, M: Matrix<T>> LogisticRegression<T, M> {
|
impl<T: FloatExt, M: Matrix<T>> LogisticRegression<T, M> {
|
||||||
|
|
||||||
pub fn fit(x: &M, y: &M::RowVector) -> LogisticRegression<T, M>{
|
pub fn fit(x: &M, y: &M::RowVector) -> LogisticRegression<T, M>{
|
||||||
|
|
||||||
@@ -371,6 +384,33 @@ mod tests {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn serde(){
|
||||||
|
let x = DenseMatrix::from_array(&[
|
||||||
|
&[1., -5.],
|
||||||
|
&[ 2., 5.],
|
||||||
|
&[ 3., -2.],
|
||||||
|
&[ 1., 2.],
|
||||||
|
&[ 2., 0.],
|
||||||
|
&[ 6., -5.],
|
||||||
|
&[ 7., 5.],
|
||||||
|
&[ 6., -2.],
|
||||||
|
&[ 7., 2.],
|
||||||
|
&[ 6., 0.],
|
||||||
|
&[ 8., -5.],
|
||||||
|
&[ 9., 5.],
|
||||||
|
&[10., -2.],
|
||||||
|
&[ 8., 2.],
|
||||||
|
&[ 9., 0.]]);
|
||||||
|
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
|
||||||
|
|
||||||
|
let lr = LogisticRegression::fit(&x, &y);
|
||||||
|
|
||||||
|
let deserialized_lr: LogisticRegression<f64, DenseMatrix<f64>> = serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(lr, deserialized_lr);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn lr_fit_predict_iris() {
|
fn lr_fit_predict_iris() {
|
||||||
let x = arr2(&[
|
let x = arr2(&[
|
||||||
|
|||||||
+2
-1
@@ -1,7 +1,8 @@
|
|||||||
|
use std::fmt::{Debug, Display};
|
||||||
use num_traits::{Float, FromPrimitive};
|
use num_traits::{Float, FromPrimitive};
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
|
|
||||||
pub trait FloatExt: Float + FromPrimitive {
|
pub trait FloatExt: Float + FromPrimitive + Debug + Display {
|
||||||
|
|
||||||
fn copysign(self, sign: Self) -> Self;
|
fn copysign(self, sign: Self) -> Self;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user