Merge pull request #72 from smartcorelib/lr_reg
feat: adds l2 regularization penalty to the Logistic Regression
This commit is contained in:
@@ -54,7 +54,6 @@
|
|||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
use std::cmp::Ordering;
|
use std::cmp::Ordering;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::marker::PhantomData;
|
|
||||||
|
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
@@ -79,9 +78,11 @@ pub enum LogisticRegressionSolverName {
|
|||||||
/// Logistic Regression parameters
|
/// Logistic Regression parameters
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct LogisticRegressionParameters {
|
pub struct LogisticRegressionParameters<T: RealNumber> {
|
||||||
/// Solver to use for estimation of regression coefficients.
|
/// Solver to use for estimation of regression coefficients.
|
||||||
pub solver: LogisticRegressionSolverName,
|
pub solver: LogisticRegressionSolverName,
|
||||||
|
/// Regularization parameter.
|
||||||
|
pub alpha: T,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Logistic Regression
|
/// Logistic Regression
|
||||||
@@ -113,21 +114,27 @@ trait ObjectiveFunction<T: RealNumber, M: Matrix<T>> {
|
|||||||
struct BinaryObjectiveFunction<'a, T: RealNumber, M: Matrix<T>> {
|
struct BinaryObjectiveFunction<'a, T: RealNumber, M: Matrix<T>> {
|
||||||
x: &'a M,
|
x: &'a M,
|
||||||
y: Vec<usize>,
|
y: Vec<usize>,
|
||||||
phantom: PhantomData<&'a T>,
|
alpha: T,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LogisticRegressionParameters {
|
impl<T: RealNumber> LogisticRegressionParameters<T> {
|
||||||
/// Solver to use for estimation of regression coefficients.
|
/// Solver to use for estimation of regression coefficients.
|
||||||
pub fn with_solver(mut self, solver: LogisticRegressionSolverName) -> Self {
|
pub fn with_solver(mut self, solver: LogisticRegressionSolverName) -> Self {
|
||||||
self.solver = solver;
|
self.solver = solver;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
/// Regularization parameter.
|
||||||
|
pub fn with_alpha(mut self, alpha: T) -> Self {
|
||||||
|
self.alpha = alpha;
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for LogisticRegressionParameters {
|
impl<T: RealNumber> Default for LogisticRegressionParameters<T> {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
LogisticRegressionParameters {
|
LogisticRegressionParameters {
|
||||||
solver: LogisticRegressionSolverName::LBFGS,
|
solver: LogisticRegressionSolverName::LBFGS,
|
||||||
|
alpha: T::zero(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -156,13 +163,22 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
|
|||||||
{
|
{
|
||||||
fn f(&self, w_bias: &M) -> T {
|
fn f(&self, w_bias: &M) -> T {
|
||||||
let mut f = T::zero();
|
let mut f = T::zero();
|
||||||
let (n, _) = self.x.shape();
|
let (n, p) = self.x.shape();
|
||||||
|
|
||||||
for i in 0..n {
|
for i in 0..n {
|
||||||
let wx = BinaryObjectiveFunction::partial_dot(w_bias, self.x, 0, i);
|
let wx = BinaryObjectiveFunction::partial_dot(w_bias, self.x, 0, i);
|
||||||
f += wx.ln_1pe() - (T::from(self.y[i]).unwrap()) * wx;
|
f += wx.ln_1pe() - (T::from(self.y[i]).unwrap()) * wx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.alpha > T::zero() {
|
||||||
|
let mut w_squared = T::zero();
|
||||||
|
for i in 0..p {
|
||||||
|
let w = w_bias.get(0, i);
|
||||||
|
w_squared += w * w;
|
||||||
|
}
|
||||||
|
f += T::half() * self.alpha * w_squared;
|
||||||
|
}
|
||||||
|
|
||||||
f
|
f
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -180,6 +196,13 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
|
|||||||
}
|
}
|
||||||
g.set(0, p, g.get(0, p) - dyi);
|
g.set(0, p, g.get(0, p) - dyi);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.alpha > T::zero() {
|
||||||
|
for i in 0..p {
|
||||||
|
let w = w_bias.get(0, i);
|
||||||
|
g.set(0, i, g.get(0, i) + self.alpha * w);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -187,7 +210,7 @@ struct MultiClassObjectiveFunction<'a, T: RealNumber, M: Matrix<T>> {
|
|||||||
x: &'a M,
|
x: &'a M,
|
||||||
y: Vec<usize>,
|
y: Vec<usize>,
|
||||||
k: usize,
|
k: usize,
|
||||||
phantom: PhantomData<&'a T>,
|
alpha: T,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
|
impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
|
||||||
@@ -209,6 +232,17 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
|
|||||||
f -= prob.get(0, self.y[i]).ln();
|
f -= prob.get(0, self.y[i]).ln();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.alpha > T::zero() {
|
||||||
|
let mut w_squared = T::zero();
|
||||||
|
for i in 0..self.k {
|
||||||
|
for j in 0..p {
|
||||||
|
let wi = w_bias.get(0, i * (p + 1) + j);
|
||||||
|
w_squared += wi * wi;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
f += T::half() * self.alpha * w_squared;
|
||||||
|
}
|
||||||
|
|
||||||
f
|
f
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -239,16 +273,27 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
|
|||||||
g.set(0, j * (p + 1) + p, g.get(0, j * (p + 1) + p) - yi);
|
g.set(0, j * (p + 1) + p, g.get(0, j * (p + 1) + p) - yi);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.alpha > T::zero() {
|
||||||
|
for i in 0..self.k {
|
||||||
|
for j in 0..p {
|
||||||
|
let pos = i * (p + 1);
|
||||||
|
let wi = w.get(0, pos + j);
|
||||||
|
g.set(0, pos + j, g.get(0, pos + j) + self.alpha * wi);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, LogisticRegressionParameters>
|
impl<T: RealNumber, M: Matrix<T>>
|
||||||
|
SupervisedEstimator<M, M::RowVector, LogisticRegressionParameters<T>>
|
||||||
for LogisticRegression<T, M>
|
for LogisticRegression<T, M>
|
||||||
{
|
{
|
||||||
fn fit(
|
fn fit(
|
||||||
x: &M,
|
x: &M,
|
||||||
y: &M::RowVector,
|
y: &M::RowVector,
|
||||||
parameters: LogisticRegressionParameters,
|
parameters: LogisticRegressionParameters<T>,
|
||||||
) -> Result<Self, Failed> {
|
) -> Result<Self, Failed> {
|
||||||
LogisticRegression::fit(x, y, parameters)
|
LogisticRegression::fit(x, y, parameters)
|
||||||
}
|
}
|
||||||
@@ -268,7 +313,7 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
|
|||||||
pub fn fit(
|
pub fn fit(
|
||||||
x: &M,
|
x: &M,
|
||||||
y: &M::RowVector,
|
y: &M::RowVector,
|
||||||
_parameters: LogisticRegressionParameters,
|
parameters: LogisticRegressionParameters<T>,
|
||||||
) -> Result<LogisticRegression<T, M>, Failed> {
|
) -> Result<LogisticRegression<T, M>, Failed> {
|
||||||
let y_m = M::from_row_vector(y.clone());
|
let y_m = M::from_row_vector(y.clone());
|
||||||
let (x_nrows, num_attributes) = x.shape();
|
let (x_nrows, num_attributes) = x.shape();
|
||||||
@@ -302,7 +347,7 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
|
|||||||
let objective = BinaryObjectiveFunction {
|
let objective = BinaryObjectiveFunction {
|
||||||
x,
|
x,
|
||||||
y: yi,
|
y: yi,
|
||||||
phantom: PhantomData,
|
alpha: parameters.alpha,
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = LogisticRegression::minimize(x0, objective);
|
let result = LogisticRegression::minimize(x0, objective);
|
||||||
@@ -324,7 +369,7 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
|
|||||||
x,
|
x,
|
||||||
y: yi,
|
y: yi,
|
||||||
k,
|
k,
|
||||||
phantom: PhantomData,
|
alpha: parameters.alpha,
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = LogisticRegression::minimize(x0, objective);
|
let result = LogisticRegression::minimize(x0, objective);
|
||||||
@@ -431,9 +476,9 @@ mod tests {
|
|||||||
|
|
||||||
let objective = MultiClassObjectiveFunction {
|
let objective = MultiClassObjectiveFunction {
|
||||||
x: &x,
|
x: &x,
|
||||||
y,
|
y: y.clone(),
|
||||||
k: 3,
|
k: 3,
|
||||||
phantom: PhantomData,
|
alpha: 0.0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 9);
|
let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 9);
|
||||||
@@ -454,6 +499,24 @@ mod tests {
|
|||||||
]));
|
]));
|
||||||
|
|
||||||
assert!((f - 408.0052230582765).abs() < std::f64::EPSILON);
|
assert!((f - 408.0052230582765).abs() < std::f64::EPSILON);
|
||||||
|
|
||||||
|
let objective_reg = MultiClassObjectiveFunction {
|
||||||
|
x: &x,
|
||||||
|
y: y.clone(),
|
||||||
|
k: 3,
|
||||||
|
alpha: 1.0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let f = objective_reg.f(&DenseMatrix::row_vector_from_array(&[
|
||||||
|
1., 2., 3., 4., 5., 6., 7., 8., 9.,
|
||||||
|
]));
|
||||||
|
assert!((f - 487.5052).abs() < 1e-4);
|
||||||
|
|
||||||
|
objective_reg.df(
|
||||||
|
&mut g,
|
||||||
|
&DenseMatrix::row_vector_from_array(&[1., 2., 3., 4., 5., 6., 7., 8., 9.]),
|
||||||
|
);
|
||||||
|
assert!((g.get(0, 0).abs() - 32.0).abs() < 1e-4);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -480,8 +543,8 @@ mod tests {
|
|||||||
|
|
||||||
let objective = BinaryObjectiveFunction {
|
let objective = BinaryObjectiveFunction {
|
||||||
x: &x,
|
x: &x,
|
||||||
y,
|
y: y.clone(),
|
||||||
phantom: PhantomData,
|
alpha: 0.0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 3);
|
let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 3);
|
||||||
@@ -496,6 +559,20 @@ mod tests {
|
|||||||
let f = objective.f(&DenseMatrix::row_vector_from_array(&[1., 2., 3.]));
|
let f = objective.f(&DenseMatrix::row_vector_from_array(&[1., 2., 3.]));
|
||||||
|
|
||||||
assert!((f - 59.76994756647412).abs() < std::f64::EPSILON);
|
assert!((f - 59.76994756647412).abs() < std::f64::EPSILON);
|
||||||
|
|
||||||
|
let objective_reg = BinaryObjectiveFunction {
|
||||||
|
x: &x,
|
||||||
|
y: y.clone(),
|
||||||
|
alpha: 1.0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let f = objective_reg.f(&DenseMatrix::row_vector_from_array(&[1., 2., 3.]));
|
||||||
|
assert!((f - 62.2699).abs() < 1e-4);
|
||||||
|
|
||||||
|
objective_reg.df(&mut g, &DenseMatrix::row_vector_from_array(&[1., 2., 3.]));
|
||||||
|
assert!((g.get(0, 0) - 27.0511).abs() < 1e-4);
|
||||||
|
assert!((g.get(0, 1) - 12.239).abs() < 1e-4);
|
||||||
|
assert!((g.get(0, 2) - 3.8693).abs() < 1e-4);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -547,6 +624,15 @@ mod tests {
|
|||||||
let y_hat = lr.predict(&x).unwrap();
|
let y_hat = lr.predict(&x).unwrap();
|
||||||
|
|
||||||
assert!(accuracy(&y_hat, &y) > 0.9);
|
assert!(accuracy(&y_hat, &y) > 0.9);
|
||||||
|
|
||||||
|
let lr_reg = LogisticRegression::fit(
|
||||||
|
&x,
|
||||||
|
&y,
|
||||||
|
LogisticRegressionParameters::default().with_alpha(10.0),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(lr_reg.coefficients().abs().sum() < lr.coefficients().abs().sum());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -561,6 +647,15 @@ mod tests {
|
|||||||
let y_hat = lr.predict(&x).unwrap();
|
let y_hat = lr.predict(&x).unwrap();
|
||||||
|
|
||||||
assert!(accuracy(&y_hat, &y) > 0.9);
|
assert!(accuracy(&y_hat, &y) > 0.9);
|
||||||
|
|
||||||
|
let lr_reg = LogisticRegression::fit(
|
||||||
|
&x,
|
||||||
|
&y,
|
||||||
|
LogisticRegressionParameters::default().with_alpha(10.0),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(lr_reg.coefficients().abs().sum() < lr.coefficients().abs().sum());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -622,6 +717,12 @@ mod tests {
|
|||||||
];
|
];
|
||||||
|
|
||||||
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||||
|
let lr_reg = LogisticRegression::fit(
|
||||||
|
&x,
|
||||||
|
&y,
|
||||||
|
LogisticRegressionParameters::default().with_alpha(1.0),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let y_hat = lr.predict(&x).unwrap();
|
let y_hat = lr.predict(&x).unwrap();
|
||||||
|
|
||||||
@@ -632,5 +733,6 @@ mod tests {
|
|||||||
.sum();
|
.sum();
|
||||||
|
|
||||||
assert!(error <= 1.0);
|
assert!(error <= 1.0);
|
||||||
|
assert!(lr_reg.coefficients().abs().sum() < lr.coefficients().abs().sum());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user