fix: metric parameter name changed
This commit is contained in:
@@ -7,12 +7,12 @@ use crate::math::num::FloatExt;
|
|||||||
pub struct Accuracy {}
|
pub struct Accuracy {}
|
||||||
|
|
||||||
impl Accuracy {
|
impl Accuracy {
|
||||||
pub fn get_score<T: FloatExt, V: BaseVector<T>>(&self, y_true: &V, y_prod: &V) -> T {
|
pub fn get_score<T: FloatExt, V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
|
||||||
if y_true.len() != y_prod.len() {
|
if y_true.len() != y_pred.len() {
|
||||||
panic!(
|
panic!(
|
||||||
"The vector sizes don't match: {} != {}",
|
"The vector sizes don't match: {} != {}",
|
||||||
y_true.len(),
|
y_true.len(),
|
||||||
y_prod.len()
|
y_pred.len()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -20,7 +20,7 @@ impl Accuracy {
|
|||||||
|
|
||||||
let mut positive = 0;
|
let mut positive = 0;
|
||||||
for i in 0..n {
|
for i in 0..n {
|
||||||
if y_true.get(i) == y_prod.get(i) {
|
if y_true.get(i) == y_pred.get(i) {
|
||||||
positive += 1;
|
positive += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+5
-5
@@ -9,18 +9,18 @@ use crate::metrics::recall::Recall;
|
|||||||
pub struct F1 {}
|
pub struct F1 {}
|
||||||
|
|
||||||
impl F1 {
|
impl F1 {
|
||||||
pub fn get_score<T: FloatExt, V: BaseVector<T>>(&self, y_true: &V, y_prod: &V) -> T {
|
pub fn get_score<T: FloatExt, V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
|
||||||
if y_true.len() != y_prod.len() {
|
if y_true.len() != y_pred.len() {
|
||||||
panic!(
|
panic!(
|
||||||
"The vector sizes don't match: {} != {}",
|
"The vector sizes don't match: {} != {}",
|
||||||
y_true.len(),
|
y_true.len(),
|
||||||
y_prod.len()
|
y_pred.len()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
let beta2 = T::one();
|
let beta2 = T::one();
|
||||||
|
|
||||||
let p = Precision {}.get_score(y_true, y_prod);
|
let p = Precision {}.get_score(y_true, y_pred);
|
||||||
let r = Recall {}.get_score(y_true, y_prod);
|
let r = Recall {}.get_score(y_true, y_pred);
|
||||||
|
|
||||||
(T::one() + beta2) * (p * r) / (beta2 * p + r)
|
(T::one() + beta2) * (p * r) / (beta2 * p + r)
|
||||||
}
|
}
|
||||||
|
|||||||
+22
-6
@@ -21,16 +21,32 @@ impl ClassificationMetrics {
|
|||||||
pub fn precision() -> precision::Precision {
|
pub fn precision() -> precision::Precision {
|
||||||
precision::Precision {}
|
precision::Precision {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn f1() -> f1::F1 {
|
||||||
|
f1::F1 {}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn accuracy<T: FloatExt, V: BaseVector<T>>(y_true: &V, y_prod: &V) -> T {
|
pub fn roc_auc_score() -> auc::AUC {
|
||||||
ClassificationMetrics::accuracy().get_score(y_true, y_prod)
|
auc::AUC {}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn recall<T: FloatExt, V: BaseVector<T>>(y_true: &V, y_prod: &V) -> T {
|
pub fn accuracy<T: FloatExt, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
|
||||||
ClassificationMetrics::recall().get_score(y_true, y_prod)
|
ClassificationMetrics::accuracy().get_score(y_true, y_pred)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn precision<T: FloatExt, V: BaseVector<T>>(y_true: &V, y_prod: &V) -> T {
|
pub fn recall<T: FloatExt, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
|
||||||
ClassificationMetrics::precision().get_score(y_true, y_prod)
|
ClassificationMetrics::recall().get_score(y_true, y_pred)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn precision<T: FloatExt, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
|
||||||
|
ClassificationMetrics::precision().get_score(y_true, y_pred)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn f1<T: FloatExt, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
|
||||||
|
ClassificationMetrics::f1().get_score(y_true, y_pred)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn roc_auc_score<T: FloatExt, V: BaseVector<T>>(y_true: &V, y_pred_probabilities: &V) -> T {
|
||||||
|
ClassificationMetrics::roc_auc_score().get_score(y_true, y_pred_probabilities)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,12 +7,12 @@ use crate::math::num::FloatExt;
|
|||||||
pub struct Precision {}
|
pub struct Precision {}
|
||||||
|
|
||||||
impl Precision {
|
impl Precision {
|
||||||
pub fn get_score<T: FloatExt, V: BaseVector<T>>(&self, y_true: &V, y_prod: &V) -> T {
|
pub fn get_score<T: FloatExt, V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
|
||||||
if y_true.len() != y_prod.len() {
|
if y_true.len() != y_pred.len() {
|
||||||
panic!(
|
panic!(
|
||||||
"The vector sizes don't match: {} != {}",
|
"The vector sizes don't match: {} != {}",
|
||||||
y_true.len(),
|
y_true.len(),
|
||||||
y_prod.len()
|
y_pred.len()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -27,14 +27,14 @@ impl Precision {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if y_prod.get(i) != T::zero() && y_prod.get(i) != T::one() {
|
if y_pred.get(i) != T::zero() && y_pred.get(i) != T::one() {
|
||||||
panic!(
|
panic!(
|
||||||
"Precision can only be applied to binary classification: {}",
|
"Precision can only be applied to binary classification: {}",
|
||||||
y_prod.get(i)
|
y_pred.get(i)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if y_prod.get(i) == T::one() {
|
if y_pred.get(i) == T::one() {
|
||||||
p += 1;
|
p += 1;
|
||||||
|
|
||||||
if y_true.get(i) == T::one() {
|
if y_true.get(i) == T::one() {
|
||||||
|
|||||||
@@ -7,12 +7,12 @@ use crate::math::num::FloatExt;
|
|||||||
pub struct Recall {}
|
pub struct Recall {}
|
||||||
|
|
||||||
impl Recall {
|
impl Recall {
|
||||||
pub fn get_score<T: FloatExt, V: BaseVector<T>>(&self, y_true: &V, y_prod: &V) -> T {
|
pub fn get_score<T: FloatExt, V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
|
||||||
if y_true.len() != y_prod.len() {
|
if y_true.len() != y_pred.len() {
|
||||||
panic!(
|
panic!(
|
||||||
"The vector sizes don't match: {} != {}",
|
"The vector sizes don't match: {} != {}",
|
||||||
y_true.len(),
|
y_true.len(),
|
||||||
y_prod.len()
|
y_pred.len()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -27,17 +27,17 @@ impl Recall {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if y_prod.get(i) != T::zero() && y_prod.get(i) != T::one() {
|
if y_pred.get(i) != T::zero() && y_pred.get(i) != T::one() {
|
||||||
panic!(
|
panic!(
|
||||||
"Recall can only be applied to binary classification: {}",
|
"Recall can only be applied to binary classification: {}",
|
||||||
y_prod.get(i)
|
y_pred.get(i)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if y_true.get(i) == T::one() {
|
if y_true.get(i) == T::one() {
|
||||||
p += 1;
|
p += 1;
|
||||||
|
|
||||||
if y_prod.get(i) == T::one() {
|
if y_pred.get(i) == T::one() {
|
||||||
tp += 1;
|
tp += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user