Compare commits
86 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a62c293244 | ||
|
|
39f87aa5c2 | ||
|
|
8cc02cdd48 | ||
|
|
d60ba63862 | ||
|
|
5dd5c2f0d0 | ||
|
|
074cfaf14f | ||
|
|
393cf15534 | ||
|
|
80c406b37d | ||
|
|
0e1bf6ce7f | ||
|
|
0c9c70f8d2 | ||
|
|
62de25b2ae | ||
|
|
7d87451333 | ||
|
|
265fd558e7 | ||
|
|
e25e2aea2b | ||
|
|
2f6dd1325e | ||
|
|
b0dece9476 | ||
|
|
c507d976be | ||
|
|
fa54d5ee86 | ||
|
|
459d558d48 | ||
|
|
1b7dda30a2 | ||
|
|
c1bd1df5f6 | ||
|
|
cf751f05aa | ||
|
|
63ed89aadd | ||
|
|
890e9d644c | ||
|
|
af0a740394 | ||
|
|
616e38c282 | ||
|
|
a449fdd4ea | ||
|
|
669f87f812 | ||
|
|
6d529b34d2 | ||
|
|
3ec9e4f0db | ||
|
|
527477dea7 | ||
|
|
5b517c5048 | ||
|
|
2df0795be9 | ||
|
|
0dc97a4e9b | ||
|
|
6c0fd37222 | ||
|
|
d8d0fb6903 | ||
|
|
8d07efd921 | ||
|
|
ba27dd2a55 | ||
|
|
ed9769f651 | ||
|
|
b427e5d8b1 | ||
|
|
fabe362755 | ||
|
|
ee6b6a53d6 | ||
|
|
19f3a2fcc0 | ||
|
|
e09c4ba724 | ||
|
|
6624732a65 | ||
|
|
1cbde3ba22 | ||
|
|
551a6e34a5 | ||
|
|
c45bab491a | ||
|
|
7f35dc54e4 | ||
|
|
8f1a7dfd79 | ||
|
|
712c478af6 | ||
|
|
4d36b7f34f | ||
|
|
a16927aa16 | ||
|
|
d91f4f7ce4 | ||
|
|
a7fa0585eb | ||
|
|
a32eb66a6a | ||
|
|
f605f6e075 | ||
|
|
3b1aaaadf7 | ||
|
|
d015b12402 | ||
|
|
d5200074c2 | ||
|
|
473cdfc44d | ||
|
|
ad2e6c2900 | ||
|
|
9ea3133c27 | ||
|
|
e4c47c7540 | ||
|
|
f4fd4d2239 | ||
|
|
05dfffad5c | ||
|
|
a37b552a7d | ||
|
|
55e1158581 | ||
|
|
cfa824d7db | ||
|
|
bb5b437a32 | ||
|
|
851533dfa7 | ||
|
|
0d996edafe | ||
|
|
f291b71f4a | ||
|
|
2d75c2c405 | ||
|
|
1f2597be74 | ||
|
|
0f442e96c0 | ||
|
|
44e4be23a6 | ||
|
|
01f753f86d | ||
|
|
df766eaf79 | ||
|
|
09d9205696 | ||
|
|
dc7f01db4a | ||
|
|
eb4b49d552 | ||
|
|
98e3465e7b | ||
|
|
ea39024fd2 | ||
|
|
4e94feb872 | ||
|
|
fa802d2d3f |
@@ -0,0 +1,219 @@
|
||||
//! This module provides FastPair, a data-structure for efficiently tracking the dynamic
|
||||
//! closest pairs in a set of points, with an example usage in hierarchical clustering.[2][3][5]
|
||||
//!
|
||||
//! ## Purpose
|
||||
//!
|
||||
//! FastPair allows quick retrieval of the nearest neighbor for each data point by maintaining
|
||||
//! a "conga line" of closest pairs. Each point retains a link to its known nearest neighbor,
|
||||
//! and updates in the data structure propagate accordingly. This can be leveraged in
|
||||
//! agglomerative clustering steps, where merging or insertion of new points must be reflected
|
||||
//! in nearest-neighbor relationships.
|
||||
//!
|
||||
//! ## Example
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::metrics::distance::PairwiseDistance;
|
||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
//! use smartcore::algorithm::neighbour::fastpair::FastPair;
|
||||
//!
|
||||
//! let x = DenseMatrix::from_2d_array(&[
|
||||
//! &[5.1, 3.5, 1.4, 0.2],
|
||||
//! &[4.9, 3.0, 1.4, 0.2],
|
||||
//! &[4.7, 3.2, 1.3, 0.2],
|
||||
//! &[4.6, 3.1, 1.5, 0.2],
|
||||
//! &[5.0, 3.6, 1.4, 0.2],
|
||||
//! &[5.4, 3.9, 1.7, 0.4],
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! let fastpair = FastPair::new(&x).unwrap();
|
||||
//! let closest = fastpair.closest_pair();
|
||||
//! println!("Closest pair: {:?}", closest);
|
||||
//! ```
|
||||
use std::collections::HashMap;
|
||||
|
||||
use num::Bounded;
|
||||
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::basic::arrays::{Array, Array1, Array2};
|
||||
use crate::metrics::distance::euclidian::Euclidian;
|
||||
use crate::metrics::distance::PairwiseDistance;
|
||||
use crate::numbers::floatnum::FloatNumber;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
|
||||
/// Eppstein dynamic closet-pair structure
|
||||
/// 'M' can be a matrix-like trait that provides row access
|
||||
#[derive(Debug)]
|
||||
pub struct EppsteinDCP<'a, T: RealNumber + FloatNumber, M: Array2<T>> {
|
||||
samples: &'a M,
|
||||
// "buckets" store, for each row, a small structure recording potential neighbors
|
||||
neighbors: HashMap<usize, PairwiseDistance<T>>,
|
||||
}
|
||||
|
||||
impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> EppsteinDCP<'a, T, M> {
|
||||
/// Creates a new EppsteinDCP instance with the given data
|
||||
pub fn new(m: &'a M) -> Result<Self, Failed> {
|
||||
if m.shape().0 < 3 {
|
||||
return Err(Failed::because(
|
||||
FailedError::FindFailed,
|
||||
"min number of rows should be 3",
|
||||
));
|
||||
}
|
||||
|
||||
let mut this = Self {
|
||||
samples: m,
|
||||
neighbors: HashMap::with_capacity(m.shape().0),
|
||||
};
|
||||
this.initialize();
|
||||
Ok(this)
|
||||
}
|
||||
|
||||
/// Build an initial "conga line" or chain of potential neighbors
|
||||
/// akin to Eppstein’s technique[2].
|
||||
fn initialize(&mut self) {
|
||||
let n = self.samples.shape().0;
|
||||
if n < 2 {
|
||||
return;
|
||||
}
|
||||
// Assign each row i some large distance by default
|
||||
for i in 0..n {
|
||||
self.neighbors.insert(
|
||||
i,
|
||||
PairwiseDistance {
|
||||
node: i,
|
||||
neighbour: None,
|
||||
distance: Some(<T as Bounded>::max_value()),
|
||||
},
|
||||
);
|
||||
}
|
||||
// Example: link each i to the next, forming a chain
|
||||
// (depending on the actual Eppstein approach, can refine)
|
||||
for i in 0..(n - 1) {
|
||||
let dist = self.compute_dist(i, i + 1);
|
||||
self.neighbors.entry(i).and_modify(|pd| {
|
||||
pd.neighbour = Some(i + 1);
|
||||
pd.distance = Some(dist);
|
||||
});
|
||||
}
|
||||
// Potential refinement steps omitted for brevity
|
||||
}
|
||||
|
||||
/// Insert a point into the structure.
|
||||
pub fn insert(&mut self, row_idx: usize) {
|
||||
// Expand data, find neighbor to link with
|
||||
// For example, link row_idx to nearest among existing
|
||||
let mut best_neighbor = None;
|
||||
let mut best_d = <T as Bounded>::max_value();
|
||||
for (i, _) in &self.neighbors {
|
||||
let d = self.compute_dist(*i, row_idx);
|
||||
if d < best_d {
|
||||
best_d = d;
|
||||
best_neighbor = Some(*i);
|
||||
}
|
||||
}
|
||||
self.neighbors.insert(
|
||||
row_idx,
|
||||
PairwiseDistance {
|
||||
node: row_idx,
|
||||
neighbour: best_neighbor,
|
||||
distance: Some(best_d),
|
||||
},
|
||||
);
|
||||
// For the best_neighbor, you might want to see if row_idx becomes closer
|
||||
if let Some(kn) = best_neighbor {
|
||||
let dist = self.compute_dist(row_idx, kn);
|
||||
let entry = self.neighbors.get_mut(&kn).unwrap();
|
||||
if dist < entry.distance.unwrap() {
|
||||
entry.neighbour = Some(row_idx);
|
||||
entry.distance = Some(dist);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// For hierarchical clustering, discover minimal pairs, then merge
|
||||
pub fn closest_pair(&self) -> Option<PairwiseDistance<T>> {
|
||||
let mut min_pair: Option<PairwiseDistance<T>> = None;
|
||||
for (_, pd) in &self.neighbors {
|
||||
if let Some(d) = pd.distance {
|
||||
if min_pair.is_none() || d < min_pair.as_ref().unwrap().distance.unwrap() {
|
||||
min_pair = Some(pd.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
min_pair
|
||||
}
|
||||
|
||||
fn compute_dist(&self, i: usize, j: usize) -> T {
|
||||
// Example: Euclidean
|
||||
let row_i = self.samples.get_row(i);
|
||||
let row_j = self.samples.get_row(j);
|
||||
row_i
|
||||
.iterator(0)
|
||||
.zip(row_j.iterator(0))
|
||||
.map(|(a, b)| (*a - *b) * (*a - *b))
|
||||
.sum()
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple usage
|
||||
#[cfg(test)]
|
||||
mod tests_eppstein {
|
||||
use super::*;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
|
||||
#[test]
|
||||
fn test_eppstein() {
|
||||
let matrix =
|
||||
DenseMatrix::from_2d_array(&[&vec![1.0, 2.0], &vec![2.0, 2.0], &vec![5.0, 3.0]])
|
||||
.unwrap();
|
||||
let mut dcp = EppsteinDCP::new(&matrix).unwrap();
|
||||
dcp.insert(2);
|
||||
let cp = dcp.closest_pair();
|
||||
assert!(cp.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compare_fastpair_eppstein() {
|
||||
use crate::algorithm::neighbour::fastpair::FastPair;
|
||||
// Assuming EppsteinDCP is implemented in a similar module
|
||||
use crate::algorithm::neighbour::eppstein::EppsteinDCP;
|
||||
|
||||
// Create a static example matrix
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
// Build FastPair
|
||||
let fastpair = FastPair::new(&x).unwrap();
|
||||
let pair_fastpair = fastpair.closest_pair();
|
||||
|
||||
// Build EppsteinDCP
|
||||
let eppstein = EppsteinDCP::new(&x).unwrap();
|
||||
let pair_eppstein = eppstein.closest_pair();
|
||||
|
||||
// Compare the results
|
||||
assert_eq!(pair_fastpair.node, pair_eppstein.as_ref().unwrap().node);
|
||||
assert_eq!(
|
||||
pair_fastpair.neighbour.unwrap(),
|
||||
pair_eppstein.as_ref().unwrap().neighbour.unwrap()
|
||||
);
|
||||
|
||||
// Use a small epsilon for floating-point comparison
|
||||
let epsilon = 1e-9;
|
||||
let diff: f64 =
|
||||
pair_fastpair.distance.unwrap() - pair_eppstein.as_ref().unwrap().distance.unwrap();
|
||||
assert!(diff.abs() < epsilon);
|
||||
|
||||
println!("FastPair result: {:?}", pair_fastpair);
|
||||
println!("EppsteinDCP result: {:?}", pair_eppstein);
|
||||
}
|
||||
}
|
||||
@@ -173,6 +173,21 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
/// Return order dissimilarities from closest to furthest
|
||||
///
|
||||
#[allow(dead_code)]
|
||||
pub fn ordered_pairs(&self) -> std::vec::IntoIter<&PairwiseDistance<T>> {
|
||||
// improvement: implement this to return `impl Iterator<Item = &PairwiseDistance<T>>`
|
||||
// need to implement trait `Iterator` for `Vec<&PairwiseDistance<T>>`
|
||||
let mut distances = self
|
||||
.distances
|
||||
.values()
|
||||
.collect::<Vec<&PairwiseDistance<T>>>();
|
||||
distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
distances.into_iter()
|
||||
}
|
||||
|
||||
//
|
||||
// Compute distances from input to all other points in data-structure.
|
||||
// input is the row index of the sample matrix
|
||||
@@ -588,4 +603,103 @@ mod tests_fastpair {
|
||||
|
||||
assert_eq!(closest, min_dissimilarity);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fastpair_ordered_pairs() {
|
||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
])
|
||||
.unwrap();
|
||||
let fastpair = FastPair::new(&x).unwrap();
|
||||
|
||||
let ordered = fastpair.ordered_pairs();
|
||||
|
||||
let mut previous: f64 = -1.0;
|
||||
for p in ordered {
|
||||
if previous == -1.0 {
|
||||
previous = p.distance.unwrap();
|
||||
} else {
|
||||
let current = p.distance.unwrap();
|
||||
assert!(current >= previous);
|
||||
previous = current;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_set() {
|
||||
let empty_matrix = DenseMatrix::<f64>::zeros(0, 0);
|
||||
let result = FastPair::new(&empty_matrix);
|
||||
assert!(result.is_err());
|
||||
if let Err(e) = result {
|
||||
assert_eq!(
|
||||
e,
|
||||
Failed::because(FailedError::FindFailed, "min number of rows should be 3")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_point() {
|
||||
let single_point = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]).unwrap();
|
||||
let result = FastPair::new(&single_point);
|
||||
assert!(result.is_err());
|
||||
if let Err(e) = result {
|
||||
assert_eq!(
|
||||
e,
|
||||
Failed::because(FailedError::FindFailed, "min number of rows should be 3")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_two_points() {
|
||||
let two_points = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
|
||||
let result = FastPair::new(&two_points);
|
||||
assert!(result.is_err());
|
||||
if let Err(e) = result {
|
||||
assert_eq!(
|
||||
e,
|
||||
Failed::because(FailedError::FindFailed, "min number of rows should be 3")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_three_identical_points() {
|
||||
let identical_points =
|
||||
DenseMatrix::from_2d_array(&[&[1.0, 1.0], &[1.0, 1.0], &[1.0, 1.0]]).unwrap();
|
||||
let result = FastPair::new(&identical_points);
|
||||
assert!(result.is_ok());
|
||||
let fastpair = result.unwrap();
|
||||
let closest_pair = fastpair.closest_pair();
|
||||
assert_eq!(closest_pair.distance, Some(0.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_result_unwrapping() {
|
||||
let valid_matrix =
|
||||
DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0], &[5.0, 6.0], &[7.0, 8.0]])
|
||||
.unwrap();
|
||||
|
||||
let result = FastPair::new(&valid_matrix);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// This should not panic
|
||||
let _fastpair = result.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,7 +41,9 @@ use serde::{Deserialize, Serialize};
|
||||
pub(crate) mod bbd_tree;
|
||||
/// tree data structure for fast nearest neighbor search
|
||||
pub mod cover_tree;
|
||||
/// fastpair closest neighbour algorithm
|
||||
/// eppstein pairwise closest neighbour algorithm
|
||||
pub mod eppstein;
|
||||
/// fastpair pairwise closest neighbour algorithm
|
||||
pub mod fastpair;
|
||||
/// very simple algorithm that sequentially checks each element of the list until a match is found or the whole list has been searched.
|
||||
pub mod linear_search;
|
||||
|
||||
Reference in New Issue
Block a user