diff --git a/src/lib.rs b/src/lib.rs index c5802d2..6e6205f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -91,9 +91,9 @@ pub mod naive_bayes; /// Supervised neighbors-based learning methods pub mod neighbors; pub(crate) mod optimization; +/// Preprocessing utilities +pub mod preprocessing; /// Support Vector Machines pub mod svm; /// Supervised tree-based learning methods pub mod tree; -/// Preprocessing utilities -pub mod preprocessing; diff --git a/src/preprocessing/mod.rs b/src/preprocessing/mod.rs index e4b5190..c70f7dc 100644 --- a/src/preprocessing/mod.rs +++ b/src/preprocessing/mod.rs @@ -1 +1 @@ -pub mod target_encoders; \ No newline at end of file +pub mod target_encoders; diff --git a/src/preprocessing/target_encoders.rs b/src/preprocessing/target_encoders.rs index 1894361..81cbdbd 100644 --- a/src/preprocessing/target_encoders.rs +++ b/src/preprocessing/target_encoders.rs @@ -1,22 +1,18 @@ - #![allow(clippy::ptr_arg)] //! # Encode categorical features as a one-hot or multi-class numeric array. -//! +//! -use std::hash::Hash; -use std::collections::HashMap; - -use crate::math::num::RealNumber; use crate::error::Failed; - +use crate::math::num::RealNumber; +use std::collections::HashMap; +use std::hash::Hash; /// Turn a collection of label types into a one-hot vectors. /// This struct encodes single class per exmample pub struct OneHotEncoder { label_to_idx: HashMap, labels: Vec, - num_classes: usize - + num_classes: usize, } enum LabelDefinition { @@ -27,13 +23,18 @@ enum LabelDefinition { /// Crearte a vector of size num_labels with zeros everywhere and 1 at label_idx (one-hot vector) pub fn make_one_hot(label_idx: usize, num_labels: usize) -> Vec { let (pos, neg) = (T::from_f64(1f64).unwrap(), T::from_f64(0f64).unwrap()); - (0..num_labels).map(|idx| if idx == label_idx {pos.clone()} else {neg.clone()}).collect() - + (0..num_labels) + .map(|idx| { + if idx == label_idx { + pos.clone() + } else { + neg.clone() + } + }) + .collect() } -impl<'a, T: Hash + Eq + Clone> OneHotEncoder -{ - +impl<'a, T: Hash + Eq + Clone> OneHotEncoder { /// Fit an encoder to a lable list /// /// Label numbers will be assigned in the order they are encountered @@ -45,23 +46,24 @@ impl<'a, T: Hash + Eq + Clone> OneHotEncoder /// assert_eq!(oh_vec, vec![1f64,0f64,0f64,0f64,0f64]); /// ``` pub fn fit(labels: &[T]) -> Self { - let mut label_map: HashMap = HashMap::new(); let mut class_num = 0usize; let mut unique_lables: Vec = Vec::new(); - for l in labels - { + for l in labels { if !label_map.contains_key(&l) { label_map.insert(l.clone(), class_num); unique_lables.push(l.clone()); class_num += 1; } } - Self {label_to_idx: label_map, num_classes: class_num, labels:unique_lables} + Self { + label_to_idx: label_map, + num_classes: class_num, + labels: unique_lables, + } } - /// Build an encoder from a predefined (label -> class number) map /// /// Definition example: @@ -84,21 +86,18 @@ impl<'a, T: Hash + Eq + Clone> OneHotEncoder pub fn from_positional_label_vec(labels: Vec) -> Self { Self::from_label_def(LabelDefinition::PositionalLabel(labels)) } - + /// Transform a slice of label types into one-hot vectors - /// None is returned if unknown label is encountered + /// None is returned if unknown label is encountered pub fn transform(&self, labels: &[T]) -> Vec>> { - labels - .into_iter() - .map(|l| self.transform_one(l)) - .collect() + labels.into_iter().map(|l| self.transform_one(l)).collect() } /// Transform a single label type into a one-hot vector pub fn transform_one(&self, label: &T) -> Option> { match self.label_to_idx.get(label) { None => None, - Some(&idx) => Some(make_one_hot(idx, self.num_classes)) + Some(&idx) => Some(make_one_hot(idx, self.num_classes)), } } @@ -111,99 +110,104 @@ impl<'a, T: Hash + Eq + Clone> OneHotEncoder let pos = U::from_f64(1f64).unwrap(); let s: Vec = one_hot - .into_iter() - .enumerate() - .filter_map(|(idx, v)| if v == pos {Some(idx)} else {None}) - .collect(); - + .into_iter() + .enumerate() + .filter_map(|(idx, v)| if v == pos { Some(idx) } else { None }) + .collect(); + if s.len() == 1 { let idx = s[0]; - return Ok(self.labels[idx].clone()) + return Ok(self.labels[idx].clone()); } - let pos_entries = format!("Expected a single positive entry, {} entires found", s.len()); + let pos_entries = format!( + "Expected a single positive entry, {} entires found", + s.len() + ); Err(Failed::transform(&pos_entries[..])) } - fn from_label_def(labels: LabelDefinition) -> Self { - let (label_map, class_num, unique_lables) = match labels { LabelDefinition::LabelToClsNumMap(h) => { - let mut _unique_lab: Vec<(T, usize)> = h.iter().map(|(k,v)| (k.clone(), v.clone())).collect(); - _unique_lab.sort_by(|a,b| a.1.cmp(&b.1)); + let mut _unique_lab: Vec<(T, usize)> = + h.iter().map(|(k, v)| (k.clone(), v.clone())).collect(); + _unique_lab.sort_by(|a, b| a.1.cmp(&b.1)); let unique_lab: Vec = _unique_lab.into_iter().map(|a| a.0).collect(); (h, unique_lab.len(), unique_lab) - }, + } LabelDefinition::PositionalLabel(unique_lab) => { - let h: HashMap = unique_lab.iter().enumerate().map(|(v, k)| (k.clone(),v)).collect(); + let h: HashMap = unique_lab + .iter() + .enumerate() + .map(|(v, k)| (k.clone(), v)) + .collect(); (h, unique_lab.len(), unique_lab) } }; - Self {label_to_idx: label_map, num_classes: class_num, labels:unique_lables} - + Self { + label_to_idx: label_map, + num_classes: class_num, + labels: unique_lables, + } } } - #[cfg(test)] mod tests { use super::*; #[test] fn from_labels() { - let fake_labels: Vec = vec![1,2,3,4,5,3,5,3,1,2,4]; + let fake_labels: Vec = vec![1, 2, 3, 4, 5, 3, 5, 3, 1, 2, 4]; let enc = OneHotEncoder::::fit(&fake_labels[0..]); let oh_vec = match enc.transform_one(&1) { None => panic!("Wrong labels"), - Some(v) => v + Some(v) => v, }; - let res: Vec = vec![1f64,0f64,0f64,0f64,0f64]; + let res: Vec = vec![1f64, 0f64, 0f64, 0f64, 0f64]; assert_eq!(oh_vec, res); } - - fn build_fake_str_enc<'a>() -> OneHotEncoder<&'a str>{ - let fake_label_pos = vec!["background","dog", "cat"]; + fn build_fake_str_enc<'a>() -> OneHotEncoder<&'a str> { + let fake_label_pos = vec!["background", "dog", "cat"]; let enc = OneHotEncoder::<&str>::from_positional_label_vec(fake_label_pos); enc } #[test] fn label_map_and_vec() { - let fake_label_map: HashMap<&str, usize> = vec![("background",0), ("dog", 1), ("cat", 2)].into_iter().collect(); - let enc = OneHotEncoder::<&str>::from_label_map(fake_label_map); - let oh_vec = match enc.transform_one(&"dog") { - None => panic!("Wrong labels"), - Some(v) => v - }; - let res: Vec = vec![0f64, 1f64,0f64]; - assert_eq!(oh_vec, res); - } - + let fake_label_map: HashMap<&str, usize> = vec![("background", 0), ("dog", 1), ("cat", 2)] + .into_iter() + .collect(); + let enc = OneHotEncoder::<&str>::from_label_map(fake_label_map); + let oh_vec = match enc.transform_one(&"dog") { + None => panic!("Wrong labels"), + Some(v) => v, + }; + let res: Vec = vec![0f64, 1f64, 0f64]; + assert_eq!(oh_vec, res); + } + #[test] fn positional_labels_vec() { - let enc = build_fake_str_enc(); - let oh_vec = match enc.transform_one(&"dog") { - None => panic!("Wrong labels"), - Some(v) => v - }; - let res: Vec = vec![0f64, 1f64,0f64]; - assert_eq!(oh_vec, res); + let enc = build_fake_str_enc(); + let oh_vec = match enc.transform_one(&"dog") { + None => panic!("Wrong labels"), + Some(v) => v, + }; + let res: Vec = vec![0.0, 1.0, 0.0]; + assert_eq!(oh_vec, res); } #[test] fn invert_label_test() { let enc = build_fake_str_enc(); - let res: Vec = vec![0f64, 1f64,0f64]; + let res: Vec = vec![0.0, 1.0, 0.0]; let lab = enc.invert_one(res).unwrap(); assert_eq!(lab, "dog"); - - if let Err(e) = enc.invert_one(vec![0.0, 0.0,0.0]) { + if let Err(e) = enc.invert_one(vec![0.0, 0.0, 0.0]) { let pos_entries = format!("Expected a single positive entry, 0 entires found"); assert_eq!(e, Failed::transform(&pos_entries[..])); }; } - - - -} \ No newline at end of file +}