fix: formatting

This commit is contained in:
Volodymyr Orlov
2020-10-28 17:23:40 -07:00
parent cf4f658f01
commit 797dc3c8e0
2 changed files with 8 additions and 6 deletions
+1 -1
View File
@@ -79,7 +79,7 @@ impl Kernels {
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct LinearKernel {} pub struct LinearKernel {}
/// Radial basis function (Gaussian) kernel /// Radial basis function (Gaussian) kernel
pub struct RBFKernel<T: RealNumber> { pub struct RBFKernel<T: RealNumber> {
/// kernel coefficient /// kernel coefficient
pub gamma: T, pub gamma: T,
+7 -5
View File
@@ -155,7 +155,8 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVC<T, M, K> {
if classes.len() != 2 { if classes.len() != 2 {
return Err(Failed::fit(&format!( return Err(Failed::fit(&format!(
"Incorrect number of classes {}", classes.len() "Incorrect number of classes {}",
classes.len()
))); )));
} }
@@ -166,7 +167,7 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVC<T, M, K> {
if y_v != -T::one() || y_v != T::one() { if y_v != -T::one() || y_v != T::one() {
match y_v == classes[0] { match y_v == classes[0] {
true => y.set(i, -T::one()), true => y.set(i, -T::one()),
false => y.set(i, T::one()) false => y.set(i, T::one()),
} }
} }
} }
@@ -194,7 +195,7 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVC<T, M, K> {
for i in 0..n { for i in 0..n {
let cls_idx = match self.predict_for_row(x.get_row(i)) == T::one() { let cls_idx = match self.predict_for_row(x.get_row(i)) == T::one() {
false => self.classes[0], false => self.classes[0],
true => self.classes[1] true => self.classes[1],
}; };
y_hat.set(i, cls_idx); y_hat.set(i, cls_idx);
} }
@@ -720,7 +721,8 @@ mod tests {
]); ]);
let y: Vec<f64> = vec![ let y: Vec<f64> = vec![
-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1.,
]; ];
let y_hat = SVC::fit( let y_hat = SVC::fit(
@@ -734,7 +736,7 @@ mod tests {
}, },
) )
.and_then(|lr| lr.predict(&x)) .and_then(|lr| lr.predict(&x))
.unwrap(); .unwrap();
assert!(accuracy(&y_hat, &y) >= 0.9); assert!(accuracy(&y_hat, &y) >= 0.9);
} }