feat: documents matrix methods

This commit is contained in:
Volodymyr Orlov
2020-09-06 18:27:11 -07:00
parent 1e3ed4c924
commit bbe810d164
25 changed files with 587 additions and 245 deletions
+44 -37
View File
@@ -69,17 +69,6 @@ impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Su
BaseMatrix::fill(nrows, ncols, T::one())
}
fn to_raw_vector(&self) -> Vec<T> {
let (nrows, ncols) = self.shape();
let mut result = vec![T::zero(); nrows * ncols];
for (i, row) in self.row_iter().enumerate() {
for (j, v) in row.iter().enumerate() {
result[i * ncols + j] = *v;
}
}
result
}
fn fill(nrows: usize, ncols: usize, value: T) -> Self {
let mut m = DMatrix::zeros(nrows, ncols);
m.fill(value);
@@ -90,7 +79,7 @@ impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Su
self.shape()
}
fn v_stack(&self, other: &Self) -> Self {
fn h_stack(&self, other: &Self) -> Self {
let mut columns = Vec::new();
for r in 0..self.ncols() {
columns.push(self.column(r));
@@ -101,7 +90,7 @@ impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Su
Matrix::from_columns(&columns)
}
fn h_stack(&self, other: &Self) -> Self {
fn v_stack(&self, other: &Self) -> Self {
let mut rows = Vec::new();
for r in 0..self.nrows() {
rows.push(self.row(r));
@@ -112,11 +101,11 @@ impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Su
Matrix::from_rows(&rows)
}
fn dot(&self, other: &Self) -> Self {
fn matmul(&self, other: &Self) -> Self {
self * other
}
fn vector_dot(&self, other: &Self) -> T {
fn dot(&self, other: &Self) -> T {
self.dot(other)
}
@@ -250,7 +239,14 @@ impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Su
}
fn reshape(&self, nrows: usize, ncols: usize) -> Self {
DMatrix::from_row_slice(nrows, ncols, &self.to_raw_vector())
let (c_nrows, c_ncols) = self.shape();
let mut raw_v = vec![T::zero(); c_nrows * c_ncols];
for (i, row) in self.row_iter().enumerate() {
for (j, v) in row.iter().enumerate() {
raw_v[i * c_ncols + j] = *v;
}
}
DMatrix::from_row_slice(nrows, ncols, &raw_v)
}
fn copy_from(&mut self, other: &Self) {
@@ -272,6 +268,22 @@ impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Su
sum
}
fn max(&self) -> T {
let mut m = T::zero();
for v in self.iter() {
m = m.max(*v);
}
m
}
fn min(&self) -> T {
let mut m = T::zero();
for v in self.iter() {
m = m.min(*v);
}
m
}
fn max_diff(&self, other: &Self) -> T {
let mut max_diff = T::zero();
for r in 0..self.nrows() {
@@ -488,13 +500,6 @@ mod tests {
assert_eq!(m.get_col_as_vec(1), vec!(2., 5., 8.));
}
#[test]
fn to_raw_vector() {
let m = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
assert_eq!(m.to_raw_vector(), vec!(1., 2., 3., 4., 5., 6.));
}
#[test]
fn element_add_sub_mul_div() {
let mut m = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
@@ -518,25 +523,25 @@ mod tests {
let expected =
DMatrix::from_row_slice(3, 4, &[1., 2., 3., 7., 4., 5., 6., 8., 9., 10., 11., 12.]);
let result = m1.v_stack(&m2).h_stack(&m3);
let result = m1.h_stack(&m2).v_stack(&m3);
assert_eq!(result, expected);
}
#[test]
fn matmul() {
let a = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]);
let b = DMatrix::from_row_slice(3, 2, &[1., 2., 3., 4., 5., 6.]);
let expected = DMatrix::from_row_slice(2, 2, &[22., 28., 49., 64.]);
let result = BaseMatrix::matmul(&a, &b);
assert_eq!(result, expected);
}
#[test]
fn dot() {
let a = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]);
let b = DMatrix::from_row_slice(3, 2, &[1., 2., 3., 4., 5., 6.]);
let expected = DMatrix::from_row_slice(2, 2, &[22., 28., 49., 64.]);
let result = BaseMatrix::dot(&a, &b);
assert_eq!(result, expected);
}
#[test]
fn vector_dot() {
let a = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]);
let b = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]);
assert_eq!(14., a.vector_dot(&b));
assert_eq!(14., a.dot(&b));
}
#[test]
@@ -632,9 +637,11 @@ mod tests {
}
#[test]
fn sum() {
let a = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]);
assert_eq!(a.sum(), 6.);
fn min_max_sum() {
let a = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]);
assert_eq!(21., a.sum());
assert_eq!(1., a.min());
assert_eq!(6., a.max());
}
#[test]