make default params available to serde (#167)

* add seed param to search params

* make default params available to serde

* lints

* create defaults for enums

* lint
This commit is contained in:
Montana Low
2022-09-21 19:48:31 -07:00
committed by GitHub
parent 403d3f2348
commit 764309e313
22 changed files with 175 additions and 18 deletions
+10
View File
@@ -71,16 +71,21 @@ use crate::linear::lasso_optimizer::InteriorPointOptimizer;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct ElasticNetParameters<T: RealNumber> {
#[cfg_attr(feature = "serde", serde(default))]
/// Regularization parameter.
pub alpha: T,
#[cfg_attr(feature = "serde", serde(default))]
/// The elastic net mixing parameter, with 0 <= l1_ratio <= 1.
/// For l1_ratio = 0 the penalty is an L2 penalty.
/// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.
pub l1_ratio: T,
#[cfg_attr(feature = "serde", serde(default))]
/// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation.
pub normalize: bool,
#[cfg_attr(feature = "serde", serde(default))]
/// The tolerance for the optimization
pub tol: T,
#[cfg_attr(feature = "serde", serde(default))]
/// The maximum number of iterations
pub max_iter: usize,
}
@@ -139,16 +144,21 @@ impl<T: RealNumber> Default for ElasticNetParameters<T> {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct ElasticNetSearchParameters<T: RealNumber> {
#[cfg_attr(feature = "serde", serde(default))]
/// Regularization parameter.
pub alpha: Vec<T>,
#[cfg_attr(feature = "serde", serde(default))]
/// The elastic net mixing parameter, with 0 <= l1_ratio <= 1.
/// For l1_ratio = 0 the penalty is an L2 penalty.
/// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.
pub l1_ratio: Vec<T>,
#[cfg_attr(feature = "serde", serde(default))]
/// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation.
pub normalize: Vec<bool>,
#[cfg_attr(feature = "serde", serde(default))]
/// The tolerance for the optimization
pub tol: Vec<T>,
#[cfg_attr(feature = "serde", serde(default))]
/// The maximum number of iterations
pub max_iter: Vec<usize>,
}
+8
View File
@@ -38,13 +38,17 @@ use crate::math::num::RealNumber;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LassoParameters<T: RealNumber> {
#[cfg_attr(feature = "serde", serde(default))]
/// Controls the strength of the penalty to the loss function.
pub alpha: T,
#[cfg_attr(feature = "serde", serde(default))]
/// If true the regressors X will be normalized before regression
/// by subtracting the mean and dividing by the standard deviation.
pub normalize: bool,
#[cfg_attr(feature = "serde", serde(default))]
/// The tolerance for the optimization
pub tol: T,
#[cfg_attr(feature = "serde", serde(default))]
/// The maximum number of iterations
pub max_iter: usize,
}
@@ -116,13 +120,17 @@ impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for Lasso<T, M> {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LassoSearchParameters<T: RealNumber> {
#[cfg_attr(feature = "serde", serde(default))]
/// Controls the strength of the penalty to the loss function.
pub alpha: Vec<T>,
#[cfg_attr(feature = "serde", serde(default))]
/// If true the regressors X will be normalized before regression
/// by subtracting the mean and dividing by the standard deviation.
pub normalize: Vec<bool>,
#[cfg_attr(feature = "serde", serde(default))]
/// The tolerance for the optimization
pub tol: Vec<T>,
#[cfg_attr(feature = "serde", serde(default))]
/// The maximum number of iterations
pub max_iter: Vec<usize>,
}
+9 -10
View File
@@ -71,19 +71,21 @@ use crate::linalg::Matrix;
use crate::math::num::RealNumber;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, Eq, PartialEq)]
#[derive(Debug, Default, Clone, Eq, PartialEq)]
/// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable.
pub enum LinearRegressionSolverName {
/// QR decomposition, see [QR](../../linalg/qr/index.html)
QR,
#[default]
/// SVD decomposition, see [SVD](../../linalg/svd/index.html)
SVD,
}
/// Linear Regression parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
#[derive(Debug, Default, Clone)]
pub struct LinearRegressionParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Solver to use for estimation of regression coefficients.
pub solver: LinearRegressionSolverName,
}
@@ -105,18 +107,11 @@ impl LinearRegressionParameters {
}
}
impl Default for LinearRegressionParameters {
fn default() -> Self {
LinearRegressionParameters {
solver: LinearRegressionSolverName::SVD,
}
}
}
/// Linear Regression grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LinearRegressionSearchParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Solver to use for estimation of regression coefficients.
pub solver: Vec<LinearRegressionSolverName>,
}
@@ -353,5 +348,9 @@ mod tests {
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
assert_eq!(lr, deserialized_lr);
let default = LinearRegressionParameters::default();
let parameters: LinearRegressionParameters = serde_json::from_str("{}").unwrap();
assert_eq!(parameters.solver, default.solver);
}
}
+11 -1
View File
@@ -75,12 +75,20 @@ pub enum LogisticRegressionSolverName {
LBFGS,
}
impl Default for LogisticRegressionSolverName {
fn default() -> Self {
LogisticRegressionSolverName::LBFGS
}
}
/// Logistic Regression parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LogisticRegressionParameters<T: RealNumber> {
#[cfg_attr(feature = "serde", serde(default))]
/// Solver to use for estimation of regression coefficients.
pub solver: LogisticRegressionSolverName,
#[cfg_attr(feature = "serde", serde(default))]
/// Regularization parameter.
pub alpha: T,
}
@@ -89,8 +97,10 @@ pub struct LogisticRegressionParameters<T: RealNumber> {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LogisticRegressionSearchParameters<T: RealNumber> {
#[cfg_attr(feature = "serde", serde(default))]
/// Solver to use for estimation of regression coefficients.
pub solver: Vec<LogisticRegressionSolverName>,
#[cfg_attr(feature = "serde", serde(default))]
/// Regularization parameter.
pub alpha: Vec<T>,
}
@@ -204,7 +214,7 @@ impl<T: RealNumber> LogisticRegressionParameters<T> {
impl<T: RealNumber> Default for LogisticRegressionParameters<T> {
fn default() -> Self {
LogisticRegressionParameters {
solver: LogisticRegressionSolverName::LBFGS,
solver: LogisticRegressionSolverName::default(),
alpha: T::zero(),
}
}
+10 -1
View File
@@ -77,6 +77,12 @@ pub enum RidgeRegressionSolverName {
SVD,
}
impl Default for RidgeRegressionSolverName {
fn default() -> Self {
RidgeRegressionSolverName::Cholesky
}
}
/// Ridge Regression parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
@@ -94,10 +100,13 @@ pub struct RidgeRegressionParameters<T: RealNumber> {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct RidgeRegressionSearchParameters<T: RealNumber> {
#[cfg_attr(feature = "serde", serde(default))]
/// Solver to use for estimation of regression coefficients.
pub solver: Vec<RidgeRegressionSolverName>,
#[cfg_attr(feature = "serde", serde(default))]
/// Regularization parameter.
pub alpha: Vec<T>,
#[cfg_attr(feature = "serde", serde(default))]
/// If true the regressors X will be normalized before regression
/// by subtracting the mean and dividing by the standard deviation.
pub normalize: Vec<bool>,
@@ -204,7 +213,7 @@ impl<T: RealNumber> RidgeRegressionParameters<T> {
impl<T: RealNumber> Default for RidgeRegressionParameters<T> {
fn default() -> Self {
RidgeRegressionParameters {
solver: RidgeRegressionSolverName::Cholesky,
solver: RidgeRegressionSolverName::default(),
alpha: T::one(),
normalize: true,
}