Add ordered pairs for FastPair (#252)
* Add ordered_pairs method to FastPair * add tests to fastpair
This commit is contained in:
@@ -173,6 +173,21 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
/// Return order dissimilarities from closest to furthest
|
||||
///
|
||||
#[allow(dead_code)]
|
||||
pub fn ordered_pairs(&self) -> std::vec::IntoIter<&PairwiseDistance<T>> {
|
||||
// improvement: implement this to return `impl Iterator<Item = &PairwiseDistance<T>>`
|
||||
// need to implement trait `Iterator` for `Vec<&PairwiseDistance<T>>`
|
||||
let mut distances = self
|
||||
.distances
|
||||
.values()
|
||||
.collect::<Vec<&PairwiseDistance<T>>>();
|
||||
distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
distances.into_iter()
|
||||
}
|
||||
|
||||
//
|
||||
// Compute distances from input to all other points in data-structure.
|
||||
// input is the row index of the sample matrix
|
||||
@@ -588,4 +603,103 @@ mod tests_fastpair {
|
||||
|
||||
assert_eq!(closest, min_dissimilarity);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fastpair_ordered_pairs() {
|
||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
])
|
||||
.unwrap();
|
||||
let fastpair = FastPair::new(&x).unwrap();
|
||||
|
||||
let ordered = fastpair.ordered_pairs();
|
||||
|
||||
let mut previous: f64 = -1.0;
|
||||
for p in ordered {
|
||||
if previous == -1.0 {
|
||||
previous = p.distance.unwrap();
|
||||
} else {
|
||||
let current = p.distance.unwrap();
|
||||
assert!(current >= previous);
|
||||
previous = current;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_set() {
|
||||
let empty_matrix = DenseMatrix::<f64>::zeros(0, 0);
|
||||
let result = FastPair::new(&empty_matrix);
|
||||
assert!(result.is_err());
|
||||
if let Err(e) = result {
|
||||
assert_eq!(
|
||||
e,
|
||||
Failed::because(FailedError::FindFailed, "min number of rows should be 3")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_point() {
|
||||
let single_point = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]).unwrap();
|
||||
let result = FastPair::new(&single_point);
|
||||
assert!(result.is_err());
|
||||
if let Err(e) = result {
|
||||
assert_eq!(
|
||||
e,
|
||||
Failed::because(FailedError::FindFailed, "min number of rows should be 3")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_two_points() {
|
||||
let two_points = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
|
||||
let result = FastPair::new(&two_points);
|
||||
assert!(result.is_err());
|
||||
if let Err(e) = result {
|
||||
assert_eq!(
|
||||
e,
|
||||
Failed::because(FailedError::FindFailed, "min number of rows should be 3")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_three_identical_points() {
|
||||
let identical_points =
|
||||
DenseMatrix::from_2d_array(&[&[1.0, 1.0], &[1.0, 1.0], &[1.0, 1.0]]).unwrap();
|
||||
let result = FastPair::new(&identical_points);
|
||||
assert!(result.is_ok());
|
||||
let fastpair = result.unwrap();
|
||||
let closest_pair = fastpair.closest_pair();
|
||||
assert_eq!(closest_pair.distance, Some(0.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_result_unwrapping() {
|
||||
let valid_matrix =
|
||||
DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0], &[5.0, 6.0], &[7.0, 8.0]])
|
||||
.unwrap();
|
||||
|
||||
let result = FastPair::new(&valid_matrix);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// This should not panic
|
||||
let _fastpair = result.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user