Lmm/add seeds in more algorithms (#164)

* Provide better output in flaky tests

* feat: add seed parameter to multiple algorithms

* Update changelog

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
This commit is contained in:
morenol
2022-09-21 15:35:22 -04:00
committed by GitHub
parent 48514d1b15
commit 3a44161406
14 changed files with 139 additions and 64 deletions
+7 -4
View File
@@ -41,7 +41,7 @@
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
//! ];
//!
//! let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true);
//! let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true, None);
//!
//! println!("X train: {:?}, y train: {}, X test: {:?}, y test: {}",
//! x_train.shape(), y_train.len(), x_test.shape(), y_test.len());
@@ -107,8 +107,8 @@ use crate::error::Failed;
use crate::linalg::BaseVector;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
use crate::rand::get_rng_impl;
use rand::seq::SliceRandom;
use rand::thread_rng;
pub(crate) mod kfold;
@@ -130,11 +130,13 @@ pub trait BaseKFold {
/// * `y` - target values, should be of size _N_
/// * `test_size`, (0, 1] - the proportion of the dataset to include in the test split.
/// * `shuffle`, - whether or not to shuffle the data before splitting
/// * `seed` - Controls the shuffling applied to the data before applying the split. Pass an int for reproducible output across multiple function calls
pub fn train_test_split<T: RealNumber, M: Matrix<T>>(
x: &M,
y: &M::RowVector,
test_size: f32,
shuffle: bool,
seed: Option<u64>,
) -> (M, M, M::RowVector, M::RowVector) {
if x.shape().0 != y.len() {
panic!(
@@ -143,6 +145,7 @@ pub fn train_test_split<T: RealNumber, M: Matrix<T>>(
y.len()
);
}
let mut rng = get_rng_impl(seed);
if test_size <= 0. || test_size > 1.0 {
panic!("test_size should be between 0 and 1");
@@ -159,7 +162,7 @@ pub fn train_test_split<T: RealNumber, M: Matrix<T>>(
let mut indices: Vec<usize> = (0..n).collect();
if shuffle {
indices.shuffle(&mut thread_rng());
indices.shuffle(&mut rng);
}
let x_train = x.take(&indices[n_test..n], 0);
@@ -292,7 +295,7 @@ mod tests {
let x: DenseMatrix<f64> = DenseMatrix::rand(n, 3);
let y = vec![0f64; n];
let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true);
let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true, None);
assert!(
x_train.shape().0 > (n as f64 * 0.65) as usize