Complete grid search params (#166)

* grid search draft

* hyperparam search for linear estimators

* grid search for ensembles

* support grid search for more algos

* grid search for unsupervised algos

* minor cleanup
This commit is contained in:
Montana Low
2022-09-21 12:34:21 -07:00
committed by morenol
parent cfa824d7db
commit 55e1158581
18 changed files with 1713 additions and 25 deletions
+93
View File
@@ -132,6 +132,76 @@ impl Default for KMeansParameters {
}
}
/// KMeans grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct KMeansSearchParameters {
/// Number of clusters.
pub k: Vec<usize>,
/// Maximum number of iterations of the k-means algorithm for a single run.
pub max_iter: Vec<usize>,
}
/// KMeans grid search iterator
pub struct KMeansSearchParametersIterator {
kmeans_search_parameters: KMeansSearchParameters,
current_k: usize,
current_max_iter: usize,
}
impl IntoIterator for KMeansSearchParameters {
type Item = KMeansParameters;
type IntoIter = KMeansSearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
KMeansSearchParametersIterator {
kmeans_search_parameters: self,
current_k: 0,
current_max_iter: 0,
}
}
}
impl Iterator for KMeansSearchParametersIterator {
type Item = KMeansParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_k == self.kmeans_search_parameters.k.len()
&& self.current_max_iter == self.kmeans_search_parameters.max_iter.len()
{
return None;
}
let next = KMeansParameters {
k: self.kmeans_search_parameters.k[self.current_k],
max_iter: self.kmeans_search_parameters.max_iter[self.current_max_iter],
};
if self.current_k + 1 < self.kmeans_search_parameters.k.len() {
self.current_k += 1;
} else if self.current_max_iter + 1 < self.kmeans_search_parameters.max_iter.len() {
self.current_k = 0;
self.current_max_iter += 1;
} else {
self.current_k += 1;
self.current_max_iter += 1;
}
Some(next)
}
}
impl Default for KMeansSearchParameters {
fn default() -> Self {
let default_params = KMeansParameters::default();
KMeansSearchParameters {
k: vec![default_params.k],
max_iter: vec![default_params.max_iter],
}
}
}
impl<T: RealNumber + Sum, M: Matrix<T>> UnsupervisedEstimator<M, KMeansParameters> for KMeans<T> {
fn fit(x: &M, parameters: KMeansParameters) -> Result<Self, Failed> {
KMeans::fit(x, parameters)
@@ -313,6 +383,29 @@ mod tests {
);
}
#[test]
fn search_parameters() {
let parameters = KMeansSearchParameters {
k: vec![2, 4],
max_iter: vec![10, 100],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.k, 2);
assert_eq!(next.max_iter, 10);
let next = iter.next().unwrap();
assert_eq!(next.k, 4);
assert_eq!(next.max_iter, 10);
let next = iter.next().unwrap();
assert_eq!(next.k, 2);
assert_eq!(next.max_iter, 100);
let next = iter.next().unwrap();
assert_eq!(next.k, 4);
assert_eq!(next.max_iter, 100);
assert!(iter.next().is_none());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn fit_predict_iris() {