diff --git a/src/lib.rs b/src/lib.rs index 1c2424f..e41b127 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,4 +8,5 @@ pub mod linalg; pub mod math; pub mod algorithm; pub mod common; -pub mod optimization; \ No newline at end of file +pub mod optimization; +pub mod metrics; \ No newline at end of file diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs index dff103a..2cf109d 100644 --- a/src/linalg/mod.rs +++ b/src/linalg/mod.rs @@ -16,9 +16,18 @@ use evd::EVDDecomposableMatrix; use qr::QRDecomposableMatrix; use lu::LUDecomposableMatrix; +pub trait BaseVector: Clone + Debug { + + fn get(&self, i: usize) -> T; + + fn set(&mut self, i: usize, x: T); + + fn len(&self) -> usize; +} + pub trait BaseMatrix: Clone + Debug { - type RowVector: Clone + Debug; + type RowVector: BaseVector + Clone + Debug; fn from_row_vector(vec: Self::RowVector) -> Self; diff --git a/src/linalg/naive/dense_matrix.rs b/src/linalg/naive/dense_matrix.rs index 062db5f..5f90b27 100644 --- a/src/linalg/naive/dense_matrix.rs +++ b/src/linalg/naive/dense_matrix.rs @@ -9,13 +9,26 @@ use serde::ser::{Serializer, SerializeStruct}; use serde::de::{Deserializer, Visitor, SeqAccess, MapAccess}; use crate::linalg::Matrix; -pub use crate::linalg::BaseMatrix; +pub use crate::linalg::{BaseMatrix, BaseVector}; use crate::linalg::svd::SVDDecomposableMatrix; use crate::linalg::evd::EVDDecomposableMatrix; use crate::linalg::qr::QRDecomposableMatrix; use crate::linalg::lu::LUDecomposableMatrix; use crate::math::num::FloatExt; +impl BaseVector for Vec { + fn get(&self, i: usize) -> T { + self[i] + } + fn set(&mut self, i: usize, x: T){ + self[i] = x + } + + fn len(&self) -> usize { + self.len() + } +} + #[derive(Debug, Clone)] pub struct DenseMatrix { diff --git a/src/linalg/nalgebra_bindings.rs b/src/linalg/nalgebra_bindings.rs index ba8665b..9b6b4bd 100644 --- a/src/linalg/nalgebra_bindings.rs +++ b/src/linalg/nalgebra_bindings.rs @@ -4,13 +4,26 @@ use std::iter::Sum; use nalgebra::{MatrixMN, DMatrix, Matrix, Scalar, Dynamic, U1, VecStorage}; use crate::math::num::FloatExt; -use crate::linalg::BaseMatrix; +use crate::linalg::{BaseMatrix, BaseVector}; use crate::linalg::Matrix as SmartCoreMatrix; use crate::linalg::svd::SVDDecomposableMatrix; use crate::linalg::evd::EVDDecomposableMatrix; use crate::linalg::qr::QRDecomposableMatrix; use crate::linalg::lu::LUDecomposableMatrix; +impl BaseVector for MatrixMN { + fn get(&self, i: usize) -> T { + *self.get((0, i)).unwrap() + } + fn set(&mut self, i: usize, x: T){ + *self.get_mut((0, i)).unwrap() = x; + } + + fn len(&self) -> usize{ + self.len() + } +} + impl BaseMatrix for Matrix> { type RowVector = MatrixMN; @@ -340,6 +353,24 @@ mod tests { use super::*; use nalgebra::{Matrix2x3, DMatrix, RowDVector}; + #[test] + fn vec_len() { + let v = RowDVector::from_vec(vec!(1., 2., 3.)); + assert_eq!(3, v.len()); + } + + #[test] + fn get_set_vector() { + let mut v = RowDVector::from_vec(vec!(1., 2., 3., 4.)); + + let expected = RowDVector::from_vec(vec!(1., 5., 3., 4.)); + + v.set(1, 5.); + + assert_eq!(v, expected); + assert_eq!(5., BaseVector::get(&v, 1)); + } + #[test] fn get_set_dynamic() { let mut m = DMatrix::from_row_slice( @@ -355,7 +386,7 @@ mod tests { assert_eq!(m, expected); assert_eq!(10., BaseMatrix::get(&m, 1, 1)); - } + } #[test] fn zeros() { diff --git a/src/linalg/ndarray_bindings.rs b/src/linalg/ndarray_bindings.rs index 2fc88d9..9c439b6 100644 --- a/src/linalg/ndarray_bindings.rs +++ b/src/linalg/ndarray_bindings.rs @@ -9,13 +9,25 @@ use ndarray::{Array, ArrayBase, OwnedRepr, Ix2, Ix1, Axis, stack, s}; use ndarray::ScalarOperand; use crate::math::num::FloatExt; -use crate::linalg::BaseMatrix; +use crate::linalg::{BaseMatrix, BaseVector}; use crate::linalg::Matrix; use crate::linalg::svd::SVDDecomposableMatrix; use crate::linalg::evd::EVDDecomposableMatrix; use crate::linalg::qr::QRDecomposableMatrix; use crate::linalg::lu::LUDecomposableMatrix; +impl BaseVector for ArrayBase, Ix1> { + fn get(&self, i: usize) -> T { + self[i] + } + fn set(&mut self, i: usize, x: T){ + self[i] = x; + } + + fn len(&self) -> usize{ + self.len() + } +} impl BaseMatrix for ArrayBase, Ix2> { @@ -308,6 +320,23 @@ mod tests { use super::*; use ndarray::{arr1, arr2, Array2}; + #[test] + fn vec_get_set() { + let mut result = arr1(&[1., 2., 3.]); + let expected = arr1(&[1., 5., 3.]); + + result.set(1, 5.); + + assert_eq!(result, expected); + assert_eq!(5., BaseVector::get(&result, 1)); + } + + #[test] + fn vec_len() { + let v = arr1(&[1., 2., 3.]); + assert_eq!(3, v.len()); + } + #[test] fn from_to_row_vec() { @@ -449,7 +478,7 @@ mod tests { assert_eq!(result, expected); assert_eq!(10., BaseMatrix::get(&result, 1, 1)); - } + } #[test] fn dot() { diff --git a/src/linear/logistic_regression.rs b/src/linear/logistic_regression.rs index f78c310..a635815 100644 --- a/src/linear/logistic_regression.rs +++ b/src/linear/logistic_regression.rs @@ -273,6 +273,7 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::*; use ndarray::{arr1, arr2, Array1}; + use crate::metrics::*; #[test] fn multiclass_objective_f() { @@ -447,7 +448,7 @@ mod tests { let lr = LogisticRegression::fit(&x, &y); - let y_hat = lr.predict(&x); + let y_hat = lr.predict(&x); let error: f64 = y.into_iter().zip(y_hat.into_iter()).map(|(&a, &b)| (a - b).abs()).sum(); diff --git a/src/metrics/accuracy.rs b/src/metrics/accuracy.rs new file mode 100644 index 0000000..fe8d95f --- /dev/null +++ b/src/metrics/accuracy.rs @@ -0,0 +1,45 @@ +use serde::{Serialize, Deserialize}; + +use crate::math::num::FloatExt; +use crate::linalg::BaseVector; + +#[derive(Serialize, Deserialize, Debug)] +pub struct Accuracy{} + +impl Accuracy { + pub fn get_score>(&self, y_true: &V, y_prod: &V) -> T { + if y_true.len() != y_prod.len() { + panic!("The vector sizes don't match: {} != {}", y_true.len(), y_prod.len()); + } + + let n = y_true.len(); + + let mut positive = 0; + for i in 0..n { + if y_true.get(i) == y_prod.get(i) { + positive += 1; + } + } + + T::from_i64(positive).unwrap() / T::from_usize(n).unwrap() + } + +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn accuracy() { + let y_pred: Vec = vec![0., 2., 1., 3.]; + let y_true: Vec = vec![0., 1., 2., 3.]; + + let score1: f64 = Accuracy{}.get_score(&y_pred, &y_true); + let score2: f64 = Accuracy{}.get_score(&y_true, &y_true); + + assert!((score1 - 0.5).abs() < 1e-8); + assert!((score2 - 1.0).abs() < 1e-8); + } + +} \ No newline at end of file diff --git a/src/metrics/mod.rs b/src/metrics/mod.rs new file mode 100644 index 0000000..f486ade --- /dev/null +++ b/src/metrics/mod.rs @@ -0,0 +1,34 @@ +pub mod accuracy; +pub mod recall; +pub mod precision; + +use crate::math::num::FloatExt; +use crate::linalg::BaseVector; + +pub struct ClassificationMetrics{} + +impl ClassificationMetrics { + pub fn accuracy() -> accuracy::Accuracy{ + accuracy::Accuracy {} + } + + pub fn recall() -> recall::Recall{ + recall::Recall {} + } + + pub fn precision() -> precision::Precision{ + precision::Precision {} + } +} + +pub fn accuracy>(y_true: &V, y_prod: &V) -> T{ + ClassificationMetrics::accuracy().get_score(y_true, y_prod) +} + +pub fn recall>(y_true: &V, y_prod: &V) -> T{ + ClassificationMetrics::recall().get_score(y_true, y_prod) +} + +pub fn precision>(y_true: &V, y_prod: &V) -> T{ + ClassificationMetrics::precision().get_score(y_true, y_prod) +} \ No newline at end of file diff --git a/src/metrics/precision.rs b/src/metrics/precision.rs new file mode 100644 index 0000000..fecab51 --- /dev/null +++ b/src/metrics/precision.rs @@ -0,0 +1,57 @@ +use serde::{Serialize, Deserialize}; + +use crate::math::num::FloatExt; +use crate::linalg::BaseVector; + +#[derive(Serialize, Deserialize, Debug)] +pub struct Precision{} + +impl Precision { + pub fn get_score>(&self, y_true: &V, y_prod: &V) -> T { + if y_true.len() != y_prod.len() { + panic!("The vector sizes don't match: {} != {}", y_true.len(), y_prod.len()); + } + + let mut tp = 0; + let mut p = 0; + let n = y_true.len(); + for i in 0..n { + if y_true.get(i) != T::zero() && y_true.get(i) != T::one() { + panic!("Precision can only be applied to binary classification: {}", y_true.get(i)); + } + + if y_prod.get(i) != T::zero() && y_prod.get(i) != T::one() { + panic!("Precision can only be applied to binary classification: {}", y_prod.get(i)); + } + + if y_prod.get(i) == T::one() { + p += 1; + + if y_true.get(i) == T::one() { + tp += 1; + } + } + } + + T::from_i64(tp).unwrap() / T::from_i64(p).unwrap() + } + +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn precision() { + let y_true: Vec = vec![0., 1., 1., 0.]; + let y_pred: Vec = vec![0., 0., 1., 1.]; + + let score1: f64 = Precision{}.get_score(&y_pred, &y_true); + let score2: f64 = Precision{}.get_score(&y_pred, &y_pred); + + assert!((score1 - 0.5).abs() < 1e-8); + assert!((score2 - 1.0).abs() < 1e-8); + } + +} \ No newline at end of file diff --git a/src/metrics/recall.rs b/src/metrics/recall.rs new file mode 100644 index 0000000..61ee792 --- /dev/null +++ b/src/metrics/recall.rs @@ -0,0 +1,57 @@ +use serde::{Serialize, Deserialize}; + +use crate::math::num::FloatExt; +use crate::linalg::BaseVector; + +#[derive(Serialize, Deserialize, Debug)] +pub struct Recall{} + +impl Recall { + pub fn get_score>(&self, y_true: &V, y_prod: &V) -> T { + if y_true.len() != y_prod.len() { + panic!("The vector sizes don't match: {} != {}", y_true.len(), y_prod.len()); + } + + let mut tp = 0; + let mut p = 0; + let n = y_true.len(); + for i in 0..n { + if y_true.get(i) != T::zero() && y_true.get(i) != T::one() { + panic!("Recall can only be applied to binary classification: {}", y_true.get(i)); + } + + if y_prod.get(i) != T::zero() && y_prod.get(i) != T::one() { + panic!("Recall can only be applied to binary classification: {}", y_prod.get(i)); + } + + if y_true.get(i) == T::one() { + p += 1; + + if y_prod.get(i) == T::one() { + tp += 1; + } + } + } + + T::from_i64(tp).unwrap() / T::from_i64(p).unwrap() + } + +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn recall() { + let y_true: Vec = vec![0., 1., 1., 0.]; + let y_pred: Vec = vec![0., 0., 1., 1.]; + + let score1: f64 = Recall{}.get_score(&y_pred, &y_true); + let score2: f64 = Recall{}.get_score(&y_pred, &y_pred); + + assert!((score1 - 0.5).abs() < 1e-8); + assert!((score2 - 1.0).abs() < 1e-8); + } + +} \ No newline at end of file