feat: documents metric functions
This commit is contained in:
@@ -1,12 +1,33 @@
|
||||
//! # Accuracy score
|
||||
//!
|
||||
//! Calculates accuracy of predictions \\(\hat{y}\\) when compared to true labels \\(y\\)
|
||||
//!
|
||||
//! \\[ accuracy(y, \hat{y}) = \frac{1}{n_{samples}} \sum_{i=1}^{n_{samples}} 1(y_i = \hat{y_i}) \\]
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::metrics::accuracy::Accuracy;
|
||||
//! let y_pred: Vec<f64> = vec![0., 2., 1., 3.];
|
||||
//! let y_true: Vec<f64> = vec![0., 1., 2., 3.];
|
||||
//!
|
||||
//! let score: f64 = Accuracy {}.get_score(&y_pred, &y_true);
|
||||
//! ```
|
||||
//!
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Accuracy metric.
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Accuracy {}
|
||||
|
||||
impl Accuracy {
|
||||
/// Function that calculated accuracy score.
|
||||
/// * `y_true` - cround truth (correct) labels
|
||||
/// * `y_pred` - predicted labels, as returned by a classifier.
|
||||
pub fn get_score<T: RealNumber, V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
|
||||
if y_true.len() != y_pred.len() {
|
||||
panic!(
|
||||
|
||||
@@ -1,3 +1,23 @@
|
||||
//! # Area Under the Receiver Operating Characteristic Curve (ROC AUC)
|
||||
//! Computes the area under the receiver operating characteristic (ROC) curve that is equal to the probability that a classifier will rank a
|
||||
//! randomly chosen positive instance higher than a randomly chosen negative one.
|
||||
//!
|
||||
//! SmartCore calculates ROC AUC from Wilcoxon or Mann-Whitney U test.
|
||||
//!
|
||||
//! Example:
|
||||
//! ```
|
||||
//! use smartcore::metrics::auc::AUC;
|
||||
//!
|
||||
//! let y_true: Vec<f64> = vec![0., 0., 1., 1.];
|
||||
//! let y_pred: Vec<f64> = vec![0.1, 0.4, 0.35, 0.8];
|
||||
//!
|
||||
//! let score1: f64 = AUC {}.get_score(&y_true, &y_pred);
|
||||
//! ```
|
||||
//!
|
||||
//! ## References:
|
||||
//! * ["Areas beneath the relative operating characteristics (ROC) and relative operating levels (ROL) curves: Statistical significance and interpretation", Mason S. J., Graham N. E.](http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.458.8392)
|
||||
//! * [Wikipedia article on ROC AUC](https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve)
|
||||
//! * ["The ROC-AUC and the Mann-Whitney U-test", Haupt, J.](https://johaupt.github.io/roc-auc/model%20evaluation/Area_under_ROC_curve.html)
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
+32
-6
@@ -1,3 +1,22 @@
|
||||
//! # F-measure
|
||||
//!
|
||||
//! Harmonic mean of the precision and recall.
|
||||
//!
|
||||
//! \\[f1 = (1 + \beta^2)\frac{precision \times recall}{\beta^2 \times precision + recall}\\]
|
||||
//!
|
||||
//! where \\(\beta \\) is a positive real factor, where \\(\beta \\) is chosen such that recall is considered \\(\beta \\) times as important as precision.
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::metrics::f1::F1;
|
||||
//! let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
||||
//! let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
|
||||
//!
|
||||
//! let score: f64 = F1 {beta: 1.0}.get_score(&y_pred, &y_true);
|
||||
//! ```
|
||||
//!
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
@@ -5,11 +24,18 @@ use crate::math::num::RealNumber;
|
||||
use crate::metrics::precision::Precision;
|
||||
use crate::metrics::recall::Recall;
|
||||
|
||||
/// F-measure
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct F1 {}
|
||||
pub struct F1<T: RealNumber> {
|
||||
/// a positive real factor
|
||||
pub beta: T,
|
||||
}
|
||||
|
||||
impl F1 {
|
||||
pub fn get_score<T: RealNumber, V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
|
||||
impl<T: RealNumber> F1<T> {
|
||||
/// Computes F1 score
|
||||
/// * `y_true` - cround truth (correct) labels.
|
||||
/// * `y_pred` - predicted labels, as returned by a classifier.
|
||||
pub fn get_score<V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
|
||||
if y_true.len() != y_pred.len() {
|
||||
panic!(
|
||||
"The vector sizes don't match: {} != {}",
|
||||
@@ -17,7 +43,7 @@ impl F1 {
|
||||
y_pred.len()
|
||||
);
|
||||
}
|
||||
let beta2 = T::one();
|
||||
let beta2 = self.beta * self.beta;
|
||||
|
||||
let p = Precision {}.get_score(y_true, y_pred);
|
||||
let r = Recall {}.get_score(y_true, y_pred);
|
||||
@@ -35,8 +61,8 @@ mod tests {
|
||||
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
||||
let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
|
||||
|
||||
let score1: f64 = F1 {}.get_score(&y_pred, &y_true);
|
||||
let score2: f64 = F1 {}.get_score(&y_true, &y_true);
|
||||
let score1: f64 = F1 { beta: 1.0 }.get_score(&y_pred, &y_true);
|
||||
let score2: f64 = F1 { beta: 1.0 }.get_score(&y_true, &y_true);
|
||||
|
||||
assert!((score1 - 0.57142857).abs() < 1e-8);
|
||||
assert!((score2 - 1.0).abs() < 1e-8);
|
||||
|
||||
@@ -1,12 +1,35 @@
|
||||
//! # Mean Absolute Error
|
||||
//!
|
||||
//! MAE measures the average magnitude of the errors in a set of predictions, without considering their direction.
|
||||
//!
|
||||
//! \\[mse(y, \hat{y}) = \frac{1}{n_{samples}} \sum_{i=1}^{n_{samples}} \lvert y_i - \hat{y_i} \rvert \\]
|
||||
//!
|
||||
//! where \\(\hat{y}\\) are predictions and \\(y\\) are true target values.
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::metrics::mean_absolute_error::MeanAbsoluteError;
|
||||
//! let y_pred: Vec<f64> = vec![3., -0.5, 2., 7.];
|
||||
//! let y_true: Vec<f64> = vec![2.5, 0.0, 2., 8.];
|
||||
//!
|
||||
//! let mse: f64 = MeanAbsoluteError {}.get_score(&y_pred, &y_true);
|
||||
//! ```
|
||||
//!
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
/// Mean Absolute Error
|
||||
pub struct MeanAbsoluteError {}
|
||||
|
||||
impl MeanAbsoluteError {
|
||||
/// Computes mean absolute error
|
||||
/// * `y_true` - Ground truth (correct) target values.
|
||||
/// * `y_pred` - Estimated target values.
|
||||
pub fn get_score<T: RealNumber, V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
|
||||
if y_true.len() != y_pred.len() {
|
||||
panic!(
|
||||
|
||||
@@ -1,12 +1,35 @@
|
||||
//! # Mean Squared Error
|
||||
//!
|
||||
//! MSE measures the average magnitude of the errors in a set of predictions, without considering their direction.
|
||||
//!
|
||||
//! \\[mse(y, \hat{y}) = \frac{1}{n_{samples}} \sum_{i=1}^{n_{samples}} (y_i - \hat{y_i})^2 \\]
|
||||
//!
|
||||
//! where \\(\hat{y}\\) are predictions and \\(y\\) are true target values.
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::metrics::mean_squared_error::MeanSquareError;
|
||||
//! let y_pred: Vec<f64> = vec![3., -0.5, 2., 7.];
|
||||
//! let y_true: Vec<f64> = vec![2.5, 0.0, 2., 8.];
|
||||
//!
|
||||
//! let mse: f64 = MeanSquareError {}.get_score(&y_pred, &y_true);
|
||||
//! ```
|
||||
//!
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
/// Mean Squared Error
|
||||
pub struct MeanSquareError {}
|
||||
|
||||
impl MeanSquareError {
|
||||
/// Computes mean squared error
|
||||
/// * `y_true` - Ground truth (correct) target values.
|
||||
/// * `y_pred` - Estimated target values.
|
||||
pub fn get_score<T: RealNumber, V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
|
||||
if y_true.len() != y_pred.len() {
|
||||
panic!(
|
||||
|
||||
+98
-4
@@ -1,83 +1,177 @@
|
||||
//! # Metric functions
|
||||
//!
|
||||
//! One way to build machine learning models is to use a constructive feedback loop through model evaluation.
|
||||
//! In a feedback loop you build your model first, then you get feedback from metrics, improve it and repeat until your model achieve desirable performance.
|
||||
//! Evaluation metrics helps to explain the performance of a model and compare models based on an objective criterion.
|
||||
//!
|
||||
//! Choosing the right metric is crucial while evaluating machine learning models. In SmartCore you will find metrics for these classes of ML models:
|
||||
//!
|
||||
//! * [Classification metrics](struct.ClassificationMetrics.html)
|
||||
//! * [Regression metrics](struct.RegressionMetrics.html)
|
||||
//!
|
||||
//! Example:
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::linear::logistic_regression::LogisticRegression;
|
||||
//! use smartcore::metrics::*;
|
||||
//!
|
||||
//! let x = DenseMatrix::from_array(&[
|
||||
//! &[5.1, 3.5, 1.4, 0.2],
|
||||
//! &[4.9, 3.0, 1.4, 0.2],
|
||||
//! &[4.7, 3.2, 1.3, 0.2],
|
||||
//! &[4.6, 3.1, 1.5, 0.2],
|
||||
//! &[5.0, 3.6, 1.4, 0.2],
|
||||
//! &[5.4, 3.9, 1.7, 0.4],
|
||||
//! &[4.6, 3.4, 1.4, 0.3],
|
||||
//! &[5.0, 3.4, 1.5, 0.2],
|
||||
//! &[4.4, 2.9, 1.4, 0.2],
|
||||
//! &[4.9, 3.1, 1.5, 0.1],
|
||||
//! &[7.0, 3.2, 4.7, 1.4],
|
||||
//! &[6.4, 3.2, 4.5, 1.5],
|
||||
//! &[6.9, 3.1, 4.9, 1.5],
|
||||
//! &[5.5, 2.3, 4.0, 1.3],
|
||||
//! &[6.5, 2.8, 4.6, 1.5],
|
||||
//! &[5.7, 2.8, 4.5, 1.3],
|
||||
//! &[6.3, 3.3, 4.7, 1.6],
|
||||
//! &[4.9, 2.4, 3.3, 1.0],
|
||||
//! &[6.6, 2.9, 4.6, 1.3],
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//! let y: Vec<f64> = vec![
|
||||
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
//! ];
|
||||
//!
|
||||
//! let lr = LogisticRegression::fit(&x, &y);
|
||||
//!
|
||||
//! let y_hat = lr.predict(&x);
|
||||
//!
|
||||
//! let acc = ClassificationMetrics::accuracy().get_score(&y, &y_hat);
|
||||
//! // or
|
||||
//! let acc = accuracy(&y, &y_hat);
|
||||
//! ```
|
||||
|
||||
/// Accuracy score.
|
||||
pub mod accuracy;
|
||||
/// Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores.
|
||||
pub mod auc;
|
||||
/// F1 score, also known as balanced F-score or F-measure.
|
||||
pub mod f1;
|
||||
/// Mean absolute error regression loss.
|
||||
pub mod mean_absolute_error;
|
||||
/// Mean squared error regression loss.
|
||||
pub mod mean_squared_error;
|
||||
/// Computes the precision.
|
||||
pub mod precision;
|
||||
/// Coefficient of determination (R2).
|
||||
pub mod r2;
|
||||
/// Computes the recall.
|
||||
pub mod recall;
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Use these metrics to compare classification models.
|
||||
pub struct ClassificationMetrics {}
|
||||
|
||||
/// Metrics for regression models.
|
||||
pub struct RegressionMetrics {}
|
||||
|
||||
impl ClassificationMetrics {
|
||||
/// Accuracy score, see [accuracy](accuracy/index.html).
|
||||
pub fn accuracy() -> accuracy::Accuracy {
|
||||
accuracy::Accuracy {}
|
||||
}
|
||||
|
||||
/// Recall, see [recall](recall/index.html).
|
||||
pub fn recall() -> recall::Recall {
|
||||
recall::Recall {}
|
||||
}
|
||||
|
||||
/// Precision, see [precision](precision/index.html).
|
||||
pub fn precision() -> precision::Precision {
|
||||
precision::Precision {}
|
||||
}
|
||||
|
||||
pub fn f1() -> f1::F1 {
|
||||
f1::F1 {}
|
||||
/// F1 score, also known as balanced F-score or F-measure, see [F1](f1/index.html).
|
||||
pub fn f1<T: RealNumber>(beta: T) -> f1::F1<T> {
|
||||
f1::F1 { beta: beta }
|
||||
}
|
||||
|
||||
/// Area Under the Receiver Operating Characteristic Curve (ROC AUC), see [AUC](auc/index.html).
|
||||
pub fn roc_auc_score() -> auc::AUC {
|
||||
auc::AUC {}
|
||||
}
|
||||
}
|
||||
|
||||
impl RegressionMetrics {
|
||||
/// Mean squared error, see [mean squared error](mean_squared_error/index.html).
|
||||
pub fn mean_squared_error() -> mean_squared_error::MeanSquareError {
|
||||
mean_squared_error::MeanSquareError {}
|
||||
}
|
||||
|
||||
/// Mean absolute error, see [mean absolute error](mean_absolute_error/index.html).
|
||||
pub fn mean_absolute_error() -> mean_absolute_error::MeanAbsoluteError {
|
||||
mean_absolute_error::MeanAbsoluteError {}
|
||||
}
|
||||
|
||||
/// Coefficient of determination (R2), see [R2](r2/index.html).
|
||||
pub fn r2() -> r2::R2 {
|
||||
r2::R2 {}
|
||||
}
|
||||
}
|
||||
|
||||
/// Function that calculated accuracy score, see [accuracy](accuracy/index.html).
|
||||
/// * `y_true` - cround truth (correct) labels
|
||||
/// * `y_pred` - predicted labels, as returned by a classifier.
|
||||
pub fn accuracy<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
|
||||
ClassificationMetrics::accuracy().get_score(y_true, y_pred)
|
||||
}
|
||||
|
||||
/// Calculated recall score, see [recall](recall/index.html)
|
||||
/// * `y_true` - cround truth (correct) labels.
|
||||
/// * `y_pred` - predicted labels, as returned by a classifier.
|
||||
pub fn recall<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
|
||||
ClassificationMetrics::recall().get_score(y_true, y_pred)
|
||||
}
|
||||
|
||||
/// Calculated precision score, see [precision](precision/index.html).
|
||||
/// * `y_true` - cround truth (correct) labels.
|
||||
/// * `y_pred` - predicted labels, as returned by a classifier.
|
||||
pub fn precision<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
|
||||
ClassificationMetrics::precision().get_score(y_true, y_pred)
|
||||
}
|
||||
|
||||
pub fn f1<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
|
||||
ClassificationMetrics::f1().get_score(y_true, y_pred)
|
||||
/// Computes F1 score, see [F1](f1/index.html).
|
||||
/// * `y_true` - cround truth (correct) labels.
|
||||
/// * `y_pred` - predicted labels, as returned by a classifier.
|
||||
pub fn f1<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V, beta: T) -> T {
|
||||
ClassificationMetrics::f1(beta).get_score(y_true, y_pred)
|
||||
}
|
||||
|
||||
/// AUC score, see [AUC](auc/index.html).
|
||||
/// * `y_true` - cround truth (correct) labels.
|
||||
/// * `y_pred_probabilities` - probability estimates, as returned by a classifier.
|
||||
pub fn roc_auc_score<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred_probabilities: &V) -> T {
|
||||
ClassificationMetrics::roc_auc_score().get_score(y_true, y_pred_probabilities)
|
||||
}
|
||||
|
||||
/// Computes mean squared error, see [mean squared error](mean_squared_error/index.html).
|
||||
/// * `y_true` - Ground truth (correct) target values.
|
||||
/// * `y_pred` - Estimated target values.
|
||||
pub fn mean_squared_error<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
|
||||
RegressionMetrics::mean_squared_error().get_score(y_true, y_pred)
|
||||
}
|
||||
|
||||
/// Computes mean absolute error, see [mean absolute error](mean_absolute_error/index.html).
|
||||
/// * `y_true` - Ground truth (correct) target values.
|
||||
/// * `y_pred` - Estimated target values.
|
||||
pub fn mean_absolute_error<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
|
||||
RegressionMetrics::mean_absolute_error().get_score(y_true, y_pred)
|
||||
}
|
||||
|
||||
/// Computes R2 score, see [R2](r2/index.html).
|
||||
/// * `y_true` - Ground truth (correct) target values.
|
||||
/// * `y_pred` - Estimated target values.
|
||||
pub fn r2<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
|
||||
RegressionMetrics::r2().get_score(y_true, y_pred)
|
||||
}
|
||||
|
||||
@@ -1,12 +1,35 @@
|
||||
//! # Precision score
|
||||
//!
|
||||
//! How many predicted items are relevant?
|
||||
//!
|
||||
//! \\[precision = \frac{tp}{tp + fp}\\]
|
||||
//!
|
||||
//! where tp (true positive) - correct result, fp (false positive) - unexpected result
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::metrics::precision::Precision;
|
||||
//! let y_pred: Vec<f64> = vec![0., 1., 1., 0.];
|
||||
//! let y_true: Vec<f64> = vec![0., 0., 1., 1.];
|
||||
//!
|
||||
//! let score: f64 = Precision {}.get_score(&y_pred, &y_true);
|
||||
//! ```
|
||||
//!
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Precision metric.
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Precision {}
|
||||
|
||||
impl Precision {
|
||||
/// Calculated precision score
|
||||
/// * `y_true` - cround truth (correct) labels.
|
||||
/// * `y_pred` - predicted labels, as returned by a classifier.
|
||||
pub fn get_score<T: RealNumber, V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
|
||||
if y_true.len() != y_pred.len() {
|
||||
panic!(
|
||||
|
||||
@@ -1,12 +1,35 @@
|
||||
//! Coefficient of Determination (R2)
|
||||
//!
|
||||
//! Coefficient of determination, denoted R2 is the proportion of the variance in the dependent variable that can be explained be explanatory (independent) variable(s).
|
||||
//!
|
||||
//! \\[R^2(y, \hat{y}) = 1 - \frac{\sum_{i=1}^{n}(y_i - \hat{y_i})^2}{\sum_{i=1}^{n}(y_i - \bar{y})^2} \\]
|
||||
//!
|
||||
//! where \\(\hat{y}\\) are predictions, \\(y\\) are true target values, \\(\bar{y}\\) is the mean of the observed data
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::metrics::mean_absolute_error::MeanAbsoluteError;
|
||||
//! let y_pred: Vec<f64> = vec![3., -0.5, 2., 7.];
|
||||
//! let y_true: Vec<f64> = vec![2.5, 0.0, 2., 8.];
|
||||
//!
|
||||
//! let mse: f64 = MeanAbsoluteError {}.get_score(&y_pred, &y_true);
|
||||
//! ```
|
||||
//!
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Coefficient of Determination (R2)
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct R2 {}
|
||||
|
||||
impl R2 {
|
||||
/// Computes R2 score
|
||||
/// * `y_true` - Ground truth (correct) target values.
|
||||
/// * `y_pred` - Estimated target values.
|
||||
pub fn get_score<T: RealNumber, V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
|
||||
if y_true.len() != y_pred.len() {
|
||||
panic!(
|
||||
|
||||
@@ -1,12 +1,35 @@
|
||||
//! # Recall score
|
||||
//!
|
||||
//! How many relevant items are selected?
|
||||
//!
|
||||
//! \\[recall = \frac{tp}{tp + fn}\\]
|
||||
//!
|
||||
//! where tp (true positive) - correct result, fn (false negative) - missing result
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::metrics::recall::Recall;
|
||||
//! let y_pred: Vec<f64> = vec![0., 1., 1., 0.];
|
||||
//! let y_true: Vec<f64> = vec![0., 0., 1., 1.];
|
||||
//!
|
||||
//! let score: f64 = Recall {}.get_score(&y_pred, &y_true);
|
||||
//! ```
|
||||
//!
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Recall metric.
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Recall {}
|
||||
|
||||
impl Recall {
|
||||
/// Calculated recall score
|
||||
/// * `y_true` - cround truth (correct) labels.
|
||||
/// * `y_pred` - predicted labels, as returned by a classifier.
|
||||
pub fn get_score<T: RealNumber, V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
|
||||
if y_true.len() != y_pred.len() {
|
||||
panic!(
|
||||
|
||||
Reference in New Issue
Block a user