feat: refactors matrix decomposition routines
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
use std::ops::Range;
|
||||
use crate::linalg::{Matrix};
|
||||
use crate::linalg::svd::SVD;
|
||||
use crate::linalg::evd::EVD;
|
||||
use crate::linalg::BaseMatrix;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::linalg::svd::SVDDecomposableMatrix;
|
||||
use crate::linalg::evd::EVDDecomposableMatrix;
|
||||
use crate::linalg::qr::QRDecomposableMatrix;
|
||||
use ndarray::{Array, ArrayBase, OwnedRepr, Ix2, Ix1, Axis, stack, s};
|
||||
|
||||
impl Matrix for ArrayBase<OwnedRepr<f64>, Ix2>
|
||||
impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
|
||||
{
|
||||
type RowVector = ArrayBase<OwnedRepr<f64>, Ix1>;
|
||||
|
||||
@@ -32,19 +34,7 @@ impl Matrix for ArrayBase<OwnedRepr<f64>, Ix2>
|
||||
|
||||
fn set(&mut self, row: usize, col: usize, x: f64) {
|
||||
self[[row, col]] = x;
|
||||
}
|
||||
|
||||
fn svd(&self) -> SVD<Self>{
|
||||
panic!("svd method is not implemented for ndarray");
|
||||
}
|
||||
|
||||
fn evd_mut(self, symmetric: bool) -> EVD<Self>{
|
||||
panic!("evd method is not implemented for ndarray");
|
||||
}
|
||||
|
||||
fn qr_solve_mut(&mut self, b: Self) -> Self {
|
||||
panic!("qr_solve_mut method is not implemented for ndarray");
|
||||
}
|
||||
}
|
||||
|
||||
fn eye(size: usize) -> Self {
|
||||
Array::eye(size)
|
||||
@@ -286,6 +276,14 @@ impl Matrix for ArrayBase<OwnedRepr<f64>, Ix2>
|
||||
|
||||
}
|
||||
|
||||
impl SVDDecomposableMatrix for ArrayBase<OwnedRepr<f64>, Ix2> {}
|
||||
|
||||
impl EVDDecomposableMatrix for ArrayBase<OwnedRepr<f64>, Ix2> {}
|
||||
|
||||
impl QRDecomposableMatrix for ArrayBase<OwnedRepr<f64>, Ix2> {}
|
||||
|
||||
impl Matrix for ArrayBase<OwnedRepr<f64>, Ix2> {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -359,7 +357,7 @@ mod tests {
|
||||
[4., 5., 6.]]);
|
||||
a.div_element_mut(1, 1, 5.);
|
||||
|
||||
assert_eq!(Matrix::get(&a, 1, 1), 1.);
|
||||
assert_eq!(BaseMatrix::get(&a, 1, 1), 1.);
|
||||
|
||||
}
|
||||
|
||||
@@ -370,7 +368,7 @@ mod tests {
|
||||
[4., 5., 6.]]);
|
||||
a.mul_element_mut(1, 1, 5.);
|
||||
|
||||
assert_eq!(Matrix::get(&a, 1, 1), 25.);
|
||||
assert_eq!(BaseMatrix::get(&a, 1, 1), 25.);
|
||||
|
||||
}
|
||||
|
||||
@@ -381,7 +379,7 @@ mod tests {
|
||||
[4., 5., 6.]]);
|
||||
a.add_element_mut(1, 1, 5.);
|
||||
|
||||
assert_eq!(Matrix::get(&a, 1, 1), 10.);
|
||||
assert_eq!(BaseMatrix::get(&a, 1, 1), 10.);
|
||||
|
||||
}
|
||||
|
||||
@@ -392,7 +390,7 @@ mod tests {
|
||||
[4., 5., 6.]]);
|
||||
a.sub_element_mut(1, 1, 5.);
|
||||
|
||||
assert_eq!(Matrix::get(&a, 1, 1), 0.);
|
||||
assert_eq!(BaseMatrix::get(&a, 1, 1), 0.);
|
||||
|
||||
}
|
||||
|
||||
@@ -431,7 +429,7 @@ mod tests {
|
||||
result.set(1, 1, 10.);
|
||||
|
||||
assert_eq!(result, expected);
|
||||
assert_eq!(10., Matrix::get(&result, 1, 1));
|
||||
assert_eq!(10., BaseMatrix::get(&result, 1, 1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -447,7 +445,7 @@ mod tests {
|
||||
let expected = arr2(&[
|
||||
[22., 28.],
|
||||
[49., 64.]]);
|
||||
let result = Matrix::dot(&a, &b);
|
||||
let result = BaseMatrix::dot(&a, &b);
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
@@ -470,7 +468,7 @@ mod tests {
|
||||
&[
|
||||
[2., 3.],
|
||||
[5., 6.]]);
|
||||
let result = Matrix::slice(&a, 0..2, 1..3);
|
||||
let result = BaseMatrix::slice(&a, 0..2, 1..3);
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
@@ -510,12 +508,12 @@ mod tests {
|
||||
#[test]
|
||||
fn reshape() {
|
||||
let m_orig = arr2(&[[1., 2., 3., 4., 5., 6.]]);
|
||||
let m_2_by_3 = Matrix::reshape(&m_orig, 2, 3);
|
||||
let m_result = Matrix::reshape(&m_2_by_3, 1, 6);
|
||||
assert_eq!(Matrix::shape(&m_2_by_3), (2, 3));
|
||||
assert_eq!(Matrix::get(&m_2_by_3, 1, 1), 5.);
|
||||
assert_eq!(Matrix::get(&m_result, 0, 1), 2.);
|
||||
assert_eq!(Matrix::get(&m_result, 0, 3), 4.);
|
||||
let m_2_by_3 = BaseMatrix::reshape(&m_orig, 2, 3);
|
||||
let m_result = BaseMatrix::reshape(&m_2_by_3, 1, 6);
|
||||
assert_eq!(BaseMatrix::shape(&m_2_by_3), (2, 3));
|
||||
assert_eq!(BaseMatrix::get(&m_2_by_3, 1, 1), 5.);
|
||||
assert_eq!(BaseMatrix::get(&m_result, 0, 1), 2.);
|
||||
assert_eq!(BaseMatrix::get(&m_result, 0, 3), 4.);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -544,9 +542,9 @@ mod tests {
|
||||
fn softmax_mut(){
|
||||
let mut prob = arr2(&[[1., 2., 3.]]);
|
||||
prob.softmax_mut();
|
||||
assert!((Matrix::get(&prob, 0, 0) - 0.09).abs() < 0.01);
|
||||
assert!((Matrix::get(&prob, 0, 1) - 0.24).abs() < 0.01);
|
||||
assert!((Matrix::get(&prob, 0, 2) - 0.66).abs() < 0.01);
|
||||
assert!((BaseMatrix::get(&prob, 0, 0) - 0.09).abs() < 0.01);
|
||||
assert!((BaseMatrix::get(&prob, 0, 1) - 0.24).abs() < 0.01);
|
||||
assert!((BaseMatrix::get(&prob, 0, 2) - 0.66).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -599,7 +597,7 @@ mod tests {
|
||||
let a = arr2(&[[1., 0., 0.],
|
||||
[0., 1., 0.],
|
||||
[0., 0., 1.]]);
|
||||
let res: Array2<f64> = Matrix::eye(3);
|
||||
let res: Array2<f64> = BaseMatrix::eye(3);
|
||||
assert_eq!(res, a);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user