//! # Decision Tree Classifier
//!
//! The process of building a classification tree is similar to the task of building a [regression tree](../decision_tree_regressor/index.html).
//! However, in the classification setting one of these criteriums is used for making the binary splits:
//!
//! * Classification error rate, \\(E = 1 - \max_k(p_{mk})\\)
//!
//! * Gini index, \\(G = \sum_{k=1}^K p_{mk}(1 - p_{mk})\\)
//!
//! * Entropy, \\(D = -\sum_{k=1}^K p_{mk}\log p_{mk}\\)
//!
//! where \\(p_{mk}\\) represents the proportion of training observations in the *m*th region that are from the *k*th class.
//!
//! The classification error rate is simply the fraction of the training observations in that region that do not belong to the most common class.
//! Classification error is not sufficiently sensitive for tree-growing, and in practice Gini index or Entropy are preferable.
//!
//! The Gini index is referred to as a measure of node purity. A small value indicates that a node contains predominantly observations from a single class.
//!
//! The Entropy, like Gini index will take on a small value if the *m*th node is pure.
//!
//! Example:
//!
//! ```
//! use rand::Rng;
//!
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::tree::decision_tree_classifier::*;
//!
//! // Iris dataset
//! let x = DenseMatrix::from_2d_array(&[
//! &[5.1, 3.5, 1.4, 0.2],
//! &[4.9, 3.0, 1.4, 0.2],
//! &[4.7, 3.2, 1.3, 0.2],
//! &[4.6, 3.1, 1.5, 0.2],
//! &[5.0, 3.6, 1.4, 0.2],
//! &[5.4, 3.9, 1.7, 0.4],
//! &[4.6, 3.4, 1.4, 0.3],
//! &[5.0, 3.4, 1.5, 0.2],
//! &[4.4, 2.9, 1.4, 0.2],
//! &[4.9, 3.1, 1.5, 0.1],
//! &[7.0, 3.2, 4.7, 1.4],
//! &[6.4, 3.2, 4.5, 1.5],
//! &[6.9, 3.1, 4.9, 1.5],
//! &[5.5, 2.3, 4.0, 1.3],
//! &[6.5, 2.8, 4.6, 1.5],
//! &[5.7, 2.8, 4.5, 1.3],
//! &[6.3, 3.3, 4.7, 1.6],
//! &[4.9, 2.4, 3.3, 1.0],
//! &[6.6, 2.9, 4.6, 1.3],
//! &[5.2, 2.7, 3.9, 1.4],
//! ]);
//! let y = vec![ 0, 0, 0, 0, 0, 0, 0, 0,
//! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
//!
//! let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap();
//!
//! let y_hat = tree.predict(&x).unwrap(); // use the same data for prediction
//! ```
//!
//!
//! ## References:
//! * ["Classification and regression trees", Breiman, L, Friedman, J H, Olshen, R A, and Stone, C J, 1984](https://www.sciencebase.gov/catalog/item/545d07dfe4b0ba8303f728c1)
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., Chapter 8](http://faculty.marshall.usc.edu/gareth-james/ISL/)
//!
//!
//!
use std::collections::LinkedList;
use std::default::Default;
use std::fmt::Debug;
use std::marker::PhantomData;
use rand::seq::SliceRandom;
use rand::Rng;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed;
use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1};
use crate::numbers::basenum::Number;
use crate::rand_custom::get_rng_impl;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// Parameters of Decision Tree
pub struct DecisionTreeClassifierParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Split criteria to use when building a tree.
pub criterion: SplitCriterion,
#[cfg_attr(feature = "serde", serde(default))]
/// The maximum depth of the tree.
pub max_depth: Option,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to be at a leaf node.
pub min_samples_leaf: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to split an internal node.
pub min_samples_split: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// Controls the randomness of the estimator
pub seed: Option,
}
/// Decision Tree
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct DecisionTreeClassifier<
TX: Number + PartialOrd,
TY: Number + Ord,
X: Array2,
Y: Array1,
> {
nodes: Vec,
parameters: Option,
num_classes: usize,
classes: Vec,
depth: u16,
_phantom_tx: PhantomData,
_phantom_x: PhantomData,
_phantom_y: PhantomData,
}
impl, Y: Array1>
DecisionTreeClassifier
{
/// Get nodes, return a shared reference
fn nodes(&self) -> &Vec {
self.nodes.as_ref()
}
/// Get parameters, return a shared reference
fn parameters(&self) -> &DecisionTreeClassifierParameters {
self.parameters.as_ref().unwrap()
}
/// get classes vector, return a shared reference
fn classes(&self) -> &Vec {
self.classes.as_ref()
}
/// Get depth of tree
fn depth(&self) -> u16 {
self.depth
}
}
/// The function to measure the quality of a split.
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub enum SplitCriterion {
/// [Gini index](../decision_tree_classifier/index.html)
Gini,
/// [Entropy](../decision_tree_classifier/index.html)
Entropy,
/// [Classification error](../decision_tree_classifier/index.html)
ClassificationError,
}
impl Default for SplitCriterion {
fn default() -> Self {
SplitCriterion::Gini
}
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
struct Node {
output: usize,
split_feature: usize,
split_value: Option,
split_score: Option,
true_child: Option,
false_child: Option,
}
impl, Y: Array1> PartialEq
for DecisionTreeClassifier
{
fn eq(&self, other: &Self) -> bool {
if self.depth != other.depth
|| self.num_classes != other.num_classes
|| self.nodes().len() != other.nodes().len()
{
false
} else {
self.classes()
.iter()
.zip(other.classes().iter())
.all(|(a, b)| a == b)
&& self
.nodes()
.iter()
.zip(other.nodes().iter())
.all(|(a, b)| a == b)
}
}
}
impl PartialEq for Node {
fn eq(&self, other: &Self) -> bool {
self.output == other.output
&& self.split_feature == other.split_feature
&& match (self.split_value, other.split_value) {
(Some(a), Some(b)) => (a - b).abs() < std::f64::EPSILON,
(None, None) => true,
_ => false,
}
&& match (self.split_score, other.split_score) {
(Some(a), Some(b)) => (a - b).abs() < std::f64::EPSILON,
(None, None) => true,
_ => false,
}
}
}
impl DecisionTreeClassifierParameters {
/// Split criteria to use when building a tree.
pub fn with_criterion(mut self, criterion: SplitCriterion) -> Self {
self.criterion = criterion;
self
}
/// The maximum depth of the tree.
pub fn with_max_depth(mut self, max_depth: u16) -> Self {
self.max_depth = Some(max_depth);
self
}
/// The minimum number of samples required to be at a leaf node.
pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
self.min_samples_leaf = min_samples_leaf;
self
}
/// The minimum number of samples required to split an internal node.
pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
self.min_samples_split = min_samples_split;
self
}
}
impl Default for DecisionTreeClassifierParameters {
fn default() -> Self {
DecisionTreeClassifierParameters {
criterion: SplitCriterion::default(),
max_depth: Option::None,
min_samples_leaf: 1,
min_samples_split: 2,
seed: Option::None,
}
}
}
/// DecisionTreeClassifier grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct DecisionTreeClassifierSearchParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub criterion: Vec,
#[cfg_attr(feature = "serde", serde(default))]
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub max_depth: Vec