make default params available to serde (#167)

* add seed param to search params

* make default params available to serde

* lints

* create defaults for enums

* lint
This commit is contained in:
Montana Low
2022-09-21 19:48:31 -07:00
committed by GitHub
parent 403d3f2348
commit 764309e313
22 changed files with 175 additions and 18 deletions
+9 -10
View File
@@ -71,19 +71,21 @@ use crate::linalg::Matrix;
use crate::math::num::RealNumber;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, Eq, PartialEq)]
#[derive(Debug, Default, Clone, Eq, PartialEq)]
/// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable.
pub enum LinearRegressionSolverName {
/// QR decomposition, see [QR](../../linalg/qr/index.html)
QR,
#[default]
/// SVD decomposition, see [SVD](../../linalg/svd/index.html)
SVD,
}
/// Linear Regression parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
#[derive(Debug, Default, Clone)]
pub struct LinearRegressionParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Solver to use for estimation of regression coefficients.
pub solver: LinearRegressionSolverName,
}
@@ -105,18 +107,11 @@ impl LinearRegressionParameters {
}
}
impl Default for LinearRegressionParameters {
fn default() -> Self {
LinearRegressionParameters {
solver: LinearRegressionSolverName::SVD,
}
}
}
/// Linear Regression grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LinearRegressionSearchParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Solver to use for estimation of regression coefficients.
pub solver: Vec<LinearRegressionSolverName>,
}
@@ -353,5 +348,9 @@ mod tests {
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
assert_eq!(lr, deserialized_lr);
let default = LinearRegressionParameters::default();
let parameters: LinearRegressionParameters = serde_json::from_str("{}").unwrap();
assert_eq!(parameters.solver, default.solver);
}
}