fix: ridge regression, code refactoring

This commit is contained in:
Volodymyr Orlov
2020-11-11 15:59:04 -08:00
parent 7a4fe114d8
commit c42fccdc22
3 changed files with 63 additions and 24 deletions
+6 -2
View File
@@ -134,6 +134,10 @@ impl<T: RealNumber, M: Matrix<T>> RidgeRegression<T, M> {
)));
}
if y.len() != n {
return Err(Failed::fit(&format!("Number of rows in X should = len(y)")));
}
let y_column = M::from_row_vector(y.clone()).transpose();
let (w, b) = if parameters.normalize {
@@ -216,8 +220,8 @@ impl<T: RealNumber, M: Matrix<T>> RidgeRegression<T, M> {
}
/// Get estimates regression coefficients
pub fn coefficients(&self) -> M {
self.coefficients.clone()
pub fn coefficients(&self) -> &M {
&self.coefficients
}
/// Get estimate of intercept