feat: refactoring, adds Result to most public API

This commit is contained in:
Volodymyr Orlov
2020-09-18 15:20:32 -07:00
parent 4921ae76f5
commit a9db970195
24 changed files with 389 additions and 298 deletions
+8 -7
View File
@@ -21,7 +21,7 @@
//! &[0.7000, 0.3000, 0.8000],
//! ]);
//!
//! let evd = A.evd(true);
//! let evd = A.evd(true).unwrap();
//! let eigenvectors: DenseMatrix<f64> = evd.V;
//! let eigenvalues: Vec<f64> = evd.d;
//! ```
@@ -34,6 +34,7 @@
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#![allow(non_snake_case)]
use crate::error::Failed;
use crate::linalg::BaseMatrix;
use crate::math::num::RealNumber;
use num::complex::Complex;
@@ -54,14 +55,14 @@ pub struct EVD<T: RealNumber, M: BaseMatrix<T>> {
pub trait EVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
/// Compute the eigen decomposition of a square matrix.
/// * `symmetric` - whether the matrix is symmetric
fn evd(&self, symmetric: bool) -> EVD<T, Self> {
fn evd(&self, symmetric: bool) -> Result<EVD<T, Self>, Failed> {
self.clone().evd_mut(symmetric)
}
/// Compute the eigen decomposition of a square matrix. The input matrix
/// will be used for factorization.
/// * `symmetric` - whether the matrix is symmetric
fn evd_mut(mut self, symmetric: bool) -> EVD<T, Self> {
fn evd_mut(mut self, symmetric: bool) -> Result<EVD<T, Self>, Failed> {
let (nrows, ncols) = self.shape();
if ncols != nrows {
panic!("Matrix is not square: {} x {}", nrows, ncols);
@@ -92,7 +93,7 @@ pub trait EVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
sort(&mut d, &mut e, &mut V);
}
EVD { V: V, d: d, e: e }
Ok(EVD { V: V, d: d, e: e })
}
}
@@ -845,7 +846,7 @@ mod tests {
&[0.6240573, -0.44947578, -0.6391588],
]);
let evd = A.evd(true);
let evd = A.evd(true).unwrap();
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
for i in 0..eigen_values.len() {
@@ -872,7 +873,7 @@ mod tests {
&[0.6952105, 0.43984484, -0.7036135],
]);
let evd = A.evd(false);
let evd = A.evd(false).unwrap();
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
for i in 0..eigen_values.len() {
@@ -902,7 +903,7 @@ mod tests {
&[0.6707, 0.1059, 0.901, -0.6289],
]);
let evd = A.evd(false);
let evd = A.evd(false).unwrap();
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
for i in 0..eigen_values_d.len() {
+13 -13
View File
@@ -20,7 +20,7 @@
//! &[5., 6., 0.]
//! ]);
//!
//! let lu = A.lu();
//! let lu = A.lu().unwrap();
//! let lower: DenseMatrix<f64> = lu.L();
//! let upper: DenseMatrix<f64> = lu.U();
//! ```
@@ -36,6 +36,7 @@
use std::fmt::Debug;
use std::marker::PhantomData;
use crate::error::Failed;
use crate::linalg::BaseMatrix;
use crate::math::num::RealNumber;
@@ -121,7 +122,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
}
/// Returns matrix inverse
pub fn inverse(&self) -> M {
pub fn inverse(&self) -> Result<M, Failed> {
let (m, n) = self.LU.shape();
if m != n {
@@ -134,11 +135,10 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
inv.set(i, i, T::one());
}
inv = self.solve(inv);
return inv;
self.solve(inv)
}
fn solve(&self, mut b: M) -> M {
fn solve(&self, mut b: M) -> Result<M, Failed> {
let (m, n) = self.LU.shape();
let (b_m, b_n) = b.shape();
@@ -187,20 +187,20 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
}
}
b
Ok(b)
}
}
/// Trait that implements LU decomposition routine for any matrix.
pub trait LUDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
/// Compute the LU decomposition of a square matrix.
fn lu(&self) -> LU<T, Self> {
fn lu(&self) -> Result<LU<T, Self>, Failed> {
self.clone().lu_mut()
}
/// Compute the LU decomposition of a square matrix. The input matrix
/// will be used for factorization.
fn lu_mut(mut self) -> LU<T, Self> {
fn lu_mut(mut self) -> Result<LU<T, Self>, Failed> {
let (m, n) = self.shape();
let mut piv = vec![0; m];
@@ -252,12 +252,12 @@ pub trait LUDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
}
}
LU::new(self, piv, pivsign)
Ok(LU::new(self, piv, pivsign))
}
/// Solves Ax = b
fn lu_solve_mut(self, b: Self) -> Self {
self.lu_mut().solve(b)
fn lu_solve_mut(self, b: Self) -> Result<Self, Failed> {
self.lu_mut().and_then(|lu| lu.solve(b))
}
}
@@ -275,7 +275,7 @@ mod tests {
DenseMatrix::from_2d_array(&[&[5., 6., 0.], &[0., 1., 5.], &[0., 0., -1.]]);
let expected_pivot =
DenseMatrix::from_2d_array(&[&[0., 0., 1.], &[0., 1., 0.], &[1., 0., 0.]]);
let lu = a.lu();
let lu = a.lu().unwrap();
assert!(lu.L().approximate_eq(&expected_L, 1e-4));
assert!(lu.U().approximate_eq(&expected_U, 1e-4));
assert!(lu.pivot().approximate_eq(&expected_pivot, 1e-4));
@@ -286,7 +286,7 @@ mod tests {
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
let expected =
DenseMatrix::from_2d_array(&[&[-6.0, 3.6, 1.4], &[5.0, -3.0, -1.0], &[-1.0, 0.8, 0.2]]);
let a_inv = a.lu().inverse();
let a_inv = a.lu().and_then(|lu| lu.inverse()).unwrap();
println!("{}", a_inv);
assert!(a_inv.approximate_eq(&expected, 1e-4));
}
+1 -1
View File
@@ -26,7 +26,7 @@
//! &[0.7000, 0.3000, 0.8000],
//! ]);
//!
//! let svd = A.svd();
//! let svd = A.svd().unwrap();
//!
//! let s: Vec<f64> = svd.s;
//! let v: DenseMatrix<f64> = svd.V;
+7 -4
View File
@@ -34,8 +34,8 @@
//! 116.9,
//! ]);
//!
//! let lr = LinearRegression::fit(&x, &y, Default::default());
//! let y_hat = lr.predict(&x);
//! let lr = LinearRegression::fit(&x, &y, Default::default()).unwrap();
//! let y_hat = lr.predict(&x).unwrap();
//! ```
use std::iter::Sum;
use std::ops::{AddAssign, DivAssign, MulAssign, Range, SubAssign};
@@ -777,9 +777,12 @@ mod tests {
solver: LinearRegressionSolverName::QR,
},
)
.predict(&x);
.and_then(|lr| lr.predict(&x))
.unwrap();
let y_hat_svd = LinearRegression::fit(&x, &y, Default::default()).predict(&x);
let y_hat_svd = LinearRegression::fit(&x, &y, Default::default())
.and_then(|lr| lr.predict(&x))
.unwrap();
assert!(y
.iter()
+9 -12
View File
@@ -36,8 +36,8 @@
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.
//! ]);
//!
//! let lr = LogisticRegression::fit(&x, &y);
//! let y_hat = lr.predict(&x);
//! let lr = LogisticRegression::fit(&x, &y).unwrap();
//! let y_hat = lr.predict(&x).unwrap();
//! ```
use std::iter::Sum;
use std::ops::AddAssign;
@@ -395,6 +395,7 @@ mod tests {
use super::*;
use crate::ensemble::random_forest_regressor::*;
use crate::linear::logistic_regression::*;
use crate::metrics::mean_absolute_error;
use ndarray::{arr1, arr2, Array1, Array2};
#[test]
@@ -736,9 +737,9 @@ mod tests {
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
]);
let lr = LogisticRegression::fit(&x, &y);
let lr = LogisticRegression::fit(&x, &y).unwrap();
let y_hat = lr.predict(&x);
let y_hat = lr.predict(&x).unwrap();
let error: f64 = y
.into_iter()
@@ -774,10 +775,6 @@ mod tests {
114.2, 115.7, 116.9,
]);
let expected_y: Vec<f64> = vec![
85., 88., 88., 89., 97., 98., 99., 99., 102., 104., 109., 110., 113., 114., 115., 116.,
];
let y_hat = RandomForestRegressor::fit(
&x,
&y,
@@ -789,10 +786,10 @@ mod tests {
m: Option::None,
},
)
.predict(&x);
.unwrap()
.predict(&x)
.unwrap();
for i in 0..y_hat.len() {
assert!((y_hat[i] - expected_y[i]).abs() < 1.0);
}
assert!(mean_absolute_error(&y, &y_hat) < 1.0);
}
}
+14 -14
View File
@@ -15,9 +15,9 @@
//! &[0.7, 0.3, 0.8]
//! ]);
//!
//! let lu = A.qr();
//! let orthogonal: DenseMatrix<f64> = lu.Q();
//! let triangular: DenseMatrix<f64> = lu.R();
//! let qr = A.qr().unwrap();
//! let orthogonal: DenseMatrix<f64> = qr.Q();
//! let triangular: DenseMatrix<f64> = qr.R();
//! ```
//!
//! ## References:
@@ -28,10 +28,10 @@
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#![allow(non_snake_case)]
use std::fmt::Debug;
use crate::error::Failed;
use crate::linalg::BaseMatrix;
use crate::math::num::RealNumber;
use std::fmt::Debug;
#[derive(Debug, Clone)]
/// Results of QR decomposition.
@@ -99,7 +99,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> QR<T, M> {
return Q;
}
fn solve(&self, mut b: M) -> M {
fn solve(&self, mut b: M) -> Result<M, Failed> {
let (m, n) = self.QR.shape();
let (b_nrows, b_ncols) = b.shape();
@@ -139,20 +139,20 @@ impl<T: RealNumber, M: BaseMatrix<T>> QR<T, M> {
}
}
b
Ok(b)
}
}
/// Trait that implements QR decomposition routine for any matrix.
pub trait QRDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
/// Compute the QR decomposition of a matrix.
fn qr(&self) -> QR<T, Self> {
fn qr(&self) -> Result<QR<T, Self>, Failed> {
self.clone().qr_mut()
}
/// Compute the QR decomposition of a matrix. The input matrix
/// will be used for factorization.
fn qr_mut(mut self) -> QR<T, Self> {
fn qr_mut(mut self) -> Result<QR<T, Self>, Failed> {
let (m, n) = self.shape();
let mut r_diagonal: Vec<T> = vec![T::zero(); n];
@@ -186,12 +186,12 @@ pub trait QRDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
r_diagonal[k] = -nrm;
}
QR::new(self, r_diagonal)
Ok(QR::new(self, r_diagonal))
}
/// Solves Ax = b
fn qr_solve_mut(self, b: Self) -> Self {
self.qr_mut().solve(b)
fn qr_solve_mut(self, b: Self) -> Result<Self, Failed> {
self.qr_mut().and_then(|qr| qr.solve(b))
}
}
@@ -213,7 +213,7 @@ mod tests {
&[0.0, -0.3064, 0.0682],
&[0.0, 0.0, -0.1999],
]);
let qr = a.qr();
let qr = a.qr().unwrap();
assert!(qr.Q().abs().approximate_eq(&q.abs(), 1e-4));
assert!(qr.R().abs().approximate_eq(&r.abs(), 1e-4));
}
@@ -227,7 +227,7 @@ mod tests {
&[0.8783784, 2.2297297],
&[0.4729730, 0.6621622],
]);
let w = a.qr_solve_mut(b);
let w = a.qr_solve_mut(b).unwrap();
assert!(w.approximate_eq(&expected_w, 1e-2));
}
}
+14 -13
View File
@@ -19,7 +19,7 @@
//! &[0.7, 0.3, 0.8]
//! ]);
//!
//! let svd = A.svd();
//! let svd = A.svd().unwrap();
//! let u: DenseMatrix<f64> = svd.U;
//! let v: DenseMatrix<f64> = svd.V;
//! let s: Vec<f64> = svd.s;
@@ -33,6 +33,7 @@
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#![allow(non_snake_case)]
use crate::error::Failed;
use crate::linalg::BaseMatrix;
use crate::math::num::RealNumber;
use std::fmt::Debug;
@@ -55,23 +56,23 @@ pub struct SVD<T: RealNumber, M: SVDDecomposableMatrix<T>> {
/// Trait that implements SVD decomposition routine for any matrix.
pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
/// Solves Ax = b. Overrides original matrix in the process.
fn svd_solve_mut(self, b: Self) -> Self {
self.svd_mut().solve(b)
fn svd_solve_mut(self, b: Self) -> Result<Self, Failed> {
self.svd_mut().and_then(|svd| svd.solve(b))
}
/// Solves Ax = b
fn svd_solve(&self, b: Self) -> Self {
self.svd().solve(b)
fn svd_solve(&self, b: Self) -> Result<Self, Failed> {
self.svd().and_then(|svd| svd.solve(b))
}
/// Compute the SVD decomposition of a matrix.
fn svd(&self) -> SVD<T, Self> {
fn svd(&self) -> Result<SVD<T, Self>, Failed> {
self.clone().svd_mut()
}
/// Compute the SVD decomposition of a matrix. The input matrix
/// will be used for factorization.
fn svd_mut(self) -> SVD<T, Self> {
fn svd_mut(self) -> Result<SVD<T, Self>, Failed> {
let mut U = self;
let (m, n) = U.shape();
@@ -406,7 +407,7 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
}
}
SVD::new(U, v, w)
Ok(SVD::new(U, v, w))
}
}
@@ -427,7 +428,7 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
}
}
pub(crate) fn solve(&self, mut b: M) -> M {
pub(crate) fn solve(&self, mut b: M) -> Result<M, Failed> {
let p = b.shape().1;
if self.U.shape().0 != b.shape().0 {
@@ -460,7 +461,7 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
}
}
b
Ok(b)
}
}
@@ -491,7 +492,7 @@ mod tests {
&[0.6240573, -0.44947578, -0.6391588],
]);
let svd = A.svd();
let svd = A.svd().unwrap();
assert!(V.abs().approximate_eq(&svd.V.abs(), 1e-4));
assert!(U.abs().approximate_eq(&svd.U.abs(), 1e-4));
@@ -692,7 +693,7 @@ mod tests {
],
]);
let svd = A.svd();
let svd = A.svd().unwrap();
assert!(V.abs().approximate_eq(&svd.V.abs(), 1e-4));
assert!(U.abs().approximate_eq(&svd.U.abs(), 1e-4));
@@ -707,7 +708,7 @@ mod tests {
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]);
let expected_w =
DenseMatrix::from_2d_array(&[&[-0.20, -1.28], &[0.87, 2.22], &[0.47, 0.66]]);
let w = a.svd_solve_mut(b);
let w = a.svd_solve_mut(b).unwrap();
assert!(w.approximate_eq(&expected_w, 1e-2));
}
}