fix: improves SVD

This commit is contained in:
Volodymyr Orlov
2020-09-24 14:42:28 -07:00
parent a19398fd70
commit 3e541e0b8f
+28
View File
@@ -53,6 +53,19 @@ pub struct SVD<T: RealNumber, M: SVDDecomposableMatrix<T>> {
tol: T,
}
impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
/// Diagonal matrix with singular values
pub fn S(&self) -> M {
let mut s = M::zeros(self.U.shape().1, self.V.shape().0);
for i in 0..self.s.len() {
s.set(i, i, self.s[i]);
}
s
}
}
/// Trait that implements SVD decomposition routine for any matrix.
pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
/// Solves Ax = b. Overrides original matrix in the process.
@@ -711,4 +724,19 @@ mod tests {
let w = a.svd_solve_mut(b).unwrap();
assert!(w.approximate_eq(&expected_w, 1e-2));
}
#[test]
fn decompose_restore() {
let a = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0, 4.0], &[5.0, 6.0, 7.0, 8.0]]);
let svd = a.svd().unwrap();
let u: &DenseMatrix<f32> = &svd.U; //U
let v: &DenseMatrix<f32> = &svd.V; // V
let s: &DenseMatrix<f32> = &svd.S(); // Sigma
let a_hat = u.matmul(s).matmul(&v.transpose());
for (a, a_hat) in a.iter().zip(a_hat.iter()) {
assert!((a - a_hat).abs() < 1e-3)
}
}
}