grid search (#154)
* grid search draft * hyperparam search for linear estimators
This commit is contained in:
@@ -91,8 +91,8 @@
|
||||
//!
|
||||
//! let results = cross_validate(LogisticRegression::fit, //estimator
|
||||
//! &x, &y, //data
|
||||
//! Default::default(), //hyperparameters
|
||||
//! cv, //cross validation split
|
||||
//! &Default::default(), //hyperparameters
|
||||
//! &cv, //cross validation split
|
||||
//! &accuracy).unwrap(); //metric
|
||||
//!
|
||||
//! println!("Training accuracy: {}, test accuracy: {}",
|
||||
@@ -201,8 +201,8 @@ pub fn cross_validate<T, M, H, E, K, F, S>(
|
||||
fit_estimator: F,
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: H,
|
||||
cv: K,
|
||||
parameters: &H,
|
||||
cv: &K,
|
||||
score: S,
|
||||
) -> Result<CrossValidationResult<T>, Failed>
|
||||
where
|
||||
@@ -281,6 +281,7 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
|
||||
use crate::metrics::{accuracy, mean_absolute_error};
|
||||
use crate::model_selection::kfold::KFold;
|
||||
use crate::neighbors::knn_regressor::KNNRegressor;
|
||||
@@ -362,8 +363,15 @@ mod tests {
|
||||
..KFold::default()
|
||||
};
|
||||
|
||||
let results =
|
||||
cross_validate(BiasedEstimator::fit, &x, &y, NoParameters {}, cv, &accuracy).unwrap();
|
||||
let results = cross_validate(
|
||||
BiasedEstimator::fit,
|
||||
&x,
|
||||
&y,
|
||||
&NoParameters {},
|
||||
&cv,
|
||||
&accuracy,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(0.4, results.mean_test_score());
|
||||
assert_eq!(0.4, results.mean_train_score());
|
||||
@@ -404,8 +412,8 @@ mod tests {
|
||||
KNNRegressor::fit,
|
||||
&x,
|
||||
&y,
|
||||
Default::default(),
|
||||
cv,
|
||||
&Default::default(),
|
||||
&cv,
|
||||
&mean_absolute_error,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
Reference in New Issue
Block a user