feat: simplifies LR API

This commit is contained in:
Volodymyr Orlov
2019-12-23 11:18:22 -08:00
parent c1d7c038a6
commit a4ff1cbe5f
4 changed files with 66 additions and 23 deletions
+18 -20
View File
@@ -122,23 +122,24 @@ impl<'a, M: Matrix> ObjectiveFunction<M> for MultiClassObjectiveFunction<'a, M>
impl<M: Matrix> LogisticRegression<M> {
pub fn fit(x: &M, y: &M) -> LogisticRegression<M>{
pub fn fit(x: &M, y: &M::RowVector) -> LogisticRegression<M>{
let y_m = M::from_row_vector(y.clone());
let (x_nrows, num_attributes) = x.shape();
let (_, y_nrows) = y.shape();
let (_, y_nrows) = y_m.shape();
if x_nrows != y_nrows {
panic!("Number of rows of X doesn't match number of rows of Y");
}
let classes = y.unique();
let classes = y_m.unique();
let k = classes.len();
let mut yi: Vec<usize> = vec![0; y_nrows];
for i in 0..y_nrows {
let yc = y.get(0, i);
let yc = y_m.get(0, i);
let j = classes.iter().position(|c| yc == *c).unwrap();
yi[i] = classes.iter().position(|c| yc == *c).unwrap();
}
@@ -190,19 +191,19 @@ impl<M: Matrix> LogisticRegression<M> {
}
pub fn predict(&self, x: &M) -> M {
pub fn predict(&self, x: &M) -> M::RowVector {
if self.num_classes == 2 {
let (nrows, _) = x.shape();
let x_and_bias = x.v_stack(&M::ones(nrows, 1));
let y_hat: Vec<f64> = x_and_bias.dot(&self.weights.transpose()).to_raw_vector();
M::from_vec(1, nrows, y_hat.iter().map(|y_hat| self.classes[if y_hat.sigmoid() > 0.5 { 1 } else { 0 }]).collect())
M::from_vec(1, nrows, y_hat.iter().map(|y_hat| self.classes[if y_hat.sigmoid() > 0.5 { 1 } else { 0 }]).collect()).to_row_vector()
} else {
let (nrows, _) = x.shape();
let x_and_bias = x.v_stack(&M::ones(nrows, 1));
let y_hat = x_and_bias.dot(&self.weights.transpose());
let class_idxs = y_hat.argmax();
M::from_vec(1, nrows, class_idxs.iter().map(|class_idx| self.classes[*class_idx]).collect())
M::from_vec(1, nrows, class_idxs.iter().map(|class_idx| self.classes[*class_idx]).collect()).to_row_vector()
}
}
@@ -235,9 +236,8 @@ impl<M: Matrix> LogisticRegression<M> {
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::linalg::ndarray_bindings;
use ndarray::{arr2, Array};
use crate::linalg::naive::dense_matrix::DenseMatrix;
use ndarray::{arr1, arr2, Array};
#[test]
fn multiclass_objective_f() {
@@ -339,7 +339,7 @@ mod tests {
&[10., -2.],
&[ 8., 2.],
&[ 9., 0.]]);
let y = DenseMatrix::vector_from_array(&[0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.]);
let y = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
let lr = LogisticRegression::fit(&x, &y);
@@ -351,7 +351,7 @@ mod tests {
let y_hat = lr.predict(&x);
assert_eq!(y_hat, DenseMatrix::vector_from_array(&[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]));
assert_eq!(y_hat, vec![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]);
}
@@ -380,13 +380,13 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4]]);
let y = DenseMatrix::vector_from_array(&[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]);
let y =vec![0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
let lr = LogisticRegression::fit(&x, &y);
let y_hat = lr.predict(&x);
assert_eq!(y_hat, DenseMatrix::vector_from_array(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]));
assert_eq!(y_hat, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]);
}
@@ -414,15 +414,13 @@ mod tests {
[4.9, 2.4, 3.3, 1.0],
[6.6, 2.9, 4.6, 1.3],
[5.2, 2.7, 3.9, 1.4]]);
let y = Array::from_shape_vec((1, 20), vec![0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]).unwrap();
let y = arr1(&[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]);
let lr = LogisticRegression::fit(&x, &y);
let lr = LogisticRegression::fit(&x, &y);
println!("{:?}", lr);
let y_hat = lr.predict(&x);
let y_hat = lr.predict(&x).to_raw_vector();
assert_eq!(y_hat, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]);
assert_eq!(y_hat, arr1(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]));
}