feat: lasso documentation
This commit is contained in:
+8
-21
@@ -105,18 +105,15 @@ impl<T: RealNumber, M: Matrix<T>> Lasso<T, M> {
|
||||
return Err(Failed::fit("Number of rows in X should = len(y)"));
|
||||
}
|
||||
|
||||
let l1_reg = parameters.alpha * T::from_usize(n).unwrap();
|
||||
|
||||
let (w, b) = if parameters.normalize {
|
||||
let (scaled_x, col_mean, col_std) = Self::rescale_x(x)?;
|
||||
|
||||
let mut optimizer = InteriorPointOptimizer::new(&scaled_x, p);
|
||||
|
||||
let mut w = optimizer.optimize(
|
||||
&scaled_x,
|
||||
y,
|
||||
parameters.alpha,
|
||||
parameters.max_iter,
|
||||
parameters.tol,
|
||||
)?;
|
||||
let mut w =
|
||||
optimizer.optimize(&scaled_x, y, l1_reg, parameters.max_iter, parameters.tol)?;
|
||||
|
||||
for (j, col_std_j) in col_std.iter().enumerate().take(p) {
|
||||
w.set(j, 0, w.get(j, 0) / *col_std_j);
|
||||
@@ -133,8 +130,7 @@ impl<T: RealNumber, M: Matrix<T>> Lasso<T, M> {
|
||||
} else {
|
||||
let mut optimizer = InteriorPointOptimizer::new(x, p);
|
||||
|
||||
let w =
|
||||
optimizer.optimize(x, y, parameters.alpha, parameters.max_iter, parameters.tol)?;
|
||||
let w = optimizer.optimize(x, y, l1_reg, parameters.max_iter, parameters.tol)?;
|
||||
|
||||
(w, y.mean())
|
||||
};
|
||||
@@ -215,18 +211,9 @@ mod tests {
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
|
||||
let y_hat = Lasso::fit(
|
||||
&x,
|
||||
&y,
|
||||
LassoParameters {
|
||||
alpha: 0.1,
|
||||
normalize: true,
|
||||
tol: 1e-4,
|
||||
max_iter: 1000,
|
||||
},
|
||||
)
|
||||
.and_then(|lr| lr.predict(&x))
|
||||
.unwrap();
|
||||
let y_hat = Lasso::fit(&x, &y, Default::default())
|
||||
.and_then(|lr| lr.predict(&x))
|
||||
.unwrap();
|
||||
|
||||
assert!(mean_absolute_error(&y_hat, &y) < 2.0);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user