From 55e11585812b6bd25c16f0693f42cf04a5304838 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Wed, 21 Sep 2022 12:34:21 -0700 Subject: [PATCH] Complete grid search params (#166) * grid search draft * hyperparam search for linear estimators * grid search for ensembles * support grid search for more algos * grid search for unsupervised algos * minor cleanup --- src/cluster/dbscan.rs | 120 +++++++++++ src/cluster/kmeans.rs | 93 +++++++++ src/decomposition/pca.rs | 98 +++++++++ src/decomposition/svd.rs | 68 +++++++ src/ensemble/random_forest_classifier.rs | 243 +++++++++++++++++++++++ src/ensemble/random_forest_regressor.rs | 208 +++++++++++++++++++ src/linear/lasso.rs | 31 ++- src/model_selection/hyper_tuning.rs | 2 +- src/model_selection/mod.rs | 1 - src/naive_bayes/bernoulli.rs | 96 +++++++++ src/naive_bayes/categorical.rs | 68 +++++++ src/naive_bayes/gaussian.rs | 76 ++++++- src/naive_bayes/multinomial.rs | 84 ++++++++ src/svm/mod.rs | 10 +- src/svm/svc.rs | 121 +++++++++++ src/svm/svr.rs | 121 +++++++++++ src/tree/decision_tree_classifier.rs | 161 +++++++++++++++ src/tree/decision_tree_regressor.rs | 137 +++++++++++++ 18 files changed, 1713 insertions(+), 25 deletions(-) diff --git a/src/cluster/dbscan.rs b/src/cluster/dbscan.rs index 7f2baef..621d017 100644 --- a/src/cluster/dbscan.rs +++ b/src/cluster/dbscan.rs @@ -109,6 +109,103 @@ impl, T>> DBSCANParameters { } } +/// DBSCAN grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct DBSCANSearchParameters, T>> { + /// a function that defines a distance between each pair of point in training data. + /// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait. + /// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions. + pub distance: Vec, + /// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point. + pub min_samples: Vec, + /// The maximum distance between two samples for one to be considered as in the neighborhood of the other. + pub eps: Vec, + /// KNN algorithm to use. + pub algorithm: Vec, +} + +/// DBSCAN grid search iterator +pub struct DBSCANSearchParametersIterator, T>> { + dbscan_search_parameters: DBSCANSearchParameters, + current_distance: usize, + current_min_samples: usize, + current_eps: usize, + current_algorithm: usize, +} + +impl, T>> IntoIterator for DBSCANSearchParameters { + type Item = DBSCANParameters; + type IntoIter = DBSCANSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + DBSCANSearchParametersIterator { + dbscan_search_parameters: self, + current_distance: 0, + current_min_samples: 0, + current_eps: 0, + current_algorithm: 0, + } + } +} + +impl, T>> Iterator for DBSCANSearchParametersIterator { + type Item = DBSCANParameters; + + fn next(&mut self) -> Option { + if self.current_distance == self.dbscan_search_parameters.distance.len() + && self.current_min_samples == self.dbscan_search_parameters.min_samples.len() + && self.current_eps == self.dbscan_search_parameters.eps.len() + && self.current_algorithm == self.dbscan_search_parameters.algorithm.len() + { + return None; + } + + let next = DBSCANParameters { + distance: self.dbscan_search_parameters.distance[self.current_distance].clone(), + min_samples: self.dbscan_search_parameters.min_samples[self.current_min_samples], + eps: self.dbscan_search_parameters.eps[self.current_eps], + algorithm: self.dbscan_search_parameters.algorithm[self.current_algorithm].clone(), + }; + + if self.current_distance + 1 < self.dbscan_search_parameters.distance.len() { + self.current_distance += 1; + } else if self.current_min_samples + 1 < self.dbscan_search_parameters.min_samples.len() { + self.current_distance = 0; + self.current_min_samples += 1; + } else if self.current_eps + 1 < self.dbscan_search_parameters.eps.len() { + self.current_distance = 0; + self.current_min_samples = 0; + self.current_eps += 1; + } else if self.current_algorithm + 1 < self.dbscan_search_parameters.algorithm.len() { + self.current_distance = 0; + self.current_min_samples = 0; + self.current_eps = 0; + self.current_algorithm += 1; + } else { + self.current_distance += 1; + self.current_min_samples += 1; + self.current_eps += 1; + self.current_algorithm += 1; + } + + Some(next) + } +} + +impl Default for DBSCANSearchParameters { + fn default() -> Self { + let default_params = DBSCANParameters::default(); + + DBSCANSearchParameters { + distance: vec![default_params.distance], + min_samples: vec![default_params.min_samples], + eps: vec![default_params.eps], + algorithm: vec![default_params.algorithm], + } + } +} + impl, T>> PartialEq for DBSCAN { fn eq(&self, other: &Self) -> bool { self.cluster_labels.len() == other.cluster_labels.len() @@ -268,6 +365,29 @@ mod tests { #[cfg(feature = "serde")] use crate::math::distance::euclidian::Euclidian; + #[test] + fn search_parameters() { + let parameters = DBSCANSearchParameters { + min_samples: vec![10, 100], + eps: vec![1., 2.], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.min_samples, 10); + assert_eq!(next.eps, 1.); + let next = iter.next().unwrap(); + assert_eq!(next.min_samples, 100); + assert_eq!(next.eps, 1.); + let next = iter.next().unwrap(); + assert_eq!(next.min_samples, 10); + assert_eq!(next.eps, 2.); + let next = iter.next().unwrap(); + assert_eq!(next.min_samples, 100); + assert_eq!(next.eps, 2.); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn fit_predict_dbscan() { diff --git a/src/cluster/kmeans.rs b/src/cluster/kmeans.rs index 05af680..8ecbb2e 100644 --- a/src/cluster/kmeans.rs +++ b/src/cluster/kmeans.rs @@ -132,6 +132,76 @@ impl Default for KMeansParameters { } } +/// KMeans grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct KMeansSearchParameters { + /// Number of clusters. + pub k: Vec, + /// Maximum number of iterations of the k-means algorithm for a single run. + pub max_iter: Vec, +} + +/// KMeans grid search iterator +pub struct KMeansSearchParametersIterator { + kmeans_search_parameters: KMeansSearchParameters, + current_k: usize, + current_max_iter: usize, +} + +impl IntoIterator for KMeansSearchParameters { + type Item = KMeansParameters; + type IntoIter = KMeansSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + KMeansSearchParametersIterator { + kmeans_search_parameters: self, + current_k: 0, + current_max_iter: 0, + } + } +} + +impl Iterator for KMeansSearchParametersIterator { + type Item = KMeansParameters; + + fn next(&mut self) -> Option { + if self.current_k == self.kmeans_search_parameters.k.len() + && self.current_max_iter == self.kmeans_search_parameters.max_iter.len() + { + return None; + } + + let next = KMeansParameters { + k: self.kmeans_search_parameters.k[self.current_k], + max_iter: self.kmeans_search_parameters.max_iter[self.current_max_iter], + }; + + if self.current_k + 1 < self.kmeans_search_parameters.k.len() { + self.current_k += 1; + } else if self.current_max_iter + 1 < self.kmeans_search_parameters.max_iter.len() { + self.current_k = 0; + self.current_max_iter += 1; + } else { + self.current_k += 1; + self.current_max_iter += 1; + } + + Some(next) + } +} + +impl Default for KMeansSearchParameters { + fn default() -> Self { + let default_params = KMeansParameters::default(); + + KMeansSearchParameters { + k: vec![default_params.k], + max_iter: vec![default_params.max_iter], + } + } +} + impl> UnsupervisedEstimator for KMeans { fn fit(x: &M, parameters: KMeansParameters) -> Result { KMeans::fit(x, parameters) @@ -313,6 +383,29 @@ mod tests { ); } + #[test] + fn search_parameters() { + let parameters = KMeansSearchParameters { + k: vec![2, 4], + max_iter: vec![10, 100], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.k, 2); + assert_eq!(next.max_iter, 10); + let next = iter.next().unwrap(); + assert_eq!(next.k, 4); + assert_eq!(next.max_iter, 10); + let next = iter.next().unwrap(); + assert_eq!(next.k, 2); + assert_eq!(next.max_iter, 100); + let next = iter.next().unwrap(); + assert_eq!(next.k, 4); + assert_eq!(next.max_iter, 100); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn fit_predict_iris() { diff --git a/src/decomposition/pca.rs b/src/decomposition/pca.rs index 9aebae2..296926a 100644 --- a/src/decomposition/pca.rs +++ b/src/decomposition/pca.rs @@ -116,6 +116,81 @@ impl Default for PCAParameters { } } +/// PCA grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct PCASearchParameters { + /// Number of components to keep. + pub n_components: Vec, + /// By default, covariance matrix is used to compute principal components. + /// Enable this flag if you want to use correlation matrix instead. + pub use_correlation_matrix: Vec, +} + +/// PCA grid search iterator +pub struct PCASearchParametersIterator { + pca_search_parameters: PCASearchParameters, + current_k: usize, + current_use_correlation_matrix: usize, +} + +impl IntoIterator for PCASearchParameters { + type Item = PCAParameters; + type IntoIter = PCASearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + PCASearchParametersIterator { + pca_search_parameters: self, + current_k: 0, + current_use_correlation_matrix: 0, + } + } +} + +impl Iterator for PCASearchParametersIterator { + type Item = PCAParameters; + + fn next(&mut self) -> Option { + if self.current_k == self.pca_search_parameters.n_components.len() + && self.current_use_correlation_matrix + == self.pca_search_parameters.use_correlation_matrix.len() + { + return None; + } + + let next = PCAParameters { + n_components: self.pca_search_parameters.n_components[self.current_k], + use_correlation_matrix: self.pca_search_parameters.use_correlation_matrix + [self.current_use_correlation_matrix], + }; + + if self.current_k + 1 < self.pca_search_parameters.n_components.len() { + self.current_k += 1; + } else if self.current_use_correlation_matrix + 1 + < self.pca_search_parameters.use_correlation_matrix.len() + { + self.current_k = 0; + self.current_use_correlation_matrix += 1; + } else { + self.current_k += 1; + self.current_use_correlation_matrix += 1; + } + + Some(next) + } +} + +impl Default for PCASearchParameters { + fn default() -> Self { + let default_params = PCAParameters::default(); + + PCASearchParameters { + n_components: vec![default_params.n_components], + use_correlation_matrix: vec![default_params.use_correlation_matrix], + } + } +} + impl> UnsupervisedEstimator for PCA { fn fit(x: &M, parameters: PCAParameters) -> Result { PCA::fit(x, parameters) @@ -271,6 +346,29 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::*; + #[test] + fn search_parameters() { + let parameters = PCASearchParameters { + n_components: vec![2, 4], + use_correlation_matrix: vec![true, false], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.n_components, 2); + assert_eq!(next.use_correlation_matrix, true); + let next = iter.next().unwrap(); + assert_eq!(next.n_components, 4); + assert_eq!(next.use_correlation_matrix, true); + let next = iter.next().unwrap(); + assert_eq!(next.n_components, 2); + assert_eq!(next.use_correlation_matrix, false); + let next = iter.next().unwrap(); + assert_eq!(next.n_components, 4); + assert_eq!(next.use_correlation_matrix, false); + assert!(iter.next().is_none()); + } + fn us_arrests_data() -> DenseMatrix { DenseMatrix::from_2d_array(&[ &[13.2, 236.0, 58.0, 21.2], diff --git a/src/decomposition/svd.rs b/src/decomposition/svd.rs index 3807760..3001fd9 100644 --- a/src/decomposition/svd.rs +++ b/src/decomposition/svd.rs @@ -90,6 +90,60 @@ impl SVDParameters { } } +/// SVD grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct SVDSearchParameters { + /// Maximum number of iterations of the k-means algorithm for a single run. + pub n_components: Vec, +} + +/// SVD grid search iterator +pub struct SVDSearchParametersIterator { + svd_search_parameters: SVDSearchParameters, + current_n_components: usize, +} + +impl IntoIterator for SVDSearchParameters { + type Item = SVDParameters; + type IntoIter = SVDSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + SVDSearchParametersIterator { + svd_search_parameters: self, + current_n_components: 0, + } + } +} + +impl Iterator for SVDSearchParametersIterator { + type Item = SVDParameters; + + fn next(&mut self) -> Option { + if self.current_n_components == self.svd_search_parameters.n_components.len() { + return None; + } + + let next = SVDParameters { + n_components: self.svd_search_parameters.n_components[self.current_n_components], + }; + + self.current_n_components += 1; + + Some(next) + } +} + +impl Default for SVDSearchParameters { + fn default() -> Self { + let default_params = SVDParameters::default(); + + SVDSearchParameters { + n_components: vec![default_params.n_components], + } + } +} + impl> UnsupervisedEstimator for SVD { fn fit(x: &M, parameters: SVDParameters) -> Result { SVD::fit(x, parameters) @@ -153,6 +207,20 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::*; + #[test] + fn search_parameters() { + let parameters = SVDSearchParameters { + n_components: vec![10, 100], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.n_components, 10); + let next = iter.next().unwrap(); + assert_eq!(next.n_components, 100); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn svd_decompose() { diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index 247b502..a4d6e75 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -193,6 +193,226 @@ impl> Predictor for RandomForestCla } } +/// RandomForestClassifier grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct RandomForestClassifierSearchParameters { + /// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub criterion: Vec, + /// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub max_depth: Vec>, + /// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub min_samples_leaf: Vec, + /// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub min_samples_split: Vec, + /// The number of trees in the forest. + pub n_trees: Vec, + /// Number of random sample of predictors to use as split candidates. + pub m: Vec>, + /// Whether to keep samples used for tree generation. This is required for OOB prediction. + pub keep_samples: Vec, + /// Seed used for bootstrap sampling and feature selection for each tree. + pub seed: Vec, +} + +/// RandomForestClassifier grid search iterator +pub struct RandomForestClassifierSearchParametersIterator { + random_forest_classifier_search_parameters: RandomForestClassifierSearchParameters, + current_criterion: usize, + current_max_depth: usize, + current_min_samples_leaf: usize, + current_min_samples_split: usize, + current_n_trees: usize, + current_m: usize, + current_keep_samples: usize, + current_seed: usize, +} + +impl IntoIterator for RandomForestClassifierSearchParameters { + type Item = RandomForestClassifierParameters; + type IntoIter = RandomForestClassifierSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + RandomForestClassifierSearchParametersIterator { + random_forest_classifier_search_parameters: self, + current_criterion: 0, + current_max_depth: 0, + current_min_samples_leaf: 0, + current_min_samples_split: 0, + current_n_trees: 0, + current_m: 0, + current_keep_samples: 0, + current_seed: 0, + } + } +} + +impl Iterator for RandomForestClassifierSearchParametersIterator { + type Item = RandomForestClassifierParameters; + + fn next(&mut self) -> Option { + if self.current_criterion + == self + .random_forest_classifier_search_parameters + .criterion + .len() + && self.current_max_depth + == self + .random_forest_classifier_search_parameters + .max_depth + .len() + && self.current_min_samples_leaf + == self + .random_forest_classifier_search_parameters + .min_samples_leaf + .len() + && self.current_min_samples_split + == self + .random_forest_classifier_search_parameters + .min_samples_split + .len() + && self.current_n_trees + == self + .random_forest_classifier_search_parameters + .n_trees + .len() + && self.current_m == self.random_forest_classifier_search_parameters.m.len() + && self.current_keep_samples + == self + .random_forest_classifier_search_parameters + .keep_samples + .len() + && self.current_seed == self.random_forest_classifier_search_parameters.seed.len() + { + return None; + } + + let next = RandomForestClassifierParameters { + criterion: self.random_forest_classifier_search_parameters.criterion + [self.current_criterion] + .clone(), + max_depth: self.random_forest_classifier_search_parameters.max_depth + [self.current_max_depth], + min_samples_leaf: self + .random_forest_classifier_search_parameters + .min_samples_leaf[self.current_min_samples_leaf], + min_samples_split: self + .random_forest_classifier_search_parameters + .min_samples_split[self.current_min_samples_split], + n_trees: self.random_forest_classifier_search_parameters.n_trees[self.current_n_trees], + m: self.random_forest_classifier_search_parameters.m[self.current_m], + keep_samples: self.random_forest_classifier_search_parameters.keep_samples + [self.current_keep_samples], + seed: self.random_forest_classifier_search_parameters.seed[self.current_seed], + }; + + if self.current_criterion + 1 + < self + .random_forest_classifier_search_parameters + .criterion + .len() + { + self.current_criterion += 1; + } else if self.current_max_depth + 1 + < self + .random_forest_classifier_search_parameters + .max_depth + .len() + { + self.current_criterion = 0; + self.current_max_depth += 1; + } else if self.current_min_samples_leaf + 1 + < self + .random_forest_classifier_search_parameters + .min_samples_leaf + .len() + { + self.current_criterion = 0; + self.current_max_depth = 0; + self.current_min_samples_leaf += 1; + } else if self.current_min_samples_split + 1 + < self + .random_forest_classifier_search_parameters + .min_samples_split + .len() + { + self.current_criterion = 0; + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split += 1; + } else if self.current_n_trees + 1 + < self + .random_forest_classifier_search_parameters + .n_trees + .len() + { + self.current_criterion = 0; + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split = 0; + self.current_n_trees += 1; + } else if self.current_m + 1 < self.random_forest_classifier_search_parameters.m.len() { + self.current_criterion = 0; + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split = 0; + self.current_n_trees = 0; + self.current_m += 1; + } else if self.current_keep_samples + 1 + < self + .random_forest_classifier_search_parameters + .keep_samples + .len() + { + self.current_criterion = 0; + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split = 0; + self.current_n_trees = 0; + self.current_m = 0; + self.current_keep_samples += 1; + } else if self.current_seed + 1 < self.random_forest_classifier_search_parameters.seed.len() + { + self.current_criterion = 0; + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split = 0; + self.current_n_trees = 0; + self.current_m = 0; + self.current_keep_samples = 0; + self.current_seed += 1; + } else { + self.current_criterion += 1; + self.current_max_depth += 1; + self.current_min_samples_leaf += 1; + self.current_min_samples_split += 1; + self.current_n_trees += 1; + self.current_m += 1; + self.current_keep_samples += 1; + self.current_seed += 1; + } + + Some(next) + } +} + +impl Default for RandomForestClassifierSearchParameters { + fn default() -> Self { + let default_params = RandomForestClassifierParameters::default(); + + RandomForestClassifierSearchParameters { + criterion: vec![default_params.criterion], + max_depth: vec![default_params.max_depth], + min_samples_leaf: vec![default_params.min_samples_leaf], + min_samples_split: vec![default_params.min_samples_split], + n_trees: vec![default_params.n_trees], + m: vec![default_params.m], + keep_samples: vec![default_params.keep_samples], + seed: vec![default_params.seed], + } + } +} + impl RandomForestClassifier { /// Build a forest of trees from the training set. /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation. @@ -346,6 +566,29 @@ mod tests { use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::metrics::*; + #[test] + fn search_parameters() { + let parameters = RandomForestClassifierSearchParameters { + n_trees: vec![10, 100], + m: vec![None, Some(1)], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.n_trees, 10); + assert_eq!(next.m, None); + let next = iter.next().unwrap(); + assert_eq!(next.n_trees, 100); + assert_eq!(next.m, None); + let next = iter.next().unwrap(); + assert_eq!(next.n_trees, 10); + assert_eq!(next.m, Some(1)); + let next = iter.next().unwrap(); + assert_eq!(next.n_trees, 100); + assert_eq!(next.m, Some(1)); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn fit_predict_iris() { diff --git a/src/ensemble/random_forest_regressor.rs b/src/ensemble/random_forest_regressor.rs index 08a7dcc..ec78137 100644 --- a/src/ensemble/random_forest_regressor.rs +++ b/src/ensemble/random_forest_regressor.rs @@ -176,6 +176,191 @@ impl> Predictor for RandomForestReg } } +/// RandomForestRegressor grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct RandomForestRegressorSearchParameters { + /// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub max_depth: Vec>, + /// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub min_samples_leaf: Vec, + /// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub min_samples_split: Vec, + /// The number of trees in the forest. + pub n_trees: Vec, + /// Number of random sample of predictors to use as split candidates. + pub m: Vec>, + /// Whether to keep samples used for tree generation. This is required for OOB prediction. + pub keep_samples: Vec, + /// Seed used for bootstrap sampling and feature selection for each tree. + pub seed: Vec, +} + +/// RandomForestRegressor grid search iterator +pub struct RandomForestRegressorSearchParametersIterator { + random_forest_regressor_search_parameters: RandomForestRegressorSearchParameters, + current_max_depth: usize, + current_min_samples_leaf: usize, + current_min_samples_split: usize, + current_n_trees: usize, + current_m: usize, + current_keep_samples: usize, + current_seed: usize, +} + +impl IntoIterator for RandomForestRegressorSearchParameters { + type Item = RandomForestRegressorParameters; + type IntoIter = RandomForestRegressorSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + RandomForestRegressorSearchParametersIterator { + random_forest_regressor_search_parameters: self, + current_max_depth: 0, + current_min_samples_leaf: 0, + current_min_samples_split: 0, + current_n_trees: 0, + current_m: 0, + current_keep_samples: 0, + current_seed: 0, + } + } +} + +impl Iterator for RandomForestRegressorSearchParametersIterator { + type Item = RandomForestRegressorParameters; + + fn next(&mut self) -> Option { + if self.current_max_depth + == self + .random_forest_regressor_search_parameters + .max_depth + .len() + && self.current_min_samples_leaf + == self + .random_forest_regressor_search_parameters + .min_samples_leaf + .len() + && self.current_min_samples_split + == self + .random_forest_regressor_search_parameters + .min_samples_split + .len() + && self.current_n_trees == self.random_forest_regressor_search_parameters.n_trees.len() + && self.current_m == self.random_forest_regressor_search_parameters.m.len() + && self.current_keep_samples + == self + .random_forest_regressor_search_parameters + .keep_samples + .len() + && self.current_seed == self.random_forest_regressor_search_parameters.seed.len() + { + return None; + } + + let next = RandomForestRegressorParameters { + max_depth: self.random_forest_regressor_search_parameters.max_depth + [self.current_max_depth], + min_samples_leaf: self + .random_forest_regressor_search_parameters + .min_samples_leaf[self.current_min_samples_leaf], + min_samples_split: self + .random_forest_regressor_search_parameters + .min_samples_split[self.current_min_samples_split], + n_trees: self.random_forest_regressor_search_parameters.n_trees[self.current_n_trees], + m: self.random_forest_regressor_search_parameters.m[self.current_m], + keep_samples: self.random_forest_regressor_search_parameters.keep_samples + [self.current_keep_samples], + seed: self.random_forest_regressor_search_parameters.seed[self.current_seed], + }; + + if self.current_max_depth + 1 + < self + .random_forest_regressor_search_parameters + .max_depth + .len() + { + self.current_max_depth += 1; + } else if self.current_min_samples_leaf + 1 + < self + .random_forest_regressor_search_parameters + .min_samples_leaf + .len() + { + self.current_max_depth = 0; + self.current_min_samples_leaf += 1; + } else if self.current_min_samples_split + 1 + < self + .random_forest_regressor_search_parameters + .min_samples_split + .len() + { + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split += 1; + } else if self.current_n_trees + 1 + < self.random_forest_regressor_search_parameters.n_trees.len() + { + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split = 0; + self.current_n_trees += 1; + } else if self.current_m + 1 < self.random_forest_regressor_search_parameters.m.len() { + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split = 0; + self.current_n_trees = 0; + self.current_m += 1; + } else if self.current_keep_samples + 1 + < self + .random_forest_regressor_search_parameters + .keep_samples + .len() + { + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split = 0; + self.current_n_trees = 0; + self.current_m = 0; + self.current_keep_samples += 1; + } else if self.current_seed + 1 < self.random_forest_regressor_search_parameters.seed.len() + { + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split = 0; + self.current_n_trees = 0; + self.current_m = 0; + self.current_keep_samples = 0; + self.current_seed += 1; + } else { + self.current_max_depth += 1; + self.current_min_samples_leaf += 1; + self.current_min_samples_split += 1; + self.current_n_trees += 1; + self.current_m += 1; + self.current_keep_samples += 1; + self.current_seed += 1; + } + + Some(next) + } +} + +impl Default for RandomForestRegressorSearchParameters { + fn default() -> Self { + let default_params = RandomForestRegressorParameters::default(); + + RandomForestRegressorSearchParameters { + max_depth: vec![default_params.max_depth], + min_samples_leaf: vec![default_params.min_samples_leaf], + min_samples_split: vec![default_params.min_samples_split], + n_trees: vec![default_params.n_trees], + m: vec![default_params.m], + keep_samples: vec![default_params.keep_samples], + seed: vec![default_params.seed], + } + } +} + impl RandomForestRegressor { /// Build a forest of trees from the training set. /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation. @@ -302,6 +487,29 @@ mod tests { use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::metrics::mean_absolute_error; + #[test] + fn search_parameters() { + let parameters = RandomForestRegressorSearchParameters { + n_trees: vec![10, 100], + m: vec![None, Some(1)], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.n_trees, 10); + assert_eq!(next.m, None); + let next = iter.next().unwrap(); + assert_eq!(next.n_trees, 100); + assert_eq!(next.m, None); + let next = iter.next().unwrap(); + assert_eq!(next.n_trees, 10); + assert_eq!(next.m, Some(1)); + let next = iter.next().unwrap(); + assert_eq!(next.n_trees, 100); + assert_eq!(next.m, Some(1)); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn fit_longley() { diff --git a/src/linear/lasso.rs b/src/linear/lasso.rs index 7e80a8b..aae7e50 100644 --- a/src/linear/lasso.rs +++ b/src/linear/lasso.rs @@ -129,7 +129,7 @@ pub struct LassoSearchParameters { /// Lasso grid search iterator pub struct LassoSearchParametersIterator { - lasso_regression_search_parameters: LassoSearchParameters, + lasso_search_parameters: LassoSearchParameters, current_alpha: usize, current_normalize: usize, current_tol: usize, @@ -142,7 +142,7 @@ impl IntoIterator for LassoSearchParameters { fn into_iter(self) -> Self::IntoIter { LassoSearchParametersIterator { - lasso_regression_search_parameters: self, + lasso_search_parameters: self, current_alpha: 0, current_normalize: 0, current_tol: 0, @@ -155,34 +155,31 @@ impl Iterator for LassoSearchParametersIterator { type Item = LassoParameters; fn next(&mut self) -> Option { - if self.current_alpha == self.lasso_regression_search_parameters.alpha.len() - && self.current_normalize == self.lasso_regression_search_parameters.normalize.len() - && self.current_tol == self.lasso_regression_search_parameters.tol.len() - && self.current_max_iter == self.lasso_regression_search_parameters.max_iter.len() + if self.current_alpha == self.lasso_search_parameters.alpha.len() + && self.current_normalize == self.lasso_search_parameters.normalize.len() + && self.current_tol == self.lasso_search_parameters.tol.len() + && self.current_max_iter == self.lasso_search_parameters.max_iter.len() { return None; } let next = LassoParameters { - alpha: self.lasso_regression_search_parameters.alpha[self.current_alpha], - normalize: self.lasso_regression_search_parameters.normalize[self.current_normalize], - tol: self.lasso_regression_search_parameters.tol[self.current_tol], - max_iter: self.lasso_regression_search_parameters.max_iter[self.current_max_iter], + alpha: self.lasso_search_parameters.alpha[self.current_alpha], + normalize: self.lasso_search_parameters.normalize[self.current_normalize], + tol: self.lasso_search_parameters.tol[self.current_tol], + max_iter: self.lasso_search_parameters.max_iter[self.current_max_iter], }; - if self.current_alpha + 1 < self.lasso_regression_search_parameters.alpha.len() { + if self.current_alpha + 1 < self.lasso_search_parameters.alpha.len() { self.current_alpha += 1; - } else if self.current_normalize + 1 - < self.lasso_regression_search_parameters.normalize.len() - { + } else if self.current_normalize + 1 < self.lasso_search_parameters.normalize.len() { self.current_alpha = 0; self.current_normalize += 1; - } else if self.current_tol + 1 < self.lasso_regression_search_parameters.tol.len() { + } else if self.current_tol + 1 < self.lasso_search_parameters.tol.len() { self.current_alpha = 0; self.current_normalize = 0; self.current_tol += 1; - } else if self.current_max_iter + 1 < self.lasso_regression_search_parameters.max_iter.len() - { + } else if self.current_max_iter + 1 < self.lasso_search_parameters.max_iter.len() { self.current_alpha = 0; self.current_normalize = 0; self.current_tol = 0; diff --git a/src/model_selection/hyper_tuning.rs b/src/model_selection/hyper_tuning.rs index 3093fbd..cb69da1 100644 --- a/src/model_selection/hyper_tuning.rs +++ b/src/model_selection/hyper_tuning.rs @@ -114,4 +114,4 @@ mod tests { assert!([0., 1.].contains(&results.parameters.alpha)); } -} \ No newline at end of file +} diff --git a/src/model_selection/mod.rs b/src/model_selection/mod.rs index 68f0635..6f737d6 100644 --- a/src/model_selection/mod.rs +++ b/src/model_selection/mod.rs @@ -281,7 +281,6 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::*; - use crate::metrics::{accuracy, mean_absolute_error}; use crate::model_selection::kfold::KFold; use crate::neighbors::knn_regressor::KNNRegressor; diff --git a/src/naive_bayes/bernoulli.rs b/src/naive_bayes/bernoulli.rs index 95c4d36..29c6c84 100644 --- a/src/naive_bayes/bernoulli.rs +++ b/src/naive_bayes/bernoulli.rs @@ -150,6 +150,88 @@ impl Default for BernoulliNBParameters { } } +/// BernoulliNB grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct BernoulliNBSearchParameters { + /// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing). + pub alpha: Vec, + /// Prior probabilities of the classes. If specified the priors are not adjusted according to the data + pub priors: Vec>>, + /// Threshold for binarizing (mapping to booleans) of sample features. If None, input is presumed to already consist of binary vectors. + pub binarize: Vec>, +} + +/// BernoulliNB grid search iterator +pub struct BernoulliNBSearchParametersIterator { + bernoulli_nb_search_parameters: BernoulliNBSearchParameters, + current_alpha: usize, + current_priors: usize, + current_binarize: usize, +} + +impl IntoIterator for BernoulliNBSearchParameters { + type Item = BernoulliNBParameters; + type IntoIter = BernoulliNBSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + BernoulliNBSearchParametersIterator { + bernoulli_nb_search_parameters: self, + current_alpha: 0, + current_priors: 0, + current_binarize: 0, + } + } +} + +impl Iterator for BernoulliNBSearchParametersIterator { + type Item = BernoulliNBParameters; + + fn next(&mut self) -> Option { + if self.current_alpha == self.bernoulli_nb_search_parameters.alpha.len() + && self.current_priors == self.bernoulli_nb_search_parameters.priors.len() + && self.current_binarize == self.bernoulli_nb_search_parameters.binarize.len() + { + return None; + } + + let next = BernoulliNBParameters { + alpha: self.bernoulli_nb_search_parameters.alpha[self.current_alpha], + priors: self.bernoulli_nb_search_parameters.priors[self.current_priors].clone(), + binarize: self.bernoulli_nb_search_parameters.binarize[self.current_binarize], + }; + + if self.current_alpha + 1 < self.bernoulli_nb_search_parameters.alpha.len() { + self.current_alpha += 1; + } else if self.current_priors + 1 < self.bernoulli_nb_search_parameters.priors.len() { + self.current_alpha = 0; + self.current_priors += 1; + } else if self.current_binarize + 1 < self.bernoulli_nb_search_parameters.binarize.len() { + self.current_alpha = 0; + self.current_priors = 0; + self.current_binarize += 1; + } else { + self.current_alpha += 1; + self.current_priors += 1; + self.current_binarize += 1; + } + + Some(next) + } +} + +impl Default for BernoulliNBSearchParameters { + fn default() -> Self { + let default_params = BernoulliNBParameters::default(); + + BernoulliNBSearchParameters { + alpha: vec![default_params.alpha], + priors: vec![default_params.priors], + binarize: vec![default_params.binarize], + } + } +} + impl BernoulliNBDistribution { /// Fits the distribution to a NxM matrix where N is number of samples and M is number of features. /// * `x` - training data. @@ -347,6 +429,20 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; + #[test] + fn search_parameters() { + let parameters = BernoulliNBSearchParameters { + alpha: vec![1., 2.], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 1.); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 2.); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn run_bernoulli_naive_bayes() { diff --git a/src/naive_bayes/categorical.rs b/src/naive_bayes/categorical.rs index 8706702..7855688 100644 --- a/src/naive_bayes/categorical.rs +++ b/src/naive_bayes/categorical.rs @@ -261,6 +261,60 @@ impl Default for CategoricalNBParameters { } } +/// CategoricalNB grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct CategoricalNBSearchParameters { + /// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing). + pub alpha: Vec, +} + +/// CategoricalNB grid search iterator +pub struct CategoricalNBSearchParametersIterator { + categorical_nb_search_parameters: CategoricalNBSearchParameters, + current_alpha: usize, +} + +impl IntoIterator for CategoricalNBSearchParameters { + type Item = CategoricalNBParameters; + type IntoIter = CategoricalNBSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + CategoricalNBSearchParametersIterator { + categorical_nb_search_parameters: self, + current_alpha: 0, + } + } +} + +impl Iterator for CategoricalNBSearchParametersIterator { + type Item = CategoricalNBParameters; + + fn next(&mut self) -> Option { + if self.current_alpha == self.categorical_nb_search_parameters.alpha.len() { + return None; + } + + let next = CategoricalNBParameters { + alpha: self.categorical_nb_search_parameters.alpha[self.current_alpha], + }; + + self.current_alpha += 1; + + Some(next) + } +} + +impl Default for CategoricalNBSearchParameters { + fn default() -> Self { + let default_params = CategoricalNBParameters::default(); + + CategoricalNBSearchParameters { + alpha: vec![default_params.alpha], + } + } +} + /// CategoricalNB implements the categorical naive Bayes algorithm for categorically distributed data. #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, PartialEq)] @@ -351,6 +405,20 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; + #[test] + fn search_parameters() { + let parameters = CategoricalNBSearchParameters { + alpha: vec![1., 2.], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 1.); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 2.); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn run_categorical_naive_bayes() { diff --git a/src/naive_bayes/gaussian.rs b/src/naive_bayes/gaussian.rs index bd23919..24bbdd3 100644 --- a/src/naive_bayes/gaussian.rs +++ b/src/naive_bayes/gaussian.rs @@ -76,7 +76,7 @@ impl> NBDistribution for GaussianNBDistributio /// `GaussianNB` parameters. Use `Default::default()` for default values. #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Default, Clone)] +#[derive(Debug, Clone)] pub struct GaussianNBParameters { /// Prior probabilities of the classes. If specified the priors are not adjusted according to the data pub priors: Option>, @@ -90,6 +90,66 @@ impl GaussianNBParameters { } } +impl Default for GaussianNBParameters { + fn default() -> Self { + Self { priors: None } + } +} + +/// GaussianNB grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct GaussianNBSearchParameters { + /// Prior probabilities of the classes. If specified the priors are not adjusted according to the data + pub priors: Vec>>, +} + +/// GaussianNB grid search iterator +pub struct GaussianNBSearchParametersIterator { + gaussian_nb_search_parameters: GaussianNBSearchParameters, + current_priors: usize, +} + +impl IntoIterator for GaussianNBSearchParameters { + type Item = GaussianNBParameters; + type IntoIter = GaussianNBSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + GaussianNBSearchParametersIterator { + gaussian_nb_search_parameters: self, + current_priors: 0, + } + } +} + +impl Iterator for GaussianNBSearchParametersIterator { + type Item = GaussianNBParameters; + + fn next(&mut self) -> Option { + if self.current_priors == self.gaussian_nb_search_parameters.priors.len() { + return None; + } + + let next = GaussianNBParameters { + priors: self.gaussian_nb_search_parameters.priors[self.current_priors].clone(), + }; + + self.current_priors += 1; + + Some(next) + } +} + +impl Default for GaussianNBSearchParameters { + fn default() -> Self { + let default_params = GaussianNBParameters::default(); + + GaussianNBSearchParameters { + priors: vec![default_params.priors], + } + } +} + impl GaussianNBDistribution { /// Fits the distribution to a NxM matrix where N is number of samples and M is number of features. /// * `x` - training data. @@ -260,6 +320,20 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; + #[test] + fn search_parameters() { + let parameters = GaussianNBSearchParameters { + priors: vec![Some(vec![1.]), Some(vec![2.])], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.priors, Some(vec![1.])); + let next = iter.next().unwrap(); + assert_eq!(next.priors, Some(vec![2.])); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn run_gaussian_naive_bayes() { diff --git a/src/naive_bayes/multinomial.rs b/src/naive_bayes/multinomial.rs index f42b99e..6e846c1 100644 --- a/src/naive_bayes/multinomial.rs +++ b/src/naive_bayes/multinomial.rs @@ -114,6 +114,76 @@ impl Default for MultinomialNBParameters { } } +/// MultinomialNB grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct MultinomialNBSearchParameters { + /// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing). + pub alpha: Vec, + /// Prior probabilities of the classes. If specified the priors are not adjusted according to the data + pub priors: Vec>>, +} + +/// MultinomialNB grid search iterator +pub struct MultinomialNBSearchParametersIterator { + multinomial_nb_search_parameters: MultinomialNBSearchParameters, + current_alpha: usize, + current_priors: usize, +} + +impl IntoIterator for MultinomialNBSearchParameters { + type Item = MultinomialNBParameters; + type IntoIter = MultinomialNBSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + MultinomialNBSearchParametersIterator { + multinomial_nb_search_parameters: self, + current_alpha: 0, + current_priors: 0, + } + } +} + +impl Iterator for MultinomialNBSearchParametersIterator { + type Item = MultinomialNBParameters; + + fn next(&mut self) -> Option { + if self.current_alpha == self.multinomial_nb_search_parameters.alpha.len() + && self.current_priors == self.multinomial_nb_search_parameters.priors.len() + { + return None; + } + + let next = MultinomialNBParameters { + alpha: self.multinomial_nb_search_parameters.alpha[self.current_alpha], + priors: self.multinomial_nb_search_parameters.priors[self.current_priors].clone(), + }; + + if self.current_alpha + 1 < self.multinomial_nb_search_parameters.alpha.len() { + self.current_alpha += 1; + } else if self.current_priors + 1 < self.multinomial_nb_search_parameters.priors.len() { + self.current_alpha = 0; + self.current_priors += 1; + } else { + self.current_alpha += 1; + self.current_priors += 1; + } + + Some(next) + } +} + +impl Default for MultinomialNBSearchParameters { + fn default() -> Self { + let default_params = MultinomialNBParameters::default(); + + MultinomialNBSearchParameters { + alpha: vec![default_params.alpha], + priors: vec![default_params.priors], + } + } +} + impl MultinomialNBDistribution { /// Fits the distribution to a NxM matrix where N is number of samples and M is number of features. /// * `x` - training data. @@ -297,6 +367,20 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; + #[test] + fn search_parameters() { + let parameters = MultinomialNBSearchParameters { + alpha: vec![1., 2.], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 1.); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 2.); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn run_multinomial_naive_bayes() { diff --git a/src/svm/mod.rs b/src/svm/mod.rs index 55df584..4c71b3f 100644 --- a/src/svm/mod.rs +++ b/src/svm/mod.rs @@ -33,7 +33,7 @@ use crate::linalg::BaseVector; use crate::math::num::RealNumber; /// Defines a kernel function -pub trait Kernel> { +pub trait Kernel>: Clone { /// Apply kernel function to x_i and x_j fn apply(&self, x_i: &V, x_j: &V) -> T; } @@ -95,12 +95,12 @@ impl Kernels { /// Linear Kernel #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct LinearKernel {} /// Radial basis function (Gaussian) kernel #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct RBFKernel { /// kernel coefficient pub gamma: T, @@ -108,7 +108,7 @@ pub struct RBFKernel { /// Polynomial kernel #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct PolynomialKernel { /// degree of the polynomial pub degree: T, @@ -120,7 +120,7 @@ pub struct PolynomialKernel { /// Sigmoid (hyperbolic tangent) kernel #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct SigmoidKernel { /// kernel coefficient pub gamma: T, diff --git a/src/svm/svc.rs b/src/svm/svc.rs index 87fb743..46b0b68 100644 --- a/src/svm/svc.rs +++ b/src/svm/svc.rs @@ -102,6 +102,109 @@ pub struct SVCParameters, K: Kernel m: PhantomData, } +/// SVC grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct SVCSearchParameters, K: Kernel> { + /// Number of epochs. + pub epoch: Vec, + /// Regularization parameter. + pub c: Vec, + /// Tolerance for stopping epoch. + pub tol: Vec, + /// The kernel function. + pub kernel: Vec, + /// Unused parameter. + m: PhantomData, +} + +/// SVC grid search iterator +pub struct SVCSearchParametersIterator, K: Kernel> { + svc_search_parameters: SVCSearchParameters, + current_epoch: usize, + current_c: usize, + current_tol: usize, + current_kernel: usize, +} + +impl, K: Kernel> IntoIterator + for SVCSearchParameters +{ + type Item = SVCParameters; + type IntoIter = SVCSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + SVCSearchParametersIterator { + svc_search_parameters: self, + current_epoch: 0, + current_c: 0, + current_tol: 0, + current_kernel: 0, + } + } +} + +impl, K: Kernel> Iterator + for SVCSearchParametersIterator +{ + type Item = SVCParameters; + + fn next(&mut self) -> Option { + if self.current_epoch == self.svc_search_parameters.epoch.len() + && self.current_c == self.svc_search_parameters.c.len() + && self.current_tol == self.svc_search_parameters.tol.len() + && self.current_kernel == self.svc_search_parameters.kernel.len() + { + return None; + } + + let next = SVCParameters:: { + epoch: self.svc_search_parameters.epoch[self.current_epoch], + c: self.svc_search_parameters.c[self.current_c], + tol: self.svc_search_parameters.tol[self.current_tol], + kernel: self.svc_search_parameters.kernel[self.current_kernel].clone(), + m: PhantomData, + }; + + if self.current_epoch + 1 < self.svc_search_parameters.epoch.len() { + self.current_epoch += 1; + } else if self.current_c + 1 < self.svc_search_parameters.c.len() { + self.current_epoch = 0; + self.current_c += 1; + } else if self.current_tol + 1 < self.svc_search_parameters.tol.len() { + self.current_epoch = 0; + self.current_c = 0; + self.current_tol += 1; + } else if self.current_kernel + 1 < self.svc_search_parameters.kernel.len() { + self.current_epoch = 0; + self.current_c = 0; + self.current_tol = 0; + self.current_kernel += 1; + } else { + self.current_epoch += 1; + self.current_c += 1; + self.current_tol += 1; + self.current_kernel += 1; + } + + Some(next) + } +} + +impl> Default for SVCSearchParameters { + fn default() -> Self { + let default_params: SVCParameters = SVCParameters::default(); + + SVCSearchParameters { + epoch: vec![default_params.epoch], + c: vec![default_params.c], + tol: vec![default_params.tol], + kernel: vec![default_params.kernel], + m: PhantomData, + } + } +} + #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug)] #[cfg_attr( @@ -737,6 +840,24 @@ mod tests { #[cfg(feature = "serde")] use crate::svm::*; + #[test] + fn search_parameters() { + let parameters: SVCSearchParameters, LinearKernel> = + SVCSearchParameters { + epoch: vec![10, 100], + kernel: vec![LinearKernel {}], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.epoch, 10); + assert_eq!(next.kernel, LinearKernel {}); + let next = iter.next().unwrap(); + assert_eq!(next.epoch, 100); + assert_eq!(next.kernel, LinearKernel {}); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn svc_fit_predict() { diff --git a/src/svm/svr.rs b/src/svm/svr.rs index 18c73d1..25326d4 100644 --- a/src/svm/svr.rs +++ b/src/svm/svr.rs @@ -94,6 +94,109 @@ pub struct SVRParameters, K: Kernel m: PhantomData, } +/// SVR grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct SVRSearchParameters, K: Kernel> { + /// Epsilon in the epsilon-SVR model. + pub eps: Vec, + /// Regularization parameter. + pub c: Vec, + /// Tolerance for stopping eps. + pub tol: Vec, + /// The kernel function. + pub kernel: Vec, + /// Unused parameter. + m: PhantomData, +} + +/// SVR grid search iterator +pub struct SVRSearchParametersIterator, K: Kernel> { + svr_search_parameters: SVRSearchParameters, + current_eps: usize, + current_c: usize, + current_tol: usize, + current_kernel: usize, +} + +impl, K: Kernel> IntoIterator + for SVRSearchParameters +{ + type Item = SVRParameters; + type IntoIter = SVRSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + SVRSearchParametersIterator { + svr_search_parameters: self, + current_eps: 0, + current_c: 0, + current_tol: 0, + current_kernel: 0, + } + } +} + +impl, K: Kernel> Iterator + for SVRSearchParametersIterator +{ + type Item = SVRParameters; + + fn next(&mut self) -> Option { + if self.current_eps == self.svr_search_parameters.eps.len() + && self.current_c == self.svr_search_parameters.c.len() + && self.current_tol == self.svr_search_parameters.tol.len() + && self.current_kernel == self.svr_search_parameters.kernel.len() + { + return None; + } + + let next = SVRParameters:: { + eps: self.svr_search_parameters.eps[self.current_eps], + c: self.svr_search_parameters.c[self.current_c], + tol: self.svr_search_parameters.tol[self.current_tol], + kernel: self.svr_search_parameters.kernel[self.current_kernel].clone(), + m: PhantomData, + }; + + if self.current_eps + 1 < self.svr_search_parameters.eps.len() { + self.current_eps += 1; + } else if self.current_c + 1 < self.svr_search_parameters.c.len() { + self.current_eps = 0; + self.current_c += 1; + } else if self.current_tol + 1 < self.svr_search_parameters.tol.len() { + self.current_eps = 0; + self.current_c = 0; + self.current_tol += 1; + } else if self.current_kernel + 1 < self.svr_search_parameters.kernel.len() { + self.current_eps = 0; + self.current_c = 0; + self.current_tol = 0; + self.current_kernel += 1; + } else { + self.current_eps += 1; + self.current_c += 1; + self.current_tol += 1; + self.current_kernel += 1; + } + + Some(next) + } +} + +impl> Default for SVRSearchParameters { + fn default() -> Self { + let default_params: SVRParameters = SVRParameters::default(); + + SVRSearchParameters { + eps: vec![default_params.eps], + c: vec![default_params.c], + tol: vec![default_params.tol], + kernel: vec![default_params.kernel], + m: PhantomData, + } + } +} + #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug)] #[cfg_attr( @@ -536,6 +639,24 @@ mod tests { #[cfg(feature = "serde")] use crate::svm::*; + #[test] + fn search_parameters() { + let parameters: SVRSearchParameters, LinearKernel> = + SVRSearchParameters { + eps: vec![0., 1.], + kernel: vec![LinearKernel {}], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.eps, 0.); + assert_eq!(next.kernel, LinearKernel {}); + let next = iter.next().unwrap(); + assert_eq!(next.eps, 1.); + assert_eq!(next.kernel, LinearKernel {}); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn svr_fit_predict() { diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index 35889e4..a1699af 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -201,6 +201,144 @@ impl Default for DecisionTreeClassifierParameters { } } +/// DecisionTreeClassifier grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct DecisionTreeClassifierSearchParameters { + /// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub criterion: Vec, + /// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub max_depth: Vec>, + /// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub min_samples_leaf: Vec, + /// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub min_samples_split: Vec, +} + +/// DecisionTreeClassifier grid search iterator +pub struct DecisionTreeClassifierSearchParametersIterator { + decision_tree_classifier_search_parameters: DecisionTreeClassifierSearchParameters, + current_criterion: usize, + current_max_depth: usize, + current_min_samples_leaf: usize, + current_min_samples_split: usize, +} + +impl IntoIterator for DecisionTreeClassifierSearchParameters { + type Item = DecisionTreeClassifierParameters; + type IntoIter = DecisionTreeClassifierSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + DecisionTreeClassifierSearchParametersIterator { + decision_tree_classifier_search_parameters: self, + current_criterion: 0, + current_max_depth: 0, + current_min_samples_leaf: 0, + current_min_samples_split: 0, + } + } +} + +impl Iterator for DecisionTreeClassifierSearchParametersIterator { + type Item = DecisionTreeClassifierParameters; + + fn next(&mut self) -> Option { + if self.current_criterion + == self + .decision_tree_classifier_search_parameters + .criterion + .len() + && self.current_max_depth + == self + .decision_tree_classifier_search_parameters + .max_depth + .len() + && self.current_min_samples_leaf + == self + .decision_tree_classifier_search_parameters + .min_samples_leaf + .len() + && self.current_min_samples_split + == self + .decision_tree_classifier_search_parameters + .min_samples_split + .len() + { + return None; + } + + let next = DecisionTreeClassifierParameters { + criterion: self.decision_tree_classifier_search_parameters.criterion + [self.current_criterion] + .clone(), + max_depth: self.decision_tree_classifier_search_parameters.max_depth + [self.current_max_depth], + min_samples_leaf: self + .decision_tree_classifier_search_parameters + .min_samples_leaf[self.current_min_samples_leaf], + min_samples_split: self + .decision_tree_classifier_search_parameters + .min_samples_split[self.current_min_samples_split], + }; + + if self.current_criterion + 1 + < self + .decision_tree_classifier_search_parameters + .criterion + .len() + { + self.current_criterion += 1; + } else if self.current_max_depth + 1 + < self + .decision_tree_classifier_search_parameters + .max_depth + .len() + { + self.current_criterion = 0; + self.current_max_depth += 1; + } else if self.current_min_samples_leaf + 1 + < self + .decision_tree_classifier_search_parameters + .min_samples_leaf + .len() + { + self.current_criterion = 0; + self.current_max_depth = 0; + self.current_min_samples_leaf += 1; + } else if self.current_min_samples_split + 1 + < self + .decision_tree_classifier_search_parameters + .min_samples_split + .len() + { + self.current_criterion = 0; + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split += 1; + } else { + self.current_criterion += 1; + self.current_max_depth += 1; + self.current_min_samples_leaf += 1; + self.current_min_samples_split += 1; + } + + Some(next) + } +} + +impl Default for DecisionTreeClassifierSearchParameters { + fn default() -> Self { + let default_params = DecisionTreeClassifierParameters::default(); + + DecisionTreeClassifierSearchParameters { + criterion: vec![default_params.criterion], + max_depth: vec![default_params.max_depth], + min_samples_leaf: vec![default_params.min_samples_leaf], + min_samples_split: vec![default_params.min_samples_split], + } + } +} + impl Node { fn new(index: usize, output: usize) -> Self { Node { @@ -651,6 +789,29 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; + #[test] + fn search_parameters() { + let parameters = DecisionTreeClassifierSearchParameters { + max_depth: vec![Some(10), Some(100)], + min_samples_split: vec![1, 2], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.max_depth, Some(10)); + assert_eq!(next.min_samples_split, 1); + let next = iter.next().unwrap(); + assert_eq!(next.max_depth, Some(100)); + assert_eq!(next.min_samples_split, 1); + let next = iter.next().unwrap(); + assert_eq!(next.max_depth, Some(10)); + assert_eq!(next.min_samples_split, 2); + let next = iter.next().unwrap(); + assert_eq!(next.max_depth, Some(100)); + assert_eq!(next.min_samples_split, 2); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn gini_impurity() { diff --git a/src/tree/decision_tree_regressor.rs b/src/tree/decision_tree_regressor.rs index 25f5e7e..f48de33 100644 --- a/src/tree/decision_tree_regressor.rs +++ b/src/tree/decision_tree_regressor.rs @@ -134,6 +134,120 @@ impl Default for DecisionTreeRegressorParameters { } } +/// DecisionTreeRegressor grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct DecisionTreeRegressorSearchParameters { + /// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) + pub max_depth: Vec>, + /// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) + pub min_samples_leaf: Vec, + /// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) + pub min_samples_split: Vec, +} + +/// DecisionTreeRegressor grid search iterator +pub struct DecisionTreeRegressorSearchParametersIterator { + decision_tree_regressor_search_parameters: DecisionTreeRegressorSearchParameters, + current_max_depth: usize, + current_min_samples_leaf: usize, + current_min_samples_split: usize, +} + +impl IntoIterator for DecisionTreeRegressorSearchParameters { + type Item = DecisionTreeRegressorParameters; + type IntoIter = DecisionTreeRegressorSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + DecisionTreeRegressorSearchParametersIterator { + decision_tree_regressor_search_parameters: self, + current_max_depth: 0, + current_min_samples_leaf: 0, + current_min_samples_split: 0, + } + } +} + +impl Iterator for DecisionTreeRegressorSearchParametersIterator { + type Item = DecisionTreeRegressorParameters; + + fn next(&mut self) -> Option { + if self.current_max_depth + == self + .decision_tree_regressor_search_parameters + .max_depth + .len() + && self.current_min_samples_leaf + == self + .decision_tree_regressor_search_parameters + .min_samples_leaf + .len() + && self.current_min_samples_split + == self + .decision_tree_regressor_search_parameters + .min_samples_split + .len() + { + return None; + } + + let next = DecisionTreeRegressorParameters { + max_depth: self.decision_tree_regressor_search_parameters.max_depth + [self.current_max_depth], + min_samples_leaf: self + .decision_tree_regressor_search_parameters + .min_samples_leaf[self.current_min_samples_leaf], + min_samples_split: self + .decision_tree_regressor_search_parameters + .min_samples_split[self.current_min_samples_split], + }; + + if self.current_max_depth + 1 + < self + .decision_tree_regressor_search_parameters + .max_depth + .len() + { + self.current_max_depth += 1; + } else if self.current_min_samples_leaf + 1 + < self + .decision_tree_regressor_search_parameters + .min_samples_leaf + .len() + { + self.current_max_depth = 0; + self.current_min_samples_leaf += 1; + } else if self.current_min_samples_split + 1 + < self + .decision_tree_regressor_search_parameters + .min_samples_split + .len() + { + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split += 1; + } else { + self.current_max_depth += 1; + self.current_min_samples_leaf += 1; + self.current_min_samples_split += 1; + } + + Some(next) + } +} + +impl Default for DecisionTreeRegressorSearchParameters { + fn default() -> Self { + let default_params = DecisionTreeRegressorParameters::default(); + + DecisionTreeRegressorSearchParameters { + max_depth: vec![default_params.max_depth], + min_samples_leaf: vec![default_params.min_samples_leaf], + min_samples_split: vec![default_params.min_samples_split], + } + } +} + impl Node { fn new(index: usize, output: T) -> Self { Node { @@ -517,6 +631,29 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; + #[test] + fn search_parameters() { + let parameters = DecisionTreeRegressorSearchParameters { + max_depth: vec![Some(10), Some(100)], + min_samples_split: vec![1, 2], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.max_depth, Some(10)); + assert_eq!(next.min_samples_split, 1); + let next = iter.next().unwrap(); + assert_eq!(next.max_depth, Some(100)); + assert_eq!(next.min_samples_split, 1); + let next = iter.next().unwrap(); + assert_eq!(next.max_depth, Some(10)); + assert_eq!(next.min_samples_split, 2); + let next = iter.next().unwrap(); + assert_eq!(next.max_depth, Some(100)); + assert_eq!(next.min_samples_split, 2); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn fit_longley() {