fix: ridge regression, code refactoring

This commit is contained in:
Volodymyr Orlov
2020-11-11 15:59:04 -08:00
parent 7a4fe114d8
commit c42fccdc22
3 changed files with 63 additions and 24 deletions
+2 -2
View File
@@ -154,8 +154,8 @@ impl<T: RealNumber, M: Matrix<T>> LinearRegression<T, M> {
} }
/// Get estimates regression coefficients /// Get estimates regression coefficients
pub fn coefficients(&self) -> M { pub fn coefficients(&self) -> &M {
self.coefficients.clone() &self.coefficients
} }
/// Get estimate of intercept /// Get estimate of intercept
+55 -20
View File
@@ -68,7 +68,8 @@ use crate::optimization::FunctionOrder;
/// Logistic Regression /// Logistic Regression
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct LogisticRegression<T: RealNumber, M: Matrix<T>> { pub struct LogisticRegression<T: RealNumber, M: Matrix<T>> {
weights: M, coefficients: M,
intercept: M,
classes: Vec<T>, classes: Vec<T>,
num_attributes: usize, num_attributes: usize,
num_classes: usize, num_classes: usize,
@@ -109,7 +110,7 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for LogisticRegression<T, M> {
} }
} }
return self.weights == other.weights; return self.coefficients == other.coefficients && self.intercept == other.intercept;
} }
} }
} }
@@ -246,9 +247,11 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
}; };
let result = LogisticRegression::minimize(x0, objective); let result = LogisticRegression::minimize(x0, objective);
let weights = result.x;
Ok(LogisticRegression { Ok(LogisticRegression {
weights: result.x, coefficients: weights.slice(0..1, 0..num_attributes),
intercept: weights.slice(0..1, num_attributes..num_attributes + 1),
classes: classes, classes: classes,
num_attributes: num_attributes, num_attributes: num_attributes,
num_classes: k, num_classes: k,
@@ -268,7 +271,8 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
let weights = result.x.reshape(k, num_attributes + 1); let weights = result.x.reshape(k, num_attributes + 1);
Ok(LogisticRegression { Ok(LogisticRegression {
weights: weights, coefficients: weights.slice(0..k, 0..num_attributes),
intercept: weights.slice(0..k, num_attributes..num_attributes + 1),
classes: classes, classes: classes,
num_attributes: num_attributes, num_attributes: num_attributes,
num_classes: k, num_classes: k,
@@ -283,21 +287,26 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
let mut result = M::zeros(1, n); let mut result = M::zeros(1, n);
if self.num_classes == 2 { if self.num_classes == 2 {
let (nrows, _) = x.shape(); let (nrows, _) = x.shape();
let x_and_bias = x.h_stack(&M::ones(nrows, 1)); let y_hat: Vec<T> = x.matmul(&self.coefficients.transpose()).get_col_as_vec(0);
let y_hat: Vec<T> = x_and_bias let intercept = self.intercept.get(0, 0);
.matmul(&self.weights.transpose())
.get_col_as_vec(0);
for i in 0..n { for i in 0..n {
result.set( result.set(
0, 0,
i, i,
self.classes[if y_hat[i].sigmoid() > T::half() { 1 } else { 0 }], self.classes[if (y_hat[i] + intercept).sigmoid() > T::half() {
1
} else {
0
}],
); );
} }
} else { } else {
let (nrows, _) = x.shape(); let mut y_hat = x.matmul(&self.coefficients.transpose());
let x_and_bias = x.h_stack(&M::ones(nrows, 1)); for r in 0..n {
let y_hat = x_and_bias.matmul(&self.weights.transpose()); for c in 0..self.num_classes {
y_hat.set(r, c, y_hat.get(r, c) + self.intercept.get(c, 0));
}
}
let class_idxs = y_hat.argmax(); let class_idxs = y_hat.argmax();
for i in 0..n { for i in 0..n {
result.set(0, i, self.classes[class_idxs[i]]); result.set(0, i, self.classes[class_idxs[i]]);
@@ -307,17 +316,13 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
} }
/// Get estimates regression coefficients /// Get estimates regression coefficients
pub fn coefficients(&self) -> M { pub fn coefficients(&self) -> &M {
self.weights &self.coefficients
.slice(0..self.num_classes, 0..self.num_attributes)
} }
/// Get estimate of intercept /// Get estimate of intercept
pub fn intercept(&self) -> M { pub fn intercept(&self) -> &M {
self.weights.slice( &self.intercept
0..self.num_classes,
self.num_attributes..self.num_attributes + 1,
)
} }
fn minimize(x0: M, objective: impl ObjectiveFunction<T, M>) -> OptimizerResult<T, M> { fn minimize(x0: M, objective: impl ObjectiveFunction<T, M>) -> OptimizerResult<T, M> {
@@ -336,7 +341,9 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::dataset::generator::make_blobs;
use crate::linalg::naive::dense_matrix::*; use crate::linalg::naive::dense_matrix::*;
use crate::metrics::accuracy;
#[test] #[test]
fn multiclass_objective_f() { fn multiclass_objective_f() {
@@ -466,6 +473,34 @@ mod tests {
); );
} }
#[test]
fn lr_fit_predict_multiclass() {
let blobs = make_blobs(15, 4, 3);
let x = DenseMatrix::from_vec(15, 4, &blobs.data);
let y = blobs.target;
let lr = LogisticRegression::fit(&x, &y).unwrap();
let y_hat = lr.predict(&x).unwrap();
assert!(accuracy(&y_hat, &y) > 0.9);
}
#[test]
fn lr_fit_predict_binary() {
let blobs = make_blobs(20, 4, 2);
let x = DenseMatrix::from_vec(20, 4, &blobs.data);
let y = blobs.target;
let lr = LogisticRegression::fit(&x, &y).unwrap();
let y_hat = lr.predict(&x).unwrap();
assert!(accuracy(&y_hat, &y) > 0.9);
}
#[test] #[test]
fn serde() { fn serde() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
+6 -2
View File
@@ -134,6 +134,10 @@ impl<T: RealNumber, M: Matrix<T>> RidgeRegression<T, M> {
))); )));
} }
if y.len() != n {
return Err(Failed::fit(&format!("Number of rows in X should = len(y)")));
}
let y_column = M::from_row_vector(y.clone()).transpose(); let y_column = M::from_row_vector(y.clone()).transpose();
let (w, b) = if parameters.normalize { let (w, b) = if parameters.normalize {
@@ -216,8 +220,8 @@ impl<T: RealNumber, M: Matrix<T>> RidgeRegression<T, M> {
} }
/// Get estimates regression coefficients /// Get estimates regression coefficients
pub fn coefficients(&self) -> M { pub fn coefficients(&self) -> &M {
self.coefficients.clone() &self.coefficients
} }
/// Get estimate of intercept /// Get estimate of intercept