From 89a5136191522a2882ffa3f8a10bda92161024b5 Mon Sep 17 00:00:00 2001 From: morenol Date: Wed, 25 Nov 2020 14:39:02 -0400 Subject: [PATCH] Change implementation of to_row_vector for nalgebra (#34) * Add failing test * Change implementation of to_row_vector for nalgebra --- Cargo.toml | 4 ++-- src/linalg/naive/dense_matrix.rs | 6 ++++++ src/linalg/nalgebra_bindings.rs | 11 +++++++++-- src/linalg/ndarray_bindings.rs | 6 ++++++ 4 files changed, 23 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 20eebf5..6e15f88 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ datasets = [] [dependencies] ndarray = { version = "0.13", optional = true } -nalgebra = { version = "0.22.0", optional = true } +nalgebra = { version = "0.23.0", optional = true } num-traits = "0.2.12" num = "0.3.0" rand = "0.7.3" @@ -35,4 +35,4 @@ bincode = "1.3.1" [[bench]] name = "distance" -harness = false \ No newline at end of file +harness = false diff --git a/src/linalg/naive/dense_matrix.rs b/src/linalg/naive/dense_matrix.rs index 7486329..9279c3c 100644 --- a/src/linalg/naive/dense_matrix.rs +++ b/src/linalg/naive/dense_matrix.rs @@ -1064,6 +1064,12 @@ mod tests { ); } + #[test] + fn col_matrix_to_row_vector() { + let m: DenseMatrix = BaseMatrix::zeros(10, 1); + assert_eq!(m.to_row_vector().len(), 10) + } + #[test] fn iter() { let vec = vec![1., 2., 3., 4., 5., 6.]; diff --git a/src/linalg/nalgebra_bindings.rs b/src/linalg/nalgebra_bindings.rs index 8ddfdb6..da2ec05 100644 --- a/src/linalg/nalgebra_bindings.rs +++ b/src/linalg/nalgebra_bindings.rs @@ -185,14 +185,15 @@ impl BaseVector for MatrixMN { impl BaseMatrix for Matrix> { - type RowVector = MatrixMN; + type RowVector = RowDVector; fn from_row_vector(vec: Self::RowVector) -> Self { Matrix::from_rows(&[vec]) } fn to_row_vector(self) -> Self::RowVector { - self.row(0).into_owned() + let (nrows, ncols) = self.shape(); + self.reshape_generic(U1, Dynamic::new(nrows * ncols)) } fn get(&self, row: usize, col: usize) -> T { @@ -697,6 +698,12 @@ mod tests { assert_eq!(m.to_row_vector(), expected); } + #[test] + fn col_matrix_to_row_vector() { + let m: DMatrix = BaseMatrix::zeros(10, 1); + assert_eq!(m.to_row_vector().len(), 10) + } + #[test] fn get_row_col_as_vec() { let m = DMatrix::from_row_slice(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]); diff --git a/src/linalg/ndarray_bindings.rs b/src/linalg/ndarray_bindings.rs index b5058ab..308e355 100644 --- a/src/linalg/ndarray_bindings.rs +++ b/src/linalg/ndarray_bindings.rs @@ -563,6 +563,12 @@ mod tests { ); } + #[test] + fn col_matrix_to_row_vector() { + let m: Array2 = BaseMatrix::zeros(10, 1); + assert_eq!(m.to_row_vector().len(), 10) + } + #[test] fn add_mut() { let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);