diff --git a/src/linear/logistic_regression.rs b/src/linear/logistic_regression.rs index cbdef77..a23c15a 100644 --- a/src/linear/logistic_regression.rs +++ b/src/linear/logistic_regression.rs @@ -68,10 +68,21 @@ use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult}; use crate::optimization::line_search::Backtracking; use crate::optimization::FunctionOrder; +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +/// Solver options for Logistic regression. Right now only LBFGS solver is supported. +pub enum LogisticRegressionSolverName { + /// Limited-memory Broyden–Fletcher–Goldfarb–Shanno method, see [LBFGS paper](http://users.iems.northwestern.edu/~nocedal/lbfgsb.html) + LBFGS, +} + /// Logistic Regression parameters #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] -pub struct LogisticRegressionParameters {} +pub struct LogisticRegressionParameters { + /// Solver to use for estimation of regression coefficients. + pub solver: LogisticRegressionSolverName, +} /// Logistic Regression #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -105,9 +116,19 @@ struct BinaryObjectiveFunction<'a, T: RealNumber, M: Matrix> { phantom: PhantomData<&'a T>, } +impl LogisticRegressionParameters { + /// Solver to use for estimation of regression coefficients. + pub fn with_solver(mut self, solver: LogisticRegressionSolverName) -> Self { + self.solver = solver; + self + } +} + impl Default for LogisticRegressionParameters { fn default() -> Self { - LogisticRegressionParameters {} + LogisticRegressionParameters { + solver: LogisticRegressionSolverName::LBFGS, + } } }