feat: refactoring, adds Result to most public API

This commit is contained in:
Volodymyr Orlov
2020-09-18 15:20:32 -07:00
parent 4921ae76f5
commit a9db970195
24 changed files with 389 additions and 298 deletions
+14 -13
View File
@@ -19,7 +19,7 @@
//! &[0.7, 0.3, 0.8]
//! ]);
//!
//! let svd = A.svd();
//! let svd = A.svd().unwrap();
//! let u: DenseMatrix<f64> = svd.U;
//! let v: DenseMatrix<f64> = svd.V;
//! let s: Vec<f64> = svd.s;
@@ -33,6 +33,7 @@
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#![allow(non_snake_case)]
use crate::error::Failed;
use crate::linalg::BaseMatrix;
use crate::math::num::RealNumber;
use std::fmt::Debug;
@@ -55,23 +56,23 @@ pub struct SVD<T: RealNumber, M: SVDDecomposableMatrix<T>> {
/// Trait that implements SVD decomposition routine for any matrix.
pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
/// Solves Ax = b. Overrides original matrix in the process.
fn svd_solve_mut(self, b: Self) -> Self {
self.svd_mut().solve(b)
fn svd_solve_mut(self, b: Self) -> Result<Self, Failed> {
self.svd_mut().and_then(|svd| svd.solve(b))
}
/// Solves Ax = b
fn svd_solve(&self, b: Self) -> Self {
self.svd().solve(b)
fn svd_solve(&self, b: Self) -> Result<Self, Failed> {
self.svd().and_then(|svd| svd.solve(b))
}
/// Compute the SVD decomposition of a matrix.
fn svd(&self) -> SVD<T, Self> {
fn svd(&self) -> Result<SVD<T, Self>, Failed> {
self.clone().svd_mut()
}
/// Compute the SVD decomposition of a matrix. The input matrix
/// will be used for factorization.
fn svd_mut(self) -> SVD<T, Self> {
fn svd_mut(self) -> Result<SVD<T, Self>, Failed> {
let mut U = self;
let (m, n) = U.shape();
@@ -406,7 +407,7 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
}
}
SVD::new(U, v, w)
Ok(SVD::new(U, v, w))
}
}
@@ -427,7 +428,7 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
}
}
pub(crate) fn solve(&self, mut b: M) -> M {
pub(crate) fn solve(&self, mut b: M) -> Result<M, Failed> {
let p = b.shape().1;
if self.U.shape().0 != b.shape().0 {
@@ -460,7 +461,7 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
}
}
b
Ok(b)
}
}
@@ -491,7 +492,7 @@ mod tests {
&[0.6240573, -0.44947578, -0.6391588],
]);
let svd = A.svd();
let svd = A.svd().unwrap();
assert!(V.abs().approximate_eq(&svd.V.abs(), 1e-4));
assert!(U.abs().approximate_eq(&svd.U.abs(), 1e-4));
@@ -692,7 +693,7 @@ mod tests {
],
]);
let svd = A.svd();
let svd = A.svd().unwrap();
assert!(V.abs().approximate_eq(&svd.V.abs(), 1e-4));
assert!(U.abs().approximate_eq(&svd.U.abs(), 1e-4));
@@ -707,7 +708,7 @@ mod tests {
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]);
let expected_w =
DenseMatrix::from_2d_array(&[&[-0.20, -1.28], &[0.87, 2.22], &[0.47, 0.66]]);
let w = a.svd_solve_mut(b);
let w = a.svd_solve_mut(b).unwrap();
assert!(w.approximate_eq(&expected_w, 1e-2));
}
}