Separate mapper object
This commit is contained in:
@@ -8,6 +8,73 @@ use crate::math::num::RealNumber;
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::hash::Hash;
|
use std::hash::Hash;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct CategoryMapper<CategoryType> {
|
||||||
|
category_map: HashMap<CategoryType, usize>,
|
||||||
|
categories: Vec<CategoryType>,
|
||||||
|
num_categories: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, CategoryType: 'a + Hash + Eq + Clone> CategoryMapper<CategoryType> {
|
||||||
|
fn fit_to_iter(categories: impl Iterator<Item = CategoryType>) -> Self {
|
||||||
|
let mut category_map: HashMap<CategoryType, usize> = HashMap::new();
|
||||||
|
let mut category_num = 0usize;
|
||||||
|
let mut unique_lables: Vec<CategoryType> = Vec::new();
|
||||||
|
|
||||||
|
for l in categories {
|
||||||
|
if !category_map.contains_key(&l) {
|
||||||
|
category_map.insert(l.clone(), category_num);
|
||||||
|
unique_lables.push(l.clone());
|
||||||
|
category_num += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Self {
|
||||||
|
category_map,
|
||||||
|
num_categories: category_num,
|
||||||
|
categories: unique_lables,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn from_category_map(category_map: HashMap<CategoryType, usize>) -> Self {
|
||||||
|
let mut _unique_cat: Vec<(CategoryType, usize)> =
|
||||||
|
category_map.iter().map(|(k, v)| (k.clone(), *v)).collect();
|
||||||
|
_unique_cat.sort_by(|a, b| a.1.cmp(&b.1));
|
||||||
|
let categories: Vec<CategoryType> = _unique_cat.into_iter().map(|a| a.0).collect();
|
||||||
|
Self {
|
||||||
|
num_categories: categories.len(),
|
||||||
|
categories,
|
||||||
|
category_map,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn from_positional_category_vec(categories: Vec<CategoryType>) -> Self {
|
||||||
|
let category_map: HashMap<CategoryType, usize> = categories
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(v, k)| (k.clone(), v))
|
||||||
|
.collect();
|
||||||
|
Self {
|
||||||
|
num_categories: categories.len(),
|
||||||
|
category_map,
|
||||||
|
categories,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get label num of a category
|
||||||
|
fn get_num(&self, category: &CategoryType) -> Option<&usize> {
|
||||||
|
self.category_map.get(category)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return category corresponding to label num
|
||||||
|
fn get_cat(&self, num: usize) -> &CategoryType {
|
||||||
|
&self.categories[num]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_categories(&self) -> &[CategoryType] {
|
||||||
|
&self.categories[..]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Make a one-hot encoded vector from a categorical variable
|
/// Make a one-hot encoded vector from a categorical variable
|
||||||
///
|
///
|
||||||
/// Example:
|
/// Example:
|
||||||
|
|||||||
Reference in New Issue
Block a user