From 700d320724c8dad09cdd31e3d73e5cc4d91c33ce Mon Sep 17 00:00:00 2001 From: gaxler Date: Wed, 3 Feb 2021 10:45:25 -0800 Subject: [PATCH] simplify SeriesEncoder trait --- src/preprocessing/series_encoder.rs | 134 ++++++++++++++-------------- 1 file changed, 68 insertions(+), 66 deletions(-) diff --git a/src/preprocessing/series_encoder.rs b/src/preprocessing/series_encoder.rs index 9d7e259..6975c0d 100644 --- a/src/preprocessing/series_encoder.rs +++ b/src/preprocessing/series_encoder.rs @@ -10,19 +10,22 @@ use std::hash::Hash; /// Bi-directional map category <-> label num. #[derive(Debug, Clone)] -pub struct CategoryMapper { - category_map: HashMap, - categories: Vec, +pub struct CategoryMapper { + category_map: HashMap, + categories: Vec, num_categories: usize, } -impl<'a, CategoryType: 'a + Hash + Eq + Clone> CategoryMapper { +impl<'a, C> CategoryMapper +where + C: 'a + Hash + Eq + Clone +{ /// Fit an encoder to a lable iterator - pub fn fit_to_iter(categories: impl Iterator) -> Self { - let mut category_map: HashMap = HashMap::new(); + pub fn fit_to_iter(categories: impl Iterator) -> Self { + let mut category_map: HashMap = HashMap::new(); let mut category_num = 0usize; - let mut unique_lables: Vec = Vec::new(); + let mut unique_lables: Vec = Vec::new(); for l in categories { if !category_map.contains_key(&l) { @@ -39,11 +42,11 @@ impl<'a, CategoryType: 'a + Hash + Eq + Clone> CategoryMapper { } /// Build an encoder from a predefined (category -> class number) map - pub fn from_category_map(category_map: HashMap) -> Self { - let mut _unique_cat: Vec<(CategoryType, usize)> = + pub fn from_category_map(category_map: HashMap) -> Self { + let mut _unique_cat: Vec<(C, usize)> = category_map.iter().map(|(k, v)| (k.clone(), *v)).collect(); _unique_cat.sort_by(|a, b| a.1.cmp(&b.1)); - let categories: Vec = _unique_cat.into_iter().map(|a| a.0).collect(); + let categories: Vec = _unique_cat.into_iter().map(|a| a.0).collect(); Self { num_categories: categories.len(), categories, @@ -52,8 +55,8 @@ impl<'a, CategoryType: 'a + Hash + Eq + Clone> CategoryMapper { } /// Build an encoder from a predefined positional category-class num vector - pub fn from_positional_category_vec(categories: Vec) -> Self { - let category_map: HashMap = categories + pub fn from_positional_category_vec(categories: Vec) -> Self { + let category_map: HashMap = categories .iter() .enumerate() .map(|(v, k)| (k.clone(), v)) @@ -66,64 +69,49 @@ impl<'a, CategoryType: 'a + Hash + Eq + Clone> CategoryMapper { } /// Get label num of a category - pub fn get_num(&self, category: &CategoryType) -> Option<&usize> { + pub fn get_num(&self, category: &C) -> Option<&usize> { self.category_map.get(category) } /// Return category corresponding to label num - pub fn get_cat(&self, num: usize) -> &CategoryType { + pub fn get_cat(&self, num: usize) -> &C { &self.categories[num] } /// List all categories (position = category number) - pub fn get_categories(&self) -> &[CategoryType] { + pub fn get_categories(&self) -> &[C] { &self.categories[..] } } /// Defines common behavior for series encoders(e.g. OneHot, Ordinal) -pub trait SeriesEncoder: +pub trait SeriesEncoder: where - CategoryType:Hash + Eq + Clone + C: Hash + Eq + Clone { /// Fit an encoder to a lable iterator - fn fit_to_iter(categories: impl Iterator) -> Self; + fn fit_to_iter(categories: impl Iterator) -> Self; /// Number of categories for categorical variable fn num_categories(&self) -> usize; - /// Build an encoder from a predefined (category -> class number) map - fn from_category_map(category_map: HashMap) -> Self; - - /// Build an encoder from a predefined positional category-class num vector - fn from_positional_category_vec(categories: Vec) -> Self; - /// Transform a single category type into a one-hot vector - fn transform_one>(&self, category: &CategoryType) -> Option; + fn transform_one>(&self, category: &C) -> Option; /// Invert one-hot vector, back to the category - fn invert_one>(&self, one_hot: V) -> Result; + fn invert_one>(&self, one_hot: V) -> Result; /// Get categories ordered by encoder's category enumeration - fn get_categories(&self) -> &[CategoryType]; + fn get_categories(&self) -> &[C]; /// Take an iterator as a series to transform + /// None is returned if unknown category is encountered fn transform_iter>( &self, - cat_it: impl Iterator, + cat_it: impl Iterator, ) -> Vec> { cat_it.map(|l| self.transform_one(&l)).collect() } - - /// Transform a slice of category types into one-hot vectors - /// None is returned if unknown category is encountered - fn transfrom_series>( - &self, - categories: &[CategoryType], - ) -> Vec> { - let v = categories.iter().cloned(); - self.transform_iter(v) - } } /// Make a one-hot encoded vector from a categorical variable @@ -153,22 +141,22 @@ pub fn make_one_hot>( /// Example: /// ``` /// use std::collections::HashMap; -/// use smartcore::preprocessing::series_encoder::SeriesOneHotEncoder; +/// use smartcore::preprocessing::series_encoder::{SeriesOneHotEncoder, SeriesEncoder}; /// /// let fake_categories: Vec = vec![1, 2, 3, 4, 5, 3, 5, 3, 1, 2, 4]; /// let it = fake_categories.iter().map(|&a| a); -/// let enc = SeriesOneHotEncoder::::fit_to_iter(it); +/// let enc: SeriesOneHotEncoder:: = SeriesEncoder::fit_to_iter(it); /// let oh_vec: Vec = enc.transform_one(&1).unwrap(); /// // notice that 1 is actually a zero-th positional category /// assert_eq!(oh_vec, vec![1.0, 0.0, 0.0, 0.0, 0.0]); /// ``` /// -/// You can also pass a predefined category enumeration such as a hashmap `HashMap` or a vector `Vec` +/// You can also pass a predefined category enumeration such as a hashmap `HashMap` or a vector `Vec` /// /// /// ``` /// use std::collections::HashMap; -/// use smartcore::preprocessing::series_encoder::SeriesOneHotEncoder; +/// use smartcore::preprocessing::series_encoder::{SeriesOneHotEncoder, SeriesEncoder, CategoryMapper}; /// /// let category_map: HashMap<&str, usize> = /// vec![("cat", 2), ("background",0), ("dog", 1)] @@ -176,43 +164,53 @@ pub fn make_one_hot>( /// .collect(); /// let category_vec = vec!["background", "dog", "cat"]; /// -/// let enc_lv = SeriesOneHotEncoder::<&str>::from_positional_category_vec(category_vec); -/// let enc_lm = SeriesOneHotEncoder::<&str>::from_category_map(category_map); +/// let enc_lv = SeriesOneHotEncoder::<&str>::new(CategoryMapper::from_positional_category_vec(category_vec)); +/// let enc_lm = SeriesOneHotEncoder::<&str>::new(CategoryMapper::from_category_map(category_map)); /// /// // ["background", "dog", "cat"] /// println!("{:?}", enc_lv.get_categories()); -/// assert_eq!(enc_lv.transform_one::(&"dog"), enc_lm.transform_one::(&"dog")) +/// let lv: Vec = enc_lv.transform_one(&"dog").unwrap(); +/// let lm: Vec = enc_lm.transform_one(&"dog").unwrap(); +/// assert_eq!(lv, lm); /// ``` #[derive(Debug, Clone)] -pub struct SeriesOneHotEncoder { - mapper: CategoryMapper, +pub struct SeriesOneHotEncoder { + mapper: CategoryMapper, } -impl SeriesEncoder for SeriesOneHotEncoder { +impl SeriesOneHotEncoder +where + C: Hash + Eq + Clone +{ + /// Create SeriesEncoder form existing mapper + pub fn new(mapper: CategoryMapper) -> Self { + Self {mapper} + } +} + +impl SeriesEncoder for SeriesOneHotEncoder +where + C: Hash + Eq + Clone +{ - fn fit_to_iter(categories: impl Iterator) -> Self { + + fn fit_to_iter(categories: impl Iterator) -> Self { Self {mapper:CategoryMapper::fit_to_iter(categories)} } - /// Build an encoder from a predefined (category -> class number) map - fn from_category_map(category_map: HashMap) -> Self { - Self {mapper: CategoryMapper::from_category_map(category_map)} - } - - /// Build an encoder from a predefined positional category-class num vector - fn from_positional_category_vec(categories: Vec) -> Self { - Self {mapper:CategoryMapper::from_positional_category_vec(categories)} - } - fn num_categories(&self) -> usize { self.mapper.num_categories } - fn get_categories(&self) -> &[CategoryType] { + fn get_categories(&self) -> &[C] { self.mapper.get_categories() } - fn invert_one>(&self, one_hot: V) -> Result + fn invert_one(&self, one_hot: V) -> Result + where + U: RealNumber, + V: BaseVector + { let pos = U::from_f64(1f64).unwrap(); @@ -234,7 +232,11 @@ impl SeriesEncoder for SeriesOneH Err(Failed::transform(&pos_entries[..])) } - fn transform_one>(&self, category: &CategoryType) -> Option { + fn transform_one(&self, category: &C) -> Option + where + U: RealNumber, + V: BaseVector + { match self.mapper.get_num(category) { None => None, Some(&idx) => Some(make_one_hot(idx, self.num_categories())), @@ -262,7 +264,7 @@ mod tests { fn build_fake_str_enc<'a>() -> SeriesOneHotEncoder<&'a str> { let fake_category_pos = vec!["background", "dog", "cat"]; - let enc = SeriesOneHotEncoder::<&str>::from_positional_category_vec(fake_category_pos); + let enc = SeriesOneHotEncoder::<&str>::new( CategoryMapper::from_positional_category_vec(fake_category_pos)); enc } @@ -271,7 +273,7 @@ mod tests { let category_map: HashMap<&str, usize> = vec![("background", 0), ("dog", 1), ("cat", 2)] .into_iter() .collect(); - let enc = SeriesOneHotEncoder::<&str>::from_category_map(category_map); + let enc = SeriesOneHotEncoder::<&str>::new( CategoryMapper::from_category_map(category_map)); let oh_vec: Vec = match enc.transform_one(&"dog") { None => panic!("Wrong categories"), Some(v) => v, @@ -306,8 +308,8 @@ mod tests { #[test] fn test_many_categorys() { let enc = build_fake_str_enc(); - let res: Vec>> = - enc.transfrom_series(&["dog", "cat", "fish", "background"]); + let cat_it = ["dog", "cat", "fish", "background"].iter().cloned(); + let res: Vec>> = enc.transform_iter(cat_it); let v = vec![ Some(vec![0.0, 1.0, 0.0]), Some(vec![0.0, 0.0, 1.0]),