Use Box in SVM and remove lifetimes (#228)
* Do not change external API Authored-by: Luis Moreno <morenol@users.noreply.github.com>
This commit is contained in:
+28
-55
@@ -29,7 +29,6 @@ pub mod svr;
|
|||||||
// pub mod search;
|
// pub mod search;
|
||||||
|
|
||||||
use core::fmt::Debug;
|
use core::fmt::Debug;
|
||||||
use std::marker::PhantomData;
|
|
||||||
|
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
use serde::ser::{SerializeStruct, Serializer};
|
use serde::ser::{SerializeStruct, Serializer};
|
||||||
@@ -41,22 +40,22 @@ use crate::linalg::basic::arrays::{Array1, ArrayView1};
|
|||||||
|
|
||||||
/// Defines a kernel function.
|
/// Defines a kernel function.
|
||||||
/// This is a object-safe trait.
|
/// This is a object-safe trait.
|
||||||
pub trait Kernel<'a> {
|
pub trait Kernel {
|
||||||
#[allow(clippy::ptr_arg)]
|
#[allow(clippy::ptr_arg)]
|
||||||
/// Apply kernel function to x_i and x_j
|
/// Apply kernel function to x_i and x_j
|
||||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed>;
|
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed>;
|
||||||
/// Return a serializable name
|
/// Return a serializable name
|
||||||
fn name(&self) -> &'a str;
|
fn name(&self) -> &'static str;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Debug for dyn Kernel<'_> + 'a {
|
impl Debug for dyn Kernel {
|
||||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||||
write!(f, "Kernel<f64>")
|
write!(f, "Kernel<f64>")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
impl<'a> Serialize for dyn Kernel<'_> + 'a {
|
impl Serialize for dyn Kernel {
|
||||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
where
|
where
|
||||||
S: Serializer,
|
S: Serializer,
|
||||||
@@ -72,21 +71,21 @@ impl<'a> Serialize for dyn Kernel<'_> + 'a {
|
|||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Kernels {}
|
pub struct Kernels {}
|
||||||
|
|
||||||
impl<'a> Kernels {
|
impl Kernels {
|
||||||
/// Return a default linear
|
/// Return a default linear
|
||||||
pub fn linear() -> LinearKernel<'a> {
|
pub fn linear() -> LinearKernel {
|
||||||
LinearKernel::default()
|
LinearKernel::default()
|
||||||
}
|
}
|
||||||
/// Return a default RBF
|
/// Return a default RBF
|
||||||
pub fn rbf() -> RBFKernel<'a> {
|
pub fn rbf() -> RBFKernel {
|
||||||
RBFKernel::default()
|
RBFKernel::default()
|
||||||
}
|
}
|
||||||
/// Return a default polynomial
|
/// Return a default polynomial
|
||||||
pub fn polynomial() -> PolynomialKernel<'a> {
|
pub fn polynomial() -> PolynomialKernel {
|
||||||
PolynomialKernel::default()
|
PolynomialKernel::default()
|
||||||
}
|
}
|
||||||
/// Return a default sigmoid
|
/// Return a default sigmoid
|
||||||
pub fn sigmoid() -> SigmoidKernel<'a> {
|
pub fn sigmoid() -> SigmoidKernel {
|
||||||
SigmoidKernel::default()
|
SigmoidKernel::default()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -94,39 +93,19 @@ impl<'a> Kernels {
|
|||||||
/// Linear Kernel
|
/// Linear Kernel
|
||||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||||
pub struct LinearKernel<'a> {
|
pub struct LinearKernel;
|
||||||
phantom: PhantomData<&'a ()>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> Default for LinearKernel<'a> {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
phantom: PhantomData,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Radial basis function (Gaussian) kernel
|
/// Radial basis function (Gaussian) kernel
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Default, Clone, PartialEq)]
|
||||||
pub struct RBFKernel<'a> {
|
pub struct RBFKernel {
|
||||||
/// kernel coefficient
|
/// kernel coefficient
|
||||||
pub gamma: Option<f64>,
|
pub gamma: Option<f64>,
|
||||||
phantom: PhantomData<&'a ()>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> Default for RBFKernel<'a> {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
gamma: Option::None,
|
|
||||||
phantom: PhantomData,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
impl<'a> RBFKernel<'a> {
|
impl RBFKernel {
|
||||||
/// assign gamma parameter to kernel (required)
|
/// assign gamma parameter to kernel (required)
|
||||||
/// ```rust
|
/// ```rust
|
||||||
/// use smartcore::svm::RBFKernel;
|
/// use smartcore::svm::RBFKernel;
|
||||||
@@ -141,29 +120,26 @@ impl<'a> RBFKernel<'a> {
|
|||||||
/// Polynomial kernel
|
/// Polynomial kernel
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub struct PolynomialKernel<'a> {
|
pub struct PolynomialKernel {
|
||||||
/// degree of the polynomial
|
/// degree of the polynomial
|
||||||
pub degree: Option<f64>,
|
pub degree: Option<f64>,
|
||||||
/// kernel coefficient
|
/// kernel coefficient
|
||||||
pub gamma: Option<f64>,
|
pub gamma: Option<f64>,
|
||||||
/// independent term in kernel function
|
/// independent term in kernel function
|
||||||
pub coef0: Option<f64>,
|
pub coef0: Option<f64>,
|
||||||
phantom: PhantomData<&'a ()>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Default for PolynomialKernel<'a> {
|
impl Default for PolynomialKernel {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
gamma: Option::None,
|
gamma: Option::None,
|
||||||
degree: Option::None,
|
degree: Option::None,
|
||||||
coef0: Some(1f64),
|
coef0: Some(1f64),
|
||||||
phantom: PhantomData,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
impl PolynomialKernel {
|
||||||
impl<'a> PolynomialKernel<'a> {
|
|
||||||
/// set parameters for kernel
|
/// set parameters for kernel
|
||||||
/// ```rust
|
/// ```rust
|
||||||
/// use smartcore::svm::PolynomialKernel;
|
/// use smartcore::svm::PolynomialKernel;
|
||||||
@@ -197,26 +173,23 @@ impl<'a> PolynomialKernel<'a> {
|
|||||||
/// Sigmoid (hyperbolic tangent) kernel
|
/// Sigmoid (hyperbolic tangent) kernel
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub struct SigmoidKernel<'a> {
|
pub struct SigmoidKernel {
|
||||||
/// kernel coefficient
|
/// kernel coefficient
|
||||||
pub gamma: Option<f64>,
|
pub gamma: Option<f64>,
|
||||||
/// independent term in kernel function
|
/// independent term in kernel function
|
||||||
pub coef0: Option<f64>,
|
pub coef0: Option<f64>,
|
||||||
phantom: PhantomData<&'a ()>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Default for SigmoidKernel<'a> {
|
impl Default for SigmoidKernel {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
gamma: Option::None,
|
gamma: Option::None,
|
||||||
coef0: Some(1f64),
|
coef0: Some(1f64),
|
||||||
phantom: PhantomData,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
impl SigmoidKernel {
|
||||||
impl<'a> SigmoidKernel<'a> {
|
|
||||||
/// set parameters for kernel
|
/// set parameters for kernel
|
||||||
/// ```rust
|
/// ```rust
|
||||||
/// use smartcore::svm::SigmoidKernel;
|
/// use smartcore::svm::SigmoidKernel;
|
||||||
@@ -238,16 +211,16 @@ impl<'a> SigmoidKernel<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Kernel<'a> for LinearKernel<'a> {
|
impl Kernel for LinearKernel {
|
||||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
||||||
Ok(x_i.dot(x_j))
|
Ok(x_i.dot(x_j))
|
||||||
}
|
}
|
||||||
fn name(&self) -> &'a str {
|
fn name(&self) -> &'static str {
|
||||||
"Linear"
|
"Linear"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Kernel<'a> for RBFKernel<'a> {
|
impl Kernel for RBFKernel {
|
||||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
||||||
if self.gamma.is_none() {
|
if self.gamma.is_none() {
|
||||||
return Err(Failed::because(
|
return Err(Failed::because(
|
||||||
@@ -258,12 +231,12 @@ impl<'a> Kernel<'a> for RBFKernel<'a> {
|
|||||||
let v_diff = x_i.sub(x_j);
|
let v_diff = x_i.sub(x_j);
|
||||||
Ok((-self.gamma.unwrap() * v_diff.mul(&v_diff).sum()).exp())
|
Ok((-self.gamma.unwrap() * v_diff.mul(&v_diff).sum()).exp())
|
||||||
}
|
}
|
||||||
fn name(&self) -> &'a str {
|
fn name(&self) -> &'static str {
|
||||||
"RBF"
|
"RBF"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Kernel<'a> for PolynomialKernel<'a> {
|
impl Kernel for PolynomialKernel {
|
||||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
||||||
if self.gamma.is_none() || self.coef0.is_none() || self.degree.is_none() {
|
if self.gamma.is_none() || self.coef0.is_none() || self.degree.is_none() {
|
||||||
return Err(Failed::because(
|
return Err(Failed::because(
|
||||||
@@ -274,12 +247,12 @@ impl<'a> Kernel<'a> for PolynomialKernel<'a> {
|
|||||||
let dot = x_i.dot(x_j);
|
let dot = x_i.dot(x_j);
|
||||||
Ok((self.gamma.unwrap() * dot + self.coef0.unwrap()).powf(self.degree.unwrap()))
|
Ok((self.gamma.unwrap() * dot + self.coef0.unwrap()).powf(self.degree.unwrap()))
|
||||||
}
|
}
|
||||||
fn name(&self) -> &'a str {
|
fn name(&self) -> &'static str {
|
||||||
"Polynomial"
|
"Polynomial"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Kernel<'a> for SigmoidKernel<'a> {
|
impl Kernel for SigmoidKernel {
|
||||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
||||||
if self.gamma.is_none() || self.coef0.is_none() {
|
if self.gamma.is_none() || self.coef0.is_none() {
|
||||||
return Err(Failed::because(
|
return Err(Failed::because(
|
||||||
@@ -290,7 +263,7 @@ impl<'a> Kernel<'a> for SigmoidKernel<'a> {
|
|||||||
let dot = x_i.dot(x_j);
|
let dot = x_i.dot(x_j);
|
||||||
Ok(self.gamma.unwrap() * dot + self.coef0.unwrap().tanh())
|
Ok(self.gamma.unwrap() * dot + self.coef0.unwrap().tanh())
|
||||||
}
|
}
|
||||||
fn name(&self) -> &'a str {
|
fn name(&self) -> &'static str {
|
||||||
"Sigmoid"
|
"Sigmoid"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+20
-26
@@ -58,7 +58,7 @@
|
|||||||
//! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
|
//! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
|
||||||
//!
|
//!
|
||||||
//! let knl = Kernels::linear();
|
//! let knl = Kernels::linear();
|
||||||
//! let params = &SVCParameters::default().with_c(200.0).with_kernel(&knl);
|
//! let params = &SVCParameters::default().with_c(200.0).with_kernel(knl);
|
||||||
//! let svc = SVC::fit(&x, &y, params).unwrap();
|
//! let svc = SVC::fit(&x, &y, params).unwrap();
|
||||||
//!
|
//!
|
||||||
//! let y_hat = svc.predict(&x).unwrap();
|
//! let y_hat = svc.predict(&x).unwrap();
|
||||||
@@ -91,15 +91,9 @@ use crate::rand_custom::get_rng_impl;
|
|||||||
use crate::svm::Kernel;
|
use crate::svm::Kernel;
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug)]
|
||||||
/// SVC Parameters
|
/// SVC Parameters
|
||||||
pub struct SVCParameters<
|
pub struct SVCParameters<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> {
|
||||||
'a,
|
|
||||||
TX: Number + RealNumber,
|
|
||||||
TY: Number + Ord,
|
|
||||||
X: Array2<TX>,
|
|
||||||
Y: Array1<TY>,
|
|
||||||
> {
|
|
||||||
/// Number of epochs.
|
/// Number of epochs.
|
||||||
pub epoch: usize,
|
pub epoch: usize,
|
||||||
/// Regularization parameter.
|
/// Regularization parameter.
|
||||||
@@ -108,7 +102,7 @@ pub struct SVCParameters<
|
|||||||
pub tol: TX,
|
pub tol: TX,
|
||||||
#[cfg_attr(feature = "serde", serde(skip_deserializing))]
|
#[cfg_attr(feature = "serde", serde(skip_deserializing))]
|
||||||
/// The kernel function.
|
/// The kernel function.
|
||||||
pub kernel: Option<&'a dyn Kernel<'a>>,
|
pub kernel: Option<Box<dyn Kernel>>,
|
||||||
/// Unused parameter.
|
/// Unused parameter.
|
||||||
m: PhantomData<(X, Y, TY)>,
|
m: PhantomData<(X, Y, TY)>,
|
||||||
/// Controls the pseudo random number generation for shuffling the data for probability estimates
|
/// Controls the pseudo random number generation for shuffling the data for probability estimates
|
||||||
@@ -129,7 +123,7 @@ pub struct SVC<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y:
|
|||||||
classes: Option<Vec<TY>>,
|
classes: Option<Vec<TY>>,
|
||||||
instances: Option<Vec<Vec<TX>>>,
|
instances: Option<Vec<Vec<TX>>>,
|
||||||
#[cfg_attr(feature = "serde", serde(skip))]
|
#[cfg_attr(feature = "serde", serde(skip))]
|
||||||
parameters: Option<&'a SVCParameters<'a, TX, TY, X, Y>>,
|
parameters: Option<&'a SVCParameters<TX, TY, X, Y>>,
|
||||||
w: Option<Vec<TX>>,
|
w: Option<Vec<TX>>,
|
||||||
b: Option<TX>,
|
b: Option<TX>,
|
||||||
phantomdata: PhantomData<(X, Y)>,
|
phantomdata: PhantomData<(X, Y)>,
|
||||||
@@ -155,7 +149,7 @@ struct Cache<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1
|
|||||||
struct Optimizer<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> {
|
struct Optimizer<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> {
|
||||||
x: &'a X,
|
x: &'a X,
|
||||||
y: &'a Y,
|
y: &'a Y,
|
||||||
parameters: &'a SVCParameters<'a, TX, TY, X, Y>,
|
parameters: &'a SVCParameters<TX, TY, X, Y>,
|
||||||
svmin: usize,
|
svmin: usize,
|
||||||
svmax: usize,
|
svmax: usize,
|
||||||
gmin: TX,
|
gmin: TX,
|
||||||
@@ -165,8 +159,8 @@ struct Optimizer<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y
|
|||||||
recalculate_minmax_grad: bool,
|
recalculate_minmax_grad: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
impl<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||||
SVCParameters<'a, TX, TY, X, Y>
|
SVCParameters<TX, TY, X, Y>
|
||||||
{
|
{
|
||||||
/// Number of epochs.
|
/// Number of epochs.
|
||||||
pub fn with_epoch(mut self, epoch: usize) -> Self {
|
pub fn with_epoch(mut self, epoch: usize) -> Self {
|
||||||
@@ -184,8 +178,8 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
/// The kernel function.
|
/// The kernel function.
|
||||||
pub fn with_kernel(mut self, kernel: &'a (dyn Kernel<'a>)) -> Self {
|
pub fn with_kernel<K: Kernel + 'static>(mut self, kernel: K) -> Self {
|
||||||
self.kernel = Some(kernel);
|
self.kernel = Some(Box::new(kernel));
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -196,8 +190,8 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> Default
|
impl<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> Default
|
||||||
for SVCParameters<'a, TX, TY, X, Y>
|
for SVCParameters<TX, TY, X, Y>
|
||||||
{
|
{
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
SVCParameters {
|
SVCParameters {
|
||||||
@@ -212,7 +206,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||||
SupervisedEstimatorBorrow<'a, X, Y, SVCParameters<'a, TX, TY, X, Y>> for SVC<'a, TX, TY, X, Y>
|
SupervisedEstimatorBorrow<'a, X, Y, SVCParameters<TX, TY, X, Y>> for SVC<'a, TX, TY, X, Y>
|
||||||
{
|
{
|
||||||
fn new() -> Self {
|
fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -227,7 +221,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
|||||||
fn fit(
|
fn fit(
|
||||||
x: &'a X,
|
x: &'a X,
|
||||||
y: &'a Y,
|
y: &'a Y,
|
||||||
parameters: &'a SVCParameters<'a, TX, TY, X, Y>,
|
parameters: &'a SVCParameters<TX, TY, X, Y>,
|
||||||
) -> Result<Self, Failed> {
|
) -> Result<Self, Failed> {
|
||||||
SVC::fit(x, y, parameters)
|
SVC::fit(x, y, parameters)
|
||||||
}
|
}
|
||||||
@@ -251,7 +245,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array
|
|||||||
pub fn fit(
|
pub fn fit(
|
||||||
x: &'a X,
|
x: &'a X,
|
||||||
y: &'a Y,
|
y: &'a Y,
|
||||||
parameters: &'a SVCParameters<'a, TX, TY, X, Y>,
|
parameters: &'a SVCParameters<TX, TY, X, Y>,
|
||||||
) -> Result<SVC<'a, TX, TY, X, Y>, Failed> {
|
) -> Result<SVC<'a, TX, TY, X, Y>, Failed> {
|
||||||
let (n, _) = x.shape();
|
let (n, _) = x.shape();
|
||||||
|
|
||||||
@@ -447,7 +441,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
|||||||
fn new(
|
fn new(
|
||||||
x: &'a X,
|
x: &'a X,
|
||||||
y: &'a Y,
|
y: &'a Y,
|
||||||
parameters: &'a SVCParameters<'a, TX, TY, X, Y>,
|
parameters: &'a SVCParameters<TX, TY, X, Y>,
|
||||||
) -> Optimizer<'a, TX, TY, X, Y> {
|
) -> Optimizer<'a, TX, TY, X, Y> {
|
||||||
let (n, _) = x.shape();
|
let (n, _) = x.shape();
|
||||||
|
|
||||||
@@ -979,7 +973,7 @@ mod tests {
|
|||||||
let knl = Kernels::linear();
|
let knl = Kernels::linear();
|
||||||
let params = SVCParameters::default()
|
let params = SVCParameters::default()
|
||||||
.with_c(200.0)
|
.with_c(200.0)
|
||||||
.with_kernel(&knl)
|
.with_kernel(knl)
|
||||||
.with_seed(Some(100));
|
.with_seed(Some(100));
|
||||||
|
|
||||||
let y_hat = SVC::fit(&x, &y, ¶ms)
|
let y_hat = SVC::fit(&x, &y, ¶ms)
|
||||||
@@ -1018,7 +1012,7 @@ mod tests {
|
|||||||
&y,
|
&y,
|
||||||
&SVCParameters::default()
|
&SVCParameters::default()
|
||||||
.with_c(200.0)
|
.with_c(200.0)
|
||||||
.with_kernel(&Kernels::linear()),
|
.with_kernel(Kernels::linear()),
|
||||||
)
|
)
|
||||||
.and_then(|lr| lr.decision_function(&x2))
|
.and_then(|lr| lr.decision_function(&x2))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -1073,7 +1067,7 @@ mod tests {
|
|||||||
&y,
|
&y,
|
||||||
&SVCParameters::default()
|
&SVCParameters::default()
|
||||||
.with_c(1.0)
|
.with_c(1.0)
|
||||||
.with_kernel(&Kernels::rbf().with_gamma(0.7)),
|
.with_kernel(Kernels::rbf().with_gamma(0.7)),
|
||||||
)
|
)
|
||||||
.and_then(|lr| lr.predict(&x))
|
.and_then(|lr| lr.predict(&x))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -1122,7 +1116,7 @@ mod tests {
|
|||||||
];
|
];
|
||||||
|
|
||||||
let knl = Kernels::linear();
|
let knl = Kernels::linear();
|
||||||
let params = SVCParameters::default().with_kernel(&knl);
|
let params = SVCParameters::default().with_kernel(knl);
|
||||||
let svc = SVC::fit(&x, &y, ¶ms).unwrap();
|
let svc = SVC::fit(&x, &y, ¶ms).unwrap();
|
||||||
|
|
||||||
// serialization
|
// serialization
|
||||||
|
|||||||
+16
-16
@@ -50,7 +50,7 @@
|
|||||||
//! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
//! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
||||||
//!
|
//!
|
||||||
//! let knl = Kernels::linear();
|
//! let knl = Kernels::linear();
|
||||||
//! let params = &SVRParameters::default().with_eps(2.0).with_c(10.0).with_kernel(&knl);
|
//! let params = &SVRParameters::default().with_eps(2.0).with_c(10.0).with_kernel(knl);
|
||||||
//! // let svr = SVR::fit(&x, &y, params).unwrap();
|
//! // let svr = SVR::fit(&x, &y, params).unwrap();
|
||||||
//!
|
//!
|
||||||
//! // let y_hat = svr.predict(&x).unwrap();
|
//! // let y_hat = svr.predict(&x).unwrap();
|
||||||
@@ -83,9 +83,9 @@ use crate::numbers::floatnum::FloatNumber;
|
|||||||
use crate::svm::Kernel;
|
use crate::svm::Kernel;
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug)]
|
||||||
/// SVR Parameters
|
/// SVR Parameters
|
||||||
pub struct SVRParameters<'a, T: Number + FloatNumber + PartialOrd> {
|
pub struct SVRParameters<T: Number + FloatNumber + PartialOrd> {
|
||||||
/// Epsilon in the epsilon-SVR model.
|
/// Epsilon in the epsilon-SVR model.
|
||||||
pub eps: T,
|
pub eps: T,
|
||||||
/// Regularization parameter.
|
/// Regularization parameter.
|
||||||
@@ -94,7 +94,7 @@ pub struct SVRParameters<'a, T: Number + FloatNumber + PartialOrd> {
|
|||||||
pub tol: T,
|
pub tol: T,
|
||||||
#[cfg_attr(feature = "serde", serde(skip_deserializing))]
|
#[cfg_attr(feature = "serde", serde(skip_deserializing))]
|
||||||
/// The kernel function.
|
/// The kernel function.
|
||||||
pub kernel: Option<&'a dyn Kernel<'a>>,
|
pub kernel: Option<Box<dyn Kernel>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
@@ -103,7 +103,7 @@ pub struct SVRParameters<'a, T: Number + FloatNumber + PartialOrd> {
|
|||||||
pub struct SVR<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> {
|
pub struct SVR<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> {
|
||||||
instances: Option<Vec<Vec<f64>>>,
|
instances: Option<Vec<Vec<f64>>>,
|
||||||
#[cfg_attr(feature = "serde", serde(skip_deserializing))]
|
#[cfg_attr(feature = "serde", serde(skip_deserializing))]
|
||||||
parameters: Option<&'a SVRParameters<'a, T>>,
|
parameters: Option<&'a SVRParameters<T>>,
|
||||||
w: Option<Vec<T>>,
|
w: Option<Vec<T>>,
|
||||||
b: T,
|
b: T,
|
||||||
phantom: PhantomData<(X, Y)>,
|
phantom: PhantomData<(X, Y)>,
|
||||||
@@ -123,7 +123,7 @@ struct SupportVector<T> {
|
|||||||
struct Optimizer<'a, T: Number + FloatNumber + PartialOrd> {
|
struct Optimizer<'a, T: Number + FloatNumber + PartialOrd> {
|
||||||
tol: T,
|
tol: T,
|
||||||
c: T,
|
c: T,
|
||||||
parameters: Option<&'a SVRParameters<'a, T>>,
|
parameters: Option<&'a SVRParameters<T>>,
|
||||||
svmin: usize,
|
svmin: usize,
|
||||||
svmax: usize,
|
svmax: usize,
|
||||||
gmin: T,
|
gmin: T,
|
||||||
@@ -140,7 +140,7 @@ struct Cache<T: Clone> {
|
|||||||
data: Vec<RefCell<Option<Vec<T>>>>,
|
data: Vec<RefCell<Option<Vec<T>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: Number + FloatNumber + PartialOrd> SVRParameters<'a, T> {
|
impl<T: Number + FloatNumber + PartialOrd> SVRParameters<T> {
|
||||||
/// Epsilon in the epsilon-SVR model.
|
/// Epsilon in the epsilon-SVR model.
|
||||||
pub fn with_eps(mut self, eps: T) -> Self {
|
pub fn with_eps(mut self, eps: T) -> Self {
|
||||||
self.eps = eps;
|
self.eps = eps;
|
||||||
@@ -157,13 +157,13 @@ impl<'a, T: Number + FloatNumber + PartialOrd> SVRParameters<'a, T> {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
/// The kernel function.
|
/// The kernel function.
|
||||||
pub fn with_kernel(mut self, kernel: &'a (dyn Kernel<'a>)) -> Self {
|
pub fn with_kernel<K: Kernel + 'static>(mut self, kernel: K) -> Self {
|
||||||
self.kernel = Some(kernel);
|
self.kernel = Some(Box::new(kernel));
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: Number + FloatNumber + PartialOrd> Default for SVRParameters<'a, T> {
|
impl<T: Number + FloatNumber + PartialOrd> Default for SVRParameters<T> {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
SVRParameters {
|
SVRParameters {
|
||||||
eps: T::from_f64(0.1).unwrap(),
|
eps: T::from_f64(0.1).unwrap(),
|
||||||
@@ -175,7 +175,7 @@ impl<'a, T: Number + FloatNumber + PartialOrd> Default for SVRParameters<'a, T>
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>>
|
impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>>
|
||||||
SupervisedEstimatorBorrow<'a, X, Y, SVRParameters<'a, T>> for SVR<'a, T, X, Y>
|
SupervisedEstimatorBorrow<'a, X, Y, SVRParameters<T>> for SVR<'a, T, X, Y>
|
||||||
{
|
{
|
||||||
fn new() -> Self {
|
fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -186,7 +186,7 @@ impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>>
|
|||||||
phantom: PhantomData,
|
phantom: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fn fit(x: &'a X, y: &'a Y, parameters: &'a SVRParameters<'a, T>) -> Result<Self, Failed> {
|
fn fit(x: &'a X, y: &'a Y, parameters: &'a SVRParameters<T>) -> Result<Self, Failed> {
|
||||||
SVR::fit(x, y, parameters)
|
SVR::fit(x, y, parameters)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -208,7 +208,7 @@ impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> SVR<'
|
|||||||
pub fn fit(
|
pub fn fit(
|
||||||
x: &'a X,
|
x: &'a X,
|
||||||
y: &'a Y,
|
y: &'a Y,
|
||||||
parameters: &'a SVRParameters<'a, T>,
|
parameters: &'a SVRParameters<T>,
|
||||||
) -> Result<SVR<'a, T, X, Y>, Failed> {
|
) -> Result<SVR<'a, T, X, Y>, Failed> {
|
||||||
let (n, _) = x.shape();
|
let (n, _) = x.shape();
|
||||||
|
|
||||||
@@ -324,7 +324,7 @@ impl<'a, T: Number + FloatNumber + PartialOrd> Optimizer<'a, T> {
|
|||||||
fn new<X: Array2<T>, Y: Array1<T>>(
|
fn new<X: Array2<T>, Y: Array1<T>>(
|
||||||
x: &'a X,
|
x: &'a X,
|
||||||
y: &'a Y,
|
y: &'a Y,
|
||||||
parameters: &'a SVRParameters<'a, T>,
|
parameters: &'a SVRParameters<T>,
|
||||||
) -> Optimizer<'a, T> {
|
) -> Optimizer<'a, T> {
|
||||||
let (n, _) = x.shape();
|
let (n, _) = x.shape();
|
||||||
|
|
||||||
@@ -655,7 +655,7 @@ mod tests {
|
|||||||
&SVRParameters::default()
|
&SVRParameters::default()
|
||||||
.with_eps(2.0)
|
.with_eps(2.0)
|
||||||
.with_c(10.0)
|
.with_c(10.0)
|
||||||
.with_kernel(&knl),
|
.with_kernel(knl),
|
||||||
)
|
)
|
||||||
.and_then(|lr| lr.predict(&x))
|
.and_then(|lr| lr.predict(&x))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -697,7 +697,7 @@ mod tests {
|
|||||||
];
|
];
|
||||||
|
|
||||||
let knl = Kernels::rbf().with_gamma(0.7);
|
let knl = Kernels::rbf().with_gamma(0.7);
|
||||||
let params = SVRParameters::default().with_kernel(&knl);
|
let params = SVRParameters::default().with_kernel(knl);
|
||||||
|
|
||||||
let svr = SVR::fit(&x, &y, ¶ms).unwrap();
|
let svr = SVR::fit(&x, &y, ¶ms).unwrap();
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user