feat: adds KMeans clustering algorithm

This commit is contained in:
Volodymyr Orlov
2020-02-20 18:43:24 -08:00
parent 4359d66bfa
commit 0e89113297
13 changed files with 637 additions and 84 deletions
+2 -2
View File
@@ -412,7 +412,7 @@ mod tests {
#[test]
fn fit_predict_iris() {
let x = DenseMatrix::from_2d_array(&[
let x = DenseMatrix::from_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
@@ -444,7 +444,7 @@ mod tests {
#[test]
fn fit_predict_baloons() {
let x = DenseMatrix::from_2d_array(&[
let x = DenseMatrix::from_array(&[
&[1.,1.,1.,0.],
&[1.,1.,1.,0.],
&[1.,1.,1.,1.],
+12 -5
View File
@@ -192,19 +192,26 @@ impl<M: Matrix> LogisticRegression<M> {
}
pub fn predict(&self, x: &M) -> M::RowVector {
let n = x.shape().0;
let mut result = M::zeros(1, n);
if self.num_classes == 2 {
let (nrows, _) = x.shape();
let x_and_bias = x.v_stack(&M::ones(nrows, 1));
let y_hat: Vec<f64> = x_and_bias.dot(&self.weights.transpose()).to_raw_vector();
M::from_vec(1, nrows, y_hat.iter().map(|y_hat| self.classes[if y_hat.sigmoid() > 0.5 { 1 } else { 0 }]).collect()).to_row_vector()
for i in 0..n {
result.set(0, i, self.classes[if y_hat[i].sigmoid() > 0.5 { 1 } else { 0 }]);
}
} else {
let (nrows, _) = x.shape();
let x_and_bias = x.v_stack(&M::ones(nrows, 1));
let y_hat = x_and_bias.dot(&self.weights.transpose());
let class_idxs = y_hat.argmax();
M::from_vec(1, nrows, class_idxs.iter().map(|class_idx| self.classes[*class_idx]).collect()).to_row_vector()
for i in 0..n {
result.set(0, i, self.classes[class_idxs[i]]);
}
}
result.to_row_vector()
}
pub fn coefficients(&self) -> M {
@@ -242,7 +249,7 @@ mod tests {
#[test]
fn multiclass_objective_f() {
let x = DenseMatrix::from_2d_array(&[
let x = DenseMatrix::from_array(&[
&[1., -5.],
&[ 2., 5.],
&[ 3., -2.],
@@ -282,7 +289,7 @@ mod tests {
#[test]
fn binary_objective_f() {
let x = DenseMatrix::from_2d_array(&[
let x = DenseMatrix::from_array(&[
&[1., -5.],
&[ 2., 5.],
&[ 3., -2.],
@@ -323,7 +330,7 @@ mod tests {
#[test]
fn lr_fit_predict() {
let x = DenseMatrix::from_2d_array(&[
let x = DenseMatrix::from_array(&[
&[1., -5.],
&[ 2., 5.],
&[ 3., -2.],
+1 -1
View File
@@ -128,7 +128,7 @@ mod tests {
#[test]
fn fit_predict_iris() {
let x = DenseMatrix::from_2d_array(&[
let x = DenseMatrix::from_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],