diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md index 4fee515..1177761 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -1,5 +1,6 @@ ### I'm submitting a - [ ] bug report. +- [ ] improvement. - [ ] feature request. ### Current Behaviour: diff --git a/src/lib.rs b/src/lib.rs index 11c5b38..a955de2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -105,8 +105,8 @@ pub mod neighbors; pub mod optimization; /// Preprocessing utilities pub mod preprocessing; -// /// Reading in Data. -// pub mod readers; +/// Reading in data from serialized foramts +pub mod readers; /// Support Vector Machines pub mod svm; /// Supervised tree-based learning methods diff --git a/src/readers/csv.rs b/src/readers/csv.rs index e80d99b..0b2c18c 100644 --- a/src/readers/csv.rs +++ b/src/readers/csv.rs @@ -1,23 +1,24 @@ //! This module contains utitilities to read-in matrices from csv files. -//! ``` +//! ```rust //! use smartcore::readers::csv; -//! use smartcore::linalg::naive::dense_matrix::DenseMatrix; -//! use crate::smartcore::linalg::BaseMatrix; +//! use smartcore::linalg::basic::matrix::DenseMatrix; //! use std::fs; //! //! fs::write("identity.csv", "header\n1.0,0.0\n0.0,1.0"); -//! assert_eq!( -//! csv::matrix_from_csv_source::, DenseMatrix<_>>( -//! fs::File::open("identity.csv").unwrap(), -//! csv::CSVDefinition::default() -//! ) -//! .unwrap(), -//! DenseMatrix::from_row_vectors(vec![vec![1.0, 0.0], vec![0.0, 1.0]]).unwrap() -//! ); +//! +//! let mtx = csv::matrix_from_csv_source::, DenseMatrix<_>>( +//! fs::File::open("identity.csv").unwrap(), +//! csv::CSVDefinition::default() +//! ) +//! .unwrap(); +//! println!("{}", &mtx); +//! //! fs::remove_file("identity.csv"); //! ``` -use crate::linalg::{BaseMatrix, BaseVector}; -use crate::math::num::RealNumber; + +use crate::linalg::basic::arrays::{Array1, Array2}; +use crate::numbers::basenum::Number; +use crate::numbers::realnum::RealNumber; use crate::readers::ReadingError; use std::io::Read; @@ -77,35 +78,41 @@ pub fn matrix_from_csv_source( definition: CSVDefinition<'_>, ) -> Result where - T: RealNumber, - RowVector: BaseVector, - Matrix: BaseMatrix, + T: Number + RealNumber + std::str::FromStr, + RowVector: Array1, + Matrix: Array2, { let csv_text = read_string_from_source(source)?; - let rows = extract_row_vectors_from_csv_text::( + let rows: Vec> = extract_row_vectors_from_csv_text::( &csv_text, &definition, detect_row_format(&csv_text, &definition)?, )?; + let nrows = rows.len(); + let ncols = rows[0].len(); - match Matrix::from_row_vectors(rows) { - Some(matrix) => Ok(matrix), - None => Err(ReadingError::NoRowsProvided), + // TODO: try to return ReadingError + let array2 = Matrix::from_iterator(rows.into_iter().flatten(), nrows, ncols, 0); + + if array2.shape() != (nrows, ncols) { + Err(ReadingError::ShapesDoNotMatch { msg: String::new() }) + } else { + Ok(array2) } } /// Given a string containing the contents of a csv file, extract its value /// into row-vectors. -fn extract_row_vectors_from_csv_text<'a, T, RowVector, Matrix>( +fn extract_row_vectors_from_csv_text< + 'a, + T: Number + RealNumber + std::str::FromStr, + RowVector: Array1, + Matrix: Array2, +>( csv_text: &'a str, definition: &'a CSVDefinition<'_>, row_format: CSVRowFormat<'_>, -) -> Result, ReadingError> -where - T: RealNumber, - RowVector: BaseVector, - Matrix: BaseMatrix, -{ +) -> Result>, ReadingError> { csv_text .lines() .skip(definition.n_rows_header) @@ -132,12 +139,12 @@ fn extract_vector_from_csv_line( row_format: &CSVRowFormat<'_>, ) -> Result where - T: RealNumber, - RowVector: BaseVector, + T: Number + RealNumber + std::str::FromStr, + RowVector: Array1, { validate_csv_row(line, row_format)?; - let extracted_fields = extract_fields_from_csv_row(line, row_format)?; - Ok(BaseVector::from_array(&extracted_fields[..])) + let extracted_fields: Vec = extract_fields_from_csv_row(line, row_format)?; + Ok(Array1::from_vec_slice(&extracted_fields[..])) } /// Extract the fields from a string containing the row of a csv file. @@ -146,7 +153,7 @@ fn extract_fields_from_csv_row( row_format: &CSVRowFormat<'_>, ) -> Result, ReadingError> where - T: RealNumber, + T: Number + RealNumber + std::str::FromStr, { row.split(row_format.field_seperator) .enumerate() @@ -192,7 +199,7 @@ fn enrich_reading_error( /// Extract the value from a single csv field. fn extract_value_from_csv_field(value_string: &str) -> Result where - T: RealNumber, + T: Number + RealNumber + std::str::FromStr, { // By default, `FromStr::Err` does not implement `Debug`. // Restricting it in the library leads to many breaking @@ -210,7 +217,7 @@ where mod tests { mod matrix_from_csv_source { use super::super::{read_string_from_source, CSVDefinition, ReadingError}; - use crate::linalg::naive::dense_matrix::DenseMatrix; + use crate::linalg::basic::matrix::DenseMatrix; use crate::readers::{csv::matrix_from_csv_source, io_testing}; #[test] @@ -298,7 +305,7 @@ mod tests { } mod extract_row_vectors_from_csv_text { use super::super::{extract_row_vectors_from_csv_text, CSVDefinition, CSVRowFormat}; - use crate::linalg::naive::dense_matrix::DenseMatrix; + use crate::linalg::basic::matrix::DenseMatrix; #[test] fn read_default_csv() { diff --git a/src/readers/error.rs b/src/readers/error.rs index 16e910d..047092a 100644 --- a/src/readers/error.rs +++ b/src/readers/error.rs @@ -24,6 +24,12 @@ pub enum ReadingError { /// and where it happened. msg: String, }, + /// Shape after deserialization is wrong + ShapesDoNotMatch { + /// More details about what row could not be read + /// and where it happened. + msg: String, + }, } impl From for ReadingError { fn from(io_error: std::io::Error) -> Self { @@ -39,6 +45,7 @@ impl ReadingError { ReadingError::InvalidField { msg } => Some(msg), ReadingError::InvalidRow { msg } => Some(msg), ReadingError::CouldNotReadFileSystem { msg } => Some(msg), + ReadingError::ShapesDoNotMatch { msg } => Some(msg), ReadingError::NoRowsProvided => None, } } diff --git a/src/svm/mod.rs b/src/svm/mod.rs index 46898c9..febfead 100644 --- a/src/svm/mod.rs +++ b/src/svm/mod.rs @@ -23,9 +23,10 @@ //! //! /// search parameters -pub mod search; pub mod svc; pub mod svr; +// /// search parameters space +// pub mod search; use core::fmt::Debug; use std::marker::PhantomData;