feat: expose hyper tuning module in model_selection (#179)

* feat: expose hyper tuning module in model_selection

* Move to a folder

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
This commit is contained in:
morenol
2022-10-01 12:47:56 -05:00
parent 9ea3133c27
commit ad2e6c2900
3 changed files with 65 additions and 49 deletions
@@ -1,3 +1,12 @@
use crate::{
api::Predictor,
error::{Failed, FailedError},
linalg::Matrix,
math::num::RealNumber,
};
use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult};
/// grid search results.
#[derive(Clone, Debug)]
pub struct GridSearchResult<T: RealNumber, I: Clone> {
@@ -60,9 +69,12 @@ where
#[cfg(test)]
mod tests {
use crate::linear::logistic_regression::{
LogisticRegression, LogisticRegressionSearchParameters,
};
use crate::{
linalg::naive::dense_matrix::DenseMatrix,
linear::logistic_regression::{LogisticRegression, LogisticRegressionSearchParameters},
metrics::accuracy,
model_selection::{hyper_tuning::grid_search, KFold},
};
#[test]
fn test_grid_search() {
+2
View File
@@ -0,0 +1,2 @@
mod grid_search;
pub use grid_search::{grid_search, GridSearchResult};
+2
View File
@@ -110,8 +110,10 @@ use crate::math::num::RealNumber;
use crate::rand::get_rng_impl;
use rand::seq::SliceRandom;
pub(crate) mod hyper_tuning;
pub(crate) mod kfold;
pub use hyper_tuning::{grid_search, GridSearchResult};
pub use kfold::{KFold, KFoldIter};
/// An interface for the K-Folds cross-validator