feat: refactoring, adds Result to most public API
This commit is contained in:
+13
-13
@@ -20,7 +20,7 @@
|
||||
//! &[5., 6., 0.]
|
||||
//! ]);
|
||||
//!
|
||||
//! let lu = A.lu();
|
||||
//! let lu = A.lu().unwrap();
|
||||
//! let lower: DenseMatrix<f64> = lu.L();
|
||||
//! let upper: DenseMatrix<f64> = lu.U();
|
||||
//! ```
|
||||
@@ -36,6 +36,7 @@
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::BaseMatrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
@@ -121,7 +122,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
|
||||
}
|
||||
|
||||
/// Returns matrix inverse
|
||||
pub fn inverse(&self) -> M {
|
||||
pub fn inverse(&self) -> Result<M, Failed> {
|
||||
let (m, n) = self.LU.shape();
|
||||
|
||||
if m != n {
|
||||
@@ -134,11 +135,10 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
|
||||
inv.set(i, i, T::one());
|
||||
}
|
||||
|
||||
inv = self.solve(inv);
|
||||
return inv;
|
||||
self.solve(inv)
|
||||
}
|
||||
|
||||
fn solve(&self, mut b: M) -> M {
|
||||
fn solve(&self, mut b: M) -> Result<M, Failed> {
|
||||
let (m, n) = self.LU.shape();
|
||||
let (b_m, b_n) = b.shape();
|
||||
|
||||
@@ -187,20 +187,20 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
b
|
||||
Ok(b)
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait that implements LU decomposition routine for any matrix.
|
||||
pub trait LUDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
/// Compute the LU decomposition of a square matrix.
|
||||
fn lu(&self) -> LU<T, Self> {
|
||||
fn lu(&self) -> Result<LU<T, Self>, Failed> {
|
||||
self.clone().lu_mut()
|
||||
}
|
||||
|
||||
/// Compute the LU decomposition of a square matrix. The input matrix
|
||||
/// will be used for factorization.
|
||||
fn lu_mut(mut self) -> LU<T, Self> {
|
||||
fn lu_mut(mut self) -> Result<LU<T, Self>, Failed> {
|
||||
let (m, n) = self.shape();
|
||||
|
||||
let mut piv = vec![0; m];
|
||||
@@ -252,12 +252,12 @@ pub trait LUDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
}
|
||||
}
|
||||
|
||||
LU::new(self, piv, pivsign)
|
||||
Ok(LU::new(self, piv, pivsign))
|
||||
}
|
||||
|
||||
/// Solves Ax = b
|
||||
fn lu_solve_mut(self, b: Self) -> Self {
|
||||
self.lu_mut().solve(b)
|
||||
fn lu_solve_mut(self, b: Self) -> Result<Self, Failed> {
|
||||
self.lu_mut().and_then(|lu| lu.solve(b))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -275,7 +275,7 @@ mod tests {
|
||||
DenseMatrix::from_2d_array(&[&[5., 6., 0.], &[0., 1., 5.], &[0., 0., -1.]]);
|
||||
let expected_pivot =
|
||||
DenseMatrix::from_2d_array(&[&[0., 0., 1.], &[0., 1., 0.], &[1., 0., 0.]]);
|
||||
let lu = a.lu();
|
||||
let lu = a.lu().unwrap();
|
||||
assert!(lu.L().approximate_eq(&expected_L, 1e-4));
|
||||
assert!(lu.U().approximate_eq(&expected_U, 1e-4));
|
||||
assert!(lu.pivot().approximate_eq(&expected_pivot, 1e-4));
|
||||
@@ -286,7 +286,7 @@ mod tests {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
|
||||
let expected =
|
||||
DenseMatrix::from_2d_array(&[&[-6.0, 3.6, 1.4], &[5.0, -3.0, -1.0], &[-1.0, 0.8, 0.2]]);
|
||||
let a_inv = a.lu().inverse();
|
||||
let a_inv = a.lu().and_then(|lu| lu.inverse()).unwrap();
|
||||
println!("{}", a_inv);
|
||||
assert!(a_inv.approximate_eq(&expected, 1e-4));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user