feat: BernoulliNB (#31)
* feat: BernoulliNB * Move preprocessing to a trait in linalg/stats.rs
This commit is contained in:
+2
-1
@@ -63,7 +63,7 @@ use evd::EVDDecomposableMatrix;
|
||||
use high_order::HighOrderOperations;
|
||||
use lu::LUDecomposableMatrix;
|
||||
use qr::QRDecomposableMatrix;
|
||||
use stats::MatrixStats;
|
||||
use stats::{MatrixPreprocessing, MatrixStats};
|
||||
use svd::SVDDecomposableMatrix;
|
||||
|
||||
/// Column or row vector
|
||||
@@ -619,6 +619,7 @@ pub trait Matrix<T: RealNumber>:
|
||||
+ LUDecomposableMatrix<T>
|
||||
+ CholeskyDecomposableMatrix<T>
|
||||
+ MatrixStats<T>
|
||||
+ MatrixPreprocessing<T>
|
||||
+ HighOrderOperations<T>
|
||||
+ PartialEq
|
||||
+ Display
|
||||
|
||||
@@ -12,7 +12,7 @@ use crate::linalg::evd::EVDDecomposableMatrix;
|
||||
use crate::linalg::high_order::HighOrderOperations;
|
||||
use crate::linalg::lu::LUDecomposableMatrix;
|
||||
use crate::linalg::qr::QRDecomposableMatrix;
|
||||
use crate::linalg::stats::MatrixStats;
|
||||
use crate::linalg::stats::{MatrixPreprocessing, MatrixStats};
|
||||
use crate::linalg::svd::SVDDecomposableMatrix;
|
||||
use crate::linalg::Matrix;
|
||||
pub use crate::linalg::{BaseMatrix, BaseVector};
|
||||
@@ -478,6 +478,7 @@ impl<T: RealNumber> HighOrderOperations<T> for DenseMatrix<T> {
|
||||
}
|
||||
|
||||
impl<T: RealNumber> MatrixStats<T> for DenseMatrix<T> {}
|
||||
impl<T: RealNumber> MatrixPreprocessing<T> for DenseMatrix<T> {}
|
||||
|
||||
impl<T: RealNumber> Matrix<T> for DenseMatrix<T> {}
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ use crate::linalg::evd::EVDDecomposableMatrix;
|
||||
use crate::linalg::high_order::HighOrderOperations;
|
||||
use crate::linalg::lu::LUDecomposableMatrix;
|
||||
use crate::linalg::qr::QRDecomposableMatrix;
|
||||
use crate::linalg::stats::MatrixStats;
|
||||
use crate::linalg::stats::{MatrixPreprocessing, MatrixStats};
|
||||
use crate::linalg::svd::SVDDecomposableMatrix;
|
||||
use crate::linalg::Matrix as SmartCoreMatrix;
|
||||
use crate::linalg::{BaseMatrix, BaseVector};
|
||||
@@ -554,6 +554,11 @@ impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Su
|
||||
{
|
||||
}
|
||||
|
||||
impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
|
||||
MatrixPreprocessing<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
|
||||
{
|
||||
}
|
||||
|
||||
impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
|
||||
HighOrderOperations<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
|
||||
{
|
||||
|
||||
@@ -54,7 +54,7 @@ use crate::linalg::evd::EVDDecomposableMatrix;
|
||||
use crate::linalg::high_order::HighOrderOperations;
|
||||
use crate::linalg::lu::LUDecomposableMatrix;
|
||||
use crate::linalg::qr::QRDecomposableMatrix;
|
||||
use crate::linalg::stats::MatrixStats;
|
||||
use crate::linalg::stats::{MatrixPreprocessing, MatrixStats};
|
||||
use crate::linalg::svd::SVDDecomposableMatrix;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::linalg::{BaseMatrix, BaseVector};
|
||||
@@ -503,6 +503,11 @@ impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssi
|
||||
{
|
||||
}
|
||||
|
||||
impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum>
|
||||
MatrixPreprocessing<T> for ArrayBase<OwnedRepr<T>, Ix2>
|
||||
{
|
||||
}
|
||||
|
||||
impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum>
|
||||
HighOrderOperations<T> for ArrayBase<OwnedRepr<T>, Ix2>
|
||||
{
|
||||
|
||||
@@ -104,6 +104,47 @@ pub trait MatrixStats<T: RealNumber>: BaseMatrix<T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Defines baseline implementations for various matrix processing functions
|
||||
pub trait MatrixPreprocessing<T: RealNumber>: BaseMatrix<T> {
|
||||
/// Each element of the matrix greater than the threshold becomes 1, while values less than or equal to the threshold become 0
|
||||
/// ```
|
||||
/// use smartcore::linalg::naive::dense_matrix::*;
|
||||
/// use crate::smartcore::linalg::stats::MatrixPreprocessing;
|
||||
/// let mut a = DenseMatrix::from_array(2, 3, &[0., 2., 3., -5., -6., -7.]);
|
||||
/// let expected = DenseMatrix::from_array(2, 3, &[0., 1., 1., 0., 0., 0.]);
|
||||
/// a.binarize_mut(0.);
|
||||
///
|
||||
/// assert_eq!(a, expected);
|
||||
/// ```
|
||||
|
||||
fn binarize_mut(&mut self, threshold: T) {
|
||||
let (nrows, ncols) = self.shape();
|
||||
for row in 0..nrows {
|
||||
for col in 0..ncols {
|
||||
if self.get(row, col) > threshold {
|
||||
self.set(row, col, T::one());
|
||||
} else {
|
||||
self.set(row, col, T::zero());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
/// Returns new matrix where elements are binarized according to a given threshold.
|
||||
/// ```
|
||||
/// use smartcore::linalg::naive::dense_matrix::*;
|
||||
/// use crate::smartcore::linalg::stats::MatrixPreprocessing;
|
||||
/// let a = DenseMatrix::from_array(2, 3, &[0., 2., 3., -5., -6., -7.]);
|
||||
/// let expected = DenseMatrix::from_array(2, 3, &[0., 1., 1., 0., 0., 0.]);
|
||||
///
|
||||
/// assert_eq!(a.binarize(0.), expected);
|
||||
/// ```
|
||||
fn binarize(&self, threshold: T) -> Self {
|
||||
let mut m = self.clone();
|
||||
m.binarize_mut(threshold);
|
||||
m
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
Reference in New Issue
Block a user