diff --git a/src/linear/logistic_regression.rs b/src/linear/logistic_regression.rs
index 116d700..2df9b87 100644
--- a/src/linear/logistic_regression.rs
+++ b/src/linear/logistic_regression.rs
@@ -52,6 +52,7 @@
//!
//!
//!
+use std::cmp::Ordering;
use std::fmt::Debug;
use std::marker::PhantomData;
@@ -232,51 +233,53 @@ impl> LogisticRegression {
yi[i] = classes.iter().position(|c| yc == *c).unwrap();
}
- if k < 2 {
- Err(Failed::fit(&format!(
+ match k.cmp(&2) {
+ Ordering::Less => Err(Failed::fit(&format!(
"incorrect number of classes: {}. Should be >= 2.",
k
- )))
- } else if k == 2 {
- let x0 = M::zeros(1, num_attributes + 1);
+ ))),
+ Ordering::Equal => {
+ let x0 = M::zeros(1, num_attributes + 1);
- let objective = BinaryObjectiveFunction {
- x: x,
- y: yi,
- phantom: PhantomData,
- };
+ let objective = BinaryObjectiveFunction {
+ x: x,
+ y: yi,
+ phantom: PhantomData,
+ };
- let result = LogisticRegression::minimize(x0, objective);
- let weights = result.x;
+ let result = LogisticRegression::minimize(x0, objective);
+ let weights = result.x;
- Ok(LogisticRegression {
- coefficients: weights.slice(0..1, 0..num_attributes),
- intercept: weights.slice(0..1, num_attributes..num_attributes + 1),
- classes: classes,
- num_attributes: num_attributes,
- num_classes: k,
- })
- } else {
- let x0 = M::zeros(1, (num_attributes + 1) * k);
+ Ok(LogisticRegression {
+ coefficients: weights.slice(0..1, 0..num_attributes),
+ intercept: weights.slice(0..1, num_attributes..num_attributes + 1),
+ classes: classes,
+ num_attributes: num_attributes,
+ num_classes: k,
+ })
+ }
+ Ordering::Greater => {
+ let x0 = M::zeros(1, (num_attributes + 1) * k);
- let objective = MultiClassObjectiveFunction {
- x: x,
- y: yi,
- k: k,
- phantom: PhantomData,
- };
+ let objective = MultiClassObjectiveFunction {
+ x: x,
+ y: yi,
+ k: k,
+ phantom: PhantomData,
+ };
- let result = LogisticRegression::minimize(x0, objective);
+ let result = LogisticRegression::minimize(x0, objective);
- let weights = result.x.reshape(k, num_attributes + 1);
+ let weights = result.x.reshape(k, num_attributes + 1);
- Ok(LogisticRegression {
- coefficients: weights.slice(0..k, 0..num_attributes),
- intercept: weights.slice(0..k, num_attributes..num_attributes + 1),
- classes: classes,
- num_attributes: num_attributes,
- num_classes: k,
- })
+ Ok(LogisticRegression {
+ coefficients: weights.slice(0..k, 0..num_attributes),
+ intercept: weights.slice(0..k, num_attributes..num_attributes + 1),
+ classes: classes,
+ num_attributes: num_attributes,
+ num_classes: k,
+ })
+ }
}
}
@@ -286,7 +289,6 @@ impl> LogisticRegression {
let n = x.shape().0;
let mut result = M::zeros(1, n);
if self.num_classes == 2 {
- let (nrows, _) = x.shape();
let y_hat: Vec = x.matmul(&self.coefficients.transpose()).get_col_as_vec(0);
let intercept = self.intercept.get(0, 0);
for i in 0..n {