From 17dc9f3bbfef6adbd60db01c0b5f46e09768f88b Mon Sep 17 00:00:00 2001 From: Lorenzo Date: Tue, 28 Jan 2025 00:48:08 +0000 Subject: [PATCH] Add ordered pairs for FastPair (#252) * Add ordered_pairs method to FastPair * add tests to fastpair --- src/algorithm/neighbour/fastpair.rs | 114 ++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/src/algorithm/neighbour/fastpair.rs b/src/algorithm/neighbour/fastpair.rs index 4e99261..f494a7d 100644 --- a/src/algorithm/neighbour/fastpair.rs +++ b/src/algorithm/neighbour/fastpair.rs @@ -173,6 +173,21 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2> FastPair<'a, T, M> { } } + /// + /// Return order dissimilarities from closest to furthest + /// + #[allow(dead_code)] + pub fn ordered_pairs(&self) -> std::vec::IntoIter<&PairwiseDistance> { + // improvement: implement this to return `impl Iterator>` + // need to implement trait `Iterator` for `Vec<&PairwiseDistance>` + let mut distances = self + .distances + .values() + .collect::>>(); + distances.sort_by(|a, b| a.partial_cmp(b).unwrap()); + distances.into_iter() + } + // // Compute distances from input to all other points in data-structure. // input is the row index of the sample matrix @@ -588,4 +603,103 @@ mod tests_fastpair { assert_eq!(closest, min_dissimilarity); } + + #[test] + fn fastpair_ordered_pairs() { + let x = DenseMatrix::::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.9, 3.1, 1.5, 0.1], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + ]) + .unwrap(); + let fastpair = FastPair::new(&x).unwrap(); + + let ordered = fastpair.ordered_pairs(); + + let mut previous: f64 = -1.0; + for p in ordered { + if previous == -1.0 { + previous = p.distance.unwrap(); + } else { + let current = p.distance.unwrap(); + assert!(current >= previous); + previous = current; + } + } + } + + #[test] + fn test_empty_set() { + let empty_matrix = DenseMatrix::::zeros(0, 0); + let result = FastPair::new(&empty_matrix); + assert!(result.is_err()); + if let Err(e) = result { + assert_eq!( + e, + Failed::because(FailedError::FindFailed, "min number of rows should be 3") + ); + } + } + + #[test] + fn test_single_point() { + let single_point = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]).unwrap(); + let result = FastPair::new(&single_point); + assert!(result.is_err()); + if let Err(e) = result { + assert_eq!( + e, + Failed::because(FailedError::FindFailed, "min number of rows should be 3") + ); + } + } + + #[test] + fn test_two_points() { + let two_points = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap(); + let result = FastPair::new(&two_points); + assert!(result.is_err()); + if let Err(e) = result { + assert_eq!( + e, + Failed::because(FailedError::FindFailed, "min number of rows should be 3") + ); + } + } + + #[test] + fn test_three_identical_points() { + let identical_points = + DenseMatrix::from_2d_array(&[&[1.0, 1.0], &[1.0, 1.0], &[1.0, 1.0]]).unwrap(); + let result = FastPair::new(&identical_points); + assert!(result.is_ok()); + let fastpair = result.unwrap(); + let closest_pair = fastpair.closest_pair(); + assert_eq!(closest_pair.distance, Some(0.0)); + } + + #[test] + fn test_result_unwrapping() { + let valid_matrix = + DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0], &[5.0, 6.0], &[7.0, 8.0]]) + .unwrap(); + + let result = FastPair::new(&valid_matrix); + assert!(result.is_ok()); + + // This should not panic + let _fastpair = result.unwrap(); + } }