+ DBSCAN and data generator. Improves KNN API
This commit is contained in:
@@ -100,7 +100,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
/// Find k nearest neighbors of `p`
|
||||
/// * `p` - look for k nearest points to `p`
|
||||
/// * `k` - the number of nearest neighbors to return
|
||||
pub fn find(&self, p: &T, k: usize) -> Result<Vec<(usize, F)>, Failed> {
|
||||
pub fn find(&self, p: &T, k: usize) -> Result<Vec<(usize, F, &T)>, Failed> {
|
||||
if k <= 0 {
|
||||
return Err(Failed::because(FailedError::FindFailed, "k should be > 0"));
|
||||
}
|
||||
@@ -164,13 +164,13 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
current_cover_set = next_cover_set;
|
||||
}
|
||||
|
||||
let mut neighbors: Vec<(usize, F)> = Vec::new();
|
||||
let mut neighbors: Vec<(usize, F, &T)> = Vec::new();
|
||||
let upper_bound = *heap.peek();
|
||||
for ds in zero_set {
|
||||
if ds.0 <= upper_bound {
|
||||
let v = self.get_data_value(ds.1.idx);
|
||||
if !self.identical_excluded || v != p {
|
||||
neighbors.push((ds.1.idx, ds.0));
|
||||
neighbors.push((ds.1.idx, ds.0, &v));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -178,6 +178,60 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
Ok(neighbors.into_iter().take(k).collect())
|
||||
}
|
||||
|
||||
/// Find all nearest neighbors within radius `radius` from `p`
|
||||
/// * `p` - look for k nearest points to `p`
|
||||
/// * `radius` - radius of the search
|
||||
pub fn find_radius(&self, p: &T, radius: F) -> Result<Vec<(usize, F, &T)>, Failed> {
|
||||
if radius <= F::zero() {
|
||||
return Err(Failed::because(
|
||||
FailedError::FindFailed,
|
||||
"radius should be > 0",
|
||||
));
|
||||
}
|
||||
|
||||
let mut neighbors: Vec<(usize, F, &T)> = Vec::new();
|
||||
|
||||
let mut current_cover_set: Vec<(F, &Node<F>)> = Vec::new();
|
||||
let mut zero_set: Vec<(F, &Node<F>)> = Vec::new();
|
||||
|
||||
let e = self.get_data_value(self.root.idx);
|
||||
let mut d = self.distance.distance(&e, p);
|
||||
current_cover_set.push((d, &self.root));
|
||||
|
||||
while !current_cover_set.is_empty() {
|
||||
let mut next_cover_set: Vec<(F, &Node<F>)> = Vec::new();
|
||||
for par in current_cover_set {
|
||||
let parent = par.1;
|
||||
for c in 0..parent.children.len() {
|
||||
let child = &parent.children[c];
|
||||
if c == 0 {
|
||||
d = par.0;
|
||||
} else {
|
||||
d = self.distance.distance(self.get_data_value(child.idx), p);
|
||||
}
|
||||
|
||||
if d <= radius + child.max_dist {
|
||||
if !child.children.is_empty() {
|
||||
next_cover_set.push((d, child));
|
||||
} else if d <= radius {
|
||||
zero_set.push((d, child));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
current_cover_set = next_cover_set;
|
||||
}
|
||||
|
||||
for ds in zero_set {
|
||||
let v = self.get_data_value(ds.1.idx);
|
||||
if !self.identical_excluded || v != p {
|
||||
neighbors.push((ds.1.idx, ds.0, &v));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(neighbors)
|
||||
}
|
||||
|
||||
fn new_leaf(&self, idx: usize) -> Node<F> {
|
||||
Node {
|
||||
idx: idx,
|
||||
@@ -417,6 +471,11 @@ mod tests {
|
||||
knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
|
||||
let knn: Vec<usize> = knn.iter().map(|v| v.0).collect();
|
||||
assert_eq!(vec!(3, 4, 5), knn);
|
||||
|
||||
let mut knn = tree.find_radius(&5, 2.0).unwrap();
|
||||
knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
|
||||
let knn: Vec<i32> = knn.iter().map(|v| *v.2).collect();
|
||||
assert_eq!(vec!(3, 4, 5, 6, 7), knn);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -26,7 +26,7 @@ use std::cmp::{Ordering, PartialOrd};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::algorithm::sort::heap_select::HeapSelection;
|
||||
use crate::error::Failed;
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::math::distance::Distance;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
@@ -53,9 +53,12 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
||||
/// Find k nearest neighbors
|
||||
/// * `from` - look for k nearest points to `from`
|
||||
/// * `k` - the number of nearest neighbors to return
|
||||
pub fn find(&self, from: &T, k: usize) -> Result<Vec<(usize, F)>, Failed> {
|
||||
pub fn find(&self, from: &T, k: usize) -> Result<Vec<(usize, F, &T)>, Failed> {
|
||||
if k < 1 || k > self.data.len() {
|
||||
panic!("k should be >= 1 and <= length(data)");
|
||||
return Err(Failed::because(
|
||||
FailedError::FindFailed,
|
||||
"k should be >= 1 and <= length(data)",
|
||||
));
|
||||
}
|
||||
|
||||
let mut heap = HeapSelection::<KNNPoint<F>>::with_capacity(k);
|
||||
@@ -80,9 +83,33 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
||||
Ok(heap
|
||||
.get()
|
||||
.into_iter()
|
||||
.flat_map(|x| x.index.map(|i| (i, x.distance)))
|
||||
.flat_map(|x| x.index.map(|i| (i, x.distance, &self.data[i])))
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Find all nearest neighbors within radius `radius` from `p`
|
||||
/// * `p` - look for k nearest points to `p`
|
||||
/// * `radius` - radius of the search
|
||||
pub fn find_radius(&self, from: &T, radius: F) -> Result<Vec<(usize, F, &T)>, Failed> {
|
||||
if radius <= F::zero() {
|
||||
return Err(Failed::because(
|
||||
FailedError::FindFailed,
|
||||
"radius should be > 0",
|
||||
));
|
||||
}
|
||||
|
||||
let mut neighbors: Vec<(usize, F, &T)> = Vec::new();
|
||||
|
||||
for i in 0..self.data.len() {
|
||||
let d = self.distance.distance(&from, &self.data[i]);
|
||||
|
||||
if d <= radius {
|
||||
neighbors.push((i, d, &self.data[i]));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(neighbors)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -134,6 +161,16 @@ mod tests {
|
||||
|
||||
assert_eq!(vec!(0, 1, 2), found_idxs1);
|
||||
|
||||
let mut found_idxs1: Vec<i32> = algorithm1
|
||||
.find_radius(&5, 3.0)
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|v| *v.2)
|
||||
.collect();
|
||||
found_idxs1.sort();
|
||||
|
||||
assert_eq!(vec!(2, 3, 4, 5, 6, 7, 8), found_idxs1);
|
||||
|
||||
let data2 = vec![
|
||||
vec![1., 1.],
|
||||
vec![2., 2.],
|
||||
|
||||
@@ -29,8 +29,68 @@
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
use crate::algorithm::neighbour::cover_tree::CoverTree;
|
||||
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
|
||||
use crate::error::Failed;
|
||||
use crate::math::distance::Distance;
|
||||
use crate::math::num::RealNumber;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub(crate) mod bbd_tree;
|
||||
/// tree data structure for fast nearest neighbor search
|
||||
pub mod cover_tree;
|
||||
/// very simple algorithm that sequentially checks each element of the list until a match is found or the whole list has been searched.
|
||||
pub mod linear_search;
|
||||
|
||||
/// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries.
|
||||
/// `KNNAlgorithmName` maintains a list of supported search algorithms, see [KNN algorithms](../algorithm/neighbour/index.html)
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub enum KNNAlgorithmName {
|
||||
/// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html)
|
||||
LinearSearch,
|
||||
/// Cover Tree Search algorithm, see [`CoverTree`](../algorithm/neighbour/cover_tree/index.html)
|
||||
CoverTree,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub(crate) enum KNNAlgorithm<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
|
||||
CoverTree(CoverTree<Vec<T>, T, D>),
|
||||
}
|
||||
|
||||
impl KNNAlgorithmName {
|
||||
pub(crate) fn fit<T: RealNumber, D: Distance<Vec<T>, T>>(
|
||||
&self,
|
||||
data: Vec<Vec<T>>,
|
||||
distance: D,
|
||||
) -> Result<KNNAlgorithm<T, D>, Failed> {
|
||||
match *self {
|
||||
KNNAlgorithmName::LinearSearch => {
|
||||
LinearKNNSearch::new(data, distance).map(|a| KNNAlgorithm::LinearSearch(a))
|
||||
}
|
||||
KNNAlgorithmName::CoverTree => {
|
||||
CoverTree::new(data, distance).map(|a| KNNAlgorithm::CoverTree(a))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNAlgorithm<T, D> {
|
||||
pub fn find(&self, from: &Vec<T>, k: usize) -> Result<Vec<(usize, T, &Vec<T>)>, Failed> {
|
||||
match *self {
|
||||
KNNAlgorithm::LinearSearch(ref linear) => linear.find(from, k),
|
||||
KNNAlgorithm::CoverTree(ref cover) => cover.find(from, k),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn find_radius(
|
||||
&self,
|
||||
from: &Vec<T>,
|
||||
radius: T,
|
||||
) -> Result<Vec<(usize, T, &Vec<T>)>, Failed> {
|
||||
match *self {
|
||||
KNNAlgorithm::LinearSearch(ref linear) => linear.find_radius(from, radius),
|
||||
KNNAlgorithm::CoverTree(ref cover) => cover.find_radius(from, radius),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user