From d31145b4fe24e0718aef3b0b9371e9e2834b31ce Mon Sep 17 00:00:00 2001 From: gaxler Date: Tue, 2 Feb 2021 18:19:36 -0800 Subject: [PATCH] Define common series encoder behavior --- src/preprocessing/series_encoder.rs | 146 +++++++++++++--------------- 1 file changed, 70 insertions(+), 76 deletions(-) diff --git a/src/preprocessing/series_encoder.rs b/src/preprocessing/series_encoder.rs index 4e9625e..4e9ddf9 100644 --- a/src/preprocessing/series_encoder.rs +++ b/src/preprocessing/series_encoder.rs @@ -75,6 +75,50 @@ impl<'a, CategoryType: 'a + Hash + Eq + Clone> CategoryMapper { } } +/// Defines common behavior for series encoders(e.g. OneHot, Ordinal) +pub trait SeriesEncoder: + where + CategoryType:Hash + Eq + Clone +{ + /// Fit an encoder to a lable list + fn fit_to_iter(categories: impl Iterator) -> Self; + + /// Number of categories for categorical variable + fn num_categories(&self) -> usize; + + /// Build an encoder from a predefined (category -> class number) map + fn from_category_map(category_map: HashMap) -> Self; + + /// Build an encoder from a predefined positional category-class num vector + fn from_positional_category_vec(categories: Vec) -> Self; + + /// Transform a single category type into a one-hot vector + fn transform_one>(&self, category: &CategoryType) -> Option; + + /// Invert one-hot vector, back to the category + fn invert_one>(&self, one_hot: V) -> Result; + + /// Get categories ordered by encoder's category enumeration + fn get_categories(&self) -> &[CategoryType]; + + /// Take an iterator as a series to transform + fn transform_iter>( + &self, + cat_it: impl Iterator, + ) -> Vec> { + cat_it.map(|l| self.transform_one(&l)).collect() + } + + /// Transform a slice of category types into one-hot vectors + /// None is returned if unknown category is encountered + fn transfrom_series>( + &self, + categories: &[CategoryType], + ) -> Vec> { + let v = categories.iter().cloned(); + self.transform_iter(v) + } +} /// Make a one-hot encoded vector from a categorical variable /// /// Example: @@ -134,104 +178,47 @@ pub fn make_one_hot>( /// ``` #[derive(Debug, Clone)] pub struct SeriesOneHotEncoder { - category_map: HashMap, - categories: Vec, - /// Number of categories for categorical variable - pub num_categories: usize, + mapper: CategoryMapper, } -impl<'a, CategoryType: 'a + Hash + Eq + Clone> SeriesOneHotEncoder { - /// Fit an encoder to a lable list - pub fn fit_to_iter(categories: impl Iterator) -> Self { - let mut category_map: HashMap = HashMap::new(); - let mut category_num = 0usize; - let mut unique_lables: Vec = Vec::new(); +impl SeriesEncoder for SeriesOneHotEncoder { - for l in categories { - if !category_map.contains_key(&l) { - category_map.insert(l.clone(), category_num); - unique_lables.push(l.clone()); - category_num += 1; + fn fit_to_iter(categories: impl Iterator) -> Self { + Self {mapper:CategoryMapper::fit_to_iter(categories)} } - } - Self { - category_map, - num_categories: category_num, - categories: unique_lables, - } - } /// Build an encoder from a predefined (category -> class number) map - pub fn from_category_map(category_map: HashMap) -> Self { - let mut _unique_cat: Vec<(CategoryType, usize)> = - category_map.iter().map(|(k, v)| (k.clone(), *v)).collect(); - _unique_cat.sort_by(|a, b| a.1.cmp(&b.1)); - let categories: Vec = _unique_cat.into_iter().map(|a| a.0).collect(); - Self { - num_categories: categories.len(), - categories, - category_map, + fn from_category_map(category_map: HashMap) -> Self { + Self {mapper: CategoryMapper::from_category_map(category_map)} } - } /// Build an encoder from a predefined positional category-class num vector - pub fn from_positional_category_vec(categories: Vec) -> Self { - let category_map: HashMap = categories - .iter() - .enumerate() - .map(|(v, k)| (k.clone(), v)) - .collect(); - Self { - num_categories: categories.len(), - category_map, - categories, + fn from_positional_category_vec(categories: Vec) -> Self { + Self {mapper:CategoryMapper::from_positional_category_vec(categories)} } + + fn num_categories(&self) -> usize { + self.mapper.num_categories } - /// Take an iterator as a series to transform - pub fn transform_iter( - &self, - cat_it: impl Iterator, - ) -> Vec>> { - cat_it.map(|l| self.transform_one(&l)).collect() + fn get_categories(&self) -> &[CategoryType] { + self.mapper.get_categories() } - /// Transform a slice of category types into one-hot vectors - /// None is returned if unknown category is encountered - pub fn transfrom_series( - &self, - categories: &'a [CategoryType], - ) -> Vec>> { - let v = categories.iter().cloned(); - self.transform_iter(v) - } - - /// Transform a single category type into a one-hot vector - pub fn transform_one(&self, category: &CategoryType) -> Option> { - match self.category_map.get(category) { - None => None, - Some(&idx) => Some(make_one_hot(idx, self.num_categories)), - } - } - - /// Get categories ordered by encoder's category enumeration - pub fn get_categories(&self) -> &Vec { - &self.categories - } - - /// Invert one-hot vector, back to the category - pub fn invert_one(&self, one_hot: Vec) -> Result { + fn invert_one>(&self, one_hot: V) -> Result + { let pos = U::from_f64(1f64).unwrap(); + + let oh_it = (0..one_hot.len()).map(|idx| one_hot.get(idx)); - let s: Vec = one_hot - .into_iter() + let s: Vec = oh_it .enumerate() .filter_map(|(idx, v)| if v == pos { Some(idx) } else { None }) .collect(); if s.len() == 1 { let idx = s[0]; - return Ok(self.categories[idx].clone()); + return Ok(self.mapper.get_cat(idx).clone()); } let pos_entries = format!( "Expected a single positive entry, {} entires found", @@ -239,6 +226,13 @@ impl<'a, CategoryType: 'a + Hash + Eq + Clone> SeriesOneHotEncoder ); Err(Failed::transform(&pos_entries[..])) } + + fn transform_one>(&self, category: &CategoryType) -> Option { + match self.mapper.get_num(category) { + None => None, + Some(&idx) => Some(make_one_hot(idx, self.num_categories())), + } + } } #[cfg(test)]