fix: ridge regression, code refactoring
This commit is contained in:
@@ -134,6 +134,10 @@ impl<T: RealNumber, M: Matrix<T>> RidgeRegression<T, M> {
|
||||
)));
|
||||
}
|
||||
|
||||
if y.len() != n {
|
||||
return Err(Failed::fit(&format!("Number of rows in X should = len(y)")));
|
||||
}
|
||||
|
||||
let y_column = M::from_row_vector(y.clone()).transpose();
|
||||
|
||||
let (w, b) = if parameters.normalize {
|
||||
@@ -216,8 +220,8 @@ impl<T: RealNumber, M: Matrix<T>> RidgeRegression<T, M> {
|
||||
}
|
||||
|
||||
/// Get estimates regression coefficients
|
||||
pub fn coefficients(&self) -> M {
|
||||
self.coefficients.clone()
|
||||
pub fn coefficients(&self) -> &M {
|
||||
&self.coefficients
|
||||
}
|
||||
|
||||
/// Get estimate of intercept
|
||||
|
||||
Reference in New Issue
Block a user