refactor: Try to follow similar pattern to other APIs (#180)
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
use crate::{
|
||||
api::Predictor,
|
||||
api::{Predictor, SupervisedEstimator},
|
||||
error::{Failed, FailedError},
|
||||
linalg::Matrix,
|
||||
math::num::RealNumber,
|
||||
@@ -7,45 +7,85 @@ use crate::{
|
||||
|
||||
use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult};
|
||||
|
||||
/// grid search results.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct GridSearchResult<T: RealNumber, I: Clone> {
|
||||
/// Vector with test scores on each cv split
|
||||
pub cross_validation_result: CrossValidationResult<T>,
|
||||
/// Vector with training scores on each cv split
|
||||
pub parameters: I,
|
||||
}
|
||||
|
||||
/// Search for the best estimator by testing all possible combinations with cross-validation using given metric.
|
||||
/// * `fit_estimator` - a `fit` function of an estimator
|
||||
/// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes.
|
||||
/// * `y` - target values, should be of size _N_
|
||||
/// * `parameter_search` - an iterator for parameters that will be tested.
|
||||
/// * `cv` - the cross-validation splitting strategy, should be an instance of [`BaseKFold`](./trait.BaseKFold.html)
|
||||
/// * `score` - a metric to use for evaluation, see [metrics](../metrics/index.html)
|
||||
pub fn grid_search<T, M, I, E, K, F, S>(
|
||||
fit_estimator: F,
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameter_search: I,
|
||||
cv: K,
|
||||
score: S,
|
||||
) -> Result<GridSearchResult<T, I::Item>, Failed>
|
||||
where
|
||||
/// Parameters for GridSearchCV
|
||||
#[derive(Debug)]
|
||||
pub struct GridSearchCVParameters<
|
||||
T: RealNumber,
|
||||
M: Matrix<T>,
|
||||
I: Iterator,
|
||||
I::Item: Clone,
|
||||
C: Clone,
|
||||
I: Iterator<Item = C>,
|
||||
E: Predictor<M, M::RowVector>,
|
||||
F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>,
|
||||
K: BaseKFold,
|
||||
F: Fn(&M, &M::RowVector, I::Item) -> Result<E, Failed>,
|
||||
S: Fn(&M::RowVector, &M::RowVector) -> T,
|
||||
> {
|
||||
_phantom: std::marker::PhantomData<(T, M)>,
|
||||
|
||||
parameters_search: I,
|
||||
estimator: F,
|
||||
score: S,
|
||||
cv: K,
|
||||
}
|
||||
|
||||
impl<
|
||||
T: RealNumber,
|
||||
M: Matrix<T>,
|
||||
C: Clone,
|
||||
I: Iterator<Item = C>,
|
||||
E: Predictor<M, M::RowVector>,
|
||||
F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>,
|
||||
K: BaseKFold,
|
||||
S: Fn(&M::RowVector, &M::RowVector) -> T,
|
||||
> GridSearchCVParameters<T, M, C, I, E, F, K, S>
|
||||
{
|
||||
/// Create new GridSearchCVParameters
|
||||
pub fn new(parameters_search: I, estimator: F, score: S, cv: K) -> Self {
|
||||
GridSearchCVParameters {
|
||||
_phantom: std::marker::PhantomData,
|
||||
parameters_search,
|
||||
estimator,
|
||||
score,
|
||||
cv,
|
||||
}
|
||||
}
|
||||
}
|
||||
/// Exhaustive search over specified parameter values for an estimator.
|
||||
#[derive(Debug)]
|
||||
pub struct GridSearchCV<T: RealNumber, M: Matrix<T>, C: Clone, E: Predictor<M, M::RowVector>> {
|
||||
_phantom: std::marker::PhantomData<(T, M)>,
|
||||
predictor: E,
|
||||
/// Cross validation results.
|
||||
pub cross_validation_result: CrossValidationResult<T>,
|
||||
/// best parameter
|
||||
pub best_parameter: C,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, E: Predictor<M, M::RowVector>, C: Clone>
|
||||
GridSearchCV<T, M, C, E>
|
||||
{
|
||||
/// Search for the best estimator by testing all possible combinations with cross-validation using given metric.
|
||||
/// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes.
|
||||
/// * `y` - target values, should be of size _N_
|
||||
/// * `gs_parameters` - GridSearchCVParameters struct
|
||||
pub fn fit<
|
||||
I: Iterator<Item = C>,
|
||||
K: BaseKFold,
|
||||
F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>,
|
||||
S: Fn(&M::RowVector, &M::RowVector) -> T,
|
||||
>(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
gs_parameters: GridSearchCVParameters<T, M, C, I, E, F, K, S>,
|
||||
) -> Result<Self, Failed> {
|
||||
let mut best_result: Option<CrossValidationResult<T>> = None;
|
||||
let mut best_parameters = None;
|
||||
let parameters_search = gs_parameters.parameters_search;
|
||||
let estimator = gs_parameters.estimator;
|
||||
let cv = gs_parameters.cv;
|
||||
let score = gs_parameters.score;
|
||||
|
||||
for parameters in parameter_search {
|
||||
let result = cross_validate(&fit_estimator, x, y, ¶meters, &cv, &score)?;
|
||||
for parameters in parameters_search {
|
||||
let result = cross_validate(&estimator, x, y, ¶meters, &cv, &score)?;
|
||||
if best_result.is_none()
|
||||
|| result.mean_test_score() > best_result.as_ref().unwrap().mean_test_score()
|
||||
{
|
||||
@@ -54,10 +94,15 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
if let (Some(parameters), Some(cross_validation_result)) = (best_parameters, best_result) {
|
||||
Ok(GridSearchResult {
|
||||
if let (Some(best_parameter), Some(cross_validation_result)) =
|
||||
(best_parameters, best_result)
|
||||
{
|
||||
let predictor = estimator(x, y, best_parameter.clone())?;
|
||||
Ok(Self {
|
||||
_phantom: gs_parameters._phantom,
|
||||
predictor,
|
||||
cross_validation_result,
|
||||
parameters,
|
||||
best_parameter,
|
||||
})
|
||||
} else {
|
||||
Err(Failed::because(
|
||||
@@ -65,16 +110,66 @@ where
|
||||
"there were no parameter sets found",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Return grid search cross validation results
|
||||
pub fn cv_results(&self) -> &CrossValidationResult<T> {
|
||||
&self.cross_validation_result
|
||||
}
|
||||
|
||||
/// Return best parameters found
|
||||
pub fn best_parameters(&self) -> &C {
|
||||
&self.best_parameter
|
||||
}
|
||||
|
||||
/// Call predict on the estimator with the best found parameters
|
||||
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predictor.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
T: RealNumber,
|
||||
M: Matrix<T>,
|
||||
C: Clone,
|
||||
I: Iterator<Item = C>,
|
||||
E: Predictor<M, M::RowVector>,
|
||||
F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>,
|
||||
K: BaseKFold,
|
||||
S: Fn(&M::RowVector, &M::RowVector) -> T,
|
||||
> SupervisedEstimator<M, M::RowVector, GridSearchCVParameters<T, M, C, I, E, F, K, S>>
|
||||
for GridSearchCV<T, M, C, E>
|
||||
{
|
||||
fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: GridSearchCVParameters<T, M, C, I, E, F, K, S>,
|
||||
) -> Result<Self, Failed> {
|
||||
GridSearchCV::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, C: Clone, E: Predictor<M, M::RowVector>>
|
||||
Predictor<M, M::RowVector> for GridSearchCV<T, M, C, E>
|
||||
{
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use crate::{
|
||||
linalg::naive::dense_matrix::DenseMatrix,
|
||||
linear::logistic_regression::{LogisticRegression, LogisticRegressionSearchParameters},
|
||||
metrics::accuracy,
|
||||
model_selection::{hyper_tuning::grid_search, KFold},
|
||||
model_selection::{
|
||||
hyper_tuning::grid_search::{self, GridSearchCVParameters},
|
||||
KFold,
|
||||
},
|
||||
};
|
||||
use grid_search::GridSearchCV;
|
||||
|
||||
#[test]
|
||||
fn test_grid_search() {
|
||||
@@ -114,16 +209,28 @@ mod tests {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let results = grid_search(
|
||||
LogisticRegression::fit,
|
||||
let grid_search = GridSearchCV::fit(
|
||||
&x,
|
||||
&y,
|
||||
parameters.into_iter(),
|
||||
GridSearchCVParameters {
|
||||
estimator: LogisticRegression::fit,
|
||||
score: accuracy,
|
||||
cv,
|
||||
&accuracy,
|
||||
parameters_search: parameters.into_iter(),
|
||||
_phantom: Default::default(),
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
let best_parameters = grid_search.best_parameters();
|
||||
|
||||
assert!([0., 1.].contains(&results.parameters.alpha));
|
||||
assert!([1.].contains(&best_parameters.alpha));
|
||||
|
||||
let cv_results = grid_search.cv_results();
|
||||
|
||||
assert_eq!(cv_results.mean_test_score(), 0.9);
|
||||
|
||||
let x = DenseMatrix::from_2d_array(&[&[5., 3., 1., 0.]]);
|
||||
let result = grid_search.predict(&x).unwrap();
|
||||
assert_eq!(result, vec![0.]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
mod grid_search;
|
||||
pub use grid_search::{grid_search, GridSearchResult};
|
||||
pub use grid_search::{GridSearchCV, GridSearchCVParameters};
|
||||
|
||||
@@ -113,7 +113,7 @@ use rand::seq::SliceRandom;
|
||||
pub(crate) mod hyper_tuning;
|
||||
pub(crate) mod kfold;
|
||||
|
||||
pub use hyper_tuning::{grid_search, GridSearchResult};
|
||||
pub use hyper_tuning::{GridSearchCV, GridSearchCVParameters};
|
||||
pub use kfold::{KFold, KFoldIter};
|
||||
|
||||
/// An interface for the K-Folds cross-validator
|
||||
|
||||
Reference in New Issue
Block a user