Lmm/add seeds in more algorithms (#164)
* Provide better output in flaky tests * feat: add seed parameter to multiple algorithms * Update changelog Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
This commit is contained in:
+17
-5
@@ -84,6 +84,7 @@ use crate::error::Failed;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::rand::get_rng_impl;
|
||||
use crate::svm::{Kernel, Kernels, LinearKernel};
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
@@ -100,6 +101,8 @@ pub struct SVCParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>
|
||||
pub kernel: K,
|
||||
/// Unused parameter.
|
||||
m: PhantomData<M>,
|
||||
/// Controls the pseudo random number generation for shuffling the data for probability estimates
|
||||
seed: Option<u64>,
|
||||
}
|
||||
|
||||
/// SVC grid search parameters
|
||||
@@ -279,8 +282,15 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVCParameters<T, M
|
||||
tol: self.tol,
|
||||
kernel,
|
||||
m: PhantomData,
|
||||
seed: self.seed,
|
||||
}
|
||||
}
|
||||
|
||||
/// Seed for the pseudo random number generator.
|
||||
pub fn with_seed(mut self, seed: Option<u64>) -> Self {
|
||||
self.seed = seed;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Default for SVCParameters<T, M, LinearKernel> {
|
||||
@@ -291,6 +301,7 @@ impl<T: RealNumber, M: Matrix<T>> Default for SVCParameters<T, M, LinearKernel>
|
||||
tol: T::from_f64(1e-3).unwrap(),
|
||||
kernel: Kernels::linear(),
|
||||
m: PhantomData,
|
||||
seed: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -511,7 +522,7 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
|
||||
let good_enough = T::from_i32(1000).unwrap();
|
||||
|
||||
for _ in 0..self.parameters.epoch {
|
||||
for i in Self::permutate(n) {
|
||||
for i in self.permutate(n) {
|
||||
self.process(i, self.x.get_row(i), self.y.get(i), &mut cache);
|
||||
loop {
|
||||
self.reprocess(tol, &mut cache);
|
||||
@@ -544,7 +555,7 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
|
||||
let mut cp = 0;
|
||||
let mut cn = 0;
|
||||
|
||||
for i in Self::permutate(n) {
|
||||
for i in self.permutate(n) {
|
||||
if self.y.get(i) == T::one() && cp < few {
|
||||
if self.process(i, self.x.get_row(i), self.y.get(i), cache) {
|
||||
cp += 1;
|
||||
@@ -669,8 +680,8 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
|
||||
self.recalculate_minmax_grad = true;
|
||||
}
|
||||
|
||||
fn permutate(n: usize) -> Vec<usize> {
|
||||
let mut rng = rand::thread_rng();
|
||||
fn permutate(&self, n: usize) -> Vec<usize> {
|
||||
let mut rng = get_rng_impl(self.parameters.seed);
|
||||
let mut range: Vec<usize> = (0..n).collect();
|
||||
range.shuffle(&mut rng);
|
||||
range
|
||||
@@ -893,7 +904,8 @@ mod tests {
|
||||
&y,
|
||||
SVCParameters::default()
|
||||
.with_c(200.0)
|
||||
.with_kernel(Kernels::linear()),
|
||||
.with_kernel(Kernels::linear())
|
||||
.with_seed(Some(100)),
|
||||
)
|
||||
.and_then(|lr| lr.predict(&x))
|
||||
.unwrap();
|
||||
|
||||
Reference in New Issue
Block a user