Implement a generic read_csv method (#147)
* feat: Add interface to build `Matrix` from rows. * feat: Add option to derive `RealNumber` from string. To construct a `Matrix` from csv, and therefore from string, I need to be able to deserialize a generic `RealNumber` from string. * feat: Implement `Matrix::read_csv`.
This commit is contained in:
@@ -95,6 +95,8 @@ pub mod neighbors;
|
|||||||
pub(crate) mod optimization;
|
pub(crate) mod optimization;
|
||||||
/// Preprocessing utilities
|
/// Preprocessing utilities
|
||||||
pub mod preprocessing;
|
pub mod preprocessing;
|
||||||
|
/// Reading in Data.
|
||||||
|
pub mod readers;
|
||||||
/// Support Vector Machines
|
/// Support Vector Machines
|
||||||
pub mod svm;
|
pub mod svm;
|
||||||
/// Supervised tree-based learning methods
|
/// Supervised tree-based learning methods
|
||||||
|
|||||||
@@ -65,8 +65,11 @@ use high_order::HighOrderOperations;
|
|||||||
use lu::LUDecomposableMatrix;
|
use lu::LUDecomposableMatrix;
|
||||||
use qr::QRDecomposableMatrix;
|
use qr::QRDecomposableMatrix;
|
||||||
use stats::{MatrixPreprocessing, MatrixStats};
|
use stats::{MatrixPreprocessing, MatrixStats};
|
||||||
|
use std::fs;
|
||||||
use svd::SVDDecomposableMatrix;
|
use svd::SVDDecomposableMatrix;
|
||||||
|
|
||||||
|
use crate::readers;
|
||||||
|
|
||||||
/// Column or row vector
|
/// Column or row vector
|
||||||
pub trait BaseVector<T: RealNumber>: Clone + Debug {
|
pub trait BaseVector<T: RealNumber>: Clone + Debug {
|
||||||
/// Get an element of a vector
|
/// Get an element of a vector
|
||||||
@@ -298,9 +301,60 @@ pub trait BaseMatrix<T: RealNumber>: Clone + Debug {
|
|||||||
/// represents a row in this matrix.
|
/// represents a row in this matrix.
|
||||||
type RowVector: BaseVector<T> + Clone + Debug;
|
type RowVector: BaseVector<T> + Clone + Debug;
|
||||||
|
|
||||||
|
/// Create a matrix from a csv file.
|
||||||
|
/// ```
|
||||||
|
/// use smartcore::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
/// use smartcore::linalg::BaseMatrix;
|
||||||
|
/// use smartcore::readers::csv;
|
||||||
|
/// use std::fs;
|
||||||
|
///
|
||||||
|
/// fs::write("identity.csv", "header\n1.0,0.0\n0.0,1.0");
|
||||||
|
/// assert_eq!(
|
||||||
|
/// DenseMatrix::<f64>::from_csv("identity.csv", csv::CSVDefinition::default()).unwrap(),
|
||||||
|
/// DenseMatrix::from_row_vectors(vec![vec![1.0, 0.0], vec![0.0, 1.0]]).unwrap()
|
||||||
|
/// );
|
||||||
|
/// fs::remove_file("identity.csv");
|
||||||
|
/// ```
|
||||||
|
fn from_csv(
|
||||||
|
path: &str,
|
||||||
|
definition: readers::csv::CSVDefinition<'_>,
|
||||||
|
) -> Result<Self, readers::ReadingError> {
|
||||||
|
readers::csv::matrix_from_csv_source(fs::File::open(path)?, definition)
|
||||||
|
}
|
||||||
|
|
||||||
/// Transforms row vector `vec` into a 1xM matrix.
|
/// Transforms row vector `vec` into a 1xM matrix.
|
||||||
fn from_row_vector(vec: Self::RowVector) -> Self;
|
fn from_row_vector(vec: Self::RowVector) -> Self;
|
||||||
|
|
||||||
|
/// Transforms Vector of n rows with dimension m into
|
||||||
|
/// a matrix nxm.
|
||||||
|
/// ```
|
||||||
|
/// use smartcore::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
/// use crate::smartcore::linalg::BaseMatrix;
|
||||||
|
///
|
||||||
|
/// let eye = DenseMatrix::from_row_vectors(vec![vec![1., 0., 0.], vec![0., 1., 0.], vec![0., 0., 1.]])
|
||||||
|
/// .unwrap();
|
||||||
|
///
|
||||||
|
/// assert_eq!(
|
||||||
|
/// eye,
|
||||||
|
/// DenseMatrix::from_2d_vec(&vec![
|
||||||
|
/// vec![1.0, 0.0, 0.0],
|
||||||
|
/// vec![0.0, 1.0, 0.0],
|
||||||
|
/// vec![0.0, 0.0, 1.0],
|
||||||
|
/// ])
|
||||||
|
/// );
|
||||||
|
fn from_row_vectors(rows: Vec<Self::RowVector>) -> Option<Self> {
|
||||||
|
if let Some(first_row) = rows.first().cloned() {
|
||||||
|
return Some(rows.iter().skip(1).cloned().fold(
|
||||||
|
Self::from_row_vector(first_row),
|
||||||
|
|current_matrix, new_row| {
|
||||||
|
current_matrix.v_stack(&BaseMatrix::from_row_vector(new_row))
|
||||||
|
},
|
||||||
|
));
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Transforms 1-d matrix of 1xM into a row vector.
|
/// Transforms 1-d matrix of 1xM into a row vector.
|
||||||
fn to_row_vector(self) -> Self::RowVector;
|
fn to_row_vector(self) -> Self::RowVector;
|
||||||
|
|
||||||
@@ -782,4 +836,50 @@ mod tests {
|
|||||||
"The second column was not extracted correctly"
|
"The second column was not extracted correctly"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
mod matrix_from_csv {
|
||||||
|
|
||||||
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
use crate::linalg::BaseMatrix;
|
||||||
|
use crate::readers::csv;
|
||||||
|
use crate::readers::io_testing;
|
||||||
|
use crate::readers::ReadingError;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn simple_read_default_csv() {
|
||||||
|
let test_csv_file = io_testing::TemporaryTextFile::new(
|
||||||
|
"'sepal.length','sepal.width','petal.length','petal.width'\n\
|
||||||
|
5.1,3.5,1.4,0.2\n\
|
||||||
|
4.9,3,1.4,0.2\n\
|
||||||
|
4.7,3.2,1.3,0.2",
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
DenseMatrix::<f64>::from_csv(
|
||||||
|
test_csv_file
|
||||||
|
.expect("Temporary file could not be written.")
|
||||||
|
.path(),
|
||||||
|
csv::CSVDefinition::default()
|
||||||
|
),
|
||||||
|
Ok(DenseMatrix::from_2d_array(&[
|
||||||
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
|
&[4.7, 3.2, 1.3, 0.2],
|
||||||
|
]))
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn non_existant_input_file() {
|
||||||
|
let potential_error =
|
||||||
|
DenseMatrix::<f64>::from_csv("/invalid/path", csv::CSVDefinition::default());
|
||||||
|
// The exact message is operating system dependant, therefore, I only test that the correct type
|
||||||
|
// error was returned.
|
||||||
|
assert_eq!(
|
||||||
|
potential_error.clone(),
|
||||||
|
Err(ReadingError::CouldNotReadFileSystem {
|
||||||
|
msg: String::from(potential_error.err().unwrap().message().unwrap())
|
||||||
|
})
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ use rand::prelude::*;
|
|||||||
use std::fmt::{Debug, Display};
|
use std::fmt::{Debug, Display};
|
||||||
use std::iter::{Product, Sum};
|
use std::iter::{Product, Sum};
|
||||||
use std::ops::{AddAssign, DivAssign, MulAssign, SubAssign};
|
use std::ops::{AddAssign, DivAssign, MulAssign, SubAssign};
|
||||||
|
use std::str::FromStr;
|
||||||
|
|
||||||
/// Defines real number
|
/// Defines real number
|
||||||
/// <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
/// <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||||
@@ -22,6 +23,7 @@ pub trait RealNumber:
|
|||||||
+ SubAssign
|
+ SubAssign
|
||||||
+ MulAssign
|
+ MulAssign
|
||||||
+ DivAssign
|
+ DivAssign
|
||||||
|
+ FromStr
|
||||||
{
|
{
|
||||||
/// Copy sign from `sign` - another real number
|
/// Copy sign from `sign` - another real number
|
||||||
fn copysign(self, sign: Self) -> Self;
|
fn copysign(self, sign: Self) -> Self;
|
||||||
@@ -154,4 +156,14 @@ mod tests {
|
|||||||
assert_eq!(41.0.sigmoid(), 1.);
|
assert_eq!(41.0.sigmoid(), 1.);
|
||||||
assert_eq!((-41.0).sigmoid(), 0.);
|
assert_eq!((-41.0).sigmoid(), 0.);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn f32_from_string() {
|
||||||
|
assert_eq!(f32::from_str("1.111111").unwrap(), 1.111111)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn f64_from_string() {
|
||||||
|
assert_eq!(f64::from_str("1.111111111").unwrap(), 1.111111111)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,487 @@
|
|||||||
|
//! This module contains utitilities to read-in matrices from csv files.
|
||||||
|
//! ```
|
||||||
|
//! use smartcore::readers::csv;
|
||||||
|
//! use smartcore::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
//! use crate::smartcore::linalg::BaseMatrix;
|
||||||
|
//! use std::fs;
|
||||||
|
//!
|
||||||
|
//! fs::write("identity.csv", "header\n1.0,0.0\n0.0,1.0");
|
||||||
|
//! assert_eq!(
|
||||||
|
//! csv::matrix_from_csv_source::<f64, Vec<_>, DenseMatrix<_>>(
|
||||||
|
//! fs::File::open("identity.csv").unwrap(),
|
||||||
|
//! csv::CSVDefinition::default()
|
||||||
|
//! )
|
||||||
|
//! .unwrap(),
|
||||||
|
//! DenseMatrix::from_row_vectors(vec![vec![1.0, 0.0], vec![0.0, 1.0]]).unwrap()
|
||||||
|
//! );
|
||||||
|
//! fs::remove_file("identity.csv");
|
||||||
|
//! ```
|
||||||
|
use crate::linalg::{BaseMatrix, BaseVector};
|
||||||
|
use crate::math::num::RealNumber;
|
||||||
|
use crate::readers::ReadingError;
|
||||||
|
use std::io::Read;
|
||||||
|
|
||||||
|
/// Define the structure of a CSV-file so that it can be read.
|
||||||
|
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||||
|
pub struct CSVDefinition<'a> {
|
||||||
|
/// How many rows does the header have?
|
||||||
|
n_rows_header: usize,
|
||||||
|
/// What seperates the fields in your csv-file?
|
||||||
|
field_seperator: &'a str,
|
||||||
|
}
|
||||||
|
impl<'a> Default for CSVDefinition<'a> {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
n_rows_header: 1,
|
||||||
|
field_seperator: ",",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Format definition for a single row in a csv file.
|
||||||
|
/// This is used internally to validate rows of the csv file and
|
||||||
|
/// be able to fail as early as possible.
|
||||||
|
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||||
|
struct CSVRowFormat<'a> {
|
||||||
|
field_seperator: &'a str,
|
||||||
|
n_fields: usize,
|
||||||
|
}
|
||||||
|
impl<'a> CSVRowFormat<'a> {
|
||||||
|
fn from_csv_definition(definition: &'a CSVDefinition<'_>, n_fields: usize) -> Self {
|
||||||
|
CSVRowFormat {
|
||||||
|
field_seperator: definition.field_seperator,
|
||||||
|
n_fields,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Detect the row format for the csv file from the first row.
|
||||||
|
fn detect_row_format<'a>(
|
||||||
|
csv_text: &'a str,
|
||||||
|
definition: &'a CSVDefinition<'_>,
|
||||||
|
) -> Result<CSVRowFormat<'a>, ReadingError> {
|
||||||
|
let first_line = csv_text
|
||||||
|
.lines()
|
||||||
|
.nth(definition.n_rows_header)
|
||||||
|
.ok_or(ReadingError::NoRowsProvided)?;
|
||||||
|
|
||||||
|
Ok(CSVRowFormat::from_csv_definition(
|
||||||
|
definition,
|
||||||
|
first_line.split(definition.field_seperator).count(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Read in a matrix from a source that contains a csv file.
|
||||||
|
pub fn matrix_from_csv_source<T, RowVector, Matrix>(
|
||||||
|
source: impl Read,
|
||||||
|
definition: CSVDefinition<'_>,
|
||||||
|
) -> Result<Matrix, ReadingError>
|
||||||
|
where
|
||||||
|
T: RealNumber,
|
||||||
|
RowVector: BaseVector<T>,
|
||||||
|
Matrix: BaseMatrix<T, RowVector = RowVector>,
|
||||||
|
{
|
||||||
|
let csv_text = read_string_from_source(source)?;
|
||||||
|
let rows = extract_row_vectors_from_csv_text::<T, RowVector, Matrix>(
|
||||||
|
&csv_text,
|
||||||
|
&definition,
|
||||||
|
detect_row_format(&csv_text, &definition)?,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
match Matrix::from_row_vectors(rows) {
|
||||||
|
Some(matrix) => Ok(matrix),
|
||||||
|
None => Err(ReadingError::NoRowsProvided),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Given a string containing the contents of a csv file, extract its value
|
||||||
|
/// into row-vectors.
|
||||||
|
fn extract_row_vectors_from_csv_text<'a, T, RowVector, Matrix>(
|
||||||
|
csv_text: &'a str,
|
||||||
|
definition: &'a CSVDefinition<'_>,
|
||||||
|
row_format: CSVRowFormat<'_>,
|
||||||
|
) -> Result<Vec<RowVector>, ReadingError>
|
||||||
|
where
|
||||||
|
T: RealNumber,
|
||||||
|
RowVector: BaseVector<T>,
|
||||||
|
Matrix: BaseMatrix<T, RowVector = RowVector>,
|
||||||
|
{
|
||||||
|
csv_text
|
||||||
|
.lines()
|
||||||
|
.skip(definition.n_rows_header)
|
||||||
|
.enumerate()
|
||||||
|
.map(|(row_index, line)| {
|
||||||
|
enrich_reading_error(
|
||||||
|
extract_vector_from_csv_line(line, &row_format),
|
||||||
|
format!(", Row: {row_index}."),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, ReadingError>>()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Read a string from source implementing `Read`.
|
||||||
|
fn read_string_from_source(mut source: impl Read) -> Result<String, ReadingError> {
|
||||||
|
let mut string = String::new();
|
||||||
|
source.read_to_string(&mut string)?;
|
||||||
|
Ok(string)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract a vector from a single line of a csv file.
|
||||||
|
fn extract_vector_from_csv_line<T, RowVector>(
|
||||||
|
line: &str,
|
||||||
|
row_format: &CSVRowFormat<'_>,
|
||||||
|
) -> Result<RowVector, ReadingError>
|
||||||
|
where
|
||||||
|
T: RealNumber,
|
||||||
|
RowVector: BaseVector<T>,
|
||||||
|
{
|
||||||
|
validate_csv_row(line, row_format)?;
|
||||||
|
let extracted_fields = extract_fields_from_csv_row(line, row_format)?;
|
||||||
|
Ok(BaseVector::from_array(&extracted_fields[..]))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract the fields from a string containing the row of a csv file.
|
||||||
|
fn extract_fields_from_csv_row<T>(
|
||||||
|
row: &str,
|
||||||
|
row_format: &CSVRowFormat<'_>,
|
||||||
|
) -> Result<Vec<T>, ReadingError>
|
||||||
|
where
|
||||||
|
T: RealNumber,
|
||||||
|
{
|
||||||
|
row.split(row_format.field_seperator)
|
||||||
|
.enumerate()
|
||||||
|
.map(|(field_number, csv_field)| {
|
||||||
|
enrich_reading_error(
|
||||||
|
extract_value_from_csv_field(csv_field.trim()),
|
||||||
|
format!(" Column: {field_number}"),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<T>, ReadingError>>()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Ensure that a string containing a csv row conforms to a specified row format.
|
||||||
|
fn validate_csv_row<'a>(row: &'a str, row_format: &CSVRowFormat<'_>) -> Result<(), ReadingError> {
|
||||||
|
let actual_number_of_fields = row.split(row_format.field_seperator).count();
|
||||||
|
if row_format.n_fields == actual_number_of_fields {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(ReadingError::InvalidRow {
|
||||||
|
msg: format!(
|
||||||
|
"{} fields found but expected {}",
|
||||||
|
actual_number_of_fields, row_format.n_fields
|
||||||
|
),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add additional text to the message of an error.
|
||||||
|
/// In csv reading it is used to add the line-number / row-number
|
||||||
|
/// The error occured that is only known in the functions above.
|
||||||
|
fn enrich_reading_error<T>(
|
||||||
|
result: Result<T, ReadingError>,
|
||||||
|
additional_text: String,
|
||||||
|
) -> Result<T, ReadingError> {
|
||||||
|
result.map_err(|error| ReadingError::InvalidField {
|
||||||
|
msg: format!(
|
||||||
|
"{}{additional_text}",
|
||||||
|
error.message().unwrap_or("Could not serialize value")
|
||||||
|
),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract the value from a single csv field.
|
||||||
|
fn extract_value_from_csv_field<T>(value_string: &str) -> Result<T, ReadingError>
|
||||||
|
where
|
||||||
|
T: RealNumber,
|
||||||
|
{
|
||||||
|
// By default, `FromStr::Err` does not implement `Debug`.
|
||||||
|
// Restricting it in the library leads to many breaking
|
||||||
|
// changes therefore I have to reconstruct my own, printable
|
||||||
|
// error as good as possible.
|
||||||
|
match value_string.parse::<T>().ok() {
|
||||||
|
Some(value) => Ok(value),
|
||||||
|
None => Err(ReadingError::InvalidField {
|
||||||
|
msg: format!("Value '{}' could not be read.", value_string,),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
mod matrix_from_csv_source {
|
||||||
|
use super::super::{read_string_from_source, CSVDefinition, ReadingError};
|
||||||
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
use crate::readers::{csv::matrix_from_csv_source, io_testing};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn read_simple_string() {
|
||||||
|
assert_eq!(
|
||||||
|
read_string_from_source(io_testing::TestingDataSource::new("test-string")),
|
||||||
|
Ok(String::from("test-string"))
|
||||||
|
)
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn read_simple_csv() {
|
||||||
|
assert_eq!(
|
||||||
|
matrix_from_csv_source::<f64, Vec<_>, DenseMatrix<_>>(
|
||||||
|
io_testing::TestingDataSource::new(
|
||||||
|
"'sepal.length','sepal.width','petal.length','petal.width'\n\
|
||||||
|
5.1,3.5,1.4,0.2\n\
|
||||||
|
4.9,3.0,1.4,0.2\n\
|
||||||
|
4.7,3.2,1.3,0.2",
|
||||||
|
),
|
||||||
|
CSVDefinition::default(),
|
||||||
|
),
|
||||||
|
Ok(DenseMatrix::from_2d_array(&[
|
||||||
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
|
&[4.7, 3.2, 1.3, 0.2],
|
||||||
|
]))
|
||||||
|
)
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn read_csv_semicolon_as_seperator() {
|
||||||
|
assert_eq!(
|
||||||
|
matrix_from_csv_source::<f64, Vec<_>, DenseMatrix<_>>(
|
||||||
|
io_testing::TestingDataSource::new(
|
||||||
|
"'sepal.length';'sepal.width';'petal.length';'petal.width'\n\
|
||||||
|
'Length of sepals.';'Width of Sepals';'Length of petals';'Width of petals'\n\
|
||||||
|
5.1;3.5;1.4;0.2\n\
|
||||||
|
4.9;3.0;1.4;0.2\n\
|
||||||
|
4.7;3.2;1.3;0.2",
|
||||||
|
),
|
||||||
|
CSVDefinition {
|
||||||
|
n_rows_header: 2,
|
||||||
|
field_seperator: ";"
|
||||||
|
},
|
||||||
|
),
|
||||||
|
Ok(DenseMatrix::from_2d_array(&[
|
||||||
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
|
&[4.7, 3.2, 1.3, 0.2],
|
||||||
|
]))
|
||||||
|
)
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn error_in_colum_1_row_1() {
|
||||||
|
assert_eq!(
|
||||||
|
matrix_from_csv_source::<f64, Vec<_>, DenseMatrix<_>>(
|
||||||
|
io_testing::TestingDataSource::new(
|
||||||
|
"'sepal.length','sepal.width','petal.length','petal.width'\n\
|
||||||
|
5.1,3.5,1.4,0.2\n\
|
||||||
|
4.9,invalid,1.4,0.2\n\
|
||||||
|
4.7,3.2,1.3,0.2",
|
||||||
|
),
|
||||||
|
CSVDefinition::default(),
|
||||||
|
),
|
||||||
|
Err(ReadingError::InvalidField {
|
||||||
|
msg: String::from("Value 'invalid' could not be read. Column: 1, Row: 1.")
|
||||||
|
})
|
||||||
|
)
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn different_number_of_columns() {
|
||||||
|
assert_eq!(
|
||||||
|
matrix_from_csv_source::<f64, Vec<_>, DenseMatrix<_>>(
|
||||||
|
io_testing::TestingDataSource::new(
|
||||||
|
"'field_1','field_2'\n\
|
||||||
|
5.1,3.5\n\
|
||||||
|
4.9,3.0,1.4",
|
||||||
|
),
|
||||||
|
CSVDefinition::default(),
|
||||||
|
),
|
||||||
|
Err(ReadingError::InvalidField {
|
||||||
|
msg: String::from("3 fields found but expected 2, Row: 1.")
|
||||||
|
})
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mod extract_row_vectors_from_csv_text {
|
||||||
|
use super::super::{extract_row_vectors_from_csv_text, CSVDefinition, CSVRowFormat};
|
||||||
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn read_default_csv() {
|
||||||
|
assert_eq!(
|
||||||
|
extract_row_vectors_from_csv_text::<f64, Vec<_>, DenseMatrix<_>>(
|
||||||
|
"column 1, column 2, column3\n1.0,2.0,3.0\n4.0,5.0,6.0",
|
||||||
|
&CSVDefinition::default(),
|
||||||
|
CSVRowFormat {
|
||||||
|
field_seperator: ",",
|
||||||
|
n_fields: 3,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
Ok(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]])
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mod test_validate_csv_row {
|
||||||
|
use super::super::{validate_csv_row, CSVRowFormat, ReadingError};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn valid_row_with_comma() {
|
||||||
|
assert_eq!(
|
||||||
|
validate_csv_row(
|
||||||
|
"1.0, 2.0, 3.0",
|
||||||
|
&CSVRowFormat {
|
||||||
|
field_seperator: ",",
|
||||||
|
n_fields: 3,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
Ok(())
|
||||||
|
)
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn valid_row_with_semicolon() {
|
||||||
|
assert_eq!(
|
||||||
|
validate_csv_row(
|
||||||
|
"1.0; 2.0; 3.0; 4.0",
|
||||||
|
&CSVRowFormat {
|
||||||
|
field_seperator: ";",
|
||||||
|
n_fields: 4,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
Ok(())
|
||||||
|
)
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn invalid_number_of_fields() {
|
||||||
|
assert_eq!(
|
||||||
|
validate_csv_row(
|
||||||
|
"1.0; 2.0; 3.0; 4.0",
|
||||||
|
&CSVRowFormat {
|
||||||
|
field_seperator: ";",
|
||||||
|
n_fields: 3,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
Err(ReadingError::InvalidRow {
|
||||||
|
msg: String::from("4 fields found but expected 3")
|
||||||
|
})
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mod extract_fields_from_csv_row {
|
||||||
|
use super::super::{extract_fields_from_csv_row, CSVRowFormat};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn read_four_values_from_csv_row() {
|
||||||
|
assert_eq!(
|
||||||
|
extract_fields_from_csv_row(
|
||||||
|
"1.0; 2.0; 3.0; 4.0",
|
||||||
|
&CSVRowFormat {
|
||||||
|
field_seperator: ";",
|
||||||
|
n_fields: 4
|
||||||
|
}
|
||||||
|
),
|
||||||
|
Ok(vec![1.0, 2.0, 3.0, 4.0])
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mod detect_row_format {
|
||||||
|
use super::super::{detect_row_format, CSVDefinition, CSVRowFormat, ReadingError};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn detect_2_fields_with_header() {
|
||||||
|
assert_eq!(
|
||||||
|
detect_row_format(
|
||||||
|
"header-1\nheader-2\n1.0,2.0",
|
||||||
|
&CSVDefinition {
|
||||||
|
n_rows_header: 2,
|
||||||
|
field_seperator: ","
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.expect("The row format should be detectable with this input."),
|
||||||
|
CSVRowFormat {
|
||||||
|
field_seperator: ",",
|
||||||
|
n_fields: 2
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn detect_3_fields_no_header() {
|
||||||
|
assert_eq!(
|
||||||
|
detect_row_format(
|
||||||
|
"1.0,2.0,3.0",
|
||||||
|
&CSVDefinition {
|
||||||
|
n_rows_header: 0,
|
||||||
|
field_seperator: ","
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.expect("The row format should be detectable with this input."),
|
||||||
|
CSVRowFormat {
|
||||||
|
field_seperator: ",",
|
||||||
|
n_fields: 3
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn detect_no_rows_provided() {
|
||||||
|
assert_eq!(
|
||||||
|
detect_row_format("header\n", &CSVDefinition::default()),
|
||||||
|
Err(ReadingError::NoRowsProvided)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mod extract_value_from_csv_field {
|
||||||
|
use super::super::extract_value_from_csv_field;
|
||||||
|
use crate::readers::ReadingError;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn deserialize_f64_from_floating_point() {
|
||||||
|
assert_eq!(extract_value_from_csv_field::<f64>("1.0"), Ok(1.0))
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn deserialize_f64_from_negative_floating_point() {
|
||||||
|
assert_eq!(extract_value_from_csv_field::<f64>("-1.0"), Ok(-1.0))
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn deserialize_f64_from_non_floating_point() {
|
||||||
|
assert_eq!(extract_value_from_csv_field::<f64>("1"), Ok(1.0))
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn cant_deserialize_f64_from_string() {
|
||||||
|
assert_eq!(
|
||||||
|
extract_value_from_csv_field::<f64>("Test"),
|
||||||
|
Err(ReadingError::InvalidField {
|
||||||
|
msg: String::from("Value 'Test' could not be read.")
|
||||||
|
},)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn deserialize_f32_from_non_floating_point() {
|
||||||
|
assert_eq!(extract_value_from_csv_field::<f32>("12.0"), Ok(12.0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mod extract_vector_from_csv_line {
|
||||||
|
use super::super::{extract_vector_from_csv_line, CSVRowFormat, ReadingError};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_five_floating_point_values() {
|
||||||
|
assert_eq!(
|
||||||
|
extract_vector_from_csv_line::<f64, Vec<f64>>(
|
||||||
|
"-1.0,2.0,100.0,12",
|
||||||
|
&CSVRowFormat {
|
||||||
|
field_seperator: ",",
|
||||||
|
n_fields: 4
|
||||||
|
}
|
||||||
|
),
|
||||||
|
Ok(vec![-1.0, 2.0, 100.0, 12.0])
|
||||||
|
)
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn cannot_extract_second_value() {
|
||||||
|
assert_eq!(
|
||||||
|
extract_vector_from_csv_line::<f64, Vec<f64>>(
|
||||||
|
"-1.0,test,100.0,12",
|
||||||
|
&CSVRowFormat {
|
||||||
|
field_seperator: ",",
|
||||||
|
n_fields: 4
|
||||||
|
}
|
||||||
|
),
|
||||||
|
Err(ReadingError::InvalidField {
|
||||||
|
msg: String::from("Value 'test' could not be read. Column: 1")
|
||||||
|
})
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,71 @@
|
|||||||
|
//! The module contains the errors that can happen in the `readers` folder and
|
||||||
|
//! utility functions.
|
||||||
|
|
||||||
|
/// Error wrapping all failures that can happen during loading from file.
|
||||||
|
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||||
|
pub enum ReadingError {
|
||||||
|
/// The file could not be read from the file-system.
|
||||||
|
CouldNotReadFileSystem {
|
||||||
|
/// More details about the specific file-system error
|
||||||
|
/// that occured.
|
||||||
|
msg: String,
|
||||||
|
},
|
||||||
|
/// No rows exists in the CSV-file.
|
||||||
|
NoRowsProvided,
|
||||||
|
/// A field in the csv file could not be read.
|
||||||
|
InvalidField {
|
||||||
|
/// More details about what field could not be
|
||||||
|
/// read and where it happened.
|
||||||
|
msg: String,
|
||||||
|
},
|
||||||
|
/// A row from the csv is invalid.
|
||||||
|
InvalidRow {
|
||||||
|
/// More details about what row could not be read
|
||||||
|
/// and where it happened.
|
||||||
|
msg: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
impl From<std::io::Error> for ReadingError {
|
||||||
|
fn from(io_error: std::io::Error) -> Self {
|
||||||
|
Self::CouldNotReadFileSystem {
|
||||||
|
msg: io_error.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl ReadingError {
|
||||||
|
/// Extract the error-message from a `ReadingError`.
|
||||||
|
pub fn message(&self) -> Option<&str> {
|
||||||
|
match self {
|
||||||
|
ReadingError::InvalidField { msg } => Some(msg),
|
||||||
|
ReadingError::InvalidRow { msg } => Some(msg),
|
||||||
|
ReadingError::CouldNotReadFileSystem { msg } => Some(msg),
|
||||||
|
ReadingError::NoRowsProvided => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::ReadingError;
|
||||||
|
use std::io;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn reading_error_from_io_error() {
|
||||||
|
let _parsed_reading_error: ReadingError = ReadingError::from(io::Error::new(
|
||||||
|
io::ErrorKind::AlreadyExists,
|
||||||
|
"File already exists .",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn extract_message_from_reading_error() {
|
||||||
|
let error_content = "Path does not exist";
|
||||||
|
assert_eq!(
|
||||||
|
ReadingError::CouldNotReadFileSystem {
|
||||||
|
msg: String::from(error_content)
|
||||||
|
}
|
||||||
|
.message()
|
||||||
|
.expect("This error should contain a mesage"),
|
||||||
|
String::from(error_content)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,158 @@
|
|||||||
|
//! This module contains functionality to test IO. It has both functions that write
|
||||||
|
//! to the file-system for end-to-end tests, but also abstractions to avoid this by
|
||||||
|
//! reading from strings instead.
|
||||||
|
use rand::distributions::{Alphanumeric, DistString};
|
||||||
|
use std::fs;
|
||||||
|
use std::io::Bytes;
|
||||||
|
use std::io::Read;
|
||||||
|
use std::io::{Chain, IoSliceMut, Take, Write};
|
||||||
|
|
||||||
|
/// Writing out a temporary csv file at a random location and cleaning
|
||||||
|
/// it up on `Drop`.
|
||||||
|
#[derive(Clone, Debug, Default, Eq, PartialEq)]
|
||||||
|
pub struct TemporaryTextFile {
|
||||||
|
random_path: String,
|
||||||
|
}
|
||||||
|
impl TemporaryTextFile {
|
||||||
|
pub fn new(contents: &str) -> std::io::Result<Self> {
|
||||||
|
let test_text_file = TemporaryTextFile {
|
||||||
|
random_path: Alphanumeric.sample_string(&mut rand::thread_rng(), 16),
|
||||||
|
};
|
||||||
|
string_to_file(contents, &test_text_file.random_path)?;
|
||||||
|
Ok(test_text_file)
|
||||||
|
}
|
||||||
|
pub fn path(&self) -> &str {
|
||||||
|
&self.random_path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// On `Drop` we cleanup the file-system by remove the file.
|
||||||
|
impl Drop for TemporaryTextFile {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
fs::remove_file(self.path())
|
||||||
|
.unwrap_or_else(|_| panic!("Could not clean up temporary file {}.", self.random_path));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// Write out a string to file.
|
||||||
|
pub(crate) fn string_to_file(string: &str, file_path: &str) -> std::io::Result<()> {
|
||||||
|
let mut csv_file = fs::File::create(file_path)?;
|
||||||
|
csv_file.write_all(string.as_bytes())?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This is used an an alternative struct that implements `Read` so
|
||||||
|
/// that instead of reading from the file-system, we can test the same
|
||||||
|
/// functionality without any file-system interaction.
|
||||||
|
pub(crate) struct TestingDataSource {
|
||||||
|
text: String,
|
||||||
|
}
|
||||||
|
impl TestingDataSource {
|
||||||
|
pub(crate) fn new(text: &str) -> Self {
|
||||||
|
Self {
|
||||||
|
text: String::from(text),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// This is the trait that also `file::File` implements, so by implementing
|
||||||
|
/// it for `TestingDataSource` we can test functionality that reads from
|
||||||
|
/// file in a more lightweight way.
|
||||||
|
impl Read for TestingDataSource {
|
||||||
|
fn read(&mut self, _buf: &mut [u8]) -> Result<usize, std::io::Error> {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_vectored(&mut self, _bufs: &mut [IoSliceMut<'_>]) -> Result<usize, std::io::Error> {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_to_end(&mut self, _buf: &mut Vec<u8>) -> Result<usize, std::io::Error> {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
fn read_to_string(&mut self, buf: &mut String) -> Result<usize, std::io::Error> {
|
||||||
|
<String as std::fmt::Write>::write_str(buf, &self.text).unwrap();
|
||||||
|
Ok(0)
|
||||||
|
}
|
||||||
|
fn read_exact(&mut self, _buf: &mut [u8]) -> Result<(), std::io::Error> {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
fn by_ref(&mut self) -> &mut Self
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
fn bytes(self) -> Bytes<Self>
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
fn chain<R: Read>(self, _next: R) -> Chain<Self, R>
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
fn take(self, _limit: u64) -> Take<Self>
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use super::TestingDataSource;
|
||||||
|
use super::{string_to_file, TemporaryTextFile};
|
||||||
|
use std::fs;
|
||||||
|
use std::io::Read;
|
||||||
|
use std::path;
|
||||||
|
#[test]
|
||||||
|
fn test_temporary_text_file() {
|
||||||
|
let path_of_temporary_file;
|
||||||
|
{
|
||||||
|
let hello_world_file = TemporaryTextFile::new("Hello World!")
|
||||||
|
.expect("`hello_world_file` should be able to write file.");
|
||||||
|
|
||||||
|
path_of_temporary_file = String::from(hello_world_file.path());
|
||||||
|
assert_eq!(
|
||||||
|
fs::read_to_string(&path_of_temporary_file).expect(
|
||||||
|
"This field should have been written by the `hello_world_file`-object."
|
||||||
|
),
|
||||||
|
"Hello World!"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
// By now `hello_world_file` should have been dropped and the file
|
||||||
|
// should have been cleaned up.
|
||||||
|
assert!(!path::Path::new(&path_of_temporary_file).exists())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_string_to_file() {
|
||||||
|
let path_of_test_file = "test.file";
|
||||||
|
let contents_of_test_file = "Hello IO-World";
|
||||||
|
|
||||||
|
string_to_file(contents_of_test_file, path_of_test_file)
|
||||||
|
.expect("The file should have been written out.");
|
||||||
|
assert_eq!(
|
||||||
|
fs::read_to_string(path_of_test_file)
|
||||||
|
.expect("The file we test for should have been written."),
|
||||||
|
String::from(contents_of_test_file)
|
||||||
|
);
|
||||||
|
|
||||||
|
// Cleanup the temporary file.
|
||||||
|
fs::remove_file(path_of_test_file)
|
||||||
|
.expect("The test file should exist before and be removed here.");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn read_from_testing_data_source() {
|
||||||
|
let mut test_buffer = String::new();
|
||||||
|
let test_data_content = "Hello non-IO world!";
|
||||||
|
|
||||||
|
TestingDataSource::new(test_data_content)
|
||||||
|
.read_to_string(&mut test_buffer)
|
||||||
|
.expect("Text should have been written to buffer `test_buffer`.");
|
||||||
|
assert_eq!(test_buffer, test_data_content)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
/// Read in from csv.
|
||||||
|
pub mod csv;
|
||||||
|
|
||||||
|
/// Error definition for readers.
|
||||||
|
mod error;
|
||||||
|
/// Utilities to help with testing functionality using IO.
|
||||||
|
/// Only meant for internal usage.
|
||||||
|
#[cfg(test)]
|
||||||
|
pub(crate) mod io_testing;
|
||||||
|
|
||||||
|
pub use error::ReadingError;
|
||||||
Reference in New Issue
Block a user