fix: formatting
This commit is contained in:
@@ -155,7 +155,7 @@ pub fn cross_val_predict<T, M, H, E, K, F>(
|
|||||||
x: &M,
|
x: &M,
|
||||||
y: &M::RowVector,
|
y: &M::RowVector,
|
||||||
parameters: H,
|
parameters: H,
|
||||||
cv: K
|
cv: K,
|
||||||
) -> Result<M::RowVector, Failed>
|
) -> Result<M::RowVector, Failed>
|
||||||
where
|
where
|
||||||
T: RealNumber,
|
T: RealNumber,
|
||||||
@@ -163,7 +163,7 @@ where
|
|||||||
H: Clone,
|
H: Clone,
|
||||||
E: Predictor<M, M::RowVector>,
|
E: Predictor<M, M::RowVector>,
|
||||||
K: BaseKFold,
|
K: BaseKFold,
|
||||||
F: Fn(&M, &M::RowVector, H) -> Result<E, Failed>
|
F: Fn(&M, &M::RowVector, H) -> Result<E, Failed>,
|
||||||
{
|
{
|
||||||
let mut y_hat = M::RowVector::zeros(y.len());
|
let mut y_hat = M::RowVector::zeros(y.len());
|
||||||
|
|
||||||
@@ -348,16 +348,8 @@ mod tests {
|
|||||||
..KFold::default()
|
..KFold::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let y_hat = cross_val_predict(
|
let y_hat = cross_val_predict(KNNRegressor::fit, &x, &y, Default::default(), cv).unwrap();
|
||||||
KNNRegressor::fit,
|
|
||||||
&x,
|
|
||||||
&y,
|
|
||||||
Default::default(),
|
|
||||||
cv
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert!(mean_absolute_error(&y, &y_hat) < 10.0);
|
assert!(mean_absolute_error(&y, &y_hat) < 10.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user