Adapt column numbers to the new columns introduced by categorical variables.

This commit is contained in:
gaxler
2021-01-30 16:05:45 -08:00
parent 3480e728af
commit 3dc8a42832
+34
View File
@@ -41,6 +41,40 @@ pub struct OneHotEncoderParams {
pub categorical_param_idxs: Option<Vec<usize>>,
pub infer_categorical: bool,
}
/// Calculate the offset to parameters to due introduction of one-hot encoding
fn find_new_idxs(num_params: usize, cat_sizes: &[usize], encoded_idxs: &[usize]) -> Vec<usize> {
// This functions uses iterators and returns a vector.
// In case we get a huge amount of paramenters this might be a problem
// todo: Change this such that it will return an iterator
let cat_idx = encoded_idxs.iter().copied().chain((num_params..).take(1));
// Offset is constant between two categorical values, here we calculate the number of steps
// that remain constant
let repeats = cat_idx.scan(0, |a, v| {
let im = v + 1 - *a;
*a = v;
Some(im)
});
// Calculate the offset to parameter idx due to newly intorduced one-hot vectors
let offset_ = cat_sizes.iter().scan(0, |a, &v| {
*a = *a + v - 1;
Some(*a)
});
let offset = (0..1).chain(offset_);
let new_param_idxs: Vec<usize> = (0..num_params)
.zip(
repeats
.zip(offset)
.map(|(r, o)| iter::repeat(o).take(r))
.flatten(),
)
.map(|(idx, ofst)| idx + ofst)
.collect();
new_param_idxs
}
/// Encode Categorical variavbles of data matrix to one-hot
pub struct OneHotEncoder {
series_encoders: Vec<SeriesOneHotEncoder<HashableReal>>,