fix: more refactoring

This commit is contained in:
Volodymyr Orlov
2020-03-13 11:24:53 -07:00
parent cb4323f26e
commit 4f8318e933
15 changed files with 51 additions and 66 deletions
+2
View File
@@ -1,3 +1,5 @@
#![allow(non_snake_case)]
use num::complex::Complex;
use crate::linalg::BaseMatrix;
+1 -3
View File
@@ -122,9 +122,7 @@ pub trait BaseMatrix: Clone + Debug {
r
}
fn transpose(&self) -> Self;
fn generate_positive_definite(nrows: usize, ncols: usize) -> Self;
fn transpose(&self) -> Self;
fn rand(nrows: usize, ncols: usize) -> Self;
+1 -11
View File
@@ -366,12 +366,7 @@ impl BaseMatrix for DenseMatrix {
fn sub_element_mut(&mut self, row: usize, col: usize, x: f64) {
self.values[col*self.nrows + row] -= x;
}
fn generate_positive_definite(nrows: usize, ncols: usize) -> Self {
let m = DenseMatrix::rand(nrows, ncols);
m.dot(&m.transpose())
}
}
fn transpose(&self) -> Self {
let mut m = DenseMatrix {
@@ -723,11 +718,6 @@ mod tests {
}
}
#[test]
fn generate_positive_definite() {
let m = DenseMatrix::generate_positive_definite(3, 3);
}
#[test]
fn reshape() {
let m_orig = DenseMatrix::vector_from_array(&[1., 2., 3., 4., 5., 6.]);
+29 -6
View File
@@ -5,6 +5,7 @@ use crate::linalg::svd::SVDDecomposableMatrix;
use crate::linalg::evd::EVDDecomposableMatrix;
use crate::linalg::qr::QRDecomposableMatrix;
use ndarray::{Array, ArrayBase, OwnedRepr, Ix2, Ix1, Axis, stack, s};
use rand::prelude::*;
impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
{
@@ -81,7 +82,7 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
}
fn approximate_eq(&self, other: &Self, error: f64) -> bool {
false
(self - other).iter().all(|v| v.abs() <= error)
}
fn add_mut(&mut self, other: &Self) -> &Self {
@@ -128,12 +129,12 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
self.clone().reversed_axes()
}
fn generate_positive_definite(nrows: usize, ncols: usize) -> Self{
panic!("generate_positive_definite method is not implemented for ndarray");
}
fn rand(nrows: usize, ncols: usize) -> Self{
panic!("rand method is not implemented for ndarray");
let mut rng = rand::thread_rng();
let values: Vec<f64> = (0..nrows*ncols).map(|_| {
rng.gen()
}).collect();
Array::from_shape_vec((nrows, ncols), values).unwrap()
}
fn norm2(&self) -> f64{
@@ -600,4 +601,26 @@ mod tests {
let res: Array2<f64> = BaseMatrix::eye(3);
assert_eq!(res, a);
}
#[test]
fn rand() {
let m: Array2<f64> = BaseMatrix::rand(3, 3);
for c in 0..3 {
for r in 0..3 {
assert!(m[[r, c]] != 0f64);
}
}
}
#[test]
fn approximate_eq() {
let a = arr2(&[[1., 2., 3.],
[4., 5., 6.],
[7., 8., 9.]]);
let noise = arr2(&[[1e-5, 2e-5, 3e-5],
[4e-5, 5e-5, 6e-5],
[7e-5, 8e-5, 9e-5]]);
assert!(a.approximate_eq(&(&noise + &a), 1e-4));
assert!(!a.approximate_eq(&(&noise + &a), 1e-5));
}
}
+2
View File
@@ -1,3 +1,5 @@
#![allow(non_snake_case)]
use crate::linalg::BaseMatrix;
#[derive(Debug, Clone)]
+3 -1
View File
@@ -1,3 +1,5 @@
#![allow(non_snake_case)]
use crate::linalg::BaseMatrix;
#[derive(Debug, Clone)]
@@ -504,7 +506,7 @@ mod tests {
#[test]
fn solve() {
let mut a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
let a = DenseMatrix::from_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
let b = DenseMatrix::from_array(&[&[0.5, 0.2],&[0.5, 0.8], &[0.5, 0.3]]);
let expected_w = DenseMatrix::from_array(&[
&[-0.20, -1.28],