diff --git a/src/classification/logistic_regression.rs b/src/classification/logistic_regression.rs index 5212973..a297f36 100644 --- a/src/classification/logistic_regression.rs +++ b/src/classification/logistic_regression.rs @@ -122,23 +122,24 @@ impl<'a, M: Matrix> ObjectiveFunction for MultiClassObjectiveFunction<'a, M> impl LogisticRegression { - pub fn fit(x: &M, y: &M) -> LogisticRegression{ + pub fn fit(x: &M, y: &M::RowVector) -> LogisticRegression{ + let y_m = M::from_row_vector(y.clone()); let (x_nrows, num_attributes) = x.shape(); - let (_, y_nrows) = y.shape(); + let (_, y_nrows) = y_m.shape(); if x_nrows != y_nrows { panic!("Number of rows of X doesn't match number of rows of Y"); } - let classes = y.unique(); + let classes = y_m.unique(); let k = classes.len(); let mut yi: Vec = vec![0; y_nrows]; for i in 0..y_nrows { - let yc = y.get(0, i); + let yc = y_m.get(0, i); let j = classes.iter().position(|c| yc == *c).unwrap(); yi[i] = classes.iter().position(|c| yc == *c).unwrap(); } @@ -190,19 +191,19 @@ impl LogisticRegression { } - pub fn predict(&self, x: &M) -> M { + pub fn predict(&self, x: &M) -> M::RowVector { if self.num_classes == 2 { let (nrows, _) = x.shape(); let x_and_bias = x.v_stack(&M::ones(nrows, 1)); let y_hat: Vec = x_and_bias.dot(&self.weights.transpose()).to_raw_vector(); - M::from_vec(1, nrows, y_hat.iter().map(|y_hat| self.classes[if y_hat.sigmoid() > 0.5 { 1 } else { 0 }]).collect()) + M::from_vec(1, nrows, y_hat.iter().map(|y_hat| self.classes[if y_hat.sigmoid() > 0.5 { 1 } else { 0 }]).collect()).to_row_vector() } else { let (nrows, _) = x.shape(); let x_and_bias = x.v_stack(&M::ones(nrows, 1)); let y_hat = x_and_bias.dot(&self.weights.transpose()); let class_idxs = y_hat.argmax(); - M::from_vec(1, nrows, class_idxs.iter().map(|class_idx| self.classes[*class_idx]).collect()) + M::from_vec(1, nrows, class_idxs.iter().map(|class_idx| self.classes[*class_idx]).collect()).to_row_vector() } } @@ -235,9 +236,8 @@ impl LogisticRegression { #[cfg(test)] mod tests { use super::*; - use crate::linalg::naive::dense_matrix::DenseMatrix; - use crate::linalg::ndarray_bindings; - use ndarray::{arr2, Array}; + use crate::linalg::naive::dense_matrix::DenseMatrix; + use ndarray::{arr1, arr2, Array}; #[test] fn multiclass_objective_f() { @@ -339,7 +339,7 @@ mod tests { &[10., -2.], &[ 8., 2.], &[ 9., 0.]]); - let y = DenseMatrix::vector_from_array(&[0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.]); + let y = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.]; let lr = LogisticRegression::fit(&x, &y); @@ -351,7 +351,7 @@ mod tests { let y_hat = lr.predict(&x); - assert_eq!(y_hat, DenseMatrix::vector_from_array(&[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])); + assert_eq!(y_hat, vec![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]); } @@ -380,13 +380,13 @@ mod tests { &[4.9, 2.4, 3.3, 1.0], &[6.6, 2.9, 4.6, 1.3], &[5.2, 2.7, 3.9, 1.4]]); - let y = DenseMatrix::vector_from_array(&[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]); + let y =vec![0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]; let lr = LogisticRegression::fit(&x, &y); let y_hat = lr.predict(&x); - assert_eq!(y_hat, DenseMatrix::vector_from_array(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])); + assert_eq!(y_hat, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]); } @@ -414,15 +414,13 @@ mod tests { [4.9, 2.4, 3.3, 1.0], [6.6, 2.9, 4.6, 1.3], [5.2, 2.7, 3.9, 1.4]]); - let y = Array::from_shape_vec((1, 20), vec![0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]).unwrap(); + let y = arr1(&[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]); - let lr = LogisticRegression::fit(&x, &y); + let lr = LogisticRegression::fit(&x, &y); - println!("{:?}", lr); + let y_hat = lr.predict(&x); - let y_hat = lr.predict(&x).to_raw_vector(); - - assert_eq!(y_hat, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]); + assert_eq!(y_hat, arr1(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])); } diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs index f3eebfb..896f8d6 100644 --- a/src/linalg/mod.rs +++ b/src/linalg/mod.rs @@ -6,6 +6,12 @@ pub mod ndarray_bindings; pub trait Matrix: Clone + Debug { + type RowVector: Clone + Debug; + + fn from_row_vector(vec: Self::RowVector) -> Self; + + fn to_row_vector(self) -> Self::RowVector; + fn from_array(nrows: usize, ncols: usize, values: &[f64]) -> Self; fn from_vec(nrows: usize, ncols: usize, values: Vec) -> Self; diff --git a/src/linalg/naive/dense_matrix.rs b/src/linalg/naive/dense_matrix.rs index 067844e..5165339 100644 --- a/src/linalg/naive/dense_matrix.rs +++ b/src/linalg/naive/dense_matrix.rs @@ -107,7 +107,17 @@ impl Into> for DenseMatrix { } } -impl Matrix for DenseMatrix { +impl Matrix for DenseMatrix { + + type RowVector = Vec; + + fn from_row_vector(vec: Self::RowVector) -> Self{ + DenseMatrix::from_vec(1, vec.len(), vec) + } + + fn to_row_vector(self) -> Self::RowVector{ + self.to_raw_vector() + } fn from_array(nrows: usize, ncols: usize, values: &[f64]) -> DenseMatrix { DenseMatrix::from_vec(nrows, ncols, Vec::from(values)) @@ -968,6 +978,15 @@ impl Matrix for DenseMatrix { mod tests { use super::*; + #[test] + fn from_to_row_vec() { + + let vec = vec![ 1., 2., 3.]; + assert_eq!(DenseMatrix::from_row_vector(vec.clone()), DenseMatrix::from_vec(1, 3, vec![1., 2., 3.])); + assert_eq!(DenseMatrix::from_row_vector(vec.clone()).to_row_vector(), vec![1., 2., 3.]); + + } + #[test] fn qr_solve_mut() { diff --git a/src/linalg/ndarray_bindings.rs b/src/linalg/ndarray_bindings.rs index 0a18ace..a3d2723 100644 --- a/src/linalg/ndarray_bindings.rs +++ b/src/linalg/ndarray_bindings.rs @@ -1,9 +1,20 @@ use std::ops::Range; use crate::linalg::{Matrix}; -use ndarray::{Array, ArrayBase, OwnedRepr, Ix2, Axis, stack, s}; +use ndarray::{Array, ArrayBase, OwnedRepr, Ix2, Ix1, Axis, stack, s}; impl Matrix for ArrayBase, Ix2> { + type RowVector = ArrayBase, Ix1>; + + fn from_row_vector(vec: Self::RowVector) -> Self{ + let vec_size = vec.len(); + vec.into_shape((1, vec_size)).unwrap() + } + + fn to_row_vector(self) -> Self::RowVector{ + let vec_size = self.nrows() * self.ncols(); + self.into_shape(vec_size).unwrap() + } fn from_array(nrows: usize, ncols: usize, values: &[f64]) -> Self { Array::from_shape_vec((nrows, ncols), values.to_vec()).unwrap() @@ -248,7 +259,16 @@ impl Matrix for ArrayBase, Ix2> #[cfg(test)] mod tests { use super::*; - use ndarray::{arr2, Array2}; + use ndarray::{arr1, arr2, Array2}; + + #[test] + fn from_to_row_vec() { + + let vec = arr1(&[ 1., 2., 3.]); + assert_eq!(Array2::from_row_vector(vec.clone()), arr2(&[[1., 2., 3.]])); + assert_eq!(Array2::from_row_vector(vec.clone()).to_row_vector(), arr1(&[1., 2., 3.])); + + } #[test] fn add_mut() {