add fit_intercept to LASSO (#344)

* add fit_intercept to LASSO
* lasso: intercept=None if fit_intercept is false
* update CHANGELOG.md to reflect lasso changes
* lasso: minor
This commit is contained in:
Georeth Chow
2025-11-29 10:46:14 +08:00
committed by GitHub
parent 2bf5f7a1a5
commit 18de2aa244
4 changed files with 125 additions and 54 deletions
+5
View File
@@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
- WARNING: Breaking changes!
- `LassoParameters` and `LassoSearchParameters` have a new field `fit_intercept`. When it is set to false, the `beta_0` term in the formula will be forced to zero, and `intercept` field in `Lasso` will be set to `None`.
## [0.4.0] - 2023-04-05
## Added
+2
View File
@@ -345,6 +345,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
l1_reg * gamma,
parameters.max_iter,
TX::from_f64(parameters.tol).unwrap(),
true,
)?;
for i in 0..p {
@@ -371,6 +372,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
l1_reg * gamma,
parameters.max_iter,
TX::from_f64(parameters.tol).unwrap(),
true,
)?;
for i in 0..p {
+112 -53
View File
@@ -53,6 +53,9 @@ pub struct LassoParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// The maximum number of iterations
pub max_iter: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// If false, force the intercept parameter (beta_0) to be zero.
pub fit_intercept: bool,
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@@ -86,6 +89,12 @@ impl LassoParameters {
self.max_iter = max_iter;
self
}
/// If false, force the intercept parameter (beta_0) to be zero.
pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
self.fit_intercept = fit_intercept;
self
}
}
impl Default for LassoParameters {
@@ -95,6 +104,7 @@ impl Default for LassoParameters {
normalize: true,
tol: 1e-4,
max_iter: 1000,
fit_intercept: true,
}
}
}
@@ -118,8 +128,8 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
{
fn new() -> Self {
Self {
coefficients: Option::None,
intercept: Option::None,
coefficients: None,
intercept: None,
_phantom_ty: PhantomData,
_phantom_y: PhantomData,
}
@@ -155,6 +165,9 @@ pub struct LassoSearchParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// The maximum number of iterations
pub max_iter: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// The maximum number of iterations
pub fit_intercept: Vec<bool>,
}
/// Lasso grid search iterator
@@ -164,6 +177,7 @@ pub struct LassoSearchParametersIterator {
current_normalize: usize,
current_tol: usize,
current_max_iter: usize,
current_fit_intercept: usize,
}
impl IntoIterator for LassoSearchParameters {
@@ -177,6 +191,7 @@ impl IntoIterator for LassoSearchParameters {
current_normalize: 0,
current_tol: 0,
current_max_iter: 0,
current_fit_intercept: 0,
}
}
}
@@ -189,6 +204,7 @@ impl Iterator for LassoSearchParametersIterator {
&& self.current_normalize == self.lasso_search_parameters.normalize.len()
&& self.current_tol == self.lasso_search_parameters.tol.len()
&& self.current_max_iter == self.lasso_search_parameters.max_iter.len()
&& self.current_fit_intercept == self.lasso_search_parameters.fit_intercept.len()
{
return None;
}
@@ -198,6 +214,7 @@ impl Iterator for LassoSearchParametersIterator {
normalize: self.lasso_search_parameters.normalize[self.current_normalize],
tol: self.lasso_search_parameters.tol[self.current_tol],
max_iter: self.lasso_search_parameters.max_iter[self.current_max_iter],
fit_intercept: self.lasso_search_parameters.fit_intercept[self.current_fit_intercept],
};
if self.current_alpha + 1 < self.lasso_search_parameters.alpha.len() {
@@ -214,11 +231,19 @@ impl Iterator for LassoSearchParametersIterator {
self.current_normalize = 0;
self.current_tol = 0;
self.current_max_iter += 1;
} else if self.current_fit_intercept + 1 < self.lasso_search_parameters.fit_intercept.len()
{
self.current_alpha = 0;
self.current_normalize = 0;
self.current_tol = 0;
self.current_max_iter = 0;
self.current_fit_intercept += 1;
} else {
self.current_alpha += 1;
self.current_normalize += 1;
self.current_tol += 1;
self.current_max_iter += 1;
self.current_fit_intercept += 1;
}
Some(next)
@@ -234,6 +259,7 @@ impl Default for LassoSearchParameters {
normalize: vec![default_params.normalize],
tol: vec![default_params.tol],
max_iter: vec![default_params.max_iter],
fit_intercept: vec![default_params.fit_intercept],
}
}
}
@@ -283,19 +309,23 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
l1_reg,
parameters.max_iter,
TX::from_f64(parameters.tol).unwrap(),
parameters.fit_intercept,
)?;
for (j, col_std_j) in col_std.iter().enumerate().take(p) {
w[j] /= *col_std_j;
}
let mut b = TX::zero();
let b = if parameters.fit_intercept {
let mut xw_mean = TX::zero();
for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
xw_mean += w[i] * *col_mean_i;
}
for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
b += w[i] * *col_mean_i;
}
b = TX::from_f64(y.mean_by()).unwrap() - b;
Some(TX::from_f64(y.mean_by()).unwrap() - xw_mean)
} else {
None
};
(X::from_column(&w), b)
} else {
let mut optimizer = InteriorPointOptimizer::new(x, p);
@@ -306,13 +336,21 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
l1_reg,
parameters.max_iter,
TX::from_f64(parameters.tol).unwrap(),
parameters.fit_intercept,
)?;
(X::from_column(&w), TX::from_f64(y.mean_by()).unwrap())
(
X::from_column(&w),
if parameters.fit_intercept {
Some(TX::from_f64(y.mean_by()).unwrap())
} else {
None
},
)
};
Ok(Lasso {
intercept: Some(b),
intercept: b,
coefficients: Some(w),
_phantom_ty: PhantomData,
_phantom_y: PhantomData,
@@ -378,30 +416,28 @@ mod tests {
let parameters = LassoSearchParameters {
alpha: vec![0., 1.],
max_iter: vec![10, 100],
fit_intercept: vec![false, true],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.alpha, 0.);
assert_eq!(next.max_iter, 10);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 1.);
assert_eq!(next.max_iter, 10);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 0.);
assert_eq!(next.max_iter, 100);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 1.);
assert_eq!(next.max_iter, 100);
let mut iter = parameters.clone().into_iter();
for current_fit_intercept in 0..parameters.fit_intercept.len() {
for current_max_iter in 0..parameters.max_iter.len() {
for current_alpha in 0..parameters.alpha.len() {
let next = iter.next().unwrap();
assert_eq!(next.alpha, parameters.alpha[current_alpha]);
assert_eq!(next.max_iter, parameters.max_iter[current_max_iter]);
assert_eq!(
next.fit_intercept,
parameters.fit_intercept[current_fit_intercept]
);
}
}
}
assert!(iter.next().is_none());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn lasso_fit_predict() {
fn get_example_x_y() -> (DenseMatrix<f64>, Vec<f64>) {
let x = DenseMatrix::from_2d_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
@@ -427,6 +463,17 @@ mod tests {
114.2, 115.7, 116.9,
];
(x, y)
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn lasso_fit_predict() {
let (x, y) = get_example_x_y();
let y_hat = Lasso::fit(&x, &y, Default::default())
.and_then(|lr| lr.predict(&x))
.unwrap();
@@ -441,6 +488,7 @@ mod tests {
normalize: false,
tol: 1e-4,
max_iter: 1000,
fit_intercept: true,
},
)
.and_then(|lr| lr.predict(&x))
@@ -479,35 +527,46 @@ mod tests {
assert!(mean_absolute_error(&w, &expected_w) < 1e-3); // actual mean_absolute_error is about 2e-4
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn test_fit_intercept() {
let (x, y) = get_example_x_y();
let fit_result = Lasso::fit(
&x,
&y,
LassoParameters {
alpha: 0.1,
normalize: false,
tol: 1e-8,
max_iter: 1000,
fit_intercept: false,
},
)
.unwrap();
let w = fit_result.coefficients().iterator(0).copied().collect();
// by sklearn LassoLars. coordinate descent doesn't converge well
let expected_w = vec![
0.18335684,
0.02106526,
0.00703214,
-1.35952542,
0.09295222,
0.,
];
assert!(mean_absolute_error(&w, &expected_w) < 1e-6);
assert_eq!(fit_result.intercept, None);
}
// TODO: serialization for the new DenseMatrix needs to be implemented
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
// #[test]
// #[cfg(feature = "serde")]
// fn serde() {
// let x = DenseMatrix::from_2d_array(&[
// &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
// &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
// &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
// &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
// &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
// &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
// &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
// &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
// &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
// &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
// &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
// &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
// &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
// &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
// &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
// &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
// ]);
// let y = vec![
// 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
// 114.2, 115.7, 116.9,
// ];
// let (x, y) = get_lasso_sample_x_y();
// let lr = Lasso::fit(&x, &y, Default::default()).unwrap();
// let deserialized_lr: Lasso<f64, f64, DenseMatrix<f64>, Vec<f64>> =
+6 -1
View File
@@ -45,6 +45,7 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
lambda: T,
max_iter: usize,
tol: T,
fit_intercept: bool,
) -> Result<Vec<T>, Failed> {
let (n, p) = x.shape();
let p_f64 = T::from_usize(p).unwrap();
@@ -61,7 +62,11 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
let mu = T::two();
// let y = M::from_row_vector(y.sub_scalar(y.mean_by())).transpose();
let y = y.sub_scalar(T::from_f64(y.mean_by()).unwrap());
let y = if fit_intercept {
y.sub_scalar(T::from_f64(y.mean_by()).unwrap())
} else {
y.to_owned()
};
let mut max_ls_iter = 100;
let mut pitr = 0;