feat: simplifies LR API

This commit is contained in:
Volodymyr Orlov
2019-12-23 11:18:22 -08:00
parent c1d7c038a6
commit a4ff1cbe5f
4 changed files with 66 additions and 23 deletions
+22 -2
View File
@@ -1,9 +1,20 @@
use std::ops::Range;
use crate::linalg::{Matrix};
use ndarray::{Array, ArrayBase, OwnedRepr, Ix2, Axis, stack, s};
use ndarray::{Array, ArrayBase, OwnedRepr, Ix2, Ix1, Axis, stack, s};
impl Matrix for ArrayBase<OwnedRepr<f64>, Ix2>
{
type RowVector = ArrayBase<OwnedRepr<f64>, Ix1>;
fn from_row_vector(vec: Self::RowVector) -> Self{
let vec_size = vec.len();
vec.into_shape((1, vec_size)).unwrap()
}
fn to_row_vector(self) -> Self::RowVector{
let vec_size = self.nrows() * self.ncols();
self.into_shape(vec_size).unwrap()
}
fn from_array(nrows: usize, ncols: usize, values: &[f64]) -> Self {
Array::from_shape_vec((nrows, ncols), values.to_vec()).unwrap()
@@ -248,7 +259,16 @@ impl Matrix for ArrayBase<OwnedRepr<f64>, Ix2>
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{arr2, Array2};
use ndarray::{arr1, arr2, Array2};
#[test]
fn from_to_row_vec() {
let vec = arr1(&[ 1., 2., 3.]);
assert_eq!(Array2::from_row_vector(vec.clone()), arr2(&[[1., 2., 3.]]));
assert_eq!(Array2::from_row_vector(vec.clone()).to_row_vector(), arr1(&[1., 2., 3.]));
}
#[test]
fn add_mut() {