Implement CSV reader with new traits (#209)

This commit is contained in:
Lorenzo
2022-11-03 15:49:00 +00:00
committed by morenol
parent b427e5d8b1
commit ed9769f651
5 changed files with 54 additions and 38 deletions
+1
View File
@@ -1,5 +1,6 @@
### I'm submitting a ### I'm submitting a
- [ ] bug report. - [ ] bug report.
- [ ] improvement.
- [ ] feature request. - [ ] feature request.
### Current Behaviour: ### Current Behaviour:
+2 -2
View File
@@ -105,8 +105,8 @@ pub mod neighbors;
pub mod optimization; pub mod optimization;
/// Preprocessing utilities /// Preprocessing utilities
pub mod preprocessing; pub mod preprocessing;
// /// Reading in Data. /// Reading in data from serialized foramts
// pub mod readers; pub mod readers;
/// Support Vector Machines /// Support Vector Machines
pub mod svm; pub mod svm;
/// Supervised tree-based learning methods /// Supervised tree-based learning methods
+39 -32
View File
@@ -1,23 +1,24 @@
//! This module contains utitilities to read-in matrices from csv files. //! This module contains utitilities to read-in matrices from csv files.
//! ``` //! ```rust
//! use smartcore::readers::csv; //! use smartcore::readers::csv;
//! use smartcore::linalg::naive::dense_matrix::DenseMatrix; //! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use crate::smartcore::linalg::BaseMatrix;
//! use std::fs; //! use std::fs;
//! //!
//! fs::write("identity.csv", "header\n1.0,0.0\n0.0,1.0"); //! fs::write("identity.csv", "header\n1.0,0.0\n0.0,1.0");
//! assert_eq!( //!
//! csv::matrix_from_csv_source::<f64, Vec<_>, DenseMatrix<_>>( //! let mtx = csv::matrix_from_csv_source::<f64, Vec<_>, DenseMatrix<_>>(
//! fs::File::open("identity.csv").unwrap(), //! fs::File::open("identity.csv").unwrap(),
//! csv::CSVDefinition::default() //! csv::CSVDefinition::default()
//! ) //! )
//! .unwrap(), //! .unwrap();
//! DenseMatrix::from_row_vectors(vec![vec![1.0, 0.0], vec![0.0, 1.0]]).unwrap() //! println!("{}", &mtx);
//! ); //!
//! fs::remove_file("identity.csv"); //! fs::remove_file("identity.csv");
//! ``` //! ```
use crate::linalg::{BaseMatrix, BaseVector};
use crate::math::num::RealNumber; use crate::linalg::basic::arrays::{Array1, Array2};
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
use crate::readers::ReadingError; use crate::readers::ReadingError;
use std::io::Read; use std::io::Read;
@@ -77,35 +78,41 @@ pub fn matrix_from_csv_source<T, RowVector, Matrix>(
definition: CSVDefinition<'_>, definition: CSVDefinition<'_>,
) -> Result<Matrix, ReadingError> ) -> Result<Matrix, ReadingError>
where where
T: RealNumber, T: Number + RealNumber + std::str::FromStr,
RowVector: BaseVector<T>, RowVector: Array1<T>,
Matrix: BaseMatrix<T, RowVector = RowVector>, Matrix: Array2<T>,
{ {
let csv_text = read_string_from_source(source)?; let csv_text = read_string_from_source(source)?;
let rows = extract_row_vectors_from_csv_text::<T, RowVector, Matrix>( let rows: Vec<Vec<T>> = extract_row_vectors_from_csv_text::<T, RowVector, Matrix>(
&csv_text, &csv_text,
&definition, &definition,
detect_row_format(&csv_text, &definition)?, detect_row_format(&csv_text, &definition)?,
)?; )?;
let nrows = rows.len();
let ncols = rows[0].len();
match Matrix::from_row_vectors(rows) { // TODO: try to return ReadingError
Some(matrix) => Ok(matrix), let array2 = Matrix::from_iterator(rows.into_iter().flatten(), nrows, ncols, 0);
None => Err(ReadingError::NoRowsProvided),
if array2.shape() != (nrows, ncols) {
Err(ReadingError::ShapesDoNotMatch { msg: String::new() })
} else {
Ok(array2)
} }
} }
/// Given a string containing the contents of a csv file, extract its value /// Given a string containing the contents of a csv file, extract its value
/// into row-vectors. /// into row-vectors.
fn extract_row_vectors_from_csv_text<'a, T, RowVector, Matrix>( fn extract_row_vectors_from_csv_text<
'a,
T: Number + RealNumber + std::str::FromStr,
RowVector: Array1<T>,
Matrix: Array2<T>,
>(
csv_text: &'a str, csv_text: &'a str,
definition: &'a CSVDefinition<'_>, definition: &'a CSVDefinition<'_>,
row_format: CSVRowFormat<'_>, row_format: CSVRowFormat<'_>,
) -> Result<Vec<RowVector>, ReadingError> ) -> Result<Vec<Vec<T>>, ReadingError> {
where
T: RealNumber,
RowVector: BaseVector<T>,
Matrix: BaseMatrix<T, RowVector = RowVector>,
{
csv_text csv_text
.lines() .lines()
.skip(definition.n_rows_header) .skip(definition.n_rows_header)
@@ -132,12 +139,12 @@ fn extract_vector_from_csv_line<T, RowVector>(
row_format: &CSVRowFormat<'_>, row_format: &CSVRowFormat<'_>,
) -> Result<RowVector, ReadingError> ) -> Result<RowVector, ReadingError>
where where
T: RealNumber, T: Number + RealNumber + std::str::FromStr,
RowVector: BaseVector<T>, RowVector: Array1<T>,
{ {
validate_csv_row(line, row_format)?; validate_csv_row(line, row_format)?;
let extracted_fields = extract_fields_from_csv_row(line, row_format)?; let extracted_fields: Vec<T> = extract_fields_from_csv_row(line, row_format)?;
Ok(BaseVector::from_array(&extracted_fields[..])) Ok(Array1::from_vec_slice(&extracted_fields[..]))
} }
/// Extract the fields from a string containing the row of a csv file. /// Extract the fields from a string containing the row of a csv file.
@@ -146,7 +153,7 @@ fn extract_fields_from_csv_row<T>(
row_format: &CSVRowFormat<'_>, row_format: &CSVRowFormat<'_>,
) -> Result<Vec<T>, ReadingError> ) -> Result<Vec<T>, ReadingError>
where where
T: RealNumber, T: Number + RealNumber + std::str::FromStr,
{ {
row.split(row_format.field_seperator) row.split(row_format.field_seperator)
.enumerate() .enumerate()
@@ -192,7 +199,7 @@ fn enrich_reading_error<T>(
/// Extract the value from a single csv field. /// Extract the value from a single csv field.
fn extract_value_from_csv_field<T>(value_string: &str) -> Result<T, ReadingError> fn extract_value_from_csv_field<T>(value_string: &str) -> Result<T, ReadingError>
where where
T: RealNumber, T: Number + RealNumber + std::str::FromStr,
{ {
// By default, `FromStr::Err` does not implement `Debug`. // By default, `FromStr::Err` does not implement `Debug`.
// Restricting it in the library leads to many breaking // Restricting it in the library leads to many breaking
@@ -210,7 +217,7 @@ where
mod tests { mod tests {
mod matrix_from_csv_source { mod matrix_from_csv_source {
use super::super::{read_string_from_source, CSVDefinition, ReadingError}; use super::super::{read_string_from_source, CSVDefinition, ReadingError};
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::basic::matrix::DenseMatrix;
use crate::readers::{csv::matrix_from_csv_source, io_testing}; use crate::readers::{csv::matrix_from_csv_source, io_testing};
#[test] #[test]
@@ -298,7 +305,7 @@ mod tests {
} }
mod extract_row_vectors_from_csv_text { mod extract_row_vectors_from_csv_text {
use super::super::{extract_row_vectors_from_csv_text, CSVDefinition, CSVRowFormat}; use super::super::{extract_row_vectors_from_csv_text, CSVDefinition, CSVRowFormat};
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::basic::matrix::DenseMatrix;
#[test] #[test]
fn read_default_csv() { fn read_default_csv() {
+7
View File
@@ -24,6 +24,12 @@ pub enum ReadingError {
/// and where it happened. /// and where it happened.
msg: String, msg: String,
}, },
/// Shape after deserialization is wrong
ShapesDoNotMatch {
/// More details about what row could not be read
/// and where it happened.
msg: String,
},
} }
impl From<std::io::Error> for ReadingError { impl From<std::io::Error> for ReadingError {
fn from(io_error: std::io::Error) -> Self { fn from(io_error: std::io::Error) -> Self {
@@ -39,6 +45,7 @@ impl ReadingError {
ReadingError::InvalidField { msg } => Some(msg), ReadingError::InvalidField { msg } => Some(msg),
ReadingError::InvalidRow { msg } => Some(msg), ReadingError::InvalidRow { msg } => Some(msg),
ReadingError::CouldNotReadFileSystem { msg } => Some(msg), ReadingError::CouldNotReadFileSystem { msg } => Some(msg),
ReadingError::ShapesDoNotMatch { msg } => Some(msg),
ReadingError::NoRowsProvided => None, ReadingError::NoRowsProvided => None,
} }
} }
+2 -1
View File
@@ -23,9 +23,10 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
/// search parameters /// search parameters
pub mod search;
pub mod svc; pub mod svc;
pub mod svr; pub mod svr;
// /// search parameters space
// pub mod search;
use core::fmt::Debug; use core::fmt::Debug;
use std::marker::PhantomData; use std::marker::PhantomData;