Disambiguate distances. Implement Fastpair. (#220)

This commit is contained in:
Lorenzo
2022-11-02 14:53:28 +00:00
committed by GitHub
parent 4b096ad558
commit b60329ca5d
8 changed files with 171 additions and 135 deletions
+48
View File
@@ -24,9 +24,15 @@ pub mod manhattan;
/// A generalization of both the Euclidean distance and the Manhattan distance.
pub mod minkowski;
use std::cmp::{Eq, Ordering, PartialOrd};
use crate::linalg::basic::arrays::Array2;
use crate::linalg::traits::lu::LUDecomposable;
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
/// Distance metric, a function that calculates distance between two points
pub trait Distance<T>: Clone {
@@ -66,3 +72,45 @@ impl Distances {
mahalanobis::Mahalanobis::new(data)
}
}
///
/// ### Pairwise dissimilarities.
///
/// Representing distances as pairwise dissimilarities, so to build a
/// graph of closest neighbours. This representation can be reused for
/// different implementations
/// (initially used in this library for [FastPair](algorithm/neighbour/fastpair)).
/// The edge of the subgraph is defined by `PairwiseDistance`.
/// The calling algorithm can store a list of distances as
/// a list of these structures.
///
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, Copy)]
pub struct PairwiseDistance<T: RealNumber> {
/// index of the vector in the original `Matrix` or list
pub node: usize,
/// index of the closest neighbor in the original `Matrix` or same list
pub neighbour: Option<usize>,
/// measure of distance, according to the algorithm distance function
/// if the distance is None, the edge has value "infinite" or max distance
/// each algorithm has to match
pub distance: Option<T>,
}
impl<T: RealNumber> Eq for PairwiseDistance<T> {}
impl<T: RealNumber> PartialEq for PairwiseDistance<T> {
fn eq(&self, other: &Self) -> bool {
self.node == other.node
&& self.neighbour == other.neighbour
&& self.distance == other.distance
}
}
impl<T: RealNumber> PartialOrd for PairwiseDistance<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.distance.partial_cmp(&other.distance)
}
}
+18 -15
View File
@@ -84,7 +84,7 @@ use std::marker::PhantomData;
/// A trait to be implemented by all metrics
pub trait Metrics<T> {
/// instantiate a new Metrics trait-object
/// https://doc.rust-lang.org/error-index.html#E0038
/// <https://doc.rust-lang.org/error-index.html#E0038>
fn new() -> Self
where
Self: Sized;
@@ -133,10 +133,10 @@ impl<T: Number + RealNumber + FloatNumber> ClassificationMetrics<T> {
f1::F1::new_with(beta)
}
// /// Area Under the Receiver Operating Characteristic Curve (ROC AUC), see [AUC](auc/index.html).
// pub fn roc_auc_score() -> auc::AUC<T> {
// auc::AUC::<T>::new()
// }
/// Area Under the Receiver Operating Characteristic Curve (ROC AUC), see [AUC](auc/index.html).
pub fn roc_auc_score() -> auc::AUC<T> {
auc::AUC::<T>::new()
}
}
impl<T: Number + Ord> ClassificationMetricsOrd<T> {
@@ -212,16 +212,19 @@ pub fn f1<T: Number + RealNumber + FloatNumber, V: ArrayView1<T>>(
obj.get_score(y_true, y_pred)
}
// /// AUC score, see [AUC](auc/index.html).
// /// * `y_true` - cround truth (correct) labels.
// /// * `y_pred_probabilities` - probability estimates, as returned by a classifier.
// pub fn roc_auc_score<T: Number + PartialOrd, V: ArrayView1<T> + Array1<T> + Array1<T>>(
// y_true: &V,
// y_pred_probabilities: &V,
// ) -> T {
// let obj = ClassificationMetrics::<T>::roc_auc_score();
// obj.get_score(y_true, y_pred_probabilities)
// }
/// AUC score, see [AUC](auc/index.html).
/// * `y_true` - cround truth (correct) labels.
/// * `y_pred_probabilities` - probability estimates, as returned by a classifier.
pub fn roc_auc_score<
T: Number + RealNumber + FloatNumber + PartialOrd,
V: ArrayView1<T> + Array1<T> + Array1<T>,
>(
y_true: &V,
y_pred_probabilities: &V,
) -> f64 {
let obj = ClassificationMetrics::<T>::roc_auc_score();
obj.get_score(y_true, y_pred_probabilities)
}
/// Computes mean squared error, see [mean squared error](mean_squared_error/index.html).
/// * `y_true` - Ground truth (correct) target values.