add fit_intercept to LASSO (#344)
* add fit_intercept to LASSO * lasso: intercept=None if fit_intercept is false * update CHANGELOG.md to reflect lasso changes * lasso: minor
This commit is contained in:
@@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
|
|||||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
|
## [Unreleased]
|
||||||
|
- WARNING: Breaking changes!
|
||||||
|
- `LassoParameters` and `LassoSearchParameters` have a new field `fit_intercept`. When it is set to false, the `beta_0` term in the formula will be forced to zero, and `intercept` field in `Lasso` will be set to `None`.
|
||||||
|
|
||||||
|
|
||||||
## [0.4.0] - 2023-04-05
|
## [0.4.0] - 2023-04-05
|
||||||
|
|
||||||
## Added
|
## Added
|
||||||
|
|||||||
@@ -345,6 +345,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
l1_reg * gamma,
|
l1_reg * gamma,
|
||||||
parameters.max_iter,
|
parameters.max_iter,
|
||||||
TX::from_f64(parameters.tol).unwrap(),
|
TX::from_f64(parameters.tol).unwrap(),
|
||||||
|
true,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
for i in 0..p {
|
for i in 0..p {
|
||||||
@@ -371,6 +372,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
l1_reg * gamma,
|
l1_reg * gamma,
|
||||||
parameters.max_iter,
|
parameters.max_iter,
|
||||||
TX::from_f64(parameters.tol).unwrap(),
|
TX::from_f64(parameters.tol).unwrap(),
|
||||||
|
true,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
for i in 0..p {
|
for i in 0..p {
|
||||||
|
|||||||
+109
-50
@@ -53,6 +53,9 @@ pub struct LassoParameters {
|
|||||||
#[cfg_attr(feature = "serde", serde(default))]
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
/// The maximum number of iterations
|
/// The maximum number of iterations
|
||||||
pub max_iter: usize,
|
pub max_iter: usize,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
/// If false, force the intercept parameter (beta_0) to be zero.
|
||||||
|
pub fit_intercept: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
@@ -86,6 +89,12 @@ impl LassoParameters {
|
|||||||
self.max_iter = max_iter;
|
self.max_iter = max_iter;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// If false, force the intercept parameter (beta_0) to be zero.
|
||||||
|
pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
|
||||||
|
self.fit_intercept = fit_intercept;
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for LassoParameters {
|
impl Default for LassoParameters {
|
||||||
@@ -95,6 +104,7 @@ impl Default for LassoParameters {
|
|||||||
normalize: true,
|
normalize: true,
|
||||||
tol: 1e-4,
|
tol: 1e-4,
|
||||||
max_iter: 1000,
|
max_iter: 1000,
|
||||||
|
fit_intercept: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -118,8 +128,8 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
{
|
{
|
||||||
fn new() -> Self {
|
fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
coefficients: Option::None,
|
coefficients: None,
|
||||||
intercept: Option::None,
|
intercept: None,
|
||||||
_phantom_ty: PhantomData,
|
_phantom_ty: PhantomData,
|
||||||
_phantom_y: PhantomData,
|
_phantom_y: PhantomData,
|
||||||
}
|
}
|
||||||
@@ -155,6 +165,9 @@ pub struct LassoSearchParameters {
|
|||||||
#[cfg_attr(feature = "serde", serde(default))]
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
/// The maximum number of iterations
|
/// The maximum number of iterations
|
||||||
pub max_iter: Vec<usize>,
|
pub max_iter: Vec<usize>,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
/// The maximum number of iterations
|
||||||
|
pub fit_intercept: Vec<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Lasso grid search iterator
|
/// Lasso grid search iterator
|
||||||
@@ -164,6 +177,7 @@ pub struct LassoSearchParametersIterator {
|
|||||||
current_normalize: usize,
|
current_normalize: usize,
|
||||||
current_tol: usize,
|
current_tol: usize,
|
||||||
current_max_iter: usize,
|
current_max_iter: usize,
|
||||||
|
current_fit_intercept: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IntoIterator for LassoSearchParameters {
|
impl IntoIterator for LassoSearchParameters {
|
||||||
@@ -177,6 +191,7 @@ impl IntoIterator for LassoSearchParameters {
|
|||||||
current_normalize: 0,
|
current_normalize: 0,
|
||||||
current_tol: 0,
|
current_tol: 0,
|
||||||
current_max_iter: 0,
|
current_max_iter: 0,
|
||||||
|
current_fit_intercept: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -189,6 +204,7 @@ impl Iterator for LassoSearchParametersIterator {
|
|||||||
&& self.current_normalize == self.lasso_search_parameters.normalize.len()
|
&& self.current_normalize == self.lasso_search_parameters.normalize.len()
|
||||||
&& self.current_tol == self.lasso_search_parameters.tol.len()
|
&& self.current_tol == self.lasso_search_parameters.tol.len()
|
||||||
&& self.current_max_iter == self.lasso_search_parameters.max_iter.len()
|
&& self.current_max_iter == self.lasso_search_parameters.max_iter.len()
|
||||||
|
&& self.current_fit_intercept == self.lasso_search_parameters.fit_intercept.len()
|
||||||
{
|
{
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
@@ -198,6 +214,7 @@ impl Iterator for LassoSearchParametersIterator {
|
|||||||
normalize: self.lasso_search_parameters.normalize[self.current_normalize],
|
normalize: self.lasso_search_parameters.normalize[self.current_normalize],
|
||||||
tol: self.lasso_search_parameters.tol[self.current_tol],
|
tol: self.lasso_search_parameters.tol[self.current_tol],
|
||||||
max_iter: self.lasso_search_parameters.max_iter[self.current_max_iter],
|
max_iter: self.lasso_search_parameters.max_iter[self.current_max_iter],
|
||||||
|
fit_intercept: self.lasso_search_parameters.fit_intercept[self.current_fit_intercept],
|
||||||
};
|
};
|
||||||
|
|
||||||
if self.current_alpha + 1 < self.lasso_search_parameters.alpha.len() {
|
if self.current_alpha + 1 < self.lasso_search_parameters.alpha.len() {
|
||||||
@@ -214,11 +231,19 @@ impl Iterator for LassoSearchParametersIterator {
|
|||||||
self.current_normalize = 0;
|
self.current_normalize = 0;
|
||||||
self.current_tol = 0;
|
self.current_tol = 0;
|
||||||
self.current_max_iter += 1;
|
self.current_max_iter += 1;
|
||||||
|
} else if self.current_fit_intercept + 1 < self.lasso_search_parameters.fit_intercept.len()
|
||||||
|
{
|
||||||
|
self.current_alpha = 0;
|
||||||
|
self.current_normalize = 0;
|
||||||
|
self.current_tol = 0;
|
||||||
|
self.current_max_iter = 0;
|
||||||
|
self.current_fit_intercept += 1;
|
||||||
} else {
|
} else {
|
||||||
self.current_alpha += 1;
|
self.current_alpha += 1;
|
||||||
self.current_normalize += 1;
|
self.current_normalize += 1;
|
||||||
self.current_tol += 1;
|
self.current_tol += 1;
|
||||||
self.current_max_iter += 1;
|
self.current_max_iter += 1;
|
||||||
|
self.current_fit_intercept += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
Some(next)
|
Some(next)
|
||||||
@@ -234,6 +259,7 @@ impl Default for LassoSearchParameters {
|
|||||||
normalize: vec![default_params.normalize],
|
normalize: vec![default_params.normalize],
|
||||||
tol: vec![default_params.tol],
|
tol: vec![default_params.tol],
|
||||||
max_iter: vec![default_params.max_iter],
|
max_iter: vec![default_params.max_iter],
|
||||||
|
fit_intercept: vec![default_params.fit_intercept],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -283,19 +309,23 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
|
|||||||
l1_reg,
|
l1_reg,
|
||||||
parameters.max_iter,
|
parameters.max_iter,
|
||||||
TX::from_f64(parameters.tol).unwrap(),
|
TX::from_f64(parameters.tol).unwrap(),
|
||||||
|
parameters.fit_intercept,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
for (j, col_std_j) in col_std.iter().enumerate().take(p) {
|
for (j, col_std_j) in col_std.iter().enumerate().take(p) {
|
||||||
w[j] /= *col_std_j;
|
w[j] /= *col_std_j;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut b = TX::zero();
|
let b = if parameters.fit_intercept {
|
||||||
|
let mut xw_mean = TX::zero();
|
||||||
for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
|
for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
|
||||||
b += w[i] * *col_mean_i;
|
xw_mean += w[i] * *col_mean_i;
|
||||||
}
|
}
|
||||||
|
|
||||||
b = TX::from_f64(y.mean_by()).unwrap() - b;
|
Some(TX::from_f64(y.mean_by()).unwrap() - xw_mean)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
(X::from_column(&w), b)
|
(X::from_column(&w), b)
|
||||||
} else {
|
} else {
|
||||||
let mut optimizer = InteriorPointOptimizer::new(x, p);
|
let mut optimizer = InteriorPointOptimizer::new(x, p);
|
||||||
@@ -306,13 +336,21 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
|
|||||||
l1_reg,
|
l1_reg,
|
||||||
parameters.max_iter,
|
parameters.max_iter,
|
||||||
TX::from_f64(parameters.tol).unwrap(),
|
TX::from_f64(parameters.tol).unwrap(),
|
||||||
|
parameters.fit_intercept,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
(X::from_column(&w), TX::from_f64(y.mean_by()).unwrap())
|
(
|
||||||
|
X::from_column(&w),
|
||||||
|
if parameters.fit_intercept {
|
||||||
|
Some(TX::from_f64(y.mean_by()).unwrap())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
},
|
||||||
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Lasso {
|
Ok(Lasso {
|
||||||
intercept: Some(b),
|
intercept: b,
|
||||||
coefficients: Some(w),
|
coefficients: Some(w),
|
||||||
_phantom_ty: PhantomData,
|
_phantom_ty: PhantomData,
|
||||||
_phantom_y: PhantomData,
|
_phantom_y: PhantomData,
|
||||||
@@ -378,30 +416,28 @@ mod tests {
|
|||||||
let parameters = LassoSearchParameters {
|
let parameters = LassoSearchParameters {
|
||||||
alpha: vec![0., 1.],
|
alpha: vec![0., 1.],
|
||||||
max_iter: vec![10, 100],
|
max_iter: vec![10, 100],
|
||||||
|
fit_intercept: vec![false, true],
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let mut iter = parameters.into_iter();
|
|
||||||
|
let mut iter = parameters.clone().into_iter();
|
||||||
|
for current_fit_intercept in 0..parameters.fit_intercept.len() {
|
||||||
|
for current_max_iter in 0..parameters.max_iter.len() {
|
||||||
|
for current_alpha in 0..parameters.alpha.len() {
|
||||||
let next = iter.next().unwrap();
|
let next = iter.next().unwrap();
|
||||||
assert_eq!(next.alpha, 0.);
|
assert_eq!(next.alpha, parameters.alpha[current_alpha]);
|
||||||
assert_eq!(next.max_iter, 10);
|
assert_eq!(next.max_iter, parameters.max_iter[current_max_iter]);
|
||||||
let next = iter.next().unwrap();
|
assert_eq!(
|
||||||
assert_eq!(next.alpha, 1.);
|
next.fit_intercept,
|
||||||
assert_eq!(next.max_iter, 10);
|
parameters.fit_intercept[current_fit_intercept]
|
||||||
let next = iter.next().unwrap();
|
);
|
||||||
assert_eq!(next.alpha, 0.);
|
}
|
||||||
assert_eq!(next.max_iter, 100);
|
}
|
||||||
let next = iter.next().unwrap();
|
}
|
||||||
assert_eq!(next.alpha, 1.);
|
|
||||||
assert_eq!(next.max_iter, 100);
|
|
||||||
assert!(iter.next().is_none());
|
assert!(iter.next().is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg_attr(
|
fn get_example_x_y() -> (DenseMatrix<f64>, Vec<f64>) {
|
||||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
|
||||||
wasm_bindgen_test::wasm_bindgen_test
|
|
||||||
)]
|
|
||||||
#[test]
|
|
||||||
fn lasso_fit_predict() {
|
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||||
@@ -427,6 +463,17 @@ mod tests {
|
|||||||
114.2, 115.7, 116.9,
|
114.2, 115.7, 116.9,
|
||||||
];
|
];
|
||||||
|
|
||||||
|
(x, y)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn lasso_fit_predict() {
|
||||||
|
let (x, y) = get_example_x_y();
|
||||||
|
|
||||||
let y_hat = Lasso::fit(&x, &y, Default::default())
|
let y_hat = Lasso::fit(&x, &y, Default::default())
|
||||||
.and_then(|lr| lr.predict(&x))
|
.and_then(|lr| lr.predict(&x))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -441,6 +488,7 @@ mod tests {
|
|||||||
normalize: false,
|
normalize: false,
|
||||||
tol: 1e-4,
|
tol: 1e-4,
|
||||||
max_iter: 1000,
|
max_iter: 1000,
|
||||||
|
fit_intercept: true,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.and_then(|lr| lr.predict(&x))
|
.and_then(|lr| lr.predict(&x))
|
||||||
@@ -479,35 +527,46 @@ mod tests {
|
|||||||
assert!(mean_absolute_error(&w, &expected_w) < 1e-3); // actual mean_absolute_error is about 2e-4
|
assert!(mean_absolute_error(&w, &expected_w) < 1e-3); // actual mean_absolute_error is about 2e-4
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn test_fit_intercept() {
|
||||||
|
let (x, y) = get_example_x_y();
|
||||||
|
let fit_result = Lasso::fit(
|
||||||
|
&x,
|
||||||
|
&y,
|
||||||
|
LassoParameters {
|
||||||
|
alpha: 0.1,
|
||||||
|
normalize: false,
|
||||||
|
tol: 1e-8,
|
||||||
|
max_iter: 1000,
|
||||||
|
fit_intercept: false,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let w = fit_result.coefficients().iterator(0).copied().collect();
|
||||||
|
// by sklearn LassoLars. coordinate descent doesn't converge well
|
||||||
|
let expected_w = vec![
|
||||||
|
0.18335684,
|
||||||
|
0.02106526,
|
||||||
|
0.00703214,
|
||||||
|
-1.35952542,
|
||||||
|
0.09295222,
|
||||||
|
0.,
|
||||||
|
];
|
||||||
|
assert!(mean_absolute_error(&w, &expected_w) < 1e-6);
|
||||||
|
assert_eq!(fit_result.intercept, None);
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: serialization for the new DenseMatrix needs to be implemented
|
// TODO: serialization for the new DenseMatrix needs to be implemented
|
||||||
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
|
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
// #[test]
|
// #[test]
|
||||||
// #[cfg(feature = "serde")]
|
// #[cfg(feature = "serde")]
|
||||||
// fn serde() {
|
// fn serde() {
|
||||||
// let x = DenseMatrix::from_2d_array(&[
|
// let (x, y) = get_lasso_sample_x_y();
|
||||||
// &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
|
||||||
// &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
|
||||||
// &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
|
||||||
// &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
|
|
||||||
// &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
|
||||||
// &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
|
|
||||||
// &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
|
|
||||||
// &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
|
|
||||||
// &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
|
||||||
// &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
|
|
||||||
// &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
|
||||||
// &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
|
|
||||||
// &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
|
||||||
// &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
|
||||||
// &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
|
||||||
// &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
|
||||||
// ]);
|
|
||||||
|
|
||||||
// let y = vec![
|
|
||||||
// 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
|
||||||
// 114.2, 115.7, 116.9,
|
|
||||||
// ];
|
|
||||||
|
|
||||||
// let lr = Lasso::fit(&x, &y, Default::default()).unwrap();
|
// let lr = Lasso::fit(&x, &y, Default::default()).unwrap();
|
||||||
|
|
||||||
// let deserialized_lr: Lasso<f64, f64, DenseMatrix<f64>, Vec<f64>> =
|
// let deserialized_lr: Lasso<f64, f64, DenseMatrix<f64>, Vec<f64>> =
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
|
|||||||
lambda: T,
|
lambda: T,
|
||||||
max_iter: usize,
|
max_iter: usize,
|
||||||
tol: T,
|
tol: T,
|
||||||
|
fit_intercept: bool,
|
||||||
) -> Result<Vec<T>, Failed> {
|
) -> Result<Vec<T>, Failed> {
|
||||||
let (n, p) = x.shape();
|
let (n, p) = x.shape();
|
||||||
let p_f64 = T::from_usize(p).unwrap();
|
let p_f64 = T::from_usize(p).unwrap();
|
||||||
@@ -61,7 +62,11 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
|
|||||||
let mu = T::two();
|
let mu = T::two();
|
||||||
|
|
||||||
// let y = M::from_row_vector(y.sub_scalar(y.mean_by())).transpose();
|
// let y = M::from_row_vector(y.sub_scalar(y.mean_by())).transpose();
|
||||||
let y = y.sub_scalar(T::from_f64(y.mean_by()).unwrap());
|
let y = if fit_intercept {
|
||||||
|
y.sub_scalar(T::from_f64(y.mean_by()).unwrap())
|
||||||
|
} else {
|
||||||
|
y.to_owned()
|
||||||
|
};
|
||||||
|
|
||||||
let mut max_ls_iter = 100;
|
let mut max_ls_iter = 100;
|
||||||
let mut pitr = 0;
|
let mut pitr = 0;
|
||||||
|
|||||||
Reference in New Issue
Block a user