Disambiguate distances. Implement Fastpair. (#220)

This commit is contained in:
Lorenzo
2022-11-02 14:53:28 +00:00
committed by GitHub
parent 4b096ad558
commit b60329ca5d
8 changed files with 171 additions and 135 deletions
+77 -50
View File
@@ -1,5 +1,5 @@
///
/// # FastPair: Data-structure for the dynamic closest-pair problem.
/// ### FastPair: Data-structure for the dynamic closest-pair problem.
///
/// Reference:
/// Eppstein, David: Fast hierarchical clustering and other applications of
@@ -7,8 +7,8 @@
///
/// Example:
/// ```
/// use smartcore::algorithm::neighbour::distances::PairwiseDistance;
/// use smartcore::linalg::naive::dense_matrix::DenseMatrix;
/// use smartcore::metrics::distance::PairwiseDistance;
/// use smartcore::linalg::basic::matrix::DenseMatrix;
/// use smartcore::algorithm::neighbour::fastpair::FastPair;
/// let x = DenseMatrix::<f64>::from_2d_array(&[
/// &[5.1, 3.5, 1.4, 0.2],
@@ -25,12 +25,14 @@
/// <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::collections::HashMap;
use crate::algorithm::neighbour::distances::PairwiseDistance;
use num::Bounded;
use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::Array2;
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::metrics::distance::euclidian::Euclidian;
use crate::numbers::realnum::RealNumber;
use crate::metrics::distance::PairwiseDistance;
use crate::numbers::floatnum::FloatNumber;
use crate::numbers::realnum::RealNumber;
///
/// Inspired by Python implementation:
@@ -98,7 +100,7 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
PairwiseDistance {
node: index_row_i,
neighbour: Option::None,
distance: Some(T::MAX),
distance: Some(<T as Bounded>::max_value()),
},
);
}
@@ -119,13 +121,19 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
);
let d = Euclidian::squared_distance(
&(self.samples.get_row_as_vec(index_row_i)),
&(self.samples.get_row_as_vec(index_row_j)),
&Vec::from_iterator(
self.samples.get_row(index_row_i).iterator(0).copied(),
self.samples.shape().1,
),
&Vec::from_iterator(
self.samples.get_row(index_row_j).iterator(0).copied(),
self.samples.shape().1,
),
);
if d < nbd.unwrap() {
if d < nbd.unwrap().to_f64().unwrap() {
// set this j-value to be the closest neighbour
index_closest = index_row_j;
nbd = Some(d);
nbd = Some(T::from(d).unwrap());
}
}
@@ -138,7 +146,7 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
// No more neighbors, terminate conga line.
// Last person on the line has no neigbors
distances.get_mut(&max_index).unwrap().neighbour = Some(max_index);
distances.get_mut(&(len - 1)).unwrap().distance = Some(T::max_value());
distances.get_mut(&(len - 1)).unwrap().distance = Some(<T as Bounded>::max_value());
// compute sparse matrix (connectivity matrix)
let mut sparse_matrix = M::zeros(len, len);
@@ -171,33 +179,6 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
}
}
///
/// Brute force algorithm, used only for comparison and testing
///
#[allow(dead_code)]
pub fn closest_pair_brute(&self) -> PairwiseDistance<T> {
use itertools::Itertools;
let m = self.samples.shape().0;
let mut closest_pair = PairwiseDistance {
node: 0,
neighbour: Option::None,
distance: Some(T::max_value()),
};
for pair in (0..m).combinations(2) {
let d = Euclidian::squared_distance(
&(self.samples.get_row_as_vec(pair[0])),
&(self.samples.get_row_as_vec(pair[1])),
);
if d < closest_pair.distance.unwrap() {
closest_pair.node = pair[0];
closest_pair.neighbour = Some(pair[1]);
closest_pair.distance = Some(d);
}
}
closest_pair
}
//
// Compute distances from input to all other points in data-structure.
// input is the row index of the sample matrix
@@ -210,10 +191,19 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
distances.push(PairwiseDistance {
node: index_row,
neighbour: Some(*other),
distance: Some(Euclidian::squared_distance(
&(self.samples.get_row_as_vec(index_row)),
&(self.samples.get_row_as_vec(*other)),
)),
distance: Some(
T::from(Euclidian::squared_distance(
&Vec::from_iterator(
self.samples.get_row(index_row).iterator(0).copied(),
self.samples.shape().1,
),
&Vec::from_iterator(
self.samples.get_row(*other).iterator(0).copied(),
self.samples.shape().1,
),
))
.unwrap(),
),
})
}
}
@@ -225,7 +215,39 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
mod tests_fastpair {
use super::*;
use crate::linalg::naive::dense_matrix::*;
use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix};
///
/// Brute force algorithm, used only for comparison and testing
///
pub fn closest_pair_brute(fastpair: &FastPair<f64, DenseMatrix<f64>>) -> PairwiseDistance<f64> {
use itertools::Itertools;
let m = fastpair.samples.shape().0;
let mut closest_pair = PairwiseDistance {
node: 0,
neighbour: Option::None,
distance: Some(f64::max_value()),
};
for pair in (0..m).combinations(2) {
let d = Euclidian::squared_distance(
&Vec::from_iterator(
fastpair.samples.get_row(pair[0]).iterator(0).copied(),
fastpair.samples.shape().1,
),
&Vec::from_iterator(
fastpair.samples.get_row(pair[1]).iterator(0).copied(),
fastpair.samples.shape().1,
),
);
if d < closest_pair.distance.unwrap() {
closest_pair.node = pair[0];
closest_pair.neighbour = Some(pair[1]);
closest_pair.distance = Some(d);
}
}
closest_pair
}
#[test]
fn fastpair_init() {
@@ -284,7 +306,7 @@ mod tests_fastpair {
};
assert_eq!(closest_pair, expected_closest_pair);
let closest_pair_brute = fastpair.closest_pair_brute();
let closest_pair_brute = closest_pair_brute(&fastpair);
assert_eq!(closest_pair_brute, expected_closest_pair);
}
@@ -302,7 +324,7 @@ mod tests_fastpair {
neighbour: Some(3),
distance: Some(4.0),
};
assert_eq!(closest_pair, fastpair.closest_pair_brute());
assert_eq!(closest_pair, closest_pair_brute(&fastpair));
assert_eq!(closest_pair, expected_closest_pair);
}
@@ -459,11 +481,16 @@ mod tests_fastpair {
let expected: HashMap<_, _> = dissimilarities.into_iter().collect();
for i in 0..(x.shape().0 - 1) {
let input_node = result.samples.get_row_as_vec(i);
let input_neighbour: usize = expected.get(&i).unwrap().neighbour.unwrap();
let distance = Euclidian::squared_distance(
&input_node,
&result.samples.get_row_as_vec(input_neighbour),
&Vec::from_iterator(
result.samples.get_row(i).iterator(0).copied(),
result.samples.shape().1,
),
&Vec::from_iterator(
result.samples.get_row(input_neighbour).iterator(0).copied(),
result.samples.shape().1,
),
);
assert_eq!(i, expected.get(&i).unwrap().node);
@@ -518,7 +545,7 @@ mod tests_fastpair {
let result = fastpair.unwrap();
let dissimilarity1 = result.closest_pair();
let dissimilarity2 = result.closest_pair_brute();
let dissimilarity2 = closest_pair_brute(&result);
assert_eq!(dissimilarity1, dissimilarity2);
}