feat: documents tree
This commit is contained in:
@@ -1,3 +1,67 @@
|
||||
//! # Decision Tree Classifier
|
||||
//!
|
||||
//! The process of building a classification tree is similar to the task of building a [regression tree](../decision_tree_regressor/index.html).
|
||||
//! However, in the classification setting one of these criteriums is used for making the binary splits:
|
||||
//!
|
||||
//! * Classification error rate, \\(E = 1 - \max_k(p_{mk})\\)
|
||||
//!
|
||||
//! * Gini index, \\(G = \sum_{k=1}^K p_{mk}(1 - p_{mk})\\)
|
||||
//!
|
||||
//! * Entropy, \\(D = -\sum_{k=1}^K p_{mk}\log p_{mk}\\)
|
||||
//!
|
||||
//! where \\(p_{mk}\\) represents the proportion of training observations in the *m*th region that are from the *k*th class.
|
||||
//!
|
||||
//! The classification error rate is simply the fraction of the training observations in that region that do not belong to the most common class.
|
||||
//! Classification error is not sufficiently sensitive for tree-growing, and in practice Gini index or Entropy are preferable.
|
||||
//!
|
||||
//! The Gini index is referred to as a measure of node purity. A small value indicates that a node contains predominantly observations from a single class.
|
||||
//!
|
||||
//! The Entropy, like Gini index will take on a small value if the *m*th node is pure.
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::tree::decision_tree_classifier::*;
|
||||
//!
|
||||
//! // Iris dataset
|
||||
//! let x = DenseMatrix::from_array(&[
|
||||
//! &[5.1, 3.5, 1.4, 0.2],
|
||||
//! &[4.9, 3.0, 1.4, 0.2],
|
||||
//! &[4.7, 3.2, 1.3, 0.2],
|
||||
//! &[4.6, 3.1, 1.5, 0.2],
|
||||
//! &[5.0, 3.6, 1.4, 0.2],
|
||||
//! &[5.4, 3.9, 1.7, 0.4],
|
||||
//! &[4.6, 3.4, 1.4, 0.3],
|
||||
//! &[5.0, 3.4, 1.5, 0.2],
|
||||
//! &[4.4, 2.9, 1.4, 0.2],
|
||||
//! &[4.9, 3.1, 1.5, 0.1],
|
||||
//! &[7.0, 3.2, 4.7, 1.4],
|
||||
//! &[6.4, 3.2, 4.5, 1.5],
|
||||
//! &[6.9, 3.1, 4.9, 1.5],
|
||||
//! &[5.5, 2.3, 4.0, 1.3],
|
||||
//! &[6.5, 2.8, 4.6, 1.5],
|
||||
//! &[5.7, 2.8, 4.5, 1.3],
|
||||
//! &[6.3, 3.3, 4.7, 1.6],
|
||||
//! &[4.9, 2.4, 3.3, 1.0],
|
||||
//! &[6.6, 2.9, 4.6, 1.3],
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//! let y = vec![ 0., 0., 0., 0., 0., 0., 0., 0.,
|
||||
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
|
||||
//!
|
||||
//! let tree = DecisionTreeClassifier::fit(&x, &y, Default::default());
|
||||
//!
|
||||
//! let y_hat = tree.predict(&x); // use the same data for prediction
|
||||
//! ```
|
||||
//!
|
||||
//!
|
||||
//! ## References:
|
||||
//! * ["Classification and regression trees", Breiman, L, Friedman, J H, Olshen, R A, and Stone, C J, 1984](https://www.sciencebase.gov/catalog/item/545d07dfe4b0ba8303f728c1)
|
||||
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., Chapter 8](http://faculty.marshall.usc.edu/gareth-james/ISL/)
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::collections::LinkedList;
|
||||
use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
@@ -10,13 +74,19 @@ use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
/// Parameters of Decision Tree
|
||||
pub struct DecisionTreeClassifierParameters {
|
||||
/// Split criteria to use when building a tree.
|
||||
pub criterion: SplitCriterion,
|
||||
/// The maximum depth of the tree.
|
||||
pub max_depth: Option<u16>,
|
||||
/// The minimum number of samples required to be at a leaf node.
|
||||
pub min_samples_leaf: usize,
|
||||
/// The minimum number of samples required to split an internal node.
|
||||
pub min_samples_split: usize,
|
||||
}
|
||||
|
||||
/// Decision Tree
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct DecisionTreeClassifier<T: RealNumber> {
|
||||
nodes: Vec<Node<T>>,
|
||||
@@ -26,15 +96,19 @@ pub struct DecisionTreeClassifier<T: RealNumber> {
|
||||
depth: u16,
|
||||
}
|
||||
|
||||
/// The function to measure the quality of a split.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub enum SplitCriterion {
|
||||
/// [Gini index](../decision_tree_classifier/index.html)
|
||||
Gini,
|
||||
/// [Entropy](../decision_tree_classifier/index.html)
|
||||
Entropy,
|
||||
/// [Classification error](../decision_tree_classifier/index.html)
|
||||
ClassificationError,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Node<T: RealNumber> {
|
||||
struct Node<T: RealNumber> {
|
||||
index: usize,
|
||||
output: usize,
|
||||
split_feature: usize,
|
||||
@@ -194,6 +268,9 @@ pub(in crate) fn which_max(x: &Vec<usize>) -> usize {
|
||||
}
|
||||
|
||||
impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
/// Build a decision tree classifier from the training data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - the target class values
|
||||
pub fn fit<M: Matrix<T>>(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
@@ -204,7 +281,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||
}
|
||||
|
||||
pub fn fit_weak_learner<M: Matrix<T>>(
|
||||
pub(crate) fn fit_weak_learner<M: Matrix<T>>(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
samples: Vec<usize>,
|
||||
@@ -268,6 +345,8 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
tree
|
||||
}
|
||||
|
||||
/// Predict class value for `x`.
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
|
||||
let mut result = M::zeros(1, x.shape().0);
|
||||
|
||||
|
||||
@@ -1,3 +1,62 @@
|
||||
//! # Decision Tree Regressor
|
||||
//!
|
||||
//! The process of building a decision tree can be simplified to these two steps:
|
||||
//!
|
||||
//! 1. Divide the predictor space \\(X\\) into K distinct and non-overlapping regions, \\(R_1, R_2, ..., R_K\\).
|
||||
//! 1. For every observation that falls into the region \\(R_k\\), we make the same prediction, which is simply the mean of the response values for the training observations in \\(R_k\\).
|
||||
//!
|
||||
//! Regions \\(R_1, R_2, ..., R_K\\) are build in such a way that minimizes the residual sum of squares (RSS) given by
|
||||
//!
|
||||
//! \\[RSS = \sum_{k=1}^K\sum_{i \in R_k} (y_i - \hat{y}_{Rk})^2\\]
|
||||
//!
|
||||
//! where \\(\hat{y}_{Rk}\\) is the mean response for the training observations withing region _k_.
|
||||
//!
|
||||
//! SmartCore uses recursive binary splitting approach to build \\(R_1, R_2, ..., R_K\\) regions. The approach begins at the top of the tree and then successively splits the predictor space
|
||||
//! one predictor at a time. At each step of the tree-building process, the best split is made at that particular step, rather than looking ahead and picking a split that will lead to a better
|
||||
//! tree in some future step.
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::tree::decision_tree_regressor::*;
|
||||
//!
|
||||
//! // Longley dataset (https://www.statsmodels.org/stable/datasets/generated/longley.html)
|
||||
//! let x = DenseMatrix::from_array(&[
|
||||
//! &[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||
//! &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
//! &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
//! &[284.599, 335.1, 165., 110.929, 1950., 61.187],
|
||||
//! &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
//! &[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
|
||||
//! &[365.385, 187., 354.7, 115.094, 1953., 64.989],
|
||||
//! &[363.112, 357.8, 335., 116.219, 1954., 63.761],
|
||||
//! &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
//! &[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
//! &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
//! &[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
|
||||
//! &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
//! ]);
|
||||
//! let y: Vec<f64> = vec![
|
||||
//! 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0,
|
||||
//! 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9,
|
||||
//! ];
|
||||
//!
|
||||
//! let tree = DecisionTreeRegressor::fit(&x, &y, Default::default());
|
||||
//!
|
||||
//! let y_hat = tree.predict(&x); // use the same data for prediction
|
||||
//! ```
|
||||
//!
|
||||
//! ## References:
|
||||
//!
|
||||
//! * ["Classification and regression trees", Breiman, L, Friedman, J H, Olshen, R A, and Stone, C J, 1984](https://www.sciencebase.gov/catalog/item/545d07dfe4b0ba8303f728c1)
|
||||
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., Chapter 8](http://faculty.marshall.usc.edu/gareth-james/ISL/)
|
||||
//!
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
|
||||
use std::collections::LinkedList;
|
||||
use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
@@ -9,12 +68,17 @@ use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
/// Parameters of Regression Tree
|
||||
pub struct DecisionTreeRegressorParameters {
|
||||
/// The maximum depth of the tree.
|
||||
pub max_depth: Option<u16>,
|
||||
/// The minimum number of samples required to be at a leaf node.
|
||||
pub min_samples_leaf: usize,
|
||||
/// The minimum number of samples required to split an internal node.
|
||||
pub min_samples_split: usize,
|
||||
}
|
||||
|
||||
/// Regression Tree
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct DecisionTreeRegressor<T: RealNumber> {
|
||||
nodes: Vec<Node<T>>,
|
||||
@@ -23,7 +87,7 @@ pub struct DecisionTreeRegressor<T: RealNumber> {
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Node<T: RealNumber> {
|
||||
struct Node<T: RealNumber> {
|
||||
index: usize,
|
||||
output: T,
|
||||
split_feature: usize,
|
||||
@@ -123,6 +187,9 @@ impl<'a, T: RealNumber, M: Matrix<T>> NodeVisitor<'a, T, M> {
|
||||
}
|
||||
|
||||
impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
/// Build a regression tree regressor from the training data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - the target values
|
||||
pub fn fit<M: Matrix<T>>(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
@@ -133,7 +200,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||
}
|
||||
|
||||
pub fn fit_weak_learner<M: Matrix<T>>(
|
||||
pub(crate) fn fit_weak_learner<M: Matrix<T>>(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
samples: Vec<usize>,
|
||||
@@ -191,6 +258,8 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
tree
|
||||
}
|
||||
|
||||
/// Predict regression value for `x`.
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
|
||||
let mut result = M::zeros(1, x.shape().0);
|
||||
|
||||
|
||||
@@ -1,2 +1,24 @@
|
||||
//! # Classification and regression trees
|
||||
//!
|
||||
//! Tree-based methods are simple, nonparametric and useful algorithms in machine learning that are easy to understand and interpret.
|
||||
//!
|
||||
//! Decision trees recursively partition the predictor space \\(X\\) into k distinct and non-overlapping rectangular regions \\(R_1, R_2,..., R_k\\)
|
||||
//! and fit a simple prediction model within each region. In order to make a prediction for a given observation, \\(\hat{y}\\)
|
||||
//! decision tree typically use the mean or the mode of the training observations in the region \\(R_j\\) to which it belongs.
|
||||
//!
|
||||
//! Decision trees often does not deliver best prediction accuracy when compared to other supervised learning approaches, such as linear and logistic regression.
|
||||
//! Hence some techniques such as [Random Forests](../ensemble/index.html) use more than one decision tree to improve performance of the algorithm.
|
||||
//!
|
||||
//! SmartCore uses [CART](https://en.wikipedia.org/wiki/Predictive_analytics#Classification_and_regression_trees_.28CART.29) learning technique to build both classification and regression trees.
|
||||
//!
|
||||
//! ## References:
|
||||
//!
|
||||
//! * ["Classification and regression trees", Breiman, L, Friedman, J H, Olshen, R A, and Stone, C J, 1984](https://www.sciencebase.gov/catalog/item/545d07dfe4b0ba8303f728c1)
|
||||
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., Chapter 8](http://faculty.marshall.usc.edu/gareth-james/ISL/)
|
||||
//!
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
|
||||
/// Classification tree for dependent variables that take a finite number of unordered values.
|
||||
pub mod decision_tree_classifier;
|
||||
/// Regression tree for for dependent variables that take continuous or ordered discrete values.
|
||||
pub mod decision_tree_regressor;
|
||||
|
||||
Reference in New Issue
Block a user