Complete grid search params (#166)
* grid search draft * hyperparam search for linear estimators * grid search for ensembles * support grid search for more algos * grid search for unsupervised algos * minor cleanup
This commit is contained in:
@@ -132,6 +132,76 @@ impl Default for KMeansParameters {
|
||||
}
|
||||
}
|
||||
|
||||
/// KMeans grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KMeansSearchParameters {
|
||||
/// Number of clusters.
|
||||
pub k: Vec<usize>,
|
||||
/// Maximum number of iterations of the k-means algorithm for a single run.
|
||||
pub max_iter: Vec<usize>,
|
||||
}
|
||||
|
||||
/// KMeans grid search iterator
|
||||
pub struct KMeansSearchParametersIterator {
|
||||
kmeans_search_parameters: KMeansSearchParameters,
|
||||
current_k: usize,
|
||||
current_max_iter: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for KMeansSearchParameters {
|
||||
type Item = KMeansParameters;
|
||||
type IntoIter = KMeansSearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
KMeansSearchParametersIterator {
|
||||
kmeans_search_parameters: self,
|
||||
current_k: 0,
|
||||
current_max_iter: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for KMeansSearchParametersIterator {
|
||||
type Item = KMeansParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_k == self.kmeans_search_parameters.k.len()
|
||||
&& self.current_max_iter == self.kmeans_search_parameters.max_iter.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = KMeansParameters {
|
||||
k: self.kmeans_search_parameters.k[self.current_k],
|
||||
max_iter: self.kmeans_search_parameters.max_iter[self.current_max_iter],
|
||||
};
|
||||
|
||||
if self.current_k + 1 < self.kmeans_search_parameters.k.len() {
|
||||
self.current_k += 1;
|
||||
} else if self.current_max_iter + 1 < self.kmeans_search_parameters.max_iter.len() {
|
||||
self.current_k = 0;
|
||||
self.current_max_iter += 1;
|
||||
} else {
|
||||
self.current_k += 1;
|
||||
self.current_max_iter += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for KMeansSearchParameters {
|
||||
fn default() -> Self {
|
||||
let default_params = KMeansParameters::default();
|
||||
|
||||
KMeansSearchParameters {
|
||||
k: vec![default_params.k],
|
||||
max_iter: vec![default_params.max_iter],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber + Sum, M: Matrix<T>> UnsupervisedEstimator<M, KMeansParameters> for KMeans<T> {
|
||||
fn fit(x: &M, parameters: KMeansParameters) -> Result<Self, Failed> {
|
||||
KMeans::fit(x, parameters)
|
||||
@@ -313,6 +383,29 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = KMeansSearchParameters {
|
||||
k: vec![2, 4],
|
||||
max_iter: vec![10, 100],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.k, 2);
|
||||
assert_eq!(next.max_iter, 10);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.k, 4);
|
||||
assert_eq!(next.max_iter, 10);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.k, 2);
|
||||
assert_eq!(next.max_iter, 100);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.k, 4);
|
||||
assert_eq!(next.max_iter, 100);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn fit_predict_iris() {
|
||||
|
||||
Reference in New Issue
Block a user