From 47abbbe8b63d4e48ef4d36418e03ea033e9f8f72 Mon Sep 17 00:00:00 2001 From: Volodymyr Orlov Date: Mon, 26 Oct 2020 16:00:31 -0700 Subject: [PATCH] fix: SVS: post-review changes --- src/svm/svc.rs | 100 +++++++++++++++++++++++++++---------------------- 1 file changed, 55 insertions(+), 45 deletions(-) diff --git a/src/svm/svc.rs b/src/svm/svc.rs index a3fbb8a..6e79177 100644 --- a/src/svm/svc.rs +++ b/src/svm/svc.rs @@ -462,13 +462,7 @@ impl<'a, T: RealNumber, M: Matrix, K: Kernel> Optimizer<'a, range } - fn smo( - &mut self, - idx_1: Option, - idx_2: Option, - tol: T, - cache: &mut Cache, - ) -> bool { + fn select_pair(&mut self, idx_1: Option, idx_2: Option, cache: &mut Cache) -> Option<(usize, usize, T)> { let mut idx_1 = idx_1; let mut idx_2 = idx_2; @@ -532,51 +526,67 @@ impl<'a, T: RealNumber, M: Matrix, K: Kernel> Optimizer<'a, } } } - } + } if idx_1.is_none() || idx_2.is_none() { - return false; - } - - let idx_1 = idx_1.unwrap(); - let idx_2 = idx_2.unwrap(); - - if k_v_12.is_none() { - k_v_12 = Some(self.kernel.apply(&self.sv[idx_1].x, &self.sv[idx_2].x)); - } - - let k_v_12 = k_v_12.unwrap(); - - let mut curv = self.sv[idx_1].k + self.sv[idx_2].k - T::two() * k_v_12; - if curv <= T::zero() { - curv = self.tau; - } - - let mut step = (self.sv[idx_2].grad - self.sv[idx_1].grad) / curv; - - if step >= T::zero() { - let mut ostep = self.sv[idx_1].alpha - self.sv[idx_1].cmin; - if ostep < step { - step = ostep; - } - ostep = self.sv[idx_2].cmax - self.sv[idx_2].alpha; - if ostep < step { - step = ostep; - } + None } else { - let mut ostep = self.sv[idx_2].cmin - self.sv[idx_2].alpha; - if ostep > step { - step = ostep; - } - ostep = self.sv[idx_1].alpha - self.sv[idx_1].cmax; - if ostep > step { - step = ostep; + + let idx_1 = idx_1.unwrap(); + let idx_2 = idx_2.unwrap(); + + if k_v_12.is_none() { + k_v_12 = Some(self.kernel.apply(&self.sv[idx_1].x, &self.sv[idx_2].x)); } + + Some((idx_1, idx_2, k_v_12.unwrap())) } + } - self.update(idx_1, idx_2, step, cache); + fn smo( + &mut self, + idx_1: Option, + idx_2: Option, + tol: T, + cache: &mut Cache, + ) -> bool { + + match self.select_pair(idx_1, idx_2, cache) { + Some((idx_1, idx_2, k_v_12)) => { + let mut curv = self.sv[idx_1].k + self.sv[idx_2].k - T::two() * k_v_12; + if curv <= T::zero() { + curv = self.tau; + } - return self.gmax - self.gmin > tol; + let mut step = (self.sv[idx_2].grad - self.sv[idx_1].grad) / curv; + + if step >= T::zero() { + let mut ostep = self.sv[idx_1].alpha - self.sv[idx_1].cmin; + if ostep < step { + step = ostep; + } + ostep = self.sv[idx_2].cmax - self.sv[idx_2].alpha; + if ostep < step { + step = ostep; + } + } else { + let mut ostep = self.sv[idx_2].cmin - self.sv[idx_2].alpha; + if ostep > step { + step = ostep; + } + ostep = self.sv[idx_1].alpha - self.sv[idx_1].cmax; + if ostep > step { + step = ostep; + } + } + + self.update(idx_1, idx_2, step, cache); + + return self.gmax - self.gmin > tol; + }, + None => false + } + } fn update(&mut self, v1: usize, v2: usize, step: T, cache: &mut Cache) {