Separate mapper object
This commit is contained in:
@@ -8,6 +8,73 @@ use crate::math::num::RealNumber;
|
||||
use std::collections::HashMap;
|
||||
use std::hash::Hash;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CategoryMapper<CategoryType> {
|
||||
category_map: HashMap<CategoryType, usize>,
|
||||
categories: Vec<CategoryType>,
|
||||
num_categories: usize,
|
||||
}
|
||||
|
||||
impl<'a, CategoryType: 'a + Hash + Eq + Clone> CategoryMapper<CategoryType> {
|
||||
fn fit_to_iter(categories: impl Iterator<Item = CategoryType>) -> Self {
|
||||
let mut category_map: HashMap<CategoryType, usize> = HashMap::new();
|
||||
let mut category_num = 0usize;
|
||||
let mut unique_lables: Vec<CategoryType> = Vec::new();
|
||||
|
||||
for l in categories {
|
||||
if !category_map.contains_key(&l) {
|
||||
category_map.insert(l.clone(), category_num);
|
||||
unique_lables.push(l.clone());
|
||||
category_num += 1;
|
||||
}
|
||||
}
|
||||
Self {
|
||||
category_map,
|
||||
num_categories: category_num,
|
||||
categories: unique_lables,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_category_map(category_map: HashMap<CategoryType, usize>) -> Self {
|
||||
let mut _unique_cat: Vec<(CategoryType, usize)> =
|
||||
category_map.iter().map(|(k, v)| (k.clone(), *v)).collect();
|
||||
_unique_cat.sort_by(|a, b| a.1.cmp(&b.1));
|
||||
let categories: Vec<CategoryType> = _unique_cat.into_iter().map(|a| a.0).collect();
|
||||
Self {
|
||||
num_categories: categories.len(),
|
||||
categories,
|
||||
category_map,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_positional_category_vec(categories: Vec<CategoryType>) -> Self {
|
||||
let category_map: HashMap<CategoryType, usize> = categories
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(v, k)| (k.clone(), v))
|
||||
.collect();
|
||||
Self {
|
||||
num_categories: categories.len(),
|
||||
category_map,
|
||||
categories,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get label num of a category
|
||||
fn get_num(&self, category: &CategoryType) -> Option<&usize> {
|
||||
self.category_map.get(category)
|
||||
}
|
||||
|
||||
/// Return category corresponding to label num
|
||||
fn get_cat(&self, num: usize) -> &CategoryType {
|
||||
&self.categories[num]
|
||||
}
|
||||
|
||||
fn get_categories(&self) -> &[CategoryType] {
|
||||
&self.categories[..]
|
||||
}
|
||||
}
|
||||
|
||||
/// Make a one-hot encoded vector from a categorical variable
|
||||
///
|
||||
/// Example:
|
||||
|
||||
Reference in New Issue
Block a user