Complete grid search params (#166)
* grid search draft * hyperparam search for linear estimators * grid search for ensembles * support grid search for more algos * grid search for unsupervised algos * minor cleanup
This commit is contained in:
@@ -109,6 +109,103 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> DBSCANParameters<T, D> {
|
||||
}
|
||||
}
|
||||
|
||||
/// DBSCAN grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DBSCANSearchParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
/// a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
pub distance: Vec<D>,
|
||||
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
|
||||
pub min_samples: Vec<usize>,
|
||||
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
|
||||
pub eps: Vec<T>,
|
||||
/// KNN algorithm to use.
|
||||
pub algorithm: Vec<KNNAlgorithmName>,
|
||||
}
|
||||
|
||||
/// DBSCAN grid search iterator
|
||||
pub struct DBSCANSearchParametersIterator<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
dbscan_search_parameters: DBSCANSearchParameters<T, D>,
|
||||
current_distance: usize,
|
||||
current_min_samples: usize,
|
||||
current_eps: usize,
|
||||
current_algorithm: usize,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> IntoIterator for DBSCANSearchParameters<T, D> {
|
||||
type Item = DBSCANParameters<T, D>;
|
||||
type IntoIter = DBSCANSearchParametersIterator<T, D>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
DBSCANSearchParametersIterator {
|
||||
dbscan_search_parameters: self,
|
||||
current_distance: 0,
|
||||
current_min_samples: 0,
|
||||
current_eps: 0,
|
||||
current_algorithm: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> Iterator for DBSCANSearchParametersIterator<T, D> {
|
||||
type Item = DBSCANParameters<T, D>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_distance == self.dbscan_search_parameters.distance.len()
|
||||
&& self.current_min_samples == self.dbscan_search_parameters.min_samples.len()
|
||||
&& self.current_eps == self.dbscan_search_parameters.eps.len()
|
||||
&& self.current_algorithm == self.dbscan_search_parameters.algorithm.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = DBSCANParameters {
|
||||
distance: self.dbscan_search_parameters.distance[self.current_distance].clone(),
|
||||
min_samples: self.dbscan_search_parameters.min_samples[self.current_min_samples],
|
||||
eps: self.dbscan_search_parameters.eps[self.current_eps],
|
||||
algorithm: self.dbscan_search_parameters.algorithm[self.current_algorithm].clone(),
|
||||
};
|
||||
|
||||
if self.current_distance + 1 < self.dbscan_search_parameters.distance.len() {
|
||||
self.current_distance += 1;
|
||||
} else if self.current_min_samples + 1 < self.dbscan_search_parameters.min_samples.len() {
|
||||
self.current_distance = 0;
|
||||
self.current_min_samples += 1;
|
||||
} else if self.current_eps + 1 < self.dbscan_search_parameters.eps.len() {
|
||||
self.current_distance = 0;
|
||||
self.current_min_samples = 0;
|
||||
self.current_eps += 1;
|
||||
} else if self.current_algorithm + 1 < self.dbscan_search_parameters.algorithm.len() {
|
||||
self.current_distance = 0;
|
||||
self.current_min_samples = 0;
|
||||
self.current_eps = 0;
|
||||
self.current_algorithm += 1;
|
||||
} else {
|
||||
self.current_distance += 1;
|
||||
self.current_min_samples += 1;
|
||||
self.current_eps += 1;
|
||||
self.current_algorithm += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for DBSCANSearchParameters<T, Euclidian> {
|
||||
fn default() -> Self {
|
||||
let default_params = DBSCANParameters::default();
|
||||
|
||||
DBSCANSearchParameters {
|
||||
distance: vec![default_params.distance],
|
||||
min_samples: vec![default_params.min_samples],
|
||||
eps: vec![default_params.eps],
|
||||
algorithm: vec![default_params.algorithm],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> PartialEq for DBSCAN<T, D> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.cluster_labels.len() == other.cluster_labels.len()
|
||||
@@ -268,6 +365,29 @@ mod tests {
|
||||
#[cfg(feature = "serde")]
|
||||
use crate::math::distance::euclidian::Euclidian;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = DBSCANSearchParameters {
|
||||
min_samples: vec![10, 100],
|
||||
eps: vec![1., 2.],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.min_samples, 10);
|
||||
assert_eq!(next.eps, 1.);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.min_samples, 100);
|
||||
assert_eq!(next.eps, 1.);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.min_samples, 10);
|
||||
assert_eq!(next.eps, 2.);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.min_samples, 100);
|
||||
assert_eq!(next.eps, 2.);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn fit_predict_dbscan() {
|
||||
|
||||
@@ -132,6 +132,76 @@ impl Default for KMeansParameters {
|
||||
}
|
||||
}
|
||||
|
||||
/// KMeans grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KMeansSearchParameters {
|
||||
/// Number of clusters.
|
||||
pub k: Vec<usize>,
|
||||
/// Maximum number of iterations of the k-means algorithm for a single run.
|
||||
pub max_iter: Vec<usize>,
|
||||
}
|
||||
|
||||
/// KMeans grid search iterator
|
||||
pub struct KMeansSearchParametersIterator {
|
||||
kmeans_search_parameters: KMeansSearchParameters,
|
||||
current_k: usize,
|
||||
current_max_iter: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for KMeansSearchParameters {
|
||||
type Item = KMeansParameters;
|
||||
type IntoIter = KMeansSearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
KMeansSearchParametersIterator {
|
||||
kmeans_search_parameters: self,
|
||||
current_k: 0,
|
||||
current_max_iter: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for KMeansSearchParametersIterator {
|
||||
type Item = KMeansParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_k == self.kmeans_search_parameters.k.len()
|
||||
&& self.current_max_iter == self.kmeans_search_parameters.max_iter.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = KMeansParameters {
|
||||
k: self.kmeans_search_parameters.k[self.current_k],
|
||||
max_iter: self.kmeans_search_parameters.max_iter[self.current_max_iter],
|
||||
};
|
||||
|
||||
if self.current_k + 1 < self.kmeans_search_parameters.k.len() {
|
||||
self.current_k += 1;
|
||||
} else if self.current_max_iter + 1 < self.kmeans_search_parameters.max_iter.len() {
|
||||
self.current_k = 0;
|
||||
self.current_max_iter += 1;
|
||||
} else {
|
||||
self.current_k += 1;
|
||||
self.current_max_iter += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for KMeansSearchParameters {
|
||||
fn default() -> Self {
|
||||
let default_params = KMeansParameters::default();
|
||||
|
||||
KMeansSearchParameters {
|
||||
k: vec![default_params.k],
|
||||
max_iter: vec![default_params.max_iter],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber + Sum, M: Matrix<T>> UnsupervisedEstimator<M, KMeansParameters> for KMeans<T> {
|
||||
fn fit(x: &M, parameters: KMeansParameters) -> Result<Self, Failed> {
|
||||
KMeans::fit(x, parameters)
|
||||
@@ -313,6 +383,29 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = KMeansSearchParameters {
|
||||
k: vec![2, 4],
|
||||
max_iter: vec![10, 100],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.k, 2);
|
||||
assert_eq!(next.max_iter, 10);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.k, 4);
|
||||
assert_eq!(next.max_iter, 10);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.k, 2);
|
||||
assert_eq!(next.max_iter, 100);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.k, 4);
|
||||
assert_eq!(next.max_iter, 100);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn fit_predict_iris() {
|
||||
|
||||
@@ -116,6 +116,81 @@ impl Default for PCAParameters {
|
||||
}
|
||||
}
|
||||
|
||||
/// PCA grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PCASearchParameters {
|
||||
/// Number of components to keep.
|
||||
pub n_components: Vec<usize>,
|
||||
/// By default, covariance matrix is used to compute principal components.
|
||||
/// Enable this flag if you want to use correlation matrix instead.
|
||||
pub use_correlation_matrix: Vec<bool>,
|
||||
}
|
||||
|
||||
/// PCA grid search iterator
|
||||
pub struct PCASearchParametersIterator {
|
||||
pca_search_parameters: PCASearchParameters,
|
||||
current_k: usize,
|
||||
current_use_correlation_matrix: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for PCASearchParameters {
|
||||
type Item = PCAParameters;
|
||||
type IntoIter = PCASearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
PCASearchParametersIterator {
|
||||
pca_search_parameters: self,
|
||||
current_k: 0,
|
||||
current_use_correlation_matrix: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for PCASearchParametersIterator {
|
||||
type Item = PCAParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_k == self.pca_search_parameters.n_components.len()
|
||||
&& self.current_use_correlation_matrix
|
||||
== self.pca_search_parameters.use_correlation_matrix.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = PCAParameters {
|
||||
n_components: self.pca_search_parameters.n_components[self.current_k],
|
||||
use_correlation_matrix: self.pca_search_parameters.use_correlation_matrix
|
||||
[self.current_use_correlation_matrix],
|
||||
};
|
||||
|
||||
if self.current_k + 1 < self.pca_search_parameters.n_components.len() {
|
||||
self.current_k += 1;
|
||||
} else if self.current_use_correlation_matrix + 1
|
||||
< self.pca_search_parameters.use_correlation_matrix.len()
|
||||
{
|
||||
self.current_k = 0;
|
||||
self.current_use_correlation_matrix += 1;
|
||||
} else {
|
||||
self.current_k += 1;
|
||||
self.current_use_correlation_matrix += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PCASearchParameters {
|
||||
fn default() -> Self {
|
||||
let default_params = PCAParameters::default();
|
||||
|
||||
PCASearchParameters {
|
||||
n_components: vec![default_params.n_components],
|
||||
use_correlation_matrix: vec![default_params.use_correlation_matrix],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> UnsupervisedEstimator<M, PCAParameters> for PCA<T, M> {
|
||||
fn fit(x: &M, parameters: PCAParameters) -> Result<Self, Failed> {
|
||||
PCA::fit(x, parameters)
|
||||
@@ -271,6 +346,29 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = PCASearchParameters {
|
||||
n_components: vec![2, 4],
|
||||
use_correlation_matrix: vec![true, false],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_components, 2);
|
||||
assert_eq!(next.use_correlation_matrix, true);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_components, 4);
|
||||
assert_eq!(next.use_correlation_matrix, true);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_components, 2);
|
||||
assert_eq!(next.use_correlation_matrix, false);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_components, 4);
|
||||
assert_eq!(next.use_correlation_matrix, false);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
fn us_arrests_data() -> DenseMatrix<f64> {
|
||||
DenseMatrix::from_2d_array(&[
|
||||
&[13.2, 236.0, 58.0, 21.2],
|
||||
|
||||
@@ -90,6 +90,60 @@ impl SVDParameters {
|
||||
}
|
||||
}
|
||||
|
||||
/// SVD grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SVDSearchParameters {
|
||||
/// Maximum number of iterations of the k-means algorithm for a single run.
|
||||
pub n_components: Vec<usize>,
|
||||
}
|
||||
|
||||
/// SVD grid search iterator
|
||||
pub struct SVDSearchParametersIterator {
|
||||
svd_search_parameters: SVDSearchParameters,
|
||||
current_n_components: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for SVDSearchParameters {
|
||||
type Item = SVDParameters;
|
||||
type IntoIter = SVDSearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
SVDSearchParametersIterator {
|
||||
svd_search_parameters: self,
|
||||
current_n_components: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for SVDSearchParametersIterator {
|
||||
type Item = SVDParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_n_components == self.svd_search_parameters.n_components.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = SVDParameters {
|
||||
n_components: self.svd_search_parameters.n_components[self.current_n_components],
|
||||
};
|
||||
|
||||
self.current_n_components += 1;
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SVDSearchParameters {
|
||||
fn default() -> Self {
|
||||
let default_params = SVDParameters::default();
|
||||
|
||||
SVDSearchParameters {
|
||||
n_components: vec![default_params.n_components],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> UnsupervisedEstimator<M, SVDParameters> for SVD<T, M> {
|
||||
fn fit(x: &M, parameters: SVDParameters) -> Result<Self, Failed> {
|
||||
SVD::fit(x, parameters)
|
||||
@@ -153,6 +207,20 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = SVDSearchParameters {
|
||||
n_components: vec![10, 100],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_components, 10);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_components, 100);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn svd_decompose() {
|
||||
|
||||
@@ -193,6 +193,226 @@ impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RandomForestCla
|
||||
}
|
||||
}
|
||||
|
||||
/// RandomForestClassifier grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RandomForestClassifierSearchParameters {
|
||||
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub criterion: Vec<SplitCriterion>,
|
||||
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub max_depth: Vec<Option<u16>>,
|
||||
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub min_samples_leaf: Vec<usize>,
|
||||
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub min_samples_split: Vec<usize>,
|
||||
/// The number of trees in the forest.
|
||||
pub n_trees: Vec<u16>,
|
||||
/// Number of random sample of predictors to use as split candidates.
|
||||
pub m: Vec<Option<usize>>,
|
||||
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||
pub keep_samples: Vec<bool>,
|
||||
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||
pub seed: Vec<u64>,
|
||||
}
|
||||
|
||||
/// RandomForestClassifier grid search iterator
|
||||
pub struct RandomForestClassifierSearchParametersIterator {
|
||||
random_forest_classifier_search_parameters: RandomForestClassifierSearchParameters,
|
||||
current_criterion: usize,
|
||||
current_max_depth: usize,
|
||||
current_min_samples_leaf: usize,
|
||||
current_min_samples_split: usize,
|
||||
current_n_trees: usize,
|
||||
current_m: usize,
|
||||
current_keep_samples: usize,
|
||||
current_seed: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for RandomForestClassifierSearchParameters {
|
||||
type Item = RandomForestClassifierParameters;
|
||||
type IntoIter = RandomForestClassifierSearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
RandomForestClassifierSearchParametersIterator {
|
||||
random_forest_classifier_search_parameters: self,
|
||||
current_criterion: 0,
|
||||
current_max_depth: 0,
|
||||
current_min_samples_leaf: 0,
|
||||
current_min_samples_split: 0,
|
||||
current_n_trees: 0,
|
||||
current_m: 0,
|
||||
current_keep_samples: 0,
|
||||
current_seed: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for RandomForestClassifierSearchParametersIterator {
|
||||
type Item = RandomForestClassifierParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_criterion
|
||||
== self
|
||||
.random_forest_classifier_search_parameters
|
||||
.criterion
|
||||
.len()
|
||||
&& self.current_max_depth
|
||||
== self
|
||||
.random_forest_classifier_search_parameters
|
||||
.max_depth
|
||||
.len()
|
||||
&& self.current_min_samples_leaf
|
||||
== self
|
||||
.random_forest_classifier_search_parameters
|
||||
.min_samples_leaf
|
||||
.len()
|
||||
&& self.current_min_samples_split
|
||||
== self
|
||||
.random_forest_classifier_search_parameters
|
||||
.min_samples_split
|
||||
.len()
|
||||
&& self.current_n_trees
|
||||
== self
|
||||
.random_forest_classifier_search_parameters
|
||||
.n_trees
|
||||
.len()
|
||||
&& self.current_m == self.random_forest_classifier_search_parameters.m.len()
|
||||
&& self.current_keep_samples
|
||||
== self
|
||||
.random_forest_classifier_search_parameters
|
||||
.keep_samples
|
||||
.len()
|
||||
&& self.current_seed == self.random_forest_classifier_search_parameters.seed.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = RandomForestClassifierParameters {
|
||||
criterion: self.random_forest_classifier_search_parameters.criterion
|
||||
[self.current_criterion]
|
||||
.clone(),
|
||||
max_depth: self.random_forest_classifier_search_parameters.max_depth
|
||||
[self.current_max_depth],
|
||||
min_samples_leaf: self
|
||||
.random_forest_classifier_search_parameters
|
||||
.min_samples_leaf[self.current_min_samples_leaf],
|
||||
min_samples_split: self
|
||||
.random_forest_classifier_search_parameters
|
||||
.min_samples_split[self.current_min_samples_split],
|
||||
n_trees: self.random_forest_classifier_search_parameters.n_trees[self.current_n_trees],
|
||||
m: self.random_forest_classifier_search_parameters.m[self.current_m],
|
||||
keep_samples: self.random_forest_classifier_search_parameters.keep_samples
|
||||
[self.current_keep_samples],
|
||||
seed: self.random_forest_classifier_search_parameters.seed[self.current_seed],
|
||||
};
|
||||
|
||||
if self.current_criterion + 1
|
||||
< self
|
||||
.random_forest_classifier_search_parameters
|
||||
.criterion
|
||||
.len()
|
||||
{
|
||||
self.current_criterion += 1;
|
||||
} else if self.current_max_depth + 1
|
||||
< self
|
||||
.random_forest_classifier_search_parameters
|
||||
.max_depth
|
||||
.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth += 1;
|
||||
} else if self.current_min_samples_leaf + 1
|
||||
< self
|
||||
.random_forest_classifier_search_parameters
|
||||
.min_samples_leaf
|
||||
.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf += 1;
|
||||
} else if self.current_min_samples_split + 1
|
||||
< self
|
||||
.random_forest_classifier_search_parameters
|
||||
.min_samples_split
|
||||
.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split += 1;
|
||||
} else if self.current_n_trees + 1
|
||||
< self
|
||||
.random_forest_classifier_search_parameters
|
||||
.n_trees
|
||||
.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split = 0;
|
||||
self.current_n_trees += 1;
|
||||
} else if self.current_m + 1 < self.random_forest_classifier_search_parameters.m.len() {
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split = 0;
|
||||
self.current_n_trees = 0;
|
||||
self.current_m += 1;
|
||||
} else if self.current_keep_samples + 1
|
||||
< self
|
||||
.random_forest_classifier_search_parameters
|
||||
.keep_samples
|
||||
.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split = 0;
|
||||
self.current_n_trees = 0;
|
||||
self.current_m = 0;
|
||||
self.current_keep_samples += 1;
|
||||
} else if self.current_seed + 1 < self.random_forest_classifier_search_parameters.seed.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split = 0;
|
||||
self.current_n_trees = 0;
|
||||
self.current_m = 0;
|
||||
self.current_keep_samples = 0;
|
||||
self.current_seed += 1;
|
||||
} else {
|
||||
self.current_criterion += 1;
|
||||
self.current_max_depth += 1;
|
||||
self.current_min_samples_leaf += 1;
|
||||
self.current_min_samples_split += 1;
|
||||
self.current_n_trees += 1;
|
||||
self.current_m += 1;
|
||||
self.current_keep_samples += 1;
|
||||
self.current_seed += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RandomForestClassifierSearchParameters {
|
||||
fn default() -> Self {
|
||||
let default_params = RandomForestClassifierParameters::default();
|
||||
|
||||
RandomForestClassifierSearchParameters {
|
||||
criterion: vec![default_params.criterion],
|
||||
max_depth: vec![default_params.max_depth],
|
||||
min_samples_leaf: vec![default_params.min_samples_leaf],
|
||||
min_samples_split: vec![default_params.min_samples_split],
|
||||
n_trees: vec![default_params.n_trees],
|
||||
m: vec![default_params.m],
|
||||
keep_samples: vec![default_params.keep_samples],
|
||||
seed: vec![default_params.seed],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
/// Build a forest of trees from the training set.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
@@ -346,6 +566,29 @@ mod tests {
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::metrics::*;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = RandomForestClassifierSearchParameters {
|
||||
n_trees: vec![10, 100],
|
||||
m: vec![None, Some(1)],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_trees, 10);
|
||||
assert_eq!(next.m, None);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_trees, 100);
|
||||
assert_eq!(next.m, None);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_trees, 10);
|
||||
assert_eq!(next.m, Some(1));
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_trees, 100);
|
||||
assert_eq!(next.m, Some(1));
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn fit_predict_iris() {
|
||||
|
||||
@@ -176,6 +176,191 @@ impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RandomForestReg
|
||||
}
|
||||
}
|
||||
|
||||
/// RandomForestRegressor grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RandomForestRegressorSearchParameters {
|
||||
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub max_depth: Vec<Option<u16>>,
|
||||
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub min_samples_leaf: Vec<usize>,
|
||||
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub min_samples_split: Vec<usize>,
|
||||
/// The number of trees in the forest.
|
||||
pub n_trees: Vec<usize>,
|
||||
/// Number of random sample of predictors to use as split candidates.
|
||||
pub m: Vec<Option<usize>>,
|
||||
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||
pub keep_samples: Vec<bool>,
|
||||
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||
pub seed: Vec<u64>,
|
||||
}
|
||||
|
||||
/// RandomForestRegressor grid search iterator
|
||||
pub struct RandomForestRegressorSearchParametersIterator {
|
||||
random_forest_regressor_search_parameters: RandomForestRegressorSearchParameters,
|
||||
current_max_depth: usize,
|
||||
current_min_samples_leaf: usize,
|
||||
current_min_samples_split: usize,
|
||||
current_n_trees: usize,
|
||||
current_m: usize,
|
||||
current_keep_samples: usize,
|
||||
current_seed: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for RandomForestRegressorSearchParameters {
|
||||
type Item = RandomForestRegressorParameters;
|
||||
type IntoIter = RandomForestRegressorSearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
RandomForestRegressorSearchParametersIterator {
|
||||
random_forest_regressor_search_parameters: self,
|
||||
current_max_depth: 0,
|
||||
current_min_samples_leaf: 0,
|
||||
current_min_samples_split: 0,
|
||||
current_n_trees: 0,
|
||||
current_m: 0,
|
||||
current_keep_samples: 0,
|
||||
current_seed: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for RandomForestRegressorSearchParametersIterator {
|
||||
type Item = RandomForestRegressorParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_max_depth
|
||||
== self
|
||||
.random_forest_regressor_search_parameters
|
||||
.max_depth
|
||||
.len()
|
||||
&& self.current_min_samples_leaf
|
||||
== self
|
||||
.random_forest_regressor_search_parameters
|
||||
.min_samples_leaf
|
||||
.len()
|
||||
&& self.current_min_samples_split
|
||||
== self
|
||||
.random_forest_regressor_search_parameters
|
||||
.min_samples_split
|
||||
.len()
|
||||
&& self.current_n_trees == self.random_forest_regressor_search_parameters.n_trees.len()
|
||||
&& self.current_m == self.random_forest_regressor_search_parameters.m.len()
|
||||
&& self.current_keep_samples
|
||||
== self
|
||||
.random_forest_regressor_search_parameters
|
||||
.keep_samples
|
||||
.len()
|
||||
&& self.current_seed == self.random_forest_regressor_search_parameters.seed.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = RandomForestRegressorParameters {
|
||||
max_depth: self.random_forest_regressor_search_parameters.max_depth
|
||||
[self.current_max_depth],
|
||||
min_samples_leaf: self
|
||||
.random_forest_regressor_search_parameters
|
||||
.min_samples_leaf[self.current_min_samples_leaf],
|
||||
min_samples_split: self
|
||||
.random_forest_regressor_search_parameters
|
||||
.min_samples_split[self.current_min_samples_split],
|
||||
n_trees: self.random_forest_regressor_search_parameters.n_trees[self.current_n_trees],
|
||||
m: self.random_forest_regressor_search_parameters.m[self.current_m],
|
||||
keep_samples: self.random_forest_regressor_search_parameters.keep_samples
|
||||
[self.current_keep_samples],
|
||||
seed: self.random_forest_regressor_search_parameters.seed[self.current_seed],
|
||||
};
|
||||
|
||||
if self.current_max_depth + 1
|
||||
< self
|
||||
.random_forest_regressor_search_parameters
|
||||
.max_depth
|
||||
.len()
|
||||
{
|
||||
self.current_max_depth += 1;
|
||||
} else if self.current_min_samples_leaf + 1
|
||||
< self
|
||||
.random_forest_regressor_search_parameters
|
||||
.min_samples_leaf
|
||||
.len()
|
||||
{
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf += 1;
|
||||
} else if self.current_min_samples_split + 1
|
||||
< self
|
||||
.random_forest_regressor_search_parameters
|
||||
.min_samples_split
|
||||
.len()
|
||||
{
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split += 1;
|
||||
} else if self.current_n_trees + 1
|
||||
< self.random_forest_regressor_search_parameters.n_trees.len()
|
||||
{
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split = 0;
|
||||
self.current_n_trees += 1;
|
||||
} else if self.current_m + 1 < self.random_forest_regressor_search_parameters.m.len() {
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split = 0;
|
||||
self.current_n_trees = 0;
|
||||
self.current_m += 1;
|
||||
} else if self.current_keep_samples + 1
|
||||
< self
|
||||
.random_forest_regressor_search_parameters
|
||||
.keep_samples
|
||||
.len()
|
||||
{
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split = 0;
|
||||
self.current_n_trees = 0;
|
||||
self.current_m = 0;
|
||||
self.current_keep_samples += 1;
|
||||
} else if self.current_seed + 1 < self.random_forest_regressor_search_parameters.seed.len()
|
||||
{
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split = 0;
|
||||
self.current_n_trees = 0;
|
||||
self.current_m = 0;
|
||||
self.current_keep_samples = 0;
|
||||
self.current_seed += 1;
|
||||
} else {
|
||||
self.current_max_depth += 1;
|
||||
self.current_min_samples_leaf += 1;
|
||||
self.current_min_samples_split += 1;
|
||||
self.current_n_trees += 1;
|
||||
self.current_m += 1;
|
||||
self.current_keep_samples += 1;
|
||||
self.current_seed += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RandomForestRegressorSearchParameters {
|
||||
fn default() -> Self {
|
||||
let default_params = RandomForestRegressorParameters::default();
|
||||
|
||||
RandomForestRegressorSearchParameters {
|
||||
max_depth: vec![default_params.max_depth],
|
||||
min_samples_leaf: vec![default_params.min_samples_leaf],
|
||||
min_samples_split: vec![default_params.min_samples_split],
|
||||
n_trees: vec![default_params.n_trees],
|
||||
m: vec![default_params.m],
|
||||
keep_samples: vec![default_params.keep_samples],
|
||||
seed: vec![default_params.seed],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> RandomForestRegressor<T> {
|
||||
/// Build a forest of trees from the training set.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
@@ -302,6 +487,29 @@ mod tests {
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::metrics::mean_absolute_error;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = RandomForestRegressorSearchParameters {
|
||||
n_trees: vec![10, 100],
|
||||
m: vec![None, Some(1)],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_trees, 10);
|
||||
assert_eq!(next.m, None);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_trees, 100);
|
||||
assert_eq!(next.m, None);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_trees, 10);
|
||||
assert_eq!(next.m, Some(1));
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_trees, 100);
|
||||
assert_eq!(next.m, Some(1));
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn fit_longley() {
|
||||
|
||||
+14
-17
@@ -129,7 +129,7 @@ pub struct LassoSearchParameters<T: RealNumber> {
|
||||
|
||||
/// Lasso grid search iterator
|
||||
pub struct LassoSearchParametersIterator<T: RealNumber> {
|
||||
lasso_regression_search_parameters: LassoSearchParameters<T>,
|
||||
lasso_search_parameters: LassoSearchParameters<T>,
|
||||
current_alpha: usize,
|
||||
current_normalize: usize,
|
||||
current_tol: usize,
|
||||
@@ -142,7 +142,7 @@ impl<T: RealNumber> IntoIterator for LassoSearchParameters<T> {
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
LassoSearchParametersIterator {
|
||||
lasso_regression_search_parameters: self,
|
||||
lasso_search_parameters: self,
|
||||
current_alpha: 0,
|
||||
current_normalize: 0,
|
||||
current_tol: 0,
|
||||
@@ -155,34 +155,31 @@ impl<T: RealNumber> Iterator for LassoSearchParametersIterator<T> {
|
||||
type Item = LassoParameters<T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_alpha == self.lasso_regression_search_parameters.alpha.len()
|
||||
&& self.current_normalize == self.lasso_regression_search_parameters.normalize.len()
|
||||
&& self.current_tol == self.lasso_regression_search_parameters.tol.len()
|
||||
&& self.current_max_iter == self.lasso_regression_search_parameters.max_iter.len()
|
||||
if self.current_alpha == self.lasso_search_parameters.alpha.len()
|
||||
&& self.current_normalize == self.lasso_search_parameters.normalize.len()
|
||||
&& self.current_tol == self.lasso_search_parameters.tol.len()
|
||||
&& self.current_max_iter == self.lasso_search_parameters.max_iter.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = LassoParameters {
|
||||
alpha: self.lasso_regression_search_parameters.alpha[self.current_alpha],
|
||||
normalize: self.lasso_regression_search_parameters.normalize[self.current_normalize],
|
||||
tol: self.lasso_regression_search_parameters.tol[self.current_tol],
|
||||
max_iter: self.lasso_regression_search_parameters.max_iter[self.current_max_iter],
|
||||
alpha: self.lasso_search_parameters.alpha[self.current_alpha],
|
||||
normalize: self.lasso_search_parameters.normalize[self.current_normalize],
|
||||
tol: self.lasso_search_parameters.tol[self.current_tol],
|
||||
max_iter: self.lasso_search_parameters.max_iter[self.current_max_iter],
|
||||
};
|
||||
|
||||
if self.current_alpha + 1 < self.lasso_regression_search_parameters.alpha.len() {
|
||||
if self.current_alpha + 1 < self.lasso_search_parameters.alpha.len() {
|
||||
self.current_alpha += 1;
|
||||
} else if self.current_normalize + 1
|
||||
< self.lasso_regression_search_parameters.normalize.len()
|
||||
{
|
||||
} else if self.current_normalize + 1 < self.lasso_search_parameters.normalize.len() {
|
||||
self.current_alpha = 0;
|
||||
self.current_normalize += 1;
|
||||
} else if self.current_tol + 1 < self.lasso_regression_search_parameters.tol.len() {
|
||||
} else if self.current_tol + 1 < self.lasso_search_parameters.tol.len() {
|
||||
self.current_alpha = 0;
|
||||
self.current_normalize = 0;
|
||||
self.current_tol += 1;
|
||||
} else if self.current_max_iter + 1 < self.lasso_regression_search_parameters.max_iter.len()
|
||||
{
|
||||
} else if self.current_max_iter + 1 < self.lasso_search_parameters.max_iter.len() {
|
||||
self.current_alpha = 0;
|
||||
self.current_normalize = 0;
|
||||
self.current_tol = 0;
|
||||
|
||||
@@ -114,4 +114,4 @@ mod tests {
|
||||
|
||||
assert!([0., 1.].contains(&results.parameters.alpha));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -281,7 +281,6 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
|
||||
use crate::metrics::{accuracy, mean_absolute_error};
|
||||
use crate::model_selection::kfold::KFold;
|
||||
use crate::neighbors::knn_regressor::KNNRegressor;
|
||||
|
||||
@@ -150,6 +150,88 @@ impl<T: RealNumber> Default for BernoulliNBParameters<T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// BernoulliNB grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BernoulliNBSearchParameters<T: RealNumber> {
|
||||
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||
pub alpha: Vec<T>,
|
||||
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
|
||||
pub priors: Vec<Option<Vec<T>>>,
|
||||
/// Threshold for binarizing (mapping to booleans) of sample features. If None, input is presumed to already consist of binary vectors.
|
||||
pub binarize: Vec<Option<T>>,
|
||||
}
|
||||
|
||||
/// BernoulliNB grid search iterator
|
||||
pub struct BernoulliNBSearchParametersIterator<T: RealNumber> {
|
||||
bernoulli_nb_search_parameters: BernoulliNBSearchParameters<T>,
|
||||
current_alpha: usize,
|
||||
current_priors: usize,
|
||||
current_binarize: usize,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> IntoIterator for BernoulliNBSearchParameters<T> {
|
||||
type Item = BernoulliNBParameters<T>;
|
||||
type IntoIter = BernoulliNBSearchParametersIterator<T>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
BernoulliNBSearchParametersIterator {
|
||||
bernoulli_nb_search_parameters: self,
|
||||
current_alpha: 0,
|
||||
current_priors: 0,
|
||||
current_binarize: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Iterator for BernoulliNBSearchParametersIterator<T> {
|
||||
type Item = BernoulliNBParameters<T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_alpha == self.bernoulli_nb_search_parameters.alpha.len()
|
||||
&& self.current_priors == self.bernoulli_nb_search_parameters.priors.len()
|
||||
&& self.current_binarize == self.bernoulli_nb_search_parameters.binarize.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = BernoulliNBParameters {
|
||||
alpha: self.bernoulli_nb_search_parameters.alpha[self.current_alpha],
|
||||
priors: self.bernoulli_nb_search_parameters.priors[self.current_priors].clone(),
|
||||
binarize: self.bernoulli_nb_search_parameters.binarize[self.current_binarize],
|
||||
};
|
||||
|
||||
if self.current_alpha + 1 < self.bernoulli_nb_search_parameters.alpha.len() {
|
||||
self.current_alpha += 1;
|
||||
} else if self.current_priors + 1 < self.bernoulli_nb_search_parameters.priors.len() {
|
||||
self.current_alpha = 0;
|
||||
self.current_priors += 1;
|
||||
} else if self.current_binarize + 1 < self.bernoulli_nb_search_parameters.binarize.len() {
|
||||
self.current_alpha = 0;
|
||||
self.current_priors = 0;
|
||||
self.current_binarize += 1;
|
||||
} else {
|
||||
self.current_alpha += 1;
|
||||
self.current_priors += 1;
|
||||
self.current_binarize += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for BernoulliNBSearchParameters<T> {
|
||||
fn default() -> Self {
|
||||
let default_params = BernoulliNBParameters::default();
|
||||
|
||||
BernoulliNBSearchParameters {
|
||||
alpha: vec![default_params.alpha],
|
||||
priors: vec![default_params.priors],
|
||||
binarize: vec![default_params.binarize],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> BernoulliNBDistribution<T> {
|
||||
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
|
||||
/// * `x` - training data.
|
||||
@@ -347,6 +429,20 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = BernoulliNBSearchParameters {
|
||||
alpha: vec![1., 2.],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 1.);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 2.);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn run_bernoulli_naive_bayes() {
|
||||
|
||||
@@ -261,6 +261,60 @@ impl<T: RealNumber> Default for CategoricalNBParameters<T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// CategoricalNB grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CategoricalNBSearchParameters<T: RealNumber> {
|
||||
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||
pub alpha: Vec<T>,
|
||||
}
|
||||
|
||||
/// CategoricalNB grid search iterator
|
||||
pub struct CategoricalNBSearchParametersIterator<T: RealNumber> {
|
||||
categorical_nb_search_parameters: CategoricalNBSearchParameters<T>,
|
||||
current_alpha: usize,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> IntoIterator for CategoricalNBSearchParameters<T> {
|
||||
type Item = CategoricalNBParameters<T>;
|
||||
type IntoIter = CategoricalNBSearchParametersIterator<T>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
CategoricalNBSearchParametersIterator {
|
||||
categorical_nb_search_parameters: self,
|
||||
current_alpha: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Iterator for CategoricalNBSearchParametersIterator<T> {
|
||||
type Item = CategoricalNBParameters<T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_alpha == self.categorical_nb_search_parameters.alpha.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = CategoricalNBParameters {
|
||||
alpha: self.categorical_nb_search_parameters.alpha[self.current_alpha],
|
||||
};
|
||||
|
||||
self.current_alpha += 1;
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for CategoricalNBSearchParameters<T> {
|
||||
fn default() -> Self {
|
||||
let default_params = CategoricalNBParameters::default();
|
||||
|
||||
CategoricalNBSearchParameters {
|
||||
alpha: vec![default_params.alpha],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// CategoricalNB implements the categorical naive Bayes algorithm for categorically distributed data.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, PartialEq)]
|
||||
@@ -351,6 +405,20 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = CategoricalNBSearchParameters {
|
||||
alpha: vec![1., 2.],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 1.);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 2.);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn run_categorical_naive_bayes() {
|
||||
|
||||
@@ -76,7 +76,7 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for GaussianNBDistributio
|
||||
|
||||
/// `GaussianNB` parameters. Use `Default::default()` for default values.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Default, Clone)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GaussianNBParameters<T: RealNumber> {
|
||||
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
|
||||
pub priors: Option<Vec<T>>,
|
||||
@@ -90,6 +90,66 @@ impl<T: RealNumber> GaussianNBParameters<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for GaussianNBParameters<T> {
|
||||
fn default() -> Self {
|
||||
Self { priors: None }
|
||||
}
|
||||
}
|
||||
|
||||
/// GaussianNB grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GaussianNBSearchParameters<T: RealNumber> {
|
||||
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
|
||||
pub priors: Vec<Option<Vec<T>>>,
|
||||
}
|
||||
|
||||
/// GaussianNB grid search iterator
|
||||
pub struct GaussianNBSearchParametersIterator<T: RealNumber> {
|
||||
gaussian_nb_search_parameters: GaussianNBSearchParameters<T>,
|
||||
current_priors: usize,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> IntoIterator for GaussianNBSearchParameters<T> {
|
||||
type Item = GaussianNBParameters<T>;
|
||||
type IntoIter = GaussianNBSearchParametersIterator<T>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
GaussianNBSearchParametersIterator {
|
||||
gaussian_nb_search_parameters: self,
|
||||
current_priors: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Iterator for GaussianNBSearchParametersIterator<T> {
|
||||
type Item = GaussianNBParameters<T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_priors == self.gaussian_nb_search_parameters.priors.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = GaussianNBParameters {
|
||||
priors: self.gaussian_nb_search_parameters.priors[self.current_priors].clone(),
|
||||
};
|
||||
|
||||
self.current_priors += 1;
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for GaussianNBSearchParameters<T> {
|
||||
fn default() -> Self {
|
||||
let default_params = GaussianNBParameters::default();
|
||||
|
||||
GaussianNBSearchParameters {
|
||||
priors: vec![default_params.priors],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> GaussianNBDistribution<T> {
|
||||
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
|
||||
/// * `x` - training data.
|
||||
@@ -260,6 +320,20 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = GaussianNBSearchParameters {
|
||||
priors: vec![Some(vec![1.]), Some(vec![2.])],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.priors, Some(vec![1.]));
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.priors, Some(vec![2.]));
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn run_gaussian_naive_bayes() {
|
||||
|
||||
@@ -114,6 +114,76 @@ impl<T: RealNumber> Default for MultinomialNBParameters<T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// MultinomialNB grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MultinomialNBSearchParameters<T: RealNumber> {
|
||||
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||
pub alpha: Vec<T>,
|
||||
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
|
||||
pub priors: Vec<Option<Vec<T>>>,
|
||||
}
|
||||
|
||||
/// MultinomialNB grid search iterator
|
||||
pub struct MultinomialNBSearchParametersIterator<T: RealNumber> {
|
||||
multinomial_nb_search_parameters: MultinomialNBSearchParameters<T>,
|
||||
current_alpha: usize,
|
||||
current_priors: usize,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> IntoIterator for MultinomialNBSearchParameters<T> {
|
||||
type Item = MultinomialNBParameters<T>;
|
||||
type IntoIter = MultinomialNBSearchParametersIterator<T>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
MultinomialNBSearchParametersIterator {
|
||||
multinomial_nb_search_parameters: self,
|
||||
current_alpha: 0,
|
||||
current_priors: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Iterator for MultinomialNBSearchParametersIterator<T> {
|
||||
type Item = MultinomialNBParameters<T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_alpha == self.multinomial_nb_search_parameters.alpha.len()
|
||||
&& self.current_priors == self.multinomial_nb_search_parameters.priors.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = MultinomialNBParameters {
|
||||
alpha: self.multinomial_nb_search_parameters.alpha[self.current_alpha],
|
||||
priors: self.multinomial_nb_search_parameters.priors[self.current_priors].clone(),
|
||||
};
|
||||
|
||||
if self.current_alpha + 1 < self.multinomial_nb_search_parameters.alpha.len() {
|
||||
self.current_alpha += 1;
|
||||
} else if self.current_priors + 1 < self.multinomial_nb_search_parameters.priors.len() {
|
||||
self.current_alpha = 0;
|
||||
self.current_priors += 1;
|
||||
} else {
|
||||
self.current_alpha += 1;
|
||||
self.current_priors += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for MultinomialNBSearchParameters<T> {
|
||||
fn default() -> Self {
|
||||
let default_params = MultinomialNBParameters::default();
|
||||
|
||||
MultinomialNBSearchParameters {
|
||||
alpha: vec![default_params.alpha],
|
||||
priors: vec![default_params.priors],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> MultinomialNBDistribution<T> {
|
||||
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
|
||||
/// * `x` - training data.
|
||||
@@ -297,6 +367,20 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = MultinomialNBSearchParameters {
|
||||
alpha: vec![1., 2.],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 1.);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 2.);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn run_multinomial_naive_bayes() {
|
||||
|
||||
+5
-5
@@ -33,7 +33,7 @@ use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Defines a kernel function
|
||||
pub trait Kernel<T: RealNumber, V: BaseVector<T>> {
|
||||
pub trait Kernel<T: RealNumber, V: BaseVector<T>>: Clone {
|
||||
/// Apply kernel function to x_i and x_j
|
||||
fn apply(&self, x_i: &V, x_j: &V) -> T;
|
||||
}
|
||||
@@ -95,12 +95,12 @@ impl Kernels {
|
||||
|
||||
/// Linear Kernel
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct LinearKernel {}
|
||||
|
||||
/// Radial basis function (Gaussian) kernel
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct RBFKernel<T: RealNumber> {
|
||||
/// kernel coefficient
|
||||
pub gamma: T,
|
||||
@@ -108,7 +108,7 @@ pub struct RBFKernel<T: RealNumber> {
|
||||
|
||||
/// Polynomial kernel
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct PolynomialKernel<T: RealNumber> {
|
||||
/// degree of the polynomial
|
||||
pub degree: T,
|
||||
@@ -120,7 +120,7 @@ pub struct PolynomialKernel<T: RealNumber> {
|
||||
|
||||
/// Sigmoid (hyperbolic tangent) kernel
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct SigmoidKernel<T: RealNumber> {
|
||||
/// kernel coefficient
|
||||
pub gamma: T,
|
||||
|
||||
+121
@@ -102,6 +102,109 @@ pub struct SVCParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>
|
||||
m: PhantomData<M>,
|
||||
}
|
||||
|
||||
/// SVC grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SVCSearchParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||
/// Number of epochs.
|
||||
pub epoch: Vec<usize>,
|
||||
/// Regularization parameter.
|
||||
pub c: Vec<T>,
|
||||
/// Tolerance for stopping epoch.
|
||||
pub tol: Vec<T>,
|
||||
/// The kernel function.
|
||||
pub kernel: Vec<K>,
|
||||
/// Unused parameter.
|
||||
m: PhantomData<M>,
|
||||
}
|
||||
|
||||
/// SVC grid search iterator
|
||||
pub struct SVCSearchParametersIterator<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||
svc_search_parameters: SVCSearchParameters<T, M, K>,
|
||||
current_epoch: usize,
|
||||
current_c: usize,
|
||||
current_tol: usize,
|
||||
current_kernel: usize,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> IntoIterator
|
||||
for SVCSearchParameters<T, M, K>
|
||||
{
|
||||
type Item = SVCParameters<T, M, K>;
|
||||
type IntoIter = SVCSearchParametersIterator<T, M, K>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
SVCSearchParametersIterator {
|
||||
svc_search_parameters: self,
|
||||
current_epoch: 0,
|
||||
current_c: 0,
|
||||
current_tol: 0,
|
||||
current_kernel: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Iterator
|
||||
for SVCSearchParametersIterator<T, M, K>
|
||||
{
|
||||
type Item = SVCParameters<T, M, K>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_epoch == self.svc_search_parameters.epoch.len()
|
||||
&& self.current_c == self.svc_search_parameters.c.len()
|
||||
&& self.current_tol == self.svc_search_parameters.tol.len()
|
||||
&& self.current_kernel == self.svc_search_parameters.kernel.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = SVCParameters::<T, M, K> {
|
||||
epoch: self.svc_search_parameters.epoch[self.current_epoch],
|
||||
c: self.svc_search_parameters.c[self.current_c],
|
||||
tol: self.svc_search_parameters.tol[self.current_tol],
|
||||
kernel: self.svc_search_parameters.kernel[self.current_kernel].clone(),
|
||||
m: PhantomData,
|
||||
};
|
||||
|
||||
if self.current_epoch + 1 < self.svc_search_parameters.epoch.len() {
|
||||
self.current_epoch += 1;
|
||||
} else if self.current_c + 1 < self.svc_search_parameters.c.len() {
|
||||
self.current_epoch = 0;
|
||||
self.current_c += 1;
|
||||
} else if self.current_tol + 1 < self.svc_search_parameters.tol.len() {
|
||||
self.current_epoch = 0;
|
||||
self.current_c = 0;
|
||||
self.current_tol += 1;
|
||||
} else if self.current_kernel + 1 < self.svc_search_parameters.kernel.len() {
|
||||
self.current_epoch = 0;
|
||||
self.current_c = 0;
|
||||
self.current_tol = 0;
|
||||
self.current_kernel += 1;
|
||||
} else {
|
||||
self.current_epoch += 1;
|
||||
self.current_c += 1;
|
||||
self.current_tol += 1;
|
||||
self.current_kernel += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Default for SVCSearchParameters<T, M, LinearKernel> {
|
||||
fn default() -> Self {
|
||||
let default_params: SVCParameters<T, M, LinearKernel> = SVCParameters::default();
|
||||
|
||||
SVCSearchParameters {
|
||||
epoch: vec![default_params.epoch],
|
||||
c: vec![default_params.c],
|
||||
tol: vec![default_params.tol],
|
||||
kernel: vec![default_params.kernel],
|
||||
m: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
#[cfg_attr(
|
||||
@@ -737,6 +840,24 @@ mod tests {
|
||||
#[cfg(feature = "serde")]
|
||||
use crate::svm::*;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters: SVCSearchParameters<f64, DenseMatrix<f64>, LinearKernel> =
|
||||
SVCSearchParameters {
|
||||
epoch: vec![10, 100],
|
||||
kernel: vec![LinearKernel {}],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.epoch, 10);
|
||||
assert_eq!(next.kernel, LinearKernel {});
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.epoch, 100);
|
||||
assert_eq!(next.kernel, LinearKernel {});
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn svc_fit_predict() {
|
||||
|
||||
+121
@@ -94,6 +94,109 @@ pub struct SVRParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>
|
||||
m: PhantomData<M>,
|
||||
}
|
||||
|
||||
/// SVR grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SVRSearchParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||
/// Epsilon in the epsilon-SVR model.
|
||||
pub eps: Vec<T>,
|
||||
/// Regularization parameter.
|
||||
pub c: Vec<T>,
|
||||
/// Tolerance for stopping eps.
|
||||
pub tol: Vec<T>,
|
||||
/// The kernel function.
|
||||
pub kernel: Vec<K>,
|
||||
/// Unused parameter.
|
||||
m: PhantomData<M>,
|
||||
}
|
||||
|
||||
/// SVR grid search iterator
|
||||
pub struct SVRSearchParametersIterator<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||
svr_search_parameters: SVRSearchParameters<T, M, K>,
|
||||
current_eps: usize,
|
||||
current_c: usize,
|
||||
current_tol: usize,
|
||||
current_kernel: usize,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> IntoIterator
|
||||
for SVRSearchParameters<T, M, K>
|
||||
{
|
||||
type Item = SVRParameters<T, M, K>;
|
||||
type IntoIter = SVRSearchParametersIterator<T, M, K>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
SVRSearchParametersIterator {
|
||||
svr_search_parameters: self,
|
||||
current_eps: 0,
|
||||
current_c: 0,
|
||||
current_tol: 0,
|
||||
current_kernel: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Iterator
|
||||
for SVRSearchParametersIterator<T, M, K>
|
||||
{
|
||||
type Item = SVRParameters<T, M, K>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_eps == self.svr_search_parameters.eps.len()
|
||||
&& self.current_c == self.svr_search_parameters.c.len()
|
||||
&& self.current_tol == self.svr_search_parameters.tol.len()
|
||||
&& self.current_kernel == self.svr_search_parameters.kernel.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = SVRParameters::<T, M, K> {
|
||||
eps: self.svr_search_parameters.eps[self.current_eps],
|
||||
c: self.svr_search_parameters.c[self.current_c],
|
||||
tol: self.svr_search_parameters.tol[self.current_tol],
|
||||
kernel: self.svr_search_parameters.kernel[self.current_kernel].clone(),
|
||||
m: PhantomData,
|
||||
};
|
||||
|
||||
if self.current_eps + 1 < self.svr_search_parameters.eps.len() {
|
||||
self.current_eps += 1;
|
||||
} else if self.current_c + 1 < self.svr_search_parameters.c.len() {
|
||||
self.current_eps = 0;
|
||||
self.current_c += 1;
|
||||
} else if self.current_tol + 1 < self.svr_search_parameters.tol.len() {
|
||||
self.current_eps = 0;
|
||||
self.current_c = 0;
|
||||
self.current_tol += 1;
|
||||
} else if self.current_kernel + 1 < self.svr_search_parameters.kernel.len() {
|
||||
self.current_eps = 0;
|
||||
self.current_c = 0;
|
||||
self.current_tol = 0;
|
||||
self.current_kernel += 1;
|
||||
} else {
|
||||
self.current_eps += 1;
|
||||
self.current_c += 1;
|
||||
self.current_tol += 1;
|
||||
self.current_kernel += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Default for SVRSearchParameters<T, M, LinearKernel> {
|
||||
fn default() -> Self {
|
||||
let default_params: SVRParameters<T, M, LinearKernel> = SVRParameters::default();
|
||||
|
||||
SVRSearchParameters {
|
||||
eps: vec![default_params.eps],
|
||||
c: vec![default_params.c],
|
||||
tol: vec![default_params.tol],
|
||||
kernel: vec![default_params.kernel],
|
||||
m: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
#[cfg_attr(
|
||||
@@ -536,6 +639,24 @@ mod tests {
|
||||
#[cfg(feature = "serde")]
|
||||
use crate::svm::*;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters: SVRSearchParameters<f64, DenseMatrix<f64>, LinearKernel> =
|
||||
SVRSearchParameters {
|
||||
eps: vec![0., 1.],
|
||||
kernel: vec![LinearKernel {}],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.eps, 0.);
|
||||
assert_eq!(next.kernel, LinearKernel {});
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.eps, 1.);
|
||||
assert_eq!(next.kernel, LinearKernel {});
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn svr_fit_predict() {
|
||||
|
||||
@@ -201,6 +201,144 @@ impl Default for DecisionTreeClassifierParameters {
|
||||
}
|
||||
}
|
||||
|
||||
/// DecisionTreeClassifier grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DecisionTreeClassifierSearchParameters {
|
||||
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub criterion: Vec<SplitCriterion>,
|
||||
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub max_depth: Vec<Option<u16>>,
|
||||
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub min_samples_leaf: Vec<usize>,
|
||||
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub min_samples_split: Vec<usize>,
|
||||
}
|
||||
|
||||
/// DecisionTreeClassifier grid search iterator
|
||||
pub struct DecisionTreeClassifierSearchParametersIterator {
|
||||
decision_tree_classifier_search_parameters: DecisionTreeClassifierSearchParameters,
|
||||
current_criterion: usize,
|
||||
current_max_depth: usize,
|
||||
current_min_samples_leaf: usize,
|
||||
current_min_samples_split: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for DecisionTreeClassifierSearchParameters {
|
||||
type Item = DecisionTreeClassifierParameters;
|
||||
type IntoIter = DecisionTreeClassifierSearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
DecisionTreeClassifierSearchParametersIterator {
|
||||
decision_tree_classifier_search_parameters: self,
|
||||
current_criterion: 0,
|
||||
current_max_depth: 0,
|
||||
current_min_samples_leaf: 0,
|
||||
current_min_samples_split: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for DecisionTreeClassifierSearchParametersIterator {
|
||||
type Item = DecisionTreeClassifierParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_criterion
|
||||
== self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.criterion
|
||||
.len()
|
||||
&& self.current_max_depth
|
||||
== self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.max_depth
|
||||
.len()
|
||||
&& self.current_min_samples_leaf
|
||||
== self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.min_samples_leaf
|
||||
.len()
|
||||
&& self.current_min_samples_split
|
||||
== self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.min_samples_split
|
||||
.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = DecisionTreeClassifierParameters {
|
||||
criterion: self.decision_tree_classifier_search_parameters.criterion
|
||||
[self.current_criterion]
|
||||
.clone(),
|
||||
max_depth: self.decision_tree_classifier_search_parameters.max_depth
|
||||
[self.current_max_depth],
|
||||
min_samples_leaf: self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.min_samples_leaf[self.current_min_samples_leaf],
|
||||
min_samples_split: self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.min_samples_split[self.current_min_samples_split],
|
||||
};
|
||||
|
||||
if self.current_criterion + 1
|
||||
< self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.criterion
|
||||
.len()
|
||||
{
|
||||
self.current_criterion += 1;
|
||||
} else if self.current_max_depth + 1
|
||||
< self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.max_depth
|
||||
.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth += 1;
|
||||
} else if self.current_min_samples_leaf + 1
|
||||
< self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.min_samples_leaf
|
||||
.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf += 1;
|
||||
} else if self.current_min_samples_split + 1
|
||||
< self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.min_samples_split
|
||||
.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split += 1;
|
||||
} else {
|
||||
self.current_criterion += 1;
|
||||
self.current_max_depth += 1;
|
||||
self.current_min_samples_leaf += 1;
|
||||
self.current_min_samples_split += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DecisionTreeClassifierSearchParameters {
|
||||
fn default() -> Self {
|
||||
let default_params = DecisionTreeClassifierParameters::default();
|
||||
|
||||
DecisionTreeClassifierSearchParameters {
|
||||
criterion: vec![default_params.criterion],
|
||||
max_depth: vec![default_params.max_depth],
|
||||
min_samples_leaf: vec![default_params.min_samples_leaf],
|
||||
min_samples_split: vec![default_params.min_samples_split],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Node<T> {
|
||||
fn new(index: usize, output: usize) -> Self {
|
||||
Node {
|
||||
@@ -651,6 +789,29 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = DecisionTreeClassifierSearchParameters {
|
||||
max_depth: vec![Some(10), Some(100)],
|
||||
min_samples_split: vec![1, 2],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.max_depth, Some(10));
|
||||
assert_eq!(next.min_samples_split, 1);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.max_depth, Some(100));
|
||||
assert_eq!(next.min_samples_split, 1);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.max_depth, Some(10));
|
||||
assert_eq!(next.min_samples_split, 2);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.max_depth, Some(100));
|
||||
assert_eq!(next.min_samples_split, 2);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn gini_impurity() {
|
||||
|
||||
@@ -134,6 +134,120 @@ impl Default for DecisionTreeRegressorParameters {
|
||||
}
|
||||
}
|
||||
|
||||
/// DecisionTreeRegressor grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DecisionTreeRegressorSearchParameters {
|
||||
/// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||
pub max_depth: Vec<Option<u16>>,
|
||||
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||
pub min_samples_leaf: Vec<usize>,
|
||||
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||
pub min_samples_split: Vec<usize>,
|
||||
}
|
||||
|
||||
/// DecisionTreeRegressor grid search iterator
|
||||
pub struct DecisionTreeRegressorSearchParametersIterator {
|
||||
decision_tree_regressor_search_parameters: DecisionTreeRegressorSearchParameters,
|
||||
current_max_depth: usize,
|
||||
current_min_samples_leaf: usize,
|
||||
current_min_samples_split: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for DecisionTreeRegressorSearchParameters {
|
||||
type Item = DecisionTreeRegressorParameters;
|
||||
type IntoIter = DecisionTreeRegressorSearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
DecisionTreeRegressorSearchParametersIterator {
|
||||
decision_tree_regressor_search_parameters: self,
|
||||
current_max_depth: 0,
|
||||
current_min_samples_leaf: 0,
|
||||
current_min_samples_split: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for DecisionTreeRegressorSearchParametersIterator {
|
||||
type Item = DecisionTreeRegressorParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_max_depth
|
||||
== self
|
||||
.decision_tree_regressor_search_parameters
|
||||
.max_depth
|
||||
.len()
|
||||
&& self.current_min_samples_leaf
|
||||
== self
|
||||
.decision_tree_regressor_search_parameters
|
||||
.min_samples_leaf
|
||||
.len()
|
||||
&& self.current_min_samples_split
|
||||
== self
|
||||
.decision_tree_regressor_search_parameters
|
||||
.min_samples_split
|
||||
.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = DecisionTreeRegressorParameters {
|
||||
max_depth: self.decision_tree_regressor_search_parameters.max_depth
|
||||
[self.current_max_depth],
|
||||
min_samples_leaf: self
|
||||
.decision_tree_regressor_search_parameters
|
||||
.min_samples_leaf[self.current_min_samples_leaf],
|
||||
min_samples_split: self
|
||||
.decision_tree_regressor_search_parameters
|
||||
.min_samples_split[self.current_min_samples_split],
|
||||
};
|
||||
|
||||
if self.current_max_depth + 1
|
||||
< self
|
||||
.decision_tree_regressor_search_parameters
|
||||
.max_depth
|
||||
.len()
|
||||
{
|
||||
self.current_max_depth += 1;
|
||||
} else if self.current_min_samples_leaf + 1
|
||||
< self
|
||||
.decision_tree_regressor_search_parameters
|
||||
.min_samples_leaf
|
||||
.len()
|
||||
{
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf += 1;
|
||||
} else if self.current_min_samples_split + 1
|
||||
< self
|
||||
.decision_tree_regressor_search_parameters
|
||||
.min_samples_split
|
||||
.len()
|
||||
{
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split += 1;
|
||||
} else {
|
||||
self.current_max_depth += 1;
|
||||
self.current_min_samples_leaf += 1;
|
||||
self.current_min_samples_split += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DecisionTreeRegressorSearchParameters {
|
||||
fn default() -> Self {
|
||||
let default_params = DecisionTreeRegressorParameters::default();
|
||||
|
||||
DecisionTreeRegressorSearchParameters {
|
||||
max_depth: vec![default_params.max_depth],
|
||||
min_samples_leaf: vec![default_params.min_samples_leaf],
|
||||
min_samples_split: vec![default_params.min_samples_split],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Node<T> {
|
||||
fn new(index: usize, output: T) -> Self {
|
||||
Node {
|
||||
@@ -517,6 +631,29 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = DecisionTreeRegressorSearchParameters {
|
||||
max_depth: vec![Some(10), Some(100)],
|
||||
min_samples_split: vec![1, 2],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.max_depth, Some(10));
|
||||
assert_eq!(next.min_samples_split, 1);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.max_depth, Some(100));
|
||||
assert_eq!(next.min_samples_split, 1);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.max_depth, Some(10));
|
||||
assert_eq!(next.min_samples_split, 2);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.max_depth, Some(100));
|
||||
assert_eq!(next.min_samples_split, 2);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn fit_longley() {
|
||||
|
||||
Reference in New Issue
Block a user