Handle kernel serialization (#232)
* Handle kernel serialization * Do not use typetag in WASM * enable tests for serialization * Update serde feature deps Co-authored-by: Luis Moreno <morenol@users.noreply.github.com> Co-authored-by: Lorenzo <tunedconsulting@gmail.com>
This commit is contained in:
+10
-36
@@ -30,8 +30,6 @@ pub mod svr;
|
||||
|
||||
use core::fmt::Debug;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::ser::{SerializeStruct, Serializer};
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -40,36 +38,20 @@ use crate::linalg::basic::arrays::{Array1, ArrayView1};
|
||||
|
||||
/// Defines a kernel function.
|
||||
/// This is a object-safe trait.
|
||||
pub trait Kernel {
|
||||
#[cfg_attr(
|
||||
all(feature = "serde", not(target_arch = "wasm32")),
|
||||
typetag::serde(tag = "type")
|
||||
)]
|
||||
pub trait Kernel: Debug {
|
||||
#[allow(clippy::ptr_arg)]
|
||||
/// Apply kernel function to x_i and x_j
|
||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed>;
|
||||
/// Return a serializable name
|
||||
fn name(&self) -> &'static str;
|
||||
}
|
||||
|
||||
impl Debug for dyn Kernel {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
write!(f, "Kernel<f64>")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
impl Serialize for dyn Kernel {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let mut s = serializer.serialize_struct("Kernel", 1)?;
|
||||
s.serialize_field("type", &self.name())?;
|
||||
s.end()
|
||||
}
|
||||
}
|
||||
|
||||
/// Pre-defined kernel functions
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Kernels {}
|
||||
pub struct Kernels;
|
||||
|
||||
impl Kernels {
|
||||
/// Return a default linear
|
||||
@@ -211,15 +193,14 @@ impl SigmoidKernel {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
|
||||
impl Kernel for LinearKernel {
|
||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
||||
Ok(x_i.dot(x_j))
|
||||
}
|
||||
fn name(&self) -> &'static str {
|
||||
"Linear"
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
|
||||
impl Kernel for RBFKernel {
|
||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
||||
if self.gamma.is_none() {
|
||||
@@ -231,11 +212,9 @@ impl Kernel for RBFKernel {
|
||||
let v_diff = x_i.sub(x_j);
|
||||
Ok((-self.gamma.unwrap() * v_diff.mul(&v_diff).sum()).exp())
|
||||
}
|
||||
fn name(&self) -> &'static str {
|
||||
"RBF"
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
|
||||
impl Kernel for PolynomialKernel {
|
||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
||||
if self.gamma.is_none() || self.coef0.is_none() || self.degree.is_none() {
|
||||
@@ -247,11 +226,9 @@ impl Kernel for PolynomialKernel {
|
||||
let dot = x_i.dot(x_j);
|
||||
Ok((self.gamma.unwrap() * dot + self.coef0.unwrap()).powf(self.degree.unwrap()))
|
||||
}
|
||||
fn name(&self) -> &'static str {
|
||||
"Polynomial"
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
|
||||
impl Kernel for SigmoidKernel {
|
||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
||||
if self.gamma.is_none() || self.coef0.is_none() {
|
||||
@@ -263,9 +240,6 @@ impl Kernel for SigmoidKernel {
|
||||
let dot = x_i.dot(x_j);
|
||||
Ok(self.gamma.unwrap() * dot + self.coef0.unwrap().tanh())
|
||||
}
|
||||
fn name(&self) -> &'static str {
|
||||
"Sigmoid"
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
+8
-4
@@ -100,8 +100,11 @@ pub struct SVCParameters<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX
|
||||
pub c: TX,
|
||||
/// Tolerance for stopping criterion.
|
||||
pub tol: TX,
|
||||
#[cfg_attr(feature = "serde", serde(skip_deserializing))]
|
||||
/// The kernel function.
|
||||
#[cfg_attr(
|
||||
all(feature = "serde", target_arch = "wasm32"),
|
||||
serde(skip_serializing, skip_deserializing)
|
||||
)]
|
||||
pub kernel: Option<Box<dyn Kernel>>,
|
||||
/// Unused parameter.
|
||||
m: PhantomData<(X, Y, TY)>,
|
||||
@@ -1085,7 +1088,7 @@ mod tests {
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
#[cfg(all(feature = "serde", not(target_arch = "wasm32")))]
|
||||
fn svc_serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
@@ -1119,8 +1122,9 @@ mod tests {
|
||||
let svc = SVC::fit(&x, &y, ¶ms).unwrap();
|
||||
|
||||
// serialization
|
||||
let serialized_svc = &serde_json::to_string(&svc).unwrap();
|
||||
let deserialized_svc: SVC<f64, i32, _, _> =
|
||||
serde_json::from_str(&serde_json::to_string(&svc).unwrap()).unwrap();
|
||||
|
||||
println!("{:?}", serialized_svc);
|
||||
assert_eq!(svc, deserialized_svc);
|
||||
}
|
||||
}
|
||||
|
||||
+8
-9
@@ -92,8 +92,11 @@ pub struct SVRParameters<T: Number + FloatNumber + PartialOrd> {
|
||||
pub c: T,
|
||||
/// Tolerance for stopping criterion.
|
||||
pub tol: T,
|
||||
#[cfg_attr(feature = "serde", serde(skip_deserializing))]
|
||||
/// The kernel function.
|
||||
#[cfg_attr(
|
||||
all(feature = "serde", target_arch = "wasm32"),
|
||||
serde(skip_serializing, skip_deserializing)
|
||||
)]
|
||||
pub kernel: Option<Box<dyn Kernel>>,
|
||||
}
|
||||
|
||||
@@ -668,7 +671,7 @@ mod tests {
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
#[cfg(all(feature = "serde", not(target_arch = "wasm32")))]
|
||||
fn svr_serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
@@ -699,13 +702,9 @@ mod tests {
|
||||
|
||||
let svr = SVR::fit(&x, &y, ¶ms).unwrap();
|
||||
|
||||
let serialized = &serde_json::to_string(&svr).unwrap();
|
||||
let deserialized_svr: SVR<f64, DenseMatrix<f64>, _> =
|
||||
serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap();
|
||||
|
||||
println!("{}", &serialized);
|
||||
|
||||
// let deserialized_svr: SVR<f64, DenseMatrix<f64>, LinearKernel> =
|
||||
// serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap();
|
||||
|
||||
// assert_eq!(svr, deserialized_svr);
|
||||
assert_eq!(svr, deserialized_svr);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user