Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f53cb36b9d | ||
|
|
c57a4370ba | ||
|
|
78f18505b1 | ||
|
|
58a8624fa9 | ||
|
|
18de2aa244 | ||
|
|
2bf5f7a1a5 | ||
|
|
0caa8306ff | ||
|
|
2f63148de4 | ||
|
|
f9e473c919 | ||
|
|
70d8a0f34b | ||
|
|
0e42a97514 | ||
|
|
36efd582a5 | ||
|
|
70212c71e0 | ||
|
|
63f86f7bc9 | ||
|
|
e633afa520 | ||
|
|
b6e32fb328 | ||
|
|
948d78a4d0 | ||
|
|
448b6f77e3 | ||
|
|
09be4681cf |
+10
-30
@@ -31,33 +31,21 @@ jobs:
|
|||||||
~/.cargo
|
~/.cargo
|
||||||
./target
|
./target
|
||||||
key: ${{ runner.os }}-cargo-${{ matrix.platform.target }}-${{ hashFiles('**/Cargo.toml') }}
|
key: ${{ runner.os }}-cargo-${{ matrix.platform.target }}-${{ hashFiles('**/Cargo.toml') }}
|
||||||
restore-keys: ${{ runner.os }}-cargo-${{ matrix.platform.target }}-${{ hashFiles('**/Cargo.toml') }}
|
restore-keys: ${{ runner.os }}-cargo-${{ matrix.platform.target }}
|
||||||
- name: Install Rust toolchain
|
- name: Install Rust toolchain
|
||||||
uses: actions-rs/toolchain@v1
|
uses: dtolnay/rust-toolchain@stable
|
||||||
with:
|
with:
|
||||||
toolchain: stable
|
targets: ${{ matrix.platform.target }}
|
||||||
target: ${{ matrix.platform.target }}
|
|
||||||
profile: minimal
|
|
||||||
default: true
|
|
||||||
- name: Install test runner for wasm
|
- name: Install test runner for wasm
|
||||||
if: matrix.platform.target == 'wasm32-unknown-unknown'
|
if: matrix.platform.target == 'wasm32-unknown-unknown'
|
||||||
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
|
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
|
||||||
- name: Stable Build with all features
|
- name: Stable Build with all features
|
||||||
uses: actions-rs/cargo@v1
|
run: cargo build --all-features --target ${{ matrix.platform.target }}
|
||||||
with:
|
|
||||||
command: build
|
|
||||||
args: --all-features --target ${{ matrix.platform.target }}
|
|
||||||
- name: Stable Build without features
|
- name: Stable Build without features
|
||||||
uses: actions-rs/cargo@v1
|
run: cargo build --target ${{ matrix.platform.target }}
|
||||||
with:
|
|
||||||
command: build
|
|
||||||
args: --target ${{ matrix.platform.target }}
|
|
||||||
- name: Tests
|
- name: Tests
|
||||||
if: matrix.platform.target == 'x86_64-unknown-linux-gnu' || matrix.platform.target == 'x86_64-pc-windows-msvc' || matrix.platform.target == 'aarch64-apple-darwin'
|
if: matrix.platform.target == 'x86_64-unknown-linux-gnu' || matrix.platform.target == 'x86_64-pc-windows-msvc' || matrix.platform.target == 'aarch64-apple-darwin'
|
||||||
uses: actions-rs/cargo@v1
|
run: cargo test --all-features
|
||||||
with:
|
|
||||||
command: test
|
|
||||||
args: --all-features
|
|
||||||
- name: Tests in WASM
|
- name: Tests in WASM
|
||||||
if: matrix.platform.target == 'wasm32-unknown-unknown'
|
if: matrix.platform.target == 'wasm32-unknown-unknown'
|
||||||
run: wasm-pack test --node -- --all-features
|
run: wasm-pack test --node -- --all-features
|
||||||
@@ -78,17 +66,9 @@ jobs:
|
|||||||
path: |
|
path: |
|
||||||
~/.cargo
|
~/.cargo
|
||||||
./target
|
./target
|
||||||
key: ${{ runner.os }}-cargo-features-${{ hashFiles('**/Cargo.toml') }}
|
key: ${{ runner.os }}-cargo-features-${{ hashFiles('Cargo.toml') }}
|
||||||
restore-keys: ${{ runner.os }}-cargo-features-${{ hashFiles('**/Cargo.toml') }}
|
restore-keys: ${{ runner.os }}-cargo-features
|
||||||
- name: Install Rust toolchain
|
- name: Install Rust toolchain
|
||||||
uses: actions-rs/toolchain@v1
|
uses: dtolnay/rust-toolchain@stable
|
||||||
with:
|
|
||||||
toolchain: stable
|
|
||||||
target: ${{ matrix.platform.target }}
|
|
||||||
profile: minimal
|
|
||||||
default: true
|
|
||||||
- name: Stable Build
|
- name: Stable Build
|
||||||
uses: actions-rs/cargo@v1
|
run: cargo build --no-default-features ${{ matrix.features }}
|
||||||
with:
|
|
||||||
command: build
|
|
||||||
args: --no-default-features ${{ matrix.features }}
|
|
||||||
|
|||||||
@@ -19,26 +19,15 @@ jobs:
|
|||||||
path: |
|
path: |
|
||||||
~/.cargo
|
~/.cargo
|
||||||
./target
|
./target
|
||||||
key: ${{ runner.os }}-coverage-cargo-${{ hashFiles('**/Cargo.toml') }}
|
key: ${{ runner.os }}-coverage-cargo-${{ hashFiles('Cargo.toml') }}
|
||||||
restore-keys: ${{ runner.os }}-coverage-cargo-${{ hashFiles('**/Cargo.toml') }}
|
restore-keys: ${{ runner.os }}-coverage-cargo
|
||||||
- name: Install Rust toolchain
|
- name: Install Rust toolchain
|
||||||
uses: actions-rs/toolchain@v1
|
uses: dtolnay/rust-toolchain@nightly
|
||||||
with:
|
|
||||||
toolchain: nightly
|
|
||||||
profile: minimal
|
|
||||||
default: true
|
|
||||||
- name: Install cargo-tarpaulin
|
- name: Install cargo-tarpaulin
|
||||||
uses: actions-rs/install@v0.1
|
run: cargo install cargo-tarpaulin
|
||||||
with:
|
|
||||||
crate: cargo-tarpaulin
|
|
||||||
version: latest
|
|
||||||
use-tool-cache: true
|
|
||||||
- name: Run cargo-tarpaulin
|
- name: Run cargo-tarpaulin
|
||||||
uses: actions-rs/cargo@v1
|
run: cargo tarpaulin --out Lcov --all-features -- --test-threads 1
|
||||||
with:
|
|
||||||
command: tarpaulin
|
|
||||||
args: --out Lcov --all-features -- --test-threads 1
|
|
||||||
- name: Upload to codecov.io
|
- name: Upload to codecov.io
|
||||||
uses: codecov/codecov-action@v2
|
uses: codecov/codecov-action@v4
|
||||||
with:
|
with:
|
||||||
fail_ci_if_error: false
|
fail_ci_if_error: false
|
||||||
|
|||||||
@@ -6,36 +6,27 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
branches: [ development ]
|
branches: [ development ]
|
||||||
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
lint:
|
lint:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
TZ: "/usr/share/zoneinfo/your/location"
|
TZ: "/usr/share/zoneinfo/your/location"
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v4
|
||||||
- name: Cache .cargo and target
|
- name: Cache .cargo and target
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/.cargo
|
~/.cargo
|
||||||
./target
|
./target
|
||||||
key: ${{ runner.os }}-lint-cargo-${{ hashFiles('**/Cargo.toml') }}
|
key: ${{ runner.os }}-lint-cargo-${{ hashFiles('Cargo.toml') }}
|
||||||
restore-keys: ${{ runner.os }}-lint-cargo-${{ hashFiles('**/Cargo.toml') }}
|
restore-keys: ${{ runner.os }}-lint-cargo
|
||||||
- name: Install Rust toolchain
|
- name: Install Rust toolchain
|
||||||
uses: actions-rs/toolchain@v1
|
uses: dtolnay/rust-toolchain@stable
|
||||||
with:
|
with:
|
||||||
toolchain: stable
|
components: rustfmt, clippy
|
||||||
profile: minimal
|
- name: Check format
|
||||||
default: true
|
run: cargo fmt --all -- --check
|
||||||
- run: rustup component add rustfmt
|
|
||||||
- name: Check formt
|
|
||||||
uses: actions-rs/cargo@v1
|
|
||||||
with:
|
|
||||||
command: fmt
|
|
||||||
args: --all -- --check
|
|
||||||
- run: rustup component add clippy
|
|
||||||
- name: Run clippy
|
- name: Run clippy
|
||||||
uses: actions-rs/cargo@v1
|
run: cargo clippy --all-features -- -Drust-2018-idioms -Dwarnings
|
||||||
with:
|
|
||||||
command: clippy
|
|
||||||
args: --all-features -- -Drust-2018-idioms -Dwarnings
|
|
||||||
|
|||||||
@@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
|
|||||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
|
## [0.4.8] - 2025-11-29
|
||||||
|
- WARNING: Breaking changes!
|
||||||
|
- `LassoParameters` and `LassoSearchParameters` have a new field `fit_intercept`. When it is set to false, the `beta_0` term in the formula will be forced to zero, and `intercept` field in `Lasso` will be set to `None`.
|
||||||
|
|
||||||
|
|
||||||
## [0.4.0] - 2023-04-05
|
## [0.4.0] - 2023-04-05
|
||||||
|
|
||||||
## Added
|
## Added
|
||||||
|
|||||||
@@ -0,0 +1,41 @@
|
|||||||
|
cff-version: 1.2.0
|
||||||
|
message: "If this software contributes to published work, please cite smartcore."
|
||||||
|
type: software
|
||||||
|
title: "smartcore: Machine Learning in Rust"
|
||||||
|
abstract: "smartcore is a comprehensive machine learning and numerical computing library for Rust, offering supervised and unsupervised algorithms, model evaluation tools, and linear algebra abstractions, with optional ndarray integration." [web:5][web:3]
|
||||||
|
repository-code: "https://github.com/smartcorelib/smartcore" [web:5]
|
||||||
|
url: "https://github.com/smartcorelib" [web:3]
|
||||||
|
license: "MIT" [web:13]
|
||||||
|
keywords:
|
||||||
|
- Rust
|
||||||
|
- machine learning
|
||||||
|
- numerical computing
|
||||||
|
- linear algebra
|
||||||
|
- classification
|
||||||
|
- regression
|
||||||
|
- clustering
|
||||||
|
- SVM
|
||||||
|
- Random Forest
|
||||||
|
- XGBoost [web:5]
|
||||||
|
authors:
|
||||||
|
- name: "smartcore Developers" [web:7]
|
||||||
|
- name: "Lorenzo (contributor)" [web:16]
|
||||||
|
- name: "Community contributors" [web:7]
|
||||||
|
version: "0.4.2" [attached_file:1]
|
||||||
|
date-released: "2025-09-14" [attached_file:1]
|
||||||
|
preferred-citation:
|
||||||
|
type: software
|
||||||
|
title: "smartcore: Machine Learning in Rust"
|
||||||
|
authors:
|
||||||
|
- name: "smartcore Developers" [web:7]
|
||||||
|
url: "https://github.com/smartcorelib" [web:3]
|
||||||
|
repository-code: "https://github.com/smartcorelib/smartcore" [web:5]
|
||||||
|
license: "MIT" [web:13]
|
||||||
|
references:
|
||||||
|
- type: manual
|
||||||
|
title: "smartcore Documentation"
|
||||||
|
url: "https://docs.rs/smartcore" [web:5]
|
||||||
|
- type: webpage
|
||||||
|
title: "smartcore Homepage"
|
||||||
|
url: "https://github.com/smartcorelib" [web:3]
|
||||||
|
notes: "For development features, see the docs.rs page and the repository README; SmartCore includes algorithms such as SVM, Random Forest, K-Means, PCA, DBSCAN, and XGBoost." [web:5]
|
||||||
+2
-1
@@ -2,7 +2,7 @@
|
|||||||
name = "smartcore"
|
name = "smartcore"
|
||||||
description = "Machine Learning in Rust."
|
description = "Machine Learning in Rust."
|
||||||
homepage = "https://smartcorelib.org"
|
homepage = "https://smartcorelib.org"
|
||||||
version = "0.4.2"
|
version = "0.4.9"
|
||||||
authors = ["smartcore Developers"]
|
authors = ["smartcore Developers"]
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
@@ -28,6 +28,7 @@ num = "0.4"
|
|||||||
rand = { version = "0.8.5", default-features = false, features = ["small_rng"] }
|
rand = { version = "0.8.5", default-features = false, features = ["small_rng"] }
|
||||||
rand_distr = { version = "0.4", optional = true }
|
rand_distr = { version = "0.4", optional = true }
|
||||||
serde = { version = "1", features = ["derive"], optional = true }
|
serde = { version = "1", features = ["derive"], optional = true }
|
||||||
|
ordered-float = "5.1.0"
|
||||||
|
|
||||||
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
|
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
|
||||||
typetag = { version = "0.2", optional = true }
|
typetag = { version = "0.2", optional = true }
|
||||||
|
|||||||
@@ -16,6 +16,132 @@
|
|||||||
</p>
|
</p>
|
||||||
|
|
||||||
-----
|
-----
|
||||||
[](https://github.com/smartcorelib/smartcore/actions/workflows/ci.yml)
|
[](https://github.com/smartcorelib/smartcore/actions/workflows/ci.yml) [](https://doi.org/10.5281/zenodo.17219259)
|
||||||
|
|
||||||
To start getting familiar with the new smartcore v0.4 API, there is now available a [**Jupyter Notebook environment repository**](https://github.com/smartcorelib/smartcore-jupyter). Please see instructions there, contributions welcome see [CONTRIBUTING](.github/CONTRIBUTING.md).
|
To start getting familiar with the new smartcore v0.4 API, there is now available a [**Jupyter Notebook environment repository**](https://github.com/smartcorelib/smartcore-jupyter). Please see instructions there, contributions welcome see [CONTRIBUTING](.github/CONTRIBUTING.md).
|
||||||
|
|
||||||
|
smartcore is a fast, ergonomic machine learning library for Rust, covering classical supervised and unsupervised methods with a modular linear algebra abstraction and optional ndarray support. It aims to provide production-friendly APIs, strong typing, and good defaults while remaining flexible for research and experimentation.
|
||||||
|
|
||||||
|
|
||||||
|
## Highlights
|
||||||
|
|
||||||
|
- Broad algorithm coverage: linear models, tree-based methods, ensembles, SVMs, neighbors, clustering, decomposition, and preprocessing.
|
||||||
|
- Strong linear algebra traits with optional ndarray integration for users who prefer array-first workflows.
|
||||||
|
- WASM-first defaults with attention to portability; features such as serde and datasets are opt-in.
|
||||||
|
- Practical utilities for model selection, evaluation, readers (CSV), dataset generators, and built-in sample datasets.
|
||||||
|
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
Add to Cargo.toml:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[dependencies]
|
||||||
|
smartcore = "^0.4.3"
|
||||||
|
```
|
||||||
|
|
||||||
|
For the latest development branch:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[dependencies]
|
||||||
|
smartcore = { git = "https://github.com/smartcorelib/smartcore", branch = "development" }
|
||||||
|
```
|
||||||
|
|
||||||
|
Optional features (examples):
|
||||||
|
|
||||||
|
- datasets
|
||||||
|
- serde
|
||||||
|
- ndarray-bindings (deprecated in favor of ndarray-only support per recent changes)
|
||||||
|
|
||||||
|
Check Cargo.toml for available features and compatibility notes.
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
|
||||||
|
Here is a minimal example fitting a KNN classifier from native Rust vectors using DenseMatrix:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||||
|
use smartcore::neighbors::knn_classifier::KNNClassifier;
|
||||||
|
|
||||||
|
// Turn vector slices into a matrix
|
||||||
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
|
&[1., 2.],
|
||||||
|
&[3., 4.],
|
||||||
|
&[5., 6.],
|
||||||
|
&[7., 8.],
|
||||||
|
&[9., 10.],
|
||||||
|
]).unwrap;
|
||||||
|
|
||||||
|
// Class labels
|
||||||
|
let y = vec![2, 2, 2, 3, 3];
|
||||||
|
|
||||||
|
// Train classifier
|
||||||
|
let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||||
|
|
||||||
|
// Predict
|
||||||
|
let yhat = knn.predict(&x).unwrap();
|
||||||
|
```
|
||||||
|
|
||||||
|
This example mirrors the “First Example” section of the crate docs and demonstrates smartcore’s ergonomic API surface.
|
||||||
|
|
||||||
|
## Algorithms
|
||||||
|
|
||||||
|
smartcore organizes algorithms into clear modules with consistent traits:
|
||||||
|
|
||||||
|
- Clustering: K-Means, DBSCAN, agglomerative (including single-linkage), with K-Means++ initialization and utilities.
|
||||||
|
- Matrix decomposition: SVD, EVD, Cholesky, LU, QR, plus related linear algebra helpers.
|
||||||
|
- Linear models: OLS, Ridge, Lasso, ElasticNet, Logistic Regression.
|
||||||
|
- Ensemble and tree-based: Random Forest (classifier and regressor), Extra Trees, shared reusable components across trees and forests.
|
||||||
|
- SVM: SVC/SVR with kernel enum support and multiclass extensions.
|
||||||
|
- Neighbors: KNN classification and regression with distance metrics and fast selection helpers.
|
||||||
|
- Naive Bayes: Gaussian, Bernoulli, Categorical, Multinomial.
|
||||||
|
- Preprocessing: encoders, split utilities, and common transforms.
|
||||||
|
- Model selection and metrics: K-fold, search parameters, and evaluation utilities.
|
||||||
|
|
||||||
|
Recent refactors emphasize reusable components in trees/forests and expanded multiclass SVM capabilities. XGBoost-style regression and single-linkage clustering have been added. See CHANGELOG for API changes and migration notes.
|
||||||
|
|
||||||
|
## Data access and readers
|
||||||
|
|
||||||
|
- CSV readers: Read matrices from CSV with configurable delimiter and header rows, with helpful error messages and testing utilities (including non-IO reader abstractions).
|
||||||
|
- Dataset generators: make_blobs, make_circles, make_moons for quick experiments.
|
||||||
|
- Built-in datasets (feature-gated): digits, diabetes, breast cancer, boston, with serialization utilities to persist or refresh .xy bundles.
|
||||||
|
|
||||||
|
|
||||||
|
## WebAssembly and portability
|
||||||
|
|
||||||
|
smartcore adopts a WASM/WASI-first posture in defaults to ease browser and embedded deployments. Some file-system operations are restricted in wasm targets; tests and IO utilities are structured to avoid unsupported calls where possible. Enable features like serde selectively to minimize footprint. Consult module-level docs and CHANGELOG for target-specific caveats.
|
||||||
|
|
||||||
|
## Notebooks
|
||||||
|
|
||||||
|
A curated set of Jupyter notebooks is available via the companion repository to explore smartcore interactively. To run locally, use EVCXR to enable Rust notebooks. This is the recommended path to quickly experiment with the v0.4 API.
|
||||||
|
|
||||||
|
## Roadmap and recent changes
|
||||||
|
|
||||||
|
- Trait-system refactor, fewer structs and more object-safe traits, large codebase reorganization.
|
||||||
|
- Move to Rust 2021 edition and cleanup of duplicate code paths.
|
||||||
|
- Seeds and deterministic controls across algorithms using RNG plumbing.
|
||||||
|
- Search parameter API for hyperparameter exploration in K-Means and SVM families.
|
||||||
|
- Tree and forest components refactored for reuse; Extra Trees added.
|
||||||
|
- SVM multiclass support; SVR kernel enum and related improvements.
|
||||||
|
- XGBoost-style regression introduced; single-linkage clustering implemented.
|
||||||
|
|
||||||
|
See CHANGELOG.md for precise details, deprecations, and breaking changes. Some features like nalgebra-bindings have been dropped in favor of ndarray-only paths. Default features are tuned for WASM/WASI builds; enable serde/datasets as needed.
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Contributions are welcome:
|
||||||
|
|
||||||
|
- Open an issue describing the change and link it in the PR.
|
||||||
|
- Keep PRs in sync with the development branch and ensure tests pass on stable Rust.
|
||||||
|
- Provide or update tests; run clippy and apply formatting. Coverage and linting are part of the workflow.
|
||||||
|
- Use the provided PR and issue templates to describe behavior changes, new features, and expectations.
|
||||||
|
|
||||||
|
If adding IO, prefer abstractions that make non-IO testing straightforward (see readers/iotesting). For datasets, keep serialization helpers in tests gated appropriately to avoid unintended file writes in wasm targets.
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
smartcore is open source under a permissive license; see Cargo.toml and LICENSE for details. The crate metadata identifies “smartcore Developers” as authors; community contributions are credited via Git history and releases.
|
||||||
|
|
||||||
|
## Acknowledgments
|
||||||
|
|
||||||
|
smartcore’s design incorporates well-known ML patterns while staying idiomatic to Rust. Thanks to all contributors who have helped expand algorithms, improve docs, modernize traits, and harden the codebase for production.
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
|||||||
#![allow(clippy::ptr_arg)]
|
#![allow(clippy::ptr_arg, clippy::needless_range_loop)]
|
||||||
//! # Nearest Neighbors Search Algorithms and Data Structures
|
//! # Nearest Neighbors Search Algorithms and Data Structures
|
||||||
//!
|
//!
|
||||||
//! Nearest neighbor search is a basic computational tool that is particularly relevant to machine learning,
|
//! Nearest neighbor search is a basic computational tool that is particularly relevant to machine learning,
|
||||||
@@ -39,6 +39,8 @@ use crate::numbers::basenum::Number;
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
pub(crate) mod bbd_tree;
|
pub(crate) mod bbd_tree;
|
||||||
|
/// a variant of fastpair using cosine distance
|
||||||
|
pub mod cosinepair;
|
||||||
/// tree data structure for fast nearest neighbor search
|
/// tree data structure for fast nearest neighbor search
|
||||||
pub mod cover_tree;
|
pub mod cover_tree;
|
||||||
/// fastpair closest neighbour algorithm
|
/// fastpair closest neighbour algorithm
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
use num_traits::Num;
|
use num_traits::Num;
|
||||||
|
|
||||||
pub trait QuickArgSort {
|
pub trait QuickArgSort {
|
||||||
|
#[allow(dead_code)]
|
||||||
fn quick_argsort_mut(&mut self) -> Vec<usize>;
|
fn quick_argsort_mut(&mut self) -> Vec<usize>;
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
#![allow(clippy::ptr_arg, clippy::needless_range_loop)]
|
||||||
//! # Clustering
|
//! # Clustering
|
||||||
//!
|
//!
|
||||||
//! Clustering is the type of unsupervised learning where you divide the population or data points into a number of groups such that data points in the same groups
|
//! Clustering is the type of unsupervised learning where you divide the population or data points into a number of groups such that data points in the same groups
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
#![allow(clippy::ptr_arg, clippy::needless_range_loop)]
|
||||||
//! Datasets
|
//! Datasets
|
||||||
//!
|
//!
|
||||||
//! In this module you will find small datasets that are used in `smartcore` mostly for demonstration purposes.
|
//! In this module you will find small datasets that are used in `smartcore` mostly for demonstration purposes.
|
||||||
|
|||||||
@@ -385,7 +385,7 @@ impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrix
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn is_empty(&self) -> bool {
|
fn is_empty(&self) -> bool {
|
||||||
self.ncols > 0 && self.nrows > 0
|
self.ncols < 1 || self.nrows < 1
|
||||||
}
|
}
|
||||||
|
|
||||||
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
|
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
|
||||||
|
|||||||
@@ -345,6 +345,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
l1_reg * gamma,
|
l1_reg * gamma,
|
||||||
parameters.max_iter,
|
parameters.max_iter,
|
||||||
TX::from_f64(parameters.tol).unwrap(),
|
TX::from_f64(parameters.tol).unwrap(),
|
||||||
|
true,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
for i in 0..p {
|
for i in 0..p {
|
||||||
@@ -371,6 +372,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
l1_reg * gamma,
|
l1_reg * gamma,
|
||||||
parameters.max_iter,
|
parameters.max_iter,
|
||||||
TX::from_f64(parameters.tol).unwrap(),
|
TX::from_f64(parameters.tol).unwrap(),
|
||||||
|
true,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
for i in 0..p {
|
for i in 0..p {
|
||||||
|
|||||||
+145
-55
@@ -9,7 +9,7 @@
|
|||||||
//!
|
//!
|
||||||
//! Lasso coefficient estimates solve the problem:
|
//! Lasso coefficient estimates solve the problem:
|
||||||
//!
|
//!
|
||||||
//! \\[\underset{\beta}{minimize} \space \space \sum_{i=1}^n \left( y_i - \beta_0 - \sum_{j=1}^p \beta_jx_{ij} \right)^2 + \alpha \sum_{j=1}^p \lVert \beta_j \rVert_1\\]
|
//! \\[\underset{\beta}{minimize} \space \space \frac{1}{n} \sum_{i=1}^n \left( y_i - \beta_0 - \sum_{j=1}^p \beta_jx_{ij} \right)^2 + \alpha \sum_{j=1}^p \lVert \beta_j \rVert_1\\]
|
||||||
//!
|
//!
|
||||||
//! This problem is solved with an interior-point method that is comparable to coordinate descent in solving large problems with modest accuracy,
|
//! This problem is solved with an interior-point method that is comparable to coordinate descent in solving large problems with modest accuracy,
|
||||||
//! but is able to solve them with high accuracy with relatively small additional computational cost.
|
//! but is able to solve them with high accuracy with relatively small additional computational cost.
|
||||||
@@ -53,6 +53,9 @@ pub struct LassoParameters {
|
|||||||
#[cfg_attr(feature = "serde", serde(default))]
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
/// The maximum number of iterations
|
/// The maximum number of iterations
|
||||||
pub max_iter: usize,
|
pub max_iter: usize,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
/// If false, force the intercept parameter (beta_0) to be zero.
|
||||||
|
pub fit_intercept: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
@@ -86,6 +89,12 @@ impl LassoParameters {
|
|||||||
self.max_iter = max_iter;
|
self.max_iter = max_iter;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// If false, force the intercept parameter (beta_0) to be zero.
|
||||||
|
pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
|
||||||
|
self.fit_intercept = fit_intercept;
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for LassoParameters {
|
impl Default for LassoParameters {
|
||||||
@@ -95,6 +104,7 @@ impl Default for LassoParameters {
|
|||||||
normalize: true,
|
normalize: true,
|
||||||
tol: 1e-4,
|
tol: 1e-4,
|
||||||
max_iter: 1000,
|
max_iter: 1000,
|
||||||
|
fit_intercept: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -118,8 +128,8 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
{
|
{
|
||||||
fn new() -> Self {
|
fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
coefficients: Option::None,
|
coefficients: None,
|
||||||
intercept: Option::None,
|
intercept: None,
|
||||||
_phantom_ty: PhantomData,
|
_phantom_ty: PhantomData,
|
||||||
_phantom_y: PhantomData,
|
_phantom_y: PhantomData,
|
||||||
}
|
}
|
||||||
@@ -155,6 +165,9 @@ pub struct LassoSearchParameters {
|
|||||||
#[cfg_attr(feature = "serde", serde(default))]
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
/// The maximum number of iterations
|
/// The maximum number of iterations
|
||||||
pub max_iter: Vec<usize>,
|
pub max_iter: Vec<usize>,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
/// If false, force the intercept parameter (beta_0) to be zero.
|
||||||
|
pub fit_intercept: Vec<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Lasso grid search iterator
|
/// Lasso grid search iterator
|
||||||
@@ -164,6 +177,7 @@ pub struct LassoSearchParametersIterator {
|
|||||||
current_normalize: usize,
|
current_normalize: usize,
|
||||||
current_tol: usize,
|
current_tol: usize,
|
||||||
current_max_iter: usize,
|
current_max_iter: usize,
|
||||||
|
current_fit_intercept: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IntoIterator for LassoSearchParameters {
|
impl IntoIterator for LassoSearchParameters {
|
||||||
@@ -177,6 +191,7 @@ impl IntoIterator for LassoSearchParameters {
|
|||||||
current_normalize: 0,
|
current_normalize: 0,
|
||||||
current_tol: 0,
|
current_tol: 0,
|
||||||
current_max_iter: 0,
|
current_max_iter: 0,
|
||||||
|
current_fit_intercept: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -189,6 +204,7 @@ impl Iterator for LassoSearchParametersIterator {
|
|||||||
&& self.current_normalize == self.lasso_search_parameters.normalize.len()
|
&& self.current_normalize == self.lasso_search_parameters.normalize.len()
|
||||||
&& self.current_tol == self.lasso_search_parameters.tol.len()
|
&& self.current_tol == self.lasso_search_parameters.tol.len()
|
||||||
&& self.current_max_iter == self.lasso_search_parameters.max_iter.len()
|
&& self.current_max_iter == self.lasso_search_parameters.max_iter.len()
|
||||||
|
&& self.current_fit_intercept == self.lasso_search_parameters.fit_intercept.len()
|
||||||
{
|
{
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
@@ -198,6 +214,7 @@ impl Iterator for LassoSearchParametersIterator {
|
|||||||
normalize: self.lasso_search_parameters.normalize[self.current_normalize],
|
normalize: self.lasso_search_parameters.normalize[self.current_normalize],
|
||||||
tol: self.lasso_search_parameters.tol[self.current_tol],
|
tol: self.lasso_search_parameters.tol[self.current_tol],
|
||||||
max_iter: self.lasso_search_parameters.max_iter[self.current_max_iter],
|
max_iter: self.lasso_search_parameters.max_iter[self.current_max_iter],
|
||||||
|
fit_intercept: self.lasso_search_parameters.fit_intercept[self.current_fit_intercept],
|
||||||
};
|
};
|
||||||
|
|
||||||
if self.current_alpha + 1 < self.lasso_search_parameters.alpha.len() {
|
if self.current_alpha + 1 < self.lasso_search_parameters.alpha.len() {
|
||||||
@@ -214,11 +231,19 @@ impl Iterator for LassoSearchParametersIterator {
|
|||||||
self.current_normalize = 0;
|
self.current_normalize = 0;
|
||||||
self.current_tol = 0;
|
self.current_tol = 0;
|
||||||
self.current_max_iter += 1;
|
self.current_max_iter += 1;
|
||||||
|
} else if self.current_fit_intercept + 1 < self.lasso_search_parameters.fit_intercept.len()
|
||||||
|
{
|
||||||
|
self.current_alpha = 0;
|
||||||
|
self.current_normalize = 0;
|
||||||
|
self.current_tol = 0;
|
||||||
|
self.current_max_iter = 0;
|
||||||
|
self.current_fit_intercept += 1;
|
||||||
} else {
|
} else {
|
||||||
self.current_alpha += 1;
|
self.current_alpha += 1;
|
||||||
self.current_normalize += 1;
|
self.current_normalize += 1;
|
||||||
self.current_tol += 1;
|
self.current_tol += 1;
|
||||||
self.current_max_iter += 1;
|
self.current_max_iter += 1;
|
||||||
|
self.current_fit_intercept += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
Some(next)
|
Some(next)
|
||||||
@@ -234,6 +259,7 @@ impl Default for LassoSearchParameters {
|
|||||||
normalize: vec![default_params.normalize],
|
normalize: vec![default_params.normalize],
|
||||||
tol: vec![default_params.tol],
|
tol: vec![default_params.tol],
|
||||||
max_iter: vec![default_params.max_iter],
|
max_iter: vec![default_params.max_iter],
|
||||||
|
fit_intercept: vec![default_params.fit_intercept],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -246,7 +272,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
|
|||||||
pub fn fit(x: &X, y: &Y, parameters: LassoParameters) -> Result<Lasso<TX, TY, X, Y>, Failed> {
|
pub fn fit(x: &X, y: &Y, parameters: LassoParameters) -> Result<Lasso<TX, TY, X, Y>, Failed> {
|
||||||
let (n, p) = x.shape();
|
let (n, p) = x.shape();
|
||||||
|
|
||||||
if n <= p {
|
if n < p {
|
||||||
return Err(Failed::fit(
|
return Err(Failed::fit(
|
||||||
"Number of rows in X should be >= number of columns in X",
|
"Number of rows in X should be >= number of columns in X",
|
||||||
));
|
));
|
||||||
@@ -283,19 +309,23 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
|
|||||||
l1_reg,
|
l1_reg,
|
||||||
parameters.max_iter,
|
parameters.max_iter,
|
||||||
TX::from_f64(parameters.tol).unwrap(),
|
TX::from_f64(parameters.tol).unwrap(),
|
||||||
|
parameters.fit_intercept,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
for (j, col_std_j) in col_std.iter().enumerate().take(p) {
|
for (j, col_std_j) in col_std.iter().enumerate().take(p) {
|
||||||
w[j] /= *col_std_j;
|
w[j] /= *col_std_j;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut b = TX::zero();
|
let b = if parameters.fit_intercept {
|
||||||
|
let mut xw_mean = TX::zero();
|
||||||
|
for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
|
||||||
|
xw_mean += w[i] * *col_mean_i;
|
||||||
|
}
|
||||||
|
|
||||||
for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
|
Some(TX::from_f64(y.mean_by()).unwrap() - xw_mean)
|
||||||
b += w[i] * *col_mean_i;
|
} else {
|
||||||
}
|
None
|
||||||
|
};
|
||||||
b = TX::from_f64(y.mean_by()).unwrap() - b;
|
|
||||||
(X::from_column(&w), b)
|
(X::from_column(&w), b)
|
||||||
} else {
|
} else {
|
||||||
let mut optimizer = InteriorPointOptimizer::new(x, p);
|
let mut optimizer = InteriorPointOptimizer::new(x, p);
|
||||||
@@ -306,13 +336,21 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
|
|||||||
l1_reg,
|
l1_reg,
|
||||||
parameters.max_iter,
|
parameters.max_iter,
|
||||||
TX::from_f64(parameters.tol).unwrap(),
|
TX::from_f64(parameters.tol).unwrap(),
|
||||||
|
parameters.fit_intercept,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
(X::from_column(&w), TX::from_f64(y.mean_by()).unwrap())
|
(
|
||||||
|
X::from_column(&w),
|
||||||
|
if parameters.fit_intercept {
|
||||||
|
Some(TX::from_f64(y.mean_by()).unwrap())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
},
|
||||||
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Lasso {
|
Ok(Lasso {
|
||||||
intercept: Some(b),
|
intercept: b,
|
||||||
coefficients: Some(w),
|
coefficients: Some(w),
|
||||||
_phantom_ty: PhantomData,
|
_phantom_ty: PhantomData,
|
||||||
_phantom_y: PhantomData,
|
_phantom_y: PhantomData,
|
||||||
@@ -369,6 +407,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::linalg::basic::arrays::Array;
|
||||||
use crate::linalg::basic::matrix::DenseMatrix;
|
use crate::linalg::basic::matrix::DenseMatrix;
|
||||||
use crate::metrics::mean_absolute_error;
|
use crate::metrics::mean_absolute_error;
|
||||||
|
|
||||||
@@ -377,30 +416,28 @@ mod tests {
|
|||||||
let parameters = LassoSearchParameters {
|
let parameters = LassoSearchParameters {
|
||||||
alpha: vec![0., 1.],
|
alpha: vec![0., 1.],
|
||||||
max_iter: vec![10, 100],
|
max_iter: vec![10, 100],
|
||||||
|
fit_intercept: vec![false, true],
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let mut iter = parameters.into_iter();
|
|
||||||
let next = iter.next().unwrap();
|
let mut iter = parameters.clone().into_iter();
|
||||||
assert_eq!(next.alpha, 0.);
|
for current_fit_intercept in 0..parameters.fit_intercept.len() {
|
||||||
assert_eq!(next.max_iter, 10);
|
for current_max_iter in 0..parameters.max_iter.len() {
|
||||||
let next = iter.next().unwrap();
|
for current_alpha in 0..parameters.alpha.len() {
|
||||||
assert_eq!(next.alpha, 1.);
|
let next = iter.next().unwrap();
|
||||||
assert_eq!(next.max_iter, 10);
|
assert_eq!(next.alpha, parameters.alpha[current_alpha]);
|
||||||
let next = iter.next().unwrap();
|
assert_eq!(next.max_iter, parameters.max_iter[current_max_iter]);
|
||||||
assert_eq!(next.alpha, 0.);
|
assert_eq!(
|
||||||
assert_eq!(next.max_iter, 100);
|
next.fit_intercept,
|
||||||
let next = iter.next().unwrap();
|
parameters.fit_intercept[current_fit_intercept]
|
||||||
assert_eq!(next.alpha, 1.);
|
);
|
||||||
assert_eq!(next.max_iter, 100);
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
assert!(iter.next().is_none());
|
assert!(iter.next().is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg_attr(
|
fn get_example_x_y() -> (DenseMatrix<f64>, Vec<f64>) {
|
||||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
|
||||||
wasm_bindgen_test::wasm_bindgen_test
|
|
||||||
)]
|
|
||||||
#[test]
|
|
||||||
fn lasso_fit_predict() {
|
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||||
@@ -426,6 +463,17 @@ mod tests {
|
|||||||
114.2, 115.7, 116.9,
|
114.2, 115.7, 116.9,
|
||||||
];
|
];
|
||||||
|
|
||||||
|
(x, y)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn lasso_fit_predict() {
|
||||||
|
let (x, y) = get_example_x_y();
|
||||||
|
|
||||||
let y_hat = Lasso::fit(&x, &y, Default::default())
|
let y_hat = Lasso::fit(&x, &y, Default::default())
|
||||||
.and_then(|lr| lr.predict(&x))
|
.and_then(|lr| lr.predict(&x))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -440,6 +488,7 @@ mod tests {
|
|||||||
normalize: false,
|
normalize: false,
|
||||||
tol: 1e-4,
|
tol: 1e-4,
|
||||||
max_iter: 1000,
|
max_iter: 1000,
|
||||||
|
fit_intercept: true,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.and_then(|lr| lr.predict(&x))
|
.and_then(|lr| lr.predict(&x))
|
||||||
@@ -448,35 +497,76 @@ mod tests {
|
|||||||
assert!(mean_absolute_error(&y_hat, &y) < 2.0);
|
assert!(mean_absolute_error(&y_hat, &y) < 2.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn test_full_rank_x() {
|
||||||
|
// x: randn(3,3) * 10, demean, then round to 2 decimal points
|
||||||
|
// y = x @ [10.0, 0.2, -3.0], round to 2 decimal points
|
||||||
|
let param = LassoParameters::default()
|
||||||
|
.with_normalize(false)
|
||||||
|
.with_alpha(200.0);
|
||||||
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
|
&[-8.9, -2.24, 8.89],
|
||||||
|
&[-4.02, 8.89, 12.33],
|
||||||
|
&[12.92, -6.65, -21.22],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let y = vec![-116.12, -75.41, 191.53];
|
||||||
|
let w = Lasso::fit(&x, &y, param)
|
||||||
|
.unwrap()
|
||||||
|
.coefficients()
|
||||||
|
.iterator(0)
|
||||||
|
.copied()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let expected_w = vec![5.20289531, 0., -5.32823882]; // by coordinate descent
|
||||||
|
assert!(mean_absolute_error(&w, &expected_w) < 1e-3); // actual mean_absolute_error is about 2e-4
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn test_fit_intercept() {
|
||||||
|
let (x, y) = get_example_x_y();
|
||||||
|
let fit_result = Lasso::fit(
|
||||||
|
&x,
|
||||||
|
&y,
|
||||||
|
LassoParameters {
|
||||||
|
alpha: 0.1,
|
||||||
|
normalize: false,
|
||||||
|
tol: 1e-8,
|
||||||
|
max_iter: 1000,
|
||||||
|
fit_intercept: false,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let w = fit_result.coefficients().iterator(0).copied().collect();
|
||||||
|
// by sklearn LassoLars. coordinate descent doesn't converge well
|
||||||
|
let expected_w = vec![
|
||||||
|
0.18335684,
|
||||||
|
0.02106526,
|
||||||
|
0.00703214,
|
||||||
|
-1.35952542,
|
||||||
|
0.09295222,
|
||||||
|
0.,
|
||||||
|
];
|
||||||
|
assert!(mean_absolute_error(&w, &expected_w) < 1e-6);
|
||||||
|
assert_eq!(fit_result.intercept, None);
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: serialization for the new DenseMatrix needs to be implemented
|
// TODO: serialization for the new DenseMatrix needs to be implemented
|
||||||
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
|
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
// #[test]
|
// #[test]
|
||||||
// #[cfg(feature = "serde")]
|
// #[cfg(feature = "serde")]
|
||||||
// fn serde() {
|
// fn serde() {
|
||||||
// let x = DenseMatrix::from_2d_array(&[
|
// let (x, y) = get_lasso_sample_x_y();
|
||||||
// &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
|
||||||
// &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
|
||||||
// &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
|
||||||
// &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
|
|
||||||
// &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
|
||||||
// &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
|
|
||||||
// &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
|
|
||||||
// &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
|
|
||||||
// &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
|
||||||
// &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
|
|
||||||
// &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
|
||||||
// &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
|
|
||||||
// &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
|
||||||
// &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
|
||||||
// &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
|
||||||
// &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
|
||||||
// ]);
|
|
||||||
|
|
||||||
// let y = vec![
|
|
||||||
// 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
|
||||||
// 114.2, 115.7, 116.9,
|
|
||||||
// ];
|
|
||||||
|
|
||||||
// let lr = Lasso::fit(&x, &y, Default::default()).unwrap();
|
// let lr = Lasso::fit(&x, &y, Default::default()).unwrap();
|
||||||
|
|
||||||
// let deserialized_lr: Lasso<f64, f64, DenseMatrix<f64>, Vec<f64>> =
|
// let deserialized_lr: Lasso<f64, f64, DenseMatrix<f64>, Vec<f64>> =
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
|
|||||||
lambda: T,
|
lambda: T,
|
||||||
max_iter: usize,
|
max_iter: usize,
|
||||||
tol: T,
|
tol: T,
|
||||||
|
fit_intercept: bool,
|
||||||
) -> Result<Vec<T>, Failed> {
|
) -> Result<Vec<T>, Failed> {
|
||||||
let (n, p) = x.shape();
|
let (n, p) = x.shape();
|
||||||
let p_f64 = T::from_usize(p).unwrap();
|
let p_f64 = T::from_usize(p).unwrap();
|
||||||
@@ -52,6 +53,7 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
|
|||||||
let lambda = lambda.max(T::epsilon());
|
let lambda = lambda.max(T::epsilon());
|
||||||
|
|
||||||
//parameters
|
//parameters
|
||||||
|
let max_ls_iter = 100;
|
||||||
let pcgmaxi = 5000;
|
let pcgmaxi = 5000;
|
||||||
let min_pcgtol = T::from_f64(0.1).unwrap();
|
let min_pcgtol = T::from_f64(0.1).unwrap();
|
||||||
let eta = T::from_f64(1E-3).unwrap();
|
let eta = T::from_f64(1E-3).unwrap();
|
||||||
@@ -61,9 +63,12 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
|
|||||||
let mu = T::two();
|
let mu = T::two();
|
||||||
|
|
||||||
// let y = M::from_row_vector(y.sub_scalar(y.mean_by())).transpose();
|
// let y = M::from_row_vector(y.sub_scalar(y.mean_by())).transpose();
|
||||||
let y = y.sub_scalar(T::from_f64(y.mean_by()).unwrap());
|
let y = if fit_intercept {
|
||||||
|
y.sub_scalar(T::from_f64(y.mean_by()).unwrap())
|
||||||
|
} else {
|
||||||
|
y.to_owned()
|
||||||
|
};
|
||||||
|
|
||||||
let mut max_ls_iter = 100;
|
|
||||||
let mut pitr = 0;
|
let mut pitr = 0;
|
||||||
let mut w = Vec::zeros(p);
|
let mut w = Vec::zeros(p);
|
||||||
let mut neww = w.clone();
|
let mut neww = w.clone();
|
||||||
@@ -165,7 +170,7 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
|
|||||||
s = T::one();
|
s = T::one();
|
||||||
let gdx = grad.dot(&dxu);
|
let gdx = grad.dot(&dxu);
|
||||||
|
|
||||||
let lsiter = 0;
|
let mut lsiter = 0;
|
||||||
while lsiter < max_ls_iter {
|
while lsiter < max_ls_iter {
|
||||||
for i in 0..p {
|
for i in 0..p {
|
||||||
neww[i] = w[i] + s * dx[i];
|
neww[i] = w[i] + s * dx[i];
|
||||||
@@ -190,7 +195,7 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
s = beta * s;
|
s = beta * s;
|
||||||
max_ls_iter += 1;
|
lsiter += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if lsiter == max_ls_iter {
|
if lsiter == max_ls_iter {
|
||||||
|
|||||||
@@ -0,0 +1,219 @@
|
|||||||
|
//! # Cosine Distance Metric
|
||||||
|
//!
|
||||||
|
//! The cosine distance between two points \\( x \\) and \\( y \\) in n-space is defined as:
|
||||||
|
//!
|
||||||
|
//! \\[ d(x, y) = 1 - \frac{x \cdot y}{||x|| ||y||} \\]
|
||||||
|
//!
|
||||||
|
//! where \\( x \cdot y \\) is the dot product of the vectors, and \\( ||x|| \\) and \\( ||y|| \\)
|
||||||
|
//! are their respective magnitudes (Euclidean norms).
|
||||||
|
//!
|
||||||
|
//! Cosine distance measures the angular dissimilarity between vectors, ranging from 0 to 2.
|
||||||
|
//! A value of 0 indicates identical direction (parallel vectors), while larger values indicate
|
||||||
|
//! greater angular separation.
|
||||||
|
//!
|
||||||
|
//! Example:
|
||||||
|
//!
|
||||||
|
//! ```
|
||||||
|
//! use smartcore::metrics::distance::Distance;
|
||||||
|
//! use smartcore::metrics::distance::cosine::Cosine;
|
||||||
|
//!
|
||||||
|
//! let x = vec![1., 1.];
|
||||||
|
//! let y = vec![2., 2.];
|
||||||
|
//!
|
||||||
|
//! let cosine_dist: f64 = Cosine::new().distance(&x, &y);
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
use crate::linalg::basic::arrays::ArrayView1;
|
||||||
|
use crate::numbers::basenum::Number;
|
||||||
|
|
||||||
|
use super::Distance;
|
||||||
|
|
||||||
|
/// Cosine distance is a measure of the angular dissimilarity between two non-zero vectors in n-space.
|
||||||
|
/// It is defined as 1 minus the cosine similarity of the vectors.
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Cosine<T> {
|
||||||
|
_t: PhantomData<T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Number> Default for Cosine<T> {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Number> Cosine<T> {
|
||||||
|
/// Instantiate the initial structure
|
||||||
|
pub fn new() -> Cosine<T> {
|
||||||
|
Cosine { _t: PhantomData }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculate the dot product of two vectors using smartcore's ArrayView1 trait
|
||||||
|
#[inline]
|
||||||
|
pub(crate) fn dot_product<A: ArrayView1<T>>(x: &A, y: &A) -> f64 {
|
||||||
|
if x.shape() != y.shape() {
|
||||||
|
panic!("Input vector sizes are different.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the built-in dot product method from ArrayView1 trait
|
||||||
|
x.dot(y).to_f64().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculate the squared magnitude (norm squared) of a vector
|
||||||
|
#[inline]
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub(crate) fn squared_magnitude<A: ArrayView1<T>>(x: &A) -> f64 {
|
||||||
|
x.iterator(0)
|
||||||
|
.map(|&a| {
|
||||||
|
let val = a.to_f64().unwrap();
|
||||||
|
val * val
|
||||||
|
})
|
||||||
|
.sum()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculate the magnitude (Euclidean norm) of a vector using smartcore's norm2 method
|
||||||
|
#[inline]
|
||||||
|
pub(crate) fn magnitude<A: ArrayView1<T>>(x: &A) -> f64 {
|
||||||
|
// Use the built-in norm2 method from ArrayView1 trait
|
||||||
|
x.norm2()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculate cosine similarity between two vectors
|
||||||
|
#[inline]
|
||||||
|
pub(crate) fn cosine_similarity<A: ArrayView1<T>>(x: &A, y: &A) -> f64 {
|
||||||
|
let dot_product = Self::dot_product(x, y);
|
||||||
|
let magnitude_x = Self::magnitude(x);
|
||||||
|
let magnitude_y = Self::magnitude(y);
|
||||||
|
|
||||||
|
if magnitude_x == 0.0 || magnitude_y == 0.0 {
|
||||||
|
return f64::MIN;
|
||||||
|
}
|
||||||
|
|
||||||
|
dot_product / (magnitude_x * magnitude_y)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Number, A: ArrayView1<T>> Distance<A> for Cosine<T> {
|
||||||
|
fn distance(&self, x: &A, y: &A) -> f64 {
|
||||||
|
let similarity = Cosine::cosine_similarity(x, y);
|
||||||
|
1.0 - similarity
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_distance_identical_vectors() {
|
||||||
|
let a = vec![1, 2, 3];
|
||||||
|
let b = vec![1, 2, 3];
|
||||||
|
|
||||||
|
let dist: f64 = Cosine::new().distance(&a, &b);
|
||||||
|
|
||||||
|
assert!((dist - 0.0).abs() < 1e-8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_distance_orthogonal_vectors() {
|
||||||
|
let a = vec![1, 0];
|
||||||
|
let b = vec![0, 1];
|
||||||
|
|
||||||
|
let dist: f64 = Cosine::new().distance(&a, &b);
|
||||||
|
|
||||||
|
assert!((dist - 1.0).abs() < 1e-8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_distance_opposite_vectors() {
|
||||||
|
let a = vec![1, 2, 3];
|
||||||
|
let b = vec![-1, -2, -3];
|
||||||
|
|
||||||
|
let dist: f64 = Cosine::new().distance(&a, &b);
|
||||||
|
|
||||||
|
assert!((dist - 2.0).abs() < 1e-8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_distance_general_case() {
|
||||||
|
let a = vec![1.0, 2.0, 3.0];
|
||||||
|
let b = vec![2.0, 1.0, 3.0];
|
||||||
|
|
||||||
|
let dist: f64 = Cosine::new().distance(&a, &b);
|
||||||
|
|
||||||
|
// Expected cosine similarity: (1*2 + 2*1 + 3*3) / (sqrt(1+4+9) * sqrt(4+1+9))
|
||||||
|
// = (2 + 2 + 9) / (sqrt(14) * sqrt(14)) = 13/14 ≈ 0.9286
|
||||||
|
// So cosine distance = 1 - 13/14 = 1/14 ≈ 0.0714
|
||||||
|
let expected_dist = 1.0 - (13.0 / 14.0);
|
||||||
|
assert!((dist - expected_dist).abs() < 1e-8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "Input vector sizes are different.")]
|
||||||
|
fn cosine_distance_different_sizes() {
|
||||||
|
let a = vec![1, 2];
|
||||||
|
let b = vec![1, 2, 3];
|
||||||
|
|
||||||
|
let _dist: f64 = Cosine::new().distance(&a, &b);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_distance_zero_vector() {
|
||||||
|
let a = vec![0, 0, 0];
|
||||||
|
let b = vec![1, 2, 3];
|
||||||
|
|
||||||
|
let dist: f64 = Cosine::new().distance(&a, &b);
|
||||||
|
assert!(dist > 1e300)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_distance_float_precision() {
|
||||||
|
let a = vec![1.0f32, 2.0, 3.0];
|
||||||
|
let b = vec![4.0f32, 5.0, 6.0];
|
||||||
|
|
||||||
|
let dist: f64 = Cosine::new().distance(&a, &b);
|
||||||
|
|
||||||
|
// Calculate expected value manually
|
||||||
|
let dot_product = 1.0 * 4.0 + 2.0 * 5.0 + 3.0 * 6.0; // = 32
|
||||||
|
let mag_a = (1.0 * 1.0 + 2.0 * 2.0 + 3.0 * 3.0_f64).sqrt(); // = sqrt(14)
|
||||||
|
let mag_b = (4.0 * 4.0 + 5.0 * 5.0 + 6.0 * 6.0_f64).sqrt(); // = sqrt(77)
|
||||||
|
let expected_similarity = dot_product / (mag_a * mag_b);
|
||||||
|
let expected_distance = 1.0 - expected_similarity;
|
||||||
|
|
||||||
|
assert!((dist - expected_distance).abs() < 1e-6);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -13,6 +13,8 @@
|
|||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
|
||||||
|
/// Cosine distance
|
||||||
|
pub mod cosine;
|
||||||
/// Euclidean Distance is the straight-line distance between two points in Euclidean spacere that presents the shortest distance between these points.
|
/// Euclidean Distance is the straight-line distance between two points in Euclidean spacere that presents the shortest distance between these points.
|
||||||
pub mod euclidian;
|
pub mod euclidian;
|
||||||
/// Hamming Distance between two strings is the number of positions at which the corresponding symbols are different.
|
/// Hamming Distance between two strings is the number of positions at which the corresponding symbols are different.
|
||||||
|
|||||||
+88
-23
@@ -4,7 +4,9 @@
|
|||||||
//!
|
//!
|
||||||
//! \\[precision = \frac{tp}{tp + fp}\\]
|
//! \\[precision = \frac{tp}{tp + fp}\\]
|
||||||
//!
|
//!
|
||||||
//! where tp (true positive) - correct result, fp (false positive) - unexpected result
|
//! where tp (true positive) - correct result, fp (false positive) - unexpected result.
|
||||||
|
//! For binary classification, this is precision for the positive class (assumed to be 1.0).
|
||||||
|
//! For multiclass, this is macro-averaged precision (average of per-class precisions).
|
||||||
//!
|
//!
|
||||||
//! Example:
|
//! Example:
|
||||||
//!
|
//!
|
||||||
@@ -19,7 +21,8 @@
|
|||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
use std::collections::HashSet;
|
|
||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
@@ -61,33 +64,63 @@ impl<T: RealNumber> Metrics<T> for Precision<T> {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut classes = HashSet::new();
|
let n = y_true.shape();
|
||||||
for i in 0..y_true.shape() {
|
|
||||||
classes.insert(y_true.get(i).to_f64_bits());
|
|
||||||
}
|
|
||||||
let classes = classes.len();
|
|
||||||
|
|
||||||
let mut tp = 0;
|
let mut classes_set: HashSet<u64> = HashSet::new();
|
||||||
let mut fp = 0;
|
for i in 0..n {
|
||||||
for i in 0..y_true.shape() {
|
classes_set.insert(y_true.get(i).to_f64_bits());
|
||||||
if y_pred.get(i) == y_true.get(i) {
|
}
|
||||||
if classes == 2 {
|
let classes: usize = classes_set.len();
|
||||||
if *y_true.get(i) == T::one() {
|
|
||||||
|
if classes == 2 {
|
||||||
|
// Binary case: precision for positive class (assumed T::one())
|
||||||
|
let positive = T::one();
|
||||||
|
let mut tp: usize = 0;
|
||||||
|
let mut fp_count: usize = 0;
|
||||||
|
for i in 0..n {
|
||||||
|
let t = *y_true.get(i);
|
||||||
|
let p = *y_pred.get(i);
|
||||||
|
if p == t {
|
||||||
|
if t == positive {
|
||||||
tp += 1;
|
tp += 1;
|
||||||
}
|
}
|
||||||
} else {
|
} else if t != positive {
|
||||||
tp += 1;
|
fp_count += 1;
|
||||||
}
|
|
||||||
} else if classes == 2 {
|
|
||||||
if *y_true.get(i) == T::one() {
|
|
||||||
fp += 1;
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
if tp + fp_count == 0 {
|
||||||
|
0.0
|
||||||
} else {
|
} else {
|
||||||
fp += 1;
|
tp as f64 / (tp + fp_count) as f64
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Multiclass case: macro-averaged precision
|
||||||
|
let mut predicted: HashMap<u64, usize> = HashMap::new();
|
||||||
|
let mut tp_map: HashMap<u64, usize> = HashMap::new();
|
||||||
|
for i in 0..n {
|
||||||
|
let p_bits = y_pred.get(i).to_f64_bits();
|
||||||
|
*predicted.entry(p_bits).or_insert(0) += 1;
|
||||||
|
if *y_true.get(i) == *y_pred.get(i) {
|
||||||
|
*tp_map.entry(p_bits).or_insert(0) += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mut precision_sum = 0.0;
|
||||||
|
for &bits in &classes_set {
|
||||||
|
let pred_count = *predicted.get(&bits).unwrap_or(&0);
|
||||||
|
let tp = *tp_map.get(&bits).unwrap_or(&0);
|
||||||
|
let prec = if pred_count > 0 {
|
||||||
|
tp as f64 / pred_count as f64
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
precision_sum += prec;
|
||||||
|
}
|
||||||
|
if classes == 0 {
|
||||||
|
0.0
|
||||||
|
} else {
|
||||||
|
precision_sum / classes as f64
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tp as f64 / (tp as f64 + fp as f64)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -114,7 +147,7 @@ mod tests {
|
|||||||
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
||||||
|
|
||||||
let score3: f64 = Precision::new().get_score(&y_true, &y_pred);
|
let score3: f64 = Precision::new().get_score(&y_true, &y_pred);
|
||||||
assert!((score3 - 0.6666666666).abs() < 1e-8);
|
assert!((score3 - 0.5).abs() < 1e-8);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg_attr(
|
#[cfg_attr(
|
||||||
@@ -132,4 +165,36 @@ mod tests {
|
|||||||
assert!((score1 - 0.333333333).abs() < 1e-8);
|
assert!((score1 - 0.333333333).abs() < 1e-8);
|
||||||
assert!((score2 - 1.0).abs() < 1e-8);
|
assert!((score2 - 1.0).abs() < 1e-8);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn precision_multiclass_imbalanced() {
|
||||||
|
let y_true: Vec<f64> = vec![0., 0., 1., 2., 2., 2.];
|
||||||
|
let y_pred: Vec<f64> = vec![0., 1., 1., 2., 0., 2.];
|
||||||
|
|
||||||
|
let score: f64 = Precision::new().get_score(&y_true, &y_pred);
|
||||||
|
let expected = (0.5 + 0.5 + 1.0) / 3.0;
|
||||||
|
assert!((score - expected).abs() < 1e-8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn precision_multiclass_unpredicted_class() {
|
||||||
|
let y_true: Vec<f64> = vec![0., 0., 1., 2., 2., 2., 3.];
|
||||||
|
let y_pred: Vec<f64> = vec![0., 1., 1., 2., 0., 2., 0.];
|
||||||
|
|
||||||
|
let score: f64 = Precision::new().get_score(&y_true, &y_pred);
|
||||||
|
// Class 0: pred=3, tp=1 -> 1/3 ≈0.333
|
||||||
|
// Class 1: pred=2, tp=1 -> 0.5
|
||||||
|
// Class 2: pred=2, tp=2 -> 1.0
|
||||||
|
// Class 3: pred=0, tp=0 -> 0.0
|
||||||
|
let expected = (1.0 / 3.0 + 0.5 + 1.0 + 0.0) / 4.0;
|
||||||
|
assert!((score - expected).abs() < 1e-8);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+64
-24
@@ -4,7 +4,9 @@
|
|||||||
//!
|
//!
|
||||||
//! \\[recall = \frac{tp}{tp + fn}\\]
|
//! \\[recall = \frac{tp}{tp + fn}\\]
|
||||||
//!
|
//!
|
||||||
//! where tp (true positive) - correct result, fn (false negative) - missing result
|
//! where tp (true positive) - correct result, fn (false negative) - missing result.
|
||||||
|
//! For binary classification, this is recall for the positive class (assumed to be 1.0).
|
||||||
|
//! For multiclass, this is macro-averaged recall (average of per-class recalls).
|
||||||
//!
|
//!
|
||||||
//! Example:
|
//! Example:
|
||||||
//!
|
//!
|
||||||
@@ -20,8 +22,7 @@
|
|||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
|
||||||
use std::collections::HashSet;
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::convert::TryInto;
|
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
@@ -52,7 +53,7 @@ impl<T: RealNumber> Metrics<T> for Recall<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
/// Calculated recall score
|
/// Calculated recall score
|
||||||
/// * `y_true` - cround truth (correct) labels.
|
/// * `y_true` - ground truth (correct) labels.
|
||||||
/// * `y_pred` - predicted labels, as returned by a classifier.
|
/// * `y_pred` - predicted labels, as returned by a classifier.
|
||||||
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64 {
|
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64 {
|
||||||
if y_true.shape() != y_pred.shape() {
|
if y_true.shape() != y_pred.shape() {
|
||||||
@@ -63,32 +64,57 @@ impl<T: RealNumber> Metrics<T> for Recall<T> {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut classes = HashSet::new();
|
let n = y_true.shape();
|
||||||
for i in 0..y_true.shape() {
|
|
||||||
classes.insert(y_true.get(i).to_f64_bits());
|
|
||||||
}
|
|
||||||
let classes: i64 = classes.len().try_into().unwrap();
|
|
||||||
|
|
||||||
let mut tp = 0;
|
let mut classes_set = HashSet::new();
|
||||||
let mut fne = 0;
|
for i in 0..n {
|
||||||
for i in 0..y_true.shape() {
|
classes_set.insert(y_true.get(i).to_f64_bits());
|
||||||
if y_pred.get(i) == y_true.get(i) {
|
}
|
||||||
if classes == 2 {
|
let classes: usize = classes_set.len();
|
||||||
if *y_true.get(i) == T::one() {
|
|
||||||
|
if classes == 2 {
|
||||||
|
// Binary case: recall for positive class (assumed T::one())
|
||||||
|
let positive = T::one();
|
||||||
|
let mut tp: usize = 0;
|
||||||
|
let mut fn_count: usize = 0;
|
||||||
|
for i in 0..n {
|
||||||
|
let t = *y_true.get(i);
|
||||||
|
let p = *y_pred.get(i);
|
||||||
|
if p == t {
|
||||||
|
if t == positive {
|
||||||
tp += 1;
|
tp += 1;
|
||||||
}
|
}
|
||||||
} else {
|
} else if t == positive {
|
||||||
tp += 1;
|
fn_count += 1;
|
||||||
}
|
|
||||||
} else if classes == 2 {
|
|
||||||
if *y_true.get(i) != T::one() {
|
|
||||||
fne += 1;
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
if tp + fn_count == 0 {
|
||||||
|
0.0
|
||||||
} else {
|
} else {
|
||||||
fne += 1;
|
tp as f64 / (tp + fn_count) as f64
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Multiclass case: macro-averaged recall
|
||||||
|
let mut support: HashMap<u64, usize> = HashMap::new();
|
||||||
|
let mut tp_map: HashMap<u64, usize> = HashMap::new();
|
||||||
|
for i in 0..n {
|
||||||
|
let t_bits = y_true.get(i).to_f64_bits();
|
||||||
|
*support.entry(t_bits).or_insert(0) += 1;
|
||||||
|
if *y_true.get(i) == *y_pred.get(i) {
|
||||||
|
*tp_map.entry(t_bits).or_insert(0) += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mut recall_sum = 0.0;
|
||||||
|
for (&bits, &sup) in &support {
|
||||||
|
let tp = *tp_map.get(&bits).unwrap_or(&0);
|
||||||
|
recall_sum += tp as f64 / sup as f64;
|
||||||
|
}
|
||||||
|
if support.is_empty() {
|
||||||
|
0.0
|
||||||
|
} else {
|
||||||
|
recall_sum / support.len() as f64
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
tp as f64 / (tp as f64 + fne as f64)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -115,7 +141,7 @@ mod tests {
|
|||||||
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
||||||
|
|
||||||
let score3: f64 = Recall::new().get_score(&y_true, &y_pred);
|
let score3: f64 = Recall::new().get_score(&y_true, &y_pred);
|
||||||
assert!((score3 - 0.5).abs() < 1e-8);
|
assert!((score3 - (2.0 / 3.0)).abs() < 1e-8);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg_attr(
|
#[cfg_attr(
|
||||||
@@ -133,4 +159,18 @@ mod tests {
|
|||||||
assert!((score1 - 0.333333333).abs() < 1e-8);
|
assert!((score1 - 0.333333333).abs() < 1e-8);
|
||||||
assert!((score2 - 1.0).abs() < 1e-8);
|
assert!((score2 - 1.0).abs() < 1e-8);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn recall_multiclass_imbalanced() {
|
||||||
|
let y_true: Vec<f64> = vec![0., 0., 1., 2., 2., 2.];
|
||||||
|
let y_pred: Vec<f64> = vec![0., 1., 1., 2., 0., 2.];
|
||||||
|
|
||||||
|
let score: f64 = Recall::new().get_score(&y_true, &y_pred);
|
||||||
|
let expected = (0.5 + 1.0 + (2.0 / 3.0)) / 3.0;
|
||||||
|
assert!((score - expected).abs() < 1e-8);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
//! # K Nearest Neighbors Regressor
|
//! # K Nearest Neighbors Regressor with Feature Sparsing
|
||||||
//!
|
//!
|
||||||
//! Regressor that predicts estimated values as a function of k nearest neightbours.
|
//! Regressor that predicts estimated values as a function of k nearest neightbours.
|
||||||
|
//! Now supports feature sparsing - the ability to consider only a subset of features during prediction.
|
||||||
//!
|
//!
|
||||||
//! `KNNRegressor` relies on 2 backend algorithms to speedup KNN queries:
|
//! `KNNRegressor` relies on 2 backend algorithms to speedup KNN queries:
|
||||||
//! * [`LinearSearch`](../../algorithm/neighbour/linear_search/index.html)
|
//! * [`LinearSearch`](../../algorithm/neighbour/linear_search/index.html)
|
||||||
@@ -29,6 +30,10 @@
|
|||||||
//!
|
//!
|
||||||
//! let knn = KNNRegressor::fit(&x, &y, Default::default()).unwrap();
|
//! let knn = KNNRegressor::fit(&x, &y, Default::default()).unwrap();
|
||||||
//! let y_hat = knn.predict(&x).unwrap();
|
//! let y_hat = knn.predict(&x).unwrap();
|
||||||
|
//!
|
||||||
|
//! // Predict using only features at indices 0
|
||||||
|
//! let feature_indices = vec![0];
|
||||||
|
//! let y_hat_sparse = knn.predict_sparse(&x, &feature_indices).unwrap();
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! variable `y_hat` will hold predicted value
|
//! variable `y_hat` will hold predicted value
|
||||||
@@ -77,12 +82,13 @@ pub struct KNNRegressorParameters<T: Number, D: Distance<Vec<T>>> {
|
|||||||
pub struct KNNRegressor<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
|
pub struct KNNRegressor<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
|
||||||
{
|
{
|
||||||
y: Option<Y>,
|
y: Option<Y>,
|
||||||
|
x: Option<X>, // Store training data for sparse feature prediction
|
||||||
knn_algorithm: Option<KNNAlgorithm<TX, D>>,
|
knn_algorithm: Option<KNNAlgorithm<TX, D>>,
|
||||||
|
distance: Option<D>, // Store distance function for sparse prediction
|
||||||
weight: Option<KNNWeightFunction>,
|
weight: Option<KNNWeightFunction>,
|
||||||
k: Option<usize>,
|
k: Option<usize>,
|
||||||
_phantom_tx: PhantomData<TX>,
|
_phantom_tx: PhantomData<TX>,
|
||||||
_phantom_ty: PhantomData<TY>,
|
_phantom_ty: PhantomData<TY>,
|
||||||
_phantom_x: PhantomData<X>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
|
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
|
||||||
@@ -92,12 +98,20 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
|
|||||||
self.y.as_ref().unwrap()
|
self.y.as_ref().unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn x(&self) -> &X {
|
||||||
|
self.x.as_ref().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
fn knn_algorithm(&self) -> &KNNAlgorithm<TX, D> {
|
fn knn_algorithm(&self) -> &KNNAlgorithm<TX, D> {
|
||||||
self.knn_algorithm
|
self.knn_algorithm
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.expect("Missing parameter: KNNAlgorithm")
|
.expect("Missing parameter: KNNAlgorithm")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn distance(&self) -> &D {
|
||||||
|
self.distance.as_ref().expect("Missing parameter: distance")
|
||||||
|
}
|
||||||
|
|
||||||
fn weight(&self) -> &KNNWeightFunction {
|
fn weight(&self) -> &KNNWeightFunction {
|
||||||
self.weight.as_ref().expect("Missing parameter: weight")
|
self.weight.as_ref().expect("Missing parameter: weight")
|
||||||
}
|
}
|
||||||
@@ -176,12 +190,13 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
|
|||||||
fn new() -> Self {
|
fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
y: Option::None,
|
y: Option::None,
|
||||||
|
x: Option::None,
|
||||||
knn_algorithm: Option::None,
|
knn_algorithm: Option::None,
|
||||||
|
distance: Option::None,
|
||||||
weight: Option::None,
|
weight: Option::None,
|
||||||
k: Option::None,
|
k: Option::None,
|
||||||
_phantom_tx: PhantomData,
|
_phantom_tx: PhantomData,
|
||||||
_phantom_ty: PhantomData,
|
_phantom_ty: PhantomData,
|
||||||
_phantom_x: PhantomData,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -231,16 +246,17 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
|
|||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let knn_algo = parameters.algorithm.fit(data, parameters.distance)?;
|
let knn_algo = parameters.algorithm.fit(data, parameters.distance.clone())?;
|
||||||
|
|
||||||
Ok(KNNRegressor {
|
Ok(KNNRegressor {
|
||||||
y: Some(y.clone()),
|
y: Some(y.clone()),
|
||||||
|
x: Some(x.clone()),
|
||||||
k: Some(parameters.k),
|
k: Some(parameters.k),
|
||||||
knn_algorithm: Some(knn_algo),
|
knn_algorithm: Some(knn_algo),
|
||||||
|
distance: Some(parameters.distance),
|
||||||
weight: Some(parameters.weight),
|
weight: Some(parameters.weight),
|
||||||
_phantom_tx: PhantomData,
|
_phantom_tx: PhantomData,
|
||||||
_phantom_ty: PhantomData,
|
_phantom_ty: PhantomData,
|
||||||
_phantom_x: PhantomData,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -262,6 +278,45 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
|
|||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Predict the target for the provided data using only specified features.
|
||||||
|
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
|
||||||
|
/// * `feature_indices` - indices of features to consider (e.g., [0, 2, 4] to use only features at positions 0, 2, and 4)
|
||||||
|
///
|
||||||
|
/// Returns a vector of size N with estimates.
|
||||||
|
pub fn predict_sparse(&self, x: &X, feature_indices: &[usize]) -> Result<Y, Failed> {
|
||||||
|
let (n_samples, n_features) = x.shape();
|
||||||
|
|
||||||
|
// Validate feature indices
|
||||||
|
for &idx in feature_indices {
|
||||||
|
if idx >= n_features {
|
||||||
|
return Err(Failed::predict(&format!(
|
||||||
|
"Feature index {} out of bounds (max: {})",
|
||||||
|
idx,
|
||||||
|
n_features - 1
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if feature_indices.is_empty() {
|
||||||
|
return Err(Failed::predict(
|
||||||
|
"feature_indices cannot be empty"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut result = Y::zeros(n_samples);
|
||||||
|
|
||||||
|
let mut row_vec = vec![TX::zero(); feature_indices.len()];
|
||||||
|
for (i, row) in x.row_iter().enumerate() {
|
||||||
|
// Extract only the specified features
|
||||||
|
for (j, &feat_idx) in feature_indices.iter().enumerate() {
|
||||||
|
row_vec[j] = *row.get(feat_idx);
|
||||||
|
}
|
||||||
|
result.set(i, self.predict_for_row_sparse(&row_vec, feature_indices)?);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
fn predict_for_row(&self, row: &Vec<TX>) -> Result<TY, Failed> {
|
fn predict_for_row(&self, row: &Vec<TX>) -> Result<TY, Failed> {
|
||||||
let search_result = self.knn_algorithm().find(row, self.k.unwrap())?;
|
let search_result = self.knn_algorithm().find(row, self.k.unwrap())?;
|
||||||
let mut result = TY::zero();
|
let mut result = TY::zero();
|
||||||
@@ -277,6 +332,50 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
|
|||||||
|
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn predict_for_row_sparse(
|
||||||
|
&self,
|
||||||
|
row: &Vec<TX>,
|
||||||
|
feature_indices: &[usize],
|
||||||
|
) -> Result<TY, Failed> {
|
||||||
|
let training_data = self.x();
|
||||||
|
let (n_training_samples, _) = training_data.shape();
|
||||||
|
let k = self.k.unwrap();
|
||||||
|
|
||||||
|
// Manually compute distances using only specified features
|
||||||
|
let mut distances: Vec<(usize, f64)> = Vec::with_capacity(n_training_samples);
|
||||||
|
|
||||||
|
for i in 0..n_training_samples {
|
||||||
|
let train_row = training_data.get_row(i);
|
||||||
|
|
||||||
|
// Extract sparse features from training data
|
||||||
|
let mut train_sparse = Vec::with_capacity(feature_indices.len());
|
||||||
|
for &feat_idx in feature_indices {
|
||||||
|
train_sparse.push(*train_row.get(feat_idx));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute distance using only selected features
|
||||||
|
let dist = self.distance().distance(row, &train_sparse);
|
||||||
|
distances.push((i, dist));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by distance and take k nearest
|
||||||
|
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||||
|
let k_nearest: Vec<(usize, f64)> = distances.into_iter().take(k).collect();
|
||||||
|
|
||||||
|
// Compute weighted prediction
|
||||||
|
let mut result = TY::zero();
|
||||||
|
let weights = self
|
||||||
|
.weight()
|
||||||
|
.calc_weights(k_nearest.iter().map(|v| v.1).collect());
|
||||||
|
let w_sum: f64 = weights.iter().copied().sum();
|
||||||
|
|
||||||
|
for (neighbor, w) in k_nearest.iter().zip(weights.iter()) {
|
||||||
|
result += *self.y().get(neighbor.0) * TY::from_f64(*w / w_sum).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -332,6 +431,91 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn knn_predict_sparse() {
|
||||||
|
// Training data with 3 features
|
||||||
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
|
&[1., 2., 10.],
|
||||||
|
&[3., 4., 20.],
|
||||||
|
&[5., 6., 30.],
|
||||||
|
&[7., 8., 40.],
|
||||||
|
&[9., 10., 50.],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
let y: Vec<f64> = vec![1., 2., 3., 4., 5.];
|
||||||
|
|
||||||
|
let knn = KNNRegressor::fit(&x, &y, Default::default()).unwrap();
|
||||||
|
|
||||||
|
// Test data
|
||||||
|
let x_test = DenseMatrix::from_2d_array(&[
|
||||||
|
&[1., 2., 999.], // Third feature is very different
|
||||||
|
&[5., 6., 999.],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Predict using only first two features (ignore the third)
|
||||||
|
let feature_indices = vec![0, 1];
|
||||||
|
let y_hat_sparse = knn.predict_sparse(&x_test, &feature_indices).unwrap();
|
||||||
|
|
||||||
|
// Should get good predictions since we're ignoring the mismatched third feature
|
||||||
|
assert_eq!(2, Vec::len(&y_hat_sparse));
|
||||||
|
assert!((y_hat_sparse[0] - 2.0).abs() < 1.0); // Should be close to 1-2
|
||||||
|
assert!((y_hat_sparse[1] - 3.0).abs() < 1.0); // Should be close to 3
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn knn_predict_sparse_single_feature() {
|
||||||
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
|
&[1., 100., 1000.],
|
||||||
|
&[2., 200., 2000.],
|
||||||
|
&[3., 300., 3000.],
|
||||||
|
&[4., 400., 4000.],
|
||||||
|
&[5., 500., 5000.],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
let y: Vec<f64> = vec![1., 2., 3., 4., 5.];
|
||||||
|
|
||||||
|
let knn = KNNRegressor::fit(&x, &y, Default::default()).unwrap();
|
||||||
|
|
||||||
|
let x_test = DenseMatrix::from_2d_array(&[&[1.5, 999., 9999.]]).unwrap();
|
||||||
|
|
||||||
|
// Use only first feature
|
||||||
|
let y_hat = knn.predict_sparse(&x_test, &[0]).unwrap();
|
||||||
|
|
||||||
|
// Should predict based on first feature only
|
||||||
|
assert_eq!(1, Vec::len(&y_hat));
|
||||||
|
assert!((y_hat[0] - 1.5).abs() < 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn knn_predict_sparse_invalid_indices() {
|
||||||
|
let x = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.]]).unwrap();
|
||||||
|
let y: Vec<f64> = vec![1., 2.];
|
||||||
|
|
||||||
|
let knn = KNNRegressor::fit(&x, &y, Default::default()).unwrap();
|
||||||
|
let x_test = DenseMatrix::from_2d_array(&[&[1., 2.]]).unwrap();
|
||||||
|
|
||||||
|
// Index out of bounds
|
||||||
|
let result = knn.predict_sparse(&x_test, &[5]);
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
// Empty indices
|
||||||
|
let result = knn.predict_sparse(&x_test, &[]);
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(
|
#[cfg_attr(
|
||||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
wasm_bindgen_test::wasm_bindgen_test
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ pub trait LineSearchMethod<T: Float> {
|
|||||||
/// Find alpha that satisfies strong Wolfe conditions.
|
/// Find alpha that satisfies strong Wolfe conditions.
|
||||||
fn search(
|
fn search(
|
||||||
&self,
|
&self,
|
||||||
f: &(dyn Fn(T) -> T),
|
f: &dyn Fn(T) -> T,
|
||||||
df: &(dyn Fn(T) -> T),
|
df: &dyn Fn(T) -> T,
|
||||||
alpha: T,
|
alpha: T,
|
||||||
f0: T,
|
f0: T,
|
||||||
df0: T,
|
df0: T,
|
||||||
@@ -55,8 +55,8 @@ impl<T: Float> Default for Backtracking<T> {
|
|||||||
impl<T: Float> LineSearchMethod<T> for Backtracking<T> {
|
impl<T: Float> LineSearchMethod<T> for Backtracking<T> {
|
||||||
fn search(
|
fn search(
|
||||||
&self,
|
&self,
|
||||||
f: &(dyn Fn(T) -> T),
|
f: &dyn Fn(T) -> T,
|
||||||
_: &(dyn Fn(T) -> T),
|
_: &dyn Fn(T) -> T,
|
||||||
alpha: T,
|
alpha: T,
|
||||||
f0: T,
|
f0: T,
|
||||||
df0: T,
|
df0: T,
|
||||||
|
|||||||
@@ -674,15 +674,20 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
) -> bool {
|
) -> bool {
|
||||||
let (n_rows, n_attr) = visitor.x.shape();
|
let (n_rows, n_attr) = visitor.x.shape();
|
||||||
|
|
||||||
let mut label = Option::None;
|
let mut label = None;
|
||||||
let mut is_pure = true;
|
let mut is_pure = true;
|
||||||
for i in 0..n_rows {
|
for i in 0..n_rows {
|
||||||
if visitor.samples[i] > 0 {
|
if visitor.samples[i] > 0 {
|
||||||
if label.is_none() {
|
match label {
|
||||||
label = Option::Some(visitor.y[i]);
|
None => {
|
||||||
} else if visitor.y[i] != label.unwrap() {
|
label = Some(visitor.y[i]);
|
||||||
is_pure = false;
|
}
|
||||||
break;
|
Some(current_label) => {
|
||||||
|
if visitor.y[i] != current_label {
|
||||||
|
is_pure = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,10 +53,14 @@ use crate::{
|
|||||||
rand_custom::get_rng_impl,
|
rand_custom::get_rng_impl,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// Defines the objective function to be optimized.
|
/// Defines the objective function to be optimized.
|
||||||
/// The objective function provides the loss, gradient (first derivative), and
|
/// The objective function provides the loss, gradient (first derivative), and
|
||||||
/// hessian (second derivative) required for the XGBoost algorithm.
|
/// hessian (second derivative) required for the XGBoost algorithm.
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
pub enum Objective {
|
pub enum Objective {
|
||||||
/// The objective for regression tasks using Mean Squared Error.
|
/// The objective for regression tasks using Mean Squared Error.
|
||||||
/// Loss: 0.5 * (y_true - y_pred)^2
|
/// Loss: 0.5 * (y_true - y_pred)^2
|
||||||
@@ -96,7 +100,7 @@ impl Objective {
|
|||||||
pub fn gradient<TY: Number, Y: Array1<TY>>(&self, y_true: &Y, y_pred: &Vec<f64>) -> Vec<f64> {
|
pub fn gradient<TY: Number, Y: Array1<TY>>(&self, y_true: &Y, y_pred: &Vec<f64>) -> Vec<f64> {
|
||||||
match self {
|
match self {
|
||||||
Objective::MeanSquaredError => zip(y_true.iterator(0), y_pred)
|
Objective::MeanSquaredError => zip(y_true.iterator(0), y_pred)
|
||||||
.map(|(true_val, pred_val)| (*pred_val - true_val.to_f64().unwrap()))
|
.map(|(true_val, pred_val)| *pred_val - true_val.to_f64().unwrap())
|
||||||
.collect(),
|
.collect(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -122,6 +126,8 @@ impl Objective {
|
|||||||
/// This is a recursive data structure where each `TreeRegressor` is a node
|
/// This is a recursive data structure where each `TreeRegressor` is a node
|
||||||
/// that can have a left and a right child, also of type `TreeRegressor`.
|
/// that can have a left and a right child, also of type `TreeRegressor`.
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
struct TreeRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
|
struct TreeRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
|
||||||
left: Option<Box<TreeRegressor<TX, TY, X, Y>>>,
|
left: Option<Box<TreeRegressor<TX, TY, X, Y>>>,
|
||||||
right: Option<Box<TreeRegressor<TX, TY, X, Y>>>,
|
right: Option<Box<TreeRegressor<TX, TY, X, Y>>>,
|
||||||
@@ -374,6 +380,7 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
/// Parameters for the `jRegressor` model.
|
/// Parameters for the `jRegressor` model.
|
||||||
///
|
///
|
||||||
/// This struct holds all the hyperparameters that control the training process.
|
/// This struct holds all the hyperparameters that control the training process.
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct XGRegressorParameters {
|
pub struct XGRegressorParameters {
|
||||||
/// The number of boosting rounds or trees to build.
|
/// The number of boosting rounds or trees to build.
|
||||||
@@ -494,6 +501,8 @@ impl XGRegressorParameters {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// An Extreme Gradient Boosting (XGBoost) model for regression and classification tasks.
|
/// An Extreme Gradient Boosting (XGBoost) model for regression and classification tasks.
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct XGRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
|
pub struct XGRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
|
||||||
regressors: Option<Vec<TreeRegressor<TX, TY, X, Y>>>,
|
regressors: Option<Vec<TreeRegressor<TX, TY, X, Y>>>,
|
||||||
parameters: Option<XGRegressorParameters>,
|
parameters: Option<XGRegressorParameters>,
|
||||||
|
|||||||
Reference in New Issue
Block a user