feat: + cross_validate, trait Predictor, refactoring

This commit is contained in:
Volodymyr Orlov
2020-12-22 15:41:53 -08:00
parent 40dfca702e
commit a2be9e117f
34 changed files with 977 additions and 369 deletions
+27 -7
View File
@@ -40,7 +40,7 @@
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
//! ];
//!
//! let lr = LogisticRegression::fit(&x, &y).unwrap();
//! let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
//!
//! let y_hat = lr.predict(&x).unwrap();
//! ```
@@ -58,6 +58,7 @@ use std::marker::PhantomData;
use serde::{Deserialize, Serialize};
use crate::base::Predictor;
use crate::error::Failed;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
@@ -66,6 +67,11 @@ use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
use crate::optimization::line_search::Backtracking;
use crate::optimization::FunctionOrder;
/// Logistic Regression parameters
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct LogisticRegressionParameters {
}
/// Logistic Regression
#[derive(Serialize, Deserialize, Debug)]
pub struct LogisticRegression<T: RealNumber, M: Matrix<T>> {
@@ -97,6 +103,13 @@ struct BinaryObjectiveFunction<'a, T: RealNumber, M: Matrix<T>> {
phantom: PhantomData<&'a T>,
}
impl Default for LogisticRegressionParameters {
fn default() -> Self {
LogisticRegressionParameters {
}
}
}
impl<T: RealNumber, M: Matrix<T>> PartialEq for LogisticRegression<T, M> {
fn eq(&self, other: &Self) -> bool {
if self.num_classes != other.num_classes
@@ -207,11 +220,18 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
}
}
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for LogisticRegression<T, M> {
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
self.predict(x)
}
}
impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
/// Fits Logistic Regression to your data.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - target class values
pub fn fit(x: &M, y: &M::RowVector) -> Result<LogisticRegression<T, M>, Failed> {
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
pub fn fit(x: &M, y: &M::RowVector, _parameters: LogisticRegressionParameters) -> Result<LogisticRegression<T, M>, Failed> {
let y_m = M::from_row_vector(y.clone());
let (x_nrows, num_attributes) = x.shape();
let (_, y_nrows) = y_m.shape();
@@ -461,7 +481,7 @@ mod tests {
]);
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
let lr = LogisticRegression::fit(&x, &y).unwrap();
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
assert_eq!(lr.coefficients().shape(), (3, 2));
assert_eq!(lr.intercept().shape(), (3, 1));
@@ -484,7 +504,7 @@ mod tests {
let x = DenseMatrix::from_vec(15, 4, &blobs.data);
let y = blobs.target;
let lr = LogisticRegression::fit(&x, &y).unwrap();
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
let y_hat = lr.predict(&x).unwrap();
@@ -498,7 +518,7 @@ mod tests {
let x = DenseMatrix::from_vec(20, 4, &blobs.data);
let y = blobs.target;
let lr = LogisticRegression::fit(&x, &y).unwrap();
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
let y_hat = lr.predict(&x).unwrap();
@@ -526,7 +546,7 @@ mod tests {
]);
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
let lr = LogisticRegression::fit(&x, &y).unwrap();
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
let deserialized_lr: LogisticRegression<f64, DenseMatrix<f64>> =
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
@@ -562,7 +582,7 @@ mod tests {
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
let lr = LogisticRegression::fit(&x, &y).unwrap();
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
let y_hat = lr.predict(&x).unwrap();