fit SeriesOneHotEncoders to predefined columns

This commit is contained in:
gaxler
2021-01-27 19:37:54 -08:00
parent 5c400f40d2
commit f91b1f9942
+42
View File
@@ -25,3 +25,45 @@ pub struct OneHotEncoder {
categorical_param_idxs: Vec<usize>,
}
impl<T: RealNumber, M: Matrix<T>> OneHotEncoder {
/// PlaceHolder
pub fn fit(data: &M, params: OneHotEncoderParams) -> Result<OneHotEncoder, Failed> {
match (params.categorical_param_idxs, params.infer_categorical) {
(None, false) => Err(Failed::fit(
"Must pass categorical series ids or infer flag",
)),
(Some(idxs), true) => Err(Failed::fit(
"Ambigous parameters, got both infer and categroy ids",
)),
(Some(idxs), false) => Ok(Self {
series_encoders: Self::build_series_encoders::<T, M>(data, &idxs[..]),
categorical_param_idxs: idxs,
}),
(None, true) => {
todo!("implement categorical auto-inference")
}
}
}
fn build_series_encoders(data: &M, idxs: &[usize]) -> Vec<SeriesOneHotEncoder<HashableReal>> {
let (nrows, _) = data.shape();
// let mut res: Vec<SeriesOneHotEncoder<HashableReal>> = Vec::with_capacity(idxs.len());
let mut tmp_col: Vec<T> = Vec::with_capacity(nrows);
let res: Vec<SeriesOneHotEncoder<HashableReal>> = idxs
.iter()
.map(|&idx| {
data.copy_col_as_vec(idx, &mut tmp_col);
let hashable_col = tmp_col.iter().map(|v| hashable_num::<T>(v));
SeriesOneHotEncoder::fit_to_iter(hashable_col)
})
.collect();
res
}
}