Implement SVR and SVR kernels with Enum. Add tests for argsort_mut (#303)
* Add tests for argsort_mut * Add formatting and cleaning up .github directory * fix clippy error. suggestion to use .contains() * define type explicitly for variable jstack * Implement kernel as enumerator * basic svr and svr_params implementation * Complete enum implementation for Kernels. Implement search grid for SVR. Add documentation. * Fix serde configuration in cargo clippy * Implement search parameters (#304) * Implement SVR kernels as enumerator * basic svr and svr_params implementation * Implement search grid for SVR. Add documentation. * Fix serde configuration in cargo clippy * Fix wasm32 typetag * fix typetag * Bump to version 0.4.2
This commit is contained in:
+26
-25
@@ -51,9 +51,9 @@
|
||||
//!
|
||||
//! let knl = Kernels::linear();
|
||||
//! let params = &SVRParameters::default().with_eps(2.0).with_c(10.0).with_kernel(knl);
|
||||
//! // let svr = SVR::fit(&x, &y, params).unwrap();
|
||||
//! let svr = SVR::fit(&x, &y, params).unwrap();
|
||||
//!
|
||||
//! // let y_hat = svr.predict(&x).unwrap();
|
||||
//! let y_hat = svr.predict(&x).unwrap();
|
||||
//! ```
|
||||
//!
|
||||
//! ## References:
|
||||
@@ -80,11 +80,12 @@ use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::basic::arrays::{Array1, Array2, MutArray};
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::floatnum::FloatNumber;
|
||||
use crate::svm::Kernel;
|
||||
|
||||
use crate::svm::{Kernel, Kernels};
|
||||
|
||||
/// SVR Parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
/// SVR Parameters
|
||||
pub struct SVRParameters<T: Number + FloatNumber + PartialOrd> {
|
||||
/// Epsilon in the epsilon-SVR model.
|
||||
pub eps: T,
|
||||
@@ -97,7 +98,7 @@ pub struct SVRParameters<T: Number + FloatNumber + PartialOrd> {
|
||||
all(feature = "serde", target_arch = "wasm32"),
|
||||
serde(skip_serializing, skip_deserializing)
|
||||
)]
|
||||
pub kernel: Option<Box<dyn Kernel>>,
|
||||
pub kernel: Option<Kernels>,
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
@@ -160,8 +161,8 @@ impl<T: Number + FloatNumber + PartialOrd> SVRParameters<T> {
|
||||
self
|
||||
}
|
||||
/// The kernel function.
|
||||
pub fn with_kernel<K: Kernel + 'static>(mut self, kernel: K) -> Self {
|
||||
self.kernel = Some(Box::new(kernel));
|
||||
pub fn with_kernel(mut self, kernel: Kernels) -> Self {
|
||||
self.kernel = Some(kernel);
|
||||
self
|
||||
}
|
||||
}
|
||||
@@ -597,25 +598,25 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use crate::metrics::mean_squared_error;
|
||||
use crate::svm::search::svr_params::SVRSearchParameters;
|
||||
use crate::svm::Kernels;
|
||||
|
||||
// #[test]
|
||||
// fn search_parameters() {
|
||||
// let parameters: SVRSearchParameters<f64, DenseMatrix<f64>, LinearKernel> =
|
||||
// SVRSearchParameters {
|
||||
// eps: vec![0., 1.],
|
||||
// kernel: vec![LinearKernel {}],
|
||||
// ..Default::default()
|
||||
// };
|
||||
// let mut iter = parameters.into_iter();
|
||||
// let next = iter.next().unwrap();
|
||||
// assert_eq!(next.eps, 0.);
|
||||
// assert_eq!(next.kernel, LinearKernel {});
|
||||
// let next = iter.next().unwrap();
|
||||
// assert_eq!(next.eps, 1.);
|
||||
// assert_eq!(next.kernel, LinearKernel {});
|
||||
// assert!(iter.next().is_none());
|
||||
// }
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters: SVRSearchParameters<f64, DenseMatrix<f64>> = SVRSearchParameters {
|
||||
eps: vec![0., 1.],
|
||||
kernel: vec![Kernels::linear()],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.eps, 0.);
|
||||
// assert_eq!(next.kernel, LinearKernel {});
|
||||
// let next = iter.next().unwrap();
|
||||
// assert_eq!(next.eps, 1.);
|
||||
// assert_eq!(next.kernel, LinearKernel {});
|
||||
// assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
@@ -648,7 +649,7 @@ mod tests {
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
|
||||
let knl = Kernels::linear();
|
||||
let knl: Kernels = Kernels::linear();
|
||||
let y_hat = SVR::fit(
|
||||
&x,
|
||||
&y,
|
||||
|
||||
Reference in New Issue
Block a user