Use Box in SVM and remove lifetimes (#228)

* Do not change external API
Authored-by: Luis Moreno <morenol@users.noreply.github.com>
This commit is contained in:
morenol
2022-11-04 17:08:30 -05:00
committed by GitHub
parent 35fe68e024
commit 425c3c1d0b
3 changed files with 64 additions and 97 deletions
+20 -26
View File
@@ -58,7 +58,7 @@
//! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
//!
//! let knl = Kernels::linear();
//! let params = &SVCParameters::default().with_c(200.0).with_kernel(&knl);
//! let params = &SVCParameters::default().with_c(200.0).with_kernel(knl);
//! let svc = SVC::fit(&x, &y, params).unwrap();
//!
//! let y_hat = svc.predict(&x).unwrap();
@@ -91,15 +91,9 @@ use crate::rand_custom::get_rng_impl;
use crate::svm::Kernel;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
#[derive(Debug)]
/// SVC Parameters
pub struct SVCParameters<
'a,
TX: Number + RealNumber,
TY: Number + Ord,
X: Array2<TX>,
Y: Array1<TY>,
> {
pub struct SVCParameters<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> {
/// Number of epochs.
pub epoch: usize,
/// Regularization parameter.
@@ -108,7 +102,7 @@ pub struct SVCParameters<
pub tol: TX,
#[cfg_attr(feature = "serde", serde(skip_deserializing))]
/// The kernel function.
pub kernel: Option<&'a dyn Kernel<'a>>,
pub kernel: Option<Box<dyn Kernel>>,
/// Unused parameter.
m: PhantomData<(X, Y, TY)>,
/// Controls the pseudo random number generation for shuffling the data for probability estimates
@@ -129,7 +123,7 @@ pub struct SVC<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y:
classes: Option<Vec<TY>>,
instances: Option<Vec<Vec<TX>>>,
#[cfg_attr(feature = "serde", serde(skip))]
parameters: Option<&'a SVCParameters<'a, TX, TY, X, Y>>,
parameters: Option<&'a SVCParameters<TX, TY, X, Y>>,
w: Option<Vec<TX>>,
b: Option<TX>,
phantomdata: PhantomData<(X, Y)>,
@@ -155,7 +149,7 @@ struct Cache<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1
struct Optimizer<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> {
x: &'a X,
y: &'a Y,
parameters: &'a SVCParameters<'a, TX, TY, X, Y>,
parameters: &'a SVCParameters<TX, TY, X, Y>,
svmin: usize,
svmax: usize,
gmin: TX,
@@ -165,8 +159,8 @@ struct Optimizer<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y
recalculate_minmax_grad: bool,
}
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
SVCParameters<'a, TX, TY, X, Y>
impl<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
SVCParameters<TX, TY, X, Y>
{
/// Number of epochs.
pub fn with_epoch(mut self, epoch: usize) -> Self {
@@ -184,8 +178,8 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
self
}
/// The kernel function.
pub fn with_kernel(mut self, kernel: &'a (dyn Kernel<'a>)) -> Self {
self.kernel = Some(kernel);
pub fn with_kernel<K: Kernel + 'static>(mut self, kernel: K) -> Self {
self.kernel = Some(Box::new(kernel));
self
}
@@ -196,8 +190,8 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
}
}
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> Default
for SVCParameters<'a, TX, TY, X, Y>
impl<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> Default
for SVCParameters<TX, TY, X, Y>
{
fn default() -> Self {
SVCParameters {
@@ -212,7 +206,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
}
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
SupervisedEstimatorBorrow<'a, X, Y, SVCParameters<'a, TX, TY, X, Y>> for SVC<'a, TX, TY, X, Y>
SupervisedEstimatorBorrow<'a, X, Y, SVCParameters<TX, TY, X, Y>> for SVC<'a, TX, TY, X, Y>
{
fn new() -> Self {
Self {
@@ -227,7 +221,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
fn fit(
x: &'a X,
y: &'a Y,
parameters: &'a SVCParameters<'a, TX, TY, X, Y>,
parameters: &'a SVCParameters<TX, TY, X, Y>,
) -> Result<Self, Failed> {
SVC::fit(x, y, parameters)
}
@@ -251,7 +245,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array
pub fn fit(
x: &'a X,
y: &'a Y,
parameters: &'a SVCParameters<'a, TX, TY, X, Y>,
parameters: &'a SVCParameters<TX, TY, X, Y>,
) -> Result<SVC<'a, TX, TY, X, Y>, Failed> {
let (n, _) = x.shape();
@@ -447,7 +441,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
fn new(
x: &'a X,
y: &'a Y,
parameters: &'a SVCParameters<'a, TX, TY, X, Y>,
parameters: &'a SVCParameters<TX, TY, X, Y>,
) -> Optimizer<'a, TX, TY, X, Y> {
let (n, _) = x.shape();
@@ -979,7 +973,7 @@ mod tests {
let knl = Kernels::linear();
let params = SVCParameters::default()
.with_c(200.0)
.with_kernel(&knl)
.with_kernel(knl)
.with_seed(Some(100));
let y_hat = SVC::fit(&x, &y, &params)
@@ -1018,7 +1012,7 @@ mod tests {
&y,
&SVCParameters::default()
.with_c(200.0)
.with_kernel(&Kernels::linear()),
.with_kernel(Kernels::linear()),
)
.and_then(|lr| lr.decision_function(&x2))
.unwrap();
@@ -1073,7 +1067,7 @@ mod tests {
&y,
&SVCParameters::default()
.with_c(1.0)
.with_kernel(&Kernels::rbf().with_gamma(0.7)),
.with_kernel(Kernels::rbf().with_gamma(0.7)),
)
.and_then(|lr| lr.predict(&x))
.unwrap();
@@ -1122,7 +1116,7 @@ mod tests {
];
let knl = Kernels::linear();
let params = SVCParameters::default().with_kernel(&knl);
let params = SVCParameters::default().with_kernel(knl);
let svc = SVC::fit(&x, &y, &params).unwrap();
// serialization