Disambiguate distances. Implement Fastpair. (#220)

This commit is contained in:
Lorenzo
2022-11-02 14:53:28 +00:00
committed by GitHub
parent 4b096ad558
commit b60329ca5d
8 changed files with 171 additions and 135 deletions
+1
View File
@@ -35,6 +35,7 @@ js = ["getrandom/js"]
getrandom = { version = "0.2", optional = true }
[dev-dependencies]
itertools = "*"
criterion = { version = "0.4", default-features = false }
serde_json = "1.0"
bincode = "1.3.1"
-48
View File
@@ -1,48 +0,0 @@
//!
//! Dissimilarities for vector-vector distance
//!
//! Representing distances as pairwise dissimilarities, so to build a
//! graph of closest neighbours. This representation can be reused for
//! different implementations (initially used in this library for FastPair).
use std::cmp::{Eq, Ordering, PartialOrd};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::numbers::realnum::RealNumber;
///
/// The edge of the subgraph is defined by `PairwiseDistance`.
/// The calling algorithm can store a list of distsances as
/// a list of these structures.
///
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, Copy)]
pub struct PairwiseDistance<T: RealNumber> {
/// index of the vector in the original `Matrix` or list
pub node: usize,
/// index of the closest neighbor in the original `Matrix` or same list
pub neighbour: Option<usize>,
/// measure of distance, according to the algorithm distance function
/// if the distance is None, the edge has value "infinite" or max distance
/// each algorithm has to match
pub distance: Option<T>,
}
impl<T: RealNumber> Eq for PairwiseDistance<T> {}
impl<T: RealNumber> PartialEq for PairwiseDistance<T> {
fn eq(&self, other: &Self) -> bool {
self.node == other.node
&& self.neighbour == other.neighbour
&& self.distance == other.distance
}
}
impl<T: RealNumber> PartialOrd for PairwiseDistance<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.distance.partial_cmp(&other.distance)
}
}
+77 -50
View File
@@ -1,5 +1,5 @@
///
/// # FastPair: Data-structure for the dynamic closest-pair problem.
/// ### FastPair: Data-structure for the dynamic closest-pair problem.
///
/// Reference:
/// Eppstein, David: Fast hierarchical clustering and other applications of
@@ -7,8 +7,8 @@
///
/// Example:
/// ```
/// use smartcore::algorithm::neighbour::distances::PairwiseDistance;
/// use smartcore::linalg::naive::dense_matrix::DenseMatrix;
/// use smartcore::metrics::distance::PairwiseDistance;
/// use smartcore::linalg::basic::matrix::DenseMatrix;
/// use smartcore::algorithm::neighbour::fastpair::FastPair;
/// let x = DenseMatrix::<f64>::from_2d_array(&[
/// &[5.1, 3.5, 1.4, 0.2],
@@ -25,12 +25,14 @@
/// <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::collections::HashMap;
use crate::algorithm::neighbour::distances::PairwiseDistance;
use num::Bounded;
use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::Array2;
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::metrics::distance::euclidian::Euclidian;
use crate::numbers::realnum::RealNumber;
use crate::metrics::distance::PairwiseDistance;
use crate::numbers::floatnum::FloatNumber;
use crate::numbers::realnum::RealNumber;
///
/// Inspired by Python implementation:
@@ -98,7 +100,7 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
PairwiseDistance {
node: index_row_i,
neighbour: Option::None,
distance: Some(T::MAX),
distance: Some(<T as Bounded>::max_value()),
},
);
}
@@ -119,13 +121,19 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
);
let d = Euclidian::squared_distance(
&(self.samples.get_row_as_vec(index_row_i)),
&(self.samples.get_row_as_vec(index_row_j)),
&Vec::from_iterator(
self.samples.get_row(index_row_i).iterator(0).copied(),
self.samples.shape().1,
),
&Vec::from_iterator(
self.samples.get_row(index_row_j).iterator(0).copied(),
self.samples.shape().1,
),
);
if d < nbd.unwrap() {
if d < nbd.unwrap().to_f64().unwrap() {
// set this j-value to be the closest neighbour
index_closest = index_row_j;
nbd = Some(d);
nbd = Some(T::from(d).unwrap());
}
}
@@ -138,7 +146,7 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
// No more neighbors, terminate conga line.
// Last person on the line has no neigbors
distances.get_mut(&max_index).unwrap().neighbour = Some(max_index);
distances.get_mut(&(len - 1)).unwrap().distance = Some(T::max_value());
distances.get_mut(&(len - 1)).unwrap().distance = Some(<T as Bounded>::max_value());
// compute sparse matrix (connectivity matrix)
let mut sparse_matrix = M::zeros(len, len);
@@ -171,33 +179,6 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
}
}
///
/// Brute force algorithm, used only for comparison and testing
///
#[allow(dead_code)]
pub fn closest_pair_brute(&self) -> PairwiseDistance<T> {
use itertools::Itertools;
let m = self.samples.shape().0;
let mut closest_pair = PairwiseDistance {
node: 0,
neighbour: Option::None,
distance: Some(T::max_value()),
};
for pair in (0..m).combinations(2) {
let d = Euclidian::squared_distance(
&(self.samples.get_row_as_vec(pair[0])),
&(self.samples.get_row_as_vec(pair[1])),
);
if d < closest_pair.distance.unwrap() {
closest_pair.node = pair[0];
closest_pair.neighbour = Some(pair[1]);
closest_pair.distance = Some(d);
}
}
closest_pair
}
//
// Compute distances from input to all other points in data-structure.
// input is the row index of the sample matrix
@@ -210,10 +191,19 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
distances.push(PairwiseDistance {
node: index_row,
neighbour: Some(*other),
distance: Some(Euclidian::squared_distance(
&(self.samples.get_row_as_vec(index_row)),
&(self.samples.get_row_as_vec(*other)),
)),
distance: Some(
T::from(Euclidian::squared_distance(
&Vec::from_iterator(
self.samples.get_row(index_row).iterator(0).copied(),
self.samples.shape().1,
),
&Vec::from_iterator(
self.samples.get_row(*other).iterator(0).copied(),
self.samples.shape().1,
),
))
.unwrap(),
),
})
}
}
@@ -225,7 +215,39 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
mod tests_fastpair {
use super::*;
use crate::linalg::naive::dense_matrix::*;
use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix};
///
/// Brute force algorithm, used only for comparison and testing
///
pub fn closest_pair_brute(fastpair: &FastPair<f64, DenseMatrix<f64>>) -> PairwiseDistance<f64> {
use itertools::Itertools;
let m = fastpair.samples.shape().0;
let mut closest_pair = PairwiseDistance {
node: 0,
neighbour: Option::None,
distance: Some(f64::max_value()),
};
for pair in (0..m).combinations(2) {
let d = Euclidian::squared_distance(
&Vec::from_iterator(
fastpair.samples.get_row(pair[0]).iterator(0).copied(),
fastpair.samples.shape().1,
),
&Vec::from_iterator(
fastpair.samples.get_row(pair[1]).iterator(0).copied(),
fastpair.samples.shape().1,
),
);
if d < closest_pair.distance.unwrap() {
closest_pair.node = pair[0];
closest_pair.neighbour = Some(pair[1]);
closest_pair.distance = Some(d);
}
}
closest_pair
}
#[test]
fn fastpair_init() {
@@ -284,7 +306,7 @@ mod tests_fastpair {
};
assert_eq!(closest_pair, expected_closest_pair);
let closest_pair_brute = fastpair.closest_pair_brute();
let closest_pair_brute = closest_pair_brute(&fastpair);
assert_eq!(closest_pair_brute, expected_closest_pair);
}
@@ -302,7 +324,7 @@ mod tests_fastpair {
neighbour: Some(3),
distance: Some(4.0),
};
assert_eq!(closest_pair, fastpair.closest_pair_brute());
assert_eq!(closest_pair, closest_pair_brute(&fastpair));
assert_eq!(closest_pair, expected_closest_pair);
}
@@ -459,11 +481,16 @@ mod tests_fastpair {
let expected: HashMap<_, _> = dissimilarities.into_iter().collect();
for i in 0..(x.shape().0 - 1) {
let input_node = result.samples.get_row_as_vec(i);
let input_neighbour: usize = expected.get(&i).unwrap().neighbour.unwrap();
let distance = Euclidian::squared_distance(
&input_node,
&result.samples.get_row_as_vec(input_neighbour),
&Vec::from_iterator(
result.samples.get_row(i).iterator(0).copied(),
result.samples.shape().1,
),
&Vec::from_iterator(
result.samples.get_row(input_neighbour).iterator(0).copied(),
result.samples.shape().1,
),
);
assert_eq!(i, expected.get(&i).unwrap().node);
@@ -518,7 +545,7 @@ mod tests_fastpair {
let result = fastpair.unwrap();
let dissimilarity1 = result.closest_pair();
let dissimilarity2 = result.closest_pair_brute();
let dissimilarity2 = closest_pair_brute(&result);
assert_eq!(dissimilarity1, dissimilarity2);
}
+1 -3
View File
@@ -41,10 +41,8 @@ use serde::{Deserialize, Serialize};
pub(crate) mod bbd_tree;
/// tree data structure for fast nearest neighbor search
pub mod cover_tree;
/// dissimilarities for vector-vector distance. Linkage algorithms used in fastpair
pub mod distances;
/// fastpair closest neighbour algorithm
// pub mod fastpair;
pub mod fastpair;
/// very simple algorithm that sequentially checks each element of the list until a match is found or the whole list has been searched.
pub mod linear_search;
+23 -16
View File
@@ -10,34 +10,30 @@
//! # SmartCore
//!
//! Welcome to SmartCore, the most advanced machine learning library in Rust!
//! Welcome to SmartCore, machine learning in Rust!
//!
//! SmartCore features various classification, regression and clustering algorithms including support vector machines, random forests, k-means and DBSCAN,
//! as well as tools for model selection and model evaluation.
//!
//! SmartCore is well integrated with a with wide variaty of libraries that provide support for large, multi-dimensional arrays and matrices. At this moment,
//! all Smartcore's algorithms work with ordinary Rust vectors, as well as matrices and vectors defined in these packages:
//! * [ndarray](https://docs.rs/ndarray)
//! SmartCore provides its own traits system that extends Rust standard library, to deal with linear algebra and common
//! computational models. Its API is designed using well recognizable patterns. Extra features (like support for [ndarray](https://docs.rs/ndarray)
//! structures) is available via optional features.
//!
//! ## Getting Started
//!
//! To start using SmartCore simply add the following to your Cargo.toml file:
//! ```ignore
//! [dependencies]
//! smartcore = { git = "https://github.com/smartcorelib/smartcore", branch = "v0.5-wip" }
//! smartcore = { git = "https://github.com/smartcorelib/smartcore", branch = "development" }
//! ```
//!
//! All machine learning algorithms in SmartCore are grouped into these broad categories:
//! * [Clustering](cluster/index.html), unsupervised clustering of unlabeled data.
//! * [Matrix Decomposition](decomposition/index.html), various methods for matrix decomposition.
//! * [Linear Models](linear/index.html), regression and classification methods where output is assumed to have linear relation to explanatory variables
//! * [Ensemble Models](ensemble/index.html), variety of regression and classification ensemble models
//! * [Tree-based Models](tree/index.html), classification and regression trees
//! * [Nearest Neighbors](neighbors/index.html), K Nearest Neighbors for classification and regression
//! * [Naive Bayes](naive_bayes/index.html), statistical classification technique based on Bayes Theorem
//! * [SVM](svm/index.html), support vector machines
//! ## Using Jupyter
//! For quick introduction, Jupyter Notebooks are available [here](https://github.com/smartcorelib/smartcore-jupyter/tree/main/notebooks).
//! You can set up a local environment to run Rust notebooks using [EVCXR](https://github.com/google/evcxr)
//! following [these instructions](https://depth-first.com/articles/2020/09/21/interactive-rust-in-a-repl-and-jupyter-notebook-with-evcxr/).
//!
//!
//! ## First Example
//! For example, you can use this code to fit a [K Nearest Neighbors classifier](neighbors/knn_classifier/index.html) to a dataset that is defined as standard Rust vector:
//!
//! ```
@@ -48,14 +44,14 @@
//! // Various distance metrics
//! use smartcore::metrics::distance::*;
//!
//! // Turn Rust vectors with samples into a matrix
//! // Turn Rust vector-slices with samples into a matrix
//! let x = DenseMatrix::from_2d_array(&[
//! &[1., 2.],
//! &[3., 4.],
//! &[5., 6.],
//! &[7., 8.],
//! &[9., 10.]]);
//! // Our classes are defined as a Vector
//! // Our classes are defined as a vector
//! let y = vec![2, 2, 2, 3, 3];
//!
//! // Train classifier
@@ -64,6 +60,17 @@
//! // Predict classes
//! let y_hat = knn.predict(&x).unwrap();
//! ```
//!
//! ## Overview
//! All machine learning algorithms in SmartCore are grouped into these broad categories:
//! * [Clustering](cluster/index.html), unsupervised clustering of unlabeled data.
//! * [Matrix Decomposition](decomposition/index.html), various methods for matrix decomposition.
//! * [Linear Models](linear/index.html), regression and classification methods where output is assumed to have linear relation to explanatory variables
//! * [Ensemble Models](ensemble/index.html), variety of regression and classification ensemble models
//! * [Tree-based Models](tree/index.html), classification and regression trees
//! * [Nearest Neighbors](neighbors/index.html), K Nearest Neighbors for classification and regression
//! * [Naive Bayes](naive_bayes/index.html), statistical classification technique based on Bayes Theorem
//! * [SVM](svm/index.html), support vector machines
/// Foundamental numbers traits
pub mod numbers;
+3 -3
View File
@@ -71,8 +71,8 @@ pub trait MatrixStats<T: RealNumber>: ArrayView2<T> + Array2<T> {
x
}
/// (reference)[http://en.wikipedia.org/wiki/Arithmetic_mean]
/// Taken from statistical
/// <http://en.wikipedia.org/wiki/Arithmetic_mean>
/// Taken from `statistical`
/// The MIT License (MIT)
/// Copyright (c) 2015 Jeff Belgum
fn _mean_of_vector(v: &[T]) -> T {
@@ -97,7 +97,7 @@ pub trait MatrixStats<T: RealNumber>: ArrayView2<T> + Array2<T> {
sum
}
/// (Sample variance)[http://en.wikipedia.org/wiki/Variance#Sample_variance]
/// <http://en.wikipedia.org/wiki/Variance#Sample_variance>
/// Taken from statistical
/// The MIT License (MIT)
/// Copyright (c) 2015 Jeff Belgum
+48
View File
@@ -24,9 +24,15 @@ pub mod manhattan;
/// A generalization of both the Euclidean distance and the Manhattan distance.
pub mod minkowski;
use std::cmp::{Eq, Ordering, PartialOrd};
use crate::linalg::basic::arrays::Array2;
use crate::linalg::traits::lu::LUDecomposable;
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
/// Distance metric, a function that calculates distance between two points
pub trait Distance<T>: Clone {
@@ -66,3 +72,45 @@ impl Distances {
mahalanobis::Mahalanobis::new(data)
}
}
///
/// ### Pairwise dissimilarities.
///
/// Representing distances as pairwise dissimilarities, so to build a
/// graph of closest neighbours. This representation can be reused for
/// different implementations
/// (initially used in this library for [FastPair](algorithm/neighbour/fastpair)).
/// The edge of the subgraph is defined by `PairwiseDistance`.
/// The calling algorithm can store a list of distances as
/// a list of these structures.
///
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, Copy)]
pub struct PairwiseDistance<T: RealNumber> {
/// index of the vector in the original `Matrix` or list
pub node: usize,
/// index of the closest neighbor in the original `Matrix` or same list
pub neighbour: Option<usize>,
/// measure of distance, according to the algorithm distance function
/// if the distance is None, the edge has value "infinite" or max distance
/// each algorithm has to match
pub distance: Option<T>,
}
impl<T: RealNumber> Eq for PairwiseDistance<T> {}
impl<T: RealNumber> PartialEq for PairwiseDistance<T> {
fn eq(&self, other: &Self) -> bool {
self.node == other.node
&& self.neighbour == other.neighbour
&& self.distance == other.distance
}
}
impl<T: RealNumber> PartialOrd for PairwiseDistance<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.distance.partial_cmp(&other.distance)
}
}
+18 -15
View File
@@ -84,7 +84,7 @@ use std::marker::PhantomData;
/// A trait to be implemented by all metrics
pub trait Metrics<T> {
/// instantiate a new Metrics trait-object
/// https://doc.rust-lang.org/error-index.html#E0038
/// <https://doc.rust-lang.org/error-index.html#E0038>
fn new() -> Self
where
Self: Sized;
@@ -133,10 +133,10 @@ impl<T: Number + RealNumber + FloatNumber> ClassificationMetrics<T> {
f1::F1::new_with(beta)
}
// /// Area Under the Receiver Operating Characteristic Curve (ROC AUC), see [AUC](auc/index.html).
// pub fn roc_auc_score() -> auc::AUC<T> {
// auc::AUC::<T>::new()
// }
/// Area Under the Receiver Operating Characteristic Curve (ROC AUC), see [AUC](auc/index.html).
pub fn roc_auc_score() -> auc::AUC<T> {
auc::AUC::<T>::new()
}
}
impl<T: Number + Ord> ClassificationMetricsOrd<T> {
@@ -212,16 +212,19 @@ pub fn f1<T: Number + RealNumber + FloatNumber, V: ArrayView1<T>>(
obj.get_score(y_true, y_pred)
}
// /// AUC score, see [AUC](auc/index.html).
// /// * `y_true` - cround truth (correct) labels.
// /// * `y_pred_probabilities` - probability estimates, as returned by a classifier.
// pub fn roc_auc_score<T: Number + PartialOrd, V: ArrayView1<T> + Array1<T> + Array1<T>>(
// y_true: &V,
// y_pred_probabilities: &V,
// ) -> T {
// let obj = ClassificationMetrics::<T>::roc_auc_score();
// obj.get_score(y_true, y_pred_probabilities)
// }
/// AUC score, see [AUC](auc/index.html).
/// * `y_true` - cround truth (correct) labels.
/// * `y_pred_probabilities` - probability estimates, as returned by a classifier.
pub fn roc_auc_score<
T: Number + RealNumber + FloatNumber + PartialOrd,
V: ArrayView1<T> + Array1<T> + Array1<T>,
>(
y_true: &V,
y_pred_probabilities: &V,
) -> f64 {
let obj = ClassificationMetrics::<T>::roc_auc_score();
obj.get_score(y_true, y_pred_probabilities)
}
/// Computes mean squared error, see [mean squared error](mean_squared_error/index.html).
/// * `y_true` - Ground truth (correct) target values.