feat: expose hyper tuning module in model_selection (#179)

* feat: expose hyper tuning module in model_selection

* Move to a folder

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
This commit is contained in:
morenol
2022-10-01 12:47:56 -05:00
parent 9ea3133c27
commit ad2e6c2900
3 changed files with 65 additions and 49 deletions
@@ -1,3 +1,12 @@
use crate::{
api::Predictor,
error::{Failed, FailedError},
linalg::Matrix,
math::num::RealNumber,
};
use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult};
/// grid search results.
#[derive(Clone, Debug)]
pub struct GridSearchResult<T: RealNumber, I: Clone> {
@@ -60,58 +69,61 @@ where
#[cfg(test)]
mod tests {
use crate::linear::logistic_regression::{
LogisticRegression, LogisticRegressionSearchParameters,
};
use crate::{
linalg::naive::dense_matrix::DenseMatrix,
linear::logistic_regression::{LogisticRegression, LogisticRegressionSearchParameters},
metrics::accuracy,
model_selection::{hyper_tuning::grid_search, KFold},
};
#[test]
fn test_grid_search() {
let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[5.7, 2.8, 4.5, 1.3],
&[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
let y = vec![
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
#[test]
fn test_grid_search() {
let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[5.7, 2.8, 4.5, 1.3],
&[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
let y = vec![
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
let cv = KFold {
n_splits: 5,
..KFold::default()
};
let cv = KFold {
n_splits: 5,
..KFold::default()
};
let parameters = LogisticRegressionSearchParameters {
alpha: vec![0., 1.],
..Default::default()
};
let parameters = LogisticRegressionSearchParameters {
alpha: vec![0., 1.],
..Default::default()
};
let results = grid_search(
LogisticRegression::fit,
&x,
&y,
parameters.into_iter(),
cv,
&accuracy,
)
.unwrap();
let results = grid_search(
LogisticRegression::fit,
&x,
&y,
parameters.into_iter(),
cv,
&accuracy,
)
.unwrap();
assert!([0., 1.].contains(&results.parameters.alpha));
}
assert!([0., 1.].contains(&results.parameters.alpha));
}
}
+2
View File
@@ -0,0 +1,2 @@
mod grid_search;
pub use grid_search::{grid_search, GridSearchResult};
+2
View File
@@ -110,8 +110,10 @@ use crate::math::num::RealNumber;
use crate::rand::get_rng_impl;
use rand::seq::SliceRandom;
pub(crate) mod hyper_tuning;
pub(crate) mod kfold;
pub use hyper_tuning::{grid_search, GridSearchResult};
pub use kfold::{KFold, KFoldIter};
/// An interface for the K-Folds cross-validator