feat: + cluster metrics

This commit is contained in:
Volodymyr Orlov
2020-09-22 20:23:51 -07:00
parent 0803532e79
commit 750015b861
15 changed files with 477 additions and 16 deletions
+39
View File
@@ -54,6 +54,8 @@
pub mod accuracy;
/// Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores.
pub mod auc;
pub mod cluster_hcv;
pub(crate) mod cluster_helpers;
/// F1 score, also known as balanced F-score or F-measure.
pub mod f1;
/// Mean absolute error regression loss.
@@ -76,6 +78,9 @@ pub struct ClassificationMetrics {}
/// Metrics for regression models.
pub struct RegressionMetrics {}
/// Cluster metrics.
pub struct ClusterMetrics {}
impl ClassificationMetrics {
/// Accuracy score, see [accuracy](accuracy/index.html).
pub fn accuracy() -> accuracy::Accuracy {
@@ -120,6 +125,13 @@ impl RegressionMetrics {
}
}
impl ClusterMetrics {
/// Mean squared error, see [mean squared error](mean_squared_error/index.html).
pub fn hcv_score() -> cluster_hcv::HCVScore {
cluster_hcv::HCVScore {}
}
}
/// Function that calculated accuracy score, see [accuracy](accuracy/index.html).
/// * `y_true` - cround truth (correct) labels
/// * `y_pred` - predicted labels, as returned by a classifier.
@@ -175,3 +187,30 @@ pub fn mean_absolute_error<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred:
pub fn r2<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
RegressionMetrics::r2().get_score(y_true, y_pred)
}
/// Computes R2 score, see [R2](r2/index.html).
/// * `y_true` - Ground truth (correct) target values.
/// * `y_pred` - Estimated target values.
pub fn homogeneity_score<T: RealNumber, V: BaseVector<T>>(labels_true: &V, labels_pred: &V) -> T {
ClusterMetrics::hcv_score()
.get_score(labels_true, labels_pred)
.0
}
/// Computes R2 score, see [R2](r2/index.html).
/// * `y_true` - Ground truth (correct) target values.
/// * `y_pred` - Estimated target values.
pub fn completeness_score<T: RealNumber, V: BaseVector<T>>(labels_true: &V, labels_pred: &V) -> T {
ClusterMetrics::hcv_score()
.get_score(labels_true, labels_pred)
.1
}
/// Computes R2 score, see [R2](r2/index.html).
/// * `y_true` - Ground truth (correct) target values.
/// * `y_pred` - Estimated target values.
pub fn v_measure_score<T: RealNumber, V: BaseVector<T>>(labels_true: &V, labels_pred: &V) -> T {
ClusterMetrics::hcv_score()
.get_score(labels_true, labels_pred)
.2
}