diff --git a/src/math/num.rs b/src/math/num.rs index 7199949..c454b9d 100644 --- a/src/math/num.rs +++ b/src/math/num.rs @@ -46,8 +46,11 @@ pub trait RealNumber: self * self } - /// Raw transmutation to u64 + /// Raw transmutation to u32 fn to_f32_bits(self) -> u32; + + /// Raw transmutation to u64 + fn to_f64_bits(self) -> u64; } impl RealNumber for f64 { @@ -89,6 +92,10 @@ impl RealNumber for f64 { fn to_f32_bits(self) -> u32 { self.to_bits() as u32 } + + fn to_f64_bits(self) -> u64 { + self.to_bits() + } } impl RealNumber for f32 { @@ -130,6 +137,10 @@ impl RealNumber for f32 { fn to_f32_bits(self) -> u32 { self.to_bits() } + + fn to_f64_bits(self) -> u64 { + self.to_bits() as u64 + } } #[cfg(test)] diff --git a/src/metrics/precision.rs b/src/metrics/precision.rs index a0171aa..a2bad30 100644 --- a/src/metrics/precision.rs +++ b/src/metrics/precision.rs @@ -18,6 +18,8 @@ //! //! //! +use std::collections::HashSet; + #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -42,34 +44,33 @@ impl Precision { ); } + let mut classes = HashSet::new(); + for i in 0..y_true.len() { + classes.insert(y_true.get(i).to_f64_bits()); + } + let classes = classes.len(); + let mut tp = 0; - let mut p = 0; - let n = y_true.len(); - for i in 0..n { - if y_true.get(i) != T::zero() && y_true.get(i) != T::one() { - panic!( - "Precision can only be applied to binary classification: {}", - y_true.get(i) - ); - } - - if y_pred.get(i) != T::zero() && y_pred.get(i) != T::one() { - panic!( - "Precision can only be applied to binary classification: {}", - y_pred.get(i) - ); - } - - if y_pred.get(i) == T::one() { - p += 1; - - if y_true.get(i) == T::one() { + let mut fp = 0; + for i in 0..y_true.len() { + if y_pred.get(i) == y_true.get(i) { + if classes == 2 { + if y_true.get(i) == T::one() { + tp += 1; + } + } else { tp += 1; } + } else if classes == 2 { + if y_true.get(i) == T::one() { + fp += 1; + } + } else { + fp += 1; } } - T::from_i64(tp).unwrap() / T::from_i64(p).unwrap() + T::from_i64(tp).unwrap() / (T::from_i64(tp).unwrap() + T::from_i64(fp).unwrap()) } } @@ -88,5 +89,24 @@ mod tests { assert!((score1 - 0.5).abs() < 1e-8); assert!((score2 - 1.0).abs() < 1e-8); + + let y_pred: Vec = vec![0., 0., 1., 1., 1., 1.]; + let y_true: Vec = vec![0., 1., 1., 0., 1., 0.]; + + let score3: f64 = Precision {}.get_score(&y_pred, &y_true); + assert!((score3 - 0.5).abs() < 1e-8); + } + + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + fn precision_multiclass() { + let y_true: Vec = vec![0., 0., 0., 1., 1., 1., 2., 2., 2.]; + let y_pred: Vec = vec![0., 1., 2., 0., 1., 2., 0., 1., 2.]; + + let score1: f64 = Precision {}.get_score(&y_pred, &y_true); + let score2: f64 = Precision {}.get_score(&y_pred, &y_pred); + + assert!((score1 - 0.333333333).abs() < 1e-8); + assert!((score2 - 1.0).abs() < 1e-8); } } diff --git a/src/metrics/recall.rs b/src/metrics/recall.rs index 18863ae..48ddeeb 100644 --- a/src/metrics/recall.rs +++ b/src/metrics/recall.rs @@ -18,6 +18,9 @@ //! //! //! +use std::collections::HashSet; +use std::convert::TryInto; + #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -42,34 +45,32 @@ impl Recall { ); } + let mut classes = HashSet::new(); + for i in 0..y_true.len() { + classes.insert(y_true.get(i).to_f64_bits()); + } + let classes: i64 = classes.len().try_into().unwrap(); + let mut tp = 0; - let mut p = 0; - let n = y_true.len(); - for i in 0..n { - if y_true.get(i) != T::zero() && y_true.get(i) != T::one() { - panic!( - "Recall can only be applied to binary classification: {}", - y_true.get(i) - ); - } - - if y_pred.get(i) != T::zero() && y_pred.get(i) != T::one() { - panic!( - "Recall can only be applied to binary classification: {}", - y_pred.get(i) - ); - } - - if y_true.get(i) == T::one() { - p += 1; - - if y_pred.get(i) == T::one() { + let mut fne = 0; + for i in 0..y_true.len() { + if y_pred.get(i) == y_true.get(i) { + if classes == 2 { + if y_true.get(i) == T::one() { + tp += 1; + } + } else { tp += 1; } + } else if classes == 2 { + if y_true.get(i) != T::one() { + fne += 1; + } + } else { + fne += 1; } } - - T::from_i64(tp).unwrap() / T::from_i64(p).unwrap() + T::from_i64(tp).unwrap() / (T::from_i64(tp).unwrap() + T::from_i64(fne).unwrap()) } } @@ -88,5 +89,24 @@ mod tests { assert!((score1 - 0.5).abs() < 1e-8); assert!((score2 - 1.0).abs() < 1e-8); + + let y_pred: Vec = vec![0., 0., 1., 1., 1., 1.]; + let y_true: Vec = vec![0., 1., 1., 0., 1., 0.]; + + let score3: f64 = Recall {}.get_score(&y_pred, &y_true); + assert!((score3 - 0.66666666).abs() < 1e-8); + } + + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + fn recall_multiclass() { + let y_true: Vec = vec![0., 0., 0., 1., 1., 1., 2., 2., 2.]; + let y_pred: Vec = vec![0., 1., 2., 0., 1., 2., 0., 1., 2.]; + + let score1: f64 = Recall {}.get_score(&y_pred, &y_true); + let score2: f64 = Recall {}.get_score(&y_pred, &y_pred); + + assert!((score1 - 0.333333333).abs() < 1e-8); + assert!((score2 - 1.0).abs() < 1e-8); } }