refactor: Try to follow similar pattern to other APIs (#180)

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
This commit is contained in:
morenol
2022-10-01 16:44:08 -05:00
parent ad2e6c2900
commit 473cdfc44d
3 changed files with 164 additions and 57 deletions
+162 -55
View File
@@ -1,5 +1,5 @@
use crate::{
api::Predictor,
api::{Predictor, SupervisedEstimator},
error::{Failed, FailedError},
linalg::Matrix,
math::num::RealNumber,
@@ -7,74 +7,169 @@ use crate::{
use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult};
/// grid search results.
#[derive(Clone, Debug)]
pub struct GridSearchResult<T: RealNumber, I: Clone> {
/// Vector with test scores on each cv split
pub cross_validation_result: CrossValidationResult<T>,
/// Vector with training scores on each cv split
pub parameters: I,
}
/// Search for the best estimator by testing all possible combinations with cross-validation using given metric.
/// * `fit_estimator` - a `fit` function of an estimator
/// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes.
/// * `y` - target values, should be of size _N_
/// * `parameter_search` - an iterator for parameters that will be tested.
/// * `cv` - the cross-validation splitting strategy, should be an instance of [`BaseKFold`](./trait.BaseKFold.html)
/// * `score` - a metric to use for evaluation, see [metrics](../metrics/index.html)
pub fn grid_search<T, M, I, E, K, F, S>(
fit_estimator: F,
x: &M,
y: &M::RowVector,
parameter_search: I,
cv: K,
score: S,
) -> Result<GridSearchResult<T, I::Item>, Failed>
where
/// Parameters for GridSearchCV
#[derive(Debug)]
pub struct GridSearchCVParameters<
T: RealNumber,
M: Matrix<T>,
I: Iterator,
I::Item: Clone,
C: Clone,
I: Iterator<Item = C>,
E: Predictor<M, M::RowVector>,
F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>,
K: BaseKFold,
F: Fn(&M, &M::RowVector, I::Item) -> Result<E, Failed>,
S: Fn(&M::RowVector, &M::RowVector) -> T,
{
let mut best_result: Option<CrossValidationResult<T>> = None;
let mut best_parameters = None;
> {
_phantom: std::marker::PhantomData<(T, M)>,
for parameters in parameter_search {
let result = cross_validate(&fit_estimator, x, y, &parameters, &cv, &score)?;
if best_result.is_none()
|| result.mean_test_score() > best_result.as_ref().unwrap().mean_test_score()
parameters_search: I,
estimator: F,
score: S,
cv: K,
}
impl<
T: RealNumber,
M: Matrix<T>,
C: Clone,
I: Iterator<Item = C>,
E: Predictor<M, M::RowVector>,
F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>,
K: BaseKFold,
S: Fn(&M::RowVector, &M::RowVector) -> T,
> GridSearchCVParameters<T, M, C, I, E, F, K, S>
{
/// Create new GridSearchCVParameters
pub fn new(parameters_search: I, estimator: F, score: S, cv: K) -> Self {
GridSearchCVParameters {
_phantom: std::marker::PhantomData,
parameters_search,
estimator,
score,
cv,
}
}
}
/// Exhaustive search over specified parameter values for an estimator.
#[derive(Debug)]
pub struct GridSearchCV<T: RealNumber, M: Matrix<T>, C: Clone, E: Predictor<M, M::RowVector>> {
_phantom: std::marker::PhantomData<(T, M)>,
predictor: E,
/// Cross validation results.
pub cross_validation_result: CrossValidationResult<T>,
/// best parameter
pub best_parameter: C,
}
impl<T: RealNumber, M: Matrix<T>, E: Predictor<M, M::RowVector>, C: Clone>
GridSearchCV<T, M, C, E>
{
/// Search for the best estimator by testing all possible combinations with cross-validation using given metric.
/// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes.
/// * `y` - target values, should be of size _N_
/// * `gs_parameters` - GridSearchCVParameters struct
pub fn fit<
I: Iterator<Item = C>,
K: BaseKFold,
F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>,
S: Fn(&M::RowVector, &M::RowVector) -> T,
>(
x: &M,
y: &M::RowVector,
gs_parameters: GridSearchCVParameters<T, M, C, I, E, F, K, S>,
) -> Result<Self, Failed> {
let mut best_result: Option<CrossValidationResult<T>> = None;
let mut best_parameters = None;
let parameters_search = gs_parameters.parameters_search;
let estimator = gs_parameters.estimator;
let cv = gs_parameters.cv;
let score = gs_parameters.score;
for parameters in parameters_search {
let result = cross_validate(&estimator, x, y, &parameters, &cv, &score)?;
if best_result.is_none()
|| result.mean_test_score() > best_result.as_ref().unwrap().mean_test_score()
{
best_parameters = Some(parameters);
best_result = Some(result);
}
}
if let (Some(best_parameter), Some(cross_validation_result)) =
(best_parameters, best_result)
{
best_parameters = Some(parameters);
best_result = Some(result);
let predictor = estimator(x, y, best_parameter.clone())?;
Ok(Self {
_phantom: gs_parameters._phantom,
predictor,
cross_validation_result,
best_parameter,
})
} else {
Err(Failed::because(
FailedError::FindFailed,
"there were no parameter sets found",
))
}
}
if let (Some(parameters), Some(cross_validation_result)) = (best_parameters, best_result) {
Ok(GridSearchResult {
cross_validation_result,
parameters,
})
} else {
Err(Failed::because(
FailedError::FindFailed,
"there were no parameter sets found",
))
/// Return grid search cross validation results
pub fn cv_results(&self) -> &CrossValidationResult<T> {
&self.cross_validation_result
}
/// Return best parameters found
pub fn best_parameters(&self) -> &C {
&self.best_parameter
}
/// Call predict on the estimator with the best found parameters
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
self.predictor.predict(x)
}
}
impl<
T: RealNumber,
M: Matrix<T>,
C: Clone,
I: Iterator<Item = C>,
E: Predictor<M, M::RowVector>,
F: Fn(&M, &M::RowVector, C) -> Result<E, Failed>,
K: BaseKFold,
S: Fn(&M::RowVector, &M::RowVector) -> T,
> SupervisedEstimator<M, M::RowVector, GridSearchCVParameters<T, M, C, I, E, F, K, S>>
for GridSearchCV<T, M, C, E>
{
fn fit(
x: &M,
y: &M::RowVector,
parameters: GridSearchCVParameters<T, M, C, I, E, F, K, S>,
) -> Result<Self, Failed> {
GridSearchCV::fit(x, y, parameters)
}
}
impl<T: RealNumber, M: Matrix<T>, C: Clone, E: Predictor<M, M::RowVector>>
Predictor<M, M::RowVector> for GridSearchCV<T, M, C, E>
{
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
self.predict(x)
}
}
#[cfg(test)]
mod tests {
use crate::{
linalg::naive::dense_matrix::DenseMatrix,
linear::logistic_regression::{LogisticRegression, LogisticRegressionSearchParameters},
metrics::accuracy,
model_selection::{hyper_tuning::grid_search, KFold},
model_selection::{
hyper_tuning::grid_search::{self, GridSearchCVParameters},
KFold,
},
};
use grid_search::GridSearchCV;
#[test]
fn test_grid_search() {
@@ -114,16 +209,28 @@ mod tests {
..Default::default()
};
let results = grid_search(
LogisticRegression::fit,
let grid_search = GridSearchCV::fit(
&x,
&y,
parameters.into_iter(),
cv,
&accuracy,
GridSearchCVParameters {
estimator: LogisticRegression::fit,
score: accuracy,
cv,
parameters_search: parameters.into_iter(),
_phantom: Default::default(),
},
)
.unwrap();
let best_parameters = grid_search.best_parameters();
assert!([0., 1.].contains(&results.parameters.alpha));
assert!([1.].contains(&best_parameters.alpha));
let cv_results = grid_search.cv_results();
assert_eq!(cv_results.mean_test_score(), 0.9);
let x = DenseMatrix::from_2d_array(&[&[5., 3., 1., 0.]]);
let result = grid_search.predict(&x).unwrap();
assert_eq!(result, vec![0.]);
}
}
+1 -1
View File
@@ -1,2 +1,2 @@
mod grid_search;
pub use grid_search::{grid_search, GridSearchResult};
pub use grid_search::{GridSearchCV, GridSearchCVParameters};
+1 -1
View File
@@ -113,7 +113,7 @@ use rand::seq::SliceRandom;
pub(crate) mod hyper_tuning;
pub(crate) mod kfold;
pub use hyper_tuning::{grid_search, GridSearchResult};
pub use hyper_tuning::{GridSearchCV, GridSearchCVParameters};
pub use kfold::{KFold, KFoldIter};
/// An interface for the K-Folds cross-validator