feat: documents ensemble models

This commit is contained in:
Volodymyr Orlov
2020-09-04 10:04:22 -07:00
parent 2c6a03ddc1
commit ecfbaac167
24 changed files with 191 additions and 31 deletions
+2 -1
View File
@@ -26,7 +26,8 @@
//! * ["Faster cover trees." Izbicki et al., Proceedings of the 32nd International Conference on Machine Learning, ICML'15 (2015)](http://www.cs.ucr.edu/~cshelton/papers/index.cgi%3FIzbShe15)
//! * ["The Elements of Statistical Learning: Data Mining, Inference, and Prediction" Trevor et al., 2nd edition, chapter 13](https://web.stanford.edu/~hastie/ElemStatLearn/)
//!
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
pub(crate) mod bbd_tree;
/// tree data structure for fast nearest neighbor search
+2 -1
View File
@@ -43,7 +43,8 @@
//!
//! ```
//!
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::fmt::Debug;
use serde::{Deserialize, Serialize};
+20
View File
@@ -1,2 +1,22 @@
//! # Ensemble Methods
//!
//! Combining predictions of several base estimators is a general-purpose procedure for reducing the variance of a statistical learning method.
//! When combined with bagging, ensemble models achive superior performance to individual estimators.
//!
//! The main idea behind bagging (or bootstrap aggregation) is to fit the same base model to a big number of random subsets of the original training
//! set and then aggregate their individual predictions to form a final prediction. In classification setting the overall prediction is the most commonly
//! occurring majority class among the individual predictions.
//!
//! In SmartCore you will find implementation of RandomForest - a popular averaging algorithms based on randomized [decision trees](../tree/index.html).
//! Random forests provide an improvement over bagged trees by way of a small tweak that decorrelates the trees. As in bagging, we build a number of
//! decision trees on bootstrapped training samples. But when building these decision trees, each time a split in a tree is considered,
//! a random sample of _m_ predictors is chosen as split candidates from the full set of _p_ predictors.
//!
//! ## References:
//!
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 8.2 Bagging, Random Forests, Boosting](http://faculty.marshall.usc.edu/gareth-james/ISL/)
/// Random forest classifier
pub mod random_forest_classifier;
/// Random forest regressor
pub mod random_forest_regressor;
+65 -4
View File
@@ -1,3 +1,50 @@
//! # Random Forest Classifier
//! A random forest is an ensemble estimator that fits multiple [decision trees](../../tree/index.html) to random subsets of the dataset and averages predictions
//! to improve the predictive accuracy and control over-fitting. See [ensemble models](../index.html) for more details.
//!
//! Bigger number of estimators in general improves performance of the algorithm with an increased cost of training time.
//! The random sample of _m_ predictors is typically set to be \\(\sqrt{p}\\) from the full set of _p_ predictors.
//!
//! Example:
//!
//! ```
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::ensemble::random_forest_classifier::*;
//!
//! // Iris dataset
//! let x = DenseMatrix::from_array(&[
//! &[5.1, 3.5, 1.4, 0.2],
//! &[4.9, 3.0, 1.4, 0.2],
//! &[4.7, 3.2, 1.3, 0.2],
//! &[4.6, 3.1, 1.5, 0.2],
//! &[5.0, 3.6, 1.4, 0.2],
//! &[5.4, 3.9, 1.7, 0.4],
//! &[4.6, 3.4, 1.4, 0.3],
//! &[5.0, 3.4, 1.5, 0.2],
//! &[4.4, 2.9, 1.4, 0.2],
//! &[4.9, 3.1, 1.5, 0.1],
//! &[7.0, 3.2, 4.7, 1.4],
//! &[6.4, 3.2, 4.5, 1.5],
//! &[6.9, 3.1, 4.9, 1.5],
//! &[5.5, 2.3, 4.0, 1.3],
//! &[6.5, 2.8, 4.6, 1.5],
//! &[5.7, 2.8, 4.5, 1.3],
//! &[6.3, 3.3, 4.7, 1.6],
//! &[4.9, 2.4, 3.3, 1.0],
//! &[6.6, 2.9, 4.6, 1.3],
//! &[5.2, 2.7, 3.9, 1.4],
//! ]);
//! let y = vec![
//! 0., 0., 0., 0., 0., 0., 0., 0.,
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
//! ];
//!
//! let classifier = RandomForestClassifier::fit(&x, &y, Default::default());
//! let y_hat = classifier.predict(&x); // use the same data for prediction
//! ```
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
extern crate rand;
use std::default::Default;
@@ -12,16 +59,25 @@ use crate::tree::decision_tree_classifier::{
which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion,
};
/// Parameters of the Random Forest algorithm.
/// Some parameters here are passed directly into base estimator.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct RandomForestClassifierParameters {
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub criterion: SplitCriterion,
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub max_depth: Option<u16>,
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub min_samples_leaf: usize,
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub min_samples_split: usize,
/// The number of trees in the forest.
pub n_trees: u16,
pub mtry: Option<usize>,
/// Number of random sample of predictors to use as split candidates.
pub m: Option<usize>,
}
/// Random Forest Classifier
#[derive(Serialize, Deserialize, Debug)]
pub struct RandomForestClassifier<T: RealNumber> {
parameters: RandomForestClassifierParameters,
@@ -57,12 +113,15 @@ impl Default for RandomForestClassifierParameters {
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 100,
mtry: Option::None,
m: Option::None,
}
}
}
impl<T: RealNumber> RandomForestClassifier<T> {
/// Build a forest of trees from the training set.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - the target class values
pub fn fit<M: Matrix<T>>(
x: &M,
y: &M::RowVector,
@@ -79,7 +138,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
yi[i] = classes.iter().position(|c| yc == *c).unwrap();
}
let mtry = parameters.mtry.unwrap_or(
let mtry = parameters.m.unwrap_or(
(T::from(num_attributes).unwrap())
.sqrt()
.floor()
@@ -110,6 +169,8 @@ impl<T: RealNumber> RandomForestClassifier<T> {
}
}
/// Predict class for `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
let mut result = M::zeros(1, x.shape().0);
@@ -199,7 +260,7 @@ mod tests {
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 1000,
mtry: Option::None,
m: Option::None,
},
);
+61 -4
View File
@@ -1,3 +1,47 @@
//! # Random Forest Regressor
//! A random forest is an ensemble estimator that fits multiple [decision trees](../../tree/index.html) to random subsets of the dataset and averages predictions
//! to improve the predictive accuracy and control over-fitting. See [ensemble models](../index.html) for more details.
//!
//! Bigger number of estimators in general improves performance of the algorithm with an increased cost of training time.
//! The random sample of _m_ predictors is typically set to be \\(\sqrt{p}\\) from the full set of _p_ predictors.
//!
//! Example:
//!
//! ```
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::ensemble::random_forest_regressor::*;
//!
//! // Longley dataset (https://www.statsmodels.org/stable/datasets/generated/longley.html)
//! let x = DenseMatrix::from_array(&[
//! &[234.289, 235.6, 159., 107.608, 1947., 60.323],
//! &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
//! &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
//! &[284.599, 335.1, 165., 110.929, 1950., 61.187],
//! &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
//! &[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
//! &[365.385, 187., 354.7, 115.094, 1953., 64.989],
//! &[363.112, 357.8, 335., 116.219, 1954., 63.761],
//! &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
//! &[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
//! &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
//! &[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
//! &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
//! ]);
//! let y = vec![
//! 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2,
//! 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9
//! ];
//!
//! let regressor = RandomForestRegressor::fit(&x, &y, Default::default());
//!
//! let y_hat = regressor.predict(&x); // use the same data for prediction
//! ```
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
extern crate rand;
use std::default::Default;
@@ -13,14 +57,22 @@ use crate::tree::decision_tree_regressor::{
};
#[derive(Serialize, Deserialize, Debug, Clone)]
/// Parameters of the Random Forest Regressor
/// Some parameters here are passed directly into base estimator.
pub struct RandomForestRegressorParameters {
/// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub max_depth: Option<u16>,
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub min_samples_leaf: usize,
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub min_samples_split: usize,
/// The number of trees in the forest.
pub n_trees: usize,
pub mtry: Option<usize>,
/// Number of random sample of predictors to use as split candidates.
pub m: Option<usize>,
}
/// Random Forest Regressor
#[derive(Serialize, Deserialize, Debug)]
pub struct RandomForestRegressor<T: RealNumber> {
parameters: RandomForestRegressorParameters,
@@ -34,7 +86,7 @@ impl Default for RandomForestRegressorParameters {
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 10,
mtry: Option::None,
m: Option::None,
}
}
}
@@ -55,6 +107,9 @@ impl<T: RealNumber> PartialEq for RandomForestRegressor<T> {
}
impl<T: RealNumber> RandomForestRegressor<T> {
/// Build a forest of trees from the training set.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - the target class values
pub fn fit<M: Matrix<T>>(
x: &M,
y: &M::RowVector,
@@ -63,7 +118,7 @@ impl<T: RealNumber> RandomForestRegressor<T> {
let (n_rows, num_attributes) = x.shape();
let mtry = parameters
.mtry
.m
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new();
@@ -85,6 +140,8 @@ impl<T: RealNumber> RandomForestRegressor<T> {
}
}
/// Predict class for `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
let mut result = M::zeros(1, x.shape().0);
@@ -162,7 +219,7 @@ mod tests {
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 1000,
mtry: Option::None,
m: Option::None,
},
)
.predict(&x);
+2 -1
View File
@@ -58,7 +58,8 @@
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 3. Linear Regression](http://faculty.marshall.usc.edu/gareth-james/ISL/)
//! * ["Numerical Recipes: The Art of Scientific Computing", Press W.H., Teukolsky S.A., Vetterling W.T, Flannery B.P, 3rd ed., Section 15.4 General Linear Least Squares](http://numerical.recipes/)
//!
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::fmt::Debug;
use serde::{Deserialize, Serialize};
+3 -1
View File
@@ -49,7 +49,9 @@
//! * ["Pattern Recognition and Machine Learning", C.M. Bishop, Linear Models for Classification](https://www.microsoft.com/en-us/research/uploads/prod/2006/01/Bishop-Pattern-Recognition-and-Machine-Learning-2006.pdf)
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 4.3 Logistic Regression](http://faculty.marshall.usc.edu/gareth-james/ISL/)
//! * ["On the Limited Memory Method for Large Scale Optimization", Nocedal et al., Mathematical Programming, 1989](http://users.iems.northwestern.edu/~nocedal/PDFfiles/limited.pdf)
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::fmt::Debug;
use std::marker::PhantomData;
+2 -1
View File
@@ -17,7 +17,8 @@
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 3. Linear Regression](http://faculty.marshall.usc.edu/gareth-james/ISL/)
//! * ["The Statistical Sleuth, A Course in Methods of Data Analysis", Ramsey F.L., Schafer D.W., Ch 7, 8, 3rd edition, 2013](http://www.statisticalsleuth.com/)
//!
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
pub mod linear_regression;
pub mod logistic_regression;
+2 -1
View File
@@ -16,7 +16,8 @@
//! let l2: f64 = Euclidian{}.distance(&x, &y);
//! ```
//!
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use serde::{Deserialize, Serialize};
use crate::math::num::RealNumber;
+2 -1
View File
@@ -16,7 +16,8 @@
//!
//! ```
//!
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use serde::{Deserialize, Serialize};
+2 -1
View File
@@ -38,7 +38,8 @@
//! * ["Introduction to Multivariate Statistical Analysis in Chemometrics", Varmuza, K., Filzmoser, P., 2016, p.46](https://www.taylorfrancis.com/books/9780429145049)
//! * ["Example of Calculating the Mahalanobis Distance", McCaffrey, J.D.](https://jamesmccaffrey.wordpress.com/2017/11/09/example-of-calculating-the-mahalanobis-distance/)
//!
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#![allow(non_snake_case)]
use std::marker::PhantomData;
+2 -1
View File
@@ -15,7 +15,8 @@
//!
//! let l1: f64 = Manhattan {}.distance(&x, &y);
//! ```
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use serde::{Deserialize, Serialize};
use crate::math::num::RealNumber;
+2 -1
View File
@@ -18,7 +18,8 @@
//! let l2: f64 = Minkowski { p: 2 }.distance(&x, &y);
//!
//! ```
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use serde::{Deserialize, Serialize};
+2 -1
View File
@@ -10,7 +10,8 @@
//!
//! A good distance metric helps to improve the performance of classification, clustering and information retrieval algorithms significantly.
//!
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
/// Euclidean Distance is the straight-line distance between two points in Euclidean spacere that presents the shortest distance between these points.
pub mod euclidian;
+2 -1
View File
@@ -14,7 +14,8 @@
//! let score: f64 = Accuracy {}.get_score(&y_pred, &y_true);
//! ```
//!
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector;
+2 -1
View File
@@ -16,7 +16,8 @@
//! let score: f64 = F1 {beta: 1.0}.get_score(&y_pred, &y_true);
//! ```
//!
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector;
+2 -1
View File
@@ -16,7 +16,8 @@
//! let mse: f64 = MeanAbsoluteError {}.get_score(&y_pred, &y_true);
//! ```
//!
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector;
+2 -1
View File
@@ -16,7 +16,8 @@
//! let mse: f64 = MeanSquareError {}.get_score(&y_pred, &y_true);
//! ```
//!
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector;
+2 -1
View File
@@ -16,7 +16,8 @@
//! let score: f64 = Precision {}.get_score(&y_pred, &y_true);
//! ```
//!
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector;
+2 -1
View File
@@ -16,7 +16,8 @@
//! let mse: f64 = MeanAbsoluteError {}.get_score(&y_pred, &y_true);
//! ```
//!
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector;
+2 -1
View File
@@ -16,7 +16,8 @@
//! let score: f64 = Recall {}.get_score(&y_pred, &y_true);
//! ```
//!
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector;
+2 -1
View File
@@ -29,7 +29,8 @@
//! * ["Nearest Neighbor Pattern Classification" Cover, T.M., IEEE Transactions on Information Theory (1967)](http://ssg.mit.edu/cal/abs/2000_spring/np_dens/classification/cover67.pdf)
//! * ["The Elements of Statistical Learning: Data Mining, Inference, and Prediction" Trevor et al., 2nd edition, chapter 13](https://web.stanford.edu/~hastie/ElemStatLearn/)
//!
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use crate::algorithm::neighbour::cover_tree::CoverTree;
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
+3 -2
View File
@@ -55,7 +55,8 @@
//! * ["Classification and regression trees", Breiman, L, Friedman, J H, Olshen, R A, and Stone, C J, 1984](https://www.sciencebase.gov/catalog/item/545d07dfe4b0ba8303f728c1)
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., Chapter 8](http://faculty.marshall.usc.edu/gareth-james/ISL/)
//!
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::collections::LinkedList;
use std::default::Default;
@@ -187,7 +188,7 @@ impl<'a, T: RealNumber, M: Matrix<T>> NodeVisitor<'a, T, M> {
}
impl<T: RealNumber> DecisionTreeRegressor<T> {
/// Build a regression tree regressor from the training data.
/// Build a decision tree regressor from the training data.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - the target values
pub fn fit<M: Matrix<T>>(
+3 -2
View File
@@ -6,7 +6,7 @@
//! and fit a simple prediction model within each region. In order to make a prediction for a given observation, \\(\hat{y}\\)
//! decision tree typically use the mean or the mode of the training observations in the region \\(R_j\\) to which it belongs.
//!
//! Decision trees often does not deliver best prediction accuracy when compared to other supervised learning approaches, such as linear and logistic regression.
//! Decision trees suffer from high variance and often does not deliver best prediction accuracy when compared to other supervised learning approaches, such as linear and logistic regression.
//! Hence some techniques such as [Random Forests](../ensemble/index.html) use more than one decision tree to improve performance of the algorithm.
//!
//! SmartCore uses [CART](https://en.wikipedia.org/wiki/Predictive_analytics#Classification_and_regression_trees_.28CART.29) learning technique to build both classification and regression trees.
@@ -16,7 +16,8 @@
//! * ["Classification and regression trees", Breiman, L, Friedman, J H, Olshen, R A, and Stone, C J, 1984](https://www.sciencebase.gov/catalog/item/545d07dfe4b0ba8303f728c1)
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., Chapter 8](http://faculty.marshall.usc.edu/gareth-james/ISL/)
//!
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
/// Classification tree for dependent variables that take a finite number of unordered values.
pub mod decision_tree_classifier;