Implement a generic read_csv method (#147)
* feat: Add interface to build `Matrix` from rows. * feat: Add option to derive `RealNumber` from string. To construct a `Matrix` from csv, and therefore from string, I need to be able to deserialize a generic `RealNumber` from string. * feat: Implement `Matrix::read_csv`.
This commit is contained in:
@@ -65,8 +65,11 @@ use high_order::HighOrderOperations;
|
||||
use lu::LUDecomposableMatrix;
|
||||
use qr::QRDecomposableMatrix;
|
||||
use stats::{MatrixPreprocessing, MatrixStats};
|
||||
use std::fs;
|
||||
use svd::SVDDecomposableMatrix;
|
||||
|
||||
use crate::readers;
|
||||
|
||||
/// Column or row vector
|
||||
pub trait BaseVector<T: RealNumber>: Clone + Debug {
|
||||
/// Get an element of a vector
|
||||
@@ -298,9 +301,60 @@ pub trait BaseMatrix<T: RealNumber>: Clone + Debug {
|
||||
/// represents a row in this matrix.
|
||||
type RowVector: BaseVector<T> + Clone + Debug;
|
||||
|
||||
/// Create a matrix from a csv file.
|
||||
/// ```
|
||||
/// use smartcore::linalg::naive::dense_matrix::DenseMatrix;
|
||||
/// use smartcore::linalg::BaseMatrix;
|
||||
/// use smartcore::readers::csv;
|
||||
/// use std::fs;
|
||||
///
|
||||
/// fs::write("identity.csv", "header\n1.0,0.0\n0.0,1.0");
|
||||
/// assert_eq!(
|
||||
/// DenseMatrix::<f64>::from_csv("identity.csv", csv::CSVDefinition::default()).unwrap(),
|
||||
/// DenseMatrix::from_row_vectors(vec![vec![1.0, 0.0], vec![0.0, 1.0]]).unwrap()
|
||||
/// );
|
||||
/// fs::remove_file("identity.csv");
|
||||
/// ```
|
||||
fn from_csv(
|
||||
path: &str,
|
||||
definition: readers::csv::CSVDefinition<'_>,
|
||||
) -> Result<Self, readers::ReadingError> {
|
||||
readers::csv::matrix_from_csv_source(fs::File::open(path)?, definition)
|
||||
}
|
||||
|
||||
/// Transforms row vector `vec` into a 1xM matrix.
|
||||
fn from_row_vector(vec: Self::RowVector) -> Self;
|
||||
|
||||
/// Transforms Vector of n rows with dimension m into
|
||||
/// a matrix nxm.
|
||||
/// ```
|
||||
/// use smartcore::linalg::naive::dense_matrix::DenseMatrix;
|
||||
/// use crate::smartcore::linalg::BaseMatrix;
|
||||
///
|
||||
/// let eye = DenseMatrix::from_row_vectors(vec![vec![1., 0., 0.], vec![0., 1., 0.], vec![0., 0., 1.]])
|
||||
/// .unwrap();
|
||||
///
|
||||
/// assert_eq!(
|
||||
/// eye,
|
||||
/// DenseMatrix::from_2d_vec(&vec![
|
||||
/// vec![1.0, 0.0, 0.0],
|
||||
/// vec![0.0, 1.0, 0.0],
|
||||
/// vec![0.0, 0.0, 1.0],
|
||||
/// ])
|
||||
/// );
|
||||
fn from_row_vectors(rows: Vec<Self::RowVector>) -> Option<Self> {
|
||||
if let Some(first_row) = rows.first().cloned() {
|
||||
return Some(rows.iter().skip(1).cloned().fold(
|
||||
Self::from_row_vector(first_row),
|
||||
|current_matrix, new_row| {
|
||||
current_matrix.v_stack(&BaseMatrix::from_row_vector(new_row))
|
||||
},
|
||||
));
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Transforms 1-d matrix of 1xM into a row vector.
|
||||
fn to_row_vector(self) -> Self::RowVector;
|
||||
|
||||
@@ -782,4 +836,50 @@ mod tests {
|
||||
"The second column was not extracted correctly"
|
||||
);
|
||||
}
|
||||
mod matrix_from_csv {
|
||||
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::linalg::BaseMatrix;
|
||||
use crate::readers::csv;
|
||||
use crate::readers::io_testing;
|
||||
use crate::readers::ReadingError;
|
||||
|
||||
#[test]
|
||||
fn simple_read_default_csv() {
|
||||
let test_csv_file = io_testing::TemporaryTextFile::new(
|
||||
"'sepal.length','sepal.width','petal.length','petal.width'\n\
|
||||
5.1,3.5,1.4,0.2\n\
|
||||
4.9,3,1.4,0.2\n\
|
||||
4.7,3.2,1.3,0.2",
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
DenseMatrix::<f64>::from_csv(
|
||||
test_csv_file
|
||||
.expect("Temporary file could not be written.")
|
||||
.path(),
|
||||
csv::CSVDefinition::default()
|
||||
),
|
||||
Ok(DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
]))
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_existant_input_file() {
|
||||
let potential_error =
|
||||
DenseMatrix::<f64>::from_csv("/invalid/path", csv::CSVDefinition::default());
|
||||
// The exact message is operating system dependant, therefore, I only test that the correct type
|
||||
// error was returned.
|
||||
assert_eq!(
|
||||
potential_error.clone(),
|
||||
Err(ReadingError::CouldNotReadFileSystem {
|
||||
msg: String::from(potential_error.err().unwrap().message().unwrap())
|
||||
})
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user