Complete grid search params (#166)

* grid search draft

* hyperparam search for linear estimators

* grid search for ensembles

* support grid search for more algos

* grid search for unsupervised algos

* minor cleanup
This commit is contained in:
Montana Low
2022-09-21 12:34:21 -07:00
committed by GitHub
parent 69d8be35de
commit 48514d1b15
18 changed files with 1713 additions and 25 deletions
+120
View File
@@ -109,6 +109,103 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> DBSCANParameters<T, D> {
}
}
/// DBSCAN grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct DBSCANSearchParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
/// a function that defines a distance between each pair of point in training data.
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
pub distance: Vec<D>,
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
pub min_samples: Vec<usize>,
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
pub eps: Vec<T>,
/// KNN algorithm to use.
pub algorithm: Vec<KNNAlgorithmName>,
}
/// DBSCAN grid search iterator
pub struct DBSCANSearchParametersIterator<T: RealNumber, D: Distance<Vec<T>, T>> {
dbscan_search_parameters: DBSCANSearchParameters<T, D>,
current_distance: usize,
current_min_samples: usize,
current_eps: usize,
current_algorithm: usize,
}
impl<T: RealNumber, D: Distance<Vec<T>, T>> IntoIterator for DBSCANSearchParameters<T, D> {
type Item = DBSCANParameters<T, D>;
type IntoIter = DBSCANSearchParametersIterator<T, D>;
fn into_iter(self) -> Self::IntoIter {
DBSCANSearchParametersIterator {
dbscan_search_parameters: self,
current_distance: 0,
current_min_samples: 0,
current_eps: 0,
current_algorithm: 0,
}
}
}
impl<T: RealNumber, D: Distance<Vec<T>, T>> Iterator for DBSCANSearchParametersIterator<T, D> {
type Item = DBSCANParameters<T, D>;
fn next(&mut self) -> Option<Self::Item> {
if self.current_distance == self.dbscan_search_parameters.distance.len()
&& self.current_min_samples == self.dbscan_search_parameters.min_samples.len()
&& self.current_eps == self.dbscan_search_parameters.eps.len()
&& self.current_algorithm == self.dbscan_search_parameters.algorithm.len()
{
return None;
}
let next = DBSCANParameters {
distance: self.dbscan_search_parameters.distance[self.current_distance].clone(),
min_samples: self.dbscan_search_parameters.min_samples[self.current_min_samples],
eps: self.dbscan_search_parameters.eps[self.current_eps],
algorithm: self.dbscan_search_parameters.algorithm[self.current_algorithm].clone(),
};
if self.current_distance + 1 < self.dbscan_search_parameters.distance.len() {
self.current_distance += 1;
} else if self.current_min_samples + 1 < self.dbscan_search_parameters.min_samples.len() {
self.current_distance = 0;
self.current_min_samples += 1;
} else if self.current_eps + 1 < self.dbscan_search_parameters.eps.len() {
self.current_distance = 0;
self.current_min_samples = 0;
self.current_eps += 1;
} else if self.current_algorithm + 1 < self.dbscan_search_parameters.algorithm.len() {
self.current_distance = 0;
self.current_min_samples = 0;
self.current_eps = 0;
self.current_algorithm += 1;
} else {
self.current_distance += 1;
self.current_min_samples += 1;
self.current_eps += 1;
self.current_algorithm += 1;
}
Some(next)
}
}
impl<T: RealNumber> Default for DBSCANSearchParameters<T, Euclidian> {
fn default() -> Self {
let default_params = DBSCANParameters::default();
DBSCANSearchParameters {
distance: vec![default_params.distance],
min_samples: vec![default_params.min_samples],
eps: vec![default_params.eps],
algorithm: vec![default_params.algorithm],
}
}
}
impl<T: RealNumber, D: Distance<Vec<T>, T>> PartialEq for DBSCAN<T, D> {
fn eq(&self, other: &Self) -> bool {
self.cluster_labels.len() == other.cluster_labels.len()
@@ -268,6 +365,29 @@ mod tests {
#[cfg(feature = "serde")]
use crate::math::distance::euclidian::Euclidian;
#[test]
fn search_parameters() {
let parameters = DBSCANSearchParameters {
min_samples: vec![10, 100],
eps: vec![1., 2.],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.min_samples, 10);
assert_eq!(next.eps, 1.);
let next = iter.next().unwrap();
assert_eq!(next.min_samples, 100);
assert_eq!(next.eps, 1.);
let next = iter.next().unwrap();
assert_eq!(next.min_samples, 10);
assert_eq!(next.eps, 2.);
let next = iter.next().unwrap();
assert_eq!(next.min_samples, 100);
assert_eq!(next.eps, 2.);
assert!(iter.next().is_none());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn fit_predict_dbscan() {