feat: refactoring, adds Result to most public API
This commit is contained in:
+8
-7
@@ -21,7 +21,7 @@
|
||||
//! &[0.7000, 0.3000, 0.8000],
|
||||
//! ]);
|
||||
//!
|
||||
//! let evd = A.evd(true);
|
||||
//! let evd = A.evd(true).unwrap();
|
||||
//! let eigenvectors: DenseMatrix<f64> = evd.V;
|
||||
//! let eigenvalues: Vec<f64> = evd.d;
|
||||
//! ```
|
||||
@@ -34,6 +34,7 @@
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::BaseMatrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use num::complex::Complex;
|
||||
@@ -54,14 +55,14 @@ pub struct EVD<T: RealNumber, M: BaseMatrix<T>> {
|
||||
pub trait EVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
/// Compute the eigen decomposition of a square matrix.
|
||||
/// * `symmetric` - whether the matrix is symmetric
|
||||
fn evd(&self, symmetric: bool) -> EVD<T, Self> {
|
||||
fn evd(&self, symmetric: bool) -> Result<EVD<T, Self>, Failed> {
|
||||
self.clone().evd_mut(symmetric)
|
||||
}
|
||||
|
||||
/// Compute the eigen decomposition of a square matrix. The input matrix
|
||||
/// will be used for factorization.
|
||||
/// * `symmetric` - whether the matrix is symmetric
|
||||
fn evd_mut(mut self, symmetric: bool) -> EVD<T, Self> {
|
||||
fn evd_mut(mut self, symmetric: bool) -> Result<EVD<T, Self>, Failed> {
|
||||
let (nrows, ncols) = self.shape();
|
||||
if ncols != nrows {
|
||||
panic!("Matrix is not square: {} x {}", nrows, ncols);
|
||||
@@ -92,7 +93,7 @@ pub trait EVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
sort(&mut d, &mut e, &mut V);
|
||||
}
|
||||
|
||||
EVD { V: V, d: d, e: e }
|
||||
Ok(EVD { V: V, d: d, e: e })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -845,7 +846,7 @@ mod tests {
|
||||
&[0.6240573, -0.44947578, -0.6391588],
|
||||
]);
|
||||
|
||||
let evd = A.evd(true);
|
||||
let evd = A.evd(true).unwrap();
|
||||
|
||||
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
|
||||
for i in 0..eigen_values.len() {
|
||||
@@ -872,7 +873,7 @@ mod tests {
|
||||
&[0.6952105, 0.43984484, -0.7036135],
|
||||
]);
|
||||
|
||||
let evd = A.evd(false);
|
||||
let evd = A.evd(false).unwrap();
|
||||
|
||||
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
|
||||
for i in 0..eigen_values.len() {
|
||||
@@ -902,7 +903,7 @@ mod tests {
|
||||
&[0.6707, 0.1059, 0.901, -0.6289],
|
||||
]);
|
||||
|
||||
let evd = A.evd(false);
|
||||
let evd = A.evd(false).unwrap();
|
||||
|
||||
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
|
||||
for i in 0..eigen_values_d.len() {
|
||||
|
||||
+13
-13
@@ -20,7 +20,7 @@
|
||||
//! &[5., 6., 0.]
|
||||
//! ]);
|
||||
//!
|
||||
//! let lu = A.lu();
|
||||
//! let lu = A.lu().unwrap();
|
||||
//! let lower: DenseMatrix<f64> = lu.L();
|
||||
//! let upper: DenseMatrix<f64> = lu.U();
|
||||
//! ```
|
||||
@@ -36,6 +36,7 @@
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::BaseMatrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
@@ -121,7 +122,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
|
||||
}
|
||||
|
||||
/// Returns matrix inverse
|
||||
pub fn inverse(&self) -> M {
|
||||
pub fn inverse(&self) -> Result<M, Failed> {
|
||||
let (m, n) = self.LU.shape();
|
||||
|
||||
if m != n {
|
||||
@@ -134,11 +135,10 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
|
||||
inv.set(i, i, T::one());
|
||||
}
|
||||
|
||||
inv = self.solve(inv);
|
||||
return inv;
|
||||
self.solve(inv)
|
||||
}
|
||||
|
||||
fn solve(&self, mut b: M) -> M {
|
||||
fn solve(&self, mut b: M) -> Result<M, Failed> {
|
||||
let (m, n) = self.LU.shape();
|
||||
let (b_m, b_n) = b.shape();
|
||||
|
||||
@@ -187,20 +187,20 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
b
|
||||
Ok(b)
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait that implements LU decomposition routine for any matrix.
|
||||
pub trait LUDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
/// Compute the LU decomposition of a square matrix.
|
||||
fn lu(&self) -> LU<T, Self> {
|
||||
fn lu(&self) -> Result<LU<T, Self>, Failed> {
|
||||
self.clone().lu_mut()
|
||||
}
|
||||
|
||||
/// Compute the LU decomposition of a square matrix. The input matrix
|
||||
/// will be used for factorization.
|
||||
fn lu_mut(mut self) -> LU<T, Self> {
|
||||
fn lu_mut(mut self) -> Result<LU<T, Self>, Failed> {
|
||||
let (m, n) = self.shape();
|
||||
|
||||
let mut piv = vec![0; m];
|
||||
@@ -252,12 +252,12 @@ pub trait LUDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
}
|
||||
}
|
||||
|
||||
LU::new(self, piv, pivsign)
|
||||
Ok(LU::new(self, piv, pivsign))
|
||||
}
|
||||
|
||||
/// Solves Ax = b
|
||||
fn lu_solve_mut(self, b: Self) -> Self {
|
||||
self.lu_mut().solve(b)
|
||||
fn lu_solve_mut(self, b: Self) -> Result<Self, Failed> {
|
||||
self.lu_mut().and_then(|lu| lu.solve(b))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -275,7 +275,7 @@ mod tests {
|
||||
DenseMatrix::from_2d_array(&[&[5., 6., 0.], &[0., 1., 5.], &[0., 0., -1.]]);
|
||||
let expected_pivot =
|
||||
DenseMatrix::from_2d_array(&[&[0., 0., 1.], &[0., 1., 0.], &[1., 0., 0.]]);
|
||||
let lu = a.lu();
|
||||
let lu = a.lu().unwrap();
|
||||
assert!(lu.L().approximate_eq(&expected_L, 1e-4));
|
||||
assert!(lu.U().approximate_eq(&expected_U, 1e-4));
|
||||
assert!(lu.pivot().approximate_eq(&expected_pivot, 1e-4));
|
||||
@@ -286,7 +286,7 @@ mod tests {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
|
||||
let expected =
|
||||
DenseMatrix::from_2d_array(&[&[-6.0, 3.6, 1.4], &[5.0, -3.0, -1.0], &[-1.0, 0.8, 0.2]]);
|
||||
let a_inv = a.lu().inverse();
|
||||
let a_inv = a.lu().and_then(|lu| lu.inverse()).unwrap();
|
||||
println!("{}", a_inv);
|
||||
assert!(a_inv.approximate_eq(&expected, 1e-4));
|
||||
}
|
||||
|
||||
+1
-1
@@ -26,7 +26,7 @@
|
||||
//! &[0.7000, 0.3000, 0.8000],
|
||||
//! ]);
|
||||
//!
|
||||
//! let svd = A.svd();
|
||||
//! let svd = A.svd().unwrap();
|
||||
//!
|
||||
//! let s: Vec<f64> = svd.s;
|
||||
//! let v: DenseMatrix<f64> = svd.V;
|
||||
|
||||
@@ -34,8 +34,8 @@
|
||||
//! 116.9,
|
||||
//! ]);
|
||||
//!
|
||||
//! let lr = LinearRegression::fit(&x, &y, Default::default());
|
||||
//! let y_hat = lr.predict(&x);
|
||||
//! let lr = LinearRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
//! let y_hat = lr.predict(&x).unwrap();
|
||||
//! ```
|
||||
use std::iter::Sum;
|
||||
use std::ops::{AddAssign, DivAssign, MulAssign, Range, SubAssign};
|
||||
@@ -777,9 +777,12 @@ mod tests {
|
||||
solver: LinearRegressionSolverName::QR,
|
||||
},
|
||||
)
|
||||
.predict(&x);
|
||||
.and_then(|lr| lr.predict(&x))
|
||||
.unwrap();
|
||||
|
||||
let y_hat_svd = LinearRegression::fit(&x, &y, Default::default()).predict(&x);
|
||||
let y_hat_svd = LinearRegression::fit(&x, &y, Default::default())
|
||||
.and_then(|lr| lr.predict(&x))
|
||||
.unwrap();
|
||||
|
||||
assert!(y
|
||||
.iter()
|
||||
|
||||
@@ -36,8 +36,8 @@
|
||||
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.
|
||||
//! ]);
|
||||
//!
|
||||
//! let lr = LogisticRegression::fit(&x, &y);
|
||||
//! let y_hat = lr.predict(&x);
|
||||
//! let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||
//! let y_hat = lr.predict(&x).unwrap();
|
||||
//! ```
|
||||
use std::iter::Sum;
|
||||
use std::ops::AddAssign;
|
||||
@@ -395,6 +395,7 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::ensemble::random_forest_regressor::*;
|
||||
use crate::linear::logistic_regression::*;
|
||||
use crate::metrics::mean_absolute_error;
|
||||
use ndarray::{arr1, arr2, Array1, Array2};
|
||||
|
||||
#[test]
|
||||
@@ -736,9 +737,9 @@ mod tests {
|
||||
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
]);
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y);
|
||||
let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||
|
||||
let y_hat = lr.predict(&x);
|
||||
let y_hat = lr.predict(&x).unwrap();
|
||||
|
||||
let error: f64 = y
|
||||
.into_iter()
|
||||
@@ -774,10 +775,6 @@ mod tests {
|
||||
114.2, 115.7, 116.9,
|
||||
]);
|
||||
|
||||
let expected_y: Vec<f64> = vec![
|
||||
85., 88., 88., 89., 97., 98., 99., 99., 102., 104., 109., 110., 113., 114., 115., 116.,
|
||||
];
|
||||
|
||||
let y_hat = RandomForestRegressor::fit(
|
||||
&x,
|
||||
&y,
|
||||
@@ -789,10 +786,10 @@ mod tests {
|
||||
m: Option::None,
|
||||
},
|
||||
)
|
||||
.predict(&x);
|
||||
.unwrap()
|
||||
.predict(&x)
|
||||
.unwrap();
|
||||
|
||||
for i in 0..y_hat.len() {
|
||||
assert!((y_hat[i] - expected_y[i]).abs() < 1.0);
|
||||
}
|
||||
assert!(mean_absolute_error(&y, &y_hat) < 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
+14
-14
@@ -15,9 +15,9 @@
|
||||
//! &[0.7, 0.3, 0.8]
|
||||
//! ]);
|
||||
//!
|
||||
//! let lu = A.qr();
|
||||
//! let orthogonal: DenseMatrix<f64> = lu.Q();
|
||||
//! let triangular: DenseMatrix<f64> = lu.R();
|
||||
//! let qr = A.qr().unwrap();
|
||||
//! let orthogonal: DenseMatrix<f64> = qr.Q();
|
||||
//! let triangular: DenseMatrix<f64> = qr.R();
|
||||
//! ```
|
||||
//!
|
||||
//! ## References:
|
||||
@@ -28,10 +28,10 @@
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
use std::fmt::Debug;
|
||||
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::BaseMatrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use std::fmt::Debug;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Results of QR decomposition.
|
||||
@@ -99,7 +99,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> QR<T, M> {
|
||||
return Q;
|
||||
}
|
||||
|
||||
fn solve(&self, mut b: M) -> M {
|
||||
fn solve(&self, mut b: M) -> Result<M, Failed> {
|
||||
let (m, n) = self.QR.shape();
|
||||
let (b_nrows, b_ncols) = b.shape();
|
||||
|
||||
@@ -139,20 +139,20 @@ impl<T: RealNumber, M: BaseMatrix<T>> QR<T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
b
|
||||
Ok(b)
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait that implements QR decomposition routine for any matrix.
|
||||
pub trait QRDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
/// Compute the QR decomposition of a matrix.
|
||||
fn qr(&self) -> QR<T, Self> {
|
||||
fn qr(&self) -> Result<QR<T, Self>, Failed> {
|
||||
self.clone().qr_mut()
|
||||
}
|
||||
|
||||
/// Compute the QR decomposition of a matrix. The input matrix
|
||||
/// will be used for factorization.
|
||||
fn qr_mut(mut self) -> QR<T, Self> {
|
||||
fn qr_mut(mut self) -> Result<QR<T, Self>, Failed> {
|
||||
let (m, n) = self.shape();
|
||||
|
||||
let mut r_diagonal: Vec<T> = vec![T::zero(); n];
|
||||
@@ -186,12 +186,12 @@ pub trait QRDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
r_diagonal[k] = -nrm;
|
||||
}
|
||||
|
||||
QR::new(self, r_diagonal)
|
||||
Ok(QR::new(self, r_diagonal))
|
||||
}
|
||||
|
||||
/// Solves Ax = b
|
||||
fn qr_solve_mut(self, b: Self) -> Self {
|
||||
self.qr_mut().solve(b)
|
||||
fn qr_solve_mut(self, b: Self) -> Result<Self, Failed> {
|
||||
self.qr_mut().and_then(|qr| qr.solve(b))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -213,7 +213,7 @@ mod tests {
|
||||
&[0.0, -0.3064, 0.0682],
|
||||
&[0.0, 0.0, -0.1999],
|
||||
]);
|
||||
let qr = a.qr();
|
||||
let qr = a.qr().unwrap();
|
||||
assert!(qr.Q().abs().approximate_eq(&q.abs(), 1e-4));
|
||||
assert!(qr.R().abs().approximate_eq(&r.abs(), 1e-4));
|
||||
}
|
||||
@@ -227,7 +227,7 @@ mod tests {
|
||||
&[0.8783784, 2.2297297],
|
||||
&[0.4729730, 0.6621622],
|
||||
]);
|
||||
let w = a.qr_solve_mut(b);
|
||||
let w = a.qr_solve_mut(b).unwrap();
|
||||
assert!(w.approximate_eq(&expected_w, 1e-2));
|
||||
}
|
||||
}
|
||||
|
||||
+14
-13
@@ -19,7 +19,7 @@
|
||||
//! &[0.7, 0.3, 0.8]
|
||||
//! ]);
|
||||
//!
|
||||
//! let svd = A.svd();
|
||||
//! let svd = A.svd().unwrap();
|
||||
//! let u: DenseMatrix<f64> = svd.U;
|
||||
//! let v: DenseMatrix<f64> = svd.V;
|
||||
//! let s: Vec<f64> = svd.s;
|
||||
@@ -33,6 +33,7 @@
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::BaseMatrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use std::fmt::Debug;
|
||||
@@ -55,23 +56,23 @@ pub struct SVD<T: RealNumber, M: SVDDecomposableMatrix<T>> {
|
||||
/// Trait that implements SVD decomposition routine for any matrix.
|
||||
pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
/// Solves Ax = b. Overrides original matrix in the process.
|
||||
fn svd_solve_mut(self, b: Self) -> Self {
|
||||
self.svd_mut().solve(b)
|
||||
fn svd_solve_mut(self, b: Self) -> Result<Self, Failed> {
|
||||
self.svd_mut().and_then(|svd| svd.solve(b))
|
||||
}
|
||||
|
||||
/// Solves Ax = b
|
||||
fn svd_solve(&self, b: Self) -> Self {
|
||||
self.svd().solve(b)
|
||||
fn svd_solve(&self, b: Self) -> Result<Self, Failed> {
|
||||
self.svd().and_then(|svd| svd.solve(b))
|
||||
}
|
||||
|
||||
/// Compute the SVD decomposition of a matrix.
|
||||
fn svd(&self) -> SVD<T, Self> {
|
||||
fn svd(&self) -> Result<SVD<T, Self>, Failed> {
|
||||
self.clone().svd_mut()
|
||||
}
|
||||
|
||||
/// Compute the SVD decomposition of a matrix. The input matrix
|
||||
/// will be used for factorization.
|
||||
fn svd_mut(self) -> SVD<T, Self> {
|
||||
fn svd_mut(self) -> Result<SVD<T, Self>, Failed> {
|
||||
let mut U = self;
|
||||
|
||||
let (m, n) = U.shape();
|
||||
@@ -406,7 +407,7 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
}
|
||||
}
|
||||
|
||||
SVD::new(U, v, w)
|
||||
Ok(SVD::new(U, v, w))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -427,7 +428,7 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn solve(&self, mut b: M) -> M {
|
||||
pub(crate) fn solve(&self, mut b: M) -> Result<M, Failed> {
|
||||
let p = b.shape().1;
|
||||
|
||||
if self.U.shape().0 != b.shape().0 {
|
||||
@@ -460,7 +461,7 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
b
|
||||
Ok(b)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -491,7 +492,7 @@ mod tests {
|
||||
&[0.6240573, -0.44947578, -0.6391588],
|
||||
]);
|
||||
|
||||
let svd = A.svd();
|
||||
let svd = A.svd().unwrap();
|
||||
|
||||
assert!(V.abs().approximate_eq(&svd.V.abs(), 1e-4));
|
||||
assert!(U.abs().approximate_eq(&svd.U.abs(), 1e-4));
|
||||
@@ -692,7 +693,7 @@ mod tests {
|
||||
],
|
||||
]);
|
||||
|
||||
let svd = A.svd();
|
||||
let svd = A.svd().unwrap();
|
||||
|
||||
assert!(V.abs().approximate_eq(&svd.V.abs(), 1e-4));
|
||||
assert!(U.abs().approximate_eq(&svd.U.abs(), 1e-4));
|
||||
@@ -707,7 +708,7 @@ mod tests {
|
||||
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]);
|
||||
let expected_w =
|
||||
DenseMatrix::from_2d_array(&[&[-0.20, -1.28], &[0.87, 2.22], &[0.47, 0.66]]);
|
||||
let w = a.svd_solve_mut(b);
|
||||
let w = a.svd_solve_mut(b).unwrap();
|
||||
assert!(w.approximate_eq(&expected_w, 1e-2));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user