feat: refactors matrix decomposition routines
This commit is contained in:
@@ -243,7 +243,7 @@ impl<M: Matrix> LogisticRegression<M> {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
use ndarray::{arr1, arr2, Array};
|
use ndarray::{arr1, arr2, Array};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -147,7 +147,7 @@ impl<M: Matrix> PCA<M> {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
|
|
||||||
fn us_arrests_data() -> DenseMatrix {
|
fn us_arrests_data() -> DenseMatrix {
|
||||||
DenseMatrix::from_array(&[
|
DenseMatrix::from_array(&[
|
||||||
|
|||||||
+760
-3
@@ -1,13 +1,14 @@
|
|||||||
use crate::linalg::{Matrix};
|
use num::complex::Complex;
|
||||||
|
use crate::linalg::BaseMatrix;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct EVD<M: Matrix> {
|
pub struct EVD<M: BaseMatrix> {
|
||||||
pub d: Vec<f64>,
|
pub d: Vec<f64>,
|
||||||
pub e: Vec<f64>,
|
pub e: Vec<f64>,
|
||||||
pub V: M
|
pub V: M
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<M: Matrix> EVD<M> {
|
impl<M: BaseMatrix> EVD<M> {
|
||||||
pub fn new(V: M, d: Vec<f64>, e: Vec<f64>) -> EVD<M> {
|
pub fn new(V: M, d: Vec<f64>, e: Vec<f64>) -> EVD<M> {
|
||||||
EVD {
|
EVD {
|
||||||
d: d,
|
d: d,
|
||||||
@@ -17,6 +18,762 @@ impl<M: Matrix> EVD<M> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait EVDDecomposableMatrix: BaseMatrix {
|
||||||
|
|
||||||
|
fn evd(&self, symmetric: bool) -> EVD<Self>{
|
||||||
|
self.clone().evd_mut(symmetric)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn evd_mut(mut self, symmetric: bool) -> EVD<Self>{
|
||||||
|
let(nrows, ncols) = self.shape();
|
||||||
|
if ncols != nrows {
|
||||||
|
panic!("Matrix is not square: {} x {}", nrows, ncols);
|
||||||
|
}
|
||||||
|
|
||||||
|
let n = nrows;
|
||||||
|
let mut d = vec![0f64; n];
|
||||||
|
let mut e = vec![0f64; n];
|
||||||
|
|
||||||
|
let mut V;
|
||||||
|
if symmetric {
|
||||||
|
V = self;
|
||||||
|
// Tridiagonalize.
|
||||||
|
tred2(&mut V, &mut d, &mut e);
|
||||||
|
// Diagonalize.
|
||||||
|
tql2(&mut V, &mut d, &mut e);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
let scale = balance(&mut self);
|
||||||
|
|
||||||
|
let perm = elmhes(&mut self);
|
||||||
|
|
||||||
|
V = Self::eye(n);
|
||||||
|
|
||||||
|
eltran(&self, &mut V, &perm);
|
||||||
|
|
||||||
|
hqr2(&mut self, &mut V, &mut d, &mut e);
|
||||||
|
balbak(&mut V, &scale);
|
||||||
|
sort(&mut d, &mut e, &mut V);
|
||||||
|
}
|
||||||
|
|
||||||
|
EVD {
|
||||||
|
V: V,
|
||||||
|
d: d,
|
||||||
|
e: e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tred2<M: BaseMatrix>(V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>) {
|
||||||
|
|
||||||
|
let (n, _) = V.shape();
|
||||||
|
for i in 0..n {
|
||||||
|
d[i] = V.get(n - 1, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Householder reduction to tridiagonal form.
|
||||||
|
for i in (1..n).rev() {
|
||||||
|
// Scale to avoid under/overflow.
|
||||||
|
let mut scale = 0f64;
|
||||||
|
let mut h = 0f64;
|
||||||
|
for k in 0..i {
|
||||||
|
scale = scale + d[k].abs();
|
||||||
|
}
|
||||||
|
if scale == 0f64 {
|
||||||
|
e[i] = d[i - 1];
|
||||||
|
for j in 0..i {
|
||||||
|
d[j] = V.get(i - 1, j);
|
||||||
|
V.set(i, j, 0.0);
|
||||||
|
V.set(j, i, 0.0);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Generate Householder vector.
|
||||||
|
for k in 0..i {
|
||||||
|
d[k] /= scale;
|
||||||
|
h += d[k] * d[k];
|
||||||
|
}
|
||||||
|
let mut f = d[i - 1];
|
||||||
|
let mut g = h.sqrt();
|
||||||
|
if f > 0f64 {
|
||||||
|
g = -g;
|
||||||
|
}
|
||||||
|
e[i] = scale * g;
|
||||||
|
h = h - f * g;
|
||||||
|
d[i - 1] = f - g;
|
||||||
|
for j in 0..i {
|
||||||
|
e[j] = 0f64;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply similarity transformation to remaining columns.
|
||||||
|
for j in 0..i {
|
||||||
|
f = d[j];
|
||||||
|
V.set(j, i, f);
|
||||||
|
g = e[j] + V.get(j, j) * f;
|
||||||
|
for k in j + 1..=i - 1 {
|
||||||
|
g += V.get(k, j) * d[k];
|
||||||
|
e[k] += V.get(k, j) * f;
|
||||||
|
}
|
||||||
|
e[j] = g;
|
||||||
|
}
|
||||||
|
f = 0.0;
|
||||||
|
for j in 0..i {
|
||||||
|
e[j] /= h;
|
||||||
|
f += e[j] * d[j];
|
||||||
|
}
|
||||||
|
let hh = f / (h + h);
|
||||||
|
for j in 0..i {
|
||||||
|
e[j] -= hh * d[j];
|
||||||
|
}
|
||||||
|
for j in 0..i {
|
||||||
|
f = d[j];
|
||||||
|
g = e[j];
|
||||||
|
for k in j..=i-1 {
|
||||||
|
V.sub_element_mut(k, j, f * e[k] + g * d[k]);
|
||||||
|
}
|
||||||
|
d[j] = V.get(i - 1, j);
|
||||||
|
V.set(i, j, 0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
d[i] = h;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accumulate transformations.
|
||||||
|
for i in 0..n-1 {
|
||||||
|
V.set(n - 1, i, V.get(i, i));
|
||||||
|
V.set(i, i, 1.0);
|
||||||
|
let h = d[i + 1];
|
||||||
|
if h != 0f64 {
|
||||||
|
for k in 0..=i {
|
||||||
|
d[k] = V.get(k, i + 1) / h;
|
||||||
|
}
|
||||||
|
for j in 0..=i {
|
||||||
|
let mut g = 0f64;
|
||||||
|
for k in 0..=i {
|
||||||
|
g += V.get(k, i + 1) * V.get(k, j);
|
||||||
|
}
|
||||||
|
for k in 0..=i {
|
||||||
|
V.sub_element_mut(k, j, g * d[k]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for k in 0..=i {
|
||||||
|
V.set(k, i + 1, 0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for j in 0..n {
|
||||||
|
d[j] = V.get(n - 1, j);
|
||||||
|
V.set(n - 1, j, 0.0);
|
||||||
|
}
|
||||||
|
V.set(n - 1, n - 1, 1.0);
|
||||||
|
e[0] = 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tql2<M: BaseMatrix>(V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>) {
|
||||||
|
let (n, _) = V.shape();
|
||||||
|
for i in 1..n {
|
||||||
|
e[i - 1] = e[i];
|
||||||
|
}
|
||||||
|
e[n - 1] = 0f64;
|
||||||
|
|
||||||
|
let mut f = 0f64;
|
||||||
|
let mut tst1 = 0f64;
|
||||||
|
for l in 0..n {
|
||||||
|
// Find small subdiagonal element
|
||||||
|
tst1 = f64::max(tst1, d[l].abs() + e[l].abs());
|
||||||
|
|
||||||
|
let mut m = l;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
if m < n {
|
||||||
|
if e[m].abs() <= tst1 * std::f64::EPSILON {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
m += 1;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If m == l, d[l] is an eigenvalue,
|
||||||
|
// otherwise, iterate.
|
||||||
|
if m > l {
|
||||||
|
let mut iter = 0;
|
||||||
|
loop {
|
||||||
|
iter += 1;
|
||||||
|
if iter >= 30 {
|
||||||
|
panic!("Too many iterations");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute implicit shift
|
||||||
|
let mut g = d[l];
|
||||||
|
let mut p = (d[l + 1] - g) / (2.0 * e[l]);
|
||||||
|
let mut r = p.hypot(1.0);
|
||||||
|
if p < 0f64 {
|
||||||
|
r = -r;
|
||||||
|
}
|
||||||
|
d[l] = e[l] / (p + r);
|
||||||
|
d[l + 1] = e[l] * (p + r);
|
||||||
|
let dl1 = d[l + 1];
|
||||||
|
let mut h = g - d[l];
|
||||||
|
for i in l+2..n {
|
||||||
|
d[i] -= h;
|
||||||
|
}
|
||||||
|
f = f + h;
|
||||||
|
|
||||||
|
// Implicit QL transformation.
|
||||||
|
p = d[m];
|
||||||
|
let mut c = 1.0;
|
||||||
|
let mut c2 = c;
|
||||||
|
let mut c3 = c;
|
||||||
|
let el1 = e[l + 1];
|
||||||
|
let mut s = 0.0;
|
||||||
|
let mut s2 = 0.0;
|
||||||
|
for i in (l..m).rev() {
|
||||||
|
c3 = c2;
|
||||||
|
c2 = c;
|
||||||
|
s2 = s;
|
||||||
|
g = c * e[i];
|
||||||
|
h = c * p;
|
||||||
|
r = p.hypot(e[i]);
|
||||||
|
e[i + 1] = s * r;
|
||||||
|
s = e[i] / r;
|
||||||
|
c = p / r;
|
||||||
|
p = c * d[i] - s * g;
|
||||||
|
d[i + 1] = h + s * (c * g + s * d[i]);
|
||||||
|
|
||||||
|
// Accumulate transformation.
|
||||||
|
for k in 0..n {
|
||||||
|
h = V.get(k, i + 1);
|
||||||
|
V.set(k, i + 1, s * V.get(k, i) + c * h);
|
||||||
|
V.set(k, i, c * V.get(k, i) - s * h);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
p = -s * s2 * c3 * el1 * e[l] / dl1;
|
||||||
|
e[l] = s * p;
|
||||||
|
d[l] = c * p;
|
||||||
|
|
||||||
|
// Check for convergence.
|
||||||
|
if e[l].abs() <= tst1 * std::f64::EPSILON {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
d[l] = d[l] + f;
|
||||||
|
e[l] = 0f64;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort eigenvalues and corresponding vectors.
|
||||||
|
for i in 0..n-1 {
|
||||||
|
let mut k = i;
|
||||||
|
let mut p = d[i];
|
||||||
|
for j in i + 1..n {
|
||||||
|
if d[j] > p {
|
||||||
|
k = j;
|
||||||
|
p = d[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if k != i {
|
||||||
|
d[k] = d[i];
|
||||||
|
d[i] = p;
|
||||||
|
for j in 0..n {
|
||||||
|
p = V.get(j, i);
|
||||||
|
V.set(j, i, V.get(j, k));
|
||||||
|
V.set(j, k, p);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn balance<M: BaseMatrix>(A: &mut M) -> Vec<f64> {
|
||||||
|
let radix = 2f64;
|
||||||
|
let sqrdx = radix * radix;
|
||||||
|
|
||||||
|
let (n, _) = A.shape();
|
||||||
|
|
||||||
|
let mut scale = vec![1f64; n];
|
||||||
|
|
||||||
|
let mut done = false;
|
||||||
|
while !done {
|
||||||
|
done = true;
|
||||||
|
for i in 0..n {
|
||||||
|
let mut r = 0f64;
|
||||||
|
let mut c = 0f64;
|
||||||
|
for j in 0..n {
|
||||||
|
if j != i {
|
||||||
|
c += A.get(j, i).abs();
|
||||||
|
r += A.get(i, j).abs();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c != 0f64 && r != 0f64 {
|
||||||
|
let mut g = r / radix;
|
||||||
|
let mut f = 1.0;
|
||||||
|
let s = c + r;
|
||||||
|
while c < g {
|
||||||
|
f *= radix;
|
||||||
|
c *= sqrdx;
|
||||||
|
}
|
||||||
|
g = r * radix;
|
||||||
|
while c > g {
|
||||||
|
f /= radix;
|
||||||
|
c /= sqrdx;
|
||||||
|
}
|
||||||
|
if (c + r) / f < 0.95 * s {
|
||||||
|
done = false;
|
||||||
|
g = 1.0 / f;
|
||||||
|
scale[i] *= f;
|
||||||
|
for j in 0..n {
|
||||||
|
A.mul_element_mut(i, j, g);
|
||||||
|
}
|
||||||
|
for j in 0..n {
|
||||||
|
A.mul_element_mut(j, i, f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn elmhes<M: BaseMatrix>(A: &mut M) -> Vec<usize> {
|
||||||
|
let (n, _) = A.shape();
|
||||||
|
let mut perm = vec![0; n];
|
||||||
|
|
||||||
|
for m in 1..n-1 {
|
||||||
|
let mut x = 0f64;
|
||||||
|
let mut i = m;
|
||||||
|
for j in m..n {
|
||||||
|
if A.get(j, m - 1).abs() > x.abs() {
|
||||||
|
x = A.get(j, m - 1);
|
||||||
|
i = j;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
perm[m] = i;
|
||||||
|
if i != m {
|
||||||
|
for j in (m-1)..n {
|
||||||
|
let swap = A.get(i, j);
|
||||||
|
A.set(i, j, A.get(m, j));
|
||||||
|
A.set(m, j, swap);
|
||||||
|
}
|
||||||
|
for j in 0..n {
|
||||||
|
let swap = A.get(j, i);
|
||||||
|
A.set(j, i, A.get(j, m));
|
||||||
|
A.set(j, m, swap);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if x != 0f64 {
|
||||||
|
for i in (m + 1)..n {
|
||||||
|
let mut y = A.get(i, m - 1);
|
||||||
|
if y != 0f64 {
|
||||||
|
y /= x;
|
||||||
|
A.set(i, m - 1, y);
|
||||||
|
for j in m..n {
|
||||||
|
A.sub_element_mut(i, j, y * A.get(m, j));
|
||||||
|
}
|
||||||
|
for j in 0..n {
|
||||||
|
A.add_element_mut(j, m, y * A.get(j, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return perm;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn eltran<M: BaseMatrix>(A: &M, V: &mut M, perm: &Vec<usize>) {
|
||||||
|
let (n, _) = A.shape();
|
||||||
|
for mp in (1..n - 1).rev() {
|
||||||
|
for k in mp + 1..n {
|
||||||
|
V.set(k, mp, A.get(k, mp - 1));
|
||||||
|
}
|
||||||
|
let i = perm[mp];
|
||||||
|
if i != mp {
|
||||||
|
for j in mp..n {
|
||||||
|
V.set(mp, j, V.get(i, j));
|
||||||
|
V.set(i, j, 0.0);
|
||||||
|
}
|
||||||
|
V.set(i, mp, 1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn hqr2<M: BaseMatrix>(A: &mut M, V: &mut M, d: &mut Vec<f64>, e: &mut Vec<f64>) {
|
||||||
|
let (n, _) = A.shape();
|
||||||
|
let mut z = 0f64;
|
||||||
|
let mut s = 0f64;
|
||||||
|
let mut r = 0f64;
|
||||||
|
let mut q = 0f64;
|
||||||
|
let mut p = 0f64;
|
||||||
|
let mut anorm = 0f64;
|
||||||
|
|
||||||
|
for i in 0..n {
|
||||||
|
for j in i32::max(i as i32 - 1, 0)..n as i32 {
|
||||||
|
anorm += A.get(i, j as usize).abs();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut nn = n - 1;
|
||||||
|
let mut t = 0.0;
|
||||||
|
'outer: loop {
|
||||||
|
let mut its = 0;
|
||||||
|
loop {
|
||||||
|
let mut l = nn;
|
||||||
|
while l > 0 {
|
||||||
|
s = A.get(l - 1, l - 1).abs() + A.get(l, l).abs();
|
||||||
|
if s == 0.0 {
|
||||||
|
s = anorm;
|
||||||
|
}
|
||||||
|
if A.get(l, l - 1).abs() <= std::f64::EPSILON * s {
|
||||||
|
A.set(l, l - 1, 0.0);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
l -= 1;
|
||||||
|
}
|
||||||
|
let mut x = A.get(nn, nn);
|
||||||
|
if l == nn {
|
||||||
|
d[nn] = x + t;
|
||||||
|
A.set(nn, nn, x + t);
|
||||||
|
if nn == 0 {
|
||||||
|
break 'outer;
|
||||||
|
} else {
|
||||||
|
nn -= 1;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let mut y = A.get(nn - 1, nn - 1);
|
||||||
|
let mut w = A.get(nn, nn - 1) * A.get(nn - 1, nn);
|
||||||
|
if l == nn - 1 {
|
||||||
|
p = 0.5 * (y - x);
|
||||||
|
q = p * p + w;
|
||||||
|
z = q.abs().sqrt();
|
||||||
|
x += t;
|
||||||
|
A.set(nn, nn, x );
|
||||||
|
A.set(nn - 1, nn - 1, y + t);
|
||||||
|
if q >= 0.0 {
|
||||||
|
z = p + z.copysign(p);
|
||||||
|
d[nn - 1] = x + z;
|
||||||
|
d[nn] = x + z;
|
||||||
|
if z != 0.0 {
|
||||||
|
d[nn] = x - w / z;
|
||||||
|
}
|
||||||
|
x = A.get(nn, nn - 1);
|
||||||
|
s = x.abs() + z.abs();
|
||||||
|
p = x / s;
|
||||||
|
q = z / s;
|
||||||
|
r = (p * p + q * q).sqrt();
|
||||||
|
p /= r;
|
||||||
|
q /= r;
|
||||||
|
for j in nn-1..n {
|
||||||
|
z = A.get(nn - 1, j);
|
||||||
|
A.set(nn - 1, j, q * z + p * A.get(nn, j));
|
||||||
|
A.set(nn, j, q * A.get(nn, j) - p * z);
|
||||||
|
}
|
||||||
|
for i in 0..=nn {
|
||||||
|
z = A.get(i, nn - 1);
|
||||||
|
A.set(i, nn - 1, q * z + p * A.get(i, nn));
|
||||||
|
A.set(i, nn, q * A.get(i, nn) - p * z);
|
||||||
|
}
|
||||||
|
for i in 0..n {
|
||||||
|
z = V.get(i, nn - 1);
|
||||||
|
V.set(i, nn - 1, q * z + p * V.get(i, nn));
|
||||||
|
V.set(i, nn, q * V.get(i, nn) - p * z);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
d[nn] = x + p;
|
||||||
|
e[nn] = -z;
|
||||||
|
d[nn - 1] = d[nn];
|
||||||
|
e[nn - 1] = -e[nn];
|
||||||
|
}
|
||||||
|
|
||||||
|
if nn <= 1 {
|
||||||
|
break 'outer;
|
||||||
|
} else {
|
||||||
|
nn -= 2;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if its == 30 {
|
||||||
|
panic!("Too many iterations in hqr");
|
||||||
|
}
|
||||||
|
if its == 10 || its == 20 {
|
||||||
|
t += x;
|
||||||
|
for i in 0..nn+1 {
|
||||||
|
A.sub_element_mut(i, i, x);
|
||||||
|
}
|
||||||
|
s = A.get(nn, nn - 1).abs() + A.get(nn - 1, nn - 2).abs();
|
||||||
|
y = 0.75 * s;
|
||||||
|
x = 0.75 * s;
|
||||||
|
w = -0.4375 * s * s;
|
||||||
|
}
|
||||||
|
its += 1;
|
||||||
|
let mut m = nn - 2;
|
||||||
|
while m >= l {
|
||||||
|
z = A.get(m, m);
|
||||||
|
r = x - z;
|
||||||
|
s = y - z;
|
||||||
|
p = (r * s - w) / A.get(m + 1, m) + A.get(m, m + 1);
|
||||||
|
q = A.get(m + 1, m + 1) - z - r - s;
|
||||||
|
r = A.get(m + 2, m + 1);
|
||||||
|
s = p.abs() + q.abs() + r.abs();
|
||||||
|
p /= s;
|
||||||
|
q /= s;
|
||||||
|
r /= s;
|
||||||
|
if m == l {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let u = A.get(m, m - 1).abs() * (q.abs() + r.abs());
|
||||||
|
let v = p.abs() * (A.get(m - 1, m - 1).abs() + z.abs() + A.get(m + 1, m + 1).abs());
|
||||||
|
if u <= std::f64::EPSILON * v {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
m -= 1;
|
||||||
|
}
|
||||||
|
for i in m..nn-1 {
|
||||||
|
A.set(i + 2, i , 0.0);
|
||||||
|
if i != m {
|
||||||
|
A.set(i + 2, i - 1, 0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for k in m..nn {
|
||||||
|
if k != m {
|
||||||
|
p = A.get(k, k - 1);
|
||||||
|
q = A.get(k + 1, k - 1);
|
||||||
|
r = 0.0;
|
||||||
|
if k + 1 != nn {
|
||||||
|
r = A.get(k + 2, k - 1);
|
||||||
|
}
|
||||||
|
x = p.abs() + q.abs() +r.abs();
|
||||||
|
if x != 0.0 {
|
||||||
|
p /= x;
|
||||||
|
q /= x;
|
||||||
|
r /= x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let s = (p * p + q * q + r * r).sqrt().copysign(p);
|
||||||
|
if s != 0.0 {
|
||||||
|
if k == m {
|
||||||
|
if l != m {
|
||||||
|
A.set(k, k - 1, -A.get(k, k - 1));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
A.set(k, k - 1, -s * x);
|
||||||
|
}
|
||||||
|
p += s;
|
||||||
|
x = p / s;
|
||||||
|
y = q / s;
|
||||||
|
z = r / s;
|
||||||
|
q /= p;
|
||||||
|
r /= p;
|
||||||
|
for j in k..n {
|
||||||
|
p = A.get(k, j) + q * A.get(k + 1, j);
|
||||||
|
if k + 1 != nn {
|
||||||
|
p += r * A.get(k + 2, j);
|
||||||
|
A.sub_element_mut(k + 2, j, p * z);
|
||||||
|
}
|
||||||
|
A.sub_element_mut(k + 1, j, p * y);
|
||||||
|
A.sub_element_mut(k, j, p * x);
|
||||||
|
}
|
||||||
|
let mmin;
|
||||||
|
if nn < k + 3 {
|
||||||
|
mmin = nn;
|
||||||
|
} else {
|
||||||
|
mmin = k + 3;
|
||||||
|
}
|
||||||
|
for i in 0..mmin+1 {
|
||||||
|
p = x * A.get(i, k) + y * A.get(i, k + 1);
|
||||||
|
if k + 1 != nn {
|
||||||
|
p += z * A.get(i, k + 2);
|
||||||
|
A.sub_element_mut(i, k + 2, p * r);
|
||||||
|
}
|
||||||
|
A.sub_element_mut(i, k + 1, p * q);
|
||||||
|
A.sub_element_mut(i, k, p);
|
||||||
|
}
|
||||||
|
for i in 0..n {
|
||||||
|
p = x * V.get(i, k) + y * V.get(i, k + 1);
|
||||||
|
if k + 1 != nn {
|
||||||
|
p += z * V.get(i, k + 2);
|
||||||
|
V.sub_element_mut(i, k + 2, p * r);
|
||||||
|
}
|
||||||
|
V.sub_element_mut(i, k + 1, p * q);
|
||||||
|
V.sub_element_mut(i, k, p);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if l + 1 >= nn {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if anorm != 0f64 {
|
||||||
|
for nn in (0..n).rev() {
|
||||||
|
p = d[nn];
|
||||||
|
q = e[nn];
|
||||||
|
let na = nn.wrapping_sub(1);
|
||||||
|
if q == 0f64 {
|
||||||
|
let mut m = nn;
|
||||||
|
A.set(nn, nn, 1.0);
|
||||||
|
if nn > 0 {
|
||||||
|
let mut i = nn - 1;
|
||||||
|
loop {
|
||||||
|
let w = A.get(i, i) - p;
|
||||||
|
r = 0.0;
|
||||||
|
for j in m..=nn {
|
||||||
|
r += A.get(i, j) * A.get(j, nn);
|
||||||
|
}
|
||||||
|
if e[i] < 0.0 {
|
||||||
|
z = w;
|
||||||
|
s = r;
|
||||||
|
} else {
|
||||||
|
m = i;
|
||||||
|
|
||||||
|
if e[i] == 0.0 {
|
||||||
|
t = w;
|
||||||
|
if t == 0.0 {
|
||||||
|
t = std::f64::EPSILON * anorm;
|
||||||
|
}
|
||||||
|
A.set(i, nn, -r / t);
|
||||||
|
} else {
|
||||||
|
let x = A.get(i, i + 1);
|
||||||
|
let y = A.get(i + 1, i);
|
||||||
|
q = (d[i] - p).powf(2f64) + e[i].powf(2f64);
|
||||||
|
t = (x * s - z * r) / q;
|
||||||
|
A.set(i, nn, t);
|
||||||
|
if x.abs() > z.abs() {
|
||||||
|
A.set(i + 1, nn, (-r - w * t) / x);
|
||||||
|
} else {
|
||||||
|
A.set(i + 1, nn, (-s - y * t) / z);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t = A.get(i, nn).abs();
|
||||||
|
if std::f64::EPSILON * t * t > 1f64 {
|
||||||
|
for j in i..=nn {
|
||||||
|
A.div_element_mut(j, nn, t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if i == 0{
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
i -= 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if q < 0f64 {
|
||||||
|
let mut m = na;
|
||||||
|
if A.get(nn, na).abs() > A.get(na, nn).abs() {
|
||||||
|
A.set(na, na, q / A.get(nn, na));
|
||||||
|
A.set(na, nn, -(A.get(nn, nn) - p) / A.get(nn, na));
|
||||||
|
} else {
|
||||||
|
let temp = Complex::new(0.0, -A.get(na, nn)) / Complex::new(A.get(na, na) - p, q);
|
||||||
|
A.set(na, na, temp.re);
|
||||||
|
A.set(na, nn, temp.im);
|
||||||
|
}
|
||||||
|
A.set(nn, na, 0.0);
|
||||||
|
A.set(nn, nn, 1.0);
|
||||||
|
if nn >= 2 {
|
||||||
|
for i in (0..nn - 1).rev() {
|
||||||
|
let w = A.get(i, i) - p;
|
||||||
|
let mut ra = 0f64;
|
||||||
|
let mut sa = 0f64;
|
||||||
|
for j in m..=nn {
|
||||||
|
ra += A.get(i, j) * A.get(j, na);
|
||||||
|
sa += A.get(i, j) * A.get(j, nn);
|
||||||
|
}
|
||||||
|
if e[i] < 0.0 {
|
||||||
|
z = w;
|
||||||
|
r = ra;
|
||||||
|
s = sa;
|
||||||
|
} else {
|
||||||
|
m = i;
|
||||||
|
if e[i] == 0.0 {
|
||||||
|
let temp = Complex::new(-ra, -sa) / Complex::new(w, q);
|
||||||
|
A.set(i, na, temp.re);
|
||||||
|
A.set(i, nn, temp.im);
|
||||||
|
} else {
|
||||||
|
let x = A.get(i, i + 1);
|
||||||
|
let y = A.get(i + 1, i);
|
||||||
|
let mut vr = (d[i] - p).powf(2f64) + (e[i]).powf(2.0) - q * q;
|
||||||
|
let vi = 2.0 * q * (d[i] - p);
|
||||||
|
if vr == 0.0 && vi == 0.0 {
|
||||||
|
vr = std::f64::EPSILON * anorm * (w.abs() + q.abs() + x.abs() + y.abs() + z.abs());
|
||||||
|
}
|
||||||
|
let temp = Complex::new(x * r - z * ra + q * sa, x * s - z * sa - q * ra) / Complex::new(vr, vi);
|
||||||
|
A.set(i, na, temp.re);
|
||||||
|
A.set(i, nn, temp.im);
|
||||||
|
if x.abs() > z.abs() + q.abs() {
|
||||||
|
A.set(i + 1, na, (-ra - w * A.get(i, na) + q * A.get(i, nn)) / x);
|
||||||
|
A.set(i + 1, nn, (-sa - w * A.get(i, nn) - q * A.get(i, na)) / x);
|
||||||
|
} else {
|
||||||
|
let temp = Complex::new(-r - y * A.get(i, na), -s - y * A.get(i, nn)) / Complex::new(z, q);
|
||||||
|
A.set(i + 1, na, temp.re);
|
||||||
|
A.set(i + 1, nn, temp.im);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t = f64::max(A.get(i, na).abs(), A.get(i, nn).abs());
|
||||||
|
if std::f64::EPSILON * t * t > 1f64 {
|
||||||
|
for j in i..=nn {
|
||||||
|
A.div_element_mut(j, na, t);
|
||||||
|
A.div_element_mut(j, nn, t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for j in (0..n).rev() {
|
||||||
|
for i in 0..n {
|
||||||
|
z = 0f64;
|
||||||
|
for k in 0..=j {
|
||||||
|
z += V.get(i, k) * A.get(k, j);
|
||||||
|
}
|
||||||
|
V.set(i, j, z);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn balbak<M: BaseMatrix>(V: &mut M, scale: &Vec<f64>) {
|
||||||
|
let (n, _) = V.shape();
|
||||||
|
for i in 0..n {
|
||||||
|
for j in 0..n {
|
||||||
|
V.mul_element_mut(i, j, scale[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sort<M: BaseMatrix>(d: &mut Vec<f64>, e: &mut Vec<f64>, V: &mut M) {
|
||||||
|
let n = d.len();
|
||||||
|
let mut temp = vec![0f64; n];
|
||||||
|
for j in 1..n {
|
||||||
|
let real = d[j];
|
||||||
|
let img = e[j];
|
||||||
|
for k in 0..n {
|
||||||
|
temp[k] = V.get(k, j);
|
||||||
|
}
|
||||||
|
let mut i = j as i32 - 1;
|
||||||
|
while i >= 0 {
|
||||||
|
if d[i as usize] >= d[j] {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
d[i as usize + 1] = d[i as usize];
|
||||||
|
e[i as usize + 1] = e[i as usize];
|
||||||
|
for k in 0..n {
|
||||||
|
V.set(k, i as usize + 1, V.get(k, i as usize));
|
||||||
|
}
|
||||||
|
i -= 1;
|
||||||
|
}
|
||||||
|
d[i as usize + 1] = real;
|
||||||
|
e[i as usize + 1] = img;
|
||||||
|
for k in 0..n {
|
||||||
|
V.set(k, i as usize + 1, temp[k]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|||||||
+7
-24
@@ -1,14 +1,16 @@
|
|||||||
pub mod naive;
|
pub mod naive;
|
||||||
|
pub mod qr;
|
||||||
pub mod svd;
|
pub mod svd;
|
||||||
pub mod evd;
|
pub mod evd;
|
||||||
pub mod ndarray_bindings;
|
pub mod ndarray_bindings;
|
||||||
|
|
||||||
use std::ops::Range;
|
use std::ops::Range;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use svd::SVD;
|
use svd::SVDDecomposableMatrix;
|
||||||
use evd::EVD;
|
use evd::EVDDecomposableMatrix;
|
||||||
|
use qr::QRDecomposableMatrix;
|
||||||
|
|
||||||
pub trait Matrix: Clone + Debug {
|
pub trait BaseMatrix: Clone + Debug {
|
||||||
|
|
||||||
type RowVector: Clone + Debug;
|
type RowVector: Clone + Debug;
|
||||||
|
|
||||||
@@ -24,27 +26,6 @@ pub trait Matrix: Clone + Debug {
|
|||||||
|
|
||||||
fn set(&mut self, row: usize, col: usize, x: f64);
|
fn set(&mut self, row: usize, col: usize, x: f64);
|
||||||
|
|
||||||
fn qr_solve_mut(&mut self, b: Self) -> Self;
|
|
||||||
|
|
||||||
fn svd(&self) -> SVD<Self>;
|
|
||||||
|
|
||||||
fn svd_solve_mut(&mut self, b: Self) -> Self {
|
|
||||||
self.svd_solve(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn svd_solve(&self, b: Self) -> Self {
|
|
||||||
|
|
||||||
let svd = self.svd();
|
|
||||||
svd.solve(b)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
fn evd(&self, symmetric: bool) -> EVD<Self>{
|
|
||||||
self.clone().evd_mut(symmetric)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn evd_mut(self, symmetric: bool) -> EVD<Self>;
|
|
||||||
|
|
||||||
fn eye(size: usize) -> Self;
|
fn eye(size: usize) -> Self;
|
||||||
|
|
||||||
fn zeros(nrows: usize, ncols: usize) -> Self;
|
fn zeros(nrows: usize, ncols: usize) -> Self;
|
||||||
@@ -193,6 +174,8 @@ pub trait Matrix: Clone + Debug {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait Matrix: BaseMatrix + SVDDecomposableMatrix + EVDDecomposableMatrix + QRDecomposableMatrix {}
|
||||||
|
|
||||||
pub fn row_iter<M: Matrix>(m: &M) -> RowIter<M> {
|
pub fn row_iter<M: Matrix>(m: &M) -> RowIter<M> {
|
||||||
RowIter{
|
RowIter{
|
||||||
m: m,
|
m: m,
|
||||||
|
|||||||
+10
-1179
File diff suppressed because it is too large
Load Diff
@@ -1,10 +1,12 @@
|
|||||||
use std::ops::Range;
|
use std::ops::Range;
|
||||||
use crate::linalg::{Matrix};
|
use crate::linalg::BaseMatrix;
|
||||||
use crate::linalg::svd::SVD;
|
use crate::linalg::Matrix;
|
||||||
use crate::linalg::evd::EVD;
|
use crate::linalg::svd::SVDDecomposableMatrix;
|
||||||
|
use crate::linalg::evd::EVDDecomposableMatrix;
|
||||||
|
use crate::linalg::qr::QRDecomposableMatrix;
|
||||||
use ndarray::{Array, ArrayBase, OwnedRepr, Ix2, Ix1, Axis, stack, s};
|
use ndarray::{Array, ArrayBase, OwnedRepr, Ix2, Ix1, Axis, stack, s};
|
||||||
|
|
||||||
impl Matrix for ArrayBase<OwnedRepr<f64>, Ix2>
|
impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
|
||||||
{
|
{
|
||||||
type RowVector = ArrayBase<OwnedRepr<f64>, Ix1>;
|
type RowVector = ArrayBase<OwnedRepr<f64>, Ix1>;
|
||||||
|
|
||||||
@@ -34,18 +36,6 @@ impl Matrix for ArrayBase<OwnedRepr<f64>, Ix2>
|
|||||||
self[[row, col]] = x;
|
self[[row, col]] = x;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn svd(&self) -> SVD<Self>{
|
|
||||||
panic!("svd method is not implemented for ndarray");
|
|
||||||
}
|
|
||||||
|
|
||||||
fn evd_mut(self, symmetric: bool) -> EVD<Self>{
|
|
||||||
panic!("evd method is not implemented for ndarray");
|
|
||||||
}
|
|
||||||
|
|
||||||
fn qr_solve_mut(&mut self, b: Self) -> Self {
|
|
||||||
panic!("qr_solve_mut method is not implemented for ndarray");
|
|
||||||
}
|
|
||||||
|
|
||||||
fn eye(size: usize) -> Self {
|
fn eye(size: usize) -> Self {
|
||||||
Array::eye(size)
|
Array::eye(size)
|
||||||
}
|
}
|
||||||
@@ -286,6 +276,14 @@ impl Matrix for ArrayBase<OwnedRepr<f64>, Ix2>
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl SVDDecomposableMatrix for ArrayBase<OwnedRepr<f64>, Ix2> {}
|
||||||
|
|
||||||
|
impl EVDDecomposableMatrix for ArrayBase<OwnedRepr<f64>, Ix2> {}
|
||||||
|
|
||||||
|
impl QRDecomposableMatrix for ArrayBase<OwnedRepr<f64>, Ix2> {}
|
||||||
|
|
||||||
|
impl Matrix for ArrayBase<OwnedRepr<f64>, Ix2> {}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -359,7 +357,7 @@ mod tests {
|
|||||||
[4., 5., 6.]]);
|
[4., 5., 6.]]);
|
||||||
a.div_element_mut(1, 1, 5.);
|
a.div_element_mut(1, 1, 5.);
|
||||||
|
|
||||||
assert_eq!(Matrix::get(&a, 1, 1), 1.);
|
assert_eq!(BaseMatrix::get(&a, 1, 1), 1.);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -370,7 +368,7 @@ mod tests {
|
|||||||
[4., 5., 6.]]);
|
[4., 5., 6.]]);
|
||||||
a.mul_element_mut(1, 1, 5.);
|
a.mul_element_mut(1, 1, 5.);
|
||||||
|
|
||||||
assert_eq!(Matrix::get(&a, 1, 1), 25.);
|
assert_eq!(BaseMatrix::get(&a, 1, 1), 25.);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -381,7 +379,7 @@ mod tests {
|
|||||||
[4., 5., 6.]]);
|
[4., 5., 6.]]);
|
||||||
a.add_element_mut(1, 1, 5.);
|
a.add_element_mut(1, 1, 5.);
|
||||||
|
|
||||||
assert_eq!(Matrix::get(&a, 1, 1), 10.);
|
assert_eq!(BaseMatrix::get(&a, 1, 1), 10.);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -392,7 +390,7 @@ mod tests {
|
|||||||
[4., 5., 6.]]);
|
[4., 5., 6.]]);
|
||||||
a.sub_element_mut(1, 1, 5.);
|
a.sub_element_mut(1, 1, 5.);
|
||||||
|
|
||||||
assert_eq!(Matrix::get(&a, 1, 1), 0.);
|
assert_eq!(BaseMatrix::get(&a, 1, 1), 0.);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -431,7 +429,7 @@ mod tests {
|
|||||||
result.set(1, 1, 10.);
|
result.set(1, 1, 10.);
|
||||||
|
|
||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
assert_eq!(10., Matrix::get(&result, 1, 1));
|
assert_eq!(10., BaseMatrix::get(&result, 1, 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -447,7 +445,7 @@ mod tests {
|
|||||||
let expected = arr2(&[
|
let expected = arr2(&[
|
||||||
[22., 28.],
|
[22., 28.],
|
||||||
[49., 64.]]);
|
[49., 64.]]);
|
||||||
let result = Matrix::dot(&a, &b);
|
let result = BaseMatrix::dot(&a, &b);
|
||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -470,7 +468,7 @@ mod tests {
|
|||||||
&[
|
&[
|
||||||
[2., 3.],
|
[2., 3.],
|
||||||
[5., 6.]]);
|
[5., 6.]]);
|
||||||
let result = Matrix::slice(&a, 0..2, 1..3);
|
let result = BaseMatrix::slice(&a, 0..2, 1..3);
|
||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -510,12 +508,12 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn reshape() {
|
fn reshape() {
|
||||||
let m_orig = arr2(&[[1., 2., 3., 4., 5., 6.]]);
|
let m_orig = arr2(&[[1., 2., 3., 4., 5., 6.]]);
|
||||||
let m_2_by_3 = Matrix::reshape(&m_orig, 2, 3);
|
let m_2_by_3 = BaseMatrix::reshape(&m_orig, 2, 3);
|
||||||
let m_result = Matrix::reshape(&m_2_by_3, 1, 6);
|
let m_result = BaseMatrix::reshape(&m_2_by_3, 1, 6);
|
||||||
assert_eq!(Matrix::shape(&m_2_by_3), (2, 3));
|
assert_eq!(BaseMatrix::shape(&m_2_by_3), (2, 3));
|
||||||
assert_eq!(Matrix::get(&m_2_by_3, 1, 1), 5.);
|
assert_eq!(BaseMatrix::get(&m_2_by_3, 1, 1), 5.);
|
||||||
assert_eq!(Matrix::get(&m_result, 0, 1), 2.);
|
assert_eq!(BaseMatrix::get(&m_result, 0, 1), 2.);
|
||||||
assert_eq!(Matrix::get(&m_result, 0, 3), 4.);
|
assert_eq!(BaseMatrix::get(&m_result, 0, 3), 4.);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -544,9 +542,9 @@ mod tests {
|
|||||||
fn softmax_mut(){
|
fn softmax_mut(){
|
||||||
let mut prob = arr2(&[[1., 2., 3.]]);
|
let mut prob = arr2(&[[1., 2., 3.]]);
|
||||||
prob.softmax_mut();
|
prob.softmax_mut();
|
||||||
assert!((Matrix::get(&prob, 0, 0) - 0.09).abs() < 0.01);
|
assert!((BaseMatrix::get(&prob, 0, 0) - 0.09).abs() < 0.01);
|
||||||
assert!((Matrix::get(&prob, 0, 1) - 0.24).abs() < 0.01);
|
assert!((BaseMatrix::get(&prob, 0, 1) - 0.24).abs() < 0.01);
|
||||||
assert!((Matrix::get(&prob, 0, 2) - 0.66).abs() < 0.01);
|
assert!((BaseMatrix::get(&prob, 0, 2) - 0.66).abs() < 0.01);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -599,7 +597,7 @@ mod tests {
|
|||||||
let a = arr2(&[[1., 0., 0.],
|
let a = arr2(&[[1., 0., 0.],
|
||||||
[0., 1., 0.],
|
[0., 1., 0.],
|
||||||
[0., 0., 1.]]);
|
[0., 0., 1.]]);
|
||||||
let res: Array2<f64> = Matrix::eye(3);
|
let res: Array2<f64> = BaseMatrix::eye(3);
|
||||||
assert_eq!(res, a);
|
assert_eq!(res, a);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,198 @@
|
|||||||
|
use crate::linalg::BaseMatrix;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct QR<M: BaseMatrix> {
|
||||||
|
QR: M,
|
||||||
|
tau: Vec<f64>,
|
||||||
|
singular: bool
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M: BaseMatrix> QR<M> {
|
||||||
|
pub fn new(QR: M, tau: Vec<f64>) -> QR<M> {
|
||||||
|
|
||||||
|
let mut singular = false;
|
||||||
|
for j in 0..tau.len() {
|
||||||
|
if tau[j] == 0. {
|
||||||
|
singular = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
QR {
|
||||||
|
QR: QR,
|
||||||
|
tau: tau,
|
||||||
|
singular: singular
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn R(&self) -> M {
|
||||||
|
let (_, n) = self.QR.shape();
|
||||||
|
let mut R = M::zeros(n, n);
|
||||||
|
for i in 0..n {
|
||||||
|
R.set(i, i, self.tau[i]);
|
||||||
|
for j in i+1..n {
|
||||||
|
R.set(i, j, self.QR.get(i, j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return R;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn Q(&self) -> M {
|
||||||
|
let (m, n) = self.QR.shape();
|
||||||
|
let mut Q = M::zeros(m, n);
|
||||||
|
let mut k = n - 1;
|
||||||
|
loop {
|
||||||
|
Q.set(k, k, 1.0);
|
||||||
|
for j in k..n {
|
||||||
|
if self.QR.get(k, k) != 0f64 {
|
||||||
|
let mut s = 0f64;
|
||||||
|
for i in k..m {
|
||||||
|
s += self.QR.get(i, k) * Q.get(i, j);
|
||||||
|
}
|
||||||
|
s = -s / self.QR.get(k, k);
|
||||||
|
for i in k..m {
|
||||||
|
Q.add_element_mut(i, j, s * self.QR.get(i, k));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if k == 0 {
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
k -= 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Q;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn solve(&self, mut b: M) -> M {
|
||||||
|
|
||||||
|
let (m, n) = self.QR.shape();
|
||||||
|
let (b_nrows, b_ncols) = b.shape();
|
||||||
|
|
||||||
|
if b_nrows != m {
|
||||||
|
panic!("Row dimensions do not agree: A is {} x {}, but B is {} x {}", m, n, b_nrows, b_ncols);
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.singular {
|
||||||
|
panic!("Matrix is rank deficient.");
|
||||||
|
}
|
||||||
|
|
||||||
|
for k in 0..n {
|
||||||
|
for j in 0..b_ncols {
|
||||||
|
let mut s = 0f64;
|
||||||
|
for i in k..m {
|
||||||
|
s += self.QR.get(i, k) * b.get(i, j);
|
||||||
|
}
|
||||||
|
s = -s / self.QR.get(k, k);
|
||||||
|
for i in k..m {
|
||||||
|
b.add_element_mut(i, j, s * self.QR.get(i, k));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for k in (0..n).rev() {
|
||||||
|
for j in 0..b_ncols {
|
||||||
|
b.set(k, j, b.get(k, j) / self.tau[k]);
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in 0..k {
|
||||||
|
for j in 0..b_ncols {
|
||||||
|
b.sub_element_mut(i, j, b.get(k, j) * self.QR.get(i, k));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait QRDecomposableMatrix: BaseMatrix {
|
||||||
|
|
||||||
|
fn qr(&self) -> QR<Self> {
|
||||||
|
self.clone().qr_mut()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn qr_mut(mut self) -> QR<Self> {
|
||||||
|
|
||||||
|
let (m, n) = self.shape();
|
||||||
|
|
||||||
|
let mut r_diagonal: Vec<f64> = vec![0f64; n];
|
||||||
|
|
||||||
|
for k in 0..n {
|
||||||
|
let mut nrm = 0f64;
|
||||||
|
for i in k..m {
|
||||||
|
nrm = nrm.hypot(self.get(i, k));
|
||||||
|
}
|
||||||
|
|
||||||
|
if nrm.abs() > std::f64::EPSILON {
|
||||||
|
|
||||||
|
if self.get(k, k) < 0f64 {
|
||||||
|
nrm = -nrm;
|
||||||
|
}
|
||||||
|
for i in k..m {
|
||||||
|
self.div_element_mut(i, k, nrm);
|
||||||
|
}
|
||||||
|
self.add_element_mut(k, k, 1f64);
|
||||||
|
|
||||||
|
for j in k+1..n {
|
||||||
|
let mut s = 0f64;
|
||||||
|
for i in k..m {
|
||||||
|
s += self.get(i, k) * self.get(i, j);
|
||||||
|
}
|
||||||
|
s = -s / self.get(k, k);
|
||||||
|
for i in k..m {
|
||||||
|
self.add_element_mut(i, j, s * self.get(i, k));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r_diagonal[k] = -nrm;
|
||||||
|
}
|
||||||
|
|
||||||
|
QR::new(self, r_diagonal)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
fn qr_solve_mut(self, b: Self) -> Self {
|
||||||
|
|
||||||
|
self.qr_mut().solve(b)
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn decompose() {
|
||||||
|
|
||||||
|
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||||
|
let q = DenseMatrix::from_array(&[
|
||||||
|
&[-0.7448, 0.2436, 0.6212],
|
||||||
|
&[-0.331, -0.9432, -0.027],
|
||||||
|
&[-0.5793, 0.2257, -0.7832]]);
|
||||||
|
let r = DenseMatrix::from_array(&[
|
||||||
|
&[-1.2083, -0.6373, -1.0842],
|
||||||
|
&[0.0, -0.3064, 0.0682],
|
||||||
|
&[0.0, 0.0, -0.1999]]);
|
||||||
|
let qr = a.qr();
|
||||||
|
assert!(qr.Q().abs().approximate_eq(&q.abs(), 1e-4));
|
||||||
|
assert!(qr.R().abs().approximate_eq(&r.abs(), 1e-4));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn qr_solve_mut() {
|
||||||
|
|
||||||
|
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||||
|
let b = DenseMatrix::from_array(&[&[0.5, 0.2],&[0.5, 0.8], &[0.5, 0.3]]);
|
||||||
|
let expected_w = DenseMatrix::from_array(&[
|
||||||
|
&[-0.2027027, -1.2837838],
|
||||||
|
&[0.8783784, 2.2297297],
|
||||||
|
&[0.4729730, 0.6621622]
|
||||||
|
]);
|
||||||
|
let w = a.qr_solve_mut(b);
|
||||||
|
assert!(w.approximate_eq(&expected_w, 1e-2));
|
||||||
|
}
|
||||||
|
}
|
||||||
+361
-3
@@ -1,7 +1,7 @@
|
|||||||
use crate::linalg::{Matrix};
|
use crate::linalg::BaseMatrix;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct SVD<M: Matrix> {
|
pub struct SVD<M: SVDDecomposableMatrix> {
|
||||||
pub U: M,
|
pub U: M,
|
||||||
pub V: M,
|
pub V: M,
|
||||||
pub s: Vec<f64>,
|
pub s: Vec<f64>,
|
||||||
@@ -11,7 +11,365 @@ pub struct SVD<M: Matrix> {
|
|||||||
tol: f64
|
tol: f64
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<M: Matrix> SVD<M> {
|
pub trait SVDDecomposableMatrix: BaseMatrix {
|
||||||
|
|
||||||
|
fn svd_solve_mut(self, b: Self) -> Self {
|
||||||
|
self.svd_mut().solve(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn svd_solve(&self, b: Self) -> Self {
|
||||||
|
self.svd().solve(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn svd(&self) -> SVD<Self> {
|
||||||
|
self.clone().svd_mut()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn svd_mut(self) -> SVD<Self> {
|
||||||
|
|
||||||
|
let mut U = self;
|
||||||
|
|
||||||
|
let (m, n) = U.shape();
|
||||||
|
|
||||||
|
let (mut l, mut nm) = (0usize, 0usize);
|
||||||
|
let (mut anorm, mut g, mut scale) = (0f64, 0f64, 0f64);
|
||||||
|
|
||||||
|
let mut v = Self::zeros(n, n);
|
||||||
|
let mut w = vec![0f64; n];
|
||||||
|
let mut rv1 = vec![0f64; n];
|
||||||
|
|
||||||
|
for i in 0..n {
|
||||||
|
l = i + 2;
|
||||||
|
rv1[i] = scale * g;
|
||||||
|
g = 0f64;
|
||||||
|
let mut s = 0f64;
|
||||||
|
scale = 0f64;
|
||||||
|
|
||||||
|
if i < m {
|
||||||
|
for k in i..m {
|
||||||
|
scale += U.get(k, i).abs();
|
||||||
|
}
|
||||||
|
|
||||||
|
if scale.abs() > std::f64::EPSILON {
|
||||||
|
|
||||||
|
for k in i..m {
|
||||||
|
U.div_element_mut(k, i, scale);
|
||||||
|
s += U.get(k, i) * U.get(k, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut f = U.get(i, i);
|
||||||
|
g = -s.sqrt().copysign(f);
|
||||||
|
let h = f * g - s;
|
||||||
|
U.set(i, i, f - g);
|
||||||
|
for j in l - 1..n {
|
||||||
|
s = 0f64;
|
||||||
|
for k in i..m {
|
||||||
|
s += U.get(k, i) * U.get(k, j);
|
||||||
|
}
|
||||||
|
f = s / h;
|
||||||
|
for k in i..m {
|
||||||
|
U.add_element_mut(k, j, f * U.get(k, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for k in i..m {
|
||||||
|
U.mul_element_mut(k, i, scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
w[i] = scale * g;
|
||||||
|
g = 0f64;
|
||||||
|
let mut s = 0f64;
|
||||||
|
scale = 0f64;
|
||||||
|
|
||||||
|
if i + 1 <= m && i + 1 != n {
|
||||||
|
for k in l - 1..n {
|
||||||
|
scale += U.get(i, k).abs();
|
||||||
|
}
|
||||||
|
|
||||||
|
if scale.abs() > std::f64::EPSILON {
|
||||||
|
for k in l - 1..n {
|
||||||
|
U.div_element_mut(i, k, scale);
|
||||||
|
s += U.get(i, k) * U.get(i, k);
|
||||||
|
}
|
||||||
|
|
||||||
|
let f = U.get(i, l - 1);
|
||||||
|
g = -s.sqrt().copysign(f);
|
||||||
|
let h = f * g - s;
|
||||||
|
U.set(i, l - 1, f - g);
|
||||||
|
|
||||||
|
for k in l - 1..n {
|
||||||
|
rv1[k] = U.get(i, k) / h;
|
||||||
|
}
|
||||||
|
|
||||||
|
for j in l - 1..m {
|
||||||
|
s = 0f64;
|
||||||
|
for k in l - 1..n {
|
||||||
|
s += U.get(j, k) * U.get(i, k);
|
||||||
|
}
|
||||||
|
|
||||||
|
for k in l - 1..n {
|
||||||
|
U.add_element_mut(j, k, s * rv1[k]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for k in l - 1..n {
|
||||||
|
U.mul_element_mut(i, k, scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
anorm = f64::max(anorm, w[i].abs() + rv1[i].abs());
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in (0..n).rev() {
|
||||||
|
if i < n - 1 {
|
||||||
|
if g != 0.0 {
|
||||||
|
for j in l..n {
|
||||||
|
v.set(j, i, (U.get(i, j) / U.get(i, l)) / g);
|
||||||
|
}
|
||||||
|
for j in l..n {
|
||||||
|
let mut s = 0f64;
|
||||||
|
for k in l..n {
|
||||||
|
s += U.get(i, k) * v.get(k, j);
|
||||||
|
}
|
||||||
|
for k in l..n {
|
||||||
|
v.add_element_mut(k, j, s * v.get(k, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for j in l..n {
|
||||||
|
v.set(i, j, 0f64);
|
||||||
|
v.set(j, i, 0f64);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
v.set(i, i, 1.0);
|
||||||
|
g = rv1[i];
|
||||||
|
l = i;
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in (0..usize::min(m, n)).rev() {
|
||||||
|
l = i + 1;
|
||||||
|
g = w[i];
|
||||||
|
for j in l..n {
|
||||||
|
U.set(i, j, 0f64);
|
||||||
|
}
|
||||||
|
|
||||||
|
if g.abs() > std::f64::EPSILON {
|
||||||
|
g = 1f64 / g;
|
||||||
|
for j in l..n {
|
||||||
|
let mut s = 0f64;
|
||||||
|
for k in l..m {
|
||||||
|
s += U.get(k, i) * U.get(k, j);
|
||||||
|
}
|
||||||
|
let f = (s / U.get(i, i)) * g;
|
||||||
|
for k in i..m {
|
||||||
|
U.add_element_mut(k, j, f * U.get(k, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for j in i..m {
|
||||||
|
U.mul_element_mut(j, i, g);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for j in i..m {
|
||||||
|
U.set(j, i, 0f64);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
U.add_element_mut(i, i, 1f64);
|
||||||
|
}
|
||||||
|
|
||||||
|
for k in (0..n).rev() {
|
||||||
|
for iteration in 0..30 {
|
||||||
|
let mut flag = true;
|
||||||
|
l = k;
|
||||||
|
while l != 0 {
|
||||||
|
if l == 0 || rv1[l].abs() <= std::f64::EPSILON * anorm {
|
||||||
|
flag = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
nm = l - 1;
|
||||||
|
if w[nm].abs() <= std::f64::EPSILON * anorm {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
l -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if flag {
|
||||||
|
let mut c = 0.0;
|
||||||
|
let mut s = 1.0;
|
||||||
|
for i in l..k+1 {
|
||||||
|
let f = s * rv1[i];
|
||||||
|
rv1[i] = c * rv1[i];
|
||||||
|
if f.abs() <= std::f64::EPSILON * anorm {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
g = w[i];
|
||||||
|
let mut h = f.hypot(g);
|
||||||
|
w[i] = h;
|
||||||
|
h = 1.0 / h;
|
||||||
|
c = g * h;
|
||||||
|
s = -f * h;
|
||||||
|
for j in 0..m {
|
||||||
|
let y = U.get(j, nm);
|
||||||
|
let z = U.get(j, i);
|
||||||
|
U.set(j, nm, y * c + z * s);
|
||||||
|
U.set(j, i, z * c - y * s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let z = w[k];
|
||||||
|
if l == k {
|
||||||
|
if z < 0f64 {
|
||||||
|
w[k] = -z;
|
||||||
|
for j in 0..n {
|
||||||
|
v.set(j, k, -v.get(j, k));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if iteration == 29 {
|
||||||
|
panic!("no convergence in 30 iterations");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut x = w[l];
|
||||||
|
nm = k - 1;
|
||||||
|
let mut y = w[nm];
|
||||||
|
g = rv1[nm];
|
||||||
|
let mut h = rv1[k];
|
||||||
|
let mut f = ((y - z) * (y + z) + (g - h) * (g + h)) / (2.0 * h * y);
|
||||||
|
g = f.hypot(1.0);
|
||||||
|
f = ((x - z) * (x + z) + h * ((y / (f + g.copysign(f))) - h)) / x;
|
||||||
|
let mut c = 1f64;
|
||||||
|
let mut s = 1f64;
|
||||||
|
|
||||||
|
for j in l..=nm {
|
||||||
|
let i = j + 1;
|
||||||
|
g = rv1[i];
|
||||||
|
y = w[i];
|
||||||
|
h = s * g;
|
||||||
|
g = c * g;
|
||||||
|
let mut z = f.hypot(h);
|
||||||
|
rv1[j] = z;
|
||||||
|
c = f / z;
|
||||||
|
s = h / z;
|
||||||
|
f = x * c + g * s;
|
||||||
|
g = g * c - x * s;
|
||||||
|
h = y * s;
|
||||||
|
y *= c;
|
||||||
|
|
||||||
|
for jj in 0..n {
|
||||||
|
x = v.get(jj, j);
|
||||||
|
z = v.get(jj, i);
|
||||||
|
v.set(jj, j, x * c + z * s);
|
||||||
|
v.set(jj, i, z * c - x * s);
|
||||||
|
}
|
||||||
|
|
||||||
|
z = f.hypot(h);
|
||||||
|
w[j] = z;
|
||||||
|
if z.abs() > std::f64::EPSILON {
|
||||||
|
z = 1.0 / z;
|
||||||
|
c = f * z;
|
||||||
|
s = h * z;
|
||||||
|
}
|
||||||
|
|
||||||
|
f = c * g + s * y;
|
||||||
|
x = c * y - s * g;
|
||||||
|
for jj in 0..m {
|
||||||
|
y = U.get(jj, j);
|
||||||
|
z = U.get(jj, i);
|
||||||
|
U.set(jj, j, y * c + z * s);
|
||||||
|
U.set(jj, i, z * c - y * s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rv1[l] = 0.0;
|
||||||
|
rv1[k] = f;
|
||||||
|
w[k] = x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut inc = 1usize;
|
||||||
|
let mut su = vec![0f64; m];
|
||||||
|
let mut sv = vec![0f64; n];
|
||||||
|
|
||||||
|
loop {
|
||||||
|
inc *= 3;
|
||||||
|
inc += 1;
|
||||||
|
if inc > n {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
loop {
|
||||||
|
inc /= 3;
|
||||||
|
for i in inc..n {
|
||||||
|
let sw = w[i];
|
||||||
|
for k in 0..m {
|
||||||
|
su[k] = U.get(k, i);
|
||||||
|
}
|
||||||
|
for k in 0..n {
|
||||||
|
sv[k] = v.get(k, i);
|
||||||
|
}
|
||||||
|
let mut j = i;
|
||||||
|
while w[j - inc] < sw {
|
||||||
|
w[j] = w[j - inc];
|
||||||
|
for k in 0..m {
|
||||||
|
U.set(k, j, U.get(k, j - inc));
|
||||||
|
}
|
||||||
|
for k in 0..n {
|
||||||
|
v.set(k, j, v.get(k, j - inc));
|
||||||
|
}
|
||||||
|
j -= inc;
|
||||||
|
if j < inc {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w[j] = sw;
|
||||||
|
for k in 0..m {
|
||||||
|
U.set(k, j, su[k]);
|
||||||
|
}
|
||||||
|
for k in 0..n {
|
||||||
|
v.set(k, j, sv[k]);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
if inc <= 1 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for k in 0..n {
|
||||||
|
let mut s = 0.;
|
||||||
|
for i in 0..m {
|
||||||
|
if U.get(i, k) < 0. {
|
||||||
|
s += 1.;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for j in 0..n {
|
||||||
|
if v.get(j, k) < 0. {
|
||||||
|
s += 1.;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s > (m + n) as f64 / 2. {
|
||||||
|
for i in 0..m {
|
||||||
|
U.set(i, k, -U.get(i, k));
|
||||||
|
}
|
||||||
|
for j in 0..n {
|
||||||
|
v.set(j, k, -v.get(j, k));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SVD::new(U, v, w)
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<M: SVDDecomposableMatrix> SVD<M> {
|
||||||
pub fn new(U: M, V: M, s: Vec<f64>) -> SVD<M> {
|
pub fn new(U: M, V: M, s: Vec<f64>) -> SVD<M> {
|
||||||
let m = U.shape().0;
|
let m = U.shape().0;
|
||||||
let n = V.shape().0;
|
let n = V.shape().0;
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ impl FirstOrderOptimizer for GradientDescent
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
use crate::optimization::line_search::Backtracking;
|
use crate::optimization::line_search::Backtracking;
|
||||||
use crate::optimization::FunctionOrder;
|
use crate::optimization::FunctionOrder;
|
||||||
|
|
||||||
|
|||||||
@@ -226,7 +226,7 @@ impl FirstOrderOptimizer for LBFGS {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
use crate::optimization::line_search::Backtracking;
|
use crate::optimization::line_search::Backtracking;
|
||||||
use crate::optimization::FunctionOrder;
|
use crate::optimization::FunctionOrder;
|
||||||
use crate::math::EPSILON;
|
use crate::math::EPSILON;
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ impl<M: Matrix> Regression<M> for LinearRegression<M> {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn ols_fit_predict() {
|
fn ols_fit_predict() {
|
||||||
|
|||||||
Reference in New Issue
Block a user