Complete grid search params (#166)
* grid search draft * hyperparam search for linear estimators * grid search for ensembles * support grid search for more algos * grid search for unsupervised algos * minor cleanup
This commit is contained in:
@@ -201,6 +201,144 @@ impl Default for DecisionTreeClassifierParameters {
|
||||
}
|
||||
}
|
||||
|
||||
/// DecisionTreeClassifier grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DecisionTreeClassifierSearchParameters {
|
||||
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub criterion: Vec<SplitCriterion>,
|
||||
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub max_depth: Vec<Option<u16>>,
|
||||
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub min_samples_leaf: Vec<usize>,
|
||||
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub min_samples_split: Vec<usize>,
|
||||
}
|
||||
|
||||
/// DecisionTreeClassifier grid search iterator
|
||||
pub struct DecisionTreeClassifierSearchParametersIterator {
|
||||
decision_tree_classifier_search_parameters: DecisionTreeClassifierSearchParameters,
|
||||
current_criterion: usize,
|
||||
current_max_depth: usize,
|
||||
current_min_samples_leaf: usize,
|
||||
current_min_samples_split: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for DecisionTreeClassifierSearchParameters {
|
||||
type Item = DecisionTreeClassifierParameters;
|
||||
type IntoIter = DecisionTreeClassifierSearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
DecisionTreeClassifierSearchParametersIterator {
|
||||
decision_tree_classifier_search_parameters: self,
|
||||
current_criterion: 0,
|
||||
current_max_depth: 0,
|
||||
current_min_samples_leaf: 0,
|
||||
current_min_samples_split: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for DecisionTreeClassifierSearchParametersIterator {
|
||||
type Item = DecisionTreeClassifierParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_criterion
|
||||
== self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.criterion
|
||||
.len()
|
||||
&& self.current_max_depth
|
||||
== self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.max_depth
|
||||
.len()
|
||||
&& self.current_min_samples_leaf
|
||||
== self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.min_samples_leaf
|
||||
.len()
|
||||
&& self.current_min_samples_split
|
||||
== self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.min_samples_split
|
||||
.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = DecisionTreeClassifierParameters {
|
||||
criterion: self.decision_tree_classifier_search_parameters.criterion
|
||||
[self.current_criterion]
|
||||
.clone(),
|
||||
max_depth: self.decision_tree_classifier_search_parameters.max_depth
|
||||
[self.current_max_depth],
|
||||
min_samples_leaf: self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.min_samples_leaf[self.current_min_samples_leaf],
|
||||
min_samples_split: self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.min_samples_split[self.current_min_samples_split],
|
||||
};
|
||||
|
||||
if self.current_criterion + 1
|
||||
< self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.criterion
|
||||
.len()
|
||||
{
|
||||
self.current_criterion += 1;
|
||||
} else if self.current_max_depth + 1
|
||||
< self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.max_depth
|
||||
.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth += 1;
|
||||
} else if self.current_min_samples_leaf + 1
|
||||
< self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.min_samples_leaf
|
||||
.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf += 1;
|
||||
} else if self.current_min_samples_split + 1
|
||||
< self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.min_samples_split
|
||||
.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split += 1;
|
||||
} else {
|
||||
self.current_criterion += 1;
|
||||
self.current_max_depth += 1;
|
||||
self.current_min_samples_leaf += 1;
|
||||
self.current_min_samples_split += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DecisionTreeClassifierSearchParameters {
|
||||
fn default() -> Self {
|
||||
let default_params = DecisionTreeClassifierParameters::default();
|
||||
|
||||
DecisionTreeClassifierSearchParameters {
|
||||
criterion: vec![default_params.criterion],
|
||||
max_depth: vec![default_params.max_depth],
|
||||
min_samples_leaf: vec![default_params.min_samples_leaf],
|
||||
min_samples_split: vec![default_params.min_samples_split],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Node<T> {
|
||||
fn new(index: usize, output: usize) -> Self {
|
||||
Node {
|
||||
@@ -651,6 +789,29 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = DecisionTreeClassifierSearchParameters {
|
||||
max_depth: vec![Some(10), Some(100)],
|
||||
min_samples_split: vec![1, 2],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.max_depth, Some(10));
|
||||
assert_eq!(next.min_samples_split, 1);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.max_depth, Some(100));
|
||||
assert_eq!(next.min_samples_split, 1);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.max_depth, Some(10));
|
||||
assert_eq!(next.min_samples_split, 2);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.max_depth, Some(100));
|
||||
assert_eq!(next.min_samples_split, 2);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn gini_impurity() {
|
||||
|
||||
Reference in New Issue
Block a user