feat: refactoring, adds Result to most public API

This commit is contained in:
Volodymyr Orlov
2020-09-18 15:20:32 -07:00
parent 4921ae76f5
commit a9db970195
24 changed files with 389 additions and 298 deletions
+13 -13
View File
@@ -20,7 +20,7 @@
//! &[5., 6., 0.]
//! ]);
//!
//! let lu = A.lu();
//! let lu = A.lu().unwrap();
//! let lower: DenseMatrix<f64> = lu.L();
//! let upper: DenseMatrix<f64> = lu.U();
//! ```
@@ -36,6 +36,7 @@
use std::fmt::Debug;
use std::marker::PhantomData;
use crate::error::Failed;
use crate::linalg::BaseMatrix;
use crate::math::num::RealNumber;
@@ -121,7 +122,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
}
/// Returns matrix inverse
pub fn inverse(&self) -> M {
pub fn inverse(&self) -> Result<M, Failed> {
let (m, n) = self.LU.shape();
if m != n {
@@ -134,11 +135,10 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
inv.set(i, i, T::one());
}
inv = self.solve(inv);
return inv;
self.solve(inv)
}
fn solve(&self, mut b: M) -> M {
fn solve(&self, mut b: M) -> Result<M, Failed> {
let (m, n) = self.LU.shape();
let (b_m, b_n) = b.shape();
@@ -187,20 +187,20 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
}
}
b
Ok(b)
}
}
/// Trait that implements LU decomposition routine for any matrix.
pub trait LUDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
/// Compute the LU decomposition of a square matrix.
fn lu(&self) -> LU<T, Self> {
fn lu(&self) -> Result<LU<T, Self>, Failed> {
self.clone().lu_mut()
}
/// Compute the LU decomposition of a square matrix. The input matrix
/// will be used for factorization.
fn lu_mut(mut self) -> LU<T, Self> {
fn lu_mut(mut self) -> Result<LU<T, Self>, Failed> {
let (m, n) = self.shape();
let mut piv = vec![0; m];
@@ -252,12 +252,12 @@ pub trait LUDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
}
}
LU::new(self, piv, pivsign)
Ok(LU::new(self, piv, pivsign))
}
/// Solves Ax = b
fn lu_solve_mut(self, b: Self) -> Self {
self.lu_mut().solve(b)
fn lu_solve_mut(self, b: Self) -> Result<Self, Failed> {
self.lu_mut().and_then(|lu| lu.solve(b))
}
}
@@ -275,7 +275,7 @@ mod tests {
DenseMatrix::from_2d_array(&[&[5., 6., 0.], &[0., 1., 5.], &[0., 0., -1.]]);
let expected_pivot =
DenseMatrix::from_2d_array(&[&[0., 0., 1.], &[0., 1., 0.], &[1., 0., 0.]]);
let lu = a.lu();
let lu = a.lu().unwrap();
assert!(lu.L().approximate_eq(&expected_L, 1e-4));
assert!(lu.U().approximate_eq(&expected_U, 1e-4));
assert!(lu.pivot().approximate_eq(&expected_pivot, 1e-4));
@@ -286,7 +286,7 @@ mod tests {
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
let expected =
DenseMatrix::from_2d_array(&[&[-6.0, 3.6, 1.4], &[5.0, -3.0, -1.0], &[-1.0, 0.8, 0.2]]);
let a_inv = a.lu().inverse();
let a_inv = a.lu().and_then(|lu| lu.inverse()).unwrap();
println!("{}", a_inv);
assert!(a_inv.approximate_eq(&expected, 1e-4));
}