feat: + cross_validate, trait Predictor, refactoring
This commit is contained in:
@@ -40,7 +40,7 @@
|
||||
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
//! ];
|
||||
//!
|
||||
//! let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||
//! let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
//!
|
||||
//! let y_hat = lr.predict(&x).unwrap();
|
||||
//! ```
|
||||
@@ -58,6 +58,7 @@ use std::marker::PhantomData;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::base::Predictor;
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
@@ -66,6 +67,11 @@ use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
|
||||
use crate::optimization::line_search::Backtracking;
|
||||
use crate::optimization::FunctionOrder;
|
||||
|
||||
/// Logistic Regression parameters
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct LogisticRegressionParameters {
|
||||
}
|
||||
|
||||
/// Logistic Regression
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct LogisticRegression<T: RealNumber, M: Matrix<T>> {
|
||||
@@ -97,6 +103,13 @@ struct BinaryObjectiveFunction<'a, T: RealNumber, M: Matrix<T>> {
|
||||
phantom: PhantomData<&'a T>,
|
||||
}
|
||||
|
||||
impl Default for LogisticRegressionParameters {
|
||||
fn default() -> Self {
|
||||
LogisticRegressionParameters {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> PartialEq for LogisticRegression<T, M> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.num_classes != other.num_classes
|
||||
@@ -207,11 +220,18 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for LogisticRegression<T, M> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
|
||||
/// Fits Logistic Regression to your data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - target class values
|
||||
pub fn fit(x: &M, y: &M::RowVector) -> Result<LogisticRegression<T, M>, Failed> {
|
||||
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
|
||||
pub fn fit(x: &M, y: &M::RowVector, _parameters: LogisticRegressionParameters) -> Result<LogisticRegression<T, M>, Failed> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let (_, y_nrows) = y_m.shape();
|
||||
@@ -461,7 +481,7 @@ mod tests {
|
||||
]);
|
||||
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
assert_eq!(lr.coefficients().shape(), (3, 2));
|
||||
assert_eq!(lr.intercept().shape(), (3, 1));
|
||||
@@ -484,7 +504,7 @@ mod tests {
|
||||
let x = DenseMatrix::from_vec(15, 4, &blobs.data);
|
||||
let y = blobs.target;
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let y_hat = lr.predict(&x).unwrap();
|
||||
|
||||
@@ -498,7 +518,7 @@ mod tests {
|
||||
let x = DenseMatrix::from_vec(20, 4, &blobs.data);
|
||||
let y = blobs.target;
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let y_hat = lr.predict(&x).unwrap();
|
||||
|
||||
@@ -526,7 +546,7 @@ mod tests {
|
||||
]);
|
||||
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let deserialized_lr: LogisticRegression<f64, DenseMatrix<f64>> =
|
||||
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
||||
@@ -562,7 +582,7 @@ mod tests {
|
||||
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let y_hat = lr.predict(&x).unwrap();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user