feat: refactors matrix decomposition routines

This commit is contained in:
Volodymyr Orlov
2020-03-12 17:32:27 -07:00
parent 7b3fa982be
commit cb4323f26e
11 changed files with 1381 additions and 1256 deletions
+32 -34
View File
@@ -1,10 +1,12 @@
use std::ops::Range;
use crate::linalg::{Matrix};
use crate::linalg::svd::SVD;
use crate::linalg::evd::EVD;
use crate::linalg::BaseMatrix;
use crate::linalg::Matrix;
use crate::linalg::svd::SVDDecomposableMatrix;
use crate::linalg::evd::EVDDecomposableMatrix;
use crate::linalg::qr::QRDecomposableMatrix;
use ndarray::{Array, ArrayBase, OwnedRepr, Ix2, Ix1, Axis, stack, s};
impl Matrix for ArrayBase<OwnedRepr<f64>, Ix2>
impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
{
type RowVector = ArrayBase<OwnedRepr<f64>, Ix1>;
@@ -32,19 +34,7 @@ impl Matrix for ArrayBase<OwnedRepr<f64>, Ix2>
fn set(&mut self, row: usize, col: usize, x: f64) {
self[[row, col]] = x;
}
fn svd(&self) -> SVD<Self>{
panic!("svd method is not implemented for ndarray");
}
fn evd_mut(self, symmetric: bool) -> EVD<Self>{
panic!("evd method is not implemented for ndarray");
}
fn qr_solve_mut(&mut self, b: Self) -> Self {
panic!("qr_solve_mut method is not implemented for ndarray");
}
}
fn eye(size: usize) -> Self {
Array::eye(size)
@@ -286,6 +276,14 @@ impl Matrix for ArrayBase<OwnedRepr<f64>, Ix2>
}
impl SVDDecomposableMatrix for ArrayBase<OwnedRepr<f64>, Ix2> {}
impl EVDDecomposableMatrix for ArrayBase<OwnedRepr<f64>, Ix2> {}
impl QRDecomposableMatrix for ArrayBase<OwnedRepr<f64>, Ix2> {}
impl Matrix for ArrayBase<OwnedRepr<f64>, Ix2> {}
#[cfg(test)]
mod tests {
use super::*;
@@ -359,7 +357,7 @@ mod tests {
[4., 5., 6.]]);
a.div_element_mut(1, 1, 5.);
assert_eq!(Matrix::get(&a, 1, 1), 1.);
assert_eq!(BaseMatrix::get(&a, 1, 1), 1.);
}
@@ -370,7 +368,7 @@ mod tests {
[4., 5., 6.]]);
a.mul_element_mut(1, 1, 5.);
assert_eq!(Matrix::get(&a, 1, 1), 25.);
assert_eq!(BaseMatrix::get(&a, 1, 1), 25.);
}
@@ -381,7 +379,7 @@ mod tests {
[4., 5., 6.]]);
a.add_element_mut(1, 1, 5.);
assert_eq!(Matrix::get(&a, 1, 1), 10.);
assert_eq!(BaseMatrix::get(&a, 1, 1), 10.);
}
@@ -392,7 +390,7 @@ mod tests {
[4., 5., 6.]]);
a.sub_element_mut(1, 1, 5.);
assert_eq!(Matrix::get(&a, 1, 1), 0.);
assert_eq!(BaseMatrix::get(&a, 1, 1), 0.);
}
@@ -431,7 +429,7 @@ mod tests {
result.set(1, 1, 10.);
assert_eq!(result, expected);
assert_eq!(10., Matrix::get(&result, 1, 1));
assert_eq!(10., BaseMatrix::get(&result, 1, 1));
}
#[test]
@@ -447,7 +445,7 @@ mod tests {
let expected = arr2(&[
[22., 28.],
[49., 64.]]);
let result = Matrix::dot(&a, &b);
let result = BaseMatrix::dot(&a, &b);
assert_eq!(result, expected);
}
@@ -470,7 +468,7 @@ mod tests {
&[
[2., 3.],
[5., 6.]]);
let result = Matrix::slice(&a, 0..2, 1..3);
let result = BaseMatrix::slice(&a, 0..2, 1..3);
assert_eq!(result, expected);
}
@@ -510,12 +508,12 @@ mod tests {
#[test]
fn reshape() {
let m_orig = arr2(&[[1., 2., 3., 4., 5., 6.]]);
let m_2_by_3 = Matrix::reshape(&m_orig, 2, 3);
let m_result = Matrix::reshape(&m_2_by_3, 1, 6);
assert_eq!(Matrix::shape(&m_2_by_3), (2, 3));
assert_eq!(Matrix::get(&m_2_by_3, 1, 1), 5.);
assert_eq!(Matrix::get(&m_result, 0, 1), 2.);
assert_eq!(Matrix::get(&m_result, 0, 3), 4.);
let m_2_by_3 = BaseMatrix::reshape(&m_orig, 2, 3);
let m_result = BaseMatrix::reshape(&m_2_by_3, 1, 6);
assert_eq!(BaseMatrix::shape(&m_2_by_3), (2, 3));
assert_eq!(BaseMatrix::get(&m_2_by_3, 1, 1), 5.);
assert_eq!(BaseMatrix::get(&m_result, 0, 1), 2.);
assert_eq!(BaseMatrix::get(&m_result, 0, 3), 4.);
}
#[test]
@@ -544,9 +542,9 @@ mod tests {
fn softmax_mut(){
let mut prob = arr2(&[[1., 2., 3.]]);
prob.softmax_mut();
assert!((Matrix::get(&prob, 0, 0) - 0.09).abs() < 0.01);
assert!((Matrix::get(&prob, 0, 1) - 0.24).abs() < 0.01);
assert!((Matrix::get(&prob, 0, 2) - 0.66).abs() < 0.01);
assert!((BaseMatrix::get(&prob, 0, 0) - 0.09).abs() < 0.01);
assert!((BaseMatrix::get(&prob, 0, 1) - 0.24).abs() < 0.01);
assert!((BaseMatrix::get(&prob, 0, 2) - 0.66).abs() < 0.01);
}
#[test]
@@ -599,7 +597,7 @@ mod tests {
let a = arr2(&[[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]]);
let res: Array2<f64> = Matrix::eye(3);
let res: Array2<f64> = BaseMatrix::eye(3);
assert_eq!(res, a);
}
}