From 810a5c429b9df1aa383e1eaf607f7c4c1e0b7a3f Mon Sep 17 00:00:00 2001 From: Volodymyr Orlov Date: Thu, 24 Dec 2020 18:36:23 -0800 Subject: [PATCH] feat: consolidates API --- src/api.rs | 43 +++++++++++++++ src/base.rs | 10 ---- src/cluster/dbscan.rs | 66 ++++++++++++++++-------- src/cluster/kmeans.rs | 65 +++++++++++++++-------- src/decomposition/pca.rs | 52 ++++++++++++------- src/decomposition/svd.rs | 40 +++++++++++--- src/ensemble/random_forest_classifier.rs | 15 +++++- src/ensemble/random_forest_regressor.rs | 15 +++++- src/lib.rs | 2 +- src/linear/elastic_net.rs | 10 +++- src/linear/lasso.rs | 10 +++- src/linear/linear_regression.rs | 14 ++++- src/linear/logistic_regression.rs | 14 ++++- src/linear/ridge_regression.rs | 14 ++++- src/model_selection/mod.rs | 2 +- src/naive_bayes/bernoulli.rs | 10 +++- src/naive_bayes/categorical.rs | 14 ++++- src/naive_bayes/gaussian.rs | 10 +++- src/naive_bayes/multinomial.rs | 14 ++++- src/neighbors/knn_classifier.rs | 14 ++++- src/neighbors/knn_regressor.rs | 14 ++++- src/svm/svc.rs | 10 +++- src/svm/svr.rs | 10 +++- src/tree/decision_tree_classifier.rs | 15 +++++- src/tree/decision_tree_regressor.rs | 15 +++++- 25 files changed, 400 insertions(+), 98 deletions(-) create mode 100644 src/api.rs delete mode 100644 src/base.rs diff --git a/src/api.rs b/src/api.rs new file mode 100644 index 0000000..c598e12 --- /dev/null +++ b/src/api.rs @@ -0,0 +1,43 @@ +//! # Common Interfaces and API +//! +//! This module provides interfaces and uniform API with simple conventions +//! that are used in other modules for supervised and unsupervised learning. + +use crate::error::Failed; + +/// An estimator for unsupervised learning, that provides method `fit` to learn from data +pub trait UnsupervisedEstimator { + /// Fit a model to a training dataset, estimate model's parameters. + /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation. + /// * `parameters` - hyperparameters of an algorithm + fn fit(x: &X, parameters: P) -> Result + where + Self: Sized, + P: Clone; +} + +/// An estimator for supervised learning, , that provides method `fit` to learn from data and training values +pub trait SupervisedEstimator { + /// Fit a model to a training dataset, estimate model's parameters. + /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation. + /// * `y` - target training values of size _N_. + /// * `parameters` - hyperparameters of an algorithm + fn fit(x: &X, y: &Y, parameters: P) -> Result + where + Self: Sized, + P: Clone; +} + +/// Implements method predict that estimates target value from new data +pub trait Predictor { + /// Estimate target values from new data. + /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation. + fn predict(&self, x: &X) -> Result; +} + +/// Implements method transform that filters or modifies input data +pub trait Transformer { + /// Transform data by modifying or filtering it + /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation. + fn transform(&self, x: &X) -> Result; +} diff --git a/src/base.rs b/src/base.rs deleted file mode 100644 index a2d4468..0000000 --- a/src/base.rs +++ /dev/null @@ -1,10 +0,0 @@ -//! # Common Interfaces and methods -//! -//! This module consolidates interfaces and uniform basic API that is used elsewhere in the code. - -use crate::error::Failed; - -/// Implements method predict that offers a way to estimate target value from new data -pub trait Predictor { - fn predict(&self, x: &X) -> Result; -} diff --git a/src/cluster/dbscan.rs b/src/cluster/dbscan.rs index c572ccc..9aed2f0 100644 --- a/src/cluster/dbscan.rs +++ b/src/cluster/dbscan.rs @@ -15,8 +15,7 @@ //! let blobs = generator::make_blobs(100, 2, 3); //! let x = DenseMatrix::from_vec(blobs.num_samples, blobs.num_features, &blobs.data); //! // Fit the algorithm and predict cluster labels -//! let labels = DBSCAN::fit(&x, Distances::euclidian(), -//! DBSCANParameters::default().with_eps(3.0)). +//! let labels = DBSCAN::fit(&x, DBSCANParameters::default().with_eps(3.0)). //! and_then(|dbscan| dbscan.predict(&x)); //! //! println!("{:?}", labels); @@ -33,9 +32,11 @@ use std::iter::Sum; use serde::{Deserialize, Serialize}; use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName}; +use crate::api::{Predictor, UnsupervisedEstimator}; use crate::error::Failed; use crate::linalg::{row_iter, Matrix}; -use crate::math::distance::Distance; +use crate::math::distance::euclidian::Euclidian; +use crate::math::distance::{Distance, Distances}; use crate::math::num::RealNumber; use crate::tree::decision_tree_classifier::which_max; @@ -50,7 +51,11 @@ pub struct DBSCAN, T>> { #[derive(Debug, Clone)] /// DBSCAN clustering algorithm parameters -pub struct DBSCANParameters { +pub struct DBSCANParameters, T>> { + /// a function that defines a distance between each pair of point in training data. + /// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait. + /// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions. + pub distance: D, /// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point. pub min_samples: usize, /// The maximum distance between two samples for one to be considered as in the neighborhood of the other. @@ -59,7 +64,18 @@ pub struct DBSCANParameters { pub algorithm: KNNAlgorithmName, } -impl DBSCANParameters { +impl, T>> DBSCANParameters { + /// a function that defines a distance between each pair of point in training data. + /// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait. + /// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions. + pub fn with_distance, T>>(self, distance: DD) -> DBSCANParameters { + DBSCANParameters { + distance, + min_samples: self.min_samples, + eps: self.eps, + algorithm: self.algorithm, + } + } /// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point. pub fn with_min_samples(mut self, min_samples: usize) -> Self { self.min_samples = min_samples; @@ -86,9 +102,10 @@ impl, T>> PartialEq for DBSCAN { } } -impl Default for DBSCANParameters { +impl Default for DBSCANParameters { fn default() -> Self { DBSCANParameters { + distance: Distances::euclidian(), min_samples: 5, eps: T::half(), algorithm: KNNAlgorithmName::CoverTree, @@ -96,6 +113,22 @@ impl Default for DBSCANParameters { } } +impl, D: Distance, T>> + UnsupervisedEstimator> for DBSCAN +{ + fn fit(x: &M, parameters: DBSCANParameters) -> Result { + DBSCAN::fit(x, parameters) + } +} + +impl, D: Distance, T>> Predictor + for DBSCAN +{ + fn predict(&self, x: &M) -> Result { + self.predict(x) + } +} + impl, T>> DBSCAN { /// Fit algorithm to _NxM_ matrix where _N_ is number of samples and _M_ is number of features. /// * `data` - training instances to cluster @@ -103,8 +136,7 @@ impl, T>> DBSCAN { /// * `parameters` - cluster parameters pub fn fit>( x: &M, - distance: D, - parameters: DBSCANParameters, + parameters: DBSCANParameters, ) -> Result, Failed> { if parameters.min_samples < 1 { return Err(Failed::fit(&"Invalid minPts".to_string())); @@ -121,7 +153,9 @@ impl, T>> DBSCAN { let n = x.shape().0; let mut y = vec![unassigned; n]; - let algo = parameters.algorithm.fit(row_iter(x).collect(), distance)?; + let algo = parameters + .algorithm + .fit(row_iter(x).collect(), parameters.distance)?; for (i, e) in row_iter(x).enumerate() { if y[i] == unassigned { @@ -195,7 +229,6 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::math::distance::euclidian::Euclidian; - use crate::math::distance::Distances; #[test] fn fit_predict_dbscan() { @@ -215,16 +248,7 @@ mod tests { let expected_labels = vec![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0]; - let dbscan = DBSCAN::fit( - &x, - Distances::euclidian(), - DBSCANParameters { - min_samples: 5, - eps: 1.0, - algorithm: KNNAlgorithmName::CoverTree, - }, - ) - .unwrap(); + let dbscan = DBSCAN::fit(&x, DBSCANParameters::default().with_eps(1.0)).unwrap(); let predicted_labels = dbscan.predict(&x).unwrap(); @@ -256,7 +280,7 @@ mod tests { &[5.2, 2.7, 3.9, 1.4], ]); - let dbscan = DBSCAN::fit(&x, Distances::euclidian(), Default::default()).unwrap(); + let dbscan = DBSCAN::fit(&x, Default::default()).unwrap(); let deserialized_dbscan: DBSCAN = serde_json::from_str(&serde_json::to_string(&dbscan).unwrap()).unwrap(); diff --git a/src/cluster/kmeans.rs b/src/cluster/kmeans.rs index bc5d673..44ce1e6 100644 --- a/src/cluster/kmeans.rs +++ b/src/cluster/kmeans.rs @@ -43,7 +43,7 @@ //! &[5.2, 2.7, 3.9, 1.4], //! ]); //! -//! let kmeans = KMeans::fit(&x, 2, Default::default()).unwrap(); // Fit to data, 2 clusters +//! let kmeans = KMeans::fit(&x, KMeansParameters::default().with_k(2)).unwrap(); // Fit to data, 2 clusters //! let y_hat = kmeans.predict(&x).unwrap(); // use the same points for prediction //! ``` //! @@ -59,6 +59,7 @@ use std::iter::Sum; use serde::{Deserialize, Serialize}; use crate::algorithm::neighbour::bbd_tree::BBDTree; +use crate::api::{Predictor, UnsupervisedEstimator}; use crate::error::Failed; use crate::linalg::Matrix; use crate::math::distance::euclidian::*; @@ -101,11 +102,18 @@ impl PartialEq for KMeans { #[derive(Debug, Clone)] /// K-Means clustering algorithm parameters pub struct KMeansParameters { + /// Number of clusters. + pub k: usize, /// Maximum number of iterations of the k-means algorithm for a single run. pub max_iter: usize, } impl KMeansParameters { + /// Number of clusters. + pub fn with_k(mut self, k: usize) -> Self { + self.k = k; + self + } /// Maximum number of iterations of the k-means algorithm for a single run. pub fn with_max_iter(mut self, max_iter: usize) -> Self { self.max_iter = max_iter; @@ -115,24 +123,37 @@ impl KMeansParameters { impl Default for KMeansParameters { fn default() -> Self { - KMeansParameters { max_iter: 100 } + KMeansParameters { + k: 2, + max_iter: 100, + } + } +} + +impl> UnsupervisedEstimator for KMeans { + fn fit(x: &M, parameters: KMeansParameters) -> Result { + KMeans::fit(x, parameters) + } +} + +impl> Predictor for KMeans { + fn predict(&self, x: &M) -> Result { + self.predict(x) } } impl KMeans { /// Fit algorithm to _NxM_ matrix where _N_ is number of samples and _M_ is number of features. - /// * `data` - training instances to cluster - /// * `k` - number of clusters + /// * `data` - training instances to cluster /// * `parameters` - cluster parameters - pub fn fit>( - data: &M, - k: usize, - parameters: KMeansParameters, - ) -> Result, Failed> { + pub fn fit>(data: &M, parameters: KMeansParameters) -> Result, Failed> { let bbd = BBDTree::new(data); - if k < 2 { - return Err(Failed::fit(&format!("invalid number of clusters: {}", k))); + if parameters.k < 2 { + return Err(Failed::fit(&format!( + "invalid number of clusters: {}", + parameters.k + ))); } if parameters.max_iter == 0 { @@ -145,9 +166,9 @@ impl KMeans { let (n, d) = data.shape(); let mut distortion = T::max_value(); - let mut y = KMeans::kmeans_plus_plus(data, k); - let mut size = vec![0; k]; - let mut centroids = vec![vec![T::zero(); d]; k]; + let mut y = KMeans::kmeans_plus_plus(data, parameters.k); + let mut size = vec![0; parameters.k]; + let mut centroids = vec![vec![T::zero(); d]; parameters.k]; for i in 0..n { size[y[i]] += 1; @@ -159,16 +180,16 @@ impl KMeans { } } - for i in 0..k { + for i in 0..parameters.k { for j in 0..d { centroids[i][j] /= T::from(size[i]).unwrap(); } } - let mut sums = vec![vec![T::zero(); d]; k]; + let mut sums = vec![vec![T::zero(); d]; parameters.k]; for _ in 1..=parameters.max_iter { let dist = bbd.clustering(¢roids, &mut sums, &mut size, &mut y); - for i in 0..k { + for i in 0..parameters.k { if size[i] > 0 { for j in 0..d { centroids[i][j] = T::from(sums[i][j]).unwrap() / T::from(size[i]).unwrap(); @@ -184,7 +205,7 @@ impl KMeans { } Ok(KMeans { - k, + k: parameters.k, y, size, distortion, @@ -280,10 +301,10 @@ mod tests { fn invalid_k() { let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]); - assert!(KMeans::fit(&x, 0, Default::default()).is_err()); + assert!(KMeans::fit(&x, KMeansParameters::default().with_k(0)).is_err()); assert_eq!( "Fit failed: invalid number of clusters: 1", - KMeans::fit(&x, 1, Default::default()) + KMeans::fit(&x, KMeansParameters::default().with_k(1)) .unwrap_err() .to_string() ); @@ -314,7 +335,7 @@ mod tests { &[5.2, 2.7, 3.9, 1.4], ]); - let kmeans = KMeans::fit(&x, 2, Default::default()).unwrap(); + let kmeans = KMeans::fit(&x, Default::default()).unwrap(); let y = kmeans.predict(&x).unwrap(); @@ -348,7 +369,7 @@ mod tests { &[5.2, 2.7, 3.9, 1.4], ]); - let kmeans = KMeans::fit(&x, 2, Default::default()).unwrap(); + let kmeans = KMeans::fit(&x, Default::default()).unwrap(); let deserialized_kmeans: KMeans = serde_json::from_str(&serde_json::to_string(&kmeans).unwrap()).unwrap(); diff --git a/src/decomposition/pca.rs b/src/decomposition/pca.rs index 68220e3..189e6de 100644 --- a/src/decomposition/pca.rs +++ b/src/decomposition/pca.rs @@ -37,7 +37,7 @@ //! &[5.2, 2.7, 3.9, 1.4], //! ]); //! -//! let pca = PCA::fit(&iris, 2, Default::default()).unwrap(); // Reduce number of features to 2 +//! let pca = PCA::fit(&iris, PCAParameters::default().with_n_components(2)).unwrap(); // Reduce number of features to 2 //! //! let iris_reduced = pca.transform(&iris).unwrap(); //! @@ -49,6 +49,7 @@ use std::fmt::Debug; use serde::{Deserialize, Serialize}; +use crate::api::{Transformer, UnsupervisedEstimator}; use crate::error::Failed; use crate::linalg::Matrix; use crate::math::num::RealNumber; @@ -83,12 +84,19 @@ impl> PartialEq for PCA { #[derive(Debug, Clone)] /// PCA parameters pub struct PCAParameters { + /// Number of components to keep. + pub n_components: usize, /// By default, covariance matrix is used to compute principal components. /// Enable this flag if you want to use correlation matrix instead. pub use_correlation_matrix: bool, } impl PCAParameters { + /// Number of components to keep. + pub fn with_n_components(mut self, n_components: usize) -> Self { + self.n_components = n_components; + self + } /// By default, covariance matrix is used to compute principal components. /// Enable this flag if you want to use correlation matrix instead. pub fn with_use_correlation_matrix(mut self, use_correlation_matrix: bool) -> Self { @@ -100,24 +108,33 @@ impl PCAParameters { impl Default for PCAParameters { fn default() -> Self { PCAParameters { + n_components: 2, use_correlation_matrix: false, } } } +impl> UnsupervisedEstimator for PCA { + fn fit(x: &M, parameters: PCAParameters) -> Result { + PCA::fit(x, parameters) + } +} + +impl> Transformer for PCA { + fn transform(&self, x: &M) -> Result { + self.transform(x) + } +} + impl> PCA { /// Fits PCA to your data. /// * `data` - _NxM_ matrix with _N_ observations and _M_ features in each observation. /// * `n_components` - number of components to keep. /// * `parameters` - other parameters, use `Default::default()` to set parameters to default values. - pub fn fit( - data: &M, - n_components: usize, - parameters: PCAParameters, - ) -> Result, Failed> { + pub fn fit(data: &M, parameters: PCAParameters) -> Result, Failed> { let (m, n) = data.shape(); - if n_components > n { + if parameters.n_components > n { return Err(Failed::fit(&format!( "Number of components, n_components should be <= number of attributes ({})", n @@ -196,16 +213,16 @@ impl> PCA { } } - let mut projection = M::zeros(n_components, n); + let mut projection = M::zeros(parameters.n_components, n); for i in 0..n { - for j in 0..n_components { + for j in 0..parameters.n_components { projection.set(j, i, eigenvectors.get(i, j)); } } - let mut pmu = vec![T::zero(); n_components]; + let mut pmu = vec![T::zero(); parameters.n_components]; for (k, mu_k) in mu.iter().enumerate().take(n) { - for (i, pmu_i) in pmu.iter_mut().enumerate().take(n_components) { + for (i, pmu_i) in pmu.iter_mut().enumerate().take(parameters.n_components) { *pmu_i += projection.get(i, k) * (*mu_k); } } @@ -318,7 +335,7 @@ mod tests { &[0.0752, 0.2007], ]); - let pca = PCA::fit(&us_arrests, 2, Default::default()).unwrap(); + let pca = PCA::fit(&us_arrests, Default::default()).unwrap(); assert!(expected.approximate_eq(&pca.components().abs(), 0.4)); } @@ -414,7 +431,7 @@ mod tests { 302.04806302399646, ]; - let pca = PCA::fit(&us_arrests, 4, Default::default()).unwrap(); + let pca = PCA::fit(&us_arrests, PCAParameters::default().with_n_components(4)).unwrap(); assert!(pca .eigenvectors @@ -525,10 +542,9 @@ mod tests { let pca = PCA::fit( &us_arrests, - 4, - PCAParameters { - use_correlation_matrix: true, - }, + PCAParameters::default() + .with_n_components(4) + .with_use_correlation_matrix(true), ) .unwrap(); @@ -573,7 +589,7 @@ mod tests { &[5.2, 2.7, 3.9, 1.4], ]); - let pca = PCA::fit(&iris, 4, Default::default()).unwrap(); + let pca = PCA::fit(&iris, Default::default()).unwrap(); let deserialized_pca: PCA> = serde_json::from_str(&serde_json::to_string(&pca).unwrap()).unwrap(); diff --git a/src/decomposition/svd.rs b/src/decomposition/svd.rs index eea1969..d404ca7 100644 --- a/src/decomposition/svd.rs +++ b/src/decomposition/svd.rs @@ -34,7 +34,7 @@ //! &[5.2, 2.7, 3.9, 1.4], //! ]); //! -//! let svd = SVD::fit(&iris, 2, Default::default()).unwrap(); // Reduce number of features to 2 +//! let svd = SVD::fit(&iris, SVDParameters::default().with_n_components(2)).unwrap(); // Reduce number of features to 2 //! //! let iris_reduced = svd.transform(&iris).unwrap(); //! @@ -47,6 +47,7 @@ use std::marker::PhantomData; use serde::{Deserialize, Serialize}; +use crate::api::{Transformer, UnsupervisedEstimator}; use crate::error::Failed; use crate::linalg::Matrix; use crate::math::num::RealNumber; @@ -67,11 +68,34 @@ impl> PartialEq for SVD { #[derive(Debug, Clone)] /// SVD parameters -pub struct SVDParameters {} +pub struct SVDParameters { + /// Number of components to keep. + pub n_components: usize, +} impl Default for SVDParameters { fn default() -> Self { - SVDParameters {} + SVDParameters { n_components: 2 } + } +} + +impl SVDParameters { + /// Number of components to keep. + pub fn with_n_components(mut self, n_components: usize) -> Self { + self.n_components = n_components; + self + } +} + +impl> UnsupervisedEstimator for SVD { + fn fit(x: &M, parameters: SVDParameters) -> Result { + SVD::fit(x, parameters) + } +} + +impl> Transformer for SVD { + fn transform(&self, x: &M) -> Result { + self.transform(x) } } @@ -80,10 +104,10 @@ impl> SVD { /// * `data` - _NxM_ matrix with _N_ observations and _M_ features in each observation. /// * `n_components` - number of components to keep. /// * `parameters` - other parameters, use `Default::default()` to set parameters to default values. - pub fn fit(x: &M, n_components: usize, _: SVDParameters) -> Result, Failed> { + pub fn fit(x: &M, parameters: SVDParameters) -> Result, Failed> { let (_, p) = x.shape(); - if n_components >= p { + if parameters.n_components >= p { return Err(Failed::fit(&format!( "Number of components, n_components should be < number of attributes ({})", p @@ -92,7 +116,7 @@ impl> SVD { let svd = x.svd()?; - let components = svd.V.slice(0..p, 0..n_components); + let components = svd.V.slice(0..p, 0..parameters.n_components); Ok(SVD { components, @@ -189,7 +213,7 @@ mod tests { &[197.28420365, -11.66808306], &[293.43187394, 1.91163633], ]); - let svd = SVD::fit(&x, 2, Default::default()).unwrap(); + let svd = SVD::fit(&x, Default::default()).unwrap(); let x_transformed = svd.transform(&x).unwrap(); @@ -225,7 +249,7 @@ mod tests { &[5.2, 2.7, 3.9, 1.4], ]); - let svd = SVD::fit(&iris, 2, Default::default()).unwrap(); + let svd = SVD::fit(&iris, Default::default()).unwrap(); let deserialized_svd: SVD> = serde_json::from_str(&serde_json::to_string(&svd).unwrap()).unwrap(); diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index 9f1ba72..49c4239 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -51,7 +51,7 @@ use std::fmt::Debug; use rand::Rng; use serde::{Deserialize, Serialize}; -use crate::base::Predictor; +use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; use crate::linalg::Matrix; use crate::math::num::RealNumber; @@ -151,6 +151,19 @@ impl Default for RandomForestClassifierParameters { } } +impl> + SupervisedEstimator + for RandomForestClassifier +{ + fn fit( + x: &M, + y: &M::RowVector, + parameters: RandomForestClassifierParameters, + ) -> Result { + RandomForestClassifier::fit(x, y, parameters) + } +} + impl> Predictor for RandomForestClassifier { fn predict(&self, x: &M) -> Result { self.predict(x) diff --git a/src/ensemble/random_forest_regressor.rs b/src/ensemble/random_forest_regressor.rs index 6aa89d0..fdeb9fc 100644 --- a/src/ensemble/random_forest_regressor.rs +++ b/src/ensemble/random_forest_regressor.rs @@ -49,7 +49,7 @@ use std::fmt::Debug; use rand::Rng; use serde::{Deserialize, Serialize}; -use crate::base::Predictor; +use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; use crate::linalg::Matrix; use crate::math::num::RealNumber; @@ -135,6 +135,19 @@ impl PartialEq for RandomForestRegressor { } } +impl> + SupervisedEstimator + for RandomForestRegressor +{ + fn fit( + x: &M, + y: &M::RowVector, + parameters: RandomForestRegressorParameters, + ) -> Result { + RandomForestRegressor::fit(x, y, parameters) + } +} + impl> Predictor for RandomForestRegressor { fn predict(&self, x: &M) -> Result { self.predict(x) diff --git a/src/lib.rs b/src/lib.rs index a1608c3..297fcc4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -71,7 +71,7 @@ /// Various algorithms and helper methods that are used elsewhere in SmartCore pub mod algorithm; -pub(crate) mod base; +pub mod api; /// Algorithms for clustering of unlabeled data pub mod cluster; /// Various datasets diff --git a/src/linear/elastic_net.rs b/src/linear/elastic_net.rs index 1ab933a..2833ff1 100644 --- a/src/linear/elastic_net.rs +++ b/src/linear/elastic_net.rs @@ -58,7 +58,7 @@ use std::fmt::Debug; use serde::{Deserialize, Serialize}; -use crate::base::Predictor; +use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; use crate::linalg::BaseVector; use crate::linalg::Matrix; @@ -139,6 +139,14 @@ impl> PartialEq for ElasticNet { } } +impl> SupervisedEstimator> + for ElasticNet +{ + fn fit(x: &M, y: &M::RowVector, parameters: ElasticNetParameters) -> Result { + ElasticNet::fit(x, y, parameters) + } +} + impl> Predictor for ElasticNet { fn predict(&self, x: &M) -> Result { self.predict(x) diff --git a/src/linear/lasso.rs b/src/linear/lasso.rs index e16a316..b99ecff 100644 --- a/src/linear/lasso.rs +++ b/src/linear/lasso.rs @@ -26,7 +26,7 @@ use std::fmt::Debug; use serde::{Deserialize, Serialize}; -use crate::base::Predictor; +use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; use crate::linalg::BaseVector; use crate::linalg::Matrix; @@ -95,6 +95,14 @@ impl> PartialEq for Lasso { } } +impl> SupervisedEstimator> + for Lasso +{ + fn fit(x: &M, y: &M::RowVector, parameters: LassoParameters) -> Result { + Lasso::fit(x, y, parameters) + } +} + impl> Predictor for Lasso { fn predict(&self, x: &M) -> Result { self.predict(x) diff --git a/src/linear/linear_regression.rs b/src/linear/linear_regression.rs index 1855673..2ef03c1 100644 --- a/src/linear/linear_regression.rs +++ b/src/linear/linear_regression.rs @@ -64,7 +64,7 @@ use std::fmt::Debug; use serde::{Deserialize, Serialize}; -use crate::base::Predictor; +use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; use crate::linalg::Matrix; use crate::math::num::RealNumber; @@ -116,6 +116,18 @@ impl> PartialEq for LinearRegression { } } +impl> SupervisedEstimator + for LinearRegression +{ + fn fit( + x: &M, + y: &M::RowVector, + parameters: LinearRegressionParameters, + ) -> Result { + LinearRegression::fit(x, y, parameters) + } +} + impl> Predictor for LinearRegression { fn predict(&self, x: &M) -> Result { self.predict(x) diff --git a/src/linear/logistic_regression.rs b/src/linear/logistic_regression.rs index ffb845c..a71ac45 100644 --- a/src/linear/logistic_regression.rs +++ b/src/linear/logistic_regression.rs @@ -58,7 +58,7 @@ use std::marker::PhantomData; use serde::{Deserialize, Serialize}; -use crate::base::Predictor; +use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; use crate::linalg::Matrix; use crate::math::num::RealNumber; @@ -218,6 +218,18 @@ impl<'a, T: RealNumber, M: Matrix> ObjectiveFunction } } +impl> SupervisedEstimator + for LogisticRegression +{ + fn fit( + x: &M, + y: &M::RowVector, + parameters: LogisticRegressionParameters, + ) -> Result { + LogisticRegression::fit(x, y, parameters) + } +} + impl> Predictor for LogisticRegression { fn predict(&self, x: &M) -> Result { self.predict(x) diff --git a/src/linear/ridge_regression.rs b/src/linear/ridge_regression.rs index f29898d..e9ed1ff 100644 --- a/src/linear/ridge_regression.rs +++ b/src/linear/ridge_regression.rs @@ -60,7 +60,7 @@ use std::fmt::Debug; use serde::{Deserialize, Serialize}; -use crate::base::Predictor; +use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; use crate::linalg::BaseVector; use crate::linalg::Matrix; @@ -130,6 +130,18 @@ impl> PartialEq for RidgeRegression { } } +impl> SupervisedEstimator> + for RidgeRegression +{ + fn fit( + x: &M, + y: &M::RowVector, + parameters: RidgeRegressionParameters, + ) -> Result { + RidgeRegression::fit(x, y, parameters) + } +} + impl> Predictor for RidgeRegression { fn predict(&self, x: &M) -> Result { self.predict(x) diff --git a/src/model_selection/mod.rs b/src/model_selection/mod.rs index 7776354..18dfa35 100644 --- a/src/model_selection/mod.rs +++ b/src/model_selection/mod.rs @@ -9,7 +9,7 @@ //! //! In SmartCore you can split your data into training and test datasets using `train_test_split` function. -use crate::base::Predictor; +use crate::api::Predictor; use crate::error::Failed; use crate::linalg::BaseVector; use crate::linalg::Matrix; diff --git a/src/naive_bayes/bernoulli.rs b/src/naive_bayes/bernoulli.rs index c6cbfa8..388646f 100644 --- a/src/naive_bayes/bernoulli.rs +++ b/src/naive_bayes/bernoulli.rs @@ -33,7 +33,7 @@ //! ## References: //! //! * ["Introduction to Information Retrieval", Manning C. D., Raghavan P., Schutze H., 2009, Chapter 13 ](https://nlp.stanford.edu/IR-book/information-retrieval-book.html) -use crate::base::Predictor; +use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; use crate::linalg::row_iter; use crate::linalg::BaseVector; @@ -208,6 +208,14 @@ pub struct BernoulliNB> { binarize: Option, } +impl> SupervisedEstimator> + for BernoulliNB +{ + fn fit(x: &M, y: &M::RowVector, parameters: BernoulliNBParameters) -> Result { + BernoulliNB::fit(x, y, parameters) + } +} + impl> Predictor for BernoulliNB { fn predict(&self, x: &M) -> Result { self.predict(x) diff --git a/src/naive_bayes/categorical.rs b/src/naive_bayes/categorical.rs index 667a270..c6f28bd 100644 --- a/src/naive_bayes/categorical.rs +++ b/src/naive_bayes/categorical.rs @@ -30,7 +30,7 @@ //! let nb = CategoricalNB::fit(&x, &y, Default::default()).unwrap(); //! let y_hat = nb.predict(&x).unwrap(); //! ``` -use crate::base::Predictor; +use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; use crate::linalg::BaseVector; use crate::linalg::Matrix; @@ -242,6 +242,18 @@ pub struct CategoricalNB> { inner: BaseNaiveBayes>, } +impl> SupervisedEstimator> + for CategoricalNB +{ + fn fit( + x: &M, + y: &M::RowVector, + parameters: CategoricalNBParameters, + ) -> Result { + CategoricalNB::fit(x, y, parameters) + } +} + impl> Predictor for CategoricalNB { fn predict(&self, x: &M) -> Result { self.predict(x) diff --git a/src/naive_bayes/gaussian.rs b/src/naive_bayes/gaussian.rs index bc96420..2ac9892 100644 --- a/src/naive_bayes/gaussian.rs +++ b/src/naive_bayes/gaussian.rs @@ -22,7 +22,7 @@ //! let nb = GaussianNB::fit(&x, &y, Default::default()).unwrap(); //! let y_hat = nb.predict(&x).unwrap(); //! ``` -use crate::base::Predictor; +use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; use crate::linalg::row_iter; use crate::linalg::BaseVector; @@ -183,6 +183,14 @@ pub struct GaussianNB> { inner: BaseNaiveBayes>, } +impl> SupervisedEstimator> + for GaussianNB +{ + fn fit(x: &M, y: &M::RowVector, parameters: GaussianNBParameters) -> Result { + GaussianNB::fit(x, y, parameters) + } +} + impl> Predictor for GaussianNB { fn predict(&self, x: &M) -> Result { self.predict(x) diff --git a/src/naive_bayes/multinomial.rs b/src/naive_bayes/multinomial.rs index 237b606..4cae1f3 100644 --- a/src/naive_bayes/multinomial.rs +++ b/src/naive_bayes/multinomial.rs @@ -33,7 +33,7 @@ //! ## References: //! //! * ["Introduction to Information Retrieval", Manning C. D., Raghavan P., Schutze H., 2009, Chapter 13 ](https://nlp.stanford.edu/IR-book/information-retrieval-book.html) -use crate::base::Predictor; +use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; use crate::linalg::row_iter; use crate::linalg::BaseVector; @@ -194,6 +194,18 @@ pub struct MultinomialNB> { inner: BaseNaiveBayes>, } +impl> SupervisedEstimator> + for MultinomialNB +{ + fn fit( + x: &M, + y: &M::RowVector, + parameters: MultinomialNBParameters, + ) -> Result { + MultinomialNB::fit(x, y, parameters) + } +} + impl> Predictor for MultinomialNB { fn predict(&self, x: &M) -> Result { self.predict(x) diff --git a/src/neighbors/knn_classifier.rs b/src/neighbors/knn_classifier.rs index 6668539..97dd748 100644 --- a/src/neighbors/knn_classifier.rs +++ b/src/neighbors/knn_classifier.rs @@ -36,7 +36,7 @@ use std::marker::PhantomData; use serde::{Deserialize, Serialize}; use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName}; -use crate::base::Predictor; +use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; use crate::linalg::{row_iter, Matrix}; use crate::math::distance::euclidian::Euclidian; @@ -139,6 +139,18 @@ impl, T>> PartialEq for KNNClassifier { } } +impl, D: Distance, T>> + SupervisedEstimator> for KNNClassifier +{ + fn fit( + x: &M, + y: &M::RowVector, + parameters: KNNClassifierParameters, + ) -> Result { + KNNClassifier::fit(x, y, parameters) + } +} + impl, D: Distance, T>> Predictor for KNNClassifier { diff --git a/src/neighbors/knn_regressor.rs b/src/neighbors/knn_regressor.rs index 80971e5..4e73103 100644 --- a/src/neighbors/knn_regressor.rs +++ b/src/neighbors/knn_regressor.rs @@ -39,7 +39,7 @@ use std::marker::PhantomData; use serde::{Deserialize, Serialize}; use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName}; -use crate::base::Predictor; +use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; use crate::linalg::{row_iter, BaseVector, Matrix}; use crate::math::distance::euclidian::Euclidian; @@ -133,6 +133,18 @@ impl, T>> PartialEq for KNNRegressor { } } +impl, D: Distance, T>> + SupervisedEstimator> for KNNRegressor +{ + fn fit( + x: &M, + y: &M::RowVector, + parameters: KNNRegressorParameters, + ) -> Result { + KNNRegressor::fit(x, y, parameters) + } +} + impl, D: Distance, T>> Predictor for KNNRegressor { diff --git a/src/svm/svc.rs b/src/svm/svc.rs index aee4d3f..095d555 100644 --- a/src/svm/svc.rs +++ b/src/svm/svc.rs @@ -78,7 +78,7 @@ use rand::seq::SliceRandom; use serde::{Deserialize, Serialize}; -use crate::base::Predictor; +use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; use crate::linalg::BaseVector; use crate::linalg::Matrix; @@ -185,6 +185,14 @@ impl> Default for SVCParameters } } +impl, K: Kernel> + SupervisedEstimator> for SVC +{ + fn fit(x: &M, y: &M::RowVector, parameters: SVCParameters) -> Result { + SVC::fit(x, y, parameters) + } +} + impl, K: Kernel> Predictor for SVC { diff --git a/src/svm/svr.rs b/src/svm/svr.rs index 295ad78..9eb6046 100644 --- a/src/svm/svr.rs +++ b/src/svm/svr.rs @@ -70,7 +70,7 @@ use std::marker::PhantomData; use serde::{Deserialize, Serialize}; -use crate::base::Predictor; +use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; use crate::linalg::BaseVector; use crate::linalg::Matrix; @@ -174,6 +174,14 @@ impl> Default for SVRParameters } } +impl, K: Kernel> + SupervisedEstimator> for SVR +{ + fn fit(x: &M, y: &M::RowVector, parameters: SVRParameters) -> Result { + SVR::fit(x, y, parameters) + } +} + impl, K: Kernel> Predictor for SVR { diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index 50a855b..3a92c54 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -71,7 +71,7 @@ use rand::seq::SliceRandom; use serde::{Deserialize, Serialize}; use crate::algorithm::sort::quick_sort::QuickArgSort; -use crate::base::Predictor; +use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; use crate::linalg::Matrix; use crate::math::num::RealNumber; @@ -293,6 +293,19 @@ pub(in crate) fn which_max(x: &[usize]) -> usize { which } +impl> + SupervisedEstimator + for DecisionTreeClassifier +{ + fn fit( + x: &M, + y: &M::RowVector, + parameters: DecisionTreeClassifierParameters, + ) -> Result { + DecisionTreeClassifier::fit(x, y, parameters) + } +} + impl> Predictor for DecisionTreeClassifier { fn predict(&self, x: &M) -> Result { self.predict(x) diff --git a/src/tree/decision_tree_regressor.rs b/src/tree/decision_tree_regressor.rs index 806e680..06ee507 100644 --- a/src/tree/decision_tree_regressor.rs +++ b/src/tree/decision_tree_regressor.rs @@ -66,7 +66,7 @@ use rand::seq::SliceRandom; use serde::{Deserialize, Serialize}; use crate::algorithm::sort::quick_sort::QuickArgSort; -use crate::base::Predictor; +use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; use crate::linalg::Matrix; use crate::math::num::RealNumber; @@ -208,6 +208,19 @@ impl<'a, T: RealNumber, M: Matrix> NodeVisitor<'a, T, M> { } } +impl> + SupervisedEstimator + for DecisionTreeRegressor +{ + fn fit( + x: &M, + y: &M::RowVector, + parameters: DecisionTreeRegressorParameters, + ) -> Result { + DecisionTreeRegressor::fit(x, y, parameters) + } +} + impl> Predictor for DecisionTreeRegressor { fn predict(&self, x: &M) -> Result { self.predict(x)