fix: improves SVD
This commit is contained in:
@@ -53,6 +53,19 @@ pub struct SVD<T: RealNumber, M: SVDDecomposableMatrix<T>> {
|
|||||||
tol: T,
|
tol: T,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
|
||||||
|
/// Diagonal matrix with singular values
|
||||||
|
pub fn S(&self) -> M {
|
||||||
|
let mut s = M::zeros(self.U.shape().1, self.V.shape().0);
|
||||||
|
|
||||||
|
for i in 0..self.s.len() {
|
||||||
|
s.set(i, i, self.s[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Trait that implements SVD decomposition routine for any matrix.
|
/// Trait that implements SVD decomposition routine for any matrix.
|
||||||
pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||||
/// Solves Ax = b. Overrides original matrix in the process.
|
/// Solves Ax = b. Overrides original matrix in the process.
|
||||||
@@ -711,4 +724,19 @@ mod tests {
|
|||||||
let w = a.svd_solve_mut(b).unwrap();
|
let w = a.svd_solve_mut(b).unwrap();
|
||||||
assert!(w.approximate_eq(&expected_w, 1e-2));
|
assert!(w.approximate_eq(&expected_w, 1e-2));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn decompose_restore() {
|
||||||
|
let a = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0, 4.0], &[5.0, 6.0, 7.0, 8.0]]);
|
||||||
|
let svd = a.svd().unwrap();
|
||||||
|
let u: &DenseMatrix<f32> = &svd.U; //U
|
||||||
|
let v: &DenseMatrix<f32> = &svd.V; // V
|
||||||
|
let s: &DenseMatrix<f32> = &svd.S(); // Sigma
|
||||||
|
|
||||||
|
let a_hat = u.matmul(s).matmul(&v.transpose());
|
||||||
|
|
||||||
|
for (a, a_hat) in a.iter().zip(a_hat.iter()) {
|
||||||
|
assert!((a - a_hat).abs() < 1e-3)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user