feat: consolidates API
This commit is contained in:
@@ -58,7 +58,7 @@ use std::marker::PhantomData;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::base::Predictor;
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
@@ -218,6 +218,18 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, LogisticRegressionParameters>
|
||||
for LogisticRegression<T, M>
|
||||
{
|
||||
fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: LogisticRegressionParameters,
|
||||
) -> Result<Self, Failed> {
|
||||
LogisticRegression::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for LogisticRegression<T, M> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
|
||||
Reference in New Issue
Block a user