feat: extends interface of Matrix to support for broad range of types
This commit is contained in:
+42
-37
@@ -6,11 +6,14 @@ pub mod ndarray_bindings;
|
||||
|
||||
use std::ops::Range;
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use svd::SVDDecomposableMatrix;
|
||||
use evd::EVDDecomposableMatrix;
|
||||
use qr::QRDecomposableMatrix;
|
||||
|
||||
pub trait BaseMatrix: Clone + Debug {
|
||||
pub trait BaseMatrix<T: FloatExt + Debug>: Clone + Debug {
|
||||
|
||||
type RowVector: Clone + Debug;
|
||||
|
||||
@@ -18,13 +21,13 @@ pub trait BaseMatrix: Clone + Debug {
|
||||
|
||||
fn to_row_vector(self) -> Self::RowVector;
|
||||
|
||||
fn get(&self, row: usize, col: usize) -> f64;
|
||||
fn get(&self, row: usize, col: usize) -> T;
|
||||
|
||||
fn get_row_as_vec(&self, row: usize) -> Vec<f64>;
|
||||
fn get_row_as_vec(&self, row: usize) -> Vec<T>;
|
||||
|
||||
fn get_col_as_vec(&self, col: usize) -> Vec<f64>;
|
||||
fn get_col_as_vec(&self, col: usize) -> Vec<T>;
|
||||
|
||||
fn set(&mut self, row: usize, col: usize, x: f64);
|
||||
fn set(&mut self, row: usize, col: usize, x: T);
|
||||
|
||||
fn eye(size: usize) -> Self;
|
||||
|
||||
@@ -32,9 +35,9 @@ pub trait BaseMatrix: Clone + Debug {
|
||||
|
||||
fn ones(nrows: usize, ncols: usize) -> Self;
|
||||
|
||||
fn to_raw_vector(&self) -> Vec<f64>;
|
||||
fn to_raw_vector(&self) -> Vec<T>;
|
||||
|
||||
fn fill(nrows: usize, ncols: usize, value: f64) -> Self;
|
||||
fn fill(nrows: usize, ncols: usize, value: T) -> Self;
|
||||
|
||||
fn shape(&self) -> (usize, usize);
|
||||
|
||||
@@ -44,11 +47,11 @@ pub trait BaseMatrix: Clone + Debug {
|
||||
|
||||
fn dot(&self, other: &Self) -> Self;
|
||||
|
||||
fn vector_dot(&self, other: &Self) -> f64;
|
||||
fn vector_dot(&self, other: &Self) -> T;
|
||||
|
||||
fn slice(&self, rows: Range<usize>, cols: Range<usize>) -> Self;
|
||||
|
||||
fn approximate_eq(&self, other: &Self, error: f64) -> bool;
|
||||
fn approximate_eq(&self, other: &Self, error: T) -> bool;
|
||||
|
||||
fn add_mut(&mut self, other: &Self) -> &Self;
|
||||
|
||||
@@ -58,13 +61,13 @@ pub trait BaseMatrix: Clone + Debug {
|
||||
|
||||
fn div_mut(&mut self, other: &Self) -> &Self;
|
||||
|
||||
fn div_element_mut(&mut self, row: usize, col: usize, x: f64);
|
||||
fn div_element_mut(&mut self, row: usize, col: usize, x: T);
|
||||
|
||||
fn mul_element_mut(&mut self, row: usize, col: usize, x: f64);
|
||||
fn mul_element_mut(&mut self, row: usize, col: usize, x: T);
|
||||
|
||||
fn add_element_mut(&mut self, row: usize, col: usize, x: f64);
|
||||
fn add_element_mut(&mut self, row: usize, col: usize, x: T);
|
||||
|
||||
fn sub_element_mut(&mut self, row: usize, col: usize, x: f64);
|
||||
fn sub_element_mut(&mut self, row: usize, col: usize, x: T);
|
||||
|
||||
fn add(&self, other: &Self) -> Self {
|
||||
let mut r = self.clone();
|
||||
@@ -90,33 +93,33 @@ pub trait BaseMatrix: Clone + Debug {
|
||||
r
|
||||
}
|
||||
|
||||
fn add_scalar_mut(&mut self, scalar: f64) -> &Self;
|
||||
fn add_scalar_mut(&mut self, scalar: T) -> &Self;
|
||||
|
||||
fn sub_scalar_mut(&mut self, scalar: f64) -> &Self;
|
||||
fn sub_scalar_mut(&mut self, scalar: T) -> &Self;
|
||||
|
||||
fn mul_scalar_mut(&mut self, scalar: f64) -> &Self;
|
||||
fn mul_scalar_mut(&mut self, scalar: T) -> &Self;
|
||||
|
||||
fn div_scalar_mut(&mut self, scalar: f64) -> &Self;
|
||||
fn div_scalar_mut(&mut self, scalar: T) -> &Self;
|
||||
|
||||
fn add_scalar(&self, scalar: f64) -> Self{
|
||||
fn add_scalar(&self, scalar: T) -> Self{
|
||||
let mut r = self.clone();
|
||||
r.add_scalar_mut(scalar);
|
||||
r
|
||||
}
|
||||
|
||||
fn sub_scalar(&self, scalar: f64) -> Self{
|
||||
fn sub_scalar(&self, scalar: T) -> Self{
|
||||
let mut r = self.clone();
|
||||
r.sub_scalar_mut(scalar);
|
||||
r
|
||||
}
|
||||
|
||||
fn mul_scalar(&self, scalar: f64) -> Self{
|
||||
fn mul_scalar(&self, scalar: T) -> Self{
|
||||
let mut r = self.clone();
|
||||
r.mul_scalar_mut(scalar);
|
||||
r
|
||||
}
|
||||
|
||||
fn div_scalar(&self, scalar: f64) -> Self{
|
||||
fn div_scalar(&self, scalar: T) -> Self{
|
||||
let mut r = self.clone();
|
||||
r.div_scalar_mut(scalar);
|
||||
r
|
||||
@@ -126,11 +129,11 @@ pub trait BaseMatrix: Clone + Debug {
|
||||
|
||||
fn rand(nrows: usize, ncols: usize) -> Self;
|
||||
|
||||
fn norm2(&self) -> f64;
|
||||
fn norm2(&self) -> T;
|
||||
|
||||
fn norm(&self, p:f64) -> f64;
|
||||
fn norm(&self, p:T) -> T;
|
||||
|
||||
fn column_mean(&self) -> Vec<f64>;
|
||||
fn column_mean(&self) -> Vec<T>;
|
||||
|
||||
fn negative_mut(&mut self);
|
||||
|
||||
@@ -152,15 +155,15 @@ pub trait BaseMatrix: Clone + Debug {
|
||||
result
|
||||
}
|
||||
|
||||
fn sum(&self) -> f64;
|
||||
fn sum(&self) -> T;
|
||||
|
||||
fn max_diff(&self, other: &Self) -> f64;
|
||||
fn max_diff(&self, other: &Self) -> T;
|
||||
|
||||
fn softmax_mut(&mut self);
|
||||
|
||||
fn pow_mut(&mut self, p: f64) -> &Self;
|
||||
fn pow_mut(&mut self, p: T) -> &Self;
|
||||
|
||||
fn pow(&mut self, p: f64) -> Self {
|
||||
fn pow(&mut self, p: T) -> Self {
|
||||
let mut result = self.clone();
|
||||
result.pow_mut(p);
|
||||
result
|
||||
@@ -168,31 +171,33 @@ pub trait BaseMatrix: Clone + Debug {
|
||||
|
||||
fn argmax(&self) -> Vec<usize>;
|
||||
|
||||
fn unique(&self) -> Vec<f64>;
|
||||
fn unique(&self) -> Vec<T>;
|
||||
|
||||
}
|
||||
|
||||
pub trait Matrix: BaseMatrix + SVDDecomposableMatrix + EVDDecomposableMatrix + QRDecomposableMatrix {}
|
||||
pub trait Matrix<T: FloatExt + Debug>: BaseMatrix<T> + SVDDecomposableMatrix<T> + EVDDecomposableMatrix<T> + QRDecomposableMatrix<T> {}
|
||||
|
||||
pub fn row_iter<M: Matrix>(m: &M) -> RowIter<M> {
|
||||
pub fn row_iter<F: FloatExt + Debug, M: Matrix<F>>(m: &M) -> RowIter<F, M> {
|
||||
RowIter{
|
||||
m: m,
|
||||
pos: 0,
|
||||
max_pos: m.shape().0
|
||||
max_pos: m.shape().0,
|
||||
phantom: PhantomData
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RowIter<'a, M: Matrix> {
|
||||
pub struct RowIter<'a, T: FloatExt + Debug, M: Matrix<T>> {
|
||||
m: &'a M,
|
||||
pos: usize,
|
||||
max_pos: usize
|
||||
max_pos: usize,
|
||||
phantom: PhantomData<&'a T>
|
||||
}
|
||||
|
||||
impl<'a, M: Matrix> Iterator for RowIter<'a, M> {
|
||||
impl<'a, T: FloatExt + Debug, M: Matrix<T>> Iterator for RowIter<'a, T, M> {
|
||||
|
||||
type Item = Vec<f64>;
|
||||
type Item = Vec<T>;
|
||||
|
||||
fn next(&mut self) -> Option<Vec<f64>> {
|
||||
fn next(&mut self) -> Option<Vec<T>> {
|
||||
let res;
|
||||
if self.pos < self.max_pos {
|
||||
res = Some(self.m.get_row_as_vec(self.pos))
|
||||
|
||||
Reference in New Issue
Block a user