Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e633afa520 | ||
|
|
b6e32fb328 | ||
|
|
948d78a4d0 | ||
|
|
448b6f77e3 | ||
|
|
09be4681cf |
@@ -0,0 +1,41 @@
|
|||||||
|
cff-version: 1.2.0
|
||||||
|
message: "If this software contributes to published work, please cite smartcore."
|
||||||
|
type: software
|
||||||
|
title: "smartcore: Machine Learning in Rust"
|
||||||
|
abstract: "smartcore is a comprehensive machine learning and numerical computing library for Rust, offering supervised and unsupervised algorithms, model evaluation tools, and linear algebra abstractions, with optional ndarray integration." [web:5][web:3]
|
||||||
|
repository-code: "https://github.com/smartcorelib/smartcore" [web:5]
|
||||||
|
url: "https://github.com/smartcorelib" [web:3]
|
||||||
|
license: "MIT" [web:13]
|
||||||
|
keywords:
|
||||||
|
- Rust
|
||||||
|
- machine learning
|
||||||
|
- numerical computing
|
||||||
|
- linear algebra
|
||||||
|
- classification
|
||||||
|
- regression
|
||||||
|
- clustering
|
||||||
|
- SVM
|
||||||
|
- Random Forest
|
||||||
|
- XGBoost [web:5]
|
||||||
|
authors:
|
||||||
|
- name: "smartcore Developers" [web:7]
|
||||||
|
- name: "Lorenzo (contributor)" [web:16]
|
||||||
|
- name: "Community contributors" [web:7]
|
||||||
|
version: "0.4.2" [attached_file:1]
|
||||||
|
date-released: "2025-09-14" [attached_file:1]
|
||||||
|
preferred-citation:
|
||||||
|
type: software
|
||||||
|
title: "smartcore: Machine Learning in Rust"
|
||||||
|
authors:
|
||||||
|
- name: "smartcore Developers" [web:7]
|
||||||
|
url: "https://github.com/smartcorelib" [web:3]
|
||||||
|
repository-code: "https://github.com/smartcorelib/smartcore" [web:5]
|
||||||
|
license: "MIT" [web:13]
|
||||||
|
references:
|
||||||
|
- type: manual
|
||||||
|
title: "smartcore Documentation"
|
||||||
|
url: "https://docs.rs/smartcore" [web:5]
|
||||||
|
- type: webpage
|
||||||
|
title: "smartcore Homepage"
|
||||||
|
url: "https://github.com/smartcorelib" [web:3]
|
||||||
|
notes: "For development features, see the docs.rs page and the repository README; SmartCore includes algorithms such as SVM, Random Forest, K-Means, PCA, DBSCAN, and XGBoost." [web:5]
|
||||||
+1
-1
@@ -2,7 +2,7 @@
|
|||||||
name = "smartcore"
|
name = "smartcore"
|
||||||
description = "Machine Learning in Rust."
|
description = "Machine Learning in Rust."
|
||||||
homepage = "https://smartcorelib.org"
|
homepage = "https://smartcorelib.org"
|
||||||
version = "0.4.2"
|
version = "0.4.4"
|
||||||
authors = ["smartcore Developers"]
|
authors = ["smartcore Developers"]
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
|
|||||||
@@ -16,6 +16,132 @@
|
|||||||
</p>
|
</p>
|
||||||
|
|
||||||
-----
|
-----
|
||||||
[](https://github.com/smartcorelib/smartcore/actions/workflows/ci.yml)
|
[](https://github.com/smartcorelib/smartcore/actions/workflows/ci.yml) [](https://doi.org/10.5281/zenodo.17219259)
|
||||||
|
|
||||||
To start getting familiar with the new smartcore v0.4 API, there is now available a [**Jupyter Notebook environment repository**](https://github.com/smartcorelib/smartcore-jupyter). Please see instructions there, contributions welcome see [CONTRIBUTING](.github/CONTRIBUTING.md).
|
To start getting familiar with the new smartcore v0.4 API, there is now available a [**Jupyter Notebook environment repository**](https://github.com/smartcorelib/smartcore-jupyter). Please see instructions there, contributions welcome see [CONTRIBUTING](.github/CONTRIBUTING.md).
|
||||||
|
|
||||||
|
smartcore is a fast, ergonomic machine learning library for Rust, covering classical supervised and unsupervised methods with a modular linear algebra abstraction and optional ndarray support. It aims to provide production-friendly APIs, strong typing, and good defaults while remaining flexible for research and experimentation.
|
||||||
|
|
||||||
|
|
||||||
|
## Highlights
|
||||||
|
|
||||||
|
- Broad algorithm coverage: linear models, tree-based methods, ensembles, SVMs, neighbors, clustering, decomposition, and preprocessing.
|
||||||
|
- Strong linear algebra traits with optional ndarray integration for users who prefer array-first workflows.
|
||||||
|
- WASM-first defaults with attention to portability; features such as serde and datasets are opt-in.
|
||||||
|
- Practical utilities for model selection, evaluation, readers (CSV), dataset generators, and built-in sample datasets.
|
||||||
|
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
Add to Cargo.toml:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[dependencies]
|
||||||
|
smartcore = "^0.4.3"
|
||||||
|
```
|
||||||
|
|
||||||
|
For the latest development branch:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[dependencies]
|
||||||
|
smartcore = { git = "https://github.com/smartcorelib/smartcore", branch = "development" }
|
||||||
|
```
|
||||||
|
|
||||||
|
Optional features (examples):
|
||||||
|
|
||||||
|
- datasets
|
||||||
|
- serde
|
||||||
|
- ndarray-bindings (deprecated in favor of ndarray-only support per recent changes)
|
||||||
|
|
||||||
|
Check Cargo.toml for available features and compatibility notes.
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
|
||||||
|
Here is a minimal example fitting a KNN classifier from native Rust vectors using DenseMatrix:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||||
|
use smartcore::neighbors::knn_classifier::KNNClassifier;
|
||||||
|
|
||||||
|
// Turn vector slices into a matrix
|
||||||
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
|
&[1., 2.],
|
||||||
|
&[3., 4.],
|
||||||
|
&[5., 6.],
|
||||||
|
&[7., 8.],
|
||||||
|
&[9., 10.],
|
||||||
|
]).unwrap;
|
||||||
|
|
||||||
|
// Class labels
|
||||||
|
let y = vec![2, 2, 2, 3, 3];
|
||||||
|
|
||||||
|
// Train classifier
|
||||||
|
let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||||
|
|
||||||
|
// Predict
|
||||||
|
let yhat = knn.predict(&x).unwrap();
|
||||||
|
```
|
||||||
|
|
||||||
|
This example mirrors the “First Example” section of the crate docs and demonstrates smartcore’s ergonomic API surface.
|
||||||
|
|
||||||
|
## Algorithms
|
||||||
|
|
||||||
|
smartcore organizes algorithms into clear modules with consistent traits:
|
||||||
|
|
||||||
|
- Clustering: K-Means, DBSCAN, agglomerative (including single-linkage), with K-Means++ initialization and utilities.
|
||||||
|
- Matrix decomposition: SVD, EVD, Cholesky, LU, QR, plus related linear algebra helpers.
|
||||||
|
- Linear models: OLS, Ridge, Lasso, ElasticNet, Logistic Regression.
|
||||||
|
- Ensemble and tree-based: Random Forest (classifier and regressor), Extra Trees, shared reusable components across trees and forests.
|
||||||
|
- SVM: SVC/SVR with kernel enum support and multiclass extensions.
|
||||||
|
- Neighbors: KNN classification and regression with distance metrics and fast selection helpers.
|
||||||
|
- Naive Bayes: Gaussian, Bernoulli, Categorical, Multinomial.
|
||||||
|
- Preprocessing: encoders, split utilities, and common transforms.
|
||||||
|
- Model selection and metrics: K-fold, search parameters, and evaluation utilities.
|
||||||
|
|
||||||
|
Recent refactors emphasize reusable components in trees/forests and expanded multiclass SVM capabilities. XGBoost-style regression and single-linkage clustering have been added. See CHANGELOG for API changes and migration notes.
|
||||||
|
|
||||||
|
## Data access and readers
|
||||||
|
|
||||||
|
- CSV readers: Read matrices from CSV with configurable delimiter and header rows, with helpful error messages and testing utilities (including non-IO reader abstractions).
|
||||||
|
- Dataset generators: make_blobs, make_circles, make_moons for quick experiments.
|
||||||
|
- Built-in datasets (feature-gated): digits, diabetes, breast cancer, boston, with serialization utilities to persist or refresh .xy bundles.
|
||||||
|
|
||||||
|
|
||||||
|
## WebAssembly and portability
|
||||||
|
|
||||||
|
smartcore adopts a WASM/WASI-first posture in defaults to ease browser and embedded deployments. Some file-system operations are restricted in wasm targets; tests and IO utilities are structured to avoid unsupported calls where possible. Enable features like serde selectively to minimize footprint. Consult module-level docs and CHANGELOG for target-specific caveats.
|
||||||
|
|
||||||
|
## Notebooks
|
||||||
|
|
||||||
|
A curated set of Jupyter notebooks is available via the companion repository to explore smartcore interactively. To run locally, use EVCXR to enable Rust notebooks. This is the recommended path to quickly experiment with the v0.4 API.
|
||||||
|
|
||||||
|
## Roadmap and recent changes
|
||||||
|
|
||||||
|
- Trait-system refactor, fewer structs and more object-safe traits, large codebase reorganization.
|
||||||
|
- Move to Rust 2021 edition and cleanup of duplicate code paths.
|
||||||
|
- Seeds and deterministic controls across algorithms using RNG plumbing.
|
||||||
|
- Search parameter API for hyperparameter exploration in K-Means and SVM families.
|
||||||
|
- Tree and forest components refactored for reuse; Extra Trees added.
|
||||||
|
- SVM multiclass support; SVR kernel enum and related improvements.
|
||||||
|
- XGBoost-style regression introduced; single-linkage clustering implemented.
|
||||||
|
|
||||||
|
See CHANGELOG.md for precise details, deprecations, and breaking changes. Some features like nalgebra-bindings have been dropped in favor of ndarray-only paths. Default features are tuned for WASM/WASI builds; enable serde/datasets as needed.
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Contributions are welcome:
|
||||||
|
|
||||||
|
- Open an issue describing the change and link it in the PR.
|
||||||
|
- Keep PRs in sync with the development branch and ensure tests pass on stable Rust.
|
||||||
|
- Provide or update tests; run clippy and apply formatting. Coverage and linting are part of the workflow.
|
||||||
|
- Use the provided PR and issue templates to describe behavior changes, new features, and expectations.
|
||||||
|
|
||||||
|
If adding IO, prefer abstractions that make non-IO testing straightforward (see readers/iotesting). For datasets, keep serialization helpers in tests gated appropriately to avoid unintended file writes in wasm targets.
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
smartcore is open source under a permissive license; see Cargo.toml and LICENSE for details. The crate metadata identifies “smartcore Developers” as authors; community contributions are credited via Git history and releases.
|
||||||
|
|
||||||
|
## Acknowledgments
|
||||||
|
|
||||||
|
smartcore’s design incorporates well-known ML patterns while staying idiomatic to Rust. Thanks to all contributors who have helped expand algorithms, improve docs, modernize traits, and harden the codebase for production.
|
||||||
|
|||||||
@@ -0,0 +1,777 @@
|
|||||||
|
///
|
||||||
|
/// ### CosinePair: Data-structure for the dynamic closest-pair problem.
|
||||||
|
///
|
||||||
|
/// Reference:
|
||||||
|
/// Eppstein, David: Fast hierarchical clustering and other applications of
|
||||||
|
/// dynamic closest pairs. Journal of Experimental Algorithmics 5 (2000) 1.
|
||||||
|
///
|
||||||
|
/// Example:
|
||||||
|
/// ```
|
||||||
|
/// use smartcore::metrics::distance::PairwiseDistance;
|
||||||
|
/// use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||||
|
/// use smartcore::algorithm::neighbour::cosinepair::CosinePair;
|
||||||
|
/// let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
/// &[5.1, 3.5, 1.4, 0.2],
|
||||||
|
/// &[4.9, 3.0, 1.4, 0.2],
|
||||||
|
/// &[4.7, 3.2, 1.3, 0.2],
|
||||||
|
/// &[4.6, 3.1, 1.5, 0.2],
|
||||||
|
/// &[5.0, 3.6, 1.4, 0.2],
|
||||||
|
/// &[5.4, 3.9, 1.7, 0.4],
|
||||||
|
/// ]).unwrap();
|
||||||
|
/// let cosinepair = CosinePair::new(&x);
|
||||||
|
/// let closest_pair: PairwiseDistance<f64> = cosinepair.unwrap().closest_pair();
|
||||||
|
/// ```
|
||||||
|
/// <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
|
/// <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use num::Bounded;
|
||||||
|
|
||||||
|
use crate::error::{Failed, FailedError};
|
||||||
|
use crate::linalg::basic::arrays::{Array1, Array2};
|
||||||
|
use crate::metrics::distance::cosine::Cosine;
|
||||||
|
use crate::metrics::distance::{Distance, PairwiseDistance};
|
||||||
|
use crate::numbers::floatnum::FloatNumber;
|
||||||
|
use crate::numbers::realnum::RealNumber;
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Inspired by Python implementation:
|
||||||
|
/// <https://github.com/carsonfarmer/fastpair/blob/b8b4d3000ab6f795a878936667eee1b557bf353d/fastpair/base.py>
|
||||||
|
/// MIT License (MIT) Copyright (c) 2016 Carson Farmer
|
||||||
|
///
|
||||||
|
/// affinity used is Cosine as it is the most used
|
||||||
|
///
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct CosinePair<'a, T: RealNumber + FloatNumber, M: Array2<T>> {
|
||||||
|
/// initial matrix
|
||||||
|
pub samples: &'a M,
|
||||||
|
/// closest pair hashmap (connectivity matrix for closest pairs)
|
||||||
|
pub distances: HashMap<usize, PairwiseDistance<T>>,
|
||||||
|
/// conga line used to keep track of the closest pair
|
||||||
|
pub neighbours: Vec<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> CosinePair<'a, T, M> {
|
||||||
|
/// Constructor
|
||||||
|
/// Instantiate and initialize the algorithm
|
||||||
|
pub fn new(m: &'a M) -> Result<Self, Failed> {
|
||||||
|
if m.shape().0 < 2 {
|
||||||
|
return Err(Failed::because(
|
||||||
|
FailedError::FindFailed,
|
||||||
|
"min number of rows should be 2",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut init = Self {
|
||||||
|
samples: m,
|
||||||
|
// to be computed in init(..)
|
||||||
|
distances: HashMap::with_capacity(m.shape().0),
|
||||||
|
neighbours: Vec::with_capacity(m.shape().0 + 1),
|
||||||
|
};
|
||||||
|
init.init();
|
||||||
|
Ok(init)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Initialise `CosinePair` by passing a `Array2`.
|
||||||
|
/// Build a CosinePairs data-structure from a set of (new) points.
|
||||||
|
fn init(&mut self) {
|
||||||
|
// basic measures
|
||||||
|
let len = self.samples.shape().0;
|
||||||
|
let max_index = self.samples.shape().0 - 1;
|
||||||
|
|
||||||
|
// Store all closest neighbors
|
||||||
|
let _distances = Box::new(HashMap::with_capacity(len));
|
||||||
|
let _neighbours = Box::new(Vec::with_capacity(len));
|
||||||
|
|
||||||
|
let mut distances = *_distances;
|
||||||
|
let mut neighbours = *_neighbours;
|
||||||
|
|
||||||
|
// fill neighbours with -1 values
|
||||||
|
neighbours.extend(0..len);
|
||||||
|
|
||||||
|
// init closest neighbour pairwise data
|
||||||
|
for index_row_i in 0..(max_index) {
|
||||||
|
distances.insert(
|
||||||
|
index_row_i,
|
||||||
|
PairwiseDistance {
|
||||||
|
node: index_row_i,
|
||||||
|
neighbour: Option::None,
|
||||||
|
distance: Some(<T as Bounded>::max_value()),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// loop through indeces and neighbours
|
||||||
|
for index_row_i in 0..(len) {
|
||||||
|
// start looking for the neighbour in the second element
|
||||||
|
let mut index_closest = index_row_i + 1; // closest neighbour index
|
||||||
|
let mut nbd: Option<T> = distances[&index_row_i].distance; // init neighbour distance
|
||||||
|
for index_row_j in (index_row_i + 1)..len {
|
||||||
|
distances.insert(
|
||||||
|
index_row_j,
|
||||||
|
PairwiseDistance {
|
||||||
|
node: index_row_j,
|
||||||
|
neighbour: Some(index_row_i),
|
||||||
|
distance: nbd,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
let d = Cosine::new().distance(
|
||||||
|
&Vec::from_iterator(
|
||||||
|
self.samples.get_row(index_row_i).iterator(0).copied(),
|
||||||
|
self.samples.shape().1,
|
||||||
|
),
|
||||||
|
&Vec::from_iterator(
|
||||||
|
self.samples.get_row(index_row_j).iterator(0).copied(),
|
||||||
|
self.samples.shape().1,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
if d < nbd.unwrap().to_f64().unwrap() {
|
||||||
|
// set this j-value to be the closest neighbour
|
||||||
|
index_closest = index_row_j;
|
||||||
|
nbd = Some(T::from(d).unwrap());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add that edge
|
||||||
|
distances.entry(index_row_i).and_modify(|e| {
|
||||||
|
e.distance = nbd;
|
||||||
|
e.neighbour = Some(index_closest);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
// No more neighbors, terminate conga line.
|
||||||
|
// Last person on the line has no neigbors
|
||||||
|
distances.get_mut(&max_index).unwrap().neighbour = Some(max_index);
|
||||||
|
distances.get_mut(&(len - 1)).unwrap().distance = Some(<T as Bounded>::max_value());
|
||||||
|
|
||||||
|
// compute sparse matrix (connectivity matrix)
|
||||||
|
let mut sparse_matrix = M::zeros(len, len);
|
||||||
|
for (_, p) in distances.iter() {
|
||||||
|
sparse_matrix.set((p.node, p.neighbour.unwrap()), p.distance.unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
self.distances = distances;
|
||||||
|
self.neighbours = neighbours;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Query k nearest neighbors for a row that's already in the dataset
|
||||||
|
pub fn query_row(&self, query_row_index: usize, k: usize) -> Result<Vec<(T, usize)>, Failed> {
|
||||||
|
if query_row_index >= self.samples.shape().0 {
|
||||||
|
return Err(Failed::because(
|
||||||
|
FailedError::FindFailed,
|
||||||
|
"Query row index out of bounds",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if k == 0 {
|
||||||
|
return Ok(Vec::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get distances to all other points
|
||||||
|
let mut distances = self.distances_from(query_row_index);
|
||||||
|
|
||||||
|
// Sort by distance (ascending)
|
||||||
|
distances.sort_by(|a, b| {
|
||||||
|
a.distance
|
||||||
|
.unwrap()
|
||||||
|
.partial_cmp(&b.distance.unwrap())
|
||||||
|
.unwrap_or(std::cmp::Ordering::Equal)
|
||||||
|
});
|
||||||
|
|
||||||
|
// Take top k neighbors and convert to (distance, index) format
|
||||||
|
let neighbors: Vec<(T, usize)> = distances
|
||||||
|
.into_iter()
|
||||||
|
.take(k)
|
||||||
|
.map(|pd| (pd.distance.unwrap(), pd.neighbour.unwrap()))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(neighbors)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Query k nearest neighbors for an external query vector
|
||||||
|
pub fn query(&self, query_vector: &Vec<T>, k: usize) -> Result<Vec<(T, usize)>, Failed> {
|
||||||
|
if query_vector.len() != self.samples.shape().1 {
|
||||||
|
return Err(Failed::because(
|
||||||
|
FailedError::FindFailed,
|
||||||
|
"Query vector dimension mismatch",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if k == 0 {
|
||||||
|
return Ok(Vec::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute distances from query vector to all points in the dataset
|
||||||
|
let mut distances = Vec::<PairwiseDistance<T>>::with_capacity(self.samples.shape().0);
|
||||||
|
|
||||||
|
for i in 0..self.samples.shape().0 {
|
||||||
|
let dataset_point = Vec::from_iterator(
|
||||||
|
self.samples.get_row(i).iterator(0).copied(),
|
||||||
|
self.samples.shape().1,
|
||||||
|
);
|
||||||
|
|
||||||
|
let distance = T::from(Cosine::new().distance(query_vector, &dataset_point)).unwrap();
|
||||||
|
|
||||||
|
distances.push(PairwiseDistance {
|
||||||
|
node: i, // This represents the dataset point index
|
||||||
|
neighbour: Some(i),
|
||||||
|
distance: Some(distance),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by distance (ascending)
|
||||||
|
distances.sort_by(|a, b| {
|
||||||
|
a.distance
|
||||||
|
.unwrap()
|
||||||
|
.partial_cmp(&b.distance.unwrap())
|
||||||
|
.unwrap_or(std::cmp::Ordering::Equal)
|
||||||
|
});
|
||||||
|
|
||||||
|
// Take top k neighbors and convert to (distance, index) format
|
||||||
|
let neighbors: Vec<(T, usize)> = distances
|
||||||
|
.into_iter()
|
||||||
|
.take(k)
|
||||||
|
.map(|pd| (pd.distance.unwrap(), pd.node))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(neighbors)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Optimized version that reuses the existing distances_from method
|
||||||
|
/// This is more efficient for queries that are points already in the dataset
|
||||||
|
pub fn query_optimized(
|
||||||
|
&self,
|
||||||
|
query_row_index: usize,
|
||||||
|
k: usize,
|
||||||
|
) -> Result<Vec<(T, usize)>, Failed> {
|
||||||
|
// Reuse existing method and sort the results
|
||||||
|
self.query_row(query_row_index, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find closest pair by scanning list of nearest neighbors.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn closest_pair(&self) -> PairwiseDistance<T> {
|
||||||
|
let mut a = self.neighbours[0]; // Start with first point
|
||||||
|
let mut d = self.distances[&a].distance;
|
||||||
|
for p in self.neighbours.iter() {
|
||||||
|
if self.distances[p].distance < d {
|
||||||
|
a = *p; // Update `a` and distance `d`
|
||||||
|
d = self.distances[p].distance;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let b = self.distances[&a].neighbour;
|
||||||
|
PairwiseDistance {
|
||||||
|
node: a,
|
||||||
|
neighbour: b,
|
||||||
|
distance: d,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Return order dissimilarities from closest to furthest
|
||||||
|
///
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn ordered_pairs(&self) -> std::vec::IntoIter<&PairwiseDistance<T>> {
|
||||||
|
// improvement: implement this to return `impl Iterator<Item = &PairwiseDistance<T>>`
|
||||||
|
// need to implement trait `Iterator` for `Vec<&PairwiseDistance<T>>`
|
||||||
|
let mut distances = self
|
||||||
|
.distances
|
||||||
|
.values()
|
||||||
|
.collect::<Vec<&PairwiseDistance<T>>>();
|
||||||
|
distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||||
|
distances.into_iter()
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Compute distances from input to all other points in data-structure.
|
||||||
|
// input is the row index of the sample matrix
|
||||||
|
//
|
||||||
|
#[allow(dead_code)]
|
||||||
|
fn distances_from(&self, index_row: usize) -> Vec<PairwiseDistance<T>> {
|
||||||
|
let mut distances = Vec::<PairwiseDistance<T>>::with_capacity(self.samples.shape().0);
|
||||||
|
for other in self.neighbours.iter() {
|
||||||
|
if index_row != *other {
|
||||||
|
distances.push(PairwiseDistance {
|
||||||
|
node: index_row,
|
||||||
|
neighbour: Some(*other),
|
||||||
|
distance: Some(
|
||||||
|
T::from(Cosine::new().distance(
|
||||||
|
&Vec::from_iterator(
|
||||||
|
self.samples.get_row(index_row).iterator(0).copied(),
|
||||||
|
self.samples.shape().1,
|
||||||
|
),
|
||||||
|
&Vec::from_iterator(
|
||||||
|
self.samples.get_row(*other).iterator(0).copied(),
|
||||||
|
self.samples.shape().1,
|
||||||
|
),
|
||||||
|
))
|
||||||
|
.unwrap(),
|
||||||
|
),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
distances
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix};
|
||||||
|
use approx::assert_relative_eq;
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_initialization() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
|
&[4.7, 3.2, 1.3, 0.2],
|
||||||
|
&[4.6, 3.1, 1.5, 0.2],
|
||||||
|
&[5.0, 3.6, 1.4, 0.2],
|
||||||
|
&[5.4, 3.9, 1.7, 0.4],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x);
|
||||||
|
|
||||||
|
assert!(cosine_pair.is_ok());
|
||||||
|
let cp = cosine_pair.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(cp.samples.shape().0, 6);
|
||||||
|
assert_eq!(cp.distances.len(), 6);
|
||||||
|
assert_eq!(cp.neighbours.len(), 6);
|
||||||
|
assert!(!cp.distances.is_empty());
|
||||||
|
assert!(!cp.neighbours.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_minimum_rows_error() {
|
||||||
|
// Test with only one row - should fail
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[&[5.1, 3.5, 1.4, 0.2]]).unwrap();
|
||||||
|
|
||||||
|
let result = CosinePair::new(&x);
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
if let Err(e) = result {
|
||||||
|
let expected_error =
|
||||||
|
Failed::because(FailedError::FindFailed, "min number of rows should be 2");
|
||||||
|
assert_eq!(e, expected_error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_closest_pair() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[1.0, 0.0],
|
||||||
|
&[0.0, 1.0],
|
||||||
|
&[1.0, 1.0],
|
||||||
|
&[2.0, 2.0], // This should be closest to [1.0, 1.0] with cosine distance
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
let closest_pair = cosine_pair.closest_pair();
|
||||||
|
|
||||||
|
// Verify structure
|
||||||
|
assert!(closest_pair.distance.is_some());
|
||||||
|
assert!(closest_pair.neighbour.is_some());
|
||||||
|
|
||||||
|
// The closest pair should have the smallest cosine distance
|
||||||
|
let distance = closest_pair.distance.unwrap();
|
||||||
|
assert!(distance >= 0.0 && distance <= 2.0); // Cosine distance range
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_identical_vectors() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[1.0, 2.0, 3.0],
|
||||||
|
&[1.0, 2.0, 3.0], // Identical vector
|
||||||
|
&[4.0, 5.0, 6.0],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
let closest_pair = cosine_pair.closest_pair();
|
||||||
|
|
||||||
|
// Distance between identical vectors should be 0
|
||||||
|
let distance = closest_pair.distance.unwrap();
|
||||||
|
assert!((distance - 0.0).abs() < 1e-8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_orthogonal_vectors() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[1.0, 0.0],
|
||||||
|
&[0.0, 1.0], // Orthogonal to first
|
||||||
|
&[2.0, 3.0],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
|
||||||
|
// Check that orthogonal vectors have cosine distance of 1.0
|
||||||
|
let distances_from_first = cosine_pair.distances_from(0);
|
||||||
|
let orthogonal_distance = distances_from_first
|
||||||
|
.iter()
|
||||||
|
.find(|pd| pd.neighbour == Some(1))
|
||||||
|
.unwrap()
|
||||||
|
.distance
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!((orthogonal_distance - 1.0).abs() < 1e-8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_ordered_pairs() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[1.0, 2.0],
|
||||||
|
&[2.0, 1.0],
|
||||||
|
&[3.0, 4.0],
|
||||||
|
&[4.0, 3.0],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
let ordered_pairs: Vec<_> = cosine_pair.ordered_pairs().collect();
|
||||||
|
|
||||||
|
assert_eq!(ordered_pairs.len(), 4);
|
||||||
|
|
||||||
|
// Check that pairs are ordered by distance (ascending)
|
||||||
|
for i in 1..ordered_pairs.len() {
|
||||||
|
let prev_distance = ordered_pairs[i - 1].distance.unwrap();
|
||||||
|
let curr_distance = ordered_pairs[i].distance.unwrap();
|
||||||
|
assert!(prev_distance <= curr_distance);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_query_row() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[1.0, 0.0, 0.0],
|
||||||
|
&[0.0, 1.0, 0.0],
|
||||||
|
&[0.0, 0.0, 1.0],
|
||||||
|
&[1.0, 1.0, 0.0],
|
||||||
|
&[0.0, 1.0, 1.0],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
|
||||||
|
// Query k=2 nearest neighbors for row 0
|
||||||
|
let neighbors = cosine_pair.query_row(0, 2).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(neighbors.len(), 2);
|
||||||
|
|
||||||
|
// Check that distances are in ascending order
|
||||||
|
assert!(neighbors[0].0 <= neighbors[1].0);
|
||||||
|
|
||||||
|
// All distances should be valid cosine distances (0 to 2)
|
||||||
|
for (distance, _) in &neighbors {
|
||||||
|
assert!(*distance >= 0.0 && *distance <= 2.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_query_row_bounds_error() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
|
||||||
|
// Query with out-of-bounds row index
|
||||||
|
let result = cosine_pair.query_row(5, 1);
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
if let Err(e) = result {
|
||||||
|
let expected_error =
|
||||||
|
Failed::because(FailedError::FindFailed, "Query row index out of bounds");
|
||||||
|
assert_eq!(e, expected_error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_query_row_k_zero() {
|
||||||
|
let x =
|
||||||
|
DenseMatrix::<f64>::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0], &[5.0, 6.0]]).unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
let neighbors = cosine_pair.query_row(0, 0).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(neighbors.len(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_query_external_vector() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[1.0, 0.0, 0.0],
|
||||||
|
&[0.0, 1.0, 0.0],
|
||||||
|
&[0.0, 0.0, 1.0],
|
||||||
|
&[1.0, 1.0, 0.0],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
|
||||||
|
// Query with external vector
|
||||||
|
let query_vector = vec![1.0, 0.5, 0.0];
|
||||||
|
let neighbors = cosine_pair.query(&query_vector, 2).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(neighbors.len(), 2);
|
||||||
|
|
||||||
|
// Verify distances are valid and ordered
|
||||||
|
assert!(neighbors[0].0 <= neighbors[1].0);
|
||||||
|
for (distance, index) in &neighbors {
|
||||||
|
assert!(*distance >= 0.0 && *distance <= 2.0);
|
||||||
|
assert!(*index < x.shape().0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_query_dimension_mismatch() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]).unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
|
||||||
|
// Query with mismatched dimensions
|
||||||
|
let query_vector = vec![1.0, 2.0]; // Only 2 dimensions, but data has 3
|
||||||
|
let result = cosine_pair.query(&query_vector, 1);
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
if let Err(e) = result {
|
||||||
|
let expected_error =
|
||||||
|
Failed::because(FailedError::FindFailed, "Query vector dimension mismatch");
|
||||||
|
assert_eq!(e, expected_error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_query_k_zero_external() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
let query_vector = vec![1.0, 1.0];
|
||||||
|
let neighbors = cosine_pair.query(&query_vector, 0).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(neighbors.len(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_large_dataset() {
|
||||||
|
// Test with larger dataset (similar to Iris)
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
|
&[4.7, 3.2, 1.3, 0.2],
|
||||||
|
&[4.6, 3.1, 1.5, 0.2],
|
||||||
|
&[5.0, 3.6, 1.4, 0.2],
|
||||||
|
&[5.4, 3.9, 1.7, 0.4],
|
||||||
|
&[4.6, 3.4, 1.4, 0.3],
|
||||||
|
&[5.0, 3.4, 1.5, 0.2],
|
||||||
|
&[4.4, 2.9, 1.4, 0.2],
|
||||||
|
&[4.9, 3.1, 1.5, 0.1],
|
||||||
|
&[7.0, 3.2, 4.7, 1.4],
|
||||||
|
&[6.4, 3.2, 4.5, 1.5],
|
||||||
|
&[6.9, 3.1, 4.9, 1.5],
|
||||||
|
&[5.5, 2.3, 4.0, 1.3],
|
||||||
|
&[6.5, 2.8, 4.6, 1.5],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(cosine_pair.samples.shape().0, 15);
|
||||||
|
assert_eq!(cosine_pair.distances.len(), 15);
|
||||||
|
assert_eq!(cosine_pair.neighbours.len(), 15);
|
||||||
|
|
||||||
|
// Test closest pair computation
|
||||||
|
let closest_pair = cosine_pair.closest_pair();
|
||||||
|
assert!(closest_pair.distance.is_some());
|
||||||
|
assert!(closest_pair.neighbour.is_some());
|
||||||
|
|
||||||
|
let distance = closest_pair.distance.unwrap();
|
||||||
|
assert!(distance >= 0.0 && distance <= 2.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_float_precision() {
|
||||||
|
// Test with f32 precision
|
||||||
|
let x = DenseMatrix::<f32>::from_2d_array(&[
|
||||||
|
&[1.0f32, 2.0, 3.0],
|
||||||
|
&[4.0f32, 5.0, 6.0],
|
||||||
|
&[7.0f32, 8.0, 9.0],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
let closest_pair = cosine_pair.closest_pair();
|
||||||
|
|
||||||
|
assert!(closest_pair.distance.is_some());
|
||||||
|
let distance = closest_pair.distance.unwrap();
|
||||||
|
assert!(distance >= 0.0 && distance <= 2.0);
|
||||||
|
|
||||||
|
// Test querying
|
||||||
|
let neighbors = cosine_pair.query_row(0, 2).unwrap();
|
||||||
|
assert_eq!(neighbors.len(), 2);
|
||||||
|
assert_eq!(neighbors[0].1, 1);
|
||||||
|
assert_relative_eq!(neighbors[0].0, 0.025368154);
|
||||||
|
assert_eq!(neighbors[1].1, 2);
|
||||||
|
assert_relative_eq!(neighbors[1].0, 0.040588055);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_distances_from() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[1.0, 0.0],
|
||||||
|
&[0.0, 1.0],
|
||||||
|
&[1.0, 1.0],
|
||||||
|
&[2.0, 0.0],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
let distances = cosine_pair.distances_from(0);
|
||||||
|
|
||||||
|
// Should have 3 distances (excluding self)
|
||||||
|
assert_eq!(distances.len(), 3);
|
||||||
|
|
||||||
|
// All should be from node 0
|
||||||
|
for pd in &distances {
|
||||||
|
assert_eq!(pd.node, 0);
|
||||||
|
assert!(pd.neighbour.is_some());
|
||||||
|
assert!(pd.distance.is_some());
|
||||||
|
assert!(pd.neighbour.unwrap() != 0); // Should not include self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_consistency_check() {
|
||||||
|
// Verify that different query methods return consistent results
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[1.0, 2.0, 3.0],
|
||||||
|
&[4.0, 5.0, 6.0],
|
||||||
|
&[7.0, 8.0, 9.0],
|
||||||
|
&[2.0, 3.0, 4.0],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
|
||||||
|
// Query row 0 using internal method
|
||||||
|
let neighbors_internal = cosine_pair.query_row(0, 2).unwrap();
|
||||||
|
|
||||||
|
// Query row 0 using optimized method (should be same)
|
||||||
|
let neighbors_optimized = cosine_pair.query_optimized(0, 2).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(neighbors_internal.len(), neighbors_optimized.len());
|
||||||
|
for i in 0..neighbors_internal.len() {
|
||||||
|
let (dist1, idx1) = neighbors_internal[i];
|
||||||
|
let (dist2, idx2) = neighbors_optimized[i];
|
||||||
|
assert!((dist1 - dist2).abs() < 1e-10);
|
||||||
|
assert_eq!(idx1, idx2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Brute force algorithm for testing/comparison
|
||||||
|
fn closest_pair_brute_force(
|
||||||
|
cosine_pair: &CosinePair<'_, f64, DenseMatrix<f64>>,
|
||||||
|
) -> PairwiseDistance<f64> {
|
||||||
|
use itertools::Itertools;
|
||||||
|
|
||||||
|
let m = cosine_pair.samples.shape().0;
|
||||||
|
let mut closest_pair = PairwiseDistance {
|
||||||
|
node: 0,
|
||||||
|
neighbour: None,
|
||||||
|
distance: Some(f64::MAX),
|
||||||
|
};
|
||||||
|
|
||||||
|
for pair in (0..m).combinations(2) {
|
||||||
|
let d = Cosine::new().distance(
|
||||||
|
&Vec::from_iterator(
|
||||||
|
cosine_pair.samples.get_row(pair[0]).iterator(0).copied(),
|
||||||
|
cosine_pair.samples.shape().1,
|
||||||
|
),
|
||||||
|
&Vec::from_iterator(
|
||||||
|
cosine_pair.samples.get_row(pair[1]).iterator(0).copied(),
|
||||||
|
cosine_pair.samples.shape().1,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
if d < closest_pair.distance.unwrap() {
|
||||||
|
closest_pair.node = pair[0];
|
||||||
|
closest_pair.neighbour = Some(pair[1]);
|
||||||
|
closest_pair.distance = Some(d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
closest_pair
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_vs_brute_force() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[1.0, 2.0, 3.0],
|
||||||
|
&[4.0, 5.0, 6.0],
|
||||||
|
&[7.0, 8.0, 9.0],
|
||||||
|
&[1.1, 2.1, 3.1], // Close to first point
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
let cp_result = cosine_pair.closest_pair();
|
||||||
|
let brute_result = closest_pair_brute_force(&cosine_pair);
|
||||||
|
|
||||||
|
// Results should be identical or very close
|
||||||
|
assert!((cp_result.distance.unwrap() - brute_result.distance.unwrap()).abs() < 1e-10);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -39,6 +39,8 @@ use crate::numbers::basenum::Number;
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
pub(crate) mod bbd_tree;
|
pub(crate) mod bbd_tree;
|
||||||
|
/// a variant of fastpair using cosine distance
|
||||||
|
pub mod cosinepair;
|
||||||
/// tree data structure for fast nearest neighbor search
|
/// tree data structure for fast nearest neighbor search
|
||||||
pub mod cover_tree;
|
pub mod cover_tree;
|
||||||
/// fastpair closest neighbour algorithm
|
/// fastpair closest neighbour algorithm
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
use num_traits::Num;
|
use num_traits::Num;
|
||||||
|
|
||||||
pub trait QuickArgSort {
|
pub trait QuickArgSort {
|
||||||
|
#[allow(dead_code)]
|
||||||
fn quick_argsort_mut(&mut self) -> Vec<usize>;
|
fn quick_argsort_mut(&mut self) -> Vec<usize>;
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
|
|||||||
@@ -0,0 +1,219 @@
|
|||||||
|
//! # Cosine Distance Metric
|
||||||
|
//!
|
||||||
|
//! The cosine distance between two points \\( x \\) and \\( y \\) in n-space is defined as:
|
||||||
|
//!
|
||||||
|
//! \\[ d(x, y) = 1 - \frac{x \cdot y}{||x|| ||y||} \\]
|
||||||
|
//!
|
||||||
|
//! where \\( x \cdot y \\) is the dot product of the vectors, and \\( ||x|| \\) and \\( ||y|| \\)
|
||||||
|
//! are their respective magnitudes (Euclidean norms).
|
||||||
|
//!
|
||||||
|
//! Cosine distance measures the angular dissimilarity between vectors, ranging from 0 to 2.
|
||||||
|
//! A value of 0 indicates identical direction (parallel vectors), while larger values indicate
|
||||||
|
//! greater angular separation.
|
||||||
|
//!
|
||||||
|
//! Example:
|
||||||
|
//!
|
||||||
|
//! ```
|
||||||
|
//! use smartcore::metrics::distance::Distance;
|
||||||
|
//! use smartcore::metrics::distance::cosine::Cosine;
|
||||||
|
//!
|
||||||
|
//! let x = vec![1., 1.];
|
||||||
|
//! let y = vec![2., 2.];
|
||||||
|
//!
|
||||||
|
//! let cosine_dist: f64 = Cosine::new().distance(&x, &y);
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
use crate::linalg::basic::arrays::ArrayView1;
|
||||||
|
use crate::numbers::basenum::Number;
|
||||||
|
|
||||||
|
use super::Distance;
|
||||||
|
|
||||||
|
/// Cosine distance is a measure of the angular dissimilarity between two non-zero vectors in n-space.
|
||||||
|
/// It is defined as 1 minus the cosine similarity of the vectors.
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Cosine<T> {
|
||||||
|
_t: PhantomData<T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Number> Default for Cosine<T> {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Number> Cosine<T> {
|
||||||
|
/// Instantiate the initial structure
|
||||||
|
pub fn new() -> Cosine<T> {
|
||||||
|
Cosine { _t: PhantomData }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculate the dot product of two vectors using smartcore's ArrayView1 trait
|
||||||
|
#[inline]
|
||||||
|
pub(crate) fn dot_product<A: ArrayView1<T>>(x: &A, y: &A) -> f64 {
|
||||||
|
if x.shape() != y.shape() {
|
||||||
|
panic!("Input vector sizes are different.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the built-in dot product method from ArrayView1 trait
|
||||||
|
x.dot(y).to_f64().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculate the squared magnitude (norm squared) of a vector
|
||||||
|
#[inline]
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub(crate) fn squared_magnitude<A: ArrayView1<T>>(x: &A) -> f64 {
|
||||||
|
x.iterator(0)
|
||||||
|
.map(|&a| {
|
||||||
|
let val = a.to_f64().unwrap();
|
||||||
|
val * val
|
||||||
|
})
|
||||||
|
.sum()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculate the magnitude (Euclidean norm) of a vector using smartcore's norm2 method
|
||||||
|
#[inline]
|
||||||
|
pub(crate) fn magnitude<A: ArrayView1<T>>(x: &A) -> f64 {
|
||||||
|
// Use the built-in norm2 method from ArrayView1 trait
|
||||||
|
x.norm2()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculate cosine similarity between two vectors
|
||||||
|
#[inline]
|
||||||
|
pub(crate) fn cosine_similarity<A: ArrayView1<T>>(x: &A, y: &A) -> f64 {
|
||||||
|
let dot_product = Self::dot_product(x, y);
|
||||||
|
let magnitude_x = Self::magnitude(x);
|
||||||
|
let magnitude_y = Self::magnitude(y);
|
||||||
|
|
||||||
|
if magnitude_x == 0.0 || magnitude_y == 0.0 {
|
||||||
|
return f64::MIN;
|
||||||
|
}
|
||||||
|
|
||||||
|
dot_product / (magnitude_x * magnitude_y)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Number, A: ArrayView1<T>> Distance<A> for Cosine<T> {
|
||||||
|
fn distance(&self, x: &A, y: &A) -> f64 {
|
||||||
|
let similarity = Cosine::cosine_similarity(x, y);
|
||||||
|
1.0 - similarity
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_distance_identical_vectors() {
|
||||||
|
let a = vec![1, 2, 3];
|
||||||
|
let b = vec![1, 2, 3];
|
||||||
|
|
||||||
|
let dist: f64 = Cosine::new().distance(&a, &b);
|
||||||
|
|
||||||
|
assert!((dist - 0.0).abs() < 1e-8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_distance_orthogonal_vectors() {
|
||||||
|
let a = vec![1, 0];
|
||||||
|
let b = vec![0, 1];
|
||||||
|
|
||||||
|
let dist: f64 = Cosine::new().distance(&a, &b);
|
||||||
|
|
||||||
|
assert!((dist - 1.0).abs() < 1e-8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_distance_opposite_vectors() {
|
||||||
|
let a = vec![1, 2, 3];
|
||||||
|
let b = vec![-1, -2, -3];
|
||||||
|
|
||||||
|
let dist: f64 = Cosine::new().distance(&a, &b);
|
||||||
|
|
||||||
|
assert!((dist - 2.0).abs() < 1e-8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_distance_general_case() {
|
||||||
|
let a = vec![1.0, 2.0, 3.0];
|
||||||
|
let b = vec![2.0, 1.0, 3.0];
|
||||||
|
|
||||||
|
let dist: f64 = Cosine::new().distance(&a, &b);
|
||||||
|
|
||||||
|
// Expected cosine similarity: (1*2 + 2*1 + 3*3) / (sqrt(1+4+9) * sqrt(4+1+9))
|
||||||
|
// = (2 + 2 + 9) / (sqrt(14) * sqrt(14)) = 13/14 ≈ 0.9286
|
||||||
|
// So cosine distance = 1 - 13/14 = 1/14 ≈ 0.0714
|
||||||
|
let expected_dist = 1.0 - (13.0 / 14.0);
|
||||||
|
assert!((dist - expected_dist).abs() < 1e-8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "Input vector sizes are different.")]
|
||||||
|
fn cosine_distance_different_sizes() {
|
||||||
|
let a = vec![1, 2];
|
||||||
|
let b = vec![1, 2, 3];
|
||||||
|
|
||||||
|
let _dist: f64 = Cosine::new().distance(&a, &b);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_distance_zero_vector() {
|
||||||
|
let a = vec![0, 0, 0];
|
||||||
|
let b = vec![1, 2, 3];
|
||||||
|
|
||||||
|
let dist: f64 = Cosine::new().distance(&a, &b);
|
||||||
|
assert!(dist > 1e300)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_distance_float_precision() {
|
||||||
|
let a = vec![1.0f32, 2.0, 3.0];
|
||||||
|
let b = vec![4.0f32, 5.0, 6.0];
|
||||||
|
|
||||||
|
let dist: f64 = Cosine::new().distance(&a, &b);
|
||||||
|
|
||||||
|
// Calculate expected value manually
|
||||||
|
let dot_product = 1.0 * 4.0 + 2.0 * 5.0 + 3.0 * 6.0; // = 32
|
||||||
|
let mag_a = (1.0 * 1.0 + 2.0 * 2.0 + 3.0 * 3.0_f64).sqrt(); // = sqrt(14)
|
||||||
|
let mag_b = (4.0 * 4.0 + 5.0 * 5.0 + 6.0 * 6.0_f64).sqrt(); // = sqrt(77)
|
||||||
|
let expected_similarity = dot_product / (mag_a * mag_b);
|
||||||
|
let expected_distance = 1.0 - expected_similarity;
|
||||||
|
|
||||||
|
assert!((dist - expected_distance).abs() < 1e-6);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -13,6 +13,8 @@
|
|||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
|
||||||
|
/// Cosine distance
|
||||||
|
pub mod cosine;
|
||||||
/// Euclidean Distance is the straight-line distance between two points in Euclidean spacere that presents the shortest distance between these points.
|
/// Euclidean Distance is the straight-line distance between two points in Euclidean spacere that presents the shortest distance between these points.
|
||||||
pub mod euclidian;
|
pub mod euclidian;
|
||||||
/// Hamming Distance between two strings is the number of positions at which the corresponding symbols are different.
|
/// Hamming Distance between two strings is the number of positions at which the corresponding symbols are different.
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ pub trait LineSearchMethod<T: Float> {
|
|||||||
/// Find alpha that satisfies strong Wolfe conditions.
|
/// Find alpha that satisfies strong Wolfe conditions.
|
||||||
fn search(
|
fn search(
|
||||||
&self,
|
&self,
|
||||||
f: &(dyn Fn(T) -> T),
|
f: &dyn Fn(T) -> T,
|
||||||
df: &(dyn Fn(T) -> T),
|
df: &dyn Fn(T) -> T,
|
||||||
alpha: T,
|
alpha: T,
|
||||||
f0: T,
|
f0: T,
|
||||||
df0: T,
|
df0: T,
|
||||||
@@ -55,8 +55,8 @@ impl<T: Float> Default for Backtracking<T> {
|
|||||||
impl<T: Float> LineSearchMethod<T> for Backtracking<T> {
|
impl<T: Float> LineSearchMethod<T> for Backtracking<T> {
|
||||||
fn search(
|
fn search(
|
||||||
&self,
|
&self,
|
||||||
f: &(dyn Fn(T) -> T),
|
f: &dyn Fn(T) -> T,
|
||||||
_: &(dyn Fn(T) -> T),
|
_: &dyn Fn(T) -> T,
|
||||||
alpha: T,
|
alpha: T,
|
||||||
f0: T,
|
f0: T,
|
||||||
df0: T,
|
df0: T,
|
||||||
|
|||||||
@@ -674,15 +674,20 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
) -> bool {
|
) -> bool {
|
||||||
let (n_rows, n_attr) = visitor.x.shape();
|
let (n_rows, n_attr) = visitor.x.shape();
|
||||||
|
|
||||||
let mut label = Option::None;
|
let mut label = None;
|
||||||
let mut is_pure = true;
|
let mut is_pure = true;
|
||||||
for i in 0..n_rows {
|
for i in 0..n_rows {
|
||||||
if visitor.samples[i] > 0 {
|
if visitor.samples[i] > 0 {
|
||||||
if label.is_none() {
|
match label {
|
||||||
label = Option::Some(visitor.y[i]);
|
None => {
|
||||||
} else if visitor.y[i] != label.unwrap() {
|
label = Some(visitor.y[i]);
|
||||||
is_pure = false;
|
}
|
||||||
break;
|
Some(current_label) => {
|
||||||
|
if visitor.y[i] != current_label {
|
||||||
|
is_pure = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ impl Objective {
|
|||||||
pub fn gradient<TY: Number, Y: Array1<TY>>(&self, y_true: &Y, y_pred: &Vec<f64>) -> Vec<f64> {
|
pub fn gradient<TY: Number, Y: Array1<TY>>(&self, y_true: &Y, y_pred: &Vec<f64>) -> Vec<f64> {
|
||||||
match self {
|
match self {
|
||||||
Objective::MeanSquaredError => zip(y_true.iterator(0), y_pred)
|
Objective::MeanSquaredError => zip(y_true.iterator(0), y_pred)
|
||||||
.map(|(true_val, pred_val)| (*pred_val - true_val.to_f64().unwrap()))
|
.map(|(true_val, pred_val)| *pred_val - true_val.to_f64().unwrap())
|
||||||
.collect(),
|
.collect(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user