feat: add basic Matrix implementation for ndarray
This commit is contained in:
@@ -193,13 +193,13 @@ impl<M: Matrix> LogisticRegression<M> {
|
||||
pub fn predict(&self, x: &M) -> M {
|
||||
if self.num_classes == 2 {
|
||||
let (nrows, _) = x.shape();
|
||||
let x_and_bias = x.h_stack(&M::ones(nrows, 1));
|
||||
let x_and_bias = x.v_stack(&M::ones(nrows, 1));
|
||||
let y_hat: Vec<f64> = x_and_bias.dot(&self.weights.transpose()).to_raw_vector();
|
||||
M::from_vec(1, nrows, y_hat.iter().map(|y_hat| self.classes[if y_hat.sigmoid() > 0.5 { 1 } else { 0 }]).collect())
|
||||
|
||||
} else {
|
||||
let (nrows, _) = x.shape();
|
||||
let x_and_bias = x.h_stack(&M::ones(nrows, 1));
|
||||
let x_and_bias = x.v_stack(&M::ones(nrows, 1));
|
||||
let y_hat = x_and_bias.dot(&self.weights.transpose());
|
||||
let class_idxs = y_hat.argmax();
|
||||
M::from_vec(1, nrows, class_idxs.iter().map(|class_idx| self.classes[*class_idx]).collect())
|
||||
@@ -235,7 +235,9 @@ impl<M: Matrix> LogisticRegression<M> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::linalg::ndarray_bindings;
|
||||
use ndarray::{arr2, Array};
|
||||
|
||||
#[test]
|
||||
fn multiclass_objective_f() {
|
||||
@@ -388,4 +390,40 @@ mod tests {
|
||||
|
||||
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tt() {
|
||||
let x = arr2(&[
|
||||
[5.1, 3.5, 1.4, 0.2],
|
||||
[4.9, 3.0, 1.4, 0.2],
|
||||
[4.7, 3.2, 1.3, 0.2],
|
||||
[4.6, 3.1, 1.5, 0.2],
|
||||
[5.0, 3.6, 1.4, 0.2],
|
||||
[5.4, 3.9, 1.7, 0.4],
|
||||
[4.6, 3.4, 1.4, 0.3],
|
||||
[5.0, 3.4, 1.5, 0.2],
|
||||
[4.4, 2.9, 1.4, 0.2],
|
||||
[4.9, 3.1, 1.5, 0.1],
|
||||
[7.0, 3.2, 4.7, 1.4],
|
||||
[6.4, 3.2, 4.5, 1.5],
|
||||
[6.9, 3.1, 4.9, 1.5],
|
||||
[5.5, 2.3, 4.0, 1.3],
|
||||
[6.5, 2.8, 4.6, 1.5],
|
||||
[5.7, 2.8, 4.5, 1.3],
|
||||
[6.3, 3.3, 4.7, 1.6],
|
||||
[4.9, 2.4, 3.3, 1.0],
|
||||
[6.6, 2.9, 4.6, 1.3],
|
||||
[5.2, 2.7, 3.9, 1.4]]);
|
||||
let y = Array::from_shape_vec((1, 20), vec![0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]).unwrap();
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y);
|
||||
|
||||
println!("{:?}", lr);
|
||||
|
||||
let y_hat = lr.predict(&x).to_raw_vector();
|
||||
|
||||
assert_eq!(y_hat, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user