Handle kernel serialization (#232)
* Handle kernel serialization * Do not use typetag in WASM * enable tests for serialization * Update serde feature deps Co-authored-by: Luis Moreno <morenol@users.noreply.github.com> Co-authored-by: Lorenzo <tunedconsulting@gmail.com>
This commit is contained in:
+10
-36
@@ -30,8 +30,6 @@ pub mod svr;
|
||||
|
||||
use core::fmt::Debug;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::ser::{SerializeStruct, Serializer};
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -40,36 +38,20 @@ use crate::linalg::basic::arrays::{Array1, ArrayView1};
|
||||
|
||||
/// Defines a kernel function.
|
||||
/// This is a object-safe trait.
|
||||
pub trait Kernel {
|
||||
#[cfg_attr(
|
||||
all(feature = "serde", not(target_arch = "wasm32")),
|
||||
typetag::serde(tag = "type")
|
||||
)]
|
||||
pub trait Kernel: Debug {
|
||||
#[allow(clippy::ptr_arg)]
|
||||
/// Apply kernel function to x_i and x_j
|
||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed>;
|
||||
/// Return a serializable name
|
||||
fn name(&self) -> &'static str;
|
||||
}
|
||||
|
||||
impl Debug for dyn Kernel {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
write!(f, "Kernel<f64>")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
impl Serialize for dyn Kernel {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let mut s = serializer.serialize_struct("Kernel", 1)?;
|
||||
s.serialize_field("type", &self.name())?;
|
||||
s.end()
|
||||
}
|
||||
}
|
||||
|
||||
/// Pre-defined kernel functions
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Kernels {}
|
||||
pub struct Kernels;
|
||||
|
||||
impl Kernels {
|
||||
/// Return a default linear
|
||||
@@ -211,15 +193,14 @@ impl SigmoidKernel {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
|
||||
impl Kernel for LinearKernel {
|
||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
||||
Ok(x_i.dot(x_j))
|
||||
}
|
||||
fn name(&self) -> &'static str {
|
||||
"Linear"
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
|
||||
impl Kernel for RBFKernel {
|
||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
||||
if self.gamma.is_none() {
|
||||
@@ -231,11 +212,9 @@ impl Kernel for RBFKernel {
|
||||
let v_diff = x_i.sub(x_j);
|
||||
Ok((-self.gamma.unwrap() * v_diff.mul(&v_diff).sum()).exp())
|
||||
}
|
||||
fn name(&self) -> &'static str {
|
||||
"RBF"
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
|
||||
impl Kernel for PolynomialKernel {
|
||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
||||
if self.gamma.is_none() || self.coef0.is_none() || self.degree.is_none() {
|
||||
@@ -247,11 +226,9 @@ impl Kernel for PolynomialKernel {
|
||||
let dot = x_i.dot(x_j);
|
||||
Ok((self.gamma.unwrap() * dot + self.coef0.unwrap()).powf(self.degree.unwrap()))
|
||||
}
|
||||
fn name(&self) -> &'static str {
|
||||
"Polynomial"
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
|
||||
impl Kernel for SigmoidKernel {
|
||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
||||
if self.gamma.is_none() || self.coef0.is_none() {
|
||||
@@ -263,9 +240,6 @@ impl Kernel for SigmoidKernel {
|
||||
let dot = x_i.dot(x_j);
|
||||
Ok(self.gamma.unwrap() * dot + self.coef0.unwrap().tanh())
|
||||
}
|
||||
fn name(&self) -> &'static str {
|
||||
"Sigmoid"
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
Reference in New Issue
Block a user