diff --git a/src/error/mod.rs b/src/error/mod.rs index 320b991..c411e87 100644 --- a/src/error/mod.rs +++ b/src/error/mod.rs @@ -24,6 +24,8 @@ pub enum FailedError { FindFailed, /// Can't decompose a matrix DecompositionFailed, + /// Can't solve for x + SolutionFailed, } impl Failed { @@ -87,6 +89,7 @@ impl fmt::Display for FailedError { FailedError::TransformFailed => "Transform failed", FailedError::FindFailed => "Find failed", FailedError::DecompositionFailed => "Decomposition failed", + FailedError::SolutionFailed => "Can't find solution", }; write!(f, "{}", failed_err_str) } diff --git a/src/linalg/cholesky.rs b/src/linalg/cholesky.rs new file mode 100644 index 0000000..286e3f2 --- /dev/null +++ b/src/linalg/cholesky.rs @@ -0,0 +1,206 @@ +//! # Cholesky Decomposition +//! +//! every positive definite matrix \\(A \in R^{n \times n}\\) can be factored as +//! +//! \\[A = R^TR\\] +//! +//! where \\(R\\) is upper triangular matrix with positive diagonal elements +//! +//! Example: +//! ``` +//! use smartcore::linalg::naive::dense_matrix::*; +//! use crate::smartcore::linalg::cholesky::*; +//! +//! let A = DenseMatrix::from_2d_array(&[ +//! &[25., 15., -5.], +//! &[15., 18., 0.], +//! &[-5., 0., 11.] +//! ]); +//! +//! let cholesky = A.cholesky().unwrap(); +//! let lower_triangular: DenseMatrix = cholesky.L(); +//! let upper_triangular: DenseMatrix = cholesky.U(); +//! ``` +//! +//! ## References: +//! * ["No bullshit guide to linear algebra", Ivan Savov, 2016, 7.6 Matrix decompositions](https://minireference.com/) +//! * ["Numerical Recipes: The Art of Scientific Computing", Press W.H., Teukolsky S.A., Vetterling W.T, Flannery B.P, 3rd ed., 2.9 Cholesky Decomposition](http://numerical.recipes/) +//! +//! +//! +#![allow(non_snake_case)] + +use std::fmt::Debug; +use std::marker::PhantomData; + +use crate::error::{Failed, FailedError}; +use crate::linalg::BaseMatrix; +use crate::math::num::RealNumber; + +#[derive(Debug, Clone)] +/// Results of Cholesky decomposition. +pub struct Cholesky> { + R: M, + t: PhantomData +} + +impl> Cholesky { + pub(crate) fn new(R: M) -> Cholesky { + Cholesky { + R: R, + t: PhantomData + } + } + + /// Get lower triangular matrix. + pub fn L(&self) -> M { + let (n, _) = self.R.shape(); + let mut R = M::zeros(n, n); + + for i in 0..n { + for j in 0..n { + if j <= i { + R.set(i, j, self.R.get(i, j)); + } + } + } + R + } + + /// Get upper triangular matrix. + pub fn U(&self) -> M { + let (n, _) = self.R.shape(); + let mut R = M::zeros(n, n); + + for i in 0..n { + for j in 0..n { + if j <= i { + R.set(j, i, self.R.get(i, j)); + } + } + } + R + } + + /// Solves Ax = b + pub(crate) fn solve(&self, mut b: M) -> Result { + + let (bn, m) = b.shape(); + let (rn, _) = self.R.shape(); + + if bn != rn { + return Err(Failed::because(FailedError::SolutionFailed, &format!( + "Can't solve Ax = b for x. Number of rows in b != number of rows in R." + ))); + } + + for k in 0..bn { + for j in 0..m { + for i in 0..k { + b.sub_element_mut(k, j, b.get(i, j) * self.R.get(k, i)); + } + b.div_element_mut(k, j, self.R.get(k, k)); + } + } + + for k in (0..bn).rev() { + for j in 0..m { + for i in k + 1..bn { + b.sub_element_mut(k, j, b.get(i, j) * self.R.get(i, k)); + } + b.div_element_mut(k, j, self.R.get(k, k)); + } + } + Ok(b) + } +} + +/// Trait that implements Cholesky decomposition routine for any matrix. +pub trait CholeskyDecomposableMatrix: BaseMatrix { + /// Compute the Cholesky decomposition of a matrix. + fn cholesky(&self) -> Result, Failed> { + self.clone().cholesky_mut() + } + + /// Compute the Cholesky decomposition of a matrix. The input matrix + /// will be used for factorization. + fn cholesky_mut(mut self) -> Result, Failed> { + let (m, n) = self.shape(); + + if m != n { + return Err(Failed::because(FailedError::DecompositionFailed, &format!( + "Can't do Cholesky decomposition on a non-square matrix" + ))); + } + + for j in 0..n { + let mut d = T::zero(); + for k in 0..j { + let mut s = T::zero(); + for i in 0..k { + s += self.get(k, i) * self.get(j, i); + } + s = (self.get(j, k) - s) / self.get(k, k); + self.set(j, k, s); + d = d + s * s; + } + d = self.get(j, j) - d; + + if d < T::zero() { + return Err(Failed::because(FailedError::DecompositionFailed, &format!( + "The matrix is not positive definite." + ))); + } + + self.set(j, j, d.sqrt()); + } + + Ok(Cholesky::new(self)) + } + + /// Solves Ax = b + fn cholesky_solve_mut(self, b: Self) -> Result { + self.cholesky_mut().and_then(|qr| qr.solve(b)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::linalg::naive::dense_matrix::*; + + #[test] + fn cholesky_decompose() { + let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]); + let l = DenseMatrix::from_2d_array(&[ + &[5.0, 0.0, 0.0], + &[3.0, 3.0, 0.0], + &[-1.0, 1.0, 3.0], + ]); + let u = DenseMatrix::from_2d_array(&[ + &[5.0, 3.0, -1.0], + &[0.0, 3.0, 1.0], + &[0.0, 0.0, 3.0], + ]); + let cholesky = a.cholesky().unwrap(); + + assert!(cholesky.L().abs().approximate_eq(&l.abs(), 1e-4)); + assert!(cholesky.U().abs().approximate_eq(&u.abs(), 1e-4)); + assert!(cholesky.L().matmul(&cholesky.U()).abs().approximate_eq(&a.abs(), 1e-4)); + } + + #[test] + fn cholesky_solve_mut() { + let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]); + let b = DenseMatrix::from_2d_array(&[&[40., 51., 28.]]); + let expected = DenseMatrix::from_2d_array(&[ + &[1.0, 2.0, 3.0] + ]); + + let cholesky = a.cholesky().unwrap(); + + assert!(cholesky.solve(b.transpose()).unwrap().transpose().approximate_eq(&expected, 1e-4)); + + } + +} diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs index 29c7a89..61749b0 100644 --- a/src/linalg/mod.rs +++ b/src/linalg/mod.rs @@ -49,6 +49,7 @@ pub mod ndarray_bindings; pub mod qr; /// Singular value decomposition. pub mod svd; +pub mod cholesky; use std::fmt::{Debug, Display}; use std::marker::PhantomData; @@ -59,6 +60,7 @@ use evd::EVDDecomposableMatrix; use lu::LUDecomposableMatrix; use qr::QRDecomposableMatrix; use svd::SVDDecomposableMatrix; +use cholesky::CholeskyDecomposableMatrix; /// Column or row vector pub trait BaseVector: Clone + Debug { @@ -507,6 +509,7 @@ pub trait Matrix: + EVDDecomposableMatrix + QRDecomposableMatrix + LUDecomposableMatrix + + CholeskyDecomposableMatrix + PartialEq + Display { diff --git a/src/linalg/naive/dense_matrix.rs b/src/linalg/naive/dense_matrix.rs index cf29061..b5ecd90 100644 --- a/src/linalg/naive/dense_matrix.rs +++ b/src/linalg/naive/dense_matrix.rs @@ -12,6 +12,7 @@ use crate::linalg::evd::EVDDecomposableMatrix; use crate::linalg::lu::LUDecomposableMatrix; use crate::linalg::qr::QRDecomposableMatrix; use crate::linalg::svd::SVDDecomposableMatrix; +use crate::linalg::cholesky::CholeskyDecomposableMatrix; use crate::linalg::Matrix; pub use crate::linalg::{BaseMatrix, BaseVector}; use crate::math::num::RealNumber; @@ -442,6 +443,8 @@ impl QRDecomposableMatrix for DenseMatrix {} impl LUDecomposableMatrix for DenseMatrix {} +impl CholeskyDecomposableMatrix for DenseMatrix {} + impl Matrix for DenseMatrix {} impl PartialEq for DenseMatrix { diff --git a/src/linalg/nalgebra_bindings.rs b/src/linalg/nalgebra_bindings.rs index 3596899..a400a67 100644 --- a/src/linalg/nalgebra_bindings.rs +++ b/src/linalg/nalgebra_bindings.rs @@ -46,6 +46,7 @@ use crate::linalg::evd::EVDDecomposableMatrix; use crate::linalg::lu::LUDecomposableMatrix; use crate::linalg::qr::QRDecomposableMatrix; use crate::linalg::svd::SVDDecomposableMatrix; +use crate::linalg::cholesky::CholeskyDecomposableMatrix; use crate::linalg::Matrix as SmartCoreMatrix; use crate::linalg::{BaseMatrix, BaseVector}; use crate::math::num::RealNumber; @@ -544,6 +545,11 @@ impl + CholeskyDecomposableMatrix for Matrix> +{ +} + impl SmartCoreMatrix for Matrix> { diff --git a/src/linalg/ndarray_bindings.rs b/src/linalg/ndarray_bindings.rs index 9f911f5..76749a7 100644 --- a/src/linalg/ndarray_bindings.rs +++ b/src/linalg/ndarray_bindings.rs @@ -53,6 +53,7 @@ use crate::linalg::evd::EVDDecomposableMatrix; use crate::linalg::lu::LUDecomposableMatrix; use crate::linalg::qr::QRDecomposableMatrix; use crate::linalg::svd::SVDDecomposableMatrix; +use crate::linalg::cholesky::CholeskyDecomposableMatrix; use crate::linalg::Matrix; use crate::linalg::{BaseMatrix, BaseVector}; use crate::math::num::RealNumber; @@ -494,6 +495,11 @@ impl + CholeskyDecomposableMatrix for ArrayBase, Ix2> +{ +} + impl Matrix for ArrayBase, Ix2> {