Use Box in SVM and remove lifetimes (#228)

* Do not change external API
Authored-by: Luis Moreno <morenol@users.noreply.github.com>
This commit is contained in:
morenol
2022-11-04 17:08:30 -05:00
committed by GitHub
parent 35fe68e024
commit 425c3c1d0b
3 changed files with 64 additions and 97 deletions
+28 -55
View File
@@ -29,7 +29,6 @@ pub mod svr;
// pub mod search; // pub mod search;
use core::fmt::Debug; use core::fmt::Debug;
use std::marker::PhantomData;
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
use serde::ser::{SerializeStruct, Serializer}; use serde::ser::{SerializeStruct, Serializer};
@@ -41,22 +40,22 @@ use crate::linalg::basic::arrays::{Array1, ArrayView1};
/// Defines a kernel function. /// Defines a kernel function.
/// This is a object-safe trait. /// This is a object-safe trait.
pub trait Kernel<'a> { pub trait Kernel {
#[allow(clippy::ptr_arg)] #[allow(clippy::ptr_arg)]
/// Apply kernel function to x_i and x_j /// Apply kernel function to x_i and x_j
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed>; fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed>;
/// Return a serializable name /// Return a serializable name
fn name(&self) -> &'a str; fn name(&self) -> &'static str;
} }
impl<'a> Debug for dyn Kernel<'_> + 'a { impl Debug for dyn Kernel {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Kernel<f64>") write!(f, "Kernel<f64>")
} }
} }
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
impl<'a> Serialize for dyn Kernel<'_> + 'a { impl Serialize for dyn Kernel {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where where
S: Serializer, S: Serializer,
@@ -72,21 +71,21 @@ impl<'a> Serialize for dyn Kernel<'_> + 'a {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Kernels {} pub struct Kernels {}
impl<'a> Kernels { impl Kernels {
/// Return a default linear /// Return a default linear
pub fn linear() -> LinearKernel<'a> { pub fn linear() -> LinearKernel {
LinearKernel::default() LinearKernel::default()
} }
/// Return a default RBF /// Return a default RBF
pub fn rbf() -> RBFKernel<'a> { pub fn rbf() -> RBFKernel {
RBFKernel::default() RBFKernel::default()
} }
/// Return a default polynomial /// Return a default polynomial
pub fn polynomial() -> PolynomialKernel<'a> { pub fn polynomial() -> PolynomialKernel {
PolynomialKernel::default() PolynomialKernel::default()
} }
/// Return a default sigmoid /// Return a default sigmoid
pub fn sigmoid() -> SigmoidKernel<'a> { pub fn sigmoid() -> SigmoidKernel {
SigmoidKernel::default() SigmoidKernel::default()
} }
} }
@@ -94,39 +93,19 @@ impl<'a> Kernels {
/// Linear Kernel /// Linear Kernel
#[allow(clippy::derive_partial_eq_without_eq)] #[allow(clippy::derive_partial_eq_without_eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct LinearKernel<'a> { pub struct LinearKernel;
phantom: PhantomData<&'a ()>,
}
impl<'a> Default for LinearKernel<'a> {
fn default() -> Self {
Self {
phantom: PhantomData,
}
}
}
/// Radial basis function (Gaussian) kernel /// Radial basis function (Gaussian) kernel
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Default, Clone, PartialEq)]
pub struct RBFKernel<'a> { pub struct RBFKernel {
/// kernel coefficient /// kernel coefficient
pub gamma: Option<f64>, pub gamma: Option<f64>,
phantom: PhantomData<&'a ()>,
}
impl<'a> Default for RBFKernel<'a> {
fn default() -> Self {
Self {
gamma: Option::None,
phantom: PhantomData,
}
}
} }
#[allow(dead_code)] #[allow(dead_code)]
impl<'a> RBFKernel<'a> { impl RBFKernel {
/// assign gamma parameter to kernel (required) /// assign gamma parameter to kernel (required)
/// ```rust /// ```rust
/// use smartcore::svm::RBFKernel; /// use smartcore::svm::RBFKernel;
@@ -141,29 +120,26 @@ impl<'a> RBFKernel<'a> {
/// Polynomial kernel /// Polynomial kernel
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub struct PolynomialKernel<'a> { pub struct PolynomialKernel {
/// degree of the polynomial /// degree of the polynomial
pub degree: Option<f64>, pub degree: Option<f64>,
/// kernel coefficient /// kernel coefficient
pub gamma: Option<f64>, pub gamma: Option<f64>,
/// independent term in kernel function /// independent term in kernel function
pub coef0: Option<f64>, pub coef0: Option<f64>,
phantom: PhantomData<&'a ()>,
} }
impl<'a> Default for PolynomialKernel<'a> { impl Default for PolynomialKernel {
fn default() -> Self { fn default() -> Self {
Self { Self {
gamma: Option::None, gamma: Option::None,
degree: Option::None, degree: Option::None,
coef0: Some(1f64), coef0: Some(1f64),
phantom: PhantomData,
} }
} }
} }
#[allow(dead_code)] impl PolynomialKernel {
impl<'a> PolynomialKernel<'a> {
/// set parameters for kernel /// set parameters for kernel
/// ```rust /// ```rust
/// use smartcore::svm::PolynomialKernel; /// use smartcore::svm::PolynomialKernel;
@@ -197,26 +173,23 @@ impl<'a> PolynomialKernel<'a> {
/// Sigmoid (hyperbolic tangent) kernel /// Sigmoid (hyperbolic tangent) kernel
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub struct SigmoidKernel<'a> { pub struct SigmoidKernel {
/// kernel coefficient /// kernel coefficient
pub gamma: Option<f64>, pub gamma: Option<f64>,
/// independent term in kernel function /// independent term in kernel function
pub coef0: Option<f64>, pub coef0: Option<f64>,
phantom: PhantomData<&'a ()>,
} }
impl<'a> Default for SigmoidKernel<'a> { impl Default for SigmoidKernel {
fn default() -> Self { fn default() -> Self {
Self { Self {
gamma: Option::None, gamma: Option::None,
coef0: Some(1f64), coef0: Some(1f64),
phantom: PhantomData,
} }
} }
} }
#[allow(dead_code)] impl SigmoidKernel {
impl<'a> SigmoidKernel<'a> {
/// set parameters for kernel /// set parameters for kernel
/// ```rust /// ```rust
/// use smartcore::svm::SigmoidKernel; /// use smartcore::svm::SigmoidKernel;
@@ -238,16 +211,16 @@ impl<'a> SigmoidKernel<'a> {
} }
} }
impl<'a> Kernel<'a> for LinearKernel<'a> { impl Kernel for LinearKernel {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> { fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
Ok(x_i.dot(x_j)) Ok(x_i.dot(x_j))
} }
fn name(&self) -> &'a str { fn name(&self) -> &'static str {
"Linear" "Linear"
} }
} }
impl<'a> Kernel<'a> for RBFKernel<'a> { impl Kernel for RBFKernel {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> { fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
if self.gamma.is_none() { if self.gamma.is_none() {
return Err(Failed::because( return Err(Failed::because(
@@ -258,12 +231,12 @@ impl<'a> Kernel<'a> for RBFKernel<'a> {
let v_diff = x_i.sub(x_j); let v_diff = x_i.sub(x_j);
Ok((-self.gamma.unwrap() * v_diff.mul(&v_diff).sum()).exp()) Ok((-self.gamma.unwrap() * v_diff.mul(&v_diff).sum()).exp())
} }
fn name(&self) -> &'a str { fn name(&self) -> &'static str {
"RBF" "RBF"
} }
} }
impl<'a> Kernel<'a> for PolynomialKernel<'a> { impl Kernel for PolynomialKernel {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> { fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
if self.gamma.is_none() || self.coef0.is_none() || self.degree.is_none() { if self.gamma.is_none() || self.coef0.is_none() || self.degree.is_none() {
return Err(Failed::because( return Err(Failed::because(
@@ -274,12 +247,12 @@ impl<'a> Kernel<'a> for PolynomialKernel<'a> {
let dot = x_i.dot(x_j); let dot = x_i.dot(x_j);
Ok((self.gamma.unwrap() * dot + self.coef0.unwrap()).powf(self.degree.unwrap())) Ok((self.gamma.unwrap() * dot + self.coef0.unwrap()).powf(self.degree.unwrap()))
} }
fn name(&self) -> &'a str { fn name(&self) -> &'static str {
"Polynomial" "Polynomial"
} }
} }
impl<'a> Kernel<'a> for SigmoidKernel<'a> { impl Kernel for SigmoidKernel {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> { fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
if self.gamma.is_none() || self.coef0.is_none() { if self.gamma.is_none() || self.coef0.is_none() {
return Err(Failed::because( return Err(Failed::because(
@@ -290,7 +263,7 @@ impl<'a> Kernel<'a> for SigmoidKernel<'a> {
let dot = x_i.dot(x_j); let dot = x_i.dot(x_j);
Ok(self.gamma.unwrap() * dot + self.coef0.unwrap().tanh()) Ok(self.gamma.unwrap() * dot + self.coef0.unwrap().tanh())
} }
fn name(&self) -> &'a str { fn name(&self) -> &'static str {
"Sigmoid" "Sigmoid"
} }
} }
+20 -26
View File
@@ -58,7 +58,7 @@
//! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; //! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
//! //!
//! let knl = Kernels::linear(); //! let knl = Kernels::linear();
//! let params = &SVCParameters::default().with_c(200.0).with_kernel(&knl); //! let params = &SVCParameters::default().with_c(200.0).with_kernel(knl);
//! let svc = SVC::fit(&x, &y, params).unwrap(); //! let svc = SVC::fit(&x, &y, params).unwrap();
//! //!
//! let y_hat = svc.predict(&x).unwrap(); //! let y_hat = svc.predict(&x).unwrap();
@@ -91,15 +91,9 @@ use crate::rand_custom::get_rng_impl;
use crate::svm::Kernel; use crate::svm::Kernel;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)] #[derive(Debug)]
/// SVC Parameters /// SVC Parameters
pub struct SVCParameters< pub struct SVCParameters<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> {
'a,
TX: Number + RealNumber,
TY: Number + Ord,
X: Array2<TX>,
Y: Array1<TY>,
> {
/// Number of epochs. /// Number of epochs.
pub epoch: usize, pub epoch: usize,
/// Regularization parameter. /// Regularization parameter.
@@ -108,7 +102,7 @@ pub struct SVCParameters<
pub tol: TX, pub tol: TX,
#[cfg_attr(feature = "serde", serde(skip_deserializing))] #[cfg_attr(feature = "serde", serde(skip_deserializing))]
/// The kernel function. /// The kernel function.
pub kernel: Option<&'a dyn Kernel<'a>>, pub kernel: Option<Box<dyn Kernel>>,
/// Unused parameter. /// Unused parameter.
m: PhantomData<(X, Y, TY)>, m: PhantomData<(X, Y, TY)>,
/// Controls the pseudo random number generation for shuffling the data for probability estimates /// Controls the pseudo random number generation for shuffling the data for probability estimates
@@ -129,7 +123,7 @@ pub struct SVC<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y:
classes: Option<Vec<TY>>, classes: Option<Vec<TY>>,
instances: Option<Vec<Vec<TX>>>, instances: Option<Vec<Vec<TX>>>,
#[cfg_attr(feature = "serde", serde(skip))] #[cfg_attr(feature = "serde", serde(skip))]
parameters: Option<&'a SVCParameters<'a, TX, TY, X, Y>>, parameters: Option<&'a SVCParameters<TX, TY, X, Y>>,
w: Option<Vec<TX>>, w: Option<Vec<TX>>,
b: Option<TX>, b: Option<TX>,
phantomdata: PhantomData<(X, Y)>, phantomdata: PhantomData<(X, Y)>,
@@ -155,7 +149,7 @@ struct Cache<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1
struct Optimizer<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> { struct Optimizer<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> {
x: &'a X, x: &'a X,
y: &'a Y, y: &'a Y,
parameters: &'a SVCParameters<'a, TX, TY, X, Y>, parameters: &'a SVCParameters<TX, TY, X, Y>,
svmin: usize, svmin: usize,
svmax: usize, svmax: usize,
gmin: TX, gmin: TX,
@@ -165,8 +159,8 @@ struct Optimizer<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y
recalculate_minmax_grad: bool, recalculate_minmax_grad: bool,
} }
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> impl<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
SVCParameters<'a, TX, TY, X, Y> SVCParameters<TX, TY, X, Y>
{ {
/// Number of epochs. /// Number of epochs.
pub fn with_epoch(mut self, epoch: usize) -> Self { pub fn with_epoch(mut self, epoch: usize) -> Self {
@@ -184,8 +178,8 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
self self
} }
/// The kernel function. /// The kernel function.
pub fn with_kernel(mut self, kernel: &'a (dyn Kernel<'a>)) -> Self { pub fn with_kernel<K: Kernel + 'static>(mut self, kernel: K) -> Self {
self.kernel = Some(kernel); self.kernel = Some(Box::new(kernel));
self self
} }
@@ -196,8 +190,8 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
} }
} }
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> Default impl<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> Default
for SVCParameters<'a, TX, TY, X, Y> for SVCParameters<TX, TY, X, Y>
{ {
fn default() -> Self { fn default() -> Self {
SVCParameters { SVCParameters {
@@ -212,7 +206,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
} }
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
SupervisedEstimatorBorrow<'a, X, Y, SVCParameters<'a, TX, TY, X, Y>> for SVC<'a, TX, TY, X, Y> SupervisedEstimatorBorrow<'a, X, Y, SVCParameters<TX, TY, X, Y>> for SVC<'a, TX, TY, X, Y>
{ {
fn new() -> Self { fn new() -> Self {
Self { Self {
@@ -227,7 +221,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
fn fit( fn fit(
x: &'a X, x: &'a X,
y: &'a Y, y: &'a Y,
parameters: &'a SVCParameters<'a, TX, TY, X, Y>, parameters: &'a SVCParameters<TX, TY, X, Y>,
) -> Result<Self, Failed> { ) -> Result<Self, Failed> {
SVC::fit(x, y, parameters) SVC::fit(x, y, parameters)
} }
@@ -251,7 +245,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array
pub fn fit( pub fn fit(
x: &'a X, x: &'a X,
y: &'a Y, y: &'a Y,
parameters: &'a SVCParameters<'a, TX, TY, X, Y>, parameters: &'a SVCParameters<TX, TY, X, Y>,
) -> Result<SVC<'a, TX, TY, X, Y>, Failed> { ) -> Result<SVC<'a, TX, TY, X, Y>, Failed> {
let (n, _) = x.shape(); let (n, _) = x.shape();
@@ -447,7 +441,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
fn new( fn new(
x: &'a X, x: &'a X,
y: &'a Y, y: &'a Y,
parameters: &'a SVCParameters<'a, TX, TY, X, Y>, parameters: &'a SVCParameters<TX, TY, X, Y>,
) -> Optimizer<'a, TX, TY, X, Y> { ) -> Optimizer<'a, TX, TY, X, Y> {
let (n, _) = x.shape(); let (n, _) = x.shape();
@@ -979,7 +973,7 @@ mod tests {
let knl = Kernels::linear(); let knl = Kernels::linear();
let params = SVCParameters::default() let params = SVCParameters::default()
.with_c(200.0) .with_c(200.0)
.with_kernel(&knl) .with_kernel(knl)
.with_seed(Some(100)); .with_seed(Some(100));
let y_hat = SVC::fit(&x, &y, &params) let y_hat = SVC::fit(&x, &y, &params)
@@ -1018,7 +1012,7 @@ mod tests {
&y, &y,
&SVCParameters::default() &SVCParameters::default()
.with_c(200.0) .with_c(200.0)
.with_kernel(&Kernels::linear()), .with_kernel(Kernels::linear()),
) )
.and_then(|lr| lr.decision_function(&x2)) .and_then(|lr| lr.decision_function(&x2))
.unwrap(); .unwrap();
@@ -1073,7 +1067,7 @@ mod tests {
&y, &y,
&SVCParameters::default() &SVCParameters::default()
.with_c(1.0) .with_c(1.0)
.with_kernel(&Kernels::rbf().with_gamma(0.7)), .with_kernel(Kernels::rbf().with_gamma(0.7)),
) )
.and_then(|lr| lr.predict(&x)) .and_then(|lr| lr.predict(&x))
.unwrap(); .unwrap();
@@ -1122,7 +1116,7 @@ mod tests {
]; ];
let knl = Kernels::linear(); let knl = Kernels::linear();
let params = SVCParameters::default().with_kernel(&knl); let params = SVCParameters::default().with_kernel(knl);
let svc = SVC::fit(&x, &y, &params).unwrap(); let svc = SVC::fit(&x, &y, &params).unwrap();
// serialization // serialization
+16 -16
View File
@@ -50,7 +50,7 @@
//! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]; //! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
//! //!
//! let knl = Kernels::linear(); //! let knl = Kernels::linear();
//! let params = &SVRParameters::default().with_eps(2.0).with_c(10.0).with_kernel(&knl); //! let params = &SVRParameters::default().with_eps(2.0).with_c(10.0).with_kernel(knl);
//! // let svr = SVR::fit(&x, &y, params).unwrap(); //! // let svr = SVR::fit(&x, &y, params).unwrap();
//! //!
//! // let y_hat = svr.predict(&x).unwrap(); //! // let y_hat = svr.predict(&x).unwrap();
@@ -83,9 +83,9 @@ use crate::numbers::floatnum::FloatNumber;
use crate::svm::Kernel; use crate::svm::Kernel;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)] #[derive(Debug)]
/// SVR Parameters /// SVR Parameters
pub struct SVRParameters<'a, T: Number + FloatNumber + PartialOrd> { pub struct SVRParameters<T: Number + FloatNumber + PartialOrd> {
/// Epsilon in the epsilon-SVR model. /// Epsilon in the epsilon-SVR model.
pub eps: T, pub eps: T,
/// Regularization parameter. /// Regularization parameter.
@@ -94,7 +94,7 @@ pub struct SVRParameters<'a, T: Number + FloatNumber + PartialOrd> {
pub tol: T, pub tol: T,
#[cfg_attr(feature = "serde", serde(skip_deserializing))] #[cfg_attr(feature = "serde", serde(skip_deserializing))]
/// The kernel function. /// The kernel function.
pub kernel: Option<&'a dyn Kernel<'a>>, pub kernel: Option<Box<dyn Kernel>>,
} }
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@@ -103,7 +103,7 @@ pub struct SVRParameters<'a, T: Number + FloatNumber + PartialOrd> {
pub struct SVR<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> { pub struct SVR<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> {
instances: Option<Vec<Vec<f64>>>, instances: Option<Vec<Vec<f64>>>,
#[cfg_attr(feature = "serde", serde(skip_deserializing))] #[cfg_attr(feature = "serde", serde(skip_deserializing))]
parameters: Option<&'a SVRParameters<'a, T>>, parameters: Option<&'a SVRParameters<T>>,
w: Option<Vec<T>>, w: Option<Vec<T>>,
b: T, b: T,
phantom: PhantomData<(X, Y)>, phantom: PhantomData<(X, Y)>,
@@ -123,7 +123,7 @@ struct SupportVector<T> {
struct Optimizer<'a, T: Number + FloatNumber + PartialOrd> { struct Optimizer<'a, T: Number + FloatNumber + PartialOrd> {
tol: T, tol: T,
c: T, c: T,
parameters: Option<&'a SVRParameters<'a, T>>, parameters: Option<&'a SVRParameters<T>>,
svmin: usize, svmin: usize,
svmax: usize, svmax: usize,
gmin: T, gmin: T,
@@ -140,7 +140,7 @@ struct Cache<T: Clone> {
data: Vec<RefCell<Option<Vec<T>>>>, data: Vec<RefCell<Option<Vec<T>>>>,
} }
impl<'a, T: Number + FloatNumber + PartialOrd> SVRParameters<'a, T> { impl<T: Number + FloatNumber + PartialOrd> SVRParameters<T> {
/// Epsilon in the epsilon-SVR model. /// Epsilon in the epsilon-SVR model.
pub fn with_eps(mut self, eps: T) -> Self { pub fn with_eps(mut self, eps: T) -> Self {
self.eps = eps; self.eps = eps;
@@ -157,13 +157,13 @@ impl<'a, T: Number + FloatNumber + PartialOrd> SVRParameters<'a, T> {
self self
} }
/// The kernel function. /// The kernel function.
pub fn with_kernel(mut self, kernel: &'a (dyn Kernel<'a>)) -> Self { pub fn with_kernel<K: Kernel + 'static>(mut self, kernel: K) -> Self {
self.kernel = Some(kernel); self.kernel = Some(Box::new(kernel));
self self
} }
} }
impl<'a, T: Number + FloatNumber + PartialOrd> Default for SVRParameters<'a, T> { impl<T: Number + FloatNumber + PartialOrd> Default for SVRParameters<T> {
fn default() -> Self { fn default() -> Self {
SVRParameters { SVRParameters {
eps: T::from_f64(0.1).unwrap(), eps: T::from_f64(0.1).unwrap(),
@@ -175,7 +175,7 @@ impl<'a, T: Number + FloatNumber + PartialOrd> Default for SVRParameters<'a, T>
} }
impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>>
SupervisedEstimatorBorrow<'a, X, Y, SVRParameters<'a, T>> for SVR<'a, T, X, Y> SupervisedEstimatorBorrow<'a, X, Y, SVRParameters<T>> for SVR<'a, T, X, Y>
{ {
fn new() -> Self { fn new() -> Self {
Self { Self {
@@ -186,7 +186,7 @@ impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>>
phantom: PhantomData, phantom: PhantomData,
} }
} }
fn fit(x: &'a X, y: &'a Y, parameters: &'a SVRParameters<'a, T>) -> Result<Self, Failed> { fn fit(x: &'a X, y: &'a Y, parameters: &'a SVRParameters<T>) -> Result<Self, Failed> {
SVR::fit(x, y, parameters) SVR::fit(x, y, parameters)
} }
} }
@@ -208,7 +208,7 @@ impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> SVR<'
pub fn fit( pub fn fit(
x: &'a X, x: &'a X,
y: &'a Y, y: &'a Y,
parameters: &'a SVRParameters<'a, T>, parameters: &'a SVRParameters<T>,
) -> Result<SVR<'a, T, X, Y>, Failed> { ) -> Result<SVR<'a, T, X, Y>, Failed> {
let (n, _) = x.shape(); let (n, _) = x.shape();
@@ -324,7 +324,7 @@ impl<'a, T: Number + FloatNumber + PartialOrd> Optimizer<'a, T> {
fn new<X: Array2<T>, Y: Array1<T>>( fn new<X: Array2<T>, Y: Array1<T>>(
x: &'a X, x: &'a X,
y: &'a Y, y: &'a Y,
parameters: &'a SVRParameters<'a, T>, parameters: &'a SVRParameters<T>,
) -> Optimizer<'a, T> { ) -> Optimizer<'a, T> {
let (n, _) = x.shape(); let (n, _) = x.shape();
@@ -655,7 +655,7 @@ mod tests {
&SVRParameters::default() &SVRParameters::default()
.with_eps(2.0) .with_eps(2.0)
.with_c(10.0) .with_c(10.0)
.with_kernel(&knl), .with_kernel(knl),
) )
.and_then(|lr| lr.predict(&x)) .and_then(|lr| lr.predict(&x))
.unwrap(); .unwrap();
@@ -697,7 +697,7 @@ mod tests {
]; ];
let knl = Kernels::rbf().with_gamma(0.7); let knl = Kernels::rbf().with_gamma(0.7);
let params = SVRParameters::default().with_kernel(&knl); let params = SVRParameters::default().with_kernel(knl);
let svr = SVR::fit(&x, &y, &params).unwrap(); let svr = SVR::fit(&x, &y, &params).unwrap();