Disambiguate distances. Implement Fastpair. (#220)

This commit is contained in:
Lorenzo
2022-11-02 14:53:28 +00:00
committed by morenol
parent 8f1a7dfd79
commit 7f35dc54e4
8 changed files with 171 additions and 135 deletions
+1
View File
@@ -35,6 +35,7 @@ js = ["getrandom/js"]
getrandom = { version = "0.2", optional = true } getrandom = { version = "0.2", optional = true }
[dev-dependencies] [dev-dependencies]
itertools = "*"
criterion = { version = "0.4", default-features = false } criterion = { version = "0.4", default-features = false }
serde_json = "1.0" serde_json = "1.0"
bincode = "1.3.1" bincode = "1.3.1"
-48
View File
@@ -1,48 +0,0 @@
//!
//! Dissimilarities for vector-vector distance
//!
//! Representing distances as pairwise dissimilarities, so to build a
//! graph of closest neighbours. This representation can be reused for
//! different implementations (initially used in this library for FastPair).
use std::cmp::{Eq, Ordering, PartialOrd};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::numbers::realnum::RealNumber;
///
/// The edge of the subgraph is defined by `PairwiseDistance`.
/// The calling algorithm can store a list of distsances as
/// a list of these structures.
///
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, Copy)]
pub struct PairwiseDistance<T: RealNumber> {
/// index of the vector in the original `Matrix` or list
pub node: usize,
/// index of the closest neighbor in the original `Matrix` or same list
pub neighbour: Option<usize>,
/// measure of distance, according to the algorithm distance function
/// if the distance is None, the edge has value "infinite" or max distance
/// each algorithm has to match
pub distance: Option<T>,
}
impl<T: RealNumber> Eq for PairwiseDistance<T> {}
impl<T: RealNumber> PartialEq for PairwiseDistance<T> {
fn eq(&self, other: &Self) -> bool {
self.node == other.node
&& self.neighbour == other.neighbour
&& self.distance == other.distance
}
}
impl<T: RealNumber> PartialOrd for PairwiseDistance<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.distance.partial_cmp(&other.distance)
}
}
+77 -50
View File
@@ -1,5 +1,5 @@
/// ///
/// # FastPair: Data-structure for the dynamic closest-pair problem. /// ### FastPair: Data-structure for the dynamic closest-pair problem.
/// ///
/// Reference: /// Reference:
/// Eppstein, David: Fast hierarchical clustering and other applications of /// Eppstein, David: Fast hierarchical clustering and other applications of
@@ -7,8 +7,8 @@
/// ///
/// Example: /// Example:
/// ``` /// ```
/// use smartcore::algorithm::neighbour::distances::PairwiseDistance; /// use smartcore::metrics::distance::PairwiseDistance;
/// use smartcore::linalg::naive::dense_matrix::DenseMatrix; /// use smartcore::linalg::basic::matrix::DenseMatrix;
/// use smartcore::algorithm::neighbour::fastpair::FastPair; /// use smartcore::algorithm::neighbour::fastpair::FastPair;
/// let x = DenseMatrix::<f64>::from_2d_array(&[ /// let x = DenseMatrix::<f64>::from_2d_array(&[
/// &[5.1, 3.5, 1.4, 0.2], /// &[5.1, 3.5, 1.4, 0.2],
@@ -25,12 +25,14 @@
/// <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> /// <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::collections::HashMap; use std::collections::HashMap;
use crate::algorithm::neighbour::distances::PairwiseDistance; use num::Bounded;
use crate::error::{Failed, FailedError}; use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::Array2; use crate::linalg::basic::arrays::{Array1, Array2};
use crate::metrics::distance::euclidian::Euclidian; use crate::metrics::distance::euclidian::Euclidian;
use crate::numbers::realnum::RealNumber; use crate::metrics::distance::PairwiseDistance;
use crate::numbers::floatnum::FloatNumber; use crate::numbers::floatnum::FloatNumber;
use crate::numbers::realnum::RealNumber;
/// ///
/// Inspired by Python implementation: /// Inspired by Python implementation:
@@ -98,7 +100,7 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
PairwiseDistance { PairwiseDistance {
node: index_row_i, node: index_row_i,
neighbour: Option::None, neighbour: Option::None,
distance: Some(T::MAX), distance: Some(<T as Bounded>::max_value()),
}, },
); );
} }
@@ -119,13 +121,19 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
); );
let d = Euclidian::squared_distance( let d = Euclidian::squared_distance(
&(self.samples.get_row_as_vec(index_row_i)), &Vec::from_iterator(
&(self.samples.get_row_as_vec(index_row_j)), self.samples.get_row(index_row_i).iterator(0).copied(),
self.samples.shape().1,
),
&Vec::from_iterator(
self.samples.get_row(index_row_j).iterator(0).copied(),
self.samples.shape().1,
),
); );
if d < nbd.unwrap() { if d < nbd.unwrap().to_f64().unwrap() {
// set this j-value to be the closest neighbour // set this j-value to be the closest neighbour
index_closest = index_row_j; index_closest = index_row_j;
nbd = Some(d); nbd = Some(T::from(d).unwrap());
} }
} }
@@ -138,7 +146,7 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
// No more neighbors, terminate conga line. // No more neighbors, terminate conga line.
// Last person on the line has no neigbors // Last person on the line has no neigbors
distances.get_mut(&max_index).unwrap().neighbour = Some(max_index); distances.get_mut(&max_index).unwrap().neighbour = Some(max_index);
distances.get_mut(&(len - 1)).unwrap().distance = Some(T::max_value()); distances.get_mut(&(len - 1)).unwrap().distance = Some(<T as Bounded>::max_value());
// compute sparse matrix (connectivity matrix) // compute sparse matrix (connectivity matrix)
let mut sparse_matrix = M::zeros(len, len); let mut sparse_matrix = M::zeros(len, len);
@@ -171,33 +179,6 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
} }
} }
///
/// Brute force algorithm, used only for comparison and testing
///
#[allow(dead_code)]
pub fn closest_pair_brute(&self) -> PairwiseDistance<T> {
use itertools::Itertools;
let m = self.samples.shape().0;
let mut closest_pair = PairwiseDistance {
node: 0,
neighbour: Option::None,
distance: Some(T::max_value()),
};
for pair in (0..m).combinations(2) {
let d = Euclidian::squared_distance(
&(self.samples.get_row_as_vec(pair[0])),
&(self.samples.get_row_as_vec(pair[1])),
);
if d < closest_pair.distance.unwrap() {
closest_pair.node = pair[0];
closest_pair.neighbour = Some(pair[1]);
closest_pair.distance = Some(d);
}
}
closest_pair
}
// //
// Compute distances from input to all other points in data-structure. // Compute distances from input to all other points in data-structure.
// input is the row index of the sample matrix // input is the row index of the sample matrix
@@ -210,10 +191,19 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
distances.push(PairwiseDistance { distances.push(PairwiseDistance {
node: index_row, node: index_row,
neighbour: Some(*other), neighbour: Some(*other),
distance: Some(Euclidian::squared_distance( distance: Some(
&(self.samples.get_row_as_vec(index_row)), T::from(Euclidian::squared_distance(
&(self.samples.get_row_as_vec(*other)), &Vec::from_iterator(
)), self.samples.get_row(index_row).iterator(0).copied(),
self.samples.shape().1,
),
&Vec::from_iterator(
self.samples.get_row(*other).iterator(0).copied(),
self.samples.shape().1,
),
))
.unwrap(),
),
}) })
} }
} }
@@ -225,7 +215,39 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
mod tests_fastpair { mod tests_fastpair {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::*; use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix};
///
/// Brute force algorithm, used only for comparison and testing
///
pub fn closest_pair_brute(fastpair: &FastPair<f64, DenseMatrix<f64>>) -> PairwiseDistance<f64> {
use itertools::Itertools;
let m = fastpair.samples.shape().0;
let mut closest_pair = PairwiseDistance {
node: 0,
neighbour: Option::None,
distance: Some(f64::max_value()),
};
for pair in (0..m).combinations(2) {
let d = Euclidian::squared_distance(
&Vec::from_iterator(
fastpair.samples.get_row(pair[0]).iterator(0).copied(),
fastpair.samples.shape().1,
),
&Vec::from_iterator(
fastpair.samples.get_row(pair[1]).iterator(0).copied(),
fastpair.samples.shape().1,
),
);
if d < closest_pair.distance.unwrap() {
closest_pair.node = pair[0];
closest_pair.neighbour = Some(pair[1]);
closest_pair.distance = Some(d);
}
}
closest_pair
}
#[test] #[test]
fn fastpair_init() { fn fastpair_init() {
@@ -284,7 +306,7 @@ mod tests_fastpair {
}; };
assert_eq!(closest_pair, expected_closest_pair); assert_eq!(closest_pair, expected_closest_pair);
let closest_pair_brute = fastpair.closest_pair_brute(); let closest_pair_brute = closest_pair_brute(&fastpair);
assert_eq!(closest_pair_brute, expected_closest_pair); assert_eq!(closest_pair_brute, expected_closest_pair);
} }
@@ -302,7 +324,7 @@ mod tests_fastpair {
neighbour: Some(3), neighbour: Some(3),
distance: Some(4.0), distance: Some(4.0),
}; };
assert_eq!(closest_pair, fastpair.closest_pair_brute()); assert_eq!(closest_pair, closest_pair_brute(&fastpair));
assert_eq!(closest_pair, expected_closest_pair); assert_eq!(closest_pair, expected_closest_pair);
} }
@@ -459,11 +481,16 @@ mod tests_fastpair {
let expected: HashMap<_, _> = dissimilarities.into_iter().collect(); let expected: HashMap<_, _> = dissimilarities.into_iter().collect();
for i in 0..(x.shape().0 - 1) { for i in 0..(x.shape().0 - 1) {
let input_node = result.samples.get_row_as_vec(i);
let input_neighbour: usize = expected.get(&i).unwrap().neighbour.unwrap(); let input_neighbour: usize = expected.get(&i).unwrap().neighbour.unwrap();
let distance = Euclidian::squared_distance( let distance = Euclidian::squared_distance(
&input_node, &Vec::from_iterator(
&result.samples.get_row_as_vec(input_neighbour), result.samples.get_row(i).iterator(0).copied(),
result.samples.shape().1,
),
&Vec::from_iterator(
result.samples.get_row(input_neighbour).iterator(0).copied(),
result.samples.shape().1,
),
); );
assert_eq!(i, expected.get(&i).unwrap().node); assert_eq!(i, expected.get(&i).unwrap().node);
@@ -518,7 +545,7 @@ mod tests_fastpair {
let result = fastpair.unwrap(); let result = fastpair.unwrap();
let dissimilarity1 = result.closest_pair(); let dissimilarity1 = result.closest_pair();
let dissimilarity2 = result.closest_pair_brute(); let dissimilarity2 = closest_pair_brute(&result);
assert_eq!(dissimilarity1, dissimilarity2); assert_eq!(dissimilarity1, dissimilarity2);
} }
+1 -3
View File
@@ -41,10 +41,8 @@ use serde::{Deserialize, Serialize};
pub(crate) mod bbd_tree; pub(crate) mod bbd_tree;
/// tree data structure for fast nearest neighbor search /// tree data structure for fast nearest neighbor search
pub mod cover_tree; pub mod cover_tree;
/// dissimilarities for vector-vector distance. Linkage algorithms used in fastpair
pub mod distances;
/// fastpair closest neighbour algorithm /// fastpair closest neighbour algorithm
// pub mod fastpair; pub mod fastpair;
/// very simple algorithm that sequentially checks each element of the list until a match is found or the whole list has been searched. /// very simple algorithm that sequentially checks each element of the list until a match is found or the whole list has been searched.
pub mod linear_search; pub mod linear_search;
+23 -16
View File
@@ -10,34 +10,30 @@
//! # SmartCore //! # SmartCore
//! //!
//! Welcome to SmartCore, the most advanced machine learning library in Rust! //! Welcome to SmartCore, machine learning in Rust!
//! //!
//! SmartCore features various classification, regression and clustering algorithms including support vector machines, random forests, k-means and DBSCAN, //! SmartCore features various classification, regression and clustering algorithms including support vector machines, random forests, k-means and DBSCAN,
//! as well as tools for model selection and model evaluation. //! as well as tools for model selection and model evaluation.
//! //!
//! SmartCore is well integrated with a with wide variaty of libraries that provide support for large, multi-dimensional arrays and matrices. At this moment, //! SmartCore provides its own traits system that extends Rust standard library, to deal with linear algebra and common
//! all Smartcore's algorithms work with ordinary Rust vectors, as well as matrices and vectors defined in these packages: //! computational models. Its API is designed using well recognizable patterns. Extra features (like support for [ndarray](https://docs.rs/ndarray)
//! * [ndarray](https://docs.rs/ndarray) //! structures) is available via optional features.
//! //!
//! ## Getting Started //! ## Getting Started
//! //!
//! To start using SmartCore simply add the following to your Cargo.toml file: //! To start using SmartCore simply add the following to your Cargo.toml file:
//! ```ignore //! ```ignore
//! [dependencies] //! [dependencies]
//! smartcore = { git = "https://github.com/smartcorelib/smartcore", branch = "v0.5-wip" } //! smartcore = { git = "https://github.com/smartcorelib/smartcore", branch = "development" }
//! ``` //! ```
//! //!
//! All machine learning algorithms in SmartCore are grouped into these broad categories: //! ## Using Jupyter
//! * [Clustering](cluster/index.html), unsupervised clustering of unlabeled data. //! For quick introduction, Jupyter Notebooks are available [here](https://github.com/smartcorelib/smartcore-jupyter/tree/main/notebooks).
//! * [Matrix Decomposition](decomposition/index.html), various methods for matrix decomposition. //! You can set up a local environment to run Rust notebooks using [EVCXR](https://github.com/google/evcxr)
//! * [Linear Models](linear/index.html), regression and classification methods where output is assumed to have linear relation to explanatory variables //! following [these instructions](https://depth-first.com/articles/2020/09/21/interactive-rust-in-a-repl-and-jupyter-notebook-with-evcxr/).
//! * [Ensemble Models](ensemble/index.html), variety of regression and classification ensemble models
//! * [Tree-based Models](tree/index.html), classification and regression trees
//! * [Nearest Neighbors](neighbors/index.html), K Nearest Neighbors for classification and regression
//! * [Naive Bayes](naive_bayes/index.html), statistical classification technique based on Bayes Theorem
//! * [SVM](svm/index.html), support vector machines
//! //!
//! //!
//! ## First Example
//! For example, you can use this code to fit a [K Nearest Neighbors classifier](neighbors/knn_classifier/index.html) to a dataset that is defined as standard Rust vector: //! For example, you can use this code to fit a [K Nearest Neighbors classifier](neighbors/knn_classifier/index.html) to a dataset that is defined as standard Rust vector:
//! //!
//! ``` //! ```
@@ -48,14 +44,14 @@
//! // Various distance metrics //! // Various distance metrics
//! use smartcore::metrics::distance::*; //! use smartcore::metrics::distance::*;
//! //!
//! // Turn Rust vectors with samples into a matrix //! // Turn Rust vector-slices with samples into a matrix
//! let x = DenseMatrix::from_2d_array(&[ //! let x = DenseMatrix::from_2d_array(&[
//! &[1., 2.], //! &[1., 2.],
//! &[3., 4.], //! &[3., 4.],
//! &[5., 6.], //! &[5., 6.],
//! &[7., 8.], //! &[7., 8.],
//! &[9., 10.]]); //! &[9., 10.]]);
//! // Our classes are defined as a Vector //! // Our classes are defined as a vector
//! let y = vec![2, 2, 2, 3, 3]; //! let y = vec![2, 2, 2, 3, 3];
//! //!
//! // Train classifier //! // Train classifier
@@ -64,6 +60,17 @@
//! // Predict classes //! // Predict classes
//! let y_hat = knn.predict(&x).unwrap(); //! let y_hat = knn.predict(&x).unwrap();
//! ``` //! ```
//!
//! ## Overview
//! All machine learning algorithms in SmartCore are grouped into these broad categories:
//! * [Clustering](cluster/index.html), unsupervised clustering of unlabeled data.
//! * [Matrix Decomposition](decomposition/index.html), various methods for matrix decomposition.
//! * [Linear Models](linear/index.html), regression and classification methods where output is assumed to have linear relation to explanatory variables
//! * [Ensemble Models](ensemble/index.html), variety of regression and classification ensemble models
//! * [Tree-based Models](tree/index.html), classification and regression trees
//! * [Nearest Neighbors](neighbors/index.html), K Nearest Neighbors for classification and regression
//! * [Naive Bayes](naive_bayes/index.html), statistical classification technique based on Bayes Theorem
//! * [SVM](svm/index.html), support vector machines
/// Foundamental numbers traits /// Foundamental numbers traits
pub mod numbers; pub mod numbers;
+3 -3
View File
@@ -71,8 +71,8 @@ pub trait MatrixStats<T: RealNumber>: ArrayView2<T> + Array2<T> {
x x
} }
/// (reference)[http://en.wikipedia.org/wiki/Arithmetic_mean] /// <http://en.wikipedia.org/wiki/Arithmetic_mean>
/// Taken from statistical /// Taken from `statistical`
/// The MIT License (MIT) /// The MIT License (MIT)
/// Copyright (c) 2015 Jeff Belgum /// Copyright (c) 2015 Jeff Belgum
fn _mean_of_vector(v: &[T]) -> T { fn _mean_of_vector(v: &[T]) -> T {
@@ -97,7 +97,7 @@ pub trait MatrixStats<T: RealNumber>: ArrayView2<T> + Array2<T> {
sum sum
} }
/// (Sample variance)[http://en.wikipedia.org/wiki/Variance#Sample_variance] /// <http://en.wikipedia.org/wiki/Variance#Sample_variance>
/// Taken from statistical /// Taken from statistical
/// The MIT License (MIT) /// The MIT License (MIT)
/// Copyright (c) 2015 Jeff Belgum /// Copyright (c) 2015 Jeff Belgum
+48
View File
@@ -24,9 +24,15 @@ pub mod manhattan;
/// A generalization of both the Euclidean distance and the Manhattan distance. /// A generalization of both the Euclidean distance and the Manhattan distance.
pub mod minkowski; pub mod minkowski;
use std::cmp::{Eq, Ordering, PartialOrd};
use crate::linalg::basic::arrays::Array2; use crate::linalg::basic::arrays::Array2;
use crate::linalg::traits::lu::LUDecomposable; use crate::linalg::traits::lu::LUDecomposable;
use crate::numbers::basenum::Number; use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
/// Distance metric, a function that calculates distance between two points /// Distance metric, a function that calculates distance between two points
pub trait Distance<T>: Clone { pub trait Distance<T>: Clone {
@@ -66,3 +72,45 @@ impl Distances {
mahalanobis::Mahalanobis::new(data) mahalanobis::Mahalanobis::new(data)
} }
} }
///
/// ### Pairwise dissimilarities.
///
/// Representing distances as pairwise dissimilarities, so to build a
/// graph of closest neighbours. This representation can be reused for
/// different implementations
/// (initially used in this library for [FastPair](algorithm/neighbour/fastpair)).
/// The edge of the subgraph is defined by `PairwiseDistance`.
/// The calling algorithm can store a list of distances as
/// a list of these structures.
///
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, Copy)]
pub struct PairwiseDistance<T: RealNumber> {
/// index of the vector in the original `Matrix` or list
pub node: usize,
/// index of the closest neighbor in the original `Matrix` or same list
pub neighbour: Option<usize>,
/// measure of distance, according to the algorithm distance function
/// if the distance is None, the edge has value "infinite" or max distance
/// each algorithm has to match
pub distance: Option<T>,
}
impl<T: RealNumber> Eq for PairwiseDistance<T> {}
impl<T: RealNumber> PartialEq for PairwiseDistance<T> {
fn eq(&self, other: &Self) -> bool {
self.node == other.node
&& self.neighbour == other.neighbour
&& self.distance == other.distance
}
}
impl<T: RealNumber> PartialOrd for PairwiseDistance<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.distance.partial_cmp(&other.distance)
}
}
+18 -15
View File
@@ -84,7 +84,7 @@ use std::marker::PhantomData;
/// A trait to be implemented by all metrics /// A trait to be implemented by all metrics
pub trait Metrics<T> { pub trait Metrics<T> {
/// instantiate a new Metrics trait-object /// instantiate a new Metrics trait-object
/// https://doc.rust-lang.org/error-index.html#E0038 /// <https://doc.rust-lang.org/error-index.html#E0038>
fn new() -> Self fn new() -> Self
where where
Self: Sized; Self: Sized;
@@ -133,10 +133,10 @@ impl<T: Number + RealNumber + FloatNumber> ClassificationMetrics<T> {
f1::F1::new_with(beta) f1::F1::new_with(beta)
} }
// /// Area Under the Receiver Operating Characteristic Curve (ROC AUC), see [AUC](auc/index.html). /// Area Under the Receiver Operating Characteristic Curve (ROC AUC), see [AUC](auc/index.html).
// pub fn roc_auc_score() -> auc::AUC<T> { pub fn roc_auc_score() -> auc::AUC<T> {
// auc::AUC::<T>::new() auc::AUC::<T>::new()
// } }
} }
impl<T: Number + Ord> ClassificationMetricsOrd<T> { impl<T: Number + Ord> ClassificationMetricsOrd<T> {
@@ -212,16 +212,19 @@ pub fn f1<T: Number + RealNumber + FloatNumber, V: ArrayView1<T>>(
obj.get_score(y_true, y_pred) obj.get_score(y_true, y_pred)
} }
// /// AUC score, see [AUC](auc/index.html). /// AUC score, see [AUC](auc/index.html).
// /// * `y_true` - cround truth (correct) labels. /// * `y_true` - cround truth (correct) labels.
// /// * `y_pred_probabilities` - probability estimates, as returned by a classifier. /// * `y_pred_probabilities` - probability estimates, as returned by a classifier.
// pub fn roc_auc_score<T: Number + PartialOrd, V: ArrayView1<T> + Array1<T> + Array1<T>>( pub fn roc_auc_score<
// y_true: &V, T: Number + RealNumber + FloatNumber + PartialOrd,
// y_pred_probabilities: &V, V: ArrayView1<T> + Array1<T> + Array1<T>,
// ) -> T { >(
// let obj = ClassificationMetrics::<T>::roc_auc_score(); y_true: &V,
// obj.get_score(y_true, y_pred_probabilities) y_pred_probabilities: &V,
// } ) -> f64 {
let obj = ClassificationMetrics::<T>::roc_auc_score();
obj.get_score(y_true, y_pred_probabilities)
}
/// Computes mean squared error, see [mean squared error](mean_squared_error/index.html). /// Computes mean squared error, see [mean squared error](mean_squared_error/index.html).
/// * `y_true` - Ground truth (correct) target values. /// * `y_true` - Ground truth (correct) target values.