add fit_intercept to LASSO (#344)

* add fit_intercept to LASSO
* lasso: intercept=None if fit_intercept is false
* update CHANGELOG.md to reflect lasso changes
* lasso: minor
This commit is contained in:
Georeth Chow
2025-11-29 10:46:14 +08:00
committed by GitHub
parent 2bf5f7a1a5
commit 18de2aa244
4 changed files with 125 additions and 54 deletions
+5
View File
@@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
- WARNING: Breaking changes!
- `LassoParameters` and `LassoSearchParameters` have a new field `fit_intercept`. When it is set to false, the `beta_0` term in the formula will be forced to zero, and `intercept` field in `Lasso` will be set to `None`.
## [0.4.0] - 2023-04-05 ## [0.4.0] - 2023-04-05
## Added ## Added
+2
View File
@@ -345,6 +345,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
l1_reg * gamma, l1_reg * gamma,
parameters.max_iter, parameters.max_iter,
TX::from_f64(parameters.tol).unwrap(), TX::from_f64(parameters.tol).unwrap(),
true,
)?; )?;
for i in 0..p { for i in 0..p {
@@ -371,6 +372,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
l1_reg * gamma, l1_reg * gamma,
parameters.max_iter, parameters.max_iter,
TX::from_f64(parameters.tol).unwrap(), TX::from_f64(parameters.tol).unwrap(),
true,
)?; )?;
for i in 0..p { for i in 0..p {
+112 -53
View File
@@ -53,6 +53,9 @@ pub struct LassoParameters {
#[cfg_attr(feature = "serde", serde(default))] #[cfg_attr(feature = "serde", serde(default))]
/// The maximum number of iterations /// The maximum number of iterations
pub max_iter: usize, pub max_iter: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// If false, force the intercept parameter (beta_0) to be zero.
pub fit_intercept: bool,
} }
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@@ -86,6 +89,12 @@ impl LassoParameters {
self.max_iter = max_iter; self.max_iter = max_iter;
self self
} }
/// If false, force the intercept parameter (beta_0) to be zero.
pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
self.fit_intercept = fit_intercept;
self
}
} }
impl Default for LassoParameters { impl Default for LassoParameters {
@@ -95,6 +104,7 @@ impl Default for LassoParameters {
normalize: true, normalize: true,
tol: 1e-4, tol: 1e-4,
max_iter: 1000, max_iter: 1000,
fit_intercept: true,
} }
} }
} }
@@ -118,8 +128,8 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
{ {
fn new() -> Self { fn new() -> Self {
Self { Self {
coefficients: Option::None, coefficients: None,
intercept: Option::None, intercept: None,
_phantom_ty: PhantomData, _phantom_ty: PhantomData,
_phantom_y: PhantomData, _phantom_y: PhantomData,
} }
@@ -155,6 +165,9 @@ pub struct LassoSearchParameters {
#[cfg_attr(feature = "serde", serde(default))] #[cfg_attr(feature = "serde", serde(default))]
/// The maximum number of iterations /// The maximum number of iterations
pub max_iter: Vec<usize>, pub max_iter: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// The maximum number of iterations
pub fit_intercept: Vec<bool>,
} }
/// Lasso grid search iterator /// Lasso grid search iterator
@@ -164,6 +177,7 @@ pub struct LassoSearchParametersIterator {
current_normalize: usize, current_normalize: usize,
current_tol: usize, current_tol: usize,
current_max_iter: usize, current_max_iter: usize,
current_fit_intercept: usize,
} }
impl IntoIterator for LassoSearchParameters { impl IntoIterator for LassoSearchParameters {
@@ -177,6 +191,7 @@ impl IntoIterator for LassoSearchParameters {
current_normalize: 0, current_normalize: 0,
current_tol: 0, current_tol: 0,
current_max_iter: 0, current_max_iter: 0,
current_fit_intercept: 0,
} }
} }
} }
@@ -189,6 +204,7 @@ impl Iterator for LassoSearchParametersIterator {
&& self.current_normalize == self.lasso_search_parameters.normalize.len() && self.current_normalize == self.lasso_search_parameters.normalize.len()
&& self.current_tol == self.lasso_search_parameters.tol.len() && self.current_tol == self.lasso_search_parameters.tol.len()
&& self.current_max_iter == self.lasso_search_parameters.max_iter.len() && self.current_max_iter == self.lasso_search_parameters.max_iter.len()
&& self.current_fit_intercept == self.lasso_search_parameters.fit_intercept.len()
{ {
return None; return None;
} }
@@ -198,6 +214,7 @@ impl Iterator for LassoSearchParametersIterator {
normalize: self.lasso_search_parameters.normalize[self.current_normalize], normalize: self.lasso_search_parameters.normalize[self.current_normalize],
tol: self.lasso_search_parameters.tol[self.current_tol], tol: self.lasso_search_parameters.tol[self.current_tol],
max_iter: self.lasso_search_parameters.max_iter[self.current_max_iter], max_iter: self.lasso_search_parameters.max_iter[self.current_max_iter],
fit_intercept: self.lasso_search_parameters.fit_intercept[self.current_fit_intercept],
}; };
if self.current_alpha + 1 < self.lasso_search_parameters.alpha.len() { if self.current_alpha + 1 < self.lasso_search_parameters.alpha.len() {
@@ -214,11 +231,19 @@ impl Iterator for LassoSearchParametersIterator {
self.current_normalize = 0; self.current_normalize = 0;
self.current_tol = 0; self.current_tol = 0;
self.current_max_iter += 1; self.current_max_iter += 1;
} else if self.current_fit_intercept + 1 < self.lasso_search_parameters.fit_intercept.len()
{
self.current_alpha = 0;
self.current_normalize = 0;
self.current_tol = 0;
self.current_max_iter = 0;
self.current_fit_intercept += 1;
} else { } else {
self.current_alpha += 1; self.current_alpha += 1;
self.current_normalize += 1; self.current_normalize += 1;
self.current_tol += 1; self.current_tol += 1;
self.current_max_iter += 1; self.current_max_iter += 1;
self.current_fit_intercept += 1;
} }
Some(next) Some(next)
@@ -234,6 +259,7 @@ impl Default for LassoSearchParameters {
normalize: vec![default_params.normalize], normalize: vec![default_params.normalize],
tol: vec![default_params.tol], tol: vec![default_params.tol],
max_iter: vec![default_params.max_iter], max_iter: vec![default_params.max_iter],
fit_intercept: vec![default_params.fit_intercept],
} }
} }
} }
@@ -283,19 +309,23 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
l1_reg, l1_reg,
parameters.max_iter, parameters.max_iter,
TX::from_f64(parameters.tol).unwrap(), TX::from_f64(parameters.tol).unwrap(),
parameters.fit_intercept,
)?; )?;
for (j, col_std_j) in col_std.iter().enumerate().take(p) { for (j, col_std_j) in col_std.iter().enumerate().take(p) {
w[j] /= *col_std_j; w[j] /= *col_std_j;
} }
let mut b = TX::zero(); let b = if parameters.fit_intercept {
let mut xw_mean = TX::zero();
for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
xw_mean += w[i] * *col_mean_i;
}
for (i, col_mean_i) in col_mean.iter().enumerate().take(p) { Some(TX::from_f64(y.mean_by()).unwrap() - xw_mean)
b += w[i] * *col_mean_i; } else {
} None
};
b = TX::from_f64(y.mean_by()).unwrap() - b;
(X::from_column(&w), b) (X::from_column(&w), b)
} else { } else {
let mut optimizer = InteriorPointOptimizer::new(x, p); let mut optimizer = InteriorPointOptimizer::new(x, p);
@@ -306,13 +336,21 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
l1_reg, l1_reg,
parameters.max_iter, parameters.max_iter,
TX::from_f64(parameters.tol).unwrap(), TX::from_f64(parameters.tol).unwrap(),
parameters.fit_intercept,
)?; )?;
(X::from_column(&w), TX::from_f64(y.mean_by()).unwrap()) (
X::from_column(&w),
if parameters.fit_intercept {
Some(TX::from_f64(y.mean_by()).unwrap())
} else {
None
},
)
}; };
Ok(Lasso { Ok(Lasso {
intercept: Some(b), intercept: b,
coefficients: Some(w), coefficients: Some(w),
_phantom_ty: PhantomData, _phantom_ty: PhantomData,
_phantom_y: PhantomData, _phantom_y: PhantomData,
@@ -378,30 +416,28 @@ mod tests {
let parameters = LassoSearchParameters { let parameters = LassoSearchParameters {
alpha: vec![0., 1.], alpha: vec![0., 1.],
max_iter: vec![10, 100], max_iter: vec![10, 100],
fit_intercept: vec![false, true],
..Default::default() ..Default::default()
}; };
let mut iter = parameters.into_iter();
let next = iter.next().unwrap(); let mut iter = parameters.clone().into_iter();
assert_eq!(next.alpha, 0.); for current_fit_intercept in 0..parameters.fit_intercept.len() {
assert_eq!(next.max_iter, 10); for current_max_iter in 0..parameters.max_iter.len() {
let next = iter.next().unwrap(); for current_alpha in 0..parameters.alpha.len() {
assert_eq!(next.alpha, 1.); let next = iter.next().unwrap();
assert_eq!(next.max_iter, 10); assert_eq!(next.alpha, parameters.alpha[current_alpha]);
let next = iter.next().unwrap(); assert_eq!(next.max_iter, parameters.max_iter[current_max_iter]);
assert_eq!(next.alpha, 0.); assert_eq!(
assert_eq!(next.max_iter, 100); next.fit_intercept,
let next = iter.next().unwrap(); parameters.fit_intercept[current_fit_intercept]
assert_eq!(next.alpha, 1.); );
assert_eq!(next.max_iter, 100); }
}
}
assert!(iter.next().is_none()); assert!(iter.next().is_none());
} }
#[cfg_attr( fn get_example_x_y() -> (DenseMatrix<f64>, Vec<f64>) {
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn lasso_fit_predict() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323], &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122], &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
@@ -427,6 +463,17 @@ mod tests {
114.2, 115.7, 116.9, 114.2, 115.7, 116.9,
]; ];
(x, y)
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn lasso_fit_predict() {
let (x, y) = get_example_x_y();
let y_hat = Lasso::fit(&x, &y, Default::default()) let y_hat = Lasso::fit(&x, &y, Default::default())
.and_then(|lr| lr.predict(&x)) .and_then(|lr| lr.predict(&x))
.unwrap(); .unwrap();
@@ -441,6 +488,7 @@ mod tests {
normalize: false, normalize: false,
tol: 1e-4, tol: 1e-4,
max_iter: 1000, max_iter: 1000,
fit_intercept: true,
}, },
) )
.and_then(|lr| lr.predict(&x)) .and_then(|lr| lr.predict(&x))
@@ -479,35 +527,46 @@ mod tests {
assert!(mean_absolute_error(&w, &expected_w) < 1e-3); // actual mean_absolute_error is about 2e-4 assert!(mean_absolute_error(&w, &expected_w) < 1e-3); // actual mean_absolute_error is about 2e-4
} }
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn test_fit_intercept() {
let (x, y) = get_example_x_y();
let fit_result = Lasso::fit(
&x,
&y,
LassoParameters {
alpha: 0.1,
normalize: false,
tol: 1e-8,
max_iter: 1000,
fit_intercept: false,
},
)
.unwrap();
let w = fit_result.coefficients().iterator(0).copied().collect();
// by sklearn LassoLars. coordinate descent doesn't converge well
let expected_w = vec![
0.18335684,
0.02106526,
0.00703214,
-1.35952542,
0.09295222,
0.,
];
assert!(mean_absolute_error(&w, &expected_w) < 1e-6);
assert_eq!(fit_result.intercept, None);
}
// TODO: serialization for the new DenseMatrix needs to be implemented // TODO: serialization for the new DenseMatrix needs to be implemented
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)] // #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
// #[test] // #[test]
// #[cfg(feature = "serde")] // #[cfg(feature = "serde")]
// fn serde() { // fn serde() {
// let x = DenseMatrix::from_2d_array(&[ // let (x, y) = get_lasso_sample_x_y();
// &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
// &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
// &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
// &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
// &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
// &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
// &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
// &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
// &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
// &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
// &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
// &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
// &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
// &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
// &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
// &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
// ]);
// let y = vec![
// 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
// 114.2, 115.7, 116.9,
// ];
// let lr = Lasso::fit(&x, &y, Default::default()).unwrap(); // let lr = Lasso::fit(&x, &y, Default::default()).unwrap();
// let deserialized_lr: Lasso<f64, f64, DenseMatrix<f64>, Vec<f64>> = // let deserialized_lr: Lasso<f64, f64, DenseMatrix<f64>, Vec<f64>> =
+6 -1
View File
@@ -45,6 +45,7 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
lambda: T, lambda: T,
max_iter: usize, max_iter: usize,
tol: T, tol: T,
fit_intercept: bool,
) -> Result<Vec<T>, Failed> { ) -> Result<Vec<T>, Failed> {
let (n, p) = x.shape(); let (n, p) = x.shape();
let p_f64 = T::from_usize(p).unwrap(); let p_f64 = T::from_usize(p).unwrap();
@@ -61,7 +62,11 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
let mu = T::two(); let mu = T::two();
// let y = M::from_row_vector(y.sub_scalar(y.mean_by())).transpose(); // let y = M::from_row_vector(y.sub_scalar(y.mean_by())).transpose();
let y = y.sub_scalar(T::from_f64(y.mean_by()).unwrap()); let y = if fit_intercept {
y.sub_scalar(T::from_f64(y.mean_by()).unwrap())
} else {
y.to_owned()
};
let mut max_ls_iter = 100; let mut max_ls_iter = 100;
let mut pitr = 0; let mut pitr = 0;