fix: clippy, documentation and formatting

This commit is contained in:
Volodymyr Orlov
2020-12-22 16:35:28 -08:00
parent a2be9e117f
commit 9b221979da
7 changed files with 80 additions and 62 deletions
+36 -7
View File
@@ -14,15 +14,27 @@ use crate::error::Failed;
use crate::linalg::BaseVector;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
use crate::model_selection::kfold::BaseKFold;
use rand::seq::SliceRandom;
use rand::thread_rng;
pub mod kfold;
pub(crate) mod kfold;
pub use kfold::{KFold, KFoldIter};
/// An interface for the K-Folds cross-validator
pub trait BaseKFold {
/// An iterator over indices that split data into training and test set.
type Output: Iterator<Item = (Vec<usize>, Vec<usize>)>;
/// Return a tuple containing the the training set indices for that split and
/// the testing set indices for that split.
fn split<T: RealNumber, M: Matrix<T>>(&self, x: &M) -> Self::Output;
/// Returns the number of splits
fn n_splits(&self) -> usize;
}
/// Splits data into 2 disjoint datasets.
/// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes.
/// * `y` - target values, should be of size _M_
/// * `y` - target values, should be of size _N_
/// * `test_size`, (0, 1] - the proportion of the dataset to include in the test split.
/// * `shuffle`, - whether or not to shuffle the data before splitting
pub fn train_test_split<T: RealNumber, M: Matrix<T>>(
@@ -65,22 +77,33 @@ pub fn train_test_split<T: RealNumber, M: Matrix<T>>(
(x_train, x_test, y_train, y_test)
}
/// Cross validation results.
#[derive(Clone, Debug)]
pub struct CrossValidationResult<T: RealNumber> {
/// Vector with test scores on each cv split
pub test_score: Vec<T>,
/// Vector with training scores on each cv split
pub train_score: Vec<T>,
}
impl<T: RealNumber> CrossValidationResult<T> {
/// Average test score
pub fn mean_test_score(&self) -> T {
self.test_score.sum() / T::from_usize(self.test_score.len()).unwrap()
}
/// Average training score
pub fn mean_train_score(&self) -> T {
self.train_score.sum() / T::from_usize(self.train_score.len()).unwrap()
}
}
/// Evaluate an estimator by cross-validation using given metric.
/// * `fit_estimator` - a `fit` function of an estimator
/// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes.
/// * `y` - target values, should be of size _N_
/// * `parameters` - parameters of selected estimator. Use `Default::default()` for default parameters.
/// * `cv` - the cross-validation splitting strategy, should be an instance of [`BaseKFold`](./trait.BaseKFold.html)
/// * `score` - a metric to use for evaluation, see [metrics](../metrics/index.html)
pub fn cross_validate<T, M, H, E, K, F, S>(
fit_estimator: F,
x: &M,
@@ -302,7 +325,6 @@ mod tests {
#[test]
fn test_some_classifier() {
let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
@@ -334,8 +356,15 @@ mod tests {
..KFold::default()
};
let results =
cross_validate(DecisionTreeClassifier::fit, &x, &y, Default::default(), cv, &accuracy).unwrap();
let results = cross_validate(
DecisionTreeClassifier::fit,
&x,
&y,
Default::default(),
cv,
&accuracy,
)
.unwrap();
println!("{}", results.mean_test_score());
println!("{}", results.mean_train_score());