Fix clippy::comparison_chain

This commit is contained in:
Luis Moreno
2020-11-09 00:02:22 -04:00
parent 3c1969bdf5
commit 5e887634db
3 changed files with 41 additions and 40 deletions
-1
View File
@@ -69,7 +69,6 @@
clippy::ptr_arg, clippy::ptr_arg,
clippy::len_without_is_empty, clippy::len_without_is_empty,
clippy::map_entry, clippy::map_entry,
clippy::comparison_chain,
clippy::type_complexity, clippy::type_complexity,
clippy::too_many_arguments, clippy::too_many_arguments,
clippy::many_single_char_names clippy::many_single_char_names
+5 -6
View File
@@ -33,6 +33,7 @@
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#![allow(non_snake_case)] #![allow(non_snake_case)]
use std::cmp::Ordering;
use std::fmt::Debug; use std::fmt::Debug;
use std::marker::PhantomData; use std::marker::PhantomData;
@@ -78,12 +79,10 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
for i in 0..n_rows { for i in 0..n_rows {
for j in 0..n_cols { for j in 0..n_cols {
if i > j { match i.cmp(&j) {
L.set(i, j, self.LU.get(i, j)); Ordering::Greater => L.set(i, j, self.LU.get(i, j)),
} else if i == j { Ordering::Equal => L.set(i, j, T::one()),
L.set(i, j, T::one()); Ordering::Less => L.set(i, j, T::zero()),
} else {
L.set(i, j, T::zero());
} }
} }
} }
+36 -33
View File
@@ -52,6 +52,7 @@
//! //!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::cmp::Ordering;
use std::fmt::Debug; use std::fmt::Debug;
use std::marker::PhantomData; use std::marker::PhantomData;
@@ -231,48 +232,50 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
yi[i] = classes.iter().position(|c| yc == *c).unwrap(); yi[i] = classes.iter().position(|c| yc == *c).unwrap();
} }
if k < 2 { match k.cmp(&2) {
Err(Failed::fit(&format!( Ordering::Less => Err(Failed::fit(&format!(
"incorrect number of classes: {}. Should be >= 2.", "incorrect number of classes: {}. Should be >= 2.",
k k
))) ))),
} else if k == 2 { Ordering::Greater => {
let x0 = M::zeros(1, num_attributes + 1); let x0 = M::zeros(1, (num_attributes + 1) * k);
let objective = BinaryObjectiveFunction { let objective = MultiClassObjectiveFunction {
x, x,
y: yi, y: yi,
phantom: PhantomData, k,
}; phantom: PhantomData,
};
let result = LogisticRegression::minimize(x0, objective); let result = LogisticRegression::minimize(x0, objective);
Ok(LogisticRegression { let weights = result.x.reshape(k, num_attributes + 1);
weights: result.x,
classes,
num_attributes,
num_classes: k,
})
} else {
let x0 = M::zeros(1, (num_attributes + 1) * k);
let objective = MultiClassObjectiveFunction { Ok(LogisticRegression {
x, weights,
y: yi, classes,
k, num_attributes,
phantom: PhantomData, num_classes: k,
}; })
}
Ordering::Equal => {
let x0 = M::zeros(1, num_attributes + 1);
let result = LogisticRegression::minimize(x0, objective); let objective = BinaryObjectiveFunction {
x,
y: yi,
phantom: PhantomData,
};
let weights = result.x.reshape(k, num_attributes + 1); let result = LogisticRegression::minimize(x0, objective);
Ok(LogisticRegression { Ok(LogisticRegression {
weights, weights: result.x,
classes, classes,
num_attributes, num_attributes,
num_classes: k, num_classes: k,
}) })
}
} }
} }