fix: clippy, documentation and formatting

This commit is contained in:
Volodymyr Orlov
2020-12-22 16:35:28 -08:00
parent a2be9e117f
commit 9b221979da
7 changed files with 80 additions and 62 deletions
+12 -14
View File
@@ -134,7 +134,7 @@ struct Cache<T: Clone> {
data: Vec<RefCell<Option<Vec<T>>>>,
}
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVRParameters<T, M, K> {
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVRParameters<T, M, K> {
/// Epsilon in the epsilon-SVR model.
pub fn with_eps(mut self, eps: T) -> Self {
self.eps = eps;
@@ -153,11 +153,11 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVRParameters<T, M
/// The kernel function.
pub fn with_kernel<KK: Kernel<T, M::RowVector>>(&self, kernel: KK) -> SVRParameters<T, M, KK> {
SVRParameters {
eps: self.eps,
eps: self.eps,
c: self.c,
tol: self.tol,
kernel: kernel,
m: PhantomData
kernel,
m: PhantomData,
}
}
}
@@ -169,12 +169,14 @@ impl<T: RealNumber, M: Matrix<T>> Default for SVRParameters<T, M, LinearKernel>
c: T::one(),
tol: T::from_f64(1e-3).unwrap(),
kernel: Kernels::linear(),
m: PhantomData
m: PhantomData,
}
}
}
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Predictor<M, M::RowVector> for SVR<T, M, K> {
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Predictor<M, M::RowVector>
for SVR<T, M, K>
{
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
self.predict(x)
}
@@ -188,7 +190,7 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVR<T, M, K> {
/// * `parameters` - optional parameters, use `Default::default()` to set parameters to default values.
pub fn fit(
x: &M,
y: &M::RowVector,
y: &M::RowVector,
parameters: SVRParameters<T, M, K>,
) -> Result<SVR<T, M, K>, Failed> {
let (n, _) = x.shape();
@@ -544,13 +546,9 @@ mod tests {
114.2, 115.7, 116.9,
];
let y_hat = SVR::fit(
&x,
&y,
SVRParameters::default().with_eps(2.0).with_c(10.0),
)
.and_then(|lr| lr.predict(&x))
.unwrap();
let y_hat = SVR::fit(&x, &y, SVRParameters::default().with_eps(2.0).with_c(10.0))
.and_then(|lr| lr.predict(&x))
.unwrap();
assert!(mean_squared_error(&y_hat, &y) < 2.5);
}