* Fix LASSO (#342) * change loss function in doc to match code * allow `n == p` case * lasso add test_full_rank_x --------- Co-authored-by: Zhou Xiaozhou <zxz@jiweifund.com>
This commit is contained in:
+33
-2
@@ -9,7 +9,7 @@
|
||||
//!
|
||||
//! Lasso coefficient estimates solve the problem:
|
||||
//!
|
||||
//! \\[\underset{\beta}{minimize} \space \space \sum_{i=1}^n \left( y_i - \beta_0 - \sum_{j=1}^p \beta_jx_{ij} \right)^2 + \alpha \sum_{j=1}^p \lVert \beta_j \rVert_1\\]
|
||||
//! \\[\underset{\beta}{minimize} \space \space \frac{1}{n} \sum_{i=1}^n \left( y_i - \beta_0 - \sum_{j=1}^p \beta_jx_{ij} \right)^2 + \alpha \sum_{j=1}^p \lVert \beta_j \rVert_1\\]
|
||||
//!
|
||||
//! This problem is solved with an interior-point method that is comparable to coordinate descent in solving large problems with modest accuracy,
|
||||
//! but is able to solve them with high accuracy with relatively small additional computational cost.
|
||||
@@ -246,7 +246,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
|
||||
pub fn fit(x: &X, y: &Y, parameters: LassoParameters) -> Result<Lasso<TX, TY, X, Y>, Failed> {
|
||||
let (n, p) = x.shape();
|
||||
|
||||
if n <= p {
|
||||
if n < p {
|
||||
return Err(Failed::fit(
|
||||
"Number of rows in X should be >= number of columns in X",
|
||||
));
|
||||
@@ -369,6 +369,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::basic::arrays::Array;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use crate::metrics::mean_absolute_error;
|
||||
|
||||
@@ -448,6 +449,36 @@ mod tests {
|
||||
assert!(mean_absolute_error(&y_hat, &y) < 2.0);
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn test_full_rank_x() {
|
||||
// x: randn(3,3) * 10, demean, then round to 2 decimal points
|
||||
// y = x @ [10.0, 0.2, -3.0], round to 2 decimal points
|
||||
let param = LassoParameters::default()
|
||||
.with_normalize(false)
|
||||
.with_alpha(200.0);
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[-8.9, -2.24, 8.89],
|
||||
&[-4.02, 8.89, 12.33],
|
||||
&[12.92, -6.65, -21.22],
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let y = vec![-116.12, -75.41, 191.53];
|
||||
let w = Lasso::fit(&x, &y, param)
|
||||
.unwrap()
|
||||
.coefficients()
|
||||
.iterator(0)
|
||||
.copied()
|
||||
.collect();
|
||||
|
||||
let expected_w = vec![5.20289531, 0., -5.32823882]; // by coordinate descent
|
||||
assert!(mean_absolute_error(&w, &expected_w) < 1e-3); // actual mean_absolute_error is about 2e-4
|
||||
}
|
||||
|
||||
// TODO: serialization for the new DenseMatrix needs to be implemented
|
||||
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
|
||||
// #[test]
|
||||
|
||||
Reference in New Issue
Block a user