diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs index 9f1697c..4fb3ebf 100644 --- a/src/linalg/mod.rs +++ b/src/linalg/mod.rs @@ -343,16 +343,19 @@ pub trait BaseMatrix: Clone + Debug { /// ]) /// ); fn from_row_vectors(rows: Vec) -> Option { - if let Some(first_row) = rows.first().cloned() { - return Some(rows.iter().skip(1).cloned().fold( - Self::from_row_vector(first_row), - |current_matrix, new_row| { - current_matrix.v_stack(&BaseMatrix::from_row_vector(new_row)) - }, - )); - } else { - None + if rows.is_empty() { + return None; } + let n = rows.len(); + let m = rows[0].len(); + + let mut result = Self::zeros(n, m); + + for (row_idx, row) in rows.into_iter().enumerate() { + result.set_row(row_idx, row); + } + + Some(result) } /// Transforms 1-d matrix of 1xM into a row vector. @@ -376,6 +379,13 @@ pub trait BaseMatrix: Clone + Debug { /// * `result` - receiver for the row fn copy_row_as_vec(&self, row: usize, result: &mut Vec); + /// Set row vector at row `row_idx`. + fn set_row(&mut self, row_idx: usize, row: Self::RowVector) { + for (col_idx, val) in row.to_vec().into_iter().enumerate() { + self.set(row_idx, col_idx, val); + } + } + /// Get a vector with elements of the `col`'th column /// * `col` - column number fn get_col_as_vec(&self, col: usize) -> Vec; @@ -836,6 +846,32 @@ mod tests { "The second column was not extracted correctly" ); } + + #[test] + fn test_from_row_vectors_simple() { + let eye = DenseMatrix::from_row_vectors(vec![ + vec![1., 0., 0.], + vec![0., 1., 0.], + vec![0., 0., 1.], + ]) + .unwrap(); + assert_eq!( + eye, + DenseMatrix::from_2d_vec(&vec![ + vec![1.0, 0.0, 0.0], + vec![0.0, 1.0, 0.0], + vec![0.0, 0.0, 1.0], + ]) + ); + } + + #[test] + fn test_from_row_vectors_large() { + let eye = DenseMatrix::from_row_vectors(vec![vec![4.25; 5000]; 5000]).unwrap(); + + assert_eq!(eye.shape(), (5000, 5000)); + assert_eq!(eye.get_row(5), vec![4.25; 5000]); + } mod matrix_from_csv { use crate::linalg::naive::dense_matrix::DenseMatrix;