From dcf636a5f1118e24142fc6a9caed5e96c08b4659 Mon Sep 17 00:00:00 2001 From: Volodymyr Orlov Date: Thu, 27 Aug 2020 19:47:11 -0700 Subject: [PATCH] feat: documents KNN Classifier --- src/lib.rs | 12 ++++----- src/neighbors/knn_classifier.rs | 48 +++++++++++++++++++++++++++++++++ src/neighbors/mod.rs | 38 +++++++++++++++++++++++++- 3 files changed, 91 insertions(+), 7 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index c0e0171..59b236e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,12 +6,12 @@ //! Welcome to SmartCore library, the most complete machine learning library for Rust! //! //! In SmartCore you will find implementation of these ML algorithms: -//! * Regression: Linear Regression (OLS), Decision Tree Regressor, Random Forest Regressor -//! * Classification: Logistic Regressor, Decision Tree Classifier, Random Forest Classifier, Unsupervised Nearest Neighbors (KNN) -//! * Clustering: K-Means -//! * Matrix decomposition: PCA, LU, QR, SVD, EVD -//! * Distance Metrics: Euclidian, Minkowski, Manhattan, Hamming, Mahalanobis -//! * Evaluation Metrics: Accuracy, AUC, Recall, Precision, F1, Mean Absolute Error, Mean Squared Error, R2 +//! * __Regression__: Linear Regression (OLS), Decision Tree Regressor, Random Forest Regressor, K Nearest Neighbors +//! * __Classification__: Logistic Regressor, Decision Tree Classifier, Random Forest Classifier, Supervised Nearest Neighbors (KNN) +//! * __Clustering__: K-Means +//! * __Matrix Decomposition__: PCA, LU, QR, SVD, EVD +//! * __Distance Metrics__: Euclidian, Minkowski, Manhattan, Hamming, Mahalanobis +//! * __Evaluation Metrics__: Accuracy, AUC, Recall, Precision, F1, Mean Absolute Error, Mean Squared Error, R2 //! //! Most of algorithms implemented in SmartCore operate on n-dimentional arrays. While you can use Rust vectors with all functions defined in this library //! we do recommend to go with one of the popular linear algebra libraries available in Rust. At this moment we support these packages: diff --git a/src/neighbors/knn_classifier.rs b/src/neighbors/knn_classifier.rs index 553011d..2e11697 100644 --- a/src/neighbors/knn_classifier.rs +++ b/src/neighbors/knn_classifier.rs @@ -1,3 +1,37 @@ +//! # K Nearest Neighbors Classifier +//! +//! SmartCore relies on 2 backend algorithms to speedup KNN queries: +//! * [`LinearSearch`](../../algorithm/neighbour/linear_search/index.html) +//! * [`CoverTree`](../../algorithm/neighbour/cover_tree/index.html) +//! +//! The parameter `k` controls the stability of the KNN estimate: when `k` is small the algorithm is sensitive to the noise in data. When `k` increases the estimator becomes more stable. +//! In terms of the bias variance trade-off the variance decreases with `k` and the bias is likely to increase with `k`. +//! +//! When you don't know which search algorithm and `k` value to use go with default parameters defined by `Default::default()` +//! +//! To fit the model to a 4 x 2 matrix with 4 training samples, 2 features per sample: +//! +//! ``` +//! use smartcore::linalg::naive::dense_matrix::*; +//! use smartcore::neighbors::knn_classifier::*; +//! use smartcore::math::distance::*; +//! +//! //your explanatory variables. Each row is a training sample with 2 numerical features +//! let x = DenseMatrix::from_array(&[ +//! &[1., 2.], +//! &[3., 4.], +//! &[5., 6.], +//! &[7., 8.], +//! &[9., 10.]]); +//! let y = vec![2., 2., 2., 3., 3.]; //your class labels +//! +//! let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()); +//! let y_hat = knn.predict(&x); +//! ``` +//! +//! variable `y_hat` will hold a vector with estimates of class labels +//! + use serde::{Deserialize, Serialize}; use crate::linalg::{row_iter, Matrix}; @@ -5,12 +39,16 @@ use crate::math::distance::Distance; use crate::math::num::FloatExt; use crate::neighbors::{KNNAlgorithm, KNNAlgorithmName}; +/// `KNNClassifier` parameters. Use `Default::default()` for default values. #[derive(Serialize, Deserialize, Debug)] pub struct KNNClassifierParameters { + /// backend search algorithm. See [`knn search algorithms`](../../algorithm/neighbour/index.html). `CoverTree` is default. pub algorithm: KNNAlgorithmName, + /// number of training samples to consider when estimating class for new point. Default value is 3. pub k: usize, } +/// K Nearest Neighbors Classifier #[derive(Serialize, Deserialize, Debug)] pub struct KNNClassifier, T>> { classes: Vec, @@ -52,6 +90,13 @@ impl, T>> PartialEq for KNNClassifier { } impl, T>> KNNClassifier { + /// Fits KNN Classifier to a NxM matrix where N is number of samples and M is number of features. + /// * `x` - training data + /// * `y` - vector with target values (classes) of length N + /// * `distance` - a function that defines a distance between each pair of point in training data. + /// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait. + /// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions. + /// * `parameters` - additional parameters like search algorithm and k pub fn fit>( x: &M, y: &M::RowVector, @@ -94,6 +139,9 @@ impl, T>> KNNClassifier { } } + /// Estimates the class labels for the provided data. + /// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features. + /// Returns a vector of size N with class estimates. pub fn predict>(&self, x: &M) -> M::RowVector { let mut result = M::zeros(1, x.shape().0); diff --git a/src/neighbors/mod.rs b/src/neighbors/mod.rs index f727726..c9aca51 100644 --- a/src/neighbors/mod.rs +++ b/src/neighbors/mod.rs @@ -1,4 +1,35 @@ //! # Nearest Neighbors +//! +//! +//! +//! The k-nearest neighbors (KNN) algorithm is a simple supervised machine learning algorithm that can be used to solve both classification and regression problems. +//! KNN is a non-parametric method that assumes that similar things exist in close proximity. +//! +//! During training the algorithms memorizes all training samples. To make a prediction it finds a predefined set of training samples closest in distance to the new +//! point and uses labels of found samples to calculate value of new point. The number of samples (k) is defined by user and does not change after training. +//! +//! The distance can be any metric measure that is defined as \\( d(x, y) \geq 0\\) +//! and follows three conditions: +//! 1. \\( d(x, y) = 0 \\) if and only \\( x = y \\), positive definiteness +//! 1. \\( d(x, y) = d(y, x) \\), symmetry +//! 1. \\( d(x, y) \leq d(x, z) + d(z, y) \\), subadditivity or triangle inequality +//! +//! for all \\(x, y, z \in Z \\) +//! +//! Neighbors-based methods are very simple and are known as non-generalizing machine learning methods since they simply remember all of its training data and is prone to overfitting. +//! Despite its disadvantages, nearest neighbors algorithms has been very successful in a large number of applications because of its flexibility and speed. +//! +//! __Advantages__ +//! * The algorithm is simple and fast. +//! * The algorithm is non-parametric: there’s no need to build a model, the algorithm simply stores all training samples in memory. +//! * The algorithm is versatile. It can be used for classification, regression. +//! +//! __Disadvantages__ +//! * The algorithm gets significantly slower as the number of examples and/or predictors/independent variables increase. +//! +//! ## References: +//! * ["Nearest Neighbor Pattern Classification" Cover, T.M., IEEE Transactions on Information Theory (1967)](http://ssg.mit.edu/cal/abs/2000_spring/np_dens/classification/cover67.pdf) +//! * ["The Elements of Statistical Learning: Data Mining, Inference, and Prediction" Trevor et al., 2nd edition, chapter 13](https://web.stanford.edu/~hastie/ElemStatLearn/) use crate::algorithm::neighbour::cover_tree::CoverTree; use crate::algorithm::neighbour::linear_search::LinearKNNSearch; @@ -6,13 +37,18 @@ use crate::math::distance::Distance; use crate::math::num::FloatExt; use serde::{Deserialize, Serialize}; -/// +/// K Nearest Neighbors Classifier pub mod knn_classifier; +/// K Nearest Neighbors Regressor pub mod knn_regressor; +/// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries. +/// `KNNAlgorithmName` maintains a list of supported search algorithms #[derive(Serialize, Deserialize, Debug)] pub enum KNNAlgorithmName { + /// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html) LinearSearch, + /// Cover Tree Search algorithm, see [`CoverTree`](../algorithm/neighbour/cover_tree/index.html) CoverTree, }