add seed param to search params (#168)

This commit is contained in:
Montana Low
2022-09-21 16:15:26 -07:00
committed by GitHub
parent 3a44161406
commit 403d3f2348
4 changed files with 61 additions and 0 deletions
+14
View File
@@ -119,6 +119,8 @@ pub struct SVCSearchParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowV
pub kernel: Vec<K>,
/// Unused parameter.
m: PhantomData<M>,
/// Controls the pseudo random number generation for shuffling the data for probability estimates
seed: Vec<Option<u64>>,
}
/// SVC grid search iterator
@@ -128,6 +130,7 @@ pub struct SVCSearchParametersIterator<T: RealNumber, M: Matrix<T>, K: Kernel<T,
current_c: usize,
current_tol: usize,
current_kernel: usize,
current_seed: usize,
}
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> IntoIterator
@@ -143,6 +146,7 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> IntoIterator
current_c: 0,
current_tol: 0,
current_kernel: 0,
current_seed: 0,
}
}
}
@@ -157,6 +161,7 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Iterator
&& self.current_c == self.svc_search_parameters.c.len()
&& self.current_tol == self.svc_search_parameters.tol.len()
&& self.current_kernel == self.svc_search_parameters.kernel.len()
&& self.current_seed == self.svc_search_parameters.kernel.len()
{
return None;
}
@@ -167,6 +172,7 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Iterator
tol: self.svc_search_parameters.tol[self.current_tol],
kernel: self.svc_search_parameters.kernel[self.current_kernel].clone(),
m: PhantomData,
seed: self.svc_search_parameters.seed[self.current_seed],
};
if self.current_epoch + 1 < self.svc_search_parameters.epoch.len() {
@@ -183,11 +189,18 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Iterator
self.current_c = 0;
self.current_tol = 0;
self.current_kernel += 1;
} else if self.current_kernel + 1 < self.svc_search_parameters.kernel.len() {
self.current_epoch = 0;
self.current_c = 0;
self.current_tol = 0;
self.current_kernel = 0;
self.current_seed += 1;
} else {
self.current_epoch += 1;
self.current_c += 1;
self.current_tol += 1;
self.current_kernel += 1;
self.current_seed += 1;
}
Some(next)
@@ -204,6 +217,7 @@ impl<T: RealNumber, M: Matrix<T>> Default for SVCSearchParameters<T, M, LinearKe
tol: vec![default_params.tol],
kernel: vec![default_params.kernel],
m: PhantomData,
seed: vec![default_params.seed],
}
}
}