Implement CSV reader with new traits (#209)

This commit is contained in:
Lorenzo
2022-11-03 15:49:00 +00:00
committed by morenol
parent b427e5d8b1
commit ed9769f651
5 changed files with 54 additions and 38 deletions
+2 -2
View File
@@ -105,8 +105,8 @@ pub mod neighbors;
pub mod optimization;
/// Preprocessing utilities
pub mod preprocessing;
// /// Reading in Data.
// pub mod readers;
/// Reading in data from serialized foramts
pub mod readers;
/// Support Vector Machines
pub mod svm;
/// Supervised tree-based learning methods
+42 -35
View File
@@ -1,23 +1,24 @@
//! This module contains utitilities to read-in matrices from csv files.
//! ```
//! ```rust
//! use smartcore::readers::csv;
//! use smartcore::linalg::naive::dense_matrix::DenseMatrix;
//! use crate::smartcore::linalg::BaseMatrix;
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use std::fs;
//!
//! fs::write("identity.csv", "header\n1.0,0.0\n0.0,1.0");
//! assert_eq!(
//! csv::matrix_from_csv_source::<f64, Vec<_>, DenseMatrix<_>>(
//! fs::File::open("identity.csv").unwrap(),
//! csv::CSVDefinition::default()
//! )
//! .unwrap(),
//! DenseMatrix::from_row_vectors(vec![vec![1.0, 0.0], vec![0.0, 1.0]]).unwrap()
//! );
//!
//! let mtx = csv::matrix_from_csv_source::<f64, Vec<_>, DenseMatrix<_>>(
//! fs::File::open("identity.csv").unwrap(),
//! csv::CSVDefinition::default()
//! )
//! .unwrap();
//! println!("{}", &mtx);
//!
//! fs::remove_file("identity.csv");
//! ```
use crate::linalg::{BaseMatrix, BaseVector};
use crate::math::num::RealNumber;
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
use crate::readers::ReadingError;
use std::io::Read;
@@ -77,35 +78,41 @@ pub fn matrix_from_csv_source<T, RowVector, Matrix>(
definition: CSVDefinition<'_>,
) -> Result<Matrix, ReadingError>
where
T: RealNumber,
RowVector: BaseVector<T>,
Matrix: BaseMatrix<T, RowVector = RowVector>,
T: Number + RealNumber + std::str::FromStr,
RowVector: Array1<T>,
Matrix: Array2<T>,
{
let csv_text = read_string_from_source(source)?;
let rows = extract_row_vectors_from_csv_text::<T, RowVector, Matrix>(
let rows: Vec<Vec<T>> = extract_row_vectors_from_csv_text::<T, RowVector, Matrix>(
&csv_text,
&definition,
detect_row_format(&csv_text, &definition)?,
)?;
let nrows = rows.len();
let ncols = rows[0].len();
match Matrix::from_row_vectors(rows) {
Some(matrix) => Ok(matrix),
None => Err(ReadingError::NoRowsProvided),
// TODO: try to return ReadingError
let array2 = Matrix::from_iterator(rows.into_iter().flatten(), nrows, ncols, 0);
if array2.shape() != (nrows, ncols) {
Err(ReadingError::ShapesDoNotMatch { msg: String::new() })
} else {
Ok(array2)
}
}
/// Given a string containing the contents of a csv file, extract its value
/// into row-vectors.
fn extract_row_vectors_from_csv_text<'a, T, RowVector, Matrix>(
fn extract_row_vectors_from_csv_text<
'a,
T: Number + RealNumber + std::str::FromStr,
RowVector: Array1<T>,
Matrix: Array2<T>,
>(
csv_text: &'a str,
definition: &'a CSVDefinition<'_>,
row_format: CSVRowFormat<'_>,
) -> Result<Vec<RowVector>, ReadingError>
where
T: RealNumber,
RowVector: BaseVector<T>,
Matrix: BaseMatrix<T, RowVector = RowVector>,
{
) -> Result<Vec<Vec<T>>, ReadingError> {
csv_text
.lines()
.skip(definition.n_rows_header)
@@ -132,12 +139,12 @@ fn extract_vector_from_csv_line<T, RowVector>(
row_format: &CSVRowFormat<'_>,
) -> Result<RowVector, ReadingError>
where
T: RealNumber,
RowVector: BaseVector<T>,
T: Number + RealNumber + std::str::FromStr,
RowVector: Array1<T>,
{
validate_csv_row(line, row_format)?;
let extracted_fields = extract_fields_from_csv_row(line, row_format)?;
Ok(BaseVector::from_array(&extracted_fields[..]))
let extracted_fields: Vec<T> = extract_fields_from_csv_row(line, row_format)?;
Ok(Array1::from_vec_slice(&extracted_fields[..]))
}
/// Extract the fields from a string containing the row of a csv file.
@@ -146,7 +153,7 @@ fn extract_fields_from_csv_row<T>(
row_format: &CSVRowFormat<'_>,
) -> Result<Vec<T>, ReadingError>
where
T: RealNumber,
T: Number + RealNumber + std::str::FromStr,
{
row.split(row_format.field_seperator)
.enumerate()
@@ -192,7 +199,7 @@ fn enrich_reading_error<T>(
/// Extract the value from a single csv field.
fn extract_value_from_csv_field<T>(value_string: &str) -> Result<T, ReadingError>
where
T: RealNumber,
T: Number + RealNumber + std::str::FromStr,
{
// By default, `FromStr::Err` does not implement `Debug`.
// Restricting it in the library leads to many breaking
@@ -210,7 +217,7 @@ where
mod tests {
mod matrix_from_csv_source {
use super::super::{read_string_from_source, CSVDefinition, ReadingError};
use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::readers::{csv::matrix_from_csv_source, io_testing};
#[test]
@@ -298,7 +305,7 @@ mod tests {
}
mod extract_row_vectors_from_csv_text {
use super::super::{extract_row_vectors_from_csv_text, CSVDefinition, CSVRowFormat};
use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::linalg::basic::matrix::DenseMatrix;
#[test]
fn read_default_csv() {
+7
View File
@@ -24,6 +24,12 @@ pub enum ReadingError {
/// and where it happened.
msg: String,
},
/// Shape after deserialization is wrong
ShapesDoNotMatch {
/// More details about what row could not be read
/// and where it happened.
msg: String,
},
}
impl From<std::io::Error> for ReadingError {
fn from(io_error: std::io::Error) -> Self {
@@ -39,6 +45,7 @@ impl ReadingError {
ReadingError::InvalidField { msg } => Some(msg),
ReadingError::InvalidRow { msg } => Some(msg),
ReadingError::CouldNotReadFileSystem { msg } => Some(msg),
ReadingError::ShapesDoNotMatch { msg } => Some(msg),
ReadingError::NoRowsProvided => None,
}
}
+2 -1
View File
@@ -23,9 +23,10 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
/// search parameters
pub mod search;
pub mod svc;
pub mod svr;
// /// search parameters space
// pub mod search;
use core::fmt::Debug;
use std::marker::PhantomData;