feat: adds 3 new regression metrics

This commit is contained in:
Volodymyr Orlov
2020-06-10 17:06:34 -07:00
parent 61b404afea
commit 2ece181e08
5 changed files with 185 additions and 0 deletions
+31
View File
@@ -1,7 +1,10 @@
pub mod accuracy;
pub mod auc;
pub mod f1;
pub mod mean_absolute_error;
pub mod mean_squared_error;
pub mod precision;
pub mod r2;
pub mod recall;
use crate::linalg::BaseVector;
@@ -9,6 +12,8 @@ use crate::math::num::FloatExt;
pub struct ClassificationMetrics {}
pub struct RegressionMetrics {}
impl ClassificationMetrics {
pub fn accuracy() -> accuracy::Accuracy {
accuracy::Accuracy {}
@@ -31,6 +36,20 @@ impl ClassificationMetrics {
}
}
impl RegressionMetrics {
pub fn mean_squared_error() -> mean_squared_error::MeanSquareError {
mean_squared_error::MeanSquareError {}
}
pub fn mean_absolute_error() -> mean_absolute_error::MeanAbsoluteError {
mean_absolute_error::MeanAbsoluteError {}
}
pub fn r2() -> r2::R2 {
r2::R2 {}
}
}
pub fn accuracy<T: FloatExt, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
ClassificationMetrics::accuracy().get_score(y_true, y_pred)
}
@@ -50,3 +69,15 @@ pub fn f1<T: FloatExt, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
pub fn roc_auc_score<T: FloatExt, V: BaseVector<T>>(y_true: &V, y_pred_probabilities: &V) -> T {
ClassificationMetrics::roc_auc_score().get_score(y_true, y_pred_probabilities)
}
pub fn mean_squared_error<T: FloatExt, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
RegressionMetrics::mean_squared_error().get_score(y_true, y_pred)
}
pub fn mean_absolute_error<T: FloatExt, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
RegressionMetrics::mean_absolute_error().get_score(y_true, y_pred)
}
pub fn r2<T: FloatExt, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
RegressionMetrics::r2().get_score(y_true, y_pred)
}