feat: BernoulliNB (#31)

* feat: BernoulliNB

* Move preprocessing to a trait in linalg/stats.rs
This commit is contained in:
morenol
2020-12-04 20:45:40 -04:00
committed by GitHub
parent 4720a3a4eb
commit f0b348dd6e
7 changed files with 367 additions and 4 deletions
+2 -1
View File
@@ -63,7 +63,7 @@ use evd::EVDDecomposableMatrix;
use high_order::HighOrderOperations;
use lu::LUDecomposableMatrix;
use qr::QRDecomposableMatrix;
use stats::MatrixStats;
use stats::{MatrixPreprocessing, MatrixStats};
use svd::SVDDecomposableMatrix;
/// Column or row vector
@@ -619,6 +619,7 @@ pub trait Matrix<T: RealNumber>:
+ LUDecomposableMatrix<T>
+ CholeskyDecomposableMatrix<T>
+ MatrixStats<T>
+ MatrixPreprocessing<T>
+ HighOrderOperations<T>
+ PartialEq
+ Display
+2 -1
View File
@@ -12,7 +12,7 @@ use crate::linalg::evd::EVDDecomposableMatrix;
use crate::linalg::high_order::HighOrderOperations;
use crate::linalg::lu::LUDecomposableMatrix;
use crate::linalg::qr::QRDecomposableMatrix;
use crate::linalg::stats::MatrixStats;
use crate::linalg::stats::{MatrixPreprocessing, MatrixStats};
use crate::linalg::svd::SVDDecomposableMatrix;
use crate::linalg::Matrix;
pub use crate::linalg::{BaseMatrix, BaseVector};
@@ -478,6 +478,7 @@ impl<T: RealNumber> HighOrderOperations<T> for DenseMatrix<T> {
}
impl<T: RealNumber> MatrixStats<T> for DenseMatrix<T> {}
impl<T: RealNumber> MatrixPreprocessing<T> for DenseMatrix<T> {}
impl<T: RealNumber> Matrix<T> for DenseMatrix<T> {}
+6 -1
View File
@@ -47,7 +47,7 @@ use crate::linalg::evd::EVDDecomposableMatrix;
use crate::linalg::high_order::HighOrderOperations;
use crate::linalg::lu::LUDecomposableMatrix;
use crate::linalg::qr::QRDecomposableMatrix;
use crate::linalg::stats::MatrixStats;
use crate::linalg::stats::{MatrixPreprocessing, MatrixStats};
use crate::linalg::svd::SVDDecomposableMatrix;
use crate::linalg::Matrix as SmartCoreMatrix;
use crate::linalg::{BaseMatrix, BaseVector};
@@ -554,6 +554,11 @@ impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Su
{
}
impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
MatrixPreprocessing<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
{
}
impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
HighOrderOperations<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
{
+6 -1
View File
@@ -54,7 +54,7 @@ use crate::linalg::evd::EVDDecomposableMatrix;
use crate::linalg::high_order::HighOrderOperations;
use crate::linalg::lu::LUDecomposableMatrix;
use crate::linalg::qr::QRDecomposableMatrix;
use crate::linalg::stats::MatrixStats;
use crate::linalg::stats::{MatrixPreprocessing, MatrixStats};
use crate::linalg::svd::SVDDecomposableMatrix;
use crate::linalg::Matrix;
use crate::linalg::{BaseMatrix, BaseVector};
@@ -503,6 +503,11 @@ impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssi
{
}
impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum>
MatrixPreprocessing<T> for ArrayBase<OwnedRepr<T>, Ix2>
{
}
impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum>
HighOrderOperations<T> for ArrayBase<OwnedRepr<T>, Ix2>
{
+41
View File
@@ -104,6 +104,47 @@ pub trait MatrixStats<T: RealNumber>: BaseMatrix<T> {
}
}
/// Defines baseline implementations for various matrix processing functions
pub trait MatrixPreprocessing<T: RealNumber>: BaseMatrix<T> {
/// Each element of the matrix greater than the threshold becomes 1, while values less than or equal to the threshold become 0
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
/// use crate::smartcore::linalg::stats::MatrixPreprocessing;
/// let mut a = DenseMatrix::from_array(2, 3, &[0., 2., 3., -5., -6., -7.]);
/// let expected = DenseMatrix::from_array(2, 3, &[0., 1., 1., 0., 0., 0.]);
/// a.binarize_mut(0.);
///
/// assert_eq!(a, expected);
/// ```
fn binarize_mut(&mut self, threshold: T) {
let (nrows, ncols) = self.shape();
for row in 0..nrows {
for col in 0..ncols {
if self.get(row, col) > threshold {
self.set(row, col, T::one());
} else {
self.set(row, col, T::zero());
}
}
}
}
/// Returns new matrix where elements are binarized according to a given threshold.
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
/// use crate::smartcore::linalg::stats::MatrixPreprocessing;
/// let a = DenseMatrix::from_array(2, 3, &[0., 2., 3., -5., -6., -7.]);
/// let expected = DenseMatrix::from_array(2, 3, &[0., 1., 1., 0., 0., 0.]);
///
/// assert_eq!(a.binarize(0.), expected);
/// ```
fn binarize(&self, threshold: T) -> Self {
let mut m = self.clone();
m.binarize_mut(threshold);
m
}
}
#[cfg(test)]
mod tests {
use super::*;