Lmm/add seeds in more algorithms (#164)
* Provide better output in flaky tests * feat: add seed parameter to multiple algorithms * Update changelog Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
This commit is contained in:
+12
-11
@@ -2,23 +2,24 @@ name: CI
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main, development ]
|
branches: [main, development]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ development ]
|
branches: [development]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
tests:
|
tests:
|
||||||
runs-on: "${{ matrix.platform.os }}-latest"
|
runs-on: "${{ matrix.platform.os }}-latest"
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
platform: [
|
platform:
|
||||||
{ os: "windows", target: "x86_64-pc-windows-msvc" },
|
[
|
||||||
{ os: "windows", target: "i686-pc-windows-msvc" },
|
{ os: "windows", target: "x86_64-pc-windows-msvc" },
|
||||||
{ os: "ubuntu", target: "x86_64-unknown-linux-gnu" },
|
{ os: "windows", target: "i686-pc-windows-msvc" },
|
||||||
{ os: "ubuntu", target: "i686-unknown-linux-gnu" },
|
{ os: "ubuntu", target: "x86_64-unknown-linux-gnu" },
|
||||||
{ os: "ubuntu", target: "wasm32-unknown-unknown" },
|
{ os: "ubuntu", target: "i686-unknown-linux-gnu" },
|
||||||
{ os: "macos", target: "aarch64-apple-darwin" },
|
{ os: "ubuntu", target: "wasm32-unknown-unknown" },
|
||||||
]
|
{ os: "macos", target: "aarch64-apple-darwin" },
|
||||||
|
]
|
||||||
env:
|
env:
|
||||||
TZ: "/usr/share/zoneinfo/your/location"
|
TZ: "/usr/share/zoneinfo/your/location"
|
||||||
steps:
|
steps:
|
||||||
@@ -40,7 +41,7 @@ jobs:
|
|||||||
default: true
|
default: true
|
||||||
- name: Install test runner for wasm
|
- name: Install test runner for wasm
|
||||||
if: matrix.platform.target == 'wasm32-unknown-unknown'
|
if: matrix.platform.target == 'wasm32-unknown-unknown'
|
||||||
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
|
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
|
||||||
- name: Stable Build
|
- name: Stable Build
|
||||||
uses: actions-rs/cargo@v1
|
uses: actions-rs/cargo@v1
|
||||||
with:
|
with:
|
||||||
|
|||||||
@@ -6,6 +6,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
|
|
||||||
|
## Added
|
||||||
|
- Seeds to multiple algorithims that depend on random number generation.
|
||||||
|
- Added feature `js` to use WASM in browser
|
||||||
|
|
||||||
|
## BREAKING CHANGE
|
||||||
|
- Added a new parameter to `train_test_split` to define the seed.
|
||||||
|
|
||||||
|
## [0.2.1] - 2022-05-10
|
||||||
|
|
||||||
## Added
|
## Added
|
||||||
- L2 regularization penalty to the Logistic Regression
|
- L2 regularization penalty to the Logistic Regression
|
||||||
- Getters for the naive bayes structs
|
- Getters for the naive bayes structs
|
||||||
|
|||||||
+7
-3
@@ -16,21 +16,25 @@ categories = ["science"]
|
|||||||
default = ["datasets"]
|
default = ["datasets"]
|
||||||
ndarray-bindings = ["ndarray"]
|
ndarray-bindings = ["ndarray"]
|
||||||
nalgebra-bindings = ["nalgebra"]
|
nalgebra-bindings = ["nalgebra"]
|
||||||
datasets = ["rand_distr"]
|
datasets = ["rand_distr", "std"]
|
||||||
fp_bench = ["itertools"]
|
fp_bench = ["itertools"]
|
||||||
|
std = ["rand/std", "rand/std_rng"]
|
||||||
|
# wasm32 only
|
||||||
|
js = ["getrandom/js"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
ndarray = { version = "0.15", optional = true }
|
ndarray = { version = "0.15", optional = true }
|
||||||
nalgebra = { version = "0.31", optional = true }
|
nalgebra = { version = "0.31", optional = true }
|
||||||
num-traits = "0.2"
|
num-traits = "0.2"
|
||||||
num = "0.4"
|
num = "0.4"
|
||||||
rand = "0.8"
|
rand = { version = "0.8", default-features = false, features = ["small_rng"] }
|
||||||
rand_distr = { version = "0.4", optional = true }
|
rand_distr = { version = "0.4", optional = true }
|
||||||
serde = { version = "1", features = ["derive"], optional = true }
|
serde = { version = "1", features = ["derive"], optional = true }
|
||||||
itertools = { version = "0.10.3", optional = true }
|
itertools = { version = "0.10.3", optional = true }
|
||||||
|
cfg-if = "1.0.0"
|
||||||
|
|
||||||
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||||
getrandom = { version = "0.2", features = ["js"] }
|
getrandom = { version = "0.2", optional = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
smartcore = { path = ".", features = ["fp_bench"] }
|
smartcore = { path = ".", features = ["fp_bench"] }
|
||||||
|
|||||||
@@ -52,10 +52,10 @@
|
|||||||
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 10.3.1 K-Means Clustering](http://faculty.marshall.usc.edu/gareth-james/ISL/)
|
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 10.3.1 K-Means Clustering](http://faculty.marshall.usc.edu/gareth-james/ISL/)
|
||||||
//! * ["k-means++: The Advantages of Careful Seeding", Arthur D., Vassilvitskii S.](http://ilpubs.stanford.edu:8090/778/1/2006-13.pdf)
|
//! * ["k-means++: The Advantages of Careful Seeding", Arthur D., Vassilvitskii S.](http://ilpubs.stanford.edu:8090/778/1/2006-13.pdf)
|
||||||
|
|
||||||
use rand::Rng;
|
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::iter::Sum;
|
use std::iter::Sum;
|
||||||
|
|
||||||
|
use ::rand::Rng;
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
@@ -65,6 +65,7 @@ use crate::error::Failed;
|
|||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::distance::euclidian::*;
|
use crate::math::distance::euclidian::*;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
use crate::rand::get_rng_impl;
|
||||||
|
|
||||||
/// K-Means clustering algorithm
|
/// K-Means clustering algorithm
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
@@ -108,6 +109,9 @@ pub struct KMeansParameters {
|
|||||||
pub k: usize,
|
pub k: usize,
|
||||||
/// Maximum number of iterations of the k-means algorithm for a single run.
|
/// Maximum number of iterations of the k-means algorithm for a single run.
|
||||||
pub max_iter: usize,
|
pub max_iter: usize,
|
||||||
|
/// Determines random number generation for centroid initialization.
|
||||||
|
/// Use an int to make the randomness deterministic
|
||||||
|
pub seed: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl KMeansParameters {
|
impl KMeansParameters {
|
||||||
@@ -128,6 +132,7 @@ impl Default for KMeansParameters {
|
|||||||
KMeansParameters {
|
KMeansParameters {
|
||||||
k: 2,
|
k: 2,
|
||||||
max_iter: 100,
|
max_iter: 100,
|
||||||
|
seed: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -238,7 +243,7 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
|||||||
let (n, d) = data.shape();
|
let (n, d) = data.shape();
|
||||||
|
|
||||||
let mut distortion = T::max_value();
|
let mut distortion = T::max_value();
|
||||||
let mut y = KMeans::kmeans_plus_plus(data, parameters.k);
|
let mut y = KMeans::kmeans_plus_plus(data, parameters.k, parameters.seed);
|
||||||
let mut size = vec![0; parameters.k];
|
let mut size = vec![0; parameters.k];
|
||||||
let mut centroids = vec![vec![T::zero(); d]; parameters.k];
|
let mut centroids = vec![vec![T::zero(); d]; parameters.k];
|
||||||
|
|
||||||
@@ -311,8 +316,8 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
|||||||
Ok(result.to_row_vector())
|
Ok(result.to_row_vector())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize) -> Vec<usize> {
|
fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize, seed: Option<u64>) -> Vec<usize> {
|
||||||
let mut rng = rand::thread_rng();
|
let mut rng = get_rng_impl(seed);
|
||||||
let (n, m) = data.shape();
|
let (n, m) = data.shape();
|
||||||
let mut y = vec![0; n];
|
let mut y = vec![0; n];
|
||||||
let mut centroid = data.get_row_as_vec(rng.gen_range(0..n));
|
let mut centroid = data.get_row_as_vec(rng.gen_range(0..n));
|
||||||
|
|||||||
@@ -45,8 +45,8 @@
|
|||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
use rand::rngs::StdRng;
|
use rand::Rng;
|
||||||
use rand::{Rng, SeedableRng};
|
|
||||||
use std::default::Default;
|
use std::default::Default;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
@@ -57,6 +57,7 @@ use crate::api::{Predictor, SupervisedEstimator};
|
|||||||
use crate::error::{Failed, FailedError};
|
use crate::error::{Failed, FailedError};
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
use crate::rand::get_rng_impl;
|
||||||
use crate::tree::decision_tree_classifier::{
|
use crate::tree::decision_tree_classifier::{
|
||||||
which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion,
|
which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion,
|
||||||
};
|
};
|
||||||
@@ -441,7 +442,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
|||||||
.unwrap()
|
.unwrap()
|
||||||
});
|
});
|
||||||
|
|
||||||
let mut rng = StdRng::seed_from_u64(parameters.seed);
|
let mut rng = get_rng_impl(Some(parameters.seed));
|
||||||
let classes = y_m.unique();
|
let classes = y_m.unique();
|
||||||
let k = classes.len();
|
let k = classes.len();
|
||||||
let mut trees: Vec<DecisionTreeClassifier<T>> = Vec::new();
|
let mut trees: Vec<DecisionTreeClassifier<T>> = Vec::new();
|
||||||
@@ -462,9 +463,9 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
|||||||
max_depth: parameters.max_depth,
|
max_depth: parameters.max_depth,
|
||||||
min_samples_leaf: parameters.min_samples_leaf,
|
min_samples_leaf: parameters.min_samples_leaf,
|
||||||
min_samples_split: parameters.min_samples_split,
|
min_samples_split: parameters.min_samples_split,
|
||||||
|
seed: Some(parameters.seed),
|
||||||
};
|
};
|
||||||
let tree =
|
let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params)?;
|
||||||
DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
|
|
||||||
trees.push(tree);
|
trees.push(tree);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -43,8 +43,8 @@
|
|||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
|
||||||
use rand::rngs::StdRng;
|
use rand::Rng;
|
||||||
use rand::{Rng, SeedableRng};
|
|
||||||
use std::default::Default;
|
use std::default::Default;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
@@ -55,6 +55,7 @@ use crate::api::{Predictor, SupervisedEstimator};
|
|||||||
use crate::error::{Failed, FailedError};
|
use crate::error::{Failed, FailedError};
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
use crate::rand::get_rng_impl;
|
||||||
use crate::tree::decision_tree_regressor::{
|
use crate::tree::decision_tree_regressor::{
|
||||||
DecisionTreeRegressor, DecisionTreeRegressorParameters,
|
DecisionTreeRegressor, DecisionTreeRegressorParameters,
|
||||||
};
|
};
|
||||||
@@ -376,7 +377,7 @@ impl<T: RealNumber> RandomForestRegressor<T> {
|
|||||||
.m
|
.m
|
||||||
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
|
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
|
||||||
|
|
||||||
let mut rng = StdRng::seed_from_u64(parameters.seed);
|
let mut rng = get_rng_impl(Some(parameters.seed));
|
||||||
let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new();
|
let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new();
|
||||||
|
|
||||||
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
|
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
|
||||||
@@ -393,9 +394,9 @@ impl<T: RealNumber> RandomForestRegressor<T> {
|
|||||||
max_depth: parameters.max_depth,
|
max_depth: parameters.max_depth,
|
||||||
min_samples_leaf: parameters.min_samples_leaf,
|
min_samples_leaf: parameters.min_samples_leaf,
|
||||||
min_samples_split: parameters.min_samples_split,
|
min_samples_split: parameters.min_samples_split,
|
||||||
|
seed: Some(parameters.seed),
|
||||||
};
|
};
|
||||||
let tree =
|
let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params)?;
|
||||||
DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
|
|
||||||
trees.push(tree);
|
trees.push(tree);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -101,3 +101,5 @@ pub mod readers;
|
|||||||
pub mod svm;
|
pub mod svm;
|
||||||
/// Supervised tree-based learning methods
|
/// Supervised tree-based learning methods
|
||||||
pub mod tree;
|
pub mod tree;
|
||||||
|
|
||||||
|
pub(crate) mod rand;
|
||||||
|
|||||||
+4
-2
@@ -9,6 +9,8 @@ use std::iter::{Product, Sum};
|
|||||||
use std::ops::{AddAssign, DivAssign, MulAssign, SubAssign};
|
use std::ops::{AddAssign, DivAssign, MulAssign, SubAssign};
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
|
||||||
|
use crate::rand::get_rng_impl;
|
||||||
|
|
||||||
/// Defines real number
|
/// Defines real number
|
||||||
/// <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
/// <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||||
pub trait RealNumber:
|
pub trait RealNumber:
|
||||||
@@ -79,7 +81,7 @@ impl RealNumber for f64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn rand() -> f64 {
|
fn rand() -> f64 {
|
||||||
let mut rng = rand::thread_rng();
|
let mut rng = get_rng_impl(None);
|
||||||
rng.gen()
|
rng.gen()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,7 +126,7 @@ impl RealNumber for f32 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn rand() -> f32 {
|
fn rand() -> f32 {
|
||||||
let mut rng = rand::thread_rng();
|
let mut rng = get_rng_impl(None);
|
||||||
rng.gen()
|
rng.gen()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,8 +5,8 @@
|
|||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
use crate::model_selection::BaseKFold;
|
use crate::model_selection::BaseKFold;
|
||||||
|
use crate::rand::get_rng_impl;
|
||||||
use rand::seq::SliceRandom;
|
use rand::seq::SliceRandom;
|
||||||
use rand::thread_rng;
|
|
||||||
|
|
||||||
/// K-Folds cross-validator
|
/// K-Folds cross-validator
|
||||||
pub struct KFold {
|
pub struct KFold {
|
||||||
@@ -14,6 +14,9 @@ pub struct KFold {
|
|||||||
pub n_splits: usize, // cannot exceed std::usize::MAX
|
pub n_splits: usize, // cannot exceed std::usize::MAX
|
||||||
/// Whether to shuffle the data before splitting into batches
|
/// Whether to shuffle the data before splitting into batches
|
||||||
pub shuffle: bool,
|
pub shuffle: bool,
|
||||||
|
/// When shuffle is True, seed affects the ordering of the indices.
|
||||||
|
/// Which controls the randomness of each fold
|
||||||
|
pub seed: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl KFold {
|
impl KFold {
|
||||||
@@ -23,8 +26,10 @@ impl KFold {
|
|||||||
|
|
||||||
// initialise indices
|
// initialise indices
|
||||||
let mut indices: Vec<usize> = (0..n_samples).collect();
|
let mut indices: Vec<usize> = (0..n_samples).collect();
|
||||||
|
let mut rng = get_rng_impl(self.seed);
|
||||||
|
|
||||||
if self.shuffle {
|
if self.shuffle {
|
||||||
indices.shuffle(&mut thread_rng());
|
indices.shuffle(&mut rng);
|
||||||
}
|
}
|
||||||
// return a new array of given shape n_split, filled with each element of n_samples divided by n_splits.
|
// return a new array of given shape n_split, filled with each element of n_samples divided by n_splits.
|
||||||
let mut fold_sizes = vec![n_samples / self.n_splits; self.n_splits];
|
let mut fold_sizes = vec![n_samples / self.n_splits; self.n_splits];
|
||||||
@@ -66,6 +71,7 @@ impl Default for KFold {
|
|||||||
KFold {
|
KFold {
|
||||||
n_splits: 3,
|
n_splits: 3,
|
||||||
shuffle: true,
|
shuffle: true,
|
||||||
|
seed: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -81,6 +87,12 @@ impl KFold {
|
|||||||
self.shuffle = shuffle;
|
self.shuffle = shuffle;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// When shuffle is True, random_state affects the ordering of the indices.
|
||||||
|
pub fn with_seed(mut self, seed: Option<u64>) -> Self {
|
||||||
|
self.seed = seed;
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// An iterator over indices that split data into training and test set.
|
/// An iterator over indices that split data into training and test set.
|
||||||
@@ -150,6 +162,7 @@ mod tests {
|
|||||||
let k = KFold {
|
let k = KFold {
|
||||||
n_splits: 3,
|
n_splits: 3,
|
||||||
shuffle: false,
|
shuffle: false,
|
||||||
|
seed: None,
|
||||||
};
|
};
|
||||||
let x: DenseMatrix<f64> = DenseMatrix::rand(33, 100);
|
let x: DenseMatrix<f64> = DenseMatrix::rand(33, 100);
|
||||||
let test_indices = k.test_indices(&x);
|
let test_indices = k.test_indices(&x);
|
||||||
@@ -165,6 +178,7 @@ mod tests {
|
|||||||
let k = KFold {
|
let k = KFold {
|
||||||
n_splits: 3,
|
n_splits: 3,
|
||||||
shuffle: false,
|
shuffle: false,
|
||||||
|
seed: None,
|
||||||
};
|
};
|
||||||
let x: DenseMatrix<f64> = DenseMatrix::rand(34, 100);
|
let x: DenseMatrix<f64> = DenseMatrix::rand(34, 100);
|
||||||
let test_indices = k.test_indices(&x);
|
let test_indices = k.test_indices(&x);
|
||||||
@@ -180,6 +194,7 @@ mod tests {
|
|||||||
let k = KFold {
|
let k = KFold {
|
||||||
n_splits: 2,
|
n_splits: 2,
|
||||||
shuffle: false,
|
shuffle: false,
|
||||||
|
seed: None,
|
||||||
};
|
};
|
||||||
let x: DenseMatrix<f64> = DenseMatrix::rand(22, 100);
|
let x: DenseMatrix<f64> = DenseMatrix::rand(22, 100);
|
||||||
let test_masks = k.test_masks(&x);
|
let test_masks = k.test_masks(&x);
|
||||||
@@ -206,6 +221,7 @@ mod tests {
|
|||||||
let k = KFold {
|
let k = KFold {
|
||||||
n_splits: 2,
|
n_splits: 2,
|
||||||
shuffle: false,
|
shuffle: false,
|
||||||
|
seed: None,
|
||||||
};
|
};
|
||||||
let x: DenseMatrix<f64> = DenseMatrix::rand(22, 100);
|
let x: DenseMatrix<f64> = DenseMatrix::rand(22, 100);
|
||||||
let train_test_splits: Vec<(Vec<usize>, Vec<usize>)> = k.split(&x).collect();
|
let train_test_splits: Vec<(Vec<usize>, Vec<usize>)> = k.split(&x).collect();
|
||||||
@@ -238,6 +254,7 @@ mod tests {
|
|||||||
let k = KFold {
|
let k = KFold {
|
||||||
n_splits: 3,
|
n_splits: 3,
|
||||||
shuffle: false,
|
shuffle: false,
|
||||||
|
seed: None,
|
||||||
};
|
};
|
||||||
let x: DenseMatrix<f64> = DenseMatrix::rand(10, 4);
|
let x: DenseMatrix<f64> = DenseMatrix::rand(10, 4);
|
||||||
let expected: Vec<(Vec<usize>, Vec<usize>)> = vec![
|
let expected: Vec<(Vec<usize>, Vec<usize>)> = vec![
|
||||||
|
|||||||
@@ -41,7 +41,7 @@
|
|||||||
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||||
//! ];
|
//! ];
|
||||||
//!
|
//!
|
||||||
//! let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true);
|
//! let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true, None);
|
||||||
//!
|
//!
|
||||||
//! println!("X train: {:?}, y train: {}, X test: {:?}, y test: {}",
|
//! println!("X train: {:?}, y train: {}, X test: {:?}, y test: {}",
|
||||||
//! x_train.shape(), y_train.len(), x_test.shape(), y_test.len());
|
//! x_train.shape(), y_train.len(), x_test.shape(), y_test.len());
|
||||||
@@ -107,8 +107,8 @@ use crate::error::Failed;
|
|||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
use crate::rand::get_rng_impl;
|
||||||
use rand::seq::SliceRandom;
|
use rand::seq::SliceRandom;
|
||||||
use rand::thread_rng;
|
|
||||||
|
|
||||||
pub(crate) mod kfold;
|
pub(crate) mod kfold;
|
||||||
|
|
||||||
@@ -130,11 +130,13 @@ pub trait BaseKFold {
|
|||||||
/// * `y` - target values, should be of size _N_
|
/// * `y` - target values, should be of size _N_
|
||||||
/// * `test_size`, (0, 1] - the proportion of the dataset to include in the test split.
|
/// * `test_size`, (0, 1] - the proportion of the dataset to include in the test split.
|
||||||
/// * `shuffle`, - whether or not to shuffle the data before splitting
|
/// * `shuffle`, - whether or not to shuffle the data before splitting
|
||||||
|
/// * `seed` - Controls the shuffling applied to the data before applying the split. Pass an int for reproducible output across multiple function calls
|
||||||
pub fn train_test_split<T: RealNumber, M: Matrix<T>>(
|
pub fn train_test_split<T: RealNumber, M: Matrix<T>>(
|
||||||
x: &M,
|
x: &M,
|
||||||
y: &M::RowVector,
|
y: &M::RowVector,
|
||||||
test_size: f32,
|
test_size: f32,
|
||||||
shuffle: bool,
|
shuffle: bool,
|
||||||
|
seed: Option<u64>,
|
||||||
) -> (M, M, M::RowVector, M::RowVector) {
|
) -> (M, M, M::RowVector, M::RowVector) {
|
||||||
if x.shape().0 != y.len() {
|
if x.shape().0 != y.len() {
|
||||||
panic!(
|
panic!(
|
||||||
@@ -143,6 +145,7 @@ pub fn train_test_split<T: RealNumber, M: Matrix<T>>(
|
|||||||
y.len()
|
y.len()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
let mut rng = get_rng_impl(seed);
|
||||||
|
|
||||||
if test_size <= 0. || test_size > 1.0 {
|
if test_size <= 0. || test_size > 1.0 {
|
||||||
panic!("test_size should be between 0 and 1");
|
panic!("test_size should be between 0 and 1");
|
||||||
@@ -159,7 +162,7 @@ pub fn train_test_split<T: RealNumber, M: Matrix<T>>(
|
|||||||
let mut indices: Vec<usize> = (0..n).collect();
|
let mut indices: Vec<usize> = (0..n).collect();
|
||||||
|
|
||||||
if shuffle {
|
if shuffle {
|
||||||
indices.shuffle(&mut thread_rng());
|
indices.shuffle(&mut rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
let x_train = x.take(&indices[n_test..n], 0);
|
let x_train = x.take(&indices[n_test..n], 0);
|
||||||
@@ -292,7 +295,7 @@ mod tests {
|
|||||||
let x: DenseMatrix<f64> = DenseMatrix::rand(n, 3);
|
let x: DenseMatrix<f64> = DenseMatrix::rand(n, 3);
|
||||||
let y = vec![0f64; n];
|
let y = vec![0f64; n];
|
||||||
|
|
||||||
let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true);
|
let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true, None);
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
x_train.shape().0 > (n as f64 * 0.65) as usize
|
x_train.shape().0 > (n as f64 * 0.65) as usize
|
||||||
|
|||||||
+21
@@ -0,0 +1,21 @@
|
|||||||
|
use ::rand::SeedableRng;
|
||||||
|
#[cfg(not(feature = "std"))]
|
||||||
|
use rand::rngs::SmallRng as RngImpl;
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
use rand::rngs::StdRng as RngImpl;
|
||||||
|
|
||||||
|
pub(crate) fn get_rng_impl(seed: Option<u64>) -> RngImpl {
|
||||||
|
match seed {
|
||||||
|
Some(seed) => RngImpl::seed_from_u64(seed),
|
||||||
|
None => {
|
||||||
|
cfg_if::cfg_if! {
|
||||||
|
if #[cfg(feature = "std")] {
|
||||||
|
use rand::RngCore;
|
||||||
|
RngImpl::seed_from_u64(rand::thread_rng().next_u64())
|
||||||
|
} else {
|
||||||
|
panic!("seed number needed for non-std build");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
+17
-5
@@ -84,6 +84,7 @@ use crate::error::Failed;
|
|||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
use crate::rand::get_rng_impl;
|
||||||
use crate::svm::{Kernel, Kernels, LinearKernel};
|
use crate::svm::{Kernel, Kernels, LinearKernel};
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
@@ -100,6 +101,8 @@ pub struct SVCParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>
|
|||||||
pub kernel: K,
|
pub kernel: K,
|
||||||
/// Unused parameter.
|
/// Unused parameter.
|
||||||
m: PhantomData<M>,
|
m: PhantomData<M>,
|
||||||
|
/// Controls the pseudo random number generation for shuffling the data for probability estimates
|
||||||
|
seed: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// SVC grid search parameters
|
/// SVC grid search parameters
|
||||||
@@ -279,8 +282,15 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVCParameters<T, M
|
|||||||
tol: self.tol,
|
tol: self.tol,
|
||||||
kernel,
|
kernel,
|
||||||
m: PhantomData,
|
m: PhantomData,
|
||||||
|
seed: self.seed,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Seed for the pseudo random number generator.
|
||||||
|
pub fn with_seed(mut self, seed: Option<u64>) -> Self {
|
||||||
|
self.seed = seed;
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: RealNumber, M: Matrix<T>> Default for SVCParameters<T, M, LinearKernel> {
|
impl<T: RealNumber, M: Matrix<T>> Default for SVCParameters<T, M, LinearKernel> {
|
||||||
@@ -291,6 +301,7 @@ impl<T: RealNumber, M: Matrix<T>> Default for SVCParameters<T, M, LinearKernel>
|
|||||||
tol: T::from_f64(1e-3).unwrap(),
|
tol: T::from_f64(1e-3).unwrap(),
|
||||||
kernel: Kernels::linear(),
|
kernel: Kernels::linear(),
|
||||||
m: PhantomData,
|
m: PhantomData,
|
||||||
|
seed: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -511,7 +522,7 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
|
|||||||
let good_enough = T::from_i32(1000).unwrap();
|
let good_enough = T::from_i32(1000).unwrap();
|
||||||
|
|
||||||
for _ in 0..self.parameters.epoch {
|
for _ in 0..self.parameters.epoch {
|
||||||
for i in Self::permutate(n) {
|
for i in self.permutate(n) {
|
||||||
self.process(i, self.x.get_row(i), self.y.get(i), &mut cache);
|
self.process(i, self.x.get_row(i), self.y.get(i), &mut cache);
|
||||||
loop {
|
loop {
|
||||||
self.reprocess(tol, &mut cache);
|
self.reprocess(tol, &mut cache);
|
||||||
@@ -544,7 +555,7 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
|
|||||||
let mut cp = 0;
|
let mut cp = 0;
|
||||||
let mut cn = 0;
|
let mut cn = 0;
|
||||||
|
|
||||||
for i in Self::permutate(n) {
|
for i in self.permutate(n) {
|
||||||
if self.y.get(i) == T::one() && cp < few {
|
if self.y.get(i) == T::one() && cp < few {
|
||||||
if self.process(i, self.x.get_row(i), self.y.get(i), cache) {
|
if self.process(i, self.x.get_row(i), self.y.get(i), cache) {
|
||||||
cp += 1;
|
cp += 1;
|
||||||
@@ -669,8 +680,8 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
|
|||||||
self.recalculate_minmax_grad = true;
|
self.recalculate_minmax_grad = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn permutate(n: usize) -> Vec<usize> {
|
fn permutate(&self, n: usize) -> Vec<usize> {
|
||||||
let mut rng = rand::thread_rng();
|
let mut rng = get_rng_impl(self.parameters.seed);
|
||||||
let mut range: Vec<usize> = (0..n).collect();
|
let mut range: Vec<usize> = (0..n).collect();
|
||||||
range.shuffle(&mut rng);
|
range.shuffle(&mut rng);
|
||||||
range
|
range
|
||||||
@@ -893,7 +904,8 @@ mod tests {
|
|||||||
&y,
|
&y,
|
||||||
SVCParameters::default()
|
SVCParameters::default()
|
||||||
.with_c(200.0)
|
.with_c(200.0)
|
||||||
.with_kernel(Kernels::linear()),
|
.with_kernel(Kernels::linear())
|
||||||
|
.with_seed(Some(100)),
|
||||||
)
|
)
|
||||||
.and_then(|lr| lr.predict(&x))
|
.and_then(|lr| lr.predict(&x))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|||||||
@@ -77,6 +77,7 @@ use crate::api::{Predictor, SupervisedEstimator};
|
|||||||
use crate::error::Failed;
|
use crate::error::Failed;
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
use crate::rand::get_rng_impl;
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@@ -90,6 +91,8 @@ pub struct DecisionTreeClassifierParameters {
|
|||||||
pub min_samples_leaf: usize,
|
pub min_samples_leaf: usize,
|
||||||
/// The minimum number of samples required to split an internal node.
|
/// The minimum number of samples required to split an internal node.
|
||||||
pub min_samples_split: usize,
|
pub min_samples_split: usize,
|
||||||
|
/// Controls the randomness of the estimator
|
||||||
|
pub seed: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Decision Tree
|
/// Decision Tree
|
||||||
@@ -197,6 +200,7 @@ impl Default for DecisionTreeClassifierParameters {
|
|||||||
max_depth: None,
|
max_depth: None,
|
||||||
min_samples_leaf: 1,
|
min_samples_leaf: 1,
|
||||||
min_samples_split: 2,
|
min_samples_split: 2,
|
||||||
|
seed: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -467,14 +471,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
|||||||
) -> Result<DecisionTreeClassifier<T>, Failed> {
|
) -> Result<DecisionTreeClassifier<T>, Failed> {
|
||||||
let (x_nrows, num_attributes) = x.shape();
|
let (x_nrows, num_attributes) = x.shape();
|
||||||
let samples = vec![1; x_nrows];
|
let samples = vec![1; x_nrows];
|
||||||
DecisionTreeClassifier::fit_weak_learner(
|
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||||
x,
|
|
||||||
y,
|
|
||||||
samples,
|
|
||||||
num_attributes,
|
|
||||||
parameters,
|
|
||||||
&mut rand::thread_rng(),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn fit_weak_learner<M: Matrix<T>>(
|
pub(crate) fn fit_weak_learner<M: Matrix<T>>(
|
||||||
@@ -483,7 +480,6 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
|||||||
samples: Vec<usize>,
|
samples: Vec<usize>,
|
||||||
mtry: usize,
|
mtry: usize,
|
||||||
parameters: DecisionTreeClassifierParameters,
|
parameters: DecisionTreeClassifierParameters,
|
||||||
rng: &mut impl Rng,
|
|
||||||
) -> Result<DecisionTreeClassifier<T>, Failed> {
|
) -> Result<DecisionTreeClassifier<T>, Failed> {
|
||||||
let y_m = M::from_row_vector(y.clone());
|
let y_m = M::from_row_vector(y.clone());
|
||||||
let (_, y_ncols) = y_m.shape();
|
let (_, y_ncols) = y_m.shape();
|
||||||
@@ -497,6 +493,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
|||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let mut rng = get_rng_impl(parameters.seed);
|
||||||
let mut yi: Vec<usize> = vec![0; y_ncols];
|
let mut yi: Vec<usize> = vec![0; y_ncols];
|
||||||
|
|
||||||
for (i, yi_i) in yi.iter_mut().enumerate().take(y_ncols) {
|
for (i, yi_i) in yi.iter_mut().enumerate().take(y_ncols) {
|
||||||
@@ -531,13 +528,13 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
|||||||
|
|
||||||
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
|
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
|
||||||
|
|
||||||
if tree.find_best_cutoff(&mut visitor, mtry, rng) {
|
if tree.find_best_cutoff(&mut visitor, mtry, &mut rng) {
|
||||||
visitor_queue.push_back(visitor);
|
visitor_queue.push_back(visitor);
|
||||||
}
|
}
|
||||||
|
|
||||||
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
||||||
match visitor_queue.pop_front() {
|
match visitor_queue.pop_front() {
|
||||||
Some(node) => tree.split(node, mtry, &mut visitor_queue, rng),
|
Some(node) => tree.split(node, mtry, &mut visitor_queue, &mut rng),
|
||||||
None => break,
|
None => break,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@@ -874,7 +871,8 @@ mod tests {
|
|||||||
criterion: SplitCriterion::Entropy,
|
criterion: SplitCriterion::Entropy,
|
||||||
max_depth: Some(3),
|
max_depth: Some(3),
|
||||||
min_samples_leaf: 1,
|
min_samples_leaf: 1,
|
||||||
min_samples_split: 2
|
min_samples_split: 2,
|
||||||
|
seed: None
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ use crate::api::{Predictor, SupervisedEstimator};
|
|||||||
use crate::error::Failed;
|
use crate::error::Failed;
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
use crate::rand::get_rng_impl;
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@@ -83,6 +84,8 @@ pub struct DecisionTreeRegressorParameters {
|
|||||||
pub min_samples_leaf: usize,
|
pub min_samples_leaf: usize,
|
||||||
/// The minimum number of samples required to split an internal node.
|
/// The minimum number of samples required to split an internal node.
|
||||||
pub min_samples_split: usize,
|
pub min_samples_split: usize,
|
||||||
|
/// Controls the randomness of the estimator
|
||||||
|
pub seed: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Regression Tree
|
/// Regression Tree
|
||||||
@@ -130,6 +133,7 @@ impl Default for DecisionTreeRegressorParameters {
|
|||||||
max_depth: None,
|
max_depth: None,
|
||||||
min_samples_leaf: 1,
|
min_samples_leaf: 1,
|
||||||
min_samples_split: 2,
|
min_samples_split: 2,
|
||||||
|
seed: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -357,14 +361,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
|||||||
) -> Result<DecisionTreeRegressor<T>, Failed> {
|
) -> Result<DecisionTreeRegressor<T>, Failed> {
|
||||||
let (x_nrows, num_attributes) = x.shape();
|
let (x_nrows, num_attributes) = x.shape();
|
||||||
let samples = vec![1; x_nrows];
|
let samples = vec![1; x_nrows];
|
||||||
DecisionTreeRegressor::fit_weak_learner(
|
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||||
x,
|
|
||||||
y,
|
|
||||||
samples,
|
|
||||||
num_attributes,
|
|
||||||
parameters,
|
|
||||||
&mut rand::thread_rng(),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn fit_weak_learner<M: Matrix<T>>(
|
pub(crate) fn fit_weak_learner<M: Matrix<T>>(
|
||||||
@@ -373,7 +370,6 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
|||||||
samples: Vec<usize>,
|
samples: Vec<usize>,
|
||||||
mtry: usize,
|
mtry: usize,
|
||||||
parameters: DecisionTreeRegressorParameters,
|
parameters: DecisionTreeRegressorParameters,
|
||||||
rng: &mut impl Rng,
|
|
||||||
) -> Result<DecisionTreeRegressor<T>, Failed> {
|
) -> Result<DecisionTreeRegressor<T>, Failed> {
|
||||||
let y_m = M::from_row_vector(y.clone());
|
let y_m = M::from_row_vector(y.clone());
|
||||||
|
|
||||||
@@ -381,6 +377,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
|||||||
let (_, num_attributes) = x.shape();
|
let (_, num_attributes) = x.shape();
|
||||||
|
|
||||||
let mut nodes: Vec<Node<T>> = Vec::new();
|
let mut nodes: Vec<Node<T>> = Vec::new();
|
||||||
|
let mut rng = get_rng_impl(parameters.seed);
|
||||||
|
|
||||||
let mut n = 0;
|
let mut n = 0;
|
||||||
let mut sum = T::zero();
|
let mut sum = T::zero();
|
||||||
@@ -407,13 +404,13 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
|||||||
|
|
||||||
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
|
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
|
||||||
|
|
||||||
if tree.find_best_cutoff(&mut visitor, mtry, rng) {
|
if tree.find_best_cutoff(&mut visitor, mtry, &mut rng) {
|
||||||
visitor_queue.push_back(visitor);
|
visitor_queue.push_back(visitor);
|
||||||
}
|
}
|
||||||
|
|
||||||
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
||||||
match visitor_queue.pop_front() {
|
match visitor_queue.pop_front() {
|
||||||
Some(node) => tree.split(node, mtry, &mut visitor_queue, rng),
|
Some(node) => tree.split(node, mtry, &mut visitor_queue, &mut rng),
|
||||||
None => break,
|
None => break,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@@ -699,6 +696,7 @@ mod tests {
|
|||||||
max_depth: Option::None,
|
max_depth: Option::None,
|
||||||
min_samples_leaf: 2,
|
min_samples_leaf: 2,
|
||||||
min_samples_split: 6,
|
min_samples_split: 6,
|
||||||
|
seed: None,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.and_then(|t| t.predict(&x))
|
.and_then(|t| t.predict(&x))
|
||||||
@@ -719,6 +717,7 @@ mod tests {
|
|||||||
max_depth: Option::None,
|
max_depth: Option::None,
|
||||||
min_samples_leaf: 1,
|
min_samples_leaf: 1,
|
||||||
min_samples_split: 3,
|
min_samples_split: 3,
|
||||||
|
seed: None,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.and_then(|t| t.predict(&x))
|
.and_then(|t| t.predict(&x))
|
||||||
|
|||||||
Reference in New Issue
Block a user