76 Commits

Author SHA1 Message Date
morenol
62de25b2ae Handle kernel serialization (#232)
* Handle kernel serialization
* Do not use typetag in WASM
* enable tests for serialization
* Update serde feature deps

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
Co-authored-by: Lorenzo <tunedconsulting@gmail.com>
2022-11-08 11:29:56 -05:00
morenol
7d87451333 Fixes for release (#237)
* Fixes for release
* add new test
* Remove change applied in development branch
* Only add dependency for wasm32
* Update ci.yml

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
Co-authored-by: Lorenzo <tunedconsulting@gmail.com>
2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
265fd558e7 make work cargo build --target wasm32-unknown-unknown 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
e25e2aea2b update CHANGELOG 2022-11-08 11:29:56 -05:00
Lorenzo
2f6dd1325e update comment 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
b0dece9476 use getrandom/js 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
c507d976be Update CHANGELOG 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
fa54d5ee86 Remove unused tests flags 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
459d558d48 minor fixes to doc 2022-11-08 11:29:56 -05:00
Lorenzo
1b7dda30a2 minor fix 2022-11-08 11:29:56 -05:00
Lorenzo
c1bd1df5f6 minor fix 2022-11-08 11:29:56 -05:00
Lorenzo
cf751f05aa minor fix 2022-11-08 11:29:56 -05:00
Lorenzo
63ed89aadd minor fix 2022-11-08 11:29:56 -05:00
Lorenzo
890e9d644c minor fix 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
af0a740394 Fix std_rand feature 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
616e38c282 cleanup 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
a449fdd4ea fmt 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
669f87f812 Use getrandom as default (for no-std feature) 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
6d529b34d2 Add static analyzer to doc 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
3ec9e4f0db Exclude datasets test for wasm/wasi 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
527477dea7 minor fixes 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
5b517c5048 minor fix 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
2df0795be9 Release 0.3 2022-11-08 11:29:56 -05:00
Lorenzo
0dc97a4e9b Create DEVELOPERS.md 2022-11-08 11:29:56 -05:00
Lorenzo
6c0fd37222 Update README.md 2022-11-08 11:29:56 -05:00
Lorenzo
d8d0fb6903 Update README.md 2022-11-08 11:29:56 -05:00
morenol
8d07efd921 Use Box in SVM and remove lifetimes (#228)
* Do not change external API
Authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-11-08 11:29:56 -05:00
morenol
ba27dd2a55 Fix CI (#227)
* Update ci.yml
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-11-08 11:29:56 -05:00
Lorenzo
ed9769f651 Implement CSV reader with new traits (#209) 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
b427e5d8b1 Improve options conditionals 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
fabe362755 Implement Display for NaiveBayes 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
ee6b6a53d6 cargo clippy 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
19f3a2fcc0 Fix signature of metrics tests 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
e09c4ba724 Add kernels' parameters to public interface 2022-11-08 11:29:56 -05:00
Lorenzo
6624732a65 Fix svr tests (#222) 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
1cbde3ba22 Refactor modules structure in src/svm 2022-11-08 11:29:56 -05:00
Lorenzo (Mec-iS)
551a6e34a5 clean up svm 2022-11-08 11:29:56 -05:00
Lorenzo
c45bab491a Support Wasi as target (#216)
* Improve features
* Add wasm32-wasi as a target
* Update .github/workflows/ci.yml
Co-authored-by: morenol <22335041+morenol@users.noreply.github.com>
2022-11-08 11:29:56 -05:00
Lorenzo
7f35dc54e4 Disambiguate distances. Implement Fastpair. (#220) 2022-11-08 11:29:56 -05:00
morenol
8f1a7dfd79 build: fix compilation without default features (#218)
* build: fix compilation with optional features
* Remove unused config from Cargo.toml
* Fix cache keys
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-11-08 11:29:56 -05:00
Lorenzo
712c478af6 Improve features (#215) 2022-11-08 11:29:56 -05:00
Lorenzo
4d36b7f34f Fix metrics::auc (#212)
* Fix metrics::auc
2022-11-08 11:29:56 -05:00
Lorenzo
a16927aa16 Port ensemble. Add Display to naive_bayes (#208) 2022-11-08 11:29:56 -05:00
Lorenzo
d91f4f7ce4 Update README.md 2022-11-08 11:29:56 -05:00
Lorenzo
a7fa0585eb Merge potential next release v0.4 (#187) Breaking Changes
* First draft of the new n-dimensional arrays + NB use case
* Improves default implementation of multiple Array methods
* Refactors tree methods
* Adds matrix decomposition routines
* Adds matrix decomposition methods to ndarray and nalgebra bindings
* Refactoring + linear regression now uses array2
* Ridge & Linear regression
* LBFGS optimizer & logistic regression
* LBFGS optimizer & logistic regression
* Changes linear methods, metrics and model selection methods to new n-dimensional arrays
* Switches KNN and clustering algorithms to new n-d array layer
* Refactors distance metrics
* Optimizes knn and clustering methods
* Refactors metrics module
* Switches decomposition methods to n-dimensional arrays
* Linalg refactoring - cleanup rng merge (#172)
* Remove legacy DenseMatrix and BaseMatrix implementation. Port the new Number, FloatNumber and Array implementation into module structure.
* Exclude AUC metrics. Needs reimplementation
* Improve developers walkthrough

New traits system in place at `src/numbers` and `src/linalg`
Co-authored-by: Lorenzo <tunedconsulting@gmail.com>

* Provide SupervisedEstimator with a constructor to avoid explicit dynamical box allocation in 'cross_validate' and 'cross_validate_predict' as required by the use of 'dyn' as per Rust 2021
* Implement getters to use as_ref() in src/neighbors
* Implement getters to use as_ref() in src/naive_bayes
* Implement getters to use as_ref() in src/linear
* Add Clone to src/naive_bayes
* Change signature for cross_validate and other model_selection functions to abide to use of dyn in Rust 2021
* Implement ndarray-bindings. Remove FloatNumber from implementations
* Drop nalgebra-bindings support (as decided in conf-call to go for ndarray)
* Remove benches. Benches will have their own repo at smartcore-benches
* Implement SVC
* Implement SVC serialization. Move search parameters in dedicated module
* Implement SVR. Definitely too slow
* Fix compilation issues for wasm (#202)

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
* Fix tests (#203)

* Port linalg/traits/stats.rs
* Improve methods naming
* Improve Display for DenseMatrix

Co-authored-by: Montana Low <montanalow@users.noreply.github.com>
Co-authored-by: VolodymyrOrlov <volodymyr.orlov@gmail.com>
2022-11-08 11:29:56 -05:00
RJ Nowling
a32eb66a6a Dataset doc cleanup (#205)
* Update iris.rs

* Update mod.rs

* Update digits.rs
2022-11-08 11:29:56 -05:00
Lorenzo
f605f6e075 Update README.md 2022-11-08 11:29:56 -05:00
Lorenzo
3b1aaaadf7 Update README.md 2022-11-08 11:29:56 -05:00
Lorenzo
d015b12402 Update CONTRIBUTING.md 2022-11-08 11:29:56 -05:00
morenol
d5200074c2 fix: fix issue with iterator for svc search (#182) 2022-11-08 11:29:56 -05:00
morenol
473cdfc44d refactor: Try to follow similar pattern to other APIs (#180)
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-11-08 11:29:56 -05:00
morenol
ad2e6c2900 feat: expose hyper tuning module in model_selection (#179)
* feat: expose hyper tuning module in model_selection

* Move to a folder

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-11-08 11:29:56 -05:00
Lorenzo
9ea3133c27 Update CONTRIBUTING.md 2022-11-08 11:29:56 -05:00
Lorenzo
e4c47c7540 Add contribution guidelines (#178) 2022-11-08 11:29:56 -05:00
Montana Low
f4fd4d2239 make default params available to serde (#167)
* add seed param to search params

* make default params available to serde

* lints

* create defaults for enums

* lint
2022-11-08 11:29:56 -05:00
Montana Low
05dfffad5c add seed param to search params (#168) 2022-11-08 11:29:56 -05:00
morenol
a37b552a7d Lmm/add seeds in more algorithms (#164)
* Provide better output in flaky tests

* feat: add seed parameter to multiple algorithms

* Update changelog

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-11-08 11:29:56 -05:00
Montana Low
55e1158581 Complete grid search params (#166)
* grid search draft

* hyperparam search for linear estimators

* grid search for ensembles

* support grid search for more algos

* grid search for unsupervised algos

* minor cleanup
2022-11-08 11:29:56 -05:00
morenol
cfa824d7db Provide better output in flaky tests (#163) 2022-11-08 11:29:56 -05:00
morenol
bb5b437a32 feat: allocate first and then proceed to create matrix from Vec of Ro… (#159)
* feat: allocate first and then proceed to create matrix from Vec of RowVectors
2022-11-08 11:29:56 -05:00
morenol
851533dfa7 Make rand_distr optional (#161) 2022-11-08 11:29:56 -05:00
Lorenzo
0d996edafe Update LICENSE 2022-11-08 11:29:56 -05:00
morenol
f291b71f4a fix: fix compilation warnings when running only with default features (#160)
* fix: fix compilation warnings when running only with default features
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-11-08 11:29:56 -05:00
Tim Toebrock
2d75c2c405 Implement a generic read_csv method (#147)
* feat: Add interface to build `Matrix` from rows.
* feat: Add option to derive `RealNumber` from string.
To construct a `Matrix` from csv, and therefore from string, I need to be able to deserialize a generic `RealNumber` from string.
* feat: Implement `Matrix::read_csv`.
2022-11-08 11:29:56 -05:00
Montana Low
1f2597be74 grid search (#154)
* grid search draft
* hyperparam search for linear estimators
2022-11-08 11:29:56 -05:00
Montana Low
0f442e96c0 Handle multiclass precision/recall (#152)
* handle multiclass precision/recall
2022-11-08 11:29:56 -05:00
dependabot[bot]
44e4be23a6 Update criterion requirement from 0.3 to 0.4 (#150)
* Update criterion requirement from 0.3 to 0.4

Updates the requirements on [criterion](https://github.com/bheisler/criterion.rs) to permit the latest version.
- [Release notes](https://github.com/bheisler/criterion.rs/releases)
- [Changelog](https://github.com/bheisler/criterion.rs/blob/master/CHANGELOG.md)
- [Commits](https://github.com/bheisler/criterion.rs/compare/0.3.0...0.4.0)

---
updated-dependencies:
- dependency-name: criterion
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

* fix criterion

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-11-08 11:29:56 -05:00
Christos Katsakioris
01f753f86d Add serde for StandardScaler (#148)
* Derive `serde::Serialize` and `serde::Deserialize` for
  `StandardScaler`.
* Add relevant unit test.

Signed-off-by: Christos Katsakioris <ckatsak@gmail.com>

Signed-off-by: Christos Katsakioris <ckatsak@gmail.com>
2022-11-08 11:29:56 -05:00
Tim Toebrock
df766eaf79 Implementation of Standard scaler (#143)
* docs: Fix typo in doc for categorical transformer.
* feat: Add option to take a column from Matrix.
I created the method `Matrix::take_column` that uses the `Matrix::take`-interface to extract a single column from a matrix. I need that feature in the implementation of  `StandardScaler`.
* feat: Add `StandardScaler`.
Authored-by: titoeb <timtoebrock@googlemail.com>
2022-11-08 11:29:56 -05:00
Lorenzo
09d9205696 Add example for FastPair (#144)
* Add example

* Move to top

* Add imports to example

* Fix imports
2022-11-08 11:29:56 -05:00
Lorenzo
dc7f01db4a Implement fastpair (#142)
* initial fastpair implementation
* FastPair initial implementation
* implement fastpair
* Add random test
* Add bench for fastpair
* Refactor with constructor for FastPair
* Add serialization for PairwiseDistance
* Add fp_bench feature for fastpair bench
2022-11-08 11:29:56 -05:00
Chris McComb
eb4b49d552 Added additional doctest and fixed indices (#141) 2022-11-08 11:29:56 -05:00
morenol
98e3465e7b Fix clippy warnings (#139)
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-11-08 11:29:56 -05:00
ferrouille
ea39024fd2 Add SVC::decision_function (#135) 2022-11-08 11:29:56 -05:00
dependabot[bot]
4e94feb872 Update nalgebra requirement from 0.23.0 to 0.31.0 (#128)
Updates the requirements on [nalgebra](https://github.com/dimforge/nalgebra) to permit the latest version.
- [Release notes](https://github.com/dimforge/nalgebra/releases)
- [Changelog](https://github.com/dimforge/nalgebra/blob/dev/CHANGELOG.md)
- [Commits](https://github.com/dimforge/nalgebra/compare/v0.23.0...v0.31.0)

---
updated-dependencies:
- dependency-name: nalgebra
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-11-08 11:29:56 -05:00
dependabot-preview[bot]
fa802d2d3f build(deps): update nalgebra requirement from 0.23.0 to 0.26.2 (#98)
* build(deps): update nalgebra requirement from 0.23.0 to 0.26.2

Updates the requirements on [nalgebra](https://github.com/dimforge/nalgebra) to permit the latest version.
- [Release notes](https://github.com/dimforge/nalgebra/releases)
- [Changelog](https://github.com/dimforge/nalgebra/blob/dev/CHANGELOG.md)
- [Commits](https://github.com/dimforge/nalgebra/compare/v0.23.0...v0.26.2)

Signed-off-by: dependabot-preview[bot] <support@dependabot.com>

* fix: updates for nalgebre

* test: explicitly call pow_mut from BaseVector since now it conflicts with nalgebra implementation

* Don't be strict with dependencies

Co-authored-by: dependabot-preview[bot] <27856297+dependabot-preview[bot]@users.noreply.github.com>
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-11-08 11:29:56 -05:00
91 changed files with 2049 additions and 5432 deletions
+1
View File
@@ -2,5 +2,6 @@
# the repo. Unless a later match takes precedence, # the repo. Unless a later match takes precedence,
# Developers in this list will be requested for # Developers in this list will be requested for
# review when someone opens a pull request. # review when someone opens a pull request.
* @VolodymyrOrlov
* @morenol * @morenol
* @Mec-iS * @Mec-iS
+1 -3
View File
@@ -37,8 +37,6 @@ $ rust-code-analysis-cli -p src/algorithm/neighbour/fastpair.rs --ls 22 --le 213
``` ```
* find more information about what happens in your binary with [`twiggy`](https://rustwasm.github.io/twiggy/install.html). This need a compiled binary so create a brief `main {}` function using `smartcore` and then point `twiggy` to that file. * find more information about what happens in your binary with [`twiggy`](https://rustwasm.github.io/twiggy/install.html). This need a compiled binary so create a brief `main {}` function using `smartcore` and then point `twiggy` to that file.
* Please take a look to the output of a profiler to spot most evident performance problems, see [this guide about using a profiler](http://www.codeofview.com/fix-rs/2017/01/24/how-to-optimize-rust-programs-on-linux/).
## Issue Report Process ## Issue Report Process
1. Go to the project's issues. 1. Go to the project's issues.
@@ -50,9 +48,9 @@ $ rust-code-analysis-cli -p src/algorithm/neighbour/fastpair.rs --ls 22 --le 213
1. After a PR is opened maintainers are notified 1. After a PR is opened maintainers are notified
2. Probably changes will be required to comply with the workflow, these commands are run automatically and all tests shall pass: 2. Probably changes will be required to comply with the workflow, these commands are run automatically and all tests shall pass:
* **Coverage** (optional): `tarpaulin` is used with command `cargo tarpaulin --out Lcov --all-features -- --test-threads 1`
* **Formatting**: run `rustfmt src/*.rs` to apply automatic formatting * **Formatting**: run `rustfmt src/*.rs` to apply automatic formatting
* **Linting**: `clippy` is used with command `cargo clippy --all-features -- -Drust-2018-idioms -Dwarnings` * **Linting**: `clippy` is used with command `cargo clippy --all-features -- -Drust-2018-idioms -Dwarnings`
* **Coverage** (optional): `tarpaulin` is used with command `cargo tarpaulin --out Lcov --all-features -- --test-threads 1`
* **Testing**: multiple test pipelines are run for different targets * **Testing**: multiple test pipelines are run for different targets
3. When everything is OK, code is merged. 3. When everything is OK, code is merged.
+14 -4
View File
@@ -19,13 +19,14 @@ jobs:
{ os: "ubuntu", target: "i686-unknown-linux-gnu" }, { os: "ubuntu", target: "i686-unknown-linux-gnu" },
{ os: "ubuntu", target: "wasm32-unknown-unknown" }, { os: "ubuntu", target: "wasm32-unknown-unknown" },
{ os: "macos", target: "aarch64-apple-darwin" }, { os: "macos", target: "aarch64-apple-darwin" },
{ os: "ubuntu", target: "wasm32-wasi" },
] ]
env: env:
TZ: "/usr/share/zoneinfo/your/location" TZ: "/usr/share/zoneinfo/your/location"
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v3
- name: Cache .cargo and target - name: Cache .cargo and target
uses: actions/cache@v4 uses: actions/cache@v2
with: with:
path: | path: |
~/.cargo ~/.cargo
@@ -42,6 +43,9 @@ jobs:
- name: Install test runner for wasm - name: Install test runner for wasm
if: matrix.platform.target == 'wasm32-unknown-unknown' if: matrix.platform.target == 'wasm32-unknown-unknown'
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
- name: Install test runner for wasi
if: matrix.platform.target == 'wasm32-wasi'
run: curl https://wasmtime.dev/install.sh -sSf | bash
- name: Stable Build with all features - name: Stable Build with all features
uses: actions-rs/cargo@v1 uses: actions-rs/cargo@v1
with: with:
@@ -61,6 +65,12 @@ jobs:
- name: Tests in WASM - name: Tests in WASM
if: matrix.platform.target == 'wasm32-unknown-unknown' if: matrix.platform.target == 'wasm32-unknown-unknown'
run: wasm-pack test --node -- --all-features run: wasm-pack test --node -- --all-features
- name: Tests in WASI
if: matrix.platform.target == 'wasm32-wasi'
run: |
export WASMTIME_HOME="$HOME/.wasmtime"
export PATH="$WASMTIME_HOME/bin:$PATH"
cargo install cargo-wasi && cargo wasi test
check_features: check_features:
runs-on: "${{ matrix.platform.os }}-latest" runs-on: "${{ matrix.platform.os }}-latest"
@@ -71,9 +81,9 @@ jobs:
env: env:
TZ: "/usr/share/zoneinfo/your/location" TZ: "/usr/share/zoneinfo/your/location"
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v3
- name: Cache .cargo and target - name: Cache .cargo and target
uses: actions/cache@v4 uses: actions/cache@v2
with: with:
path: | path: |
~/.cargo ~/.cargo
+3 -3
View File
@@ -12,9 +12,9 @@ jobs:
env: env:
TZ: "/usr/share/zoneinfo/your/location" TZ: "/usr/share/zoneinfo/your/location"
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v2
- name: Cache .cargo - name: Cache .cargo
uses: actions/cache@v4 uses: actions/cache@v2
with: with:
path: | path: |
~/.cargo ~/.cargo
@@ -41,4 +41,4 @@ jobs:
- name: Upload to codecov.io - name: Upload to codecov.io
uses: codecov/codecov-action@v2 uses: codecov/codecov-action@v2
with: with:
fail_ci_if_error: false fail_ci_if_error: true
+1 -1
View File
@@ -14,7 +14,7 @@ jobs:
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Cache .cargo and target - name: Cache .cargo and target
uses: actions/cache@v4 uses: actions/cache@v2
with: with:
path: | path: |
~/.cargo ~/.cargo
-6
View File
@@ -4,12 +4,6 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.4.0] - 2023-04-05
## Added
- WARNING: Breaking changes!
- `DenseMatrix` constructor now returns `Result` to avoid user instantiating inconsistent rows/cols count. Their return values need to be unwrapped with `unwrap()`, see tests
## [0.3.0] - 2022-11-09 ## [0.3.0] - 2022-11-09
## Added ## Added
+3 -3
View File
@@ -2,7 +2,7 @@
name = "smartcore" name = "smartcore"
description = "Machine Learning in Rust." description = "Machine Learning in Rust."
homepage = "https://smartcorelib.org" homepage = "https://smartcorelib.org"
version = "0.4.2" version = "0.3.0"
authors = ["smartcore Developers"] authors = ["smartcore Developers"]
edition = "2021" edition = "2021"
license = "Apache-2.0" license = "Apache-2.0"
@@ -42,13 +42,13 @@ std_rand = ["rand/std_rng", "rand/std"]
js = ["getrandom/js"] js = ["getrandom/js"]
[target.'cfg(target_arch = "wasm32")'.dependencies] [target.'cfg(target_arch = "wasm32")'.dependencies]
getrandom = { version = "0.2.8", optional = true } getrandom = { version = "*", optional = true }
[target.'cfg(all(target_arch = "wasm32", not(target_os = "wasi")))'.dev-dependencies] [target.'cfg(all(target_arch = "wasm32", not(target_os = "wasi")))'.dev-dependencies]
wasm-bindgen-test = "0.3" wasm-bindgen-test = "0.3"
[dev-dependencies] [dev-dependencies]
itertools = "0.13.0" itertools = "*"
serde_json = "1.0" serde_json = "1.0"
bincode = "1.3.1" bincode = "1.3.1"
+1 -1
View File
@@ -18,4 +18,4 @@
----- -----
[![CI](https://github.com/smartcorelib/smartcore/actions/workflows/ci.yml/badge.svg)](https://github.com/smartcorelib/smartcore/actions/workflows/ci.yml) [![CI](https://github.com/smartcorelib/smartcore/actions/workflows/ci.yml/badge.svg)](https://github.com/smartcorelib/smartcore/actions/workflows/ci.yml)
To start getting familiar with the new smartcore v0.4 API, there is now available a [**Jupyter Notebook environment repository**](https://github.com/smartcorelib/smartcore-jupyter). Please see instructions there, contributions welcome see [CONTRIBUTING](.github/CONTRIBUTING.md). To start getting familiar with the new smartcore v0.5 API, there is now available a [**Jupyter Notebook environment repository**](https://github.com/smartcorelib/smartcore-jupyter). Please see instructions there, contributions welcome see [CONTRIBUTING](.github/CONTRIBUTING.md).
+15
View File
@@ -0,0 +1,15 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="RUST_MODULE" version="4">
<component name="NewModuleRootManager" inherit-compiler-output="true">
<exclude-output />
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/examples" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/tests" isTestSource="true" />
<sourceFolder url="file://$MODULE_DIR$/benches" isTestSource="true" />
<excludeFolder url="file://$MODULE_DIR$/target" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>
+3 -4
View File
@@ -40,11 +40,11 @@ impl BBDTreeNode {
impl BBDTree { impl BBDTree {
pub fn new<T: Number, M: Array2<T>>(data: &M) -> BBDTree { pub fn new<T: Number, M: Array2<T>>(data: &M) -> BBDTree {
let nodes: Vec<BBDTreeNode> = Vec::new(); let nodes = Vec::new();
let (n, _) = data.shape(); let (n, _) = data.shape();
let index = (0..n).collect::<Vec<usize>>(); let index = (0..n).collect::<Vec<_>>();
let mut tree = BBDTree { let mut tree = BBDTree {
nodes, nodes,
@@ -343,8 +343,7 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0], &[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3], &[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4], &[5.2, 2.7, 3.9, 1.4],
]) ]);
.unwrap();
let tree = BBDTree::new(&data); let tree = BBDTree::new(&data);
+4 -4
View File
@@ -124,7 +124,7 @@ impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
current_cover_set.push((d, &self.root)); current_cover_set.push((d, &self.root));
let mut heap = HeapSelection::with_capacity(k); let mut heap = HeapSelection::with_capacity(k);
heap.add(f64::MAX); heap.add(std::f64::MAX);
let mut empty_heap = true; let mut empty_heap = true;
if !self.identical_excluded || self.get_data_value(self.root.idx) != p { if !self.identical_excluded || self.get_data_value(self.root.idx) != p {
@@ -145,7 +145,7 @@ impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
} }
let upper_bound = if empty_heap { let upper_bound = if empty_heap {
f64::INFINITY std::f64::INFINITY
} else { } else {
*heap.peek() *heap.peek()
}; };
@@ -291,7 +291,7 @@ impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
} else { } else {
let max_dist = self.max(point_set); let max_dist = self.max(point_set);
let next_scale = (max_scale - 1).min(self.get_scale(max_dist)); let next_scale = (max_scale - 1).min(self.get_scale(max_dist));
if next_scale == i64::MIN { if next_scale == std::i64::MIN {
let mut children: Vec<Node> = Vec::new(); let mut children: Vec<Node> = Vec::new();
let mut leaf = self.new_leaf(p); let mut leaf = self.new_leaf(p);
children.push(leaf); children.push(leaf);
@@ -435,7 +435,7 @@ impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
fn get_scale(&self, d: f64) -> i64 { fn get_scale(&self, d: f64) -> i64 {
if d == 0f64 { if d == 0f64 {
i64::MIN std::i64::MIN
} else { } else {
(self.inv_log_base * d.ln()).ceil() as i64 (self.inv_log_base * d.ln()).ceil() as i64
} }
+30 -138
View File
@@ -17,7 +17,7 @@
/// &[4.6, 3.1, 1.5, 0.2], /// &[4.6, 3.1, 1.5, 0.2],
/// &[5.0, 3.6, 1.4, 0.2], /// &[5.0, 3.6, 1.4, 0.2],
/// &[5.4, 3.9, 1.7, 0.4], /// &[5.4, 3.9, 1.7, 0.4],
/// ]).unwrap(); /// ]);
/// let fastpair = FastPair::new(&x); /// let fastpair = FastPair::new(&x);
/// let closest_pair: PairwiseDistance<f64> = fastpair.unwrap().closest_pair(); /// let closest_pair: PairwiseDistance<f64> = fastpair.unwrap().closest_pair();
/// ``` /// ```
@@ -52,8 +52,10 @@ pub struct FastPair<'a, T: RealNumber + FloatNumber, M: Array2<T>> {
} }
impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> { impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
///
/// Constructor /// Constructor
/// Instantiate and initialize the algorithm /// Instantiate and inizialise the algorithm
///
pub fn new(m: &'a M) -> Result<Self, Failed> { pub fn new(m: &'a M) -> Result<Self, Failed> {
if m.shape().0 < 3 { if m.shape().0 < 3 {
return Err(Failed::because( return Err(Failed::because(
@@ -72,8 +74,10 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
Ok(init) Ok(init)
} }
///
/// Initialise `FastPair` by passing a `Array2`. /// Initialise `FastPair` by passing a `Array2`.
/// Build a FastPairs data-structure from a set of (new) points. /// Build a FastPairs data-structure from a set of (new) points.
///
fn init(&mut self) { fn init(&mut self) {
// basic measures // basic measures
let len = self.samples.shape().0; let len = self.samples.shape().0;
@@ -154,7 +158,9 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
self.neighbours = neighbours; self.neighbours = neighbours;
} }
///
/// Find closest pair by scanning list of nearest neighbors. /// Find closest pair by scanning list of nearest neighbors.
///
#[allow(dead_code)] #[allow(dead_code)]
pub fn closest_pair(&self) -> PairwiseDistance<T> { pub fn closest_pair(&self) -> PairwiseDistance<T> {
let mut a = self.neighbours[0]; // Start with first point let mut a = self.neighbours[0]; // Start with first point
@@ -173,21 +179,6 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
} }
} }
///
/// Return order dissimilarities from closest to furthest
///
#[allow(dead_code)]
pub fn ordered_pairs(&self) -> std::vec::IntoIter<&PairwiseDistance<T>> {
// improvement: implement this to return `impl Iterator<Item = &PairwiseDistance<T>>`
// need to implement trait `Iterator` for `Vec<&PairwiseDistance<T>>`
let mut distances = self
.distances
.values()
.collect::<Vec<&PairwiseDistance<T>>>();
distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
distances.into_iter()
}
// //
// Compute distances from input to all other points in data-structure. // Compute distances from input to all other points in data-structure.
// input is the row index of the sample matrix // input is the row index of the sample matrix
@@ -226,10 +217,10 @@ mod tests_fastpair {
use super::*; use super::*;
use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix}; use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix};
///
/// Brute force algorithm, used only for comparison and testing /// Brute force algorithm, used only for comparison and testing
pub fn closest_pair_brute( ///
fastpair: &FastPair<'_, f64, DenseMatrix<f64>>, pub fn closest_pair_brute(fastpair: &FastPair<f64, DenseMatrix<f64>>) -> PairwiseDistance<f64> {
) -> PairwiseDistance<f64> {
use itertools::Itertools; use itertools::Itertools;
let m = fastpair.samples.shape().0; let m = fastpair.samples.shape().0;
@@ -269,8 +260,8 @@ mod tests_fastpair {
let distances = fastpair.distances; let distances = fastpair.distances;
let neighbours = fastpair.neighbours; let neighbours = fastpair.neighbours;
assert!(!distances.is_empty()); assert!(distances.len() != 0);
assert!(!neighbours.is_empty()); assert!(neighbours.len() != 0);
assert_eq!(10, neighbours.len()); assert_eq!(10, neighbours.len());
assert_eq!(10, distances.len()); assert_eq!(10, distances.len());
@@ -280,24 +271,28 @@ mod tests_fastpair {
fn dataset_has_at_least_three_points() { fn dataset_has_at_least_three_points() {
// Create a dataset which consists of only two points: // Create a dataset which consists of only two points:
// A(0.0, 0.0) and B(1.0, 1.0). // A(0.0, 0.0) and B(1.0, 1.0).
let dataset = DenseMatrix::<f64>::from_2d_array(&[&[0.0, 0.0], &[1.0, 1.0]]).unwrap(); let dataset = DenseMatrix::<f64>::from_2d_array(&[&[0.0, 0.0], &[1.0, 1.0]]);
// We expect an error when we run `FastPair` on this dataset, // We expect an error when we run `FastPair` on this dataset,
// becuase `FastPair` currently only works on a minimum of 3 // becuase `FastPair` currently only works on a minimum of 3
// points. // points.
let fastpair = FastPair::new(&dataset); let _fastpair = FastPair::new(&dataset);
assert!(fastpair.is_err());
if let Err(e) = fastpair { match _fastpair {
let expected_error = Err(e) => {
Failed::because(FailedError::FindFailed, "min number of rows should be 3"); let expected_error =
assert_eq!(e, expected_error) Failed::because(FailedError::FindFailed, "min number of rows should be 3");
assert_eq!(e, expected_error)
}
_ => {
assert!(false);
}
} }
} }
#[test] #[test]
fn one_dimensional_dataset_minimal() { fn one_dimensional_dataset_minimal() {
let dataset = DenseMatrix::<f64>::from_2d_array(&[&[0.0], &[2.0], &[9.0]]).unwrap(); let dataset = DenseMatrix::<f64>::from_2d_array(&[&[0.0], &[2.0], &[9.0]]);
let result = FastPair::new(&dataset); let result = FastPair::new(&dataset);
assert!(result.is_ok()); assert!(result.is_ok());
@@ -317,8 +312,7 @@ mod tests_fastpair {
#[test] #[test]
fn one_dimensional_dataset_2() { fn one_dimensional_dataset_2() {
let dataset = let dataset = DenseMatrix::<f64>::from_2d_array(&[&[27.0], &[0.0], &[9.0], &[2.0]]);
DenseMatrix::<f64>::from_2d_array(&[&[27.0], &[0.0], &[9.0], &[2.0]]).unwrap();
let result = FastPair::new(&dataset); let result = FastPair::new(&dataset);
assert!(result.is_ok()); assert!(result.is_ok());
@@ -353,8 +347,7 @@ mod tests_fastpair {
&[6.9, 3.1, 4.9, 1.5], &[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3], &[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5], &[6.5, 2.8, 4.6, 1.5],
]) ]);
.unwrap();
let fastpair = FastPair::new(&x); let fastpair = FastPair::new(&x);
assert!(fastpair.is_ok()); assert!(fastpair.is_ok());
@@ -527,8 +520,7 @@ mod tests_fastpair {
&[6.9, 3.1, 4.9, 1.5], &[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3], &[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5], &[6.5, 2.8, 4.6, 1.5],
]) ]);
.unwrap();
// compute // compute
let fastpair = FastPair::new(&x); let fastpair = FastPair::new(&x);
assert!(fastpair.is_ok()); assert!(fastpair.is_ok());
@@ -576,8 +568,7 @@ mod tests_fastpair {
&[6.9, 3.1, 4.9, 1.5], &[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3], &[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5], &[6.5, 2.8, 4.6, 1.5],
]) ]);
.unwrap();
// compute // compute
let fastpair = FastPair::new(&x); let fastpair = FastPair::new(&x);
assert!(fastpair.is_ok()); assert!(fastpair.is_ok());
@@ -591,7 +582,7 @@ mod tests_fastpair {
}; };
for p in dissimilarities.iter() { for p in dissimilarities.iter() {
if p.distance.unwrap() < min_dissimilarity.distance.unwrap() { if p.distance.unwrap() < min_dissimilarity.distance.unwrap() {
min_dissimilarity = *p min_dissimilarity = p.clone()
} }
} }
@@ -603,103 +594,4 @@ mod tests_fastpair {
assert_eq!(closest, min_dissimilarity); assert_eq!(closest, min_dissimilarity);
} }
#[test]
fn fastpair_ordered_pairs() {
let x = DenseMatrix::<f64>::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
])
.unwrap();
let fastpair = FastPair::new(&x).unwrap();
let ordered = fastpair.ordered_pairs();
let mut previous: f64 = -1.0;
for p in ordered {
if previous == -1.0 {
previous = p.distance.unwrap();
} else {
let current = p.distance.unwrap();
assert!(current >= previous);
previous = current;
}
}
}
#[test]
fn test_empty_set() {
let empty_matrix = DenseMatrix::<f64>::zeros(0, 0);
let result = FastPair::new(&empty_matrix);
assert!(result.is_err());
if let Err(e) = result {
assert_eq!(
e,
Failed::because(FailedError::FindFailed, "min number of rows should be 3")
);
}
}
#[test]
fn test_single_point() {
let single_point = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]).unwrap();
let result = FastPair::new(&single_point);
assert!(result.is_err());
if let Err(e) = result {
assert_eq!(
e,
Failed::because(FailedError::FindFailed, "min number of rows should be 3")
);
}
}
#[test]
fn test_two_points() {
let two_points = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
let result = FastPair::new(&two_points);
assert!(result.is_err());
if let Err(e) = result {
assert_eq!(
e,
Failed::because(FailedError::FindFailed, "min number of rows should be 3")
);
}
}
#[test]
fn test_three_identical_points() {
let identical_points =
DenseMatrix::from_2d_array(&[&[1.0, 1.0], &[1.0, 1.0], &[1.0, 1.0]]).unwrap();
let result = FastPair::new(&identical_points);
assert!(result.is_ok());
let fastpair = result.unwrap();
let closest_pair = fastpair.closest_pair();
assert_eq!(closest_pair.distance, Some(0.0));
}
#[test]
fn test_result_unwrapping() {
let valid_matrix =
DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0], &[5.0, 6.0], &[7.0, 8.0]])
.unwrap();
let result = FastPair::new(&valid_matrix);
assert!(result.is_ok());
// This should not panic
let _fastpair = result.unwrap();
}
} }
+2 -2
View File
@@ -61,7 +61,7 @@ impl<T, D: Distance<T>> LinearKNNSearch<T, D> {
for _ in 0..k { for _ in 0..k {
heap.add(KNNPoint { heap.add(KNNPoint {
distance: f64::INFINITY, distance: std::f64::INFINITY,
index: None, index: None,
}); });
} }
@@ -215,7 +215,7 @@ mod tests {
}; };
let point_inf = KNNPoint { let point_inf = KNNPoint {
distance: f64::INFINITY, distance: std::f64::INFINITY,
index: Some(3), index: Some(3),
}; };
+7 -2
View File
@@ -49,15 +49,20 @@ pub mod linear_search;
/// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries. /// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries.
/// `KNNAlgorithmName` maintains a list of supported search algorithms, see [KNN algorithms](../algorithm/neighbour/index.html) /// `KNNAlgorithmName` maintains a list of supported search algorithms, see [KNN algorithms](../algorithm/neighbour/index.html)
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone)]
pub enum KNNAlgorithmName { pub enum KNNAlgorithmName {
/// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html) /// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html)
LinearSearch, LinearSearch,
/// Cover Tree Search algorithm, see [`CoverTree`](../algorithm/neighbour/cover_tree/index.html) /// Cover Tree Search algorithm, see [`CoverTree`](../algorithm/neighbour/cover_tree/index.html)
#[default]
CoverTree, CoverTree,
} }
impl Default for KNNAlgorithmName {
fn default() -> Self {
KNNAlgorithmName::CoverTree
}
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)] #[derive(Debug)]
pub(crate) enum KNNAlgorithm<T: Number, D: Distance<Vec<T>>> { pub(crate) enum KNNAlgorithm<T: Number, D: Distance<Vec<T>>> {
+2 -2
View File
@@ -133,7 +133,7 @@ mod tests {
#[test] #[test]
fn test_add1() { fn test_add1() {
let mut heap = HeapSelection::with_capacity(3); let mut heap = HeapSelection::with_capacity(3);
heap.add(f64::INFINITY); heap.add(std::f64::INFINITY);
heap.add(-5f64); heap.add(-5f64);
heap.add(4f64); heap.add(4f64);
heap.add(-1f64); heap.add(-1f64);
@@ -151,7 +151,7 @@ mod tests {
#[test] #[test]
fn test_add2() { fn test_add2() {
let mut heap = HeapSelection::with_capacity(3); let mut heap = HeapSelection::with_capacity(3);
heap.add(f64::INFINITY); heap.add(std::f64::INFINITY);
heap.add(0.0); heap.add(0.0);
heap.add(8.4852); heap.add(8.4852);
heap.add(5.6568); heap.add(5.6568);
-1
View File
@@ -3,7 +3,6 @@ use num_traits::Num;
pub trait QuickArgSort { pub trait QuickArgSort {
fn quick_argsort_mut(&mut self) -> Vec<usize>; fn quick_argsort_mut(&mut self) -> Vec<usize>;
#[allow(dead_code)]
fn quick_argsort(&self) -> Vec<usize>; fn quick_argsort(&self) -> Vec<usize>;
} }
-317
View File
@@ -1,317 +0,0 @@
//! # Agglomerative Hierarchical Clustering
//!
//! Agglomerative clustering is a "bottom-up" hierarchical clustering method. It works by placing each data point in its own cluster and then successively merging the two most similar clusters until a stopping criterion is met. This process creates a tree-based hierarchy of clusters known as a dendrogram.
//!
//! The similarity of two clusters is determined by a **linkage criterion**. This implementation uses **single-linkage**, where the distance between two clusters is defined as the minimum distance between any single point in the first cluster and any single point in the second cluster. The distance between points is the standard Euclidean distance.
//!
//! The algorithm first builds the full hierarchy of `N-1` merges. To obtain a specific number of clusters, `n_clusters`, the algorithm then effectively "cuts" the dendrogram at the point where `n_clusters` remain.
//!
//! ## Example:
//!
//! ```
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::cluster::agglomerative::{AgglomerativeClustering, AgglomerativeClusteringParameters};
//!
//! // A dataset with 2 distinct groups of points.
//! let x = DenseMatrix::from_2d_array(&[
//! &[0.0, 0.0], &[1.0, 1.0], &[0.5, 0.5], // Cluster A
//! &[10.0, 10.0], &[11.0, 11.0], &[10.5, 10.5], // Cluster B
//! ]).unwrap();
//!
//! // Set parameters to find 2 clusters.
//! let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(2);
//!
//! // Fit the model to the data.
//! let clustering = AgglomerativeClustering::<f64, usize, DenseMatrix<f64>, Vec<usize>>::fit(&x, parameters).unwrap();
//!
//! // Get the cluster assignments.
//! let labels = clustering.labels; // e.g., [0, 0, 0, 1, 1, 1]
//! ```
//!
//! ## References:
//!
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 10.3.2 Hierarchical Clustering](http://faculty.marshall.usc.edu/gareth-james/ISL/)
//! * ["The Elements of Statistical Learning", Hastie T., Tibshirani R., Friedman J., 14.3.12 Hierarchical Clustering](https://hastie.su.domains/ElemStatLearn/)
use std::collections::HashMap;
use std::marker::PhantomData;
use crate::api::UnsupervisedEstimator;
use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::numbers::basenum::Number;
/// Parameters for the Agglomerative Clustering algorithm.
#[derive(Debug, Clone, Copy)]
pub struct AgglomerativeClusteringParameters {
/// The number of clusters to find.
pub n_clusters: usize,
}
impl AgglomerativeClusteringParameters {
/// Sets the number of clusters.
///
/// # Arguments
/// * `n_clusters` - The desired number of clusters.
pub fn with_n_clusters(mut self, n_clusters: usize) -> Self {
self.n_clusters = n_clusters;
self
}
}
impl Default for AgglomerativeClusteringParameters {
fn default() -> Self {
AgglomerativeClusteringParameters { n_clusters: 2 }
}
}
/// Agglomerative Clustering model.
///
/// This implementation uses single-linkage clustering, which is mathematically
/// equivalent to finding the Minimum Spanning Tree (MST) of the data points.
/// The core logic is an efficient implementation of Kruskal's algorithm, which
/// processes all pairwise distances in increasing order and uses a Disjoint
/// Set Union (DSU) data structure to track cluster membership.
#[derive(Debug)]
pub struct AgglomerativeClustering<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
/// The cluster label assigned to each sample.
pub labels: Vec<usize>,
_phantom_tx: PhantomData<TX>,
_phantom_ty: PhantomData<TY>,
_phantom_x: PhantomData<X>,
_phantom_y: PhantomData<Y>,
}
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> AgglomerativeClustering<TX, TY, X, Y> {
/// Fits the agglomerative clustering model to the data.
///
/// # Arguments
/// * `data` - A reference to the input data matrix.
/// * `parameters` - The parameters for the clustering algorithm, including `n_clusters`.
///
/// # Returns
/// A `Result` containing the fitted model with cluster labels, or an error if
pub fn fit(data: &X, parameters: AgglomerativeClusteringParameters) -> Result<Self, Failed> {
let (num_samples, _) = data.shape();
let n_clusters = parameters.n_clusters;
if n_clusters > num_samples {
return Err(Failed::because(
FailedError::ParametersError,
&format!(
"n_clusters: {n_clusters} cannot be greater than n_samples: {num_samples}"
),
));
}
let mut distance_pairs = Vec::new();
for i in 0..num_samples {
for j in (i + 1)..num_samples {
let distance: f64 = data
.get_row(i)
.iterator(0)
.zip(data.get_row(j).iterator(0))
.map(|(&a, &b)| (a.to_f64().unwrap() - b.to_f64().unwrap()).powi(2))
.sum::<f64>();
distance_pairs.push((distance, i, j));
}
}
distance_pairs.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
let mut parent = HashMap::new();
let mut children = HashMap::new();
for i in 0..num_samples {
parent.insert(i, i);
children.insert(i, vec![i]);
}
let mut merge_history = Vec::new();
let num_merges_needed = num_samples - 1;
while merge_history.len() < num_merges_needed {
let (_, p1, p2) = distance_pairs.pop().unwrap();
let root1 = parent[&p1];
let root2 = parent[&p2];
if root1 != root2 {
let root2_children = children.remove(&root2).unwrap();
for child in root2_children.iter() {
parent.insert(*child, root1);
}
let root1_children = children.get_mut(&root1).unwrap();
root1_children.extend(root2_children);
merge_history.push((root1, root2));
}
}
let mut clusters = HashMap::new();
let mut assignments = HashMap::new();
for i in 0..num_samples {
clusters.insert(i, vec![i]);
assignments.insert(i, i);
}
let merges_to_apply = num_samples - n_clusters;
for (root1, root2) in merge_history[0..merges_to_apply].iter() {
let root1_cluster = assignments[root1];
let root2_cluster = assignments[root2];
let root2_assignments = clusters.remove(&root2_cluster).unwrap();
for assignment in root2_assignments.iter() {
assignments.insert(*assignment, root1_cluster);
}
let root1_assignments = clusters.get_mut(&root1_cluster).unwrap();
root1_assignments.extend(root2_assignments);
}
let mut labels: Vec<usize> = (0..num_samples).map(|_| 0).collect();
let mut cluster_keys: Vec<&usize> = clusters.keys().collect();
cluster_keys.sort();
for (i, key) in cluster_keys.into_iter().enumerate() {
for index in clusters[key].iter() {
labels[*index] = i;
}
}
Ok(AgglomerativeClustering {
labels,
_phantom_tx: PhantomData,
_phantom_ty: PhantomData,
_phantom_x: PhantomData,
_phantom_y: PhantomData,
})
}
}
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>>
UnsupervisedEstimator<X, AgglomerativeClusteringParameters>
for AgglomerativeClustering<TX, TY, X, Y>
{
fn fit(x: &X, parameters: AgglomerativeClusteringParameters) -> Result<Self, Failed> {
AgglomerativeClustering::fit(x, parameters)
}
}
#[cfg(test)]
mod tests {
use crate::linalg::basic::matrix::DenseMatrix;
use std::collections::HashSet;
use super::*;
#[test]
fn test_simple_clustering() {
// Two distinct clusters, far apart.
let data = vec![
0.0, 0.0, 1.0, 1.0, 0.5, 0.5, // Cluster A
10.0, 10.0, 11.0, 11.0, 10.5, 10.5, // Cluster B
];
let matrix = DenseMatrix::new(6, 2, data, false).unwrap();
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(2);
// Using f64 for TY as usize doesn't satisfy the Number trait bound.
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
&matrix, parameters,
)
.unwrap();
let labels = clustering.labels;
// Check that all points in the first group have the same label.
let first_group_label = labels[0];
assert!(labels[0..3].iter().all(|&l| l == first_group_label));
// Check that all points in the second group have the same label.
let second_group_label = labels[3];
assert!(labels[3..6].iter().all(|&l| l == second_group_label));
// Check that the two groups have different labels.
assert_ne!(first_group_label, second_group_label);
}
#[test]
fn test_four_clusters() {
// Four distinct clusters in the corners of a square.
let data = vec![
0.0, 0.0, 1.0, 1.0, // Cluster A
100.0, 100.0, 101.0, 101.0, // Cluster B
0.0, 100.0, 1.0, 101.0, // Cluster C
100.0, 0.0, 101.0, 1.0, // Cluster D
];
let matrix = DenseMatrix::new(8, 2, data, false).unwrap();
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(4);
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
&matrix, parameters,
)
.unwrap();
let labels = clustering.labels;
// Verify that there are exactly 4 unique labels produced.
let unique_labels: HashSet<usize> = labels.iter().cloned().collect();
assert_eq!(unique_labels.len(), 4);
// Verify that points within each original group were assigned the same cluster label.
let label_a = labels[0];
assert_eq!(label_a, labels[1]);
let label_b = labels[2];
assert_eq!(label_b, labels[3]);
let label_c = labels[4];
assert_eq!(label_c, labels[5]);
let label_d = labels[6];
assert_eq!(label_d, labels[7]);
// Verify that all four groups received different labels.
assert_ne!(label_a, label_b);
assert_ne!(label_a, label_c);
assert_ne!(label_a, label_d);
assert_ne!(label_b, label_c);
assert_ne!(label_b, label_d);
assert_ne!(label_c, label_d);
}
#[test]
fn test_n_clusters_equal_to_samples() {
let data = vec![0.0, 0.0, 5.0, 5.0, 10.0, 10.0];
let matrix = DenseMatrix::new(3, 2, data, false).unwrap();
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(3);
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
&matrix, parameters,
)
.unwrap();
// Each point should be its own cluster. Sorting makes the test deterministic.
let mut labels = clustering.labels;
labels.sort();
assert_eq!(labels, vec![0, 1, 2]);
}
#[test]
fn test_one_cluster() {
let data = vec![0.0, 0.0, 5.0, 5.0, 10.0, 10.0];
let matrix = DenseMatrix::new(3, 2, data, false).unwrap();
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(1);
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
&matrix, parameters,
)
.unwrap();
// All points should be in the same cluster.
assert_eq!(clustering.labels, vec![0, 0, 0]);
}
#[test]
fn test_error_on_too_many_clusters() {
let data = vec![0.0, 0.0, 5.0, 5.0];
let matrix = DenseMatrix::new(2, 2, data, false).unwrap();
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(3);
let result = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
&matrix, parameters,
);
assert!(result.is_err());
}
}
+6 -7
View File
@@ -18,7 +18,7 @@
//! //!
//! Example: //! Example:
//! //!
//! ```ignore //! ```
//! use smartcore::linalg::basic::matrix::DenseMatrix; //! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linalg::basic::arrays::Array2; //! use smartcore::linalg::basic::arrays::Array2;
//! use smartcore::cluster::dbscan::*; //! use smartcore::cluster::dbscan::*;
@@ -315,7 +315,8 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
} }
} }
while let Some(neighbor) = neighbors.pop() { while !neighbors.is_empty() {
let neighbor = neighbors.pop().unwrap();
let index = neighbor.0; let index = neighbor.0;
if y[index] == outlier { if y[index] == outlier {
@@ -442,8 +443,7 @@ mod tests {
&[2.2, 1.2], &[2.2, 1.2],
&[1.8, 0.8], &[1.8, 0.8],
&[3.0, 5.0], &[3.0, 5.0],
]) ]);
.unwrap();
let expected_labels = vec![1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 0]; let expected_labels = vec![1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 0];
@@ -488,8 +488,7 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0], &[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3], &[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4], &[5.2, 2.7, 3.9, 1.4],
]) ]);
.unwrap();
let dbscan = DBSCAN::fit(&x, Default::default()).unwrap(); let dbscan = DBSCAN::fit(&x, Default::default()).unwrap();
@@ -512,6 +511,6 @@ mod tests {
.and_then(|dbscan| dbscan.predict(&x)) .and_then(|dbscan| dbscan.predict(&x))
.unwrap(); .unwrap();
println!("{labels:?}"); println!("{:?}", labels);
} }
} }
+10 -12
View File
@@ -41,7 +41,7 @@
//! &[4.9, 2.4, 3.3, 1.0], //! &[4.9, 2.4, 3.3, 1.0],
//! &[6.6, 2.9, 4.6, 1.3], //! &[6.6, 2.9, 4.6, 1.3],
//! &[5.2, 2.7, 3.9, 1.4], //! &[5.2, 2.7, 3.9, 1.4],
//! ]).unwrap(); //! ]);
//! //!
//! let kmeans = KMeans::fit(&x, KMeansParameters::default().with_k(2)).unwrap(); // Fit to data, 2 clusters //! let kmeans = KMeans::fit(&x, KMeansParameters::default().with_k(2)).unwrap(); // Fit to data, 2 clusters
//! let y_hat: Vec<u8> = kmeans.predict(&x).unwrap(); // use the same points for prediction //! let y_hat: Vec<u8> = kmeans.predict(&x).unwrap(); // use the same points for prediction
@@ -96,7 +96,7 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq for KMeans<
return false; return false;
} }
for j in 0..self.centroids[i].len() { for j in 0..self.centroids[i].len() {
if (self.centroids[i][j] - other.centroids[i][j]).abs() > f64::EPSILON { if (self.centroids[i][j] - other.centroids[i][j]).abs() > std::f64::EPSILON {
return false; return false;
} }
} }
@@ -270,7 +270,7 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y>
let (n, d) = data.shape(); let (n, d) = data.shape();
let mut distortion = f64::MAX; let mut distortion = std::f64::MAX;
let mut y = KMeans::<TX, TY, X, Y>::kmeans_plus_plus(data, parameters.k, parameters.seed); let mut y = KMeans::<TX, TY, X, Y>::kmeans_plus_plus(data, parameters.k, parameters.seed);
let mut size = vec![0; parameters.k]; let mut size = vec![0; parameters.k];
let mut centroids = vec![vec![0f64; d]; parameters.k]; let mut centroids = vec![vec![0f64; d]; parameters.k];
@@ -331,7 +331,7 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y>
let mut row = vec![0f64; x.shape().1]; let mut row = vec![0f64; x.shape().1];
for i in 0..n { for i in 0..n {
let mut min_dist = f64::MAX; let mut min_dist = std::f64::MAX;
let mut best_cluster = 0; let mut best_cluster = 0;
for j in 0..self.k { for j in 0..self.k {
@@ -361,7 +361,7 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y>
.cloned() .cloned()
.collect(); .collect();
let mut d = vec![f64::MAX; n]; let mut d = vec![std::f64::MAX; n];
let mut row = vec![TX::zero(); data.shape().1]; let mut row = vec![TX::zero(); data.shape().1];
for j in 1..k { for j in 1..k {
@@ -424,7 +424,7 @@ mod tests {
)] )]
#[test] #[test]
fn invalid_k() { fn invalid_k() {
let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap(); let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]);
assert!(KMeans::<i32, i32, DenseMatrix<i32>, Vec<i32>>::fit( assert!(KMeans::<i32, i32, DenseMatrix<i32>, Vec<i32>>::fit(
&x, &x,
@@ -492,15 +492,14 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0], &[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3], &[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4], &[5.2, 2.7, 3.9, 1.4],
]) ]);
.unwrap();
let kmeans = KMeans::fit(&x, Default::default()).unwrap(); let kmeans = KMeans::fit(&x, Default::default()).unwrap();
let y: Vec<usize> = kmeans.predict(&x).unwrap(); let y: Vec<usize> = kmeans.predict(&x).unwrap();
for (i, _y_i) in y.iter().enumerate() { for i in 0..y.len() {
assert_eq!({ y[i] }, kmeans._y[i]); assert_eq!(y[i] as usize, kmeans._y[i]);
} }
} }
@@ -532,8 +531,7 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0], &[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3], &[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4], &[5.2, 2.7, 3.9, 1.4],
]) ]);
.unwrap();
let kmeans: KMeans<f32, f32, DenseMatrix<f32>, Vec<f32>> = let kmeans: KMeans<f32, f32, DenseMatrix<f32>, Vec<f32>> =
KMeans::fit(&x, Default::default()).unwrap(); KMeans::fit(&x, Default::default()).unwrap();
-1
View File
@@ -3,7 +3,6 @@
//! Clustering is the type of unsupervised learning where you divide the population or data points into a number of groups such that data points in the same groups //! Clustering is the type of unsupervised learning where you divide the population or data points into a number of groups such that data points in the same groups
//! are more similar to other data points in the same group than those in other groups. In simple words, the aim is to segregate groups with similar traits and assign them into clusters. //! are more similar to other data points in the same group than those in other groups. In simple words, the aim is to segregate groups with similar traits and assign them into clusters.
pub mod agglomerative;
pub mod dbscan; pub mod dbscan;
/// An iterative clustering algorithm that aims to find local maxima in each iteration. /// An iterative clustering algorithm that aims to find local maxima in each iteration.
pub mod kmeans; pub mod kmeans;
+1 -1
View File
@@ -31,7 +31,7 @@ use crate::dataset::Dataset;
pub fn load_dataset() -> Dataset<f32, f32> { pub fn load_dataset() -> Dataset<f32, f32> {
let (x, y, num_samples, num_features) = match deserialize_data(std::include_bytes!("boston.xy")) let (x, y, num_samples, num_features) = match deserialize_data(std::include_bytes!("boston.xy"))
{ {
Err(why) => panic!("Can't deserialize boston.xy. {why}"), Err(why) => panic!("Can't deserialize boston.xy. {}", why),
Ok((x, y, num_samples, num_features)) => (x, y, num_samples, num_features), Ok((x, y, num_samples, num_features)) => (x, y, num_samples, num_features),
}; };
+1 -1
View File
@@ -33,7 +33,7 @@ use crate::dataset::Dataset;
pub fn load_dataset() -> Dataset<f32, u32> { pub fn load_dataset() -> Dataset<f32, u32> {
let (x, y, num_samples, num_features) = let (x, y, num_samples, num_features) =
match deserialize_data(std::include_bytes!("breast_cancer.xy")) { match deserialize_data(std::include_bytes!("breast_cancer.xy")) {
Err(why) => panic!("Can't deserialize breast_cancer.xy. {why}"), Err(why) => panic!("Can't deserialize breast_cancer.xy. {}", why),
Ok((x, y, num_samples, num_features)) => ( Ok((x, y, num_samples, num_features)) => (
x, x,
y.into_iter().map(|x| x as u32).collect(), y.into_iter().map(|x| x as u32).collect(),
+2 -2
View File
@@ -26,7 +26,7 @@ use crate::dataset::Dataset;
pub fn load_dataset() -> Dataset<f32, u32> { pub fn load_dataset() -> Dataset<f32, u32> {
let (x, y, num_samples, num_features) = let (x, y, num_samples, num_features) =
match deserialize_data(std::include_bytes!("diabetes.xy")) { match deserialize_data(std::include_bytes!("diabetes.xy")) {
Err(why) => panic!("Can't deserialize diabetes.xy. {why}"), Err(why) => panic!("Can't deserialize diabetes.xy. {}", why),
Ok((x, y, num_samples, num_features)) => ( Ok((x, y, num_samples, num_features)) => (
x, x,
y.into_iter().map(|x| x as u32).collect(), y.into_iter().map(|x| x as u32).collect(),
@@ -40,7 +40,7 @@ pub fn load_dataset() -> Dataset<f32, u32> {
target: y, target: y,
num_samples, num_samples,
num_features, num_features,
feature_names: [ feature_names: vec![
"Age", "Sex", "BMI", "BP", "S1", "S2", "S3", "S4", "S5", "S6", "Age", "Sex", "BMI", "BP", "S1", "S2", "S3", "S4", "S5", "S6",
] ]
.iter() .iter()
+6 -4
View File
@@ -16,7 +16,7 @@ use crate::dataset::Dataset;
pub fn load_dataset() -> Dataset<f32, f32> { pub fn load_dataset() -> Dataset<f32, f32> {
let (x, y, num_samples, num_features) = match deserialize_data(std::include_bytes!("digits.xy")) let (x, y, num_samples, num_features) = match deserialize_data(std::include_bytes!("digits.xy"))
{ {
Err(why) => panic!("Can't deserialize digits.xy. {why}"), Err(why) => panic!("Can't deserialize digits.xy. {}", why),
Ok((x, y, num_samples, num_features)) => (x, y, num_samples, num_features), Ok((x, y, num_samples, num_features)) => (x, y, num_samples, num_features),
}; };
@@ -25,14 +25,16 @@ pub fn load_dataset() -> Dataset<f32, f32> {
target: y, target: y,
num_samples, num_samples,
num_features, num_features,
feature_names: ["sepal length (cm)", feature_names: vec![
"sepal length (cm)",
"sepal width (cm)", "sepal width (cm)",
"petal length (cm)", "petal length (cm)",
"petal width (cm)"] "petal width (cm)",
]
.iter() .iter()
.map(|s| s.to_string()) .map(|s| s.to_string())
.collect(), .collect(),
target_names: ["setosa", "versicolor", "virginica"] target_names: vec!["setosa", "versicolor", "virginica"]
.iter() .iter()
.map(|s| s.to_string()) .map(|s| s.to_string())
.collect(), .collect(),
+3 -3
View File
@@ -22,7 +22,7 @@ use crate::dataset::Dataset;
pub fn load_dataset() -> Dataset<f32, u32> { pub fn load_dataset() -> Dataset<f32, u32> {
let (x, y, num_samples, num_features): (Vec<f32>, Vec<u32>, usize, usize) = let (x, y, num_samples, num_features): (Vec<f32>, Vec<u32>, usize, usize) =
match deserialize_data(std::include_bytes!("iris.xy")) { match deserialize_data(std::include_bytes!("iris.xy")) {
Err(why) => panic!("Can't deserialize iris.xy. {why}"), Err(why) => panic!("Can't deserialize iris.xy. {}", why),
Ok((x, y, num_samples, num_features)) => ( Ok((x, y, num_samples, num_features)) => (
x, x,
y.into_iter().map(|x| x as u32).collect(), y.into_iter().map(|x| x as u32).collect(),
@@ -36,7 +36,7 @@ pub fn load_dataset() -> Dataset<f32, u32> {
target: y, target: y,
num_samples, num_samples,
num_features, num_features,
feature_names: [ feature_names: vec![
"sepal length (cm)", "sepal length (cm)",
"sepal width (cm)", "sepal width (cm)",
"petal length (cm)", "petal length (cm)",
@@ -45,7 +45,7 @@ pub fn load_dataset() -> Dataset<f32, u32> {
.iter() .iter()
.map(|s| s.to_string()) .map(|s| s.to_string())
.collect(), .collect(),
target_names: ["setosa", "versicolor", "virginica"] target_names: vec!["setosa", "versicolor", "virginica"]
.iter() .iter()
.map(|s| s.to_string()) .map(|s| s.to_string())
.collect(), .collect(),
+1 -1
View File
@@ -78,7 +78,7 @@ pub(crate) fn serialize_data<X: Number + RealNumber, Y: RealNumber>(
.collect(); .collect();
file.write_all(&y)?; file.write_all(&y)?;
} }
Err(why) => panic!("couldn't create {filename}: {why}"), Err(why) => panic!("couldn't create {}: {}", filename, why),
} }
Ok(()) Ok(())
} }
+18 -22
View File
@@ -35,7 +35,7 @@
//! &[4.9, 2.4, 3.3, 1.0], //! &[4.9, 2.4, 3.3, 1.0],
//! &[6.6, 2.9, 4.6, 1.3], //! &[6.6, 2.9, 4.6, 1.3],
//! &[5.2, 2.7, 3.9, 1.4], //! &[5.2, 2.7, 3.9, 1.4],
//! ]).unwrap(); //! ]);
//! //!
//! let pca = PCA::fit(&iris, PCAParameters::default().with_n_components(2)).unwrap(); // Reduce number of features to 2 //! let pca = PCA::fit(&iris, PCAParameters::default().with_n_components(2)).unwrap(); // Reduce number of features to 2
//! //!
@@ -231,7 +231,8 @@ impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable
if parameters.n_components > n { if parameters.n_components > n {
return Err(Failed::fit(&format!( return Err(Failed::fit(&format!(
"Number of components, n_components should be <= number of attributes ({n})" "Number of components, n_components should be <= number of attributes ({})",
n
))); )));
} }
@@ -373,20 +374,21 @@ mod tests {
let parameters = PCASearchParameters { let parameters = PCASearchParameters {
n_components: vec![2, 4], n_components: vec![2, 4],
use_correlation_matrix: vec![true, false], use_correlation_matrix: vec![true, false],
..Default::default()
}; };
let mut iter = parameters.into_iter(); let mut iter = parameters.into_iter();
let next = iter.next().unwrap(); let next = iter.next().unwrap();
assert_eq!(next.n_components, 2); assert_eq!(next.n_components, 2);
assert!(next.use_correlation_matrix); assert_eq!(next.use_correlation_matrix, true);
let next = iter.next().unwrap(); let next = iter.next().unwrap();
assert_eq!(next.n_components, 4); assert_eq!(next.n_components, 4);
assert!(next.use_correlation_matrix); assert_eq!(next.use_correlation_matrix, true);
let next = iter.next().unwrap(); let next = iter.next().unwrap();
assert_eq!(next.n_components, 2); assert_eq!(next.n_components, 2);
assert!(!next.use_correlation_matrix); assert_eq!(next.use_correlation_matrix, false);
let next = iter.next().unwrap(); let next = iter.next().unwrap();
assert_eq!(next.n_components, 4); assert_eq!(next.n_components, 4);
assert!(!next.use_correlation_matrix); assert_eq!(next.use_correlation_matrix, false);
assert!(iter.next().is_none()); assert!(iter.next().is_none());
} }
@@ -443,7 +445,6 @@ mod tests {
&[2.6, 53.0, 66.0, 10.8], &[2.6, 53.0, 66.0, 10.8],
&[6.8, 161.0, 60.0, 15.6], &[6.8, 161.0, 60.0, 15.6],
]) ])
.unwrap()
} }
#[cfg_attr( #[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")), all(target_arch = "wasm32", not(target_os = "wasi")),
@@ -458,8 +459,7 @@ mod tests {
&[0.9952, 0.0588], &[0.9952, 0.0588],
&[0.0463, 0.9769], &[0.0463, 0.9769],
&[0.0752, 0.2007], &[0.0752, 0.2007],
]) ]);
.unwrap();
let pca = PCA::fit(&us_arrests, Default::default()).unwrap(); let pca = PCA::fit(&us_arrests, Default::default()).unwrap();
@@ -502,8 +502,7 @@ mod tests {
-0.974080592182491, -0.974080592182491,
0.0723250196376097, 0.0723250196376097,
], ],
]) ]);
.unwrap();
let expected_projection = DenseMatrix::from_2d_array(&[ let expected_projection = DenseMatrix::from_2d_array(&[
&[-64.8022, -11.448, 2.4949, -2.4079], &[-64.8022, -11.448, 2.4949, -2.4079],
@@ -556,8 +555,7 @@ mod tests {
&[91.5446, -22.9529, 0.402, -0.7369], &[91.5446, -22.9529, 0.402, -0.7369],
&[118.1763, 5.5076, 2.7113, -0.205], &[118.1763, 5.5076, 2.7113, -0.205],
&[10.4345, -5.9245, 3.7944, 0.5179], &[10.4345, -5.9245, 3.7944, 0.5179],
]) ]);
.unwrap();
let expected_eigenvalues: Vec<f64> = vec![ let expected_eigenvalues: Vec<f64> = vec![
343544.6277001563, 343544.6277001563,
@@ -574,8 +572,8 @@ mod tests {
epsilon = 1e-4 epsilon = 1e-4
)); ));
for (i, pca_eigenvalues_i) in pca.eigenvalues.iter().enumerate() { for i in 0..pca.eigenvalues.len() {
assert!((pca_eigenvalues_i.abs() - expected_eigenvalues[i].abs()).abs() < 1e-8); assert!((pca.eigenvalues[i].abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
} }
let us_arrests_t = pca.transform(&us_arrests).unwrap(); let us_arrests_t = pca.transform(&us_arrests).unwrap();
@@ -620,8 +618,7 @@ mod tests {
-0.0881962972508558, -0.0881962972508558,
-0.0096011588898465, -0.0096011588898465,
], ],
]) ]);
.unwrap();
let expected_projection = DenseMatrix::from_2d_array(&[ let expected_projection = DenseMatrix::from_2d_array(&[
&[0.9856, -1.1334, 0.4443, -0.1563], &[0.9856, -1.1334, 0.4443, -0.1563],
@@ -674,8 +671,7 @@ mod tests {
&[-2.1086, -1.4248, -0.1048, -0.1319], &[-2.1086, -1.4248, -0.1048, -0.1319],
&[-2.0797, 0.6113, 0.1389, -0.1841], &[-2.0797, 0.6113, 0.1389, -0.1841],
&[-0.6294, -0.321, 0.2407, 0.1667], &[-0.6294, -0.321, 0.2407, 0.1667],
]) ]);
.unwrap();
let expected_eigenvalues: Vec<f64> = vec![ let expected_eigenvalues: Vec<f64> = vec![
2.480241579149493, 2.480241579149493,
@@ -698,8 +694,8 @@ mod tests {
epsilon = 1e-4 epsilon = 1e-4
)); ));
for (i, pca_eigenvalues_i) in pca.eigenvalues.iter().enumerate() { for i in 0..pca.eigenvalues.len() {
assert!((pca_eigenvalues_i.abs() - expected_eigenvalues[i].abs()).abs() < 1e-8); assert!((pca.eigenvalues[i].abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
} }
let us_arrests_t = pca.transform(&us_arrests).unwrap(); let us_arrests_t = pca.transform(&us_arrests).unwrap();
@@ -738,7 +734,7 @@ mod tests {
// &[4.9, 2.4, 3.3, 1.0], // &[4.9, 2.4, 3.3, 1.0],
// &[6.6, 2.9, 4.6, 1.3], // &[6.6, 2.9, 4.6, 1.3],
// &[5.2, 2.7, 3.9, 1.4], // &[5.2, 2.7, 3.9, 1.4],
// ]).unwrap(); // ]);
// let pca = PCA::fit(&iris, Default::default()).unwrap(); // let pca = PCA::fit(&iris, Default::default()).unwrap();
+9 -8
View File
@@ -32,7 +32,7 @@
//! &[4.9, 2.4, 3.3, 1.0], //! &[4.9, 2.4, 3.3, 1.0],
//! &[6.6, 2.9, 4.6, 1.3], //! &[6.6, 2.9, 4.6, 1.3],
//! &[5.2, 2.7, 3.9, 1.4], //! &[5.2, 2.7, 3.9, 1.4],
//! ]).unwrap(); //! ]);
//! //!
//! let svd = SVD::fit(&iris, SVDParameters::default(). //! let svd = SVD::fit(&iris, SVDParameters::default().
//! with_n_components(2)).unwrap(); // Reduce number of features to 2 //! with_n_components(2)).unwrap(); // Reduce number of features to 2
@@ -180,7 +180,8 @@ impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable
if parameters.n_components >= p { if parameters.n_components >= p {
return Err(Failed::fit(&format!( return Err(Failed::fit(&format!(
"Number of components, n_components should be < number of attributes ({p})" "Number of components, n_components should be < number of attributes ({})",
p
))); )));
} }
@@ -201,7 +202,8 @@ impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable
let (p_c, k) = self.components.shape(); let (p_c, k) = self.components.shape();
if p_c != p { if p_c != p {
return Err(Failed::transform(&format!( return Err(Failed::transform(&format!(
"Can not transform a {n}x{p} matrix into {n}x{k} matrix, incorrect input dimentions" "Can not transform a {}x{} matrix into {}x{} matrix, incorrect input dimentions",
n, p, n, k
))); )));
} }
@@ -225,6 +227,7 @@ mod tests {
fn search_parameters() { fn search_parameters() {
let parameters = SVDSearchParameters { let parameters = SVDSearchParameters {
n_components: vec![10, 100], n_components: vec![10, 100],
..Default::default()
}; };
let mut iter = parameters.into_iter(); let mut iter = parameters.into_iter();
let next = iter.next().unwrap(); let next = iter.next().unwrap();
@@ -292,8 +295,7 @@ mod tests {
&[5.7, 81.0, 39.0, 9.3], &[5.7, 81.0, 39.0, 9.3],
&[2.6, 53.0, 66.0, 10.8], &[2.6, 53.0, 66.0, 10.8],
&[6.8, 161.0, 60.0, 15.6], &[6.8, 161.0, 60.0, 15.6],
]) ]);
.unwrap();
let expected = DenseMatrix::from_2d_array(&[ let expected = DenseMatrix::from_2d_array(&[
&[243.54655757, -18.76673788], &[243.54655757, -18.76673788],
@@ -301,8 +303,7 @@ mod tests {
&[305.93972467, -15.39087376], &[305.93972467, -15.39087376],
&[197.28420365, -11.66808306], &[197.28420365, -11.66808306],
&[293.43187394, 1.91163633], &[293.43187394, 1.91163633],
]) ]);
.unwrap();
let svd = SVD::fit(&x, Default::default()).unwrap(); let svd = SVD::fit(&x, Default::default()).unwrap();
let x_transformed = svd.transform(&x).unwrap(); let x_transformed = svd.transform(&x).unwrap();
@@ -343,7 +344,7 @@ mod tests {
// &[4.9, 2.4, 3.3, 1.0], // &[4.9, 2.4, 3.3, 1.0],
// &[6.6, 2.9, 4.6, 1.3], // &[6.6, 2.9, 4.6, 1.3],
// &[5.2, 2.7, 3.9, 1.4], // &[5.2, 2.7, 3.9, 1.4],
// ]).unwrap(); // ]);
// let svd = SVD::fit(&iris, Default::default()).unwrap(); // let svd = SVD::fit(&iris, Default::default()).unwrap();
-214
View File
@@ -1,214 +0,0 @@
use rand::Rng;
use std::fmt::Debug;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
use crate::rand_custom::get_rng_impl;
use crate::tree::base_tree_regressor::{BaseTreeRegressor, BaseTreeRegressorParameters, Splitter};
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// Parameters of the Forest Regressor
/// Some parameters here are passed directly into base estimator.
pub struct BaseForestRegressorParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub max_depth: Option<u16>,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub min_samples_leaf: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub min_samples_split: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The number of trees in the forest.
pub n_trees: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// Number of random sample of predictors to use as split candidates.
pub m: Option<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub keep_samples: bool,
#[cfg_attr(feature = "serde", serde(default))]
/// Seed used for bootstrap sampling and feature selection for each tree.
pub seed: u64,
#[cfg_attr(feature = "serde", serde(default))]
pub bootstrap: bool,
#[cfg_attr(feature = "serde", serde(default))]
pub splitter: Splitter,
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
for BaseForestRegressor<TX, TY, X, Y>
{
fn eq(&self, other: &Self) -> bool {
if self.trees.as_ref().unwrap().len() != other.trees.as_ref().unwrap().len() {
false
} else {
self.trees
.iter()
.zip(other.trees.iter())
.all(|(a, b)| a == b)
}
}
}
/// Forest Regressor
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct BaseForestRegressor<
TX: Number + FloatNumber + PartialOrd,
TY: Number,
X: Array2<TX>,
Y: Array1<TY>,
> {
trees: Option<Vec<BaseTreeRegressor<TX, TY, X, Y>>>,
samples: Option<Vec<Vec<bool>>>,
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
BaseForestRegressor<TX, TY, X, Y>
{
/// Build a forest of trees from the training set.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - the target class values
pub fn fit(
x: &X,
y: &Y,
parameters: BaseForestRegressorParameters,
) -> Result<BaseForestRegressor<TX, TY, X, Y>, Failed> {
let (n_rows, num_attributes) = x.shape();
if n_rows != y.shape() {
return Err(Failed::fit("Number of rows in X should = len(y)"));
}
let mtry = parameters
.m
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
let mut rng = get_rng_impl(Some(parameters.seed));
let mut trees: Vec<BaseTreeRegressor<TX, TY, X, Y>> = Vec::new();
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
if parameters.keep_samples {
// TODO: use with_capacity here
maybe_all_samples = Some(Vec::new());
}
let mut samples: Vec<usize> = (0..n_rows).map(|_| 1).collect();
for _ in 0..parameters.n_trees {
if parameters.bootstrap {
samples =
BaseForestRegressor::<TX, TY, X, Y>::sample_with_replacement(n_rows, &mut rng);
}
// keep samples is flag is on
if let Some(ref mut all_samples) = maybe_all_samples {
all_samples.push(samples.iter().map(|x| *x != 0).collect())
}
let params = BaseTreeRegressorParameters {
max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split,
seed: Some(parameters.seed),
splitter: parameters.splitter.clone(),
};
let tree = BaseTreeRegressor::fit_weak_learner(x, y, samples.clone(), mtry, params)?;
trees.push(tree);
}
Ok(BaseForestRegressor {
trees: Some(trees),
samples: maybe_all_samples,
})
}
/// Predict class for `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let mut result = Y::zeros(x.shape().0);
let (n, _) = x.shape();
for i in 0..n {
result.set(i, self.predict_for_row(x, i));
}
Ok(result)
}
fn predict_for_row(&self, x: &X, row: usize) -> TY {
let n_trees = self.trees.as_ref().unwrap().len();
let mut result = TY::zero();
for tree in self.trees.as_ref().unwrap().iter() {
result += tree.predict_for_row(x, row);
}
result / TY::from_usize(n_trees).unwrap()
}
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
let (n, _) = x.shape();
if self.samples.is_none() {
Err(Failed::because(
FailedError::PredictFailed,
"Need samples=true for OOB predictions.",
))
} else if self.samples.as_ref().unwrap()[0].len() != n {
Err(Failed::because(
FailedError::PredictFailed,
"Prediction matrix must match matrix used in training for OOB predictions.",
))
} else {
let mut result = Y::zeros(n);
for i in 0..n {
result.set(i, self.predict_for_row_oob(x, i));
}
Ok(result)
}
}
fn predict_for_row_oob(&self, x: &X, row: usize) -> TY {
let mut n_trees = 0;
let mut result = TY::zero();
for (tree, samples) in self
.trees
.as_ref()
.unwrap()
.iter()
.zip(self.samples.as_ref().unwrap())
{
if !samples[row] {
result += tree.predict_for_row(x, row);
n_trees += 1;
}
}
// TODO: What to do if there are no oob trees?
result / TY::from(n_trees).unwrap()
}
fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec<usize> {
let mut samples = vec![0; nrows];
for _ in 0..nrows {
let xi = rng.gen_range(0..nrows);
samples[xi] += 1;
}
samples
}
}
-318
View File
@@ -1,318 +0,0 @@
//! # Extra Trees Regressor
//! An Extra-Trees (Extremely Randomized Trees) regressor is an ensemble learning method that fits multiple randomized
//! decision trees on the dataset and averages their predictions to improve accuracy and control over-fitting.
//!
//! It is similar to a standard Random Forest, but introduces more randomness in the way splits are chosen, which can
//! reduce the variance of the model and often make the training process faster.
//!
//! The two key differences from a standard Random Forest are:
//! 1. It uses the whole original dataset to build each tree instead of bootstrap samples.
//! 2. When splitting a node, it chooses a random split point for each feature, rather than the most optimal one.
//!
//! See [ensemble models](../index.html) for more details.
//!
//! Bigger number of estimators in general improves performance of the algorithm with an increased cost of training time.
//! The random sample of _m_ predictors is typically set to be \\(\sqrt{p}\\) from the full set of _p_ predictors.
//!
//! Example:
//!
//! ```
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::ensemble::extra_trees_regressor::*;
//!
//! // Longley dataset ([https://www.statsmodels.org/stable/datasets/generated/longley.html](https://www.statsmodels.org/stable/datasets/generated/longley.html))
//! let x = DenseMatrix::from_2d_array(&[
//! &[234.289, 235.6, 159., 107.608, 1947., 60.323],
//! &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
//! &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
//! &[284.599, 335.1, 165., 110.929, 1950., 61.187],
//! &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
//! &[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
//! &[365.385, 187., 354.7, 115.094, 1953., 64.989],
//! &[363.112, 357.8, 335., 116.219, 1954., 63.761],
//! &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
//! &[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
//! &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
//! &[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
//! &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
//! ]).unwrap();
//! let y = vec![
//! 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2,
//! 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9
//! ];
//!
//! let regressor = ExtraTreesRegressor::fit(&x, &y, Default::default()).unwrap();
//!
//! let y_hat = regressor.predict(&x).unwrap(); // use the same data for prediction
//! ```
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::default::Default;
use std::fmt::Debug;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::ensemble::base_forest_regressor::{BaseForestRegressor, BaseForestRegressorParameters};
use crate::error::Failed;
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
use crate::tree::base_tree_regressor::Splitter;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// Parameters of the Extra Trees Regressor
/// Some parameters here are passed directly into base estimator.
pub struct ExtraTreesRegressorParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub max_depth: Option<u16>,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub min_samples_leaf: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub min_samples_split: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The number of trees in the forest.
pub n_trees: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// Number of random sample of predictors to use as split candidates.
pub m: Option<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub keep_samples: bool,
#[cfg_attr(feature = "serde", serde(default))]
/// Seed used for bootstrap sampling and feature selection for each tree.
pub seed: u64,
}
/// Extra Trees Regressor
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct ExtraTreesRegressor<
TX: Number + FloatNumber + PartialOrd,
TY: Number,
X: Array2<TX>,
Y: Array1<TY>,
> {
forest_regressor: Option<BaseForestRegressor<TX, TY, X, Y>>,
}
impl ExtraTreesRegressorParameters {
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub fn with_max_depth(mut self, max_depth: u16) -> Self {
self.max_depth = Some(max_depth);
self
}
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
self.min_samples_leaf = min_samples_leaf;
self
}
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
self.min_samples_split = min_samples_split;
self
}
/// The number of trees in the forest.
pub fn with_n_trees(mut self, n_trees: usize) -> Self {
self.n_trees = n_trees;
self
}
/// Number of random sample of predictors to use as split candidates.
pub fn with_m(mut self, m: usize) -> Self {
self.m = Some(m);
self
}
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub fn with_keep_samples(mut self, keep_samples: bool) -> Self {
self.keep_samples = keep_samples;
self
}
/// Seed used for bootstrap sampling and feature selection for each tree.
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
}
impl Default for ExtraTreesRegressorParameters {
fn default() -> Self {
ExtraTreesRegressorParameters {
max_depth: Option::None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 10,
m: Option::None,
keep_samples: false,
seed: 0,
}
}
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
SupervisedEstimator<X, Y, ExtraTreesRegressorParameters> for ExtraTreesRegressor<TX, TY, X, Y>
{
fn new() -> Self {
Self {
forest_regressor: Option::None,
}
}
fn fit(x: &X, y: &Y, parameters: ExtraTreesRegressorParameters) -> Result<Self, Failed> {
ExtraTreesRegressor::fit(x, y, parameters)
}
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
Predictor<X, Y> for ExtraTreesRegressor<TX, TY, X, Y>
{
fn predict(&self, x: &X) -> Result<Y, Failed> {
self.predict(x)
}
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
ExtraTreesRegressor<TX, TY, X, Y>
{
/// Build a forest of trees from the training set.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - the target class values
pub fn fit(
x: &X,
y: &Y,
parameters: ExtraTreesRegressorParameters,
) -> Result<ExtraTreesRegressor<TX, TY, X, Y>, Failed> {
let regressor_params = BaseForestRegressorParameters {
max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split,
n_trees: parameters.n_trees,
m: parameters.m,
keep_samples: parameters.keep_samples,
seed: parameters.seed,
bootstrap: false,
splitter: Splitter::Random,
};
let forest_regressor = BaseForestRegressor::fit(x, y, regressor_params)?;
Ok(ExtraTreesRegressor {
forest_regressor: Some(forest_regressor),
})
}
/// Predict class for `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let forest_regressor = self.forest_regressor.as_ref().unwrap();
forest_regressor.predict(x)
}
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
let forest_regressor = self.forest_regressor.as_ref().unwrap();
forest_regressor.predict_oob(x)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::metrics::mean_squared_error;
#[test]
fn test_extra_trees_regressor_fit_predict() {
// Use a simpler, more predictable dataset for unit testing.
let x = DenseMatrix::from_2d_array(&[
&[1., 2.],
&[3., 4.],
&[5., 6.],
&[7., 8.],
&[9., 10.],
&[11., 12.],
&[13., 14.],
&[15., 16.],
])
.unwrap();
let y = vec![1., 2., 3., 4., 5., 6., 7., 8.];
let parameters = ExtraTreesRegressorParameters::default()
.with_n_trees(100)
.with_seed(42);
let regressor = ExtraTreesRegressor::fit(&x, &y, parameters).unwrap();
let y_hat = regressor.predict(&x).unwrap();
assert_eq!(y_hat.len(), y.len());
// A basic check to ensure the model is learning something.
// The error should be significantly less than the variance of y.
let mse = mean_squared_error(&y, &y_hat);
// With this simple dataset, the error should be very low.
assert!(mse < 1.0);
}
#[test]
fn test_fit_predict_higher_dims() {
// Dataset with 10 features, but y is only dependent on the 3rd feature (index 2).
let x = DenseMatrix::from_2d_array(&[
// The 3rd column is the important one. The rest are noise.
&[0., 0., 10., 5., 8., 1., 4., 9., 2., 7.],
&[0., 0., 20., 1., 2., 3., 4., 5., 6., 7.],
&[0., 0., 30., 7., 6., 5., 4., 3., 2., 1.],
&[0., 0., 40., 9., 2., 4., 6., 8., 1., 3.],
&[0., 0., 55., 3., 1., 8., 6., 4., 2., 9.],
&[0., 0., 65., 2., 4., 7., 5., 3., 1., 8.],
])
.unwrap();
let y = vec![10., 20., 30., 40., 55., 65.];
let parameters = ExtraTreesRegressorParameters::default()
.with_n_trees(100)
.with_seed(42);
let regressor = ExtraTreesRegressor::fit(&x, &y, parameters).unwrap();
let y_hat = regressor.predict(&x).unwrap();
assert_eq!(y_hat.len(), y.len());
let mse = mean_squared_error(&y, &y_hat);
// The model should be able to learn this simple relationship perfectly,
// ignoring the noise features. The MSE should be very low.
assert!(mse < 1.0);
}
#[test]
fn test_reproducibility() {
let x = DenseMatrix::from_2d_array(&[
&[1., 2.],
&[3., 4.],
&[5., 6.],
&[7., 8.],
&[9., 10.],
&[11., 12.],
])
.unwrap();
let y = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let params = ExtraTreesRegressorParameters::default().with_seed(42);
let regressor1 = ExtraTreesRegressor::fit(&x, &y, params.clone()).unwrap();
let y_hat1 = regressor1.predict(&x).unwrap();
let regressor2 = ExtraTreesRegressor::fit(&x, &y, params.clone()).unwrap();
let y_hat2 = regressor2.predict(&x).unwrap();
assert_eq!(y_hat1, y_hat2);
}
}
-2
View File
@@ -16,8 +16,6 @@
//! //!
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 8.2 Bagging, Random Forests, Boosting](http://faculty.marshall.usc.edu/gareth-james/ISL/) //! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 8.2 Bagging, Random Forests, Boosting](http://faculty.marshall.usc.edu/gareth-james/ISL/)
mod base_forest_regressor;
pub mod extra_trees_regressor;
/// Random forest classifier /// Random forest classifier
pub mod random_forest_classifier; pub mod random_forest_classifier;
/// Random forest regressor /// Random forest regressor
+5 -36
View File
@@ -33,7 +33,7 @@
//! &[4.9, 2.4, 3.3, 1.0], //! &[4.9, 2.4, 3.3, 1.0],
//! &[6.6, 2.9, 4.6, 1.3], //! &[6.6, 2.9, 4.6, 1.3],
//! &[5.2, 2.7, 3.9, 1.4], //! &[5.2, 2.7, 3.9, 1.4],
//! ]).unwrap(); //! ]);
//! let y = vec![ //! let y = vec![
//! 0, 0, 0, 0, 0, 0, 0, 0, //! 0, 0, 0, 0, 0, 0, 0, 0,
//! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, //! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
@@ -454,12 +454,8 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
y: &Y, y: &Y,
parameters: RandomForestClassifierParameters, parameters: RandomForestClassifierParameters,
) -> Result<RandomForestClassifier<TX, TY, X, Y>, Failed> { ) -> Result<RandomForestClassifier<TX, TY, X, Y>, Failed> {
let (x_nrows, num_attributes) = x.shape(); let (_, num_attributes) = x.shape();
let y_ncols = y.shape(); let y_ncols = y.shape();
if x_nrows != y_ncols {
return Err(Failed::fit("Number of rows in X should = len(y)"));
}
let mut yi: Vec<usize> = vec![0; y_ncols]; let mut yi: Vec<usize> = vec![0; y_ncols];
let classes = y.unique(); let classes = y.unique();
@@ -660,8 +656,7 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0], &[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3], &[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4], &[5.2, 2.7, 3.9, 1.4],
]) ]);
.unwrap();
let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
let classifier = RandomForestClassifier::fit( let classifier = RandomForestClassifier::fit(
@@ -683,30 +678,6 @@ mod tests {
assert!(accuracy(&y, &classifier.predict(&x).unwrap()) >= 0.95); assert!(accuracy(&y, &classifier.predict(&x).unwrap()) >= 0.95);
} }
#[test]
fn test_random_matrix_with_wrong_rownum() {
let x_rand: DenseMatrix<f64> = DenseMatrix::<f64>::rand(21, 200);
let y: Vec<u32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
let fail = RandomForestClassifier::fit(
&x_rand,
&y,
RandomForestClassifierParameters {
criterion: SplitCriterion::Gini,
max_depth: Option::None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 100,
m: Option::None,
keep_samples: false,
seed: 87,
},
);
assert!(fail.is_err());
}
#[cfg_attr( #[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")), all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test wasm_bindgen_test::wasm_bindgen_test
@@ -734,8 +705,7 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0], &[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3], &[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4], &[5.2, 2.7, 3.9, 1.4],
]) ]);
.unwrap();
let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
let classifier = RandomForestClassifier::fit( let classifier = RandomForestClassifier::fit(
@@ -788,8 +758,7 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0], &[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3], &[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4], &[5.2, 2.7, 3.9, 1.4],
]) ]);
.unwrap();
let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap(); let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap();
+129 -56
View File
@@ -29,7 +29,7 @@
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564], //! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], //! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551], //! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
//! ]).unwrap(); //! ]);
//! let y = vec![ //! let y = vec![
//! 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, //! 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2,
//! 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9 //! 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9
@@ -43,6 +43,7 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use rand::Rng;
use std::default::Default; use std::default::Default;
use std::fmt::Debug; use std::fmt::Debug;
@@ -50,12 +51,15 @@ use std::fmt::Debug;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator}; use crate::api::{Predictor, SupervisedEstimator};
use crate::ensemble::base_forest_regressor::{BaseForestRegressor, BaseForestRegressorParameters}; use crate::error::{Failed, FailedError};
use crate::error::Failed;
use crate::linalg::basic::arrays::{Array1, Array2}; use crate::linalg::basic::arrays::{Array1, Array2};
use crate::numbers::basenum::Number; use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber; use crate::numbers::floatnum::FloatNumber;
use crate::tree::base_tree_regressor::Splitter;
use crate::rand_custom::get_rng_impl;
use crate::tree::decision_tree_regressor::{
DecisionTreeRegressor, DecisionTreeRegressorParameters,
};
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -94,7 +98,8 @@ pub struct RandomForestRegressor<
X: Array2<TX>, X: Array2<TX>,
Y: Array1<TY>, Y: Array1<TY>,
> { > {
forest_regressor: Option<BaseForestRegressor<TX, TY, X, Y>>, trees: Option<Vec<DecisionTreeRegressor<TX, TY, X, Y>>>,
samples: Option<Vec<Vec<bool>>>,
} }
impl RandomForestRegressorParameters { impl RandomForestRegressorParameters {
@@ -154,7 +159,14 @@ impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1
for RandomForestRegressor<TX, TY, X, Y> for RandomForestRegressor<TX, TY, X, Y>
{ {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.forest_regressor == other.forest_regressor if self.trees.as_ref().unwrap().len() != other.trees.as_ref().unwrap().len() {
false
} else {
self.trees
.iter()
.zip(other.trees.iter())
.all(|(a, b)| a == b)
}
} }
} }
@@ -164,7 +176,8 @@ impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1
{ {
fn new() -> Self { fn new() -> Self {
Self { Self {
forest_regressor: Option::None, trees: Option::None,
samples: Option::None,
} }
} }
@@ -384,35 +397,124 @@ impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1
y: &Y, y: &Y,
parameters: RandomForestRegressorParameters, parameters: RandomForestRegressorParameters,
) -> Result<RandomForestRegressor<TX, TY, X, Y>, Failed> { ) -> Result<RandomForestRegressor<TX, TY, X, Y>, Failed> {
let regressor_params = BaseForestRegressorParameters { let (n_rows, num_attributes) = x.shape();
max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf, let mtry = parameters
min_samples_split: parameters.min_samples_split, .m
n_trees: parameters.n_trees, .unwrap_or((num_attributes as f64).sqrt().floor() as usize);
m: parameters.m,
keep_samples: parameters.keep_samples, let mut rng = get_rng_impl(Some(parameters.seed));
seed: parameters.seed, let mut trees: Vec<DecisionTreeRegressor<TX, TY, X, Y>> = Vec::new();
bootstrap: true,
splitter: Splitter::Best, let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
}; if parameters.keep_samples {
let forest_regressor = BaseForestRegressor::fit(x, y, regressor_params)?; // TODO: use with_capacity here
maybe_all_samples = Some(Vec::new());
}
for _ in 0..parameters.n_trees {
let samples: Vec<usize> =
RandomForestRegressor::<TX, TY, X, Y>::sample_with_replacement(n_rows, &mut rng);
// keep samples is flag is on
if let Some(ref mut all_samples) = maybe_all_samples {
all_samples.push(samples.iter().map(|x| *x != 0).collect())
}
let params = DecisionTreeRegressorParameters {
max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split,
seed: Some(parameters.seed),
};
let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params)?;
trees.push(tree);
}
Ok(RandomForestRegressor { Ok(RandomForestRegressor {
forest_regressor: Some(forest_regressor), trees: Some(trees),
samples: maybe_all_samples,
}) })
} }
/// Predict class for `x` /// Predict class for `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &X) -> Result<Y, Failed> { pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let forest_regressor = self.forest_regressor.as_ref().unwrap(); let mut result = Y::zeros(x.shape().0);
forest_regressor.predict(x)
let (n, _) = x.shape();
for i in 0..n {
result.set(i, self.predict_for_row(x, i));
}
Ok(result)
}
fn predict_for_row(&self, x: &X, row: usize) -> TY {
let n_trees = self.trees.as_ref().unwrap().len();
let mut result = TY::zero();
for tree in self.trees.as_ref().unwrap().iter() {
result += tree.predict_for_row(x, row);
}
result / TY::from_usize(n_trees).unwrap()
} }
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training. /// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> { pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
let forest_regressor = self.forest_regressor.as_ref().unwrap(); let (n, _) = x.shape();
forest_regressor.predict_oob(x) if self.samples.is_none() {
Err(Failed::because(
FailedError::PredictFailed,
"Need samples=true for OOB predictions.",
))
} else if self.samples.as_ref().unwrap()[0].len() != n {
Err(Failed::because(
FailedError::PredictFailed,
"Prediction matrix must match matrix used in training for OOB predictions.",
))
} else {
let mut result = Y::zeros(n);
for i in 0..n {
result.set(i, self.predict_for_row_oob(x, i));
}
Ok(result)
}
}
fn predict_for_row_oob(&self, x: &X, row: usize) -> TY {
let mut n_trees = 0;
let mut result = TY::zero();
for (tree, samples) in self
.trees
.as_ref()
.unwrap()
.iter()
.zip(self.samples.as_ref().unwrap())
{
if !samples[row] {
result += tree.predict_for_row(x, row);
n_trees += 1;
}
}
// TODO: What to do if there are no oob trees?
result / TY::from(n_trees).unwrap()
}
fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec<usize> {
let mut samples = vec![0; nrows];
for _ in 0..nrows {
let xi = rng.gen_range(0..nrows);
samples[xi] += 1;
}
samples
} }
} }
@@ -468,8 +570,7 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564], &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331], &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551], &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]) ]);
.unwrap();
let y = vec![ let y = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9, 114.2, 115.7, 116.9,
@@ -494,32 +595,6 @@ mod tests {
assert!(mean_absolute_error(&y, &y_hat) < 1.0); assert!(mean_absolute_error(&y, &y_hat) < 1.0);
} }
#[test]
fn test_random_matrix_with_wrong_rownum() {
let x_rand: DenseMatrix<f64> = DenseMatrix::<f64>::rand(17, 200);
let y = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
];
let fail = RandomForestRegressor::fit(
&x_rand,
&y,
RandomForestRegressorParameters {
max_depth: Option::None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 1000,
m: Option::None,
keep_samples: false,
seed: 87,
},
);
assert!(fail.is_err());
}
#[cfg_attr( #[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")), all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test wasm_bindgen_test::wasm_bindgen_test
@@ -543,8 +618,7 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564], &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331], &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551], &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]) ]);
.unwrap();
let y = vec![ let y = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9, 114.2, 115.7, 116.9,
@@ -598,8 +672,7 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564], &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331], &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551], &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]) ]);
.unwrap();
let y = vec![ let y = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9, 114.2, 115.7, 116.9,
+2 -21
View File
@@ -30,10 +30,8 @@ pub enum FailedError {
DecompositionFailed, DecompositionFailed,
/// Can't solve for x /// Can't solve for x
SolutionFailed, SolutionFailed,
/// Error in input parameters /// Erro in input
ParametersError, ParametersError,
/// Invalid state error (should never happen)
InvalidStateError,
} }
impl Failed { impl Failed {
@@ -66,22 +64,6 @@ impl Failed {
} }
} }
/// new instance of `FailedError::ParametersError`
pub fn input(msg: &str) -> Self {
Failed {
err: FailedError::ParametersError,
msg: msg.to_string(),
}
}
/// new instance of `FailedError::InvalidStateError`
pub fn invalid_state(msg: &str) -> Self {
Failed {
err: FailedError::InvalidStateError,
msg: msg.to_string(),
}
}
/// new instance of `err` /// new instance of `err`
pub fn because(err: FailedError, msg: &str) -> Self { pub fn because(err: FailedError, msg: &str) -> Self {
Failed { Failed {
@@ -115,9 +97,8 @@ impl fmt::Display for FailedError {
FailedError::DecompositionFailed => "Decomposition failed", FailedError::DecompositionFailed => "Decomposition failed",
FailedError::SolutionFailed => "Can't find solution", FailedError::SolutionFailed => "Can't find solution",
FailedError::ParametersError => "Error in input, check parameters", FailedError::ParametersError => "Error in input, check parameters",
FailedError::InvalidStateError => "Invalid state, this should never happen", // useful in development phase of lib
}; };
write!(f, "{failed_err_str}") write!(f, "{}", failed_err_str)
} }
} }
+3 -4
View File
@@ -3,10 +3,10 @@
clippy::too_many_arguments, clippy::too_many_arguments,
clippy::many_single_char_names, clippy::many_single_char_names,
clippy::unnecessary_wraps, clippy::unnecessary_wraps,
clippy::upper_case_acronyms, clippy::upper_case_acronyms
clippy::approx_constant
)] )]
#![warn(missing_docs)] #![warn(missing_docs)]
#![warn(rustdoc::missing_doc_code_examples)]
//! # smartcore //! # smartcore
//! //!
@@ -63,7 +63,7 @@
//! &[3., 4.], //! &[3., 4.],
//! &[5., 6.], //! &[5., 6.],
//! &[7., 8.], //! &[7., 8.],
//! &[9., 10.]]).unwrap(); //! &[9., 10.]]);
//! // Our classes are defined as a vector //! // Our classes are defined as a vector
//! let y = vec![2, 2, 2, 3, 3]; //! let y = vec![2, 2, 2, 3, 3];
//! //!
@@ -130,6 +130,5 @@ pub mod readers;
pub mod svm; pub mod svm;
/// Supervised tree-based learning methods /// Supervised tree-based learning methods
pub mod tree; pub mod tree;
pub mod xgboost;
pub(crate) mod rand_custom; pub(crate) mod rand_custom;
+195 -239
View File
File diff suppressed because it is too large Load Diff
+107 -236
View File
@@ -19,8 +19,6 @@ use crate::linalg::traits::svd::SVDDecomposable;
use crate::numbers::basenum::Number; use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber; use crate::numbers::realnum::RealNumber;
use crate::error::Failed;
/// Dense matrix /// Dense matrix
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -52,26 +50,26 @@ pub struct DenseMatrixMutView<'a, T: Debug + Display + Copy + Sized> {
} }
impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixView<'a, T> { impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixView<'a, T> {
fn new( fn new(m: &'a DenseMatrix<T>, rows: Range<usize>, cols: Range<usize>) -> Self {
m: &'a DenseMatrix<T>, let (start, end, stride) = if m.column_major {
vrows: Range<usize>, (
vcols: Range<usize>, rows.start + cols.start * m.nrows,
) -> Result<Self, Failed> { rows.end + (cols.end - 1) * m.nrows,
if m.is_valid_view(m.shape().0, m.shape().1, &vrows, &vcols) { m.nrows,
Err(Failed::input( )
"The specified view is outside of the matrix range",
))
} else { } else {
let (start, end, stride) = (
m.stride_range(m.shape().0, m.shape().1, &vrows, &vcols, m.column_major); rows.start * m.ncols + cols.start,
(rows.end - 1) * m.ncols + cols.end,
Ok(DenseMatrixView { m.ncols,
values: &m.values[start..end], )
stride, };
nrows: vrows.end - vrows.start, DenseMatrixView {
ncols: vcols.end - vcols.start, values: &m.values[start..end],
column_major: m.column_major, stride,
}) nrows: rows.end - rows.start,
ncols: cols.end - cols.start,
column_major: m.column_major,
} }
} }
@@ -91,7 +89,7 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixView<'a, T> {
} }
} }
impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixView<'_, T> { impl<'a, T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixView<'a, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!( writeln!(
f, f,
@@ -104,26 +102,26 @@ impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixView<'_, T>
} }
impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> { impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> {
fn new( fn new(m: &'a mut DenseMatrix<T>, rows: Range<usize>, cols: Range<usize>) -> Self {
m: &'a mut DenseMatrix<T>, let (start, end, stride) = if m.column_major {
vrows: Range<usize>, (
vcols: Range<usize>, rows.start + cols.start * m.nrows,
) -> Result<Self, Failed> { rows.end + (cols.end - 1) * m.nrows,
if m.is_valid_view(m.shape().0, m.shape().1, &vrows, &vcols) { m.nrows,
Err(Failed::input( )
"The specified view is outside of the matrix range",
))
} else { } else {
let (start, end, stride) = (
m.stride_range(m.shape().0, m.shape().1, &vrows, &vcols, m.column_major); rows.start * m.ncols + cols.start,
(rows.end - 1) * m.ncols + cols.end,
Ok(DenseMatrixMutView { m.ncols,
values: &mut m.values[start..end], )
stride, };
nrows: vrows.end - vrows.start, DenseMatrixMutView {
ncols: vcols.end - vcols.start, values: &mut m.values[start..end],
column_major: m.column_major, stride,
}) nrows: rows.end - rows.start,
ncols: cols.end - cols.start,
column_major: m.column_major,
} }
} }
@@ -142,7 +140,7 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> {
} }
} }
fn iter_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> { fn iter_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &mut T> + 'b> {
let column_major = self.column_major; let column_major = self.column_major;
let stride = self.stride; let stride = self.stride;
let ptr = self.values.as_mut_ptr(); let ptr = self.values.as_mut_ptr();
@@ -169,7 +167,7 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> {
} }
} }
impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixMutView<'_, T> { impl<'a, T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixMutView<'a, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!( writeln!(
f, f,
@@ -184,102 +182,42 @@ impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixMutView<'_,
impl<T: Debug + Display + Copy + Sized> DenseMatrix<T> { impl<T: Debug + Display + Copy + Sized> DenseMatrix<T> {
/// Create new instance of `DenseMatrix` without copying data. /// Create new instance of `DenseMatrix` without copying data.
/// `values` should be in column-major order. /// `values` should be in column-major order.
pub fn new( pub fn new(nrows: usize, ncols: usize, values: Vec<T>, column_major: bool) -> Self {
nrows: usize, DenseMatrix {
ncols: usize, ncols,
values: Vec<T>, nrows,
column_major: bool, values,
) -> Result<Self, Failed> { column_major,
let data_len = values.len();
if nrows * ncols != values.len() {
Err(Failed::input(&format!(
"The specified shape: (cols: {ncols}, rows: {nrows}) does not align with data len: {data_len}"
)))
} else {
Ok(DenseMatrix {
ncols,
nrows,
values,
column_major,
})
} }
} }
/// New instance of `DenseMatrix` from 2d array. /// New instance of `DenseMatrix` from 2d array.
pub fn from_2d_array(values: &[&[T]]) -> Result<Self, Failed> { pub fn from_2d_array(values: &[&[T]]) -> Self {
DenseMatrix::from_2d_vec(&values.iter().map(|row| Vec::from(*row)).collect()) DenseMatrix::from_2d_vec(&values.iter().map(|row| Vec::from(*row)).collect())
} }
/// New instance of `DenseMatrix` from 2d vector. /// New instance of `DenseMatrix` from 2d vector.
#[allow(clippy::ptr_arg)] pub fn from_2d_vec(values: &Vec<Vec<T>>) -> Self {
pub fn from_2d_vec(values: &Vec<Vec<T>>) -> Result<Self, Failed> { let nrows = values.len();
if values.is_empty() || values[0].is_empty() { let ncols = values
Err(Failed::input( .first()
"The 2d vec provided is empty; cannot instantiate the matrix", .unwrap_or_else(|| panic!("Cannot create 2d matrix from an empty vector"))
)) .len();
} else { let mut m_values = Vec::with_capacity(nrows * ncols);
let nrows = values.len();
let ncols = values
.first()
.unwrap_or_else(|| {
panic!("Invalid state: Cannot create 2d matrix from an empty vector")
})
.len();
let mut m_values = Vec::with_capacity(nrows * ncols);
for c in 0..ncols { for c in 0..ncols {
for r in values.iter().take(nrows) { for r in values.iter().take(nrows) {
m_values.push(r[c]) m_values.push(r[c])
}
} }
DenseMatrix::new(nrows, ncols, m_values, true)
} }
DenseMatrix::new(nrows, ncols, m_values, true)
} }
/// Iterate over values of matrix /// Iterate over values of matrix
pub fn iter(&self) -> Iter<'_, T> { pub fn iter(&self) -> Iter<'_, T> {
self.values.iter() self.values.iter()
} }
/// Check if the size of the requested view is bounded to matrix rows/cols count
fn is_valid_view(
&self,
n_rows: usize,
n_cols: usize,
vrows: &Range<usize>,
vcols: &Range<usize>,
) -> bool {
!(vrows.end <= n_rows
&& vcols.end <= n_cols
&& vrows.start <= n_rows
&& vcols.start <= n_cols)
}
/// Compute the range of the requested view: start, end, size of the slice
fn stride_range(
&self,
n_rows: usize,
n_cols: usize,
vrows: &Range<usize>,
vcols: &Range<usize>,
column_major: bool,
) -> (usize, usize, usize) {
let (start, end, stride) = if column_major {
(
vrows.start + vcols.start * n_rows,
vrows.end + (vcols.end - 1) * n_rows,
n_rows,
)
} else {
(
vrows.start * n_cols + vcols.start,
(vrows.end - 1) * n_cols + vcols.end,
n_cols,
)
};
(start, end, stride)
}
} }
impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrix<T> { impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrix<T> {
@@ -366,7 +304,6 @@ where
impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrix<T> { impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrix<T> {
fn get(&self, pos: (usize, usize)) -> &T { fn get(&self, pos: (usize, usize)) -> &T {
let (row, col) = pos; let (row, col) = pos;
if row >= self.nrows || col >= self.ncols { if row >= self.nrows || col >= self.ncols {
panic!( panic!(
"Invalid index ({},{}) for {}x{} matrix", "Invalid index ({},{}) for {}x{} matrix",
@@ -446,15 +383,15 @@ impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for DenseMatrix<T> {}
impl<T: Debug + Display + Copy + Sized> Array2<T> for DenseMatrix<T> { impl<T: Debug + Display + Copy + Sized> Array2<T> for DenseMatrix<T> {
fn get_row<'a>(&'a self, row: usize) -> Box<dyn ArrayView1<T> + 'a> { fn get_row<'a>(&'a self, row: usize) -> Box<dyn ArrayView1<T> + 'a> {
Box::new(DenseMatrixView::new(self, row..row + 1, 0..self.ncols).unwrap()) Box::new(DenseMatrixView::new(self, row..row + 1, 0..self.ncols))
} }
fn get_col<'a>(&'a self, col: usize) -> Box<dyn ArrayView1<T> + 'a> { fn get_col<'a>(&'a self, col: usize) -> Box<dyn ArrayView1<T> + 'a> {
Box::new(DenseMatrixView::new(self, 0..self.nrows, col..col + 1).unwrap()) Box::new(DenseMatrixView::new(self, 0..self.nrows, col..col + 1))
} }
fn slice<'a>(&'a self, rows: Range<usize>, cols: Range<usize>) -> Box<dyn ArrayView2<T> + 'a> { fn slice<'a>(&'a self, rows: Range<usize>, cols: Range<usize>) -> Box<dyn ArrayView2<T> + 'a> {
Box::new(DenseMatrixView::new(self, rows, cols).unwrap()) Box::new(DenseMatrixView::new(self, rows, cols))
} }
fn slice_mut<'a>( fn slice_mut<'a>(
@@ -465,17 +402,15 @@ impl<T: Debug + Display + Copy + Sized> Array2<T> for DenseMatrix<T> {
where where
Self: Sized, Self: Sized,
{ {
Box::new(DenseMatrixMutView::new(self, rows, cols).unwrap()) Box::new(DenseMatrixMutView::new(self, rows, cols))
} }
// private function so for now assume infalible
fn fill(nrows: usize, ncols: usize, value: T) -> Self { fn fill(nrows: usize, ncols: usize, value: T) -> Self {
DenseMatrix::new(nrows, ncols, vec![value; nrows * ncols], true).unwrap() DenseMatrix::new(nrows, ncols, vec![value; nrows * ncols], true)
} }
// private function so for now assume infalible
fn from_iterator<I: Iterator<Item = T>>(iter: I, nrows: usize, ncols: usize, axis: u8) -> Self { fn from_iterator<I: Iterator<Item = T>>(iter: I, nrows: usize, ncols: usize, axis: u8) -> Self {
DenseMatrix::new(nrows, ncols, iter.collect(), axis != 0).unwrap() DenseMatrix::new(nrows, ncols, iter.collect(), axis != 0)
} }
fn transpose(&self) -> Self { fn transpose(&self) -> Self {
@@ -493,12 +428,12 @@ impl<T: Number + RealNumber> EVDDecomposable<T> for DenseMatrix<T> {}
impl<T: Number + RealNumber> LUDecomposable<T> for DenseMatrix<T> {} impl<T: Number + RealNumber> LUDecomposable<T> for DenseMatrix<T> {}
impl<T: Number + RealNumber> SVDDecomposable<T> for DenseMatrix<T> {} impl<T: Number + RealNumber> SVDDecomposable<T> for DenseMatrix<T> {}
impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixView<'_, T> { impl<'a, T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixView<'a, T> {
fn get(&self, pos: (usize, usize)) -> &T { fn get(&self, pos: (usize, usize)) -> &T {
if self.column_major { if self.column_major {
&self.values[pos.0 + pos.1 * self.stride] &self.values[(pos.0 + pos.1 * self.stride)]
} else { } else {
&self.values[pos.0 * self.stride + pos.1] &self.values[(pos.0 * self.stride + pos.1)]
} }
} }
@@ -515,7 +450,7 @@ impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrix
} }
} }
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for DenseMatrixView<'_, T> { impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for DenseMatrixView<'a, T> {
fn get(&self, i: usize) -> &T { fn get(&self, i: usize) -> &T {
if self.nrows == 1 { if self.nrows == 1 {
if self.column_major { if self.column_major {
@@ -553,16 +488,16 @@ impl<T: Debug + Display + Copy + Sized> Array<T, usize> for DenseMatrixView<'_,
} }
} }
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixView<'_, T> {} impl<'a, T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixView<'a, T> {}
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for DenseMatrixView<'_, T> {} impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for DenseMatrixView<'a, T> {}
impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixMutView<'_, T> { impl<'a, T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixMutView<'a, T> {
fn get(&self, pos: (usize, usize)) -> &T { fn get(&self, pos: (usize, usize)) -> &T {
if self.column_major { if self.column_major {
&self.values[pos.0 + pos.1 * self.stride] &self.values[(pos.0 + pos.1 * self.stride)]
} else { } else {
&self.values[pos.0 * self.stride + pos.1] &self.values[(pos.0 * self.stride + pos.1)]
} }
} }
@@ -579,12 +514,14 @@ impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrix
} }
} }
impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for DenseMatrixMutView<'_, T> { impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)>
for DenseMatrixMutView<'a, T>
{
fn set(&mut self, pos: (usize, usize), x: T) { fn set(&mut self, pos: (usize, usize), x: T) {
if self.column_major { if self.column_major {
self.values[pos.0 + pos.1 * self.stride] = x; self.values[(pos.0 + pos.1 * self.stride)] = x;
} else { } else {
self.values[pos.0 * self.stride + pos.1] = x; self.values[(pos.0 * self.stride + pos.1)] = x;
} }
} }
@@ -593,90 +530,29 @@ impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for DenseMat
} }
} }
impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for DenseMatrixMutView<'_, T> {} impl<'a, T: Debug + Display + Copy + Sized> MutArrayView2<T> for DenseMatrixMutView<'a, T> {}
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixMutView<'_, T> {} impl<'a, T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixMutView<'a, T> {}
impl<T: RealNumber> MatrixStats<T> for DenseMatrix<T> {} impl<T: RealNumber> MatrixStats<T> for DenseMatrix<T> {}
impl<T: RealNumber> MatrixPreprocessing<T> for DenseMatrix<T> {} impl<T: RealNumber> MatrixPreprocessing<T> for DenseMatrix<T> {}
#[cfg(test)] #[cfg(test)]
#[warn(clippy::reversed_empty_ranges)]
mod tests { mod tests {
use super::*; use super::*;
use approx::relative_eq; use approx::relative_eq;
#[test]
fn test_instantiate_from_2d() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
assert!(x.is_ok());
}
#[test]
fn test_instantiate_from_2d_empty() {
let input: &[&[f64]] = &[&[]];
let x = DenseMatrix::from_2d_array(input);
assert!(x.is_err());
}
#[test]
fn test_instantiate_from_2d_empty2() {
let input: &[&[f64]] = &[&[], &[]];
let x = DenseMatrix::from_2d_array(input);
assert!(x.is_err());
}
#[test]
fn test_instantiate_ok_view1() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
let v = DenseMatrixView::new(&x, 0..2, 0..2);
assert!(v.is_ok());
}
#[test]
fn test_instantiate_ok_view2() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
let v = DenseMatrixView::new(&x, 0..3, 0..3);
assert!(v.is_ok());
}
#[test]
fn test_instantiate_ok_view3() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
let v = DenseMatrixView::new(&x, 2..3, 0..3);
assert!(v.is_ok());
}
#[test]
fn test_instantiate_ok_view4() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
let v = DenseMatrixView::new(&x, 3..3, 0..3);
assert!(v.is_ok());
}
#[test]
fn test_instantiate_err_view1() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
let v = DenseMatrixView::new(&x, 3..4, 0..3);
assert!(v.is_err());
}
#[test]
fn test_instantiate_err_view2() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
let v = DenseMatrixView::new(&x, 0..3, 3..4);
assert!(v.is_err());
}
#[test]
fn test_instantiate_err_view3() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
#[allow(clippy::reversed_empty_ranges)]
let v = DenseMatrixView::new(&x, 0..3, 4..3);
assert!(v.is_err());
}
#[test] #[test]
fn test_display() { fn test_display() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap(); let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
println!("{}", &x); println!("{}", &x);
} }
#[test] #[test]
fn test_get_row_col() { fn test_get_row_col() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap(); let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
assert_eq!(15.0, x.get_col(1).sum()); assert_eq!(15.0, x.get_col(1).sum());
assert_eq!(15.0, x.get_row(1).sum()); assert_eq!(15.0, x.get_row(1).sum());
@@ -685,7 +561,7 @@ mod tests {
#[test] #[test]
fn test_row_major() { fn test_row_major() {
let mut x = DenseMatrix::new(2, 3, vec![1, 2, 3, 4, 5, 6], false).unwrap(); let mut x = DenseMatrix::new(2, 3, vec![1, 2, 3, 4, 5, 6], false);
assert_eq!(5, *x.get_col(1).get(1)); assert_eq!(5, *x.get_col(1).get(1));
assert_eq!(7, x.get_col(1).sum()); assert_eq!(7, x.get_col(1).sum());
@@ -699,22 +575,21 @@ mod tests {
#[test] #[test]
fn test_get_slice() { fn test_get_slice() {
let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9], &[10, 11, 12]]) let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9], &[10, 11, 12]]);
.unwrap();
assert_eq!( assert_eq!(
vec![4, 5, 6], vec![4, 5, 6],
DenseMatrix::from_slice(&(*x.slice(1..2, 0..3))).values DenseMatrix::from_slice(&(*x.slice(1..2, 0..3))).values
); );
let second_row: Vec<i32> = x.slice(1..2, 0..3).iterator(0).copied().collect(); let second_row: Vec<i32> = x.slice(1..2, 0..3).iterator(0).map(|x| *x).collect();
assert_eq!(vec![4, 5, 6], second_row); assert_eq!(vec![4, 5, 6], second_row);
let second_col: Vec<i32> = x.slice(0..3, 1..2).iterator(0).copied().collect(); let second_col: Vec<i32> = x.slice(0..3, 1..2).iterator(0).map(|x| *x).collect();
assert_eq!(vec![2, 5, 8], second_col); assert_eq!(vec![2, 5, 8], second_col);
} }
#[test] #[test]
fn test_iter_mut() { fn test_iter_mut() {
let mut x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]]).unwrap(); let mut x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]]);
assert_eq!(vec![1, 4, 7, 2, 5, 8, 3, 6, 9], x.values); assert_eq!(vec![1, 4, 7, 2, 5, 8, 3, 6, 9], x.values);
// add +2 to some elements // add +2 to some elements
@@ -750,8 +625,7 @@ mod tests {
#[test] #[test]
fn test_str_array() { fn test_str_array() {
let mut x = let mut x =
DenseMatrix::from_2d_array(&[&["1", "2", "3"], &["4", "5", "6"], &["7", "8", "9"]]) DenseMatrix::from_2d_array(&[&["1", "2", "3"], &["4", "5", "6"], &["7", "8", "9"]]);
.unwrap();
assert_eq!(vec!["1", "4", "7", "2", "5", "8", "3", "6", "9"], x.values); assert_eq!(vec!["1", "4", "7", "2", "5", "8", "3", "6", "9"], x.values);
x.iterator_mut(0).for_each(|v| *v = "str"); x.iterator_mut(0).for_each(|v| *v = "str");
@@ -763,20 +637,20 @@ mod tests {
#[test] #[test]
fn test_transpose() { fn test_transpose() {
let x = DenseMatrix::<&str>::from_2d_array(&[&["1", "2", "3"], &["4", "5", "6"]]).unwrap(); let x = DenseMatrix::<&str>::from_2d_array(&[&["1", "2", "3"], &["4", "5", "6"]]);
assert_eq!(vec!["1", "4", "2", "5", "3", "6"], x.values); assert_eq!(vec!["1", "4", "2", "5", "3", "6"], x.values);
assert!(x.column_major); assert!(x.column_major == true);
// transpose // transpose
let x = x.transpose(); let x = x.transpose();
assert_eq!(vec!["1", "4", "2", "5", "3", "6"], x.values); assert_eq!(vec!["1", "4", "2", "5", "3", "6"], x.values);
assert!(!x.column_major); // should change column_major assert!(x.column_major == false); // should change column_major
} }
#[test] #[test]
fn test_from_iterator() { fn test_from_iterator() {
let data = [1, 2, 3, 4, 5, 6]; let data = vec![1, 2, 3, 4, 5, 6];
let m = DenseMatrix::from_iterator(data.iter(), 2, 3, 0); let m = DenseMatrix::from_iterator(data.iter(), 2, 3, 0);
@@ -785,25 +659,25 @@ mod tests {
vec![1, 2, 3, 4, 5, 6], vec![1, 2, 3, 4, 5, 6],
m.values.iter().map(|e| **e).collect::<Vec<i32>>() m.values.iter().map(|e| **e).collect::<Vec<i32>>()
); );
assert!(!m.column_major); assert!(m.column_major == false);
} }
#[test] #[test]
fn test_take() { fn test_take() {
let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap(); let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]);
let b = DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4], &[5, 6]]).unwrap(); let b = DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4], &[5, 6]]);
println!("{a}"); println!("{}", a);
// take column 0 and 2 // take column 0 and 2
assert_eq!(vec![1, 3, 4, 6], a.take(&[0, 2], 1).values); assert_eq!(vec![1, 3, 4, 6], a.take(&[0, 2], 1).values);
println!("{b}"); println!("{}", b);
// take rows 0 and 2 // take rows 0 and 2
assert_eq!(vec![1, 2, 5, 6], b.take(&[0, 2], 0).values); assert_eq!(vec![1, 2, 5, 6], b.take(&[0, 2], 0).values);
} }
#[test] #[test]
fn test_mut() { fn test_mut() {
let a = DenseMatrix::from_2d_array(&[&[1.3, -2.1, 3.4], &[-4., -5.3, 6.1]]).unwrap(); let a = DenseMatrix::from_2d_array(&[&[1.3, -2.1, 3.4], &[-4., -5.3, 6.1]]);
let a = a.abs(); let a = a.abs();
assert_eq!(vec![1.3, 4.0, 2.1, 5.3, 3.4, 6.1], a.values); assert_eq!(vec![1.3, 4.0, 2.1, 5.3, 3.4, 6.1], a.values);
@@ -814,29 +688,26 @@ mod tests {
#[test] #[test]
fn test_reshape() { fn test_reshape() {
let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9], &[10, 11, 12]]) let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9], &[10, 11, 12]]);
.unwrap();
let a = a.reshape(2, 6, 0); let a = a.reshape(2, 6, 0);
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], a.values); assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], a.values);
assert!(a.ncols == 6 && a.nrows == 2 && !a.column_major); assert!(a.ncols == 6 && a.nrows == 2 && a.column_major == false);
let a = a.reshape(3, 4, 1); let a = a.reshape(3, 4, 1);
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], a.values); assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], a.values);
assert!(a.ncols == 4 && a.nrows == 3 && a.column_major); assert!(a.ncols == 4 && a.nrows == 3 && a.column_major == true);
} }
#[test] #[test]
fn test_eq() { fn test_eq() {
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]).unwrap(); let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
let b = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap(); let b = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
let c = DenseMatrix::from_2d_array(&[ let c = DenseMatrix::from_2d_array(&[
&[1. + f32::EPSILON, 2., 3.], &[1. + f32::EPSILON, 2., 3.],
&[4., 5., 6. + f32::EPSILON], &[4., 5., 6. + f32::EPSILON],
]) ]);
.unwrap(); let d = DenseMatrix::from_2d_array(&[&[1. + 0.5, 2., 3.], &[4., 5., 6. + f32::EPSILON]]);
let d = DenseMatrix::from_2d_array(&[&[1. + 0.5, 2., 3.], &[4., 5., 6. + f32::EPSILON]])
.unwrap();
assert!(!relative_eq!(a, b)); assert!(!relative_eq!(a, b));
assert!(!relative_eq!(a, d)); assert!(!relative_eq!(a, d));
+10 -31
View File
@@ -15,25 +15,6 @@ pub struct VecView<'a, T: Debug + Display + Copy + Sized> {
ptr: &'a [T], ptr: &'a [T],
} }
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for &[T] {
fn get(&self, i: usize) -> &T {
&self[i]
}
fn shape(&self) -> usize {
self.len()
}
fn is_empty(&self) -> bool {
self.len() > 0
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
assert!(axis == 0, "For one dimensional array `axis` should == 0");
Box::new(self.iter())
}
}
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for Vec<T> { impl<T: Debug + Display + Copy + Sized> Array<T, usize> for Vec<T> {
fn get(&self, i: usize) -> &T { fn get(&self, i: usize) -> &T {
&self[i] &self[i]
@@ -55,7 +36,6 @@ impl<T: Debug + Display + Copy + Sized> Array<T, usize> for Vec<T> {
impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for Vec<T> { impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for Vec<T> {
fn set(&mut self, i: usize, x: T) { fn set(&mut self, i: usize, x: T) {
// NOTE: this panics in case of out of bounds index
self[i] = x self[i] = x
} }
@@ -66,7 +46,6 @@ impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for Vec<T> {
} }
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for Vec<T> {} impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for Vec<T> {}
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for &[T] {}
impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for Vec<T> {} impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for Vec<T> {}
@@ -119,7 +98,7 @@ impl<T: Debug + Display + Copy + Sized> Array1<T> for Vec<T> {
} }
} }
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for VecMutView<'_, T> { impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for VecMutView<'a, T> {
fn get(&self, i: usize) -> &T { fn get(&self, i: usize) -> &T {
&self.ptr[i] &self.ptr[i]
} }
@@ -138,7 +117,7 @@ impl<T: Debug + Display + Copy + Sized> Array<T, usize> for VecMutView<'_, T> {
} }
} }
impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for VecMutView<'_, T> { impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, usize> for VecMutView<'a, T> {
fn set(&mut self, i: usize, x: T) { fn set(&mut self, i: usize, x: T) {
self.ptr[i] = x; self.ptr[i] = x;
} }
@@ -149,10 +128,10 @@ impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for VecMutView<'_, T>
} }
} }
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for VecMutView<'_, T> {} impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for VecMutView<'a, T> {}
impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for VecMutView<'_, T> {} impl<'a, T: Debug + Display + Copy + Sized> MutArrayView1<T> for VecMutView<'a, T> {}
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for VecView<'_, T> { impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for VecView<'a, T> {
fn get(&self, i: usize) -> &T { fn get(&self, i: usize) -> &T {
&self.ptr[i] &self.ptr[i]
} }
@@ -171,7 +150,7 @@ impl<T: Debug + Display + Copy + Sized> Array<T, usize> for VecView<'_, T> {
} }
} }
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for VecView<'_, T> {} impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for VecView<'a, T> {}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
@@ -181,8 +160,8 @@ mod tests {
fn dot_product<T: Number, V: Array1<T>>(v: &V) -> T { fn dot_product<T: Number, V: Array1<T>>(v: &V) -> T {
let vv = V::zeros(10); let vv = V::zeros(10);
let v_s = vv.slice(0..3); let v_s = vv.slice(0..3);
let dot = v_s.dot(v);
v_s.dot(v) dot
} }
fn vector_ops<T: Number + PartialOrd, V: Array1<T>>(_: &V) -> T { fn vector_ops<T: Number + PartialOrd, V: Array1<T>>(_: &V) -> T {
@@ -212,7 +191,7 @@ mod tests {
#[test] #[test]
fn test_len() { fn test_len() {
let x = [1, 2, 3]; let x = vec![1, 2, 3];
assert_eq!(3, x.len()); assert_eq!(3, x.len());
} }
@@ -237,7 +216,7 @@ mod tests {
#[test] #[test]
fn test_mut_iterator() { fn test_mut_iterator() {
let mut x = vec![1, 2, 3]; let mut x = vec![1, 2, 3];
x.iterator_mut(0).for_each(|v| *v *= 2); x.iterator_mut(0).for_each(|v| *v = *v * 2);
assert_eq!(vec![2, 4, 6], x); assert_eq!(vec![2, 4, 6], x);
} }
+16 -12
View File
@@ -68,7 +68,7 @@ impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayBase<OwnedRepr<T>
impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for ArrayBase<OwnedRepr<T>, Ix2> {} impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> for ArrayView<'_, T, Ix2> { impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> for ArrayView<'a, T, Ix2> {
fn get(&self, pos: (usize, usize)) -> &T { fn get(&self, pos: (usize, usize)) -> &T {
&self[[pos.0, pos.1]] &self[[pos.0, pos.1]]
} }
@@ -144,9 +144,11 @@ impl<T: Number + RealNumber> EVDDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2>
impl<T: Number + RealNumber> LUDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {} impl<T: Number + RealNumber> LUDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl<T: Number + RealNumber> SVDDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {} impl<T: Number + RealNumber> SVDDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayView<'_, T, Ix2> {} impl<'a, T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayView<'a, T, Ix2> {}
impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> for ArrayViewMut<'_, T, Ix2> { impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)>
for ArrayViewMut<'a, T, Ix2>
{
fn get(&self, pos: (usize, usize)) -> &T { fn get(&self, pos: (usize, usize)) -> &T {
&self[[pos.0, pos.1]] &self[[pos.0, pos.1]]
} }
@@ -173,7 +175,9 @@ impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> for ArrayVi
} }
} }
impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for ArrayViewMut<'_, T, Ix2> { impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)>
for ArrayViewMut<'a, T, Ix2>
{
fn set(&mut self, pos: (usize, usize), x: T) { fn set(&mut self, pos: (usize, usize), x: T) {
self[[pos.0, pos.1]] = x self[[pos.0, pos.1]] = x
} }
@@ -191,9 +195,9 @@ impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for ArrayVie
} }
} }
impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for ArrayViewMut<'_, T, Ix2> {} impl<'a, T: Debug + Display + Copy + Sized> MutArrayView2<T> for ArrayViewMut<'a, T, Ix2> {}
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayViewMut<'_, T, Ix2> {} impl<'a, T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayViewMut<'a, T, Ix2> {}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
@@ -213,7 +217,7 @@ mod tests {
fn test_iterator() { fn test_iterator() {
let a = arr2(&[[1, 2, 3], [4, 5, 6]]); let a = arr2(&[[1, 2, 3], [4, 5, 6]]);
let v: Vec<i32> = a.iterator(0).copied().collect(); let v: Vec<i32> = a.iterator(0).map(|&v| v).collect();
assert_eq!(v, vec!(1, 2, 3, 4, 5, 6)); assert_eq!(v, vec!(1, 2, 3, 4, 5, 6));
} }
@@ -232,7 +236,7 @@ mod tests {
let x = arr2(&[[1, 2, 3], [4, 5, 6]]); let x = arr2(&[[1, 2, 3], [4, 5, 6]]);
let x_slice = Array2::slice(&x, 0..2, 1..2); let x_slice = Array2::slice(&x, 0..2, 1..2);
assert_eq!((2, 1), x_slice.shape()); assert_eq!((2, 1), x_slice.shape());
let v: Vec<i32> = x_slice.iterator(0).copied().collect(); let v: Vec<i32> = x_slice.iterator(0).map(|&v| v).collect();
assert_eq!(v, [2, 5]); assert_eq!(v, [2, 5]);
} }
@@ -241,11 +245,11 @@ mod tests {
let x = arr2(&[[1, 2, 3], [4, 5, 6]]); let x = arr2(&[[1, 2, 3], [4, 5, 6]]);
let x_slice = Array2::slice(&x, 0..2, 0..3); let x_slice = Array2::slice(&x, 0..2, 0..3);
assert_eq!( assert_eq!(
x_slice.iterator(0).copied().collect::<Vec<i32>>(), x_slice.iterator(0).map(|&v| v).collect::<Vec<i32>>(),
vec![1, 2, 3, 4, 5, 6] vec![1, 2, 3, 4, 5, 6]
); );
assert_eq!( assert_eq!(
x_slice.iterator(1).copied().collect::<Vec<i32>>(), x_slice.iterator(1).map(|&v| v).collect::<Vec<i32>>(),
vec![1, 4, 2, 5, 3, 6] vec![1, 4, 2, 5, 3, 6]
); );
} }
@@ -275,8 +279,8 @@ mod tests {
fn test_c_from_iterator() { fn test_c_from_iterator() {
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
let a: NDArray2<i32> = Array2::from_iterator(data.clone().into_iter(), 4, 3, 0); let a: NDArray2<i32> = Array2::from_iterator(data.clone().into_iter(), 4, 3, 0);
println!("{a}"); println!("{}", a);
let a: NDArray2<i32> = Array2::from_iterator(data.into_iter(), 4, 3, 1); let a: NDArray2<i32> = Array2::from_iterator(data.into_iter(), 4, 3, 1);
println!("{a}"); println!("{}", a);
} }
} }
+7 -7
View File
@@ -41,7 +41,7 @@ impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayBase<OwnedRepr<T>
impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for ArrayBase<OwnedRepr<T>, Ix1> {} impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for ArrayBase<OwnedRepr<T>, Ix1> {}
impl<T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayView<'_, T, Ix1> { impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayView<'a, T, Ix1> {
fn get(&self, i: usize) -> &T { fn get(&self, i: usize) -> &T {
&self[i] &self[i]
} }
@@ -60,9 +60,9 @@ impl<T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayView<'_, T,
} }
} }
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayView<'_, T, Ix1> {} impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayView<'a, T, Ix1> {}
impl<T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayViewMut<'_, T, Ix1> { impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayViewMut<'a, T, Ix1> {
fn get(&self, i: usize) -> &T { fn get(&self, i: usize) -> &T {
&self[i] &self[i]
} }
@@ -81,7 +81,7 @@ impl<T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayViewMut<'_,
} }
} }
impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for ArrayViewMut<'_, T, Ix1> { impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, usize> for ArrayViewMut<'a, T, Ix1> {
fn set(&mut self, i: usize, x: T) { fn set(&mut self, i: usize, x: T) {
self[i] = x; self[i] = x;
} }
@@ -92,8 +92,8 @@ impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for ArrayViewMut<'_,
} }
} }
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayViewMut<'_, T, Ix1> {} impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayViewMut<'a, T, Ix1> {}
impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for ArrayViewMut<'_, T, Ix1> {} impl<'a, T: Debug + Display + Copy + Sized> MutArrayView1<T> for ArrayViewMut<'a, T, Ix1> {}
impl<T: Debug + Display + Copy + Sized> Array1<T> for ArrayBase<OwnedRepr<T>, Ix1> { impl<T: Debug + Display + Copy + Sized> Array1<T> for ArrayBase<OwnedRepr<T>, Ix1> {
fn slice<'a>(&'a self, range: Range<usize>) -> Box<dyn ArrayView1<T> + 'a> { fn slice<'a>(&'a self, range: Range<usize>) -> Box<dyn ArrayView1<T> + 'a> {
@@ -152,7 +152,7 @@ mod tests {
fn test_iterator() { fn test_iterator() {
let a = arr1(&[1, 2, 3]); let a = arr1(&[1, 2, 3]);
let v: Vec<i32> = a.iterator(0).copied().collect(); let v: Vec<i32> = a.iterator(0).map(|&v| v).collect();
assert_eq!(v, vec!(1, 2, 3)); assert_eq!(v, vec!(1, 2, 3));
} }
+7 -11
View File
@@ -15,7 +15,7 @@
//! &[25., 15., -5.], //! &[25., 15., -5.],
//! &[15., 18., 0.], //! &[15., 18., 0.],
//! &[-5., 0., 11.] //! &[-5., 0., 11.]
//! ]).unwrap(); //! ]);
//! //!
//! let cholesky = A.cholesky().unwrap(); //! let cholesky = A.cholesky().unwrap();
//! let lower_triangular: DenseMatrix<f64> = cholesky.L(); //! let lower_triangular: DenseMatrix<f64> = cholesky.L();
@@ -175,14 +175,11 @@ mod tests {
)] )]
#[test] #[test]
fn cholesky_decompose() { fn cholesky_decompose() {
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]) let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
.unwrap();
let l = let l =
DenseMatrix::from_2d_array(&[&[5.0, 0.0, 0.0], &[3.0, 3.0, 0.0], &[-1.0, 1.0, 3.0]]) DenseMatrix::from_2d_array(&[&[5.0, 0.0, 0.0], &[3.0, 3.0, 0.0], &[-1.0, 1.0, 3.0]]);
.unwrap();
let u = let u =
DenseMatrix::from_2d_array(&[&[5.0, 3.0, -1.0], &[0.0, 3.0, 1.0], &[0.0, 0.0, 3.0]]) DenseMatrix::from_2d_array(&[&[5.0, 3.0, -1.0], &[0.0, 3.0, 1.0], &[0.0, 0.0, 3.0]]);
.unwrap();
let cholesky = a.cholesky().unwrap(); let cholesky = a.cholesky().unwrap();
assert!(relative_eq!(cholesky.L().abs(), l.abs(), epsilon = 1e-4)); assert!(relative_eq!(cholesky.L().abs(), l.abs(), epsilon = 1e-4));
@@ -200,10 +197,9 @@ mod tests {
)] )]
#[test] #[test]
fn cholesky_solve_mut() { fn cholesky_solve_mut() {
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]) let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
.unwrap(); let b = DenseMatrix::from_2d_array(&[&[40., 51., 28.]]);
let b = DenseMatrix::from_2d_array(&[&[40., 51., 28.]]).unwrap(); let expected = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]);
let expected = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]).unwrap();
let cholesky = a.cholesky().unwrap(); let cholesky = a.cholesky().unwrap();
+22 -24
View File
@@ -19,7 +19,7 @@
//! &[0.9000, 0.4000, 0.7000], //! &[0.9000, 0.4000, 0.7000],
//! &[0.4000, 0.5000, 0.3000], //! &[0.4000, 0.5000, 0.3000],
//! &[0.7000, 0.3000, 0.8000], //! &[0.7000, 0.3000, 0.8000],
//! ]).unwrap(); //! ]);
//! //!
//! let evd = A.evd(true).unwrap(); //! let evd = A.evd(true).unwrap();
//! let eigenvectors: DenseMatrix<f64> = evd.V; //! let eigenvectors: DenseMatrix<f64> = evd.V;
@@ -66,7 +66,7 @@ pub trait EVDDecomposable<T: Number + RealNumber>: Array2<T> {
fn evd_mut(mut self, symmetric: bool) -> Result<EVD<T, Self>, Failed> { fn evd_mut(mut self, symmetric: bool) -> Result<EVD<T, Self>, Failed> {
let (nrows, ncols) = self.shape(); let (nrows, ncols) = self.shape();
if ncols != nrows { if ncols != nrows {
panic!("Matrix is not square: {nrows} x {ncols}"); panic!("Matrix is not square: {} x {}", nrows, ncols);
} }
let n = nrows; let n = nrows;
@@ -820,8 +820,7 @@ mod tests {
&[0.9000, 0.4000, 0.7000], &[0.9000, 0.4000, 0.7000],
&[0.4000, 0.5000, 0.3000], &[0.4000, 0.5000, 0.3000],
&[0.7000, 0.3000, 0.8000], &[0.7000, 0.3000, 0.8000],
]) ]);
.unwrap();
let eigen_values: Vec<f64> = vec![1.7498382, 0.3165784, 0.1335834]; let eigen_values: Vec<f64> = vec![1.7498382, 0.3165784, 0.1335834];
@@ -829,8 +828,7 @@ mod tests {
&[0.6881997, -0.07121225, 0.7220180], &[0.6881997, -0.07121225, 0.7220180],
&[0.3700456, 0.89044952, -0.2648886], &[0.3700456, 0.89044952, -0.2648886],
&[0.6240573, -0.44947578, -0.6391588], &[0.6240573, -0.44947578, -0.6391588],
]) ]);
.unwrap();
let evd = A.evd(true).unwrap(); let evd = A.evd(true).unwrap();
@@ -839,9 +837,11 @@ mod tests {
evd.V.abs(), evd.V.abs(),
epsilon = 1e-4 epsilon = 1e-4
)); ));
for (i, eigen_values_i) in eigen_values.iter().enumerate() { for i in 0..eigen_values.len() {
assert!((eigen_values_i - evd.d[i]).abs() < 1e-4); assert!((eigen_values[i] - evd.d[i]).abs() < 1e-4);
assert!((0f64 - evd.e[i]).abs() < f64::EPSILON); }
for i in 0..eigen_values.len() {
assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON);
} }
} }
#[cfg_attr( #[cfg_attr(
@@ -854,8 +854,7 @@ mod tests {
&[0.9000, 0.4000, 0.7000], &[0.9000, 0.4000, 0.7000],
&[0.4000, 0.5000, 0.3000], &[0.4000, 0.5000, 0.3000],
&[0.8000, 0.3000, 0.8000], &[0.8000, 0.3000, 0.8000],
]) ]);
.unwrap();
let eigen_values: Vec<f64> = vec![1.79171122, 0.31908143, 0.08920735]; let eigen_values: Vec<f64> = vec![1.79171122, 0.31908143, 0.08920735];
@@ -863,8 +862,7 @@ mod tests {
&[0.7178958, 0.05322098, 0.6812010], &[0.7178958, 0.05322098, 0.6812010],
&[0.3837711, -0.84702111, -0.1494582], &[0.3837711, -0.84702111, -0.1494582],
&[0.6952105, 0.43984484, -0.7036135], &[0.6952105, 0.43984484, -0.7036135],
]) ]);
.unwrap();
let evd = A.evd(false).unwrap(); let evd = A.evd(false).unwrap();
@@ -873,9 +871,11 @@ mod tests {
evd.V.abs(), evd.V.abs(),
epsilon = 1e-4 epsilon = 1e-4
)); ));
for (i, eigen_values_i) in eigen_values.iter().enumerate() { for i in 0..eigen_values.len() {
assert!((eigen_values_i - evd.d[i]).abs() < 1e-4); assert!((eigen_values[i] - evd.d[i]).abs() < 1e-4);
assert!((0f64 - evd.e[i]).abs() < f64::EPSILON); }
for i in 0..eigen_values.len() {
assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON);
} }
} }
#[cfg_attr( #[cfg_attr(
@@ -889,8 +889,7 @@ mod tests {
&[4.0, -1.0, 1.0, 1.0], &[4.0, -1.0, 1.0, 1.0],
&[1.0, 1.0, 3.0, -2.0], &[1.0, 1.0, 3.0, -2.0],
&[1.0, 1.0, 4.0, -1.0], &[1.0, 1.0, 4.0, -1.0],
]) ]);
.unwrap();
let eigen_values_d: Vec<f64> = vec![0.0, 2.0, 2.0, 0.0]; let eigen_values_d: Vec<f64> = vec![0.0, 2.0, 2.0, 0.0];
let eigen_values_e: Vec<f64> = vec![2.2361, 0.9999, -0.9999, -2.2361]; let eigen_values_e: Vec<f64> = vec![2.2361, 0.9999, -0.9999, -2.2361];
@@ -900,8 +899,7 @@ mod tests {
&[-0.6707, 0.1059, 0.901, 0.6289], &[-0.6707, 0.1059, 0.901, 0.6289],
&[0.9159, -0.1378, 0.3816, 0.0806], &[0.9159, -0.1378, 0.3816, 0.0806],
&[0.6707, 0.1059, 0.901, -0.6289], &[0.6707, 0.1059, 0.901, -0.6289],
]) ]);
.unwrap();
let evd = A.evd(false).unwrap(); let evd = A.evd(false).unwrap();
@@ -910,11 +908,11 @@ mod tests {
evd.V.abs(), evd.V.abs(),
epsilon = 1e-4 epsilon = 1e-4
)); ));
for (i, eigen_values_d_i) in eigen_values_d.iter().enumerate() { for i in 0..eigen_values_d.len() {
assert!((eigen_values_d_i - evd.d[i]).abs() < 1e-4); assert!((eigen_values_d[i] - evd.d[i]).abs() < 1e-4);
} }
for (i, eigen_values_e_i) in eigen_values_e.iter().enumerate() { for i in 0..eigen_values_e.len() {
assert!((eigen_values_e_i - evd.e[i]).abs() < 1e-4); assert!((eigen_values_e[i] - evd.e[i]).abs() < 1e-4);
} }
} }
} }
+3 -3
View File
@@ -12,9 +12,9 @@ pub trait HighOrderOperations<T: Number>: Array2<T> {
/// use smartcore::linalg::traits::high_order::HighOrderOperations; /// use smartcore::linalg::traits::high_order::HighOrderOperations;
/// use smartcore::linalg::basic::arrays::Array2; /// use smartcore::linalg::basic::arrays::Array2;
/// ///
/// let a = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.]]).unwrap(); /// let a = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.]]);
/// let b = DenseMatrix::from_2d_array(&[&[5., 6.], &[7., 8.], &[9., 10.]]).unwrap(); /// let b = DenseMatrix::from_2d_array(&[&[5., 6.], &[7., 8.], &[9., 10.]]);
/// let expected = DenseMatrix::from_2d_array(&[&[71., 80.], &[92., 104.]]).unwrap(); /// let expected = DenseMatrix::from_2d_array(&[&[71., 80.], &[92., 104.]]);
/// ///
/// assert_eq!(a.ab(true, &b, false), expected); /// assert_eq!(a.ab(true, &b, false), expected);
/// ``` /// ```
+12 -10
View File
@@ -18,7 +18,7 @@
//! &[1., 2., 3.], //! &[1., 2., 3.],
//! &[0., 1., 5.], //! &[0., 1., 5.],
//! &[5., 6., 0.] //! &[5., 6., 0.]
//! ]).unwrap(); //! ]);
//! //!
//! let lu = A.lu().unwrap(); //! let lu = A.lu().unwrap();
//! let lower: DenseMatrix<f64> = lu.L(); //! let lower: DenseMatrix<f64> = lu.L();
@@ -126,7 +126,7 @@ impl<T: Number + RealNumber, M: Array2<T>> LU<T, M> {
let (m, n) = self.LU.shape(); let (m, n) = self.LU.shape();
if m != n { if m != n {
panic!("Matrix is not square: {m}x{n}"); panic!("Matrix is not square: {}x{}", m, n);
} }
let mut inv = M::zeros(n, n); let mut inv = M::zeros(n, n);
@@ -143,7 +143,10 @@ impl<T: Number + RealNumber, M: Array2<T>> LU<T, M> {
let (b_m, b_n) = b.shape(); let (b_m, b_n) = b.shape();
if b_m != m { if b_m != m {
panic!("Row dimensions do not agree: A is {m} x {n}, but B is {b_m} x {b_n}"); panic!(
"Row dimensions do not agree: A is {} x {}, but B is {} x {}",
m, n, b_m, b_n
);
} }
if self.singular { if self.singular {
@@ -263,13 +266,13 @@ mod tests {
)] )]
#[test] #[test]
fn decompose() { fn decompose() {
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]).unwrap(); let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
let expected_L = let expected_L =
DenseMatrix::from_2d_array(&[&[1., 0., 0.], &[0., 1., 0.], &[0.2, 0.8, 1.]]).unwrap(); DenseMatrix::from_2d_array(&[&[1., 0., 0.], &[0., 1., 0.], &[0.2, 0.8, 1.]]);
let expected_U = let expected_U =
DenseMatrix::from_2d_array(&[&[5., 6., 0.], &[0., 1., 5.], &[0., 0., -1.]]).unwrap(); DenseMatrix::from_2d_array(&[&[5., 6., 0.], &[0., 1., 5.], &[0., 0., -1.]]);
let expected_pivot = let expected_pivot =
DenseMatrix::from_2d_array(&[&[0., 0., 1.], &[0., 1., 0.], &[1., 0., 0.]]).unwrap(); DenseMatrix::from_2d_array(&[&[0., 0., 1.], &[0., 1., 0.], &[1., 0., 0.]]);
let lu = a.lu().unwrap(); let lu = a.lu().unwrap();
assert!(relative_eq!(lu.L(), expected_L, epsilon = 1e-4)); assert!(relative_eq!(lu.L(), expected_L, epsilon = 1e-4));
assert!(relative_eq!(lu.U(), expected_U, epsilon = 1e-4)); assert!(relative_eq!(lu.U(), expected_U, epsilon = 1e-4));
@@ -281,10 +284,9 @@ mod tests {
)] )]
#[test] #[test]
fn inverse() { fn inverse() {
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]).unwrap(); let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
let expected = let expected =
DenseMatrix::from_2d_array(&[&[-6.0, 3.6, 1.4], &[5.0, -3.0, -1.0], &[-1.0, 0.8, 0.2]]) DenseMatrix::from_2d_array(&[&[-6.0, 3.6, 1.4], &[5.0, -3.0, -1.0], &[-1.0, 0.8, 0.2]]);
.unwrap();
let a_inv = a.lu().and_then(|lu| lu.inverse()).unwrap(); let a_inv = a.lu().and_then(|lu| lu.inverse()).unwrap();
assert!(relative_eq!(a_inv, expected, epsilon = 1e-4)); assert!(relative_eq!(a_inv, expected, epsilon = 1e-4));
} }
+11 -13
View File
@@ -13,7 +13,7 @@
//! &[0.9, 0.4, 0.7], //! &[0.9, 0.4, 0.7],
//! &[0.4, 0.5, 0.3], //! &[0.4, 0.5, 0.3],
//! &[0.7, 0.3, 0.8] //! &[0.7, 0.3, 0.8]
//! ]).unwrap(); //! ]);
//! //!
//! let qr = A.qr().unwrap(); //! let qr = A.qr().unwrap();
//! let orthogonal: DenseMatrix<f64> = qr.Q(); //! let orthogonal: DenseMatrix<f64> = qr.Q();
@@ -102,7 +102,10 @@ impl<T: Number + RealNumber, M: Array2<T>> QR<T, M> {
let (b_nrows, b_ncols) = b.shape(); let (b_nrows, b_ncols) = b.shape();
if b_nrows != m { if b_nrows != m {
panic!("Row dimensions do not agree: A is {m} x {n}, but B is {b_nrows} x {b_ncols}"); panic!(
"Row dimensions do not agree: A is {} x {}, but B is {} x {}",
m, n, b_nrows, b_ncols
);
} }
if self.singular { if self.singular {
@@ -201,20 +204,17 @@ mod tests {
)] )]
#[test] #[test]
fn decompose() { fn decompose() {
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]) let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
.unwrap();
let q = DenseMatrix::from_2d_array(&[ let q = DenseMatrix::from_2d_array(&[
&[-0.7448, 0.2436, 0.6212], &[-0.7448, 0.2436, 0.6212],
&[-0.331, -0.9432, -0.027], &[-0.331, -0.9432, -0.027],
&[-0.5793, 0.2257, -0.7832], &[-0.5793, 0.2257, -0.7832],
]) ]);
.unwrap();
let r = DenseMatrix::from_2d_array(&[ let r = DenseMatrix::from_2d_array(&[
&[-1.2083, -0.6373, -1.0842], &[-1.2083, -0.6373, -1.0842],
&[0.0, -0.3064, 0.0682], &[0.0, -0.3064, 0.0682],
&[0.0, 0.0, -0.1999], &[0.0, 0.0, -0.1999],
]) ]);
.unwrap();
let qr = a.qr().unwrap(); let qr = a.qr().unwrap();
assert!(relative_eq!(qr.Q().abs(), q.abs(), epsilon = 1e-4)); assert!(relative_eq!(qr.Q().abs(), q.abs(), epsilon = 1e-4));
assert!(relative_eq!(qr.R().abs(), r.abs(), epsilon = 1e-4)); assert!(relative_eq!(qr.R().abs(), r.abs(), epsilon = 1e-4));
@@ -226,15 +226,13 @@ mod tests {
)] )]
#[test] #[test]
fn qr_solve_mut() { fn qr_solve_mut() {
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]) let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
.unwrap(); let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]);
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]).unwrap();
let expected_w = DenseMatrix::from_2d_array(&[ let expected_w = DenseMatrix::from_2d_array(&[
&[-0.2027027, -1.2837838], &[-0.2027027, -1.2837838],
&[0.8783784, 2.2297297], &[0.8783784, 2.2297297],
&[0.4729730, 0.6621622], &[0.4729730, 0.6621622],
]) ]);
.unwrap();
let w = a.qr_solve_mut(b).unwrap(); let w = a.qr_solve_mut(b).unwrap();
assert!(relative_eq!(w, expected_w, epsilon = 1e-2)); assert!(relative_eq!(w, expected_w, epsilon = 1e-2));
} }
+15 -18
View File
@@ -136,12 +136,13 @@ pub trait MatrixPreprocessing<T: RealNumber>: MutArrayView2<T> + Clone {
/// ```rust /// ```rust
/// use smartcore::linalg::basic::matrix::DenseMatrix; /// use smartcore::linalg::basic::matrix::DenseMatrix;
/// use smartcore::linalg::traits::stats::MatrixPreprocessing; /// use smartcore::linalg::traits::stats::MatrixPreprocessing;
/// let mut a = DenseMatrix::from_2d_array(&[&[0., 2., 3.], &[-5., -6., -7.]]).unwrap(); /// let mut a = DenseMatrix::from_2d_array(&[&[0., 2., 3.], &[-5., -6., -7.]]);
/// let expected = DenseMatrix::from_2d_array(&[&[0., 1., 1.],&[0., 0., 0.]]).unwrap(); /// let expected = DenseMatrix::from_2d_array(&[&[0., 1., 1.],&[0., 0., 0.]]);
/// a.binarize_mut(0.); /// a.binarize_mut(0.);
/// ///
/// assert_eq!(a, expected); /// assert_eq!(a, expected);
/// ``` /// ```
fn binarize_mut(&mut self, threshold: T) { fn binarize_mut(&mut self, threshold: T) {
let (nrows, ncols) = self.shape(); let (nrows, ncols) = self.shape();
for row in 0..nrows { for row in 0..nrows {
@@ -158,8 +159,8 @@ pub trait MatrixPreprocessing<T: RealNumber>: MutArrayView2<T> + Clone {
/// ```rust /// ```rust
/// use smartcore::linalg::basic::matrix::DenseMatrix; /// use smartcore::linalg::basic::matrix::DenseMatrix;
/// use smartcore::linalg::traits::stats::MatrixPreprocessing; /// use smartcore::linalg::traits::stats::MatrixPreprocessing;
/// let a = DenseMatrix::from_2d_array(&[&[0., 2., 3.], &[-5., -6., -7.]]).unwrap(); /// let a = DenseMatrix::from_2d_array(&[&[0., 2., 3.], &[-5., -6., -7.]]);
/// let expected = DenseMatrix::from_2d_array(&[&[0., 1., 1.],&[0., 0., 0.]]).unwrap(); /// let expected = DenseMatrix::from_2d_array(&[&[0., 1., 1.],&[0., 0., 0.]]);
/// ///
/// assert_eq!(a.binarize(0.), expected); /// assert_eq!(a.binarize(0.), expected);
/// ``` /// ```
@@ -185,8 +186,7 @@ mod tests {
&[1., 2., 3., 1., 2.], &[1., 2., 3., 1., 2.],
&[4., 5., 6., 3., 4.], &[4., 5., 6., 3., 4.],
&[7., 8., 9., 5., 6.], &[7., 8., 9., 5., 6.],
]) ]);
.unwrap();
let expected_0 = vec![4., 5., 6., 3., 4.]; let expected_0 = vec![4., 5., 6., 3., 4.];
let expected_1 = vec![1.8, 4.4, 7.]; let expected_1 = vec![1.8, 4.4, 7.];
@@ -196,7 +196,7 @@ mod tests {
#[test] #[test]
fn test_var() { fn test_var() {
let m = DenseMatrix::from_2d_array(&[&[1., 2., 3., 4.], &[5., 6., 7., 8.]]).unwrap(); let m = DenseMatrix::from_2d_array(&[&[1., 2., 3., 4.], &[5., 6., 7., 8.]]);
let expected_0 = vec![4., 4., 4., 4.]; let expected_0 = vec![4., 4., 4., 4.];
let expected_1 = vec![1.25, 1.25]; let expected_1 = vec![1.25, 1.25];
@@ -211,13 +211,12 @@ mod tests {
let m = DenseMatrix::from_2d_array(&[ let m = DenseMatrix::from_2d_array(&[
&[0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25], &[0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25],
&[0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25], &[0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25],
]) ]);
.unwrap();
let expected_0 = vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]; let expected_0 = vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let expected_1 = vec![1.25, 1.25]; let expected_1 = vec![1.25, 1.25];
assert!(m.var(0).approximate_eq(&expected_0, f64::EPSILON)); assert!(m.var(0).approximate_eq(&expected_0, std::f64::EPSILON));
assert!(m.var(1).approximate_eq(&expected_1, f64::EPSILON)); assert!(m.var(1).approximate_eq(&expected_1, std::f64::EPSILON));
assert_eq!( assert_eq!(
m.mean(0), m.mean(0),
vec![0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25] vec![0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25]
@@ -231,8 +230,7 @@ mod tests {
&[1., 2., 3., 1., 2.], &[1., 2., 3., 1., 2.],
&[4., 5., 6., 3., 4.], &[4., 5., 6., 3., 4.],
&[7., 8., 9., 5., 6.], &[7., 8., 9., 5., 6.],
]) ]);
.unwrap();
let expected_0 = vec![ let expected_0 = vec![
2.449489742783178, 2.449489742783178,
2.449489742783178, 2.449489742783178,
@@ -253,10 +251,10 @@ mod tests {
#[test] #[test]
fn test_scale() { fn test_scale() {
let m: DenseMatrix<f64> = let m: DenseMatrix<f64> =
DenseMatrix::from_2d_array(&[&[1., 2., 3., 4.], &[5., 6., 7., 8.]]).unwrap(); DenseMatrix::from_2d_array(&[&[1., 2., 3., 4.], &[5., 6., 7., 8.]]);
let expected_0: DenseMatrix<f64> = let expected_0: DenseMatrix<f64> =
DenseMatrix::from_2d_array(&[&[-1., -1., -1., -1.], &[1., 1., 1., 1.]]).unwrap(); DenseMatrix::from_2d_array(&[&[-1., -1., -1., -1.], &[1., 1., 1., 1.]]);
let expected_1: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[ let expected_1: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
&[ &[
-1.3416407864998738, -1.3416407864998738,
@@ -270,8 +268,7 @@ mod tests {
0.4472135954999579, 0.4472135954999579,
1.3416407864998738, 1.3416407864998738,
], ],
]) ]);
.unwrap();
assert_eq!(m.mean(0), vec![3.0, 4.0, 5.0, 6.0]); assert_eq!(m.mean(0), vec![3.0, 4.0, 5.0, 6.0]);
assert_eq!(m.mean(1), vec![2.5, 6.5]); assert_eq!(m.mean(1), vec![2.5, 6.5]);
@@ -289,7 +286,7 @@ mod tests {
} }
{ {
let mut m = m; let mut m = m.clone();
m.standard_scale_mut(&m.mean(1), &m.std(1), 1); m.standard_scale_mut(&m.mean(1), &m.std(1), 1);
assert_eq!(&m, &expected_1); assert_eq!(&m, &expected_1);
} }
+18 -24
View File
@@ -17,7 +17,7 @@
//! &[0.9, 0.4, 0.7], //! &[0.9, 0.4, 0.7],
//! &[0.4, 0.5, 0.3], //! &[0.4, 0.5, 0.3],
//! &[0.7, 0.3, 0.8] //! &[0.7, 0.3, 0.8]
//! ]).unwrap(); //! ]);
//! //!
//! let svd = A.svd().unwrap(); //! let svd = A.svd().unwrap();
//! let u: DenseMatrix<f64> = svd.U; //! let u: DenseMatrix<f64> = svd.U;
@@ -48,9 +48,11 @@ pub struct SVD<T: Number + RealNumber, M: SVDDecomposable<T>> {
pub V: M, pub V: M,
/// Singular values of the original matrix /// Singular values of the original matrix
pub s: Vec<T>, pub s: Vec<T>,
///
m: usize, m: usize,
///
n: usize, n: usize,
/// Tolerance ///
tol: T, tol: T,
} }
@@ -487,8 +489,7 @@ mod tests {
&[0.9000, 0.4000, 0.7000], &[0.9000, 0.4000, 0.7000],
&[0.4000, 0.5000, 0.3000], &[0.4000, 0.5000, 0.3000],
&[0.7000, 0.3000, 0.8000], &[0.7000, 0.3000, 0.8000],
]) ]);
.unwrap();
let s: Vec<f64> = vec![1.7498382, 0.3165784, 0.1335834]; let s: Vec<f64> = vec![1.7498382, 0.3165784, 0.1335834];
@@ -496,22 +497,20 @@ mod tests {
&[0.6881997, -0.07121225, 0.7220180], &[0.6881997, -0.07121225, 0.7220180],
&[0.3700456, 0.89044952, -0.2648886], &[0.3700456, 0.89044952, -0.2648886],
&[0.6240573, -0.44947578, -0.639158], &[0.6240573, -0.44947578, -0.639158],
]) ]);
.unwrap();
let V = DenseMatrix::from_2d_array(&[ let V = DenseMatrix::from_2d_array(&[
&[0.6881997, -0.07121225, 0.7220180], &[0.6881997, -0.07121225, 0.7220180],
&[0.3700456, 0.89044952, -0.2648886], &[0.3700456, 0.89044952, -0.2648886],
&[0.6240573, -0.44947578, -0.6391588], &[0.6240573, -0.44947578, -0.6391588],
]) ]);
.unwrap();
let svd = A.svd().unwrap(); let svd = A.svd().unwrap();
assert!(relative_eq!(V.abs(), svd.V.abs(), epsilon = 1e-4)); assert!(relative_eq!(V.abs(), svd.V.abs(), epsilon = 1e-4));
assert!(relative_eq!(U.abs(), svd.U.abs(), epsilon = 1e-4)); assert!(relative_eq!(U.abs(), svd.U.abs(), epsilon = 1e-4));
for (i, s_i) in s.iter().enumerate() { for i in 0..s.len() {
assert!((s_i - svd.s[i]).abs() < 1e-4); assert!((s[i] - svd.s[i]).abs() < 1e-4);
} }
} }
#[cfg_attr( #[cfg_attr(
@@ -578,8 +577,7 @@ mod tests {
-0.2158704, -0.2158704,
-0.27529472, -0.27529472,
], ],
]) ]);
.unwrap();
let s: Vec<f64> = vec![ let s: Vec<f64> = vec![
3.8589375, 3.4396766, 2.6487176, 2.2317399, 1.5165054, 0.8109055, 0.2706515, 3.8589375, 3.4396766, 2.6487176, 2.2317399, 1.5165054, 0.8109055, 0.2706515,
@@ -649,8 +647,7 @@ mod tests {
0.73034065, 0.73034065,
-0.43965505, -0.43965505,
], ],
]) ]);
.unwrap();
let V = DenseMatrix::from_2d_array(&[ let V = DenseMatrix::from_2d_array(&[
&[ &[
@@ -710,15 +707,14 @@ mod tests {
0.1654796, 0.1654796,
-0.32346758, -0.32346758,
], ],
]) ]);
.unwrap();
let svd = A.svd().unwrap(); let svd = A.svd().unwrap();
assert!(relative_eq!(V.abs(), svd.V.abs(), epsilon = 1e-4)); assert!(relative_eq!(V.abs(), svd.V.abs(), epsilon = 1e-4));
assert!(relative_eq!(U.abs(), svd.U.abs(), epsilon = 1e-4)); assert!(relative_eq!(U.abs(), svd.U.abs(), epsilon = 1e-4));
for (i, s_i) in s.iter().enumerate() { for i in 0..s.len() {
assert!((s_i - svd.s[i]).abs() < 1e-4); assert!((s[i] - svd.s[i]).abs() < 1e-4);
} }
} }
#[cfg_attr( #[cfg_attr(
@@ -727,11 +723,10 @@ mod tests {
)] )]
#[test] #[test]
fn solve() { fn solve() {
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]) let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
.unwrap(); let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]);
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]).unwrap();
let expected_w = let expected_w =
DenseMatrix::from_2d_array(&[&[-0.20, -1.28], &[0.87, 2.22], &[0.47, 0.66]]).unwrap(); DenseMatrix::from_2d_array(&[&[-0.20, -1.28], &[0.87, 2.22], &[0.47, 0.66]]);
let w = a.svd_solve_mut(b).unwrap(); let w = a.svd_solve_mut(b).unwrap();
assert!(relative_eq!(w, expected_w, epsilon = 1e-2)); assert!(relative_eq!(w, expected_w, epsilon = 1e-2));
} }
@@ -742,8 +737,7 @@ mod tests {
)] )]
#[test] #[test]
fn decompose_restore() { fn decompose_restore() {
let a = let a = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0, 4.0], &[5.0, 6.0, 7.0, 8.0]]);
DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0, 4.0], &[5.0, 6.0, 7.0, 8.0]]).unwrap();
let svd = a.svd().unwrap(); let svd = a.svd().unwrap();
let u: &DenseMatrix<f32> = &svd.U; //U let u: &DenseMatrix<f32> = &svd.U; //U
let v: &DenseMatrix<f32> = &svd.V; // V let v: &DenseMatrix<f32> = &svd.V; // V
+7 -9
View File
@@ -12,8 +12,7 @@
//! pub struct BGSolver {} //! pub struct BGSolver {}
//! impl<'a, T: FloatNumber, X: Array2<T>> BiconjugateGradientSolver<'a, T, X> for BGSolver {} //! impl<'a, T: FloatNumber, X: Array2<T>> BiconjugateGradientSolver<'a, T, X> for BGSolver {}
//! //!
//! let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., //! let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
//! 11.]]).unwrap();
//! let b = vec![40., 51., 28.]; //! let b = vec![40., 51., 28.];
//! let expected = vec![1.0, 2.0, 3.0]; //! let expected = vec![1.0, 2.0, 3.0];
//! let mut x = Vec::zeros(3); //! let mut x = Vec::zeros(3);
@@ -27,9 +26,9 @@ use crate::error::Failed;
use crate::linalg::basic::arrays::{Array, Array1, Array2, ArrayView1, MutArrayView1}; use crate::linalg::basic::arrays::{Array, Array1, Array2, ArrayView1, MutArrayView1};
use crate::numbers::floatnum::FloatNumber; use crate::numbers::floatnum::FloatNumber;
/// Trait for Biconjugate Gradient Solver ///
pub trait BiconjugateGradientSolver<'a, T: FloatNumber, X: Array2<T>> { pub trait BiconjugateGradientSolver<'a, T: FloatNumber, X: Array2<T>> {
/// Solve Ax = b ///
fn solve_mut( fn solve_mut(
&self, &self,
a: &'a X, a: &'a X,
@@ -109,7 +108,7 @@ pub trait BiconjugateGradientSolver<'a, T: FloatNumber, X: Array2<T>> {
Ok(err) Ok(err)
} }
/// solve preconditioner ///
fn solve_preconditioner(&self, a: &'a X, b: &[T], x: &mut [T]) { fn solve_preconditioner(&self, a: &'a X, b: &[T], x: &mut [T]) {
let diag = Self::diag(a); let diag = Self::diag(a);
let n = diag.len(); let n = diag.len();
@@ -133,7 +132,7 @@ pub trait BiconjugateGradientSolver<'a, T: FloatNumber, X: Array2<T>> {
y.copy_from(&x.xa(true, a)); y.copy_from(&x.xa(true, a));
} }
/// Extract the diagonal from a matrix ///
fn diag(a: &X) -> Vec<T> { fn diag(a: &X) -> Vec<T> {
let (nrows, ncols) = a.shape(); let (nrows, ncols) = a.shape();
let n = nrows.min(ncols); let n = nrows.min(ncols);
@@ -159,10 +158,9 @@ mod tests {
#[test] #[test]
fn bg_solver() { fn bg_solver() {
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]) let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
.unwrap();
let b = vec![40., 51., 28.]; let b = vec![40., 51., 28.];
let expected = [1.0, 2.0, 3.0]; let expected = vec![1.0, 2.0, 3.0];
let mut x = Vec::zeros(3); let mut x = Vec::zeros(3);
+8 -7
View File
@@ -38,7 +38,7 @@
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564], //! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], //! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551], //! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
//! ]).unwrap(); //! ]);
//! //!
//! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, //! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0,
//! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]; //! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
@@ -425,7 +425,10 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
for (i, col_std_i) in col_std.iter().enumerate() { for (i, col_std_i) in col_std.iter().enumerate() {
if (*col_std_i - TX::zero()).abs() < TX::epsilon() { if (*col_std_i - TX::zero()).abs() < TX::epsilon() {
return Err(Failed::fit(&format!("Cannot rescale constant column {i}"))); return Err(Failed::fit(&format!(
"Cannot rescale constant column {}",
i
)));
} }
} }
@@ -511,8 +514,7 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564], &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331], &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551], &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]) ]);
.unwrap();
let y: Vec<f64> = vec![ let y: Vec<f64> = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
@@ -563,8 +565,7 @@ mod tests {
&[17.0, 1918.0, 1.4054969025700674], &[17.0, 1918.0, 1.4054969025700674],
&[18.0, 1929.0, 1.3271699396384906], &[18.0, 1929.0, 1.3271699396384906],
&[19.0, 1915.0, 1.1373332337674806], &[19.0, 1915.0, 1.1373332337674806],
]) ]);
.unwrap();
let y: Vec<f64> = vec![ let y: Vec<f64> = vec![
1.48, 2.72, 4.52, 5.72, 5.25, 4.07, 3.75, 4.75, 6.77, 4.72, 6.78, 6.79, 8.3, 7.42, 1.48, 2.72, 4.52, 5.72, 5.25, 4.07, 3.75, 4.75, 6.77, 4.72, 6.78, 6.79, 8.3, 7.42,
@@ -629,7 +630,7 @@ mod tests {
// &[502.601, 393.1, 251.4, 125.368, 1960., 69.564], // &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
// &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], // &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
// &[554.894, 400.7, 282.7, 130.081, 1962., 70.551], // &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
// ]).unwrap(); // ]);
// let y = vec![ // let y = vec![
// 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, // 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
+5 -3
View File
@@ -356,7 +356,10 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
for (i, col_std_i) in col_std.iter().enumerate() { for (i, col_std_i) in col_std.iter().enumerate() {
if (*col_std_i - TX::zero()).abs() < TX::epsilon() { if (*col_std_i - TX::zero()).abs() < TX::epsilon() {
return Err(Failed::fit(&format!("Cannot rescale constant column {i}"))); return Err(Failed::fit(&format!(
"Cannot rescale constant column {}",
i
)));
} }
} }
@@ -418,8 +421,7 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564], &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331], &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551], &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]) ]);
.unwrap();
let y: Vec<f64> = vec![ let y: Vec<f64> = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
+10 -4
View File
@@ -16,7 +16,7 @@ use crate::linalg::basic::arrays::{Array1, Array2, ArrayView1, MutArray, MutArra
use crate::linear::bg_solver::BiconjugateGradientSolver; use crate::linear::bg_solver::BiconjugateGradientSolver;
use crate::numbers::floatnum::FloatNumber; use crate::numbers::floatnum::FloatNumber;
/// Interior Point Optimizer ///
pub struct InteriorPointOptimizer<T: FloatNumber, X: Array2<T>> { pub struct InteriorPointOptimizer<T: FloatNumber, X: Array2<T>> {
ata: X, ata: X,
d1: Vec<T>, d1: Vec<T>,
@@ -25,8 +25,9 @@ pub struct InteriorPointOptimizer<T: FloatNumber, X: Array2<T>> {
prs: Vec<T>, prs: Vec<T>,
} }
///
impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> { impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
/// Initialize a new Interior Point Optimizer ///
pub fn new(a: &X, n: usize) -> InteriorPointOptimizer<T, X> { pub fn new(a: &X, n: usize) -> InteriorPointOptimizer<T, X> {
InteriorPointOptimizer { InteriorPointOptimizer {
ata: a.ab(true, a, false), ata: a.ab(true, a, false),
@@ -37,7 +38,7 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
} }
} }
/// Run the optimization ///
pub fn optimize( pub fn optimize(
&mut self, &mut self,
x: &X, x: &X,
@@ -100,7 +101,7 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
// CALCULATE DUALITY GAP // CALCULATE DUALITY GAP
let xnu = nu.xa(false, x); let xnu = nu.xa(false, x);
let max_xnu = xnu.norm(f64::INFINITY); let max_xnu = xnu.norm(std::f64::INFINITY);
if max_xnu > lambda_f64 { if max_xnu > lambda_f64 {
let lnu = T::from_f64(lambda_f64 / max_xnu).unwrap(); let lnu = T::from_f64(lambda_f64 / max_xnu).unwrap();
nu.mul_scalar_mut(lnu); nu.mul_scalar_mut(lnu);
@@ -207,6 +208,7 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
Ok(w) Ok(w)
} }
///
fn sumlogneg(f: &X) -> T { fn sumlogneg(f: &X) -> T {
let (n, _) = f.shape(); let (n, _) = f.shape();
let mut sum = T::zero(); let mut sum = T::zero();
@@ -218,9 +220,11 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
} }
} }
///
impl<'a, T: FloatNumber, X: Array2<T>> BiconjugateGradientSolver<'a, T, X> impl<'a, T: FloatNumber, X: Array2<T>> BiconjugateGradientSolver<'a, T, X>
for InteriorPointOptimizer<T, X> for InteriorPointOptimizer<T, X>
{ {
///
fn solve_preconditioner(&self, a: &'a X, b: &[T], x: &mut [T]) { fn solve_preconditioner(&self, a: &'a X, b: &[T], x: &mut [T]) {
let (_, p) = a.shape(); let (_, p) = a.shape();
@@ -230,6 +234,7 @@ impl<'a, T: FloatNumber, X: Array2<T>> BiconjugateGradientSolver<'a, T, X>
} }
} }
///
fn mat_vec_mul(&self, _: &X, x: &Vec<T>, y: &mut Vec<T>) { fn mat_vec_mul(&self, _: &X, x: &Vec<T>, y: &mut Vec<T>) {
let (_, p) = self.ata.shape(); let (_, p) = self.ata.shape();
let x_slice = Vec::from_slice(x.slice(0..p).as_ref()); let x_slice = Vec::from_slice(x.slice(0..p).as_ref());
@@ -241,6 +246,7 @@ impl<'a, T: FloatNumber, X: Array2<T>> BiconjugateGradientSolver<'a, T, X>
} }
} }
///
fn mat_t_vec_mul(&self, a: &X, x: &Vec<T>, y: &mut Vec<T>) { fn mat_t_vec_mul(&self, a: &X, x: &Vec<T>, y: &mut Vec<T>) {
self.mat_vec_mul(a, x, y); self.mat_vec_mul(a, x, y);
} }
+3 -4
View File
@@ -40,7 +40,7 @@
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564], //! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], //! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551], //! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
//! ]).unwrap(); //! ]);
//! //!
//! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, //! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0,
//! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]; //! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
@@ -341,8 +341,7 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564], &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331], &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551], &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]) ]);
.unwrap();
let y: Vec<f64> = vec![ let y: Vec<f64> = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8,
@@ -394,7 +393,7 @@ mod tests {
// &[502.601, 393.1, 251.4, 125.368, 1960., 69.564], // &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
// &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], // &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
// &[554.894, 400.7, 282.7, 130.081, 1962., 70.551], // &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
// ]).unwrap(); // ]);
// let y = vec![ // let y = vec![
// 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, // 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
+66 -103
View File
@@ -35,7 +35,7 @@
//! &[4.9, 2.4, 3.3, 1.0], //! &[4.9, 2.4, 3.3, 1.0],
//! &[6.6, 2.9, 4.6, 1.3], //! &[6.6, 2.9, 4.6, 1.3],
//! &[5.2, 2.7, 3.9, 1.4], //! &[5.2, 2.7, 3.9, 1.4],
//! ]).unwrap(); //! ]);
//! let y: Vec<i32> = vec![ //! let y: Vec<i32> = vec![
//! 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, //! 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
//! ]; //! ];
@@ -71,14 +71,19 @@ use crate::optimization::line_search::Backtracking;
use crate::optimization::FunctionOrder; use crate::optimization::FunctionOrder;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, Eq, PartialEq, Default)] #[derive(Debug, Clone, Eq, PartialEq)]
/// Solver options for Logistic regression. Right now only LBFGS solver is supported. /// Solver options for Logistic regression. Right now only LBFGS solver is supported.
pub enum LogisticRegressionSolverName { pub enum LogisticRegressionSolverName {
/// Limited-memory BroydenFletcherGoldfarbShanno method, see [LBFGS paper](http://users.iems.northwestern.edu/~nocedal/lbfgsb.html) /// Limited-memory BroydenFletcherGoldfarbShanno method, see [LBFGS paper](http://users.iems.northwestern.edu/~nocedal/lbfgsb.html)
#[default]
LBFGS, LBFGS,
} }
impl Default for LogisticRegressionSolverName {
fn default() -> Self {
LogisticRegressionSolverName::LBFGS
}
}
/// Logistic Regression parameters /// Logistic Regression parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -183,11 +188,14 @@ pub struct LogisticRegression<
} }
trait ObjectiveFunction<T: Number + FloatNumber, X: Array2<T>> { trait ObjectiveFunction<T: Number + FloatNumber, X: Array2<T>> {
///
fn f(&self, w_bias: &[T]) -> T; fn f(&self, w_bias: &[T]) -> T;
///
#[allow(clippy::ptr_arg)] #[allow(clippy::ptr_arg)]
fn df(&self, g: &mut Vec<T>, w_bias: &Vec<T>); fn df(&self, g: &mut Vec<T>, w_bias: &Vec<T>);
///
#[allow(clippy::ptr_arg)] #[allow(clippy::ptr_arg)]
fn partial_dot(w: &[T], x: &X, v_col: usize, m_row: usize) -> T { fn partial_dot(w: &[T], x: &X, v_col: usize, m_row: usize) -> T {
let mut sum = T::zero(); let mut sum = T::zero();
@@ -258,8 +266,8 @@ impl<TX: Number + FloatNumber + RealNumber, TY: Number + Ord, X: Array2<TX>, Y:
} }
} }
impl<T: Number + FloatNumber, X: Array2<T>> ObjectiveFunction<T, X> impl<'a, T: Number + FloatNumber, X: Array2<T>> ObjectiveFunction<T, X>
for BinaryObjectiveFunction<'_, T, X> for BinaryObjectiveFunction<'a, T, X>
{ {
fn f(&self, w_bias: &[T]) -> T { fn f(&self, w_bias: &[T]) -> T {
let mut f = T::zero(); let mut f = T::zero();
@@ -313,8 +321,8 @@ struct MultiClassObjectiveFunction<'a, T: Number + FloatNumber, X: Array2<T>> {
_phantom_t: PhantomData<T>, _phantom_t: PhantomData<T>,
} }
impl<T: Number + FloatNumber + RealNumber, X: Array2<T>> ObjectiveFunction<T, X> impl<'a, T: Number + FloatNumber + RealNumber, X: Array2<T>> ObjectiveFunction<T, X>
for MultiClassObjectiveFunction<'_, T, X> for MultiClassObjectiveFunction<'a, T, X>
{ {
fn f(&self, w_bias: &[T]) -> T { fn f(&self, w_bias: &[T]) -> T {
let mut f = T::zero(); let mut f = T::zero();
@@ -441,7 +449,8 @@ impl<TX: Number + FloatNumber + RealNumber, TY: Number + Ord, X: Array2<TX>, Y:
match k.cmp(&2) { match k.cmp(&2) {
Ordering::Less => Err(Failed::fit(&format!( Ordering::Less => Err(Failed::fit(&format!(
"incorrect number of classes: {k}. Should be >= 2." "incorrect number of classes: {}. Should be >= 2.",
k
))), ))),
Ordering::Equal => { Ordering::Equal => {
let x0 = Vec::zeros(num_attributes + 1); let x0 = Vec::zeros(num_attributes + 1);
@@ -608,8 +617,7 @@ mod tests {
&[10., -2.], &[10., -2.],
&[8., 2.], &[8., 2.],
&[9., 0.], &[9., 0.],
]) ]);
.unwrap();
let y = vec![0, 0, 1, 1, 2, 1, 1, 0, 0, 2, 1, 1, 0, 0, 1]; let y = vec![0, 0, 1, 1, 2, 1, 1, 0, 0, 2, 1, 1, 0, 0, 1];
@@ -626,21 +634,21 @@ mod tests {
objective.df(&mut g, &vec![1., 2., 3., 4., 5., 6., 7., 8., 9.]); objective.df(&mut g, &vec![1., 2., 3., 4., 5., 6., 7., 8., 9.]);
objective.df(&mut g, &vec![1., 2., 3., 4., 5., 6., 7., 8., 9.]); objective.df(&mut g, &vec![1., 2., 3., 4., 5., 6., 7., 8., 9.]);
assert!((g[0] + 33.000068218163484).abs() < f64::EPSILON); assert!((g[0] + 33.000068218163484).abs() < std::f64::EPSILON);
let f = objective.f(&[1., 2., 3., 4., 5., 6., 7., 8., 9.]); let f = objective.f(&vec![1., 2., 3., 4., 5., 6., 7., 8., 9.]);
assert!((f - 408.0052230582765).abs() < f64::EPSILON); assert!((f - 408.0052230582765).abs() < std::f64::EPSILON);
let objective_reg = MultiClassObjectiveFunction { let objective_reg = MultiClassObjectiveFunction {
x: &x, x: &x,
y, y: y.clone(),
k: 3, k: 3,
alpha: 1.0, alpha: 1.0,
_phantom_t: PhantomData, _phantom_t: PhantomData,
}; };
let f = objective_reg.f(&[1., 2., 3., 4., 5., 6., 7., 8., 9.]); let f = objective_reg.f(&vec![1., 2., 3., 4., 5., 6., 7., 8., 9.]);
assert!((f - 487.5052).abs() < 1e-4); assert!((f - 487.5052).abs() < 1e-4);
objective_reg.df(&mut g, &vec![1., 2., 3., 4., 5., 6., 7., 8., 9.]); objective_reg.df(&mut g, &vec![1., 2., 3., 4., 5., 6., 7., 8., 9.]);
@@ -669,8 +677,7 @@ mod tests {
&[10., -2.], &[10., -2.],
&[8., 2.], &[8., 2.],
&[9., 0.], &[9., 0.],
]) ]);
.unwrap();
let y = vec![0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1]; let y = vec![0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1];
@@ -686,22 +693,22 @@ mod tests {
objective.df(&mut g, &vec![1., 2., 3.]); objective.df(&mut g, &vec![1., 2., 3.]);
objective.df(&mut g, &vec![1., 2., 3.]); objective.df(&mut g, &vec![1., 2., 3.]);
assert!((g[0] - 26.051064349381285).abs() < f64::EPSILON); assert!((g[0] - 26.051064349381285).abs() < std::f64::EPSILON);
assert!((g[1] - 10.239000702928523).abs() < f64::EPSILON); assert!((g[1] - 10.239000702928523).abs() < std::f64::EPSILON);
assert!((g[2] - 3.869294270156324).abs() < f64::EPSILON); assert!((g[2] - 3.869294270156324).abs() < std::f64::EPSILON);
let f = objective.f(&[1., 2., 3.]); let f = objective.f(&vec![1., 2., 3.]);
assert!((f - 59.76994756647412).abs() < f64::EPSILON); assert!((f - 59.76994756647412).abs() < std::f64::EPSILON);
let objective_reg = BinaryObjectiveFunction { let objective_reg = BinaryObjectiveFunction {
x: &x, x: &x,
y, y: y.clone(),
alpha: 1.0, alpha: 1.0,
_phantom_t: PhantomData, _phantom_t: PhantomData,
}; };
let f = objective_reg.f(&[1., 2., 3.]); let f = objective_reg.f(&vec![1., 2., 3.]);
assert!((f - 62.2699).abs() < 1e-4); assert!((f - 62.2699).abs() < 1e-4);
objective_reg.df(&mut g, &vec![1., 2., 3.]); objective_reg.df(&mut g, &vec![1., 2., 3.]);
@@ -732,8 +739,7 @@ mod tests {
&[10., -2.], &[10., -2.],
&[8., 2.], &[8., 2.],
&[9., 0.], &[9., 0.],
]) ]);
.unwrap();
let y: Vec<i32> = vec![0, 0, 1, 1, 2, 1, 1, 0, 0, 2, 1, 1, 0, 0, 1]; let y: Vec<i32> = vec![0, 0, 1, 1, 2, 1, 1, 0, 0, 2, 1, 1, 0, 0, 1];
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap(); let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
@@ -818,41 +824,37 @@ mod tests {
assert!(reg_coeff_sum < coeff); assert!(reg_coeff_sum < coeff);
} }
//TODO: serialization for the new DenseMatrix needs to be implemented // TODO: serialization for the new DenseMatrix needs to be implemented
#[cfg_attr( // #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
all(target_arch = "wasm32", not(target_os = "wasi")), // #[test]
wasm_bindgen_test::wasm_bindgen_test // #[cfg(feature = "serde")]
)] // fn serde() {
#[test] // let x = DenseMatrix::from_2d_array(&[
#[cfg(feature = "serde")] // &[1., -5.],
fn serde() { // &[2., 5.],
let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[ // &[3., -2.],
&[1., -5.], // &[1., 2.],
&[2., 5.], // &[2., 0.],
&[3., -2.], // &[6., -5.],
&[1., 2.], // &[7., 5.],
&[2., 0.], // &[6., -2.],
&[6., -5.], // &[7., 2.],
&[7., 5.], // &[6., 0.],
&[6., -2.], // &[8., -5.],
&[7., 2.], // &[9., 5.],
&[6., 0.], // &[10., -2.],
&[8., -5.], // &[8., 2.],
&[9., 5.], // &[9., 0.],
&[10., -2.], // ]);
&[8., 2.], // let y: Vec<i32> = vec![0, 0, 1, 1, 2, 1, 1, 0, 0, 2, 1, 1, 0, 0, 1];
&[9., 0.],
])
.unwrap();
let y: Vec<i32> = vec![0, 0, 1, 1, 2, 1, 1, 0, 0, 2, 1, 1, 0, 0, 1];
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap(); // let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
let deserialized_lr: LogisticRegression<f64, i32, DenseMatrix<f64>, Vec<i32>> = // let deserialized_lr: LogisticRegression<f64, i32, DenseMatrix<f64>, Vec<i32>> =
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap(); // serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
assert_eq!(lr, deserialized_lr); // assert_eq!(lr, deserialized_lr);
} // }
#[cfg_attr( #[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")), all(target_arch = "wasm32", not(target_os = "wasi")),
@@ -881,8 +883,7 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0], &[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3], &[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4], &[5.2, 2.7, 3.9, 1.4],
]) ]);
.unwrap();
let y: Vec<i32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; let y: Vec<i32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap(); let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
@@ -895,7 +896,11 @@ mod tests {
let y_hat = lr.predict(&x).unwrap(); let y_hat = lr.predict(&x).unwrap();
let error: i32 = y.into_iter().zip(y_hat).map(|(a, b)| (a - b).abs()).sum(); let error: i32 = y
.into_iter()
.zip(y_hat.into_iter())
.map(|(a, b)| (a - b).abs())
.sum();
assert!(error <= 1); assert!(error <= 1);
@@ -904,46 +909,4 @@ mod tests {
assert!(reg_coeff_sum < coeff); assert!(reg_coeff_sum < coeff);
} }
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn lr_fit_predict_random() {
let x: DenseMatrix<f32> = DenseMatrix::rand(52181, 94);
let y1: Vec<i32> = vec![1; 2181];
let y2: Vec<i32> = vec![0; 50000];
let y: Vec<i32> = y1.into_iter().chain(y2).collect();
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
let lr_reg = LogisticRegression::fit(
&x,
&y,
LogisticRegressionParameters::default().with_alpha(1.0),
)
.unwrap();
let y_hat = lr.predict(&x).unwrap();
let y_hat_reg = lr_reg.predict(&x).unwrap();
assert_eq!(y.len(), y_hat.len());
assert_eq!(y.len(), y_hat_reg.len());
}
#[test]
fn test_logit() {
let x: &DenseMatrix<f64> = &DenseMatrix::rand(52181, 94);
let y1: Vec<u32> = vec![1; 2181];
let y2: Vec<u32> = vec![0; 50000];
let y: &Vec<u32> = &(y1.into_iter().chain(y2).collect());
println!("y vec height: {:?}", y.len());
println!("x matrix shape: {:?}", x.shape());
let lr = LogisticRegression::fit(x, y, Default::default()).unwrap();
let y_hat = lr.predict(x).unwrap();
println!("y_hat shape: {:?}", y_hat.shape());
assert_eq!(y_hat.shape(), 52181);
}
} }
+14 -7
View File
@@ -40,7 +40,7 @@
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564], //! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], //! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551], //! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
//! ]).unwrap(); //! ]);
//! //!
//! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, //! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0,
//! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]; //! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
@@ -71,16 +71,21 @@ use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber; use crate::numbers::realnum::RealNumber;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, Eq, PartialEq, Default)] #[derive(Debug, Clone, Eq, PartialEq)]
/// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable. /// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable.
pub enum RidgeRegressionSolverName { pub enum RidgeRegressionSolverName {
/// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html) /// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html)
#[default]
Cholesky, Cholesky,
/// SVD decomposition, see [SVD](../../linalg/svd/index.html) /// SVD decomposition, see [SVD](../../linalg/svd/index.html)
SVD, SVD,
} }
impl Default for RidgeRegressionSolverName {
fn default() -> Self {
RidgeRegressionSolverName::Cholesky
}
}
/// Ridge Regression parameters /// Ridge Regression parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -379,7 +384,10 @@ impl<
for (i, col_std_i) in col_std.iter().enumerate() { for (i, col_std_i) in col_std.iter().enumerate() {
if (*col_std_i - TX::zero()).abs() < TX::epsilon() { if (*col_std_i - TX::zero()).abs() < TX::epsilon() {
return Err(Failed::fit(&format!("Cannot rescale constant column {i}"))); return Err(Failed::fit(&format!(
"Cannot rescale constant column {}",
i
)));
} }
} }
@@ -455,8 +463,7 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564], &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331], &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551], &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]) ]);
.unwrap();
let y: Vec<f64> = vec![ let y: Vec<f64> = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
@@ -514,7 +521,7 @@ mod tests {
// &[502.601, 393.1, 251.4, 125.368, 1960., 69.564], // &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
// &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], // &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
// &[554.894, 400.7, 282.7, 130.081, 1962., 70.551], // &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
// ]).unwrap(); // ]);
// let y = vec![ // let y = vec![
// 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, // 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
+3 -3
View File
@@ -98,8 +98,8 @@ mod tests {
let mut scores = HCVScore::new(); let mut scores = HCVScore::new();
scores.compute(&v1, &v2); scores.compute(&v1, &v2);
assert!((0.2548 - scores.homogeneity.unwrap()).abs() < 1e-4); assert!((0.2548 - scores.homogeneity.unwrap() as f64).abs() < 1e-4);
assert!((0.5440 - scores.completeness.unwrap()).abs() < 1e-4); assert!((0.5440 - scores.completeness.unwrap() as f64).abs() < 1e-4);
assert!((0.3471 - scores.v_measure.unwrap()).abs() < 1e-4); assert!((0.3471 - scores.v_measure.unwrap() as f64).abs() < 1e-4);
} }
} }
+1 -1
View File
@@ -125,7 +125,7 @@ mod tests {
fn entropy_test() { fn entropy_test() {
let v1 = vec![0, 0, 1, 1, 2, 0, 4]; let v1 = vec![0, 0, 1, 1, 2, 0, 4];
assert!((1.2770 - entropy(&v1).unwrap()).abs() < 1e-4); assert!((1.2770 - entropy(&v1).unwrap() as f64).abs() < 1e-4);
} }
#[cfg_attr( #[cfg_attr(
+2 -3
View File
@@ -25,7 +25,7 @@
//! &[68., 590., 37.], //! &[68., 590., 37.],
//! &[69., 660., 46.], //! &[69., 660., 46.],
//! &[73., 600., 55.], //! &[73., 600., 55.],
//! ]).unwrap(); //! ]);
//! //!
//! let a = data.mean_by(0); //! let a = data.mean_by(0);
//! let b = vec![66., 640., 44.]; //! let b = vec![66., 640., 44.];
@@ -151,8 +151,7 @@ mod tests {
&[68., 590., 37.], &[68., 590., 37.],
&[69., 660., 46.], &[69., 660., 46.],
&[73., 600., 55.], &[73., 600., 55.],
]) ]);
.unwrap();
let a = data.mean_by(0); let a = data.mean_by(0);
let b = vec![66., 640., 44.]; let b = vec![66., 640., 44.];
+2 -2
View File
@@ -95,8 +95,8 @@ mod tests {
let score1: f64 = F1::new_with(beta).get_score(&y_true, &y_pred); let score1: f64 = F1::new_with(beta).get_score(&y_true, &y_pred);
let score2: f64 = F1::new_with(beta).get_score(&y_true, &y_true); let score2: f64 = F1::new_with(beta).get_score(&y_true, &y_true);
println!("{score1:?}"); println!("{:?}", score1);
println!("{score2:?}"); println!("{:?}", score2);
assert!((score1 - 0.57142857).abs() < 1e-8); assert!((score1 - 0.57142857).abs() < 1e-8);
assert!((score2 - 1.0).abs() < 1e-8); assert!((score2 - 1.0).abs() < 1e-8);
+1 -1
View File
@@ -37,7 +37,7 @@
//! &[4.9, 2.4, 3.3, 1.0], //! &[4.9, 2.4, 3.3, 1.0],
//! &[6.6, 2.9, 4.6, 1.3], //! &[6.6, 2.9, 4.6, 1.3],
//! &[5.2, 2.7, 3.9, 1.4], //! &[5.2, 2.7, 3.9, 1.4],
//! ]).unwrap(); //! ]);
//! let y: Vec<i8> = vec![ //! let y: Vec<i8> = vec![
//! 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, //! 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
//! ]; //! ];
@@ -3,9 +3,9 @@
use crate::{ use crate::{
api::{Predictor, SupervisedEstimator}, api::{Predictor, SupervisedEstimator},
error::{Failed, FailedError}, error::{Failed, FailedError},
linalg::basic::arrays::{Array1, Array2}, linalg::basic::arrays::{Array2, Array1},
numbers::basenum::Number,
numbers::realnum::RealNumber, numbers::realnum::RealNumber,
numbers::basenum::Number,
}; };
use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult}; use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult};
+10 -6
View File
@@ -213,17 +213,17 @@ mod tests {
for t in &test_masks[0][0..11] { for t in &test_masks[0][0..11] {
// TODO: this can be prob done better // TODO: this can be prob done better
assert!(*t) assert_eq!(*t, true)
} }
for t in &test_masks[0][11..22] { for t in &test_masks[0][11..22] {
assert!(!*t) assert_eq!(*t, false)
} }
for t in &test_masks[1][0..11] { for t in &test_masks[1][0..11] {
assert!(!*t) assert_eq!(*t, false)
} }
for t in &test_masks[1][11..22] { for t in &test_masks[1][11..22] {
assert!(*t) assert_eq!(*t, true)
} }
} }
@@ -283,7 +283,9 @@ mod tests {
(vec![0, 1, 2, 3, 7, 8, 9], vec![4, 5, 6]), (vec![0, 1, 2, 3, 7, 8, 9], vec![4, 5, 6]),
(vec![0, 1, 2, 3, 4, 5, 6], vec![7, 8, 9]), (vec![0, 1, 2, 3, 4, 5, 6], vec![7, 8, 9]),
]; ];
for ((train, test), (expected_train, expected_test)) in k.split(&x).zip(expected) { for ((train, test), (expected_train, expected_test)) in
k.split(&x).into_iter().zip(expected)
{
assert_eq!(test, expected_test); assert_eq!(test, expected_test);
assert_eq!(train, expected_train); assert_eq!(train, expected_train);
} }
@@ -305,7 +307,9 @@ mod tests {
(vec![0, 1, 2, 3, 7, 8, 9], vec![4, 5, 6]), (vec![0, 1, 2, 3, 7, 8, 9], vec![4, 5, 6]),
(vec![0, 1, 2, 3, 4, 5, 6], vec![7, 8, 9]), (vec![0, 1, 2, 3, 4, 5, 6], vec![7, 8, 9]),
]; ];
for ((train, test), (expected_train, expected_test)) in k.split(&x).zip(expected) { for ((train, test), (expected_train, expected_test)) in
k.split(&x).into_iter().zip(expected)
{
assert_eq!(test.len(), expected_test.len()); assert_eq!(test.len(), expected_test.len());
assert_eq!(train.len(), expected_train.len()); assert_eq!(train.len(), expected_train.len());
} }
+8 -12
View File
@@ -36,7 +36,7 @@
//! &[4.9, 2.4, 3.3, 1.0], //! &[4.9, 2.4, 3.3, 1.0],
//! &[6.6, 2.9, 4.6, 1.3], //! &[6.6, 2.9, 4.6, 1.3],
//! &[5.2, 2.7, 3.9, 1.4], //! &[5.2, 2.7, 3.9, 1.4],
//! ]).unwrap(); //! ]);
//! let y: Vec<f64> = vec![ //! let y: Vec<f64> = vec![
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., //! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
//! ]; //! ];
@@ -84,7 +84,7 @@
//! &[4.9, 2.4, 3.3, 1.0], //! &[4.9, 2.4, 3.3, 1.0],
//! &[6.6, 2.9, 4.6, 1.3], //! &[6.6, 2.9, 4.6, 1.3],
//! &[5.2, 2.7, 3.9, 1.4], //! &[5.2, 2.7, 3.9, 1.4],
//! ]).unwrap(); //! ]);
//! let y: Vec<i32> = vec![ //! let y: Vec<i32> = vec![
//! 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, //! 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
//! ]; //! ];
@@ -169,7 +169,7 @@ pub fn train_test_split<
let n_test = ((n as f32) * test_size) as usize; let n_test = ((n as f32) * test_size) as usize;
if n_test < 1 { if n_test < 1 {
panic!("number of sample is too small {n}"); panic!("number of sample is too small {}", n);
} }
let mut indices: Vec<usize> = (0..n).collect(); let mut indices: Vec<usize> = (0..n).collect();
@@ -396,8 +396,7 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0], &[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3], &[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4], &[5.2, 2.7, 3.9, 1.4],
]) ]);
.unwrap();
let y: Vec<u32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; let y: Vec<u32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
let cv = KFold { let cv = KFold {
@@ -442,8 +441,7 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564], &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331], &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551], &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]) ]);
.unwrap();
let y = vec![ let y = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9, 114.2, 115.7, 116.9,
@@ -491,8 +489,7 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564], &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331], &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551], &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]) ]);
.unwrap();
let y: Vec<f64> = vec![ let y: Vec<f64> = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9, 114.2, 115.7, 116.9,
@@ -542,8 +539,7 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0], &[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3], &[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4], &[5.2, 2.7, 3.9, 1.4],
]) ]);
.unwrap();
let y: Vec<i32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; let y: Vec<i32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
let cv = KFold::default().with_n_splits(3); let cv = KFold::default().with_n_splits(3);
@@ -557,6 +553,6 @@ mod tests {
&accuracy, &accuracy,
) )
.unwrap(); .unwrap();
println!("{results:?}"); println!("{:?}", results);
} }
} }
+18 -17
View File
@@ -19,14 +19,14 @@
//! &[0, 1, 0, 0, 1, 0], //! &[0, 1, 0, 0, 1, 0],
//! &[0, 1, 0, 1, 0, 0], //! &[0, 1, 0, 1, 0, 0],
//! &[0, 1, 1, 0, 0, 1], //! &[0, 1, 1, 0, 0, 1],
//! ]).unwrap(); //! ]);
//! let y: Vec<u32> = vec![0, 0, 0, 1]; //! let y: Vec<u32> = vec![0, 0, 0, 1];
//! //!
//! let nb = BernoulliNB::fit(&x, &y, Default::default()).unwrap(); //! let nb = BernoulliNB::fit(&x, &y, Default::default()).unwrap();
//! //!
//! // Testing data point is: //! // Testing data point is:
//! // Chinese Chinese Chinese Tokyo Japan //! // Chinese Chinese Chinese Tokyo Japan
//! let x_test = DenseMatrix::from_2d_array(&[&[0, 1, 1, 0, 0, 1]]).unwrap(); //! let x_test = DenseMatrix::from_2d_array(&[&[0, 1, 1, 0, 0, 1]]);
//! let y_hat = nb.predict(&x_test).unwrap(); //! let y_hat = nb.predict(&x_test).unwrap();
//! ``` //! ```
//! //!
@@ -257,7 +257,8 @@ impl<TY: Number + Ord + Unsigned> BernoulliNBDistribution<TY> {
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features. /// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
/// * `x` - training data. /// * `x` - training data.
/// * `y` - vector with target values (classes) of length N. /// * `y` - vector with target values (classes) of length N.
/// * `priors` - Optional vector with prior probabilities of the classes. If not defined, priors are adjusted according to the data. /// * `priors` - Optional vector with prior probabilities of the classes. If not defined,
/// priors are adjusted according to the data.
/// * `alpha` - Additive (Laplace/Lidstone) smoothing parameter. /// * `alpha` - Additive (Laplace/Lidstone) smoothing parameter.
/// * `binarize` - Threshold for binarizing. /// * `binarize` - Threshold for binarizing.
fn fit<TX: Number + PartialOrd, X: Array2<TX>, Y: Array1<TY>>( fn fit<TX: Number + PartialOrd, X: Array2<TX>, Y: Array1<TY>>(
@@ -270,18 +271,21 @@ impl<TY: Number + Ord + Unsigned> BernoulliNBDistribution<TY> {
let y_samples = y.shape(); let y_samples = y.shape();
if y_samples != n_samples { if y_samples != n_samples {
return Err(Failed::fit(&format!( return Err(Failed::fit(&format!(
"Size of x should equal size of y; |x|=[{n_samples}], |y|=[{y_samples}]" "Size of x should equal size of y; |x|=[{}], |y|=[{}]",
n_samples, y_samples
))); )));
} }
if n_samples == 0 { if n_samples == 0 {
return Err(Failed::fit(&format!( return Err(Failed::fit(&format!(
"Size of x and y should greater than 0; |x|=[{n_samples}]" "Size of x and y should greater than 0; |x|=[{}]",
n_samples
))); )));
} }
if alpha < 0f64 { if alpha < 0f64 {
return Err(Failed::fit(&format!( return Err(Failed::fit(&format!(
"Alpha should be greater than 0; |alpha|=[{alpha}]" "Alpha should be greater than 0; |alpha|=[{}]",
alpha
))); )));
} }
@@ -314,7 +318,8 @@ impl<TY: Number + Ord + Unsigned> BernoulliNBDistribution<TY> {
feature_in_class_counter[class_index][idx] += feature_in_class_counter[class_index][idx] +=
row_i.to_usize().ok_or_else(|| { row_i.to_usize().ok_or_else(|| {
Failed::fit(&format!( Failed::fit(&format!(
"Elements of the matrix should be 1.0 or 0.0 |found|=[{row_i}]" "Elements of the matrix should be 1.0 or 0.0 |found|=[{}]",
row_i
)) ))
})?; })?;
} }
@@ -401,10 +406,10 @@ impl<TX: Number + PartialOrd, TY: Number + Ord + Unsigned, X: Array2<TX>, Y: Arr
{ {
/// Fits BernoulliNB with given data /// Fits BernoulliNB with given data
/// * `x` - training data of size NxM where N is the number of samples and M is the number of /// * `x` - training data of size NxM where N is the number of samples and M is the number of
/// features. /// features.
/// * `y` - vector with target values (classes) of length N. /// * `y` - vector with target values (classes) of length N.
/// * `parameters` - additional parameters like class priors, alpha for smoothing and /// * `parameters` - additional parameters like class priors, alpha for smoothing and
/// binarizing threshold. /// binarizing threshold.
pub fn fit(x: &X, y: &Y, parameters: BernoulliNBParameters<TX>) -> Result<Self, Failed> { pub fn fit(x: &X, y: &Y, parameters: BernoulliNBParameters<TX>) -> Result<Self, Failed> {
let distribution = if let Some(threshold) = parameters.binarize { let distribution = if let Some(threshold) = parameters.binarize {
BernoulliNBDistribution::fit( BernoulliNBDistribution::fit(
@@ -426,7 +431,6 @@ impl<TX: Number + PartialOrd, TY: Number + Ord + Unsigned, X: Array2<TX>, Y: Arr
/// Estimates the class labels for the provided data. /// Estimates the class labels for the provided data.
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features. /// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
///
/// Returns a vector of size N with class estimates. /// Returns a vector of size N with class estimates.
pub fn predict(&self, x: &X) -> Result<Y, Failed> { pub fn predict(&self, x: &X) -> Result<Y, Failed> {
if let Some(threshold) = self.binarize { if let Some(threshold) = self.binarize {
@@ -527,8 +531,7 @@ mod tests {
&[0.0, 1.0, 0.0, 0.0, 1.0, 0.0], &[0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
&[0.0, 1.0, 0.0, 1.0, 0.0, 0.0], &[0.0, 1.0, 0.0, 1.0, 0.0, 0.0],
&[0.0, 1.0, 1.0, 0.0, 0.0, 1.0], &[0.0, 1.0, 1.0, 0.0, 0.0, 1.0],
]) ]);
.unwrap();
let y: Vec<u32> = vec![0, 0, 0, 1]; let y: Vec<u32> = vec![0, 0, 0, 1];
let bnb = BernoulliNB::fit(&x, &y, Default::default()).unwrap(); let bnb = BernoulliNB::fit(&x, &y, Default::default()).unwrap();
@@ -559,7 +562,7 @@ mod tests {
// Testing data point is: // Testing data point is:
// Chinese Chinese Chinese Tokyo Japan // Chinese Chinese Chinese Tokyo Japan
let x_test = DenseMatrix::from_2d_array(&[&[0.0, 1.0, 1.0, 0.0, 0.0, 1.0]]).unwrap(); let x_test = DenseMatrix::from_2d_array(&[&[0.0, 1.0, 1.0, 0.0, 0.0, 1.0]]);
let y_hat = bnb.predict(&x_test).unwrap(); let y_hat = bnb.predict(&x_test).unwrap();
assert_eq!(y_hat, &[1]); assert_eq!(y_hat, &[1]);
@@ -587,8 +590,7 @@ mod tests {
&[2, 0, 3, 3, 1, 2, 0, 2, 4, 1], &[2, 0, 3, 3, 1, 2, 0, 2, 4, 1],
&[2, 4, 0, 4, 2, 4, 1, 3, 1, 4], &[2, 4, 0, 4, 2, 4, 1, 3, 1, 4],
&[0, 2, 2, 3, 4, 0, 4, 4, 4, 4], &[0, 2, 2, 3, 4, 0, 4, 4, 4, 4],
]) ]);
.unwrap();
let y: Vec<u32> = vec![2, 2, 0, 0, 0, 2, 1, 1, 0, 1, 0, 0, 2, 0, 2]; let y: Vec<u32> = vec![2, 2, 0, 0, 0, 2, 1, 1, 0, 1, 0, 0, 2, 0, 2];
let bnb = BernoulliNB::fit(&x, &y, Default::default()).unwrap(); let bnb = BernoulliNB::fit(&x, &y, Default::default()).unwrap();
@@ -645,8 +647,7 @@ mod tests {
&[0, 1, 0, 0, 1, 0], &[0, 1, 0, 0, 1, 0],
&[0, 1, 0, 1, 0, 0], &[0, 1, 0, 1, 0, 0],
&[0, 1, 1, 0, 0, 1], &[0, 1, 1, 0, 0, 1],
]) ]);
.unwrap();
let y: Vec<u32> = vec![0, 0, 0, 1]; let y: Vec<u32> = vec![0, 0, 0, 1];
let bnb = BernoulliNB::fit(&x, &y, Default::default()).unwrap(); let bnb = BernoulliNB::fit(&x, &y, Default::default()).unwrap();
+16 -15
View File
@@ -24,7 +24,7 @@
//! &[3, 4, 2, 4], //! &[3, 4, 2, 4],
//! &[0, 3, 1, 2], //! &[0, 3, 1, 2],
//! &[0, 4, 1, 2], //! &[0, 4, 1, 2],
//! ]).unwrap(); //! ]);
//! let y: Vec<u32> = vec![0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0]; //! let y: Vec<u32> = vec![0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0];
//! //!
//! let nb = CategoricalNB::fit(&x, &y, Default::default()).unwrap(); //! let nb = CategoricalNB::fit(&x, &y, Default::default()).unwrap();
@@ -95,7 +95,7 @@ impl<T: Number + Unsigned> PartialEq for CategoricalNBDistribution<T> {
return false; return false;
} }
for (a_i_j, b_i_j) in a_i.iter().zip(b_i.iter()) { for (a_i_j, b_i_j) in a_i.iter().zip(b_i.iter()) {
if (*a_i_j - *b_i_j).abs() > f64::EPSILON { if (*a_i_j - *b_i_j).abs() > std::f64::EPSILON {
return false; return false;
} }
} }
@@ -158,7 +158,8 @@ impl<T: Number + Unsigned> CategoricalNBDistribution<T> {
pub fn fit<X: Array2<T>, Y: Array1<T>>(x: &X, y: &Y, alpha: f64) -> Result<Self, Failed> { pub fn fit<X: Array2<T>, Y: Array1<T>>(x: &X, y: &Y, alpha: f64) -> Result<Self, Failed> {
if alpha < 0f64 { if alpha < 0f64 {
return Err(Failed::fit(&format!( return Err(Failed::fit(&format!(
"alpha should be >= 0, alpha=[{alpha}]" "alpha should be >= 0, alpha=[{}]",
alpha
))); )));
} }
@@ -166,13 +167,15 @@ impl<T: Number + Unsigned> CategoricalNBDistribution<T> {
let y_samples = y.shape(); let y_samples = y.shape();
if y_samples != n_samples { if y_samples != n_samples {
return Err(Failed::fit(&format!( return Err(Failed::fit(&format!(
"Size of x should equal size of y; |x|=[{n_samples}], |y|=[{y_samples}]" "Size of x should equal size of y; |x|=[{}], |y|=[{}]",
n_samples, y_samples
))); )));
} }
if n_samples == 0 { if n_samples == 0 {
return Err(Failed::fit(&format!( return Err(Failed::fit(&format!(
"Size of x and y should greater than 0; |x|=[{n_samples}]" "Size of x and y should greater than 0; |x|=[{}]",
n_samples
))); )));
} }
let y: Vec<usize> = y.iterator(0).map(|y_i| y_i.to_usize().unwrap()).collect(); let y: Vec<usize> = y.iterator(0).map(|y_i| y_i.to_usize().unwrap()).collect();
@@ -199,7 +202,8 @@ impl<T: Number + Unsigned> CategoricalNBDistribution<T> {
.max() .max()
.ok_or_else(|| { .ok_or_else(|| {
Failed::fit(&format!( Failed::fit(&format!(
"Failed to get the categories for feature = {feature}" "Failed to get the categories for feature = {}",
feature
)) ))
})?; })?;
n_categories.push(feature_max + 1); n_categories.push(feature_max + 1);
@@ -363,7 +367,7 @@ impl<T: Number + Unsigned, X: Array2<T>, Y: Array1<T>> Predictor<X, Y> for Categ
impl<T: Number + Unsigned, X: Array2<T>, Y: Array1<T>> CategoricalNB<T, X, Y> { impl<T: Number + Unsigned, X: Array2<T>, Y: Array1<T>> CategoricalNB<T, X, Y> {
/// Fits CategoricalNB with given data /// Fits CategoricalNB with given data
/// * `x` - training data of size NxM where N is the number of samples and M is the number of /// * `x` - training data of size NxM where N is the number of samples and M is the number of
/// features. /// features.
/// * `y` - vector with target values (classes) of length N. /// * `y` - vector with target values (classes) of length N.
/// * `parameters` - additional parameters like alpha for smoothing /// * `parameters` - additional parameters like alpha for smoothing
pub fn fit(x: &X, y: &Y, parameters: CategoricalNBParameters) -> Result<Self, Failed> { pub fn fit(x: &X, y: &Y, parameters: CategoricalNBParameters) -> Result<Self, Failed> {
@@ -375,7 +379,6 @@ impl<T: Number + Unsigned, X: Array2<T>, Y: Array1<T>> CategoricalNB<T, X, Y> {
/// Estimates the class labels for the provided data. /// Estimates the class labels for the provided data.
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features. /// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
///
/// Returns a vector of size N with class estimates. /// Returns a vector of size N with class estimates.
pub fn predict(&self, x: &X) -> Result<Y, Failed> { pub fn predict(&self, x: &X) -> Result<Y, Failed> {
self.inner.as_ref().unwrap().predict(x) self.inner.as_ref().unwrap().predict(x)
@@ -426,6 +429,7 @@ mod tests {
fn search_parameters() { fn search_parameters() {
let parameters = CategoricalNBSearchParameters { let parameters = CategoricalNBSearchParameters {
alpha: vec![1., 2.], alpha: vec![1., 2.],
..Default::default()
}; };
let mut iter = parameters.into_iter(); let mut iter = parameters.into_iter();
let next = iter.next().unwrap(); let next = iter.next().unwrap();
@@ -456,8 +460,7 @@ mod tests {
&[1, 1, 1, 1], &[1, 1, 1, 1],
&[1, 2, 0, 0], &[1, 2, 0, 0],
&[2, 1, 1, 1], &[2, 1, 1, 1],
]) ]);
.unwrap();
let y: Vec<u32> = vec![0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0]; let y: Vec<u32> = vec![0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0];
let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap(); let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap();
@@ -515,7 +518,7 @@ mod tests {
] ]
); );
let x_test = DenseMatrix::from_2d_array(&[&[0, 2, 1, 0], &[2, 2, 0, 0]]).unwrap(); let x_test = DenseMatrix::from_2d_array(&[&[0, 2, 1, 0], &[2, 2, 0, 0]]);
let y_hat = cnb.predict(&x_test).unwrap(); let y_hat = cnb.predict(&x_test).unwrap();
assert_eq!(y_hat, vec![0, 1]); assert_eq!(y_hat, vec![0, 1]);
} }
@@ -541,8 +544,7 @@ mod tests {
&[3, 4, 2, 4], &[3, 4, 2, 4],
&[0, 3, 1, 2], &[0, 3, 1, 2],
&[0, 4, 1, 2], &[0, 4, 1, 2],
]) ]);
.unwrap();
let y: Vec<u32> = vec![0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0]; let y: Vec<u32> = vec![0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0];
let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap(); let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap();
@@ -574,8 +576,7 @@ mod tests {
&[3, 4, 2, 4], &[3, 4, 2, 4],
&[0, 3, 1, 2], &[0, 3, 1, 2],
&[0, 4, 1, 2], &[0, 4, 1, 2],
]) ]);
.unwrap();
let y: Vec<u32> = vec![0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0]; let y: Vec<u32> = vec![0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0];
let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap(); let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap();
+12 -12
View File
@@ -16,7 +16,7 @@
//! &[ 1., 1.], //! &[ 1., 1.],
//! &[ 2., 1.], //! &[ 2., 1.],
//! &[ 3., 2.], //! &[ 3., 2.],
//! ]).unwrap(); //! ]);
//! let y: Vec<u32> = vec![1, 1, 1, 2, 2, 2]; //! let y: Vec<u32> = vec![1, 1, 1, 2, 2, 2];
//! //!
//! let nb = GaussianNB::fit(&x, &y, Default::default()).unwrap(); //! let nb = GaussianNB::fit(&x, &y, Default::default()).unwrap();
@@ -174,7 +174,8 @@ impl<TY: Number + Ord + Unsigned> GaussianNBDistribution<TY> {
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features. /// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
/// * `x` - training data. /// * `x` - training data.
/// * `y` - vector with target values (classes) of length N. /// * `y` - vector with target values (classes) of length N.
/// * `priors` - Optional vector with prior probabilities of the classes. If not defined, priors are adjusted according to the data. /// * `priors` - Optional vector with prior probabilities of the classes. If not defined,
/// priors are adjusted according to the data.
pub fn fit<TX: Number + RealNumber, X: Array2<TX>, Y: Array1<TY>>( pub fn fit<TX: Number + RealNumber, X: Array2<TX>, Y: Array1<TY>>(
x: &X, x: &X,
y: &Y, y: &Y,
@@ -184,13 +185,15 @@ impl<TY: Number + Ord + Unsigned> GaussianNBDistribution<TY> {
let y_samples = y.shape(); let y_samples = y.shape();
if y_samples != n_samples { if y_samples != n_samples {
return Err(Failed::fit(&format!( return Err(Failed::fit(&format!(
"Size of x should equal size of y; |x|=[{n_samples}], |y|=[{y_samples}]" "Size of x should equal size of y; |x|=[{}], |y|=[{}]",
n_samples, y_samples
))); )));
} }
if n_samples == 0 { if n_samples == 0 {
return Err(Failed::fit(&format!( return Err(Failed::fit(&format!(
"Size of x and y should greater than 0; |x|=[{n_samples}]" "Size of x and y should greater than 0; |x|=[{}]",
n_samples
))); )));
} }
let (class_labels, indices) = y.unique_with_indices(); let (class_labels, indices) = y.unique_with_indices();
@@ -316,7 +319,7 @@ impl<TX: Number + RealNumber, TY: Number + Ord + Unsigned, X: Array2<TX>, Y: Arr
{ {
/// Fits GaussianNB with given data /// Fits GaussianNB with given data
/// * `x` - training data of size NxM where N is the number of samples and M is the number of /// * `x` - training data of size NxM where N is the number of samples and M is the number of
/// features. /// features.
/// * `y` - vector with target values (classes) of length N. /// * `y` - vector with target values (classes) of length N.
/// * `parameters` - additional parameters like class priors. /// * `parameters` - additional parameters like class priors.
pub fn fit(x: &X, y: &Y, parameters: GaussianNBParameters) -> Result<Self, Failed> { pub fn fit(x: &X, y: &Y, parameters: GaussianNBParameters) -> Result<Self, Failed> {
@@ -327,7 +330,6 @@ impl<TX: Number + RealNumber, TY: Number + Ord + Unsigned, X: Array2<TX>, Y: Arr
/// Estimates the class labels for the provided data. /// Estimates the class labels for the provided data.
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features. /// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
///
/// Returns a vector of size N with class estimates. /// Returns a vector of size N with class estimates.
pub fn predict(&self, x: &X) -> Result<Y, Failed> { pub fn predict(&self, x: &X) -> Result<Y, Failed> {
self.inner.as_ref().unwrap().predict(x) self.inner.as_ref().unwrap().predict(x)
@@ -373,6 +375,7 @@ mod tests {
fn search_parameters() { fn search_parameters() {
let parameters = GaussianNBSearchParameters { let parameters = GaussianNBSearchParameters {
priors: vec![Some(vec![1.]), Some(vec![2.])], priors: vec![Some(vec![1.]), Some(vec![2.])],
..Default::default()
}; };
let mut iter = parameters.into_iter(); let mut iter = parameters.into_iter();
let next = iter.next().unwrap(); let next = iter.next().unwrap();
@@ -395,8 +398,7 @@ mod tests {
&[1., 1.], &[1., 1.],
&[2., 1.], &[2., 1.],
&[3., 2.], &[3., 2.],
]) ]);
.unwrap();
let y: Vec<u32> = vec![1, 1, 1, 2, 2, 2]; let y: Vec<u32> = vec![1, 1, 1, 2, 2, 2];
let gnb = GaussianNB::fit(&x, &y, Default::default()).unwrap(); let gnb = GaussianNB::fit(&x, &y, Default::default()).unwrap();
@@ -436,8 +438,7 @@ mod tests {
&[1., 1.], &[1., 1.],
&[2., 1.], &[2., 1.],
&[3., 2.], &[3., 2.],
]) ]);
.unwrap();
let y: Vec<u32> = vec![1, 1, 1, 2, 2, 2]; let y: Vec<u32> = vec![1, 1, 1, 2, 2, 2];
let priors = vec![0.3, 0.7]; let priors = vec![0.3, 0.7];
@@ -464,8 +465,7 @@ mod tests {
&[1., 1.], &[1., 1.],
&[2., 1.], &[2., 1.],
&[3., 2.], &[3., 2.],
]) ]);
.unwrap();
let y: Vec<u32> = vec![1, 1, 1, 2, 2, 2]; let y: Vec<u32> = vec![1, 1, 1, 2, 2, 2];
let gnb = GaussianNB::fit(&x, &y, Default::default()).unwrap(); let gnb = GaussianNB::fit(&x, &y, Default::default()).unwrap();
+20 -532
View File
@@ -89,545 +89,33 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: NBDistribution<TX,
/// Estimates the class labels for the provided data. /// Estimates the class labels for the provided data.
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features. /// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
///
/// Returns a vector of size N with class estimates. /// Returns a vector of size N with class estimates.
pub fn predict(&self, x: &X) -> Result<Y, Failed> { pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let y_classes = self.distribution.classes(); let y_classes = self.distribution.classes();
if y_classes.is_empty() {
return Err(Failed::predict("Failed to predict, no classes available"));
}
let (rows, _) = x.shape(); let (rows, _) = x.shape();
let mut predictions = Vec::with_capacity(rows); let predictions = (0..rows)
let mut all_probs_nan = true; .map(|row_index| {
let row = x.get_row(row_index);
for row_index in 0..rows { let (prediction, _probability) = y_classes
let row = x.get_row(row_index); .iter()
let mut max_log_prob = f64::NEG_INFINITY; .enumerate()
let mut max_class = None; .map(|(class_index, class)| {
(
for (class_index, class) in y_classes.iter().enumerate() { class,
let log_likelihood = self.distribution.log_likelihood(class_index, &row); self.distribution.log_likelihood(class_index, &row)
let log_prob = log_likelihood + self.distribution.prior(class_index).ln(); + self.distribution.prior(class_index).ln(),
)
if !log_prob.is_nan() && log_prob > max_log_prob { })
max_log_prob = log_prob; .max_by(|(_, p1), (_, p2)| p1.partial_cmp(p2).unwrap())
max_class = Some(*class); .unwrap();
all_probs_nan = false; *prediction
} })
} .collect::<Vec<TY>>();
let y_hat = Y::from_vec_slice(&predictions);
predictions.push(max_class.unwrap_or(y_classes[0])); Ok(y_hat)
}
if all_probs_nan {
Err(Failed::predict(
"Failed to predict, all probabilities were NaN",
))
} else {
Ok(Y::from_vec_slice(&predictions))
}
} }
} }
pub mod bernoulli; pub mod bernoulli;
pub mod categorical; pub mod categorical;
pub mod gaussian; pub mod gaussian;
pub mod multinomial; pub mod multinomial;
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::arrays::Array;
use crate::linalg::basic::matrix::DenseMatrix;
use num_traits::float::Float;
type Model<'d> = BaseNaiveBayes<i32, i32, DenseMatrix<i32>, Vec<i32>, TestDistribution<'d>>;
#[derive(Debug, PartialEq, Clone)]
struct TestDistribution<'d>(&'d Vec<i32>);
impl NBDistribution<i32, i32> for TestDistribution<'_> {
fn prior(&self, _class_index: usize) -> f64 {
1.
}
fn log_likelihood<'a>(
&'a self,
class_index: usize,
_j: &'a Box<dyn ArrayView1<i32> + 'a>,
) -> f64 {
match self.0.get(class_index) {
&v @ 2 | &v @ 10 | &v @ 20 => v as f64,
_ => f64::nan(),
}
}
fn classes(&self) -> &Vec<i32> {
self.0
}
}
#[test]
fn test_predict() {
let matrix = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]]).unwrap();
let val = vec![];
match Model::fit(TestDistribution(&val)).unwrap().predict(&matrix) {
Ok(_) => panic!("Should return error in case of empty classes"),
Err(err) => assert_eq!(
err.to_string(),
"Predict failed: Failed to predict, no classes available"
),
}
let val = vec![1, 2, 3];
match Model::fit(TestDistribution(&val)).unwrap().predict(&matrix) {
Ok(r) => assert_eq!(r, vec![2, 2, 2]),
Err(_) => panic!("Should success in normal case with NaNs"),
}
let val = vec![20, 2, 10];
match Model::fit(TestDistribution(&val)).unwrap().predict(&matrix) {
Ok(r) => assert_eq!(r, vec![20, 20, 20]),
Err(_) => panic!("Should success in normal case without NaNs"),
}
}
// A simple test distribution using float
#[derive(Debug, PartialEq, Clone)]
struct TestDistributionAgain {
classes: Vec<u32>,
probs: Vec<f64>,
}
impl NBDistribution<f64, u32> for TestDistributionAgain {
fn classes(&self) -> &Vec<u32> {
&self.classes
}
fn prior(&self, class_index: usize) -> f64 {
self.probs[class_index]
}
fn log_likelihood<'a>(
&'a self,
class_index: usize,
_j: &'a Box<dyn ArrayView1<f64> + 'a>,
) -> f64 {
self.probs[class_index].ln()
}
}
type TestNB = BaseNaiveBayes<f64, u32, DenseMatrix<f64>, Vec<u32>, TestDistributionAgain>;
#[test]
fn test_predict_empty_classes() {
let dist = TestDistributionAgain {
classes: vec![],
probs: vec![],
};
let nb = TestNB::fit(dist).unwrap();
let x = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
assert!(nb.predict(&x).is_err());
}
#[test]
fn test_predict_single_class() {
let dist = TestDistributionAgain {
classes: vec![1],
probs: vec![1.0],
};
let nb = TestNB::fit(dist).unwrap();
let x = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
let result = nb.predict(&x).unwrap();
assert_eq!(result, vec![1, 1]);
}
#[test]
fn test_predict_multiple_classes() {
let dist = TestDistributionAgain {
classes: vec![1, 2, 3],
probs: vec![0.2, 0.5, 0.3],
};
let nb = TestNB::fit(dist).unwrap();
let x = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0], &[5.0, 6.0]]).unwrap();
let result = nb.predict(&x).unwrap();
assert_eq!(result, vec![2, 2, 2]);
}
#[test]
fn test_predict_with_nans() {
let dist = TestDistributionAgain {
classes: vec![1, 2],
probs: vec![f64::NAN, 0.5],
};
let nb = TestNB::fit(dist).unwrap();
let x = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
let result = nb.predict(&x).unwrap();
assert_eq!(result, vec![2, 2]);
}
#[test]
fn test_predict_all_nans() {
let dist = TestDistributionAgain {
classes: vec![1, 2],
probs: vec![f64::NAN, f64::NAN],
};
let nb = TestNB::fit(dist).unwrap();
let x = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
assert!(nb.predict(&x).is_err());
}
#[test]
fn test_predict_extreme_probabilities() {
let dist = TestDistributionAgain {
classes: vec![1, 2],
probs: vec![1e-300, 1e-301],
};
let nb = TestNB::fit(dist).unwrap();
let x = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
let result = nb.predict(&x).unwrap();
assert_eq!(result, vec![1, 1]);
}
#[test]
fn test_predict_with_infinity() {
let dist = TestDistributionAgain {
classes: vec![1, 2, 3],
probs: vec![f64::INFINITY, 1.0, 2.0],
};
let nb = TestNB::fit(dist).unwrap();
let x = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
let result = nb.predict(&x).unwrap();
assert_eq!(result, vec![1, 1]);
}
#[test]
fn test_predict_with_negative_infinity() {
let dist = TestDistributionAgain {
classes: vec![1, 2, 3],
probs: vec![f64::NEG_INFINITY, 1.0, 2.0],
};
let nb = TestNB::fit(dist).unwrap();
let x = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
let result = nb.predict(&x).unwrap();
assert_eq!(result, vec![3, 3]);
}
#[test]
fn test_gaussian_naive_bayes_numerical_stability() {
#[derive(Debug, PartialEq, Clone)]
struct GaussianTestDistribution {
classes: Vec<u32>,
means: Vec<Vec<f64>>,
variances: Vec<Vec<f64>>,
priors: Vec<f64>,
}
impl NBDistribution<f64, u32> for GaussianTestDistribution {
fn classes(&self) -> &Vec<u32> {
&self.classes
}
fn prior(&self, class_index: usize) -> f64 {
self.priors[class_index]
}
fn log_likelihood<'a>(
&'a self,
class_index: usize,
j: &'a Box<dyn ArrayView1<f64> + 'a>,
) -> f64 {
let means = &self.means[class_index];
let variances = &self.variances[class_index];
j.iterator(0)
.enumerate()
.map(|(i, &xi)| {
let mean = means[i];
let var = variances[i] + 1e-9; // Small smoothing for numerical stability
let coeff = -0.5 * (2.0 * std::f64::consts::PI * var).ln();
let exponent = -(xi - mean).powi(2) / (2.0 * var);
coeff + exponent
})
.sum()
}
}
fn train_distribution(x: &DenseMatrix<f64>, y: &[u32]) -> GaussianTestDistribution {
let mut classes: Vec<u32> = y
.iter()
.cloned()
.collect::<std::collections::HashSet<u32>>()
.into_iter()
.collect();
classes.sort();
let n_classes = classes.len();
let n_features = x.shape().1;
let mut means = vec![vec![0.0; n_features]; n_classes];
let mut variances = vec![vec![0.0; n_features]; n_classes];
let mut class_counts = vec![0; n_classes];
// Calculate means and count samples per class
for (sample, &class) in x.row_iter().zip(y.iter()) {
let class_idx = classes.iter().position(|&c| c == class).unwrap();
class_counts[class_idx] += 1;
for (i, &value) in sample.iterator(0).enumerate() {
means[class_idx][i] += value;
}
}
// Normalize means
for (class_idx, mean) in means.iter_mut().enumerate() {
for value in mean.iter_mut() {
*value /= class_counts[class_idx] as f64;
}
}
// Calculate variances
for (sample, &class) in x.row_iter().zip(y.iter()) {
let class_idx = classes.iter().position(|&c| c == class).unwrap();
for (i, &value) in sample.iterator(0).enumerate() {
let diff = value - means[class_idx][i];
variances[class_idx][i] += diff * diff;
}
}
// Normalize variances and add small epsilon to avoid zero variance
let epsilon = 1e-9;
for (class_idx, variance) in variances.iter_mut().enumerate() {
for value in variance.iter_mut() {
*value = *value / class_counts[class_idx] as f64 + epsilon;
}
}
// Calculate priors
let total_samples = y.len() as f64;
let priors: Vec<f64> = class_counts
.iter()
.map(|&count| count as f64 / total_samples)
.collect();
GaussianTestDistribution {
classes,
means,
variances,
priors,
}
}
type TestNBGaussian =
BaseNaiveBayes<f64, u32, DenseMatrix<f64>, Vec<u32>, GaussianTestDistribution>;
// Create a constant training dataset
let n_samples = 1000;
let n_features = 5;
let n_classes = 4;
let mut x_data = Vec::with_capacity(n_samples * n_features);
let mut y_data = Vec::with_capacity(n_samples);
for i in 0..n_samples {
for j in 0..n_features {
x_data.push((i * j) as f64 % 10.0);
}
y_data.push((i % n_classes) as u32);
}
let x = DenseMatrix::new(n_samples, n_features, x_data, true).unwrap();
let y = y_data;
// Train the model
let dist = train_distribution(&x, &y);
let nb = TestNBGaussian::fit(dist).unwrap();
// Create constant test data
let n_test_samples = 100;
let mut test_x_data = Vec::with_capacity(n_test_samples * n_features);
for i in 0..n_test_samples {
for j in 0..n_features {
test_x_data.push((i * j * 2) as f64 % 15.0);
}
}
let test_x = DenseMatrix::new(n_test_samples, n_features, test_x_data, true).unwrap();
// Make predictions
let predictions = nb
.predict(&test_x)
.map_err(|e| format!("Prediction failed: {}", e))
.unwrap();
// Check numerical stability
assert_eq!(
predictions.len(),
n_test_samples,
"Number of predictions should match number of test samples"
);
// Check that all predictions are valid class labels
for &pred in predictions.iter() {
assert!(pred < n_classes as u32, "Predicted class should be valid");
}
// Check consistency of predictions
let repeated_predictions = nb
.predict(&test_x)
.map_err(|e| format!("Repeated prediction failed: {}", e))
.unwrap();
assert_eq!(
predictions, repeated_predictions,
"Predictions should be consistent when repeated"
);
// Check extreme values
let extreme_x =
DenseMatrix::new(2, n_features, vec![f64::MAX; n_features * 2], true).unwrap();
let extreme_predictions = nb.predict(&extreme_x);
assert!(
extreme_predictions.is_err(),
"Extreme value input should result in an error"
);
assert_eq!(
extreme_predictions.unwrap_err().to_string(),
"Predict failed: Failed to predict, all probabilities were NaN",
"Incorrect error message for extreme values"
);
// Check for NaN handling
let nan_x = DenseMatrix::new(2, n_features, vec![f64::NAN; n_features * 2], true).unwrap();
let nan_predictions = nb.predict(&nan_x);
assert!(
nan_predictions.is_err(),
"NaN input should result in an error"
);
// Check for very small values
let small_x =
DenseMatrix::new(2, n_features, vec![f64::MIN_POSITIVE; n_features * 2], true).unwrap();
let small_predictions = nb
.predict(&small_x)
.map_err(|e| format!("Small value prediction failed: {}", e))
.unwrap();
for &pred in small_predictions.iter() {
assert!(
pred < n_classes as u32,
"Predictions for very small values should be valid"
);
}
// Check for values close to zero
let near_zero_x =
DenseMatrix::new(2, n_features, vec![1e-300; n_features * 2], true).unwrap();
let near_zero_predictions = nb
.predict(&near_zero_x)
.map_err(|e| format!("Near-zero value prediction failed: {}", e))
.unwrap();
for &pred in near_zero_predictions.iter() {
assert!(
pred < n_classes as u32,
"Predictions for near-zero values should be valid"
);
}
println!("All numerical stability checks passed!");
}
#[test]
fn test_gaussian_naive_bayes_numerical_stability_random_data() {
#[derive(Debug)]
struct MySimpleRng {
state: u64,
}
impl MySimpleRng {
fn new(seed: u64) -> Self {
MySimpleRng { state: seed }
}
/// Get the next u64 in the sequence.
fn next_u64(&mut self) -> u64 {
// LCG parameters; these are somewhat arbitrary but commonly used.
// Feel free to tweak the multiplier/adder etc.
self.state = self.state.wrapping_mul(6364136223846793005).wrapping_add(1);
self.state
}
/// Get an f64 in the range [min, max).
fn next_f64(&mut self, min: f64, max: f64) -> f64 {
let fraction = (self.next_u64() as f64) / (u64::MAX as f64);
min + fraction * (max - min)
}
/// Get a usize in the range [min, max). This floors the floating result.
fn gen_range_usize(&mut self, min: usize, max: usize) -> usize {
let v = self.next_f64(min as f64, max as f64);
// Truncate into the integer range. Because of floating inexactness,
// ensure we also clamp.
let int_v = v.floor() as isize;
// simple clamp to avoid any float rounding out of range
let clamped = int_v.max(min as isize).min((max - 1) as isize);
clamped as usize
}
}
use crate::naive_bayes::gaussian::GaussianNB;
// We will generate random data in a reproducible way (using a fixed seed).
// We will generate random data in a reproducible way:
let mut rng = MySimpleRng::new(42);
let n_samples = 1000;
let n_features = 5;
let n_classes = 4;
// Our feature matrix and label vector
let mut x_data = Vec::with_capacity(n_samples * n_features);
let mut y_data = Vec::with_capacity(n_samples);
// Fill x_data with random values and y_data with random class labels.
for _i in 0..n_samples {
for _j in 0..n_features {
// Well pick random values in [-10, 10).
x_data.push(rng.next_f64(-10.0, 10.0));
}
let class = rng.gen_range_usize(0, n_classes) as u32;
y_data.push(class);
}
// Create DenseMatrix from x_data
let x = DenseMatrix::new(n_samples, n_features, x_data, true).unwrap();
// Train GaussianNB
let gnb = GaussianNB::fit(&x, &y_data, Default::default())
.expect("Fitting GaussianNB with random data failed.");
// Predict on the same training data to verify no numerical instability
let predictions = gnb.predict(&x).expect("Prediction on random data failed.");
// Basic sanity checks
assert_eq!(
predictions.len(),
n_samples,
"Prediction size must match n_samples"
);
for &pred_class in &predictions {
assert!(
(pred_class as usize) < n_classes,
"Predicted class {} is out of range [0..n_classes).",
pred_class
);
}
// If you want to compare with scikit-learn, you can do something like:
// println!("X = {:?}", &x);
// println!("Y = {:?}", &y_data);
// println!("predictions = {:?}", &predictions);
// and then in Python:
// import numpy as np
// from sklearn.naive_bayes import GaussianNB
// X = np.reshape(np.array(x), (1000, 5), order='F')
// Y = np.array(y)
// gnb = GaussianNB().fit(X, Y)
// preds = gnb.predict(X)
// expected = np.array(predictions)
// assert expected == preds
// They should match closely (or exactly) depending on floating rounding.
}
}
+18 -17
View File
@@ -20,13 +20,13 @@
//! &[0, 2, 0, 0, 1, 0], //! &[0, 2, 0, 0, 1, 0],
//! &[0, 1, 0, 1, 0, 0], //! &[0, 1, 0, 1, 0, 0],
//! &[0, 1, 1, 0, 0, 1], //! &[0, 1, 1, 0, 0, 1],
//! ]).unwrap(); //! ]);
//! let y: Vec<u32> = vec![0, 0, 0, 1]; //! let y: Vec<u32> = vec![0, 0, 0, 1];
//! let nb = MultinomialNB::fit(&x, &y, Default::default()).unwrap(); //! let nb = MultinomialNB::fit(&x, &y, Default::default()).unwrap();
//! //!
//! // Testing data point is: //! // Testing data point is:
//! // Chinese Chinese Chinese Tokyo Japan //! // Chinese Chinese Chinese Tokyo Japan
//! let x_test = DenseMatrix::from_2d_array(&[&[0, 3, 1, 0, 0, 1]]).unwrap(); //! let x_test = DenseMatrix::from_2d_array(&[&[0, 3, 1, 0, 0, 1]]);
//! let y_hat = nb.predict(&x_test).unwrap(); //! let y_hat = nb.predict(&x_test).unwrap();
//! ``` //! ```
//! //!
@@ -207,7 +207,8 @@ impl<TY: Number + Ord + Unsigned> MultinomialNBDistribution<TY> {
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features. /// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
/// * `x` - training data. /// * `x` - training data.
/// * `y` - vector with target values (classes) of length N. /// * `y` - vector with target values (classes) of length N.
/// * `priors` - Optional vector with prior probabilities of the classes. If not defined, priors are adjusted according to the data. /// * `priors` - Optional vector with prior probabilities of the classes. If not defined,
/// priors are adjusted according to the data.
/// * `alpha` - Additive (Laplace/Lidstone) smoothing parameter. /// * `alpha` - Additive (Laplace/Lidstone) smoothing parameter.
pub fn fit<TX: Number + Unsigned, X: Array2<TX>, Y: Array1<TY>>( pub fn fit<TX: Number + Unsigned, X: Array2<TX>, Y: Array1<TY>>(
x: &X, x: &X,
@@ -219,18 +220,21 @@ impl<TY: Number + Ord + Unsigned> MultinomialNBDistribution<TY> {
let y_samples = y.shape(); let y_samples = y.shape();
if y_samples != n_samples { if y_samples != n_samples {
return Err(Failed::fit(&format!( return Err(Failed::fit(&format!(
"Size of x should equal size of y; |x|=[{n_samples}], |y|=[{y_samples}]" "Size of x should equal size of y; |x|=[{}], |y|=[{}]",
n_samples, y_samples
))); )));
} }
if n_samples == 0 { if n_samples == 0 {
return Err(Failed::fit(&format!( return Err(Failed::fit(&format!(
"Size of x and y should greater than 0; |x|=[{n_samples}]" "Size of x and y should greater than 0; |x|=[{}]",
n_samples
))); )));
} }
if alpha < 0f64 { if alpha < 0f64 {
return Err(Failed::fit(&format!( return Err(Failed::fit(&format!(
"Alpha should be greater than 0; |alpha|=[{alpha}]" "Alpha should be greater than 0; |alpha|=[{}]",
alpha
))); )));
} }
@@ -262,7 +266,8 @@ impl<TY: Number + Ord + Unsigned> MultinomialNBDistribution<TY> {
feature_in_class_counter[class_index][idx] += feature_in_class_counter[class_index][idx] +=
row_i.to_usize().ok_or_else(|| { row_i.to_usize().ok_or_else(|| {
Failed::fit(&format!( Failed::fit(&format!(
"Elements of the matrix should be convertible to usize |found|=[{row_i}]" "Elements of the matrix should be convertible to usize |found|=[{}]",
row_i
)) ))
})?; })?;
} }
@@ -344,10 +349,10 @@ impl<TX: Number + Unsigned, TY: Number + Ord + Unsigned, X: Array2<TX>, Y: Array
{ {
/// Fits MultinomialNB with given data /// Fits MultinomialNB with given data
/// * `x` - training data of size NxM where N is the number of samples and M is the number of /// * `x` - training data of size NxM where N is the number of samples and M is the number of
/// features. /// features.
/// * `y` - vector with target values (classes) of length N. /// * `y` - vector with target values (classes) of length N.
/// * `parameters` - additional parameters like class priors, alpha for smoothing and /// * `parameters` - additional parameters like class priors, alpha for smoothing and
/// binarizing threshold. /// binarizing threshold.
pub fn fit(x: &X, y: &Y, parameters: MultinomialNBParameters) -> Result<Self, Failed> { pub fn fit(x: &X, y: &Y, parameters: MultinomialNBParameters) -> Result<Self, Failed> {
let distribution = let distribution =
MultinomialNBDistribution::fit(x, y, parameters.alpha, parameters.priors)?; MultinomialNBDistribution::fit(x, y, parameters.alpha, parameters.priors)?;
@@ -357,7 +362,6 @@ impl<TX: Number + Unsigned, TY: Number + Ord + Unsigned, X: Array2<TX>, Y: Array
/// Estimates the class labels for the provided data. /// Estimates the class labels for the provided data.
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features. /// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
///
/// Returns a vector of size N with class estimates. /// Returns a vector of size N with class estimates.
pub fn predict(&self, x: &X) -> Result<Y, Failed> { pub fn predict(&self, x: &X) -> Result<Y, Failed> {
self.inner.as_ref().unwrap().predict(x) self.inner.as_ref().unwrap().predict(x)
@@ -433,8 +437,7 @@ mod tests {
&[0, 2, 0, 0, 1, 0], &[0, 2, 0, 0, 1, 0],
&[0, 1, 0, 1, 0, 0], &[0, 1, 0, 1, 0, 0],
&[0, 1, 1, 0, 0, 1], &[0, 1, 1, 0, 0, 1],
]) ]);
.unwrap();
let y: Vec<u32> = vec![0, 0, 0, 1]; let y: Vec<u32> = vec![0, 0, 0, 1];
let mnb = MultinomialNB::fit(&x, &y, Default::default()).unwrap(); let mnb = MultinomialNB::fit(&x, &y, Default::default()).unwrap();
@@ -468,7 +471,7 @@ mod tests {
// Testing data point is: // Testing data point is:
// Chinese Chinese Chinese Tokyo Japan // Chinese Chinese Chinese Tokyo Japan
let x_test = DenseMatrix::<u32>::from_2d_array(&[&[0, 3, 1, 0, 0, 1]]).unwrap(); let x_test = DenseMatrix::<u32>::from_2d_array(&[&[0, 3, 1, 0, 0, 1]]);
let y_hat = mnb.predict(&x_test).unwrap(); let y_hat = mnb.predict(&x_test).unwrap();
assert_eq!(y_hat, &[0]); assert_eq!(y_hat, &[0]);
@@ -496,8 +499,7 @@ mod tests {
&[2, 0, 3, 3, 1, 2, 0, 2, 4, 1], &[2, 0, 3, 3, 1, 2, 0, 2, 4, 1],
&[2, 4, 0, 4, 2, 4, 1, 3, 1, 4], &[2, 4, 0, 4, 2, 4, 1, 3, 1, 4],
&[0, 2, 2, 3, 4, 0, 4, 4, 4, 4], &[0, 2, 2, 3, 4, 0, 4, 4, 4, 4],
]) ]);
.unwrap();
let y: Vec<u32> = vec![2, 2, 0, 0, 0, 2, 1, 1, 0, 1, 0, 0, 2, 0, 2]; let y: Vec<u32> = vec![2, 2, 0, 0, 0, 2, 1, 1, 0, 1, 0, 0, 2, 0, 2];
let nb = MultinomialNB::fit(&x, &y, Default::default()).unwrap(); let nb = MultinomialNB::fit(&x, &y, Default::default()).unwrap();
@@ -556,8 +558,7 @@ mod tests {
&[0, 1, 0, 0, 1, 0], &[0, 1, 0, 0, 1, 0],
&[0, 1, 0, 1, 0, 0], &[0, 1, 0, 1, 0, 0],
&[0, 1, 1, 0, 0, 1], &[0, 1, 1, 0, 0, 1],
]) ]);
.unwrap();
let y = vec![0, 0, 0, 1]; let y = vec![0, 0, 0, 1];
let mnb = MultinomialNB::fit(&x, &y, Default::default()).unwrap(); let mnb = MultinomialNB::fit(&x, &y, Default::default()).unwrap();
+7 -11
View File
@@ -22,7 +22,7 @@
//! &[3., 4.], //! &[3., 4.],
//! &[5., 6.], //! &[5., 6.],
//! &[7., 8.], //! &[7., 8.],
//! &[9., 10.]]).unwrap(); //! &[9., 10.]]);
//! let y = vec![2, 2, 2, 3, 3]; //your class labels //! let y = vec![2, 2, 2, 3, 3]; //your class labels
//! //!
//! let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap(); //! let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
@@ -236,7 +236,8 @@ impl<TX: Number, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec
if x_n != y_n { if x_n != y_n {
return Err(Failed::fit(&format!( return Err(Failed::fit(&format!(
"Size of x should equal size of y; |x|=[{x_n}], |y|=[{y_n}]" "Size of x should equal size of y; |x|=[{}], |y|=[{}]",
x_n, y_n
))); )));
} }
@@ -261,7 +262,6 @@ impl<TX: Number, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec
/// Estimates the class labels for the provided data. /// Estimates the class labels for the provided data.
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features. /// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
///
/// Returns a vector of size N with class estimates. /// Returns a vector of size N with class estimates.
pub fn predict(&self, x: &X) -> Result<Y, Failed> { pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let mut result = Y::zeros(x.shape().0); let mut result = Y::zeros(x.shape().0);
@@ -312,8 +312,7 @@ mod tests {
#[test] #[test]
fn knn_fit_predict() { fn knn_fit_predict() {
let x = let x =
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]) DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
.unwrap();
let y = vec![2, 2, 2, 3, 3]; let y = vec![2, 2, 2, 3, 3];
let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap(); let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
let y_hat = knn.predict(&x).unwrap(); let y_hat = knn.predict(&x).unwrap();
@@ -327,7 +326,7 @@ mod tests {
)] )]
#[test] #[test]
fn knn_fit_predict_weighted() { fn knn_fit_predict_weighted() {
let x = DenseMatrix::from_2d_array(&[&[1.], &[2.], &[3.], &[4.], &[5.]]).unwrap(); let x = DenseMatrix::from_2d_array(&[&[1.], &[2.], &[3.], &[4.], &[5.]]);
let y = vec![2, 2, 2, 3, 3]; let y = vec![2, 2, 2, 3, 3];
let knn = KNNClassifier::fit( let knn = KNNClassifier::fit(
&x, &x,
@@ -338,9 +337,7 @@ mod tests {
.with_weight(KNNWeightFunction::Distance), .with_weight(KNNWeightFunction::Distance),
) )
.unwrap(); .unwrap();
let y_hat = knn let y_hat = knn.predict(&DenseMatrix::from_2d_array(&[&[4.1]])).unwrap();
.predict(&DenseMatrix::from_2d_array(&[&[4.1]]).unwrap())
.unwrap();
assert_eq!(vec![3], y_hat); assert_eq!(vec![3], y_hat);
} }
@@ -352,8 +349,7 @@ mod tests {
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
fn serde() { fn serde() {
let x = let x =
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]) DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
.unwrap();
let y = vec![2, 2, 2, 3, 3]; let y = vec![2, 2, 2, 3, 3];
let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap(); let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
+13 -12
View File
@@ -24,7 +24,7 @@
//! &[2., 2.], //! &[2., 2.],
//! &[3., 3.], //! &[3., 3.],
//! &[4., 4.], //! &[4., 4.],
//! &[5., 5.]]).unwrap(); //! &[5., 5.]]);
//! let y = vec![1., 2., 3., 4., 5.]; //your target values //! let y = vec![1., 2., 3., 4., 5.]; //your target values
//! //!
//! let knn = KNNRegressor::fit(&x, &y, Default::default()).unwrap(); //! let knn = KNNRegressor::fit(&x, &y, Default::default()).unwrap();
@@ -88,21 +88,25 @@ pub struct KNNRegressor<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D:
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>> impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
KNNRegressor<TX, TY, X, Y, D> KNNRegressor<TX, TY, X, Y, D>
{ {
///
fn y(&self) -> &Y { fn y(&self) -> &Y {
self.y.as_ref().unwrap() self.y.as_ref().unwrap()
} }
///
fn knn_algorithm(&self) -> &KNNAlgorithm<TX, D> { fn knn_algorithm(&self) -> &KNNAlgorithm<TX, D> {
self.knn_algorithm self.knn_algorithm
.as_ref() .as_ref()
.expect("Missing parameter: KNNAlgorithm") .expect("Missing parameter: KNNAlgorithm")
} }
///
fn weight(&self) -> &KNNWeightFunction { fn weight(&self) -> &KNNWeightFunction {
self.weight.as_ref().expect("Missing parameter: weight") self.weight.as_ref().expect("Missing parameter: weight")
} }
#[allow(dead_code)] #[allow(dead_code)]
///
fn k(&self) -> usize { fn k(&self) -> usize {
self.k.unwrap() self.k.unwrap()
} }
@@ -220,7 +224,8 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
if x_n != y_n { if x_n != y_n {
return Err(Failed::fit(&format!( return Err(Failed::fit(&format!(
"Size of x should equal size of y; |x|=[{x_n}], |y|=[{y_n}]" "Size of x should equal size of y; |x|=[{}], |y|=[{}]",
x_n, y_n
))); )));
} }
@@ -246,7 +251,6 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
/// Predict the target for the provided data. /// Predict the target for the provided data.
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features. /// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
///
/// Returns a vector of size N with estimates. /// Returns a vector of size N with estimates.
pub fn predict(&self, x: &X) -> Result<Y, Failed> { pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let mut result = Y::zeros(x.shape().0); let mut result = Y::zeros(x.shape().0);
@@ -292,10 +296,9 @@ mod tests {
#[test] #[test]
fn knn_fit_predict_weighted() { fn knn_fit_predict_weighted() {
let x = let x =
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]) DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
.unwrap();
let y: Vec<f64> = vec![1., 2., 3., 4., 5.]; let y: Vec<f64> = vec![1., 2., 3., 4., 5.];
let y_exp = [1., 2., 3., 4., 5.]; let y_exp = vec![1., 2., 3., 4., 5.];
let knn = KNNRegressor::fit( let knn = KNNRegressor::fit(
&x, &x,
&y, &y,
@@ -309,7 +312,7 @@ mod tests {
let y_hat = knn.predict(&x).unwrap(); let y_hat = knn.predict(&x).unwrap();
assert_eq!(5, Vec::len(&y_hat)); assert_eq!(5, Vec::len(&y_hat));
for i in 0..y_hat.len() { for i in 0..y_hat.len() {
assert!((y_hat[i] - y_exp[i]).abs() < f64::EPSILON); assert!((y_hat[i] - y_exp[i]).abs() < std::f64::EPSILON);
} }
} }
@@ -320,10 +323,9 @@ mod tests {
#[test] #[test]
fn knn_fit_predict_uniform() { fn knn_fit_predict_uniform() {
let x = let x =
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]) DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
.unwrap();
let y: Vec<f64> = vec![1., 2., 3., 4., 5.]; let y: Vec<f64> = vec![1., 2., 3., 4., 5.];
let y_exp = [2., 2., 3., 4., 4.]; let y_exp = vec![2., 2., 3., 4., 4.];
let knn = KNNRegressor::fit(&x, &y, Default::default()).unwrap(); let knn = KNNRegressor::fit(&x, &y, Default::default()).unwrap();
let y_hat = knn.predict(&x).unwrap(); let y_hat = knn.predict(&x).unwrap();
assert_eq!(5, Vec::len(&y_hat)); assert_eq!(5, Vec::len(&y_hat));
@@ -340,8 +342,7 @@ mod tests {
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
fn serde() { fn serde() {
let x = let x =
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]) DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
.unwrap();
let y = vec![1., 2., 3., 4., 5.]; let y = vec![1., 2., 3., 4., 5.];
let knn = KNNRegressor::fit(&x, &y, Default::default()).unwrap(); let knn = KNNRegressor::fit(&x, &y, Default::default()).unwrap();
+8 -3
View File
@@ -49,22 +49,27 @@ pub type KNNAlgorithmName = crate::algorithm::neighbour::KNNAlgorithmName;
/// Weight function that is used to determine estimated value. /// Weight function that is used to determine estimated value.
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone)]
pub enum KNNWeightFunction { pub enum KNNWeightFunction {
/// All k nearest points are weighted equally /// All k nearest points are weighted equally
#[default]
Uniform, Uniform,
/// k nearest points are weighted by the inverse of their distance. Closer neighbors will have a greater influence than neighbors which are further away. /// k nearest points are weighted by the inverse of their distance. Closer neighbors will have a greater influence than neighbors which are further away.
Distance, Distance,
} }
impl Default for KNNWeightFunction {
fn default() -> Self {
KNNWeightFunction::Uniform
}
}
impl KNNWeightFunction { impl KNNWeightFunction {
fn calc_weights(&self, distances: Vec<f64>) -> std::vec::Vec<f64> { fn calc_weights(&self, distances: Vec<f64>) -> std::vec::Vec<f64> {
match *self { match *self {
KNNWeightFunction::Distance => { KNNWeightFunction::Distance => {
// if there are any points that has zero distance from one or more training points, // if there are any points that has zero distance from one or more training points,
// those training points are weighted as 1.0 and the other points as 0.0 // those training points are weighted as 1.0 and the other points as 0.0
if distances.contains(&0f64) { if distances.iter().any(|&e| e == 0f64) {
distances distances
.iter() .iter()
.map(|e| if *e == 0f64 { 1f64 } else { 0f64 }) .map(|e| if *e == 0f64 { 1f64 } else { 0f64 })
+3 -26
View File
@@ -2,13 +2,9 @@
//! Most algorithms in `smartcore` rely on basic linear algebra operations like dot product, matrix decomposition and other subroutines that are defined for a set of real numbers, . //! Most algorithms in `smartcore` rely on basic linear algebra operations like dot product, matrix decomposition and other subroutines that are defined for a set of real numbers, .
//! This module defines real number and some useful functions that are used in [Linear Algebra](../../linalg/index.html) module. //! This module defines real number and some useful functions that are used in [Linear Algebra](../../linalg/index.html) module.
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use num_traits::Float; use num_traits::Float;
use crate::numbers::basenum::Number; use crate::numbers::basenum::Number;
use crate::rand_custom::get_rng_impl;
/// Defines real number /// Defines real number
/// <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script> /// <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
@@ -67,12 +63,8 @@ impl RealNumber for f64 {
} }
fn rand() -> f64 { fn rand() -> f64 {
let mut small_rng = get_rng_impl(None); // TODO: to be implemented, see issue smartcore#214
1.0
let mut rngs: Vec<SmallRng> = (0..3)
.map(|_| SmallRng::from_rng(&mut small_rng).unwrap())
.collect();
rngs[0].gen::<f64>()
} }
fn two() -> Self { fn two() -> Self {
@@ -116,12 +108,7 @@ impl RealNumber for f32 {
} }
fn rand() -> f32 { fn rand() -> f32 {
let mut small_rng = get_rng_impl(None); 1.0
let mut rngs: Vec<SmallRng> = (0..3)
.map(|_| SmallRng::from_rng(&mut small_rng).unwrap())
.collect();
rngs[0].gen::<f32>()
} }
fn two() -> Self { fn two() -> Self {
@@ -162,14 +149,4 @@ mod tests {
fn f64_from_string() { fn f64_from_string() {
assert_eq!(f64::from_str("1.111111111").unwrap(), 1.111111111) assert_eq!(f64::from_str("1.111111111").unwrap(), 1.111111111)
} }
#[test]
fn f64_rand() {
f64::rand();
}
#[test]
fn f32_rand() {
f32::rand();
}
} }
@@ -1,3 +1,5 @@
// TODO: missing documentation
use std::default::Default; use std::default::Default;
use crate::linalg::basic::arrays::Array1; use crate::linalg::basic::arrays::Array1;
@@ -6,27 +8,30 @@ use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
use crate::optimization::line_search::LineSearchMethod; use crate::optimization::line_search::LineSearchMethod;
use crate::optimization::{DF, F}; use crate::optimization::{DF, F};
/// Gradient Descent optimization algorithm ///
pub struct GradientDescent { pub struct GradientDescent {
/// Maximum number of iterations ///
pub max_iter: usize, pub max_iter: usize,
/// Relative tolerance for the gradient norm ///
pub g_rtol: f64, pub g_rtol: f64,
/// Absolute tolerance for the gradient norm ///
pub g_atol: f64, pub g_atol: f64,
} }
///
impl Default for GradientDescent { impl Default for GradientDescent {
fn default() -> Self { fn default() -> Self {
GradientDescent { GradientDescent {
max_iter: 10000, max_iter: 10000,
g_rtol: f64::EPSILON.sqrt(), g_rtol: std::f64::EPSILON.sqrt(),
g_atol: f64::EPSILON, g_atol: std::f64::EPSILON,
} }
} }
} }
///
impl<T: FloatNumber> FirstOrderOptimizer<T> for GradientDescent { impl<T: FloatNumber> FirstOrderOptimizer<T> for GradientDescent {
///
fn optimize<'a, X: Array1<T>, LS: LineSearchMethod<T>>( fn optimize<'a, X: Array1<T>, LS: LineSearchMethod<T>>(
&self, &self,
f: &'a F<'_, T, X>, f: &'a F<'_, T, X>,
@@ -108,13 +113,12 @@ mod tests {
g[1] = 200. * (x[1] - x[0].powf(2.)); g[1] = 200. * (x[1] - x[0].powf(2.));
}; };
let ls: Backtracking<f64> = Backtracking::<f64> { let mut ls: Backtracking<f64> = Default::default();
order: FunctionOrder::THIRD, ls.order = FunctionOrder::THIRD;
..Default::default()
};
let optimizer: GradientDescent = Default::default(); let optimizer: GradientDescent = Default::default();
let result = optimizer.optimize(&f, &df, &x0, &ls); let result = optimizer.optimize(&f, &df, &x0, &ls);
println!("{:?}", result);
assert!((result.f_x - 0.0).abs() < 1e-5); assert!((result.f_x - 0.0).abs() < 1e-5);
assert!((result.x[0] - 1.0).abs() < 1e-2); assert!((result.x[0] - 1.0).abs() < 1e-2);
+29 -20
View File
@@ -11,29 +11,31 @@ use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
use crate::optimization::line_search::LineSearchMethod; use crate::optimization::line_search::LineSearchMethod;
use crate::optimization::{DF, F}; use crate::optimization::{DF, F};
/// Limited-memory BFGS optimization algorithm ///
pub struct LBFGS { pub struct LBFGS {
/// Maximum number of iterations ///
pub max_iter: usize, pub max_iter: usize,
/// TODO: Add documentation ///
pub g_rtol: f64, pub g_rtol: f64,
/// TODO: Add documentation ///
pub g_atol: f64, pub g_atol: f64,
/// TODO: Add documentation ///
pub x_atol: f64, pub x_atol: f64,
/// TODO: Add documentation ///
pub x_rtol: f64, pub x_rtol: f64,
/// TODO: Add documentation ///
pub f_abstol: f64, pub f_abstol: f64,
/// TODO: Add documentation ///
pub f_reltol: f64, pub f_reltol: f64,
/// TODO: Add documentation ///
pub successive_f_tol: usize, pub successive_f_tol: usize,
/// TODO: Add documentation ///
pub m: usize, pub m: usize,
} }
///
impl Default for LBFGS { impl Default for LBFGS {
///
fn default() -> Self { fn default() -> Self {
LBFGS { LBFGS {
max_iter: 1000, max_iter: 1000,
@@ -49,7 +51,9 @@ impl Default for LBFGS {
} }
} }
///
impl LBFGS { impl LBFGS {
///
fn two_loops<T: FloatNumber + RealNumber, X: Array1<T>>(&self, state: &mut LBFGSState<T, X>) { fn two_loops<T: FloatNumber + RealNumber, X: Array1<T>>(&self, state: &mut LBFGSState<T, X>) {
let lower = state.iteration.max(self.m) - self.m; let lower = state.iteration.max(self.m) - self.m;
let upper = state.iteration; let upper = state.iteration;
@@ -91,6 +95,7 @@ impl LBFGS {
state.s.mul_scalar_mut(-T::one()); state.s.mul_scalar_mut(-T::one());
} }
///
fn init_state<T: FloatNumber + RealNumber, X: Array1<T>>(&self, x: &X) -> LBFGSState<T, X> { fn init_state<T: FloatNumber + RealNumber, X: Array1<T>>(&self, x: &X) -> LBFGSState<T, X> {
LBFGSState { LBFGSState {
x: x.clone(), x: x.clone(),
@@ -114,6 +119,7 @@ impl LBFGS {
} }
} }
///
fn update_state<'a, T: FloatNumber + RealNumber, X: Array1<T>, LS: LineSearchMethod<T>>( fn update_state<'a, T: FloatNumber + RealNumber, X: Array1<T>, LS: LineSearchMethod<T>>(
&self, &self,
f: &'a F<'_, T, X>, f: &'a F<'_, T, X>,
@@ -155,6 +161,7 @@ impl LBFGS {
df(&mut state.x_df, &state.x); df(&mut state.x_df, &state.x);
} }
///
fn assess_convergence<T: FloatNumber, X: Array1<T>>( fn assess_convergence<T: FloatNumber, X: Array1<T>>(
&self, &self,
state: &mut LBFGSState<T, X>, state: &mut LBFGSState<T, X>,
@@ -166,7 +173,7 @@ impl LBFGS {
} }
if state.x.max_diff(&state.x_prev) if state.x.max_diff(&state.x_prev)
<= T::from_f64(self.x_rtol * state.x.norm(f64::INFINITY)).unwrap() <= T::from_f64(self.x_rtol * state.x.norm(std::f64::INFINITY)).unwrap()
{ {
x_converged = true; x_converged = true;
} }
@@ -181,16 +188,17 @@ impl LBFGS {
state.counter_f_tol += 1; state.counter_f_tol += 1;
} }
if state.x_df.norm(f64::INFINITY) <= self.g_atol { if state.x_df.norm(std::f64::INFINITY) <= self.g_atol {
g_converged = true; g_converged = true;
} }
g_converged || x_converged || state.counter_f_tol > self.successive_f_tol g_converged || x_converged || state.counter_f_tol > self.successive_f_tol
} }
fn update_hessian<T: FloatNumber, X: Array1<T>>( ///
fn update_hessian<'a, T: FloatNumber, X: Array1<T>>(
&self, &self,
_: &DF<'_, X>, _: &'a DF<'_, X>,
state: &mut LBFGSState<T, X>, state: &mut LBFGSState<T, X>,
) { ) {
state.dg = state.x_df.sub(&state.x_df_prev); state.dg = state.x_df.sub(&state.x_df_prev);
@@ -204,6 +212,7 @@ impl LBFGS {
} }
} }
///
#[derive(Debug)] #[derive(Debug)]
struct LBFGSState<T: FloatNumber, X: Array1<T>> { struct LBFGSState<T: FloatNumber, X: Array1<T>> {
x: X, x: X,
@@ -225,7 +234,9 @@ struct LBFGSState<T: FloatNumber, X: Array1<T>> {
alpha: T, alpha: T,
} }
///
impl<T: FloatNumber + RealNumber> FirstOrderOptimizer<T> for LBFGS { impl<T: FloatNumber + RealNumber> FirstOrderOptimizer<T> for LBFGS {
///
fn optimize<'a, X: Array1<T>, LS: LineSearchMethod<T>>( fn optimize<'a, X: Array1<T>, LS: LineSearchMethod<T>>(
&self, &self,
f: &F<'_, T, X>, f: &F<'_, T, X>,
@@ -237,7 +248,7 @@ impl<T: FloatNumber + RealNumber> FirstOrderOptimizer<T> for LBFGS {
df(&mut state.x_df, x0); df(&mut state.x_df, x0);
let g_converged = state.x_df.norm(f64::INFINITY) < self.g_atol; let g_converged = state.x_df.norm(std::f64::INFINITY) < self.g_atol;
let mut converged = g_converged; let mut converged = g_converged;
let stopped = false; let stopped = false;
@@ -280,15 +291,13 @@ mod tests {
g[0] = -2. * (1. - x[0]) - 400. * (x[1] - x[0].powf(2.)) * x[0]; g[0] = -2. * (1. - x[0]) - 400. * (x[1] - x[0].powf(2.)) * x[0];
g[1] = 200. * (x[1] - x[0].powf(2.)); g[1] = 200. * (x[1] - x[0].powf(2.));
}; };
let ls: Backtracking<f64> = Backtracking::<f64> { let mut ls: Backtracking<f64> = Default::default();
order: FunctionOrder::THIRD, ls.order = FunctionOrder::THIRD;
..Default::default()
};
let optimizer: LBFGS = Default::default(); let optimizer: LBFGS = Default::default();
let result = optimizer.optimize(&f, &df, &x0, &ls); let result = optimizer.optimize(&f, &df, &x0, &ls);
assert!((result.f_x - 0.0).abs() < f64::EPSILON); assert!((result.f_x - 0.0).abs() < std::f64::EPSILON);
assert!((result.x[0] - 1.0).abs() < 1e-8); assert!((result.x[0] - 1.0).abs() < 1e-8);
assert!((result.x[1] - 1.0).abs() < 1e-8); assert!((result.x[1] - 1.0).abs() < 1e-8);
assert!(result.iterations <= 24); assert!(result.iterations <= 24);
+8 -8
View File
@@ -1,6 +1,6 @@
/// Gradient descent optimization algorithm ///
pub mod gradient_descent; pub mod gradient_descent;
/// Limited-memory BFGS optimization algorithm ///
pub mod lbfgs; pub mod lbfgs;
use std::clone::Clone; use std::clone::Clone;
@@ -11,9 +11,9 @@ use crate::numbers::floatnum::FloatNumber;
use crate::optimization::line_search::LineSearchMethod; use crate::optimization::line_search::LineSearchMethod;
use crate::optimization::{DF, F}; use crate::optimization::{DF, F};
/// First-order optimization is a class of algorithms that use the first derivative of a function to find optimal solutions. ///
pub trait FirstOrderOptimizer<T: FloatNumber> { pub trait FirstOrderOptimizer<T: FloatNumber> {
/// run first order optimization ///
fn optimize<'a, X: Array1<T>, LS: LineSearchMethod<T>>( fn optimize<'a, X: Array1<T>, LS: LineSearchMethod<T>>(
&self, &self,
f: &F<'_, T, X>, f: &F<'_, T, X>,
@@ -23,13 +23,13 @@ pub trait FirstOrderOptimizer<T: FloatNumber> {
) -> OptimizerResult<T, X>; ) -> OptimizerResult<T, X>;
} }
/// Result of optimization ///
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct OptimizerResult<T: FloatNumber, X: Array1<T>> { pub struct OptimizerResult<T: FloatNumber, X: Array1<T>> {
/// Solution ///
pub x: X, pub x: X,
/// f(x) value ///
pub f_x: T, pub f_x: T,
/// number of iterations ///
pub iterations: usize, pub iterations: usize,
} }
+17 -12
View File
@@ -1,9 +1,11 @@
// TODO: missing documentation
use crate::optimization::FunctionOrder; use crate::optimization::FunctionOrder;
use num_traits::Float; use num_traits::Float;
/// Line search optimization. ///
pub trait LineSearchMethod<T: Float> { pub trait LineSearchMethod<T: Float> {
/// Find alpha that satisfies strong Wolfe conditions. ///
fn search( fn search(
&self, &self,
f: &(dyn Fn(T) -> T), f: &(dyn Fn(T) -> T),
@@ -14,31 +16,32 @@ pub trait LineSearchMethod<T: Float> {
) -> LineSearchResult<T>; ) -> LineSearchResult<T>;
} }
/// Line search result ///
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct LineSearchResult<T: Float> { pub struct LineSearchResult<T: Float> {
/// Alpha value ///
pub alpha: T, pub alpha: T,
/// f(alpha) value ///
pub f_x: T, pub f_x: T,
} }
/// Backtracking line search method. ///
pub struct Backtracking<T: Float> { pub struct Backtracking<T: Float> {
/// TODO: Add documentation ///
pub c1: T, pub c1: T,
/// Maximum number of iterations for Backtracking single run ///
pub max_iterations: usize, pub max_iterations: usize,
/// TODO: Add documentation ///
pub max_infinity_iterations: usize, pub max_infinity_iterations: usize,
/// TODO: Add documentation ///
pub phi: T, pub phi: T,
/// TODO: Add documentation ///
pub plo: T, pub plo: T,
/// function order ///
pub order: FunctionOrder, pub order: FunctionOrder,
} }
///
impl<T: Float> Default for Backtracking<T> { impl<T: Float> Default for Backtracking<T> {
fn default() -> Self { fn default() -> Self {
Backtracking { Backtracking {
@@ -52,7 +55,9 @@ impl<T: Float> Default for Backtracking<T> {
} }
} }
///
impl<T: Float> LineSearchMethod<T> for Backtracking<T> { impl<T: Float> LineSearchMethod<T> for Backtracking<T> {
///
fn search( fn search(
&self, &self,
f: &(dyn Fn(T) -> T), f: &(dyn Fn(T) -> T),
+9 -7
View File
@@ -1,19 +1,21 @@
/// first order optimization algorithms // TODO: missing documentation
///
pub mod first_order; pub mod first_order;
/// line search algorithms ///
pub mod line_search; pub mod line_search;
/// Function f(x) = y ///
pub type F<'a, T, X> = dyn for<'b> Fn(&'b X) -> T + 'a; pub type F<'a, T, X> = dyn for<'b> Fn(&'b X) -> T + 'a;
/// Function df(x) ///
pub type DF<'a, X> = dyn for<'b> Fn(&'b mut X, &'b X) + 'a; pub type DF<'a, X> = dyn for<'b> Fn(&'b mut X, &'b X) + 'a;
/// Function order ///
#[allow(clippy::upper_case_acronyms)] #[allow(clippy::upper_case_acronyms)]
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
pub enum FunctionOrder { pub enum FunctionOrder {
/// Second order ///
SECOND, SECOND,
/// Third order ///
THIRD, THIRD,
} }
+23 -19
View File
@@ -12,7 +12,7 @@
//! &[1.5, 2.0, 1.5, 4.0], //! &[1.5, 2.0, 1.5, 4.0],
//! &[1.5, 1.0, 1.5, 5.0], //! &[1.5, 1.0, 1.5, 5.0],
//! &[1.5, 2.0, 1.5, 6.0], //! &[1.5, 2.0, 1.5, 6.0],
//! ]).unwrap(); //! ]);
//! let encoder_params = OneHotEncoderParams::from_cat_idx(&[1, 3]); //! let encoder_params = OneHotEncoderParams::from_cat_idx(&[1, 3]);
//! // Infer number of categories from data and return a reusable encoder //! // Infer number of categories from data and return a reusable encoder
//! let encoder = OneHotEncoder::fit(&data, encoder_params).unwrap(); //! let encoder = OneHotEncoder::fit(&data, encoder_params).unwrap();
@@ -24,7 +24,7 @@
//! // &[1.5, 1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0] //! // &[1.5, 1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0]
//! // &[1.5, 0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0] //! // &[1.5, 0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0]
//! ``` //! ```
use std::iter::repeat_n; use std::iter;
use crate::error::Failed; use crate::error::Failed;
use crate::linalg::basic::arrays::Array2; use crate::linalg::basic::arrays::Array2;
@@ -75,7 +75,11 @@ fn find_new_idxs(num_params: usize, cat_sizes: &[usize], cat_idxs: &[usize]) ->
let offset = (0..1).chain(offset_); let offset = (0..1).chain(offset_);
let new_param_idxs: Vec<usize> = (0..num_params) let new_param_idxs: Vec<usize> = (0..num_params)
.zip(repeats.zip(offset).flat_map(|(r, o)| repeat_n(o, r))) .zip(
repeats
.zip(offset)
.flat_map(|(r, o)| iter::repeat(o).take(r)),
)
.map(|(idx, ofst)| idx + ofst) .map(|(idx, ofst)| idx + ofst)
.collect(); .collect();
new_param_idxs new_param_idxs
@@ -120,7 +124,7 @@ impl OneHotEncoder {
let (nrows, _) = data.shape(); let (nrows, _) = data.shape();
// col buffer to avoid allocations // col buffer to avoid allocations
let mut col_buf: Vec<T> = repeat_n(T::zero(), nrows).collect(); let mut col_buf: Vec<T> = iter::repeat(T::zero()).take(nrows).collect();
let mut res: Vec<CategoryMapper<CategoricalFloat>> = Vec::with_capacity(idxs.len()); let mut res: Vec<CategoryMapper<CategoricalFloat>> = Vec::with_capacity(idxs.len());
@@ -128,7 +132,8 @@ impl OneHotEncoder {
data.copy_col_as_vec(idx, &mut col_buf); data.copy_col_as_vec(idx, &mut col_buf);
if !validate_col_is_categorical(&col_buf) { if !validate_col_is_categorical(&col_buf) {
let msg = format!( let msg = format!(
"Column {idx} of data matrix containts non categorizable (integer) values" "Column {} of data matrix containts non categorizable (integer) values",
idx
); );
return Err(Failed::fit(&msg[..])); return Err(Failed::fit(&msg[..]));
} }
@@ -177,7 +182,7 @@ impl OneHotEncoder {
match oh_vec { match oh_vec {
None => { None => {
// Since we support T types, bad value in a series causes in to be invalid // Since we support T types, bad value in a series causes in to be invalid
let msg = format!("At least one value in column {old_cidx} doesn't conform to category definition"); let msg = format!("At least one value in column {} doesn't conform to category definition", old_cidx);
return Err(Failed::transform(&msg[..])); return Err(Failed::transform(&msg[..]));
} }
Some(v) => { Some(v) => {
@@ -236,16 +241,14 @@ mod tests {
&[2.0, 1.5, 4.0], &[2.0, 1.5, 4.0],
&[1.0, 1.5, 5.0], &[1.0, 1.5, 5.0],
&[2.0, 1.5, 6.0], &[2.0, 1.5, 6.0],
]) ]);
.unwrap();
let oh_enc = DenseMatrix::from_2d_array(&[ let oh_enc = DenseMatrix::from_2d_array(&[
&[1.0, 0.0, 1.5, 1.0, 0.0, 0.0, 0.0], &[1.0, 0.0, 1.5, 1.0, 0.0, 0.0, 0.0],
&[0.0, 1.0, 1.5, 0.0, 1.0, 0.0, 0.0], &[0.0, 1.0, 1.5, 0.0, 1.0, 0.0, 0.0],
&[1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0], &[1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0],
&[0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0], &[0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0],
]) ]);
.unwrap();
(orig, oh_enc) (orig, oh_enc)
} }
@@ -257,16 +260,14 @@ mod tests {
&[1.5, 2.0, 1.5, 4.0], &[1.5, 2.0, 1.5, 4.0],
&[1.5, 1.0, 1.5, 5.0], &[1.5, 1.0, 1.5, 5.0],
&[1.5, 2.0, 1.5, 6.0], &[1.5, 2.0, 1.5, 6.0],
]) ]);
.unwrap();
let oh_enc = DenseMatrix::from_2d_array(&[ let oh_enc = DenseMatrix::from_2d_array(&[
&[1.5, 1.0, 0.0, 1.5, 1.0, 0.0, 0.0, 0.0], &[1.5, 1.0, 0.0, 1.5, 1.0, 0.0, 0.0, 0.0],
&[1.5, 0.0, 1.0, 1.5, 0.0, 1.0, 0.0, 0.0], &[1.5, 0.0, 1.0, 1.5, 0.0, 1.0, 0.0, 0.0],
&[1.5, 1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0], &[1.5, 1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0],
&[1.5, 0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0], &[1.5, 0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0],
]) ]);
.unwrap();
(orig, oh_enc) (orig, oh_enc)
} }
@@ -277,7 +278,7 @@ mod tests {
)] )]
#[test] #[test]
fn hash_encode_f64_series() { fn hash_encode_f64_series() {
let series = [3.0, 1.0, 2.0, 1.0]; let series = vec![3.0, 1.0, 2.0, 1.0];
let hashable_series: Vec<CategoricalFloat> = let hashable_series: Vec<CategoricalFloat> =
series.iter().map(|v| v.to_category()).collect(); series.iter().map(|v| v.to_category()).collect();
let enc = CategoryMapper::from_positional_category_vec(hashable_series); let enc = CategoryMapper::from_positional_category_vec(hashable_series);
@@ -334,11 +335,14 @@ mod tests {
&[2.0, 1.5, 4.0], &[2.0, 1.5, 4.0],
&[1.0, 1.5, 5.0], &[1.0, 1.5, 5.0],
&[2.0, 1.5, 6.0], &[2.0, 1.5, 6.0],
]) ]);
.unwrap();
let params = OneHotEncoderParams::from_cat_idx(&[1]); let params = OneHotEncoderParams::from_cat_idx(&[1]);
let result = OneHotEncoder::fit(&m, params); match OneHotEncoder::fit(&m, params) {
assert!(result.is_err()); Err(_) => {
assert!(true);
}
_ => assert!(false),
}
} }
} }
+51 -56
View File
@@ -11,7 +11,7 @@
//! vec![0.0, 0.0], //! vec![0.0, 0.0],
//! vec![1.0, 1.0], //! vec![1.0, 1.0],
//! vec![1.0, 1.0], //! vec![1.0, 1.0],
//! ]).unwrap(); //! ]);
//! //!
//! let standard_scaler = //! let standard_scaler =
//! numerical::StandardScaler::fit(&data, numerical::StandardScalerParameters::default()) //! numerical::StandardScaler::fit(&data, numerical::StandardScalerParameters::default())
@@ -24,7 +24,7 @@
//! vec![-1.0, -1.0], //! vec![-1.0, -1.0],
//! vec![1.0, 1.0], //! vec![1.0, 1.0],
//! vec![1.0, 1.0], //! vec![1.0, 1.0],
//! ]).unwrap() //! ])
//! ); //! );
//! ``` //! ```
use std::marker::PhantomData; use std::marker::PhantomData;
@@ -172,14 +172,18 @@ where
T: Number + RealNumber, T: Number + RealNumber,
M: Array2<T>, M: Array2<T>,
{ {
columns.first().cloned().map(|output_matrix| { if let Some(output_matrix) = columns.first().cloned() {
columns return Some(
.iter() columns
.skip(1) .iter()
.fold(output_matrix, |current_matrix, new_colum| { .skip(1)
current_matrix.h_stack(new_colum) .fold(output_matrix, |current_matrix, new_colum| {
}) current_matrix.h_stack(new_colum)
}) }),
);
} else {
None
}
} }
#[cfg(test)] #[cfg(test)]
@@ -193,18 +197,15 @@ mod tests {
fn combine_three_columns() { fn combine_three_columns() {
assert_eq!( assert_eq!(
build_matrix_from_columns(vec![ build_matrix_from_columns(vec![
DenseMatrix::from_2d_vec(&vec![vec![1.0], vec![1.0], vec![1.0],]).unwrap(), DenseMatrix::from_2d_vec(&vec![vec![1.0], vec![1.0], vec![1.0],]),
DenseMatrix::from_2d_vec(&vec![vec![2.0], vec![2.0], vec![2.0],]).unwrap(), DenseMatrix::from_2d_vec(&vec![vec![2.0], vec![2.0], vec![2.0],]),
DenseMatrix::from_2d_vec(&vec![vec![3.0], vec![3.0], vec![3.0],]).unwrap() DenseMatrix::from_2d_vec(&vec![vec![3.0], vec![3.0], vec![3.0],])
]), ]),
Some( Some(DenseMatrix::from_2d_vec(&vec![
DenseMatrix::from_2d_vec(&vec![ vec![1.0, 2.0, 3.0],
vec![1.0, 2.0, 3.0], vec![1.0, 2.0, 3.0],
vec![1.0, 2.0, 3.0], vec![1.0, 2.0, 3.0]
vec![1.0, 2.0, 3.0] ]))
])
.unwrap()
)
) )
} }
@@ -286,24 +287,21 @@ mod tests {
/// sklearn. /// sklearn.
#[test] #[test]
fn fit_transform_random_values() { fn fit_transform_random_values() {
let transformed_values = fit_transform_with_default_standard_scaler( let transformed_values =
&DenseMatrix::from_2d_array(&[ fit_transform_with_default_standard_scaler(&DenseMatrix::from_2d_array(&[
&[0.1004222429, 0.2194113576, 0.9310663354, 0.3313593793], &[0.1004222429, 0.2194113576, 0.9310663354, 0.3313593793],
&[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264], &[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264],
&[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046], &[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046],
&[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442], &[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442],
]) ]));
.unwrap(), println!("{}", transformed_values);
);
println!("{transformed_values}");
assert!(transformed_values.approximate_eq( assert!(transformed_values.approximate_eq(
&DenseMatrix::from_2d_array(&[ &DenseMatrix::from_2d_array(&[
&[-1.1154020653, -0.4031985330, 0.9284605204, -0.4271473866], &[-1.1154020653, -0.4031985330, 0.9284605204, -0.4271473866],
&[-0.7615464283, -0.7076698384, -1.1075452562, 1.2632979631], &[-0.7615464283, -0.7076698384, -1.1075452562, 1.2632979631],
&[0.4832504303, -0.6106747444, 1.0630075435, 0.5494084257], &[0.4832504303, -0.6106747444, 1.0630075435, 0.5494084257],
&[1.3936980634, 1.7215431158, -0.8839228078, -1.3855590021], &[1.3936980634, 1.7215431158, -0.8839228078, -1.3855590021],
]) ]),
.unwrap(),
1.0 1.0
)) ))
} }
@@ -312,10 +310,13 @@ mod tests {
#[test] #[test]
fn fit_transform_with_zero_variance() { fn fit_transform_with_zero_variance() {
assert_eq!( assert_eq!(
fit_transform_with_default_standard_scaler( fit_transform_with_default_standard_scaler(&DenseMatrix::from_2d_array(&[
&DenseMatrix::from_2d_array(&[&[1.0], &[1.0], &[1.0], &[1.0]]).unwrap() &[1.0],
), &[1.0],
DenseMatrix::from_2d_array(&[&[0.0], &[0.0], &[0.0], &[0.0]]).unwrap(), &[1.0],
&[1.0]
])),
DenseMatrix::from_2d_array(&[&[0.0], &[0.0], &[0.0], &[0.0]]),
"When scaling values with zero variance, zero is expected as return value" "When scaling values with zero variance, zero is expected as return value"
) )
} }
@@ -330,8 +331,7 @@ mod tests {
&[1.0, 2.0, 5.0], &[1.0, 2.0, 5.0],
&[1.0, 1.0, 1.0], &[1.0, 1.0, 1.0],
&[1.0, 2.0, 5.0] &[1.0, 2.0, 5.0]
]) ]),
.unwrap(),
StandardScalerParameters::default(), StandardScalerParameters::default(),
), ),
Ok(StandardScaler { Ok(StandardScaler {
@@ -354,8 +354,7 @@ mod tests {
&[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264], &[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264],
&[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046], &[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046],
&[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442], &[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442],
]) ]),
.unwrap(),
StandardScalerParameters::default(), StandardScalerParameters::default(),
) )
.unwrap(); .unwrap();
@@ -365,18 +364,17 @@ mod tests {
vec![0.42864544605, 0.2869813741, 0.737752073825, 0.431011663625], vec![0.42864544605, 0.2869813741, 0.737752073825, 0.431011663625],
); );
assert!(&DenseMatrix::<f64>::from_2d_vec(&vec![fitted_scaler.stds]) assert!(
.unwrap() &DenseMatrix::<f64>::from_2d_vec(&vec![fitted_scaler.stds]).approximate_eq(
.approximate_eq(
&DenseMatrix::from_2d_array(&[&[ &DenseMatrix::from_2d_array(&[&[
0.29426447500954, 0.29426447500954,
0.16758497615485, 0.16758497615485,
0.20820945786863, 0.20820945786863,
0.23329718831165 0.23329718831165
],]) ],]),
.unwrap(),
0.00000000000001 0.00000000000001
)) )
)
} }
/// If `with_std` is set to `false` the values should not be /// If `with_std` is set to `false` the values should not be
@@ -394,9 +392,8 @@ mod tests {
}; };
assert_eq!( assert_eq!(
standard_scaler standard_scaler.transform(&DenseMatrix::from_2d_array(&[&[0.0, 2.0], &[2.0, 4.0]])),
.transform(&DenseMatrix::from_2d_array(&[&[0.0, 2.0], &[2.0, 4.0]]).unwrap()), Ok(DenseMatrix::from_2d_array(&[&[-1.0, -1.0], &[1.0, 1.0]]))
Ok(DenseMatrix::from_2d_array(&[&[-1.0, -1.0], &[1.0, 1.0]]).unwrap())
) )
} }
@@ -416,8 +413,8 @@ mod tests {
assert_eq!( assert_eq!(
standard_scaler standard_scaler
.transform(&DenseMatrix::from_2d_array(&[&[0.0, 9.0], &[4.0, 12.0]]).unwrap()), .transform(&DenseMatrix::from_2d_array(&[&[0.0, 9.0], &[4.0, 12.0]])),
Ok(DenseMatrix::from_2d_array(&[&[0.0, 3.0], &[2.0, 4.0]]).unwrap()) Ok(DenseMatrix::from_2d_array(&[&[0.0, 3.0], &[2.0, 4.0]]))
) )
} }
@@ -436,8 +433,7 @@ mod tests {
&[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264], &[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264],
&[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046], &[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046],
&[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442], &[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442],
]) ]),
.unwrap(),
StandardScalerParameters::default(), StandardScalerParameters::default(),
) )
.unwrap(); .unwrap();
@@ -450,18 +446,17 @@ mod tests {
vec![0.42864544605, 0.2869813741, 0.737752073825, 0.431011663625], vec![0.42864544605, 0.2869813741, 0.737752073825, 0.431011663625],
); );
assert!(&DenseMatrix::from_2d_vec(&vec![deserialized_scaler.stds]) assert!(
.unwrap() &DenseMatrix::from_2d_vec(&vec![deserialized_scaler.stds]).approximate_eq(
.approximate_eq(
&DenseMatrix::from_2d_array(&[&[ &DenseMatrix::from_2d_array(&[&[
0.29426447500954, 0.29426447500954,
0.16758497615485, 0.16758497615485,
0.20820945786863, 0.20820945786863,
0.23329718831165 0.23329718831165
],]) ],]),
.unwrap(),
0.00000000000001 0.00000000000001
)) )
)
} }
} }
} }
+4 -4
View File
@@ -206,7 +206,7 @@ mod tests {
#[test] #[test]
fn from_categories() { fn from_categories() {
let fake_categories: Vec<usize> = vec![1, 2, 3, 4, 5, 3, 5, 3, 1, 2, 4]; let fake_categories: Vec<usize> = vec![1, 2, 3, 4, 5, 3, 5, 3, 1, 2, 4];
let it = fake_categories.iter().copied(); let it = fake_categories.iter().map(|&a| a);
let enc = CategoryMapper::<usize>::fit_to_iter(it); let enc = CategoryMapper::<usize>::fit_to_iter(it);
let oh_vec: Vec<f64> = match enc.get_one_hot(&1) { let oh_vec: Vec<f64> = match enc.get_one_hot(&1) {
None => panic!("Wrong categories"), None => panic!("Wrong categories"),
@@ -218,8 +218,8 @@ mod tests {
fn build_fake_str_enc<'a>() -> CategoryMapper<&'a str> { fn build_fake_str_enc<'a>() -> CategoryMapper<&'a str> {
let fake_category_pos = vec!["background", "dog", "cat"]; let fake_category_pos = vec!["background", "dog", "cat"];
let enc = CategoryMapper::<&str>::from_positional_category_vec(fake_category_pos);
CategoryMapper::<&str>::from_positional_category_vec(fake_category_pos) enc
} }
#[cfg_attr( #[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")), all(target_arch = "wasm32", not(target_os = "wasi")),
@@ -275,7 +275,7 @@ mod tests {
let lab = enc.invert_one_hot(res).unwrap(); let lab = enc.invert_one_hot(res).unwrap();
assert_eq!(lab, "dog"); assert_eq!(lab, "dog");
if let Err(e) = enc.invert_one_hot(vec![0.0, 0.0, 0.0]) { if let Err(e) = enc.invert_one_hot(vec![0.0, 0.0, 0.0]) {
let pos_entries = "Expected a single positive entry, 0 entires found".to_string(); let pos_entries = format!("Expected a single positive entry, 0 entires found");
assert_eq!(e, Failed::transform(&pos_entries[..])); assert_eq!(e, Failed::transform(&pos_entries[..]));
}; };
} }
+14 -9
View File
@@ -30,7 +30,7 @@ pub struct CSVDefinition<'a> {
/// What seperates the fields in your csv-file? /// What seperates the fields in your csv-file?
field_seperator: &'a str, field_seperator: &'a str,
} }
impl Default for CSVDefinition<'_> { impl<'a> Default for CSVDefinition<'a> {
fn default() -> Self { fn default() -> Self {
Self { Self {
n_rows_header: 1, n_rows_header: 1,
@@ -83,7 +83,7 @@ where
Matrix: Array2<T>, Matrix: Array2<T>,
{ {
let csv_text = read_string_from_source(source)?; let csv_text = read_string_from_source(source)?;
let rows: Vec<Vec<T>> = extract_row_vectors_from_csv_text( let rows: Vec<Vec<T>> = extract_row_vectors_from_csv_text::<T, RowVector, Matrix>(
&csv_text, &csv_text,
&definition, &definition,
detect_row_format(&csv_text, &definition)?, detect_row_format(&csv_text, &definition)?,
@@ -103,7 +103,12 @@ where
/// Given a string containing the contents of a csv file, extract its value /// Given a string containing the contents of a csv file, extract its value
/// into row-vectors. /// into row-vectors.
fn extract_row_vectors_from_csv_text<'a, T: Number + RealNumber + std::str::FromStr>( fn extract_row_vectors_from_csv_text<
'a,
T: Number + RealNumber + std::str::FromStr,
RowVector: Array1<T>,
Matrix: Array2<T>,
>(
csv_text: &'a str, csv_text: &'a str,
definition: &'a CSVDefinition<'_>, definition: &'a CSVDefinition<'_>,
row_format: CSVRowFormat<'_>, row_format: CSVRowFormat<'_>,
@@ -162,7 +167,7 @@ where
} }
/// Ensure that a string containing a csv row conforms to a specified row format. /// Ensure that a string containing a csv row conforms to a specified row format.
fn validate_csv_row(row: &str, row_format: &CSVRowFormat<'_>) -> Result<(), ReadingError> { fn validate_csv_row<'a>(row: &'a str, row_format: &CSVRowFormat<'_>) -> Result<(), ReadingError> {
let actual_number_of_fields = row.split(row_format.field_seperator).count(); let actual_number_of_fields = row.split(row_format.field_seperator).count();
if row_format.n_fields == actual_number_of_fields { if row_format.n_fields == actual_number_of_fields {
Ok(()) Ok(())
@@ -203,7 +208,7 @@ where
match value_string.parse::<T>().ok() { match value_string.parse::<T>().ok() {
Some(value) => Ok(value), Some(value) => Ok(value),
None => Err(ReadingError::InvalidField { None => Err(ReadingError::InvalidField {
msg: format!("Value '{value_string}' could not be read.",), msg: format!("Value '{}' could not be read.", value_string,),
}), }),
} }
} }
@@ -238,8 +243,7 @@ mod tests {
&[5.1, 3.5, 1.4, 0.2], &[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2], &[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2], &[4.7, 3.2, 1.3, 0.2],
]) ]))
.unwrap())
) )
} }
#[test] #[test]
@@ -262,7 +266,7 @@ mod tests {
&[5.1, 3.5, 1.4, 0.2], &[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2], &[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2], &[4.7, 3.2, 1.3, 0.2],
]).unwrap()) ]))
) )
} }
#[test] #[test]
@@ -301,11 +305,12 @@ mod tests {
} }
mod extract_row_vectors_from_csv_text { mod extract_row_vectors_from_csv_text {
use super::super::{extract_row_vectors_from_csv_text, CSVDefinition, CSVRowFormat}; use super::super::{extract_row_vectors_from_csv_text, CSVDefinition, CSVRowFormat};
use crate::linalg::basic::matrix::DenseMatrix;
#[test] #[test]
fn read_default_csv() { fn read_default_csv() {
assert_eq!( assert_eq!(
extract_row_vectors_from_csv_text::<f64>( extract_row_vectors_from_csv_text::<f64, Vec<_>, DenseMatrix<_>>(
"column 1, column 2, column3\n1.0,2.0,3.0\n4.0,5.0,6.0", "column 1, column 2, column3\n1.0,2.0,3.0\n4.0,5.0,6.0",
&CSVDefinition::default(), &CSVDefinition::default(),
CSVRowFormat { CSVRowFormat {
+181 -285
View File
@@ -25,18 +25,14 @@
/// search parameters /// search parameters
pub mod svc; pub mod svc;
pub mod svr; pub mod svr;
// search parameters space // /// search parameters space
pub mod search; // pub mod search;
use core::fmt::Debug; use core::fmt::Debug;
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
// Only import typetag if not compiling for wasm32 and serde is enabled
#[cfg(all(feature = "serde", not(target_arch = "wasm32")))]
use typetag;
use crate::error::{Failed, FailedError}; use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::{Array1, ArrayView1}; use crate::linalg::basic::arrays::{Array1, ArrayView1};
@@ -52,301 +48,205 @@ pub trait Kernel: Debug {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed>; fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed>;
} }
/// A enumerator for all the kernels type to support. /// Pre-defined kernel functions
/// This allows kernel selection and parameterization ergonomic, type-safe, and ready for use in parameter structs like SVRParameters.
/// You can construct kernels using the provided variants and builder-style methods.
///
/// # Examples
///
/// ```
/// use smartcore::svm::Kernels;
///
/// let linear = Kernels::linear();
/// let rbf = Kernels::rbf().with_gamma(0.5);
/// let poly = Kernels::polynomial().with_degree(3.0).with_gamma(0.5).with_coef0(1.0);
/// let sigmoid = Kernels::sigmoid().with_gamma(0.2).with_coef0(0.0);
/// ```
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone)]
pub enum Kernels { pub struct Kernels;
/// Linear kernel (default).
///
/// Computes the standard dot product between vectors.
Linear,
/// Radial Basis Function (RBF) kernel.
///
/// Formula: K(x, y) = exp(-gamma * ||x-y||²)
RBF {
/// Controls the width of the Gaussian RBF kernel.
///
/// Larger values of gamma lead to higher bias and lower variance.
/// This parameter is inversely proportional to the radius of influence
/// of samples selected by the model as support vectors.
gamma: Option<f64>,
},
/// Polynomial kernel.
///
/// Formula: K(x, y) = (gamma * <x, y> + coef0)^degree
Polynomial {
/// The degree of the polynomial kernel.
///
/// Integer values are typical (2 = quadratic, 3 = cubic), but any positive real value is valid.
/// Higher degree values create decision boundaries with higher complexity.
degree: Option<f64>,
/// Kernel coefficient for the dot product.
///
/// Controls the influence of higher-degree versus lower-degree terms in the polynomial.
/// If None, a default value will be used.
gamma: Option<f64>,
/// Independent term in the polynomial kernel.
///
/// Controls the influence of higher-degree versus lower-degree terms.
/// If None, a default value of 1.0 will be used.
coef0: Option<f64>,
},
/// Sigmoid kernel.
///
/// Formula: K(x, y) = tanh(gamma * <x, y> + coef0)
Sigmoid {
/// Kernel coefficient for the dot product.
///
/// Controls the scaling of the dot product in the sigmoid function.
/// If None, a default value will be used.
gamma: Option<f64>,
/// Independent term in the sigmoid kernel.
///
/// Acts as a threshold/bias term in the sigmoid function.
/// If None, a default value of 1.0 will be used.
coef0: Option<f64>,
},
}
impl Kernels { impl Kernels {
/// Create a linear kernel. /// Return a default linear
/// pub fn linear() -> LinearKernel {
/// The linear kernel computes the dot product between two vectors: LinearKernel::default()
/// K(x, y) = <x, y>
pub fn linear() -> Self {
Kernels::Linear
} }
/// Return a default RBF
/// Create an RBF kernel with unspecified gamma. pub fn rbf() -> RBFKernel {
/// RBFKernel::default()
/// The RBF kernel is defined as:
/// K(x, y) = exp(-gamma * ||x-y||²)
///
/// You should specify gamma using `with_gamma()` before using this kernel.
pub fn rbf() -> Self {
Kernels::RBF { gamma: None }
} }
/// Return a default polynomial
/// Create a polynomial kernel with default parameters. pub fn polynomial() -> PolynomialKernel {
/// PolynomialKernel::default()
/// The polynomial kernel is defined as:
/// K(x, y) = (gamma * <x, y> + coef0)^degree
///
/// Default values:
/// - gamma: None (must be specified)
/// - degree: None (must be specified)
/// - coef0: 1.0
pub fn polynomial() -> Self {
Kernels::Polynomial {
gamma: None,
degree: None,
coef0: Some(1.0),
}
} }
/// Return a default sigmoid
/// Create a sigmoid kernel with default parameters. pub fn sigmoid() -> SigmoidKernel {
/// SigmoidKernel::default()
/// The sigmoid kernel is defined as:
/// K(x, y) = tanh(gamma * <x, y> + coef0)
///
/// Default values:
/// - gamma: None (must be specified)
/// - coef0: 1.0
///
pub fn sigmoid() -> Self {
Kernels::Sigmoid {
gamma: None,
coef0: Some(1.0),
}
} }
}
/// Set the `gamma` parameter for RBF, polynomial, or sigmoid kernels. /// Linear Kernel
/// #[allow(clippy::derive_partial_eq_without_eq)]
/// The gamma parameter has different interpretations depending on the kernel: #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
/// - For RBF: Controls the width of the Gaussian. Larger values mean tighter fit. #[derive(Debug, Clone, PartialEq, Eq, Default)]
/// - For Polynomial: Scaling factor for the dot product. pub struct LinearKernel;
/// - For Sigmoid: Scaling factor for the dot product.
/// /// Radial basis function (Gaussian) kernel
pub fn with_gamma(self, gamma: f64) -> Self { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
match self { #[derive(Debug, Default, Clone, PartialEq)]
Kernels::RBF { .. } => Kernels::RBF { gamma: Some(gamma) }, pub struct RBFKernel {
Kernels::Polynomial { degree, coef0, .. } => Kernels::Polynomial { /// kernel coefficient
gamma: Some(gamma), pub gamma: Option<f64>,
degree, }
coef0,
}, #[allow(dead_code)]
Kernels::Sigmoid { coef0, .. } => Kernels::Sigmoid { impl RBFKernel {
gamma: Some(gamma), /// assign gamma parameter to kernel (required)
coef0, /// ```rust
}, /// use smartcore::svm::RBFKernel;
other => other, /// let knl = RBFKernel::default().with_gamma(0.7);
} /// ```
pub fn with_gamma(mut self, gamma: f64) -> Self {
self.gamma = Some(gamma);
self
} }
}
/// Set the `degree` parameter for the polynomial kernel. /// Polynomial kernel
/// #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
/// The degree parameter controls the flexibility of the decision boundary. #[derive(Debug, Clone, PartialEq)]
/// Higher degrees create more complex boundaries but may lead to overfitting. pub struct PolynomialKernel {
/// /// degree of the polynomial
pub fn with_degree(self, degree: f64) -> Self { pub degree: Option<f64>,
match self { /// kernel coefficient
Kernels::Polynomial { gamma, coef0, .. } => Kernels::Polynomial { pub gamma: Option<f64>,
degree: Some(degree), /// independent term in kernel function
gamma, pub coef0: Option<f64>,
coef0, }
},
other => other,
}
}
/// Set the `coef0` parameter for polynomial or sigmoid kernels. impl Default for PolynomialKernel {
/// fn default() -> Self {
/// The coef0 parameter is the independent term in the kernel function: Self {
/// - For Polynomial: Controls the influence of higher-degree vs. lower-degree terms. gamma: Option::None,
/// - For Sigmoid: Acts as a threshold/bias term. degree: Option::None,
/// coef0: Some(1f64),
pub fn with_coef0(self, coef0: f64) -> Self {
match self {
Kernels::Polynomial { degree, gamma, .. } => Kernels::Polynomial {
degree,
gamma,
coef0: Some(coef0),
},
Kernels::Sigmoid { gamma, .. } => Kernels::Sigmoid {
gamma,
coef0: Some(coef0),
},
other => other,
} }
} }
} }
/// Implementation of the [`Kernel`] trait for the [`Kernels`] enum in smartcore. impl PolynomialKernel {
/// /// set parameters for kernel
/// This method computes the value of the kernel function between two feature vectors `x_i` and `x_j`, /// ```rust
/// according to the variant and parameters of the [`Kernels`] enum. This enables flexible and type-safe /// use smartcore::svm::PolynomialKernel;
/// selection of kernel functions for SVM and SVR models in smartcore. /// let knl = PolynomialKernel::default().with_params(3.0, 0.7, 1.0);
/// /// ```
/// # Supported Kernels pub fn with_params(mut self, degree: f64, gamma: f64, coef0: f64) -> Self {
/// self.degree = Some(degree);
/// - [`Kernels::Linear`]: Computes the standard dot product between `x_i` and `x_j`. self.gamma = Some(gamma);
/// - [`Kernels::RBF`]: Computes the Radial Basis Function (Gaussian) kernel. Requires `gamma`. self.coef0 = Some(coef0);
/// - [`Kernels::Polynomial`]: Computes the polynomial kernel. Requires `degree`, `gamma`, and `coef0`. self
/// - [`Kernels::Sigmoid`]: Computes the sigmoid kernel. Requires `gamma` and `coef0`. }
/// /// set gamma parameter for kernel
/// # Parameters /// ```rust
/// /// use smartcore::svm::PolynomialKernel;
/// - `x_i`: First input vector (feature vector). /// let knl = PolynomialKernel::default().with_gamma(0.7);
/// - `x_j`: Second input vector (feature vector). /// ```
/// pub fn with_gamma(mut self, gamma: f64) -> Self {
/// # Returns self.gamma = Some(gamma);
/// self
/// - `Ok(f64)`: The computed kernel value. }
/// - `Err(Failed)`: If any required kernel parameter is missing. /// set degree parameter for kernel
/// /// ```rust
/// # Errors /// use smartcore::svm::PolynomialKernel;
/// /// let knl = PolynomialKernel::default().with_degree(3.0, 100);
/// Returns `Err(Failed)` if a required parameter (such as `gamma`, `degree`, or `coef0`) /// ```
/// is `None` for the selected kernel variant. pub fn with_degree(self, degree: f64, n_features: usize) -> Self {
/// self.with_params(degree, 1f64, 1f64 / n_features as f64)
/// # Example }
/// }
/// ```
/// use smartcore::svm::Kernels; /// Sigmoid (hyperbolic tangent) kernel
/// use smartcore::svm::Kernel; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
/// #[derive(Debug, Clone, PartialEq)]
/// let x = vec![1.0, 2.0, 3.0]; pub struct SigmoidKernel {
/// let y = vec![4.0, 5.0, 6.0]; /// kernel coefficient
/// let kernel = Kernels::rbf().with_gamma(0.5); pub gamma: Option<f64>,
/// let value = kernel.apply(&x, &y).unwrap(); /// independent term in kernel function
/// ``` pub coef0: Option<f64>,
/// }
/// # Notes
/// impl Default for SigmoidKernel {
/// - This implementation follows smartcore's philosophy: pure Rust, no macros, no unsafe code, fn default() -> Self {
/// and an accessible, pythonic API surface for both ML practitioners and Rust beginners. Self {
/// - All kernel parameters must be set before calling `apply`; missing parameters will result in an error. gamma: Option::None,
/// coef0: Some(1f64),
/// See the [`Kernels`] enum documentation for more details on each kernel type and its parameters.
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
impl Kernel for Kernels {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
match self {
Kernels::Linear => Ok(x_i.dot(x_j)),
Kernels::RBF { gamma } => {
let gamma = gamma.ok_or_else(|| {
Failed::because(FailedError::ParametersError, "gamma not set")
})?;
let v_diff = x_i.sub(x_j);
Ok((-gamma * v_diff.mul(&v_diff).sum()).exp())
}
Kernels::Polynomial {
degree,
gamma,
coef0,
} => {
let degree = degree.ok_or_else(|| {
Failed::because(FailedError::ParametersError, "degree not set")
})?;
let gamma = gamma.ok_or_else(|| {
Failed::because(FailedError::ParametersError, "gamma not set")
})?;
let coef0 = coef0.ok_or_else(|| {
Failed::because(FailedError::ParametersError, "coef0 not set")
})?;
let dot = x_i.dot(x_j);
Ok((gamma * dot + coef0).powf(degree))
}
Kernels::Sigmoid { gamma, coef0 } => {
let gamma = gamma.ok_or_else(|| {
Failed::because(FailedError::ParametersError, "gamma not set")
})?;
let coef0 = coef0.ok_or_else(|| {
Failed::because(FailedError::ParametersError, "coef0 not set")
})?;
let dot = x_i.dot(x_j);
Ok((gamma * dot + coef0).tanh())
}
} }
} }
} }
impl SigmoidKernel {
/// set parameters for kernel
/// ```rust
/// use smartcore::svm::SigmoidKernel;
/// let knl = SigmoidKernel::default().with_params(0.7, 1.0);
/// ```
pub fn with_params(mut self, gamma: f64, coef0: f64) -> Self {
self.gamma = Some(gamma);
self.coef0 = Some(coef0);
self
}
/// set gamma parameter for kernel
/// ```rust
/// use smartcore::svm::SigmoidKernel;
/// let knl = SigmoidKernel::default().with_gamma(0.7);
/// ```
pub fn with_gamma(mut self, gamma: f64) -> Self {
self.gamma = Some(gamma);
self
}
}
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
impl Kernel for LinearKernel {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
Ok(x_i.dot(x_j))
}
}
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
impl Kernel for RBFKernel {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
if self.gamma.is_none() {
return Err(Failed::because(
FailedError::ParametersError,
"gamma should be set, use {Kernel}::default().with_gamma(..)",
));
}
let v_diff = x_i.sub(x_j);
Ok((-self.gamma.unwrap() * v_diff.mul(&v_diff).sum()).exp())
}
}
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
impl Kernel for PolynomialKernel {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
if self.gamma.is_none() || self.coef0.is_none() || self.degree.is_none() {
return Err(Failed::because(
FailedError::ParametersError, "gamma, coef0, degree should be set,
use {Kernel}::default().with_{parameter}(..)")
);
}
let dot = x_i.dot(x_j);
Ok((self.gamma.unwrap() * dot + self.coef0.unwrap()).powf(self.degree.unwrap()))
}
}
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
impl Kernel for SigmoidKernel {
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
if self.gamma.is_none() || self.coef0.is_none() {
return Err(Failed::because(
FailedError::ParametersError, "gamma, coef0, degree should be set,
use {Kernel}::default().with_{parameter}(..)")
);
}
let dot = x_i.dot(x_j);
Ok(self.gamma.unwrap() * dot + self.coef0.unwrap().tanh())
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::svm::Kernels; use crate::svm::Kernels;
#[test]
fn rbf_kernel() {
let v1 = vec![1., 2., 3.];
let v2 = vec![4., 5., 6.];
let result = Kernels::rbf()
.with_gamma(0.055)
.apply(&v1, &v2)
.unwrap()
.abs();
assert!((0.2265f64 - result) < 1e-4);
}
#[cfg_attr( #[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")), all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test wasm_bindgen_test::wasm_bindgen_test
@@ -364,7 +264,7 @@ mod tests {
wasm_bindgen_test::wasm_bindgen_test wasm_bindgen_test::wasm_bindgen_test
)] )]
#[test] #[test]
fn test_rbf_kernel() { fn rbf_kernel() {
let v1 = vec![1., 2., 3.]; let v1 = vec![1., 2., 3.];
let v2 = vec![4., 5., 6.]; let v2 = vec![4., 5., 6.];
@@ -387,15 +287,12 @@ mod tests {
let v2 = vec![4., 5., 6.]; let v2 = vec![4., 5., 6.];
let result = Kernels::polynomial() let result = Kernels::polynomial()
.with_gamma(0.5) .with_params(3.0, 0.5, 1.0)
.with_degree(3.0)
.with_coef0(1.0)
//.with_params(3.0, 0.5, 1.0)
.apply(&v1, &v2) .apply(&v1, &v2)
.unwrap() .unwrap()
.abs(); .abs();
assert!((4913f64 - result).abs() < f64::EPSILON); assert!((4913f64 - result) < std::f64::EPSILON);
} }
#[cfg_attr( #[cfg_attr(
@@ -408,8 +305,7 @@ mod tests {
let v2 = vec![4., 5., 6.]; let v2 = vec![4., 5., 6.];
let result = Kernels::sigmoid() let result = Kernels::sigmoid()
.with_gamma(0.01) .with_params(0.01, 0.1)
.with_coef0(0.1)
.apply(&v1, &v2) .apply(&v1, &v2)
.unwrap() .unwrap()
.abs(); .abs();
-2
View File
@@ -1,5 +1,3 @@
//! SVC and Grid Search
/// SVC search parameters /// SVC search parameters
pub mod svc_params; pub mod svc_params;
/// SVC search parameters /// SVC search parameters
+101 -282
View File
@@ -1,293 +1,112 @@
//! # SVR Grid Search Parameters // /// SVR grid search parameters
//! // #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
//! This module provides utilities for defining and iterating over grid search parameter spaces // #[derive(Debug, Clone)]
//! for Support Vector Regression (SVR) models in [smartcore](https://github.com/smartcorelib/smartcore). // pub struct SVRSearchParameters<T: Number + RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
//! // /// Epsilon in the epsilon-SVR model.
//! The main struct, [`SVRSearchParameters`], allows users to specify multiple values for each // pub eps: Vec<T>,
//! SVR hyperparameter (epsilon, regularization parameter C, tolerance, and kernel function). // /// Regularization parameter.
//! The provided iterator yields all possible combinations (the Cartesian product) of these parameters, // pub c: Vec<T>,
//! enabling exhaustive grid search for hyperparameter tuning. // /// Tolerance for stopping eps.
//! // pub tol: Vec<T>,
//! // /// The kernel function.
//! ## Example // pub kernel: Vec<K>,
//! ``` // /// Unused parameter.
//! use smartcore::svm::Kernels; // m: PhantomData<M>,
//! use smartcore::svm::search::svr_params::SVRSearchParameters; // }
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//!
//! let params = SVRSearchParameters::<f64, DenseMatrix<f64>> {
//! eps: vec![0.1, 0.2],
//! c: vec![1.0, 10.0],
//! tol: vec![1e-3],
//! kernel: vec![Kernels::linear(), Kernels::rbf().with_gamma(0.5)],
//! m: std::marker::PhantomData,
//! };
//!
//! // for param_set in params.into_iter() {
//! // Use param_set (of type svr::SVRParameters) to fit and evaluate your SVR model.
//! // }
//! ```
//!
//!
//! ## Note
//! This module is intended for use with smartcore version 0.4 or later. The API is not compatible with older versions[1].
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::linalg::basic::arrays::Array2; // /// SVR grid search iterator
use crate::numbers::basenum::Number; // pub struct SVRSearchParametersIterator<T: Number + RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
use crate::numbers::floatnum::FloatNumber; // svr_search_parameters: SVRSearchParameters<T, M, K>,
use crate::numbers::realnum::RealNumber; // current_eps: usize,
use crate::svm::{svr, Kernels}; // current_c: usize,
use std::marker::PhantomData; // current_tol: usize,
// current_kernel: usize,
// }
/// ## SVR grid search parameters // impl<T: Number + RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> IntoIterator
/// A struct representing a grid of hyperparameters for SVR grid search in smartcore. // for SVRSearchParameters<T, M, K>
/// // {
/// Each field is a vector of possible values for the corresponding SVR hyperparameter. // type Item = SVRParameters<T, M, K>;
/// The [`IntoIterator`] implementation yields every possible combination of these parameters // type IntoIter = SVRSearchParametersIterator<T, M, K>;
/// as an `svr::SVRParameters` struct, suitable for use in model selection routines.
///
/// # Type Parameters
/// - `T`: Numeric type for parameters (e.g., `f64`)
/// - `M`: Matrix type implementing [`Array2<T>`]
///
/// # Fields
/// - `eps`: Vector of epsilon values for the epsilon-insensitive loss in SVR.
/// - `c`: Vector of regularization parameters (C) for SVR.
/// - `tol`: Vector of tolerance values for the stopping criterion.
/// - `kernel`: Vector of kernel function variants (see [`Kernels`]).
/// - `m`: Phantom data for the matrix type parameter.
///
/// # Example
/// ```
/// use smartcore::svm::Kernels;
/// use smartcore::svm::search::svr_params::SVRSearchParameters;
/// use smartcore::linalg::basic::matrix::DenseMatrix;
///
/// let params = SVRSearchParameters::<f64, DenseMatrix<f64>> {
/// eps: vec![0.1, 0.2],
/// c: vec![1.0, 10.0],
/// tol: vec![1e-3],
/// kernel: vec![Kernels::linear(), Kernels::rbf().with_gamma(0.5)],
/// m: std::marker::PhantomData,
/// };
/// ```
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct SVRSearchParameters<T: Number + RealNumber, M: Array2<T>> {
/// Epsilon in the epsilon-SVR model.
pub eps: Vec<T>,
/// Regularization parameter.
pub c: Vec<T>,
/// Tolerance for stopping eps.
pub tol: Vec<T>,
/// The kernel function.
pub kernel: Vec<Kernels>,
/// Unused parameter.
pub m: PhantomData<M>,
}
/// SVR grid search iterator // fn into_iter(self) -> Self::IntoIter {
pub struct SVRSearchParametersIterator<T: Number + RealNumber, M: Array2<T>> { // SVRSearchParametersIterator {
svr_search_parameters: SVRSearchParameters<T, M>, // svr_search_parameters: self,
current_eps: usize, // current_eps: 0,
current_c: usize, // current_c: 0,
current_tol: usize, // current_tol: 0,
current_kernel: usize, // current_kernel: 0,
} // }
// }
// }
impl<T: Number + FloatNumber + RealNumber, M: Array2<T>> IntoIterator // impl<T: Number + RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Iterator
for SVRSearchParameters<T, M> // for SVRSearchParametersIterator<T, M, K>
{ // {
type Item = svr::SVRParameters<T>; // type Item = SVRParameters<T, M, K>;
type IntoIter = SVRSearchParametersIterator<T, M>;
fn into_iter(self) -> Self::IntoIter { // fn next(&mut self) -> Option<Self::Item> {
SVRSearchParametersIterator { // if self.current_eps == self.svr_search_parameters.eps.len()
svr_search_parameters: self, // && self.current_c == self.svr_search_parameters.c.len()
current_eps: 0, // && self.current_tol == self.svr_search_parameters.tol.len()
current_c: 0, // && self.current_kernel == self.svr_search_parameters.kernel.len()
current_tol: 0, // {
current_kernel: 0, // return None;
} // }
}
}
impl<T: Number + FloatNumber + RealNumber, M: Array2<T>> Iterator // let next = SVRParameters::<T, M, K> {
for SVRSearchParametersIterator<T, M> // eps: self.svr_search_parameters.eps[self.current_eps],
{ // c: self.svr_search_parameters.c[self.current_c],
type Item = svr::SVRParameters<T>; // tol: self.svr_search_parameters.tol[self.current_tol],
// kernel: self.svr_search_parameters.kernel[self.current_kernel].clone(),
// m: PhantomData,
// };
fn next(&mut self) -> Option<Self::Item> { // if self.current_eps + 1 < self.svr_search_parameters.eps.len() {
if self.current_eps == self.svr_search_parameters.eps.len() // self.current_eps += 1;
&& self.current_c == self.svr_search_parameters.c.len() // } else if self.current_c + 1 < self.svr_search_parameters.c.len() {
&& self.current_tol == self.svr_search_parameters.tol.len() // self.current_eps = 0;
&& self.current_kernel == self.svr_search_parameters.kernel.len() // self.current_c += 1;
{ // } else if self.current_tol + 1 < self.svr_search_parameters.tol.len() {
return None; // self.current_eps = 0;
} // self.current_c = 0;
// self.current_tol += 1;
// } else if self.current_kernel + 1 < self.svr_search_parameters.kernel.len() {
// self.current_eps = 0;
// self.current_c = 0;
// self.current_tol = 0;
// self.current_kernel += 1;
// } else {
// self.current_eps += 1;
// self.current_c += 1;
// self.current_tol += 1;
// self.current_kernel += 1;
// }
let next = svr::SVRParameters::<T> { // Some(next)
eps: self.svr_search_parameters.eps[self.current_eps], // }
c: self.svr_search_parameters.c[self.current_c], // }
tol: self.svr_search_parameters.tol[self.current_tol],
kernel: Some(self.svr_search_parameters.kernel[self.current_kernel].clone()),
};
if self.current_eps + 1 < self.svr_search_parameters.eps.len() { // impl<T: Number + RealNumber, M: Matrix<T>> Default for SVRSearchParameters<T, M, LinearKernel> {
self.current_eps += 1; // fn default() -> Self {
} else if self.current_c + 1 < self.svr_search_parameters.c.len() { // let default_params: SVRParameters<T, M, LinearKernel> = SVRParameters::default();
self.current_eps = 0;
self.current_c += 1;
} else if self.current_tol + 1 < self.svr_search_parameters.tol.len() {
self.current_eps = 0;
self.current_c = 0;
self.current_tol += 1;
} else if self.current_kernel + 1 < self.svr_search_parameters.kernel.len() {
self.current_eps = 0;
self.current_c = 0;
self.current_tol = 0;
self.current_kernel += 1;
} else {
self.current_eps += 1;
self.current_c += 1;
self.current_tol += 1;
self.current_kernel += 1;
}
Some(next) // SVRSearchParameters {
} // eps: vec![default_params.eps],
} // c: vec![default_params.c],
// tol: vec![default_params.tol],
// kernel: vec![default_params.kernel],
// m: PhantomData,
// }
// }
// }
impl<T: Number + FloatNumber + RealNumber, M: Array2<T>> Default for SVRSearchParameters<T, M> { // #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
fn default() -> Self { // #[derive(Debug)]
let default_params: svr::SVRParameters<T> = svr::SVRParameters::default(); // #[cfg_attr(
// feature = "serde",
SVRSearchParameters { // serde(bound(
eps: vec![default_params.eps], // serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize",
c: vec![default_params.c], // deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>",
tol: vec![default_params.tol], // ))
kernel: vec![default_params.kernel.unwrap_or_else(Kernels::linear)], // )]
m: PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::svm::Kernels;
type T = f64;
type M = DenseMatrix<T>;
#[test]
fn test_default_parameters() {
let params = SVRSearchParameters::<T, M>::default();
assert_eq!(params.eps.len(), 1);
assert_eq!(params.c.len(), 1);
assert_eq!(params.tol.len(), 1);
assert_eq!(params.kernel.len(), 1);
// Check that the default kernel is linear
assert_eq!(params.kernel[0], Kernels::linear());
}
#[test]
fn test_single_grid_iteration() {
let params = SVRSearchParameters::<T, M> {
eps: vec![0.1],
c: vec![1.0],
tol: vec![1e-3],
kernel: vec![Kernels::rbf().with_gamma(0.5)],
m: PhantomData,
};
let mut iter = params.into_iter();
let param = iter.next().unwrap();
assert_eq!(param.eps, 0.1);
assert_eq!(param.c, 1.0);
assert_eq!(param.tol, 1e-3);
assert_eq!(param.kernel, Some(Kernels::rbf().with_gamma(0.5)));
assert!(iter.next().is_none());
}
#[test]
fn test_cartesian_grid_iteration() {
let params = SVRSearchParameters::<T, M> {
eps: vec![0.1, 0.2],
c: vec![1.0, 2.0],
tol: vec![1e-3],
kernel: vec![Kernels::linear(), Kernels::rbf().with_gamma(0.5)],
m: PhantomData,
};
let expected_count =
params.eps.len() * params.c.len() * params.tol.len() * params.kernel.len();
let results: Vec<_> = params.into_iter().collect();
assert_eq!(results.len(), expected_count);
// Check that all parameter combinations are present
let mut seen = vec![];
for p in &results {
seen.push((p.eps, p.c, p.tol, p.kernel.clone().unwrap()));
}
for &eps in &[0.1, 0.2] {
for &c in &[1.0, 2.0] {
for &tol in &[1e-3] {
for kernel in &[Kernels::linear(), Kernels::rbf().with_gamma(0.5)] {
assert!(seen.contains(&(eps, c, tol, kernel.clone())));
}
}
}
}
}
#[test]
fn test_empty_grid() {
let params = SVRSearchParameters::<T, M> {
eps: vec![],
c: vec![],
tol: vec![],
kernel: vec![],
m: PhantomData,
};
let mut iter = params.into_iter();
assert!(iter.next().is_none());
}
#[test]
fn test_kernel_enum_variants() {
let lin = Kernels::linear();
let rbf = Kernels::rbf().with_gamma(0.2);
let poly = Kernels::polynomial()
.with_degree(2.0)
.with_gamma(1.0)
.with_coef0(0.5);
let sig = Kernels::sigmoid().with_gamma(0.3).with_coef0(0.1);
assert_eq!(lin, Kernels::Linear);
match rbf {
Kernels::RBF { gamma } => assert_eq!(gamma, Some(0.2)),
_ => panic!("Not RBF"),
}
match poly {
Kernels::Polynomial {
degree,
gamma,
coef0,
} => {
assert_eq!(degree, Some(2.0));
assert_eq!(gamma, Some(1.0));
assert_eq!(coef0, Some(0.5));
}
_ => panic!("Not Polynomial"),
}
match sig {
Kernels::Sigmoid { gamma, coef0 } => {
assert_eq!(gamma, Some(0.3));
assert_eq!(coef0, Some(0.1));
}
_ => panic!("Not Sigmoid"),
}
}
}
+137 -439
View File
@@ -53,16 +53,15 @@
//! &[4.9, 2.4, 3.3, 1.0], //! &[4.9, 2.4, 3.3, 1.0],
//! &[6.6, 2.9, 4.6, 1.3], //! &[6.6, 2.9, 4.6, 1.3],
//! &[5.2, 2.7, 3.9, 1.4], //! &[5.2, 2.7, 3.9, 1.4],
//! ]).unwrap(); //! ]);
//! let y = vec![ -1, -1, -1, -1, -1, -1, -1, -1, //! let y = vec![ -1, -1, -1, -1, -1, -1, -1, -1,
//! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; //! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
//! //!
//! let knl = Kernels::linear(); //! let knl = Kernels::linear();
//! let parameters = &SVCParameters::default().with_c(200.0).with_kernel(knl); //! let params = &SVCParameters::default().with_c(200.0).with_kernel(knl);
//! let svc = SVC::fit(&x, &y, parameters).unwrap(); //! let svc = SVC::fit(&x, &y, params).unwrap();
//! //!
//! let y_hat = svc.predict(&x).unwrap(); //! let y_hat = svc.predict(&x).unwrap();
//!
//! ``` //! ```
//! //!
//! ## References: //! ## References:
@@ -85,194 +84,12 @@ use serde::{Deserialize, Serialize};
use crate::api::{PredictorBorrow, SupervisedEstimatorBorrow}; use crate::api::{PredictorBorrow, SupervisedEstimatorBorrow};
use crate::error::{Failed, FailedError}; use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::{Array, Array1, Array2, MutArray}; use crate::linalg::basic::arrays::{Array1, Array2, MutArray};
use crate::numbers::basenum::Number; use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber; use crate::numbers::realnum::RealNumber;
use crate::rand_custom::get_rng_impl; use crate::rand_custom::get_rng_impl;
use crate::svm::Kernel; use crate::svm::Kernel;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
/// Configuration for a multi-class Support Vector Machine (SVM) classifier.
/// This struct holds the indices of the data points relevant to a specific binary
/// classification problem within a multi-class context, and the two classes
/// being discriminated.
struct MultiClassConfig<TY: Number + Ord> {
/// The indices of the data points from the original dataset that belong to the two `classes`.
indices: Vec<usize>,
/// A tuple representing the two classes that this configuration is designed to distinguish.
classes: (TY, TY),
}
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
SupervisedEstimatorBorrow<'a, X, Y, SVCParameters<TX, TY, X, Y>>
for MultiClassSVC<'a, TX, TY, X, Y>
{
/// Creates a new, empty `MultiClassSVC` instance.
fn new() -> Self {
Self {
classifiers: Option::None,
}
}
/// Fits the `MultiClassSVC` model to the provided data and parameters.
///
/// This method delegates the fitting process to the inherent `MultiClassSVC::fit` method.
///
/// # Arguments
/// * `x` - A reference to the input features (2D array).
/// * `y` - A reference to the target labels (1D array).
/// * `parameters` - A reference to the `SVCParameters` controlling the SVM training.
///
/// # Returns
/// A `Result` indicating success (`Self`) or failure (`Failed`).
fn fit(
x: &'a X,
y: &'a Y,
parameters: &'a SVCParameters<TX, TY, X, Y>,
) -> Result<Self, Failed> {
MultiClassSVC::fit(x, y, parameters)
}
}
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
PredictorBorrow<'a, X, TX> for MultiClassSVC<'a, TX, TY, X, Y>
{
/// Predicts the class labels for new data points.
///
/// This method delegates the prediction process to the inherent `MultiClassSVC::predict` method.
///
/// # Arguments
/// * `x` - A reference to the input features (2D array) for which to make predictions.
///
/// # Returns
/// A `Result` containing a `Vec` of predicted class labels (`TX`) or a `Failed` error.
fn predict(&self, x: &'a X) -> Result<Vec<TX>, Failed> {
Ok(self.predict(x).unwrap())
}
}
/// A multi-class Support Vector Machine (SVM) classifier.
///
/// This struct implements a multi-class SVM using the "one-vs-one" strategy,
/// where a separate binary SVC classifier is trained for every pair of classes.
///
/// # Type Parameters
/// * `'a` - Lifetime parameter for borrowed data.
/// * `TX` - The numeric type of the input features (must implement `Number` and `RealNumber`).
/// * `TY` - The numeric type of the target labels (must implement `Number` and `Ord`).
/// * `X` - The type representing the 2D array of input features (e.g., a matrix).
/// * `Y` - The type representing the 1D array of target labels (e.g., a vector).
pub struct MultiClassSVC<
'a,
TX: Number + RealNumber,
TY: Number + Ord,
X: Array2<TX>,
Y: Array1<TY>,
> {
/// An optional vector of binary `SVC` classifiers.
classifiers: Option<Vec<SVC<'a, TX, TY, X, Y>>>,
}
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
MultiClassSVC<'a, TX, TY, X, Y>
{
/// Fits the `MultiClassSVC` model to the provided data using a one-vs-one strategy.
///
/// This method identifies all unique classes in the target labels `y` and then
/// trains a binary `SVC` for every unique pair of classes. For each pair, it
/// extracts the relevant data points and their labels, and then trains a
/// specialized `SVC` for that binary classification task.
///
/// # Arguments
/// * `x` - A reference to the input features (2D array).
/// * `y` - A reference to the target labels (1D array).
/// * `parameters` - A reference to the `SVCParameters` controlling the SVM training for each individual binary classifier.
///
///
/// # Returns
/// A `Result` indicating success (`MultiClassSVC`) or failure (`Failed`).
pub fn fit(
x: &'a X,
y: &'a Y,
parameters: &'a SVCParameters<TX, TY, X, Y>,
) -> Result<MultiClassSVC<'a, TX, TY, X, Y>, Failed> {
let unique_classes = y.unique();
let mut classifiers = Vec::new();
// Iterate through all unique pairs of classes (one-vs-one strategy)
for i in 0..unique_classes.len() {
for j in i..unique_classes.len() {
if i == j {
continue;
}
let class0 = unique_classes[j];
let class1 = unique_classes[i];
let mut indices = Vec::new();
// Collect indices of data points belonging to the current pair of classes
for (index, v) in y.iterator(0).enumerate() {
if *v == class0 || *v == class1 {
indices.push(index)
}
}
let classes = (class0, class1);
let multiclass_config = MultiClassConfig { classes, indices };
// Fit a binary SVC for the current pair of classes
let svc = SVC::multiclass_fit(x, y, parameters, multiclass_config).unwrap();
classifiers.push(svc);
}
}
Ok(Self {
classifiers: Some(classifiers),
})
}
/// Predicts the class labels for new data points using the trained multi-class SVM.
///
/// This method uses a "voting" scheme (majority vote) among all the binary
/// classifiers to determine the final prediction for each data point.
///
/// # Arguments
/// * `x` - A reference to the input features (2D array) for which to make predictions.
///
/// # Returns
/// A `Result` containing a `Vec` of predicted class labels (`TX`) or a `Failed` error.
///
pub fn predict(&self, x: &X) -> Result<Vec<TX>, Failed> {
// Initialize a HashMap for each data point to store votes for each class
let mut polls = vec![HashMap::new(); x.shape().0];
// Retrieve the trained binary classifiers
let classifiers = self.classifiers.as_ref().unwrap();
// Iterate through each binary classifier
for i in 0..classifiers.len() {
let svc = classifiers.get(i).unwrap();
let predictions = svc.predict(x).unwrap(); // call SVC::predict for each binary classifier
// For each prediction from the current binary classifier
for (j, prediction) in predictions.iter().enumerate() {
let prediction = prediction.to_i32().unwrap();
let poll = polls.get_mut(j).unwrap(); // Get the poll for the current data point
// Increment the vote for the predicted class
if let Some(count) = poll.get_mut(&prediction) {
*count += 1
} else {
poll.insert(prediction, 1);
}
}
}
// Determine the final prediction for each data point based on majority vote
Ok(polls
.iter()
.map(|v| {
// Find the class with the maximum votes for each data point
TX::from(*v.iter().max_by_key(|(_, class)| *class).unwrap().0).unwrap()
})
.collect())
}
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)] #[derive(Debug)]
/// SVC Parameters /// SVC Parameters
@@ -306,7 +123,7 @@ pub struct SVCParameters<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX
)] )]
/// Support Vector Classifier /// Support Vector Classifier
pub struct SVC<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> { pub struct SVC<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> {
classes: Option<(TY, TY)>, classes: Option<Vec<TY>>,
instances: Option<Vec<Vec<TX>>>, instances: Option<Vec<Vec<TX>>>,
#[cfg_attr(feature = "serde", serde(skip))] #[cfg_attr(feature = "serde", serde(skip))]
parameters: Option<&'a SVCParameters<TX, TY, X, Y>>, parameters: Option<&'a SVCParameters<TX, TY, X, Y>>,
@@ -335,9 +152,7 @@ struct Cache<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1
struct Optimizer<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> { struct Optimizer<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> {
x: &'a X, x: &'a X,
y: &'a Y, y: &'a Y,
indices: Option<Vec<usize>>,
parameters: &'a SVCParameters<TX, TY, X, Y>, parameters: &'a SVCParameters<TX, TY, X, Y>,
classes: &'a (TY, TY),
svmin: usize, svmin: usize,
svmax: usize, svmax: usize,
gmin: TX, gmin: TX,
@@ -365,12 +180,12 @@ impl<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
self.tol = tol; self.tol = tol;
self self
} }
/// The kernel function. /// The kernel function.
pub fn with_kernel<K: Kernel + 'static>(mut self, kernel: K) -> Self { pub fn with_kernel<K: Kernel + 'static>(mut self, kernel: K) -> Self {
self.kernel = Some(Box::new(kernel)); self.kernel = Some(Box::new(kernel));
self self
} }
/// Seed for the pseudo random number generator. /// Seed for the pseudo random number generator.
pub fn with_seed(mut self, seed: Option<u64>) -> Self { pub fn with_seed(mut self, seed: Option<u64>) -> Self {
self.seed = seed; self.seed = seed;
@@ -426,98 +241,17 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array1<TY> + 'a> impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array1<TY> + 'a>
SVC<'a, TX, TY, X, Y> SVC<'a, TX, TY, X, Y>
{ {
/// Fits a binary Support Vector Classifier (SVC) to the provided data. /// Fits SVC to your data.
/// /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// This is the primary `fit` method for a standalone binary SVC. It expects /// * `y` - class labels
/// the target labels `y` to contain exactly two unique classes. If more or /// * `parameters` - optional parameters, use `Default::default()` to set parameters to default values.
/// fewer than two classes are found, it returns an error. It then extracts
/// these two classes and proceeds to optimize and fit the SVC model.
///
/// # Arguments
/// * `x` - A reference to the input features (2D array) of the training data.
/// * `y` - A reference to the target labels (1D array) of the training data. `y` must contain exactly two unique class labels.
/// * `parameters` - A reference to the `SVCParameters` controlling the training process.
///
/// # Returns
/// A `Result` which is:
/// - `Ok(SVC<'a, TX, TY, X, Y>)`: A new, fitted binary SVC instance.
/// - `Err(Failed)`: If the number of unique classes in `y` is not exactly two, or if the underlying optimization fails.
pub fn fit( pub fn fit(
x: &'a X, x: &'a X,
y: &'a Y, y: &'a Y,
parameters: &'a SVCParameters<TX, TY, X, Y>, parameters: &'a SVCParameters<TX, TY, X, Y>,
) -> Result<SVC<'a, TX, TY, X, Y>, Failed> { ) -> Result<SVC<'a, TX, TY, X, Y>, Failed> {
let classes = y.unique(); let (n, _) = x.shape();
// Validate that there are exactly two unique classes in the target labels.
if classes.len() != 2 {
return Err(Failed::fit(&format!(
"Incorrect number of classes: {}. A binary SVC requires exactly two classes.",
classes.len()
)));
}
let classes = (classes[0], classes[1]);
let svc = Self::optimize_and_fit(x, y, parameters, classes, None);
svc
}
/// Fits a binary Support Vector Classifier (SVC) specifically for multi-class scenarios.
///
/// This function is intended to be called by a multi-class strategy (e.g., one-vs-one)
/// to train individual binary SVCs. It takes a `MultiClassConfig` which specifies
/// the two classes this SVC should discriminate and the subset of data indices
/// relevant to these classes. It then delegates the actual optimization and fitting
/// to `optimize_and_fit`.
///
/// # Arguments
/// * `x` - A reference to the input features (2D array) of the training data.
/// * `y` - A reference to the target labels (1D array) of the training data.
/// * `parameters` - A reference to the `SVCParameters` controlling the training process (e.g., kernel, C-value, tolerance).
/// * `multiclass_config` - A `MultiClassConfig` struct containing:
/// - `classes`: A tuple `(class0, class1)` specifying the two classes this SVC should distinguish.
/// - `indices`: A `Vec<usize>` containing the indices of the data points in `x` and `y that belong to either `class0` or `class1`.`
///
/// # Returns
/// A `Result` which is:
/// - `Ok(SVC<'a, TX, TY, X, Y>)`: A new, fitted binary SVC instance.
/// - `Err(Failed)`: If the fitting process encounters an error (e.g., invalid parameters).
fn multiclass_fit(
x: &'a X,
y: &'a Y,
parameters: &'a SVCParameters<TX, TY, X, Y>,
multiclass_config: MultiClassConfig<TY>,
) -> Result<SVC<'a, TX, TY, X, Y>, Failed> {
let classes = multiclass_config.classes;
let indices = multiclass_config.indices;
let svc = Self::optimize_and_fit(x, y, parameters, classes, Some(indices));
svc
}
/// Internal function to optimize and fit the Support Vector Classifier.
///
/// This is the core logic for training a binary SVC. It performs several checks
/// (e.g., kernel presence, data shape consistency) and then initializes an
/// `Optimizer` to find the support vectors, weights (`w`), and bias (`b`).
///
/// # Arguments
/// * `x` - A reference to the input features (2D array) of the training data.
/// * `y` - A reference to the target labels (1D array) of the training data.
/// * `parameters` - A reference to the `SVCParameters` defining the SVM model's configuration.
/// * `classes` - A tuple `(class0, class1)` representing the two distinct class labels that the SVC will learn to separate.
/// * `indices` - An `Option<Vec<usize>>`. If `Some`, it contains the specific indices of data points from `x` and `y` that should be used for training this binary classifier. If `None`, all data points in `x` and `y` are considered.
/// # Returns
/// A `Result` which is:
/// - `Ok(SVC<'a, TX, TY, X, Y>)`: A new `SVC` instance populated with the learned model components (support vectors, weights, bias).
/// - `Err(Failed)`: If any of the validation checks fail (e.g., missing kernel, mismatched data shapes), or if the optimization process fails.
fn optimize_and_fit(
x: &'a X,
y: &'a Y,
parameters: &'a SVCParameters<TX, TY, X, Y>,
classes: (TY, TY),
indices: Option<Vec<usize>>,
) -> Result<SVC<'a, TX, TY, X, Y>, Failed> {
let (n_samples, _) = x.shape();
// Validate that a kernel has been defined in the parameters.
if parameters.kernel.is_none() { if parameters.kernel.is_none() {
return Err(Failed::because( return Err(Failed::because(
FailedError::ParametersError, FailedError::ParametersError,
@@ -525,39 +259,55 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array
)); ));
} }
// Validate that the number of samples in X matches the number of labels in Y. if n != y.shape() {
if n_samples != y.shape() {
return Err(Failed::fit( return Err(Failed::fit(
"Number of rows of X doesn't match number of rows of Y", "Number of rows of X doesn\'t match number of rows of Y",
)); ));
} }
let optimizer: Optimizer<'_, TX, TY, X, Y> = let classes = y.unique();
Optimizer::new(x, y, indices, parameters, &classes);
if classes.len() != 2 {
return Err(Failed::fit(&format!(
"Incorrect number of classes: {}",
classes.len()
)));
}
// Make sure class labels are either 1 or -1
for e in y.iterator(0) {
let y_v = e.to_i32().unwrap();
if y_v != -1 && y_v != 1 {
return Err(Failed::because(
FailedError::ParametersError,
"Class labels must be 1 or -1",
));
}
}
let optimizer: Optimizer<'_, TX, TY, X, Y> = Optimizer::new(x, y, parameters);
// Perform the optimization to find the support vectors, weight vector, and bias.
// This is where the core SVM algorithm (e.g., SMO) would run.
let (support_vectors, weight, b) = optimizer.optimize(); let (support_vectors, weight, b) = optimizer.optimize();
// Construct and return the fitted SVC model.
Ok(SVC::<'a> { Ok(SVC::<'a> {
classes: Some(classes), // Store the two classes the SVC was trained on. classes: Some(classes),
instances: Some(support_vectors), // Store the data points that are support vectors. instances: Some(support_vectors),
parameters: Some(parameters), // Reference to the parameters used for fitting. parameters: Some(parameters),
w: Some(weight), // The learned weight vector (for linear kernels). w: Some(weight),
b: Some(b), // The learned bias term. b: Some(b),
phantomdata: PhantomData, // Placeholder for type parameters not directly stored. phantomdata: PhantomData,
}) })
} }
/// Predicts estimated class labels from `x` /// Predicts estimated class labels from `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &'a X) -> Result<Vec<TX>, Failed> { pub fn predict(&self, x: &'a X) -> Result<Vec<TX>, Failed> {
let mut y_hat: Vec<TX> = self.decision_function(x)?; let mut y_hat: Vec<TX> = self.decision_function(x)?;
for i in 0..y_hat.len() { for i in 0..y_hat.len() {
let cls_idx = match *y_hat.get(i) > TX::zero() { let cls_idx = match *y_hat.get(i).unwrap() > TX::zero() {
false => TX::from(self.classes.as_ref().unwrap().0).unwrap(), false => TX::from(self.classes.as_ref().unwrap()[0]).unwrap(),
true => TX::from(self.classes.as_ref().unwrap().1).unwrap(), true => TX::from(self.classes.as_ref().unwrap()[1]).unwrap(),
}; };
y_hat.set(i, cls_idx); y_hat.set(i, cls_idx);
@@ -572,26 +322,19 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array
let (n, _) = x.shape(); let (n, _) = x.shape();
let mut y_hat: Vec<TX> = Array1::zeros(n); let mut y_hat: Vec<TX> = Array1::zeros(n);
let mut row = Vec::with_capacity(n);
for i in 0..n { for i in 0..n {
row.clear(); let row_pred: TX =
row.extend(x.get_row(i).iterator(0).copied()); self.predict_for_row(Vec::from_iterator(x.get_row(i).iterator(0).copied(), n));
let row_pred: TX = self.predict_for_row(&row);
y_hat.set(i, row_pred); y_hat.set(i, row_pred);
} }
Ok(y_hat) Ok(y_hat)
} }
fn predict_for_row(&self, x: &[TX]) -> TX { fn predict_for_row(&self, x: Vec<TX>) -> TX {
let mut f = self.b.unwrap(); let mut f = self.b.unwrap();
let xi: Vec<_> = x.iter().map(|e| e.to_f64().unwrap()).collect();
for i in 0..self.instances.as_ref().unwrap().len() { for i in 0..self.instances.as_ref().unwrap().len() {
let xj: Vec<_> = self.instances.as_ref().unwrap()[i]
.iter()
.map(|e| e.to_f64().unwrap())
.collect();
f += self.w.as_ref().unwrap()[i] f += self.w.as_ref().unwrap()[i]
* TX::from( * TX::from(
self.parameters self.parameters
@@ -600,7 +343,13 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array
.kernel .kernel
.as_ref() .as_ref()
.unwrap() .unwrap()
.apply(&xi, &xj) .apply(
&x.iter().map(|e| e.to_f64().unwrap()).collect(),
&self.instances.as_ref().unwrap()[i]
.iter()
.map(|e| e.to_f64().unwrap())
.collect(),
)
.unwrap(), .unwrap(),
) )
.unwrap(); .unwrap();
@@ -610,8 +359,8 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array
} }
} }
impl<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> PartialEq impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> PartialEq
for SVC<'_, TX, TY, X, Y> for SVC<'a, TX, TY, X, Y>
{ {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
if (self.b.unwrap().sub(other.b.unwrap())).abs() > TX::epsilon() * TX::two() if (self.b.unwrap().sub(other.b.unwrap())).abs() > TX::epsilon() * TX::two()
@@ -695,18 +444,14 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
fn new( fn new(
x: &'a X, x: &'a X,
y: &'a Y, y: &'a Y,
indices: Option<Vec<usize>>,
parameters: &'a SVCParameters<TX, TY, X, Y>, parameters: &'a SVCParameters<TX, TY, X, Y>,
classes: &'a (TY, TY),
) -> Optimizer<'a, TX, TY, X, Y> { ) -> Optimizer<'a, TX, TY, X, Y> {
let (n, _) = x.shape(); let (n, _) = x.shape();
Optimizer { Optimizer {
x, x,
y, y,
indices,
parameters, parameters,
classes,
svmin: 0, svmin: 0,
svmax: 0, svmax: 0,
gmin: <TX as Bounded>::max_value(), gmin: <TX as Bounded>::max_value(),
@@ -727,17 +472,14 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
let tol = self.parameters.tol; let tol = self.parameters.tol;
let good_enough = TX::from_i32(1000).unwrap(); let good_enough = TX::from_i32(1000).unwrap();
let mut x = Vec::with_capacity(n);
for _ in 0..self.parameters.epoch { for _ in 0..self.parameters.epoch {
for i in self.permutate(n) { for i in self.permutate(n) {
x.clear(); self.process(
x.extend(self.x.get_row(i).iterator(0).take(n).copied()); i,
let y = if *self.y.get(i) == self.classes.1 { Vec::from_iterator(self.x.get_row(i).iterator(0).copied(), n),
1 *self.y.get(i),
} else { &mut cache,
-1 );
} as f64;
self.process(i, &x, y, &mut cache);
loop { loop {
self.reprocess(tol, &mut cache); self.reprocess(tol, &mut cache);
self.find_min_max_gradient(); self.find_min_max_gradient();
@@ -769,20 +511,25 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
let mut cp = 0; let mut cp = 0;
let mut cn = 0; let mut cn = 0;
let mut x = Vec::with_capacity(n);
for i in self.permutate(n) { for i in self.permutate(n) {
x.clear(); if *self.y.get(i) == TY::one() && cp < few {
x.extend(self.x.get_row(i).iterator(0).take(n).copied()); if self.process(
let y = if *self.y.get(i) == self.classes.1 { i,
1 Vec::from_iterator(self.x.get_row(i).iterator(0).copied(), n),
} else { *self.y.get(i),
-1 cache,
} as f64; ) {
if y == 1.0 && cp < few {
if self.process(i, &x, y, cache) {
cp += 1; cp += 1;
} }
} else if y == -1.0 && cn < few && self.process(i, &x, y, cache) { } else if *self.y.get(i) == TY::from(-1).unwrap()
&& cn < few
&& self.process(
i,
Vec::from_iterator(self.x.get_row(i).iterator(0).copied(), n),
*self.y.get(i),
cache,
)
{
cn += 1; cn += 1;
} }
@@ -792,26 +539,27 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
} }
} }
fn process(&mut self, i: usize, x: &[TX], y: f64, cache: &mut Cache<TX, TY, X, Y>) -> bool { fn process(&mut self, i: usize, x: Vec<TX>, y: TY, cache: &mut Cache<TX, TY, X, Y>) -> bool {
for j in 0..self.sv.len() { for j in 0..self.sv.len() {
if self.sv[j].index == i { if self.sv[j].index == i {
return true; return true;
} }
} }
let mut g = y; let mut g: f64 = y.to_f64().unwrap();
let mut cache_values: Vec<((usize, usize), TX)> = Vec::new(); let mut cache_values: Vec<((usize, usize), TX)> = Vec::new();
for v in self.sv.iter() { for v in self.sv.iter() {
let xi: Vec<_> = v.x.iter().map(|e| e.to_f64().unwrap()).collect();
let xj: Vec<_> = x.iter().map(|e| e.to_f64().unwrap()).collect();
let k = self let k = self
.parameters .parameters
.kernel .kernel
.as_ref() .as_ref()
.unwrap() .unwrap()
.apply(&xi, &xj) .apply(
&v.x.iter().map(|e| e.to_f64().unwrap()).collect(),
&x.iter().map(|e| e.to_f64().unwrap()).collect(),
)
.unwrap(); .unwrap();
cache_values.push(((i, v.index), TX::from(k).unwrap())); cache_values.push(((i, v.index), TX::from(k).unwrap()));
g -= v.alpha * k; g -= v.alpha * k;
@@ -820,8 +568,8 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
self.find_min_max_gradient(); self.find_min_max_gradient();
if self.gmin < self.gmax if self.gmin < self.gmax
&& ((y > 0.0 && g < self.gmin.to_f64().unwrap()) && ((y > TY::zero() && g < self.gmin.to_f64().unwrap())
|| (y < 0.0 && g > self.gmax.to_f64().unwrap())) || (y < TY::zero() && g > self.gmax.to_f64().unwrap()))
{ {
return false; return false;
} }
@@ -830,7 +578,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
cache.insert(v.0, v.1.to_f64().unwrap()); cache.insert(v.0, v.1.to_f64().unwrap());
} }
let x_f64: Vec<_> = x.iter().map(|e| e.to_f64().unwrap()).collect(); let x_f64 = x.iter().map(|e| e.to_f64().unwrap()).collect();
let k_v = self let k_v = self
.parameters .parameters
.kernel .kernel
@@ -851,7 +599,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
), ),
); );
if y > 0.0 { if y > TY::zero() {
self.smo(None, Some(0), TX::zero(), cache); self.smo(None, Some(0), TX::zero(), cache);
} else { } else {
self.smo(Some(0), None, TX::zero(), cache); self.smo(Some(0), None, TX::zero(), cache);
@@ -908,6 +656,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
let gmin = self.gmin; let gmin = self.gmin;
let mut idxs_to_drop: HashSet<usize> = HashSet::new(); let mut idxs_to_drop: HashSet<usize> = HashSet::new();
self.sv.retain(|v| { self.sv.retain(|v| {
if v.alpha == 0f64 if v.alpha == 0f64
&& ((TX::from(v.grad).unwrap() >= gmax && TX::zero() >= TX::from(v.cmax).unwrap()) && ((TX::from(v.grad).unwrap() >= gmax && TX::zero() >= TX::from(v.cmax).unwrap())
@@ -926,11 +675,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
fn permutate(&self, n: usize) -> Vec<usize> { fn permutate(&self, n: usize) -> Vec<usize> {
let mut rng = get_rng_impl(self.parameters.seed); let mut rng = get_rng_impl(self.parameters.seed);
let mut range = if let Some(indices) = self.indices.clone() { let mut range: Vec<usize> = (0..n).collect();
indices
} else {
(0..n).collect::<Vec<usize>>()
};
range.shuffle(&mut rng); range.shuffle(&mut rng);
range range
} }
@@ -956,10 +701,8 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
let km = sv1.k; let km = sv1.k;
let gm = sv1.grad; let gm = sv1.grad;
let mut best = 0f64; let mut best = 0f64;
let xi: Vec<_> = sv1.x.iter().map(|e| e.to_f64().unwrap()).collect();
for i in 0..self.sv.len() { for i in 0..self.sv.len() {
let v = &self.sv[i]; let v = &self.sv[i];
let xj: Vec<_> = v.x.iter().map(|e| e.to_f64().unwrap()).collect();
let z = v.grad - gm; let z = v.grad - gm;
let k = cache.get( let k = cache.get(
sv1, sv1,
@@ -968,7 +711,10 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
.kernel .kernel
.as_ref() .as_ref()
.unwrap() .unwrap()
.apply(&xi, &xj) .apply(
&sv1.x.iter().map(|e| e.to_f64().unwrap()).collect(),
&v.x.iter().map(|e| e.to_f64().unwrap()).collect(),
)
.unwrap(), .unwrap(),
); );
let mut curv = km + v.k - 2f64 * k; let mut curv = km + v.k - 2f64 * k;
@@ -986,12 +732,6 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
} }
} }
let xi: Vec<_> = self.sv[idx_1]
.x
.iter()
.map(|e| e.to_f64().unwrap())
.collect::<Vec<_>>();
idx_2.map(|idx_2| { idx_2.map(|idx_2| {
( (
idx_1, idx_1,
@@ -1002,12 +742,16 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
.as_ref() .as_ref()
.unwrap() .unwrap()
.apply( .apply(
&xi, &self.sv[idx_1]
.x
.iter()
.map(|e| e.to_f64().unwrap())
.collect(),
&self.sv[idx_2] &self.sv[idx_2]
.x .x
.iter() .iter()
.map(|e| e.to_f64().unwrap()) .map(|e| e.to_f64().unwrap())
.collect::<Vec<_>>(), .collect(),
) )
.unwrap() .unwrap()
}), }),
@@ -1021,11 +765,8 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
let km = sv2.k; let km = sv2.k;
let gm = sv2.grad; let gm = sv2.grad;
let mut best = 0f64; let mut best = 0f64;
let xi: Vec<_> = sv2.x.iter().map(|e| e.to_f64().unwrap()).collect();
for i in 0..self.sv.len() { for i in 0..self.sv.len() {
let v = &self.sv[i]; let v = &self.sv[i];
let xj: Vec<_> = v.x.iter().map(|e| e.to_f64().unwrap()).collect();
let z = gm - v.grad; let z = gm - v.grad;
let k = cache.get( let k = cache.get(
sv2, sv2,
@@ -1034,7 +775,10 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
.kernel .kernel
.as_ref() .as_ref()
.unwrap() .unwrap()
.apply(&xi, &xj) .apply(
&sv2.x.iter().map(|e| e.to_f64().unwrap()).collect(),
&v.x.iter().map(|e| e.to_f64().unwrap()).collect(),
)
.unwrap(), .unwrap(),
); );
let mut curv = km + v.k - 2f64 * k; let mut curv = km + v.k - 2f64 * k;
@@ -1053,12 +797,6 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
} }
} }
let xj: Vec<_> = self.sv[idx_2]
.x
.iter()
.map(|e| e.to_f64().unwrap())
.collect();
idx_1.map(|idx_1| { idx_1.map(|idx_1| {
( (
idx_1, idx_1,
@@ -1073,8 +811,12 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
.x .x
.iter() .iter()
.map(|e| e.to_f64().unwrap()) .map(|e| e.to_f64().unwrap())
.collect::<Vec<_>>(), .collect(),
&xj, &self.sv[idx_2]
.x
.iter()
.map(|e| e.to_f64().unwrap())
.collect(),
) )
.unwrap() .unwrap()
}), }),
@@ -1093,12 +835,12 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
.x .x
.iter() .iter()
.map(|e| e.to_f64().unwrap()) .map(|e| e.to_f64().unwrap())
.collect::<Vec<_>>(), .collect(),
&self.sv[idx_2] &self.sv[idx_2]
.x .x
.iter() .iter()
.map(|e| e.to_f64().unwrap()) .map(|e| e.to_f64().unwrap())
.collect::<Vec<_>>(), .collect(),
) )
.unwrap(), .unwrap(),
)), )),
@@ -1153,10 +895,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
self.sv[v1].alpha -= step.to_f64().unwrap(); self.sv[v1].alpha -= step.to_f64().unwrap();
self.sv[v2].alpha += step.to_f64().unwrap(); self.sv[v2].alpha += step.to_f64().unwrap();
let xi_v1: Vec<_> = self.sv[v1].x.iter().map(|e| e.to_f64().unwrap()).collect();
let xi_v2: Vec<_> = self.sv[v2].x.iter().map(|e| e.to_f64().unwrap()).collect();
for i in 0..self.sv.len() { for i in 0..self.sv.len() {
let xj: Vec<_> = self.sv[i].x.iter().map(|e| e.to_f64().unwrap()).collect();
let k2 = cache.get( let k2 = cache.get(
&self.sv[v2], &self.sv[v2],
&self.sv[i], &self.sv[i],
@@ -1164,7 +903,10 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
.kernel .kernel
.as_ref() .as_ref()
.unwrap() .unwrap()
.apply(&xi_v2, &xj) .apply(
&self.sv[v2].x.iter().map(|e| e.to_f64().unwrap()).collect(),
&self.sv[i].x.iter().map(|e| e.to_f64().unwrap()).collect(),
)
.unwrap(), .unwrap(),
); );
let k1 = cache.get( let k1 = cache.get(
@@ -1174,7 +916,10 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
.kernel .kernel
.as_ref() .as_ref()
.unwrap() .unwrap()
.apply(&xi_v1, &xj) .apply(
&self.sv[v1].x.iter().map(|e| e.to_f64().unwrap()).collect(),
&self.sv[i].x.iter().map(|e| e.to_f64().unwrap()).collect(),
)
.unwrap(), .unwrap(),
); );
self.sv[i].grad -= step.to_f64().unwrap() * (k2 - k1); self.sv[i].grad -= step.to_f64().unwrap() * (k2 - k1);
@@ -1221,25 +966,28 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0], &[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3], &[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4], &[5.2, 2.7, 3.9, 1.4],
]) ]);
.unwrap();
let y: Vec<i32> = vec![ let y: Vec<i32> = vec![
-1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
]; ];
let knl = Kernels::linear(); let knl = Kernels::linear();
let parameters = SVCParameters::default() let params = SVCParameters::default()
.with_c(200.0) .with_c(200.0)
.with_kernel(knl) .with_kernel(knl)
.with_seed(Some(100)); .with_seed(Some(100));
let y_hat = SVC::fit(&x, &y, &parameters) let y_hat = SVC::fit(&x, &y, &params)
.and_then(|lr| lr.predict(&x)) .and_then(|lr| lr.predict(&x))
.unwrap(); .unwrap();
let acc = accuracy(&y, &(y_hat.iter().map(|e| e.to_i32().unwrap()).collect())); let acc = accuracy(&y, &(y_hat.iter().map(|e| e.to_i32().unwrap()).collect()));
assert!(acc >= 0.9, "accuracy ({acc}) is not larger or equal to 0.9"); assert!(
acc >= 0.9,
"accuracy ({}) is not larger or equal to 0.9",
acc
);
} }
#[cfg_attr( #[cfg_attr(
@@ -1248,8 +996,7 @@ mod tests {
)] )]
#[test] #[test]
fn svc_fit_decision_function() { fn svc_fit_decision_function() {
let x = DenseMatrix::from_2d_array(&[&[4.0, 0.0], &[0.0, 4.0], &[8.0, 0.0], &[0.0, 8.0]]) let x = DenseMatrix::from_2d_array(&[&[4.0, 0.0], &[0.0, 4.0], &[8.0, 0.0], &[0.0, 8.0]]);
.unwrap();
let x2 = DenseMatrix::from_2d_array(&[ let x2 = DenseMatrix::from_2d_array(&[
&[3.0, 3.0], &[3.0, 3.0],
@@ -1258,8 +1005,7 @@ mod tests {
&[10.0, 10.0], &[10.0, 10.0],
&[1.0, 1.0], &[1.0, 1.0],
&[0.0, 0.0], &[0.0, 0.0],
]) ]);
.unwrap();
let y: Vec<i32> = vec![-1, -1, 1, 1]; let y: Vec<i32> = vec![-1, -1, 1, 1];
@@ -1312,8 +1058,7 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0], &[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3], &[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4], &[5.2, 2.7, 3.9, 1.4],
]) ]);
.unwrap();
let y: Vec<i32> = vec![ let y: Vec<i32> = vec![
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
@@ -1331,56 +1076,10 @@ mod tests {
let acc = accuracy(&y, &(y_hat.iter().map(|e| e.to_i32().unwrap()).collect())); let acc = accuracy(&y, &(y_hat.iter().map(|e| e.to_i32().unwrap()).collect()));
assert!(acc >= 0.9, "accuracy ({acc}) is not larger or equal to 0.9");
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn svc_multiclass_fit_predict() {
let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[5.7, 2.8, 4.5, 1.3],
&[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
])
.unwrap();
let y: Vec<i32> = vec![0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2];
let knl = Kernels::linear();
let parameters = SVCParameters::default()
.with_c(200.0)
.with_kernel(knl)
.with_seed(Some(100));
let y_hat = MultiClassSVC::fit(&x, &y, &parameters)
.and_then(|lr| lr.predict(&x))
.unwrap();
let acc = accuracy(&y, &(y_hat.iter().map(|e| e.to_i32().unwrap()).collect()));
assert!( assert!(
acc >= 0.9, acc >= 0.9,
"Multiclass accuracy ({acc}) is not larger or equal to 0.9" "accuracy ({}) is not larger or equal to 0.9",
acc
); );
} }
@@ -1412,19 +1111,18 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0], &[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3], &[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4], &[5.2, 2.7, 3.9, 1.4],
]) ]);
.unwrap();
let y: Vec<i32> = vec![ let y: Vec<i32> = vec![
-1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
]; ];
let knl = Kernels::linear(); let knl = Kernels::linear();
let parameters = SVCParameters::default().with_kernel(knl); let params = SVCParameters::default().with_kernel(knl);
let svc = SVC::fit(&x, &y, &parameters).unwrap(); let svc = SVC::fit(&x, &y, &params).unwrap();
// serialization // serialization
let deserialized_svc: SVC<'_, f64, i32, _, _> = let deserialized_svc: SVC<f64, i32, _, _> =
serde_json::from_str(&serde_json::to_string(&svc).unwrap()).unwrap(); serde_json::from_str(&serde_json::to_string(&svc).unwrap()).unwrap();
assert_eq!(svc, deserialized_svc); assert_eq!(svc, deserialized_svc);
+41 -42
View File
@@ -44,16 +44,16 @@
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564], //! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], //! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551], //! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
//! ]).unwrap(); //! ]);
//! //!
//! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, //! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0,
//! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]; //! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
//! //!
//! let knl = Kernels::linear(); //! let knl = Kernels::linear();
//! let params = &SVRParameters::default().with_eps(2.0).with_c(10.0).with_kernel(knl); //! let params = &SVRParameters::default().with_eps(2.0).with_c(10.0).with_kernel(knl);
//! let svr = SVR::fit(&x, &y, params).unwrap(); //! // let svr = SVR::fit(&x, &y, params).unwrap();
//! //!
//! let y_hat = svr.predict(&x).unwrap(); //! // let y_hat = svr.predict(&x).unwrap();
//! ``` //! ```
//! //!
//! ## References: //! ## References:
@@ -80,12 +80,11 @@ use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::{Array1, Array2, MutArray}; use crate::linalg::basic::arrays::{Array1, Array2, MutArray};
use crate::numbers::basenum::Number; use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber; use crate::numbers::floatnum::FloatNumber;
use crate::svm::Kernel;
use crate::svm::{Kernel, Kernels};
/// SVR Parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)] #[derive(Debug)]
/// SVR Parameters
pub struct SVRParameters<T: Number + FloatNumber + PartialOrd> { pub struct SVRParameters<T: Number + FloatNumber + PartialOrd> {
/// Epsilon in the epsilon-SVR model. /// Epsilon in the epsilon-SVR model.
pub eps: T, pub eps: T,
@@ -98,7 +97,7 @@ pub struct SVRParameters<T: Number + FloatNumber + PartialOrd> {
all(feature = "serde", target_arch = "wasm32"), all(feature = "serde", target_arch = "wasm32"),
serde(skip_serializing, skip_deserializing) serde(skip_serializing, skip_deserializing)
)] )]
pub kernel: Option<Kernels>, pub kernel: Option<Box<dyn Kernel>>,
} }
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@@ -161,8 +160,8 @@ impl<T: Number + FloatNumber + PartialOrd> SVRParameters<T> {
self self
} }
/// The kernel function. /// The kernel function.
pub fn with_kernel(mut self, kernel: Kernels) -> Self { pub fn with_kernel<K: Kernel + 'static>(mut self, kernel: K) -> Self {
self.kernel = Some(kernel); self.kernel = Some(Box::new(kernel));
self self
} }
} }
@@ -249,20 +248,19 @@ impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> SVR<'
let mut y_hat: Vec<T> = Vec::<T>::zeros(n); let mut y_hat: Vec<T> = Vec::<T>::zeros(n);
let mut x_i = Vec::with_capacity(n);
for i in 0..n { for i in 0..n {
x_i.clear(); y_hat.set(
x_i.extend(x.get_row(i).iterator(0).copied()); i,
y_hat.set(i, self.predict_for_row(&x_i)); self.predict_for_row(Vec::from_iterator(x.get_row(i).iterator(0).copied(), n)),
);
} }
Ok(y_hat) Ok(y_hat)
} }
pub(crate) fn predict_for_row(&self, x: &[T]) -> T { pub(crate) fn predict_for_row(&self, x: Vec<T>) -> T {
let mut f = self.b; let mut f = self.b;
let xi: Vec<_> = x.iter().map(|e| e.to_f64().unwrap()).collect();
for i in 0..self.instances.as_ref().unwrap().len() { for i in 0..self.instances.as_ref().unwrap().len() {
f += self.w.as_ref().unwrap()[i] f += self.w.as_ref().unwrap()[i]
* T::from( * T::from(
@@ -272,7 +270,10 @@ impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> SVR<'
.kernel .kernel
.as_ref() .as_ref()
.unwrap() .unwrap()
.apply(&xi, &self.instances.as_ref().unwrap()[i]) .apply(
&x.iter().map(|e| e.to_f64().unwrap()).collect(),
&self.instances.as_ref().unwrap()[i],
)
.unwrap(), .unwrap(),
) )
.unwrap() .unwrap()
@@ -282,8 +283,8 @@ impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> SVR<'
} }
} }
impl<T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> PartialEq impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> PartialEq
for SVR<'_, T, X, Y> for SVR<'a, T, X, Y>
{ {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
if (self.b - other.b).abs() > T::epsilon() * T::two() if (self.b - other.b).abs() > T::epsilon() * T::two()
@@ -598,25 +599,25 @@ mod tests {
use super::*; use super::*;
use crate::linalg::basic::matrix::DenseMatrix; use crate::linalg::basic::matrix::DenseMatrix;
use crate::metrics::mean_squared_error; use crate::metrics::mean_squared_error;
use crate::svm::search::svr_params::SVRSearchParameters;
use crate::svm::Kernels; use crate::svm::Kernels;
#[test] // #[test]
fn search_parameters() { // fn search_parameters() {
let parameters: SVRSearchParameters<f64, DenseMatrix<f64>> = SVRSearchParameters { // let parameters: SVRSearchParameters<f64, DenseMatrix<f64>, LinearKernel> =
eps: vec![0., 1.], // SVRSearchParameters {
kernel: vec![Kernels::linear()], // eps: vec![0., 1.],
..Default::default() // kernel: vec![LinearKernel {}],
}; // ..Default::default()
let mut iter = parameters.into_iter(); // };
let next = iter.next().unwrap(); // let mut iter = parameters.into_iter();
assert_eq!(next.eps, 0.); // let next = iter.next().unwrap();
// assert_eq!(next.kernel, LinearKernel {}); // assert_eq!(next.eps, 0.);
// let next = iter.next().unwrap(); // assert_eq!(next.kernel, LinearKernel {});
// assert_eq!(next.eps, 1.); // let next = iter.next().unwrap();
// assert_eq!(next.kernel, LinearKernel {}); // assert_eq!(next.eps, 1.);
// assert!(iter.next().is_none()); // assert_eq!(next.kernel, LinearKernel {});
} // assert!(iter.next().is_none());
// }
#[cfg_attr( #[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")), all(target_arch = "wasm32", not(target_os = "wasi")),
@@ -641,15 +642,14 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564], &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331], &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551], &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]) ]);
.unwrap();
let y: Vec<f64> = vec![ let y: Vec<f64> = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9, 114.2, 115.7, 116.9,
]; ];
let knl: Kernels = Kernels::linear(); let knl = Kernels::linear();
let y_hat = SVR::fit( let y_hat = SVR::fit(
&x, &x,
&y, &y,
@@ -662,7 +662,7 @@ mod tests {
.unwrap(); .unwrap();
let t = mean_squared_error(&y_hat, &y); let t = mean_squared_error(&y_hat, &y);
println!("{t:?}"); println!("{:?}", t);
assert!(t < 2.5); assert!(t < 2.5);
} }
@@ -690,8 +690,7 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564], &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331], &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551], &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]) ]);
.unwrap();
let y: Vec<f64> = vec![ let y: Vec<f64> = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
@@ -703,7 +702,7 @@ mod tests {
let svr = SVR::fit(&x, &y, &params).unwrap(); let svr = SVR::fit(&x, &y, &params).unwrap();
let deserialized_svr: SVR<'_, f64, DenseMatrix<f64>, _> = let deserialized_svr: SVR<f64, DenseMatrix<f64>, _> =
serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap(); serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap();
assert_eq!(svr, deserialized_svr); assert_eq!(svr, deserialized_svr);
-551
View File
@@ -1,551 +0,0 @@
use std::collections::LinkedList;
use std::default::Default;
use std::fmt::Debug;
use std::marker::PhantomData;
use rand::seq::SliceRandom;
use rand::Rng;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::error::Failed;
use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1};
use crate::numbers::basenum::Number;
use crate::rand_custom::get_rng_impl;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, Default)]
pub enum Splitter {
Random,
#[default]
Best,
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// Parameters of Regression base_tree
pub struct BaseTreeRegressorParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// The maximum depth of the base_tree.
pub max_depth: Option<u16>,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to be at a leaf node.
pub min_samples_leaf: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to split an internal node.
pub min_samples_split: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// Controls the randomness of the estimator
pub seed: Option<u64>,
#[cfg_attr(feature = "serde", serde(default))]
/// Determines the strategy used to choose the split at each node.
pub splitter: Splitter,
}
/// Regression base_tree
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct BaseTreeRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
nodes: Vec<Node>,
parameters: Option<BaseTreeRegressorParameters>,
depth: u16,
_phantom_tx: PhantomData<TX>,
_phantom_ty: PhantomData<TY>,
_phantom_x: PhantomData<X>,
_phantom_y: PhantomData<Y>,
}
impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
BaseTreeRegressor<TX, TY, X, Y>
{
/// Get nodes, return a shared reference
fn nodes(&self) -> &Vec<Node> {
self.nodes.as_ref()
}
/// Get parameters, return a shared reference
fn parameters(&self) -> &BaseTreeRegressorParameters {
self.parameters.as_ref().unwrap()
}
/// Get estimate of intercept, return value
fn depth(&self) -> u16 {
self.depth
}
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
struct Node {
output: f64,
split_feature: usize,
split_value: Option<f64>,
split_score: Option<f64>,
true_child: Option<usize>,
false_child: Option<usize>,
}
impl Node {
fn new(output: f64) -> Self {
Node {
output,
split_feature: 0,
split_value: Option::None,
split_score: Option::None,
true_child: Option::None,
false_child: Option::None,
}
}
}
impl PartialEq for Node {
fn eq(&self, other: &Self) -> bool {
(self.output - other.output).abs() < f64::EPSILON
&& self.split_feature == other.split_feature
&& match (self.split_value, other.split_value) {
(Some(a), Some(b)) => (a - b).abs() < f64::EPSILON,
(None, None) => true,
_ => false,
}
&& match (self.split_score, other.split_score) {
(Some(a), Some(b)) => (a - b).abs() < f64::EPSILON,
(None, None) => true,
_ => false,
}
}
}
impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
for BaseTreeRegressor<TX, TY, X, Y>
{
fn eq(&self, other: &Self) -> bool {
if self.depth != other.depth || self.nodes().len() != other.nodes().len() {
false
} else {
self.nodes()
.iter()
.zip(other.nodes().iter())
.all(|(a, b)| a == b)
}
}
}
struct NodeVisitor<'a, TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
x: &'a X,
y: &'a Y,
node: usize,
samples: Vec<usize>,
order: &'a [Vec<usize>],
true_child_output: f64,
false_child_output: f64,
level: u16,
_phantom_tx: PhantomData<TX>,
_phantom_ty: PhantomData<TY>,
}
impl<'a, TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
NodeVisitor<'a, TX, TY, X, Y>
{
fn new(
node_id: usize,
samples: Vec<usize>,
order: &'a [Vec<usize>],
x: &'a X,
y: &'a Y,
level: u16,
) -> Self {
NodeVisitor {
x,
y,
node: node_id,
samples,
order,
true_child_output: 0f64,
false_child_output: 0f64,
level,
_phantom_tx: PhantomData,
_phantom_ty: PhantomData,
}
}
}
impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
BaseTreeRegressor<TX, TY, X, Y>
{
/// Build a decision base_tree regressor from the training data.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - the target values
pub fn fit(
x: &X,
y: &Y,
parameters: BaseTreeRegressorParameters,
) -> Result<BaseTreeRegressor<TX, TY, X, Y>, Failed> {
let (x_nrows, num_attributes) = x.shape();
if x_nrows != y.shape() {
return Err(Failed::fit("Size of x should equal size of y"));
}
let samples = vec![1; x_nrows];
BaseTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
}
pub(crate) fn fit_weak_learner(
x: &X,
y: &Y,
samples: Vec<usize>,
mtry: usize,
parameters: BaseTreeRegressorParameters,
) -> Result<BaseTreeRegressor<TX, TY, X, Y>, Failed> {
let y_m = y.clone();
let y_ncols = y_m.shape();
let (_, num_attributes) = x.shape();
let mut nodes: Vec<Node> = Vec::new();
let mut rng = get_rng_impl(parameters.seed);
let mut n = 0;
let mut sum = 0f64;
for (i, sample_i) in samples.iter().enumerate().take(y_ncols) {
n += *sample_i;
sum += *sample_i as f64 * y_m.get(i).to_f64().unwrap();
}
let root = Node::new(sum / (n as f64));
nodes.push(root);
let mut order: Vec<Vec<usize>> = Vec::new();
for i in 0..num_attributes {
let mut col_i: Vec<TX> = x.get_col(i).iterator(0).copied().collect();
order.push(col_i.argsort_mut());
}
let mut base_tree = BaseTreeRegressor {
nodes,
parameters: Some(parameters),
depth: 0u16,
_phantom_tx: PhantomData,
_phantom_ty: PhantomData,
_phantom_x: PhantomData,
_phantom_y: PhantomData,
};
let mut visitor = NodeVisitor::<TX, TY, X, Y>::new(0, samples, &order, x, &y_m, 1);
let mut visitor_queue: LinkedList<NodeVisitor<'_, TX, TY, X, Y>> = LinkedList::new();
if base_tree.find_best_cutoff(&mut visitor, mtry, &mut rng) {
visitor_queue.push_back(visitor);
}
while base_tree.depth() < base_tree.parameters().max_depth.unwrap_or(u16::MAX) {
match visitor_queue.pop_front() {
Some(node) => base_tree.split(node, mtry, &mut visitor_queue, &mut rng),
None => break,
};
}
Ok(base_tree)
}
/// Predict regression value for `x`.
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let mut result = Y::zeros(x.shape().0);
let (n, _) = x.shape();
for i in 0..n {
result.set(i, self.predict_for_row(x, i));
}
Ok(result)
}
pub(crate) fn predict_for_row(&self, x: &X, row: usize) -> TY {
let mut result = 0f64;
let mut queue: LinkedList<usize> = LinkedList::new();
queue.push_back(0);
while !queue.is_empty() {
match queue.pop_front() {
Some(node_id) => {
let node = &self.nodes()[node_id];
if node.true_child.is_none() && node.false_child.is_none() {
result = node.output;
} else if x.get((row, node.split_feature)).to_f64().unwrap()
<= node.split_value.unwrap_or(f64::NAN)
{
queue.push_back(node.true_child.unwrap());
} else {
queue.push_back(node.false_child.unwrap());
}
}
None => break,
};
}
TY::from_f64(result).unwrap()
}
fn find_best_cutoff(
&mut self,
visitor: &mut NodeVisitor<'_, TX, TY, X, Y>,
mtry: usize,
rng: &mut impl Rng,
) -> bool {
let (_, n_attr) = visitor.x.shape();
let n: usize = visitor.samples.iter().sum();
if n < self.parameters().min_samples_split {
return false;
}
let sum = self.nodes()[visitor.node].output * n as f64;
let mut variables = (0..n_attr).collect::<Vec<_>>();
if mtry < n_attr {
variables.shuffle(rng);
}
let parent_gain =
n as f64 * self.nodes()[visitor.node].output * self.nodes()[visitor.node].output;
let splitter = self.parameters().splitter.clone();
for variable in variables.iter().take(mtry) {
match splitter {
Splitter::Random => {
self.find_random_split(visitor, n, sum, parent_gain, *variable, rng);
}
Splitter::Best => {
self.find_best_split(visitor, n, sum, parent_gain, *variable);
}
}
}
self.nodes()[visitor.node].split_score.is_some()
}
fn find_random_split(
&mut self,
visitor: &mut NodeVisitor<'_, TX, TY, X, Y>,
n: usize,
sum: f64,
parent_gain: f64,
j: usize,
rng: &mut impl Rng,
) {
let (min_val, max_val) = {
let mut min_opt = None;
let mut max_opt = None;
for &i in &visitor.order[j] {
if visitor.samples[i] > 0 {
min_opt = Some(*visitor.x.get((i, j)));
break;
}
}
for &i in visitor.order[j].iter().rev() {
if visitor.samples[i] > 0 {
max_opt = Some(*visitor.x.get((i, j)));
break;
}
}
if min_opt.is_none() {
return;
}
(min_opt.unwrap(), max_opt.unwrap())
};
if min_val >= max_val {
return;
}
let split_value = rng.gen_range(min_val.to_f64().unwrap()..max_val.to_f64().unwrap());
let mut true_sum = 0f64;
let mut true_count = 0;
for &i in &visitor.order[j] {
if visitor.samples[i] > 0 {
if visitor.x.get((i, j)).to_f64().unwrap() <= split_value {
true_sum += visitor.samples[i] as f64 * visitor.y.get(i).to_f64().unwrap();
true_count += visitor.samples[i];
} else {
break;
}
}
}
let false_count = n - true_count;
if true_count < self.parameters().min_samples_leaf
|| false_count < self.parameters().min_samples_leaf
{
return;
}
let true_mean = if true_count > 0 {
true_sum / true_count as f64
} else {
0.0
};
let false_mean = if false_count > 0 {
(sum - true_sum) / false_count as f64
} else {
0.0
};
let gain = (true_count as f64 * true_mean * true_mean
+ false_count as f64 * false_mean * false_mean)
- parent_gain;
if self.nodes[visitor.node].split_score.is_none()
|| gain > self.nodes[visitor.node].split_score.unwrap()
{
self.nodes[visitor.node].split_feature = j;
self.nodes[visitor.node].split_value = Some(split_value);
self.nodes[visitor.node].split_score = Some(gain);
visitor.true_child_output = true_mean;
visitor.false_child_output = false_mean;
}
}
fn find_best_split(
&mut self,
visitor: &mut NodeVisitor<'_, TX, TY, X, Y>,
n: usize,
sum: f64,
parent_gain: f64,
j: usize,
) {
let mut true_sum = 0f64;
let mut true_count = 0;
let mut prevx = Option::None;
for i in visitor.order[j].iter() {
if visitor.samples[*i] > 0 {
let x_ij = *visitor.x.get((*i, j));
if prevx.is_none() || x_ij == prevx.unwrap() {
prevx = Some(x_ij);
true_count += visitor.samples[*i];
true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
continue;
}
let false_count = n - true_count;
if true_count < self.parameters().min_samples_leaf
|| false_count < self.parameters().min_samples_leaf
{
prevx = Some(x_ij);
true_count += visitor.samples[*i];
true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
continue;
}
let true_mean = true_sum / true_count as f64;
let false_mean = (sum - true_sum) / false_count as f64;
let gain = (true_count as f64 * true_mean * true_mean
+ false_count as f64 * false_mean * false_mean)
- parent_gain;
if self.nodes()[visitor.node].split_score.is_none()
|| gain > self.nodes()[visitor.node].split_score.unwrap()
{
self.nodes[visitor.node].split_feature = j;
self.nodes[visitor.node].split_value =
Option::Some((x_ij + prevx.unwrap()).to_f64().unwrap() / 2f64);
self.nodes[visitor.node].split_score = Option::Some(gain);
visitor.true_child_output = true_mean;
visitor.false_child_output = false_mean;
}
prevx = Some(x_ij);
true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
true_count += visitor.samples[*i];
}
}
}
fn split<'a>(
&mut self,
mut visitor: NodeVisitor<'a, TX, TY, X, Y>,
mtry: usize,
visitor_queue: &mut LinkedList<NodeVisitor<'a, TX, TY, X, Y>>,
rng: &mut impl Rng,
) -> bool {
let (n, _) = visitor.x.shape();
let mut tc = 0;
let mut fc = 0;
let mut true_samples: Vec<usize> = vec![0; n];
for (i, true_sample) in true_samples.iter_mut().enumerate().take(n) {
if visitor.samples[i] > 0 {
if visitor
.x
.get((i, self.nodes()[visitor.node].split_feature))
.to_f64()
.unwrap()
<= self.nodes()[visitor.node].split_value.unwrap_or(f64::NAN)
{
*true_sample = visitor.samples[i];
tc += *true_sample;
visitor.samples[i] = 0;
} else {
fc += visitor.samples[i];
}
}
}
if tc < self.parameters().min_samples_leaf || fc < self.parameters().min_samples_leaf {
self.nodes[visitor.node].split_feature = 0;
self.nodes[visitor.node].split_value = Option::None;
self.nodes[visitor.node].split_score = Option::None;
return false;
}
let true_child_idx = self.nodes().len();
self.nodes.push(Node::new(visitor.true_child_output));
let false_child_idx = self.nodes().len();
self.nodes.push(Node::new(visitor.false_child_output));
self.nodes[visitor.node].true_child = Some(true_child_idx);
self.nodes[visitor.node].false_child = Some(false_child_idx);
self.depth = u16::max(self.depth, visitor.level + 1);
let mut true_visitor = NodeVisitor::<TX, TY, X, Y>::new(
true_child_idx,
true_samples,
visitor.order,
visitor.x,
visitor.y,
visitor.level + 1,
);
if self.find_best_cutoff(&mut true_visitor, mtry, rng) {
visitor_queue.push_back(true_visitor);
}
let mut false_visitor = NodeVisitor::<TX, TY, X, Y>::new(
false_child_idx,
visitor.samples,
visitor.order,
visitor.x,
visitor.y,
visitor.level + 1,
);
if self.find_best_cutoff(&mut false_visitor, mtry, rng) {
visitor_queue.push_back(false_visitor);
}
true
}
}
+51 -235
View File
@@ -48,7 +48,7 @@
//! &[4.9, 2.4, 3.3, 1.0], //! &[4.9, 2.4, 3.3, 1.0],
//! &[6.6, 2.9, 4.6, 1.3], //! &[6.6, 2.9, 4.6, 1.3],
//! &[5.2, 2.7, 3.9, 1.4], //! &[5.2, 2.7, 3.9, 1.4],
//! ]).unwrap(); //! ]);
//! let y = vec![ 0, 0, 0, 0, 0, 0, 0, 0, //! let y = vec![ 0, 0, 0, 0, 0, 0, 0, 0,
//! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; //! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
//! //!
@@ -77,9 +77,7 @@ use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator}; use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed; use crate::error::Failed;
use crate::linalg::basic::arrays::MutArray;
use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1}; use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1};
use crate::linalg::basic::matrix::DenseMatrix;
use crate::numbers::basenum::Number; use crate::numbers::basenum::Number;
use crate::rand_custom::get_rng_impl; use crate::rand_custom::get_rng_impl;
@@ -118,7 +116,6 @@ pub struct DecisionTreeClassifier<
num_classes: usize, num_classes: usize,
classes: Vec<TY>, classes: Vec<TY>,
depth: u16, depth: u16,
num_features: usize,
_phantom_tx: PhantomData<TX>, _phantom_tx: PhantomData<TX>,
_phantom_x: PhantomData<X>, _phantom_x: PhantomData<X>,
_phantom_y: PhantomData<Y>, _phantom_y: PhantomData<Y>,
@@ -140,17 +137,16 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
self.classes.as_ref() self.classes.as_ref()
} }
/// Get depth of tree /// Get depth of tree
pub fn depth(&self) -> u16 { fn depth(&self) -> u16 {
self.depth self.depth
} }
} }
/// The function to measure the quality of a split. /// The function to measure the quality of a split.
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone)]
pub enum SplitCriterion { pub enum SplitCriterion {
/// [Gini index](../decision_tree_classifier/index.html) /// [Gini index](../decision_tree_classifier/index.html)
#[default]
Gini, Gini,
/// [Entropy](../decision_tree_classifier/index.html) /// [Entropy](../decision_tree_classifier/index.html)
Entropy, Entropy,
@@ -158,17 +154,21 @@ pub enum SplitCriterion {
ClassificationError, ClassificationError,
} }
impl Default for SplitCriterion {
fn default() -> Self {
SplitCriterion::Gini
}
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct Node { struct Node {
output: usize, output: usize,
n_node_samples: usize,
split_feature: usize, split_feature: usize,
split_value: Option<f64>, split_value: Option<f64>,
split_score: Option<f64>, split_score: Option<f64>,
true_child: Option<usize>, true_child: Option<usize>,
false_child: Option<usize>, false_child: Option<usize>,
impurity: Option<f64>,
} }
impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> PartialEq impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> PartialEq
@@ -199,12 +199,12 @@ impl PartialEq for Node {
self.output == other.output self.output == other.output
&& self.split_feature == other.split_feature && self.split_feature == other.split_feature
&& match (self.split_value, other.split_value) { && match (self.split_value, other.split_value) {
(Some(a), Some(b)) => (a - b).abs() < f64::EPSILON, (Some(a), Some(b)) => (a - b).abs() < std::f64::EPSILON,
(None, None) => true, (None, None) => true,
_ => false, _ => false,
} }
&& match (self.split_score, other.split_score) { && match (self.split_score, other.split_score) {
(Some(a), Some(b)) => (a - b).abs() < f64::EPSILON, (Some(a), Some(b)) => (a - b).abs() < std::f64::EPSILON,
(None, None) => true, (None, None) => true,
_ => false, _ => false,
} }
@@ -405,16 +405,14 @@ impl Default for DecisionTreeClassifierSearchParameters {
} }
impl Node { impl Node {
fn new(output: usize, n_node_samples: usize) -> Self { fn new(output: usize) -> Self {
Node { Node {
output, output,
n_node_samples,
split_feature: 0, split_feature: 0,
split_value: Option::None, split_value: Option::None,
split_score: Option::None, split_score: Option::None,
true_child: Option::None, true_child: Option::None,
false_child: Option::None, false_child: Option::None,
impurity: Option::None,
} }
} }
} }
@@ -514,7 +512,6 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
num_classes: 0usize, num_classes: 0usize,
classes: vec![], classes: vec![],
depth: 0u16, depth: 0u16,
num_features: 0usize,
_phantom_tx: PhantomData, _phantom_tx: PhantomData,
_phantom_x: PhantomData, _phantom_x: PhantomData,
_phantom_y: PhantomData, _phantom_y: PhantomData,
@@ -546,10 +543,6 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
parameters: DecisionTreeClassifierParameters, parameters: DecisionTreeClassifierParameters,
) -> Result<DecisionTreeClassifier<TX, TY, X, Y>, Failed> { ) -> Result<DecisionTreeClassifier<TX, TY, X, Y>, Failed> {
let (x_nrows, num_attributes) = x.shape(); let (x_nrows, num_attributes) = x.shape();
if x_nrows != y.shape() {
return Err(Failed::fit("Size of x should equal size of y"));
}
let samples = vec![1; x_nrows]; let samples = vec![1; x_nrows];
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters) DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
} }
@@ -567,7 +560,8 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
let k = classes.len(); let k = classes.len();
if k < 2 { if k < 2 {
return Err(Failed::fit(&format!( return Err(Failed::fit(&format!(
"Incorrect number of classes: {k}. Should be >= 2." "Incorrect number of classes: {}. Should be >= 2.",
k
))); )));
} }
@@ -586,7 +580,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
count[yi[i]] += samples[i]; count[yi[i]] += samples[i];
} }
let root = Node::new(which_max(&count), y_ncols); let root = Node::new(which_max(&count));
change_nodes.push(root); change_nodes.push(root);
let mut order: Vec<Vec<usize>> = Vec::new(); let mut order: Vec<Vec<usize>> = Vec::new();
@@ -601,7 +595,6 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
num_classes: k, num_classes: k,
classes, classes,
depth: 0u16, depth: 0u16,
num_features: num_attributes,
_phantom_tx: PhantomData, _phantom_tx: PhantomData,
_phantom_x: PhantomData, _phantom_x: PhantomData,
_phantom_y: PhantomData, _phantom_y: PhantomData,
@@ -615,7 +608,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
visitor_queue.push_back(visitor); visitor_queue.push_back(visitor);
} }
while tree.depth() < tree.parameters().max_depth.unwrap_or(u16::MAX) { while tree.depth() < tree.parameters().max_depth.unwrap_or(std::u16::MAX) {
match visitor_queue.pop_front() { match visitor_queue.pop_front() {
Some(node) => tree.split(node, mtry, &mut visitor_queue, &mut rng), Some(node) => tree.split(node, mtry, &mut visitor_queue, &mut rng),
None => break, None => break,
@@ -652,7 +645,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
if node.true_child.is_none() && node.false_child.is_none() { if node.true_child.is_none() && node.false_child.is_none() {
result = node.output; result = node.output;
} else if x.get((row, node.split_feature)).to_f64().unwrap() } else if x.get((row, node.split_feature)).to_f64().unwrap()
<= node.split_value.unwrap_or(f64::NAN) <= node.split_value.unwrap_or(std::f64::NAN)
{ {
queue.push_back(node.true_child.unwrap()); queue.push_back(node.true_child.unwrap());
} else { } else {
@@ -687,7 +680,16 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
} }
} }
if is_pure {
return false;
}
let n = visitor.samples.iter().sum(); let n = visitor.samples.iter().sum();
if n <= self.parameters().min_samples_split {
return false;
}
let mut count = vec![0; self.num_classes]; let mut count = vec![0; self.num_classes];
let mut false_count = vec![0; self.num_classes]; let mut false_count = vec![0; self.num_classes];
for i in 0..n_rows { for i in 0..n_rows {
@@ -696,15 +698,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
} }
} }
self.nodes[visitor.node].impurity = Some(impurity(&self.parameters().criterion, &count, n)); let parent_impurity = impurity(&self.parameters().criterion, &count, n);
if is_pure {
return false;
}
if n <= self.parameters().min_samples_split {
return false;
}
let mut variables = (0..n_attr).collect::<Vec<_>>(); let mut variables = (0..n_attr).collect::<Vec<_>>();
@@ -713,7 +707,14 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
} }
for variable in variables.iter().take(mtry) { for variable in variables.iter().take(mtry) {
self.find_best_split(visitor, n, &count, &mut false_count, *variable); self.find_best_split(
visitor,
n,
&count,
&mut false_count,
parent_impurity,
*variable,
);
} }
self.nodes()[visitor.node].split_score.is_some() self.nodes()[visitor.node].split_score.is_some()
@@ -725,6 +726,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
n: usize, n: usize,
count: &[usize], count: &[usize],
false_count: &mut [usize], false_count: &mut [usize],
parent_impurity: f64,
j: usize, j: usize,
) { ) {
let mut true_count = vec![0; self.num_classes]; let mut true_count = vec![0; self.num_classes];
@@ -760,7 +762,6 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
let true_label = which_max(&true_count); let true_label = which_max(&true_count);
let false_label = which_max(false_count); let false_label = which_max(false_count);
let parent_impurity = self.nodes()[visitor.node].impurity.unwrap();
let gain = parent_impurity let gain = parent_impurity
- tc as f64 / n as f64 - tc as f64 / n as f64
* impurity(&self.parameters().criterion, &true_count, tc) * impurity(&self.parameters().criterion, &true_count, tc)
@@ -805,7 +806,9 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
.get((i, self.nodes()[visitor.node].split_feature)) .get((i, self.nodes()[visitor.node].split_feature))
.to_f64() .to_f64()
.unwrap() .unwrap()
<= self.nodes()[visitor.node].split_value.unwrap_or(f64::NAN) <= self.nodes()[visitor.node]
.split_value
.unwrap_or(std::f64::NAN)
{ {
*true_sample = visitor.samples[i]; *true_sample = visitor.samples[i];
tc += *true_sample; tc += *true_sample;
@@ -826,9 +829,9 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
let true_child_idx = self.nodes().len(); let true_child_idx = self.nodes().len();
self.nodes.push(Node::new(visitor.true_child_output, tc)); self.nodes.push(Node::new(visitor.true_child_output));
let false_child_idx = self.nodes().len(); let false_child_idx = self.nodes().len();
self.nodes.push(Node::new(visitor.false_child_output, fc)); self.nodes.push(Node::new(visitor.false_child_output));
self.nodes[visitor.node].true_child = Some(true_child_idx); self.nodes[visitor.node].true_child = Some(true_child_idx);
self.nodes[visitor.node].false_child = Some(false_child_idx); self.nodes[visitor.node].false_child = Some(false_child_idx);
@@ -862,104 +865,11 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
true true
} }
/// Compute feature importances for the fitted tree.
pub fn compute_feature_importances(&self, normalize: bool) -> Vec<f64> {
let mut importances = vec![0f64; self.num_features];
for node in self.nodes().iter() {
if node.true_child.is_none() && node.false_child.is_none() {
continue;
}
let left = &self.nodes()[node.true_child.unwrap()];
let right = &self.nodes()[node.false_child.unwrap()];
importances[node.split_feature] += node.n_node_samples as f64 * node.impurity.unwrap()
- left.n_node_samples as f64 * left.impurity.unwrap()
- right.n_node_samples as f64 * right.impurity.unwrap();
}
for item in importances.iter_mut() {
*item /= self.nodes()[0].n_node_samples as f64;
}
if normalize {
let sum = importances.iter().sum::<f64>();
for importance in importances.iter_mut() {
*importance /= sum;
}
}
importances
}
/// Predict class probabilities for the input samples.
///
/// # Arguments
///
/// * `x` - The input samples as a matrix where each row is a sample and each column is a feature.
///
/// # Returns
///
/// A `Result` containing a `DenseMatrix<f64>` where each row corresponds to a sample and each column
/// corresponds to a class. The values represent the probability of the sample belonging to each class.
///
/// # Errors
///
/// Returns an error if at least one row prediction process fails.
pub fn predict_proba(&self, x: &X) -> Result<DenseMatrix<f64>, Failed> {
let (n_samples, _) = x.shape();
let n_classes = self.classes().len();
let mut result = DenseMatrix::<f64>::zeros(n_samples, n_classes);
for i in 0..n_samples {
let probs = self.predict_proba_for_row(x, i)?;
for (j, &prob) in probs.iter().enumerate() {
result.set((i, j), prob);
}
}
Ok(result)
}
/// Predict class probabilities for a single input sample.
///
/// # Arguments
///
/// * `x` - The input matrix containing all samples.
/// * `row` - The index of the row in `x` for which to predict probabilities.
///
/// # Returns
///
/// A vector of probabilities, one for each class, representing the probability
/// of the input sample belonging to each class.
fn predict_proba_for_row(&self, x: &X, row: usize) -> Result<Vec<f64>, Failed> {
let mut node = 0;
while let Some(current_node) = self.nodes().get(node) {
if current_node.true_child.is_none() && current_node.false_child.is_none() {
// Leaf node reached
let mut probs = vec![0.0; self.classes().len()];
probs[current_node.output] = 1.0;
return Ok(probs);
}
let split_feature = current_node.split_feature;
let split_value = current_node.split_value.unwrap_or(f64::NAN);
if x.get((row, split_feature)).to_f64().unwrap() <= split_value {
node = current_node.true_child.unwrap();
} else {
node = current_node.false_child.unwrap();
}
}
// This should never happen if the tree is properly constructed
Err(Failed::predict("Nodes iteration did not reach leaf"))
}
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::linalg::basic::arrays::Array;
use crate::linalg::basic::matrix::DenseMatrix; use crate::linalg::basic::matrix::DenseMatrix;
#[test] #[test]
@@ -991,60 +901,17 @@ mod tests {
)] )]
#[test] #[test]
fn gini_impurity() { fn gini_impurity() {
assert!((impurity(&SplitCriterion::Gini, &[7, 3], 10) - 0.42).abs() < f64::EPSILON);
assert!( assert!(
(impurity(&SplitCriterion::Entropy, &[7, 3], 10) - 0.8812908992306927).abs() (impurity(&SplitCriterion::Gini, &vec![7, 3], 10) - 0.42).abs() < std::f64::EPSILON
< f64::EPSILON
); );
assert!( assert!(
(impurity(&SplitCriterion::ClassificationError, &[7, 3], 10) - 0.3).abs() (impurity(&SplitCriterion::Entropy, &vec![7, 3], 10) - 0.8812908992306927).abs()
< f64::EPSILON < std::f64::EPSILON
);
assert!(
(impurity(&SplitCriterion::ClassificationError, &vec![7, 3], 10) - 0.3).abs()
< std::f64::EPSILON
); );
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn test_predict_proba() {
let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
])
.unwrap();
let y: Vec<usize> = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap();
let probabilities = tree.predict_proba(&x).unwrap();
assert_eq!(probabilities.shape(), (10, 2));
for row in 0..10 {
let row_sum: f64 = probabilities.get_row(row).sum();
assert!(
(row_sum - 1.0).abs() < 1e-6,
"Row probabilities should sum to 1"
);
}
// Check if the first 5 samples have higher probability for class 0
for i in 0..5 {
assert!(probabilities.get((i, 0)) > probabilities.get((i, 1)));
}
// Check if the last 5 samples have higher probability for class 1
for i in 5..10 {
assert!(probabilities.get((i, 1)) > probabilities.get((i, 0)));
}
} }
#[cfg_attr( #[cfg_attr(
@@ -1075,8 +942,7 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0], &[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3], &[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4], &[5.2, 2.7, 3.9, 1.4],
]) ]);
.unwrap();
let y: Vec<u32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; let y: Vec<u32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
assert_eq!( assert_eq!(
@@ -1105,17 +971,6 @@ mod tests {
); );
} }
#[test]
fn test_random_matrix_with_wrong_rownum() {
let x_rand: DenseMatrix<f64> = DenseMatrix::<f64>::rand(21, 200);
let y: Vec<u32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
let fail = DecisionTreeClassifier::fit(&x_rand, &y, Default::default());
assert!(fail.is_err());
}
#[cfg_attr( #[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")), all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test wasm_bindgen_test::wasm_bindgen_test
@@ -1143,8 +998,7 @@ mod tests {
&[0., 0., 1., 1.], &[0., 0., 1., 1.],
&[0., 0., 0., 0.], &[0., 0., 0., 0.],
&[0., 0., 0., 1.], &[0., 0., 0., 1.],
]) ]);
.unwrap();
let y: Vec<u32> = vec![1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0]; let y: Vec<u32> = vec![1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0];
assert_eq!( assert_eq!(
@@ -1155,43 +1009,6 @@ mod tests {
); );
} }
#[test]
fn test_compute_feature_importances() {
let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
&[1., 1., 1., 0.],
&[1., 1., 1., 0.],
&[1., 1., 1., 1.],
&[1., 1., 0., 0.],
&[1., 1., 0., 1.],
&[1., 0., 1., 0.],
&[1., 0., 1., 0.],
&[1., 0., 1., 1.],
&[1., 0., 0., 0.],
&[1., 0., 0., 1.],
&[0., 1., 1., 0.],
&[0., 1., 1., 0.],
&[0., 1., 1., 1.],
&[0., 1., 0., 0.],
&[0., 1., 0., 1.],
&[0., 0., 1., 0.],
&[0., 0., 1., 0.],
&[0., 0., 1., 1.],
&[0., 0., 0., 0.],
&[0., 0., 0., 1.],
])
.unwrap();
let y: Vec<u32> = vec![1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0];
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap();
assert_eq!(
tree.compute_feature_importances(false),
vec![0., 0., 0.21333333333333332, 0.26666666666666666]
);
assert_eq!(
tree.compute_feature_importances(true),
vec![0., 0., 0.4444444444444444, 0.5555555555555556]
);
}
#[cfg_attr( #[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")), all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test wasm_bindgen_test::wasm_bindgen_test
@@ -1220,8 +1037,7 @@ mod tests {
&[0., 0., 1., 1.], &[0., 0., 1., 1.],
&[0., 0., 0., 0.], &[0., 0., 0., 0.],
&[0., 0., 0., 1.], &[0., 0., 0., 1.],
]) ]);
.unwrap();
let y = vec![1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0]; let y = vec![1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0];
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap(); let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap();
+401 -23
View File
@@ -18,6 +18,7 @@
//! Example: //! Example:
//! //!
//! ``` //! ```
//! use rand::thread_rng;
//! use smartcore::linalg::basic::matrix::DenseMatrix; //! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::tree::decision_tree_regressor::*; //! use smartcore::tree::decision_tree_regressor::*;
//! //!
@@ -39,7 +40,7 @@
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564], //! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], //! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551], //! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
//! ]).unwrap(); //! ]);
//! let y: Vec<f64> = vec![ //! let y: Vec<f64> = vec![
//! 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, //! 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0,
//! 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9, //! 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9,
@@ -58,17 +59,22 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::collections::LinkedList;
use std::default::Default; use std::default::Default;
use std::fmt::Debug; use std::fmt::Debug;
use std::marker::PhantomData;
use rand::seq::SliceRandom;
use rand::Rng;
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::base_tree_regressor::{BaseTreeRegressor, BaseTreeRegressorParameters, Splitter};
use crate::api::{Predictor, SupervisedEstimator}; use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed; use crate::error::Failed;
use crate::linalg::basic::arrays::{Array1, Array2}; use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1};
use crate::numbers::basenum::Number; use crate::numbers::basenum::Number;
use crate::rand_custom::get_rng_impl;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -93,7 +99,41 @@ pub struct DecisionTreeRegressorParameters {
#[derive(Debug)] #[derive(Debug)]
pub struct DecisionTreeRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> pub struct DecisionTreeRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
{ {
tree_regressor: Option<BaseTreeRegressor<TX, TY, X, Y>>, nodes: Vec<Node>,
parameters: Option<DecisionTreeRegressorParameters>,
depth: u16,
_phantom_tx: PhantomData<TX>,
_phantom_ty: PhantomData<TY>,
_phantom_x: PhantomData<X>,
_phantom_y: PhantomData<Y>,
}
impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
DecisionTreeRegressor<TX, TY, X, Y>
{
/// Get nodes, return a shared reference
fn nodes(&self) -> &Vec<Node> {
self.nodes.as_ref()
}
/// Get parameters, return a shared reference
fn parameters(&self) -> &DecisionTreeRegressorParameters {
self.parameters.as_ref().unwrap()
}
/// Get estimate of intercept, return value
fn depth(&self) -> u16 {
self.depth
}
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
struct Node {
output: f64,
split_feature: usize,
split_value: Option<f64>,
split_score: Option<f64>,
true_child: Option<usize>,
false_child: Option<usize>,
} }
impl DecisionTreeRegressorParameters { impl DecisionTreeRegressorParameters {
@@ -257,11 +297,87 @@ impl Default for DecisionTreeRegressorSearchParameters {
} }
} }
impl Node {
fn new(output: f64) -> Self {
Node {
output,
split_feature: 0,
split_value: Option::None,
split_score: Option::None,
true_child: Option::None,
false_child: Option::None,
}
}
}
impl PartialEq for Node {
fn eq(&self, other: &Self) -> bool {
(self.output - other.output).abs() < std::f64::EPSILON
&& self.split_feature == other.split_feature
&& match (self.split_value, other.split_value) {
(Some(a), Some(b)) => (a - b).abs() < std::f64::EPSILON,
(None, None) => true,
_ => false,
}
&& match (self.split_score, other.split_score) {
(Some(a), Some(b)) => (a - b).abs() < std::f64::EPSILON,
(None, None) => true,
_ => false,
}
}
}
impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
for DecisionTreeRegressor<TX, TY, X, Y> for DecisionTreeRegressor<TX, TY, X, Y>
{ {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.tree_regressor == other.tree_regressor if self.depth != other.depth || self.nodes().len() != other.nodes().len() {
false
} else {
self.nodes()
.iter()
.zip(other.nodes().iter())
.all(|(a, b)| a == b)
}
}
}
struct NodeVisitor<'a, TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
x: &'a X,
y: &'a Y,
node: usize,
samples: Vec<usize>,
order: &'a [Vec<usize>],
true_child_output: f64,
false_child_output: f64,
level: u16,
_phantom_tx: PhantomData<TX>,
_phantom_ty: PhantomData<TY>,
}
impl<'a, TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
NodeVisitor<'a, TX, TY, X, Y>
{
fn new(
node_id: usize,
samples: Vec<usize>,
order: &'a [Vec<usize>],
x: &'a X,
y: &'a Y,
level: u16,
) -> Self {
NodeVisitor {
x,
y,
node: node_id,
samples,
order,
true_child_output: 0f64,
false_child_output: 0f64,
level,
_phantom_tx: PhantomData,
_phantom_ty: PhantomData,
}
} }
} }
@@ -271,7 +387,13 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
{ {
fn new() -> Self { fn new() -> Self {
Self { Self {
tree_regressor: None, nodes: vec![],
parameters: Option::None,
depth: 0u16,
_phantom_tx: PhantomData,
_phantom_ty: PhantomData,
_phantom_x: PhantomData,
_phantom_y: PhantomData,
} }
} }
@@ -299,23 +421,281 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
y: &Y, y: &Y,
parameters: DecisionTreeRegressorParameters, parameters: DecisionTreeRegressorParameters,
) -> Result<DecisionTreeRegressor<TX, TY, X, Y>, Failed> { ) -> Result<DecisionTreeRegressor<TX, TY, X, Y>, Failed> {
let tree_parameters = BaseTreeRegressorParameters { let (x_nrows, num_attributes) = x.shape();
max_depth: parameters.max_depth, let samples = vec![1; x_nrows];
min_samples_leaf: parameters.min_samples_leaf, DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
min_samples_split: parameters.min_samples_split, }
seed: parameters.seed,
splitter: Splitter::Best, pub(crate) fn fit_weak_learner(
x: &X,
y: &Y,
samples: Vec<usize>,
mtry: usize,
parameters: DecisionTreeRegressorParameters,
) -> Result<DecisionTreeRegressor<TX, TY, X, Y>, Failed> {
let y_m = y.clone();
let y_ncols = y_m.shape();
let (_, num_attributes) = x.shape();
let mut nodes: Vec<Node> = Vec::new();
let mut rng = get_rng_impl(parameters.seed);
let mut n = 0;
let mut sum = 0f64;
for (i, sample_i) in samples.iter().enumerate().take(y_ncols) {
n += *sample_i;
sum += *sample_i as f64 * y_m.get(i).to_f64().unwrap();
}
let root = Node::new(sum / (n as f64));
nodes.push(root);
let mut order: Vec<Vec<usize>> = Vec::new();
for i in 0..num_attributes {
let mut col_i: Vec<TX> = x.get_col(i).iterator(0).copied().collect();
order.push(col_i.argsort_mut());
}
let mut tree = DecisionTreeRegressor {
nodes,
parameters: Some(parameters),
depth: 0u16,
_phantom_tx: PhantomData,
_phantom_ty: PhantomData,
_phantom_x: PhantomData,
_phantom_y: PhantomData,
}; };
let tree = BaseTreeRegressor::fit(x, y, tree_parameters)?;
Ok(Self { let mut visitor = NodeVisitor::<TX, TY, X, Y>::new(0, samples, &order, x, &y_m, 1);
tree_regressor: Some(tree),
}) let mut visitor_queue: LinkedList<NodeVisitor<'_, TX, TY, X, Y>> = LinkedList::new();
if tree.find_best_cutoff(&mut visitor, mtry, &mut rng) {
visitor_queue.push_back(visitor);
}
while tree.depth() < tree.parameters().max_depth.unwrap_or(std::u16::MAX) {
match visitor_queue.pop_front() {
Some(node) => tree.split(node, mtry, &mut visitor_queue, &mut rng),
None => break,
};
}
Ok(tree)
} }
/// Predict regression value for `x`. /// Predict regression value for `x`.
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &X) -> Result<Y, Failed> { pub fn predict(&self, x: &X) -> Result<Y, Failed> {
self.tree_regressor.as_ref().unwrap().predict(x) let mut result = Y::zeros(x.shape().0);
let (n, _) = x.shape();
for i in 0..n {
result.set(i, self.predict_for_row(x, i));
}
Ok(result)
}
pub(crate) fn predict_for_row(&self, x: &X, row: usize) -> TY {
let mut result = 0f64;
let mut queue: LinkedList<usize> = LinkedList::new();
queue.push_back(0);
while !queue.is_empty() {
match queue.pop_front() {
Some(node_id) => {
let node = &self.nodes()[node_id];
if node.true_child.is_none() && node.false_child.is_none() {
result = node.output;
} else if x.get((row, node.split_feature)).to_f64().unwrap()
<= node.split_value.unwrap_or(std::f64::NAN)
{
queue.push_back(node.true_child.unwrap());
} else {
queue.push_back(node.false_child.unwrap());
}
}
None => break,
};
}
TY::from_f64(result).unwrap()
}
fn find_best_cutoff(
&mut self,
visitor: &mut NodeVisitor<'_, TX, TY, X, Y>,
mtry: usize,
rng: &mut impl Rng,
) -> bool {
let (_, n_attr) = visitor.x.shape();
let n: usize = visitor.samples.iter().sum();
if n < self.parameters().min_samples_split {
return false;
}
let sum = self.nodes()[visitor.node].output * n as f64;
let mut variables = (0..n_attr).collect::<Vec<_>>();
if mtry < n_attr {
variables.shuffle(rng);
}
let parent_gain =
n as f64 * self.nodes()[visitor.node].output * self.nodes()[visitor.node].output;
for variable in variables.iter().take(mtry) {
self.find_best_split(visitor, n, sum, parent_gain, *variable);
}
self.nodes()[visitor.node].split_score.is_some()
}
fn find_best_split(
&mut self,
visitor: &mut NodeVisitor<'_, TX, TY, X, Y>,
n: usize,
sum: f64,
parent_gain: f64,
j: usize,
) {
let mut true_sum = 0f64;
let mut true_count = 0;
let mut prevx = Option::None;
for i in visitor.order[j].iter() {
if visitor.samples[*i] > 0 {
let x_ij = *visitor.x.get((*i, j));
if prevx.is_none() || x_ij == prevx.unwrap() {
prevx = Some(x_ij);
true_count += visitor.samples[*i];
true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
continue;
}
let false_count = n - true_count;
if true_count < self.parameters().min_samples_leaf
|| false_count < self.parameters().min_samples_leaf
{
prevx = Some(x_ij);
true_count += visitor.samples[*i];
true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
continue;
}
let true_mean = true_sum / true_count as f64;
let false_mean = (sum - true_sum) / false_count as f64;
let gain = (true_count as f64 * true_mean * true_mean
+ false_count as f64 * false_mean * false_mean)
- parent_gain;
if self.nodes()[visitor.node].split_score.is_none()
|| gain > self.nodes()[visitor.node].split_score.unwrap()
{
self.nodes[visitor.node].split_feature = j;
self.nodes[visitor.node].split_value =
Option::Some((x_ij + prevx.unwrap()).to_f64().unwrap() / 2f64);
self.nodes[visitor.node].split_score = Option::Some(gain);
visitor.true_child_output = true_mean;
visitor.false_child_output = false_mean;
}
prevx = Some(x_ij);
true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
true_count += visitor.samples[*i];
}
}
}
fn split<'a>(
&mut self,
mut visitor: NodeVisitor<'a, TX, TY, X, Y>,
mtry: usize,
visitor_queue: &mut LinkedList<NodeVisitor<'a, TX, TY, X, Y>>,
rng: &mut impl Rng,
) -> bool {
let (n, _) = visitor.x.shape();
let mut tc = 0;
let mut fc = 0;
let mut true_samples: Vec<usize> = vec![0; n];
for (i, true_sample) in true_samples.iter_mut().enumerate().take(n) {
if visitor.samples[i] > 0 {
if visitor
.x
.get((i, self.nodes()[visitor.node].split_feature))
.to_f64()
.unwrap()
<= self.nodes()[visitor.node]
.split_value
.unwrap_or(std::f64::NAN)
{
*true_sample = visitor.samples[i];
tc += *true_sample;
visitor.samples[i] = 0;
} else {
fc += visitor.samples[i];
}
}
}
if tc < self.parameters().min_samples_leaf || fc < self.parameters().min_samples_leaf {
self.nodes[visitor.node].split_feature = 0;
self.nodes[visitor.node].split_value = Option::None;
self.nodes[visitor.node].split_score = Option::None;
return false;
}
let true_child_idx = self.nodes().len();
self.nodes.push(Node::new(visitor.true_child_output));
let false_child_idx = self.nodes().len();
self.nodes.push(Node::new(visitor.false_child_output));
self.nodes[visitor.node].true_child = Some(true_child_idx);
self.nodes[visitor.node].false_child = Some(false_child_idx);
self.depth = u16::max(self.depth, visitor.level + 1);
let mut true_visitor = NodeVisitor::<TX, TY, X, Y>::new(
true_child_idx,
true_samples,
visitor.order,
visitor.x,
visitor.y,
visitor.level + 1,
);
if self.find_best_cutoff(&mut true_visitor, mtry, rng) {
visitor_queue.push_back(true_visitor);
}
let mut false_visitor = NodeVisitor::<TX, TY, X, Y>::new(
false_child_idx,
visitor.samples,
visitor.order,
visitor.x,
visitor.y,
visitor.level + 1,
);
if self.find_best_cutoff(&mut false_visitor, mtry, rng) {
visitor_queue.push_back(false_visitor);
}
true
} }
} }
@@ -370,8 +750,7 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564], &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331], &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551], &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]) ]);
.unwrap();
let y: Vec<f64> = vec![ let y: Vec<f64> = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9, 114.2, 115.7, 116.9,
@@ -385,7 +764,7 @@ mod tests {
assert!((y_hat[i] - y[i]).abs() < 0.1); assert!((y_hat[i] - y[i]).abs() < 0.1);
} }
let expected_y = [ let expected_y = vec![
87.3, 87.3, 87.3, 87.3, 98.9, 98.9, 98.9, 98.9, 98.9, 107.9, 107.9, 107.9, 114.85, 87.3, 87.3, 87.3, 87.3, 98.9, 98.9, 98.9, 98.9, 98.9, 107.9, 107.9, 107.9, 114.85,
114.85, 114.85, 114.85, 114.85, 114.85, 114.85,
]; ];
@@ -406,7 +785,7 @@ mod tests {
assert!((y_hat[i] - expected_y[i]).abs() < 0.1); assert!((y_hat[i] - expected_y[i]).abs() < 0.1);
} }
let expected_y = [ let expected_y = vec![
83.0, 88.35, 88.35, 89.5, 97.15, 97.15, 99.5, 99.5, 101.2, 104.6, 109.6, 109.6, 113.4, 83.0, 88.35, 88.35, 89.5, 97.15, 97.15, 99.5, 99.5, 101.2, 104.6, 109.6, 109.6, 113.4,
113.4, 116.30, 116.30, 113.4, 116.30, 116.30,
]; ];
@@ -452,8 +831,7 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564], &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331], &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551], &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]) ]);
.unwrap();
let y: Vec<f64> = vec![ let y: Vec<f64> = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9, 114.2, 115.7, 116.9,
-1
View File
@@ -19,7 +19,6 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
pub(crate) mod base_tree_regressor;
/// Classification tree for dependent variables that take a finite number of unordered values. /// Classification tree for dependent variables that take a finite number of unordered values.
pub mod decision_tree_classifier; pub mod decision_tree_classifier;
/// Regression tree for for dependent variables that take continuous or ordered discrete values. /// Regression tree for for dependent variables that take continuous or ordered discrete values.
-16
View File
@@ -1,16 +0,0 @@
//! # XGBoost
//!
//! XGBoost, which stands for Extreme Gradient Boosting, is a powerful and efficient implementation of the gradient boosting framework. Gradient boosting is a machine learning technique for regression and classification problems, which produces a prediction model in the form of an ensemble of weak prediction models, typically decision trees.
//!
//! The core idea of boosting is to build the model in a stage-wise fashion. It learns from its mistakes by sequentially adding new models that correct the errors of the previous ones. Unlike bagging, which trains models in parallel, boosting is a sequential process. Each new tree is fit on a modified version of the original data set, specifically focusing on the instances where the previous models performed poorly.
//!
//! XGBoost enhances this process through several key innovations. It employs a more regularized model formalization to control over-fitting, which gives it better performance. It also has a highly optimized and parallelized tree construction process, making it significantly faster and more scalable than traditional gradient boosting implementations.
//!
//! ## References:
//!
//! * "Elements of Statistical Learning", Hastie T., Tibshirani R., Friedman J., 10. Boosting and Additive Trees
//! * XGBoost: A Scalable Tree Boosting System, Chen T., Guestrin C.
// xgboost implementation
pub mod xgb_regressor;
pub use xgb_regressor::{XGRegressor, XGRegressorParameters};
-762
View File
@@ -1,762 +0,0 @@
//! # Extreme Gradient Boosting (XGBoost)
//!
//! XGBoost is a highly efficient and effective implementation of the gradient boosting framework.
//! Like other boosting models, it builds an ensemble of sequential decision trees, where each new tree
//! is trained to correct the errors of the previous ones.
//!
//! What makes XGBoost powerful is its use of both the first and second derivatives (gradient and hessian)
//! of the loss function, which allows for more accurate approximations and faster convergence. It also
//! includes built-in regularization techniques (L1/`alpha` and L2/`lambda`) to prevent overfitting.
//!
//! This implementation was ported to Rust from the concepts and algorithm explained in the blog post
//! ["XGBoost from Scratch"](https://randomrealizations.com/posts/xgboost-from-scratch/). It is designed
//! to be a general-purpose regressor that can be used with any objective function that provides a gradient
//! and a hessian.
//!
//! Example:
//!
//! ```
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::xgboost::{XGRegressor, XGRegressorParameters};
//!
//! // Simple dataset: predict y = 2*x
//! let x = DenseMatrix::from_2d_array(&[
//! &[1.0], &[2.0], &[3.0], &[4.0], &[5.0]
//! ]).unwrap();
//! let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
//!
//! // Use default parameters, but set a few for demonstration
//! let parameters = XGRegressorParameters::default()
//! .with_n_estimators(50)
//! .with_max_depth(3)
//! .with_learning_rate(0.1);
//!
//! // Train the model
//! let model = XGRegressor::fit(&x, &y, parameters).unwrap();
//!
//! // Make predictions
//! let x_test = DenseMatrix::from_2d_array(&[&[6.0], &[7.0]]).unwrap();
//! let y_hat = model.predict(&x_test).unwrap();
//!
//! // y_hat should be close to [12.0, 14.0]
//! ```
//!
use rand::{seq::SliceRandom, Rng};
use std::{iter::zip, marker::PhantomData};
use crate::{
api::{PredictorBorrow, SupervisedEstimatorBorrow},
error::{Failed, FailedError},
linalg::basic::arrays::{Array1, Array2},
numbers::basenum::Number,
rand_custom::get_rng_impl,
};
/// Defines the objective function to be optimized.
/// The objective function provides the loss, gradient (first derivative), and
/// hessian (second derivative) required for the XGBoost algorithm.
#[derive(Clone, Debug)]
pub enum Objective {
/// The objective for regression tasks using Mean Squared Error.
/// Loss: 0.5 * (y_true - y_pred)^2
MeanSquaredError,
}
impl Objective {
/// Calculates the loss for each sample given the true and predicted values.
///
/// # Arguments
/// * `y_true` - A vector of the true target values.
/// * `y_pred` - A vector of the predicted values.
///
/// # Returns
/// The mean of the calculated loss values.
pub fn loss_function<TY: Number, Y: Array1<TY>>(&self, y_true: &Y, y_pred: &Vec<f64>) -> f64 {
match self {
Objective::MeanSquaredError => {
zip(y_true.iterator(0), y_pred)
.map(|(true_val, pred_val)| {
0.5 * (true_val.to_f64().unwrap() - pred_val).powi(2)
})
.sum::<f64>()
/ y_true.shape() as f64
}
}
}
/// Calculates the gradient (first derivative) of the loss function.
///
/// # Arguments
/// * `y_true` - A vector of the true target values.
/// * `y_pred` - A vector of the predicted values.
///
/// # Returns
/// A vector of gradients for each sample.
pub fn gradient<TY: Number, Y: Array1<TY>>(&self, y_true: &Y, y_pred: &Vec<f64>) -> Vec<f64> {
match self {
Objective::MeanSquaredError => zip(y_true.iterator(0), y_pred)
.map(|(true_val, pred_val)| (*pred_val - true_val.to_f64().unwrap()))
.collect(),
}
}
/// Calculates the hessian (second derivative) of the loss function.
///
/// # Arguments
/// * `y_true` - A vector of the true target values.
/// * `y_pred` - A vector of the predicted values.
///
/// # Returns
/// A vector of hessians for each sample.
#[allow(unused_variables)]
pub fn hessian<TY: Number, Y: Array1<TY>>(&self, y_true: &Y, y_pred: &[f64]) -> Vec<f64> {
match self {
Objective::MeanSquaredError => vec![1.0; y_true.shape()],
}
}
}
/// Represents a single decision tree in the XGBoost ensemble.
///
/// This is a recursive data structure where each `TreeRegressor` is a node
/// that can have a left and a right child, also of type `TreeRegressor`.
#[allow(dead_code)]
struct TreeRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
left: Option<Box<TreeRegressor<TX, TY, X, Y>>>,
right: Option<Box<TreeRegressor<TX, TY, X, Y>>>,
/// The output value of this node. If it's a leaf, this is the final prediction.
value: f64,
/// The feature value threshold used to split this node.
threshold: f64,
/// The index of the feature used for splitting.
split_feature_idx: usize,
/// The gain in score achieved by this split.
split_score: f64,
_phantom_tx: PhantomData<TX>,
_phantom_ty: PhantomData<TY>,
_phantom_x: PhantomData<X>,
_phantom_y: PhantomData<Y>,
}
impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
TreeRegressor<TX, TY, X, Y>
{
/// Recursively builds a decision tree (a `TreeRegressor` node).
///
/// This function determines the optimal split for the given set of samples (`idxs`)
/// and then recursively calls itself to build the left and right child nodes.
///
/// # Arguments
/// * `data` - The full training dataset.
/// * `g` - Gradients for all samples.
/// * `h` - Hessians for all samples.
/// * `idxs` - The indices of the samples belonging to the current node.
/// * `max_depth` - The maximum remaining depth for this branch.
/// * `min_child_weight` - The minimum sum of hessians required in a child node.
/// * `lambda` - L2 regularization term on weights.
/// * `gamma` - Minimum loss reduction required to make a further partition.
pub fn fit(
data: &X,
g: &Vec<f64>,
h: &Vec<f64>,
idxs: &[usize],
max_depth: u16,
min_child_weight: f64,
lambda: f64,
gamma: f64,
) -> Self {
let g_sum = idxs.iter().map(|&i| g[i]).sum::<f64>();
let h_sum = idxs.iter().map(|&i| h[i]).sum::<f64>();
let value = -g_sum / (h_sum + lambda);
let mut best_feature_idx = usize::MAX;
let mut best_split_score = 0.0;
let mut best_threshold = 0.0;
let mut left = Option::None;
let mut right = Option::None;
if max_depth > 0 {
Self::insert_child_nodes(
data,
g,
h,
idxs,
&mut best_feature_idx,
&mut best_split_score,
&mut best_threshold,
&mut left,
&mut right,
max_depth,
min_child_weight,
lambda,
gamma,
);
}
Self {
left,
right,
value,
threshold: best_threshold,
split_feature_idx: best_feature_idx,
split_score: best_split_score,
_phantom_tx: PhantomData,
_phantom_ty: PhantomData,
_phantom_x: PhantomData,
_phantom_y: PhantomData,
}
}
/// Finds the best split and creates child nodes if a valid split is found.
fn insert_child_nodes(
data: &X,
g: &Vec<f64>,
h: &Vec<f64>,
idxs: &[usize],
best_feature_idx: &mut usize,
best_split_score: &mut f64,
best_threshold: &mut f64,
left: &mut Option<Box<Self>>,
right: &mut Option<Box<Self>>,
max_depth: u16,
min_child_weight: f64,
lambda: f64,
gamma: f64,
) {
let (_, n_features) = data.shape();
for i in 0..n_features {
Self::find_best_split(
data,
g,
h,
idxs,
i,
best_feature_idx,
best_split_score,
best_threshold,
min_child_weight,
lambda,
gamma,
);
}
// A split is only valid if it results in a positive gain.
if *best_split_score > 0.0 {
let mut left_idxs = Vec::new();
let mut right_idxs = Vec::new();
for idx in idxs.iter() {
if data.get((*idx, *best_feature_idx)).to_f64().unwrap() <= *best_threshold {
left_idxs.push(*idx);
} else {
right_idxs.push(*idx);
}
}
*left = Some(Box::new(TreeRegressor::fit(
data,
g,
h,
&left_idxs,
max_depth - 1,
min_child_weight,
lambda,
gamma,
)));
*right = Some(Box::new(TreeRegressor::fit(
data,
g,
h,
&right_idxs,
max_depth - 1,
min_child_weight,
lambda,
gamma,
)));
}
}
/// Iterates through a single feature to find the best possible split point.
fn find_best_split(
data: &X,
g: &[f64],
h: &[f64],
idxs: &[usize],
feature_idx: usize,
best_feature_idx: &mut usize,
best_split_score: &mut f64,
best_threshold: &mut f64,
min_child_weight: f64,
lambda: f64,
gamma: f64,
) {
let mut sorted_idxs = idxs.to_owned();
sorted_idxs.sort_by(|a, b| {
data.get((*a, feature_idx))
.partial_cmp(data.get((*b, feature_idx)))
.unwrap()
});
let sum_g = sorted_idxs.iter().map(|&i| g[i]).sum::<f64>();
let sum_h = sorted_idxs.iter().map(|&i| h[i]).sum::<f64>();
let mut sum_g_right = sum_g;
let mut sum_h_right = sum_h;
let mut sum_g_left = 0.0;
let mut sum_h_left = 0.0;
for i in 0..sorted_idxs.len() - 1 {
let idx = sorted_idxs[i];
let next_idx = sorted_idxs[i + 1];
let g_i = g[idx];
let h_i = h[idx];
let x_i = data.get((idx, feature_idx)).to_f64().unwrap();
let x_i_next = data.get((next_idx, feature_idx)).to_f64().unwrap();
sum_g_left += g_i;
sum_h_left += h_i;
sum_g_right -= g_i;
sum_h_right -= h_i;
if sum_h_left < min_child_weight || x_i == x_i_next {
continue;
}
if sum_h_right < min_child_weight {
break;
}
let gain = 0.5
* ((sum_g_left * sum_g_left / (sum_h_left + lambda))
+ (sum_g_right * sum_g_right / (sum_h_right + lambda))
- (sum_g * sum_g / (sum_h + lambda)))
- gamma;
if gain > *best_split_score {
*best_split_score = gain;
*best_threshold = (x_i + x_i_next) / 2.0;
*best_feature_idx = feature_idx;
}
}
}
/// Predicts the output values for a dataset.
pub fn predict(&self, data: &X) -> Vec<f64> {
let (n_samples, n_features) = data.shape();
(0..n_samples)
.map(|i| {
self.predict_for_row(&Vec::from_iterator(
data.get_row(i).iterator(0).copied(),
n_features,
))
})
.collect()
}
/// Predicts the output value for a single row of data by traversing the tree.
pub fn predict_for_row(&self, row: &Vec<TX>) -> f64 {
// A leaf node is identified by having no children.
if self.left.is_none() {
return self.value;
}
// Recurse down the appropriate branch.
let child = if row[self.split_feature_idx].to_f64().unwrap() <= self.threshold {
self.left.as_ref().unwrap()
} else {
self.right.as_ref().unwrap()
};
child.predict_for_row(row)
}
}
/// Parameters for the `jRegressor` model.
///
/// This struct holds all the hyperparameters that control the training process.
#[derive(Clone, Debug)]
pub struct XGRegressorParameters {
/// The number of boosting rounds or trees to build.
pub n_estimators: usize,
/// The maximum depth of each tree.
pub max_depth: u16,
/// Step size shrinkage used to prevent overfitting.
pub learning_rate: f64,
/// Minimum sum of instance weight (hessian) needed in a child.
pub min_child_weight: usize,
/// L2 regularization term on weights.
pub lambda: f64,
/// Minimum loss reduction required to make a further partition on a leaf node.
pub gamma: f64,
/// The initial prediction score for all instances.
pub base_score: f64,
/// The fraction of samples to be used for fitting the individual base learners.
pub subsample: f64,
/// The seed for the random number generator for reproducibility.
pub seed: u64,
/// The objective function to be optimized.
pub objective: Objective,
}
impl Default for XGRegressorParameters {
/// Creates a new set of `XGRegressorParameters` with default values.
fn default() -> Self {
Self {
n_estimators: 100,
learning_rate: 0.3,
max_depth: 6,
min_child_weight: 1,
lambda: 1.0,
gamma: 0.0,
base_score: 0.5,
subsample: 1.0,
seed: 0,
objective: Objective::MeanSquaredError,
}
}
}
// Builder pattern for XGRegressorParameters
impl XGRegressorParameters {
/// Sets the number of boosting rounds or trees to build.
pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
self.n_estimators = n_estimators;
self
}
/// Sets the step size shrinkage used to prevent overfitting.
///
/// Also known as `eta`. A smaller value makes the model more robust by preventing
/// too much weight being given to any single tree.
pub fn with_learning_rate(mut self, learning_rate: f64) -> Self {
self.learning_rate = learning_rate;
self
}
/// Sets the maximum depth of each individual tree.
// A lower value helps prevent overfitting.*
pub fn with_max_depth(mut self, max_depth: u16) -> Self {
self.max_depth = max_depth;
self
}
/// Sets the minimum sum of instance weight (hessian) needed in a child node.
///
/// If the tree partition step results in a leaf node with the sum of
// instance weight less than `min_child_weight`, then the building process*
/// will give up further partitioning.
pub fn with_min_child_weight(mut self, min_child_weight: usize) -> Self {
self.min_child_weight = min_child_weight;
self
}
/// Sets the L2 regularization term on weights (`lambda`).
///
/// Increasing this value will make the model more conservative.
pub fn with_lambda(mut self, lambda: f64) -> Self {
self.lambda = lambda;
self
}
/// Sets the minimum loss reduction required to make a further partition on a leaf node.
///
/// The larger `gamma` is, the more conservative the algorithm will be.
pub fn with_gamma(mut self, gamma: f64) -> Self {
self.gamma = gamma;
self
}
/// Sets the initial prediction score for all instances.
pub fn with_base_score(mut self, base_score: f64) -> Self {
self.base_score = base_score;
self
}
/// Sets the fraction of samples to be used for fitting individual base learners.
///
/// A value of less than 1.0 introduces randomness and helps prevent overfitting.
pub fn with_subsample(mut self, subsample: f64) -> Self {
self.subsample = subsample;
self
}
/// Sets the seed for the random number generator for reproducibility.
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
/// Sets the objective function to be optimized during training.
pub fn with_objective(mut self, objective: Objective) -> Self {
self.objective = objective;
self
}
}
/// An Extreme Gradient Boosting (XGBoost) model for regression and classification tasks.
pub struct XGRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
regressors: Option<Vec<TreeRegressor<TX, TY, X, Y>>>,
parameters: Option<XGRegressorParameters>,
_phantom_ty: PhantomData<TY>,
_phantom_tx: PhantomData<TX>,
_phantom_y: PhantomData<Y>,
_phantom_x: PhantomData<X>,
}
impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> XGRegressor<TX, TY, X, Y> {
/// Fits the XGBoost model to the training data.
pub fn fit(data: &X, y: &Y, parameters: XGRegressorParameters) -> Result<Self, Failed> {
if parameters.subsample > 1.0 || parameters.subsample <= 0.0 {
return Err(Failed::because(
FailedError::ParametersError,
"Subsample ratio must be in (0, 1].",
));
}
let (n_samples, _) = data.shape();
let learning_rate = parameters.learning_rate;
let mut predictions = vec![parameters.base_score; n_samples];
let mut regressors = Vec::new();
let mut rng = get_rng_impl(Some(parameters.seed));
for _ in 0..parameters.n_estimators {
let gradients = parameters.objective.gradient(y, &predictions);
let hessians = parameters.objective.hessian(y, &predictions);
let sample_idxs = if parameters.subsample < 1.0 {
Self::sample_without_replacement(n_samples, parameters.subsample, &mut rng)
} else {
(0..n_samples).collect::<Vec<usize>>()
};
let regressor = TreeRegressor::fit(
data,
&gradients,
&hessians,
&sample_idxs,
parameters.max_depth,
parameters.min_child_weight as f64,
parameters.lambda,
parameters.gamma,
);
let corrections = regressor.predict(data);
predictions = zip(predictions, corrections)
.map(|(pred, correction)| pred + (learning_rate * correction))
.collect();
regressors.push(regressor);
}
Ok(Self {
regressors: Some(regressors),
parameters: Some(parameters),
_phantom_ty: PhantomData,
_phantom_y: PhantomData,
_phantom_tx: PhantomData,
_phantom_x: PhantomData,
})
}
/// Predicts target values for the given input data.
pub fn predict(&self, data: &X) -> Result<Vec<TX>, Failed> {
let (n_samples, _) = data.shape();
let parameters = self.parameters.as_ref().unwrap();
let mut predictions = vec![parameters.base_score; n_samples];
let regressors = self.regressors.as_ref().unwrap();
for regressor in regressors.iter() {
let corrections = regressor.predict(data);
predictions = zip(predictions, corrections)
.map(|(pred, correction)| pred + (parameters.learning_rate * correction))
.collect();
}
Ok(predictions
.into_iter()
.map(|p| TX::from_f64(p).unwrap())
.collect())
}
/// Creates a random sample of indices without replacement.
fn sample_without_replacement(
population_size: usize,
subsample_ratio: f64,
rng: &mut impl Rng,
) -> Vec<usize> {
let mut indices: Vec<usize> = (0..population_size).collect();
indices.shuffle(rng);
indices.truncate((population_size as f64 * subsample_ratio) as usize);
indices
}
}
// Boilerplate implementation for the smartcore traits
impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
SupervisedEstimatorBorrow<'_, X, Y, XGRegressorParameters> for XGRegressor<TX, TY, X, Y>
{
fn new() -> Self {
Self {
regressors: None,
parameters: None,
_phantom_ty: PhantomData,
_phantom_y: PhantomData,
_phantom_tx: PhantomData,
_phantom_x: PhantomData,
}
}
fn fit(x: &X, y: &Y, parameters: &XGRegressorParameters) -> Result<Self, Failed> {
XGRegressor::fit(x, y, parameters.clone())
}
}
impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> PredictorBorrow<'_, X, TX>
for XGRegressor<TX, TY, X, Y>
{
fn predict(&self, x: &X) -> Result<Vec<TX>, Failed> {
self.predict(x)
}
}
// ------------------- TESTS -------------------
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix};
/// Tests the gradient and hessian calculations for MeanSquaredError.
#[test]
fn test_mse_objective() {
let objective = Objective::MeanSquaredError;
let y_true = vec![1.0, 2.0, 3.0];
let y_pred = vec![1.5, 2.5, 2.5];
let gradients = objective.gradient(&y_true, &y_pred);
let hessians = objective.hessian(&y_true, &y_pred);
// Gradients should be (pred - true)
assert_eq!(gradients, vec![0.5, 0.5, -0.5]);
// Hessians should be all 1.0 for MSE
assert_eq!(hessians, vec![1.0, 1.0, 1.0]);
}
#[test]
fn test_find_best_split_multidimensional() {
// Data has two features. The second feature is a better predictor.
let data = vec![
vec![1.0, 10.0], // g = -0.5
vec![1.0, 20.0], // g = -1.0
vec![1.0, 30.0], // g = 1.0
vec![1.0, 40.0], // g = 1.5
];
let data = DenseMatrix::from_2d_vec(&data).unwrap();
let g = vec![-0.5, -1.0, 1.0, 1.5];
let h = vec![1.0, 1.0, 1.0, 1.0];
let idxs = (0..4).collect::<Vec<usize>>();
let mut best_feature_idx = usize::MAX;
let mut best_split_score = 0.0;
let mut best_threshold = 0.0;
// Manually calculated expected gain for the best split (on feature 1, with lambda=1.0).
// G_left = -1.5, H_left = 2.0
// G_right = 2.5, H_right = 2.0
// G_total = 1.0, H_total = 4.0
// Gain = 0.5 * (G_l^2/(H_l+λ) + G_r^2/(H_r+λ) - G_t^2/(H_t+λ))
// Gain = 0.5 * ((-1.5)^2/(2+1) + (2.5)^2/(2+1) - (1.0)^2/(4+1))
// Gain = 0.5 * (2.25/3 + 6.25/3 - 1.0/5) = 0.5 * (0.75 + 2.0833 - 0.2) = 1.3166...
let expected_gain = 1.3166666666666667;
// Search both features. The algorithm must find the best split on feature 1.
let (_, n_features) = data.shape();
for i in 0..n_features {
TreeRegressor::<f64, f64, DenseMatrix<f64>, Vec<f64>>::find_best_split(
&data,
&g,
&h,
&idxs,
i,
&mut best_feature_idx,
&mut best_split_score,
&mut best_threshold,
1.0,
1.0,
0.0,
);
}
assert_eq!(best_feature_idx, 1); // Should choose the second feature
assert!((best_split_score - expected_gain).abs() < 1e-9);
assert_eq!(best_threshold, 25.0); // (20 + 30) / 2
}
/// Tests that the TreeRegressor can build a simple one-level tree on multidimensional data.
#[test]
fn test_tree_regressor_fit_multidimensional() {
let data = vec![
vec![1.0, 10.0],
vec![1.0, 20.0],
vec![1.0, 30.0],
vec![1.0, 40.0],
];
let data = DenseMatrix::from_2d_vec(&data).unwrap();
let g = vec![-0.5, -1.0, 1.0, 1.5];
let h = vec![1.0, 1.0, 1.0, 1.0];
let idxs = (0..4).collect::<Vec<usize>>();
let tree = TreeRegressor::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
&data, &g, &h, &idxs, 2, 1.0, 1.0, 0.0,
);
// Check that the root node was split on the correct feature
assert!(tree.left.is_some());
assert!(tree.right.is_some());
assert_eq!(tree.split_feature_idx, 1); // Should split on the second feature
assert_eq!(tree.threshold, 25.0);
// Check leaf values (G/H+lambda)
// Left leaf: G = -1.5, H = 2.0 => value = -(-1.5)/(2+1) = 0.5
// Right leaf: G = 2.5, H = 2.0 => value = -(2.5)/(2+1) = -0.8333
assert!((tree.left.unwrap().value - 0.5).abs() < 1e-9);
assert!((tree.right.unwrap().value - (-0.833333333)).abs() < 1e-9);
}
/// A "smoke test" to ensure the main XGRegressor can fit and predict on multidimensional data.
#[test]
fn test_xgregressor_fit_predict_multidimensional() {
// Simple 2D data where y is roughly 2*x1 + 3*x2
let x_vec = vec![
vec![1.0, 1.0],
vec![2.0, 1.0],
vec![1.0, 2.0],
vec![2.0, 2.0],
];
let x = DenseMatrix::from_2d_vec(&x_vec).unwrap();
let y = vec![5.0, 7.0, 8.0, 10.0];
let params = XGRegressorParameters::default()
.with_n_estimators(10)
.with_max_depth(2);
let fit_result = XGRegressor::fit(&x, &y, params);
assert!(
fit_result.is_ok(),
"Fit failed with error: {:?}",
fit_result.err()
);
let model = fit_result.unwrap();
let predict_result = model.predict(&x);
assert!(
predict_result.is_ok(),
"Predict failed with error: {:?}",
predict_result.err()
);
let predictions = predict_result.unwrap();
assert_eq!(predictions.len(), 4);
}
}