feat: refactoring, adds Result to most public API
This commit is contained in:
@@ -36,8 +36,8 @@
|
||||
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.
|
||||
//! ]);
|
||||
//!
|
||||
//! let lr = LogisticRegression::fit(&x, &y);
|
||||
//! let y_hat = lr.predict(&x);
|
||||
//! let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||
//! let y_hat = lr.predict(&x).unwrap();
|
||||
//! ```
|
||||
use std::iter::Sum;
|
||||
use std::ops::AddAssign;
|
||||
@@ -395,6 +395,7 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::ensemble::random_forest_regressor::*;
|
||||
use crate::linear::logistic_regression::*;
|
||||
use crate::metrics::mean_absolute_error;
|
||||
use ndarray::{arr1, arr2, Array1, Array2};
|
||||
|
||||
#[test]
|
||||
@@ -736,9 +737,9 @@ mod tests {
|
||||
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
]);
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y);
|
||||
let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||
|
||||
let y_hat = lr.predict(&x);
|
||||
let y_hat = lr.predict(&x).unwrap();
|
||||
|
||||
let error: f64 = y
|
||||
.into_iter()
|
||||
@@ -774,10 +775,6 @@ mod tests {
|
||||
114.2, 115.7, 116.9,
|
||||
]);
|
||||
|
||||
let expected_y: Vec<f64> = vec![
|
||||
85., 88., 88., 89., 97., 98., 99., 99., 102., 104., 109., 110., 113., 114., 115., 116.,
|
||||
];
|
||||
|
||||
let y_hat = RandomForestRegressor::fit(
|
||||
&x,
|
||||
&y,
|
||||
@@ -789,10 +786,10 @@ mod tests {
|
||||
m: Option::None,
|
||||
},
|
||||
)
|
||||
.predict(&x);
|
||||
.unwrap()
|
||||
.predict(&x)
|
||||
.unwrap();
|
||||
|
||||
for i in 0..y_hat.len() {
|
||||
assert!((y_hat[i] - expected_y[i]).abs() < 1.0);
|
||||
}
|
||||
assert!(mean_absolute_error(&y, &y_hat) < 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user