feat: lasso documentation

This commit is contained in:
Volodymyr Orlov
2020-12-13 13:35:14 -08:00
parent a27c29b736
commit cceb2f046d
4 changed files with 86 additions and 34 deletions
+8 -21
View File
@@ -105,18 +105,15 @@ impl<T: RealNumber, M: Matrix<T>> Lasso<T, M> {
return Err(Failed::fit("Number of rows in X should = len(y)"));
}
let l1_reg = parameters.alpha * T::from_usize(n).unwrap();
let (w, b) = if parameters.normalize {
let (scaled_x, col_mean, col_std) = Self::rescale_x(x)?;
let mut optimizer = InteriorPointOptimizer::new(&scaled_x, p);
let mut w = optimizer.optimize(
&scaled_x,
y,
parameters.alpha,
parameters.max_iter,
parameters.tol,
)?;
let mut w =
optimizer.optimize(&scaled_x, y, l1_reg, parameters.max_iter, parameters.tol)?;
for (j, col_std_j) in col_std.iter().enumerate().take(p) {
w.set(j, 0, w.get(j, 0) / *col_std_j);
@@ -133,8 +130,7 @@ impl<T: RealNumber, M: Matrix<T>> Lasso<T, M> {
} else {
let mut optimizer = InteriorPointOptimizer::new(x, p);
let w =
optimizer.optimize(x, y, parameters.alpha, parameters.max_iter, parameters.tol)?;
let w = optimizer.optimize(x, y, l1_reg, parameters.max_iter, parameters.tol)?;
(w, y.mean())
};
@@ -215,18 +211,9 @@ mod tests {
114.2, 115.7, 116.9,
];
let y_hat = Lasso::fit(
&x,
&y,
LassoParameters {
alpha: 0.1,
normalize: true,
tol: 1e-4,
max_iter: 1000,
},
)
.and_then(|lr| lr.predict(&x))
.unwrap();
let y_hat = Lasso::fit(&x, &y, Default::default())
.and_then(|lr| lr.predict(&x))
.unwrap();
assert!(mean_absolute_error(&y_hat, &y) < 2.0);