From 374dfeceb906262a2797967cfa02514b5ca2d48d Mon Sep 17 00:00:00 2001 From: gaxler Date: Wed, 3 Feb 2021 13:41:25 -0800 Subject: [PATCH] No more SeriesEncoders. --- src/preprocessing/series_encoder.rs | 104 +++++++++++++++++----------- 1 file changed, 63 insertions(+), 41 deletions(-) diff --git a/src/preprocessing/series_encoder.rs b/src/preprocessing/series_encoder.rs index cdbae16..e24eca1 100644 --- a/src/preprocessing/series_encoder.rs +++ b/src/preprocessing/series_encoder.rs @@ -65,7 +65,7 @@ where pub fn num_categories(&self) -> usize { self.num_categories } - + /// Fit an encoder to a lable iterator pub fn fit_to_iter(categories: impl Iterator) -> Self { let mut category_map: HashMap = HashMap::new(); @@ -85,7 +85,7 @@ where categories: unique_lables, } } - + /// Build an encoder from a predefined (category -> class number) map pub fn from_category_map(category_map: HashMap) -> Self { let mut _unique_cat: Vec<(C, usize)> = @@ -98,7 +98,7 @@ where category_map, } } - + /// Build an encoder from a predefined positional category-class num vector pub fn from_positional_category_vec(categories: Vec) -> Self { let category_map: HashMap = categories @@ -130,54 +130,71 @@ where /// Get one-hot encoding of the category pub fn get_one_hot(&self, category: &C) -> Option - where + where U: RealNumber, V: BaseVector, -{ + { match self.get_num(category) { None => None, Some(&idx) => Some(make_one_hot::(idx, self.num_categories)), + } } -} /// Invert one-hot vector, back to the category pub fn invert_one_hot(&self, one_hot: V) -> Result where U: RealNumber, - V: BaseVector + V: BaseVector, + { + let pos = U::one(); - { - let pos = U::from_f64(1f64).unwrap(); - - let oh_it = (0..one_hot.len()).map(|idx| one_hot.get(idx)); - - let s: Vec = oh_it - .enumerate() - .filter_map(|(idx, v)| if v == pos { Some(idx) } else { None }) - .collect(); - - if s.len() == 1 { - let idx = s[0]; - return Ok(self.mapper.get_cat(idx).clone()); - } - let pos_entries = format!( - "Expected a single positive entry, {} entires found", - s.len() - ); - Err(Failed::transform(&pos_entries[..])) + let oh_it = (0..one_hot.len()).map(|idx| one_hot.get(idx)); + + let s: Vec = oh_it + .enumerate() + .filter_map(|(idx, v)| if v == pos { Some(idx) } else { None }) + .collect(); + + if s.len() == 1 { + let idx = s[0]; + return Ok(self.get_cat(idx).clone()); } + let pos_entries = format!( + "Expected a single positive entry, {} entires found", + s.len() + ); + Err(Failed::transform(&pos_entries[..])) + } - fn transform_one(&self, category: &C) -> Option + /// Get ordinal encoding of the catergory + pub fn get_ordinal(&self, category: &C) -> Option where U: RealNumber, - V: BaseVector { - match self.mapper.get_num(category) { + match self.get_num(category) { None => None, - Some(&idx) => Some(make_one_hot(idx, self.num_categories())), + Some(&idx) => U::from_usize(idx), } } - +} + +/// Make a one-hot encoded vector from a categorical variable +/// +/// Example: +/// ``` +/// use smartcore::preprocessing::series_encoder::make_one_hot; +/// let one_hot: Vec = make_one_hot(2, 3); +/// assert_eq!(one_hot, vec![0.0, 0.0, 1.0]); +/// ``` +pub fn make_one_hot(category_idx: usize, num_categories: usize) -> V +where + T: RealNumber, + V: BaseVector, +{ + let pos = T::one(); + let mut z = V::zeros(num_categories); + z.set(category_idx, pos); + z } #[cfg(test)] @@ -188,8 +205,8 @@ mod tests { fn from_categories() { let fake_categories: Vec = vec![1, 2, 3, 4, 5, 3, 5, 3, 1, 2, 4]; let it = fake_categories.iter().map(|&a| a); - let enc = SeriesOneHotEncoder::::fit_to_iter(it); - let oh_vec: Vec = match enc.transform_one(&1) { + let enc = CategoryMapper::::fit_to_iter(it); + let oh_vec: Vec = match enc.get_one_hot(&1) { None => panic!("Wrong categories"), Some(v) => v, }; @@ -197,19 +214,24 @@ mod tests { assert_eq!(oh_vec, res); } - fn build_fake_str_enc<'a>() -> SeriesOneHotEncoder<&'a str> { + fn build_fake_str_enc<'a>() -> CategoryMapper<&'a str> { let fake_category_pos = vec!["background", "dog", "cat"]; - let enc = SeriesOneHotEncoder::<&str>::new( CategoryMapper::from_positional_category_vec(fake_category_pos)); + let enc = CategoryMapper::<&str>::from_positional_category_vec(fake_category_pos); enc } + #[test] + fn ordinal_encoding() { + let enc = build_fake_str_enc(); + assert_eq!(1f64, enc.get_ordinal::(&"dog").unwrap()) + } #[test] fn category_map_and_vec() { let category_map: HashMap<&str, usize> = vec![("background", 0), ("dog", 1), ("cat", 2)] .into_iter() .collect(); - let enc = SeriesOneHotEncoder::<&str>::new( CategoryMapper::from_category_map(category_map)); - let oh_vec: Vec = match enc.transform_one(&"dog") { + let enc = CategoryMapper::<&str>::from_category_map(category_map); + let oh_vec: Vec = match enc.get_one_hot(&"dog") { None => panic!("Wrong categories"), Some(v) => v, }; @@ -220,7 +242,7 @@ mod tests { #[test] fn positional_categories_vec() { let enc = build_fake_str_enc(); - let oh_vec: Vec = match enc.transform_one(&"dog") { + let oh_vec: Vec = match enc.get_one_hot(&"dog") { None => panic!("Wrong categories"), Some(v) => v, }; @@ -232,9 +254,9 @@ mod tests { fn invert_label_test() { let enc = build_fake_str_enc(); let res: Vec = vec![0.0, 1.0, 0.0]; - let lab = enc.invert_one(res).unwrap(); + let lab = enc.invert_one_hot(res).unwrap(); assert_eq!(lab, "dog"); - if let Err(e) = enc.invert_one(vec![0.0, 0.0, 0.0]) { + if let Err(e) = enc.invert_one_hot(vec![0.0, 0.0, 0.0]) { let pos_entries = format!("Expected a single positive entry, 0 entires found"); assert_eq!(e, Failed::transform(&pos_entries[..])); }; @@ -244,7 +266,7 @@ mod tests { fn test_many_categorys() { let enc = build_fake_str_enc(); let cat_it = ["dog", "cat", "fish", "background"].iter().cloned(); - let res: Vec>> = enc.transform_iter(cat_it); + let res: Vec>> = cat_it.map(|v| enc.get_one_hot(&v)).collect(); let v = vec![ Some(vec![0.0, 1.0, 0.0]), Some(vec![0.0, 0.0, 1.0]),