add seed param to search params (#168)

This commit is contained in:
Montana Low
2022-09-21 16:15:26 -07:00
committed by GitHub
parent 3a44161406
commit 403d3f2348
4 changed files with 61 additions and 0 deletions
+20
View File
@@ -209,14 +209,21 @@ impl Default for DecisionTreeClassifierParameters {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct DecisionTreeClassifierSearchParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub criterion: Vec<SplitCriterion>,
#[cfg_attr(feature = "serde", serde(default))]
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub max_depth: Vec<Option<u16>>,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub min_samples_leaf: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub min_samples_split: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// Controls the randomness of the estimator
pub seed: Vec<Option<u64>>,
}
/// DecisionTreeClassifier grid search iterator
@@ -226,6 +233,7 @@ pub struct DecisionTreeClassifierSearchParametersIterator {
current_max_depth: usize,
current_min_samples_leaf: usize,
current_min_samples_split: usize,
current_seed: usize,
}
impl IntoIterator for DecisionTreeClassifierSearchParameters {
@@ -239,6 +247,7 @@ impl IntoIterator for DecisionTreeClassifierSearchParameters {
current_max_depth: 0,
current_min_samples_leaf: 0,
current_min_samples_split: 0,
current_seed: 0,
}
}
}
@@ -267,6 +276,7 @@ impl Iterator for DecisionTreeClassifierSearchParametersIterator {
.decision_tree_classifier_search_parameters
.min_samples_split
.len()
&& self.current_seed == self.decision_tree_classifier_search_parameters.seed.len()
{
return None;
}
@@ -283,6 +293,7 @@ impl Iterator for DecisionTreeClassifierSearchParametersIterator {
min_samples_split: self
.decision_tree_classifier_search_parameters
.min_samples_split[self.current_min_samples_split],
seed: self.decision_tree_classifier_search_parameters.seed[self.current_seed],
};
if self.current_criterion + 1
@@ -319,11 +330,19 @@ impl Iterator for DecisionTreeClassifierSearchParametersIterator {
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split += 1;
} else if self.current_seed + 1 < self.decision_tree_classifier_search_parameters.seed.len()
{
self.current_criterion = 0;
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split = 0;
self.current_seed += 1;
} else {
self.current_criterion += 1;
self.current_max_depth += 1;
self.current_min_samples_leaf += 1;
self.current_min_samples_split += 1;
self.current_seed += 1;
}
Some(next)
@@ -339,6 +358,7 @@ impl Default for DecisionTreeClassifierSearchParameters {
max_depth: vec![default_params.max_depth],
min_samples_leaf: vec![default_params.min_samples_leaf],
min_samples_split: vec![default_params.min_samples_split],
seed: vec![default_params.seed],
}
}
}
+14
View File
@@ -148,6 +148,8 @@ pub struct DecisionTreeRegressorSearchParameters {
pub min_samples_leaf: Vec<usize>,
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub min_samples_split: Vec<usize>,
/// Controls the randomness of the estimator
pub seed: Vec<Option<u64>>,
}
/// DecisionTreeRegressor grid search iterator
@@ -156,6 +158,7 @@ pub struct DecisionTreeRegressorSearchParametersIterator {
current_max_depth: usize,
current_min_samples_leaf: usize,
current_min_samples_split: usize,
current_seed: usize,
}
impl IntoIterator for DecisionTreeRegressorSearchParameters {
@@ -168,6 +171,7 @@ impl IntoIterator for DecisionTreeRegressorSearchParameters {
current_max_depth: 0,
current_min_samples_leaf: 0,
current_min_samples_split: 0,
current_seed: 0,
}
}
}
@@ -191,6 +195,7 @@ impl Iterator for DecisionTreeRegressorSearchParametersIterator {
.decision_tree_regressor_search_parameters
.min_samples_split
.len()
&& self.current_seed == self.decision_tree_regressor_search_parameters.seed.len()
{
return None;
}
@@ -204,6 +209,7 @@ impl Iterator for DecisionTreeRegressorSearchParametersIterator {
min_samples_split: self
.decision_tree_regressor_search_parameters
.min_samples_split[self.current_min_samples_split],
seed: self.decision_tree_regressor_search_parameters.seed[self.current_seed],
};
if self.current_max_depth + 1
@@ -230,10 +236,17 @@ impl Iterator for DecisionTreeRegressorSearchParametersIterator {
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split += 1;
} else if self.current_seed + 1 < self.decision_tree_regressor_search_parameters.seed.len()
{
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split = 0;
self.current_seed += 1;
} else {
self.current_max_depth += 1;
self.current_min_samples_leaf += 1;
self.current_min_samples_split += 1;
self.current_seed += 1;
}
Some(next)
@@ -248,6 +261,7 @@ impl Default for DecisionTreeRegressorSearchParameters {
max_depth: vec![default_params.max_depth],
min_samples_leaf: vec![default_params.min_samples_leaf],
min_samples_split: vec![default_params.min_samples_split],
seed: vec![default_params.seed],
}
}
}