fix: improves SVD
This commit is contained in:
@@ -53,6 +53,19 @@ pub struct SVD<T: RealNumber, M: SVDDecomposableMatrix<T>> {
|
||||
tol: T,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
|
||||
/// Diagonal matrix with singular values
|
||||
pub fn S(&self) -> M {
|
||||
let mut s = M::zeros(self.U.shape().1, self.V.shape().0);
|
||||
|
||||
for i in 0..self.s.len() {
|
||||
s.set(i, i, self.s[i]);
|
||||
}
|
||||
|
||||
s
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait that implements SVD decomposition routine for any matrix.
|
||||
pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
/// Solves Ax = b. Overrides original matrix in the process.
|
||||
@@ -711,4 +724,19 @@ mod tests {
|
||||
let w = a.svd_solve_mut(b).unwrap();
|
||||
assert!(w.approximate_eq(&expected_w, 1e-2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decompose_restore() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0, 4.0], &[5.0, 6.0, 7.0, 8.0]]);
|
||||
let svd = a.svd().unwrap();
|
||||
let u: &DenseMatrix<f32> = &svd.U; //U
|
||||
let v: &DenseMatrix<f32> = &svd.V; // V
|
||||
let s: &DenseMatrix<f32> = &svd.S(); // Sigma
|
||||
|
||||
let a_hat = u.matmul(s).matmul(&v.transpose());
|
||||
|
||||
for (a, a_hat) in a.iter().zip(a_hat.iter()) {
|
||||
assert!((a - a_hat).abs() < 1e-3)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user