From b4a807eb9f1121f85e274c9939a7acbca17cc9fc Mon Sep 17 00:00:00 2001 From: ferrouille <93612259+ferrouille@users.noreply.github.com> Date: Tue, 21 Jun 2022 18:48:16 +0200 Subject: [PATCH] Add SVC::decision_function (#135) --- src/svm/svc.rs | 67 ++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 57 insertions(+), 10 deletions(-) diff --git a/src/svm/svc.rs b/src/svm/svc.rs index 7432b9c..74f31c7 100644 --- a/src/svm/svc.rs +++ b/src/svm/svc.rs @@ -263,21 +263,33 @@ impl, K: Kernel> SVC { /// Predicts estimated class labels from `x` /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. pub fn predict(&self, x: &M) -> Result { - let (n, _) = x.shape(); + let mut y_hat = self.decision_function(x)?; - let mut y_hat = M::RowVector::zeros(n); - - for i in 0..n { - let cls_idx = match self.predict_for_row(x.get_row(i)) == T::one() { + for i in 0..y_hat.len() { + let cls_idx = match y_hat.get(i) > T::zero() { false => self.classes[0], true => self.classes[1], }; + y_hat.set(i, cls_idx); } Ok(y_hat) } + /// Evaluates the decision function for the rows in `x` + /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. + pub fn decision_function(&self, x: &M) -> Result { + let (n, _) = x.shape(); + let mut y_hat = M::RowVector::zeros(n); + + for i in 0..n { + y_hat.set(i, self.predict_for_row(x.get_row(i))); + } + + Ok(y_hat) + } + fn predict_for_row(&self, x: M::RowVector) -> T { let mut f = self.b; @@ -285,11 +297,7 @@ impl, K: Kernel> SVC { f += self.w[i] * self.kernel.apply(&x, &self.instances[i]); } - if f > T::zero() { - T::one() - } else { - -T::one() - } + f } } @@ -772,6 +780,45 @@ mod tests { assert!(accuracy(&y_hat, &y) >= 0.9); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + fn svc_fit_decision_function() { + let x = DenseMatrix::from_2d_array(&[&[4.0, 0.0], &[0.0, 4.0], &[8.0, 0.0], &[0.0, 8.0]]); + + let x2 = DenseMatrix::from_2d_array(&[ + &[3.0, 3.0], + &[4.0, 4.0], + &[6.0, 6.0], + &[10.0, 10.0], + &[1.0, 1.0], + &[0.0, 0.0], + ]); + + let y: Vec = vec![0., 0., 1., 1.]; + + let y_hat = SVC::fit( + &x, + &y, + SVCParameters::default() + .with_c(200.0) + .with_kernel(Kernels::linear()), + ) + .and_then(|lr| lr.decision_function(&x2)) + .unwrap(); + + // x can be classified by a straight line through [6.0, 0.0] and [0.0, 6.0], + // so the score should increase as points get further away from that line + println!("{:?}", y_hat); + assert!(y_hat[1] < y_hat[2]); + assert!(y_hat[2] < y_hat[3]); + + // for negative scores the score should decrease + assert!(y_hat[4] > y_hat[5]); + + // y_hat[0] is on the line, so its score should be close to 0 + assert!(y_hat[0].abs() <= 0.1); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn svc_fit_predict_rbf() {