fix: formatting
This commit is contained in:
@@ -155,7 +155,7 @@ pub fn cross_val_predict<T, M, H, E, K, F>(
|
|||||||
x: &M,
|
x: &M,
|
||||||
y: &M::RowVector,
|
y: &M::RowVector,
|
||||||
parameters: H,
|
parameters: H,
|
||||||
cv: K
|
cv: K,
|
||||||
) -> Result<M::RowVector, Failed>
|
) -> Result<M::RowVector, Failed>
|
||||||
where
|
where
|
||||||
T: RealNumber,
|
T: RealNumber,
|
||||||
@@ -163,14 +163,14 @@ where
|
|||||||
H: Clone,
|
H: Clone,
|
||||||
E: Predictor<M, M::RowVector>,
|
E: Predictor<M, M::RowVector>,
|
||||||
K: BaseKFold,
|
K: BaseKFold,
|
||||||
F: Fn(&M, &M::RowVector, H) -> Result<E, Failed>
|
F: Fn(&M, &M::RowVector, H) -> Result<E, Failed>,
|
||||||
{
|
{
|
||||||
let mut y_hat = M::RowVector::zeros(y.len());
|
let mut y_hat = M::RowVector::zeros(y.len());
|
||||||
|
|
||||||
for (train_idx, test_idx) in cv.split(x) {
|
for (train_idx, test_idx) in cv.split(x) {
|
||||||
let train_x = x.take(&train_idx, 0);
|
let train_x = x.take(&train_idx, 0);
|
||||||
let train_y = y.take(&train_idx);
|
let train_y = y.take(&train_idx);
|
||||||
let test_x = x.take(&test_idx, 0);
|
let test_x = x.take(&test_idx, 0);
|
||||||
|
|
||||||
let estimator = fit_estimator(&train_x, &train_y, parameters.clone())?;
|
let estimator = fit_estimator(&train_x, &train_y, parameters.clone())?;
|
||||||
|
|
||||||
@@ -348,16 +348,8 @@ mod tests {
|
|||||||
..KFold::default()
|
..KFold::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let y_hat = cross_val_predict(
|
let y_hat = cross_val_predict(KNNRegressor::fit, &x, &y, Default::default(), cv).unwrap();
|
||||||
KNNRegressor::fit,
|
|
||||||
&x,
|
|
||||||
&y,
|
|
||||||
Default::default(),
|
|
||||||
cv
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
assert!(mean_absolute_error(&y, &y_hat) < 10.0);
|
assert!(mean_absolute_error(&y, &y_hat) < 10.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user