grid search (#154)

* grid search draft
* hyperparam search for linear estimators
This commit is contained in:
Montana Low
2022-09-19 02:31:56 -07:00
committed by GitHub
parent 2e5f88fad8
commit 4685fc73e0
7 changed files with 649 additions and 11 deletions
+16 -8
View File
@@ -91,8 +91,8 @@
//!
//! let results = cross_validate(LogisticRegression::fit, //estimator
//! &x, &y, //data
//! Default::default(), //hyperparameters
//! cv, //cross validation split
//! &Default::default(), //hyperparameters
//! &cv, //cross validation split
//! &accuracy).unwrap(); //metric
//!
//! println!("Training accuracy: {}, test accuracy: {}",
@@ -201,8 +201,8 @@ pub fn cross_validate<T, M, H, E, K, F, S>(
fit_estimator: F,
x: &M,
y: &M::RowVector,
parameters: H,
cv: K,
parameters: &H,
cv: &K,
score: S,
) -> Result<CrossValidationResult<T>, Failed>
where
@@ -281,6 +281,7 @@ mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::*;
use crate::metrics::{accuracy, mean_absolute_error};
use crate::model_selection::kfold::KFold;
use crate::neighbors::knn_regressor::KNNRegressor;
@@ -362,8 +363,15 @@ mod tests {
..KFold::default()
};
let results =
cross_validate(BiasedEstimator::fit, &x, &y, NoParameters {}, cv, &accuracy).unwrap();
let results = cross_validate(
BiasedEstimator::fit,
&x,
&y,
&NoParameters {},
&cv,
&accuracy,
)
.unwrap();
assert_eq!(0.4, results.mean_test_score());
assert_eq!(0.4, results.mean_train_score());
@@ -404,8 +412,8 @@ mod tests {
KNNRegressor::fit,
&x,
&y,
Default::default(),
cv,
&Default::default(),
&cv,
&mean_absolute_error,
)
.unwrap();