diff --git a/src/preprocessing/target_encoders.rs b/src/preprocessing/target_encoders.rs index 56a97ed..3f2592b 100644 --- a/src/preprocessing/target_encoders.rs +++ b/src/preprocessing/target_encoders.rs @@ -6,7 +6,13 @@ use crate::math::num::RealNumber; use std::collections::HashMap; use std::hash::Hash; -/// Turn a collection of `LabelType`s into a one-hot vectors. +/// Make a one-hot encoded vector from a categorical variable +pub fn make_one_hot>(label_idx: usize, num_labels: usize) -> V { + let pos = T::from_f64(1f64).unwrap(); + let mut z = V::zeros(num_labels); + z.set(label_idx, pos); + z +} /// This struct encodes single class per exmample /// /// You can fit a label enumeration by passing a collection of labels.