feat: simplifies LR API

This commit is contained in:
Volodymyr Orlov
2019-12-23 11:18:22 -08:00
parent c1d7c038a6
commit a4ff1cbe5f
4 changed files with 66 additions and 23 deletions
+6
View File
@@ -6,6 +6,12 @@ pub mod ndarray_bindings;
pub trait Matrix: Clone + Debug {
type RowVector: Clone + Debug;
fn from_row_vector(vec: Self::RowVector) -> Self;
fn to_row_vector(self) -> Self::RowVector;
fn from_array(nrows: usize, ncols: usize, values: &[f64]) -> Self;
fn from_vec(nrows: usize, ncols: usize, values: Vec<f64>) -> Self;
+20 -1
View File
@@ -107,7 +107,17 @@ impl Into<Vec<f64>> for DenseMatrix {
}
}
impl Matrix for DenseMatrix {
impl Matrix for DenseMatrix {
type RowVector = Vec<f64>;
fn from_row_vector(vec: Self::RowVector) -> Self{
DenseMatrix::from_vec(1, vec.len(), vec)
}
fn to_row_vector(self) -> Self::RowVector{
self.to_raw_vector()
}
fn from_array(nrows: usize, ncols: usize, values: &[f64]) -> DenseMatrix {
DenseMatrix::from_vec(nrows, ncols, Vec::from(values))
@@ -968,6 +978,15 @@ impl Matrix for DenseMatrix {
mod tests {
use super::*;
#[test]
fn from_to_row_vec() {
let vec = vec![ 1., 2., 3.];
assert_eq!(DenseMatrix::from_row_vector(vec.clone()), DenseMatrix::from_vec(1, 3, vec![1., 2., 3.]));
assert_eq!(DenseMatrix::from_row_vector(vec.clone()).to_row_vector(), vec![1., 2., 3.]);
}
#[test]
fn qr_solve_mut() {
+22 -2
View File
@@ -1,9 +1,20 @@
use std::ops::Range;
use crate::linalg::{Matrix};
use ndarray::{Array, ArrayBase, OwnedRepr, Ix2, Axis, stack, s};
use ndarray::{Array, ArrayBase, OwnedRepr, Ix2, Ix1, Axis, stack, s};
impl Matrix for ArrayBase<OwnedRepr<f64>, Ix2>
{
type RowVector = ArrayBase<OwnedRepr<f64>, Ix1>;
fn from_row_vector(vec: Self::RowVector) -> Self{
let vec_size = vec.len();
vec.into_shape((1, vec_size)).unwrap()
}
fn to_row_vector(self) -> Self::RowVector{
let vec_size = self.nrows() * self.ncols();
self.into_shape(vec_size).unwrap()
}
fn from_array(nrows: usize, ncols: usize, values: &[f64]) -> Self {
Array::from_shape_vec((nrows, ncols), values.to_vec()).unwrap()
@@ -248,7 +259,16 @@ impl Matrix for ArrayBase<OwnedRepr<f64>, Ix2>
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{arr2, Array2};
use ndarray::{arr1, arr2, Array2};
#[test]
fn from_to_row_vec() {
let vec = arr1(&[ 1., 2., 3.]);
assert_eq!(Array2::from_row_vector(vec.clone()), arr2(&[[1., 2., 3.]]));
assert_eq!(Array2::from_row_vector(vec.clone()).to_row_vector(), arr1(&[1., 2., 3.]));
}
#[test]
fn add_mut() {