No more SeriesEncoders.

This commit is contained in:
gaxler
2021-02-03 13:41:25 -08:00
parent 3cc20fd400
commit 374dfeceb9
+40 -18
View File
@@ -144,10 +144,9 @@ where
pub fn invert_one_hot<U, V>(&self, one_hot: V) -> Result<C, Failed> pub fn invert_one_hot<U, V>(&self, one_hot: V) -> Result<C, Failed>
where where
U: RealNumber, U: RealNumber,
V: BaseVector<U> V: BaseVector<U>,
{ {
let pos = U::from_f64(1f64).unwrap(); let pos = U::one();
let oh_it = (0..one_hot.len()).map(|idx| one_hot.get(idx)); let oh_it = (0..one_hot.len()).map(|idx| one_hot.get(idx));
@@ -158,7 +157,7 @@ where
if s.len() == 1 { if s.len() == 1 {
let idx = s[0]; let idx = s[0];
return Ok(self.mapper.get_cat(idx).clone()); return Ok(self.get_cat(idx).clone());
} }
let pos_entries = format!( let pos_entries = format!(
"Expected a single positive entry, {} entires found", "Expected a single positive entry, {} entires found",
@@ -167,17 +166,35 @@ where
Err(Failed::transform(&pos_entries[..])) Err(Failed::transform(&pos_entries[..]))
} }
fn transform_one<U, V>(&self, category: &C) -> Option<V> /// Get ordinal encoding of the catergory
pub fn get_ordinal<U>(&self, category: &C) -> Option<U>
where where
U: RealNumber, U: RealNumber,
V: BaseVector<U>
{ {
match self.mapper.get_num(category) { match self.get_num(category) {
None => None, None => None,
Some(&idx) => Some(make_one_hot(idx, self.num_categories())), Some(&idx) => U::from_usize(idx),
}
} }
} }
/// Make a one-hot encoded vector from a categorical variable
///
/// Example:
/// ```
/// use smartcore::preprocessing::series_encoder::make_one_hot;
/// let one_hot: Vec<f64> = make_one_hot(2, 3);
/// assert_eq!(one_hot, vec![0.0, 0.0, 1.0]);
/// ```
pub fn make_one_hot<T, V>(category_idx: usize, num_categories: usize) -> V
where
T: RealNumber,
V: BaseVector<T>,
{
let pos = T::one();
let mut z = V::zeros(num_categories);
z.set(category_idx, pos);
z
} }
#[cfg(test)] #[cfg(test)]
@@ -188,8 +205,8 @@ mod tests {
fn from_categories() { fn from_categories() {
let fake_categories: Vec<usize> = vec![1, 2, 3, 4, 5, 3, 5, 3, 1, 2, 4]; let fake_categories: Vec<usize> = vec![1, 2, 3, 4, 5, 3, 5, 3, 1, 2, 4];
let it = fake_categories.iter().map(|&a| a); let it = fake_categories.iter().map(|&a| a);
let enc = SeriesOneHotEncoder::<usize>::fit_to_iter(it); let enc = CategoryMapper::<usize>::fit_to_iter(it);
let oh_vec: Vec<f64> = match enc.transform_one(&1) { let oh_vec: Vec<f64> = match enc.get_one_hot(&1) {
None => panic!("Wrong categories"), None => panic!("Wrong categories"),
Some(v) => v, Some(v) => v,
}; };
@@ -197,19 +214,24 @@ mod tests {
assert_eq!(oh_vec, res); assert_eq!(oh_vec, res);
} }
fn build_fake_str_enc<'a>() -> SeriesOneHotEncoder<&'a str> { fn build_fake_str_enc<'a>() -> CategoryMapper<&'a str> {
let fake_category_pos = vec!["background", "dog", "cat"]; let fake_category_pos = vec!["background", "dog", "cat"];
let enc = SeriesOneHotEncoder::<&str>::new( CategoryMapper::from_positional_category_vec(fake_category_pos)); let enc = CategoryMapper::<&str>::from_positional_category_vec(fake_category_pos);
enc enc
} }
#[test]
fn ordinal_encoding() {
let enc = build_fake_str_enc();
assert_eq!(1f64, enc.get_ordinal::<f64>(&"dog").unwrap())
}
#[test] #[test]
fn category_map_and_vec() { fn category_map_and_vec() {
let category_map: HashMap<&str, usize> = vec![("background", 0), ("dog", 1), ("cat", 2)] let category_map: HashMap<&str, usize> = vec![("background", 0), ("dog", 1), ("cat", 2)]
.into_iter() .into_iter()
.collect(); .collect();
let enc = SeriesOneHotEncoder::<&str>::new( CategoryMapper::from_category_map(category_map)); let enc = CategoryMapper::<&str>::from_category_map(category_map);
let oh_vec: Vec<f64> = match enc.transform_one(&"dog") { let oh_vec: Vec<f64> = match enc.get_one_hot(&"dog") {
None => panic!("Wrong categories"), None => panic!("Wrong categories"),
Some(v) => v, Some(v) => v,
}; };
@@ -220,7 +242,7 @@ mod tests {
#[test] #[test]
fn positional_categories_vec() { fn positional_categories_vec() {
let enc = build_fake_str_enc(); let enc = build_fake_str_enc();
let oh_vec: Vec<f64> = match enc.transform_one(&"dog") { let oh_vec: Vec<f64> = match enc.get_one_hot(&"dog") {
None => panic!("Wrong categories"), None => panic!("Wrong categories"),
Some(v) => v, Some(v) => v,
}; };
@@ -232,9 +254,9 @@ mod tests {
fn invert_label_test() { fn invert_label_test() {
let enc = build_fake_str_enc(); let enc = build_fake_str_enc();
let res: Vec<f64> = vec![0.0, 1.0, 0.0]; let res: Vec<f64> = vec![0.0, 1.0, 0.0];
let lab = enc.invert_one(res).unwrap(); let lab = enc.invert_one_hot(res).unwrap();
assert_eq!(lab, "dog"); assert_eq!(lab, "dog");
if let Err(e) = enc.invert_one(vec![0.0, 0.0, 0.0]) { if let Err(e) = enc.invert_one_hot(vec![0.0, 0.0, 0.0]) {
let pos_entries = format!("Expected a single positive entry, 0 entires found"); let pos_entries = format!("Expected a single positive entry, 0 entires found");
assert_eq!(e, Failed::transform(&pos_entries[..])); assert_eq!(e, Failed::transform(&pos_entries[..]));
}; };
@@ -244,7 +266,7 @@ mod tests {
fn test_many_categorys() { fn test_many_categorys() {
let enc = build_fake_str_enc(); let enc = build_fake_str_enc();
let cat_it = ["dog", "cat", "fish", "background"].iter().cloned(); let cat_it = ["dog", "cat", "fish", "background"].iter().cloned();
let res: Vec<Option<Vec<f64>>> = enc.transform_iter(cat_it); let res: Vec<Option<Vec<f64>>> = cat_it.map(|v| enc.get_one_hot(&v)).collect();
let v = vec![ let v = vec![
Some(vec![0.0, 1.0, 0.0]), Some(vec![0.0, 1.0, 0.0]),
Some(vec![0.0, 0.0, 1.0]), Some(vec![0.0, 0.0, 1.0]),