@@ -23,10 +23,10 @@ jobs:
|
||||
command: cargo fmt -- --check
|
||||
- run:
|
||||
name: Stable Build
|
||||
command: cargo build --features "nalgebra-bindings ndarray-bindings"
|
||||
command: cargo build --all-features
|
||||
- run:
|
||||
name: Test
|
||||
command: cargo test --features "nalgebra-bindings ndarray-bindings"
|
||||
command: cargo test --all-features
|
||||
- save_cache:
|
||||
key: project-cache
|
||||
paths:
|
||||
|
||||
+1
-2
@@ -25,8 +25,7 @@ num-traits = "0.2.12"
|
||||
num = "0.3.0"
|
||||
rand = "0.7.3"
|
||||
rand_distr = "0.3.0"
|
||||
serde = { version = "1.0.115", features = ["derive"] }
|
||||
serde_derive = "1.0.115"
|
||||
serde = { version = "1.0.115", features = ["derive"], optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.3"
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
//! ```
|
||||
use std::fmt::Debug;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::sort::heap_select::HeapSelection;
|
||||
@@ -32,7 +33,8 @@ use crate::math::distance::Distance;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Implements Cover Tree algorithm
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct CoverTree<T, F: RealNumber, D: Distance<T, F>> {
|
||||
base: F,
|
||||
inv_log_base: F,
|
||||
@@ -56,7 +58,8 @@ impl<T, F: RealNumber, D: Distance<T, F>> PartialEq for CoverTree<T, F, D> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
struct Node<F: RealNumber> {
|
||||
idx: usize,
|
||||
max_dist: F,
|
||||
@@ -65,7 +68,7 @@ struct Node<F: RealNumber> {
|
||||
scale: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug)]
|
||||
struct DistanceSet<F: RealNumber> {
|
||||
idx: usize,
|
||||
dist: Vec<F>,
|
||||
@@ -454,7 +457,8 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::math::distance::Distances;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
struct SimpleDistance {}
|
||||
|
||||
impl Distance<i32, f64> for SimpleDistance {
|
||||
@@ -500,6 +504,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
//!
|
||||
//! ```
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::cmp::{Ordering, PartialOrd};
|
||||
use std::marker::PhantomData;
|
||||
@@ -32,7 +33,8 @@ use crate::math::distance::Distance;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Implements Linear Search algorithm, see [KNN algorithms](../index.html)
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct LinearKNNSearch<T, F: RealNumber, D: Distance<T, F>> {
|
||||
distance: D,
|
||||
data: Vec<T>,
|
||||
@@ -138,7 +140,8 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::math::distance::Distances;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
struct SimpleDistance {}
|
||||
|
||||
impl Distance<i32, f64> for SimpleDistance {
|
||||
|
||||
@@ -35,6 +35,7 @@ use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
|
||||
use crate::error::Failed;
|
||||
use crate::math::distance::Distance;
|
||||
use crate::math::num::RealNumber;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub(crate) mod bbd_tree;
|
||||
@@ -45,7 +46,8 @@ pub mod linear_search;
|
||||
|
||||
/// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries.
|
||||
/// `KNNAlgorithmName` maintains a list of supported search algorithms, see [KNN algorithms](../algorithm/neighbour/index.html)
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum KNNAlgorithmName {
|
||||
/// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html)
|
||||
LinearSearch,
|
||||
@@ -53,7 +55,8 @@ pub enum KNNAlgorithmName {
|
||||
CoverTree,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum KNNAlgorithm<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
|
||||
CoverTree(CoverTree<Vec<T>, T, D>),
|
||||
|
||||
@@ -43,6 +43,7 @@
|
||||
use std::fmt::Debug;
|
||||
use std::iter::Sum;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
|
||||
@@ -55,7 +56,8 @@ use crate::math::num::RealNumber;
|
||||
use crate::tree::decision_tree_classifier::which_max;
|
||||
|
||||
/// DBSCAN clustering algorithm
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct DBSCAN<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
cluster_labels: Vec<i16>,
|
||||
num_classes: usize,
|
||||
@@ -263,6 +265,7 @@ impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
#[cfg(feature = "serde")]
|
||||
use crate::math::distance::euclidian::Euclidian;
|
||||
|
||||
#[test]
|
||||
@@ -297,6 +300,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
|
||||
@@ -56,6 +56,7 @@ use rand::Rng;
|
||||
use std::fmt::Debug;
|
||||
use std::iter::Sum;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
||||
@@ -66,7 +67,8 @@ use crate::math::distance::euclidian::*;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// K-Means clustering algorithm
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct KMeans<T: RealNumber> {
|
||||
k: usize,
|
||||
y: Vec<usize>,
|
||||
@@ -345,6 +347,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
|
||||
@@ -47,6 +47,7 @@
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::fmt::Debug;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Transformer, UnsupervisedEstimator};
|
||||
@@ -55,7 +56,8 @@ use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Principal components analysis algorithm
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct PCA<T: RealNumber, M: Matrix<T>> {
|
||||
eigenvectors: M,
|
||||
eigenvalues: Vec<T>,
|
||||
@@ -565,6 +567,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let iris = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
|
||||
@@ -46,6 +46,7 @@
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Transformer, UnsupervisedEstimator};
|
||||
@@ -54,7 +55,8 @@ use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// SVD
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct SVD<T: RealNumber, M: Matrix<T>> {
|
||||
components: M,
|
||||
phantom: PhantomData<T>,
|
||||
@@ -226,6 +228,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let iris = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
|
||||
@@ -49,6 +49,7 @@ use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rand::Rng;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
@@ -61,7 +62,8 @@ use crate::tree::decision_tree_classifier::{
|
||||
|
||||
/// Parameters of the Random Forest algorithm.
|
||||
/// Some parameters here are passed directly into base estimator.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RandomForestClassifierParameters {
|
||||
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub criterion: SplitCriterion,
|
||||
@@ -78,7 +80,8 @@ pub struct RandomForestClassifierParameters {
|
||||
}
|
||||
|
||||
/// Random Forest Classifier
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct RandomForestClassifier<T: RealNumber> {
|
||||
parameters: RandomForestClassifierParameters,
|
||||
trees: Vec<DecisionTreeClassifier<T>>,
|
||||
@@ -322,6 +325,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
|
||||
@@ -47,6 +47,7 @@ use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rand::Rng;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
@@ -57,7 +58,8 @@ use crate::tree::decision_tree_regressor::{
|
||||
DecisionTreeRegressor, DecisionTreeRegressorParameters,
|
||||
};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// Parameters of the Random Forest Regressor
|
||||
/// Some parameters here are passed directly into base estimator.
|
||||
pub struct RandomForestRegressorParameters {
|
||||
@@ -74,7 +76,8 @@ pub struct RandomForestRegressorParameters {
|
||||
}
|
||||
|
||||
/// Random Forest Regressor
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct RandomForestRegressor<T: RealNumber> {
|
||||
parameters: RandomForestRegressorParameters,
|
||||
trees: Vec<DecisionTreeRegressor<T>>,
|
||||
@@ -271,6 +274,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||
|
||||
+5
-2
@@ -2,10 +2,12 @@
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Generic error to be raised when something goes wrong.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct Failed {
|
||||
err: FailedError,
|
||||
msg: String,
|
||||
@@ -13,7 +15,8 @@ pub struct Failed {
|
||||
|
||||
/// Type of error
|
||||
#[non_exhaustive]
|
||||
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub enum FailedError {
|
||||
/// Can't fit algorithm to data
|
||||
FitFailed = 1,
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
#![allow(clippy::ptr_arg)]
|
||||
use std::fmt;
|
||||
use std::fmt::Debug;
|
||||
#[cfg(feature = "serde")]
|
||||
use std::marker::PhantomData;
|
||||
use std::ops::Range;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::de::{Deserializer, MapAccess, SeqAccess, Visitor};
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::ser::{SerializeStruct, Serializer};
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::cholesky::CholeskyDecomposableMatrix;
|
||||
@@ -349,6 +353,7 @@ impl<'a, T: RealNumber> Iterator for DenseMatrixIterator<'a, T> {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
impl<'de, T: RealNumber + fmt::Debug + Deserialize<'de>> Deserialize<'de> for DenseMatrix<T> {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
@@ -434,6 +439,7 @@ impl<'de, T: RealNumber + fmt::Debug + Deserialize<'de>> Deserialize<'de> for De
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
impl<T: RealNumber + fmt::Debug + Serialize> Serialize for DenseMatrix<T> {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
@@ -1306,6 +1312,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn to_from_json() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||
let deserialized_a: DenseMatrix<f64> =
|
||||
@@ -1314,6 +1321,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn to_from_bincode() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||
let deserialized_a: DenseMatrix<f64> =
|
||||
|
||||
@@ -56,6 +56,7 @@
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::fmt::Debug;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
@@ -67,7 +68,8 @@ use crate::math::num::RealNumber;
|
||||
use crate::linear::lasso_optimizer::InteriorPointOptimizer;
|
||||
|
||||
/// Elastic net parameters
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ElasticNetParameters<T: RealNumber> {
|
||||
/// Regularization parameter.
|
||||
pub alpha: T,
|
||||
@@ -84,7 +86,8 @@ pub struct ElasticNetParameters<T: RealNumber> {
|
||||
}
|
||||
|
||||
/// Elastic net
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct ElasticNet<T: RealNumber, M: Matrix<T>> {
|
||||
coefficients: M,
|
||||
intercept: T,
|
||||
@@ -398,6 +401,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
|
||||
+6
-2
@@ -24,6 +24,7 @@
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::fmt::Debug;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
@@ -34,7 +35,8 @@ use crate::linear::lasso_optimizer::InteriorPointOptimizer;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Lasso regression parameters
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LassoParameters<T: RealNumber> {
|
||||
/// Controls the strength of the penalty to the loss function.
|
||||
pub alpha: T,
|
||||
@@ -47,7 +49,8 @@ pub struct LassoParameters<T: RealNumber> {
|
||||
pub max_iter: usize,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
/// Lasso regressor
|
||||
pub struct Lasso<T: RealNumber, M: Matrix<T>> {
|
||||
coefficients: M,
|
||||
@@ -272,6 +275,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
|
||||
@@ -62,6 +62,7 @@
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::fmt::Debug;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
@@ -69,7 +70,8 @@ use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable.
|
||||
pub enum LinearRegressionSolverName {
|
||||
/// QR decomposition, see [QR](../../linalg/qr/index.html)
|
||||
@@ -79,14 +81,16 @@ pub enum LinearRegressionSolverName {
|
||||
}
|
||||
|
||||
/// Linear Regression parameters
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LinearRegressionParameters {
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub solver: LinearRegressionSolverName,
|
||||
}
|
||||
|
||||
/// Linear Regression
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct LinearRegression<T: RealNumber, M: Matrix<T>> {
|
||||
coefficients: M,
|
||||
intercept: T,
|
||||
@@ -247,6 +251,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
|
||||
@@ -56,6 +56,7 @@ use std::cmp::Ordering;
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
@@ -68,11 +69,13 @@ use crate::optimization::line_search::Backtracking;
|
||||
use crate::optimization::FunctionOrder;
|
||||
|
||||
/// Logistic Regression parameters
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LogisticRegressionParameters {}
|
||||
|
||||
/// Logistic Regression
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct LogisticRegression<T: RealNumber, M: Matrix<T>> {
|
||||
coefficients: M,
|
||||
intercept: M,
|
||||
@@ -540,6 +543,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[1., -5.],
|
||||
|
||||
@@ -58,6 +58,7 @@
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::fmt::Debug;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
@@ -66,7 +67,8 @@ use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable.
|
||||
pub enum RidgeRegressionSolverName {
|
||||
/// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html)
|
||||
@@ -76,7 +78,8 @@ pub enum RidgeRegressionSolverName {
|
||||
}
|
||||
|
||||
/// Ridge Regression parameters
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RidgeRegressionParameters<T: RealNumber> {
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub solver: RidgeRegressionSolverName,
|
||||
@@ -88,7 +91,8 @@ pub struct RidgeRegressionParameters<T: RealNumber> {
|
||||
}
|
||||
|
||||
/// Ridge regression
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct RidgeRegression<T: RealNumber, M: Matrix<T>> {
|
||||
coefficients: M,
|
||||
intercept: T,
|
||||
@@ -326,6 +330,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::math::num::RealNumber;
|
||||
@@ -25,7 +26,8 @@ use crate::math::num::RealNumber;
|
||||
use super::Distance;
|
||||
|
||||
/// Euclidean distance is a measure of the true straight line distance between two points in Euclidean n-space.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Euclidian {}
|
||||
|
||||
impl Euclidian {
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::math::num::RealNumber;
|
||||
@@ -26,7 +27,8 @@ use crate::math::num::RealNumber;
|
||||
use super::Distance;
|
||||
|
||||
/// While comparing two integer-valued vectors of equal length, Hamming distance is the number of bit positions in which the two bits are different
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Hamming {}
|
||||
|
||||
impl<T: PartialEq, F: RealNumber> Distance<Vec<T>, F> for Hamming {
|
||||
|
||||
@@ -44,6 +44,7 @@
|
||||
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::math::num::RealNumber;
|
||||
@@ -52,7 +53,8 @@ use super::Distance;
|
||||
use crate::linalg::Matrix;
|
||||
|
||||
/// Mahalanobis distance.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Mahalanobis<T: RealNumber, M: Matrix<T>> {
|
||||
/// covariance matrix of the dataset
|
||||
pub sigma: M,
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
//! ```
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::math::num::RealNumber;
|
||||
@@ -24,7 +25,8 @@ use crate::math::num::RealNumber;
|
||||
use super::Distance;
|
||||
|
||||
/// Manhattan distance
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Manhattan {}
|
||||
|
||||
impl<T: RealNumber> Distance<Vec<T>, T> for Manhattan {
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::math::num::RealNumber;
|
||||
@@ -28,7 +29,8 @@ use crate::math::num::RealNumber;
|
||||
use super::Distance;
|
||||
|
||||
/// Defines the Minkowski distance of order `p`
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Minkowski {
|
||||
/// order, integer
|
||||
pub p: u16,
|
||||
|
||||
@@ -16,13 +16,15 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Accuracy metric.
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct Accuracy {}
|
||||
|
||||
impl Accuracy {
|
||||
|
||||
+3
-1
@@ -20,6 +20,7 @@
|
||||
//! * ["The ROC-AUC and the Mann-Whitney U-test", Haupt, J.](https://johaupt.github.io/roc-auc/model%20evaluation/Area_under_ROC_curve.html)
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||
@@ -27,7 +28,8 @@ use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Area Under the Receiver Operating Characteristic Curve (ROC AUC)
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct AUC {}
|
||||
|
||||
impl AUC {
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::metrics::cluster_helpers::*;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
/// Homogeneity, completeness and V-Measure scores.
|
||||
pub struct HCVScore {}
|
||||
|
||||
|
||||
+3
-1
@@ -18,6 +18,7 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
@@ -26,7 +27,8 @@ use crate::metrics::precision::Precision;
|
||||
use crate::metrics::recall::Recall;
|
||||
|
||||
/// F-measure
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct F1<T: RealNumber> {
|
||||
/// a positive real factor
|
||||
pub beta: T,
|
||||
|
||||
@@ -18,12 +18,14 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
/// Mean Absolute Error
|
||||
pub struct MeanAbsoluteError {}
|
||||
|
||||
|
||||
@@ -18,12 +18,14 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
/// Mean Squared Error
|
||||
pub struct MeanSquareError {}
|
||||
|
||||
|
||||
@@ -18,13 +18,15 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Precision metric.
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct Precision {}
|
||||
|
||||
impl Precision {
|
||||
|
||||
+3
-1
@@ -18,13 +18,15 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Coefficient of Determination (R2)
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct R2 {}
|
||||
|
||||
impl R2 {
|
||||
|
||||
@@ -18,13 +18,15 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Recall metric.
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct Recall {}
|
||||
|
||||
impl Recall {
|
||||
|
||||
@@ -42,10 +42,12 @@ use crate::math::num::RealNumber;
|
||||
use crate::math::vector::RealNumberVector;
|
||||
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Naive Bayes classifier for Bearnoulli features
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, PartialEq)]
|
||||
struct BernoulliNBDistribution<T: RealNumber> {
|
||||
/// class labels known to the classifier
|
||||
class_labels: Vec<T>,
|
||||
@@ -77,7 +79,8 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for BernoulliNBDistributi
|
||||
}
|
||||
|
||||
/// `BernoulliNB` parameters. Use `Default::default()` for default values.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BernoulliNBParameters<T: RealNumber> {
|
||||
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||
pub alpha: T,
|
||||
@@ -202,7 +205,8 @@ impl<T: RealNumber> BernoulliNBDistribution<T> {
|
||||
}
|
||||
|
||||
/// BernoulliNB implements the categorical naive Bayes algorithm for categorically distributed data.
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct BernoulliNB<T: RealNumber, M: Matrix<T>> {
|
||||
inner: BaseNaiveBayes<T, M, BernoulliNBDistribution<T>>,
|
||||
binarize: Option<T>,
|
||||
@@ -347,6 +351,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
&[1., 1., 0., 0., 0., 0.],
|
||||
|
||||
@@ -36,10 +36,12 @@ use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Naive Bayes classifier for categorical features
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
struct CategoricalNBDistribution<T: RealNumber> {
|
||||
class_labels: Vec<T>,
|
||||
class_priors: Vec<T>,
|
||||
@@ -216,7 +218,8 @@ impl<T: RealNumber> CategoricalNBDistribution<T> {
|
||||
}
|
||||
|
||||
/// `CategoricalNB` parameters. Use `Default::default()` for default values.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CategoricalNBParameters<T: RealNumber> {
|
||||
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||
pub alpha: T,
|
||||
@@ -237,7 +240,8 @@ impl<T: RealNumber> Default for CategoricalNBParameters<T> {
|
||||
}
|
||||
|
||||
/// CategoricalNB implements the categorical naive Bayes algorithm for categorically distributed data.
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct CategoricalNB<T: RealNumber, M: Matrix<T>> {
|
||||
inner: BaseNaiveBayes<T, M, CategoricalNBDistribution<T>>,
|
||||
}
|
||||
@@ -345,6 +349,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
&[3., 4., 0., 1.],
|
||||
|
||||
@@ -30,10 +30,12 @@ use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::math::vector::RealNumberVector;
|
||||
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Naive Bayes classifier for categorical features
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, PartialEq)]
|
||||
struct GaussianNBDistribution<T: RealNumber> {
|
||||
/// class labels known to the classifier
|
||||
class_labels: Vec<T>,
|
||||
@@ -75,7 +77,8 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for GaussianNBDistributio
|
||||
}
|
||||
|
||||
/// `GaussianNB` parameters. Use `Default::default()` for default values.
|
||||
#[derive(Serialize, Deserialize, Debug, Default, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct GaussianNBParameters<T: RealNumber> {
|
||||
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
|
||||
pub priors: Option<Vec<T>>,
|
||||
@@ -178,7 +181,8 @@ impl<T: RealNumber> GaussianNBDistribution<T> {
|
||||
}
|
||||
|
||||
/// GaussianNB implements the categorical naive Bayes algorithm for categorically distributed data.
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct GaussianNB<T: RealNumber, M: Matrix<T>> {
|
||||
inner: BaseNaiveBayes<T, M, GaussianNBDistribution<T>>,
|
||||
}
|
||||
@@ -277,6 +281,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
&[-1., -1.],
|
||||
|
||||
@@ -39,6 +39,7 @@ use crate::error::Failed;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
@@ -55,7 +56,8 @@ pub(crate) trait NBDistribution<T: RealNumber, M: Matrix<T>> {
|
||||
}
|
||||
|
||||
/// Base struct for the Naive Bayes classifier.
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub(crate) struct BaseNaiveBayes<T: RealNumber, M: Matrix<T>, D: NBDistribution<T, M>> {
|
||||
distribution: D,
|
||||
_phantom_t: PhantomData<T>,
|
||||
|
||||
@@ -42,10 +42,12 @@ use crate::math::num::RealNumber;
|
||||
use crate::math::vector::RealNumberVector;
|
||||
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Naive Bayes classifier for Multinomial features
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, PartialEq)]
|
||||
struct MultinomialNBDistribution<T: RealNumber> {
|
||||
/// class labels known to the classifier
|
||||
class_labels: Vec<T>,
|
||||
@@ -73,7 +75,8 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for MultinomialNBDistribu
|
||||
}
|
||||
|
||||
/// `MultinomialNB` parameters. Use `Default::default()` for default values.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MultinomialNBParameters<T: RealNumber> {
|
||||
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||
pub alpha: T,
|
||||
@@ -189,7 +192,8 @@ impl<T: RealNumber> MultinomialNBDistribution<T> {
|
||||
}
|
||||
|
||||
/// MultinomialNB implements the categorical naive Bayes algorithm for categorically distributed data.
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct MultinomialNB<T: RealNumber, M: Matrix<T>> {
|
||||
inner: BaseNaiveBayes<T, M, MultinomialNBDistribution<T>>,
|
||||
}
|
||||
@@ -320,6 +324,7 @@ mod tests {
|
||||
));
|
||||
}
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
&[1., 1., 0., 0., 0., 0.],
|
||||
|
||||
@@ -33,6 +33,7 @@
|
||||
//!
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
|
||||
@@ -45,7 +46,8 @@ use crate::math::num::RealNumber;
|
||||
use crate::neighbors::KNNWeightFunction;
|
||||
|
||||
/// `KNNClassifier` parameters. Use `Default::default()` for default values.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KNNClassifierParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
/// a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
@@ -62,7 +64,8 @@ pub struct KNNClassifierParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
}
|
||||
|
||||
/// K Nearest Neighbors Classifier
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct KNNClassifier<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
classes: Vec<T>,
|
||||
y: Vec<usize>,
|
||||
@@ -277,6 +280,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x =
|
||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
|
||||
@@ -36,6 +36,7 @@
|
||||
//!
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
|
||||
@@ -48,7 +49,8 @@ use crate::math::num::RealNumber;
|
||||
use crate::neighbors::KNNWeightFunction;
|
||||
|
||||
/// `KNNRegressor` parameters. Use `Default::default()` for default values.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KNNRegressorParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
/// a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
@@ -65,7 +67,8 @@ pub struct KNNRegressorParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
}
|
||||
|
||||
/// K Nearest Neighbors Regressor
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct KNNRegressor<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
y: Vec<T>,
|
||||
knn_algorithm: KNNAlgorithm<T, D>,
|
||||
@@ -266,6 +269,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x =
|
||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
|
||||
@@ -33,6 +33,7 @@
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
use crate::math::num::RealNumber;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// K Nearest Neighbors Classifier
|
||||
@@ -48,7 +49,8 @@ pub mod knn_regressor;
|
||||
pub type KNNAlgorithmName = crate::algorithm::neighbour::KNNAlgorithmName;
|
||||
|
||||
/// Weight function that is used to determine estimated value.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum KNNWeightFunction {
|
||||
/// All k nearest points are weighted equally
|
||||
Uniform,
|
||||
|
||||
+9
-4
@@ -26,6 +26,7 @@
|
||||
pub mod svc;
|
||||
pub mod svr;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
@@ -93,18 +94,21 @@ impl Kernels {
|
||||
}
|
||||
|
||||
/// Linear Kernel
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LinearKernel {}
|
||||
|
||||
/// Radial basis function (Gaussian) kernel
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RBFKernel<T: RealNumber> {
|
||||
/// kernel coefficient
|
||||
pub gamma: T,
|
||||
}
|
||||
|
||||
/// Polynomial kernel
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PolynomialKernel<T: RealNumber> {
|
||||
/// degree of the polynomial
|
||||
pub degree: T,
|
||||
@@ -115,7 +119,8 @@ pub struct PolynomialKernel<T: RealNumber> {
|
||||
}
|
||||
|
||||
/// Sigmoid (hyperbolic tangent) kernel
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SigmoidKernel<T: RealNumber> {
|
||||
/// kernel coefficient
|
||||
pub gamma: T,
|
||||
|
||||
+14
-5
@@ -76,6 +76,7 @@ use std::marker::PhantomData;
|
||||
|
||||
use rand::seq::SliceRandom;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
@@ -85,7 +86,8 @@ use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::svm::{Kernel, Kernels, LinearKernel};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// SVC Parameters
|
||||
pub struct SVCParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||
/// Number of epochs.
|
||||
@@ -100,11 +102,15 @@ pub struct SVCParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>
|
||||
m: PhantomData<M>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(bound(
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
#[cfg_attr(
|
||||
feature = "serde",
|
||||
serde(bound(
|
||||
serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize",
|
||||
deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>",
|
||||
))]
|
||||
))
|
||||
)]
|
||||
/// Support Vector Classifier
|
||||
pub struct SVC<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||
classes: Vec<T>,
|
||||
@@ -114,7 +120,8 @@ pub struct SVC<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||
b: T,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
struct SupportVector<T: RealNumber, V: BaseVector<T>> {
|
||||
index: usize,
|
||||
x: V,
|
||||
@@ -719,6 +726,7 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::metrics::accuracy;
|
||||
#[cfg(feature = "serde")]
|
||||
use crate::svm::*;
|
||||
|
||||
#[test]
|
||||
@@ -807,6 +815,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn svc_serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
|
||||
+14
-5
@@ -68,6 +68,7 @@ use std::cell::{Ref, RefCell};
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
@@ -77,7 +78,8 @@ use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::svm::{Kernel, Kernels, LinearKernel};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// SVR Parameters
|
||||
pub struct SVRParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||
/// Epsilon in the epsilon-SVR model.
|
||||
@@ -92,11 +94,15 @@ pub struct SVRParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>
|
||||
m: PhantomData<M>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(bound(
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
#[cfg_attr(
|
||||
feature = "serde",
|
||||
serde(bound(
|
||||
serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize",
|
||||
deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>",
|
||||
))]
|
||||
))
|
||||
)]
|
||||
|
||||
/// Epsilon-Support Vector Regression
|
||||
pub struct SVR<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||
@@ -106,7 +112,8 @@ pub struct SVR<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||
b: T,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
struct SupportVector<T: RealNumber, V: BaseVector<T>> {
|
||||
index: usize,
|
||||
x: V,
|
||||
@@ -526,6 +533,7 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::metrics::mean_squared_error;
|
||||
#[cfg(feature = "serde")]
|
||||
use crate::svm::*;
|
||||
|
||||
#[test]
|
||||
@@ -562,6 +570,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn svr_serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
|
||||
@@ -68,6 +68,7 @@ use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use rand::seq::SliceRandom;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||
@@ -76,7 +77,8 @@ use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// Parameters of Decision Tree
|
||||
pub struct DecisionTreeClassifierParameters {
|
||||
/// Split criteria to use when building a tree.
|
||||
@@ -90,7 +92,8 @@ pub struct DecisionTreeClassifierParameters {
|
||||
}
|
||||
|
||||
/// Decision Tree
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct DecisionTreeClassifier<T: RealNumber> {
|
||||
nodes: Vec<Node<T>>,
|
||||
parameters: DecisionTreeClassifierParameters,
|
||||
@@ -100,7 +103,8 @@ pub struct DecisionTreeClassifier<T: RealNumber> {
|
||||
}
|
||||
|
||||
/// The function to measure the quality of a split.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum SplitCriterion {
|
||||
/// [Gini index](../decision_tree_classifier/index.html)
|
||||
Gini,
|
||||
@@ -110,7 +114,8 @@ pub enum SplitCriterion {
|
||||
ClassificationError,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
struct Node<T: RealNumber> {
|
||||
index: usize,
|
||||
output: usize,
|
||||
@@ -740,6 +745,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[1., 1., 1., 0.],
|
||||
|
||||
@@ -63,6 +63,7 @@ use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rand::seq::SliceRandom;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||
@@ -71,7 +72,8 @@ use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// Parameters of Regression Tree
|
||||
pub struct DecisionTreeRegressorParameters {
|
||||
/// The maximum depth of the tree.
|
||||
@@ -83,14 +85,16 @@ pub struct DecisionTreeRegressorParameters {
|
||||
}
|
||||
|
||||
/// Regression Tree
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct DecisionTreeRegressor<T: RealNumber> {
|
||||
nodes: Vec<Node<T>>,
|
||||
parameters: DecisionTreeRegressorParameters,
|
||||
depth: u16,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
struct Node<T: RealNumber> {
|
||||
index: usize,
|
||||
output: T,
|
||||
@@ -577,6 +581,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||
|
||||
Reference in New Issue
Block a user