18 Commits

Author SHA1 Message Date
dependabot[bot]
4db94badbd Update bincode requirement from 1.3.1 to 3.0.0
Updates the requirements on [bincode](https://github.com/bincode-org/bincode) to permit the latest version.
- [Commits](https://github.com/bincode-org/bincode/commits)

---
updated-dependencies:
- dependency-name: bincode
  dependency-version: 3.0.0
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-01-19 17:02:16 +00:00
Lorenzo Mec-iS
c57a4370ba bump version tp 0.4.9 2026-01-09 06:14:44 +00:00
Georeth Chow
78f18505b1 fix LASSO (#346)
* fix lasso doc typo
* fix lasso optimizer bug
2025-12-05 17:49:07 +09:00
Lorenzo
58a8624fa9 v0.4.8 (#345) 2025-11-29 02:54:35 +00:00
Georeth Chow
18de2aa244 add fit_intercept to LASSO (#344)
* add fit_intercept to LASSO
* lasso: intercept=None if fit_intercept is false
* update CHANGELOG.md to reflect lasso changes
* lasso: minor
2025-11-29 02:46:14 +00:00
Georeth Chow
2bf5f7a1a5 Fix LASSO (first two of #342) (#343)
* Fix LASSO (#342)
* change loss function in doc to match code
* allow `n == p` case
* lasso add test_full_rank_x

---------

Co-authored-by: Zhou Xiaozhou <zxz@jiweifund.com>
2025-11-28 12:15:43 +09:00
Lorenzo
0caa8306ff Modernise CI toolchain to avoid deprecation (#341)
* fix cache failing to find Cargo.toml
2025-11-24 02:25:36 +00:00
Lorenzo
2f63148de4 fix CI (#340)
* fix CI workflow
2025-11-24 02:07:49 +00:00
Lorenzo
f9e473c919 v0.4.7 (#339) 2025-11-24 01:57:25 +00:00
Charlie Martin
70d8a0f34b fix precision and recall calculations (#338)
* fix precision and recall calculations
2025-11-24 01:46:56 +00:00
Charlie Martin
0e42a97514 add serde support for XGRegressor (#337)
* add serde support for XGBoostRegressor
* add traits to dependent structs
2025-11-16 19:31:21 +09:00
Lorenzo
36efd582a5 Fix is_empty method logic in matrix.rs (#336)
* Fix is_empty method logic in matrix.rs
* bump to 0.4.6
* silence some clippy
2025-11-15 05:22:42 +00:00
Lorenzo
70212c71e0 Update Cargo.toml (#333) 2025-10-09 17:37:02 +01:00
Lorenzo
63f86f7bc9 Add with_top_k to CosineSimilarity (#332)
* Implement cosine similarity and cosinepair
* formatting
* fix clippy
* Add top k CosinePair
* fix distance computation
* set min similarity for constant zeros
* bump version to 0.4.5
2025-10-09 17:27:54 +01:00
Lorenzo
e633afa520 set min similarity for constant zeros (#331)
* set min similarity for constant zeros
* bump version
2025-10-02 15:41:18 +01:00
Lorenzo
b6e32fb328 Update README.md (#330) 2025-09-28 16:04:12 +01:00
Lorenzo
948d78a4d0 Create CITATION.cff (#329) 2025-09-28 15:50:50 +01:00
Lorenzo
448b6f77e3 Update README.md (#328) 2025-09-28 15:43:46 +01:00
19 changed files with 890 additions and 253 deletions
+10 -30
View File
@@ -31,33 +31,21 @@ jobs:
~/.cargo
./target
key: ${{ runner.os }}-cargo-${{ matrix.platform.target }}-${{ hashFiles('**/Cargo.toml') }}
restore-keys: ${{ runner.os }}-cargo-${{ matrix.platform.target }}-${{ hashFiles('**/Cargo.toml') }}
restore-keys: ${{ runner.os }}-cargo-${{ matrix.platform.target }}
- name: Install Rust toolchain
uses: actions-rs/toolchain@v1
uses: dtolnay/rust-toolchain@stable
with:
toolchain: stable
target: ${{ matrix.platform.target }}
profile: minimal
default: true
targets: ${{ matrix.platform.target }}
- name: Install test runner for wasm
if: matrix.platform.target == 'wasm32-unknown-unknown'
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
- name: Stable Build with all features
uses: actions-rs/cargo@v1
with:
command: build
args: --all-features --target ${{ matrix.platform.target }}
run: cargo build --all-features --target ${{ matrix.platform.target }}
- name: Stable Build without features
uses: actions-rs/cargo@v1
with:
command: build
args: --target ${{ matrix.platform.target }}
run: cargo build --target ${{ matrix.platform.target }}
- name: Tests
if: matrix.platform.target == 'x86_64-unknown-linux-gnu' || matrix.platform.target == 'x86_64-pc-windows-msvc' || matrix.platform.target == 'aarch64-apple-darwin'
uses: actions-rs/cargo@v1
with:
command: test
args: --all-features
run: cargo test --all-features
- name: Tests in WASM
if: matrix.platform.target == 'wasm32-unknown-unknown'
run: wasm-pack test --node -- --all-features
@@ -78,17 +66,9 @@ jobs:
path: |
~/.cargo
./target
key: ${{ runner.os }}-cargo-features-${{ hashFiles('**/Cargo.toml') }}
restore-keys: ${{ runner.os }}-cargo-features-${{ hashFiles('**/Cargo.toml') }}
key: ${{ runner.os }}-cargo-features-${{ hashFiles('Cargo.toml') }}
restore-keys: ${{ runner.os }}-cargo-features
- name: Install Rust toolchain
uses: actions-rs/toolchain@v1
with:
toolchain: stable
target: ${{ matrix.platform.target }}
profile: minimal
default: true
uses: dtolnay/rust-toolchain@stable
- name: Stable Build
uses: actions-rs/cargo@v1
with:
command: build
args: --no-default-features ${{ matrix.features }}
run: cargo build --no-default-features ${{ matrix.features }}
+6 -17
View File
@@ -19,26 +19,15 @@ jobs:
path: |
~/.cargo
./target
key: ${{ runner.os }}-coverage-cargo-${{ hashFiles('**/Cargo.toml') }}
restore-keys: ${{ runner.os }}-coverage-cargo-${{ hashFiles('**/Cargo.toml') }}
key: ${{ runner.os }}-coverage-cargo-${{ hashFiles('Cargo.toml') }}
restore-keys: ${{ runner.os }}-coverage-cargo
- name: Install Rust toolchain
uses: actions-rs/toolchain@v1
with:
toolchain: nightly
profile: minimal
default: true
uses: dtolnay/rust-toolchain@nightly
- name: Install cargo-tarpaulin
uses: actions-rs/install@v0.1
with:
crate: cargo-tarpaulin
version: latest
use-tool-cache: true
run: cargo install cargo-tarpaulin
- name: Run cargo-tarpaulin
uses: actions-rs/cargo@v1
with:
command: tarpaulin
args: --out Lcov --all-features -- --test-threads 1
run: cargo tarpaulin --out Lcov --all-features -- --test-threads 1
- name: Upload to codecov.io
uses: codecov/codecov-action@v2
uses: codecov/codecov-action@v4
with:
fail_ci_if_error: false
+9 -18
View File
@@ -6,36 +6,27 @@ on:
pull_request:
branches: [ development ]
jobs:
lint:
runs-on: ubuntu-latest
env:
TZ: "/usr/share/zoneinfo/your/location"
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Cache .cargo and target
uses: actions/cache@v4
with:
path: |
~/.cargo
./target
key: ${{ runner.os }}-lint-cargo-${{ hashFiles('**/Cargo.toml') }}
restore-keys: ${{ runner.os }}-lint-cargo-${{ hashFiles('**/Cargo.toml') }}
key: ${{ runner.os }}-lint-cargo-${{ hashFiles('Cargo.toml') }}
restore-keys: ${{ runner.os }}-lint-cargo
- name: Install Rust toolchain
uses: actions-rs/toolchain@v1
uses: dtolnay/rust-toolchain@stable
with:
toolchain: stable
profile: minimal
default: true
- run: rustup component add rustfmt
- name: Check formt
uses: actions-rs/cargo@v1
with:
command: fmt
args: --all -- --check
- run: rustup component add clippy
components: rustfmt, clippy
- name: Check format
run: cargo fmt --all -- --check
- name: Run clippy
uses: actions-rs/cargo@v1
with:
command: clippy
args: --all-features -- -Drust-2018-idioms -Dwarnings
run: cargo clippy --all-features -- -Drust-2018-idioms -Dwarnings
+5
View File
@@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.4.8] - 2025-11-29
- WARNING: Breaking changes!
- `LassoParameters` and `LassoSearchParameters` have a new field `fit_intercept`. When it is set to false, the `beta_0` term in the formula will be forced to zero, and `intercept` field in `Lasso` will be set to `None`.
## [0.4.0] - 2023-04-05
## Added
+41
View File
@@ -0,0 +1,41 @@
cff-version: 1.2.0
message: "If this software contributes to published work, please cite smartcore."
type: software
title: "smartcore: Machine Learning in Rust"
abstract: "smartcore is a comprehensive machine learning and numerical computing library for Rust, offering supervised and unsupervised algorithms, model evaluation tools, and linear algebra abstractions, with optional ndarray integration." [web:5][web:3]
repository-code: "https://github.com/smartcorelib/smartcore" [web:5]
url: "https://github.com/smartcorelib" [web:3]
license: "MIT" [web:13]
keywords:
- Rust
- machine learning
- numerical computing
- linear algebra
- classification
- regression
- clustering
- SVM
- Random Forest
- XGBoost [web:5]
authors:
- name: "smartcore Developers" [web:7]
- name: "Lorenzo (contributor)" [web:16]
- name: "Community contributors" [web:7]
version: "0.4.2" [attached_file:1]
date-released: "2025-09-14" [attached_file:1]
preferred-citation:
type: software
title: "smartcore: Machine Learning in Rust"
authors:
- name: "smartcore Developers" [web:7]
url: "https://github.com/smartcorelib" [web:3]
repository-code: "https://github.com/smartcorelib/smartcore" [web:5]
license: "MIT" [web:13]
references:
- type: manual
title: "smartcore Documentation"
url: "https://docs.rs/smartcore" [web:5]
- type: webpage
title: "smartcore Homepage"
url: "https://github.com/smartcorelib" [web:3]
notes: "For development features, see the docs.rs page and the repository README; SmartCore includes algorithms such as SVM, Random Forest, K-Means, PCA, DBSCAN, and XGBoost." [web:5]
+3 -2
View File
@@ -2,7 +2,7 @@
name = "smartcore"
description = "Machine Learning in Rust."
homepage = "https://smartcorelib.org"
version = "0.4.3"
version = "0.4.9"
authors = ["smartcore Developers"]
edition = "2021"
license = "Apache-2.0"
@@ -28,6 +28,7 @@ num = "0.4"
rand = { version = "0.8.5", default-features = false, features = ["small_rng"] }
rand_distr = { version = "0.4", optional = true }
serde = { version = "1", features = ["derive"], optional = true }
ordered-float = "5.1.0"
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
typetag = { version = "0.2", optional = true }
@@ -50,7 +51,7 @@ wasm-bindgen-test = "0.3"
[dev-dependencies]
itertools = "0.13.0"
serde_json = "1.0"
bincode = "1.3.1"
bincode = "3.0.0"
[workspace]
+127 -1
View File
@@ -16,6 +16,132 @@
</p>
-----
[![CI](https://github.com/smartcorelib/smartcore/actions/workflows/ci.yml/badge.svg)](https://github.com/smartcorelib/smartcore/actions/workflows/ci.yml)
[![CI](https://github.com/smartcorelib/smartcore/actions/workflows/ci.yml/badge.svg)](https://github.com/smartcorelib/smartcore/actions/workflows/ci.yml) [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.17219259.svg)](https://doi.org/10.5281/zenodo.17219259)
To start getting familiar with the new smartcore v0.4 API, there is now available a [**Jupyter Notebook environment repository**](https://github.com/smartcorelib/smartcore-jupyter). Please see instructions there, contributions welcome see [CONTRIBUTING](.github/CONTRIBUTING.md).
smartcore is a fast, ergonomic machine learning library for Rust, covering classical supervised and unsupervised methods with a modular linear algebra abstraction and optional ndarray support. It aims to provide production-friendly APIs, strong typing, and good defaults while remaining flexible for research and experimentation.
## Highlights
- Broad algorithm coverage: linear models, tree-based methods, ensembles, SVMs, neighbors, clustering, decomposition, and preprocessing.
- Strong linear algebra traits with optional ndarray integration for users who prefer array-first workflows.
- WASM-first defaults with attention to portability; features such as serde and datasets are opt-in.
- Practical utilities for model selection, evaluation, readers (CSV), dataset generators, and built-in sample datasets.
## Install
Add to Cargo.toml:
```toml
[dependencies]
smartcore = "^0.4.3"
```
For the latest development branch:
```toml
[dependencies]
smartcore = { git = "https://github.com/smartcorelib/smartcore", branch = "development" }
```
Optional features (examples):
- datasets
- serde
- ndarray-bindings (deprecated in favor of ndarray-only support per recent changes)
Check Cargo.toml for available features and compatibility notes.
## Quick start
Here is a minimal example fitting a KNN classifier from native Rust vectors using DenseMatrix:
```rust
use smartcore::linalg::basic::matrix::DenseMatrix;
use smartcore::neighbors::knn_classifier::KNNClassifier;
// Turn vector slices into a matrix
let x = DenseMatrix::from_2d_array(&[
&[1., 2.],
&[3., 4.],
&[5., 6.],
&[7., 8.],
&[9., 10.],
]).unwrap;
// Class labels
let y = vec![2, 2, 2, 3, 3];
// Train classifier
let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
// Predict
let yhat = knn.predict(&x).unwrap();
```
This example mirrors the “First Example” section of the crate docs and demonstrates smartcores ergonomic API surface.
## Algorithms
smartcore organizes algorithms into clear modules with consistent traits:
- Clustering: K-Means, DBSCAN, agglomerative (including single-linkage), with K-Means++ initialization and utilities.
- Matrix decomposition: SVD, EVD, Cholesky, LU, QR, plus related linear algebra helpers.
- Linear models: OLS, Ridge, Lasso, ElasticNet, Logistic Regression.
- Ensemble and tree-based: Random Forest (classifier and regressor), Extra Trees, shared reusable components across trees and forests.
- SVM: SVC/SVR with kernel enum support and multiclass extensions.
- Neighbors: KNN classification and regression with distance metrics and fast selection helpers.
- Naive Bayes: Gaussian, Bernoulli, Categorical, Multinomial.
- Preprocessing: encoders, split utilities, and common transforms.
- Model selection and metrics: K-fold, search parameters, and evaluation utilities.
Recent refactors emphasize reusable components in trees/forests and expanded multiclass SVM capabilities. XGBoost-style regression and single-linkage clustering have been added. See CHANGELOG for API changes and migration notes.
## Data access and readers
- CSV readers: Read matrices from CSV with configurable delimiter and header rows, with helpful error messages and testing utilities (including non-IO reader abstractions).
- Dataset generators: make_blobs, make_circles, make_moons for quick experiments.
- Built-in datasets (feature-gated): digits, diabetes, breast cancer, boston, with serialization utilities to persist or refresh .xy bundles.
## WebAssembly and portability
smartcore adopts a WASM/WASI-first posture in defaults to ease browser and embedded deployments. Some file-system operations are restricted in wasm targets; tests and IO utilities are structured to avoid unsupported calls where possible. Enable features like serde selectively to minimize footprint. Consult module-level docs and CHANGELOG for target-specific caveats.
## Notebooks
A curated set of Jupyter notebooks is available via the companion repository to explore smartcore interactively. To run locally, use EVCXR to enable Rust notebooks. This is the recommended path to quickly experiment with the v0.4 API.
## Roadmap and recent changes
- Trait-system refactor, fewer structs and more object-safe traits, large codebase reorganization.
- Move to Rust 2021 edition and cleanup of duplicate code paths.
- Seeds and deterministic controls across algorithms using RNG plumbing.
- Search parameter API for hyperparameter exploration in K-Means and SVM families.
- Tree and forest components refactored for reuse; Extra Trees added.
- SVM multiclass support; SVR kernel enum and related improvements.
- XGBoost-style regression introduced; single-linkage clustering implemented.
See CHANGELOG.md for precise details, deprecations, and breaking changes. Some features like nalgebra-bindings have been dropped in favor of ndarray-only paths. Default features are tuned for WASM/WASI builds; enable serde/datasets as needed.
## Contributing
Contributions are welcome:
- Open an issue describing the change and link it in the PR.
- Keep PRs in sync with the development branch and ensure tests pass on stable Rust.
- Provide or update tests; run clippy and apply formatting. Coverage and linting are part of the workflow.
- Use the provided PR and issue templates to describe behavior changes, new features, and expectations.
If adding IO, prefer abstractions that make non-IO testing straightforward (see readers/iotesting). For datasets, keep serialization helpers in tests gated appropriately to avoid unintended file writes in wasm targets.
## License
smartcore is open source under a permissive license; see Cargo.toml and LICENSE for details. The crate metadata identifies “smartcore Developers” as authors; community contributions are credited via Git history and releases.
## Acknowledgments
smartcores design incorporates well-known ML patterns while staying idiomatic to Rust. Thanks to all contributors who have helped expand algorithms, improve docs, modernize traits, and harden the codebase for production.
+365 -74
View File
@@ -23,7 +23,10 @@
/// ```
/// <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
/// <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::collections::HashMap;
use ordered_float::{FloatCore, OrderedFloat};
use std::cmp::Reverse;
use std::collections::{BinaryHeap, HashMap};
use num::Bounded;
@@ -34,6 +37,25 @@ use crate::metrics::distance::{Distance, PairwiseDistance};
use crate::numbers::floatnum::FloatNumber;
use crate::numbers::realnum::RealNumber;
/// Parameters for CosinePair construction
#[derive(Debug, Clone)]
pub struct CosinePairParameters {
/// Maximum number of neighbors to consider per point (default: all points)
pub top_k: Option<usize>,
/// Whether to use approximate nearest neighbor search
pub approximate: bool,
}
#[allow(clippy::derivable_impls)]
impl Default for CosinePairParameters {
fn default() -> Self {
Self {
top_k: None,
approximate: false,
}
}
}
///
/// Inspired by Python implementation:
/// <https://github.com/carsonfarmer/fastpair/blob/b8b4d3000ab6f795a878936667eee1b557bf353d/fastpair/base.py>
@@ -49,12 +71,29 @@ pub struct CosinePair<'a, T: RealNumber + FloatNumber, M: Array2<T>> {
pub distances: HashMap<usize, PairwiseDistance<T>>,
/// conga line used to keep track of the closest pair
pub neighbours: Vec<usize>,
/// parameters used during construction
pub parameters: CosinePairParameters,
}
impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> CosinePair<'a, T, M> {
/// Constructor
/// Instantiate and initialize the algorithm
impl<'a, T: RealNumber + FloatNumber + FloatCore, M: Array2<T>> CosinePair<'a, T, M> {
/// Constructor with default parameters (backward compatibility)
pub fn new(m: &'a M) -> Result<Self, Failed> {
Self::with_parameters(m, CosinePairParameters::default())
}
/// Constructor with top-k limiting for faster performance
pub fn with_top_k(m: &'a M, top_k: usize) -> Result<Self, Failed> {
Self::with_parameters(
m,
CosinePairParameters {
top_k: Some(top_k),
approximate: false,
},
)
}
/// Constructor with full parameter control
pub fn with_parameters(m: &'a M, parameters: CosinePairParameters) -> Result<Self, Failed> {
if m.shape().0 < 2 {
return Err(Failed::because(
FailedError::FindFailed,
@@ -64,96 +103,156 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> CosinePair<'a, T, M> {
let mut init = Self {
samples: m,
// to be computed in init(..)
distances: HashMap::with_capacity(m.shape().0),
neighbours: Vec::with_capacity(m.shape().0 + 1),
neighbours: Vec::with_capacity(m.shape().0),
parameters,
};
init.init();
Ok(init)
}
/// Initialise `CosinePair` by passing a `Array2`.
/// Build a CosinePairs data-structure from a set of (new) points.
/// Helper function to create ordered float wrapper
fn ordered_float(value: T) -> OrderedFloat<T> {
OrderedFloat(value)
}
/// Helper function to extract value from ordered float wrapper
fn extract_float(ordered: OrderedFloat<T>) -> T {
ordered.into_inner()
}
/// Optimized initialization with top-k neighbor limiting
fn init(&mut self) {
// basic measures
let len = self.samples.shape().0;
let max_index = self.samples.shape().0 - 1;
let max_neighbors: usize = self.parameters.top_k.unwrap_or(len - 1).min(len - 1);
// Store all closest neighbors
let _distances = Box::new(HashMap::with_capacity(len));
let _neighbours = Box::new(Vec::with_capacity(len));
let mut distances = HashMap::with_capacity(len);
let mut neighbours = Vec::with_capacity(len);
let mut distances = *_distances;
let mut neighbours = *_neighbours;
// fill neighbours with -1 values
neighbours.extend(0..len);
// init closest neighbour pairwise data
for index_row_i in 0..(max_index) {
// Initialize with max distances
for i in 0..len {
distances.insert(
index_row_i,
i,
PairwiseDistance {
node: index_row_i,
neighbour: Option::None,
node: i,
neighbour: None,
distance: Some(<T as Bounded>::max_value()),
},
);
}
// loop through indeces and neighbours
for index_row_i in 0..(len) {
// start looking for the neighbour in the second element
let mut index_closest = index_row_i + 1; // closest neighbour index
let mut nbd: Option<T> = distances[&index_row_i].distance; // init neighbour distance
for index_row_j in (index_row_i + 1)..len {
distances.insert(
index_row_j,
PairwiseDistance {
node: index_row_j,
neighbour: Some(index_row_i),
distance: nbd,
},
);
// Compute distances for each point using top-k optimization
for i in 0..len {
let mut candidate_distances = BinaryHeap::new();
let d = Cosine::new().distance(
&Vec::from_iterator(
self.samples.get_row(index_row_i).iterator(0).copied(),
self.samples.shape().1,
),
&Vec::from_iterator(
self.samples.get_row(index_row_j).iterator(0).copied(),
self.samples.shape().1,
),
);
if d < nbd.unwrap().to_f64().unwrap() {
// set this j-value to be the closest neighbour
index_closest = index_row_j;
nbd = Some(T::from(d).unwrap());
for j in 0..len {
if i != j {
let distance = T::from(Cosine::new().distance(
&Vec::from_iterator(
self.samples.get_row(i).iterator(0).copied(),
self.samples.shape().1,
),
&Vec::from_iterator(
self.samples.get_row(j).iterator(0).copied(),
self.samples.shape().1,
),
))
.unwrap();
// Use OrderedFloat for stable ordering
candidate_distances.push(Reverse((Self::ordered_float(distance), j)));
if candidate_distances.len() > max_neighbors {
candidate_distances.pop();
}
}
}
// Add that edge
distances.entry(index_row_i).and_modify(|e| {
e.distance = nbd;
e.neighbour = Some(index_closest);
});
}
// No more neighbors, terminate conga line.
// Last person on the line has no neigbors
distances.get_mut(&max_index).unwrap().neighbour = Some(max_index);
distances.get_mut(&(len - 1)).unwrap().distance = Some(<T as Bounded>::max_value());
// compute sparse matrix (connectivity matrix)
let mut sparse_matrix = M::zeros(len, len);
for (_, p) in distances.iter() {
sparse_matrix.set((p.node, p.neighbour.unwrap()), p.distance.unwrap());
// Find the closest neighbor from candidates
if let Some(Reverse((closest_distance, closest_neighbor))) =
candidate_distances.iter().min_by_key(|Reverse((d, _))| *d)
{
distances.entry(i).and_modify(|e| {
e.distance = Some(Self::extract_float(*closest_distance));
e.neighbour = Some(*closest_neighbor);
});
}
}
self.distances = distances;
self.neighbours = neighbours;
}
/// Fast query using top-k pre-computed neighbors with ordered-float
pub fn query_row_top_k(
&self,
query_row_index: usize,
k: usize,
) -> Result<Vec<(T, usize)>, Failed> {
if query_row_index >= self.samples.shape().0 {
return Err(Failed::because(
FailedError::FindFailed,
"Query row index out of bounds",
));
}
if k == 0 {
return Ok(Vec::new());
}
let max_candidates = self.parameters.top_k.unwrap_or(self.samples.shape().0);
let actual_k: usize = k.min(max_candidates);
// Use binary heap with ordered-float for reliable ordering
let mut heap = BinaryHeap::with_capacity(actual_k + 1);
let candidates = if let Some(top_k) = self.parameters.top_k {
let step = (self.samples.shape().0 / top_k).max(1);
(0..self.samples.shape().0)
.step_by(step)
.filter(|&i| i != query_row_index)
.take(top_k)
.collect::<Vec<_>>()
} else {
(0..self.samples.shape().0)
.filter(|&i| i != query_row_index)
.collect::<Vec<_>>()
};
for &candidate_idx in &candidates {
let distance = T::from(Cosine::new().distance(
&Vec::from_iterator(
self.samples.get_row(query_row_index).iterator(0).copied(),
self.samples.shape().1,
),
&Vec::from_iterator(
self.samples.get_row(candidate_idx).iterator(0).copied(),
self.samples.shape().1,
),
))
.unwrap();
heap.push(Reverse((Self::ordered_float(distance), candidate_idx)));
if heap.len() > actual_k {
heap.pop();
}
}
// Convert heap to sorted vector
let mut neighbors: Vec<_> = heap
.into_vec()
.into_iter()
.map(|Reverse((dist, idx))| (Self::extract_float(dist), idx))
.collect();
neighbors.sort_by(|a, b| Self::ordered_float(a.0).cmp(&Self::ordered_float(b.0)));
Ok(neighbors)
}
/// Query k nearest neighbors for a row that's already in the dataset
pub fn query_row(&self, query_row_index: usize, k: usize) -> Result<Vec<(T, usize)>, Failed> {
if query_row_index >= self.samples.shape().0 {
@@ -318,7 +417,7 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> CosinePair<'a, T, M> {
mod tests {
use super::*;
use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix};
use approx::assert_relative_eq;
use approx::{assert_relative_eq, relative_eq};
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
@@ -499,10 +598,6 @@ mod tests {
}
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn cosine_pair_query_row_bounds_error() {
let x = DenseMatrix::<f64>::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
@@ -520,10 +615,6 @@ mod tests {
}
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn cosine_pair_query_row_k_zero() {
let x =
@@ -635,6 +726,206 @@ mod tests {
assert!(distance >= 0.0 && distance <= 2.0);
}
#[test]
fn query_row_top_k_top_k_limiting() {
// Test that query_row_top_k respects top_k parameter and returns correct results
let x = DenseMatrix::<f64>::from_2d_array(&[
&[1.0, 0.0, 0.0], // Point 0
&[0.0, 1.0, 0.0], // Point 1 - orthogonal to point 0
&[0.0, 0.0, 1.0], // Point 2 - orthogonal to point 0
&[1.0, 1.0, 0.0], // Point 3 - closer to point 0 than points 1,2
&[0.5, 0.0, 0.0], // Point 4 - very close to point 0 (parallel)
&[2.0, 0.0, 0.0], // Point 5 - very close to point 0 (parallel)
&[0.0, 1.0, 1.0], // Point 6 - far from point 0
&[3.0, 3.0, 3.0], // Point 7 - moderately close to point 0
])
.unwrap();
// Create CosinePair with top_k=4 to limit candidates
let cosine_pair = CosinePair::with_top_k(&x, 4).unwrap();
// Query for 3 nearest neighbors to point 0
let neighbors = cosine_pair.query_row_top_k(0, 3).unwrap();
// Should return exactly 3 neighbors
assert_eq!(neighbors.len(), 3);
// Verify that distances are in ascending order
for i in 1..neighbors.len() {
assert!(
neighbors[i - 1].0 <= neighbors[i].0,
"Distances should be in ascending order: {} <= {}",
neighbors[i - 1].0,
neighbors[i].0
);
}
// All distances should be valid cosine distances (0 to 2)
for (distance, index) in &neighbors {
assert!(
*distance >= 0.0 && *distance <= 2.0,
"Cosine distance {} should be between 0 and 2",
distance
);
assert!(
*index < x.shape().0,
"Neighbor index {} should be less than dataset size {}",
index,
x.shape().0
);
assert!(
*index != 0,
"Neighbor index should not include query point itself"
);
}
// The closest neighbor should be either point 4 or 5 (parallel vectors)
// These should have cosine distance ≈ 0
let closest_distance = neighbors[0].0;
assert!(
closest_distance < 0.01,
"Closest parallel vector should have distance close to 0, got {}",
closest_distance
);
// Verify that we get different results with different top_k values
let cosine_pair_full = CosinePair::new(&x).unwrap();
let neighbors_full = cosine_pair_full.query_row(0, 3).unwrap();
// Results should be the same or very close since we're asking for top 3
// but the algorithm might find different candidates due to top_k limiting
assert_eq!(neighbors.len(), neighbors_full.len());
// The closest neighbor should be the same in both cases
let closest_idx_fast = neighbors[0].1;
let closest_idx_full = neighbors_full[0].1;
let closest_dist_fast = neighbors[0].0;
let closest_dist_full = neighbors_full[0].0;
// Either we get the same closest neighbor, or distances are very close
if closest_idx_fast == closest_idx_full {
assert!(relative_eq!(
closest_dist_fast,
closest_dist_full,
epsilon = 1e-10
));
} else {
// Different neighbors, but distances should be very close (parallel vectors)
assert!(relative_eq!(
closest_dist_fast,
closest_dist_full,
epsilon = 1e-6
));
}
}
#[test]
fn query_row_top_k_performance_vs_accuracy() {
// Test that query_row_top_k provides reasonable performance/accuracy tradeoff
// and handles edge cases properly
let large_dataset = DenseMatrix::<f32>::from_2d_array(&[
&[1.0f32, 2.0, 3.0, 4.0], // Point 0 - query point
&[1.1f32, 2.1, 3.1, 4.1], // Point 1 - very close to 0
&[1.05f32, 2.05, 3.05, 4.05], // Point 2 - very close to 0
&[2.0f32, 4.0, 6.0, 8.0], // Point 3 - parallel to 0 (2x scaling)
&[0.5f32, 1.0, 1.5, 2.0], // Point 4 - parallel to 0 (0.5x scaling)
&[-1.0f32, -2.0, -3.0, -4.0], // Point 5 - opposite to 0
&[4.0f32, 3.0, 2.0, 1.0], // Point 6 - different direction
&[0.0f32, 0.0, 0.0, 0.1], // Point 7 - mostly orthogonal
&[10.0f32, 20.0, 30.0, 40.0], // Point 8 - parallel but far
&[1.0f32, 0.0, 0.0, 0.0], // Point 9 - partially similar
&[0.0f32, 2.0, 0.0, 0.0], // Point 10 - partially similar
&[0.0f32, 0.0, 3.0, 0.0], // Point 11 - partially similar
])
.unwrap();
// Test with aggressive top_k limiting (only consider 5 out of 11 other points)
let cosine_pair_limited = CosinePair::with_top_k(&large_dataset, 5).unwrap();
// Query for 4 nearest neighbors
let neighbors_limited = cosine_pair_limited.query_row_top_k(0, 4).unwrap();
// Should return exactly 4 neighbors
assert_eq!(neighbors_limited.len(), 4);
// Test error handling - out of bounds query
let result_oob = cosine_pair_limited.query_row_top_k(15, 2);
assert!(result_oob.is_err());
if let Err(e) = result_oob {
assert_eq!(
e,
Failed::because(FailedError::FindFailed, "Query row index out of bounds")
);
}
// Test k=0 case
let neighbors_zero = cosine_pair_limited.query_row_top_k(0, 0).unwrap();
assert_eq!(neighbors_zero.len(), 0);
// Test k > available candidates
let neighbors_large_k = cosine_pair_limited.query_row_top_k(0, 20).unwrap();
assert!(neighbors_large_k.len() <= 11); // At most 11 other points
// Verify ordering is correct
for i in 1..neighbors_limited.len() {
assert!(
neighbors_limited[i - 1].0 <= neighbors_limited[i].0,
"Distance ordering violation at position {}: {} > {}",
i,
neighbors_limited[i - 1].0,
neighbors_limited[i].0
);
}
// The closest neighbors should be the parallel vectors (points 1, 2, 3, 4)
// since they have the smallest cosine distances
let closest_distance = neighbors_limited[0].0;
assert!(
closest_distance < 0.1,
"Closest neighbor should be nearly parallel, distance: {}",
closest_distance
);
// Compare with full algorithm for accuracy assessment
let cosine_pair_full = CosinePair::new(&large_dataset).unwrap();
let neighbors_full = cosine_pair_full.query_row(0, 4).unwrap();
// The fast version might not find the exact same neighbors due to sampling,
// but the closest neighbor's distance should be very similar
let dist_diff = (neighbors_limited[0].0 - neighbors_full[0].0).abs();
assert!(
dist_diff < 0.01,
"Fast and full algorithms should give similar closest distances. Diff: {}",
dist_diff
);
// Verify that all returned indices are valid and unique
let mut indices: Vec<usize> = neighbors_limited.iter().map(|(_, idx)| *idx).collect();
indices.sort();
indices.dedup();
assert_eq!(
indices.len(),
neighbors_limited.len(),
"All neighbor indices should be unique"
);
for &idx in &indices {
assert!(
idx < large_dataset.shape().0,
"Neighbor index {} should be valid",
idx
);
assert!(idx != 0, "Neighbor should not include query point itself");
}
// Test with f32 precision to ensure type compatibility
for (distance, _) in &neighbors_limited {
assert!(!distance.is_nan(), "Distance should not be NaN");
assert!(distance.is_finite(), "Distance should be finite");
assert!(*distance >= 0.0, "Distance should be non-negative");
}
}
#[test]
fn cosine_pair_float_precision() {
// Test with f32 precision
+1 -1
View File
@@ -1,4 +1,4 @@
#![allow(clippy::ptr_arg)]
#![allow(clippy::ptr_arg, clippy::needless_range_loop)]
//! # Nearest Neighbors Search Algorithms and Data Structures
//!
//! Nearest neighbor search is a basic computational tool that is particularly relevant to machine learning,
+1
View File
@@ -1,3 +1,4 @@
#![allow(clippy::ptr_arg, clippy::needless_range_loop)]
//! # Clustering
//!
//! Clustering is the type of unsupervised learning where you divide the population or data points into a number of groups such that data points in the same groups
+1
View File
@@ -1,3 +1,4 @@
#![allow(clippy::ptr_arg, clippy::needless_range_loop)]
//! Datasets
//!
//! In this module you will find small datasets that are used in `smartcore` mostly for demonstration purposes.
+1 -1
View File
@@ -385,7 +385,7 @@ impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrix
}
fn is_empty(&self) -> bool {
self.ncols > 0 && self.nrows > 0
self.ncols < 1 || self.nrows < 1
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
+2
View File
@@ -345,6 +345,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
l1_reg * gamma,
parameters.max_iter,
TX::from_f64(parameters.tol).unwrap(),
true,
)?;
for i in 0..p {
@@ -371,6 +372,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
l1_reg * gamma,
parameters.max_iter,
TX::from_f64(parameters.tol).unwrap(),
true,
)?;
for i in 0..p {
+145 -55
View File
@@ -9,7 +9,7 @@
//!
//! Lasso coefficient estimates solve the problem:
//!
//! \\[\underset{\beta}{minimize} \space \space \sum_{i=1}^n \left( y_i - \beta_0 - \sum_{j=1}^p \beta_jx_{ij} \right)^2 + \alpha \sum_{j=1}^p \lVert \beta_j \rVert_1\\]
//! \\[\underset{\beta}{minimize} \space \space \frac{1}{n} \sum_{i=1}^n \left( y_i - \beta_0 - \sum_{j=1}^p \beta_jx_{ij} \right)^2 + \alpha \sum_{j=1}^p \lVert \beta_j \rVert_1\\]
//!
//! This problem is solved with an interior-point method that is comparable to coordinate descent in solving large problems with modest accuracy,
//! but is able to solve them with high accuracy with relatively small additional computational cost.
@@ -53,6 +53,9 @@ pub struct LassoParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// The maximum number of iterations
pub max_iter: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// If false, force the intercept parameter (beta_0) to be zero.
pub fit_intercept: bool,
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@@ -86,6 +89,12 @@ impl LassoParameters {
self.max_iter = max_iter;
self
}
/// If false, force the intercept parameter (beta_0) to be zero.
pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
self.fit_intercept = fit_intercept;
self
}
}
impl Default for LassoParameters {
@@ -95,6 +104,7 @@ impl Default for LassoParameters {
normalize: true,
tol: 1e-4,
max_iter: 1000,
fit_intercept: true,
}
}
}
@@ -118,8 +128,8 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
{
fn new() -> Self {
Self {
coefficients: Option::None,
intercept: Option::None,
coefficients: None,
intercept: None,
_phantom_ty: PhantomData,
_phantom_y: PhantomData,
}
@@ -155,6 +165,9 @@ pub struct LassoSearchParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// The maximum number of iterations
pub max_iter: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// If false, force the intercept parameter (beta_0) to be zero.
pub fit_intercept: Vec<bool>,
}
/// Lasso grid search iterator
@@ -164,6 +177,7 @@ pub struct LassoSearchParametersIterator {
current_normalize: usize,
current_tol: usize,
current_max_iter: usize,
current_fit_intercept: usize,
}
impl IntoIterator for LassoSearchParameters {
@@ -177,6 +191,7 @@ impl IntoIterator for LassoSearchParameters {
current_normalize: 0,
current_tol: 0,
current_max_iter: 0,
current_fit_intercept: 0,
}
}
}
@@ -189,6 +204,7 @@ impl Iterator for LassoSearchParametersIterator {
&& self.current_normalize == self.lasso_search_parameters.normalize.len()
&& self.current_tol == self.lasso_search_parameters.tol.len()
&& self.current_max_iter == self.lasso_search_parameters.max_iter.len()
&& self.current_fit_intercept == self.lasso_search_parameters.fit_intercept.len()
{
return None;
}
@@ -198,6 +214,7 @@ impl Iterator for LassoSearchParametersIterator {
normalize: self.lasso_search_parameters.normalize[self.current_normalize],
tol: self.lasso_search_parameters.tol[self.current_tol],
max_iter: self.lasso_search_parameters.max_iter[self.current_max_iter],
fit_intercept: self.lasso_search_parameters.fit_intercept[self.current_fit_intercept],
};
if self.current_alpha + 1 < self.lasso_search_parameters.alpha.len() {
@@ -214,11 +231,19 @@ impl Iterator for LassoSearchParametersIterator {
self.current_normalize = 0;
self.current_tol = 0;
self.current_max_iter += 1;
} else if self.current_fit_intercept + 1 < self.lasso_search_parameters.fit_intercept.len()
{
self.current_alpha = 0;
self.current_normalize = 0;
self.current_tol = 0;
self.current_max_iter = 0;
self.current_fit_intercept += 1;
} else {
self.current_alpha += 1;
self.current_normalize += 1;
self.current_tol += 1;
self.current_max_iter += 1;
self.current_fit_intercept += 1;
}
Some(next)
@@ -234,6 +259,7 @@ impl Default for LassoSearchParameters {
normalize: vec![default_params.normalize],
tol: vec![default_params.tol],
max_iter: vec![default_params.max_iter],
fit_intercept: vec![default_params.fit_intercept],
}
}
}
@@ -246,7 +272,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
pub fn fit(x: &X, y: &Y, parameters: LassoParameters) -> Result<Lasso<TX, TY, X, Y>, Failed> {
let (n, p) = x.shape();
if n <= p {
if n < p {
return Err(Failed::fit(
"Number of rows in X should be >= number of columns in X",
));
@@ -283,19 +309,23 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
l1_reg,
parameters.max_iter,
TX::from_f64(parameters.tol).unwrap(),
parameters.fit_intercept,
)?;
for (j, col_std_j) in col_std.iter().enumerate().take(p) {
w[j] /= *col_std_j;
}
let mut b = TX::zero();
let b = if parameters.fit_intercept {
let mut xw_mean = TX::zero();
for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
xw_mean += w[i] * *col_mean_i;
}
for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
b += w[i] * *col_mean_i;
}
b = TX::from_f64(y.mean_by()).unwrap() - b;
Some(TX::from_f64(y.mean_by()).unwrap() - xw_mean)
} else {
None
};
(X::from_column(&w), b)
} else {
let mut optimizer = InteriorPointOptimizer::new(x, p);
@@ -306,13 +336,21 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
l1_reg,
parameters.max_iter,
TX::from_f64(parameters.tol).unwrap(),
parameters.fit_intercept,
)?;
(X::from_column(&w), TX::from_f64(y.mean_by()).unwrap())
(
X::from_column(&w),
if parameters.fit_intercept {
Some(TX::from_f64(y.mean_by()).unwrap())
} else {
None
},
)
};
Ok(Lasso {
intercept: Some(b),
intercept: b,
coefficients: Some(w),
_phantom_ty: PhantomData,
_phantom_y: PhantomData,
@@ -369,6 +407,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::arrays::Array;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::metrics::mean_absolute_error;
@@ -377,30 +416,28 @@ mod tests {
let parameters = LassoSearchParameters {
alpha: vec![0., 1.],
max_iter: vec![10, 100],
fit_intercept: vec![false, true],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.alpha, 0.);
assert_eq!(next.max_iter, 10);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 1.);
assert_eq!(next.max_iter, 10);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 0.);
assert_eq!(next.max_iter, 100);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 1.);
assert_eq!(next.max_iter, 100);
let mut iter = parameters.clone().into_iter();
for current_fit_intercept in 0..parameters.fit_intercept.len() {
for current_max_iter in 0..parameters.max_iter.len() {
for current_alpha in 0..parameters.alpha.len() {
let next = iter.next().unwrap();
assert_eq!(next.alpha, parameters.alpha[current_alpha]);
assert_eq!(next.max_iter, parameters.max_iter[current_max_iter]);
assert_eq!(
next.fit_intercept,
parameters.fit_intercept[current_fit_intercept]
);
}
}
}
assert!(iter.next().is_none());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn lasso_fit_predict() {
fn get_example_x_y() -> (DenseMatrix<f64>, Vec<f64>) {
let x = DenseMatrix::from_2d_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
@@ -426,6 +463,17 @@ mod tests {
114.2, 115.7, 116.9,
];
(x, y)
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn lasso_fit_predict() {
let (x, y) = get_example_x_y();
let y_hat = Lasso::fit(&x, &y, Default::default())
.and_then(|lr| lr.predict(&x))
.unwrap();
@@ -440,6 +488,7 @@ mod tests {
normalize: false,
tol: 1e-4,
max_iter: 1000,
fit_intercept: true,
},
)
.and_then(|lr| lr.predict(&x))
@@ -448,35 +497,76 @@ mod tests {
assert!(mean_absolute_error(&y_hat, &y) < 2.0);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn test_full_rank_x() {
// x: randn(3,3) * 10, demean, then round to 2 decimal points
// y = x @ [10.0, 0.2, -3.0], round to 2 decimal points
let param = LassoParameters::default()
.with_normalize(false)
.with_alpha(200.0);
let x = DenseMatrix::from_2d_array(&[
&[-8.9, -2.24, 8.89],
&[-4.02, 8.89, 12.33],
&[12.92, -6.65, -21.22],
])
.unwrap();
let y = vec![-116.12, -75.41, 191.53];
let w = Lasso::fit(&x, &y, param)
.unwrap()
.coefficients()
.iterator(0)
.copied()
.collect();
let expected_w = vec![5.20289531, 0., -5.32823882]; // by coordinate descent
assert!(mean_absolute_error(&w, &expected_w) < 1e-3); // actual mean_absolute_error is about 2e-4
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn test_fit_intercept() {
let (x, y) = get_example_x_y();
let fit_result = Lasso::fit(
&x,
&y,
LassoParameters {
alpha: 0.1,
normalize: false,
tol: 1e-8,
max_iter: 1000,
fit_intercept: false,
},
)
.unwrap();
let w = fit_result.coefficients().iterator(0).copied().collect();
// by sklearn LassoLars. coordinate descent doesn't converge well
let expected_w = vec![
0.18335684,
0.02106526,
0.00703214,
-1.35952542,
0.09295222,
0.,
];
assert!(mean_absolute_error(&w, &expected_w) < 1e-6);
assert_eq!(fit_result.intercept, None);
}
// TODO: serialization for the new DenseMatrix needs to be implemented
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
// #[test]
// #[cfg(feature = "serde")]
// fn serde() {
// let x = DenseMatrix::from_2d_array(&[
// &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
// &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
// &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
// &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
// &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
// &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
// &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
// &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
// &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
// &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
// &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
// &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
// &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
// &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
// &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
// &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
// ]);
// let y = vec![
// 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
// 114.2, 115.7, 116.9,
// ];
// let (x, y) = get_lasso_sample_x_y();
// let lr = Lasso::fit(&x, &y, Default::default()).unwrap();
// let deserialized_lr: Lasso<f64, f64, DenseMatrix<f64>, Vec<f64>> =
+9 -4
View File
@@ -45,6 +45,7 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
lambda: T,
max_iter: usize,
tol: T,
fit_intercept: bool,
) -> Result<Vec<T>, Failed> {
let (n, p) = x.shape();
let p_f64 = T::from_usize(p).unwrap();
@@ -52,6 +53,7 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
let lambda = lambda.max(T::epsilon());
//parameters
let max_ls_iter = 100;
let pcgmaxi = 5000;
let min_pcgtol = T::from_f64(0.1).unwrap();
let eta = T::from_f64(1E-3).unwrap();
@@ -61,9 +63,12 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
let mu = T::two();
// let y = M::from_row_vector(y.sub_scalar(y.mean_by())).transpose();
let y = y.sub_scalar(T::from_f64(y.mean_by()).unwrap());
let y = if fit_intercept {
y.sub_scalar(T::from_f64(y.mean_by()).unwrap())
} else {
y.to_owned()
};
let mut max_ls_iter = 100;
let mut pitr = 0;
let mut w = Vec::zeros(p);
let mut neww = w.clone();
@@ -165,7 +170,7 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
s = T::one();
let gdx = grad.dot(&dxu);
let lsiter = 0;
let mut lsiter = 0;
while lsiter < max_ls_iter {
for i in 0..p {
neww[i] = w[i] + s * dx[i];
@@ -190,7 +195,7 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
}
}
s = beta * s;
max_ls_iter += 1;
lsiter += 1;
}
if lsiter == max_ls_iter {
+3 -3
View File
@@ -92,7 +92,7 @@ impl<T: Number> Cosine<T> {
let magnitude_y = Self::magnitude(y);
if magnitude_x == 0.0 || magnitude_y == 0.0 {
panic!("Cannot compute cosine distance for zero-magnitude vectors.");
return f64::MIN;
}
dot_product / (magnitude_x * magnitude_y)
@@ -188,12 +188,12 @@ mod tests {
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
#[should_panic(expected = "Cannot compute cosine distance for zero-magnitude vectors.")]
fn cosine_distance_zero_vector() {
let a = vec![0, 0, 0];
let b = vec![1, 2, 3];
let _dist: f64 = Cosine::new().distance(&a, &b);
let dist: f64 = Cosine::new().distance(&a, &b);
assert!(dist > 1e300)
}
#[cfg_attr(
+88 -23
View File
@@ -4,7 +4,9 @@
//!
//! \\[precision = \frac{tp}{tp + fp}\\]
//!
//! where tp (true positive) - correct result, fp (false positive) - unexpected result
//! where tp (true positive) - correct result, fp (false positive) - unexpected result.
//! For binary classification, this is precision for the positive class (assumed to be 1.0).
//! For multiclass, this is macro-averaged precision (average of per-class precisions).
//!
//! Example:
//!
@@ -19,7 +21,8 @@
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::collections::HashSet;
use std::collections::{HashMap, HashSet};
use std::marker::PhantomData;
#[cfg(feature = "serde")]
@@ -61,33 +64,63 @@ impl<T: RealNumber> Metrics<T> for Precision<T> {
);
}
let mut classes = HashSet::new();
for i in 0..y_true.shape() {
classes.insert(y_true.get(i).to_f64_bits());
}
let classes = classes.len();
let n = y_true.shape();
let mut tp = 0;
let mut fp = 0;
for i in 0..y_true.shape() {
if y_pred.get(i) == y_true.get(i) {
if classes == 2 {
if *y_true.get(i) == T::one() {
let mut classes_set: HashSet<u64> = HashSet::new();
for i in 0..n {
classes_set.insert(y_true.get(i).to_f64_bits());
}
let classes: usize = classes_set.len();
if classes == 2 {
// Binary case: precision for positive class (assumed T::one())
let positive = T::one();
let mut tp: usize = 0;
let mut fp_count: usize = 0;
for i in 0..n {
let t = *y_true.get(i);
let p = *y_pred.get(i);
if p == t {
if t == positive {
tp += 1;
}
} else {
tp += 1;
}
} else if classes == 2 {
if *y_true.get(i) == T::one() {
fp += 1;
} else if t != positive {
fp_count += 1;
}
}
if tp + fp_count == 0 {
0.0
} else {
fp += 1;
tp as f64 / (tp + fp_count) as f64
}
} else {
// Multiclass case: macro-averaged precision
let mut predicted: HashMap<u64, usize> = HashMap::new();
let mut tp_map: HashMap<u64, usize> = HashMap::new();
for i in 0..n {
let p_bits = y_pred.get(i).to_f64_bits();
*predicted.entry(p_bits).or_insert(0) += 1;
if *y_true.get(i) == *y_pred.get(i) {
*tp_map.entry(p_bits).or_insert(0) += 1;
}
}
let mut precision_sum = 0.0;
for &bits in &classes_set {
let pred_count = *predicted.get(&bits).unwrap_or(&0);
let tp = *tp_map.get(&bits).unwrap_or(&0);
let prec = if pred_count > 0 {
tp as f64 / pred_count as f64
} else {
0.0
};
precision_sum += prec;
}
if classes == 0 {
0.0
} else {
precision_sum / classes as f64
}
}
tp as f64 / (tp as f64 + fp as f64)
}
}
@@ -114,7 +147,7 @@ mod tests {
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
let score3: f64 = Precision::new().get_score(&y_true, &y_pred);
assert!((score3 - 0.6666666666).abs() < 1e-8);
assert!((score3 - 0.5).abs() < 1e-8);
}
#[cfg_attr(
@@ -132,4 +165,36 @@ mod tests {
assert!((score1 - 0.333333333).abs() < 1e-8);
assert!((score2 - 1.0).abs() < 1e-8);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn precision_multiclass_imbalanced() {
let y_true: Vec<f64> = vec![0., 0., 1., 2., 2., 2.];
let y_pred: Vec<f64> = vec![0., 1., 1., 2., 0., 2.];
let score: f64 = Precision::new().get_score(&y_true, &y_pred);
let expected = (0.5 + 0.5 + 1.0) / 3.0;
assert!((score - expected).abs() < 1e-8);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn precision_multiclass_unpredicted_class() {
let y_true: Vec<f64> = vec![0., 0., 1., 2., 2., 2., 3.];
let y_pred: Vec<f64> = vec![0., 1., 1., 2., 0., 2., 0.];
let score: f64 = Precision::new().get_score(&y_true, &y_pred);
// Class 0: pred=3, tp=1 -> 1/3 ≈0.333
// Class 1: pred=2, tp=1 -> 0.5
// Class 2: pred=2, tp=2 -> 1.0
// Class 3: pred=0, tp=0 -> 0.0
let expected = (1.0 / 3.0 + 0.5 + 1.0 + 0.0) / 4.0;
assert!((score - expected).abs() < 1e-8);
}
}
+64 -24
View File
@@ -4,7 +4,9 @@
//!
//! \\[recall = \frac{tp}{tp + fn}\\]
//!
//! where tp (true positive) - correct result, fn (false negative) - missing result
//! where tp (true positive) - correct result, fn (false negative) - missing result.
//! For binary classification, this is recall for the positive class (assumed to be 1.0).
//! For multiclass, this is macro-averaged recall (average of per-class recalls).
//!
//! Example:
//!
@@ -20,8 +22,7 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::collections::HashSet;
use std::convert::TryInto;
use std::collections::{HashMap, HashSet};
use std::marker::PhantomData;
#[cfg(feature = "serde")]
@@ -52,7 +53,7 @@ impl<T: RealNumber> Metrics<T> for Recall<T> {
}
}
/// Calculated recall score
/// * `y_true` - cround truth (correct) labels.
/// * `y_true` - ground truth (correct) labels.
/// * `y_pred` - predicted labels, as returned by a classifier.
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64 {
if y_true.shape() != y_pred.shape() {
@@ -63,32 +64,57 @@ impl<T: RealNumber> Metrics<T> for Recall<T> {
);
}
let mut classes = HashSet::new();
for i in 0..y_true.shape() {
classes.insert(y_true.get(i).to_f64_bits());
}
let classes: i64 = classes.len().try_into().unwrap();
let n = y_true.shape();
let mut tp = 0;
let mut fne = 0;
for i in 0..y_true.shape() {
if y_pred.get(i) == y_true.get(i) {
if classes == 2 {
if *y_true.get(i) == T::one() {
let mut classes_set = HashSet::new();
for i in 0..n {
classes_set.insert(y_true.get(i).to_f64_bits());
}
let classes: usize = classes_set.len();
if classes == 2 {
// Binary case: recall for positive class (assumed T::one())
let positive = T::one();
let mut tp: usize = 0;
let mut fn_count: usize = 0;
for i in 0..n {
let t = *y_true.get(i);
let p = *y_pred.get(i);
if p == t {
if t == positive {
tp += 1;
}
} else {
tp += 1;
}
} else if classes == 2 {
if *y_true.get(i) != T::one() {
fne += 1;
} else if t == positive {
fn_count += 1;
}
}
if tp + fn_count == 0 {
0.0
} else {
fne += 1;
tp as f64 / (tp + fn_count) as f64
}
} else {
// Multiclass case: macro-averaged recall
let mut support: HashMap<u64, usize> = HashMap::new();
let mut tp_map: HashMap<u64, usize> = HashMap::new();
for i in 0..n {
let t_bits = y_true.get(i).to_f64_bits();
*support.entry(t_bits).or_insert(0) += 1;
if *y_true.get(i) == *y_pred.get(i) {
*tp_map.entry(t_bits).or_insert(0) += 1;
}
}
let mut recall_sum = 0.0;
for (&bits, &sup) in &support {
let tp = *tp_map.get(&bits).unwrap_or(&0);
recall_sum += tp as f64 / sup as f64;
}
if support.is_empty() {
0.0
} else {
recall_sum / support.len() as f64
}
}
tp as f64 / (tp as f64 + fne as f64)
}
}
@@ -115,7 +141,7 @@ mod tests {
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
let score3: f64 = Recall::new().get_score(&y_true, &y_pred);
assert!((score3 - 0.5).abs() < 1e-8);
assert!((score3 - (2.0 / 3.0)).abs() < 1e-8);
}
#[cfg_attr(
@@ -133,4 +159,18 @@ mod tests {
assert!((score1 - 0.333333333).abs() < 1e-8);
assert!((score2 - 1.0).abs() < 1e-8);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn recall_multiclass_imbalanced() {
let y_true: Vec<f64> = vec![0., 0., 1., 2., 2., 2.];
let y_pred: Vec<f64> = vec![0., 1., 1., 2., 0., 2.];
let score: f64 = Recall::new().get_score(&y_true, &y_pred);
let expected = (0.5 + 1.0 + (2.0 / 3.0)) / 3.0;
assert!((score - expected).abs() < 1e-8);
}
}
+9
View File
@@ -53,10 +53,14 @@ use crate::{
rand_custom::get_rng_impl,
};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
/// Defines the objective function to be optimized.
/// The objective function provides the loss, gradient (first derivative), and
/// hessian (second derivative) required for the XGBoost algorithm.
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum Objective {
/// The objective for regression tasks using Mean Squared Error.
/// Loss: 0.5 * (y_true - y_pred)^2
@@ -122,6 +126,8 @@ impl Objective {
/// This is a recursive data structure where each `TreeRegressor` is a node
/// that can have a left and a right child, also of type `TreeRegressor`.
#[allow(dead_code)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
struct TreeRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
left: Option<Box<TreeRegressor<TX, TY, X, Y>>>,
right: Option<Box<TreeRegressor<TX, TY, X, Y>>>,
@@ -374,6 +380,7 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
/// Parameters for the `jRegressor` model.
///
/// This struct holds all the hyperparameters that control the training process.
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone, Debug)]
pub struct XGRegressorParameters {
/// The number of boosting rounds or trees to build.
@@ -494,6 +501,8 @@ impl XGRegressorParameters {
}
/// An Extreme Gradient Boosting (XGBoost) model for regression and classification tasks.
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct XGRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
regressors: Option<Vec<TreeRegressor<TX, TY, X, Y>>>,
parameters: Option<XGRegressorParameters>,