Implement a generic read_csv method (#147)

* feat: Add interface to build `Matrix` from rows.
* feat: Add option to derive `RealNumber` from string.
To construct a `Matrix` from csv, and therefore from string, I need to be able to deserialize a generic `RealNumber` from string.
* feat: Implement `Matrix::read_csv`.
This commit is contained in:
Tim Toebrock
2022-09-19 11:38:01 +02:00
committed by morenol
parent 1f2597be74
commit 2d75c2c405
7 changed files with 841 additions and 0 deletions
+100
View File
@@ -65,8 +65,11 @@ use high_order::HighOrderOperations;
use lu::LUDecomposableMatrix;
use qr::QRDecomposableMatrix;
use stats::{MatrixPreprocessing, MatrixStats};
use std::fs;
use svd::SVDDecomposableMatrix;
use crate::readers;
/// Column or row vector
pub trait BaseVector<T: RealNumber>: Clone + Debug {
/// Get an element of a vector
@@ -298,9 +301,60 @@ pub trait BaseMatrix<T: RealNumber>: Clone + Debug {
/// represents a row in this matrix.
type RowVector: BaseVector<T> + Clone + Debug;
/// Create a matrix from a csv file.
/// ```
/// use smartcore::linalg::naive::dense_matrix::DenseMatrix;
/// use smartcore::linalg::BaseMatrix;
/// use smartcore::readers::csv;
/// use std::fs;
///
/// fs::write("identity.csv", "header\n1.0,0.0\n0.0,1.0");
/// assert_eq!(
/// DenseMatrix::<f64>::from_csv("identity.csv", csv::CSVDefinition::default()).unwrap(),
/// DenseMatrix::from_row_vectors(vec![vec![1.0, 0.0], vec![0.0, 1.0]]).unwrap()
/// );
/// fs::remove_file("identity.csv");
/// ```
fn from_csv(
path: &str,
definition: readers::csv::CSVDefinition<'_>,
) -> Result<Self, readers::ReadingError> {
readers::csv::matrix_from_csv_source(fs::File::open(path)?, definition)
}
/// Transforms row vector `vec` into a 1xM matrix.
fn from_row_vector(vec: Self::RowVector) -> Self;
/// Transforms Vector of n rows with dimension m into
/// a matrix nxm.
/// ```
/// use smartcore::linalg::naive::dense_matrix::DenseMatrix;
/// use crate::smartcore::linalg::BaseMatrix;
///
/// let eye = DenseMatrix::from_row_vectors(vec![vec![1., 0., 0.], vec![0., 1., 0.], vec![0., 0., 1.]])
/// .unwrap();
///
/// assert_eq!(
/// eye,
/// DenseMatrix::from_2d_vec(&vec![
/// vec![1.0, 0.0, 0.0],
/// vec![0.0, 1.0, 0.0],
/// vec![0.0, 0.0, 1.0],
/// ])
/// );
fn from_row_vectors(rows: Vec<Self::RowVector>) -> Option<Self> {
if let Some(first_row) = rows.first().cloned() {
return Some(rows.iter().skip(1).cloned().fold(
Self::from_row_vector(first_row),
|current_matrix, new_row| {
current_matrix.v_stack(&BaseMatrix::from_row_vector(new_row))
},
));
} else {
None
}
}
/// Transforms 1-d matrix of 1xM into a row vector.
fn to_row_vector(self) -> Self::RowVector;
@@ -782,4 +836,50 @@ mod tests {
"The second column was not extracted correctly"
);
}
mod matrix_from_csv {
use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::linalg::BaseMatrix;
use crate::readers::csv;
use crate::readers::io_testing;
use crate::readers::ReadingError;
#[test]
fn simple_read_default_csv() {
let test_csv_file = io_testing::TemporaryTextFile::new(
"'sepal.length','sepal.width','petal.length','petal.width'\n\
5.1,3.5,1.4,0.2\n\
4.9,3,1.4,0.2\n\
4.7,3.2,1.3,0.2",
);
assert_eq!(
DenseMatrix::<f64>::from_csv(
test_csv_file
.expect("Temporary file could not be written.")
.path(),
csv::CSVDefinition::default()
),
Ok(DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
]))
)
}
#[test]
fn non_existant_input_file() {
let potential_error =
DenseMatrix::<f64>::from_csv("/invalid/path", csv::CSVDefinition::default());
// The exact message is operating system dependant, therefore, I only test that the correct type
// error was returned.
assert_eq!(
potential_error.clone(),
Err(ReadingError::CouldNotReadFileSystem {
msg: String::from(potential_error.err().unwrap().message().unwrap())
})
)
}
}
}