@@ -23,10 +23,10 @@ jobs:
|
|||||||
command: cargo fmt -- --check
|
command: cargo fmt -- --check
|
||||||
- run:
|
- run:
|
||||||
name: Stable Build
|
name: Stable Build
|
||||||
command: cargo build --features "nalgebra-bindings ndarray-bindings"
|
command: cargo build --all-features
|
||||||
- run:
|
- run:
|
||||||
name: Test
|
name: Test
|
||||||
command: cargo test --features "nalgebra-bindings ndarray-bindings"
|
command: cargo test --all-features
|
||||||
- save_cache:
|
- save_cache:
|
||||||
key: project-cache
|
key: project-cache
|
||||||
paths:
|
paths:
|
||||||
|
|||||||
+1
-2
@@ -25,8 +25,7 @@ num-traits = "0.2.12"
|
|||||||
num = "0.3.0"
|
num = "0.3.0"
|
||||||
rand = "0.7.3"
|
rand = "0.7.3"
|
||||||
rand_distr = "0.3.0"
|
rand_distr = "0.3.0"
|
||||||
serde = { version = "1.0.115", features = ["derive"] }
|
serde = { version = "1.0.115", features = ["derive"], optional = true }
|
||||||
serde_derive = "1.0.115"
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
criterion = "0.3"
|
criterion = "0.3"
|
||||||
|
|||||||
@@ -24,6 +24,7 @@
|
|||||||
//! ```
|
//! ```
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::algorithm::sort::heap_select::HeapSelection;
|
use crate::algorithm::sort::heap_select::HeapSelection;
|
||||||
@@ -32,7 +33,8 @@ use crate::math::distance::Distance;
|
|||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
/// Implements Cover Tree algorithm
|
/// Implements Cover Tree algorithm
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct CoverTree<T, F: RealNumber, D: Distance<T, F>> {
|
pub struct CoverTree<T, F: RealNumber, D: Distance<T, F>> {
|
||||||
base: F,
|
base: F,
|
||||||
inv_log_base: F,
|
inv_log_base: F,
|
||||||
@@ -56,7 +58,8 @@ impl<T, F: RealNumber, D: Distance<T, F>> PartialEq for CoverTree<T, F, D> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
struct Node<F: RealNumber> {
|
struct Node<F: RealNumber> {
|
||||||
idx: usize,
|
idx: usize,
|
||||||
max_dist: F,
|
max_dist: F,
|
||||||
@@ -65,7 +68,7 @@ struct Node<F: RealNumber> {
|
|||||||
scale: i64,
|
scale: i64,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug)]
|
||||||
struct DistanceSet<F: RealNumber> {
|
struct DistanceSet<F: RealNumber> {
|
||||||
idx: usize,
|
idx: usize,
|
||||||
dist: Vec<F>,
|
dist: Vec<F>,
|
||||||
@@ -454,7 +457,8 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::math::distance::Distances;
|
use crate::math::distance::Distances;
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
struct SimpleDistance {}
|
struct SimpleDistance {}
|
||||||
|
|
||||||
impl Distance<i32, f64> for SimpleDistance {
|
impl Distance<i32, f64> for SimpleDistance {
|
||||||
@@ -500,6 +504,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
|
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@
|
|||||||
//!
|
//!
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::cmp::{Ordering, PartialOrd};
|
use std::cmp::{Ordering, PartialOrd};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
@@ -32,7 +33,8 @@ use crate::math::distance::Distance;
|
|||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
/// Implements Linear Search algorithm, see [KNN algorithms](../index.html)
|
/// Implements Linear Search algorithm, see [KNN algorithms](../index.html)
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct LinearKNNSearch<T, F: RealNumber, D: Distance<T, F>> {
|
pub struct LinearKNNSearch<T, F: RealNumber, D: Distance<T, F>> {
|
||||||
distance: D,
|
distance: D,
|
||||||
data: Vec<T>,
|
data: Vec<T>,
|
||||||
@@ -138,7 +140,8 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::math::distance::Distances;
|
use crate::math::distance::Distances;
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
struct SimpleDistance {}
|
struct SimpleDistance {}
|
||||||
|
|
||||||
impl Distance<i32, f64> for SimpleDistance {
|
impl Distance<i32, f64> for SimpleDistance {
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
|
|||||||
use crate::error::Failed;
|
use crate::error::Failed;
|
||||||
use crate::math::distance::Distance;
|
use crate::math::distance::Distance;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
pub(crate) mod bbd_tree;
|
pub(crate) mod bbd_tree;
|
||||||
@@ -45,7 +46,8 @@ pub mod linear_search;
|
|||||||
|
|
||||||
/// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries.
|
/// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries.
|
||||||
/// `KNNAlgorithmName` maintains a list of supported search algorithms, see [KNN algorithms](../algorithm/neighbour/index.html)
|
/// `KNNAlgorithmName` maintains a list of supported search algorithms, see [KNN algorithms](../algorithm/neighbour/index.html)
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub enum KNNAlgorithmName {
|
pub enum KNNAlgorithmName {
|
||||||
/// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html)
|
/// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html)
|
||||||
LinearSearch,
|
LinearSearch,
|
||||||
@@ -53,7 +55,8 @@ pub enum KNNAlgorithmName {
|
|||||||
CoverTree,
|
CoverTree,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub(crate) enum KNNAlgorithm<T: RealNumber, D: Distance<Vec<T>, T>> {
|
pub(crate) enum KNNAlgorithm<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||||
LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
|
LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
|
||||||
CoverTree(CoverTree<Vec<T>, T, D>),
|
CoverTree(CoverTree<Vec<T>, T, D>),
|
||||||
|
|||||||
@@ -43,6 +43,7 @@
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::iter::Sum;
|
use std::iter::Sum;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
|
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
|
||||||
@@ -55,7 +56,8 @@ use crate::math::num::RealNumber;
|
|||||||
use crate::tree::decision_tree_classifier::which_max;
|
use crate::tree::decision_tree_classifier::which_max;
|
||||||
|
|
||||||
/// DBSCAN clustering algorithm
|
/// DBSCAN clustering algorithm
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct DBSCAN<T: RealNumber, D: Distance<Vec<T>, T>> {
|
pub struct DBSCAN<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||||
cluster_labels: Vec<i16>,
|
cluster_labels: Vec<i16>,
|
||||||
num_classes: usize,
|
num_classes: usize,
|
||||||
@@ -263,6 +265,7 @@ impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use crate::math::distance::euclidian::Euclidian;
|
use crate::math::distance::euclidian::Euclidian;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -297,6 +300,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[5.1, 3.5, 1.4, 0.2],
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ use rand::Rng;
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::iter::Sum;
|
use std::iter::Sum;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
||||||
@@ -66,7 +67,8 @@ use crate::math::distance::euclidian::*;
|
|||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
/// K-Means clustering algorithm
|
/// K-Means clustering algorithm
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct KMeans<T: RealNumber> {
|
pub struct KMeans<T: RealNumber> {
|
||||||
k: usize,
|
k: usize,
|
||||||
y: Vec<usize>,
|
y: Vec<usize>,
|
||||||
@@ -345,6 +347,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[5.1, 3.5, 1.4, 0.2],
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
|||||||
@@ -47,6 +47,7 @@
|
|||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::api::{Transformer, UnsupervisedEstimator};
|
use crate::api::{Transformer, UnsupervisedEstimator};
|
||||||
@@ -55,7 +56,8 @@ use crate::linalg::Matrix;
|
|||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
/// Principal components analysis algorithm
|
/// Principal components analysis algorithm
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct PCA<T: RealNumber, M: Matrix<T>> {
|
pub struct PCA<T: RealNumber, M: Matrix<T>> {
|
||||||
eigenvectors: M,
|
eigenvectors: M,
|
||||||
eigenvalues: Vec<T>,
|
eigenvalues: Vec<T>,
|
||||||
@@ -565,6 +567,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let iris = DenseMatrix::from_2d_array(&[
|
let iris = DenseMatrix::from_2d_array(&[
|
||||||
&[5.1, 3.5, 1.4, 0.2],
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
|||||||
@@ -46,6 +46,7 @@
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::api::{Transformer, UnsupervisedEstimator};
|
use crate::api::{Transformer, UnsupervisedEstimator};
|
||||||
@@ -54,7 +55,8 @@ use crate::linalg::Matrix;
|
|||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
/// SVD
|
/// SVD
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct SVD<T: RealNumber, M: Matrix<T>> {
|
pub struct SVD<T: RealNumber, M: Matrix<T>> {
|
||||||
components: M,
|
components: M,
|
||||||
phantom: PhantomData<T>,
|
phantom: PhantomData<T>,
|
||||||
@@ -226,6 +228,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let iris = DenseMatrix::from_2d_array(&[
|
let iris = DenseMatrix::from_2d_array(&[
|
||||||
&[5.1, 3.5, 1.4, 0.2],
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ use std::default::Default;
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::api::{Predictor, SupervisedEstimator};
|
use crate::api::{Predictor, SupervisedEstimator};
|
||||||
@@ -61,7 +62,8 @@ use crate::tree::decision_tree_classifier::{
|
|||||||
|
|
||||||
/// Parameters of the Random Forest algorithm.
|
/// Parameters of the Random Forest algorithm.
|
||||||
/// Some parameters here are passed directly into base estimator.
|
/// Some parameters here are passed directly into base estimator.
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct RandomForestClassifierParameters {
|
pub struct RandomForestClassifierParameters {
|
||||||
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||||
pub criterion: SplitCriterion,
|
pub criterion: SplitCriterion,
|
||||||
@@ -78,7 +80,8 @@ pub struct RandomForestClassifierParameters {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Random Forest Classifier
|
/// Random Forest Classifier
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct RandomForestClassifier<T: RealNumber> {
|
pub struct RandomForestClassifier<T: RealNumber> {
|
||||||
parameters: RandomForestClassifierParameters,
|
parameters: RandomForestClassifierParameters,
|
||||||
trees: Vec<DecisionTreeClassifier<T>>,
|
trees: Vec<DecisionTreeClassifier<T>>,
|
||||||
@@ -322,6 +325,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[5.1, 3.5, 1.4, 0.2],
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ use std::default::Default;
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::api::{Predictor, SupervisedEstimator};
|
use crate::api::{Predictor, SupervisedEstimator};
|
||||||
@@ -57,7 +58,8 @@ use crate::tree::decision_tree_regressor::{
|
|||||||
DecisionTreeRegressor, DecisionTreeRegressorParameters,
|
DecisionTreeRegressor, DecisionTreeRegressorParameters,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
/// Parameters of the Random Forest Regressor
|
/// Parameters of the Random Forest Regressor
|
||||||
/// Some parameters here are passed directly into base estimator.
|
/// Some parameters here are passed directly into base estimator.
|
||||||
pub struct RandomForestRegressorParameters {
|
pub struct RandomForestRegressorParameters {
|
||||||
@@ -74,7 +76,8 @@ pub struct RandomForestRegressorParameters {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Random Forest Regressor
|
/// Random Forest Regressor
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct RandomForestRegressor<T: RealNumber> {
|
pub struct RandomForestRegressor<T: RealNumber> {
|
||||||
parameters: RandomForestRegressorParameters,
|
parameters: RandomForestRegressorParameters,
|
||||||
trees: Vec<DecisionTreeRegressor<T>>,
|
trees: Vec<DecisionTreeRegressor<T>>,
|
||||||
@@ -271,6 +274,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
&[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||||
|
|||||||
+5
-2
@@ -2,10 +2,12 @@
|
|||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// Generic error to be raised when something goes wrong.
|
/// Generic error to be raised when something goes wrong.
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct Failed {
|
pub struct Failed {
|
||||||
err: FailedError,
|
err: FailedError,
|
||||||
msg: String,
|
msg: String,
|
||||||
@@ -13,7 +15,8 @@ pub struct Failed {
|
|||||||
|
|
||||||
/// Type of error
|
/// Type of error
|
||||||
#[non_exhaustive]
|
#[non_exhaustive]
|
||||||
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Copy, Clone, Debug)]
|
||||||
pub enum FailedError {
|
pub enum FailedError {
|
||||||
/// Can't fit algorithm to data
|
/// Can't fit algorithm to data
|
||||||
FitFailed = 1,
|
FitFailed = 1,
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
#![allow(clippy::ptr_arg)]
|
#![allow(clippy::ptr_arg)]
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
use std::ops::Range;
|
use std::ops::Range;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::de::{Deserializer, MapAccess, SeqAccess, Visitor};
|
use serde::de::{Deserializer, MapAccess, SeqAccess, Visitor};
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::ser::{SerializeStruct, Serializer};
|
use serde::ser::{SerializeStruct, Serializer};
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::linalg::cholesky::CholeskyDecomposableMatrix;
|
use crate::linalg::cholesky::CholeskyDecomposableMatrix;
|
||||||
@@ -349,6 +353,7 @@ impl<'a, T: RealNumber> Iterator for DenseMatrixIterator<'a, T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
impl<'de, T: RealNumber + fmt::Debug + Deserialize<'de>> Deserialize<'de> for DenseMatrix<T> {
|
impl<'de, T: RealNumber + fmt::Debug + Deserialize<'de>> Deserialize<'de> for DenseMatrix<T> {
|
||||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||||
where
|
where
|
||||||
@@ -434,6 +439,7 @@ impl<'de, T: RealNumber + fmt::Debug + Deserialize<'de>> Deserialize<'de> for De
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
impl<T: RealNumber + fmt::Debug + Serialize> Serialize for DenseMatrix<T> {
|
impl<T: RealNumber + fmt::Debug + Serialize> Serialize for DenseMatrix<T> {
|
||||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
where
|
where
|
||||||
@@ -1306,6 +1312,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn to_from_json() {
|
fn to_from_json() {
|
||||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||||
let deserialized_a: DenseMatrix<f64> =
|
let deserialized_a: DenseMatrix<f64> =
|
||||||
@@ -1314,6 +1321,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn to_from_bincode() {
|
fn to_from_bincode() {
|
||||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||||
let deserialized_a: DenseMatrix<f64> =
|
let deserialized_a: DenseMatrix<f64> =
|
||||||
|
|||||||
@@ -56,6 +56,7 @@
|
|||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::api::{Predictor, SupervisedEstimator};
|
use crate::api::{Predictor, SupervisedEstimator};
|
||||||
@@ -67,7 +68,8 @@ use crate::math::num::RealNumber;
|
|||||||
use crate::linear::lasso_optimizer::InteriorPointOptimizer;
|
use crate::linear::lasso_optimizer::InteriorPointOptimizer;
|
||||||
|
|
||||||
/// Elastic net parameters
|
/// Elastic net parameters
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct ElasticNetParameters<T: RealNumber> {
|
pub struct ElasticNetParameters<T: RealNumber> {
|
||||||
/// Regularization parameter.
|
/// Regularization parameter.
|
||||||
pub alpha: T,
|
pub alpha: T,
|
||||||
@@ -84,7 +86,8 @@ pub struct ElasticNetParameters<T: RealNumber> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Elastic net
|
/// Elastic net
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct ElasticNet<T: RealNumber, M: Matrix<T>> {
|
pub struct ElasticNet<T: RealNumber, M: Matrix<T>> {
|
||||||
coefficients: M,
|
coefficients: M,
|
||||||
intercept: T,
|
intercept: T,
|
||||||
@@ -398,6 +401,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||||
|
|||||||
+6
-2
@@ -24,6 +24,7 @@
|
|||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::api::{Predictor, SupervisedEstimator};
|
use crate::api::{Predictor, SupervisedEstimator};
|
||||||
@@ -34,7 +35,8 @@ use crate::linear::lasso_optimizer::InteriorPointOptimizer;
|
|||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
/// Lasso regression parameters
|
/// Lasso regression parameters
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct LassoParameters<T: RealNumber> {
|
pub struct LassoParameters<T: RealNumber> {
|
||||||
/// Controls the strength of the penalty to the loss function.
|
/// Controls the strength of the penalty to the loss function.
|
||||||
pub alpha: T,
|
pub alpha: T,
|
||||||
@@ -47,7 +49,8 @@ pub struct LassoParameters<T: RealNumber> {
|
|||||||
pub max_iter: usize,
|
pub max_iter: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
/// Lasso regressor
|
/// Lasso regressor
|
||||||
pub struct Lasso<T: RealNumber, M: Matrix<T>> {
|
pub struct Lasso<T: RealNumber, M: Matrix<T>> {
|
||||||
coefficients: M,
|
coefficients: M,
|
||||||
@@ -272,6 +275,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||||
|
|||||||
@@ -62,6 +62,7 @@
|
|||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::api::{Predictor, SupervisedEstimator};
|
use crate::api::{Predictor, SupervisedEstimator};
|
||||||
@@ -69,7 +70,8 @@ use crate::error::Failed;
|
|||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
/// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable.
|
/// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable.
|
||||||
pub enum LinearRegressionSolverName {
|
pub enum LinearRegressionSolverName {
|
||||||
/// QR decomposition, see [QR](../../linalg/qr/index.html)
|
/// QR decomposition, see [QR](../../linalg/qr/index.html)
|
||||||
@@ -79,14 +81,16 @@ pub enum LinearRegressionSolverName {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Linear Regression parameters
|
/// Linear Regression parameters
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct LinearRegressionParameters {
|
pub struct LinearRegressionParameters {
|
||||||
/// Solver to use for estimation of regression coefficients.
|
/// Solver to use for estimation of regression coefficients.
|
||||||
pub solver: LinearRegressionSolverName,
|
pub solver: LinearRegressionSolverName,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Linear Regression
|
/// Linear Regression
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct LinearRegression<T: RealNumber, M: Matrix<T>> {
|
pub struct LinearRegression<T: RealNumber, M: Matrix<T>> {
|
||||||
coefficients: M,
|
coefficients: M,
|
||||||
intercept: T,
|
intercept: T,
|
||||||
@@ -247,6 +251,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ use std::cmp::Ordering;
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::api::{Predictor, SupervisedEstimator};
|
use crate::api::{Predictor, SupervisedEstimator};
|
||||||
@@ -68,11 +69,13 @@ use crate::optimization::line_search::Backtracking;
|
|||||||
use crate::optimization::FunctionOrder;
|
use crate::optimization::FunctionOrder;
|
||||||
|
|
||||||
/// Logistic Regression parameters
|
/// Logistic Regression parameters
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct LogisticRegressionParameters {}
|
pub struct LogisticRegressionParameters {}
|
||||||
|
|
||||||
/// Logistic Regression
|
/// Logistic Regression
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct LogisticRegression<T: RealNumber, M: Matrix<T>> {
|
pub struct LogisticRegression<T: RealNumber, M: Matrix<T>> {
|
||||||
coefficients: M,
|
coefficients: M,
|
||||||
intercept: M,
|
intercept: M,
|
||||||
@@ -540,6 +543,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[1., -5.],
|
&[1., -5.],
|
||||||
|
|||||||
@@ -58,6 +58,7 @@
|
|||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::api::{Predictor, SupervisedEstimator};
|
use crate::api::{Predictor, SupervisedEstimator};
|
||||||
@@ -66,7 +67,8 @@ use crate::linalg::BaseVector;
|
|||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
/// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable.
|
/// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable.
|
||||||
pub enum RidgeRegressionSolverName {
|
pub enum RidgeRegressionSolverName {
|
||||||
/// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html)
|
/// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html)
|
||||||
@@ -76,7 +78,8 @@ pub enum RidgeRegressionSolverName {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Ridge Regression parameters
|
/// Ridge Regression parameters
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct RidgeRegressionParameters<T: RealNumber> {
|
pub struct RidgeRegressionParameters<T: RealNumber> {
|
||||||
/// Solver to use for estimation of regression coefficients.
|
/// Solver to use for estimation of regression coefficients.
|
||||||
pub solver: RidgeRegressionSolverName,
|
pub solver: RidgeRegressionSolverName,
|
||||||
@@ -88,7 +91,8 @@ pub struct RidgeRegressionParameters<T: RealNumber> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Ridge regression
|
/// Ridge regression
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct RidgeRegression<T: RealNumber, M: Matrix<T>> {
|
pub struct RidgeRegression<T: RealNumber, M: Matrix<T>> {
|
||||||
coefficients: M,
|
coefficients: M,
|
||||||
intercept: T,
|
intercept: T,
|
||||||
@@ -326,6 +330,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||||
|
|||||||
@@ -18,6 +18,7 @@
|
|||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
@@ -25,7 +26,8 @@ use crate::math::num::RealNumber;
|
|||||||
use super::Distance;
|
use super::Distance;
|
||||||
|
|
||||||
/// Euclidean distance is a measure of the true straight line distance between two points in Euclidean n-space.
|
/// Euclidean distance is a measure of the true straight line distance between two points in Euclidean n-space.
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct Euclidian {}
|
pub struct Euclidian {}
|
||||||
|
|
||||||
impl Euclidian {
|
impl Euclidian {
|
||||||
|
|||||||
@@ -19,6 +19,7 @@
|
|||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
@@ -26,7 +27,8 @@ use crate::math::num::RealNumber;
|
|||||||
use super::Distance;
|
use super::Distance;
|
||||||
|
|
||||||
/// While comparing two integer-valued vectors of equal length, Hamming distance is the number of bit positions in which the two bits are different
|
/// While comparing two integer-valued vectors of equal length, Hamming distance is the number of bit positions in which the two bits are different
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct Hamming {}
|
pub struct Hamming {}
|
||||||
|
|
||||||
impl<T: PartialEq, F: RealNumber> Distance<Vec<T>, F> for Hamming {
|
impl<T: PartialEq, F: RealNumber> Distance<Vec<T>, F> for Hamming {
|
||||||
|
|||||||
@@ -44,6 +44,7 @@
|
|||||||
|
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
@@ -52,7 +53,8 @@ use super::Distance;
|
|||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
|
|
||||||
/// Mahalanobis distance.
|
/// Mahalanobis distance.
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct Mahalanobis<T: RealNumber, M: Matrix<T>> {
|
pub struct Mahalanobis<T: RealNumber, M: Matrix<T>> {
|
||||||
/// covariance matrix of the dataset
|
/// covariance matrix of the dataset
|
||||||
pub sigma: M,
|
pub sigma: M,
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
//! ```
|
//! ```
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
@@ -24,7 +25,8 @@ use crate::math::num::RealNumber;
|
|||||||
use super::Distance;
|
use super::Distance;
|
||||||
|
|
||||||
/// Manhattan distance
|
/// Manhattan distance
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct Manhattan {}
|
pub struct Manhattan {}
|
||||||
|
|
||||||
impl<T: RealNumber> Distance<Vec<T>, T> for Manhattan {
|
impl<T: RealNumber> Distance<Vec<T>, T> for Manhattan {
|
||||||
|
|||||||
@@ -21,6 +21,7 @@
|
|||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
@@ -28,7 +29,8 @@ use crate::math::num::RealNumber;
|
|||||||
use super::Distance;
|
use super::Distance;
|
||||||
|
|
||||||
/// Defines the Minkowski distance of order `p`
|
/// Defines the Minkowski distance of order `p`
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct Minkowski {
|
pub struct Minkowski {
|
||||||
/// order, integer
|
/// order, integer
|
||||||
pub p: u16,
|
pub p: u16,
|
||||||
|
|||||||
@@ -16,13 +16,15 @@
|
|||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
/// Accuracy metric.
|
/// Accuracy metric.
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct Accuracy {}
|
pub struct Accuracy {}
|
||||||
|
|
||||||
impl Accuracy {
|
impl Accuracy {
|
||||||
|
|||||||
+3
-1
@@ -20,6 +20,7 @@
|
|||||||
//! * ["The ROC-AUC and the Mann-Whitney U-test", Haupt, J.](https://johaupt.github.io/roc-auc/model%20evaluation/Area_under_ROC_curve.html)
|
//! * ["The ROC-AUC and the Mann-Whitney U-test", Haupt, J.](https://johaupt.github.io/roc-auc/model%20evaluation/Area_under_ROC_curve.html)
|
||||||
#![allow(non_snake_case)]
|
#![allow(non_snake_case)]
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||||
@@ -27,7 +28,8 @@ use crate::linalg::BaseVector;
|
|||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
/// Area Under the Receiver Operating Characteristic Curve (ROC AUC)
|
/// Area Under the Receiver Operating Characteristic Curve (ROC AUC)
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct AUC {}
|
pub struct AUC {}
|
||||||
|
|
||||||
impl AUC {
|
impl AUC {
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
use crate::metrics::cluster_helpers::*;
|
use crate::metrics::cluster_helpers::*;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
/// Homogeneity, completeness and V-Measure scores.
|
/// Homogeneity, completeness and V-Measure scores.
|
||||||
pub struct HCVScore {}
|
pub struct HCVScore {}
|
||||||
|
|
||||||
|
|||||||
+3
-1
@@ -18,6 +18,7 @@
|
|||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
@@ -26,7 +27,8 @@ use crate::metrics::precision::Precision;
|
|||||||
use crate::metrics::recall::Recall;
|
use crate::metrics::recall::Recall;
|
||||||
|
|
||||||
/// F-measure
|
/// F-measure
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct F1<T: RealNumber> {
|
pub struct F1<T: RealNumber> {
|
||||||
/// a positive real factor
|
/// a positive real factor
|
||||||
pub beta: T,
|
pub beta: T,
|
||||||
|
|||||||
@@ -18,12 +18,14 @@
|
|||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
/// Mean Absolute Error
|
/// Mean Absolute Error
|
||||||
pub struct MeanAbsoluteError {}
|
pub struct MeanAbsoluteError {}
|
||||||
|
|
||||||
|
|||||||
@@ -18,12 +18,14 @@
|
|||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
/// Mean Squared Error
|
/// Mean Squared Error
|
||||||
pub struct MeanSquareError {}
|
pub struct MeanSquareError {}
|
||||||
|
|
||||||
|
|||||||
@@ -18,13 +18,15 @@
|
|||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
/// Precision metric.
|
/// Precision metric.
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct Precision {}
|
pub struct Precision {}
|
||||||
|
|
||||||
impl Precision {
|
impl Precision {
|
||||||
|
|||||||
+3
-1
@@ -18,13 +18,15 @@
|
|||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
/// Coefficient of Determination (R2)
|
/// Coefficient of Determination (R2)
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct R2 {}
|
pub struct R2 {}
|
||||||
|
|
||||||
impl R2 {
|
impl R2 {
|
||||||
|
|||||||
@@ -18,13 +18,15 @@
|
|||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
/// Recall metric.
|
/// Recall metric.
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct Recall {}
|
pub struct Recall {}
|
||||||
|
|
||||||
impl Recall {
|
impl Recall {
|
||||||
|
|||||||
@@ -42,10 +42,12 @@ use crate::math::num::RealNumber;
|
|||||||
use crate::math::vector::RealNumberVector;
|
use crate::math::vector::RealNumberVector;
|
||||||
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
|
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// Naive Bayes classifier for Bearnoulli features
|
/// Naive Bayes classifier for Bearnoulli features
|
||||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
struct BernoulliNBDistribution<T: RealNumber> {
|
struct BernoulliNBDistribution<T: RealNumber> {
|
||||||
/// class labels known to the classifier
|
/// class labels known to the classifier
|
||||||
class_labels: Vec<T>,
|
class_labels: Vec<T>,
|
||||||
@@ -77,7 +79,8 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for BernoulliNBDistributi
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// `BernoulliNB` parameters. Use `Default::default()` for default values.
|
/// `BernoulliNB` parameters. Use `Default::default()` for default values.
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct BernoulliNBParameters<T: RealNumber> {
|
pub struct BernoulliNBParameters<T: RealNumber> {
|
||||||
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||||
pub alpha: T,
|
pub alpha: T,
|
||||||
@@ -202,7 +205,8 @@ impl<T: RealNumber> BernoulliNBDistribution<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// BernoulliNB implements the categorical naive Bayes algorithm for categorically distributed data.
|
/// BernoulliNB implements the categorical naive Bayes algorithm for categorically distributed data.
|
||||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
pub struct BernoulliNB<T: RealNumber, M: Matrix<T>> {
|
pub struct BernoulliNB<T: RealNumber, M: Matrix<T>> {
|
||||||
inner: BaseNaiveBayes<T, M, BernoulliNBDistribution<T>>,
|
inner: BaseNaiveBayes<T, M, BernoulliNBDistribution<T>>,
|
||||||
binarize: Option<T>,
|
binarize: Option<T>,
|
||||||
@@ -347,6 +351,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
&[1., 1., 0., 0., 0., 0.],
|
&[1., 1., 0., 0., 0., 0.],
|
||||||
|
|||||||
@@ -36,10 +36,12 @@ use crate::linalg::BaseVector;
|
|||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
|
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// Naive Bayes classifier for categorical features
|
/// Naive Bayes classifier for categorical features
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
struct CategoricalNBDistribution<T: RealNumber> {
|
struct CategoricalNBDistribution<T: RealNumber> {
|
||||||
class_labels: Vec<T>,
|
class_labels: Vec<T>,
|
||||||
class_priors: Vec<T>,
|
class_priors: Vec<T>,
|
||||||
@@ -216,7 +218,8 @@ impl<T: RealNumber> CategoricalNBDistribution<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// `CategoricalNB` parameters. Use `Default::default()` for default values.
|
/// `CategoricalNB` parameters. Use `Default::default()` for default values.
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct CategoricalNBParameters<T: RealNumber> {
|
pub struct CategoricalNBParameters<T: RealNumber> {
|
||||||
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||||
pub alpha: T,
|
pub alpha: T,
|
||||||
@@ -237,7 +240,8 @@ impl<T: RealNumber> Default for CategoricalNBParameters<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// CategoricalNB implements the categorical naive Bayes algorithm for categorically distributed data.
|
/// CategoricalNB implements the categorical naive Bayes algorithm for categorically distributed data.
|
||||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
pub struct CategoricalNB<T: RealNumber, M: Matrix<T>> {
|
pub struct CategoricalNB<T: RealNumber, M: Matrix<T>> {
|
||||||
inner: BaseNaiveBayes<T, M, CategoricalNBDistribution<T>>,
|
inner: BaseNaiveBayes<T, M, CategoricalNBDistribution<T>>,
|
||||||
}
|
}
|
||||||
@@ -345,6 +349,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
&[3., 4., 0., 1.],
|
&[3., 4., 0., 1.],
|
||||||
|
|||||||
@@ -30,10 +30,12 @@ use crate::linalg::Matrix;
|
|||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
use crate::math::vector::RealNumberVector;
|
use crate::math::vector::RealNumberVector;
|
||||||
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
|
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// Naive Bayes classifier for categorical features
|
/// Naive Bayes classifier for categorical features
|
||||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
struct GaussianNBDistribution<T: RealNumber> {
|
struct GaussianNBDistribution<T: RealNumber> {
|
||||||
/// class labels known to the classifier
|
/// class labels known to the classifier
|
||||||
class_labels: Vec<T>,
|
class_labels: Vec<T>,
|
||||||
@@ -75,7 +77,8 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for GaussianNBDistributio
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// `GaussianNB` parameters. Use `Default::default()` for default values.
|
/// `GaussianNB` parameters. Use `Default::default()` for default values.
|
||||||
#[derive(Serialize, Deserialize, Debug, Default, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Default, Clone)]
|
||||||
pub struct GaussianNBParameters<T: RealNumber> {
|
pub struct GaussianNBParameters<T: RealNumber> {
|
||||||
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
|
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
|
||||||
pub priors: Option<Vec<T>>,
|
pub priors: Option<Vec<T>>,
|
||||||
@@ -178,7 +181,8 @@ impl<T: RealNumber> GaussianNBDistribution<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// GaussianNB implements the categorical naive Bayes algorithm for categorically distributed data.
|
/// GaussianNB implements the categorical naive Bayes algorithm for categorically distributed data.
|
||||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
pub struct GaussianNB<T: RealNumber, M: Matrix<T>> {
|
pub struct GaussianNB<T: RealNumber, M: Matrix<T>> {
|
||||||
inner: BaseNaiveBayes<T, M, GaussianNBDistribution<T>>,
|
inner: BaseNaiveBayes<T, M, GaussianNBDistribution<T>>,
|
||||||
}
|
}
|
||||||
@@ -277,6 +281,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
&[-1., -1.],
|
&[-1., -1.],
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ use crate::error::Failed;
|
|||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
@@ -55,7 +56,8 @@ pub(crate) trait NBDistribution<T: RealNumber, M: Matrix<T>> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Base struct for the Naive Bayes classifier.
|
/// Base struct for the Naive Bayes classifier.
|
||||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
pub(crate) struct BaseNaiveBayes<T: RealNumber, M: Matrix<T>, D: NBDistribution<T, M>> {
|
pub(crate) struct BaseNaiveBayes<T: RealNumber, M: Matrix<T>, D: NBDistribution<T, M>> {
|
||||||
distribution: D,
|
distribution: D,
|
||||||
_phantom_t: PhantomData<T>,
|
_phantom_t: PhantomData<T>,
|
||||||
|
|||||||
@@ -42,10 +42,12 @@ use crate::math::num::RealNumber;
|
|||||||
use crate::math::vector::RealNumberVector;
|
use crate::math::vector::RealNumberVector;
|
||||||
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
|
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// Naive Bayes classifier for Multinomial features
|
/// Naive Bayes classifier for Multinomial features
|
||||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
struct MultinomialNBDistribution<T: RealNumber> {
|
struct MultinomialNBDistribution<T: RealNumber> {
|
||||||
/// class labels known to the classifier
|
/// class labels known to the classifier
|
||||||
class_labels: Vec<T>,
|
class_labels: Vec<T>,
|
||||||
@@ -73,7 +75,8 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for MultinomialNBDistribu
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// `MultinomialNB` parameters. Use `Default::default()` for default values.
|
/// `MultinomialNB` parameters. Use `Default::default()` for default values.
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct MultinomialNBParameters<T: RealNumber> {
|
pub struct MultinomialNBParameters<T: RealNumber> {
|
||||||
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||||
pub alpha: T,
|
pub alpha: T,
|
||||||
@@ -189,7 +192,8 @@ impl<T: RealNumber> MultinomialNBDistribution<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// MultinomialNB implements the categorical naive Bayes algorithm for categorically distributed data.
|
/// MultinomialNB implements the categorical naive Bayes algorithm for categorically distributed data.
|
||||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
pub struct MultinomialNB<T: RealNumber, M: Matrix<T>> {
|
pub struct MultinomialNB<T: RealNumber, M: Matrix<T>> {
|
||||||
inner: BaseNaiveBayes<T, M, MultinomialNBDistribution<T>>,
|
inner: BaseNaiveBayes<T, M, MultinomialNBDistribution<T>>,
|
||||||
}
|
}
|
||||||
@@ -320,6 +324,7 @@ mod tests {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
&[1., 1., 0., 0., 0., 0.],
|
&[1., 1., 0., 0., 0., 0.],
|
||||||
|
|||||||
@@ -33,6 +33,7 @@
|
|||||||
//!
|
//!
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
|
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
|
||||||
@@ -45,7 +46,8 @@ use crate::math::num::RealNumber;
|
|||||||
use crate::neighbors::KNNWeightFunction;
|
use crate::neighbors::KNNWeightFunction;
|
||||||
|
|
||||||
/// `KNNClassifier` parameters. Use `Default::default()` for default values.
|
/// `KNNClassifier` parameters. Use `Default::default()` for default values.
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct KNNClassifierParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
pub struct KNNClassifierParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||||
/// a function that defines a distance between each pair of point in training data.
|
/// a function that defines a distance between each pair of point in training data.
|
||||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||||
@@ -62,7 +64,8 @@ pub struct KNNClassifierParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// K Nearest Neighbors Classifier
|
/// K Nearest Neighbors Classifier
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct KNNClassifier<T: RealNumber, D: Distance<Vec<T>, T>> {
|
pub struct KNNClassifier<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||||
classes: Vec<T>,
|
classes: Vec<T>,
|
||||||
y: Vec<usize>,
|
y: Vec<usize>,
|
||||||
@@ -277,6 +280,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x =
|
let x =
|
||||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||||
|
|||||||
@@ -36,6 +36,7 @@
|
|||||||
//!
|
//!
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
|
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
|
||||||
@@ -48,7 +49,8 @@ use crate::math::num::RealNumber;
|
|||||||
use crate::neighbors::KNNWeightFunction;
|
use crate::neighbors::KNNWeightFunction;
|
||||||
|
|
||||||
/// `KNNRegressor` parameters. Use `Default::default()` for default values.
|
/// `KNNRegressor` parameters. Use `Default::default()` for default values.
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct KNNRegressorParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
pub struct KNNRegressorParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||||
/// a function that defines a distance between each pair of point in training data.
|
/// a function that defines a distance between each pair of point in training data.
|
||||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||||
@@ -65,7 +67,8 @@ pub struct KNNRegressorParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// K Nearest Neighbors Regressor
|
/// K Nearest Neighbors Regressor
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct KNNRegressor<T: RealNumber, D: Distance<Vec<T>, T>> {
|
pub struct KNNRegressor<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||||
y: Vec<T>,
|
y: Vec<T>,
|
||||||
knn_algorithm: KNNAlgorithm<T, D>,
|
knn_algorithm: KNNAlgorithm<T, D>,
|
||||||
@@ -266,6 +269,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x =
|
let x =
|
||||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||||
|
|||||||
@@ -33,6 +33,7 @@
|
|||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// K Nearest Neighbors Classifier
|
/// K Nearest Neighbors Classifier
|
||||||
@@ -48,7 +49,8 @@ pub mod knn_regressor;
|
|||||||
pub type KNNAlgorithmName = crate::algorithm::neighbour::KNNAlgorithmName;
|
pub type KNNAlgorithmName = crate::algorithm::neighbour::KNNAlgorithmName;
|
||||||
|
|
||||||
/// Weight function that is used to determine estimated value.
|
/// Weight function that is used to determine estimated value.
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub enum KNNWeightFunction {
|
pub enum KNNWeightFunction {
|
||||||
/// All k nearest points are weighted equally
|
/// All k nearest points are weighted equally
|
||||||
Uniform,
|
Uniform,
|
||||||
|
|||||||
+9
-4
@@ -26,6 +26,7 @@
|
|||||||
pub mod svc;
|
pub mod svc;
|
||||||
pub mod svr;
|
pub mod svr;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
@@ -93,18 +94,21 @@ impl Kernels {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Linear Kernel
|
/// Linear Kernel
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct LinearKernel {}
|
pub struct LinearKernel {}
|
||||||
|
|
||||||
/// Radial basis function (Gaussian) kernel
|
/// Radial basis function (Gaussian) kernel
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct RBFKernel<T: RealNumber> {
|
pub struct RBFKernel<T: RealNumber> {
|
||||||
/// kernel coefficient
|
/// kernel coefficient
|
||||||
pub gamma: T,
|
pub gamma: T,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Polynomial kernel
|
/// Polynomial kernel
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct PolynomialKernel<T: RealNumber> {
|
pub struct PolynomialKernel<T: RealNumber> {
|
||||||
/// degree of the polynomial
|
/// degree of the polynomial
|
||||||
pub degree: T,
|
pub degree: T,
|
||||||
@@ -115,7 +119,8 @@ pub struct PolynomialKernel<T: RealNumber> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Sigmoid (hyperbolic tangent) kernel
|
/// Sigmoid (hyperbolic tangent) kernel
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct SigmoidKernel<T: RealNumber> {
|
pub struct SigmoidKernel<T: RealNumber> {
|
||||||
/// kernel coefficient
|
/// kernel coefficient
|
||||||
pub gamma: T,
|
pub gamma: T,
|
||||||
|
|||||||
+14
-5
@@ -76,6 +76,7 @@ use std::marker::PhantomData;
|
|||||||
|
|
||||||
use rand::seq::SliceRandom;
|
use rand::seq::SliceRandom;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::api::{Predictor, SupervisedEstimator};
|
use crate::api::{Predictor, SupervisedEstimator};
|
||||||
@@ -85,7 +86,8 @@ use crate::linalg::Matrix;
|
|||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
use crate::svm::{Kernel, Kernels, LinearKernel};
|
use crate::svm::{Kernel, Kernels, LinearKernel};
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
/// SVC Parameters
|
/// SVC Parameters
|
||||||
pub struct SVCParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
pub struct SVCParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||||
/// Number of epochs.
|
/// Number of epochs.
|
||||||
@@ -100,11 +102,15 @@ pub struct SVCParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>
|
|||||||
m: PhantomData<M>,
|
m: PhantomData<M>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[serde(bound(
|
#[derive(Debug)]
|
||||||
|
#[cfg_attr(
|
||||||
|
feature = "serde",
|
||||||
|
serde(bound(
|
||||||
serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize",
|
serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize",
|
||||||
deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>",
|
deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>",
|
||||||
))]
|
))
|
||||||
|
)]
|
||||||
/// Support Vector Classifier
|
/// Support Vector Classifier
|
||||||
pub struct SVC<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
pub struct SVC<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||||
classes: Vec<T>,
|
classes: Vec<T>,
|
||||||
@@ -114,7 +120,8 @@ pub struct SVC<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
|||||||
b: T,
|
b: T,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
struct SupportVector<T: RealNumber, V: BaseVector<T>> {
|
struct SupportVector<T: RealNumber, V: BaseVector<T>> {
|
||||||
index: usize,
|
index: usize,
|
||||||
x: V,
|
x: V,
|
||||||
@@ -719,6 +726,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::*;
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
use crate::metrics::accuracy;
|
use crate::metrics::accuracy;
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use crate::svm::*;
|
use crate::svm::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -807,6 +815,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn svc_serde() {
|
fn svc_serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[5.1, 3.5, 1.4, 0.2],
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
|||||||
+14
-5
@@ -68,6 +68,7 @@ use std::cell::{Ref, RefCell};
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::api::{Predictor, SupervisedEstimator};
|
use crate::api::{Predictor, SupervisedEstimator};
|
||||||
@@ -77,7 +78,8 @@ use crate::linalg::Matrix;
|
|||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
use crate::svm::{Kernel, Kernels, LinearKernel};
|
use crate::svm::{Kernel, Kernels, LinearKernel};
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
/// SVR Parameters
|
/// SVR Parameters
|
||||||
pub struct SVRParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
pub struct SVRParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||||
/// Epsilon in the epsilon-SVR model.
|
/// Epsilon in the epsilon-SVR model.
|
||||||
@@ -92,11 +94,15 @@ pub struct SVRParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>
|
|||||||
m: PhantomData<M>,
|
m: PhantomData<M>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[serde(bound(
|
#[derive(Debug)]
|
||||||
|
#[cfg_attr(
|
||||||
|
feature = "serde",
|
||||||
|
serde(bound(
|
||||||
serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize",
|
serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize",
|
||||||
deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>",
|
deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>",
|
||||||
))]
|
))
|
||||||
|
)]
|
||||||
|
|
||||||
/// Epsilon-Support Vector Regression
|
/// Epsilon-Support Vector Regression
|
||||||
pub struct SVR<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
pub struct SVR<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||||
@@ -106,7 +112,8 @@ pub struct SVR<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
|||||||
b: T,
|
b: T,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
struct SupportVector<T: RealNumber, V: BaseVector<T>> {
|
struct SupportVector<T: RealNumber, V: BaseVector<T>> {
|
||||||
index: usize,
|
index: usize,
|
||||||
x: V,
|
x: V,
|
||||||
@@ -526,6 +533,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::*;
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
use crate::metrics::mean_squared_error;
|
use crate::metrics::mean_squared_error;
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use crate::svm::*;
|
use crate::svm::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -562,6 +570,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn svr_serde() {
|
fn svr_serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||||
|
|||||||
@@ -68,6 +68,7 @@ use std::fmt::Debug;
|
|||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
use rand::seq::SliceRandom;
|
use rand::seq::SliceRandom;
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||||
@@ -76,7 +77,8 @@ use crate::error::Failed;
|
|||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
/// Parameters of Decision Tree
|
/// Parameters of Decision Tree
|
||||||
pub struct DecisionTreeClassifierParameters {
|
pub struct DecisionTreeClassifierParameters {
|
||||||
/// Split criteria to use when building a tree.
|
/// Split criteria to use when building a tree.
|
||||||
@@ -90,7 +92,8 @@ pub struct DecisionTreeClassifierParameters {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Decision Tree
|
/// Decision Tree
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct DecisionTreeClassifier<T: RealNumber> {
|
pub struct DecisionTreeClassifier<T: RealNumber> {
|
||||||
nodes: Vec<Node<T>>,
|
nodes: Vec<Node<T>>,
|
||||||
parameters: DecisionTreeClassifierParameters,
|
parameters: DecisionTreeClassifierParameters,
|
||||||
@@ -100,7 +103,8 @@ pub struct DecisionTreeClassifier<T: RealNumber> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// The function to measure the quality of a split.
|
/// The function to measure the quality of a split.
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub enum SplitCriterion {
|
pub enum SplitCriterion {
|
||||||
/// [Gini index](../decision_tree_classifier/index.html)
|
/// [Gini index](../decision_tree_classifier/index.html)
|
||||||
Gini,
|
Gini,
|
||||||
@@ -110,7 +114,8 @@ pub enum SplitCriterion {
|
|||||||
ClassificationError,
|
ClassificationError,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
struct Node<T: RealNumber> {
|
struct Node<T: RealNumber> {
|
||||||
index: usize,
|
index: usize,
|
||||||
output: usize,
|
output: usize,
|
||||||
@@ -740,6 +745,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[1., 1., 1., 0.],
|
&[1., 1., 1., 0.],
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ use std::default::Default;
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
use rand::seq::SliceRandom;
|
use rand::seq::SliceRandom;
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||||
@@ -71,7 +72,8 @@ use crate::error::Failed;
|
|||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
/// Parameters of Regression Tree
|
/// Parameters of Regression Tree
|
||||||
pub struct DecisionTreeRegressorParameters {
|
pub struct DecisionTreeRegressorParameters {
|
||||||
/// The maximum depth of the tree.
|
/// The maximum depth of the tree.
|
||||||
@@ -83,14 +85,16 @@ pub struct DecisionTreeRegressorParameters {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Regression Tree
|
/// Regression Tree
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct DecisionTreeRegressor<T: RealNumber> {
|
pub struct DecisionTreeRegressor<T: RealNumber> {
|
||||||
nodes: Vec<Node<T>>,
|
nodes: Vec<Node<T>>,
|
||||||
parameters: DecisionTreeRegressorParameters,
|
parameters: DecisionTreeRegressorParameters,
|
||||||
depth: u16,
|
depth: u16,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
struct Node<T: RealNumber> {
|
struct Node<T: RealNumber> {
|
||||||
index: usize,
|
index: usize,
|
||||||
output: T,
|
output: T,
|
||||||
@@ -577,6 +581,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
&[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||||
|
|||||||
Reference in New Issue
Block a user