//! This module contains utitilities to read-in matrices from csv files. //! ```rust //! use smartcore::readers::csv; //! use smartcore::linalg::basic::matrix::DenseMatrix; //! use std::fs; //! //! fs::write("identity.csv", "header\n1.0,0.0\n0.0,1.0"); //! //! let mtx = csv::matrix_from_csv_source::, DenseMatrix<_>>( //! fs::File::open("identity.csv").unwrap(), //! csv::CSVDefinition::default() //! ) //! .unwrap(); //! println!("{}", &mtx); //! //! fs::remove_file("identity.csv"); //! ``` use crate::linalg::basic::arrays::{Array1, Array2}; use crate::numbers::basenum::Number; use crate::numbers::realnum::RealNumber; use crate::readers::ReadingError; use std::io::Read; /// Define the structure of a CSV-file so that it can be read. #[derive(Clone, Debug, PartialEq, Eq)] pub struct CSVDefinition<'a> { /// How many rows does the header have? n_rows_header: usize, /// What seperates the fields in your csv-file? field_seperator: &'a str, } impl Default for CSVDefinition<'_> { fn default() -> Self { Self { n_rows_header: 1, field_seperator: ",", } } } /// Format definition for a single row in a csv file. /// This is used internally to validate rows of the csv file and /// be able to fail as early as possible. #[derive(Clone, Debug, PartialEq, Eq)] struct CSVRowFormat<'a> { field_seperator: &'a str, n_fields: usize, } impl<'a> CSVRowFormat<'a> { fn from_csv_definition(definition: &'a CSVDefinition<'_>, n_fields: usize) -> Self { CSVRowFormat { field_seperator: definition.field_seperator, n_fields, } } } /// Detect the row format for the csv file from the first row. fn detect_row_format<'a>( csv_text: &'a str, definition: &'a CSVDefinition<'_>, ) -> Result, ReadingError> { let first_line = csv_text .lines() .nth(definition.n_rows_header) .ok_or(ReadingError::NoRowsProvided)?; Ok(CSVRowFormat::from_csv_definition( definition, first_line.split(definition.field_seperator).count(), )) } /// Read in a matrix from a source that contains a csv file. pub fn matrix_from_csv_source( source: impl Read, definition: CSVDefinition<'_>, ) -> Result where T: Number + RealNumber + std::str::FromStr, RowVector: Array1, Matrix: Array2, { let csv_text = read_string_from_source(source)?; let rows: Vec> = extract_row_vectors_from_csv_text( &csv_text, &definition, detect_row_format(&csv_text, &definition)?, )?; let nrows = rows.len(); let ncols = rows[0].len(); // TODO: try to return ReadingError let array2 = Matrix::from_iterator(rows.into_iter().flatten(), nrows, ncols, 0); if array2.shape() != (nrows, ncols) { Err(ReadingError::ShapesDoNotMatch { msg: String::new() }) } else { Ok(array2) } } /// Given a string containing the contents of a csv file, extract its value /// into row-vectors. fn extract_row_vectors_from_csv_text<'a, T: Number + RealNumber + std::str::FromStr>( csv_text: &'a str, definition: &'a CSVDefinition<'_>, row_format: CSVRowFormat<'_>, ) -> Result>, ReadingError> { csv_text .lines() .skip(definition.n_rows_header) .enumerate() .map(|(row_index, line)| { enrich_reading_error( extract_vector_from_csv_line(line, &row_format), format!(", Row: {row_index}."), ) }) .collect::, ReadingError>>() } /// Read a string from source implementing `Read`. fn read_string_from_source(mut source: impl Read) -> Result { let mut string = String::new(); source.read_to_string(&mut string)?; Ok(string) } /// Extract a vector from a single line of a csv file. fn extract_vector_from_csv_line( line: &str, row_format: &CSVRowFormat<'_>, ) -> Result where T: Number + RealNumber + std::str::FromStr, RowVector: Array1, { validate_csv_row(line, row_format)?; let extracted_fields: Vec = extract_fields_from_csv_row(line, row_format)?; Ok(Array1::from_vec_slice(&extracted_fields[..])) } /// Extract the fields from a string containing the row of a csv file. fn extract_fields_from_csv_row( row: &str, row_format: &CSVRowFormat<'_>, ) -> Result, ReadingError> where T: Number + RealNumber + std::str::FromStr, { row.split(row_format.field_seperator) .enumerate() .map(|(field_number, csv_field)| { enrich_reading_error( extract_value_from_csv_field(csv_field.trim()), format!(" Column: {field_number}"), ) }) .collect::, ReadingError>>() } /// Ensure that a string containing a csv row conforms to a specified row format. fn validate_csv_row(row: &str, row_format: &CSVRowFormat<'_>) -> Result<(), ReadingError> { let actual_number_of_fields = row.split(row_format.field_seperator).count(); if row_format.n_fields == actual_number_of_fields { Ok(()) } else { Err(ReadingError::InvalidRow { msg: format!( "{} fields found but expected {}", actual_number_of_fields, row_format.n_fields ), }) } } /// Add additional text to the message of an error. /// In csv reading it is used to add the line-number / row-number /// The error occured that is only known in the functions above. fn enrich_reading_error( result: Result, additional_text: String, ) -> Result { result.map_err(|error| ReadingError::InvalidField { msg: format!( "{}{additional_text}", error.message().unwrap_or("Could not serialize value") ), }) } /// Extract the value from a single csv field. fn extract_value_from_csv_field(value_string: &str) -> Result where T: Number + RealNumber + std::str::FromStr, { // By default, `FromStr::Err` does not implement `Debug`. // Restricting it in the library leads to many breaking // changes therefore I have to reconstruct my own, printable // error as good as possible. match value_string.parse::().ok() { Some(value) => Ok(value), None => Err(ReadingError::InvalidField { msg: format!("Value '{value_string}' could not be read.",), }), } } #[cfg(test)] mod tests { mod matrix_from_csv_source { use super::super::{read_string_from_source, CSVDefinition, ReadingError}; use crate::linalg::basic::matrix::DenseMatrix; use crate::readers::{csv::matrix_from_csv_source, io_testing}; #[test] fn read_simple_string() { assert_eq!( read_string_from_source(io_testing::TestingDataSource::new("test-string")), Ok(String::from("test-string")) ) } #[test] fn read_simple_csv() { assert_eq!( matrix_from_csv_source::, DenseMatrix<_>>( io_testing::TestingDataSource::new( "'sepal.length','sepal.width','petal.length','petal.width'\n\ 5.1,3.5,1.4,0.2\n\ 4.9,3.0,1.4,0.2\n\ 4.7,3.2,1.3,0.2", ), CSVDefinition::default(), ), Ok(DenseMatrix::from_2d_array(&[ &[5.1, 3.5, 1.4, 0.2], &[4.9, 3.0, 1.4, 0.2], &[4.7, 3.2, 1.3, 0.2], ]) .unwrap()) ) } #[test] fn read_csv_semicolon_as_seperator() { assert_eq!( matrix_from_csv_source::, DenseMatrix<_>>( io_testing::TestingDataSource::new( "'sepal.length';'sepal.width';'petal.length';'petal.width'\n\ 'Length of sepals.';'Width of Sepals';'Length of petals';'Width of petals'\n\ 5.1;3.5;1.4;0.2\n\ 4.9;3.0;1.4;0.2\n\ 4.7;3.2;1.3;0.2", ), CSVDefinition { n_rows_header: 2, field_seperator: ";" }, ), Ok(DenseMatrix::from_2d_array(&[ &[5.1, 3.5, 1.4, 0.2], &[4.9, 3.0, 1.4, 0.2], &[4.7, 3.2, 1.3, 0.2], ]).unwrap()) ) } #[test] fn error_in_colum_1_row_1() { assert_eq!( matrix_from_csv_source::, DenseMatrix<_>>( io_testing::TestingDataSource::new( "'sepal.length','sepal.width','petal.length','petal.width'\n\ 5.1,3.5,1.4,0.2\n\ 4.9,invalid,1.4,0.2\n\ 4.7,3.2,1.3,0.2", ), CSVDefinition::default(), ), Err(ReadingError::InvalidField { msg: String::from("Value 'invalid' could not be read. Column: 1, Row: 1.") }) ) } #[test] fn different_number_of_columns() { assert_eq!( matrix_from_csv_source::, DenseMatrix<_>>( io_testing::TestingDataSource::new( "'field_1','field_2'\n\ 5.1,3.5\n\ 4.9,3.0,1.4", ), CSVDefinition::default(), ), Err(ReadingError::InvalidField { msg: String::from("3 fields found but expected 2, Row: 1.") }) ) } } mod extract_row_vectors_from_csv_text { use super::super::{extract_row_vectors_from_csv_text, CSVDefinition, CSVRowFormat}; #[test] fn read_default_csv() { assert_eq!( extract_row_vectors_from_csv_text::( "column 1, column 2, column3\n1.0,2.0,3.0\n4.0,5.0,6.0", &CSVDefinition::default(), CSVRowFormat { field_seperator: ",", n_fields: 3, }, ), Ok(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]) ); } } mod test_validate_csv_row { use super::super::{validate_csv_row, CSVRowFormat, ReadingError}; #[test] fn valid_row_with_comma() { assert_eq!( validate_csv_row( "1.0, 2.0, 3.0", &CSVRowFormat { field_seperator: ",", n_fields: 3, }, ), Ok(()) ) } #[test] fn valid_row_with_semicolon() { assert_eq!( validate_csv_row( "1.0; 2.0; 3.0; 4.0", &CSVRowFormat { field_seperator: ";", n_fields: 4, }, ), Ok(()) ) } #[test] fn invalid_number_of_fields() { assert_eq!( validate_csv_row( "1.0; 2.0; 3.0; 4.0", &CSVRowFormat { field_seperator: ";", n_fields: 3, }, ), Err(ReadingError::InvalidRow { msg: String::from("4 fields found but expected 3") }) ) } } mod extract_fields_from_csv_row { use super::super::{extract_fields_from_csv_row, CSVRowFormat}; #[test] fn read_four_values_from_csv_row() { assert_eq!( extract_fields_from_csv_row( "1.0; 2.0; 3.0; 4.0", &CSVRowFormat { field_seperator: ";", n_fields: 4 } ), Ok(vec![1.0, 2.0, 3.0, 4.0]) ) } } mod detect_row_format { use super::super::{detect_row_format, CSVDefinition, CSVRowFormat, ReadingError}; #[test] fn detect_2_fields_with_header() { assert_eq!( detect_row_format( "header-1\nheader-2\n1.0,2.0", &CSVDefinition { n_rows_header: 2, field_seperator: "," } ) .expect("The row format should be detectable with this input."), CSVRowFormat { field_seperator: ",", n_fields: 2 } ) } #[test] fn detect_3_fields_no_header() { assert_eq!( detect_row_format( "1.0,2.0,3.0", &CSVDefinition { n_rows_header: 0, field_seperator: "," } ) .expect("The row format should be detectable with this input."), CSVRowFormat { field_seperator: ",", n_fields: 3 } ) } #[test] fn detect_no_rows_provided() { assert_eq!( detect_row_format("header\n", &CSVDefinition::default()), Err(ReadingError::NoRowsProvided) ) } } mod extract_value_from_csv_field { use super::super::extract_value_from_csv_field; use crate::readers::ReadingError; #[test] fn deserialize_f64_from_floating_point() { assert_eq!(extract_value_from_csv_field::("1.0"), Ok(1.0)) } #[test] fn deserialize_f64_from_negative_floating_point() { assert_eq!(extract_value_from_csv_field::("-1.0"), Ok(-1.0)) } #[test] fn deserialize_f64_from_non_floating_point() { assert_eq!(extract_value_from_csv_field::("1"), Ok(1.0)) } #[test] fn cant_deserialize_f64_from_string() { assert_eq!( extract_value_from_csv_field::("Test"), Err(ReadingError::InvalidField { msg: String::from("Value 'Test' could not be read.") },) ) } #[test] fn deserialize_f32_from_non_floating_point() { assert_eq!(extract_value_from_csv_field::("12.0"), Ok(12.0)) } } mod extract_vector_from_csv_line { use super::super::{extract_vector_from_csv_line, CSVRowFormat, ReadingError}; #[test] fn extract_five_floating_point_values() { assert_eq!( extract_vector_from_csv_line::>( "-1.0,2.0,100.0,12", &CSVRowFormat { field_seperator: ",", n_fields: 4 } ), Ok(vec![-1.0, 2.0, 100.0, 12.0]) ) } #[test] fn cannot_extract_second_value() { assert_eq!( extract_vector_from_csv_line::>( "-1.0,test,100.0,12", &CSVRowFormat { field_seperator: ",", n_fields: 4 } ), Err(ReadingError::InvalidField { msg: String::from("Value 'test' could not be read. Column: 1") }) ) } } }