Complete grid search params (#166)
* grid search draft * hyperparam search for linear estimators * grid search for ensembles * support grid search for more algos * grid search for unsupervised algos * minor cleanup
This commit is contained in:
@@ -116,6 +116,81 @@ impl Default for PCAParameters {
|
||||
}
|
||||
}
|
||||
|
||||
/// PCA grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PCASearchParameters {
|
||||
/// Number of components to keep.
|
||||
pub n_components: Vec<usize>,
|
||||
/// By default, covariance matrix is used to compute principal components.
|
||||
/// Enable this flag if you want to use correlation matrix instead.
|
||||
pub use_correlation_matrix: Vec<bool>,
|
||||
}
|
||||
|
||||
/// PCA grid search iterator
|
||||
pub struct PCASearchParametersIterator {
|
||||
pca_search_parameters: PCASearchParameters,
|
||||
current_k: usize,
|
||||
current_use_correlation_matrix: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for PCASearchParameters {
|
||||
type Item = PCAParameters;
|
||||
type IntoIter = PCASearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
PCASearchParametersIterator {
|
||||
pca_search_parameters: self,
|
||||
current_k: 0,
|
||||
current_use_correlation_matrix: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for PCASearchParametersIterator {
|
||||
type Item = PCAParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_k == self.pca_search_parameters.n_components.len()
|
||||
&& self.current_use_correlation_matrix
|
||||
== self.pca_search_parameters.use_correlation_matrix.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = PCAParameters {
|
||||
n_components: self.pca_search_parameters.n_components[self.current_k],
|
||||
use_correlation_matrix: self.pca_search_parameters.use_correlation_matrix
|
||||
[self.current_use_correlation_matrix],
|
||||
};
|
||||
|
||||
if self.current_k + 1 < self.pca_search_parameters.n_components.len() {
|
||||
self.current_k += 1;
|
||||
} else if self.current_use_correlation_matrix + 1
|
||||
< self.pca_search_parameters.use_correlation_matrix.len()
|
||||
{
|
||||
self.current_k = 0;
|
||||
self.current_use_correlation_matrix += 1;
|
||||
} else {
|
||||
self.current_k += 1;
|
||||
self.current_use_correlation_matrix += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PCASearchParameters {
|
||||
fn default() -> Self {
|
||||
let default_params = PCAParameters::default();
|
||||
|
||||
PCASearchParameters {
|
||||
n_components: vec![default_params.n_components],
|
||||
use_correlation_matrix: vec![default_params.use_correlation_matrix],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> UnsupervisedEstimator<M, PCAParameters> for PCA<T, M> {
|
||||
fn fit(x: &M, parameters: PCAParameters) -> Result<Self, Failed> {
|
||||
PCA::fit(x, parameters)
|
||||
@@ -271,6 +346,29 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = PCASearchParameters {
|
||||
n_components: vec![2, 4],
|
||||
use_correlation_matrix: vec![true, false],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_components, 2);
|
||||
assert_eq!(next.use_correlation_matrix, true);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_components, 4);
|
||||
assert_eq!(next.use_correlation_matrix, true);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_components, 2);
|
||||
assert_eq!(next.use_correlation_matrix, false);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_components, 4);
|
||||
assert_eq!(next.use_correlation_matrix, false);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
fn us_arrests_data() -> DenseMatrix<f64> {
|
||||
DenseMatrix::from_2d_array(&[
|
||||
&[13.2, 236.0, 58.0, 21.2],
|
||||
|
||||
Reference in New Issue
Block a user