Merge pull request #35 from smartcorelib/lasso

LASSO
This commit is contained in:
VolodymyrOrlov
2020-12-02 17:34:54 -08:00
committed by GitHub
9 changed files with 819 additions and 3 deletions
+58 -2
View File
@@ -9,6 +9,7 @@ use serde::{Deserialize, Serialize};
use crate::linalg::cholesky::CholeskyDecomposableMatrix;
use crate::linalg::evd::EVDDecomposableMatrix;
use crate::linalg::high_order::HighOrderOperations;
use crate::linalg::lu::LUDecomposableMatrix;
use crate::linalg::qr::QRDecomposableMatrix;
use crate::linalg::stats::MatrixStats;
@@ -444,6 +445,38 @@ impl<T: RealNumber> LUDecomposableMatrix<T> for DenseMatrix<T> {}
impl<T: RealNumber> CholeskyDecomposableMatrix<T> for DenseMatrix<T> {}
impl<T: RealNumber> HighOrderOperations<T> for DenseMatrix<T> {
fn ab(&self, a_transpose: bool, b: &Self, b_transpose: bool) -> Self {
if !a_transpose && !b_transpose {
self.matmul(b)
} else {
let (d1, d2, d3, d4) = match (a_transpose, b_transpose) {
(true, false) => (self.nrows, self.ncols, b.ncols, b.nrows),
(false, true) => (self.ncols, self.nrows, b.nrows, b.ncols),
_ => (self.nrows, self.ncols, b.nrows, b.ncols),
};
if d1 != d4 {
panic!("Can not multiply {}x{} by {}x{} matrices", d2, d1, d4, d3);
}
let mut result = Self::zeros(d2, d3);
for r in 0..d2 {
for c in 0..d3 {
let mut s = T::zero();
for i in 0..d1 {
match (a_transpose, b_transpose) {
(true, false) => s += self.get(i, r) * b.get(i, c),
(false, true) => s += self.get(r, i) * b.get(c, i),
_ => s += self.get(i, r) * b.get(c, i),
}
}
result.set(r, c, s);
}
}
result
}
}
}
impl<T: RealNumber> MatrixStats<T> for DenseMatrix<T> {}
impl<T: RealNumber> Matrix<T> for DenseMatrix<T> {}
@@ -625,8 +658,8 @@ impl<T: RealNumber> BaseMatrix<T> for DenseMatrix<T> {
}
fn dot(&self, other: &Self) -> T {
if self.nrows != 1 && other.nrows != 1 {
panic!("A and B should both be 1-dimentional vectors.");
if (self.nrows != 1 && other.nrows != 1) && (self.ncols != 1 && other.ncols != 1) {
panic!("A and B should both be either a row or a column vector.");
}
if self.nrows * self.ncols != other.nrows * other.ncols {
panic!("A and B should have the same size");
@@ -1120,6 +1153,29 @@ mod tests {
assert_eq!(result, expected);
}
#[test]
fn ab() {
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
let b = DenseMatrix::from_2d_array(&[&[5., 6.], &[7., 8.], &[9., 10.]]);
let c = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.]]);
assert_eq!(
a.ab(false, &b, false),
DenseMatrix::from_2d_array(&[&[46., 52.], &[109., 124.]])
);
assert_eq!(
c.ab(true, &b, false),
DenseMatrix::from_2d_array(&[&[71., 80.], &[92., 104.]])
);
assert_eq!(
b.ab(false, &c, true),
DenseMatrix::from_2d_array(&[&[17., 39., 61.], &[23., 53., 83.,], &[29., 67., 105.]])
);
assert_eq!(
a.ab(true, &b, true),
DenseMatrix::from_2d_array(&[&[29., 39., 49.], &[40., 54., 68.,], &[51., 69., 87.]])
);
}
#[test]
fn dot() {
let a = DenseMatrix::from_array(1, 3, &[1., 2., 3.]);