fix: clippy, documentation and formatting
This commit is contained in:
@@ -1,30 +1,13 @@
|
||||
//! # KFold
|
||||
//!
|
||||
//! In statistics and machine learning we usually split our data into multiple subsets: training data and testing data (and sometimes to validate),
|
||||
//! and fit our model on the train data, in order to make predictions on the test data. We do that to avoid overfitting or underfitting model to our data.
|
||||
//! Overfitting is bad because the model we trained fits trained data too well and can’t make any inferences on new data.
|
||||
//! Underfitted is bad because the model is undetrained and does not fit the training data well.
|
||||
//! Splitting data into multiple subsets helps to find the right combination of hyperparameters, estimate model performance and choose the right model for
|
||||
//! your data.
|
||||
//!
|
||||
//! In SmartCore you can split your data into training and test datasets using `train_test_split` function.
|
||||
//! Defines k-fold cross validator.
|
||||
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::model_selection::BaseKFold;
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
/// An interface for the K-Folds cross-validator
|
||||
pub trait BaseKFold {
|
||||
/// An iterator over indices that split data into training and test set.
|
||||
type Output: Iterator<Item = (Vec<usize>, Vec<usize>)>;
|
||||
/// Return a tuple containing the the training set indices for that split and
|
||||
/// the testing set indices for that split.
|
||||
fn split<T: RealNumber, M: Matrix<T>>(&self, x: &M) -> Self::Output;
|
||||
/// Returns the number of splits
|
||||
fn n_splits(&self) -> usize;
|
||||
}
|
||||
|
||||
/// K-Folds cross-validator
|
||||
pub struct KFold {
|
||||
/// Number of folds. Must be at least 2.
|
||||
@@ -101,12 +84,12 @@ impl KFold {
|
||||
}
|
||||
|
||||
/// An iterator over indices that split data into training and test set.
|
||||
pub struct BaseKFoldIter {
|
||||
pub struct KFoldIter {
|
||||
indices: Vec<usize>,
|
||||
test_indices: Vec<Vec<bool>>,
|
||||
}
|
||||
|
||||
impl Iterator for BaseKFoldIter {
|
||||
impl Iterator for KFoldIter {
|
||||
type Item = (Vec<usize>, Vec<usize>);
|
||||
|
||||
fn next(&mut self) -> Option<(Vec<usize>, Vec<usize>)> {
|
||||
@@ -133,7 +116,7 @@ impl Iterator for BaseKFoldIter {
|
||||
|
||||
/// Abstract class for all KFold functionalities
|
||||
impl BaseKFold for KFold {
|
||||
type Output = BaseKFoldIter;
|
||||
type Output = KFoldIter;
|
||||
|
||||
fn n_splits(&self) -> usize {
|
||||
self.n_splits
|
||||
@@ -148,7 +131,7 @@ impl BaseKFold for KFold {
|
||||
let mut test_indices = self.test_masks(x);
|
||||
test_indices.reverse();
|
||||
|
||||
BaseKFoldIter {
|
||||
KFoldIter {
|
||||
indices,
|
||||
test_indices,
|
||||
}
|
||||
|
||||
@@ -14,15 +14,27 @@ use crate::error::Failed;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::model_selection::kfold::BaseKFold;
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
pub mod kfold;
|
||||
pub(crate) mod kfold;
|
||||
|
||||
pub use kfold::{KFold, KFoldIter};
|
||||
|
||||
/// An interface for the K-Folds cross-validator
|
||||
pub trait BaseKFold {
|
||||
/// An iterator over indices that split data into training and test set.
|
||||
type Output: Iterator<Item = (Vec<usize>, Vec<usize>)>;
|
||||
/// Return a tuple containing the the training set indices for that split and
|
||||
/// the testing set indices for that split.
|
||||
fn split<T: RealNumber, M: Matrix<T>>(&self, x: &M) -> Self::Output;
|
||||
/// Returns the number of splits
|
||||
fn n_splits(&self) -> usize;
|
||||
}
|
||||
|
||||
/// Splits data into 2 disjoint datasets.
|
||||
/// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes.
|
||||
/// * `y` - target values, should be of size _M_
|
||||
/// * `y` - target values, should be of size _N_
|
||||
/// * `test_size`, (0, 1] - the proportion of the dataset to include in the test split.
|
||||
/// * `shuffle`, - whether or not to shuffle the data before splitting
|
||||
pub fn train_test_split<T: RealNumber, M: Matrix<T>>(
|
||||
@@ -65,22 +77,33 @@ pub fn train_test_split<T: RealNumber, M: Matrix<T>>(
|
||||
(x_train, x_test, y_train, y_test)
|
||||
}
|
||||
|
||||
/// Cross validation results.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CrossValidationResult<T: RealNumber> {
|
||||
/// Vector with test scores on each cv split
|
||||
pub test_score: Vec<T>,
|
||||
/// Vector with training scores on each cv split
|
||||
pub train_score: Vec<T>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> CrossValidationResult<T> {
|
||||
/// Average test score
|
||||
pub fn mean_test_score(&self) -> T {
|
||||
self.test_score.sum() / T::from_usize(self.test_score.len()).unwrap()
|
||||
}
|
||||
|
||||
/// Average training score
|
||||
pub fn mean_train_score(&self) -> T {
|
||||
self.train_score.sum() / T::from_usize(self.train_score.len()).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate an estimator by cross-validation using given metric.
|
||||
/// * `fit_estimator` - a `fit` function of an estimator
|
||||
/// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes.
|
||||
/// * `y` - target values, should be of size _N_
|
||||
/// * `parameters` - parameters of selected estimator. Use `Default::default()` for default parameters.
|
||||
/// * `cv` - the cross-validation splitting strategy, should be an instance of [`BaseKFold`](./trait.BaseKFold.html)
|
||||
/// * `score` - a metric to use for evaluation, see [metrics](../metrics/index.html)
|
||||
pub fn cross_validate<T, M, H, E, K, F, S>(
|
||||
fit_estimator: F,
|
||||
x: &M,
|
||||
@@ -302,7 +325,6 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_some_classifier() {
|
||||
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
@@ -334,8 +356,15 @@ mod tests {
|
||||
..KFold::default()
|
||||
};
|
||||
|
||||
let results =
|
||||
cross_validate(DecisionTreeClassifier::fit, &x, &y, Default::default(), cv, &accuracy).unwrap();
|
||||
let results = cross_validate(
|
||||
DecisionTreeClassifier::fit,
|
||||
&x,
|
||||
&y,
|
||||
Default::default(),
|
||||
cv,
|
||||
&accuracy,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
println!("{}", results.mean_test_score());
|
||||
println!("{}", results.mean_train_score());
|
||||
|
||||
Reference in New Issue
Block a user