feat: puts ndarray and nalgebra bindings behind feature flags

This commit is contained in:
Volodymyr Orlov
2020-08-28 16:55:41 -07:00
parent 367ea62608
commit 68dca25f91
8 changed files with 173 additions and 125 deletions
+1 -47
View File
@@ -13,7 +13,7 @@ pub enum LinearRegressionSolverName {
#[derive(Serialize, Deserialize, Debug)]
pub struct LinearRegressionParameters {
solver: LinearRegressionSolverName,
pub solver: LinearRegressionSolverName,
}
#[derive(Serialize, Deserialize, Debug)]
@@ -81,55 +81,9 @@ impl<T: FloatExt, M: Matrix<T>> LinearRegression<T, M> {
mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::*;
use nalgebra::{DMatrix, RowDVector};
#[test]
fn ols_fit_predict() {
let x = DMatrix::from_row_slice(
16,
6,
&[
234.289, 235.6, 159.0, 107.608, 1947., 60.323, 259.426, 232.5, 145.6, 108.632,
1948., 61.122, 258.054, 368.2, 161.6, 109.773, 1949., 60.171, 284.599, 335.1,
165.0, 110.929, 1950., 61.187, 328.975, 209.9, 309.9, 112.075, 1951., 63.221,
346.999, 193.2, 359.4, 113.270, 1952., 63.639, 365.385, 187.0, 354.7, 115.094,
1953., 64.989, 363.112, 357.8, 335.0, 116.219, 1954., 63.761, 397.469, 290.4,
304.8, 117.388, 1955., 66.019, 419.180, 282.2, 285.7, 118.734, 1956., 67.857,
442.769, 293.6, 279.8, 120.445, 1957., 68.169, 444.546, 468.1, 263.7, 121.950,
1958., 66.513, 482.704, 381.3, 255.2, 123.366, 1959., 68.655, 502.601, 393.1,
251.4, 125.368, 1960., 69.564, 518.173, 480.6, 257.2, 127.852, 1961., 69.331,
554.894, 400.7, 282.7, 130.081, 1962., 70.551,
],
);
let y: RowDVector<f64> = RowDVector::from_vec(vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
]);
let y_hat_qr = LinearRegression::fit(
&x,
&y,
LinearRegressionParameters {
solver: LinearRegressionSolverName::QR,
},
)
.predict(&x);
let y_hat_svd = LinearRegression::fit(&x, &y, Default::default()).predict(&x);
assert!(y
.iter()
.zip(y_hat_qr.iter())
.all(|(&a, &b)| (a - b).abs() <= 5.0));
assert!(y
.iter()
.zip(y_hat_svd.iter())
.all(|(&a, &b)| (a - b).abs() <= 5.0));
}
#[test]
fn ols_fit_predict_nalgebra() {
let x = DenseMatrix::from_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
+24 -25
View File
@@ -265,7 +265,6 @@ impl<T: FloatExt, M: Matrix<T>> LogisticRegression<T, M> {
mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::*;
use ndarray::{arr1, arr2, Array1};
#[test]
fn multiclass_objective_f() {
@@ -426,31 +425,31 @@ mod tests {
#[test]
fn lr_fit_predict_iris() {
let x = arr2(&[
[5.1, 3.5, 1.4, 0.2],
[4.9, 3.0, 1.4, 0.2],
[4.7, 3.2, 1.3, 0.2],
[4.6, 3.1, 1.5, 0.2],
[5.0, 3.6, 1.4, 0.2],
[5.4, 3.9, 1.7, 0.4],
[4.6, 3.4, 1.4, 0.3],
[5.0, 3.4, 1.5, 0.2],
[4.4, 2.9, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.1],
[7.0, 3.2, 4.7, 1.4],
[6.4, 3.2, 4.5, 1.5],
[6.9, 3.1, 4.9, 1.5],
[5.5, 2.3, 4.0, 1.3],
[6.5, 2.8, 4.6, 1.5],
[5.7, 2.8, 4.5, 1.3],
[6.3, 3.3, 4.7, 1.6],
[4.9, 2.4, 3.3, 1.0],
[6.6, 2.9, 4.6, 1.3],
[5.2, 2.7, 3.9, 1.4],
let x = DenseMatrix::from_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[5.7, 2.8, 4.5, 1.3],
&[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
let y: Array1<f64> = arr1(&[
let y: Vec<f64> = vec![
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
]);
];
let lr = LogisticRegression::fit(&x, &y);
@@ -459,7 +458,7 @@ mod tests {
let error: f64 = y
.into_iter()
.zip(y_hat.into_iter())
.map(|(&a, &b)| (a - b).abs())
.map(|(a, b)| (a - b).abs())
.sum();
assert!(error <= 1.0);