Implement the feature importance for Decision Tree Classifier (#275)
* store impurity in the node * add number of features * add a TODO * draft feature importance * feat * n_samples of node * compute_feature_importances * unit tests * always calculate impurity * fix bug * fix linter
This commit is contained in:
@@ -116,6 +116,7 @@ pub struct DecisionTreeClassifier<
|
|||||||
num_classes: usize,
|
num_classes: usize,
|
||||||
classes: Vec<TY>,
|
classes: Vec<TY>,
|
||||||
depth: u16,
|
depth: u16,
|
||||||
|
num_features: usize,
|
||||||
_phantom_tx: PhantomData<TX>,
|
_phantom_tx: PhantomData<TX>,
|
||||||
_phantom_x: PhantomData<X>,
|
_phantom_x: PhantomData<X>,
|
||||||
_phantom_y: PhantomData<Y>,
|
_phantom_y: PhantomData<Y>,
|
||||||
@@ -159,11 +160,13 @@ pub enum SplitCriterion {
|
|||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct Node {
|
struct Node {
|
||||||
output: usize,
|
output: usize,
|
||||||
|
n_node_samples: usize,
|
||||||
split_feature: usize,
|
split_feature: usize,
|
||||||
split_value: Option<f64>,
|
split_value: Option<f64>,
|
||||||
split_score: Option<f64>,
|
split_score: Option<f64>,
|
||||||
true_child: Option<usize>,
|
true_child: Option<usize>,
|
||||||
false_child: Option<usize>,
|
false_child: Option<usize>,
|
||||||
|
impurity: Option<f64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> PartialEq
|
impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> PartialEq
|
||||||
@@ -400,14 +403,16 @@ impl Default for DecisionTreeClassifierSearchParameters {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Node {
|
impl Node {
|
||||||
fn new(output: usize) -> Self {
|
fn new(output: usize, n_node_samples: usize) -> Self {
|
||||||
Node {
|
Node {
|
||||||
output,
|
output,
|
||||||
|
n_node_samples,
|
||||||
split_feature: 0,
|
split_feature: 0,
|
||||||
split_value: Option::None,
|
split_value: Option::None,
|
||||||
split_score: Option::None,
|
split_score: Option::None,
|
||||||
true_child: Option::None,
|
true_child: Option::None,
|
||||||
false_child: Option::None,
|
false_child: Option::None,
|
||||||
|
impurity: Option::None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -507,6 +512,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
num_classes: 0usize,
|
num_classes: 0usize,
|
||||||
classes: vec![],
|
classes: vec![],
|
||||||
depth: 0u16,
|
depth: 0u16,
|
||||||
|
num_features: 0usize,
|
||||||
_phantom_tx: PhantomData,
|
_phantom_tx: PhantomData,
|
||||||
_phantom_x: PhantomData,
|
_phantom_x: PhantomData,
|
||||||
_phantom_y: PhantomData,
|
_phantom_y: PhantomData,
|
||||||
@@ -578,7 +584,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
count[yi[i]] += samples[i];
|
count[yi[i]] += samples[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
let root = Node::new(which_max(&count));
|
let root = Node::new(which_max(&count), y_ncols);
|
||||||
change_nodes.push(root);
|
change_nodes.push(root);
|
||||||
let mut order: Vec<Vec<usize>> = Vec::new();
|
let mut order: Vec<Vec<usize>> = Vec::new();
|
||||||
|
|
||||||
@@ -593,6 +599,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
num_classes: k,
|
num_classes: k,
|
||||||
classes,
|
classes,
|
||||||
depth: 0u16,
|
depth: 0u16,
|
||||||
|
num_features: num_attributes,
|
||||||
_phantom_tx: PhantomData,
|
_phantom_tx: PhantomData,
|
||||||
_phantom_x: PhantomData,
|
_phantom_x: PhantomData,
|
||||||
_phantom_y: PhantomData,
|
_phantom_y: PhantomData,
|
||||||
@@ -678,16 +685,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if is_pure {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
let n = visitor.samples.iter().sum();
|
let n = visitor.samples.iter().sum();
|
||||||
|
|
||||||
if n <= self.parameters().min_samples_split {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut count = vec![0; self.num_classes];
|
let mut count = vec![0; self.num_classes];
|
||||||
let mut false_count = vec![0; self.num_classes];
|
let mut false_count = vec![0; self.num_classes];
|
||||||
for i in 0..n_rows {
|
for i in 0..n_rows {
|
||||||
@@ -696,7 +694,15 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let parent_impurity = impurity(&self.parameters().criterion, &count, n);
|
self.nodes[visitor.node].impurity = Some(impurity(&self.parameters().criterion, &count, n));
|
||||||
|
|
||||||
|
if is_pure {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if n <= self.parameters().min_samples_split {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
let mut variables = (0..n_attr).collect::<Vec<_>>();
|
let mut variables = (0..n_attr).collect::<Vec<_>>();
|
||||||
|
|
||||||
@@ -705,14 +711,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
}
|
}
|
||||||
|
|
||||||
for variable in variables.iter().take(mtry) {
|
for variable in variables.iter().take(mtry) {
|
||||||
self.find_best_split(
|
self.find_best_split(visitor, n, &count, &mut false_count, *variable);
|
||||||
visitor,
|
|
||||||
n,
|
|
||||||
&count,
|
|
||||||
&mut false_count,
|
|
||||||
parent_impurity,
|
|
||||||
*variable,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
self.nodes()[visitor.node].split_score.is_some()
|
self.nodes()[visitor.node].split_score.is_some()
|
||||||
@@ -724,7 +723,6 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
n: usize,
|
n: usize,
|
||||||
count: &[usize],
|
count: &[usize],
|
||||||
false_count: &mut [usize],
|
false_count: &mut [usize],
|
||||||
parent_impurity: f64,
|
|
||||||
j: usize,
|
j: usize,
|
||||||
) {
|
) {
|
||||||
let mut true_count = vec![0; self.num_classes];
|
let mut true_count = vec![0; self.num_classes];
|
||||||
@@ -760,6 +758,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
|
|
||||||
let true_label = which_max(&true_count);
|
let true_label = which_max(&true_count);
|
||||||
let false_label = which_max(false_count);
|
let false_label = which_max(false_count);
|
||||||
|
let parent_impurity = self.nodes()[visitor.node].impurity.unwrap();
|
||||||
let gain = parent_impurity
|
let gain = parent_impurity
|
||||||
- tc as f64 / n as f64
|
- tc as f64 / n as f64
|
||||||
* impurity(&self.parameters().criterion, &true_count, tc)
|
* impurity(&self.parameters().criterion, &true_count, tc)
|
||||||
@@ -827,9 +826,9 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
|
|
||||||
let true_child_idx = self.nodes().len();
|
let true_child_idx = self.nodes().len();
|
||||||
|
|
||||||
self.nodes.push(Node::new(visitor.true_child_output));
|
self.nodes.push(Node::new(visitor.true_child_output, tc));
|
||||||
let false_child_idx = self.nodes().len();
|
let false_child_idx = self.nodes().len();
|
||||||
self.nodes.push(Node::new(visitor.false_child_output));
|
self.nodes.push(Node::new(visitor.false_child_output, fc));
|
||||||
self.nodes[visitor.node].true_child = Some(true_child_idx);
|
self.nodes[visitor.node].true_child = Some(true_child_idx);
|
||||||
self.nodes[visitor.node].false_child = Some(false_child_idx);
|
self.nodes[visitor.node].false_child = Some(false_child_idx);
|
||||||
|
|
||||||
@@ -863,6 +862,33 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
|
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Compute feature importances for the fitted tree.
|
||||||
|
pub fn compute_feature_importances(&self, normalize: bool) -> Vec<f64> {
|
||||||
|
let mut importances = vec![0f64; self.num_features];
|
||||||
|
|
||||||
|
for node in self.nodes().iter() {
|
||||||
|
if node.true_child.is_none() && node.false_child.is_none() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let left = &self.nodes()[node.true_child.unwrap()];
|
||||||
|
let right = &self.nodes()[node.false_child.unwrap()];
|
||||||
|
|
||||||
|
importances[node.split_feature] += node.n_node_samples as f64 * node.impurity.unwrap()
|
||||||
|
- left.n_node_samples as f64 * left.impurity.unwrap()
|
||||||
|
- right.n_node_samples as f64 * right.impurity.unwrap();
|
||||||
|
}
|
||||||
|
for item in importances.iter_mut() {
|
||||||
|
*item /= self.nodes()[0].n_node_samples as f64;
|
||||||
|
}
|
||||||
|
if normalize {
|
||||||
|
let sum = importances.iter().sum::<f64>();
|
||||||
|
for importance in importances.iter_mut() {
|
||||||
|
*importance /= sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
importances
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -1016,6 +1042,42 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_compute_feature_importances() {
|
||||||
|
let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
|
||||||
|
&[1., 1., 1., 0.],
|
||||||
|
&[1., 1., 1., 0.],
|
||||||
|
&[1., 1., 1., 1.],
|
||||||
|
&[1., 1., 0., 0.],
|
||||||
|
&[1., 1., 0., 1.],
|
||||||
|
&[1., 0., 1., 0.],
|
||||||
|
&[1., 0., 1., 0.],
|
||||||
|
&[1., 0., 1., 1.],
|
||||||
|
&[1., 0., 0., 0.],
|
||||||
|
&[1., 0., 0., 1.],
|
||||||
|
&[0., 1., 1., 0.],
|
||||||
|
&[0., 1., 1., 0.],
|
||||||
|
&[0., 1., 1., 1.],
|
||||||
|
&[0., 1., 0., 0.],
|
||||||
|
&[0., 1., 0., 1.],
|
||||||
|
&[0., 0., 1., 0.],
|
||||||
|
&[0., 0., 1., 0.],
|
||||||
|
&[0., 0., 1., 1.],
|
||||||
|
&[0., 0., 0., 0.],
|
||||||
|
&[0., 0., 0., 1.],
|
||||||
|
]);
|
||||||
|
let y: Vec<u32> = vec![1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0];
|
||||||
|
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
tree.compute_feature_importances(false),
|
||||||
|
vec![0., 0., 0.21333333333333332, 0.26666666666666666]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
tree.compute_feature_importances(true),
|
||||||
|
vec![0., 0., 0.4444444444444444, 0.5555555555555556]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(
|
#[cfg_attr(
|
||||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
wasm_bindgen_test::wasm_bindgen_test
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
|||||||
Reference in New Issue
Block a user