feat: extends interface of Matrix to support for broad range of types
This commit is contained in:
@@ -1,5 +1,9 @@
|
||||
use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
use std::collections::LinkedList;
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||
|
||||
@@ -12,11 +16,11 @@ pub struct DecisionTreeClassifierParameters {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DecisionTreeClassifier {
|
||||
nodes: Vec<Node>,
|
||||
pub struct DecisionTreeClassifier<T: FloatExt> {
|
||||
nodes: Vec<Node<T>>,
|
||||
parameters: DecisionTreeClassifierParameters,
|
||||
num_classes: usize,
|
||||
classes: Vec<f64>,
|
||||
classes: Vec<T>,
|
||||
depth: u16
|
||||
}
|
||||
|
||||
@@ -28,12 +32,12 @@ pub enum SplitCriterion {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Node {
|
||||
pub struct Node<T: FloatExt> {
|
||||
index: usize,
|
||||
output: usize,
|
||||
split_feature: usize,
|
||||
split_value: f64,
|
||||
split_score: f64,
|
||||
split_value: T,
|
||||
split_score: T,
|
||||
true_child: Option<usize>,
|
||||
false_child: Option<usize>,
|
||||
}
|
||||
@@ -50,21 +54,21 @@ impl Default for DecisionTreeClassifierParameters {
|
||||
}
|
||||
}
|
||||
|
||||
impl Node {
|
||||
impl<T: FloatExt> Node<T> {
|
||||
fn new(index: usize, output: usize) -> Self {
|
||||
Node {
|
||||
index: index,
|
||||
output: output,
|
||||
split_feature: 0,
|
||||
split_value: std::f64::NAN,
|
||||
split_score: std::f64::NAN,
|
||||
split_value: T::nan(),
|
||||
split_score: T::nan(),
|
||||
true_child: Option::None,
|
||||
false_child: Option::None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct NodeVisitor<'a, M: Matrix> {
|
||||
struct NodeVisitor<'a, T: FloatExt + Debug, M: Matrix<T>> {
|
||||
x: &'a M,
|
||||
y: &'a Vec<usize>,
|
||||
node: usize,
|
||||
@@ -72,19 +76,20 @@ struct NodeVisitor<'a, M: Matrix> {
|
||||
order: &'a Vec<Vec<usize>>,
|
||||
true_child_output: usize,
|
||||
false_child_output: usize,
|
||||
level: u16
|
||||
level: u16,
|
||||
phantom: PhantomData<&'a T>
|
||||
}
|
||||
|
||||
fn impurity(criterion: &SplitCriterion, count: &Vec<usize>, n: usize) -> f64 {
|
||||
let mut impurity = 0.;
|
||||
fn impurity<T: FloatExt>(criterion: &SplitCriterion, count: &Vec<usize>, n: usize) -> T {
|
||||
let mut impurity = T::zero();
|
||||
|
||||
match criterion {
|
||||
SplitCriterion::Gini => {
|
||||
impurity = 1.0;
|
||||
impurity = T::one();
|
||||
for i in 0..count.len() {
|
||||
if count[i] > 0 {
|
||||
let p = count[i] as f64 / n as f64;
|
||||
impurity -= p * p;
|
||||
let p = T::from(count[i]).unwrap() / T::from(n).unwrap();
|
||||
impurity = impurity - p * p;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -92,25 +97,25 @@ fn impurity(criterion: &SplitCriterion, count: &Vec<usize>, n: usize) -> f64 {
|
||||
SplitCriterion::Entropy => {
|
||||
for i in 0..count.len() {
|
||||
if count[i] > 0 {
|
||||
let p = count[i] as f64 / n as f64;
|
||||
impurity -= p * p.log2();
|
||||
let p = T::from(count[i]).unwrap() / T::from(n).unwrap();
|
||||
impurity = impurity - p * p.log2();
|
||||
}
|
||||
}
|
||||
}
|
||||
SplitCriterion::ClassificationError => {
|
||||
for i in 0..count.len() {
|
||||
if count[i] > 0 {
|
||||
impurity = impurity.max(count[i] as f64 / n as f64);
|
||||
impurity = impurity.max(T::from(count[i]).unwrap() / T::from(n).unwrap());
|
||||
}
|
||||
}
|
||||
impurity = (1. - impurity).abs();
|
||||
impurity = (T::one() - impurity).abs();
|
||||
}
|
||||
}
|
||||
|
||||
return impurity;
|
||||
}
|
||||
|
||||
impl<'a, M: Matrix> NodeVisitor<'a, M> {
|
||||
impl<'a, T: FloatExt + Debug, M: Matrix<T>> NodeVisitor<'a, T, M> {
|
||||
|
||||
fn new(node_id: usize, samples: Vec<usize>, order: &'a Vec<Vec<usize>>, x: &'a M, y: &'a Vec<usize>, level: u16) -> Self {
|
||||
NodeVisitor {
|
||||
@@ -121,7 +126,8 @@ impl<'a, M: Matrix> NodeVisitor<'a, M> {
|
||||
order: order,
|
||||
true_child_output: 0,
|
||||
false_child_output: 0,
|
||||
level: level
|
||||
level: level,
|
||||
phantom: PhantomData
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,15 +147,15 @@ pub(in crate) fn which_max(x: &Vec<usize>) -> usize {
|
||||
return which;
|
||||
}
|
||||
|
||||
impl DecisionTreeClassifier {
|
||||
impl<T: FloatExt + Debug> DecisionTreeClassifier<T> {
|
||||
|
||||
pub fn fit<M: Matrix>(x: &M, y: &M::RowVector, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier {
|
||||
pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier<T> {
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let samples = vec![1; x_nrows];
|
||||
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||
}
|
||||
|
||||
pub fn fit_weak_learner<M: Matrix>(x: &M, y: &M::RowVector, samples: Vec<usize>, mtry: usize, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier {
|
||||
pub fn fit_weak_learner<M: Matrix<T>>(x: &M, y: &M::RowVector, samples: Vec<usize>, mtry: usize, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier<T> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
let (_, y_ncols) = y_m.shape();
|
||||
let (_, num_attributes) = x.shape();
|
||||
@@ -166,7 +172,7 @@ impl DecisionTreeClassifier {
|
||||
yi[i] = classes.iter().position(|c| yc == *c).unwrap();
|
||||
}
|
||||
|
||||
let mut nodes: Vec<Node> = Vec::new();
|
||||
let mut nodes: Vec<Node<T>> = Vec::new();
|
||||
|
||||
let mut count = vec![0; k];
|
||||
for i in 0..y_ncols {
|
||||
@@ -189,9 +195,9 @@ impl DecisionTreeClassifier {
|
||||
depth: 0
|
||||
};
|
||||
|
||||
let mut visitor = NodeVisitor::<M>::new(0, samples, &order, &x, &yi, 1);
|
||||
let mut visitor = NodeVisitor::<T, M>::new(0, samples, &order, &x, &yi, 1);
|
||||
|
||||
let mut visitor_queue: LinkedList<NodeVisitor<M>> = LinkedList::new();
|
||||
let mut visitor_queue: LinkedList<NodeVisitor<T, M>> = LinkedList::new();
|
||||
|
||||
if tree.find_best_cutoff(&mut visitor, mtry) {
|
||||
visitor_queue.push_back(visitor);
|
||||
@@ -207,7 +213,7 @@ impl DecisionTreeClassifier {
|
||||
tree
|
||||
}
|
||||
|
||||
pub fn predict<M: Matrix>(&self, x: &M) -> M::RowVector {
|
||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
|
||||
let mut result = M::zeros(1, x.shape().0);
|
||||
|
||||
let (n, _) = x.shape();
|
||||
@@ -219,7 +225,7 @@ impl DecisionTreeClassifier {
|
||||
result.to_row_vector()
|
||||
}
|
||||
|
||||
pub(in crate) fn predict_for_row<M: Matrix>(&self, x: &M, row: usize) -> usize {
|
||||
pub(in crate) fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
|
||||
let mut result = 0;
|
||||
let mut queue: LinkedList<usize> = LinkedList::new();
|
||||
|
||||
@@ -247,7 +253,7 @@ impl DecisionTreeClassifier {
|
||||
|
||||
}
|
||||
|
||||
fn find_best_cutoff<M: Matrix>(&mut self, visitor: &mut NodeVisitor<M>, mtry: usize) -> bool {
|
||||
fn find_best_cutoff<M: Matrix<T>>(&mut self, visitor: &mut NodeVisitor<T, M>, mtry: usize) -> bool {
|
||||
|
||||
let (n_rows, n_attr) = visitor.x.shape();
|
||||
|
||||
@@ -297,10 +303,10 @@ impl DecisionTreeClassifier {
|
||||
|
||||
}
|
||||
|
||||
fn find_best_split<M: Matrix>(&mut self, visitor: &mut NodeVisitor<M>, n: usize, count: &Vec<usize>, false_count: &mut Vec<usize>, parent_impurity: f64, j: usize){
|
||||
fn find_best_split<M: Matrix<T>>(&mut self, visitor: &mut NodeVisitor<T, M>, n: usize, count: &Vec<usize>, false_count: &mut Vec<usize>, parent_impurity: T, j: usize){
|
||||
|
||||
let mut true_count = vec![0; self.num_classes];
|
||||
let mut prevx = std::f64::NAN;
|
||||
let mut prevx = T::nan();
|
||||
let mut prevy = 0;
|
||||
|
||||
for i in visitor.order[j].iter() {
|
||||
@@ -328,11 +334,11 @@ impl DecisionTreeClassifier {
|
||||
|
||||
let true_label = which_max(&true_count);
|
||||
let false_label = which_max(false_count);
|
||||
let gain = parent_impurity - tc as f64 / n as f64 * impurity(&self.parameters.criterion, &true_count, tc) - fc as f64 / n as f64 * impurity(&self.parameters.criterion, &false_count, fc);
|
||||
let gain = parent_impurity - T::from(tc).unwrap() / T::from(n).unwrap() * impurity(&self.parameters.criterion, &true_count, tc) - T::from(fc).unwrap() / T::from(n).unwrap() * impurity(&self.parameters.criterion, &false_count, fc);
|
||||
|
||||
if self.nodes[visitor.node].split_score.is_nan() || gain > self.nodes[visitor.node].split_score {
|
||||
self.nodes[visitor.node].split_feature = j;
|
||||
self.nodes[visitor.node].split_value = (visitor.x.get(*i, j) + prevx) / 2.;
|
||||
self.nodes[visitor.node].split_value = (visitor.x.get(*i, j) + prevx) / T::two();
|
||||
self.nodes[visitor.node].split_score = gain;
|
||||
visitor.true_child_output = true_label;
|
||||
visitor.false_child_output = false_label;
|
||||
@@ -346,7 +352,7 @@ impl DecisionTreeClassifier {
|
||||
|
||||
}
|
||||
|
||||
fn split<'a, M: Matrix>(&mut self, mut visitor: NodeVisitor<'a, M>, mtry: usize, visitor_queue: &mut LinkedList<NodeVisitor<'a, M>>) -> bool {
|
||||
fn split<'a, M: Matrix<T>>(&mut self, mut visitor: NodeVisitor<'a, T, M>, mtry: usize, visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>) -> bool {
|
||||
let (n, _) = visitor.x.shape();
|
||||
let mut tc = 0;
|
||||
let mut fc = 0;
|
||||
@@ -366,8 +372,8 @@ impl DecisionTreeClassifier {
|
||||
|
||||
if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf {
|
||||
self.nodes[visitor.node].split_feature = 0;
|
||||
self.nodes[visitor.node].split_value = std::f64::NAN;
|
||||
self.nodes[visitor.node].split_score = std::f64::NAN;
|
||||
self.nodes[visitor.node].split_value = T::nan();
|
||||
self.nodes[visitor.node].split_score = T::nan();
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -381,13 +387,13 @@ impl DecisionTreeClassifier {
|
||||
|
||||
self.depth = u16::max(self.depth, visitor.level + 1);
|
||||
|
||||
let mut true_visitor = NodeVisitor::<M>::new(true_child_idx, true_samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
|
||||
let mut true_visitor = NodeVisitor::<T, M>::new(true_child_idx, true_samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
|
||||
|
||||
if self.find_best_cutoff(&mut true_visitor, mtry) {
|
||||
visitor_queue.push_back(true_visitor);
|
||||
}
|
||||
|
||||
let mut false_visitor = NodeVisitor::<M>::new(false_child_idx, visitor.samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
|
||||
let mut false_visitor = NodeVisitor::<T, M>::new(false_child_idx, visitor.samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
|
||||
|
||||
if self.find_best_cutoff(&mut false_visitor, mtry) {
|
||||
visitor_queue.push_back(false_visitor);
|
||||
@@ -405,9 +411,9 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn gini_impurity() {
|
||||
assert!((impurity(&SplitCriterion::Gini, &vec![7, 3], 10) - 0.42).abs() < std::f64::EPSILON);
|
||||
assert!((impurity(&SplitCriterion::Entropy, &vec![7, 3], 10) - 0.8812908992306927).abs() < std::f64::EPSILON);
|
||||
assert!((impurity(&SplitCriterion::ClassificationError, &vec![7, 3], 10) - 0.3).abs() < std::f64::EPSILON);
|
||||
assert!((impurity::<f64>(&SplitCriterion::Gini, &vec![7, 3], 10) - 0.42).abs() < std::f64::EPSILON);
|
||||
assert!((impurity::<f64>(&SplitCriterion::Entropy, &vec![7, 3], 10) - 0.8812908992306927).abs() < std::f64::EPSILON);
|
||||
assert!((impurity::<f64>(&SplitCriterion::ClassificationError, &vec![7, 3], 10) - 0.3).abs() < std::f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
Reference in New Issue
Block a user