Implement CSV reader with new traits (#209)
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
### I'm submitting a
|
### I'm submitting a
|
||||||
- [ ] bug report.
|
- [ ] bug report.
|
||||||
|
- [ ] improvement.
|
||||||
- [ ] feature request.
|
- [ ] feature request.
|
||||||
|
|
||||||
### Current Behaviour:
|
### Current Behaviour:
|
||||||
|
|||||||
+2
-2
@@ -105,8 +105,8 @@ pub mod neighbors;
|
|||||||
pub mod optimization;
|
pub mod optimization;
|
||||||
/// Preprocessing utilities
|
/// Preprocessing utilities
|
||||||
pub mod preprocessing;
|
pub mod preprocessing;
|
||||||
// /// Reading in Data.
|
/// Reading in data from serialized foramts
|
||||||
// pub mod readers;
|
pub mod readers;
|
||||||
/// Support Vector Machines
|
/// Support Vector Machines
|
||||||
pub mod svm;
|
pub mod svm;
|
||||||
/// Supervised tree-based learning methods
|
/// Supervised tree-based learning methods
|
||||||
|
|||||||
+39
-32
@@ -1,23 +1,24 @@
|
|||||||
//! This module contains utitilities to read-in matrices from csv files.
|
//! This module contains utitilities to read-in matrices from csv files.
|
||||||
//! ```
|
//! ```rust
|
||||||
//! use smartcore::readers::csv;
|
//! use smartcore::readers::csv;
|
||||||
//! use smartcore::linalg::naive::dense_matrix::DenseMatrix;
|
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||||
//! use crate::smartcore::linalg::BaseMatrix;
|
|
||||||
//! use std::fs;
|
//! use std::fs;
|
||||||
//!
|
//!
|
||||||
//! fs::write("identity.csv", "header\n1.0,0.0\n0.0,1.0");
|
//! fs::write("identity.csv", "header\n1.0,0.0\n0.0,1.0");
|
||||||
//! assert_eq!(
|
//!
|
||||||
//! csv::matrix_from_csv_source::<f64, Vec<_>, DenseMatrix<_>>(
|
//! let mtx = csv::matrix_from_csv_source::<f64, Vec<_>, DenseMatrix<_>>(
|
||||||
//! fs::File::open("identity.csv").unwrap(),
|
//! fs::File::open("identity.csv").unwrap(),
|
||||||
//! csv::CSVDefinition::default()
|
//! csv::CSVDefinition::default()
|
||||||
//! )
|
//! )
|
||||||
//! .unwrap(),
|
//! .unwrap();
|
||||||
//! DenseMatrix::from_row_vectors(vec![vec![1.0, 0.0], vec![0.0, 1.0]]).unwrap()
|
//! println!("{}", &mtx);
|
||||||
//! );
|
//!
|
||||||
//! fs::remove_file("identity.csv");
|
//! fs::remove_file("identity.csv");
|
||||||
//! ```
|
//! ```
|
||||||
use crate::linalg::{BaseMatrix, BaseVector};
|
|
||||||
use crate::math::num::RealNumber;
|
use crate::linalg::basic::arrays::{Array1, Array2};
|
||||||
|
use crate::numbers::basenum::Number;
|
||||||
|
use crate::numbers::realnum::RealNumber;
|
||||||
use crate::readers::ReadingError;
|
use crate::readers::ReadingError;
|
||||||
use std::io::Read;
|
use std::io::Read;
|
||||||
|
|
||||||
@@ -77,35 +78,41 @@ pub fn matrix_from_csv_source<T, RowVector, Matrix>(
|
|||||||
definition: CSVDefinition<'_>,
|
definition: CSVDefinition<'_>,
|
||||||
) -> Result<Matrix, ReadingError>
|
) -> Result<Matrix, ReadingError>
|
||||||
where
|
where
|
||||||
T: RealNumber,
|
T: Number + RealNumber + std::str::FromStr,
|
||||||
RowVector: BaseVector<T>,
|
RowVector: Array1<T>,
|
||||||
Matrix: BaseMatrix<T, RowVector = RowVector>,
|
Matrix: Array2<T>,
|
||||||
{
|
{
|
||||||
let csv_text = read_string_from_source(source)?;
|
let csv_text = read_string_from_source(source)?;
|
||||||
let rows = extract_row_vectors_from_csv_text::<T, RowVector, Matrix>(
|
let rows: Vec<Vec<T>> = extract_row_vectors_from_csv_text::<T, RowVector, Matrix>(
|
||||||
&csv_text,
|
&csv_text,
|
||||||
&definition,
|
&definition,
|
||||||
detect_row_format(&csv_text, &definition)?,
|
detect_row_format(&csv_text, &definition)?,
|
||||||
)?;
|
)?;
|
||||||
|
let nrows = rows.len();
|
||||||
|
let ncols = rows[0].len();
|
||||||
|
|
||||||
match Matrix::from_row_vectors(rows) {
|
// TODO: try to return ReadingError
|
||||||
Some(matrix) => Ok(matrix),
|
let array2 = Matrix::from_iterator(rows.into_iter().flatten(), nrows, ncols, 0);
|
||||||
None => Err(ReadingError::NoRowsProvided),
|
|
||||||
|
if array2.shape() != (nrows, ncols) {
|
||||||
|
Err(ReadingError::ShapesDoNotMatch { msg: String::new() })
|
||||||
|
} else {
|
||||||
|
Ok(array2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Given a string containing the contents of a csv file, extract its value
|
/// Given a string containing the contents of a csv file, extract its value
|
||||||
/// into row-vectors.
|
/// into row-vectors.
|
||||||
fn extract_row_vectors_from_csv_text<'a, T, RowVector, Matrix>(
|
fn extract_row_vectors_from_csv_text<
|
||||||
|
'a,
|
||||||
|
T: Number + RealNumber + std::str::FromStr,
|
||||||
|
RowVector: Array1<T>,
|
||||||
|
Matrix: Array2<T>,
|
||||||
|
>(
|
||||||
csv_text: &'a str,
|
csv_text: &'a str,
|
||||||
definition: &'a CSVDefinition<'_>,
|
definition: &'a CSVDefinition<'_>,
|
||||||
row_format: CSVRowFormat<'_>,
|
row_format: CSVRowFormat<'_>,
|
||||||
) -> Result<Vec<RowVector>, ReadingError>
|
) -> Result<Vec<Vec<T>>, ReadingError> {
|
||||||
where
|
|
||||||
T: RealNumber,
|
|
||||||
RowVector: BaseVector<T>,
|
|
||||||
Matrix: BaseMatrix<T, RowVector = RowVector>,
|
|
||||||
{
|
|
||||||
csv_text
|
csv_text
|
||||||
.lines()
|
.lines()
|
||||||
.skip(definition.n_rows_header)
|
.skip(definition.n_rows_header)
|
||||||
@@ -132,12 +139,12 @@ fn extract_vector_from_csv_line<T, RowVector>(
|
|||||||
row_format: &CSVRowFormat<'_>,
|
row_format: &CSVRowFormat<'_>,
|
||||||
) -> Result<RowVector, ReadingError>
|
) -> Result<RowVector, ReadingError>
|
||||||
where
|
where
|
||||||
T: RealNumber,
|
T: Number + RealNumber + std::str::FromStr,
|
||||||
RowVector: BaseVector<T>,
|
RowVector: Array1<T>,
|
||||||
{
|
{
|
||||||
validate_csv_row(line, row_format)?;
|
validate_csv_row(line, row_format)?;
|
||||||
let extracted_fields = extract_fields_from_csv_row(line, row_format)?;
|
let extracted_fields: Vec<T> = extract_fields_from_csv_row(line, row_format)?;
|
||||||
Ok(BaseVector::from_array(&extracted_fields[..]))
|
Ok(Array1::from_vec_slice(&extracted_fields[..]))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extract the fields from a string containing the row of a csv file.
|
/// Extract the fields from a string containing the row of a csv file.
|
||||||
@@ -146,7 +153,7 @@ fn extract_fields_from_csv_row<T>(
|
|||||||
row_format: &CSVRowFormat<'_>,
|
row_format: &CSVRowFormat<'_>,
|
||||||
) -> Result<Vec<T>, ReadingError>
|
) -> Result<Vec<T>, ReadingError>
|
||||||
where
|
where
|
||||||
T: RealNumber,
|
T: Number + RealNumber + std::str::FromStr,
|
||||||
{
|
{
|
||||||
row.split(row_format.field_seperator)
|
row.split(row_format.field_seperator)
|
||||||
.enumerate()
|
.enumerate()
|
||||||
@@ -192,7 +199,7 @@ fn enrich_reading_error<T>(
|
|||||||
/// Extract the value from a single csv field.
|
/// Extract the value from a single csv field.
|
||||||
fn extract_value_from_csv_field<T>(value_string: &str) -> Result<T, ReadingError>
|
fn extract_value_from_csv_field<T>(value_string: &str) -> Result<T, ReadingError>
|
||||||
where
|
where
|
||||||
T: RealNumber,
|
T: Number + RealNumber + std::str::FromStr,
|
||||||
{
|
{
|
||||||
// By default, `FromStr::Err` does not implement `Debug`.
|
// By default, `FromStr::Err` does not implement `Debug`.
|
||||||
// Restricting it in the library leads to many breaking
|
// Restricting it in the library leads to many breaking
|
||||||
@@ -210,7 +217,7 @@ where
|
|||||||
mod tests {
|
mod tests {
|
||||||
mod matrix_from_csv_source {
|
mod matrix_from_csv_source {
|
||||||
use super::super::{read_string_from_source, CSVDefinition, ReadingError};
|
use super::super::{read_string_from_source, CSVDefinition, ReadingError};
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::basic::matrix::DenseMatrix;
|
||||||
use crate::readers::{csv::matrix_from_csv_source, io_testing};
|
use crate::readers::{csv::matrix_from_csv_source, io_testing};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -298,7 +305,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
mod extract_row_vectors_from_csv_text {
|
mod extract_row_vectors_from_csv_text {
|
||||||
use super::super::{extract_row_vectors_from_csv_text, CSVDefinition, CSVRowFormat};
|
use super::super::{extract_row_vectors_from_csv_text, CSVDefinition, CSVRowFormat};
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::basic::matrix::DenseMatrix;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn read_default_csv() {
|
fn read_default_csv() {
|
||||||
|
|||||||
@@ -24,6 +24,12 @@ pub enum ReadingError {
|
|||||||
/// and where it happened.
|
/// and where it happened.
|
||||||
msg: String,
|
msg: String,
|
||||||
},
|
},
|
||||||
|
/// Shape after deserialization is wrong
|
||||||
|
ShapesDoNotMatch {
|
||||||
|
/// More details about what row could not be read
|
||||||
|
/// and where it happened.
|
||||||
|
msg: String,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
impl From<std::io::Error> for ReadingError {
|
impl From<std::io::Error> for ReadingError {
|
||||||
fn from(io_error: std::io::Error) -> Self {
|
fn from(io_error: std::io::Error) -> Self {
|
||||||
@@ -39,6 +45,7 @@ impl ReadingError {
|
|||||||
ReadingError::InvalidField { msg } => Some(msg),
|
ReadingError::InvalidField { msg } => Some(msg),
|
||||||
ReadingError::InvalidRow { msg } => Some(msg),
|
ReadingError::InvalidRow { msg } => Some(msg),
|
||||||
ReadingError::CouldNotReadFileSystem { msg } => Some(msg),
|
ReadingError::CouldNotReadFileSystem { msg } => Some(msg),
|
||||||
|
ReadingError::ShapesDoNotMatch { msg } => Some(msg),
|
||||||
ReadingError::NoRowsProvided => None,
|
ReadingError::NoRowsProvided => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+2
-1
@@ -23,9 +23,10 @@
|
|||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
/// search parameters
|
/// search parameters
|
||||||
pub mod search;
|
|
||||||
pub mod svc;
|
pub mod svc;
|
||||||
pub mod svr;
|
pub mod svr;
|
||||||
|
// /// search parameters space
|
||||||
|
// pub mod search;
|
||||||
|
|
||||||
use core::fmt::Debug;
|
use core::fmt::Debug;
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|||||||
Reference in New Issue
Block a user