feat: documents KNN Classifier
This commit is contained in:
@@ -1,3 +1,37 @@
|
||||
//! # K Nearest Neighbors Classifier
|
||||
//!
|
||||
//! SmartCore relies on 2 backend algorithms to speedup KNN queries:
|
||||
//! * [`LinearSearch`](../../algorithm/neighbour/linear_search/index.html)
|
||||
//! * [`CoverTree`](../../algorithm/neighbour/cover_tree/index.html)
|
||||
//!
|
||||
//! The parameter `k` controls the stability of the KNN estimate: when `k` is small the algorithm is sensitive to the noise in data. When `k` increases the estimator becomes more stable.
|
||||
//! In terms of the bias variance trade-off the variance decreases with `k` and the bias is likely to increase with `k`.
|
||||
//!
|
||||
//! When you don't know which search algorithm and `k` value to use go with default parameters defined by `Default::default()`
|
||||
//!
|
||||
//! To fit the model to a 4 x 2 matrix with 4 training samples, 2 features per sample:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::neighbors::knn_classifier::*;
|
||||
//! use smartcore::math::distance::*;
|
||||
//!
|
||||
//! //your explanatory variables. Each row is a training sample with 2 numerical features
|
||||
//! let x = DenseMatrix::from_array(&[
|
||||
//! &[1., 2.],
|
||||
//! &[3., 4.],
|
||||
//! &[5., 6.],
|
||||
//! &[7., 8.],
|
||||
//! &[9., 10.]]);
|
||||
//! let y = vec![2., 2., 2., 3., 3.]; //your class labels
|
||||
//!
|
||||
//! let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default());
|
||||
//! let y_hat = knn.predict(&x);
|
||||
//! ```
|
||||
//!
|
||||
//! variable `y_hat` will hold a vector with estimates of class labels
|
||||
//!
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::{row_iter, Matrix};
|
||||
@@ -5,12 +39,16 @@ use crate::math::distance::Distance;
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::neighbors::{KNNAlgorithm, KNNAlgorithmName};
|
||||
|
||||
/// `KNNClassifier` parameters. Use `Default::default()` for default values.
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct KNNClassifierParameters {
|
||||
/// backend search algorithm. See [`knn search algorithms`](../../algorithm/neighbour/index.html). `CoverTree` is default.
|
||||
pub algorithm: KNNAlgorithmName,
|
||||
/// number of training samples to consider when estimating class for new point. Default value is 3.
|
||||
pub k: usize,
|
||||
}
|
||||
|
||||
/// K Nearest Neighbors Classifier
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct KNNClassifier<T: FloatExt, D: Distance<Vec<T>, T>> {
|
||||
classes: Vec<T>,
|
||||
@@ -52,6 +90,13 @@ impl<T: FloatExt, D: Distance<Vec<T>, T>> PartialEq for KNNClassifier<T, D> {
|
||||
}
|
||||
|
||||
impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||
/// Fits KNN Classifier to a NxM matrix where N is number of samples and M is number of features.
|
||||
/// * `x` - training data
|
||||
/// * `y` - vector with target values (classes) of length N
|
||||
/// * `distance` - a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
/// * `parameters` - additional parameters like search algorithm and k
|
||||
pub fn fit<M: Matrix<T>>(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
@@ -94,6 +139,9 @@ impl<T: FloatExt, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimates the class labels for the provided data.
|
||||
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
|
||||
/// Returns a vector of size N with class estimates.
|
||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
|
||||
let mut result = M::zeros(1, x.shape().0);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user