diff --git a/src/preprocessing/target_encoders.rs b/src/preprocessing/target_encoders.rs index 3f2592b..ff9fa6e 100644 --- a/src/preprocessing/target_encoders.rs +++ b/src/preprocessing/target_encoders.rs @@ -91,12 +91,31 @@ impl<'a, LabelType: Hash + Eq + Clone> OneHotEncoder { } /// Build an encoder from a predefined (label -> class number) map - pub fn from_label_map(labels: HashMap) -> Self { - Self::from_label_def(LabelDefinition::LabelToClsNumMap(labels)) + pub fn from_label_map(category_map: HashMap) -> Self { + let mut _unique_cat: Vec<(CategoryType, usize)> = + category_map.iter().map(|(k, v)| (k.clone(), *v)).collect(); + _unique_cat.sort_by(|a, b| a.1.cmp(&b.1)); + let categories: Vec = _unique_cat.into_iter().map(|a| a.0).collect(); + Self { + num_categories: categories.len(), + categories, + category_map, } + } + /// Build an encoder from a predefined positional label-class num vector - pub fn from_positional_label_vec(labels: Vec) -> Self { - Self::from_label_def(LabelDefinition::PositionalLabel(labels)) + pub fn from_positional_label_vec(categories: Vec) -> Self { + // Self::from_label_def(LabelDefinition::PositionalLabel(categories)) + let category_map: HashMap = categories + .iter() + .enumerate() + .map(|(v, k)| (k.clone(), v)) + .collect(); + Self { + num_categories: categories.len(), + category_map, + categories, + } } /// Transform a slice of label types into one-hot vectors