diff --git a/Cargo.toml b/Cargo.toml index 0a6e32d..20eebf5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ nalgebra = { version = "0.22.0", optional = true } num-traits = "0.2.12" num = "0.3.0" rand = "0.7.3" +rand_distr = "0.3.0" serde = { version = "1.0.115", features = ["derive"] } serde_derive = "1.0.115" diff --git a/src/algorithm/neighbour/cover_tree.rs b/src/algorithm/neighbour/cover_tree.rs index 70a3d33..da870d2 100644 --- a/src/algorithm/neighbour/cover_tree.rs +++ b/src/algorithm/neighbour/cover_tree.rs @@ -100,7 +100,7 @@ impl> CoverTree /// Find k nearest neighbors of `p` /// * `p` - look for k nearest points to `p` /// * `k` - the number of nearest neighbors to return - pub fn find(&self, p: &T, k: usize) -> Result, Failed> { + pub fn find(&self, p: &T, k: usize) -> Result, Failed> { if k <= 0 { return Err(Failed::because(FailedError::FindFailed, "k should be > 0")); } @@ -164,13 +164,13 @@ impl> CoverTree current_cover_set = next_cover_set; } - let mut neighbors: Vec<(usize, F)> = Vec::new(); + let mut neighbors: Vec<(usize, F, &T)> = Vec::new(); let upper_bound = *heap.peek(); for ds in zero_set { if ds.0 <= upper_bound { let v = self.get_data_value(ds.1.idx); if !self.identical_excluded || v != p { - neighbors.push((ds.1.idx, ds.0)); + neighbors.push((ds.1.idx, ds.0, &v)); } } } @@ -178,6 +178,60 @@ impl> CoverTree Ok(neighbors.into_iter().take(k).collect()) } + /// Find all nearest neighbors within radius `radius` from `p` + /// * `p` - look for k nearest points to `p` + /// * `radius` - radius of the search + pub fn find_radius(&self, p: &T, radius: F) -> Result, Failed> { + if radius <= F::zero() { + return Err(Failed::because( + FailedError::FindFailed, + "radius should be > 0", + )); + } + + let mut neighbors: Vec<(usize, F, &T)> = Vec::new(); + + let mut current_cover_set: Vec<(F, &Node)> = Vec::new(); + let mut zero_set: Vec<(F, &Node)> = Vec::new(); + + let e = self.get_data_value(self.root.idx); + let mut d = self.distance.distance(&e, p); + current_cover_set.push((d, &self.root)); + + while !current_cover_set.is_empty() { + let mut next_cover_set: Vec<(F, &Node)> = Vec::new(); + for par in current_cover_set { + let parent = par.1; + for c in 0..parent.children.len() { + let child = &parent.children[c]; + if c == 0 { + d = par.0; + } else { + d = self.distance.distance(self.get_data_value(child.idx), p); + } + + if d <= radius + child.max_dist { + if !child.children.is_empty() { + next_cover_set.push((d, child)); + } else if d <= radius { + zero_set.push((d, child)); + } + } + } + } + current_cover_set = next_cover_set; + } + + for ds in zero_set { + let v = self.get_data_value(ds.1.idx); + if !self.identical_excluded || v != p { + neighbors.push((ds.1.idx, ds.0, &v)); + } + } + + Ok(neighbors) + } + fn new_leaf(&self, idx: usize) -> Node { Node { idx: idx, @@ -417,6 +471,11 @@ mod tests { knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); let knn: Vec = knn.iter().map(|v| v.0).collect(); assert_eq!(vec!(3, 4, 5), knn); + + let mut knn = tree.find_radius(&5, 2.0).unwrap(); + knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); + let knn: Vec = knn.iter().map(|v| *v.2).collect(); + assert_eq!(vec!(3, 4, 5, 6, 7), knn); } #[test] diff --git a/src/algorithm/neighbour/linear_search.rs b/src/algorithm/neighbour/linear_search.rs index 3ac1a2b..e89a793 100644 --- a/src/algorithm/neighbour/linear_search.rs +++ b/src/algorithm/neighbour/linear_search.rs @@ -26,7 +26,7 @@ use std::cmp::{Ordering, PartialOrd}; use std::marker::PhantomData; use crate::algorithm::sort::heap_select::HeapSelection; -use crate::error::Failed; +use crate::error::{Failed, FailedError}; use crate::math::distance::Distance; use crate::math::num::RealNumber; @@ -53,9 +53,12 @@ impl> LinearKNNSearch { /// Find k nearest neighbors /// * `from` - look for k nearest points to `from` /// * `k` - the number of nearest neighbors to return - pub fn find(&self, from: &T, k: usize) -> Result, Failed> { + pub fn find(&self, from: &T, k: usize) -> Result, Failed> { if k < 1 || k > self.data.len() { - panic!("k should be >= 1 and <= length(data)"); + return Err(Failed::because( + FailedError::FindFailed, + "k should be >= 1 and <= length(data)", + )); } let mut heap = HeapSelection::>::with_capacity(k); @@ -80,9 +83,33 @@ impl> LinearKNNSearch { Ok(heap .get() .into_iter() - .flat_map(|x| x.index.map(|i| (i, x.distance))) + .flat_map(|x| x.index.map(|i| (i, x.distance, &self.data[i]))) .collect()) } + + /// Find all nearest neighbors within radius `radius` from `p` + /// * `p` - look for k nearest points to `p` + /// * `radius` - radius of the search + pub fn find_radius(&self, from: &T, radius: F) -> Result, Failed> { + if radius <= F::zero() { + return Err(Failed::because( + FailedError::FindFailed, + "radius should be > 0", + )); + } + + let mut neighbors: Vec<(usize, F, &T)> = Vec::new(); + + for i in 0..self.data.len() { + let d = self.distance.distance(&from, &self.data[i]); + + if d <= radius { + neighbors.push((i, d, &self.data[i])); + } + } + + Ok(neighbors) + } } #[derive(Debug)] @@ -134,6 +161,16 @@ mod tests { assert_eq!(vec!(0, 1, 2), found_idxs1); + let mut found_idxs1: Vec = algorithm1 + .find_radius(&5, 3.0) + .unwrap() + .iter() + .map(|v| *v.2) + .collect(); + found_idxs1.sort(); + + assert_eq!(vec!(2, 3, 4, 5, 6, 7, 8), found_idxs1); + let data2 = vec![ vec![1., 1.], vec![2., 2.], diff --git a/src/algorithm/neighbour/mod.rs b/src/algorithm/neighbour/mod.rs index 48c8835..0a4f21a 100644 --- a/src/algorithm/neighbour/mod.rs +++ b/src/algorithm/neighbour/mod.rs @@ -29,8 +29,68 @@ //! //! +use crate::algorithm::neighbour::cover_tree::CoverTree; +use crate::algorithm::neighbour::linear_search::LinearKNNSearch; +use crate::error::Failed; +use crate::math::distance::Distance; +use crate::math::num::RealNumber; +use serde::{Deserialize, Serialize}; + pub(crate) mod bbd_tree; /// tree data structure for fast nearest neighbor search pub mod cover_tree; /// very simple algorithm that sequentially checks each element of the list until a match is found or the whole list has been searched. pub mod linear_search; + +/// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries. +/// `KNNAlgorithmName` maintains a list of supported search algorithms, see [KNN algorithms](../algorithm/neighbour/index.html) +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum KNNAlgorithmName { + /// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html) + LinearSearch, + /// Cover Tree Search algorithm, see [`CoverTree`](../algorithm/neighbour/cover_tree/index.html) + CoverTree, +} + +#[derive(Serialize, Deserialize, Debug)] +pub(crate) enum KNNAlgorithm, T>> { + LinearSearch(LinearKNNSearch, T, D>), + CoverTree(CoverTree, T, D>), +} + +impl KNNAlgorithmName { + pub(crate) fn fit, T>>( + &self, + data: Vec>, + distance: D, + ) -> Result, Failed> { + match *self { + KNNAlgorithmName::LinearSearch => { + LinearKNNSearch::new(data, distance).map(|a| KNNAlgorithm::LinearSearch(a)) + } + KNNAlgorithmName::CoverTree => { + CoverTree::new(data, distance).map(|a| KNNAlgorithm::CoverTree(a)) + } + } + } +} + +impl, T>> KNNAlgorithm { + pub fn find(&self, from: &Vec, k: usize) -> Result)>, Failed> { + match *self { + KNNAlgorithm::LinearSearch(ref linear) => linear.find(from, k), + KNNAlgorithm::CoverTree(ref cover) => cover.find(from, k), + } + } + + pub fn find_radius( + &self, + from: &Vec, + radius: T, + ) -> Result)>, Failed> { + match *self { + KNNAlgorithm::LinearSearch(ref linear) => linear.find_radius(from, radius), + KNNAlgorithm::CoverTree(ref cover) => cover.find_radius(from, radius), + } + } +} diff --git a/src/cluster/dbscan.rs b/src/cluster/dbscan.rs new file mode 100644 index 0000000..488a7ac --- /dev/null +++ b/src/cluster/dbscan.rs @@ -0,0 +1,252 @@ +//! # DBSCAN Clustering +//! +//! DBSCAN - Density-Based Spatial Clustering of Applications with Noise. +//! +//! Example: +//! +//! ``` +//! use smartcore::linalg::naive::dense_matrix::*; +//! use smartcore::cluster::dbscan::*; +//! use smartcore::math::distance::Distances; +//! use smartcore::neighbors::KNNAlgorithmName; +//! use smartcore::dataset::generator; +//! +//! // Generate three blobs +//! let blobs = generator::make_blobs(100, 2, 3); +//! let x = DenseMatrix::from_vec(blobs.num_samples, blobs.num_features, &blobs.data); +//! // Fit the algorithm and predict cluster labels +//! let labels = DBSCAN::fit(&x, Distances::euclidian(), DBSCANParameters{ +//! min_samples: 5, +//! eps: 3.0, +//! algorithm: KNNAlgorithmName::CoverTree +//! }).and_then(|dbscan| dbscan.predict(&x)); +//! +//! println!("{:?}", labels); +//! ``` +//! +//! ## References: +//! +//! * ["A Density-Based Algorithm for Discovering Clusters in Large Spatial Databases with Noise", Ester M., Kriegel HP., Sander J., Xu X.](http://faculty.marshall.usc.edu/gareth-james/ISL/) +//! * ["Density-Based Clustering in Spatial Databases: The Algorithm GDBSCAN and its Applications", Sander J., Ester M., Kriegel HP., Xu X.](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.63.1629&rep=rep1&type=pdf) + +extern crate rand; + +use std::fmt::Debug; +use std::iter::Sum; + +use serde::{Deserialize, Serialize}; + +use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName}; +use crate::error::Failed; +use crate::linalg::{row_iter, Matrix}; +use crate::math::distance::Distance; +use crate::math::num::RealNumber; +use crate::tree::decision_tree_classifier::which_max; + +/// DBSCAN clustering algorithm +#[derive(Serialize, Deserialize, Debug)] +pub struct DBSCAN, T>> { + cluster_labels: Vec, + num_classes: usize, + knn_algorithm: KNNAlgorithm, + eps: T, +} + +#[derive(Debug, Clone)] +/// DBSCAN clustering algorithm parameters +pub struct DBSCANParameters { + /// Maximum number of iterations of the k-means algorithm for a single run. + pub min_samples: usize, + /// The number of samples in a neighborhood for a point to be considered as a core point. + pub eps: T, + /// KNN algorithm to use. + pub algorithm: KNNAlgorithmName, +} + +impl, T>> PartialEq for DBSCAN { + fn eq(&self, other: &Self) -> bool { + self.cluster_labels.len() == other.cluster_labels.len() + && self.num_classes == other.num_classes + && self.eps == other.eps + && self.cluster_labels == other.cluster_labels + } +} + +impl Default for DBSCANParameters { + fn default() -> Self { + DBSCANParameters { + min_samples: 5, + eps: T::half(), + algorithm: KNNAlgorithmName::CoverTree, + } + } +} + +impl, T>> DBSCAN { + /// Fit algorithm to _NxM_ matrix where _N_ is number of samples and _M_ is number of features. + /// * `data` - training instances to cluster + /// * `k` - number of clusters + /// * `parameters` - cluster parameters + pub fn fit>( + x: &M, + distance: D, + parameters: DBSCANParameters, + ) -> Result, Failed> { + if parameters.min_samples < 1 { + return Err(Failed::fit(&format!("Invalid minPts"))); + } + + if parameters.eps <= T::zero() { + return Err(Failed::fit(&format!("Invalid radius: "))); + } + + let mut k = 0; + let unassigned = -2; + let outlier = -1; + + let n = x.shape().0; + let mut y = vec![unassigned; n]; + + let algo = parameters.algorithm.fit(row_iter(x).collect(), distance)?; + + for (i, e) in row_iter(x).enumerate() { + if y[i] == unassigned { + let mut neighbors = algo.find_radius(&e, parameters.eps)?; + if neighbors.len() < parameters.min_samples { + y[i] = outlier; + } else { + y[i] = k; + for j in 0..neighbors.len() { + if y[neighbors[j].0] == unassigned { + y[neighbors[j].0] = k; + + let mut secondary_neighbors = + algo.find_radius(neighbors[j].2, parameters.eps)?; + + if secondary_neighbors.len() >= parameters.min_samples { + neighbors.append(&mut secondary_neighbors); + } + } + + if y[neighbors[j].0] == outlier { + y[neighbors[j].0] = k; + } + } + k += 1; + } + } + } + + Ok(DBSCAN { + cluster_labels: y, + num_classes: k as usize, + knn_algorithm: algo, + eps: parameters.eps, + }) + } + + /// Predict clusters for `x` + /// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features. + pub fn predict>(&self, x: &M) -> Result { + let (n, m) = x.shape(); + let mut result = M::zeros(1, n); + let mut row = vec![T::zero(); m]; + + for i in 0..n { + x.copy_row_as_vec(i, &mut row); + let neighbors = self.knn_algorithm.find_radius(&row, self.eps)?; + let mut label = vec![0usize; self.num_classes + 1]; + for neighbor in neighbors { + let yi = self.cluster_labels[neighbor.0]; + if yi < 0 { + label[self.num_classes] += 1; + } else { + label[yi as usize] += 1; + } + } + let class = which_max(&label); + if class != self.num_classes { + result.set(0, i, T::from(class).unwrap()); + } else { + result.set(0, i, -T::one()); + } + } + + Ok(result.to_row_vector()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::linalg::naive::dense_matrix::DenseMatrix; + use crate::math::distance::euclidian::Euclidian; + use crate::math::distance::Distances; + + #[test] + fn fit_predict_dbscan() { + let x = DenseMatrix::from_2d_array(&[ + &[1.0, 2.0], + &[1.1, 2.1], + &[0.9, 1.9], + &[1.2, 1.2], + &[0.8, 1.8], + &[2.0, 1.0], + &[2.1, 1.1], + &[2.2, 1.2], + &[1.9, 0.9], + &[1.8, 0.8], + &[3.0, 5.0], + ]); + + let expected_labels = vec![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0]; + + let dbscan = DBSCAN::fit( + &x, + Distances::euclidian(), + DBSCANParameters { + min_samples: 5, + eps: 1.0, + algorithm: KNNAlgorithmName::CoverTree, + }, + ) + .unwrap(); + + let predicted_labels = dbscan.predict(&x).unwrap(); + + assert_eq!(expected_labels, predicted_labels); + } + + #[test] + fn serde() { + let x = DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + &[4.9, 3.1, 1.5, 0.1], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + &[5.7, 2.8, 4.5, 1.3], + &[6.3, 3.3, 4.7, 1.6], + &[4.9, 2.4, 3.3, 1.0], + &[6.6, 2.9, 4.6, 1.3], + &[5.2, 2.7, 3.9, 1.4], + ]); + + let dbscan = DBSCAN::fit(&x, Distances::euclidian(), Default::default()).unwrap(); + + let deserialized_dbscan: DBSCAN = + serde_json::from_str(&serde_json::to_string(&dbscan).unwrap()).unwrap(); + + assert_eq!(dbscan, deserialized_dbscan); + } +} diff --git a/src/cluster/mod.rs b/src/cluster/mod.rs index 3201cda..be6ef9f 100644 --- a/src/cluster/mod.rs +++ b/src/cluster/mod.rs @@ -3,5 +3,6 @@ //! Clustering is the type of unsupervised learning where you divide the population or data points into a number of groups such that data points in the same groups //! are more similar to other data points in the same group than those in other groups. In simple words, the aim is to segregate groups with similar traits and assign them into clusters. +pub mod dbscan; /// An iterative clustering algorithm that aims to find local maxima in each iteration. pub mod kmeans; diff --git a/src/dataset/generator.rs b/src/dataset/generator.rs new file mode 100644 index 0000000..fd4f400 --- /dev/null +++ b/src/dataset/generator.rs @@ -0,0 +1,129 @@ +//! # Dataset Generators +//! +use rand::distributions::Uniform; +use rand::prelude::*; +use rand_distr::Normal; + +use crate::dataset::Dataset; + +/// Generate `num_centers` clusters of normally distributed points +pub fn make_blobs( + num_samples: usize, + num_features: usize, + num_centers: usize, +) -> Dataset { + let center_box = Uniform::from(-10.0..10.0); + let cluster_std = 1.0; + let mut centers: Vec>> = Vec::with_capacity(num_centers); + + let mut rng = rand::thread_rng(); + for _ in 0..num_centers { + centers.push( + (0..num_features) + .map(|_| Normal::new(center_box.sample(&mut rng), cluster_std).unwrap()) + .collect(), + ); + } + + let mut y: Vec = Vec::with_capacity(num_samples); + let mut x: Vec = Vec::with_capacity(num_samples); + + for i in 0..num_samples { + let label = i % num_centers; + y.push(label as f32); + for j in 0..num_features { + x.push(centers[label][j].sample(&mut rng)); + } + } + + Dataset { + data: x, + target: y, + num_samples: num_samples, + num_features: num_features, + feature_names: (0..num_features).map(|n| n.to_string()).collect(), + target_names: vec!["label".to_string()], + description: "Isotropic Gaussian blobs".to_string(), + } +} + +/// Make a large circle containing a smaller circle in 2d. +pub fn make_circles(num_samples: usize, factor: f32, noise: f32) -> Dataset { + if factor >= 1.0 || factor < 0.0 { + panic!("'factor' has to be between 0 and 1."); + } + + let num_samples_out = num_samples / 2; + let num_samples_in = num_samples - num_samples_out; + + let linspace_out = linspace(0.0, 2.0 * std::f32::consts::PI, num_samples_out); + let linspace_in = linspace(0.0, 2.0 * std::f32::consts::PI, num_samples_in); + + println!("{:?}", linspace_out); + println!("{:?}", linspace_in); + let noise = Normal::new(0.0, noise).unwrap(); + let mut rng = rand::thread_rng(); + + let mut x: Vec = Vec::with_capacity(num_samples * 2); + let mut y: Vec = Vec::with_capacity(num_samples); + + for v in linspace_out { + x.push(v.cos() + noise.sample(&mut rng)); + x.push(v.sin() + noise.sample(&mut rng)); + y.push(0.0); + } + + for v in linspace_in { + x.push(v.cos() * factor + noise.sample(&mut rng)); + x.push(v.sin() * factor + noise.sample(&mut rng)); + y.push(1.0); + } + + Dataset { + data: x, + target: y, + num_samples: num_samples, + num_features: 2, + feature_names: (0..2).map(|n| n.to_string()).collect(), + target_names: vec!["label".to_string()], + description: "Large circle containing a smaller circle in 2d".to_string(), + } +} + +fn linspace(start: f32, stop: f32, num: usize) -> Vec { + let div = num as f32; + let delta = stop - start; + let step = delta / div; + (0..num).map(|v| v as f32 * step).collect() +} + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn test_make_blobs() { + let dataset = make_blobs(10, 2, 3); + assert_eq!( + dataset.data.len(), + dataset.num_features * dataset.num_samples + ); + assert_eq!(dataset.target.len(), dataset.num_samples); + assert_eq!(dataset.num_features, 2); + assert_eq!(dataset.num_samples, 10); + } + + #[test] + fn test_make_circles() { + let dataset = make_circles(10, 0.5, 0.05); + println!("{:?}", dataset.as_matrix()); + assert_eq!( + dataset.data.len(), + dataset.num_features * dataset.num_samples + ); + assert_eq!(dataset.target.len(), dataset.num_samples); + assert_eq!(dataset.num_features, 2); + assert_eq!(dataset.num_samples, 10); + } +} diff --git a/src/dataset/mod.rs b/src/dataset/mod.rs index 8d7a4e2..bfcd1c9 100644 --- a/src/dataset/mod.rs +++ b/src/dataset/mod.rs @@ -5,6 +5,7 @@ pub mod boston; pub mod breast_cancer; pub mod diabetes; pub mod digits; +pub mod generator; pub mod iris; use crate::math::num::RealNumber; diff --git a/src/neighbors/knn_classifier.rs b/src/neighbors/knn_classifier.rs index 1f75949..3ad4297 100644 --- a/src/neighbors/knn_classifier.rs +++ b/src/neighbors/knn_classifier.rs @@ -34,11 +34,12 @@ use serde::{Deserialize, Serialize}; +use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName}; use crate::error::Failed; use crate::linalg::{row_iter, Matrix}; use crate::math::distance::Distance; use crate::math::num::RealNumber; -use crate::neighbors::{KNNAlgorithm, KNNAlgorithmName, KNNWeightFunction}; +use crate::neighbors::KNNWeightFunction; /// `KNNClassifier` parameters. Use `Default::default()` for default values. #[derive(Serialize, Deserialize, Debug)] diff --git a/src/neighbors/knn_regressor.rs b/src/neighbors/knn_regressor.rs index 5c979a3..04fbd35 100644 --- a/src/neighbors/knn_regressor.rs +++ b/src/neighbors/knn_regressor.rs @@ -36,11 +36,12 @@ //! use serde::{Deserialize, Serialize}; +use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName}; use crate::error::Failed; use crate::linalg::{row_iter, BaseVector, Matrix}; use crate::math::distance::Distance; use crate::math::num::RealNumber; -use crate::neighbors::{KNNAlgorithm, KNNAlgorithmName, KNNWeightFunction}; +use crate::neighbors::KNNWeightFunction; /// `KNNRegressor` parameters. Use `Default::default()` for default values. #[derive(Serialize, Deserialize, Debug)] diff --git a/src/neighbors/mod.rs b/src/neighbors/mod.rs index 8251117..6d542f6 100644 --- a/src/neighbors/mod.rs +++ b/src/neighbors/mod.rs @@ -32,10 +32,6 @@ //! //! -use crate::algorithm::neighbour::cover_tree::CoverTree; -use crate::algorithm::neighbour::linear_search::LinearKNNSearch; -use crate::error::Failed; -use crate::math::distance::Distance; use crate::math::num::RealNumber; use serde::{Deserialize, Serialize}; @@ -44,15 +40,12 @@ pub mod knn_classifier; /// K Nearest Neighbors Regressor pub mod knn_regressor; -/// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries. /// `KNNAlgorithmName` maintains a list of supported search algorithms, see [KNN algorithms](../algorithm/neighbour/index.html) -#[derive(Serialize, Deserialize, Debug)] -pub enum KNNAlgorithmName { - /// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html) - LinearSearch, - /// Cover Tree Search algorithm, see [`CoverTree`](../algorithm/neighbour/cover_tree/index.html) - CoverTree, -} +#[deprecated( + since = "0.2.0", + note = "please use `smartcore::algorithm::neighbour::KNNAlgorithmName` instead" +)] +pub type KNNAlgorithmName = crate::algorithm::neighbour::KNNAlgorithmName; /// Weight function that is used to determine estimated value. #[derive(Serialize, Deserialize, Debug)] @@ -63,12 +56,6 @@ pub enum KNNWeightFunction { Distance, } -#[derive(Serialize, Deserialize, Debug)] -enum KNNAlgorithm, T>> { - LinearSearch(LinearKNNSearch, T, D>), - CoverTree(CoverTree, T, D>), -} - impl KNNWeightFunction { fn calc_weights(&self, distances: Vec) -> std::vec::Vec { match *self { @@ -88,29 +75,3 @@ impl KNNWeightFunction { } } } - -impl KNNAlgorithmName { - fn fit, T>>( - &self, - data: Vec>, - distance: D, - ) -> Result, Failed> { - match *self { - KNNAlgorithmName::LinearSearch => { - LinearKNNSearch::new(data, distance).map(|a| KNNAlgorithm::LinearSearch(a)) - } - KNNAlgorithmName::CoverTree => { - CoverTree::new(data, distance).map(|a| KNNAlgorithm::CoverTree(a)) - } - } - } -} - -impl, T>> KNNAlgorithm { - fn find(&self, from: &Vec, k: usize) -> Result, Failed> { - match *self { - KNNAlgorithm::LinearSearch(ref linear) => linear.find(from, k), - KNNAlgorithm::CoverTree(ref cover) => cover.find(from, k), - } - } -}