feat: + cross_validate, trait Predictor, refactoring
This commit is contained in:
@@ -64,11 +64,12 @@ use std::fmt::Debug;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::base::Predictor;
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
/// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable.
|
||||
pub enum LinearRegressionSolverName {
|
||||
/// QR decomposition, see [QR](../../linalg/qr/index.html)
|
||||
@@ -78,7 +79,7 @@ pub enum LinearRegressionSolverName {
|
||||
}
|
||||
|
||||
/// Linear Regression parameters
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct LinearRegressionParameters {
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub solver: LinearRegressionSolverName,
|
||||
@@ -107,6 +108,12 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for LinearRegression<T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for LinearRegression<T, M> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> LinearRegression<T, M> {
|
||||
/// Fits Linear Regression to your data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
|
||||
Reference in New Issue
Block a user