Complete grid search params (#166)
* grid search draft * hyperparam search for linear estimators * grid search for ensembles * support grid search for more algos * grid search for unsupervised algos * minor cleanup
This commit is contained in:
@@ -109,6 +109,103 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> DBSCANParameters<T, D> {
|
||||
}
|
||||
}
|
||||
|
||||
/// DBSCAN grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DBSCANSearchParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
/// a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
pub distance: Vec<D>,
|
||||
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
|
||||
pub min_samples: Vec<usize>,
|
||||
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
|
||||
pub eps: Vec<T>,
|
||||
/// KNN algorithm to use.
|
||||
pub algorithm: Vec<KNNAlgorithmName>,
|
||||
}
|
||||
|
||||
/// DBSCAN grid search iterator
|
||||
pub struct DBSCANSearchParametersIterator<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
dbscan_search_parameters: DBSCANSearchParameters<T, D>,
|
||||
current_distance: usize,
|
||||
current_min_samples: usize,
|
||||
current_eps: usize,
|
||||
current_algorithm: usize,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> IntoIterator for DBSCANSearchParameters<T, D> {
|
||||
type Item = DBSCANParameters<T, D>;
|
||||
type IntoIter = DBSCANSearchParametersIterator<T, D>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
DBSCANSearchParametersIterator {
|
||||
dbscan_search_parameters: self,
|
||||
current_distance: 0,
|
||||
current_min_samples: 0,
|
||||
current_eps: 0,
|
||||
current_algorithm: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> Iterator for DBSCANSearchParametersIterator<T, D> {
|
||||
type Item = DBSCANParameters<T, D>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_distance == self.dbscan_search_parameters.distance.len()
|
||||
&& self.current_min_samples == self.dbscan_search_parameters.min_samples.len()
|
||||
&& self.current_eps == self.dbscan_search_parameters.eps.len()
|
||||
&& self.current_algorithm == self.dbscan_search_parameters.algorithm.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = DBSCANParameters {
|
||||
distance: self.dbscan_search_parameters.distance[self.current_distance].clone(),
|
||||
min_samples: self.dbscan_search_parameters.min_samples[self.current_min_samples],
|
||||
eps: self.dbscan_search_parameters.eps[self.current_eps],
|
||||
algorithm: self.dbscan_search_parameters.algorithm[self.current_algorithm].clone(),
|
||||
};
|
||||
|
||||
if self.current_distance + 1 < self.dbscan_search_parameters.distance.len() {
|
||||
self.current_distance += 1;
|
||||
} else if self.current_min_samples + 1 < self.dbscan_search_parameters.min_samples.len() {
|
||||
self.current_distance = 0;
|
||||
self.current_min_samples += 1;
|
||||
} else if self.current_eps + 1 < self.dbscan_search_parameters.eps.len() {
|
||||
self.current_distance = 0;
|
||||
self.current_min_samples = 0;
|
||||
self.current_eps += 1;
|
||||
} else if self.current_algorithm + 1 < self.dbscan_search_parameters.algorithm.len() {
|
||||
self.current_distance = 0;
|
||||
self.current_min_samples = 0;
|
||||
self.current_eps = 0;
|
||||
self.current_algorithm += 1;
|
||||
} else {
|
||||
self.current_distance += 1;
|
||||
self.current_min_samples += 1;
|
||||
self.current_eps += 1;
|
||||
self.current_algorithm += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for DBSCANSearchParameters<T, Euclidian> {
|
||||
fn default() -> Self {
|
||||
let default_params = DBSCANParameters::default();
|
||||
|
||||
DBSCANSearchParameters {
|
||||
distance: vec![default_params.distance],
|
||||
min_samples: vec![default_params.min_samples],
|
||||
eps: vec![default_params.eps],
|
||||
algorithm: vec![default_params.algorithm],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> PartialEq for DBSCAN<T, D> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.cluster_labels.len() == other.cluster_labels.len()
|
||||
@@ -268,6 +365,29 @@ mod tests {
|
||||
#[cfg(feature = "serde")]
|
||||
use crate::math::distance::euclidian::Euclidian;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = DBSCANSearchParameters {
|
||||
min_samples: vec![10, 100],
|
||||
eps: vec![1., 2.],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.min_samples, 10);
|
||||
assert_eq!(next.eps, 1.);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.min_samples, 100);
|
||||
assert_eq!(next.eps, 1.);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.min_samples, 10);
|
||||
assert_eq!(next.eps, 2.);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.min_samples, 100);
|
||||
assert_eq!(next.eps, 2.);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn fit_predict_dbscan() {
|
||||
|
||||
Reference in New Issue
Block a user