Use Box in SVM and remove lifetimes (#228)
* Do not change external API Authored-by: Luis Moreno <morenol@users.noreply.github.com>
This commit is contained in:
+20
-26
@@ -58,7 +58,7 @@
|
||||
//! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
|
||||
//!
|
||||
//! let knl = Kernels::linear();
|
||||
//! let params = &SVCParameters::default().with_c(200.0).with_kernel(&knl);
|
||||
//! let params = &SVCParameters::default().with_c(200.0).with_kernel(knl);
|
||||
//! let svc = SVC::fit(&x, &y, params).unwrap();
|
||||
//!
|
||||
//! let y_hat = svc.predict(&x).unwrap();
|
||||
@@ -91,15 +91,9 @@ use crate::rand_custom::get_rng_impl;
|
||||
use crate::svm::Kernel;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug)]
|
||||
/// SVC Parameters
|
||||
pub struct SVCParameters<
|
||||
'a,
|
||||
TX: Number + RealNumber,
|
||||
TY: Number + Ord,
|
||||
X: Array2<TX>,
|
||||
Y: Array1<TY>,
|
||||
> {
|
||||
pub struct SVCParameters<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> {
|
||||
/// Number of epochs.
|
||||
pub epoch: usize,
|
||||
/// Regularization parameter.
|
||||
@@ -108,7 +102,7 @@ pub struct SVCParameters<
|
||||
pub tol: TX,
|
||||
#[cfg_attr(feature = "serde", serde(skip_deserializing))]
|
||||
/// The kernel function.
|
||||
pub kernel: Option<&'a dyn Kernel<'a>>,
|
||||
pub kernel: Option<Box<dyn Kernel>>,
|
||||
/// Unused parameter.
|
||||
m: PhantomData<(X, Y, TY)>,
|
||||
/// Controls the pseudo random number generation for shuffling the data for probability estimates
|
||||
@@ -129,7 +123,7 @@ pub struct SVC<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y:
|
||||
classes: Option<Vec<TY>>,
|
||||
instances: Option<Vec<Vec<TX>>>,
|
||||
#[cfg_attr(feature = "serde", serde(skip))]
|
||||
parameters: Option<&'a SVCParameters<'a, TX, TY, X, Y>>,
|
||||
parameters: Option<&'a SVCParameters<TX, TY, X, Y>>,
|
||||
w: Option<Vec<TX>>,
|
||||
b: Option<TX>,
|
||||
phantomdata: PhantomData<(X, Y)>,
|
||||
@@ -155,7 +149,7 @@ struct Cache<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1
|
||||
struct Optimizer<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> {
|
||||
x: &'a X,
|
||||
y: &'a Y,
|
||||
parameters: &'a SVCParameters<'a, TX, TY, X, Y>,
|
||||
parameters: &'a SVCParameters<TX, TY, X, Y>,
|
||||
svmin: usize,
|
||||
svmax: usize,
|
||||
gmin: TX,
|
||||
@@ -165,8 +159,8 @@ struct Optimizer<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y
|
||||
recalculate_minmax_grad: bool,
|
||||
}
|
||||
|
||||
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
SVCParameters<'a, TX, TY, X, Y>
|
||||
impl<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
SVCParameters<TX, TY, X, Y>
|
||||
{
|
||||
/// Number of epochs.
|
||||
pub fn with_epoch(mut self, epoch: usize) -> Self {
|
||||
@@ -184,8 +178,8 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
||||
self
|
||||
}
|
||||
/// The kernel function.
|
||||
pub fn with_kernel(mut self, kernel: &'a (dyn Kernel<'a>)) -> Self {
|
||||
self.kernel = Some(kernel);
|
||||
pub fn with_kernel<K: Kernel + 'static>(mut self, kernel: K) -> Self {
|
||||
self.kernel = Some(Box::new(kernel));
|
||||
self
|
||||
}
|
||||
|
||||
@@ -196,8 +190,8 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> Default
|
||||
for SVCParameters<'a, TX, TY, X, Y>
|
||||
impl<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> Default
|
||||
for SVCParameters<TX, TY, X, Y>
|
||||
{
|
||||
fn default() -> Self {
|
||||
SVCParameters {
|
||||
@@ -212,7 +206,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
||||
}
|
||||
|
||||
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
SupervisedEstimatorBorrow<'a, X, Y, SVCParameters<'a, TX, TY, X, Y>> for SVC<'a, TX, TY, X, Y>
|
||||
SupervisedEstimatorBorrow<'a, X, Y, SVCParameters<TX, TY, X, Y>> for SVC<'a, TX, TY, X, Y>
|
||||
{
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
@@ -227,7 +221,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
||||
fn fit(
|
||||
x: &'a X,
|
||||
y: &'a Y,
|
||||
parameters: &'a SVCParameters<'a, TX, TY, X, Y>,
|
||||
parameters: &'a SVCParameters<TX, TY, X, Y>,
|
||||
) -> Result<Self, Failed> {
|
||||
SVC::fit(x, y, parameters)
|
||||
}
|
||||
@@ -251,7 +245,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array
|
||||
pub fn fit(
|
||||
x: &'a X,
|
||||
y: &'a Y,
|
||||
parameters: &'a SVCParameters<'a, TX, TY, X, Y>,
|
||||
parameters: &'a SVCParameters<TX, TY, X, Y>,
|
||||
) -> Result<SVC<'a, TX, TY, X, Y>, Failed> {
|
||||
let (n, _) = x.shape();
|
||||
|
||||
@@ -447,7 +441,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
||||
fn new(
|
||||
x: &'a X,
|
||||
y: &'a Y,
|
||||
parameters: &'a SVCParameters<'a, TX, TY, X, Y>,
|
||||
parameters: &'a SVCParameters<TX, TY, X, Y>,
|
||||
) -> Optimizer<'a, TX, TY, X, Y> {
|
||||
let (n, _) = x.shape();
|
||||
|
||||
@@ -979,7 +973,7 @@ mod tests {
|
||||
let knl = Kernels::linear();
|
||||
let params = SVCParameters::default()
|
||||
.with_c(200.0)
|
||||
.with_kernel(&knl)
|
||||
.with_kernel(knl)
|
||||
.with_seed(Some(100));
|
||||
|
||||
let y_hat = SVC::fit(&x, &y, ¶ms)
|
||||
@@ -1018,7 +1012,7 @@ mod tests {
|
||||
&y,
|
||||
&SVCParameters::default()
|
||||
.with_c(200.0)
|
||||
.with_kernel(&Kernels::linear()),
|
||||
.with_kernel(Kernels::linear()),
|
||||
)
|
||||
.and_then(|lr| lr.decision_function(&x2))
|
||||
.unwrap();
|
||||
@@ -1073,7 +1067,7 @@ mod tests {
|
||||
&y,
|
||||
&SVCParameters::default()
|
||||
.with_c(1.0)
|
||||
.with_kernel(&Kernels::rbf().with_gamma(0.7)),
|
||||
.with_kernel(Kernels::rbf().with_gamma(0.7)),
|
||||
)
|
||||
.and_then(|lr| lr.predict(&x))
|
||||
.unwrap();
|
||||
@@ -1122,7 +1116,7 @@ mod tests {
|
||||
];
|
||||
|
||||
let knl = Kernels::linear();
|
||||
let params = SVCParameters::default().with_kernel(&knl);
|
||||
let params = SVCParameters::default().with_kernel(knl);
|
||||
let svc = SVC::fit(&x, &y, ¶ms).unwrap();
|
||||
|
||||
// serialization
|
||||
|
||||
Reference in New Issue
Block a user