From 61b404afea9b4453951525caf40bdb4383771c66 Mon Sep 17 00:00:00 2001 From: Volodymyr Orlov Date: Mon, 8 Jun 2020 15:02:51 -0700 Subject: [PATCH] fix: metric parameter name changed --- src/metrics/accuracy.rs | 8 ++++---- src/metrics/f1.rs | 10 +++++----- src/metrics/mod.rs | 28 ++++++++++++++++++++++------ src/metrics/precision.rs | 12 ++++++------ src/metrics/recall.rs | 12 ++++++------ 5 files changed, 43 insertions(+), 27 deletions(-) diff --git a/src/metrics/accuracy.rs b/src/metrics/accuracy.rs index da6cd48..a1695d0 100644 --- a/src/metrics/accuracy.rs +++ b/src/metrics/accuracy.rs @@ -7,12 +7,12 @@ use crate::math::num::FloatExt; pub struct Accuracy {} impl Accuracy { - pub fn get_score>(&self, y_true: &V, y_prod: &V) -> T { - if y_true.len() != y_prod.len() { + pub fn get_score>(&self, y_true: &V, y_pred: &V) -> T { + if y_true.len() != y_pred.len() { panic!( "The vector sizes don't match: {} != {}", y_true.len(), - y_prod.len() + y_pred.len() ); } @@ -20,7 +20,7 @@ impl Accuracy { let mut positive = 0; for i in 0..n { - if y_true.get(i) == y_prod.get(i) { + if y_true.get(i) == y_pred.get(i) { positive += 1; } } diff --git a/src/metrics/f1.rs b/src/metrics/f1.rs index 1313322..a1af664 100644 --- a/src/metrics/f1.rs +++ b/src/metrics/f1.rs @@ -9,18 +9,18 @@ use crate::metrics::recall::Recall; pub struct F1 {} impl F1 { - pub fn get_score>(&self, y_true: &V, y_prod: &V) -> T { - if y_true.len() != y_prod.len() { + pub fn get_score>(&self, y_true: &V, y_pred: &V) -> T { + if y_true.len() != y_pred.len() { panic!( "The vector sizes don't match: {} != {}", y_true.len(), - y_prod.len() + y_pred.len() ); } let beta2 = T::one(); - let p = Precision {}.get_score(y_true, y_prod); - let r = Recall {}.get_score(y_true, y_prod); + let p = Precision {}.get_score(y_true, y_pred); + let r = Recall {}.get_score(y_true, y_pred); (T::one() + beta2) * (p * r) / (beta2 * p + r) } diff --git a/src/metrics/mod.rs b/src/metrics/mod.rs index 30cc6ce..7a062f1 100644 --- a/src/metrics/mod.rs +++ b/src/metrics/mod.rs @@ -21,16 +21,32 @@ impl ClassificationMetrics { pub fn precision() -> precision::Precision { precision::Precision {} } + + pub fn f1() -> f1::F1 { + f1::F1 {} + } + + pub fn roc_auc_score() -> auc::AUC { + auc::AUC {} + } } -pub fn accuracy>(y_true: &V, y_prod: &V) -> T { - ClassificationMetrics::accuracy().get_score(y_true, y_prod) +pub fn accuracy>(y_true: &V, y_pred: &V) -> T { + ClassificationMetrics::accuracy().get_score(y_true, y_pred) } -pub fn recall>(y_true: &V, y_prod: &V) -> T { - ClassificationMetrics::recall().get_score(y_true, y_prod) +pub fn recall>(y_true: &V, y_pred: &V) -> T { + ClassificationMetrics::recall().get_score(y_true, y_pred) } -pub fn precision>(y_true: &V, y_prod: &V) -> T { - ClassificationMetrics::precision().get_score(y_true, y_prod) +pub fn precision>(y_true: &V, y_pred: &V) -> T { + ClassificationMetrics::precision().get_score(y_true, y_pred) +} + +pub fn f1>(y_true: &V, y_pred: &V) -> T { + ClassificationMetrics::f1().get_score(y_true, y_pred) +} + +pub fn roc_auc_score>(y_true: &V, y_pred_probabilities: &V) -> T { + ClassificationMetrics::roc_auc_score().get_score(y_true, y_pred_probabilities) } diff --git a/src/metrics/precision.rs b/src/metrics/precision.rs index 3a285da..b3e6c72 100644 --- a/src/metrics/precision.rs +++ b/src/metrics/precision.rs @@ -7,12 +7,12 @@ use crate::math::num::FloatExt; pub struct Precision {} impl Precision { - pub fn get_score>(&self, y_true: &V, y_prod: &V) -> T { - if y_true.len() != y_prod.len() { + pub fn get_score>(&self, y_true: &V, y_pred: &V) -> T { + if y_true.len() != y_pred.len() { panic!( "The vector sizes don't match: {} != {}", y_true.len(), - y_prod.len() + y_pred.len() ); } @@ -27,14 +27,14 @@ impl Precision { ); } - if y_prod.get(i) != T::zero() && y_prod.get(i) != T::one() { + if y_pred.get(i) != T::zero() && y_pred.get(i) != T::one() { panic!( "Precision can only be applied to binary classification: {}", - y_prod.get(i) + y_pred.get(i) ); } - if y_prod.get(i) == T::one() { + if y_pred.get(i) == T::one() { p += 1; if y_true.get(i) == T::one() { diff --git a/src/metrics/recall.rs b/src/metrics/recall.rs index 1af9e45..14e91ee 100644 --- a/src/metrics/recall.rs +++ b/src/metrics/recall.rs @@ -7,12 +7,12 @@ use crate::math::num::FloatExt; pub struct Recall {} impl Recall { - pub fn get_score>(&self, y_true: &V, y_prod: &V) -> T { - if y_true.len() != y_prod.len() { + pub fn get_score>(&self, y_true: &V, y_pred: &V) -> T { + if y_true.len() != y_pred.len() { panic!( "The vector sizes don't match: {} != {}", y_true.len(), - y_prod.len() + y_pred.len() ); } @@ -27,17 +27,17 @@ impl Recall { ); } - if y_prod.get(i) != T::zero() && y_prod.get(i) != T::one() { + if y_pred.get(i) != T::zero() && y_pred.get(i) != T::one() { panic!( "Recall can only be applied to binary classification: {}", - y_prod.get(i) + y_pred.get(i) ); } if y_true.get(i) == T::one() { p += 1; - if y_prod.get(i) == T::one() { + if y_pred.get(i) == T::one() { tp += 1; } }