From 3dc8a4283298d6622a6a0c74cd008339d6b8e9c4 Mon Sep 17 00:00:00 2001 From: gaxler Date: Sat, 30 Jan 2021 16:05:45 -0800 Subject: [PATCH] Adapt column numbers to the new columns introduced by categorical variables. --- src/preprocessing/categorical_encoders.rs | 34 +++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/src/preprocessing/categorical_encoders.rs b/src/preprocessing/categorical_encoders.rs index 0436787..31d3500 100644 --- a/src/preprocessing/categorical_encoders.rs +++ b/src/preprocessing/categorical_encoders.rs @@ -41,6 +41,40 @@ pub struct OneHotEncoderParams { pub categorical_param_idxs: Option>, pub infer_categorical: bool, } +/// Calculate the offset to parameters to due introduction of one-hot encoding +fn find_new_idxs(num_params: usize, cat_sizes: &[usize], encoded_idxs: &[usize]) -> Vec { + // This functions uses iterators and returns a vector. + // In case we get a huge amount of paramenters this might be a problem + // todo: Change this such that it will return an iterator + + let cat_idx = encoded_idxs.iter().copied().chain((num_params..).take(1)); + + // Offset is constant between two categorical values, here we calculate the number of steps + // that remain constant + let repeats = cat_idx.scan(0, |a, v| { + let im = v + 1 - *a; + *a = v; + Some(im) + }); + + // Calculate the offset to parameter idx due to newly intorduced one-hot vectors + let offset_ = cat_sizes.iter().scan(0, |a, &v| { + *a = *a + v - 1; + Some(*a) + }); + let offset = (0..1).chain(offset_); + + let new_param_idxs: Vec = (0..num_params) + .zip( + repeats + .zip(offset) + .map(|(r, o)| iter::repeat(o).take(r)) + .flatten(), + ) + .map(|(idx, ofst)| idx + ofst) + .collect(); + new_param_idxs +} /// Encode Categorical variavbles of data matrix to one-hot pub struct OneHotEncoder { series_encoders: Vec>,