grid search (#154)

* grid search draft
* hyperparam search for linear estimators
This commit is contained in:
Montana Low
2022-09-19 02:31:56 -07:00
committed by GitHub
parent 2e5f88fad8
commit 4685fc73e0
7 changed files with 649 additions and 11 deletions
+69 -1
View File
@@ -71,7 +71,7 @@ use crate::linalg::Matrix;
use crate::math::num::RealNumber;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Eq, PartialEq)]
/// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable.
pub enum LinearRegressionSolverName {
/// QR decomposition, see [QR](../../linalg/qr/index.html)
@@ -113,6 +113,60 @@ impl Default for LinearRegressionParameters {
}
}
/// Linear Regression grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LinearRegressionSearchParameters {
/// Solver to use for estimation of regression coefficients.
pub solver: Vec<LinearRegressionSolverName>,
}
/// Linear Regression grid search iterator
pub struct LinearRegressionSearchParametersIterator {
linear_regression_search_parameters: LinearRegressionSearchParameters,
current_solver: usize,
}
impl IntoIterator for LinearRegressionSearchParameters {
type Item = LinearRegressionParameters;
type IntoIter = LinearRegressionSearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
LinearRegressionSearchParametersIterator {
linear_regression_search_parameters: self,
current_solver: 0,
}
}
}
impl Iterator for LinearRegressionSearchParametersIterator {
type Item = LinearRegressionParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_solver == self.linear_regression_search_parameters.solver.len() {
return None;
}
let next = LinearRegressionParameters {
solver: self.linear_regression_search_parameters.solver[self.current_solver].clone(),
};
self.current_solver += 1;
Some(next)
}
}
impl Default for LinearRegressionSearchParameters {
fn default() -> Self {
let default_params = LinearRegressionParameters::default();
LinearRegressionSearchParameters {
solver: vec![default_params.solver],
}
}
}
impl<T: RealNumber, M: Matrix<T>> PartialEq for LinearRegression<T, M> {
fn eq(&self, other: &Self) -> bool {
self.coefficients == other.coefficients
@@ -200,6 +254,20 @@ mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::*;
#[test]
fn search_parameters() {
let parameters = LinearRegressionSearchParameters {
solver: vec![
LinearRegressionSolverName::QR,
LinearRegressionSolverName::SVD,
],
};
let mut iter = parameters.into_iter();
assert_eq!(iter.next().unwrap().solver, LinearRegressionSolverName::QR);
assert_eq!(iter.next().unwrap().solver, LinearRegressionSolverName::SVD);
assert!(iter.next().is_none());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn ols_fit_predict() {