Implement CSV reader with new traits (#209)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
### I'm submitting a
|
||||
- [ ] bug report.
|
||||
- [ ] improvement.
|
||||
- [ ] feature request.
|
||||
|
||||
### Current Behaviour:
|
||||
|
||||
+2
-2
@@ -105,8 +105,8 @@ pub mod neighbors;
|
||||
pub mod optimization;
|
||||
/// Preprocessing utilities
|
||||
pub mod preprocessing;
|
||||
// /// Reading in Data.
|
||||
// pub mod readers;
|
||||
/// Reading in data from serialized foramts
|
||||
pub mod readers;
|
||||
/// Support Vector Machines
|
||||
pub mod svm;
|
||||
/// Supervised tree-based learning methods
|
||||
|
||||
+39
-32
@@ -1,23 +1,24 @@
|
||||
//! This module contains utitilities to read-in matrices from csv files.
|
||||
//! ```
|
||||
//! ```rust
|
||||
//! use smartcore::readers::csv;
|
||||
//! use smartcore::linalg::naive::dense_matrix::DenseMatrix;
|
||||
//! use crate::smartcore::linalg::BaseMatrix;
|
||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
//! use std::fs;
|
||||
//!
|
||||
//! fs::write("identity.csv", "header\n1.0,0.0\n0.0,1.0");
|
||||
//! assert_eq!(
|
||||
//! csv::matrix_from_csv_source::<f64, Vec<_>, DenseMatrix<_>>(
|
||||
//!
|
||||
//! let mtx = csv::matrix_from_csv_source::<f64, Vec<_>, DenseMatrix<_>>(
|
||||
//! fs::File::open("identity.csv").unwrap(),
|
||||
//! csv::CSVDefinition::default()
|
||||
//! )
|
||||
//! .unwrap(),
|
||||
//! DenseMatrix::from_row_vectors(vec![vec![1.0, 0.0], vec![0.0, 1.0]]).unwrap()
|
||||
//! );
|
||||
//! .unwrap();
|
||||
//! println!("{}", &mtx);
|
||||
//!
|
||||
//! fs::remove_file("identity.csv");
|
||||
//! ```
|
||||
use crate::linalg::{BaseMatrix, BaseVector};
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
use crate::linalg::basic::arrays::{Array1, Array2};
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
use crate::readers::ReadingError;
|
||||
use std::io::Read;
|
||||
|
||||
@@ -77,35 +78,41 @@ pub fn matrix_from_csv_source<T, RowVector, Matrix>(
|
||||
definition: CSVDefinition<'_>,
|
||||
) -> Result<Matrix, ReadingError>
|
||||
where
|
||||
T: RealNumber,
|
||||
RowVector: BaseVector<T>,
|
||||
Matrix: BaseMatrix<T, RowVector = RowVector>,
|
||||
T: Number + RealNumber + std::str::FromStr,
|
||||
RowVector: Array1<T>,
|
||||
Matrix: Array2<T>,
|
||||
{
|
||||
let csv_text = read_string_from_source(source)?;
|
||||
let rows = extract_row_vectors_from_csv_text::<T, RowVector, Matrix>(
|
||||
let rows: Vec<Vec<T>> = extract_row_vectors_from_csv_text::<T, RowVector, Matrix>(
|
||||
&csv_text,
|
||||
&definition,
|
||||
detect_row_format(&csv_text, &definition)?,
|
||||
)?;
|
||||
let nrows = rows.len();
|
||||
let ncols = rows[0].len();
|
||||
|
||||
match Matrix::from_row_vectors(rows) {
|
||||
Some(matrix) => Ok(matrix),
|
||||
None => Err(ReadingError::NoRowsProvided),
|
||||
// TODO: try to return ReadingError
|
||||
let array2 = Matrix::from_iterator(rows.into_iter().flatten(), nrows, ncols, 0);
|
||||
|
||||
if array2.shape() != (nrows, ncols) {
|
||||
Err(ReadingError::ShapesDoNotMatch { msg: String::new() })
|
||||
} else {
|
||||
Ok(array2)
|
||||
}
|
||||
}
|
||||
|
||||
/// Given a string containing the contents of a csv file, extract its value
|
||||
/// into row-vectors.
|
||||
fn extract_row_vectors_from_csv_text<'a, T, RowVector, Matrix>(
|
||||
fn extract_row_vectors_from_csv_text<
|
||||
'a,
|
||||
T: Number + RealNumber + std::str::FromStr,
|
||||
RowVector: Array1<T>,
|
||||
Matrix: Array2<T>,
|
||||
>(
|
||||
csv_text: &'a str,
|
||||
definition: &'a CSVDefinition<'_>,
|
||||
row_format: CSVRowFormat<'_>,
|
||||
) -> Result<Vec<RowVector>, ReadingError>
|
||||
where
|
||||
T: RealNumber,
|
||||
RowVector: BaseVector<T>,
|
||||
Matrix: BaseMatrix<T, RowVector = RowVector>,
|
||||
{
|
||||
) -> Result<Vec<Vec<T>>, ReadingError> {
|
||||
csv_text
|
||||
.lines()
|
||||
.skip(definition.n_rows_header)
|
||||
@@ -132,12 +139,12 @@ fn extract_vector_from_csv_line<T, RowVector>(
|
||||
row_format: &CSVRowFormat<'_>,
|
||||
) -> Result<RowVector, ReadingError>
|
||||
where
|
||||
T: RealNumber,
|
||||
RowVector: BaseVector<T>,
|
||||
T: Number + RealNumber + std::str::FromStr,
|
||||
RowVector: Array1<T>,
|
||||
{
|
||||
validate_csv_row(line, row_format)?;
|
||||
let extracted_fields = extract_fields_from_csv_row(line, row_format)?;
|
||||
Ok(BaseVector::from_array(&extracted_fields[..]))
|
||||
let extracted_fields: Vec<T> = extract_fields_from_csv_row(line, row_format)?;
|
||||
Ok(Array1::from_vec_slice(&extracted_fields[..]))
|
||||
}
|
||||
|
||||
/// Extract the fields from a string containing the row of a csv file.
|
||||
@@ -146,7 +153,7 @@ fn extract_fields_from_csv_row<T>(
|
||||
row_format: &CSVRowFormat<'_>,
|
||||
) -> Result<Vec<T>, ReadingError>
|
||||
where
|
||||
T: RealNumber,
|
||||
T: Number + RealNumber + std::str::FromStr,
|
||||
{
|
||||
row.split(row_format.field_seperator)
|
||||
.enumerate()
|
||||
@@ -192,7 +199,7 @@ fn enrich_reading_error<T>(
|
||||
/// Extract the value from a single csv field.
|
||||
fn extract_value_from_csv_field<T>(value_string: &str) -> Result<T, ReadingError>
|
||||
where
|
||||
T: RealNumber,
|
||||
T: Number + RealNumber + std::str::FromStr,
|
||||
{
|
||||
// By default, `FromStr::Err` does not implement `Debug`.
|
||||
// Restricting it in the library leads to many breaking
|
||||
@@ -210,7 +217,7 @@ where
|
||||
mod tests {
|
||||
mod matrix_from_csv_source {
|
||||
use super::super::{read_string_from_source, CSVDefinition, ReadingError};
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use crate::readers::{csv::matrix_from_csv_source, io_testing};
|
||||
|
||||
#[test]
|
||||
@@ -298,7 +305,7 @@ mod tests {
|
||||
}
|
||||
mod extract_row_vectors_from_csv_text {
|
||||
use super::super::{extract_row_vectors_from_csv_text, CSVDefinition, CSVRowFormat};
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
|
||||
#[test]
|
||||
fn read_default_csv() {
|
||||
|
||||
@@ -24,6 +24,12 @@ pub enum ReadingError {
|
||||
/// and where it happened.
|
||||
msg: String,
|
||||
},
|
||||
/// Shape after deserialization is wrong
|
||||
ShapesDoNotMatch {
|
||||
/// More details about what row could not be read
|
||||
/// and where it happened.
|
||||
msg: String,
|
||||
},
|
||||
}
|
||||
impl From<std::io::Error> for ReadingError {
|
||||
fn from(io_error: std::io::Error) -> Self {
|
||||
@@ -39,6 +45,7 @@ impl ReadingError {
|
||||
ReadingError::InvalidField { msg } => Some(msg),
|
||||
ReadingError::InvalidRow { msg } => Some(msg),
|
||||
ReadingError::CouldNotReadFileSystem { msg } => Some(msg),
|
||||
ReadingError::ShapesDoNotMatch { msg } => Some(msg),
|
||||
ReadingError::NoRowsProvided => None,
|
||||
}
|
||||
}
|
||||
|
||||
+2
-1
@@ -23,9 +23,10 @@
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
/// search parameters
|
||||
pub mod search;
|
||||
pub mod svc;
|
||||
pub mod svr;
|
||||
// /// search parameters space
|
||||
// pub mod search;
|
||||
|
||||
use core::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
Reference in New Issue
Block a user