feat: consolidates API
This commit is contained in:
@@ -34,7 +34,7 @@
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//!
|
||||
//! let svd = SVD::fit(&iris, 2, Default::default()).unwrap(); // Reduce number of features to 2
|
||||
//! let svd = SVD::fit(&iris, SVDParameters::default().with_n_components(2)).unwrap(); // Reduce number of features to 2
|
||||
//!
|
||||
//! let iris_reduced = svd.transform(&iris).unwrap();
|
||||
//!
|
||||
@@ -47,6 +47,7 @@ use std::marker::PhantomData;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Transformer, UnsupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
@@ -67,11 +68,34 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for SVD<T, M> {
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// SVD parameters
|
||||
pub struct SVDParameters {}
|
||||
pub struct SVDParameters {
|
||||
/// Number of components to keep.
|
||||
pub n_components: usize,
|
||||
}
|
||||
|
||||
impl Default for SVDParameters {
|
||||
fn default() -> Self {
|
||||
SVDParameters {}
|
||||
SVDParameters { n_components: 2 }
|
||||
}
|
||||
}
|
||||
|
||||
impl SVDParameters {
|
||||
/// Number of components to keep.
|
||||
pub fn with_n_components(mut self, n_components: usize) -> Self {
|
||||
self.n_components = n_components;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> UnsupervisedEstimator<M, SVDParameters> for SVD<T, M> {
|
||||
fn fit(x: &M, parameters: SVDParameters) -> Result<Self, Failed> {
|
||||
SVD::fit(x, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Transformer<M> for SVD<T, M> {
|
||||
fn transform(&self, x: &M) -> Result<M, Failed> {
|
||||
self.transform(x)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,10 +104,10 @@ impl<T: RealNumber, M: Matrix<T>> SVD<T, M> {
|
||||
/// * `data` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `n_components` - number of components to keep.
|
||||
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
|
||||
pub fn fit(x: &M, n_components: usize, _: SVDParameters) -> Result<SVD<T, M>, Failed> {
|
||||
pub fn fit(x: &M, parameters: SVDParameters) -> Result<SVD<T, M>, Failed> {
|
||||
let (_, p) = x.shape();
|
||||
|
||||
if n_components >= p {
|
||||
if parameters.n_components >= p {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Number of components, n_components should be < number of attributes ({})",
|
||||
p
|
||||
@@ -92,7 +116,7 @@ impl<T: RealNumber, M: Matrix<T>> SVD<T, M> {
|
||||
|
||||
let svd = x.svd()?;
|
||||
|
||||
let components = svd.V.slice(0..p, 0..n_components);
|
||||
let components = svd.V.slice(0..p, 0..parameters.n_components);
|
||||
|
||||
Ok(SVD {
|
||||
components,
|
||||
@@ -189,7 +213,7 @@ mod tests {
|
||||
&[197.28420365, -11.66808306],
|
||||
&[293.43187394, 1.91163633],
|
||||
]);
|
||||
let svd = SVD::fit(&x, 2, Default::default()).unwrap();
|
||||
let svd = SVD::fit(&x, Default::default()).unwrap();
|
||||
|
||||
let x_transformed = svd.transform(&x).unwrap();
|
||||
|
||||
@@ -225,7 +249,7 @@ mod tests {
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
|
||||
let svd = SVD::fit(&iris, 2, Default::default()).unwrap();
|
||||
let svd = SVD::fit(&iris, Default::default()).unwrap();
|
||||
|
||||
let deserialized_svd: SVD<f64, DenseMatrix<f64>> =
|
||||
serde_json::from_str(&serde_json::to_string(&svd).unwrap()).unwrap();
|
||||
|
||||
Reference in New Issue
Block a user