Disambiguate distances. Implement Fastpair. (#220)
This commit is contained in:
@@ -24,9 +24,15 @@ pub mod manhattan;
|
||||
/// A generalization of both the Euclidean distance and the Manhattan distance.
|
||||
pub mod minkowski;
|
||||
|
||||
use std::cmp::{Eq, Ordering, PartialOrd};
|
||||
|
||||
use crate::linalg::basic::arrays::Array2;
|
||||
use crate::linalg::traits::lu::LUDecomposable;
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Distance metric, a function that calculates distance between two points
|
||||
pub trait Distance<T>: Clone {
|
||||
@@ -66,3 +72,45 @@ impl Distances {
|
||||
mahalanobis::Mahalanobis::new(data)
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
/// ### Pairwise dissimilarities.
|
||||
///
|
||||
/// Representing distances as pairwise dissimilarities, so to build a
|
||||
/// graph of closest neighbours. This representation can be reused for
|
||||
/// different implementations
|
||||
/// (initially used in this library for [FastPair](algorithm/neighbour/fastpair)).
|
||||
/// The edge of the subgraph is defined by `PairwiseDistance`.
|
||||
/// The calling algorithm can store a list of distances as
|
||||
/// a list of these structures.
|
||||
///
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct PairwiseDistance<T: RealNumber> {
|
||||
/// index of the vector in the original `Matrix` or list
|
||||
pub node: usize,
|
||||
|
||||
/// index of the closest neighbor in the original `Matrix` or same list
|
||||
pub neighbour: Option<usize>,
|
||||
|
||||
/// measure of distance, according to the algorithm distance function
|
||||
/// if the distance is None, the edge has value "infinite" or max distance
|
||||
/// each algorithm has to match
|
||||
pub distance: Option<T>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Eq for PairwiseDistance<T> {}
|
||||
|
||||
impl<T: RealNumber> PartialEq for PairwiseDistance<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.node == other.node
|
||||
&& self.neighbour == other.neighbour
|
||||
&& self.distance == other.distance
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> PartialOrd for PairwiseDistance<T> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
self.distance.partial_cmp(&other.distance)
|
||||
}
|
||||
}
|
||||
|
||||
+18
-15
@@ -84,7 +84,7 @@ use std::marker::PhantomData;
|
||||
/// A trait to be implemented by all metrics
|
||||
pub trait Metrics<T> {
|
||||
/// instantiate a new Metrics trait-object
|
||||
/// https://doc.rust-lang.org/error-index.html#E0038
|
||||
/// <https://doc.rust-lang.org/error-index.html#E0038>
|
||||
fn new() -> Self
|
||||
where
|
||||
Self: Sized;
|
||||
@@ -133,10 +133,10 @@ impl<T: Number + RealNumber + FloatNumber> ClassificationMetrics<T> {
|
||||
f1::F1::new_with(beta)
|
||||
}
|
||||
|
||||
// /// Area Under the Receiver Operating Characteristic Curve (ROC AUC), see [AUC](auc/index.html).
|
||||
// pub fn roc_auc_score() -> auc::AUC<T> {
|
||||
// auc::AUC::<T>::new()
|
||||
// }
|
||||
/// Area Under the Receiver Operating Characteristic Curve (ROC AUC), see [AUC](auc/index.html).
|
||||
pub fn roc_auc_score() -> auc::AUC<T> {
|
||||
auc::AUC::<T>::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Number + Ord> ClassificationMetricsOrd<T> {
|
||||
@@ -212,16 +212,19 @@ pub fn f1<T: Number + RealNumber + FloatNumber, V: ArrayView1<T>>(
|
||||
obj.get_score(y_true, y_pred)
|
||||
}
|
||||
|
||||
// /// AUC score, see [AUC](auc/index.html).
|
||||
// /// * `y_true` - cround truth (correct) labels.
|
||||
// /// * `y_pred_probabilities` - probability estimates, as returned by a classifier.
|
||||
// pub fn roc_auc_score<T: Number + PartialOrd, V: ArrayView1<T> + Array1<T> + Array1<T>>(
|
||||
// y_true: &V,
|
||||
// y_pred_probabilities: &V,
|
||||
// ) -> T {
|
||||
// let obj = ClassificationMetrics::<T>::roc_auc_score();
|
||||
// obj.get_score(y_true, y_pred_probabilities)
|
||||
// }
|
||||
/// AUC score, see [AUC](auc/index.html).
|
||||
/// * `y_true` - cround truth (correct) labels.
|
||||
/// * `y_pred_probabilities` - probability estimates, as returned by a classifier.
|
||||
pub fn roc_auc_score<
|
||||
T: Number + RealNumber + FloatNumber + PartialOrd,
|
||||
V: ArrayView1<T> + Array1<T> + Array1<T>,
|
||||
>(
|
||||
y_true: &V,
|
||||
y_pred_probabilities: &V,
|
||||
) -> f64 {
|
||||
let obj = ClassificationMetrics::<T>::roc_auc_score();
|
||||
obj.get_score(y_true, y_pred_probabilities)
|
||||
}
|
||||
|
||||
/// Computes mean squared error, see [mean squared error](mean_squared_error/index.html).
|
||||
/// * `y_true` - Ground truth (correct) target values.
|
||||
|
||||
Reference in New Issue
Block a user