No more SeriesEncoders.

This commit is contained in:
gaxler
2021-02-03 13:41:25 -08:00
parent 3cc20fd400
commit 374dfeceb9
+63 -41
View File
@@ -65,7 +65,7 @@ where
pub fn num_categories(&self) -> usize { pub fn num_categories(&self) -> usize {
self.num_categories self.num_categories
} }
/// Fit an encoder to a lable iterator /// Fit an encoder to a lable iterator
pub fn fit_to_iter(categories: impl Iterator<Item = C>) -> Self { pub fn fit_to_iter(categories: impl Iterator<Item = C>) -> Self {
let mut category_map: HashMap<C, usize> = HashMap::new(); let mut category_map: HashMap<C, usize> = HashMap::new();
@@ -85,7 +85,7 @@ where
categories: unique_lables, categories: unique_lables,
} }
} }
/// Build an encoder from a predefined (category -> class number) map /// Build an encoder from a predefined (category -> class number) map
pub fn from_category_map(category_map: HashMap<C, usize>) -> Self { pub fn from_category_map(category_map: HashMap<C, usize>) -> Self {
let mut _unique_cat: Vec<(C, usize)> = let mut _unique_cat: Vec<(C, usize)> =
@@ -98,7 +98,7 @@ where
category_map, category_map,
} }
} }
/// Build an encoder from a predefined positional category-class num vector /// Build an encoder from a predefined positional category-class num vector
pub fn from_positional_category_vec(categories: Vec<C>) -> Self { pub fn from_positional_category_vec(categories: Vec<C>) -> Self {
let category_map: HashMap<C, usize> = categories let category_map: HashMap<C, usize> = categories
@@ -130,54 +130,71 @@ where
/// Get one-hot encoding of the category /// Get one-hot encoding of the category
pub fn get_one_hot<U, V>(&self, category: &C) -> Option<V> pub fn get_one_hot<U, V>(&self, category: &C) -> Option<V>
where where
U: RealNumber, U: RealNumber,
V: BaseVector<U>, V: BaseVector<U>,
{ {
match self.get_num(category) { match self.get_num(category) {
None => None, None => None,
Some(&idx) => Some(make_one_hot::<U, V>(idx, self.num_categories)), Some(&idx) => Some(make_one_hot::<U, V>(idx, self.num_categories)),
}
} }
}
/// Invert one-hot vector, back to the category /// Invert one-hot vector, back to the category
pub fn invert_one_hot<U, V>(&self, one_hot: V) -> Result<C, Failed> pub fn invert_one_hot<U, V>(&self, one_hot: V) -> Result<C, Failed>
where where
U: RealNumber, U: RealNumber,
V: BaseVector<U> V: BaseVector<U>,
{
let pos = U::one();
{ let oh_it = (0..one_hot.len()).map(|idx| one_hot.get(idx));
let pos = U::from_f64(1f64).unwrap();
let s: Vec<usize> = oh_it
let oh_it = (0..one_hot.len()).map(|idx| one_hot.get(idx)); .enumerate()
.filter_map(|(idx, v)| if v == pos { Some(idx) } else { None })
let s: Vec<usize> = oh_it .collect();
.enumerate()
.filter_map(|(idx, v)| if v == pos { Some(idx) } else { None }) if s.len() == 1 {
.collect(); let idx = s[0];
return Ok(self.get_cat(idx).clone());
if s.len() == 1 {
let idx = s[0];
return Ok(self.mapper.get_cat(idx).clone());
}
let pos_entries = format!(
"Expected a single positive entry, {} entires found",
s.len()
);
Err(Failed::transform(&pos_entries[..]))
} }
let pos_entries = format!(
"Expected a single positive entry, {} entires found",
s.len()
);
Err(Failed::transform(&pos_entries[..]))
}
fn transform_one<U, V>(&self, category: &C) -> Option<V> /// Get ordinal encoding of the catergory
pub fn get_ordinal<U>(&self, category: &C) -> Option<U>
where where
U: RealNumber, U: RealNumber,
V: BaseVector<U>
{ {
match self.mapper.get_num(category) { match self.get_num(category) {
None => None, None => None,
Some(&idx) => Some(make_one_hot(idx, self.num_categories())), Some(&idx) => U::from_usize(idx),
} }
} }
}
/// Make a one-hot encoded vector from a categorical variable
///
/// Example:
/// ```
/// use smartcore::preprocessing::series_encoder::make_one_hot;
/// let one_hot: Vec<f64> = make_one_hot(2, 3);
/// assert_eq!(one_hot, vec![0.0, 0.0, 1.0]);
/// ```
pub fn make_one_hot<T, V>(category_idx: usize, num_categories: usize) -> V
where
T: RealNumber,
V: BaseVector<T>,
{
let pos = T::one();
let mut z = V::zeros(num_categories);
z.set(category_idx, pos);
z
} }
#[cfg(test)] #[cfg(test)]
@@ -188,8 +205,8 @@ mod tests {
fn from_categories() { fn from_categories() {
let fake_categories: Vec<usize> = vec![1, 2, 3, 4, 5, 3, 5, 3, 1, 2, 4]; let fake_categories: Vec<usize> = vec![1, 2, 3, 4, 5, 3, 5, 3, 1, 2, 4];
let it = fake_categories.iter().map(|&a| a); let it = fake_categories.iter().map(|&a| a);
let enc = SeriesOneHotEncoder::<usize>::fit_to_iter(it); let enc = CategoryMapper::<usize>::fit_to_iter(it);
let oh_vec: Vec<f64> = match enc.transform_one(&1) { let oh_vec: Vec<f64> = match enc.get_one_hot(&1) {
None => panic!("Wrong categories"), None => panic!("Wrong categories"),
Some(v) => v, Some(v) => v,
}; };
@@ -197,19 +214,24 @@ mod tests {
assert_eq!(oh_vec, res); assert_eq!(oh_vec, res);
} }
fn build_fake_str_enc<'a>() -> SeriesOneHotEncoder<&'a str> { fn build_fake_str_enc<'a>() -> CategoryMapper<&'a str> {
let fake_category_pos = vec!["background", "dog", "cat"]; let fake_category_pos = vec!["background", "dog", "cat"];
let enc = SeriesOneHotEncoder::<&str>::new( CategoryMapper::from_positional_category_vec(fake_category_pos)); let enc = CategoryMapper::<&str>::from_positional_category_vec(fake_category_pos);
enc enc
} }
#[test]
fn ordinal_encoding() {
let enc = build_fake_str_enc();
assert_eq!(1f64, enc.get_ordinal::<f64>(&"dog").unwrap())
}
#[test] #[test]
fn category_map_and_vec() { fn category_map_and_vec() {
let category_map: HashMap<&str, usize> = vec![("background", 0), ("dog", 1), ("cat", 2)] let category_map: HashMap<&str, usize> = vec![("background", 0), ("dog", 1), ("cat", 2)]
.into_iter() .into_iter()
.collect(); .collect();
let enc = SeriesOneHotEncoder::<&str>::new( CategoryMapper::from_category_map(category_map)); let enc = CategoryMapper::<&str>::from_category_map(category_map);
let oh_vec: Vec<f64> = match enc.transform_one(&"dog") { let oh_vec: Vec<f64> = match enc.get_one_hot(&"dog") {
None => panic!("Wrong categories"), None => panic!("Wrong categories"),
Some(v) => v, Some(v) => v,
}; };
@@ -220,7 +242,7 @@ mod tests {
#[test] #[test]
fn positional_categories_vec() { fn positional_categories_vec() {
let enc = build_fake_str_enc(); let enc = build_fake_str_enc();
let oh_vec: Vec<f64> = match enc.transform_one(&"dog") { let oh_vec: Vec<f64> = match enc.get_one_hot(&"dog") {
None => panic!("Wrong categories"), None => panic!("Wrong categories"),
Some(v) => v, Some(v) => v,
}; };
@@ -232,9 +254,9 @@ mod tests {
fn invert_label_test() { fn invert_label_test() {
let enc = build_fake_str_enc(); let enc = build_fake_str_enc();
let res: Vec<f64> = vec![0.0, 1.0, 0.0]; let res: Vec<f64> = vec![0.0, 1.0, 0.0];
let lab = enc.invert_one(res).unwrap(); let lab = enc.invert_one_hot(res).unwrap();
assert_eq!(lab, "dog"); assert_eq!(lab, "dog");
if let Err(e) = enc.invert_one(vec![0.0, 0.0, 0.0]) { if let Err(e) = enc.invert_one_hot(vec![0.0, 0.0, 0.0]) {
let pos_entries = format!("Expected a single positive entry, 0 entires found"); let pos_entries = format!("Expected a single positive entry, 0 entires found");
assert_eq!(e, Failed::transform(&pos_entries[..])); assert_eq!(e, Failed::transform(&pos_entries[..]));
}; };
@@ -244,7 +266,7 @@ mod tests {
fn test_many_categorys() { fn test_many_categorys() {
let enc = build_fake_str_enc(); let enc = build_fake_str_enc();
let cat_it = ["dog", "cat", "fish", "background"].iter().cloned(); let cat_it = ["dog", "cat", "fish", "background"].iter().cloned();
let res: Vec<Option<Vec<f64>>> = enc.transform_iter(cat_it); let res: Vec<Option<Vec<f64>>> = cat_it.map(|v| enc.get_one_hot(&v)).collect();
let v = vec![ let v = vec![
Some(vec![0.0, 1.0, 0.0]), Some(vec![0.0, 1.0, 0.0]),
Some(vec![0.0, 0.0, 1.0]), Some(vec![0.0, 0.0, 1.0]),