grid search (#154)
* grid search draft * hyperparam search for linear estimators
This commit is contained in:
@@ -135,6 +135,121 @@ impl<T: RealNumber> Default for ElasticNetParameters<T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// ElasticNet grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ElasticNetSearchParameters<T: RealNumber> {
|
||||
/// Regularization parameter.
|
||||
pub alpha: Vec<T>,
|
||||
/// The elastic net mixing parameter, with 0 <= l1_ratio <= 1.
|
||||
/// For l1_ratio = 0 the penalty is an L2 penalty.
|
||||
/// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.
|
||||
pub l1_ratio: Vec<T>,
|
||||
/// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation.
|
||||
pub normalize: Vec<bool>,
|
||||
/// The tolerance for the optimization
|
||||
pub tol: Vec<T>,
|
||||
/// The maximum number of iterations
|
||||
pub max_iter: Vec<usize>,
|
||||
}
|
||||
|
||||
/// ElasticNet grid search iterator
|
||||
pub struct ElasticNetSearchParametersIterator<T: RealNumber> {
|
||||
lasso_regression_search_parameters: ElasticNetSearchParameters<T>,
|
||||
current_alpha: usize,
|
||||
current_l1_ratio: usize,
|
||||
current_normalize: usize,
|
||||
current_tol: usize,
|
||||
current_max_iter: usize,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> IntoIterator for ElasticNetSearchParameters<T> {
|
||||
type Item = ElasticNetParameters<T>;
|
||||
type IntoIter = ElasticNetSearchParametersIterator<T>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
ElasticNetSearchParametersIterator {
|
||||
lasso_regression_search_parameters: self,
|
||||
current_alpha: 0,
|
||||
current_l1_ratio: 0,
|
||||
current_normalize: 0,
|
||||
current_tol: 0,
|
||||
current_max_iter: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Iterator for ElasticNetSearchParametersIterator<T> {
|
||||
type Item = ElasticNetParameters<T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_alpha == self.lasso_regression_search_parameters.alpha.len()
|
||||
&& self.current_l1_ratio == self.lasso_regression_search_parameters.l1_ratio.len()
|
||||
&& self.current_normalize == self.lasso_regression_search_parameters.normalize.len()
|
||||
&& self.current_tol == self.lasso_regression_search_parameters.tol.len()
|
||||
&& self.current_max_iter == self.lasso_regression_search_parameters.max_iter.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = ElasticNetParameters {
|
||||
alpha: self.lasso_regression_search_parameters.alpha[self.current_alpha],
|
||||
l1_ratio: self.lasso_regression_search_parameters.alpha[self.current_l1_ratio],
|
||||
normalize: self.lasso_regression_search_parameters.normalize[self.current_normalize],
|
||||
tol: self.lasso_regression_search_parameters.tol[self.current_tol],
|
||||
max_iter: self.lasso_regression_search_parameters.max_iter[self.current_max_iter],
|
||||
};
|
||||
|
||||
if self.current_alpha + 1 < self.lasso_regression_search_parameters.alpha.len() {
|
||||
self.current_alpha += 1;
|
||||
} else if self.current_l1_ratio + 1 < self.lasso_regression_search_parameters.l1_ratio.len()
|
||||
{
|
||||
self.current_alpha = 0;
|
||||
self.current_l1_ratio += 1;
|
||||
} else if self.current_normalize + 1
|
||||
< self.lasso_regression_search_parameters.normalize.len()
|
||||
{
|
||||
self.current_alpha = 0;
|
||||
self.current_l1_ratio = 0;
|
||||
self.current_normalize += 1;
|
||||
} else if self.current_tol + 1 < self.lasso_regression_search_parameters.tol.len() {
|
||||
self.current_alpha = 0;
|
||||
self.current_l1_ratio = 0;
|
||||
self.current_normalize = 0;
|
||||
self.current_tol += 1;
|
||||
} else if self.current_max_iter + 1 < self.lasso_regression_search_parameters.max_iter.len()
|
||||
{
|
||||
self.current_alpha = 0;
|
||||
self.current_l1_ratio = 0;
|
||||
self.current_normalize = 0;
|
||||
self.current_tol = 0;
|
||||
self.current_max_iter += 1;
|
||||
} else {
|
||||
self.current_alpha += 1;
|
||||
self.current_l1_ratio += 1;
|
||||
self.current_normalize += 1;
|
||||
self.current_tol += 1;
|
||||
self.current_max_iter += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for ElasticNetSearchParameters<T> {
|
||||
fn default() -> Self {
|
||||
let default_params = ElasticNetParameters::default();
|
||||
|
||||
ElasticNetSearchParameters {
|
||||
alpha: vec![default_params.alpha],
|
||||
l1_ratio: vec![default_params.l1_ratio],
|
||||
normalize: vec![default_params.normalize],
|
||||
tol: vec![default_params.tol],
|
||||
max_iter: vec![default_params.max_iter],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> PartialEq for ElasticNet<T, M> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.coefficients == other.coefficients
|
||||
@@ -291,6 +406,29 @@ mod tests {
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::metrics::mean_absolute_error;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = ElasticNetSearchParameters {
|
||||
alpha: vec![0., 1.],
|
||||
max_iter: vec![10, 100],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 0.);
|
||||
assert_eq!(next.max_iter, 10);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 1.);
|
||||
assert_eq!(next.max_iter, 10);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 0.);
|
||||
assert_eq!(next.max_iter, 100);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 1.);
|
||||
assert_eq!(next.max_iter, 100);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn elasticnet_longley() {
|
||||
|
||||
@@ -112,6 +112,105 @@ impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for Lasso<T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Lasso grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LassoSearchParameters<T: RealNumber> {
|
||||
/// Controls the strength of the penalty to the loss function.
|
||||
pub alpha: Vec<T>,
|
||||
/// If true the regressors X will be normalized before regression
|
||||
/// by subtracting the mean and dividing by the standard deviation.
|
||||
pub normalize: Vec<bool>,
|
||||
/// The tolerance for the optimization
|
||||
pub tol: Vec<T>,
|
||||
/// The maximum number of iterations
|
||||
pub max_iter: Vec<usize>,
|
||||
}
|
||||
|
||||
/// Lasso grid search iterator
|
||||
pub struct LassoSearchParametersIterator<T: RealNumber> {
|
||||
lasso_regression_search_parameters: LassoSearchParameters<T>,
|
||||
current_alpha: usize,
|
||||
current_normalize: usize,
|
||||
current_tol: usize,
|
||||
current_max_iter: usize,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> IntoIterator for LassoSearchParameters<T> {
|
||||
type Item = LassoParameters<T>;
|
||||
type IntoIter = LassoSearchParametersIterator<T>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
LassoSearchParametersIterator {
|
||||
lasso_regression_search_parameters: self,
|
||||
current_alpha: 0,
|
||||
current_normalize: 0,
|
||||
current_tol: 0,
|
||||
current_max_iter: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Iterator for LassoSearchParametersIterator<T> {
|
||||
type Item = LassoParameters<T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_alpha == self.lasso_regression_search_parameters.alpha.len()
|
||||
&& self.current_normalize == self.lasso_regression_search_parameters.normalize.len()
|
||||
&& self.current_tol == self.lasso_regression_search_parameters.tol.len()
|
||||
&& self.current_max_iter == self.lasso_regression_search_parameters.max_iter.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = LassoParameters {
|
||||
alpha: self.lasso_regression_search_parameters.alpha[self.current_alpha],
|
||||
normalize: self.lasso_regression_search_parameters.normalize[self.current_normalize],
|
||||
tol: self.lasso_regression_search_parameters.tol[self.current_tol],
|
||||
max_iter: self.lasso_regression_search_parameters.max_iter[self.current_max_iter],
|
||||
};
|
||||
|
||||
if self.current_alpha + 1 < self.lasso_regression_search_parameters.alpha.len() {
|
||||
self.current_alpha += 1;
|
||||
} else if self.current_normalize + 1
|
||||
< self.lasso_regression_search_parameters.normalize.len()
|
||||
{
|
||||
self.current_alpha = 0;
|
||||
self.current_normalize += 1;
|
||||
} else if self.current_tol + 1 < self.lasso_regression_search_parameters.tol.len() {
|
||||
self.current_alpha = 0;
|
||||
self.current_normalize = 0;
|
||||
self.current_tol += 1;
|
||||
} else if self.current_max_iter + 1 < self.lasso_regression_search_parameters.max_iter.len()
|
||||
{
|
||||
self.current_alpha = 0;
|
||||
self.current_normalize = 0;
|
||||
self.current_tol = 0;
|
||||
self.current_max_iter += 1;
|
||||
} else {
|
||||
self.current_alpha += 1;
|
||||
self.current_normalize += 1;
|
||||
self.current_tol += 1;
|
||||
self.current_max_iter += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for LassoSearchParameters<T> {
|
||||
fn default() -> Self {
|
||||
let default_params = LassoParameters::default();
|
||||
|
||||
LassoSearchParameters {
|
||||
alpha: vec![default_params.alpha],
|
||||
normalize: vec![default_params.normalize],
|
||||
tol: vec![default_params.tol],
|
||||
max_iter: vec![default_params.max_iter],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Lasso<T, M> {
|
||||
/// Fits Lasso regression to your data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
@@ -226,6 +325,29 @@ mod tests {
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::metrics::mean_absolute_error;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = LassoSearchParameters {
|
||||
alpha: vec![0., 1.],
|
||||
max_iter: vec![10, 100],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 0.);
|
||||
assert_eq!(next.max_iter, 10);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 1.);
|
||||
assert_eq!(next.max_iter, 10);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 0.);
|
||||
assert_eq!(next.max_iter, 100);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 1.);
|
||||
assert_eq!(next.max_iter, 100);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn lasso_fit_predict() {
|
||||
|
||||
@@ -71,7 +71,7 @@ use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
/// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable.
|
||||
pub enum LinearRegressionSolverName {
|
||||
/// QR decomposition, see [QR](../../linalg/qr/index.html)
|
||||
@@ -113,6 +113,60 @@ impl Default for LinearRegressionParameters {
|
||||
}
|
||||
}
|
||||
|
||||
/// Linear Regression grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LinearRegressionSearchParameters {
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub solver: Vec<LinearRegressionSolverName>,
|
||||
}
|
||||
|
||||
/// Linear Regression grid search iterator
|
||||
pub struct LinearRegressionSearchParametersIterator {
|
||||
linear_regression_search_parameters: LinearRegressionSearchParameters,
|
||||
current_solver: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for LinearRegressionSearchParameters {
|
||||
type Item = LinearRegressionParameters;
|
||||
type IntoIter = LinearRegressionSearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
LinearRegressionSearchParametersIterator {
|
||||
linear_regression_search_parameters: self,
|
||||
current_solver: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for LinearRegressionSearchParametersIterator {
|
||||
type Item = LinearRegressionParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_solver == self.linear_regression_search_parameters.solver.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = LinearRegressionParameters {
|
||||
solver: self.linear_regression_search_parameters.solver[self.current_solver].clone(),
|
||||
};
|
||||
|
||||
self.current_solver += 1;
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LinearRegressionSearchParameters {
|
||||
fn default() -> Self {
|
||||
let default_params = LinearRegressionParameters::default();
|
||||
|
||||
LinearRegressionSearchParameters {
|
||||
solver: vec![default_params.solver],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> PartialEq for LinearRegression<T, M> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.coefficients == other.coefficients
|
||||
@@ -200,6 +254,20 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = LinearRegressionSearchParameters {
|
||||
solver: vec![
|
||||
LinearRegressionSolverName::QR,
|
||||
LinearRegressionSolverName::SVD,
|
||||
],
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
assert_eq!(iter.next().unwrap().solver, LinearRegressionSolverName::QR);
|
||||
assert_eq!(iter.next().unwrap().solver, LinearRegressionSolverName::SVD);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn ols_fit_predict() {
|
||||
|
||||
@@ -68,7 +68,7 @@ use crate::optimization::line_search::Backtracking;
|
||||
use crate::optimization::FunctionOrder;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
/// Solver options for Logistic regression. Right now only LBFGS solver is supported.
|
||||
pub enum LogisticRegressionSolverName {
|
||||
/// Limited-memory Broyden–Fletcher–Goldfarb–Shanno method, see [LBFGS paper](http://users.iems.northwestern.edu/~nocedal/lbfgsb.html)
|
||||
@@ -85,6 +85,77 @@ pub struct LogisticRegressionParameters<T: RealNumber> {
|
||||
pub alpha: T,
|
||||
}
|
||||
|
||||
/// Logistic Regression grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LogisticRegressionSearchParameters<T: RealNumber> {
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub solver: Vec<LogisticRegressionSolverName>,
|
||||
/// Regularization parameter.
|
||||
pub alpha: Vec<T>,
|
||||
}
|
||||
|
||||
/// Logistic Regression grid search iterator
|
||||
pub struct LogisticRegressionSearchParametersIterator<T: RealNumber> {
|
||||
logistic_regression_search_parameters: LogisticRegressionSearchParameters<T>,
|
||||
current_solver: usize,
|
||||
current_alpha: usize,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> IntoIterator for LogisticRegressionSearchParameters<T> {
|
||||
type Item = LogisticRegressionParameters<T>;
|
||||
type IntoIter = LogisticRegressionSearchParametersIterator<T>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
LogisticRegressionSearchParametersIterator {
|
||||
logistic_regression_search_parameters: self,
|
||||
current_solver: 0,
|
||||
current_alpha: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Iterator for LogisticRegressionSearchParametersIterator<T> {
|
||||
type Item = LogisticRegressionParameters<T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_alpha == self.logistic_regression_search_parameters.alpha.len()
|
||||
&& self.current_solver == self.logistic_regression_search_parameters.solver.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = LogisticRegressionParameters {
|
||||
solver: self.logistic_regression_search_parameters.solver[self.current_solver].clone(),
|
||||
alpha: self.logistic_regression_search_parameters.alpha[self.current_alpha],
|
||||
};
|
||||
|
||||
if self.current_alpha + 1 < self.logistic_regression_search_parameters.alpha.len() {
|
||||
self.current_alpha += 1;
|
||||
} else if self.current_solver + 1 < self.logistic_regression_search_parameters.solver.len()
|
||||
{
|
||||
self.current_alpha = 0;
|
||||
self.current_solver += 1;
|
||||
} else {
|
||||
self.current_alpha += 1;
|
||||
self.current_solver += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for LogisticRegressionSearchParameters<T> {
|
||||
fn default() -> Self {
|
||||
let default_params = LogisticRegressionParameters::default();
|
||||
|
||||
LogisticRegressionSearchParameters {
|
||||
solver: vec![default_params.solver],
|
||||
alpha: vec![default_params.alpha],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Logistic Regression
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
@@ -452,6 +523,21 @@ mod tests {
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::metrics::accuracy;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = LogisticRegressionSearchParameters {
|
||||
alpha: vec![0., 1.],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
assert_eq!(iter.next().unwrap().alpha, 0.);
|
||||
assert_eq!(
|
||||
iter.next().unwrap().solver,
|
||||
LogisticRegressionSolverName::LBFGS
|
||||
);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn multiclass_objective_f() {
|
||||
|
||||
@@ -68,7 +68,7 @@ use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
/// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable.
|
||||
pub enum RidgeRegressionSolverName {
|
||||
/// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html)
|
||||
@@ -90,6 +90,90 @@ pub struct RidgeRegressionParameters<T: RealNumber> {
|
||||
pub normalize: bool,
|
||||
}
|
||||
|
||||
/// Ridge Regression grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RidgeRegressionSearchParameters<T: RealNumber> {
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub solver: Vec<RidgeRegressionSolverName>,
|
||||
/// Regularization parameter.
|
||||
pub alpha: Vec<T>,
|
||||
/// If true the regressors X will be normalized before regression
|
||||
/// by subtracting the mean and dividing by the standard deviation.
|
||||
pub normalize: Vec<bool>,
|
||||
}
|
||||
|
||||
/// Ridge Regression grid search iterator
|
||||
pub struct RidgeRegressionSearchParametersIterator<T: RealNumber> {
|
||||
ridge_regression_search_parameters: RidgeRegressionSearchParameters<T>,
|
||||
current_solver: usize,
|
||||
current_alpha: usize,
|
||||
current_normalize: usize,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> IntoIterator for RidgeRegressionSearchParameters<T> {
|
||||
type Item = RidgeRegressionParameters<T>;
|
||||
type IntoIter = RidgeRegressionSearchParametersIterator<T>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
RidgeRegressionSearchParametersIterator {
|
||||
ridge_regression_search_parameters: self,
|
||||
current_solver: 0,
|
||||
current_alpha: 0,
|
||||
current_normalize: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Iterator for RidgeRegressionSearchParametersIterator<T> {
|
||||
type Item = RidgeRegressionParameters<T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_alpha == self.ridge_regression_search_parameters.alpha.len()
|
||||
&& self.current_solver == self.ridge_regression_search_parameters.solver.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = RidgeRegressionParameters {
|
||||
solver: self.ridge_regression_search_parameters.solver[self.current_solver].clone(),
|
||||
alpha: self.ridge_regression_search_parameters.alpha[self.current_alpha],
|
||||
normalize: self.ridge_regression_search_parameters.normalize[self.current_normalize],
|
||||
};
|
||||
|
||||
if self.current_alpha + 1 < self.ridge_regression_search_parameters.alpha.len() {
|
||||
self.current_alpha += 1;
|
||||
} else if self.current_solver + 1 < self.ridge_regression_search_parameters.solver.len() {
|
||||
self.current_alpha = 0;
|
||||
self.current_solver += 1;
|
||||
} else if self.current_normalize + 1
|
||||
< self.ridge_regression_search_parameters.normalize.len()
|
||||
{
|
||||
self.current_alpha = 0;
|
||||
self.current_solver = 0;
|
||||
self.current_normalize += 1;
|
||||
} else {
|
||||
self.current_alpha += 1;
|
||||
self.current_solver += 1;
|
||||
self.current_normalize += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for RidgeRegressionSearchParameters<T> {
|
||||
fn default() -> Self {
|
||||
let default_params = RidgeRegressionParameters::default();
|
||||
|
||||
RidgeRegressionSearchParameters {
|
||||
solver: vec![default_params.solver],
|
||||
alpha: vec![default_params.alpha],
|
||||
normalize: vec![default_params.normalize],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Ridge regression
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
@@ -274,6 +358,21 @@ mod tests {
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::metrics::mean_absolute_error;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = RidgeRegressionSearchParameters {
|
||||
alpha: vec![0., 1.],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
assert_eq!(iter.next().unwrap().alpha, 0.);
|
||||
assert_eq!(
|
||||
iter.next().unwrap().solver,
|
||||
RidgeRegressionSolverName::Cholesky
|
||||
);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn ridge_fit_predict() {
|
||||
|
||||
Reference in New Issue
Block a user