feat: extends interface of Matrix to support for broad range of types

This commit is contained in:
Volodymyr Orlov
2020-03-26 15:28:26 -07:00
parent 84ffd331cd
commit 02b85415d9
27 changed files with 1021 additions and 868 deletions
+24 -21
View File
@@ -1,20 +1,23 @@
#![allow(non_snake_case)]
use std::fmt::Debug;
use crate::math::num::FloatExt;
use crate::linalg::BaseMatrix;
#[derive(Debug, Clone)]
pub struct QR<M: BaseMatrix> {
pub struct QR<T: FloatExt + Debug, M: BaseMatrix<T>> {
QR: M,
tau: Vec<f64>,
tau: Vec<T>,
singular: bool
}
impl<M: BaseMatrix> QR<M> {
pub fn new(QR: M, tau: Vec<f64>) -> QR<M> {
impl<T: FloatExt + Debug, M: BaseMatrix<T>> QR<T, M> {
pub fn new(QR: M, tau: Vec<T>) -> QR<T, M> {
let mut singular = false;
for j in 0..tau.len() {
if tau[j] == 0. {
if tau[j] == T::zero() {
singular = true;
break;
}
@@ -44,12 +47,12 @@ impl<M: BaseMatrix> QR<M> {
let mut Q = M::zeros(m, n);
let mut k = n - 1;
loop {
Q.set(k, k, 1.0);
Q.set(k, k, T::one());
for j in k..n {
if self.QR.get(k, k) != 0f64 {
let mut s = 0f64;
if self.QR.get(k, k) != T::zero() {
let mut s = T::zero();
for i in k..m {
s += self.QR.get(i, k) * Q.get(i, j);
s = s + self.QR.get(i, k) * Q.get(i, j);
}
s = -s / self.QR.get(k, k);
for i in k..m {
@@ -81,9 +84,9 @@ impl<M: BaseMatrix> QR<M> {
for k in 0..n {
for j in 0..b_ncols {
let mut s = 0f64;
let mut s = T::zero();
for i in k..m {
s += self.QR.get(i, k) * b.get(i, j);
s = s + self.QR.get(i, k) * b.get(i, j);
}
s = -s / self.QR.get(k, k);
for i in k..m {
@@ -109,38 +112,38 @@ impl<M: BaseMatrix> QR<M> {
}
}
pub trait QRDecomposableMatrix: BaseMatrix {
pub trait QRDecomposableMatrix<T: FloatExt + Debug>: BaseMatrix<T> {
fn qr(&self) -> QR<Self> {
fn qr(&self) -> QR<T, Self> {
self.clone().qr_mut()
}
fn qr_mut(mut self) -> QR<Self> {
fn qr_mut(mut self) -> QR<T, Self> {
let (m, n) = self.shape();
let mut r_diagonal: Vec<f64> = vec![0f64; n];
let mut r_diagonal: Vec<T> = vec![T::zero(); n];
for k in 0..n {
let mut nrm = 0f64;
let mut nrm = T::zero();
for i in k..m {
nrm = nrm.hypot(self.get(i, k));
}
if nrm.abs() > std::f64::EPSILON {
if nrm.abs() > T::epsilon() {
if self.get(k, k) < 0f64 {
if self.get(k, k) < T::zero() {
nrm = -nrm;
}
for i in k..m {
self.div_element_mut(i, k, nrm);
}
self.add_element_mut(k, k, 1f64);
self.add_element_mut(k, k, T::one());
for j in k+1..n {
let mut s = 0f64;
let mut s = T::zero();
for i in k..m {
s += self.get(i, k) * self.get(i, j);
s = s + self.get(i, k) * self.get(i, j);
}
s = -s / self.get(k, k);
for i in k..m {