fix: SVS: post-review changes
This commit is contained in:
+55
-45
@@ -462,13 +462,7 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
|
||||
range
|
||||
}
|
||||
|
||||
fn smo(
|
||||
&mut self,
|
||||
idx_1: Option<usize>,
|
||||
idx_2: Option<usize>,
|
||||
tol: T,
|
||||
cache: &mut Cache<T, M, K>,
|
||||
) -> bool {
|
||||
fn select_pair(&mut self, idx_1: Option<usize>, idx_2: Option<usize>, cache: &mut Cache<T, M, K>) -> Option<(usize, usize, T)> {
|
||||
let mut idx_1 = idx_1;
|
||||
let mut idx_2 = idx_2;
|
||||
|
||||
@@ -532,51 +526,67 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if idx_1.is_none() || idx_2.is_none() {
|
||||
return false;
|
||||
}
|
||||
|
||||
let idx_1 = idx_1.unwrap();
|
||||
let idx_2 = idx_2.unwrap();
|
||||
|
||||
if k_v_12.is_none() {
|
||||
k_v_12 = Some(self.kernel.apply(&self.sv[idx_1].x, &self.sv[idx_2].x));
|
||||
}
|
||||
|
||||
let k_v_12 = k_v_12.unwrap();
|
||||
|
||||
let mut curv = self.sv[idx_1].k + self.sv[idx_2].k - T::two() * k_v_12;
|
||||
if curv <= T::zero() {
|
||||
curv = self.tau;
|
||||
}
|
||||
|
||||
let mut step = (self.sv[idx_2].grad - self.sv[idx_1].grad) / curv;
|
||||
|
||||
if step >= T::zero() {
|
||||
let mut ostep = self.sv[idx_1].alpha - self.sv[idx_1].cmin;
|
||||
if ostep < step {
|
||||
step = ostep;
|
||||
}
|
||||
ostep = self.sv[idx_2].cmax - self.sv[idx_2].alpha;
|
||||
if ostep < step {
|
||||
step = ostep;
|
||||
}
|
||||
None
|
||||
} else {
|
||||
let mut ostep = self.sv[idx_2].cmin - self.sv[idx_2].alpha;
|
||||
if ostep > step {
|
||||
step = ostep;
|
||||
}
|
||||
ostep = self.sv[idx_1].alpha - self.sv[idx_1].cmax;
|
||||
if ostep > step {
|
||||
step = ostep;
|
||||
|
||||
let idx_1 = idx_1.unwrap();
|
||||
let idx_2 = idx_2.unwrap();
|
||||
|
||||
if k_v_12.is_none() {
|
||||
k_v_12 = Some(self.kernel.apply(&self.sv[idx_1].x, &self.sv[idx_2].x));
|
||||
}
|
||||
|
||||
Some((idx_1, idx_2, k_v_12.unwrap()))
|
||||
}
|
||||
}
|
||||
|
||||
self.update(idx_1, idx_2, step, cache);
|
||||
fn smo(
|
||||
&mut self,
|
||||
idx_1: Option<usize>,
|
||||
idx_2: Option<usize>,
|
||||
tol: T,
|
||||
cache: &mut Cache<T, M, K>,
|
||||
) -> bool {
|
||||
|
||||
match self.select_pair(idx_1, idx_2, cache) {
|
||||
Some((idx_1, idx_2, k_v_12)) => {
|
||||
let mut curv = self.sv[idx_1].k + self.sv[idx_2].k - T::two() * k_v_12;
|
||||
if curv <= T::zero() {
|
||||
curv = self.tau;
|
||||
}
|
||||
|
||||
return self.gmax - self.gmin > tol;
|
||||
let mut step = (self.sv[idx_2].grad - self.sv[idx_1].grad) / curv;
|
||||
|
||||
if step >= T::zero() {
|
||||
let mut ostep = self.sv[idx_1].alpha - self.sv[idx_1].cmin;
|
||||
if ostep < step {
|
||||
step = ostep;
|
||||
}
|
||||
ostep = self.sv[idx_2].cmax - self.sv[idx_2].alpha;
|
||||
if ostep < step {
|
||||
step = ostep;
|
||||
}
|
||||
} else {
|
||||
let mut ostep = self.sv[idx_2].cmin - self.sv[idx_2].alpha;
|
||||
if ostep > step {
|
||||
step = ostep;
|
||||
}
|
||||
ostep = self.sv[idx_1].alpha - self.sv[idx_1].cmax;
|
||||
if ostep > step {
|
||||
step = ostep;
|
||||
}
|
||||
}
|
||||
|
||||
self.update(idx_1, idx_2, step, cache);
|
||||
|
||||
return self.gmax - self.gmin > tol;
|
||||
},
|
||||
None => false
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
fn update(&mut self, v1: usize, v2: usize, step: T, cache: &mut Cache<T, M, K>) {
|
||||
|
||||
Reference in New Issue
Block a user