grid search (#154)

* grid search draft
* hyperparam search for linear estimators
This commit is contained in:
Montana Low
2022-09-19 02:31:56 -07:00
committed by morenol
parent 0f442e96c0
commit 1f2597be74
7 changed files with 649 additions and 11 deletions
+87 -1
View File
@@ -68,7 +68,7 @@ use crate::optimization::line_search::Backtracking;
use crate::optimization::FunctionOrder;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Eq, PartialEq)]
/// Solver options for Logistic regression. Right now only LBFGS solver is supported.
pub enum LogisticRegressionSolverName {
/// Limited-memory BroydenFletcherGoldfarbShanno method, see [LBFGS paper](http://users.iems.northwestern.edu/~nocedal/lbfgsb.html)
@@ -85,6 +85,77 @@ pub struct LogisticRegressionParameters<T: RealNumber> {
pub alpha: T,
}
/// Logistic Regression grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LogisticRegressionSearchParameters<T: RealNumber> {
/// Solver to use for estimation of regression coefficients.
pub solver: Vec<LogisticRegressionSolverName>,
/// Regularization parameter.
pub alpha: Vec<T>,
}
/// Logistic Regression grid search iterator
pub struct LogisticRegressionSearchParametersIterator<T: RealNumber> {
logistic_regression_search_parameters: LogisticRegressionSearchParameters<T>,
current_solver: usize,
current_alpha: usize,
}
impl<T: RealNumber> IntoIterator for LogisticRegressionSearchParameters<T> {
type Item = LogisticRegressionParameters<T>;
type IntoIter = LogisticRegressionSearchParametersIterator<T>;
fn into_iter(self) -> Self::IntoIter {
LogisticRegressionSearchParametersIterator {
logistic_regression_search_parameters: self,
current_solver: 0,
current_alpha: 0,
}
}
}
impl<T: RealNumber> Iterator for LogisticRegressionSearchParametersIterator<T> {
type Item = LogisticRegressionParameters<T>;
fn next(&mut self) -> Option<Self::Item> {
if self.current_alpha == self.logistic_regression_search_parameters.alpha.len()
&& self.current_solver == self.logistic_regression_search_parameters.solver.len()
{
return None;
}
let next = LogisticRegressionParameters {
solver: self.logistic_regression_search_parameters.solver[self.current_solver].clone(),
alpha: self.logistic_regression_search_parameters.alpha[self.current_alpha],
};
if self.current_alpha + 1 < self.logistic_regression_search_parameters.alpha.len() {
self.current_alpha += 1;
} else if self.current_solver + 1 < self.logistic_regression_search_parameters.solver.len()
{
self.current_alpha = 0;
self.current_solver += 1;
} else {
self.current_alpha += 1;
self.current_solver += 1;
}
Some(next)
}
}
impl<T: RealNumber> Default for LogisticRegressionSearchParameters<T> {
fn default() -> Self {
let default_params = LogisticRegressionParameters::default();
LogisticRegressionSearchParameters {
solver: vec![default_params.solver],
alpha: vec![default_params.alpha],
}
}
}
/// Logistic Regression
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
@@ -452,6 +523,21 @@ mod tests {
use crate::linalg::naive::dense_matrix::*;
use crate::metrics::accuracy;
#[test]
fn search_parameters() {
let parameters = LogisticRegressionSearchParameters {
alpha: vec![0., 1.],
..Default::default()
};
let mut iter = parameters.into_iter();
assert_eq!(iter.next().unwrap().alpha, 0.);
assert_eq!(
iter.next().unwrap().solver,
LogisticRegressionSolverName::LBFGS
);
assert!(iter.next().is_none());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn multiclass_objective_f() {