feat: adds KMeans clustering algorithm

This commit is contained in:
Volodymyr Orlov
2020-02-20 18:43:24 -08:00
parent 4359d66bfa
commit 0e89113297
13 changed files with 637 additions and 84 deletions
+7 -6
View File
@@ -19,14 +19,14 @@ impl<M: Matrix> LinearRegression<M> {
pub fn fit(x: &M, y: &M, solver: LinearRegressionSolver) -> LinearRegression<M>{
let b = y.transpose();
let (x_nrows, num_attributes) = x.shape();
let (y_nrows, _) = y.shape();
let (y_nrows, _) = b.shape();
if x_nrows != y_nrows {
panic!("Number of rows of X doesn't match number of rows of Y");
}
let b = y.clone();
let mut a = x.v_stack(&M::ones(x_nrows, 1));
let w = match solver {
@@ -52,7 +52,7 @@ impl<M: Matrix> Regression<M> for LinearRegression<M> {
let (nrows, _) = x.shape();
let mut y_hat = x.dot(&self.coefficients);
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
y_hat
y_hat.transpose()
}
}
@@ -65,7 +65,7 @@ mod tests {
#[test]
fn ols_fit_predict() {
let x = DenseMatrix::from_2d_array(&[
let x = DenseMatrix::from_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
@@ -82,7 +82,8 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
let y = DenseMatrix::from_array(16, 1, &[83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]);
let y = DenseMatrix::from_array(&[&[83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]]);
let y_hat_qr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR).predict(&x);