Fix metrics::auc (#212)

* Fix metrics::auc
This commit is contained in:
Lorenzo
2022-11-01 12:50:46 +00:00
committed by morenol
parent a16927aa16
commit 4d36b7f34f
3 changed files with 12 additions and 15 deletions
+1 -1
View File
@@ -174,7 +174,7 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
///
/// Brute force algorithm, used only for comparison and testing
///
#[cfg(feature = "fp_bench")]
#[allow(dead_code)]
pub fn closest_pair_brute(&self) -> PairwiseDistance<T> {
use itertools::Itertools;
let m = self.samples.shape().0;
+10 -13
View File
@@ -26,8 +26,8 @@ use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::linalg::basic::arrays::{Array1, ArrayView1, MutArrayView1};
use crate::numbers::basenum::Number;
use crate::linalg::basic::arrays::{Array1, ArrayView1};
use crate::numbers::floatnum::FloatNumber;
use crate::metrics::Metrics;
@@ -38,14 +38,14 @@ pub struct AUC<T> {
_phantom: PhantomData<T>,
}
impl<T: Number + Ord> Metrics<T> for AUC<T> {
impl<T: FloatNumber + PartialOrd> Metrics<T> for AUC<T> {
/// create a typed object to call AUC functions
fn new() -> Self {
Self {
_phantom: PhantomData,
}
}
fn new_with(_parameter: T) -> Self {
fn new_with(_parameter: f64) -> Self {
Self {
_phantom: PhantomData,
}
@@ -53,11 +53,7 @@ impl<T: Number + Ord> Metrics<T> for AUC<T> {
/// AUC score.
/// * `y_true` - ground truth (correct) labels.
/// * `y_pred_prob` - probability estimates, as returned by a classifier.
fn get_score(
&self,
y_true: &dyn ArrayView1<T>,
y_pred_prob: &dyn ArrayView1<T>,
) -> f64 {
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred_prob: &dyn ArrayView1<T>) -> f64 {
let mut pos = T::zero();
let mut neg = T::zero();
@@ -76,9 +72,10 @@ impl<T: Number + Ord> Metrics<T> for AUC<T> {
}
}
let y_pred = y_pred_prob.clone();
let label_idx = y_pred.argsort();
let y_pred: Vec<T> =
Array1::<T>::from_iterator(y_pred_prob.iterator(0).copied(), y_pred_prob.shape());
// TODO: try to use `crate::algorithm::sort::quick_sort` here
let label_idx: Vec<usize> = y_pred.argsort();
let mut rank = vec![0f64; n];
let mut i = 0;
@@ -108,7 +105,7 @@ impl<T: Number + Ord> Metrics<T> for AUC<T> {
let pos = pos.to_f64().unwrap();
let neg = neg.to_f64().unwrap();
T::from(auc - (pos * (pos + 1f64) / 2.0)).unwrap() / T::from(pos * neg).unwrap()
(auc - (pos * (pos + 1f64) / 2f64)) / (pos * neg)
}
}
+1 -1
View File
@@ -55,7 +55,7 @@
pub mod accuracy;
// TODO: reimplement AUC
// /// Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores.
// pub mod auc;
pub mod auc;
/// Compute the homogeneity, completeness and V-Measure scores.
pub mod cluster_hcv;
pub(crate) mod cluster_helpers;