Complete grid search params (#166)

* grid search draft

* hyperparam search for linear estimators

* grid search for ensembles

* support grid search for more algos

* grid search for unsupervised algos

* minor cleanup
This commit is contained in:
Montana Low
2022-09-21 12:34:21 -07:00
committed by GitHub
parent 69d8be35de
commit 48514d1b15
18 changed files with 1713 additions and 25 deletions
+14 -17
View File
@@ -129,7 +129,7 @@ pub struct LassoSearchParameters<T: RealNumber> {
/// Lasso grid search iterator
pub struct LassoSearchParametersIterator<T: RealNumber> {
lasso_regression_search_parameters: LassoSearchParameters<T>,
lasso_search_parameters: LassoSearchParameters<T>,
current_alpha: usize,
current_normalize: usize,
current_tol: usize,
@@ -142,7 +142,7 @@ impl<T: RealNumber> IntoIterator for LassoSearchParameters<T> {
fn into_iter(self) -> Self::IntoIter {
LassoSearchParametersIterator {
lasso_regression_search_parameters: self,
lasso_search_parameters: self,
current_alpha: 0,
current_normalize: 0,
current_tol: 0,
@@ -155,34 +155,31 @@ impl<T: RealNumber> Iterator for LassoSearchParametersIterator<T> {
type Item = LassoParameters<T>;
fn next(&mut self) -> Option<Self::Item> {
if self.current_alpha == self.lasso_regression_search_parameters.alpha.len()
&& self.current_normalize == self.lasso_regression_search_parameters.normalize.len()
&& self.current_tol == self.lasso_regression_search_parameters.tol.len()
&& self.current_max_iter == self.lasso_regression_search_parameters.max_iter.len()
if self.current_alpha == self.lasso_search_parameters.alpha.len()
&& self.current_normalize == self.lasso_search_parameters.normalize.len()
&& self.current_tol == self.lasso_search_parameters.tol.len()
&& self.current_max_iter == self.lasso_search_parameters.max_iter.len()
{
return None;
}
let next = LassoParameters {
alpha: self.lasso_regression_search_parameters.alpha[self.current_alpha],
normalize: self.lasso_regression_search_parameters.normalize[self.current_normalize],
tol: self.lasso_regression_search_parameters.tol[self.current_tol],
max_iter: self.lasso_regression_search_parameters.max_iter[self.current_max_iter],
alpha: self.lasso_search_parameters.alpha[self.current_alpha],
normalize: self.lasso_search_parameters.normalize[self.current_normalize],
tol: self.lasso_search_parameters.tol[self.current_tol],
max_iter: self.lasso_search_parameters.max_iter[self.current_max_iter],
};
if self.current_alpha + 1 < self.lasso_regression_search_parameters.alpha.len() {
if self.current_alpha + 1 < self.lasso_search_parameters.alpha.len() {
self.current_alpha += 1;
} else if self.current_normalize + 1
< self.lasso_regression_search_parameters.normalize.len()
{
} else if self.current_normalize + 1 < self.lasso_search_parameters.normalize.len() {
self.current_alpha = 0;
self.current_normalize += 1;
} else if self.current_tol + 1 < self.lasso_regression_search_parameters.tol.len() {
} else if self.current_tol + 1 < self.lasso_search_parameters.tol.len() {
self.current_alpha = 0;
self.current_normalize = 0;
self.current_tol += 1;
} else if self.current_max_iter + 1 < self.lasso_regression_search_parameters.max_iter.len()
{
} else if self.current_max_iter + 1 < self.lasso_search_parameters.max_iter.len() {
self.current_alpha = 0;
self.current_normalize = 0;
self.current_tol = 0;