diff --git a/src/svm/mod.rs b/src/svm/mod.rs index 3f3c7eb..d98a0ab 100644 --- a/src/svm/mod.rs +++ b/src/svm/mod.rs @@ -24,6 +24,8 @@ //! pub mod svc; pub mod svr; +/// search parameters +pub mod search; use core::fmt::Debug; use std::marker::PhantomData; diff --git a/src/svm/search/mod.rs b/src/svm/search/mod.rs index e69de29..0d67cc4 100644 --- a/src/svm/search/mod.rs +++ b/src/svm/search/mod.rs @@ -0,0 +1,4 @@ +/// SVC search parameters +pub mod svc_params; +/// SVC search parameters +pub mod svr_params; \ No newline at end of file diff --git a/src/svm/search/svc_params.rs b/src/svm/search/svc_params.rs index 6f1de6a..e8c836c 100644 --- a/src/svm/search/svc_params.rs +++ b/src/svm/search/svc_params.rs @@ -1,184 +1,184 @@ -/// SVC grid search parameters -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Clone)] -pub struct SVCSearchParameters< - TX: Number + RealNumber, - TY: Number + Ord, - X: Array2, - Y: Array1, - K: Kernel, -> { - #[cfg_attr(feature = "serde", serde(default))] - /// Number of epochs. - pub epoch: Vec, - #[cfg_attr(feature = "serde", serde(default))] - /// Regularization parameter. - pub c: Vec, - #[cfg_attr(feature = "serde", serde(default))] - /// Tolerance for stopping epoch. - pub tol: Vec, - #[cfg_attr(feature = "serde", serde(default))] - /// The kernel function. - pub kernel: Vec, - #[cfg_attr(feature = "serde", serde(default))] - /// Unused parameter. - m: PhantomData<(X, Y, TY)>, - #[cfg_attr(feature = "serde", serde(default))] - /// Controls the pseudo random number generation for shuffling the data for probability estimates - seed: Vec>, -} +// /// SVC grid search parameters +// #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +// #[derive(Debug, Clone)] +// pub struct SVCSearchParameters< +// TX: Number + RealNumber, +// TY: Number + Ord, +// X: Array2, +// Y: Array1, +// K: Kernel, +// > { +// #[cfg_attr(feature = "serde", serde(default))] +// /// Number of epochs. +// pub epoch: Vec, +// #[cfg_attr(feature = "serde", serde(default))] +// /// Regularization parameter. +// pub c: Vec, +// #[cfg_attr(feature = "serde", serde(default))] +// /// Tolerance for stopping epoch. +// pub tol: Vec, +// #[cfg_attr(feature = "serde", serde(default))] +// /// The kernel function. +// pub kernel: Vec, +// #[cfg_attr(feature = "serde", serde(default))] +// /// Unused parameter. +// m: PhantomData<(X, Y, TY)>, +// #[cfg_attr(feature = "serde", serde(default))] +// /// Controls the pseudo random number generation for shuffling the data for probability estimates +// seed: Vec>, +// } -/// SVC grid search iterator -pub struct SVCSearchParametersIterator< - TX: Number + RealNumber, - TY: Number + Ord, - X: Array2, - Y: Array1, - K: Kernel, -> { - svc_search_parameters: SVCSearchParameters, - current_epoch: usize, - current_c: usize, - current_tol: usize, - current_kernel: usize, - current_seed: usize, -} +// /// SVC grid search iterator +// pub struct SVCSearchParametersIterator< +// TX: Number + RealNumber, +// TY: Number + Ord, +// X: Array2, +// Y: Array1, +// K: Kernel, +// > { +// svc_search_parameters: SVCSearchParameters, +// current_epoch: usize, +// current_c: usize, +// current_tol: usize, +// current_kernel: usize, +// current_seed: usize, +// } -impl, Y: Array1, K: Kernel> - IntoIterator for SVCSearchParameters -{ - type Item = SVCParameters<'a, TX, TY, X, Y>; - type IntoIter = SVCSearchParametersIterator; +// impl, Y: Array1, K: Kernel> +// IntoIterator for SVCSearchParameters +// { +// type Item = SVCParameters<'a, TX, TY, X, Y>; +// type IntoIter = SVCSearchParametersIterator; - fn into_iter(self) -> Self::IntoIter { - SVCSearchParametersIterator { - svc_search_parameters: self, - current_epoch: 0, - current_c: 0, - current_tol: 0, - current_kernel: 0, - current_seed: 0, - } - } -} +// fn into_iter(self) -> Self::IntoIter { +// SVCSearchParametersIterator { +// svc_search_parameters: self, +// current_epoch: 0, +// current_c: 0, +// current_tol: 0, +// current_kernel: 0, +// current_seed: 0, +// } +// } +// } -impl, Y: Array1, K: Kernel> - Iterator for SVCSearchParametersIterator -{ - type Item = SVCParameters; +// impl, Y: Array1, K: Kernel> +// Iterator for SVCSearchParametersIterator +// { +// type Item = SVCParameters; - fn next(&mut self) -> Option { - if self.current_epoch == self.svc_search_parameters.epoch.len() - && self.current_c == self.svc_search_parameters.c.len() - && self.current_tol == self.svc_search_parameters.tol.len() - && self.current_kernel == self.svc_search_parameters.kernel.len() - && self.current_seed == self.svc_search_parameters.seed.len() - { - return None; - } +// fn next(&mut self) -> Option { +// if self.current_epoch == self.svc_search_parameters.epoch.len() +// && self.current_c == self.svc_search_parameters.c.len() +// && self.current_tol == self.svc_search_parameters.tol.len() +// && self.current_kernel == self.svc_search_parameters.kernel.len() +// && self.current_seed == self.svc_search_parameters.seed.len() +// { +// return None; +// } - let next = SVCParameters { - epoch: self.svc_search_parameters.epoch[self.current_epoch], - c: self.svc_search_parameters.c[self.current_c], - tol: self.svc_search_parameters.tol[self.current_tol], - kernel: self.svc_search_parameters.kernel[self.current_kernel].clone(), - m: PhantomData, - seed: self.svc_search_parameters.seed[self.current_seed], - }; +// let next = SVCParameters { +// epoch: self.svc_search_parameters.epoch[self.current_epoch], +// c: self.svc_search_parameters.c[self.current_c], +// tol: self.svc_search_parameters.tol[self.current_tol], +// kernel: self.svc_search_parameters.kernel[self.current_kernel].clone(), +// m: PhantomData, +// seed: self.svc_search_parameters.seed[self.current_seed], +// }; - if self.current_epoch + 1 < self.svc_search_parameters.epoch.len() { - self.current_epoch += 1; - } else if self.current_c + 1 < self.svc_search_parameters.c.len() { - self.current_epoch = 0; - self.current_c += 1; - } else if self.current_tol + 1 < self.svc_search_parameters.tol.len() { - self.current_epoch = 0; - self.current_c = 0; - self.current_tol += 1; - } else if self.current_kernel + 1 < self.svc_search_parameters.kernel.len() { - self.current_epoch = 0; - self.current_c = 0; - self.current_tol = 0; - self.current_kernel += 1; - } else if self.current_seed + 1 < self.svc_search_parameters.seed.len() { - self.current_epoch = 0; - self.current_c = 0; - self.current_tol = 0; - self.current_kernel = 0; - self.current_seed += 1; - } else { - self.current_epoch += 1; - self.current_c += 1; - self.current_tol += 1; - self.current_kernel += 1; - self.current_seed += 1; - } +// if self.current_epoch + 1 < self.svc_search_parameters.epoch.len() { +// self.current_epoch += 1; +// } else if self.current_c + 1 < self.svc_search_parameters.c.len() { +// self.current_epoch = 0; +// self.current_c += 1; +// } else if self.current_tol + 1 < self.svc_search_parameters.tol.len() { +// self.current_epoch = 0; +// self.current_c = 0; +// self.current_tol += 1; +// } else if self.current_kernel + 1 < self.svc_search_parameters.kernel.len() { +// self.current_epoch = 0; +// self.current_c = 0; +// self.current_tol = 0; +// self.current_kernel += 1; +// } else if self.current_seed + 1 < self.svc_search_parameters.seed.len() { +// self.current_epoch = 0; +// self.current_c = 0; +// self.current_tol = 0; +// self.current_kernel = 0; +// self.current_seed += 1; +// } else { +// self.current_epoch += 1; +// self.current_c += 1; +// self.current_tol += 1; +// self.current_kernel += 1; +// self.current_seed += 1; +// } - Some(next) - } -} +// Some(next) +// } +// } -impl, Y: Array1, K: Kernel> Default - for SVCSearchParameters -{ - fn default() -> Self { - let default_params: SVCParameters = SVCParameters::default(); +// impl, Y: Array1, K: Kernel> Default +// for SVCSearchParameters +// { +// fn default() -> Self { +// let default_params: SVCParameters = SVCParameters::default(); - SVCSearchParameters { - epoch: vec![default_params.epoch], - c: vec![default_params.c], - tol: vec![default_params.tol], - kernel: vec![default_params.kernel], - m: PhantomData, - seed: vec![default_params.seed], - } - } -} +// SVCSearchParameters { +// epoch: vec![default_params.epoch], +// c: vec![default_params.c], +// tol: vec![default_params.tol], +// kernel: vec![default_params.kernel], +// m: PhantomData, +// seed: vec![default_params.seed], +// } +// } +// } -#[cfg(test)] -mod tests { - use num::ToPrimitive; +// #[cfg(test)] +// mod tests { +// use num::ToPrimitive; - use super::*; - use crate::linalg::basic::matrix::DenseMatrix; - use crate::metrics::accuracy; - #[cfg(feature = "serde")] - use crate::svm::*; +// use super::*; +// use crate::linalg::basic::matrix::DenseMatrix; +// use crate::metrics::accuracy; +// #[cfg(feature = "serde")] +// use crate::svm::*; - #[test] - fn search_parameters() { - let parameters: SVCSearchParameters, LinearKernel> = - SVCSearchParameters { - epoch: vec![10, 100], - kernel: vec![LinearKernel {}], - ..Default::default() - }; - let mut iter = parameters.into_iter(); - let next = iter.next().unwrap(); - assert_eq!(next.epoch, 10); - assert_eq!(next.kernel, LinearKernel {}); - let next = iter.next().unwrap(); - assert_eq!(next.epoch, 100); - assert_eq!(next.kernel, LinearKernel {}); - assert!(iter.next().is_none()); - } +// #[test] +// fn search_parameters() { +// let parameters: SVCSearchParameters, LinearKernel> = +// SVCSearchParameters { +// epoch: vec![10, 100], +// kernel: vec![LinearKernel {}], +// ..Default::default() +// }; +// let mut iter = parameters.into_iter(); +// let next = iter.next().unwrap(); +// assert_eq!(next.epoch, 10); +// assert_eq!(next.kernel, LinearKernel {}); +// let next = iter.next().unwrap(); +// assert_eq!(next.epoch, 100); +// assert_eq!(next.kernel, LinearKernel {}); +// assert!(iter.next().is_none()); +// } - #[test] - fn search_parameters() { - let parameters: SVCSearchParameters, LinearKernel> = - SVCSearchParameters { - epoch: vec![10, 100], - kernel: vec![LinearKernel {}], - ..Default::default() - }; - let mut iter = parameters.into_iter(); - let next = iter.next().unwrap(); - assert_eq!(next.epoch, 10); - assert_eq!(next.kernel, LinearKernel {}); - let next = iter.next().unwrap(); - assert_eq!(next.epoch, 100); - assert_eq!(next.kernel, LinearKernel {}); - assert!(iter.next().is_none()); - } -} +// #[test] +// fn search_parameters() { +// let parameters: SVCSearchParameters, LinearKernel> = +// SVCSearchParameters { +// epoch: vec![10, 100], +// kernel: vec![LinearKernel {}], +// ..Default::default() +// }; +// let mut iter = parameters.into_iter(); +// let next = iter.next().unwrap(); +// assert_eq!(next.epoch, 10); +// assert_eq!(next.kernel, LinearKernel {}); +// let next = iter.next().unwrap(); +// assert_eq!(next.epoch, 100); +// assert_eq!(next.kernel, LinearKernel {}); +// assert!(iter.next().is_none()); +// } +// }