feat: + cross_validate, trait Predictor, refactoring

This commit is contained in:
Volodymyr Orlov
2020-12-22 15:41:53 -08:00
parent 40dfca702e
commit a2be9e117f
34 changed files with 977 additions and 369 deletions
+8 -1
View File
@@ -26,6 +26,7 @@ use std::fmt::Debug;
use serde::{Deserialize, Serialize};
use crate::base::Predictor;
use crate::error::Failed;
use crate::linalg::BaseVector;
use crate::linalg::Matrix;
@@ -33,7 +34,7 @@ use crate::linear::lasso_optimizer::InteriorPointOptimizer;
use crate::math::num::RealNumber;
/// Lasso regression parameters
#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct LassoParameters<T: RealNumber> {
/// Controls the strength of the penalty to the loss function.
pub alpha: T,
@@ -71,6 +72,12 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for Lasso<T, M> {
}
}
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for Lasso<T, M> {
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
self.predict(x)
}
}
impl<T: RealNumber, M: Matrix<T>> Lasso<T, M> {
/// Fits Lasso regression to your data.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.