fix: cargo fmt
This commit is contained in:
+46
-52
@@ -3,22 +3,21 @@
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::linalg::BaseMatrix;
|
||||
use crate::math::num::FloatExt;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LU<T: FloatExt, M: BaseMatrix<T>> {
|
||||
LU: M,
|
||||
pub struct LU<T: FloatExt, M: BaseMatrix<T>> {
|
||||
LU: M,
|
||||
pivot: Vec<usize>,
|
||||
pivot_sign: i8,
|
||||
singular: bool,
|
||||
phantom: PhantomData<T>
|
||||
phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: FloatExt, M: BaseMatrix<T>> LU<T, M> {
|
||||
pub fn new(LU: M, pivot: Vec<usize>, pivot_sign: i8) -> LU<T, M> {
|
||||
|
||||
let (_, n) = LU.shape();
|
||||
let (_, n) = LU.shape();
|
||||
|
||||
let mut singular = false;
|
||||
for j in 0..n {
|
||||
@@ -33,7 +32,7 @@ impl<T: FloatExt, M: BaseMatrix<T>> LU<T, M> {
|
||||
pivot: pivot,
|
||||
pivot_sign: pivot_sign,
|
||||
singular: singular,
|
||||
phantom: PhantomData
|
||||
phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,24 +62,24 @@ impl<T: FloatExt, M: BaseMatrix<T>> LU<T, M> {
|
||||
for i in 0..n_rows {
|
||||
for j in 0..n_cols {
|
||||
if i <= j {
|
||||
U.set(i, j, self.LU.get(i, j));
|
||||
U.set(i, j, self.LU.get(i, j));
|
||||
} else {
|
||||
U.set(i, j, T::zero());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
U
|
||||
}
|
||||
|
||||
pub fn pivot(&self) -> M {
|
||||
let (_, n) = self.LU.shape();
|
||||
let mut piv = M::zeros(n, n);
|
||||
|
||||
|
||||
for i in 0..n {
|
||||
piv.set(i, self.pivot[i], T::one());
|
||||
}
|
||||
|
||||
|
||||
piv
|
||||
}
|
||||
|
||||
@@ -92,7 +91,7 @@ impl<T: FloatExt, M: BaseMatrix<T>> LU<T, M> {
|
||||
}
|
||||
|
||||
let mut inv = M::zeros(n, n);
|
||||
|
||||
|
||||
for i in 0..n {
|
||||
inv.set(i, i, T::one());
|
||||
}
|
||||
@@ -106,7 +105,10 @@ impl<T: FloatExt, M: BaseMatrix<T>> LU<T, M> {
|
||||
let (b_m, b_n) = b.shape();
|
||||
|
||||
if b_m != m {
|
||||
panic!("Row dimensions do not agree: A is {} x {}, but B is {} x {}", m, n, b_m, b_n);
|
||||
panic!(
|
||||
"Row dimensions do not agree: A is {} x {}, but B is {} x {}",
|
||||
m, n, b_m, b_n
|
||||
);
|
||||
}
|
||||
|
||||
if self.singular {
|
||||
@@ -120,9 +122,9 @@ impl<T: FloatExt, M: BaseMatrix<T>> LU<T, M> {
|
||||
X.set(i, j, b.get(self.pivot[i], j));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
for k in 0..n {
|
||||
for i in k+1..n {
|
||||
for i in k + 1..n {
|
||||
for j in 0..b_n {
|
||||
X.sub_element_mut(i, j, X.get(k, j) * self.LU.get(i, k));
|
||||
}
|
||||
@@ -140,7 +142,7 @@ impl<T: FloatExt, M: BaseMatrix<T>> LU<T, M> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
for j in 0..b_n {
|
||||
for i in 0..m {
|
||||
b.set(i, j, X.get(i, j));
|
||||
@@ -148,20 +150,16 @@ impl<T: FloatExt, M: BaseMatrix<T>> LU<T, M> {
|
||||
}
|
||||
|
||||
b
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
pub trait LUDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
||||
|
||||
pub trait LUDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
||||
fn lu(&self) -> LU<T, Self> {
|
||||
self.clone().lu_mut()
|
||||
}
|
||||
|
||||
fn lu_mut(mut self) -> LU<T, Self> {
|
||||
|
||||
let (m, n) = self.shape();
|
||||
let (m, n) = self.shape();
|
||||
|
||||
let mut piv = vec![0; m];
|
||||
for i in 0..m {
|
||||
@@ -172,7 +170,6 @@ pub trait LUDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
||||
let mut LUcolj = vec![T::zero(); m];
|
||||
|
||||
for j in 0..n {
|
||||
|
||||
for i in 0..m {
|
||||
LUcolj[i] = self.get(i, j);
|
||||
}
|
||||
@@ -189,7 +186,7 @@ pub trait LUDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
||||
}
|
||||
|
||||
let mut p = j;
|
||||
for i in j+1..m {
|
||||
for i in j + 1..m {
|
||||
if LUcolj[i].abs() > LUcolj[p].abs() {
|
||||
p = i;
|
||||
}
|
||||
@@ -205,50 +202,47 @@ pub trait LUDecomposableMatrix<T: FloatExt>: BaseMatrix<T> {
|
||||
piv[j] = k;
|
||||
pivsign = -pivsign;
|
||||
}
|
||||
|
||||
|
||||
if j < m && self.get(j, j) != T::zero() {
|
||||
for i in j+1..m {
|
||||
for i in j + 1..m {
|
||||
self.div_element_mut(i, j, self.get(j, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
LU::new(self, piv, pivsign)
|
||||
|
||||
}
|
||||
|
||||
fn lu_solve_mut(self, b: Self) -> Self {
|
||||
|
||||
self.lu_mut().solve(b)
|
||||
|
||||
}
|
||||
self.lu_mut().solve(b)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
|
||||
#[test]
|
||||
fn decompose() {
|
||||
|
||||
let a = DenseMatrix::from_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
|
||||
let expected_L = DenseMatrix::from_array(&[&[1. , 0. , 0. ], &[0. , 1. , 0. ], &[0.2, 0.8, 1. ]]);
|
||||
let expected_U = DenseMatrix::from_array(&[&[ 5., 6., 0.], &[ 0., 1., 5.], &[ 0., 0., -1.]]);
|
||||
let expected_pivot = DenseMatrix::from_array(&[&[0., 0., 1.], &[0., 1., 0.], &[1., 0., 0.]]);
|
||||
let lu = a.lu();
|
||||
assert!(lu.L().approximate_eq(&expected_L, 1e-4));
|
||||
assert!(lu.U().approximate_eq(&expected_U, 1e-4));
|
||||
assert!(lu.pivot().approximate_eq(&expected_pivot, 1e-4));
|
||||
fn decompose() {
|
||||
let a = DenseMatrix::from_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
|
||||
let expected_L = DenseMatrix::from_array(&[&[1., 0., 0.], &[0., 1., 0.], &[0.2, 0.8, 1.]]);
|
||||
let expected_U = DenseMatrix::from_array(&[&[5., 6., 0.], &[0., 1., 5.], &[0., 0., -1.]]);
|
||||
let expected_pivot =
|
||||
DenseMatrix::from_array(&[&[0., 0., 1.], &[0., 1., 0.], &[1., 0., 0.]]);
|
||||
let lu = a.lu();
|
||||
assert!(lu.L().approximate_eq(&expected_L, 1e-4));
|
||||
assert!(lu.U().approximate_eq(&expected_U, 1e-4));
|
||||
assert!(lu.pivot().approximate_eq(&expected_pivot, 1e-4));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inverse() {
|
||||
|
||||
let a = DenseMatrix::from_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
|
||||
let expected = DenseMatrix::from_array(&[&[-6.0, 3.6, 1.4], &[5.0, -3.0, -1.0], &[-1.0, 0.8, 0.2]]);
|
||||
let a_inv = a.lu().inverse();
|
||||
println!("{}", a_inv);
|
||||
assert!(a_inv.approximate_eq(&expected, 1e-4));
|
||||
fn inverse() {
|
||||
let a = DenseMatrix::from_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
|
||||
let expected =
|
||||
DenseMatrix::from_array(&[&[-6.0, 3.6, 1.4], &[5.0, -3.0, -1.0], &[-1.0, 0.8, 0.2]]);
|
||||
let a_inv = a.lu().inverse();
|
||||
println!("{}", a_inv);
|
||||
assert!(a_inv.approximate_eq(&expected, 1e-4));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user