Adds SVD solver, code refactoring
This commit is contained in:
@@ -26,10 +26,14 @@ impl<M: Matrix> LinearRegression<M> {
|
||||
panic!("Number of rows of X doesn't match number of rows of Y");
|
||||
}
|
||||
|
||||
let b = y.v_stack(&M::ones(1, 1));
|
||||
// let b = y.v_stack(&M::ones(1, 1));
|
||||
let b = y.clone();
|
||||
let mut a = x.h_stack(&M::ones(x_nrows, 1));
|
||||
|
||||
let w = a.qr_solve_mut(b);
|
||||
let w = match solver {
|
||||
LinearRegressionSolver::QR => a.qr_solve_mut(b),
|
||||
LinearRegressionSolver::SVD => a.svd_solve_mut(b)
|
||||
};
|
||||
|
||||
let wights = w.slice(0..num_attributes, 0..1);
|
||||
|
||||
@@ -45,7 +49,7 @@ impl<M: Matrix> LinearRegression<M> {
|
||||
impl<M: Matrix> Regression<M> for LinearRegression<M> {
|
||||
|
||||
|
||||
fn predict(&self, x: M) -> M {
|
||||
fn predict(&self, x: &M) -> M {
|
||||
let (nrows, _) = x.shape();
|
||||
let mut y_hat = x.dot(&self.coefficients);
|
||||
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
|
||||
@@ -81,10 +85,13 @@ mod tests {
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
|
||||
let y = DenseMatrix::from_array(16, 1, &[83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]);
|
||||
|
||||
let lr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR);
|
||||
let y_hat_qr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR).predict(&x);
|
||||
|
||||
let y_hat_svd = LinearRegression::fit(&x, &y, LinearRegressionSolver::SVD).predict(&x);
|
||||
|
||||
assert!(y.approximate_eq(&y_hat_qr, 5.));
|
||||
assert!(y.approximate_eq(&y_hat_svd, 5.));
|
||||
|
||||
let y_hat = lr.predict(x);
|
||||
|
||||
assert!(y.approximate_eq(&y_hat, 5.));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user