feat: consolidates API

This commit is contained in:
Volodymyr Orlov
2020-12-24 18:36:23 -08:00
parent a69fb3aada
commit 810a5c429b
25 changed files with 400 additions and 98 deletions
+9 -1
View File
@@ -58,7 +58,7 @@ use std::fmt::Debug;
use serde::{Deserialize, Serialize};
use crate::base::Predictor;
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed;
use crate::linalg::BaseVector;
use crate::linalg::Matrix;
@@ -139,6 +139,14 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for ElasticNet<T, M> {
}
}
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, ElasticNetParameters<T>>
for ElasticNet<T, M>
{
fn fit(x: &M, y: &M::RowVector, parameters: ElasticNetParameters<T>) -> Result<Self, Failed> {
ElasticNet::fit(x, y, parameters)
}
}
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for ElasticNet<T, M> {
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
self.predict(x)
+9 -1
View File
@@ -26,7 +26,7 @@ use std::fmt::Debug;
use serde::{Deserialize, Serialize};
use crate::base::Predictor;
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed;
use crate::linalg::BaseVector;
use crate::linalg::Matrix;
@@ -95,6 +95,14 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for Lasso<T, M> {
}
}
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, LassoParameters<T>>
for Lasso<T, M>
{
fn fit(x: &M, y: &M::RowVector, parameters: LassoParameters<T>) -> Result<Self, Failed> {
Lasso::fit(x, y, parameters)
}
}
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for Lasso<T, M> {
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
self.predict(x)
+13 -1
View File
@@ -64,7 +64,7 @@ use std::fmt::Debug;
use serde::{Deserialize, Serialize};
use crate::base::Predictor;
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
@@ -116,6 +116,18 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for LinearRegression<T, M> {
}
}
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, LinearRegressionParameters>
for LinearRegression<T, M>
{
fn fit(
x: &M,
y: &M::RowVector,
parameters: LinearRegressionParameters,
) -> Result<Self, Failed> {
LinearRegression::fit(x, y, parameters)
}
}
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for LinearRegression<T, M> {
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
self.predict(x)
+13 -1
View File
@@ -58,7 +58,7 @@ use std::marker::PhantomData;
use serde::{Deserialize, Serialize};
use crate::base::Predictor;
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
@@ -218,6 +218,18 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
}
}
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, LogisticRegressionParameters>
for LogisticRegression<T, M>
{
fn fit(
x: &M,
y: &M::RowVector,
parameters: LogisticRegressionParameters,
) -> Result<Self, Failed> {
LogisticRegression::fit(x, y, parameters)
}
}
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for LogisticRegression<T, M> {
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
self.predict(x)
+13 -1
View File
@@ -60,7 +60,7 @@ use std::fmt::Debug;
use serde::{Deserialize, Serialize};
use crate::base::Predictor;
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed;
use crate::linalg::BaseVector;
use crate::linalg::Matrix;
@@ -130,6 +130,18 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for RidgeRegression<T, M> {
}
}
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, RidgeRegressionParameters<T>>
for RidgeRegression<T, M>
{
fn fit(
x: &M,
y: &M::RowVector,
parameters: RidgeRegressionParameters<T>,
) -> Result<Self, Failed> {
RidgeRegression::fit(x, y, parameters)
}
}
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RidgeRegression<T, M> {
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
self.predict(x)