Complete grid search params (#166)

* grid search draft

* hyperparam search for linear estimators

* grid search for ensembles

* support grid search for more algos

* grid search for unsupervised algos

* minor cleanup
This commit is contained in:
Montana Low
2022-09-21 12:34:21 -07:00
committed by GitHub
parent 69d8be35de
commit 48514d1b15
18 changed files with 1713 additions and 25 deletions
+96
View File
@@ -150,6 +150,88 @@ impl<T: RealNumber> Default for BernoulliNBParameters<T> {
}
}
/// BernoulliNB grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct BernoulliNBSearchParameters<T: RealNumber> {
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
pub alpha: Vec<T>,
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
pub priors: Vec<Option<Vec<T>>>,
/// Threshold for binarizing (mapping to booleans) of sample features. If None, input is presumed to already consist of binary vectors.
pub binarize: Vec<Option<T>>,
}
/// BernoulliNB grid search iterator
pub struct BernoulliNBSearchParametersIterator<T: RealNumber> {
bernoulli_nb_search_parameters: BernoulliNBSearchParameters<T>,
current_alpha: usize,
current_priors: usize,
current_binarize: usize,
}
impl<T: RealNumber> IntoIterator for BernoulliNBSearchParameters<T> {
type Item = BernoulliNBParameters<T>;
type IntoIter = BernoulliNBSearchParametersIterator<T>;
fn into_iter(self) -> Self::IntoIter {
BernoulliNBSearchParametersIterator {
bernoulli_nb_search_parameters: self,
current_alpha: 0,
current_priors: 0,
current_binarize: 0,
}
}
}
impl<T: RealNumber> Iterator for BernoulliNBSearchParametersIterator<T> {
type Item = BernoulliNBParameters<T>;
fn next(&mut self) -> Option<Self::Item> {
if self.current_alpha == self.bernoulli_nb_search_parameters.alpha.len()
&& self.current_priors == self.bernoulli_nb_search_parameters.priors.len()
&& self.current_binarize == self.bernoulli_nb_search_parameters.binarize.len()
{
return None;
}
let next = BernoulliNBParameters {
alpha: self.bernoulli_nb_search_parameters.alpha[self.current_alpha],
priors: self.bernoulli_nb_search_parameters.priors[self.current_priors].clone(),
binarize: self.bernoulli_nb_search_parameters.binarize[self.current_binarize],
};
if self.current_alpha + 1 < self.bernoulli_nb_search_parameters.alpha.len() {
self.current_alpha += 1;
} else if self.current_priors + 1 < self.bernoulli_nb_search_parameters.priors.len() {
self.current_alpha = 0;
self.current_priors += 1;
} else if self.current_binarize + 1 < self.bernoulli_nb_search_parameters.binarize.len() {
self.current_alpha = 0;
self.current_priors = 0;
self.current_binarize += 1;
} else {
self.current_alpha += 1;
self.current_priors += 1;
self.current_binarize += 1;
}
Some(next)
}
}
impl<T: RealNumber> Default for BernoulliNBSearchParameters<T> {
fn default() -> Self {
let default_params = BernoulliNBParameters::default();
BernoulliNBSearchParameters {
alpha: vec![default_params.alpha],
priors: vec![default_params.priors],
binarize: vec![default_params.binarize],
}
}
}
impl<T: RealNumber> BernoulliNBDistribution<T> {
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
/// * `x` - training data.
@@ -347,6 +429,20 @@ mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix;
#[test]
fn search_parameters() {
let parameters = BernoulliNBSearchParameters {
alpha: vec![1., 2.],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.alpha, 1.);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 2.);
assert!(iter.next().is_none());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn run_bernoulli_naive_bayes() {