diff --git a/src/svm/svc.rs b/src/svm/svc.rs index b0bf8b0..829f729 100644 --- a/src/svm/svc.rs +++ b/src/svm/svc.rs @@ -468,83 +468,81 @@ impl<'a, T: RealNumber, M: Matrix, K: Kernel> Optimizer<'a, idx_2: Option, cache: &mut Cache, ) -> Option<(usize, usize, T)> { - let mut idx_1 = idx_1; - let mut idx_2 = idx_2; - - let mut k_v_12: Option = None; - - if idx_1.is_none() && idx_2.is_none() { - self.find_min_max_gradient(); - if self.gmax > -self.gmin { - idx_2 = Some(self.svmax); - } else { - idx_1 = Some(self.svmin); - } - } - - if idx_2.is_none() { - let idx_1 = &self.sv[idx_1.unwrap()]; - let km = idx_1.k; - let gm = idx_1.grad; - let mut best = T::zero(); - for i in 0..self.sv.len() { - let v = &self.sv[i]; - let z = v.grad - gm; - let k = cache.get(idx_1, &v); - let mut curv = km + v.k - T::two() * k; - if curv <= T::zero() { - curv = self.tau; - } - let mu = z / curv; - if (mu > T::zero() && v.alpha < v.cmax) || (mu < T::zero() && v.alpha > v.cmin) { - let gain = z * mu; - if gain > best { - best = gain; - idx_2 = Some(i); - k_v_12 = Some(k); + + match (idx_1, idx_2) { + (None, None) => { + if self.gmax > -self.gmin { + self.select_pair(None, Some(self.svmax), cache) + } else { + self.select_pair(Some(self.svmin), None, cache) + } + }, + (Some(idx_1), None) => { + let sv1 = &self.sv[idx_1]; + let mut idx_2 = None; + let mut k_v_12 = None; + let km = sv1.k; + let gm = sv1.grad; + let mut best = T::zero(); + for i in 0..self.sv.len() { + let v = &self.sv[i]; + let z = v.grad - gm; + let k = cache.get(sv1, &v); + let mut curv = km + v.k - T::two() * k; + if curv <= T::zero() { + curv = self.tau; + } + let mu = z / curv; + if (mu > T::zero() && v.alpha < v.cmax) || (mu < T::zero() && v.alpha > v.cmin) { + let gain = z * mu; + if gain > best { + best = gain; + idx_2 = Some(i); + k_v_12 = Some(k); + } } } - } - } - if idx_1.is_none() { - let idx_2 = &self.sv[idx_2.unwrap()]; - let km = idx_2.k; - let gm = idx_2.grad; - let mut best = T::zero(); - for i in 0..self.sv.len() { - let v = &self.sv[i]; - let z = gm - v.grad; - let k = cache.get(idx_2, v); - let mut curv = km + v.k - T::two() * k; - if curv <= T::zero() { - curv = self.tau; - } + idx_2.map(|idx_2| { + (idx_1, idx_2, k_v_12.unwrap_or(self.kernel.apply(&self.sv[idx_1].x, &self.sv[idx_2].x))) + }) + }, + (None, Some(idx_2)) => { + let mut idx_1 = None; + let sv2 = &self.sv[idx_2]; + let mut k_v_12 = None; + let km = sv2.k; + let gm = sv2.grad; + let mut best = T::zero(); + for i in 0..self.sv.len() { + let v = &self.sv[i]; + let z = gm - v.grad; + let k = cache.get(sv2, v); + let mut curv = km + v.k - T::two() * k; + if curv <= T::zero() { + curv = self.tau; + } - let mu = z / curv; - if (mu > T::zero() && v.alpha > v.cmin) || (mu < T::zero() && v.alpha < v.cmax) { - let gain = z * mu; - if gain > best { - best = gain; - idx_1 = Some(i); - k_v_12 = Some(k); + let mu = z / curv; + if (mu > T::zero() && v.alpha > v.cmin) || (mu < T::zero() && v.alpha < v.cmax) { + let gain = z * mu; + if gain > best { + best = gain; + idx_1 = Some(i); + k_v_12 = Some(k); + } } } + + idx_1.map(|idx_1| { + (idx_1, idx_2, k_v_12.unwrap_or(self.kernel.apply(&self.sv[idx_1].x, &self.sv[idx_2].x))) + }) + }, + (Some(idx_1), Some(idx_2)) => { + Some((idx_1, idx_2, self.kernel.apply(&self.sv[idx_1].x, &self.sv[idx_2].x))) } } - - if idx_1.is_none() || idx_2.is_none() { - None - } else { - let idx_1 = idx_1.unwrap(); - let idx_2 = idx_2.unwrap(); - - if k_v_12.is_none() { - k_v_12 = Some(self.kernel.apply(&self.sv[idx_1].x, &self.sv[idx_2].x)); - } - - Some((idx_1, idx_2, k_v_12.unwrap())) - } + } fn smo(