fix: SVC: some more post-review refactoring

This commit is contained in:
Volodymyr Orlov
2020-10-26 16:27:26 -07:00
parent aa38fc8b70
commit bf8d0c081f
+64 -66
View File
@@ -468,83 +468,81 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
idx_2: Option<usize>, idx_2: Option<usize>,
cache: &mut Cache<T, M, K>, cache: &mut Cache<T, M, K>,
) -> Option<(usize, usize, T)> { ) -> Option<(usize, usize, T)> {
let mut idx_1 = idx_1;
let mut idx_2 = idx_2;
let mut k_v_12: Option<T> = None; match (idx_1, idx_2) {
(None, None) => {
if idx_1.is_none() && idx_2.is_none() { if self.gmax > -self.gmin {
self.find_min_max_gradient(); self.select_pair(None, Some(self.svmax), cache)
if self.gmax > -self.gmin { } else {
idx_2 = Some(self.svmax); self.select_pair(Some(self.svmin), None, cache)
} else {
idx_1 = Some(self.svmin);
}
}
if idx_2.is_none() {
let idx_1 = &self.sv[idx_1.unwrap()];
let km = idx_1.k;
let gm = idx_1.grad;
let mut best = T::zero();
for i in 0..self.sv.len() {
let v = &self.sv[i];
let z = v.grad - gm;
let k = cache.get(idx_1, &v);
let mut curv = km + v.k - T::two() * k;
if curv <= T::zero() {
curv = self.tau;
} }
let mu = z / curv; },
if (mu > T::zero() && v.alpha < v.cmax) || (mu < T::zero() && v.alpha > v.cmin) { (Some(idx_1), None) => {
let gain = z * mu; let sv1 = &self.sv[idx_1];
if gain > best { let mut idx_2 = None;
best = gain; let mut k_v_12 = None;
idx_2 = Some(i); let km = sv1.k;
k_v_12 = Some(k); let gm = sv1.grad;
let mut best = T::zero();
for i in 0..self.sv.len() {
let v = &self.sv[i];
let z = v.grad - gm;
let k = cache.get(sv1, &v);
let mut curv = km + v.k - T::two() * k;
if curv <= T::zero() {
curv = self.tau;
}
let mu = z / curv;
if (mu > T::zero() && v.alpha < v.cmax) || (mu < T::zero() && v.alpha > v.cmin) {
let gain = z * mu;
if gain > best {
best = gain;
idx_2 = Some(i);
k_v_12 = Some(k);
}
} }
} }
}
}
if idx_1.is_none() { idx_2.map(|idx_2| {
let idx_2 = &self.sv[idx_2.unwrap()]; (idx_1, idx_2, k_v_12.unwrap_or(self.kernel.apply(&self.sv[idx_1].x, &self.sv[idx_2].x)))
let km = idx_2.k; })
let gm = idx_2.grad; },
let mut best = T::zero(); (None, Some(idx_2)) => {
for i in 0..self.sv.len() { let mut idx_1 = None;
let v = &self.sv[i]; let sv2 = &self.sv[idx_2];
let z = gm - v.grad; let mut k_v_12 = None;
let k = cache.get(idx_2, v); let km = sv2.k;
let mut curv = km + v.k - T::two() * k; let gm = sv2.grad;
if curv <= T::zero() { let mut best = T::zero();
curv = self.tau; for i in 0..self.sv.len() {
} let v = &self.sv[i];
let z = gm - v.grad;
let k = cache.get(sv2, v);
let mut curv = km + v.k - T::two() * k;
if curv <= T::zero() {
curv = self.tau;
}
let mu = z / curv; let mu = z / curv;
if (mu > T::zero() && v.alpha > v.cmin) || (mu < T::zero() && v.alpha < v.cmax) { if (mu > T::zero() && v.alpha > v.cmin) || (mu < T::zero() && v.alpha < v.cmax) {
let gain = z * mu; let gain = z * mu;
if gain > best { if gain > best {
best = gain; best = gain;
idx_1 = Some(i); idx_1 = Some(i);
k_v_12 = Some(k); k_v_12 = Some(k);
}
} }
} }
idx_1.map(|idx_1| {
(idx_1, idx_2, k_v_12.unwrap_or(self.kernel.apply(&self.sv[idx_1].x, &self.sv[idx_2].x)))
})
},
(Some(idx_1), Some(idx_2)) => {
Some((idx_1, idx_2, self.kernel.apply(&self.sv[idx_1].x, &self.sv[idx_2].x)))
} }
} }
if idx_1.is_none() || idx_2.is_none() {
None
} else {
let idx_1 = idx_1.unwrap();
let idx_2 = idx_2.unwrap();
if k_v_12.is_none() {
k_v_12 = Some(self.kernel.apply(&self.sv[idx_1].x, &self.sv[idx_2].x));
}
Some((idx_1, idx_2, k_v_12.unwrap()))
}
} }
fn smo( fn smo(