feat: simplifies LR API
This commit is contained in:
@@ -122,23 +122,24 @@ impl<'a, M: Matrix> ObjectiveFunction<M> for MultiClassObjectiveFunction<'a, M>
|
|||||||
|
|
||||||
impl<M: Matrix> LogisticRegression<M> {
|
impl<M: Matrix> LogisticRegression<M> {
|
||||||
|
|
||||||
pub fn fit(x: &M, y: &M) -> LogisticRegression<M>{
|
pub fn fit(x: &M, y: &M::RowVector) -> LogisticRegression<M>{
|
||||||
|
|
||||||
|
let y_m = M::from_row_vector(y.clone());
|
||||||
let (x_nrows, num_attributes) = x.shape();
|
let (x_nrows, num_attributes) = x.shape();
|
||||||
let (_, y_nrows) = y.shape();
|
let (_, y_nrows) = y_m.shape();
|
||||||
|
|
||||||
if x_nrows != y_nrows {
|
if x_nrows != y_nrows {
|
||||||
panic!("Number of rows of X doesn't match number of rows of Y");
|
panic!("Number of rows of X doesn't match number of rows of Y");
|
||||||
}
|
}
|
||||||
|
|
||||||
let classes = y.unique();
|
let classes = y_m.unique();
|
||||||
|
|
||||||
let k = classes.len();
|
let k = classes.len();
|
||||||
|
|
||||||
let mut yi: Vec<usize> = vec![0; y_nrows];
|
let mut yi: Vec<usize> = vec![0; y_nrows];
|
||||||
|
|
||||||
for i in 0..y_nrows {
|
for i in 0..y_nrows {
|
||||||
let yc = y.get(0, i);
|
let yc = y_m.get(0, i);
|
||||||
let j = classes.iter().position(|c| yc == *c).unwrap();
|
let j = classes.iter().position(|c| yc == *c).unwrap();
|
||||||
yi[i] = classes.iter().position(|c| yc == *c).unwrap();
|
yi[i] = classes.iter().position(|c| yc == *c).unwrap();
|
||||||
}
|
}
|
||||||
@@ -190,19 +191,19 @@ impl<M: Matrix> LogisticRegression<M> {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn predict(&self, x: &M) -> M {
|
pub fn predict(&self, x: &M) -> M::RowVector {
|
||||||
if self.num_classes == 2 {
|
if self.num_classes == 2 {
|
||||||
let (nrows, _) = x.shape();
|
let (nrows, _) = x.shape();
|
||||||
let x_and_bias = x.v_stack(&M::ones(nrows, 1));
|
let x_and_bias = x.v_stack(&M::ones(nrows, 1));
|
||||||
let y_hat: Vec<f64> = x_and_bias.dot(&self.weights.transpose()).to_raw_vector();
|
let y_hat: Vec<f64> = x_and_bias.dot(&self.weights.transpose()).to_raw_vector();
|
||||||
M::from_vec(1, nrows, y_hat.iter().map(|y_hat| self.classes[if y_hat.sigmoid() > 0.5 { 1 } else { 0 }]).collect())
|
M::from_vec(1, nrows, y_hat.iter().map(|y_hat| self.classes[if y_hat.sigmoid() > 0.5 { 1 } else { 0 }]).collect()).to_row_vector()
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
let (nrows, _) = x.shape();
|
let (nrows, _) = x.shape();
|
||||||
let x_and_bias = x.v_stack(&M::ones(nrows, 1));
|
let x_and_bias = x.v_stack(&M::ones(nrows, 1));
|
||||||
let y_hat = x_and_bias.dot(&self.weights.transpose());
|
let y_hat = x_and_bias.dot(&self.weights.transpose());
|
||||||
let class_idxs = y_hat.argmax();
|
let class_idxs = y_hat.argmax();
|
||||||
M::from_vec(1, nrows, class_idxs.iter().map(|class_idx| self.classes[*class_idx]).collect())
|
M::from_vec(1, nrows, class_idxs.iter().map(|class_idx| self.classes[*class_idx]).collect()).to_row_vector()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -235,9 +236,8 @@ impl<M: Matrix> LogisticRegression<M> {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
use crate::linalg::ndarray_bindings;
|
use ndarray::{arr1, arr2, Array};
|
||||||
use ndarray::{arr2, Array};
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn multiclass_objective_f() {
|
fn multiclass_objective_f() {
|
||||||
@@ -339,7 +339,7 @@ mod tests {
|
|||||||
&[10., -2.],
|
&[10., -2.],
|
||||||
&[ 8., 2.],
|
&[ 8., 2.],
|
||||||
&[ 9., 0.]]);
|
&[ 9., 0.]]);
|
||||||
let y = DenseMatrix::vector_from_array(&[0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.]);
|
let y = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
|
||||||
|
|
||||||
let lr = LogisticRegression::fit(&x, &y);
|
let lr = LogisticRegression::fit(&x, &y);
|
||||||
|
|
||||||
@@ -351,7 +351,7 @@ mod tests {
|
|||||||
|
|
||||||
let y_hat = lr.predict(&x);
|
let y_hat = lr.predict(&x);
|
||||||
|
|
||||||
assert_eq!(y_hat, DenseMatrix::vector_from_array(&[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]));
|
assert_eq!(y_hat, vec![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -380,13 +380,13 @@ mod tests {
|
|||||||
&[4.9, 2.4, 3.3, 1.0],
|
&[4.9, 2.4, 3.3, 1.0],
|
||||||
&[6.6, 2.9, 4.6, 1.3],
|
&[6.6, 2.9, 4.6, 1.3],
|
||||||
&[5.2, 2.7, 3.9, 1.4]]);
|
&[5.2, 2.7, 3.9, 1.4]]);
|
||||||
let y = DenseMatrix::vector_from_array(&[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]);
|
let y =vec![0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
|
||||||
|
|
||||||
let lr = LogisticRegression::fit(&x, &y);
|
let lr = LogisticRegression::fit(&x, &y);
|
||||||
|
|
||||||
let y_hat = lr.predict(&x);
|
let y_hat = lr.predict(&x);
|
||||||
|
|
||||||
assert_eq!(y_hat, DenseMatrix::vector_from_array(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]));
|
assert_eq!(y_hat, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -414,15 +414,13 @@ mod tests {
|
|||||||
[4.9, 2.4, 3.3, 1.0],
|
[4.9, 2.4, 3.3, 1.0],
|
||||||
[6.6, 2.9, 4.6, 1.3],
|
[6.6, 2.9, 4.6, 1.3],
|
||||||
[5.2, 2.7, 3.9, 1.4]]);
|
[5.2, 2.7, 3.9, 1.4]]);
|
||||||
let y = Array::from_shape_vec((1, 20), vec![0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]).unwrap();
|
let y = arr1(&[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]);
|
||||||
|
|
||||||
let lr = LogisticRegression::fit(&x, &y);
|
let lr = LogisticRegression::fit(&x, &y);
|
||||||
|
|
||||||
println!("{:?}", lr);
|
let y_hat = lr.predict(&x);
|
||||||
|
|
||||||
let y_hat = lr.predict(&x).to_raw_vector();
|
assert_eq!(y_hat, arr1(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]));
|
||||||
|
|
||||||
assert_eq!(y_hat, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,12 @@ pub mod ndarray_bindings;
|
|||||||
|
|
||||||
pub trait Matrix: Clone + Debug {
|
pub trait Matrix: Clone + Debug {
|
||||||
|
|
||||||
|
type RowVector: Clone + Debug;
|
||||||
|
|
||||||
|
fn from_row_vector(vec: Self::RowVector) -> Self;
|
||||||
|
|
||||||
|
fn to_row_vector(self) -> Self::RowVector;
|
||||||
|
|
||||||
fn from_array(nrows: usize, ncols: usize, values: &[f64]) -> Self;
|
fn from_array(nrows: usize, ncols: usize, values: &[f64]) -> Self;
|
||||||
|
|
||||||
fn from_vec(nrows: usize, ncols: usize, values: Vec<f64>) -> Self;
|
fn from_vec(nrows: usize, ncols: usize, values: Vec<f64>) -> Self;
|
||||||
|
|||||||
@@ -107,7 +107,17 @@ impl Into<Vec<f64>> for DenseMatrix {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Matrix for DenseMatrix {
|
impl Matrix for DenseMatrix {
|
||||||
|
|
||||||
|
type RowVector = Vec<f64>;
|
||||||
|
|
||||||
|
fn from_row_vector(vec: Self::RowVector) -> Self{
|
||||||
|
DenseMatrix::from_vec(1, vec.len(), vec)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_row_vector(self) -> Self::RowVector{
|
||||||
|
self.to_raw_vector()
|
||||||
|
}
|
||||||
|
|
||||||
fn from_array(nrows: usize, ncols: usize, values: &[f64]) -> DenseMatrix {
|
fn from_array(nrows: usize, ncols: usize, values: &[f64]) -> DenseMatrix {
|
||||||
DenseMatrix::from_vec(nrows, ncols, Vec::from(values))
|
DenseMatrix::from_vec(nrows, ncols, Vec::from(values))
|
||||||
@@ -968,6 +978,15 @@ impl Matrix for DenseMatrix {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn from_to_row_vec() {
|
||||||
|
|
||||||
|
let vec = vec![ 1., 2., 3.];
|
||||||
|
assert_eq!(DenseMatrix::from_row_vector(vec.clone()), DenseMatrix::from_vec(1, 3, vec![1., 2., 3.]));
|
||||||
|
assert_eq!(DenseMatrix::from_row_vector(vec.clone()).to_row_vector(), vec![1., 2., 3.]);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn qr_solve_mut() {
|
fn qr_solve_mut() {
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,20 @@
|
|||||||
use std::ops::Range;
|
use std::ops::Range;
|
||||||
use crate::linalg::{Matrix};
|
use crate::linalg::{Matrix};
|
||||||
use ndarray::{Array, ArrayBase, OwnedRepr, Ix2, Axis, stack, s};
|
use ndarray::{Array, ArrayBase, OwnedRepr, Ix2, Ix1, Axis, stack, s};
|
||||||
|
|
||||||
impl Matrix for ArrayBase<OwnedRepr<f64>, Ix2>
|
impl Matrix for ArrayBase<OwnedRepr<f64>, Ix2>
|
||||||
{
|
{
|
||||||
|
type RowVector = ArrayBase<OwnedRepr<f64>, Ix1>;
|
||||||
|
|
||||||
|
fn from_row_vector(vec: Self::RowVector) -> Self{
|
||||||
|
let vec_size = vec.len();
|
||||||
|
vec.into_shape((1, vec_size)).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_row_vector(self) -> Self::RowVector{
|
||||||
|
let vec_size = self.nrows() * self.ncols();
|
||||||
|
self.into_shape(vec_size).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
fn from_array(nrows: usize, ncols: usize, values: &[f64]) -> Self {
|
fn from_array(nrows: usize, ncols: usize, values: &[f64]) -> Self {
|
||||||
Array::from_shape_vec((nrows, ncols), values.to_vec()).unwrap()
|
Array::from_shape_vec((nrows, ncols), values.to_vec()).unwrap()
|
||||||
@@ -248,7 +259,16 @@ impl Matrix for ArrayBase<OwnedRepr<f64>, Ix2>
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use ndarray::{arr2, Array2};
|
use ndarray::{arr1, arr2, Array2};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn from_to_row_vec() {
|
||||||
|
|
||||||
|
let vec = arr1(&[ 1., 2., 3.]);
|
||||||
|
assert_eq!(Array2::from_row_vector(vec.clone()), arr2(&[[1., 2., 3.]]));
|
||||||
|
assert_eq!(Array2::from_row_vector(vec.clone()).to_row_vector(), arr1(&[1., 2., 3.]));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn add_mut() {
|
fn add_mut() {
|
||||||
|
|||||||
Reference in New Issue
Block a user