diff --git a/src/ensemble/random_forest_regressor.rs b/src/ensemble/random_forest_regressor.rs index 62da396..493b130 100644 --- a/src/ensemble/random_forest_regressor.rs +++ b/src/ensemble/random_forest_regressor.rs @@ -181,6 +181,7 @@ impl RandomForestRegressor { mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; + use crate::metrics::mean_absolute_error; #[test] fn fit_longley() { @@ -224,9 +225,7 @@ mod tests { ) .predict(&x); - for i in 0..y_hat.len() { - assert!((y_hat[i] - expected_y[i]).abs() < 1.0); - } + assert!(mean_absolute_error(&y, &y_hat) < 1.0); } #[test]