Adds SVD solver, code refactoring

This commit is contained in:
Volodymyr Orlov
2019-10-16 08:28:36 -07:00
parent 50744208a9
commit f4aec2b35e
9 changed files with 422 additions and 26 deletions
+13 -6
View File
@@ -26,10 +26,14 @@ impl<M: Matrix> LinearRegression<M> {
panic!("Number of rows of X doesn't match number of rows of Y");
}
let b = y.v_stack(&M::ones(1, 1));
// let b = y.v_stack(&M::ones(1, 1));
let b = y.clone();
let mut a = x.h_stack(&M::ones(x_nrows, 1));
let w = a.qr_solve_mut(b);
let w = match solver {
LinearRegressionSolver::QR => a.qr_solve_mut(b),
LinearRegressionSolver::SVD => a.svd_solve_mut(b)
};
let wights = w.slice(0..num_attributes, 0..1);
@@ -45,7 +49,7 @@ impl<M: Matrix> LinearRegression<M> {
impl<M: Matrix> Regression<M> for LinearRegression<M> {
fn predict(&self, x: M) -> M {
fn predict(&self, x: &M) -> M {
let (nrows, _) = x.shape();
let mut y_hat = x.dot(&self.coefficients);
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
@@ -81,10 +85,13 @@ mod tests {
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
let y = DenseMatrix::from_array(16, 1, &[83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]);
let lr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR);
let y_hat_qr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR).predict(&x);
let y_hat_svd = LinearRegression::fit(&x, &y, LinearRegressionSolver::SVD).predict(&x);
assert!(y.approximate_eq(&y_hat_qr, 5.));
assert!(y.approximate_eq(&y_hat_svd, 5.));
let y_hat = lr.predict(x);
assert!(y.approximate_eq(&y_hat, 5.));
}
}