Disambiguate distances. Implement Fastpair. (#220)
This commit is contained in:
@@ -1,48 +0,0 @@
|
||||
//!
|
||||
//! Dissimilarities for vector-vector distance
|
||||
//!
|
||||
//! Representing distances as pairwise dissimilarities, so to build a
|
||||
//! graph of closest neighbours. This representation can be reused for
|
||||
//! different implementations (initially used in this library for FastPair).
|
||||
use std::cmp::{Eq, Ordering, PartialOrd};
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
|
||||
///
|
||||
/// The edge of the subgraph is defined by `PairwiseDistance`.
|
||||
/// The calling algorithm can store a list of distsances as
|
||||
/// a list of these structures.
|
||||
///
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct PairwiseDistance<T: RealNumber> {
|
||||
/// index of the vector in the original `Matrix` or list
|
||||
pub node: usize,
|
||||
|
||||
/// index of the closest neighbor in the original `Matrix` or same list
|
||||
pub neighbour: Option<usize>,
|
||||
|
||||
/// measure of distance, according to the algorithm distance function
|
||||
/// if the distance is None, the edge has value "infinite" or max distance
|
||||
/// each algorithm has to match
|
||||
pub distance: Option<T>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Eq for PairwiseDistance<T> {}
|
||||
|
||||
impl<T: RealNumber> PartialEq for PairwiseDistance<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.node == other.node
|
||||
&& self.neighbour == other.neighbour
|
||||
&& self.distance == other.distance
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> PartialOrd for PairwiseDistance<T> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
self.distance.partial_cmp(&other.distance)
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
///
|
||||
/// # FastPair: Data-structure for the dynamic closest-pair problem.
|
||||
/// ### FastPair: Data-structure for the dynamic closest-pair problem.
|
||||
///
|
||||
/// Reference:
|
||||
/// Eppstein, David: Fast hierarchical clustering and other applications of
|
||||
@@ -7,8 +7,8 @@
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// use smartcore::algorithm::neighbour::distances::PairwiseDistance;
|
||||
/// use smartcore::linalg::naive::dense_matrix::DenseMatrix;
|
||||
/// use smartcore::metrics::distance::PairwiseDistance;
|
||||
/// use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
/// use smartcore::algorithm::neighbour::fastpair::FastPair;
|
||||
/// let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
/// &[5.1, 3.5, 1.4, 0.2],
|
||||
@@ -25,12 +25,14 @@
|
||||
/// <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::algorithm::neighbour::distances::PairwiseDistance;
|
||||
use num::Bounded;
|
||||
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::basic::arrays::Array2;
|
||||
use crate::linalg::basic::arrays::{Array1, Array2};
|
||||
use crate::metrics::distance::euclidian::Euclidian;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
use crate::metrics::distance::PairwiseDistance;
|
||||
use crate::numbers::floatnum::FloatNumber;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
|
||||
///
|
||||
/// Inspired by Python implementation:
|
||||
@@ -98,7 +100,7 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
|
||||
PairwiseDistance {
|
||||
node: index_row_i,
|
||||
neighbour: Option::None,
|
||||
distance: Some(T::MAX),
|
||||
distance: Some(<T as Bounded>::max_value()),
|
||||
},
|
||||
);
|
||||
}
|
||||
@@ -119,13 +121,19 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
|
||||
);
|
||||
|
||||
let d = Euclidian::squared_distance(
|
||||
&(self.samples.get_row_as_vec(index_row_i)),
|
||||
&(self.samples.get_row_as_vec(index_row_j)),
|
||||
&Vec::from_iterator(
|
||||
self.samples.get_row(index_row_i).iterator(0).copied(),
|
||||
self.samples.shape().1,
|
||||
),
|
||||
&Vec::from_iterator(
|
||||
self.samples.get_row(index_row_j).iterator(0).copied(),
|
||||
self.samples.shape().1,
|
||||
),
|
||||
);
|
||||
if d < nbd.unwrap() {
|
||||
if d < nbd.unwrap().to_f64().unwrap() {
|
||||
// set this j-value to be the closest neighbour
|
||||
index_closest = index_row_j;
|
||||
nbd = Some(d);
|
||||
nbd = Some(T::from(d).unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -138,7 +146,7 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
|
||||
// No more neighbors, terminate conga line.
|
||||
// Last person on the line has no neigbors
|
||||
distances.get_mut(&max_index).unwrap().neighbour = Some(max_index);
|
||||
distances.get_mut(&(len - 1)).unwrap().distance = Some(T::max_value());
|
||||
distances.get_mut(&(len - 1)).unwrap().distance = Some(<T as Bounded>::max_value());
|
||||
|
||||
// compute sparse matrix (connectivity matrix)
|
||||
let mut sparse_matrix = M::zeros(len, len);
|
||||
@@ -171,33 +179,6 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
/// Brute force algorithm, used only for comparison and testing
|
||||
///
|
||||
#[allow(dead_code)]
|
||||
pub fn closest_pair_brute(&self) -> PairwiseDistance<T> {
|
||||
use itertools::Itertools;
|
||||
let m = self.samples.shape().0;
|
||||
|
||||
let mut closest_pair = PairwiseDistance {
|
||||
node: 0,
|
||||
neighbour: Option::None,
|
||||
distance: Some(T::max_value()),
|
||||
};
|
||||
for pair in (0..m).combinations(2) {
|
||||
let d = Euclidian::squared_distance(
|
||||
&(self.samples.get_row_as_vec(pair[0])),
|
||||
&(self.samples.get_row_as_vec(pair[1])),
|
||||
);
|
||||
if d < closest_pair.distance.unwrap() {
|
||||
closest_pair.node = pair[0];
|
||||
closest_pair.neighbour = Some(pair[1]);
|
||||
closest_pair.distance = Some(d);
|
||||
}
|
||||
}
|
||||
closest_pair
|
||||
}
|
||||
|
||||
//
|
||||
// Compute distances from input to all other points in data-structure.
|
||||
// input is the row index of the sample matrix
|
||||
@@ -210,10 +191,19 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
|
||||
distances.push(PairwiseDistance {
|
||||
node: index_row,
|
||||
neighbour: Some(*other),
|
||||
distance: Some(Euclidian::squared_distance(
|
||||
&(self.samples.get_row_as_vec(index_row)),
|
||||
&(self.samples.get_row_as_vec(*other)),
|
||||
)),
|
||||
distance: Some(
|
||||
T::from(Euclidian::squared_distance(
|
||||
&Vec::from_iterator(
|
||||
self.samples.get_row(index_row).iterator(0).copied(),
|
||||
self.samples.shape().1,
|
||||
),
|
||||
&Vec::from_iterator(
|
||||
self.samples.get_row(*other).iterator(0).copied(),
|
||||
self.samples.shape().1,
|
||||
),
|
||||
))
|
||||
.unwrap(),
|
||||
),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -225,7 +215,39 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
|
||||
mod tests_fastpair {
|
||||
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix};
|
||||
|
||||
///
|
||||
/// Brute force algorithm, used only for comparison and testing
|
||||
///
|
||||
pub fn closest_pair_brute(fastpair: &FastPair<f64, DenseMatrix<f64>>) -> PairwiseDistance<f64> {
|
||||
use itertools::Itertools;
|
||||
let m = fastpair.samples.shape().0;
|
||||
|
||||
let mut closest_pair = PairwiseDistance {
|
||||
node: 0,
|
||||
neighbour: Option::None,
|
||||
distance: Some(f64::max_value()),
|
||||
};
|
||||
for pair in (0..m).combinations(2) {
|
||||
let d = Euclidian::squared_distance(
|
||||
&Vec::from_iterator(
|
||||
fastpair.samples.get_row(pair[0]).iterator(0).copied(),
|
||||
fastpair.samples.shape().1,
|
||||
),
|
||||
&Vec::from_iterator(
|
||||
fastpair.samples.get_row(pair[1]).iterator(0).copied(),
|
||||
fastpair.samples.shape().1,
|
||||
),
|
||||
);
|
||||
if d < closest_pair.distance.unwrap() {
|
||||
closest_pair.node = pair[0];
|
||||
closest_pair.neighbour = Some(pair[1]);
|
||||
closest_pair.distance = Some(d);
|
||||
}
|
||||
}
|
||||
closest_pair
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fastpair_init() {
|
||||
@@ -284,7 +306,7 @@ mod tests_fastpair {
|
||||
};
|
||||
assert_eq!(closest_pair, expected_closest_pair);
|
||||
|
||||
let closest_pair_brute = fastpair.closest_pair_brute();
|
||||
let closest_pair_brute = closest_pair_brute(&fastpair);
|
||||
assert_eq!(closest_pair_brute, expected_closest_pair);
|
||||
}
|
||||
|
||||
@@ -302,7 +324,7 @@ mod tests_fastpair {
|
||||
neighbour: Some(3),
|
||||
distance: Some(4.0),
|
||||
};
|
||||
assert_eq!(closest_pair, fastpair.closest_pair_brute());
|
||||
assert_eq!(closest_pair, closest_pair_brute(&fastpair));
|
||||
assert_eq!(closest_pair, expected_closest_pair);
|
||||
}
|
||||
|
||||
@@ -459,11 +481,16 @@ mod tests_fastpair {
|
||||
let expected: HashMap<_, _> = dissimilarities.into_iter().collect();
|
||||
|
||||
for i in 0..(x.shape().0 - 1) {
|
||||
let input_node = result.samples.get_row_as_vec(i);
|
||||
let input_neighbour: usize = expected.get(&i).unwrap().neighbour.unwrap();
|
||||
let distance = Euclidian::squared_distance(
|
||||
&input_node,
|
||||
&result.samples.get_row_as_vec(input_neighbour),
|
||||
&Vec::from_iterator(
|
||||
result.samples.get_row(i).iterator(0).copied(),
|
||||
result.samples.shape().1,
|
||||
),
|
||||
&Vec::from_iterator(
|
||||
result.samples.get_row(input_neighbour).iterator(0).copied(),
|
||||
result.samples.shape().1,
|
||||
),
|
||||
);
|
||||
|
||||
assert_eq!(i, expected.get(&i).unwrap().node);
|
||||
@@ -518,7 +545,7 @@ mod tests_fastpair {
|
||||
let result = fastpair.unwrap();
|
||||
|
||||
let dissimilarity1 = result.closest_pair();
|
||||
let dissimilarity2 = result.closest_pair_brute();
|
||||
let dissimilarity2 = closest_pair_brute(&result);
|
||||
|
||||
assert_eq!(dissimilarity1, dissimilarity2);
|
||||
}
|
||||
|
||||
@@ -41,10 +41,8 @@ use serde::{Deserialize, Serialize};
|
||||
pub(crate) mod bbd_tree;
|
||||
/// tree data structure for fast nearest neighbor search
|
||||
pub mod cover_tree;
|
||||
/// dissimilarities for vector-vector distance. Linkage algorithms used in fastpair
|
||||
pub mod distances;
|
||||
/// fastpair closest neighbour algorithm
|
||||
// pub mod fastpair;
|
||||
pub mod fastpair;
|
||||
/// very simple algorithm that sequentially checks each element of the list until a match is found or the whole list has been searched.
|
||||
pub mod linear_search;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user