Fit OneHotEncoder

This commit is contained in:
gaxler
2021-01-30 19:29:33 -08:00
parent dd39433ff8
commit cd5611079c
+45 -11
View File
@@ -75,32 +75,66 @@ fn find_new_idxs(num_params: usize, cat_sizes: &[usize], encoded_idxs: &[usize])
.collect();
new_param_idxs
}
fn validate_col_is_categorical<T: Categorizable>(data: &Vec<T>) -> bool {
for v in data {
if !v.is_valid() { return false}
}
true
}
/// Encode Categorical variavbles of data matrix to one-hot
pub struct OneHotEncoder {
series_encoders: Vec<SeriesOneHotEncoder<HashableReal>>,
categorical_param_idxs: Vec<usize>,
series_encoders: Vec<SeriesOneHotEncoder<CategoricalFloat>>,
col_idx_categorical: Vec<usize>,
}
impl<T: RealNumber, M: Matrix<T>> OneHotEncoder {
impl OneHotEncoder {
/// PlaceHolder
pub fn fit(data: &M, params: OneHotEncoderParams) -> Result<OneHotEncoder, Failed> {
match (params.categorical_param_idxs, params.infer_categorical) {
pub fn fit<T: Categorizable, M: Matrix<T>>(
data: &M,
params: OneHotEncoderParams,
) -> Result<OneHotEncoder, Failed> {
match (params.col_idx_categorical, params.infer_categorical) {
(None, false) => Err(Failed::fit(
"Must pass categorical series ids or infer flag",
)),
(Some(idxs), true) => Err(Failed::fit(
(Some(_idxs), true) => Err(Failed::fit(
"Ambigous parameters, got both infer and categroy ids",
)),
(Some(idxs), false) => Ok(Self {
series_encoders: Self::build_series_encoders::<T, M>(data, &idxs[..]),
categorical_param_idxs: idxs,
}),
(Some(mut idxs), false) => {
// make sure categories have same order as data columns
idxs.sort();
let (nrows, _) = data.shape();
// col buffer to avoid allocations
let mut col_buf: Vec<T> = iter::repeat(T::zero()).take(nrows).collect();
let mut res: Vec<SeriesOneHotEncoder<CategoricalFloat>> = Vec::with_capacity(idxs.len());
for &idx in &idxs {
data.copy_col_as_vec(idx, &mut col_buf);
if !validate_col_is_categorical(&col_buf) {
let msg = format!("Column {} of data matrix containts non categorizable (integer) values", idx);
return Err(Failed::fit(&msg[..]))
}
let hashable_col = col_buf.iter().map(|v| v.to_category());
res.push(SeriesOneHotEncoder::fit_to_iter(hashable_col));
}
Ok(Self {
series_encoders: res, //Self::build_series_encoders::<T, M>(data, &idxs[..]),
col_idx_categorical: idxs,
})
}
(None, true) => {
todo!("implement categorical auto-inference")
todo!("Auto-Inference for Categorical Variables not yet implemented")
}
}
}
}
}
}