From 3e541e0b8f9447105ba7150541216e7956d99578 Mon Sep 17 00:00:00 2001 From: Volodymyr Orlov Date: Thu, 24 Sep 2020 14:42:28 -0700 Subject: [PATCH] fix: improves SVD --- src/linalg/svd.rs | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/linalg/svd.rs b/src/linalg/svd.rs index cf27779..8866ba9 100644 --- a/src/linalg/svd.rs +++ b/src/linalg/svd.rs @@ -53,6 +53,19 @@ pub struct SVD> { tol: T, } +impl> SVD { + /// Diagonal matrix with singular values + pub fn S(&self) -> M { + let mut s = M::zeros(self.U.shape().1, self.V.shape().0); + + for i in 0..self.s.len() { + s.set(i, i, self.s[i]); + } + + s + } +} + /// Trait that implements SVD decomposition routine for any matrix. pub trait SVDDecomposableMatrix: BaseMatrix { /// Solves Ax = b. Overrides original matrix in the process. @@ -711,4 +724,19 @@ mod tests { let w = a.svd_solve_mut(b).unwrap(); assert!(w.approximate_eq(&expected_w, 1e-2)); } + + #[test] + fn decompose_restore() { + let a = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0, 4.0], &[5.0, 6.0, 7.0, 8.0]]); + let svd = a.svd().unwrap(); + let u: &DenseMatrix = &svd.U; //U + let v: &DenseMatrix = &svd.V; // V + let s: &DenseMatrix = &svd.S(); // Sigma + + let a_hat = u.matmul(s).matmul(&v.transpose()); + + for (a, a_hat) in a.iter().zip(a_hat.iter()) { + assert!((a - a_hat).abs() < 1e-3) + } + } }