Implement SVR and SVR kernels with Enum. Add tests for argsort_mut (#303)
* Add tests for argsort_mut * Add formatting and cleaning up .github directory * fix clippy error. suggestion to use .contains() * define type explicitly for variable jstack * Implement kernel as enumerator * basic svr and svr_params implementation * Complete enum implementation for Kernels. Implement search grid for SVR. Add documentation. * Fix serde configuration in cargo clippy * Implement search parameters (#304) * Implement SVR kernels as enumerator * basic svr and svr_params implementation * Implement search grid for SVR. Add documentation. * Fix serde configuration in cargo clippy * Fix wasm32 typetag * fix typetag * Bump to version 0.4.2
This commit is contained in:
@@ -2,6 +2,5 @@
|
||||
# the repo. Unless a later match takes precedence,
|
||||
# Developers in this list will be requested for
|
||||
# review when someone opens a pull request.
|
||||
* @VolodymyrOrlov
|
||||
* @morenol
|
||||
* @Mec-iS
|
||||
|
||||
@@ -50,9 +50,9 @@ $ rust-code-analysis-cli -p src/algorithm/neighbour/fastpair.rs --ls 22 --le 213
|
||||
|
||||
1. After a PR is opened maintainers are notified
|
||||
2. Probably changes will be required to comply with the workflow, these commands are run automatically and all tests shall pass:
|
||||
* **Coverage** (optional): `tarpaulin` is used with command `cargo tarpaulin --out Lcov --all-features -- --test-threads 1`
|
||||
* **Formatting**: run `rustfmt src/*.rs` to apply automatic formatting
|
||||
* **Linting**: `clippy` is used with command `cargo clippy --all-features -- -Drust-2018-idioms -Dwarnings`
|
||||
* **Coverage** (optional): `tarpaulin` is used with command `cargo tarpaulin --out Lcov --all-features -- --test-threads 1`
|
||||
* **Testing**: multiple test pipelines are run for different targets
|
||||
3. When everything is OK, code is merged.
|
||||
|
||||
|
||||
+1
-1
@@ -2,7 +2,7 @@
|
||||
name = "smartcore"
|
||||
description = "Machine Learning in Rust."
|
||||
homepage = "https://smartcorelib.org"
|
||||
version = "0.4.1"
|
||||
version = "0.4.2"
|
||||
authors = ["smartcore Developers"]
|
||||
edition = "2021"
|
||||
license = "Apache-2.0"
|
||||
|
||||
@@ -619,7 +619,7 @@ pub trait MutArrayView1<T: Debug + Display + Copy + Sized>:
|
||||
T: Number + PartialOrd,
|
||||
{
|
||||
let stack_size = 64;
|
||||
let mut jstack = -1;
|
||||
let mut jstack: i32 = -1;
|
||||
let mut l = 0;
|
||||
let mut istack = vec![0; stack_size];
|
||||
let mut ir = self.shape() - 1;
|
||||
@@ -2190,4 +2190,29 @@ mod tests {
|
||||
|
||||
assert_eq!(result, [65, 581, 30])
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_argsort_mut_exact_boundary() {
|
||||
// Test index == length - 1 case
|
||||
let boundary =
|
||||
DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0, f64::MAX], &[3.0, f64::MAX, 0.0, 2.0]])
|
||||
.unwrap();
|
||||
let mut view0: Vec<f64> = boundary.get_col(0).iterator(0).copied().collect();
|
||||
let indices = view0.argsort_mut();
|
||||
assert_eq!(indices.last(), Some(&1));
|
||||
assert_eq!(indices.first(), Some(&0));
|
||||
|
||||
let mut view1: Vec<f64> = boundary.get_col(3).iterator(0).copied().collect();
|
||||
let indices = view1.argsort_mut();
|
||||
assert_eq!(indices.last(), Some(&0));
|
||||
assert_eq!(indices.first(), Some(&1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_argsort_mut_filled_array() {
|
||||
let matrix = DenseMatrix::<f64>::rand(1000, 1000);
|
||||
let mut view: Vec<f64> = matrix.get_col(0).iterator(0).copied().collect();
|
||||
let sorted = view.argsort_mut();
|
||||
assert_eq!(sorted.len(), 1000);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,7 +64,7 @@ impl KNNWeightFunction {
|
||||
KNNWeightFunction::Distance => {
|
||||
// if there are any points that has zero distance from one or more training points,
|
||||
// those training points are weighted as 1.0 and the other points as 0.0
|
||||
if distances.iter().any(|&e| e == 0f64) {
|
||||
if distances.contains(&0f64) {
|
||||
distances
|
||||
.iter()
|
||||
.map(|e| if *e == 0f64 { 1f64 } else { 0f64 })
|
||||
|
||||
+282
-178
@@ -25,14 +25,18 @@
|
||||
/// search parameters
|
||||
pub mod svc;
|
||||
pub mod svr;
|
||||
// /// search parameters space
|
||||
// pub mod search;
|
||||
// search parameters space
|
||||
pub mod search;
|
||||
|
||||
use core::fmt::Debug;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// Only import typetag if not compiling for wasm32 and serde is enabled
|
||||
#[cfg(all(feature = "serde", not(target_arch = "wasm32")))]
|
||||
use typetag;
|
||||
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::basic::arrays::{Array1, ArrayView1};
|
||||
|
||||
@@ -48,197 +52,281 @@ pub trait Kernel: Debug {
|
||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed>;
|
||||
}
|
||||
|
||||
/// Pre-defined kernel functions
|
||||
/// A enumerator for all the kernels type to support.
|
||||
/// This allows kernel selection and parameterization ergonomic, type-safe, and ready for use in parameter structs like SVRParameters.
|
||||
/// You can construct kernels using the provided variants and builder-style methods.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use smartcore::svm::Kernels;
|
||||
///
|
||||
/// let linear = Kernels::linear();
|
||||
/// let rbf = Kernels::rbf().with_gamma(0.5);
|
||||
/// let poly = Kernels::polynomial().with_degree(3.0).with_gamma(0.5).with_coef0(1.0);
|
||||
/// let sigmoid = Kernels::sigmoid().with_gamma(0.2).with_coef0(0.0);
|
||||
/// ```
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Kernels;
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum Kernels {
|
||||
/// Linear kernel (default).
|
||||
///
|
||||
/// Computes the standard dot product between vectors.
|
||||
Linear,
|
||||
|
||||
/// Radial Basis Function (RBF) kernel.
|
||||
///
|
||||
/// Formula: K(x, y) = exp(-gamma * ||x-y||²)
|
||||
RBF {
|
||||
/// Controls the width of the Gaussian RBF kernel.
|
||||
///
|
||||
/// Larger values of gamma lead to higher bias and lower variance.
|
||||
/// This parameter is inversely proportional to the radius of influence
|
||||
/// of samples selected by the model as support vectors.
|
||||
gamma: Option<f64>,
|
||||
},
|
||||
|
||||
/// Polynomial kernel.
|
||||
///
|
||||
/// Formula: K(x, y) = (gamma * <x, y> + coef0)^degree
|
||||
Polynomial {
|
||||
/// The degree of the polynomial kernel.
|
||||
///
|
||||
/// Integer values are typical (2 = quadratic, 3 = cubic), but any positive real value is valid.
|
||||
/// Higher degree values create decision boundaries with higher complexity.
|
||||
degree: Option<f64>,
|
||||
|
||||
/// Kernel coefficient for the dot product.
|
||||
///
|
||||
/// Controls the influence of higher-degree versus lower-degree terms in the polynomial.
|
||||
/// If None, a default value will be used.
|
||||
gamma: Option<f64>,
|
||||
|
||||
/// Independent term in the polynomial kernel.
|
||||
///
|
||||
/// Controls the influence of higher-degree versus lower-degree terms.
|
||||
/// If None, a default value of 1.0 will be used.
|
||||
coef0: Option<f64>,
|
||||
},
|
||||
|
||||
/// Sigmoid kernel.
|
||||
///
|
||||
/// Formula: K(x, y) = tanh(gamma * <x, y> + coef0)
|
||||
Sigmoid {
|
||||
/// Kernel coefficient for the dot product.
|
||||
///
|
||||
/// Controls the scaling of the dot product in the sigmoid function.
|
||||
/// If None, a default value will be used.
|
||||
gamma: Option<f64>,
|
||||
|
||||
/// Independent term in the sigmoid kernel.
|
||||
///
|
||||
/// Acts as a threshold/bias term in the sigmoid function.
|
||||
/// If None, a default value of 1.0 will be used.
|
||||
coef0: Option<f64>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Kernels {
|
||||
/// Return a default linear
|
||||
pub fn linear() -> LinearKernel {
|
||||
LinearKernel
|
||||
/// Create a linear kernel.
|
||||
///
|
||||
/// The linear kernel computes the dot product between two vectors:
|
||||
/// K(x, y) = <x, y>
|
||||
pub fn linear() -> Self {
|
||||
Kernels::Linear
|
||||
}
|
||||
/// Return a default RBF
|
||||
pub fn rbf() -> RBFKernel {
|
||||
RBFKernel::default()
|
||||
|
||||
/// Create an RBF kernel with unspecified gamma.
|
||||
///
|
||||
/// The RBF kernel is defined as:
|
||||
/// K(x, y) = exp(-gamma * ||x-y||²)
|
||||
///
|
||||
/// You should specify gamma using `with_gamma()` before using this kernel.
|
||||
pub fn rbf() -> Self {
|
||||
Kernels::RBF { gamma: None }
|
||||
}
|
||||
/// Return a default polynomial
|
||||
pub fn polynomial() -> PolynomialKernel {
|
||||
PolynomialKernel::default()
|
||||
|
||||
/// Create a polynomial kernel with default parameters.
|
||||
///
|
||||
/// The polynomial kernel is defined as:
|
||||
/// K(x, y) = (gamma * <x, y> + coef0)^degree
|
||||
///
|
||||
/// Default values:
|
||||
/// - gamma: None (must be specified)
|
||||
/// - degree: None (must be specified)
|
||||
/// - coef0: 1.0
|
||||
pub fn polynomial() -> Self {
|
||||
Kernels::Polynomial {
|
||||
gamma: None,
|
||||
degree: None,
|
||||
coef0: Some(1.0),
|
||||
}
|
||||
}
|
||||
/// Return a default sigmoid
|
||||
pub fn sigmoid() -> SigmoidKernel {
|
||||
SigmoidKernel::default()
|
||||
|
||||
/// Create a sigmoid kernel with default parameters.
|
||||
///
|
||||
/// The sigmoid kernel is defined as:
|
||||
/// K(x, y) = tanh(gamma * <x, y> + coef0)
|
||||
///
|
||||
/// Default values:
|
||||
/// - gamma: None (must be specified)
|
||||
/// - coef0: 1.0
|
||||
///
|
||||
pub fn sigmoid() -> Self {
|
||||
Kernels::Sigmoid {
|
||||
gamma: None,
|
||||
coef0: Some(1.0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Linear Kernel
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||
pub struct LinearKernel;
|
||||
|
||||
/// Radial basis function (Gaussian) kernel
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Default, Clone, PartialEq)]
|
||||
pub struct RBFKernel {
|
||||
/// kernel coefficient
|
||||
pub gamma: Option<f64>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl RBFKernel {
|
||||
/// assign gamma parameter to kernel (required)
|
||||
/// ```rust
|
||||
/// use smartcore::svm::RBFKernel;
|
||||
/// let knl = RBFKernel::default().with_gamma(0.7);
|
||||
/// ```
|
||||
pub fn with_gamma(mut self, gamma: f64) -> Self {
|
||||
self.gamma = Some(gamma);
|
||||
self
|
||||
/// Set the `gamma` parameter for RBF, polynomial, or sigmoid kernels.
|
||||
///
|
||||
/// The gamma parameter has different interpretations depending on the kernel:
|
||||
/// - For RBF: Controls the width of the Gaussian. Larger values mean tighter fit.
|
||||
/// - For Polynomial: Scaling factor for the dot product.
|
||||
/// - For Sigmoid: Scaling factor for the dot product.
|
||||
///
|
||||
pub fn with_gamma(self, gamma: f64) -> Self {
|
||||
match self {
|
||||
Kernels::RBF { .. } => Kernels::RBF { gamma: Some(gamma) },
|
||||
Kernels::Polynomial { degree, coef0, .. } => Kernels::Polynomial {
|
||||
gamma: Some(gamma),
|
||||
degree,
|
||||
coef0,
|
||||
},
|
||||
Kernels::Sigmoid { coef0, .. } => Kernels::Sigmoid {
|
||||
gamma: Some(gamma),
|
||||
coef0,
|
||||
},
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Polynomial kernel
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct PolynomialKernel {
|
||||
/// degree of the polynomial
|
||||
pub degree: Option<f64>,
|
||||
/// kernel coefficient
|
||||
pub gamma: Option<f64>,
|
||||
/// independent term in kernel function
|
||||
pub coef0: Option<f64>,
|
||||
}
|
||||
/// Set the `degree` parameter for the polynomial kernel.
|
||||
///
|
||||
/// The degree parameter controls the flexibility of the decision boundary.
|
||||
/// Higher degrees create more complex boundaries but may lead to overfitting.
|
||||
///
|
||||
pub fn with_degree(self, degree: f64) -> Self {
|
||||
match self {
|
||||
Kernels::Polynomial { gamma, coef0, .. } => Kernels::Polynomial {
|
||||
degree: Some(degree),
|
||||
gamma,
|
||||
coef0,
|
||||
},
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PolynomialKernel {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
gamma: Option::None,
|
||||
degree: Option::None,
|
||||
coef0: Some(1f64),
|
||||
/// Set the `coef0` parameter for polynomial or sigmoid kernels.
|
||||
///
|
||||
/// The coef0 parameter is the independent term in the kernel function:
|
||||
/// - For Polynomial: Controls the influence of higher-degree vs. lower-degree terms.
|
||||
/// - For Sigmoid: Acts as a threshold/bias term.
|
||||
///
|
||||
pub fn with_coef0(self, coef0: f64) -> Self {
|
||||
match self {
|
||||
Kernels::Polynomial { degree, gamma, .. } => Kernels::Polynomial {
|
||||
degree,
|
||||
gamma,
|
||||
coef0: Some(coef0),
|
||||
},
|
||||
Kernels::Sigmoid { gamma, .. } => Kernels::Sigmoid {
|
||||
gamma,
|
||||
coef0: Some(coef0),
|
||||
},
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PolynomialKernel {
|
||||
/// set parameters for kernel
|
||||
/// ```rust
|
||||
/// use smartcore::svm::PolynomialKernel;
|
||||
/// let knl = PolynomialKernel::default().with_params(3.0, 0.7, 1.0);
|
||||
/// ```
|
||||
pub fn with_params(mut self, degree: f64, gamma: f64, coef0: f64) -> Self {
|
||||
self.degree = Some(degree);
|
||||
self.gamma = Some(gamma);
|
||||
self.coef0 = Some(coef0);
|
||||
self
|
||||
}
|
||||
/// set gamma parameter for kernel
|
||||
/// ```rust
|
||||
/// use smartcore::svm::PolynomialKernel;
|
||||
/// let knl = PolynomialKernel::default().with_gamma(0.7);
|
||||
/// ```
|
||||
pub fn with_gamma(mut self, gamma: f64) -> Self {
|
||||
self.gamma = Some(gamma);
|
||||
self
|
||||
}
|
||||
/// set degree parameter for kernel
|
||||
/// ```rust
|
||||
/// use smartcore::svm::PolynomialKernel;
|
||||
/// let knl = PolynomialKernel::default().with_degree(3.0, 100);
|
||||
/// ```
|
||||
pub fn with_degree(self, degree: f64, n_features: usize) -> Self {
|
||||
self.with_params(degree, 1f64, 1f64 / n_features as f64)
|
||||
}
|
||||
}
|
||||
|
||||
/// Sigmoid (hyperbolic tangent) kernel
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct SigmoidKernel {
|
||||
/// kernel coefficient
|
||||
pub gamma: Option<f64>,
|
||||
/// independent term in kernel function
|
||||
pub coef0: Option<f64>,
|
||||
}
|
||||
|
||||
impl Default for SigmoidKernel {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
gamma: Option::None,
|
||||
coef0: Some(1f64),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SigmoidKernel {
|
||||
/// set parameters for kernel
|
||||
/// ```rust
|
||||
/// use smartcore::svm::SigmoidKernel;
|
||||
/// let knl = SigmoidKernel::default().with_params(0.7, 1.0);
|
||||
/// ```
|
||||
pub fn with_params(mut self, gamma: f64, coef0: f64) -> Self {
|
||||
self.gamma = Some(gamma);
|
||||
self.coef0 = Some(coef0);
|
||||
self
|
||||
}
|
||||
/// set gamma parameter for kernel
|
||||
/// ```rust
|
||||
/// use smartcore::svm::SigmoidKernel;
|
||||
/// let knl = SigmoidKernel::default().with_gamma(0.7);
|
||||
/// ```
|
||||
pub fn with_gamma(mut self, gamma: f64) -> Self {
|
||||
self.gamma = Some(gamma);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of the [`Kernel`] trait for the [`Kernels`] enum in smartcore.
|
||||
///
|
||||
/// This method computes the value of the kernel function between two feature vectors `x_i` and `x_j`,
|
||||
/// according to the variant and parameters of the [`Kernels`] enum. This enables flexible and type-safe
|
||||
/// selection of kernel functions for SVM and SVR models in smartcore.
|
||||
///
|
||||
/// # Supported Kernels
|
||||
///
|
||||
/// - [`Kernels::Linear`]: Computes the standard dot product between `x_i` and `x_j`.
|
||||
/// - [`Kernels::RBF`]: Computes the Radial Basis Function (Gaussian) kernel. Requires `gamma`.
|
||||
/// - [`Kernels::Polynomial`]: Computes the polynomial kernel. Requires `degree`, `gamma`, and `coef0`.
|
||||
/// - [`Kernels::Sigmoid`]: Computes the sigmoid kernel. Requires `gamma` and `coef0`.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `x_i`: First input vector (feature vector).
|
||||
/// - `x_j`: Second input vector (feature vector).
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// - `Ok(f64)`: The computed kernel value.
|
||||
/// - `Err(Failed)`: If any required kernel parameter is missing.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err(Failed)` if a required parameter (such as `gamma`, `degree`, or `coef0`)
|
||||
/// is `None` for the selected kernel variant.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use smartcore::svm::Kernels;
|
||||
/// use smartcore::svm::Kernel;
|
||||
///
|
||||
/// let x = vec![1.0, 2.0, 3.0];
|
||||
/// let y = vec![4.0, 5.0, 6.0];
|
||||
/// let kernel = Kernels::rbf().with_gamma(0.5);
|
||||
/// let value = kernel.apply(&x, &y).unwrap();
|
||||
/// ```
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// - This implementation follows smartcore's philosophy: pure Rust, no macros, no unsafe code,
|
||||
/// and an accessible, pythonic API surface for both ML practitioners and Rust beginners.
|
||||
/// - All kernel parameters must be set before calling `apply`; missing parameters will result in an error.
|
||||
///
|
||||
/// See the [`Kernels`] enum documentation for more details on each kernel type and its parameters.
|
||||
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
|
||||
impl Kernel for LinearKernel {
|
||||
impl Kernel for Kernels {
|
||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
||||
Ok(x_i.dot(x_j))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
|
||||
impl Kernel for RBFKernel {
|
||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
||||
if self.gamma.is_none() {
|
||||
return Err(Failed::because(
|
||||
FailedError::ParametersError,
|
||||
"gamma should be set, use {Kernel}::default().with_gamma(..)",
|
||||
));
|
||||
match self {
|
||||
Kernels::Linear => Ok(x_i.dot(x_j)),
|
||||
Kernels::RBF { gamma } => {
|
||||
let gamma = gamma.ok_or_else(|| {
|
||||
Failed::because(FailedError::ParametersError, "gamma not set")
|
||||
})?;
|
||||
let v_diff = x_i.sub(x_j);
|
||||
Ok((-gamma * v_diff.mul(&v_diff).sum()).exp())
|
||||
}
|
||||
Kernels::Polynomial {
|
||||
degree,
|
||||
gamma,
|
||||
coef0,
|
||||
} => {
|
||||
let degree = degree.ok_or_else(|| {
|
||||
Failed::because(FailedError::ParametersError, "degree not set")
|
||||
})?;
|
||||
let gamma = gamma.ok_or_else(|| {
|
||||
Failed::because(FailedError::ParametersError, "gamma not set")
|
||||
})?;
|
||||
let coef0 = coef0.ok_or_else(|| {
|
||||
Failed::because(FailedError::ParametersError, "coef0 not set")
|
||||
})?;
|
||||
let dot = x_i.dot(x_j);
|
||||
Ok((gamma * dot + coef0).powf(degree))
|
||||
}
|
||||
Kernels::Sigmoid { gamma, coef0 } => {
|
||||
let gamma = gamma.ok_or_else(|| {
|
||||
Failed::because(FailedError::ParametersError, "gamma not set")
|
||||
})?;
|
||||
let coef0 = coef0.ok_or_else(|| {
|
||||
Failed::because(FailedError::ParametersError, "coef0 not set")
|
||||
})?;
|
||||
let dot = x_i.dot(x_j);
|
||||
Ok((gamma * dot + coef0).tanh())
|
||||
}
|
||||
}
|
||||
let v_diff = x_i.sub(x_j);
|
||||
Ok((-self.gamma.unwrap() * v_diff.mul(&v_diff).sum()).exp())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
|
||||
impl Kernel for PolynomialKernel {
|
||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
||||
if self.gamma.is_none() || self.coef0.is_none() || self.degree.is_none() {
|
||||
return Err(Failed::because(
|
||||
FailedError::ParametersError, "gamma, coef0, degree should be set,
|
||||
use {Kernel}::default().with_{parameter}(..)")
|
||||
);
|
||||
}
|
||||
let dot = x_i.dot(x_j);
|
||||
Ok((self.gamma.unwrap() * dot + self.coef0.unwrap()).powf(self.degree.unwrap()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
|
||||
impl Kernel for SigmoidKernel {
|
||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
||||
if self.gamma.is_none() || self.coef0.is_none() {
|
||||
return Err(Failed::because(
|
||||
FailedError::ParametersError, "gamma, coef0, degree should be set,
|
||||
use {Kernel}::default().with_{parameter}(..)")
|
||||
);
|
||||
}
|
||||
let dot = x_i.dot(x_j);
|
||||
Ok(self.gamma.unwrap() * dot + self.coef0.unwrap().tanh())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -247,6 +335,18 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::svm::Kernels;
|
||||
|
||||
#[test]
|
||||
fn rbf_kernel() {
|
||||
let v1 = vec![1., 2., 3.];
|
||||
let v2 = vec![4., 5., 6.];
|
||||
let result = Kernels::rbf()
|
||||
.with_gamma(0.055)
|
||||
.apply(&v1, &v2)
|
||||
.unwrap()
|
||||
.abs();
|
||||
assert!((0.2265f64 - result) < 1e-4);
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
@@ -264,7 +364,7 @@ mod tests {
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn rbf_kernel() {
|
||||
fn test_rbf_kernel() {
|
||||
let v1 = vec![1., 2., 3.];
|
||||
let v2 = vec![4., 5., 6.];
|
||||
|
||||
@@ -287,7 +387,10 @@ mod tests {
|
||||
let v2 = vec![4., 5., 6.];
|
||||
|
||||
let result = Kernels::polynomial()
|
||||
.with_params(3.0, 0.5, 1.0)
|
||||
.with_gamma(0.5)
|
||||
.with_degree(3.0)
|
||||
.with_coef0(1.0)
|
||||
//.with_params(3.0, 0.5, 1.0)
|
||||
.apply(&v1, &v2)
|
||||
.unwrap()
|
||||
.abs();
|
||||
@@ -305,7 +408,8 @@ mod tests {
|
||||
let v2 = vec![4., 5., 6.];
|
||||
|
||||
let result = Kernels::sigmoid()
|
||||
.with_params(0.01, 0.1)
|
||||
.with_gamma(0.01)
|
||||
.with_coef0(0.1)
|
||||
.apply(&v1, &v2)
|
||||
.unwrap()
|
||||
.abs();
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//! SVC and Grid Search
|
||||
|
||||
/// SVC search parameters
|
||||
pub mod svc_params;
|
||||
/// SVC search parameters
|
||||
|
||||
+282
-101
@@ -1,112 +1,293 @@
|
||||
// /// SVR grid search parameters
|
||||
// #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
// #[derive(Debug, Clone)]
|
||||
// pub struct SVRSearchParameters<T: Number + RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||
// /// Epsilon in the epsilon-SVR model.
|
||||
// pub eps: Vec<T>,
|
||||
// /// Regularization parameter.
|
||||
// pub c: Vec<T>,
|
||||
// /// Tolerance for stopping eps.
|
||||
// pub tol: Vec<T>,
|
||||
// /// The kernel function.
|
||||
// pub kernel: Vec<K>,
|
||||
// /// Unused parameter.
|
||||
// m: PhantomData<M>,
|
||||
// }
|
||||
//! # SVR Grid Search Parameters
|
||||
//!
|
||||
//! This module provides utilities for defining and iterating over grid search parameter spaces
|
||||
//! for Support Vector Regression (SVR) models in [smartcore](https://github.com/smartcorelib/smartcore).
|
||||
//!
|
||||
//! The main struct, [`SVRSearchParameters`], allows users to specify multiple values for each
|
||||
//! SVR hyperparameter (epsilon, regularization parameter C, tolerance, and kernel function).
|
||||
//! The provided iterator yields all possible combinations (the Cartesian product) of these parameters,
|
||||
//! enabling exhaustive grid search for hyperparameter tuning.
|
||||
//!
|
||||
//!
|
||||
//! ## Example
|
||||
//! ```
|
||||
//! use smartcore::svm::Kernels;
|
||||
//! use smartcore::svm::search::svr_params::SVRSearchParameters;
|
||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
//!
|
||||
//! let params = SVRSearchParameters::<f64, DenseMatrix<f64>> {
|
||||
//! eps: vec![0.1, 0.2],
|
||||
//! c: vec![1.0, 10.0],
|
||||
//! tol: vec![1e-3],
|
||||
//! kernel: vec![Kernels::linear(), Kernels::rbf().with_gamma(0.5)],
|
||||
//! m: std::marker::PhantomData,
|
||||
//! };
|
||||
//!
|
||||
//! // for param_set in params.into_iter() {
|
||||
//! // Use param_set (of type svr::SVRParameters) to fit and evaluate your SVR model.
|
||||
//! // }
|
||||
//! ```
|
||||
//!
|
||||
//!
|
||||
//! ## Note
|
||||
//! This module is intended for use with smartcore version 0.4 or later. The API is not compatible with older versions[1].
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// /// SVR grid search iterator
|
||||
// pub struct SVRSearchParametersIterator<T: Number + RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||
// svr_search_parameters: SVRSearchParameters<T, M, K>,
|
||||
// current_eps: usize,
|
||||
// current_c: usize,
|
||||
// current_tol: usize,
|
||||
// current_kernel: usize,
|
||||
// }
|
||||
use crate::linalg::basic::arrays::Array2;
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::floatnum::FloatNumber;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
use crate::svm::{svr, Kernels};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
// impl<T: Number + RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> IntoIterator
|
||||
// for SVRSearchParameters<T, M, K>
|
||||
// {
|
||||
// type Item = SVRParameters<T, M, K>;
|
||||
// type IntoIter = SVRSearchParametersIterator<T, M, K>;
|
||||
/// ## SVR grid search parameters
|
||||
/// A struct representing a grid of hyperparameters for SVR grid search in smartcore.
|
||||
///
|
||||
/// Each field is a vector of possible values for the corresponding SVR hyperparameter.
|
||||
/// The [`IntoIterator`] implementation yields every possible combination of these parameters
|
||||
/// as an `svr::SVRParameters` struct, suitable for use in model selection routines.
|
||||
///
|
||||
/// # Type Parameters
|
||||
/// - `T`: Numeric type for parameters (e.g., `f64`)
|
||||
/// - `M`: Matrix type implementing [`Array2<T>`]
|
||||
///
|
||||
/// # Fields
|
||||
/// - `eps`: Vector of epsilon values for the epsilon-insensitive loss in SVR.
|
||||
/// - `c`: Vector of regularization parameters (C) for SVR.
|
||||
/// - `tol`: Vector of tolerance values for the stopping criterion.
|
||||
/// - `kernel`: Vector of kernel function variants (see [`Kernels`]).
|
||||
/// - `m`: Phantom data for the matrix type parameter.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use smartcore::svm::Kernels;
|
||||
/// use smartcore::svm::search::svr_params::SVRSearchParameters;
|
||||
/// use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
///
|
||||
/// let params = SVRSearchParameters::<f64, DenseMatrix<f64>> {
|
||||
/// eps: vec![0.1, 0.2],
|
||||
/// c: vec![1.0, 10.0],
|
||||
/// tol: vec![1e-3],
|
||||
/// kernel: vec![Kernels::linear(), Kernels::rbf().with_gamma(0.5)],
|
||||
/// m: std::marker::PhantomData,
|
||||
/// };
|
||||
/// ```
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SVRSearchParameters<T: Number + RealNumber, M: Array2<T>> {
|
||||
/// Epsilon in the epsilon-SVR model.
|
||||
pub eps: Vec<T>,
|
||||
/// Regularization parameter.
|
||||
pub c: Vec<T>,
|
||||
/// Tolerance for stopping eps.
|
||||
pub tol: Vec<T>,
|
||||
/// The kernel function.
|
||||
pub kernel: Vec<Kernels>,
|
||||
/// Unused parameter.
|
||||
pub m: PhantomData<M>,
|
||||
}
|
||||
|
||||
// fn into_iter(self) -> Self::IntoIter {
|
||||
// SVRSearchParametersIterator {
|
||||
// svr_search_parameters: self,
|
||||
// current_eps: 0,
|
||||
// current_c: 0,
|
||||
// current_tol: 0,
|
||||
// current_kernel: 0,
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
/// SVR grid search iterator
|
||||
pub struct SVRSearchParametersIterator<T: Number + RealNumber, M: Array2<T>> {
|
||||
svr_search_parameters: SVRSearchParameters<T, M>,
|
||||
current_eps: usize,
|
||||
current_c: usize,
|
||||
current_tol: usize,
|
||||
current_kernel: usize,
|
||||
}
|
||||
|
||||
// impl<T: Number + RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Iterator
|
||||
// for SVRSearchParametersIterator<T, M, K>
|
||||
// {
|
||||
// type Item = SVRParameters<T, M, K>;
|
||||
impl<T: Number + FloatNumber + RealNumber, M: Array2<T>> IntoIterator
|
||||
for SVRSearchParameters<T, M>
|
||||
{
|
||||
type Item = svr::SVRParameters<T>;
|
||||
type IntoIter = SVRSearchParametersIterator<T, M>;
|
||||
|
||||
// fn next(&mut self) -> Option<Self::Item> {
|
||||
// if self.current_eps == self.svr_search_parameters.eps.len()
|
||||
// && self.current_c == self.svr_search_parameters.c.len()
|
||||
// && self.current_tol == self.svr_search_parameters.tol.len()
|
||||
// && self.current_kernel == self.svr_search_parameters.kernel.len()
|
||||
// {
|
||||
// return None;
|
||||
// }
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
SVRSearchParametersIterator {
|
||||
svr_search_parameters: self,
|
||||
current_eps: 0,
|
||||
current_c: 0,
|
||||
current_tol: 0,
|
||||
current_kernel: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// let next = SVRParameters::<T, M, K> {
|
||||
// eps: self.svr_search_parameters.eps[self.current_eps],
|
||||
// c: self.svr_search_parameters.c[self.current_c],
|
||||
// tol: self.svr_search_parameters.tol[self.current_tol],
|
||||
// kernel: self.svr_search_parameters.kernel[self.current_kernel].clone(),
|
||||
// m: PhantomData,
|
||||
// };
|
||||
impl<T: Number + FloatNumber + RealNumber, M: Array2<T>> Iterator
|
||||
for SVRSearchParametersIterator<T, M>
|
||||
{
|
||||
type Item = svr::SVRParameters<T>;
|
||||
|
||||
// if self.current_eps + 1 < self.svr_search_parameters.eps.len() {
|
||||
// self.current_eps += 1;
|
||||
// } else if self.current_c + 1 < self.svr_search_parameters.c.len() {
|
||||
// self.current_eps = 0;
|
||||
// self.current_c += 1;
|
||||
// } else if self.current_tol + 1 < self.svr_search_parameters.tol.len() {
|
||||
// self.current_eps = 0;
|
||||
// self.current_c = 0;
|
||||
// self.current_tol += 1;
|
||||
// } else if self.current_kernel + 1 < self.svr_search_parameters.kernel.len() {
|
||||
// self.current_eps = 0;
|
||||
// self.current_c = 0;
|
||||
// self.current_tol = 0;
|
||||
// self.current_kernel += 1;
|
||||
// } else {
|
||||
// self.current_eps += 1;
|
||||
// self.current_c += 1;
|
||||
// self.current_tol += 1;
|
||||
// self.current_kernel += 1;
|
||||
// }
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_eps == self.svr_search_parameters.eps.len()
|
||||
&& self.current_c == self.svr_search_parameters.c.len()
|
||||
&& self.current_tol == self.svr_search_parameters.tol.len()
|
||||
&& self.current_kernel == self.svr_search_parameters.kernel.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
// Some(next)
|
||||
// }
|
||||
// }
|
||||
let next = svr::SVRParameters::<T> {
|
||||
eps: self.svr_search_parameters.eps[self.current_eps],
|
||||
c: self.svr_search_parameters.c[self.current_c],
|
||||
tol: self.svr_search_parameters.tol[self.current_tol],
|
||||
kernel: Some(self.svr_search_parameters.kernel[self.current_kernel].clone()),
|
||||
};
|
||||
|
||||
// impl<T: Number + RealNumber, M: Matrix<T>> Default for SVRSearchParameters<T, M, LinearKernel> {
|
||||
// fn default() -> Self {
|
||||
// let default_params: SVRParameters<T, M, LinearKernel> = SVRParameters::default();
|
||||
if self.current_eps + 1 < self.svr_search_parameters.eps.len() {
|
||||
self.current_eps += 1;
|
||||
} else if self.current_c + 1 < self.svr_search_parameters.c.len() {
|
||||
self.current_eps = 0;
|
||||
self.current_c += 1;
|
||||
} else if self.current_tol + 1 < self.svr_search_parameters.tol.len() {
|
||||
self.current_eps = 0;
|
||||
self.current_c = 0;
|
||||
self.current_tol += 1;
|
||||
} else if self.current_kernel + 1 < self.svr_search_parameters.kernel.len() {
|
||||
self.current_eps = 0;
|
||||
self.current_c = 0;
|
||||
self.current_tol = 0;
|
||||
self.current_kernel += 1;
|
||||
} else {
|
||||
self.current_eps += 1;
|
||||
self.current_c += 1;
|
||||
self.current_tol += 1;
|
||||
self.current_kernel += 1;
|
||||
}
|
||||
|
||||
// SVRSearchParameters {
|
||||
// eps: vec![default_params.eps],
|
||||
// c: vec![default_params.c],
|
||||
// tol: vec![default_params.tol],
|
||||
// kernel: vec![default_params.kernel],
|
||||
// m: PhantomData,
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
// #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
// #[derive(Debug)]
|
||||
// #[cfg_attr(
|
||||
// feature = "serde",
|
||||
// serde(bound(
|
||||
// serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize",
|
||||
// deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>",
|
||||
// ))
|
||||
// )]
|
||||
impl<T: Number + FloatNumber + RealNumber, M: Array2<T>> Default for SVRSearchParameters<T, M> {
|
||||
fn default() -> Self {
|
||||
let default_params: svr::SVRParameters<T> = svr::SVRParameters::default();
|
||||
|
||||
SVRSearchParameters {
|
||||
eps: vec![default_params.eps],
|
||||
c: vec![default_params.c],
|
||||
tol: vec![default_params.tol],
|
||||
kernel: vec![default_params.kernel.unwrap_or_else(Kernels::linear)],
|
||||
m: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use crate::svm::Kernels;
|
||||
|
||||
type T = f64;
|
||||
type M = DenseMatrix<T>;
|
||||
|
||||
#[test]
|
||||
fn test_default_parameters() {
|
||||
let params = SVRSearchParameters::<T, M>::default();
|
||||
assert_eq!(params.eps.len(), 1);
|
||||
assert_eq!(params.c.len(), 1);
|
||||
assert_eq!(params.tol.len(), 1);
|
||||
assert_eq!(params.kernel.len(), 1);
|
||||
// Check that the default kernel is linear
|
||||
assert_eq!(params.kernel[0], Kernels::linear());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_grid_iteration() {
|
||||
let params = SVRSearchParameters::<T, M> {
|
||||
eps: vec![0.1],
|
||||
c: vec![1.0],
|
||||
tol: vec![1e-3],
|
||||
kernel: vec![Kernels::rbf().with_gamma(0.5)],
|
||||
m: PhantomData,
|
||||
};
|
||||
let mut iter = params.into_iter();
|
||||
let param = iter.next().unwrap();
|
||||
assert_eq!(param.eps, 0.1);
|
||||
assert_eq!(param.c, 1.0);
|
||||
assert_eq!(param.tol, 1e-3);
|
||||
assert_eq!(param.kernel, Some(Kernels::rbf().with_gamma(0.5)));
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cartesian_grid_iteration() {
|
||||
let params = SVRSearchParameters::<T, M> {
|
||||
eps: vec![0.1, 0.2],
|
||||
c: vec![1.0, 2.0],
|
||||
tol: vec![1e-3],
|
||||
kernel: vec![Kernels::linear(), Kernels::rbf().with_gamma(0.5)],
|
||||
m: PhantomData,
|
||||
};
|
||||
let expected_count =
|
||||
params.eps.len() * params.c.len() * params.tol.len() * params.kernel.len();
|
||||
let results: Vec<_> = params.into_iter().collect();
|
||||
assert_eq!(results.len(), expected_count);
|
||||
|
||||
// Check that all parameter combinations are present
|
||||
let mut seen = vec![];
|
||||
for p in &results {
|
||||
seen.push((p.eps, p.c, p.tol, p.kernel.clone().unwrap()));
|
||||
}
|
||||
for &eps in &[0.1, 0.2] {
|
||||
for &c in &[1.0, 2.0] {
|
||||
for &tol in &[1e-3] {
|
||||
for kernel in &[Kernels::linear(), Kernels::rbf().with_gamma(0.5)] {
|
||||
assert!(seen.contains(&(eps, c, tol, kernel.clone())));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_grid() {
|
||||
let params = SVRSearchParameters::<T, M> {
|
||||
eps: vec![],
|
||||
c: vec![],
|
||||
tol: vec![],
|
||||
kernel: vec![],
|
||||
m: PhantomData,
|
||||
};
|
||||
let mut iter = params.into_iter();
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kernel_enum_variants() {
|
||||
let lin = Kernels::linear();
|
||||
let rbf = Kernels::rbf().with_gamma(0.2);
|
||||
let poly = Kernels::polynomial()
|
||||
.with_degree(2.0)
|
||||
.with_gamma(1.0)
|
||||
.with_coef0(0.5);
|
||||
let sig = Kernels::sigmoid().with_gamma(0.3).with_coef0(0.1);
|
||||
|
||||
assert_eq!(lin, Kernels::Linear);
|
||||
match rbf {
|
||||
Kernels::RBF { gamma } => assert_eq!(gamma, Some(0.2)),
|
||||
_ => panic!("Not RBF"),
|
||||
}
|
||||
match poly {
|
||||
Kernels::Polynomial {
|
||||
degree,
|
||||
gamma,
|
||||
coef0,
|
||||
} => {
|
||||
assert_eq!(degree, Some(2.0));
|
||||
assert_eq!(gamma, Some(1.0));
|
||||
assert_eq!(coef0, Some(0.5));
|
||||
}
|
||||
_ => panic!("Not Polynomial"),
|
||||
}
|
||||
match sig {
|
||||
Kernels::Sigmoid { gamma, coef0 } => {
|
||||
assert_eq!(gamma, Some(0.3));
|
||||
assert_eq!(coef0, Some(0.1));
|
||||
}
|
||||
_ => panic!("Not Sigmoid"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+26
-25
@@ -51,9 +51,9 @@
|
||||
//!
|
||||
//! let knl = Kernels::linear();
|
||||
//! let params = &SVRParameters::default().with_eps(2.0).with_c(10.0).with_kernel(knl);
|
||||
//! // let svr = SVR::fit(&x, &y, params).unwrap();
|
||||
//! let svr = SVR::fit(&x, &y, params).unwrap();
|
||||
//!
|
||||
//! // let y_hat = svr.predict(&x).unwrap();
|
||||
//! let y_hat = svr.predict(&x).unwrap();
|
||||
//! ```
|
||||
//!
|
||||
//! ## References:
|
||||
@@ -80,11 +80,12 @@ use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::basic::arrays::{Array1, Array2, MutArray};
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::floatnum::FloatNumber;
|
||||
use crate::svm::Kernel;
|
||||
|
||||
use crate::svm::{Kernel, Kernels};
|
||||
|
||||
/// SVR Parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
/// SVR Parameters
|
||||
pub struct SVRParameters<T: Number + FloatNumber + PartialOrd> {
|
||||
/// Epsilon in the epsilon-SVR model.
|
||||
pub eps: T,
|
||||
@@ -97,7 +98,7 @@ pub struct SVRParameters<T: Number + FloatNumber + PartialOrd> {
|
||||
all(feature = "serde", target_arch = "wasm32"),
|
||||
serde(skip_serializing, skip_deserializing)
|
||||
)]
|
||||
pub kernel: Option<Box<dyn Kernel>>,
|
||||
pub kernel: Option<Kernels>,
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
@@ -160,8 +161,8 @@ impl<T: Number + FloatNumber + PartialOrd> SVRParameters<T> {
|
||||
self
|
||||
}
|
||||
/// The kernel function.
|
||||
pub fn with_kernel<K: Kernel + 'static>(mut self, kernel: K) -> Self {
|
||||
self.kernel = Some(Box::new(kernel));
|
||||
pub fn with_kernel(mut self, kernel: Kernels) -> Self {
|
||||
self.kernel = Some(kernel);
|
||||
self
|
||||
}
|
||||
}
|
||||
@@ -597,25 +598,25 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use crate::metrics::mean_squared_error;
|
||||
use crate::svm::search::svr_params::SVRSearchParameters;
|
||||
use crate::svm::Kernels;
|
||||
|
||||
// #[test]
|
||||
// fn search_parameters() {
|
||||
// let parameters: SVRSearchParameters<f64, DenseMatrix<f64>, LinearKernel> =
|
||||
// SVRSearchParameters {
|
||||
// eps: vec![0., 1.],
|
||||
// kernel: vec![LinearKernel {}],
|
||||
// ..Default::default()
|
||||
// };
|
||||
// let mut iter = parameters.into_iter();
|
||||
// let next = iter.next().unwrap();
|
||||
// assert_eq!(next.eps, 0.);
|
||||
// assert_eq!(next.kernel, LinearKernel {});
|
||||
// let next = iter.next().unwrap();
|
||||
// assert_eq!(next.eps, 1.);
|
||||
// assert_eq!(next.kernel, LinearKernel {});
|
||||
// assert!(iter.next().is_none());
|
||||
// }
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters: SVRSearchParameters<f64, DenseMatrix<f64>> = SVRSearchParameters {
|
||||
eps: vec![0., 1.],
|
||||
kernel: vec![Kernels::linear()],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.eps, 0.);
|
||||
// assert_eq!(next.kernel, LinearKernel {});
|
||||
// let next = iter.next().unwrap();
|
||||
// assert_eq!(next.eps, 1.);
|
||||
// assert_eq!(next.kernel, LinearKernel {});
|
||||
// assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
@@ -648,7 +649,7 @@ mod tests {
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
|
||||
let knl = Kernels::linear();
|
||||
let knl: Kernels = Kernels::linear();
|
||||
let y_hat = SVR::fit(
|
||||
&x,
|
||||
&y,
|
||||
|
||||
Reference in New Issue
Block a user