Add SVC::decision_function (#135)
This commit is contained in:
+57
-10
@@ -263,21 +263,33 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVC<T, M, K> {
|
|||||||
/// Predicts estimated class labels from `x`
|
/// Predicts estimated class labels from `x`
|
||||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||||
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||||
let (n, _) = x.shape();
|
let mut y_hat = self.decision_function(x)?;
|
||||||
|
|
||||||
let mut y_hat = M::RowVector::zeros(n);
|
for i in 0..y_hat.len() {
|
||||||
|
let cls_idx = match y_hat.get(i) > T::zero() {
|
||||||
for i in 0..n {
|
|
||||||
let cls_idx = match self.predict_for_row(x.get_row(i)) == T::one() {
|
|
||||||
false => self.classes[0],
|
false => self.classes[0],
|
||||||
true => self.classes[1],
|
true => self.classes[1],
|
||||||
};
|
};
|
||||||
|
|
||||||
y_hat.set(i, cls_idx);
|
y_hat.set(i, cls_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(y_hat)
|
Ok(y_hat)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Evaluates the decision function for the rows in `x`
|
||||||
|
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||||
|
pub fn decision_function(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||||
|
let (n, _) = x.shape();
|
||||||
|
let mut y_hat = M::RowVector::zeros(n);
|
||||||
|
|
||||||
|
for i in 0..n {
|
||||||
|
y_hat.set(i, self.predict_for_row(x.get_row(i)));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(y_hat)
|
||||||
|
}
|
||||||
|
|
||||||
fn predict_for_row(&self, x: M::RowVector) -> T {
|
fn predict_for_row(&self, x: M::RowVector) -> T {
|
||||||
let mut f = self.b;
|
let mut f = self.b;
|
||||||
|
|
||||||
@@ -285,11 +297,7 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVC<T, M, K> {
|
|||||||
f += self.w[i] * self.kernel.apply(&x, &self.instances[i]);
|
f += self.w[i] * self.kernel.apply(&x, &self.instances[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
if f > T::zero() {
|
f
|
||||||
T::one()
|
|
||||||
} else {
|
|
||||||
-T::one()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -772,6 +780,45 @@ mod tests {
|
|||||||
assert!(accuracy(&y_hat, &y) >= 0.9);
|
assert!(accuracy(&y_hat, &y) >= 0.9);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
|
#[test]
|
||||||
|
fn svc_fit_decision_function() {
|
||||||
|
let x = DenseMatrix::from_2d_array(&[&[4.0, 0.0], &[0.0, 4.0], &[8.0, 0.0], &[0.0, 8.0]]);
|
||||||
|
|
||||||
|
let x2 = DenseMatrix::from_2d_array(&[
|
||||||
|
&[3.0, 3.0],
|
||||||
|
&[4.0, 4.0],
|
||||||
|
&[6.0, 6.0],
|
||||||
|
&[10.0, 10.0],
|
||||||
|
&[1.0, 1.0],
|
||||||
|
&[0.0, 0.0],
|
||||||
|
]);
|
||||||
|
|
||||||
|
let y: Vec<f64> = vec![0., 0., 1., 1.];
|
||||||
|
|
||||||
|
let y_hat = SVC::fit(
|
||||||
|
&x,
|
||||||
|
&y,
|
||||||
|
SVCParameters::default()
|
||||||
|
.with_c(200.0)
|
||||||
|
.with_kernel(Kernels::linear()),
|
||||||
|
)
|
||||||
|
.and_then(|lr| lr.decision_function(&x2))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// x can be classified by a straight line through [6.0, 0.0] and [0.0, 6.0],
|
||||||
|
// so the score should increase as points get further away from that line
|
||||||
|
println!("{:?}", y_hat);
|
||||||
|
assert!(y_hat[1] < y_hat[2]);
|
||||||
|
assert!(y_hat[2] < y_hat[3]);
|
||||||
|
|
||||||
|
// for negative scores the score should decrease
|
||||||
|
assert!(y_hat[4] > y_hat[5]);
|
||||||
|
|
||||||
|
// y_hat[0] is on the line, so its score should be close to 0
|
||||||
|
assert!(y_hat[0].abs() <= 0.1);
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn svc_fit_predict_rbf() {
|
fn svc_fit_predict_rbf() {
|
||||||
|
|||||||
Reference in New Issue
Block a user