Add SVC::decision_function (#135)

This commit is contained in:
ferrouille
2022-06-21 18:48:16 +02:00
committed by GitHub
parent ff456df0a4
commit b4a807eb9f
+57 -10
View File
@@ -263,21 +263,33 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVC<T, M, K> {
/// Predicts estimated class labels from `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
let (n, _) = x.shape();
let mut y_hat = self.decision_function(x)?;
let mut y_hat = M::RowVector::zeros(n);
for i in 0..n {
let cls_idx = match self.predict_for_row(x.get_row(i)) == T::one() {
for i in 0..y_hat.len() {
let cls_idx = match y_hat.get(i) > T::zero() {
false => self.classes[0],
true => self.classes[1],
};
y_hat.set(i, cls_idx);
}
Ok(y_hat)
}
/// Evaluates the decision function for the rows in `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn decision_function(&self, x: &M) -> Result<M::RowVector, Failed> {
let (n, _) = x.shape();
let mut y_hat = M::RowVector::zeros(n);
for i in 0..n {
y_hat.set(i, self.predict_for_row(x.get_row(i)));
}
Ok(y_hat)
}
fn predict_for_row(&self, x: M::RowVector) -> T {
let mut f = self.b;
@@ -285,11 +297,7 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVC<T, M, K> {
f += self.w[i] * self.kernel.apply(&x, &self.instances[i]);
}
if f > T::zero() {
T::one()
} else {
-T::one()
}
f
}
}
@@ -772,6 +780,45 @@ mod tests {
assert!(accuracy(&y_hat, &y) >= 0.9);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn svc_fit_decision_function() {
let x = DenseMatrix::from_2d_array(&[&[4.0, 0.0], &[0.0, 4.0], &[8.0, 0.0], &[0.0, 8.0]]);
let x2 = DenseMatrix::from_2d_array(&[
&[3.0, 3.0],
&[4.0, 4.0],
&[6.0, 6.0],
&[10.0, 10.0],
&[1.0, 1.0],
&[0.0, 0.0],
]);
let y: Vec<f64> = vec![0., 0., 1., 1.];
let y_hat = SVC::fit(
&x,
&y,
SVCParameters::default()
.with_c(200.0)
.with_kernel(Kernels::linear()),
)
.and_then(|lr| lr.decision_function(&x2))
.unwrap();
// x can be classified by a straight line through [6.0, 0.0] and [0.0, 6.0],
// so the score should increase as points get further away from that line
println!("{:?}", y_hat);
assert!(y_hat[1] < y_hat[2]);
assert!(y_hat[2] < y_hat[3]);
// for negative scores the score should decrease
assert!(y_hat[4] > y_hat[5]);
// y_hat[0] is on the line, so its score should be close to 0
assert!(y_hat[0].abs() <= 0.1);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn svc_fit_predict_rbf() {