feat: refactoring, adds Result to most public API

This commit is contained in:
Volodymyr Orlov
2020-09-18 15:20:32 -07:00
parent 4921ae76f5
commit a9db970195
24 changed files with 389 additions and 298 deletions
+19 -13
View File
@@ -47,9 +47,9 @@
//!
//! let lr = LinearRegression::fit(&x, &y, LinearRegressionParameters {
//! solver: LinearRegressionSolverName::QR, // or SVD
//! });
//! }).unwrap();
//!
//! let y_hat = lr.predict(&x);
//! let y_hat = lr.predict(&x).unwrap();
//! ```
//!
//! ## References:
@@ -64,6 +64,7 @@ use std::fmt::Debug;
use serde::{Deserialize, Serialize};
use crate::error::Failed;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
@@ -115,39 +116,41 @@ impl<T: RealNumber, M: Matrix<T>> LinearRegression<T, M> {
x: &M,
y: &M::RowVector,
parameters: LinearRegressionParameters,
) -> LinearRegression<T, M> {
) -> Result<LinearRegression<T, M>, Failed> {
let y_m = M::from_row_vector(y.clone());
let b = y_m.transpose();
let (x_nrows, num_attributes) = x.shape();
let (y_nrows, _) = b.shape();
if x_nrows != y_nrows {
panic!("Number of rows of X doesn't match number of rows of Y");
return Err(Failed::fit(&format!(
"Number of rows of X doesn't match number of rows of Y"
)));
}
let a = x.h_stack(&M::ones(x_nrows, 1));
let w = match parameters.solver {
LinearRegressionSolverName::QR => a.qr_solve_mut(b),
LinearRegressionSolverName::SVD => a.svd_solve_mut(b),
LinearRegressionSolverName::QR => a.qr_solve_mut(b)?,
LinearRegressionSolverName::SVD => a.svd_solve_mut(b)?,
};
let wights = w.slice(0..num_attributes, 0..1);
LinearRegression {
Ok(LinearRegression {
intercept: w.get(num_attributes, 0),
coefficients: wights,
solver: parameters.solver,
}
})
}
/// Predict target values from `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &M) -> M::RowVector {
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
let (nrows, _) = x.shape();
let mut y_hat = x.matmul(&self.coefficients);
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
y_hat.transpose().to_row_vector()
Ok(y_hat.transpose().to_row_vector())
}
/// Get estimates regression coefficients
@@ -199,9 +202,12 @@ mod tests {
solver: LinearRegressionSolverName::QR,
},
)
.predict(&x);
.and_then(|lr| lr.predict(&x))
.unwrap();
let y_hat_svd = LinearRegression::fit(&x, &y, Default::default()).predict(&x);
let y_hat_svd = LinearRegression::fit(&x, &y, Default::default())
.and_then(|lr| lr.predict(&x))
.unwrap();
assert!(y
.iter()
@@ -239,7 +245,7 @@ mod tests {
114.2, 115.7, 116.9,
];
let lr = LinearRegression::fit(&x, &y, Default::default());
let lr = LinearRegression::fit(&x, &y, Default::default()).unwrap();
let deserialized_lr: LinearRegression<f64, DenseMatrix<f64>> =
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();