feat: refactoring, adds Result to most public API

This commit is contained in:
Volodymyr Orlov
2020-09-18 15:20:32 -07:00
parent 4921ae76f5
commit a9db970195
24 changed files with 389 additions and 298 deletions
+9 -12
View File
@@ -36,8 +36,8 @@
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.
//! ]);
//!
//! let lr = LogisticRegression::fit(&x, &y);
//! let y_hat = lr.predict(&x);
//! let lr = LogisticRegression::fit(&x, &y).unwrap();
//! let y_hat = lr.predict(&x).unwrap();
//! ```
use std::iter::Sum;
use std::ops::AddAssign;
@@ -395,6 +395,7 @@ mod tests {
use super::*;
use crate::ensemble::random_forest_regressor::*;
use crate::linear::logistic_regression::*;
use crate::metrics::mean_absolute_error;
use ndarray::{arr1, arr2, Array1, Array2};
#[test]
@@ -736,9 +737,9 @@ mod tests {
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
]);
let lr = LogisticRegression::fit(&x, &y);
let lr = LogisticRegression::fit(&x, &y).unwrap();
let y_hat = lr.predict(&x);
let y_hat = lr.predict(&x).unwrap();
let error: f64 = y
.into_iter()
@@ -774,10 +775,6 @@ mod tests {
114.2, 115.7, 116.9,
]);
let expected_y: Vec<f64> = vec![
85., 88., 88., 89., 97., 98., 99., 99., 102., 104., 109., 110., 113., 114., 115., 116.,
];
let y_hat = RandomForestRegressor::fit(
&x,
&y,
@@ -789,10 +786,10 @@ mod tests {
m: Option::None,
},
)
.predict(&x);
.unwrap()
.predict(&x)
.unwrap();
for i in 0..y_hat.len() {
assert!((y_hat[i] - expected_y[i]).abs() < 1.0);
}
assert!(mean_absolute_error(&y, &y_hat) < 1.0);
}
}