fix: formatting

This commit is contained in:
Volodymyr Orlov
2020-12-22 17:44:44 -08:00
parent f685f575e0
commit 74f0d9e6fb
+3 -11
View File
@@ -155,7 +155,7 @@ pub fn cross_val_predict<T, M, H, E, K, F>(
x: &M, x: &M,
y: &M::RowVector, y: &M::RowVector,
parameters: H, parameters: H,
cv: K cv: K,
) -> Result<M::RowVector, Failed> ) -> Result<M::RowVector, Failed>
where where
T: RealNumber, T: RealNumber,
@@ -163,7 +163,7 @@ where
H: Clone, H: Clone,
E: Predictor<M, M::RowVector>, E: Predictor<M, M::RowVector>,
K: BaseKFold, K: BaseKFold,
F: Fn(&M, &M::RowVector, H) -> Result<E, Failed> F: Fn(&M, &M::RowVector, H) -> Result<E, Failed>,
{ {
let mut y_hat = M::RowVector::zeros(y.len()); let mut y_hat = M::RowVector::zeros(y.len());
@@ -348,16 +348,8 @@ mod tests {
..KFold::default() ..KFold::default()
}; };
let y_hat = cross_val_predict( let y_hat = cross_val_predict(KNNRegressor::fit, &x, &y, Default::default(), cv).unwrap();
KNNRegressor::fit,
&x,
&y,
Default::default(),
cv
)
.unwrap();
assert!(mean_absolute_error(&y, &y_hat) < 10.0); assert!(mean_absolute_error(&y, &y_hat) < 10.0);
} }
} }