From 05dfffad5ce2f1dc42f00752ef00c6149c833c74 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Wed, 21 Sep 2022 16:15:26 -0700 Subject: [PATCH] add seed param to search params (#168) --- src/cluster/kmeans.rs | 13 +++++++++++++ src/svm/svc.rs | 14 ++++++++++++++ src/tree/decision_tree_classifier.rs | 20 ++++++++++++++++++++ src/tree/decision_tree_regressor.rs | 14 ++++++++++++++ 4 files changed, 61 insertions(+) diff --git a/src/cluster/kmeans.rs b/src/cluster/kmeans.rs index fee1425..404f7b0 100644 --- a/src/cluster/kmeans.rs +++ b/src/cluster/kmeans.rs @@ -145,6 +145,9 @@ pub struct KMeansSearchParameters { pub k: Vec, /// Maximum number of iterations of the k-means algorithm for a single run. pub max_iter: Vec, + /// Determines random number generation for centroid initialization. + /// Use an int to make the randomness deterministic + pub seed: Vec>, } /// KMeans grid search iterator @@ -152,6 +155,7 @@ pub struct KMeansSearchParametersIterator { kmeans_search_parameters: KMeansSearchParameters, current_k: usize, current_max_iter: usize, + current_seed: usize, } impl IntoIterator for KMeansSearchParameters { @@ -163,6 +167,7 @@ impl IntoIterator for KMeansSearchParameters { kmeans_search_parameters: self, current_k: 0, current_max_iter: 0, + current_seed: 0, } } } @@ -173,6 +178,7 @@ impl Iterator for KMeansSearchParametersIterator { fn next(&mut self) -> Option { if self.current_k == self.kmeans_search_parameters.k.len() && self.current_max_iter == self.kmeans_search_parameters.max_iter.len() + && self.current_seed == self.kmeans_search_parameters.seed.len() { return None; } @@ -180,6 +186,7 @@ impl Iterator for KMeansSearchParametersIterator { let next = KMeansParameters { k: self.kmeans_search_parameters.k[self.current_k], max_iter: self.kmeans_search_parameters.max_iter[self.current_max_iter], + seed: self.kmeans_search_parameters.seed[self.current_seed], }; if self.current_k + 1 < self.kmeans_search_parameters.k.len() { @@ -187,9 +194,14 @@ impl Iterator for KMeansSearchParametersIterator { } else if self.current_max_iter + 1 < self.kmeans_search_parameters.max_iter.len() { self.current_k = 0; self.current_max_iter += 1; + } else if self.current_seed + 1 < self.kmeans_search_parameters.seed.len() { + self.current_k = 0; + self.current_max_iter = 0; + self.current_seed += 1; } else { self.current_k += 1; self.current_max_iter += 1; + self.current_seed += 1; } Some(next) @@ -203,6 +215,7 @@ impl Default for KMeansSearchParameters { KMeansSearchParameters { k: vec![default_params.k], max_iter: vec![default_params.max_iter], + seed: vec![default_params.seed], } } } diff --git a/src/svm/svc.rs b/src/svm/svc.rs index 94c6d9e..d390866 100644 --- a/src/svm/svc.rs +++ b/src/svm/svc.rs @@ -119,6 +119,8 @@ pub struct SVCSearchParameters, K: Kernel, /// Unused parameter. m: PhantomData, + /// Controls the pseudo random number generation for shuffling the data for probability estimates + seed: Vec>, } /// SVC grid search iterator @@ -128,6 +130,7 @@ pub struct SVCSearchParametersIterator, K: Kernel, K: Kernel> IntoIterator @@ -143,6 +146,7 @@ impl, K: Kernel> IntoIterator current_c: 0, current_tol: 0, current_kernel: 0, + current_seed: 0, } } } @@ -157,6 +161,7 @@ impl, K: Kernel> Iterator && self.current_c == self.svc_search_parameters.c.len() && self.current_tol == self.svc_search_parameters.tol.len() && self.current_kernel == self.svc_search_parameters.kernel.len() + && self.current_seed == self.svc_search_parameters.kernel.len() { return None; } @@ -167,6 +172,7 @@ impl, K: Kernel> Iterator tol: self.svc_search_parameters.tol[self.current_tol], kernel: self.svc_search_parameters.kernel[self.current_kernel].clone(), m: PhantomData, + seed: self.svc_search_parameters.seed[self.current_seed], }; if self.current_epoch + 1 < self.svc_search_parameters.epoch.len() { @@ -183,11 +189,18 @@ impl, K: Kernel> Iterator self.current_c = 0; self.current_tol = 0; self.current_kernel += 1; + } else if self.current_kernel + 1 < self.svc_search_parameters.kernel.len() { + self.current_epoch = 0; + self.current_c = 0; + self.current_tol = 0; + self.current_kernel = 0; + self.current_seed += 1; } else { self.current_epoch += 1; self.current_c += 1; self.current_tol += 1; self.current_kernel += 1; + self.current_seed += 1; } Some(next) @@ -204,6 +217,7 @@ impl> Default for SVCSearchParameters, + #[cfg_attr(feature = "serde", serde(default))] /// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) pub max_depth: Vec>, + #[cfg_attr(feature = "serde", serde(default))] /// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) pub min_samples_leaf: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) pub min_samples_split: Vec, + #[cfg_attr(feature = "serde", serde(default))] + /// Controls the randomness of the estimator + pub seed: Vec>, } /// DecisionTreeClassifier grid search iterator @@ -226,6 +233,7 @@ pub struct DecisionTreeClassifierSearchParametersIterator { current_max_depth: usize, current_min_samples_leaf: usize, current_min_samples_split: usize, + current_seed: usize, } impl IntoIterator for DecisionTreeClassifierSearchParameters { @@ -239,6 +247,7 @@ impl IntoIterator for DecisionTreeClassifierSearchParameters { current_max_depth: 0, current_min_samples_leaf: 0, current_min_samples_split: 0, + current_seed: 0, } } } @@ -267,6 +276,7 @@ impl Iterator for DecisionTreeClassifierSearchParametersIterator { .decision_tree_classifier_search_parameters .min_samples_split .len() + && self.current_seed == self.decision_tree_classifier_search_parameters.seed.len() { return None; } @@ -283,6 +293,7 @@ impl Iterator for DecisionTreeClassifierSearchParametersIterator { min_samples_split: self .decision_tree_classifier_search_parameters .min_samples_split[self.current_min_samples_split], + seed: self.decision_tree_classifier_search_parameters.seed[self.current_seed], }; if self.current_criterion + 1 @@ -319,11 +330,19 @@ impl Iterator for DecisionTreeClassifierSearchParametersIterator { self.current_max_depth = 0; self.current_min_samples_leaf = 0; self.current_min_samples_split += 1; + } else if self.current_seed + 1 < self.decision_tree_classifier_search_parameters.seed.len() + { + self.current_criterion = 0; + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split = 0; + self.current_seed += 1; } else { self.current_criterion += 1; self.current_max_depth += 1; self.current_min_samples_leaf += 1; self.current_min_samples_split += 1; + self.current_seed += 1; } Some(next) @@ -339,6 +358,7 @@ impl Default for DecisionTreeClassifierSearchParameters { max_depth: vec![default_params.max_depth], min_samples_leaf: vec![default_params.min_samples_leaf], min_samples_split: vec![default_params.min_samples_split], + seed: vec![default_params.seed], } } } diff --git a/src/tree/decision_tree_regressor.rs b/src/tree/decision_tree_regressor.rs index 7d88c40..12bb9c9 100644 --- a/src/tree/decision_tree_regressor.rs +++ b/src/tree/decision_tree_regressor.rs @@ -148,6 +148,8 @@ pub struct DecisionTreeRegressorSearchParameters { pub min_samples_leaf: Vec, /// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) pub min_samples_split: Vec, + /// Controls the randomness of the estimator + pub seed: Vec>, } /// DecisionTreeRegressor grid search iterator @@ -156,6 +158,7 @@ pub struct DecisionTreeRegressorSearchParametersIterator { current_max_depth: usize, current_min_samples_leaf: usize, current_min_samples_split: usize, + current_seed: usize, } impl IntoIterator for DecisionTreeRegressorSearchParameters { @@ -168,6 +171,7 @@ impl IntoIterator for DecisionTreeRegressorSearchParameters { current_max_depth: 0, current_min_samples_leaf: 0, current_min_samples_split: 0, + current_seed: 0, } } } @@ -191,6 +195,7 @@ impl Iterator for DecisionTreeRegressorSearchParametersIterator { .decision_tree_regressor_search_parameters .min_samples_split .len() + && self.current_seed == self.decision_tree_regressor_search_parameters.seed.len() { return None; } @@ -204,6 +209,7 @@ impl Iterator for DecisionTreeRegressorSearchParametersIterator { min_samples_split: self .decision_tree_regressor_search_parameters .min_samples_split[self.current_min_samples_split], + seed: self.decision_tree_regressor_search_parameters.seed[self.current_seed], }; if self.current_max_depth + 1 @@ -230,10 +236,17 @@ impl Iterator for DecisionTreeRegressorSearchParametersIterator { self.current_max_depth = 0; self.current_min_samples_leaf = 0; self.current_min_samples_split += 1; + } else if self.current_seed + 1 < self.decision_tree_regressor_search_parameters.seed.len() + { + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split = 0; + self.current_seed += 1; } else { self.current_max_depth += 1; self.current_min_samples_leaf += 1; self.current_min_samples_split += 1; + self.current_seed += 1; } Some(next) @@ -248,6 +261,7 @@ impl Default for DecisionTreeRegressorSearchParameters { max_depth: vec![default_params.max_depth], min_samples_leaf: vec![default_params.min_samples_leaf], min_samples_split: vec![default_params.min_samples_split], + seed: vec![default_params.seed], } } }