86 Commits

Author SHA1 Message Date
Lorenzo Mec-iS
a62c293244 Add another pairwise distance algorithm 2025-01-28 00:30:57 +00:00
Lorenzo Mec-iS
39f87aa5c2 add tests to fastpair 2025-01-28 00:20:29 +00:00
Lorenzo Mec-iS
8cc02cdd48 fix test 2025-01-27 23:43:42 +00:00
Lorenzo Mec-iS
d60ba63862 Merge branch 'main' of github.com:smartcorelib/smartcore into march-2023-improvements 2025-01-27 23:34:45 +00:00
Lorenzo
5dd5c2f0d0 Merge branch 'development' into march-2023-improvements 2025-01-27 23:28:58 +00:00
Lorenzo (Mec-iS)
074cfaf14f rustfmt 2023-03-24 12:06:54 +09:00
Lorenzo
393cf15534 Merge branch 'development' into march-2023-improvements 2023-03-24 12:05:06 +09:00
Lorenzo (Mec-iS)
80c406b37d Merge branch 'development' of github.com:smartcorelib/smartcore into march-2023-improvements 2023-03-21 17:38:35 +09:00
Lorenzo (Mec-iS)
0e1bf6ce7f Add ordered_pairs method to FastPair 2023-03-21 14:46:33 +09:00
Lorenzo (Mec-iS)
0c9c70f8d2 Merge 2022-11-09 12:05:17 +00:00
morenol
62de25b2ae Handle kernel serialization (#232)
* Handle kernel serialization
* Do not use typetag in WASM
* enable tests for serialization
* Update serde feature deps

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
Co-authored-by: Lorenzo <tunedconsulting@gmail.com>
2022-11-08 11:29:56 -05:00
morenol
7d87451333 Fixes for release (#237)
* Fixes for release
* add new test
* Remove change applied in development branch
* Only add dependency for wasm32
* Update ci.yml

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
Co-authored-by: Lorenzo <tunedconsulting@gmail.com>
2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
265fd558e7 make work cargo build --target wasm32-unknown-unknown 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
e25e2aea2b update CHANGELOG 2022-11-08 11:29:56 -05:00
Lorenzo
2f6dd1325e update comment 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
b0dece9476 use getrandom/js 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
c507d976be Update CHANGELOG 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
fa54d5ee86 Remove unused tests flags 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
459d558d48 minor fixes to doc 2022-11-08 11:29:56 -05:00
Lorenzo
1b7dda30a2 minor fix 2022-11-08 11:29:56 -05:00
Lorenzo
c1bd1df5f6 minor fix 2022-11-08 11:29:56 -05:00
Lorenzo
cf751f05aa minor fix 2022-11-08 11:29:56 -05:00
Lorenzo
63ed89aadd minor fix 2022-11-08 11:29:56 -05:00
Lorenzo
890e9d644c minor fix 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
af0a740394 Fix std_rand feature 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
616e38c282 cleanup 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
a449fdd4ea fmt 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
669f87f812 Use getrandom as default (for no-std feature) 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
6d529b34d2 Add static analyzer to doc 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
3ec9e4f0db Exclude datasets test for wasm/wasi 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
527477dea7 minor fixes 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
5b517c5048 minor fix 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
2df0795be9 Release 0.3 2022-11-08 11:29:56 -05:00
Lorenzo
0dc97a4e9b Create DEVELOPERS.md 2022-11-08 11:29:56 -05:00
Lorenzo
6c0fd37222 Update README.md 2022-11-08 11:29:56 -05:00
Lorenzo
d8d0fb6903 Update README.md 2022-11-08 11:29:56 -05:00
morenol
8d07efd921 Use Box in SVM and remove lifetimes (#228)
* Do not change external API
Authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-11-08 11:29:56 -05:00
morenol
ba27dd2a55 Fix CI (#227)
* Update ci.yml
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-11-08 11:29:56 -05:00
Lorenzo
ed9769f651 Implement CSV reader with new traits (#209) 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
b427e5d8b1 Improve options conditionals 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
fabe362755 Implement Display for NaiveBayes 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
ee6b6a53d6 cargo clippy 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
19f3a2fcc0 Fix signature of metrics tests 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
e09c4ba724 Add kernels' parameters to public interface 2022-11-08 11:29:56 -05:00
Lorenzo
6624732a65 Fix svr tests (#222) 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
1cbde3ba22 Refactor modules structure in src/svm 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
551a6e34a5 clean up svm 2022-11-08 11:29:56 -05:00
Lorenzo
c45bab491a Support Wasi as target (#216)
* Improve features
* Add wasm32-wasi as a target
* Update .github/workflows/ci.yml
Co-authored-by: morenol <22335041+morenol@users.noreply.github.com>
2022-11-08 11:29:56 -05:00
Lorenzo
7f35dc54e4 Disambiguate distances. Implement Fastpair. (#220) 2022-11-08 11:29:56 -05:00
morenol
8f1a7dfd79 build: fix compilation without default features (#218)
* build: fix compilation with optional features
* Remove unused config from Cargo.toml
* Fix cache keys
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-11-08 11:29:56 -05:00
Lorenzo
712c478af6 Improve features (#215) 2022-11-08 11:29:56 -05:00
Lorenzo
4d36b7f34f Fix metrics::auc (#212)
* Fix metrics::auc
2022-11-08 11:29:56 -05:00
Lorenzo
a16927aa16 Port ensemble. Add Display to naive_bayes (#208) 2022-11-08 11:29:56 -05:00
Lorenzo
d91f4f7ce4 Update README.md 2022-11-08 11:29:56 -05:00
Lorenzo
a7fa0585eb Merge potential next release v0.4 (#187) Breaking Changes
* First draft of the new n-dimensional arrays + NB use case
* Improves default implementation of multiple Array methods
* Refactors tree methods
* Adds matrix decomposition routines
* Adds matrix decomposition methods to ndarray and nalgebra bindings
* Refactoring + linear regression now uses array2
* Ridge & Linear regression
* LBFGS optimizer & logistic regression
* LBFGS optimizer & logistic regression
* Changes linear methods, metrics and model selection methods to new n-dimensional arrays
* Switches KNN and clustering algorithms to new n-d array layer
* Refactors distance metrics
* Optimizes knn and clustering methods
* Refactors metrics module
* Switches decomposition methods to n-dimensional arrays
* Linalg refactoring - cleanup rng merge (#172)
* Remove legacy DenseMatrix and BaseMatrix implementation. Port the new Number, FloatNumber and Array implementation into module structure.
* Exclude AUC metrics. Needs reimplementation
* Improve developers walkthrough

New traits system in place at `src/numbers` and `src/linalg`
Co-authored-by: Lorenzo <tunedconsulting@gmail.com>

* Provide SupervisedEstimator with a constructor to avoid explicit dynamical box allocation in 'cross_validate' and 'cross_validate_predict' as required by the use of 'dyn' as per Rust 2021
* Implement getters to use as_ref() in src/neighbors
* Implement getters to use as_ref() in src/naive_bayes
* Implement getters to use as_ref() in src/linear
* Add Clone to src/naive_bayes
* Change signature for cross_validate and other model_selection functions to abide to use of dyn in Rust 2021
* Implement ndarray-bindings. Remove FloatNumber from implementations
* Drop nalgebra-bindings support (as decided in conf-call to go for ndarray)
* Remove benches. Benches will have their own repo at smartcore-benches
* Implement SVC
* Implement SVC serialization. Move search parameters in dedicated module
* Implement SVR. Definitely too slow
* Fix compilation issues for wasm (#202)

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
* Fix tests (#203)

* Port linalg/traits/stats.rs
* Improve methods naming
* Improve Display for DenseMatrix

Co-authored-by: Montana Low <montanalow@users.noreply.github.com>
Co-authored-by: VolodymyrOrlov <volodymyr.orlov@gmail.com>
2022-11-08 11:29:56 -05:00
RJ Nowling
a32eb66a6a Dataset doc cleanup (#205)
* Update iris.rs

* Update mod.rs

* Update digits.rs
2022-11-08 11:29:56 -05:00
Lorenzo
f605f6e075 Update README.md 2022-11-08 11:29:56 -05:00
Lorenzo
3b1aaaadf7 Update README.md 2022-11-08 11:29:56 -05:00
Lorenzo
d015b12402 Update CONTRIBUTING.md 2022-11-08 11:29:56 -05:00
morenol
d5200074c2 fix: fix issue with iterator for svc search (#182) 2022-11-08 11:29:56 -05:00
morenol
473cdfc44d refactor: Try to follow similar pattern to other APIs (#180)
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-11-08 11:29:56 -05:00
morenol
ad2e6c2900 feat: expose hyper tuning module in model_selection (#179)
* feat: expose hyper tuning module in model_selection

* Move to a folder

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-11-08 11:29:56 -05:00
Lorenzo
9ea3133c27 Update CONTRIBUTING.md 2022-11-08 11:29:56 -05:00
Lorenzo
e4c47c7540 Add contribution guidelines (#178) 2022-11-08 11:29:56 -05:00
Montana Low
f4fd4d2239 make default params available to serde (#167)
* add seed param to search params

* make default params available to serde

* lints

* create defaults for enums

* lint
2022-11-08 11:29:56 -05:00
Montana Low
05dfffad5c add seed param to search params (#168) 2022-11-08 11:29:56 -05:00
morenol
a37b552a7d Lmm/add seeds in more algorithms (#164)
* Provide better output in flaky tests

* feat: add seed parameter to multiple algorithms

* Update changelog

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-11-08 11:29:56 -05:00
Montana Low
55e1158581 Complete grid search params (#166)
* grid search draft

* hyperparam search for linear estimators

* grid search for ensembles

* support grid search for more algos

* grid search for unsupervised algos

* minor cleanup
2022-11-08 11:29:56 -05:00
morenol
cfa824d7db Provide better output in flaky tests (#163) 2022-11-08 11:29:56 -05:00
morenol
bb5b437a32 feat: allocate first and then proceed to create matrix from Vec of Ro… (#159)
* feat: allocate first and then proceed to create matrix from Vec of RowVectors
2022-11-08 11:29:56 -05:00
morenol
851533dfa7 Make rand_distr optional (#161) 2022-11-08 11:29:56 -05:00
Lorenzo
0d996edafe Update LICENSE 2022-11-08 11:29:56 -05:00
morenol
f291b71f4a fix: fix compilation warnings when running only with default features (#160)
* fix: fix compilation warnings when running only with default features
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-11-08 11:29:56 -05:00
Tim Toebrock
2d75c2c405 Implement a generic read_csv method (#147)
* feat: Add interface to build `Matrix` from rows.
* feat: Add option to derive `RealNumber` from string.
To construct a `Matrix` from csv, and therefore from string, I need to be able to deserialize a generic `RealNumber` from string.
* feat: Implement `Matrix::read_csv`.
2022-11-08 11:29:56 -05:00
Montana Low
1f2597be74 grid search (#154)
* grid search draft
* hyperparam search for linear estimators
2022-11-08 11:29:56 -05:00
Montana Low
0f442e96c0 Handle multiclass precision/recall (#152)
* handle multiclass precision/recall
2022-11-08 11:29:56 -05:00
dependabot[bot]
44e4be23a6 Update criterion requirement from 0.3 to 0.4 (#150)
* Update criterion requirement from 0.3 to 0.4

Updates the requirements on [criterion](https://github.com/bheisler/criterion.rs) to permit the latest version.
- [Release notes](https://github.com/bheisler/criterion.rs/releases)
- [Changelog](https://github.com/bheisler/criterion.rs/blob/master/CHANGELOG.md)
- [Commits](https://github.com/bheisler/criterion.rs/compare/0.3.0...0.4.0)

---
updated-dependencies:
- dependency-name: criterion
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

* fix criterion

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-11-08 11:29:56 -05:00
Christos Katsakioris
01f753f86d Add serde for StandardScaler (#148)
* Derive `serde::Serialize` and `serde::Deserialize` for
  `StandardScaler`.
* Add relevant unit test.

Signed-off-by: Christos Katsakioris <ckatsak@gmail.com>

Signed-off-by: Christos Katsakioris <ckatsak@gmail.com>
2022-11-08 11:29:56 -05:00
Tim Toebrock
df766eaf79 Implementation of Standard scaler (#143)
* docs: Fix typo in doc for categorical transformer.
* feat: Add option to take a column from Matrix.
I created the method `Matrix::take_column` that uses the `Matrix::take`-interface to extract a single column from a matrix. I need that feature in the implementation of  `StandardScaler`.
* feat: Add `StandardScaler`.
Authored-by: titoeb <timtoebrock@googlemail.com>
2022-11-08 11:29:56 -05:00
Lorenzo
09d9205696 Add example for FastPair (#144)
* Add example

* Move to top

* Add imports to example

* Fix imports
2022-11-08 11:29:56 -05:00
Lorenzo
dc7f01db4a Implement fastpair (#142)
* initial fastpair implementation
* FastPair initial implementation
* implement fastpair
* Add random test
* Add bench for fastpair
* Refactor with constructor for FastPair
* Add serialization for PairwiseDistance
* Add fp_bench feature for fastpair bench
2022-11-08 11:29:56 -05:00
Chris McComb
eb4b49d552 Added additional doctest and fixed indices (#141) 2022-11-08 11:29:56 -05:00
morenol
98e3465e7b Fix clippy warnings (#139)
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-11-08 11:29:56 -05:00
ferrouille
ea39024fd2 Add SVC::decision_function (#135) 2022-11-08 11:29:56 -05:00
dependabot[bot]
4e94feb872 Update nalgebra requirement from 0.23.0 to 0.31.0 (#128)
Updates the requirements on [nalgebra](https://github.com/dimforge/nalgebra) to permit the latest version.
- [Release notes](https://github.com/dimforge/nalgebra/releases)
- [Changelog](https://github.com/dimforge/nalgebra/blob/dev/CHANGELOG.md)
- [Commits](https://github.com/dimforge/nalgebra/compare/v0.23.0...v0.31.0)

---
updated-dependencies:
- dependency-name: nalgebra
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-11-08 11:29:56 -05:00
dependabot-preview[bot]
fa802d2d3f build(deps): update nalgebra requirement from 0.23.0 to 0.26.2 (#98)
* build(deps): update nalgebra requirement from 0.23.0 to 0.26.2

Updates the requirements on [nalgebra](https://github.com/dimforge/nalgebra) to permit the latest version.
- [Release notes](https://github.com/dimforge/nalgebra/releases)
- [Changelog](https://github.com/dimforge/nalgebra/blob/dev/CHANGELOG.md)
- [Commits](https://github.com/dimforge/nalgebra/compare/v0.23.0...v0.26.2)

Signed-off-by: dependabot-preview[bot] <support@dependabot.com>

* fix: updates for nalgebre

* test: explicitly call pow_mut from BaseVector since now it conflicts with nalgebra implementation

* Don't be strict with dependencies

Co-authored-by: dependabot-preview[bot] <27856297+dependabot-preview[bot]@users.noreply.github.com>
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-11-08 11:29:56 -05:00
23 changed files with 627 additions and 1332 deletions
+1
View File
@@ -2,5 +2,6 @@
# the repo. Unless a later match takes precedence, # the repo. Unless a later match takes precedence,
# Developers in this list will be requested for # Developers in this list will be requested for
# review when someone opens a pull request. # review when someone opens a pull request.
* @VolodymyrOrlov
* @morenol * @morenol
* @Mec-iS * @Mec-iS
+1 -1
View File
@@ -50,9 +50,9 @@ $ rust-code-analysis-cli -p src/algorithm/neighbour/fastpair.rs --ls 22 --le 213
1. After a PR is opened maintainers are notified 1. After a PR is opened maintainers are notified
2. Probably changes will be required to comply with the workflow, these commands are run automatically and all tests shall pass: 2. Probably changes will be required to comply with the workflow, these commands are run automatically and all tests shall pass:
* **Coverage** (optional): `tarpaulin` is used with command `cargo tarpaulin --out Lcov --all-features -- --test-threads 1`
* **Formatting**: run `rustfmt src/*.rs` to apply automatic formatting * **Formatting**: run `rustfmt src/*.rs` to apply automatic formatting
* **Linting**: `clippy` is used with command `cargo clippy --all-features -- -Drust-2018-idioms -Dwarnings` * **Linting**: `clippy` is used with command `cargo clippy --all-features -- -Drust-2018-idioms -Dwarnings`
* **Coverage** (optional): `tarpaulin` is used with command `cargo tarpaulin --out Lcov --all-features -- --test-threads 1`
* **Testing**: multiple test pipelines are run for different targets * **Testing**: multiple test pipelines are run for different targets
3. When everything is OK, code is merged. 3. When everything is OK, code is merged.
+15 -5
View File
@@ -19,13 +19,14 @@ jobs:
{ os: "ubuntu", target: "i686-unknown-linux-gnu" }, { os: "ubuntu", target: "i686-unknown-linux-gnu" },
{ os: "ubuntu", target: "wasm32-unknown-unknown" }, { os: "ubuntu", target: "wasm32-unknown-unknown" },
{ os: "macos", target: "aarch64-apple-darwin" }, { os: "macos", target: "aarch64-apple-darwin" },
{ os: "ubuntu", target: "wasm32-wasi" },
] ]
env: env:
TZ: "/usr/share/zoneinfo/your/location" TZ: "/usr/share/zoneinfo/your/location"
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v3
- name: Cache .cargo and target - name: Cache .cargo and target
uses: actions/cache@v4 uses: actions/cache@v2
with: with:
path: | path: |
~/.cargo ~/.cargo
@@ -35,13 +36,16 @@ jobs:
- name: Install Rust toolchain - name: Install Rust toolchain
uses: actions-rs/toolchain@v1 uses: actions-rs/toolchain@v1
with: with:
toolchain: stable toolchain: 1.81 # 1.82 seems to break wasm32 tests https://github.com/rustwasm/wasm-bindgen/issues/4274
target: ${{ matrix.platform.target }} target: ${{ matrix.platform.target }}
profile: minimal profile: minimal
default: true default: true
- name: Install test runner for wasm - name: Install test runner for wasm
if: matrix.platform.target == 'wasm32-unknown-unknown' if: matrix.platform.target == 'wasm32-unknown-unknown'
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
- name: Install test runner for wasi
if: matrix.platform.target == 'wasm32-wasi'
run: curl https://wasmtime.dev/install.sh -sSf | bash
- name: Stable Build with all features - name: Stable Build with all features
uses: actions-rs/cargo@v1 uses: actions-rs/cargo@v1
with: with:
@@ -61,6 +65,12 @@ jobs:
- name: Tests in WASM - name: Tests in WASM
if: matrix.platform.target == 'wasm32-unknown-unknown' if: matrix.platform.target == 'wasm32-unknown-unknown'
run: wasm-pack test --node -- --all-features run: wasm-pack test --node -- --all-features
- name: Tests in WASI
if: matrix.platform.target == 'wasm32-wasi'
run: |
export WASMTIME_HOME="$HOME/.wasmtime"
export PATH="$WASMTIME_HOME/bin:$PATH"
cargo install cargo-wasi && cargo wasi test
check_features: check_features:
runs-on: "${{ matrix.platform.os }}-latest" runs-on: "${{ matrix.platform.os }}-latest"
@@ -71,9 +81,9 @@ jobs:
env: env:
TZ: "/usr/share/zoneinfo/your/location" TZ: "/usr/share/zoneinfo/your/location"
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v3
- name: Cache .cargo and target - name: Cache .cargo and target
uses: actions/cache@v4 uses: actions/cache@v2
with: with:
path: | path: |
~/.cargo ~/.cargo
+2 -2
View File
@@ -12,9 +12,9 @@ jobs:
env: env:
TZ: "/usr/share/zoneinfo/your/location" TZ: "/usr/share/zoneinfo/your/location"
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v2
- name: Cache .cargo - name: Cache .cargo
uses: actions/cache@v4 uses: actions/cache@v2
with: with:
path: | path: |
~/.cargo ~/.cargo
+1 -1
View File
@@ -14,7 +14,7 @@ jobs:
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Cache .cargo and target - name: Cache .cargo and target
uses: actions/cache@v4 uses: actions/cache@v2
with: with:
path: | path: |
~/.cargo ~/.cargo
+1 -1
View File
@@ -2,7 +2,7 @@
name = "smartcore" name = "smartcore"
description = "Machine Learning in Rust." description = "Machine Learning in Rust."
homepage = "https://smartcorelib.org" homepage = "https://smartcorelib.org"
version = "0.4.2" version = "0.4.0"
authors = ["smartcore Developers"] authors = ["smartcore Developers"]
edition = "2021" edition = "2021"
license = "Apache-2.0" license = "Apache-2.0"
+1 -1
View File
@@ -18,4 +18,4 @@
----- -----
[![CI](https://github.com/smartcorelib/smartcore/actions/workflows/ci.yml/badge.svg)](https://github.com/smartcorelib/smartcore/actions/workflows/ci.yml) [![CI](https://github.com/smartcorelib/smartcore/actions/workflows/ci.yml/badge.svg)](https://github.com/smartcorelib/smartcore/actions/workflows/ci.yml)
To start getting familiar with the new smartcore v0.4 API, there is now available a [**Jupyter Notebook environment repository**](https://github.com/smartcorelib/smartcore-jupyter). Please see instructions there, contributions welcome see [CONTRIBUTING](.github/CONTRIBUTING.md). To start getting familiar with the new smartcore v0.3 API, there is now available a [**Jupyter Notebook environment repository**](https://github.com/smartcorelib/smartcore-jupyter). Please see instructions there, contributions welcome see [CONTRIBUTING](.github/CONTRIBUTING.md).
+219
View File
@@ -0,0 +1,219 @@
//! This module provides FastPair, a data-structure for efficiently tracking the dynamic
//! closest pairs in a set of points, with an example usage in hierarchical clustering.[2][3][5]
//!
//! ## Purpose
//!
//! FastPair allows quick retrieval of the nearest neighbor for each data point by maintaining
//! a "conga line" of closest pairs. Each point retains a link to its known nearest neighbor,
//! and updates in the data structure propagate accordingly. This can be leveraged in
//! agglomerative clustering steps, where merging or insertion of new points must be reflected
//! in nearest-neighbor relationships.
//!
//! ## Example
//!
//! ```
//! use smartcore::metrics::distance::PairwiseDistance;
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::algorithm::neighbour::fastpair::FastPair;
//!
//! let x = DenseMatrix::from_2d_array(&[
//! &[5.1, 3.5, 1.4, 0.2],
//! &[4.9, 3.0, 1.4, 0.2],
//! &[4.7, 3.2, 1.3, 0.2],
//! &[4.6, 3.1, 1.5, 0.2],
//! &[5.0, 3.6, 1.4, 0.2],
//! &[5.4, 3.9, 1.7, 0.4],
//! ]).unwrap();
//!
//! let fastpair = FastPair::new(&x).unwrap();
//! let closest = fastpair.closest_pair();
//! println!("Closest pair: {:?}", closest);
//! ```
use std::collections::HashMap;
use num::Bounded;
use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::{Array, Array1, Array2};
use crate::metrics::distance::euclidian::Euclidian;
use crate::metrics::distance::PairwiseDistance;
use crate::numbers::floatnum::FloatNumber;
use crate::numbers::realnum::RealNumber;
/// Eppstein dynamic closet-pair structure
/// 'M' can be a matrix-like trait that provides row access
#[derive(Debug)]
pub struct EppsteinDCP<'a, T: RealNumber + FloatNumber, M: Array2<T>> {
samples: &'a M,
// "buckets" store, for each row, a small structure recording potential neighbors
neighbors: HashMap<usize, PairwiseDistance<T>>,
}
impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> EppsteinDCP<'a, T, M> {
/// Creates a new EppsteinDCP instance with the given data
pub fn new(m: &'a M) -> Result<Self, Failed> {
if m.shape().0 < 3 {
return Err(Failed::because(
FailedError::FindFailed,
"min number of rows should be 3",
));
}
let mut this = Self {
samples: m,
neighbors: HashMap::with_capacity(m.shape().0),
};
this.initialize();
Ok(this)
}
/// Build an initial "conga line" or chain of potential neighbors
/// akin to Eppsteins technique[2].
fn initialize(&mut self) {
let n = self.samples.shape().0;
if n < 2 {
return;
}
// Assign each row i some large distance by default
for i in 0..n {
self.neighbors.insert(
i,
PairwiseDistance {
node: i,
neighbour: None,
distance: Some(<T as Bounded>::max_value()),
},
);
}
// Example: link each i to the next, forming a chain
// (depending on the actual Eppstein approach, can refine)
for i in 0..(n - 1) {
let dist = self.compute_dist(i, i + 1);
self.neighbors.entry(i).and_modify(|pd| {
pd.neighbour = Some(i + 1);
pd.distance = Some(dist);
});
}
// Potential refinement steps omitted for brevity
}
/// Insert a point into the structure.
pub fn insert(&mut self, row_idx: usize) {
// Expand data, find neighbor to link with
// For example, link row_idx to nearest among existing
let mut best_neighbor = None;
let mut best_d = <T as Bounded>::max_value();
for (i, _) in &self.neighbors {
let d = self.compute_dist(*i, row_idx);
if d < best_d {
best_d = d;
best_neighbor = Some(*i);
}
}
self.neighbors.insert(
row_idx,
PairwiseDistance {
node: row_idx,
neighbour: best_neighbor,
distance: Some(best_d),
},
);
// For the best_neighbor, you might want to see if row_idx becomes closer
if let Some(kn) = best_neighbor {
let dist = self.compute_dist(row_idx, kn);
let entry = self.neighbors.get_mut(&kn).unwrap();
if dist < entry.distance.unwrap() {
entry.neighbour = Some(row_idx);
entry.distance = Some(dist);
}
}
}
/// For hierarchical clustering, discover minimal pairs, then merge
pub fn closest_pair(&self) -> Option<PairwiseDistance<T>> {
let mut min_pair: Option<PairwiseDistance<T>> = None;
for (_, pd) in &self.neighbors {
if let Some(d) = pd.distance {
if min_pair.is_none() || d < min_pair.as_ref().unwrap().distance.unwrap() {
min_pair = Some(pd.clone());
}
}
}
min_pair
}
fn compute_dist(&self, i: usize, j: usize) -> T {
// Example: Euclidean
let row_i = self.samples.get_row(i);
let row_j = self.samples.get_row(j);
row_i
.iterator(0)
.zip(row_j.iterator(0))
.map(|(a, b)| (*a - *b) * (*a - *b))
.sum()
}
}
/// Simple usage
#[cfg(test)]
mod tests_eppstein {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
#[test]
fn test_eppstein() {
let matrix =
DenseMatrix::from_2d_array(&[&vec![1.0, 2.0], &vec![2.0, 2.0], &vec![5.0, 3.0]])
.unwrap();
let mut dcp = EppsteinDCP::new(&matrix).unwrap();
dcp.insert(2);
let cp = dcp.closest_pair();
assert!(cp.is_some());
}
#[test]
fn compare_fastpair_eppstein() {
use crate::algorithm::neighbour::fastpair::FastPair;
// Assuming EppsteinDCP is implemented in a similar module
use crate::algorithm::neighbour::eppstein::EppsteinDCP;
// Create a static example matrix
let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
])
.unwrap();
// Build FastPair
let fastpair = FastPair::new(&x).unwrap();
let pair_fastpair = fastpair.closest_pair();
// Build EppsteinDCP
let eppstein = EppsteinDCP::new(&x).unwrap();
let pair_eppstein = eppstein.closest_pair();
// Compare the results
assert_eq!(pair_fastpair.node, pair_eppstein.as_ref().unwrap().node);
assert_eq!(
pair_fastpair.neighbour.unwrap(),
pair_eppstein.as_ref().unwrap().neighbour.unwrap()
);
// Use a small epsilon for floating-point comparison
let epsilon = 1e-9;
let diff: f64 =
pair_fastpair.distance.unwrap() - pair_eppstein.as_ref().unwrap().distance.unwrap();
assert!(diff.abs() < epsilon);
println!("FastPair result: {:?}", pair_fastpair);
println!("EppsteinDCP result: {:?}", pair_eppstein);
}
}
+3 -1
View File
@@ -41,7 +41,9 @@ use serde::{Deserialize, Serialize};
pub(crate) mod bbd_tree; pub(crate) mod bbd_tree;
/// tree data structure for fast nearest neighbor search /// tree data structure for fast nearest neighbor search
pub mod cover_tree; pub mod cover_tree;
/// fastpair closest neighbour algorithm /// eppstein pairwise closest neighbour algorithm
pub mod eppstein;
/// fastpair pairwise closest neighbour algorithm
pub mod fastpair; pub mod fastpair;
/// very simple algorithm that sequentially checks each element of the list until a match is found or the whole list has been searched. /// very simple algorithm that sequentially checks each element of the list until a match is found or the whole list has been searched.
pub mod linear_search; pub mod linear_search;
-315
View File
@@ -1,315 +0,0 @@
//! # Agglomerative Hierarchical Clustering
//!
//! Agglomerative clustering is a "bottom-up" hierarchical clustering method. It works by placing each data point in its own cluster and then successively merging the two most similar clusters until a stopping criterion is met. This process creates a tree-based hierarchy of clusters known as a dendrogram.
//!
//! The similarity of two clusters is determined by a **linkage criterion**. This implementation uses **single-linkage**, where the distance between two clusters is defined as the minimum distance between any single point in the first cluster and any single point in the second cluster. The distance between points is the standard Euclidean distance.
//!
//! The algorithm first builds the full hierarchy of `N-1` merges. To obtain a specific number of clusters, `n_clusters`, the algorithm then effectively "cuts" the dendrogram at the point where `n_clusters` remain.
//!
//! ## Example:
//!
//! ```
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::cluster::agglomerative::{AgglomerativeClustering, AgglomerativeClusteringParameters};
//!
//! // A dataset with 2 distinct groups of points.
//! let x = DenseMatrix::from_2d_array(&[
//! &[0.0, 0.0], &[1.0, 1.0], &[0.5, 0.5], // Cluster A
//! &[10.0, 10.0], &[11.0, 11.0], &[10.5, 10.5], // Cluster B
//! ]).unwrap();
//!
//! // Set parameters to find 2 clusters.
//! let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(2);
//!
//! // Fit the model to the data.
//! let clustering = AgglomerativeClustering::<f64, usize, DenseMatrix<f64>, Vec<usize>>::fit(&x, parameters).unwrap();
//!
//! // Get the cluster assignments.
//! let labels = clustering.labels; // e.g., [0, 0, 0, 1, 1, 1]
//! ```
//!
//! ## References:
//!
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 10.3.2 Hierarchical Clustering](http://faculty.marshall.usc.edu/gareth-james/ISL/)
//! * ["The Elements of Statistical Learning", Hastie T., Tibshirani R., Friedman J., 14.3.12 Hierarchical Clustering](https://hastie.su.domains/ElemStatLearn/)
use std::collections::HashMap;
use std::marker::PhantomData;
use crate::api::UnsupervisedEstimator;
use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::numbers::basenum::Number;
/// Parameters for the Agglomerative Clustering algorithm.
#[derive(Debug, Clone, Copy)]
pub struct AgglomerativeClusteringParameters {
/// The number of clusters to find.
pub n_clusters: usize,
}
impl AgglomerativeClusteringParameters {
/// Sets the number of clusters.
///
/// # Arguments
/// * `n_clusters` - The desired number of clusters.
pub fn with_n_clusters(mut self, n_clusters: usize) -> Self {
self.n_clusters = n_clusters;
self
}
}
impl Default for AgglomerativeClusteringParameters {
fn default() -> Self {
AgglomerativeClusteringParameters { n_clusters: 2 }
}
}
/// Agglomerative Clustering model.
///
/// This implementation uses single-linkage clustering, which is mathematically
/// equivalent to finding the Minimum Spanning Tree (MST) of the data points.
/// The core logic is an efficient implementation of Kruskal's algorithm, which
/// processes all pairwise distances in increasing order and uses a Disjoint
/// Set Union (DSU) data structure to track cluster membership.
#[derive(Debug)]
pub struct AgglomerativeClustering<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
/// The cluster label assigned to each sample.
pub labels: Vec<usize>,
_phantom_tx: PhantomData<TX>,
_phantom_ty: PhantomData<TY>,
_phantom_x: PhantomData<X>,
_phantom_y: PhantomData<Y>,
}
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> AgglomerativeClustering<TX, TY, X, Y> {
/// Fits the agglomerative clustering model to the data.
///
/// # Arguments
/// * `data` - A reference to the input data matrix.
/// * `parameters` - The parameters for the clustering algorithm, including `n_clusters`.
///
/// # Returns
/// A `Result` containing the fitted model with cluster labels, or an error if
pub fn fit(data: &X, parameters: AgglomerativeClusteringParameters) -> Result<Self, Failed> {
let (num_samples, _) = data.shape();
let n_clusters = parameters.n_clusters;
if n_clusters > num_samples {
return Err(Failed::because(
FailedError::ParametersError,
&format!("n_clusters: {n_clusters} cannot be greater than n_samples: {num_samples}"),
));
}
let mut distance_pairs = Vec::new();
for i in 0..num_samples {
for j in (i + 1)..num_samples {
let distance: f64 = data
.get_row(i)
.iterator(0)
.zip(data.get_row(j).iterator(0))
.map(|(&a, &b)| (a.to_f64().unwrap() - b.to_f64().unwrap()).powi(2))
.sum::<f64>();
distance_pairs.push((distance, i, j));
}
}
distance_pairs.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
let mut parent = HashMap::new();
let mut children = HashMap::new();
for i in 0..num_samples {
parent.insert(i, i);
children.insert(i, vec![i]);
}
let mut merge_history = Vec::new();
let num_merges_needed = num_samples - 1;
while merge_history.len() < num_merges_needed {
let (_, p1, p2) = distance_pairs.pop().unwrap();
let root1 = parent[&p1];
let root2 = parent[&p2];
if root1 != root2 {
let root2_children = children.remove(&root2).unwrap();
for child in root2_children.iter() {
parent.insert(*child, root1);
}
let root1_children = children.get_mut(&root1).unwrap();
root1_children.extend(root2_children);
merge_history.push((root1, root2));
}
}
let mut clusters = HashMap::new();
let mut assignments = HashMap::new();
for i in 0..num_samples {
clusters.insert(i, vec![i]);
assignments.insert(i, i);
}
let merges_to_apply = num_samples - n_clusters;
for (root1, root2) in merge_history[0..merges_to_apply].iter() {
let root1_cluster = assignments[root1];
let root2_cluster = assignments[root2];
let root2_assignments = clusters.remove(&root2_cluster).unwrap();
for assignment in root2_assignments.iter() {
assignments.insert(*assignment, root1_cluster);
}
let root1_assignments = clusters.get_mut(&root1_cluster).unwrap();
root1_assignments.extend(root2_assignments);
}
let mut labels: Vec<usize> = (0..num_samples).map(|_| 0).collect();
let mut cluster_keys: Vec<&usize> = clusters.keys().collect();
cluster_keys.sort();
for (i, key) in cluster_keys.into_iter().enumerate() {
for index in clusters[key].iter() {
labels[*index] = i;
}
}
Ok(AgglomerativeClustering {
labels,
_phantom_tx: PhantomData,
_phantom_ty: PhantomData,
_phantom_x: PhantomData,
_phantom_y: PhantomData,
})
}
}
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>>
UnsupervisedEstimator<X, AgglomerativeClusteringParameters>
for AgglomerativeClustering<TX, TY, X, Y>
{
fn fit(x: &X, parameters: AgglomerativeClusteringParameters) -> Result<Self, Failed> {
AgglomerativeClustering::fit(x, parameters)
}
}
#[cfg(test)]
mod tests {
use crate::linalg::basic::matrix::DenseMatrix;
use std::collections::HashSet;
use super::*;
#[test]
fn test_simple_clustering() {
// Two distinct clusters, far apart.
let data = vec![
0.0, 0.0, 1.0, 1.0, 0.5, 0.5, // Cluster A
10.0, 10.0, 11.0, 11.0, 10.5, 10.5, // Cluster B
];
let matrix = DenseMatrix::new(6, 2, data, false).unwrap();
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(2);
// Using f64 for TY as usize doesn't satisfy the Number trait bound.
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
&matrix, parameters,
)
.unwrap();
let labels = clustering.labels;
// Check that all points in the first group have the same label.
let first_group_label = labels[0];
assert!(labels[0..3].iter().all(|&l| l == first_group_label));
// Check that all points in the second group have the same label.
let second_group_label = labels[3];
assert!(labels[3..6].iter().all(|&l| l == second_group_label));
// Check that the two groups have different labels.
assert_ne!(first_group_label, second_group_label);
}
#[test]
fn test_four_clusters() {
// Four distinct clusters in the corners of a square.
let data = vec![
0.0, 0.0, 1.0, 1.0, // Cluster A
100.0, 100.0, 101.0, 101.0, // Cluster B
0.0, 100.0, 1.0, 101.0, // Cluster C
100.0, 0.0, 101.0, 1.0, // Cluster D
];
let matrix = DenseMatrix::new(8, 2, data, false).unwrap();
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(4);
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
&matrix, parameters,
)
.unwrap();
let labels = clustering.labels;
// Verify that there are exactly 4 unique labels produced.
let unique_labels: HashSet<usize> = labels.iter().cloned().collect();
assert_eq!(unique_labels.len(), 4);
// Verify that points within each original group were assigned the same cluster label.
let label_a = labels[0];
assert_eq!(label_a, labels[1]);
let label_b = labels[2];
assert_eq!(label_b, labels[3]);
let label_c = labels[4];
assert_eq!(label_c, labels[5]);
let label_d = labels[6];
assert_eq!(label_d, labels[7]);
// Verify that all four groups received different labels.
assert_ne!(label_a, label_b);
assert_ne!(label_a, label_c);
assert_ne!(label_a, label_d);
assert_ne!(label_b, label_c);
assert_ne!(label_b, label_d);
assert_ne!(label_c, label_d);
}
#[test]
fn test_n_clusters_equal_to_samples() {
let data = vec![0.0, 0.0, 5.0, 5.0, 10.0, 10.0];
let matrix = DenseMatrix::new(3, 2, data, false).unwrap();
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(3);
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
&matrix, parameters,
)
.unwrap();
// Each point should be its own cluster. Sorting makes the test deterministic.
let mut labels = clustering.labels;
labels.sort();
assert_eq!(labels, vec![0, 1, 2]);
}
#[test]
fn test_one_cluster() {
let data = vec![0.0, 0.0, 5.0, 5.0, 10.0, 10.0];
let matrix = DenseMatrix::new(3, 2, data, false).unwrap();
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(1);
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
&matrix, parameters,
)
.unwrap();
// All points should be in the same cluster.
assert_eq!(clustering.labels, vec![0, 0, 0]);
}
#[test]
fn test_error_on_too_many_clusters() {
let data = vec![0.0, 0.0, 5.0, 5.0];
let matrix = DenseMatrix::new(2, 2, data, false).unwrap();
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(3);
let result = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
&matrix, parameters,
);
assert!(result.is_err());
}
}
-1
View File
@@ -3,7 +3,6 @@
//! Clustering is the type of unsupervised learning where you divide the population or data points into a number of groups such that data points in the same groups //! Clustering is the type of unsupervised learning where you divide the population or data points into a number of groups such that data points in the same groups
//! are more similar to other data points in the same group than those in other groups. In simple words, the aim is to segregate groups with similar traits and assign them into clusters. //! are more similar to other data points in the same group than those in other groups. In simple words, the aim is to segregate groups with similar traits and assign them into clusters.
pub mod agglomerative;
pub mod dbscan; pub mod dbscan;
/// An iterative clustering algorithm that aims to find local maxima in each iteration. /// An iterative clustering algorithm that aims to find local maxima in each iteration.
pub mod kmeans; pub mod kmeans;
+1 -26
View File
@@ -619,7 +619,7 @@ pub trait MutArrayView1<T: Debug + Display + Copy + Sized>:
T: Number + PartialOrd, T: Number + PartialOrd,
{ {
let stack_size = 64; let stack_size = 64;
let mut jstack: i32 = -1; let mut jstack = -1;
let mut l = 0; let mut l = 0;
let mut istack = vec![0; stack_size]; let mut istack = vec![0; stack_size];
let mut ir = self.shape() - 1; let mut ir = self.shape() - 1;
@@ -2190,29 +2190,4 @@ mod tests {
assert_eq!(result, [65, 581, 30]) assert_eq!(result, [65, 581, 30])
} }
#[test]
fn test_argsort_mut_exact_boundary() {
// Test index == length - 1 case
let boundary =
DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0, f64::MAX], &[3.0, f64::MAX, 0.0, 2.0]])
.unwrap();
let mut view0: Vec<f64> = boundary.get_col(0).iterator(0).copied().collect();
let indices = view0.argsort_mut();
assert_eq!(indices.last(), Some(&1));
assert_eq!(indices.first(), Some(&0));
let mut view1: Vec<f64> = boundary.get_col(3).iterator(0).copied().collect();
let indices = view1.argsort_mut();
assert_eq!(indices.last(), Some(&0));
assert_eq!(indices.first(), Some(&1));
}
#[test]
fn test_argsort_mut_filled_array() {
let matrix = DenseMatrix::<f64>::rand(1000, 1000);
let mut view: Vec<f64> = matrix.get_col(0).iterator(0).copied().collect();
let sorted = view.argsort_mut();
assert_eq!(sorted.len(), 1000);
}
} }
-1
View File
@@ -663,7 +663,6 @@ mod tests {
#[test] #[test]
fn test_instantiate_err_view3() { fn test_instantiate_err_view3() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap(); let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
#[allow(clippy::reversed_empty_ranges)]
let v = DenseMatrixView::new(&x, 0..3, 4..3); let v = DenseMatrixView::new(&x, 0..3, 4..3);
assert!(v.is_err()); assert!(v.is_err());
} }
+2 -1
View File
@@ -257,7 +257,8 @@ impl<TY: Number + Ord + Unsigned> BernoulliNBDistribution<TY> {
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features. /// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
/// * `x` - training data. /// * `x` - training data.
/// * `y` - vector with target values (classes) of length N. /// * `y` - vector with target values (classes) of length N.
/// * `priors` - Optional vector with prior probabilities of the classes. If not defined, priors are adjusted according to the data. /// * `priors` - Optional vector with prior probabilities of the classes. If not defined,
/// priors are adjusted according to the data.
/// * `alpha` - Additive (Laplace/Lidstone) smoothing parameter. /// * `alpha` - Additive (Laplace/Lidstone) smoothing parameter.
/// * `binarize` - Threshold for binarizing. /// * `binarize` - Threshold for binarizing.
fn fit<TX: Number + PartialOrd, X: Array2<TX>, Y: Array1<TY>>( fn fit<TX: Number + PartialOrd, X: Array2<TX>, Y: Array1<TY>>(
+2 -1
View File
@@ -174,7 +174,8 @@ impl<TY: Number + Ord + Unsigned> GaussianNBDistribution<TY> {
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features. /// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
/// * `x` - training data. /// * `x` - training data.
/// * `y` - vector with target values (classes) of length N. /// * `y` - vector with target values (classes) of length N.
/// * `priors` - Optional vector with prior probabilities of the classes. If not defined, priors are adjusted according to the data. /// * `priors` - Optional vector with prior probabilities of the classes. If not defined,
/// priors are adjusted according to the data.
pub fn fit<TX: Number + RealNumber, X: Array2<TX>, Y: Array1<TY>>( pub fn fit<TX: Number + RealNumber, X: Array2<TX>, Y: Array1<TY>>(
x: &X, x: &X,
y: &Y, y: &Y,
+2 -1
View File
@@ -207,7 +207,8 @@ impl<TY: Number + Ord + Unsigned> MultinomialNBDistribution<TY> {
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features. /// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
/// * `x` - training data. /// * `x` - training data.
/// * `y` - vector with target values (classes) of length N. /// * `y` - vector with target values (classes) of length N.
/// * `priors` - Optional vector with prior probabilities of the classes. If not defined, priors are adjusted according to the data. /// * `priors` - Optional vector with prior probabilities of the classes. If not defined,
/// priors are adjusted according to the data.
/// * `alpha` - Additive (Laplace/Lidstone) smoothing parameter. /// * `alpha` - Additive (Laplace/Lidstone) smoothing parameter.
pub fn fit<TX: Number + Unsigned, X: Array2<TX>, Y: Array1<TY>>( pub fn fit<TX: Number + Unsigned, X: Array2<TX>, Y: Array1<TY>>(
x: &X, x: &X,
+1 -1
View File
@@ -64,7 +64,7 @@ impl KNNWeightFunction {
KNNWeightFunction::Distance => { KNNWeightFunction::Distance => {
// if there are any points that has zero distance from one or more training points, // if there are any points that has zero distance from one or more training points,
// those training points are weighted as 1.0 and the other points as 0.0 // those training points are weighted as 1.0 and the other points as 0.0
if distances.contains(&0f64) { if distances.iter().any(|&e| e == 0f64) {
distances distances
.iter() .iter()
.map(|e| if *e == 0f64 { 1f64 } else { 0f64 }) .map(|e| if *e == 0f64 { 1f64 } else { 0f64 })
+7 -3
View File
@@ -24,7 +24,7 @@
//! // &[1.5, 1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0] //! // &[1.5, 1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0]
//! // &[1.5, 0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0] //! // &[1.5, 0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0]
//! ``` //! ```
use std::iter::repeat_n; use std::iter;
use crate::error::Failed; use crate::error::Failed;
use crate::linalg::basic::arrays::Array2; use crate::linalg::basic::arrays::Array2;
@@ -75,7 +75,11 @@ fn find_new_idxs(num_params: usize, cat_sizes: &[usize], cat_idxs: &[usize]) ->
let offset = (0..1).chain(offset_); let offset = (0..1).chain(offset_);
let new_param_idxs: Vec<usize> = (0..num_params) let new_param_idxs: Vec<usize> = (0..num_params)
.zip(repeats.zip(offset).flat_map(|(r, o)| repeat_n(o, r))) .zip(
repeats
.zip(offset)
.flat_map(|(r, o)| iter::repeat(o).take(r)),
)
.map(|(idx, ofst)| idx + ofst) .map(|(idx, ofst)| idx + ofst)
.collect(); .collect();
new_param_idxs new_param_idxs
@@ -120,7 +124,7 @@ impl OneHotEncoder {
let (nrows, _) = data.shape(); let (nrows, _) = data.shape();
// col buffer to avoid allocations // col buffer to avoid allocations
let mut col_buf: Vec<T> = repeat_n(T::zero(), nrows).collect(); let mut col_buf: Vec<T> = iter::repeat(T::zero()).take(nrows).collect();
let mut res: Vec<CategoryMapper<CategoricalFloat>> = Vec::with_capacity(idxs.len()); let mut res: Vec<CategoryMapper<CategoricalFloat>> = Vec::with_capacity(idxs.len());
+177 -281
View File
@@ -25,18 +25,14 @@
/// search parameters /// search parameters
pub mod svc; pub mod svc;
pub mod svr; pub mod svr;
// search parameters space // /// search parameters space
pub mod search; // pub mod search;
use core::fmt::Debug; use core::fmt::Debug;
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
// Only import typetag if not compiling for wasm32 and serde is enabled
#[cfg(all(feature = "serde", not(target_arch = "wasm32")))]
use typetag;
use crate::error::{Failed, FailedError}; use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::{Array1, ArrayView1}; use crate::linalg::basic::arrays::{Array1, ArrayView1};
@@ -52,281 +48,197 @@ pub trait Kernel: Debug {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed>; fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed>;
} }
/// A enumerator for all the kernels type to support. /// Pre-defined kernel functions
/// This allows kernel selection and parameterization ergonomic, type-safe, and ready for use in parameter structs like SVRParameters.
/// You can construct kernels using the provided variants and builder-style methods.
///
/// # Examples
///
/// ```
/// use smartcore::svm::Kernels;
///
/// let linear = Kernels::linear();
/// let rbf = Kernels::rbf().with_gamma(0.5);
/// let poly = Kernels::polynomial().with_degree(3.0).with_gamma(0.5).with_coef0(1.0);
/// let sigmoid = Kernels::sigmoid().with_gamma(0.2).with_coef0(0.0);
/// ```
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone)]
pub enum Kernels { pub struct Kernels;
/// Linear kernel (default).
///
/// Computes the standard dot product between vectors.
Linear,
/// Radial Basis Function (RBF) kernel.
///
/// Formula: K(x, y) = exp(-gamma * ||x-y||²)
RBF {
/// Controls the width of the Gaussian RBF kernel.
///
/// Larger values of gamma lead to higher bias and lower variance.
/// This parameter is inversely proportional to the radius of influence
/// of samples selected by the model as support vectors.
gamma: Option<f64>,
},
/// Polynomial kernel.
///
/// Formula: K(x, y) = (gamma * <x, y> + coef0)^degree
Polynomial {
/// The degree of the polynomial kernel.
///
/// Integer values are typical (2 = quadratic, 3 = cubic), but any positive real value is valid.
/// Higher degree values create decision boundaries with higher complexity.
degree: Option<f64>,
/// Kernel coefficient for the dot product.
///
/// Controls the influence of higher-degree versus lower-degree terms in the polynomial.
/// If None, a default value will be used.
gamma: Option<f64>,
/// Independent term in the polynomial kernel.
///
/// Controls the influence of higher-degree versus lower-degree terms.
/// If None, a default value of 1.0 will be used.
coef0: Option<f64>,
},
/// Sigmoid kernel.
///
/// Formula: K(x, y) = tanh(gamma * <x, y> + coef0)
Sigmoid {
/// Kernel coefficient for the dot product.
///
/// Controls the scaling of the dot product in the sigmoid function.
/// If None, a default value will be used.
gamma: Option<f64>,
/// Independent term in the sigmoid kernel.
///
/// Acts as a threshold/bias term in the sigmoid function.
/// If None, a default value of 1.0 will be used.
coef0: Option<f64>,
},
}
impl Kernels { impl Kernels {
/// Create a linear kernel. /// Return a default linear
/// pub fn linear() -> LinearKernel {
/// The linear kernel computes the dot product between two vectors: LinearKernel
/// K(x, y) = <x, y>
pub fn linear() -> Self {
Kernels::Linear
} }
/// Return a default RBF
/// Create an RBF kernel with unspecified gamma. pub fn rbf() -> RBFKernel {
/// RBFKernel::default()
/// The RBF kernel is defined as:
/// K(x, y) = exp(-gamma * ||x-y||²)
///
/// You should specify gamma using `with_gamma()` before using this kernel.
pub fn rbf() -> Self {
Kernels::RBF { gamma: None }
} }
/// Return a default polynomial
/// Create a polynomial kernel with default parameters. pub fn polynomial() -> PolynomialKernel {
/// PolynomialKernel::default()
/// The polynomial kernel is defined as: }
/// K(x, y) = (gamma * <x, y> + coef0)^degree /// Return a default sigmoid
/// pub fn sigmoid() -> SigmoidKernel {
/// Default values: SigmoidKernel::default()
/// - gamma: None (must be specified)
/// - degree: None (must be specified)
/// - coef0: 1.0
pub fn polynomial() -> Self {
Kernels::Polynomial {
gamma: None,
degree: None,
coef0: Some(1.0),
} }
} }
/// Create a sigmoid kernel with default parameters. /// Linear Kernel
/// #[allow(clippy::derive_partial_eq_without_eq)]
/// The sigmoid kernel is defined as: #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
/// K(x, y) = tanh(gamma * <x, y> + coef0) #[derive(Debug, Clone, PartialEq, Eq, Default)]
/// pub struct LinearKernel;
/// Default values:
/// - gamma: None (must be specified) /// Radial basis function (Gaussian) kernel
/// - coef0: 1.0 #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
/// #[derive(Debug, Default, Clone, PartialEq)]
pub fn sigmoid() -> Self { pub struct RBFKernel {
Kernels::Sigmoid { /// kernel coefficient
gamma: None, pub gamma: Option<f64>,
coef0: Some(1.0),
}
} }
/// Set the `gamma` parameter for RBF, polynomial, or sigmoid kernels. #[allow(dead_code)]
/// impl RBFKernel {
/// The gamma parameter has different interpretations depending on the kernel: /// assign gamma parameter to kernel (required)
/// - For RBF: Controls the width of the Gaussian. Larger values mean tighter fit. /// ```rust
/// - For Polynomial: Scaling factor for the dot product. /// use smartcore::svm::RBFKernel;
/// - For Sigmoid: Scaling factor for the dot product. /// let knl = RBFKernel::default().with_gamma(0.7);
///
pub fn with_gamma(self, gamma: f64) -> Self {
match self {
Kernels::RBF { .. } => Kernels::RBF { gamma: Some(gamma) },
Kernels::Polynomial { degree, coef0, .. } => Kernels::Polynomial {
gamma: Some(gamma),
degree,
coef0,
},
Kernels::Sigmoid { coef0, .. } => Kernels::Sigmoid {
gamma: Some(gamma),
coef0,
},
other => other,
}
}
/// Set the `degree` parameter for the polynomial kernel.
///
/// The degree parameter controls the flexibility of the decision boundary.
/// Higher degrees create more complex boundaries but may lead to overfitting.
///
pub fn with_degree(self, degree: f64) -> Self {
match self {
Kernels::Polynomial { gamma, coef0, .. } => Kernels::Polynomial {
degree: Some(degree),
gamma,
coef0,
},
other => other,
}
}
/// Set the `coef0` parameter for polynomial or sigmoid kernels.
///
/// The coef0 parameter is the independent term in the kernel function:
/// - For Polynomial: Controls the influence of higher-degree vs. lower-degree terms.
/// - For Sigmoid: Acts as a threshold/bias term.
///
pub fn with_coef0(self, coef0: f64) -> Self {
match self {
Kernels::Polynomial { degree, gamma, .. } => Kernels::Polynomial {
degree,
gamma,
coef0: Some(coef0),
},
Kernels::Sigmoid { gamma, .. } => Kernels::Sigmoid {
gamma,
coef0: Some(coef0),
},
other => other,
}
}
}
/// Implementation of the [`Kernel`] trait for the [`Kernels`] enum in smartcore.
///
/// This method computes the value of the kernel function between two feature vectors `x_i` and `x_j`,
/// according to the variant and parameters of the [`Kernels`] enum. This enables flexible and type-safe
/// selection of kernel functions for SVM and SVR models in smartcore.
///
/// # Supported Kernels
///
/// - [`Kernels::Linear`]: Computes the standard dot product between `x_i` and `x_j`.
/// - [`Kernels::RBF`]: Computes the Radial Basis Function (Gaussian) kernel. Requires `gamma`.
/// - [`Kernels::Polynomial`]: Computes the polynomial kernel. Requires `degree`, `gamma`, and `coef0`.
/// - [`Kernels::Sigmoid`]: Computes the sigmoid kernel. Requires `gamma` and `coef0`.
///
/// # Parameters
///
/// - `x_i`: First input vector (feature vector).
/// - `x_j`: Second input vector (feature vector).
///
/// # Returns
///
/// - `Ok(f64)`: The computed kernel value.
/// - `Err(Failed)`: If any required kernel parameter is missing.
///
/// # Errors
///
/// Returns `Err(Failed)` if a required parameter (such as `gamma`, `degree`, or `coef0`)
/// is `None` for the selected kernel variant.
///
/// # Example
///
/// ``` /// ```
/// use smartcore::svm::Kernels; pub fn with_gamma(mut self, gamma: f64) -> Self {
/// use smartcore::svm::Kernel; self.gamma = Some(gamma);
/// self
/// let x = vec![1.0, 2.0, 3.0]; }
/// let y = vec![4.0, 5.0, 6.0]; }
/// let kernel = Kernels::rbf().with_gamma(0.5);
/// let value = kernel.apply(&x, &y).unwrap(); /// Polynomial kernel
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, PartialEq)]
pub struct PolynomialKernel {
/// degree of the polynomial
pub degree: Option<f64>,
/// kernel coefficient
pub gamma: Option<f64>,
/// independent term in kernel function
pub coef0: Option<f64>,
}
impl Default for PolynomialKernel {
fn default() -> Self {
Self {
gamma: Option::None,
degree: Option::None,
coef0: Some(1f64),
}
}
}
impl PolynomialKernel {
/// set parameters for kernel
/// ```rust
/// use smartcore::svm::PolynomialKernel;
/// let knl = PolynomialKernel::default().with_params(3.0, 0.7, 1.0);
/// ``` /// ```
/// pub fn with_params(mut self, degree: f64, gamma: f64, coef0: f64) -> Self {
/// # Notes self.degree = Some(degree);
/// self.gamma = Some(gamma);
/// - This implementation follows smartcore's philosophy: pure Rust, no macros, no unsafe code, self.coef0 = Some(coef0);
/// and an accessible, pythonic API surface for both ML practitioners and Rust beginners. self
/// - All kernel parameters must be set before calling `apply`; missing parameters will result in an error. }
/// /// set gamma parameter for kernel
/// See the [`Kernels`] enum documentation for more details on each kernel type and its parameters. /// ```rust
/// use smartcore::svm::PolynomialKernel;
/// let knl = PolynomialKernel::default().with_gamma(0.7);
/// ```
pub fn with_gamma(mut self, gamma: f64) -> Self {
self.gamma = Some(gamma);
self
}
/// set degree parameter for kernel
/// ```rust
/// use smartcore::svm::PolynomialKernel;
/// let knl = PolynomialKernel::default().with_degree(3.0, 100);
/// ```
pub fn with_degree(self, degree: f64, n_features: usize) -> Self {
self.with_params(degree, 1f64, 1f64 / n_features as f64)
}
}
/// Sigmoid (hyperbolic tangent) kernel
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, PartialEq)]
pub struct SigmoidKernel {
/// kernel coefficient
pub gamma: Option<f64>,
/// independent term in kernel function
pub coef0: Option<f64>,
}
impl Default for SigmoidKernel {
fn default() -> Self {
Self {
gamma: Option::None,
coef0: Some(1f64),
}
}
}
impl SigmoidKernel {
/// set parameters for kernel
/// ```rust
/// use smartcore::svm::SigmoidKernel;
/// let knl = SigmoidKernel::default().with_params(0.7, 1.0);
/// ```
pub fn with_params(mut self, gamma: f64, coef0: f64) -> Self {
self.gamma = Some(gamma);
self.coef0 = Some(coef0);
self
}
/// set gamma parameter for kernel
/// ```rust
/// use smartcore::svm::SigmoidKernel;
/// let knl = SigmoidKernel::default().with_gamma(0.7);
/// ```
pub fn with_gamma(mut self, gamma: f64) -> Self {
self.gamma = Some(gamma);
self
}
}
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)] #[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
impl Kernel for Kernels { impl Kernel for LinearKernel {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> { fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
match self { Ok(x_i.dot(x_j))
Kernels::Linear => Ok(x_i.dot(x_j)), }
Kernels::RBF { gamma } => { }
let gamma = gamma.ok_or_else(|| {
Failed::because(FailedError::ParametersError, "gamma not set") #[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
})?; impl Kernel for RBFKernel {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
if self.gamma.is_none() {
return Err(Failed::because(
FailedError::ParametersError,
"gamma should be set, use {Kernel}::default().with_gamma(..)",
));
}
let v_diff = x_i.sub(x_j); let v_diff = x_i.sub(x_j);
Ok((-gamma * v_diff.mul(&v_diff).sum()).exp()) Ok((-self.gamma.unwrap() * v_diff.mul(&v_diff).sum()).exp())
}
}
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
impl Kernel for PolynomialKernel {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
if self.gamma.is_none() || self.coef0.is_none() || self.degree.is_none() {
return Err(Failed::because(
FailedError::ParametersError, "gamma, coef0, degree should be set,
use {Kernel}::default().with_{parameter}(..)")
);
} }
Kernels::Polynomial {
degree,
gamma,
coef0,
} => {
let degree = degree.ok_or_else(|| {
Failed::because(FailedError::ParametersError, "degree not set")
})?;
let gamma = gamma.ok_or_else(|| {
Failed::because(FailedError::ParametersError, "gamma not set")
})?;
let coef0 = coef0.ok_or_else(|| {
Failed::because(FailedError::ParametersError, "coef0 not set")
})?;
let dot = x_i.dot(x_j); let dot = x_i.dot(x_j);
Ok((gamma * dot + coef0).powf(degree)) Ok((self.gamma.unwrap() * dot + self.coef0.unwrap()).powf(self.degree.unwrap()))
}
}
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
impl Kernel for SigmoidKernel {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
if self.gamma.is_none() || self.coef0.is_none() {
return Err(Failed::because(
FailedError::ParametersError, "gamma, coef0, degree should be set,
use {Kernel}::default().with_{parameter}(..)")
);
} }
Kernels::Sigmoid { gamma, coef0 } => {
let gamma = gamma.ok_or_else(|| {
Failed::because(FailedError::ParametersError, "gamma not set")
})?;
let coef0 = coef0.ok_or_else(|| {
Failed::because(FailedError::ParametersError, "coef0 not set")
})?;
let dot = x_i.dot(x_j); let dot = x_i.dot(x_j);
Ok((gamma * dot + coef0).tanh()) Ok(self.gamma.unwrap() * dot + self.coef0.unwrap().tanh())
}
}
} }
} }
@@ -335,18 +247,6 @@ mod tests {
use super::*; use super::*;
use crate::svm::Kernels; use crate::svm::Kernels;
#[test]
fn rbf_kernel() {
let v1 = vec![1., 2., 3.];
let v2 = vec![4., 5., 6.];
let result = Kernels::rbf()
.with_gamma(0.055)
.apply(&v1, &v2)
.unwrap()
.abs();
assert!((0.2265f64 - result) < 1e-4);
}
#[cfg_attr( #[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")), all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test wasm_bindgen_test::wasm_bindgen_test
@@ -364,7 +264,7 @@ mod tests {
wasm_bindgen_test::wasm_bindgen_test wasm_bindgen_test::wasm_bindgen_test
)] )]
#[test] #[test]
fn test_rbf_kernel() { fn rbf_kernel() {
let v1 = vec![1., 2., 3.]; let v1 = vec![1., 2., 3.];
let v2 = vec![4., 5., 6.]; let v2 = vec![4., 5., 6.];
@@ -387,10 +287,7 @@ mod tests {
let v2 = vec![4., 5., 6.]; let v2 = vec![4., 5., 6.];
let result = Kernels::polynomial() let result = Kernels::polynomial()
.with_gamma(0.5) .with_params(3.0, 0.5, 1.0)
.with_degree(3.0)
.with_coef0(1.0)
//.with_params(3.0, 0.5, 1.0)
.apply(&v1, &v2) .apply(&v1, &v2)
.unwrap() .unwrap()
.abs(); .abs();
@@ -408,8 +305,7 @@ mod tests {
let v2 = vec![4., 5., 6.]; let v2 = vec![4., 5., 6.];
let result = Kernels::sigmoid() let result = Kernels::sigmoid()
.with_gamma(0.01) .with_params(0.01, 0.1)
.with_coef0(0.1)
.apply(&v1, &v2) .apply(&v1, &v2)
.unwrap() .unwrap()
.abs(); .abs();
-2
View File
@@ -1,5 +1,3 @@
//! SVC and Grid Search
/// SVC search parameters /// SVC search parameters
pub mod svc_params; pub mod svc_params;
/// SVC search parameters /// SVC search parameters
+101 -282
View File
@@ -1,293 +1,112 @@
//! # SVR Grid Search Parameters // /// SVR grid search parameters
//! // #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
//! This module provides utilities for defining and iterating over grid search parameter spaces // #[derive(Debug, Clone)]
//! for Support Vector Regression (SVR) models in [smartcore](https://github.com/smartcorelib/smartcore). // pub struct SVRSearchParameters<T: Number + RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
//! // /// Epsilon in the epsilon-SVR model.
//! The main struct, [`SVRSearchParameters`], allows users to specify multiple values for each // pub eps: Vec<T>,
//! SVR hyperparameter (epsilon, regularization parameter C, tolerance, and kernel function). // /// Regularization parameter.
//! The provided iterator yields all possible combinations (the Cartesian product) of these parameters, // pub c: Vec<T>,
//! enabling exhaustive grid search for hyperparameter tuning. // /// Tolerance for stopping eps.
//! // pub tol: Vec<T>,
//! // /// The kernel function.
//! ## Example // pub kernel: Vec<K>,
//! ``` // /// Unused parameter.
//! use smartcore::svm::Kernels; // m: PhantomData<M>,
//! use smartcore::svm::search::svr_params::SVRSearchParameters; // }
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//!
//! let params = SVRSearchParameters::<f64, DenseMatrix<f64>> {
//! eps: vec![0.1, 0.2],
//! c: vec![1.0, 10.0],
//! tol: vec![1e-3],
//! kernel: vec![Kernels::linear(), Kernels::rbf().with_gamma(0.5)],
//! m: std::marker::PhantomData,
//! };
//!
//! // for param_set in params.into_iter() {
//! // Use param_set (of type svr::SVRParameters) to fit and evaluate your SVR model.
//! // }
//! ```
//!
//!
//! ## Note
//! This module is intended for use with smartcore version 0.4 or later. The API is not compatible with older versions[1].
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::linalg::basic::arrays::Array2; // /// SVR grid search iterator
use crate::numbers::basenum::Number; // pub struct SVRSearchParametersIterator<T: Number + RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
use crate::numbers::floatnum::FloatNumber; // svr_search_parameters: SVRSearchParameters<T, M, K>,
use crate::numbers::realnum::RealNumber; // current_eps: usize,
use crate::svm::{svr, Kernels}; // current_c: usize,
use std::marker::PhantomData; // current_tol: usize,
// current_kernel: usize,
// }
/// ## SVR grid search parameters // impl<T: Number + RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> IntoIterator
/// A struct representing a grid of hyperparameters for SVR grid search in smartcore. // for SVRSearchParameters<T, M, K>
/// // {
/// Each field is a vector of possible values for the corresponding SVR hyperparameter. // type Item = SVRParameters<T, M, K>;
/// The [`IntoIterator`] implementation yields every possible combination of these parameters // type IntoIter = SVRSearchParametersIterator<T, M, K>;
/// as an `svr::SVRParameters` struct, suitable for use in model selection routines.
///
/// # Type Parameters
/// - `T`: Numeric type for parameters (e.g., `f64`)
/// - `M`: Matrix type implementing [`Array2<T>`]
///
/// # Fields
/// - `eps`: Vector of epsilon values for the epsilon-insensitive loss in SVR.
/// - `c`: Vector of regularization parameters (C) for SVR.
/// - `tol`: Vector of tolerance values for the stopping criterion.
/// - `kernel`: Vector of kernel function variants (see [`Kernels`]).
/// - `m`: Phantom data for the matrix type parameter.
///
/// # Example
/// ```
/// use smartcore::svm::Kernels;
/// use smartcore::svm::search::svr_params::SVRSearchParameters;
/// use smartcore::linalg::basic::matrix::DenseMatrix;
///
/// let params = SVRSearchParameters::<f64, DenseMatrix<f64>> {
/// eps: vec![0.1, 0.2],
/// c: vec![1.0, 10.0],
/// tol: vec![1e-3],
/// kernel: vec![Kernels::linear(), Kernels::rbf().with_gamma(0.5)],
/// m: std::marker::PhantomData,
/// };
/// ```
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct SVRSearchParameters<T: Number + RealNumber, M: Array2<T>> {
/// Epsilon in the epsilon-SVR model.
pub eps: Vec<T>,
/// Regularization parameter.
pub c: Vec<T>,
/// Tolerance for stopping eps.
pub tol: Vec<T>,
/// The kernel function.
pub kernel: Vec<Kernels>,
/// Unused parameter.
pub m: PhantomData<M>,
}
/// SVR grid search iterator // fn into_iter(self) -> Self::IntoIter {
pub struct SVRSearchParametersIterator<T: Number + RealNumber, M: Array2<T>> { // SVRSearchParametersIterator {
svr_search_parameters: SVRSearchParameters<T, M>, // svr_search_parameters: self,
current_eps: usize, // current_eps: 0,
current_c: usize, // current_c: 0,
current_tol: usize, // current_tol: 0,
current_kernel: usize, // current_kernel: 0,
} // }
// }
// }
impl<T: Number + FloatNumber + RealNumber, M: Array2<T>> IntoIterator // impl<T: Number + RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Iterator
for SVRSearchParameters<T, M> // for SVRSearchParametersIterator<T, M, K>
{ // {
type Item = svr::SVRParameters<T>; // type Item = SVRParameters<T, M, K>;
type IntoIter = SVRSearchParametersIterator<T, M>;
fn into_iter(self) -> Self::IntoIter { // fn next(&mut self) -> Option<Self::Item> {
SVRSearchParametersIterator { // if self.current_eps == self.svr_search_parameters.eps.len()
svr_search_parameters: self, // && self.current_c == self.svr_search_parameters.c.len()
current_eps: 0, // && self.current_tol == self.svr_search_parameters.tol.len()
current_c: 0, // && self.current_kernel == self.svr_search_parameters.kernel.len()
current_tol: 0, // {
current_kernel: 0, // return None;
} // }
}
}
impl<T: Number + FloatNumber + RealNumber, M: Array2<T>> Iterator // let next = SVRParameters::<T, M, K> {
for SVRSearchParametersIterator<T, M> // eps: self.svr_search_parameters.eps[self.current_eps],
{ // c: self.svr_search_parameters.c[self.current_c],
type Item = svr::SVRParameters<T>; // tol: self.svr_search_parameters.tol[self.current_tol],
// kernel: self.svr_search_parameters.kernel[self.current_kernel].clone(),
// m: PhantomData,
// };
fn next(&mut self) -> Option<Self::Item> { // if self.current_eps + 1 < self.svr_search_parameters.eps.len() {
if self.current_eps == self.svr_search_parameters.eps.len() // self.current_eps += 1;
&& self.current_c == self.svr_search_parameters.c.len() // } else if self.current_c + 1 < self.svr_search_parameters.c.len() {
&& self.current_tol == self.svr_search_parameters.tol.len() // self.current_eps = 0;
&& self.current_kernel == self.svr_search_parameters.kernel.len() // self.current_c += 1;
{ // } else if self.current_tol + 1 < self.svr_search_parameters.tol.len() {
return None; // self.current_eps = 0;
} // self.current_c = 0;
// self.current_tol += 1;
// } else if self.current_kernel + 1 < self.svr_search_parameters.kernel.len() {
// self.current_eps = 0;
// self.current_c = 0;
// self.current_tol = 0;
// self.current_kernel += 1;
// } else {
// self.current_eps += 1;
// self.current_c += 1;
// self.current_tol += 1;
// self.current_kernel += 1;
// }
let next = svr::SVRParameters::<T> { // Some(next)
eps: self.svr_search_parameters.eps[self.current_eps], // }
c: self.svr_search_parameters.c[self.current_c], // }
tol: self.svr_search_parameters.tol[self.current_tol],
kernel: Some(self.svr_search_parameters.kernel[self.current_kernel].clone()),
};
if self.current_eps + 1 < self.svr_search_parameters.eps.len() { // impl<T: Number + RealNumber, M: Matrix<T>> Default for SVRSearchParameters<T, M, LinearKernel> {
self.current_eps += 1; // fn default() -> Self {
} else if self.current_c + 1 < self.svr_search_parameters.c.len() { // let default_params: SVRParameters<T, M, LinearKernel> = SVRParameters::default();
self.current_eps = 0;
self.current_c += 1;
} else if self.current_tol + 1 < self.svr_search_parameters.tol.len() {
self.current_eps = 0;
self.current_c = 0;
self.current_tol += 1;
} else if self.current_kernel + 1 < self.svr_search_parameters.kernel.len() {
self.current_eps = 0;
self.current_c = 0;
self.current_tol = 0;
self.current_kernel += 1;
} else {
self.current_eps += 1;
self.current_c += 1;
self.current_tol += 1;
self.current_kernel += 1;
}
Some(next) // SVRSearchParameters {
} // eps: vec![default_params.eps],
} // c: vec![default_params.c],
// tol: vec![default_params.tol],
// kernel: vec![default_params.kernel],
// m: PhantomData,
// }
// }
// }
impl<T: Number + FloatNumber + RealNumber, M: Array2<T>> Default for SVRSearchParameters<T, M> { // #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
fn default() -> Self { // #[derive(Debug)]
let default_params: svr::SVRParameters<T> = svr::SVRParameters::default(); // #[cfg_attr(
// feature = "serde",
SVRSearchParameters { // serde(bound(
eps: vec![default_params.eps], // serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize",
c: vec![default_params.c], // deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>",
tol: vec![default_params.tol], // ))
kernel: vec![default_params.kernel.unwrap_or_else(Kernels::linear)], // )]
m: PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::svm::Kernels;
type T = f64;
type M = DenseMatrix<T>;
#[test]
fn test_default_parameters() {
let params = SVRSearchParameters::<T, M>::default();
assert_eq!(params.eps.len(), 1);
assert_eq!(params.c.len(), 1);
assert_eq!(params.tol.len(), 1);
assert_eq!(params.kernel.len(), 1);
// Check that the default kernel is linear
assert_eq!(params.kernel[0], Kernels::linear());
}
#[test]
fn test_single_grid_iteration() {
let params = SVRSearchParameters::<T, M> {
eps: vec![0.1],
c: vec![1.0],
tol: vec![1e-3],
kernel: vec![Kernels::rbf().with_gamma(0.5)],
m: PhantomData,
};
let mut iter = params.into_iter();
let param = iter.next().unwrap();
assert_eq!(param.eps, 0.1);
assert_eq!(param.c, 1.0);
assert_eq!(param.tol, 1e-3);
assert_eq!(param.kernel, Some(Kernels::rbf().with_gamma(0.5)));
assert!(iter.next().is_none());
}
#[test]
fn test_cartesian_grid_iteration() {
let params = SVRSearchParameters::<T, M> {
eps: vec![0.1, 0.2],
c: vec![1.0, 2.0],
tol: vec![1e-3],
kernel: vec![Kernels::linear(), Kernels::rbf().with_gamma(0.5)],
m: PhantomData,
};
let expected_count =
params.eps.len() * params.c.len() * params.tol.len() * params.kernel.len();
let results: Vec<_> = params.into_iter().collect();
assert_eq!(results.len(), expected_count);
// Check that all parameter combinations are present
let mut seen = vec![];
for p in &results {
seen.push((p.eps, p.c, p.tol, p.kernel.clone().unwrap()));
}
for &eps in &[0.1, 0.2] {
for &c in &[1.0, 2.0] {
for &tol in &[1e-3] {
for kernel in &[Kernels::linear(), Kernels::rbf().with_gamma(0.5)] {
assert!(seen.contains(&(eps, c, tol, kernel.clone())));
}
}
}
}
}
#[test]
fn test_empty_grid() {
let params = SVRSearchParameters::<T, M> {
eps: vec![],
c: vec![],
tol: vec![],
kernel: vec![],
m: PhantomData,
};
let mut iter = params.into_iter();
assert!(iter.next().is_none());
}
#[test]
fn test_kernel_enum_variants() {
let lin = Kernels::linear();
let rbf = Kernels::rbf().with_gamma(0.2);
let poly = Kernels::polynomial()
.with_degree(2.0)
.with_gamma(1.0)
.with_coef0(0.5);
let sig = Kernels::sigmoid().with_gamma(0.3).with_coef0(0.1);
assert_eq!(lin, Kernels::Linear);
match rbf {
Kernels::RBF { gamma } => assert_eq!(gamma, Some(0.2)),
_ => panic!("Not RBF"),
}
match poly {
Kernels::Polynomial {
degree,
gamma,
coef0,
} => {
assert_eq!(degree, Some(2.0));
assert_eq!(gamma, Some(1.0));
assert_eq!(coef0, Some(0.5));
}
_ => panic!("Not Polynomial"),
}
match sig {
Kernels::Sigmoid { gamma, coef0 } => {
assert_eq!(gamma, Some(0.3));
assert_eq!(coef0, Some(0.1));
}
_ => panic!("Not Sigmoid"),
}
}
}
+61 -375
View File
@@ -58,11 +58,10 @@
//! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; //! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
//! //!
//! let knl = Kernels::linear(); //! let knl = Kernels::linear();
//! let parameters = &SVCParameters::default().with_c(200.0).with_kernel(knl); //! let params = &SVCParameters::default().with_c(200.0).with_kernel(knl);
//! let svc = SVC::fit(&x, &y, parameters).unwrap(); //! let svc = SVC::fit(&x, &y, params).unwrap();
//! //!
//! let y_hat = svc.predict(&x).unwrap(); //! let y_hat = svc.predict(&x).unwrap();
//!
//! ``` //! ```
//! //!
//! ## References: //! ## References:
@@ -85,194 +84,12 @@ use serde::{Deserialize, Serialize};
use crate::api::{PredictorBorrow, SupervisedEstimatorBorrow}; use crate::api::{PredictorBorrow, SupervisedEstimatorBorrow};
use crate::error::{Failed, FailedError}; use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::{Array, Array1, Array2, MutArray}; use crate::linalg::basic::arrays::{Array1, Array2, MutArray};
use crate::numbers::basenum::Number; use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber; use crate::numbers::realnum::RealNumber;
use crate::rand_custom::get_rng_impl; use crate::rand_custom::get_rng_impl;
use crate::svm::Kernel; use crate::svm::Kernel;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
/// Configuration for a multi-class Support Vector Machine (SVM) classifier.
/// This struct holds the indices of the data points relevant to a specific binary
/// classification problem within a multi-class context, and the two classes
/// being discriminated.
struct MultiClassConfig<TY: Number + Ord> {
/// The indices of the data points from the original dataset that belong to the two `classes`.
indices: Vec<usize>,
/// A tuple representing the two classes that this configuration is designed to distinguish.
classes: (TY, TY),
}
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
SupervisedEstimatorBorrow<'a, X, Y, SVCParameters<TX, TY, X, Y>>
for MultiClassSVC<'a, TX, TY, X, Y>
{
/// Creates a new, empty `MultiClassSVC` instance.
fn new() -> Self {
Self {
classifiers: Option::None,
}
}
/// Fits the `MultiClassSVC` model to the provided data and parameters.
///
/// This method delegates the fitting process to the inherent `MultiClassSVC::fit` method.
///
/// # Arguments
/// * `x` - A reference to the input features (2D array).
/// * `y` - A reference to the target labels (1D array).
/// * `parameters` - A reference to the `SVCParameters` controlling the SVM training.
///
/// # Returns
/// A `Result` indicating success (`Self`) or failure (`Failed`).
fn fit(
x: &'a X,
y: &'a Y,
parameters: &'a SVCParameters<TX, TY, X, Y>,
) -> Result<Self, Failed> {
MultiClassSVC::fit(x, y, parameters)
}
}
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
PredictorBorrow<'a, X, TX> for MultiClassSVC<'a, TX, TY, X, Y>
{
/// Predicts the class labels for new data points.
///
/// This method delegates the prediction process to the inherent `MultiClassSVC::predict` method.
///
/// # Arguments
/// * `x` - A reference to the input features (2D array) for which to make predictions.
///
/// # Returns
/// A `Result` containing a `Vec` of predicted class labels (`TX`) or a `Failed` error.
fn predict(&self, x: &'a X) -> Result<Vec<TX>, Failed> {
Ok(self.predict(x).unwrap())
}
}
/// A multi-class Support Vector Machine (SVM) classifier.
///
/// This struct implements a multi-class SVM using the "one-vs-one" strategy,
/// where a separate binary SVC classifier is trained for every pair of classes.
///
/// # Type Parameters
/// * `'a` - Lifetime parameter for borrowed data.
/// * `TX` - The numeric type of the input features (must implement `Number` and `RealNumber`).
/// * `TY` - The numeric type of the target labels (must implement `Number` and `Ord`).
/// * `X` - The type representing the 2D array of input features (e.g., a matrix).
/// * `Y` - The type representing the 1D array of target labels (e.g., a vector).
pub struct MultiClassSVC<
'a,
TX: Number + RealNumber,
TY: Number + Ord,
X: Array2<TX>,
Y: Array1<TY>,
> {
/// An optional vector of binary `SVC` classifiers.
classifiers: Option<Vec<SVC<'a, TX, TY, X, Y>>>,
}
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
MultiClassSVC<'a, TX, TY, X, Y>
{
/// Fits the `MultiClassSVC` model to the provided data using a one-vs-one strategy.
///
/// This method identifies all unique classes in the target labels `y` and then
/// trains a binary `SVC` for every unique pair of classes. For each pair, it
/// extracts the relevant data points and their labels, and then trains a
/// specialized `SVC` for that binary classification task.
///
/// # Arguments
/// * `x` - A reference to the input features (2D array).
/// * `y` - A reference to the target labels (1D array).
/// * `parameters` - A reference to the `SVCParameters` controlling the SVM training for each individual binary classifier.
///
///
/// # Returns
/// A `Result` indicating success (`MultiClassSVC`) or failure (`Failed`).
pub fn fit(
x: &'a X,
y: &'a Y,
parameters: &'a SVCParameters<TX, TY, X, Y>,
) -> Result<MultiClassSVC<'a, TX, TY, X, Y>, Failed> {
let unique_classes = y.unique();
let mut classifiers = Vec::new();
// Iterate through all unique pairs of classes (one-vs-one strategy)
for i in 0..unique_classes.len() {
for j in i..unique_classes.len() {
if i == j {
continue;
}
let class0 = unique_classes[j];
let class1 = unique_classes[i];
let mut indices = Vec::new();
// Collect indices of data points belonging to the current pair of classes
for (index, v) in y.iterator(0).enumerate() {
if *v == class0 || *v == class1 {
indices.push(index)
}
}
let classes = (class0, class1);
let multiclass_config = MultiClassConfig { classes, indices };
// Fit a binary SVC for the current pair of classes
let svc = SVC::multiclass_fit(x, y, parameters, multiclass_config).unwrap();
classifiers.push(svc);
}
}
Ok(Self {
classifiers: Some(classifiers),
})
}
/// Predicts the class labels for new data points using the trained multi-class SVM.
///
/// This method uses a "voting" scheme (majority vote) among all the binary
/// classifiers to determine the final prediction for each data point.
///
/// # Arguments
/// * `x` - A reference to the input features (2D array) for which to make predictions.
///
/// # Returns
/// A `Result` containing a `Vec` of predicted class labels (`TX`) or a `Failed` error.
///
pub fn predict(&self, x: &X) -> Result<Vec<TX>, Failed> {
// Initialize a HashMap for each data point to store votes for each class
let mut polls = vec![HashMap::new(); x.shape().0];
// Retrieve the trained binary classifiers
let classifiers = self.classifiers.as_ref().unwrap();
// Iterate through each binary classifier
for i in 0..classifiers.len() {
let svc = classifiers.get(i).unwrap();
let predictions = svc.predict(x).unwrap(); // call SVC::predict for each binary classifier
// For each prediction from the current binary classifier
for (j, prediction) in predictions.iter().enumerate() {
let prediction = prediction.to_i32().unwrap();
let poll = polls.get_mut(j).unwrap(); // Get the poll for the current data point
// Increment the vote for the predicted class
if let Some(count) = poll.get_mut(&prediction) {
*count += 1
} else {
poll.insert(prediction, 1);
}
}
}
// Determine the final prediction for each data point based on majority vote
Ok(polls
.iter()
.map(|v| {
// Find the class with the maximum votes for each data point
TX::from(*v.iter().max_by_key(|(_, class)| *class).unwrap().0).unwrap()
})
.collect())
}
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)] #[derive(Debug)]
/// SVC Parameters /// SVC Parameters
@@ -306,7 +123,7 @@ pub struct SVCParameters<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX
)] )]
/// Support Vector Classifier /// Support Vector Classifier
pub struct SVC<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> { pub struct SVC<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> {
classes: Option<(TY, TY)>, classes: Option<Vec<TY>>,
instances: Option<Vec<Vec<TX>>>, instances: Option<Vec<Vec<TX>>>,
#[cfg_attr(feature = "serde", serde(skip))] #[cfg_attr(feature = "serde", serde(skip))]
parameters: Option<&'a SVCParameters<TX, TY, X, Y>>, parameters: Option<&'a SVCParameters<TX, TY, X, Y>>,
@@ -335,9 +152,7 @@ struct Cache<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1
struct Optimizer<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> { struct Optimizer<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> {
x: &'a X, x: &'a X,
y: &'a Y, y: &'a Y,
indices: Option<Vec<usize>>,
parameters: &'a SVCParameters<TX, TY, X, Y>, parameters: &'a SVCParameters<TX, TY, X, Y>,
classes: &'a (TY, TY),
svmin: usize, svmin: usize,
svmax: usize, svmax: usize,
gmin: TX, gmin: TX,
@@ -365,12 +180,12 @@ impl<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
self.tol = tol; self.tol = tol;
self self
} }
/// The kernel function. /// The kernel function.
pub fn with_kernel<K: Kernel + 'static>(mut self, kernel: K) -> Self { pub fn with_kernel<K: Kernel + 'static>(mut self, kernel: K) -> Self {
self.kernel = Some(Box::new(kernel)); self.kernel = Some(Box::new(kernel));
self self
} }
/// Seed for the pseudo random number generator. /// Seed for the pseudo random number generator.
pub fn with_seed(mut self, seed: Option<u64>) -> Self { pub fn with_seed(mut self, seed: Option<u64>) -> Self {
self.seed = seed; self.seed = seed;
@@ -426,98 +241,17 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array1<TY> + 'a> impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array1<TY> + 'a>
SVC<'a, TX, TY, X, Y> SVC<'a, TX, TY, X, Y>
{ {
/// Fits a binary Support Vector Classifier (SVC) to the provided data. /// Fits SVC to your data.
/// /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// This is the primary `fit` method for a standalone binary SVC. It expects /// * `y` - class labels
/// the target labels `y` to contain exactly two unique classes. If more or /// * `parameters` - optional parameters, use `Default::default()` to set parameters to default values.
/// fewer than two classes are found, it returns an error. It then extracts
/// these two classes and proceeds to optimize and fit the SVC model.
///
/// # Arguments
/// * `x` - A reference to the input features (2D array) of the training data.
/// * `y` - A reference to the target labels (1D array) of the training data. `y` must contain exactly two unique class labels.
/// * `parameters` - A reference to the `SVCParameters` controlling the training process.
///
/// # Returns
/// A `Result` which is:
/// - `Ok(SVC<'a, TX, TY, X, Y>)`: A new, fitted binary SVC instance.
/// - `Err(Failed)`: If the number of unique classes in `y` is not exactly two, or if the underlying optimization fails.
pub fn fit( pub fn fit(
x: &'a X, x: &'a X,
y: &'a Y, y: &'a Y,
parameters: &'a SVCParameters<TX, TY, X, Y>, parameters: &'a SVCParameters<TX, TY, X, Y>,
) -> Result<SVC<'a, TX, TY, X, Y>, Failed> { ) -> Result<SVC<'a, TX, TY, X, Y>, Failed> {
let classes = y.unique(); let (n, _) = x.shape();
// Validate that there are exactly two unique classes in the target labels.
if classes.len() != 2 {
return Err(Failed::fit(&format!(
"Incorrect number of classes: {}. A binary SVC requires exactly two classes.",
classes.len()
)));
}
let classes = (classes[0], classes[1]);
let svc = Self::optimize_and_fit(x, y, parameters, classes, None);
svc
}
/// Fits a binary Support Vector Classifier (SVC) specifically for multi-class scenarios.
///
/// This function is intended to be called by a multi-class strategy (e.g., one-vs-one)
/// to train individual binary SVCs. It takes a `MultiClassConfig` which specifies
/// the two classes this SVC should discriminate and the subset of data indices
/// relevant to these classes. It then delegates the actual optimization and fitting
/// to `optimize_and_fit`.
///
/// # Arguments
/// * `x` - A reference to the input features (2D array) of the training data.
/// * `y` - A reference to the target labels (1D array) of the training data.
/// * `parameters` - A reference to the `SVCParameters` controlling the training process (e.g., kernel, C-value, tolerance).
/// * `multiclass_config` - A `MultiClassConfig` struct containing:
/// - `classes`: A tuple `(class0, class1)` specifying the two classes this SVC should distinguish.
/// - `indices`: A `Vec<usize>` containing the indices of the data points in `x` and `y that belong to either `class0` or `class1`.`
///
/// # Returns
/// A `Result` which is:
/// - `Ok(SVC<'a, TX, TY, X, Y>)`: A new, fitted binary SVC instance.
/// - `Err(Failed)`: If the fitting process encounters an error (e.g., invalid parameters).
fn multiclass_fit(
x: &'a X,
y: &'a Y,
parameters: &'a SVCParameters<TX, TY, X, Y>,
multiclass_config: MultiClassConfig<TY>,
) -> Result<SVC<'a, TX, TY, X, Y>, Failed> {
let classes = multiclass_config.classes;
let indices = multiclass_config.indices;
let svc = Self::optimize_and_fit(x, y, parameters, classes, Some(indices));
svc
}
/// Internal function to optimize and fit the Support Vector Classifier.
///
/// This is the core logic for training a binary SVC. It performs several checks
/// (e.g., kernel presence, data shape consistency) and then initializes an
/// `Optimizer` to find the support vectors, weights (`w`), and bias (`b`).
///
/// # Arguments
/// * `x` - A reference to the input features (2D array) of the training data.
/// * `y` - A reference to the target labels (1D array) of the training data.
/// * `parameters` - A reference to the `SVCParameters` defining the SVM model's configuration.
/// * `classes` - A tuple `(class0, class1)` representing the two distinct class labels that the SVC will learn to separate.
/// * `indices` - An `Option<Vec<usize>>`. If `Some`, it contains the specific indices of data points from `x` and `y` that should be used for training this binary classifier. If `None`, all data points in `x` and `y` are considered.
/// # Returns
/// A `Result` which is:
/// - `Ok(SVC<'a, TX, TY, X, Y>)`: A new `SVC` instance populated with the learned model components (support vectors, weights, bias).
/// - `Err(Failed)`: If any of the validation checks fail (e.g., missing kernel, mismatched data shapes), or if the optimization process fails.
fn optimize_and_fit(
x: &'a X,
y: &'a Y,
parameters: &'a SVCParameters<TX, TY, X, Y>,
classes: (TY, TY),
indices: Option<Vec<usize>>,
) -> Result<SVC<'a, TX, TY, X, Y>, Failed> {
let (n_samples, _) = x.shape();
// Validate that a kernel has been defined in the parameters.
if parameters.kernel.is_none() { if parameters.kernel.is_none() {
return Err(Failed::because( return Err(Failed::because(
FailedError::ParametersError, FailedError::ParametersError,
@@ -525,39 +259,55 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array
)); ));
} }
// Validate that the number of samples in X matches the number of labels in Y. if n != y.shape() {
if n_samples != y.shape() {
return Err(Failed::fit( return Err(Failed::fit(
"Number of rows of X doesn't match number of rows of Y", "Number of rows of X doesn\'t match number of rows of Y",
)); ));
} }
let optimizer: Optimizer<'_, TX, TY, X, Y> = let classes = y.unique();
Optimizer::new(x, y, indices, parameters, &classes);
if classes.len() != 2 {
return Err(Failed::fit(&format!(
"Incorrect number of classes: {}",
classes.len()
)));
}
// Make sure class labels are either 1 or -1
for e in y.iterator(0) {
let y_v = e.to_i32().unwrap();
if y_v != -1 && y_v != 1 {
return Err(Failed::because(
FailedError::ParametersError,
"Class labels must be 1 or -1",
));
}
}
let optimizer: Optimizer<'_, TX, TY, X, Y> = Optimizer::new(x, y, parameters);
// Perform the optimization to find the support vectors, weight vector, and bias.
// This is where the core SVM algorithm (e.g., SMO) would run.
let (support_vectors, weight, b) = optimizer.optimize(); let (support_vectors, weight, b) = optimizer.optimize();
// Construct and return the fitted SVC model.
Ok(SVC::<'a> { Ok(SVC::<'a> {
classes: Some(classes), // Store the two classes the SVC was trained on. classes: Some(classes),
instances: Some(support_vectors), // Store the data points that are support vectors. instances: Some(support_vectors),
parameters: Some(parameters), // Reference to the parameters used for fitting. parameters: Some(parameters),
w: Some(weight), // The learned weight vector (for linear kernels). w: Some(weight),
b: Some(b), // The learned bias term. b: Some(b),
phantomdata: PhantomData, // Placeholder for type parameters not directly stored. phantomdata: PhantomData,
}) })
} }
/// Predicts estimated class labels from `x` /// Predicts estimated class labels from `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &'a X) -> Result<Vec<TX>, Failed> { pub fn predict(&self, x: &'a X) -> Result<Vec<TX>, Failed> {
let mut y_hat: Vec<TX> = self.decision_function(x)?; let mut y_hat: Vec<TX> = self.decision_function(x)?;
for i in 0..y_hat.len() { for i in 0..y_hat.len() {
let cls_idx = match *y_hat.get(i) > TX::zero() { let cls_idx = match *y_hat.get(i).unwrap() > TX::zero() {
false => TX::from(self.classes.as_ref().unwrap().0).unwrap(), false => TX::from(self.classes.as_ref().unwrap()[0]).unwrap(),
true => TX::from(self.classes.as_ref().unwrap().1).unwrap(), true => TX::from(self.classes.as_ref().unwrap()[1]).unwrap(),
}; };
y_hat.set(i, cls_idx); y_hat.set(i, cls_idx);
@@ -695,18 +445,14 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
fn new( fn new(
x: &'a X, x: &'a X,
y: &'a Y, y: &'a Y,
indices: Option<Vec<usize>>,
parameters: &'a SVCParameters<TX, TY, X, Y>, parameters: &'a SVCParameters<TX, TY, X, Y>,
classes: &'a (TY, TY),
) -> Optimizer<'a, TX, TY, X, Y> { ) -> Optimizer<'a, TX, TY, X, Y> {
let (n, _) = x.shape(); let (n, _) = x.shape();
Optimizer { Optimizer {
x, x,
y, y,
indices,
parameters, parameters,
classes,
svmin: 0, svmin: 0,
svmax: 0, svmax: 0,
gmin: <TX as Bounded>::max_value(), gmin: <TX as Bounded>::max_value(),
@@ -732,12 +478,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
for i in self.permutate(n) { for i in self.permutate(n) {
x.clear(); x.clear();
x.extend(self.x.get_row(i).iterator(0).take(n).copied()); x.extend(self.x.get_row(i).iterator(0).take(n).copied());
let y = if *self.y.get(i) == self.classes.1 { self.process(i, &x, *self.y.get(i), &mut cache);
1
} else {
-1
} as f64;
self.process(i, &x, y, &mut cache);
loop { loop {
self.reprocess(tol, &mut cache); self.reprocess(tol, &mut cache);
self.find_min_max_gradient(); self.find_min_max_gradient();
@@ -773,16 +514,14 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
for i in self.permutate(n) { for i in self.permutate(n) {
x.clear(); x.clear();
x.extend(self.x.get_row(i).iterator(0).take(n).copied()); x.extend(self.x.get_row(i).iterator(0).take(n).copied());
let y = if *self.y.get(i) == self.classes.1 { if *self.y.get(i) == TY::one() && cp < few {
1 if self.process(i, &x, *self.y.get(i), cache) {
} else {
-1
} as f64;
if y == 1.0 && cp < few {
if self.process(i, &x, y, cache) {
cp += 1; cp += 1;
} }
} else if y == -1.0 && cn < few && self.process(i, &x, y, cache) { } else if *self.y.get(i) == TY::from(-1).unwrap()
&& cn < few
&& self.process(i, &x, *self.y.get(i), cache)
{
cn += 1; cn += 1;
} }
@@ -792,14 +531,14 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
} }
} }
fn process(&mut self, i: usize, x: &[TX], y: f64, cache: &mut Cache<TX, TY, X, Y>) -> bool { fn process(&mut self, i: usize, x: &[TX], y: TY, cache: &mut Cache<TX, TY, X, Y>) -> bool {
for j in 0..self.sv.len() { for j in 0..self.sv.len() {
if self.sv[j].index == i { if self.sv[j].index == i {
return true; return true;
} }
} }
let mut g = y; let mut g: f64 = y.to_f64().unwrap();
let mut cache_values: Vec<((usize, usize), TX)> = Vec::new(); let mut cache_values: Vec<((usize, usize), TX)> = Vec::new();
@@ -820,8 +559,8 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
self.find_min_max_gradient(); self.find_min_max_gradient();
if self.gmin < self.gmax if self.gmin < self.gmax
&& ((y > 0.0 && g < self.gmin.to_f64().unwrap()) && ((y > TY::zero() && g < self.gmin.to_f64().unwrap())
|| (y < 0.0 && g > self.gmax.to_f64().unwrap())) || (y < TY::zero() && g > self.gmax.to_f64().unwrap()))
{ {
return false; return false;
} }
@@ -851,7 +590,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
), ),
); );
if y > 0.0 { if y > TY::zero() {
self.smo(None, Some(0), TX::zero(), cache); self.smo(None, Some(0), TX::zero(), cache);
} else { } else {
self.smo(Some(0), None, TX::zero(), cache); self.smo(Some(0), None, TX::zero(), cache);
@@ -908,6 +647,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
let gmin = self.gmin; let gmin = self.gmin;
let mut idxs_to_drop: HashSet<usize> = HashSet::new(); let mut idxs_to_drop: HashSet<usize> = HashSet::new();
self.sv.retain(|v| { self.sv.retain(|v| {
if v.alpha == 0f64 if v.alpha == 0f64
&& ((TX::from(v.grad).unwrap() >= gmax && TX::zero() >= TX::from(v.cmax).unwrap()) && ((TX::from(v.grad).unwrap() >= gmax && TX::zero() >= TX::from(v.cmax).unwrap())
@@ -926,11 +666,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
fn permutate(&self, n: usize) -> Vec<usize> { fn permutate(&self, n: usize) -> Vec<usize> {
let mut rng = get_rng_impl(self.parameters.seed); let mut rng = get_rng_impl(self.parameters.seed);
let mut range = if let Some(indices) = self.indices.clone() { let mut range: Vec<usize> = (0..n).collect();
indices
} else {
(0..n).collect::<Vec<usize>>()
};
range.shuffle(&mut rng); range.shuffle(&mut rng);
range range
} }
@@ -1229,12 +965,12 @@ mod tests {
]; ];
let knl = Kernels::linear(); let knl = Kernels::linear();
let parameters = SVCParameters::default() let params = SVCParameters::default()
.with_c(200.0) .with_c(200.0)
.with_kernel(knl) .with_kernel(knl)
.with_seed(Some(100)); .with_seed(Some(100));
let y_hat = SVC::fit(&x, &y, &parameters) let y_hat = SVC::fit(&x, &y, &params)
.and_then(|lr| lr.predict(&x)) .and_then(|lr| lr.predict(&x))
.unwrap(); .unwrap();
let acc = accuracy(&y, &(y_hat.iter().map(|e| e.to_i32().unwrap()).collect())); let acc = accuracy(&y, &(y_hat.iter().map(|e| e.to_i32().unwrap()).collect()));
@@ -1334,56 +1070,6 @@ mod tests {
assert!(acc >= 0.9, "accuracy ({acc}) is not larger or equal to 0.9"); assert!(acc >= 0.9, "accuracy ({acc}) is not larger or equal to 0.9");
} }
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn svc_multiclass_fit_predict() {
let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[5.7, 2.8, 4.5, 1.3],
&[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
])
.unwrap();
let y: Vec<i32> = vec![0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2];
let knl = Kernels::linear();
let parameters = SVCParameters::default()
.with_c(200.0)
.with_kernel(knl)
.with_seed(Some(100));
let y_hat = MultiClassSVC::fit(&x, &y, &parameters)
.and_then(|lr| lr.predict(&x))
.unwrap();
let acc = accuracy(&y, &(y_hat.iter().map(|e| e.to_i32().unwrap()).collect()));
assert!(
acc >= 0.9,
"Multiclass accuracy ({acc}) is not larger or equal to 0.9"
);
}
#[cfg_attr( #[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")), all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test wasm_bindgen_test::wasm_bindgen_test
@@ -1420,8 +1106,8 @@ mod tests {
]; ];
let knl = Kernels::linear(); let knl = Kernels::linear();
let parameters = SVCParameters::default().with_kernel(knl); let params = SVCParameters::default().with_kernel(knl);
let svc = SVC::fit(&x, &y, &parameters).unwrap(); let svc = SVC::fit(&x, &y, &params).unwrap();
// serialization // serialization
let deserialized_svc: SVC<'_, f64, i32, _, _> = let deserialized_svc: SVC<'_, f64, i32, _, _> =
+20 -21
View File
@@ -51,9 +51,9 @@
//! //!
//! let knl = Kernels::linear(); //! let knl = Kernels::linear();
//! let params = &SVRParameters::default().with_eps(2.0).with_c(10.0).with_kernel(knl); //! let params = &SVRParameters::default().with_eps(2.0).with_c(10.0).with_kernel(knl);
//! let svr = SVR::fit(&x, &y, params).unwrap(); //! // let svr = SVR::fit(&x, &y, params).unwrap();
//! //!
//! let y_hat = svr.predict(&x).unwrap(); //! // let y_hat = svr.predict(&x).unwrap();
//! ``` //! ```
//! //!
//! ## References: //! ## References:
@@ -80,12 +80,11 @@ use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::{Array1, Array2, MutArray}; use crate::linalg::basic::arrays::{Array1, Array2, MutArray};
use crate::numbers::basenum::Number; use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber; use crate::numbers::floatnum::FloatNumber;
use crate::svm::Kernel;
use crate::svm::{Kernel, Kernels};
/// SVR Parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)] #[derive(Debug)]
/// SVR Parameters
pub struct SVRParameters<T: Number + FloatNumber + PartialOrd> { pub struct SVRParameters<T: Number + FloatNumber + PartialOrd> {
/// Epsilon in the epsilon-SVR model. /// Epsilon in the epsilon-SVR model.
pub eps: T, pub eps: T,
@@ -98,7 +97,7 @@ pub struct SVRParameters<T: Number + FloatNumber + PartialOrd> {
all(feature = "serde", target_arch = "wasm32"), all(feature = "serde", target_arch = "wasm32"),
serde(skip_serializing, skip_deserializing) serde(skip_serializing, skip_deserializing)
)] )]
pub kernel: Option<Kernels>, pub kernel: Option<Box<dyn Kernel>>,
} }
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@@ -161,8 +160,8 @@ impl<T: Number + FloatNumber + PartialOrd> SVRParameters<T> {
self self
} }
/// The kernel function. /// The kernel function.
pub fn with_kernel(mut self, kernel: Kernels) -> Self { pub fn with_kernel<K: Kernel + 'static>(mut self, kernel: K) -> Self {
self.kernel = Some(kernel); self.kernel = Some(Box::new(kernel));
self self
} }
} }
@@ -598,25 +597,25 @@ mod tests {
use super::*; use super::*;
use crate::linalg::basic::matrix::DenseMatrix; use crate::linalg::basic::matrix::DenseMatrix;
use crate::metrics::mean_squared_error; use crate::metrics::mean_squared_error;
use crate::svm::search::svr_params::SVRSearchParameters;
use crate::svm::Kernels; use crate::svm::Kernels;
#[test] // #[test]
fn search_parameters() { // fn search_parameters() {
let parameters: SVRSearchParameters<f64, DenseMatrix<f64>> = SVRSearchParameters { // let parameters: SVRSearchParameters<f64, DenseMatrix<f64>, LinearKernel> =
eps: vec![0., 1.], // SVRSearchParameters {
kernel: vec![Kernels::linear()], // eps: vec![0., 1.],
..Default::default() // kernel: vec![LinearKernel {}],
}; // ..Default::default()
let mut iter = parameters.into_iter(); // };
let next = iter.next().unwrap(); // let mut iter = parameters.into_iter();
assert_eq!(next.eps, 0.); // let next = iter.next().unwrap();
// assert_eq!(next.eps, 0.);
// assert_eq!(next.kernel, LinearKernel {}); // assert_eq!(next.kernel, LinearKernel {});
// let next = iter.next().unwrap(); // let next = iter.next().unwrap();
// assert_eq!(next.eps, 1.); // assert_eq!(next.eps, 1.);
// assert_eq!(next.kernel, LinearKernel {}); // assert_eq!(next.kernel, LinearKernel {});
// assert!(iter.next().is_none()); // assert!(iter.next().is_none());
} // }
#[cfg_attr( #[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")), all(target_arch = "wasm32", not(target_os = "wasi")),
@@ -649,7 +648,7 @@ mod tests {
114.2, 115.7, 116.9, 114.2, 115.7, 116.9,
]; ];
let knl: Kernels = Kernels::linear(); let knl = Kernels::linear();
let y_hat = SVR::fit( let y_hat = SVR::fit(
&x, &x,
&y, &y,