diff --git a/src/error/mod.rs b/src/error/mod.rs index 320b991..c411e87 100644 --- a/src/error/mod.rs +++ b/src/error/mod.rs @@ -24,6 +24,8 @@ pub enum FailedError { FindFailed, /// Can't decompose a matrix DecompositionFailed, + /// Can't solve for x + SolutionFailed, } impl Failed { @@ -87,6 +89,7 @@ impl fmt::Display for FailedError { FailedError::TransformFailed => "Transform failed", FailedError::FindFailed => "Find failed", FailedError::DecompositionFailed => "Decomposition failed", + FailedError::SolutionFailed => "Can't find solution", }; write!(f, "{}", failed_err_str) } diff --git a/src/linalg/cholesky.rs b/src/linalg/cholesky.rs new file mode 100644 index 0000000..e55d6bb --- /dev/null +++ b/src/linalg/cholesky.rs @@ -0,0 +1,206 @@ +//! # Cholesky Decomposition +//! +//! every positive definite matrix \\(A \in R^{n \times n}\\) can be factored as +//! +//! \\[A = R^TR\\] +//! +//! where \\(R\\) is upper triangular matrix with positive diagonal elements +//! +//! Example: +//! ``` +//! use smartcore::linalg::naive::dense_matrix::*; +//! use crate::smartcore::linalg::cholesky::*; +//! +//! let A = DenseMatrix::from_2d_array(&[ +//! &[25., 15., -5.], +//! &[15., 18., 0.], +//! &[-5., 0., 11.] +//! ]); +//! +//! let cholesky = A.cholesky().unwrap(); +//! let lower_triangular: DenseMatrix = cholesky.L(); +//! let upper_triangular: DenseMatrix = cholesky.U(); +//! ``` +//! +//! ## References: +//! * ["No bullshit guide to linear algebra", Ivan Savov, 2016, 7.6 Matrix decompositions](https://minireference.com/) +//! * ["Numerical Recipes: The Art of Scientific Computing", Press W.H., Teukolsky S.A., Vetterling W.T, Flannery B.P, 3rd ed., 2.9 Cholesky Decomposition](http://numerical.recipes/) +//! +//! +//! +#![allow(non_snake_case)] + +use std::fmt::Debug; +use std::marker::PhantomData; + +use crate::error::{Failed, FailedError}; +use crate::linalg::BaseMatrix; +use crate::math::num::RealNumber; + +#[derive(Debug, Clone)] +/// Results of Cholesky decomposition. +pub struct Cholesky> { + R: M, + t: PhantomData, +} + +impl> Cholesky { + pub(crate) fn new(R: M) -> Cholesky { + Cholesky { + R: R, + t: PhantomData, + } + } + + /// Get lower triangular matrix. + pub fn L(&self) -> M { + let (n, _) = self.R.shape(); + let mut R = M::zeros(n, n); + + for i in 0..n { + for j in 0..n { + if j <= i { + R.set(i, j, self.R.get(i, j)); + } + } + } + R + } + + /// Get upper triangular matrix. + pub fn U(&self) -> M { + let (n, _) = self.R.shape(); + let mut R = M::zeros(n, n); + + for i in 0..n { + for j in 0..n { + if j <= i { + R.set(j, i, self.R.get(i, j)); + } + } + } + R + } + + /// Solves Ax = b + pub(crate) fn solve(&self, mut b: M) -> Result { + let (bn, m) = b.shape(); + let (rn, _) = self.R.shape(); + + if bn != rn { + return Err(Failed::because( + FailedError::SolutionFailed, + &format!("Can't solve Ax = b for x. Number of rows in b != number of rows in R."), + )); + } + + for k in 0..bn { + for j in 0..m { + for i in 0..k { + b.sub_element_mut(k, j, b.get(i, j) * self.R.get(k, i)); + } + b.div_element_mut(k, j, self.R.get(k, k)); + } + } + + for k in (0..bn).rev() { + for j in 0..m { + for i in k + 1..bn { + b.sub_element_mut(k, j, b.get(i, j) * self.R.get(i, k)); + } + b.div_element_mut(k, j, self.R.get(k, k)); + } + } + Ok(b) + } +} + +/// Trait that implements Cholesky decomposition routine for any matrix. +pub trait CholeskyDecomposableMatrix: BaseMatrix { + /// Compute the Cholesky decomposition of a matrix. + fn cholesky(&self) -> Result, Failed> { + self.clone().cholesky_mut() + } + + /// Compute the Cholesky decomposition of a matrix. The input matrix + /// will be used for factorization. + fn cholesky_mut(mut self) -> Result, Failed> { + let (m, n) = self.shape(); + + if m != n { + return Err(Failed::because( + FailedError::DecompositionFailed, + &format!("Can't do Cholesky decomposition on a non-square matrix"), + )); + } + + for j in 0..n { + let mut d = T::zero(); + for k in 0..j { + let mut s = T::zero(); + for i in 0..k { + s += self.get(k, i) * self.get(j, i); + } + s = (self.get(j, k) - s) / self.get(k, k); + self.set(j, k, s); + d = d + s * s; + } + d = self.get(j, j) - d; + + if d < T::zero() { + return Err(Failed::because( + FailedError::DecompositionFailed, + &format!("The matrix is not positive definite."), + )); + } + + self.set(j, j, d.sqrt()); + } + + Ok(Cholesky::new(self)) + } + + /// Solves Ax = b + fn cholesky_solve_mut(self, b: Self) -> Result { + self.cholesky_mut().and_then(|qr| qr.solve(b)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::linalg::naive::dense_matrix::*; + + #[test] + fn cholesky_decompose() { + let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]); + let l = + DenseMatrix::from_2d_array(&[&[5.0, 0.0, 0.0], &[3.0, 3.0, 0.0], &[-1.0, 1.0, 3.0]]); + let u = + DenseMatrix::from_2d_array(&[&[5.0, 3.0, -1.0], &[0.0, 3.0, 1.0], &[0.0, 0.0, 3.0]]); + let cholesky = a.cholesky().unwrap(); + + assert!(cholesky.L().abs().approximate_eq(&l.abs(), 1e-4)); + assert!(cholesky.U().abs().approximate_eq(&u.abs(), 1e-4)); + assert!(cholesky + .L() + .matmul(&cholesky.U()) + .abs() + .approximate_eq(&a.abs(), 1e-4)); + } + + #[test] + fn cholesky_solve_mut() { + let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]); + let b = DenseMatrix::from_2d_array(&[&[40., 51., 28.]]); + let expected = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]); + + let cholesky = a.cholesky().unwrap(); + + assert!(cholesky + .solve(b.transpose()) + .unwrap() + .transpose() + .approximate_eq(&expected, 1e-4)); + } +} diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs index 29c7a89..fb12909 100644 --- a/src/linalg/mod.rs +++ b/src/linalg/mod.rs @@ -33,6 +33,7 @@ //! let u: DenseMatrix = svd.U; //! ``` +pub mod cholesky; /// The matrix is represented in terms of its eigenvalues and eigenvectors. pub mod evd; /// Factors a matrix as the product of a lower triangular matrix and an upper triangular matrix. @@ -55,6 +56,7 @@ use std::marker::PhantomData; use std::ops::Range; use crate::math::num::RealNumber; +use cholesky::CholeskyDecomposableMatrix; use evd::EVDDecomposableMatrix; use lu::LUDecomposableMatrix; use qr::QRDecomposableMatrix; @@ -507,6 +509,7 @@ pub trait Matrix: + EVDDecomposableMatrix + QRDecomposableMatrix + LUDecomposableMatrix + + CholeskyDecomposableMatrix + PartialEq + Display { diff --git a/src/linalg/naive/dense_matrix.rs b/src/linalg/naive/dense_matrix.rs index cf29061..d3d6353 100644 --- a/src/linalg/naive/dense_matrix.rs +++ b/src/linalg/naive/dense_matrix.rs @@ -8,6 +8,7 @@ use serde::de::{Deserializer, MapAccess, SeqAccess, Visitor}; use serde::ser::{SerializeStruct, Serializer}; use serde::{Deserialize, Serialize}; +use crate::linalg::cholesky::CholeskyDecomposableMatrix; use crate::linalg::evd::EVDDecomposableMatrix; use crate::linalg::lu::LUDecomposableMatrix; use crate::linalg::qr::QRDecomposableMatrix; @@ -442,6 +443,8 @@ impl QRDecomposableMatrix for DenseMatrix {} impl LUDecomposableMatrix for DenseMatrix {} +impl CholeskyDecomposableMatrix for DenseMatrix {} + impl Matrix for DenseMatrix {} impl PartialEq for DenseMatrix { diff --git a/src/linalg/nalgebra_bindings.rs b/src/linalg/nalgebra_bindings.rs index 3596899..e0b885b 100644 --- a/src/linalg/nalgebra_bindings.rs +++ b/src/linalg/nalgebra_bindings.rs @@ -42,6 +42,7 @@ use std::ops::{AddAssign, DivAssign, MulAssign, Range, SubAssign}; use nalgebra::{DMatrix, Dynamic, Matrix, MatrixMN, RowDVector, Scalar, VecStorage, U1}; +use crate::linalg::cholesky::CholeskyDecomposableMatrix; use crate::linalg::evd::EVDDecomposableMatrix; use crate::linalg::lu::LUDecomposableMatrix; use crate::linalg::qr::QRDecomposableMatrix; @@ -544,6 +545,11 @@ impl + CholeskyDecomposableMatrix for Matrix> +{ +} + impl SmartCoreMatrix for Matrix> { diff --git a/src/linalg/ndarray_bindings.rs b/src/linalg/ndarray_bindings.rs index 9f911f5..00c9745 100644 --- a/src/linalg/ndarray_bindings.rs +++ b/src/linalg/ndarray_bindings.rs @@ -49,6 +49,7 @@ use std::ops::SubAssign; use ndarray::ScalarOperand; use ndarray::{s, stack, Array, ArrayBase, Axis, Ix1, Ix2, OwnedRepr}; +use crate::linalg::cholesky::CholeskyDecomposableMatrix; use crate::linalg::evd::EVDDecomposableMatrix; use crate::linalg::lu::LUDecomposableMatrix; use crate::linalg::qr::QRDecomposableMatrix; @@ -494,6 +495,11 @@ impl + CholeskyDecomposableMatrix for ArrayBase, Ix2> +{ +} + impl Matrix for ArrayBase, Ix2> {