feat: refactoring, adds Result to most public API
This commit is contained in:
+24
-18
@@ -37,9 +37,9 @@
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//!
|
||||
//! let pca = PCA::new(&iris, 2, Default::default()); // Reduce number of features to 2
|
||||
//! let pca = PCA::fit(&iris, 2, Default::default()).unwrap(); // Reduce number of features to 2
|
||||
//!
|
||||
//! let iris_reduced = pca.transform(&iris);
|
||||
//! let iris_reduced = pca.transform(&iris).unwrap();
|
||||
//!
|
||||
//! ```
|
||||
//!
|
||||
@@ -49,6 +49,7 @@ use std::fmt::Debug;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
@@ -100,7 +101,11 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
|
||||
/// * `data` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `n_components` - number of components to keep.
|
||||
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
|
||||
pub fn new(data: &M, n_components: usize, parameters: PCAParameters) -> PCA<T, M> {
|
||||
pub fn fit(
|
||||
data: &M,
|
||||
n_components: usize,
|
||||
parameters: PCAParameters,
|
||||
) -> Result<PCA<T, M>, Failed> {
|
||||
let (m, n) = data.shape();
|
||||
|
||||
let mu = data.column_mean();
|
||||
@@ -117,7 +122,7 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
|
||||
let mut eigenvectors;
|
||||
|
||||
if m > n && !parameters.use_correlation_matrix {
|
||||
let svd = x.svd();
|
||||
let svd = x.svd()?;
|
||||
eigenvalues = svd.s;
|
||||
for i in 0..eigenvalues.len() {
|
||||
eigenvalues[i] = eigenvalues[i] * eigenvalues[i];
|
||||
@@ -155,7 +160,7 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
let evd = cov.evd(true);
|
||||
let evd = cov.evd(true)?;
|
||||
|
||||
eigenvalues = evd.d;
|
||||
|
||||
@@ -167,7 +172,7 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let evd = cov.evd(true);
|
||||
let evd = cov.evd(true)?;
|
||||
|
||||
eigenvalues = evd.d;
|
||||
|
||||
@@ -189,26 +194,26 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
PCA {
|
||||
Ok(PCA {
|
||||
eigenvectors: eigenvectors,
|
||||
eigenvalues: eigenvalues,
|
||||
projection: projection.transpose(),
|
||||
mu: mu,
|
||||
pmu: pmu,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Run dimensionality reduction for `x`
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn transform(&self, x: &M) -> M {
|
||||
pub fn transform(&self, x: &M) -> Result<M, Failed> {
|
||||
let (nrows, ncols) = x.shape();
|
||||
let (_, n_components) = self.projection.shape();
|
||||
if ncols != self.mu.len() {
|
||||
panic!(
|
||||
return Err(Failed::transform(&format!(
|
||||
"Invalid input vector size: {}, expected: {}",
|
||||
ncols,
|
||||
self.mu.len()
|
||||
);
|
||||
)));
|
||||
}
|
||||
|
||||
let mut x_transformed = x.matmul(&self.projection);
|
||||
@@ -217,7 +222,7 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
|
||||
x_transformed.sub_element_mut(r, c, self.pmu[c]);
|
||||
}
|
||||
}
|
||||
x_transformed
|
||||
Ok(x_transformed)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -372,7 +377,7 @@ mod tests {
|
||||
302.04806302399646,
|
||||
];
|
||||
|
||||
let pca = PCA::new(&us_arrests, 4, Default::default());
|
||||
let pca = PCA::fit(&us_arrests, 4, Default::default()).unwrap();
|
||||
|
||||
assert!(pca
|
||||
.eigenvectors
|
||||
@@ -383,7 +388,7 @@ mod tests {
|
||||
assert!((pca.eigenvalues[i].abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
|
||||
}
|
||||
|
||||
let us_arrests_t = pca.transform(&us_arrests);
|
||||
let us_arrests_t = pca.transform(&us_arrests).unwrap();
|
||||
|
||||
assert!(us_arrests_t
|
||||
.abs()
|
||||
@@ -481,13 +486,14 @@ mod tests {
|
||||
0.1734300877298357,
|
||||
];
|
||||
|
||||
let pca = PCA::new(
|
||||
let pca = PCA::fit(
|
||||
&us_arrests,
|
||||
4,
|
||||
PCAParameters {
|
||||
use_correlation_matrix: true,
|
||||
},
|
||||
);
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(pca
|
||||
.eigenvectors
|
||||
@@ -498,7 +504,7 @@ mod tests {
|
||||
assert!((pca.eigenvalues[i].abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
|
||||
}
|
||||
|
||||
let us_arrests_t = pca.transform(&us_arrests);
|
||||
let us_arrests_t = pca.transform(&us_arrests).unwrap();
|
||||
|
||||
assert!(us_arrests_t
|
||||
.abs()
|
||||
@@ -530,7 +536,7 @@ mod tests {
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
|
||||
let pca = PCA::new(&iris, 4, Default::default());
|
||||
let pca = PCA::fit(&iris, 4, Default::default()).unwrap();
|
||||
|
||||
let deserialized_pca: PCA<f64, DenseMatrix<f64>> =
|
||||
serde_json::from_str(&serde_json::to_string(&pca).unwrap()).unwrap();
|
||||
|
||||
Reference in New Issue
Block a user