Fix LASSO (first two of #342) (#343)

* Fix LASSO (#342)
* change loss function in doc to match code
* allow `n == p` case
* lasso add test_full_rank_x

---------

Co-authored-by: Zhou Xiaozhou <zxz@jiweifund.com>
This commit is contained in:
Georeth Chow
2025-11-28 11:15:43 +08:00
committed by GitHub
parent 0caa8306ff
commit 2bf5f7a1a5
+33 -2
View File
@@ -9,7 +9,7 @@
//! //!
//! Lasso coefficient estimates solve the problem: //! Lasso coefficient estimates solve the problem:
//! //!
//! \\[\underset{\beta}{minimize} \space \space \sum_{i=1}^n \left( y_i - \beta_0 - \sum_{j=1}^p \beta_jx_{ij} \right)^2 + \alpha \sum_{j=1}^p \lVert \beta_j \rVert_1\\] //! \\[\underset{\beta}{minimize} \space \space \frac{1}{n} \sum_{i=1}^n \left( y_i - \beta_0 - \sum_{j=1}^p \beta_jx_{ij} \right)^2 + \alpha \sum_{j=1}^p \lVert \beta_j \rVert_1\\]
//! //!
//! This problem is solved with an interior-point method that is comparable to coordinate descent in solving large problems with modest accuracy, //! This problem is solved with an interior-point method that is comparable to coordinate descent in solving large problems with modest accuracy,
//! but is able to solve them with high accuracy with relatively small additional computational cost. //! but is able to solve them with high accuracy with relatively small additional computational cost.
@@ -246,7 +246,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
pub fn fit(x: &X, y: &Y, parameters: LassoParameters) -> Result<Lasso<TX, TY, X, Y>, Failed> { pub fn fit(x: &X, y: &Y, parameters: LassoParameters) -> Result<Lasso<TX, TY, X, Y>, Failed> {
let (n, p) = x.shape(); let (n, p) = x.shape();
if n <= p { if n < p {
return Err(Failed::fit( return Err(Failed::fit(
"Number of rows in X should be >= number of columns in X", "Number of rows in X should be >= number of columns in X",
)); ));
@@ -369,6 +369,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::linalg::basic::arrays::Array;
use crate::linalg::basic::matrix::DenseMatrix; use crate::linalg::basic::matrix::DenseMatrix;
use crate::metrics::mean_absolute_error; use crate::metrics::mean_absolute_error;
@@ -448,6 +449,36 @@ mod tests {
assert!(mean_absolute_error(&y_hat, &y) < 2.0); assert!(mean_absolute_error(&y_hat, &y) < 2.0);
} }
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn test_full_rank_x() {
// x: randn(3,3) * 10, demean, then round to 2 decimal points
// y = x @ [10.0, 0.2, -3.0], round to 2 decimal points
let param = LassoParameters::default()
.with_normalize(false)
.with_alpha(200.0);
let x = DenseMatrix::from_2d_array(&[
&[-8.9, -2.24, 8.89],
&[-4.02, 8.89, 12.33],
&[12.92, -6.65, -21.22],
])
.unwrap();
let y = vec![-116.12, -75.41, 191.53];
let w = Lasso::fit(&x, &y, param)
.unwrap()
.coefficients()
.iterator(0)
.copied()
.collect();
let expected_w = vec![5.20289531, 0., -5.32823882]; // by coordinate descent
assert!(mean_absolute_error(&w, &expected_w) < 1e-3); // actual mean_absolute_error is about 2e-4
}
// TODO: serialization for the new DenseMatrix needs to be implemented // TODO: serialization for the new DenseMatrix needs to be implemented
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)] // #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
// #[test] // #[test]