add seed param to search params (#168)
This commit is contained in:
@@ -145,6 +145,9 @@ pub struct KMeansSearchParameters {
|
|||||||
pub k: Vec<usize>,
|
pub k: Vec<usize>,
|
||||||
/// Maximum number of iterations of the k-means algorithm for a single run.
|
/// Maximum number of iterations of the k-means algorithm for a single run.
|
||||||
pub max_iter: Vec<usize>,
|
pub max_iter: Vec<usize>,
|
||||||
|
/// Determines random number generation for centroid initialization.
|
||||||
|
/// Use an int to make the randomness deterministic
|
||||||
|
pub seed: Vec<Option<u64>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// KMeans grid search iterator
|
/// KMeans grid search iterator
|
||||||
@@ -152,6 +155,7 @@ pub struct KMeansSearchParametersIterator {
|
|||||||
kmeans_search_parameters: KMeansSearchParameters,
|
kmeans_search_parameters: KMeansSearchParameters,
|
||||||
current_k: usize,
|
current_k: usize,
|
||||||
current_max_iter: usize,
|
current_max_iter: usize,
|
||||||
|
current_seed: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IntoIterator for KMeansSearchParameters {
|
impl IntoIterator for KMeansSearchParameters {
|
||||||
@@ -163,6 +167,7 @@ impl IntoIterator for KMeansSearchParameters {
|
|||||||
kmeans_search_parameters: self,
|
kmeans_search_parameters: self,
|
||||||
current_k: 0,
|
current_k: 0,
|
||||||
current_max_iter: 0,
|
current_max_iter: 0,
|
||||||
|
current_seed: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -173,6 +178,7 @@ impl Iterator for KMeansSearchParametersIterator {
|
|||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
if self.current_k == self.kmeans_search_parameters.k.len()
|
if self.current_k == self.kmeans_search_parameters.k.len()
|
||||||
&& self.current_max_iter == self.kmeans_search_parameters.max_iter.len()
|
&& self.current_max_iter == self.kmeans_search_parameters.max_iter.len()
|
||||||
|
&& self.current_seed == self.kmeans_search_parameters.seed.len()
|
||||||
{
|
{
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
@@ -180,6 +186,7 @@ impl Iterator for KMeansSearchParametersIterator {
|
|||||||
let next = KMeansParameters {
|
let next = KMeansParameters {
|
||||||
k: self.kmeans_search_parameters.k[self.current_k],
|
k: self.kmeans_search_parameters.k[self.current_k],
|
||||||
max_iter: self.kmeans_search_parameters.max_iter[self.current_max_iter],
|
max_iter: self.kmeans_search_parameters.max_iter[self.current_max_iter],
|
||||||
|
seed: self.kmeans_search_parameters.seed[self.current_seed],
|
||||||
};
|
};
|
||||||
|
|
||||||
if self.current_k + 1 < self.kmeans_search_parameters.k.len() {
|
if self.current_k + 1 < self.kmeans_search_parameters.k.len() {
|
||||||
@@ -187,9 +194,14 @@ impl Iterator for KMeansSearchParametersIterator {
|
|||||||
} else if self.current_max_iter + 1 < self.kmeans_search_parameters.max_iter.len() {
|
} else if self.current_max_iter + 1 < self.kmeans_search_parameters.max_iter.len() {
|
||||||
self.current_k = 0;
|
self.current_k = 0;
|
||||||
self.current_max_iter += 1;
|
self.current_max_iter += 1;
|
||||||
|
} else if self.current_seed + 1 < self.kmeans_search_parameters.seed.len() {
|
||||||
|
self.current_k = 0;
|
||||||
|
self.current_max_iter = 0;
|
||||||
|
self.current_seed += 1;
|
||||||
} else {
|
} else {
|
||||||
self.current_k += 1;
|
self.current_k += 1;
|
||||||
self.current_max_iter += 1;
|
self.current_max_iter += 1;
|
||||||
|
self.current_seed += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
Some(next)
|
Some(next)
|
||||||
@@ -203,6 +215,7 @@ impl Default for KMeansSearchParameters {
|
|||||||
KMeansSearchParameters {
|
KMeansSearchParameters {
|
||||||
k: vec![default_params.k],
|
k: vec![default_params.k],
|
||||||
max_iter: vec![default_params.max_iter],
|
max_iter: vec![default_params.max_iter],
|
||||||
|
seed: vec![default_params.seed],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -119,6 +119,8 @@ pub struct SVCSearchParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowV
|
|||||||
pub kernel: Vec<K>,
|
pub kernel: Vec<K>,
|
||||||
/// Unused parameter.
|
/// Unused parameter.
|
||||||
m: PhantomData<M>,
|
m: PhantomData<M>,
|
||||||
|
/// Controls the pseudo random number generation for shuffling the data for probability estimates
|
||||||
|
seed: Vec<Option<u64>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// SVC grid search iterator
|
/// SVC grid search iterator
|
||||||
@@ -128,6 +130,7 @@ pub struct SVCSearchParametersIterator<T: RealNumber, M: Matrix<T>, K: Kernel<T,
|
|||||||
current_c: usize,
|
current_c: usize,
|
||||||
current_tol: usize,
|
current_tol: usize,
|
||||||
current_kernel: usize,
|
current_kernel: usize,
|
||||||
|
current_seed: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> IntoIterator
|
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> IntoIterator
|
||||||
@@ -143,6 +146,7 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> IntoIterator
|
|||||||
current_c: 0,
|
current_c: 0,
|
||||||
current_tol: 0,
|
current_tol: 0,
|
||||||
current_kernel: 0,
|
current_kernel: 0,
|
||||||
|
current_seed: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -157,6 +161,7 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Iterator
|
|||||||
&& self.current_c == self.svc_search_parameters.c.len()
|
&& self.current_c == self.svc_search_parameters.c.len()
|
||||||
&& self.current_tol == self.svc_search_parameters.tol.len()
|
&& self.current_tol == self.svc_search_parameters.tol.len()
|
||||||
&& self.current_kernel == self.svc_search_parameters.kernel.len()
|
&& self.current_kernel == self.svc_search_parameters.kernel.len()
|
||||||
|
&& self.current_seed == self.svc_search_parameters.kernel.len()
|
||||||
{
|
{
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
@@ -167,6 +172,7 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Iterator
|
|||||||
tol: self.svc_search_parameters.tol[self.current_tol],
|
tol: self.svc_search_parameters.tol[self.current_tol],
|
||||||
kernel: self.svc_search_parameters.kernel[self.current_kernel].clone(),
|
kernel: self.svc_search_parameters.kernel[self.current_kernel].clone(),
|
||||||
m: PhantomData,
|
m: PhantomData,
|
||||||
|
seed: self.svc_search_parameters.seed[self.current_seed],
|
||||||
};
|
};
|
||||||
|
|
||||||
if self.current_epoch + 1 < self.svc_search_parameters.epoch.len() {
|
if self.current_epoch + 1 < self.svc_search_parameters.epoch.len() {
|
||||||
@@ -183,11 +189,18 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Iterator
|
|||||||
self.current_c = 0;
|
self.current_c = 0;
|
||||||
self.current_tol = 0;
|
self.current_tol = 0;
|
||||||
self.current_kernel += 1;
|
self.current_kernel += 1;
|
||||||
|
} else if self.current_kernel + 1 < self.svc_search_parameters.kernel.len() {
|
||||||
|
self.current_epoch = 0;
|
||||||
|
self.current_c = 0;
|
||||||
|
self.current_tol = 0;
|
||||||
|
self.current_kernel = 0;
|
||||||
|
self.current_seed += 1;
|
||||||
} else {
|
} else {
|
||||||
self.current_epoch += 1;
|
self.current_epoch += 1;
|
||||||
self.current_c += 1;
|
self.current_c += 1;
|
||||||
self.current_tol += 1;
|
self.current_tol += 1;
|
||||||
self.current_kernel += 1;
|
self.current_kernel += 1;
|
||||||
|
self.current_seed += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
Some(next)
|
Some(next)
|
||||||
@@ -204,6 +217,7 @@ impl<T: RealNumber, M: Matrix<T>> Default for SVCSearchParameters<T, M, LinearKe
|
|||||||
tol: vec![default_params.tol],
|
tol: vec![default_params.tol],
|
||||||
kernel: vec![default_params.kernel],
|
kernel: vec![default_params.kernel],
|
||||||
m: PhantomData,
|
m: PhantomData,
|
||||||
|
seed: vec![default_params.seed],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -209,14 +209,21 @@ impl Default for DecisionTreeClassifierParameters {
|
|||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct DecisionTreeClassifierSearchParameters {
|
pub struct DecisionTreeClassifierSearchParameters {
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||||
pub criterion: Vec<SplitCriterion>,
|
pub criterion: Vec<SplitCriterion>,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||||
pub max_depth: Vec<Option<u16>>,
|
pub max_depth: Vec<Option<u16>>,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||||
pub min_samples_leaf: Vec<usize>,
|
pub min_samples_leaf: Vec<usize>,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||||
pub min_samples_split: Vec<usize>,
|
pub min_samples_split: Vec<usize>,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
/// Controls the randomness of the estimator
|
||||||
|
pub seed: Vec<Option<u64>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// DecisionTreeClassifier grid search iterator
|
/// DecisionTreeClassifier grid search iterator
|
||||||
@@ -226,6 +233,7 @@ pub struct DecisionTreeClassifierSearchParametersIterator {
|
|||||||
current_max_depth: usize,
|
current_max_depth: usize,
|
||||||
current_min_samples_leaf: usize,
|
current_min_samples_leaf: usize,
|
||||||
current_min_samples_split: usize,
|
current_min_samples_split: usize,
|
||||||
|
current_seed: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IntoIterator for DecisionTreeClassifierSearchParameters {
|
impl IntoIterator for DecisionTreeClassifierSearchParameters {
|
||||||
@@ -239,6 +247,7 @@ impl IntoIterator for DecisionTreeClassifierSearchParameters {
|
|||||||
current_max_depth: 0,
|
current_max_depth: 0,
|
||||||
current_min_samples_leaf: 0,
|
current_min_samples_leaf: 0,
|
||||||
current_min_samples_split: 0,
|
current_min_samples_split: 0,
|
||||||
|
current_seed: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -267,6 +276,7 @@ impl Iterator for DecisionTreeClassifierSearchParametersIterator {
|
|||||||
.decision_tree_classifier_search_parameters
|
.decision_tree_classifier_search_parameters
|
||||||
.min_samples_split
|
.min_samples_split
|
||||||
.len()
|
.len()
|
||||||
|
&& self.current_seed == self.decision_tree_classifier_search_parameters.seed.len()
|
||||||
{
|
{
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
@@ -283,6 +293,7 @@ impl Iterator for DecisionTreeClassifierSearchParametersIterator {
|
|||||||
min_samples_split: self
|
min_samples_split: self
|
||||||
.decision_tree_classifier_search_parameters
|
.decision_tree_classifier_search_parameters
|
||||||
.min_samples_split[self.current_min_samples_split],
|
.min_samples_split[self.current_min_samples_split],
|
||||||
|
seed: self.decision_tree_classifier_search_parameters.seed[self.current_seed],
|
||||||
};
|
};
|
||||||
|
|
||||||
if self.current_criterion + 1
|
if self.current_criterion + 1
|
||||||
@@ -319,11 +330,19 @@ impl Iterator for DecisionTreeClassifierSearchParametersIterator {
|
|||||||
self.current_max_depth = 0;
|
self.current_max_depth = 0;
|
||||||
self.current_min_samples_leaf = 0;
|
self.current_min_samples_leaf = 0;
|
||||||
self.current_min_samples_split += 1;
|
self.current_min_samples_split += 1;
|
||||||
|
} else if self.current_seed + 1 < self.decision_tree_classifier_search_parameters.seed.len()
|
||||||
|
{
|
||||||
|
self.current_criterion = 0;
|
||||||
|
self.current_max_depth = 0;
|
||||||
|
self.current_min_samples_leaf = 0;
|
||||||
|
self.current_min_samples_split = 0;
|
||||||
|
self.current_seed += 1;
|
||||||
} else {
|
} else {
|
||||||
self.current_criterion += 1;
|
self.current_criterion += 1;
|
||||||
self.current_max_depth += 1;
|
self.current_max_depth += 1;
|
||||||
self.current_min_samples_leaf += 1;
|
self.current_min_samples_leaf += 1;
|
||||||
self.current_min_samples_split += 1;
|
self.current_min_samples_split += 1;
|
||||||
|
self.current_seed += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
Some(next)
|
Some(next)
|
||||||
@@ -339,6 +358,7 @@ impl Default for DecisionTreeClassifierSearchParameters {
|
|||||||
max_depth: vec![default_params.max_depth],
|
max_depth: vec![default_params.max_depth],
|
||||||
min_samples_leaf: vec![default_params.min_samples_leaf],
|
min_samples_leaf: vec![default_params.min_samples_leaf],
|
||||||
min_samples_split: vec![default_params.min_samples_split],
|
min_samples_split: vec![default_params.min_samples_split],
|
||||||
|
seed: vec![default_params.seed],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -148,6 +148,8 @@ pub struct DecisionTreeRegressorSearchParameters {
|
|||||||
pub min_samples_leaf: Vec<usize>,
|
pub min_samples_leaf: Vec<usize>,
|
||||||
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||||
pub min_samples_split: Vec<usize>,
|
pub min_samples_split: Vec<usize>,
|
||||||
|
/// Controls the randomness of the estimator
|
||||||
|
pub seed: Vec<Option<u64>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// DecisionTreeRegressor grid search iterator
|
/// DecisionTreeRegressor grid search iterator
|
||||||
@@ -156,6 +158,7 @@ pub struct DecisionTreeRegressorSearchParametersIterator {
|
|||||||
current_max_depth: usize,
|
current_max_depth: usize,
|
||||||
current_min_samples_leaf: usize,
|
current_min_samples_leaf: usize,
|
||||||
current_min_samples_split: usize,
|
current_min_samples_split: usize,
|
||||||
|
current_seed: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IntoIterator for DecisionTreeRegressorSearchParameters {
|
impl IntoIterator for DecisionTreeRegressorSearchParameters {
|
||||||
@@ -168,6 +171,7 @@ impl IntoIterator for DecisionTreeRegressorSearchParameters {
|
|||||||
current_max_depth: 0,
|
current_max_depth: 0,
|
||||||
current_min_samples_leaf: 0,
|
current_min_samples_leaf: 0,
|
||||||
current_min_samples_split: 0,
|
current_min_samples_split: 0,
|
||||||
|
current_seed: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -191,6 +195,7 @@ impl Iterator for DecisionTreeRegressorSearchParametersIterator {
|
|||||||
.decision_tree_regressor_search_parameters
|
.decision_tree_regressor_search_parameters
|
||||||
.min_samples_split
|
.min_samples_split
|
||||||
.len()
|
.len()
|
||||||
|
&& self.current_seed == self.decision_tree_regressor_search_parameters.seed.len()
|
||||||
{
|
{
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
@@ -204,6 +209,7 @@ impl Iterator for DecisionTreeRegressorSearchParametersIterator {
|
|||||||
min_samples_split: self
|
min_samples_split: self
|
||||||
.decision_tree_regressor_search_parameters
|
.decision_tree_regressor_search_parameters
|
||||||
.min_samples_split[self.current_min_samples_split],
|
.min_samples_split[self.current_min_samples_split],
|
||||||
|
seed: self.decision_tree_regressor_search_parameters.seed[self.current_seed],
|
||||||
};
|
};
|
||||||
|
|
||||||
if self.current_max_depth + 1
|
if self.current_max_depth + 1
|
||||||
@@ -230,10 +236,17 @@ impl Iterator for DecisionTreeRegressorSearchParametersIterator {
|
|||||||
self.current_max_depth = 0;
|
self.current_max_depth = 0;
|
||||||
self.current_min_samples_leaf = 0;
|
self.current_min_samples_leaf = 0;
|
||||||
self.current_min_samples_split += 1;
|
self.current_min_samples_split += 1;
|
||||||
|
} else if self.current_seed + 1 < self.decision_tree_regressor_search_parameters.seed.len()
|
||||||
|
{
|
||||||
|
self.current_max_depth = 0;
|
||||||
|
self.current_min_samples_leaf = 0;
|
||||||
|
self.current_min_samples_split = 0;
|
||||||
|
self.current_seed += 1;
|
||||||
} else {
|
} else {
|
||||||
self.current_max_depth += 1;
|
self.current_max_depth += 1;
|
||||||
self.current_min_samples_leaf += 1;
|
self.current_min_samples_leaf += 1;
|
||||||
self.current_min_samples_split += 1;
|
self.current_min_samples_split += 1;
|
||||||
|
self.current_seed += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
Some(next)
|
Some(next)
|
||||||
@@ -248,6 +261,7 @@ impl Default for DecisionTreeRegressorSearchParameters {
|
|||||||
max_depth: vec![default_params.max_depth],
|
max_depth: vec![default_params.max_depth],
|
||||||
min_samples_leaf: vec![default_params.min_samples_leaf],
|
min_samples_leaf: vec![default_params.min_samples_leaf],
|
||||||
min_samples_split: vec![default_params.min_samples_split],
|
min_samples_split: vec![default_params.min_samples_split],
|
||||||
|
seed: vec![default_params.seed],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user