feat: documents tree

This commit is contained in:
Volodymyr Orlov
2020-09-03 17:49:58 -07:00
parent 32081852ad
commit 2c6a03ddc1
5 changed files with 176 additions and 4 deletions
+81 -2
View File
@@ -1,3 +1,67 @@
//! # Decision Tree Classifier
//!
//! The process of building a classification tree is similar to the task of building a [regression tree](../decision_tree_regressor/index.html).
//! However, in the classification setting one of these criteriums is used for making the binary splits:
//!
//! * Classification error rate, \\(E = 1 - \max_k(p_{mk})\\)
//!
//! * Gini index, \\(G = \sum_{k=1}^K p_{mk}(1 - p_{mk})\\)
//!
//! * Entropy, \\(D = -\sum_{k=1}^K p_{mk}\log p_{mk}\\)
//!
//! where \\(p_{mk}\\) represents the proportion of training observations in the *m*th region that are from the *k*th class.
//!
//! The classification error rate is simply the fraction of the training observations in that region that do not belong to the most common class.
//! Classification error is not sufficiently sensitive for tree-growing, and in practice Gini index or Entropy are preferable.
//!
//! The Gini index is referred to as a measure of node purity. A small value indicates that a node contains predominantly observations from a single class.
//!
//! The Entropy, like Gini index will take on a small value if the *m*th node is pure.
//!
//! Example:
//!
//! ```
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::tree::decision_tree_classifier::*;
//!
//! // Iris dataset
//! let x = DenseMatrix::from_array(&[
//! &[5.1, 3.5, 1.4, 0.2],
//! &[4.9, 3.0, 1.4, 0.2],
//! &[4.7, 3.2, 1.3, 0.2],
//! &[4.6, 3.1, 1.5, 0.2],
//! &[5.0, 3.6, 1.4, 0.2],
//! &[5.4, 3.9, 1.7, 0.4],
//! &[4.6, 3.4, 1.4, 0.3],
//! &[5.0, 3.4, 1.5, 0.2],
//! &[4.4, 2.9, 1.4, 0.2],
//! &[4.9, 3.1, 1.5, 0.1],
//! &[7.0, 3.2, 4.7, 1.4],
//! &[6.4, 3.2, 4.5, 1.5],
//! &[6.9, 3.1, 4.9, 1.5],
//! &[5.5, 2.3, 4.0, 1.3],
//! &[6.5, 2.8, 4.6, 1.5],
//! &[5.7, 2.8, 4.5, 1.3],
//! &[6.3, 3.3, 4.7, 1.6],
//! &[4.9, 2.4, 3.3, 1.0],
//! &[6.6, 2.9, 4.6, 1.3],
//! &[5.2, 2.7, 3.9, 1.4],
//! ]);
//! let y = vec![ 0., 0., 0., 0., 0., 0., 0., 0.,
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
//!
//! let tree = DecisionTreeClassifier::fit(&x, &y, Default::default());
//!
//! let y_hat = tree.predict(&x); // use the same data for prediction
//! ```
//!
//!
//! ## References:
//! * ["Classification and regression trees", Breiman, L, Friedman, J H, Olshen, R A, and Stone, C J, 1984](https://www.sciencebase.gov/catalog/item/545d07dfe4b0ba8303f728c1)
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., Chapter 8](http://faculty.marshall.usc.edu/gareth-james/ISL/)
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::collections::LinkedList;
use std::default::Default;
use std::fmt::Debug;
@@ -10,13 +74,19 @@ use crate::linalg::Matrix;
use crate::math::num::RealNumber;
#[derive(Serialize, Deserialize, Debug)]
/// Parameters of Decision Tree
pub struct DecisionTreeClassifierParameters {
/// Split criteria to use when building a tree.
pub criterion: SplitCriterion,
/// The maximum depth of the tree.
pub max_depth: Option<u16>,
/// The minimum number of samples required to be at a leaf node.
pub min_samples_leaf: usize,
/// The minimum number of samples required to split an internal node.
pub min_samples_split: usize,
}
/// Decision Tree
#[derive(Serialize, Deserialize, Debug)]
pub struct DecisionTreeClassifier<T: RealNumber> {
nodes: Vec<Node<T>>,
@@ -26,15 +96,19 @@ pub struct DecisionTreeClassifier<T: RealNumber> {
depth: u16,
}
/// The function to measure the quality of a split.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum SplitCriterion {
/// [Gini index](../decision_tree_classifier/index.html)
Gini,
/// [Entropy](../decision_tree_classifier/index.html)
Entropy,
/// [Classification error](../decision_tree_classifier/index.html)
ClassificationError,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Node<T: RealNumber> {
struct Node<T: RealNumber> {
index: usize,
output: usize,
split_feature: usize,
@@ -194,6 +268,9 @@ pub(in crate) fn which_max(x: &Vec<usize>) -> usize {
}
impl<T: RealNumber> DecisionTreeClassifier<T> {
/// Build a decision tree classifier from the training data.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - the target class values
pub fn fit<M: Matrix<T>>(
x: &M,
y: &M::RowVector,
@@ -204,7 +281,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
}
pub fn fit_weak_learner<M: Matrix<T>>(
pub(crate) fn fit_weak_learner<M: Matrix<T>>(
x: &M,
y: &M::RowVector,
samples: Vec<usize>,
@@ -268,6 +345,8 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
tree
}
/// Predict class value for `x`.
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
let mut result = M::zeros(1, x.shape().0);