Complete grid search params (#166)

* grid search draft

* hyperparam search for linear estimators

* grid search for ensembles

* support grid search for more algos

* grid search for unsupervised algos

* minor cleanup
This commit is contained in:
Montana Low
2022-09-21 12:34:21 -07:00
committed by GitHub
parent 69d8be35de
commit 48514d1b15
18 changed files with 1713 additions and 25 deletions
+161
View File
@@ -201,6 +201,144 @@ impl Default for DecisionTreeClassifierParameters {
}
}
/// DecisionTreeClassifier grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct DecisionTreeClassifierSearchParameters {
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub criterion: Vec<SplitCriterion>,
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub max_depth: Vec<Option<u16>>,
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub min_samples_leaf: Vec<usize>,
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub min_samples_split: Vec<usize>,
}
/// DecisionTreeClassifier grid search iterator
pub struct DecisionTreeClassifierSearchParametersIterator {
decision_tree_classifier_search_parameters: DecisionTreeClassifierSearchParameters,
current_criterion: usize,
current_max_depth: usize,
current_min_samples_leaf: usize,
current_min_samples_split: usize,
}
impl IntoIterator for DecisionTreeClassifierSearchParameters {
type Item = DecisionTreeClassifierParameters;
type IntoIter = DecisionTreeClassifierSearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
DecisionTreeClassifierSearchParametersIterator {
decision_tree_classifier_search_parameters: self,
current_criterion: 0,
current_max_depth: 0,
current_min_samples_leaf: 0,
current_min_samples_split: 0,
}
}
}
impl Iterator for DecisionTreeClassifierSearchParametersIterator {
type Item = DecisionTreeClassifierParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_criterion
== self
.decision_tree_classifier_search_parameters
.criterion
.len()
&& self.current_max_depth
== self
.decision_tree_classifier_search_parameters
.max_depth
.len()
&& self.current_min_samples_leaf
== self
.decision_tree_classifier_search_parameters
.min_samples_leaf
.len()
&& self.current_min_samples_split
== self
.decision_tree_classifier_search_parameters
.min_samples_split
.len()
{
return None;
}
let next = DecisionTreeClassifierParameters {
criterion: self.decision_tree_classifier_search_parameters.criterion
[self.current_criterion]
.clone(),
max_depth: self.decision_tree_classifier_search_parameters.max_depth
[self.current_max_depth],
min_samples_leaf: self
.decision_tree_classifier_search_parameters
.min_samples_leaf[self.current_min_samples_leaf],
min_samples_split: self
.decision_tree_classifier_search_parameters
.min_samples_split[self.current_min_samples_split],
};
if self.current_criterion + 1
< self
.decision_tree_classifier_search_parameters
.criterion
.len()
{
self.current_criterion += 1;
} else if self.current_max_depth + 1
< self
.decision_tree_classifier_search_parameters
.max_depth
.len()
{
self.current_criterion = 0;
self.current_max_depth += 1;
} else if self.current_min_samples_leaf + 1
< self
.decision_tree_classifier_search_parameters
.min_samples_leaf
.len()
{
self.current_criterion = 0;
self.current_max_depth = 0;
self.current_min_samples_leaf += 1;
} else if self.current_min_samples_split + 1
< self
.decision_tree_classifier_search_parameters
.min_samples_split
.len()
{
self.current_criterion = 0;
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split += 1;
} else {
self.current_criterion += 1;
self.current_max_depth += 1;
self.current_min_samples_leaf += 1;
self.current_min_samples_split += 1;
}
Some(next)
}
}
impl Default for DecisionTreeClassifierSearchParameters {
fn default() -> Self {
let default_params = DecisionTreeClassifierParameters::default();
DecisionTreeClassifierSearchParameters {
criterion: vec![default_params.criterion],
max_depth: vec![default_params.max_depth],
min_samples_leaf: vec![default_params.min_samples_leaf],
min_samples_split: vec![default_params.min_samples_split],
}
}
}
impl<T: RealNumber> Node<T> {
fn new(index: usize, output: usize) -> Self {
Node {
@@ -651,6 +789,29 @@ mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix;
#[test]
fn search_parameters() {
let parameters = DecisionTreeClassifierSearchParameters {
max_depth: vec![Some(10), Some(100)],
min_samples_split: vec![1, 2],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.max_depth, Some(10));
assert_eq!(next.min_samples_split, 1);
let next = iter.next().unwrap();
assert_eq!(next.max_depth, Some(100));
assert_eq!(next.min_samples_split, 1);
let next = iter.next().unwrap();
assert_eq!(next.max_depth, Some(10));
assert_eq!(next.min_samples_split, 2);
let next = iter.next().unwrap();
assert_eq!(next.max_depth, Some(100));
assert_eq!(next.min_samples_split, 2);
assert!(iter.next().is_none());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn gini_impurity() {
+137
View File
@@ -134,6 +134,120 @@ impl Default for DecisionTreeRegressorParameters {
}
}
/// DecisionTreeRegressor grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct DecisionTreeRegressorSearchParameters {
/// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub max_depth: Vec<Option<u16>>,
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub min_samples_leaf: Vec<usize>,
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub min_samples_split: Vec<usize>,
}
/// DecisionTreeRegressor grid search iterator
pub struct DecisionTreeRegressorSearchParametersIterator {
decision_tree_regressor_search_parameters: DecisionTreeRegressorSearchParameters,
current_max_depth: usize,
current_min_samples_leaf: usize,
current_min_samples_split: usize,
}
impl IntoIterator for DecisionTreeRegressorSearchParameters {
type Item = DecisionTreeRegressorParameters;
type IntoIter = DecisionTreeRegressorSearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
DecisionTreeRegressorSearchParametersIterator {
decision_tree_regressor_search_parameters: self,
current_max_depth: 0,
current_min_samples_leaf: 0,
current_min_samples_split: 0,
}
}
}
impl Iterator for DecisionTreeRegressorSearchParametersIterator {
type Item = DecisionTreeRegressorParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_max_depth
== self
.decision_tree_regressor_search_parameters
.max_depth
.len()
&& self.current_min_samples_leaf
== self
.decision_tree_regressor_search_parameters
.min_samples_leaf
.len()
&& self.current_min_samples_split
== self
.decision_tree_regressor_search_parameters
.min_samples_split
.len()
{
return None;
}
let next = DecisionTreeRegressorParameters {
max_depth: self.decision_tree_regressor_search_parameters.max_depth
[self.current_max_depth],
min_samples_leaf: self
.decision_tree_regressor_search_parameters
.min_samples_leaf[self.current_min_samples_leaf],
min_samples_split: self
.decision_tree_regressor_search_parameters
.min_samples_split[self.current_min_samples_split],
};
if self.current_max_depth + 1
< self
.decision_tree_regressor_search_parameters
.max_depth
.len()
{
self.current_max_depth += 1;
} else if self.current_min_samples_leaf + 1
< self
.decision_tree_regressor_search_parameters
.min_samples_leaf
.len()
{
self.current_max_depth = 0;
self.current_min_samples_leaf += 1;
} else if self.current_min_samples_split + 1
< self
.decision_tree_regressor_search_parameters
.min_samples_split
.len()
{
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split += 1;
} else {
self.current_max_depth += 1;
self.current_min_samples_leaf += 1;
self.current_min_samples_split += 1;
}
Some(next)
}
}
impl Default for DecisionTreeRegressorSearchParameters {
fn default() -> Self {
let default_params = DecisionTreeRegressorParameters::default();
DecisionTreeRegressorSearchParameters {
max_depth: vec![default_params.max_depth],
min_samples_leaf: vec![default_params.min_samples_leaf],
min_samples_split: vec![default_params.min_samples_split],
}
}
}
impl<T: RealNumber> Node<T> {
fn new(index: usize, output: T) -> Self {
Node {
@@ -517,6 +631,29 @@ mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix;
#[test]
fn search_parameters() {
let parameters = DecisionTreeRegressorSearchParameters {
max_depth: vec![Some(10), Some(100)],
min_samples_split: vec![1, 2],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.max_depth, Some(10));
assert_eq!(next.min_samples_split, 1);
let next = iter.next().unwrap();
assert_eq!(next.max_depth, Some(100));
assert_eq!(next.min_samples_split, 1);
let next = iter.next().unwrap();
assert_eq!(next.max_depth, Some(10));
assert_eq!(next.min_samples_split, 2);
let next = iter.next().unwrap();
assert_eq!(next.max_depth, Some(100));
assert_eq!(next.min_samples_split, 2);
assert!(iter.next().is_none());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn fit_longley() {