From 14113b4152245679ef53e0f2af7deb2ee005875b Mon Sep 17 00:00:00 2001 From: Volodymyr Orlov Date: Mon, 8 Jun 2020 14:47:59 -0700 Subject: [PATCH] feat: adds F1 and roc_auc_score --- src/algorithm/sort/quick_sort.rs | 15 ++++-- src/linalg/mod.rs | 2 + src/linalg/naive/dense_matrix.rs | 5 ++ src/linalg/nalgebra_bindings.rs | 10 ++++ src/linalg/ndarray_bindings.rs | 10 ++++ src/metrics/auc.rs | 81 ++++++++++++++++++++++++++++ src/metrics/f1.rs | 44 +++++++++++++++ src/metrics/mod.rs | 2 + src/tree/decision_tree_classifier.rs | 2 +- src/tree/decision_tree_regressor.rs | 2 +- 10 files changed, 167 insertions(+), 6 deletions(-) create mode 100644 src/metrics/auc.rs create mode 100644 src/metrics/f1.rs diff --git a/src/algorithm/sort/quick_sort.rs b/src/algorithm/sort/quick_sort.rs index 79bc0e7..e160ed2 100644 --- a/src/algorithm/sort/quick_sort.rs +++ b/src/algorithm/sort/quick_sort.rs @@ -1,11 +1,18 @@ use num_traits::Float; pub trait QuickArgSort { - fn quick_argsort(&mut self) -> Vec; + fn quick_argsort_mut(&mut self) -> Vec; + + fn quick_argsort(&self) -> Vec; } impl QuickArgSort for Vec { - fn quick_argsort(&mut self) -> Vec { + fn quick_argsort(&self) -> Vec { + let mut v = self.clone(); + v.quick_argsort_mut() + } + + fn quick_argsort_mut(&mut self) -> Vec { let stack_size = 64; let mut jstack = -1; let mut l = 0; @@ -108,10 +115,10 @@ mod tests { #[test] fn with_capacity() { - let mut arr1 = vec![0.3, 0.1, 0.2, 0.4, 0.9, 0.5, 0.7, 0.6, 0.8]; + let arr1 = vec![0.3, 0.1, 0.2, 0.4, 0.9, 0.5, 0.7, 0.6, 0.8]; assert_eq!(vec![1, 2, 0, 3, 5, 7, 6, 8, 4], arr1.quick_argsort()); - let mut arr2 = vec![ + let arr2 = vec![ 0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1, 1.4, 1.5, 1.5, 1.3, 1.5, 1.3, 1.6, 1.0, 1.3, 1.4, ]; diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs index 0b5734b..412cf19 100644 --- a/src/linalg/mod.rs +++ b/src/linalg/mod.rs @@ -22,6 +22,8 @@ pub trait BaseVector: Clone + Debug { fn set(&mut self, i: usize, x: T); fn len(&self) -> usize; + + fn to_vec(&self) -> Vec; } pub trait BaseMatrix: Clone + Debug { diff --git a/src/linalg/naive/dense_matrix.rs b/src/linalg/naive/dense_matrix.rs index 4541de5..f807140 100644 --- a/src/linalg/naive/dense_matrix.rs +++ b/src/linalg/naive/dense_matrix.rs @@ -27,6 +27,11 @@ impl BaseVector for Vec { fn len(&self) -> usize { self.len() } + + fn to_vec(&self) -> Vec { + let v = self.clone(); + v + } } #[derive(Debug, Clone)] diff --git a/src/linalg/nalgebra_bindings.rs b/src/linalg/nalgebra_bindings.rs index 851d878..5d5b875 100644 --- a/src/linalg/nalgebra_bindings.rs +++ b/src/linalg/nalgebra_bindings.rs @@ -22,6 +22,10 @@ impl BaseVector for MatrixMN { fn len(&self) -> usize { self.len() } + + fn to_vec(&self) -> Vec { + self.row(0).iter().map(|v| *v).collect() + } } impl @@ -384,6 +388,12 @@ mod tests { assert_eq!(5., BaseVector::get(&v, 1)); } + #[test] + fn vec_to_vec() { + let v = RowDVector::from_vec(vec![1., 2., 3.]); + assert_eq!(vec![1., 2., 3.], v.to_vec()); + } + #[test] fn get_set_dynamic() { let mut m = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); diff --git a/src/linalg/ndarray_bindings.rs b/src/linalg/ndarray_bindings.rs index fc80e0e..dbbd8f9 100644 --- a/src/linalg/ndarray_bindings.rs +++ b/src/linalg/ndarray_bindings.rs @@ -27,6 +27,10 @@ impl BaseVector for ArrayBase, Ix1> { fn len(&self) -> usize { self.len() } + + fn to_vec(&self) -> Vec { + self.to_owned().to_vec() + } } impl @@ -351,6 +355,12 @@ mod tests { assert_eq!(3, v.len()); } + #[test] + fn vec_to_vec() { + let v = arr1(&[1., 2., 3.]); + assert_eq!(vec![1., 2., 3.], v.to_vec()); + } + #[test] fn from_to_row_vec() { let vec = arr1(&[1., 2., 3.]); diff --git a/src/metrics/auc.rs b/src/metrics/auc.rs new file mode 100644 index 0000000..cf34e2c --- /dev/null +++ b/src/metrics/auc.rs @@ -0,0 +1,81 @@ +#![allow(non_snake_case)] + +use serde::{Deserialize, Serialize}; + +use crate::algorithm::sort::quick_sort::QuickArgSort; +use crate::linalg::BaseVector; +use crate::math::num::FloatExt; + +#[derive(Serialize, Deserialize, Debug)] +pub struct AUC {} + +impl AUC { + pub fn get_score>(&self, y_true: &V, y_pred_prob: &V) -> T { + let mut pos = T::zero(); + let mut neg = T::zero(); + + let n = y_true.len(); + + for i in 0..n { + if y_true.get(i) == T::zero() { + neg = neg + T::one(); + } else if y_true.get(i) == T::one() { + pos = pos + T::one(); + } else { + panic!( + "AUC is only for binary classification. Invalid label: {}", + y_true.get(i) + ); + } + } + + let mut y_pred = y_pred_prob.to_vec(); + + let label_idx = y_pred.quick_argsort_mut(); + + let mut rank = vec![T::zero(); n]; + let mut i = 0; + while i < n { + if i == n - 1 || y_pred[i] != y_pred[i + 1] { + rank[i] = T::from_usize(i + 1).unwrap(); + } else { + let mut j = i + 1; + while j < n && y_pred[j] == y_pred[i] { + j += 1; + } + let r = T::from_usize(i + 1 + j).unwrap() / T::two(); + for k in i..j { + rank[k] = r; + } + i = j - 1; + } + i += 1; + } + + let mut auc = T::zero(); + for i in 0..n { + if y_true.get(label_idx[i]) == T::one() { + auc = auc + rank[i]; + } + } + + (auc - (pos * (pos + T::one()) / T::two())) / (pos * neg) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn auc() { + let y_true: Vec = vec![0., 0., 1., 1.]; + let y_pred: Vec = vec![0.1, 0.4, 0.35, 0.8]; + + let score1: f64 = AUC {}.get_score(&y_true, &y_pred); + let score2: f64 = AUC {}.get_score(&y_true, &y_true); + + assert!((score1 - 0.75).abs() < 1e-8); + assert!((score2 - 1.0).abs() < 1e-8); + } +} diff --git a/src/metrics/f1.rs b/src/metrics/f1.rs new file mode 100644 index 0000000..1313322 --- /dev/null +++ b/src/metrics/f1.rs @@ -0,0 +1,44 @@ +use serde::{Deserialize, Serialize}; + +use crate::linalg::BaseVector; +use crate::math::num::FloatExt; +use crate::metrics::precision::Precision; +use crate::metrics::recall::Recall; + +#[derive(Serialize, Deserialize, Debug)] +pub struct F1 {} + +impl F1 { + pub fn get_score>(&self, y_true: &V, y_prod: &V) -> T { + if y_true.len() != y_prod.len() { + panic!( + "The vector sizes don't match: {} != {}", + y_true.len(), + y_prod.len() + ); + } + let beta2 = T::one(); + + let p = Precision {}.get_score(y_true, y_prod); + let r = Recall {}.get_score(y_true, y_prod); + + (T::one() + beta2) * (p * r) / (beta2 * p + r) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn f1() { + let y_pred: Vec = vec![0., 0., 1., 1., 1., 1.]; + let y_true: Vec = vec![0., 1., 1., 0., 1., 0.]; + + let score1: f64 = F1 {}.get_score(&y_pred, &y_true); + let score2: f64 = F1 {}.get_score(&y_true, &y_true); + + assert!((score1 - 0.57142857).abs() < 1e-8); + assert!((score2 - 1.0).abs() < 1e-8); + } +} diff --git a/src/metrics/mod.rs b/src/metrics/mod.rs index 478edf6..30cc6ce 100644 --- a/src/metrics/mod.rs +++ b/src/metrics/mod.rs @@ -1,4 +1,6 @@ pub mod accuracy; +pub mod auc; +pub mod f1; pub mod precision; pub mod recall; diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index 7e1d72c..51db52f 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -239,7 +239,7 @@ impl DecisionTreeClassifier { let mut order: Vec> = Vec::new(); for i in 0..num_attributes { - order.push(x.get_col_as_vec(i).quick_argsort()); + order.push(x.get_col_as_vec(i).quick_argsort_mut()); } let mut tree = DecisionTreeClassifier { diff --git a/src/tree/decision_tree_regressor.rs b/src/tree/decision_tree_regressor.rs index d9677e9..d0802ee 100644 --- a/src/tree/decision_tree_regressor.rs +++ b/src/tree/decision_tree_regressor.rs @@ -164,7 +164,7 @@ impl DecisionTreeRegressor { let mut order: Vec> = Vec::new(); for i in 0..num_attributes { - order.push(x.get_col_as_vec(i).quick_argsort()); + order.push(x.get_col_as_vec(i).quick_argsort_mut()); } let mut tree = DecisionTreeRegressor {