feat: + cross_validate, trait Predictor, refactoring
This commit is contained in:
@@ -58,6 +58,7 @@ use std::fmt::Debug;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::base::Predictor;
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
@@ -66,7 +67,7 @@ use crate::math::num::RealNumber;
|
||||
use crate::linear::lasso_optimizer::InteriorPointOptimizer;
|
||||
|
||||
/// Elastic net parameters
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct ElasticNetParameters<T: RealNumber> {
|
||||
/// Regularization parameter.
|
||||
pub alpha: T,
|
||||
@@ -108,6 +109,12 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for ElasticNet<T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for ElasticNet<T, M> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> ElasticNet<T, M> {
|
||||
/// Fits elastic net regression to your data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
|
||||
+8
-1
@@ -26,6 +26,7 @@ use std::fmt::Debug;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::base::Predictor;
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
@@ -33,7 +34,7 @@ use crate::linear::lasso_optimizer::InteriorPointOptimizer;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Lasso regression parameters
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct LassoParameters<T: RealNumber> {
|
||||
/// Controls the strength of the penalty to the loss function.
|
||||
pub alpha: T,
|
||||
@@ -71,6 +72,12 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for Lasso<T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for Lasso<T, M> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Lasso<T, M> {
|
||||
/// Fits Lasso regression to your data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
|
||||
@@ -64,11 +64,12 @@ use std::fmt::Debug;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::base::Predictor;
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
/// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable.
|
||||
pub enum LinearRegressionSolverName {
|
||||
/// QR decomposition, see [QR](../../linalg/qr/index.html)
|
||||
@@ -78,7 +79,7 @@ pub enum LinearRegressionSolverName {
|
||||
}
|
||||
|
||||
/// Linear Regression parameters
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct LinearRegressionParameters {
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub solver: LinearRegressionSolverName,
|
||||
@@ -107,6 +108,12 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for LinearRegression<T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for LinearRegression<T, M> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> LinearRegression<T, M> {
|
||||
/// Fits Linear Regression to your data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
//! ];
|
||||
//!
|
||||
//! let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||
//! let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
//!
|
||||
//! let y_hat = lr.predict(&x).unwrap();
|
||||
//! ```
|
||||
@@ -58,6 +58,7 @@ use std::marker::PhantomData;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::base::Predictor;
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
@@ -66,6 +67,11 @@ use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
|
||||
use crate::optimization::line_search::Backtracking;
|
||||
use crate::optimization::FunctionOrder;
|
||||
|
||||
/// Logistic Regression parameters
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct LogisticRegressionParameters {
|
||||
}
|
||||
|
||||
/// Logistic Regression
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct LogisticRegression<T: RealNumber, M: Matrix<T>> {
|
||||
@@ -97,6 +103,13 @@ struct BinaryObjectiveFunction<'a, T: RealNumber, M: Matrix<T>> {
|
||||
phantom: PhantomData<&'a T>,
|
||||
}
|
||||
|
||||
impl Default for LogisticRegressionParameters {
|
||||
fn default() -> Self {
|
||||
LogisticRegressionParameters {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> PartialEq for LogisticRegression<T, M> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.num_classes != other.num_classes
|
||||
@@ -207,11 +220,18 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for LogisticRegression<T, M> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
|
||||
/// Fits Logistic Regression to your data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - target class values
|
||||
pub fn fit(x: &M, y: &M::RowVector) -> Result<LogisticRegression<T, M>, Failed> {
|
||||
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
|
||||
pub fn fit(x: &M, y: &M::RowVector, _parameters: LogisticRegressionParameters) -> Result<LogisticRegression<T, M>, Failed> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let (_, y_nrows) = y_m.shape();
|
||||
@@ -461,7 +481,7 @@ mod tests {
|
||||
]);
|
||||
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
assert_eq!(lr.coefficients().shape(), (3, 2));
|
||||
assert_eq!(lr.intercept().shape(), (3, 1));
|
||||
@@ -484,7 +504,7 @@ mod tests {
|
||||
let x = DenseMatrix::from_vec(15, 4, &blobs.data);
|
||||
let y = blobs.target;
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let y_hat = lr.predict(&x).unwrap();
|
||||
|
||||
@@ -498,7 +518,7 @@ mod tests {
|
||||
let x = DenseMatrix::from_vec(20, 4, &blobs.data);
|
||||
let y = blobs.target;
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let y_hat = lr.predict(&x).unwrap();
|
||||
|
||||
@@ -526,7 +546,7 @@ mod tests {
|
||||
]);
|
||||
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let deserialized_lr: LogisticRegression<f64, DenseMatrix<f64>> =
|
||||
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
||||
@@ -562,7 +582,7 @@ mod tests {
|
||||
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let y_hat = lr.predict(&x).unwrap();
|
||||
|
||||
|
||||
@@ -63,12 +63,13 @@ use std::fmt::Debug;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::base::Predictor;
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
/// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable.
|
||||
pub enum RidgeRegressionSolverName {
|
||||
/// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html)
|
||||
@@ -78,7 +79,7 @@ pub enum RidgeRegressionSolverName {
|
||||
}
|
||||
|
||||
/// Ridge Regression parameters
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct RidgeRegressionParameters<T: RealNumber> {
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub solver: RidgeRegressionSolverName,
|
||||
@@ -114,6 +115,12 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for RidgeRegression<T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RidgeRegression<T, M> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> RidgeRegression<T, M> {
|
||||
/// Fits ridge regression to your data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
|
||||
Reference in New Issue
Block a user