fix: cargo fmt
This commit is contained in:
+221
-74
@@ -12,17 +12,16 @@ pub struct SVD<T: FloatExt, M: SVDDecomposableMatrix<T>> {
|
||||
full: bool,
|
||||
m: usize,
|
||||
n: usize,
|
||||
tol: T
|
||||
tol: T,
|
||||
}
|
||||
|
||||
pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
||||
|
||||
fn svd_solve_mut(self, b: Self) -> Self {
|
||||
self.svd_mut().solve(b)
|
||||
}
|
||||
|
||||
fn svd_solve(&self, b: Self) -> Self {
|
||||
self.svd().solve(b)
|
||||
self.svd().solve(b)
|
||||
}
|
||||
|
||||
fn svd(&self) -> SVD<T, Self> {
|
||||
@@ -30,14 +29,13 @@ pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
||||
}
|
||||
|
||||
fn svd_mut(self) -> SVD<T, Self> {
|
||||
let mut U = self;
|
||||
|
||||
let mut U = self;
|
||||
let (m, n) = U.shape();
|
||||
|
||||
let (m, n) = U.shape();
|
||||
|
||||
let (mut l, mut nm) = (0usize, 0usize);
|
||||
let (mut anorm, mut g, mut scale) = (T::zero(), T::zero(), T::zero());
|
||||
|
||||
let (mut anorm, mut g, mut scale) = (T::zero(), T::zero(), T::zero());
|
||||
|
||||
let mut v = Self::zeros(n, n);
|
||||
let mut w = vec![T::zero(); n];
|
||||
let mut rv1 = vec![T::zero(); n];
|
||||
@@ -55,7 +53,6 @@ pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
||||
}
|
||||
|
||||
if scale.abs() > T::epsilon() {
|
||||
|
||||
for k in i..m {
|
||||
U.div_element_mut(k, i, scale);
|
||||
s = s + U.get(k, i) * U.get(k, i);
|
||||
@@ -98,7 +95,7 @@ pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
||||
}
|
||||
|
||||
let f = U.get(i, l - 1);
|
||||
g = -s.sqrt().copysign(f);
|
||||
g = -s.sqrt().copysign(f);
|
||||
let h = f * g - s;
|
||||
U.set(i, l - 1, f - g);
|
||||
|
||||
@@ -123,7 +120,6 @@ pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
anorm = T::max(anorm, w[i].abs() + rv1[i].abs());
|
||||
}
|
||||
|
||||
@@ -186,11 +182,11 @@ pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
||||
|
||||
for k in (0..n).rev() {
|
||||
for iteration in 0..30 {
|
||||
let mut flag = true;
|
||||
let mut flag = true;
|
||||
l = k;
|
||||
while l != 0 {
|
||||
while l != 0 {
|
||||
if l == 0 || rv1[l].abs() <= T::epsilon() * anorm {
|
||||
flag = false;
|
||||
flag = false;
|
||||
break;
|
||||
}
|
||||
nm = l - 1;
|
||||
@@ -203,7 +199,7 @@ pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
||||
if flag {
|
||||
let mut c = T::zero();
|
||||
let mut s = T::one();
|
||||
for i in l..k+1 {
|
||||
for i in l..k + 1 {
|
||||
let f = s * rv1[i];
|
||||
rv1[i] = c * rv1[i];
|
||||
if f.abs() <= T::epsilon() * anorm {
|
||||
@@ -219,7 +215,7 @@ pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
||||
let y = U.get(j, nm);
|
||||
let z = U.get(j, i);
|
||||
U.set(j, nm, y * c + z * s);
|
||||
U.set(j, i, z * c - y * s);
|
||||
U.set(j, i, z * c - y * s);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -295,11 +291,11 @@ pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
||||
w[k] = x;
|
||||
}
|
||||
}
|
||||
|
||||
let mut inc = 1usize;
|
||||
|
||||
let mut inc = 1usize;
|
||||
let mut su = vec![T::zero(); m];
|
||||
let mut sv = vec![T::zero(); n];
|
||||
|
||||
|
||||
loop {
|
||||
inc *= 3;
|
||||
inc += 1;
|
||||
@@ -310,7 +306,7 @@ pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
||||
|
||||
loop {
|
||||
inc /= 3;
|
||||
for i in inc..n {
|
||||
for i in inc..n {
|
||||
let sw = w[i];
|
||||
for k in 0..m {
|
||||
su[k] = U.get(k, i);
|
||||
@@ -339,7 +335,6 @@ pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
||||
for k in 0..n {
|
||||
v.set(k, j, sv[k]);
|
||||
}
|
||||
|
||||
}
|
||||
if inc <= 1 {
|
||||
break;
|
||||
@@ -366,18 +361,17 @@ pub trait SVDDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
||||
v.set(j, k, -v.get(j, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SVD::new(U, v, w)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt, M: SVDDecomposableMatrix<T>> SVD<T, M> {
|
||||
pub fn new(U: M, V: M, s: Vec<T>) -> SVD<T, M> {
|
||||
let m = U.shape().0;
|
||||
let n = V.shape().0;
|
||||
let full = s.len() == m.min(n);
|
||||
let n = V.shape().0;
|
||||
let full = s.len() == m.min(n);
|
||||
let tol = T::half() * (T::from(m + n).unwrap() + T::one()).sqrt() * s[0] * T::epsilon();
|
||||
SVD {
|
||||
U: U,
|
||||
@@ -386,7 +380,7 @@ impl<T: FloatExt, M: SVDDecomposableMatrix<T>> SVD<T, M> {
|
||||
full: full,
|
||||
m: m,
|
||||
n: n,
|
||||
tol: tol
|
||||
tol: tol,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -394,7 +388,11 @@ impl<T: FloatExt, M: SVDDecomposableMatrix<T>> SVD<T, M> {
|
||||
let p = b.shape().1;
|
||||
|
||||
if self.U.shape().0 != b.shape().0 {
|
||||
panic!("Dimensions do not agree. U.nrows should equal b.nrows but is {}, {}", self.U.shape().0, b.shape().0);
|
||||
panic!(
|
||||
"Dimensions do not agree. U.nrows should equal b.nrows but is {}, {}",
|
||||
self.U.shape().0,
|
||||
b.shape().0
|
||||
);
|
||||
}
|
||||
|
||||
for k in 0..p {
|
||||
@@ -424,30 +422,30 @@ impl<T: FloatExt, M: SVDDecomposableMatrix<T>> SVD<T, M> {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[test]
|
||||
fn decompose_symmetric() {
|
||||
|
||||
fn decompose_symmetric() {
|
||||
let A = DenseMatrix::from_array(&[
|
||||
&[0.9000, 0.4000, 0.7000],
|
||||
&[0.4000, 0.5000, 0.3000],
|
||||
&[0.7000, 0.3000, 0.8000]]);
|
||||
&[0.7000, 0.3000, 0.8000],
|
||||
]);
|
||||
|
||||
let s: Vec<f64> = vec![1.7498382, 0.3165784, 0.1335834];
|
||||
|
||||
let U = DenseMatrix::from_array(&[
|
||||
&[0.6881997, -0.07121225, 0.7220180],
|
||||
&[0.3700456, 0.89044952, -0.2648886],
|
||||
&[0.6240573, -0.44947578, -0.639158]
|
||||
&[0.6240573, -0.44947578, -0.639158],
|
||||
]);
|
||||
|
||||
let V = DenseMatrix::from_array(&[
|
||||
&[0.6881997, -0.07121225, 0.7220180],
|
||||
&[0.3700456, 0.89044952, -0.2648886],
|
||||
&[0.6240573, -0.44947578, -0.6391588]
|
||||
&[0.6240573, -0.44947578, -0.6391588],
|
||||
]);
|
||||
|
||||
let svd = A.svd();
|
||||
@@ -457,43 +455,199 @@ mod tests {
|
||||
for i in 0..s.len() {
|
||||
assert!((s[i] - svd.s[i]).abs() < 1e-4);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decompose_asymmetric() {
|
||||
|
||||
fn decompose_asymmetric() {
|
||||
let A = DenseMatrix::from_array(&[
|
||||
&[1.19720880, -1.8391378, 0.3019585, -1.1165701, -1.7210814, 0.4918882, -0.04247433],
|
||||
&[0.06605075, 1.0315583, 0.8294362, -0.3646043, -1.6038017, -0.9188110, -0.63760340],
|
||||
&[-1.02637715, 1.0747931, -0.8089055, -0.4726863, -0.2064826, -0.3325532, 0.17966051],
|
||||
&[-1.45817729, -0.8942353, 0.3459245, 1.5068363, -2.0180708, -0.3696350, -1.19575563],
|
||||
&[-0.07318103, -0.2783787, 1.2237598, 0.1995332, 0.2545336, -0.1392502, -1.88207227],
|
||||
&[0.88248425, -0.9360321, 0.1393172, 0.1393281, -0.3277873, -0.5553013, 1.63805985],
|
||||
&[0.12641406, -0.8710055, -0.2712301, 0.2296515, 1.1781535, -0.2158704, -0.27529472]
|
||||
&[
|
||||
1.19720880,
|
||||
-1.8391378,
|
||||
0.3019585,
|
||||
-1.1165701,
|
||||
-1.7210814,
|
||||
0.4918882,
|
||||
-0.04247433,
|
||||
],
|
||||
&[
|
||||
0.06605075,
|
||||
1.0315583,
|
||||
0.8294362,
|
||||
-0.3646043,
|
||||
-1.6038017,
|
||||
-0.9188110,
|
||||
-0.63760340,
|
||||
],
|
||||
&[
|
||||
-1.02637715,
|
||||
1.0747931,
|
||||
-0.8089055,
|
||||
-0.4726863,
|
||||
-0.2064826,
|
||||
-0.3325532,
|
||||
0.17966051,
|
||||
],
|
||||
&[
|
||||
-1.45817729,
|
||||
-0.8942353,
|
||||
0.3459245,
|
||||
1.5068363,
|
||||
-2.0180708,
|
||||
-0.3696350,
|
||||
-1.19575563,
|
||||
],
|
||||
&[
|
||||
-0.07318103,
|
||||
-0.2783787,
|
||||
1.2237598,
|
||||
0.1995332,
|
||||
0.2545336,
|
||||
-0.1392502,
|
||||
-1.88207227,
|
||||
],
|
||||
&[
|
||||
0.88248425, -0.9360321, 0.1393172, 0.1393281, -0.3277873, -0.5553013, 1.63805985,
|
||||
],
|
||||
&[
|
||||
0.12641406,
|
||||
-0.8710055,
|
||||
-0.2712301,
|
||||
0.2296515,
|
||||
1.1781535,
|
||||
-0.2158704,
|
||||
-0.27529472,
|
||||
],
|
||||
]);
|
||||
|
||||
let s: Vec<f64> = vec![3.8589375, 3.4396766, 2.6487176, 2.2317399, 1.5165054, 0.8109055, 0.2706515];
|
||||
let s: Vec<f64> = vec![
|
||||
3.8589375, 3.4396766, 2.6487176, 2.2317399, 1.5165054, 0.8109055, 0.2706515,
|
||||
];
|
||||
|
||||
let U = DenseMatrix::from_array(&[
|
||||
&[-0.3082776, 0.77676231, 0.01330514, 0.23231424, -0.47682758, 0.13927109, 0.02640713],
|
||||
&[-0.4013477, -0.09112050, 0.48754440, 0.47371793, 0.40636608, 0.24600706, -0.37796295],
|
||||
&[0.0599719, -0.31406586, 0.45428229, -0.08071283, -0.38432597, 0.57320261, 0.45673993],
|
||||
&[-0.7694214, -0.12681435, -0.05536793, -0.62189972, -0.02075522, -0.01724911, -0.03681864],
|
||||
&[-0.3319069, -0.17984404, -0.54466777, 0.45335157, 0.19377726, 0.12333423, 0.55003852],
|
||||
&[0.1259351, 0.49087824, 0.16349687, -0.32080176, 0.64828744, 0.20643772, 0.38812467],
|
||||
&[0.1491884, 0.01768604, -0.47884363, -0.14108924, 0.03922507, 0.73034065, -0.43965505]
|
||||
&[
|
||||
-0.3082776,
|
||||
0.77676231,
|
||||
0.01330514,
|
||||
0.23231424,
|
||||
-0.47682758,
|
||||
0.13927109,
|
||||
0.02640713,
|
||||
],
|
||||
&[
|
||||
-0.4013477,
|
||||
-0.09112050,
|
||||
0.48754440,
|
||||
0.47371793,
|
||||
0.40636608,
|
||||
0.24600706,
|
||||
-0.37796295,
|
||||
],
|
||||
&[
|
||||
0.0599719,
|
||||
-0.31406586,
|
||||
0.45428229,
|
||||
-0.08071283,
|
||||
-0.38432597,
|
||||
0.57320261,
|
||||
0.45673993,
|
||||
],
|
||||
&[
|
||||
-0.7694214,
|
||||
-0.12681435,
|
||||
-0.05536793,
|
||||
-0.62189972,
|
||||
-0.02075522,
|
||||
-0.01724911,
|
||||
-0.03681864,
|
||||
],
|
||||
&[
|
||||
-0.3319069,
|
||||
-0.17984404,
|
||||
-0.54466777,
|
||||
0.45335157,
|
||||
0.19377726,
|
||||
0.12333423,
|
||||
0.55003852,
|
||||
],
|
||||
&[
|
||||
0.1259351,
|
||||
0.49087824,
|
||||
0.16349687,
|
||||
-0.32080176,
|
||||
0.64828744,
|
||||
0.20643772,
|
||||
0.38812467,
|
||||
],
|
||||
&[
|
||||
0.1491884,
|
||||
0.01768604,
|
||||
-0.47884363,
|
||||
-0.14108924,
|
||||
0.03922507,
|
||||
0.73034065,
|
||||
-0.43965505,
|
||||
],
|
||||
]);
|
||||
|
||||
let V = DenseMatrix::from_array(&[
|
||||
&[-0.2122609, -0.54650056, 0.08071332, -0.43239135, -0.2925067, 0.1414550, 0.59769207],
|
||||
&[-0.1943605, 0.63132116, -0.54059857, -0.37089970, -0.1363031, 0.2892641, 0.17774114],
|
||||
&[0.3031265, -0.06182488, 0.18579097, -0.38606409, -0.5364911, 0.2983466, -0.58642548],
|
||||
&[0.1844063, 0.24425278, 0.25923756, 0.59043765, -0.4435443, 0.3959057, 0.37019098],
|
||||
&[-0.7164205, 0.30694911, 0.58264743, -0.07458095, -0.1142140, -0.1311972, -0.13124764],
|
||||
&[-0.1103067, -0.10633600, 0.18257905, -0.03638501, 0.5722925, 0.7784398, -0.09153611],
|
||||
&[-0.5156083, -0.36573746, -0.47613340, 0.41342817, -0.2659765, 0.1654796, -0.32346758]
|
||||
]);
|
||||
&[
|
||||
-0.2122609,
|
||||
-0.54650056,
|
||||
0.08071332,
|
||||
-0.43239135,
|
||||
-0.2925067,
|
||||
0.1414550,
|
||||
0.59769207,
|
||||
],
|
||||
&[
|
||||
-0.1943605,
|
||||
0.63132116,
|
||||
-0.54059857,
|
||||
-0.37089970,
|
||||
-0.1363031,
|
||||
0.2892641,
|
||||
0.17774114,
|
||||
],
|
||||
&[
|
||||
0.3031265,
|
||||
-0.06182488,
|
||||
0.18579097,
|
||||
-0.38606409,
|
||||
-0.5364911,
|
||||
0.2983466,
|
||||
-0.58642548,
|
||||
],
|
||||
&[
|
||||
0.1844063, 0.24425278, 0.25923756, 0.59043765, -0.4435443, 0.3959057, 0.37019098,
|
||||
],
|
||||
&[
|
||||
-0.7164205,
|
||||
0.30694911,
|
||||
0.58264743,
|
||||
-0.07458095,
|
||||
-0.1142140,
|
||||
-0.1311972,
|
||||
-0.13124764,
|
||||
],
|
||||
&[
|
||||
-0.1103067,
|
||||
-0.10633600,
|
||||
0.18257905,
|
||||
-0.03638501,
|
||||
0.5722925,
|
||||
0.7784398,
|
||||
-0.09153611,
|
||||
],
|
||||
&[
|
||||
-0.5156083,
|
||||
-0.36573746,
|
||||
-0.47613340,
|
||||
0.41342817,
|
||||
-0.2659765,
|
||||
0.1654796,
|
||||
-0.32346758,
|
||||
],
|
||||
]);
|
||||
|
||||
let svd = A.svd();
|
||||
|
||||
@@ -502,21 +656,14 @@ mod tests {
|
||||
for i in 0..s.len() {
|
||||
assert!((s[i] - svd.s[i]).abs() < 1e-4);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn solve() {
|
||||
|
||||
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||
let b = DenseMatrix::from_array(&[&[0.5, 0.2],&[0.5, 0.8], &[0.5, 0.3]]);
|
||||
let expected_w = DenseMatrix::from_array(&[
|
||||
&[-0.20, -1.28],
|
||||
&[0.87, 2.22],
|
||||
&[0.47, 0.66]
|
||||
]);
|
||||
let w = a.svd_solve_mut(b);
|
||||
assert!(w.approximate_eq(&expected_w, 1e-2));
|
||||
fn solve() {
|
||||
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||
let b = DenseMatrix::from_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]);
|
||||
let expected_w = DenseMatrix::from_array(&[&[-0.20, -1.28], &[0.87, 2.22], &[0.47, 0.66]]);
|
||||
let w = a.svd_solve_mut(b);
|
||||
assert!(w.approximate_eq(&expected_w, 1e-2));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user