feat: expose hyper tuning module in model_selection (#179)
* feat: expose hyper tuning module in model_selection * Move to a folder Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
This commit is contained in:
+61
-49
@@ -1,3 +1,12 @@
|
||||
use crate::{
|
||||
api::Predictor,
|
||||
error::{Failed, FailedError},
|
||||
linalg::Matrix,
|
||||
math::num::RealNumber,
|
||||
};
|
||||
|
||||
use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult};
|
||||
|
||||
/// grid search results.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct GridSearchResult<T: RealNumber, I: Clone> {
|
||||
@@ -60,58 +69,61 @@ where
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::linear::logistic_regression::{
|
||||
LogisticRegression, LogisticRegressionSearchParameters,
|
||||
};
|
||||
use crate::{
|
||||
linalg::naive::dense_matrix::DenseMatrix,
|
||||
linear::logistic_regression::{LogisticRegression, LogisticRegressionSearchParameters},
|
||||
metrics::accuracy,
|
||||
model_selection::{hyper_tuning::grid_search, KFold},
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_grid_search() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[5.7, 2.8, 4.5, 1.3],
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
let y = vec![
|
||||
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
#[test]
|
||||
fn test_grid_search() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[5.7, 2.8, 4.5, 1.3],
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
let y = vec![
|
||||
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
|
||||
let cv = KFold {
|
||||
n_splits: 5,
|
||||
..KFold::default()
|
||||
};
|
||||
let cv = KFold {
|
||||
n_splits: 5,
|
||||
..KFold::default()
|
||||
};
|
||||
|
||||
let parameters = LogisticRegressionSearchParameters {
|
||||
alpha: vec![0., 1.],
|
||||
..Default::default()
|
||||
};
|
||||
let parameters = LogisticRegressionSearchParameters {
|
||||
alpha: vec![0., 1.],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let results = grid_search(
|
||||
LogisticRegression::fit,
|
||||
&x,
|
||||
&y,
|
||||
parameters.into_iter(),
|
||||
cv,
|
||||
&accuracy,
|
||||
)
|
||||
.unwrap();
|
||||
let results = grid_search(
|
||||
LogisticRegression::fit,
|
||||
&x,
|
||||
&y,
|
||||
parameters.into_iter(),
|
||||
cv,
|
||||
&accuracy,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!([0., 1.].contains(&results.parameters.alpha));
|
||||
}
|
||||
assert!([0., 1.].contains(&results.parameters.alpha));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
mod grid_search;
|
||||
pub use grid_search::{grid_search, GridSearchResult};
|
||||
@@ -110,8 +110,10 @@ use crate::math::num::RealNumber;
|
||||
use crate::rand::get_rng_impl;
|
||||
use rand::seq::SliceRandom;
|
||||
|
||||
pub(crate) mod hyper_tuning;
|
||||
pub(crate) mod kfold;
|
||||
|
||||
pub use hyper_tuning::{grid_search, GridSearchResult};
|
||||
pub use kfold::{KFold, KFoldIter};
|
||||
|
||||
/// An interface for the K-Folds cross-validator
|
||||
|
||||
Reference in New Issue
Block a user