feat: adds SVD

This commit is contained in:
Volodymyr Orlov
2020-02-28 09:21:00 -08:00
parent fe50509d3b
commit 619560a1cd
5 changed files with 269 additions and 119 deletions
+48 -85
View File
@@ -1,5 +1,6 @@
use std::ops::Range;
use crate::linalg::{Matrix};
use crate::linalg::svd::SVD;
use crate::math;
use rand::prelude::*;
@@ -338,14 +339,12 @@ impl Matrix for DenseMatrix {
}
fn svd_solve_mut(&mut self, mut b: DenseMatrix) -> DenseMatrix {
fn svd(&self) -> SVD<Self> {
if self.nrows != b.nrows {
panic!("Dimensions do not agree. Self.nrows should equal b.nrows but is {}, {}", self.nrows, b.nrows);
}
let mut U = self.clone();
let m = self.nrows;
let n = self.ncols;
let m = U.nrows;
let n = U.ncols;
let (mut l, mut nm) = (0usize, 0usize);
let (mut anorm, mut g, mut scale) = (0f64, 0f64, 0f64);
@@ -363,32 +362,32 @@ impl Matrix for DenseMatrix {
if i < m {
for k in i..m {
scale += self.get(k, i).abs();
scale += U.get(k, i).abs();
}
if scale.abs() > math::EPSILON {
for k in i..m {
self.div_element_mut(k, i, scale);
s += self.get(k, i) * self.get(k, i);
U.div_element_mut(k, i, scale);
s += U.get(k, i) * U.get(k, i);
}
let mut f = self.get(i, i);
let mut f = U.get(i, i);
g = -s.sqrt().copysign(f);
let h = f * g - s;
self.set(i, i, f - g);
U.set(i, i, f - g);
for j in l - 1..n {
s = 0f64;
for k in i..m {
s += self.get(k, i) * self.get(k, j);
s += U.get(k, i) * U.get(k, j);
}
f = s / h;
for k in i..m {
self.add_element_mut(k, j, f * self.get(k, i));
U.add_element_mut(k, j, f * U.get(k, i));
}
}
for k in i..m {
self.mul_element_mut(k, i, scale);
U.mul_element_mut(k, i, scale);
}
}
}
@@ -400,37 +399,37 @@ impl Matrix for DenseMatrix {
if i + 1 <= m && i + 1 != n {
for k in l - 1..n {
scale += self.get(i, k).abs();
scale += U.get(i, k).abs();
}
if scale.abs() > math::EPSILON {
for k in l - 1..n {
self.div_element_mut(i, k, scale);
s += self.get(i, k) * self.get(i, k);
U.div_element_mut(i, k, scale);
s += U.get(i, k) * U.get(i, k);
}
let f = self.get(i, l - 1);
let f = U.get(i, l - 1);
g = -s.sqrt().copysign(f);
let h = f * g - s;
self.set(i, l - 1, f - g);
U.set(i, l - 1, f - g);
for k in l - 1..n {
rv1[k] = self.get(i, k) / h;
rv1[k] = U.get(i, k) / h;
}
for j in l - 1..m {
s = 0f64;
for k in l - 1..n {
s += self.get(j, k) * self.get(i, k);
s += U.get(j, k) * U.get(i, k);
}
for k in l - 1..n {
self.add_element_mut(j, k, s * rv1[k]);
U.add_element_mut(j, k, s * rv1[k]);
}
}
for k in l - 1..n {
self.mul_element_mut(i, k, scale);
U.mul_element_mut(i, k, scale);
}
}
}
@@ -443,12 +442,12 @@ impl Matrix for DenseMatrix {
if i < n - 1 {
if g != 0.0 {
for j in l..n {
v.set(j, i, (self.get(i, j) / self.get(i, l)) / g);
v.set(j, i, (U.get(i, j) / U.get(i, l)) / g);
}
for j in l..n {
let mut s = 0f64;
for k in l..n {
s += self.get(i, k) * v.get(k, j);
s += U.get(i, k) * v.get(k, j);
}
for k in l..n {
v.add_element_mut(k, j, s * v.get(k, i));
@@ -469,7 +468,7 @@ impl Matrix for DenseMatrix {
l = i + 1;
g = w[i];
for j in l..n {
self.set(i, j, 0f64);
U.set(i, j, 0f64);
}
if g.abs() > math::EPSILON {
@@ -477,23 +476,23 @@ impl Matrix for DenseMatrix {
for j in l..n {
let mut s = 0f64;
for k in l..m {
s += self.get(k, i) * self.get(k, j);
s += U.get(k, i) * U.get(k, j);
}
let f = (s / self.get(i, i)) * g;
let f = (s / U.get(i, i)) * g;
for k in i..m {
self.add_element_mut(k, j, f * self.get(k, i));
U.add_element_mut(k, j, f * U.get(k, i));
}
}
for j in i..m {
self.mul_element_mut(j, i, g);
U.mul_element_mut(j, i, g);
}
} else {
for j in i..m {
self.set(j, i, 0f64);
U.set(j, i, 0f64);
}
}
self.add_element_mut(i, i, 1f64);
U.add_element_mut(i, i, 1f64);
}
for k in (0..n).rev() {
@@ -528,10 +527,10 @@ impl Matrix for DenseMatrix {
c = g * h;
s = -f * h;
for j in 0..m {
let y = self.get(j, nm);
let z = self.get(j, i);
self.set(j, nm, y * c + z * s);
self.set(j, i, z * c - y * s);
let y = U.get(j, nm);
let z = U.get(j, i);
U.set(j, nm, y * c + z * s);
U.set(j, i, z * c - y * s);
}
}
}
@@ -595,10 +594,10 @@ impl Matrix for DenseMatrix {
f = c * g + s * y;
x = c * y - s * g;
for jj in 0..m {
y = self.get(jj, j);
z = self.get(jj, i);
self.set(jj, j, y * c + z * s);
self.set(jj, i, z * c - y * s);
y = U.get(jj, j);
z = U.get(jj, i);
U.set(jj, j, y * c + z * s);
U.set(jj, i, z * c - y * s);
}
}
@@ -625,7 +624,7 @@ impl Matrix for DenseMatrix {
for i in inc..n {
let sw = w[i];
for k in 0..m {
su[k] = self.get(k, i);
su[k] = U.get(k, i);
}
for k in 0..n {
sv[k] = v.get(k, i);
@@ -634,7 +633,7 @@ impl Matrix for DenseMatrix {
while w[j - inc] < sw {
w[j] = w[j - inc];
for k in 0..m {
self.set(k, j, self.get(k, j - inc));
U.set(k, j, U.get(k, j - inc));
}
for k in 0..n {
v.set(k, j, v.get(k, j - inc));
@@ -646,7 +645,7 @@ impl Matrix for DenseMatrix {
}
w[j] = sw;
for k in 0..m {
self.set(k, j, su[k]);
U.set(k, j, su[k]);
}
for k in 0..n {
v.set(k, j, sv[k]);
@@ -661,7 +660,7 @@ impl Matrix for DenseMatrix {
for k in 0..n {
let mut s = 0.;
for i in 0..m {
if self.get(i, k) < 0. {
if U.get(i, k) < 0. {
s += 1.;
}
}
@@ -672,43 +671,17 @@ impl Matrix for DenseMatrix {
}
if s > (m + n) as f64 / 2. {
for i in 0..m {
self.set(i, k, -self.get(i, k));
U.set(i, k, -U.get(i, k));
}
for j in 0..n {
v.set(j, k, -v.get(j, k));
}
}
}
}
let tol = 0.5 * ((m + n) as f64 + 1.).sqrt() * w[0] * math::EPSILON;
SVD::new(U, v, w)
let p = b.ncols;
for k in 0..p {
let mut tmp = vec![0f64; v.nrows];
for j in 0..n {
let mut r = 0f64;
if w[j] > tol {
for i in 0..m {
r += self.get(i, j) * b.get(i, k);
}
r /= w[j];
}
tmp[j] = r;
}
for j in 0..n {
let mut r = 0.0;
for jj in 0..n {
r += v.get(j, jj) * tmp[jj];
}
b.set(j, k, r);
}
}
b
}
}
fn approximate_eq(&self, other: &Self, error: f64) -> bool {
if self.ncols != other.ncols || self.nrows != other.nrows {
@@ -1007,17 +980,7 @@ mod tests {
let expected_w = DenseMatrix::new(3, 2, vec![-0.20, 0.87, 0.47, -1.28, 2.22, 0.66]);
let w = a.qr_solve_mut(b);
assert!(w.approximate_eq(&expected_w, 1e-2));
}
#[test]
fn svd_solve_mut() {
let mut a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
let b = DenseMatrix::from_array(&[&[0.5, 0.2],&[0.5, 0.8], &[0.5, 0.3]]);
let expected_w = DenseMatrix::new(3, 2, vec![-0.20, 0.87, 0.47, -1.28, 2.22, 0.66]);
let w = a.svd_solve_mut(b);
assert!(w.approximate_eq(&expected_w, 1e-2));
}
}
#[test]
fn h_stack() {