Handle kernel serialization (#232)

* Handle kernel serialization
* Do not use typetag in WASM
* enable tests for serialization
* Update serde feature deps

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
Co-authored-by: Lorenzo <tunedconsulting@gmail.com>
This commit is contained in:
morenol
2022-11-08 11:18:05 -05:00
committed by GitHub
parent 9eaae9ef35
commit 8efb959b3c
4 changed files with 30 additions and 50 deletions
+10 -36
View File
@@ -30,8 +30,6 @@ pub mod svr;
use core::fmt::Debug;
#[cfg(feature = "serde")]
use serde::ser::{SerializeStruct, Serializer};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
@@ -40,36 +38,20 @@ use crate::linalg::basic::arrays::{Array1, ArrayView1};
/// Defines a kernel function.
/// This is a object-safe trait.
pub trait Kernel {
#[cfg_attr(
all(feature = "serde", not(target_arch = "wasm32")),
typetag::serde(tag = "type")
)]
pub trait Kernel: Debug {
#[allow(clippy::ptr_arg)]
/// Apply kernel function to x_i and x_j
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed>;
/// Return a serializable name
fn name(&self) -> &'static str;
}
impl Debug for dyn Kernel {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Kernel<f64>")
}
}
#[cfg(feature = "serde")]
impl Serialize for dyn Kernel {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut s = serializer.serialize_struct("Kernel", 1)?;
s.serialize_field("type", &self.name())?;
s.end()
}
}
/// Pre-defined kernel functions
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Kernels {}
pub struct Kernels;
impl Kernels {
/// Return a default linear
@@ -211,15 +193,14 @@ impl SigmoidKernel {
}
}
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
impl Kernel for LinearKernel {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
Ok(x_i.dot(x_j))
}
fn name(&self) -> &'static str {
"Linear"
}
}
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
impl Kernel for RBFKernel {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
if self.gamma.is_none() {
@@ -231,11 +212,9 @@ impl Kernel for RBFKernel {
let v_diff = x_i.sub(x_j);
Ok((-self.gamma.unwrap() * v_diff.mul(&v_diff).sum()).exp())
}
fn name(&self) -> &'static str {
"RBF"
}
}
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
impl Kernel for PolynomialKernel {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
if self.gamma.is_none() || self.coef0.is_none() || self.degree.is_none() {
@@ -247,11 +226,9 @@ impl Kernel for PolynomialKernel {
let dot = x_i.dot(x_j);
Ok((self.gamma.unwrap() * dot + self.coef0.unwrap()).powf(self.degree.unwrap()))
}
fn name(&self) -> &'static str {
"Polynomial"
}
}
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
impl Kernel for SigmoidKernel {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
if self.gamma.is_none() || self.coef0.is_none() {
@@ -263,9 +240,6 @@ impl Kernel for SigmoidKernel {
let dot = x_i.dot(x_j);
Ok(self.gamma.unwrap() * dot + self.coef0.unwrap().tanh())
}
fn name(&self) -> &'static str {
"Sigmoid"
}
}
#[cfg(test)]