feat: integrates with nalgebra

This commit is contained in:
Volodymyr Orlov
2020-04-06 19:16:37 -07:00
parent eb0c36223f
commit b068295dac
6 changed files with 66 additions and 18 deletions
+6 -4
View File
@@ -272,7 +272,7 @@ impl<T: FloatExt, M: Matrix<T>> LogisticRegression<T, M> {
mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::*;
use ndarray::{arr1, arr2};
use ndarray::{arr1, arr2, Array1};
#[test]
fn multiclass_objective_f() {
@@ -443,13 +443,15 @@ mod tests {
[4.9, 2.4, 3.3, 1.0],
[6.6, 2.9, 4.6, 1.3],
[5.2, 2.7, 3.9, 1.4]]);
let y = arr1(&[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]);
let y: Array1<f64> = arr1(&[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]);
let lr = LogisticRegression::fit(&x, &y);
let y_hat = lr.predict(&x);
let y_hat = lr.predict(&x);
let error: f64 = y.into_iter().zip(y_hat.into_iter()).map(|(&a, &b)| (a - b).abs()).sum();
assert_eq!(y_hat, arr1(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]));
assert!(error <= 1.0);
}