feat: puts ndarray and nalgebra bindings behind feature flags
This commit is contained in:
@@ -368,6 +368,7 @@ impl<T: FloatExt + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linear::linear_regression::*;
|
||||
use nalgebra::{DMatrix, Matrix2x3, RowDVector};
|
||||
|
||||
#[test]
|
||||
@@ -674,4 +675,49 @@ mod tests {
|
||||
assert_eq!(res.len(), 7);
|
||||
assert_eq!(res, vec![-7., -6., -2., 1., 2., 3., 4.]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ols_fit_predict() {
|
||||
let x = DMatrix::from_row_slice(
|
||||
16,
|
||||
6,
|
||||
&[
|
||||
234.289, 235.6, 159.0, 107.608, 1947., 60.323, 259.426, 232.5, 145.6, 108.632,
|
||||
1948., 61.122, 258.054, 368.2, 161.6, 109.773, 1949., 60.171, 284.599, 335.1,
|
||||
165.0, 110.929, 1950., 61.187, 328.975, 209.9, 309.9, 112.075, 1951., 63.221,
|
||||
346.999, 193.2, 359.4, 113.270, 1952., 63.639, 365.385, 187.0, 354.7, 115.094,
|
||||
1953., 64.989, 363.112, 357.8, 335.0, 116.219, 1954., 63.761, 397.469, 290.4,
|
||||
304.8, 117.388, 1955., 66.019, 419.180, 282.2, 285.7, 118.734, 1956., 67.857,
|
||||
442.769, 293.6, 279.8, 120.445, 1957., 68.169, 444.546, 468.1, 263.7, 121.950,
|
||||
1958., 66.513, 482.704, 381.3, 255.2, 123.366, 1959., 68.655, 502.601, 393.1,
|
||||
251.4, 125.368, 1960., 69.564, 518.173, 480.6, 257.2, 127.852, 1961., 69.331,
|
||||
554.894, 400.7, 282.7, 130.081, 1962., 70.551,
|
||||
],
|
||||
);
|
||||
|
||||
let y: RowDVector<f64> = RowDVector::from_vec(vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
]);
|
||||
|
||||
let y_hat_qr = LinearRegression::fit(
|
||||
&x,
|
||||
&y,
|
||||
LinearRegressionParameters {
|
||||
solver: LinearRegressionSolverName::QR,
|
||||
},
|
||||
)
|
||||
.predict(&x);
|
||||
|
||||
let y_hat_svd = LinearRegression::fit(&x, &y, Default::default()).predict(&x);
|
||||
|
||||
assert!(y
|
||||
.iter()
|
||||
.zip(y_hat_qr.iter())
|
||||
.all(|(&a, &b)| (a - b).abs() <= 5.0));
|
||||
assert!(y
|
||||
.iter()
|
||||
.zip(y_hat_svd.iter())
|
||||
.all(|(&a, &b)| (a - b).abs() <= 5.0));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user