Disambiguate distances. Implement Fastpair. (#220)
This commit is contained in:
@@ -35,6 +35,7 @@ js = ["getrandom/js"]
|
|||||||
getrandom = { version = "0.2", optional = true }
|
getrandom = { version = "0.2", optional = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
itertools = "*"
|
||||||
criterion = { version = "0.4", default-features = false }
|
criterion = { version = "0.4", default-features = false }
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
bincode = "1.3.1"
|
bincode = "1.3.1"
|
||||||
|
|||||||
@@ -1,48 +0,0 @@
|
|||||||
//!
|
|
||||||
//! Dissimilarities for vector-vector distance
|
|
||||||
//!
|
|
||||||
//! Representing distances as pairwise dissimilarities, so to build a
|
|
||||||
//! graph of closest neighbours. This representation can be reused for
|
|
||||||
//! different implementations (initially used in this library for FastPair).
|
|
||||||
use std::cmp::{Eq, Ordering, PartialOrd};
|
|
||||||
|
|
||||||
#[cfg(feature = "serde")]
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
use crate::numbers::realnum::RealNumber;
|
|
||||||
|
|
||||||
///
|
|
||||||
/// The edge of the subgraph is defined by `PairwiseDistance`.
|
|
||||||
/// The calling algorithm can store a list of distsances as
|
|
||||||
/// a list of these structures.
|
|
||||||
///
|
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub struct PairwiseDistance<T: RealNumber> {
|
|
||||||
/// index of the vector in the original `Matrix` or list
|
|
||||||
pub node: usize,
|
|
||||||
|
|
||||||
/// index of the closest neighbor in the original `Matrix` or same list
|
|
||||||
pub neighbour: Option<usize>,
|
|
||||||
|
|
||||||
/// measure of distance, according to the algorithm distance function
|
|
||||||
/// if the distance is None, the edge has value "infinite" or max distance
|
|
||||||
/// each algorithm has to match
|
|
||||||
pub distance: Option<T>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: RealNumber> Eq for PairwiseDistance<T> {}
|
|
||||||
|
|
||||||
impl<T: RealNumber> PartialEq for PairwiseDistance<T> {
|
|
||||||
fn eq(&self, other: &Self) -> bool {
|
|
||||||
self.node == other.node
|
|
||||||
&& self.neighbour == other.neighbour
|
|
||||||
&& self.distance == other.distance
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: RealNumber> PartialOrd for PairwiseDistance<T> {
|
|
||||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
|
||||||
self.distance.partial_cmp(&other.distance)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
///
|
///
|
||||||
/// # FastPair: Data-structure for the dynamic closest-pair problem.
|
/// ### FastPair: Data-structure for the dynamic closest-pair problem.
|
||||||
///
|
///
|
||||||
/// Reference:
|
/// Reference:
|
||||||
/// Eppstein, David: Fast hierarchical clustering and other applications of
|
/// Eppstein, David: Fast hierarchical clustering and other applications of
|
||||||
@@ -7,8 +7,8 @@
|
|||||||
///
|
///
|
||||||
/// Example:
|
/// Example:
|
||||||
/// ```
|
/// ```
|
||||||
/// use smartcore::algorithm::neighbour::distances::PairwiseDistance;
|
/// use smartcore::metrics::distance::PairwiseDistance;
|
||||||
/// use smartcore::linalg::naive::dense_matrix::DenseMatrix;
|
/// use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||||
/// use smartcore::algorithm::neighbour::fastpair::FastPair;
|
/// use smartcore::algorithm::neighbour::fastpair::FastPair;
|
||||||
/// let x = DenseMatrix::<f64>::from_2d_array(&[
|
/// let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
/// &[5.1, 3.5, 1.4, 0.2],
|
/// &[5.1, 3.5, 1.4, 0.2],
|
||||||
@@ -25,12 +25,14 @@
|
|||||||
/// <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
/// <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use crate::algorithm::neighbour::distances::PairwiseDistance;
|
use num::Bounded;
|
||||||
|
|
||||||
use crate::error::{Failed, FailedError};
|
use crate::error::{Failed, FailedError};
|
||||||
use crate::linalg::basic::arrays::Array2;
|
use crate::linalg::basic::arrays::{Array1, Array2};
|
||||||
use crate::metrics::distance::euclidian::Euclidian;
|
use crate::metrics::distance::euclidian::Euclidian;
|
||||||
use crate::numbers::realnum::RealNumber;
|
use crate::metrics::distance::PairwiseDistance;
|
||||||
use crate::numbers::floatnum::FloatNumber;
|
use crate::numbers::floatnum::FloatNumber;
|
||||||
|
use crate::numbers::realnum::RealNumber;
|
||||||
|
|
||||||
///
|
///
|
||||||
/// Inspired by Python implementation:
|
/// Inspired by Python implementation:
|
||||||
@@ -98,7 +100,7 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
|
|||||||
PairwiseDistance {
|
PairwiseDistance {
|
||||||
node: index_row_i,
|
node: index_row_i,
|
||||||
neighbour: Option::None,
|
neighbour: Option::None,
|
||||||
distance: Some(T::MAX),
|
distance: Some(<T as Bounded>::max_value()),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -119,13 +121,19 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let d = Euclidian::squared_distance(
|
let d = Euclidian::squared_distance(
|
||||||
&(self.samples.get_row_as_vec(index_row_i)),
|
&Vec::from_iterator(
|
||||||
&(self.samples.get_row_as_vec(index_row_j)),
|
self.samples.get_row(index_row_i).iterator(0).copied(),
|
||||||
|
self.samples.shape().1,
|
||||||
|
),
|
||||||
|
&Vec::from_iterator(
|
||||||
|
self.samples.get_row(index_row_j).iterator(0).copied(),
|
||||||
|
self.samples.shape().1,
|
||||||
|
),
|
||||||
);
|
);
|
||||||
if d < nbd.unwrap() {
|
if d < nbd.unwrap().to_f64().unwrap() {
|
||||||
// set this j-value to be the closest neighbour
|
// set this j-value to be the closest neighbour
|
||||||
index_closest = index_row_j;
|
index_closest = index_row_j;
|
||||||
nbd = Some(d);
|
nbd = Some(T::from(d).unwrap());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -138,7 +146,7 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
|
|||||||
// No more neighbors, terminate conga line.
|
// No more neighbors, terminate conga line.
|
||||||
// Last person on the line has no neigbors
|
// Last person on the line has no neigbors
|
||||||
distances.get_mut(&max_index).unwrap().neighbour = Some(max_index);
|
distances.get_mut(&max_index).unwrap().neighbour = Some(max_index);
|
||||||
distances.get_mut(&(len - 1)).unwrap().distance = Some(T::max_value());
|
distances.get_mut(&(len - 1)).unwrap().distance = Some(<T as Bounded>::max_value());
|
||||||
|
|
||||||
// compute sparse matrix (connectivity matrix)
|
// compute sparse matrix (connectivity matrix)
|
||||||
let mut sparse_matrix = M::zeros(len, len);
|
let mut sparse_matrix = M::zeros(len, len);
|
||||||
@@ -171,33 +179,6 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
///
|
|
||||||
/// Brute force algorithm, used only for comparison and testing
|
|
||||||
///
|
|
||||||
#[allow(dead_code)]
|
|
||||||
pub fn closest_pair_brute(&self) -> PairwiseDistance<T> {
|
|
||||||
use itertools::Itertools;
|
|
||||||
let m = self.samples.shape().0;
|
|
||||||
|
|
||||||
let mut closest_pair = PairwiseDistance {
|
|
||||||
node: 0,
|
|
||||||
neighbour: Option::None,
|
|
||||||
distance: Some(T::max_value()),
|
|
||||||
};
|
|
||||||
for pair in (0..m).combinations(2) {
|
|
||||||
let d = Euclidian::squared_distance(
|
|
||||||
&(self.samples.get_row_as_vec(pair[0])),
|
|
||||||
&(self.samples.get_row_as_vec(pair[1])),
|
|
||||||
);
|
|
||||||
if d < closest_pair.distance.unwrap() {
|
|
||||||
closest_pair.node = pair[0];
|
|
||||||
closest_pair.neighbour = Some(pair[1]);
|
|
||||||
closest_pair.distance = Some(d);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
closest_pair
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// Compute distances from input to all other points in data-structure.
|
// Compute distances from input to all other points in data-structure.
|
||||||
// input is the row index of the sample matrix
|
// input is the row index of the sample matrix
|
||||||
@@ -210,10 +191,19 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
|
|||||||
distances.push(PairwiseDistance {
|
distances.push(PairwiseDistance {
|
||||||
node: index_row,
|
node: index_row,
|
||||||
neighbour: Some(*other),
|
neighbour: Some(*other),
|
||||||
distance: Some(Euclidian::squared_distance(
|
distance: Some(
|
||||||
&(self.samples.get_row_as_vec(index_row)),
|
T::from(Euclidian::squared_distance(
|
||||||
&(self.samples.get_row_as_vec(*other)),
|
&Vec::from_iterator(
|
||||||
)),
|
self.samples.get_row(index_row).iterator(0).copied(),
|
||||||
|
self.samples.shape().1,
|
||||||
|
),
|
||||||
|
&Vec::from_iterator(
|
||||||
|
self.samples.get_row(*other).iterator(0).copied(),
|
||||||
|
self.samples.shape().1,
|
||||||
|
),
|
||||||
|
))
|
||||||
|
.unwrap(),
|
||||||
|
),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -225,7 +215,39 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
|
|||||||
mod tests_fastpair {
|
mod tests_fastpair {
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::*;
|
use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix};
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Brute force algorithm, used only for comparison and testing
|
||||||
|
///
|
||||||
|
pub fn closest_pair_brute(fastpair: &FastPair<f64, DenseMatrix<f64>>) -> PairwiseDistance<f64> {
|
||||||
|
use itertools::Itertools;
|
||||||
|
let m = fastpair.samples.shape().0;
|
||||||
|
|
||||||
|
let mut closest_pair = PairwiseDistance {
|
||||||
|
node: 0,
|
||||||
|
neighbour: Option::None,
|
||||||
|
distance: Some(f64::max_value()),
|
||||||
|
};
|
||||||
|
for pair in (0..m).combinations(2) {
|
||||||
|
let d = Euclidian::squared_distance(
|
||||||
|
&Vec::from_iterator(
|
||||||
|
fastpair.samples.get_row(pair[0]).iterator(0).copied(),
|
||||||
|
fastpair.samples.shape().1,
|
||||||
|
),
|
||||||
|
&Vec::from_iterator(
|
||||||
|
fastpair.samples.get_row(pair[1]).iterator(0).copied(),
|
||||||
|
fastpair.samples.shape().1,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
if d < closest_pair.distance.unwrap() {
|
||||||
|
closest_pair.node = pair[0];
|
||||||
|
closest_pair.neighbour = Some(pair[1]);
|
||||||
|
closest_pair.distance = Some(d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
closest_pair
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn fastpair_init() {
|
fn fastpair_init() {
|
||||||
@@ -284,7 +306,7 @@ mod tests_fastpair {
|
|||||||
};
|
};
|
||||||
assert_eq!(closest_pair, expected_closest_pair);
|
assert_eq!(closest_pair, expected_closest_pair);
|
||||||
|
|
||||||
let closest_pair_brute = fastpair.closest_pair_brute();
|
let closest_pair_brute = closest_pair_brute(&fastpair);
|
||||||
assert_eq!(closest_pair_brute, expected_closest_pair);
|
assert_eq!(closest_pair_brute, expected_closest_pair);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -302,7 +324,7 @@ mod tests_fastpair {
|
|||||||
neighbour: Some(3),
|
neighbour: Some(3),
|
||||||
distance: Some(4.0),
|
distance: Some(4.0),
|
||||||
};
|
};
|
||||||
assert_eq!(closest_pair, fastpair.closest_pair_brute());
|
assert_eq!(closest_pair, closest_pair_brute(&fastpair));
|
||||||
assert_eq!(closest_pair, expected_closest_pair);
|
assert_eq!(closest_pair, expected_closest_pair);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -459,11 +481,16 @@ mod tests_fastpair {
|
|||||||
let expected: HashMap<_, _> = dissimilarities.into_iter().collect();
|
let expected: HashMap<_, _> = dissimilarities.into_iter().collect();
|
||||||
|
|
||||||
for i in 0..(x.shape().0 - 1) {
|
for i in 0..(x.shape().0 - 1) {
|
||||||
let input_node = result.samples.get_row_as_vec(i);
|
|
||||||
let input_neighbour: usize = expected.get(&i).unwrap().neighbour.unwrap();
|
let input_neighbour: usize = expected.get(&i).unwrap().neighbour.unwrap();
|
||||||
let distance = Euclidian::squared_distance(
|
let distance = Euclidian::squared_distance(
|
||||||
&input_node,
|
&Vec::from_iterator(
|
||||||
&result.samples.get_row_as_vec(input_neighbour),
|
result.samples.get_row(i).iterator(0).copied(),
|
||||||
|
result.samples.shape().1,
|
||||||
|
),
|
||||||
|
&Vec::from_iterator(
|
||||||
|
result.samples.get_row(input_neighbour).iterator(0).copied(),
|
||||||
|
result.samples.shape().1,
|
||||||
|
),
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(i, expected.get(&i).unwrap().node);
|
assert_eq!(i, expected.get(&i).unwrap().node);
|
||||||
@@ -518,7 +545,7 @@ mod tests_fastpair {
|
|||||||
let result = fastpair.unwrap();
|
let result = fastpair.unwrap();
|
||||||
|
|
||||||
let dissimilarity1 = result.closest_pair();
|
let dissimilarity1 = result.closest_pair();
|
||||||
let dissimilarity2 = result.closest_pair_brute();
|
let dissimilarity2 = closest_pair_brute(&result);
|
||||||
|
|
||||||
assert_eq!(dissimilarity1, dissimilarity2);
|
assert_eq!(dissimilarity1, dissimilarity2);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -41,10 +41,8 @@ use serde::{Deserialize, Serialize};
|
|||||||
pub(crate) mod bbd_tree;
|
pub(crate) mod bbd_tree;
|
||||||
/// tree data structure for fast nearest neighbor search
|
/// tree data structure for fast nearest neighbor search
|
||||||
pub mod cover_tree;
|
pub mod cover_tree;
|
||||||
/// dissimilarities for vector-vector distance. Linkage algorithms used in fastpair
|
|
||||||
pub mod distances;
|
|
||||||
/// fastpair closest neighbour algorithm
|
/// fastpair closest neighbour algorithm
|
||||||
// pub mod fastpair;
|
pub mod fastpair;
|
||||||
/// very simple algorithm that sequentially checks each element of the list until a match is found or the whole list has been searched.
|
/// very simple algorithm that sequentially checks each element of the list until a match is found or the whole list has been searched.
|
||||||
pub mod linear_search;
|
pub mod linear_search;
|
||||||
|
|
||||||
|
|||||||
+23
-16
@@ -10,34 +10,30 @@
|
|||||||
|
|
||||||
//! # SmartCore
|
//! # SmartCore
|
||||||
//!
|
//!
|
||||||
//! Welcome to SmartCore, the most advanced machine learning library in Rust!
|
//! Welcome to SmartCore, machine learning in Rust!
|
||||||
//!
|
//!
|
||||||
//! SmartCore features various classification, regression and clustering algorithms including support vector machines, random forests, k-means and DBSCAN,
|
//! SmartCore features various classification, regression and clustering algorithms including support vector machines, random forests, k-means and DBSCAN,
|
||||||
//! as well as tools for model selection and model evaluation.
|
//! as well as tools for model selection and model evaluation.
|
||||||
//!
|
//!
|
||||||
//! SmartCore is well integrated with a with wide variaty of libraries that provide support for large, multi-dimensional arrays and matrices. At this moment,
|
//! SmartCore provides its own traits system that extends Rust standard library, to deal with linear algebra and common
|
||||||
//! all Smartcore's algorithms work with ordinary Rust vectors, as well as matrices and vectors defined in these packages:
|
//! computational models. Its API is designed using well recognizable patterns. Extra features (like support for [ndarray](https://docs.rs/ndarray)
|
||||||
//! * [ndarray](https://docs.rs/ndarray)
|
//! structures) is available via optional features.
|
||||||
//!
|
//!
|
||||||
//! ## Getting Started
|
//! ## Getting Started
|
||||||
//!
|
//!
|
||||||
//! To start using SmartCore simply add the following to your Cargo.toml file:
|
//! To start using SmartCore simply add the following to your Cargo.toml file:
|
||||||
//! ```ignore
|
//! ```ignore
|
||||||
//! [dependencies]
|
//! [dependencies]
|
||||||
//! smartcore = { git = "https://github.com/smartcorelib/smartcore", branch = "v0.5-wip" }
|
//! smartcore = { git = "https://github.com/smartcorelib/smartcore", branch = "development" }
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! All machine learning algorithms in SmartCore are grouped into these broad categories:
|
//! ## Using Jupyter
|
||||||
//! * [Clustering](cluster/index.html), unsupervised clustering of unlabeled data.
|
//! For quick introduction, Jupyter Notebooks are available [here](https://github.com/smartcorelib/smartcore-jupyter/tree/main/notebooks).
|
||||||
//! * [Matrix Decomposition](decomposition/index.html), various methods for matrix decomposition.
|
//! You can set up a local environment to run Rust notebooks using [EVCXR](https://github.com/google/evcxr)
|
||||||
//! * [Linear Models](linear/index.html), regression and classification methods where output is assumed to have linear relation to explanatory variables
|
//! following [these instructions](https://depth-first.com/articles/2020/09/21/interactive-rust-in-a-repl-and-jupyter-notebook-with-evcxr/).
|
||||||
//! * [Ensemble Models](ensemble/index.html), variety of regression and classification ensemble models
|
|
||||||
//! * [Tree-based Models](tree/index.html), classification and regression trees
|
|
||||||
//! * [Nearest Neighbors](neighbors/index.html), K Nearest Neighbors for classification and regression
|
|
||||||
//! * [Naive Bayes](naive_bayes/index.html), statistical classification technique based on Bayes Theorem
|
|
||||||
//! * [SVM](svm/index.html), support vector machines
|
|
||||||
//!
|
//!
|
||||||
//!
|
//!
|
||||||
|
//! ## First Example
|
||||||
//! For example, you can use this code to fit a [K Nearest Neighbors classifier](neighbors/knn_classifier/index.html) to a dataset that is defined as standard Rust vector:
|
//! For example, you can use this code to fit a [K Nearest Neighbors classifier](neighbors/knn_classifier/index.html) to a dataset that is defined as standard Rust vector:
|
||||||
//!
|
//!
|
||||||
//! ```
|
//! ```
|
||||||
@@ -48,14 +44,14 @@
|
|||||||
//! // Various distance metrics
|
//! // Various distance metrics
|
||||||
//! use smartcore::metrics::distance::*;
|
//! use smartcore::metrics::distance::*;
|
||||||
//!
|
//!
|
||||||
//! // Turn Rust vectors with samples into a matrix
|
//! // Turn Rust vector-slices with samples into a matrix
|
||||||
//! let x = DenseMatrix::from_2d_array(&[
|
//! let x = DenseMatrix::from_2d_array(&[
|
||||||
//! &[1., 2.],
|
//! &[1., 2.],
|
||||||
//! &[3., 4.],
|
//! &[3., 4.],
|
||||||
//! &[5., 6.],
|
//! &[5., 6.],
|
||||||
//! &[7., 8.],
|
//! &[7., 8.],
|
||||||
//! &[9., 10.]]);
|
//! &[9., 10.]]);
|
||||||
//! // Our classes are defined as a Vector
|
//! // Our classes are defined as a vector
|
||||||
//! let y = vec![2, 2, 2, 3, 3];
|
//! let y = vec![2, 2, 2, 3, 3];
|
||||||
//!
|
//!
|
||||||
//! // Train classifier
|
//! // Train classifier
|
||||||
@@ -64,6 +60,17 @@
|
|||||||
//! // Predict classes
|
//! // Predict classes
|
||||||
//! let y_hat = knn.predict(&x).unwrap();
|
//! let y_hat = knn.predict(&x).unwrap();
|
||||||
//! ```
|
//! ```
|
||||||
|
//!
|
||||||
|
//! ## Overview
|
||||||
|
//! All machine learning algorithms in SmartCore are grouped into these broad categories:
|
||||||
|
//! * [Clustering](cluster/index.html), unsupervised clustering of unlabeled data.
|
||||||
|
//! * [Matrix Decomposition](decomposition/index.html), various methods for matrix decomposition.
|
||||||
|
//! * [Linear Models](linear/index.html), regression and classification methods where output is assumed to have linear relation to explanatory variables
|
||||||
|
//! * [Ensemble Models](ensemble/index.html), variety of regression and classification ensemble models
|
||||||
|
//! * [Tree-based Models](tree/index.html), classification and regression trees
|
||||||
|
//! * [Nearest Neighbors](neighbors/index.html), K Nearest Neighbors for classification and regression
|
||||||
|
//! * [Naive Bayes](naive_bayes/index.html), statistical classification technique based on Bayes Theorem
|
||||||
|
//! * [SVM](svm/index.html), support vector machines
|
||||||
|
|
||||||
/// Foundamental numbers traits
|
/// Foundamental numbers traits
|
||||||
pub mod numbers;
|
pub mod numbers;
|
||||||
|
|||||||
@@ -71,8 +71,8 @@ pub trait MatrixStats<T: RealNumber>: ArrayView2<T> + Array2<T> {
|
|||||||
x
|
x
|
||||||
}
|
}
|
||||||
|
|
||||||
/// (reference)[http://en.wikipedia.org/wiki/Arithmetic_mean]
|
/// <http://en.wikipedia.org/wiki/Arithmetic_mean>
|
||||||
/// Taken from statistical
|
/// Taken from `statistical`
|
||||||
/// The MIT License (MIT)
|
/// The MIT License (MIT)
|
||||||
/// Copyright (c) 2015 Jeff Belgum
|
/// Copyright (c) 2015 Jeff Belgum
|
||||||
fn _mean_of_vector(v: &[T]) -> T {
|
fn _mean_of_vector(v: &[T]) -> T {
|
||||||
@@ -97,7 +97,7 @@ pub trait MatrixStats<T: RealNumber>: ArrayView2<T> + Array2<T> {
|
|||||||
sum
|
sum
|
||||||
}
|
}
|
||||||
|
|
||||||
/// (Sample variance)[http://en.wikipedia.org/wiki/Variance#Sample_variance]
|
/// <http://en.wikipedia.org/wiki/Variance#Sample_variance>
|
||||||
/// Taken from statistical
|
/// Taken from statistical
|
||||||
/// The MIT License (MIT)
|
/// The MIT License (MIT)
|
||||||
/// Copyright (c) 2015 Jeff Belgum
|
/// Copyright (c) 2015 Jeff Belgum
|
||||||
|
|||||||
@@ -24,9 +24,15 @@ pub mod manhattan;
|
|||||||
/// A generalization of both the Euclidean distance and the Manhattan distance.
|
/// A generalization of both the Euclidean distance and the Manhattan distance.
|
||||||
pub mod minkowski;
|
pub mod minkowski;
|
||||||
|
|
||||||
|
use std::cmp::{Eq, Ordering, PartialOrd};
|
||||||
|
|
||||||
use crate::linalg::basic::arrays::Array2;
|
use crate::linalg::basic::arrays::Array2;
|
||||||
use crate::linalg::traits::lu::LUDecomposable;
|
use crate::linalg::traits::lu::LUDecomposable;
|
||||||
use crate::numbers::basenum::Number;
|
use crate::numbers::basenum::Number;
|
||||||
|
use crate::numbers::realnum::RealNumber;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// Distance metric, a function that calculates distance between two points
|
/// Distance metric, a function that calculates distance between two points
|
||||||
pub trait Distance<T>: Clone {
|
pub trait Distance<T>: Clone {
|
||||||
@@ -66,3 +72,45 @@ impl Distances {
|
|||||||
mahalanobis::Mahalanobis::new(data)
|
mahalanobis::Mahalanobis::new(data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///
|
||||||
|
/// ### Pairwise dissimilarities.
|
||||||
|
///
|
||||||
|
/// Representing distances as pairwise dissimilarities, so to build a
|
||||||
|
/// graph of closest neighbours. This representation can be reused for
|
||||||
|
/// different implementations
|
||||||
|
/// (initially used in this library for [FastPair](algorithm/neighbour/fastpair)).
|
||||||
|
/// The edge of the subgraph is defined by `PairwiseDistance`.
|
||||||
|
/// The calling algorithm can store a list of distances as
|
||||||
|
/// a list of these structures.
|
||||||
|
///
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct PairwiseDistance<T: RealNumber> {
|
||||||
|
/// index of the vector in the original `Matrix` or list
|
||||||
|
pub node: usize,
|
||||||
|
|
||||||
|
/// index of the closest neighbor in the original `Matrix` or same list
|
||||||
|
pub neighbour: Option<usize>,
|
||||||
|
|
||||||
|
/// measure of distance, according to the algorithm distance function
|
||||||
|
/// if the distance is None, the edge has value "infinite" or max distance
|
||||||
|
/// each algorithm has to match
|
||||||
|
pub distance: Option<T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber> Eq for PairwiseDistance<T> {}
|
||||||
|
|
||||||
|
impl<T: RealNumber> PartialEq for PairwiseDistance<T> {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
self.node == other.node
|
||||||
|
&& self.neighbour == other.neighbour
|
||||||
|
&& self.distance == other.distance
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber> PartialOrd for PairwiseDistance<T> {
|
||||||
|
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||||
|
self.distance.partial_cmp(&other.distance)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
+18
-15
@@ -84,7 +84,7 @@ use std::marker::PhantomData;
|
|||||||
/// A trait to be implemented by all metrics
|
/// A trait to be implemented by all metrics
|
||||||
pub trait Metrics<T> {
|
pub trait Metrics<T> {
|
||||||
/// instantiate a new Metrics trait-object
|
/// instantiate a new Metrics trait-object
|
||||||
/// https://doc.rust-lang.org/error-index.html#E0038
|
/// <https://doc.rust-lang.org/error-index.html#E0038>
|
||||||
fn new() -> Self
|
fn new() -> Self
|
||||||
where
|
where
|
||||||
Self: Sized;
|
Self: Sized;
|
||||||
@@ -133,10 +133,10 @@ impl<T: Number + RealNumber + FloatNumber> ClassificationMetrics<T> {
|
|||||||
f1::F1::new_with(beta)
|
f1::F1::new_with(beta)
|
||||||
}
|
}
|
||||||
|
|
||||||
// /// Area Under the Receiver Operating Characteristic Curve (ROC AUC), see [AUC](auc/index.html).
|
/// Area Under the Receiver Operating Characteristic Curve (ROC AUC), see [AUC](auc/index.html).
|
||||||
// pub fn roc_auc_score() -> auc::AUC<T> {
|
pub fn roc_auc_score() -> auc::AUC<T> {
|
||||||
// auc::AUC::<T>::new()
|
auc::AUC::<T>::new()
|
||||||
// }
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Number + Ord> ClassificationMetricsOrd<T> {
|
impl<T: Number + Ord> ClassificationMetricsOrd<T> {
|
||||||
@@ -212,16 +212,19 @@ pub fn f1<T: Number + RealNumber + FloatNumber, V: ArrayView1<T>>(
|
|||||||
obj.get_score(y_true, y_pred)
|
obj.get_score(y_true, y_pred)
|
||||||
}
|
}
|
||||||
|
|
||||||
// /// AUC score, see [AUC](auc/index.html).
|
/// AUC score, see [AUC](auc/index.html).
|
||||||
// /// * `y_true` - cround truth (correct) labels.
|
/// * `y_true` - cround truth (correct) labels.
|
||||||
// /// * `y_pred_probabilities` - probability estimates, as returned by a classifier.
|
/// * `y_pred_probabilities` - probability estimates, as returned by a classifier.
|
||||||
// pub fn roc_auc_score<T: Number + PartialOrd, V: ArrayView1<T> + Array1<T> + Array1<T>>(
|
pub fn roc_auc_score<
|
||||||
// y_true: &V,
|
T: Number + RealNumber + FloatNumber + PartialOrd,
|
||||||
// y_pred_probabilities: &V,
|
V: ArrayView1<T> + Array1<T> + Array1<T>,
|
||||||
// ) -> T {
|
>(
|
||||||
// let obj = ClassificationMetrics::<T>::roc_auc_score();
|
y_true: &V,
|
||||||
// obj.get_score(y_true, y_pred_probabilities)
|
y_pred_probabilities: &V,
|
||||||
// }
|
) -> f64 {
|
||||||
|
let obj = ClassificationMetrics::<T>::roc_auc_score();
|
||||||
|
obj.get_score(y_true, y_pred_probabilities)
|
||||||
|
}
|
||||||
|
|
||||||
/// Computes mean squared error, see [mean squared error](mean_squared_error/index.html).
|
/// Computes mean squared error, see [mean squared error](mean_squared_error/index.html).
|
||||||
/// * `y_true` - Ground truth (correct) target values.
|
/// * `y_true` - Ground truth (correct) target values.
|
||||||
|
|||||||
Reference in New Issue
Block a user