From 2f03c1d6d74834d5bad990a5fd9c7cd7962fa351 Mon Sep 17 00:00:00 2001 From: gaxler Date: Sat, 30 Jan 2021 19:54:42 -0800 Subject: [PATCH] module name change --- ...cal_encoders.rs => categorical_encoder.rs} | 49 ++++++++++++++----- 1 file changed, 37 insertions(+), 12 deletions(-) rename src/preprocessing/{categorical_encoders.rs => categorical_encoder.rs} (89%) diff --git a/src/preprocessing/categorical_encoders.rs b/src/preprocessing/categorical_encoder.rs similarity index 89% rename from src/preprocessing/categorical_encoders.rs rename to src/preprocessing/categorical_encoder.rs index 063aa5c..22cd052 100644 --- a/src/preprocessing/categorical_encoders.rs +++ b/src/preprocessing/categorical_encoder.rs @@ -38,7 +38,7 @@ pub struct OneHotEncoderParams { /// Column number that contain categorical variable pub col_idx_categorical: Option>, /// (Currently not implemented) Try and infer which of the matrix columns are categorical variables - pub infer_categorical: bool, + infer_categorical: bool, } impl OneHotEncoderParams { @@ -86,14 +86,17 @@ fn find_new_idxs(num_params: usize, cat_sizes: &[usize], encoded_idxs: &[usize]) new_param_idxs } -fn validate_col_is_categorical(data: &Vec) -> bool { +fn validate_col_is_categorical(data: &[T]) -> bool { for v in data { - if !v.is_valid() { return false} + if !v.is_valid() { + return false; + } } true } /// Encode Categorical variavbles of data matrix to one-hot +#[derive(Debug, Clone)] pub struct OneHotEncoder { series_encoders: Vec>, col_idx_categorical: Vec, @@ -102,7 +105,7 @@ pub struct OneHotEncoder { impl OneHotEncoder { /// PlaceHolder - pub fn fit>( + pub fn fit>( data: &M, params: OneHotEncoderParams, ) -> Result { @@ -117,20 +120,24 @@ impl OneHotEncoder { (Some(mut idxs), false) => { // make sure categories have same order as data columns - idxs.sort(); + idxs.sort_unstable(); let (nrows, _) = data.shape(); // col buffer to avoid allocations let mut col_buf: Vec = iter::repeat(T::zero()).take(nrows).collect(); - - let mut res: Vec> = Vec::with_capacity(idxs.len()); - + + let mut res: Vec> = + Vec::with_capacity(idxs.len()); + for &idx in &idxs { data.copy_col_as_vec(idx, &mut col_buf); if !validate_col_is_categorical(&col_buf) { - let msg = format!("Column {} of data matrix containts non categorizable (integer) values", idx); - return Err(Failed::fit(&msg[..])) + let msg = format!( + "Column {} of data matrix containts non categorizable (integer) values", + idx + ); + return Err(Failed::fit(&msg[..])); } let hashable_col = col_buf.iter().map(|v| v.to_category()); res.push(SeriesOneHotEncoder::fit_to_iter(hashable_col)); @@ -149,7 +156,7 @@ impl OneHotEncoder { } /// Transform categorical variables to one-hot encoded and return a new matrix - pub fn transform>(&self, x: &M) -> Option { + pub fn transform>(&self, x: &M) -> Option { let (nrows, p) = x.shape(); let additional_params: Vec = self .series_encoders @@ -201,7 +208,7 @@ mod tests { #[test] fn adjust_idxs() { - assert_eq!(find_new_idxs(0, &[], &[]), Vec::new()); + assert_eq!(find_new_idxs(0, &[], &[]), Vec::::new()); // [0,1,2] -> [0, 1, 1, 1, 2] assert_eq!(find_new_idxs(3, &[3], &[1]), vec![0, 1, 4]); } @@ -282,4 +289,22 @@ mod tests { let nm = oh_enc.transform(&X).unwrap(); assert_eq!(nm, expectedX); } + + #[test] + fn fail_on_bad_category() { + let m = DenseMatrix::from_2d_array(&[ + &[1.0, 1.5, 3.0], + &[2.0, 1.5, 4.0], + &[1.0, 1.5, 5.0], + &[2.0, 1.5, 6.0], + ]); + + let params = OneHotEncoderParams::from_cat_idx(&[1]); + match OneHotEncoder::fit(&m, params) { + Err(_) => { + assert!(true); + } + _ => assert!(false), + } + } }