grid search (#154)

* grid search draft
* hyperparam search for linear estimators
This commit is contained in:
Montana Low
2022-09-19 02:31:56 -07:00
committed by morenol
parent 0f442e96c0
commit 1f2597be74
7 changed files with 649 additions and 11 deletions
+138
View File
@@ -135,6 +135,121 @@ impl<T: RealNumber> Default for ElasticNetParameters<T> {
}
}
/// ElasticNet grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct ElasticNetSearchParameters<T: RealNumber> {
/// Regularization parameter.
pub alpha: Vec<T>,
/// The elastic net mixing parameter, with 0 <= l1_ratio <= 1.
/// For l1_ratio = 0 the penalty is an L2 penalty.
/// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.
pub l1_ratio: Vec<T>,
/// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation.
pub normalize: Vec<bool>,
/// The tolerance for the optimization
pub tol: Vec<T>,
/// The maximum number of iterations
pub max_iter: Vec<usize>,
}
/// ElasticNet grid search iterator
pub struct ElasticNetSearchParametersIterator<T: RealNumber> {
lasso_regression_search_parameters: ElasticNetSearchParameters<T>,
current_alpha: usize,
current_l1_ratio: usize,
current_normalize: usize,
current_tol: usize,
current_max_iter: usize,
}
impl<T: RealNumber> IntoIterator for ElasticNetSearchParameters<T> {
type Item = ElasticNetParameters<T>;
type IntoIter = ElasticNetSearchParametersIterator<T>;
fn into_iter(self) -> Self::IntoIter {
ElasticNetSearchParametersIterator {
lasso_regression_search_parameters: self,
current_alpha: 0,
current_l1_ratio: 0,
current_normalize: 0,
current_tol: 0,
current_max_iter: 0,
}
}
}
impl<T: RealNumber> Iterator for ElasticNetSearchParametersIterator<T> {
type Item = ElasticNetParameters<T>;
fn next(&mut self) -> Option<Self::Item> {
if self.current_alpha == self.lasso_regression_search_parameters.alpha.len()
&& self.current_l1_ratio == self.lasso_regression_search_parameters.l1_ratio.len()
&& self.current_normalize == self.lasso_regression_search_parameters.normalize.len()
&& self.current_tol == self.lasso_regression_search_parameters.tol.len()
&& self.current_max_iter == self.lasso_regression_search_parameters.max_iter.len()
{
return None;
}
let next = ElasticNetParameters {
alpha: self.lasso_regression_search_parameters.alpha[self.current_alpha],
l1_ratio: self.lasso_regression_search_parameters.alpha[self.current_l1_ratio],
normalize: self.lasso_regression_search_parameters.normalize[self.current_normalize],
tol: self.lasso_regression_search_parameters.tol[self.current_tol],
max_iter: self.lasso_regression_search_parameters.max_iter[self.current_max_iter],
};
if self.current_alpha + 1 < self.lasso_regression_search_parameters.alpha.len() {
self.current_alpha += 1;
} else if self.current_l1_ratio + 1 < self.lasso_regression_search_parameters.l1_ratio.len()
{
self.current_alpha = 0;
self.current_l1_ratio += 1;
} else if self.current_normalize + 1
< self.lasso_regression_search_parameters.normalize.len()
{
self.current_alpha = 0;
self.current_l1_ratio = 0;
self.current_normalize += 1;
} else if self.current_tol + 1 < self.lasso_regression_search_parameters.tol.len() {
self.current_alpha = 0;
self.current_l1_ratio = 0;
self.current_normalize = 0;
self.current_tol += 1;
} else if self.current_max_iter + 1 < self.lasso_regression_search_parameters.max_iter.len()
{
self.current_alpha = 0;
self.current_l1_ratio = 0;
self.current_normalize = 0;
self.current_tol = 0;
self.current_max_iter += 1;
} else {
self.current_alpha += 1;
self.current_l1_ratio += 1;
self.current_normalize += 1;
self.current_tol += 1;
self.current_max_iter += 1;
}
Some(next)
}
}
impl<T: RealNumber> Default for ElasticNetSearchParameters<T> {
fn default() -> Self {
let default_params = ElasticNetParameters::default();
ElasticNetSearchParameters {
alpha: vec![default_params.alpha],
l1_ratio: vec![default_params.l1_ratio],
normalize: vec![default_params.normalize],
tol: vec![default_params.tol],
max_iter: vec![default_params.max_iter],
}
}
}
impl<T: RealNumber, M: Matrix<T>> PartialEq for ElasticNet<T, M> {
fn eq(&self, other: &Self) -> bool {
self.coefficients == other.coefficients
@@ -291,6 +406,29 @@ mod tests {
use crate::linalg::naive::dense_matrix::*;
use crate::metrics::mean_absolute_error;
#[test]
fn search_parameters() {
let parameters = ElasticNetSearchParameters {
alpha: vec![0., 1.],
max_iter: vec![10, 100],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.alpha, 0.);
assert_eq!(next.max_iter, 10);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 1.);
assert_eq!(next.max_iter, 10);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 0.);
assert_eq!(next.max_iter, 100);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 1.);
assert_eq!(next.max_iter, 100);
assert!(iter.next().is_none());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn elasticnet_longley() {
+122
View File
@@ -112,6 +112,105 @@ impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for Lasso<T, M> {
}
}
/// Lasso grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LassoSearchParameters<T: RealNumber> {
/// Controls the strength of the penalty to the loss function.
pub alpha: Vec<T>,
/// If true the regressors X will be normalized before regression
/// by subtracting the mean and dividing by the standard deviation.
pub normalize: Vec<bool>,
/// The tolerance for the optimization
pub tol: Vec<T>,
/// The maximum number of iterations
pub max_iter: Vec<usize>,
}
/// Lasso grid search iterator
pub struct LassoSearchParametersIterator<T: RealNumber> {
lasso_regression_search_parameters: LassoSearchParameters<T>,
current_alpha: usize,
current_normalize: usize,
current_tol: usize,
current_max_iter: usize,
}
impl<T: RealNumber> IntoIterator for LassoSearchParameters<T> {
type Item = LassoParameters<T>;
type IntoIter = LassoSearchParametersIterator<T>;
fn into_iter(self) -> Self::IntoIter {
LassoSearchParametersIterator {
lasso_regression_search_parameters: self,
current_alpha: 0,
current_normalize: 0,
current_tol: 0,
current_max_iter: 0,
}
}
}
impl<T: RealNumber> Iterator for LassoSearchParametersIterator<T> {
type Item = LassoParameters<T>;
fn next(&mut self) -> Option<Self::Item> {
if self.current_alpha == self.lasso_regression_search_parameters.alpha.len()
&& self.current_normalize == self.lasso_regression_search_parameters.normalize.len()
&& self.current_tol == self.lasso_regression_search_parameters.tol.len()
&& self.current_max_iter == self.lasso_regression_search_parameters.max_iter.len()
{
return None;
}
let next = LassoParameters {
alpha: self.lasso_regression_search_parameters.alpha[self.current_alpha],
normalize: self.lasso_regression_search_parameters.normalize[self.current_normalize],
tol: self.lasso_regression_search_parameters.tol[self.current_tol],
max_iter: self.lasso_regression_search_parameters.max_iter[self.current_max_iter],
};
if self.current_alpha + 1 < self.lasso_regression_search_parameters.alpha.len() {
self.current_alpha += 1;
} else if self.current_normalize + 1
< self.lasso_regression_search_parameters.normalize.len()
{
self.current_alpha = 0;
self.current_normalize += 1;
} else if self.current_tol + 1 < self.lasso_regression_search_parameters.tol.len() {
self.current_alpha = 0;
self.current_normalize = 0;
self.current_tol += 1;
} else if self.current_max_iter + 1 < self.lasso_regression_search_parameters.max_iter.len()
{
self.current_alpha = 0;
self.current_normalize = 0;
self.current_tol = 0;
self.current_max_iter += 1;
} else {
self.current_alpha += 1;
self.current_normalize += 1;
self.current_tol += 1;
self.current_max_iter += 1;
}
Some(next)
}
}
impl<T: RealNumber> Default for LassoSearchParameters<T> {
fn default() -> Self {
let default_params = LassoParameters::default();
LassoSearchParameters {
alpha: vec![default_params.alpha],
normalize: vec![default_params.normalize],
tol: vec![default_params.tol],
max_iter: vec![default_params.max_iter],
}
}
}
impl<T: RealNumber, M: Matrix<T>> Lasso<T, M> {
/// Fits Lasso regression to your data.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
@@ -226,6 +325,29 @@ mod tests {
use crate::linalg::naive::dense_matrix::*;
use crate::metrics::mean_absolute_error;
#[test]
fn search_parameters() {
let parameters = LassoSearchParameters {
alpha: vec![0., 1.],
max_iter: vec![10, 100],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.alpha, 0.);
assert_eq!(next.max_iter, 10);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 1.);
assert_eq!(next.max_iter, 10);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 0.);
assert_eq!(next.max_iter, 100);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 1.);
assert_eq!(next.max_iter, 100);
assert!(iter.next().is_none());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn lasso_fit_predict() {
+69 -1
View File
@@ -71,7 +71,7 @@ use crate::linalg::Matrix;
use crate::math::num::RealNumber;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Eq, PartialEq)]
/// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable.
pub enum LinearRegressionSolverName {
/// QR decomposition, see [QR](../../linalg/qr/index.html)
@@ -113,6 +113,60 @@ impl Default for LinearRegressionParameters {
}
}
/// Linear Regression grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LinearRegressionSearchParameters {
/// Solver to use for estimation of regression coefficients.
pub solver: Vec<LinearRegressionSolverName>,
}
/// Linear Regression grid search iterator
pub struct LinearRegressionSearchParametersIterator {
linear_regression_search_parameters: LinearRegressionSearchParameters,
current_solver: usize,
}
impl IntoIterator for LinearRegressionSearchParameters {
type Item = LinearRegressionParameters;
type IntoIter = LinearRegressionSearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
LinearRegressionSearchParametersIterator {
linear_regression_search_parameters: self,
current_solver: 0,
}
}
}
impl Iterator for LinearRegressionSearchParametersIterator {
type Item = LinearRegressionParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_solver == self.linear_regression_search_parameters.solver.len() {
return None;
}
let next = LinearRegressionParameters {
solver: self.linear_regression_search_parameters.solver[self.current_solver].clone(),
};
self.current_solver += 1;
Some(next)
}
}
impl Default for LinearRegressionSearchParameters {
fn default() -> Self {
let default_params = LinearRegressionParameters::default();
LinearRegressionSearchParameters {
solver: vec![default_params.solver],
}
}
}
impl<T: RealNumber, M: Matrix<T>> PartialEq for LinearRegression<T, M> {
fn eq(&self, other: &Self) -> bool {
self.coefficients == other.coefficients
@@ -200,6 +254,20 @@ mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::*;
#[test]
fn search_parameters() {
let parameters = LinearRegressionSearchParameters {
solver: vec![
LinearRegressionSolverName::QR,
LinearRegressionSolverName::SVD,
],
};
let mut iter = parameters.into_iter();
assert_eq!(iter.next().unwrap().solver, LinearRegressionSolverName::QR);
assert_eq!(iter.next().unwrap().solver, LinearRegressionSolverName::SVD);
assert!(iter.next().is_none());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn ols_fit_predict() {
+87 -1
View File
@@ -68,7 +68,7 @@ use crate::optimization::line_search::Backtracking;
use crate::optimization::FunctionOrder;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Eq, PartialEq)]
/// Solver options for Logistic regression. Right now only LBFGS solver is supported.
pub enum LogisticRegressionSolverName {
/// Limited-memory BroydenFletcherGoldfarbShanno method, see [LBFGS paper](http://users.iems.northwestern.edu/~nocedal/lbfgsb.html)
@@ -85,6 +85,77 @@ pub struct LogisticRegressionParameters<T: RealNumber> {
pub alpha: T,
}
/// Logistic Regression grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LogisticRegressionSearchParameters<T: RealNumber> {
/// Solver to use for estimation of regression coefficients.
pub solver: Vec<LogisticRegressionSolverName>,
/// Regularization parameter.
pub alpha: Vec<T>,
}
/// Logistic Regression grid search iterator
pub struct LogisticRegressionSearchParametersIterator<T: RealNumber> {
logistic_regression_search_parameters: LogisticRegressionSearchParameters<T>,
current_solver: usize,
current_alpha: usize,
}
impl<T: RealNumber> IntoIterator for LogisticRegressionSearchParameters<T> {
type Item = LogisticRegressionParameters<T>;
type IntoIter = LogisticRegressionSearchParametersIterator<T>;
fn into_iter(self) -> Self::IntoIter {
LogisticRegressionSearchParametersIterator {
logistic_regression_search_parameters: self,
current_solver: 0,
current_alpha: 0,
}
}
}
impl<T: RealNumber> Iterator for LogisticRegressionSearchParametersIterator<T> {
type Item = LogisticRegressionParameters<T>;
fn next(&mut self) -> Option<Self::Item> {
if self.current_alpha == self.logistic_regression_search_parameters.alpha.len()
&& self.current_solver == self.logistic_regression_search_parameters.solver.len()
{
return None;
}
let next = LogisticRegressionParameters {
solver: self.logistic_regression_search_parameters.solver[self.current_solver].clone(),
alpha: self.logistic_regression_search_parameters.alpha[self.current_alpha],
};
if self.current_alpha + 1 < self.logistic_regression_search_parameters.alpha.len() {
self.current_alpha += 1;
} else if self.current_solver + 1 < self.logistic_regression_search_parameters.solver.len()
{
self.current_alpha = 0;
self.current_solver += 1;
} else {
self.current_alpha += 1;
self.current_solver += 1;
}
Some(next)
}
}
impl<T: RealNumber> Default for LogisticRegressionSearchParameters<T> {
fn default() -> Self {
let default_params = LogisticRegressionParameters::default();
LogisticRegressionSearchParameters {
solver: vec![default_params.solver],
alpha: vec![default_params.alpha],
}
}
}
/// Logistic Regression
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
@@ -452,6 +523,21 @@ mod tests {
use crate::linalg::naive::dense_matrix::*;
use crate::metrics::accuracy;
#[test]
fn search_parameters() {
let parameters = LogisticRegressionSearchParameters {
alpha: vec![0., 1.],
..Default::default()
};
let mut iter = parameters.into_iter();
assert_eq!(iter.next().unwrap().alpha, 0.);
assert_eq!(
iter.next().unwrap().solver,
LogisticRegressionSolverName::LBFGS
);
assert!(iter.next().is_none());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn multiclass_objective_f() {
+100 -1
View File
@@ -68,7 +68,7 @@ use crate::linalg::Matrix;
use crate::math::num::RealNumber;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Eq, PartialEq)]
/// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable.
pub enum RidgeRegressionSolverName {
/// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html)
@@ -90,6 +90,90 @@ pub struct RidgeRegressionParameters<T: RealNumber> {
pub normalize: bool,
}
/// Ridge Regression grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct RidgeRegressionSearchParameters<T: RealNumber> {
/// Solver to use for estimation of regression coefficients.
pub solver: Vec<RidgeRegressionSolverName>,
/// Regularization parameter.
pub alpha: Vec<T>,
/// If true the regressors X will be normalized before regression
/// by subtracting the mean and dividing by the standard deviation.
pub normalize: Vec<bool>,
}
/// Ridge Regression grid search iterator
pub struct RidgeRegressionSearchParametersIterator<T: RealNumber> {
ridge_regression_search_parameters: RidgeRegressionSearchParameters<T>,
current_solver: usize,
current_alpha: usize,
current_normalize: usize,
}
impl<T: RealNumber> IntoIterator for RidgeRegressionSearchParameters<T> {
type Item = RidgeRegressionParameters<T>;
type IntoIter = RidgeRegressionSearchParametersIterator<T>;
fn into_iter(self) -> Self::IntoIter {
RidgeRegressionSearchParametersIterator {
ridge_regression_search_parameters: self,
current_solver: 0,
current_alpha: 0,
current_normalize: 0,
}
}
}
impl<T: RealNumber> Iterator for RidgeRegressionSearchParametersIterator<T> {
type Item = RidgeRegressionParameters<T>;
fn next(&mut self) -> Option<Self::Item> {
if self.current_alpha == self.ridge_regression_search_parameters.alpha.len()
&& self.current_solver == self.ridge_regression_search_parameters.solver.len()
{
return None;
}
let next = RidgeRegressionParameters {
solver: self.ridge_regression_search_parameters.solver[self.current_solver].clone(),
alpha: self.ridge_regression_search_parameters.alpha[self.current_alpha],
normalize: self.ridge_regression_search_parameters.normalize[self.current_normalize],
};
if self.current_alpha + 1 < self.ridge_regression_search_parameters.alpha.len() {
self.current_alpha += 1;
} else if self.current_solver + 1 < self.ridge_regression_search_parameters.solver.len() {
self.current_alpha = 0;
self.current_solver += 1;
} else if self.current_normalize + 1
< self.ridge_regression_search_parameters.normalize.len()
{
self.current_alpha = 0;
self.current_solver = 0;
self.current_normalize += 1;
} else {
self.current_alpha += 1;
self.current_solver += 1;
self.current_normalize += 1;
}
Some(next)
}
}
impl<T: RealNumber> Default for RidgeRegressionSearchParameters<T> {
fn default() -> Self {
let default_params = RidgeRegressionParameters::default();
RidgeRegressionSearchParameters {
solver: vec![default_params.solver],
alpha: vec![default_params.alpha],
normalize: vec![default_params.normalize],
}
}
}
/// Ridge regression
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
@@ -274,6 +358,21 @@ mod tests {
use crate::linalg::naive::dense_matrix::*;
use crate::metrics::mean_absolute_error;
#[test]
fn search_parameters() {
let parameters = RidgeRegressionSearchParameters {
alpha: vec![0., 1.],
..Default::default()
};
let mut iter = parameters.into_iter();
assert_eq!(iter.next().unwrap().alpha, 0.);
assert_eq!(
iter.next().unwrap().solver,
RidgeRegressionSolverName::Cholesky
);
assert!(iter.next().is_none());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn ridge_fit_predict() {
+117
View File
@@ -0,0 +1,117 @@
/// grid search results.
#[derive(Clone, Debug)]
pub struct GridSearchResult<T: RealNumber, I: Clone> {
/// Vector with test scores on each cv split
pub cross_validation_result: CrossValidationResult<T>,
/// Vector with training scores on each cv split
pub parameters: I,
}
/// Search for the best estimator by testing all possible combinations with cross-validation using given metric.
/// * `fit_estimator` - a `fit` function of an estimator
/// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes.
/// * `y` - target values, should be of size _N_
/// * `parameter_search` - an iterator for parameters that will be tested.
/// * `cv` - the cross-validation splitting strategy, should be an instance of [`BaseKFold`](./trait.BaseKFold.html)
/// * `score` - a metric to use for evaluation, see [metrics](../metrics/index.html)
pub fn grid_search<T, M, I, E, K, F, S>(
fit_estimator: F,
x: &M,
y: &M::RowVector,
parameter_search: I,
cv: K,
score: S,
) -> Result<GridSearchResult<T, I::Item>, Failed>
where
T: RealNumber,
M: Matrix<T>,
I: Iterator,
I::Item: Clone,
E: Predictor<M, M::RowVector>,
K: BaseKFold,
F: Fn(&M, &M::RowVector, I::Item) -> Result<E, Failed>,
S: Fn(&M::RowVector, &M::RowVector) -> T,
{
let mut best_result: Option<CrossValidationResult<T>> = None;
let mut best_parameters = None;
for parameters in parameter_search {
let result = cross_validate(&fit_estimator, x, y, &parameters, &cv, &score)?;
if best_result.is_none()
|| result.mean_test_score() > best_result.as_ref().unwrap().mean_test_score()
{
best_parameters = Some(parameters);
best_result = Some(result);
}
}
if let (Some(parameters), Some(cross_validation_result)) = (best_parameters, best_result) {
Ok(GridSearchResult {
cross_validation_result,
parameters,
})
} else {
Err(Failed::because(
FailedError::FindFailed,
"there were no parameter sets found",
))
}
}
#[cfg(test)]
mod tests {
use crate::linear::logistic_regression::{
LogisticRegression, LogisticRegressionSearchParameters,
};
#[test]
fn test_grid_search() {
let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[5.7, 2.8, 4.5, 1.3],
&[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
let y = vec![
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
let cv = KFold {
n_splits: 5,
..KFold::default()
};
let parameters = LogisticRegressionSearchParameters {
alpha: vec![0., 1.],
..Default::default()
};
let results = grid_search(
LogisticRegression::fit,
&x,
&y,
parameters.into_iter(),
cv,
&accuracy,
)
.unwrap();
assert!([0., 1.].contains(&results.parameters.alpha));
}
}
+16 -8
View File
@@ -91,8 +91,8 @@
//!
//! let results = cross_validate(LogisticRegression::fit, //estimator
//! &x, &y, //data
//! Default::default(), //hyperparameters
//! cv, //cross validation split
//! &Default::default(), //hyperparameters
//! &cv, //cross validation split
//! &accuracy).unwrap(); //metric
//!
//! println!("Training accuracy: {}, test accuracy: {}",
@@ -201,8 +201,8 @@ pub fn cross_validate<T, M, H, E, K, F, S>(
fit_estimator: F,
x: &M,
y: &M::RowVector,
parameters: H,
cv: K,
parameters: &H,
cv: &K,
score: S,
) -> Result<CrossValidationResult<T>, Failed>
where
@@ -281,6 +281,7 @@ mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::*;
use crate::metrics::{accuracy, mean_absolute_error};
use crate::model_selection::kfold::KFold;
use crate::neighbors::knn_regressor::KNNRegressor;
@@ -362,8 +363,15 @@ mod tests {
..KFold::default()
};
let results =
cross_validate(BiasedEstimator::fit, &x, &y, NoParameters {}, cv, &accuracy).unwrap();
let results = cross_validate(
BiasedEstimator::fit,
&x,
&y,
&NoParameters {},
&cv,
&accuracy,
)
.unwrap();
assert_eq!(0.4, results.mean_test_score());
assert_eq!(0.4, results.mean_train_score());
@@ -404,8 +412,8 @@ mod tests {
KNNRegressor::fit,
&x,
&y,
Default::default(),
cv,
&Default::default(),
&cv,
&mean_absolute_error,
)
.unwrap();