@@ -9,6 +9,7 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::cholesky::CholeskyDecomposableMatrix;
|
||||
use crate::linalg::evd::EVDDecomposableMatrix;
|
||||
use crate::linalg::high_order::HighOrderOperations;
|
||||
use crate::linalg::lu::LUDecomposableMatrix;
|
||||
use crate::linalg::qr::QRDecomposableMatrix;
|
||||
use crate::linalg::stats::MatrixStats;
|
||||
@@ -444,6 +445,38 @@ impl<T: RealNumber> LUDecomposableMatrix<T> for DenseMatrix<T> {}
|
||||
|
||||
impl<T: RealNumber> CholeskyDecomposableMatrix<T> for DenseMatrix<T> {}
|
||||
|
||||
impl<T: RealNumber> HighOrderOperations<T> for DenseMatrix<T> {
|
||||
fn ab(&self, a_transpose: bool, b: &Self, b_transpose: bool) -> Self {
|
||||
if !a_transpose && !b_transpose {
|
||||
self.matmul(b)
|
||||
} else {
|
||||
let (d1, d2, d3, d4) = match (a_transpose, b_transpose) {
|
||||
(true, false) => (self.nrows, self.ncols, b.ncols, b.nrows),
|
||||
(false, true) => (self.ncols, self.nrows, b.nrows, b.ncols),
|
||||
_ => (self.nrows, self.ncols, b.nrows, b.ncols),
|
||||
};
|
||||
if d1 != d4 {
|
||||
panic!("Can not multiply {}x{} by {}x{} matrices", d2, d1, d4, d3);
|
||||
}
|
||||
let mut result = Self::zeros(d2, d3);
|
||||
for r in 0..d2 {
|
||||
for c in 0..d3 {
|
||||
let mut s = T::zero();
|
||||
for i in 0..d1 {
|
||||
match (a_transpose, b_transpose) {
|
||||
(true, false) => s += self.get(i, r) * b.get(i, c),
|
||||
(false, true) => s += self.get(r, i) * b.get(c, i),
|
||||
_ => s += self.get(i, r) * b.get(c, i),
|
||||
}
|
||||
}
|
||||
result.set(r, c, s);
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> MatrixStats<T> for DenseMatrix<T> {}
|
||||
|
||||
impl<T: RealNumber> Matrix<T> for DenseMatrix<T> {}
|
||||
@@ -625,8 +658,8 @@ impl<T: RealNumber> BaseMatrix<T> for DenseMatrix<T> {
|
||||
}
|
||||
|
||||
fn dot(&self, other: &Self) -> T {
|
||||
if self.nrows != 1 && other.nrows != 1 {
|
||||
panic!("A and B should both be 1-dimentional vectors.");
|
||||
if (self.nrows != 1 && other.nrows != 1) && (self.ncols != 1 && other.ncols != 1) {
|
||||
panic!("A and B should both be either a row or a column vector.");
|
||||
}
|
||||
if self.nrows * self.ncols != other.nrows * other.ncols {
|
||||
panic!("A and B should have the same size");
|
||||
@@ -1120,6 +1153,29 @@ mod tests {
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ab() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
|
||||
let b = DenseMatrix::from_2d_array(&[&[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
let c = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.]]);
|
||||
assert_eq!(
|
||||
a.ab(false, &b, false),
|
||||
DenseMatrix::from_2d_array(&[&[46., 52.], &[109., 124.]])
|
||||
);
|
||||
assert_eq!(
|
||||
c.ab(true, &b, false),
|
||||
DenseMatrix::from_2d_array(&[&[71., 80.], &[92., 104.]])
|
||||
);
|
||||
assert_eq!(
|
||||
b.ab(false, &c, true),
|
||||
DenseMatrix::from_2d_array(&[&[17., 39., 61.], &[23., 53., 83.,], &[29., 67., 105.]])
|
||||
);
|
||||
assert_eq!(
|
||||
a.ab(true, &b, true),
|
||||
DenseMatrix::from_2d_array(&[&[29., 39., 49.], &[40., 54., 68.,], &[51., 69., 87.]])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dot() {
|
||||
let a = DenseMatrix::from_array(1, 3, &[1., 2., 3.]);
|
||||
|
||||
Reference in New Issue
Block a user