diff --git a/src/preprocessing/series_encoder.rs b/src/preprocessing/series_encoder.rs index 438d678..4e9625e 100644 --- a/src/preprocessing/series_encoder.rs +++ b/src/preprocessing/series_encoder.rs @@ -8,6 +8,73 @@ use crate::math::num::RealNumber; use std::collections::HashMap; use std::hash::Hash; +#[derive(Debug, Clone)] +pub struct CategoryMapper { + category_map: HashMap, + categories: Vec, + num_categories: usize, +} + +impl<'a, CategoryType: 'a + Hash + Eq + Clone> CategoryMapper { + fn fit_to_iter(categories: impl Iterator) -> Self { + let mut category_map: HashMap = HashMap::new(); + let mut category_num = 0usize; + let mut unique_lables: Vec = Vec::new(); + + for l in categories { + if !category_map.contains_key(&l) { + category_map.insert(l.clone(), category_num); + unique_lables.push(l.clone()); + category_num += 1; + } + } + Self { + category_map, + num_categories: category_num, + categories: unique_lables, + } + } + + fn from_category_map(category_map: HashMap) -> Self { + let mut _unique_cat: Vec<(CategoryType, usize)> = + category_map.iter().map(|(k, v)| (k.clone(), *v)).collect(); + _unique_cat.sort_by(|a, b| a.1.cmp(&b.1)); + let categories: Vec = _unique_cat.into_iter().map(|a| a.0).collect(); + Self { + num_categories: categories.len(), + categories, + category_map, + } + } + + fn from_positional_category_vec(categories: Vec) -> Self { + let category_map: HashMap = categories + .iter() + .enumerate() + .map(|(v, k)| (k.clone(), v)) + .collect(); + Self { + num_categories: categories.len(), + category_map, + categories, + } + } + + /// Get label num of a category + fn get_num(&self, category: &CategoryType) -> Option<&usize> { + self.category_map.get(category) + } + + /// Return category corresponding to label num + fn get_cat(&self, num: usize) -> &CategoryType { + &self.categories[num] + } + + fn get_categories(&self) -> &[CategoryType] { + &self.categories[..] + } +} + /// Make a one-hot encoded vector from a categorical variable /// /// Example: