Handle kernel serialization (#232)

* Handle kernel serialization
* Do not use typetag in WASM
* enable tests for serialization
* Update serde feature deps

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
Co-authored-by: Lorenzo <tunedconsulting@gmail.com>
This commit is contained in:
morenol
2022-11-08 11:18:05 -05:00
committed by GitHub
parent 9eaae9ef35
commit 8efb959b3c
4 changed files with 30 additions and 50 deletions
+4 -1
View File
@@ -29,9 +29,12 @@ rand = { version = "0.8.5", default-features = false, features = ["small_rng"] }
rand_distr = { version = "0.4", optional = true } rand_distr = { version = "0.4", optional = true }
serde = { version = "1", features = ["derive"], optional = true } serde = { version = "1", features = ["derive"], optional = true }
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
typetag = { version = "0.2", optional = true }
[features] [features]
default = [] default = []
serde = ["dep:serde"] serde = ["dep:serde", "dep:typetag"]
ndarray-bindings = ["dep:ndarray"] ndarray-bindings = ["dep:ndarray"]
datasets = ["dep:rand_distr", "std_rand", "serde"] datasets = ["dep:rand_distr", "std_rand", "serde"]
std_rand = ["rand/std_rng", "rand/std"] std_rand = ["rand/std_rng", "rand/std"]
+10 -36
View File
@@ -30,8 +30,6 @@ pub mod svr;
use core::fmt::Debug; use core::fmt::Debug;
#[cfg(feature = "serde")]
use serde::ser::{SerializeStruct, Serializer};
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -40,36 +38,20 @@ use crate::linalg::basic::arrays::{Array1, ArrayView1};
/// Defines a kernel function. /// Defines a kernel function.
/// This is a object-safe trait. /// This is a object-safe trait.
pub trait Kernel { #[cfg_attr(
all(feature = "serde", not(target_arch = "wasm32")),
typetag::serde(tag = "type")
)]
pub trait Kernel: Debug {
#[allow(clippy::ptr_arg)] #[allow(clippy::ptr_arg)]
/// Apply kernel function to x_i and x_j /// Apply kernel function to x_i and x_j
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed>; fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed>;
/// Return a serializable name
fn name(&self) -> &'static str;
}
impl Debug for dyn Kernel {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Kernel<f64>")
}
}
#[cfg(feature = "serde")]
impl Serialize for dyn Kernel {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut s = serializer.serialize_struct("Kernel", 1)?;
s.serialize_field("type", &self.name())?;
s.end()
}
} }
/// Pre-defined kernel functions /// Pre-defined kernel functions
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Kernels {} pub struct Kernels;
impl Kernels { impl Kernels {
/// Return a default linear /// Return a default linear
@@ -211,15 +193,14 @@ impl SigmoidKernel {
} }
} }
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
impl Kernel for LinearKernel { impl Kernel for LinearKernel {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> { fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
Ok(x_i.dot(x_j)) Ok(x_i.dot(x_j))
} }
fn name(&self) -> &'static str {
"Linear"
}
} }
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
impl Kernel for RBFKernel { impl Kernel for RBFKernel {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> { fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
if self.gamma.is_none() { if self.gamma.is_none() {
@@ -231,11 +212,9 @@ impl Kernel for RBFKernel {
let v_diff = x_i.sub(x_j); let v_diff = x_i.sub(x_j);
Ok((-self.gamma.unwrap() * v_diff.mul(&v_diff).sum()).exp()) Ok((-self.gamma.unwrap() * v_diff.mul(&v_diff).sum()).exp())
} }
fn name(&self) -> &'static str {
"RBF"
}
} }
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
impl Kernel for PolynomialKernel { impl Kernel for PolynomialKernel {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> { fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
if self.gamma.is_none() || self.coef0.is_none() || self.degree.is_none() { if self.gamma.is_none() || self.coef0.is_none() || self.degree.is_none() {
@@ -247,11 +226,9 @@ impl Kernel for PolynomialKernel {
let dot = x_i.dot(x_j); let dot = x_i.dot(x_j);
Ok((self.gamma.unwrap() * dot + self.coef0.unwrap()).powf(self.degree.unwrap())) Ok((self.gamma.unwrap() * dot + self.coef0.unwrap()).powf(self.degree.unwrap()))
} }
fn name(&self) -> &'static str {
"Polynomial"
}
} }
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
impl Kernel for SigmoidKernel { impl Kernel for SigmoidKernel {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> { fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
if self.gamma.is_none() || self.coef0.is_none() { if self.gamma.is_none() || self.coef0.is_none() {
@@ -263,9 +240,6 @@ impl Kernel for SigmoidKernel {
let dot = x_i.dot(x_j); let dot = x_i.dot(x_j);
Ok(self.gamma.unwrap() * dot + self.coef0.unwrap().tanh()) Ok(self.gamma.unwrap() * dot + self.coef0.unwrap().tanh())
} }
fn name(&self) -> &'static str {
"Sigmoid"
}
} }
#[cfg(test)] #[cfg(test)]
+8 -4
View File
@@ -100,8 +100,11 @@ pub struct SVCParameters<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX
pub c: TX, pub c: TX,
/// Tolerance for stopping criterion. /// Tolerance for stopping criterion.
pub tol: TX, pub tol: TX,
#[cfg_attr(feature = "serde", serde(skip_deserializing))]
/// The kernel function. /// The kernel function.
#[cfg_attr(
all(feature = "serde", target_arch = "wasm32"),
serde(skip_serializing, skip_deserializing)
)]
pub kernel: Option<Box<dyn Kernel>>, pub kernel: Option<Box<dyn Kernel>>,
/// Unused parameter. /// Unused parameter.
m: PhantomData<(X, Y, TY)>, m: PhantomData<(X, Y, TY)>,
@@ -1085,7 +1088,7 @@ mod tests {
wasm_bindgen_test::wasm_bindgen_test wasm_bindgen_test::wasm_bindgen_test
)] )]
#[test] #[test]
#[cfg(feature = "serde")] #[cfg(all(feature = "serde", not(target_arch = "wasm32")))]
fn svc_serde() { fn svc_serde() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2], &[5.1, 3.5, 1.4, 0.2],
@@ -1119,8 +1122,9 @@ mod tests {
let svc = SVC::fit(&x, &y, &params).unwrap(); let svc = SVC::fit(&x, &y, &params).unwrap();
// serialization // serialization
let serialized_svc = &serde_json::to_string(&svc).unwrap(); let deserialized_svc: SVC<f64, i32, _, _> =
serde_json::from_str(&serde_json::to_string(&svc).unwrap()).unwrap();
println!("{:?}", serialized_svc); assert_eq!(svc, deserialized_svc);
} }
} }
+8 -9
View File
@@ -92,8 +92,11 @@ pub struct SVRParameters<T: Number + FloatNumber + PartialOrd> {
pub c: T, pub c: T,
/// Tolerance for stopping criterion. /// Tolerance for stopping criterion.
pub tol: T, pub tol: T,
#[cfg_attr(feature = "serde", serde(skip_deserializing))]
/// The kernel function. /// The kernel function.
#[cfg_attr(
all(feature = "serde", target_arch = "wasm32"),
serde(skip_serializing, skip_deserializing)
)]
pub kernel: Option<Box<dyn Kernel>>, pub kernel: Option<Box<dyn Kernel>>,
} }
@@ -668,7 +671,7 @@ mod tests {
wasm_bindgen_test::wasm_bindgen_test wasm_bindgen_test::wasm_bindgen_test
)] )]
#[test] #[test]
#[cfg(feature = "serde")] #[cfg(all(feature = "serde", not(target_arch = "wasm32")))]
fn svr_serde() { fn svr_serde() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323], &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
@@ -699,13 +702,9 @@ mod tests {
let svr = SVR::fit(&x, &y, &params).unwrap(); let svr = SVR::fit(&x, &y, &params).unwrap();
let serialized = &serde_json::to_string(&svr).unwrap(); let deserialized_svr: SVR<f64, DenseMatrix<f64>, _> =
serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap();
println!("{}", &serialized); assert_eq!(svr, deserialized_svr);
// let deserialized_svr: SVR<f64, DenseMatrix<f64>, LinearKernel> =
// serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap();
// assert_eq!(svr, deserialized_svr);
} }
} }