feat: refactoring, adds Result to most public API

This commit is contained in:
Volodymyr Orlov
2020-09-18 15:20:32 -07:00
parent 4921ae76f5
commit a9db970195
24 changed files with 389 additions and 298 deletions
+22 -16
View File
@@ -40,9 +40,9 @@
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
//! ];
//!
//! let lr = LogisticRegression::fit(&x, &y);
//! let lr = LogisticRegression::fit(&x, &y).unwrap();
//!
//! let y_hat = lr.predict(&x);
//! let y_hat = lr.predict(&x).unwrap();
//! ```
//!
//! ## References:
@@ -57,6 +57,7 @@ use std::marker::PhantomData;
use serde::{Deserialize, Serialize};
use crate::error::Failed;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
use crate::optimization::first_order::lbfgs::LBFGS;
@@ -208,13 +209,15 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
/// Fits Logistic Regression to your data.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - target class values
pub fn fit(x: &M, y: &M::RowVector) -> LogisticRegression<T, M> {
pub fn fit(x: &M, y: &M::RowVector) -> Result<LogisticRegression<T, M>, Failed> {
let y_m = M::from_row_vector(y.clone());
let (x_nrows, num_attributes) = x.shape();
let (_, y_nrows) = y_m.shape();
if x_nrows != y_nrows {
panic!("Number of rows of X doesn't match number of rows of Y");
return Err(Failed::fit(&format!(
"Number of rows of X doesn't match number of rows of Y"
)));
}
let classes = y_m.unique();
@@ -229,7 +232,10 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
}
if k < 2 {
panic!("Incorrect number of classes: {}", k);
Err(Failed::fit(&format!(
"incorrect number of classes: {}. Should be >= 2.",
k
)))
} else if k == 2 {
let x0 = M::zeros(1, num_attributes + 1);
@@ -241,12 +247,12 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
let result = LogisticRegression::minimize(x0, objective);
LogisticRegression {
Ok(LogisticRegression {
weights: result.x,
classes: classes,
num_attributes: num_attributes,
num_classes: k,
}
})
} else {
let x0 = M::zeros(1, (num_attributes + 1) * k);
@@ -261,18 +267,18 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
let weights = result.x.reshape(k, num_attributes + 1);
LogisticRegression {
Ok(LogisticRegression {
weights: weights,
classes: classes,
num_attributes: num_attributes,
num_classes: k,
}
})
}
}
/// Predict class labels for samples in `x`.
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &M) -> M::RowVector {
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
let n = x.shape().0;
let mut result = M::zeros(1, n);
if self.num_classes == 2 {
@@ -297,7 +303,7 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
result.set(0, i, self.classes[class_idxs[i]]);
}
}
result.to_row_vector()
Ok(result.to_row_vector())
}
/// Get estimates regression coefficients
@@ -444,7 +450,7 @@ mod tests {
]);
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
let lr = LogisticRegression::fit(&x, &y);
let lr = LogisticRegression::fit(&x, &y).unwrap();
assert_eq!(lr.coefficients().shape(), (3, 2));
assert_eq!(lr.intercept().shape(), (3, 1));
@@ -452,7 +458,7 @@ mod tests {
assert!((lr.coefficients().get(0, 0) - 0.0435).abs() < 1e-4);
assert!((lr.intercept().get(0, 0) - 0.1250).abs() < 1e-4);
let y_hat = lr.predict(&x);
let y_hat = lr.predict(&x).unwrap();
assert_eq!(
y_hat,
@@ -481,7 +487,7 @@ mod tests {
]);
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
let lr = LogisticRegression::fit(&x, &y);
let lr = LogisticRegression::fit(&x, &y).unwrap();
let deserialized_lr: LogisticRegression<f64, DenseMatrix<f64>> =
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
@@ -517,9 +523,9 @@ mod tests {
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
let lr = LogisticRegression::fit(&x, &y);
let lr = LogisticRegression::fit(&x, &y).unwrap();
let y_hat = lr.predict(&x);
let y_hat = lr.predict(&x).unwrap();
let error: f64 = y
.into_iter()