Complete grid search params (#166)

* grid search draft

* hyperparam search for linear estimators

* grid search for ensembles

* support grid search for more algos

* grid search for unsupervised algos

* minor cleanup
This commit is contained in:
Montana Low
2022-09-21 12:34:21 -07:00
committed by GitHub
parent 69d8be35de
commit 48514d1b15
18 changed files with 1713 additions and 25 deletions
+120
View File
@@ -109,6 +109,103 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> DBSCANParameters<T, D> {
} }
} }
/// DBSCAN grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct DBSCANSearchParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
/// a function that defines a distance between each pair of point in training data.
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
pub distance: Vec<D>,
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
pub min_samples: Vec<usize>,
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
pub eps: Vec<T>,
/// KNN algorithm to use.
pub algorithm: Vec<KNNAlgorithmName>,
}
/// DBSCAN grid search iterator
pub struct DBSCANSearchParametersIterator<T: RealNumber, D: Distance<Vec<T>, T>> {
dbscan_search_parameters: DBSCANSearchParameters<T, D>,
current_distance: usize,
current_min_samples: usize,
current_eps: usize,
current_algorithm: usize,
}
impl<T: RealNumber, D: Distance<Vec<T>, T>> IntoIterator for DBSCANSearchParameters<T, D> {
type Item = DBSCANParameters<T, D>;
type IntoIter = DBSCANSearchParametersIterator<T, D>;
fn into_iter(self) -> Self::IntoIter {
DBSCANSearchParametersIterator {
dbscan_search_parameters: self,
current_distance: 0,
current_min_samples: 0,
current_eps: 0,
current_algorithm: 0,
}
}
}
impl<T: RealNumber, D: Distance<Vec<T>, T>> Iterator for DBSCANSearchParametersIterator<T, D> {
type Item = DBSCANParameters<T, D>;
fn next(&mut self) -> Option<Self::Item> {
if self.current_distance == self.dbscan_search_parameters.distance.len()
&& self.current_min_samples == self.dbscan_search_parameters.min_samples.len()
&& self.current_eps == self.dbscan_search_parameters.eps.len()
&& self.current_algorithm == self.dbscan_search_parameters.algorithm.len()
{
return None;
}
let next = DBSCANParameters {
distance: self.dbscan_search_parameters.distance[self.current_distance].clone(),
min_samples: self.dbscan_search_parameters.min_samples[self.current_min_samples],
eps: self.dbscan_search_parameters.eps[self.current_eps],
algorithm: self.dbscan_search_parameters.algorithm[self.current_algorithm].clone(),
};
if self.current_distance + 1 < self.dbscan_search_parameters.distance.len() {
self.current_distance += 1;
} else if self.current_min_samples + 1 < self.dbscan_search_parameters.min_samples.len() {
self.current_distance = 0;
self.current_min_samples += 1;
} else if self.current_eps + 1 < self.dbscan_search_parameters.eps.len() {
self.current_distance = 0;
self.current_min_samples = 0;
self.current_eps += 1;
} else if self.current_algorithm + 1 < self.dbscan_search_parameters.algorithm.len() {
self.current_distance = 0;
self.current_min_samples = 0;
self.current_eps = 0;
self.current_algorithm += 1;
} else {
self.current_distance += 1;
self.current_min_samples += 1;
self.current_eps += 1;
self.current_algorithm += 1;
}
Some(next)
}
}
impl<T: RealNumber> Default for DBSCANSearchParameters<T, Euclidian> {
fn default() -> Self {
let default_params = DBSCANParameters::default();
DBSCANSearchParameters {
distance: vec![default_params.distance],
min_samples: vec![default_params.min_samples],
eps: vec![default_params.eps],
algorithm: vec![default_params.algorithm],
}
}
}
impl<T: RealNumber, D: Distance<Vec<T>, T>> PartialEq for DBSCAN<T, D> { impl<T: RealNumber, D: Distance<Vec<T>, T>> PartialEq for DBSCAN<T, D> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.cluster_labels.len() == other.cluster_labels.len() self.cluster_labels.len() == other.cluster_labels.len()
@@ -268,6 +365,29 @@ mod tests {
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
use crate::math::distance::euclidian::Euclidian; use crate::math::distance::euclidian::Euclidian;
#[test]
fn search_parameters() {
let parameters = DBSCANSearchParameters {
min_samples: vec![10, 100],
eps: vec![1., 2.],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.min_samples, 10);
assert_eq!(next.eps, 1.);
let next = iter.next().unwrap();
assert_eq!(next.min_samples, 100);
assert_eq!(next.eps, 1.);
let next = iter.next().unwrap();
assert_eq!(next.min_samples, 10);
assert_eq!(next.eps, 2.);
let next = iter.next().unwrap();
assert_eq!(next.min_samples, 100);
assert_eq!(next.eps, 2.);
assert!(iter.next().is_none());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn fit_predict_dbscan() { fn fit_predict_dbscan() {
+93
View File
@@ -132,6 +132,76 @@ impl Default for KMeansParameters {
} }
} }
/// KMeans grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct KMeansSearchParameters {
/// Number of clusters.
pub k: Vec<usize>,
/// Maximum number of iterations of the k-means algorithm for a single run.
pub max_iter: Vec<usize>,
}
/// KMeans grid search iterator
pub struct KMeansSearchParametersIterator {
kmeans_search_parameters: KMeansSearchParameters,
current_k: usize,
current_max_iter: usize,
}
impl IntoIterator for KMeansSearchParameters {
type Item = KMeansParameters;
type IntoIter = KMeansSearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
KMeansSearchParametersIterator {
kmeans_search_parameters: self,
current_k: 0,
current_max_iter: 0,
}
}
}
impl Iterator for KMeansSearchParametersIterator {
type Item = KMeansParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_k == self.kmeans_search_parameters.k.len()
&& self.current_max_iter == self.kmeans_search_parameters.max_iter.len()
{
return None;
}
let next = KMeansParameters {
k: self.kmeans_search_parameters.k[self.current_k],
max_iter: self.kmeans_search_parameters.max_iter[self.current_max_iter],
};
if self.current_k + 1 < self.kmeans_search_parameters.k.len() {
self.current_k += 1;
} else if self.current_max_iter + 1 < self.kmeans_search_parameters.max_iter.len() {
self.current_k = 0;
self.current_max_iter += 1;
} else {
self.current_k += 1;
self.current_max_iter += 1;
}
Some(next)
}
}
impl Default for KMeansSearchParameters {
fn default() -> Self {
let default_params = KMeansParameters::default();
KMeansSearchParameters {
k: vec![default_params.k],
max_iter: vec![default_params.max_iter],
}
}
}
impl<T: RealNumber + Sum, M: Matrix<T>> UnsupervisedEstimator<M, KMeansParameters> for KMeans<T> { impl<T: RealNumber + Sum, M: Matrix<T>> UnsupervisedEstimator<M, KMeansParameters> for KMeans<T> {
fn fit(x: &M, parameters: KMeansParameters) -> Result<Self, Failed> { fn fit(x: &M, parameters: KMeansParameters) -> Result<Self, Failed> {
KMeans::fit(x, parameters) KMeans::fit(x, parameters)
@@ -313,6 +383,29 @@ mod tests {
); );
} }
#[test]
fn search_parameters() {
let parameters = KMeansSearchParameters {
k: vec![2, 4],
max_iter: vec![10, 100],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.k, 2);
assert_eq!(next.max_iter, 10);
let next = iter.next().unwrap();
assert_eq!(next.k, 4);
assert_eq!(next.max_iter, 10);
let next = iter.next().unwrap();
assert_eq!(next.k, 2);
assert_eq!(next.max_iter, 100);
let next = iter.next().unwrap();
assert_eq!(next.k, 4);
assert_eq!(next.max_iter, 100);
assert!(iter.next().is_none());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn fit_predict_iris() { fn fit_predict_iris() {
+98
View File
@@ -116,6 +116,81 @@ impl Default for PCAParameters {
} }
} }
/// PCA grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct PCASearchParameters {
/// Number of components to keep.
pub n_components: Vec<usize>,
/// By default, covariance matrix is used to compute principal components.
/// Enable this flag if you want to use correlation matrix instead.
pub use_correlation_matrix: Vec<bool>,
}
/// PCA grid search iterator
pub struct PCASearchParametersIterator {
pca_search_parameters: PCASearchParameters,
current_k: usize,
current_use_correlation_matrix: usize,
}
impl IntoIterator for PCASearchParameters {
type Item = PCAParameters;
type IntoIter = PCASearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
PCASearchParametersIterator {
pca_search_parameters: self,
current_k: 0,
current_use_correlation_matrix: 0,
}
}
}
impl Iterator for PCASearchParametersIterator {
type Item = PCAParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_k == self.pca_search_parameters.n_components.len()
&& self.current_use_correlation_matrix
== self.pca_search_parameters.use_correlation_matrix.len()
{
return None;
}
let next = PCAParameters {
n_components: self.pca_search_parameters.n_components[self.current_k],
use_correlation_matrix: self.pca_search_parameters.use_correlation_matrix
[self.current_use_correlation_matrix],
};
if self.current_k + 1 < self.pca_search_parameters.n_components.len() {
self.current_k += 1;
} else if self.current_use_correlation_matrix + 1
< self.pca_search_parameters.use_correlation_matrix.len()
{
self.current_k = 0;
self.current_use_correlation_matrix += 1;
} else {
self.current_k += 1;
self.current_use_correlation_matrix += 1;
}
Some(next)
}
}
impl Default for PCASearchParameters {
fn default() -> Self {
let default_params = PCAParameters::default();
PCASearchParameters {
n_components: vec![default_params.n_components],
use_correlation_matrix: vec![default_params.use_correlation_matrix],
}
}
}
impl<T: RealNumber, M: Matrix<T>> UnsupervisedEstimator<M, PCAParameters> for PCA<T, M> { impl<T: RealNumber, M: Matrix<T>> UnsupervisedEstimator<M, PCAParameters> for PCA<T, M> {
fn fit(x: &M, parameters: PCAParameters) -> Result<Self, Failed> { fn fit(x: &M, parameters: PCAParameters) -> Result<Self, Failed> {
PCA::fit(x, parameters) PCA::fit(x, parameters)
@@ -271,6 +346,29 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::*; use crate::linalg::naive::dense_matrix::*;
#[test]
fn search_parameters() {
let parameters = PCASearchParameters {
n_components: vec![2, 4],
use_correlation_matrix: vec![true, false],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.n_components, 2);
assert_eq!(next.use_correlation_matrix, true);
let next = iter.next().unwrap();
assert_eq!(next.n_components, 4);
assert_eq!(next.use_correlation_matrix, true);
let next = iter.next().unwrap();
assert_eq!(next.n_components, 2);
assert_eq!(next.use_correlation_matrix, false);
let next = iter.next().unwrap();
assert_eq!(next.n_components, 4);
assert_eq!(next.use_correlation_matrix, false);
assert!(iter.next().is_none());
}
fn us_arrests_data() -> DenseMatrix<f64> { fn us_arrests_data() -> DenseMatrix<f64> {
DenseMatrix::from_2d_array(&[ DenseMatrix::from_2d_array(&[
&[13.2, 236.0, 58.0, 21.2], &[13.2, 236.0, 58.0, 21.2],
+68
View File
@@ -90,6 +90,60 @@ impl SVDParameters {
} }
} }
/// SVD grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct SVDSearchParameters {
/// Maximum number of iterations of the k-means algorithm for a single run.
pub n_components: Vec<usize>,
}
/// SVD grid search iterator
pub struct SVDSearchParametersIterator {
svd_search_parameters: SVDSearchParameters,
current_n_components: usize,
}
impl IntoIterator for SVDSearchParameters {
type Item = SVDParameters;
type IntoIter = SVDSearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
SVDSearchParametersIterator {
svd_search_parameters: self,
current_n_components: 0,
}
}
}
impl Iterator for SVDSearchParametersIterator {
type Item = SVDParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_n_components == self.svd_search_parameters.n_components.len() {
return None;
}
let next = SVDParameters {
n_components: self.svd_search_parameters.n_components[self.current_n_components],
};
self.current_n_components += 1;
Some(next)
}
}
impl Default for SVDSearchParameters {
fn default() -> Self {
let default_params = SVDParameters::default();
SVDSearchParameters {
n_components: vec![default_params.n_components],
}
}
}
impl<T: RealNumber, M: Matrix<T>> UnsupervisedEstimator<M, SVDParameters> for SVD<T, M> { impl<T: RealNumber, M: Matrix<T>> UnsupervisedEstimator<M, SVDParameters> for SVD<T, M> {
fn fit(x: &M, parameters: SVDParameters) -> Result<Self, Failed> { fn fit(x: &M, parameters: SVDParameters) -> Result<Self, Failed> {
SVD::fit(x, parameters) SVD::fit(x, parameters)
@@ -153,6 +207,20 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::*; use crate::linalg::naive::dense_matrix::*;
#[test]
fn search_parameters() {
let parameters = SVDSearchParameters {
n_components: vec![10, 100],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.n_components, 10);
let next = iter.next().unwrap();
assert_eq!(next.n_components, 100);
assert!(iter.next().is_none());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn svd_decompose() { fn svd_decompose() {
+243
View File
@@ -193,6 +193,226 @@ impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RandomForestCla
} }
} }
/// RandomForestClassifier grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct RandomForestClassifierSearchParameters {
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub criterion: Vec<SplitCriterion>,
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub max_depth: Vec<Option<u16>>,
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub min_samples_leaf: Vec<usize>,
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub min_samples_split: Vec<usize>,
/// The number of trees in the forest.
pub n_trees: Vec<u16>,
/// Number of random sample of predictors to use as split candidates.
pub m: Vec<Option<usize>>,
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub keep_samples: Vec<bool>,
/// Seed used for bootstrap sampling and feature selection for each tree.
pub seed: Vec<u64>,
}
/// RandomForestClassifier grid search iterator
pub struct RandomForestClassifierSearchParametersIterator {
random_forest_classifier_search_parameters: RandomForestClassifierSearchParameters,
current_criterion: usize,
current_max_depth: usize,
current_min_samples_leaf: usize,
current_min_samples_split: usize,
current_n_trees: usize,
current_m: usize,
current_keep_samples: usize,
current_seed: usize,
}
impl IntoIterator for RandomForestClassifierSearchParameters {
type Item = RandomForestClassifierParameters;
type IntoIter = RandomForestClassifierSearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
RandomForestClassifierSearchParametersIterator {
random_forest_classifier_search_parameters: self,
current_criterion: 0,
current_max_depth: 0,
current_min_samples_leaf: 0,
current_min_samples_split: 0,
current_n_trees: 0,
current_m: 0,
current_keep_samples: 0,
current_seed: 0,
}
}
}
impl Iterator for RandomForestClassifierSearchParametersIterator {
type Item = RandomForestClassifierParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_criterion
== self
.random_forest_classifier_search_parameters
.criterion
.len()
&& self.current_max_depth
== self
.random_forest_classifier_search_parameters
.max_depth
.len()
&& self.current_min_samples_leaf
== self
.random_forest_classifier_search_parameters
.min_samples_leaf
.len()
&& self.current_min_samples_split
== self
.random_forest_classifier_search_parameters
.min_samples_split
.len()
&& self.current_n_trees
== self
.random_forest_classifier_search_parameters
.n_trees
.len()
&& self.current_m == self.random_forest_classifier_search_parameters.m.len()
&& self.current_keep_samples
== self
.random_forest_classifier_search_parameters
.keep_samples
.len()
&& self.current_seed == self.random_forest_classifier_search_parameters.seed.len()
{
return None;
}
let next = RandomForestClassifierParameters {
criterion: self.random_forest_classifier_search_parameters.criterion
[self.current_criterion]
.clone(),
max_depth: self.random_forest_classifier_search_parameters.max_depth
[self.current_max_depth],
min_samples_leaf: self
.random_forest_classifier_search_parameters
.min_samples_leaf[self.current_min_samples_leaf],
min_samples_split: self
.random_forest_classifier_search_parameters
.min_samples_split[self.current_min_samples_split],
n_trees: self.random_forest_classifier_search_parameters.n_trees[self.current_n_trees],
m: self.random_forest_classifier_search_parameters.m[self.current_m],
keep_samples: self.random_forest_classifier_search_parameters.keep_samples
[self.current_keep_samples],
seed: self.random_forest_classifier_search_parameters.seed[self.current_seed],
};
if self.current_criterion + 1
< self
.random_forest_classifier_search_parameters
.criterion
.len()
{
self.current_criterion += 1;
} else if self.current_max_depth + 1
< self
.random_forest_classifier_search_parameters
.max_depth
.len()
{
self.current_criterion = 0;
self.current_max_depth += 1;
} else if self.current_min_samples_leaf + 1
< self
.random_forest_classifier_search_parameters
.min_samples_leaf
.len()
{
self.current_criterion = 0;
self.current_max_depth = 0;
self.current_min_samples_leaf += 1;
} else if self.current_min_samples_split + 1
< self
.random_forest_classifier_search_parameters
.min_samples_split
.len()
{
self.current_criterion = 0;
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split += 1;
} else if self.current_n_trees + 1
< self
.random_forest_classifier_search_parameters
.n_trees
.len()
{
self.current_criterion = 0;
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split = 0;
self.current_n_trees += 1;
} else if self.current_m + 1 < self.random_forest_classifier_search_parameters.m.len() {
self.current_criterion = 0;
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split = 0;
self.current_n_trees = 0;
self.current_m += 1;
} else if self.current_keep_samples + 1
< self
.random_forest_classifier_search_parameters
.keep_samples
.len()
{
self.current_criterion = 0;
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split = 0;
self.current_n_trees = 0;
self.current_m = 0;
self.current_keep_samples += 1;
} else if self.current_seed + 1 < self.random_forest_classifier_search_parameters.seed.len()
{
self.current_criterion = 0;
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split = 0;
self.current_n_trees = 0;
self.current_m = 0;
self.current_keep_samples = 0;
self.current_seed += 1;
} else {
self.current_criterion += 1;
self.current_max_depth += 1;
self.current_min_samples_leaf += 1;
self.current_min_samples_split += 1;
self.current_n_trees += 1;
self.current_m += 1;
self.current_keep_samples += 1;
self.current_seed += 1;
}
Some(next)
}
}
impl Default for RandomForestClassifierSearchParameters {
fn default() -> Self {
let default_params = RandomForestClassifierParameters::default();
RandomForestClassifierSearchParameters {
criterion: vec![default_params.criterion],
max_depth: vec![default_params.max_depth],
min_samples_leaf: vec![default_params.min_samples_leaf],
min_samples_split: vec![default_params.min_samples_split],
n_trees: vec![default_params.n_trees],
m: vec![default_params.m],
keep_samples: vec![default_params.keep_samples],
seed: vec![default_params.seed],
}
}
}
impl<T: RealNumber> RandomForestClassifier<T> { impl<T: RealNumber> RandomForestClassifier<T> {
/// Build a forest of trees from the training set. /// Build a forest of trees from the training set.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation. /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
@@ -346,6 +566,29 @@ mod tests {
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::metrics::*; use crate::metrics::*;
#[test]
fn search_parameters() {
let parameters = RandomForestClassifierSearchParameters {
n_trees: vec![10, 100],
m: vec![None, Some(1)],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.n_trees, 10);
assert_eq!(next.m, None);
let next = iter.next().unwrap();
assert_eq!(next.n_trees, 100);
assert_eq!(next.m, None);
let next = iter.next().unwrap();
assert_eq!(next.n_trees, 10);
assert_eq!(next.m, Some(1));
let next = iter.next().unwrap();
assert_eq!(next.n_trees, 100);
assert_eq!(next.m, Some(1));
assert!(iter.next().is_none());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn fit_predict_iris() { fn fit_predict_iris() {
+208
View File
@@ -176,6 +176,191 @@ impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RandomForestReg
} }
} }
/// RandomForestRegressor grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct RandomForestRegressorSearchParameters {
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub max_depth: Vec<Option<u16>>,
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub min_samples_leaf: Vec<usize>,
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub min_samples_split: Vec<usize>,
/// The number of trees in the forest.
pub n_trees: Vec<usize>,
/// Number of random sample of predictors to use as split candidates.
pub m: Vec<Option<usize>>,
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub keep_samples: Vec<bool>,
/// Seed used for bootstrap sampling and feature selection for each tree.
pub seed: Vec<u64>,
}
/// RandomForestRegressor grid search iterator
pub struct RandomForestRegressorSearchParametersIterator {
random_forest_regressor_search_parameters: RandomForestRegressorSearchParameters,
current_max_depth: usize,
current_min_samples_leaf: usize,
current_min_samples_split: usize,
current_n_trees: usize,
current_m: usize,
current_keep_samples: usize,
current_seed: usize,
}
impl IntoIterator for RandomForestRegressorSearchParameters {
type Item = RandomForestRegressorParameters;
type IntoIter = RandomForestRegressorSearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
RandomForestRegressorSearchParametersIterator {
random_forest_regressor_search_parameters: self,
current_max_depth: 0,
current_min_samples_leaf: 0,
current_min_samples_split: 0,
current_n_trees: 0,
current_m: 0,
current_keep_samples: 0,
current_seed: 0,
}
}
}
impl Iterator for RandomForestRegressorSearchParametersIterator {
type Item = RandomForestRegressorParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_max_depth
== self
.random_forest_regressor_search_parameters
.max_depth
.len()
&& self.current_min_samples_leaf
== self
.random_forest_regressor_search_parameters
.min_samples_leaf
.len()
&& self.current_min_samples_split
== self
.random_forest_regressor_search_parameters
.min_samples_split
.len()
&& self.current_n_trees == self.random_forest_regressor_search_parameters.n_trees.len()
&& self.current_m == self.random_forest_regressor_search_parameters.m.len()
&& self.current_keep_samples
== self
.random_forest_regressor_search_parameters
.keep_samples
.len()
&& self.current_seed == self.random_forest_regressor_search_parameters.seed.len()
{
return None;
}
let next = RandomForestRegressorParameters {
max_depth: self.random_forest_regressor_search_parameters.max_depth
[self.current_max_depth],
min_samples_leaf: self
.random_forest_regressor_search_parameters
.min_samples_leaf[self.current_min_samples_leaf],
min_samples_split: self
.random_forest_regressor_search_parameters
.min_samples_split[self.current_min_samples_split],
n_trees: self.random_forest_regressor_search_parameters.n_trees[self.current_n_trees],
m: self.random_forest_regressor_search_parameters.m[self.current_m],
keep_samples: self.random_forest_regressor_search_parameters.keep_samples
[self.current_keep_samples],
seed: self.random_forest_regressor_search_parameters.seed[self.current_seed],
};
if self.current_max_depth + 1
< self
.random_forest_regressor_search_parameters
.max_depth
.len()
{
self.current_max_depth += 1;
} else if self.current_min_samples_leaf + 1
< self
.random_forest_regressor_search_parameters
.min_samples_leaf
.len()
{
self.current_max_depth = 0;
self.current_min_samples_leaf += 1;
} else if self.current_min_samples_split + 1
< self
.random_forest_regressor_search_parameters
.min_samples_split
.len()
{
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split += 1;
} else if self.current_n_trees + 1
< self.random_forest_regressor_search_parameters.n_trees.len()
{
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split = 0;
self.current_n_trees += 1;
} else if self.current_m + 1 < self.random_forest_regressor_search_parameters.m.len() {
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split = 0;
self.current_n_trees = 0;
self.current_m += 1;
} else if self.current_keep_samples + 1
< self
.random_forest_regressor_search_parameters
.keep_samples
.len()
{
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split = 0;
self.current_n_trees = 0;
self.current_m = 0;
self.current_keep_samples += 1;
} else if self.current_seed + 1 < self.random_forest_regressor_search_parameters.seed.len()
{
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split = 0;
self.current_n_trees = 0;
self.current_m = 0;
self.current_keep_samples = 0;
self.current_seed += 1;
} else {
self.current_max_depth += 1;
self.current_min_samples_leaf += 1;
self.current_min_samples_split += 1;
self.current_n_trees += 1;
self.current_m += 1;
self.current_keep_samples += 1;
self.current_seed += 1;
}
Some(next)
}
}
impl Default for RandomForestRegressorSearchParameters {
fn default() -> Self {
let default_params = RandomForestRegressorParameters::default();
RandomForestRegressorSearchParameters {
max_depth: vec![default_params.max_depth],
min_samples_leaf: vec![default_params.min_samples_leaf],
min_samples_split: vec![default_params.min_samples_split],
n_trees: vec![default_params.n_trees],
m: vec![default_params.m],
keep_samples: vec![default_params.keep_samples],
seed: vec![default_params.seed],
}
}
}
impl<T: RealNumber> RandomForestRegressor<T> { impl<T: RealNumber> RandomForestRegressor<T> {
/// Build a forest of trees from the training set. /// Build a forest of trees from the training set.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation. /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
@@ -302,6 +487,29 @@ mod tests {
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::metrics::mean_absolute_error; use crate::metrics::mean_absolute_error;
#[test]
fn search_parameters() {
let parameters = RandomForestRegressorSearchParameters {
n_trees: vec![10, 100],
m: vec![None, Some(1)],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.n_trees, 10);
assert_eq!(next.m, None);
let next = iter.next().unwrap();
assert_eq!(next.n_trees, 100);
assert_eq!(next.m, None);
let next = iter.next().unwrap();
assert_eq!(next.n_trees, 10);
assert_eq!(next.m, Some(1));
let next = iter.next().unwrap();
assert_eq!(next.n_trees, 100);
assert_eq!(next.m, Some(1));
assert!(iter.next().is_none());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn fit_longley() { fn fit_longley() {
+14 -17
View File
@@ -129,7 +129,7 @@ pub struct LassoSearchParameters<T: RealNumber> {
/// Lasso grid search iterator /// Lasso grid search iterator
pub struct LassoSearchParametersIterator<T: RealNumber> { pub struct LassoSearchParametersIterator<T: RealNumber> {
lasso_regression_search_parameters: LassoSearchParameters<T>, lasso_search_parameters: LassoSearchParameters<T>,
current_alpha: usize, current_alpha: usize,
current_normalize: usize, current_normalize: usize,
current_tol: usize, current_tol: usize,
@@ -142,7 +142,7 @@ impl<T: RealNumber> IntoIterator for LassoSearchParameters<T> {
fn into_iter(self) -> Self::IntoIter { fn into_iter(self) -> Self::IntoIter {
LassoSearchParametersIterator { LassoSearchParametersIterator {
lasso_regression_search_parameters: self, lasso_search_parameters: self,
current_alpha: 0, current_alpha: 0,
current_normalize: 0, current_normalize: 0,
current_tol: 0, current_tol: 0,
@@ -155,34 +155,31 @@ impl<T: RealNumber> Iterator for LassoSearchParametersIterator<T> {
type Item = LassoParameters<T>; type Item = LassoParameters<T>;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
if self.current_alpha == self.lasso_regression_search_parameters.alpha.len() if self.current_alpha == self.lasso_search_parameters.alpha.len()
&& self.current_normalize == self.lasso_regression_search_parameters.normalize.len() && self.current_normalize == self.lasso_search_parameters.normalize.len()
&& self.current_tol == self.lasso_regression_search_parameters.tol.len() && self.current_tol == self.lasso_search_parameters.tol.len()
&& self.current_max_iter == self.lasso_regression_search_parameters.max_iter.len() && self.current_max_iter == self.lasso_search_parameters.max_iter.len()
{ {
return None; return None;
} }
let next = LassoParameters { let next = LassoParameters {
alpha: self.lasso_regression_search_parameters.alpha[self.current_alpha], alpha: self.lasso_search_parameters.alpha[self.current_alpha],
normalize: self.lasso_regression_search_parameters.normalize[self.current_normalize], normalize: self.lasso_search_parameters.normalize[self.current_normalize],
tol: self.lasso_regression_search_parameters.tol[self.current_tol], tol: self.lasso_search_parameters.tol[self.current_tol],
max_iter: self.lasso_regression_search_parameters.max_iter[self.current_max_iter], max_iter: self.lasso_search_parameters.max_iter[self.current_max_iter],
}; };
if self.current_alpha + 1 < self.lasso_regression_search_parameters.alpha.len() { if self.current_alpha + 1 < self.lasso_search_parameters.alpha.len() {
self.current_alpha += 1; self.current_alpha += 1;
} else if self.current_normalize + 1 } else if self.current_normalize + 1 < self.lasso_search_parameters.normalize.len() {
< self.lasso_regression_search_parameters.normalize.len()
{
self.current_alpha = 0; self.current_alpha = 0;
self.current_normalize += 1; self.current_normalize += 1;
} else if self.current_tol + 1 < self.lasso_regression_search_parameters.tol.len() { } else if self.current_tol + 1 < self.lasso_search_parameters.tol.len() {
self.current_alpha = 0; self.current_alpha = 0;
self.current_normalize = 0; self.current_normalize = 0;
self.current_tol += 1; self.current_tol += 1;
} else if self.current_max_iter + 1 < self.lasso_regression_search_parameters.max_iter.len() } else if self.current_max_iter + 1 < self.lasso_search_parameters.max_iter.len() {
{
self.current_alpha = 0; self.current_alpha = 0;
self.current_normalize = 0; self.current_normalize = 0;
self.current_tol = 0; self.current_tol = 0;
+1 -1
View File
@@ -114,4 +114,4 @@ mod tests {
assert!([0., 1.].contains(&results.parameters.alpha)); assert!([0., 1.].contains(&results.parameters.alpha));
} }
} }
-1
View File
@@ -281,7 +281,6 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::*; use crate::linalg::naive::dense_matrix::*;
use crate::metrics::{accuracy, mean_absolute_error}; use crate::metrics::{accuracy, mean_absolute_error};
use crate::model_selection::kfold::KFold; use crate::model_selection::kfold::KFold;
use crate::neighbors::knn_regressor::KNNRegressor; use crate::neighbors::knn_regressor::KNNRegressor;
+96
View File
@@ -150,6 +150,88 @@ impl<T: RealNumber> Default for BernoulliNBParameters<T> {
} }
} }
/// BernoulliNB grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct BernoulliNBSearchParameters<T: RealNumber> {
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
pub alpha: Vec<T>,
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
pub priors: Vec<Option<Vec<T>>>,
/// Threshold for binarizing (mapping to booleans) of sample features. If None, input is presumed to already consist of binary vectors.
pub binarize: Vec<Option<T>>,
}
/// BernoulliNB grid search iterator
pub struct BernoulliNBSearchParametersIterator<T: RealNumber> {
bernoulli_nb_search_parameters: BernoulliNBSearchParameters<T>,
current_alpha: usize,
current_priors: usize,
current_binarize: usize,
}
impl<T: RealNumber> IntoIterator for BernoulliNBSearchParameters<T> {
type Item = BernoulliNBParameters<T>;
type IntoIter = BernoulliNBSearchParametersIterator<T>;
fn into_iter(self) -> Self::IntoIter {
BernoulliNBSearchParametersIterator {
bernoulli_nb_search_parameters: self,
current_alpha: 0,
current_priors: 0,
current_binarize: 0,
}
}
}
impl<T: RealNumber> Iterator for BernoulliNBSearchParametersIterator<T> {
type Item = BernoulliNBParameters<T>;
fn next(&mut self) -> Option<Self::Item> {
if self.current_alpha == self.bernoulli_nb_search_parameters.alpha.len()
&& self.current_priors == self.bernoulli_nb_search_parameters.priors.len()
&& self.current_binarize == self.bernoulli_nb_search_parameters.binarize.len()
{
return None;
}
let next = BernoulliNBParameters {
alpha: self.bernoulli_nb_search_parameters.alpha[self.current_alpha],
priors: self.bernoulli_nb_search_parameters.priors[self.current_priors].clone(),
binarize: self.bernoulli_nb_search_parameters.binarize[self.current_binarize],
};
if self.current_alpha + 1 < self.bernoulli_nb_search_parameters.alpha.len() {
self.current_alpha += 1;
} else if self.current_priors + 1 < self.bernoulli_nb_search_parameters.priors.len() {
self.current_alpha = 0;
self.current_priors += 1;
} else if self.current_binarize + 1 < self.bernoulli_nb_search_parameters.binarize.len() {
self.current_alpha = 0;
self.current_priors = 0;
self.current_binarize += 1;
} else {
self.current_alpha += 1;
self.current_priors += 1;
self.current_binarize += 1;
}
Some(next)
}
}
impl<T: RealNumber> Default for BernoulliNBSearchParameters<T> {
fn default() -> Self {
let default_params = BernoulliNBParameters::default();
BernoulliNBSearchParameters {
alpha: vec![default_params.alpha],
priors: vec![default_params.priors],
binarize: vec![default_params.binarize],
}
}
}
impl<T: RealNumber> BernoulliNBDistribution<T> { impl<T: RealNumber> BernoulliNBDistribution<T> {
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features. /// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
/// * `x` - training data. /// * `x` - training data.
@@ -347,6 +429,20 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
#[test]
fn search_parameters() {
let parameters = BernoulliNBSearchParameters {
alpha: vec![1., 2.],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.alpha, 1.);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 2.);
assert!(iter.next().is_none());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn run_bernoulli_naive_bayes() { fn run_bernoulli_naive_bayes() {
+68
View File
@@ -261,6 +261,60 @@ impl<T: RealNumber> Default for CategoricalNBParameters<T> {
} }
} }
/// CategoricalNB grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct CategoricalNBSearchParameters<T: RealNumber> {
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
pub alpha: Vec<T>,
}
/// CategoricalNB grid search iterator
pub struct CategoricalNBSearchParametersIterator<T: RealNumber> {
categorical_nb_search_parameters: CategoricalNBSearchParameters<T>,
current_alpha: usize,
}
impl<T: RealNumber> IntoIterator for CategoricalNBSearchParameters<T> {
type Item = CategoricalNBParameters<T>;
type IntoIter = CategoricalNBSearchParametersIterator<T>;
fn into_iter(self) -> Self::IntoIter {
CategoricalNBSearchParametersIterator {
categorical_nb_search_parameters: self,
current_alpha: 0,
}
}
}
impl<T: RealNumber> Iterator for CategoricalNBSearchParametersIterator<T> {
type Item = CategoricalNBParameters<T>;
fn next(&mut self) -> Option<Self::Item> {
if self.current_alpha == self.categorical_nb_search_parameters.alpha.len() {
return None;
}
let next = CategoricalNBParameters {
alpha: self.categorical_nb_search_parameters.alpha[self.current_alpha],
};
self.current_alpha += 1;
Some(next)
}
}
impl<T: RealNumber> Default for CategoricalNBSearchParameters<T> {
fn default() -> Self {
let default_params = CategoricalNBParameters::default();
CategoricalNBSearchParameters {
alpha: vec![default_params.alpha],
}
}
}
/// CategoricalNB implements the categorical naive Bayes algorithm for categorically distributed data. /// CategoricalNB implements the categorical naive Bayes algorithm for categorically distributed data.
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
@@ -351,6 +405,20 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
#[test]
fn search_parameters() {
let parameters = CategoricalNBSearchParameters {
alpha: vec![1., 2.],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.alpha, 1.);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 2.);
assert!(iter.next().is_none());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn run_categorical_naive_bayes() { fn run_categorical_naive_bayes() {
+75 -1
View File
@@ -76,7 +76,7 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for GaussianNBDistributio
/// `GaussianNB` parameters. Use `Default::default()` for default values. /// `GaussianNB` parameters. Use `Default::default()` for default values.
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Default, Clone)] #[derive(Debug, Clone)]
pub struct GaussianNBParameters<T: RealNumber> { pub struct GaussianNBParameters<T: RealNumber> {
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data /// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
pub priors: Option<Vec<T>>, pub priors: Option<Vec<T>>,
@@ -90,6 +90,66 @@ impl<T: RealNumber> GaussianNBParameters<T> {
} }
} }
impl<T: RealNumber> Default for GaussianNBParameters<T> {
fn default() -> Self {
Self { priors: None }
}
}
/// GaussianNB grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct GaussianNBSearchParameters<T: RealNumber> {
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
pub priors: Vec<Option<Vec<T>>>,
}
/// GaussianNB grid search iterator
pub struct GaussianNBSearchParametersIterator<T: RealNumber> {
gaussian_nb_search_parameters: GaussianNBSearchParameters<T>,
current_priors: usize,
}
impl<T: RealNumber> IntoIterator for GaussianNBSearchParameters<T> {
type Item = GaussianNBParameters<T>;
type IntoIter = GaussianNBSearchParametersIterator<T>;
fn into_iter(self) -> Self::IntoIter {
GaussianNBSearchParametersIterator {
gaussian_nb_search_parameters: self,
current_priors: 0,
}
}
}
impl<T: RealNumber> Iterator for GaussianNBSearchParametersIterator<T> {
type Item = GaussianNBParameters<T>;
fn next(&mut self) -> Option<Self::Item> {
if self.current_priors == self.gaussian_nb_search_parameters.priors.len() {
return None;
}
let next = GaussianNBParameters {
priors: self.gaussian_nb_search_parameters.priors[self.current_priors].clone(),
};
self.current_priors += 1;
Some(next)
}
}
impl<T: RealNumber> Default for GaussianNBSearchParameters<T> {
fn default() -> Self {
let default_params = GaussianNBParameters::default();
GaussianNBSearchParameters {
priors: vec![default_params.priors],
}
}
}
impl<T: RealNumber> GaussianNBDistribution<T> { impl<T: RealNumber> GaussianNBDistribution<T> {
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features. /// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
/// * `x` - training data. /// * `x` - training data.
@@ -260,6 +320,20 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
#[test]
fn search_parameters() {
let parameters = GaussianNBSearchParameters {
priors: vec![Some(vec![1.]), Some(vec![2.])],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.priors, Some(vec![1.]));
let next = iter.next().unwrap();
assert_eq!(next.priors, Some(vec![2.]));
assert!(iter.next().is_none());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn run_gaussian_naive_bayes() { fn run_gaussian_naive_bayes() {
+84
View File
@@ -114,6 +114,76 @@ impl<T: RealNumber> Default for MultinomialNBParameters<T> {
} }
} }
/// MultinomialNB grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct MultinomialNBSearchParameters<T: RealNumber> {
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
pub alpha: Vec<T>,
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
pub priors: Vec<Option<Vec<T>>>,
}
/// MultinomialNB grid search iterator
pub struct MultinomialNBSearchParametersIterator<T: RealNumber> {
multinomial_nb_search_parameters: MultinomialNBSearchParameters<T>,
current_alpha: usize,
current_priors: usize,
}
impl<T: RealNumber> IntoIterator for MultinomialNBSearchParameters<T> {
type Item = MultinomialNBParameters<T>;
type IntoIter = MultinomialNBSearchParametersIterator<T>;
fn into_iter(self) -> Self::IntoIter {
MultinomialNBSearchParametersIterator {
multinomial_nb_search_parameters: self,
current_alpha: 0,
current_priors: 0,
}
}
}
impl<T: RealNumber> Iterator for MultinomialNBSearchParametersIterator<T> {
type Item = MultinomialNBParameters<T>;
fn next(&mut self) -> Option<Self::Item> {
if self.current_alpha == self.multinomial_nb_search_parameters.alpha.len()
&& self.current_priors == self.multinomial_nb_search_parameters.priors.len()
{
return None;
}
let next = MultinomialNBParameters {
alpha: self.multinomial_nb_search_parameters.alpha[self.current_alpha],
priors: self.multinomial_nb_search_parameters.priors[self.current_priors].clone(),
};
if self.current_alpha + 1 < self.multinomial_nb_search_parameters.alpha.len() {
self.current_alpha += 1;
} else if self.current_priors + 1 < self.multinomial_nb_search_parameters.priors.len() {
self.current_alpha = 0;
self.current_priors += 1;
} else {
self.current_alpha += 1;
self.current_priors += 1;
}
Some(next)
}
}
impl<T: RealNumber> Default for MultinomialNBSearchParameters<T> {
fn default() -> Self {
let default_params = MultinomialNBParameters::default();
MultinomialNBSearchParameters {
alpha: vec![default_params.alpha],
priors: vec![default_params.priors],
}
}
}
impl<T: RealNumber> MultinomialNBDistribution<T> { impl<T: RealNumber> MultinomialNBDistribution<T> {
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features. /// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
/// * `x` - training data. /// * `x` - training data.
@@ -297,6 +367,20 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
#[test]
fn search_parameters() {
let parameters = MultinomialNBSearchParameters {
alpha: vec![1., 2.],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.alpha, 1.);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 2.);
assert!(iter.next().is_none());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn run_multinomial_naive_bayes() { fn run_multinomial_naive_bayes() {
+5 -5
View File
@@ -33,7 +33,7 @@ use crate::linalg::BaseVector;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
/// Defines a kernel function /// Defines a kernel function
pub trait Kernel<T: RealNumber, V: BaseVector<T>> { pub trait Kernel<T: RealNumber, V: BaseVector<T>>: Clone {
/// Apply kernel function to x_i and x_j /// Apply kernel function to x_i and x_j
fn apply(&self, x_i: &V, x_j: &V) -> T; fn apply(&self, x_i: &V, x_j: &V) -> T;
} }
@@ -95,12 +95,12 @@ impl Kernels {
/// Linear Kernel /// Linear Kernel
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct LinearKernel {} pub struct LinearKernel {}
/// Radial basis function (Gaussian) kernel /// Radial basis function (Gaussian) kernel
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct RBFKernel<T: RealNumber> { pub struct RBFKernel<T: RealNumber> {
/// kernel coefficient /// kernel coefficient
pub gamma: T, pub gamma: T,
@@ -108,7 +108,7 @@ pub struct RBFKernel<T: RealNumber> {
/// Polynomial kernel /// Polynomial kernel
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct PolynomialKernel<T: RealNumber> { pub struct PolynomialKernel<T: RealNumber> {
/// degree of the polynomial /// degree of the polynomial
pub degree: T, pub degree: T,
@@ -120,7 +120,7 @@ pub struct PolynomialKernel<T: RealNumber> {
/// Sigmoid (hyperbolic tangent) kernel /// Sigmoid (hyperbolic tangent) kernel
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct SigmoidKernel<T: RealNumber> { pub struct SigmoidKernel<T: RealNumber> {
/// kernel coefficient /// kernel coefficient
pub gamma: T, pub gamma: T,
+121
View File
@@ -102,6 +102,109 @@ pub struct SVCParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>
m: PhantomData<M>, m: PhantomData<M>,
} }
/// SVC grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct SVCSearchParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
/// Number of epochs.
pub epoch: Vec<usize>,
/// Regularization parameter.
pub c: Vec<T>,
/// Tolerance for stopping epoch.
pub tol: Vec<T>,
/// The kernel function.
pub kernel: Vec<K>,
/// Unused parameter.
m: PhantomData<M>,
}
/// SVC grid search iterator
pub struct SVCSearchParametersIterator<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
svc_search_parameters: SVCSearchParameters<T, M, K>,
current_epoch: usize,
current_c: usize,
current_tol: usize,
current_kernel: usize,
}
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> IntoIterator
for SVCSearchParameters<T, M, K>
{
type Item = SVCParameters<T, M, K>;
type IntoIter = SVCSearchParametersIterator<T, M, K>;
fn into_iter(self) -> Self::IntoIter {
SVCSearchParametersIterator {
svc_search_parameters: self,
current_epoch: 0,
current_c: 0,
current_tol: 0,
current_kernel: 0,
}
}
}
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Iterator
for SVCSearchParametersIterator<T, M, K>
{
type Item = SVCParameters<T, M, K>;
fn next(&mut self) -> Option<Self::Item> {
if self.current_epoch == self.svc_search_parameters.epoch.len()
&& self.current_c == self.svc_search_parameters.c.len()
&& self.current_tol == self.svc_search_parameters.tol.len()
&& self.current_kernel == self.svc_search_parameters.kernel.len()
{
return None;
}
let next = SVCParameters::<T, M, K> {
epoch: self.svc_search_parameters.epoch[self.current_epoch],
c: self.svc_search_parameters.c[self.current_c],
tol: self.svc_search_parameters.tol[self.current_tol],
kernel: self.svc_search_parameters.kernel[self.current_kernel].clone(),
m: PhantomData,
};
if self.current_epoch + 1 < self.svc_search_parameters.epoch.len() {
self.current_epoch += 1;
} else if self.current_c + 1 < self.svc_search_parameters.c.len() {
self.current_epoch = 0;
self.current_c += 1;
} else if self.current_tol + 1 < self.svc_search_parameters.tol.len() {
self.current_epoch = 0;
self.current_c = 0;
self.current_tol += 1;
} else if self.current_kernel + 1 < self.svc_search_parameters.kernel.len() {
self.current_epoch = 0;
self.current_c = 0;
self.current_tol = 0;
self.current_kernel += 1;
} else {
self.current_epoch += 1;
self.current_c += 1;
self.current_tol += 1;
self.current_kernel += 1;
}
Some(next)
}
}
impl<T: RealNumber, M: Matrix<T>> Default for SVCSearchParameters<T, M, LinearKernel> {
fn default() -> Self {
let default_params: SVCParameters<T, M, LinearKernel> = SVCParameters::default();
SVCSearchParameters {
epoch: vec![default_params.epoch],
c: vec![default_params.c],
tol: vec![default_params.tol],
kernel: vec![default_params.kernel],
m: PhantomData,
}
}
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)] #[derive(Debug)]
#[cfg_attr( #[cfg_attr(
@@ -737,6 +840,24 @@ mod tests {
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
use crate::svm::*; use crate::svm::*;
#[test]
fn search_parameters() {
let parameters: SVCSearchParameters<f64, DenseMatrix<f64>, LinearKernel> =
SVCSearchParameters {
epoch: vec![10, 100],
kernel: vec![LinearKernel {}],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.epoch, 10);
assert_eq!(next.kernel, LinearKernel {});
let next = iter.next().unwrap();
assert_eq!(next.epoch, 100);
assert_eq!(next.kernel, LinearKernel {});
assert!(iter.next().is_none());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn svc_fit_predict() { fn svc_fit_predict() {
+121
View File
@@ -94,6 +94,109 @@ pub struct SVRParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>
m: PhantomData<M>, m: PhantomData<M>,
} }
/// SVR grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct SVRSearchParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
/// Epsilon in the epsilon-SVR model.
pub eps: Vec<T>,
/// Regularization parameter.
pub c: Vec<T>,
/// Tolerance for stopping eps.
pub tol: Vec<T>,
/// The kernel function.
pub kernel: Vec<K>,
/// Unused parameter.
m: PhantomData<M>,
}
/// SVR grid search iterator
pub struct SVRSearchParametersIterator<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
svr_search_parameters: SVRSearchParameters<T, M, K>,
current_eps: usize,
current_c: usize,
current_tol: usize,
current_kernel: usize,
}
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> IntoIterator
for SVRSearchParameters<T, M, K>
{
type Item = SVRParameters<T, M, K>;
type IntoIter = SVRSearchParametersIterator<T, M, K>;
fn into_iter(self) -> Self::IntoIter {
SVRSearchParametersIterator {
svr_search_parameters: self,
current_eps: 0,
current_c: 0,
current_tol: 0,
current_kernel: 0,
}
}
}
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Iterator
for SVRSearchParametersIterator<T, M, K>
{
type Item = SVRParameters<T, M, K>;
fn next(&mut self) -> Option<Self::Item> {
if self.current_eps == self.svr_search_parameters.eps.len()
&& self.current_c == self.svr_search_parameters.c.len()
&& self.current_tol == self.svr_search_parameters.tol.len()
&& self.current_kernel == self.svr_search_parameters.kernel.len()
{
return None;
}
let next = SVRParameters::<T, M, K> {
eps: self.svr_search_parameters.eps[self.current_eps],
c: self.svr_search_parameters.c[self.current_c],
tol: self.svr_search_parameters.tol[self.current_tol],
kernel: self.svr_search_parameters.kernel[self.current_kernel].clone(),
m: PhantomData,
};
if self.current_eps + 1 < self.svr_search_parameters.eps.len() {
self.current_eps += 1;
} else if self.current_c + 1 < self.svr_search_parameters.c.len() {
self.current_eps = 0;
self.current_c += 1;
} else if self.current_tol + 1 < self.svr_search_parameters.tol.len() {
self.current_eps = 0;
self.current_c = 0;
self.current_tol += 1;
} else if self.current_kernel + 1 < self.svr_search_parameters.kernel.len() {
self.current_eps = 0;
self.current_c = 0;
self.current_tol = 0;
self.current_kernel += 1;
} else {
self.current_eps += 1;
self.current_c += 1;
self.current_tol += 1;
self.current_kernel += 1;
}
Some(next)
}
}
impl<T: RealNumber, M: Matrix<T>> Default for SVRSearchParameters<T, M, LinearKernel> {
fn default() -> Self {
let default_params: SVRParameters<T, M, LinearKernel> = SVRParameters::default();
SVRSearchParameters {
eps: vec![default_params.eps],
c: vec![default_params.c],
tol: vec![default_params.tol],
kernel: vec![default_params.kernel],
m: PhantomData,
}
}
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)] #[derive(Debug)]
#[cfg_attr( #[cfg_attr(
@@ -536,6 +639,24 @@ mod tests {
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
use crate::svm::*; use crate::svm::*;
#[test]
fn search_parameters() {
let parameters: SVRSearchParameters<f64, DenseMatrix<f64>, LinearKernel> =
SVRSearchParameters {
eps: vec![0., 1.],
kernel: vec![LinearKernel {}],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.eps, 0.);
assert_eq!(next.kernel, LinearKernel {});
let next = iter.next().unwrap();
assert_eq!(next.eps, 1.);
assert_eq!(next.kernel, LinearKernel {});
assert!(iter.next().is_none());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn svr_fit_predict() { fn svr_fit_predict() {
+161
View File
@@ -201,6 +201,144 @@ impl Default for DecisionTreeClassifierParameters {
} }
} }
/// DecisionTreeClassifier grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct DecisionTreeClassifierSearchParameters {
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub criterion: Vec<SplitCriterion>,
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub max_depth: Vec<Option<u16>>,
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub min_samples_leaf: Vec<usize>,
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub min_samples_split: Vec<usize>,
}
/// DecisionTreeClassifier grid search iterator
pub struct DecisionTreeClassifierSearchParametersIterator {
decision_tree_classifier_search_parameters: DecisionTreeClassifierSearchParameters,
current_criterion: usize,
current_max_depth: usize,
current_min_samples_leaf: usize,
current_min_samples_split: usize,
}
impl IntoIterator for DecisionTreeClassifierSearchParameters {
type Item = DecisionTreeClassifierParameters;
type IntoIter = DecisionTreeClassifierSearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
DecisionTreeClassifierSearchParametersIterator {
decision_tree_classifier_search_parameters: self,
current_criterion: 0,
current_max_depth: 0,
current_min_samples_leaf: 0,
current_min_samples_split: 0,
}
}
}
impl Iterator for DecisionTreeClassifierSearchParametersIterator {
type Item = DecisionTreeClassifierParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_criterion
== self
.decision_tree_classifier_search_parameters
.criterion
.len()
&& self.current_max_depth
== self
.decision_tree_classifier_search_parameters
.max_depth
.len()
&& self.current_min_samples_leaf
== self
.decision_tree_classifier_search_parameters
.min_samples_leaf
.len()
&& self.current_min_samples_split
== self
.decision_tree_classifier_search_parameters
.min_samples_split
.len()
{
return None;
}
let next = DecisionTreeClassifierParameters {
criterion: self.decision_tree_classifier_search_parameters.criterion
[self.current_criterion]
.clone(),
max_depth: self.decision_tree_classifier_search_parameters.max_depth
[self.current_max_depth],
min_samples_leaf: self
.decision_tree_classifier_search_parameters
.min_samples_leaf[self.current_min_samples_leaf],
min_samples_split: self
.decision_tree_classifier_search_parameters
.min_samples_split[self.current_min_samples_split],
};
if self.current_criterion + 1
< self
.decision_tree_classifier_search_parameters
.criterion
.len()
{
self.current_criterion += 1;
} else if self.current_max_depth + 1
< self
.decision_tree_classifier_search_parameters
.max_depth
.len()
{
self.current_criterion = 0;
self.current_max_depth += 1;
} else if self.current_min_samples_leaf + 1
< self
.decision_tree_classifier_search_parameters
.min_samples_leaf
.len()
{
self.current_criterion = 0;
self.current_max_depth = 0;
self.current_min_samples_leaf += 1;
} else if self.current_min_samples_split + 1
< self
.decision_tree_classifier_search_parameters
.min_samples_split
.len()
{
self.current_criterion = 0;
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split += 1;
} else {
self.current_criterion += 1;
self.current_max_depth += 1;
self.current_min_samples_leaf += 1;
self.current_min_samples_split += 1;
}
Some(next)
}
}
impl Default for DecisionTreeClassifierSearchParameters {
fn default() -> Self {
let default_params = DecisionTreeClassifierParameters::default();
DecisionTreeClassifierSearchParameters {
criterion: vec![default_params.criterion],
max_depth: vec![default_params.max_depth],
min_samples_leaf: vec![default_params.min_samples_leaf],
min_samples_split: vec![default_params.min_samples_split],
}
}
}
impl<T: RealNumber> Node<T> { impl<T: RealNumber> Node<T> {
fn new(index: usize, output: usize) -> Self { fn new(index: usize, output: usize) -> Self {
Node { Node {
@@ -651,6 +789,29 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
#[test]
fn search_parameters() {
let parameters = DecisionTreeClassifierSearchParameters {
max_depth: vec![Some(10), Some(100)],
min_samples_split: vec![1, 2],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.max_depth, Some(10));
assert_eq!(next.min_samples_split, 1);
let next = iter.next().unwrap();
assert_eq!(next.max_depth, Some(100));
assert_eq!(next.min_samples_split, 1);
let next = iter.next().unwrap();
assert_eq!(next.max_depth, Some(10));
assert_eq!(next.min_samples_split, 2);
let next = iter.next().unwrap();
assert_eq!(next.max_depth, Some(100));
assert_eq!(next.min_samples_split, 2);
assert!(iter.next().is_none());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn gini_impurity() { fn gini_impurity() {
+137
View File
@@ -134,6 +134,120 @@ impl Default for DecisionTreeRegressorParameters {
} }
} }
/// DecisionTreeRegressor grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct DecisionTreeRegressorSearchParameters {
/// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub max_depth: Vec<Option<u16>>,
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub min_samples_leaf: Vec<usize>,
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub min_samples_split: Vec<usize>,
}
/// DecisionTreeRegressor grid search iterator
pub struct DecisionTreeRegressorSearchParametersIterator {
decision_tree_regressor_search_parameters: DecisionTreeRegressorSearchParameters,
current_max_depth: usize,
current_min_samples_leaf: usize,
current_min_samples_split: usize,
}
impl IntoIterator for DecisionTreeRegressorSearchParameters {
type Item = DecisionTreeRegressorParameters;
type IntoIter = DecisionTreeRegressorSearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
DecisionTreeRegressorSearchParametersIterator {
decision_tree_regressor_search_parameters: self,
current_max_depth: 0,
current_min_samples_leaf: 0,
current_min_samples_split: 0,
}
}
}
impl Iterator for DecisionTreeRegressorSearchParametersIterator {
type Item = DecisionTreeRegressorParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_max_depth
== self
.decision_tree_regressor_search_parameters
.max_depth
.len()
&& self.current_min_samples_leaf
== self
.decision_tree_regressor_search_parameters
.min_samples_leaf
.len()
&& self.current_min_samples_split
== self
.decision_tree_regressor_search_parameters
.min_samples_split
.len()
{
return None;
}
let next = DecisionTreeRegressorParameters {
max_depth: self.decision_tree_regressor_search_parameters.max_depth
[self.current_max_depth],
min_samples_leaf: self
.decision_tree_regressor_search_parameters
.min_samples_leaf[self.current_min_samples_leaf],
min_samples_split: self
.decision_tree_regressor_search_parameters
.min_samples_split[self.current_min_samples_split],
};
if self.current_max_depth + 1
< self
.decision_tree_regressor_search_parameters
.max_depth
.len()
{
self.current_max_depth += 1;
} else if self.current_min_samples_leaf + 1
< self
.decision_tree_regressor_search_parameters
.min_samples_leaf
.len()
{
self.current_max_depth = 0;
self.current_min_samples_leaf += 1;
} else if self.current_min_samples_split + 1
< self
.decision_tree_regressor_search_parameters
.min_samples_split
.len()
{
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split += 1;
} else {
self.current_max_depth += 1;
self.current_min_samples_leaf += 1;
self.current_min_samples_split += 1;
}
Some(next)
}
}
impl Default for DecisionTreeRegressorSearchParameters {
fn default() -> Self {
let default_params = DecisionTreeRegressorParameters::default();
DecisionTreeRegressorSearchParameters {
max_depth: vec![default_params.max_depth],
min_samples_leaf: vec![default_params.min_samples_leaf],
min_samples_split: vec![default_params.min_samples_split],
}
}
}
impl<T: RealNumber> Node<T> { impl<T: RealNumber> Node<T> {
fn new(index: usize, output: T) -> Self { fn new(index: usize, output: T) -> Self {
Node { Node {
@@ -517,6 +631,29 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
#[test]
fn search_parameters() {
let parameters = DecisionTreeRegressorSearchParameters {
max_depth: vec![Some(10), Some(100)],
min_samples_split: vec![1, 2],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.max_depth, Some(10));
assert_eq!(next.min_samples_split, 1);
let next = iter.next().unwrap();
assert_eq!(next.max_depth, Some(100));
assert_eq!(next.min_samples_split, 1);
let next = iter.next().unwrap();
assert_eq!(next.max_depth, Some(10));
assert_eq!(next.min_samples_split, 2);
let next = iter.next().unwrap();
assert_eq!(next.max_depth, Some(100));
assert_eq!(next.min_samples_split, 2);
assert!(iter.next().is_none());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn fit_longley() { fn fit_longley() {