diff --git a/Cargo.toml b/Cargo.toml index 47283cc..1c2e0a9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ edition = "2018" [dependencies] ndarray = "0.13" +nalgebra = "0.20.0" num-traits = "0.2.11" num = "0.2.1" rand = "0.7.3" diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs index 5b28f6a..9a64a0f 100644 --- a/src/linalg/mod.rs +++ b/src/linalg/mod.rs @@ -3,6 +3,7 @@ pub mod qr; pub mod svd; pub mod evd; pub mod ndarray_bindings; +pub mod nalgebra_bindings; use std::ops::Range; use std::fmt::{Debug, Display}; diff --git a/src/linalg/ndarray_bindings.rs b/src/linalg/ndarray_bindings.rs index 7f1c9b0..a22c335 100644 --- a/src/linalg/ndarray_bindings.rs +++ b/src/linalg/ndarray_bindings.rs @@ -213,6 +213,9 @@ impl &Self{ + for v in self.iter_mut(){ + *v = v.abs() + } self } @@ -631,4 +634,12 @@ mod tests { assert!(a.approximate_eq(&(&noise + &a), 1e-4)); assert!(!a.approximate_eq(&(&noise + &a), 1e-5)); } + + #[test] + fn abs_mut() { + let mut a = arr2(&[[1., -2.], [3., -4.]]); + let expected = arr2(&[[1., 2.], [3., 4.]]); + a.abs_mut(); + assert_eq!(a, expected); + } } \ No newline at end of file diff --git a/src/linear/linear_regression.rs b/src/linear/linear_regression.rs index a1d5e2c..957e0bf 100644 --- a/src/linear/linear_regression.rs +++ b/src/linear/linear_regression.rs @@ -27,9 +27,10 @@ impl> PartialEq for LinearRegression { impl> LinearRegression { - pub fn fit(x: &M, y: &M, solver: LinearRegressionSolver) -> LinearRegression{ + pub fn fit(x: &M, y: &M::RowVector, solver: LinearRegressionSolver) -> LinearRegression{ - let b = y.transpose(); + let y_m = M::from_row_vector(y.clone()); + let b = y_m.transpose(); let (x_nrows, num_attributes) = x.shape(); let (y_nrows, _) = b.shape(); @@ -63,13 +64,46 @@ impl> LinearRegression { } #[cfg(test)] -mod tests { - use super::*; - use crate::linalg::naive::dense_matrix::*; +mod tests { + use super::*; + use nalgebra::{DMatrix, RowDVector}; + use crate::linalg::naive::dense_matrix::*; #[test] fn ols_fit_predict() { + let x = DMatrix::from_row_slice(16, 6, &[ + 234.289, 235.6, 159.0, 107.608, 1947., 60.323, + 259.426, 232.5, 145.6, 108.632, 1948., 61.122, + 258.054, 368.2, 161.6, 109.773, 1949., 60.171, + 284.599, 335.1, 165.0, 110.929, 1950., 61.187, + 328.975, 209.9, 309.9, 112.075, 1951., 63.221, + 346.999, 193.2, 359.4, 113.270, 1952., 63.639, + 365.385, 187.0, 354.7, 115.094, 1953., 64.989, + 363.112, 357.8, 335.0, 116.219, 1954., 63.761, + 397.469, 290.4, 304.8, 117.388, 1955., 66.019, + 419.180, 282.2, 285.7, 118.734, 1956., 67.857, + 442.769, 293.6, 279.8, 120.445, 1957., 68.169, + 444.546, 468.1, 263.7, 121.950, 1958., 66.513, + 482.704, 381.3, 255.2, 123.366, 1959., 68.655, + 502.601, 393.1, 251.4, 125.368, 1960., 69.564, + 518.173, 480.6, 257.2, 127.852, 1961., 69.331, + 554.894, 400.7, 282.7, 130.081, 1962., 70.551]); + + let y: RowDVector = RowDVector::from_vec(vec!(83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9)); + + let y_hat_qr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR).predict(&x); + + let y_hat_svd = LinearRegression::fit(&x, &y, LinearRegressionSolver::SVD).predict(&x); + + assert!(y.iter().zip(y_hat_qr.iter()).all(|(&a, &b)| (a - b).abs() <= 5.0)); + assert!(y.iter().zip(y_hat_svd.iter()).all(|(&a, &b)| (a - b).abs() <= 5.0)); + + } + + #[test] + fn ols_fit_predict_nalgebra() { + let x = DenseMatrix::from_array(&[ &[234.289, 235.6, 159.0, 107.608, 1947., 60.323], &[259.426, 232.5, 145.6, 108.632, 1948., 61.122], @@ -88,15 +122,14 @@ mod tests { &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], &[554.894, 400.7, 282.7, 130.081, 1962., 70.551]]); - let y = DenseMatrix::from_array(&[&[83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]]); + let y: Vec = vec!(83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9); - let y_hat_qr = DenseMatrix::from_row_vector(LinearRegression::fit(&x, &y, LinearRegressionSolver::QR).predict(&x)); - - let y_hat_svd = DenseMatrix::from_row_vector(LinearRegression::fit(&x, &y, LinearRegressionSolver::SVD).predict(&x)); - - assert!(y.approximate_eq(&y_hat_qr, 5.)); - assert!(y.approximate_eq(&y_hat_svd, 5.)); + let y_hat_qr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR).predict(&x); + let y_hat_svd = LinearRegression::fit(&x, &y, LinearRegressionSolver::SVD).predict(&x); + + assert!(y.iter().zip(y_hat_qr.iter()).all(|(&a, &b)| (a - b).abs() <= 5.0)); + assert!(y.iter().zip(y_hat_svd.iter()).all(|(&a, &b)| (a - b).abs() <= 5.0)); } @@ -120,7 +153,7 @@ mod tests { &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], &[554.894, 400.7, 282.7, 130.081, 1962., 70.551]]); - let y = DenseMatrix::from_array(&[&[83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]]); + let y = vec!(83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9); let lr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR); diff --git a/src/linear/logistic_regression.rs b/src/linear/logistic_regression.rs index bd1d9bc..f78c310 100644 --- a/src/linear/logistic_regression.rs +++ b/src/linear/logistic_regression.rs @@ -272,7 +272,7 @@ impl> LogisticRegression { mod tests { use super::*; use crate::linalg::naive::dense_matrix::*; - use ndarray::{arr1, arr2}; + use ndarray::{arr1, arr2, Array1}; #[test] fn multiclass_objective_f() { @@ -443,13 +443,15 @@ mod tests { [4.9, 2.4, 3.3, 1.0], [6.6, 2.9, 4.6, 1.3], [5.2, 2.7, 3.9, 1.4]]); - let y = arr1(&[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]); + let y: Array1 = arr1(&[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]); let lr = LogisticRegression::fit(&x, &y); - let y_hat = lr.predict(&x); + let y_hat = lr.predict(&x); + + let error: f64 = y.into_iter().zip(y_hat.into_iter()).map(|(&a, &b)| (a - b).abs()).sum(); - assert_eq!(y_hat, arr1(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])); + assert!(error <= 1.0); } diff --git a/src/math/num.rs b/src/math/num.rs index 7980853..5c1b60f 100644 --- a/src/math/num.rs +++ b/src/math/num.rs @@ -2,7 +2,7 @@ use std::fmt::{Debug, Display}; use num_traits::{Float, FromPrimitive}; use rand::prelude::*; -pub trait FloatExt: Float + FromPrimitive + Debug + Display { +pub trait FloatExt: Float + FromPrimitive + Debug + Display + Copy { fn copysign(self, sign: Self) -> Self;