feat: adds KMeans clustering algorithm
This commit is contained in:
@@ -19,14 +19,14 @@ impl<M: Matrix> LinearRegression<M> {
|
||||
|
||||
pub fn fit(x: &M, y: &M, solver: LinearRegressionSolver) -> LinearRegression<M>{
|
||||
|
||||
let b = y.transpose();
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let (y_nrows, _) = y.shape();
|
||||
let (y_nrows, _) = b.shape();
|
||||
|
||||
if x_nrows != y_nrows {
|
||||
panic!("Number of rows of X doesn't match number of rows of Y");
|
||||
}
|
||||
|
||||
let b = y.clone();
|
||||
|
||||
let mut a = x.v_stack(&M::ones(x_nrows, 1));
|
||||
|
||||
let w = match solver {
|
||||
@@ -52,7 +52,7 @@ impl<M: Matrix> Regression<M> for LinearRegression<M> {
|
||||
let (nrows, _) = x.shape();
|
||||
let mut y_hat = x.dot(&self.coefficients);
|
||||
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
|
||||
y_hat
|
||||
y_hat.transpose()
|
||||
}
|
||||
|
||||
}
|
||||
@@ -65,7 +65,7 @@ mod tests {
|
||||
#[test]
|
||||
fn ols_fit_predict() {
|
||||
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
let x = DenseMatrix::from_array(&[
|
||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
@@ -82,7 +82,8 @@ mod tests {
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
|
||||
let y = DenseMatrix::from_array(16, 1, &[83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]);
|
||||
|
||||
let y = DenseMatrix::from_array(&[&[83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]]);
|
||||
|
||||
let y_hat_qr = LinearRegression::fit(&x, &y, LinearRegressionSolver::QR).predict(&x);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user