Fit OneHotEncoder

This commit is contained in:
gaxler
2021-01-30 19:29:33 -08:00
parent dd39433ff8
commit cd5611079c
+45 -11
View File
@@ -75,32 +75,66 @@ fn find_new_idxs(num_params: usize, cat_sizes: &[usize], encoded_idxs: &[usize])
.collect(); .collect();
new_param_idxs new_param_idxs
} }
fn validate_col_is_categorical<T: Categorizable>(data: &Vec<T>) -> bool {
for v in data {
if !v.is_valid() { return false}
}
true
}
/// Encode Categorical variavbles of data matrix to one-hot /// Encode Categorical variavbles of data matrix to one-hot
pub struct OneHotEncoder { pub struct OneHotEncoder {
series_encoders: Vec<SeriesOneHotEncoder<HashableReal>>, series_encoders: Vec<SeriesOneHotEncoder<CategoricalFloat>>,
categorical_param_idxs: Vec<usize>, col_idx_categorical: Vec<usize>,
} }
impl<T: RealNumber, M: Matrix<T>> OneHotEncoder { impl OneHotEncoder {
/// PlaceHolder /// PlaceHolder
pub fn fit(data: &M, params: OneHotEncoderParams) -> Result<OneHotEncoder, Failed> { pub fn fit<T: Categorizable, M: Matrix<T>>(
match (params.categorical_param_idxs, params.infer_categorical) { data: &M,
params: OneHotEncoderParams,
) -> Result<OneHotEncoder, Failed> {
match (params.col_idx_categorical, params.infer_categorical) {
(None, false) => Err(Failed::fit( (None, false) => Err(Failed::fit(
"Must pass categorical series ids or infer flag", "Must pass categorical series ids or infer flag",
)), )),
(Some(idxs), true) => Err(Failed::fit( (Some(_idxs), true) => Err(Failed::fit(
"Ambigous parameters, got both infer and categroy ids", "Ambigous parameters, got both infer and categroy ids",
)), )),
(Some(idxs), false) => Ok(Self { (Some(mut idxs), false) => {
series_encoders: Self::build_series_encoders::<T, M>(data, &idxs[..]), // make sure categories have same order as data columns
categorical_param_idxs: idxs, idxs.sort();
}),
let (nrows, _) = data.shape();
// col buffer to avoid allocations
let mut col_buf: Vec<T> = iter::repeat(T::zero()).take(nrows).collect();
let mut res: Vec<SeriesOneHotEncoder<CategoricalFloat>> = Vec::with_capacity(idxs.len());
for &idx in &idxs {
data.copy_col_as_vec(idx, &mut col_buf);
if !validate_col_is_categorical(&col_buf) {
let msg = format!("Column {} of data matrix containts non categorizable (integer) values", idx);
return Err(Failed::fit(&msg[..]))
}
let hashable_col = col_buf.iter().map(|v| v.to_category());
res.push(SeriesOneHotEncoder::fit_to_iter(hashable_col));
}
Ok(Self {
series_encoders: res, //Self::build_series_encoders::<T, M>(data, &idxs[..]),
col_idx_categorical: idxs,
})
}
(None, true) => { (None, true) => {
todo!("implement categorical auto-inference") todo!("Auto-Inference for Categorical Variables not yet implemented")
}
}
}
} }
} }
} }