Implement SVR and SVR kernels with Enum. Add tests for argsort_mut (#303)

* Add tests for argsort_mut
* Add formatting and cleaning up .github directory
* fix clippy error. suggestion to use .contains()
* define type explicitly for variable jstack
* Implement kernel as enumerator
* basic svr and svr_params implementation
* Complete enum implementation for Kernels. Implement search grid for SVR. Add documentation.
* Fix serde configuration in cargo clippy
*  Implement search parameters (#304)
* Implement SVR kernels as enumerator
* basic svr and svr_params implementation
* Implement search grid for SVR. Add documentation.
* Fix serde configuration in cargo clippy
* Fix wasm32 typetag
* fix typetag
* Bump to version 0.4.2
This commit is contained in:
Lorenzo
2025-06-02 19:01:46 +09:00
committed by GitHub
parent 76d1ef610d
commit 44424807a0
9 changed files with 621 additions and 309 deletions
+26 -25
View File
@@ -51,9 +51,9 @@
//!
//! let knl = Kernels::linear();
//! let params = &SVRParameters::default().with_eps(2.0).with_c(10.0).with_kernel(knl);
//! // let svr = SVR::fit(&x, &y, params).unwrap();
//! let svr = SVR::fit(&x, &y, params).unwrap();
//!
//! // let y_hat = svr.predict(&x).unwrap();
//! let y_hat = svr.predict(&x).unwrap();
//! ```
//!
//! ## References:
@@ -80,11 +80,12 @@ use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::{Array1, Array2, MutArray};
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
use crate::svm::Kernel;
use crate::svm::{Kernel, Kernels};
/// SVR Parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
/// SVR Parameters
pub struct SVRParameters<T: Number + FloatNumber + PartialOrd> {
/// Epsilon in the epsilon-SVR model.
pub eps: T,
@@ -97,7 +98,7 @@ pub struct SVRParameters<T: Number + FloatNumber + PartialOrd> {
all(feature = "serde", target_arch = "wasm32"),
serde(skip_serializing, skip_deserializing)
)]
pub kernel: Option<Box<dyn Kernel>>,
pub kernel: Option<Kernels>,
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@@ -160,8 +161,8 @@ impl<T: Number + FloatNumber + PartialOrd> SVRParameters<T> {
self
}
/// The kernel function.
pub fn with_kernel<K: Kernel + 'static>(mut self, kernel: K) -> Self {
self.kernel = Some(Box::new(kernel));
pub fn with_kernel(mut self, kernel: Kernels) -> Self {
self.kernel = Some(kernel);
self
}
}
@@ -597,25 +598,25 @@ mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::metrics::mean_squared_error;
use crate::svm::search::svr_params::SVRSearchParameters;
use crate::svm::Kernels;
// #[test]
// fn search_parameters() {
// let parameters: SVRSearchParameters<f64, DenseMatrix<f64>, LinearKernel> =
// SVRSearchParameters {
// eps: vec![0., 1.],
// kernel: vec![LinearKernel {}],
// ..Default::default()
// };
// let mut iter = parameters.into_iter();
// let next = iter.next().unwrap();
// assert_eq!(next.eps, 0.);
// assert_eq!(next.kernel, LinearKernel {});
// let next = iter.next().unwrap();
// assert_eq!(next.eps, 1.);
// assert_eq!(next.kernel, LinearKernel {});
// assert!(iter.next().is_none());
// }
#[test]
fn search_parameters() {
let parameters: SVRSearchParameters<f64, DenseMatrix<f64>> = SVRSearchParameters {
eps: vec![0., 1.],
kernel: vec![Kernels::linear()],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.eps, 0.);
// assert_eq!(next.kernel, LinearKernel {});
// let next = iter.next().unwrap();
// assert_eq!(next.eps, 1.);
// assert_eq!(next.kernel, LinearKernel {});
// assert!(iter.next().is_none());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
@@ -648,7 +649,7 @@ mod tests {
114.2, 115.7, 116.9,
];
let knl = Kernels::linear();
let knl: Kernels = Kernels::linear();
let y_hat = SVR::fit(
&x,
&y,