Transform matrix

This commit is contained in:
gaxler
2021-01-30 19:29:58 -08:00
parent cd5611079c
commit fd6b2e8014
+42
View File
@@ -135,8 +135,50 @@ impl OneHotEncoder {
}
}
}
/// Transform categorical variables to one-hot encoded and return a new matrix
pub fn transform<T: Categorizable, M: Matrix<T>>(&self, x: &M) -> Option<M> {
let (nrows, p) = x.shape();
let additional_params: Vec<usize> = self
.series_encoders
.iter()
.map(|enc| enc.num_categories)
.collect();
let new_param_num: usize = p + additional_params.iter().fold(0, |cs, &v| cs + v - 1);
let new_col_idx = find_new_idxs(p, &additional_params[..], &self.col_idx_categorical[..]);
let mut res = M::zeros(nrows, new_param_num);
// copy old data in x to their new location
for (old_p, &new_p) in new_col_idx.iter().enumerate() {
for r in 0..nrows {
let val = x.get(r, old_p);
res.set(r, new_p, val);
}
}
for (pidx, &old_cidx) in self.col_idx_categorical.iter().enumerate() {
let cidx = new_col_idx[old_cidx];
let col_iter = (0..nrows).map(|r| res.get(r, cidx).to_category());
let sencoder = &self.series_encoders[pidx];
let oh_series: Vec<Option<Vec<T>>> = sencoder.transform_iter(col_iter);
for (row, oh_vec) in oh_series.iter().enumerate() {
match oh_vec {
None => {
// Bad value in a series causes in to be invalid
// todo: proper error handling, so user can know where the bad value is
return None;
}
Some(v) => {
// copy one hot vectors to their place in the data matrix;
for (col_ofst, &val) in v.iter().enumerate() {
res.set(row, cidx + col_ofst, val);
}
}
}
}
}
Some(res)
}
}
fn build_series_encoders(data: &M, idxs: &[usize]) -> Vec<SeriesOneHotEncoder<HashableReal>> {