feat: documents KNN Classifier

This commit is contained in:
Volodymyr Orlov
2020-08-27 19:47:11 -07:00
parent 762bc3d765
commit dcf636a5f1
3 changed files with 91 additions and 7 deletions
+6 -6
View File
@@ -6,12 +6,12 @@
//! Welcome to SmartCore library, the most complete machine learning library for Rust!
//!
//! In SmartCore you will find implementation of these ML algorithms:
//! * Regression: Linear Regression (OLS), Decision Tree Regressor, Random Forest Regressor
//! * Classification: Logistic Regressor, Decision Tree Classifier, Random Forest Classifier, Unsupervised Nearest Neighbors (KNN)
//! * Clustering: K-Means
//! * Matrix decomposition: PCA, LU, QR, SVD, EVD
//! * Distance Metrics: Euclidian, Minkowski, Manhattan, Hamming, Mahalanobis
//! * Evaluation Metrics: Accuracy, AUC, Recall, Precision, F1, Mean Absolute Error, Mean Squared Error, R2
//! * __Regression__: Linear Regression (OLS), Decision Tree Regressor, Random Forest Regressor, K Nearest Neighbors
//! * __Classification__: Logistic Regressor, Decision Tree Classifier, Random Forest Classifier, Supervised Nearest Neighbors (KNN)
//! * __Clustering__: K-Means
//! * __Matrix Decomposition__: PCA, LU, QR, SVD, EVD
//! * __Distance Metrics__: Euclidian, Minkowski, Manhattan, Hamming, Mahalanobis
//! * __Evaluation Metrics__: Accuracy, AUC, Recall, Precision, F1, Mean Absolute Error, Mean Squared Error, R2
//!
//! Most of algorithms implemented in SmartCore operate on n-dimentional arrays. While you can use Rust vectors with all functions defined in this library
//! we do recommend to go with one of the popular linear algebra libraries available in Rust. At this moment we support these packages:
+48
View File
@@ -1,3 +1,37 @@
//! # K Nearest Neighbors Classifier
//!
//! SmartCore relies on 2 backend algorithms to speedup KNN queries:
//! * [`LinearSearch`](../../algorithm/neighbour/linear_search/index.html)
//! * [`CoverTree`](../../algorithm/neighbour/cover_tree/index.html)
//!
//! The parameter `k` controls the stability of the KNN estimate: when `k` is small the algorithm is sensitive to the noise in data. When `k` increases the estimator becomes more stable.
//! In terms of the bias variance trade-off the variance decreases with `k` and the bias is likely to increase with `k`.
//!
//! When you don't know which search algorithm and `k` value to use go with default parameters defined by `Default::default()`
//!
//! To fit the model to a 4 x 2 matrix with 4 training samples, 2 features per sample:
//!
//! ```
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::neighbors::knn_classifier::*;
//! use smartcore::math::distance::*;
//!
//! //your explanatory variables. Each row is a training sample with 2 numerical features
//! let x = DenseMatrix::from_array(&[
//! &[1., 2.],
//! &[3., 4.],
//! &[5., 6.],
//! &[7., 8.],
//! &[9., 10.]]);
//! let y = vec![2., 2., 2., 3., 3.]; //your class labels
//!
//! let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default());
//! let y_hat = knn.predict(&x);
//! ```
//!
//! variable `y_hat` will hold a vector with estimates of class labels
//!
use serde::{Deserialize, Serialize};
use crate::linalg::{row_iter, Matrix};
@@ -5,12 +39,16 @@ use crate::math::distance::Distance;
use crate::math::num::FloatExt;
use crate::neighbors::{KNNAlgorithm, KNNAlgorithmName};
/// `KNNClassifier` parameters. Use `Default::default()` for default values.
#[derive(Serialize, Deserialize, Debug)]
pub struct KNNClassifierParameters {
/// backend search algorithm. See [`knn search algorithms`](../../algorithm/neighbour/index.html). `CoverTree` is default.
pub algorithm: KNNAlgorithmName,
/// number of training samples to consider when estimating class for new point. Default value is 3.
pub k: usize,
}
/// K Nearest Neighbors Classifier
#[derive(Serialize, Deserialize, Debug)]
pub struct KNNClassifier<T: FloatExt, D: Distance<Vec<T>, T>> {
classes: Vec<T>,
@@ -52,6 +90,13 @@ impl<T: FloatExt, D: Distance<Vec<T>, T>> PartialEq for KNNClassifier<T, D> {
}
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
/// Fits KNN Classifier to a NxM matrix where N is number of samples and M is number of features.
/// * `x` - training data
/// * `y` - vector with target values (classes) of length N
/// * `distance` - a function that defines a distance between each pair of point in training data.
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
/// * `parameters` - additional parameters like search algorithm and k
pub fn fit<M: Matrix<T>>(
x: &M,
y: &M::RowVector,
@@ -94,6 +139,9 @@ impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
}
}
/// Estimates the class labels for the provided data.
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
/// Returns a vector of size N with class estimates.
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
let mut result = M::zeros(1, x.shape().0);
+37 -1
View File
@@ -1,4 +1,35 @@
//! # Nearest Neighbors
//!
//! <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
//!
//! The k-nearest neighbors (KNN) algorithm is a simple supervised machine learning algorithm that can be used to solve both classification and regression problems.
//! KNN is a non-parametric method that assumes that similar things exist in close proximity.
//!
//! During training the algorithms memorizes all training samples. To make a prediction it finds a predefined set of training samples closest in distance to the new
//! point and uses labels of found samples to calculate value of new point. The number of samples (k) is defined by user and does not change after training.
//!
//! The distance can be any metric measure that is defined as \\( d(x, y) \geq 0\\)
//! and follows three conditions:
//! 1. \\( d(x, y) = 0 \\) if and only \\( x = y \\), positive definiteness
//! 1. \\( d(x, y) = d(y, x) \\), symmetry
//! 1. \\( d(x, y) \leq d(x, z) + d(z, y) \\), subadditivity or triangle inequality
//!
//! for all \\(x, y, z \in Z \\)
//!
//! Neighbors-based methods are very simple and are known as non-generalizing machine learning methods since they simply remember all of its training data and is prone to overfitting.
//! Despite its disadvantages, nearest neighbors algorithms has been very successful in a large number of applications because of its flexibility and speed.
//!
//! __Advantages__
//! * The algorithm is simple and fast.
//! * The algorithm is non-parametric: theres no need to build a model, the algorithm simply stores all training samples in memory.
//! * The algorithm is versatile. It can be used for classification, regression.
//!
//! __Disadvantages__
//! * The algorithm gets significantly slower as the number of examples and/or predictors/independent variables increase.
//!
//! ## References:
//! * ["Nearest Neighbor Pattern Classification" Cover, T.M., IEEE Transactions on Information Theory (1967)](http://ssg.mit.edu/cal/abs/2000_spring/np_dens/classification/cover67.pdf)
//! * ["The Elements of Statistical Learning: Data Mining, Inference, and Prediction" Trevor et al., 2nd edition, chapter 13](https://web.stanford.edu/~hastie/ElemStatLearn/)
use crate::algorithm::neighbour::cover_tree::CoverTree;
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
@@ -6,13 +37,18 @@ use crate::math::distance::Distance;
use crate::math::num::FloatExt;
use serde::{Deserialize, Serialize};
///
/// K Nearest Neighbors Classifier
pub mod knn_classifier;
/// K Nearest Neighbors Regressor
pub mod knn_regressor;
/// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries.
/// `KNNAlgorithmName` maintains a list of supported search algorithms
#[derive(Serialize, Deserialize, Debug)]
pub enum KNNAlgorithmName {
/// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html)
LinearSearch,
/// Cover Tree Search algorithm, see [`CoverTree`](../algorithm/neighbour/cover_tree/index.html)
CoverTree,
}