Adds SVD solver, code refactoring

This commit is contained in:
Volodymyr Orlov
2019-10-16 08:28:36 -07:00
parent 50744208a9
commit f4aec2b35e
9 changed files with 422 additions and 26 deletions
+3 -1
View File
@@ -2,12 +2,14 @@ use std::ops::Range;
pub mod naive;
pub trait Matrix: Into<Vec<f64>>{
pub trait Matrix: Into<Vec<f64>> + Clone{
fn get(&self, row: usize, col: usize) -> f64;
fn qr_solve_mut(&mut self, b: Self) -> Self;
fn svd_solve_mut(&mut self, b: Self) -> Self;
fn zeros(nrows: usize, ncols: usize) -> Self;
fn ones(nrows: usize, ncols: usize) -> Self;
+396 -6
View File
@@ -2,7 +2,7 @@ use std::ops::Range;
use crate::linalg::Matrix;
use crate::math;
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct DenseMatrix {
ncols: usize,
@@ -63,6 +63,10 @@ impl DenseMatrix {
self.values[col*self.nrows + row] /= x;
}
fn mul_element_mut(&mut self, row: usize, col: usize, x: f64) {
self.values[col*self.nrows + row] *= x;
}
fn add_element_mut(&mut self, row: usize, col: usize, x: f64) {
self.values[col*self.nrows + row] += x
}
@@ -87,7 +91,7 @@ impl PartialEq for DenseMatrix {
}
for i in 0..len {
if (self.values[i] - other.values[i]).abs() > math::SMALL_ERROR {
if (self.values[i] - other.values[i]).abs() > math::EPSILON {
return false;
}
}
@@ -195,6 +199,10 @@ impl Matrix for DenseMatrix {
let n = self.ncols;
let nrhs = b.ncols;
if self.nrows != b.nrows {
panic!("Dimensions do not agree. Self.nrows should equal b.nrows but is {}, {}", self.nrows, b.nrows);
}
let mut r_diagonal: Vec<f64> = vec![0f64; n];
for k in 0..n {
@@ -203,7 +211,7 @@ impl Matrix for DenseMatrix {
nrm = nrm.hypot(self.get(i, k));
}
if nrm > math::SMALL_ERROR {
if nrm.abs() > math::EPSILON {
if self.get(k, k) < 0f64 {
nrm = -nrm;
@@ -228,7 +236,7 @@ impl Matrix for DenseMatrix {
}
for j in 0..r_diagonal.len() {
if r_diagonal[j].abs() < math::SMALL_ERROR {
if r_diagonal[j].abs() < math::EPSILON {
panic!("Matrix is rank deficient.");
}
}
@@ -262,6 +270,378 @@ impl Matrix for DenseMatrix {
}
fn svd_solve_mut(&mut self, mut b: DenseMatrix) -> DenseMatrix {
if self.nrows != b.nrows {
panic!("Dimensions do not agree. Self.nrows should equal b.nrows but is {}, {}", self.nrows, b.nrows);
}
let m = self.nrows;
let n = self.ncols;
let (mut l, mut nm) = (0usize, 0usize);
let (mut anorm, mut g, mut scale) = (0f64, 0f64, 0f64);
let mut v = DenseMatrix::zeros(n, n);
let mut w = vec![0f64; n];
let mut rv1 = vec![0f64; n];
for i in 0..n {
l = i + 2;
rv1[i] = scale * g;
g = 0f64;
let mut s = 0f64;
scale = 0f64;
if i < m {
for k in i..m {
scale += self.get(k, i).abs();
}
if scale.abs() > math::EPSILON {
for k in i..m {
self.div_element_mut(k, i, scale);
s += self.get(k, i) * self.get(k, i);
}
let mut f = self.get(i, i);
g = -s.sqrt().copysign(f);
let h = f * g - s;
self.set(i, i, f - g);
for j in l - 1..n {
s = 0f64;
for k in i..m {
s += self.get(k, i) * self.get(k, j);
}
f = s / h;
for k in i..m {
self.add_element_mut(k, j, f * self.get(k, i));
}
}
for k in i..m {
self.mul_element_mut(k, i, scale);
}
}
}
w[i] = scale * g;
g = 0f64;
let mut s = 0f64;
scale = 0f64;
if i + 1 <= m && i + 1 != n {
for k in l - 1..n {
scale += self.get(i, k).abs();
}
if scale.abs() > math::EPSILON {
for k in l - 1..n {
self.div_element_mut(i, k, scale);
s += self.get(i, k) * self.get(i, k);
}
let f = self.get(i, l - 1);
g = -s.sqrt().copysign(f);
let h = f * g - s;
self.set(i, l - 1, f - g);
for k in l - 1..n {
rv1[k] = self.get(i, k) / h;
}
for j in l - 1..m {
s = 0f64;
for k in l - 1..n {
s += self.get(j, k) * self.get(i, k);
}
for k in l - 1..n {
self.add_element_mut(j, k, s * rv1[k]);
}
}
for k in l - 1..n {
self.mul_element_mut(i, k, scale);
}
}
}
anorm = f64::max(anorm, w[i].abs() + rv1[i].abs());
}
for i in (0..n).rev() {
if i < n - 1 {
if g != 0.0 {
for j in l..n {
v.set(j, i, (self.get(i, j) / self.get(i, l)) / g);
}
for j in l..n {
let mut s = 0f64;
for k in l..n {
s += self.get(i, k) * v.get(k, j);
}
for k in l..n {
v.add_element_mut(k, j, s * v.get(k, i));
}
}
}
for j in l..n {
v.set(i, j, 0f64);
v.set(j, i, 0f64);
}
}
v.set(i, i, 1.0);
g = rv1[i];
l = i;
}
for i in (0..usize::min(m, n)).rev() {
l = i + 1;
g = w[i];
for j in l..n {
self.set(i, j, 0f64);
}
if g.abs() > math::EPSILON {
g = 1f64 / g;
for j in l..n {
let mut s = 0f64;
for k in l..m {
s += self.get(k, i) * self.get(k, j);
}
let f = (s / self.get(i, i)) * g;
for k in i..m {
self.add_element_mut(k, j, f * self.get(k, i));
}
}
for j in i..m {
self.mul_element_mut(j, i, g);
}
} else {
for j in i..m {
self.set(j, i, 0f64);
}
}
self.add_element_mut(i, i, 1f64);
}
for k in (0..n).rev() {
for iteration in 0..30 {
let mut flag = true;
l = k;
while l != 0 {
if l == 0 || rv1[l].abs() <= math::EPSILON * anorm {
flag = false;
break;
}
nm = l - 1;
if w[nm].abs() <= math::EPSILON * anorm {
break;
}
l -= 1;
}
if flag {
let mut c = 0.0;
let mut s = 1.0;
for i in l..k+1 {
let f = s * rv1[i];
rv1[i] = c * rv1[i];
if f.abs() <= math::EPSILON * anorm {
break;
}
g = w[i];
let mut h = f.hypot(g);
w[i] = h;
h = 1.0 / h;
c = g * h;
s = -f * h;
for j in 0..m {
let y = self.get(j, nm);
let z = self.get(j, i);
self.set(j, nm, y * c + z * s);
self.set(j, i, z * c - y * s);
}
}
}
let z = w[k];
if l == k {
if z < 0f64 {
w[k] = -z;
for j in 0..n {
v.set(j, k, -v.get(j, k));
}
}
break;
}
if iteration == 29 {
panic!("no convergence in 30 iterations");
}
let mut x = w[l];
nm = k - 1;
let mut y = w[nm];
g = rv1[nm];
let mut h = rv1[k];
let mut f = ((y - z) * (y + z) + (g - h) * (g + h)) / (2.0 * h * y);
g = f.hypot(1.0);
f = ((x - z) * (x + z) + h * ((y / (f + g.copysign(f))) - h)) / x;
let mut c = 1f64;
let mut s = 1f64;
for j in l..=nm {
let i = j + 1;
g = rv1[i];
y = w[i];
h = s * g;
g = c * g;
let mut z = f.hypot(h);
rv1[j] = z;
c = f / z;
s = h / z;
f = x * c + g * s;
g = g * c - x * s;
h = y * s;
y *= c;
for jj in 0..n {
x = v.get(jj, j);
z = v.get(jj, i);
v.set(jj, j, x * c + z * s);
v.set(jj, i, z * c - x * s);
}
z = f.hypot(h);
w[j] = z;
if z.abs() > math::EPSILON {
z = 1.0 / z;
c = f * z;
s = h * z;
}
f = c * g + s * y;
x = c * y - s * g;
for jj in 0..m {
y = self.get(jj, j);
z = self.get(jj, i);
self.set(jj, j, y * c + z * s);
self.set(jj, i, z * c - y * s);
}
}
rv1[l] = 0.0;
rv1[k] = f;
w[k] = x;
}
}
let mut inc = 1usize;
let mut su = vec![0f64; m];
let mut sv = vec![0f64; n];
loop {
inc *= 3;
inc += 1;
if inc > n {
break;
}
}
loop {
inc /= 3;
for i in inc..n {
let sw = w[i];
for k in 0..m {
su[k] = self.get(k, i);
}
for k in 0..n {
sv[k] = v.get(k, i);
}
let mut j = i;
while w[j - inc] < sw {
w[j] = w[j - inc];
for k in 0..m {
self.set(k, j, self.get(k, j - inc));
}
for k in 0..n {
v.set(k, j, v.get(k, j - inc));
}
j -= inc;
if j < inc {
break;
}
}
w[j] = sw;
for k in 0..m {
self.set(k, j, su[k]);
}
for k in 0..n {
v.set(k, j, sv[k]);
}
}
if inc <= 1 {
break;
}
}
for k in 0..n {
let mut s = 0.;
for i in 0..m {
if self.get(i, k) < 0. {
s += 1.;
}
}
for j in 0..n {
if v.get(j, k) < 0. {
s += 1.;
}
}
if s > (m + n) as f64 / 2. {
for i in 0..m {
self.set(i, k, -self.get(i, k));
}
for j in 0..n {
v.set(j, k, -v.get(j, k));
}
}
}
let tol = 0.5 * ((m + n) as f64 + 1.).sqrt() * w[0] * math::EPSILON;
let p = b.ncols;
for k in 0..p {
let mut tmp = vec![0f64; v.nrows];
for j in 0..n {
let mut r = 0f64;
if w[j] > tol {
for i in 0..m {
r += self.get(i, j) * b.get(i, k);
}
r /= w[j];
}
tmp[j] = r;
}
for j in 0..n {
let mut r = 0.0;
for jj in 0..n {
r += v.get(j, jj) * tmp[jj];
}
b.set(j, k, r);
}
}
b
}
fn approximate_eq(&self, other: &Self, error: f64) -> bool {
if self.ncols != other.ncols || self.nrows != other.nrows {
return false
@@ -304,9 +684,19 @@ mod tests {
let mut a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2],&[0.5, 0.8], &[0.5, 0.3]]);
let expected_w = DenseMatrix::from_array(3, 2, &[-0.20270270270270263, 0.8783783783783784, 0.4729729729729729, -1.2837837837837829, 2.2297297297297303, 0.6621621621621613]);
let expected_w = DenseMatrix::from_array(3, 2, &[-0.20, 0.87, 0.47, -1.28, 2.22, 0.66]);
let w = a.qr_solve_mut(b);
assert_eq!(w, expected_w);
assert!(w.approximate_eq(&expected_w, 1e-2));
}
#[test]
fn svd_solve_mut() {
let mut a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2],&[0.5, 0.8], &[0.5, 0.3]]);
let expected_w = DenseMatrix::from_array(3, 2, &[-0.20, 0.87, 0.47, -1.28, 2.22, 0.66]);
let w = a.svd_solve_mut(b);
assert!(w.approximate_eq(&expected_w, 1e-2));
}
#[test]