use std::fmt::{Debug, Display}; use std::ops::Range; use crate::linalg::basic::arrays::{ Array as BaseArray, Array2, ArrayView1, ArrayView2, MutArray, MutArrayView2, }; use crate::linalg::traits::cholesky::CholeskyDecomposable; use crate::linalg::traits::evd::EVDDecomposable; use crate::linalg::traits::lu::LUDecomposable; use crate::linalg::traits::qr::QRDecomposable; use crate::linalg::traits::svd::SVDDecomposable; use crate::numbers::basenum::Number; use crate::numbers::realnum::RealNumber; use ndarray::{s, Array, ArrayBase, ArrayView, ArrayViewMut, Ix2, OwnedRepr}; impl BaseArray for ArrayBase, Ix2> { fn get(&self, pos: (usize, usize)) -> &T { &self[[pos.0, pos.1]] } fn shape(&self) -> (usize, usize) { (self.nrows(), self.ncols()) } fn is_empty(&self) -> bool { self.len() > 0 } fn iterator<'b>(&'b self, axis: u8) -> Box + 'b> { assert!( axis == 1 || axis == 0, "For two dimensional array `axis` should be either 0 or 1" ); match axis { 0 => Box::new(self.iter()), _ => Box::new( (0..self.ncols()).flat_map(move |c| (0..self.nrows()).map(move |r| &self[[r, c]])), ), } } } impl MutArray for ArrayBase, Ix2> { fn set(&mut self, pos: (usize, usize), x: T) { self[[pos.0, pos.1]] = x } fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box + 'b> { let ptr = self.as_mut_ptr(); let stride = self.strides(); let (rstride, cstride) = (stride[0] as usize, stride[1] as usize); match axis { 0 => Box::new(self.iter_mut()), _ => Box::new((0..self.ncols()).flat_map(move |c| { (0..self.nrows()).map(move |r| unsafe { &mut *ptr.add(r * rstride + c * cstride) }) })), } } } impl ArrayView2 for ArrayBase, Ix2> {} impl MutArrayView2 for ArrayBase, Ix2> {} impl BaseArray for ArrayView<'_, T, Ix2> { fn get(&self, pos: (usize, usize)) -> &T { &self[[pos.0, pos.1]] } fn shape(&self) -> (usize, usize) { (self.nrows(), self.ncols()) } fn is_empty(&self) -> bool { self.len() > 0 } fn iterator<'b>(&'b self, axis: u8) -> Box + 'b> { assert!( axis == 1 || axis == 0, "For two dimensional array `axis` should be either 0 or 1" ); match axis { 0 => Box::new(self.iter()), _ => Box::new( (0..self.ncols()).flat_map(move |c| (0..self.nrows()).map(move |r| &self[[r, c]])), ), } } } impl Array2 for ArrayBase, Ix2> { fn get_row<'a>(&'a self, row: usize) -> Box + 'a> { Box::new(self.row(row)) } fn get_col<'a>(&'a self, col: usize) -> Box + 'a> { Box::new(self.column(col)) } fn slice<'a>(&'a self, rows: Range, cols: Range) -> Box + 'a> { Box::new(self.slice(s![rows, cols])) } fn slice_mut<'a>( &'a mut self, rows: Range, cols: Range, ) -> Box + 'a> where Self: Sized, { Box::new(self.slice_mut(s![rows, cols])) } fn fill(nrows: usize, ncols: usize, value: T) -> Self { Array::from_elem([nrows, ncols], value) } fn from_iterator>(iter: I, nrows: usize, ncols: usize, axis: u8) -> Self { let a = Array::from_iter(iter.take(nrows * ncols)) .into_shape((nrows, ncols)) .unwrap(); match axis { 0 => a, _ => a.reversed_axes().into_shape((nrows, ncols)).unwrap(), } } fn transpose(&self) -> Self { self.t().to_owned() } } impl QRDecomposable for ArrayBase, Ix2> {} impl CholeskyDecomposable for ArrayBase, Ix2> {} impl EVDDecomposable for ArrayBase, Ix2> {} impl LUDecomposable for ArrayBase, Ix2> {} impl SVDDecomposable for ArrayBase, Ix2> {} impl ArrayView2 for ArrayView<'_, T, Ix2> {} impl BaseArray for ArrayViewMut<'_, T, Ix2> { fn get(&self, pos: (usize, usize)) -> &T { &self[[pos.0, pos.1]] } fn shape(&self) -> (usize, usize) { (self.nrows(), self.ncols()) } fn is_empty(&self) -> bool { self.len() > 0 } fn iterator<'b>(&'b self, axis: u8) -> Box + 'b> { assert!( axis == 1 || axis == 0, "For two dimensional array `axis` should be either 0 or 1" ); match axis { 0 => Box::new(self.iter()), _ => Box::new( (0..self.ncols()).flat_map(move |c| (0..self.nrows()).map(move |r| &self[[r, c]])), ), } } } impl MutArray for ArrayViewMut<'_, T, Ix2> { fn set(&mut self, pos: (usize, usize), x: T) { self[[pos.0, pos.1]] = x } fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box + 'b> { let ptr = self.as_mut_ptr(); let stride = self.strides(); let (rstride, cstride) = (stride[0] as usize, stride[1] as usize); match axis { 0 => Box::new(self.iter_mut()), _ => Box::new((0..self.ncols()).flat_map(move |c| { (0..self.nrows()).map(move |r| unsafe { &mut *ptr.add(r * rstride + c * cstride) }) })), } } } impl MutArrayView2 for ArrayViewMut<'_, T, Ix2> {} impl ArrayView2 for ArrayViewMut<'_, T, Ix2> {} #[cfg(test)] mod tests { use super::*; use ndarray::{arr2, Array2 as NDArray2}; #[test] fn test_get_set() { let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]); assert_eq!(*BaseArray::get(&a, (1, 1)), 5); a.set((1, 1), 9); assert_eq!(a, arr2(&[[1, 2, 3], [4, 9, 6]])); } #[test] fn test_iterator() { let a = arr2(&[[1, 2, 3], [4, 5, 6]]); let v: Vec = a.iterator(0).copied().collect(); assert_eq!(v, vec!(1, 2, 3, 4, 5, 6)); } #[test] fn test_mut_iterator() { let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]); a.iterator_mut(0).enumerate().for_each(|(i, v)| *v = i); assert_eq!(a, arr2(&[[0, 1, 2], [3, 4, 5]])); a.iterator_mut(1).enumerate().for_each(|(i, v)| *v = i); assert_eq!(a, arr2(&[[0, 2, 4], [1, 3, 5]])); } #[test] fn test_slice() { let x = arr2(&[[1, 2, 3], [4, 5, 6]]); let x_slice = Array2::slice(&x, 0..2, 1..2); assert_eq!((2, 1), x_slice.shape()); let v: Vec = x_slice.iterator(0).copied().collect(); assert_eq!(v, [2, 5]); } #[test] fn test_slice_iter() { let x = arr2(&[[1, 2, 3], [4, 5, 6]]); let x_slice = Array2::slice(&x, 0..2, 0..3); assert_eq!( x_slice.iterator(0).copied().collect::>(), vec![1, 2, 3, 4, 5, 6] ); assert_eq!( x_slice.iterator(1).copied().collect::>(), vec![1, 4, 2, 5, 3, 6] ); } #[test] fn test_slice_mut_iter() { let mut x = arr2(&[[1, 2, 3], [4, 5, 6]]); { let mut x_slice = Array2::slice_mut(&mut x, 0..2, 0..3); x_slice .iterator_mut(0) .enumerate() .for_each(|(i, v)| *v = i); } assert_eq!(x, arr2(&[[0, 1, 2], [3, 4, 5]])); { let mut x_slice = Array2::slice_mut(&mut x, 0..2, 0..3); x_slice .iterator_mut(1) .enumerate() .for_each(|(i, v)| *v = i); } assert_eq!(x, arr2(&[[0, 2, 4], [1, 3, 5]])); } #[test] fn test_c_from_iterator() { let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; let a: NDArray2 = Array2::from_iterator(data.clone().into_iter(), 4, 3, 0); println!("{a}"); let a: NDArray2 = Array2::from_iterator(data.into_iter(), 4, 3, 1); println!("{a}"); } }