fix: formatting
This commit is contained in:
@@ -26,7 +26,7 @@ pub struct LinearRegression<T: FloatExt, M: Matrix<T>> {
|
||||
impl Default for LinearRegressionParameters {
|
||||
fn default() -> Self {
|
||||
LinearRegressionParameters {
|
||||
solver: LinearRegressionSolverName::SVD
|
||||
solver: LinearRegressionSolverName::SVD,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -39,7 +39,11 @@ impl<T: FloatExt, M: Matrix<T>> PartialEq for LinearRegression<T, M> {
|
||||
}
|
||||
|
||||
impl<T: FloatExt, M: Matrix<T>> LinearRegression<T, M> {
|
||||
pub fn fit(x: &M, y: &M::RowVector, parameters: LinearRegressionParameters) -> LinearRegression<T, M> {
|
||||
pub fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: LinearRegressionParameters,
|
||||
) -> LinearRegression<T, M> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
let b = y_m.transpose();
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
@@ -103,7 +107,14 @@ mod tests {
|
||||
114.2, 115.7, 116.9,
|
||||
]);
|
||||
|
||||
let y_hat_qr = LinearRegression::fit(&x, &y, LinearRegressionParameters{solver: LinearRegressionSolverName::QR}).predict(&x);
|
||||
let y_hat_qr = LinearRegression::fit(
|
||||
&x,
|
||||
&y,
|
||||
LinearRegressionParameters {
|
||||
solver: LinearRegressionSolverName::QR,
|
||||
},
|
||||
)
|
||||
.predict(&x);
|
||||
|
||||
let y_hat_svd = LinearRegression::fit(&x, &y, Default::default()).predict(&x);
|
||||
|
||||
@@ -143,7 +154,14 @@ mod tests {
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
|
||||
let y_hat_qr = LinearRegression::fit(&x, &y, LinearRegressionParameters{solver: LinearRegressionSolverName::QR}).predict(&x);
|
||||
let y_hat_qr = LinearRegression::fit(
|
||||
&x,
|
||||
&y,
|
||||
LinearRegressionParameters {
|
||||
solver: LinearRegressionSolverName::QR,
|
||||
},
|
||||
)
|
||||
.predict(&x);
|
||||
|
||||
let y_hat_svd = LinearRegression::fit(&x, &y, Default::default()).predict(&x);
|
||||
|
||||
|
||||
+12
-11
@@ -15,7 +15,7 @@ pub enum KNNAlgorithmName {
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct KNNClassifierParameters {
|
||||
pub algorithm: KNNAlgorithmName,
|
||||
pub k: usize
|
||||
pub k: usize,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
@@ -36,7 +36,7 @@ impl Default for KNNClassifierParameters {
|
||||
fn default() -> Self {
|
||||
KNNClassifierParameters {
|
||||
algorithm: KNNAlgorithmName::CoverTree,
|
||||
k: 3
|
||||
k: 3,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -93,7 +93,7 @@ impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
distance: D,
|
||||
parameters: KNNClassifierParameters
|
||||
parameters: KNNClassifierParameters,
|
||||
) -> KNNClassifier<T, D> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
|
||||
@@ -118,7 +118,10 @@ impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||
)
|
||||
);
|
||||
|
||||
assert!(parameters.k > 1, format!("k should be > 1, k=[{}]", parameters.k));
|
||||
assert!(
|
||||
parameters.k > 1,
|
||||
format!("k should be > 1, k=[{}]", parameters.k)
|
||||
);
|
||||
|
||||
KNNClassifier {
|
||||
classes: classes,
|
||||
@@ -169,7 +172,10 @@ mod tests {
|
||||
&x,
|
||||
&y,
|
||||
Distances::euclidian(),
|
||||
KNNClassifierParameters{k: 3, algorithm: KNNAlgorithmName::LinearSearch}
|
||||
KNNClassifierParameters {
|
||||
k: 3,
|
||||
algorithm: KNNAlgorithmName::LinearSearch,
|
||||
},
|
||||
);
|
||||
let r = knn.predict(&x);
|
||||
assert_eq!(5, Vec::len(&r));
|
||||
@@ -181,12 +187,7 @@ mod tests {
|
||||
let x = DenseMatrix::from_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
let y = vec![2., 2., 2., 3., 3.];
|
||||
|
||||
let knn = KNNClassifier::fit(
|
||||
&x,
|
||||
&y,
|
||||
Distances::euclidian(),
|
||||
Default::default()
|
||||
);
|
||||
let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default());
|
||||
|
||||
let deserialized_knn = bincode::deserialize(&bincode::serialize(&knn).unwrap()).unwrap();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user