+10
-13
@@ -26,8 +26,8 @@ use std::marker::PhantomData;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::basic::arrays::{Array1, ArrayView1, MutArrayView1};
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::linalg::basic::arrays::{Array1, ArrayView1};
|
||||
use crate::numbers::floatnum::FloatNumber;
|
||||
|
||||
use crate::metrics::Metrics;
|
||||
|
||||
@@ -38,14 +38,14 @@ pub struct AUC<T> {
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: Number + Ord> Metrics<T> for AUC<T> {
|
||||
impl<T: FloatNumber + PartialOrd> Metrics<T> for AUC<T> {
|
||||
/// create a typed object to call AUC functions
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
fn new_with(_parameter: T) -> Self {
|
||||
fn new_with(_parameter: f64) -> Self {
|
||||
Self {
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
@@ -53,11 +53,7 @@ impl<T: Number + Ord> Metrics<T> for AUC<T> {
|
||||
/// AUC score.
|
||||
/// * `y_true` - ground truth (correct) labels.
|
||||
/// * `y_pred_prob` - probability estimates, as returned by a classifier.
|
||||
fn get_score(
|
||||
&self,
|
||||
y_true: &dyn ArrayView1<T>,
|
||||
y_pred_prob: &dyn ArrayView1<T>,
|
||||
) -> f64 {
|
||||
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred_prob: &dyn ArrayView1<T>) -> f64 {
|
||||
let mut pos = T::zero();
|
||||
let mut neg = T::zero();
|
||||
|
||||
@@ -76,9 +72,10 @@ impl<T: Number + Ord> Metrics<T> for AUC<T> {
|
||||
}
|
||||
}
|
||||
|
||||
let y_pred = y_pred_prob.clone();
|
||||
|
||||
let label_idx = y_pred.argsort();
|
||||
let y_pred: Vec<T> =
|
||||
Array1::<T>::from_iterator(y_pred_prob.iterator(0).copied(), y_pred_prob.shape());
|
||||
// TODO: try to use `crate::algorithm::sort::quick_sort` here
|
||||
let label_idx: Vec<usize> = y_pred.argsort();
|
||||
|
||||
let mut rank = vec![0f64; n];
|
||||
let mut i = 0;
|
||||
@@ -108,7 +105,7 @@ impl<T: Number + Ord> Metrics<T> for AUC<T> {
|
||||
let pos = pos.to_f64().unwrap();
|
||||
let neg = neg.to_f64().unwrap();
|
||||
|
||||
T::from(auc - (pos * (pos + 1f64) / 2.0)).unwrap() / T::from(pos * neg).unwrap()
|
||||
(auc - (pos * (pos + 1f64) / 2f64)) / (pos * neg)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user