Fix metrics::auc (#212)

* Fix metrics::auc
This commit is contained in:
Lorenzo
2022-11-01 12:50:46 +00:00
committed by morenol
parent a16927aa16
commit 4d36b7f34f
3 changed files with 12 additions and 15 deletions
+1 -1
View File
@@ -174,7 +174,7 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
/// ///
/// Brute force algorithm, used only for comparison and testing /// Brute force algorithm, used only for comparison and testing
/// ///
#[cfg(feature = "fp_bench")] #[allow(dead_code)]
pub fn closest_pair_brute(&self) -> PairwiseDistance<T> { pub fn closest_pair_brute(&self) -> PairwiseDistance<T> {
use itertools::Itertools; use itertools::Itertools;
let m = self.samples.shape().0; let m = self.samples.shape().0;
+10 -13
View File
@@ -26,8 +26,8 @@ use std::marker::PhantomData;
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::linalg::basic::arrays::{Array1, ArrayView1, MutArrayView1}; use crate::linalg::basic::arrays::{Array1, ArrayView1};
use crate::numbers::basenum::Number; use crate::numbers::floatnum::FloatNumber;
use crate::metrics::Metrics; use crate::metrics::Metrics;
@@ -38,14 +38,14 @@ pub struct AUC<T> {
_phantom: PhantomData<T>, _phantom: PhantomData<T>,
} }
impl<T: Number + Ord> Metrics<T> for AUC<T> { impl<T: FloatNumber + PartialOrd> Metrics<T> for AUC<T> {
/// create a typed object to call AUC functions /// create a typed object to call AUC functions
fn new() -> Self { fn new() -> Self {
Self { Self {
_phantom: PhantomData, _phantom: PhantomData,
} }
} }
fn new_with(_parameter: T) -> Self { fn new_with(_parameter: f64) -> Self {
Self { Self {
_phantom: PhantomData, _phantom: PhantomData,
} }
@@ -53,11 +53,7 @@ impl<T: Number + Ord> Metrics<T> for AUC<T> {
/// AUC score. /// AUC score.
/// * `y_true` - ground truth (correct) labels. /// * `y_true` - ground truth (correct) labels.
/// * `y_pred_prob` - probability estimates, as returned by a classifier. /// * `y_pred_prob` - probability estimates, as returned by a classifier.
fn get_score( fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred_prob: &dyn ArrayView1<T>) -> f64 {
&self,
y_true: &dyn ArrayView1<T>,
y_pred_prob: &dyn ArrayView1<T>,
) -> f64 {
let mut pos = T::zero(); let mut pos = T::zero();
let mut neg = T::zero(); let mut neg = T::zero();
@@ -76,9 +72,10 @@ impl<T: Number + Ord> Metrics<T> for AUC<T> {
} }
} }
let y_pred = y_pred_prob.clone(); let y_pred: Vec<T> =
Array1::<T>::from_iterator(y_pred_prob.iterator(0).copied(), y_pred_prob.shape());
let label_idx = y_pred.argsort(); // TODO: try to use `crate::algorithm::sort::quick_sort` here
let label_idx: Vec<usize> = y_pred.argsort();
let mut rank = vec![0f64; n]; let mut rank = vec![0f64; n];
let mut i = 0; let mut i = 0;
@@ -108,7 +105,7 @@ impl<T: Number + Ord> Metrics<T> for AUC<T> {
let pos = pos.to_f64().unwrap(); let pos = pos.to_f64().unwrap();
let neg = neg.to_f64().unwrap(); let neg = neg.to_f64().unwrap();
T::from(auc - (pos * (pos + 1f64) / 2.0)).unwrap() / T::from(pos * neg).unwrap() (auc - (pos * (pos + 1f64) / 2f64)) / (pos * neg)
} }
} }
+1 -1
View File
@@ -55,7 +55,7 @@
pub mod accuracy; pub mod accuracy;
// TODO: reimplement AUC // TODO: reimplement AUC
// /// Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores. // /// Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores.
// pub mod auc; pub mod auc;
/// Compute the homogeneity, completeness and V-Measure scores. /// Compute the homogeneity, completeness and V-Measure scores.
pub mod cluster_hcv; pub mod cluster_hcv;
pub(crate) mod cluster_helpers; pub(crate) mod cluster_helpers;