feat: expose hyper tuning module in model_selection (#179)
* feat: expose hyper tuning module in model_selection * Move to a folder Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
This commit is contained in:
+61
-49
@@ -1,3 +1,12 @@
|
|||||||
|
use crate::{
|
||||||
|
api::Predictor,
|
||||||
|
error::{Failed, FailedError},
|
||||||
|
linalg::Matrix,
|
||||||
|
math::num::RealNumber,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult};
|
||||||
|
|
||||||
/// grid search results.
|
/// grid search results.
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct GridSearchResult<T: RealNumber, I: Clone> {
|
pub struct GridSearchResult<T: RealNumber, I: Clone> {
|
||||||
@@ -60,58 +69,61 @@ where
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::linear::logistic_regression::{
|
use crate::{
|
||||||
LogisticRegression, LogisticRegressionSearchParameters,
|
linalg::naive::dense_matrix::DenseMatrix,
|
||||||
};
|
linear::logistic_regression::{LogisticRegression, LogisticRegressionSearchParameters},
|
||||||
|
metrics::accuracy,
|
||||||
|
model_selection::{hyper_tuning::grid_search, KFold},
|
||||||
|
};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_grid_search() {
|
fn test_grid_search() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[5.1, 3.5, 1.4, 0.2],
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
&[4.9, 3.0, 1.4, 0.2],
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
&[4.7, 3.2, 1.3, 0.2],
|
&[4.7, 3.2, 1.3, 0.2],
|
||||||
&[4.6, 3.1, 1.5, 0.2],
|
&[4.6, 3.1, 1.5, 0.2],
|
||||||
&[5.0, 3.6, 1.4, 0.2],
|
&[5.0, 3.6, 1.4, 0.2],
|
||||||
&[5.4, 3.9, 1.7, 0.4],
|
&[5.4, 3.9, 1.7, 0.4],
|
||||||
&[4.6, 3.4, 1.4, 0.3],
|
&[4.6, 3.4, 1.4, 0.3],
|
||||||
&[5.0, 3.4, 1.5, 0.2],
|
&[5.0, 3.4, 1.5, 0.2],
|
||||||
&[4.4, 2.9, 1.4, 0.2],
|
&[4.4, 2.9, 1.4, 0.2],
|
||||||
&[4.9, 3.1, 1.5, 0.1],
|
&[4.9, 3.1, 1.5, 0.1],
|
||||||
&[7.0, 3.2, 4.7, 1.4],
|
&[7.0, 3.2, 4.7, 1.4],
|
||||||
&[6.4, 3.2, 4.5, 1.5],
|
&[6.4, 3.2, 4.5, 1.5],
|
||||||
&[6.9, 3.1, 4.9, 1.5],
|
&[6.9, 3.1, 4.9, 1.5],
|
||||||
&[5.5, 2.3, 4.0, 1.3],
|
&[5.5, 2.3, 4.0, 1.3],
|
||||||
&[6.5, 2.8, 4.6, 1.5],
|
&[6.5, 2.8, 4.6, 1.5],
|
||||||
&[5.7, 2.8, 4.5, 1.3],
|
&[5.7, 2.8, 4.5, 1.3],
|
||||||
&[6.3, 3.3, 4.7, 1.6],
|
&[6.3, 3.3, 4.7, 1.6],
|
||||||
&[4.9, 2.4, 3.3, 1.0],
|
&[4.9, 2.4, 3.3, 1.0],
|
||||||
&[6.6, 2.9, 4.6, 1.3],
|
&[6.6, 2.9, 4.6, 1.3],
|
||||||
&[5.2, 2.7, 3.9, 1.4],
|
&[5.2, 2.7, 3.9, 1.4],
|
||||||
]);
|
]);
|
||||||
let y = vec![
|
let y = vec![
|
||||||
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||||
];
|
];
|
||||||
|
|
||||||
let cv = KFold {
|
let cv = KFold {
|
||||||
n_splits: 5,
|
n_splits: 5,
|
||||||
..KFold::default()
|
..KFold::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let parameters = LogisticRegressionSearchParameters {
|
let parameters = LogisticRegressionSearchParameters {
|
||||||
alpha: vec![0., 1.],
|
alpha: vec![0., 1.],
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let results = grid_search(
|
let results = grid_search(
|
||||||
LogisticRegression::fit,
|
LogisticRegression::fit,
|
||||||
&x,
|
&x,
|
||||||
&y,
|
&y,
|
||||||
parameters.into_iter(),
|
parameters.into_iter(),
|
||||||
cv,
|
cv,
|
||||||
&accuracy,
|
&accuracy,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert!([0., 1.].contains(&results.parameters.alpha));
|
assert!([0., 1.].contains(&results.parameters.alpha));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
mod grid_search;
|
||||||
|
pub use grid_search::{grid_search, GridSearchResult};
|
||||||
@@ -110,8 +110,10 @@ use crate::math::num::RealNumber;
|
|||||||
use crate::rand::get_rng_impl;
|
use crate::rand::get_rng_impl;
|
||||||
use rand::seq::SliceRandom;
|
use rand::seq::SliceRandom;
|
||||||
|
|
||||||
|
pub(crate) mod hyper_tuning;
|
||||||
pub(crate) mod kfold;
|
pub(crate) mod kfold;
|
||||||
|
|
||||||
|
pub use hyper_tuning::{grid_search, GridSearchResult};
|
||||||
pub use kfold::{KFold, KFoldIter};
|
pub use kfold::{KFold, KFoldIter};
|
||||||
|
|
||||||
/// An interface for the K-Folds cross-validator
|
/// An interface for the K-Folds cross-validator
|
||||||
|
|||||||
Reference in New Issue
Block a user