Merge pull request #71 from smartcorelib/log_regression_solvers
feat: adds a new parameter to the logistic regression: solver
This commit is contained in:
@@ -68,10 +68,21 @@ use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
|
||||
use crate::optimization::line_search::Backtracking;
|
||||
use crate::optimization::FunctionOrder;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// Solver options for Logistic regression. Right now only LBFGS solver is supported.
|
||||
pub enum LogisticRegressionSolverName {
|
||||
/// Limited-memory Broyden–Fletcher–Goldfarb–Shanno method, see [LBFGS paper](http://users.iems.northwestern.edu/~nocedal/lbfgsb.html)
|
||||
LBFGS,
|
||||
}
|
||||
|
||||
/// Logistic Regression parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LogisticRegressionParameters {}
|
||||
pub struct LogisticRegressionParameters {
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub solver: LogisticRegressionSolverName,
|
||||
}
|
||||
|
||||
/// Logistic Regression
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
@@ -105,9 +116,19 @@ struct BinaryObjectiveFunction<'a, T: RealNumber, M: Matrix<T>> {
|
||||
phantom: PhantomData<&'a T>,
|
||||
}
|
||||
|
||||
impl LogisticRegressionParameters {
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub fn with_solver(mut self, solver: LogisticRegressionSolverName) -> Self {
|
||||
self.solver = solver;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LogisticRegressionParameters {
|
||||
fn default() -> Self {
|
||||
LogisticRegressionParameters {}
|
||||
LogisticRegressionParameters {
|
||||
solver: LogisticRegressionSolverName::LBFGS,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user