feat: add basic Matrix implementation for ndarray

This commit is contained in:
Volodymyr Orlov
2019-12-23 10:33:19 -08:00
parent 2425419d10
commit c1d7c038a6
6 changed files with 545 additions and 15 deletions
+41 -3
View File
@@ -193,13 +193,13 @@ impl<M: Matrix> LogisticRegression<M> {
pub fn predict(&self, x: &M) -> M {
if self.num_classes == 2 {
let (nrows, _) = x.shape();
let x_and_bias = x.h_stack(&M::ones(nrows, 1));
let x_and_bias = x.v_stack(&M::ones(nrows, 1));
let y_hat: Vec<f64> = x_and_bias.dot(&self.weights.transpose()).to_raw_vector();
M::from_vec(1, nrows, y_hat.iter().map(|y_hat| self.classes[if y_hat.sigmoid() > 0.5 { 1 } else { 0 }]).collect())
} else {
let (nrows, _) = x.shape();
let x_and_bias = x.h_stack(&M::ones(nrows, 1));
let x_and_bias = x.v_stack(&M::ones(nrows, 1));
let y_hat = x_and_bias.dot(&self.weights.transpose());
let class_idxs = y_hat.argmax();
M::from_vec(1, nrows, class_idxs.iter().map(|class_idx| self.classes[*class_idx]).collect())
@@ -235,7 +235,9 @@ impl<M: Matrix> LogisticRegression<M> {
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::linalg::ndarray_bindings;
use ndarray::{arr2, Array};
#[test]
fn multiclass_objective_f() {
@@ -388,4 +390,40 @@ mod tests {
}
#[test]
fn tt() {
let x = arr2(&[
[5.1, 3.5, 1.4, 0.2],
[4.9, 3.0, 1.4, 0.2],
[4.7, 3.2, 1.3, 0.2],
[4.6, 3.1, 1.5, 0.2],
[5.0, 3.6, 1.4, 0.2],
[5.4, 3.9, 1.7, 0.4],
[4.6, 3.4, 1.4, 0.3],
[5.0, 3.4, 1.5, 0.2],
[4.4, 2.9, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.1],
[7.0, 3.2, 4.7, 1.4],
[6.4, 3.2, 4.5, 1.5],
[6.9, 3.1, 4.9, 1.5],
[5.5, 2.3, 4.0, 1.3],
[6.5, 2.8, 4.6, 1.5],
[5.7, 2.8, 4.5, 1.3],
[6.3, 3.3, 4.7, 1.6],
[4.9, 2.4, 3.3, 1.0],
[6.6, 2.9, 4.6, 1.3],
[5.2, 2.7, 3.9, 1.4]]);
let y = Array::from_shape_vec((1, 20), vec![0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]).unwrap();
let lr = LogisticRegression::fit(&x, &y);
println!("{:?}", lr);
let y_hat = lr.predict(&x).to_raw_vector();
assert_eq!(y_hat, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]);
}
}