grid search (#154)
* grid search draft * hyperparam search for linear estimators
This commit is contained in:
@@ -71,7 +71,7 @@ use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
/// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable.
|
||||
pub enum LinearRegressionSolverName {
|
||||
/// QR decomposition, see [QR](../../linalg/qr/index.html)
|
||||
@@ -113,6 +113,60 @@ impl Default for LinearRegressionParameters {
|
||||
}
|
||||
}
|
||||
|
||||
/// Linear Regression grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LinearRegressionSearchParameters {
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub solver: Vec<LinearRegressionSolverName>,
|
||||
}
|
||||
|
||||
/// Linear Regression grid search iterator
|
||||
pub struct LinearRegressionSearchParametersIterator {
|
||||
linear_regression_search_parameters: LinearRegressionSearchParameters,
|
||||
current_solver: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for LinearRegressionSearchParameters {
|
||||
type Item = LinearRegressionParameters;
|
||||
type IntoIter = LinearRegressionSearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
LinearRegressionSearchParametersIterator {
|
||||
linear_regression_search_parameters: self,
|
||||
current_solver: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for LinearRegressionSearchParametersIterator {
|
||||
type Item = LinearRegressionParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_solver == self.linear_regression_search_parameters.solver.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = LinearRegressionParameters {
|
||||
solver: self.linear_regression_search_parameters.solver[self.current_solver].clone(),
|
||||
};
|
||||
|
||||
self.current_solver += 1;
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LinearRegressionSearchParameters {
|
||||
fn default() -> Self {
|
||||
let default_params = LinearRegressionParameters::default();
|
||||
|
||||
LinearRegressionSearchParameters {
|
||||
solver: vec![default_params.solver],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> PartialEq for LinearRegression<T, M> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.coefficients == other.coefficients
|
||||
@@ -200,6 +254,20 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = LinearRegressionSearchParameters {
|
||||
solver: vec![
|
||||
LinearRegressionSolverName::QR,
|
||||
LinearRegressionSolverName::SVD,
|
||||
],
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
assert_eq!(iter.next().unwrap().solver, LinearRegressionSolverName::QR);
|
||||
assert_eq!(iter.next().unwrap().solver, LinearRegressionSolverName::SVD);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn ols_fit_predict() {
|
||||
|
||||
Reference in New Issue
Block a user