diff --git a/src/linalg/svd.rs b/src/linalg/svd.rs index cf27779..8866ba9 100644 --- a/src/linalg/svd.rs +++ b/src/linalg/svd.rs @@ -53,6 +53,19 @@ pub struct SVD> { tol: T, } +impl> SVD { + /// Diagonal matrix with singular values + pub fn S(&self) -> M { + let mut s = M::zeros(self.U.shape().1, self.V.shape().0); + + for i in 0..self.s.len() { + s.set(i, i, self.s[i]); + } + + s + } +} + /// Trait that implements SVD decomposition routine for any matrix. pub trait SVDDecomposableMatrix: BaseMatrix { /// Solves Ax = b. Overrides original matrix in the process. @@ -711,4 +724,19 @@ mod tests { let w = a.svd_solve_mut(b).unwrap(); assert!(w.approximate_eq(&expected_w, 1e-2)); } + + #[test] + fn decompose_restore() { + let a = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0, 4.0], &[5.0, 6.0, 7.0, 8.0]]); + let svd = a.svd().unwrap(); + let u: &DenseMatrix = &svd.U; //U + let v: &DenseMatrix = &svd.V; // V + let s: &DenseMatrix = &svd.S(); // Sigma + + let a_hat = u.matmul(s).matmul(&v.transpose()); + + for (a, a_hat) in a.iter().zip(a_hat.iter()) { + assert!((a - a_hat).abs() < 1e-3) + } + } }