diff --git a/src/preprocessing/categorical_encoder.rs b/src/preprocessing/categorical_encoder.rs index 706670b..7e71119 100644 --- a/src/preprocessing/categorical_encoder.rs +++ b/src/preprocessing/categorical_encoder.rs @@ -156,7 +156,7 @@ impl OneHotEncoder { } /// Transform categorical variables to one-hot encoded and return a new matrix - pub fn transform>(&self, x: &M) -> Option { + pub fn transform>(&self, x: &M) -> Result { let (nrows, p) = x.shape(); let additional_params: Vec = self .series_encoders @@ -164,28 +164,24 @@ impl OneHotEncoder { .map(|enc| enc.num_categories) .collect(); - let new_param_num: usize = p + additional_params.iter().fold(0, |cs, &v| cs + v - 1); + // Eac category of size v adds v-1 params + let expandws_p: usize = p + additional_params.iter().fold(0, |cs, &v| cs + v - 1); + let new_col_idx = find_new_idxs(p, &additional_params[..], &self.col_idx_categorical[..]); - let mut res = M::zeros(nrows, new_param_num); - // copy old data in x to their new location - for (old_p, &new_p) in new_col_idx.iter().enumerate() { - for r in 0..nrows { - let val = x.get(r, old_p); - res.set(r, new_p, val); - } - } + let mut res = M::zeros(nrows, expandws_p); + for (pidx, &old_cidx) in self.col_idx_categorical.iter().enumerate() { let cidx = new_col_idx[old_cidx]; - let col_iter = (0..nrows).map(|r| res.get(r, cidx).to_category()); + let col_iter = (0..nrows).map(|r| x.get(r, old_cidx).to_category()); let sencoder = &self.series_encoders[pidx]; let oh_series: Vec>> = sencoder.transform_iter(col_iter); for (row, oh_vec) in oh_series.iter().enumerate() { match oh_vec { None => { - // Bad value in a series causes in to be invalid - // todo: proper error handling, so user can know where the bad value is - return None; + // Since we support T types, bad value in a series causes in to be invalid + let msg = format!("At least one value in column {} doesn't conform to category definition", old_cidx); + return Err(Failed::transform(&msg[..])); } Some(v) => { // copy one hot vectors to their place in the data matrix; @@ -196,7 +192,27 @@ impl OneHotEncoder { } } } - Some(res) + + // copy old data in x to their new location while skipping catergorical vars (already treated) + let mut skip_idx_iter = self.col_idx_categorical.iter(); + let mut cur_skip = skip_idx_iter.next(); + + for (old_p, &new_p) in new_col_idx.iter().enumerate() { + // if found treated varible, skip it + if let Some(&v) = cur_skip { + if v == old_p { + cur_skip = skip_idx_iter.next(); + continue; + } + } + + for r in 0..nrows { + let val = x.get(r, old_p); + res.set(r, new_p, val); + } + } + + Ok(res) } }