diff --git a/src/lib.rs b/src/lib.rs
index c85596e..8c97bf7 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -69,7 +69,6 @@
clippy::ptr_arg,
clippy::len_without_is_empty,
clippy::map_entry,
- clippy::comparison_chain,
clippy::type_complexity,
clippy::too_many_arguments,
clippy::many_single_char_names
diff --git a/src/linalg/lu.rs b/src/linalg/lu.rs
index cbe195f..bfc7fff 100644
--- a/src/linalg/lu.rs
+++ b/src/linalg/lu.rs
@@ -33,6 +33,7 @@
//!
#![allow(non_snake_case)]
+use std::cmp::Ordering;
use std::fmt::Debug;
use std::marker::PhantomData;
@@ -78,12 +79,10 @@ impl> LU {
for i in 0..n_rows {
for j in 0..n_cols {
- if i > j {
- L.set(i, j, self.LU.get(i, j));
- } else if i == j {
- L.set(i, j, T::one());
- } else {
- L.set(i, j, T::zero());
+ match i.cmp(&j) {
+ Ordering::Greater => L.set(i, j, self.LU.get(i, j)),
+ Ordering::Equal => L.set(i, j, T::one()),
+ Ordering::Less => L.set(i, j, T::zero()),
}
}
}
diff --git a/src/linear/logistic_regression.rs b/src/linear/logistic_regression.rs
index ec90af1..796caed 100644
--- a/src/linear/logistic_regression.rs
+++ b/src/linear/logistic_regression.rs
@@ -52,6 +52,7 @@
//!
//!
//!
+use std::cmp::Ordering;
use std::fmt::Debug;
use std::marker::PhantomData;
@@ -231,48 +232,50 @@ impl> LogisticRegression {
yi[i] = classes.iter().position(|c| yc == *c).unwrap();
}
- if k < 2 {
- Err(Failed::fit(&format!(
+ match k.cmp(&2) {
+ Ordering::Less => Err(Failed::fit(&format!(
"incorrect number of classes: {}. Should be >= 2.",
k
- )))
- } else if k == 2 {
- let x0 = M::zeros(1, num_attributes + 1);
+ ))),
+ Ordering::Greater => {
+ let x0 = M::zeros(1, (num_attributes + 1) * k);
- let objective = BinaryObjectiveFunction {
- x,
- y: yi,
- phantom: PhantomData,
- };
+ let objective = MultiClassObjectiveFunction {
+ x,
+ y: yi,
+ k,
+ phantom: PhantomData,
+ };
- let result = LogisticRegression::minimize(x0, objective);
+ let result = LogisticRegression::minimize(x0, objective);
- Ok(LogisticRegression {
- weights: result.x,
- classes,
- num_attributes,
- num_classes: k,
- })
- } else {
- let x0 = M::zeros(1, (num_attributes + 1) * k);
+ let weights = result.x.reshape(k, num_attributes + 1);
- let objective = MultiClassObjectiveFunction {
- x,
- y: yi,
- k,
- phantom: PhantomData,
- };
+ Ok(LogisticRegression {
+ weights,
+ classes,
+ num_attributes,
+ num_classes: k,
+ })
+ }
+ Ordering::Equal => {
+ let x0 = M::zeros(1, num_attributes + 1);
- let result = LogisticRegression::minimize(x0, objective);
+ let objective = BinaryObjectiveFunction {
+ x,
+ y: yi,
+ phantom: PhantomData,
+ };
- let weights = result.x.reshape(k, num_attributes + 1);
+ let result = LogisticRegression::minimize(x0, objective);
- Ok(LogisticRegression {
- weights,
- classes,
- num_attributes,
- num_classes: k,
- })
+ Ok(LogisticRegression {
+ weights: result.x,
+ classes,
+ num_attributes,
+ num_classes: k,
+ })
+ }
}
}