diff --git a/src/algorithm/neighbour/fastpair.rs b/src/algorithm/neighbour/fastpair.rs index d676460..bea438e 100644 --- a/src/algorithm/neighbour/fastpair.rs +++ b/src/algorithm/neighbour/fastpair.rs @@ -174,7 +174,7 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2> FastPair<'a, T, M> { /// /// Brute force algorithm, used only for comparison and testing /// - #[cfg(feature = "fp_bench")] + #[allow(dead_code)] pub fn closest_pair_brute(&self) -> PairwiseDistance { use itertools::Itertools; let m = self.samples.shape().0; diff --git a/src/metrics/auc.rs b/src/metrics/auc.rs index a94f3a3..e8d02b2 100644 --- a/src/metrics/auc.rs +++ b/src/metrics/auc.rs @@ -26,8 +26,8 @@ use std::marker::PhantomData; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -use crate::linalg::basic::arrays::{Array1, ArrayView1, MutArrayView1}; -use crate::numbers::basenum::Number; +use crate::linalg::basic::arrays::{Array1, ArrayView1}; +use crate::numbers::floatnum::FloatNumber; use crate::metrics::Metrics; @@ -38,14 +38,14 @@ pub struct AUC { _phantom: PhantomData, } -impl Metrics for AUC { +impl Metrics for AUC { /// create a typed object to call AUC functions fn new() -> Self { Self { _phantom: PhantomData, } } - fn new_with(_parameter: T) -> Self { + fn new_with(_parameter: f64) -> Self { Self { _phantom: PhantomData, } @@ -53,11 +53,7 @@ impl Metrics for AUC { /// AUC score. /// * `y_true` - ground truth (correct) labels. /// * `y_pred_prob` - probability estimates, as returned by a classifier. - fn get_score( - &self, - y_true: &dyn ArrayView1, - y_pred_prob: &dyn ArrayView1, - ) -> f64 { + fn get_score(&self, y_true: &dyn ArrayView1, y_pred_prob: &dyn ArrayView1) -> f64 { let mut pos = T::zero(); let mut neg = T::zero(); @@ -76,9 +72,10 @@ impl Metrics for AUC { } } - let y_pred = y_pred_prob.clone(); - - let label_idx = y_pred.argsort(); + let y_pred: Vec = + Array1::::from_iterator(y_pred_prob.iterator(0).copied(), y_pred_prob.shape()); + // TODO: try to use `crate::algorithm::sort::quick_sort` here + let label_idx: Vec = y_pred.argsort(); let mut rank = vec![0f64; n]; let mut i = 0; @@ -108,7 +105,7 @@ impl Metrics for AUC { let pos = pos.to_f64().unwrap(); let neg = neg.to_f64().unwrap(); - T::from(auc - (pos * (pos + 1f64) / 2.0)).unwrap() / T::from(pos * neg).unwrap() + (auc - (pos * (pos + 1f64) / 2f64)) / (pos * neg) } } diff --git a/src/metrics/mod.rs b/src/metrics/mod.rs index 503391c..25cffa3 100644 --- a/src/metrics/mod.rs +++ b/src/metrics/mod.rs @@ -55,7 +55,7 @@ pub mod accuracy; // TODO: reimplement AUC // /// Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores. -// pub mod auc; +pub mod auc; /// Compute the homogeneity, completeness and V-Measure scores. pub mod cluster_hcv; pub(crate) mod cluster_helpers;