fix: svr, post-review changes
This commit is contained in:
+23
-12
@@ -41,6 +41,14 @@
|
|||||||
//!
|
//!
|
||||||
//! let y_hat = svr.predict(&x).unwrap();
|
//! let y_hat = svr.predict(&x).unwrap();
|
||||||
//! ```
|
//! ```
|
||||||
|
//!
|
||||||
|
//! ## References:
|
||||||
|
//!
|
||||||
|
//! * ["Support Vector Machines" Kowalczyk A., 2017](https://www.svm-tutorial.com/2017/10/support-vector-machines-succinctly-released/)
|
||||||
|
//! * ["A Fast Algorithm for Training Support Vector Machines", Platt J.C., 1998](https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/tr-98-14.pdf)
|
||||||
|
//! * ["Working Set Selection Using Second Order Information for Training Support Vector Machines", Rong-En Fan et al., 2005](https://www.jmlr.org/papers/volume6/fan05a/fan05a.pdf)
|
||||||
|
//! * ["A tutorial on support vector regression", SMOLA A.J., Scholkopf B., 2003](https://alex.smola.org/papers/2004/SmoSch04.pdf)
|
||||||
|
|
||||||
use std::cell::{Ref, RefCell};
|
use std::cell::{Ref, RefCell};
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
@@ -87,6 +95,7 @@ struct SupportVector<T: RealNumber, V: BaseVector<T>> {
|
|||||||
k: T,
|
k: T,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Sequential Minimal Optimization algorithm
|
||||||
struct Optimizer<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
struct Optimizer<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||||
tol: T,
|
tol: T,
|
||||||
c: T,
|
c: T,
|
||||||
@@ -135,7 +144,7 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVR<T, M, K> {
|
|||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let optimizer = Optimizer::optimize(x, y, &kernel, ¶meters);
|
let optimizer = Optimizer::new(x, y, &kernel, ¶meters);
|
||||||
|
|
||||||
let (support_vectors, weight, b) = optimizer.smo();
|
let (support_vectors, weight, b) = optimizer.smo();
|
||||||
|
|
||||||
@@ -209,7 +218,7 @@ impl<T: RealNumber, V: BaseVector<T>> SupportVector<T, V> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a, T, M, K> {
|
impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a, T, M, K> {
|
||||||
fn optimize(
|
fn new(
|
||||||
x: &M,
|
x: &M,
|
||||||
y: &M::RowVector,
|
y: &M::RowVector,
|
||||||
kernel: &'a K,
|
kernel: &'a K,
|
||||||
@@ -244,7 +253,7 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn minmax(&mut self) {
|
fn find_min_max_gradient(&mut self) {
|
||||||
self.gmin = T::max_value();
|
self.gmin = T::max_value();
|
||||||
self.gmax = T::min_value();
|
self.gmax = T::min_value();
|
||||||
|
|
||||||
@@ -278,10 +287,14 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Solvs the quadratic programming (QP) problem that arises during the training of support-vector machines (SVM) algorithm.
|
||||||
|
/// Returns:
|
||||||
|
/// * support vectors
|
||||||
|
/// * hyperplane parameters: w and b
|
||||||
fn smo(mut self) -> (Vec<M::RowVector>, Vec<T>, T) {
|
fn smo(mut self) -> (Vec<M::RowVector>, Vec<T>, T) {
|
||||||
let cache: Cache<T> = Cache::new(self.sv.len());
|
let cache: Cache<T> = Cache::new(self.sv.len());
|
||||||
|
|
||||||
self.minmax();
|
self.find_min_max_gradient();
|
||||||
|
|
||||||
while self.gmax - self.gmin > self.tol {
|
while self.gmax - self.gmin > self.tol {
|
||||||
let v1 = self.svmax;
|
let v1 = self.svmax;
|
||||||
@@ -417,22 +430,22 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
|
|||||||
v.grad[1] += si * k1[v.index] * delta_alpha_i + sj * k2[v.index] * delta_alpha_j;
|
v.grad[1] += si * k1[v.index] * delta_alpha_i + sj * k2[v.index] * delta_alpha_j;
|
||||||
}
|
}
|
||||||
|
|
||||||
self.minmax();
|
self.find_min_max_gradient();
|
||||||
}
|
}
|
||||||
|
|
||||||
let b = -(self.gmax + self.gmin) / T::two();
|
let b = -(self.gmax + self.gmin) / T::two();
|
||||||
|
|
||||||
let mut result: Vec<M::RowVector> = Vec::new();
|
let mut support_vectors: Vec<M::RowVector> = Vec::new();
|
||||||
let mut alpha: Vec<T> = Vec::new();
|
let mut w: Vec<T> = Vec::new();
|
||||||
|
|
||||||
for v in self.sv {
|
for v in self.sv {
|
||||||
if v.alpha[0] != v.alpha[1] {
|
if v.alpha[0] != v.alpha[1] {
|
||||||
result.push(v.x);
|
support_vectors.push(v.x);
|
||||||
alpha.push(v.alpha[1] - v.alpha[0]);
|
w.push(v.alpha[1] - v.alpha[0]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
(result, alpha, b)
|
(support_vectors, w, b)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -497,8 +510,6 @@ mod tests {
|
|||||||
.and_then(|lr| lr.predict(&x))
|
.and_then(|lr| lr.predict(&x))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
println!("{:?}", y_hat);
|
|
||||||
|
|
||||||
assert!(mean_squared_error(&y_hat, &y) < 2.5);
|
assert!(mean_squared_error(&y_hat, &y) < 2.5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user