Merge potential next release v0.4 (#187) Breaking Changes
* First draft of the new n-dimensional arrays + NB use case * Improves default implementation of multiple Array methods * Refactors tree methods * Adds matrix decomposition routines * Adds matrix decomposition methods to ndarray and nalgebra bindings * Refactoring + linear regression now uses array2 * Ridge & Linear regression * LBFGS optimizer & logistic regression * LBFGS optimizer & logistic regression * Changes linear methods, metrics and model selection methods to new n-dimensional arrays * Switches KNN and clustering algorithms to new n-d array layer * Refactors distance metrics * Optimizes knn and clustering methods * Refactors metrics module * Switches decomposition methods to n-dimensional arrays * Linalg refactoring - cleanup rng merge (#172) * Remove legacy DenseMatrix and BaseMatrix implementation. Port the new Number, FloatNumber and Array implementation into module structure. * Exclude AUC metrics. Needs reimplementation * Improve developers walkthrough New traits system in place at `src/numbers` and `src/linalg` Co-authored-by: Lorenzo <tunedconsulting@gmail.com> * Provide SupervisedEstimator with a constructor to avoid explicit dynamical box allocation in 'cross_validate' and 'cross_validate_predict' as required by the use of 'dyn' as per Rust 2021 * Implement getters to use as_ref() in src/neighbors * Implement getters to use as_ref() in src/naive_bayes * Implement getters to use as_ref() in src/linear * Add Clone to src/naive_bayes * Change signature for cross_validate and other model_selection functions to abide to use of dyn in Rust 2021 * Implement ndarray-bindings. Remove FloatNumber from implementations * Drop nalgebra-bindings support (as decided in conf-call to go for ndarray) * Remove benches. Benches will have their own repo at smartcore-benches * Implement SVC * Implement SVC serialization. Move search parameters in dedicated module * Implement SVR. Definitely too slow * Fix compilation issues for wasm (#202) Co-authored-by: Luis Moreno <morenol@users.noreply.github.com> * Fix tests (#203) * Port linalg/traits/stats.rs * Improve methods naming * Improve Display for DenseMatrix Co-authored-by: Montana Low <montanalow@users.noreply.github.com> Co-authored-by: VolodymyrOrlov <volodymyr.orlov@gmail.com>
This commit is contained in:
+143
-68
@@ -12,7 +12,7 @@
|
||||
//!
|
||||
//! Example:
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
//! use smartcore::linear::logistic_regression::LogisticRegression;
|
||||
//! use smartcore::metrics::*;
|
||||
//!
|
||||
@@ -38,26 +38,29 @@
|
||||
//! &[6.6, 2.9, 4.6, 1.3],
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//! let y: Vec<f64> = vec![
|
||||
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
//! let y: Vec<i8> = vec![
|
||||
//! 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
//! ];
|
||||
//!
|
||||
//! let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
//!
|
||||
//! let y_hat = lr.predict(&x).unwrap();
|
||||
//!
|
||||
//! let acc = ClassificationMetrics::accuracy().get_score(&y, &y_hat);
|
||||
//! let acc = ClassificationMetricsOrd::accuracy().get_score(&y, &y_hat);
|
||||
//! // or
|
||||
//! let acc = accuracy(&y, &y_hat);
|
||||
//! ```
|
||||
|
||||
/// Accuracy score.
|
||||
pub mod accuracy;
|
||||
/// Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores.
|
||||
pub mod auc;
|
||||
// TODO: reimplement AUC
|
||||
// /// Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores.
|
||||
// pub mod auc;
|
||||
/// Compute the homogeneity, completeness and V-Measure scores.
|
||||
pub mod cluster_hcv;
|
||||
pub(crate) mod cluster_helpers;
|
||||
/// Multitude of distance metrics are defined here
|
||||
pub mod distance;
|
||||
/// F1 score, also known as balanced F-score or F-measure.
|
||||
pub mod f1;
|
||||
/// Mean absolute error regression loss.
|
||||
@@ -71,150 +74,222 @@ pub mod r2;
|
||||
/// Computes the recall.
|
||||
pub mod recall;
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::linalg::basic::arrays::{Array1, ArrayView1};
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::floatnum::FloatNumber;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// A trait to be implemented by all metrics
|
||||
pub trait Metrics<T> {
|
||||
/// instantiate a new Metrics trait-object
|
||||
/// https://doc.rust-lang.org/error-index.html#E0038
|
||||
fn new() -> Self
|
||||
where
|
||||
Self: Sized;
|
||||
/// used to instantiate metric with a paramenter
|
||||
fn new_with(_parameter: f64) -> Self
|
||||
where
|
||||
Self: Sized;
|
||||
/// compute score realated to this metric
|
||||
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64;
|
||||
}
|
||||
|
||||
/// Use these metrics to compare classification models.
|
||||
pub struct ClassificationMetrics {}
|
||||
pub struct ClassificationMetrics<T> {
|
||||
phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
/// Use these metrics to compare classification models for
|
||||
/// numbers that require `Ord`.
|
||||
pub struct ClassificationMetricsOrd<T> {
|
||||
phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
/// Metrics for regression models.
|
||||
pub struct RegressionMetrics {}
|
||||
pub struct RegressionMetrics<T> {
|
||||
phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
/// Cluster metrics.
|
||||
pub struct ClusterMetrics {}
|
||||
|
||||
impl ClassificationMetrics {
|
||||
/// Accuracy score, see [accuracy](accuracy/index.html).
|
||||
pub fn accuracy() -> accuracy::Accuracy {
|
||||
accuracy::Accuracy {}
|
||||
}
|
||||
pub struct ClusterMetrics<T> {
|
||||
phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: Number + RealNumber + FloatNumber> ClassificationMetrics<T> {
|
||||
/// Recall, see [recall](recall/index.html).
|
||||
pub fn recall() -> recall::Recall {
|
||||
recall::Recall {}
|
||||
pub fn recall() -> recall::Recall<T> {
|
||||
recall::Recall::new()
|
||||
}
|
||||
|
||||
/// Precision, see [precision](precision/index.html).
|
||||
pub fn precision() -> precision::Precision {
|
||||
precision::Precision {}
|
||||
pub fn precision() -> precision::Precision<T> {
|
||||
precision::Precision::new()
|
||||
}
|
||||
|
||||
/// F1 score, also known as balanced F-score or F-measure, see [F1](f1/index.html).
|
||||
pub fn f1<T: RealNumber>(beta: T) -> f1::F1<T> {
|
||||
f1::F1 { beta }
|
||||
pub fn f1(beta: f64) -> f1::F1<T> {
|
||||
f1::F1::new_with(beta)
|
||||
}
|
||||
|
||||
/// Area Under the Receiver Operating Characteristic Curve (ROC AUC), see [AUC](auc/index.html).
|
||||
pub fn roc_auc_score() -> auc::AUC {
|
||||
auc::AUC {}
|
||||
// /// Area Under the Receiver Operating Characteristic Curve (ROC AUC), see [AUC](auc/index.html).
|
||||
// pub fn roc_auc_score() -> auc::AUC<T> {
|
||||
// auc::AUC::<T>::new()
|
||||
// }
|
||||
}
|
||||
|
||||
impl<T: Number + Ord> ClassificationMetricsOrd<T> {
|
||||
/// Accuracy score, see [accuracy](accuracy/index.html).
|
||||
pub fn accuracy() -> accuracy::Accuracy<T> {
|
||||
accuracy::Accuracy::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl RegressionMetrics {
|
||||
impl<T: Number + FloatNumber> RegressionMetrics<T> {
|
||||
/// Mean squared error, see [mean squared error](mean_squared_error/index.html).
|
||||
pub fn mean_squared_error() -> mean_squared_error::MeanSquareError {
|
||||
mean_squared_error::MeanSquareError {}
|
||||
pub fn mean_squared_error() -> mean_squared_error::MeanSquareError<T> {
|
||||
mean_squared_error::MeanSquareError::new()
|
||||
}
|
||||
|
||||
/// Mean absolute error, see [mean absolute error](mean_absolute_error/index.html).
|
||||
pub fn mean_absolute_error() -> mean_absolute_error::MeanAbsoluteError {
|
||||
mean_absolute_error::MeanAbsoluteError {}
|
||||
pub fn mean_absolute_error() -> mean_absolute_error::MeanAbsoluteError<T> {
|
||||
mean_absolute_error::MeanAbsoluteError::new()
|
||||
}
|
||||
|
||||
/// Coefficient of determination (R2), see [R2](r2/index.html).
|
||||
pub fn r2() -> r2::R2 {
|
||||
r2::R2 {}
|
||||
pub fn r2() -> r2::R2<T> {
|
||||
r2::R2::<T>::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ClusterMetrics {
|
||||
impl<T: Number + Ord> ClusterMetrics<T> {
|
||||
/// Homogeneity and completeness and V-Measure scores at once.
|
||||
pub fn hcv_score() -> cluster_hcv::HCVScore {
|
||||
cluster_hcv::HCVScore {}
|
||||
pub fn hcv_score() -> cluster_hcv::HCVScore<T> {
|
||||
cluster_hcv::HCVScore::<T>::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Function that calculated accuracy score, see [accuracy](accuracy/index.html).
|
||||
/// * `y_true` - cround truth (correct) labels
|
||||
/// * `y_pred` - predicted labels, as returned by a classifier.
|
||||
pub fn accuracy<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
|
||||
ClassificationMetrics::accuracy().get_score(y_true, y_pred)
|
||||
pub fn accuracy<T: Number + Ord, V: ArrayView1<T>>(y_true: &V, y_pred: &V) -> f64 {
|
||||
let obj = ClassificationMetricsOrd::<T>::accuracy();
|
||||
obj.get_score(y_true, y_pred)
|
||||
}
|
||||
|
||||
/// Calculated recall score, see [recall](recall/index.html)
|
||||
/// * `y_true` - cround truth (correct) labels.
|
||||
/// * `y_pred` - predicted labels, as returned by a classifier.
|
||||
pub fn recall<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
|
||||
ClassificationMetrics::recall().get_score(y_true, y_pred)
|
||||
pub fn recall<T: Number + RealNumber + FloatNumber, V: ArrayView1<T>>(
|
||||
y_true: &V,
|
||||
y_pred: &V,
|
||||
) -> f64 {
|
||||
let obj = ClassificationMetrics::<T>::recall();
|
||||
obj.get_score(y_true, y_pred)
|
||||
}
|
||||
|
||||
/// Calculated precision score, see [precision](precision/index.html).
|
||||
/// * `y_true` - cround truth (correct) labels.
|
||||
/// * `y_pred` - predicted labels, as returned by a classifier.
|
||||
pub fn precision<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
|
||||
ClassificationMetrics::precision().get_score(y_true, y_pred)
|
||||
pub fn precision<T: Number + RealNumber + FloatNumber, V: ArrayView1<T>>(
|
||||
y_true: &V,
|
||||
y_pred: &V,
|
||||
) -> f64 {
|
||||
let obj = ClassificationMetrics::<T>::precision();
|
||||
obj.get_score(y_true, y_pred)
|
||||
}
|
||||
|
||||
/// Computes F1 score, see [F1](f1/index.html).
|
||||
/// * `y_true` - cround truth (correct) labels.
|
||||
/// * `y_pred` - predicted labels, as returned by a classifier.
|
||||
pub fn f1<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V, beta: T) -> T {
|
||||
ClassificationMetrics::f1(beta).get_score(y_true, y_pred)
|
||||
pub fn f1<T: Number + RealNumber + FloatNumber, V: ArrayView1<T>>(
|
||||
y_true: &V,
|
||||
y_pred: &V,
|
||||
beta: f64,
|
||||
) -> f64 {
|
||||
let obj = ClassificationMetrics::<T>::f1(beta);
|
||||
obj.get_score(y_true, y_pred)
|
||||
}
|
||||
|
||||
/// AUC score, see [AUC](auc/index.html).
|
||||
/// * `y_true` - cround truth (correct) labels.
|
||||
/// * `y_pred_probabilities` - probability estimates, as returned by a classifier.
|
||||
pub fn roc_auc_score<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred_probabilities: &V) -> T {
|
||||
ClassificationMetrics::roc_auc_score().get_score(y_true, y_pred_probabilities)
|
||||
}
|
||||
// /// AUC score, see [AUC](auc/index.html).
|
||||
// /// * `y_true` - cround truth (correct) labels.
|
||||
// /// * `y_pred_probabilities` - probability estimates, as returned by a classifier.
|
||||
// pub fn roc_auc_score<T: Number + PartialOrd, V: ArrayView1<T> + Array1<T> + Array1<T>>(
|
||||
// y_true: &V,
|
||||
// y_pred_probabilities: &V,
|
||||
// ) -> T {
|
||||
// let obj = ClassificationMetrics::<T>::roc_auc_score();
|
||||
// obj.get_score(y_true, y_pred_probabilities)
|
||||
// }
|
||||
|
||||
/// Computes mean squared error, see [mean squared error](mean_squared_error/index.html).
|
||||
/// * `y_true` - Ground truth (correct) target values.
|
||||
/// * `y_pred` - Estimated target values.
|
||||
pub fn mean_squared_error<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
|
||||
RegressionMetrics::mean_squared_error().get_score(y_true, y_pred)
|
||||
pub fn mean_squared_error<T: Number + FloatNumber, V: ArrayView1<T>>(
|
||||
y_true: &V,
|
||||
y_pred: &V,
|
||||
) -> f64 {
|
||||
RegressionMetrics::<T>::mean_squared_error().get_score(y_true, y_pred)
|
||||
}
|
||||
|
||||
/// Computes mean absolute error, see [mean absolute error](mean_absolute_error/index.html).
|
||||
/// * `y_true` - Ground truth (correct) target values.
|
||||
/// * `y_pred` - Estimated target values.
|
||||
pub fn mean_absolute_error<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
|
||||
RegressionMetrics::mean_absolute_error().get_score(y_true, y_pred)
|
||||
pub fn mean_absolute_error<T: Number + FloatNumber, V: ArrayView1<T>>(
|
||||
y_true: &V,
|
||||
y_pred: &V,
|
||||
) -> f64 {
|
||||
RegressionMetrics::<T>::mean_absolute_error().get_score(y_true, y_pred)
|
||||
}
|
||||
|
||||
/// Computes R2 score, see [R2](r2/index.html).
|
||||
/// * `y_true` - Ground truth (correct) target values.
|
||||
/// * `y_pred` - Estimated target values.
|
||||
pub fn r2<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
|
||||
RegressionMetrics::r2().get_score(y_true, y_pred)
|
||||
pub fn r2<T: Number + FloatNumber, V: ArrayView1<T>>(y_true: &V, y_pred: &V) -> f64 {
|
||||
RegressionMetrics::<T>::r2().get_score(y_true, y_pred)
|
||||
}
|
||||
|
||||
/// Homogeneity metric of a cluster labeling given a ground truth (range is between 0.0 and 1.0).
|
||||
/// A cluster result satisfies homogeneity if all of its clusters contain only data points which are members of a single class.
|
||||
/// * `labels_true` - ground truth class labels to be used as a reference.
|
||||
/// * `labels_pred` - cluster labels to evaluate.
|
||||
pub fn homogeneity_score<T: RealNumber, V: BaseVector<T>>(labels_true: &V, labels_pred: &V) -> T {
|
||||
ClusterMetrics::hcv_score()
|
||||
.get_score(labels_true, labels_pred)
|
||||
.0
|
||||
pub fn homogeneity_score<
|
||||
T: Number + FloatNumber + RealNumber + Ord,
|
||||
V: ArrayView1<T> + Array1<T>,
|
||||
>(
|
||||
y_true: &V,
|
||||
y_pred: &V,
|
||||
) -> f64 {
|
||||
let mut obj = ClusterMetrics::<T>::hcv_score();
|
||||
obj.compute(y_true, y_pred);
|
||||
obj.homogeneity().unwrap()
|
||||
}
|
||||
|
||||
///
|
||||
/// Completeness metric of a cluster labeling given a ground truth (range is between 0.0 and 1.0).
|
||||
/// * `labels_true` - ground truth class labels to be used as a reference.
|
||||
/// * `labels_pred` - cluster labels to evaluate.
|
||||
pub fn completeness_score<T: RealNumber, V: BaseVector<T>>(labels_true: &V, labels_pred: &V) -> T {
|
||||
ClusterMetrics::hcv_score()
|
||||
.get_score(labels_true, labels_pred)
|
||||
.1
|
||||
pub fn completeness_score<
|
||||
T: Number + FloatNumber + RealNumber + Ord,
|
||||
V: ArrayView1<T> + Array1<T>,
|
||||
>(
|
||||
y_true: &V,
|
||||
y_pred: &V,
|
||||
) -> f64 {
|
||||
let mut obj = ClusterMetrics::<T>::hcv_score();
|
||||
obj.compute(y_true, y_pred);
|
||||
obj.completeness().unwrap()
|
||||
}
|
||||
|
||||
/// The harmonic mean between homogeneity and completeness.
|
||||
/// * `labels_true` - ground truth class labels to be used as a reference.
|
||||
/// * `labels_pred` - cluster labels to evaluate.
|
||||
pub fn v_measure_score<T: RealNumber, V: BaseVector<T>>(labels_true: &V, labels_pred: &V) -> T {
|
||||
ClusterMetrics::hcv_score()
|
||||
.get_score(labels_true, labels_pred)
|
||||
.2
|
||||
pub fn v_measure_score<T: Number + FloatNumber + RealNumber + Ord, V: ArrayView1<T> + Array1<T>>(
|
||||
y_true: &V,
|
||||
y_pred: &V,
|
||||
) -> f64 {
|
||||
let mut obj = ClusterMetrics::<T>::hcv_score();
|
||||
obj.compute(y_true, y_pred);
|
||||
obj.v_measure().unwrap()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user