feat: expose hyper tuning module in model_selection (#179)

* feat: expose hyper tuning module in model_selection

* Move to a folder

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
This commit is contained in:
morenol
2022-10-01 12:47:56 -05:00
committed by GitHub
parent 9c59e37a0f
commit 3c62686d6e
3 changed files with 65 additions and 49 deletions
@@ -1,3 +1,12 @@
use crate::{
api::Predictor,
error::{Failed, FailedError},
linalg::Matrix,
math::num::RealNumber,
};
use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult};
/// grid search results. /// grid search results.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct GridSearchResult<T: RealNumber, I: Clone> { pub struct GridSearchResult<T: RealNumber, I: Clone> {
@@ -60,58 +69,61 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::linear::logistic_regression::{ use crate::{
LogisticRegression, LogisticRegressionSearchParameters, linalg::naive::dense_matrix::DenseMatrix,
}; linear::logistic_regression::{LogisticRegression, LogisticRegressionSearchParameters},
metrics::accuracy,
model_selection::{hyper_tuning::grid_search, KFold},
};
#[test] #[test]
fn test_grid_search() { fn test_grid_search() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2], &[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2], &[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2], &[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2], &[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2], &[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4], &[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3], &[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2], &[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2], &[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1], &[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4], &[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5], &[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5], &[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3], &[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5], &[6.5, 2.8, 4.6, 1.5],
&[5.7, 2.8, 4.5, 1.3], &[5.7, 2.8, 4.5, 1.3],
&[6.3, 3.3, 4.7, 1.6], &[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0], &[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3], &[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4], &[5.2, 2.7, 3.9, 1.4],
]); ]);
let y = vec![ let y = vec![
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
]; ];
let cv = KFold { let cv = KFold {
n_splits: 5, n_splits: 5,
..KFold::default() ..KFold::default()
}; };
let parameters = LogisticRegressionSearchParameters { let parameters = LogisticRegressionSearchParameters {
alpha: vec![0., 1.], alpha: vec![0., 1.],
..Default::default() ..Default::default()
}; };
let results = grid_search( let results = grid_search(
LogisticRegression::fit, LogisticRegression::fit,
&x, &x,
&y, &y,
parameters.into_iter(), parameters.into_iter(),
cv, cv,
&accuracy, &accuracy,
) )
.unwrap(); .unwrap();
assert!([0., 1.].contains(&results.parameters.alpha)); assert!([0., 1.].contains(&results.parameters.alpha));
} }
} }
+2
View File
@@ -0,0 +1,2 @@
mod grid_search;
pub use grid_search::{grid_search, GridSearchResult};
+2
View File
@@ -110,8 +110,10 @@ use crate::math::num::RealNumber;
use crate::rand::get_rng_impl; use crate::rand::get_rng_impl;
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
pub(crate) mod hyper_tuning;
pub(crate) mod kfold; pub(crate) mod kfold;
pub use hyper_tuning::{grid_search, GridSearchResult};
pub use kfold::{KFold, KFoldIter}; pub use kfold::{KFold, KFoldIter};
/// An interface for the K-Folds cross-validator /// An interface for the K-Folds cross-validator