refactor: Try to follow similar pattern to other APIs (#180)
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
api::Predictor,
|
api::{Predictor, SupervisedEstimator},
|
||||||
error::{Failed, FailedError},
|
error::{Failed, FailedError},
|
||||||
linalg::Matrix,
|
linalg::Matrix,
|
||||||
math::num::RealNumber,
|
math::num::RealNumber,
|
||||||
@@ -7,45 +7,85 @@ use crate::{
|
|||||||
|
|
||||||
use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult};
|
use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult};
|
||||||
|
|
||||||
/// grid search results.
|
/// Parameters for GridSearchCV
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Debug)]
|
||||||
pub struct GridSearchResult<T: RealNumber, I: Clone> {
|
pub struct GridSearchCVParameters<
|
||||||
/// Vector with test scores on each cv split
|
|
||||||
pub cross_validation_result: CrossValidationResult<T>,
|
|
||||||
/// Vector with training scores on each cv split
|
|
||||||
pub parameters: I,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Search for the best estimator by testing all possible combinations with cross-validation using given metric.
|
|
||||||
/// * `fit_estimator` - a `fit` function of an estimator
|
|
||||||
/// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes.
|
|
||||||
/// * `y` - target values, should be of size _N_
|
|
||||||
/// * `parameter_search` - an iterator for parameters that will be tested.
|
|
||||||
/// * `cv` - the cross-validation splitting strategy, should be an instance of [`BaseKFold`](./trait.BaseKFold.html)
|
|
||||||
/// * `score` - a metric to use for evaluation, see [metrics](../metrics/index.html)
|
|
||||||
pub fn grid_search<T, M, I, E, K, F, S>(
|
|
||||||
fit_estimator: F,
|
|
||||||
x: &M,
|
|
||||||
y: &M::RowVector,
|
|
||||||
parameter_search: I,
|
|
||||||
cv: K,
|
|
||||||
score: S,
|
|
||||||
) -> Result<GridSearchResult<T, I::Item>, Failed>
|
|
||||||
where
|
|
||||||
T: RealNumber,
|
T: RealNumber,
|
||||||
M: Matrix<T>,
|
M: Matrix<T>,
|
||||||
I: Iterator,
|
C: Clone,
|
||||||
I::Item: Clone,
|
I: Iterator<Item = C>,
|
||||||
E: Predictor<M, M::RowVector>,
|
E: Predictor<M, M::RowVector>,
|
||||||
|
F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>,
|
||||||
K: BaseKFold,
|
K: BaseKFold,
|
||||||
F: Fn(&M, &M::RowVector, I::Item) -> Result<E, Failed>,
|
|
||||||
S: Fn(&M::RowVector, &M::RowVector) -> T,
|
S: Fn(&M::RowVector, &M::RowVector) -> T,
|
||||||
|
> {
|
||||||
|
_phantom: std::marker::PhantomData<(T, M)>,
|
||||||
|
|
||||||
|
parameters_search: I,
|
||||||
|
estimator: F,
|
||||||
|
score: S,
|
||||||
|
cv: K,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<
|
||||||
|
T: RealNumber,
|
||||||
|
M: Matrix<T>,
|
||||||
|
C: Clone,
|
||||||
|
I: Iterator<Item = C>,
|
||||||
|
E: Predictor<M, M::RowVector>,
|
||||||
|
F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>,
|
||||||
|
K: BaseKFold,
|
||||||
|
S: Fn(&M::RowVector, &M::RowVector) -> T,
|
||||||
|
> GridSearchCVParameters<T, M, C, I, E, F, K, S>
|
||||||
{
|
{
|
||||||
|
/// Create new GridSearchCVParameters
|
||||||
|
pub fn new(parameters_search: I, estimator: F, score: S, cv: K) -> Self {
|
||||||
|
GridSearchCVParameters {
|
||||||
|
_phantom: std::marker::PhantomData,
|
||||||
|
parameters_search,
|
||||||
|
estimator,
|
||||||
|
score,
|
||||||
|
cv,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// Exhaustive search over specified parameter values for an estimator.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct GridSearchCV<T: RealNumber, M: Matrix<T>, C: Clone, E: Predictor<M, M::RowVector>> {
|
||||||
|
_phantom: std::marker::PhantomData<(T, M)>,
|
||||||
|
predictor: E,
|
||||||
|
/// Cross validation results.
|
||||||
|
pub cross_validation_result: CrossValidationResult<T>,
|
||||||
|
/// best parameter
|
||||||
|
pub best_parameter: C,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber, M: Matrix<T>, E: Predictor<M, M::RowVector>, C: Clone>
|
||||||
|
GridSearchCV<T, M, C, E>
|
||||||
|
{
|
||||||
|
/// Search for the best estimator by testing all possible combinations with cross-validation using given metric.
|
||||||
|
/// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes.
|
||||||
|
/// * `y` - target values, should be of size _N_
|
||||||
|
/// * `gs_parameters` - GridSearchCVParameters struct
|
||||||
|
pub fn fit<
|
||||||
|
I: Iterator<Item = C>,
|
||||||
|
K: BaseKFold,
|
||||||
|
F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>,
|
||||||
|
S: Fn(&M::RowVector, &M::RowVector) -> T,
|
||||||
|
>(
|
||||||
|
x: &M,
|
||||||
|
y: &M::RowVector,
|
||||||
|
gs_parameters: GridSearchCVParameters<T, M, C, I, E, F, K, S>,
|
||||||
|
) -> Result<Self, Failed> {
|
||||||
let mut best_result: Option<CrossValidationResult<T>> = None;
|
let mut best_result: Option<CrossValidationResult<T>> = None;
|
||||||
let mut best_parameters = None;
|
let mut best_parameters = None;
|
||||||
|
let parameters_search = gs_parameters.parameters_search;
|
||||||
|
let estimator = gs_parameters.estimator;
|
||||||
|
let cv = gs_parameters.cv;
|
||||||
|
let score = gs_parameters.score;
|
||||||
|
|
||||||
for parameters in parameter_search {
|
for parameters in parameters_search {
|
||||||
let result = cross_validate(&fit_estimator, x, y, ¶meters, &cv, &score)?;
|
let result = cross_validate(&estimator, x, y, ¶meters, &cv, &score)?;
|
||||||
if best_result.is_none()
|
if best_result.is_none()
|
||||||
|| result.mean_test_score() > best_result.as_ref().unwrap().mean_test_score()
|
|| result.mean_test_score() > best_result.as_ref().unwrap().mean_test_score()
|
||||||
{
|
{
|
||||||
@@ -54,10 +94,15 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let (Some(parameters), Some(cross_validation_result)) = (best_parameters, best_result) {
|
if let (Some(best_parameter), Some(cross_validation_result)) =
|
||||||
Ok(GridSearchResult {
|
(best_parameters, best_result)
|
||||||
|
{
|
||||||
|
let predictor = estimator(x, y, best_parameter.clone())?;
|
||||||
|
Ok(Self {
|
||||||
|
_phantom: gs_parameters._phantom,
|
||||||
|
predictor,
|
||||||
cross_validation_result,
|
cross_validation_result,
|
||||||
parameters,
|
best_parameter,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
Err(Failed::because(
|
Err(Failed::because(
|
||||||
@@ -65,16 +110,66 @@ where
|
|||||||
"there were no parameter sets found",
|
"there were no parameter sets found",
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return grid search cross validation results
|
||||||
|
pub fn cv_results(&self) -> &CrossValidationResult<T> {
|
||||||
|
&self.cross_validation_result
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return best parameters found
|
||||||
|
pub fn best_parameters(&self) -> &C {
|
||||||
|
&self.best_parameter
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Call predict on the estimator with the best found parameters
|
||||||
|
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||||
|
self.predictor.predict(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<
|
||||||
|
T: RealNumber,
|
||||||
|
M: Matrix<T>,
|
||||||
|
C: Clone,
|
||||||
|
I: Iterator<Item = C>,
|
||||||
|
E: Predictor<M, M::RowVector>,
|
||||||
|
F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>,
|
||||||
|
K: BaseKFold,
|
||||||
|
S: Fn(&M::RowVector, &M::RowVector) -> T,
|
||||||
|
> SupervisedEstimator<M, M::RowVector, GridSearchCVParameters<T, M, C, I, E, F, K, S>>
|
||||||
|
for GridSearchCV<T, M, C, E>
|
||||||
|
{
|
||||||
|
fn fit(
|
||||||
|
x: &M,
|
||||||
|
y: &M::RowVector,
|
||||||
|
parameters: GridSearchCVParameters<T, M, C, I, E, F, K, S>,
|
||||||
|
) -> Result<Self, Failed> {
|
||||||
|
GridSearchCV::fit(x, y, parameters)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber, M: Matrix<T>, C: Clone, E: Predictor<M, M::RowVector>>
|
||||||
|
Predictor<M, M::RowVector> for GridSearchCV<T, M, C, E>
|
||||||
|
{
|
||||||
|
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||||
|
self.predict(x)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
linalg::naive::dense_matrix::DenseMatrix,
|
linalg::naive::dense_matrix::DenseMatrix,
|
||||||
linear::logistic_regression::{LogisticRegression, LogisticRegressionSearchParameters},
|
linear::logistic_regression::{LogisticRegression, LogisticRegressionSearchParameters},
|
||||||
metrics::accuracy,
|
metrics::accuracy,
|
||||||
model_selection::{hyper_tuning::grid_search, KFold},
|
model_selection::{
|
||||||
|
hyper_tuning::grid_search::{self, GridSearchCVParameters},
|
||||||
|
KFold,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
use grid_search::GridSearchCV;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_grid_search() {
|
fn test_grid_search() {
|
||||||
@@ -114,16 +209,28 @@ mod tests {
|
|||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let results = grid_search(
|
let grid_search = GridSearchCV::fit(
|
||||||
LogisticRegression::fit,
|
|
||||||
&x,
|
&x,
|
||||||
&y,
|
&y,
|
||||||
parameters.into_iter(),
|
GridSearchCVParameters {
|
||||||
|
estimator: LogisticRegression::fit,
|
||||||
|
score: accuracy,
|
||||||
cv,
|
cv,
|
||||||
&accuracy,
|
parameters_search: parameters.into_iter(),
|
||||||
|
_phantom: Default::default(),
|
||||||
|
},
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
let best_parameters = grid_search.best_parameters();
|
||||||
|
|
||||||
assert!([0., 1.].contains(&results.parameters.alpha));
|
assert!([1.].contains(&best_parameters.alpha));
|
||||||
|
|
||||||
|
let cv_results = grid_search.cv_results();
|
||||||
|
|
||||||
|
assert_eq!(cv_results.mean_test_score(), 0.9);
|
||||||
|
|
||||||
|
let x = DenseMatrix::from_2d_array(&[&[5., 3., 1., 0.]]);
|
||||||
|
let result = grid_search.predict(&x).unwrap();
|
||||||
|
assert_eq!(result, vec![0.]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,2 +1,2 @@
|
|||||||
mod grid_search;
|
mod grid_search;
|
||||||
pub use grid_search::{grid_search, GridSearchResult};
|
pub use grid_search::{GridSearchCV, GridSearchCVParameters};
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ use rand::seq::SliceRandom;
|
|||||||
pub(crate) mod hyper_tuning;
|
pub(crate) mod hyper_tuning;
|
||||||
pub(crate) mod kfold;
|
pub(crate) mod kfold;
|
||||||
|
|
||||||
pub use hyper_tuning::{grid_search, GridSearchResult};
|
pub use hyper_tuning::{GridSearchCV, GridSearchCVParameters};
|
||||||
pub use kfold::{KFold, KFoldIter};
|
pub use kfold::{KFold, KFoldIter};
|
||||||
|
|
||||||
/// An interface for the K-Folds cross-validator
|
/// An interface for the K-Folds cross-validator
|
||||||
|
|||||||
Reference in New Issue
Block a user