feat: Make SerDe optional

This commit is contained in:
Ben Cross
2021-01-17 21:35:03 +00:00
parent eb769493e7
commit e0d46f430b
44 changed files with 206 additions and 126 deletions
+1 -2
View File
@@ -25,8 +25,7 @@ num-traits = "0.2.12"
num = "0.3.0"
rand = "0.7.3"
rand_distr = "0.3.0"
serde = { version = "1.0.115", features = ["derive"] }
serde_derive = "1.0.115"
serde = { version = "1.0.115", features = ["derive"], optional = true }
[dev-dependencies]
criterion = "0.3"
+6 -4
View File
@@ -24,7 +24,7 @@
//! ```
use std::fmt::Debug;
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::algorithm::sort::heap_select::HeapSelection;
use crate::error::{Failed, FailedError};
@@ -32,7 +32,8 @@ use crate::math::distance::Distance;
use crate::math::num::RealNumber;
/// Implements Cover Tree algorithm
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct CoverTree<T, F: RealNumber, D: Distance<T, F>> {
base: F,
inv_log_base: F,
@@ -56,7 +57,8 @@ impl<T, F: RealNumber, D: Distance<T, F>> PartialEq for CoverTree<T, F, D> {
}
}
#[derive(Debug, Serialize, Deserialize)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
struct Node<F: RealNumber> {
idx: usize,
max_dist: F,
@@ -65,7 +67,7 @@ struct Node<F: RealNumber> {
scale: i64,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug)]
struct DistanceSet<F: RealNumber> {
idx: usize,
dist: Vec<F>,
+3 -2
View File
@@ -22,7 +22,7 @@
//!
//! ```
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use std::cmp::{Ordering, PartialOrd};
use std::marker::PhantomData;
@@ -32,7 +32,8 @@ use crate::math::distance::Distance;
use crate::math::num::RealNumber;
/// Implements Linear Search algorithm, see [KNN algorithms](../index.html)
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct LinearKNNSearch<T, F: RealNumber, D: Distance<T, F>> {
distance: D,
data: Vec<T>,
+5 -3
View File
@@ -35,7 +35,7 @@ use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
use crate::error::Failed;
use crate::math::distance::Distance;
use crate::math::num::RealNumber;
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
pub(crate) mod bbd_tree;
/// tree data structure for fast nearest neighbor search
@@ -45,7 +45,8 @@ pub mod linear_search;
/// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries.
/// `KNNAlgorithmName` maintains a list of supported search algorithms, see [KNN algorithms](../algorithm/neighbour/index.html)
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub enum KNNAlgorithmName {
/// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html)
LinearSearch,
@@ -53,7 +54,8 @@ pub enum KNNAlgorithmName {
CoverTree,
}
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub(crate) enum KNNAlgorithm<T: RealNumber, D: Distance<Vec<T>, T>> {
LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
CoverTree(CoverTree<Vec<T>, T, D>),
+3 -2
View File
@@ -43,7 +43,7 @@
use std::fmt::Debug;
use std::iter::Sum;
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
use crate::api::{Predictor, UnsupervisedEstimator};
@@ -55,7 +55,8 @@ use crate::math::num::RealNumber;
use crate::tree::decision_tree_classifier::which_max;
/// DBSCAN clustering algorithm
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct DBSCAN<T: RealNumber, D: Distance<Vec<T>, T>> {
cluster_labels: Vec<i16>,
num_classes: usize,
+3 -2
View File
@@ -56,7 +56,7 @@ use rand::Rng;
use std::fmt::Debug;
use std::iter::Sum;
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::algorithm::neighbour::bbd_tree::BBDTree;
use crate::api::{Predictor, UnsupervisedEstimator};
@@ -66,7 +66,8 @@ use crate::math::distance::euclidian::*;
use crate::math::num::RealNumber;
/// K-Means clustering algorithm
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct KMeans<T: RealNumber> {
k: usize,
y: Vec<usize>,
+3 -2
View File
@@ -47,7 +47,7 @@
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::fmt::Debug;
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::api::{Transformer, UnsupervisedEstimator};
use crate::error::Failed;
@@ -55,7 +55,8 @@ use crate::linalg::Matrix;
use crate::math::num::RealNumber;
/// Principal components analysis algorithm
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct PCA<T: RealNumber, M: Matrix<T>> {
eigenvectors: M,
eigenvalues: Vec<T>,
+3 -2
View File
@@ -46,7 +46,7 @@
use std::fmt::Debug;
use std::marker::PhantomData;
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::api::{Transformer, UnsupervisedEstimator};
use crate::error::Failed;
@@ -54,7 +54,8 @@ use crate::linalg::Matrix;
use crate::math::num::RealNumber;
/// SVD
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct SVD<T: RealNumber, M: Matrix<T>> {
components: M,
phantom: PhantomData<T>,
+5 -3
View File
@@ -49,7 +49,7 @@ use std::default::Default;
use std::fmt::Debug;
use rand::Rng;
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed;
@@ -61,7 +61,8 @@ use crate::tree::decision_tree_classifier::{
/// Parameters of the Random Forest algorithm.
/// Some parameters here are passed directly into base estimator.
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct RandomForestClassifierParameters {
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub criterion: SplitCriterion,
@@ -78,7 +79,8 @@ pub struct RandomForestClassifierParameters {
}
/// Random Forest Classifier
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct RandomForestClassifier<T: RealNumber> {
parameters: RandomForestClassifierParameters,
trees: Vec<DecisionTreeClassifier<T>>,
+5 -3
View File
@@ -47,7 +47,7 @@ use std::default::Default;
use std::fmt::Debug;
use rand::Rng;
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed;
@@ -57,7 +57,8 @@ use crate::tree::decision_tree_regressor::{
DecisionTreeRegressor, DecisionTreeRegressorParameters,
};
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// Parameters of the Random Forest Regressor
/// Some parameters here are passed directly into base estimator.
pub struct RandomForestRegressorParameters {
@@ -74,7 +75,8 @@ pub struct RandomForestRegressorParameters {
}
/// Random Forest Regressor
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct RandomForestRegressor<T: RealNumber> {
parameters: RandomForestRegressorParameters,
trees: Vec<DecisionTreeRegressor<T>>,
+5 -3
View File
@@ -2,10 +2,11 @@
use std::error::Error;
use std::fmt;
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
/// Generic error to be raised when something goes wrong.
#[derive(Debug, Serialize, Deserialize)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct Failed {
err: FailedError,
msg: String,
@@ -13,7 +14,8 @@ pub struct Failed {
/// Type of error
#[non_exhaustive]
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Copy, Clone, Debug)]
pub enum FailedError {
/// Can't fit algorithm to data
FitFailed = 1,
+6 -1
View File
@@ -1,11 +1,14 @@
#![allow(clippy::ptr_arg)]
use std::fmt;
use std::fmt::Debug;
use std::marker::PhantomData;
#[cfg(feature = "serde")] use std::marker::PhantomData;
use std::ops::Range;
#[cfg(feature = "serde")]
use serde::de::{Deserializer, MapAccess, SeqAccess, Visitor};
#[cfg(feature = "serde")]
use serde::ser::{SerializeStruct, Serializer};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::linalg::cholesky::CholeskyDecomposableMatrix;
@@ -349,6 +352,7 @@ impl<'a, T: RealNumber> Iterator for DenseMatrixIterator<'a, T> {
}
}
#[cfg(feature = "serde")]
impl<'de, T: RealNumber + fmt::Debug + Deserialize<'de>> Deserialize<'de> for DenseMatrix<T> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
@@ -434,6 +438,7 @@ impl<'de, T: RealNumber + fmt::Debug + Deserialize<'de>> Deserialize<'de> for De
}
}
#[cfg(feature = "serde")]
impl<T: RealNumber + fmt::Debug + Serialize> Serialize for DenseMatrix<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
+5 -3
View File
@@ -56,7 +56,7 @@
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::fmt::Debug;
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed;
@@ -67,7 +67,8 @@ use crate::math::num::RealNumber;
use crate::linear::lasso_optimizer::InteriorPointOptimizer;
/// Elastic net parameters
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct ElasticNetParameters<T: RealNumber> {
/// Regularization parameter.
pub alpha: T,
@@ -84,7 +85,8 @@ pub struct ElasticNetParameters<T: RealNumber> {
}
/// Elastic net
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct ElasticNet<T: RealNumber, M: Matrix<T>> {
coefficients: M,
intercept: T,
+5 -3
View File
@@ -24,7 +24,7 @@
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::fmt::Debug;
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed;
@@ -34,7 +34,8 @@ use crate::linear::lasso_optimizer::InteriorPointOptimizer;
use crate::math::num::RealNumber;
/// Lasso regression parameters
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LassoParameters<T: RealNumber> {
/// Controls the strength of the penalty to the loss function.
pub alpha: T,
@@ -47,7 +48,8 @@ pub struct LassoParameters<T: RealNumber> {
pub max_iter: usize,
}
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
/// Lasso regressor
pub struct Lasso<T: RealNumber, M: Matrix<T>> {
coefficients: M,
+7 -4
View File
@@ -62,14 +62,15 @@
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::fmt::Debug;
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable.
pub enum LinearRegressionSolverName {
/// QR decomposition, see [QR](../../linalg/qr/index.html)
@@ -79,14 +80,16 @@ pub enum LinearRegressionSolverName {
}
/// Linear Regression parameters
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LinearRegressionParameters {
/// Solver to use for estimation of regression coefficients.
pub solver: LinearRegressionSolverName,
}
/// Linear Regression
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct LinearRegression<T: RealNumber, M: Matrix<T>> {
coefficients: M,
intercept: T,
+5 -3
View File
@@ -56,7 +56,7 @@ use std::cmp::Ordering;
use std::fmt::Debug;
use std::marker::PhantomData;
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed;
@@ -68,11 +68,13 @@ use crate::optimization::line_search::Backtracking;
use crate::optimization::FunctionOrder;
/// Logistic Regression parameters
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LogisticRegressionParameters {}
/// Logistic Regression
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct LogisticRegression<T: RealNumber, M: Matrix<T>> {
coefficients: M,
intercept: M,
+7 -4
View File
@@ -58,7 +58,7 @@
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::fmt::Debug;
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed;
@@ -66,7 +66,8 @@ use crate::linalg::BaseVector;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable.
pub enum RidgeRegressionSolverName {
/// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html)
@@ -76,7 +77,8 @@ pub enum RidgeRegressionSolverName {
}
/// Ridge Regression parameters
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct RidgeRegressionParameters<T: RealNumber> {
/// Solver to use for estimation of regression coefficients.
pub solver: RidgeRegressionSolverName,
@@ -88,7 +90,8 @@ pub struct RidgeRegressionParameters<T: RealNumber> {
}
/// Ridge regression
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct RidgeRegression<T: RealNumber, M: Matrix<T>> {
coefficients: M,
intercept: T,
+3 -2
View File
@@ -18,14 +18,15 @@
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::math::num::RealNumber;
use super::Distance;
/// Euclidean distance is a measure of the true straight line distance between two points in Euclidean n-space.
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Euclidian {}
impl Euclidian {
+3 -2
View File
@@ -19,14 +19,15 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::math::num::RealNumber;
use super::Distance;
/// While comparing two integer-valued vectors of equal length, Hamming distance is the number of bit positions in which the two bits are different
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Hamming {}
impl<T: PartialEq, F: RealNumber> Distance<Vec<T>, F> for Hamming {
+3 -2
View File
@@ -44,7 +44,7 @@
use std::marker::PhantomData;
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::math::num::RealNumber;
@@ -52,7 +52,8 @@ use super::Distance;
use crate::linalg::Matrix;
/// Mahalanobis distance.
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Mahalanobis<T: RealNumber, M: Matrix<T>> {
/// covariance matrix of the dataset
pub sigma: M,
+3 -2
View File
@@ -17,14 +17,15 @@
//! ```
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::math::num::RealNumber;
use super::Distance;
/// Manhattan distance
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Manhattan {}
impl<T: RealNumber> Distance<Vec<T>, T> for Manhattan {
+3 -2
View File
@@ -21,14 +21,15 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::math::num::RealNumber;
use super::Distance;
/// Defines the Minkowski distance of order `p`
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Minkowski {
/// order, integer
pub p: u16,
+3 -2
View File
@@ -16,13 +16,14 @@
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
/// Accuracy metric.
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct Accuracy {}
impl Accuracy {
+3 -2
View File
@@ -20,14 +20,15 @@
//! * ["The ROC-AUC and the Mann-Whitney U-test", Haupt, J.](https://johaupt.github.io/roc-auc/model%20evaluation/Area_under_ROC_curve.html)
#![allow(non_snake_case)]
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::algorithm::sort::quick_sort::QuickArgSort;
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
/// Area Under the Receiver Operating Characteristic Curve (ROC AUC)
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct AUC {}
impl AUC {
+3 -2
View File
@@ -1,10 +1,11 @@
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
use crate::metrics::cluster_helpers::*;
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
/// Homogeneity, completeness and V-Measure scores.
pub struct HCVScore {}
+3 -2
View File
@@ -18,7 +18,7 @@
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
@@ -26,7 +26,8 @@ use crate::metrics::precision::Precision;
use crate::metrics::recall::Recall;
/// F-measure
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct F1<T: RealNumber> {
/// a positive real factor
pub beta: T,
+3 -2
View File
@@ -18,12 +18,13 @@
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
/// Mean Absolute Error
pub struct MeanAbsoluteError {}
+3 -2
View File
@@ -18,12 +18,13 @@
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
/// Mean Squared Error
pub struct MeanSquareError {}
+3 -2
View File
@@ -18,13 +18,14 @@
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
/// Precision metric.
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct Precision {}
impl Precision {
+3 -2
View File
@@ -18,13 +18,14 @@
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
/// Coefficient of Determination (R2)
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct R2 {}
impl R2 {
+3 -2
View File
@@ -18,13 +18,14 @@
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
/// Recall metric.
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct Recall {}
impl Recall {
+7 -4
View File
@@ -42,10 +42,11 @@ use crate::math::num::RealNumber;
use crate::math::vector::RealNumberVector;
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
/// Naive Bayes classifier for Bearnoulli features
#[derive(Serialize, Deserialize, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, PartialEq)]
struct BernoulliNBDistribution<T: RealNumber> {
/// class labels known to the classifier
class_labels: Vec<T>,
@@ -77,7 +78,8 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for BernoulliNBDistributi
}
/// `BernoulliNB` parameters. Use `Default::default()` for default values.
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct BernoulliNBParameters<T: RealNumber> {
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
pub alpha: T,
@@ -202,7 +204,8 @@ impl<T: RealNumber> BernoulliNBDistribution<T> {
}
/// BernoulliNB implements the categorical naive Bayes algorithm for categorically distributed data.
#[derive(Serialize, Deserialize, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, PartialEq)]
pub struct BernoulliNB<T: RealNumber, M: Matrix<T>> {
inner: BaseNaiveBayes<T, M, BernoulliNBDistribution<T>>,
binarize: Option<T>,
+7 -4
View File
@@ -36,10 +36,11 @@ use crate::linalg::BaseVector;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
/// Naive Bayes classifier for categorical features
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
struct CategoricalNBDistribution<T: RealNumber> {
class_labels: Vec<T>,
class_priors: Vec<T>,
@@ -216,7 +217,8 @@ impl<T: RealNumber> CategoricalNBDistribution<T> {
}
/// `CategoricalNB` parameters. Use `Default::default()` for default values.
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct CategoricalNBParameters<T: RealNumber> {
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
pub alpha: T,
@@ -237,7 +239,8 @@ impl<T: RealNumber> Default for CategoricalNBParameters<T> {
}
/// CategoricalNB implements the categorical naive Bayes algorithm for categorically distributed data.
#[derive(Serialize, Deserialize, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, PartialEq)]
pub struct CategoricalNB<T: RealNumber, M: Matrix<T>> {
inner: BaseNaiveBayes<T, M, CategoricalNBDistribution<T>>,
}
+7 -4
View File
@@ -30,10 +30,11 @@ use crate::linalg::Matrix;
use crate::math::num::RealNumber;
use crate::math::vector::RealNumberVector;
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
/// Naive Bayes classifier for categorical features
#[derive(Serialize, Deserialize, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, PartialEq)]
struct GaussianNBDistribution<T: RealNumber> {
/// class labels known to the classifier
class_labels: Vec<T>,
@@ -75,7 +76,8 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for GaussianNBDistributio
}
/// `GaussianNB` parameters. Use `Default::default()` for default values.
#[derive(Serialize, Deserialize, Debug, Default, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Default, Clone)]
pub struct GaussianNBParameters<T: RealNumber> {
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
pub priors: Option<Vec<T>>,
@@ -178,7 +180,8 @@ impl<T: RealNumber> GaussianNBDistribution<T> {
}
/// GaussianNB implements the categorical naive Bayes algorithm for categorically distributed data.
#[derive(Serialize, Deserialize, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, PartialEq)]
pub struct GaussianNB<T: RealNumber, M: Matrix<T>> {
inner: BaseNaiveBayes<T, M, GaussianNBDistribution<T>>,
}
+3 -2
View File
@@ -39,7 +39,7 @@ use crate::error::Failed;
use crate::linalg::BaseVector;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
/// Distribution used in the Naive Bayes classifier.
@@ -55,7 +55,8 @@ pub(crate) trait NBDistribution<T: RealNumber, M: Matrix<T>> {
}
/// Base struct for the Naive Bayes classifier.
#[derive(Serialize, Deserialize, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, PartialEq)]
pub(crate) struct BaseNaiveBayes<T: RealNumber, M: Matrix<T>, D: NBDistribution<T, M>> {
distribution: D,
_phantom_t: PhantomData<T>,
+7 -4
View File
@@ -42,10 +42,11 @@ use crate::math::num::RealNumber;
use crate::math::vector::RealNumberVector;
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
/// Naive Bayes classifier for Multinomial features
#[derive(Serialize, Deserialize, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, PartialEq)]
struct MultinomialNBDistribution<T: RealNumber> {
/// class labels known to the classifier
class_labels: Vec<T>,
@@ -73,7 +74,8 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for MultinomialNBDistribu
}
/// `MultinomialNB` parameters. Use `Default::default()` for default values.
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct MultinomialNBParameters<T: RealNumber> {
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
pub alpha: T,
@@ -189,7 +191,8 @@ impl<T: RealNumber> MultinomialNBDistribution<T> {
}
/// MultinomialNB implements the categorical naive Bayes algorithm for categorically distributed data.
#[derive(Serialize, Deserialize, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, PartialEq)]
pub struct MultinomialNB<T: RealNumber, M: Matrix<T>> {
inner: BaseNaiveBayes<T, M, MultinomialNBDistribution<T>>,
}
+5 -3
View File
@@ -33,7 +33,7 @@
//!
use std::marker::PhantomData;
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
use crate::api::{Predictor, SupervisedEstimator};
@@ -45,7 +45,8 @@ use crate::math::num::RealNumber;
use crate::neighbors::KNNWeightFunction;
/// `KNNClassifier` parameters. Use `Default::default()` for default values.
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct KNNClassifierParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
/// a function that defines a distance between each pair of point in training data.
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
@@ -62,7 +63,8 @@ pub struct KNNClassifierParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
}
/// K Nearest Neighbors Classifier
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct KNNClassifier<T: RealNumber, D: Distance<Vec<T>, T>> {
classes: Vec<T>,
y: Vec<usize>,
+5 -3
View File
@@ -36,7 +36,7 @@
//!
use std::marker::PhantomData;
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
use crate::api::{Predictor, SupervisedEstimator};
@@ -48,7 +48,8 @@ use crate::math::num::RealNumber;
use crate::neighbors::KNNWeightFunction;
/// `KNNRegressor` parameters. Use `Default::default()` for default values.
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct KNNRegressorParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
/// a function that defines a distance between each pair of point in training data.
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
@@ -65,7 +66,8 @@ pub struct KNNRegressorParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
}
/// K Nearest Neighbors Regressor
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct KNNRegressor<T: RealNumber, D: Distance<Vec<T>, T>> {
y: Vec<T>,
knn_algorithm: KNNAlgorithm<T, D>,
+3 -2
View File
@@ -33,7 +33,7 @@
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use crate::math::num::RealNumber;
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
/// K Nearest Neighbors Classifier
pub mod knn_classifier;
@@ -48,7 +48,8 @@ pub mod knn_regressor;
pub type KNNAlgorithmName = crate::algorithm::neighbour::KNNAlgorithmName;
/// Weight function that is used to determine estimated value.
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub enum KNNWeightFunction {
/// All k nearest points are weighted equally
Uniform,
+9 -5
View File
@@ -26,7 +26,7 @@
pub mod svc;
pub mod svr;
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
@@ -93,18 +93,21 @@ impl Kernels {
}
/// Linear Kernel
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LinearKernel {}
/// Radial basis function (Gaussian) kernel
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct RBFKernel<T: RealNumber> {
/// kernel coefficient
pub gamma: T,
}
/// Polynomial kernel
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct PolynomialKernel<T: RealNumber> {
/// degree of the polynomial
pub degree: T,
@@ -115,7 +118,8 @@ pub struct PolynomialKernel<T: RealNumber> {
}
/// Sigmoid (hyperbolic tangent) kernel
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct SigmoidKernel<T: RealNumber> {
/// kernel coefficient
pub gamma: T,
+9 -6
View File
@@ -76,7 +76,7 @@ use std::marker::PhantomData;
use rand::seq::SliceRandom;
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed;
@@ -85,7 +85,8 @@ use crate::linalg::Matrix;
use crate::math::num::RealNumber;
use crate::svm::{Kernel, Kernels, LinearKernel};
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// SVC Parameters
pub struct SVCParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
/// Number of epochs.
@@ -100,11 +101,12 @@ pub struct SVCParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>
m: PhantomData<M>,
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(bound(
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
#[cfg_attr(feature = "serde", serde(bound(
serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize",
deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>",
))]
)))]
/// Support Vector Classifier
pub struct SVC<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
classes: Vec<T>,
@@ -114,7 +116,8 @@ pub struct SVC<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
b: T,
}
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
struct SupportVector<T: RealNumber, V: BaseVector<T>> {
index: usize,
x: V,
+9 -6
View File
@@ -68,7 +68,7 @@ use std::cell::{Ref, RefCell};
use std::fmt::Debug;
use std::marker::PhantomData;
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed;
@@ -77,7 +77,8 @@ use crate::linalg::Matrix;
use crate::math::num::RealNumber;
use crate::svm::{Kernel, Kernels, LinearKernel};
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// SVR Parameters
pub struct SVRParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
/// Epsilon in the epsilon-SVR model.
@@ -92,11 +93,12 @@ pub struct SVRParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>
m: PhantomData<M>,
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(bound(
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
#[cfg_attr(feature = "serde", serde(bound(
serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize",
deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>",
))]
)))]
/// Epsilon-Support Vector Regression
pub struct SVR<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
@@ -106,7 +108,8 @@ pub struct SVR<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
b: T,
}
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
struct SupportVector<T: RealNumber, V: BaseVector<T>> {
index: usize,
x: V,
+9 -5
View File
@@ -68,7 +68,7 @@ use std::fmt::Debug;
use std::marker::PhantomData;
use rand::seq::SliceRandom;
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::algorithm::sort::quick_sort::QuickArgSort;
use crate::api::{Predictor, SupervisedEstimator};
@@ -76,7 +76,8 @@ use crate::error::Failed;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// Parameters of Decision Tree
pub struct DecisionTreeClassifierParameters {
/// Split criteria to use when building a tree.
@@ -90,7 +91,8 @@ pub struct DecisionTreeClassifierParameters {
}
/// Decision Tree
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct DecisionTreeClassifier<T: RealNumber> {
nodes: Vec<Node<T>>,
parameters: DecisionTreeClassifierParameters,
@@ -100,7 +102,8 @@ pub struct DecisionTreeClassifier<T: RealNumber> {
}
/// The function to measure the quality of a split.
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub enum SplitCriterion {
/// [Gini index](../decision_tree_classifier/index.html)
Gini,
@@ -110,7 +113,8 @@ pub enum SplitCriterion {
ClassificationError,
}
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
struct Node<T: RealNumber> {
index: usize,
output: usize,
+7 -4
View File
@@ -63,7 +63,7 @@ use std::default::Default;
use std::fmt::Debug;
use rand::seq::SliceRandom;
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")] use serde::{Deserialize, Serialize};
use crate::algorithm::sort::quick_sort::QuickArgSort;
use crate::api::{Predictor, SupervisedEstimator};
@@ -71,7 +71,8 @@ use crate::error::Failed;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
#[derive(Serialize, Deserialize, Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// Parameters of Regression Tree
pub struct DecisionTreeRegressorParameters {
/// The maximum depth of the tree.
@@ -83,14 +84,16 @@ pub struct DecisionTreeRegressorParameters {
}
/// Regression Tree
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct DecisionTreeRegressor<T: RealNumber> {
nodes: Vec<Node<T>>,
parameters: DecisionTreeRegressorParameters,
depth: u16,
}
#[derive(Serialize, Deserialize, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
struct Node<T: RealNumber> {
index: usize,
output: T,