From a19398fd705cc2c07f3196bf6f62f04f6f2897e5 Mon Sep 17 00:00:00 2001 From: Volodymyr Orlov Date: Wed, 23 Sep 2020 16:41:39 -0700 Subject: [PATCH] fix: code and documentation cleanup --- src/linalg/lu.rs | 1 - src/math/distance/mahalanobis.rs | 2 -- src/metrics/cluster_hcv.rs | 8 ++++---- src/metrics/cluster_helpers.rs | 9 ++++++--- src/metrics/mean_absolute_error.rs | 2 -- src/metrics/mean_squared_error.rs | 2 -- src/metrics/mod.rs | 24 ++++++++++++++---------- 7 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/linalg/lu.rs b/src/linalg/lu.rs index ee6429c..a4cc58d 100644 --- a/src/linalg/lu.rs +++ b/src/linalg/lu.rs @@ -287,7 +287,6 @@ mod tests { let expected = DenseMatrix::from_2d_array(&[&[-6.0, 3.6, 1.4], &[5.0, -3.0, -1.0], &[-1.0, 0.8, 0.2]]); let a_inv = a.lu().and_then(|lu| lu.inverse()).unwrap(); - println!("{}", a_inv); assert!(a_inv.approximate_eq(&expected, 1e-4)); } } diff --git a/src/math/distance/mahalanobis.rs b/src/math/distance/mahalanobis.rs index 8f322da..6c205e5 100644 --- a/src/math/distance/mahalanobis.rs +++ b/src/math/distance/mahalanobis.rs @@ -108,8 +108,6 @@ impl> Distance, T> for Mahalanobis { ); } - println!("{}", self.sigmaInv); - let n = x.len(); let mut z = vec![T::zero(); n]; for i in 0..n { diff --git a/src/metrics/cluster_hcv.rs b/src/metrics/cluster_hcv.rs index e1b112e..bdefc8d 100644 --- a/src/metrics/cluster_hcv.rs +++ b/src/metrics/cluster_hcv.rs @@ -5,13 +5,13 @@ use crate::math::num::RealNumber; use crate::metrics::cluster_helpers::*; #[derive(Serialize, Deserialize, Debug)] -/// Mean Absolute Error +/// Homogeneity, completeness and V-Measure scores. pub struct HCVScore {} impl HCVScore { - /// Computes mean absolute error - /// * `y_true` - Ground truth (correct) target values. - /// * `y_pred` - Estimated target values. + /// Computes Homogeneity, completeness and V-Measure scores at once. + /// * `labels_true` - ground truth class labels to be used as a reference. + /// * `labels_pred` - cluster labels to evaluate. pub fn get_score>( &self, labels_true: &V, diff --git a/src/metrics/cluster_helpers.rs b/src/metrics/cluster_helpers.rs index 3086315..76cd643 100644 --- a/src/metrics/cluster_helpers.rs +++ b/src/metrics/cluster_helpers.rs @@ -106,14 +106,17 @@ mod tests { let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0]; let v2 = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]; - println!("{:?}", contingency_matrix(&v1, &v2)); + assert_eq!( + vec!(vec!(1, 2), vec!(2, 0), vec!(1, 0), vec!(1, 0)), + contingency_matrix(&v1, &v2) + ); } #[test] fn entropy_test() { let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0]; - println!("{:?}", entropy(&v1)); + assert!((1.2770f32 - entropy(&v1).unwrap()).abs() < 1e-4); } #[test] @@ -122,6 +125,6 @@ mod tests { let v2 = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]; let s: f32 = mutual_info_score(&contingency_matrix(&v1, &v2)); - println!("{}", s); + assert!((0.3254 - s).abs() < 1e-4); } } diff --git a/src/metrics/mean_absolute_error.rs b/src/metrics/mean_absolute_error.rs index 55132cd..3e5099e 100644 --- a/src/metrics/mean_absolute_error.rs +++ b/src/metrics/mean_absolute_error.rs @@ -62,8 +62,6 @@ mod tests { let score1: f64 = MeanAbsoluteError {}.get_score(&y_pred, &y_true); let score2: f64 = MeanAbsoluteError {}.get_score(&y_true, &y_true); - println!("{}", score1); - assert!((score1 - 0.5).abs() < 1e-8); assert!((score2 - 0.0).abs() < 1e-8); } diff --git a/src/metrics/mean_squared_error.rs b/src/metrics/mean_squared_error.rs index 2b4c5be..816cc70 100644 --- a/src/metrics/mean_squared_error.rs +++ b/src/metrics/mean_squared_error.rs @@ -62,8 +62,6 @@ mod tests { let score1: f64 = MeanSquareError {}.get_score(&y_pred, &y_true); let score2: f64 = MeanSquareError {}.get_score(&y_true, &y_true); - println!("{}", score1); - assert!((score1 - 0.375).abs() < 1e-8); assert!((score2 - 0.0).abs() < 1e-8); } diff --git a/src/metrics/mod.rs b/src/metrics/mod.rs index 821b3f1..4fe199b 100644 --- a/src/metrics/mod.rs +++ b/src/metrics/mod.rs @@ -8,6 +8,7 @@ //! //! * [Classification metrics](struct.ClassificationMetrics.html) //! * [Regression metrics](struct.RegressionMetrics.html) +//! * [Clustering metrics](struct.ClusterMetrics.html) //! //! Example: //! ``` @@ -54,6 +55,7 @@ pub mod accuracy; /// Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores. pub mod auc; +/// Compute the homogeneity, completeness and V-Measure scores. pub mod cluster_hcv; pub(crate) mod cluster_helpers; /// F1 score, also known as balanced F-score or F-measure. @@ -126,7 +128,7 @@ impl RegressionMetrics { } impl ClusterMetrics { - /// Mean squared error, see [mean squared error](mean_squared_error/index.html). + /// Homogeneity and completeness and V-Measure scores at once. pub fn hcv_score() -> cluster_hcv::HCVScore { cluster_hcv::HCVScore {} } @@ -188,27 +190,29 @@ pub fn r2>(y_true: &V, y_pred: &V) -> T { RegressionMetrics::r2().get_score(y_true, y_pred) } -/// Computes R2 score, see [R2](r2/index.html). -/// * `y_true` - Ground truth (correct) target values. -/// * `y_pred` - Estimated target values. +/// Homogeneity metric of a cluster labeling given a ground truth (range is between 0.0 and 1.0). +/// A cluster result satisfies homogeneity if all of its clusters contain only data points which are members of a single class. +/// * `labels_true` - ground truth class labels to be used as a reference. +/// * `labels_pred` - cluster labels to evaluate. pub fn homogeneity_score>(labels_true: &V, labels_pred: &V) -> T { ClusterMetrics::hcv_score() .get_score(labels_true, labels_pred) .0 } -/// Computes R2 score, see [R2](r2/index.html). -/// * `y_true` - Ground truth (correct) target values. -/// * `y_pred` - Estimated target values. +/// +/// Completeness metric of a cluster labeling given a ground truth (range is between 0.0 and 1.0). +/// * `labels_true` - ground truth class labels to be used as a reference. +/// * `labels_pred` - cluster labels to evaluate. pub fn completeness_score>(labels_true: &V, labels_pred: &V) -> T { ClusterMetrics::hcv_score() .get_score(labels_true, labels_pred) .1 } -/// Computes R2 score, see [R2](r2/index.html). -/// * `y_true` - Ground truth (correct) target values. -/// * `y_pred` - Estimated target values. +/// The harmonic mean between homogeneity and completeness. +/// * `labels_true` - ground truth class labels to be used as a reference. +/// * `labels_pred` - cluster labels to evaluate. pub fn v_measure_score>(labels_true: &V, labels_pred: &V) -> T { ClusterMetrics::hcv_score() .get_score(labels_true, labels_pred)