feat: adds SVD

This commit is contained in:
Volodymyr Orlov
2020-02-28 09:21:00 -08:00
parent fe50509d3b
commit 619560a1cd
5 changed files with 269 additions and 119 deletions
+50 -18
View File
@@ -1,12 +1,14 @@
pub mod naive;
pub mod svd;
pub mod ndarray_bindings;
use std::ops::Range;
use std::fmt::Debug;
pub mod naive;
pub mod ndarray_bindings;
use svd::SVD;
pub trait Matrix: Clone + Debug {
type RowVector: Clone + Debug;
type RowVector: Clone + Debug;
fn from_row_vector(vec: Self::RowVector) -> Self;
@@ -16,25 +18,24 @@ pub trait Matrix: Clone + Debug {
fn get_row_as_vec(&self, row: usize) -> Vec<f64>;
fn get_col_as_vec(&self, col: usize) -> Vec<f64>;
fn to_vector(&self) -> Vec<Vec<f64>> {
let (n, _) = self.shape();
let mut data = Vec::new();
for i in 0..n {
data.push(self.get_row_as_vec(i));
}
data
}
fn get_col_as_vec(&self, col: usize) -> Vec<f64>;
fn set(&mut self, row: usize, col: usize, x: f64);
fn qr_solve_mut(&mut self, b: Self) -> Self;
fn svd_solve_mut(&mut self, b: Self) -> Self;
fn svd(&self) -> SVD<Self>;
fn svd_solve_mut(&mut self, b: Self) -> Self {
self.svd_solve(b)
}
fn svd_solve(&self, b: Self) -> Self {
let svd = self.svd();
svd.solve(b)
}
fn zeros(nrows: usize, ncols: usize) -> Self;
@@ -170,4 +171,35 @@ pub trait Matrix: Clone + Debug {
fn unique(&self) -> Vec<f64>;
}
pub fn row_iter<M: Matrix>(m: &M) -> RowIter<M> {
RowIter{
m: m,
pos: 0,
max_pos: m.shape().0
}
}
pub struct RowIter<'a, M: Matrix> {
m: &'a M,
pos: usize,
max_pos: usize
}
impl<'a, M: Matrix> Iterator for RowIter<'a, M> {
type Item = Vec<f64>;
fn next(&mut self) -> Option<Vec<f64>> {
let res;
if self.pos < self.max_pos {
res = Some(self.m.get_row_as_vec(self.pos))
} else {
res = None
}
self.pos += 1;
res
}
}