From 2ece181e08a33d254321011696a6e24413e94501 Mon Sep 17 00:00:00 2001 From: Volodymyr Orlov Date: Wed, 10 Jun 2020 17:06:34 -0700 Subject: [PATCH] feat: adds 3 new regression metrics --- src/math/num.rs | 4 +++ src/metrics/mean_absolute_error.rs | 46 ++++++++++++++++++++++++ src/metrics/mean_squared_error.rs | 46 ++++++++++++++++++++++++ src/metrics/mod.rs | 31 ++++++++++++++++ src/metrics/r2.rs | 58 ++++++++++++++++++++++++++++++ 5 files changed, 185 insertions(+) create mode 100644 src/metrics/mean_absolute_error.rs create mode 100644 src/metrics/mean_squared_error.rs create mode 100644 src/metrics/r2.rs diff --git a/src/math/num.rs b/src/math/num.rs index 25774de..4623779 100644 --- a/src/math/num.rs +++ b/src/math/num.rs @@ -14,6 +14,10 @@ pub trait FloatExt: Float + FromPrimitive + Debug + Display + Copy { fn two() -> Self; fn half() -> Self; + + fn square(self) -> Self { + self * self + } } impl FloatExt for f64 { diff --git a/src/metrics/mean_absolute_error.rs b/src/metrics/mean_absolute_error.rs new file mode 100644 index 0000000..6448ea7 --- /dev/null +++ b/src/metrics/mean_absolute_error.rs @@ -0,0 +1,46 @@ +use serde::{Deserialize, Serialize}; + +use crate::linalg::BaseVector; +use crate::math::num::FloatExt; + +#[derive(Serialize, Deserialize, Debug)] +pub struct MeanAbsoluteError {} + +impl MeanAbsoluteError { + pub fn get_score>(&self, y_true: &V, y_pred: &V) -> T { + if y_true.len() != y_pred.len() { + panic!( + "The vector sizes don't match: {} != {}", + y_true.len(), + y_pred.len() + ); + } + + let n = y_true.len(); + let mut ras = T::zero(); + for i in 0..n { + ras = ras + (y_true.get(i) - y_pred.get(i)).abs(); + } + + ras / T::from_usize(n).unwrap() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn mean_absolute_error() { + let y_true: Vec = vec![3., -0.5, 2., 7.]; + let y_pred: Vec = vec![2.5, 0.0, 2., 8.]; + + let score1: f64 = MeanAbsoluteError {}.get_score(&y_pred, &y_true); + let score2: f64 = MeanAbsoluteError {}.get_score(&y_true, &y_true); + + println!("{}", score1); + + assert!((score1 - 0.5).abs() < 1e-8); + assert!((score2 - 0.0).abs() < 1e-8); + } +} diff --git a/src/metrics/mean_squared_error.rs b/src/metrics/mean_squared_error.rs new file mode 100644 index 0000000..37ecf73 --- /dev/null +++ b/src/metrics/mean_squared_error.rs @@ -0,0 +1,46 @@ +use serde::{Deserialize, Serialize}; + +use crate::linalg::BaseVector; +use crate::math::num::FloatExt; + +#[derive(Serialize, Deserialize, Debug)] +pub struct MeanSquareError {} + +impl MeanSquareError { + pub fn get_score>(&self, y_true: &V, y_pred: &V) -> T { + if y_true.len() != y_pred.len() { + panic!( + "The vector sizes don't match: {} != {}", + y_true.len(), + y_pred.len() + ); + } + + let n = y_true.len(); + let mut rss = T::zero(); + for i in 0..n { + rss = rss + (y_true.get(i) - y_pred.get(i)).square(); + } + + rss / T::from_usize(n).unwrap() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn mean_squared_error() { + let y_true: Vec = vec![3., -0.5, 2., 7.]; + let y_pred: Vec = vec![2.5, 0.0, 2., 8.]; + + let score1: f64 = MeanSquareError {}.get_score(&y_pred, &y_true); + let score2: f64 = MeanSquareError {}.get_score(&y_true, &y_true); + + println!("{}", score1); + + assert!((score1 - 0.375).abs() < 1e-8); + assert!((score2 - 0.0).abs() < 1e-8); + } +} diff --git a/src/metrics/mod.rs b/src/metrics/mod.rs index 7a062f1..2ae6464 100644 --- a/src/metrics/mod.rs +++ b/src/metrics/mod.rs @@ -1,7 +1,10 @@ pub mod accuracy; pub mod auc; pub mod f1; +pub mod mean_absolute_error; +pub mod mean_squared_error; pub mod precision; +pub mod r2; pub mod recall; use crate::linalg::BaseVector; @@ -9,6 +12,8 @@ use crate::math::num::FloatExt; pub struct ClassificationMetrics {} +pub struct RegressionMetrics {} + impl ClassificationMetrics { pub fn accuracy() -> accuracy::Accuracy { accuracy::Accuracy {} @@ -31,6 +36,20 @@ impl ClassificationMetrics { } } +impl RegressionMetrics { + pub fn mean_squared_error() -> mean_squared_error::MeanSquareError { + mean_squared_error::MeanSquareError {} + } + + pub fn mean_absolute_error() -> mean_absolute_error::MeanAbsoluteError { + mean_absolute_error::MeanAbsoluteError {} + } + + pub fn r2() -> r2::R2 { + r2::R2 {} + } +} + pub fn accuracy>(y_true: &V, y_pred: &V) -> T { ClassificationMetrics::accuracy().get_score(y_true, y_pred) } @@ -50,3 +69,15 @@ pub fn f1>(y_true: &V, y_pred: &V) -> T { pub fn roc_auc_score>(y_true: &V, y_pred_probabilities: &V) -> T { ClassificationMetrics::roc_auc_score().get_score(y_true, y_pred_probabilities) } + +pub fn mean_squared_error>(y_true: &V, y_pred: &V) -> T { + RegressionMetrics::mean_squared_error().get_score(y_true, y_pred) +} + +pub fn mean_absolute_error>(y_true: &V, y_pred: &V) -> T { + RegressionMetrics::mean_absolute_error().get_score(y_true, y_pred) +} + +pub fn r2>(y_true: &V, y_pred: &V) -> T { + RegressionMetrics::r2().get_score(y_true, y_pred) +} diff --git a/src/metrics/r2.rs b/src/metrics/r2.rs new file mode 100644 index 0000000..b823bc6 --- /dev/null +++ b/src/metrics/r2.rs @@ -0,0 +1,58 @@ +use serde::{Deserialize, Serialize}; + +use crate::linalg::BaseVector; +use crate::math::num::FloatExt; + +#[derive(Serialize, Deserialize, Debug)] +pub struct R2 {} + +impl R2 { + pub fn get_score>(&self, y_true: &V, y_pred: &V) -> T { + if y_true.len() != y_pred.len() { + panic!( + "The vector sizes don't match: {} != {}", + y_true.len(), + y_pred.len() + ); + } + + let n = y_true.len(); + + let mut mean = T::zero(); + + for i in 0..n { + mean = mean + y_true.get(i); + } + + mean = mean / T::from_usize(n).unwrap(); + + let mut ss_tot = T::zero(); + let mut ss_res = T::zero(); + + for i in 0..n { + let y_i = y_true.get(i); + let f_i = y_pred.get(i); + ss_tot = ss_tot + (y_i - mean).square(); + ss_res = ss_res + (y_i - f_i).square(); + } + + T::one() - (ss_res / ss_tot) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn r2() { + let y_true: Vec = vec![3., -0.5, 2., 7.]; + let y_pred: Vec = vec![2.5, 0.0, 2., 8.]; + + let score1: f64 = R2 {}.get_score(&y_true, &y_pred); + let score2: f64 = R2 {}.get_score(&y_true, &y_true); + + assert!((score1 - 0.948608137).abs() < 1e-8); + assert!((score2 - 1.0).abs() < 1e-8); + } +}