feat: adds 3 more SVM kernels, linalg refactoring
This commit is contained in:
+81
-8
@@ -5,7 +5,7 @@
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::linear::linear_regression::*;
|
||||
//! use smartcore::svm::LinearKernel;
|
||||
//! use smartcore::svm::Kernels;
|
||||
//! use smartcore::svm::svc::{SVC, SVCParameters};
|
||||
//!
|
||||
//! // Iris dataset
|
||||
@@ -31,11 +31,11 @@
|
||||
//! &[6.6, 2.9, 4.6, 1.3],
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//! let y = vec![ -1., -1., -1., -1., -1., -1., -1., -1.,
|
||||
//! let y = vec![ 0., 0., 0., 0., 0., 0., 0., 0.,
|
||||
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
|
||||
//!
|
||||
//! let svr = SVC::fit(&x, &y,
|
||||
//! LinearKernel {},
|
||||
//! Kernels::linear(),
|
||||
//! SVCParameters {
|
||||
//! epoch: 2,
|
||||
//! c: 200.0,
|
||||
@@ -83,6 +83,7 @@ pub struct SVCParameters<T: RealNumber> {
|
||||
))]
|
||||
/// Support Vector Classifier
|
||||
pub struct SVC<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||
classes: Vec<T>,
|
||||
kernel: K,
|
||||
instances: Vec<M::RowVector>,
|
||||
w: Vec<T>,
|
||||
@@ -150,11 +151,32 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVC<T, M, K> {
|
||||
)));
|
||||
}
|
||||
|
||||
let optimizer = Optimizer::new(x, y, &kernel, ¶meters);
|
||||
let classes = y.unique();
|
||||
|
||||
if classes.len() != 2 {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Incorrect number of classes {}", classes.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// Make sure class labels are either 1 or -1
|
||||
let mut y = y.clone();
|
||||
for i in 0..y.len() {
|
||||
let y_v = y.get(i);
|
||||
if y_v != -T::one() || y_v != T::one() {
|
||||
match y_v == classes[0] {
|
||||
true => y.set(i, -T::one()),
|
||||
false => y.set(i, T::one())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let optimizer = Optimizer::new(x, &y, &kernel, ¶meters);
|
||||
|
||||
let (support_vectors, weight, b) = optimizer.optimize();
|
||||
|
||||
Ok(SVC {
|
||||
classes: classes,
|
||||
kernel: kernel,
|
||||
instances: support_vectors,
|
||||
w: weight,
|
||||
@@ -170,7 +192,11 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVC<T, M, K> {
|
||||
let mut y_hat = M::RowVector::zeros(n);
|
||||
|
||||
for i in 0..n {
|
||||
y_hat.set(i, self.predict_for_row(x.get_row(i)));
|
||||
let cls_idx = match self.predict_for_row(x.get_row(i)) == T::one() {
|
||||
false => self.classes[0],
|
||||
true => self.classes[1]
|
||||
};
|
||||
y_hat.set(i, cls_idx);
|
||||
}
|
||||
|
||||
Ok(y_hat)
|
||||
@@ -647,13 +673,13 @@ mod tests {
|
||||
]);
|
||||
|
||||
let y: Vec<f64> = vec![
|
||||
-1., -1., -1., -1., -1., -1., -1., -1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
|
||||
let y_hat = SVC::fit(
|
||||
&x,
|
||||
&y,
|
||||
LinearKernel {},
|
||||
Kernels::linear(),
|
||||
SVCParameters {
|
||||
epoch: 2,
|
||||
c: 200.0,
|
||||
@@ -663,6 +689,53 @@ mod tests {
|
||||
.and_then(|lr| lr.predict(&x))
|
||||
.unwrap();
|
||||
|
||||
println!("{:?}", y_hat);
|
||||
|
||||
assert!(accuracy(&y_hat, &y) >= 0.9);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn svc_fit_predict_rbf() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[5.7, 2.8, 4.5, 1.3],
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
|
||||
let y: Vec<f64> = vec![
|
||||
-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
|
||||
let y_hat = SVC::fit(
|
||||
&x,
|
||||
&y,
|
||||
Kernels::rbf(0.7),
|
||||
SVCParameters {
|
||||
epoch: 2,
|
||||
c: 1.0,
|
||||
tol: 1e-3,
|
||||
},
|
||||
)
|
||||
.and_then(|lr| lr.predict(&x))
|
||||
.unwrap();
|
||||
|
||||
assert!(accuracy(&y_hat, &y) >= 0.9);
|
||||
}
|
||||
|
||||
@@ -695,7 +768,7 @@ mod tests {
|
||||
-1., -1., -1., -1., -1., -1., -1., -1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
|
||||
let svr = SVC::fit(&x, &y, LinearKernel {}, Default::default()).unwrap();
|
||||
let svr = SVC::fit(&x, &y, Kernels::linear(), Default::default()).unwrap();
|
||||
|
||||
let deserialized_svr: SVC<f64, DenseMatrix<f64>, LinearKernel> =
|
||||
serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap();
|
||||
|
||||
Reference in New Issue
Block a user