feat: documents ensemble models
This commit is contained in:
@@ -26,7 +26,8 @@
|
||||
//! * ["Faster cover trees." Izbicki et al., Proceedings of the 32nd International Conference on Machine Learning, ICML'15 (2015)](http://www.cs.ucr.edu/~cshelton/papers/index.cgi%3FIzbShe15)
|
||||
//! * ["The Elements of Statistical Learning: Data Mining, Inference, and Prediction" Trevor et al., 2nd edition, chapter 13](https://web.stanford.edu/~hastie/ElemStatLearn/)
|
||||
//!
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
pub(crate) mod bbd_tree;
|
||||
/// tree data structure for fast nearest neighbor search
|
||||
|
||||
@@ -43,7 +43,8 @@
|
||||
//!
|
||||
//! ```
|
||||
//!
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::fmt::Debug;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -1,2 +1,22 @@
|
||||
//! # Ensemble Methods
|
||||
//!
|
||||
//! Combining predictions of several base estimators is a general-purpose procedure for reducing the variance of a statistical learning method.
|
||||
//! When combined with bagging, ensemble models achive superior performance to individual estimators.
|
||||
//!
|
||||
//! The main idea behind bagging (or bootstrap aggregation) is to fit the same base model to a big number of random subsets of the original training
|
||||
//! set and then aggregate their individual predictions to form a final prediction. In classification setting the overall prediction is the most commonly
|
||||
//! occurring majority class among the individual predictions.
|
||||
//!
|
||||
//! In SmartCore you will find implementation of RandomForest - a popular averaging algorithms based on randomized [decision trees](../tree/index.html).
|
||||
//! Random forests provide an improvement over bagged trees by way of a small tweak that decorrelates the trees. As in bagging, we build a number of
|
||||
//! decision trees on bootstrapped training samples. But when building these decision trees, each time a split in a tree is considered,
|
||||
//! a random sample of _m_ predictors is chosen as split candidates from the full set of _p_ predictors.
|
||||
//!
|
||||
//! ## References:
|
||||
//!
|
||||
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 8.2 Bagging, Random Forests, Boosting](http://faculty.marshall.usc.edu/gareth-james/ISL/)
|
||||
|
||||
/// Random forest classifier
|
||||
pub mod random_forest_classifier;
|
||||
/// Random forest regressor
|
||||
pub mod random_forest_regressor;
|
||||
|
||||
@@ -1,3 +1,50 @@
|
||||
//! # Random Forest Classifier
|
||||
//! A random forest is an ensemble estimator that fits multiple [decision trees](../../tree/index.html) to random subsets of the dataset and averages predictions
|
||||
//! to improve the predictive accuracy and control over-fitting. See [ensemble models](../index.html) for more details.
|
||||
//!
|
||||
//! Bigger number of estimators in general improves performance of the algorithm with an increased cost of training time.
|
||||
//! The random sample of _m_ predictors is typically set to be \\(\sqrt{p}\\) from the full set of _p_ predictors.
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::ensemble::random_forest_classifier::*;
|
||||
//!
|
||||
//! // Iris dataset
|
||||
//! let x = DenseMatrix::from_array(&[
|
||||
//! &[5.1, 3.5, 1.4, 0.2],
|
||||
//! &[4.9, 3.0, 1.4, 0.2],
|
||||
//! &[4.7, 3.2, 1.3, 0.2],
|
||||
//! &[4.6, 3.1, 1.5, 0.2],
|
||||
//! &[5.0, 3.6, 1.4, 0.2],
|
||||
//! &[5.4, 3.9, 1.7, 0.4],
|
||||
//! &[4.6, 3.4, 1.4, 0.3],
|
||||
//! &[5.0, 3.4, 1.5, 0.2],
|
||||
//! &[4.4, 2.9, 1.4, 0.2],
|
||||
//! &[4.9, 3.1, 1.5, 0.1],
|
||||
//! &[7.0, 3.2, 4.7, 1.4],
|
||||
//! &[6.4, 3.2, 4.5, 1.5],
|
||||
//! &[6.9, 3.1, 4.9, 1.5],
|
||||
//! &[5.5, 2.3, 4.0, 1.3],
|
||||
//! &[6.5, 2.8, 4.6, 1.5],
|
||||
//! &[5.7, 2.8, 4.5, 1.3],
|
||||
//! &[6.3, 3.3, 4.7, 1.6],
|
||||
//! &[4.9, 2.4, 3.3, 1.0],
|
||||
//! &[6.6, 2.9, 4.6, 1.3],
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//! let y = vec![
|
||||
//! 0., 0., 0., 0., 0., 0., 0., 0.,
|
||||
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
//! ];
|
||||
//!
|
||||
//! let classifier = RandomForestClassifier::fit(&x, &y, Default::default());
|
||||
//! let y_hat = classifier.predict(&x); // use the same data for prediction
|
||||
//! ```
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
extern crate rand;
|
||||
|
||||
use std::default::Default;
|
||||
@@ -12,16 +59,25 @@ use crate::tree::decision_tree_classifier::{
|
||||
which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion,
|
||||
};
|
||||
|
||||
/// Parameters of the Random Forest algorithm.
|
||||
/// Some parameters here are passed directly into base estimator.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct RandomForestClassifierParameters {
|
||||
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub criterion: SplitCriterion,
|
||||
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub max_depth: Option<u16>,
|
||||
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub min_samples_leaf: usize,
|
||||
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub min_samples_split: usize,
|
||||
/// The number of trees in the forest.
|
||||
pub n_trees: u16,
|
||||
pub mtry: Option<usize>,
|
||||
/// Number of random sample of predictors to use as split candidates.
|
||||
pub m: Option<usize>,
|
||||
}
|
||||
|
||||
/// Random Forest Classifier
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct RandomForestClassifier<T: RealNumber> {
|
||||
parameters: RandomForestClassifierParameters,
|
||||
@@ -57,12 +113,15 @@ impl Default for RandomForestClassifierParameters {
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2,
|
||||
n_trees: 100,
|
||||
mtry: Option::None,
|
||||
m: Option::None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
/// Build a forest of trees from the training set.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - the target class values
|
||||
pub fn fit<M: Matrix<T>>(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
@@ -79,7 +138,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
yi[i] = classes.iter().position(|c| yc == *c).unwrap();
|
||||
}
|
||||
|
||||
let mtry = parameters.mtry.unwrap_or(
|
||||
let mtry = parameters.m.unwrap_or(
|
||||
(T::from(num_attributes).unwrap())
|
||||
.sqrt()
|
||||
.floor()
|
||||
@@ -110,6 +169,8 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Predict class for `x`
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
|
||||
let mut result = M::zeros(1, x.shape().0);
|
||||
|
||||
@@ -199,7 +260,7 @@ mod tests {
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2,
|
||||
n_trees: 1000,
|
||||
mtry: Option::None,
|
||||
m: Option::None,
|
||||
},
|
||||
);
|
||||
|
||||
|
||||
@@ -1,3 +1,47 @@
|
||||
//! # Random Forest Regressor
|
||||
//! A random forest is an ensemble estimator that fits multiple [decision trees](../../tree/index.html) to random subsets of the dataset and averages predictions
|
||||
//! to improve the predictive accuracy and control over-fitting. See [ensemble models](../index.html) for more details.
|
||||
//!
|
||||
//! Bigger number of estimators in general improves performance of the algorithm with an increased cost of training time.
|
||||
//! The random sample of _m_ predictors is typically set to be \\(\sqrt{p}\\) from the full set of _p_ predictors.
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::ensemble::random_forest_regressor::*;
|
||||
//!
|
||||
//! // Longley dataset (https://www.statsmodels.org/stable/datasets/generated/longley.html)
|
||||
//! let x = DenseMatrix::from_array(&[
|
||||
//! &[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||
//! &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
//! &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
//! &[284.599, 335.1, 165., 110.929, 1950., 61.187],
|
||||
//! &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
//! &[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
|
||||
//! &[365.385, 187., 354.7, 115.094, 1953., 64.989],
|
||||
//! &[363.112, 357.8, 335., 116.219, 1954., 63.761],
|
||||
//! &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
//! &[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
//! &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
//! &[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
|
||||
//! &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
//! ]);
|
||||
//! let y = vec![
|
||||
//! 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2,
|
||||
//! 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9
|
||||
//! ];
|
||||
//!
|
||||
//! let regressor = RandomForestRegressor::fit(&x, &y, Default::default());
|
||||
//!
|
||||
//! let y_hat = regressor.predict(&x); // use the same data for prediction
|
||||
//! ```
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
extern crate rand;
|
||||
|
||||
use std::default::Default;
|
||||
@@ -13,14 +57,22 @@ use crate::tree::decision_tree_regressor::{
|
||||
};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
/// Parameters of the Random Forest Regressor
|
||||
/// Some parameters here are passed directly into base estimator.
|
||||
pub struct RandomForestRegressorParameters {
|
||||
/// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||
pub max_depth: Option<u16>,
|
||||
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||
pub min_samples_leaf: usize,
|
||||
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||
pub min_samples_split: usize,
|
||||
/// The number of trees in the forest.
|
||||
pub n_trees: usize,
|
||||
pub mtry: Option<usize>,
|
||||
/// Number of random sample of predictors to use as split candidates.
|
||||
pub m: Option<usize>,
|
||||
}
|
||||
|
||||
/// Random Forest Regressor
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct RandomForestRegressor<T: RealNumber> {
|
||||
parameters: RandomForestRegressorParameters,
|
||||
@@ -34,7 +86,7 @@ impl Default for RandomForestRegressorParameters {
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2,
|
||||
n_trees: 10,
|
||||
mtry: Option::None,
|
||||
m: Option::None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -55,6 +107,9 @@ impl<T: RealNumber> PartialEq for RandomForestRegressor<T> {
|
||||
}
|
||||
|
||||
impl<T: RealNumber> RandomForestRegressor<T> {
|
||||
/// Build a forest of trees from the training set.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - the target class values
|
||||
pub fn fit<M: Matrix<T>>(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
@@ -63,7 +118,7 @@ impl<T: RealNumber> RandomForestRegressor<T> {
|
||||
let (n_rows, num_attributes) = x.shape();
|
||||
|
||||
let mtry = parameters
|
||||
.mtry
|
||||
.m
|
||||
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
|
||||
|
||||
let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new();
|
||||
@@ -85,6 +140,8 @@ impl<T: RealNumber> RandomForestRegressor<T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Predict class for `x`
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
|
||||
let mut result = M::zeros(1, x.shape().0);
|
||||
|
||||
@@ -162,7 +219,7 @@ mod tests {
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2,
|
||||
n_trees: 1000,
|
||||
mtry: Option::None,
|
||||
m: Option::None,
|
||||
},
|
||||
)
|
||||
.predict(&x);
|
||||
|
||||
@@ -58,7 +58,8 @@
|
||||
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 3. Linear Regression](http://faculty.marshall.usc.edu/gareth-james/ISL/)
|
||||
//! * ["Numerical Recipes: The Art of Scientific Computing", Press W.H., Teukolsky S.A., Vetterling W.T, Flannery B.P, 3rd ed., Section 15.4 General Linear Least Squares](http://numerical.recipes/)
|
||||
//!
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::fmt::Debug;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -49,7 +49,9 @@
|
||||
//! * ["Pattern Recognition and Machine Learning", C.M. Bishop, Linear Models for Classification](https://www.microsoft.com/en-us/research/uploads/prod/2006/01/Bishop-Pattern-Recognition-and-Machine-Learning-2006.pdf)
|
||||
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 4.3 Logistic Regression](http://faculty.marshall.usc.edu/gareth-james/ISL/)
|
||||
//! * ["On the Limited Memory Method for Large Scale Optimization", Nocedal et al., Mathematical Programming, 1989](http://users.iems.northwestern.edu/~nocedal/PDFfiles/limited.pdf)
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
|
||||
+2
-1
@@ -17,7 +17,8 @@
|
||||
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 3. Linear Regression](http://faculty.marshall.usc.edu/gareth-james/ISL/)
|
||||
//! * ["The Statistical Sleuth, A Course in Methods of Data Analysis", Ramsey F.L., Schafer D.W., Ch 7, 8, 3rd edition, 2013](http://www.statisticalsleuth.com/)
|
||||
//!
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
pub mod linear_regression;
|
||||
pub mod logistic_regression;
|
||||
|
||||
@@ -16,7 +16,8 @@
|
||||
//! let l2: f64 = Euclidian{}.distance(&x, &y);
|
||||
//! ```
|
||||
//!
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
@@ -16,7 +16,8 @@
|
||||
//!
|
||||
//! ```
|
||||
//!
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
|
||||
@@ -38,7 +38,8 @@
|
||||
//! * ["Introduction to Multivariate Statistical Analysis in Chemometrics", Varmuza, K., Filzmoser, P., 2016, p.46](https://www.taylorfrancis.com/books/9780429145049)
|
||||
//! * ["Example of Calculating the Mahalanobis Distance", McCaffrey, J.D.](https://jamesmccaffrey.wordpress.com/2017/11/09/example-of-calculating-the-mahalanobis-distance/)
|
||||
//!
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
use std::marker::PhantomData;
|
||||
|
||||
@@ -15,7 +15,8 @@
|
||||
//!
|
||||
//! let l1: f64 = Manhattan {}.distance(&x, &y);
|
||||
//! ```
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
@@ -18,7 +18,8 @@
|
||||
//! let l2: f64 = Minkowski { p: 2 }.distance(&x, &y);
|
||||
//!
|
||||
//! ```
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
|
||||
@@ -10,7 +10,8 @@
|
||||
//!
|
||||
//! A good distance metric helps to improve the performance of classification, clustering and information retrieval algorithms significantly.
|
||||
//!
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
/// Euclidean Distance is the straight-line distance between two points in Euclidean spacere that presents the shortest distance between these points.
|
||||
pub mod euclidian;
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
//! let score: f64 = Accuracy {}.get_score(&y_pred, &y_true);
|
||||
//! ```
|
||||
//!
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
|
||||
+2
-1
@@ -16,7 +16,8 @@
|
||||
//! let score: f64 = F1 {beta: 1.0}.get_score(&y_pred, &y_true);
|
||||
//! ```
|
||||
//!
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
|
||||
@@ -16,7 +16,8 @@
|
||||
//! let mse: f64 = MeanAbsoluteError {}.get_score(&y_pred, &y_true);
|
||||
//! ```
|
||||
//!
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
|
||||
@@ -16,7 +16,8 @@
|
||||
//! let mse: f64 = MeanSquareError {}.get_score(&y_pred, &y_true);
|
||||
//! ```
|
||||
//!
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
|
||||
@@ -16,7 +16,8 @@
|
||||
//! let score: f64 = Precision {}.get_score(&y_pred, &y_true);
|
||||
//! ```
|
||||
//!
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
|
||||
+2
-1
@@ -16,7 +16,8 @@
|
||||
//! let mse: f64 = MeanAbsoluteError {}.get_score(&y_pred, &y_true);
|
||||
//! ```
|
||||
//!
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
|
||||
@@ -16,7 +16,8 @@
|
||||
//! let score: f64 = Recall {}.get_score(&y_pred, &y_true);
|
||||
//! ```
|
||||
//!
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
|
||||
@@ -29,7 +29,8 @@
|
||||
//! * ["Nearest Neighbor Pattern Classification" Cover, T.M., IEEE Transactions on Information Theory (1967)](http://ssg.mit.edu/cal/abs/2000_spring/np_dens/classification/cover67.pdf)
|
||||
//! * ["The Elements of Statistical Learning: Data Mining, Inference, and Prediction" Trevor et al., 2nd edition, chapter 13](https://web.stanford.edu/~hastie/ElemStatLearn/)
|
||||
//!
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
use crate::algorithm::neighbour::cover_tree::CoverTree;
|
||||
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
|
||||
|
||||
@@ -55,7 +55,8 @@
|
||||
//! * ["Classification and regression trees", Breiman, L, Friedman, J H, Olshen, R A, and Stone, C J, 1984](https://www.sciencebase.gov/catalog/item/545d07dfe4b0ba8303f728c1)
|
||||
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., Chapter 8](http://faculty.marshall.usc.edu/gareth-james/ISL/)
|
||||
//!
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
use std::collections::LinkedList;
|
||||
use std::default::Default;
|
||||
@@ -187,7 +188,7 @@ impl<'a, T: RealNumber, M: Matrix<T>> NodeVisitor<'a, T, M> {
|
||||
}
|
||||
|
||||
impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
/// Build a regression tree regressor from the training data.
|
||||
/// Build a decision tree regressor from the training data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - the target values
|
||||
pub fn fit<M: Matrix<T>>(
|
||||
|
||||
+3
-2
@@ -6,7 +6,7 @@
|
||||
//! and fit a simple prediction model within each region. In order to make a prediction for a given observation, \\(\hat{y}\\)
|
||||
//! decision tree typically use the mean or the mode of the training observations in the region \\(R_j\\) to which it belongs.
|
||||
//!
|
||||
//! Decision trees often does not deliver best prediction accuracy when compared to other supervised learning approaches, such as linear and logistic regression.
|
||||
//! Decision trees suffer from high variance and often does not deliver best prediction accuracy when compared to other supervised learning approaches, such as linear and logistic regression.
|
||||
//! Hence some techniques such as [Random Forests](../ensemble/index.html) use more than one decision tree to improve performance of the algorithm.
|
||||
//!
|
||||
//! SmartCore uses [CART](https://en.wikipedia.org/wiki/Predictive_analytics#Classification_and_regression_trees_.28CART.29) learning technique to build both classification and regression trees.
|
||||
@@ -16,7 +16,8 @@
|
||||
//! * ["Classification and regression trees", Breiman, L, Friedman, J H, Olshen, R A, and Stone, C J, 1984](https://www.sciencebase.gov/catalog/item/545d07dfe4b0ba8303f728c1)
|
||||
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., Chapter 8](http://faculty.marshall.usc.edu/gareth-james/ISL/)
|
||||
//!
|
||||
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
/// Classification tree for dependent variables that take a finite number of unordered values.
|
||||
pub mod decision_tree_classifier;
|
||||
|
||||
Reference in New Issue
Block a user