Complete grid search params (#166)
* grid search draft * hyperparam search for linear estimators * grid search for ensembles * support grid search for more algos * grid search for unsupervised algos * minor cleanup
This commit is contained in:
@@ -109,6 +109,103 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> DBSCANParameters<T, D> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// DBSCAN grid search parameters
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct DBSCANSearchParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||||
|
/// a function that defines a distance between each pair of point in training data.
|
||||||
|
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||||
|
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||||
|
pub distance: Vec<D>,
|
||||||
|
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
|
||||||
|
pub min_samples: Vec<usize>,
|
||||||
|
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
|
||||||
|
pub eps: Vec<T>,
|
||||||
|
/// KNN algorithm to use.
|
||||||
|
pub algorithm: Vec<KNNAlgorithmName>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// DBSCAN grid search iterator
|
||||||
|
pub struct DBSCANSearchParametersIterator<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||||
|
dbscan_search_parameters: DBSCANSearchParameters<T, D>,
|
||||||
|
current_distance: usize,
|
||||||
|
current_min_samples: usize,
|
||||||
|
current_eps: usize,
|
||||||
|
current_algorithm: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber, D: Distance<Vec<T>, T>> IntoIterator for DBSCANSearchParameters<T, D> {
|
||||||
|
type Item = DBSCANParameters<T, D>;
|
||||||
|
type IntoIter = DBSCANSearchParametersIterator<T, D>;
|
||||||
|
|
||||||
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
|
DBSCANSearchParametersIterator {
|
||||||
|
dbscan_search_parameters: self,
|
||||||
|
current_distance: 0,
|
||||||
|
current_min_samples: 0,
|
||||||
|
current_eps: 0,
|
||||||
|
current_algorithm: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber, D: Distance<Vec<T>, T>> Iterator for DBSCANSearchParametersIterator<T, D> {
|
||||||
|
type Item = DBSCANParameters<T, D>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
if self.current_distance == self.dbscan_search_parameters.distance.len()
|
||||||
|
&& self.current_min_samples == self.dbscan_search_parameters.min_samples.len()
|
||||||
|
&& self.current_eps == self.dbscan_search_parameters.eps.len()
|
||||||
|
&& self.current_algorithm == self.dbscan_search_parameters.algorithm.len()
|
||||||
|
{
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let next = DBSCANParameters {
|
||||||
|
distance: self.dbscan_search_parameters.distance[self.current_distance].clone(),
|
||||||
|
min_samples: self.dbscan_search_parameters.min_samples[self.current_min_samples],
|
||||||
|
eps: self.dbscan_search_parameters.eps[self.current_eps],
|
||||||
|
algorithm: self.dbscan_search_parameters.algorithm[self.current_algorithm].clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
if self.current_distance + 1 < self.dbscan_search_parameters.distance.len() {
|
||||||
|
self.current_distance += 1;
|
||||||
|
} else if self.current_min_samples + 1 < self.dbscan_search_parameters.min_samples.len() {
|
||||||
|
self.current_distance = 0;
|
||||||
|
self.current_min_samples += 1;
|
||||||
|
} else if self.current_eps + 1 < self.dbscan_search_parameters.eps.len() {
|
||||||
|
self.current_distance = 0;
|
||||||
|
self.current_min_samples = 0;
|
||||||
|
self.current_eps += 1;
|
||||||
|
} else if self.current_algorithm + 1 < self.dbscan_search_parameters.algorithm.len() {
|
||||||
|
self.current_distance = 0;
|
||||||
|
self.current_min_samples = 0;
|
||||||
|
self.current_eps = 0;
|
||||||
|
self.current_algorithm += 1;
|
||||||
|
} else {
|
||||||
|
self.current_distance += 1;
|
||||||
|
self.current_min_samples += 1;
|
||||||
|
self.current_eps += 1;
|
||||||
|
self.current_algorithm += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber> Default for DBSCANSearchParameters<T, Euclidian> {
|
||||||
|
fn default() -> Self {
|
||||||
|
let default_params = DBSCANParameters::default();
|
||||||
|
|
||||||
|
DBSCANSearchParameters {
|
||||||
|
distance: vec![default_params.distance],
|
||||||
|
min_samples: vec![default_params.min_samples],
|
||||||
|
eps: vec![default_params.eps],
|
||||||
|
algorithm: vec![default_params.algorithm],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> PartialEq for DBSCAN<T, D> {
|
impl<T: RealNumber, D: Distance<Vec<T>, T>> PartialEq for DBSCAN<T, D> {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
self.cluster_labels.len() == other.cluster_labels.len()
|
self.cluster_labels.len() == other.cluster_labels.len()
|
||||||
@@ -268,6 +365,29 @@ mod tests {
|
|||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
use crate::math::distance::euclidian::Euclidian;
|
use crate::math::distance::euclidian::Euclidian;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn search_parameters() {
|
||||||
|
let parameters = DBSCANSearchParameters {
|
||||||
|
min_samples: vec![10, 100],
|
||||||
|
eps: vec![1., 2.],
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let mut iter = parameters.into_iter();
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.min_samples, 10);
|
||||||
|
assert_eq!(next.eps, 1.);
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.min_samples, 100);
|
||||||
|
assert_eq!(next.eps, 1.);
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.min_samples, 10);
|
||||||
|
assert_eq!(next.eps, 2.);
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.min_samples, 100);
|
||||||
|
assert_eq!(next.eps, 2.);
|
||||||
|
assert!(iter.next().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn fit_predict_dbscan() {
|
fn fit_predict_dbscan() {
|
||||||
|
|||||||
@@ -132,6 +132,76 @@ impl Default for KMeansParameters {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// KMeans grid search parameters
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct KMeansSearchParameters {
|
||||||
|
/// Number of clusters.
|
||||||
|
pub k: Vec<usize>,
|
||||||
|
/// Maximum number of iterations of the k-means algorithm for a single run.
|
||||||
|
pub max_iter: Vec<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// KMeans grid search iterator
|
||||||
|
pub struct KMeansSearchParametersIterator {
|
||||||
|
kmeans_search_parameters: KMeansSearchParameters,
|
||||||
|
current_k: usize,
|
||||||
|
current_max_iter: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoIterator for KMeansSearchParameters {
|
||||||
|
type Item = KMeansParameters;
|
||||||
|
type IntoIter = KMeansSearchParametersIterator;
|
||||||
|
|
||||||
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
|
KMeansSearchParametersIterator {
|
||||||
|
kmeans_search_parameters: self,
|
||||||
|
current_k: 0,
|
||||||
|
current_max_iter: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Iterator for KMeansSearchParametersIterator {
|
||||||
|
type Item = KMeansParameters;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
if self.current_k == self.kmeans_search_parameters.k.len()
|
||||||
|
&& self.current_max_iter == self.kmeans_search_parameters.max_iter.len()
|
||||||
|
{
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let next = KMeansParameters {
|
||||||
|
k: self.kmeans_search_parameters.k[self.current_k],
|
||||||
|
max_iter: self.kmeans_search_parameters.max_iter[self.current_max_iter],
|
||||||
|
};
|
||||||
|
|
||||||
|
if self.current_k + 1 < self.kmeans_search_parameters.k.len() {
|
||||||
|
self.current_k += 1;
|
||||||
|
} else if self.current_max_iter + 1 < self.kmeans_search_parameters.max_iter.len() {
|
||||||
|
self.current_k = 0;
|
||||||
|
self.current_max_iter += 1;
|
||||||
|
} else {
|
||||||
|
self.current_k += 1;
|
||||||
|
self.current_max_iter += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for KMeansSearchParameters {
|
||||||
|
fn default() -> Self {
|
||||||
|
let default_params = KMeansParameters::default();
|
||||||
|
|
||||||
|
KMeansSearchParameters {
|
||||||
|
k: vec![default_params.k],
|
||||||
|
max_iter: vec![default_params.max_iter],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: RealNumber + Sum, M: Matrix<T>> UnsupervisedEstimator<M, KMeansParameters> for KMeans<T> {
|
impl<T: RealNumber + Sum, M: Matrix<T>> UnsupervisedEstimator<M, KMeansParameters> for KMeans<T> {
|
||||||
fn fit(x: &M, parameters: KMeansParameters) -> Result<Self, Failed> {
|
fn fit(x: &M, parameters: KMeansParameters) -> Result<Self, Failed> {
|
||||||
KMeans::fit(x, parameters)
|
KMeans::fit(x, parameters)
|
||||||
@@ -313,6 +383,29 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn search_parameters() {
|
||||||
|
let parameters = KMeansSearchParameters {
|
||||||
|
k: vec![2, 4],
|
||||||
|
max_iter: vec![10, 100],
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let mut iter = parameters.into_iter();
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.k, 2);
|
||||||
|
assert_eq!(next.max_iter, 10);
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.k, 4);
|
||||||
|
assert_eq!(next.max_iter, 10);
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.k, 2);
|
||||||
|
assert_eq!(next.max_iter, 100);
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.k, 4);
|
||||||
|
assert_eq!(next.max_iter, 100);
|
||||||
|
assert!(iter.next().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn fit_predict_iris() {
|
fn fit_predict_iris() {
|
||||||
|
|||||||
@@ -116,6 +116,81 @@ impl Default for PCAParameters {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// PCA grid search parameters
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct PCASearchParameters {
|
||||||
|
/// Number of components to keep.
|
||||||
|
pub n_components: Vec<usize>,
|
||||||
|
/// By default, covariance matrix is used to compute principal components.
|
||||||
|
/// Enable this flag if you want to use correlation matrix instead.
|
||||||
|
pub use_correlation_matrix: Vec<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// PCA grid search iterator
|
||||||
|
pub struct PCASearchParametersIterator {
|
||||||
|
pca_search_parameters: PCASearchParameters,
|
||||||
|
current_k: usize,
|
||||||
|
current_use_correlation_matrix: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoIterator for PCASearchParameters {
|
||||||
|
type Item = PCAParameters;
|
||||||
|
type IntoIter = PCASearchParametersIterator;
|
||||||
|
|
||||||
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
|
PCASearchParametersIterator {
|
||||||
|
pca_search_parameters: self,
|
||||||
|
current_k: 0,
|
||||||
|
current_use_correlation_matrix: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Iterator for PCASearchParametersIterator {
|
||||||
|
type Item = PCAParameters;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
if self.current_k == self.pca_search_parameters.n_components.len()
|
||||||
|
&& self.current_use_correlation_matrix
|
||||||
|
== self.pca_search_parameters.use_correlation_matrix.len()
|
||||||
|
{
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let next = PCAParameters {
|
||||||
|
n_components: self.pca_search_parameters.n_components[self.current_k],
|
||||||
|
use_correlation_matrix: self.pca_search_parameters.use_correlation_matrix
|
||||||
|
[self.current_use_correlation_matrix],
|
||||||
|
};
|
||||||
|
|
||||||
|
if self.current_k + 1 < self.pca_search_parameters.n_components.len() {
|
||||||
|
self.current_k += 1;
|
||||||
|
} else if self.current_use_correlation_matrix + 1
|
||||||
|
< self.pca_search_parameters.use_correlation_matrix.len()
|
||||||
|
{
|
||||||
|
self.current_k = 0;
|
||||||
|
self.current_use_correlation_matrix += 1;
|
||||||
|
} else {
|
||||||
|
self.current_k += 1;
|
||||||
|
self.current_use_correlation_matrix += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for PCASearchParameters {
|
||||||
|
fn default() -> Self {
|
||||||
|
let default_params = PCAParameters::default();
|
||||||
|
|
||||||
|
PCASearchParameters {
|
||||||
|
n_components: vec![default_params.n_components],
|
||||||
|
use_correlation_matrix: vec![default_params.use_correlation_matrix],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: RealNumber, M: Matrix<T>> UnsupervisedEstimator<M, PCAParameters> for PCA<T, M> {
|
impl<T: RealNumber, M: Matrix<T>> UnsupervisedEstimator<M, PCAParameters> for PCA<T, M> {
|
||||||
fn fit(x: &M, parameters: PCAParameters) -> Result<Self, Failed> {
|
fn fit(x: &M, parameters: PCAParameters) -> Result<Self, Failed> {
|
||||||
PCA::fit(x, parameters)
|
PCA::fit(x, parameters)
|
||||||
@@ -271,6 +346,29 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::*;
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn search_parameters() {
|
||||||
|
let parameters = PCASearchParameters {
|
||||||
|
n_components: vec![2, 4],
|
||||||
|
use_correlation_matrix: vec![true, false],
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let mut iter = parameters.into_iter();
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.n_components, 2);
|
||||||
|
assert_eq!(next.use_correlation_matrix, true);
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.n_components, 4);
|
||||||
|
assert_eq!(next.use_correlation_matrix, true);
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.n_components, 2);
|
||||||
|
assert_eq!(next.use_correlation_matrix, false);
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.n_components, 4);
|
||||||
|
assert_eq!(next.use_correlation_matrix, false);
|
||||||
|
assert!(iter.next().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
fn us_arrests_data() -> DenseMatrix<f64> {
|
fn us_arrests_data() -> DenseMatrix<f64> {
|
||||||
DenseMatrix::from_2d_array(&[
|
DenseMatrix::from_2d_array(&[
|
||||||
&[13.2, 236.0, 58.0, 21.2],
|
&[13.2, 236.0, 58.0, 21.2],
|
||||||
|
|||||||
@@ -90,6 +90,60 @@ impl SVDParameters {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// SVD grid search parameters
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct SVDSearchParameters {
|
||||||
|
/// Maximum number of iterations of the k-means algorithm for a single run.
|
||||||
|
pub n_components: Vec<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// SVD grid search iterator
|
||||||
|
pub struct SVDSearchParametersIterator {
|
||||||
|
svd_search_parameters: SVDSearchParameters,
|
||||||
|
current_n_components: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoIterator for SVDSearchParameters {
|
||||||
|
type Item = SVDParameters;
|
||||||
|
type IntoIter = SVDSearchParametersIterator;
|
||||||
|
|
||||||
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
|
SVDSearchParametersIterator {
|
||||||
|
svd_search_parameters: self,
|
||||||
|
current_n_components: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Iterator for SVDSearchParametersIterator {
|
||||||
|
type Item = SVDParameters;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
if self.current_n_components == self.svd_search_parameters.n_components.len() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let next = SVDParameters {
|
||||||
|
n_components: self.svd_search_parameters.n_components[self.current_n_components],
|
||||||
|
};
|
||||||
|
|
||||||
|
self.current_n_components += 1;
|
||||||
|
|
||||||
|
Some(next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for SVDSearchParameters {
|
||||||
|
fn default() -> Self {
|
||||||
|
let default_params = SVDParameters::default();
|
||||||
|
|
||||||
|
SVDSearchParameters {
|
||||||
|
n_components: vec![default_params.n_components],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: RealNumber, M: Matrix<T>> UnsupervisedEstimator<M, SVDParameters> for SVD<T, M> {
|
impl<T: RealNumber, M: Matrix<T>> UnsupervisedEstimator<M, SVDParameters> for SVD<T, M> {
|
||||||
fn fit(x: &M, parameters: SVDParameters) -> Result<Self, Failed> {
|
fn fit(x: &M, parameters: SVDParameters) -> Result<Self, Failed> {
|
||||||
SVD::fit(x, parameters)
|
SVD::fit(x, parameters)
|
||||||
@@ -153,6 +207,20 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::*;
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn search_parameters() {
|
||||||
|
let parameters = SVDSearchParameters {
|
||||||
|
n_components: vec![10, 100],
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let mut iter = parameters.into_iter();
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.n_components, 10);
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.n_components, 100);
|
||||||
|
assert!(iter.next().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn svd_decompose() {
|
fn svd_decompose() {
|
||||||
|
|||||||
@@ -193,6 +193,226 @@ impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RandomForestCla
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// RandomForestClassifier grid search parameters
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct RandomForestClassifierSearchParameters {
|
||||||
|
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||||
|
pub criterion: Vec<SplitCriterion>,
|
||||||
|
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||||
|
pub max_depth: Vec<Option<u16>>,
|
||||||
|
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||||
|
pub min_samples_leaf: Vec<usize>,
|
||||||
|
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||||
|
pub min_samples_split: Vec<usize>,
|
||||||
|
/// The number of trees in the forest.
|
||||||
|
pub n_trees: Vec<u16>,
|
||||||
|
/// Number of random sample of predictors to use as split candidates.
|
||||||
|
pub m: Vec<Option<usize>>,
|
||||||
|
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||||
|
pub keep_samples: Vec<bool>,
|
||||||
|
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||||
|
pub seed: Vec<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// RandomForestClassifier grid search iterator
|
||||||
|
pub struct RandomForestClassifierSearchParametersIterator {
|
||||||
|
random_forest_classifier_search_parameters: RandomForestClassifierSearchParameters,
|
||||||
|
current_criterion: usize,
|
||||||
|
current_max_depth: usize,
|
||||||
|
current_min_samples_leaf: usize,
|
||||||
|
current_min_samples_split: usize,
|
||||||
|
current_n_trees: usize,
|
||||||
|
current_m: usize,
|
||||||
|
current_keep_samples: usize,
|
||||||
|
current_seed: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoIterator for RandomForestClassifierSearchParameters {
|
||||||
|
type Item = RandomForestClassifierParameters;
|
||||||
|
type IntoIter = RandomForestClassifierSearchParametersIterator;
|
||||||
|
|
||||||
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
|
RandomForestClassifierSearchParametersIterator {
|
||||||
|
random_forest_classifier_search_parameters: self,
|
||||||
|
current_criterion: 0,
|
||||||
|
current_max_depth: 0,
|
||||||
|
current_min_samples_leaf: 0,
|
||||||
|
current_min_samples_split: 0,
|
||||||
|
current_n_trees: 0,
|
||||||
|
current_m: 0,
|
||||||
|
current_keep_samples: 0,
|
||||||
|
current_seed: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Iterator for RandomForestClassifierSearchParametersIterator {
|
||||||
|
type Item = RandomForestClassifierParameters;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
if self.current_criterion
|
||||||
|
== self
|
||||||
|
.random_forest_classifier_search_parameters
|
||||||
|
.criterion
|
||||||
|
.len()
|
||||||
|
&& self.current_max_depth
|
||||||
|
== self
|
||||||
|
.random_forest_classifier_search_parameters
|
||||||
|
.max_depth
|
||||||
|
.len()
|
||||||
|
&& self.current_min_samples_leaf
|
||||||
|
== self
|
||||||
|
.random_forest_classifier_search_parameters
|
||||||
|
.min_samples_leaf
|
||||||
|
.len()
|
||||||
|
&& self.current_min_samples_split
|
||||||
|
== self
|
||||||
|
.random_forest_classifier_search_parameters
|
||||||
|
.min_samples_split
|
||||||
|
.len()
|
||||||
|
&& self.current_n_trees
|
||||||
|
== self
|
||||||
|
.random_forest_classifier_search_parameters
|
||||||
|
.n_trees
|
||||||
|
.len()
|
||||||
|
&& self.current_m == self.random_forest_classifier_search_parameters.m.len()
|
||||||
|
&& self.current_keep_samples
|
||||||
|
== self
|
||||||
|
.random_forest_classifier_search_parameters
|
||||||
|
.keep_samples
|
||||||
|
.len()
|
||||||
|
&& self.current_seed == self.random_forest_classifier_search_parameters.seed.len()
|
||||||
|
{
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let next = RandomForestClassifierParameters {
|
||||||
|
criterion: self.random_forest_classifier_search_parameters.criterion
|
||||||
|
[self.current_criterion]
|
||||||
|
.clone(),
|
||||||
|
max_depth: self.random_forest_classifier_search_parameters.max_depth
|
||||||
|
[self.current_max_depth],
|
||||||
|
min_samples_leaf: self
|
||||||
|
.random_forest_classifier_search_parameters
|
||||||
|
.min_samples_leaf[self.current_min_samples_leaf],
|
||||||
|
min_samples_split: self
|
||||||
|
.random_forest_classifier_search_parameters
|
||||||
|
.min_samples_split[self.current_min_samples_split],
|
||||||
|
n_trees: self.random_forest_classifier_search_parameters.n_trees[self.current_n_trees],
|
||||||
|
m: self.random_forest_classifier_search_parameters.m[self.current_m],
|
||||||
|
keep_samples: self.random_forest_classifier_search_parameters.keep_samples
|
||||||
|
[self.current_keep_samples],
|
||||||
|
seed: self.random_forest_classifier_search_parameters.seed[self.current_seed],
|
||||||
|
};
|
||||||
|
|
||||||
|
if self.current_criterion + 1
|
||||||
|
< self
|
||||||
|
.random_forest_classifier_search_parameters
|
||||||
|
.criterion
|
||||||
|
.len()
|
||||||
|
{
|
||||||
|
self.current_criterion += 1;
|
||||||
|
} else if self.current_max_depth + 1
|
||||||
|
< self
|
||||||
|
.random_forest_classifier_search_parameters
|
||||||
|
.max_depth
|
||||||
|
.len()
|
||||||
|
{
|
||||||
|
self.current_criterion = 0;
|
||||||
|
self.current_max_depth += 1;
|
||||||
|
} else if self.current_min_samples_leaf + 1
|
||||||
|
< self
|
||||||
|
.random_forest_classifier_search_parameters
|
||||||
|
.min_samples_leaf
|
||||||
|
.len()
|
||||||
|
{
|
||||||
|
self.current_criterion = 0;
|
||||||
|
self.current_max_depth = 0;
|
||||||
|
self.current_min_samples_leaf += 1;
|
||||||
|
} else if self.current_min_samples_split + 1
|
||||||
|
< self
|
||||||
|
.random_forest_classifier_search_parameters
|
||||||
|
.min_samples_split
|
||||||
|
.len()
|
||||||
|
{
|
||||||
|
self.current_criterion = 0;
|
||||||
|
self.current_max_depth = 0;
|
||||||
|
self.current_min_samples_leaf = 0;
|
||||||
|
self.current_min_samples_split += 1;
|
||||||
|
} else if self.current_n_trees + 1
|
||||||
|
< self
|
||||||
|
.random_forest_classifier_search_parameters
|
||||||
|
.n_trees
|
||||||
|
.len()
|
||||||
|
{
|
||||||
|
self.current_criterion = 0;
|
||||||
|
self.current_max_depth = 0;
|
||||||
|
self.current_min_samples_leaf = 0;
|
||||||
|
self.current_min_samples_split = 0;
|
||||||
|
self.current_n_trees += 1;
|
||||||
|
} else if self.current_m + 1 < self.random_forest_classifier_search_parameters.m.len() {
|
||||||
|
self.current_criterion = 0;
|
||||||
|
self.current_max_depth = 0;
|
||||||
|
self.current_min_samples_leaf = 0;
|
||||||
|
self.current_min_samples_split = 0;
|
||||||
|
self.current_n_trees = 0;
|
||||||
|
self.current_m += 1;
|
||||||
|
} else if self.current_keep_samples + 1
|
||||||
|
< self
|
||||||
|
.random_forest_classifier_search_parameters
|
||||||
|
.keep_samples
|
||||||
|
.len()
|
||||||
|
{
|
||||||
|
self.current_criterion = 0;
|
||||||
|
self.current_max_depth = 0;
|
||||||
|
self.current_min_samples_leaf = 0;
|
||||||
|
self.current_min_samples_split = 0;
|
||||||
|
self.current_n_trees = 0;
|
||||||
|
self.current_m = 0;
|
||||||
|
self.current_keep_samples += 1;
|
||||||
|
} else if self.current_seed + 1 < self.random_forest_classifier_search_parameters.seed.len()
|
||||||
|
{
|
||||||
|
self.current_criterion = 0;
|
||||||
|
self.current_max_depth = 0;
|
||||||
|
self.current_min_samples_leaf = 0;
|
||||||
|
self.current_min_samples_split = 0;
|
||||||
|
self.current_n_trees = 0;
|
||||||
|
self.current_m = 0;
|
||||||
|
self.current_keep_samples = 0;
|
||||||
|
self.current_seed += 1;
|
||||||
|
} else {
|
||||||
|
self.current_criterion += 1;
|
||||||
|
self.current_max_depth += 1;
|
||||||
|
self.current_min_samples_leaf += 1;
|
||||||
|
self.current_min_samples_split += 1;
|
||||||
|
self.current_n_trees += 1;
|
||||||
|
self.current_m += 1;
|
||||||
|
self.current_keep_samples += 1;
|
||||||
|
self.current_seed += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for RandomForestClassifierSearchParameters {
|
||||||
|
fn default() -> Self {
|
||||||
|
let default_params = RandomForestClassifierParameters::default();
|
||||||
|
|
||||||
|
RandomForestClassifierSearchParameters {
|
||||||
|
criterion: vec![default_params.criterion],
|
||||||
|
max_depth: vec![default_params.max_depth],
|
||||||
|
min_samples_leaf: vec![default_params.min_samples_leaf],
|
||||||
|
min_samples_split: vec![default_params.min_samples_split],
|
||||||
|
n_trees: vec![default_params.n_trees],
|
||||||
|
m: vec![default_params.m],
|
||||||
|
keep_samples: vec![default_params.keep_samples],
|
||||||
|
seed: vec![default_params.seed],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: RealNumber> RandomForestClassifier<T> {
|
impl<T: RealNumber> RandomForestClassifier<T> {
|
||||||
/// Build a forest of trees from the training set.
|
/// Build a forest of trees from the training set.
|
||||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||||
@@ -346,6 +566,29 @@ mod tests {
|
|||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
use crate::metrics::*;
|
use crate::metrics::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn search_parameters() {
|
||||||
|
let parameters = RandomForestClassifierSearchParameters {
|
||||||
|
n_trees: vec![10, 100],
|
||||||
|
m: vec![None, Some(1)],
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let mut iter = parameters.into_iter();
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.n_trees, 10);
|
||||||
|
assert_eq!(next.m, None);
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.n_trees, 100);
|
||||||
|
assert_eq!(next.m, None);
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.n_trees, 10);
|
||||||
|
assert_eq!(next.m, Some(1));
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.n_trees, 100);
|
||||||
|
assert_eq!(next.m, Some(1));
|
||||||
|
assert!(iter.next().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn fit_predict_iris() {
|
fn fit_predict_iris() {
|
||||||
|
|||||||
@@ -176,6 +176,191 @@ impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RandomForestReg
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// RandomForestRegressor grid search parameters
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct RandomForestRegressorSearchParameters {
|
||||||
|
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||||
|
pub max_depth: Vec<Option<u16>>,
|
||||||
|
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||||
|
pub min_samples_leaf: Vec<usize>,
|
||||||
|
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||||
|
pub min_samples_split: Vec<usize>,
|
||||||
|
/// The number of trees in the forest.
|
||||||
|
pub n_trees: Vec<usize>,
|
||||||
|
/// Number of random sample of predictors to use as split candidates.
|
||||||
|
pub m: Vec<Option<usize>>,
|
||||||
|
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||||
|
pub keep_samples: Vec<bool>,
|
||||||
|
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||||
|
pub seed: Vec<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// RandomForestRegressor grid search iterator
|
||||||
|
pub struct RandomForestRegressorSearchParametersIterator {
|
||||||
|
random_forest_regressor_search_parameters: RandomForestRegressorSearchParameters,
|
||||||
|
current_max_depth: usize,
|
||||||
|
current_min_samples_leaf: usize,
|
||||||
|
current_min_samples_split: usize,
|
||||||
|
current_n_trees: usize,
|
||||||
|
current_m: usize,
|
||||||
|
current_keep_samples: usize,
|
||||||
|
current_seed: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoIterator for RandomForestRegressorSearchParameters {
|
||||||
|
type Item = RandomForestRegressorParameters;
|
||||||
|
type IntoIter = RandomForestRegressorSearchParametersIterator;
|
||||||
|
|
||||||
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
|
RandomForestRegressorSearchParametersIterator {
|
||||||
|
random_forest_regressor_search_parameters: self,
|
||||||
|
current_max_depth: 0,
|
||||||
|
current_min_samples_leaf: 0,
|
||||||
|
current_min_samples_split: 0,
|
||||||
|
current_n_trees: 0,
|
||||||
|
current_m: 0,
|
||||||
|
current_keep_samples: 0,
|
||||||
|
current_seed: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Iterator for RandomForestRegressorSearchParametersIterator {
|
||||||
|
type Item = RandomForestRegressorParameters;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
if self.current_max_depth
|
||||||
|
== self
|
||||||
|
.random_forest_regressor_search_parameters
|
||||||
|
.max_depth
|
||||||
|
.len()
|
||||||
|
&& self.current_min_samples_leaf
|
||||||
|
== self
|
||||||
|
.random_forest_regressor_search_parameters
|
||||||
|
.min_samples_leaf
|
||||||
|
.len()
|
||||||
|
&& self.current_min_samples_split
|
||||||
|
== self
|
||||||
|
.random_forest_regressor_search_parameters
|
||||||
|
.min_samples_split
|
||||||
|
.len()
|
||||||
|
&& self.current_n_trees == self.random_forest_regressor_search_parameters.n_trees.len()
|
||||||
|
&& self.current_m == self.random_forest_regressor_search_parameters.m.len()
|
||||||
|
&& self.current_keep_samples
|
||||||
|
== self
|
||||||
|
.random_forest_regressor_search_parameters
|
||||||
|
.keep_samples
|
||||||
|
.len()
|
||||||
|
&& self.current_seed == self.random_forest_regressor_search_parameters.seed.len()
|
||||||
|
{
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let next = RandomForestRegressorParameters {
|
||||||
|
max_depth: self.random_forest_regressor_search_parameters.max_depth
|
||||||
|
[self.current_max_depth],
|
||||||
|
min_samples_leaf: self
|
||||||
|
.random_forest_regressor_search_parameters
|
||||||
|
.min_samples_leaf[self.current_min_samples_leaf],
|
||||||
|
min_samples_split: self
|
||||||
|
.random_forest_regressor_search_parameters
|
||||||
|
.min_samples_split[self.current_min_samples_split],
|
||||||
|
n_trees: self.random_forest_regressor_search_parameters.n_trees[self.current_n_trees],
|
||||||
|
m: self.random_forest_regressor_search_parameters.m[self.current_m],
|
||||||
|
keep_samples: self.random_forest_regressor_search_parameters.keep_samples
|
||||||
|
[self.current_keep_samples],
|
||||||
|
seed: self.random_forest_regressor_search_parameters.seed[self.current_seed],
|
||||||
|
};
|
||||||
|
|
||||||
|
if self.current_max_depth + 1
|
||||||
|
< self
|
||||||
|
.random_forest_regressor_search_parameters
|
||||||
|
.max_depth
|
||||||
|
.len()
|
||||||
|
{
|
||||||
|
self.current_max_depth += 1;
|
||||||
|
} else if self.current_min_samples_leaf + 1
|
||||||
|
< self
|
||||||
|
.random_forest_regressor_search_parameters
|
||||||
|
.min_samples_leaf
|
||||||
|
.len()
|
||||||
|
{
|
||||||
|
self.current_max_depth = 0;
|
||||||
|
self.current_min_samples_leaf += 1;
|
||||||
|
} else if self.current_min_samples_split + 1
|
||||||
|
< self
|
||||||
|
.random_forest_regressor_search_parameters
|
||||||
|
.min_samples_split
|
||||||
|
.len()
|
||||||
|
{
|
||||||
|
self.current_max_depth = 0;
|
||||||
|
self.current_min_samples_leaf = 0;
|
||||||
|
self.current_min_samples_split += 1;
|
||||||
|
} else if self.current_n_trees + 1
|
||||||
|
< self.random_forest_regressor_search_parameters.n_trees.len()
|
||||||
|
{
|
||||||
|
self.current_max_depth = 0;
|
||||||
|
self.current_min_samples_leaf = 0;
|
||||||
|
self.current_min_samples_split = 0;
|
||||||
|
self.current_n_trees += 1;
|
||||||
|
} else if self.current_m + 1 < self.random_forest_regressor_search_parameters.m.len() {
|
||||||
|
self.current_max_depth = 0;
|
||||||
|
self.current_min_samples_leaf = 0;
|
||||||
|
self.current_min_samples_split = 0;
|
||||||
|
self.current_n_trees = 0;
|
||||||
|
self.current_m += 1;
|
||||||
|
} else if self.current_keep_samples + 1
|
||||||
|
< self
|
||||||
|
.random_forest_regressor_search_parameters
|
||||||
|
.keep_samples
|
||||||
|
.len()
|
||||||
|
{
|
||||||
|
self.current_max_depth = 0;
|
||||||
|
self.current_min_samples_leaf = 0;
|
||||||
|
self.current_min_samples_split = 0;
|
||||||
|
self.current_n_trees = 0;
|
||||||
|
self.current_m = 0;
|
||||||
|
self.current_keep_samples += 1;
|
||||||
|
} else if self.current_seed + 1 < self.random_forest_regressor_search_parameters.seed.len()
|
||||||
|
{
|
||||||
|
self.current_max_depth = 0;
|
||||||
|
self.current_min_samples_leaf = 0;
|
||||||
|
self.current_min_samples_split = 0;
|
||||||
|
self.current_n_trees = 0;
|
||||||
|
self.current_m = 0;
|
||||||
|
self.current_keep_samples = 0;
|
||||||
|
self.current_seed += 1;
|
||||||
|
} else {
|
||||||
|
self.current_max_depth += 1;
|
||||||
|
self.current_min_samples_leaf += 1;
|
||||||
|
self.current_min_samples_split += 1;
|
||||||
|
self.current_n_trees += 1;
|
||||||
|
self.current_m += 1;
|
||||||
|
self.current_keep_samples += 1;
|
||||||
|
self.current_seed += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for RandomForestRegressorSearchParameters {
|
||||||
|
fn default() -> Self {
|
||||||
|
let default_params = RandomForestRegressorParameters::default();
|
||||||
|
|
||||||
|
RandomForestRegressorSearchParameters {
|
||||||
|
max_depth: vec![default_params.max_depth],
|
||||||
|
min_samples_leaf: vec![default_params.min_samples_leaf],
|
||||||
|
min_samples_split: vec![default_params.min_samples_split],
|
||||||
|
n_trees: vec![default_params.n_trees],
|
||||||
|
m: vec![default_params.m],
|
||||||
|
keep_samples: vec![default_params.keep_samples],
|
||||||
|
seed: vec![default_params.seed],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: RealNumber> RandomForestRegressor<T> {
|
impl<T: RealNumber> RandomForestRegressor<T> {
|
||||||
/// Build a forest of trees from the training set.
|
/// Build a forest of trees from the training set.
|
||||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||||
@@ -302,6 +487,29 @@ mod tests {
|
|||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
use crate::metrics::mean_absolute_error;
|
use crate::metrics::mean_absolute_error;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn search_parameters() {
|
||||||
|
let parameters = RandomForestRegressorSearchParameters {
|
||||||
|
n_trees: vec![10, 100],
|
||||||
|
m: vec![None, Some(1)],
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let mut iter = parameters.into_iter();
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.n_trees, 10);
|
||||||
|
assert_eq!(next.m, None);
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.n_trees, 100);
|
||||||
|
assert_eq!(next.m, None);
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.n_trees, 10);
|
||||||
|
assert_eq!(next.m, Some(1));
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.n_trees, 100);
|
||||||
|
assert_eq!(next.m, Some(1));
|
||||||
|
assert!(iter.next().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn fit_longley() {
|
fn fit_longley() {
|
||||||
|
|||||||
+14
-17
@@ -129,7 +129,7 @@ pub struct LassoSearchParameters<T: RealNumber> {
|
|||||||
|
|
||||||
/// Lasso grid search iterator
|
/// Lasso grid search iterator
|
||||||
pub struct LassoSearchParametersIterator<T: RealNumber> {
|
pub struct LassoSearchParametersIterator<T: RealNumber> {
|
||||||
lasso_regression_search_parameters: LassoSearchParameters<T>,
|
lasso_search_parameters: LassoSearchParameters<T>,
|
||||||
current_alpha: usize,
|
current_alpha: usize,
|
||||||
current_normalize: usize,
|
current_normalize: usize,
|
||||||
current_tol: usize,
|
current_tol: usize,
|
||||||
@@ -142,7 +142,7 @@ impl<T: RealNumber> IntoIterator for LassoSearchParameters<T> {
|
|||||||
|
|
||||||
fn into_iter(self) -> Self::IntoIter {
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
LassoSearchParametersIterator {
|
LassoSearchParametersIterator {
|
||||||
lasso_regression_search_parameters: self,
|
lasso_search_parameters: self,
|
||||||
current_alpha: 0,
|
current_alpha: 0,
|
||||||
current_normalize: 0,
|
current_normalize: 0,
|
||||||
current_tol: 0,
|
current_tol: 0,
|
||||||
@@ -155,34 +155,31 @@ impl<T: RealNumber> Iterator for LassoSearchParametersIterator<T> {
|
|||||||
type Item = LassoParameters<T>;
|
type Item = LassoParameters<T>;
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
if self.current_alpha == self.lasso_regression_search_parameters.alpha.len()
|
if self.current_alpha == self.lasso_search_parameters.alpha.len()
|
||||||
&& self.current_normalize == self.lasso_regression_search_parameters.normalize.len()
|
&& self.current_normalize == self.lasso_search_parameters.normalize.len()
|
||||||
&& self.current_tol == self.lasso_regression_search_parameters.tol.len()
|
&& self.current_tol == self.lasso_search_parameters.tol.len()
|
||||||
&& self.current_max_iter == self.lasso_regression_search_parameters.max_iter.len()
|
&& self.current_max_iter == self.lasso_search_parameters.max_iter.len()
|
||||||
{
|
{
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
let next = LassoParameters {
|
let next = LassoParameters {
|
||||||
alpha: self.lasso_regression_search_parameters.alpha[self.current_alpha],
|
alpha: self.lasso_search_parameters.alpha[self.current_alpha],
|
||||||
normalize: self.lasso_regression_search_parameters.normalize[self.current_normalize],
|
normalize: self.lasso_search_parameters.normalize[self.current_normalize],
|
||||||
tol: self.lasso_regression_search_parameters.tol[self.current_tol],
|
tol: self.lasso_search_parameters.tol[self.current_tol],
|
||||||
max_iter: self.lasso_regression_search_parameters.max_iter[self.current_max_iter],
|
max_iter: self.lasso_search_parameters.max_iter[self.current_max_iter],
|
||||||
};
|
};
|
||||||
|
|
||||||
if self.current_alpha + 1 < self.lasso_regression_search_parameters.alpha.len() {
|
if self.current_alpha + 1 < self.lasso_search_parameters.alpha.len() {
|
||||||
self.current_alpha += 1;
|
self.current_alpha += 1;
|
||||||
} else if self.current_normalize + 1
|
} else if self.current_normalize + 1 < self.lasso_search_parameters.normalize.len() {
|
||||||
< self.lasso_regression_search_parameters.normalize.len()
|
|
||||||
{
|
|
||||||
self.current_alpha = 0;
|
self.current_alpha = 0;
|
||||||
self.current_normalize += 1;
|
self.current_normalize += 1;
|
||||||
} else if self.current_tol + 1 < self.lasso_regression_search_parameters.tol.len() {
|
} else if self.current_tol + 1 < self.lasso_search_parameters.tol.len() {
|
||||||
self.current_alpha = 0;
|
self.current_alpha = 0;
|
||||||
self.current_normalize = 0;
|
self.current_normalize = 0;
|
||||||
self.current_tol += 1;
|
self.current_tol += 1;
|
||||||
} else if self.current_max_iter + 1 < self.lasso_regression_search_parameters.max_iter.len()
|
} else if self.current_max_iter + 1 < self.lasso_search_parameters.max_iter.len() {
|
||||||
{
|
|
||||||
self.current_alpha = 0;
|
self.current_alpha = 0;
|
||||||
self.current_normalize = 0;
|
self.current_normalize = 0;
|
||||||
self.current_tol = 0;
|
self.current_tol = 0;
|
||||||
|
|||||||
@@ -114,4 +114,4 @@ mod tests {
|
|||||||
|
|
||||||
assert!([0., 1.].contains(&results.parameters.alpha));
|
assert!([0., 1.].contains(&results.parameters.alpha));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -281,7 +281,6 @@ mod tests {
|
|||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::*;
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
|
|
||||||
use crate::metrics::{accuracy, mean_absolute_error};
|
use crate::metrics::{accuracy, mean_absolute_error};
|
||||||
use crate::model_selection::kfold::KFold;
|
use crate::model_selection::kfold::KFold;
|
||||||
use crate::neighbors::knn_regressor::KNNRegressor;
|
use crate::neighbors::knn_regressor::KNNRegressor;
|
||||||
|
|||||||
@@ -150,6 +150,88 @@ impl<T: RealNumber> Default for BernoulliNBParameters<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// BernoulliNB grid search parameters
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct BernoulliNBSearchParameters<T: RealNumber> {
|
||||||
|
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||||
|
pub alpha: Vec<T>,
|
||||||
|
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
|
||||||
|
pub priors: Vec<Option<Vec<T>>>,
|
||||||
|
/// Threshold for binarizing (mapping to booleans) of sample features. If None, input is presumed to already consist of binary vectors.
|
||||||
|
pub binarize: Vec<Option<T>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// BernoulliNB grid search iterator
|
||||||
|
pub struct BernoulliNBSearchParametersIterator<T: RealNumber> {
|
||||||
|
bernoulli_nb_search_parameters: BernoulliNBSearchParameters<T>,
|
||||||
|
current_alpha: usize,
|
||||||
|
current_priors: usize,
|
||||||
|
current_binarize: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber> IntoIterator for BernoulliNBSearchParameters<T> {
|
||||||
|
type Item = BernoulliNBParameters<T>;
|
||||||
|
type IntoIter = BernoulliNBSearchParametersIterator<T>;
|
||||||
|
|
||||||
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
|
BernoulliNBSearchParametersIterator {
|
||||||
|
bernoulli_nb_search_parameters: self,
|
||||||
|
current_alpha: 0,
|
||||||
|
current_priors: 0,
|
||||||
|
current_binarize: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber> Iterator for BernoulliNBSearchParametersIterator<T> {
|
||||||
|
type Item = BernoulliNBParameters<T>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
if self.current_alpha == self.bernoulli_nb_search_parameters.alpha.len()
|
||||||
|
&& self.current_priors == self.bernoulli_nb_search_parameters.priors.len()
|
||||||
|
&& self.current_binarize == self.bernoulli_nb_search_parameters.binarize.len()
|
||||||
|
{
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let next = BernoulliNBParameters {
|
||||||
|
alpha: self.bernoulli_nb_search_parameters.alpha[self.current_alpha],
|
||||||
|
priors: self.bernoulli_nb_search_parameters.priors[self.current_priors].clone(),
|
||||||
|
binarize: self.bernoulli_nb_search_parameters.binarize[self.current_binarize],
|
||||||
|
};
|
||||||
|
|
||||||
|
if self.current_alpha + 1 < self.bernoulli_nb_search_parameters.alpha.len() {
|
||||||
|
self.current_alpha += 1;
|
||||||
|
} else if self.current_priors + 1 < self.bernoulli_nb_search_parameters.priors.len() {
|
||||||
|
self.current_alpha = 0;
|
||||||
|
self.current_priors += 1;
|
||||||
|
} else if self.current_binarize + 1 < self.bernoulli_nb_search_parameters.binarize.len() {
|
||||||
|
self.current_alpha = 0;
|
||||||
|
self.current_priors = 0;
|
||||||
|
self.current_binarize += 1;
|
||||||
|
} else {
|
||||||
|
self.current_alpha += 1;
|
||||||
|
self.current_priors += 1;
|
||||||
|
self.current_binarize += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber> Default for BernoulliNBSearchParameters<T> {
|
||||||
|
fn default() -> Self {
|
||||||
|
let default_params = BernoulliNBParameters::default();
|
||||||
|
|
||||||
|
BernoulliNBSearchParameters {
|
||||||
|
alpha: vec![default_params.alpha],
|
||||||
|
priors: vec![default_params.priors],
|
||||||
|
binarize: vec![default_params.binarize],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: RealNumber> BernoulliNBDistribution<T> {
|
impl<T: RealNumber> BernoulliNBDistribution<T> {
|
||||||
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
|
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
|
||||||
/// * `x` - training data.
|
/// * `x` - training data.
|
||||||
@@ -347,6 +429,20 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn search_parameters() {
|
||||||
|
let parameters = BernoulliNBSearchParameters {
|
||||||
|
alpha: vec![1., 2.],
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let mut iter = parameters.into_iter();
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.alpha, 1.);
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.alpha, 2.);
|
||||||
|
assert!(iter.next().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn run_bernoulli_naive_bayes() {
|
fn run_bernoulli_naive_bayes() {
|
||||||
|
|||||||
@@ -261,6 +261,60 @@ impl<T: RealNumber> Default for CategoricalNBParameters<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// CategoricalNB grid search parameters
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct CategoricalNBSearchParameters<T: RealNumber> {
|
||||||
|
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||||
|
pub alpha: Vec<T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// CategoricalNB grid search iterator
|
||||||
|
pub struct CategoricalNBSearchParametersIterator<T: RealNumber> {
|
||||||
|
categorical_nb_search_parameters: CategoricalNBSearchParameters<T>,
|
||||||
|
current_alpha: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber> IntoIterator for CategoricalNBSearchParameters<T> {
|
||||||
|
type Item = CategoricalNBParameters<T>;
|
||||||
|
type IntoIter = CategoricalNBSearchParametersIterator<T>;
|
||||||
|
|
||||||
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
|
CategoricalNBSearchParametersIterator {
|
||||||
|
categorical_nb_search_parameters: self,
|
||||||
|
current_alpha: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber> Iterator for CategoricalNBSearchParametersIterator<T> {
|
||||||
|
type Item = CategoricalNBParameters<T>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
if self.current_alpha == self.categorical_nb_search_parameters.alpha.len() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let next = CategoricalNBParameters {
|
||||||
|
alpha: self.categorical_nb_search_parameters.alpha[self.current_alpha],
|
||||||
|
};
|
||||||
|
|
||||||
|
self.current_alpha += 1;
|
||||||
|
|
||||||
|
Some(next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber> Default for CategoricalNBSearchParameters<T> {
|
||||||
|
fn default() -> Self {
|
||||||
|
let default_params = CategoricalNBParameters::default();
|
||||||
|
|
||||||
|
CategoricalNBSearchParameters {
|
||||||
|
alpha: vec![default_params.alpha],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// CategoricalNB implements the categorical naive Bayes algorithm for categorically distributed data.
|
/// CategoricalNB implements the categorical naive Bayes algorithm for categorically distributed data.
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, PartialEq)]
|
#[derive(Debug, PartialEq)]
|
||||||
@@ -351,6 +405,20 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn search_parameters() {
|
||||||
|
let parameters = CategoricalNBSearchParameters {
|
||||||
|
alpha: vec![1., 2.],
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let mut iter = parameters.into_iter();
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.alpha, 1.);
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.alpha, 2.);
|
||||||
|
assert!(iter.next().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn run_categorical_naive_bayes() {
|
fn run_categorical_naive_bayes() {
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for GaussianNBDistributio
|
|||||||
|
|
||||||
/// `GaussianNB` parameters. Use `Default::default()` for default values.
|
/// `GaussianNB` parameters. Use `Default::default()` for default values.
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Default, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct GaussianNBParameters<T: RealNumber> {
|
pub struct GaussianNBParameters<T: RealNumber> {
|
||||||
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
|
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
|
||||||
pub priors: Option<Vec<T>>,
|
pub priors: Option<Vec<T>>,
|
||||||
@@ -90,6 +90,66 @@ impl<T: RealNumber> GaussianNBParameters<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber> Default for GaussianNBParameters<T> {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self { priors: None }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GaussianNB grid search parameters
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct GaussianNBSearchParameters<T: RealNumber> {
|
||||||
|
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
|
||||||
|
pub priors: Vec<Option<Vec<T>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GaussianNB grid search iterator
|
||||||
|
pub struct GaussianNBSearchParametersIterator<T: RealNumber> {
|
||||||
|
gaussian_nb_search_parameters: GaussianNBSearchParameters<T>,
|
||||||
|
current_priors: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber> IntoIterator for GaussianNBSearchParameters<T> {
|
||||||
|
type Item = GaussianNBParameters<T>;
|
||||||
|
type IntoIter = GaussianNBSearchParametersIterator<T>;
|
||||||
|
|
||||||
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
|
GaussianNBSearchParametersIterator {
|
||||||
|
gaussian_nb_search_parameters: self,
|
||||||
|
current_priors: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber> Iterator for GaussianNBSearchParametersIterator<T> {
|
||||||
|
type Item = GaussianNBParameters<T>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
if self.current_priors == self.gaussian_nb_search_parameters.priors.len() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let next = GaussianNBParameters {
|
||||||
|
priors: self.gaussian_nb_search_parameters.priors[self.current_priors].clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
self.current_priors += 1;
|
||||||
|
|
||||||
|
Some(next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber> Default for GaussianNBSearchParameters<T> {
|
||||||
|
fn default() -> Self {
|
||||||
|
let default_params = GaussianNBParameters::default();
|
||||||
|
|
||||||
|
GaussianNBSearchParameters {
|
||||||
|
priors: vec![default_params.priors],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: RealNumber> GaussianNBDistribution<T> {
|
impl<T: RealNumber> GaussianNBDistribution<T> {
|
||||||
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
|
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
|
||||||
/// * `x` - training data.
|
/// * `x` - training data.
|
||||||
@@ -260,6 +320,20 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn search_parameters() {
|
||||||
|
let parameters = GaussianNBSearchParameters {
|
||||||
|
priors: vec![Some(vec![1.]), Some(vec![2.])],
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let mut iter = parameters.into_iter();
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.priors, Some(vec![1.]));
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.priors, Some(vec![2.]));
|
||||||
|
assert!(iter.next().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn run_gaussian_naive_bayes() {
|
fn run_gaussian_naive_bayes() {
|
||||||
|
|||||||
@@ -114,6 +114,76 @@ impl<T: RealNumber> Default for MultinomialNBParameters<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// MultinomialNB grid search parameters
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct MultinomialNBSearchParameters<T: RealNumber> {
|
||||||
|
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||||
|
pub alpha: Vec<T>,
|
||||||
|
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
|
||||||
|
pub priors: Vec<Option<Vec<T>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// MultinomialNB grid search iterator
|
||||||
|
pub struct MultinomialNBSearchParametersIterator<T: RealNumber> {
|
||||||
|
multinomial_nb_search_parameters: MultinomialNBSearchParameters<T>,
|
||||||
|
current_alpha: usize,
|
||||||
|
current_priors: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber> IntoIterator for MultinomialNBSearchParameters<T> {
|
||||||
|
type Item = MultinomialNBParameters<T>;
|
||||||
|
type IntoIter = MultinomialNBSearchParametersIterator<T>;
|
||||||
|
|
||||||
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
|
MultinomialNBSearchParametersIterator {
|
||||||
|
multinomial_nb_search_parameters: self,
|
||||||
|
current_alpha: 0,
|
||||||
|
current_priors: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber> Iterator for MultinomialNBSearchParametersIterator<T> {
|
||||||
|
type Item = MultinomialNBParameters<T>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
if self.current_alpha == self.multinomial_nb_search_parameters.alpha.len()
|
||||||
|
&& self.current_priors == self.multinomial_nb_search_parameters.priors.len()
|
||||||
|
{
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let next = MultinomialNBParameters {
|
||||||
|
alpha: self.multinomial_nb_search_parameters.alpha[self.current_alpha],
|
||||||
|
priors: self.multinomial_nb_search_parameters.priors[self.current_priors].clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
if self.current_alpha + 1 < self.multinomial_nb_search_parameters.alpha.len() {
|
||||||
|
self.current_alpha += 1;
|
||||||
|
} else if self.current_priors + 1 < self.multinomial_nb_search_parameters.priors.len() {
|
||||||
|
self.current_alpha = 0;
|
||||||
|
self.current_priors += 1;
|
||||||
|
} else {
|
||||||
|
self.current_alpha += 1;
|
||||||
|
self.current_priors += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber> Default for MultinomialNBSearchParameters<T> {
|
||||||
|
fn default() -> Self {
|
||||||
|
let default_params = MultinomialNBParameters::default();
|
||||||
|
|
||||||
|
MultinomialNBSearchParameters {
|
||||||
|
alpha: vec![default_params.alpha],
|
||||||
|
priors: vec![default_params.priors],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: RealNumber> MultinomialNBDistribution<T> {
|
impl<T: RealNumber> MultinomialNBDistribution<T> {
|
||||||
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
|
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
|
||||||
/// * `x` - training data.
|
/// * `x` - training data.
|
||||||
@@ -297,6 +367,20 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn search_parameters() {
|
||||||
|
let parameters = MultinomialNBSearchParameters {
|
||||||
|
alpha: vec![1., 2.],
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let mut iter = parameters.into_iter();
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.alpha, 1.);
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.alpha, 2.);
|
||||||
|
assert!(iter.next().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn run_multinomial_naive_bayes() {
|
fn run_multinomial_naive_bayes() {
|
||||||
|
|||||||
+5
-5
@@ -33,7 +33,7 @@ use crate::linalg::BaseVector;
|
|||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
/// Defines a kernel function
|
/// Defines a kernel function
|
||||||
pub trait Kernel<T: RealNumber, V: BaseVector<T>> {
|
pub trait Kernel<T: RealNumber, V: BaseVector<T>>: Clone {
|
||||||
/// Apply kernel function to x_i and x_j
|
/// Apply kernel function to x_i and x_j
|
||||||
fn apply(&self, x_i: &V, x_j: &V) -> T;
|
fn apply(&self, x_i: &V, x_j: &V) -> T;
|
||||||
}
|
}
|
||||||
@@ -95,12 +95,12 @@ impl Kernels {
|
|||||||
|
|
||||||
/// Linear Kernel
|
/// Linear Kernel
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct LinearKernel {}
|
pub struct LinearKernel {}
|
||||||
|
|
||||||
/// Radial basis function (Gaussian) kernel
|
/// Radial basis function (Gaussian) kernel
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct RBFKernel<T: RealNumber> {
|
pub struct RBFKernel<T: RealNumber> {
|
||||||
/// kernel coefficient
|
/// kernel coefficient
|
||||||
pub gamma: T,
|
pub gamma: T,
|
||||||
@@ -108,7 +108,7 @@ pub struct RBFKernel<T: RealNumber> {
|
|||||||
|
|
||||||
/// Polynomial kernel
|
/// Polynomial kernel
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct PolynomialKernel<T: RealNumber> {
|
pub struct PolynomialKernel<T: RealNumber> {
|
||||||
/// degree of the polynomial
|
/// degree of the polynomial
|
||||||
pub degree: T,
|
pub degree: T,
|
||||||
@@ -120,7 +120,7 @@ pub struct PolynomialKernel<T: RealNumber> {
|
|||||||
|
|
||||||
/// Sigmoid (hyperbolic tangent) kernel
|
/// Sigmoid (hyperbolic tangent) kernel
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub struct SigmoidKernel<T: RealNumber> {
|
pub struct SigmoidKernel<T: RealNumber> {
|
||||||
/// kernel coefficient
|
/// kernel coefficient
|
||||||
pub gamma: T,
|
pub gamma: T,
|
||||||
|
|||||||
+121
@@ -102,6 +102,109 @@ pub struct SVCParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>
|
|||||||
m: PhantomData<M>,
|
m: PhantomData<M>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// SVC grid search parameters
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct SVCSearchParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||||
|
/// Number of epochs.
|
||||||
|
pub epoch: Vec<usize>,
|
||||||
|
/// Regularization parameter.
|
||||||
|
pub c: Vec<T>,
|
||||||
|
/// Tolerance for stopping epoch.
|
||||||
|
pub tol: Vec<T>,
|
||||||
|
/// The kernel function.
|
||||||
|
pub kernel: Vec<K>,
|
||||||
|
/// Unused parameter.
|
||||||
|
m: PhantomData<M>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// SVC grid search iterator
|
||||||
|
pub struct SVCSearchParametersIterator<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||||
|
svc_search_parameters: SVCSearchParameters<T, M, K>,
|
||||||
|
current_epoch: usize,
|
||||||
|
current_c: usize,
|
||||||
|
current_tol: usize,
|
||||||
|
current_kernel: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> IntoIterator
|
||||||
|
for SVCSearchParameters<T, M, K>
|
||||||
|
{
|
||||||
|
type Item = SVCParameters<T, M, K>;
|
||||||
|
type IntoIter = SVCSearchParametersIterator<T, M, K>;
|
||||||
|
|
||||||
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
|
SVCSearchParametersIterator {
|
||||||
|
svc_search_parameters: self,
|
||||||
|
current_epoch: 0,
|
||||||
|
current_c: 0,
|
||||||
|
current_tol: 0,
|
||||||
|
current_kernel: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Iterator
|
||||||
|
for SVCSearchParametersIterator<T, M, K>
|
||||||
|
{
|
||||||
|
type Item = SVCParameters<T, M, K>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
if self.current_epoch == self.svc_search_parameters.epoch.len()
|
||||||
|
&& self.current_c == self.svc_search_parameters.c.len()
|
||||||
|
&& self.current_tol == self.svc_search_parameters.tol.len()
|
||||||
|
&& self.current_kernel == self.svc_search_parameters.kernel.len()
|
||||||
|
{
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let next = SVCParameters::<T, M, K> {
|
||||||
|
epoch: self.svc_search_parameters.epoch[self.current_epoch],
|
||||||
|
c: self.svc_search_parameters.c[self.current_c],
|
||||||
|
tol: self.svc_search_parameters.tol[self.current_tol],
|
||||||
|
kernel: self.svc_search_parameters.kernel[self.current_kernel].clone(),
|
||||||
|
m: PhantomData,
|
||||||
|
};
|
||||||
|
|
||||||
|
if self.current_epoch + 1 < self.svc_search_parameters.epoch.len() {
|
||||||
|
self.current_epoch += 1;
|
||||||
|
} else if self.current_c + 1 < self.svc_search_parameters.c.len() {
|
||||||
|
self.current_epoch = 0;
|
||||||
|
self.current_c += 1;
|
||||||
|
} else if self.current_tol + 1 < self.svc_search_parameters.tol.len() {
|
||||||
|
self.current_epoch = 0;
|
||||||
|
self.current_c = 0;
|
||||||
|
self.current_tol += 1;
|
||||||
|
} else if self.current_kernel + 1 < self.svc_search_parameters.kernel.len() {
|
||||||
|
self.current_epoch = 0;
|
||||||
|
self.current_c = 0;
|
||||||
|
self.current_tol = 0;
|
||||||
|
self.current_kernel += 1;
|
||||||
|
} else {
|
||||||
|
self.current_epoch += 1;
|
||||||
|
self.current_c += 1;
|
||||||
|
self.current_tol += 1;
|
||||||
|
self.current_kernel += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber, M: Matrix<T>> Default for SVCSearchParameters<T, M, LinearKernel> {
|
||||||
|
fn default() -> Self {
|
||||||
|
let default_params: SVCParameters<T, M, LinearKernel> = SVCParameters::default();
|
||||||
|
|
||||||
|
SVCSearchParameters {
|
||||||
|
epoch: vec![default_params.epoch],
|
||||||
|
c: vec![default_params.c],
|
||||||
|
tol: vec![default_params.tol],
|
||||||
|
kernel: vec![default_params.kernel],
|
||||||
|
m: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
#[cfg_attr(
|
#[cfg_attr(
|
||||||
@@ -737,6 +840,24 @@ mod tests {
|
|||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
use crate::svm::*;
|
use crate::svm::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn search_parameters() {
|
||||||
|
let parameters: SVCSearchParameters<f64, DenseMatrix<f64>, LinearKernel> =
|
||||||
|
SVCSearchParameters {
|
||||||
|
epoch: vec![10, 100],
|
||||||
|
kernel: vec![LinearKernel {}],
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let mut iter = parameters.into_iter();
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.epoch, 10);
|
||||||
|
assert_eq!(next.kernel, LinearKernel {});
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.epoch, 100);
|
||||||
|
assert_eq!(next.kernel, LinearKernel {});
|
||||||
|
assert!(iter.next().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn svc_fit_predict() {
|
fn svc_fit_predict() {
|
||||||
|
|||||||
+121
@@ -94,6 +94,109 @@ pub struct SVRParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>
|
|||||||
m: PhantomData<M>,
|
m: PhantomData<M>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// SVR grid search parameters
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct SVRSearchParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||||
|
/// Epsilon in the epsilon-SVR model.
|
||||||
|
pub eps: Vec<T>,
|
||||||
|
/// Regularization parameter.
|
||||||
|
pub c: Vec<T>,
|
||||||
|
/// Tolerance for stopping eps.
|
||||||
|
pub tol: Vec<T>,
|
||||||
|
/// The kernel function.
|
||||||
|
pub kernel: Vec<K>,
|
||||||
|
/// Unused parameter.
|
||||||
|
m: PhantomData<M>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// SVR grid search iterator
|
||||||
|
pub struct SVRSearchParametersIterator<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||||
|
svr_search_parameters: SVRSearchParameters<T, M, K>,
|
||||||
|
current_eps: usize,
|
||||||
|
current_c: usize,
|
||||||
|
current_tol: usize,
|
||||||
|
current_kernel: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> IntoIterator
|
||||||
|
for SVRSearchParameters<T, M, K>
|
||||||
|
{
|
||||||
|
type Item = SVRParameters<T, M, K>;
|
||||||
|
type IntoIter = SVRSearchParametersIterator<T, M, K>;
|
||||||
|
|
||||||
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
|
SVRSearchParametersIterator {
|
||||||
|
svr_search_parameters: self,
|
||||||
|
current_eps: 0,
|
||||||
|
current_c: 0,
|
||||||
|
current_tol: 0,
|
||||||
|
current_kernel: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Iterator
|
||||||
|
for SVRSearchParametersIterator<T, M, K>
|
||||||
|
{
|
||||||
|
type Item = SVRParameters<T, M, K>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
if self.current_eps == self.svr_search_parameters.eps.len()
|
||||||
|
&& self.current_c == self.svr_search_parameters.c.len()
|
||||||
|
&& self.current_tol == self.svr_search_parameters.tol.len()
|
||||||
|
&& self.current_kernel == self.svr_search_parameters.kernel.len()
|
||||||
|
{
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let next = SVRParameters::<T, M, K> {
|
||||||
|
eps: self.svr_search_parameters.eps[self.current_eps],
|
||||||
|
c: self.svr_search_parameters.c[self.current_c],
|
||||||
|
tol: self.svr_search_parameters.tol[self.current_tol],
|
||||||
|
kernel: self.svr_search_parameters.kernel[self.current_kernel].clone(),
|
||||||
|
m: PhantomData,
|
||||||
|
};
|
||||||
|
|
||||||
|
if self.current_eps + 1 < self.svr_search_parameters.eps.len() {
|
||||||
|
self.current_eps += 1;
|
||||||
|
} else if self.current_c + 1 < self.svr_search_parameters.c.len() {
|
||||||
|
self.current_eps = 0;
|
||||||
|
self.current_c += 1;
|
||||||
|
} else if self.current_tol + 1 < self.svr_search_parameters.tol.len() {
|
||||||
|
self.current_eps = 0;
|
||||||
|
self.current_c = 0;
|
||||||
|
self.current_tol += 1;
|
||||||
|
} else if self.current_kernel + 1 < self.svr_search_parameters.kernel.len() {
|
||||||
|
self.current_eps = 0;
|
||||||
|
self.current_c = 0;
|
||||||
|
self.current_tol = 0;
|
||||||
|
self.current_kernel += 1;
|
||||||
|
} else {
|
||||||
|
self.current_eps += 1;
|
||||||
|
self.current_c += 1;
|
||||||
|
self.current_tol += 1;
|
||||||
|
self.current_kernel += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber, M: Matrix<T>> Default for SVRSearchParameters<T, M, LinearKernel> {
|
||||||
|
fn default() -> Self {
|
||||||
|
let default_params: SVRParameters<T, M, LinearKernel> = SVRParameters::default();
|
||||||
|
|
||||||
|
SVRSearchParameters {
|
||||||
|
eps: vec![default_params.eps],
|
||||||
|
c: vec![default_params.c],
|
||||||
|
tol: vec![default_params.tol],
|
||||||
|
kernel: vec![default_params.kernel],
|
||||||
|
m: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
#[cfg_attr(
|
#[cfg_attr(
|
||||||
@@ -536,6 +639,24 @@ mod tests {
|
|||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
use crate::svm::*;
|
use crate::svm::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn search_parameters() {
|
||||||
|
let parameters: SVRSearchParameters<f64, DenseMatrix<f64>, LinearKernel> =
|
||||||
|
SVRSearchParameters {
|
||||||
|
eps: vec![0., 1.],
|
||||||
|
kernel: vec![LinearKernel {}],
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let mut iter = parameters.into_iter();
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.eps, 0.);
|
||||||
|
assert_eq!(next.kernel, LinearKernel {});
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.eps, 1.);
|
||||||
|
assert_eq!(next.kernel, LinearKernel {});
|
||||||
|
assert!(iter.next().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn svr_fit_predict() {
|
fn svr_fit_predict() {
|
||||||
|
|||||||
@@ -201,6 +201,144 @@ impl Default for DecisionTreeClassifierParameters {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// DecisionTreeClassifier grid search parameters
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct DecisionTreeClassifierSearchParameters {
|
||||||
|
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||||
|
pub criterion: Vec<SplitCriterion>,
|
||||||
|
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||||
|
pub max_depth: Vec<Option<u16>>,
|
||||||
|
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||||
|
pub min_samples_leaf: Vec<usize>,
|
||||||
|
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||||
|
pub min_samples_split: Vec<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// DecisionTreeClassifier grid search iterator
|
||||||
|
pub struct DecisionTreeClassifierSearchParametersIterator {
|
||||||
|
decision_tree_classifier_search_parameters: DecisionTreeClassifierSearchParameters,
|
||||||
|
current_criterion: usize,
|
||||||
|
current_max_depth: usize,
|
||||||
|
current_min_samples_leaf: usize,
|
||||||
|
current_min_samples_split: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoIterator for DecisionTreeClassifierSearchParameters {
|
||||||
|
type Item = DecisionTreeClassifierParameters;
|
||||||
|
type IntoIter = DecisionTreeClassifierSearchParametersIterator;
|
||||||
|
|
||||||
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
|
DecisionTreeClassifierSearchParametersIterator {
|
||||||
|
decision_tree_classifier_search_parameters: self,
|
||||||
|
current_criterion: 0,
|
||||||
|
current_max_depth: 0,
|
||||||
|
current_min_samples_leaf: 0,
|
||||||
|
current_min_samples_split: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Iterator for DecisionTreeClassifierSearchParametersIterator {
|
||||||
|
type Item = DecisionTreeClassifierParameters;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
if self.current_criterion
|
||||||
|
== self
|
||||||
|
.decision_tree_classifier_search_parameters
|
||||||
|
.criterion
|
||||||
|
.len()
|
||||||
|
&& self.current_max_depth
|
||||||
|
== self
|
||||||
|
.decision_tree_classifier_search_parameters
|
||||||
|
.max_depth
|
||||||
|
.len()
|
||||||
|
&& self.current_min_samples_leaf
|
||||||
|
== self
|
||||||
|
.decision_tree_classifier_search_parameters
|
||||||
|
.min_samples_leaf
|
||||||
|
.len()
|
||||||
|
&& self.current_min_samples_split
|
||||||
|
== self
|
||||||
|
.decision_tree_classifier_search_parameters
|
||||||
|
.min_samples_split
|
||||||
|
.len()
|
||||||
|
{
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let next = DecisionTreeClassifierParameters {
|
||||||
|
criterion: self.decision_tree_classifier_search_parameters.criterion
|
||||||
|
[self.current_criterion]
|
||||||
|
.clone(),
|
||||||
|
max_depth: self.decision_tree_classifier_search_parameters.max_depth
|
||||||
|
[self.current_max_depth],
|
||||||
|
min_samples_leaf: self
|
||||||
|
.decision_tree_classifier_search_parameters
|
||||||
|
.min_samples_leaf[self.current_min_samples_leaf],
|
||||||
|
min_samples_split: self
|
||||||
|
.decision_tree_classifier_search_parameters
|
||||||
|
.min_samples_split[self.current_min_samples_split],
|
||||||
|
};
|
||||||
|
|
||||||
|
if self.current_criterion + 1
|
||||||
|
< self
|
||||||
|
.decision_tree_classifier_search_parameters
|
||||||
|
.criterion
|
||||||
|
.len()
|
||||||
|
{
|
||||||
|
self.current_criterion += 1;
|
||||||
|
} else if self.current_max_depth + 1
|
||||||
|
< self
|
||||||
|
.decision_tree_classifier_search_parameters
|
||||||
|
.max_depth
|
||||||
|
.len()
|
||||||
|
{
|
||||||
|
self.current_criterion = 0;
|
||||||
|
self.current_max_depth += 1;
|
||||||
|
} else if self.current_min_samples_leaf + 1
|
||||||
|
< self
|
||||||
|
.decision_tree_classifier_search_parameters
|
||||||
|
.min_samples_leaf
|
||||||
|
.len()
|
||||||
|
{
|
||||||
|
self.current_criterion = 0;
|
||||||
|
self.current_max_depth = 0;
|
||||||
|
self.current_min_samples_leaf += 1;
|
||||||
|
} else if self.current_min_samples_split + 1
|
||||||
|
< self
|
||||||
|
.decision_tree_classifier_search_parameters
|
||||||
|
.min_samples_split
|
||||||
|
.len()
|
||||||
|
{
|
||||||
|
self.current_criterion = 0;
|
||||||
|
self.current_max_depth = 0;
|
||||||
|
self.current_min_samples_leaf = 0;
|
||||||
|
self.current_min_samples_split += 1;
|
||||||
|
} else {
|
||||||
|
self.current_criterion += 1;
|
||||||
|
self.current_max_depth += 1;
|
||||||
|
self.current_min_samples_leaf += 1;
|
||||||
|
self.current_min_samples_split += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for DecisionTreeClassifierSearchParameters {
|
||||||
|
fn default() -> Self {
|
||||||
|
let default_params = DecisionTreeClassifierParameters::default();
|
||||||
|
|
||||||
|
DecisionTreeClassifierSearchParameters {
|
||||||
|
criterion: vec![default_params.criterion],
|
||||||
|
max_depth: vec![default_params.max_depth],
|
||||||
|
min_samples_leaf: vec![default_params.min_samples_leaf],
|
||||||
|
min_samples_split: vec![default_params.min_samples_split],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: RealNumber> Node<T> {
|
impl<T: RealNumber> Node<T> {
|
||||||
fn new(index: usize, output: usize) -> Self {
|
fn new(index: usize, output: usize) -> Self {
|
||||||
Node {
|
Node {
|
||||||
@@ -651,6 +789,29 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn search_parameters() {
|
||||||
|
let parameters = DecisionTreeClassifierSearchParameters {
|
||||||
|
max_depth: vec![Some(10), Some(100)],
|
||||||
|
min_samples_split: vec![1, 2],
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let mut iter = parameters.into_iter();
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.max_depth, Some(10));
|
||||||
|
assert_eq!(next.min_samples_split, 1);
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.max_depth, Some(100));
|
||||||
|
assert_eq!(next.min_samples_split, 1);
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.max_depth, Some(10));
|
||||||
|
assert_eq!(next.min_samples_split, 2);
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.max_depth, Some(100));
|
||||||
|
assert_eq!(next.min_samples_split, 2);
|
||||||
|
assert!(iter.next().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn gini_impurity() {
|
fn gini_impurity() {
|
||||||
|
|||||||
@@ -134,6 +134,120 @@ impl Default for DecisionTreeRegressorParameters {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// DecisionTreeRegressor grid search parameters
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct DecisionTreeRegressorSearchParameters {
|
||||||
|
/// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||||
|
pub max_depth: Vec<Option<u16>>,
|
||||||
|
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||||
|
pub min_samples_leaf: Vec<usize>,
|
||||||
|
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||||
|
pub min_samples_split: Vec<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// DecisionTreeRegressor grid search iterator
|
||||||
|
pub struct DecisionTreeRegressorSearchParametersIterator {
|
||||||
|
decision_tree_regressor_search_parameters: DecisionTreeRegressorSearchParameters,
|
||||||
|
current_max_depth: usize,
|
||||||
|
current_min_samples_leaf: usize,
|
||||||
|
current_min_samples_split: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoIterator for DecisionTreeRegressorSearchParameters {
|
||||||
|
type Item = DecisionTreeRegressorParameters;
|
||||||
|
type IntoIter = DecisionTreeRegressorSearchParametersIterator;
|
||||||
|
|
||||||
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
|
DecisionTreeRegressorSearchParametersIterator {
|
||||||
|
decision_tree_regressor_search_parameters: self,
|
||||||
|
current_max_depth: 0,
|
||||||
|
current_min_samples_leaf: 0,
|
||||||
|
current_min_samples_split: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Iterator for DecisionTreeRegressorSearchParametersIterator {
|
||||||
|
type Item = DecisionTreeRegressorParameters;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
if self.current_max_depth
|
||||||
|
== self
|
||||||
|
.decision_tree_regressor_search_parameters
|
||||||
|
.max_depth
|
||||||
|
.len()
|
||||||
|
&& self.current_min_samples_leaf
|
||||||
|
== self
|
||||||
|
.decision_tree_regressor_search_parameters
|
||||||
|
.min_samples_leaf
|
||||||
|
.len()
|
||||||
|
&& self.current_min_samples_split
|
||||||
|
== self
|
||||||
|
.decision_tree_regressor_search_parameters
|
||||||
|
.min_samples_split
|
||||||
|
.len()
|
||||||
|
{
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let next = DecisionTreeRegressorParameters {
|
||||||
|
max_depth: self.decision_tree_regressor_search_parameters.max_depth
|
||||||
|
[self.current_max_depth],
|
||||||
|
min_samples_leaf: self
|
||||||
|
.decision_tree_regressor_search_parameters
|
||||||
|
.min_samples_leaf[self.current_min_samples_leaf],
|
||||||
|
min_samples_split: self
|
||||||
|
.decision_tree_regressor_search_parameters
|
||||||
|
.min_samples_split[self.current_min_samples_split],
|
||||||
|
};
|
||||||
|
|
||||||
|
if self.current_max_depth + 1
|
||||||
|
< self
|
||||||
|
.decision_tree_regressor_search_parameters
|
||||||
|
.max_depth
|
||||||
|
.len()
|
||||||
|
{
|
||||||
|
self.current_max_depth += 1;
|
||||||
|
} else if self.current_min_samples_leaf + 1
|
||||||
|
< self
|
||||||
|
.decision_tree_regressor_search_parameters
|
||||||
|
.min_samples_leaf
|
||||||
|
.len()
|
||||||
|
{
|
||||||
|
self.current_max_depth = 0;
|
||||||
|
self.current_min_samples_leaf += 1;
|
||||||
|
} else if self.current_min_samples_split + 1
|
||||||
|
< self
|
||||||
|
.decision_tree_regressor_search_parameters
|
||||||
|
.min_samples_split
|
||||||
|
.len()
|
||||||
|
{
|
||||||
|
self.current_max_depth = 0;
|
||||||
|
self.current_min_samples_leaf = 0;
|
||||||
|
self.current_min_samples_split += 1;
|
||||||
|
} else {
|
||||||
|
self.current_max_depth += 1;
|
||||||
|
self.current_min_samples_leaf += 1;
|
||||||
|
self.current_min_samples_split += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for DecisionTreeRegressorSearchParameters {
|
||||||
|
fn default() -> Self {
|
||||||
|
let default_params = DecisionTreeRegressorParameters::default();
|
||||||
|
|
||||||
|
DecisionTreeRegressorSearchParameters {
|
||||||
|
max_depth: vec![default_params.max_depth],
|
||||||
|
min_samples_leaf: vec![default_params.min_samples_leaf],
|
||||||
|
min_samples_split: vec![default_params.min_samples_split],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: RealNumber> Node<T> {
|
impl<T: RealNumber> Node<T> {
|
||||||
fn new(index: usize, output: T) -> Self {
|
fn new(index: usize, output: T) -> Self {
|
||||||
Node {
|
Node {
|
||||||
@@ -517,6 +631,29 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn search_parameters() {
|
||||||
|
let parameters = DecisionTreeRegressorSearchParameters {
|
||||||
|
max_depth: vec![Some(10), Some(100)],
|
||||||
|
min_samples_split: vec![1, 2],
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let mut iter = parameters.into_iter();
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.max_depth, Some(10));
|
||||||
|
assert_eq!(next.min_samples_split, 1);
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.max_depth, Some(100));
|
||||||
|
assert_eq!(next.min_samples_split, 1);
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.max_depth, Some(10));
|
||||||
|
assert_eq!(next.min_samples_split, 2);
|
||||||
|
let next = iter.next().unwrap();
|
||||||
|
assert_eq!(next.max_depth, Some(100));
|
||||||
|
assert_eq!(next.min_samples_split, 2);
|
||||||
|
assert!(iter.next().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn fit_longley() {
|
fn fit_longley() {
|
||||||
|
|||||||
Reference in New Issue
Block a user