fix: fixes suggested by Clippy

This commit is contained in:
Volodymyr Orlov
2020-11-11 16:10:37 -08:00
parent c42fccdc22
commit cc26555bfd
+39 -37
View File
@@ -52,6 +52,7 @@
//! //!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::cmp::Ordering;
use std::fmt::Debug; use std::fmt::Debug;
use std::marker::PhantomData; use std::marker::PhantomData;
@@ -232,51 +233,53 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
yi[i] = classes.iter().position(|c| yc == *c).unwrap(); yi[i] = classes.iter().position(|c| yc == *c).unwrap();
} }
if k < 2 { match k.cmp(&2) {
Err(Failed::fit(&format!( Ordering::Less => Err(Failed::fit(&format!(
"incorrect number of classes: {}. Should be >= 2.", "incorrect number of classes: {}. Should be >= 2.",
k k
))) ))),
} else if k == 2 { Ordering::Equal => {
let x0 = M::zeros(1, num_attributes + 1); let x0 = M::zeros(1, num_attributes + 1);
let objective = BinaryObjectiveFunction { let objective = BinaryObjectiveFunction {
x: x, x: x,
y: yi, y: yi,
phantom: PhantomData, phantom: PhantomData,
}; };
let result = LogisticRegression::minimize(x0, objective); let result = LogisticRegression::minimize(x0, objective);
let weights = result.x; let weights = result.x;
Ok(LogisticRegression { Ok(LogisticRegression {
coefficients: weights.slice(0..1, 0..num_attributes), coefficients: weights.slice(0..1, 0..num_attributes),
intercept: weights.slice(0..1, num_attributes..num_attributes + 1), intercept: weights.slice(0..1, num_attributes..num_attributes + 1),
classes: classes, classes: classes,
num_attributes: num_attributes, num_attributes: num_attributes,
num_classes: k, num_classes: k,
}) })
} else { }
let x0 = M::zeros(1, (num_attributes + 1) * k); Ordering::Greater => {
let x0 = M::zeros(1, (num_attributes + 1) * k);
let objective = MultiClassObjectiveFunction { let objective = MultiClassObjectiveFunction {
x: x, x: x,
y: yi, y: yi,
k: k, k: k,
phantom: PhantomData, phantom: PhantomData,
}; };
let result = LogisticRegression::minimize(x0, objective); let result = LogisticRegression::minimize(x0, objective);
let weights = result.x.reshape(k, num_attributes + 1); let weights = result.x.reshape(k, num_attributes + 1);
Ok(LogisticRegression { Ok(LogisticRegression {
coefficients: weights.slice(0..k, 0..num_attributes), coefficients: weights.slice(0..k, 0..num_attributes),
intercept: weights.slice(0..k, num_attributes..num_attributes + 1), intercept: weights.slice(0..k, num_attributes..num_attributes + 1),
classes: classes, classes: classes,
num_attributes: num_attributes, num_attributes: num_attributes,
num_classes: k, num_classes: k,
}) })
}
} }
} }
@@ -286,7 +289,6 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
let n = x.shape().0; let n = x.shape().0;
let mut result = M::zeros(1, n); let mut result = M::zeros(1, n);
if self.num_classes == 2 { if self.num_classes == 2 {
let (nrows, _) = x.shape();
let y_hat: Vec<T> = x.matmul(&self.coefficients.transpose()).get_col_as_vec(0); let y_hat: Vec<T> = x.matmul(&self.coefficients.transpose()).get_col_as_vec(0);
let intercept = self.intercept.get(0, 0); let intercept = self.intercept.get(0, 0);
for i in 0..n { for i in 0..n {