From 7b3fa982bef15404267f5b8b2eb4d89b3789aa0c Mon Sep 17 00:00:00 2001 From: Volodymyr Orlov Date: Fri, 6 Mar 2020 09:13:54 -0800 Subject: [PATCH] feat: adds PCA --- Cargo.toml | 5 +- src/decomposition/mod.rs | 1 + src/decomposition/pca.rs | 368 ++++++++++++++ src/lib.rs | 1 + src/linalg/evd.rs | 112 +++++ src/linalg/mod.rs | 22 +- src/linalg/naive/dense_matrix.rs | 830 ++++++++++++++++++++++++++++++- src/linalg/ndarray_bindings.rs | 97 +++- src/linalg/svd.rs | 6 +- 9 files changed, 1422 insertions(+), 20 deletions(-) create mode 100644 src/decomposition/mod.rs create mode 100644 src/decomposition/pca.rs create mode 100644 src/linalg/evd.rs diff --git a/Cargo.toml b/Cargo.toml index 5054d5f..59857cd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,8 +6,9 @@ edition = "2018" [dependencies] ndarray = "0.13" -num-traits = "0.2" -rand = "0.7.2" +num-traits = "0.2.11" +num = "0.2.1" +rand = "0.7.3" [dev-dependencies] ndarray = "0.13" diff --git a/src/decomposition/mod.rs b/src/decomposition/mod.rs new file mode 100644 index 0000000..9b82c9f --- /dev/null +++ b/src/decomposition/mod.rs @@ -0,0 +1 @@ +pub mod pca; \ No newline at end of file diff --git a/src/decomposition/pca.rs b/src/decomposition/pca.rs new file mode 100644 index 0000000..cbcaf8e --- /dev/null +++ b/src/decomposition/pca.rs @@ -0,0 +1,368 @@ +use crate::linalg::{Matrix}; + +#[derive(Debug)] +pub struct PCA { + eigenvectors: M, + eigenvalues: Vec, + projection: M, + mu: Vec, + pmu: Vec +} + +#[derive(Debug, Clone)] +pub struct PCAParameters { + use_correlation_matrix: bool +} + +impl Default for PCAParameters { + fn default() -> Self { + PCAParameters { + use_correlation_matrix: false + } + } +} + +impl PCA { + + pub fn new(data: &M, n_components: usize, parameters: PCAParameters) -> PCA { + + let (m, n) = data.shape(); + + let mu = data.column_mean(); + + let mut x = data.clone(); + + for c in 0..n { + for r in 0..m { + x.sub_element_mut(r, c, mu[c]); + } + } + + let mut eigenvalues; + let mut eigenvectors; + + if m > n && !parameters.use_correlation_matrix{ + + let svd = x.svd(); + eigenvalues = svd.s; + for i in 0..eigenvalues.len() { + eigenvalues[i] *= eigenvalues[i]; + } + + eigenvectors = svd.V; + } else { + let mut cov = M::zeros(n, n); + + for k in 0..m { + for i in 0..n { + for j in 0..=i { + cov.add_element_mut(i, j, x.get(k, i) * x.get(k, j)); + } + } + } + + for i in 0..n { + for j in 0..=i { + cov.div_element_mut(i, j, m as f64); + cov.set(j, i, cov.get(i, j)); + } + } + + if parameters.use_correlation_matrix { + let mut sd = vec![0f64; n]; + for i in 0..n { + sd[i] = cov.get(i, i).sqrt(); + } + + for i in 0..n { + for j in 0..=i { + cov.div_element_mut(i, j, sd[i] * sd[j]); + cov.set(j, i, cov.get(i, j)); + } + } + + let evd = cov.evd(true); + + eigenvalues = evd.d; + + eigenvectors = evd.V; + + for i in 0..n { + for j in 0..n { + eigenvectors.div_element_mut(i, j, sd[i]); + } + } + } else { + + let evd = cov.evd(true); + + eigenvalues = evd.d; + + eigenvectors = evd.V; + + } + } + + let mut projection = M::zeros(n_components, n); + for i in 0..n { + for j in 0..n_components { + projection.set(j, i, eigenvectors.get(i, j)); + } + } + + let mut pmu = vec![0f64; n_components]; + for k in 0..n { + for i in 0..n_components { + pmu[i] += projection.get(i, k) * mu[k]; + } + } + + PCA { + eigenvectors: eigenvectors, + eigenvalues: eigenvalues, + projection: projection.transpose(), + mu: mu, + pmu: pmu + } + } + + pub fn transform(&self, x: &M) -> M { + let (nrows, ncols) = x.shape(); + let (_, n_components) = self.projection.shape(); + if ncols != self.mu.len() { + panic!("Invalid input vector size: {}, expected: {}", ncols, self.mu.len()); + } + + let mut x_transformed = x.dot(&self.projection); + for r in 0..nrows { + for c in 0..n_components { + x_transformed.sub_element_mut(r, c, self.pmu[c]); + } + } + x_transformed + } + +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::linalg::naive::dense_matrix::DenseMatrix; + + fn us_arrests_data() -> DenseMatrix { + DenseMatrix::from_array(&[ + &[13.2, 236.0, 58.0, 21.2], + &[10.0, 263.0, 48.0, 44.5], + &[8.1, 294.0, 80.0, 31.0], + &[8.8, 190.0, 50.0, 19.5], + &[9.0, 276.0, 91.0, 40.6], + &[7.9, 204.0, 78.0, 38.7], + &[3.3, 110.0, 77.0, 11.1], + &[5.9, 238.0, 72.0, 15.8], + &[15.4, 335.0, 80.0, 31.9], + &[17.4, 211.0, 60.0, 25.8], + &[5.3, 46.0, 83.0, 20.2], + &[2.6, 120.0, 54.0, 14.2], + &[10.4, 249.0, 83.0, 24.0], + &[7.2, 113.0, 65.0, 21.0], + &[2.2, 56.0, 57.0, 11.3], + &[6.0, 115.0, 66.0, 18.0], + &[9.7, 109.0, 52.0, 16.3], + &[15.4, 249.0, 66.0, 22.2], + &[2.1, 83.0, 51.0, 7.8], + &[11.3, 300.0, 67.0, 27.8], + &[4.4, 149.0, 85.0, 16.3], + &[12.1, 255.0, 74.0, 35.1], + &[2.7, 72.0, 66.0, 14.9], + &[16.1, 259.0, 44.0, 17.1], + &[9.0, 178.0, 70.0, 28.2], + &[6.0, 109.0, 53.0, 16.4], + &[4.3, 102.0, 62.0, 16.5], + &[12.2, 252.0, 81.0, 46.0], + &[2.1, 57.0, 56.0, 9.5], + &[7.4, 159.0, 89.0, 18.8], + &[11.4, 285.0, 70.0, 32.1], + &[11.1, 254.0, 86.0, 26.1], + &[13.0, 337.0, 45.0, 16.1], + &[0.8, 45.0, 44.0, 7.3], + &[7.3, 120.0, 75.0, 21.4], + &[6.6, 151.0, 68.0, 20.0], + &[4.9, 159.0, 67.0, 29.3], + &[6.3, 106.0, 72.0, 14.9], + &[3.4, 174.0, 87.0, 8.3], + &[14.4, 279.0, 48.0, 22.5], + &[3.8, 86.0, 45.0, 12.8], + &[13.2, 188.0, 59.0, 26.9], + &[12.7, 201.0, 80.0, 25.5], + &[3.2, 120.0, 80.0, 22.9], + &[2.2, 48.0, 32.0, 11.2], + &[8.5, 156.0, 63.0, 20.7], + &[4.0, 145.0, 73.0, 26.2], + &[5.7, 81.0, 39.0, 9.3], + &[2.6, 53.0, 66.0, 10.8], + &[6.8, 161.0, 60.0, 15.6]]) + } + + #[test] + fn decompose_covariance() { + + let us_arrests = us_arrests_data(); + + let expected_eigenvectors = DenseMatrix::from_array(&[ + &[-0.0417043206282872, -0.0448216562696701, -0.0798906594208108, -0.994921731246978], + &[-0.995221281426497, -0.058760027857223, 0.0675697350838043, 0.0389382976351601], + &[-0.0463357461197108, 0.97685747990989, 0.200546287353866, -0.0581691430589319], + &[-0.075155500585547, 0.200718066450337, -0.974080592182491, 0.0723250196376097] + ]); + + let expected_projection = DenseMatrix::from_array(&[ + &[-64.8022, -11.448, 2.4949, -2.4079], + &[-92.8275, -17.9829, -20.1266, 4.094], + &[-124.0682, 8.8304, 1.6874, 4.3537], + &[-18.34, -16.7039, -0.2102, 0.521], + &[-107.423, 22.5201, -6.7459, 2.8118], + &[-34.976, 13.7196, -12.2794, 1.7215], + &[60.8873, 12.9325, 8.4207, 0.6999], + &[-66.731, 1.3538, 11.281, 3.728], + &[-165.2444, 6.2747, 2.9979, -1.2477], + &[-40.5352, -7.2902, -3.6095, -7.3437], + &[123.5361, 24.2912, -3.7244, -3.4728], + &[51.797, -9.4692, 1.5201, 3.3478], + &[-78.9921, 12.8971, 5.8833, -0.3676], + &[57.551, 2.8463, -3.7382, -1.6494], + &[115.5868, -3.3421, 0.654, 0.8695], + &[55.7897, 3.1572, -0.3844, -0.6528], + &[62.3832, -10.6733, -2.2371, -3.8762], + &[-78.2776, -4.2949, 3.8279, -4.4836], + &[89.261, -11.4878, 4.6924, 2.1162], + &[-129.3301, -5.007, 2.3472, 1.9283], + &[21.2663, 19.4502, 7.5071, 1.0348], + &[-85.4515, 5.9046, -6.4643, -0.499], + &[98.9548, 5.2096, -0.0066, 0.7319], + &[-86.8564, -27.4284, 5.0034, -3.8798], + &[-7.9863, 5.2756, -5.5006, -0.6794], + &[62.4836, -9.5105, -1.8384, -0.2459], + &[69.0965, -0.2112, -0.468, 0.6566], + &[-83.6136, 15.1022, -15.8887, -0.3342], + &[114.7774, -4.7346, 2.2824, 0.9359], + &[10.8157, 23.1373, 6.3102, -1.6124], + &[-114.8682, -0.3365, -2.2613, 1.3812], + &[-84.2942, 15.924, 4.7213, -0.892], + &[-164.3255, -31.0966, 11.6962, 2.1112], + &[127.4956, -16.135, 1.3118, 2.301], + &[50.0868, 12.2793, -1.6573, -2.0291], + &[19.6937, 3.3701, 0.4531, 0.1803], + &[11.1502, 3.8661, -8.13, 2.914], + &[64.6891, 8.9115, 3.2065, -1.8749], + &[-3.064, 18.374, 17.47, 2.3083], + &[-107.2811, -23.5361, 2.0328, -1.2517], + &[86.1067, -16.5979, -1.3144, 1.2523], + &[-17.5063, -6.5066, -6.1001, -3.9229], + &[-31.2911, 12.985, 0.3934, -4.242], + &[49.9134, 17.6485, -1.7882, 1.8677], + &[124.7145, -27.3136, -4.8028, 2.005], + &[14.8174, -1.7526, -1.0454, -1.1738], + &[25.0758, 9.968, -4.7811, 2.6911], + &[91.5446, -22.9529, 0.402, -0.7369], + &[118.1763, 5.5076, 2.7113, -0.205], + &[10.4345, -5.9245, 3.7944, 0.5179] + ]); + + let expected_eigenvalues: Vec = vec![343544.6277001563, 9897.625949808047, 2063.519887011604, 302.04806302399646]; + + let pca = PCA::new(&us_arrests, 4, Default::default()); + + assert!(pca.eigenvectors.abs().approximate_eq(&expected_eigenvectors.abs(), 1e-4)); + + for i in 0..pca.eigenvalues.len() { + assert_eq!(pca.eigenvalues[i].abs(), expected_eigenvalues[i].abs()); + } + + let us_arrests_t = pca.transform(&us_arrests); + + assert!(us_arrests_t.abs().approximate_eq(&expected_projection.abs(), 1e-4)); + + } + + #[test] + fn decompose_correlation() { + + let us_arrests = us_arrests_data(); + + let expected_eigenvectors = DenseMatrix::from_array(&[ + &[0.124288601688222, -0.0969866877028367, 0.0791404742697482, -0.150572299008293], + &[0.00706888610512014, -0.00227861130898090, 0.00325028101296307, 0.00901099154845273], + &[0.0194141494466002, 0.060910660326921, 0.0263806464184195, -0.0093429458365566], + &[0.0586084532558777, 0.0180450999787168, -0.0881962972508558, -0.0096011588898465] + ]); + + let expected_projection = DenseMatrix::from_array(&[ + &[0.9856, -1.1334, 0.4443, -0.1563], + &[1.9501, -1.0732, -2.04, 0.4386], + &[1.7632, 0.746, -0.0548, 0.8347], + &[-0.1414, -1.1198, -0.1146, 0.1828], + &[2.524, 1.5429, -0.5986, 0.342], + &[1.5146, 0.9876, -1.095, -0.0015], + &[-1.3586, 1.0889, 0.6433, 0.1185], + &[0.0477, 0.3254, 0.7186, 0.882], + &[3.013, -0.0392, 0.5768, 0.0963], + &[1.6393, -1.2789, 0.3425, -1.0768], + &[-0.9127, 1.5705, -0.0508, -0.9028], + &[-1.6398, -0.211, -0.2598, 0.4991], + &[1.3789, 0.6818, 0.6775, 0.122], + &[-0.5055, 0.1516, -0.2281, -0.4247], + &[-2.2536, 0.1041, -0.1646, -0.0176], + &[-0.7969, 0.2702, -0.0256, -0.2065], + &[-0.7509, -0.9584, 0.0284, -0.6706], + &[1.5648, -0.8711, 0.7835, -0.4547], + &[-2.3968, -0.3764, 0.0657, 0.3305], + &[1.7634, -0.4277, 0.1573, 0.5591], + &[-0.4862, 1.4745, 0.6095, 0.1796], + &[2.1084, 0.1554, -0.3849, -0.1024], + &[-1.6927, 0.6323, -0.1531, -0.0673], + &[0.9965, -2.3938, 0.7408, -0.2155], + &[0.6968, 0.2634, -0.3774, -0.2258], + &[-1.1855, -0.5369, -0.2469, -0.1237], + &[-1.2656, 0.194, -0.1756, -0.0159], + &[2.8744, 0.7756, -1.1634, -0.3145], + &[-2.3839, 0.0181, -0.0369, 0.0331], + &[0.1816, 1.4495, 0.7645, -0.2434], + &[1.98, -0.1428, -0.1837, 0.3395], + &[1.6826, 0.8232, 0.6431, 0.0135], + &[1.1234, -2.228, 0.8636, 0.9544], + &[-2.9922, -0.5991, -0.3013, 0.254], + &[-0.226, 0.7422, 0.0311, -0.4739], + &[-0.3118, 0.2879, 0.0153, -0.0103], + &[0.0591, 0.5414, -0.9398, 0.2378], + &[-0.8884, 0.5711, 0.4006, -0.3591], + &[-0.8638, 1.492, 1.3699, 0.6136], + &[1.3207, -1.9334, 0.3005, 0.1315], + &[-1.9878, -0.8233, -0.3893, 0.1096], + &[0.9997, -0.8603, -0.1881, -0.6529], + &[1.3551, 0.4125, 0.4921, -0.6432], + &[-0.5506, 1.4715, -0.2937, 0.0823], + &[-2.8014, -1.4023, -0.8413, 0.1449], + &[-0.0963, -0.1997, -0.0117, -0.2114], + &[-0.2169, 0.9701, -0.6249, 0.2208], + &[-2.1086, -1.4248, -0.1048, -0.1319], + &[-2.0797, 0.6113, 0.1389, -0.1841], + &[-0.6294, -0.321, 0.2407, 0.1667] + ]); + + let expected_eigenvalues: Vec = vec![2.480241579149493, 0.9897651525398419, 0.35656318058083064, 0.1734300877298357]; + + let pca = PCA::new(&us_arrests, 4, PCAParameters{use_correlation_matrix: true}); + + assert!(pca.eigenvectors.abs().approximate_eq(&expected_eigenvectors.abs(), 1e-4)); + + for i in 0..pca.eigenvalues.len() { + assert_eq!(pca.eigenvalues[i].abs(), expected_eigenvalues[i].abs()); + } + + let us_arrests_t = pca.transform(&us_arrests); + + assert!(us_arrests_t.abs().approximate_eq(&expected_projection.abs(), 1e-4)); + + } + +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 6dac826..63eeea7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ pub mod classification; pub mod regression; pub mod cluster; +pub mod decomposition; pub mod linalg; pub mod math; pub mod error; diff --git a/src/linalg/evd.rs b/src/linalg/evd.rs new file mode 100644 index 0000000..52916a3 --- /dev/null +++ b/src/linalg/evd.rs @@ -0,0 +1,112 @@ +use crate::linalg::{Matrix}; + +#[derive(Debug, Clone)] +pub struct EVD { + pub d: Vec, + pub e: Vec, + pub V: M +} + +impl EVD { + pub fn new(V: M, d: Vec, e: Vec) -> EVD { + EVD { + d: d, + e: e, + V: V + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::linalg::naive::dense_matrix::DenseMatrix; + + #[test] + fn decompose_symmetric() { + + let A = DenseMatrix::from_array(&[ + &[0.9000, 0.4000, 0.7000], + &[0.4000, 0.5000, 0.3000], + &[0.7000, 0.3000, 0.8000]]); + + let eigen_values = vec![1.7498382, 0.3165784, 0.1335834]; + + let eigen_vectors = DenseMatrix::from_array(&[ + &[0.6881997, -0.07121225, 0.7220180], + &[0.3700456, 0.89044952, -0.2648886], + &[0.6240573, -0.44947578, -0.6391588] + ]); + + let evd = A.evd(true); + + assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4)); + for i in 0..eigen_values.len() { + assert!((eigen_values[i] - evd.d[i]).abs() < 1e-4); + } + for i in 0..eigen_values.len() { + assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON); + } + + } + + #[test] + fn decompose_asymmetric() { + + let A = DenseMatrix::from_array(&[ + &[0.9000, 0.4000, 0.7000], + &[0.4000, 0.5000, 0.3000], + &[0.8000, 0.3000, 0.8000]]); + + let eigen_values = vec![1.79171122, 0.31908143, 0.08920735]; + + let eigen_vectors = DenseMatrix::from_array(&[ + &[0.7178958, 0.05322098, 0.6812010], + &[0.3837711, -0.84702111, -0.1494582], + &[0.6952105, 0.43984484, -0.7036135] + ]); + + let evd = A.evd(false); + + assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4)); + for i in 0..eigen_values.len() { + assert!((eigen_values[i] - evd.d[i]).abs() < 1e-4); + } + for i in 0..eigen_values.len() { + assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON); + } + + } + + #[test] + fn decompose_complex() { + + let A = DenseMatrix::from_array(&[ + &[3.0, -2.0, 1.0, 1.0], + &[4.0, -1.0, 1.0, 1.0], + &[1.0, 1.0, 3.0, -2.0], + &[1.0, 1.0, 4.0, -1.0]]); + + let eigen_values_d = vec![0.0, 2.0, 2.0, 0.0]; + let eigen_values_e = vec![2.2361, 0.9999, -0.9999, -2.2361]; + + let eigen_vectors = DenseMatrix::from_array(&[ + &[-0.9159, -0.1378, 0.3816, -0.0806], + &[-0.6707, 0.1059, 0.901, 0.6289], + &[0.9159, -0.1378, 0.3816, 0.0806], + &[0.6707, 0.1059, 0.901, -0.6289] + ]); + + let evd = A.evd(false); + + assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4)); + for i in 0..eigen_values_d.len() { + assert!((eigen_values_d[i] - evd.d[i]).abs() < 1e-4); + } + for i in 0..eigen_values_e.len() { + assert!((eigen_values_e[i] - evd.e[i]).abs() < 1e-4); + } + + } + +} \ No newline at end of file diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs index 89fcc2c..163036e 100644 --- a/src/linalg/mod.rs +++ b/src/linalg/mod.rs @@ -1,10 +1,12 @@ pub mod naive; pub mod svd; +pub mod evd; pub mod ndarray_bindings; use std::ops::Range; use std::fmt::Debug; use svd::SVD; +use evd::EVD; pub trait Matrix: Clone + Debug { @@ -37,6 +39,14 @@ pub trait Matrix: Clone + Debug { } + fn evd(&self, symmetric: bool) -> EVD{ + self.clone().evd_mut(symmetric) + } + + fn evd_mut(self, symmetric: bool) -> EVD; + + fn eye(size: usize) -> Self; + fn zeros(nrows: usize, ncols: usize) -> Self; fn ones(nrows: usize, ncols: usize) -> Self; @@ -51,7 +61,7 @@ pub trait Matrix: Clone + Debug { fn h_stack(&self, other: &Self) -> Self; - fn dot(&self, other: &Self) -> Self; + fn dot(&self, other: &Self) -> Self; fn vector_dot(&self, other: &Self) -> f64; @@ -67,6 +77,14 @@ pub trait Matrix: Clone + Debug { fn div_mut(&mut self, other: &Self) -> &Self; + fn div_element_mut(&mut self, row: usize, col: usize, x: f64); + + fn mul_element_mut(&mut self, row: usize, col: usize, x: f64); + + fn add_element_mut(&mut self, row: usize, col: usize, x: f64); + + fn sub_element_mut(&mut self, row: usize, col: usize, x: f64); + fn add(&self, other: &Self) -> Self { let mut r = self.clone(); r.add_mut(other); @@ -133,6 +151,8 @@ pub trait Matrix: Clone + Debug { fn norm(&self, p:f64) -> f64; + fn column_mean(&self) -> Vec; + fn negative_mut(&mut self); fn negative(&self) -> Self { diff --git a/src/linalg/naive/dense_matrix.rs b/src/linalg/naive/dense_matrix.rs index 9f56dc3..98e7540 100644 --- a/src/linalg/naive/dense_matrix.rs +++ b/src/linalg/naive/dense_matrix.rs @@ -1,6 +1,10 @@ +extern crate num; use std::ops::Range; +use std::fmt; +use num::complex::Complex; use crate::linalg::{Matrix}; use crate::linalg::svd::SVD; +use crate::linalg::evd::EVD; use crate::math; use rand::prelude::*; @@ -13,6 +17,16 @@ pub struct DenseMatrix { } +impl fmt::Display for DenseMatrix { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut rows: Vec> = Vec::new(); + for r in 0..self.nrows { + rows.push(self.get_row_as_vec(r).iter().map(|x| (x * 1e4).round() / 1e4 ).collect()); + } + write!(f, "{:?}", rows) + } +} + impl DenseMatrix { fn new(nrows: usize, ncols: usize, values: Vec) -> DenseMatrix { @@ -67,23 +81,717 @@ impl DenseMatrix { pub fn get_raw_values(&self) -> &Vec { &self.values + } + + fn tred2(&mut self, d: &mut Vec, e: &mut Vec) { + + let n = self.nrows; + for i in 0..n { + d[i] = self.get(n - 1, i); + } + + // Householder reduction to tridiagonal form. + for i in (1..n).rev() { + // Scale to avoid under/overflow. + let mut scale = 0f64; + let mut h = 0f64; + for k in 0..i { + scale = scale + d[k].abs(); + } + if scale == 0f64 { + e[i] = d[i - 1]; + for j in 0..i { + d[j] = self.get(i - 1, j); + self.set(i, j, 0.0); + self.set(j, i, 0.0); + } + } else { + // Generate Householder vector. + for k in 0..i { + d[k] /= scale; + h += d[k] * d[k]; + } + let mut f = d[i - 1]; + let mut g = h.sqrt(); + if f > 0f64 { + g = -g; + } + e[i] = scale * g; + h = h - f * g; + d[i - 1] = f - g; + for j in 0..i { + e[j] = 0f64; + } + + // Apply similarity transformation to remaining columns. + for j in 0..i { + f = d[j]; + self.set(j, i, f); + g = e[j] + self.get(j, j) * f; + for k in j + 1..=i - 1 { + g += self.get(k, j) * d[k]; + e[k] += self.get(k, j) * f; + } + e[j] = g; + } + f = 0.0; + for j in 0..i { + e[j] /= h; + f += e[j] * d[j]; + } + let hh = f / (h + h); + for j in 0..i { + e[j] -= hh * d[j]; + } + for j in 0..i { + f = d[j]; + g = e[j]; + for k in j..=i-1 { + self.sub_element_mut(k, j, f * e[k] + g * d[k]); + } + d[j] = self.get(i - 1, j); + self.set(i, j, 0.0); + } + } + d[i] = h; + } + + // Accumulate transformations. + for i in 0..n-1 { + self.set(n - 1, i, self.get(i, i)); + self.set(i, i, 1.0); + let h = d[i + 1]; + if h != 0f64 { + for k in 0..=i { + d[k] = self.get(k, i + 1) / h; + } + for j in 0..=i { + let mut g = 0f64; + for k in 0..=i { + g += self.get(k, i + 1) * self.get(k, j); + } + for k in 0..=i { + self.sub_element_mut(k, j, g * d[k]); + } + } + } + for k in 0..=i { + self.set(k, i + 1, 0.0); + } + } + for j in 0..n { + d[j] = self.get(n - 1, j); + self.set(n - 1, j, 0.0); + } + self.set(n - 1, n - 1, 1.0); + e[0] = 0.0; } - fn div_element_mut(&mut self, row: usize, col: usize, x: f64) { - self.values[col*self.nrows + row] /= x; + fn tql2(&mut self, d: &mut Vec, e: &mut Vec) { + let n = self.nrows; + for i in 1..n { + e[i - 1] = e[i]; + } + e[n - 1] = 0f64; + + let mut f = 0f64; + let mut tst1 = 0f64; + for l in 0..n { + // Find small subdiagonal element + tst1 = f64::max(tst1, d[l].abs() + e[l].abs()); + + let mut m = l; + + loop { + if m < n { + if e[m].abs() <= tst1 * std::f64::EPSILON { + break; + } + m += 1; + } else { + break; + } + } + + // If m == l, d[l] is an eigenvalue, + // otherwise, iterate. + if m > l { + let mut iter = 0; + loop { + iter += 1; + if iter >= 30 { + panic!("Too many iterations"); + } + + // Compute implicit shift + let mut g = d[l]; + let mut p = (d[l + 1] - g) / (2.0 * e[l]); + let mut r = p.hypot(1.0); + if p < 0f64 { + r = -r; + } + d[l] = e[l] / (p + r); + d[l + 1] = e[l] * (p + r); + let dl1 = d[l + 1]; + let mut h = g - d[l]; + for i in l+2..n { + d[i] -= h; + } + f = f + h; + + // Implicit QL transformation. + p = d[m]; + let mut c = 1.0; + let mut c2 = c; + let mut c3 = c; + let el1 = e[l + 1]; + let mut s = 0.0; + let mut s2 = 0.0; + for i in (l..m).rev() { + c3 = c2; + c2 = c; + s2 = s; + g = c * e[i]; + h = c * p; + r = p.hypot(e[i]); + e[i + 1] = s * r; + s = e[i] / r; + c = p / r; + p = c * d[i] - s * g; + d[i + 1] = h + s * (c * g + s * d[i]); + + // Accumulate transformation. + for k in 0..n { + h = self.get(k, i + 1); + self.set(k, i + 1, s * self.get(k, i) + c * h); + self.set(k, i, c * self.get(k, i) - s * h); + } + } + p = -s * s2 * c3 * el1 * e[l] / dl1; + e[l] = s * p; + d[l] = c * p; + + // Check for convergence. + if e[l].abs() <= tst1 * std::f64::EPSILON { + break; + } + } + } + d[l] = d[l] + f; + e[l] = 0f64; + } + + // Sort eigenvalues and corresponding vectors. + for i in 0..n-1 { + let mut k = i; + let mut p = d[i]; + for j in i + 1..n { + if d[j] > p { + k = j; + p = d[j]; + } + } + if k != i { + d[k] = d[i]; + d[i] = p; + for j in 0..n { + p = self.get(j, i); + self.set(j, i, self.get(j, k)); + self.set(j, k, p); + } + } + } } - fn mul_element_mut(&mut self, row: usize, col: usize, x: f64) { - self.values[col*self.nrows + row] *= x; + fn balance(A: &mut Self) -> Vec { + let radix = 2f64; + let sqrdx = radix * radix; + + let n = A.nrows; + + let mut scale = vec![1f64; n]; + + let mut done = false; + while !done { + done = true; + for i in 0..n { + let mut r = 0f64; + let mut c = 0f64; + for j in 0..n { + if j != i { + c += A.get(j, i).abs(); + r += A.get(i, j).abs(); + } + } + if c != 0f64 && r != 0f64 { + let mut g = r / radix; + let mut f = 1.0; + let s = c + r; + while c < g { + f *= radix; + c *= sqrdx; + } + g = r * radix; + while c > g { + f /= radix; + c /= sqrdx; + } + if (c + r) / f < 0.95 * s { + done = false; + g = 1.0 / f; + scale[i] *= f; + for j in 0..n { + A.mul_element_mut(i, j, g); + } + for j in 0..n { + A.mul_element_mut(j, i, f); + } + } + } + } + } + + return scale; } - fn add_element_mut(&mut self, row: usize, col: usize, x: f64) { - self.values[col*self.nrows + row] += x + fn elmhes(A: &mut Self) -> Vec { + let n = A.nrows; + let mut perm = vec![0; n]; + + for m in 1..n-1 { + let mut x = 0f64; + let mut i = m; + for j in m..n { + if A.get(j, m - 1).abs() > x.abs() { + x = A.get(j, m - 1); + i = j; + } + } + perm[m] = i; + if i != m { + for j in (m-1)..n { + let swap = A.get(i, j); + A.set(i, j, A.get(m, j)); + A.set(m, j, swap); + } + for j in 0..n { + let swap = A.get(j, i); + A.set(j, i, A.get(j, m)); + A.set(j, m, swap); + } + } + if x != 0f64 { + for i in (m + 1)..n { + let mut y = A.get(i, m - 1); + if y != 0f64 { + y /= x; + A.set(i, m - 1, y); + for j in m..n { + A.sub_element_mut(i, j, y * A.get(m, j)); + } + for j in 0..n { + A.add_element_mut(j, m, y * A.get(j, i)); + } + } + } + } + } + + return perm; + } + + fn eltran(A: &Self, V: &mut Self, perm: &Vec) { + let n = A.nrows; + for mp in (1..n - 1).rev() { + for k in mp + 1..n { + V.set(k, mp, A.get(k, mp - 1)); + } + let i = perm[mp]; + if i != mp { + for j in mp..n { + V.set(mp, j, V.get(i, j)); + V.set(i, j, 0.0); + } + V.set(i, mp, 1.0); + } + } } - fn sub_element_mut(&mut self, row: usize, col: usize, x: f64) { - self.values[col*self.nrows + row] -= x; - } + fn hqr2(A: &mut Self, V: &mut Self, d: &mut Vec, e: &mut Vec) { + let n = A.nrows; + let mut z = 0f64; + let mut s = 0f64; + let mut r = 0f64; + let mut q = 0f64; + let mut p = 0f64; + let mut anorm = 0f64; + + for i in 0..n { + for j in i32::max(i as i32 - 1, 0)..n as i32 { + anorm += A.get(i, j as usize).abs(); + } + } + + let mut nn = n - 1; + let mut t = 0.0; + 'outer: loop { + let mut its = 0; + loop { + let mut l = nn; + while l > 0 { + s = A.get(l - 1, l - 1).abs() + A.get(l, l).abs(); + if s == 0.0 { + s = anorm; + } + if A.get(l, l - 1).abs() <= std::f64::EPSILON * s { + A.set(l, l - 1, 0.0); + break; + } + l -= 1; + } + let mut x = A.get(nn, nn); + if l == nn { + d[nn] = x + t; + A.set(nn, nn, x + t); + if nn == 0 { + break 'outer; + } else { + nn -= 1; + } + } else { + let mut y = A.get(nn - 1, nn - 1); + let mut w = A.get(nn, nn - 1) * A.get(nn - 1, nn); + if l == nn - 1 { + p = 0.5 * (y - x); + q = p * p + w; + z = q.abs().sqrt(); + x += t; + A.set(nn, nn, x ); + A.set(nn - 1, nn - 1, y + t); + if q >= 0.0 { + z = p + z.copysign(p); + d[nn - 1] = x + z; + d[nn] = x + z; + if z != 0.0 { + d[nn] = x - w / z; + } + x = A.get(nn, nn - 1); + s = x.abs() + z.abs(); + p = x / s; + q = z / s; + r = (p * p + q * q).sqrt(); + p /= r; + q /= r; + for j in nn-1..n { + z = A.get(nn - 1, j); + A.set(nn - 1, j, q * z + p * A.get(nn, j)); + A.set(nn, j, q * A.get(nn, j) - p * z); + } + for i in 0..=nn { + z = A.get(i, nn - 1); + A.set(i, nn - 1, q * z + p * A.get(i, nn)); + A.set(i, nn, q * A.get(i, nn) - p * z); + } + for i in 0..n { + z = V.get(i, nn - 1); + V.set(i, nn - 1, q * z + p * V.get(i, nn)); + V.set(i, nn, q * V.get(i, nn) - p * z); + } + } else { + d[nn] = x + p; + e[nn] = -z; + d[nn - 1] = d[nn]; + e[nn - 1] = -e[nn]; + } + + if nn <= 1 { + break 'outer; + } else { + nn -= 2; + } + } else { + if its == 30 { + panic!("Too many iterations in hqr"); + } + if its == 10 || its == 20 { + t += x; + for i in 0..nn+1 { + A.sub_element_mut(i, i, x); + } + s = A.get(nn, nn - 1).abs() + A.get(nn - 1, nn - 2).abs(); + y = 0.75 * s; + x = 0.75 * s; + w = -0.4375 * s * s; + } + its += 1; + let mut m = nn - 2; + while m >= l { + z = A.get(m, m); + r = x - z; + s = y - z; + p = (r * s - w) / A.get(m + 1, m) + A.get(m, m + 1); + q = A.get(m + 1, m + 1) - z - r - s; + r = A.get(m + 2, m + 1); + s = p.abs() + q.abs() + r.abs(); + p /= s; + q /= s; + r /= s; + if m == l { + break; + } + let u = A.get(m, m - 1).abs() * (q.abs() + r.abs()); + let v = p.abs() * (A.get(m - 1, m - 1).abs() + z.abs() + A.get(m + 1, m + 1).abs()); + if u <= std::f64::EPSILON * v { + break; + } + m -= 1; + } + for i in m..nn-1 { + A.set(i + 2, i , 0.0); + if i != m { + A.set(i + 2, i - 1, 0.0); + } + } + for k in m..nn { + if k != m { + p = A.get(k, k - 1); + q = A.get(k + 1, k - 1); + r = 0.0; + if k + 1 != nn { + r = A.get(k + 2, k - 1); + } + x = p.abs() + q.abs() +r.abs(); + if x != 0.0 { + p /= x; + q /= x; + r /= x; + } + } + let s = (p * p + q * q + r * r).sqrt().copysign(p); + if s != 0.0 { + if k == m { + if l != m { + A.set(k, k - 1, -A.get(k, k - 1)); + } + } else { + A.set(k, k - 1, -s * x); + } + p += s; + x = p / s; + y = q / s; + z = r / s; + q /= p; + r /= p; + for j in k..n { + p = A.get(k, j) + q * A.get(k + 1, j); + if k + 1 != nn { + p += r * A.get(k + 2, j); + A.sub_element_mut(k + 2, j, p * z); + } + A.sub_element_mut(k + 1, j, p * y); + A.sub_element_mut(k, j, p * x); + } + let mmin; + if nn < k + 3 { + mmin = nn; + } else { + mmin = k + 3; + } + for i in 0..mmin+1 { + p = x * A.get(i, k) + y * A.get(i, k + 1); + if k + 1 != nn { + p += z * A.get(i, k + 2); + A.sub_element_mut(i, k + 2, p * r); + } + A.sub_element_mut(i, k + 1, p * q); + A.sub_element_mut(i, k, p); + } + for i in 0..n { + p = x * V.get(i, k) + y * V.get(i, k + 1); + if k + 1 != nn { + p += z * V.get(i, k + 2); + V.sub_element_mut(i, k + 2, p * r); + } + V.sub_element_mut(i, k + 1, p * q); + V.sub_element_mut(i, k, p); + } + } + } + } + } + if l + 1 >= nn { + break; + } + }; + } + + if anorm != 0f64 { + for nn in (0..n).rev() { + p = d[nn]; + q = e[nn]; + let na = nn.wrapping_sub(1); + if q == 0f64 { + let mut m = nn; + A.set(nn, nn, 1.0); + if nn > 0 { + let mut i = nn - 1; + loop { + let w = A.get(i, i) - p; + r = 0.0; + for j in m..=nn { + r += A.get(i, j) * A.get(j, nn); + } + if e[i] < 0.0 { + z = w; + s = r; + } else { + m = i; + + if e[i] == 0.0 { + t = w; + if t == 0.0 { + t = std::f64::EPSILON * anorm; + } + A.set(i, nn, -r / t); + } else { + let x = A.get(i, i + 1); + let y = A.get(i + 1, i); + q = (d[i] - p).powf(2f64) + e[i].powf(2f64); + t = (x * s - z * r) / q; + A.set(i, nn, t); + if x.abs() > z.abs() { + A.set(i + 1, nn, (-r - w * t) / x); + } else { + A.set(i + 1, nn, (-s - y * t) / z); + } + } + t = A.get(i, nn).abs(); + if std::f64::EPSILON * t * t > 1f64 { + for j in i..=nn { + A.div_element_mut(j, nn, t); + } + } + } + if i == 0{ + break; + } else { + i -= 1; + } + } + } + } else if q < 0f64 { + let mut m = na; + if A.get(nn, na).abs() > A.get(na, nn).abs() { + A.set(na, na, q / A.get(nn, na)); + A.set(na, nn, -(A.get(nn, nn) - p) / A.get(nn, na)); + } else { + let temp = Complex::new(0.0, -A.get(na, nn)) / Complex::new(A.get(na, na) - p, q); + A.set(na, na, temp.re); + A.set(na, nn, temp.im); + } + A.set(nn, na, 0.0); + A.set(nn, nn, 1.0); + if nn >= 2 { + for i in (0..nn - 1).rev() { + let w = A.get(i, i) - p; + let mut ra = 0f64; + let mut sa = 0f64; + for j in m..=nn { + ra += A.get(i, j) * A.get(j, na); + sa += A.get(i, j) * A.get(j, nn); + } + if e[i] < 0.0 { + z = w; + r = ra; + s = sa; + } else { + m = i; + if e[i] == 0.0 { + let temp = Complex::new(-ra, -sa) / Complex::new(w, q); + A.set(i, na, temp.re); + A.set(i, nn, temp.im); + } else { + let x = A.get(i, i + 1); + let y = A.get(i + 1, i); + let mut vr = (d[i] - p).powf(2f64) + (e[i]).powf(2.0) - q * q; + let vi = 2.0 * q * (d[i] - p); + if vr == 0.0 && vi == 0.0 { + vr = std::f64::EPSILON * anorm * (w.abs() + q.abs() + x.abs() + y.abs() + z.abs()); + } + let temp = Complex::new(x * r - z * ra + q * sa, x * s - z * sa - q * ra) / Complex::new(vr, vi); + A.set(i, na, temp.re); + A.set(i, nn, temp.im); + if x.abs() > z.abs() + q.abs() { + A.set(i + 1, na, (-ra - w * A.get(i, na) + q * A.get(i, nn)) / x); + A.set(i + 1, nn, (-sa - w * A.get(i, nn) - q * A.get(i, na)) / x); + } else { + let temp = Complex::new(-r - y * A.get(i, na), -s - y * A.get(i, nn)) / Complex::new(z, q); + A.set(i + 1, na, temp.re); + A.set(i + 1, nn, temp.im); + } + } + } + t = f64::max(A.get(i, na).abs(), A.get(i, nn).abs()); + if std::f64::EPSILON * t * t > 1f64 { + for j in i..=nn { + A.div_element_mut(j, na, t); + A.div_element_mut(j, nn, t); + } + } + } + } + } + } + + for j in (0..n).rev() { + for i in 0..n { + z = 0f64; + for k in 0..=j { + z += V.get(i, k) * A.get(k, j); + } + V.set(i, j, z); + } + } + } + } + + fn balbak(V: &mut Self, scale: &Vec) { + let n = V.nrows; + for i in 0..n { + for j in 0..n { + V.mul_element_mut(i, j, scale[i]); + } + } + } + + fn sort(d: &mut Vec, e: &mut Vec, V: &mut Self) { + let n = d.len(); + let mut temp = vec![0f64; n]; + for j in 1..n { + let real = d[j]; + let img = e[j]; + for k in 0..n { + temp[k] = V.get(k, j); + } + let mut i = j as i32 - 1; + while i >= 0 { + if d[i as usize] >= d[j] { + break; + } + d[i as usize + 1] = d[i as usize]; + e[i as usize + 1] = e[i as usize]; + for k in 0..n { + V.set(k, i as usize + 1, V.get(k, i as usize)); + } + i -= 1; + } + d[i as usize + 1] = real; + e[i as usize + 1] = img; + for k in 0..n { + V.set(k, i as usize + 1, temp[k]); + } + } + } } @@ -160,6 +868,16 @@ impl Matrix for DenseMatrix { DenseMatrix::fill(nrows, ncols, 1f64) } + fn eye(size: usize) -> Self { + let mut matrix = Self::zeros(size, size); + + for i in 0..size { + matrix.set(i, i, 1.0); + } + + return matrix; + } + fn to_raw_vector(&self) -> Vec{ let mut v = vec![0.; self.nrows * self.ncols]; @@ -229,7 +947,7 @@ impl Matrix for DenseMatrix { } result - } + } fn vector_dot(&self, other: &Self) -> f64 { if (self.nrows != 1 || self.nrows != 1) && (other.nrows != 1 || other.ncols != 1) { @@ -681,6 +1399,44 @@ impl Matrix for DenseMatrix { SVD::new(U, v, w) + } + + fn evd_mut(mut self, symmetric: bool) -> EVD{ + if self.ncols != self.nrows { + panic!("Matrix is not square: {} x {}", self.nrows, self.ncols); + } + + let n = self.nrows; + let mut d = vec![0f64; n]; + let mut e = vec![0f64; n]; + + let mut V; + if symmetric { + V = self; + // Tridiagonalize. + V.tred2(&mut d, &mut e); + // Diagonalize. + V.tql2(&mut d, &mut e); + + } else { + let scale = Self::balance(&mut self); + + let perm = Self::elmhes(&mut self); + + V = Self::eye(n); + + Self::eltran(&self, &mut V, &perm); + + Self::hqr2(&mut self, &mut V, &mut d, &mut e); + Self::balbak(&mut V, &scale); + Self::sort(&mut d, &mut e, &mut V); + } + + EVD { + V: V, + d: d, + e: e + } } fn approximate_eq(&self, other: &Self, error: f64) -> bool { @@ -755,6 +1511,22 @@ impl Matrix for DenseMatrix { self } + fn div_element_mut(&mut self, row: usize, col: usize, x: f64) { + self.values[col*self.nrows + row] /= x; + } + + fn mul_element_mut(&mut self, row: usize, col: usize, x: f64) { + self.values[col*self.nrows + row] *= x; + } + + fn add_element_mut(&mut self, row: usize, col: usize, x: f64) { + self.values[col*self.nrows + row] += x + } + + fn sub_element_mut(&mut self, row: usize, col: usize, x: f64) { + self.values[col*self.nrows + row] -= x; + } + fn generate_positive_definite(nrows: usize, ncols: usize) -> Self { let m = DenseMatrix::rand(nrows, ncols); m.dot(&m.transpose()) @@ -815,6 +1587,22 @@ impl Matrix for DenseMatrix { } } + fn column_mean(&self) -> Vec { + let mut mean = vec![0f64; self.ncols]; + + for r in 0..self.nrows { + for c in 0..self.ncols { + mean[c] += self.get(r, c); + } + } + + for i in 0..mean.len() { + mean[i] /= self.nrows as f64; + } + + mean + } + fn add_scalar_mut(&mut self, scalar: f64) -> &Self { for i in 0..self.values.len() { self.values[i] += scalar; @@ -1138,6 +1926,26 @@ mod tests { assert!((prob.get(0, 0) - 0.09).abs() < 0.01); assert!((prob.get(0, 1) - 0.24).abs() < 0.01); assert!((prob.get(0, 2) - 0.66).abs() < 0.01); - } + } + + #[test] + fn col_mean(){ + let a = DenseMatrix::from_array(&[ + &[1., 2., 3.], + &[4., 5., 6.], + &[7., 8., 9.]]); + let res = a.column_mean(); + assert_eq!(res, vec![4., 5., 6.]); + } + + #[test] + fn eye(){ + let a = DenseMatrix::from_array(&[ + &[1., 0., 0.], + &[0., 1., 0.], + &[0., 0., 1.]]); + let res = DenseMatrix::eye(3); + assert_eq!(res, a); + } } diff --git a/src/linalg/ndarray_bindings.rs b/src/linalg/ndarray_bindings.rs index f0bb767..7b3aad2 100644 --- a/src/linalg/ndarray_bindings.rs +++ b/src/linalg/ndarray_bindings.rs @@ -1,6 +1,7 @@ use std::ops::Range; use crate::linalg::{Matrix}; use crate::linalg::svd::SVD; +use crate::linalg::evd::EVD; use ndarray::{Array, ArrayBase, OwnedRepr, Ix2, Ix1, Axis, stack, s}; impl Matrix for ArrayBase, Ix2> @@ -37,10 +38,18 @@ impl Matrix for ArrayBase, Ix2> panic!("svd method is not implemented for ndarray"); } + fn evd_mut(self, symmetric: bool) -> EVD{ + panic!("evd method is not implemented for ndarray"); + } + fn qr_solve_mut(&mut self, b: Self) -> Self { panic!("qr_solve_mut method is not implemented for ndarray"); } + fn eye(size: usize) -> Self { + Array::eye(size) + } + fn zeros(nrows: usize, ncols: usize) -> Self { Array::zeros((nrows, ncols)) } @@ -58,7 +67,7 @@ impl Matrix for ArrayBase, Ix2> } fn shape(&self) -> (usize, usize) { - (self.rows(), self.cols()) + (self.nrows(), self.ncols()) } fn v_stack(&self, other: &Self) -> Self { @@ -71,7 +80,7 @@ impl Matrix for ArrayBase, Ix2> fn dot(&self, other: &Self) -> Self { self.dot(other) - } + } fn vector_dot(&self, other: &Self) -> f64 { self.dot(&other.view().reversed_axes())[[0, 0]] @@ -172,6 +181,26 @@ impl Matrix for ArrayBase, Ix2> } } + fn column_mean(&self) -> Vec { + self.mean_axis(Axis(0)).unwrap().to_vec() + } + + fn div_element_mut(&mut self, row: usize, col: usize, x: f64){ + self[[row, col]] /= x; + } + + fn mul_element_mut(&mut self, row: usize, col: usize, x: f64){ + self[[row, col]] *= x; + } + + fn add_element_mut(&mut self, row: usize, col: usize, x: f64){ + self[[row, col]] += x; + } + + fn sub_element_mut(&mut self, row: usize, col: usize, x: f64){ + self[[row, col]] -= x; + } + fn negative_mut(&mut self){ *self *= -1.; } @@ -323,6 +352,50 @@ mod tests { } + #[test] + fn div_element_mut() { + + let mut a = arr2(&[[ 1., 2., 3.], + [4., 5., 6.]]); + a.div_element_mut(1, 1, 5.); + + assert_eq!(Matrix::get(&a, 1, 1), 1.); + + } + + #[test] + fn mul_element_mut() { + + let mut a = arr2(&[[ 1., 2., 3.], + [4., 5., 6.]]); + a.mul_element_mut(1, 1, 5.); + + assert_eq!(Matrix::get(&a, 1, 1), 25.); + + } + + #[test] + fn add_element_mut() { + + let mut a = arr2(&[[ 1., 2., 3.], + [4., 5., 6.]]); + a.add_element_mut(1, 1, 5.); + + assert_eq!(Matrix::get(&a, 1, 1), 10.); + + } + + #[test] + fn sub_element_mut() { + + let mut a = arr2(&[[ 1., 2., 3.], + [4., 5., 6.]]); + a.sub_element_mut(1, 1, 5.); + + assert_eq!(Matrix::get(&a, 1, 1), 0.); + + } + #[test] fn vstack_hstack() { @@ -376,7 +449,7 @@ mod tests { [49., 64.]]); let result = Matrix::dot(&a, &b); assert_eq!(result, expected); - } + } #[test] fn vector_dot() { @@ -511,4 +584,22 @@ mod tests { let res = a.get_col_as_vec(1); assert_eq!(res, vec![2., 5., 8.]); } + + #[test] + fn col_mean(){ + let a = arr2(&[[1., 2., 3.], + [4., 5., 6.], + [7., 8., 9.]]); + let res = a.column_mean(); + assert_eq!(res, vec![4., 5., 6.]); + } + + #[test] + fn eye(){ + let a = arr2(&[[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]]); + let res: Array2 = Matrix::eye(3); + assert_eq!(res, a); + } } \ No newline at end of file diff --git a/src/linalg/svd.rs b/src/linalg/svd.rs index ce4b79c..9dcea86 100644 --- a/src/linalg/svd.rs +++ b/src/linalg/svd.rs @@ -2,9 +2,9 @@ use crate::linalg::{Matrix}; #[derive(Debug, Clone)] pub struct SVD { - U: M, - V: M, - s: Vec, + pub U: M, + pub V: M, + pub s: Vec, full: bool, m: usize, n: usize,