diff --git a/src/preprocessing/target_encoders.rs b/src/preprocessing/target_encoders.rs index 44a5c05..76f4c92 100644 --- a/src/preprocessing/target_encoders.rs +++ b/src/preprocessing/target_encoders.rs @@ -7,11 +7,47 @@ use crate::math::num::RealNumber; use std::collections::HashMap; use std::hash::Hash; -/// Turn a collection of label types into a one-hot vectors. +/// Turn a collection of `LabelType`s into a one-hot vectors. /// This struct encodes single class per exmample -pub struct OneHotEncoder { - label_to_idx: HashMap, - labels: Vec, +/// +/// You can fit a label enumeration by passing a collection of labels. +/// Label numbers will be assigned in the order they are encountered +/// +/// Example: +/// ``` +/// use std::collections::HashMap; +/// use smartcore::preprocessing::target_encoders::OneHotEncoder; +/// +/// let fake_labels: Vec = vec![1,2,3,4,5,3,5,3,1,2,4]; +/// let enc = OneHotEncoder::::fit(&fake_labels[..]); +/// let oh_vec: Vec = enc.transform_one(&1).unwrap(); +/// // notice that 1 is actually a zero-th positional label +/// assert_eq!(oh_vec, vec![1.0, 0.0, 0.0, 0.0, 0.0]); +/// ``` +/// +/// You can also pass a predefined label enumeration such as a hashmap `HashMap` or a vector `Vec` +/// +/// +/// ``` +/// use std::collections::HashMap; +/// use smartcore::preprocessing::target_encoders::OneHotEncoder; +/// +/// let label_map: HashMap<&str, usize> = +/// vec![("cat", 2), ("background",0), ("dog", 1)] +/// .into_iter() +/// .collect(); +/// let label_vec = vec!["background", "dog", "cat"]; +/// +/// let enc_lv = OneHotEncoder::<&str>::from_positional_label_vec(label_vec); +/// let enc_lm = OneHotEncoder::<&str>::from_label_map(label_map); +/// +/// // ["background", "dog", "cat"] +/// println!("{:?}", enc_lv.get_labels()); +/// assert_eq!(enc_lv.transform_one::(&"dog"), enc_lm.transform_one::(&"dog")) +/// ``` +pub struct OneHotEncoder { + label_to_idx: HashMap, + labels: Vec, num_classes: usize, } @@ -28,21 +64,12 @@ pub fn make_one_hot(label_idx: usize, num_labels: usize) -> Vec OneHotEncoder { +impl<'a, LabelType: Hash + Eq + Clone> OneHotEncoder { /// Fit an encoder to a lable list - /// - /// Label numbers will be assigned in the order they are encountered - /// Example: - /// ``` - /// let fake_labels: Vec = vec![1,2,3,4,5,3,5,3,1,2,4]; - /// let enc = OneHotEncoder::::fit(&fake_labels[0..]); - /// let oh_vec = enc.transform_one(&1); // notice that 1 is actually a zero-th positional label - /// assert_eq!(oh_vec, vec![1f64,0f64,0f64,0f64,0f64]); - /// ``` - pub fn fit(labels: &[T]) -> Self { - let mut label_map: HashMap = HashMap::new(); + pub fn fit(labels: &[LabelType]) -> Self { + let mut label_map: HashMap = HashMap::new(); let mut class_num = 0usize; - let mut unique_lables: Vec = Vec::new(); + let mut unique_lables: Vec = Vec::new(); for l in labels { if !label_map.contains_key(&l) { @@ -59,48 +86,35 @@ impl<'a, T: Hash + Eq + Clone> OneHotEncoder { } /// Build an encoder from a predefined (label -> class number) map - /// - /// Definition example: - /// ``` - /// let fake_label_map: HashMap<&str, u32> = vec![("background",0), ("dog", 1), ("cat", 2)] - /// .into_iter() - /// .collect(); - /// let enc = OneHotEncoder::<&str>::from_label_map(fake_label_map); - /// ``` - pub fn from_label_map(labels: HashMap) -> Self { + pub fn from_label_map(labels: HashMap) -> Self { Self::from_label_def(LabelDefinition::LabelToClsNumMap(labels)) } /// Build an encoder from a predefined positional label-class num vector - /// - /// Definition example: - /// ``` - /// let fake_label_pos = vec!["background","dog", "cat"]; - /// let enc = OneHotEncoder::<&str>::from_positional_label_vec(fake_label_pos); - /// ``` - pub fn from_positional_label_vec(labels: Vec) -> Self { + pub fn from_positional_label_vec(labels: Vec) -> Self { Self::from_label_def(LabelDefinition::PositionalLabel(labels)) } /// Transform a slice of label types into one-hot vectors /// None is returned if unknown label is encountered - pub fn transform(&self, labels: &[T]) -> Vec>> { + pub fn transform(&self, labels: &[LabelType]) -> Vec>> { labels.iter().map(|l| self.transform_one(l)).collect() } /// Transform a single label type into a one-hot vector - pub fn transform_one(&self, label: &T) -> Option> { + pub fn transform_one(&self, label: &LabelType) -> Option> { match self.label_to_idx.get(label) { None => None, Some(&idx) => Some(make_one_hot(idx, self.num_classes)), } } + /// Get labels ordered by encoder's label enumeration + pub fn get_labels(&self) -> &Vec { + &self.labels + } + /// Invert one-hot vector, back to the label - ///``` - /// let lab = enc.invert_one(res)?; // e.g. res = [0,1,0,0...] "dog" == class 1 - /// assert_eq!(lab, "dog") - /// ``` - pub fn invert_one(&self, one_hot: Vec) -> Result { + pub fn invert_one(&self, one_hot: Vec) -> Result { let pos = U::from_f64(1f64).unwrap(); let s: Vec = one_hot @@ -120,17 +134,17 @@ impl<'a, T: Hash + Eq + Clone> OneHotEncoder { Err(Failed::transform(&pos_entries[..])) } - fn from_label_def(labels: LabelDefinition) -> Self { + fn from_label_def(labels: LabelDefinition) -> Self { let (label_map, class_num, unique_lables) = match labels { LabelDefinition::LabelToClsNumMap(h) => { - let mut _unique_lab: Vec<(T, usize)> = + let mut _unique_lab: Vec<(LabelType, usize)> = h.iter().map(|(k, v)| (k.clone(), *v)).collect(); _unique_lab.sort_by(|a, b| a.1.cmp(&b.1)); - let unique_lab: Vec = _unique_lab.into_iter().map(|a| a.0).collect(); + let unique_lab: Vec = _unique_lab.into_iter().map(|a| a.0).collect(); (h, unique_lab.len(), unique_lab) } LabelDefinition::PositionalLabel(unique_lab) => { - let h: HashMap = unique_lab + let h: HashMap = unique_lab .iter() .enumerate() .map(|(v, k)| (k.clone(), v)) @@ -154,7 +168,7 @@ mod tests { fn from_labels() { let fake_labels: Vec = vec![1, 2, 3, 4, 5, 3, 5, 3, 1, 2, 4]; let enc = OneHotEncoder::::fit(&fake_labels[0..]); - let oh_vec = match enc.transform_one(&1) { + let oh_vec: Vec = match enc.transform_one(&1) { None => panic!("Wrong labels"), Some(v) => v, }; @@ -170,11 +184,11 @@ mod tests { #[test] fn label_map_and_vec() { - let fake_label_map: HashMap<&str, usize> = vec![("background", 0), ("dog", 1), ("cat", 2)] + let label_map: HashMap<&str, usize> = vec![("background", 0), ("dog", 1), ("cat", 2)] .into_iter() .collect(); - let enc = OneHotEncoder::<&str>::from_label_map(fake_label_map); - let oh_vec = match enc.transform_one(&"dog") { + let enc = OneHotEncoder::<&str>::from_label_map(label_map); + let oh_vec: Vec = match enc.transform_one(&"dog") { None => panic!("Wrong labels"), Some(v) => v, }; @@ -185,7 +199,7 @@ mod tests { #[test] fn positional_labels_vec() { let enc = build_fake_str_enc(); - let oh_vec = match enc.transform_one(&"dog") { + let oh_vec: Vec = match enc.transform_one(&"dog") { None => panic!("Wrong labels"), Some(v) => v, }; @@ -204,4 +218,17 @@ mod tests { assert_eq!(e, Failed::transform(&pos_entries[..])); }; } + + #[test] + fn test_many_labels() { + let enc = build_fake_str_enc(); + let res: Vec>> = enc.transform(&["dog", "cat", "fish", "background"]); + let v = vec![ + Some(vec![0.0, 1.0, 0.0]), + Some(vec![0.0, 0.0, 1.0]), + None, + Some(vec![1.0, 0.0, 0.0]), + ]; + assert_eq!(res, v) + } }