@@ -174,7 +174,7 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
|
||||
///
|
||||
/// Brute force algorithm, used only for comparison and testing
|
||||
///
|
||||
#[cfg(feature = "fp_bench")]
|
||||
#[allow(dead_code)]
|
||||
pub fn closest_pair_brute(&self) -> PairwiseDistance<T> {
|
||||
use itertools::Itertools;
|
||||
let m = self.samples.shape().0;
|
||||
|
||||
+10
-13
@@ -26,8 +26,8 @@ use std::marker::PhantomData;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::basic::arrays::{Array1, ArrayView1, MutArrayView1};
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::linalg::basic::arrays::{Array1, ArrayView1};
|
||||
use crate::numbers::floatnum::FloatNumber;
|
||||
|
||||
use crate::metrics::Metrics;
|
||||
|
||||
@@ -38,14 +38,14 @@ pub struct AUC<T> {
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: Number + Ord> Metrics<T> for AUC<T> {
|
||||
impl<T: FloatNumber + PartialOrd> Metrics<T> for AUC<T> {
|
||||
/// create a typed object to call AUC functions
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
fn new_with(_parameter: T) -> Self {
|
||||
fn new_with(_parameter: f64) -> Self {
|
||||
Self {
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
@@ -53,11 +53,7 @@ impl<T: Number + Ord> Metrics<T> for AUC<T> {
|
||||
/// AUC score.
|
||||
/// * `y_true` - ground truth (correct) labels.
|
||||
/// * `y_pred_prob` - probability estimates, as returned by a classifier.
|
||||
fn get_score(
|
||||
&self,
|
||||
y_true: &dyn ArrayView1<T>,
|
||||
y_pred_prob: &dyn ArrayView1<T>,
|
||||
) -> f64 {
|
||||
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred_prob: &dyn ArrayView1<T>) -> f64 {
|
||||
let mut pos = T::zero();
|
||||
let mut neg = T::zero();
|
||||
|
||||
@@ -76,9 +72,10 @@ impl<T: Number + Ord> Metrics<T> for AUC<T> {
|
||||
}
|
||||
}
|
||||
|
||||
let y_pred = y_pred_prob.clone();
|
||||
|
||||
let label_idx = y_pred.argsort();
|
||||
let y_pred: Vec<T> =
|
||||
Array1::<T>::from_iterator(y_pred_prob.iterator(0).copied(), y_pred_prob.shape());
|
||||
// TODO: try to use `crate::algorithm::sort::quick_sort` here
|
||||
let label_idx: Vec<usize> = y_pred.argsort();
|
||||
|
||||
let mut rank = vec![0f64; n];
|
||||
let mut i = 0;
|
||||
@@ -108,7 +105,7 @@ impl<T: Number + Ord> Metrics<T> for AUC<T> {
|
||||
let pos = pos.to_f64().unwrap();
|
||||
let neg = neg.to_f64().unwrap();
|
||||
|
||||
T::from(auc - (pos * (pos + 1f64) / 2.0)).unwrap() / T::from(pos * neg).unwrap()
|
||||
(auc - (pos * (pos + 1f64) / 2f64)) / (pos * neg)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+1
-1
@@ -55,7 +55,7 @@
|
||||
pub mod accuracy;
|
||||
// TODO: reimplement AUC
|
||||
// /// Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores.
|
||||
// pub mod auc;
|
||||
pub mod auc;
|
||||
/// Compute the homogeneity, completeness and V-Measure scores.
|
||||
pub mod cluster_hcv;
|
||||
pub(crate) mod cluster_helpers;
|
||||
|
||||
Reference in New Issue
Block a user