diff --git a/src/linear/elastic_net.rs b/src/linear/elastic_net.rs index ce13435..0e9cb57 100644 --- a/src/linear/elastic_net.rs +++ b/src/linear/elastic_net.rs @@ -135,6 +135,121 @@ impl Default for ElasticNetParameters { } } +/// ElasticNet grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct ElasticNetSearchParameters { + /// Regularization parameter. + pub alpha: Vec, + /// The elastic net mixing parameter, with 0 <= l1_ratio <= 1. + /// For l1_ratio = 0 the penalty is an L2 penalty. + /// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2. + pub l1_ratio: Vec, + /// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation. + pub normalize: Vec, + /// The tolerance for the optimization + pub tol: Vec, + /// The maximum number of iterations + pub max_iter: Vec, +} + +/// ElasticNet grid search iterator +pub struct ElasticNetSearchParametersIterator { + lasso_regression_search_parameters: ElasticNetSearchParameters, + current_alpha: usize, + current_l1_ratio: usize, + current_normalize: usize, + current_tol: usize, + current_max_iter: usize, +} + +impl IntoIterator for ElasticNetSearchParameters { + type Item = ElasticNetParameters; + type IntoIter = ElasticNetSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + ElasticNetSearchParametersIterator { + lasso_regression_search_parameters: self, + current_alpha: 0, + current_l1_ratio: 0, + current_normalize: 0, + current_tol: 0, + current_max_iter: 0, + } + } +} + +impl Iterator for ElasticNetSearchParametersIterator { + type Item = ElasticNetParameters; + + fn next(&mut self) -> Option { + if self.current_alpha == self.lasso_regression_search_parameters.alpha.len() + && self.current_l1_ratio == self.lasso_regression_search_parameters.l1_ratio.len() + && self.current_normalize == self.lasso_regression_search_parameters.normalize.len() + && self.current_tol == self.lasso_regression_search_parameters.tol.len() + && self.current_max_iter == self.lasso_regression_search_parameters.max_iter.len() + { + return None; + } + + let next = ElasticNetParameters { + alpha: self.lasso_regression_search_parameters.alpha[self.current_alpha], + l1_ratio: self.lasso_regression_search_parameters.alpha[self.current_l1_ratio], + normalize: self.lasso_regression_search_parameters.normalize[self.current_normalize], + tol: self.lasso_regression_search_parameters.tol[self.current_tol], + max_iter: self.lasso_regression_search_parameters.max_iter[self.current_max_iter], + }; + + if self.current_alpha + 1 < self.lasso_regression_search_parameters.alpha.len() { + self.current_alpha += 1; + } else if self.current_l1_ratio + 1 < self.lasso_regression_search_parameters.l1_ratio.len() + { + self.current_alpha = 0; + self.current_l1_ratio += 1; + } else if self.current_normalize + 1 + < self.lasso_regression_search_parameters.normalize.len() + { + self.current_alpha = 0; + self.current_l1_ratio = 0; + self.current_normalize += 1; + } else if self.current_tol + 1 < self.lasso_regression_search_parameters.tol.len() { + self.current_alpha = 0; + self.current_l1_ratio = 0; + self.current_normalize = 0; + self.current_tol += 1; + } else if self.current_max_iter + 1 < self.lasso_regression_search_parameters.max_iter.len() + { + self.current_alpha = 0; + self.current_l1_ratio = 0; + self.current_normalize = 0; + self.current_tol = 0; + self.current_max_iter += 1; + } else { + self.current_alpha += 1; + self.current_l1_ratio += 1; + self.current_normalize += 1; + self.current_tol += 1; + self.current_max_iter += 1; + } + + Some(next) + } +} + +impl Default for ElasticNetSearchParameters { + fn default() -> Self { + let default_params = ElasticNetParameters::default(); + + ElasticNetSearchParameters { + alpha: vec![default_params.alpha], + l1_ratio: vec![default_params.l1_ratio], + normalize: vec![default_params.normalize], + tol: vec![default_params.tol], + max_iter: vec![default_params.max_iter], + } + } +} + impl> PartialEq for ElasticNet { fn eq(&self, other: &Self) -> bool { self.coefficients == other.coefficients @@ -291,6 +406,29 @@ mod tests { use crate::linalg::naive::dense_matrix::*; use crate::metrics::mean_absolute_error; + #[test] + fn search_parameters() { + let parameters = ElasticNetSearchParameters { + alpha: vec![0., 1.], + max_iter: vec![10, 100], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 0.); + assert_eq!(next.max_iter, 10); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 1.); + assert_eq!(next.max_iter, 10); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 0.); + assert_eq!(next.max_iter, 100); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 1.); + assert_eq!(next.max_iter, 100); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn elasticnet_longley() { diff --git a/src/linear/lasso.rs b/src/linear/lasso.rs index 7edd325..7e80a8b 100644 --- a/src/linear/lasso.rs +++ b/src/linear/lasso.rs @@ -112,6 +112,105 @@ impl> Predictor for Lasso { } } +/// Lasso grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct LassoSearchParameters { + /// Controls the strength of the penalty to the loss function. + pub alpha: Vec, + /// If true the regressors X will be normalized before regression + /// by subtracting the mean and dividing by the standard deviation. + pub normalize: Vec, + /// The tolerance for the optimization + pub tol: Vec, + /// The maximum number of iterations + pub max_iter: Vec, +} + +/// Lasso grid search iterator +pub struct LassoSearchParametersIterator { + lasso_regression_search_parameters: LassoSearchParameters, + current_alpha: usize, + current_normalize: usize, + current_tol: usize, + current_max_iter: usize, +} + +impl IntoIterator for LassoSearchParameters { + type Item = LassoParameters; + type IntoIter = LassoSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + LassoSearchParametersIterator { + lasso_regression_search_parameters: self, + current_alpha: 0, + current_normalize: 0, + current_tol: 0, + current_max_iter: 0, + } + } +} + +impl Iterator for LassoSearchParametersIterator { + type Item = LassoParameters; + + fn next(&mut self) -> Option { + if self.current_alpha == self.lasso_regression_search_parameters.alpha.len() + && self.current_normalize == self.lasso_regression_search_parameters.normalize.len() + && self.current_tol == self.lasso_regression_search_parameters.tol.len() + && self.current_max_iter == self.lasso_regression_search_parameters.max_iter.len() + { + return None; + } + + let next = LassoParameters { + alpha: self.lasso_regression_search_parameters.alpha[self.current_alpha], + normalize: self.lasso_regression_search_parameters.normalize[self.current_normalize], + tol: self.lasso_regression_search_parameters.tol[self.current_tol], + max_iter: self.lasso_regression_search_parameters.max_iter[self.current_max_iter], + }; + + if self.current_alpha + 1 < self.lasso_regression_search_parameters.alpha.len() { + self.current_alpha += 1; + } else if self.current_normalize + 1 + < self.lasso_regression_search_parameters.normalize.len() + { + self.current_alpha = 0; + self.current_normalize += 1; + } else if self.current_tol + 1 < self.lasso_regression_search_parameters.tol.len() { + self.current_alpha = 0; + self.current_normalize = 0; + self.current_tol += 1; + } else if self.current_max_iter + 1 < self.lasso_regression_search_parameters.max_iter.len() + { + self.current_alpha = 0; + self.current_normalize = 0; + self.current_tol = 0; + self.current_max_iter += 1; + } else { + self.current_alpha += 1; + self.current_normalize += 1; + self.current_tol += 1; + self.current_max_iter += 1; + } + + Some(next) + } +} + +impl Default for LassoSearchParameters { + fn default() -> Self { + let default_params = LassoParameters::default(); + + LassoSearchParameters { + alpha: vec![default_params.alpha], + normalize: vec![default_params.normalize], + tol: vec![default_params.tol], + max_iter: vec![default_params.max_iter], + } + } +} + impl> Lasso { /// Fits Lasso regression to your data. /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation. @@ -226,6 +325,29 @@ mod tests { use crate::linalg::naive::dense_matrix::*; use crate::metrics::mean_absolute_error; + #[test] + fn search_parameters() { + let parameters = LassoSearchParameters { + alpha: vec![0., 1.], + max_iter: vec![10, 100], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 0.); + assert_eq!(next.max_iter, 10); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 1.); + assert_eq!(next.max_iter, 10); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 0.); + assert_eq!(next.max_iter, 100); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 1.); + assert_eq!(next.max_iter, 100); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn lasso_fit_predict() { diff --git a/src/linear/linear_regression.rs b/src/linear/linear_regression.rs index b1f7c51..c95e6e1 100644 --- a/src/linear/linear_regression.rs +++ b/src/linear/linear_regression.rs @@ -71,7 +71,7 @@ use crate::linalg::Matrix; use crate::math::num::RealNumber; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Eq, PartialEq)] /// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable. pub enum LinearRegressionSolverName { /// QR decomposition, see [QR](../../linalg/qr/index.html) @@ -113,6 +113,60 @@ impl Default for LinearRegressionParameters { } } +/// Linear Regression grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct LinearRegressionSearchParameters { + /// Solver to use for estimation of regression coefficients. + pub solver: Vec, +} + +/// Linear Regression grid search iterator +pub struct LinearRegressionSearchParametersIterator { + linear_regression_search_parameters: LinearRegressionSearchParameters, + current_solver: usize, +} + +impl IntoIterator for LinearRegressionSearchParameters { + type Item = LinearRegressionParameters; + type IntoIter = LinearRegressionSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + LinearRegressionSearchParametersIterator { + linear_regression_search_parameters: self, + current_solver: 0, + } + } +} + +impl Iterator for LinearRegressionSearchParametersIterator { + type Item = LinearRegressionParameters; + + fn next(&mut self) -> Option { + if self.current_solver == self.linear_regression_search_parameters.solver.len() { + return None; + } + + let next = LinearRegressionParameters { + solver: self.linear_regression_search_parameters.solver[self.current_solver].clone(), + }; + + self.current_solver += 1; + + Some(next) + } +} + +impl Default for LinearRegressionSearchParameters { + fn default() -> Self { + let default_params = LinearRegressionParameters::default(); + + LinearRegressionSearchParameters { + solver: vec![default_params.solver], + } + } +} + impl> PartialEq for LinearRegression { fn eq(&self, other: &Self) -> bool { self.coefficients == other.coefficients @@ -200,6 +254,20 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::*; + #[test] + fn search_parameters() { + let parameters = LinearRegressionSearchParameters { + solver: vec![ + LinearRegressionSolverName::QR, + LinearRegressionSolverName::SVD, + ], + }; + let mut iter = parameters.into_iter(); + assert_eq!(iter.next().unwrap().solver, LinearRegressionSolverName::QR); + assert_eq!(iter.next().unwrap().solver, LinearRegressionSolverName::SVD); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn ols_fit_predict() { diff --git a/src/linear/logistic_regression.rs b/src/linear/logistic_regression.rs index 1a20077..3a4c706 100644 --- a/src/linear/logistic_regression.rs +++ b/src/linear/logistic_regression.rs @@ -68,7 +68,7 @@ use crate::optimization::line_search::Backtracking; use crate::optimization::FunctionOrder; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Eq, PartialEq)] /// Solver options for Logistic regression. Right now only LBFGS solver is supported. pub enum LogisticRegressionSolverName { /// Limited-memory Broyden–Fletcher–Goldfarb–Shanno method, see [LBFGS paper](http://users.iems.northwestern.edu/~nocedal/lbfgsb.html) @@ -85,6 +85,77 @@ pub struct LogisticRegressionParameters { pub alpha: T, } +/// Logistic Regression grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct LogisticRegressionSearchParameters { + /// Solver to use for estimation of regression coefficients. + pub solver: Vec, + /// Regularization parameter. + pub alpha: Vec, +} + +/// Logistic Regression grid search iterator +pub struct LogisticRegressionSearchParametersIterator { + logistic_regression_search_parameters: LogisticRegressionSearchParameters, + current_solver: usize, + current_alpha: usize, +} + +impl IntoIterator for LogisticRegressionSearchParameters { + type Item = LogisticRegressionParameters; + type IntoIter = LogisticRegressionSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + LogisticRegressionSearchParametersIterator { + logistic_regression_search_parameters: self, + current_solver: 0, + current_alpha: 0, + } + } +} + +impl Iterator for LogisticRegressionSearchParametersIterator { + type Item = LogisticRegressionParameters; + + fn next(&mut self) -> Option { + if self.current_alpha == self.logistic_regression_search_parameters.alpha.len() + && self.current_solver == self.logistic_regression_search_parameters.solver.len() + { + return None; + } + + let next = LogisticRegressionParameters { + solver: self.logistic_regression_search_parameters.solver[self.current_solver].clone(), + alpha: self.logistic_regression_search_parameters.alpha[self.current_alpha], + }; + + if self.current_alpha + 1 < self.logistic_regression_search_parameters.alpha.len() { + self.current_alpha += 1; + } else if self.current_solver + 1 < self.logistic_regression_search_parameters.solver.len() + { + self.current_alpha = 0; + self.current_solver += 1; + } else { + self.current_alpha += 1; + self.current_solver += 1; + } + + Some(next) + } +} + +impl Default for LogisticRegressionSearchParameters { + fn default() -> Self { + let default_params = LogisticRegressionParameters::default(); + + LogisticRegressionSearchParameters { + solver: vec![default_params.solver], + alpha: vec![default_params.alpha], + } + } +} + /// Logistic Regression #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug)] @@ -452,6 +523,21 @@ mod tests { use crate::linalg::naive::dense_matrix::*; use crate::metrics::accuracy; + #[test] + fn search_parameters() { + let parameters = LogisticRegressionSearchParameters { + alpha: vec![0., 1.], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + assert_eq!(iter.next().unwrap().alpha, 0.); + assert_eq!( + iter.next().unwrap().solver, + LogisticRegressionSolverName::LBFGS + ); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn multiclass_objective_f() { diff --git a/src/linear/ridge_regression.rs b/src/linear/ridge_regression.rs index ecad250..4c3d4ff 100644 --- a/src/linear/ridge_regression.rs +++ b/src/linear/ridge_regression.rs @@ -68,7 +68,7 @@ use crate::linalg::Matrix; use crate::math::num::RealNumber; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Eq, PartialEq)] /// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable. pub enum RidgeRegressionSolverName { /// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html) @@ -90,6 +90,90 @@ pub struct RidgeRegressionParameters { pub normalize: bool, } +/// Ridge Regression grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct RidgeRegressionSearchParameters { + /// Solver to use for estimation of regression coefficients. + pub solver: Vec, + /// Regularization parameter. + pub alpha: Vec, + /// If true the regressors X will be normalized before regression + /// by subtracting the mean and dividing by the standard deviation. + pub normalize: Vec, +} + +/// Ridge Regression grid search iterator +pub struct RidgeRegressionSearchParametersIterator { + ridge_regression_search_parameters: RidgeRegressionSearchParameters, + current_solver: usize, + current_alpha: usize, + current_normalize: usize, +} + +impl IntoIterator for RidgeRegressionSearchParameters { + type Item = RidgeRegressionParameters; + type IntoIter = RidgeRegressionSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + RidgeRegressionSearchParametersIterator { + ridge_regression_search_parameters: self, + current_solver: 0, + current_alpha: 0, + current_normalize: 0, + } + } +} + +impl Iterator for RidgeRegressionSearchParametersIterator { + type Item = RidgeRegressionParameters; + + fn next(&mut self) -> Option { + if self.current_alpha == self.ridge_regression_search_parameters.alpha.len() + && self.current_solver == self.ridge_regression_search_parameters.solver.len() + { + return None; + } + + let next = RidgeRegressionParameters { + solver: self.ridge_regression_search_parameters.solver[self.current_solver].clone(), + alpha: self.ridge_regression_search_parameters.alpha[self.current_alpha], + normalize: self.ridge_regression_search_parameters.normalize[self.current_normalize], + }; + + if self.current_alpha + 1 < self.ridge_regression_search_parameters.alpha.len() { + self.current_alpha += 1; + } else if self.current_solver + 1 < self.ridge_regression_search_parameters.solver.len() { + self.current_alpha = 0; + self.current_solver += 1; + } else if self.current_normalize + 1 + < self.ridge_regression_search_parameters.normalize.len() + { + self.current_alpha = 0; + self.current_solver = 0; + self.current_normalize += 1; + } else { + self.current_alpha += 1; + self.current_solver += 1; + self.current_normalize += 1; + } + + Some(next) + } +} + +impl Default for RidgeRegressionSearchParameters { + fn default() -> Self { + let default_params = RidgeRegressionParameters::default(); + + RidgeRegressionSearchParameters { + solver: vec![default_params.solver], + alpha: vec![default_params.alpha], + normalize: vec![default_params.normalize], + } + } +} + /// Ridge regression #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug)] @@ -274,6 +358,21 @@ mod tests { use crate::linalg::naive::dense_matrix::*; use crate::metrics::mean_absolute_error; + #[test] + fn search_parameters() { + let parameters = RidgeRegressionSearchParameters { + alpha: vec![0., 1.], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + assert_eq!(iter.next().unwrap().alpha, 0.); + assert_eq!( + iter.next().unwrap().solver, + RidgeRegressionSolverName::Cholesky + ); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn ridge_fit_predict() { diff --git a/src/model_selection/hyper_tuning.rs b/src/model_selection/hyper_tuning.rs new file mode 100644 index 0000000..3093fbd --- /dev/null +++ b/src/model_selection/hyper_tuning.rs @@ -0,0 +1,117 @@ +/// grid search results. +#[derive(Clone, Debug)] +pub struct GridSearchResult { + /// Vector with test scores on each cv split + pub cross_validation_result: CrossValidationResult, + /// Vector with training scores on each cv split + pub parameters: I, +} + +/// Search for the best estimator by testing all possible combinations with cross-validation using given metric. +/// * `fit_estimator` - a `fit` function of an estimator +/// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes. +/// * `y` - target values, should be of size _N_ +/// * `parameter_search` - an iterator for parameters that will be tested. +/// * `cv` - the cross-validation splitting strategy, should be an instance of [`BaseKFold`](./trait.BaseKFold.html) +/// * `score` - a metric to use for evaluation, see [metrics](../metrics/index.html) +pub fn grid_search( + fit_estimator: F, + x: &M, + y: &M::RowVector, + parameter_search: I, + cv: K, + score: S, +) -> Result, Failed> +where + T: RealNumber, + M: Matrix, + I: Iterator, + I::Item: Clone, + E: Predictor, + K: BaseKFold, + F: Fn(&M, &M::RowVector, I::Item) -> Result, + S: Fn(&M::RowVector, &M::RowVector) -> T, +{ + let mut best_result: Option> = None; + let mut best_parameters = None; + + for parameters in parameter_search { + let result = cross_validate(&fit_estimator, x, y, ¶meters, &cv, &score)?; + if best_result.is_none() + || result.mean_test_score() > best_result.as_ref().unwrap().mean_test_score() + { + best_parameters = Some(parameters); + best_result = Some(result); + } + } + + if let (Some(parameters), Some(cross_validation_result)) = (best_parameters, best_result) { + Ok(GridSearchResult { + cross_validation_result, + parameters, + }) + } else { + Err(Failed::because( + FailedError::FindFailed, + "there were no parameter sets found", + )) + } +} + +#[cfg(test)] +mod tests { + use crate::linear::logistic_regression::{ + LogisticRegression, LogisticRegressionSearchParameters, +}; + + #[test] + fn test_grid_search() { + let x = DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + &[4.9, 3.1, 1.5, 0.1], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + &[5.7, 2.8, 4.5, 1.3], + &[6.3, 3.3, 4.7, 1.6], + &[4.9, 2.4, 3.3, 1.0], + &[6.6, 2.9, 4.6, 1.3], + &[5.2, 2.7, 3.9, 1.4], + ]); + let y = vec![ + 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., + ]; + + let cv = KFold { + n_splits: 5, + ..KFold::default() + }; + + let parameters = LogisticRegressionSearchParameters { + alpha: vec![0., 1.], + ..Default::default() + }; + + let results = grid_search( + LogisticRegression::fit, + &x, + &y, + parameters.into_iter(), + cv, + &accuracy, + ) + .unwrap(); + + assert!([0., 1.].contains(&results.parameters.alpha)); + } +} \ No newline at end of file diff --git a/src/model_selection/mod.rs b/src/model_selection/mod.rs index d283176..68f0635 100644 --- a/src/model_selection/mod.rs +++ b/src/model_selection/mod.rs @@ -91,8 +91,8 @@ //! //! let results = cross_validate(LogisticRegression::fit, //estimator //! &x, &y, //data -//! Default::default(), //hyperparameters -//! cv, //cross validation split +//! &Default::default(), //hyperparameters +//! &cv, //cross validation split //! &accuracy).unwrap(); //metric //! //! println!("Training accuracy: {}, test accuracy: {}", @@ -201,8 +201,8 @@ pub fn cross_validate( fit_estimator: F, x: &M, y: &M::RowVector, - parameters: H, - cv: K, + parameters: &H, + cv: &K, score: S, ) -> Result, Failed> where @@ -281,6 +281,7 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::*; + use crate::metrics::{accuracy, mean_absolute_error}; use crate::model_selection::kfold::KFold; use crate::neighbors::knn_regressor::KNNRegressor; @@ -362,8 +363,15 @@ mod tests { ..KFold::default() }; - let results = - cross_validate(BiasedEstimator::fit, &x, &y, NoParameters {}, cv, &accuracy).unwrap(); + let results = cross_validate( + BiasedEstimator::fit, + &x, + &y, + &NoParameters {}, + &cv, + &accuracy, + ) + .unwrap(); assert_eq!(0.4, results.mean_test_score()); assert_eq!(0.4, results.mean_train_score()); @@ -404,8 +412,8 @@ mod tests { KNNRegressor::fit, &x, &y, - Default::default(), - cv, + &Default::default(), + &cv, &mean_absolute_error, ) .unwrap();