fix: formatting

This commit is contained in:
Volodymyr Orlov
2020-12-22 17:44:44 -08:00
parent f685f575e0
commit 74f0d9e6fb
+7 -15
View File
@@ -155,7 +155,7 @@ pub fn cross_val_predict<T, M, H, E, K, F>(
x: &M,
y: &M::RowVector,
parameters: H,
cv: K
cv: K,
) -> Result<M::RowVector, Failed>
where
T: RealNumber,
@@ -163,14 +163,14 @@ where
H: Clone,
E: Predictor<M, M::RowVector>,
K: BaseKFold,
F: Fn(&M, &M::RowVector, H) -> Result<E, Failed>
{
let mut y_hat = M::RowVector::zeros(y.len());
F: Fn(&M, &M::RowVector, H) -> Result<E, Failed>,
{
let mut y_hat = M::RowVector::zeros(y.len());
for (train_idx, test_idx) in cv.split(x) {
let train_x = x.take(&train_idx, 0);
let train_y = y.take(&train_idx);
let test_x = x.take(&test_idx, 0);
let test_x = x.take(&test_idx, 0);
let estimator = fit_estimator(&train_x, &train_y, parameters.clone())?;
@@ -348,16 +348,8 @@ mod tests {
..KFold::default()
};
let y_hat = cross_val_predict(
KNNRegressor::fit,
&x,
&y,
Default::default(),
cv
)
.unwrap();
let y_hat = cross_val_predict(KNNRegressor::fit, &x, &y, Default::default(), cv).unwrap();
assert!(mean_absolute_error(&y, &y_hat) < 10.0);
}
}