Lmm/add seeds in more algorithms (#164)
* Provide better output in flaky tests * feat: add seed parameter to multiple algorithms * Update changelog Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
This commit is contained in:
+12
-11
@@ -2,23 +2,24 @@ name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, development ]
|
||||
branches: [main, development]
|
||||
pull_request:
|
||||
branches: [ development ]
|
||||
branches: [development]
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
runs-on: "${{ matrix.platform.os }}-latest"
|
||||
strategy:
|
||||
matrix:
|
||||
platform: [
|
||||
{ os: "windows", target: "x86_64-pc-windows-msvc" },
|
||||
{ os: "windows", target: "i686-pc-windows-msvc" },
|
||||
{ os: "ubuntu", target: "x86_64-unknown-linux-gnu" },
|
||||
{ os: "ubuntu", target: "i686-unknown-linux-gnu" },
|
||||
{ os: "ubuntu", target: "wasm32-unknown-unknown" },
|
||||
{ os: "macos", target: "aarch64-apple-darwin" },
|
||||
]
|
||||
platform:
|
||||
[
|
||||
{ os: "windows", target: "x86_64-pc-windows-msvc" },
|
||||
{ os: "windows", target: "i686-pc-windows-msvc" },
|
||||
{ os: "ubuntu", target: "x86_64-unknown-linux-gnu" },
|
||||
{ os: "ubuntu", target: "i686-unknown-linux-gnu" },
|
||||
{ os: "ubuntu", target: "wasm32-unknown-unknown" },
|
||||
{ os: "macos", target: "aarch64-apple-darwin" },
|
||||
]
|
||||
env:
|
||||
TZ: "/usr/share/zoneinfo/your/location"
|
||||
steps:
|
||||
@@ -40,7 +41,7 @@ jobs:
|
||||
default: true
|
||||
- name: Install test runner for wasm
|
||||
if: matrix.platform.target == 'wasm32-unknown-unknown'
|
||||
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
|
||||
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
|
||||
- name: Stable Build
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
|
||||
@@ -6,6 +6,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
## Added
|
||||
- Seeds to multiple algorithims that depend on random number generation.
|
||||
- Added feature `js` to use WASM in browser
|
||||
|
||||
## BREAKING CHANGE
|
||||
- Added a new parameter to `train_test_split` to define the seed.
|
||||
|
||||
## [0.2.1] - 2022-05-10
|
||||
|
||||
## Added
|
||||
- L2 regularization penalty to the Logistic Regression
|
||||
- Getters for the naive bayes structs
|
||||
|
||||
+7
-3
@@ -16,21 +16,25 @@ categories = ["science"]
|
||||
default = ["datasets"]
|
||||
ndarray-bindings = ["ndarray"]
|
||||
nalgebra-bindings = ["nalgebra"]
|
||||
datasets = ["rand_distr"]
|
||||
datasets = ["rand_distr", "std"]
|
||||
fp_bench = ["itertools"]
|
||||
std = ["rand/std", "rand/std_rng"]
|
||||
# wasm32 only
|
||||
js = ["getrandom/js"]
|
||||
|
||||
[dependencies]
|
||||
ndarray = { version = "0.15", optional = true }
|
||||
nalgebra = { version = "0.31", optional = true }
|
||||
num-traits = "0.2"
|
||||
num = "0.4"
|
||||
rand = "0.8"
|
||||
rand = { version = "0.8", default-features = false, features = ["small_rng"] }
|
||||
rand_distr = { version = "0.4", optional = true }
|
||||
serde = { version = "1", features = ["derive"], optional = true }
|
||||
itertools = { version = "0.10.3", optional = true }
|
||||
cfg-if = "1.0.0"
|
||||
|
||||
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
getrandom = { version = "0.2", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
smartcore = { path = ".", features = ["fp_bench"] }
|
||||
|
||||
@@ -52,10 +52,10 @@
|
||||
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 10.3.1 K-Means Clustering](http://faculty.marshall.usc.edu/gareth-james/ISL/)
|
||||
//! * ["k-means++: The Advantages of Careful Seeding", Arthur D., Vassilvitskii S.](http://ilpubs.stanford.edu:8090/778/1/2006-13.pdf)
|
||||
|
||||
use rand::Rng;
|
||||
use std::fmt::Debug;
|
||||
use std::iter::Sum;
|
||||
|
||||
use ::rand::Rng;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -65,6 +65,7 @@ use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::distance::euclidian::*;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::rand::get_rng_impl;
|
||||
|
||||
/// K-Means clustering algorithm
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
@@ -108,6 +109,9 @@ pub struct KMeansParameters {
|
||||
pub k: usize,
|
||||
/// Maximum number of iterations of the k-means algorithm for a single run.
|
||||
pub max_iter: usize,
|
||||
/// Determines random number generation for centroid initialization.
|
||||
/// Use an int to make the randomness deterministic
|
||||
pub seed: Option<u64>,
|
||||
}
|
||||
|
||||
impl KMeansParameters {
|
||||
@@ -128,6 +132,7 @@ impl Default for KMeansParameters {
|
||||
KMeansParameters {
|
||||
k: 2,
|
||||
max_iter: 100,
|
||||
seed: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -238,7 +243,7 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
||||
let (n, d) = data.shape();
|
||||
|
||||
let mut distortion = T::max_value();
|
||||
let mut y = KMeans::kmeans_plus_plus(data, parameters.k);
|
||||
let mut y = KMeans::kmeans_plus_plus(data, parameters.k, parameters.seed);
|
||||
let mut size = vec![0; parameters.k];
|
||||
let mut centroids = vec![vec![T::zero(); d]; parameters.k];
|
||||
|
||||
@@ -311,8 +316,8 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
||||
Ok(result.to_row_vector())
|
||||
}
|
||||
|
||||
fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize) -> Vec<usize> {
|
||||
let mut rng = rand::thread_rng();
|
||||
fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize, seed: Option<u64>) -> Vec<usize> {
|
||||
let mut rng = get_rng_impl(seed);
|
||||
let (n, m) = data.shape();
|
||||
let mut y = vec![0; n];
|
||||
let mut centroid = data.get_row_as_vec(rng.gen_range(0..n));
|
||||
|
||||
@@ -45,8 +45,8 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rand::Rng;
|
||||
|
||||
use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
|
||||
@@ -57,6 +57,7 @@ use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::rand::get_rng_impl;
|
||||
use crate::tree::decision_tree_classifier::{
|
||||
which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion,
|
||||
};
|
||||
@@ -441,7 +442,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let mut rng = StdRng::seed_from_u64(parameters.seed);
|
||||
let mut rng = get_rng_impl(Some(parameters.seed));
|
||||
let classes = y_m.unique();
|
||||
let k = classes.len();
|
||||
let mut trees: Vec<DecisionTreeClassifier<T>> = Vec::new();
|
||||
@@ -462,9 +463,9 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
max_depth: parameters.max_depth,
|
||||
min_samples_leaf: parameters.min_samples_leaf,
|
||||
min_samples_split: parameters.min_samples_split,
|
||||
seed: Some(parameters.seed),
|
||||
};
|
||||
let tree =
|
||||
DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
|
||||
let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params)?;
|
||||
trees.push(tree);
|
||||
}
|
||||
|
||||
|
||||
@@ -43,8 +43,8 @@
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rand::Rng;
|
||||
|
||||
use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
|
||||
@@ -55,6 +55,7 @@ use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::rand::get_rng_impl;
|
||||
use crate::tree::decision_tree_regressor::{
|
||||
DecisionTreeRegressor, DecisionTreeRegressorParameters,
|
||||
};
|
||||
@@ -376,7 +377,7 @@ impl<T: RealNumber> RandomForestRegressor<T> {
|
||||
.m
|
||||
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
|
||||
|
||||
let mut rng = StdRng::seed_from_u64(parameters.seed);
|
||||
let mut rng = get_rng_impl(Some(parameters.seed));
|
||||
let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new();
|
||||
|
||||
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
|
||||
@@ -393,9 +394,9 @@ impl<T: RealNumber> RandomForestRegressor<T> {
|
||||
max_depth: parameters.max_depth,
|
||||
min_samples_leaf: parameters.min_samples_leaf,
|
||||
min_samples_split: parameters.min_samples_split,
|
||||
seed: Some(parameters.seed),
|
||||
};
|
||||
let tree =
|
||||
DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
|
||||
let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params)?;
|
||||
trees.push(tree);
|
||||
}
|
||||
|
||||
|
||||
@@ -101,3 +101,5 @@ pub mod readers;
|
||||
pub mod svm;
|
||||
/// Supervised tree-based learning methods
|
||||
pub mod tree;
|
||||
|
||||
pub(crate) mod rand;
|
||||
|
||||
+4
-2
@@ -9,6 +9,8 @@ use std::iter::{Product, Sum};
|
||||
use std::ops::{AddAssign, DivAssign, MulAssign, SubAssign};
|
||||
use std::str::FromStr;
|
||||
|
||||
use crate::rand::get_rng_impl;
|
||||
|
||||
/// Defines real number
|
||||
/// <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
pub trait RealNumber:
|
||||
@@ -79,7 +81,7 @@ impl RealNumber for f64 {
|
||||
}
|
||||
|
||||
fn rand() -> f64 {
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut rng = get_rng_impl(None);
|
||||
rng.gen()
|
||||
}
|
||||
|
||||
@@ -124,7 +126,7 @@ impl RealNumber for f32 {
|
||||
}
|
||||
|
||||
fn rand() -> f32 {
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut rng = get_rng_impl(None);
|
||||
rng.gen()
|
||||
}
|
||||
|
||||
|
||||
@@ -5,8 +5,8 @@
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::model_selection::BaseKFold;
|
||||
use crate::rand::get_rng_impl;
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
/// K-Folds cross-validator
|
||||
pub struct KFold {
|
||||
@@ -14,6 +14,9 @@ pub struct KFold {
|
||||
pub n_splits: usize, // cannot exceed std::usize::MAX
|
||||
/// Whether to shuffle the data before splitting into batches
|
||||
pub shuffle: bool,
|
||||
/// When shuffle is True, seed affects the ordering of the indices.
|
||||
/// Which controls the randomness of each fold
|
||||
pub seed: Option<u64>,
|
||||
}
|
||||
|
||||
impl KFold {
|
||||
@@ -23,8 +26,10 @@ impl KFold {
|
||||
|
||||
// initialise indices
|
||||
let mut indices: Vec<usize> = (0..n_samples).collect();
|
||||
let mut rng = get_rng_impl(self.seed);
|
||||
|
||||
if self.shuffle {
|
||||
indices.shuffle(&mut thread_rng());
|
||||
indices.shuffle(&mut rng);
|
||||
}
|
||||
// return a new array of given shape n_split, filled with each element of n_samples divided by n_splits.
|
||||
let mut fold_sizes = vec![n_samples / self.n_splits; self.n_splits];
|
||||
@@ -66,6 +71,7 @@ impl Default for KFold {
|
||||
KFold {
|
||||
n_splits: 3,
|
||||
shuffle: true,
|
||||
seed: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -81,6 +87,12 @@ impl KFold {
|
||||
self.shuffle = shuffle;
|
||||
self
|
||||
}
|
||||
|
||||
/// When shuffle is True, random_state affects the ordering of the indices.
|
||||
pub fn with_seed(mut self, seed: Option<u64>) -> Self {
|
||||
self.seed = seed;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// An iterator over indices that split data into training and test set.
|
||||
@@ -150,6 +162,7 @@ mod tests {
|
||||
let k = KFold {
|
||||
n_splits: 3,
|
||||
shuffle: false,
|
||||
seed: None,
|
||||
};
|
||||
let x: DenseMatrix<f64> = DenseMatrix::rand(33, 100);
|
||||
let test_indices = k.test_indices(&x);
|
||||
@@ -165,6 +178,7 @@ mod tests {
|
||||
let k = KFold {
|
||||
n_splits: 3,
|
||||
shuffle: false,
|
||||
seed: None,
|
||||
};
|
||||
let x: DenseMatrix<f64> = DenseMatrix::rand(34, 100);
|
||||
let test_indices = k.test_indices(&x);
|
||||
@@ -180,6 +194,7 @@ mod tests {
|
||||
let k = KFold {
|
||||
n_splits: 2,
|
||||
shuffle: false,
|
||||
seed: None,
|
||||
};
|
||||
let x: DenseMatrix<f64> = DenseMatrix::rand(22, 100);
|
||||
let test_masks = k.test_masks(&x);
|
||||
@@ -206,6 +221,7 @@ mod tests {
|
||||
let k = KFold {
|
||||
n_splits: 2,
|
||||
shuffle: false,
|
||||
seed: None,
|
||||
};
|
||||
let x: DenseMatrix<f64> = DenseMatrix::rand(22, 100);
|
||||
let train_test_splits: Vec<(Vec<usize>, Vec<usize>)> = k.split(&x).collect();
|
||||
@@ -238,6 +254,7 @@ mod tests {
|
||||
let k = KFold {
|
||||
n_splits: 3,
|
||||
shuffle: false,
|
||||
seed: None,
|
||||
};
|
||||
let x: DenseMatrix<f64> = DenseMatrix::rand(10, 4);
|
||||
let expected: Vec<(Vec<usize>, Vec<usize>)> = vec![
|
||||
|
||||
@@ -41,7 +41,7 @@
|
||||
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
//! ];
|
||||
//!
|
||||
//! let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true);
|
||||
//! let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true, None);
|
||||
//!
|
||||
//! println!("X train: {:?}, y train: {}, X test: {:?}, y test: {}",
|
||||
//! x_train.shape(), y_train.len(), x_test.shape(), y_test.len());
|
||||
@@ -107,8 +107,8 @@ use crate::error::Failed;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::rand::get_rng_impl;
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
pub(crate) mod kfold;
|
||||
|
||||
@@ -130,11 +130,13 @@ pub trait BaseKFold {
|
||||
/// * `y` - target values, should be of size _N_
|
||||
/// * `test_size`, (0, 1] - the proportion of the dataset to include in the test split.
|
||||
/// * `shuffle`, - whether or not to shuffle the data before splitting
|
||||
/// * `seed` - Controls the shuffling applied to the data before applying the split. Pass an int for reproducible output across multiple function calls
|
||||
pub fn train_test_split<T: RealNumber, M: Matrix<T>>(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
test_size: f32,
|
||||
shuffle: bool,
|
||||
seed: Option<u64>,
|
||||
) -> (M, M, M::RowVector, M::RowVector) {
|
||||
if x.shape().0 != y.len() {
|
||||
panic!(
|
||||
@@ -143,6 +145,7 @@ pub fn train_test_split<T: RealNumber, M: Matrix<T>>(
|
||||
y.len()
|
||||
);
|
||||
}
|
||||
let mut rng = get_rng_impl(seed);
|
||||
|
||||
if test_size <= 0. || test_size > 1.0 {
|
||||
panic!("test_size should be between 0 and 1");
|
||||
@@ -159,7 +162,7 @@ pub fn train_test_split<T: RealNumber, M: Matrix<T>>(
|
||||
let mut indices: Vec<usize> = (0..n).collect();
|
||||
|
||||
if shuffle {
|
||||
indices.shuffle(&mut thread_rng());
|
||||
indices.shuffle(&mut rng);
|
||||
}
|
||||
|
||||
let x_train = x.take(&indices[n_test..n], 0);
|
||||
@@ -292,7 +295,7 @@ mod tests {
|
||||
let x: DenseMatrix<f64> = DenseMatrix::rand(n, 3);
|
||||
let y = vec![0f64; n];
|
||||
|
||||
let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true);
|
||||
let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true, None);
|
||||
|
||||
assert!(
|
||||
x_train.shape().0 > (n as f64 * 0.65) as usize
|
||||
|
||||
+21
@@ -0,0 +1,21 @@
|
||||
use ::rand::SeedableRng;
|
||||
#[cfg(not(feature = "std"))]
|
||||
use rand::rngs::SmallRng as RngImpl;
|
||||
#[cfg(feature = "std")]
|
||||
use rand::rngs::StdRng as RngImpl;
|
||||
|
||||
pub(crate) fn get_rng_impl(seed: Option<u64>) -> RngImpl {
|
||||
match seed {
|
||||
Some(seed) => RngImpl::seed_from_u64(seed),
|
||||
None => {
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(feature = "std")] {
|
||||
use rand::RngCore;
|
||||
RngImpl::seed_from_u64(rand::thread_rng().next_u64())
|
||||
} else {
|
||||
panic!("seed number needed for non-std build");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
+17
-5
@@ -84,6 +84,7 @@ use crate::error::Failed;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::rand::get_rng_impl;
|
||||
use crate::svm::{Kernel, Kernels, LinearKernel};
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
@@ -100,6 +101,8 @@ pub struct SVCParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>
|
||||
pub kernel: K,
|
||||
/// Unused parameter.
|
||||
m: PhantomData<M>,
|
||||
/// Controls the pseudo random number generation for shuffling the data for probability estimates
|
||||
seed: Option<u64>,
|
||||
}
|
||||
|
||||
/// SVC grid search parameters
|
||||
@@ -279,8 +282,15 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVCParameters<T, M
|
||||
tol: self.tol,
|
||||
kernel,
|
||||
m: PhantomData,
|
||||
seed: self.seed,
|
||||
}
|
||||
}
|
||||
|
||||
/// Seed for the pseudo random number generator.
|
||||
pub fn with_seed(mut self, seed: Option<u64>) -> Self {
|
||||
self.seed = seed;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Default for SVCParameters<T, M, LinearKernel> {
|
||||
@@ -291,6 +301,7 @@ impl<T: RealNumber, M: Matrix<T>> Default for SVCParameters<T, M, LinearKernel>
|
||||
tol: T::from_f64(1e-3).unwrap(),
|
||||
kernel: Kernels::linear(),
|
||||
m: PhantomData,
|
||||
seed: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -511,7 +522,7 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
|
||||
let good_enough = T::from_i32(1000).unwrap();
|
||||
|
||||
for _ in 0..self.parameters.epoch {
|
||||
for i in Self::permutate(n) {
|
||||
for i in self.permutate(n) {
|
||||
self.process(i, self.x.get_row(i), self.y.get(i), &mut cache);
|
||||
loop {
|
||||
self.reprocess(tol, &mut cache);
|
||||
@@ -544,7 +555,7 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
|
||||
let mut cp = 0;
|
||||
let mut cn = 0;
|
||||
|
||||
for i in Self::permutate(n) {
|
||||
for i in self.permutate(n) {
|
||||
if self.y.get(i) == T::one() && cp < few {
|
||||
if self.process(i, self.x.get_row(i), self.y.get(i), cache) {
|
||||
cp += 1;
|
||||
@@ -669,8 +680,8 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
|
||||
self.recalculate_minmax_grad = true;
|
||||
}
|
||||
|
||||
fn permutate(n: usize) -> Vec<usize> {
|
||||
let mut rng = rand::thread_rng();
|
||||
fn permutate(&self, n: usize) -> Vec<usize> {
|
||||
let mut rng = get_rng_impl(self.parameters.seed);
|
||||
let mut range: Vec<usize> = (0..n).collect();
|
||||
range.shuffle(&mut rng);
|
||||
range
|
||||
@@ -893,7 +904,8 @@ mod tests {
|
||||
&y,
|
||||
SVCParameters::default()
|
||||
.with_c(200.0)
|
||||
.with_kernel(Kernels::linear()),
|
||||
.with_kernel(Kernels::linear())
|
||||
.with_seed(Some(100)),
|
||||
)
|
||||
.and_then(|lr| lr.predict(&x))
|
||||
.unwrap();
|
||||
|
||||
@@ -77,6 +77,7 @@ use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::rand::get_rng_impl;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -90,6 +91,8 @@ pub struct DecisionTreeClassifierParameters {
|
||||
pub min_samples_leaf: usize,
|
||||
/// The minimum number of samples required to split an internal node.
|
||||
pub min_samples_split: usize,
|
||||
/// Controls the randomness of the estimator
|
||||
pub seed: Option<u64>,
|
||||
}
|
||||
|
||||
/// Decision Tree
|
||||
@@ -197,6 +200,7 @@ impl Default for DecisionTreeClassifierParameters {
|
||||
max_depth: None,
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2,
|
||||
seed: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -467,14 +471,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
) -> Result<DecisionTreeClassifier<T>, Failed> {
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let samples = vec![1; x_nrows];
|
||||
DecisionTreeClassifier::fit_weak_learner(
|
||||
x,
|
||||
y,
|
||||
samples,
|
||||
num_attributes,
|
||||
parameters,
|
||||
&mut rand::thread_rng(),
|
||||
)
|
||||
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||
}
|
||||
|
||||
pub(crate) fn fit_weak_learner<M: Matrix<T>>(
|
||||
@@ -483,7 +480,6 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
samples: Vec<usize>,
|
||||
mtry: usize,
|
||||
parameters: DecisionTreeClassifierParameters,
|
||||
rng: &mut impl Rng,
|
||||
) -> Result<DecisionTreeClassifier<T>, Failed> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
let (_, y_ncols) = y_m.shape();
|
||||
@@ -497,6 +493,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
)));
|
||||
}
|
||||
|
||||
let mut rng = get_rng_impl(parameters.seed);
|
||||
let mut yi: Vec<usize> = vec![0; y_ncols];
|
||||
|
||||
for (i, yi_i) in yi.iter_mut().enumerate().take(y_ncols) {
|
||||
@@ -531,13 +528,13 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
|
||||
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
|
||||
|
||||
if tree.find_best_cutoff(&mut visitor, mtry, rng) {
|
||||
if tree.find_best_cutoff(&mut visitor, mtry, &mut rng) {
|
||||
visitor_queue.push_back(visitor);
|
||||
}
|
||||
|
||||
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
||||
match visitor_queue.pop_front() {
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue, rng),
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue, &mut rng),
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
@@ -874,7 +871,8 @@ mod tests {
|
||||
criterion: SplitCriterion::Entropy,
|
||||
max_depth: Some(3),
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2
|
||||
min_samples_split: 2,
|
||||
seed: None
|
||||
}
|
||||
)
|
||||
.unwrap()
|
||||
|
||||
@@ -72,6 +72,7 @@ use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::rand::get_rng_impl;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -83,6 +84,8 @@ pub struct DecisionTreeRegressorParameters {
|
||||
pub min_samples_leaf: usize,
|
||||
/// The minimum number of samples required to split an internal node.
|
||||
pub min_samples_split: usize,
|
||||
/// Controls the randomness of the estimator
|
||||
pub seed: Option<u64>,
|
||||
}
|
||||
|
||||
/// Regression Tree
|
||||
@@ -130,6 +133,7 @@ impl Default for DecisionTreeRegressorParameters {
|
||||
max_depth: None,
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2,
|
||||
seed: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -357,14 +361,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
) -> Result<DecisionTreeRegressor<T>, Failed> {
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let samples = vec![1; x_nrows];
|
||||
DecisionTreeRegressor::fit_weak_learner(
|
||||
x,
|
||||
y,
|
||||
samples,
|
||||
num_attributes,
|
||||
parameters,
|
||||
&mut rand::thread_rng(),
|
||||
)
|
||||
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||
}
|
||||
|
||||
pub(crate) fn fit_weak_learner<M: Matrix<T>>(
|
||||
@@ -373,7 +370,6 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
samples: Vec<usize>,
|
||||
mtry: usize,
|
||||
parameters: DecisionTreeRegressorParameters,
|
||||
rng: &mut impl Rng,
|
||||
) -> Result<DecisionTreeRegressor<T>, Failed> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
|
||||
@@ -381,6 +377,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
let (_, num_attributes) = x.shape();
|
||||
|
||||
let mut nodes: Vec<Node<T>> = Vec::new();
|
||||
let mut rng = get_rng_impl(parameters.seed);
|
||||
|
||||
let mut n = 0;
|
||||
let mut sum = T::zero();
|
||||
@@ -407,13 +404,13 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
|
||||
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
|
||||
|
||||
if tree.find_best_cutoff(&mut visitor, mtry, rng) {
|
||||
if tree.find_best_cutoff(&mut visitor, mtry, &mut rng) {
|
||||
visitor_queue.push_back(visitor);
|
||||
}
|
||||
|
||||
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
||||
match visitor_queue.pop_front() {
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue, rng),
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue, &mut rng),
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
@@ -699,6 +696,7 @@ mod tests {
|
||||
max_depth: Option::None,
|
||||
min_samples_leaf: 2,
|
||||
min_samples_split: 6,
|
||||
seed: None,
|
||||
},
|
||||
)
|
||||
.and_then(|t| t.predict(&x))
|
||||
@@ -719,6 +717,7 @@ mod tests {
|
||||
max_depth: Option::None,
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 3,
|
||||
seed: None,
|
||||
},
|
||||
)
|
||||
.and_then(|t| t.predict(&x))
|
||||
|
||||
Reference in New Issue
Block a user