feat: extends interface of Matrix to support for broad range of types
This commit is contained in:
@@ -1,15 +1,25 @@
|
||||
use std::ops::Range;
|
||||
use std::fmt::Debug;
|
||||
use std::iter::Sum;
|
||||
use std::ops::AddAssign;
|
||||
use std::ops::SubAssign;
|
||||
use std::ops::MulAssign;
|
||||
use std::ops::DivAssign;
|
||||
|
||||
use ndarray::{Array, ArrayBase, OwnedRepr, Ix2, Ix1, Axis, stack, s};
|
||||
use ndarray::ScalarOperand;
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::linalg::BaseMatrix;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::linalg::svd::SVDDecomposableMatrix;
|
||||
use crate::linalg::evd::EVDDecomposableMatrix;
|
||||
use crate::linalg::qr::QRDecomposableMatrix;
|
||||
use ndarray::{Array, ArrayBase, OwnedRepr, Ix2, Ix1, Axis, stack, s};
|
||||
use rand::prelude::*;
|
||||
|
||||
impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
|
||||
|
||||
impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> BaseMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2>
|
||||
{
|
||||
type RowVector = ArrayBase<OwnedRepr<f64>, Ix1>;
|
||||
type RowVector = ArrayBase<OwnedRepr<T>, Ix1>;
|
||||
|
||||
fn from_row_vector(vec: Self::RowVector) -> Self{
|
||||
let vec_size = vec.len();
|
||||
@@ -21,19 +31,19 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
|
||||
self.into_shape(vec_size).unwrap()
|
||||
}
|
||||
|
||||
fn get(&self, row: usize, col: usize) -> f64 {
|
||||
fn get(&self, row: usize, col: usize) -> T {
|
||||
self[[row, col]]
|
||||
}
|
||||
|
||||
fn get_row_as_vec(&self, row: usize) -> Vec<f64> {
|
||||
fn get_row_as_vec(&self, row: usize) -> Vec<T> {
|
||||
self.row(row).to_vec()
|
||||
}
|
||||
|
||||
fn get_col_as_vec(&self, col: usize) -> Vec<f64> {
|
||||
fn get_col_as_vec(&self, col: usize) -> Vec<T> {
|
||||
self.column(col).to_vec()
|
||||
}
|
||||
|
||||
fn set(&mut self, row: usize, col: usize, x: f64) {
|
||||
fn set(&mut self, row: usize, col: usize, x: T) {
|
||||
self[[row, col]] = x;
|
||||
}
|
||||
|
||||
@@ -49,11 +59,11 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
|
||||
Array::ones((nrows, ncols))
|
||||
}
|
||||
|
||||
fn to_raw_vector(&self) -> Vec<f64> {
|
||||
fn to_raw_vector(&self) -> Vec<T> {
|
||||
self.to_owned().iter().map(|v| *v).collect()
|
||||
}
|
||||
|
||||
fn fill(nrows: usize, ncols: usize, value: f64) -> Self {
|
||||
fn fill(nrows: usize, ncols: usize, value: T) -> Self {
|
||||
Array::from_elem((nrows, ncols), value)
|
||||
}
|
||||
|
||||
@@ -73,7 +83,7 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
|
||||
self.dot(other)
|
||||
}
|
||||
|
||||
fn vector_dot(&self, other: &Self) -> f64 {
|
||||
fn vector_dot(&self, other: &Self) -> T {
|
||||
self.dot(&other.view().reversed_axes())[[0, 0]]
|
||||
}
|
||||
|
||||
@@ -81,7 +91,7 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
|
||||
self.slice(s![rows, cols]).to_owned()
|
||||
}
|
||||
|
||||
fn approximate_eq(&self, other: &Self, error: f64) -> bool {
|
||||
fn approximate_eq(&self, other: &Self, error: T) -> bool {
|
||||
(self - other).iter().all(|v| v.abs() <= error)
|
||||
}
|
||||
|
||||
@@ -105,22 +115,22 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
|
||||
self
|
||||
}
|
||||
|
||||
fn add_scalar_mut(&mut self, scalar: f64) -> &Self{
|
||||
fn add_scalar_mut(&mut self, scalar: T) -> &Self{
|
||||
*self += scalar;
|
||||
self
|
||||
}
|
||||
|
||||
fn sub_scalar_mut(&mut self, scalar: f64) -> &Self{
|
||||
fn sub_scalar_mut(&mut self, scalar: T) -> &Self{
|
||||
*self -= scalar;
|
||||
self
|
||||
}
|
||||
|
||||
fn mul_scalar_mut(&mut self, scalar: f64) -> &Self{
|
||||
fn mul_scalar_mut(&mut self, scalar: T) -> &Self{
|
||||
*self *= scalar;
|
||||
self
|
||||
}
|
||||
|
||||
fn div_scalar_mut(&mut self, scalar: f64) -> &Self{
|
||||
fn div_scalar_mut(&mut self, scalar: T) -> &Self{
|
||||
*self /= scalar;
|
||||
self
|
||||
}
|
||||
@@ -129,21 +139,20 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
|
||||
self.clone().reversed_axes()
|
||||
}
|
||||
|
||||
fn rand(nrows: usize, ncols: usize) -> Self{
|
||||
let mut rng = rand::thread_rng();
|
||||
let values: Vec<f64> = (0..nrows*ncols).map(|_| {
|
||||
rng.gen()
|
||||
fn rand(nrows: usize, ncols: usize) -> Self{
|
||||
let values: Vec<T> = (0..nrows*ncols).map(|_| {
|
||||
T::rand()
|
||||
}).collect();
|
||||
Array::from_shape_vec((nrows, ncols), values).unwrap()
|
||||
}
|
||||
|
||||
fn norm2(&self) -> f64{
|
||||
self.iter().map(|x| x * x).sum::<f64>().sqrt()
|
||||
fn norm2(&self) -> T{
|
||||
self.iter().map(|x| *x * *x).sum::<T>().sqrt()
|
||||
}
|
||||
|
||||
fn norm(&self, p:f64) -> f64 {
|
||||
fn norm(&self, p:T) -> T {
|
||||
if p.is_infinite() && p.is_sign_positive() {
|
||||
self.iter().fold(std::f64::NEG_INFINITY, |f, &val| {
|
||||
self.iter().fold(T::neg_infinity(), |f, &val| {
|
||||
let v = val.abs();
|
||||
if f > v {
|
||||
f
|
||||
@@ -152,7 +161,7 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
|
||||
}
|
||||
})
|
||||
} else if p.is_infinite() && p.is_sign_negative() {
|
||||
self.iter().fold(std::f64::INFINITY, |f, &val| {
|
||||
self.iter().fold(T::infinity(), |f, &val| {
|
||||
let v = val.abs();
|
||||
if f < v {
|
||||
f
|
||||
@@ -162,38 +171,38 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
|
||||
})
|
||||
} else {
|
||||
|
||||
let mut norm = 0f64;
|
||||
let mut norm = T::zero();
|
||||
|
||||
for xi in self.iter() {
|
||||
norm += xi.abs().powf(p);
|
||||
norm = norm + xi.abs().powf(p);
|
||||
}
|
||||
|
||||
norm.powf(1.0/p)
|
||||
norm.powf(T::one()/p)
|
||||
}
|
||||
}
|
||||
|
||||
fn column_mean(&self) -> Vec<f64> {
|
||||
fn column_mean(&self) -> Vec<T> {
|
||||
self.mean_axis(Axis(0)).unwrap().to_vec()
|
||||
}
|
||||
|
||||
fn div_element_mut(&mut self, row: usize, col: usize, x: f64){
|
||||
self[[row, col]] /= x;
|
||||
fn div_element_mut(&mut self, row: usize, col: usize, x: T){
|
||||
self[[row, col]] = self[[row, col]] / x;
|
||||
}
|
||||
|
||||
fn mul_element_mut(&mut self, row: usize, col: usize, x: f64){
|
||||
self[[row, col]] *= x;
|
||||
fn mul_element_mut(&mut self, row: usize, col: usize, x: T){
|
||||
self[[row, col]] = self[[row, col]] * x;
|
||||
}
|
||||
|
||||
fn add_element_mut(&mut self, row: usize, col: usize, x: f64){
|
||||
self[[row, col]] += x;
|
||||
fn add_element_mut(&mut self, row: usize, col: usize, x: T){
|
||||
self[[row, col]] = self[[row, col]] + x;
|
||||
}
|
||||
|
||||
fn sub_element_mut(&mut self, row: usize, col: usize, x: f64){
|
||||
self[[row, col]] -= x;
|
||||
fn sub_element_mut(&mut self, row: usize, col: usize, x: T){
|
||||
self[[row, col]] = self[[row, col]] - x;
|
||||
}
|
||||
|
||||
fn negative_mut(&mut self){
|
||||
*self *= -1.;
|
||||
*self *= -T::one();
|
||||
}
|
||||
|
||||
fn reshape(&self, nrows: usize, ncols: usize) -> Self{
|
||||
@@ -208,12 +217,12 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
|
||||
self
|
||||
}
|
||||
|
||||
fn sum(&self) -> f64{
|
||||
fn sum(&self) -> T{
|
||||
self.sum()
|
||||
}
|
||||
|
||||
fn max_diff(&self, other: &Self) -> f64{
|
||||
let mut max_diff = 0f64;
|
||||
fn max_diff(&self, other: &Self) -> T{
|
||||
let mut max_diff = T::zero();
|
||||
for r in 0..self.nrows() {
|
||||
for c in 0..self.ncols() {
|
||||
max_diff = max_diff.max((self[(r, c)] - other[(r, c)]).abs());
|
||||
@@ -223,13 +232,13 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
|
||||
}
|
||||
|
||||
fn softmax_mut(&mut self){
|
||||
let max = self.iter().map(|x| x.abs()).fold(std::f64::NEG_INFINITY, |a, b| a.max(b));
|
||||
let mut z = 0.;
|
||||
let max = self.iter().map(|x| x.abs()).fold(T::neg_infinity(), |a, b| a.max(b));
|
||||
let mut z = T::zero();
|
||||
for r in 0..self.nrows() {
|
||||
for c in 0..self.ncols() {
|
||||
let p = (self[(r, c)] - max).exp();
|
||||
self.set(r, c, p);
|
||||
z += p;
|
||||
z = z + p;
|
||||
}
|
||||
}
|
||||
for r in 0..self.nrows() {
|
||||
@@ -239,7 +248,7 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
|
||||
}
|
||||
}
|
||||
|
||||
fn pow_mut(&mut self, p: f64) -> &Self{
|
||||
fn pow_mut(&mut self, p: T) -> &Self{
|
||||
for r in 0..self.nrows() {
|
||||
for c in 0..self.ncols() {
|
||||
self.set(r, c, self[(r, c)].powf(p));
|
||||
@@ -252,7 +261,7 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
|
||||
let mut res = vec![0usize; self.nrows()];
|
||||
|
||||
for r in 0..self.nrows() {
|
||||
let mut max = std::f64::NEG_INFINITY;
|
||||
let mut max = T::neg_infinity();
|
||||
let mut max_pos = 0usize;
|
||||
for c in 0..self.ncols() {
|
||||
let v = self[(r, c)];
|
||||
@@ -268,7 +277,7 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
|
||||
|
||||
}
|
||||
|
||||
fn unique(&self) -> Vec<f64> {
|
||||
fn unique(&self) -> Vec<T> {
|
||||
let mut result = self.clone().into_raw_vec();
|
||||
result.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
result.dedup();
|
||||
@@ -277,13 +286,13 @@ impl BaseMatrix for ArrayBase<OwnedRepr<f64>, Ix2>
|
||||
|
||||
}
|
||||
|
||||
impl SVDDecomposableMatrix for ArrayBase<OwnedRepr<f64>, Ix2> {}
|
||||
impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> SVDDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
||||
|
||||
impl EVDDecomposableMatrix for ArrayBase<OwnedRepr<f64>, Ix2> {}
|
||||
impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> EVDDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
||||
|
||||
impl QRDecomposableMatrix for ArrayBase<OwnedRepr<f64>, Ix2> {}
|
||||
impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> QRDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
||||
|
||||
impl Matrix for ArrayBase<OwnedRepr<f64>, Ix2> {}
|
||||
impl<T: FloatExt + Debug + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> Matrix<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
@@ -541,7 +550,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn softmax_mut(){
|
||||
let mut prob = arr2(&[[1., 2., 3.]]);
|
||||
let mut prob: Array2<f64> = arr2(&[[1., 2., 3.]]);
|
||||
prob.softmax_mut();
|
||||
assert!((BaseMatrix::get(&prob, 0, 0) - 0.09).abs() < 0.01);
|
||||
assert!((BaseMatrix::get(&prob, 0, 1) - 0.24).abs() < 0.01);
|
||||
|
||||
Reference in New Issue
Block a user