feat: refactoring, adds Result to most public API
This commit is contained in:
+14
-13
@@ -19,7 +19,7 @@
|
||||
//! &[0.7, 0.3, 0.8]
|
||||
//! ]);
|
||||
//!
|
||||
//! let svd = A.svd();
|
||||
//! let svd = A.svd().unwrap();
|
||||
//! let u: DenseMatrix<f64> = svd.U;
|
||||
//! let v: DenseMatrix<f64> = svd.V;
|
||||
//! let s: Vec<f64> = svd.s;
|
||||
@@ -33,6 +33,7 @@
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::BaseMatrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use std::fmt::Debug;
|
||||
@@ -55,23 +56,23 @@ pub struct SVD<T: RealNumber, M: SVDDecomposableMatrix<T>> {
|
||||
/// Trait that implements SVD decomposition routine for any matrix.
|
||||
pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
/// Solves Ax = b. Overrides original matrix in the process.
|
||||
fn svd_solve_mut(self, b: Self) -> Self {
|
||||
self.svd_mut().solve(b)
|
||||
fn svd_solve_mut(self, b: Self) -> Result<Self, Failed> {
|
||||
self.svd_mut().and_then(|svd| svd.solve(b))
|
||||
}
|
||||
|
||||
/// Solves Ax = b
|
||||
fn svd_solve(&self, b: Self) -> Self {
|
||||
self.svd().solve(b)
|
||||
fn svd_solve(&self, b: Self) -> Result<Self, Failed> {
|
||||
self.svd().and_then(|svd| svd.solve(b))
|
||||
}
|
||||
|
||||
/// Compute the SVD decomposition of a matrix.
|
||||
fn svd(&self) -> SVD<T, Self> {
|
||||
fn svd(&self) -> Result<SVD<T, Self>, Failed> {
|
||||
self.clone().svd_mut()
|
||||
}
|
||||
|
||||
/// Compute the SVD decomposition of a matrix. The input matrix
|
||||
/// will be used for factorization.
|
||||
fn svd_mut(self) -> SVD<T, Self> {
|
||||
fn svd_mut(self) -> Result<SVD<T, Self>, Failed> {
|
||||
let mut U = self;
|
||||
|
||||
let (m, n) = U.shape();
|
||||
@@ -406,7 +407,7 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
}
|
||||
}
|
||||
|
||||
SVD::new(U, v, w)
|
||||
Ok(SVD::new(U, v, w))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -427,7 +428,7 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn solve(&self, mut b: M) -> M {
|
||||
pub(crate) fn solve(&self, mut b: M) -> Result<M, Failed> {
|
||||
let p = b.shape().1;
|
||||
|
||||
if self.U.shape().0 != b.shape().0 {
|
||||
@@ -460,7 +461,7 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
b
|
||||
Ok(b)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -491,7 +492,7 @@ mod tests {
|
||||
&[0.6240573, -0.44947578, -0.6391588],
|
||||
]);
|
||||
|
||||
let svd = A.svd();
|
||||
let svd = A.svd().unwrap();
|
||||
|
||||
assert!(V.abs().approximate_eq(&svd.V.abs(), 1e-4));
|
||||
assert!(U.abs().approximate_eq(&svd.U.abs(), 1e-4));
|
||||
@@ -692,7 +693,7 @@ mod tests {
|
||||
],
|
||||
]);
|
||||
|
||||
let svd = A.svd();
|
||||
let svd = A.svd().unwrap();
|
||||
|
||||
assert!(V.abs().approximate_eq(&svd.V.abs(), 1e-4));
|
||||
assert!(U.abs().approximate_eq(&svd.U.abs(), 1e-4));
|
||||
@@ -707,7 +708,7 @@ mod tests {
|
||||
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]);
|
||||
let expected_w =
|
||||
DenseMatrix::from_2d_array(&[&[-0.20, -1.28], &[0.87, 2.22], &[0.47, 0.66]]);
|
||||
let w = a.svd_solve_mut(b);
|
||||
let w = a.svd_solve_mut(b).unwrap();
|
||||
assert!(w.approximate_eq(&expected_w, 1e-2));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user