feat: adds KMeans clustering algorithm
This commit is contained in:
@@ -412,7 +412,7 @@ mod tests {
|
||||
#[test]
|
||||
fn fit_predict_iris() {
|
||||
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
let x = DenseMatrix::from_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
@@ -444,7 +444,7 @@ mod tests {
|
||||
#[test]
|
||||
fn fit_predict_baloons() {
|
||||
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
let x = DenseMatrix::from_array(&[
|
||||
&[1.,1.,1.,0.],
|
||||
&[1.,1.,1.,0.],
|
||||
&[1.,1.,1.,1.],
|
||||
|
||||
@@ -192,19 +192,26 @@ impl<M: Matrix> LogisticRegression<M> {
|
||||
}
|
||||
|
||||
pub fn predict(&self, x: &M) -> M::RowVector {
|
||||
let n = x.shape().0;
|
||||
let mut result = M::zeros(1, n);
|
||||
if self.num_classes == 2 {
|
||||
let (nrows, _) = x.shape();
|
||||
let x_and_bias = x.v_stack(&M::ones(nrows, 1));
|
||||
let y_hat: Vec<f64> = x_and_bias.dot(&self.weights.transpose()).to_raw_vector();
|
||||
M::from_vec(1, nrows, y_hat.iter().map(|y_hat| self.classes[if y_hat.sigmoid() > 0.5 { 1 } else { 0 }]).collect()).to_row_vector()
|
||||
for i in 0..n {
|
||||
result.set(0, i, self.classes[if y_hat[i].sigmoid() > 0.5 { 1 } else { 0 }]);
|
||||
}
|
||||
|
||||
} else {
|
||||
let (nrows, _) = x.shape();
|
||||
let x_and_bias = x.v_stack(&M::ones(nrows, 1));
|
||||
let y_hat = x_and_bias.dot(&self.weights.transpose());
|
||||
let class_idxs = y_hat.argmax();
|
||||
M::from_vec(1, nrows, class_idxs.iter().map(|class_idx| self.classes[*class_idx]).collect()).to_row_vector()
|
||||
for i in 0..n {
|
||||
result.set(0, i, self.classes[class_idxs[i]]);
|
||||
}
|
||||
}
|
||||
result.to_row_vector()
|
||||
}
|
||||
|
||||
pub fn coefficients(&self) -> M {
|
||||
@@ -242,7 +249,7 @@ mod tests {
|
||||
#[test]
|
||||
fn multiclass_objective_f() {
|
||||
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
let x = DenseMatrix::from_array(&[
|
||||
&[1., -5.],
|
||||
&[ 2., 5.],
|
||||
&[ 3., -2.],
|
||||
@@ -282,7 +289,7 @@ mod tests {
|
||||
#[test]
|
||||
fn binary_objective_f() {
|
||||
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
let x = DenseMatrix::from_array(&[
|
||||
&[1., -5.],
|
||||
&[ 2., 5.],
|
||||
&[ 3., -2.],
|
||||
@@ -323,7 +330,7 @@ mod tests {
|
||||
#[test]
|
||||
fn lr_fit_predict() {
|
||||
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
let x = DenseMatrix::from_array(&[
|
||||
&[1., -5.],
|
||||
&[ 2., 5.],
|
||||
&[ 3., -2.],
|
||||
|
||||
@@ -128,7 +128,7 @@ mod tests {
|
||||
#[test]
|
||||
fn fit_predict_iris() {
|
||||
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
let x = DenseMatrix::from_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
|
||||
Reference in New Issue
Block a user