+ DBSCAN and data generator. Improves KNN API

This commit is contained in:
Vadim Zaliva
2020-10-02 14:04:01 -07:00
parent 6602de0d51
commit c43990e932
11 changed files with 556 additions and 53 deletions
+62 -3
View File
@@ -100,7 +100,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
/// Find k nearest neighbors of `p`
/// * `p` - look for k nearest points to `p`
/// * `k` - the number of nearest neighbors to return
pub fn find(&self, p: &T, k: usize) -> Result<Vec<(usize, F)>, Failed> {
pub fn find(&self, p: &T, k: usize) -> Result<Vec<(usize, F, &T)>, Failed> {
if k <= 0 {
return Err(Failed::because(FailedError::FindFailed, "k should be > 0"));
}
@@ -164,13 +164,13 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
current_cover_set = next_cover_set;
}
let mut neighbors: Vec<(usize, F)> = Vec::new();
let mut neighbors: Vec<(usize, F, &T)> = Vec::new();
let upper_bound = *heap.peek();
for ds in zero_set {
if ds.0 <= upper_bound {
let v = self.get_data_value(ds.1.idx);
if !self.identical_excluded || v != p {
neighbors.push((ds.1.idx, ds.0));
neighbors.push((ds.1.idx, ds.0, &v));
}
}
}
@@ -178,6 +178,60 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
Ok(neighbors.into_iter().take(k).collect())
}
/// Find all nearest neighbors within radius `radius` from `p`
/// * `p` - look for k nearest points to `p`
/// * `radius` - radius of the search
pub fn find_radius(&self, p: &T, radius: F) -> Result<Vec<(usize, F, &T)>, Failed> {
if radius <= F::zero() {
return Err(Failed::because(
FailedError::FindFailed,
"radius should be > 0",
));
}
let mut neighbors: Vec<(usize, F, &T)> = Vec::new();
let mut current_cover_set: Vec<(F, &Node<F>)> = Vec::new();
let mut zero_set: Vec<(F, &Node<F>)> = Vec::new();
let e = self.get_data_value(self.root.idx);
let mut d = self.distance.distance(&e, p);
current_cover_set.push((d, &self.root));
while !current_cover_set.is_empty() {
let mut next_cover_set: Vec<(F, &Node<F>)> = Vec::new();
for par in current_cover_set {
let parent = par.1;
for c in 0..parent.children.len() {
let child = &parent.children[c];
if c == 0 {
d = par.0;
} else {
d = self.distance.distance(self.get_data_value(child.idx), p);
}
if d <= radius + child.max_dist {
if !child.children.is_empty() {
next_cover_set.push((d, child));
} else if d <= radius {
zero_set.push((d, child));
}
}
}
}
current_cover_set = next_cover_set;
}
for ds in zero_set {
let v = self.get_data_value(ds.1.idx);
if !self.identical_excluded || v != p {
neighbors.push((ds.1.idx, ds.0, &v));
}
}
Ok(neighbors)
}
fn new_leaf(&self, idx: usize) -> Node<F> {
Node {
idx: idx,
@@ -417,6 +471,11 @@ mod tests {
knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
let knn: Vec<usize> = knn.iter().map(|v| v.0).collect();
assert_eq!(vec!(3, 4, 5), knn);
let mut knn = tree.find_radius(&5, 2.0).unwrap();
knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
let knn: Vec<i32> = knn.iter().map(|v| *v.2).collect();
assert_eq!(vec!(3, 4, 5, 6, 7), knn);
}
#[test]
+41 -4
View File
@@ -26,7 +26,7 @@ use std::cmp::{Ordering, PartialOrd};
use std::marker::PhantomData;
use crate::algorithm::sort::heap_select::HeapSelection;
use crate::error::Failed;
use crate::error::{Failed, FailedError};
use crate::math::distance::Distance;
use crate::math::num::RealNumber;
@@ -53,9 +53,12 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
/// Find k nearest neighbors
/// * `from` - look for k nearest points to `from`
/// * `k` - the number of nearest neighbors to return
pub fn find(&self, from: &T, k: usize) -> Result<Vec<(usize, F)>, Failed> {
pub fn find(&self, from: &T, k: usize) -> Result<Vec<(usize, F, &T)>, Failed> {
if k < 1 || k > self.data.len() {
panic!("k should be >= 1 and <= length(data)");
return Err(Failed::because(
FailedError::FindFailed,
"k should be >= 1 and <= length(data)",
));
}
let mut heap = HeapSelection::<KNNPoint<F>>::with_capacity(k);
@@ -80,9 +83,33 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
Ok(heap
.get()
.into_iter()
.flat_map(|x| x.index.map(|i| (i, x.distance)))
.flat_map(|x| x.index.map(|i| (i, x.distance, &self.data[i])))
.collect())
}
/// Find all nearest neighbors within radius `radius` from `p`
/// * `p` - look for k nearest points to `p`
/// * `radius` - radius of the search
pub fn find_radius(&self, from: &T, radius: F) -> Result<Vec<(usize, F, &T)>, Failed> {
if radius <= F::zero() {
return Err(Failed::because(
FailedError::FindFailed,
"radius should be > 0",
));
}
let mut neighbors: Vec<(usize, F, &T)> = Vec::new();
for i in 0..self.data.len() {
let d = self.distance.distance(&from, &self.data[i]);
if d <= radius {
neighbors.push((i, d, &self.data[i]));
}
}
Ok(neighbors)
}
}
#[derive(Debug)]
@@ -134,6 +161,16 @@ mod tests {
assert_eq!(vec!(0, 1, 2), found_idxs1);
let mut found_idxs1: Vec<i32> = algorithm1
.find_radius(&5, 3.0)
.unwrap()
.iter()
.map(|v| *v.2)
.collect();
found_idxs1.sort();
assert_eq!(vec!(2, 3, 4, 5, 6, 7, 8), found_idxs1);
let data2 = vec![
vec![1., 1.],
vec![2., 2.],
+60
View File
@@ -29,8 +29,68 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use crate::algorithm::neighbour::cover_tree::CoverTree;
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
use crate::error::Failed;
use crate::math::distance::Distance;
use crate::math::num::RealNumber;
use serde::{Deserialize, Serialize};
pub(crate) mod bbd_tree;
/// tree data structure for fast nearest neighbor search
pub mod cover_tree;
/// very simple algorithm that sequentially checks each element of the list until a match is found or the whole list has been searched.
pub mod linear_search;
/// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries.
/// `KNNAlgorithmName` maintains a list of supported search algorithms, see [KNN algorithms](../algorithm/neighbour/index.html)
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum KNNAlgorithmName {
/// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html)
LinearSearch,
/// Cover Tree Search algorithm, see [`CoverTree`](../algorithm/neighbour/cover_tree/index.html)
CoverTree,
}
#[derive(Serialize, Deserialize, Debug)]
pub(crate) enum KNNAlgorithm<T: RealNumber, D: Distance<Vec<T>, T>> {
LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
CoverTree(CoverTree<Vec<T>, T, D>),
}
impl KNNAlgorithmName {
pub(crate) fn fit<T: RealNumber, D: Distance<Vec<T>, T>>(
&self,
data: Vec<Vec<T>>,
distance: D,
) -> Result<KNNAlgorithm<T, D>, Failed> {
match *self {
KNNAlgorithmName::LinearSearch => {
LinearKNNSearch::new(data, distance).map(|a| KNNAlgorithm::LinearSearch(a))
}
KNNAlgorithmName::CoverTree => {
CoverTree::new(data, distance).map(|a| KNNAlgorithm::CoverTree(a))
}
}
}
}
impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNAlgorithm<T, D> {
pub fn find(&self, from: &Vec<T>, k: usize) -> Result<Vec<(usize, T, &Vec<T>)>, Failed> {
match *self {
KNNAlgorithm::LinearSearch(ref linear) => linear.find(from, k),
KNNAlgorithm::CoverTree(ref cover) => cover.find(from, k),
}
}
pub fn find_radius(
&self,
from: &Vec<T>,
radius: T,
) -> Result<Vec<(usize, T, &Vec<T>)>, Failed> {
match *self {
KNNAlgorithm::LinearSearch(ref linear) => linear.find_radius(from, radius),
KNNAlgorithm::CoverTree(ref cover) => cover.find_radius(from, radius),
}
}
}