feat: adds SVD
This commit is contained in:
+50
-18
@@ -1,12 +1,14 @@
|
||||
pub mod naive;
|
||||
pub mod svd;
|
||||
pub mod ndarray_bindings;
|
||||
|
||||
use std::ops::Range;
|
||||
use std::fmt::Debug;
|
||||
|
||||
pub mod naive;
|
||||
pub mod ndarray_bindings;
|
||||
use svd::SVD;
|
||||
|
||||
pub trait Matrix: Clone + Debug {
|
||||
|
||||
type RowVector: Clone + Debug;
|
||||
type RowVector: Clone + Debug;
|
||||
|
||||
fn from_row_vector(vec: Self::RowVector) -> Self;
|
||||
|
||||
@@ -16,25 +18,24 @@ pub trait Matrix: Clone + Debug {
|
||||
|
||||
fn get_row_as_vec(&self, row: usize) -> Vec<f64>;
|
||||
|
||||
fn get_col_as_vec(&self, col: usize) -> Vec<f64>;
|
||||
|
||||
fn to_vector(&self) -> Vec<Vec<f64>> {
|
||||
|
||||
let (n, _) = self.shape();
|
||||
let mut data = Vec::new();
|
||||
|
||||
for i in 0..n {
|
||||
data.push(self.get_row_as_vec(i));
|
||||
}
|
||||
|
||||
data
|
||||
}
|
||||
fn get_col_as_vec(&self, col: usize) -> Vec<f64>;
|
||||
|
||||
fn set(&mut self, row: usize, col: usize, x: f64);
|
||||
|
||||
fn qr_solve_mut(&mut self, b: Self) -> Self;
|
||||
|
||||
fn svd_solve_mut(&mut self, b: Self) -> Self;
|
||||
fn svd(&self) -> SVD<Self>;
|
||||
|
||||
fn svd_solve_mut(&mut self, b: Self) -> Self {
|
||||
self.svd_solve(b)
|
||||
}
|
||||
|
||||
fn svd_solve(&self, b: Self) -> Self {
|
||||
|
||||
let svd = self.svd();
|
||||
svd.solve(b)
|
||||
|
||||
}
|
||||
|
||||
fn zeros(nrows: usize, ncols: usize) -> Self;
|
||||
|
||||
@@ -170,4 +171,35 @@ pub trait Matrix: Clone + Debug {
|
||||
|
||||
fn unique(&self) -> Vec<f64>;
|
||||
|
||||
}
|
||||
|
||||
pub fn row_iter<M: Matrix>(m: &M) -> RowIter<M> {
|
||||
RowIter{
|
||||
m: m,
|
||||
pos: 0,
|
||||
max_pos: m.shape().0
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RowIter<'a, M: Matrix> {
|
||||
m: &'a M,
|
||||
pos: usize,
|
||||
max_pos: usize
|
||||
}
|
||||
|
||||
impl<'a, M: Matrix> Iterator for RowIter<'a, M> {
|
||||
|
||||
type Item = Vec<f64>;
|
||||
|
||||
fn next(&mut self) -> Option<Vec<f64>> {
|
||||
let res;
|
||||
if self.pos < self.max_pos {
|
||||
res = Some(self.m.get_row_as_vec(self.pos))
|
||||
} else {
|
||||
res = None
|
||||
}
|
||||
self.pos += 1;
|
||||
res
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user