Solve conflic with num-traits (#130)

* Solve conflic with num-traits

* Fix clippy warnings

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
This commit is contained in:
morenol
2022-05-05 10:39:18 -04:00
committed by GitHub
parent 12c102d02b
commit 820201e920
23 changed files with 58 additions and 65 deletions
+3 -4
View File
@@ -87,8 +87,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> Cholesky<T, M> {
if bn != rn {
return Err(Failed::because(
FailedError::SolutionFailed,
&"Can\'t solve Ax = b for x. Number of rows in b != number of rows in R."
.to_string(),
"Can\'t solve Ax = b for x. Number of rows in b != number of rows in R.",
));
}
@@ -128,7 +127,7 @@ pub trait CholeskyDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
if m != n {
return Err(Failed::because(
FailedError::DecompositionFailed,
&"Can\'t do Cholesky decomposition on a non-square matrix".to_string(),
"Can\'t do Cholesky decomposition on a non-square matrix",
));
}
@@ -148,7 +147,7 @@ pub trait CholeskyDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
if d < T::zero() {
return Err(Failed::because(
FailedError::DecompositionFailed,
&"The matrix is not positive definite.".to_string(),
"The matrix is not positive definite.",
));
}
+7 -12
View File
@@ -97,7 +97,7 @@ pub trait EVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
}
}
fn tred2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec<T>) {
fn tred2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
let (n, _) = V.shape();
for (i, d_i) in d.iter_mut().enumerate().take(n) {
*d_i = V.get(n - 1, i);
@@ -195,7 +195,7 @@ fn tred2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec
e[0] = T::zero();
}
fn tql2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec<T>) {
fn tql2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
let (n, _) = V.shape();
for i in 1..n {
e[i - 1] = e[i];
@@ -419,7 +419,7 @@ fn eltran<T: RealNumber, M: BaseMatrix<T>>(A: &M, V: &mut M, perm: &[usize]) {
}
}
fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e: &mut Vec<T>) {
fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &mut [T]) {
let (n, _) = A.shape();
let mut z = T::zero();
let mut s = T::zero();
@@ -471,7 +471,7 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e
A.set(nn, nn, x);
A.set(nn - 1, nn - 1, y + t);
if q >= T::zero() {
z = p + z.copysign(p);
z = p + RealNumber::copysign(z, p);
d[nn - 1] = x + z;
d[nn] = x + z;
if z != T::zero() {
@@ -570,7 +570,7 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e
r /= x;
}
}
let s = (p * p + q * q + r * r).sqrt().copysign(p);
let s = RealNumber::copysign((p * p + q * q + r * r).sqrt(), p);
if s != T::zero() {
if k == m {
if l != m {
@@ -594,12 +594,7 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e
A.sub_element_mut(k + 1, j, p * y);
A.sub_element_mut(k, j, p * x);
}
let mmin;
if nn < k + 3 {
mmin = nn;
} else {
mmin = k + 3;
}
let mmin = if nn < k + 3 { nn } else { k + 3 };
for i in 0..mmin + 1 {
p = x * A.get(i, k) + y * A.get(i, k + 1);
if k + 1 != nn {
@@ -783,7 +778,7 @@ fn balbak<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, scale: &[T]) {
}
}
fn sort<T: RealNumber, M: BaseMatrix<T>>(d: &mut Vec<T>, e: &mut Vec<T>, V: &mut M) {
fn sort<T: RealNumber, M: BaseMatrix<T>>(d: &mut [T], e: &mut [T], V: &mut M) {
let n = d.len();
let mut temp = vec![T::zero(); n];
for j in 1..n {
+3 -3
View File
@@ -46,13 +46,13 @@ use crate::math::num::RealNumber;
pub struct LU<T: RealNumber, M: BaseMatrix<T>> {
LU: M,
pivot: Vec<usize>,
pivot_sign: i8,
_pivot_sign: i8,
singular: bool,
phantom: PhantomData<T>,
}
impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
pub(crate) fn new(LU: M, pivot: Vec<usize>, pivot_sign: i8) -> LU<T, M> {
pub(crate) fn new(LU: M, pivot: Vec<usize>, _pivot_sign: i8) -> LU<T, M> {
let (_, n) = LU.shape();
let mut singular = false;
@@ -66,7 +66,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
LU {
LU,
pivot,
pivot_sign,
_pivot_sign,
singular,
phantom: PhantomData,
}
+4 -5
View File
@@ -689,12 +689,11 @@ impl<'a, T: RealNumber, M: BaseMatrix<T>> Iterator for RowIter<'a, T, M> {
type Item = Vec<T>;
fn next(&mut self) -> Option<Vec<T>> {
let res;
if self.pos < self.max_pos {
res = Some(self.m.get_row_as_vec(self.pos))
let res = if self.pos < self.max_pos {
Some(self.m.get_row_as_vec(self.pos))
} else {
res = None
}
None
};
self.pos += 1;
res
}
-1
View File
@@ -523,7 +523,6 @@ impl<T: RealNumber> PartialEq for DenseMatrix<T> {
true
}
}
impl<T: RealNumber> From<DenseMatrix<T>> for Vec<T> {
fn from(dense_matrix: DenseMatrix<T>) -> Vec<T> {
dense_matrix.values
+6 -6
View File
@@ -47,7 +47,7 @@ pub struct SVD<T: RealNumber, M: SVDDecomposableMatrix<T>> {
pub V: M,
/// Singular values of the original matrix
pub s: Vec<T>,
full: bool,
_full: bool,
m: usize,
n: usize,
tol: T,
@@ -116,7 +116,7 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
}
let mut f = U.get(i, i);
g = -s.sqrt().copysign(f);
g = -RealNumber::copysign(s.sqrt(), f);
let h = f * g - s;
U.set(i, i, f - g);
for j in l - 1..n {
@@ -152,7 +152,7 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
}
let f = U.get(i, l - 1);
g = -s.sqrt().copysign(f);
g = -RealNumber::copysign(s.sqrt(), f);
let h = f * g - s;
U.set(i, l - 1, f - g);
@@ -299,7 +299,7 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
let mut h = rv1[k];
let mut f = ((y - z) * (y + z) + (g - h) * (g + h)) / (T::two() * h * y);
g = f.hypot(T::one());
f = ((x - z) * (x + z) + h * ((y / (f + g.copysign(f))) - h)) / x;
f = ((x - z) * (x + z) + h * ((y / (f + RealNumber::copysign(g, f))) - h)) / x;
let mut c = T::one();
let mut s = T::one();
@@ -428,13 +428,13 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
pub(crate) fn new(U: M, V: M, s: Vec<T>) -> SVD<T, M> {
let m = U.shape().0;
let n = V.shape().0;
let full = s.len() == m.min(n);
let _full = s.len() == m.min(n);
let tol = T::half() * (T::from(m + n).unwrap() + T::one()).sqrt() * s[0] * T::epsilon();
SVD {
U,
V,
s,
full,
_full,
m,
n,
tol,