From fd6b2e801479f709870921f192153c6abeeab53d Mon Sep 17 00:00:00 2001 From: gaxler Date: Sat, 30 Jan 2021 19:29:58 -0800 Subject: [PATCH] Transform matrix --- src/preprocessing/categorical_encoders.rs | 42 +++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/src/preprocessing/categorical_encoders.rs b/src/preprocessing/categorical_encoders.rs index 794c1d6..585f13a 100644 --- a/src/preprocessing/categorical_encoders.rs +++ b/src/preprocessing/categorical_encoders.rs @@ -135,9 +135,51 @@ impl OneHotEncoder { } } } + + /// Transform categorical variables to one-hot encoded and return a new matrix + pub fn transform>(&self, x: &M) -> Option { + let (nrows, p) = x.shape(); + let additional_params: Vec = self + .series_encoders + .iter() + .map(|enc| enc.num_categories) + .collect(); + + let new_param_num: usize = p + additional_params.iter().fold(0, |cs, &v| cs + v - 1); + let new_col_idx = find_new_idxs(p, &additional_params[..], &self.col_idx_categorical[..]); + let mut res = M::zeros(nrows, new_param_num); + // copy old data in x to their new location + for (old_p, &new_p) in new_col_idx.iter().enumerate() { + for r in 0..nrows { + let val = x.get(r, old_p); + res.set(r, new_p, val); } } + for (pidx, &old_cidx) in self.col_idx_categorical.iter().enumerate() { + let cidx = new_col_idx[old_cidx]; + let col_iter = (0..nrows).map(|r| res.get(r, cidx).to_category()); + let sencoder = &self.series_encoders[pidx]; + let oh_series: Vec>> = sencoder.transform_iter(col_iter); + + for (row, oh_vec) in oh_series.iter().enumerate() { + match oh_vec { + None => { + // Bad value in a series causes in to be invalid + // todo: proper error handling, so user can know where the bad value is + return None; + } + Some(v) => { + // copy one hot vectors to their place in the data matrix; + for (col_ofst, &val) in v.iter().enumerate() { + res.set(row, cidx + col_ofst, val); + } } + } + } + } + Some(res) + } +} fn build_series_encoders(data: &M, idxs: &[usize]) -> Vec> { let (nrows, _) = data.shape();