diff --git a/src/linear/logistic_regression.rs b/src/linear/logistic_regression.rs index 116d700..2df9b87 100644 --- a/src/linear/logistic_regression.rs +++ b/src/linear/logistic_regression.rs @@ -52,6 +52,7 @@ //! //! //! +use std::cmp::Ordering; use std::fmt::Debug; use std::marker::PhantomData; @@ -232,51 +233,53 @@ impl> LogisticRegression { yi[i] = classes.iter().position(|c| yc == *c).unwrap(); } - if k < 2 { - Err(Failed::fit(&format!( + match k.cmp(&2) { + Ordering::Less => Err(Failed::fit(&format!( "incorrect number of classes: {}. Should be >= 2.", k - ))) - } else if k == 2 { - let x0 = M::zeros(1, num_attributes + 1); + ))), + Ordering::Equal => { + let x0 = M::zeros(1, num_attributes + 1); - let objective = BinaryObjectiveFunction { - x: x, - y: yi, - phantom: PhantomData, - }; + let objective = BinaryObjectiveFunction { + x: x, + y: yi, + phantom: PhantomData, + }; - let result = LogisticRegression::minimize(x0, objective); - let weights = result.x; + let result = LogisticRegression::minimize(x0, objective); + let weights = result.x; - Ok(LogisticRegression { - coefficients: weights.slice(0..1, 0..num_attributes), - intercept: weights.slice(0..1, num_attributes..num_attributes + 1), - classes: classes, - num_attributes: num_attributes, - num_classes: k, - }) - } else { - let x0 = M::zeros(1, (num_attributes + 1) * k); + Ok(LogisticRegression { + coefficients: weights.slice(0..1, 0..num_attributes), + intercept: weights.slice(0..1, num_attributes..num_attributes + 1), + classes: classes, + num_attributes: num_attributes, + num_classes: k, + }) + } + Ordering::Greater => { + let x0 = M::zeros(1, (num_attributes + 1) * k); - let objective = MultiClassObjectiveFunction { - x: x, - y: yi, - k: k, - phantom: PhantomData, - }; + let objective = MultiClassObjectiveFunction { + x: x, + y: yi, + k: k, + phantom: PhantomData, + }; - let result = LogisticRegression::minimize(x0, objective); + let result = LogisticRegression::minimize(x0, objective); - let weights = result.x.reshape(k, num_attributes + 1); + let weights = result.x.reshape(k, num_attributes + 1); - Ok(LogisticRegression { - coefficients: weights.slice(0..k, 0..num_attributes), - intercept: weights.slice(0..k, num_attributes..num_attributes + 1), - classes: classes, - num_attributes: num_attributes, - num_classes: k, - }) + Ok(LogisticRegression { + coefficients: weights.slice(0..k, 0..num_attributes), + intercept: weights.slice(0..k, num_attributes..num_attributes + 1), + classes: classes, + num_attributes: num_attributes, + num_classes: k, + }) + } } } @@ -286,7 +289,6 @@ impl> LogisticRegression { let n = x.shape().0; let mut result = M::zeros(1, n); if self.num_classes == 2 { - let (nrows, _) = x.shape(); let y_hat: Vec = x.matmul(&self.coefficients.transpose()).get_col_as_vec(0); let intercept = self.intercept.get(0, 0); for i in 0..n {