From 74f0d9e6fb574196cd84bc7d82169ad8a96cb910 Mon Sep 17 00:00:00 2001 From: Volodymyr Orlov Date: Tue, 22 Dec 2020 17:44:44 -0800 Subject: [PATCH] fix: formatting --- src/model_selection/mod.rs | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/src/model_selection/mod.rs b/src/model_selection/mod.rs index 7178da8..7776354 100644 --- a/src/model_selection/mod.rs +++ b/src/model_selection/mod.rs @@ -155,7 +155,7 @@ pub fn cross_val_predict( x: &M, y: &M::RowVector, parameters: H, - cv: K + cv: K, ) -> Result where T: RealNumber, @@ -163,14 +163,14 @@ where H: Clone, E: Predictor, K: BaseKFold, - F: Fn(&M, &M::RowVector, H) -> Result -{ - let mut y_hat = M::RowVector::zeros(y.len()); - + F: Fn(&M, &M::RowVector, H) -> Result, +{ + let mut y_hat = M::RowVector::zeros(y.len()); + for (train_idx, test_idx) in cv.split(x) { let train_x = x.take(&train_idx, 0); let train_y = y.take(&train_idx); - let test_x = x.take(&test_idx, 0); + let test_x = x.take(&test_idx, 0); let estimator = fit_estimator(&train_x, &train_y, parameters.clone())?; @@ -348,16 +348,8 @@ mod tests { ..KFold::default() }; - let y_hat = cross_val_predict( - KNNRegressor::fit, - &x, - &y, - Default::default(), - cv - ) - .unwrap(); + let y_hat = cross_val_predict(KNNRegressor::fit, &x, &y, Default::default(), cv).unwrap(); assert!(mean_absolute_error(&y, &y_hat) < 10.0); } - }