From 19644245894be3fe4c849f51412d2c0992be7f6f Mon Sep 17 00:00:00 2001 From: Lorenzo Date: Thu, 3 Nov 2022 11:48:40 +0000 Subject: [PATCH] Fix svr tests (#222) --- src/svm/mod.rs | 4 +- src/svm/search/mod.rs | 2 +- src/svm/search/svc_params.rs | 1 - src/svm/search/svr_params.rs | 2 +- src/svm/svc.rs | 5 - src/svm/svr.rs | 204 ++++++++++++++++++++--------------- 6 files changed, 121 insertions(+), 97 deletions(-) diff --git a/src/svm/mod.rs b/src/svm/mod.rs index d98a0ab..48e5907 100644 --- a/src/svm/mod.rs +++ b/src/svm/mod.rs @@ -22,10 +22,10 @@ //! //! //! -pub mod svc; -pub mod svr; /// search parameters pub mod search; +pub mod svc; +pub mod svr; use core::fmt::Debug; use std::marker::PhantomData; diff --git a/src/svm/search/mod.rs b/src/svm/search/mod.rs index 0d67cc4..6d86feb 100644 --- a/src/svm/search/mod.rs +++ b/src/svm/search/mod.rs @@ -1,4 +1,4 @@ /// SVC search parameters pub mod svc_params; /// SVC search parameters -pub mod svr_params; \ No newline at end of file +pub mod svr_params; diff --git a/src/svm/search/svc_params.rs b/src/svm/search/svc_params.rs index e8c836c..42f686b 100644 --- a/src/svm/search/svc_params.rs +++ b/src/svm/search/svc_params.rs @@ -135,7 +135,6 @@ // } // } - // #[cfg(test)] // mod tests { // use num::ToPrimitive; diff --git a/src/svm/search/svr_params.rs b/src/svm/search/svr_params.rs index 48d18ae..03d0ece 100644 --- a/src/svm/search/svr_params.rs +++ b/src/svm/search/svr_params.rs @@ -109,4 +109,4 @@ // serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize", // deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>", // )) -// )] \ No newline at end of file +// )] diff --git a/src/svm/svc.rs b/src/svm/svc.rs index ce1e57c..716f521 100644 --- a/src/svm/svc.rs +++ b/src/svm/svc.rs @@ -100,22 +100,17 @@ pub struct SVCParameters< X: Array2, Y: Array1, > { - #[cfg_attr(feature = "serde", serde(default))] /// Number of epochs. pub epoch: usize, - #[cfg_attr(feature = "serde", serde(default))] /// Regularization parameter. pub c: TX, - #[cfg_attr(feature = "serde", serde(default))] /// Tolerance for stopping criterion. pub tol: TX, #[cfg_attr(feature = "serde", serde(skip_deserializing))] /// The kernel function. pub kernel: Option<&'a dyn Kernel<'a>>, - #[cfg_attr(feature = "serde", serde(default))] /// Unused parameter. m: PhantomData<(X, Y, TY)>, - #[cfg_attr(feature = "serde", serde(default))] /// Controls the pseudo random number generation for shuffling the data for probability estimates seed: Option, } diff --git a/src/svm/svr.rs b/src/svm/svr.rs index 71bed36..cf35bde 100644 --- a/src/svm/svr.rs +++ b/src/svm/svr.rs @@ -79,13 +79,13 @@ use crate::api::{PredictorBorrow, SupervisedEstimatorBorrow}; use crate::error::{Failed, FailedError}; use crate::linalg::basic::arrays::{Array1, Array2, MutArray}; use crate::numbers::basenum::Number; -use crate::numbers::realnum::RealNumber; +use crate::numbers::floatnum::FloatNumber; use crate::svm::Kernel; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] /// SVR Parameters -pub struct SVRParameters<'a, T: Number + RealNumber> { +pub struct SVRParameters<'a, T: Number + FloatNumber + PartialOrd> { /// Epsilon in the epsilon-SVR model. pub eps: T, /// Regularization parameter. @@ -97,9 +97,12 @@ pub struct SVRParameters<'a, T: Number + RealNumber> { pub kernel: Option<&'a dyn Kernel<'a>>, } +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] /// Epsilon-Support Vector Regression -pub struct SVR<'a, T: Number + RealNumber, X: Array2, Y: Array1> { +pub struct SVR<'a, T: Number + FloatNumber + PartialOrd, X: Array2, Y: Array1> { instances: Option>>, + #[cfg_attr(feature = "serde", serde(skip_deserializing))] parameters: Option<&'a SVRParameters<'a, T>>, w: Option>, b: T, @@ -117,7 +120,7 @@ struct SupportVector { } /// Sequential Minimal Optimization algorithm -struct Optimizer<'a, T: Number + RealNumber> { +struct Optimizer<'a, T: Number + FloatNumber + PartialOrd> { tol: T, c: T, parameters: Option<&'a SVRParameters<'a, T>>, @@ -129,13 +132,15 @@ struct Optimizer<'a, T: Number + RealNumber> { gmaxindex: usize, tau: T, sv: Vec>, + /// avoid infinite loop if SMO does not converge + max_iterations: usize, } struct Cache { data: Vec>>>, } -impl<'a, T: Number + RealNumber> SVRParameters<'a, T> { +impl<'a, T: Number + FloatNumber + PartialOrd> SVRParameters<'a, T> { /// Epsilon in the epsilon-SVR model. pub fn with_eps(mut self, eps: T) -> Self { self.eps = eps; @@ -158,7 +163,7 @@ impl<'a, T: Number + RealNumber> SVRParameters<'a, T> { } } -impl<'a, T: Number + RealNumber> Default for SVRParameters<'a, T> { +impl<'a, T: Number + FloatNumber + PartialOrd> Default for SVRParameters<'a, T> { fn default() -> Self { SVRParameters { eps: T::from_f64(0.1).unwrap(), @@ -169,7 +174,7 @@ impl<'a, T: Number + RealNumber> Default for SVRParameters<'a, T> { } } -impl<'a, T: Number + RealNumber, X: Array2, Y: Array1> +impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2, Y: Array1> SupervisedEstimatorBorrow<'a, X, Y, SVRParameters<'a, T>> for SVR<'a, T, X, Y> { fn new() -> Self { @@ -186,7 +191,7 @@ impl<'a, T: Number + RealNumber, X: Array2, Y: Array1> } } -impl<'a, T: Number + RealNumber, X: Array2, Y: Array1> PredictorBorrow<'a, X, T> +impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2, Y: Array1> PredictorBorrow<'a, X, T> for SVR<'a, T, X, Y> { fn predict(&self, x: &'a X) -> Result, Failed> { @@ -194,7 +199,7 @@ impl<'a, T: Number + RealNumber, X: Array2, Y: Array1> PredictorBorrow<'a, } } -impl<'a, T: Number + RealNumber, X: Array2, Y: Array1> SVR<'a, T, X, Y> { +impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2, Y: Array1> SVR<'a, T, X, Y> { /// Fits SVR to your data. /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation. /// * `y` - target values @@ -275,7 +280,9 @@ impl<'a, T: Number + RealNumber, X: Array2, Y: Array1> SVR<'a, T, X, Y> { } } -impl<'a, T: Number + RealNumber, X: Array2, Y: Array1> PartialEq for SVR<'a, T, X, Y> { +impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2, Y: Array1> PartialEq + for SVR<'a, T, X, Y> +{ fn eq(&self, other: &Self) -> bool { if (self.b - other.b).abs() > T::epsilon() * T::two() || self.w.as_ref().unwrap().len() != other.w.as_ref().unwrap().len() @@ -301,7 +308,7 @@ impl<'a, T: Number + RealNumber, X: Array2, Y: Array1> PartialEq for SVR<' } } -impl SupportVector { +impl SupportVector { fn new(i: usize, x: Vec, y: T, eps: T, k: f64) -> SupportVector { SupportVector { index: i, @@ -313,7 +320,7 @@ impl SupportVector { } } -impl<'a, T: Number + RealNumber> Optimizer<'a, T> { +impl<'a, T: Number + FloatNumber + PartialOrd> Optimizer<'a, T> { fn new, Y: Array1>( x: &'a X, y: &'a Y, @@ -355,12 +362,13 @@ impl<'a, T: Number + RealNumber> Optimizer<'a, T> { gmaxindex: 0, tau: T::from_f64(1e-12).unwrap(), sv: support_vectors, + max_iterations: 49999, } } fn find_min_max_gradient(&mut self) { - // self.gmin = ::max_value()(); - // self.gmax = ::min_value(); + self.gmin = ::max_value(); + self.gmax = ::min_value(); for i in 0..self.sv.len() { let v = &self.sv[i]; @@ -398,10 +406,13 @@ impl<'a, T: Number + RealNumber> Optimizer<'a, T> { /// * hyperplane parameters: w and b (computed with T) fn smo(mut self) -> (Vec>, Vec, T) { let cache: Cache = Cache::new(self.sv.len()); - + let mut n_iteration = 0usize; self.find_min_max_gradient(); while self.gmax - self.gmin > self.tol { + if n_iteration > self.max_iterations { + break; + } let v1 = self.svmax; let i = self.gmaxindex; let old_alpha_i = self.sv[v1].alpha[i]; @@ -546,6 +557,7 @@ impl<'a, T: Number + RealNumber> Optimizer<'a, T> { } self.find_min_max_gradient(); + n_iteration += 1; } let b = -(self.gmax + self.gmin) / T::two(); @@ -581,11 +593,11 @@ impl Cache { #[cfg(test)] mod tests { - // use super::*; - // use crate::linalg::basic::matrix::DenseMatrix; - // use crate::metrics::mean_squared_error; - // #[cfg(feature = "serde")] - // use crate::svm::*; + use super::*; + use crate::linalg::basic::matrix::DenseMatrix; + use crate::metrics::mean_squared_error; + #[cfg(feature = "serde")] + use crate::svm::Kernels; // #[test] // fn search_parameters() { @@ -605,79 +617,97 @@ mod tests { // assert!(iter.next().is_none()); // } - // TODO: had to disable this test as it runs for too long - // #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)] - // #[test] - // fn svr_fit_predict() { - // let x = DenseMatrix::from_2d_array(&[ - // &[234.289, 235.6, 159.0, 107.608, 1947., 60.323], - // &[259.426, 232.5, 145.6, 108.632, 1948., 61.122], - // &[258.054, 368.2, 161.6, 109.773, 1949., 60.171], - // &[284.599, 335.1, 165.0, 110.929, 1950., 61.187], - // &[328.975, 209.9, 309.9, 112.075, 1951., 63.221], - // &[346.999, 193.2, 359.4, 113.270, 1952., 63.639], - // &[365.385, 187.0, 354.7, 115.094, 1953., 64.989], - // &[363.112, 357.8, 335.0, 116.219, 1954., 63.761], - // &[397.469, 290.4, 304.8, 117.388, 1955., 66.019], - // &[419.180, 282.2, 285.7, 118.734, 1956., 67.857], - // &[442.769, 293.6, 279.8, 120.445, 1957., 68.169], - // &[444.546, 468.1, 263.7, 121.950, 1958., 66.513], - // &[482.704, 381.3, 255.2, 123.366, 1959., 68.655], - // &[502.601, 393.1, 251.4, 125.368, 1960., 69.564], - // &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], - // &[554.894, 400.7, 282.7, 130.081, 1962., 70.551], - // ]); + //TODO: had to disable this test as it runs for too long + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn svr_fit_predict() { + let x = DenseMatrix::from_2d_array(&[ + &[234.289, 235.6, 159.0, 107.608, 1947., 60.323], + &[259.426, 232.5, 145.6, 108.632, 1948., 61.122], + &[258.054, 368.2, 161.6, 109.773, 1949., 60.171], + &[284.599, 335.1, 165.0, 110.929, 1950., 61.187], + &[328.975, 209.9, 309.9, 112.075, 1951., 63.221], + &[346.999, 193.2, 359.4, 113.270, 1952., 63.639], + &[365.385, 187.0, 354.7, 115.094, 1953., 64.989], + &[363.112, 357.8, 335.0, 116.219, 1954., 63.761], + &[397.469, 290.4, 304.8, 117.388, 1955., 66.019], + &[419.180, 282.2, 285.7, 118.734, 1956., 67.857], + &[442.769, 293.6, 279.8, 120.445, 1957., 68.169], + &[444.546, 468.1, 263.7, 121.950, 1958., 66.513], + &[482.704, 381.3, 255.2, 123.366, 1959., 68.655], + &[502.601, 393.1, 251.4, 125.368, 1960., 69.564], + &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], + &[554.894, 400.7, 282.7, 130.081, 1962., 70.551], + ]); - // let y: Vec = vec![ - // 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, - // 114.2, 115.7, 116.9, - // ]; + let y: Vec = vec![ + 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, + 114.2, 115.7, 116.9, + ]; - // let knl = Kernels::linear(); - // let y_hat = SVR::fit(&x, &y, &SVRParameters::default() - // .with_eps(2.0) - // .with_c(10.0) - // .with_kernel(&knl) - // ) - // .and_then(|lr| lr.predict(&x)) - // .unwrap(); + let knl = Kernels::linear(); + let y_hat = SVR::fit( + &x, + &y, + &SVRParameters::default() + .with_eps(2.0) + .with_c(10.0) + .with_kernel(&knl), + ) + .and_then(|lr| lr.predict(&x)) + .unwrap(); - // assert!(mean_squared_error(&y_hat, &y) < 2.5); - // } + let t = mean_squared_error(&y_hat, &y); + println!("{:?}", t); + assert!(t < 2.5); + } - // #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)] - // #[test] - // #[cfg(feature = "serde")] - // fn svr_serde() { - // let x = DenseMatrix::from_2d_array(&[ - // &[234.289, 235.6, 159.0, 107.608, 1947., 60.323], - // &[259.426, 232.5, 145.6, 108.632, 1948., 61.122], - // &[258.054, 368.2, 161.6, 109.773, 1949., 60.171], - // &[284.599, 335.1, 165.0, 110.929, 1950., 61.187], - // &[328.975, 209.9, 309.9, 112.075, 1951., 63.221], - // &[346.999, 193.2, 359.4, 113.270, 1952., 63.639], - // &[365.385, 187.0, 354.7, 115.094, 1953., 64.989], - // &[363.112, 357.8, 335.0, 116.219, 1954., 63.761], - // &[397.469, 290.4, 304.8, 117.388, 1955., 66.019], - // &[419.180, 282.2, 285.7, 118.734, 1956., 67.857], - // &[442.769, 293.6, 279.8, 120.445, 1957., 68.169], - // &[444.546, 468.1, 263.7, 121.950, 1958., 66.513], - // &[482.704, 381.3, 255.2, 123.366, 1959., 68.655], - // &[502.601, 393.1, 251.4, 125.368, 1960., 69.564], - // &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], - // &[554.894, 400.7, 282.7, 130.081, 1962., 70.551], - // ]); + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + #[cfg(feature = "serde")] + fn svr_serde() { + let x = DenseMatrix::from_2d_array(&[ + &[234.289, 235.6, 159.0, 107.608, 1947., 60.323], + &[259.426, 232.5, 145.6, 108.632, 1948., 61.122], + &[258.054, 368.2, 161.6, 109.773, 1949., 60.171], + &[284.599, 335.1, 165.0, 110.929, 1950., 61.187], + &[328.975, 209.9, 309.9, 112.075, 1951., 63.221], + &[346.999, 193.2, 359.4, 113.270, 1952., 63.639], + &[365.385, 187.0, 354.7, 115.094, 1953., 64.989], + &[363.112, 357.8, 335.0, 116.219, 1954., 63.761], + &[397.469, 290.4, 304.8, 117.388, 1955., 66.019], + &[419.180, 282.2, 285.7, 118.734, 1956., 67.857], + &[442.769, 293.6, 279.8, 120.445, 1957., 68.169], + &[444.546, 468.1, 263.7, 121.950, 1958., 66.513], + &[482.704, 381.3, 255.2, 123.366, 1959., 68.655], + &[502.601, 393.1, 251.4, 125.368, 1960., 69.564], + &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], + &[554.894, 400.7, 282.7, 130.081, 1962., 70.551], + ]); - // let y: Vec = vec![ - // 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, - // 114.2, 115.7, 116.9, - // ]; + let y: Vec = vec![ + 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, + 114.2, 115.7, 116.9, + ]; - // let svr = SVR::fit(&x, &y, Default::default()).unwrap(); + let knl = Kernels::rbf().with_gamma(0.7); + let params = SVRParameters::default().with_kernel(&knl); - // let deserialized_svr: SVR, LinearKernel> = - // serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap(); + let svr = SVR::fit(&x, &y, ¶ms).unwrap(); - // assert_eq!(svr, deserialized_svr); - // } + let serialized = &serde_json::to_string(&svr).unwrap(); + + println!("{}", &serialized); + + // let deserialized_svr: SVR, LinearKernel> = + // serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap(); + + // assert_eq!(svr, deserialized_svr); + } }