From cd5611079caae782f148397a0ebad465aea6faef Mon Sep 17 00:00:00 2001 From: gaxler Date: Sat, 30 Jan 2021 19:29:33 -0800 Subject: [PATCH] Fit OneHotEncoder --- src/preprocessing/categorical_encoders.rs | 56 ++++++++++++++++++----- 1 file changed, 45 insertions(+), 11 deletions(-) diff --git a/src/preprocessing/categorical_encoders.rs b/src/preprocessing/categorical_encoders.rs index 31d3500..794c1d6 100644 --- a/src/preprocessing/categorical_encoders.rs +++ b/src/preprocessing/categorical_encoders.rs @@ -75,32 +75,66 @@ fn find_new_idxs(num_params: usize, cat_sizes: &[usize], encoded_idxs: &[usize]) .collect(); new_param_idxs } +fn validate_col_is_categorical(data: &Vec) -> bool { + for v in data { + if !v.is_valid() { return false} + } + true +} /// Encode Categorical variavbles of data matrix to one-hot pub struct OneHotEncoder { - series_encoders: Vec>, - categorical_param_idxs: Vec, + series_encoders: Vec>, + col_idx_categorical: Vec, } -impl> OneHotEncoder { +impl OneHotEncoder { /// PlaceHolder - pub fn fit(data: &M, params: OneHotEncoderParams) -> Result { - match (params.categorical_param_idxs, params.infer_categorical) { + pub fn fit>( + data: &M, + params: OneHotEncoderParams, + ) -> Result { + match (params.col_idx_categorical, params.infer_categorical) { (None, false) => Err(Failed::fit( "Must pass categorical series ids or infer flag", )), - (Some(idxs), true) => Err(Failed::fit( + (Some(_idxs), true) => Err(Failed::fit( "Ambigous parameters, got both infer and categroy ids", )), - (Some(idxs), false) => Ok(Self { - series_encoders: Self::build_series_encoders::(data, &idxs[..]), - categorical_param_idxs: idxs, - }), + (Some(mut idxs), false) => { + // make sure categories have same order as data columns + idxs.sort(); + + let (nrows, _) = data.shape(); + + // col buffer to avoid allocations + let mut col_buf: Vec = iter::repeat(T::zero()).take(nrows).collect(); + + let mut res: Vec> = Vec::with_capacity(idxs.len()); + + for &idx in &idxs { + data.copy_col_as_vec(idx, &mut col_buf); + if !validate_col_is_categorical(&col_buf) { + let msg = format!("Column {} of data matrix containts non categorizable (integer) values", idx); + return Err(Failed::fit(&msg[..])) + } + let hashable_col = col_buf.iter().map(|v| v.to_category()); + res.push(SeriesOneHotEncoder::fit_to_iter(hashable_col)); + } + + Ok(Self { + series_encoders: res, //Self::build_series_encoders::(data, &idxs[..]), + col_idx_categorical: idxs, + }) + } (None, true) => { - todo!("implement categorical auto-inference") + todo!("Auto-Inference for Categorical Variables not yet implemented") + } + } + } } } }