add seed param to search params (#168)
This commit is contained in:
@@ -145,6 +145,9 @@ pub struct KMeansSearchParameters {
|
||||
pub k: Vec<usize>,
|
||||
/// Maximum number of iterations of the k-means algorithm for a single run.
|
||||
pub max_iter: Vec<usize>,
|
||||
/// Determines random number generation for centroid initialization.
|
||||
/// Use an int to make the randomness deterministic
|
||||
pub seed: Vec<Option<u64>>,
|
||||
}
|
||||
|
||||
/// KMeans grid search iterator
|
||||
@@ -152,6 +155,7 @@ pub struct KMeansSearchParametersIterator {
|
||||
kmeans_search_parameters: KMeansSearchParameters,
|
||||
current_k: usize,
|
||||
current_max_iter: usize,
|
||||
current_seed: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for KMeansSearchParameters {
|
||||
@@ -163,6 +167,7 @@ impl IntoIterator for KMeansSearchParameters {
|
||||
kmeans_search_parameters: self,
|
||||
current_k: 0,
|
||||
current_max_iter: 0,
|
||||
current_seed: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -173,6 +178,7 @@ impl Iterator for KMeansSearchParametersIterator {
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_k == self.kmeans_search_parameters.k.len()
|
||||
&& self.current_max_iter == self.kmeans_search_parameters.max_iter.len()
|
||||
&& self.current_seed == self.kmeans_search_parameters.seed.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
@@ -180,6 +186,7 @@ impl Iterator for KMeansSearchParametersIterator {
|
||||
let next = KMeansParameters {
|
||||
k: self.kmeans_search_parameters.k[self.current_k],
|
||||
max_iter: self.kmeans_search_parameters.max_iter[self.current_max_iter],
|
||||
seed: self.kmeans_search_parameters.seed[self.current_seed],
|
||||
};
|
||||
|
||||
if self.current_k + 1 < self.kmeans_search_parameters.k.len() {
|
||||
@@ -187,9 +194,14 @@ impl Iterator for KMeansSearchParametersIterator {
|
||||
} else if self.current_max_iter + 1 < self.kmeans_search_parameters.max_iter.len() {
|
||||
self.current_k = 0;
|
||||
self.current_max_iter += 1;
|
||||
} else if self.current_seed + 1 < self.kmeans_search_parameters.seed.len() {
|
||||
self.current_k = 0;
|
||||
self.current_max_iter = 0;
|
||||
self.current_seed += 1;
|
||||
} else {
|
||||
self.current_k += 1;
|
||||
self.current_max_iter += 1;
|
||||
self.current_seed += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
@@ -203,6 +215,7 @@ impl Default for KMeansSearchParameters {
|
||||
KMeansSearchParameters {
|
||||
k: vec![default_params.k],
|
||||
max_iter: vec![default_params.max_iter],
|
||||
seed: vec![default_params.seed],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user