feat: + cross_val_predict
This commit is contained in:
+48
-57
@@ -125,7 +125,7 @@ where
|
||||
let mut test_score = Vec::with_capacity(k);
|
||||
let mut train_score = Vec::with_capacity(k);
|
||||
|
||||
for (test_idx, train_idx) in cv.split(x) {
|
||||
for (train_idx, test_idx) in cv.split(x) {
|
||||
let train_x = x.take(&train_idx, 0);
|
||||
let train_y = y.take(&train_idx);
|
||||
let test_x = x.take(&test_idx, 0);
|
||||
@@ -143,6 +143,46 @@ where
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate cross-validated estimates for each input data point.
|
||||
/// The data is split according to the cv parameter. Each sample belongs to exactly one test set, and its prediction is computed with an estimator fitted on the corresponding training set.
|
||||
/// * `fit_estimator` - a `fit` function of an estimator
|
||||
/// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes.
|
||||
/// * `y` - target values, should be of size _N_
|
||||
/// * `parameters` - parameters of selected estimator. Use `Default::default()` for default parameters.
|
||||
/// * `cv` - the cross-validation splitting strategy, should be an instance of [`BaseKFold`](./trait.BaseKFold.html)
|
||||
pub fn cross_val_predict<T, M, H, E, K, F>(
|
||||
fit_estimator: F,
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: H,
|
||||
cv: K
|
||||
) -> Result<M::RowVector, Failed>
|
||||
where
|
||||
T: RealNumber,
|
||||
M: Matrix<T>,
|
||||
H: Clone,
|
||||
E: Predictor<M, M::RowVector>,
|
||||
K: BaseKFold,
|
||||
F: Fn(&M, &M::RowVector, H) -> Result<E, Failed>
|
||||
{
|
||||
let mut y_hat = M::RowVector::zeros(y.len());
|
||||
|
||||
for (train_idx, test_idx) in cv.split(x) {
|
||||
let train_x = x.take(&train_idx, 0);
|
||||
let train_y = y.take(&train_idx);
|
||||
let test_x = x.take(&test_idx, 0);
|
||||
|
||||
let estimator = fit_estimator(&train_x, &train_y, parameters.clone())?;
|
||||
|
||||
let y_test_hat = estimator.predict(&test_x)?;
|
||||
for (i, &idx) in test_idx.iter().enumerate() {
|
||||
y_hat.set(idx, y_test_hat.get(i));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(y_hat)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
@@ -278,10 +318,8 @@ mod tests {
|
||||
assert!(results.mean_train_score() < results.mean_test_score());
|
||||
}
|
||||
|
||||
use crate::tree::decision_tree_regressor::*;
|
||||
|
||||
#[test]
|
||||
fn test_some_regressor() {
|
||||
fn test_cross_val_predict_knn() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
@@ -305,68 +343,21 @@ mod tests {
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
|
||||
let cv = KFold::default().with_n_splits(2);
|
||||
|
||||
let results = cross_validate(
|
||||
DecisionTreeRegressor::fit,
|
||||
&x,
|
||||
&y,
|
||||
Default::default(),
|
||||
cv,
|
||||
&mean_absolute_error,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
println!("{}", results.mean_test_score());
|
||||
println!("{}", results.mean_train_score());
|
||||
}
|
||||
|
||||
use crate::tree::decision_tree_classifier::*;
|
||||
|
||||
#[test]
|
||||
fn test_some_classifier() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[5.7, 2.8, 4.5, 1.3],
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
let y = vec![
|
||||
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
|
||||
let cv = KFold {
|
||||
n_splits: 2,
|
||||
..KFold::default()
|
||||
};
|
||||
|
||||
let results = cross_validate(
|
||||
DecisionTreeClassifier::fit,
|
||||
let y_hat = cross_val_predict(
|
||||
KNNRegressor::fit,
|
||||
&x,
|
||||
&y,
|
||||
Default::default(),
|
||||
cv,
|
||||
&accuracy,
|
||||
cv
|
||||
)
|
||||
.unwrap();
|
||||
.unwrap();
|
||||
|
||||
println!("{}", results.mean_test_score());
|
||||
println!("{}", results.mean_train_score());
|
||||
assert!(mean_absolute_error(&y, &y_hat) < 10.0);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user