From ef06f45638ec42540d74f41ffd2171f2d97e793f Mon Sep 17 00:00:00 2001 From: gaxler Date: Tue, 2 Feb 2021 18:21:06 -0800 Subject: [PATCH] Switch to use SeriesEncoder trait --- src/preprocessing/categorical_encoder.rs | 35 ++++++++++++++---------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/src/preprocessing/categorical_encoder.rs b/src/preprocessing/categorical_encoder.rs index e3e8ce9..75cbf2b 100644 --- a/src/preprocessing/categorical_encoder.rs +++ b/src/preprocessing/categorical_encoder.rs @@ -6,7 +6,7 @@ //! ### Usage Example //! ``` //! use smartcore::linalg::naive::dense_matrix::DenseMatrix; -//! use smartcore::preprocessing::categorical_encoder::{OneHotEncoder, OneHotEncoderParams}; +//! use smartcore::preprocessing::categorical_encoder::{OneHotEnc, OneHotEncoderParams}; //! let data = DenseMatrix::from_2d_array(&[ //! &[1.5, 1.0, 1.5, 3.0], //! &[1.5, 2.0, 1.5, 4.0], @@ -15,7 +15,7 @@ //! ]); //! let encoder_params = OneHotEncoderParams::from_cat_idx(&[1, 3]); //! // Infer number of categories from data and return a reusable encoder -//! let encoder = OneHotEncoder::fit(&data, encoder_params).unwrap(); +//! let encoder = OneHotEnc::fit(&data, encoder_params).unwrap(); //! // Transform categorical to one-hot encoded (can transform similar) //! let oh_data = encoder.transform(&data).unwrap(); //! // Produces the following: @@ -30,7 +30,7 @@ use crate::error::Failed; use crate::linalg::Matrix; use crate::preprocessing::data_traits::{CategoricalFloat, Categorizable}; -use crate::preprocessing::series_encoder::SeriesOneHotEncoder; +use crate::preprocessing::series_encoder::{SeriesOneHotEncoder, SeriesEncoder}; /// OneHotEncoder Parameters #[derive(Debug, Clone)] @@ -97,17 +97,17 @@ fn validate_col_is_categorical(data: &[T]) -> bool { /// Encode Categorical variavbles of data matrix to one-hot #[derive(Debug, Clone)] -pub struct OneHotEncoder { - series_encoders: Vec>, +pub struct OneHotEncoder { + series_encoders: Vec, col_idx_categorical: Vec, } -impl OneHotEncoder { +impl> OneHotEncoder { /// Create an encoder instance with categories infered from data matrix pub fn fit>( data: &M, params: OneHotEncoderParams, - ) -> Result { + ) -> Result, Failed> { match (params.col_idx_categorical, params.infer_categorical) { (None, false) => Err(Failed::fit( "Must pass categorical series ids or infer flag", @@ -126,7 +126,7 @@ impl OneHotEncoder { // col buffer to avoid allocations let mut col_buf: Vec = iter::repeat(T::zero()).take(nrows).collect(); - let mut res: Vec> = + let mut res: Vec = Vec::with_capacity(idxs.len()); for &idx in &idxs { @@ -139,7 +139,7 @@ impl OneHotEncoder { return Err(Failed::fit(&msg[..])); } let hashable_col = col_buf.iter().map(|v| v.to_category()); - res.push(SeriesOneHotEncoder::fit_to_iter(hashable_col)); + res.push(E::fit_to_iter(hashable_col)); } Ok(Self { @@ -160,7 +160,7 @@ impl OneHotEncoder { let additional_params: Vec = self .series_encoders .iter() - .map(|enc| enc.num_categories) + .map(|enc| enc.num_categories()) .collect(); // Eac category of size v adds v-1 params @@ -215,12 +215,17 @@ impl OneHotEncoder { } } +/// Convinince type for common use +pub type OneHotEnc = OneHotEncoder>; + + #[cfg(test)] mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::preprocessing::series_encoder::SeriesOneHotEncoder; + #[test] fn adjust_idxs() { assert_eq!(find_new_idxs(0, &[], &[]), Vec::::new()); @@ -279,13 +284,13 @@ mod tests { fn test_fit() { let (x, _) = build_fake_matrix(); let params = OneHotEncoderParams::from_cat_idx(&[1, 3]); - let oh_enc = OneHotEncoder::fit(&x, params).unwrap(); + let oh_enc = OneHotEnc::fit(&x, params).unwrap(); assert_eq!(oh_enc.series_encoders.len(), 2); let num_cat: Vec = oh_enc .series_encoders .iter() - .map(|a| a.num_categories) + .map(|a| a.num_categories()) .collect(); assert_eq!(num_cat, vec![2, 4]); } @@ -294,13 +299,13 @@ mod tests { fn matrix_transform_test() { let (x, expected_x) = build_fake_matrix(); let params = OneHotEncoderParams::from_cat_idx(&[1, 3]); - let oh_enc = OneHotEncoder::fit(&x, params).unwrap(); + let oh_enc = OneHotEnc::fit(&x, params).unwrap(); let nm = oh_enc.transform(&x).unwrap(); assert_eq!(nm, expected_x); let (x, expected_x) = build_cat_first_and_last(); let params = OneHotEncoderParams::from_cat_idx(&[0, 2]); - let oh_enc = OneHotEncoder::fit(&x, params).unwrap(); + let oh_enc = OneHotEnc::fit(&x, params).unwrap(); let nm = oh_enc.transform(&x).unwrap(); assert_eq!(nm, expected_x); } @@ -315,7 +320,7 @@ mod tests { ]); let params = OneHotEncoderParams::from_cat_idx(&[1]); - match OneHotEncoder::fit(&m, params) { + match OneHotEnc::fit(&m, params) { Err(_) => { assert!(true); }