fix: code cleanup, documentation
This commit is contained in:
@@ -6,16 +6,29 @@ use crate::linalg::Matrix;
|
||||
use crate::math::num::FloatExt;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub enum LinearRegressionSolver {
|
||||
pub enum LinearRegressionSolverName {
|
||||
QR,
|
||||
SVD,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct LinearRegressionParameters {
|
||||
solver: LinearRegressionSolverName,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct LinearRegression<T: FloatExt, M: Matrix<T>> {
|
||||
coefficients: M,
|
||||
intercept: T,
|
||||
solver: LinearRegressionSolver,
|
||||
solver: LinearRegressionSolverName,
|
||||
}
|
||||
|
||||
impl Default for LinearRegressionParameters {
|
||||
fn default() -> Self {
|
||||
LinearRegressionParameters {
|
||||
solver: LinearRegressionSolverName::SVD
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt, M: Matrix<T>> PartialEq for LinearRegression<T, M> {
|
||||
@@ -26,7 +39,7 @@ impl<T: FloatExt, M: Matrix<T>> PartialEq for LinearRegression<T, M> {
|
||||
}
|
||||
|
||||
impl<T: FloatExt, M: Matrix<T>> LinearRegression<T, M> {
|
||||
pub fn fit(x: &M, y: &M::RowVector, solver: LinearRegressionSolver) -> LinearRegression<T, M> {
|
||||
pub fn fit(x: &M, y: &M::RowVector, parameters: LinearRegressionParameters) -> LinearRegression<T, M> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
let b = y_m.transpose();
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
@@ -38,9 +51,9 @@ impl<T: FloatExt, M: Matrix<T>> LinearRegression<T, M> {
|
||||
|
||||
let a = x.v_stack(&M::ones(x_nrows, 1));
|
||||
|
||||
let w = match solver {
|
||||
LinearRegressionSolver::QR => a.qr_solve_mut(b),
|
||||
LinearRegressionSolver::SVD => a.svd_solve_mut(b),
|
||||
let w = match parameters.solver {
|
||||
LinearRegressionSolverName::QR => a.qr_solve_mut(b),
|
||||
LinearRegressionSolverName::SVD => a.svd_solve_mut(b),
|
||||
};
|
||||
|
||||
let wights = w.slice(0..num_attributes, 0..1);
|
||||
@@ -48,7 +61,7 @@ impl<T: FloatExt, M: Matrix<T>> LinearRegression<T, M> {
|
||||
LinearRegression {
|
||||
intercept: w.get(num_attributes, 0),
|
||||
coefficients: wights,
|
||||
solver: solver,
|
||||
solver: parameters.solver,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,9 +103,9 @@ mod tests {
|
||||
114.2, 115.7, 116.9,
|
||||
]);
|
||||
|
||||
let y_hat_qr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR).predict(&x);
|
||||
let y_hat_qr = LinearRegression::fit(&x, &y, LinearRegressionParameters{solver: LinearRegressionSolverName::QR}).predict(&x);
|
||||
|
||||
let y_hat_svd = LinearRegression::fit(&x, &y, LinearRegressionSolver::SVD).predict(&x);
|
||||
let y_hat_svd = LinearRegression::fit(&x, &y, Default::default()).predict(&x);
|
||||
|
||||
assert!(y
|
||||
.iter()
|
||||
@@ -130,9 +143,9 @@ mod tests {
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
|
||||
let y_hat_qr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR).predict(&x);
|
||||
let y_hat_qr = LinearRegression::fit(&x, &y, LinearRegressionParameters{solver: LinearRegressionSolverName::QR}).predict(&x);
|
||||
|
||||
let y_hat_svd = LinearRegression::fit(&x, &y, LinearRegressionSolver::SVD).predict(&x);
|
||||
let y_hat_svd = LinearRegression::fit(&x, &y, Default::default()).predict(&x);
|
||||
|
||||
assert!(y
|
||||
.iter()
|
||||
@@ -170,7 +183,7 @@ mod tests {
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
|
||||
let lr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR);
|
||||
let lr = LinearRegression::fit(&x, &y, Default::default());
|
||||
|
||||
let deserialized_lr: LinearRegression<f64, DenseMatrix<f64>> =
|
||||
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
||||
|
||||
Reference in New Issue
Block a user