diff --git a/src/lib.rs b/src/lib.rs index c85596e..8c97bf7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -69,7 +69,6 @@ clippy::ptr_arg, clippy::len_without_is_empty, clippy::map_entry, - clippy::comparison_chain, clippy::type_complexity, clippy::too_many_arguments, clippy::many_single_char_names diff --git a/src/linalg/lu.rs b/src/linalg/lu.rs index cbe195f..bfc7fff 100644 --- a/src/linalg/lu.rs +++ b/src/linalg/lu.rs @@ -33,6 +33,7 @@ //! #![allow(non_snake_case)] +use std::cmp::Ordering; use std::fmt::Debug; use std::marker::PhantomData; @@ -78,12 +79,10 @@ impl> LU { for i in 0..n_rows { for j in 0..n_cols { - if i > j { - L.set(i, j, self.LU.get(i, j)); - } else if i == j { - L.set(i, j, T::one()); - } else { - L.set(i, j, T::zero()); + match i.cmp(&j) { + Ordering::Greater => L.set(i, j, self.LU.get(i, j)), + Ordering::Equal => L.set(i, j, T::one()), + Ordering::Less => L.set(i, j, T::zero()), } } } diff --git a/src/linear/logistic_regression.rs b/src/linear/logistic_regression.rs index ec90af1..796caed 100644 --- a/src/linear/logistic_regression.rs +++ b/src/linear/logistic_regression.rs @@ -52,6 +52,7 @@ //! //! //! +use std::cmp::Ordering; use std::fmt::Debug; use std::marker::PhantomData; @@ -231,48 +232,50 @@ impl> LogisticRegression { yi[i] = classes.iter().position(|c| yc == *c).unwrap(); } - if k < 2 { - Err(Failed::fit(&format!( + match k.cmp(&2) { + Ordering::Less => Err(Failed::fit(&format!( "incorrect number of classes: {}. Should be >= 2.", k - ))) - } else if k == 2 { - let x0 = M::zeros(1, num_attributes + 1); + ))), + Ordering::Greater => { + let x0 = M::zeros(1, (num_attributes + 1) * k); - let objective = BinaryObjectiveFunction { - x, - y: yi, - phantom: PhantomData, - }; + let objective = MultiClassObjectiveFunction { + x, + y: yi, + k, + phantom: PhantomData, + }; - let result = LogisticRegression::minimize(x0, objective); + let result = LogisticRegression::minimize(x0, objective); - Ok(LogisticRegression { - weights: result.x, - classes, - num_attributes, - num_classes: k, - }) - } else { - let x0 = M::zeros(1, (num_attributes + 1) * k); + let weights = result.x.reshape(k, num_attributes + 1); - let objective = MultiClassObjectiveFunction { - x, - y: yi, - k, - phantom: PhantomData, - }; + Ok(LogisticRegression { + weights, + classes, + num_attributes, + num_classes: k, + }) + } + Ordering::Equal => { + let x0 = M::zeros(1, num_attributes + 1); - let result = LogisticRegression::minimize(x0, objective); + let objective = BinaryObjectiveFunction { + x, + y: yi, + phantom: PhantomData, + }; - let weights = result.x.reshape(k, num_attributes + 1); + let result = LogisticRegression::minimize(x0, objective); - Ok(LogisticRegression { - weights, - classes, - num_attributes, - num_classes: k, - }) + Ok(LogisticRegression { + weights: result.x, + classes, + num_attributes, + num_classes: k, + }) + } } }