Implement SVR and SVR kernels with Enum. Add tests for argsort_mut (#303)
* Add tests for argsort_mut * Add formatting and cleaning up .github directory * fix clippy error. suggestion to use .contains() * define type explicitly for variable jstack * Implement kernel as enumerator * basic svr and svr_params implementation * Complete enum implementation for Kernels. Implement search grid for SVR. Add documentation. * Fix serde configuration in cargo clippy * Implement search parameters (#304) * Implement SVR kernels as enumerator * basic svr and svr_params implementation * Implement search grid for SVR. Add documentation. * Fix serde configuration in cargo clippy * Fix wasm32 typetag * fix typetag * Bump to version 0.4.2
This commit is contained in:
@@ -2,6 +2,5 @@
|
|||||||
# the repo. Unless a later match takes precedence,
|
# the repo. Unless a later match takes precedence,
|
||||||
# Developers in this list will be requested for
|
# Developers in this list will be requested for
|
||||||
# review when someone opens a pull request.
|
# review when someone opens a pull request.
|
||||||
* @VolodymyrOrlov
|
|
||||||
* @morenol
|
* @morenol
|
||||||
* @Mec-iS
|
* @Mec-iS
|
||||||
|
|||||||
@@ -50,9 +50,9 @@ $ rust-code-analysis-cli -p src/algorithm/neighbour/fastpair.rs --ls 22 --le 213
|
|||||||
|
|
||||||
1. After a PR is opened maintainers are notified
|
1. After a PR is opened maintainers are notified
|
||||||
2. Probably changes will be required to comply with the workflow, these commands are run automatically and all tests shall pass:
|
2. Probably changes will be required to comply with the workflow, these commands are run automatically and all tests shall pass:
|
||||||
* **Coverage** (optional): `tarpaulin` is used with command `cargo tarpaulin --out Lcov --all-features -- --test-threads 1`
|
|
||||||
* **Formatting**: run `rustfmt src/*.rs` to apply automatic formatting
|
* **Formatting**: run `rustfmt src/*.rs` to apply automatic formatting
|
||||||
* **Linting**: `clippy` is used with command `cargo clippy --all-features -- -Drust-2018-idioms -Dwarnings`
|
* **Linting**: `clippy` is used with command `cargo clippy --all-features -- -Drust-2018-idioms -Dwarnings`
|
||||||
|
* **Coverage** (optional): `tarpaulin` is used with command `cargo tarpaulin --out Lcov --all-features -- --test-threads 1`
|
||||||
* **Testing**: multiple test pipelines are run for different targets
|
* **Testing**: multiple test pipelines are run for different targets
|
||||||
3. When everything is OK, code is merged.
|
3. When everything is OK, code is merged.
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -2,7 +2,7 @@
|
|||||||
name = "smartcore"
|
name = "smartcore"
|
||||||
description = "Machine Learning in Rust."
|
description = "Machine Learning in Rust."
|
||||||
homepage = "https://smartcorelib.org"
|
homepage = "https://smartcorelib.org"
|
||||||
version = "0.4.1"
|
version = "0.4.2"
|
||||||
authors = ["smartcore Developers"]
|
authors = ["smartcore Developers"]
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
|
|||||||
@@ -619,7 +619,7 @@ pub trait MutArrayView1<T: Debug + Display + Copy + Sized>:
|
|||||||
T: Number + PartialOrd,
|
T: Number + PartialOrd,
|
||||||
{
|
{
|
||||||
let stack_size = 64;
|
let stack_size = 64;
|
||||||
let mut jstack = -1;
|
let mut jstack: i32 = -1;
|
||||||
let mut l = 0;
|
let mut l = 0;
|
||||||
let mut istack = vec![0; stack_size];
|
let mut istack = vec![0; stack_size];
|
||||||
let mut ir = self.shape() - 1;
|
let mut ir = self.shape() - 1;
|
||||||
@@ -2190,4 +2190,29 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(result, [65, 581, 30])
|
assert_eq!(result, [65, 581, 30])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_argsort_mut_exact_boundary() {
|
||||||
|
// Test index == length - 1 case
|
||||||
|
let boundary =
|
||||||
|
DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0, f64::MAX], &[3.0, f64::MAX, 0.0, 2.0]])
|
||||||
|
.unwrap();
|
||||||
|
let mut view0: Vec<f64> = boundary.get_col(0).iterator(0).copied().collect();
|
||||||
|
let indices = view0.argsort_mut();
|
||||||
|
assert_eq!(indices.last(), Some(&1));
|
||||||
|
assert_eq!(indices.first(), Some(&0));
|
||||||
|
|
||||||
|
let mut view1: Vec<f64> = boundary.get_col(3).iterator(0).copied().collect();
|
||||||
|
let indices = view1.argsort_mut();
|
||||||
|
assert_eq!(indices.last(), Some(&0));
|
||||||
|
assert_eq!(indices.first(), Some(&1));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_argsort_mut_filled_array() {
|
||||||
|
let matrix = DenseMatrix::<f64>::rand(1000, 1000);
|
||||||
|
let mut view: Vec<f64> = matrix.get_col(0).iterator(0).copied().collect();
|
||||||
|
let sorted = view.argsort_mut();
|
||||||
|
assert_eq!(sorted.len(), 1000);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ impl KNNWeightFunction {
|
|||||||
KNNWeightFunction::Distance => {
|
KNNWeightFunction::Distance => {
|
||||||
// if there are any points that has zero distance from one or more training points,
|
// if there are any points that has zero distance from one or more training points,
|
||||||
// those training points are weighted as 1.0 and the other points as 0.0
|
// those training points are weighted as 1.0 and the other points as 0.0
|
||||||
if distances.iter().any(|&e| e == 0f64) {
|
if distances.contains(&0f64) {
|
||||||
distances
|
distances
|
||||||
.iter()
|
.iter()
|
||||||
.map(|e| if *e == 0f64 { 1f64 } else { 0f64 })
|
.map(|e| if *e == 0f64 { 1f64 } else { 0f64 })
|
||||||
|
|||||||
+282
-178
@@ -25,14 +25,18 @@
|
|||||||
/// search parameters
|
/// search parameters
|
||||||
pub mod svc;
|
pub mod svc;
|
||||||
pub mod svr;
|
pub mod svr;
|
||||||
// /// search parameters space
|
// search parameters space
|
||||||
// pub mod search;
|
pub mod search;
|
||||||
|
|
||||||
use core::fmt::Debug;
|
use core::fmt::Debug;
|
||||||
|
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
// Only import typetag if not compiling for wasm32 and serde is enabled
|
||||||
|
#[cfg(all(feature = "serde", not(target_arch = "wasm32")))]
|
||||||
|
use typetag;
|
||||||
|
|
||||||
use crate::error::{Failed, FailedError};
|
use crate::error::{Failed, FailedError};
|
||||||
use crate::linalg::basic::arrays::{Array1, ArrayView1};
|
use crate::linalg::basic::arrays::{Array1, ArrayView1};
|
||||||
|
|
||||||
@@ -48,197 +52,281 @@ pub trait Kernel: Debug {
|
|||||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed>;
|
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Pre-defined kernel functions
|
/// A enumerator for all the kernels type to support.
|
||||||
|
/// This allows kernel selection and parameterization ergonomic, type-safe, and ready for use in parameter structs like SVRParameters.
|
||||||
|
/// You can construct kernels using the provided variants and builder-style methods.
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// use smartcore::svm::Kernels;
|
||||||
|
///
|
||||||
|
/// let linear = Kernels::linear();
|
||||||
|
/// let rbf = Kernels::rbf().with_gamma(0.5);
|
||||||
|
/// let poly = Kernels::polynomial().with_degree(3.0).with_gamma(0.5).with_coef0(1.0);
|
||||||
|
/// let sigmoid = Kernels::sigmoid().with_gamma(0.2).with_coef0(0.0);
|
||||||
|
/// ```
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub struct Kernels;
|
pub enum Kernels {
|
||||||
|
/// Linear kernel (default).
|
||||||
|
///
|
||||||
|
/// Computes the standard dot product between vectors.
|
||||||
|
Linear,
|
||||||
|
|
||||||
|
/// Radial Basis Function (RBF) kernel.
|
||||||
|
///
|
||||||
|
/// Formula: K(x, y) = exp(-gamma * ||x-y||²)
|
||||||
|
RBF {
|
||||||
|
/// Controls the width of the Gaussian RBF kernel.
|
||||||
|
///
|
||||||
|
/// Larger values of gamma lead to higher bias and lower variance.
|
||||||
|
/// This parameter is inversely proportional to the radius of influence
|
||||||
|
/// of samples selected by the model as support vectors.
|
||||||
|
gamma: Option<f64>,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Polynomial kernel.
|
||||||
|
///
|
||||||
|
/// Formula: K(x, y) = (gamma * <x, y> + coef0)^degree
|
||||||
|
Polynomial {
|
||||||
|
/// The degree of the polynomial kernel.
|
||||||
|
///
|
||||||
|
/// Integer values are typical (2 = quadratic, 3 = cubic), but any positive real value is valid.
|
||||||
|
/// Higher degree values create decision boundaries with higher complexity.
|
||||||
|
degree: Option<f64>,
|
||||||
|
|
||||||
|
/// Kernel coefficient for the dot product.
|
||||||
|
///
|
||||||
|
/// Controls the influence of higher-degree versus lower-degree terms in the polynomial.
|
||||||
|
/// If None, a default value will be used.
|
||||||
|
gamma: Option<f64>,
|
||||||
|
|
||||||
|
/// Independent term in the polynomial kernel.
|
||||||
|
///
|
||||||
|
/// Controls the influence of higher-degree versus lower-degree terms.
|
||||||
|
/// If None, a default value of 1.0 will be used.
|
||||||
|
coef0: Option<f64>,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Sigmoid kernel.
|
||||||
|
///
|
||||||
|
/// Formula: K(x, y) = tanh(gamma * <x, y> + coef0)
|
||||||
|
Sigmoid {
|
||||||
|
/// Kernel coefficient for the dot product.
|
||||||
|
///
|
||||||
|
/// Controls the scaling of the dot product in the sigmoid function.
|
||||||
|
/// If None, a default value will be used.
|
||||||
|
gamma: Option<f64>,
|
||||||
|
|
||||||
|
/// Independent term in the sigmoid kernel.
|
||||||
|
///
|
||||||
|
/// Acts as a threshold/bias term in the sigmoid function.
|
||||||
|
/// If None, a default value of 1.0 will be used.
|
||||||
|
coef0: Option<f64>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
impl Kernels {
|
impl Kernels {
|
||||||
/// Return a default linear
|
/// Create a linear kernel.
|
||||||
pub fn linear() -> LinearKernel {
|
///
|
||||||
LinearKernel
|
/// The linear kernel computes the dot product between two vectors:
|
||||||
|
/// K(x, y) = <x, y>
|
||||||
|
pub fn linear() -> Self {
|
||||||
|
Kernels::Linear
|
||||||
}
|
}
|
||||||
/// Return a default RBF
|
|
||||||
pub fn rbf() -> RBFKernel {
|
/// Create an RBF kernel with unspecified gamma.
|
||||||
RBFKernel::default()
|
///
|
||||||
|
/// The RBF kernel is defined as:
|
||||||
|
/// K(x, y) = exp(-gamma * ||x-y||²)
|
||||||
|
///
|
||||||
|
/// You should specify gamma using `with_gamma()` before using this kernel.
|
||||||
|
pub fn rbf() -> Self {
|
||||||
|
Kernels::RBF { gamma: None }
|
||||||
}
|
}
|
||||||
/// Return a default polynomial
|
|
||||||
pub fn polynomial() -> PolynomialKernel {
|
/// Create a polynomial kernel with default parameters.
|
||||||
PolynomialKernel::default()
|
///
|
||||||
|
/// The polynomial kernel is defined as:
|
||||||
|
/// K(x, y) = (gamma * <x, y> + coef0)^degree
|
||||||
|
///
|
||||||
|
/// Default values:
|
||||||
|
/// - gamma: None (must be specified)
|
||||||
|
/// - degree: None (must be specified)
|
||||||
|
/// - coef0: 1.0
|
||||||
|
pub fn polynomial() -> Self {
|
||||||
|
Kernels::Polynomial {
|
||||||
|
gamma: None,
|
||||||
|
degree: None,
|
||||||
|
coef0: Some(1.0),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
/// Return a default sigmoid
|
|
||||||
pub fn sigmoid() -> SigmoidKernel {
|
/// Create a sigmoid kernel with default parameters.
|
||||||
SigmoidKernel::default()
|
///
|
||||||
|
/// The sigmoid kernel is defined as:
|
||||||
|
/// K(x, y) = tanh(gamma * <x, y> + coef0)
|
||||||
|
///
|
||||||
|
/// Default values:
|
||||||
|
/// - gamma: None (must be specified)
|
||||||
|
/// - coef0: 1.0
|
||||||
|
///
|
||||||
|
pub fn sigmoid() -> Self {
|
||||||
|
Kernels::Sigmoid {
|
||||||
|
gamma: None,
|
||||||
|
coef0: Some(1.0),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/// Linear Kernel
|
/// Set the `gamma` parameter for RBF, polynomial, or sigmoid kernels.
|
||||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
///
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
/// The gamma parameter has different interpretations depending on the kernel:
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
/// - For RBF: Controls the width of the Gaussian. Larger values mean tighter fit.
|
||||||
pub struct LinearKernel;
|
/// - For Polynomial: Scaling factor for the dot product.
|
||||||
|
/// - For Sigmoid: Scaling factor for the dot product.
|
||||||
/// Radial basis function (Gaussian) kernel
|
///
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
pub fn with_gamma(self, gamma: f64) -> Self {
|
||||||
#[derive(Debug, Default, Clone, PartialEq)]
|
match self {
|
||||||
pub struct RBFKernel {
|
Kernels::RBF { .. } => Kernels::RBF { gamma: Some(gamma) },
|
||||||
/// kernel coefficient
|
Kernels::Polynomial { degree, coef0, .. } => Kernels::Polynomial {
|
||||||
pub gamma: Option<f64>,
|
gamma: Some(gamma),
|
||||||
}
|
degree,
|
||||||
|
coef0,
|
||||||
#[allow(dead_code)]
|
},
|
||||||
impl RBFKernel {
|
Kernels::Sigmoid { coef0, .. } => Kernels::Sigmoid {
|
||||||
/// assign gamma parameter to kernel (required)
|
gamma: Some(gamma),
|
||||||
/// ```rust
|
coef0,
|
||||||
/// use smartcore::svm::RBFKernel;
|
},
|
||||||
/// let knl = RBFKernel::default().with_gamma(0.7);
|
other => other,
|
||||||
/// ```
|
}
|
||||||
pub fn with_gamma(mut self, gamma: f64) -> Self {
|
|
||||||
self.gamma = Some(gamma);
|
|
||||||
self
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/// Polynomial kernel
|
/// Set the `degree` parameter for the polynomial kernel.
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
///
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
/// The degree parameter controls the flexibility of the decision boundary.
|
||||||
pub struct PolynomialKernel {
|
/// Higher degrees create more complex boundaries but may lead to overfitting.
|
||||||
/// degree of the polynomial
|
///
|
||||||
pub degree: Option<f64>,
|
pub fn with_degree(self, degree: f64) -> Self {
|
||||||
/// kernel coefficient
|
match self {
|
||||||
pub gamma: Option<f64>,
|
Kernels::Polynomial { gamma, coef0, .. } => Kernels::Polynomial {
|
||||||
/// independent term in kernel function
|
degree: Some(degree),
|
||||||
pub coef0: Option<f64>,
|
gamma,
|
||||||
}
|
coef0,
|
||||||
|
},
|
||||||
|
other => other,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Default for PolynomialKernel {
|
/// Set the `coef0` parameter for polynomial or sigmoid kernels.
|
||||||
fn default() -> Self {
|
///
|
||||||
Self {
|
/// The coef0 parameter is the independent term in the kernel function:
|
||||||
gamma: Option::None,
|
/// - For Polynomial: Controls the influence of higher-degree vs. lower-degree terms.
|
||||||
degree: Option::None,
|
/// - For Sigmoid: Acts as a threshold/bias term.
|
||||||
coef0: Some(1f64),
|
///
|
||||||
|
pub fn with_coef0(self, coef0: f64) -> Self {
|
||||||
|
match self {
|
||||||
|
Kernels::Polynomial { degree, gamma, .. } => Kernels::Polynomial {
|
||||||
|
degree,
|
||||||
|
gamma,
|
||||||
|
coef0: Some(coef0),
|
||||||
|
},
|
||||||
|
Kernels::Sigmoid { gamma, .. } => Kernels::Sigmoid {
|
||||||
|
gamma,
|
||||||
|
coef0: Some(coef0),
|
||||||
|
},
|
||||||
|
other => other,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PolynomialKernel {
|
/// Implementation of the [`Kernel`] trait for the [`Kernels`] enum in smartcore.
|
||||||
/// set parameters for kernel
|
///
|
||||||
/// ```rust
|
/// This method computes the value of the kernel function between two feature vectors `x_i` and `x_j`,
|
||||||
/// use smartcore::svm::PolynomialKernel;
|
/// according to the variant and parameters of the [`Kernels`] enum. This enables flexible and type-safe
|
||||||
/// let knl = PolynomialKernel::default().with_params(3.0, 0.7, 1.0);
|
/// selection of kernel functions for SVM and SVR models in smartcore.
|
||||||
/// ```
|
///
|
||||||
pub fn with_params(mut self, degree: f64, gamma: f64, coef0: f64) -> Self {
|
/// # Supported Kernels
|
||||||
self.degree = Some(degree);
|
///
|
||||||
self.gamma = Some(gamma);
|
/// - [`Kernels::Linear`]: Computes the standard dot product between `x_i` and `x_j`.
|
||||||
self.coef0 = Some(coef0);
|
/// - [`Kernels::RBF`]: Computes the Radial Basis Function (Gaussian) kernel. Requires `gamma`.
|
||||||
self
|
/// - [`Kernels::Polynomial`]: Computes the polynomial kernel. Requires `degree`, `gamma`, and `coef0`.
|
||||||
}
|
/// - [`Kernels::Sigmoid`]: Computes the sigmoid kernel. Requires `gamma` and `coef0`.
|
||||||
/// set gamma parameter for kernel
|
///
|
||||||
/// ```rust
|
/// # Parameters
|
||||||
/// use smartcore::svm::PolynomialKernel;
|
///
|
||||||
/// let knl = PolynomialKernel::default().with_gamma(0.7);
|
/// - `x_i`: First input vector (feature vector).
|
||||||
/// ```
|
/// - `x_j`: Second input vector (feature vector).
|
||||||
pub fn with_gamma(mut self, gamma: f64) -> Self {
|
///
|
||||||
self.gamma = Some(gamma);
|
/// # Returns
|
||||||
self
|
///
|
||||||
}
|
/// - `Ok(f64)`: The computed kernel value.
|
||||||
/// set degree parameter for kernel
|
/// - `Err(Failed)`: If any required kernel parameter is missing.
|
||||||
/// ```rust
|
///
|
||||||
/// use smartcore::svm::PolynomialKernel;
|
/// # Errors
|
||||||
/// let knl = PolynomialKernel::default().with_degree(3.0, 100);
|
///
|
||||||
/// ```
|
/// Returns `Err(Failed)` if a required parameter (such as `gamma`, `degree`, or `coef0`)
|
||||||
pub fn with_degree(self, degree: f64, n_features: usize) -> Self {
|
/// is `None` for the selected kernel variant.
|
||||||
self.with_params(degree, 1f64, 1f64 / n_features as f64)
|
///
|
||||||
}
|
/// # Example
|
||||||
}
|
///
|
||||||
|
/// ```
|
||||||
/// Sigmoid (hyperbolic tangent) kernel
|
/// use smartcore::svm::Kernels;
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
/// use smartcore::svm::Kernel;
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
///
|
||||||
pub struct SigmoidKernel {
|
/// let x = vec![1.0, 2.0, 3.0];
|
||||||
/// kernel coefficient
|
/// let y = vec![4.0, 5.0, 6.0];
|
||||||
pub gamma: Option<f64>,
|
/// let kernel = Kernels::rbf().with_gamma(0.5);
|
||||||
/// independent term in kernel function
|
/// let value = kernel.apply(&x, &y).unwrap();
|
||||||
pub coef0: Option<f64>,
|
/// ```
|
||||||
}
|
///
|
||||||
|
/// # Notes
|
||||||
impl Default for SigmoidKernel {
|
///
|
||||||
fn default() -> Self {
|
/// - This implementation follows smartcore's philosophy: pure Rust, no macros, no unsafe code,
|
||||||
Self {
|
/// and an accessible, pythonic API surface for both ML practitioners and Rust beginners.
|
||||||
gamma: Option::None,
|
/// - All kernel parameters must be set before calling `apply`; missing parameters will result in an error.
|
||||||
coef0: Some(1f64),
|
///
|
||||||
}
|
/// See the [`Kernels`] enum documentation for more details on each kernel type and its parameters.
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SigmoidKernel {
|
|
||||||
/// set parameters for kernel
|
|
||||||
/// ```rust
|
|
||||||
/// use smartcore::svm::SigmoidKernel;
|
|
||||||
/// let knl = SigmoidKernel::default().with_params(0.7, 1.0);
|
|
||||||
/// ```
|
|
||||||
pub fn with_params(mut self, gamma: f64, coef0: f64) -> Self {
|
|
||||||
self.gamma = Some(gamma);
|
|
||||||
self.coef0 = Some(coef0);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
/// set gamma parameter for kernel
|
|
||||||
/// ```rust
|
|
||||||
/// use smartcore::svm::SigmoidKernel;
|
|
||||||
/// let knl = SigmoidKernel::default().with_gamma(0.7);
|
|
||||||
/// ```
|
|
||||||
pub fn with_gamma(mut self, gamma: f64) -> Self {
|
|
||||||
self.gamma = Some(gamma);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
|
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
|
||||||
impl Kernel for LinearKernel {
|
impl Kernel for Kernels {
|
||||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
||||||
Ok(x_i.dot(x_j))
|
match self {
|
||||||
}
|
Kernels::Linear => Ok(x_i.dot(x_j)),
|
||||||
}
|
Kernels::RBF { gamma } => {
|
||||||
|
let gamma = gamma.ok_or_else(|| {
|
||||||
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
|
Failed::because(FailedError::ParametersError, "gamma not set")
|
||||||
impl Kernel for RBFKernel {
|
})?;
|
||||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
let v_diff = x_i.sub(x_j);
|
||||||
if self.gamma.is_none() {
|
Ok((-gamma * v_diff.mul(&v_diff).sum()).exp())
|
||||||
return Err(Failed::because(
|
}
|
||||||
FailedError::ParametersError,
|
Kernels::Polynomial {
|
||||||
"gamma should be set, use {Kernel}::default().with_gamma(..)",
|
degree,
|
||||||
));
|
gamma,
|
||||||
|
coef0,
|
||||||
|
} => {
|
||||||
|
let degree = degree.ok_or_else(|| {
|
||||||
|
Failed::because(FailedError::ParametersError, "degree not set")
|
||||||
|
})?;
|
||||||
|
let gamma = gamma.ok_or_else(|| {
|
||||||
|
Failed::because(FailedError::ParametersError, "gamma not set")
|
||||||
|
})?;
|
||||||
|
let coef0 = coef0.ok_or_else(|| {
|
||||||
|
Failed::because(FailedError::ParametersError, "coef0 not set")
|
||||||
|
})?;
|
||||||
|
let dot = x_i.dot(x_j);
|
||||||
|
Ok((gamma * dot + coef0).powf(degree))
|
||||||
|
}
|
||||||
|
Kernels::Sigmoid { gamma, coef0 } => {
|
||||||
|
let gamma = gamma.ok_or_else(|| {
|
||||||
|
Failed::because(FailedError::ParametersError, "gamma not set")
|
||||||
|
})?;
|
||||||
|
let coef0 = coef0.ok_or_else(|| {
|
||||||
|
Failed::because(FailedError::ParametersError, "coef0 not set")
|
||||||
|
})?;
|
||||||
|
let dot = x_i.dot(x_j);
|
||||||
|
Ok((gamma * dot + coef0).tanh())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
let v_diff = x_i.sub(x_j);
|
|
||||||
Ok((-self.gamma.unwrap() * v_diff.mul(&v_diff).sum()).exp())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
|
|
||||||
impl Kernel for PolynomialKernel {
|
|
||||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
|
||||||
if self.gamma.is_none() || self.coef0.is_none() || self.degree.is_none() {
|
|
||||||
return Err(Failed::because(
|
|
||||||
FailedError::ParametersError, "gamma, coef0, degree should be set,
|
|
||||||
use {Kernel}::default().with_{parameter}(..)")
|
|
||||||
);
|
|
||||||
}
|
|
||||||
let dot = x_i.dot(x_j);
|
|
||||||
Ok((self.gamma.unwrap() * dot + self.coef0.unwrap()).powf(self.degree.unwrap()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
|
|
||||||
impl Kernel for SigmoidKernel {
|
|
||||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
|
||||||
if self.gamma.is_none() || self.coef0.is_none() {
|
|
||||||
return Err(Failed::because(
|
|
||||||
FailedError::ParametersError, "gamma, coef0, degree should be set,
|
|
||||||
use {Kernel}::default().with_{parameter}(..)")
|
|
||||||
);
|
|
||||||
}
|
|
||||||
let dot = x_i.dot(x_j);
|
|
||||||
Ok(self.gamma.unwrap() * dot + self.coef0.unwrap().tanh())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -247,6 +335,18 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::svm::Kernels;
|
use crate::svm::Kernels;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rbf_kernel() {
|
||||||
|
let v1 = vec![1., 2., 3.];
|
||||||
|
let v2 = vec![4., 5., 6.];
|
||||||
|
let result = Kernels::rbf()
|
||||||
|
.with_gamma(0.055)
|
||||||
|
.apply(&v1, &v2)
|
||||||
|
.unwrap()
|
||||||
|
.abs();
|
||||||
|
assert!((0.2265f64 - result) < 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(
|
#[cfg_attr(
|
||||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
wasm_bindgen_test::wasm_bindgen_test
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
@@ -264,7 +364,7 @@ mod tests {
|
|||||||
wasm_bindgen_test::wasm_bindgen_test
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
)]
|
)]
|
||||||
#[test]
|
#[test]
|
||||||
fn rbf_kernel() {
|
fn test_rbf_kernel() {
|
||||||
let v1 = vec![1., 2., 3.];
|
let v1 = vec![1., 2., 3.];
|
||||||
let v2 = vec![4., 5., 6.];
|
let v2 = vec![4., 5., 6.];
|
||||||
|
|
||||||
@@ -287,7 +387,10 @@ mod tests {
|
|||||||
let v2 = vec![4., 5., 6.];
|
let v2 = vec![4., 5., 6.];
|
||||||
|
|
||||||
let result = Kernels::polynomial()
|
let result = Kernels::polynomial()
|
||||||
.with_params(3.0, 0.5, 1.0)
|
.with_gamma(0.5)
|
||||||
|
.with_degree(3.0)
|
||||||
|
.with_coef0(1.0)
|
||||||
|
//.with_params(3.0, 0.5, 1.0)
|
||||||
.apply(&v1, &v2)
|
.apply(&v1, &v2)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.abs();
|
.abs();
|
||||||
@@ -305,7 +408,8 @@ mod tests {
|
|||||||
let v2 = vec![4., 5., 6.];
|
let v2 = vec![4., 5., 6.];
|
||||||
|
|
||||||
let result = Kernels::sigmoid()
|
let result = Kernels::sigmoid()
|
||||||
.with_params(0.01, 0.1)
|
.with_gamma(0.01)
|
||||||
|
.with_coef0(0.1)
|
||||||
.apply(&v1, &v2)
|
.apply(&v1, &v2)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.abs();
|
.abs();
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//! SVC and Grid Search
|
||||||
|
|
||||||
/// SVC search parameters
|
/// SVC search parameters
|
||||||
pub mod svc_params;
|
pub mod svc_params;
|
||||||
/// SVC search parameters
|
/// SVC search parameters
|
||||||
|
|||||||
+282
-101
@@ -1,112 +1,293 @@
|
|||||||
// /// SVR grid search parameters
|
//! # SVR Grid Search Parameters
|
||||||
// #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
//!
|
||||||
// #[derive(Debug, Clone)]
|
//! This module provides utilities for defining and iterating over grid search parameter spaces
|
||||||
// pub struct SVRSearchParameters<T: Number + RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
//! for Support Vector Regression (SVR) models in [smartcore](https://github.com/smartcorelib/smartcore).
|
||||||
// /// Epsilon in the epsilon-SVR model.
|
//!
|
||||||
// pub eps: Vec<T>,
|
//! The main struct, [`SVRSearchParameters`], allows users to specify multiple values for each
|
||||||
// /// Regularization parameter.
|
//! SVR hyperparameter (epsilon, regularization parameter C, tolerance, and kernel function).
|
||||||
// pub c: Vec<T>,
|
//! The provided iterator yields all possible combinations (the Cartesian product) of these parameters,
|
||||||
// /// Tolerance for stopping eps.
|
//! enabling exhaustive grid search for hyperparameter tuning.
|
||||||
// pub tol: Vec<T>,
|
//!
|
||||||
// /// The kernel function.
|
//!
|
||||||
// pub kernel: Vec<K>,
|
//! ## Example
|
||||||
// /// Unused parameter.
|
//! ```
|
||||||
// m: PhantomData<M>,
|
//! use smartcore::svm::Kernels;
|
||||||
// }
|
//! use smartcore::svm::search::svr_params::SVRSearchParameters;
|
||||||
|
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||||
|
//!
|
||||||
|
//! let params = SVRSearchParameters::<f64, DenseMatrix<f64>> {
|
||||||
|
//! eps: vec![0.1, 0.2],
|
||||||
|
//! c: vec![1.0, 10.0],
|
||||||
|
//! tol: vec![1e-3],
|
||||||
|
//! kernel: vec![Kernels::linear(), Kernels::rbf().with_gamma(0.5)],
|
||||||
|
//! m: std::marker::PhantomData,
|
||||||
|
//! };
|
||||||
|
//!
|
||||||
|
//! // for param_set in params.into_iter() {
|
||||||
|
//! // Use param_set (of type svr::SVRParameters) to fit and evaluate your SVR model.
|
||||||
|
//! // }
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//!
|
||||||
|
//! ## Note
|
||||||
|
//! This module is intended for use with smartcore version 0.4 or later. The API is not compatible with older versions[1].
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
// /// SVR grid search iterator
|
use crate::linalg::basic::arrays::Array2;
|
||||||
// pub struct SVRSearchParametersIterator<T: Number + RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
use crate::numbers::basenum::Number;
|
||||||
// svr_search_parameters: SVRSearchParameters<T, M, K>,
|
use crate::numbers::floatnum::FloatNumber;
|
||||||
// current_eps: usize,
|
use crate::numbers::realnum::RealNumber;
|
||||||
// current_c: usize,
|
use crate::svm::{svr, Kernels};
|
||||||
// current_tol: usize,
|
use std::marker::PhantomData;
|
||||||
// current_kernel: usize,
|
|
||||||
// }
|
|
||||||
|
|
||||||
// impl<T: Number + RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> IntoIterator
|
/// ## SVR grid search parameters
|
||||||
// for SVRSearchParameters<T, M, K>
|
/// A struct representing a grid of hyperparameters for SVR grid search in smartcore.
|
||||||
// {
|
///
|
||||||
// type Item = SVRParameters<T, M, K>;
|
/// Each field is a vector of possible values for the corresponding SVR hyperparameter.
|
||||||
// type IntoIter = SVRSearchParametersIterator<T, M, K>;
|
/// The [`IntoIterator`] implementation yields every possible combination of these parameters
|
||||||
|
/// as an `svr::SVRParameters` struct, suitable for use in model selection routines.
|
||||||
|
///
|
||||||
|
/// # Type Parameters
|
||||||
|
/// - `T`: Numeric type for parameters (e.g., `f64`)
|
||||||
|
/// - `M`: Matrix type implementing [`Array2<T>`]
|
||||||
|
///
|
||||||
|
/// # Fields
|
||||||
|
/// - `eps`: Vector of epsilon values for the epsilon-insensitive loss in SVR.
|
||||||
|
/// - `c`: Vector of regularization parameters (C) for SVR.
|
||||||
|
/// - `tol`: Vector of tolerance values for the stopping criterion.
|
||||||
|
/// - `kernel`: Vector of kernel function variants (see [`Kernels`]).
|
||||||
|
/// - `m`: Phantom data for the matrix type parameter.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
/// ```
|
||||||
|
/// use smartcore::svm::Kernels;
|
||||||
|
/// use smartcore::svm::search::svr_params::SVRSearchParameters;
|
||||||
|
/// use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||||
|
///
|
||||||
|
/// let params = SVRSearchParameters::<f64, DenseMatrix<f64>> {
|
||||||
|
/// eps: vec![0.1, 0.2],
|
||||||
|
/// c: vec![1.0, 10.0],
|
||||||
|
/// tol: vec![1e-3],
|
||||||
|
/// kernel: vec![Kernels::linear(), Kernels::rbf().with_gamma(0.5)],
|
||||||
|
/// m: std::marker::PhantomData,
|
||||||
|
/// };
|
||||||
|
/// ```
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct SVRSearchParameters<T: Number + RealNumber, M: Array2<T>> {
|
||||||
|
/// Epsilon in the epsilon-SVR model.
|
||||||
|
pub eps: Vec<T>,
|
||||||
|
/// Regularization parameter.
|
||||||
|
pub c: Vec<T>,
|
||||||
|
/// Tolerance for stopping eps.
|
||||||
|
pub tol: Vec<T>,
|
||||||
|
/// The kernel function.
|
||||||
|
pub kernel: Vec<Kernels>,
|
||||||
|
/// Unused parameter.
|
||||||
|
pub m: PhantomData<M>,
|
||||||
|
}
|
||||||
|
|
||||||
// fn into_iter(self) -> Self::IntoIter {
|
/// SVR grid search iterator
|
||||||
// SVRSearchParametersIterator {
|
pub struct SVRSearchParametersIterator<T: Number + RealNumber, M: Array2<T>> {
|
||||||
// svr_search_parameters: self,
|
svr_search_parameters: SVRSearchParameters<T, M>,
|
||||||
// current_eps: 0,
|
current_eps: usize,
|
||||||
// current_c: 0,
|
current_c: usize,
|
||||||
// current_tol: 0,
|
current_tol: usize,
|
||||||
// current_kernel: 0,
|
current_kernel: usize,
|
||||||
// }
|
}
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// impl<T: Number + RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Iterator
|
impl<T: Number + FloatNumber + RealNumber, M: Array2<T>> IntoIterator
|
||||||
// for SVRSearchParametersIterator<T, M, K>
|
for SVRSearchParameters<T, M>
|
||||||
// {
|
{
|
||||||
// type Item = SVRParameters<T, M, K>;
|
type Item = svr::SVRParameters<T>;
|
||||||
|
type IntoIter = SVRSearchParametersIterator<T, M>;
|
||||||
|
|
||||||
// fn next(&mut self) -> Option<Self::Item> {
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
// if self.current_eps == self.svr_search_parameters.eps.len()
|
SVRSearchParametersIterator {
|
||||||
// && self.current_c == self.svr_search_parameters.c.len()
|
svr_search_parameters: self,
|
||||||
// && self.current_tol == self.svr_search_parameters.tol.len()
|
current_eps: 0,
|
||||||
// && self.current_kernel == self.svr_search_parameters.kernel.len()
|
current_c: 0,
|
||||||
// {
|
current_tol: 0,
|
||||||
// return None;
|
current_kernel: 0,
|
||||||
// }
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// let next = SVRParameters::<T, M, K> {
|
impl<T: Number + FloatNumber + RealNumber, M: Array2<T>> Iterator
|
||||||
// eps: self.svr_search_parameters.eps[self.current_eps],
|
for SVRSearchParametersIterator<T, M>
|
||||||
// c: self.svr_search_parameters.c[self.current_c],
|
{
|
||||||
// tol: self.svr_search_parameters.tol[self.current_tol],
|
type Item = svr::SVRParameters<T>;
|
||||||
// kernel: self.svr_search_parameters.kernel[self.current_kernel].clone(),
|
|
||||||
// m: PhantomData,
|
|
||||||
// };
|
|
||||||
|
|
||||||
// if self.current_eps + 1 < self.svr_search_parameters.eps.len() {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
// self.current_eps += 1;
|
if self.current_eps == self.svr_search_parameters.eps.len()
|
||||||
// } else if self.current_c + 1 < self.svr_search_parameters.c.len() {
|
&& self.current_c == self.svr_search_parameters.c.len()
|
||||||
// self.current_eps = 0;
|
&& self.current_tol == self.svr_search_parameters.tol.len()
|
||||||
// self.current_c += 1;
|
&& self.current_kernel == self.svr_search_parameters.kernel.len()
|
||||||
// } else if self.current_tol + 1 < self.svr_search_parameters.tol.len() {
|
{
|
||||||
// self.current_eps = 0;
|
return None;
|
||||||
// self.current_c = 0;
|
}
|
||||||
// self.current_tol += 1;
|
|
||||||
// } else if self.current_kernel + 1 < self.svr_search_parameters.kernel.len() {
|
|
||||||
// self.current_eps = 0;
|
|
||||||
// self.current_c = 0;
|
|
||||||
// self.current_tol = 0;
|
|
||||||
// self.current_kernel += 1;
|
|
||||||
// } else {
|
|
||||||
// self.current_eps += 1;
|
|
||||||
// self.current_c += 1;
|
|
||||||
// self.current_tol += 1;
|
|
||||||
// self.current_kernel += 1;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// Some(next)
|
let next = svr::SVRParameters::<T> {
|
||||||
// }
|
eps: self.svr_search_parameters.eps[self.current_eps],
|
||||||
// }
|
c: self.svr_search_parameters.c[self.current_c],
|
||||||
|
tol: self.svr_search_parameters.tol[self.current_tol],
|
||||||
|
kernel: Some(self.svr_search_parameters.kernel[self.current_kernel].clone()),
|
||||||
|
};
|
||||||
|
|
||||||
// impl<T: Number + RealNumber, M: Matrix<T>> Default for SVRSearchParameters<T, M, LinearKernel> {
|
if self.current_eps + 1 < self.svr_search_parameters.eps.len() {
|
||||||
// fn default() -> Self {
|
self.current_eps += 1;
|
||||||
// let default_params: SVRParameters<T, M, LinearKernel> = SVRParameters::default();
|
} else if self.current_c + 1 < self.svr_search_parameters.c.len() {
|
||||||
|
self.current_eps = 0;
|
||||||
|
self.current_c += 1;
|
||||||
|
} else if self.current_tol + 1 < self.svr_search_parameters.tol.len() {
|
||||||
|
self.current_eps = 0;
|
||||||
|
self.current_c = 0;
|
||||||
|
self.current_tol += 1;
|
||||||
|
} else if self.current_kernel + 1 < self.svr_search_parameters.kernel.len() {
|
||||||
|
self.current_eps = 0;
|
||||||
|
self.current_c = 0;
|
||||||
|
self.current_tol = 0;
|
||||||
|
self.current_kernel += 1;
|
||||||
|
} else {
|
||||||
|
self.current_eps += 1;
|
||||||
|
self.current_c += 1;
|
||||||
|
self.current_tol += 1;
|
||||||
|
self.current_kernel += 1;
|
||||||
|
}
|
||||||
|
|
||||||
// SVRSearchParameters {
|
Some(next)
|
||||||
// eps: vec![default_params.eps],
|
}
|
||||||
// c: vec![default_params.c],
|
}
|
||||||
// tol: vec![default_params.tol],
|
|
||||||
// kernel: vec![default_params.kernel],
|
|
||||||
// m: PhantomData,
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
impl<T: Number + FloatNumber + RealNumber, M: Array2<T>> Default for SVRSearchParameters<T, M> {
|
||||||
// #[derive(Debug)]
|
fn default() -> Self {
|
||||||
// #[cfg_attr(
|
let default_params: svr::SVRParameters<T> = svr::SVRParameters::default();
|
||||||
// feature = "serde",
|
|
||||||
// serde(bound(
|
SVRSearchParameters {
|
||||||
// serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize",
|
eps: vec![default_params.eps],
|
||||||
// deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>",
|
c: vec![default_params.c],
|
||||||
// ))
|
tol: vec![default_params.tol],
|
||||||
// )]
|
kernel: vec![default_params.kernel.unwrap_or_else(Kernels::linear)],
|
||||||
|
m: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::linalg::basic::matrix::DenseMatrix;
|
||||||
|
use crate::svm::Kernels;
|
||||||
|
|
||||||
|
type T = f64;
|
||||||
|
type M = DenseMatrix<T>;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_default_parameters() {
|
||||||
|
let params = SVRSearchParameters::<T, M>::default();
|
||||||
|
assert_eq!(params.eps.len(), 1);
|
||||||
|
assert_eq!(params.c.len(), 1);
|
||||||
|
assert_eq!(params.tol.len(), 1);
|
||||||
|
assert_eq!(params.kernel.len(), 1);
|
||||||
|
// Check that the default kernel is linear
|
||||||
|
assert_eq!(params.kernel[0], Kernels::linear());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_single_grid_iteration() {
|
||||||
|
let params = SVRSearchParameters::<T, M> {
|
||||||
|
eps: vec![0.1],
|
||||||
|
c: vec![1.0],
|
||||||
|
tol: vec![1e-3],
|
||||||
|
kernel: vec![Kernels::rbf().with_gamma(0.5)],
|
||||||
|
m: PhantomData,
|
||||||
|
};
|
||||||
|
let mut iter = params.into_iter();
|
||||||
|
let param = iter.next().unwrap();
|
||||||
|
assert_eq!(param.eps, 0.1);
|
||||||
|
assert_eq!(param.c, 1.0);
|
||||||
|
assert_eq!(param.tol, 1e-3);
|
||||||
|
assert_eq!(param.kernel, Some(Kernels::rbf().with_gamma(0.5)));
|
||||||
|
assert!(iter.next().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cartesian_grid_iteration() {
|
||||||
|
let params = SVRSearchParameters::<T, M> {
|
||||||
|
eps: vec![0.1, 0.2],
|
||||||
|
c: vec![1.0, 2.0],
|
||||||
|
tol: vec![1e-3],
|
||||||
|
kernel: vec![Kernels::linear(), Kernels::rbf().with_gamma(0.5)],
|
||||||
|
m: PhantomData,
|
||||||
|
};
|
||||||
|
let expected_count =
|
||||||
|
params.eps.len() * params.c.len() * params.tol.len() * params.kernel.len();
|
||||||
|
let results: Vec<_> = params.into_iter().collect();
|
||||||
|
assert_eq!(results.len(), expected_count);
|
||||||
|
|
||||||
|
// Check that all parameter combinations are present
|
||||||
|
let mut seen = vec![];
|
||||||
|
for p in &results {
|
||||||
|
seen.push((p.eps, p.c, p.tol, p.kernel.clone().unwrap()));
|
||||||
|
}
|
||||||
|
for &eps in &[0.1, 0.2] {
|
||||||
|
for &c in &[1.0, 2.0] {
|
||||||
|
for &tol in &[1e-3] {
|
||||||
|
for kernel in &[Kernels::linear(), Kernels::rbf().with_gamma(0.5)] {
|
||||||
|
assert!(seen.contains(&(eps, c, tol, kernel.clone())));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_empty_grid() {
|
||||||
|
let params = SVRSearchParameters::<T, M> {
|
||||||
|
eps: vec![],
|
||||||
|
c: vec![],
|
||||||
|
tol: vec![],
|
||||||
|
kernel: vec![],
|
||||||
|
m: PhantomData,
|
||||||
|
};
|
||||||
|
let mut iter = params.into_iter();
|
||||||
|
assert!(iter.next().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_kernel_enum_variants() {
|
||||||
|
let lin = Kernels::linear();
|
||||||
|
let rbf = Kernels::rbf().with_gamma(0.2);
|
||||||
|
let poly = Kernels::polynomial()
|
||||||
|
.with_degree(2.0)
|
||||||
|
.with_gamma(1.0)
|
||||||
|
.with_coef0(0.5);
|
||||||
|
let sig = Kernels::sigmoid().with_gamma(0.3).with_coef0(0.1);
|
||||||
|
|
||||||
|
assert_eq!(lin, Kernels::Linear);
|
||||||
|
match rbf {
|
||||||
|
Kernels::RBF { gamma } => assert_eq!(gamma, Some(0.2)),
|
||||||
|
_ => panic!("Not RBF"),
|
||||||
|
}
|
||||||
|
match poly {
|
||||||
|
Kernels::Polynomial {
|
||||||
|
degree,
|
||||||
|
gamma,
|
||||||
|
coef0,
|
||||||
|
} => {
|
||||||
|
assert_eq!(degree, Some(2.0));
|
||||||
|
assert_eq!(gamma, Some(1.0));
|
||||||
|
assert_eq!(coef0, Some(0.5));
|
||||||
|
}
|
||||||
|
_ => panic!("Not Polynomial"),
|
||||||
|
}
|
||||||
|
match sig {
|
||||||
|
Kernels::Sigmoid { gamma, coef0 } => {
|
||||||
|
assert_eq!(gamma, Some(0.3));
|
||||||
|
assert_eq!(coef0, Some(0.1));
|
||||||
|
}
|
||||||
|
_ => panic!("Not Sigmoid"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
+26
-25
@@ -51,9 +51,9 @@
|
|||||||
//!
|
//!
|
||||||
//! let knl = Kernels::linear();
|
//! let knl = Kernels::linear();
|
||||||
//! let params = &SVRParameters::default().with_eps(2.0).with_c(10.0).with_kernel(knl);
|
//! let params = &SVRParameters::default().with_eps(2.0).with_c(10.0).with_kernel(knl);
|
||||||
//! // let svr = SVR::fit(&x, &y, params).unwrap();
|
//! let svr = SVR::fit(&x, &y, params).unwrap();
|
||||||
//!
|
//!
|
||||||
//! // let y_hat = svr.predict(&x).unwrap();
|
//! let y_hat = svr.predict(&x).unwrap();
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! ## References:
|
//! ## References:
|
||||||
@@ -80,11 +80,12 @@ use crate::error::{Failed, FailedError};
|
|||||||
use crate::linalg::basic::arrays::{Array1, Array2, MutArray};
|
use crate::linalg::basic::arrays::{Array1, Array2, MutArray};
|
||||||
use crate::numbers::basenum::Number;
|
use crate::numbers::basenum::Number;
|
||||||
use crate::numbers::floatnum::FloatNumber;
|
use crate::numbers::floatnum::FloatNumber;
|
||||||
use crate::svm::Kernel;
|
|
||||||
|
|
||||||
|
use crate::svm::{Kernel, Kernels};
|
||||||
|
|
||||||
|
/// SVR Parameters
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
/// SVR Parameters
|
|
||||||
pub struct SVRParameters<T: Number + FloatNumber + PartialOrd> {
|
pub struct SVRParameters<T: Number + FloatNumber + PartialOrd> {
|
||||||
/// Epsilon in the epsilon-SVR model.
|
/// Epsilon in the epsilon-SVR model.
|
||||||
pub eps: T,
|
pub eps: T,
|
||||||
@@ -97,7 +98,7 @@ pub struct SVRParameters<T: Number + FloatNumber + PartialOrd> {
|
|||||||
all(feature = "serde", target_arch = "wasm32"),
|
all(feature = "serde", target_arch = "wasm32"),
|
||||||
serde(skip_serializing, skip_deserializing)
|
serde(skip_serializing, skip_deserializing)
|
||||||
)]
|
)]
|
||||||
pub kernel: Option<Box<dyn Kernel>>,
|
pub kernel: Option<Kernels>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
@@ -160,8 +161,8 @@ impl<T: Number + FloatNumber + PartialOrd> SVRParameters<T> {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
/// The kernel function.
|
/// The kernel function.
|
||||||
pub fn with_kernel<K: Kernel + 'static>(mut self, kernel: K) -> Self {
|
pub fn with_kernel(mut self, kernel: Kernels) -> Self {
|
||||||
self.kernel = Some(Box::new(kernel));
|
self.kernel = Some(kernel);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -597,25 +598,25 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::basic::matrix::DenseMatrix;
|
use crate::linalg::basic::matrix::DenseMatrix;
|
||||||
use crate::metrics::mean_squared_error;
|
use crate::metrics::mean_squared_error;
|
||||||
|
use crate::svm::search::svr_params::SVRSearchParameters;
|
||||||
use crate::svm::Kernels;
|
use crate::svm::Kernels;
|
||||||
|
|
||||||
// #[test]
|
#[test]
|
||||||
// fn search_parameters() {
|
fn search_parameters() {
|
||||||
// let parameters: SVRSearchParameters<f64, DenseMatrix<f64>, LinearKernel> =
|
let parameters: SVRSearchParameters<f64, DenseMatrix<f64>> = SVRSearchParameters {
|
||||||
// SVRSearchParameters {
|
eps: vec![0., 1.],
|
||||||
// eps: vec![0., 1.],
|
kernel: vec![Kernels::linear()],
|
||||||
// kernel: vec![LinearKernel {}],
|
..Default::default()
|
||||||
// ..Default::default()
|
};
|
||||||
// };
|
let mut iter = parameters.into_iter();
|
||||||
// let mut iter = parameters.into_iter();
|
let next = iter.next().unwrap();
|
||||||
// let next = iter.next().unwrap();
|
assert_eq!(next.eps, 0.);
|
||||||
// assert_eq!(next.eps, 0.);
|
// assert_eq!(next.kernel, LinearKernel {});
|
||||||
// assert_eq!(next.kernel, LinearKernel {});
|
// let next = iter.next().unwrap();
|
||||||
// let next = iter.next().unwrap();
|
// assert_eq!(next.eps, 1.);
|
||||||
// assert_eq!(next.eps, 1.);
|
// assert_eq!(next.kernel, LinearKernel {});
|
||||||
// assert_eq!(next.kernel, LinearKernel {});
|
// assert!(iter.next().is_none());
|
||||||
// assert!(iter.next().is_none());
|
}
|
||||||
// }
|
|
||||||
|
|
||||||
#[cfg_attr(
|
#[cfg_attr(
|
||||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
@@ -648,7 +649,7 @@ mod tests {
|
|||||||
114.2, 115.7, 116.9,
|
114.2, 115.7, 116.9,
|
||||||
];
|
];
|
||||||
|
|
||||||
let knl = Kernels::linear();
|
let knl: Kernels = Kernels::linear();
|
||||||
let y_hat = SVR::fit(
|
let y_hat = SVR::fit(
|
||||||
&x,
|
&x,
|
||||||
&y,
|
&y,
|
||||||
|
|||||||
Reference in New Issue
Block a user