If transform fails - fail before copying the whole matrix
(changed the order of coping, first do the categorical, than copy ther rest)
This commit is contained in:
@@ -156,7 +156,7 @@ impl OneHotEncoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Transform categorical variables to one-hot encoded and return a new matrix
|
/// Transform categorical variables to one-hot encoded and return a new matrix
|
||||||
pub fn transform<T: Categorizable, M: Matrix<T>>(&self, x: &M) -> Option<M> {
|
pub fn transform<T: Categorizable, M: Matrix<T>>(&self, x: &M) -> Result<M, Failed> {
|
||||||
let (nrows, p) = x.shape();
|
let (nrows, p) = x.shape();
|
||||||
let additional_params: Vec<usize> = self
|
let additional_params: Vec<usize> = self
|
||||||
.series_encoders
|
.series_encoders
|
||||||
@@ -164,28 +164,24 @@ impl OneHotEncoder {
|
|||||||
.map(|enc| enc.num_categories)
|
.map(|enc| enc.num_categories)
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let new_param_num: usize = p + additional_params.iter().fold(0, |cs, &v| cs + v - 1);
|
// Eac category of size v adds v-1 params
|
||||||
|
let expandws_p: usize = p + additional_params.iter().fold(0, |cs, &v| cs + v - 1);
|
||||||
|
|
||||||
let new_col_idx = find_new_idxs(p, &additional_params[..], &self.col_idx_categorical[..]);
|
let new_col_idx = find_new_idxs(p, &additional_params[..], &self.col_idx_categorical[..]);
|
||||||
let mut res = M::zeros(nrows, new_param_num);
|
let mut res = M::zeros(nrows, expandws_p);
|
||||||
// copy old data in x to their new location
|
|
||||||
for (old_p, &new_p) in new_col_idx.iter().enumerate() {
|
|
||||||
for r in 0..nrows {
|
|
||||||
let val = x.get(r, old_p);
|
|
||||||
res.set(r, new_p, val);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (pidx, &old_cidx) in self.col_idx_categorical.iter().enumerate() {
|
for (pidx, &old_cidx) in self.col_idx_categorical.iter().enumerate() {
|
||||||
let cidx = new_col_idx[old_cidx];
|
let cidx = new_col_idx[old_cidx];
|
||||||
let col_iter = (0..nrows).map(|r| res.get(r, cidx).to_category());
|
let col_iter = (0..nrows).map(|r| x.get(r, old_cidx).to_category());
|
||||||
let sencoder = &self.series_encoders[pidx];
|
let sencoder = &self.series_encoders[pidx];
|
||||||
let oh_series: Vec<Option<Vec<T>>> = sencoder.transform_iter(col_iter);
|
let oh_series: Vec<Option<Vec<T>>> = sencoder.transform_iter(col_iter);
|
||||||
|
|
||||||
for (row, oh_vec) in oh_series.iter().enumerate() {
|
for (row, oh_vec) in oh_series.iter().enumerate() {
|
||||||
match oh_vec {
|
match oh_vec {
|
||||||
None => {
|
None => {
|
||||||
// Bad value in a series causes in to be invalid
|
// Since we support T types, bad value in a series causes in to be invalid
|
||||||
// todo: proper error handling, so user can know where the bad value is
|
let msg = format!("At least one value in column {} doesn't conform to category definition", old_cidx);
|
||||||
return None;
|
return Err(Failed::transform(&msg[..]));
|
||||||
}
|
}
|
||||||
Some(v) => {
|
Some(v) => {
|
||||||
// copy one hot vectors to their place in the data matrix;
|
// copy one hot vectors to their place in the data matrix;
|
||||||
@@ -196,7 +192,27 @@ impl OneHotEncoder {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Some(res)
|
|
||||||
|
// copy old data in x to their new location while skipping catergorical vars (already treated)
|
||||||
|
let mut skip_idx_iter = self.col_idx_categorical.iter();
|
||||||
|
let mut cur_skip = skip_idx_iter.next();
|
||||||
|
|
||||||
|
for (old_p, &new_p) in new_col_idx.iter().enumerate() {
|
||||||
|
// if found treated varible, skip it
|
||||||
|
if let Some(&v) = cur_skip {
|
||||||
|
if v == old_p {
|
||||||
|
cur_skip = skip_idx_iter.next();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for r in 0..nrows {
|
||||||
|
let val = x.get(r, old_p);
|
||||||
|
res.set(r, new_p, val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(res)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user