feat: refactoring, adds Result to most public API
This commit is contained in:
@@ -47,9 +47,9 @@
|
||||
//!
|
||||
//! let lr = LinearRegression::fit(&x, &y, LinearRegressionParameters {
|
||||
//! solver: LinearRegressionSolverName::QR, // or SVD
|
||||
//! });
|
||||
//! }).unwrap();
|
||||
//!
|
||||
//! let y_hat = lr.predict(&x);
|
||||
//! let y_hat = lr.predict(&x).unwrap();
|
||||
//! ```
|
||||
//!
|
||||
//! ## References:
|
||||
@@ -64,6 +64,7 @@ use std::fmt::Debug;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
@@ -115,39 +116,41 @@ impl<T: RealNumber, M: Matrix<T>> LinearRegression<T, M> {
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: LinearRegressionParameters,
|
||||
) -> LinearRegression<T, M> {
|
||||
) -> Result<LinearRegression<T, M>, Failed> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
let b = y_m.transpose();
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let (y_nrows, _) = b.shape();
|
||||
|
||||
if x_nrows != y_nrows {
|
||||
panic!("Number of rows of X doesn't match number of rows of Y");
|
||||
return Err(Failed::fit(&format!(
|
||||
"Number of rows of X doesn't match number of rows of Y"
|
||||
)));
|
||||
}
|
||||
|
||||
let a = x.h_stack(&M::ones(x_nrows, 1));
|
||||
|
||||
let w = match parameters.solver {
|
||||
LinearRegressionSolverName::QR => a.qr_solve_mut(b),
|
||||
LinearRegressionSolverName::SVD => a.svd_solve_mut(b),
|
||||
LinearRegressionSolverName::QR => a.qr_solve_mut(b)?,
|
||||
LinearRegressionSolverName::SVD => a.svd_solve_mut(b)?,
|
||||
};
|
||||
|
||||
let wights = w.slice(0..num_attributes, 0..1);
|
||||
|
||||
LinearRegression {
|
||||
Ok(LinearRegression {
|
||||
intercept: w.get(num_attributes, 0),
|
||||
coefficients: wights,
|
||||
solver: parameters.solver,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Predict target values from `x`
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn predict(&self, x: &M) -> M::RowVector {
|
||||
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let (nrows, _) = x.shape();
|
||||
let mut y_hat = x.matmul(&self.coefficients);
|
||||
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
|
||||
y_hat.transpose().to_row_vector()
|
||||
Ok(y_hat.transpose().to_row_vector())
|
||||
}
|
||||
|
||||
/// Get estimates regression coefficients
|
||||
@@ -199,9 +202,12 @@ mod tests {
|
||||
solver: LinearRegressionSolverName::QR,
|
||||
},
|
||||
)
|
||||
.predict(&x);
|
||||
.and_then(|lr| lr.predict(&x))
|
||||
.unwrap();
|
||||
|
||||
let y_hat_svd = LinearRegression::fit(&x, &y, Default::default()).predict(&x);
|
||||
let y_hat_svd = LinearRegression::fit(&x, &y, Default::default())
|
||||
.and_then(|lr| lr.predict(&x))
|
||||
.unwrap();
|
||||
|
||||
assert!(y
|
||||
.iter()
|
||||
@@ -239,7 +245,7 @@ mod tests {
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
|
||||
let lr = LinearRegression::fit(&x, &y, Default::default());
|
||||
let lr = LinearRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let deserialized_lr: LinearRegression<f64, DenseMatrix<f64>> =
|
||||
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
||||
|
||||
@@ -40,9 +40,9 @@
|
||||
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
//! ];
|
||||
//!
|
||||
//! let lr = LogisticRegression::fit(&x, &y);
|
||||
//! let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||
//!
|
||||
//! let y_hat = lr.predict(&x);
|
||||
//! let y_hat = lr.predict(&x).unwrap();
|
||||
//! ```
|
||||
//!
|
||||
//! ## References:
|
||||
@@ -57,6 +57,7 @@ use std::marker::PhantomData;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::optimization::first_order::lbfgs::LBFGS;
|
||||
@@ -208,13 +209,15 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
|
||||
/// Fits Logistic Regression to your data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - target class values
|
||||
pub fn fit(x: &M, y: &M::RowVector) -> LogisticRegression<T, M> {
|
||||
pub fn fit(x: &M, y: &M::RowVector) -> Result<LogisticRegression<T, M>, Failed> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let (_, y_nrows) = y_m.shape();
|
||||
|
||||
if x_nrows != y_nrows {
|
||||
panic!("Number of rows of X doesn't match number of rows of Y");
|
||||
return Err(Failed::fit(&format!(
|
||||
"Number of rows of X doesn't match number of rows of Y"
|
||||
)));
|
||||
}
|
||||
|
||||
let classes = y_m.unique();
|
||||
@@ -229,7 +232,10 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
|
||||
}
|
||||
|
||||
if k < 2 {
|
||||
panic!("Incorrect number of classes: {}", k);
|
||||
Err(Failed::fit(&format!(
|
||||
"incorrect number of classes: {}. Should be >= 2.",
|
||||
k
|
||||
)))
|
||||
} else if k == 2 {
|
||||
let x0 = M::zeros(1, num_attributes + 1);
|
||||
|
||||
@@ -241,12 +247,12 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
|
||||
|
||||
let result = LogisticRegression::minimize(x0, objective);
|
||||
|
||||
LogisticRegression {
|
||||
Ok(LogisticRegression {
|
||||
weights: result.x,
|
||||
classes: classes,
|
||||
num_attributes: num_attributes,
|
||||
num_classes: k,
|
||||
}
|
||||
})
|
||||
} else {
|
||||
let x0 = M::zeros(1, (num_attributes + 1) * k);
|
||||
|
||||
@@ -261,18 +267,18 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
|
||||
|
||||
let weights = result.x.reshape(k, num_attributes + 1);
|
||||
|
||||
LogisticRegression {
|
||||
Ok(LogisticRegression {
|
||||
weights: weights,
|
||||
classes: classes,
|
||||
num_attributes: num_attributes,
|
||||
num_classes: k,
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Predict class labels for samples in `x`.
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn predict(&self, x: &M) -> M::RowVector {
|
||||
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let n = x.shape().0;
|
||||
let mut result = M::zeros(1, n);
|
||||
if self.num_classes == 2 {
|
||||
@@ -297,7 +303,7 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
|
||||
result.set(0, i, self.classes[class_idxs[i]]);
|
||||
}
|
||||
}
|
||||
result.to_row_vector()
|
||||
Ok(result.to_row_vector())
|
||||
}
|
||||
|
||||
/// Get estimates regression coefficients
|
||||
@@ -444,7 +450,7 @@ mod tests {
|
||||
]);
|
||||
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y);
|
||||
let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||
|
||||
assert_eq!(lr.coefficients().shape(), (3, 2));
|
||||
assert_eq!(lr.intercept().shape(), (3, 1));
|
||||
@@ -452,7 +458,7 @@ mod tests {
|
||||
assert!((lr.coefficients().get(0, 0) - 0.0435).abs() < 1e-4);
|
||||
assert!((lr.intercept().get(0, 0) - 0.1250).abs() < 1e-4);
|
||||
|
||||
let y_hat = lr.predict(&x);
|
||||
let y_hat = lr.predict(&x).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
y_hat,
|
||||
@@ -481,7 +487,7 @@ mod tests {
|
||||
]);
|
||||
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y);
|
||||
let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||
|
||||
let deserialized_lr: LogisticRegression<f64, DenseMatrix<f64>> =
|
||||
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
||||
@@ -517,9 +523,9 @@ mod tests {
|
||||
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y);
|
||||
let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||
|
||||
let y_hat = lr.predict(&x);
|
||||
let y_hat = lr.predict(&x).unwrap();
|
||||
|
||||
let error: f64 = y
|
||||
.into_iter()
|
||||
|
||||
Reference in New Issue
Block a user