132 Commits

Author SHA1 Message Date
dependabot[bot]
56648f1c15 Update rand_distr requirement from 0.4 to 0.5
Updates the requirements on [rand_distr](https://github.com/rust-random/rand_distr) to permit the latest version.
- [Release notes](https://github.com/rust-random/rand_distr/releases)
- [Changelog](https://github.com/rust-random/rand_distr/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-random/rand_distr/compare/0.4.0...0.5.1)

---
updated-dependencies:
- dependency-name: rand_distr
  dependency-version: 0.5.1
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-01-19 17:02:02 +00:00
Lorenzo Mec-iS
c57a4370ba bump version tp 0.4.9 2026-01-09 06:14:44 +00:00
Georeth Chow
78f18505b1 fix LASSO (#346)
* fix lasso doc typo
* fix lasso optimizer bug
2025-12-05 17:49:07 +09:00
Lorenzo
58a8624fa9 v0.4.8 (#345) 2025-11-29 02:54:35 +00:00
Georeth Chow
18de2aa244 add fit_intercept to LASSO (#344)
* add fit_intercept to LASSO
* lasso: intercept=None if fit_intercept is false
* update CHANGELOG.md to reflect lasso changes
* lasso: minor
2025-11-29 02:46:14 +00:00
Georeth Chow
2bf5f7a1a5 Fix LASSO (first two of #342) (#343)
* Fix LASSO (#342)
* change loss function in doc to match code
* allow `n == p` case
* lasso add test_full_rank_x

---------

Co-authored-by: Zhou Xiaozhou <zxz@jiweifund.com>
2025-11-28 12:15:43 +09:00
Lorenzo
0caa8306ff Modernise CI toolchain to avoid deprecation (#341)
* fix cache failing to find Cargo.toml
2025-11-24 02:25:36 +00:00
Lorenzo
2f63148de4 fix CI (#340)
* fix CI workflow
2025-11-24 02:07:49 +00:00
Lorenzo
f9e473c919 v0.4.7 (#339) 2025-11-24 01:57:25 +00:00
Charlie Martin
70d8a0f34b fix precision and recall calculations (#338)
* fix precision and recall calculations
2025-11-24 01:46:56 +00:00
Charlie Martin
0e42a97514 add serde support for XGRegressor (#337)
* add serde support for XGBoostRegressor
* add traits to dependent structs
2025-11-16 19:31:21 +09:00
Lorenzo
36efd582a5 Fix is_empty method logic in matrix.rs (#336)
* Fix is_empty method logic in matrix.rs
* bump to 0.4.6
* silence some clippy
2025-11-15 05:22:42 +00:00
Lorenzo
70212c71e0 Update Cargo.toml (#333) 2025-10-09 17:37:02 +01:00
Lorenzo
63f86f7bc9 Add with_top_k to CosineSimilarity (#332)
* Implement cosine similarity and cosinepair
* formatting
* fix clippy
* Add top k CosinePair
* fix distance computation
* set min similarity for constant zeros
* bump version to 0.4.5
2025-10-09 17:27:54 +01:00
Lorenzo
e633afa520 set min similarity for constant zeros (#331)
* set min similarity for constant zeros
* bump version
2025-10-02 15:41:18 +01:00
Lorenzo
b6e32fb328 Update README.md (#330) 2025-09-28 16:04:12 +01:00
Lorenzo
948d78a4d0 Create CITATION.cff (#329) 2025-09-28 15:50:50 +01:00
Lorenzo
448b6f77e3 Update README.md (#328) 2025-09-28 15:43:46 +01:00
Lorenzo
09be4681cf Implement cosine similarity and cosinepair (#327)
* Implement cosine similarity and cosinepair
2025-09-27 11:08:57 +01:00
Daniel Lacina
4841791b7e implemented extra trees (#320)
* implemented extra trees

* implemented extra trees
2025-07-12 18:37:11 +01:00
Daniel Lacina
9fef05ecc6 refactored random forest regressor into reusable compoennts (#318) 2025-07-12 15:56:49 +01:00
Daniel Lacina
c5816b0e1b refactored decision tree into reusable components (#316)
* refactored decision tree into reusable components

* got rid of api code from base tree because its an implementation detail

* got rid of api code from base tree because its an implementation detail

* changed name
2025-07-12 11:25:53 +01:00
Daniel Lacina
5cc5528367 implemented xgdboost_regression (#314)
* implemented xgd_regression
2025-07-09 15:25:45 +01:00
Daniel Lacina
d459c48372 implemented single linkage clustering (#313)
* implemented single linkage clustering

---------

Co-authored-by: Lorenzo Mec-iS <tunedconsulting@gmail.com>
2025-07-03 18:05:54 +01:00
Daniel Lacina
730c0d64df implemented multiclass for svc (#308)
* implemented multiclass for svc
* modified the multiclass svc so it doesnt modify the current api
2025-06-16 11:00:11 +01:00
Lorenzo
44424807a0 Implement SVR and SVR kernels with Enum. Add tests for argsort_mut (#303)
* Add tests for argsort_mut
* Add formatting and cleaning up .github directory
* fix clippy error. suggestion to use .contains()
* define type explicitly for variable jstack
* Implement kernel as enumerator
* basic svr and svr_params implementation
* Complete enum implementation for Kernels. Implement search grid for SVR. Add documentation.
* Fix serde configuration in cargo clippy
*  Implement search parameters (#304)
* Implement SVR kernels as enumerator
* basic svr and svr_params implementation
* Implement search grid for SVR. Add documentation.
* Fix serde configuration in cargo clippy
* Fix wasm32 typetag
* fix typetag
* Bump to version 0.4.2
2025-06-02 11:01:46 +01:00
morenol
76d1ef610d Update Cargo.toml (#299)
* Update Cargo.toml

* chore: fix clippy

* chore: bump actions

* chore: fix clippy

* chore: update target name

---------

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2025-04-24 23:24:29 -04:00
Lorenzo
4092e24c2a Update README.md 2025-02-04 14:26:53 +00:00
Lorenzo
17dc9f3bbf Add ordered pairs for FastPair (#252)
* Add ordered_pairs method to FastPair
* add tests to fastpair
2025-01-28 00:48:08 +00:00
Lorenzo
c8ec8fec00 Fix #245: return error for NaN in naive bayes (#246)
* Fix #245: return error for NaN in naive bayes
* Implement error handling for NaN values in NBayes predict:
* general behaviour has been kept unchanged according to original tests in `mod.rs`
* aka: error is returned only if all the predicted probabilities are NaN
* Add tests
* Add test with static values
* Add test for numerical stability with numpy
2025-01-27 23:17:55 +00:00
Lorenzo
3da433f757 Implement predict_proba for DecisionTreeClassifier (#287)
* Implement predict_proba for DecisionTreeClassifier
* Some automated fixes suggested by cargo clippy --fix
2025-01-20 18:50:00 +00:00
dependabot[bot]
4523ac73ff Update itertools requirement from 0.12.0 to 0.13.0 (#280)
Updates the requirements on [itertools](https://github.com/rust-itertools/itertools) to permit the latest version.
- [Changelog](https://github.com/rust-itertools/itertools/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-itertools/itertools/compare/v0.12.0...v0.13.0)

---
updated-dependencies:
- dependency-name: itertools
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-11-25 11:47:23 -04:00
morenol
ba75f9ffad chore: fix clippy (#283)
* chore: fix clippy


Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2024-11-25 11:34:29 -04:00
Lorenzo
239c00428f Patch to version 0.4.0 (#257)
* uncomment test

* Add random test for logistic regression

* linting

* Bump version

* Add test for logistic regression

* linting

* initial commit

* final

* final-clean

* Bump to 0.4.0

* Fix linter

* cleanup

* Update CHANDELOG with breaking changes

* Update CHANDELOG date

* Add functional methods to DenseMatrix implementation

* linting

* add type declaration in test

* Fix Wasm tests failing

* linting

* fix tests

* linting

* Add type annotations on BBDTree constructor

* fix clippy

* fix clippy

* fix tests

* bump version

* run fmt. fix changelog

---------

Co-authored-by: Edmund Cape <edmund@Edmunds-MacBook-Pro.local>
2024-03-04 08:51:27 -05:00
morenol
80a93c1a0e chore: fix clippy (#276)
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2024-02-25 00:17:30 -05:00
Tushushu
4eadd16ce4 Implement the feature importance for Decision Tree Classifier (#275)
* store impurity in the node

* add number of features

* add a TODO

* draft feature importance

* feat

* n_samples of node

* compute_feature_importances

* unit tests

* always calculate impurity

* fix bug

* fix linter
2024-02-24 23:37:30 -05:00
Frédéric Meyer
886b5631b7 In Naive Bayes, avoid using Option::unwrap and so avoid panicking from NaN values (#274) 2024-01-10 14:59:10 -04:00
dependabot[bot]
9c07925d8a Update itertools requirement from 0.11.0 to 0.12.0 (#271)
Updates the requirements on [itertools](https://github.com/rust-itertools/itertools) to permit the latest version.
- [Changelog](https://github.com/rust-itertools/itertools/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-itertools/itertools/compare/v0.11.0...v0.12.0)

---
updated-dependencies:
- dependency-name: itertools
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-20 22:00:34 -04:00
morenol
6f22bbd150 chore: update clippy lints (#272)
* chore: fix clippy lints
---------

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2023-11-20 21:54:09 -04:00
dependabot[bot]
dbdc2b2a77 Update itertools requirement from 0.10.5 to 0.11.0 (#266)
Updates the requirements on [itertools](https://github.com/rust-itertools/itertools) to permit the latest version.
- [Changelog](https://github.com/rust-itertools/itertools/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-itertools/itertools/compare/v0.10.5...v0.11.0)

---
updated-dependencies:
- dependency-name: itertools
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-06-22 17:56:42 +01:00
Lorenzo
2d7c055154 Bump version 2023-05-01 13:20:17 +01:00
Ruben De Smet
545ed6ce2b Remove some allocations (#262)
* Remove some allocations

* Remove some more allocations
2023-04-26 21:46:26 +08:00
morenol
8939ed93b9 chore: fix clippy warnings from Rust release 1.69 (#263)
* chore: fix clippy warnings from Rust release 1.69

* chore: run `cargo fmt`

* refactor: remove unused type parameter

---------

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2023-04-26 01:35:58 +09:00
Lorenzo
9cd7348403 Update CONTRIBUTING.md 2023-04-10 15:13:27 +01:00
Hsiang-Cheng Yang
d52830a818 Update arrays.rs (#253)
fix a typo
2023-03-23 19:15:54 -04:00
Lorenzo
d15ea43975 Remove failure in case of failed upload to codecov.io 2023-03-20 15:08:30 +00:00
Lorenzo
f498f9629e Implement realnum::rand (#251)
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
Co-authored-by: Lorenzo <tunedconsulting@gmail.com>

* Implement rand. Use the new derive [#default]
* Use custom range
* Use range seed
* Bump version
* Add array length checks for
2023-03-20 14:45:44 +00:00
Lorenzo
7d059c4fb1 Update README.md 2023-03-20 11:54:10 +00:00
morenol
c7353d0b57 Run cargo clippy --fix (#250)
* Run `cargo clippy --fix`
* Run `cargo clippy --all-features --fix`
* Fix other clippy warnings
* cargo fmt

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2023-01-27 10:41:18 +00:00
Lorenzo
83dcf9a8ac Delete iml file 2022-11-10 14:09:55 +00:00
Lorenzo (Mec-iS)
3126ee87d3 Pin deps version 2022-11-09 12:03:03 +00:00
morenol
8efb959b3c Handle kernel serialization (#232)
* Handle kernel serialization
* Do not use typetag in WASM
* enable tests for serialization
* Update serde feature deps

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
Co-authored-by: Lorenzo <tunedconsulting@gmail.com>
2022-11-08 16:18:05 +00:00
morenol
9eaae9ef35 Fixes for release (#237)
* Fixes for release
* add new test
* Remove change applied in development branch
* Only add dependency for wasm32
* Update ci.yml

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
Co-authored-by: Lorenzo <tunedconsulting@gmail.com>
2022-11-08 16:07:14 +00:00
Lorenzo (Mec-iS)
46b6285d05 Merge release-0.3 2022-11-08 15:37:11 +00:00
Lorenzo (Mec-iS)
c683073b14 make work cargo build --target wasm32-unknown-unknown 2022-11-08 15:35:04 +00:00
Lorenzo
161d249917 Release 0.3 (#235) 2022-11-08 15:22:34 +00:00
Lorenzo (Mec-iS)
4558be5f73 Merge branch 'release-0.3' of github.com:smartcorelib/smartcore into release-0.3 2022-11-08 15:17:48 +00:00
Lorenzo (Mec-iS)
6c03e6e0b3 update CHANGELOG 2022-11-08 15:17:31 +00:00
Lorenzo
c934f6b6cf update comment 2022-11-08 14:23:13 +00:00
Lorenzo (Mec-iS)
48f1d6b74d use getrandom/js 2022-11-08 14:19:40 +00:00
Lorenzo (Mec-iS)
dad0d01f6d Update CHANGELOG 2022-11-08 13:59:49 +00:00
Lorenzo (Mec-iS)
98b18c4dae Remove unused tests flags 2022-11-08 13:53:50 +00:00
Lorenzo (Mec-iS)
2418b24ff4 Merge branch 'release-0.3' of github.com:smartcorelib/smartcore into release-0.3 2022-11-08 12:22:06 +00:00
Lorenzo (Mec-iS)
6c6f92697f minor fixes to doc 2022-11-08 12:21:34 +00:00
Lorenzo
a4097fce15 minor fix 2022-11-08 12:18:35 +00:00
Lorenzo
b71c7b49cb minor fix 2022-11-08 12:18:03 +00:00
Lorenzo
78bf75b5d8 minor fix 2022-11-08 12:17:32 +00:00
Lorenzo
a60fdaf235 minor fix 2022-11-08 12:17:04 +00:00
Lorenzo
b4206c4b08 minor fix 2022-11-08 12:15:10 +00:00
Lorenzo (Mec-iS)
3c4a807be8 Fix std_rand feature 2022-11-08 12:04:39 +00:00
Lorenzo (Mec-iS)
c1af60cafb cleanup 2022-11-08 11:55:32 +00:00
Lorenzo (Mec-iS)
2fa454ea94 fmt 2022-11-08 11:48:14 +00:00
Lorenzo (Mec-iS)
8e6e5f9e68 Use getrandom as default (for no-std feature) 2022-11-08 11:47:31 +00:00
Lorenzo (Mec-iS)
bf7b714126 Add static analyzer to doc 2022-11-07 18:16:13 +00:00
Lorenzo (Mec-iS)
3ac6598951 Exclude datasets test for wasm/wasi 2022-11-07 13:56:29 +00:00
Lorenzo (Mec-iS)
cc91e31a0e minor fixes 2022-11-07 13:00:51 +00:00
Lorenzo (Mec-iS)
0ec89402e8 minor fix 2022-11-07 12:50:32 +00:00
Lorenzo (Mec-iS)
23b3699730 Release 0.3 2022-11-07 12:48:44 +00:00
Lorenzo
aab3817c58 Create DEVELOPERS.md 2022-11-04 22:23:36 +00:00
Lorenzo
d3a496419d Update README.md 2022-11-04 22:17:55 +00:00
Lorenzo
ab18f127a0 Update README.md 2022-11-04 22:11:54 +00:00
morenol
425c3c1d0b Use Box in SVM and remove lifetimes (#228)
* Do not change external API
Authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-11-04 22:08:30 +00:00
morenol
35fe68e024 Fix CI (#227)
* Update ci.yml
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-11-03 13:48:16 -05:00
Lorenzo
d592b628be Implement CSV reader with new traits (#209) 2022-11-03 15:49:00 +00:00
Lorenzo (Mec-iS)
b66afa9222 Improve options conditionals 2022-11-03 14:58:05 +00:00
Lorenzo (Mec-iS)
ba70bb941f Implement Display for NaiveBayes 2022-11-03 14:18:56 +00:00
Lorenzo (Mec-iS)
d298709040 cargo clippy 2022-11-03 13:44:27 +00:00
Lorenzo (Mec-iS)
e50b4e8637 Fix signature of metrics tests 2022-11-03 13:40:54 +00:00
Lorenzo (Mec-iS)
26b72b67f4 Add kernels' parameters to public interface 2022-11-03 12:30:43 +00:00
Lorenzo
1964424589 Fix svr tests (#222) 2022-11-03 11:48:40 +00:00
Lorenzo (Mec-iS)
deac31a2ab Refactor modules structure in src/svm 2022-11-02 15:28:50 +00:00
Lorenzo (Mec-iS)
4cff7da50d Merge branch 'development' of github.com:smartcorelib/smartcore into development 2022-11-02 15:24:06 +00:00
Lorenzo (Mec-iS)
df0ae907f7 clean up svm 2022-11-02 15:23:56 +00:00
Lorenzo
cfbd45bfc0 Support Wasi as target (#216)
* Improve features
* Add wasm32-wasi as a target
* Update .github/workflows/ci.yml
Co-authored-by: morenol <22335041+morenol@users.noreply.github.com>
2022-11-02 15:22:38 +00:00
Lorenzo
b60329ca5d Disambiguate distances. Implement Fastpair. (#220) 2022-11-02 14:53:28 +00:00
morenol
4b096ad558 build: fix compilation without default features (#218)
* build: fix compilation with optional features
* Remove unused config from Cargo.toml
* Fix cache keys
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-11-02 10:09:03 +00:00
Lorenzo
4cf7e4d7b7 Improve features (#215) 2022-11-01 13:56:20 +00:00
Lorenzo
c3093f11f1 Fix metrics::auc (#212)
* Fix metrics::auc
2022-11-01 12:50:46 +00:00
Lorenzo
083803c900 Port ensemble. Add Display to naive_bayes (#208) 2022-10-31 17:35:33 +00:00
Lorenzo
4f64f2e0ff Update README.md 2022-10-31 10:45:51 +00:00
Lorenzo
52eb6ce023 Merge potential next release v0.4 (#187) Breaking Changes
* First draft of the new n-dimensional arrays + NB use case
* Improves default implementation of multiple Array methods
* Refactors tree methods
* Adds matrix decomposition routines
* Adds matrix decomposition methods to ndarray and nalgebra bindings
* Refactoring + linear regression now uses array2
* Ridge & Linear regression
* LBFGS optimizer & logistic regression
* LBFGS optimizer & logistic regression
* Changes linear methods, metrics and model selection methods to new n-dimensional arrays
* Switches KNN and clustering algorithms to new n-d array layer
* Refactors distance metrics
* Optimizes knn and clustering methods
* Refactors metrics module
* Switches decomposition methods to n-dimensional arrays
* Linalg refactoring - cleanup rng merge (#172)
* Remove legacy DenseMatrix and BaseMatrix implementation. Port the new Number, FloatNumber and Array implementation into module structure.
* Exclude AUC metrics. Needs reimplementation
* Improve developers walkthrough

New traits system in place at `src/numbers` and `src/linalg`
Co-authored-by: Lorenzo <tunedconsulting@gmail.com>

* Provide SupervisedEstimator with a constructor to avoid explicit dynamical box allocation in 'cross_validate' and 'cross_validate_predict' as required by the use of 'dyn' as per Rust 2021
* Implement getters to use as_ref() in src/neighbors
* Implement getters to use as_ref() in src/naive_bayes
* Implement getters to use as_ref() in src/linear
* Add Clone to src/naive_bayes
* Change signature for cross_validate and other model_selection functions to abide to use of dyn in Rust 2021
* Implement ndarray-bindings. Remove FloatNumber from implementations
* Drop nalgebra-bindings support (as decided in conf-call to go for ndarray)
* Remove benches. Benches will have their own repo at smartcore-benches
* Implement SVC
* Implement SVC serialization. Move search parameters in dedicated module
* Implement SVR. Definitely too slow
* Fix compilation issues for wasm (#202)

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
* Fix tests (#203)

* Port linalg/traits/stats.rs
* Improve methods naming
* Improve Display for DenseMatrix

Co-authored-by: Montana Low <montanalow@users.noreply.github.com>
Co-authored-by: VolodymyrOrlov <volodymyr.orlov@gmail.com>
2022-10-31 10:44:57 +00:00
RJ Nowling
bb71656137 Dataset doc cleanup (#205)
* Update iris.rs

* Update mod.rs

* Update digits.rs
2022-10-30 09:32:41 +00:00
Lorenzo
edbac7e4c7 Update README.md 2022-10-18 15:44:38 +01:00
Lorenzo
8a2bdd5a75 Update README.md 2022-10-13 19:47:52 +01:00
Lorenzo
b823b55460 Update CONTRIBUTING.md 2022-10-12 12:21:09 +01:00
morenol
12df301f32 fix: fix issue with iterator for svc search (#182) 2022-10-02 06:15:28 -05:00
morenol
f8210d0af9 refactor: Try to follow similar pattern to other APIs (#180)
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-10-01 16:44:08 -05:00
morenol
3c62686d6e feat: expose hyper tuning module in model_selection (#179)
* feat: expose hyper tuning module in model_selection

* Move to a folder

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-10-01 12:47:56 -05:00
Lorenzo
9c59e37a0f Update CONTRIBUTING.md 2022-09-27 14:27:27 +01:00
Lorenzo
0b619fe7eb Add contribution guidelines (#178) 2022-09-27 14:23:18 +01:00
Montana Low
764309e313 make default params available to serde (#167)
* add seed param to search params

* make default params available to serde

* lints

* create defaults for enums

* lint
2022-09-21 22:48:31 -04:00
Montana Low
403d3f2348 add seed param to search params (#168) 2022-09-22 00:15:26 +01:00
morenol
3a44161406 Lmm/add seeds in more algorithms (#164)
* Provide better output in flaky tests

* feat: add seed parameter to multiple algorithms

* Update changelog

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-09-21 20:35:22 +01:00
Montana Low
48514d1b15 Complete grid search params (#166)
* grid search draft

* hyperparam search for linear estimators

* grid search for ensembles

* support grid search for more algos

* grid search for unsupervised algos

* minor cleanup
2022-09-21 20:34:21 +01:00
morenol
69d8be35de Provide better output in flaky tests (#163) 2022-09-20 17:12:09 +01:00
morenol
c21e75276a feat: allocate first and then proceed to create matrix from Vec of Ro… (#159)
* feat: allocate first and then proceed to create matrix from Vec of RowVectors
2022-09-20 11:29:54 +01:00
morenol
6a2e10452f Make rand_distr optional (#161) 2022-09-20 11:21:02 +01:00
Lorenzo
436da104d7 Update LICENSE 2022-09-19 18:00:17 +01:00
morenol
2510ca4e9d fix: fix compilation warnings when running only with default features (#160)
* fix: fix compilation warnings when running only with default features
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-09-19 10:44:01 -04:00
Tim Toebrock
b6f585e60f Implement a generic read_csv method (#147)
* feat: Add interface to build `Matrix` from rows.
* feat: Add option to derive `RealNumber` from string.
To construct a `Matrix` from csv, and therefore from string, I need to be able to deserialize a generic `RealNumber` from string.
* feat: Implement `Matrix::read_csv`.
2022-09-19 10:38:01 +01:00
Montana Low
4685fc73e0 grid search (#154)
* grid search draft
* hyperparam search for linear estimators
2022-09-19 10:31:56 +01:00
Montana Low
2e5f88fad8 Handle multiclass precision/recall (#152)
* handle multiclass precision/recall
2022-09-13 16:23:45 +01:00
dependabot[bot]
e445f0d558 Update criterion requirement from 0.3 to 0.4 (#150)
* Update criterion requirement from 0.3 to 0.4

Updates the requirements on [criterion](https://github.com/bheisler/criterion.rs) to permit the latest version.
- [Release notes](https://github.com/bheisler/criterion.rs/releases)
- [Changelog](https://github.com/bheisler/criterion.rs/blob/master/CHANGELOG.md)
- [Commits](https://github.com/bheisler/criterion.rs/compare/0.3.0...0.4.0)

---
updated-dependencies:
- dependency-name: criterion
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

* fix criterion

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-09-12 12:03:43 -04:00
Christos Katsakioris
4d5f64c758 Add serde for StandardScaler (#148)
* Derive `serde::Serialize` and `serde::Deserialize` for
  `StandardScaler`.
* Add relevant unit test.

Signed-off-by: Christos Katsakioris <ckatsak@gmail.com>

Signed-off-by: Christos Katsakioris <ckatsak@gmail.com>
2022-09-06 18:37:54 +01:00
Tim Toebrock
d305406dfd Implementation of Standard scaler (#143)
* docs: Fix typo in doc for categorical transformer.
* feat: Add option to take a column from Matrix.
I created the method `Matrix::take_column` that uses the `Matrix::take`-interface to extract a single column from a matrix. I need that feature in the implementation of  `StandardScaler`.
* feat: Add `StandardScaler`.
Authored-by: titoeb <timtoebrock@googlemail.com>
2022-08-26 15:20:20 +01:00
Lorenzo
3d2f4f71fa Add example for FastPair (#144)
* Add example

* Move to top

* Add imports to example

* Fix imports
2022-08-24 13:40:22 +01:00
Lorenzo
a1c56a859e Implement fastpair (#142)
* initial fastpair implementation
* FastPair initial implementation
* implement fastpair
* Add random test
* Add bench for fastpair
* Refactor with constructor for FastPair
* Add serialization for PairwiseDistance
* Add fp_bench feature for fastpair bench
2022-08-23 16:56:21 +01:00
Chris McComb
d905ebea15 Added additional doctest and fixed indices (#141) 2022-08-12 17:38:13 -04:00
morenol
b482acdc8d Fix clippy warnings (#139)
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-07-13 21:06:05 -04:00
ferrouille
b4a807eb9f Add SVC::decision_function (#135) 2022-06-21 12:48:16 -04:00
dependabot[bot]
ff456df0a4 Update nalgebra requirement from 0.23.0 to 0.31.0 (#128)
Updates the requirements on [nalgebra](https://github.com/dimforge/nalgebra) to permit the latest version.
- [Release notes](https://github.com/dimforge/nalgebra/releases)
- [Changelog](https://github.com/dimforge/nalgebra/blob/dev/CHANGELOG.md)
- [Commits](https://github.com/dimforge/nalgebra/compare/v0.23.0...v0.31.0)

---
updated-dependencies:
- dependency-name: nalgebra
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-05-11 13:14:14 -04:00
dependabot-preview[bot]
322610c7fb build(deps): update nalgebra requirement from 0.23.0 to 0.26.2 (#98)
* build(deps): update nalgebra requirement from 0.23.0 to 0.26.2

Updates the requirements on [nalgebra](https://github.com/dimforge/nalgebra) to permit the latest version.
- [Release notes](https://github.com/dimforge/nalgebra/releases)
- [Changelog](https://github.com/dimforge/nalgebra/blob/dev/CHANGELOG.md)
- [Commits](https://github.com/dimforge/nalgebra/compare/v0.23.0...v0.26.2)

Signed-off-by: dependabot-preview[bot] <support@dependabot.com>

* fix: updates for nalgebre

* test: explicitly call pow_mut from BaseVector since now it conflicts with nalgebra implementation

* Don't be strict with dependencies

Co-authored-by: dependabot-preview[bot] <27856297+dependabot-preview[bot]@users.noreply.github.com>
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
2022-05-11 13:04:27 -04:00
143 changed files with 21518 additions and 9536 deletions
+6
View File
@@ -0,0 +1,6 @@
# These owners will be the default owners for everything in
# the repo. Unless a later match takes precedence,
# Developers in this list will be requested for
# review when someone opens a pull request.
* @morenol
* @Mec-iS
+22
View File
@@ -0,0 +1,22 @@
# Code of Conduct
As contributors and maintainers of this project, and in the interest of fostering an open and welcoming community, we pledge to respect all people who contribute through reporting issues, posting feature requests, updating documentation, submitting pull requests or patches, and other activities.
We are committed to making participation in this project a harassment-free experience for everyone, regardless of level of experience, gender, gender identity and expression, sexual orientation, disability, personal appearance, body size, race, ethnicity, age, religion, or nationality.
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery
* Personal attacks
* Trolling or insulting/derogatory comments
* Public or private harassment
* Publishing other's private information, such as physical or electronic addresses, without explicit permission
* Other unethical or unprofessional conduct.
Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct. By adopting this Code of Conduct, project maintainers commit themselves to fairly and consistently applying these principles to every aspect of managing this project. Project maintainers who do not follow or enforce the Code of Conduct may be permanently removed from the project team.
This code of conduct applies both within project spaces and in public spaces when an individual is representing the project or its community.
Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by opening an issue or contacting one or more of the project maintainers.
This Code of Conduct is adapted from the [Contributor Covenant](http://contributor-covenant.org), version 1.2.0, available at [http://contributor-covenant.org/version/1/2/0/](http://contributor-covenant.org/version/1/2/0/)
+72
View File
@@ -0,0 +1,72 @@
# **Contributing**
When contributing to this repository, please first discuss the change you wish to make via issue,
email, or any other method with the owners of this repository before making a change.
Please note we have a [code of conduct](CODE_OF_CONDUCT.md), please follow it in all your interactions with the project.
## Background
We try to follow these principles:
* follow as much as possible the sklearn API to give a frictionless user experience for practitioners already familiar with it
* use only pure-Rust implementations for safety and future-proofing (with some low-level limited exceptions)
* do not use macros in the library code to allow readability and transparent behavior
* priority is not on "big data" dataset, try to be fast for small/average dataset with limited memory footprint.
## Pull Request Process
1. Open a PR following the template (erase the part of the template you don't need).
2. Update the CHANGELOG.md with details of changes to the interface if they are breaking changes, this includes new environment variables, exposed ports useful file locations and container parameters.
3. Pull Request can be merged in once you have the sign-off of one other developer, or if you do not have permission to do that you may request the reviewer to merge it for you.
### generic guidelines
Take a look to the conventions established by existing code:
* Every module should come with some reference to scientific literature that allows relating the code to research. Use the `//!` comments at the top of the module to tell readers about the basics of the procedure you are implementing.
* Every module should provide a Rust doctest, a brief test embedded with the documentation that explains how to use the procedure implemented.
* Every module should provide comprehensive tests at the end, in its `mod tests {}` sub-module. These tests can be flagged or not with configuration flags to allow WebAssembly target.
* Run `cargo doc --no-deps --open` and read the generated documentation in the browser to be sure that your changes reflects in the documentation and new code is documented.
#### digging deeper
* a nice overview of the codebase is given by [static analyzer](https://mozilla.github.io/rust-code-analysis/metrics.html):
```
$ cargo install rust-code-analysis-cli
// print metrics for every module
$ rust-code-analysis-cli -m -O json -o . -p src/ --pr
// print full AST for a module
$ rust-code-analysis-cli -p src/algorithm/neighbour/fastpair.rs --ls 22 --le 213 -d > ast.txt
```
* find more information about what happens in your binary with [`twiggy`](https://rustwasm.github.io/twiggy/install.html). This need a compiled binary so create a brief `main {}` function using `smartcore` and then point `twiggy` to that file.
* Please take a look to the output of a profiler to spot most evident performance problems, see [this guide about using a profiler](http://www.codeofview.com/fix-rs/2017/01/24/how-to-optimize-rust-programs-on-linux/).
## Issue Report Process
1. Go to the project's issues.
2. Select the template that better fits your issue.
3. Read carefully the instructions and write within the template guidelines.
4. Submit it and wait for support.
## Reviewing process
1. After a PR is opened maintainers are notified
2. Probably changes will be required to comply with the workflow, these commands are run automatically and all tests shall pass:
* **Formatting**: run `rustfmt src/*.rs` to apply automatic formatting
* **Linting**: `clippy` is used with command `cargo clippy --all-features -- -Drust-2018-idioms -Dwarnings`
* **Coverage** (optional): `tarpaulin` is used with command `cargo tarpaulin --out Lcov --all-features -- --test-threads 1`
* **Testing**: multiple test pipelines are run for different targets
3. When everything is OK, code is merged.
## Contribution Best Practices
* Read this [how-to about Github workflow here](https://guides.github.com/introduction/flow/) if you are not familiar with.
* Read all the texts related to [contributing for an OS community](https://github.com/HTTP-APIs/hydrus/tree/master/.github).
* Read this [how-to about writing a PR](https://github.com/blog/1943-how-to-write-the-perfect-pull-request) and this [other how-to about writing a issue](https://wiredcraft.com/blog/how-we-write-our-github-issues/)
* **read history**: search past open or closed issues for your problem before opening a new issue.
* **PRs on develop**: any change should be PRed first in `development`
* **testing**: everything should work and be tested as defined in the workflow. If any is failing for non-related reasons, annotate the test failure in the PR comment.
+43
View File
@@ -0,0 +1,43 @@
# smartcore: Introduction to modules
Important source of information:
* [Rust API guidelines](https://rust-lang.github.io/api-guidelines/about.html)
## Walkthrough: traits system and basic structures
#### numbers
The library is founded on basic traits provided by `num-traits`. Basic traits are in `src/numbers`. These traits are used to define all the procedures in the library to make everything safer and provide constraints to what implementations can handle.
#### linalg
`numbers` are made at use in linear algebra structures in the **`src/linalg/basic`** module. These sub-modules define the traits used all over the code base.
* *arrays*: In particular data structures like `Array`, `Array1` (1-dimensional), `Array2` (matrix, 2-D); plus their "views" traits. Views are used to provide no-footprint access to data, they have composed traits to allow writing (mutable traits: `MutArray`, `ArrayViewMut`, ...).
* *matrix*: This provides the main entrypoint to matrices operations and currently the only structure provided in the shape of `struct DenseMatrix`. A matrix can be instantiated and automatically make available all the traits in "arrays" (sparse matrices implementation will be provided).
* *vector*: Convenience traits are implemented for `std::Vec` to allow extensive reuse.
These are all traits and by definition they do not allow instantiation. For instantiable structures see implementation like `DenseMatrix` with relative constructor.
#### linalg/traits
The traits in `src/linalg/traits` are closely linked to Linear Algebra's theoretical framework. These traits are used to specify characteristics and constraints for types accepted by various algorithms. For example these allow to define if a matrix is `QRDecomposable` and/or `SVDDecomposable`. See docstring for referencese to theoretical framework.
As above these are all traits and by definition they do not allow instantiation. They are mostly used to provide constraints for implementations. For example, the implementation for Linear Regression requires the input data `X` to be in `smartcore`'s trait system `Array2<FloatNumber> + QRDecomposable<TX> + SVDDecomposable<TX>`, a 2-D matrix that is both QR and SVD decomposable; that is what the provided strucure `linalg::arrays::matrix::DenseMatrix` happens to be: `impl<T: FloatNumber> QRDecomposable<T> for DenseMatrix<T> {};impl<T: FloatNumber> SVDDecomposable<T> for DenseMatrix<T> {}`.
#### metrics
Implementations for metrics (classification, regression, cluster, ...) and distance measure (Euclidean, Hamming, Manhattan, ...). For example: `Accuracy`, `F1`, `AUC`, `Precision`, `R2`. As everything else in the code base, these implementations reuse `numbers` and `linalg` traits and structures.
These are collected in structures like `pub struct ClassificationMetrics<T> {}` that implements `metrics::Metrics`, these are groups of functions (classification, regression, cluster, ...) that provide instantiation for the structures. Each of those instantiation can be passed around using the relative function, like `pub fn accuracy<T: Number + RealNumber + FloatNumber, V: ArrayView1<T>>(y_true: &V, y_pred: &V) -> T`. This provides a mechanism for metrics to be passed to higher interfaces like the `cross_validate`:
```rust
let results =
cross_validate(
BiasedEstimator::new(), // custom estimator
&x, &y, // input data
NoParameters {}, // extra parameters
cv, // type of cross validator
&accuracy // **metrics function** <--------
).unwrap();
```
TODO: complete for all modules
## Notebooks
Proceed to the [**notebooks**](https://github.com/smartcorelib/smartcore-jupyter/) to see these modules in action.
+25
View File
@@ -0,0 +1,25 @@
### I'm submitting a
- [ ] bug report.
- [ ] improvement.
- [ ] feature request.
### Current Behaviour:
<!-- Describe about the bug -->
### Expected Behaviour:
<!-- Describe what will happen if bug is removed -->
### Steps to reproduce:
<!-- If you can then please provide the steps to reproduce the bug -->
### Snapshot:
<!-- If you can then please provide the screenshot of the issue you are facing -->
### Environment:
<!-- Please provide the following environment details if relevant -->
* rustc version
* cargo version
* OS details
### Do you want to work on this issue?
<!-- yes/no -->
+29
View File
@@ -0,0 +1,29 @@
<!-- Please create (if there is not one yet) a issue before sending a PR -->
<!-- Add issue number (Eg: fixes #123) -->
<!-- Always provide changes in existing tests or new tests -->
Fixes #
### Checklist
- [ ] My branch is up-to-date with development branch.
- [ ] Everything works and tested on latest stable Rust.
- [ ] Coverage and Linting have been applied
### Current behaviour
<!-- Describe the code you are going to change and its behaviour -->
### New expected behaviour
<!-- Describe the new code and its expected behaviour -->
### Change logs
<!-- #### Added -->
<!-- Edit these points below to describe the new features added with this PR -->
<!-- - Feature 1 -->
<!-- - Feature 2 -->
<!-- #### Changed -->
<!-- Edit these points below to describe the changes made in existing functionality with this PR -->
<!-- - Change 1 -->
<!-- - Change 1 -->
+46 -29
View File
@@ -2,56 +2,73 @@ name: CI
on:
push:
branches: [ main, development ]
branches: [main, development]
pull_request:
branches: [ development ]
branches: [development]
jobs:
tests:
runs-on: "${{ matrix.platform.os }}-latest"
strategy:
matrix:
platform: [
{ os: "windows", target: "x86_64-pc-windows-msvc" },
{ os: "windows", target: "i686-pc-windows-msvc" },
{ os: "ubuntu", target: "x86_64-unknown-linux-gnu" },
{ os: "ubuntu", target: "i686-unknown-linux-gnu" },
{ os: "ubuntu", target: "wasm32-unknown-unknown" },
{ os: "macos", target: "aarch64-apple-darwin" },
]
platform:
[
{ os: "windows", target: "x86_64-pc-windows-msvc" },
{ os: "windows", target: "i686-pc-windows-msvc" },
{ os: "ubuntu", target: "x86_64-unknown-linux-gnu" },
{ os: "ubuntu", target: "i686-unknown-linux-gnu" },
{ os: "ubuntu", target: "wasm32-unknown-unknown" },
{ os: "macos", target: "aarch64-apple-darwin" },
]
env:
TZ: "/usr/share/zoneinfo/your/location"
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Cache .cargo and target
uses: actions/cache@v2
uses: actions/cache@v4
with:
path: |
~/.cargo
./target
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.toml') }}
restore-keys: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.toml') }}
key: ${{ runner.os }}-cargo-${{ matrix.platform.target }}-${{ hashFiles('**/Cargo.toml') }}
restore-keys: ${{ runner.os }}-cargo-${{ matrix.platform.target }}
- name: Install Rust toolchain
uses: actions-rs/toolchain@v1
uses: dtolnay/rust-toolchain@stable
with:
toolchain: stable
target: ${{ matrix.platform.target }}
profile: minimal
default: true
targets: ${{ matrix.platform.target }}
- name: Install test runner for wasm
if: matrix.platform.target == 'wasm32-unknown-unknown'
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
- name: Stable Build
uses: actions-rs/cargo@v1
with:
command: build
args: --all-features --target ${{ matrix.platform.target }}
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
- name: Stable Build with all features
run: cargo build --all-features --target ${{ matrix.platform.target }}
- name: Stable Build without features
run: cargo build --target ${{ matrix.platform.target }}
- name: Tests
if: matrix.platform.target == 'x86_64-unknown-linux-gnu' || matrix.platform.target == 'x86_64-pc-windows-msvc' || matrix.platform.target == 'aarch64-apple-darwin'
uses: actions-rs/cargo@v1
with:
command: test
args: --all-features
run: cargo test --all-features
- name: Tests in WASM
if: matrix.platform.target == 'wasm32-unknown-unknown'
run: wasm-pack test --node -- --all-features
check_features:
runs-on: "${{ matrix.platform.os }}-latest"
strategy:
matrix:
platform: [{ os: "ubuntu" }]
features: ["--features serde", "--features datasets", ""]
env:
TZ: "/usr/share/zoneinfo/your/location"
steps:
- uses: actions/checkout@v4
- name: Cache .cargo and target
uses: actions/cache@v4
with:
path: |
~/.cargo
./target
key: ${{ runner.os }}-cargo-features-${{ hashFiles('Cargo.toml') }}
restore-keys: ${{ runner.os }}-cargo-features
- name: Install Rust toolchain
uses: dtolnay/rust-toolchain@stable
- name: Stable Build
run: cargo build --no-default-features ${{ matrix.features }}
+9 -20
View File
@@ -12,33 +12,22 @@ jobs:
env:
TZ: "/usr/share/zoneinfo/your/location"
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Cache .cargo
uses: actions/cache@v2
uses: actions/cache@v4
with:
path: |
~/.cargo
./target
key: ${{ runner.os }}-coverage-cargo-${{ hashFiles('**/Cargo.toml') }}
restore-keys: ${{ runner.os }}-coverage-cargo-${{ hashFiles('**/Cargo.toml') }}
key: ${{ runner.os }}-coverage-cargo-${{ hashFiles('Cargo.toml') }}
restore-keys: ${{ runner.os }}-coverage-cargo
- name: Install Rust toolchain
uses: actions-rs/toolchain@v1
with:
toolchain: nightly
profile: minimal
default: true
uses: dtolnay/rust-toolchain@nightly
- name: Install cargo-tarpaulin
uses: actions-rs/install@v0.1
with:
crate: cargo-tarpaulin
version: latest
use-tool-cache: true
run: cargo install cargo-tarpaulin
- name: Run cargo-tarpaulin
uses: actions-rs/cargo@v1
with:
command: tarpaulin
args: --out Lcov --all-features -- --test-threads 1
run: cargo tarpaulin --out Lcov --all-features -- --test-threads 1
- name: Upload to codecov.io
uses: codecov/codecov-action@v1
uses: codecov/codecov-action@v4
with:
fail_ci_if_error: true
fail_ci_if_error: false
+10 -19
View File
@@ -6,36 +6,27 @@ on:
pull_request:
branches: [ development ]
jobs:
lint:
runs-on: ubuntu-latest
env:
TZ: "/usr/share/zoneinfo/your/location"
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Cache .cargo and target
uses: actions/cache@v2
uses: actions/cache@v4
with:
path: |
~/.cargo
./target
key: ${{ runner.os }}-lint-cargo-${{ hashFiles('**/Cargo.toml') }}
restore-keys: ${{ runner.os }}-lint-cargo-${{ hashFiles('**/Cargo.toml') }}
key: ${{ runner.os }}-lint-cargo-${{ hashFiles('Cargo.toml') }}
restore-keys: ${{ runner.os }}-lint-cargo
- name: Install Rust toolchain
uses: actions-rs/toolchain@v1
uses: dtolnay/rust-toolchain@stable
with:
toolchain: stable
profile: minimal
default: true
- run: rustup component add rustfmt
- name: Check formt
uses: actions-rs/cargo@v1
with:
command: fmt
args: --all -- --check
- run: rustup component add clippy
components: rustfmt, clippy
- name: Check format
run: cargo fmt --all -- --check
- name: Run clippy
uses: actions-rs/cargo@v1
with:
command: clippy
args: --all-features -- -Drust-2018-idioms -Dwarnings
run: cargo clippy --all-features -- -Drust-2018-idioms -Dwarnings
+12
View File
@@ -17,3 +17,15 @@ smartcore.code-workspace
# OS
.DS_Store
flamegraph.svg
perf.data
perf.data.old
src.dot
out.svg
FlameGraph/
out.stacks
*.json
*.txt
+34 -1
View File
@@ -4,7 +4,40 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
## [0.4.8] - 2025-11-29
- WARNING: Breaking changes!
- `LassoParameters` and `LassoSearchParameters` have a new field `fit_intercept`. When it is set to false, the `beta_0` term in the formula will be forced to zero, and `intercept` field in `Lasso` will be set to `None`.
## [0.4.0] - 2023-04-05
## Added
- WARNING: Breaking changes!
- `DenseMatrix` constructor now returns `Result` to avoid user instantiating inconsistent rows/cols count. Their return values need to be unwrapped with `unwrap()`, see tests
## [0.3.0] - 2022-11-09
## Added
- WARNING: Breaking changes!
- Complete refactoring with **extensive API changes** that includes:
* moving to a new traits system, less structs more traits
* adapting all the modules to the new traits system
* moving to Rust 2021, use of object-safe traits and `as_ref`
* reorganization of the code base, eliminate duplicates
- implements `readers` (needs "serde" feature) for read/write CSV file, extendible to other formats
- default feature is now Wasm-/Wasi-first
## Changed
- WARNING: Breaking changes!
- Seeds to multiple algorithims that depend on random number generation
- Added a new parameter to `train_test_split` to define the seed
- changed use of "serde" feature
## Dropped
- WARNING: Breaking changes!
- Drop `nalgebra-bindings` feature, only `ndarray` as supported library
## [0.2.1] - 2021-05-10
## Added
- L2 regularization penalty to the Logistic Regression
+41
View File
@@ -0,0 +1,41 @@
cff-version: 1.2.0
message: "If this software contributes to published work, please cite smartcore."
type: software
title: "smartcore: Machine Learning in Rust"
abstract: "smartcore is a comprehensive machine learning and numerical computing library for Rust, offering supervised and unsupervised algorithms, model evaluation tools, and linear algebra abstractions, with optional ndarray integration." [web:5][web:3]
repository-code: "https://github.com/smartcorelib/smartcore" [web:5]
url: "https://github.com/smartcorelib" [web:3]
license: "MIT" [web:13]
keywords:
- Rust
- machine learning
- numerical computing
- linear algebra
- classification
- regression
- clustering
- SVM
- Random Forest
- XGBoost [web:5]
authors:
- name: "smartcore Developers" [web:7]
- name: "Lorenzo (contributor)" [web:16]
- name: "Community contributors" [web:7]
version: "0.4.2" [attached_file:1]
date-released: "2025-09-14" [attached_file:1]
preferred-citation:
type: software
title: "smartcore: Machine Learning in Rust"
authors:
- name: "smartcore Developers" [web:7]
url: "https://github.com/smartcorelib" [web:3]
repository-code: "https://github.com/smartcorelib/smartcore" [web:5]
license: "MIT" [web:13]
references:
- type: manual
title: "smartcore Documentation"
url: "https://docs.rs/smartcore" [web:5]
- type: webpage
title: "smartcore Homepage"
url: "https://github.com/smartcorelib" [web:3]
notes: "For development features, see the docs.rs page and the repository README; SmartCore includes algorithms such as SVM, Random Forest, K-Means, PCA, DBSCAN, and XGBoost." [web:5]
+44 -26
View File
@@ -1,48 +1,66 @@
[package]
name = "smartcore"
description = "The most advanced machine learning library in rust."
description = "Machine Learning in Rust."
homepage = "https://smartcorelib.org"
version = "0.2.1"
authors = ["SmartCore Developers"]
edition = "2018"
version = "0.4.9"
authors = ["smartcore Developers"]
edition = "2021"
license = "Apache-2.0"
documentation = "https://docs.rs/smartcore"
repository = "https://github.com/smartcorelib/smartcore"
readme = "README.md"
keywords = ["machine-learning", "statistical", "ai", "optimization", "linear-algebra"]
categories = ["science"]
[features]
default = ["datasets"]
ndarray-bindings = ["ndarray"]
nalgebra-bindings = ["nalgebra"]
datasets = []
exclude = [
".github",
".gitignore",
"smartcore.iml",
"smartcore.svg",
"tests/"
]
[dependencies]
approx = "0.5.1"
cfg-if = "1.0.0"
ndarray = { version = "0.15", optional = true }
nalgebra = { version = "0.23.0", optional = true }
num-traits = "0.2.12"
num = "0.4.0"
rand = "0.8.3"
rand_distr = "0.4.0"
serde = { version = "1.0.115", features = ["derive"], optional = true }
num = "0.4"
rand = { version = "0.8.5", default-features = false, features = ["small_rng"] }
rand_distr = { version = "0.5", optional = true }
serde = { version = "1", features = ["derive"], optional = true }
ordered-float = "5.1.0"
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
typetag = { version = "0.2", optional = true }
[features]
default = []
serde = ["dep:serde", "dep:typetag"]
ndarray-bindings = ["dep:ndarray"]
datasets = ["dep:rand_distr", "std_rand", "serde"]
std_rand = ["rand/std_rng", "rand/std"]
# used by wasm32-unknown-unknown for in-browser usage
js = ["getrandom/js"]
[target.'cfg(target_arch = "wasm32")'.dependencies]
getrandom = { version = "0.2", features = ["js"] }
getrandom = { version = "0.2.8", optional = true }
[target.'cfg(all(target_arch = "wasm32", not(target_os = "wasi")))'.dev-dependencies]
wasm-bindgen-test = "0.3"
[dev-dependencies]
criterion = "0.3"
itertools = "0.13.0"
serde_json = "1.0"
bincode = "1.3.1"
[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
wasm-bindgen-test = "0.3"
[workspace]
[[bench]]
name = "distance"
harness = false
[profile.test]
debug = 1
opt-level = 3
[[bench]]
name = "naive_bayes"
harness = false
required-features = ["ndarray-bindings", "nalgebra-bindings"]
[profile.release]
strip = true
lto = true
codegen-units = 1
overflow-checks = true
+1 -1
View File
@@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Copyright 2019-present at smartcore developers (smartcorelib.org)
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
+133 -4
View File
@@ -1,18 +1,147 @@
<p align="center">
<a href="https://smartcorelib.org">
<img src="smartcore.svg" width="450" alt="SmartCore">
<img src="smartcore.svg" width="450" alt="smartcore">
</a>
</p>
<p align = "center">
<strong>
<a href="https://smartcorelib.org">User guide</a> | <a href="https://docs.rs/smartcore/">API</a> | <a href="https://github.com/smartcorelib/smartcore-examples">Examples</a>
<a href="https://smartcorelib.org">User guide</a> | <a href="https://docs.rs/smartcore/">API</a> | <a href="https://github.com/smartcorelib/smartcore-jupyter">Notebooks</a>
</strong>
</p>
-----
<p align = "center">
<b>The Most Advanced Machine Learning Library In Rust.</b>
<b>Machine Learning in Rust</b>
</p>
-----
-----
[![CI](https://github.com/smartcorelib/smartcore/actions/workflows/ci.yml/badge.svg)](https://github.com/smartcorelib/smartcore/actions/workflows/ci.yml) [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.17219259.svg)](https://doi.org/10.5281/zenodo.17219259)
To start getting familiar with the new smartcore v0.4 API, there is now available a [**Jupyter Notebook environment repository**](https://github.com/smartcorelib/smartcore-jupyter). Please see instructions there, contributions welcome see [CONTRIBUTING](.github/CONTRIBUTING.md).
smartcore is a fast, ergonomic machine learning library for Rust, covering classical supervised and unsupervised methods with a modular linear algebra abstraction and optional ndarray support. It aims to provide production-friendly APIs, strong typing, and good defaults while remaining flexible for research and experimentation.
## Highlights
- Broad algorithm coverage: linear models, tree-based methods, ensembles, SVMs, neighbors, clustering, decomposition, and preprocessing.
- Strong linear algebra traits with optional ndarray integration for users who prefer array-first workflows.
- WASM-first defaults with attention to portability; features such as serde and datasets are opt-in.
- Practical utilities for model selection, evaluation, readers (CSV), dataset generators, and built-in sample datasets.
## Install
Add to Cargo.toml:
```toml
[dependencies]
smartcore = "^0.4.3"
```
For the latest development branch:
```toml
[dependencies]
smartcore = { git = "https://github.com/smartcorelib/smartcore", branch = "development" }
```
Optional features (examples):
- datasets
- serde
- ndarray-bindings (deprecated in favor of ndarray-only support per recent changes)
Check Cargo.toml for available features and compatibility notes.
## Quick start
Here is a minimal example fitting a KNN classifier from native Rust vectors using DenseMatrix:
```rust
use smartcore::linalg::basic::matrix::DenseMatrix;
use smartcore::neighbors::knn_classifier::KNNClassifier;
// Turn vector slices into a matrix
let x = DenseMatrix::from_2d_array(&[
&[1., 2.],
&[3., 4.],
&[5., 6.],
&[7., 8.],
&[9., 10.],
]).unwrap;
// Class labels
let y = vec![2, 2, 2, 3, 3];
// Train classifier
let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
// Predict
let yhat = knn.predict(&x).unwrap();
```
This example mirrors the “First Example” section of the crate docs and demonstrates smartcores ergonomic API surface.
## Algorithms
smartcore organizes algorithms into clear modules with consistent traits:
- Clustering: K-Means, DBSCAN, agglomerative (including single-linkage), with K-Means++ initialization and utilities.
- Matrix decomposition: SVD, EVD, Cholesky, LU, QR, plus related linear algebra helpers.
- Linear models: OLS, Ridge, Lasso, ElasticNet, Logistic Regression.
- Ensemble and tree-based: Random Forest (classifier and regressor), Extra Trees, shared reusable components across trees and forests.
- SVM: SVC/SVR with kernel enum support and multiclass extensions.
- Neighbors: KNN classification and regression with distance metrics and fast selection helpers.
- Naive Bayes: Gaussian, Bernoulli, Categorical, Multinomial.
- Preprocessing: encoders, split utilities, and common transforms.
- Model selection and metrics: K-fold, search parameters, and evaluation utilities.
Recent refactors emphasize reusable components in trees/forests and expanded multiclass SVM capabilities. XGBoost-style regression and single-linkage clustering have been added. See CHANGELOG for API changes and migration notes.
## Data access and readers
- CSV readers: Read matrices from CSV with configurable delimiter and header rows, with helpful error messages and testing utilities (including non-IO reader abstractions).
- Dataset generators: make_blobs, make_circles, make_moons for quick experiments.
- Built-in datasets (feature-gated): digits, diabetes, breast cancer, boston, with serialization utilities to persist or refresh .xy bundles.
## WebAssembly and portability
smartcore adopts a WASM/WASI-first posture in defaults to ease browser and embedded deployments. Some file-system operations are restricted in wasm targets; tests and IO utilities are structured to avoid unsupported calls where possible. Enable features like serde selectively to minimize footprint. Consult module-level docs and CHANGELOG for target-specific caveats.
## Notebooks
A curated set of Jupyter notebooks is available via the companion repository to explore smartcore interactively. To run locally, use EVCXR to enable Rust notebooks. This is the recommended path to quickly experiment with the v0.4 API.
## Roadmap and recent changes
- Trait-system refactor, fewer structs and more object-safe traits, large codebase reorganization.
- Move to Rust 2021 edition and cleanup of duplicate code paths.
- Seeds and deterministic controls across algorithms using RNG plumbing.
- Search parameter API for hyperparameter exploration in K-Means and SVM families.
- Tree and forest components refactored for reuse; Extra Trees added.
- SVM multiclass support; SVR kernel enum and related improvements.
- XGBoost-style regression introduced; single-linkage clustering implemented.
See CHANGELOG.md for precise details, deprecations, and breaking changes. Some features like nalgebra-bindings have been dropped in favor of ndarray-only paths. Default features are tuned for WASM/WASI builds; enable serde/datasets as needed.
## Contributing
Contributions are welcome:
- Open an issue describing the change and link it in the PR.
- Keep PRs in sync with the development branch and ensure tests pass on stable Rust.
- Provide or update tests; run clippy and apply formatting. Coverage and linting are part of the workflow.
- Use the provided PR and issue templates to describe behavior changes, new features, and expectations.
If adding IO, prefer abstractions that make non-IO testing straightforward (see readers/iotesting). For datasets, keep serialization helpers in tests gated appropriately to avoid unintended file writes in wasm targets.
## License
smartcore is open source under a permissive license; see Cargo.toml and LICENSE for details. The crate metadata identifies “smartcore Developers” as authors; community contributions are credited via Git history and releases.
## Acknowledgments
smartcores design incorporates well-known ML patterns while staying idiomatic to Rust. Thanks to all contributors who have helped expand algorithms, improve docs, modernize traits, and harden the codebase for production.
-18
View File
@@ -1,18 +0,0 @@
#[macro_use]
extern crate criterion;
extern crate smartcore;
use criterion::black_box;
use criterion::Criterion;
use smartcore::math::distance::*;
fn criterion_benchmark(c: &mut Criterion) {
let a = vec![1., 2., 3.];
c.bench_function("Euclidean Distance", move |b| {
b.iter(|| Distances::euclidian().distance(black_box(&a), black_box(&a)))
});
}
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
-73
View File
@@ -1,73 +0,0 @@
use criterion::BenchmarkId;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use nalgebra::DMatrix;
use ndarray::Array2;
use smartcore::linalg::naive::dense_matrix::DenseMatrix;
use smartcore::linalg::BaseMatrix;
use smartcore::linalg::BaseVector;
use smartcore::naive_bayes::gaussian::GaussianNB;
pub fn gaussian_naive_bayes_fit_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("GaussianNB::fit");
for n_samples in [100_usize, 1000_usize, 10000_usize].iter() {
for n_features in [10_usize, 100_usize, 1000_usize].iter() {
let x = DenseMatrix::<f64>::rand(*n_samples, *n_features);
let y: Vec<f64> = (0..*n_samples)
.map(|i| (i % *n_samples / 5_usize) as f64)
.collect::<Vec<f64>>();
group.bench_with_input(
BenchmarkId::from_parameter(format!(
"n_samples: {}, n_features: {}",
n_samples, n_features
)),
n_samples,
|b, _| {
b.iter(|| {
GaussianNB::fit(black_box(&x), black_box(&y), Default::default()).unwrap();
})
},
);
}
}
group.finish();
}
pub fn gaussian_naive_matrix_datastructure(c: &mut Criterion) {
let mut group = c.benchmark_group("GaussianNB");
let classes = (0..10000).map(|i| (i % 25) as f64).collect::<Vec<f64>>();
group.bench_function("DenseMatrix", |b| {
let x = DenseMatrix::<f64>::rand(10000, 500);
let y = <DenseMatrix<f64> as BaseMatrix<f64>>::RowVector::from_array(&classes);
b.iter(|| {
GaussianNB::fit(black_box(&x), black_box(&y), Default::default()).unwrap();
})
});
group.bench_function("ndarray", |b| {
let x = Array2::<f64>::rand(10000, 500);
let y = <Array2<f64> as BaseMatrix<f64>>::RowVector::from_array(&classes);
b.iter(|| {
GaussianNB::fit(black_box(&x), black_box(&y), Default::default()).unwrap();
})
});
group.bench_function("ndalgebra", |b| {
let x = DMatrix::<f64>::rand(10000, 500);
let y = <DMatrix<f64> as BaseMatrix<f64>>::RowVector::from_array(&classes);
b.iter(|| {
GaussianNB::fit(black_box(&x), black_box(&y), Default::default()).unwrap();
})
});
}
criterion_group!(
benches,
gaussian_naive_bayes_fit_benchmark,
gaussian_naive_matrix_datastructure
);
criterion_main!(benches);
-15
View File
@@ -1,15 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="RUST_MODULE" version="4">
<component name="NewModuleRootManager" inherit-compiler-output="true">
<exclude-output />
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/examples" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/tests" isTestSource="true" />
<sourceFolder url="file://$MODULE_DIR$/benches" isTestSource="true" />
<excludeFolder url="file://$MODULE_DIR$/target" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>
+1 -1
View File
@@ -76,5 +76,5 @@
y="81.876823"
x="91.861809"
id="tspan842"
sodipodi:role="line">SmartCore</tspan></text>
sodipodi:role="line">smartcore</tspan></text>
</svg>

Before

Width:  |  Height:  |  Size: 2.5 KiB

After

Width:  |  Height:  |  Size: 2.5 KiB

+66 -60
View File
@@ -1,50 +1,50 @@
use std::fmt::Debug;
use crate::linalg::Matrix;
use crate::math::distance::euclidian::*;
use crate::math::num::RealNumber;
use crate::linalg::basic::arrays::Array2;
use crate::metrics::distance::euclidian::*;
use crate::numbers::basenum::Number;
#[derive(Debug)]
pub struct BBDTree<T: RealNumber> {
nodes: Vec<BBDTreeNode<T>>,
pub struct BBDTree {
nodes: Vec<BBDTreeNode>,
index: Vec<usize>,
root: usize,
}
#[derive(Debug)]
struct BBDTreeNode<T: RealNumber> {
struct BBDTreeNode {
count: usize,
index: usize,
center: Vec<T>,
radius: Vec<T>,
sum: Vec<T>,
cost: T,
center: Vec<f64>,
radius: Vec<f64>,
sum: Vec<f64>,
cost: f64,
lower: Option<usize>,
upper: Option<usize>,
}
impl<T: RealNumber> BBDTreeNode<T> {
fn new(d: usize) -> BBDTreeNode<T> {
impl BBDTreeNode {
fn new(d: usize) -> BBDTreeNode {
BBDTreeNode {
count: 0,
index: 0,
center: vec![T::zero(); d],
radius: vec![T::zero(); d],
sum: vec![T::zero(); d],
cost: T::zero(),
center: vec![0f64; d],
radius: vec![0f64; d],
sum: vec![0f64; d],
cost: 0f64,
lower: Option::None,
upper: Option::None,
}
}
}
impl<T: RealNumber> BBDTree<T> {
pub fn new<M: Matrix<T>>(data: &M) -> BBDTree<T> {
let nodes = Vec::new();
impl BBDTree {
pub fn new<T: Number, M: Array2<T>>(data: &M) -> BBDTree {
let nodes: Vec<BBDTreeNode> = Vec::new();
let (n, _) = data.shape();
let index = (0..n).collect::<Vec<_>>();
let index = (0..n).collect::<Vec<usize>>();
let mut tree = BBDTree {
nodes,
@@ -59,20 +59,20 @@ impl<T: RealNumber> BBDTree<T> {
tree
}
pub(in crate) fn clustering(
pub(crate) fn clustering(
&self,
centroids: &[Vec<T>],
sums: &mut Vec<Vec<T>>,
centroids: &[Vec<f64>],
sums: &mut Vec<Vec<f64>>,
counts: &mut Vec<usize>,
membership: &mut Vec<usize>,
) -> T {
) -> f64 {
let k = centroids.len();
counts.iter_mut().for_each(|v| *v = 0);
let mut candidates = vec![0; k];
for i in 0..k {
candidates[i] = i;
sums[i].iter_mut().for_each(|v| *v = T::zero());
sums[i].iter_mut().for_each(|v| *v = 0f64);
}
self.filter(
@@ -89,13 +89,13 @@ impl<T: RealNumber> BBDTree<T> {
fn filter(
&self,
node: usize,
centroids: &[Vec<T>],
centroids: &[Vec<f64>],
candidates: &[usize],
k: usize,
sums: &mut Vec<Vec<T>>,
sums: &mut Vec<Vec<f64>>,
counts: &mut Vec<usize>,
membership: &mut Vec<usize>,
) -> T {
) -> f64 {
let d = centroids[0].len();
let mut min_dist =
@@ -163,9 +163,9 @@ impl<T: RealNumber> BBDTree<T> {
}
fn prune(
center: &[T],
radius: &[T],
centroids: &[Vec<T>],
center: &[f64],
radius: &[f64],
centroids: &[Vec<f64>],
best_index: usize,
test_index: usize,
) -> bool {
@@ -177,22 +177,22 @@ impl<T: RealNumber> BBDTree<T> {
let best = &centroids[best_index];
let test = &centroids[test_index];
let mut lhs = T::zero();
let mut rhs = T::zero();
let mut lhs = 0f64;
let mut rhs = 0f64;
for i in 0..d {
let diff = test[i] - best[i];
lhs += diff * diff;
if diff > T::zero() {
if diff > 0f64 {
rhs += (center[i] + radius[i] - best[i]) * diff;
} else {
rhs += (center[i] - radius[i] - best[i]) * diff;
}
}
lhs >= T::two() * rhs
lhs >= 2f64 * rhs
}
fn build_node<M: Matrix<T>>(&mut self, data: &M, begin: usize, end: usize) -> usize {
fn build_node<T: Number, M: Array2<T>>(&mut self, data: &M, begin: usize, end: usize) -> usize {
let (_, d) = data.shape();
let mut node = BBDTreeNode::new(d);
@@ -200,17 +200,17 @@ impl<T: RealNumber> BBDTree<T> {
node.count = end - begin;
node.index = begin;
let mut lower_bound = vec![T::zero(); d];
let mut upper_bound = vec![T::zero(); d];
let mut lower_bound = vec![0f64; d];
let mut upper_bound = vec![0f64; d];
for i in 0..d {
lower_bound[i] = data.get(self.index[begin], i);
upper_bound[i] = data.get(self.index[begin], i);
lower_bound[i] = data.get((self.index[begin], i)).to_f64().unwrap();
upper_bound[i] = data.get((self.index[begin], i)).to_f64().unwrap();
}
for i in begin..end {
for j in 0..d {
let c = data.get(self.index[i], j);
let c = data.get((self.index[i], j)).to_f64().unwrap();
if lower_bound[j] > c {
lower_bound[j] = c;
}
@@ -220,32 +220,32 @@ impl<T: RealNumber> BBDTree<T> {
}
}
let mut max_radius = T::from(-1.).unwrap();
let mut max_radius = -1f64;
let mut split_index = 0;
for i in 0..d {
node.center[i] = (lower_bound[i] + upper_bound[i]) / T::two();
node.radius[i] = (upper_bound[i] - lower_bound[i]) / T::two();
node.center[i] = (lower_bound[i] + upper_bound[i]) / 2f64;
node.radius[i] = (upper_bound[i] - lower_bound[i]) / 2f64;
if node.radius[i] > max_radius {
max_radius = node.radius[i];
split_index = i;
}
}
if max_radius < T::from(1E-10).unwrap() {
if max_radius < 1E-10 {
node.lower = Option::None;
node.upper = Option::None;
for i in 0..d {
node.sum[i] = data.get(self.index[begin], i);
node.sum[i] = data.get((self.index[begin], i)).to_f64().unwrap();
}
if end > begin + 1 {
let len = end - begin;
for i in 0..d {
node.sum[i] *= T::from(len).unwrap();
node.sum[i] *= len as f64;
}
}
node.cost = T::zero();
node.cost = 0f64;
return self.add_node(node);
}
@@ -254,8 +254,10 @@ impl<T: RealNumber> BBDTree<T> {
let mut i2 = end - 1;
let mut size = 0;
while i1 <= i2 {
let mut i1_good = data.get(self.index[i1], split_index) < split_cutoff;
let mut i2_good = data.get(self.index[i2], split_index) >= split_cutoff;
let mut i1_good =
data.get((self.index[i1], split_index)).to_f64().unwrap() < split_cutoff;
let mut i2_good =
data.get((self.index[i2], split_index)).to_f64().unwrap() >= split_cutoff;
if !i1_good && !i2_good {
self.index.swap(i1, i2);
@@ -281,9 +283,9 @@ impl<T: RealNumber> BBDTree<T> {
self.nodes[node.lower.unwrap()].sum[i] + self.nodes[node.upper.unwrap()].sum[i];
}
let mut mean = vec![T::zero(); d];
let mut mean = vec![0f64; d];
for (i, mean_i) in mean.iter_mut().enumerate().take(d) {
*mean_i = node.sum[i] / T::from(node.count).unwrap();
*mean_i = node.sum[i] / node.count as f64;
}
node.cost = BBDTree::node_cost(&self.nodes[node.lower.unwrap()], &mean)
@@ -292,17 +294,17 @@ impl<T: RealNumber> BBDTree<T> {
self.add_node(node)
}
fn node_cost(node: &BBDTreeNode<T>, center: &[T]) -> T {
fn node_cost(node: &BBDTreeNode, center: &[f64]) -> f64 {
let d = center.len();
let mut scatter = T::zero();
let mut scatter = 0f64;
for (i, center_i) in center.iter().enumerate().take(d) {
let x = (node.sum[i] / T::from(node.count).unwrap()) - *center_i;
let x = (node.sum[i] / node.count as f64) - *center_i;
scatter += x * x;
}
node.cost + T::from(node.count).unwrap() * scatter
node.cost + node.count as f64 * scatter
}
fn add_node(&mut self, new_node: BBDTreeNode<T>) -> usize {
fn add_node(&mut self, new_node: BBDTreeNode) -> usize {
let idx = self.nodes.len();
self.nodes.push(new_node);
idx
@@ -312,9 +314,12 @@ impl<T: RealNumber> BBDTree<T> {
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::linalg::basic::matrix::DenseMatrix;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn bbdtree_iris() {
let data = DenseMatrix::from_2d_array(&[
@@ -338,7 +343,8 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
])
.unwrap();
let tree = BBDTree::new(&data);
File diff suppressed because it is too large Load Diff
+83 -75
View File
@@ -4,12 +4,12 @@
//!
//! ```
//! use smartcore::algorithm::neighbour::cover_tree::*;
//! use smartcore::math::distance::Distance;
//! use smartcore::metrics::distance::Distance;
//!
//! #[derive(Clone)]
//! struct SimpleDistance {} // Our distance function
//!
//! impl Distance<i32, f64> for SimpleDistance {
//! impl Distance<i32> for SimpleDistance {
//! fn distance(&self, a: &i32, b: &i32) -> f64 { // simple simmetrical scalar distance
//! (a - b).abs() as f64
//! }
@@ -29,28 +29,27 @@ use serde::{Deserialize, Serialize};
use crate::algorithm::sort::heap_select::HeapSelection;
use crate::error::{Failed, FailedError};
use crate::math::distance::Distance;
use crate::math::num::RealNumber;
use crate::metrics::distance::Distance;
/// Implements Cover Tree algorithm
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct CoverTree<T, F: RealNumber, D: Distance<T, F>> {
base: F,
inv_log_base: F,
pub struct CoverTree<T, D: Distance<T>> {
base: f64,
inv_log_base: f64,
distance: D,
root: Node<F>,
root: Node,
data: Vec<T>,
identical_excluded: bool,
}
impl<T, F: RealNumber, D: Distance<T, F>> PartialEq for CoverTree<T, F, D> {
impl<T, D: Distance<T>> PartialEq for CoverTree<T, D> {
fn eq(&self, other: &Self) -> bool {
if self.data.len() != other.data.len() {
return false;
}
for i in 0..self.data.len() {
if self.distance.distance(&self.data[i], &other.data[i]) != F::zero() {
if self.distance.distance(&self.data[i], &other.data[i]) != 0f64 {
return false;
}
}
@@ -60,36 +59,36 @@ impl<T, F: RealNumber, D: Distance<T, F>> PartialEq for CoverTree<T, F, D> {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
struct Node<F: RealNumber> {
struct Node {
idx: usize,
max_dist: F,
parent_dist: F,
children: Vec<Node<F>>,
max_dist: f64,
parent_dist: f64,
children: Vec<Node>,
_scale: i64,
}
#[derive(Debug)]
struct DistanceSet<F: RealNumber> {
struct DistanceSet {
idx: usize,
dist: Vec<F>,
dist: Vec<f64>,
}
impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D> {
impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
/// Construct a cover tree.
/// * `data` - vector of data points to search for.
/// * `distance` - distance metric to use for searching. This function should extend [`Distance`](../../../math/distance/index.html) interface.
pub fn new(data: Vec<T>, distance: D) -> Result<CoverTree<T, F, D>, Failed> {
let base = F::from_f64(1.3).unwrap();
pub fn new(data: Vec<T>, distance: D) -> Result<CoverTree<T, D>, Failed> {
let base = 1.3f64;
let root = Node {
idx: 0,
max_dist: F::zero(),
parent_dist: F::zero(),
max_dist: 0f64,
parent_dist: 0f64,
children: Vec::new(),
_scale: 0,
};
let mut tree = CoverTree {
base,
inv_log_base: F::one() / base.ln(),
inv_log_base: 1f64 / base.ln(),
distance,
root,
data,
@@ -104,7 +103,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
/// Find k nearest neighbors of `p`
/// * `p` - look for k nearest points to `p`
/// * `k` - the number of nearest neighbors to return
pub fn find(&self, p: &T, k: usize) -> Result<Vec<(usize, F, &T)>, Failed> {
pub fn find(&self, p: &T, k: usize) -> Result<Vec<(usize, f64, &T)>, Failed> {
if k == 0 {
return Err(Failed::because(FailedError::FindFailed, "k should be > 0"));
}
@@ -119,13 +118,13 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
let e = self.get_data_value(self.root.idx);
let mut d = self.distance.distance(e, p);
let mut current_cover_set: Vec<(F, &Node<F>)> = Vec::new();
let mut zero_set: Vec<(F, &Node<F>)> = Vec::new();
let mut current_cover_set: Vec<(f64, &Node)> = Vec::new();
let mut zero_set: Vec<(f64, &Node)> = Vec::new();
current_cover_set.push((d, &self.root));
let mut heap = HeapSelection::with_capacity(k);
heap.add(F::max_value());
heap.add(f64::MAX);
let mut empty_heap = true;
if !self.identical_excluded || self.get_data_value(self.root.idx) != p {
@@ -134,7 +133,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
}
while !current_cover_set.is_empty() {
let mut next_cover_set: Vec<(F, &Node<F>)> = Vec::new();
let mut next_cover_set: Vec<(f64, &Node)> = Vec::new();
for par in current_cover_set {
let parent = par.1;
for c in 0..parent.children.len() {
@@ -146,7 +145,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
}
let upper_bound = if empty_heap {
F::infinity()
f64::INFINITY
} else {
*heap.peek()
};
@@ -169,7 +168,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
current_cover_set = next_cover_set;
}
let mut neighbors: Vec<(usize, F, &T)> = Vec::new();
let mut neighbors: Vec<(usize, f64, &T)> = Vec::new();
let upper_bound = *heap.peek();
for ds in zero_set {
if ds.0 <= upper_bound {
@@ -189,25 +188,25 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
/// Find all nearest neighbors within radius `radius` from `p`
/// * `p` - look for k nearest points to `p`
/// * `radius` - radius of the search
pub fn find_radius(&self, p: &T, radius: F) -> Result<Vec<(usize, F, &T)>, Failed> {
if radius <= F::zero() {
pub fn find_radius(&self, p: &T, radius: f64) -> Result<Vec<(usize, f64, &T)>, Failed> {
if radius <= 0f64 {
return Err(Failed::because(
FailedError::FindFailed,
"radius should be > 0",
));
}
let mut neighbors: Vec<(usize, F, &T)> = Vec::new();
let mut neighbors: Vec<(usize, f64, &T)> = Vec::new();
let mut current_cover_set: Vec<(F, &Node<F>)> = Vec::new();
let mut zero_set: Vec<(F, &Node<F>)> = Vec::new();
let mut current_cover_set: Vec<(f64, &Node)> = Vec::new();
let mut zero_set: Vec<(f64, &Node)> = Vec::new();
let e = self.get_data_value(self.root.idx);
let mut d = self.distance.distance(e, p);
current_cover_set.push((d, &self.root));
while !current_cover_set.is_empty() {
let mut next_cover_set: Vec<(F, &Node<F>)> = Vec::new();
let mut next_cover_set: Vec<(f64, &Node)> = Vec::new();
for par in current_cover_set {
let parent = par.1;
for c in 0..parent.children.len() {
@@ -240,23 +239,23 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
Ok(neighbors)
}
fn new_leaf(&self, idx: usize) -> Node<F> {
fn new_leaf(&self, idx: usize) -> Node {
Node {
idx,
max_dist: F::zero(),
parent_dist: F::zero(),
max_dist: 0f64,
parent_dist: 0f64,
children: Vec::new(),
_scale: 100,
}
}
fn build_cover_tree(&mut self) {
let mut point_set: Vec<DistanceSet<F>> = Vec::new();
let mut consumed_set: Vec<DistanceSet<F>> = Vec::new();
let mut point_set: Vec<DistanceSet> = Vec::new();
let mut consumed_set: Vec<DistanceSet> = Vec::new();
let point = &self.data[0];
let idx = 0;
let mut max_dist = -F::one();
let mut max_dist = -1f64;
for i in 1..self.data.len() {
let dist = self.distance.distance(point, &self.data[i]);
@@ -284,16 +283,16 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
p: usize,
max_scale: i64,
top_scale: i64,
point_set: &mut Vec<DistanceSet<F>>,
consumed_set: &mut Vec<DistanceSet<F>>,
) -> Node<F> {
point_set: &mut Vec<DistanceSet>,
consumed_set: &mut Vec<DistanceSet>,
) -> Node {
if point_set.is_empty() {
self.new_leaf(p)
} else {
let max_dist = self.max(point_set);
let next_scale = (max_scale - 1).min(self.get_scale(max_dist));
if next_scale == std::i64::MIN {
let mut children: Vec<Node<F>> = Vec::new();
if next_scale == i64::MIN {
let mut children: Vec<Node> = Vec::new();
let mut leaf = self.new_leaf(p);
children.push(leaf);
while !point_set.is_empty() {
@@ -304,13 +303,13 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
}
Node {
idx: p,
max_dist: F::zero(),
parent_dist: F::zero(),
max_dist: 0f64,
parent_dist: 0f64,
children,
_scale: 100,
}
} else {
let mut far: Vec<DistanceSet<F>> = Vec::new();
let mut far: Vec<DistanceSet> = Vec::new();
self.split(point_set, &mut far, max_scale);
let child = self.batch_insert(p, next_scale, top_scale, point_set, consumed_set);
@@ -319,14 +318,14 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
point_set.append(&mut far);
child
} else {
let mut children: Vec<Node<F>> = vec![child];
let mut new_point_set: Vec<DistanceSet<F>> = Vec::new();
let mut new_consumed_set: Vec<DistanceSet<F>> = Vec::new();
let mut children: Vec<Node> = vec![child];
let mut new_point_set: Vec<DistanceSet> = Vec::new();
let mut new_consumed_set: Vec<DistanceSet> = Vec::new();
while !point_set.is_empty() {
let set: DistanceSet<F> = point_set.remove(point_set.len() - 1);
let set: DistanceSet = point_set.remove(point_set.len() - 1);
let new_dist: F = set.dist[set.dist.len() - 1];
let new_dist = set.dist[set.dist.len() - 1];
self.dist_split(
point_set,
@@ -374,7 +373,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
Node {
idx: p,
max_dist: self.max(consumed_set),
parent_dist: F::zero(),
parent_dist: 0f64,
children,
_scale: (top_scale - max_scale),
}
@@ -385,12 +384,12 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
fn split(
&self,
point_set: &mut Vec<DistanceSet<F>>,
far_set: &mut Vec<DistanceSet<F>>,
point_set: &mut Vec<DistanceSet>,
far_set: &mut Vec<DistanceSet>,
max_scale: i64,
) {
let fmax = self.get_cover_radius(max_scale);
let mut new_set: Vec<DistanceSet<F>> = Vec::new();
let mut new_set: Vec<DistanceSet> = Vec::new();
for n in point_set.drain(0..) {
if n.dist[n.dist.len() - 1] <= fmax {
new_set.push(n);
@@ -404,13 +403,13 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
fn dist_split(
&self,
point_set: &mut Vec<DistanceSet<F>>,
new_point_set: &mut Vec<DistanceSet<F>>,
point_set: &mut Vec<DistanceSet>,
new_point_set: &mut Vec<DistanceSet>,
new_point: &T,
max_scale: i64,
) {
let fmax = self.get_cover_radius(max_scale);
let mut new_set: Vec<DistanceSet<F>> = Vec::new();
let mut new_set: Vec<DistanceSet> = Vec::new();
for mut n in point_set.drain(0..) {
let new_dist = self
.distance
@@ -426,24 +425,24 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
point_set.append(&mut new_set);
}
fn get_cover_radius(&self, s: i64) -> F {
self.base.powf(F::from_i64(s).unwrap())
fn get_cover_radius(&self, s: i64) -> f64 {
self.base.powf(s as f64)
}
fn get_data_value(&self, idx: usize) -> &T {
&self.data[idx]
}
fn get_scale(&self, d: F) -> i64 {
if d == F::zero() {
std::i64::MIN
fn get_scale(&self, d: f64) -> i64 {
if d == 0f64 {
i64::MIN
} else {
(self.inv_log_base * d.ln()).ceil().to_i64().unwrap()
(self.inv_log_base * d.ln()).ceil() as i64
}
}
fn max(&self, distance_set: &[DistanceSet<F>]) -> F {
let mut max = F::zero();
fn max(&self, distance_set: &[DistanceSet]) -> f64 {
let mut max = 0f64;
for n in distance_set {
if max < n.dist[n.dist.len() - 1] {
max = n.dist[n.dist.len() - 1];
@@ -457,19 +456,22 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
mod tests {
use super::*;
use crate::math::distance::Distances;
use crate::metrics::distance::Distances;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
struct SimpleDistance {}
impl Distance<i32, f64> for SimpleDistance {
impl Distance<i32> for SimpleDistance {
fn distance(&self, a: &i32, b: &i32) -> f64 {
(a - b).abs() as f64
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn cover_tree_test() {
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
@@ -486,7 +488,10 @@ mod tests {
let knn: Vec<i32> = knn.iter().map(|v| *v.2).collect();
assert_eq!(vec!(3, 4, 5, 6, 7), knn);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn cover_tree_test1() {
let data = vec![
@@ -505,7 +510,10 @@ mod tests {
assert_eq!(vec!(0, 1, 2), knn);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
#[cfg(feature = "serde")]
fn serde() {
@@ -513,7 +521,7 @@ mod tests {
let tree = CoverTree::new(data, SimpleDistance {}).unwrap();
let deserialized_tree: CoverTree<i32, f64, SimpleDistance> =
let deserialized_tree: CoverTree<i32, SimpleDistance> =
serde_json::from_str(&serde_json::to_string(&tree).unwrap()).unwrap();
assert_eq!(tree, deserialized_tree);
+705
View File
@@ -0,0 +1,705 @@
///
/// ### FastPair: Data-structure for the dynamic closest-pair problem.
///
/// Reference:
/// Eppstein, David: Fast hierarchical clustering and other applications of
/// dynamic closest pairs. Journal of Experimental Algorithmics 5 (2000) 1.
///
/// Example:
/// ```
/// use smartcore::metrics::distance::PairwiseDistance;
/// use smartcore::linalg::basic::matrix::DenseMatrix;
/// use smartcore::algorithm::neighbour::fastpair::FastPair;
/// let x = DenseMatrix::<f64>::from_2d_array(&[
/// &[5.1, 3.5, 1.4, 0.2],
/// &[4.9, 3.0, 1.4, 0.2],
/// &[4.7, 3.2, 1.3, 0.2],
/// &[4.6, 3.1, 1.5, 0.2],
/// &[5.0, 3.6, 1.4, 0.2],
/// &[5.4, 3.9, 1.7, 0.4],
/// ]).unwrap();
/// let fastpair = FastPair::new(&x);
/// let closest_pair: PairwiseDistance<f64> = fastpair.unwrap().closest_pair();
/// ```
/// <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
/// <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::collections::HashMap;
use num::Bounded;
use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::metrics::distance::euclidian::Euclidian;
use crate::metrics::distance::PairwiseDistance;
use crate::numbers::floatnum::FloatNumber;
use crate::numbers::realnum::RealNumber;
///
/// Inspired by Python implementation:
/// <https://github.com/carsonfarmer/fastpair/blob/b8b4d3000ab6f795a878936667eee1b557bf353d/fastpair/base.py>
/// MIT License (MIT) Copyright (c) 2016 Carson Farmer
///
/// affinity used is Euclidean so to allow linkage with single, ward, complete and average
///
#[derive(Debug, Clone)]
pub struct FastPair<'a, T: RealNumber + FloatNumber, M: Array2<T>> {
/// initial matrix
samples: &'a M,
/// closest pair hashmap (connectivity matrix for closest pairs)
pub distances: HashMap<usize, PairwiseDistance<T>>,
/// conga line used to keep track of the closest pair
pub neighbours: Vec<usize>,
}
impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
/// Constructor
/// Instantiate and initialize the algorithm
pub fn new(m: &'a M) -> Result<Self, Failed> {
if m.shape().0 < 3 {
return Err(Failed::because(
FailedError::FindFailed,
"min number of rows should be 3",
));
}
let mut init = Self {
samples: m,
// to be computed in init(..)
distances: HashMap::with_capacity(m.shape().0),
neighbours: Vec::with_capacity(m.shape().0 + 1),
};
init.init();
Ok(init)
}
/// Initialise `FastPair` by passing a `Array2`.
/// Build a FastPairs data-structure from a set of (new) points.
fn init(&mut self) {
// basic measures
let len = self.samples.shape().0;
let max_index = self.samples.shape().0 - 1;
// Store all closest neighbors
let _distances = Box::new(HashMap::with_capacity(len));
let _neighbours = Box::new(Vec::with_capacity(len));
let mut distances = *_distances;
let mut neighbours = *_neighbours;
// fill neighbours with -1 values
neighbours.extend(0..len);
// init closest neighbour pairwise data
for index_row_i in 0..(max_index) {
distances.insert(
index_row_i,
PairwiseDistance {
node: index_row_i,
neighbour: Option::None,
distance: Some(<T as Bounded>::max_value()),
},
);
}
// loop through indeces and neighbours
for index_row_i in 0..(len) {
// start looking for the neighbour in the second element
let mut index_closest = index_row_i + 1; // closest neighbour index
let mut nbd: Option<T> = distances[&index_row_i].distance; // init neighbour distance
for index_row_j in (index_row_i + 1)..len {
distances.insert(
index_row_j,
PairwiseDistance {
node: index_row_j,
neighbour: Some(index_row_i),
distance: nbd,
},
);
let d = Euclidian::squared_distance(
&Vec::from_iterator(
self.samples.get_row(index_row_i).iterator(0).copied(),
self.samples.shape().1,
),
&Vec::from_iterator(
self.samples.get_row(index_row_j).iterator(0).copied(),
self.samples.shape().1,
),
);
if d < nbd.unwrap().to_f64().unwrap() {
// set this j-value to be the closest neighbour
index_closest = index_row_j;
nbd = Some(T::from(d).unwrap());
}
}
// Add that edge
distances.entry(index_row_i).and_modify(|e| {
e.distance = nbd;
e.neighbour = Some(index_closest);
});
}
// No more neighbors, terminate conga line.
// Last person on the line has no neigbors
distances.get_mut(&max_index).unwrap().neighbour = Some(max_index);
distances.get_mut(&(len - 1)).unwrap().distance = Some(<T as Bounded>::max_value());
// compute sparse matrix (connectivity matrix)
let mut sparse_matrix = M::zeros(len, len);
for (_, p) in distances.iter() {
sparse_matrix.set((p.node, p.neighbour.unwrap()), p.distance.unwrap());
}
self.distances = distances;
self.neighbours = neighbours;
}
/// Find closest pair by scanning list of nearest neighbors.
#[allow(dead_code)]
pub fn closest_pair(&self) -> PairwiseDistance<T> {
let mut a = self.neighbours[0]; // Start with first point
let mut d = self.distances[&a].distance;
for p in self.neighbours.iter() {
if self.distances[p].distance < d {
a = *p; // Update `a` and distance `d`
d = self.distances[p].distance;
}
}
let b = self.distances[&a].neighbour;
PairwiseDistance {
node: a,
neighbour: b,
distance: d,
}
}
///
/// Return order dissimilarities from closest to furthest
///
#[allow(dead_code)]
pub fn ordered_pairs(&self) -> std::vec::IntoIter<&PairwiseDistance<T>> {
// improvement: implement this to return `impl Iterator<Item = &PairwiseDistance<T>>`
// need to implement trait `Iterator` for `Vec<&PairwiseDistance<T>>`
let mut distances = self
.distances
.values()
.collect::<Vec<&PairwiseDistance<T>>>();
distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
distances.into_iter()
}
//
// Compute distances from input to all other points in data-structure.
// input is the row index of the sample matrix
//
#[allow(dead_code)]
fn distances_from(&self, index_row: usize) -> Vec<PairwiseDistance<T>> {
let mut distances = Vec::<PairwiseDistance<T>>::with_capacity(self.samples.shape().0);
for other in self.neighbours.iter() {
if index_row != *other {
distances.push(PairwiseDistance {
node: index_row,
neighbour: Some(*other),
distance: Some(
T::from(Euclidian::squared_distance(
&Vec::from_iterator(
self.samples.get_row(index_row).iterator(0).copied(),
self.samples.shape().1,
),
&Vec::from_iterator(
self.samples.get_row(*other).iterator(0).copied(),
self.samples.shape().1,
),
))
.unwrap(),
),
})
}
}
distances
}
}
#[cfg(test)]
mod tests_fastpair {
use super::*;
use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix};
/// Brute force algorithm, used only for comparison and testing
pub fn closest_pair_brute(
fastpair: &FastPair<'_, f64, DenseMatrix<f64>>,
) -> PairwiseDistance<f64> {
use itertools::Itertools;
let m = fastpair.samples.shape().0;
let mut closest_pair = PairwiseDistance {
node: 0,
neighbour: Option::None,
distance: Some(f64::max_value()),
};
for pair in (0..m).combinations(2) {
let d = Euclidian::squared_distance(
&Vec::from_iterator(
fastpair.samples.get_row(pair[0]).iterator(0).copied(),
fastpair.samples.shape().1,
),
&Vec::from_iterator(
fastpair.samples.get_row(pair[1]).iterator(0).copied(),
fastpair.samples.shape().1,
),
);
if d < closest_pair.distance.unwrap() {
closest_pair.node = pair[0];
closest_pair.neighbour = Some(pair[1]);
closest_pair.distance = Some(d);
}
}
closest_pair
}
#[test]
fn fastpair_init() {
let x: DenseMatrix<f64> = DenseMatrix::rand(10, 4);
let _fastpair = FastPair::new(&x);
assert!(_fastpair.is_ok());
let fastpair = _fastpair.unwrap();
let distances = fastpair.distances;
let neighbours = fastpair.neighbours;
assert!(!distances.is_empty());
assert!(!neighbours.is_empty());
assert_eq!(10, neighbours.len());
assert_eq!(10, distances.len());
}
#[test]
fn dataset_has_at_least_three_points() {
// Create a dataset which consists of only two points:
// A(0.0, 0.0) and B(1.0, 1.0).
let dataset = DenseMatrix::<f64>::from_2d_array(&[&[0.0, 0.0], &[1.0, 1.0]]).unwrap();
// We expect an error when we run `FastPair` on this dataset,
// becuase `FastPair` currently only works on a minimum of 3
// points.
let fastpair = FastPair::new(&dataset);
assert!(fastpair.is_err());
if let Err(e) = fastpair {
let expected_error =
Failed::because(FailedError::FindFailed, "min number of rows should be 3");
assert_eq!(e, expected_error)
}
}
#[test]
fn one_dimensional_dataset_minimal() {
let dataset = DenseMatrix::<f64>::from_2d_array(&[&[0.0], &[2.0], &[9.0]]).unwrap();
let result = FastPair::new(&dataset);
assert!(result.is_ok());
let fastpair = result.unwrap();
let closest_pair = fastpair.closest_pair();
let expected_closest_pair = PairwiseDistance {
node: 0,
neighbour: Some(1),
distance: Some(4.0),
};
assert_eq!(closest_pair, expected_closest_pair);
let closest_pair_brute = closest_pair_brute(&fastpair);
assert_eq!(closest_pair_brute, expected_closest_pair);
}
#[test]
fn one_dimensional_dataset_2() {
let dataset =
DenseMatrix::<f64>::from_2d_array(&[&[27.0], &[0.0], &[9.0], &[2.0]]).unwrap();
let result = FastPair::new(&dataset);
assert!(result.is_ok());
let fastpair = result.unwrap();
let closest_pair = fastpair.closest_pair();
let expected_closest_pair = PairwiseDistance {
node: 1,
neighbour: Some(3),
distance: Some(4.0),
};
assert_eq!(closest_pair, closest_pair_brute(&fastpair));
assert_eq!(closest_pair, expected_closest_pair);
}
#[test]
fn fastpair_new() {
// compute
let x = DenseMatrix::<f64>::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
])
.unwrap();
let fastpair = FastPair::new(&x);
assert!(fastpair.is_ok());
// unwrap results
let result = fastpair.unwrap();
// list of minimal pairwise dissimilarities
let dissimilarities = vec![
(
1,
PairwiseDistance {
node: 1,
neighbour: Some(9),
distance: Some(0.030000000000000037),
},
),
(
10,
PairwiseDistance {
node: 10,
neighbour: Some(12),
distance: Some(0.07000000000000003),
},
),
(
11,
PairwiseDistance {
node: 11,
neighbour: Some(14),
distance: Some(0.18000000000000013),
},
),
(
12,
PairwiseDistance {
node: 12,
neighbour: Some(14),
distance: Some(0.34000000000000086),
},
),
(
13,
PairwiseDistance {
node: 13,
neighbour: Some(14),
distance: Some(1.6499999999999997),
},
),
(
14,
PairwiseDistance {
node: 14,
neighbour: Some(14),
distance: Some(f64::MAX),
},
),
(
6,
PairwiseDistance {
node: 6,
neighbour: Some(7),
distance: Some(0.18000000000000027),
},
),
(
0,
PairwiseDistance {
node: 0,
neighbour: Some(4),
distance: Some(0.01999999999999995),
},
),
(
8,
PairwiseDistance {
node: 8,
neighbour: Some(9),
distance: Some(0.3100000000000001),
},
),
(
2,
PairwiseDistance {
node: 2,
neighbour: Some(3),
distance: Some(0.0600000000000001),
},
),
(
3,
PairwiseDistance {
node: 3,
neighbour: Some(8),
distance: Some(0.08999999999999982),
},
),
(
7,
PairwiseDistance {
node: 7,
neighbour: Some(9),
distance: Some(0.10999999999999982),
},
),
(
9,
PairwiseDistance {
node: 9,
neighbour: Some(13),
distance: Some(8.69),
},
),
(
4,
PairwiseDistance {
node: 4,
neighbour: Some(7),
distance: Some(0.050000000000000086),
},
),
(
5,
PairwiseDistance {
node: 5,
neighbour: Some(7),
distance: Some(0.4900000000000002),
},
),
];
let expected: HashMap<_, _> = dissimilarities.into_iter().collect();
for i in 0..(x.shape().0 - 1) {
let input_neighbour: usize = expected.get(&i).unwrap().neighbour.unwrap();
let distance = Euclidian::squared_distance(
&Vec::from_iterator(
result.samples.get_row(i).iterator(0).copied(),
result.samples.shape().1,
),
&Vec::from_iterator(
result.samples.get_row(input_neighbour).iterator(0).copied(),
result.samples.shape().1,
),
);
assert_eq!(i, expected.get(&i).unwrap().node);
assert_eq!(
input_neighbour,
expected.get(&i).unwrap().neighbour.unwrap()
);
assert_eq!(distance, expected.get(&i).unwrap().distance.unwrap());
}
}
#[test]
fn fastpair_closest_pair() {
let x = DenseMatrix::<f64>::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
])
.unwrap();
// compute
let fastpair = FastPair::new(&x);
assert!(fastpair.is_ok());
let dissimilarity = fastpair.unwrap().closest_pair();
let closest = PairwiseDistance {
node: 0,
neighbour: Some(4),
distance: Some(0.01999999999999995),
};
assert_eq!(closest, dissimilarity);
}
#[test]
fn fastpair_closest_pair_random_matrix() {
let x = DenseMatrix::<f64>::rand(200, 25);
// compute
let fastpair = FastPair::new(&x);
assert!(fastpair.is_ok());
let result = fastpair.unwrap();
let dissimilarity1 = result.closest_pair();
let dissimilarity2 = closest_pair_brute(&result);
assert_eq!(dissimilarity1, dissimilarity2);
}
#[test]
fn fastpair_distances() {
let x = DenseMatrix::<f64>::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
])
.unwrap();
// compute
let fastpair = FastPair::new(&x);
assert!(fastpair.is_ok());
let dissimilarities = fastpair.unwrap().distances_from(0);
let mut min_dissimilarity = PairwiseDistance {
node: 0,
neighbour: Option::None,
distance: Some(f64::MAX),
};
for p in dissimilarities.iter() {
if p.distance.unwrap() < min_dissimilarity.distance.unwrap() {
min_dissimilarity = *p
}
}
let closest = PairwiseDistance {
node: 0,
neighbour: Some(4),
distance: Some(0.01999999999999995),
};
assert_eq!(closest, min_dissimilarity);
}
#[test]
fn fastpair_ordered_pairs() {
let x = DenseMatrix::<f64>::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
])
.unwrap();
let fastpair = FastPair::new(&x).unwrap();
let ordered = fastpair.ordered_pairs();
let mut previous: f64 = -1.0;
for p in ordered {
if previous == -1.0 {
previous = p.distance.unwrap();
} else {
let current = p.distance.unwrap();
assert!(current >= previous);
previous = current;
}
}
}
#[test]
fn test_empty_set() {
let empty_matrix = DenseMatrix::<f64>::zeros(0, 0);
let result = FastPair::new(&empty_matrix);
assert!(result.is_err());
if let Err(e) = result {
assert_eq!(
e,
Failed::because(FailedError::FindFailed, "min number of rows should be 3")
);
}
}
#[test]
fn test_single_point() {
let single_point = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]).unwrap();
let result = FastPair::new(&single_point);
assert!(result.is_err());
if let Err(e) = result {
assert_eq!(
e,
Failed::because(FailedError::FindFailed, "min number of rows should be 3")
);
}
}
#[test]
fn test_two_points() {
let two_points = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
let result = FastPair::new(&two_points);
assert!(result.is_err());
if let Err(e) = result {
assert_eq!(
e,
Failed::because(FailedError::FindFailed, "min number of rows should be 3")
);
}
}
#[test]
fn test_three_identical_points() {
let identical_points =
DenseMatrix::from_2d_array(&[&[1.0, 1.0], &[1.0, 1.0], &[1.0, 1.0]]).unwrap();
let result = FastPair::new(&identical_points);
assert!(result.is_ok());
let fastpair = result.unwrap();
let closest_pair = fastpair.closest_pair();
assert_eq!(closest_pair.distance, Some(0.0));
}
#[test]
fn test_result_unwrapping() {
let valid_matrix =
DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0], &[5.0, 6.0], &[7.0, 8.0]])
.unwrap();
let result = FastPair::new(&valid_matrix);
assert!(result.is_ok());
// This should not panic
let _fastpair = result.unwrap();
}
}
+29 -30
View File
@@ -3,12 +3,12 @@
//! see [KNN algorithms](../index.html)
//! ```
//! use smartcore::algorithm::neighbour::linear_search::*;
//! use smartcore::math::distance::Distance;
//! use smartcore::metrics::distance::Distance;
//!
//! #[derive(Clone)]
//! struct SimpleDistance {} // Our distance function
//!
//! impl Distance<i32, f64> for SimpleDistance {
//! impl Distance<i32> for SimpleDistance {
//! fn distance(&self, a: &i32, b: &i32) -> f64 { // simple simmetrical scalar distance
//! (a - b).abs() as f64
//! }
@@ -25,38 +25,31 @@
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::cmp::{Ordering, PartialOrd};
use std::marker::PhantomData;
use crate::algorithm::sort::heap_select::HeapSelection;
use crate::error::{Failed, FailedError};
use crate::math::distance::Distance;
use crate::math::num::RealNumber;
use crate::metrics::distance::Distance;
/// Implements Linear Search algorithm, see [KNN algorithms](../index.html)
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct LinearKNNSearch<T, F: RealNumber, D: Distance<T, F>> {
pub struct LinearKNNSearch<T, D: Distance<T>> {
distance: D,
data: Vec<T>,
f: PhantomData<F>,
}
impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
impl<T, D: Distance<T>> LinearKNNSearch<T, D> {
/// Initializes algorithm.
/// * `data` - vector of data points to search for.
/// * `distance` - distance metric to use for searching. This function should extend [`Distance`](../../../math/distance/index.html) interface.
pub fn new(data: Vec<T>, distance: D) -> Result<LinearKNNSearch<T, F, D>, Failed> {
Ok(LinearKNNSearch {
data,
distance,
f: PhantomData,
})
pub fn new(data: Vec<T>, distance: D) -> Result<LinearKNNSearch<T, D>, Failed> {
Ok(LinearKNNSearch { data, distance })
}
/// Find k nearest neighbors
/// * `from` - look for k nearest points to `from`
/// * `k` - the number of nearest neighbors to return
pub fn find(&self, from: &T, k: usize) -> Result<Vec<(usize, F, &T)>, Failed> {
pub fn find(&self, from: &T, k: usize) -> Result<Vec<(usize, f64, &T)>, Failed> {
if k < 1 || k > self.data.len() {
return Err(Failed::because(
FailedError::FindFailed,
@@ -64,11 +57,11 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
));
}
let mut heap = HeapSelection::<KNNPoint<F>>::with_capacity(k);
let mut heap = HeapSelection::<KNNPoint>::with_capacity(k);
for _ in 0..k {
heap.add(KNNPoint {
distance: F::infinity(),
distance: f64::INFINITY,
index: None,
});
}
@@ -93,15 +86,15 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
/// Find all nearest neighbors within radius `radius` from `p`
/// * `p` - look for k nearest points to `p`
/// * `radius` - radius of the search
pub fn find_radius(&self, from: &T, radius: F) -> Result<Vec<(usize, F, &T)>, Failed> {
if radius <= F::zero() {
pub fn find_radius(&self, from: &T, radius: f64) -> Result<Vec<(usize, f64, &T)>, Failed> {
if radius <= 0f64 {
return Err(Failed::because(
FailedError::FindFailed,
"radius should be > 0",
));
}
let mut neighbors: Vec<(usize, F, &T)> = Vec::new();
let mut neighbors: Vec<(usize, f64, &T)> = Vec::new();
for i in 0..self.data.len() {
let d = self.distance.distance(from, &self.data[i]);
@@ -116,41 +109,44 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
}
#[derive(Debug)]
struct KNNPoint<F: RealNumber> {
distance: F,
struct KNNPoint {
distance: f64,
index: Option<usize>,
}
impl<F: RealNumber> PartialOrd for KNNPoint<F> {
impl PartialOrd for KNNPoint {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.distance.partial_cmp(&other.distance)
}
}
impl<F: RealNumber> PartialEq for KNNPoint<F> {
impl PartialEq for KNNPoint {
fn eq(&self, other: &Self) -> bool {
self.distance == other.distance
}
}
impl<F: RealNumber> Eq for KNNPoint<F> {}
impl Eq for KNNPoint {}
#[cfg(test)]
mod tests {
use super::*;
use crate::math::distance::Distances;
use crate::metrics::distance::Distances;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
struct SimpleDistance {}
impl Distance<i32, f64> for SimpleDistance {
impl Distance<i32> for SimpleDistance {
fn distance(&self, a: &i32, b: &i32) -> f64 {
(a - b).abs() as f64
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn knn_find() {
let data1 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
@@ -197,7 +193,10 @@ mod tests {
assert_eq!(vec!(1, 2, 3), found_idxs2);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn knn_point_eq() {
let point1 = KNNPoint {
@@ -216,7 +215,7 @@ mod tests {
};
let point_inf = KNNPoint {
distance: std::f64::INFINITY,
distance: f64::INFINITY,
index: Some(3),
};
+18 -12
View File
@@ -1,4 +1,4 @@
#![allow(clippy::ptr_arg)]
#![allow(clippy::ptr_arg, clippy::needless_range_loop)]
//! # Nearest Neighbors Search Algorithms and Data Structures
//!
//! Nearest neighbor search is a basic computational tool that is particularly relevant to machine learning,
@@ -33,37 +33,43 @@
use crate::algorithm::neighbour::cover_tree::CoverTree;
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
use crate::error::Failed;
use crate::math::distance::Distance;
use crate::math::num::RealNumber;
use crate::metrics::distance::Distance;
use crate::numbers::basenum::Number;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
pub(crate) mod bbd_tree;
/// a variant of fastpair using cosine distance
pub mod cosinepair;
/// tree data structure for fast nearest neighbor search
pub mod cover_tree;
/// fastpair closest neighbour algorithm
pub mod fastpair;
/// very simple algorithm that sequentially checks each element of the list until a match is found or the whole list has been searched.
pub mod linear_search;
/// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries.
/// `KNNAlgorithmName` maintains a list of supported search algorithms, see [KNN algorithms](../algorithm/neighbour/index.html)
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Default)]
pub enum KNNAlgorithmName {
/// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html)
LinearSearch,
/// Cover Tree Search algorithm, see [`CoverTree`](../algorithm/neighbour/cover_tree/index.html)
#[default]
CoverTree,
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub(crate) enum KNNAlgorithm<T: RealNumber, D: Distance<Vec<T>, T>> {
LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
CoverTree(CoverTree<Vec<T>, T, D>),
pub(crate) enum KNNAlgorithm<T: Number, D: Distance<Vec<T>>> {
LinearSearch(LinearKNNSearch<Vec<T>, D>),
CoverTree(CoverTree<Vec<T>, D>),
}
// TODO: missing documentation
impl KNNAlgorithmName {
pub(crate) fn fit<T: RealNumber, D: Distance<Vec<T>, T>>(
pub(crate) fn fit<T: Number, D: Distance<Vec<T>>>(
&self,
data: Vec<Vec<T>>,
distance: D,
@@ -79,8 +85,8 @@ impl KNNAlgorithmName {
}
}
impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNAlgorithm<T, D> {
pub fn find(&self, from: &Vec<T>, k: usize) -> Result<Vec<(usize, T, &Vec<T>)>, Failed> {
impl<T: Number, D: Distance<Vec<T>>> KNNAlgorithm<T, D> {
pub fn find(&self, from: &Vec<T>, k: usize) -> Result<Vec<(usize, f64, &Vec<T>)>, Failed> {
match *self {
KNNAlgorithm::LinearSearch(ref linear) => linear.find(from, k),
KNNAlgorithm::CoverTree(ref cover) => cover.find(from, k),
@@ -90,8 +96,8 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNAlgorithm<T, D> {
pub fn find_radius(
&self,
from: &Vec<T>,
radius: T,
) -> Result<Vec<(usize, T, &Vec<T>)>, Failed> {
radius: f64,
) -> Result<Vec<(usize, f64, &Vec<T>)>, Failed> {
match *self {
KNNAlgorithm::LinearSearch(ref linear) => linear.find_radius(from, radius),
KNNAlgorithm::CoverTree(ref cover) => cover.find_radius(from, radius),
+23 -8
View File
@@ -12,7 +12,7 @@ pub struct HeapSelection<T: PartialOrd + Debug> {
heap: Vec<T>,
}
impl<'a, T: PartialOrd + Debug> HeapSelection<T> {
impl<T: PartialOrd + Debug> HeapSelection<T> {
pub fn with_capacity(k: usize) -> HeapSelection<T> {
HeapSelection {
k,
@@ -95,14 +95,20 @@ impl<'a, T: PartialOrd + Debug> HeapSelection<T> {
mod tests {
use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn with_capacity() {
let heap = HeapSelection::<i32>::with_capacity(3);
assert_eq!(3, heap.k);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn test_add() {
let mut heap = HeapSelection::with_capacity(3);
@@ -120,11 +126,14 @@ mod tests {
assert_eq!(vec![2, 0, -5], heap.get());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn test_add1() {
let mut heap = HeapSelection::with_capacity(3);
heap.add(std::f64::INFINITY);
heap.add(f64::INFINITY);
heap.add(-5f64);
heap.add(4f64);
heap.add(-1f64);
@@ -135,11 +144,14 @@ mod tests {
assert_eq!(vec![0f64, -1f64, -5f64], heap.get());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn test_add2() {
let mut heap = HeapSelection::with_capacity(3);
heap.add(std::f64::INFINITY);
heap.add(f64::INFINITY);
heap.add(0.0);
heap.add(8.4852);
heap.add(5.6568);
@@ -148,7 +160,10 @@ mod tests {
assert_eq!(vec![5.6568, 2.8284, 0.0], heap.get());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn test_add_ordered() {
let mut heap = HeapSelection::with_capacity(3);
+8 -3
View File
@@ -1,12 +1,14 @@
use num_traits::Float;
use num_traits::Num;
pub trait QuickArgSort {
#[allow(dead_code)]
fn quick_argsort_mut(&mut self) -> Vec<usize>;
#[allow(dead_code)]
fn quick_argsort(&self) -> Vec<usize>;
}
impl<T: Float> QuickArgSort for Vec<T> {
impl<T: Num + PartialOrd + Copy> QuickArgSort for Vec<T> {
fn quick_argsort(&self) -> Vec<usize> {
let mut v = self.clone();
v.quick_argsort_mut()
@@ -113,7 +115,10 @@ impl<T: Float> QuickArgSort for Vec<T> {
mod tests {
use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn with_capacity() {
let arr1 = vec![0.3, 0.1, 0.2, 0.4, 0.9, 0.5, 0.7, 0.6, 0.8];
+34 -2
View File
@@ -16,8 +16,12 @@ pub trait UnsupervisedEstimator<X, P> {
P: Clone;
}
/// An estimator for supervised learning, , that provides method `fit` to learn from data and training values
pub trait SupervisedEstimator<X, Y, P> {
/// An estimator for supervised learning, that provides method `fit` to learn from data and training values
pub trait SupervisedEstimator<X, Y, P>: Predictor<X, Y> {
/// Empty constructor, instantiate an empty estimator. Object is dropped as soon as `fit()` is called.
/// used to pass around the correct `fit()` implementation.
/// by calling `::fit()`. mostly used to be used with `model_selection::cross_validate(...)`
fn new() -> Self;
/// Fit a model to a training dataset, estimate model's parameters.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - target training values of size _N_.
@@ -28,6 +32,24 @@ pub trait SupervisedEstimator<X, Y, P> {
P: Clone;
}
/// An estimator for supervised learning.
/// In this one parameters are borrowed instead of moved, this is useful for parameters that carry
/// references. Also to be used when there is no predictor attached to the estimator.
pub trait SupervisedEstimatorBorrow<'a, X, Y, P> {
/// Empty constructor, instantiate an empty estimator. Object is dropped as soon as `fit()` is called.
/// used to pass around the correct `fit()` implementation.
/// by calling `::fit()`. mostly used to be used with `model_selection::cross_validate(...)`
fn new() -> Self;
/// Fit a model to a training dataset, estimate model's parameters.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - target training values of size _N_.
/// * `&parameters` - hyperparameters of an algorithm
fn fit(x: &'a X, y: &'a Y, parameters: &'a P) -> Result<Self, Failed>
where
Self: Sized,
P: Clone;
}
/// Implements method predict that estimates target value from new data
pub trait Predictor<X, Y> {
/// Estimate target values from new data.
@@ -35,9 +57,19 @@ pub trait Predictor<X, Y> {
fn predict(&self, x: &X) -> Result<Y, Failed>;
}
/// Implements method predict that estimates target value from new data, with borrowing
pub trait PredictorBorrow<'a, X, T> {
/// Estimate target values from new data.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
fn predict(&self, x: &'a X) -> Result<Vec<T>, Failed>;
}
/// Implements method transform that filters or modifies input data
pub trait Transformer<X> {
/// Transform data by modifying or filtering it
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
fn transform(&self, x: &X) -> Result<X, Failed>;
}
/// empty parameters for an estimator, see `BiasedEstimator`
pub trait NoParameters {}
+317
View File
@@ -0,0 +1,317 @@
//! # Agglomerative Hierarchical Clustering
//!
//! Agglomerative clustering is a "bottom-up" hierarchical clustering method. It works by placing each data point in its own cluster and then successively merging the two most similar clusters until a stopping criterion is met. This process creates a tree-based hierarchy of clusters known as a dendrogram.
//!
//! The similarity of two clusters is determined by a **linkage criterion**. This implementation uses **single-linkage**, where the distance between two clusters is defined as the minimum distance between any single point in the first cluster and any single point in the second cluster. The distance between points is the standard Euclidean distance.
//!
//! The algorithm first builds the full hierarchy of `N-1` merges. To obtain a specific number of clusters, `n_clusters`, the algorithm then effectively "cuts" the dendrogram at the point where `n_clusters` remain.
//!
//! ## Example:
//!
//! ```
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::cluster::agglomerative::{AgglomerativeClustering, AgglomerativeClusteringParameters};
//!
//! // A dataset with 2 distinct groups of points.
//! let x = DenseMatrix::from_2d_array(&[
//! &[0.0, 0.0], &[1.0, 1.0], &[0.5, 0.5], // Cluster A
//! &[10.0, 10.0], &[11.0, 11.0], &[10.5, 10.5], // Cluster B
//! ]).unwrap();
//!
//! // Set parameters to find 2 clusters.
//! let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(2);
//!
//! // Fit the model to the data.
//! let clustering = AgglomerativeClustering::<f64, usize, DenseMatrix<f64>, Vec<usize>>::fit(&x, parameters).unwrap();
//!
//! // Get the cluster assignments.
//! let labels = clustering.labels; // e.g., [0, 0, 0, 1, 1, 1]
//! ```
//!
//! ## References:
//!
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 10.3.2 Hierarchical Clustering](http://faculty.marshall.usc.edu/gareth-james/ISL/)
//! * ["The Elements of Statistical Learning", Hastie T., Tibshirani R., Friedman J., 14.3.12 Hierarchical Clustering](https://hastie.su.domains/ElemStatLearn/)
use std::collections::HashMap;
use std::marker::PhantomData;
use crate::api::UnsupervisedEstimator;
use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::numbers::basenum::Number;
/// Parameters for the Agglomerative Clustering algorithm.
#[derive(Debug, Clone, Copy)]
pub struct AgglomerativeClusteringParameters {
/// The number of clusters to find.
pub n_clusters: usize,
}
impl AgglomerativeClusteringParameters {
/// Sets the number of clusters.
///
/// # Arguments
/// * `n_clusters` - The desired number of clusters.
pub fn with_n_clusters(mut self, n_clusters: usize) -> Self {
self.n_clusters = n_clusters;
self
}
}
impl Default for AgglomerativeClusteringParameters {
fn default() -> Self {
AgglomerativeClusteringParameters { n_clusters: 2 }
}
}
/// Agglomerative Clustering model.
///
/// This implementation uses single-linkage clustering, which is mathematically
/// equivalent to finding the Minimum Spanning Tree (MST) of the data points.
/// The core logic is an efficient implementation of Kruskal's algorithm, which
/// processes all pairwise distances in increasing order and uses a Disjoint
/// Set Union (DSU) data structure to track cluster membership.
#[derive(Debug)]
pub struct AgglomerativeClustering<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
/// The cluster label assigned to each sample.
pub labels: Vec<usize>,
_phantom_tx: PhantomData<TX>,
_phantom_ty: PhantomData<TY>,
_phantom_x: PhantomData<X>,
_phantom_y: PhantomData<Y>,
}
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> AgglomerativeClustering<TX, TY, X, Y> {
/// Fits the agglomerative clustering model to the data.
///
/// # Arguments
/// * `data` - A reference to the input data matrix.
/// * `parameters` - The parameters for the clustering algorithm, including `n_clusters`.
///
/// # Returns
/// A `Result` containing the fitted model with cluster labels, or an error if
pub fn fit(data: &X, parameters: AgglomerativeClusteringParameters) -> Result<Self, Failed> {
let (num_samples, _) = data.shape();
let n_clusters = parameters.n_clusters;
if n_clusters > num_samples {
return Err(Failed::because(
FailedError::ParametersError,
&format!(
"n_clusters: {n_clusters} cannot be greater than n_samples: {num_samples}"
),
));
}
let mut distance_pairs = Vec::new();
for i in 0..num_samples {
for j in (i + 1)..num_samples {
let distance: f64 = data
.get_row(i)
.iterator(0)
.zip(data.get_row(j).iterator(0))
.map(|(&a, &b)| (a.to_f64().unwrap() - b.to_f64().unwrap()).powi(2))
.sum::<f64>();
distance_pairs.push((distance, i, j));
}
}
distance_pairs.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
let mut parent = HashMap::new();
let mut children = HashMap::new();
for i in 0..num_samples {
parent.insert(i, i);
children.insert(i, vec![i]);
}
let mut merge_history = Vec::new();
let num_merges_needed = num_samples - 1;
while merge_history.len() < num_merges_needed {
let (_, p1, p2) = distance_pairs.pop().unwrap();
let root1 = parent[&p1];
let root2 = parent[&p2];
if root1 != root2 {
let root2_children = children.remove(&root2).unwrap();
for child in root2_children.iter() {
parent.insert(*child, root1);
}
let root1_children = children.get_mut(&root1).unwrap();
root1_children.extend(root2_children);
merge_history.push((root1, root2));
}
}
let mut clusters = HashMap::new();
let mut assignments = HashMap::new();
for i in 0..num_samples {
clusters.insert(i, vec![i]);
assignments.insert(i, i);
}
let merges_to_apply = num_samples - n_clusters;
for (root1, root2) in merge_history[0..merges_to_apply].iter() {
let root1_cluster = assignments[root1];
let root2_cluster = assignments[root2];
let root2_assignments = clusters.remove(&root2_cluster).unwrap();
for assignment in root2_assignments.iter() {
assignments.insert(*assignment, root1_cluster);
}
let root1_assignments = clusters.get_mut(&root1_cluster).unwrap();
root1_assignments.extend(root2_assignments);
}
let mut labels: Vec<usize> = (0..num_samples).map(|_| 0).collect();
let mut cluster_keys: Vec<&usize> = clusters.keys().collect();
cluster_keys.sort();
for (i, key) in cluster_keys.into_iter().enumerate() {
for index in clusters[key].iter() {
labels[*index] = i;
}
}
Ok(AgglomerativeClustering {
labels,
_phantom_tx: PhantomData,
_phantom_ty: PhantomData,
_phantom_x: PhantomData,
_phantom_y: PhantomData,
})
}
}
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>>
UnsupervisedEstimator<X, AgglomerativeClusteringParameters>
for AgglomerativeClustering<TX, TY, X, Y>
{
fn fit(x: &X, parameters: AgglomerativeClusteringParameters) -> Result<Self, Failed> {
AgglomerativeClustering::fit(x, parameters)
}
}
#[cfg(test)]
mod tests {
use crate::linalg::basic::matrix::DenseMatrix;
use std::collections::HashSet;
use super::*;
#[test]
fn test_simple_clustering() {
// Two distinct clusters, far apart.
let data = vec![
0.0, 0.0, 1.0, 1.0, 0.5, 0.5, // Cluster A
10.0, 10.0, 11.0, 11.0, 10.5, 10.5, // Cluster B
];
let matrix = DenseMatrix::new(6, 2, data, false).unwrap();
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(2);
// Using f64 for TY as usize doesn't satisfy the Number trait bound.
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
&matrix, parameters,
)
.unwrap();
let labels = clustering.labels;
// Check that all points in the first group have the same label.
let first_group_label = labels[0];
assert!(labels[0..3].iter().all(|&l| l == first_group_label));
// Check that all points in the second group have the same label.
let second_group_label = labels[3];
assert!(labels[3..6].iter().all(|&l| l == second_group_label));
// Check that the two groups have different labels.
assert_ne!(first_group_label, second_group_label);
}
#[test]
fn test_four_clusters() {
// Four distinct clusters in the corners of a square.
let data = vec![
0.0, 0.0, 1.0, 1.0, // Cluster A
100.0, 100.0, 101.0, 101.0, // Cluster B
0.0, 100.0, 1.0, 101.0, // Cluster C
100.0, 0.0, 101.0, 1.0, // Cluster D
];
let matrix = DenseMatrix::new(8, 2, data, false).unwrap();
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(4);
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
&matrix, parameters,
)
.unwrap();
let labels = clustering.labels;
// Verify that there are exactly 4 unique labels produced.
let unique_labels: HashSet<usize> = labels.iter().cloned().collect();
assert_eq!(unique_labels.len(), 4);
// Verify that points within each original group were assigned the same cluster label.
let label_a = labels[0];
assert_eq!(label_a, labels[1]);
let label_b = labels[2];
assert_eq!(label_b, labels[3]);
let label_c = labels[4];
assert_eq!(label_c, labels[5]);
let label_d = labels[6];
assert_eq!(label_d, labels[7]);
// Verify that all four groups received different labels.
assert_ne!(label_a, label_b);
assert_ne!(label_a, label_c);
assert_ne!(label_a, label_d);
assert_ne!(label_b, label_c);
assert_ne!(label_b, label_d);
assert_ne!(label_c, label_d);
}
#[test]
fn test_n_clusters_equal_to_samples() {
let data = vec![0.0, 0.0, 5.0, 5.0, 10.0, 10.0];
let matrix = DenseMatrix::new(3, 2, data, false).unwrap();
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(3);
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
&matrix, parameters,
)
.unwrap();
// Each point should be its own cluster. Sorting makes the test deterministic.
let mut labels = clustering.labels;
labels.sort();
assert_eq!(labels, vec![0, 1, 2]);
}
#[test]
fn test_one_cluster() {
let data = vec![0.0, 0.0, 5.0, 5.0, 10.0, 10.0];
let matrix = DenseMatrix::new(3, 2, data, false).unwrap();
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(1);
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
&matrix, parameters,
)
.unwrap();
// All points should be in the same cluster.
assert_eq!(clustering.labels, vec![0, 0, 0]);
}
#[test]
fn test_error_on_too_many_clusters() {
let data = vec![0.0, 0.0, 5.0, 5.0];
let matrix = DenseMatrix::new(2, 2, data, false).unwrap();
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(3);
let result = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
&matrix, parameters,
);
assert!(result.is_err());
}
}
+239 -59
View File
@@ -18,19 +18,20 @@
//!
//! Example:
//!
//! ```
//! use smartcore::linalg::naive::dense_matrix::*;
//! ```ignore
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linalg::basic::arrays::Array2;
//! use smartcore::cluster::dbscan::*;
//! use smartcore::math::distance::Distances;
//! use smartcore::metrics::distance::Distances;
//! use smartcore::neighbors::KNNAlgorithmName;
//! use smartcore::dataset::generator;
//!
//! // Generate three blobs
//! let blobs = generator::make_blobs(100, 2, 3);
//! let x = DenseMatrix::from_vec(blobs.num_samples, blobs.num_features, &blobs.data);
//! let x: DenseMatrix<f32> = DenseMatrix::from_iterator(blobs.data.into_iter(), 100, 2, 0);
//! // Fit the algorithm and predict cluster labels
//! let labels = DBSCAN::fit(&x, DBSCANParameters::default().with_eps(3.0)).
//! and_then(|dbscan| dbscan.predict(&x));
//! let labels: Vec<u32> = DBSCAN::fit(&x, DBSCANParameters::default().with_eps(3.0)).
//! and_then(|dbscan| dbscan.predict(&x)).unwrap();
//!
//! println!("{:?}", labels);
//! ```
@@ -41,7 +42,7 @@
//! * ["Density-Based Clustering in Spatial Databases: The Algorithm GDBSCAN and its Applications", Sander J., Ester M., Kriegel HP., Xu X.](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.63.1629&rep=rep1&type=pdf)
use std::fmt::Debug;
use std::iter::Sum;
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
@@ -49,47 +50,58 @@ use serde::{Deserialize, Serialize};
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
use crate::api::{Predictor, UnsupervisedEstimator};
use crate::error::Failed;
use crate::linalg::{row_iter, Matrix};
use crate::math::distance::euclidian::Euclidian;
use crate::math::distance::{Distance, Distances};
use crate::math::num::RealNumber;
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::metrics::distance::euclidian::Euclidian;
use crate::metrics::distance::{Distance, Distances};
use crate::numbers::basenum::Number;
use crate::tree::decision_tree_classifier::which_max;
/// DBSCAN clustering algorithm
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct DBSCAN<T: RealNumber, D: Distance<Vec<T>, T>> {
pub struct DBSCAN<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>> {
cluster_labels: Vec<i16>,
num_classes: usize,
knn_algorithm: KNNAlgorithm<T, D>,
eps: T,
knn_algorithm: KNNAlgorithm<TX, D>,
eps: f64,
_phantom_ty: PhantomData<TY>,
_phantom_x: PhantomData<X>,
_phantom_y: PhantomData<Y>,
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// DBSCAN clustering algorithm parameters
pub struct DBSCANParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
pub struct DBSCANParameters<T: Number, D: Distance<Vec<T>>> {
#[cfg_attr(feature = "serde", serde(default))]
/// a function that defines a distance between each pair of point in training data.
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
pub distance: D,
#[cfg_attr(feature = "serde", serde(default))]
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
pub min_samples: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
pub eps: T,
pub eps: f64,
#[cfg_attr(feature = "serde", serde(default))]
/// KNN algorithm to use.
pub algorithm: KNNAlgorithmName,
#[cfg_attr(feature = "serde", serde(default))]
_phantom_t: PhantomData<T>,
}
impl<T: RealNumber, D: Distance<Vec<T>, T>> DBSCANParameters<T, D> {
impl<T: Number, D: Distance<Vec<T>>> DBSCANParameters<T, D> {
/// a function that defines a distance between each pair of point in training data.
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
pub fn with_distance<DD: Distance<Vec<T>, T>>(self, distance: DD) -> DBSCANParameters<T, DD> {
pub fn with_distance<DD: Distance<Vec<T>>>(self, distance: DD) -> DBSCANParameters<T, DD> {
DBSCANParameters {
distance,
min_samples: self.min_samples,
eps: self.eps,
algorithm: self.algorithm,
_phantom_t: PhantomData,
}
}
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
@@ -98,7 +110,7 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> DBSCANParameters<T, D> {
self
}
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
pub fn with_eps(mut self, eps: T) -> Self {
pub fn with_eps(mut self, eps: f64) -> Self {
self.eps = eps;
self
}
@@ -109,7 +121,113 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> DBSCANParameters<T, D> {
}
}
impl<T: RealNumber, D: Distance<Vec<T>, T>> PartialEq for DBSCAN<T, D> {
/// DBSCAN grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct DBSCANSearchParameters<T: Number, D: Distance<Vec<T>>> {
#[cfg_attr(feature = "serde", serde(default))]
/// a function that defines a distance between each pair of point in training data.
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
pub distance: Vec<D>,
#[cfg_attr(feature = "serde", serde(default))]
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
pub min_samples: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
pub eps: Vec<f64>,
#[cfg_attr(feature = "serde", serde(default))]
/// KNN algorithm to use.
pub algorithm: Vec<KNNAlgorithmName>,
_phantom_t: PhantomData<T>,
}
/// DBSCAN grid search iterator
pub struct DBSCANSearchParametersIterator<T: Number, D: Distance<Vec<T>>> {
dbscan_search_parameters: DBSCANSearchParameters<T, D>,
current_distance: usize,
current_min_samples: usize,
current_eps: usize,
current_algorithm: usize,
}
impl<T: Number, D: Distance<Vec<T>>> IntoIterator for DBSCANSearchParameters<T, D> {
type Item = DBSCANParameters<T, D>;
type IntoIter = DBSCANSearchParametersIterator<T, D>;
fn into_iter(self) -> Self::IntoIter {
DBSCANSearchParametersIterator {
dbscan_search_parameters: self,
current_distance: 0,
current_min_samples: 0,
current_eps: 0,
current_algorithm: 0,
}
}
}
impl<T: Number, D: Distance<Vec<T>>> Iterator for DBSCANSearchParametersIterator<T, D> {
type Item = DBSCANParameters<T, D>;
fn next(&mut self) -> Option<Self::Item> {
if self.current_distance == self.dbscan_search_parameters.distance.len()
&& self.current_min_samples == self.dbscan_search_parameters.min_samples.len()
&& self.current_eps == self.dbscan_search_parameters.eps.len()
&& self.current_algorithm == self.dbscan_search_parameters.algorithm.len()
{
return None;
}
let next = DBSCANParameters {
distance: self.dbscan_search_parameters.distance[self.current_distance].clone(),
min_samples: self.dbscan_search_parameters.min_samples[self.current_min_samples],
eps: self.dbscan_search_parameters.eps[self.current_eps],
algorithm: self.dbscan_search_parameters.algorithm[self.current_algorithm].clone(),
_phantom_t: PhantomData,
};
if self.current_distance + 1 < self.dbscan_search_parameters.distance.len() {
self.current_distance += 1;
} else if self.current_min_samples + 1 < self.dbscan_search_parameters.min_samples.len() {
self.current_distance = 0;
self.current_min_samples += 1;
} else if self.current_eps + 1 < self.dbscan_search_parameters.eps.len() {
self.current_distance = 0;
self.current_min_samples = 0;
self.current_eps += 1;
} else if self.current_algorithm + 1 < self.dbscan_search_parameters.algorithm.len() {
self.current_distance = 0;
self.current_min_samples = 0;
self.current_eps = 0;
self.current_algorithm += 1;
} else {
self.current_distance += 1;
self.current_min_samples += 1;
self.current_eps += 1;
self.current_algorithm += 1;
}
Some(next)
}
}
impl<T: Number> Default for DBSCANSearchParameters<T, Euclidian<T>> {
fn default() -> Self {
let default_params = DBSCANParameters::default();
DBSCANSearchParameters {
distance: vec![default_params.distance],
min_samples: vec![default_params.min_samples],
eps: vec![default_params.eps],
algorithm: vec![default_params.algorithm],
_phantom_t: PhantomData,
}
}
}
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>> PartialEq
for DBSCAN<TX, TY, X, Y, D>
{
fn eq(&self, other: &Self) -> bool {
self.cluster_labels.len() == other.cluster_labels.len()
&& self.num_classes == other.num_classes
@@ -118,47 +236,50 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> PartialEq for DBSCAN<T, D> {
}
}
impl<T: RealNumber> Default for DBSCANParameters<T, Euclidian> {
impl<T: Number> Default for DBSCANParameters<T, Euclidian<T>> {
fn default() -> Self {
DBSCANParameters {
distance: Distances::euclidian(),
min_samples: 5,
eps: T::half(),
algorithm: KNNAlgorithmName::CoverTree,
eps: 0.5f64,
algorithm: KNNAlgorithmName::default(),
_phantom_t: PhantomData,
}
}
}
impl<T: RealNumber + Sum, M: Matrix<T>, D: Distance<Vec<T>, T>>
UnsupervisedEstimator<M, DBSCANParameters<T, D>> for DBSCAN<T, D>
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
UnsupervisedEstimator<X, DBSCANParameters<TX, D>> for DBSCAN<TX, TY, X, Y, D>
{
fn fit(x: &M, parameters: DBSCANParameters<T, D>) -> Result<Self, Failed> {
fn fit(x: &X, parameters: DBSCANParameters<TX, D>) -> Result<Self, Failed> {
DBSCAN::fit(x, parameters)
}
}
impl<T: RealNumber, M: Matrix<T>, D: Distance<Vec<T>, T>> Predictor<M, M::RowVector>
for DBSCAN<T, D>
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>> Predictor<X, Y>
for DBSCAN<TX, TY, X, Y, D>
{
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
fn predict(&self, x: &X) -> Result<Y, Failed> {
self.predict(x)
}
}
impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
DBSCAN<TX, TY, X, Y, D>
{
/// Fit algorithm to _NxM_ matrix where _N_ is number of samples and _M_ is number of features.
/// * `data` - training instances to cluster
/// * `k` - number of clusters
/// * `parameters` - cluster parameters
pub fn fit<M: Matrix<T>>(
x: &M,
parameters: DBSCANParameters<T, D>,
) -> Result<DBSCAN<T, D>, Failed> {
pub fn fit(
x: &X,
parameters: DBSCANParameters<TX, D>,
) -> Result<DBSCAN<TX, TY, X, Y, D>, Failed> {
if parameters.min_samples < 1 {
return Err(Failed::fit("Invalid minPts"));
}
if parameters.eps <= T::zero() {
if parameters.eps <= 0f64 {
return Err(Failed::fit("Invalid radius: "));
}
@@ -170,13 +291,19 @@ impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
let n = x.shape().0;
let mut y = vec![undefined; n];
let algo = parameters
.algorithm
.fit(row_iter(x).collect(), parameters.distance)?;
let algo = parameters.algorithm.fit(
x.row_iter()
.map(|row| row.iterator(0).cloned().collect())
.collect(),
parameters.distance,
)?;
for (i, e) in row_iter(x).enumerate() {
let mut row = vec![TX::zero(); x.shape().1];
for (i, e) in x.row_iter().enumerate() {
if y[i] == undefined {
let mut neighbors = algo.find_radius(&e, parameters.eps)?;
e.iterator(0).zip(row.iter_mut()).for_each(|(&x, r)| *r = x);
let mut neighbors = algo.find_radius(&row, parameters.eps)?;
if neighbors.len() < parameters.min_samples {
y[i] = outlier;
} else {
@@ -188,8 +315,7 @@ impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
}
}
while !neighbors.is_empty() {
let neighbor = neighbors.pop().unwrap();
while let Some(neighbor) = neighbors.pop() {
let index = neighbor.0;
if y[index] == outlier {
@@ -227,18 +353,25 @@ impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
num_classes: k as usize,
knn_algorithm: algo,
eps: parameters.eps,
_phantom_ty: PhantomData,
_phantom_x: PhantomData,
_phantom_y: PhantomData,
})
}
/// Predict clusters for `x`
/// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features.
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
let (n, m) = x.shape();
let mut result = M::zeros(1, n);
let mut row = vec![T::zero(); m];
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let (n, _) = x.shape();
let mut result = Y::zeros(n);
let mut row = vec![TX::zero(); x.shape().1];
for i in 0..n {
x.copy_row_as_vec(i, &mut row);
x.get_row(i)
.iterator(0)
.zip(row.iter_mut())
.for_each(|(&x, r)| *r = x);
let neighbors = self.knn_algorithm.find_radius(&row, self.eps)?;
let mut label = vec![0usize; self.num_classes + 1];
for neighbor in neighbors {
@@ -251,24 +384,50 @@ impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
}
let class = which_max(&label);
if class != self.num_classes {
result.set(0, i, T::from(class).unwrap());
result.set(i, TY::from(class + 1).unwrap());
} else {
result.set(0, i, -T::one());
result.set(i, TY::zero());
}
}
Ok(result.to_row_vector())
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::linalg::basic::matrix::DenseMatrix;
#[cfg(feature = "serde")]
use crate::math::distance::euclidian::Euclidian;
use crate::metrics::distance::euclidian::Euclidian;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn search_parameters() {
let parameters: DBSCANSearchParameters<f64, Euclidian<f64>> = DBSCANSearchParameters {
min_samples: vec![10, 100],
eps: vec![1., 2.],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.min_samples, 10);
assert_eq!(next.eps, 1.);
let next = iter.next().unwrap();
assert_eq!(next.min_samples, 100);
assert_eq!(next.eps, 1.);
let next = iter.next().unwrap();
assert_eq!(next.min_samples, 10);
assert_eq!(next.eps, 2.);
let next = iter.next().unwrap();
assert_eq!(next.min_samples, 100);
assert_eq!(next.eps, 2.);
assert!(iter.next().is_none());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn fit_predict_dbscan() {
let x = DenseMatrix::from_2d_array(&[
@@ -283,9 +442,10 @@ mod tests {
&[2.2, 1.2],
&[1.8, 0.8],
&[3.0, 5.0],
]);
])
.unwrap();
let expected_labels = vec![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0];
let expected_labels = vec![1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 0];
let dbscan = DBSCAN::fit(
&x,
@@ -295,12 +455,15 @@ mod tests {
)
.unwrap();
let predicted_labels = dbscan.predict(&x).unwrap();
let predicted_labels: Vec<i32> = dbscan.predict(&x).unwrap();
assert_eq!(expected_labels, predicted_labels);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
#[cfg(feature = "serde")]
fn serde() {
@@ -325,13 +488,30 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
])
.unwrap();
let dbscan = DBSCAN::fit(&x, Default::default()).unwrap();
let deserialized_dbscan: DBSCAN<f64, Euclidian> =
let deserialized_dbscan: DBSCAN<f32, f32, DenseMatrix<f32>, Vec<f32>, Euclidian<f32>> =
serde_json::from_str(&serde_json::to_string(&dbscan).unwrap()).unwrap();
assert_eq!(dbscan, deserialized_dbscan);
}
#[cfg(feature = "datasets")]
#[test]
fn from_vec() {
use crate::dataset::generator;
// Generate three blobs
let blobs = generator::make_blobs(100, 2, 3);
let x: DenseMatrix<f32> = DenseMatrix::from_iterator(blobs.data.into_iter(), 100, 2, 0);
// Fit the algorithm and predict cluster labels
let labels: Vec<i32> = DBSCAN::fit(&x, DBSCANParameters::default().with_eps(3.0))
.and_then(|dbscan| dbscan.predict(&x))
.unwrap();
println!("{labels:?}");
}
}
+227 -66
View File
@@ -11,12 +11,12 @@
//! these re-calculated centroids becoming the new centers of their respective clusters. Next all instances of the training set are re-assigned to their closest cluster again.
//! This iterative process continues until convergence is achieved and the clusters are considered settled.
//!
//! Initial choice of K data points is very important and has big effect on performance of the algorithm. SmartCore uses k-means++ algorithm to initialize cluster centers.
//! Initial choice of K data points is very important and has big effect on performance of the algorithm. `smartcore` uses k-means++ algorithm to initialize cluster centers.
//!
//! Example:
//!
//! ```
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::cluster::kmeans::*;
//!
//! // Iris data
@@ -41,10 +41,10 @@
//! &[4.9, 2.4, 3.3, 1.0],
//! &[6.6, 2.9, 4.6, 1.3],
//! &[5.2, 2.7, 3.9, 1.4],
//! ]);
//! ]).unwrap();
//!
//! let kmeans = KMeans::fit(&x, KMeansParameters::default().with_k(2)).unwrap(); // Fit to data, 2 clusters
//! let y_hat = kmeans.predict(&x).unwrap(); // use the same points for prediction
//! let y_hat: Vec<u8> = kmeans.predict(&x).unwrap(); // use the same points for prediction
//! ```
//!
//! ## References:
@@ -52,32 +52,37 @@
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 10.3.1 K-Means Clustering](http://faculty.marshall.usc.edu/gareth-james/ISL/)
//! * ["k-means++: The Advantages of Careful Seeding", Arthur D., Vassilvitskii S.](http://ilpubs.stanford.edu:8090/778/1/2006-13.pdf)
use rand::Rng;
use std::fmt::Debug;
use std::iter::Sum;
use std::marker::PhantomData;
use rand::Rng;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::algorithm::neighbour::bbd_tree::BBDTree;
use crate::api::{Predictor, UnsupervisedEstimator};
use crate::error::Failed;
use crate::linalg::Matrix;
use crate::math::distance::euclidian::*;
use crate::math::num::RealNumber;
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::metrics::distance::euclidian::*;
use crate::numbers::basenum::Number;
use crate::rand_custom::get_rng_impl;
/// K-Means clustering algorithm
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct KMeans<T: RealNumber> {
pub struct KMeans<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
k: usize,
_y: Vec<usize>,
size: Vec<usize>,
_distortion: T,
centroids: Vec<Vec<T>>,
_distortion: f64,
centroids: Vec<Vec<f64>>,
_phantom_tx: PhantomData<TX>,
_phantom_ty: PhantomData<TY>,
_phantom_x: PhantomData<X>,
_phantom_y: PhantomData<Y>,
}
impl<T: RealNumber> PartialEq for KMeans<T> {
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq for KMeans<TX, TY, X, Y> {
fn eq(&self, other: &Self) -> bool {
if self.k != other.k
|| self.size != other.size
@@ -91,7 +96,7 @@ impl<T: RealNumber> PartialEq for KMeans<T> {
return false;
}
for j in 0..self.centroids[i].len() {
if (self.centroids[i][j] - other.centroids[i][j]).abs() > T::epsilon() {
if (self.centroids[i][j] - other.centroids[i][j]).abs() > f64::EPSILON {
return false;
}
}
@@ -101,13 +106,20 @@ impl<T: RealNumber> PartialEq for KMeans<T> {
}
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// K-Means clustering algorithm parameters
pub struct KMeansParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Number of clusters.
pub k: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// Maximum number of iterations of the k-means algorithm for a single run.
pub max_iter: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// Determines random number generation for centroid initialization.
/// Use an int to make the randomness deterministic
pub seed: Option<u64>,
}
impl KMeansParameters {
@@ -128,27 +140,118 @@ impl Default for KMeansParameters {
KMeansParameters {
k: 2,
max_iter: 100,
seed: Option::None,
}
}
}
impl<T: RealNumber + Sum, M: Matrix<T>> UnsupervisedEstimator<M, KMeansParameters> for KMeans<T> {
fn fit(x: &M, parameters: KMeansParameters) -> Result<Self, Failed> {
/// KMeans grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct KMeansSearchParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Number of clusters.
pub k: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// Maximum number of iterations of the k-means algorithm for a single run.
pub max_iter: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// Determines random number generation for centroid initialization.
/// Use an int to make the randomness deterministic
pub seed: Vec<Option<u64>>,
}
/// KMeans grid search iterator
pub struct KMeansSearchParametersIterator {
kmeans_search_parameters: KMeansSearchParameters,
current_k: usize,
current_max_iter: usize,
current_seed: usize,
}
impl IntoIterator for KMeansSearchParameters {
type Item = KMeansParameters;
type IntoIter = KMeansSearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
KMeansSearchParametersIterator {
kmeans_search_parameters: self,
current_k: 0,
current_max_iter: 0,
current_seed: 0,
}
}
}
impl Iterator for KMeansSearchParametersIterator {
type Item = KMeansParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_k == self.kmeans_search_parameters.k.len()
&& self.current_max_iter == self.kmeans_search_parameters.max_iter.len()
&& self.current_seed == self.kmeans_search_parameters.seed.len()
{
return None;
}
let next = KMeansParameters {
k: self.kmeans_search_parameters.k[self.current_k],
max_iter: self.kmeans_search_parameters.max_iter[self.current_max_iter],
seed: self.kmeans_search_parameters.seed[self.current_seed],
};
if self.current_k + 1 < self.kmeans_search_parameters.k.len() {
self.current_k += 1;
} else if self.current_max_iter + 1 < self.kmeans_search_parameters.max_iter.len() {
self.current_k = 0;
self.current_max_iter += 1;
} else if self.current_seed + 1 < self.kmeans_search_parameters.seed.len() {
self.current_k = 0;
self.current_max_iter = 0;
self.current_seed += 1;
} else {
self.current_k += 1;
self.current_max_iter += 1;
self.current_seed += 1;
}
Some(next)
}
}
impl Default for KMeansSearchParameters {
fn default() -> Self {
let default_params = KMeansParameters::default();
KMeansSearchParameters {
k: vec![default_params.k],
max_iter: vec![default_params.max_iter],
seed: vec![default_params.seed],
}
}
}
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>>
UnsupervisedEstimator<X, KMeansParameters> for KMeans<TX, TY, X, Y>
{
fn fit(x: &X, parameters: KMeansParameters) -> Result<Self, Failed> {
KMeans::fit(x, parameters)
}
}
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for KMeans<T> {
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> Predictor<X, Y>
for KMeans<TX, TY, X, Y>
{
fn predict(&self, x: &X) -> Result<Y, Failed> {
self.predict(x)
}
}
impl<T: RealNumber + Sum> KMeans<T> {
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y> {
/// Fit algorithm to _NxM_ matrix where _N_ is number of samples and _M_ is number of features.
/// * `data` - training instances to cluster
/// * `data` - training instances to cluster
/// * `parameters` - cluster parameters
pub fn fit<M: Matrix<T>>(data: &M, parameters: KMeansParameters) -> Result<KMeans<T>, Failed> {
pub fn fit(data: &X, parameters: KMeansParameters) -> Result<KMeans<TX, TY, X, Y>, Failed> {
let bbd = BBDTree::new(data);
if parameters.k < 2 {
@@ -167,10 +270,10 @@ impl<T: RealNumber + Sum> KMeans<T> {
let (n, d) = data.shape();
let mut distortion = T::max_value();
let mut y = KMeans::kmeans_plus_plus(data, parameters.k);
let mut distortion = f64::MAX;
let mut y = KMeans::<TX, TY, X, Y>::kmeans_plus_plus(data, parameters.k, parameters.seed);
let mut size = vec![0; parameters.k];
let mut centroids = vec![vec![T::zero(); d]; parameters.k];
let mut centroids = vec![vec![0f64; d]; parameters.k];
for i in 0..n {
size[y[i]] += 1;
@@ -178,23 +281,23 @@ impl<T: RealNumber + Sum> KMeans<T> {
for i in 0..n {
for j in 0..d {
centroids[y[i]][j] += data.get(i, j);
centroids[y[i]][j] += data.get((i, j)).to_f64().unwrap();
}
}
for i in 0..parameters.k {
for j in 0..d {
centroids[i][j] /= T::from(size[i]).unwrap();
centroids[i][j] /= size[i] as f64;
}
}
let mut sums = vec![vec![T::zero(); d]; parameters.k];
let mut sums = vec![vec![0f64; d]; parameters.k];
for _ in 1..=parameters.max_iter {
let dist = bbd.clustering(&centroids, &mut sums, &mut size, &mut y);
for i in 0..parameters.k {
if size[i] > 0 {
for j in 0..d {
centroids[i][j] = T::from(sums[i][j]).unwrap() / T::from(size[i]).unwrap();
centroids[i][j] = sums[i][j] / size[i] as f64;
}
}
}
@@ -212,48 +315,61 @@ impl<T: RealNumber + Sum> KMeans<T> {
size,
_distortion: distortion,
centroids,
_phantom_tx: PhantomData,
_phantom_ty: PhantomData,
_phantom_x: PhantomData,
_phantom_y: PhantomData,
})
}
/// Predict clusters for `x`
/// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features.
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
let (n, m) = x.shape();
let mut result = M::zeros(1, n);
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let (n, _) = x.shape();
let mut result = Y::zeros(n);
let mut row = vec![T::zero(); m];
let mut row = vec![0f64; x.shape().1];
for i in 0..n {
let mut min_dist = T::max_value();
let mut min_dist = f64::MAX;
let mut best_cluster = 0;
for j in 0..self.k {
x.copy_row_as_vec(i, &mut row);
x.get_row(i)
.iterator(0)
.zip(row.iter_mut())
.for_each(|(&x, r)| *r = x.to_f64().unwrap());
let dist = Euclidian::squared_distance(&row, &self.centroids[j]);
if dist < min_dist {
min_dist = dist;
best_cluster = j;
}
}
result.set(0, i, T::from(best_cluster).unwrap());
result.set(i, TY::from_usize(best_cluster).unwrap());
}
Ok(result.to_row_vector())
Ok(result)
}
fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize) -> Vec<usize> {
let mut rng = rand::thread_rng();
let (n, m) = data.shape();
fn kmeans_plus_plus(data: &X, k: usize, seed: Option<u64>) -> Vec<usize> {
let mut rng = get_rng_impl(seed);
let (n, _) = data.shape();
let mut y = vec![0; n];
let mut centroid = data.get_row_as_vec(rng.gen_range(0..n));
let mut centroid: Vec<TX> = data
.get_row(rng.gen_range(0..n))
.iterator(0)
.cloned()
.collect();
let mut d = vec![T::max_value(); n];
let mut row = vec![T::zero(); m];
let mut d = vec![f64::MAX; n];
let mut row = vec![TX::zero(); data.shape().1];
for j in 1..k {
for i in 0..n {
data.copy_row_as_vec(i, &mut row);
data.get_row(i)
.iterator(0)
.zip(row.iter_mut())
.for_each(|(&x, r)| *r = x);
let dist = Euclidian::squared_distance(&row, &centroid);
if dist < d[i] {
@@ -262,12 +378,12 @@ impl<T: RealNumber + Sum> KMeans<T> {
}
}
let mut sum: T = T::zero();
let mut sum = 0f64;
for i in d.iter() {
sum += *i;
}
let cutoff = T::from(rng.gen::<f64>()).unwrap() * sum;
let mut cost = T::zero();
let cutoff = rng.gen::<f64>() * sum;
let mut cost = 0f64;
let mut index = 0;
while index < n {
cost += d[index];
@@ -277,11 +393,14 @@ impl<T: RealNumber + Sum> KMeans<T> {
index += 1;
}
data.copy_row_as_vec(index, &mut centroid);
centroid = data.get_row(index).iterator(0).cloned().collect();
}
for i in 0..n {
data.copy_row_as_vec(i, &mut row);
data.get_row(i)
.iterator(0)
.zip(row.iter_mut())
.for_each(|(&x, r)| *r = x);
let dist = Euclidian::squared_distance(&row, &centroid);
if dist < d[i] {
@@ -297,25 +416,61 @@ impl<T: RealNumber + Sum> KMeans<T> {
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::linalg::basic::matrix::DenseMatrix;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn invalid_k() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
assert!(KMeans::fit(&x, KMeansParameters::default().with_k(0)).is_err());
assert!(KMeans::<i32, i32, DenseMatrix<i32>, Vec<i32>>::fit(
&x,
KMeansParameters::default().with_k(0)
)
.is_err());
assert_eq!(
"Fit failed: invalid number of clusters: 1",
KMeans::fit(&x, KMeansParameters::default().with_k(1))
.unwrap_err()
.to_string()
KMeans::<i32, i32, DenseMatrix<i32>, Vec<i32>>::fit(
&x,
KMeansParameters::default().with_k(1)
)
.unwrap_err()
.to_string()
);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn fit_predict_iris() {
fn search_parameters() {
let parameters = KMeansSearchParameters {
k: vec![2, 4],
max_iter: vec![10, 100],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.k, 2);
assert_eq!(next.max_iter, 10);
let next = iter.next().unwrap();
assert_eq!(next.k, 4);
assert_eq!(next.max_iter, 10);
let next = iter.next().unwrap();
assert_eq!(next.k, 2);
assert_eq!(next.max_iter, 100);
let next = iter.next().unwrap();
assert_eq!(next.k, 4);
assert_eq!(next.max_iter, 100);
assert!(iter.next().is_none());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn fit_predict() {
let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
@@ -337,18 +492,22 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
])
.unwrap();
let kmeans = KMeans::fit(&x, Default::default()).unwrap();
let y = kmeans.predict(&x).unwrap();
let y: Vec<usize> = kmeans.predict(&x).unwrap();
for i in 0..y.len() {
assert_eq!(y[i] as usize, kmeans._y[i]);
for (i, _y_i) in y.iter().enumerate() {
assert_eq!({ y[i] }, kmeans._y[i]);
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
#[cfg(feature = "serde")]
fn serde() {
@@ -373,11 +532,13 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
])
.unwrap();
let kmeans = KMeans::fit(&x, Default::default()).unwrap();
let kmeans: KMeans<f32, f32, DenseMatrix<f32>, Vec<f32>> =
KMeans::fit(&x, Default::default()).unwrap();
let deserialized_kmeans: KMeans<f64> =
let deserialized_kmeans: KMeans<f32, f32, DenseMatrix<f32>, Vec<f32>> =
serde_json::from_str(&serde_json::to_string(&kmeans).unwrap()).unwrap();
assert_eq!(kmeans, deserialized_kmeans);
+2
View File
@@ -1,8 +1,10 @@
#![allow(clippy::ptr_arg, clippy::needless_range_loop)]
//! # Clustering
//!
//! Clustering is the type of unsupervised learning where you divide the population or data points into a number of groups such that data points in the same groups
//! are more similar to other data points in the same group than those in other groups. In simple words, the aim is to segregate groups with similar traits and assign them into clusters.
pub mod agglomerative;
pub mod dbscan;
/// An iterative clustering algorithm that aims to find local maxima in each iteration.
pub mod kmeans;
+5 -2
View File
@@ -31,7 +31,7 @@ use crate::dataset::Dataset;
pub fn load_dataset() -> Dataset<f32, f32> {
let (x, y, num_samples, num_features) = match deserialize_data(std::include_bytes!("boston.xy"))
{
Err(why) => panic!("Can't deserialize boston.xy. {}", why),
Err(why) => panic!("Can't deserialize boston.xy. {why}"),
Ok((x, y, num_samples, num_features)) => (x, y, num_samples, num_features),
};
@@ -69,7 +69,10 @@ mod tests {
assert!(serialize_data(&dataset, "boston.xy").is_ok());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn boston_dataset() {
let dataset = load_dataset();
+21 -14
View File
@@ -30,11 +30,16 @@ use crate::dataset::deserialize_data;
use crate::dataset::Dataset;
/// Get dataset
pub fn load_dataset() -> Dataset<f32, f32> {
pub fn load_dataset() -> Dataset<f32, u32> {
let (x, y, num_samples, num_features) =
match deserialize_data(std::include_bytes!("breast_cancer.xy")) {
Err(why) => panic!("Can't deserialize breast_cancer.xy. {}", why),
Ok((x, y, num_samples, num_features)) => (x, y, num_samples, num_features),
Err(why) => panic!("Can't deserialize breast_cancer.xy. {why}"),
Ok((x, y, num_samples, num_features)) => (
x,
y.into_iter().map(|x| x as u32).collect(),
num_samples,
num_features,
),
};
Dataset {
@@ -66,20 +71,22 @@ pub fn load_dataset() -> Dataset<f32, f32> {
#[cfg(test)]
mod tests {
#[cfg(not(target_arch = "wasm32"))]
use super::super::*;
use super::*;
#[test]
#[ignore]
#[cfg(not(target_arch = "wasm32"))]
fn refresh_cancer_dataset() {
// run this test to generate breast_cancer.xy file.
let dataset = load_dataset();
assert!(serialize_data(&dataset, "breast_cancer.xy").is_ok());
}
// TODO: implement serialization
// #[test]
// #[ignore]
// #[cfg(not(target_arch = "wasm32"))]
// fn refresh_cancer_dataset() {
// // run this test to generate breast_cancer.xy file.
// let dataset = load_dataset();
// assert!(serialize_data(&dataset, "breast_cancer.xy").is_ok());
// }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn cancer_dataset() {
let dataset = load_dataset();
+22 -15
View File
@@ -23,11 +23,16 @@ use crate::dataset::deserialize_data;
use crate::dataset::Dataset;
/// Get dataset
pub fn load_dataset() -> Dataset<f32, f32> {
pub fn load_dataset() -> Dataset<f32, u32> {
let (x, y, num_samples, num_features) =
match deserialize_data(std::include_bytes!("diabetes.xy")) {
Err(why) => panic!("Can't deserialize diabetes.xy. {}", why),
Ok((x, y, num_samples, num_features)) => (x, y, num_samples, num_features),
Err(why) => panic!("Can't deserialize diabetes.xy. {why}"),
Ok((x, y, num_samples, num_features)) => (
x,
y.into_iter().map(|x| x as u32).collect(),
num_samples,
num_features,
),
};
Dataset {
@@ -35,7 +40,7 @@ pub fn load_dataset() -> Dataset<f32, f32> {
target: y,
num_samples,
num_features,
feature_names: vec![
feature_names: [
"Age", "Sex", "BMI", "BP", "S1", "S2", "S3", "S4", "S5", "S6",
]
.iter()
@@ -50,20 +55,22 @@ pub fn load_dataset() -> Dataset<f32, f32> {
#[cfg(test)]
mod tests {
#[cfg(not(target_arch = "wasm32"))]
use super::super::*;
use super::*;
#[cfg(not(target_arch = "wasm32"))]
#[test]
#[ignore]
fn refresh_diabetes_dataset() {
// run this test to generate diabetes.xy file.
let dataset = load_dataset();
assert!(serialize_data(&dataset, "diabetes.xy").is_ok());
}
// TODO: fix serialization
// #[cfg(not(target_arch = "wasm32"))]
// #[test]
// #[ignore]
// fn refresh_diabetes_dataset() {
// // run this test to generate diabetes.xy file.
// let dataset = load_dataset();
// assert!(serialize_data(&dataset, "diabetes.xy").is_ok());
// }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn boston_dataset() {
let dataset = load_dataset();
+9 -8
View File
@@ -1,4 +1,4 @@
//! # Optical Recognition of Handwritten Digits Data Set
//! # Optical Recognition of Handwritten Digits Dataset
//!
//! | Number of Instances | Number of Attributes | Missing Values? | Associated Tasks: |
//! |-|-|-|-|
@@ -16,7 +16,7 @@ use crate::dataset::Dataset;
pub fn load_dataset() -> Dataset<f32, f32> {
let (x, y, num_samples, num_features) = match deserialize_data(std::include_bytes!("digits.xy"))
{
Err(why) => panic!("Can't deserialize digits.xy. {}", why),
Err(why) => panic!("Can't deserialize digits.xy. {why}"),
Ok((x, y, num_samples, num_features)) => (x, y, num_samples, num_features),
};
@@ -25,16 +25,14 @@ pub fn load_dataset() -> Dataset<f32, f32> {
target: y,
num_samples,
num_features,
feature_names: vec![
"sepal length (cm)",
feature_names: ["sepal length (cm)",
"sepal width (cm)",
"petal length (cm)",
"petal width (cm)",
]
"petal width (cm)"]
.iter()
.map(|s| s.to_string())
.collect(),
target_names: vec!["setosa", "versicolor", "virginica"]
target_names: ["setosa", "versicolor", "virginica"]
.iter()
.map(|s| s.to_string())
.collect(),
@@ -57,7 +55,10 @@ mod tests {
let dataset = load_dataset();
assert!(serialize_data(&dataset, "digits.xy").is_ok());
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn digits_dataset() {
let dataset = load_dataset();
+16 -7
View File
@@ -48,7 +48,7 @@ pub fn make_blobs(
}
/// Make a large circle containing a smaller circle in 2d.
pub fn make_circles(num_samples: usize, factor: f32, noise: f32) -> Dataset<f32, f32> {
pub fn make_circles(num_samples: usize, factor: f32, noise: f32) -> Dataset<f32, u32> {
if !(0.0..1.0).contains(&factor) {
panic!("'factor' has to be between 0 and 1.");
}
@@ -79,7 +79,7 @@ pub fn make_circles(num_samples: usize, factor: f32, noise: f32) -> Dataset<f32,
Dataset {
data: x,
target: y,
target: y.into_iter().map(|x| x as u32).collect(),
num_samples,
num_features: 2,
feature_names: (0..2).map(|n| n.to_string()).collect(),
@@ -89,7 +89,7 @@ pub fn make_circles(num_samples: usize, factor: f32, noise: f32) -> Dataset<f32,
}
/// Make two interleaving half circles in 2d
pub fn make_moons(num_samples: usize, noise: f32) -> Dataset<f32, f32> {
pub fn make_moons(num_samples: usize, noise: f32) -> Dataset<f32, u32> {
let num_samples_out = num_samples / 2;
let num_samples_in = num_samples - num_samples_out;
@@ -116,7 +116,7 @@ pub fn make_moons(num_samples: usize, noise: f32) -> Dataset<f32, f32> {
Dataset {
data: x,
target: y,
target: y.into_iter().map(|x| x as u32).collect(),
num_samples,
num_features: 2,
feature_names: (0..2).map(|n| n.to_string()).collect(),
@@ -137,7 +137,10 @@ mod tests {
use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn test_make_blobs() {
let dataset = make_blobs(10, 2, 3);
@@ -150,7 +153,10 @@ mod tests {
assert_eq!(dataset.num_samples, 10);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn test_make_circles() {
let dataset = make_circles(10, 0.5, 0.05);
@@ -163,7 +169,10 @@ mod tests {
assert_eq!(dataset.num_samples, 10);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn test_make_moons() {
let dataset = make_moons(10, 0.05);
+29 -19
View File
@@ -1,4 +1,4 @@
//! # The Iris Dataset flower
//! # The Iris flower dataset
//!
//! | Number of Instances | Number of Attributes | Missing Values? | Associated Tasks: |
//! |-|-|-|-|
@@ -19,18 +19,24 @@ use crate::dataset::deserialize_data;
use crate::dataset::Dataset;
/// Get dataset
pub fn load_dataset() -> Dataset<f32, f32> {
let (x, y, num_samples, num_features) = match deserialize_data(std::include_bytes!("iris.xy")) {
Err(why) => panic!("Can't deserialize iris.xy. {}", why),
Ok((x, y, num_samples, num_features)) => (x, y, num_samples, num_features),
};
pub fn load_dataset() -> Dataset<f32, u32> {
let (x, y, num_samples, num_features): (Vec<f32>, Vec<u32>, usize, usize) =
match deserialize_data(std::include_bytes!("iris.xy")) {
Err(why) => panic!("Can't deserialize iris.xy. {why}"),
Ok((x, y, num_samples, num_features)) => (
x,
y.into_iter().map(|x| x as u32).collect(),
num_samples,
num_features,
),
};
Dataset {
data: x,
target: y,
num_samples,
num_features,
feature_names: vec![
feature_names: [
"sepal length (cm)",
"sepal width (cm)",
"petal length (cm)",
@@ -39,7 +45,7 @@ pub fn load_dataset() -> Dataset<f32, f32> {
.iter()
.map(|s| s.to_string())
.collect(),
target_names: vec!["setosa", "versicolor", "virginica"]
target_names: ["setosa", "versicolor", "virginica"]
.iter()
.map(|s| s.to_string())
.collect(),
@@ -50,20 +56,24 @@ pub fn load_dataset() -> Dataset<f32, f32> {
#[cfg(test)]
mod tests {
#[cfg(not(target_arch = "wasm32"))]
use super::super::*;
// #[cfg(not(target_arch = "wasm32"))]
// use super::super::*;
use super::*;
#[cfg(not(target_arch = "wasm32"))]
#[test]
#[ignore]
fn refresh_iris_dataset() {
// run this test to generate iris.xy file.
let dataset = load_dataset();
assert!(serialize_data(&dataset, "iris.xy").is_ok());
}
// TODO: fix serialization
// #[cfg(not(target_arch = "wasm32"))]
// #[test]
// #[ignore]
// fn refresh_iris_dataset() {
// // run this test to generate iris.xy file.
// let dataset = load_dataset();
// assert!(serialize_data(&dataset, "iris.xy").is_ok());
// }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn iris_dataset() {
let dataset = load_dataset();
+9 -5
View File
@@ -1,6 +1,7 @@
#![allow(clippy::ptr_arg, clippy::needless_range_loop)]
//! Datasets
//!
//! In this module you will find small datasets that are used in SmartCore for demonstration purpose mostly.
//! In this module you will find small datasets that are used in `smartcore` mostly for demonstration purposes.
pub mod boston;
pub mod breast_cancer;
pub mod diabetes;
@@ -9,7 +10,7 @@ pub mod generator;
pub mod iris;
#[cfg(not(target_arch = "wasm32"))]
use crate::math::num::RealNumber;
use crate::numbers::{basenum::Number, realnum::RealNumber};
#[cfg(not(target_arch = "wasm32"))]
use std::fs::File;
use std::io;
@@ -55,7 +56,7 @@ impl<X, Y> Dataset<X, Y> {
// Running this in wasm throws: operation not supported on this platform.
#[cfg(not(target_arch = "wasm32"))]
#[allow(dead_code)]
pub(crate) fn serialize_data<X: RealNumber, Y: RealNumber>(
pub(crate) fn serialize_data<X: Number + RealNumber, Y: RealNumber>(
dataset: &Dataset<X, Y>,
filename: &str,
) -> Result<(), io::Error> {
@@ -78,7 +79,7 @@ pub(crate) fn serialize_data<X: RealNumber, Y: RealNumber>(
.collect();
file.write_all(&y)?;
}
Err(why) => panic!("couldn't create {}: {}", filename, why),
Err(why) => panic!("couldn't create {filename}: {why}"),
}
Ok(())
}
@@ -121,7 +122,10 @@ pub(crate) fn deserialize_data(
mod tests {
use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn as_matrix() {
let dataset = Dataset {
+240 -94
View File
@@ -10,7 +10,7 @@
//!
//! Example:
//! ```
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::decomposition::pca::*;
//!
//! // Iris data
@@ -35,7 +35,7 @@
//! &[4.9, 2.4, 3.3, 1.0],
//! &[6.6, 2.9, 4.6, 1.3],
//! &[5.2, 2.7, 3.9, 1.4],
//! ]);
//! ]).unwrap();
//!
//! let pca = PCA::fit(&iris, PCAParameters::default().with_n_components(2)).unwrap(); // Reduce number of features to 2
//!
@@ -52,24 +52,33 @@ use serde::{Deserialize, Serialize};
use crate::api::{Transformer, UnsupervisedEstimator};
use crate::error::Failed;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
use crate::linalg::basic::arrays::Array2;
use crate::linalg::traits::evd::EVDDecomposable;
use crate::linalg::traits::svd::SVDDecomposable;
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
/// Principal components analysis algorithm
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct PCA<T: RealNumber, M: Matrix<T>> {
eigenvectors: M,
pub struct PCA<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>> {
eigenvectors: X,
eigenvalues: Vec<T>,
projection: M,
projection: X,
mu: Vec<T>,
pmu: Vec<T>,
}
impl<T: RealNumber, M: Matrix<T>> PartialEq for PCA<T, M> {
impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>> PartialEq
for PCA<T, X>
{
fn eq(&self, other: &Self) -> bool {
if self.eigenvectors != other.eigenvectors
|| self.eigenvalues.len() != other.eigenvalues.len()
if self.eigenvalues.len() != other.eigenvalues.len()
|| self
.eigenvectors
.iterator(0)
.zip(other.eigenvectors.iterator(0))
.any(|(&a, &b)| (a - b).abs() > T::epsilon())
{
false
} else {
@@ -83,11 +92,14 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for PCA<T, M> {
}
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// PCA parameters
pub struct PCAParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Number of components to keep.
pub n_components: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// By default, covariance matrix is used to compute principal components.
/// Enable this flag if you want to use correlation matrix instead.
pub use_correlation_matrix: bool,
@@ -116,40 +128,124 @@ impl Default for PCAParameters {
}
}
impl<T: RealNumber, M: Matrix<T>> UnsupervisedEstimator<M, PCAParameters> for PCA<T, M> {
fn fit(x: &M, parameters: PCAParameters) -> Result<Self, Failed> {
/// PCA grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct PCASearchParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Number of components to keep.
pub n_components: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// By default, covariance matrix is used to compute principal components.
/// Enable this flag if you want to use correlation matrix instead.
pub use_correlation_matrix: Vec<bool>,
}
/// PCA grid search iterator
pub struct PCASearchParametersIterator {
pca_search_parameters: PCASearchParameters,
current_k: usize,
current_use_correlation_matrix: usize,
}
impl IntoIterator for PCASearchParameters {
type Item = PCAParameters;
type IntoIter = PCASearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
PCASearchParametersIterator {
pca_search_parameters: self,
current_k: 0,
current_use_correlation_matrix: 0,
}
}
}
impl Iterator for PCASearchParametersIterator {
type Item = PCAParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_k == self.pca_search_parameters.n_components.len()
&& self.current_use_correlation_matrix
== self.pca_search_parameters.use_correlation_matrix.len()
{
return None;
}
let next = PCAParameters {
n_components: self.pca_search_parameters.n_components[self.current_k],
use_correlation_matrix: self.pca_search_parameters.use_correlation_matrix
[self.current_use_correlation_matrix],
};
if self.current_k + 1 < self.pca_search_parameters.n_components.len() {
self.current_k += 1;
} else if self.current_use_correlation_matrix + 1
< self.pca_search_parameters.use_correlation_matrix.len()
{
self.current_k = 0;
self.current_use_correlation_matrix += 1;
} else {
self.current_k += 1;
self.current_use_correlation_matrix += 1;
}
Some(next)
}
}
impl Default for PCASearchParameters {
fn default() -> Self {
let default_params = PCAParameters::default();
PCASearchParameters {
n_components: vec![default_params.n_components],
use_correlation_matrix: vec![default_params.use_correlation_matrix],
}
}
}
impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>>
UnsupervisedEstimator<X, PCAParameters> for PCA<T, X>
{
fn fit(x: &X, parameters: PCAParameters) -> Result<Self, Failed> {
PCA::fit(x, parameters)
}
}
impl<T: RealNumber, M: Matrix<T>> Transformer<M> for PCA<T, M> {
fn transform(&self, x: &M) -> Result<M, Failed> {
impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>> Transformer<X>
for PCA<T, X>
{
fn transform(&self, x: &X) -> Result<X, Failed> {
self.transform(x)
}
}
impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>> PCA<T, X> {
/// Fits PCA to your data.
/// * `data` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `n_components` - number of components to keep.
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
pub fn fit(data: &M, parameters: PCAParameters) -> Result<PCA<T, M>, Failed> {
pub fn fit(data: &X, parameters: PCAParameters) -> Result<PCA<T, X>, Failed> {
let (m, n) = data.shape();
if parameters.n_components > n {
return Err(Failed::fit(&format!(
"Number of components, n_components should be <= number of attributes ({})",
n
"Number of components, n_components should be <= number of attributes ({n})"
)));
}
let mu = data.column_mean();
let mu: Vec<T> = data
.mean_by(0)
.iter()
.map(|&v| T::from_f64(v).unwrap())
.collect();
let mut x = data.clone();
for (c, mu_c) in mu.iter().enumerate().take(n) {
for (c, &mu_c) in mu.iter().enumerate().take(n) {
for r in 0..m {
x.sub_element_mut(r, c, *mu_c);
x.sub_element_mut((r, c), mu_c);
}
}
@@ -165,33 +261,33 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
eigenvectors = svd.V;
} else {
let mut cov = M::zeros(n, n);
let mut cov = X::zeros(n, n);
for k in 0..m {
for i in 0..n {
for j in 0..=i {
cov.add_element_mut(i, j, x.get(k, i) * x.get(k, j));
cov.add_element_mut((i, j), *x.get((k, i)) * *x.get((k, j)));
}
}
}
for i in 0..n {
for j in 0..=i {
cov.div_element_mut(i, j, T::from(m).unwrap());
cov.set(j, i, cov.get(i, j));
cov.div_element_mut((i, j), T::from(m).unwrap());
cov.set((j, i), *cov.get((i, j)));
}
}
if parameters.use_correlation_matrix {
let mut sd = vec![T::zero(); n];
for (i, sd_i) in sd.iter_mut().enumerate().take(n) {
*sd_i = cov.get(i, i).sqrt();
*sd_i = cov.get((i, i)).sqrt();
}
for i in 0..n {
for j in 0..=i {
cov.div_element_mut(i, j, sd[i] * sd[j]);
cov.set(j, i, cov.get(i, j));
cov.div_element_mut((i, j), sd[i] * sd[j]);
cov.set((j, i), *cov.get((i, j)));
}
}
@@ -203,7 +299,7 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
for (i, sd_i) in sd.iter().enumerate().take(n) {
for j in 0..n {
eigenvectors.div_element_mut(i, j, *sd_i);
eigenvectors.div_element_mut((i, j), *sd_i);
}
}
} else {
@@ -215,17 +311,17 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
}
}
let mut projection = M::zeros(parameters.n_components, n);
let mut projection = X::zeros(parameters.n_components, n);
for i in 0..n {
for j in 0..parameters.n_components {
projection.set(j, i, eigenvectors.get(i, j));
projection.set((j, i), *eigenvectors.get((i, j)));
}
}
let mut pmu = vec![T::zero(); parameters.n_components];
for (k, mu_k) in mu.iter().enumerate().take(n) {
for (i, pmu_i) in pmu.iter_mut().enumerate().take(parameters.n_components) {
*pmu_i += projection.get(i, k) * (*mu_k);
*pmu_i += *projection.get((i, k)) * (*mu_k);
}
}
@@ -240,7 +336,7 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
/// Run dimensionality reduction for `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn transform(&self, x: &M) -> Result<M, Failed> {
pub fn transform(&self, x: &X) -> Result<X, Failed> {
let (nrows, ncols) = x.shape();
let (_, n_components) = self.projection.shape();
if ncols != self.mu.len() {
@@ -254,14 +350,14 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
let mut x_transformed = x.matmul(&self.projection);
for r in 0..nrows {
for c in 0..n_components {
x_transformed.sub_element_mut(r, c, self.pmu[c]);
x_transformed.sub_element_mut((r, c), self.pmu[c]);
}
}
Ok(x_transformed)
}
/// Get a projection matrix
pub fn components(&self) -> &M {
pub fn components(&self) -> &X {
&self.projection
}
}
@@ -269,7 +365,30 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::*;
use crate::linalg::basic::matrix::DenseMatrix;
use approx::relative_eq;
#[test]
fn search_parameters() {
let parameters = PCASearchParameters {
n_components: vec![2, 4],
use_correlation_matrix: vec![true, false],
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.n_components, 2);
assert!(next.use_correlation_matrix);
let next = iter.next().unwrap();
assert_eq!(next.n_components, 4);
assert!(next.use_correlation_matrix);
let next = iter.next().unwrap();
assert_eq!(next.n_components, 2);
assert!(!next.use_correlation_matrix);
let next = iter.next().unwrap();
assert_eq!(next.n_components, 4);
assert!(!next.use_correlation_matrix);
assert!(iter.next().is_none());
}
fn us_arrests_data() -> DenseMatrix<f64> {
DenseMatrix::from_2d_array(&[
@@ -324,8 +443,12 @@ mod tests {
&[2.6, 53.0, 66.0, 10.8],
&[6.8, 161.0, 60.0, 15.6],
])
.unwrap()
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn pca_components() {
let us_arrests = us_arrests_data();
@@ -335,13 +458,21 @@ mod tests {
&[0.9952, 0.0588],
&[0.0463, 0.9769],
&[0.0752, 0.2007],
]);
])
.unwrap();
let pca = PCA::fit(&us_arrests, Default::default()).unwrap();
assert!(expected.approximate_eq(&pca.components().abs(), 0.4));
assert!(relative_eq!(
expected,
pca.components().abs(),
epsilon = 1e-3
));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn decompose_covariance() {
let us_arrests = us_arrests_data();
@@ -371,7 +502,8 @@ mod tests {
-0.974080592182491,
0.0723250196376097,
],
]);
])
.unwrap();
let expected_projection = DenseMatrix::from_2d_array(&[
&[-64.8022, -11.448, 2.4949, -2.4079],
@@ -424,7 +556,8 @@ mod tests {
&[91.5446, -22.9529, 0.402, -0.7369],
&[118.1763, 5.5076, 2.7113, -0.205],
&[10.4345, -5.9245, 3.7944, 0.5179],
]);
])
.unwrap();
let expected_eigenvalues: Vec<f64> = vec![
343544.6277001563,
@@ -435,23 +568,29 @@ mod tests {
let pca = PCA::fit(&us_arrests, PCAParameters::default().with_n_components(4)).unwrap();
assert!(pca
.eigenvectors
.abs()
.approximate_eq(&expected_eigenvectors.abs(), 1e-4));
assert!(relative_eq!(
pca.eigenvectors.abs(),
&expected_eigenvectors.abs(),
epsilon = 1e-4
));
for i in 0..pca.eigenvalues.len() {
assert!((pca.eigenvalues[i].abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
for (i, pca_eigenvalues_i) in pca.eigenvalues.iter().enumerate() {
assert!((pca_eigenvalues_i.abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
}
let us_arrests_t = pca.transform(&us_arrests).unwrap();
assert!(us_arrests_t
.abs()
.approximate_eq(&expected_projection.abs(), 1e-4));
assert!(relative_eq!(
us_arrests_t.abs(),
&expected_projection.abs(),
epsilon = 1e-4
));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn decompose_correlation() {
let us_arrests = us_arrests_data();
@@ -481,7 +620,8 @@ mod tests {
-0.0881962972508558,
-0.0096011588898465,
],
]);
])
.unwrap();
let expected_projection = DenseMatrix::from_2d_array(&[
&[0.9856, -1.1334, 0.4443, -0.1563],
@@ -534,7 +674,8 @@ mod tests {
&[-2.1086, -1.4248, -0.1048, -0.1319],
&[-2.0797, 0.6113, 0.1389, -0.1841],
&[-0.6294, -0.321, 0.2407, 0.1667],
]);
])
.unwrap();
let expected_eigenvalues: Vec<f64> = vec![
2.480241579149493,
@@ -551,54 +692,59 @@ mod tests {
)
.unwrap();
assert!(pca
.eigenvectors
.abs()
.approximate_eq(&expected_eigenvectors.abs(), 1e-4));
assert!(relative_eq!(
pca.eigenvectors.abs(),
&expected_eigenvectors.abs(),
epsilon = 1e-4
));
for i in 0..pca.eigenvalues.len() {
assert!((pca.eigenvalues[i].abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
for (i, pca_eigenvalues_i) in pca.eigenvalues.iter().enumerate() {
assert!((pca_eigenvalues_i.abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
}
let us_arrests_t = pca.transform(&us_arrests).unwrap();
assert!(us_arrests_t
.abs()
.approximate_eq(&expected_projection.abs(), 1e-4));
assert!(relative_eq!(
us_arrests_t.abs(),
&expected_projection.abs(),
epsilon = 1e-4
));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
#[cfg(feature = "serde")]
fn serde() {
let iris = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[5.7, 2.8, 4.5, 1.3],
&[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
// Disable this test for now
// TODO: implement deserialization for new DenseMatrix
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
// #[test]
// #[cfg(feature = "serde")]
// fn pca_serde() {
// let iris = DenseMatrix::from_2d_array(&[
// &[5.1, 3.5, 1.4, 0.2],
// &[4.9, 3.0, 1.4, 0.2],
// &[4.7, 3.2, 1.3, 0.2],
// &[4.6, 3.1, 1.5, 0.2],
// &[5.0, 3.6, 1.4, 0.2],
// &[5.4, 3.9, 1.7, 0.4],
// &[4.6, 3.4, 1.4, 0.3],
// &[5.0, 3.4, 1.5, 0.2],
// &[4.4, 2.9, 1.4, 0.2],
// &[4.9, 3.1, 1.5, 0.1],
// &[7.0, 3.2, 4.7, 1.4],
// &[6.4, 3.2, 4.5, 1.5],
// &[6.9, 3.1, 4.9, 1.5],
// &[5.5, 2.3, 4.0, 1.3],
// &[6.5, 2.8, 4.6, 1.5],
// &[5.7, 2.8, 4.5, 1.3],
// &[6.3, 3.3, 4.7, 1.6],
// &[4.9, 2.4, 3.3, 1.0],
// &[6.6, 2.9, 4.6, 1.3],
// &[5.2, 2.7, 3.9, 1.4],
// ]).unwrap();
let pca = PCA::fit(&iris, Default::default()).unwrap();
// let pca = PCA::fit(&iris, Default::default()).unwrap();
let deserialized_pca: PCA<f64, DenseMatrix<f64>> =
serde_json::from_str(&serde_json::to_string(&pca).unwrap()).unwrap();
// let deserialized_pca: PCA<f64, DenseMatrix<f64>> =
// serde_json::from_str(&serde_json::to_string(&pca).unwrap()).unwrap();
assert_eq!(pca, deserialized_pca);
}
// assert_eq!(pca, deserialized_pca);
// }
}
+149 -59
View File
@@ -7,7 +7,7 @@
//!
//! Example:
//! ```
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::decomposition::svd::*;
//!
//! // Iris data
@@ -32,7 +32,7 @@
//! &[4.9, 2.4, 3.3, 1.0],
//! &[6.6, 2.9, 4.6, 1.3],
//! &[5.2, 2.7, 3.9, 1.4],
//! ]);
//! ]).unwrap();
//!
//! let svd = SVD::fit(&iris, SVDParameters::default().
//! with_n_components(2)).unwrap(); // Reduce number of features to 2
@@ -51,27 +51,36 @@ use serde::{Deserialize, Serialize};
use crate::api::{Transformer, UnsupervisedEstimator};
use crate::error::Failed;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
use crate::linalg::basic::arrays::Array2;
use crate::linalg::traits::evd::EVDDecomposable;
use crate::linalg::traits::svd::SVDDecomposable;
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
/// SVD
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct SVD<T: RealNumber, M: Matrix<T>> {
components: M,
pub struct SVD<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>> {
components: X,
phantom: PhantomData<T>,
}
impl<T: RealNumber, M: Matrix<T>> PartialEq for SVD<T, M> {
impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>> PartialEq
for SVD<T, X>
{
fn eq(&self, other: &Self) -> bool {
self.components
.approximate_eq(&other.components, T::from_f64(1e-8).unwrap())
.iterator(0)
.zip(other.components.iterator(0))
.all(|(&a, &b)| (a - b).abs() <= T::epsilon())
}
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// SVD parameters
pub struct SVDParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Number of components to keep.
pub n_components: usize,
}
@@ -90,36 +99,94 @@ impl SVDParameters {
}
}
impl<T: RealNumber, M: Matrix<T>> UnsupervisedEstimator<M, SVDParameters> for SVD<T, M> {
fn fit(x: &M, parameters: SVDParameters) -> Result<Self, Failed> {
/// SVD grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct SVDSearchParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Maximum number of iterations of the k-means algorithm for a single run.
pub n_components: Vec<usize>,
}
/// SVD grid search iterator
pub struct SVDSearchParametersIterator {
svd_search_parameters: SVDSearchParameters,
current_n_components: usize,
}
impl IntoIterator for SVDSearchParameters {
type Item = SVDParameters;
type IntoIter = SVDSearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
SVDSearchParametersIterator {
svd_search_parameters: self,
current_n_components: 0,
}
}
}
impl Iterator for SVDSearchParametersIterator {
type Item = SVDParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_n_components == self.svd_search_parameters.n_components.len() {
return None;
}
let next = SVDParameters {
n_components: self.svd_search_parameters.n_components[self.current_n_components],
};
self.current_n_components += 1;
Some(next)
}
}
impl Default for SVDSearchParameters {
fn default() -> Self {
let default_params = SVDParameters::default();
SVDSearchParameters {
n_components: vec![default_params.n_components],
}
}
}
impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>>
UnsupervisedEstimator<X, SVDParameters> for SVD<T, X>
{
fn fit(x: &X, parameters: SVDParameters) -> Result<Self, Failed> {
SVD::fit(x, parameters)
}
}
impl<T: RealNumber, M: Matrix<T>> Transformer<M> for SVD<T, M> {
fn transform(&self, x: &M) -> Result<M, Failed> {
impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>> Transformer<X>
for SVD<T, X>
{
fn transform(&self, x: &X) -> Result<X, Failed> {
self.transform(x)
}
}
impl<T: RealNumber, M: Matrix<T>> SVD<T, M> {
impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>> SVD<T, X> {
/// Fits SVD to your data.
/// * `data` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `n_components` - number of components to keep.
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
pub fn fit(x: &M, parameters: SVDParameters) -> Result<SVD<T, M>, Failed> {
pub fn fit(x: &X, parameters: SVDParameters) -> Result<SVD<T, X>, Failed> {
let (_, p) = x.shape();
if parameters.n_components >= p {
return Err(Failed::fit(&format!(
"Number of components, n_components should be < number of attributes ({})",
p
"Number of components, n_components should be < number of attributes ({p})"
)));
}
let svd = x.svd()?;
let components = svd.V.slice(0..p, 0..parameters.n_components);
let components = X::from_slice(svd.V.slice(0..p, 0..parameters.n_components).as_ref());
Ok(SVD {
components,
@@ -129,13 +196,12 @@ impl<T: RealNumber, M: Matrix<T>> SVD<T, M> {
/// Run dimensionality reduction for `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn transform(&self, x: &M) -> Result<M, Failed> {
pub fn transform(&self, x: &X) -> Result<X, Failed> {
let (n, p) = x.shape();
let (p_c, k) = self.components.shape();
if p_c != p {
return Err(Failed::transform(&format!(
"Can not transform a {}x{} matrix into {}x{} matrix, incorrect input dimentions",
n, p, n, k
"Can not transform a {n}x{p} matrix into {n}x{k} matrix, incorrect input dimentions"
)));
}
@@ -143,7 +209,7 @@ impl<T: RealNumber, M: Matrix<T>> SVD<T, M> {
}
/// Get a projection matrix
pub fn components(&self) -> &M {
pub fn components(&self) -> &X {
&self.components
}
}
@@ -151,9 +217,27 @@ impl<T: RealNumber, M: Matrix<T>> SVD<T, M> {
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::*;
use crate::linalg::basic::arrays::Array;
use crate::linalg::basic::matrix::DenseMatrix;
use approx::relative_eq;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn search_parameters() {
let parameters = SVDSearchParameters {
n_components: vec![10, 100],
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.n_components, 10);
let next = iter.next().unwrap();
assert_eq!(next.n_components, 100);
assert!(iter.next().is_none());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn svd_decompose() {
// https://stat.ethz.ch/R-manual/R-devel/library/datasets/html/USArrests.html
@@ -208,7 +292,8 @@ mod tests {
&[5.7, 81.0, 39.0, 9.3],
&[2.6, 53.0, 66.0, 10.8],
&[6.8, 161.0, 60.0, 15.6],
]);
])
.unwrap();
let expected = DenseMatrix::from_2d_array(&[
&[243.54655757, -18.76673788],
@@ -216,50 +301,55 @@ mod tests {
&[305.93972467, -15.39087376],
&[197.28420365, -11.66808306],
&[293.43187394, 1.91163633],
]);
])
.unwrap();
let svd = SVD::fit(&x, Default::default()).unwrap();
let x_transformed = svd.transform(&x).unwrap();
assert_eq!(svd.components.shape(), (x.shape().1, 2));
assert!(x_transformed
.slice(0..5, 0..2)
.approximate_eq(&expected, 1e-4));
assert!(relative_eq!(
DenseMatrix::from_slice(x_transformed.slice(0..5, 0..2).as_ref()),
&expected,
epsilon = 1e-4
));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
#[cfg(feature = "serde")]
fn serde() {
let iris = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[5.7, 2.8, 4.5, 1.3],
&[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
// Disable this test for now
// TODO: implement deserialization for new DenseMatrix
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
// #[test]
// #[cfg(feature = "serde")]
// fn serde() {
// let iris = DenseMatrix::from_2d_array(&[
// &[5.1, 3.5, 1.4, 0.2],
// &[4.9, 3.0, 1.4, 0.2],
// &[4.7, 3.2, 1.3, 0.2],
// &[4.6, 3.1, 1.5, 0.2],
// &[5.0, 3.6, 1.4, 0.2],
// &[5.4, 3.9, 1.7, 0.4],
// &[4.6, 3.4, 1.4, 0.3],
// &[5.0, 3.4, 1.5, 0.2],
// &[4.4, 2.9, 1.4, 0.2],
// &[4.9, 3.1, 1.5, 0.1],
// &[7.0, 3.2, 4.7, 1.4],
// &[6.4, 3.2, 4.5, 1.5],
// &[6.9, 3.1, 4.9, 1.5],
// &[5.5, 2.3, 4.0, 1.3],
// &[6.5, 2.8, 4.6, 1.5],
// &[5.7, 2.8, 4.5, 1.3],
// &[6.3, 3.3, 4.7, 1.6],
// &[4.9, 2.4, 3.3, 1.0],
// &[6.6, 2.9, 4.6, 1.3],
// &[5.2, 2.7, 3.9, 1.4],
// ]).unwrap();
let svd = SVD::fit(&iris, Default::default()).unwrap();
// let svd = SVD::fit(&iris, Default::default()).unwrap();
let deserialized_svd: SVD<f64, DenseMatrix<f64>> =
serde_json::from_str(&serde_json::to_string(&svd).unwrap()).unwrap();
// let deserialized_svd: SVD<f32, DenseMatrix<f32>> =
// serde_json::from_str(&serde_json::to_string(&svd).unwrap()).unwrap();
assert_eq!(svd, deserialized_svd);
}
// assert_eq!(svd, deserialized_svd);
// }
}
+214
View File
@@ -0,0 +1,214 @@
use rand::Rng;
use std::fmt::Debug;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
use crate::rand_custom::get_rng_impl;
use crate::tree::base_tree_regressor::{BaseTreeRegressor, BaseTreeRegressorParameters, Splitter};
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// Parameters of the Forest Regressor
/// Some parameters here are passed directly into base estimator.
pub struct BaseForestRegressorParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub max_depth: Option<u16>,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub min_samples_leaf: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub min_samples_split: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The number of trees in the forest.
pub n_trees: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// Number of random sample of predictors to use as split candidates.
pub m: Option<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub keep_samples: bool,
#[cfg_attr(feature = "serde", serde(default))]
/// Seed used for bootstrap sampling and feature selection for each tree.
pub seed: u64,
#[cfg_attr(feature = "serde", serde(default))]
pub bootstrap: bool,
#[cfg_attr(feature = "serde", serde(default))]
pub splitter: Splitter,
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
for BaseForestRegressor<TX, TY, X, Y>
{
fn eq(&self, other: &Self) -> bool {
if self.trees.as_ref().unwrap().len() != other.trees.as_ref().unwrap().len() {
false
} else {
self.trees
.iter()
.zip(other.trees.iter())
.all(|(a, b)| a == b)
}
}
}
/// Forest Regressor
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct BaseForestRegressor<
TX: Number + FloatNumber + PartialOrd,
TY: Number,
X: Array2<TX>,
Y: Array1<TY>,
> {
trees: Option<Vec<BaseTreeRegressor<TX, TY, X, Y>>>,
samples: Option<Vec<Vec<bool>>>,
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
BaseForestRegressor<TX, TY, X, Y>
{
/// Build a forest of trees from the training set.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - the target class values
pub fn fit(
x: &X,
y: &Y,
parameters: BaseForestRegressorParameters,
) -> Result<BaseForestRegressor<TX, TY, X, Y>, Failed> {
let (n_rows, num_attributes) = x.shape();
if n_rows != y.shape() {
return Err(Failed::fit("Number of rows in X should = len(y)"));
}
let mtry = parameters
.m
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
let mut rng = get_rng_impl(Some(parameters.seed));
let mut trees: Vec<BaseTreeRegressor<TX, TY, X, Y>> = Vec::new();
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
if parameters.keep_samples {
// TODO: use with_capacity here
maybe_all_samples = Some(Vec::new());
}
let mut samples: Vec<usize> = (0..n_rows).map(|_| 1).collect();
for _ in 0..parameters.n_trees {
if parameters.bootstrap {
samples =
BaseForestRegressor::<TX, TY, X, Y>::sample_with_replacement(n_rows, &mut rng);
}
// keep samples is flag is on
if let Some(ref mut all_samples) = maybe_all_samples {
all_samples.push(samples.iter().map(|x| *x != 0).collect())
}
let params = BaseTreeRegressorParameters {
max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split,
seed: Some(parameters.seed),
splitter: parameters.splitter.clone(),
};
let tree = BaseTreeRegressor::fit_weak_learner(x, y, samples.clone(), mtry, params)?;
trees.push(tree);
}
Ok(BaseForestRegressor {
trees: Some(trees),
samples: maybe_all_samples,
})
}
/// Predict class for `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let mut result = Y::zeros(x.shape().0);
let (n, _) = x.shape();
for i in 0..n {
result.set(i, self.predict_for_row(x, i));
}
Ok(result)
}
fn predict_for_row(&self, x: &X, row: usize) -> TY {
let n_trees = self.trees.as_ref().unwrap().len();
let mut result = TY::zero();
for tree in self.trees.as_ref().unwrap().iter() {
result += tree.predict_for_row(x, row);
}
result / TY::from_usize(n_trees).unwrap()
}
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
let (n, _) = x.shape();
if self.samples.is_none() {
Err(Failed::because(
FailedError::PredictFailed,
"Need samples=true for OOB predictions.",
))
} else if self.samples.as_ref().unwrap()[0].len() != n {
Err(Failed::because(
FailedError::PredictFailed,
"Prediction matrix must match matrix used in training for OOB predictions.",
))
} else {
let mut result = Y::zeros(n);
for i in 0..n {
result.set(i, self.predict_for_row_oob(x, i));
}
Ok(result)
}
}
fn predict_for_row_oob(&self, x: &X, row: usize) -> TY {
let mut n_trees = 0;
let mut result = TY::zero();
for (tree, samples) in self
.trees
.as_ref()
.unwrap()
.iter()
.zip(self.samples.as_ref().unwrap())
{
if !samples[row] {
result += tree.predict_for_row(x, row);
n_trees += 1;
}
}
// TODO: What to do if there are no oob trees?
result / TY::from(n_trees).unwrap()
}
fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec<usize> {
let mut samples = vec![0; nrows];
for _ in 0..nrows {
let xi = rng.gen_range(0..nrows);
samples[xi] += 1;
}
samples
}
}
+318
View File
@@ -0,0 +1,318 @@
//! # Extra Trees Regressor
//! An Extra-Trees (Extremely Randomized Trees) regressor is an ensemble learning method that fits multiple randomized
//! decision trees on the dataset and averages their predictions to improve accuracy and control over-fitting.
//!
//! It is similar to a standard Random Forest, but introduces more randomness in the way splits are chosen, which can
//! reduce the variance of the model and often make the training process faster.
//!
//! The two key differences from a standard Random Forest are:
//! 1. It uses the whole original dataset to build each tree instead of bootstrap samples.
//! 2. When splitting a node, it chooses a random split point for each feature, rather than the most optimal one.
//!
//! See [ensemble models](../index.html) for more details.
//!
//! Bigger number of estimators in general improves performance of the algorithm with an increased cost of training time.
//! The random sample of _m_ predictors is typically set to be \\(\sqrt{p}\\) from the full set of _p_ predictors.
//!
//! Example:
//!
//! ```
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::ensemble::extra_trees_regressor::*;
//!
//! // Longley dataset ([https://www.statsmodels.org/stable/datasets/generated/longley.html](https://www.statsmodels.org/stable/datasets/generated/longley.html))
//! let x = DenseMatrix::from_2d_array(&[
//! &[234.289, 235.6, 159., 107.608, 1947., 60.323],
//! &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
//! &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
//! &[284.599, 335.1, 165., 110.929, 1950., 61.187],
//! &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
//! &[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
//! &[365.385, 187., 354.7, 115.094, 1953., 64.989],
//! &[363.112, 357.8, 335., 116.219, 1954., 63.761],
//! &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
//! &[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
//! &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
//! &[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
//! &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
//! ]).unwrap();
//! let y = vec![
//! 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2,
//! 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9
//! ];
//!
//! let regressor = ExtraTreesRegressor::fit(&x, &y, Default::default()).unwrap();
//!
//! let y_hat = regressor.predict(&x).unwrap(); // use the same data for prediction
//! ```
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::default::Default;
use std::fmt::Debug;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::ensemble::base_forest_regressor::{BaseForestRegressor, BaseForestRegressorParameters};
use crate::error::Failed;
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
use crate::tree::base_tree_regressor::Splitter;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// Parameters of the Extra Trees Regressor
/// Some parameters here are passed directly into base estimator.
pub struct ExtraTreesRegressorParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub max_depth: Option<u16>,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub min_samples_leaf: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub min_samples_split: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The number of trees in the forest.
pub n_trees: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// Number of random sample of predictors to use as split candidates.
pub m: Option<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub keep_samples: bool,
#[cfg_attr(feature = "serde", serde(default))]
/// Seed used for bootstrap sampling and feature selection for each tree.
pub seed: u64,
}
/// Extra Trees Regressor
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct ExtraTreesRegressor<
TX: Number + FloatNumber + PartialOrd,
TY: Number,
X: Array2<TX>,
Y: Array1<TY>,
> {
forest_regressor: Option<BaseForestRegressor<TX, TY, X, Y>>,
}
impl ExtraTreesRegressorParameters {
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub fn with_max_depth(mut self, max_depth: u16) -> Self {
self.max_depth = Some(max_depth);
self
}
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
self.min_samples_leaf = min_samples_leaf;
self
}
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
self.min_samples_split = min_samples_split;
self
}
/// The number of trees in the forest.
pub fn with_n_trees(mut self, n_trees: usize) -> Self {
self.n_trees = n_trees;
self
}
/// Number of random sample of predictors to use as split candidates.
pub fn with_m(mut self, m: usize) -> Self {
self.m = Some(m);
self
}
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub fn with_keep_samples(mut self, keep_samples: bool) -> Self {
self.keep_samples = keep_samples;
self
}
/// Seed used for bootstrap sampling and feature selection for each tree.
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
}
impl Default for ExtraTreesRegressorParameters {
fn default() -> Self {
ExtraTreesRegressorParameters {
max_depth: Option::None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 10,
m: Option::None,
keep_samples: false,
seed: 0,
}
}
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
SupervisedEstimator<X, Y, ExtraTreesRegressorParameters> for ExtraTreesRegressor<TX, TY, X, Y>
{
fn new() -> Self {
Self {
forest_regressor: Option::None,
}
}
fn fit(x: &X, y: &Y, parameters: ExtraTreesRegressorParameters) -> Result<Self, Failed> {
ExtraTreesRegressor::fit(x, y, parameters)
}
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
Predictor<X, Y> for ExtraTreesRegressor<TX, TY, X, Y>
{
fn predict(&self, x: &X) -> Result<Y, Failed> {
self.predict(x)
}
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
ExtraTreesRegressor<TX, TY, X, Y>
{
/// Build a forest of trees from the training set.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - the target class values
pub fn fit(
x: &X,
y: &Y,
parameters: ExtraTreesRegressorParameters,
) -> Result<ExtraTreesRegressor<TX, TY, X, Y>, Failed> {
let regressor_params = BaseForestRegressorParameters {
max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split,
n_trees: parameters.n_trees,
m: parameters.m,
keep_samples: parameters.keep_samples,
seed: parameters.seed,
bootstrap: false,
splitter: Splitter::Random,
};
let forest_regressor = BaseForestRegressor::fit(x, y, regressor_params)?;
Ok(ExtraTreesRegressor {
forest_regressor: Some(forest_regressor),
})
}
/// Predict class for `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let forest_regressor = self.forest_regressor.as_ref().unwrap();
forest_regressor.predict(x)
}
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
let forest_regressor = self.forest_regressor.as_ref().unwrap();
forest_regressor.predict_oob(x)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::metrics::mean_squared_error;
#[test]
fn test_extra_trees_regressor_fit_predict() {
// Use a simpler, more predictable dataset for unit testing.
let x = DenseMatrix::from_2d_array(&[
&[1., 2.],
&[3., 4.],
&[5., 6.],
&[7., 8.],
&[9., 10.],
&[11., 12.],
&[13., 14.],
&[15., 16.],
])
.unwrap();
let y = vec![1., 2., 3., 4., 5., 6., 7., 8.];
let parameters = ExtraTreesRegressorParameters::default()
.with_n_trees(100)
.with_seed(42);
let regressor = ExtraTreesRegressor::fit(&x, &y, parameters).unwrap();
let y_hat = regressor.predict(&x).unwrap();
assert_eq!(y_hat.len(), y.len());
// A basic check to ensure the model is learning something.
// The error should be significantly less than the variance of y.
let mse = mean_squared_error(&y, &y_hat);
// With this simple dataset, the error should be very low.
assert!(mse < 1.0);
}
#[test]
fn test_fit_predict_higher_dims() {
// Dataset with 10 features, but y is only dependent on the 3rd feature (index 2).
let x = DenseMatrix::from_2d_array(&[
// The 3rd column is the important one. The rest are noise.
&[0., 0., 10., 5., 8., 1., 4., 9., 2., 7.],
&[0., 0., 20., 1., 2., 3., 4., 5., 6., 7.],
&[0., 0., 30., 7., 6., 5., 4., 3., 2., 1.],
&[0., 0., 40., 9., 2., 4., 6., 8., 1., 3.],
&[0., 0., 55., 3., 1., 8., 6., 4., 2., 9.],
&[0., 0., 65., 2., 4., 7., 5., 3., 1., 8.],
])
.unwrap();
let y = vec![10., 20., 30., 40., 55., 65.];
let parameters = ExtraTreesRegressorParameters::default()
.with_n_trees(100)
.with_seed(42);
let regressor = ExtraTreesRegressor::fit(&x, &y, parameters).unwrap();
let y_hat = regressor.predict(&x).unwrap();
assert_eq!(y_hat.len(), y.len());
let mse = mean_squared_error(&y, &y_hat);
// The model should be able to learn this simple relationship perfectly,
// ignoring the noise features. The MSE should be very low.
assert!(mse < 1.0);
}
#[test]
fn test_reproducibility() {
let x = DenseMatrix::from_2d_array(&[
&[1., 2.],
&[3., 4.],
&[5., 6.],
&[7., 8.],
&[9., 10.],
&[11., 12.],
])
.unwrap();
let y = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let params = ExtraTreesRegressorParameters::default().with_seed(42);
let regressor1 = ExtraTreesRegressor::fit(&x, &y, params.clone()).unwrap();
let y_hat1 = regressor1.predict(&x).unwrap();
let regressor2 = ExtraTreesRegressor::fit(&x, &y, params.clone()).unwrap();
let y_hat2 = regressor2.predict(&x).unwrap();
assert_eq!(y_hat1, y_hat2);
}
}
+3 -1
View File
@@ -7,7 +7,7 @@
//! set and then aggregate their individual predictions to form a final prediction. In classification setting the overall prediction is the most commonly
//! occurring majority class among the individual predictions.
//!
//! In SmartCore you will find implementation of RandomForest - a popular averaging algorithms based on randomized [decision trees](../tree/index.html).
//! In `smartcore` you will find implementation of RandomForest - a popular averaging algorithms based on randomized [decision trees](../tree/index.html).
//! Random forests provide an improvement over bagged trees by way of a small tweak that decorrelates the trees. As in bagging, we build a number of
//! decision trees on bootstrapped training samples. But when building these decision trees, each time a split in a tree is considered,
//! a random sample of _m_ predictors is chosen as split candidates from the full set of _p_ predictors.
@@ -16,6 +16,8 @@
//!
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 8.2 Bagging, Random Forests, Boosting](http://faculty.marshall.usc.edu/gareth-james/ISL/)
mod base_forest_regressor;
pub mod extra_trees_regressor;
/// Random forest classifier
pub mod random_forest_classifier;
/// Random forest regressor
+418 -101
View File
@@ -8,7 +8,7 @@
//! Example:
//!
//! ```
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::ensemble::random_forest_classifier::RandomForestClassifier;
//!
//! // Iris dataset
@@ -33,10 +33,10 @@
//! &[4.9, 2.4, 3.3, 1.0],
//! &[6.6, 2.9, 4.6, 1.3],
//! &[5.2, 2.7, 3.9, 1.4],
//! ]);
//! ]).unwrap();
//! let y = vec![
//! 0., 0., 0., 0., 0., 0., 0., 0.,
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
//! 0, 0, 0, 0, 0, 0, 0, 0,
//! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
//! ];
//!
//! let classifier = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap();
@@ -45,8 +45,8 @@
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use rand::Rng;
use std::default::Default;
use std::fmt::Debug;
@@ -55,8 +55,11 @@ use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::{Failed, FailedError};
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
use crate::rand_custom::get_rng_impl;
use crate::tree::decision_tree_classifier::{
which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion,
};
@@ -66,20 +69,28 @@ use crate::tree::decision_tree_classifier::{
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct RandomForestClassifierParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub criterion: SplitCriterion,
#[cfg_attr(feature = "serde", serde(default))]
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub max_depth: Option<u16>,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub min_samples_leaf: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub min_samples_split: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The number of trees in the forest.
pub n_trees: u16,
#[cfg_attr(feature = "serde", serde(default))]
/// Number of random sample of predictors to use as split candidates.
pub m: Option<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub keep_samples: bool,
#[cfg_attr(feature = "serde", serde(default))]
/// Seed used for bootstrap sampling and feature selection for each tree.
pub seed: u64,
}
@@ -87,10 +98,14 @@ pub struct RandomForestClassifierParameters {
/// Random Forest Classifier
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct RandomForestClassifier<T: RealNumber> {
_parameters: RandomForestClassifierParameters,
trees: Vec<DecisionTreeClassifier<T>>,
classes: Vec<T>,
pub struct RandomForestClassifier<
TX: Number + FloatNumber + PartialOrd,
TY: Number + Ord,
X: Array2<TX>,
Y: Array1<TY>,
> {
trees: Option<Vec<DecisionTreeClassifier<TX, TY, X, Y>>>,
classes: Option<Vec<TY>>,
samples: Option<Vec<Vec<bool>>>,
}
@@ -139,22 +154,24 @@ impl RandomForestClassifierParameters {
}
}
impl<T: RealNumber> PartialEq for RandomForestClassifier<T> {
impl<TX: Number + FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
PartialEq for RandomForestClassifier<TX, TY, X, Y>
{
fn eq(&self, other: &Self) -> bool {
if self.classes.len() != other.classes.len() || self.trees.len() != other.trees.len() {
if self.classes.as_ref().unwrap().len() != other.classes.as_ref().unwrap().len()
|| self.trees.as_ref().unwrap().len() != other.trees.as_ref().unwrap().len()
{
false
} else {
for i in 0..self.classes.len() {
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
return false;
}
}
for i in 0..self.trees.len() {
if self.trees[i] != other.trees[i] {
return false;
}
}
true
self.classes
.iter()
.zip(other.classes.iter())
.all(|(a, b)| a == b)
&& self
.trees
.iter()
.zip(other.trees.iter())
.all(|(a, b)| a == b)
}
}
}
@@ -163,7 +180,7 @@ impl Default for RandomForestClassifierParameters {
fn default() -> Self {
RandomForestClassifierParameters {
criterion: SplitCriterion::Gini,
max_depth: None,
max_depth: Option::None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 100,
@@ -174,65 +191,302 @@ impl Default for RandomForestClassifierParameters {
}
}
impl<T: RealNumber, M: Matrix<T>>
SupervisedEstimator<M, M::RowVector, RandomForestClassifierParameters>
for RandomForestClassifier<T>
impl<TX: Number + FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
SupervisedEstimator<X, Y, RandomForestClassifierParameters>
for RandomForestClassifier<TX, TY, X, Y>
{
fn fit(
x: &M,
y: &M::RowVector,
parameters: RandomForestClassifierParameters,
) -> Result<Self, Failed> {
fn new() -> Self {
Self {
trees: Option::None,
classes: Option::None,
samples: Option::None,
}
}
fn fit(x: &X, y: &Y, parameters: RandomForestClassifierParameters) -> Result<Self, Failed> {
RandomForestClassifier::fit(x, y, parameters)
}
}
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RandomForestClassifier<T> {
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
impl<TX: Number + FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
Predictor<X, Y> for RandomForestClassifier<TX, TY, X, Y>
{
fn predict(&self, x: &X) -> Result<Y, Failed> {
self.predict(x)
}
}
impl<T: RealNumber> RandomForestClassifier<T> {
/// RandomForestClassifier grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct RandomForestClassifierSearchParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub criterion: Vec<SplitCriterion>,
#[cfg_attr(feature = "serde", serde(default))]
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub max_depth: Vec<Option<u16>>,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub min_samples_leaf: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub min_samples_split: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// The number of trees in the forest.
pub n_trees: Vec<u16>,
#[cfg_attr(feature = "serde", serde(default))]
/// Number of random sample of predictors to use as split candidates.
pub m: Vec<Option<usize>>,
#[cfg_attr(feature = "serde", serde(default))]
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub keep_samples: Vec<bool>,
#[cfg_attr(feature = "serde", serde(default))]
/// Seed used for bootstrap sampling and feature selection for each tree.
pub seed: Vec<u64>,
}
/// RandomForestClassifier grid search iterator
pub struct RandomForestClassifierSearchParametersIterator {
random_forest_classifier_search_parameters: RandomForestClassifierSearchParameters,
current_criterion: usize,
current_max_depth: usize,
current_min_samples_leaf: usize,
current_min_samples_split: usize,
current_n_trees: usize,
current_m: usize,
current_keep_samples: usize,
current_seed: usize,
}
impl IntoIterator for RandomForestClassifierSearchParameters {
type Item = RandomForestClassifierParameters;
type IntoIter = RandomForestClassifierSearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
RandomForestClassifierSearchParametersIterator {
random_forest_classifier_search_parameters: self,
current_criterion: 0,
current_max_depth: 0,
current_min_samples_leaf: 0,
current_min_samples_split: 0,
current_n_trees: 0,
current_m: 0,
current_keep_samples: 0,
current_seed: 0,
}
}
}
impl Iterator for RandomForestClassifierSearchParametersIterator {
type Item = RandomForestClassifierParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_criterion
== self
.random_forest_classifier_search_parameters
.criterion
.len()
&& self.current_max_depth
== self
.random_forest_classifier_search_parameters
.max_depth
.len()
&& self.current_min_samples_leaf
== self
.random_forest_classifier_search_parameters
.min_samples_leaf
.len()
&& self.current_min_samples_split
== self
.random_forest_classifier_search_parameters
.min_samples_split
.len()
&& self.current_n_trees
== self
.random_forest_classifier_search_parameters
.n_trees
.len()
&& self.current_m == self.random_forest_classifier_search_parameters.m.len()
&& self.current_keep_samples
== self
.random_forest_classifier_search_parameters
.keep_samples
.len()
&& self.current_seed == self.random_forest_classifier_search_parameters.seed.len()
{
return None;
}
let next = RandomForestClassifierParameters {
criterion: self.random_forest_classifier_search_parameters.criterion
[self.current_criterion]
.clone(),
max_depth: self.random_forest_classifier_search_parameters.max_depth
[self.current_max_depth],
min_samples_leaf: self
.random_forest_classifier_search_parameters
.min_samples_leaf[self.current_min_samples_leaf],
min_samples_split: self
.random_forest_classifier_search_parameters
.min_samples_split[self.current_min_samples_split],
n_trees: self.random_forest_classifier_search_parameters.n_trees[self.current_n_trees],
m: self.random_forest_classifier_search_parameters.m[self.current_m],
keep_samples: self.random_forest_classifier_search_parameters.keep_samples
[self.current_keep_samples],
seed: self.random_forest_classifier_search_parameters.seed[self.current_seed],
};
if self.current_criterion + 1
< self
.random_forest_classifier_search_parameters
.criterion
.len()
{
self.current_criterion += 1;
} else if self.current_max_depth + 1
< self
.random_forest_classifier_search_parameters
.max_depth
.len()
{
self.current_criterion = 0;
self.current_max_depth += 1;
} else if self.current_min_samples_leaf + 1
< self
.random_forest_classifier_search_parameters
.min_samples_leaf
.len()
{
self.current_criterion = 0;
self.current_max_depth = 0;
self.current_min_samples_leaf += 1;
} else if self.current_min_samples_split + 1
< self
.random_forest_classifier_search_parameters
.min_samples_split
.len()
{
self.current_criterion = 0;
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split += 1;
} else if self.current_n_trees + 1
< self
.random_forest_classifier_search_parameters
.n_trees
.len()
{
self.current_criterion = 0;
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split = 0;
self.current_n_trees += 1;
} else if self.current_m + 1 < self.random_forest_classifier_search_parameters.m.len() {
self.current_criterion = 0;
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split = 0;
self.current_n_trees = 0;
self.current_m += 1;
} else if self.current_keep_samples + 1
< self
.random_forest_classifier_search_parameters
.keep_samples
.len()
{
self.current_criterion = 0;
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split = 0;
self.current_n_trees = 0;
self.current_m = 0;
self.current_keep_samples += 1;
} else if self.current_seed + 1 < self.random_forest_classifier_search_parameters.seed.len()
{
self.current_criterion = 0;
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split = 0;
self.current_n_trees = 0;
self.current_m = 0;
self.current_keep_samples = 0;
self.current_seed += 1;
} else {
self.current_criterion += 1;
self.current_max_depth += 1;
self.current_min_samples_leaf += 1;
self.current_min_samples_split += 1;
self.current_n_trees += 1;
self.current_m += 1;
self.current_keep_samples += 1;
self.current_seed += 1;
}
Some(next)
}
}
impl Default for RandomForestClassifierSearchParameters {
fn default() -> Self {
let default_params = RandomForestClassifierParameters::default();
RandomForestClassifierSearchParameters {
criterion: vec![default_params.criterion],
max_depth: vec![default_params.max_depth],
min_samples_leaf: vec![default_params.min_samples_leaf],
min_samples_split: vec![default_params.min_samples_split],
n_trees: vec![default_params.n_trees],
m: vec![default_params.m],
keep_samples: vec![default_params.keep_samples],
seed: vec![default_params.seed],
}
}
}
impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
RandomForestClassifier<TX, TY, X, Y>
{
/// Build a forest of trees from the training set.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - the target class values
pub fn fit<M: Matrix<T>>(
x: &M,
y: &M::RowVector,
pub fn fit(
x: &X,
y: &Y,
parameters: RandomForestClassifierParameters,
) -> Result<RandomForestClassifier<T>, Failed> {
let (_, num_attributes) = x.shape();
let y_m = M::from_row_vector(y.clone());
let (_, y_ncols) = y_m.shape();
let mut yi: Vec<usize> = vec![0; y_ncols];
let classes = y_m.unique();
for (i, yi_i) in yi.iter_mut().enumerate().take(y_ncols) {
let yc = y_m.get(0, i);
*yi_i = classes.iter().position(|c| yc == *c).unwrap();
) -> Result<RandomForestClassifier<TX, TY, X, Y>, Failed> {
let (x_nrows, num_attributes) = x.shape();
let y_ncols = y.shape();
if x_nrows != y_ncols {
return Err(Failed::fit("Number of rows in X should = len(y)"));
}
let mtry = parameters.m.unwrap_or_else(|| {
(T::from(num_attributes).unwrap())
.sqrt()
.floor()
.to_usize()
.unwrap()
});
let mut yi: Vec<usize> = vec![0; y_ncols];
let classes = y.unique();
let mut rng = StdRng::seed_from_u64(parameters.seed);
let classes = y_m.unique();
for (i, yi_i) in yi.iter_mut().enumerate().take(y_ncols) {
let yc = y.get(i);
*yi_i = classes.iter().position(|c| yc == c).unwrap();
}
let mtry = parameters
.m
.unwrap_or_else(|| ((num_attributes as f64).sqrt().floor()) as usize);
let mut rng = get_rng_impl(Some(parameters.seed));
let classes = y.unique();
let k = classes.len();
let mut trees: Vec<DecisionTreeClassifier<T>> = Vec::new();
// TODO: use with_capacity here
let mut trees: Vec<DecisionTreeClassifier<TX, TY, X, Y>> = Vec::new();
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
if parameters.keep_samples {
// TODO: use with_capacity here
maybe_all_samples = Some(Vec::new());
}
for _ in 0..parameters.n_trees {
let samples = RandomForestClassifier::<T>::sample_with_replacement(&yi, k, &mut rng);
let samples: Vec<usize> =
RandomForestClassifier::<TX, TY, X, Y>::sample_with_replacement(&yi, k, &mut rng);
if let Some(ref mut all_samples) = maybe_all_samples {
all_samples.push(samples.iter().map(|x| *x != 0).collect())
}
@@ -242,38 +496,40 @@ impl<T: RealNumber> RandomForestClassifier<T> {
max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split,
seed: Some(parameters.seed),
};
let tree =
DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params)?;
trees.push(tree);
}
Ok(RandomForestClassifier {
_parameters: parameters,
trees,
classes,
trees: Some(trees),
classes: Some(classes),
samples: maybe_all_samples,
})
}
/// Predict class for `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
let mut result = M::zeros(1, x.shape().0);
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let mut result = Y::zeros(x.shape().0);
let (n, _) = x.shape();
for i in 0..n {
result.set(0, i, self.classes[self.predict_for_row(x, i)]);
result.set(
i,
self.classes.as_ref().unwrap()[self.predict_for_row(x, i)],
);
}
Ok(result.to_row_vector())
Ok(result)
}
fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
let mut result = vec![0; self.classes.len()];
fn predict_for_row(&self, x: &X, row: usize) -> usize {
let mut result = vec![0; self.classes.as_ref().unwrap().len()];
for tree in self.trees.iter() {
for tree in self.trees.as_ref().unwrap().iter() {
result[tree.predict_for_row(x, row)] += 1;
}
@@ -281,7 +537,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
}
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
pub fn predict_oob<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
let (n, _) = x.shape();
if self.samples.is_none() {
Err(Failed::because(
@@ -294,20 +550,28 @@ impl<T: RealNumber> RandomForestClassifier<T> {
"Prediction matrix must match matrix used in training for OOB predictions.",
))
} else {
let mut result = M::zeros(1, n);
let mut result = Y::zeros(n);
for i in 0..n {
result.set(0, i, self.classes[self.predict_for_row_oob(x, i)]);
result.set(
i,
self.classes.as_ref().unwrap()[self.predict_for_row_oob(x, i)],
);
}
Ok(result.to_row_vector())
Ok(result)
}
}
fn predict_for_row_oob<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
let mut result = vec![0; self.classes.len()];
fn predict_for_row_oob(&self, x: &X, row: usize) -> usize {
let mut result = vec![0; self.classes.as_ref().unwrap().len()];
for (tree, samples) in self.trees.iter().zip(self.samples.as_ref().unwrap()) {
for (tree, samples) in self
.trees
.as_ref()
.unwrap()
.iter()
.zip(self.samples.as_ref().unwrap())
{
if !samples[row] {
result[tree.predict_for_row(x, row)] += 1;
}
@@ -343,12 +607,38 @@ impl<T: RealNumber> RandomForestClassifier<T> {
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::metrics::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn fit_predict_iris() {
fn search_parameters() {
let parameters = RandomForestClassifierSearchParameters {
n_trees: vec![10, 100],
m: vec![None, Some(1)],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.n_trees, 10);
assert_eq!(next.m, None);
let next = iter.next().unwrap();
assert_eq!(next.n_trees, 100);
assert_eq!(next.m, None);
let next = iter.next().unwrap();
assert_eq!(next.n_trees, 10);
assert_eq!(next.m, Some(1));
let next = iter.next().unwrap();
assert_eq!(next.n_trees, 100);
assert_eq!(next.m, Some(1));
assert!(iter.next().is_none());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn fit_predict() {
let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
@@ -370,17 +660,16 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
let y = vec![
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
])
.unwrap();
let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
let classifier = RandomForestClassifier::fit(
&x,
&y,
RandomForestClassifierParameters {
criterion: SplitCriterion::Gini,
max_depth: None,
max_depth: Option::None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 100,
@@ -394,7 +683,34 @@ mod tests {
assert!(accuracy(&y, &classifier.predict(&x).unwrap()) >= 0.95);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_random_matrix_with_wrong_rownum() {
let x_rand: DenseMatrix<f64> = DenseMatrix::<f64>::rand(21, 200);
let y: Vec<u32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
let fail = RandomForestClassifier::fit(
&x_rand,
&y,
RandomForestClassifierParameters {
criterion: SplitCriterion::Gini,
max_depth: Option::None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 100,
m: Option::None,
keep_samples: false,
seed: 87,
},
);
assert!(fail.is_err());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn fit_predict_iris_oob() {
let x = DenseMatrix::from_2d_array(&[
@@ -418,17 +734,16 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
let y = vec![
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
])
.unwrap();
let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
let classifier = RandomForestClassifier::fit(
&x,
&y,
RandomForestClassifierParameters {
criterion: SplitCriterion::Gini,
max_depth: None,
max_depth: Option::None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 100,
@@ -445,7 +760,10 @@ mod tests {
);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
#[cfg(feature = "serde")]
fn serde() {
@@ -470,14 +788,13 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
let y = vec![
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
])
.unwrap();
let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap();
let deserialized_forest: RandomForestClassifier<f64> =
let deserialized_forest: RandomForestClassifier<f64, i64, DenseMatrix<f64>, Vec<i64>> =
bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap();
assert_eq!(forest, deserialized_forest);
+333 -149
View File
@@ -8,7 +8,7 @@
//! Example:
//!
//! ```
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::ensemble::random_forest_regressor::*;
//!
//! // Longley dataset (https://www.statsmodels.org/stable/datasets/generated/longley.html)
@@ -29,7 +29,7 @@
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
//! ]);
//! ]).unwrap();
//! let y = vec![
//! 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2,
//! 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9
@@ -43,8 +43,6 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::default::Default;
use std::fmt::Debug;
@@ -52,30 +50,37 @@ use std::fmt::Debug;
use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::{Failed, FailedError};
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
use crate::tree::decision_tree_regressor::{
DecisionTreeRegressor, DecisionTreeRegressorParameters,
};
use crate::ensemble::base_forest_regressor::{BaseForestRegressor, BaseForestRegressorParameters};
use crate::error::Failed;
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
use crate::tree::base_tree_regressor::Splitter;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// Parameters of the Random Forest Regressor
/// Some parameters here are passed directly into base estimator.
pub struct RandomForestRegressorParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub max_depth: Option<u16>,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub min_samples_leaf: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub min_samples_split: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The number of trees in the forest.
pub n_trees: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// Number of random sample of predictors to use as split candidates.
pub m: Option<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub keep_samples: bool,
#[cfg_attr(feature = "serde", serde(default))]
/// Seed used for bootstrap sampling and feature selection for each tree.
pub seed: u64,
}
@@ -83,10 +88,13 @@ pub struct RandomForestRegressorParameters {
/// Random Forest Regressor
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct RandomForestRegressor<T: RealNumber> {
_parameters: RandomForestRegressorParameters,
trees: Vec<DecisionTreeRegressor<T>>,
samples: Option<Vec<Vec<bool>>>,
pub struct RandomForestRegressor<
TX: Number + FloatNumber + PartialOrd,
TY: Number,
X: Array2<TX>,
Y: Array1<TY>,
> {
forest_regressor: Option<BaseForestRegressor<TX, TY, X, Y>>,
}
impl RandomForestRegressorParameters {
@@ -131,7 +139,7 @@ impl RandomForestRegressorParameters {
impl Default for RandomForestRegressorParameters {
fn default() -> Self {
RandomForestRegressorParameters {
max_depth: None,
max_depth: Option::None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 10,
@@ -142,167 +150,305 @@ impl Default for RandomForestRegressorParameters {
}
}
impl<T: RealNumber> PartialEq for RandomForestRegressor<T> {
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
for RandomForestRegressor<TX, TY, X, Y>
{
fn eq(&self, other: &Self) -> bool {
if self.trees.len() != other.trees.len() {
false
} else {
for i in 0..self.trees.len() {
if self.trees[i] != other.trees[i] {
return false;
}
}
true
}
self.forest_regressor == other.forest_regressor
}
}
impl<T: RealNumber, M: Matrix<T>>
SupervisedEstimator<M, M::RowVector, RandomForestRegressorParameters>
for RandomForestRegressor<T>
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
SupervisedEstimator<X, Y, RandomForestRegressorParameters>
for RandomForestRegressor<TX, TY, X, Y>
{
fn fit(
x: &M,
y: &M::RowVector,
parameters: RandomForestRegressorParameters,
) -> Result<Self, Failed> {
fn new() -> Self {
Self {
forest_regressor: Option::None,
}
}
fn fit(x: &X, y: &Y, parameters: RandomForestRegressorParameters) -> Result<Self, Failed> {
RandomForestRegressor::fit(x, y, parameters)
}
}
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RandomForestRegressor<T> {
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
Predictor<X, Y> for RandomForestRegressor<TX, TY, X, Y>
{
fn predict(&self, x: &X) -> Result<Y, Failed> {
self.predict(x)
}
}
impl<T: RealNumber> RandomForestRegressor<T> {
/// RandomForestRegressor grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct RandomForestRegressorSearchParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub max_depth: Vec<Option<u16>>,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub min_samples_leaf: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub min_samples_split: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// The number of trees in the forest.
pub n_trees: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// Number of random sample of predictors to use as split candidates.
pub m: Vec<Option<usize>>,
#[cfg_attr(feature = "serde", serde(default))]
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub keep_samples: Vec<bool>,
#[cfg_attr(feature = "serde", serde(default))]
/// Seed used for bootstrap sampling and feature selection for each tree.
pub seed: Vec<u64>,
}
/// RandomForestRegressor grid search iterator
pub struct RandomForestRegressorSearchParametersIterator {
random_forest_regressor_search_parameters: RandomForestRegressorSearchParameters,
current_max_depth: usize,
current_min_samples_leaf: usize,
current_min_samples_split: usize,
current_n_trees: usize,
current_m: usize,
current_keep_samples: usize,
current_seed: usize,
}
impl IntoIterator for RandomForestRegressorSearchParameters {
type Item = RandomForestRegressorParameters;
type IntoIter = RandomForestRegressorSearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
RandomForestRegressorSearchParametersIterator {
random_forest_regressor_search_parameters: self,
current_max_depth: 0,
current_min_samples_leaf: 0,
current_min_samples_split: 0,
current_n_trees: 0,
current_m: 0,
current_keep_samples: 0,
current_seed: 0,
}
}
}
impl Iterator for RandomForestRegressorSearchParametersIterator {
type Item = RandomForestRegressorParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_max_depth
== self
.random_forest_regressor_search_parameters
.max_depth
.len()
&& self.current_min_samples_leaf
== self
.random_forest_regressor_search_parameters
.min_samples_leaf
.len()
&& self.current_min_samples_split
== self
.random_forest_regressor_search_parameters
.min_samples_split
.len()
&& self.current_n_trees == self.random_forest_regressor_search_parameters.n_trees.len()
&& self.current_m == self.random_forest_regressor_search_parameters.m.len()
&& self.current_keep_samples
== self
.random_forest_regressor_search_parameters
.keep_samples
.len()
&& self.current_seed == self.random_forest_regressor_search_parameters.seed.len()
{
return None;
}
let next = RandomForestRegressorParameters {
max_depth: self.random_forest_regressor_search_parameters.max_depth
[self.current_max_depth],
min_samples_leaf: self
.random_forest_regressor_search_parameters
.min_samples_leaf[self.current_min_samples_leaf],
min_samples_split: self
.random_forest_regressor_search_parameters
.min_samples_split[self.current_min_samples_split],
n_trees: self.random_forest_regressor_search_parameters.n_trees[self.current_n_trees],
m: self.random_forest_regressor_search_parameters.m[self.current_m],
keep_samples: self.random_forest_regressor_search_parameters.keep_samples
[self.current_keep_samples],
seed: self.random_forest_regressor_search_parameters.seed[self.current_seed],
};
if self.current_max_depth + 1
< self
.random_forest_regressor_search_parameters
.max_depth
.len()
{
self.current_max_depth += 1;
} else if self.current_min_samples_leaf + 1
< self
.random_forest_regressor_search_parameters
.min_samples_leaf
.len()
{
self.current_max_depth = 0;
self.current_min_samples_leaf += 1;
} else if self.current_min_samples_split + 1
< self
.random_forest_regressor_search_parameters
.min_samples_split
.len()
{
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split += 1;
} else if self.current_n_trees + 1
< self.random_forest_regressor_search_parameters.n_trees.len()
{
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split = 0;
self.current_n_trees += 1;
} else if self.current_m + 1 < self.random_forest_regressor_search_parameters.m.len() {
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split = 0;
self.current_n_trees = 0;
self.current_m += 1;
} else if self.current_keep_samples + 1
< self
.random_forest_regressor_search_parameters
.keep_samples
.len()
{
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split = 0;
self.current_n_trees = 0;
self.current_m = 0;
self.current_keep_samples += 1;
} else if self.current_seed + 1 < self.random_forest_regressor_search_parameters.seed.len()
{
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split = 0;
self.current_n_trees = 0;
self.current_m = 0;
self.current_keep_samples = 0;
self.current_seed += 1;
} else {
self.current_max_depth += 1;
self.current_min_samples_leaf += 1;
self.current_min_samples_split += 1;
self.current_n_trees += 1;
self.current_m += 1;
self.current_keep_samples += 1;
self.current_seed += 1;
}
Some(next)
}
}
impl Default for RandomForestRegressorSearchParameters {
fn default() -> Self {
let default_params = RandomForestRegressorParameters::default();
RandomForestRegressorSearchParameters {
max_depth: vec![default_params.max_depth],
min_samples_leaf: vec![default_params.min_samples_leaf],
min_samples_split: vec![default_params.min_samples_split],
n_trees: vec![default_params.n_trees],
m: vec![default_params.m],
keep_samples: vec![default_params.keep_samples],
seed: vec![default_params.seed],
}
}
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
RandomForestRegressor<TX, TY, X, Y>
{
/// Build a forest of trees from the training set.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - the target class values
pub fn fit<M: Matrix<T>>(
x: &M,
y: &M::RowVector,
pub fn fit(
x: &X,
y: &Y,
parameters: RandomForestRegressorParameters,
) -> Result<RandomForestRegressor<T>, Failed> {
let (n_rows, num_attributes) = x.shape();
let mtry = parameters
.m
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
let mut rng = StdRng::seed_from_u64(parameters.seed);
let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new();
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
if parameters.keep_samples {
maybe_all_samples = Some(Vec::new());
}
for _ in 0..parameters.n_trees {
let samples = RandomForestRegressor::<T>::sample_with_replacement(n_rows, &mut rng);
if let Some(ref mut all_samples) = maybe_all_samples {
all_samples.push(samples.iter().map(|x| *x != 0).collect())
}
let params = DecisionTreeRegressorParameters {
max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split,
};
let tree =
DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
trees.push(tree);
}
) -> Result<RandomForestRegressor<TX, TY, X, Y>, Failed> {
let regressor_params = BaseForestRegressorParameters {
max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split,
n_trees: parameters.n_trees,
m: parameters.m,
keep_samples: parameters.keep_samples,
seed: parameters.seed,
bootstrap: true,
splitter: Splitter::Best,
};
let forest_regressor = BaseForestRegressor::fit(x, y, regressor_params)?;
Ok(RandomForestRegressor {
_parameters: parameters,
trees,
samples: maybe_all_samples,
forest_regressor: Some(forest_regressor),
})
}
/// Predict class for `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
let mut result = M::zeros(1, x.shape().0);
let (n, _) = x.shape();
for i in 0..n {
result.set(0, i, self.predict_for_row(x, i));
}
Ok(result.to_row_vector())
}
fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
let n_trees = self.trees.len();
let mut result = T::zero();
for tree in self.trees.iter() {
result += tree.predict_for_row(x, row);
}
result / T::from(n_trees).unwrap()
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let forest_regressor = self.forest_regressor.as_ref().unwrap();
forest_regressor.predict(x)
}
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
pub fn predict_oob<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
let (n, _) = x.shape();
if self.samples.is_none() {
Err(Failed::because(
FailedError::PredictFailed,
"Need samples=true for OOB predictions.",
))
} else if self.samples.as_ref().unwrap()[0].len() != n {
Err(Failed::because(
FailedError::PredictFailed,
"Prediction matrix must match matrix used in training for OOB predictions.",
))
} else {
let mut result = M::zeros(1, n);
for i in 0..n {
result.set(0, i, self.predict_for_row_oob(x, i));
}
Ok(result.to_row_vector())
}
}
fn predict_for_row_oob<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
let mut n_trees = 0;
let mut result = T::zero();
for (tree, samples) in self.trees.iter().zip(self.samples.as_ref().unwrap()) {
if !samples[row] {
result += tree.predict_for_row(x, row);
n_trees += 1;
}
}
// TODO: What to do if there are no oob trees?
result / T::from(n_trees).unwrap()
}
fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec<usize> {
let mut samples = vec![0; nrows];
for _ in 0..nrows {
let xi = rng.gen_range(0..nrows);
samples[xi] += 1;
}
samples
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
let forest_regressor = self.forest_regressor.as_ref().unwrap();
forest_regressor.predict_oob(x)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::metrics::mean_absolute_error;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn search_parameters() {
let parameters = RandomForestRegressorSearchParameters {
n_trees: vec![10, 100],
m: vec![None, Some(1)],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.n_trees, 10);
assert_eq!(next.m, None);
let next = iter.next().unwrap();
assert_eq!(next.n_trees, 100);
assert_eq!(next.m, None);
let next = iter.next().unwrap();
assert_eq!(next.n_trees, 10);
assert_eq!(next.m, Some(1));
let next = iter.next().unwrap();
assert_eq!(next.n_trees, 100);
assert_eq!(next.m, Some(1));
assert!(iter.next().is_none());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn fit_longley() {
let x = DenseMatrix::from_2d_array(&[
@@ -322,7 +468,8 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]);
])
.unwrap();
let y = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
@@ -332,7 +479,7 @@ mod tests {
&x,
&y,
RandomForestRegressorParameters {
max_depth: None,
max_depth: Option::None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 1000,
@@ -347,7 +494,36 @@ mod tests {
assert!(mean_absolute_error(&y, &y_hat) < 1.0);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_random_matrix_with_wrong_rownum() {
let x_rand: DenseMatrix<f64> = DenseMatrix::<f64>::rand(17, 200);
let y = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
];
let fail = RandomForestRegressor::fit(
&x_rand,
&y,
RandomForestRegressorParameters {
max_depth: Option::None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 1000,
m: Option::None,
keep_samples: false,
seed: 87,
},
);
assert!(fail.is_err());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn fit_predict_longley_oob() {
let x = DenseMatrix::from_2d_array(&[
@@ -367,7 +543,8 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]);
])
.unwrap();
let y = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
@@ -377,7 +554,7 @@ mod tests {
&x,
&y,
RandomForestRegressorParameters {
max_depth: None,
max_depth: Option::None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 1000,
@@ -391,10 +568,16 @@ mod tests {
let y_hat = regressor.predict(&x).unwrap();
let y_hat_oob = regressor.predict_oob(&x).unwrap();
println!("{:?}", mean_absolute_error(&y, &y_hat));
println!("{:?}", mean_absolute_error(&y, &y_hat_oob));
assert!(mean_absolute_error(&y, &y_hat) < mean_absolute_error(&y, &y_hat_oob));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
#[cfg(feature = "serde")]
fn serde() {
@@ -415,7 +598,8 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]);
])
.unwrap();
let y = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
@@ -423,7 +607,7 @@ mod tests {
let forest = RandomForestRegressor::fit(&x, &y, Default::default()).unwrap();
let deserialized_forest: RandomForestRegressor<f64> =
let deserialized_forest: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>> =
bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap();
assert_eq!(forest, deserialized_forest);
+23 -1
View File
@@ -30,6 +30,10 @@ pub enum FailedError {
DecompositionFailed,
/// Can't solve for x
SolutionFailed,
/// Error in input parameters
ParametersError,
/// Invalid state error (should never happen)
InvalidStateError,
}
impl Failed {
@@ -62,6 +66,22 @@ impl Failed {
}
}
/// new instance of `FailedError::ParametersError`
pub fn input(msg: &str) -> Self {
Failed {
err: FailedError::ParametersError,
msg: msg.to_string(),
}
}
/// new instance of `FailedError::InvalidStateError`
pub fn invalid_state(msg: &str) -> Self {
Failed {
err: FailedError::InvalidStateError,
msg: msg.to_string(),
}
}
/// new instance of `err`
pub fn because(err: FailedError, msg: &str) -> Self {
Failed {
@@ -94,8 +114,10 @@ impl fmt::Display for FailedError {
FailedError::FindFailed => "Find failed",
FailedError::DecompositionFailed => "Decomposition failed",
FailedError::SolutionFailed => "Can't find solution",
FailedError::ParametersError => "Error in input, check parameters",
FailedError::InvalidStateError => "Invalid state, this should never happen", // useful in development phase of lib
};
write!(f, "{}", failed_err_str)
write!(f, "{failed_err_str}")
}
}
+78 -44
View File
@@ -3,32 +3,81 @@
clippy::too_many_arguments,
clippy::many_single_char_names,
clippy::unnecessary_wraps,
clippy::upper_case_acronyms
clippy::upper_case_acronyms,
clippy::approx_constant
)]
#![warn(missing_docs)]
#![warn(rustdoc::missing_doc_code_examples)]
//! # SmartCore
//! # smartcore
//!
//! Welcome to SmartCore, the most advanced machine learning library in Rust!
//! Welcome to `smartcore`, machine learning in Rust!
//!
//! SmartCore features various classification, regression and clustering algorithms including support vector machines, random forests, k-means and DBSCAN,
//! `smartcore` features various classification, regression and clustering algorithms including support vector machines, random forests, k-means and DBSCAN,
//! as well as tools for model selection and model evaluation.
//!
//! SmartCore is well integrated with a with wide variaty of libraries that provide support for large, multi-dimensional arrays and matrices. At this moment,
//! all Smartcore's algorithms work with ordinary Rust vectors, as well as matrices and vectors defined in these packages:
//! * [ndarray](https://docs.rs/ndarray)
//! * [nalgebra](https://docs.rs/nalgebra/)
//! `smartcore` provides its own traits system that extends Rust standard library, to deal with linear algebra and common
//! computational models. Its API is designed using well recognizable patterns. Extra features (like support for [ndarray](https://docs.rs/ndarray)
//! structures) is available via optional features.
//!
//! ## Getting Started
//!
//! To start using SmartCore simply add the following to your Cargo.toml file:
//! To start using `smartcore` latest stable version simply add the following to your `Cargo.toml` file:
//! ```ignore
//! [dependencies]
//! smartcore = "0.2.0"
//! smartcore = "*"
//! ```
//!
//! All machine learning algorithms in SmartCore are grouped into these broad categories:
//! To start using smartcore development version with latest unstable additions:
//! ```ignore
//! [dependencies]
//! smartcore = { git = "https://github.com/smartcorelib/smartcore", branch = "development" }
//! ```
//!
//! There are different features that can be added to the base library, for example to add sample datasets:
//! ```ignore
//! [dependencies]
//! smartcore = { git = "https://github.com/smartcorelib/smartcore", features = ["datasets"] }
//! ```
//! Check `smartcore`'s `Cargo.toml` for available features.
//!
//! ## Using Jupyter
//! For quick introduction, Jupyter Notebooks are available [here](https://github.com/smartcorelib/smartcore-jupyter/tree/main/notebooks).
//! You can set up a local environment to run Rust notebooks using [EVCXR](https://github.com/google/evcxr)
//! following [these instructions](https://depth-first.com/articles/2020/09/21/interactive-rust-in-a-repl-and-jupyter-notebook-with-evcxr/).
//!
//!
//! ## First Example
//! For example, you can use this code to fit a [K Nearest Neighbors classifier](neighbors/knn_classifier/index.html) to a dataset that is defined as standard Rust vector:
//!
//! ```
//! // DenseMatrix definition
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! // KNNClassifier
//! use smartcore::neighbors::knn_classifier::*;
//! // Various distance metrics
//! use smartcore::metrics::distance::*;
//!
//! // Turn Rust vector-slices with samples into a matrix
//! let x = DenseMatrix::from_2d_array(&[
//! &[1., 2.],
//! &[3., 4.],
//! &[5., 6.],
//! &[7., 8.],
//! &[9., 10.]]).unwrap();
//! // Our classes are defined as a vector
//! let y = vec![2, 2, 2, 3, 3];
//!
//! // Train classifier
//! let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
//!
//! // Predict classes
//! let y_hat = knn.predict(&x).unwrap();
//! ```
//!
//! ## Overview
//!
//! ### Supported algorithms
//! All machine learning algorithms are grouped into these broad categories:
//! * [Clustering](cluster/index.html), unsupervised clustering of unlabeled data.
//! * [Matrix Decomposition](decomposition/index.html), various methods for matrix decomposition.
//! * [Linear Models](linear/index.html), regression and classification methods where output is assumed to have linear relation to explanatory variables
@@ -38,37 +87,16 @@
//! * [Naive Bayes](naive_bayes/index.html), statistical classification technique based on Bayes Theorem
//! * [SVM](svm/index.html), support vector machines
//!
//!
//! For example, you can use this code to fit a [K Nearest Neighbors classifier](neighbors/knn_classifier/index.html) to a dataset that is defined as standard Rust vector:
//!
//! ```
//! // DenseMatrix defenition
//! use smartcore::linalg::naive::dense_matrix::*;
//! // KNNClassifier
//! use smartcore::neighbors::knn_classifier::*;
//! // Various distance metrics
//! use smartcore::math::distance::*;
//!
//! // Turn Rust vectors with samples into a matrix
//! let x = DenseMatrix::from_2d_array(&[
//! &[1., 2.],
//! &[3., 4.],
//! &[5., 6.],
//! &[7., 8.],
//! &[9., 10.]]);
//! // Our classes are defined as a Vector
//! let y = vec![2., 2., 2., 3., 3.];
//!
//! // Train classifier
//! let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
//!
//! // Predict classes
//! let y_hat = knn.predict(&x).unwrap();
//! ```
//! ### Linear Algebra traits system
//! For an introduction to `smartcore`'s traits system see [this notebook](https://github.com/smartcorelib/smartcore-jupyter/blob/5523993c53c6ec1fd72eea130ef4e7883121c1ea/notebooks/01-A-little-bit-about-numbers.ipynb)
/// Various algorithms and helper methods that are used elsewhere in SmartCore
/// Foundamental numbers traits
pub mod numbers;
/// Various algorithms and helper methods that are used elsewhere in smartcore
pub mod algorithm;
pub mod api;
/// Algorithms for clustering of unlabeled data
pub mod cluster;
/// Various datasets
@@ -79,23 +107,29 @@ pub mod decomposition;
/// Ensemble methods, including Random Forest classifier and regressor
pub mod ensemble;
pub mod error;
/// Diverse collection of linear algebra abstractions and methods that power SmartCore algorithms
/// Diverse collection of linear algebra abstractions and methods that power smartcore algorithms
pub mod linalg;
/// Supervised classification and regression models that assume linear relationship between dependent and explanatory variables.
pub mod linear;
/// Helper methods and classes, including definitions of distance metrics
pub mod math;
/// Functions for assessing prediction error.
pub mod metrics;
/// TODO: add docstring for model_selection
pub mod model_selection;
/// Supervised learning algorithms based on applying the Bayes theorem with the independence assumptions between predictors
pub mod naive_bayes;
/// Supervised neighbors-based learning methods
pub mod neighbors;
pub(crate) mod optimization;
/// Optimization procedures
pub mod optimization;
/// Preprocessing utilities
pub mod preprocessing;
/// Reading in data from serialized formats
#[cfg(feature = "serde")]
pub mod readers;
/// Support Vector Machines
pub mod svm;
/// Supervised tree-based learning methods
pub mod tree;
pub mod xgboost;
pub(crate) mod rand_custom;
File diff suppressed because it is too large Load Diff
+845
View File
@@ -0,0 +1,845 @@
use std::fmt;
use std::fmt::{Debug, Display};
use std::ops::Range;
use std::slice::Iter;
use approx::{AbsDiffEq, RelativeEq};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::linalg::basic::arrays::{
Array, Array2, ArrayView1, ArrayView2, MutArray, MutArrayView2,
};
use crate::linalg::traits::cholesky::CholeskyDecomposable;
use crate::linalg::traits::evd::EVDDecomposable;
use crate::linalg::traits::lu::LUDecomposable;
use crate::linalg::traits::qr::QRDecomposable;
use crate::linalg::traits::stats::{MatrixPreprocessing, MatrixStats};
use crate::linalg::traits::svd::SVDDecomposable;
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
use crate::error::Failed;
/// Dense matrix
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct DenseMatrix<T> {
ncols: usize,
nrows: usize,
values: Vec<T>,
column_major: bool,
}
/// View on dense matrix
#[derive(Debug, Clone)]
pub struct DenseMatrixView<'a, T: Debug + Display + Copy + Sized> {
values: &'a [T],
stride: usize,
nrows: usize,
ncols: usize,
column_major: bool,
}
/// Mutable view on dense matrix
#[derive(Debug)]
pub struct DenseMatrixMutView<'a, T: Debug + Display + Copy + Sized> {
values: &'a mut [T],
stride: usize,
nrows: usize,
ncols: usize,
column_major: bool,
}
impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixView<'a, T> {
fn new(
m: &'a DenseMatrix<T>,
vrows: Range<usize>,
vcols: Range<usize>,
) -> Result<Self, Failed> {
if m.is_valid_view(m.shape().0, m.shape().1, &vrows, &vcols) {
Err(Failed::input(
"The specified view is outside of the matrix range",
))
} else {
let (start, end, stride) =
m.stride_range(m.shape().0, m.shape().1, &vrows, &vcols, m.column_major);
Ok(DenseMatrixView {
values: &m.values[start..end],
stride,
nrows: vrows.end - vrows.start,
ncols: vcols.end - vcols.start,
column_major: m.column_major,
})
}
}
fn iter<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
assert!(
axis == 1 || axis == 0,
"For two dimensional array `axis` should be either 0 or 1"
);
match axis {
0 => Box::new(
(0..self.nrows).flat_map(move |r| (0..self.ncols).map(move |c| self.get((r, c)))),
),
_ => Box::new(
(0..self.ncols).flat_map(move |c| (0..self.nrows).map(move |r| self.get((r, c)))),
),
}
}
}
impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixView<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"DenseMatrix: nrows: {:?}, ncols: {:?}",
self.nrows, self.ncols
)?;
writeln!(f, "column_major: {:?}", self.column_major)?;
self.display(f)
}
}
impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> {
fn new(
m: &'a mut DenseMatrix<T>,
vrows: Range<usize>,
vcols: Range<usize>,
) -> Result<Self, Failed> {
if m.is_valid_view(m.shape().0, m.shape().1, &vrows, &vcols) {
Err(Failed::input(
"The specified view is outside of the matrix range",
))
} else {
let (start, end, stride) =
m.stride_range(m.shape().0, m.shape().1, &vrows, &vcols, m.column_major);
Ok(DenseMatrixMutView {
values: &mut m.values[start..end],
stride,
nrows: vrows.end - vrows.start,
ncols: vcols.end - vcols.start,
column_major: m.column_major,
})
}
}
fn iter<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
assert!(
axis == 1 || axis == 0,
"For two dimensional array `axis` should be either 0 or 1"
);
match axis {
0 => Box::new(
(0..self.nrows).flat_map(move |r| (0..self.ncols).map(move |c| self.get((r, c)))),
),
_ => Box::new(
(0..self.ncols).flat_map(move |c| (0..self.nrows).map(move |r| self.get((r, c)))),
),
}
}
fn iter_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
let column_major = self.column_major;
let stride = self.stride;
let ptr = self.values.as_mut_ptr();
match axis {
0 => Box::new((0..self.nrows).flat_map(move |r| {
(0..self.ncols).map(move |c| unsafe {
&mut *ptr.add(if column_major {
r + c * stride
} else {
r * stride + c
})
})
})),
_ => Box::new((0..self.ncols).flat_map(move |c| {
(0..self.nrows).map(move |r| unsafe {
&mut *ptr.add(if column_major {
r + c * stride
} else {
r * stride + c
})
})
})),
}
}
}
impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixMutView<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"DenseMatrix: nrows: {:?}, ncols: {:?}",
self.nrows, self.ncols
)?;
writeln!(f, "column_major: {:?}", self.column_major)?;
self.display(f)
}
}
impl<T: Debug + Display + Copy + Sized> DenseMatrix<T> {
/// Create new instance of `DenseMatrix` without copying data.
/// `values` should be in column-major order.
pub fn new(
nrows: usize,
ncols: usize,
values: Vec<T>,
column_major: bool,
) -> Result<Self, Failed> {
let data_len = values.len();
if nrows * ncols != values.len() {
Err(Failed::input(&format!(
"The specified shape: (cols: {ncols}, rows: {nrows}) does not align with data len: {data_len}"
)))
} else {
Ok(DenseMatrix {
ncols,
nrows,
values,
column_major,
})
}
}
/// New instance of `DenseMatrix` from 2d array.
pub fn from_2d_array(values: &[&[T]]) -> Result<Self, Failed> {
DenseMatrix::from_2d_vec(&values.iter().map(|row| Vec::from(*row)).collect())
}
/// New instance of `DenseMatrix` from 2d vector.
#[allow(clippy::ptr_arg)]
pub fn from_2d_vec(values: &Vec<Vec<T>>) -> Result<Self, Failed> {
if values.is_empty() || values[0].is_empty() {
Err(Failed::input(
"The 2d vec provided is empty; cannot instantiate the matrix",
))
} else {
let nrows = values.len();
let ncols = values
.first()
.unwrap_or_else(|| {
panic!("Invalid state: Cannot create 2d matrix from an empty vector")
})
.len();
let mut m_values = Vec::with_capacity(nrows * ncols);
for c in 0..ncols {
for r in values.iter().take(nrows) {
m_values.push(r[c])
}
}
DenseMatrix::new(nrows, ncols, m_values, true)
}
}
/// Iterate over values of matrix
pub fn iter(&self) -> Iter<'_, T> {
self.values.iter()
}
/// Check if the size of the requested view is bounded to matrix rows/cols count
fn is_valid_view(
&self,
n_rows: usize,
n_cols: usize,
vrows: &Range<usize>,
vcols: &Range<usize>,
) -> bool {
!(vrows.end <= n_rows
&& vcols.end <= n_cols
&& vrows.start <= n_rows
&& vcols.start <= n_cols)
}
/// Compute the range of the requested view: start, end, size of the slice
fn stride_range(
&self,
n_rows: usize,
n_cols: usize,
vrows: &Range<usize>,
vcols: &Range<usize>,
column_major: bool,
) -> (usize, usize, usize) {
let (start, end, stride) = if column_major {
(
vrows.start + vcols.start * n_rows,
vrows.end + (vcols.end - 1) * n_rows,
n_rows,
)
} else {
(
vrows.start * n_cols + vcols.start,
(vrows.end - 1) * n_cols + vcols.end,
n_cols,
)
};
(start, end, stride)
}
}
impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrix<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"DenseMatrix: nrows: {:?}, ncols: {:?}",
self.nrows, self.ncols
)?;
writeln!(f, "column_major: {:?}", self.column_major)?;
self.display(f)
}
}
impl<T: Debug + Display + Copy + Sized + PartialEq> PartialEq for DenseMatrix<T> {
fn eq(&self, other: &Self) -> bool {
if self.ncols != other.ncols || self.nrows != other.nrows {
return false;
}
let len = self.values.len();
let other_len = other.values.len();
if len != other_len {
return false;
}
match self.column_major == other.column_major {
true => self
.values
.iter()
.zip(other.values.iter())
.all(|(&v1, v2)| v1.eq(v2)),
false => self
.iterator(0)
.zip(other.iterator(0))
.all(|(&v1, v2)| v1.eq(v2)),
}
}
}
impl<T: Number + RealNumber + AbsDiffEq> AbsDiffEq for DenseMatrix<T>
where
T::Epsilon: Copy,
{
type Epsilon = T::Epsilon;
fn default_epsilon() -> T::Epsilon {
T::default_epsilon()
}
// equality in differences in absolute values, according to an epsilon
fn abs_diff_eq(&self, other: &Self, epsilon: T::Epsilon) -> bool {
if self.ncols != other.ncols || self.nrows != other.nrows {
false
} else {
self.values
.iter()
.zip(other.values.iter())
.all(|(v1, v2)| T::abs_diff_eq(v1, v2, epsilon))
}
}
}
impl<T: Number + RealNumber + RelativeEq> RelativeEq for DenseMatrix<T>
where
T::Epsilon: Copy,
{
fn default_max_relative() -> T::Epsilon {
T::default_max_relative()
}
fn relative_eq(&self, other: &Self, epsilon: T::Epsilon, max_relative: T::Epsilon) -> bool {
if self.ncols != other.ncols || self.nrows != other.nrows {
false
} else {
self.iterator(0)
.zip(other.iterator(0))
.all(|(v1, v2)| T::relative_eq(v1, v2, epsilon, max_relative))
}
}
}
impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrix<T> {
fn get(&self, pos: (usize, usize)) -> &T {
let (row, col) = pos;
if row >= self.nrows || col >= self.ncols {
panic!(
"Invalid index ({},{}) for {}x{} matrix",
row, col, self.nrows, self.ncols
);
}
if self.column_major {
&self.values[col * self.nrows + row]
} else {
&self.values[col + self.ncols * row]
}
}
fn shape(&self) -> (usize, usize) {
(self.nrows, self.ncols)
}
fn is_empty(&self) -> bool {
self.ncols < 1 || self.nrows < 1
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
assert!(
axis == 1 || axis == 0,
"For two dimensional array `axis` should be either 0 or 1"
);
match axis {
0 => Box::new(
(0..self.nrows).flat_map(move |r| (0..self.ncols).map(move |c| self.get((r, c)))),
),
_ => Box::new(
(0..self.ncols).flat_map(move |c| (0..self.nrows).map(move |r| self.get((r, c)))),
),
}
}
}
impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for DenseMatrix<T> {
fn set(&mut self, pos: (usize, usize), x: T) {
if self.column_major {
self.values[pos.1 * self.nrows + pos.0] = x;
} else {
self.values[pos.1 + pos.0 * self.ncols] = x;
}
}
fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
let ptr = self.values.as_mut_ptr();
let column_major = self.column_major;
let (nrows, ncols) = self.shape();
match axis {
0 => Box::new((0..self.nrows).flat_map(move |r| {
(0..self.ncols).map(move |c| unsafe {
&mut *ptr.add(if column_major {
r + c * nrows
} else {
r * ncols + c
})
})
})),
_ => Box::new((0..self.ncols).flat_map(move |c| {
(0..self.nrows).map(move |r| unsafe {
&mut *ptr.add(if column_major {
r + c * nrows
} else {
r * ncols + c
})
})
})),
}
}
}
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrix<T> {}
impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for DenseMatrix<T> {}
impl<T: Debug + Display + Copy + Sized> Array2<T> for DenseMatrix<T> {
fn get_row<'a>(&'a self, row: usize) -> Box<dyn ArrayView1<T> + 'a> {
Box::new(DenseMatrixView::new(self, row..row + 1, 0..self.ncols).unwrap())
}
fn get_col<'a>(&'a self, col: usize) -> Box<dyn ArrayView1<T> + 'a> {
Box::new(DenseMatrixView::new(self, 0..self.nrows, col..col + 1).unwrap())
}
fn slice<'a>(&'a self, rows: Range<usize>, cols: Range<usize>) -> Box<dyn ArrayView2<T> + 'a> {
Box::new(DenseMatrixView::new(self, rows, cols).unwrap())
}
fn slice_mut<'a>(
&'a mut self,
rows: Range<usize>,
cols: Range<usize>,
) -> Box<dyn MutArrayView2<T> + 'a>
where
Self: Sized,
{
Box::new(DenseMatrixMutView::new(self, rows, cols).unwrap())
}
// private function so for now assume infalible
fn fill(nrows: usize, ncols: usize, value: T) -> Self {
DenseMatrix::new(nrows, ncols, vec![value; nrows * ncols], true).unwrap()
}
// private function so for now assume infalible
fn from_iterator<I: Iterator<Item = T>>(iter: I, nrows: usize, ncols: usize, axis: u8) -> Self {
DenseMatrix::new(nrows, ncols, iter.collect(), axis != 0).unwrap()
}
fn transpose(&self) -> Self {
let mut m = self.clone();
m.ncols = self.nrows;
m.nrows = self.ncols;
m.column_major = !self.column_major;
m
}
}
impl<T: Number + RealNumber> QRDecomposable<T> for DenseMatrix<T> {}
impl<T: Number + RealNumber> CholeskyDecomposable<T> for DenseMatrix<T> {}
impl<T: Number + RealNumber> EVDDecomposable<T> for DenseMatrix<T> {}
impl<T: Number + RealNumber> LUDecomposable<T> for DenseMatrix<T> {}
impl<T: Number + RealNumber> SVDDecomposable<T> for DenseMatrix<T> {}
impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixView<'_, T> {
fn get(&self, pos: (usize, usize)) -> &T {
if self.column_major {
&self.values[pos.0 + pos.1 * self.stride]
} else {
&self.values[pos.0 * self.stride + pos.1]
}
}
fn shape(&self) -> (usize, usize) {
(self.nrows, self.ncols)
}
fn is_empty(&self) -> bool {
self.nrows * self.ncols > 0
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
self.iter(axis)
}
}
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for DenseMatrixView<'_, T> {
fn get(&self, i: usize) -> &T {
if self.nrows == 1 {
if self.column_major {
&self.values[i * self.stride]
} else {
&self.values[i]
}
} else if self.ncols == 1 || (!self.column_major && self.nrows == 1) {
if self.column_major {
&self.values[i]
} else {
&self.values[i * self.stride]
}
} else {
panic!("This is neither a column nor a row");
}
}
fn shape(&self) -> usize {
if self.nrows == 1 {
self.ncols
} else if self.ncols == 1 {
self.nrows
} else {
panic!("This is neither a column nor a row");
}
}
fn is_empty(&self) -> bool {
self.nrows * self.ncols > 0
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
self.iter(axis)
}
}
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixView<'_, T> {}
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for DenseMatrixView<'_, T> {}
impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixMutView<'_, T> {
fn get(&self, pos: (usize, usize)) -> &T {
if self.column_major {
&self.values[pos.0 + pos.1 * self.stride]
} else {
&self.values[pos.0 * self.stride + pos.1]
}
}
fn shape(&self) -> (usize, usize) {
(self.nrows, self.ncols)
}
fn is_empty(&self) -> bool {
self.nrows * self.ncols > 0
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
self.iter(axis)
}
}
impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for DenseMatrixMutView<'_, T> {
fn set(&mut self, pos: (usize, usize), x: T) {
if self.column_major {
self.values[pos.0 + pos.1 * self.stride] = x;
} else {
self.values[pos.0 * self.stride + pos.1] = x;
}
}
fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
self.iter_mut(axis)
}
}
impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for DenseMatrixMutView<'_, T> {}
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixMutView<'_, T> {}
impl<T: RealNumber> MatrixStats<T> for DenseMatrix<T> {}
impl<T: RealNumber> MatrixPreprocessing<T> for DenseMatrix<T> {}
#[cfg(test)]
#[warn(clippy::reversed_empty_ranges)]
mod tests {
use super::*;
use approx::relative_eq;
#[test]
fn test_instantiate_from_2d() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
assert!(x.is_ok());
}
#[test]
fn test_instantiate_from_2d_empty() {
let input: &[&[f64]] = &[&[]];
let x = DenseMatrix::from_2d_array(input);
assert!(x.is_err());
}
#[test]
fn test_instantiate_from_2d_empty2() {
let input: &[&[f64]] = &[&[], &[]];
let x = DenseMatrix::from_2d_array(input);
assert!(x.is_err());
}
#[test]
fn test_instantiate_ok_view1() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
let v = DenseMatrixView::new(&x, 0..2, 0..2);
assert!(v.is_ok());
}
#[test]
fn test_instantiate_ok_view2() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
let v = DenseMatrixView::new(&x, 0..3, 0..3);
assert!(v.is_ok());
}
#[test]
fn test_instantiate_ok_view3() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
let v = DenseMatrixView::new(&x, 2..3, 0..3);
assert!(v.is_ok());
}
#[test]
fn test_instantiate_ok_view4() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
let v = DenseMatrixView::new(&x, 3..3, 0..3);
assert!(v.is_ok());
}
#[test]
fn test_instantiate_err_view1() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
let v = DenseMatrixView::new(&x, 3..4, 0..3);
assert!(v.is_err());
}
#[test]
fn test_instantiate_err_view2() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
let v = DenseMatrixView::new(&x, 0..3, 3..4);
assert!(v.is_err());
}
#[test]
fn test_instantiate_err_view3() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
#[allow(clippy::reversed_empty_ranges)]
let v = DenseMatrixView::new(&x, 0..3, 4..3);
assert!(v.is_err());
}
#[test]
fn test_display() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
println!("{}", &x);
}
#[test]
fn test_get_row_col() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
assert_eq!(15.0, x.get_col(1).sum());
assert_eq!(15.0, x.get_row(1).sum());
assert_eq!(81.0, x.get_col(1).dot(&(*x.get_row(1))));
}
#[test]
fn test_row_major() {
let mut x = DenseMatrix::new(2, 3, vec![1, 2, 3, 4, 5, 6], false).unwrap();
assert_eq!(5, *x.get_col(1).get(1));
assert_eq!(7, x.get_col(1).sum());
assert_eq!(5, *x.get_row(1).get(1));
assert_eq!(15, x.get_row(1).sum());
x.slice_mut(0..2, 1..2)
.iterator_mut(0)
.for_each(|v| *v += 2);
assert_eq!(vec![1, 4, 3, 4, 7, 6], *x.values);
}
#[test]
fn test_get_slice() {
let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9], &[10, 11, 12]])
.unwrap();
assert_eq!(
vec![4, 5, 6],
DenseMatrix::from_slice(&(*x.slice(1..2, 0..3))).values
);
let second_row: Vec<i32> = x.slice(1..2, 0..3).iterator(0).copied().collect();
assert_eq!(vec![4, 5, 6], second_row);
let second_col: Vec<i32> = x.slice(0..3, 1..2).iterator(0).copied().collect();
assert_eq!(vec![2, 5, 8], second_col);
}
#[test]
fn test_iter_mut() {
let mut x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]]).unwrap();
assert_eq!(vec![1, 4, 7, 2, 5, 8, 3, 6, 9], x.values);
// add +2 to some elements
x.slice_mut(1..2, 0..3)
.iterator_mut(0)
.for_each(|v| *v += 2);
assert_eq!(vec![1, 6, 7, 2, 7, 8, 3, 8, 9], x.values);
// add +1 to some others
x.slice_mut(0..3, 1..2)
.iterator_mut(0)
.for_each(|v| *v += 1);
assert_eq!(vec![1, 6, 7, 3, 8, 9, 3, 8, 9], x.values);
// rewrite matrix as indices of values per axis 1 (row-wise)
x.iterator_mut(1).enumerate().for_each(|(a, b)| *b = a);
assert_eq!(vec![0, 1, 2, 3, 4, 5, 6, 7, 8], x.values);
// rewrite matrix as indices of values per axis 0 (column-wise)
x.iterator_mut(0).enumerate().for_each(|(a, b)| *b = a);
assert_eq!(vec![0, 3, 6, 1, 4, 7, 2, 5, 8], x.values);
// rewrite some by slice
x.slice_mut(0..3, 0..2)
.iterator_mut(0)
.enumerate()
.for_each(|(a, b)| *b = a);
assert_eq!(vec![0, 2, 4, 1, 3, 5, 2, 5, 8], x.values);
x.slice_mut(0..2, 0..3)
.iterator_mut(1)
.enumerate()
.for_each(|(a, b)| *b = a);
assert_eq!(vec![0, 1, 4, 2, 3, 5, 4, 5, 8], x.values);
}
#[test]
fn test_str_array() {
let mut x =
DenseMatrix::from_2d_array(&[&["1", "2", "3"], &["4", "5", "6"], &["7", "8", "9"]])
.unwrap();
assert_eq!(vec!["1", "4", "7", "2", "5", "8", "3", "6", "9"], x.values);
x.iterator_mut(0).for_each(|v| *v = "str");
assert_eq!(
vec!["str", "str", "str", "str", "str", "str", "str", "str", "str"],
x.values
);
}
#[test]
fn test_transpose() {
let x = DenseMatrix::<&str>::from_2d_array(&[&["1", "2", "3"], &["4", "5", "6"]]).unwrap();
assert_eq!(vec!["1", "4", "2", "5", "3", "6"], x.values);
assert!(x.column_major);
// transpose
let x = x.transpose();
assert_eq!(vec!["1", "4", "2", "5", "3", "6"], x.values);
assert!(!x.column_major); // should change column_major
}
#[test]
fn test_from_iterator() {
let data = [1, 2, 3, 4, 5, 6];
let m = DenseMatrix::from_iterator(data.iter(), 2, 3, 0);
// make a vector into a 2x3 matrix.
assert_eq!(
vec![1, 2, 3, 4, 5, 6],
m.values.iter().map(|e| **e).collect::<Vec<i32>>()
);
assert!(!m.column_major);
}
#[test]
fn test_take() {
let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
let b = DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4], &[5, 6]]).unwrap();
println!("{a}");
// take column 0 and 2
assert_eq!(vec![1, 3, 4, 6], a.take(&[0, 2], 1).values);
println!("{b}");
// take rows 0 and 2
assert_eq!(vec![1, 2, 5, 6], b.take(&[0, 2], 0).values);
}
#[test]
fn test_mut() {
let a = DenseMatrix::from_2d_array(&[&[1.3, -2.1, 3.4], &[-4., -5.3, 6.1]]).unwrap();
let a = a.abs();
assert_eq!(vec![1.3, 4.0, 2.1, 5.3, 3.4, 6.1], a.values);
let a = a.neg();
assert_eq!(vec![-1.3, -4.0, -2.1, -5.3, -3.4, -6.1], a.values);
}
#[test]
fn test_reshape() {
let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9], &[10, 11, 12]])
.unwrap();
let a = a.reshape(2, 6, 0);
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], a.values);
assert!(a.ncols == 6 && a.nrows == 2 && !a.column_major);
let a = a.reshape(3, 4, 1);
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], a.values);
assert!(a.ncols == 4 && a.nrows == 3 && a.column_major);
}
#[test]
fn test_eq() {
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]).unwrap();
let b = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
let c = DenseMatrix::from_2d_array(&[
&[1. + f32::EPSILON, 2., 3.],
&[4., 5., 6. + f32::EPSILON],
])
.unwrap();
let d = DenseMatrix::from_2d_array(&[&[1. + 0.5, 2., 3.], &[4., 5., 6. + f32::EPSILON]])
.unwrap();
assert!(!relative_eq!(a, b));
assert!(!relative_eq!(a, d));
assert!(relative_eq!(a, c));
}
}
+8
View File
@@ -0,0 +1,8 @@
/// `Array`, `ArrayView` and related multidimensional
pub mod arrays;
/// foundamental implementation for a `DenseMatrix` construct
pub mod matrix;
/// foundamental implementation for 1D constructs
pub mod vector;
+348
View File
@@ -0,0 +1,348 @@
use std::fmt::{Debug, Display};
use std::ops::Range;
use crate::linalg::basic::arrays::{Array, Array1, ArrayView1, MutArray, MutArrayView1};
/// Provide mutable window on array
#[derive(Debug)]
pub struct VecMutView<'a, T: Debug + Display + Copy + Sized> {
ptr: &'a mut [T],
}
/// Provide window on array
#[derive(Debug, Clone)]
pub struct VecView<'a, T: Debug + Display + Copy + Sized> {
ptr: &'a [T],
}
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for &[T] {
fn get(&self, i: usize) -> &T {
&self[i]
}
fn shape(&self) -> usize {
self.len()
}
fn is_empty(&self) -> bool {
self.len() > 0
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
assert!(axis == 0, "For one dimensional array `axis` should == 0");
Box::new(self.iter())
}
}
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for Vec<T> {
fn get(&self, i: usize) -> &T {
&self[i]
}
fn shape(&self) -> usize {
self.len()
}
fn is_empty(&self) -> bool {
self.len() > 0
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
assert!(axis == 0, "For one dimensional array `axis` should == 0");
Box::new(self.iter())
}
}
impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for Vec<T> {
fn set(&mut self, i: usize, x: T) {
// NOTE: this panics in case of out of bounds index
self[i] = x
}
fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
assert!(axis == 0, "For one dimensional array `axis` should == 0");
Box::new(self.iter_mut())
}
}
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for Vec<T> {}
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for &[T] {}
impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for Vec<T> {}
impl<T: Debug + Display + Copy + Sized> Array1<T> for Vec<T> {
fn slice<'a>(&'a self, range: Range<usize>) -> Box<dyn ArrayView1<T> + 'a> {
assert!(
range.end <= self.len(),
"`range` should be <= {}",
self.len()
);
let view = VecView { ptr: &self[range] };
Box::new(view)
}
fn slice_mut<'b>(&'b mut self, range: Range<usize>) -> Box<dyn MutArrayView1<T> + 'b> {
assert!(
range.end <= self.len(),
"`range` should be <= {}",
self.len()
);
let view = VecMutView {
ptr: &mut self[range],
};
Box::new(view)
}
fn fill(len: usize, value: T) -> Self {
vec![value; len]
}
fn from_iterator<I: Iterator<Item = T>>(iter: I, len: usize) -> Self
where
Self: Sized,
{
let mut v: Vec<T> = Vec::with_capacity(len);
iter.take(len).for_each(|i| v.push(i));
v
}
fn from_vec_slice(slice: &[T]) -> Self {
let mut v: Vec<T> = Vec::with_capacity(slice.len());
slice.iter().for_each(|i| v.push(*i));
v
}
fn from_slice(slice: &dyn ArrayView1<T>) -> Self {
let mut v: Vec<T> = Vec::with_capacity(slice.shape());
slice.iterator(0).for_each(|i| v.push(*i));
v
}
}
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for VecMutView<'_, T> {
fn get(&self, i: usize) -> &T {
&self.ptr[i]
}
fn shape(&self) -> usize {
self.ptr.len()
}
fn is_empty(&self) -> bool {
self.ptr.len() > 0
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
assert!(axis == 0, "For one dimensional array `axis` should == 0");
Box::new(self.ptr.iter())
}
}
impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for VecMutView<'_, T> {
fn set(&mut self, i: usize, x: T) {
self.ptr[i] = x;
}
fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
assert!(axis == 0, "For one dimensional array `axis` should == 0");
Box::new(self.ptr.iter_mut())
}
}
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for VecMutView<'_, T> {}
impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for VecMutView<'_, T> {}
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for VecView<'_, T> {
fn get(&self, i: usize) -> &T {
&self.ptr[i]
}
fn shape(&self) -> usize {
self.ptr.len()
}
fn is_empty(&self) -> bool {
self.ptr.len() > 0
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
assert!(axis == 0, "For one dimensional array `axis` should == 0");
Box::new(self.ptr.iter())
}
}
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for VecView<'_, T> {}
#[cfg(test)]
mod tests {
use super::*;
use crate::numbers::basenum::Number;
fn dot_product<T: Number, V: Array1<T>>(v: &V) -> T {
let vv = V::zeros(10);
let v_s = vv.slice(0..3);
v_s.dot(v)
}
fn vector_ops<T: Number + PartialOrd, V: Array1<T>>(_: &V) -> T {
let v = V::zeros(10);
v.max()
}
#[test]
fn test_get_set() {
let mut x = vec![1, 2, 3];
assert_eq!(3, *x.get(2));
x.set(1, 1);
assert_eq!(1, *x.get(1));
}
#[test]
#[should_panic]
fn test_failed_set() {
vec![1, 2, 3].set(3, 1);
}
#[test]
#[should_panic]
fn test_failed_get() {
vec![1, 2, 3].get(3);
}
#[test]
fn test_len() {
let x = [1, 2, 3];
assert_eq!(3, x.len());
}
#[test]
fn test_is_empty() {
assert!(vec![1; 0].is_empty());
assert!(!vec![1, 2, 3].is_empty());
}
#[test]
fn test_iterator() {
let v: Vec<i32> = vec![1, 2, 3].iterator(0).map(|&v| v * 2).collect();
assert_eq!(vec![2, 4, 6], v);
}
#[test]
#[should_panic]
fn test_failed_iterator() {
let _ = vec![1, 2, 3].iterator(1);
}
#[test]
fn test_mut_iterator() {
let mut x = vec![1, 2, 3];
x.iterator_mut(0).for_each(|v| *v *= 2);
assert_eq!(vec![2, 4, 6], x);
}
#[test]
#[should_panic]
fn test_failed_mut_iterator() {
let _ = vec![1, 2, 3].iterator_mut(1);
}
#[test]
fn test_slice() {
let x = vec![1, 2, 3, 4, 5];
let x_slice = x.slice(2..3);
assert_eq!(1, x_slice.shape());
assert_eq!(3, *x_slice.get(0));
}
#[test]
#[should_panic]
fn test_failed_slice() {
vec![1, 2, 3].slice(0..4);
}
#[test]
fn test_mut_slice() {
let mut x = vec![1, 2, 3, 4, 5];
let mut x_slice = x.slice_mut(2..4);
x_slice.set(0, 9);
assert_eq!(2, x_slice.shape());
assert_eq!(9, *x_slice.get(0));
assert_eq!(4, *x_slice.get(1));
}
#[test]
#[should_panic]
fn test_failed_mut_slice() {
vec![1, 2, 3].slice_mut(0..4);
}
#[test]
fn test_init() {
assert_eq!(Vec::fill(3, 0), vec![0, 0, 0]);
assert_eq!(
Vec::from_iterator([0, 1, 2, 3].iter().cloned(), 3),
vec![0, 1, 2]
);
assert_eq!(Vec::from_vec_slice(&[0, 1, 2]), vec![0, 1, 2]);
assert_eq!(Vec::from_vec_slice(&[0, 1, 2, 3, 4][2..]), vec![2, 3, 4]);
assert_eq!(Vec::from_slice(&vec![1, 2, 3, 4, 5]), vec![1, 2, 3, 4, 5]);
assert_eq!(
Vec::from_slice(vec![1, 2, 3, 4, 5].slice(0..3).as_ref()),
vec![1, 2, 3]
);
}
#[test]
fn test_mul_scalar() {
let mut x = vec![1., 2., 3.];
let mut y = Vec::<f32>::zeros(10);
y.slice_mut(0..2).add_scalar_mut(1.0);
y.sub_scalar(1.0);
x.slice_mut(0..2).sub_scalar_mut(2.);
assert_eq!(vec![-1.0, 0.0, 3.0], x);
}
#[test]
fn test_dot() {
let y_i = vec![1, 2, 3];
let y = vec![1.0, 2.0, 3.0];
println!("Regular dot1: {:?}", dot_product(&y));
let x = vec![4.0, 5.0, 6.0];
assert_eq!(32.0, y.slice(0..3).dot(&(*x.slice(0..3))));
assert_eq!(32.0, y.slice(0..3).dot(&x));
assert_eq!(32.0, y.dot(&x));
assert_eq!(14, y_i.dot(&y_i));
}
#[test]
fn test_operators() {
let mut x: Vec<f32> = Vec::zeros(10);
x.add_scalar(15.0);
{
let mut x_s = x.slice_mut(0..5);
x_s.add_scalar_mut(1.0);
assert_eq!(
vec![1.0, 1.0, 1.0, 1.0, 1.0],
x_s.iterator(0).copied().collect::<Vec<f32>>()
);
}
assert_eq!(1.0, x.slice(2..3).min());
assert_eq!(vec![1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], x);
}
#[test]
fn test_vector_ops() {
let x = vec![1., 2., 3.];
vector_ops(&x);
}
}
+7 -762
View File
@@ -1,764 +1,9 @@
#![allow(clippy::wrong_self_convention)]
//! # Linear Algebra and Matrix Decomposition
//!
//! Most machine learning algorithms in SmartCore depend on linear algebra and matrix decomposition methods from this module.
//!
//! Traits [`BaseMatrix`](trait.BaseMatrix.html), [`Matrix`](trait.Matrix.html) and [`BaseVector`](trait.BaseVector.html) define
//! abstract methods that can be implemented for any two-dimensional and one-dimentional arrays (matrix and vector).
//! Functions from these traits are designed for SmartCore machine learning algorithms and should not be used directly in your code.
//! If you still want to use functions from `BaseMatrix`, `Matrix` and `BaseVector` please be aware that methods defined in these
//! traits might change in the future.
//!
//! One reason why linear algebra traits are public is to allow for different types of matrices and vectors to be plugged into SmartCore.
//! Once all methods defined in `BaseMatrix`, `Matrix` and `BaseVector` are implemented for your favourite type of matrix and vector you
//! should be able to run SmartCore algorithms on it. Please see `nalgebra_bindings` and `ndarray_bindings` modules for an example of how
//! it is done for other libraries.
//!
//! You will also find verious matrix decomposition methods that work for any matrix that extends [`Matrix`](trait.Matrix.html).
//! For example, to decompose matrix defined as [Vec](https://doc.rust-lang.org/std/vec/struct.Vec.html):
//!
//! ```
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::linalg::svd::*;
//!
//! let A = DenseMatrix::from_2d_array(&[
//! &[0.9000, 0.4000, 0.7000],
//! &[0.4000, 0.5000, 0.3000],
//! &[0.7000, 0.3000, 0.8000],
//! ]);
//!
//! let svd = A.svd().unwrap();
//!
//! let s: Vec<f64> = svd.s;
//! let v: DenseMatrix<f64> = svd.V;
//! let u: DenseMatrix<f64> = svd.U;
//! ```
/// basic data structures for linear algebra constructs: arrays and views
pub mod basic;
/// traits associated to algebraic constructs
pub mod traits;
pub mod cholesky;
/// The matrix is represented in terms of its eigenvalues and eigenvectors.
pub mod evd;
pub mod high_order;
/// Factors a matrix as the product of a lower triangular matrix and an upper triangular matrix.
pub mod lu;
/// Dense matrix with column-major order that wraps [Vec](https://doc.rust-lang.org/std/vec/struct.Vec.html).
pub mod naive;
/// [nalgebra](https://docs.rs/nalgebra/) bindings.
#[cfg(feature = "nalgebra-bindings")]
pub mod nalgebra_bindings;
/// [ndarray](https://docs.rs/ndarray) bindings.
#[cfg(feature = "ndarray-bindings")]
pub mod ndarray_bindings;
/// QR factorization that factors a matrix into a product of an orthogonal matrix and an upper triangular matrix.
pub mod qr;
pub mod stats;
/// Singular value decomposition.
pub mod svd;
use std::fmt::{Debug, Display};
use std::marker::PhantomData;
use std::ops::Range;
use crate::math::num::RealNumber;
use cholesky::CholeskyDecomposableMatrix;
use evd::EVDDecomposableMatrix;
use high_order::HighOrderOperations;
use lu::LUDecomposableMatrix;
use qr::QRDecomposableMatrix;
use stats::{MatrixPreprocessing, MatrixStats};
use svd::SVDDecomposableMatrix;
/// Column or row vector
pub trait BaseVector<T: RealNumber>: Clone + Debug {
/// Get an element of a vector
/// * `i` - index of an element
fn get(&self, i: usize) -> T;
/// Set an element at `i` to `x`
/// * `i` - index of an element
/// * `x` - new value
fn set(&mut self, i: usize, x: T);
/// Get number of elevemnt in the vector
fn len(&self) -> usize;
/// Returns true if the vector is empty.
fn is_empty(&self) -> bool {
self.len() == 0
}
/// Create a new vector from a &[T]
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
/// let a: [f64; 5] = [0., 0.5, 2., 3., 4.];
/// let v: Vec<f64> = BaseVector::from_array(&a);
/// assert_eq!(v, vec![0., 0.5, 2., 3., 4.]);
/// ```
fn from_array(f: &[T]) -> Self {
let mut v = Self::zeros(f.len());
for (i, elem) in f.iter().enumerate() {
v.set(i, *elem);
}
v
}
/// Return a vector with the elements of the one-dimensional array.
fn to_vec(&self) -> Vec<T>;
/// Create new vector with zeros of size `len`.
fn zeros(len: usize) -> Self;
/// Create new vector with ones of size `len`.
fn ones(len: usize) -> Self;
/// Create new vector of size `len` where each element is set to `value`.
fn fill(len: usize, value: T) -> Self;
/// Vector dot product
fn dot(&self, other: &Self) -> T;
/// Returns True if matrices are element-wise equal within a tolerance `error`.
fn approximate_eq(&self, other: &Self, error: T) -> bool;
/// Returns [L2 norm] of the vector(https://en.wikipedia.org/wiki/Matrix_norm).
fn norm2(&self) -> T;
/// Returns [vectors norm](https://en.wikipedia.org/wiki/Matrix_norm) of order `p`.
fn norm(&self, p: T) -> T;
/// Divide single element of the vector by `x`, write result to original vector.
fn div_element_mut(&mut self, pos: usize, x: T);
/// Multiply single element of the vector by `x`, write result to original vector.
fn mul_element_mut(&mut self, pos: usize, x: T);
/// Add single element of the vector to `x`, write result to original vector.
fn add_element_mut(&mut self, pos: usize, x: T);
/// Subtract `x` from single element of the vector, write result to original vector.
fn sub_element_mut(&mut self, pos: usize, x: T);
/// Subtract scalar
fn sub_scalar_mut(&mut self, x: T) -> &Self {
for i in 0..self.len() {
self.set(i, self.get(i) - x);
}
self
}
/// Subtract scalar
fn add_scalar_mut(&mut self, x: T) -> &Self {
for i in 0..self.len() {
self.set(i, self.get(i) + x);
}
self
}
/// Subtract scalar
fn mul_scalar_mut(&mut self, x: T) -> &Self {
for i in 0..self.len() {
self.set(i, self.get(i) * x);
}
self
}
/// Subtract scalar
fn div_scalar_mut(&mut self, x: T) -> &Self {
for i in 0..self.len() {
self.set(i, self.get(i) / x);
}
self
}
/// Add vectors, element-wise
fn add_scalar(&self, x: T) -> Self {
let mut r = self.clone();
r.add_scalar_mut(x);
r
}
/// Subtract vectors, element-wise
fn sub_scalar(&self, x: T) -> Self {
let mut r = self.clone();
r.sub_scalar_mut(x);
r
}
/// Multiply vectors, element-wise
fn mul_scalar(&self, x: T) -> Self {
let mut r = self.clone();
r.mul_scalar_mut(x);
r
}
/// Divide vectors, element-wise
fn div_scalar(&self, x: T) -> Self {
let mut r = self.clone();
r.div_scalar_mut(x);
r
}
/// Add vectors, element-wise, overriding original vector with result.
fn add_mut(&mut self, other: &Self) -> &Self;
/// Subtract vectors, element-wise, overriding original vector with result.
fn sub_mut(&mut self, other: &Self) -> &Self;
/// Multiply vectors, element-wise, overriding original vector with result.
fn mul_mut(&mut self, other: &Self) -> &Self;
/// Divide vectors, element-wise, overriding original vector with result.
fn div_mut(&mut self, other: &Self) -> &Self;
/// Add vectors, element-wise
fn add(&self, other: &Self) -> Self {
let mut r = self.clone();
r.add_mut(other);
r
}
/// Subtract vectors, element-wise
fn sub(&self, other: &Self) -> Self {
let mut r = self.clone();
r.sub_mut(other);
r
}
/// Multiply vectors, element-wise
fn mul(&self, other: &Self) -> Self {
let mut r = self.clone();
r.mul_mut(other);
r
}
/// Divide vectors, element-wise
fn div(&self, other: &Self) -> Self {
let mut r = self.clone();
r.div_mut(other);
r
}
/// Calculates sum of all elements of the vector.
fn sum(&self) -> T;
/// Returns unique values from the vector.
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
/// let a = vec!(1., 2., 2., -2., -6., -7., 2., 3., 4.);
///
///assert_eq!(a.unique(), vec![-7., -6., -2., 1., 2., 3., 4.]);
/// ```
fn unique(&self) -> Vec<T>;
/// Computes the arithmetic mean.
fn mean(&self) -> T {
self.sum() / T::from_usize(self.len()).unwrap()
}
/// Computes variance.
fn var(&self) -> T {
let n = self.len();
let mut mu = T::zero();
let mut sum = T::zero();
let div = T::from_usize(n).unwrap();
for i in 0..n {
let xi = self.get(i);
mu += xi;
sum += xi * xi;
}
mu /= div;
sum / div - mu.powi(2)
}
/// Computes the standard deviation.
fn std(&self) -> T {
self.var().sqrt()
}
/// Copies content of `other` vector.
fn copy_from(&mut self, other: &Self);
/// Take elements from an array.
fn take(&self, index: &[usize]) -> Self {
let n = index.len();
let mut result = Self::zeros(n);
for (i, idx) in index.iter().enumerate() {
result.set(i, self.get(*idx));
}
result
}
}
/// Generic matrix type.
pub trait BaseMatrix<T: RealNumber>: Clone + Debug {
/// Row vector that is associated with this matrix type,
/// e.g. if we have an implementation of sparce matrix
/// we should have an associated sparce vector type that
/// represents a row in this matrix.
type RowVector: BaseVector<T> + Clone + Debug;
/// Transforms row vector `vec` into a 1xM matrix.
fn from_row_vector(vec: Self::RowVector) -> Self;
/// Transforms 1-d matrix of 1xM into a row vector.
fn to_row_vector(self) -> Self::RowVector;
/// Get an element of the matrix.
/// * `row` - row number
/// * `col` - column number
fn get(&self, row: usize, col: usize) -> T;
/// Get a vector with elements of the `row`'th row
/// * `row` - row number
fn get_row_as_vec(&self, row: usize) -> Vec<T>;
/// Get the `row`'th row
/// * `row` - row number
fn get_row(&self, row: usize) -> Self::RowVector;
/// Copies a vector with elements of the `row`'th row into `result`
/// * `row` - row number
/// * `result` - receiver for the row
fn copy_row_as_vec(&self, row: usize, result: &mut Vec<T>);
/// Get a vector with elements of the `col`'th column
/// * `col` - column number
fn get_col_as_vec(&self, col: usize) -> Vec<T>;
/// Copies a vector with elements of the `col`'th column into `result`
/// * `col` - column number
/// * `result` - receiver for the col
fn copy_col_as_vec(&self, col: usize, result: &mut Vec<T>);
/// Set an element at `col`, `row` to `x`
fn set(&mut self, row: usize, col: usize, x: T);
/// Create an identity matrix of size `size`
fn eye(size: usize) -> Self;
/// Create new matrix with zeros of size `nrows` by `ncols`.
fn zeros(nrows: usize, ncols: usize) -> Self;
/// Create new matrix with ones of size `nrows` by `ncols`.
fn ones(nrows: usize, ncols: usize) -> Self;
/// Create new matrix of size `nrows` by `ncols` where each element is set to `value`.
fn fill(nrows: usize, ncols: usize, value: T) -> Self;
/// Return the shape of an array.
fn shape(&self) -> (usize, usize);
/// Stack arrays in sequence vertically (row wise).
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
///
/// let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
/// let b = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.]]);
/// let expected = DenseMatrix::from_2d_array(&[
/// &[1., 2., 3., 1., 2.],
/// &[4., 5., 6., 3., 4.]
/// ]);
///
/// assert_eq!(a.h_stack(&b), expected);
/// ```
fn h_stack(&self, other: &Self) -> Self;
/// Stack arrays in sequence horizontally (column wise).
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
///
/// let a = DenseMatrix::from_array(1, 3, &[1., 2., 3.]);
/// let b = DenseMatrix::from_array(1, 3, &[4., 5., 6.]);
/// let expected = DenseMatrix::from_2d_array(&[
/// &[1., 2., 3.],
/// &[4., 5., 6.]
/// ]);
///
/// assert_eq!(a.v_stack(&b), expected);
/// ```
fn v_stack(&self, other: &Self) -> Self;
/// Matrix product.
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
///
/// let a = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.]]);
/// let expected = DenseMatrix::from_2d_array(&[
/// &[7., 10.],
/// &[15., 22.]
/// ]);
///
/// assert_eq!(a.matmul(&a), expected);
/// ```
fn matmul(&self, other: &Self) -> Self;
/// Vector dot product
/// Both matrices should be of size _1xM_
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
///
/// let a = DenseMatrix::from_array(1, 3, &[1., 2., 3.]);
/// let b = DenseMatrix::from_array(1, 3, &[4., 5., 6.]);
///
/// assert_eq!(a.dot(&b), 32.);
/// ```
fn dot(&self, other: &Self) -> T;
/// Return a slice of the matrix.
/// * `rows` - range of rows to return
/// * `cols` - range of columns to return
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
///
/// let m = DenseMatrix::from_2d_array(&[
/// &[1., 2., 3., 1.],
/// &[4., 5., 6., 3.],
/// &[7., 8., 9., 5.]
/// ]);
/// let expected = DenseMatrix::from_2d_array(&[&[2., 3.], &[5., 6.]]);
/// let result = m.slice(0..2, 1..3);
/// assert_eq!(result, expected);
/// ```
fn slice(&self, rows: Range<usize>, cols: Range<usize>) -> Self;
/// Returns True if matrices are element-wise equal within a tolerance `error`.
fn approximate_eq(&self, other: &Self, error: T) -> bool;
/// Add matrices, element-wise, overriding original matrix with result.
fn add_mut(&mut self, other: &Self) -> &Self;
/// Subtract matrices, element-wise, overriding original matrix with result.
fn sub_mut(&mut self, other: &Self) -> &Self;
/// Multiply matrices, element-wise, overriding original matrix with result.
fn mul_mut(&mut self, other: &Self) -> &Self;
/// Divide matrices, element-wise, overriding original matrix with result.
fn div_mut(&mut self, other: &Self) -> &Self;
/// Divide single element of the matrix by `x`, write result to original matrix.
fn div_element_mut(&mut self, row: usize, col: usize, x: T);
/// Multiply single element of the matrix by `x`, write result to original matrix.
fn mul_element_mut(&mut self, row: usize, col: usize, x: T);
/// Add single element of the matrix to `x`, write result to original matrix.
fn add_element_mut(&mut self, row: usize, col: usize, x: T);
/// Subtract `x` from single element of the matrix, write result to original matrix.
fn sub_element_mut(&mut self, row: usize, col: usize, x: T);
/// Add matrices, element-wise
fn add(&self, other: &Self) -> Self {
let mut r = self.clone();
r.add_mut(other);
r
}
/// Subtract matrices, element-wise
fn sub(&self, other: &Self) -> Self {
let mut r = self.clone();
r.sub_mut(other);
r
}
/// Multiply matrices, element-wise
fn mul(&self, other: &Self) -> Self {
let mut r = self.clone();
r.mul_mut(other);
r
}
/// Divide matrices, element-wise
fn div(&self, other: &Self) -> Self {
let mut r = self.clone();
r.div_mut(other);
r
}
/// Add `scalar` to the matrix, override original matrix with result.
fn add_scalar_mut(&mut self, scalar: T) -> &Self;
/// Subtract `scalar` from the elements of matrix, override original matrix with result.
fn sub_scalar_mut(&mut self, scalar: T) -> &Self;
/// Multiply `scalar` by the elements of matrix, override original matrix with result.
fn mul_scalar_mut(&mut self, scalar: T) -> &Self;
/// Divide elements of the matrix by `scalar`, override original matrix with result.
fn div_scalar_mut(&mut self, scalar: T) -> &Self;
/// Add `scalar` to the matrix.
fn add_scalar(&self, scalar: T) -> Self {
let mut r = self.clone();
r.add_scalar_mut(scalar);
r
}
/// Subtract `scalar` from the elements of matrix.
fn sub_scalar(&self, scalar: T) -> Self {
let mut r = self.clone();
r.sub_scalar_mut(scalar);
r
}
/// Multiply `scalar` by the elements of matrix.
fn mul_scalar(&self, scalar: T) -> Self {
let mut r = self.clone();
r.mul_scalar_mut(scalar);
r
}
/// Divide elements of the matrix by `scalar`.
fn div_scalar(&self, scalar: T) -> Self {
let mut r = self.clone();
r.div_scalar_mut(scalar);
r
}
/// Reverse or permute the axes of the matrix, return new matrix.
fn transpose(&self) -> Self;
/// Create new `nrows` by `ncols` matrix and populate it with random samples from a uniform distribution over [0, 1).
fn rand(nrows: usize, ncols: usize) -> Self;
/// Returns [L2 norm](https://en.wikipedia.org/wiki/Matrix_norm).
fn norm2(&self) -> T;
/// Returns [matrix norm](https://en.wikipedia.org/wiki/Matrix_norm) of order `p`.
fn norm(&self, p: T) -> T;
/// Returns the average of the matrix columns.
fn column_mean(&self) -> Vec<T>;
/// Numerical negative, element-wise. Overrides original matrix.
fn negative_mut(&mut self);
/// Numerical negative, element-wise.
fn negative(&self) -> Self {
let mut result = self.clone();
result.negative_mut();
result
}
/// Returns new matrix of shape `nrows` by `ncols` with data copied from original matrix.
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
///
/// let a = DenseMatrix::from_array(1, 6, &[1., 2., 3., 4., 5., 6.]);
/// let expected = DenseMatrix::from_2d_array(&[
/// &[1., 2., 3.],
/// &[4., 5., 6.]
/// ]);
///
/// assert_eq!(a.reshape(2, 3), expected);
/// ```
fn reshape(&self, nrows: usize, ncols: usize) -> Self;
/// Copies content of `other` matrix.
fn copy_from(&mut self, other: &Self);
/// Calculate the absolute value element-wise. Overrides original matrix.
fn abs_mut(&mut self) -> &Self;
/// Calculate the absolute value element-wise.
fn abs(&self) -> Self {
let mut result = self.clone();
result.abs_mut();
result
}
/// Calculates sum of all elements of the matrix.
fn sum(&self) -> T;
/// Calculates max of all elements of the matrix.
fn max(&self) -> T;
/// Calculates min of all elements of the matrix.
fn min(&self) -> T;
/// Calculates max(|a - b|) of two matrices
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
///
/// let a = DenseMatrix::from_array(2, 3, &[1., 2., 3., 4., -5., 6.]);
/// let b = DenseMatrix::from_array(2, 3, &[2., 3., 4., 1., 0., -12.]);
///
/// assert_eq!(a.max_diff(&b), 18.);
/// assert_eq!(b.max_diff(&b), 0.);
/// ```
fn max_diff(&self, other: &Self) -> T {
self.sub(other).abs().max()
}
/// Calculates [Softmax function](https://en.wikipedia.org/wiki/Softmax_function). Overrides the matrix with result.
fn softmax_mut(&mut self);
/// Raises elements of the matrix to the power of `p`
fn pow_mut(&mut self, p: T) -> &Self;
/// Returns new matrix with elements raised to the power of `p`
fn pow(&mut self, p: T) -> Self {
let mut result = self.clone();
result.pow_mut(p);
result
}
/// Returns the indices of the maximum values in each row.
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
/// let a = DenseMatrix::from_array(2, 3, &[1., 2., 3., -5., -6., -7.]);
///
/// assert_eq!(a.argmax(), vec![2, 0]);
/// ```
fn argmax(&self) -> Vec<usize>;
/// Returns vector with unique values from the matrix.
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
/// let a = DenseMatrix::from_array(3, 3, &[1., 2., 2., -2., -6., -7., 2., 3., 4.]);
///
///assert_eq!(a.unique(), vec![-7., -6., -2., 1., 2., 3., 4.]);
/// ```
fn unique(&self) -> Vec<T>;
/// Calculates the covariance matrix
fn cov(&self) -> Self;
/// Take elements from an array along an axis.
fn take(&self, index: &[usize], axis: u8) -> Self {
let (n, p) = self.shape();
let k = match axis {
0 => p,
_ => n,
};
let mut result = match axis {
0 => Self::zeros(index.len(), p),
_ => Self::zeros(n, index.len()),
};
for (i, idx) in index.iter().enumerate() {
for j in 0..k {
match axis {
0 => result.set(i, j, self.get(*idx, j)),
_ => result.set(j, i, self.get(j, *idx)),
};
}
}
result
}
}
/// Generic matrix with additional mixins like various factorization methods.
pub trait Matrix<T: RealNumber>:
BaseMatrix<T>
+ SVDDecomposableMatrix<T>
+ EVDDecomposableMatrix<T>
+ QRDecomposableMatrix<T>
+ LUDecomposableMatrix<T>
+ CholeskyDecomposableMatrix<T>
+ MatrixStats<T>
+ MatrixPreprocessing<T>
+ HighOrderOperations<T>
+ PartialEq
+ Display
{
}
pub(crate) fn row_iter<F: RealNumber, M: BaseMatrix<F>>(m: &M) -> RowIter<'_, F, M> {
RowIter {
m,
pos: 0,
max_pos: m.shape().0,
phantom: PhantomData,
}
}
pub(crate) struct RowIter<'a, T: RealNumber, M: BaseMatrix<T>> {
m: &'a M,
pos: usize,
max_pos: usize,
phantom: PhantomData<&'a T>,
}
impl<'a, T: RealNumber, M: BaseMatrix<T>> Iterator for RowIter<'a, T, M> {
type Item = Vec<T>;
fn next(&mut self) -> Option<Vec<T>> {
let res = if self.pos < self.max_pos {
Some(self.m.get_row_as_vec(self.pos))
} else {
None
};
self.pos += 1;
res
}
}
#[cfg(test)]
mod tests {
use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::linalg::BaseMatrix;
use crate::linalg::BaseVector;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn mean() {
let m = vec![1., 2., 3.];
assert_eq!(m.mean(), 2.0);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn std() {
let m = vec![1., 2., 3.];
assert!((m.std() - 0.81f64).abs() < 1e-2);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn var() {
let m = vec![1., 2., 3., 4.];
assert!((m.var() - 1.25f64).abs() < std::f64::EPSILON);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn vec_take() {
let m = vec![1., 2., 3., 4., 5.];
assert_eq!(m.take(&vec!(0, 0, 4, 4)), vec![1., 1., 5., 5.]);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn take() {
let m = DenseMatrix::from_2d_array(&[
&[1.0, 2.0],
&[3.0, 4.0],
&[5.0, 6.0],
&[7.0, 8.0],
&[9.0, 10.0],
]);
let expected_0 = DenseMatrix::from_2d_array(&[&[3.0, 4.0], &[3.0, 4.0], &[7.0, 8.0]]);
let expected_1 = DenseMatrix::from_2d_array(&[
&[2.0, 1.0],
&[4.0, 3.0],
&[6.0, 5.0],
&[8.0, 7.0],
&[10.0, 9.0],
]);
assert_eq!(m.take(&vec!(1, 1, 3), 0), expected_0);
assert_eq!(m.take(&vec!(1, 0), 1), expected_1);
}
}
/// ndarray bindings
pub mod ndarray;
File diff suppressed because it is too large Load Diff
-26
View File
@@ -1,26 +0,0 @@
//! # Simple Dense Matrix
//!
//! Implements [`BaseMatrix`](../../trait.BaseMatrix.html) and [`BaseVector`](../../trait.BaseVector.html) for [Vec](https://doc.rust-lang.org/std/vec/struct.Vec.html).
//! Data is stored in dense format with [column-major order](https://en.wikipedia.org/wiki/Row-_and_column-major_order).
//!
//! Example:
//!
//! ```
//! use smartcore::linalg::naive::dense_matrix::*;
//!
//! // 3x3 matrix
//! let A = DenseMatrix::from_2d_array(&[
//! &[0.9000, 0.4000, 0.7000],
//! &[0.4000, 0.5000, 0.3000],
//! &[0.7000, 0.3000, 0.8000],
//! ]);
//!
//! // row vector
//! let B = DenseMatrix::from_array(1, 3, &[0.9, 0.4, 0.7]);
//!
//! // column vector
//! let C = DenseMatrix::from_vec(3, 1, &vec!(0.9, 0.4, 0.7));
//! ```
/// Add this module to use Dense Matrix
pub mod dense_matrix;
File diff suppressed because it is too large Load Diff
+282
View File
@@ -0,0 +1,282 @@
use std::fmt::{Debug, Display};
use std::ops::Range;
use crate::linalg::basic::arrays::{
Array as BaseArray, Array2, ArrayView1, ArrayView2, MutArray, MutArrayView2,
};
use crate::linalg::traits::cholesky::CholeskyDecomposable;
use crate::linalg::traits::evd::EVDDecomposable;
use crate::linalg::traits::lu::LUDecomposable;
use crate::linalg::traits::qr::QRDecomposable;
use crate::linalg::traits::svd::SVDDecomposable;
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
use ndarray::{s, Array, ArrayBase, ArrayView, ArrayViewMut, Ix2, OwnedRepr};
impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)>
for ArrayBase<OwnedRepr<T>, Ix2>
{
fn get(&self, pos: (usize, usize)) -> &T {
&self[[pos.0, pos.1]]
}
fn shape(&self) -> (usize, usize) {
(self.nrows(), self.ncols())
}
fn is_empty(&self) -> bool {
self.len() > 0
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
assert!(
axis == 1 || axis == 0,
"For two dimensional array `axis` should be either 0 or 1"
);
match axis {
0 => Box::new(self.iter()),
_ => Box::new(
(0..self.ncols()).flat_map(move |c| (0..self.nrows()).map(move |r| &self[[r, c]])),
),
}
}
}
impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)>
for ArrayBase<OwnedRepr<T>, Ix2>
{
fn set(&mut self, pos: (usize, usize), x: T) {
self[[pos.0, pos.1]] = x
}
fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
let ptr = self.as_mut_ptr();
let stride = self.strides();
let (rstride, cstride) = (stride[0] as usize, stride[1] as usize);
match axis {
0 => Box::new(self.iter_mut()),
_ => Box::new((0..self.ncols()).flat_map(move |c| {
(0..self.nrows()).map(move |r| unsafe { &mut *ptr.add(r * rstride + c * cstride) })
})),
}
}
}
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> for ArrayView<'_, T, Ix2> {
fn get(&self, pos: (usize, usize)) -> &T {
&self[[pos.0, pos.1]]
}
fn shape(&self) -> (usize, usize) {
(self.nrows(), self.ncols())
}
fn is_empty(&self) -> bool {
self.len() > 0
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
assert!(
axis == 1 || axis == 0,
"For two dimensional array `axis` should be either 0 or 1"
);
match axis {
0 => Box::new(self.iter()),
_ => Box::new(
(0..self.ncols()).flat_map(move |c| (0..self.nrows()).map(move |r| &self[[r, c]])),
),
}
}
}
impl<T: Debug + Display + Copy + Sized> Array2<T> for ArrayBase<OwnedRepr<T>, Ix2> {
fn get_row<'a>(&'a self, row: usize) -> Box<dyn ArrayView1<T> + 'a> {
Box::new(self.row(row))
}
fn get_col<'a>(&'a self, col: usize) -> Box<dyn ArrayView1<T> + 'a> {
Box::new(self.column(col))
}
fn slice<'a>(&'a self, rows: Range<usize>, cols: Range<usize>) -> Box<dyn ArrayView2<T> + 'a> {
Box::new(self.slice(s![rows, cols]))
}
fn slice_mut<'a>(
&'a mut self,
rows: Range<usize>,
cols: Range<usize>,
) -> Box<dyn MutArrayView2<T> + 'a>
where
Self: Sized,
{
Box::new(self.slice_mut(s![rows, cols]))
}
fn fill(nrows: usize, ncols: usize, value: T) -> Self {
Array::from_elem([nrows, ncols], value)
}
fn from_iterator<I: Iterator<Item = T>>(iter: I, nrows: usize, ncols: usize, axis: u8) -> Self {
let a = Array::from_iter(iter.take(nrows * ncols))
.into_shape((nrows, ncols))
.unwrap();
match axis {
0 => a,
_ => a.reversed_axes().into_shape((nrows, ncols)).unwrap(),
}
}
fn transpose(&self) -> Self {
self.t().to_owned()
}
}
impl<T: Number + RealNumber> QRDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl<T: Number + RealNumber> CholeskyDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl<T: Number + RealNumber> EVDDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl<T: Number + RealNumber> LUDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl<T: Number + RealNumber> SVDDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayView<'_, T, Ix2> {}
impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> for ArrayViewMut<'_, T, Ix2> {
fn get(&self, pos: (usize, usize)) -> &T {
&self[[pos.0, pos.1]]
}
fn shape(&self) -> (usize, usize) {
(self.nrows(), self.ncols())
}
fn is_empty(&self) -> bool {
self.len() > 0
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
assert!(
axis == 1 || axis == 0,
"For two dimensional array `axis` should be either 0 or 1"
);
match axis {
0 => Box::new(self.iter()),
_ => Box::new(
(0..self.ncols()).flat_map(move |c| (0..self.nrows()).map(move |r| &self[[r, c]])),
),
}
}
}
impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for ArrayViewMut<'_, T, Ix2> {
fn set(&mut self, pos: (usize, usize), x: T) {
self[[pos.0, pos.1]] = x
}
fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
let ptr = self.as_mut_ptr();
let stride = self.strides();
let (rstride, cstride) = (stride[0] as usize, stride[1] as usize);
match axis {
0 => Box::new(self.iter_mut()),
_ => Box::new((0..self.ncols()).flat_map(move |c| {
(0..self.nrows()).map(move |r| unsafe { &mut *ptr.add(r * rstride + c * cstride) })
})),
}
}
}
impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for ArrayViewMut<'_, T, Ix2> {}
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayViewMut<'_, T, Ix2> {}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{arr2, Array2 as NDArray2};
#[test]
fn test_get_set() {
let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]);
assert_eq!(*BaseArray::get(&a, (1, 1)), 5);
a.set((1, 1), 9);
assert_eq!(a, arr2(&[[1, 2, 3], [4, 9, 6]]));
}
#[test]
fn test_iterator() {
let a = arr2(&[[1, 2, 3], [4, 5, 6]]);
let v: Vec<i32> = a.iterator(0).copied().collect();
assert_eq!(v, vec!(1, 2, 3, 4, 5, 6));
}
#[test]
fn test_mut_iterator() {
let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]);
a.iterator_mut(0).enumerate().for_each(|(i, v)| *v = i);
assert_eq!(a, arr2(&[[0, 1, 2], [3, 4, 5]]));
a.iterator_mut(1).enumerate().for_each(|(i, v)| *v = i);
assert_eq!(a, arr2(&[[0, 2, 4], [1, 3, 5]]));
}
#[test]
fn test_slice() {
let x = arr2(&[[1, 2, 3], [4, 5, 6]]);
let x_slice = Array2::slice(&x, 0..2, 1..2);
assert_eq!((2, 1), x_slice.shape());
let v: Vec<i32> = x_slice.iterator(0).copied().collect();
assert_eq!(v, [2, 5]);
}
#[test]
fn test_slice_iter() {
let x = arr2(&[[1, 2, 3], [4, 5, 6]]);
let x_slice = Array2::slice(&x, 0..2, 0..3);
assert_eq!(
x_slice.iterator(0).copied().collect::<Vec<i32>>(),
vec![1, 2, 3, 4, 5, 6]
);
assert_eq!(
x_slice.iterator(1).copied().collect::<Vec<i32>>(),
vec![1, 4, 2, 5, 3, 6]
);
}
#[test]
fn test_slice_mut_iter() {
let mut x = arr2(&[[1, 2, 3], [4, 5, 6]]);
{
let mut x_slice = Array2::slice_mut(&mut x, 0..2, 0..3);
x_slice
.iterator_mut(0)
.enumerate()
.for_each(|(i, v)| *v = i);
}
assert_eq!(x, arr2(&[[0, 1, 2], [3, 4, 5]]));
{
let mut x_slice = Array2::slice_mut(&mut x, 0..2, 0..3);
x_slice
.iterator_mut(1)
.enumerate()
.for_each(|(i, v)| *v = i);
}
assert_eq!(x, arr2(&[[0, 2, 4], [1, 3, 5]]));
}
#[test]
fn test_c_from_iterator() {
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
let a: NDArray2<i32> = Array2::from_iterator(data.clone().into_iter(), 4, 3, 0);
println!("{a}");
let a: NDArray2<i32> = Array2::from_iterator(data.into_iter(), 4, 3, 1);
println!("{a}");
}
}
+4
View File
@@ -0,0 +1,4 @@
/// matrix bindings
pub mod matrix;
/// vector bindings
pub mod vector;
+184
View File
@@ -0,0 +1,184 @@
use std::fmt::{Debug, Display};
use std::ops::Range;
use crate::linalg::basic::arrays::{
Array as BaseArray, Array1, ArrayView1, MutArray, MutArrayView1,
};
use ndarray::{s, Array, ArrayBase, ArrayView, ArrayViewMut, Ix1, OwnedRepr};
impl<T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayBase<OwnedRepr<T>, Ix1> {
fn get(&self, i: usize) -> &T {
&self[i]
}
fn shape(&self) -> usize {
self.len()
}
fn is_empty(&self) -> bool {
self.len() > 0
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
assert!(axis == 0, "For one dimensional array `axis` should == 0");
Box::new(self.iter())
}
}
impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for ArrayBase<OwnedRepr<T>, Ix1> {
fn set(&mut self, i: usize, x: T) {
self[i] = x
}
fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
assert!(axis == 0, "For one dimensional array `axis` should == 0");
Box::new(self.iter_mut())
}
}
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayBase<OwnedRepr<T>, Ix1> {}
impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for ArrayBase<OwnedRepr<T>, Ix1> {}
impl<T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayView<'_, T, Ix1> {
fn get(&self, i: usize) -> &T {
&self[i]
}
fn shape(&self) -> usize {
self.len()
}
fn is_empty(&self) -> bool {
self.len() > 0
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
assert!(axis == 0, "For one dimensional array `axis` should == 0");
Box::new(self.iter())
}
}
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayView<'_, T, Ix1> {}
impl<T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayViewMut<'_, T, Ix1> {
fn get(&self, i: usize) -> &T {
&self[i]
}
fn shape(&self) -> usize {
self.len()
}
fn is_empty(&self) -> bool {
self.len() > 0
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
assert!(axis == 0, "For one dimensional array `axis` should == 0");
Box::new(self.iter())
}
}
impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for ArrayViewMut<'_, T, Ix1> {
fn set(&mut self, i: usize, x: T) {
self[i] = x;
}
fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
assert!(axis == 0, "For one dimensional array `axis` should == 0");
Box::new(self.iter_mut())
}
}
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayViewMut<'_, T, Ix1> {}
impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for ArrayViewMut<'_, T, Ix1> {}
impl<T: Debug + Display + Copy + Sized> Array1<T> for ArrayBase<OwnedRepr<T>, Ix1> {
fn slice<'a>(&'a self, range: Range<usize>) -> Box<dyn ArrayView1<T> + 'a> {
assert!(
range.end <= self.len(),
"`range` should be <= {}",
self.len()
);
Box::new(self.slice(s![range]))
}
fn slice_mut<'b>(&'b mut self, range: Range<usize>) -> Box<dyn MutArrayView1<T> + 'b> {
assert!(
range.end <= self.len(),
"`range` should be <= {}",
self.len()
);
Box::new(self.slice_mut(s![range]))
}
fn fill(len: usize, value: T) -> Self {
Array::from_elem(len, value)
}
fn from_iterator<I: Iterator<Item = T>>(iter: I, len: usize) -> Self
where
Self: Sized,
{
Array::from_iter(iter.take(len))
}
fn from_vec_slice(slice: &[T]) -> Self {
Array::from_iter(slice.iter().copied())
}
fn from_slice(slice: &dyn ArrayView1<T>) -> Self {
Array::from_iter(slice.iterator(0).copied())
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::arr1;
#[test]
fn test_get_set() {
let mut a = arr1(&[1, 2, 3]);
assert_eq!(*BaseArray::get(&a, 1), 2);
a.set(1, 9);
assert_eq!(a, arr1(&[1, 9, 3]));
}
#[test]
fn test_iterator() {
let a = arr1(&[1, 2, 3]);
let v: Vec<i32> = a.iterator(0).copied().collect();
assert_eq!(v, vec!(1, 2, 3));
}
#[test]
fn test_mut_iterator() {
let mut a = arr1(&[1, 2, 3]);
a.iterator_mut(0).for_each(|v| *v = 1);
assert_eq!(a, arr1(&[1, 1, 1]));
}
#[test]
fn test_slice() {
let x = arr1(&[1, 2, 3, 4, 5]);
let x_slice = Array1::slice(&x, 2..3);
assert_eq!(1, x_slice.shape());
assert_eq!(3, *x_slice.get(0));
}
#[test]
fn test_mut_slice() {
let mut x = arr1(&[1, 2, 3, 4, 5]);
let mut x_slice = Array1::slice_mut(&mut x, 2..4);
x_slice.set(0, 9);
assert_eq!(2, x_slice.shape());
assert_eq!(9, *x_slice.get(0));
assert_eq!(4, *x_slice.get(1));
}
}
File diff suppressed because it is too large Load Diff
-207
View File
@@ -1,207 +0,0 @@
//! # Various Statistical Methods
//!
//! This module provides reference implementations for various statistical functions.
//! Concrete implementations of the `BaseMatrix` trait are free to override these methods for better performance.
use crate::linalg::BaseMatrix;
use crate::math::num::RealNumber;
/// Defines baseline implementations for various statistical functions
pub trait MatrixStats<T: RealNumber>: BaseMatrix<T> {
/// Computes the arithmetic mean along the specified axis.
fn mean(&self, axis: u8) -> Vec<T> {
let (n, m) = match axis {
0 => {
let (n, m) = self.shape();
(m, n)
}
_ => self.shape(),
};
let mut x: Vec<T> = vec![T::zero(); n];
let div = T::from_usize(m).unwrap();
for (i, x_i) in x.iter_mut().enumerate().take(n) {
for j in 0..m {
*x_i += match axis {
0 => self.get(j, i),
_ => self.get(i, j),
};
}
*x_i /= div;
}
x
}
/// Computes variance along the specified axis.
fn var(&self, axis: u8) -> Vec<T> {
let (n, m) = match axis {
0 => {
let (n, m) = self.shape();
(m, n)
}
_ => self.shape(),
};
let mut x: Vec<T> = vec![T::zero(); n];
let div = T::from_usize(m).unwrap();
for (i, x_i) in x.iter_mut().enumerate().take(n) {
let mut mu = T::zero();
let mut sum = T::zero();
for j in 0..m {
let a = match axis {
0 => self.get(j, i),
_ => self.get(i, j),
};
mu += a;
sum += a * a;
}
mu /= div;
*x_i = sum / div - mu.powi(2);
}
x
}
/// Computes the standard deviation along the specified axis.
fn std(&self, axis: u8) -> Vec<T> {
let mut x = self.var(axis);
let n = match axis {
0 => self.shape().1,
_ => self.shape().0,
};
for x_i in x.iter_mut().take(n) {
*x_i = x_i.sqrt();
}
x
}
/// standardize values by removing the mean and scaling to unit variance
fn scale_mut(&mut self, mean: &[T], std: &[T], axis: u8) {
let (n, m) = match axis {
0 => {
let (n, m) = self.shape();
(m, n)
}
_ => self.shape(),
};
for i in 0..n {
for j in 0..m {
match axis {
0 => self.set(j, i, (self.get(j, i) - mean[i]) / std[i]),
_ => self.set(i, j, (self.get(i, j) - mean[i]) / std[i]),
}
}
}
}
}
/// Defines baseline implementations for various matrix processing functions
pub trait MatrixPreprocessing<T: RealNumber>: BaseMatrix<T> {
/// Each element of the matrix greater than the threshold becomes 1, while values less than or equal to the threshold become 0
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
/// use crate::smartcore::linalg::stats::MatrixPreprocessing;
/// let mut a = DenseMatrix::from_array(2, 3, &[0., 2., 3., -5., -6., -7.]);
/// let expected = DenseMatrix::from_array(2, 3, &[0., 1., 1., 0., 0., 0.]);
/// a.binarize_mut(0.);
///
/// assert_eq!(a, expected);
/// ```
fn binarize_mut(&mut self, threshold: T) {
let (nrows, ncols) = self.shape();
for row in 0..nrows {
for col in 0..ncols {
if self.get(row, col) > threshold {
self.set(row, col, T::one());
} else {
self.set(row, col, T::zero());
}
}
}
}
/// Returns new matrix where elements are binarized according to a given threshold.
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
/// use crate::smartcore::linalg::stats::MatrixPreprocessing;
/// let a = DenseMatrix::from_array(2, 3, &[0., 2., 3., -5., -6., -7.]);
/// let expected = DenseMatrix::from_array(2, 3, &[0., 1., 1., 0., 0., 0.]);
///
/// assert_eq!(a.binarize(0.), expected);
/// ```
fn binarize(&self, threshold: T) -> Self {
let mut m = self.clone();
m.binarize_mut(threshold);
m
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::linalg::BaseVector;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn mean() {
let m = DenseMatrix::from_2d_array(&[
&[1., 2., 3., 1., 2.],
&[4., 5., 6., 3., 4.],
&[7., 8., 9., 5., 6.],
]);
let expected_0 = vec![4., 5., 6., 3., 4.];
let expected_1 = vec![1.8, 4.4, 7.];
assert_eq!(m.mean(0), expected_0);
assert_eq!(m.mean(1), expected_1);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn std() {
let m = DenseMatrix::from_2d_array(&[
&[1., 2., 3., 1., 2.],
&[4., 5., 6., 3., 4.],
&[7., 8., 9., 5., 6.],
]);
let expected_0 = vec![2.44, 2.44, 2.44, 1.63, 1.63];
let expected_1 = vec![0.74, 1.01, 1.41];
assert!(m.std(0).approximate_eq(&expected_0, 1e-2));
assert!(m.std(1).approximate_eq(&expected_1, 1e-2));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn var() {
let m = DenseMatrix::from_2d_array(&[&[1., 2., 3., 4.], &[5., 6., 7., 8.]]);
let expected_0 = vec![4., 4., 4., 4.];
let expected_1 = vec![1.25, 1.25];
assert!(m.var(0).approximate_eq(&expected_0, std::f64::EPSILON));
assert!(m.var(1).approximate_eq(&expected_1, std::f64::EPSILON));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn scale() {
let mut m = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
let expected_0 = DenseMatrix::from_2d_array(&[&[-1., -1., -1.], &[1., 1., 1.]]);
let expected_1 = DenseMatrix::from_2d_array(&[&[-1.22, 0.0, 1.22], &[-1.22, 0.0, 1.22]]);
{
let mut m = m.clone();
m.scale_mut(&m.mean(0), &m.std(0), 0);
assert!(m.approximate_eq(&expected_0, std::f32::EPSILON));
}
m.scale_mut(&m.mean(1), &m.std(1), 1);
assert!(m.approximate_eq(&expected_1, 1e-2));
}
}
@@ -8,14 +8,14 @@
//!
//! Example:
//! ```
//! use smartcore::linalg::naive::dense_matrix::*;
//! use crate::smartcore::linalg::cholesky::*;
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linalg::traits::cholesky::*;
//!
//! let A = DenseMatrix::from_2d_array(&[
//! &[25., 15., -5.],
//! &[15., 18., 0.],
//! &[-5., 0., 11.]
//! ]);
//! ]).unwrap();
//!
//! let cholesky = A.cholesky().unwrap();
//! let lower_triangular: DenseMatrix<f64> = cholesky.L();
@@ -34,17 +34,18 @@ use std::fmt::Debug;
use std::marker::PhantomData;
use crate::error::{Failed, FailedError};
use crate::linalg::BaseMatrix;
use crate::math::num::RealNumber;
use crate::linalg::basic::arrays::Array2;
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
#[derive(Debug, Clone)]
/// Results of Cholesky decomposition.
pub struct Cholesky<T: RealNumber, M: BaseMatrix<T>> {
pub struct Cholesky<T: Number + RealNumber, M: Array2<T>> {
R: M,
t: PhantomData<T>,
}
impl<T: RealNumber, M: BaseMatrix<T>> Cholesky<T, M> {
impl<T: Number + RealNumber, M: Array2<T>> Cholesky<T, M> {
pub(crate) fn new(R: M) -> Cholesky<T, M> {
Cholesky { R, t: PhantomData }
}
@@ -57,7 +58,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> Cholesky<T, M> {
for i in 0..n {
for j in 0..n {
if j <= i {
R.set(i, j, self.R.get(i, j));
R.set((i, j), *self.R.get((i, j)));
}
}
}
@@ -72,7 +73,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> Cholesky<T, M> {
for i in 0..n {
for j in 0..n {
if j <= i {
R.set(j, i, self.R.get(i, j));
R.set((j, i), *self.R.get((i, j)));
}
}
}
@@ -87,25 +88,25 @@ impl<T: RealNumber, M: BaseMatrix<T>> Cholesky<T, M> {
if bn != rn {
return Err(Failed::because(
FailedError::SolutionFailed,
"Can\'t solve Ax = b for x. Number of rows in b != number of rows in R.",
"Can\'t solve Ax = b for x. FloatNumber of rows in b != number of rows in R.",
));
}
for k in 0..bn {
for j in 0..m {
for i in 0..k {
b.sub_element_mut(k, j, b.get(i, j) * self.R.get(k, i));
b.sub_element_mut((k, j), *b.get((i, j)) * *self.R.get((k, i)));
}
b.div_element_mut(k, j, self.R.get(k, k));
b.div_element_mut((k, j), *self.R.get((k, k)));
}
}
for k in (0..bn).rev() {
for j in 0..m {
for i in k + 1..bn {
b.sub_element_mut(k, j, b.get(i, j) * self.R.get(i, k));
b.sub_element_mut((k, j), *b.get((i, j)) * *self.R.get((i, k)));
}
b.div_element_mut(k, j, self.R.get(k, k));
b.div_element_mut((k, j), *self.R.get((k, k)));
}
}
Ok(b)
@@ -113,7 +114,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> Cholesky<T, M> {
}
/// Trait that implements Cholesky decomposition routine for any matrix.
pub trait CholeskyDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
pub trait CholeskyDecomposable<T: Number + RealNumber>: Array2<T> {
/// Compute the Cholesky decomposition of a matrix.
fn cholesky(&self) -> Result<Cholesky<T, Self>, Failed> {
self.clone().cholesky_mut()
@@ -136,13 +137,13 @@ pub trait CholeskyDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
for k in 0..j {
let mut s = T::zero();
for i in 0..k {
s += self.get(k, i) * self.get(j, i);
s += *self.get((k, i)) * *self.get((j, i));
}
s = (self.get(j, k) - s) / self.get(k, k);
self.set(j, k, s);
s = (*self.get((j, k)) - s) / *self.get((k, k));
self.set((j, k), s);
d += s * s;
}
d = self.get(j, j) - d;
d = *self.get((j, j)) - d;
if d < T::zero() {
return Err(Failed::because(
@@ -151,7 +152,7 @@ pub trait CholeskyDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
));
}
self.set(j, j, d.sqrt());
self.set((j, j), d.sqrt());
}
Ok(Cholesky::new(self))
@@ -166,39 +167,50 @@ pub trait CholeskyDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
use crate::linalg::basic::matrix::DenseMatrix;
use approx::relative_eq;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn cholesky_decompose() {
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]])
.unwrap();
let l =
DenseMatrix::from_2d_array(&[&[5.0, 0.0, 0.0], &[3.0, 3.0, 0.0], &[-1.0, 1.0, 3.0]]);
DenseMatrix::from_2d_array(&[&[5.0, 0.0, 0.0], &[3.0, 3.0, 0.0], &[-1.0, 1.0, 3.0]])
.unwrap();
let u =
DenseMatrix::from_2d_array(&[&[5.0, 3.0, -1.0], &[0.0, 3.0, 1.0], &[0.0, 0.0, 3.0]]);
DenseMatrix::from_2d_array(&[&[5.0, 3.0, -1.0], &[0.0, 3.0, 1.0], &[0.0, 0.0, 3.0]])
.unwrap();
let cholesky = a.cholesky().unwrap();
assert!(cholesky.L().abs().approximate_eq(&l.abs(), 1e-4));
assert!(cholesky.U().abs().approximate_eq(&u.abs(), 1e-4));
assert!(cholesky
.L()
.matmul(&cholesky.U())
.abs()
.approximate_eq(&a.abs(), 1e-4));
assert!(relative_eq!(cholesky.L().abs(), l.abs(), epsilon = 1e-4));
assert!(relative_eq!(cholesky.U().abs(), u.abs(), epsilon = 1e-4));
assert!(relative_eq!(
cholesky.L().matmul(&cholesky.U()).abs(),
a.abs(),
epsilon = 1e-4
));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn cholesky_solve_mut() {
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
let b = DenseMatrix::from_2d_array(&[&[40., 51., 28.]]);
let expected = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]);
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]])
.unwrap();
let b = DenseMatrix::from_2d_array(&[&[40., 51., 28.]]).unwrap();
let expected = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]).unwrap();
let cholesky = a.cholesky().unwrap();
assert!(cholesky
.solve(b.transpose())
.unwrap()
.transpose()
.approximate_eq(&expected, 1e-4));
assert!(relative_eq!(
cholesky.solve(b.transpose()).unwrap().transpose(),
expected,
epsilon = 1e-4
));
}
}
+217 -195
View File
@@ -12,14 +12,14 @@
//!
//! Example:
//! ```
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::linalg::evd::*;
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linalg::traits::evd::*;
//!
//! let A = DenseMatrix::from_2d_array(&[
//! &[0.9000, 0.4000, 0.7000],
//! &[0.4000, 0.5000, 0.3000],
//! &[0.7000, 0.3000, 0.8000],
//! ]);
//! ]).unwrap();
//!
//! let evd = A.evd(true).unwrap();
//! let eigenvectors: DenseMatrix<f64> = evd.V;
@@ -35,14 +35,15 @@
#![allow(non_snake_case)]
use crate::error::Failed;
use crate::linalg::BaseMatrix;
use crate::math::num::RealNumber;
use crate::linalg::basic::arrays::Array2;
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
use num::complex::Complex;
use std::fmt::Debug;
#[derive(Debug, Clone)]
/// Results of eigen decomposition
pub struct EVD<T: RealNumber, M: BaseMatrix<T>> {
pub struct EVD<T: Number + RealNumber, M: Array2<T>> {
/// Real part of eigenvalues.
pub d: Vec<T>,
/// Imaginary part of eigenvalues.
@@ -52,7 +53,7 @@ pub struct EVD<T: RealNumber, M: BaseMatrix<T>> {
}
/// Trait that implements EVD decomposition routine for any matrix.
pub trait EVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
pub trait EVDDecomposable<T: Number + RealNumber>: Array2<T> {
/// Compute the eigen decomposition of a square matrix.
/// * `symmetric` - whether the matrix is symmetric
fn evd(&self, symmetric: bool) -> Result<EVD<T, Self>, Failed> {
@@ -65,7 +66,7 @@ pub trait EVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
fn evd_mut(mut self, symmetric: bool) -> Result<EVD<T, Self>, Failed> {
let (nrows, ncols) = self.shape();
if ncols != nrows {
panic!("Matrix is not square: {} x {}", nrows, ncols);
panic!("Matrix is not square: {nrows} x {ncols}");
}
let n = nrows;
@@ -93,14 +94,14 @@ pub trait EVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
sort(&mut d, &mut e, &mut V);
}
Ok(EVD { d, e, V })
Ok(EVD { V, d, e })
}
}
fn tred2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
fn tred2<T: Number + RealNumber, M: Array2<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
let (n, _) = V.shape();
for (i, d_i) in d.iter_mut().enumerate().take(n) {
*d_i = V.get(n - 1, i);
*d_i = *V.get((n - 1, i));
}
for i in (1..n).rev() {
@@ -112,9 +113,9 @@ fn tred2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
if scale == T::zero() {
e[i] = d[i - 1];
for (j, d_j) in d.iter_mut().enumerate().take(i) {
*d_j = V.get(i - 1, j);
V.set(i, j, T::zero());
V.set(j, i, T::zero());
*d_j = *V.get((i - 1, j));
V.set((i, j), T::zero());
V.set((j, i), T::zero());
}
} else {
for d_k in d.iter_mut().take(i) {
@@ -135,11 +136,11 @@ fn tred2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
for j in 0..i {
f = d[j];
V.set(j, i, f);
g = e[j] + V.get(j, j) * f;
V.set((j, i), f);
g = e[j] + *V.get((j, j)) * f;
for k in j + 1..=i - 1 {
g += V.get(k, j) * d[k];
e[k] += V.get(k, j) * f;
g += *V.get((k, j)) * d[k];
e[k] += *V.get((k, j)) * f;
}
e[j] = g;
}
@@ -156,46 +157,46 @@ fn tred2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
f = d[j];
g = e[j];
for k in j..=i - 1 {
V.sub_element_mut(k, j, f * e[k] + g * d[k]);
V.sub_element_mut((k, j), f * e[k] + g * d[k]);
}
d[j] = V.get(i - 1, j);
V.set(i, j, T::zero());
d[j] = *V.get((i - 1, j));
V.set((i, j), T::zero());
}
}
d[i] = h;
}
for i in 0..n - 1 {
V.set(n - 1, i, V.get(i, i));
V.set(i, i, T::one());
V.set((n - 1, i), *V.get((i, i)));
V.set((i, i), T::one());
let h = d[i + 1];
if h != T::zero() {
for (k, d_k) in d.iter_mut().enumerate().take(i + 1) {
*d_k = V.get(k, i + 1) / h;
*d_k = *V.get((k, i + 1)) / h;
}
for j in 0..=i {
let mut g = T::zero();
for k in 0..=i {
g += V.get(k, i + 1) * V.get(k, j);
g += *V.get((k, i + 1)) * *V.get((k, j));
}
for (k, d_k) in d.iter().enumerate().take(i + 1) {
V.sub_element_mut(k, j, g * (*d_k));
V.sub_element_mut((k, j), g * (*d_k));
}
}
}
for k in 0..=i {
V.set(k, i + 1, T::zero());
V.set((k, i + 1), T::zero());
}
}
for (j, d_j) in d.iter_mut().enumerate().take(n) {
*d_j = V.get(n - 1, j);
V.set(n - 1, j, T::zero());
*d_j = *V.get((n - 1, j));
V.set((n - 1, j), T::zero());
}
V.set(n - 1, n - 1, T::one());
V.set((n - 1, n - 1), T::one());
e[0] = T::zero();
}
fn tql2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
fn tql2<T: Number + RealNumber, M: Array2<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
let (n, _) = V.shape();
for i in 1..n {
e[i - 1] = e[i];
@@ -264,9 +265,9 @@ fn tql2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
d[i + 1] = h + s * (c * g + s * d[i]);
for k in 0..n {
h = V.get(k, i + 1);
V.set(k, i + 1, s * V.get(k, i) + c * h);
V.set(k, i, c * V.get(k, i) - s * h);
h = *V.get((k, i + 1));
V.set((k, i + 1), s * *V.get((k, i)) + c * h);
V.set((k, i), c * *V.get((k, i)) - s * h);
}
}
p = -s * s2 * c3 * el1 * e[l] / dl1;
@@ -295,15 +296,15 @@ fn tql2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
d[k] = d[i];
d[i] = p;
for j in 0..n {
p = V.get(j, i);
V.set(j, i, V.get(j, k));
V.set(j, k, p);
p = *V.get((j, i));
V.set((j, i), *V.get((j, k)));
V.set((j, k), p);
}
}
}
}
fn balance<T: RealNumber, M: BaseMatrix<T>>(A: &mut M) -> Vec<T> {
fn balance<T: Number + RealNumber, M: Array2<T>>(A: &mut M) -> Vec<T> {
let radix = T::two();
let sqrdx = radix * radix;
@@ -321,8 +322,8 @@ fn balance<T: RealNumber, M: BaseMatrix<T>>(A: &mut M) -> Vec<T> {
let mut c = T::zero();
for j in 0..n {
if j != i {
c += A.get(j, i).abs();
r += A.get(i, j).abs();
c += A.get((j, i)).abs();
r += A.get((i, j)).abs();
}
}
if c != T::zero() && r != T::zero() {
@@ -343,10 +344,10 @@ fn balance<T: RealNumber, M: BaseMatrix<T>>(A: &mut M) -> Vec<T> {
g = T::one() / f;
*scale_i *= f;
for j in 0..n {
A.mul_element_mut(i, j, g);
A.mul_element_mut((i, j), g);
}
for j in 0..n {
A.mul_element_mut(j, i, f);
A.mul_element_mut((j, i), f);
}
}
}
@@ -356,7 +357,7 @@ fn balance<T: RealNumber, M: BaseMatrix<T>>(A: &mut M) -> Vec<T> {
scale
}
fn elmhes<T: RealNumber, M: BaseMatrix<T>>(A: &mut M) -> Vec<usize> {
fn elmhes<T: Number + RealNumber, M: Array2<T>>(A: &mut M) -> Vec<usize> {
let (n, _) = A.shape();
let mut perm = vec![0; n];
@@ -364,35 +365,31 @@ fn elmhes<T: RealNumber, M: BaseMatrix<T>>(A: &mut M) -> Vec<usize> {
let mut x = T::zero();
let mut i = m;
for j in m..n {
if A.get(j, m - 1).abs() > x.abs() {
x = A.get(j, m - 1);
if A.get((j, m - 1)).abs() > x.abs() {
x = *A.get((j, m - 1));
i = j;
}
}
*perm_m = i;
if i != m {
for j in (m - 1)..n {
let swap = A.get(i, j);
A.set(i, j, A.get(m, j));
A.set(m, j, swap);
A.swap((i, j), (m, j));
}
for j in 0..n {
let swap = A.get(j, i);
A.set(j, i, A.get(j, m));
A.set(j, m, swap);
A.swap((j, i), (j, m));
}
}
if x != T::zero() {
for i in (m + 1)..n {
let mut y = A.get(i, m - 1);
let mut y = *A.get((i, m - 1));
if y != T::zero() {
y /= x;
A.set(i, m - 1, y);
A.set((i, m - 1), y);
for j in m..n {
A.sub_element_mut(i, j, y * A.get(m, j));
A.sub_element_mut((i, j), y * *A.get((m, j)));
}
for j in 0..n {
A.add_element_mut(j, m, y * A.get(j, i));
A.add_element_mut((j, m), y * *A.get((j, i)));
}
}
}
@@ -402,24 +399,24 @@ fn elmhes<T: RealNumber, M: BaseMatrix<T>>(A: &mut M) -> Vec<usize> {
perm
}
fn eltran<T: RealNumber, M: BaseMatrix<T>>(A: &M, V: &mut M, perm: &[usize]) {
fn eltran<T: Number + RealNumber, M: Array2<T>>(A: &M, V: &mut M, perm: &[usize]) {
let (n, _) = A.shape();
for mp in (1..n - 1).rev() {
for k in mp + 1..n {
V.set(k, mp, A.get(k, mp - 1));
V.set((k, mp), *A.get((k, mp - 1)));
}
let i = perm[mp];
if i != mp {
for j in mp..n {
V.set(mp, j, V.get(i, j));
V.set(i, j, T::zero());
V.set((mp, j), *V.get((i, j)));
V.set((i, j), T::zero());
}
V.set(i, mp, T::one());
V.set((i, mp), T::one());
}
}
}
fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &mut [T]) {
fn hqr2<T: Number + RealNumber, M: Array2<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &mut [T]) {
let (n, _) = A.shape();
let mut z = T::zero();
let mut s = T::zero();
@@ -430,7 +427,7 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &
for i in 0..n {
for j in i32::max(i as i32 - 1, 0)..n as i32 {
anorm += A.get(i, j as usize).abs();
anorm += A.get((i, j as usize)).abs();
}
}
@@ -441,43 +438,43 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &
loop {
let mut l = nn;
while l > 0 {
s = A.get(l - 1, l - 1).abs() + A.get(l, l).abs();
s = A.get((l - 1, l - 1)).abs() + A.get((l, l)).abs();
if s == T::zero() {
s = anorm;
}
if A.get(l, l - 1).abs() <= T::epsilon() * s {
A.set(l, l - 1, T::zero());
if A.get((l, l - 1)).abs() <= T::epsilon() * s {
A.set((l, l - 1), T::zero());
break;
}
l -= 1;
}
let mut x = A.get(nn, nn);
let mut x = *A.get((nn, nn));
if l == nn {
d[nn] = x + t;
A.set(nn, nn, x + t);
A.set((nn, nn), x + t);
if nn == 0 {
break 'outer;
} else {
nn -= 1;
}
} else {
let mut y = A.get(nn - 1, nn - 1);
let mut w = A.get(nn, nn - 1) * A.get(nn - 1, nn);
let mut y = *A.get((nn - 1, nn - 1));
let mut w = *A.get((nn, nn - 1)) * *A.get((nn - 1, nn));
if l == nn - 1 {
p = T::half() * (y - x);
q = p * p + w;
z = q.abs().sqrt();
x += t;
A.set(nn, nn, x);
A.set(nn - 1, nn - 1, y + t);
A.set((nn, nn), x);
A.set((nn - 1, nn - 1), y + t);
if q >= T::zero() {
z = p + RealNumber::copysign(z, p);
z = p + <T as RealNumber>::copysign(z, p);
d[nn - 1] = x + z;
d[nn] = x + z;
if z != T::zero() {
d[nn] = x - w / z;
}
x = A.get(nn, nn - 1);
x = *A.get((nn, nn - 1));
s = x.abs() + z.abs();
p = x / s;
q = z / s;
@@ -485,19 +482,19 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &
p /= r;
q /= r;
for j in nn - 1..n {
z = A.get(nn - 1, j);
A.set(nn - 1, j, q * z + p * A.get(nn, j));
A.set(nn, j, q * A.get(nn, j) - p * z);
z = *A.get((nn - 1, j));
A.set((nn - 1, j), q * z + p * *A.get((nn, j)));
A.set((nn, j), q * *A.get((nn, j)) - p * z);
}
for i in 0..=nn {
z = A.get(i, nn - 1);
A.set(i, nn - 1, q * z + p * A.get(i, nn));
A.set(i, nn, q * A.get(i, nn) - p * z);
z = *A.get((i, nn - 1));
A.set((i, nn - 1), q * z + p * *A.get((i, nn)));
A.set((i, nn), q * *A.get((i, nn)) - p * z);
}
for i in 0..n {
z = V.get(i, nn - 1);
V.set(i, nn - 1, q * z + p * V.get(i, nn));
V.set(i, nn, q * V.get(i, nn) - p * z);
z = *V.get((i, nn - 1));
V.set((i, nn - 1), q * z + p * *V.get((i, nn)));
V.set((i, nn), q * *V.get((i, nn)) - p * z);
}
} else {
d[nn] = x + p;
@@ -518,22 +515,22 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &
if its == 10 || its == 20 {
t += x;
for i in 0..nn + 1 {
A.sub_element_mut(i, i, x);
A.sub_element_mut((i, i), x);
}
s = A.get(nn, nn - 1).abs() + A.get(nn - 1, nn - 2).abs();
y = T::from(0.75).unwrap() * s;
x = T::from(0.75).unwrap() * s;
w = T::from(-0.4375).unwrap() * s * s;
s = A.get((nn, nn - 1)).abs() + A.get((nn - 1, nn - 2)).abs();
y = T::from_f64(0.75).unwrap() * s;
x = T::from_f64(0.75).unwrap() * s;
w = T::from_f64(-0.4375).unwrap() * s * s;
}
its += 1;
let mut m = nn - 2;
while m >= l {
z = A.get(m, m);
z = *A.get((m, m));
r = x - z;
s = y - z;
p = (r * s - w) / A.get(m + 1, m) + A.get(m, m + 1);
q = A.get(m + 1, m + 1) - z - r - s;
r = A.get(m + 2, m + 1);
p = (r * s - w) / *A.get((m + 1, m)) + *A.get((m, m + 1));
q = *A.get((m + 1, m + 1)) - z - r - s;
r = *A.get((m + 2, m + 1));
s = p.abs() + q.abs() + r.abs();
p /= s;
q /= s;
@@ -541,27 +538,27 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &
if m == l {
break;
}
let u = A.get(m, m - 1).abs() * (q.abs() + r.abs());
let u = A.get((m, m - 1)).abs() * (q.abs() + r.abs());
let v = p.abs()
* (A.get(m - 1, m - 1).abs() + z.abs() + A.get(m + 1, m + 1).abs());
* (A.get((m - 1, m - 1)).abs() + z.abs() + A.get((m + 1, m + 1)).abs());
if u <= T::epsilon() * v {
break;
}
m -= 1;
}
for i in m..nn - 1 {
A.set(i + 2, i, T::zero());
A.set((i + 2, i), T::zero());
if i != m {
A.set(i + 2, i - 1, T::zero());
A.set((i + 2, i - 1), T::zero());
}
}
for k in m..nn {
if k != m {
p = A.get(k, k - 1);
q = A.get(k + 1, k - 1);
p = *A.get((k, k - 1));
q = *A.get((k + 1, k - 1));
r = T::zero();
if k + 1 != nn {
r = A.get(k + 2, k - 1);
r = *A.get((k + 2, k - 1));
}
x = p.abs() + q.abs() + r.abs();
if x != T::zero() {
@@ -570,14 +567,14 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &
r /= x;
}
}
let s = RealNumber::copysign((p * p + q * q + r * r).sqrt(), p);
let s = <T as RealNumber>::copysign((p * p + q * q + r * r).sqrt(), p);
if s != T::zero() {
if k == m {
if l != m {
A.set(k, k - 1, -A.get(k, k - 1));
A.set((k, k - 1), -*A.get((k, k - 1)));
}
} else {
A.set(k, k - 1, -s * x);
A.set((k, k - 1), -s * x);
}
p += s;
x = p / s;
@@ -586,32 +583,33 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &
q /= p;
r /= p;
for j in k..n {
p = A.get(k, j) + q * A.get(k + 1, j);
p = *A.get((k, j)) + q * *A.get((k + 1, j));
if k + 1 != nn {
p += r * A.get(k + 2, j);
A.sub_element_mut(k + 2, j, p * z);
p += r * *A.get((k + 2, j));
A.sub_element_mut((k + 2, j), p * z);
}
A.sub_element_mut(k + 1, j, p * y);
A.sub_element_mut(k, j, p * x);
A.sub_element_mut((k + 1, j), p * y);
A.sub_element_mut((k, j), p * x);
}
let mmin = if nn < k + 3 { nn } else { k + 3 };
for i in 0..mmin + 1 {
p = x * A.get(i, k) + y * A.get(i, k + 1);
for i in 0..(mmin + 1) {
p = x * *A.get((i, k)) + y * *A.get((i, k + 1));
if k + 1 != nn {
p += z * A.get(i, k + 2);
A.sub_element_mut(i, k + 2, p * r);
p += z * *A.get((i, k + 2));
A.sub_element_mut((i, k + 2), p * r);
}
A.sub_element_mut(i, k + 1, p * q);
A.sub_element_mut(i, k, p);
A.sub_element_mut((i, k + 1), p * q);
A.sub_element_mut((i, k), p);
}
for i in 0..n {
p = x * V.get(i, k) + y * V.get(i, k + 1);
p = x * *V.get((i, k)) + y * *V.get((i, k + 1));
if k + 1 != nn {
p += z * V.get(i, k + 2);
V.sub_element_mut(i, k + 2, p * r);
p += z * *V.get((i, k + 2));
V.sub_element_mut((i, k + 2), p * r);
}
V.sub_element_mut(i, k + 1, p * q);
V.sub_element_mut(i, k, p);
V.sub_element_mut((i, k + 1), p * q);
V.sub_element_mut((i, k), p);
}
}
}
@@ -630,14 +628,14 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &
let na = nn.wrapping_sub(1);
if q == T::zero() {
let mut m = nn;
A.set(nn, nn, T::one());
A.set((nn, nn), T::one());
if nn > 0 {
let mut i = nn - 1;
loop {
let w = A.get(i, i) - p;
let w = *A.get((i, i)) - p;
r = T::zero();
for j in m..=nn {
r += A.get(i, j) * A.get(j, nn);
r += *A.get((i, j)) * *A.get((j, nn));
}
if e[i] < T::zero() {
z = w;
@@ -650,23 +648,23 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &
if t == T::zero() {
t = T::epsilon() * anorm;
}
A.set(i, nn, -r / t);
A.set((i, nn), -r / t);
} else {
let x = A.get(i, i + 1);
let y = A.get(i + 1, i);
let x = *A.get((i, i + 1));
let y = *A.get((i + 1, i));
q = (d[i] - p).powf(T::two()) + e[i].powf(T::two());
t = (x * s - z * r) / q;
A.set(i, nn, t);
A.set((i, nn), t);
if x.abs() > z.abs() {
A.set(i + 1, nn, (-r - w * t) / x);
A.set((i + 1, nn), (-r - w * t) / x);
} else {
A.set(i + 1, nn, (-s - y * t) / z);
A.set((i + 1, nn), (-s - y * t) / z);
}
}
t = A.get(i, nn).abs();
t = A.get((i, nn)).abs();
if T::epsilon() * t * t > T::one() {
for j in i..=nn {
A.div_element_mut(j, nn, t);
A.div_element_mut((j, nn), t);
}
}
}
@@ -679,25 +677,25 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &
}
} else if q < T::zero() {
let mut m = na;
if A.get(nn, na).abs() > A.get(na, nn).abs() {
A.set(na, na, q / A.get(nn, na));
A.set(na, nn, -(A.get(nn, nn) - p) / A.get(nn, na));
if A.get((nn, na)).abs() > A.get((na, nn)).abs() {
A.set((na, na), q / *A.get((nn, na)));
A.set((na, nn), -(*A.get((nn, nn)) - p) / *A.get((nn, na)));
} else {
let temp = Complex::new(T::zero(), -A.get(na, nn))
/ Complex::new(A.get(na, na) - p, q);
A.set(na, na, temp.re);
A.set(na, nn, temp.im);
let temp = Complex::new(T::zero(), -*A.get((na, nn)))
/ Complex::new(*A.get((na, na)) - p, q);
A.set((na, na), temp.re);
A.set((na, nn), temp.im);
}
A.set(nn, na, T::zero());
A.set(nn, nn, T::one());
A.set((nn, na), T::zero());
A.set((nn, nn), T::one());
if nn >= 2 {
for i in (0..nn - 1).rev() {
let w = A.get(i, i) - p;
let w = *A.get((i, i)) - p;
let mut ra = T::zero();
let mut sa = T::zero();
for j in m..=nn {
ra += A.get(i, j) * A.get(j, na);
sa += A.get(i, j) * A.get(j, nn);
ra += *A.get((i, j)) * *A.get((j, na));
sa += *A.get((i, j)) * *A.get((j, nn));
}
if e[i] < T::zero() {
z = w;
@@ -707,11 +705,11 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &
m = i;
if e[i] == T::zero() {
let temp = Complex::new(-ra, -sa) / Complex::new(w, q);
A.set(i, na, temp.re);
A.set(i, nn, temp.im);
A.set((i, na), temp.re);
A.set((i, nn), temp.im);
} else {
let x = A.get(i, i + 1);
let y = A.get(i + 1, i);
let x = *A.get((i, i + 1));
let y = *A.get((i + 1, i));
let mut vr =
(d[i] - p).powf(T::two()) + (e[i]).powf(T::two()) - q * q;
let vi = T::two() * q * (d[i] - p);
@@ -723,33 +721,32 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &
let temp =
Complex::new(x * r - z * ra + q * sa, x * s - z * sa - q * ra)
/ Complex::new(vr, vi);
A.set(i, na, temp.re);
A.set(i, nn, temp.im);
A.set((i, na), temp.re);
A.set((i, nn), temp.im);
if x.abs() > z.abs() + q.abs() {
A.set(
i + 1,
na,
(-ra - w * A.get(i, na) + q * A.get(i, nn)) / x,
(i + 1, na),
(-ra - w * *A.get((i, na)) + q * *A.get((i, nn))) / x,
);
A.set(
i + 1,
nn,
(-sa - w * A.get(i, nn) - q * A.get(i, na)) / x,
(i + 1, nn),
(-sa - w * *A.get((i, nn)) - q * *A.get((i, na))) / x,
);
} else {
let temp =
Complex::new(-r - y * A.get(i, na), -s - y * A.get(i, nn))
/ Complex::new(z, q);
A.set(i + 1, na, temp.re);
A.set(i + 1, nn, temp.im);
let temp = Complex::new(
-r - y * *A.get((i, na)),
-s - y * *A.get((i, nn)),
) / Complex::new(z, q);
A.set((i + 1, na), temp.re);
A.set((i + 1, nn), temp.im);
}
}
}
t = T::max(A.get(i, na).abs(), A.get(i, nn).abs());
t = T::max(A.get((i, na)).abs(), A.get((i, nn)).abs());
if T::epsilon() * t * t > T::one() {
for j in i..=nn {
A.div_element_mut(j, na, t);
A.div_element_mut(j, nn, t);
A.div_element_mut((j, na), t);
A.div_element_mut((j, nn), t);
}
}
}
@@ -761,31 +758,31 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &
for i in 0..n {
z = T::zero();
for k in 0..=j {
z += V.get(i, k) * A.get(k, j);
z += *V.get((i, k)) * *A.get((k, j));
}
V.set(i, j, z);
V.set((i, j), z);
}
}
}
}
fn balbak<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, scale: &[T]) {
fn balbak<T: Number + RealNumber, M: Array2<T>>(V: &mut M, scale: &[T]) {
let (n, _) = V.shape();
for (i, scale_i) in scale.iter().enumerate().take(n) {
for j in 0..n {
V.mul_element_mut(i, j, *scale_i);
V.mul_element_mut((i, j), *scale_i);
}
}
}
fn sort<T: RealNumber, M: BaseMatrix<T>>(d: &mut [T], e: &mut [T], V: &mut M) {
fn sort<T: Number + RealNumber, M: Array2<T>>(d: &mut [T], e: &mut [T], V: &mut M) {
let n = d.len();
let mut temp = vec![T::zero(); n];
for j in 1..n {
let real = d[j];
let img = e[j];
for (k, temp_k) in temp.iter_mut().enumerate().take(n) {
*temp_k = V.get(k, j);
*temp_k = *V.get((k, j));
}
let mut i = j as i32 - 1;
while i >= 0 {
@@ -795,14 +792,14 @@ fn sort<T: RealNumber, M: BaseMatrix<T>>(d: &mut [T], e: &mut [T], V: &mut M) {
d[i as usize + 1] = d[i as usize];
e[i as usize + 1] = e[i as usize];
for k in 0..n {
V.set(k, i as usize + 1, V.get(k, i as usize));
V.set((k, i as usize + 1), *V.get((k, i as usize)));
}
i -= 1;
}
d[i as usize + 1] = real;
e[i as usize + 1] = img;
for (k, temp_k) in temp.iter().enumerate().take(n) {
V.set(k, i as usize + 1, *temp_k);
V.set((k, i as usize + 1), *temp_k);
}
}
}
@@ -810,15 +807,21 @@ fn sort<T: RealNumber, M: BaseMatrix<T>>(d: &mut [T], e: &mut [T], V: &mut M) {
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
use crate::linalg::basic::matrix::DenseMatrix;
use approx::relative_eq;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn decompose_symmetric() {
let A = DenseMatrix::from_2d_array(&[
&[0.9000, 0.4000, 0.7000],
&[0.4000, 0.5000, 0.3000],
&[0.7000, 0.3000, 0.8000],
]);
])
.unwrap();
let eigen_values: Vec<f64> = vec![1.7498382, 0.3165784, 0.1335834];
@@ -826,26 +829,33 @@ mod tests {
&[0.6881997, -0.07121225, 0.7220180],
&[0.3700456, 0.89044952, -0.2648886],
&[0.6240573, -0.44947578, -0.6391588],
]);
])
.unwrap();
let evd = A.evd(true).unwrap();
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
for i in 0..eigen_values.len() {
assert!((eigen_values[i] - evd.d[i]).abs() < 1e-4);
}
for i in 0..eigen_values.len() {
assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON);
assert!(relative_eq!(
eigen_vectors.abs(),
evd.V.abs(),
epsilon = 1e-4
));
for (i, eigen_values_i) in eigen_values.iter().enumerate() {
assert!((eigen_values_i - evd.d[i]).abs() < 1e-4);
assert!((0f64 - evd.e[i]).abs() < f64::EPSILON);
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn decompose_asymmetric() {
let A = DenseMatrix::from_2d_array(&[
&[0.9000, 0.4000, 0.7000],
&[0.4000, 0.5000, 0.3000],
&[0.8000, 0.3000, 0.8000],
]);
])
.unwrap();
let eigen_values: Vec<f64> = vec![1.79171122, 0.31908143, 0.08920735];
@@ -853,19 +863,25 @@ mod tests {
&[0.7178958, 0.05322098, 0.6812010],
&[0.3837711, -0.84702111, -0.1494582],
&[0.6952105, 0.43984484, -0.7036135],
]);
])
.unwrap();
let evd = A.evd(false).unwrap();
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
for i in 0..eigen_values.len() {
assert!((eigen_values[i] - evd.d[i]).abs() < 1e-4);
}
for i in 0..eigen_values.len() {
assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON);
assert!(relative_eq!(
eigen_vectors.abs(),
evd.V.abs(),
epsilon = 1e-4
));
for (i, eigen_values_i) in eigen_values.iter().enumerate() {
assert!((eigen_values_i - evd.d[i]).abs() < 1e-4);
assert!((0f64 - evd.e[i]).abs() < f64::EPSILON);
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn decompose_complex() {
let A = DenseMatrix::from_2d_array(&[
@@ -873,7 +889,8 @@ mod tests {
&[4.0, -1.0, 1.0, 1.0],
&[1.0, 1.0, 3.0, -2.0],
&[1.0, 1.0, 4.0, -1.0],
]);
])
.unwrap();
let eigen_values_d: Vec<f64> = vec![0.0, 2.0, 2.0, 0.0];
let eigen_values_e: Vec<f64> = vec![2.2361, 0.9999, -0.9999, -2.2361];
@@ -883,16 +900,21 @@ mod tests {
&[-0.6707, 0.1059, 0.901, 0.6289],
&[0.9159, -0.1378, 0.3816, 0.0806],
&[0.6707, 0.1059, 0.901, -0.6289],
]);
])
.unwrap();
let evd = A.evd(false).unwrap();
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
for i in 0..eigen_values_d.len() {
assert!((eigen_values_d[i] - evd.d[i]).abs() < 1e-4);
assert!(relative_eq!(
eigen_vectors.abs(),
evd.V.abs(),
epsilon = 1e-4
));
for (i, eigen_values_d_i) in eigen_values_d.iter().enumerate() {
assert!((eigen_values_d_i - evd.d[i]).abs() < 1e-4);
}
for i in 0..eigen_values_e.len() {
assert!((eigen_values_e[i] - evd.e[i]).abs() < 1e-4);
for (i, eigen_values_e_i) in eigen_values_e.iter().enumerate() {
assert!((eigen_values_e_i - evd.e[i]).abs() < 1e-4);
}
}
}
@@ -1,19 +1,20 @@
//! In this module you will find composite of matrix operations that are used elsewhere
//! for improved efficiency.
use crate::linalg::BaseMatrix;
use crate::math::num::RealNumber;
use crate::linalg::basic::arrays::Array2;
use crate::numbers::basenum::Number;
/// High order matrix operations.
pub trait HighOrderOperations<T: RealNumber>: BaseMatrix<T> {
pub trait HighOrderOperations<T: Number>: Array2<T> {
/// Y = AB
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
/// use smartcore::linalg::high_order::HighOrderOperations;
/// use smartcore::linalg::basic::matrix::*;
/// use smartcore::linalg::traits::high_order::HighOrderOperations;
/// use smartcore::linalg::basic::arrays::Array2;
///
/// let a = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.]]);
/// let b = DenseMatrix::from_2d_array(&[&[5., 6.], &[7., 8.], &[9., 10.]]);
/// let expected = DenseMatrix::from_2d_array(&[&[71., 80.], &[92., 104.]]);
/// let a = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.]]).unwrap();
/// let b = DenseMatrix::from_2d_array(&[&[5., 6.], &[7., 8.], &[9., 10.]]).unwrap();
/// let expected = DenseMatrix::from_2d_array(&[&[71., 80.], &[92., 104.]]).unwrap();
///
/// assert_eq!(a.ab(true, &b, false), expected);
/// ```
@@ -26,3 +27,7 @@ pub trait HighOrderOperations<T: RealNumber>: BaseMatrix<T> {
}
}
}
mod tests {
/* TODO: Add tests */
}
+55 -51
View File
@@ -11,14 +11,14 @@
//!
//! Example:
//! ```
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::linalg::lu::*;
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linalg::traits::lu::*;
//!
//! let A = DenseMatrix::from_2d_array(&[
//! &[1., 2., 3.],
//! &[0., 1., 5.],
//! &[5., 6., 0.]
//! ]);
//! ]).unwrap();
//!
//! let lu = A.lu().unwrap();
//! let lower: DenseMatrix<f64> = lu.L();
@@ -38,26 +38,27 @@ use std::fmt::Debug;
use std::marker::PhantomData;
use crate::error::Failed;
use crate::linalg::BaseMatrix;
use crate::math::num::RealNumber;
use crate::linalg::basic::arrays::Array2;
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
#[derive(Debug, Clone)]
/// Result of LU decomposition.
pub struct LU<T: RealNumber, M: BaseMatrix<T>> {
pub struct LU<T: Number + RealNumber, M: Array2<T>> {
LU: M,
pivot: Vec<usize>,
_pivot_sign: i8,
#[allow(dead_code)]
pivot_sign: i8,
singular: bool,
phantom: PhantomData<T>,
}
impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
pub(crate) fn new(LU: M, pivot: Vec<usize>, _pivot_sign: i8) -> LU<T, M> {
impl<T: Number + RealNumber, M: Array2<T>> LU<T, M> {
pub(crate) fn new(LU: M, pivot: Vec<usize>, pivot_sign: i8) -> LU<T, M> {
let (_, n) = LU.shape();
let mut singular = false;
for j in 0..n {
if LU.get(j, j) == T::zero() {
if LU.get((j, j)) == &T::zero() {
singular = true;
break;
}
@@ -66,7 +67,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
LU {
LU,
pivot,
_pivot_sign,
pivot_sign,
singular,
phantom: PhantomData,
}
@@ -80,9 +81,9 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
for i in 0..n_rows {
for j in 0..n_cols {
match i.cmp(&j) {
Ordering::Greater => L.set(i, j, self.LU.get(i, j)),
Ordering::Equal => L.set(i, j, T::one()),
Ordering::Less => L.set(i, j, T::zero()),
Ordering::Greater => L.set((i, j), *self.LU.get((i, j))),
Ordering::Equal => L.set((i, j), T::one()),
Ordering::Less => L.set((i, j), T::zero()),
}
}
}
@@ -98,9 +99,9 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
for i in 0..n_rows {
for j in 0..n_cols {
if i <= j {
U.set(i, j, self.LU.get(i, j));
U.set((i, j), *self.LU.get((i, j)));
} else {
U.set(i, j, T::zero());
U.set((i, j), T::zero());
}
}
}
@@ -114,7 +115,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
let mut piv = M::zeros(n, n);
for i in 0..n {
piv.set(i, self.pivot[i], T::one());
piv.set((i, self.pivot[i]), T::one());
}
piv
@@ -125,13 +126,13 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
let (m, n) = self.LU.shape();
if m != n {
panic!("Matrix is not square: {}x{}", m, n);
panic!("Matrix is not square: {m}x{n}");
}
let mut inv = M::zeros(n, n);
for i in 0..n {
inv.set(i, i, T::one());
inv.set((i, i), T::one());
}
self.solve(inv)
@@ -142,10 +143,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
let (b_m, b_n) = b.shape();
if b_m != m {
panic!(
"Row dimensions do not agree: A is {} x {}, but B is {} x {}",
m, n, b_m, b_n
);
panic!("Row dimensions do not agree: A is {m} x {n}, but B is {b_m} x {b_n}");
}
if self.singular {
@@ -156,33 +154,33 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
for j in 0..b_n {
for i in 0..m {
X.set(i, j, b.get(self.pivot[i], j));
X.set((i, j), *b.get((self.pivot[i], j)));
}
}
for k in 0..n {
for i in k + 1..n {
for j in 0..b_n {
X.sub_element_mut(i, j, X.get(k, j) * self.LU.get(i, k));
X.sub_element_mut((i, j), *X.get((k, j)) * *self.LU.get((i, k)));
}
}
}
for k in (0..n).rev() {
for j in 0..b_n {
X.div_element_mut(k, j, self.LU.get(k, k));
X.div_element_mut((k, j), *self.LU.get((k, k)));
}
for i in 0..k {
for j in 0..b_n {
X.sub_element_mut(i, j, X.get(k, j) * self.LU.get(i, k));
X.sub_element_mut((i, j), *X.get((k, j)) * *self.LU.get((i, k)));
}
}
}
for j in 0..b_n {
for i in 0..m {
b.set(i, j, X.get(i, j));
b.set((i, j), *X.get((i, j)));
}
}
@@ -191,7 +189,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
}
/// Trait that implements LU decomposition routine for any matrix.
pub trait LUDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
pub trait LUDecomposable<T: Number + RealNumber>: Array2<T> {
/// Compute the LU decomposition of a square matrix.
fn lu(&self) -> Result<LU<T, Self>, Failed> {
self.clone().lu_mut()
@@ -209,18 +207,18 @@ pub trait LUDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
for j in 0..n {
for (i, LUcolj_i) in LUcolj.iter_mut().enumerate().take(m) {
*LUcolj_i = self.get(i, j);
*LUcolj_i = *self.get((i, j));
}
for i in 0..m {
let kmax = usize::min(i, j);
let mut s = T::zero();
for (k, LUcolj_k) in LUcolj.iter().enumerate().take(kmax) {
s += self.get(i, k) * (*LUcolj_k);
s += *self.get((i, k)) * (*LUcolj_k);
}
LUcolj[i] -= s;
self.set(i, j, LUcolj[i]);
self.set((i, j), LUcolj[i]);
}
let mut p = j;
@@ -231,17 +229,15 @@ pub trait LUDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
}
if p != j {
for k in 0..n {
let t = self.get(p, k);
self.set(p, k, self.get(j, k));
self.set(j, k, t);
self.swap((p, k), (j, k));
}
piv.swap(p, j);
pivsign = -pivsign;
}
if j < m && self.get(j, j) != T::zero() {
if j < m && self.get((j, j)) != &T::zero() {
for i in j + 1..m {
self.div_element_mut(i, j, self.get(j, j));
self.div_element_mut((i, j), *self.get((j, j)));
}
}
}
@@ -258,30 +254,38 @@ pub trait LUDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::*;
use crate::linalg::basic::matrix::DenseMatrix;
use approx::relative_eq;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn decompose() {
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]).unwrap();
let expected_L =
DenseMatrix::from_2d_array(&[&[1., 0., 0.], &[0., 1., 0.], &[0.2, 0.8, 1.]]);
DenseMatrix::from_2d_array(&[&[1., 0., 0.], &[0., 1., 0.], &[0.2, 0.8, 1.]]).unwrap();
let expected_U =
DenseMatrix::from_2d_array(&[&[5., 6., 0.], &[0., 1., 5.], &[0., 0., -1.]]);
DenseMatrix::from_2d_array(&[&[5., 6., 0.], &[0., 1., 5.], &[0., 0., -1.]]).unwrap();
let expected_pivot =
DenseMatrix::from_2d_array(&[&[0., 0., 1.], &[0., 1., 0.], &[1., 0., 0.]]);
DenseMatrix::from_2d_array(&[&[0., 0., 1.], &[0., 1., 0.], &[1., 0., 0.]]).unwrap();
let lu = a.lu().unwrap();
assert!(lu.L().approximate_eq(&expected_L, 1e-4));
assert!(lu.U().approximate_eq(&expected_U, 1e-4));
assert!(lu.pivot().approximate_eq(&expected_pivot, 1e-4));
assert!(relative_eq!(lu.L(), expected_L, epsilon = 1e-4));
assert!(relative_eq!(lu.U(), expected_U, epsilon = 1e-4));
assert!(relative_eq!(lu.pivot(), expected_pivot, epsilon = 1e-4));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn inverse() {
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]).unwrap();
let expected =
DenseMatrix::from_2d_array(&[&[-6.0, 3.6, 1.4], &[5.0, -3.0, -1.0], &[-1.0, 0.8, 0.2]]);
DenseMatrix::from_2d_array(&[&[-6.0, 3.6, 1.4], &[5.0, -3.0, -1.0], &[-1.0, 0.8, 0.2]])
.unwrap();
let a_inv = a.lu().and_then(|lu| lu.inverse()).unwrap();
assert!(a_inv.approximate_eq(&expected, 1e-4));
assert!(relative_eq!(a_inv, expected, epsilon = 1e-4));
}
}
+15
View File
@@ -0,0 +1,15 @@
#![allow(clippy::wrong_self_convention)]
pub mod cholesky;
/// The matrix is represented in terms of its eigenvalues and eigenvectors.
pub mod evd;
pub mod high_order;
/// Factors a matrix as the product of a lower triangular matrix and an upper triangular matrix.
pub mod lu;
/// QR factorization that factors a matrix into a product of an orthogonal matrix and an upper triangular matrix.
pub mod qr;
/// statistacal tools for DenseMatrix
pub mod stats;
/// Singular value decomposition.
pub mod svd;
+55 -44
View File
@@ -6,14 +6,14 @@
//!
//! Example:
//! ```
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::linalg::qr::*;
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linalg::traits::qr::*;
//!
//! let A = DenseMatrix::from_2d_array(&[
//! &[0.9, 0.4, 0.7],
//! &[0.4, 0.5, 0.3],
//! &[0.7, 0.3, 0.8]
//! ]);
//! ]).unwrap();
//!
//! let qr = A.qr().unwrap();
//! let orthogonal: DenseMatrix<f64> = qr.Q();
@@ -28,20 +28,22 @@
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#![allow(non_snake_case)]
use crate::error::Failed;
use crate::linalg::BaseMatrix;
use crate::math::num::RealNumber;
use std::fmt::Debug;
use crate::error::Failed;
use crate::linalg::basic::arrays::Array2;
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
#[derive(Debug, Clone)]
/// Results of QR decomposition.
pub struct QR<T: RealNumber, M: BaseMatrix<T>> {
pub struct QR<T: Number + RealNumber, M: Array2<T>> {
QR: M,
tau: Vec<T>,
singular: bool,
}
impl<T: RealNumber, M: BaseMatrix<T>> QR<T, M> {
impl<T: Number + RealNumber, M: Array2<T>> QR<T, M> {
pub(crate) fn new(QR: M, tau: Vec<T>) -> QR<T, M> {
let mut singular = false;
for tau_elem in tau.iter() {
@@ -59,9 +61,9 @@ impl<T: RealNumber, M: BaseMatrix<T>> QR<T, M> {
let (_, n) = self.QR.shape();
let mut R = M::zeros(n, n);
for i in 0..n {
R.set(i, i, self.tau[i]);
R.set((i, i), self.tau[i]);
for j in i + 1..n {
R.set(i, j, self.QR.get(i, j));
R.set((i, j), *self.QR.get((i, j)));
}
}
R
@@ -73,16 +75,16 @@ impl<T: RealNumber, M: BaseMatrix<T>> QR<T, M> {
let mut Q = M::zeros(m, n);
let mut k = n - 1;
loop {
Q.set(k, k, T::one());
Q.set((k, k), T::one());
for j in k..n {
if self.QR.get(k, k) != T::zero() {
if self.QR.get((k, k)) != &T::zero() {
let mut s = T::zero();
for i in k..m {
s += self.QR.get(i, k) * Q.get(i, j);
s += *self.QR.get((i, k)) * *Q.get((i, j));
}
s = -s / self.QR.get(k, k);
s = -s / *self.QR.get((k, k));
for i in k..m {
Q.add_element_mut(i, j, s * self.QR.get(i, k));
Q.add_element_mut((i, j), s * *self.QR.get((i, k)));
}
}
}
@@ -100,10 +102,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> QR<T, M> {
let (b_nrows, b_ncols) = b.shape();
if b_nrows != m {
panic!(
"Row dimensions do not agree: A is {} x {}, but B is {} x {}",
m, n, b_nrows, b_ncols
);
panic!("Row dimensions do not agree: A is {m} x {n}, but B is {b_nrows} x {b_ncols}");
}
if self.singular {
@@ -114,23 +113,23 @@ impl<T: RealNumber, M: BaseMatrix<T>> QR<T, M> {
for j in 0..b_ncols {
let mut s = T::zero();
for i in k..m {
s += self.QR.get(i, k) * b.get(i, j);
s += *self.QR.get((i, k)) * *b.get((i, j));
}
s = -s / self.QR.get(k, k);
s = -s / *self.QR.get((k, k));
for i in k..m {
b.add_element_mut(i, j, s * self.QR.get(i, k));
b.add_element_mut((i, j), s * *self.QR.get((i, k)));
}
}
}
for k in (0..n).rev() {
for j in 0..b_ncols {
b.set(k, j, b.get(k, j) / self.tau[k]);
b.set((k, j), *b.get((k, j)) / self.tau[k]);
}
for i in 0..k {
for j in 0..b_ncols {
b.sub_element_mut(i, j, b.get(k, j) * self.QR.get(i, k));
b.sub_element_mut((i, j), *b.get((k, j)) * *self.QR.get((i, k)));
}
}
}
@@ -140,7 +139,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> QR<T, M> {
}
/// Trait that implements QR decomposition routine for any matrix.
pub trait QRDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
pub trait QRDecomposable<T: Number + RealNumber>: Array2<T> {
/// Compute the QR decomposition of a matrix.
fn qr(&self) -> Result<QR<T, Self>, Failed> {
self.clone().qr_mut()
@@ -156,26 +155,26 @@ pub trait QRDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
for (k, r_diagonal_k) in r_diagonal.iter_mut().enumerate().take(n) {
let mut nrm = T::zero();
for i in k..m {
nrm = nrm.hypot(self.get(i, k));
nrm = nrm.hypot(*self.get((i, k)));
}
if nrm.abs() > T::epsilon() {
if self.get(k, k) < T::zero() {
if self.get((k, k)) < &T::zero() {
nrm = -nrm;
}
for i in k..m {
self.div_element_mut(i, k, nrm);
self.div_element_mut((i, k), nrm);
}
self.add_element_mut(k, k, T::one());
self.add_element_mut((k, k), T::one());
for j in k + 1..n {
let mut s = T::zero();
for i in k..m {
s += self.get(i, k) * self.get(i, j);
s += *self.get((i, k)) * *self.get((i, j));
}
s = -s / self.get(k, k);
s = -s / *self.get((k, k));
for i in k..m {
self.add_element_mut(i, j, s * self.get(i, k));
self.add_element_mut((i, j), s * *self.get((i, k)));
}
}
}
@@ -194,37 +193,49 @@ pub trait QRDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
use crate::linalg::basic::matrix::DenseMatrix;
use approx::relative_eq;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn decompose() {
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]])
.unwrap();
let q = DenseMatrix::from_2d_array(&[
&[-0.7448, 0.2436, 0.6212],
&[-0.331, -0.9432, -0.027],
&[-0.5793, 0.2257, -0.7832],
]);
])
.unwrap();
let r = DenseMatrix::from_2d_array(&[
&[-1.2083, -0.6373, -1.0842],
&[0.0, -0.3064, 0.0682],
&[0.0, 0.0, -0.1999],
]);
])
.unwrap();
let qr = a.qr().unwrap();
assert!(qr.Q().abs().approximate_eq(&q.abs(), 1e-4));
assert!(qr.R().abs().approximate_eq(&r.abs(), 1e-4));
assert!(relative_eq!(qr.Q().abs(), q.abs(), epsilon = 1e-4));
assert!(relative_eq!(qr.R().abs(), r.abs(), epsilon = 1e-4));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn qr_solve_mut() {
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]);
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]])
.unwrap();
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]).unwrap();
let expected_w = DenseMatrix::from_2d_array(&[
&[-0.2027027, -1.2837838],
&[0.8783784, 2.2297297],
&[0.4729730, 0.6621622],
]);
])
.unwrap();
let w = a.qr_solve_mut(b).unwrap();
assert!(w.approximate_eq(&expected_w, 1e-2));
assert!(relative_eq!(w, expected_w, epsilon = 1e-2));
}
}
+297
View File
@@ -0,0 +1,297 @@
//! # Various Statistical Methods
//!
//! This module provides reference implementations for various statistical functions.
//! Concrete implementations of the `BaseMatrix` trait are free to override these methods for better performance.
//! This methods shall be used when dealing with `DenseMatrix`. Use the ones in `linalg::arrays` for `Array` types.
use crate::linalg::basic::arrays::{Array2, ArrayView2, MutArrayView2};
use crate::numbers::realnum::RealNumber;
/// Defines baseline implementations for various statistical functions
pub trait MatrixStats<T: RealNumber>: ArrayView2<T> + Array2<T> {
/// Computes the arithmetic mean along the specified axis.
fn mean(&self, axis: u8) -> Vec<T> {
let (n, _m) = match axis {
0 => {
let (n, m) = self.shape();
(m, n)
}
_ => self.shape(),
};
let mut x: Vec<T> = vec![T::zero(); n];
for (i, x_i) in x.iter_mut().enumerate().take(n) {
let vec = match axis {
0 => self.get_col(i).iterator(0).copied().collect::<Vec<T>>(),
_ => self.get_row(i).iterator(0).copied().collect::<Vec<T>>(),
};
*x_i = Self::_mean_of_vector(&vec[..]);
}
x
}
/// Computes variance along the specified axis.
fn var(&self, axis: u8) -> Vec<T> {
let (n, _m) = match axis {
0 => {
let (n, m) = self.shape();
(m, n)
}
_ => self.shape(),
};
let mut x: Vec<T> = vec![T::zero(); n];
for (i, x_i) in x.iter_mut().enumerate().take(n) {
let vec = match axis {
0 => self.get_col(i).iterator(0).copied().collect::<Vec<T>>(),
_ => self.get_row(i).iterator(0).copied().collect::<Vec<T>>(),
};
*x_i = Self::_var_of_vec(&vec[..], Option::None);
}
x
}
/// Computes the standard deviation along the specified axis.
fn std(&self, axis: u8) -> Vec<T> {
let mut x = Self::var(self, axis);
let n = match axis {
0 => self.shape().1,
_ => self.shape().0,
};
for x_i in x.iter_mut().take(n) {
*x_i = x_i.sqrt();
}
x
}
/// <http://en.wikipedia.org/wiki/Arithmetic_mean>
/// Taken from `statistical`
/// The MIT License (MIT)
/// Copyright (c) 2015 Jeff Belgum
fn _mean_of_vector(v: &[T]) -> T {
let len = num::cast(v.len()).unwrap();
v.iter().fold(T::zero(), |acc: T, elem| acc + *elem) / len
}
/// Taken from statistical
/// The MIT License (MIT)
/// Copyright (c) 2015 Jeff Belgum
fn _sum_square_deviations_vec(v: &[T], c: Option<T>) -> T {
let c = match c {
Some(c) => c,
None => Self::_mean_of_vector(v),
};
let sum = v
.iter()
.map(|x| (*x - c) * (*x - c))
.fold(T::zero(), |acc, elem| acc + elem);
assert!(sum >= T::zero(), "negative sum of square root deviations");
sum
}
/// <http://en.wikipedia.org/wiki/Variance#Sample_variance>
/// Taken from statistical
/// The MIT License (MIT)
/// Copyright (c) 2015 Jeff Belgum
fn _var_of_vec(v: &[T], xbar: Option<T>) -> T {
assert!(v.len() > 1, "variance requires at least two data points");
let len: T = num::cast(v.len()).unwrap();
let sum = Self::_sum_square_deviations_vec(v, xbar);
sum / len
}
/// standardize values by removing the mean and scaling to unit variance
fn standard_scale_mut(&mut self, mean: &[T], std: &[T], axis: u8) {
let (n, m) = match axis {
0 => {
let (n, m) = self.shape();
(m, n)
}
_ => self.shape(),
};
for i in 0..n {
for j in 0..m {
match axis {
0 => self.set((j, i), (*self.get((j, i)) - mean[i]) / std[i]),
_ => self.set((i, j), (*self.get((i, j)) - mean[i]) / std[i]),
}
}
}
}
}
//TODO: this is processing. Should have its own "processing.rs" module
/// Defines baseline implementations for various matrix processing functions
pub trait MatrixPreprocessing<T: RealNumber>: MutArrayView2<T> + Clone {
/// Each element of the matrix greater than the threshold becomes 1, while values less than or equal to the threshold become 0
/// ```rust
/// use smartcore::linalg::basic::matrix::DenseMatrix;
/// use smartcore::linalg::traits::stats::MatrixPreprocessing;
/// let mut a = DenseMatrix::from_2d_array(&[&[0., 2., 3.], &[-5., -6., -7.]]).unwrap();
/// let expected = DenseMatrix::from_2d_array(&[&[0., 1., 1.],&[0., 0., 0.]]).unwrap();
/// a.binarize_mut(0.);
///
/// assert_eq!(a, expected);
/// ```
fn binarize_mut(&mut self, threshold: T) {
let (nrows, ncols) = self.shape();
for row in 0..nrows {
for col in 0..ncols {
if *self.get((row, col)) > threshold {
self.set((row, col), T::one());
} else {
self.set((row, col), T::zero());
}
}
}
}
/// Returns new matrix where elements are binarized according to a given threshold.
/// ```rust
/// use smartcore::linalg::basic::matrix::DenseMatrix;
/// use smartcore::linalg::traits::stats::MatrixPreprocessing;
/// let a = DenseMatrix::from_2d_array(&[&[0., 2., 3.], &[-5., -6., -7.]]).unwrap();
/// let expected = DenseMatrix::from_2d_array(&[&[0., 1., 1.],&[0., 0., 0.]]).unwrap();
///
/// assert_eq!(a.binarize(0.), expected);
/// ```
fn binarize(self, threshold: T) -> Self
where
Self: Sized,
{
let mut m = self;
m.binarize_mut(threshold);
m
}
}
#[cfg(test)]
mod tests {
use crate::linalg::basic::arrays::Array1;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::linalg::traits::stats::MatrixStats;
#[test]
fn test_mean() {
let m = DenseMatrix::from_2d_array(&[
&[1., 2., 3., 1., 2.],
&[4., 5., 6., 3., 4.],
&[7., 8., 9., 5., 6.],
])
.unwrap();
let expected_0 = vec![4., 5., 6., 3., 4.];
let expected_1 = vec![1.8, 4.4, 7.];
assert_eq!(m.mean(0), expected_0);
assert_eq!(m.mean(1), expected_1);
}
#[test]
fn test_var() {
let m = DenseMatrix::from_2d_array(&[&[1., 2., 3., 4.], &[5., 6., 7., 8.]]).unwrap();
let expected_0 = vec![4., 4., 4., 4.];
let expected_1 = vec![1.25, 1.25];
assert!(m.var(0).approximate_eq(&expected_0, 1e-6));
assert!(m.var(1).approximate_eq(&expected_1, 1e-6));
assert_eq!(m.mean(0), vec![3.0, 4.0, 5.0, 6.0]);
assert_eq!(m.mean(1), vec![2.5, 6.5]);
}
#[test]
fn test_var_other() {
let m = DenseMatrix::from_2d_array(&[
&[0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25],
&[0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25],
])
.unwrap();
let expected_0 = vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let expected_1 = vec![1.25, 1.25];
assert!(m.var(0).approximate_eq(&expected_0, f64::EPSILON));
assert!(m.var(1).approximate_eq(&expected_1, f64::EPSILON));
assert_eq!(
m.mean(0),
vec![0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25]
);
assert_eq!(m.mean(1), vec![1.375, 1.375]);
}
#[test]
fn test_std() {
let m = DenseMatrix::from_2d_array(&[
&[1., 2., 3., 1., 2.],
&[4., 5., 6., 3., 4.],
&[7., 8., 9., 5., 6.],
])
.unwrap();
let expected_0 = vec![
2.449489742783178,
2.449489742783178,
2.449489742783178,
1.632993161855452,
1.632993161855452,
];
let expected_1 = vec![0.7483314773547883, 1.019803902718557, 1.4142135623730951];
println!("{:?}", m.var(0));
assert!(m.std(0).approximate_eq(&expected_0, f64::EPSILON));
assert!(m.std(1).approximate_eq(&expected_1, f64::EPSILON));
assert_eq!(m.mean(0), vec![4.0, 5.0, 6.0, 3.0, 4.0]);
assert_eq!(m.mean(1), vec![1.8, 4.4, 7.0]);
}
#[test]
fn test_scale() {
let m: DenseMatrix<f64> =
DenseMatrix::from_2d_array(&[&[1., 2., 3., 4.], &[5., 6., 7., 8.]]).unwrap();
let expected_0: DenseMatrix<f64> =
DenseMatrix::from_2d_array(&[&[-1., -1., -1., -1.], &[1., 1., 1., 1.]]).unwrap();
let expected_1: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
&[
-1.3416407864998738,
-0.4472135954999579,
0.4472135954999579,
1.3416407864998738,
],
&[
-1.3416407864998738,
-0.4472135954999579,
0.4472135954999579,
1.3416407864998738,
],
])
.unwrap();
assert_eq!(m.mean(0), vec![3.0, 4.0, 5.0, 6.0]);
assert_eq!(m.mean(1), vec![2.5, 6.5]);
assert_eq!(m.var(0), vec![4., 4., 4., 4.]);
assert_eq!(m.var(1), vec![1.25, 1.25]);
assert_eq!(m.std(0), vec![2., 2., 2., 2.]);
assert_eq!(m.std(1), vec![1.118033988749895, 1.118033988749895]);
{
let mut m = m.clone();
m.standard_scale_mut(&m.mean(0), &m.std(0), 0);
assert_eq!(&m, &expected_0);
}
{
let mut m = m;
m.standard_scale_mut(&m.mean(1), &m.std(1), 1);
assert_eq!(&m, &expected_1);
}
}
}
+120 -107
View File
@@ -10,14 +10,14 @@
//!
//! Example:
//! ```
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::linalg::svd::*;
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linalg::traits::svd::*;
//!
//! let A = DenseMatrix::from_2d_array(&[
//! &[0.9, 0.4, 0.7],
//! &[0.4, 0.5, 0.3],
//! &[0.7, 0.3, 0.8]
//! ]);
//! ]).unwrap();
//!
//! let svd = A.svd().unwrap();
//! let u: DenseMatrix<f64> = svd.U;
@@ -34,32 +34,33 @@
#![allow(non_snake_case)]
use crate::error::Failed;
use crate::linalg::BaseMatrix;
use crate::math::num::RealNumber;
use crate::linalg::basic::arrays::Array2;
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
use std::fmt::Debug;
/// Results of SVD decomposition
#[derive(Debug, Clone)]
pub struct SVD<T: RealNumber, M: SVDDecomposableMatrix<T>> {
pub struct SVD<T: Number + RealNumber, M: SVDDecomposable<T>> {
/// Left-singular vectors of _A_
pub U: M,
/// Right-singular vectors of _A_
pub V: M,
/// Singular values of the original matrix
pub s: Vec<T>,
_full: bool,
m: usize,
n: usize,
/// Tolerance
tol: T,
}
impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
impl<T: Number + RealNumber, M: SVDDecomposable<T>> SVD<T, M> {
/// Diagonal matrix with singular values
pub fn S(&self) -> M {
let mut s = M::zeros(self.U.shape().1, self.V.shape().0);
for i in 0..self.s.len() {
s.set(i, i, self.s[i]);
s.set((i, i), self.s[i]);
}
s
@@ -67,7 +68,7 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
}
/// Trait that implements SVD decomposition routine for any matrix.
pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
pub trait SVDDecomposable<T: Number + RealNumber>: Array2<T> {
/// Solves Ax = b. Overrides original matrix in the process.
fn svd_solve_mut(self, b: Self) -> Result<Self, Failed> {
self.svd_mut().and_then(|svd| svd.solve(b))
@@ -106,31 +107,31 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
if i < m {
for k in i..m {
scale += U.get(k, i).abs();
scale += U.get((k, i)).abs();
}
if scale.abs() > T::epsilon() {
for k in i..m {
U.div_element_mut(k, i, scale);
s += U.get(k, i) * U.get(k, i);
U.div_element_mut((k, i), scale);
s += *U.get((k, i)) * *U.get((k, i));
}
let mut f = U.get(i, i);
g = -RealNumber::copysign(s.sqrt(), f);
let mut f = *U.get((i, i));
g = -<T as RealNumber>::copysign(s.sqrt(), f);
let h = f * g - s;
U.set(i, i, f - g);
U.set((i, i), f - g);
for j in l - 1..n {
s = T::zero();
for k in i..m {
s += U.get(k, i) * U.get(k, j);
s += *U.get((k, i)) * *U.get((k, j));
}
f = s / h;
for k in i..m {
U.add_element_mut(k, j, f * U.get(k, i));
U.add_element_mut((k, j), f * *U.get((k, i)));
}
}
for k in i..m {
U.mul_element_mut(k, i, scale);
U.mul_element_mut((k, i), scale);
}
}
}
@@ -142,37 +143,37 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
if i < m && i + 1 != n {
for k in l - 1..n {
scale += U.get(i, k).abs();
scale += U.get((i, k)).abs();
}
if scale.abs() > T::epsilon() {
for k in l - 1..n {
U.div_element_mut(i, k, scale);
s += U.get(i, k) * U.get(i, k);
U.div_element_mut((i, k), scale);
s += *U.get((i, k)) * *U.get((i, k));
}
let f = U.get(i, l - 1);
g = -RealNumber::copysign(s.sqrt(), f);
let f = *U.get((i, l - 1));
g = -<T as RealNumber>::copysign(s.sqrt(), f);
let h = f * g - s;
U.set(i, l - 1, f - g);
U.set((i, l - 1), f - g);
for (k, rv1_k) in rv1.iter_mut().enumerate().take(n).skip(l - 1) {
*rv1_k = U.get(i, k) / h;
*rv1_k = *U.get((i, k)) / h;
}
for j in l - 1..m {
s = T::zero();
for k in l - 1..n {
s += U.get(j, k) * U.get(i, k);
s += *U.get((j, k)) * *U.get((i, k));
}
for (k, rv1_k) in rv1.iter().enumerate().take(n).skip(l - 1) {
U.add_element_mut(j, k, s * (*rv1_k));
U.add_element_mut((j, k), s * (*rv1_k));
}
}
for k in l - 1..n {
U.mul_element_mut(i, k, scale);
U.mul_element_mut((i, k), scale);
}
}
}
@@ -184,24 +185,24 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
if i < n - 1 {
if g != T::zero() {
for j in l..n {
v.set(j, i, (U.get(i, j) / U.get(i, l)) / g);
v.set((j, i), (*U.get((i, j)) / *U.get((i, l))) / g);
}
for j in l..n {
let mut s = T::zero();
for k in l..n {
s += U.get(i, k) * v.get(k, j);
s += *U.get((i, k)) * *v.get((k, j));
}
for k in l..n {
v.add_element_mut(k, j, s * v.get(k, i));
v.add_element_mut((k, j), s * *v.get((k, i)));
}
}
}
for j in l..n {
v.set(i, j, T::zero());
v.set(j, i, T::zero());
v.set((i, j), T::zero());
v.set((j, i), T::zero());
}
}
v.set(i, i, T::one());
v.set((i, i), T::one());
g = rv1[i];
l = i;
}
@@ -210,7 +211,7 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
l = i + 1;
g = w[i];
for j in l..n {
U.set(i, j, T::zero());
U.set((i, j), T::zero());
}
if g.abs() > T::epsilon() {
@@ -218,23 +219,23 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
for j in l..n {
let mut s = T::zero();
for k in l..m {
s += U.get(k, i) * U.get(k, j);
s += *U.get((k, i)) * *U.get((k, j));
}
let f = (s / U.get(i, i)) * g;
let f = (s / *U.get((i, i))) * g;
for k in i..m {
U.add_element_mut(k, j, f * U.get(k, i));
U.add_element_mut((k, j), f * *U.get((k, i)));
}
}
for j in i..m {
U.mul_element_mut(j, i, g);
U.mul_element_mut((j, i), g);
}
} else {
for j in i..m {
U.set(j, i, T::zero());
U.set((j, i), T::zero());
}
}
U.add_element_mut(i, i, T::one());
U.add_element_mut((i, i), T::one());
}
for k in (0..n).rev() {
@@ -269,10 +270,10 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
c = g * h;
s = -f * h;
for j in 0..m {
let y = U.get(j, nm);
let z = U.get(j, i);
U.set(j, nm, y * c + z * s);
U.set(j, i, z * c - y * s);
let y = *U.get((j, nm));
let z = *U.get((j, i));
U.set((j, nm), y * c + z * s);
U.set((j, i), z * c - y * s);
}
}
}
@@ -282,7 +283,7 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
if z < T::zero() {
w[k] = -z;
for j in 0..n {
v.set(j, k, -v.get(j, k));
v.set((j, k), -*v.get((j, k)));
}
}
break;
@@ -299,7 +300,8 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
let mut h = rv1[k];
let mut f = ((y - z) * (y + z) + (g - h) * (g + h)) / (T::two() * h * y);
g = f.hypot(T::one());
f = ((x - z) * (x + z) + h * ((y / (f + RealNumber::copysign(g, f))) - h)) / x;
f = ((x - z) * (x + z) + h * ((y / (f + <T as RealNumber>::copysign(g, f))) - h))
/ x;
let mut c = T::one();
let mut s = T::one();
@@ -319,10 +321,10 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
y *= c;
for jj in 0..n {
x = v.get(jj, j);
z = v.get(jj, i);
v.set(jj, j, x * c + z * s);
v.set(jj, i, z * c - x * s);
x = *v.get((jj, j));
z = *v.get((jj, i));
v.set((jj, j), x * c + z * s);
v.set((jj, i), z * c - x * s);
}
z = f.hypot(h);
@@ -336,10 +338,10 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
f = c * g + s * y;
x = c * y - s * g;
for jj in 0..m {
y = U.get(jj, j);
z = U.get(jj, i);
U.set(jj, j, y * c + z * s);
U.set(jj, i, z * c - y * s);
y = *U.get((jj, j));
z = *U.get((jj, i));
U.set((jj, j), y * c + z * s);
U.set((jj, i), z * c - y * s);
}
}
@@ -366,19 +368,19 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
for i in inc..n {
let sw = w[i];
for (k, su_k) in su.iter_mut().enumerate().take(m) {
*su_k = U.get(k, i);
*su_k = *U.get((k, i));
}
for (k, sv_k) in sv.iter_mut().enumerate().take(n) {
*sv_k = v.get(k, i);
*sv_k = *v.get((k, i));
}
let mut j = i;
while w[j - inc] < sw {
w[j] = w[j - inc];
for k in 0..m {
U.set(k, j, U.get(k, j - inc));
U.set((k, j), *U.get((k, j - inc)));
}
for k in 0..n {
v.set(k, j, v.get(k, j - inc));
v.set((k, j), *v.get((k, j - inc)));
}
j -= inc;
if j < inc {
@@ -387,10 +389,10 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
}
w[j] = sw;
for (k, su_k) in su.iter().enumerate().take(m) {
U.set(k, j, *su_k);
U.set((k, j), *su_k);
}
for (k, sv_k) in sv.iter().enumerate().take(n) {
v.set(k, j, *sv_k);
v.set((k, j), *sv_k);
}
}
if inc <= 1 {
@@ -401,21 +403,21 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
for k in 0..n {
let mut s = 0.;
for i in 0..m {
if U.get(i, k) < T::zero() {
if U.get((i, k)) < &T::zero() {
s += 1.;
}
}
for j in 0..n {
if v.get(j, k) < T::zero() {
if v.get((j, k)) < &T::zero() {
s += 1.;
}
}
if s > (m + n) as f64 / 2. {
for i in 0..m {
U.set(i, k, -U.get(i, k));
U.set((i, k), -*U.get((i, k)));
}
for j in 0..n {
v.set(j, k, -v.get(j, k));
v.set((j, k), -*v.get((j, k)));
}
}
}
@@ -424,21 +426,12 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
}
}
impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
impl<T: Number + RealNumber, M: SVDDecomposable<T>> SVD<T, M> {
pub(crate) fn new(U: M, V: M, s: Vec<T>) -> SVD<T, M> {
let m = U.shape().0;
let n = V.shape().0;
let _full = s.len() == m.min(n);
let tol = T::half() * (T::from(m + n).unwrap() + T::one()).sqrt() * s[0] * T::epsilon();
SVD {
U,
V,
s,
_full,
m,
n,
tol,
}
SVD { U, V, s, m, n, tol }
}
pub(crate) fn solve(&self, mut b: M) -> Result<M, Failed> {
@@ -458,7 +451,7 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
let mut r = T::zero();
if self.s[j] > self.tol {
for i in 0..self.m {
r += self.U.get(i, j) * b.get(i, k);
r += *self.U.get((i, j)) * *b.get((i, k));
}
r /= self.s[j];
}
@@ -468,9 +461,9 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
for j in 0..self.n {
let mut r = T::zero();
for (jj, tmp_jj) in tmp.iter().enumerate().take(self.n) {
r += self.V.get(j, jj) * (*tmp_jj);
r += *self.V.get((j, jj)) * (*tmp_jj);
}
b.set(j, k, r);
b.set((j, k), r);
}
}
@@ -481,15 +474,21 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
use crate::linalg::basic::matrix::DenseMatrix;
use approx::relative_eq;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn decompose_symmetric() {
let A = DenseMatrix::from_2d_array(&[
&[0.9000, 0.4000, 0.7000],
&[0.4000, 0.5000, 0.3000],
&[0.7000, 0.3000, 0.8000],
]);
])
.unwrap();
let s: Vec<f64> = vec![1.7498382, 0.3165784, 0.1335834];
@@ -497,23 +496,28 @@ mod tests {
&[0.6881997, -0.07121225, 0.7220180],
&[0.3700456, 0.89044952, -0.2648886],
&[0.6240573, -0.44947578, -0.639158],
]);
])
.unwrap();
let V = DenseMatrix::from_2d_array(&[
&[0.6881997, -0.07121225, 0.7220180],
&[0.3700456, 0.89044952, -0.2648886],
&[0.6240573, -0.44947578, -0.6391588],
]);
])
.unwrap();
let svd = A.svd().unwrap();
assert!(V.abs().approximate_eq(&svd.V.abs(), 1e-4));
assert!(U.abs().approximate_eq(&svd.U.abs(), 1e-4));
for i in 0..s.len() {
assert!((s[i] - svd.s[i]).abs() < 1e-4);
assert!(relative_eq!(V.abs(), svd.V.abs(), epsilon = 1e-4));
assert!(relative_eq!(U.abs(), svd.U.abs(), epsilon = 1e-4));
for (i, s_i) in s.iter().enumerate() {
assert!((s_i - svd.s[i]).abs() < 1e-4);
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn decompose_asymmetric() {
let A = DenseMatrix::from_2d_array(&[
@@ -574,7 +578,8 @@ mod tests {
-0.2158704,
-0.27529472,
],
]);
])
.unwrap();
let s: Vec<f64> = vec![
3.8589375, 3.4396766, 2.6487176, 2.2317399, 1.5165054, 0.8109055, 0.2706515,
@@ -644,7 +649,8 @@ mod tests {
0.73034065,
-0.43965505,
],
]);
])
.unwrap();
let V = DenseMatrix::from_2d_array(&[
&[
@@ -704,31 +710,40 @@ mod tests {
0.1654796,
-0.32346758,
],
]);
])
.unwrap();
let svd = A.svd().unwrap();
assert!(V.abs().approximate_eq(&svd.V.abs(), 1e-4));
assert!(U.abs().approximate_eq(&svd.U.abs(), 1e-4));
for i in 0..s.len() {
assert!((s[i] - svd.s[i]).abs() < 1e-4);
assert!(relative_eq!(V.abs(), svd.V.abs(), epsilon = 1e-4));
assert!(relative_eq!(U.abs(), svd.U.abs(), epsilon = 1e-4));
for (i, s_i) in s.iter().enumerate() {
assert!((s_i - svd.s[i]).abs() < 1e-4);
}
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn solve() {
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]);
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]])
.unwrap();
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]).unwrap();
let expected_w =
DenseMatrix::from_2d_array(&[&[-0.20, -1.28], &[0.87, 2.22], &[0.47, 0.66]]);
DenseMatrix::from_2d_array(&[&[-0.20, -1.28], &[0.87, 2.22], &[0.47, 0.66]]).unwrap();
let w = a.svd_solve_mut(b).unwrap();
assert!(w.approximate_eq(&expected_w, 1e-2));
assert!(relative_eq!(w, expected_w, epsilon = 1e-2));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn decompose_restore() {
let a = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0, 4.0], &[5.0, 6.0, 7.0, 8.0]]);
let a =
DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0, 4.0], &[5.0, 6.0, 7.0, 8.0]]).unwrap();
let svd = a.svd().unwrap();
let u: &DenseMatrix<f32> = &svd.U; //U
let v: &DenseMatrix<f32> = &svd.V; // V
@@ -736,8 +751,6 @@ mod tests {
let a_hat = u.matmul(s).matmul(&v.transpose());
for (a, a_hat) in a.iter().zip(a_hat.iter()) {
assert!((a - a_hat).abs() < 1e-3)
}
assert!(relative_eq!(a, a_hat, epsilon = 1e-3));
}
}
+81 -49
View File
@@ -1,13 +1,43 @@
//! This is a generic solver for Ax = b type of equation
//!
//! Example:
//! ```
//! use smartcore::linalg::basic::arrays::Array1;
//! use smartcore::linalg::basic::arrays::Array2;
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linear::bg_solver::*;
//! use smartcore::numbers::floatnum::FloatNumber;
//! use smartcore::linear::bg_solver::BiconjugateGradientSolver;
//!
//! pub struct BGSolver {}
//! impl<'a, T: FloatNumber, X: Array2<T>> BiconjugateGradientSolver<'a, T, X> for BGSolver {}
//!
//! let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0.,
//! 11.]]).unwrap();
//! let b = vec![40., 51., 28.];
//! let expected = vec![1.0, 2.0, 3.0];
//! let mut x = Vec::zeros(3);
//! let solver = BGSolver {};
//! let err: f64 = solver.solve_mut(&a, &b, &mut x, 1e-6, 6).unwrap();
//! ```
//!
//! for more information take a look at [this Wikipedia article](https://en.wikipedia.org/wiki/Biconjugate_gradient_method)
//! and [this paper](https://www.cs.cmu.edu/~quake-papers/painless-conjugate-gradient.pdf)
use crate::error::Failed;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
use crate::linalg::basic::arrays::{Array, Array1, Array2, ArrayView1, MutArrayView1};
use crate::numbers::floatnum::FloatNumber;
pub trait BiconjugateGradientSolver<T: RealNumber, M: Matrix<T>> {
fn solve_mut(&self, a: &M, b: &M, x: &mut M, tol: T, max_iter: usize) -> Result<T, Failed> {
/// Trait for Biconjugate Gradient Solver
pub trait BiconjugateGradientSolver<'a, T: FloatNumber, X: Array2<T>> {
/// Solve Ax = b
fn solve_mut(
&self,
a: &'a X,
b: &Vec<T>,
x: &mut Vec<T>,
tol: T,
max_iter: usize,
) -> Result<T, Failed> {
if tol <= T::zero() {
return Err(Failed::fit("tolerance shoud be > 0"));
}
@@ -16,25 +46,25 @@ pub trait BiconjugateGradientSolver<T: RealNumber, M: Matrix<T>> {
return Err(Failed::fit("maximum number of iterations should be > 0"));
}
let (n, _) = b.shape();
let n = b.shape();
let mut r = M::zeros(n, 1);
let mut rr = M::zeros(n, 1);
let mut z = M::zeros(n, 1);
let mut zz = M::zeros(n, 1);
let mut r = Vec::zeros(n);
let mut rr = Vec::zeros(n);
let mut z = Vec::zeros(n);
let mut zz = Vec::zeros(n);
self.mat_vec_mul(a, x, &mut r);
for j in 0..n {
r.set(j, 0, b.get(j, 0) - r.get(j, 0));
rr.set(j, 0, r.get(j, 0));
r[j] = b[j] - r[j];
rr[j] = r[j];
}
let bnrm = b.norm(T::two());
self.solve_preconditioner(a, &r, &mut z);
let bnrm = b.norm(2f64);
self.solve_preconditioner(a, &r[..], &mut z[..]);
let mut p = M::zeros(n, 1);
let mut pp = M::zeros(n, 1);
let mut p = Vec::zeros(n);
let mut pp = Vec::zeros(n);
let mut bkden = T::zero();
let mut err = T::zero();
@@ -43,35 +73,33 @@ pub trait BiconjugateGradientSolver<T: RealNumber, M: Matrix<T>> {
self.solve_preconditioner(a, &rr, &mut zz);
for j in 0..n {
bknum += z.get(j, 0) * rr.get(j, 0);
bknum += z[j] * rr[j];
}
if iter == 1 {
for j in 0..n {
p.set(j, 0, z.get(j, 0));
pp.set(j, 0, zz.get(j, 0));
}
p[..n].copy_from_slice(&z[..n]);
pp[..n].copy_from_slice(&zz[..n]);
} else {
let bk = bknum / bkden;
for j in 0..n {
p.set(j, 0, bk * p.get(j, 0) + z.get(j, 0));
pp.set(j, 0, bk * pp.get(j, 0) + zz.get(j, 0));
p[j] = bk * pp[j] + z[j];
pp[j] = bk * pp[j] + zz[j];
}
}
bkden = bknum;
self.mat_vec_mul(a, &p, &mut z);
let mut akden = T::zero();
for j in 0..n {
akden += z.get(j, 0) * pp.get(j, 0);
akden += z[j] * pp[j];
}
let ak = bknum / akden;
self.mat_t_vec_mul(a, &pp, &mut zz);
for j in 0..n {
x.set(j, 0, x.get(j, 0) + ak * p.get(j, 0));
r.set(j, 0, r.get(j, 0) - ak * z.get(j, 0));
rr.set(j, 0, rr.get(j, 0) - ak * zz.get(j, 0));
x[j] += ak * p[j];
r[j] -= ak * z[j];
rr[j] -= ak * zz[j];
}
self.solve_preconditioner(a, &r, &mut z);
err = r.norm(T::two()) / bnrm;
err = T::from_f64(r.norm(2f64) / bnrm).unwrap();
if err <= tol {
break;
@@ -81,36 +109,38 @@ pub trait BiconjugateGradientSolver<T: RealNumber, M: Matrix<T>> {
Ok(err)
}
fn solve_preconditioner(&self, a: &M, b: &M, x: &mut M) {
/// solve preconditioner
fn solve_preconditioner(&self, a: &'a X, b: &[T], x: &mut [T]) {
let diag = Self::diag(a);
let n = diag.len();
for (i, diag_i) in diag.iter().enumerate().take(n) {
if *diag_i != T::zero() {
x.set(i, 0, b.get(i, 0) / *diag_i);
x[i] = b[i] / *diag_i;
} else {
x.set(i, 0, b.get(i, 0));
x[i] = b[i];
}
}
}
// y = Ax
fn mat_vec_mul(&self, a: &M, x: &M, y: &mut M) {
y.copy_from(&a.matmul(x));
/// y = Ax
fn mat_vec_mul(&self, a: &X, x: &Vec<T>, y: &mut Vec<T>) {
y.copy_from(&x.xa(false, a));
}
// y = Atx
fn mat_t_vec_mul(&self, a: &M, x: &M, y: &mut M) {
y.copy_from(&a.ab(true, x, false));
/// y = Atx
fn mat_t_vec_mul(&self, a: &X, x: &Vec<T>, y: &mut Vec<T>) {
y.copy_from(&x.xa(true, a));
}
fn diag(a: &M) -> Vec<T> {
/// Extract the diagonal from a matrix
fn diag(a: &X) -> Vec<T> {
let (nrows, ncols) = a.shape();
let n = nrows.min(ncols);
let mut d = Vec::with_capacity(n);
for i in 0..n {
d.push(a.get(i, i));
d.push(*a.get((i, i)));
}
d
@@ -120,28 +150,30 @@ pub trait BiconjugateGradientSolver<T: RealNumber, M: Matrix<T>> {
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::*;
use crate::linalg::basic::arrays::Array2;
use crate::linalg::basic::matrix::DenseMatrix;
pub struct BGSolver {}
impl<T: RealNumber, M: Matrix<T>> BiconjugateGradientSolver<T, M> for BGSolver {}
impl<T: FloatNumber, X: Array2<T>> BiconjugateGradientSolver<'_, T, X> for BGSolver {}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn bg_solver() {
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
let b = DenseMatrix::from_2d_array(&[&[40., 51., 28.]]);
let expected = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]);
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]])
.unwrap();
let b = vec![40., 51., 28.];
let expected = [1.0, 2.0, 3.0];
let mut x = DenseMatrix::zeros(3, 1);
let mut x = Vec::zeros(3);
let solver = BGSolver {};
let err: f64 = solver
.solve_mut(&a, &b.transpose(), &mut x, 1e-6, 6)
.unwrap();
let err: f64 = solver.solve_mut(&a, &b, &mut x, 1e-6, 6).unwrap();
assert!(x.transpose().approximate_eq(&expected, 1e-4));
assert!(x
.iter()
.zip(expected.iter())
.all(|(&a, &b)| (a - b).abs() < 1e-4));
assert!((err - 0.0).abs() < 1e-4);
}
}
+322 -114
View File
@@ -17,7 +17,7 @@
//! Example:
//!
//! ```
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linear::elastic_net::*;
//!
//! // Longley dataset (https://www.statsmodels.org/stable/datasets/generated/longley.html)
@@ -38,7 +38,7 @@
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
//! ]);
//! ]).unwrap();
//!
//! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0,
//! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
@@ -55,32 +55,39 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::fmt::Debug;
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed;
use crate::linalg::BaseVector;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
use crate::linalg::basic::arrays::{Array, Array1, Array2, MutArray};
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
use crate::numbers::realnum::RealNumber;
use crate::linear::lasso_optimizer::InteriorPointOptimizer;
/// Elastic net parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct ElasticNetParameters<T: RealNumber> {
pub struct ElasticNetParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Regularization parameter.
pub alpha: T,
pub alpha: f64,
#[cfg_attr(feature = "serde", serde(default))]
/// The elastic net mixing parameter, with 0 <= l1_ratio <= 1.
/// For l1_ratio = 0 the penalty is an L2 penalty.
/// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.
pub l1_ratio: T,
pub l1_ratio: f64,
#[cfg_attr(feature = "serde", serde(default))]
/// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation.
pub normalize: bool,
#[cfg_attr(feature = "serde", serde(default))]
/// The tolerance for the optimization
pub tol: T,
pub tol: f64,
#[cfg_attr(feature = "serde", serde(default))]
/// The maximum number of iterations
pub max_iter: usize,
}
@@ -88,21 +95,23 @@ pub struct ElasticNetParameters<T: RealNumber> {
/// Elastic net
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct ElasticNet<T: RealNumber, M: Matrix<T>> {
coefficients: M,
intercept: T,
pub struct ElasticNet<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
coefficients: Option<X>,
intercept: Option<TX>,
_phantom_ty: PhantomData<TY>,
_phantom_y: PhantomData<Y>,
}
impl<T: RealNumber> ElasticNetParameters<T> {
impl ElasticNetParameters {
/// Regularization parameter.
pub fn with_alpha(mut self, alpha: T) -> Self {
pub fn with_alpha(mut self, alpha: f64) -> Self {
self.alpha = alpha;
self
}
/// The elastic net mixing parameter, with 0 <= l1_ratio <= 1.
/// For l1_ratio = 0 the penalty is an L2 penalty.
/// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.
pub fn with_l1_ratio(mut self, l1_ratio: T) -> Self {
pub fn with_l1_ratio(mut self, l1_ratio: f64) -> Self {
self.l1_ratio = l1_ratio;
self
}
@@ -112,7 +121,7 @@ impl<T: RealNumber> ElasticNetParameters<T> {
self
}
/// The tolerance for the optimization
pub fn with_tol(mut self, tol: T) -> Self {
pub fn with_tol(mut self, tol: f64) -> Self {
self.tol = tol;
self
}
@@ -123,61 +132,205 @@ impl<T: RealNumber> ElasticNetParameters<T> {
}
}
impl<T: RealNumber> Default for ElasticNetParameters<T> {
impl Default for ElasticNetParameters {
fn default() -> Self {
ElasticNetParameters {
alpha: T::one(),
l1_ratio: T::half(),
alpha: 1.0,
l1_ratio: 0.5,
normalize: true,
tol: T::from_f64(1e-4).unwrap(),
tol: 1e-4,
max_iter: 1000,
}
}
}
impl<T: RealNumber, M: Matrix<T>> PartialEq for ElasticNet<T, M> {
fn eq(&self, other: &Self) -> bool {
self.coefficients == other.coefficients
&& (self.intercept - other.intercept).abs() <= T::epsilon()
/// ElasticNet grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct ElasticNetSearchParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Regularization parameter.
pub alpha: Vec<f64>,
#[cfg_attr(feature = "serde", serde(default))]
/// The elastic net mixing parameter, with 0 <= l1_ratio <= 1.
/// For l1_ratio = 0 the penalty is an L2 penalty.
/// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.
pub l1_ratio: Vec<f64>,
#[cfg_attr(feature = "serde", serde(default))]
/// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation.
pub normalize: Vec<bool>,
#[cfg_attr(feature = "serde", serde(default))]
/// The tolerance for the optimization
pub tol: Vec<f64>,
#[cfg_attr(feature = "serde", serde(default))]
/// The maximum number of iterations
pub max_iter: Vec<usize>,
}
/// ElasticNet grid search iterator
pub struct ElasticNetSearchParametersIterator {
lasso_regression_search_parameters: ElasticNetSearchParameters,
current_alpha: usize,
current_l1_ratio: usize,
current_normalize: usize,
current_tol: usize,
current_max_iter: usize,
}
impl IntoIterator for ElasticNetSearchParameters {
type Item = ElasticNetParameters;
type IntoIter = ElasticNetSearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
ElasticNetSearchParametersIterator {
lasso_regression_search_parameters: self,
current_alpha: 0,
current_l1_ratio: 0,
current_normalize: 0,
current_tol: 0,
current_max_iter: 0,
}
}
}
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, ElasticNetParameters<T>>
for ElasticNet<T, M>
impl Iterator for ElasticNetSearchParametersIterator {
type Item = ElasticNetParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_alpha == self.lasso_regression_search_parameters.alpha.len()
&& self.current_l1_ratio == self.lasso_regression_search_parameters.l1_ratio.len()
&& self.current_normalize == self.lasso_regression_search_parameters.normalize.len()
&& self.current_tol == self.lasso_regression_search_parameters.tol.len()
&& self.current_max_iter == self.lasso_regression_search_parameters.max_iter.len()
{
return None;
}
let next = ElasticNetParameters {
alpha: self.lasso_regression_search_parameters.alpha[self.current_alpha],
l1_ratio: self.lasso_regression_search_parameters.alpha[self.current_l1_ratio],
normalize: self.lasso_regression_search_parameters.normalize[self.current_normalize],
tol: self.lasso_regression_search_parameters.tol[self.current_tol],
max_iter: self.lasso_regression_search_parameters.max_iter[self.current_max_iter],
};
if self.current_alpha + 1 < self.lasso_regression_search_parameters.alpha.len() {
self.current_alpha += 1;
} else if self.current_l1_ratio + 1 < self.lasso_regression_search_parameters.l1_ratio.len()
{
self.current_alpha = 0;
self.current_l1_ratio += 1;
} else if self.current_normalize + 1
< self.lasso_regression_search_parameters.normalize.len()
{
self.current_alpha = 0;
self.current_l1_ratio = 0;
self.current_normalize += 1;
} else if self.current_tol + 1 < self.lasso_regression_search_parameters.tol.len() {
self.current_alpha = 0;
self.current_l1_ratio = 0;
self.current_normalize = 0;
self.current_tol += 1;
} else if self.current_max_iter + 1 < self.lasso_regression_search_parameters.max_iter.len()
{
self.current_alpha = 0;
self.current_l1_ratio = 0;
self.current_normalize = 0;
self.current_tol = 0;
self.current_max_iter += 1;
} else {
self.current_alpha += 1;
self.current_l1_ratio += 1;
self.current_normalize += 1;
self.current_tol += 1;
self.current_max_iter += 1;
}
Some(next)
}
}
impl Default for ElasticNetSearchParameters {
fn default() -> Self {
let default_params = ElasticNetParameters::default();
ElasticNetSearchParameters {
alpha: vec![default_params.alpha],
l1_ratio: vec![default_params.l1_ratio],
normalize: vec![default_params.normalize],
tol: vec![default_params.tol],
max_iter: vec![default_params.max_iter],
}
}
}
impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
for ElasticNet<TX, TY, X, Y>
{
fn fit(x: &M, y: &M::RowVector, parameters: ElasticNetParameters<T>) -> Result<Self, Failed> {
fn eq(&self, other: &Self) -> bool {
if self.intercept() != other.intercept() {
return false;
}
if self.coefficients().shape() != other.coefficients().shape() {
return false;
}
self.coefficients()
.iterator(0)
.zip(other.coefficients().iterator(0))
.all(|(&a, &b)| (a - b).abs() <= TX::epsilon())
}
}
impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
SupervisedEstimator<X, Y, ElasticNetParameters> for ElasticNet<TX, TY, X, Y>
{
fn new() -> Self {
Self {
coefficients: Option::None,
intercept: Option::None,
_phantom_ty: PhantomData,
_phantom_y: PhantomData,
}
}
fn fit(x: &X, y: &Y, parameters: ElasticNetParameters) -> Result<Self, Failed> {
ElasticNet::fit(x, y, parameters)
}
}
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for ElasticNet<T, M> {
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Predictor<X, Y>
for ElasticNet<TX, TY, X, Y>
{
fn predict(&self, x: &X) -> Result<Y, Failed> {
self.predict(x)
}
}
impl<T: RealNumber, M: Matrix<T>> ElasticNet<T, M> {
impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
ElasticNet<TX, TY, X, Y>
{
/// Fits elastic net regression to your data.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - target values
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
pub fn fit(
x: &M,
y: &M::RowVector,
parameters: ElasticNetParameters<T>,
) -> Result<ElasticNet<T, M>, Failed> {
x: &X,
y: &Y,
parameters: ElasticNetParameters,
) -> Result<ElasticNet<TX, TY, X, Y>, Failed> {
let (n, p) = x.shape();
if y.len() != n {
if y.shape() != n {
return Err(Failed::fit("Number of rows in X should = len(y)"));
}
let n_float = T::from_usize(n).unwrap();
let n_float = n as f64;
let l1_reg = parameters.alpha * parameters.l1_ratio * n_float;
let l2_reg = parameters.alpha * (T::one() - parameters.l1_ratio) * n_float;
let l1_reg = TX::from_f64(parameters.alpha * parameters.l1_ratio * n_float).unwrap();
let l2_reg =
TX::from_f64(parameters.alpha * (1.0 - parameters.l1_ratio) * n_float).unwrap();
let y_mean = y.mean();
let y_mean = TX::from_f64(y.mean_by()).unwrap();
let (w, b) = if parameters.normalize {
let (scaled_x, col_mean, col_std) = Self::rescale_x(x)?;
@@ -186,72 +339,95 @@ impl<T: RealNumber, M: Matrix<T>> ElasticNet<T, M> {
let mut optimizer = InteriorPointOptimizer::new(&x, p);
let mut w =
optimizer.optimize(&x, &y, l1_reg * gamma, parameters.max_iter, parameters.tol)?;
let mut w = optimizer.optimize(
&x,
&y,
l1_reg * gamma,
parameters.max_iter,
TX::from_f64(parameters.tol).unwrap(),
true,
)?;
for i in 0..p {
w.set(i, 0, gamma * w.get(i, 0) / col_std[i]);
w.set(i, gamma * *w.get(i) / col_std[i]);
}
let mut b = T::zero();
let mut b = TX::zero();
for i in 0..p {
b += w.get(i, 0) * col_mean[i];
b += *w.get(i) * col_mean[i];
}
b = y_mean - b;
(w, b)
(X::from_column(&w), b)
} else {
let (x, y, gamma) = Self::augment_x_and_y(x, y, l2_reg);
let mut optimizer = InteriorPointOptimizer::new(&x, p);
let mut w =
optimizer.optimize(&x, &y, l1_reg * gamma, parameters.max_iter, parameters.tol)?;
let mut w = optimizer.optimize(
&x,
&y,
l1_reg * gamma,
parameters.max_iter,
TX::from_f64(parameters.tol).unwrap(),
true,
)?;
for i in 0..p {
w.set(i, 0, gamma * w.get(i, 0));
w.set(i, gamma * *w.get(i));
}
(w, y_mean)
(X::from_column(&w), y_mean)
};
Ok(ElasticNet {
intercept: b,
coefficients: w,
intercept: Some(b),
coefficients: Some(w),
_phantom_ty: PhantomData,
_phantom_y: PhantomData,
})
}
/// Predict target values from `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let (nrows, _) = x.shape();
let mut y_hat = x.matmul(&self.coefficients);
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
Ok(y_hat.transpose().to_row_vector())
let mut y_hat = x.matmul(self.coefficients.as_ref().unwrap());
let bias = X::fill(nrows, 1, self.intercept.unwrap());
y_hat.add_mut(&bias);
Ok(Y::from_iterator(
y_hat.iterator(0).map(|&v| TY::from(v).unwrap()),
nrows,
))
}
/// Get estimates regression coefficients
pub fn coefficients(&self) -> &M {
&self.coefficients
pub fn coefficients(&self) -> &X {
self.coefficients.as_ref().unwrap()
}
/// Get estimate of intercept
pub fn intercept(&self) -> T {
self.intercept
pub fn intercept(&self) -> &TX {
self.intercept.as_ref().unwrap()
}
fn rescale_x(x: &M) -> Result<(M, Vec<T>, Vec<T>), Failed> {
let col_mean = x.mean(0);
let col_std = x.std(0);
fn rescale_x(x: &X) -> Result<(X, Vec<TX>, Vec<TX>), Failed> {
let col_mean: Vec<TX> = x
.mean_by(0)
.iter()
.map(|&v| TX::from_f64(v).unwrap())
.collect();
let col_std: Vec<TX> = x
.std_dev(0)
.iter()
.map(|&v| TX::from_f64(v).unwrap())
.collect();
for i in 0..col_std.len() {
if (col_std[i] - T::zero()).abs() < T::epsilon() {
return Err(Failed::fit(&format!(
"Cannot rescale constant column {}",
i
)));
for (i, col_std_i) in col_std.iter().enumerate() {
if (*col_std_i - TX::zero()).abs() < TX::epsilon() {
return Err(Failed::fit(&format!("Cannot rescale constant column {i}")));
}
}
@@ -260,25 +436,25 @@ impl<T: RealNumber, M: Matrix<T>> ElasticNet<T, M> {
Ok((scaled_x, col_mean, col_std))
}
fn augment_x_and_y(x: &M, y: &M::RowVector, l2_reg: T) -> (M, M::RowVector, T) {
fn augment_x_and_y(x: &X, y: &Y, l2_reg: TX) -> (X, Vec<TX>, TX) {
let (n, p) = x.shape();
let gamma = T::one() / (T::one() + l2_reg).sqrt();
let gamma = TX::one() / (TX::one() + l2_reg).sqrt();
let padding = gamma * l2_reg.sqrt();
let mut y2 = M::RowVector::zeros(n + p);
for i in 0..y.len() {
y2.set(i, y.get(i));
let mut y2 = Vec::<TX>::zeros(n + p);
for i in 0..y.shape() {
y2.set(i, TX::from(*y.get(i)).unwrap());
}
let mut x2 = M::zeros(n + p, p);
let mut x2 = X::zeros(n + p, p);
for j in 0..p {
for i in 0..n {
x2.set(i, j, gamma * x.get(i, j));
x2.set((i, j), gamma * *x.get((i, j)));
}
x2.set(j + n, j, padding);
x2.set((j + n, j), padding);
}
(x2, y2, gamma)
@@ -288,10 +464,36 @@ impl<T: RealNumber, M: Matrix<T>> ElasticNet<T, M> {
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::*;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::metrics::mean_absolute_error;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn search_parameters() {
let parameters = ElasticNetSearchParameters {
alpha: vec![0., 1.],
max_iter: vec![10, 100],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.alpha, 0.);
assert_eq!(next.max_iter, 10);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 1.);
assert_eq!(next.max_iter, 10);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 0.);
assert_eq!(next.max_iter, 100);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 1.);
assert_eq!(next.max_iter, 100);
assert!(iter.next().is_none());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn elasticnet_longley() {
let x = DenseMatrix::from_2d_array(&[
@@ -311,7 +513,8 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]);
])
.unwrap();
let y: Vec<f64> = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
@@ -335,7 +538,10 @@ mod tests {
assert!(mean_absolute_error(&y_hat, &y) < 30.0);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn elasticnet_fit_predict1() {
let x = DenseMatrix::from_2d_array(&[
@@ -359,7 +565,8 @@ mod tests {
&[17.0, 1918.0, 1.4054969025700674],
&[18.0, 1929.0, 1.3271699396384906],
&[19.0, 1915.0, 1.1373332337674806],
]);
])
.unwrap();
let y: Vec<f64> = vec![
1.48, 2.72, 4.52, 5.72, 5.25, 4.07, 3.75, 4.75, 6.77, 4.72, 6.78, 6.79, 8.3, 7.42,
@@ -398,43 +605,44 @@ mod tests {
assert!(mae_l1 < 2.0);
assert!(mae_l2 < 2.0);
assert!(l1_model.coefficients().get(0, 0) > l1_model.coefficients().get(1, 0));
assert!(l1_model.coefficients().get(0, 0) > l1_model.coefficients().get(2, 0));
assert!(l1_model.coefficients().get((0, 0)) > l1_model.coefficients().get((1, 0)));
assert!(l1_model.coefficients().get((0, 0)) > l1_model.coefficients().get((2, 0)));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
#[cfg(feature = "serde")]
fn serde() {
let x = DenseMatrix::from_2d_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
&[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
&[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
&[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
&[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
&[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
&[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]);
// TODO: serialization for the new DenseMatrix needs to be implemented
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
// #[test]
// #[cfg(feature = "serde")]
// fn serde() {
// let x = DenseMatrix::from_2d_array(&[
// &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
// &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
// &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
// &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
// &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
// &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
// &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
// &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
// &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
// &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
// &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
// &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
// &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
// &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
// &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
// &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
// ]).unwrap();
let y = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
];
// let y = vec![
// 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
// 114.2, 115.7, 116.9,
// ];
let lr = ElasticNet::fit(&x, &y, Default::default()).unwrap();
// let lr = ElasticNet::fit(&x, &y, Default::default()).unwrap();
let deserialized_lr: ElasticNet<f64, DenseMatrix<f64>> =
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
// let deserialized_lr: ElasticNet<f64, f64, DenseMatrix<f64>, Vec<f64>> =
// serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
assert_eq!(lr, deserialized_lr);
}
// assert_eq!(lr, deserialized_lr);
// }
}
+362 -98
View File
@@ -9,7 +9,7 @@
//!
//! Lasso coefficient estimates solve the problem:
//!
//! \\[\underset{\beta}{minimize} \space \space \sum_{i=1}^n \left( y_i - \beta_0 - \sum_{j=1}^p \beta_jx_{ij} \right)^2 + \alpha \sum_{j=1}^p \lVert \beta_j \rVert_1\\]
//! \\[\underset{\beta}{minimize} \space \space \frac{1}{n} \sum_{i=1}^n \left( y_i - \beta_0 - \sum_{j=1}^p \beta_jx_{ij} \right)^2 + \alpha \sum_{j=1}^p \lVert \beta_j \rVert_1\\]
//!
//! This problem is solved with an interior-point method that is comparable to coordinate descent in solving large problems with modest accuracy,
//! but is able to solve them with high accuracy with relatively small additional computational cost.
@@ -23,43 +23,54 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::fmt::Debug;
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed;
use crate::linalg::BaseVector;
use crate::linalg::Matrix;
use crate::linalg::basic::arrays::{Array1, Array2, ArrayView1};
use crate::linear::lasso_optimizer::InteriorPointOptimizer;
use crate::math::num::RealNumber;
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
use crate::numbers::realnum::RealNumber;
/// Lasso regression parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LassoParameters<T: RealNumber> {
pub struct LassoParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Controls the strength of the penalty to the loss function.
pub alpha: T,
pub alpha: f64,
#[cfg_attr(feature = "serde", serde(default))]
/// If true the regressors X will be normalized before regression
/// by subtracting the mean and dividing by the standard deviation.
pub normalize: bool,
#[cfg_attr(feature = "serde", serde(default))]
/// The tolerance for the optimization
pub tol: T,
pub tol: f64,
#[cfg_attr(feature = "serde", serde(default))]
/// The maximum number of iterations
pub max_iter: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// If false, force the intercept parameter (beta_0) to be zero.
pub fit_intercept: bool,
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
/// Lasso regressor
pub struct Lasso<T: RealNumber, M: Matrix<T>> {
coefficients: M,
intercept: T,
pub struct Lasso<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
coefficients: Option<X>,
intercept: Option<TX>,
_phantom_ty: PhantomData<TY>,
_phantom_y: PhantomData<Y>,
}
impl<T: RealNumber> LassoParameters<T> {
impl LassoParameters {
/// Regularization parameter.
pub fn with_alpha(mut self, alpha: T) -> Self {
pub fn with_alpha(mut self, alpha: f64) -> Self {
self.alpha = alpha;
self
}
@@ -69,7 +80,7 @@ impl<T: RealNumber> LassoParameters<T> {
self
}
/// The tolerance for the optimization
pub fn with_tol(mut self, tol: T) -> Self {
pub fn with_tol(mut self, tol: f64) -> Self {
self.tol = tol;
self
}
@@ -78,63 +89,200 @@ impl<T: RealNumber> LassoParameters<T> {
self.max_iter = max_iter;
self
}
/// If false, force the intercept parameter (beta_0) to be zero.
pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
self.fit_intercept = fit_intercept;
self
}
}
impl<T: RealNumber> Default for LassoParameters<T> {
impl Default for LassoParameters {
fn default() -> Self {
LassoParameters {
alpha: T::one(),
alpha: 1f64,
normalize: true,
tol: T::from_f64(1e-4).unwrap(),
tol: 1e-4,
max_iter: 1000,
fit_intercept: true,
}
}
}
impl<T: RealNumber, M: Matrix<T>> PartialEq for Lasso<T, M> {
impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
for Lasso<TX, TY, X, Y>
{
fn eq(&self, other: &Self) -> bool {
self.coefficients == other.coefficients
&& (self.intercept - other.intercept).abs() <= T::epsilon()
self.intercept == other.intercept
&& self.coefficients().shape() == other.coefficients().shape()
&& self
.coefficients()
.iterator(0)
.zip(other.coefficients().iterator(0))
.all(|(&a, &b)| (a - b).abs() <= TX::epsilon())
}
}
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, LassoParameters<T>>
for Lasso<T, M>
impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
SupervisedEstimator<X, Y, LassoParameters> for Lasso<TX, TY, X, Y>
{
fn fit(x: &M, y: &M::RowVector, parameters: LassoParameters<T>) -> Result<Self, Failed> {
fn new() -> Self {
Self {
coefficients: None,
intercept: None,
_phantom_ty: PhantomData,
_phantom_y: PhantomData,
}
}
fn fit(x: &X, y: &Y, parameters: LassoParameters) -> Result<Self, Failed> {
Lasso::fit(x, y, parameters)
}
}
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for Lasso<T, M> {
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Predictor<X, Y>
for Lasso<TX, TY, X, Y>
{
fn predict(&self, x: &X) -> Result<Y, Failed> {
self.predict(x)
}
}
impl<T: RealNumber, M: Matrix<T>> Lasso<T, M> {
/// Lasso grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LassoSearchParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Controls the strength of the penalty to the loss function.
pub alpha: Vec<f64>,
#[cfg_attr(feature = "serde", serde(default))]
/// If true the regressors X will be normalized before regression
/// by subtracting the mean and dividing by the standard deviation.
pub normalize: Vec<bool>,
#[cfg_attr(feature = "serde", serde(default))]
/// The tolerance for the optimization
pub tol: Vec<f64>,
#[cfg_attr(feature = "serde", serde(default))]
/// The maximum number of iterations
pub max_iter: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// If false, force the intercept parameter (beta_0) to be zero.
pub fit_intercept: Vec<bool>,
}
/// Lasso grid search iterator
pub struct LassoSearchParametersIterator {
lasso_search_parameters: LassoSearchParameters,
current_alpha: usize,
current_normalize: usize,
current_tol: usize,
current_max_iter: usize,
current_fit_intercept: usize,
}
impl IntoIterator for LassoSearchParameters {
type Item = LassoParameters;
type IntoIter = LassoSearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
LassoSearchParametersIterator {
lasso_search_parameters: self,
current_alpha: 0,
current_normalize: 0,
current_tol: 0,
current_max_iter: 0,
current_fit_intercept: 0,
}
}
}
impl Iterator for LassoSearchParametersIterator {
type Item = LassoParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_alpha == self.lasso_search_parameters.alpha.len()
&& self.current_normalize == self.lasso_search_parameters.normalize.len()
&& self.current_tol == self.lasso_search_parameters.tol.len()
&& self.current_max_iter == self.lasso_search_parameters.max_iter.len()
&& self.current_fit_intercept == self.lasso_search_parameters.fit_intercept.len()
{
return None;
}
let next = LassoParameters {
alpha: self.lasso_search_parameters.alpha[self.current_alpha],
normalize: self.lasso_search_parameters.normalize[self.current_normalize],
tol: self.lasso_search_parameters.tol[self.current_tol],
max_iter: self.lasso_search_parameters.max_iter[self.current_max_iter],
fit_intercept: self.lasso_search_parameters.fit_intercept[self.current_fit_intercept],
};
if self.current_alpha + 1 < self.lasso_search_parameters.alpha.len() {
self.current_alpha += 1;
} else if self.current_normalize + 1 < self.lasso_search_parameters.normalize.len() {
self.current_alpha = 0;
self.current_normalize += 1;
} else if self.current_tol + 1 < self.lasso_search_parameters.tol.len() {
self.current_alpha = 0;
self.current_normalize = 0;
self.current_tol += 1;
} else if self.current_max_iter + 1 < self.lasso_search_parameters.max_iter.len() {
self.current_alpha = 0;
self.current_normalize = 0;
self.current_tol = 0;
self.current_max_iter += 1;
} else if self.current_fit_intercept + 1 < self.lasso_search_parameters.fit_intercept.len()
{
self.current_alpha = 0;
self.current_normalize = 0;
self.current_tol = 0;
self.current_max_iter = 0;
self.current_fit_intercept += 1;
} else {
self.current_alpha += 1;
self.current_normalize += 1;
self.current_tol += 1;
self.current_max_iter += 1;
self.current_fit_intercept += 1;
}
Some(next)
}
}
impl Default for LassoSearchParameters {
fn default() -> Self {
let default_params = LassoParameters::default();
LassoSearchParameters {
alpha: vec![default_params.alpha],
normalize: vec![default_params.normalize],
tol: vec![default_params.tol],
max_iter: vec![default_params.max_iter],
fit_intercept: vec![default_params.fit_intercept],
}
}
}
impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Lasso<TX, TY, X, Y> {
/// Fits Lasso regression to your data.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - target values
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
pub fn fit(
x: &M,
y: &M::RowVector,
parameters: LassoParameters<T>,
) -> Result<Lasso<T, M>, Failed> {
pub fn fit(x: &X, y: &Y, parameters: LassoParameters) -> Result<Lasso<TX, TY, X, Y>, Failed> {
let (n, p) = x.shape();
if n <= p {
if n < p {
return Err(Failed::fit(
"Number of rows in X should be >= number of columns in X",
));
}
if parameters.alpha < T::zero() {
if parameters.alpha < 0f64 {
return Err(Failed::fit("alpha should be >= 0"));
}
if parameters.tol <= T::zero() {
if parameters.tol <= 0f64 {
return Err(Failed::fit("tol should be > 0"));
}
@@ -142,75 +290,111 @@ impl<T: RealNumber, M: Matrix<T>> Lasso<T, M> {
return Err(Failed::fit("max_iter should be > 0"));
}
if y.len() != n {
if y.shape() != n {
return Err(Failed::fit("Number of rows in X should = len(y)"));
}
let l1_reg = parameters.alpha * T::from_usize(n).unwrap();
let y: Vec<TX> = y.iterator(0).map(|&v| TX::from(v).unwrap()).collect();
let l1_reg = TX::from_f64(parameters.alpha * n as f64).unwrap();
let (w, b) = if parameters.normalize {
let (scaled_x, col_mean, col_std) = Self::rescale_x(x)?;
let mut optimizer = InteriorPointOptimizer::new(&scaled_x, p);
let mut w =
optimizer.optimize(&scaled_x, y, l1_reg, parameters.max_iter, parameters.tol)?;
let mut w = optimizer.optimize(
&scaled_x,
&y,
l1_reg,
parameters.max_iter,
TX::from_f64(parameters.tol).unwrap(),
parameters.fit_intercept,
)?;
for (j, col_std_j) in col_std.iter().enumerate().take(p) {
w.set(j, 0, w.get(j, 0) / *col_std_j);
w[j] /= *col_std_j;
}
let mut b = T::zero();
let b = if parameters.fit_intercept {
let mut xw_mean = TX::zero();
for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
xw_mean += w[i] * *col_mean_i;
}
for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
b += w.get(i, 0) * *col_mean_i;
}
b = y.mean() - b;
(w, b)
Some(TX::from_f64(y.mean_by()).unwrap() - xw_mean)
} else {
None
};
(X::from_column(&w), b)
} else {
let mut optimizer = InteriorPointOptimizer::new(x, p);
let w = optimizer.optimize(x, y, l1_reg, parameters.max_iter, parameters.tol)?;
let w = optimizer.optimize(
x,
&y,
l1_reg,
parameters.max_iter,
TX::from_f64(parameters.tol).unwrap(),
parameters.fit_intercept,
)?;
(w, y.mean())
(
X::from_column(&w),
if parameters.fit_intercept {
Some(TX::from_f64(y.mean_by()).unwrap())
} else {
None
},
)
};
Ok(Lasso {
intercept: b,
coefficients: w,
coefficients: Some(w),
_phantom_ty: PhantomData,
_phantom_y: PhantomData,
})
}
/// Predict target values from `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let (nrows, _) = x.shape();
let mut y_hat = x.matmul(&self.coefficients);
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
Ok(y_hat.transpose().to_row_vector())
let mut y_hat = x.matmul(self.coefficients());
let bias = X::fill(nrows, 1, self.intercept.unwrap());
y_hat.add_mut(&bias);
Ok(Y::from_iterator(
y_hat.iterator(0).map(|&v| TY::from(v).unwrap()),
nrows,
))
}
/// Get estimates regression coefficients
pub fn coefficients(&self) -> &M {
&self.coefficients
pub fn coefficients(&self) -> &X {
self.coefficients.as_ref().unwrap()
}
/// Get estimate of intercept
pub fn intercept(&self) -> T {
self.intercept
pub fn intercept(&self) -> &TX {
self.intercept.as_ref().unwrap()
}
fn rescale_x(x: &M) -> Result<(M, Vec<T>, Vec<T>), Failed> {
let col_mean = x.mean(0);
let col_std = x.std(0);
fn rescale_x(x: &X) -> Result<(X, Vec<TX>, Vec<TX>), Failed> {
let col_mean: Vec<TX> = x
.mean_by(0)
.iter()
.map(|&v| TX::from_f64(v).unwrap())
.collect();
let col_std: Vec<TX> = x
.std_dev(0)
.iter()
.map(|&v| TX::from_f64(v).unwrap())
.collect();
for (i, col_std_i) in col_std.iter().enumerate() {
if (*col_std_i - T::zero()).abs() < T::epsilon() {
return Err(Failed::fit(&format!(
"Cannot rescale constant column {}",
i
)));
if (*col_std_i - TX::zero()).abs() < TX::epsilon() {
return Err(Failed::fit(&format!("Cannot rescale constant column {i}")));
}
}
@@ -223,12 +407,37 @@ impl<T: RealNumber, M: Matrix<T>> Lasso<T, M> {
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::*;
use crate::linalg::basic::arrays::Array;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::metrics::mean_absolute_error;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn lasso_fit_predict() {
fn search_parameters() {
let parameters = LassoSearchParameters {
alpha: vec![0., 1.],
max_iter: vec![10, 100],
fit_intercept: vec![false, true],
..Default::default()
};
let mut iter = parameters.clone().into_iter();
for current_fit_intercept in 0..parameters.fit_intercept.len() {
for current_max_iter in 0..parameters.max_iter.len() {
for current_alpha in 0..parameters.alpha.len() {
let next = iter.next().unwrap();
assert_eq!(next.alpha, parameters.alpha[current_alpha]);
assert_eq!(next.max_iter, parameters.max_iter[current_max_iter]);
assert_eq!(
next.fit_intercept,
parameters.fit_intercept[current_fit_intercept]
);
}
}
}
assert!(iter.next().is_none());
}
fn get_example_x_y() -> (DenseMatrix<f64>, Vec<f64>) {
let x = DenseMatrix::from_2d_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
@@ -246,13 +455,25 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]);
])
.unwrap();
let y: Vec<f64> = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
];
(x, y)
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn lasso_fit_predict() {
let (x, y) = get_example_x_y();
let y_hat = Lasso::fit(&x, &y, Default::default())
.and_then(|lr| lr.predict(&x))
.unwrap();
@@ -267,6 +488,7 @@ mod tests {
normalize: false,
tol: 1e-4,
max_iter: 1000,
fit_intercept: true,
},
)
.and_then(|lr| lr.predict(&x))
@@ -275,39 +497,81 @@ mod tests {
assert!(mean_absolute_error(&y_hat, &y) < 2.0);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
#[cfg(feature = "serde")]
fn serde() {
fn test_full_rank_x() {
// x: randn(3,3) * 10, demean, then round to 2 decimal points
// y = x @ [10.0, 0.2, -3.0], round to 2 decimal points
let param = LassoParameters::default()
.with_normalize(false)
.with_alpha(200.0);
let x = DenseMatrix::from_2d_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
&[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
&[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
&[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
&[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
&[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
&[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]);
&[-8.9, -2.24, 8.89],
&[-4.02, 8.89, 12.33],
&[12.92, -6.65, -21.22],
])
.unwrap();
let y = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
];
let y = vec![-116.12, -75.41, 191.53];
let w = Lasso::fit(&x, &y, param)
.unwrap()
.coefficients()
.iterator(0)
.copied()
.collect();
let lr = Lasso::fit(&x, &y, Default::default()).unwrap();
let deserialized_lr: Lasso<f64, DenseMatrix<f64>> =
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
assert_eq!(lr, deserialized_lr);
let expected_w = vec![5.20289531, 0., -5.32823882]; // by coordinate descent
assert!(mean_absolute_error(&w, &expected_w) < 1e-3); // actual mean_absolute_error is about 2e-4
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn test_fit_intercept() {
let (x, y) = get_example_x_y();
let fit_result = Lasso::fit(
&x,
&y,
LassoParameters {
alpha: 0.1,
normalize: false,
tol: 1e-8,
max_iter: 1000,
fit_intercept: false,
},
)
.unwrap();
let w = fit_result.coefficients().iterator(0).copied().collect();
// by sklearn LassoLars. coordinate descent doesn't converge well
let expected_w = vec![
0.18335684,
0.02106526,
0.00703214,
-1.35952542,
0.09295222,
0.,
];
assert!(mean_absolute_error(&w, &expected_w) < 1e-6);
assert_eq!(fit_result.intercept, None);
}
// TODO: serialization for the new DenseMatrix needs to be implemented
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
// #[test]
// #[cfg(feature = "serde")]
// fn serde() {
// let (x, y) = get_lasso_sample_x_y();
// let lr = Lasso::fit(&x, &y, Default::default()).unwrap();
// let deserialized_lr: Lasso<f64, f64, DenseMatrix<f64>, Vec<f64>> =
// serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
// assert_eq!(lr, deserialized_lr);
// }
}
+75 -78
View File
@@ -12,21 +12,22 @@
//!
use crate::error::Failed;
use crate::linalg::BaseVector;
use crate::linalg::Matrix;
use crate::linalg::basic::arrays::{Array1, Array2, ArrayView1, MutArray, MutArrayView1};
use crate::linear::bg_solver::BiconjugateGradientSolver;
use crate::math::num::RealNumber;
use crate::numbers::floatnum::FloatNumber;
pub struct InteriorPointOptimizer<T: RealNumber, M: Matrix<T>> {
ata: M,
/// Interior Point Optimizer
pub struct InteriorPointOptimizer<T: FloatNumber, X: Array2<T>> {
ata: X,
d1: Vec<T>,
d2: Vec<T>,
prb: Vec<T>,
prs: Vec<T>,
}
impl<T: RealNumber, M: Matrix<T>> InteriorPointOptimizer<T, M> {
pub fn new(a: &M, n: usize) -> InteriorPointOptimizer<T, M> {
impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
/// Initialize a new Interior Point Optimizer
pub fn new(a: &X, n: usize) -> InteriorPointOptimizer<T, X> {
InteriorPointOptimizer {
ata: a.ab(true, a, false),
d1: vec![T::zero(); n],
@@ -36,20 +37,23 @@ impl<T: RealNumber, M: Matrix<T>> InteriorPointOptimizer<T, M> {
}
}
/// Run the optimization
pub fn optimize(
&mut self,
x: &M,
y: &M::RowVector,
x: &X,
y: &Vec<T>,
lambda: T,
max_iter: usize,
tol: T,
) -> Result<M, Failed> {
fit_intercept: bool,
) -> Result<Vec<T>, Failed> {
let (n, p) = x.shape();
let p_f64 = T::from_usize(p).unwrap();
let lambda = lambda.max(T::epsilon());
//parameters
let max_ls_iter = 100;
let pcgmaxi = 5000;
let min_pcgtol = T::from_f64(0.1).unwrap();
let eta = T::from_f64(1E-3).unwrap();
@@ -58,50 +62,56 @@ impl<T: RealNumber, M: Matrix<T>> InteriorPointOptimizer<T, M> {
let gamma = T::from_f64(-0.25).unwrap();
let mu = T::two();
let y = M::from_row_vector(y.sub_scalar(y.mean())).transpose();
// let y = M::from_row_vector(y.sub_scalar(y.mean_by())).transpose();
let y = if fit_intercept {
y.sub_scalar(T::from_f64(y.mean_by()).unwrap())
} else {
y.to_owned()
};
let mut max_ls_iter = 100;
let mut pitr = 0;
let mut w = M::zeros(p, 1);
let mut w = Vec::zeros(p);
let mut neww = w.clone();
let mut u = M::ones(p, 1);
let mut u = Vec::ones(p);
let mut newu = u.clone();
let mut f = M::fill(p, 2, -T::one());
let mut f = X::fill(p, 2, -T::one());
let mut newf = f.clone();
let mut q1 = vec![T::zero(); p];
let mut q2 = vec![T::zero(); p];
let mut dx = M::zeros(p, 1);
let mut du = M::zeros(p, 1);
let mut dxu = M::zeros(2 * p, 1);
let mut grad = M::zeros(2 * p, 1);
let mut dx = Vec::zeros(p);
let mut du = Vec::zeros(p);
let mut dxu = Vec::zeros(2 * p);
let mut grad = Vec::zeros(2 * p);
let mut nu = M::zeros(n, 1);
let mut nu = Vec::zeros(n);
let mut dobj = T::zero();
let mut s = T::infinity();
let mut t = T::one()
.max(T::one() / lambda)
.min(T::two() * p_f64 / T::from(1e-3).unwrap());
let lambda_f64 = lambda.to_f64().unwrap();
for ntiter in 0..max_iter {
let mut z = x.matmul(&w);
let mut z = w.xa(true, x);
for i in 0..n {
z.set(i, 0, z.get(i, 0) - y.get(i, 0));
nu.set(i, 0, T::two() * z.get(i, 0));
z[i] -= y[i];
nu[i] = T::two() * z[i];
}
// CALCULATE DUALITY GAP
let xnu = x.ab(true, &nu, false);
let max_xnu = xnu.norm(T::infinity());
if max_xnu > lambda {
let lnu = lambda / max_xnu;
let xnu = nu.xa(false, x);
let max_xnu = xnu.norm(f64::INFINITY);
if max_xnu > lambda_f64 {
let lnu = T::from_f64(lambda_f64 / max_xnu).unwrap();
nu.mul_scalar_mut(lnu);
}
let pobj = z.dot(&z) + lambda * w.norm(T::one());
let pobj = z.dot(&z) + lambda * T::from_f64(w.norm(1f64)).unwrap();
dobj = dobj.max(gamma * nu.dot(&nu) - nu.dot(&y));
let gap = pobj - dobj;
@@ -118,22 +128,22 @@ impl<T: RealNumber, M: Matrix<T>> InteriorPointOptimizer<T, M> {
// CALCULATE NEWTON STEP
for i in 0..p {
let q1i = T::one() / (u.get(i, 0) + w.get(i, 0));
let q2i = T::one() / (u.get(i, 0) - w.get(i, 0));
let q1i = T::one() / (u[i] + w[i]);
let q2i = T::one() / (u[i] - w[i]);
q1[i] = q1i;
q2[i] = q2i;
self.d1[i] = (q1i * q1i + q2i * q2i) / t;
self.d2[i] = (q1i * q1i - q2i * q2i) / t;
}
let mut gradphi = x.ab(true, &z, false);
let mut gradphi = z.xa(false, x);
for i in 0..p {
let g1 = T::two() * gradphi.get(i, 0) - (q1[i] - q2[i]) / t;
let g1 = T::two() * gradphi[i] - (q1[i] - q2[i]) / t;
let g2 = lambda - (q1[i] + q2[i]) / t;
gradphi.set(i, 0, g1);
grad.set(i, 0, -g1);
grad.set(i + p, 0, -g2);
gradphi[i] = g1;
grad[i] = -g1;
grad[i + p] = -g2;
}
for i in 0..p {
@@ -141,7 +151,7 @@ impl<T: RealNumber, M: Matrix<T>> InteriorPointOptimizer<T, M> {
self.prs[i] = self.prb[i] * self.d1[i] - self.d2[i].powi(2);
}
let normg = grad.norm2();
let normg = T::from_f64(grad.norm2()).unwrap();
let mut pcgtol = min_pcgtol.min(eta * gap / T::one().min(normg));
if ntiter != 0 && pitr == 0 {
pcgtol *= min_pcgtol;
@@ -152,29 +162,31 @@ impl<T: RealNumber, M: Matrix<T>> InteriorPointOptimizer<T, M> {
pitr = pcgmaxi;
}
for i in 0..p {
dx.set(i, 0, dxu.get(i, 0));
du.set(i, 0, dxu.get(i + p, 0));
}
dx[..p].copy_from_slice(&dxu[..p]);
du[..p].copy_from_slice(&dxu[p..(p + p)]);
// BACKTRACKING LINE SEARCH
let phi = z.dot(&z) + lambda * u.sum() - Self::sumlogneg(&f) / t;
s = T::one();
let gdx = grad.dot(&dxu);
let lsiter = 0;
let mut lsiter = 0;
while lsiter < max_ls_iter {
for i in 0..p {
neww.set(i, 0, w.get(i, 0) + s * dx.get(i, 0));
newu.set(i, 0, u.get(i, 0) + s * du.get(i, 0));
newf.set(i, 0, neww.get(i, 0) - newu.get(i, 0));
newf.set(i, 1, -neww.get(i, 0) - newu.get(i, 0));
neww[i] = w[i] + s * dx[i];
newu[i] = u[i] + s * du[i];
newf.set((i, 0), neww[i] - newu[i]);
newf.set((i, 1), -neww[i] - newu[i]);
}
if newf.max() < T::zero() {
let mut newz = x.matmul(&neww);
if newf
.iterator(0)
.fold(T::neg_infinity(), |max, v| v.max(max))
< T::zero()
{
let mut newz = neww.xa(true, x);
for i in 0..n {
newz.set(i, 0, newz.get(i, 0) - y.get(i, 0));
newz[i] -= y[i];
}
let newphi = newz.dot(&newz) + lambda * newu.sum() - Self::sumlogneg(&newf) / t;
@@ -183,7 +195,7 @@ impl<T: RealNumber, M: Matrix<T>> InteriorPointOptimizer<T, M> {
}
}
s = beta * s;
max_ls_iter += 1;
lsiter += 1;
}
if lsiter == max_ls_iter {
@@ -200,56 +212,41 @@ impl<T: RealNumber, M: Matrix<T>> InteriorPointOptimizer<T, M> {
Ok(w)
}
fn sumlogneg(f: &M) -> T {
fn sumlogneg(f: &X) -> T {
let (n, _) = f.shape();
let mut sum = T::zero();
for i in 0..n {
sum += (-f.get(i, 0)).ln();
sum += (-f.get(i, 1)).ln();
sum += (-*f.get((i, 0))).ln();
sum += (-*f.get((i, 1))).ln();
}
sum
}
}
impl<'a, T: RealNumber, M: Matrix<T>> BiconjugateGradientSolver<T, M>
for InteriorPointOptimizer<T, M>
impl<'a, T: FloatNumber, X: Array2<T>> BiconjugateGradientSolver<'a, T, X>
for InteriorPointOptimizer<T, X>
{
fn solve_preconditioner(&self, a: &M, b: &M, x: &mut M) {
fn solve_preconditioner(&self, a: &'a X, b: &[T], x: &mut [T]) {
let (_, p) = a.shape();
for i in 0..p {
x.set(
i,
0,
(self.d1[i] * b.get(i, 0) - self.d2[i] * b.get(i + p, 0)) / self.prs[i],
);
x.set(
i + p,
0,
(-self.d2[i] * b.get(i, 0) + self.prb[i] * b.get(i + p, 0)) / self.prs[i],
);
x[i] = (self.d1[i] * b[i] - self.d2[i] * b[i + p]) / self.prs[i];
x[i + p] = (-self.d2[i] * b[i] + self.prb[i] * b[i + p]) / self.prs[i];
}
}
fn mat_vec_mul(&self, _: &M, x: &M, y: &mut M) {
fn mat_vec_mul(&self, _: &X, x: &Vec<T>, y: &mut Vec<T>) {
let (_, p) = self.ata.shape();
let atax = self.ata.matmul(&x.slice(0..p, 0..1));
let x_slice = Vec::from_slice(x.slice(0..p).as_ref());
let atax = x_slice.xa(true, &self.ata);
for i in 0..p {
y.set(
i,
0,
T::two() * atax.get(i, 0) + self.d1[i] * x.get(i, 0) + self.d2[i] * x.get(i + p, 0),
);
y.set(
i + p,
0,
self.d2[i] * x.get(i, 0) + self.d1[i] * x.get(i + p, 0),
);
y[i] = T::two() * atax[i] + self.d1[i] * x[i] + self.d2[i] * x[i + p];
y[i + p] = self.d2[i] * x[i] + self.d1[i] * x[i + p];
}
}
fn mat_t_vec_mul(&self, a: &M, x: &M, y: &mut M) {
fn mat_t_vec_mul(&self, a: &X, x: &Vec<T>, y: &mut Vec<T>) {
self.mat_vec_mul(a, x, y);
}
}
+221 -95
View File
@@ -12,14 +12,14 @@
//! \\[\hat{\beta} = (X^TX)^{-1}X^Ty \\]
//!
//! the \\((X^TX)^{-1}\\) term is both computationally expensive and numerically unstable. An alternative approach is to use a matrix decomposition to avoid this operation.
//! SmartCore uses [SVD](../../linalg/svd/index.html) and [QR](../../linalg/qr/index.html) matrix decomposition to find estimates of \\(\hat{\beta}\\).
//! `smartcore` uses [SVD](../../linalg/svd/index.html) and [QR](../../linalg/qr/index.html) matrix decomposition to find estimates of \\(\hat{\beta}\\).
//! The QR decomposition is more computationally efficient and more numerically stable than calculating the normal equation directly,
//! but does not work for all data matrices. Unlike the QR decomposition, all matrices have an SVD decomposition.
//!
//! Example:
//!
//! ```
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linear::linear_regression::*;
//!
//! // Longley dataset (https://www.statsmodels.org/stable/datasets/generated/longley.html)
@@ -40,7 +40,7 @@
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
//! ]);
//! ]).unwrap();
//!
//! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0,
//! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
@@ -61,21 +61,26 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::fmt::Debug;
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::linalg::traits::qr::QRDecomposable;
use crate::linalg::traits::svd::SVDDecomposable;
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
#[derive(Debug, Default, Clone, Eq, PartialEq)]
/// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable.
pub enum LinearRegressionSolverName {
/// QR decomposition, see [QR](../../linalg/qr/index.html)
QR,
#[default]
/// SVD decomposition, see [SVD](../../linalg/svd/index.html)
SVD,
}
@@ -84,27 +89,11 @@ pub enum LinearRegressionSolverName {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LinearRegressionParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Solver to use for estimation of regression coefficients.
pub solver: LinearRegressionSolverName,
}
/// Linear Regression
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct LinearRegression<T: RealNumber, M: Matrix<T>> {
coefficients: M,
intercept: T,
_solver: LinearRegressionSolverName,
}
impl LinearRegressionParameters {
/// Solver to use for estimation of regression coefficients.
pub fn with_solver(mut self, solver: LinearRegressionSolverName) -> Self {
self.solver = solver;
self
}
}
impl Default for LinearRegressionParameters {
fn default() -> Self {
LinearRegressionParameters {
@@ -113,43 +102,157 @@ impl Default for LinearRegressionParameters {
}
}
impl<T: RealNumber, M: Matrix<T>> PartialEq for LinearRegression<T, M> {
fn eq(&self, other: &Self) -> bool {
self.coefficients == other.coefficients
&& (self.intercept - other.intercept).abs() <= T::epsilon()
/// Linear Regression
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct LinearRegression<
TX: Number + RealNumber,
TY: Number,
X: Array2<TX> + QRDecomposable<TX> + SVDDecomposable<TX>,
Y: Array1<TY>,
> {
coefficients: Option<X>,
intercept: Option<TX>,
_phantom_ty: PhantomData<TY>,
_phantom_y: PhantomData<Y>,
}
impl LinearRegressionParameters {
/// Solver to use for estimation of regression coefficients.
pub fn with_solver(mut self, solver: LinearRegressionSolverName) -> Self {
self.solver = solver;
self
}
}
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, LinearRegressionParameters>
for LinearRegression<T, M>
/// Linear Regression grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LinearRegressionSearchParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Solver to use for estimation of regression coefficients.
pub solver: Vec<LinearRegressionSolverName>,
}
/// Linear Regression grid search iterator
pub struct LinearRegressionSearchParametersIterator {
linear_regression_search_parameters: LinearRegressionSearchParameters,
current_solver: usize,
}
impl IntoIterator for LinearRegressionSearchParameters {
type Item = LinearRegressionParameters;
type IntoIter = LinearRegressionSearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
LinearRegressionSearchParametersIterator {
linear_regression_search_parameters: self,
current_solver: 0,
}
}
}
impl Iterator for LinearRegressionSearchParametersIterator {
type Item = LinearRegressionParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_solver == self.linear_regression_search_parameters.solver.len() {
return None;
}
let next = LinearRegressionParameters {
solver: self.linear_regression_search_parameters.solver[self.current_solver].clone(),
};
self.current_solver += 1;
Some(next)
}
}
impl Default for LinearRegressionSearchParameters {
fn default() -> Self {
let default_params = LinearRegressionParameters::default();
LinearRegressionSearchParameters {
solver: vec![default_params.solver],
}
}
}
impl<
TX: Number + RealNumber,
TY: Number,
X: Array2<TX> + QRDecomposable<TX> + SVDDecomposable<TX>,
Y: Array1<TY>,
> PartialEq for LinearRegression<TX, TY, X, Y>
{
fn fit(
x: &M,
y: &M::RowVector,
parameters: LinearRegressionParameters,
) -> Result<Self, Failed> {
fn eq(&self, other: &Self) -> bool {
self.intercept == other.intercept
&& self.coefficients().shape() == other.coefficients().shape()
&& self
.coefficients()
.iterator(0)
.zip(other.coefficients().iterator(0))
.all(|(&a, &b)| (a - b).abs() <= TX::epsilon())
}
}
impl<
TX: Number + RealNumber,
TY: Number,
X: Array2<TX> + QRDecomposable<TX> + SVDDecomposable<TX>,
Y: Array1<TY>,
> SupervisedEstimator<X, Y, LinearRegressionParameters> for LinearRegression<TX, TY, X, Y>
{
fn new() -> Self {
Self {
coefficients: Option::None,
intercept: Option::None,
_phantom_ty: PhantomData,
_phantom_y: PhantomData,
}
}
fn fit(x: &X, y: &Y, parameters: LinearRegressionParameters) -> Result<Self, Failed> {
LinearRegression::fit(x, y, parameters)
}
}
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for LinearRegression<T, M> {
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
impl<
TX: Number + RealNumber,
TY: Number,
X: Array2<TX> + QRDecomposable<TX> + SVDDecomposable<TX>,
Y: Array1<TY>,
> Predictor<X, Y> for LinearRegression<TX, TY, X, Y>
{
fn predict(&self, x: &X) -> Result<Y, Failed> {
self.predict(x)
}
}
impl<T: RealNumber, M: Matrix<T>> LinearRegression<T, M> {
impl<
TX: Number + RealNumber,
TY: Number,
X: Array2<TX> + QRDecomposable<TX> + SVDDecomposable<TX>,
Y: Array1<TY>,
> LinearRegression<TX, TY, X, Y>
{
/// Fits Linear Regression to your data.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - target values
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
pub fn fit(
x: &M,
y: &M::RowVector,
x: &X,
y: &Y,
parameters: LinearRegressionParameters,
) -> Result<LinearRegression<T, M>, Failed> {
let y_m = M::from_row_vector(y.clone());
let b = y_m.transpose();
) -> Result<LinearRegression<TX, TY, X, Y>, Failed> {
let b = X::from_iterator(
y.iterator(0).map(|&v| TX::from(v).unwrap()),
y.shape(),
1,
0,
);
let (x_nrows, num_attributes) = x.shape();
let (y_nrows, _) = b.shape();
@@ -159,59 +262,77 @@ impl<T: RealNumber, M: Matrix<T>> LinearRegression<T, M> {
));
}
let a = x.h_stack(&M::ones(x_nrows, 1));
let a = x.h_stack(&X::ones(x_nrows, 1));
let w = match parameters.solver {
LinearRegressionSolverName::QR => a.qr_solve_mut(b)?,
LinearRegressionSolverName::SVD => a.svd_solve_mut(b)?,
};
let wights = w.slice(0..num_attributes, 0..1);
let weights = X::from_slice(w.slice(0..num_attributes, 0..1).as_ref());
Ok(LinearRegression {
intercept: w.get(num_attributes, 0),
coefficients: wights,
_solver: parameters.solver,
intercept: Some(*w.get((num_attributes, 0))),
coefficients: Some(weights),
_phantom_ty: PhantomData,
_phantom_y: PhantomData,
})
}
/// Predict target values from `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let (nrows, _) = x.shape();
let mut y_hat = x.matmul(&self.coefficients);
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
Ok(y_hat.transpose().to_row_vector())
let bias = X::fill(nrows, 1, *self.intercept());
let mut y_hat = x.matmul(self.coefficients());
y_hat.add_mut(&bias);
Ok(Y::from_iterator(
y_hat.iterator(0).map(|&v| TY::from(v).unwrap()),
nrows,
))
}
/// Get estimates regression coefficients
pub fn coefficients(&self) -> &M {
&self.coefficients
pub fn coefficients(&self) -> &X {
self.coefficients.as_ref().unwrap()
}
/// Get estimate of intercept
pub fn intercept(&self) -> T {
self.intercept
pub fn intercept(&self) -> &TX {
self.intercept.as_ref().unwrap()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::*;
use crate::linalg::basic::matrix::DenseMatrix;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn search_parameters() {
let parameters = LinearRegressionSearchParameters {
solver: vec![
LinearRegressionSolverName::QR,
LinearRegressionSolverName::SVD,
],
};
let mut iter = parameters.into_iter();
assert_eq!(iter.next().unwrap().solver, LinearRegressionSolverName::QR);
assert_eq!(iter.next().unwrap().solver, LinearRegressionSolverName::SVD);
assert!(iter.next().is_none());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn ols_fit_predict() {
let x = DenseMatrix::from_2d_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
&[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
&[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
&[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
&[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
&[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
@@ -220,11 +341,11 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]);
])
.unwrap();
let y: Vec<f64> = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8,
];
let y_hat_qr = LinearRegression::fit(
@@ -251,39 +372,44 @@ mod tests {
.all(|(&a, &b)| (a - b).abs() <= 5.0));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
#[cfg(feature = "serde")]
fn serde() {
let x = DenseMatrix::from_2d_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
&[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
&[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
&[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
&[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
&[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
&[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]);
// TODO: serialization for the new DenseMatrix needs to be implemented
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
// #[test]
// #[cfg(feature = "serde")]
// fn serde() {
// let x = DenseMatrix::from_2d_array(&[
// &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
// &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
// &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
// &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
// &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
// &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
// &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
// &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
// &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
// &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
// &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
// &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
// &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
// &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
// &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
// &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
// ]).unwrap();
let y = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
];
// let y = vec![
// 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
// 114.2, 115.7, 116.9,
// ];
let lr = LinearRegression::fit(&x, &y, Default::default()).unwrap();
// let lr = LinearRegression::fit(&x, &y, Default::default()).unwrap();
let deserialized_lr: LinearRegression<f64, DenseMatrix<f64>> =
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
// let deserialized_lr: LinearRegression<f64, f64, DenseMatrix<f64>, Vec<f64>> =
// serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
assert_eq!(lr, deserialized_lr);
}
// assert_eq!(lr, deserialized_lr);
// let default = LinearRegressionParameters::default();
// let parameters: LinearRegressionParameters = serde_json::from_str("{}").unwrap();
// assert_eq!(parameters.solver, default.solver);
// }
}
File diff suppressed because it is too large Load Diff
+2 -2
View File
@@ -20,10 +20,10 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
pub(crate) mod bg_solver;
pub mod bg_solver;
pub mod elastic_net;
pub mod lasso;
pub(crate) mod lasso_optimizer;
pub mod lasso_optimizer;
pub mod linear_regression;
pub mod logistic_regression;
pub mod ridge_regression;
+258 -95
View File
@@ -12,14 +12,14 @@
//! where \\(\alpha \geq 0\\) is a tuning parameter that controls strength of regularization. When \\(\alpha = 0\\) the penalty term has no effect, and ridge regression will produce the least squares estimates.
//! However, as \\(\alpha \rightarrow \infty\\), the impact of the shrinkage penalty grows, and the ridge regression coefficient estimates will approach zero.
//!
//! SmartCore uses [SVD](../../linalg/svd/index.html) and [Cholesky](../../linalg/cholesky/index.html) matrix decomposition to find estimates of \\(\hat{\beta}\\).
//! `smartcore` uses [SVD](../../linalg/svd/index.html) and [Cholesky](../../linalg/cholesky/index.html) matrix decomposition to find estimates of \\(\hat{\beta}\\).
//! The Cholesky decomposition is more computationally efficient and more numerically stable than calculating the normal equation directly,
//! but does not work for all data matrices. Unlike the Cholesky decomposition, all matrices have an SVD decomposition.
//!
//! Example:
//!
//! ```
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linear::ridge_regression::*;
//!
//! // Longley dataset (https://www.statsmodels.org/stable/datasets/generated/longley.html)
@@ -40,7 +40,7 @@
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
//! ]);
//! ]).unwrap();
//!
//! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0,
//! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
@@ -57,21 +57,25 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::fmt::Debug;
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed;
use crate::linalg::BaseVector;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::linalg::traits::cholesky::CholeskyDecomposable;
use crate::linalg::traits::svd::SVDDecomposable;
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Eq, PartialEq, Default)]
/// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable.
pub enum RidgeRegressionSolverName {
/// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html)
#[default]
Cholesky,
/// SVD decomposition, see [SVD](../../linalg/svd/index.html)
SVD,
@@ -80,7 +84,7 @@ pub enum RidgeRegressionSolverName {
/// Ridge Regression parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct RidgeRegressionParameters<T: RealNumber> {
pub struct RidgeRegressionParameters<T: Number + RealNumber> {
/// Solver to use for estimation of regression coefficients.
pub solver: RidgeRegressionSolverName,
/// Controls the strength of the penalty to the loss function.
@@ -90,16 +94,109 @@ pub struct RidgeRegressionParameters<T: RealNumber> {
pub normalize: bool,
}
/// Ridge Regression grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct RidgeRegressionSearchParameters<T: Number + RealNumber> {
#[cfg_attr(feature = "serde", serde(default))]
/// Solver to use for estimation of regression coefficients.
pub solver: Vec<RidgeRegressionSolverName>,
#[cfg_attr(feature = "serde", serde(default))]
/// Regularization parameter.
pub alpha: Vec<T>,
#[cfg_attr(feature = "serde", serde(default))]
/// If true the regressors X will be normalized before regression
/// by subtracting the mean and dividing by the standard deviation.
pub normalize: Vec<bool>,
}
/// Ridge Regression grid search iterator
pub struct RidgeRegressionSearchParametersIterator<T: Number + RealNumber> {
ridge_regression_search_parameters: RidgeRegressionSearchParameters<T>,
current_solver: usize,
current_alpha: usize,
current_normalize: usize,
}
impl<T: Number + RealNumber> IntoIterator for RidgeRegressionSearchParameters<T> {
type Item = RidgeRegressionParameters<T>;
type IntoIter = RidgeRegressionSearchParametersIterator<T>;
fn into_iter(self) -> Self::IntoIter {
RidgeRegressionSearchParametersIterator {
ridge_regression_search_parameters: self,
current_solver: 0,
current_alpha: 0,
current_normalize: 0,
}
}
}
impl<T: Number + RealNumber> Iterator for RidgeRegressionSearchParametersIterator<T> {
type Item = RidgeRegressionParameters<T>;
fn next(&mut self) -> Option<Self::Item> {
if self.current_alpha == self.ridge_regression_search_parameters.alpha.len()
&& self.current_solver == self.ridge_regression_search_parameters.solver.len()
{
return None;
}
let next = RidgeRegressionParameters {
solver: self.ridge_regression_search_parameters.solver[self.current_solver].clone(),
alpha: self.ridge_regression_search_parameters.alpha[self.current_alpha],
normalize: self.ridge_regression_search_parameters.normalize[self.current_normalize],
};
if self.current_alpha + 1 < self.ridge_regression_search_parameters.alpha.len() {
self.current_alpha += 1;
} else if self.current_solver + 1 < self.ridge_regression_search_parameters.solver.len() {
self.current_alpha = 0;
self.current_solver += 1;
} else if self.current_normalize + 1
< self.ridge_regression_search_parameters.normalize.len()
{
self.current_alpha = 0;
self.current_solver = 0;
self.current_normalize += 1;
} else {
self.current_alpha += 1;
self.current_solver += 1;
self.current_normalize += 1;
}
Some(next)
}
}
impl<T: Number + RealNumber> Default for RidgeRegressionSearchParameters<T> {
fn default() -> Self {
let default_params = RidgeRegressionParameters::default();
RidgeRegressionSearchParameters {
solver: vec![default_params.solver],
alpha: vec![default_params.alpha],
normalize: vec![default_params.normalize],
}
}
}
/// Ridge regression
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct RidgeRegression<T: RealNumber, M: Matrix<T>> {
coefficients: M,
intercept: T,
_solver: RidgeRegressionSolverName,
pub struct RidgeRegression<
TX: Number + RealNumber,
TY: Number,
X: Array2<TX> + CholeskyDecomposable<TX> + SVDDecomposable<TX>,
Y: Array1<TY>,
> {
coefficients: Option<X>,
intercept: Option<TX>,
_phantom_ty: PhantomData<TY>,
_phantom_y: PhantomData<Y>,
}
impl<T: RealNumber> RidgeRegressionParameters<T> {
impl<T: Number + RealNumber> RidgeRegressionParameters<T> {
/// Regularization parameter.
pub fn with_alpha(mut self, alpha: T) -> Self {
self.alpha = alpha;
@@ -117,51 +214,83 @@ impl<T: RealNumber> RidgeRegressionParameters<T> {
}
}
impl<T: RealNumber> Default for RidgeRegressionParameters<T> {
impl<T: Number + RealNumber> Default for RidgeRegressionParameters<T> {
fn default() -> Self {
RidgeRegressionParameters {
solver: RidgeRegressionSolverName::Cholesky,
alpha: T::one(),
solver: RidgeRegressionSolverName::default(),
alpha: T::from_f64(1.0).unwrap(),
normalize: true,
}
}
}
impl<T: RealNumber, M: Matrix<T>> PartialEq for RidgeRegression<T, M> {
impl<
TX: Number + RealNumber,
TY: Number,
X: Array2<TX> + CholeskyDecomposable<TX> + SVDDecomposable<TX>,
Y: Array1<TY>,
> PartialEq for RidgeRegression<TX, TY, X, Y>
{
fn eq(&self, other: &Self) -> bool {
self.coefficients == other.coefficients
&& (self.intercept - other.intercept).abs() <= T::epsilon()
self.intercept() == other.intercept()
&& self.coefficients().shape() == other.coefficients().shape()
&& self
.coefficients()
.iterator(0)
.zip(other.coefficients().iterator(0))
.all(|(&a, &b)| (a - b).abs() <= TX::epsilon())
}
}
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, RidgeRegressionParameters<T>>
for RidgeRegression<T, M>
impl<
TX: Number + RealNumber,
TY: Number,
X: Array2<TX> + CholeskyDecomposable<TX> + SVDDecomposable<TX>,
Y: Array1<TY>,
> SupervisedEstimator<X, Y, RidgeRegressionParameters<TX>> for RidgeRegression<TX, TY, X, Y>
{
fn fit(
x: &M,
y: &M::RowVector,
parameters: RidgeRegressionParameters<T>,
) -> Result<Self, Failed> {
fn new() -> Self {
Self {
coefficients: Option::None,
intercept: Option::None,
_phantom_ty: PhantomData,
_phantom_y: PhantomData,
}
}
fn fit(x: &X, y: &Y, parameters: RidgeRegressionParameters<TX>) -> Result<Self, Failed> {
RidgeRegression::fit(x, y, parameters)
}
}
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RidgeRegression<T, M> {
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
impl<
TX: Number + RealNumber,
TY: Number,
X: Array2<TX> + CholeskyDecomposable<TX> + SVDDecomposable<TX>,
Y: Array1<TY>,
> Predictor<X, Y> for RidgeRegression<TX, TY, X, Y>
{
fn predict(&self, x: &X) -> Result<Y, Failed> {
self.predict(x)
}
}
impl<T: RealNumber, M: Matrix<T>> RidgeRegression<T, M> {
impl<
TX: Number + RealNumber,
TY: Number,
X: Array2<TX> + CholeskyDecomposable<TX> + SVDDecomposable<TX>,
Y: Array1<TY>,
> RidgeRegression<TX, TY, X, Y>
{
/// Fits ridge regression to your data.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - target values
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
pub fn fit(
x: &M,
y: &M::RowVector,
parameters: RidgeRegressionParameters<T>,
) -> Result<RidgeRegression<T, M>, Failed> {
x: &X,
y: &Y,
parameters: RidgeRegressionParameters<TX>,
) -> Result<RidgeRegression<TX, TY, X, Y>, Failed> {
//w = inv(X^t X + alpha*Id) * X.T y
let (n, p) = x.shape();
@@ -172,11 +301,16 @@ impl<T: RealNumber, M: Matrix<T>> RidgeRegression<T, M> {
));
}
if y.len() != n {
if y.shape() != n {
return Err(Failed::fit("Number of rows in X should = len(y)"));
}
let y_column = M::from_row_vector(y.clone()).transpose();
let y_column = X::from_iterator(
y.iterator(0).map(|&v| TX::from(v).unwrap()),
y.shape(),
1,
0,
);
let (w, b) = if parameters.normalize {
let (scaled_x, col_mean, col_std) = Self::rescale_x(x)?;
@@ -185,7 +319,7 @@ impl<T: RealNumber, M: Matrix<T>> RidgeRegression<T, M> {
let mut x_t_x = x_t.matmul(&scaled_x);
for i in 0..p {
x_t_x.add_element_mut(i, i, parameters.alpha);
x_t_x.add_element_mut((i, i), parameters.alpha);
}
let mut w = match parameters.solver {
@@ -194,16 +328,16 @@ impl<T: RealNumber, M: Matrix<T>> RidgeRegression<T, M> {
};
for (i, col_std_i) in col_std.iter().enumerate().take(p) {
w.set(i, 0, w.get(i, 0) / *col_std_i);
w.set((i, 0), *w.get((i, 0)) / *col_std_i);
}
let mut b = T::zero();
let mut b = TX::zero();
for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
b += w.get(i, 0) * *col_mean_i;
b += *w.get((i, 0)) * *col_mean_i;
}
let b = y.mean() - b;
let b = TX::from_f64(y.mean_by()).unwrap() - b;
(w, b)
} else {
@@ -212,7 +346,7 @@ impl<T: RealNumber, M: Matrix<T>> RidgeRegression<T, M> {
let mut x_t_x = x_t.matmul(x);
for i in 0..p {
x_t_x.add_element_mut(i, i, parameters.alpha);
x_t_x.add_element_mut((i, i), parameters.alpha);
}
let w = match parameters.solver {
@@ -220,26 +354,32 @@ impl<T: RealNumber, M: Matrix<T>> RidgeRegression<T, M> {
RidgeRegressionSolverName::SVD => x_t_x.svd_solve_mut(x_t_y)?,
};
(w, T::zero())
(w, TX::zero())
};
Ok(RidgeRegression {
intercept: b,
coefficients: w,
_solver: parameters.solver,
intercept: Some(b),
coefficients: Some(w),
_phantom_ty: PhantomData,
_phantom_y: PhantomData,
})
}
fn rescale_x(x: &M) -> Result<(M, Vec<T>, Vec<T>), Failed> {
let col_mean = x.mean(0);
let col_std = x.std(0);
fn rescale_x(x: &X) -> Result<(X, Vec<TX>, Vec<TX>), Failed> {
let col_mean: Vec<TX> = x
.mean_by(0)
.iter()
.map(|&v| TX::from_f64(v).unwrap())
.collect();
let col_std: Vec<TX> = x
.std_dev(0)
.iter()
.map(|&v| TX::from_f64(v).unwrap())
.collect();
for (i, col_std_i) in col_std.iter().enumerate() {
if (*col_std_i - T::zero()).abs() < T::epsilon() {
return Err(Failed::fit(&format!(
"Cannot rescale constant column {}",
i
)));
if (*col_std_i - TX::zero()).abs() < TX::epsilon() {
return Err(Failed::fit(&format!("Cannot rescale constant column {i}")));
}
}
@@ -250,31 +390,52 @@ impl<T: RealNumber, M: Matrix<T>> RidgeRegression<T, M> {
/// Predict target values from `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let (nrows, _) = x.shape();
let mut y_hat = x.matmul(&self.coefficients);
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
Ok(y_hat.transpose().to_row_vector())
let mut y_hat = x.matmul(self.coefficients());
y_hat.add_mut(&X::fill(nrows, 1, self.intercept.unwrap()));
Ok(Y::from_iterator(
y_hat.iterator(0).map(|&v| TY::from(v).unwrap()),
nrows,
))
}
/// Get estimates regression coefficients
pub fn coefficients(&self) -> &M {
&self.coefficients
pub fn coefficients(&self) -> &X {
self.coefficients.as_ref().unwrap()
}
/// Get estimate of intercept
pub fn intercept(&self) -> T {
self.intercept
pub fn intercept(&self) -> &TX {
self.intercept.as_ref().unwrap()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::*;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::metrics::mean_absolute_error;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn search_parameters() {
let parameters = RidgeRegressionSearchParameters {
alpha: vec![0., 1.],
..Default::default()
};
let mut iter = parameters.into_iter();
assert_eq!(iter.next().unwrap().alpha, 0.);
assert_eq!(
iter.next().unwrap().solver,
RidgeRegressionSolverName::Cholesky
);
assert!(iter.next().is_none());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn ridge_fit_predict() {
let x = DenseMatrix::from_2d_array(&[
@@ -294,7 +455,8 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]);
])
.unwrap();
let y: Vec<f64> = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
@@ -330,39 +492,40 @@ mod tests {
assert!(mean_absolute_error(&y_hat_svd, &y) < 2.0);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
#[cfg(feature = "serde")]
fn serde() {
let x = DenseMatrix::from_2d_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
&[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
&[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
&[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
&[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
&[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
&[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]);
// TODO: implement serialization for new DenseMatrix
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
// #[test]
// #[cfg(feature = "serde")]
// fn serde() {
// let x = DenseMatrix::from_2d_array(&[
// &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
// &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
// &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
// &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
// &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
// &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
// &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
// &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
// &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
// &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
// &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
// &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
// &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
// &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
// &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
// &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
// ]).unwrap();
let y = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
];
// let y = vec![
// 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
// 114.2, 115.7, 116.9,
// ];
let lr = RidgeRegression::fit(&x, &y, Default::default()).unwrap();
// let lr = RidgeRegression::fit(&x, &y, Default::default()).unwrap();
let deserialized_lr: RidgeRegression<f64, DenseMatrix<f64>> =
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
// let deserialized_lr: RidgeRegression<f64, f64, DenseMatrix<f64>, Vec<f64>> =
// serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
assert_eq!(lr, deserialized_lr);
}
// assert_eq!(lr, deserialized_lr);
// }
}
-70
View File
@@ -1,70 +0,0 @@
//! # Euclidian Metric Distance
//!
//! The Euclidean distance (L2) between two points \\( x \\) and \\( y \\) in n-space is defined as
//!
//! \\[ d(x, y) = \sqrt{\sum_{i=1}^n (x-y)^2} \\]
//!
//! Example:
//!
//! ```
//! use smartcore::math::distance::Distance;
//! use smartcore::math::distance::euclidian::Euclidian;
//!
//! let x = vec![1., 1.];
//! let y = vec![2., 2.];
//!
//! let l2: f64 = Euclidian{}.distance(&x, &y);
//! ```
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::math::num::RealNumber;
use super::Distance;
/// Euclidean distance is a measure of the true straight line distance between two points in Euclidean n-space.
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Euclidian {}
impl Euclidian {
#[inline]
pub(crate) fn squared_distance<T: RealNumber>(x: &[T], y: &[T]) -> T {
if x.len() != y.len() {
panic!("Input vector sizes are different.");
}
let mut sum = T::zero();
for i in 0..x.len() {
let d = x[i] - y[i];
sum += d * d;
}
sum
}
}
impl<T: RealNumber> Distance<Vec<T>, T> for Euclidian {
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T {
Euclidian::squared_distance(x, y).sqrt()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn squared_distance() {
let a = vec![1., 2., 3.];
let b = vec![4., 5., 6.];
let l2: f64 = Euclidian {}.distance(&a, &b);
assert!((l2 - 5.19615242).abs() < 1e-8);
}
}
-61
View File
@@ -1,61 +0,0 @@
//! # Manhattan Distance
//!
//! The Manhattan distance between two points \\(x \in ^n \\) and \\( y \in ^n \\) in n-dimensional space is the sum of the distances in each dimension.
//!
//! \\[ d(x, y) = \sum_{i=0}^n \lvert x_i - y_i \rvert \\]
//!
//! Example:
//!
//! ```
//! use smartcore::math::distance::Distance;
//! use smartcore::math::distance::manhattan::Manhattan;
//!
//! let x = vec![1., 1.];
//! let y = vec![2., 2.];
//!
//! let l1: f64 = Manhattan {}.distance(&x, &y);
//! ```
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::math::num::RealNumber;
use super::Distance;
/// Manhattan distance
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Manhattan {}
impl<T: RealNumber> Distance<Vec<T>, T> for Manhattan {
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T {
if x.len() != y.len() {
panic!("Input vector sizes are different");
}
let mut dist = T::zero();
for i in 0..x.len() {
dist += (x[i] - y[i]).abs();
}
dist
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn manhattan_distance() {
let a = vec![1., 2., 3.];
let b = vec![4., 5., 6.];
let l1: f64 = Manhattan {}.distance(&a, &b);
assert!((l1 - 9.0).abs() < 1e-8);
}
}
-65
View File
@@ -1,65 +0,0 @@
//! # Collection of Distance Functions
//!
//! Many algorithms in machine learning require a measure of distance between data points. Distance metric (or metric) is a function that defines a distance between a pair of point elements of a set.
//! Formally, the distance can be any metric measure that is defined as \\( d(x, y) \geq 0\\) and follows three conditions:
//! 1. \\( d(x, y) = 0 \\) if and only \\( x = y \\), positive definiteness
//! 1. \\( d(x, y) = d(y, x) \\), symmetry
//! 1. \\( d(x, y) \leq d(x, z) + d(z, y) \\), subadditivity or triangle inequality
//!
//! for all \\(x, y, z \in Z \\)
//!
//! A good distance metric helps to improve the performance of classification, clustering and information retrieval algorithms significantly.
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
/// Euclidean Distance is the straight-line distance between two points in Euclidean spacere that presents the shortest distance between these points.
pub mod euclidian;
/// Hamming Distance between two strings is the number of positions at which the corresponding symbols are different.
pub mod hamming;
/// The Mahalanobis distance is the distance between two points in multivariate space.
pub mod mahalanobis;
/// Also known as rectilinear distance, city block distance, taxicab metric.
pub mod manhattan;
/// A generalization of both the Euclidean distance and the Manhattan distance.
pub mod minkowski;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
/// Distance metric, a function that calculates distance between two points
pub trait Distance<T, F: RealNumber>: Clone {
/// Calculates distance between _a_ and _b_
fn distance(&self, a: &T, b: &T) -> F;
}
/// Multitude of distance metric functions
pub struct Distances {}
impl Distances {
/// Euclidian distance, see [`Euclidian`](euclidian/index.html)
pub fn euclidian() -> euclidian::Euclidian {
euclidian::Euclidian {}
}
/// Minkowski distance, see [`Minkowski`](minkowski/index.html)
/// * `p` - function order. Should be >= 1
pub fn minkowski(p: u16) -> minkowski::Minkowski {
minkowski::Minkowski { p }
}
/// Manhattan distance, see [`Manhattan`](manhattan/index.html)
pub fn manhattan() -> manhattan::Manhattan {
manhattan::Manhattan {}
}
/// Hamming distance, see [`Hamming`](hamming/index.html)
pub fn hamming() -> hamming::Hamming {
hamming::Hamming {}
}
/// Mahalanobis distance, see [`Mahalanobis`](mahalanobis/index.html)
pub fn mahalanobis<T: RealNumber, M: Matrix<T>>(data: &M) -> mahalanobis::Mahalanobis<T, M> {
mahalanobis::Mahalanobis::new(data)
}
}
-4
View File
@@ -1,4 +0,0 @@
/// Multitude of distance metrics are defined here
pub mod distance;
pub mod num;
pub(crate) mod vector;
-42
View File
@@ -1,42 +0,0 @@
use crate::math::num::RealNumber;
use std::collections::HashMap;
use crate::linalg::BaseVector;
pub trait RealNumberVector<T: RealNumber> {
fn unique_with_indices(&self) -> (Vec<T>, Vec<usize>);
}
impl<T: RealNumber, V: BaseVector<T>> RealNumberVector<T> for V {
fn unique_with_indices(&self) -> (Vec<T>, Vec<usize>) {
let mut unique = self.to_vec();
unique.sort_by(|a, b| a.partial_cmp(b).unwrap());
unique.dedup();
let mut index = HashMap::with_capacity(unique.len());
for (i, u) in unique.iter().enumerate() {
index.insert(u.to_i64().unwrap(), i);
}
let mut unique_index = Vec::with_capacity(self.len());
for idx in 0..self.len() {
unique_index.push(index[&self.get(idx).to_i64().unwrap()]);
}
(unique, unique_index)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn unique_with_indices() {
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
assert_eq!(
(vec!(0.0, 1.0, 2.0, 4.0), vec!(0, 0, 1, 1, 2, 0, 3)),
v1.unique_with_indices()
);
}
}
+62 -17
View File
@@ -8,10 +8,20 @@
//!
//! ```
//! use smartcore::metrics::accuracy::Accuracy;
//! use smartcore::metrics::Metrics;
//! let y_pred: Vec<f64> = vec![0., 2., 1., 3.];
//! let y_true: Vec<f64> = vec![0., 1., 2., 3.];
//!
//! let score: f64 = Accuracy {}.get_score(&y_pred, &y_true);
//! let score: f64 = Accuracy::new().get_score( &y_true, &y_pred);
//! ```
//! With integers:
//! ```
//! use smartcore::metrics::accuracy::Accuracy;
//! use smartcore::metrics::Metrics;
//! let y_pred: Vec<i64> = vec![0, 2, 1, 3];
//! let y_true: Vec<i64> = vec![0, 1, 2, 3];
//!
//! let score: f64 = Accuracy::new().get_score( &y_true, &y_pred);
//! ```
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
@@ -19,37 +29,53 @@
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
use crate::linalg::basic::arrays::ArrayView1;
use crate::numbers::basenum::Number;
use std::marker::PhantomData;
use crate::metrics::Metrics;
/// Accuracy metric.
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct Accuracy {}
pub struct Accuracy<T> {
_phantom: PhantomData<T>,
}
impl Accuracy {
impl<T: Number> Metrics<T> for Accuracy<T> {
/// create a typed object to call Accuracy functions
fn new() -> Self {
Self {
_phantom: PhantomData,
}
}
fn new_with(_parameter: f64) -> Self {
Self {
_phantom: PhantomData,
}
}
/// Function that calculated accuracy score.
/// * `y_true` - cround truth (correct) labels
/// * `y_pred` - predicted labels, as returned by a classifier.
pub fn get_score<T: RealNumber, V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
if y_true.len() != y_pred.len() {
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64 {
if y_true.shape() != y_pred.shape() {
panic!(
"The vector sizes don't match: {} != {}",
y_true.len(),
y_pred.len()
y_true.shape(),
y_pred.shape()
);
}
let n = y_true.len();
let n = y_true.shape();
let mut positive = 0;
let mut positive: i32 = 0;
for i in 0..n {
if y_true.get(i) == y_pred.get(i) {
if *y_true.get(i) == *y_pred.get(i) {
positive += 1;
}
}
T::from_i64(positive).unwrap() / T::from_usize(n).unwrap()
positive as f64 / n as f64
}
}
@@ -57,16 +83,35 @@ impl Accuracy {
mod tests {
use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn accuracy() {
fn accuracy_float() {
let y_pred: Vec<f64> = vec![0., 2., 1., 3.];
let y_true: Vec<f64> = vec![0., 1., 2., 3.];
let score1: f64 = Accuracy {}.get_score(&y_pred, &y_true);
let score2: f64 = Accuracy {}.get_score(&y_true, &y_true);
let score1: f64 = Accuracy::<f64>::new().get_score(&y_true, &y_pred);
let score2: f64 = Accuracy::<f64>::new().get_score(&y_true, &y_true);
assert!((score1 - 0.5).abs() < 1e-8);
assert!((score2 - 1.0).abs() < 1e-8);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn accuracy_int() {
let y_pred: Vec<i32> = vec![0, 2, 1, 3];
let y_true: Vec<i32> = vec![0, 1, 2, 3];
let score1: f64 = Accuracy::<i32>::new().get_score(&y_true, &y_pred);
let score2: f64 = Accuracy::<i32>::new().get_score(&y_true, &y_true);
assert_eq!(score1, 0.5);
assert_eq!(score2, 1.0);
}
}
+50 -27
View File
@@ -2,16 +2,17 @@
//! Computes the area under the receiver operating characteristic (ROC) curve that is equal to the probability that a classifier will rank a
//! randomly chosen positive instance higher than a randomly chosen negative one.
//!
//! SmartCore calculates ROC AUC from Wilcoxon or Mann-Whitney U test.
//! `smartcore` calculates ROC AUC from Wilcoxon or Mann-Whitney U test.
//!
//! Example:
//! ```
//! use smartcore::metrics::auc::AUC;
//! use smartcore::metrics::Metrics;
//!
//! let y_true: Vec<f64> = vec![0., 0., 1., 1.];
//! let y_pred: Vec<f64> = vec![0.1, 0.4, 0.35, 0.8];
//!
//! let score1: f64 = AUC {}.get_score(&y_true, &y_pred);
//! let score1: f64 = AUC::new().get_score(&y_true, &y_pred);
//! ```
//!
//! ## References:
@@ -20,32 +21,48 @@
//! * ["The ROC-AUC and the Mann-Whitney U-test", Haupt, J.](https://johaupt.github.io/roc-auc/model%20evaluation/Area_under_ROC_curve.html)
#![allow(non_snake_case)]
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::algorithm::sort::quick_sort::QuickArgSort;
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
use crate::linalg::basic::arrays::{Array1, ArrayView1};
use crate::numbers::floatnum::FloatNumber;
use crate::metrics::Metrics;
/// Area Under the Receiver Operating Characteristic Curve (ROC AUC)
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct AUC {}
pub struct AUC<T> {
_phantom: PhantomData<T>,
}
impl AUC {
impl<T: FloatNumber + PartialOrd> Metrics<T> for AUC<T> {
/// create a typed object to call AUC functions
fn new() -> Self {
Self {
_phantom: PhantomData,
}
}
fn new_with(_parameter: f64) -> Self {
Self {
_phantom: PhantomData,
}
}
/// AUC score.
/// * `y_true` - cround truth (correct) labels.
/// * `y_pred_probabilities` - probability estimates, as returned by a classifier.
pub fn get_score<T: RealNumber, V: BaseVector<T>>(&self, y_true: &V, y_pred_prob: &V) -> T {
/// * `y_true` - ground truth (correct) labels.
/// * `y_pred_prob` - probability estimates, as returned by a classifier.
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred_prob: &dyn ArrayView1<T>) -> f64 {
let mut pos = T::zero();
let mut neg = T::zero();
let n = y_true.len();
let n = y_true.shape();
for i in 0..n {
if y_true.get(i) == T::zero() {
if y_true.get(i) == &T::zero() {
neg += T::one();
} else if y_true.get(i) == T::one() {
} else if y_true.get(i) == &T::one() {
pos += T::one();
} else {
panic!(
@@ -55,21 +72,22 @@ impl AUC {
}
}
let mut y_pred = y_pred_prob.to_vec();
let y_pred: Vec<T> =
Array1::<T>::from_iterator(y_pred_prob.iterator(0).copied(), y_pred_prob.shape());
// TODO: try to use `crate::algorithm::sort::quick_sort` here
let label_idx: Vec<usize> = y_pred.argsort();
let label_idx = y_pred.quick_argsort_mut();
let mut rank = vec![T::zero(); n];
let mut rank = vec![0f64; n];
let mut i = 0;
while i < n {
if i == n - 1 || y_pred[i] != y_pred[i + 1] {
rank[i] = T::from_usize(i + 1).unwrap();
if i == n - 1 || y_pred.get(i) != y_pred.get(i + 1) {
rank[i] = (i + 1) as f64;
} else {
let mut j = i + 1;
while j < n && y_pred[j] == y_pred[i] {
while j < n && y_pred.get(j) == y_pred.get(i) {
j += 1;
}
let r = T::from_usize(i + 1 + j).unwrap() / T::two();
let r = (i + 1 + j) as f64 / 2f64;
for rank_k in rank.iter_mut().take(j).skip(i) {
*rank_k = r;
}
@@ -78,14 +96,16 @@ impl AUC {
i += 1;
}
let mut auc = T::zero();
let mut auc = 0f64;
for i in 0..n {
if y_true.get(label_idx[i]) == T::one() {
if y_true.get(label_idx[i]) == &T::one() {
auc += rank[i];
}
}
let pos = pos.to_f64().unwrap();
let neg = neg.to_f64().unwrap();
(auc - (pos * (pos + T::one()) / T::two())) / (pos * neg)
(auc - (pos * (pos + 1f64) / 2f64)) / (pos * neg)
}
}
@@ -93,14 +113,17 @@ impl AUC {
mod tests {
use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn auc() {
let y_true: Vec<f64> = vec![0., 0., 1., 1.];
let y_pred: Vec<f64> = vec![0.1, 0.4, 0.35, 0.8];
let score1: f64 = AUC {}.get_score(&y_true, &y_pred);
let score2: f64 = AUC {}.get_score(&y_true, &y_true);
let score1: f64 = AUC::new().get_score(&y_true, &y_pred);
let score2: f64 = AUC::new().get_score(&y_true, &y_true);
assert!((score1 - 0.75).abs() < 1e-8);
assert!((score2 - 1.0).abs() < 1e-8);
+79 -31
View File
@@ -1,41 +1,85 @@
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
use crate::linalg::basic::arrays::ArrayView1;
use crate::metrics::cluster_helpers::*;
use crate::numbers::basenum::Number;
use crate::metrics::Metrics;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
/// Homogeneity, completeness and V-Measure scores.
pub struct HCVScore {}
pub struct HCVScore<T> {
_phantom: PhantomData<T>,
homogeneity: Option<f64>,
completeness: Option<f64>,
v_measure: Option<f64>,
}
impl HCVScore {
/// Computes Homogeneity, completeness and V-Measure scores at once.
/// * `labels_true` - ground truth class labels to be used as a reference.
/// * `labels_pred` - cluster labels to evaluate.
pub fn get_score<T: RealNumber, V: BaseVector<T>>(
&self,
labels_true: &V,
labels_pred: &V,
) -> (T, T, T) {
let labels_true = labels_true.to_vec();
let labels_pred = labels_pred.to_vec();
let entropy_c = entropy(&labels_true);
let entropy_k = entropy(&labels_pred);
let contingency = contingency_matrix(&labels_true, &labels_pred);
let mi: T = mutual_info_score(&contingency);
impl<T: Number + Ord> HCVScore<T> {
/// return homogenity score
pub fn homogeneity(&self) -> Option<f64> {
self.homogeneity
}
/// return completeness score
pub fn completeness(&self) -> Option<f64> {
self.completeness
}
/// return v_measure score
pub fn v_measure(&self) -> Option<f64> {
self.v_measure
}
/// run computation for measures
pub fn compute(&mut self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) {
let entropy_c: Option<f64> = entropy(y_true);
let entropy_k: Option<f64> = entropy(y_pred);
let contingency = contingency_matrix(y_true, y_pred);
let mi = mutual_info_score(&contingency);
let homogeneity = entropy_c.map(|e| mi / e).unwrap_or_else(T::one);
let completeness = entropy_k.map(|e| mi / e).unwrap_or_else(T::one);
let homogeneity = entropy_c.map(|e| mi / e).unwrap_or(0f64);
let completeness = entropy_k.map(|e| mi / e).unwrap_or(0f64);
let v_measure_score = if homogeneity + completeness == T::zero() {
T::zero()
let v_measure_score = if homogeneity + completeness == 0f64 {
0f64
} else {
T::two() * homogeneity * completeness / (T::one() * homogeneity + completeness)
2.0f64 * homogeneity * completeness / (1.0f64 * homogeneity + completeness)
};
(homogeneity, completeness, v_measure_score)
self.homogeneity = Some(homogeneity);
self.completeness = Some(completeness);
self.v_measure = Some(v_measure_score);
}
}
impl<T: Number + Ord> Metrics<T> for HCVScore<T> {
/// create a typed object to call HCVScore functions
fn new() -> Self {
Self {
_phantom: PhantomData,
homogeneity: Option::None,
completeness: Option::None,
v_measure: Option::None,
}
}
fn new_with(_parameter: f64) -> Self {
Self {
_phantom: PhantomData,
homogeneity: Option::None,
completeness: Option::None,
v_measure: Option::None,
}
}
/// Computes Homogeneity, completeness and V-Measure scores at once.
/// * `y_true` - ground truth class labels to be used as a reference.
/// * `y_pred` - cluster labels to evaluate.
fn get_score(&self, _y_true: &dyn ArrayView1<T>, _y_pred: &dyn ArrayView1<T>) -> f64 {
// this functions should not be used for this struct
// use homogeneity(), completeness(), v_measure()
// TODO: implement Metrics -> Result<T, Failed>
0f64
}
}
@@ -43,15 +87,19 @@ impl HCVScore {
mod tests {
use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn homogeneity_score() {
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
let v2 = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0];
let scores = HCVScore {}.get_score(&v1, &v2);
let v1 = vec![0, 0, 1, 1, 2, 0, 4];
let v2 = vec![1, 0, 0, 0, 0, 1, 0];
let mut scores = HCVScore::new();
scores.compute(&v1, &v2);
assert!((0.2548f32 - scores.0).abs() < 1e-4);
assert!((0.5440f32 - scores.1).abs() < 1e-4);
assert!((0.3471f32 - scores.2).abs() < 1e-4);
assert!((0.2548 - scores.homogeneity.unwrap()).abs() < 1e-4);
assert!((0.5440 - scores.completeness.unwrap()).abs() < 1e-4);
assert!((0.3471 - scores.v_measure.unwrap()).abs() < 1e-4);
}
}
+46 -36
View File
@@ -1,12 +1,12 @@
#![allow(clippy::ptr_arg)]
use std::collections::HashMap;
use crate::math::num::RealNumber;
use crate::math::vector::RealNumberVector;
use crate::linalg::basic::arrays::ArrayView1;
use crate::numbers::basenum::Number;
pub fn contingency_matrix<T: RealNumber>(
labels_true: &Vec<T>,
labels_pred: &Vec<T>,
pub fn contingency_matrix<T: Number + Ord, V: ArrayView1<T> + ?Sized>(
labels_true: &V,
labels_pred: &V,
) -> Vec<Vec<usize>> {
let (classes, class_idx) = labels_true.unique_with_indices();
let (clusters, cluster_idx) = labels_pred.unique_with_indices();
@@ -24,28 +24,30 @@ pub fn contingency_matrix<T: RealNumber>(
contingency_matrix
}
pub fn entropy<T: RealNumber>(data: &[T]) -> Option<T> {
let mut bincounts = HashMap::with_capacity(data.len());
pub fn entropy<T: Number + Ord, V: ArrayView1<T> + ?Sized>(data: &V) -> Option<f64> {
let mut bincounts = HashMap::with_capacity(data.shape());
for e in data.iter() {
for e in data.iterator(0) {
let k = e.to_i64().unwrap();
bincounts.insert(k, bincounts.get(&k).unwrap_or(&0) + 1);
}
let mut entropy = T::zero();
let sum = T::from_usize(bincounts.values().sum()).unwrap();
let mut entropy = 0f64;
let sum: i64 = bincounts.values().sum();
for &c in bincounts.values() {
if c > 0 {
let pi = T::from_usize(c).unwrap();
entropy -= (pi / sum) * (pi.ln() - sum.ln());
let pi = c as f64;
let pi_ln = pi.ln();
let sum_ln = (sum as f64).ln();
entropy -= (pi / sum as f64) * (pi_ln - sum_ln);
}
}
Some(entropy)
}
pub fn mutual_info_score<T: RealNumber>(contingency: &[Vec<usize>]) -> T {
pub fn mutual_info_score(contingency: &[Vec<usize>]) -> f64 {
let mut contingency_sum = 0;
let mut pi = vec![0; contingency.len()];
let mut pj = vec![0; contingency[0].len()];
@@ -64,48 +66,50 @@ pub fn mutual_info_score<T: RealNumber>(contingency: &[Vec<usize>]) -> T {
}
}
let contingency_sum = T::from_usize(contingency_sum).unwrap();
let contingency_sum = contingency_sum as f64;
let contingency_sum_ln = contingency_sum.ln();
let pi_sum_l = T::from_usize(pi.iter().sum()).unwrap().ln();
let pj_sum_l = T::from_usize(pj.iter().sum()).unwrap().ln();
let pi_sum: usize = pi.iter().sum();
let pj_sum: usize = pj.iter().sum();
let pi_sum_l = (pi_sum as f64).ln();
let pj_sum_l = (pj_sum as f64).ln();
let log_contingency_nm: Vec<T> = nz_val
let log_contingency_nm: Vec<f64> = nz_val.iter().map(|v| (*v as f64).ln()).collect();
let contingency_nm: Vec<f64> = nz_val
.iter()
.map(|v| T::from_usize(*v).unwrap().ln())
.collect();
let contingency_nm: Vec<T> = nz_val
.iter()
.map(|v| T::from_usize(*v).unwrap() / contingency_sum)
.map(|v| (*v as f64) / contingency_sum)
.collect();
let outer: Vec<usize> = nzx
.iter()
.zip(nzy.iter())
.map(|(&x, &y)| pi[x] * pj[y])
.collect();
let log_outer: Vec<T> = outer
let log_outer: Vec<f64> = outer
.iter()
.map(|&o| -T::from_usize(o).unwrap().ln() + pi_sum_l + pj_sum_l)
.map(|&o| -(o as f64).ln() + pi_sum_l + pj_sum_l)
.collect();
let mut result = T::zero();
let mut result = 0f64;
for i in 0..log_outer.len() {
result += (contingency_nm[i] * (log_contingency_nm[i] - contingency_sum_ln))
+ contingency_nm[i] * log_outer[i]
}
result.max(T::zero())
result.max(0f64)
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn contingency_matrix_test() {
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
let v2 = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0];
let v1 = vec![0, 0, 1, 1, 2, 0, 4];
let v2 = vec![1, 0, 0, 0, 0, 1, 0];
assert_eq!(
vec!(vec!(1, 2), vec!(2, 0), vec!(1, 0), vec!(1, 0)),
@@ -113,20 +117,26 @@ mod tests {
);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn entropy_test() {
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
let v1 = vec![0, 0, 1, 1, 2, 0, 4];
assert!((1.2770f32 - entropy(&v1).unwrap()).abs() < 1e-4);
assert!((1.2770 - entropy(&v1).unwrap()).abs() < 1e-4);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn mutual_info_score_test() {
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
let v2 = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0];
let s: f32 = mutual_info_score(&contingency_matrix(&v1, &v2));
let v1 = vec![0, 0, 1, 1, 2, 0, 4];
let v2 = vec![1, 0, 0, 0, 0, 1, 0];
let s = mutual_info_score(&contingency_matrix(&v1, &v2));
assert!((0.3254 - s).abs() < 1e-4);
}
+219
View File
@@ -0,0 +1,219 @@
//! # Cosine Distance Metric
//!
//! The cosine distance between two points \\( x \\) and \\( y \\) in n-space is defined as:
//!
//! \\[ d(x, y) = 1 - \frac{x \cdot y}{||x|| ||y||} \\]
//!
//! where \\( x \cdot y \\) is the dot product of the vectors, and \\( ||x|| \\) and \\( ||y|| \\)
//! are their respective magnitudes (Euclidean norms).
//!
//! Cosine distance measures the angular dissimilarity between vectors, ranging from 0 to 2.
//! A value of 0 indicates identical direction (parallel vectors), while larger values indicate
//! greater angular separation.
//!
//! Example:
//!
//! ```
//! use smartcore::metrics::distance::Distance;
//! use smartcore::metrics::distance::cosine::Cosine;
//!
//! let x = vec![1., 1.];
//! let y = vec![2., 2.];
//!
//! let cosine_dist: f64 = Cosine::new().distance(&x, &y);
//! ```
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
use crate::linalg::basic::arrays::ArrayView1;
use crate::numbers::basenum::Number;
use super::Distance;
/// Cosine distance is a measure of the angular dissimilarity between two non-zero vectors in n-space.
/// It is defined as 1 minus the cosine similarity of the vectors.
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Cosine<T> {
_t: PhantomData<T>,
}
impl<T: Number> Default for Cosine<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Number> Cosine<T> {
/// Instantiate the initial structure
pub fn new() -> Cosine<T> {
Cosine { _t: PhantomData }
}
/// Calculate the dot product of two vectors using smartcore's ArrayView1 trait
#[inline]
pub(crate) fn dot_product<A: ArrayView1<T>>(x: &A, y: &A) -> f64 {
if x.shape() != y.shape() {
panic!("Input vector sizes are different.");
}
// Use the built-in dot product method from ArrayView1 trait
x.dot(y).to_f64().unwrap()
}
/// Calculate the squared magnitude (norm squared) of a vector
#[inline]
#[allow(dead_code)]
pub(crate) fn squared_magnitude<A: ArrayView1<T>>(x: &A) -> f64 {
x.iterator(0)
.map(|&a| {
let val = a.to_f64().unwrap();
val * val
})
.sum()
}
/// Calculate the magnitude (Euclidean norm) of a vector using smartcore's norm2 method
#[inline]
pub(crate) fn magnitude<A: ArrayView1<T>>(x: &A) -> f64 {
// Use the built-in norm2 method from ArrayView1 trait
x.norm2()
}
/// Calculate cosine similarity between two vectors
#[inline]
pub(crate) fn cosine_similarity<A: ArrayView1<T>>(x: &A, y: &A) -> f64 {
let dot_product = Self::dot_product(x, y);
let magnitude_x = Self::magnitude(x);
let magnitude_y = Self::magnitude(y);
if magnitude_x == 0.0 || magnitude_y == 0.0 {
return f64::MIN;
}
dot_product / (magnitude_x * magnitude_y)
}
}
impl<T: Number, A: ArrayView1<T>> Distance<A> for Cosine<T> {
fn distance(&self, x: &A, y: &A) -> f64 {
let similarity = Cosine::cosine_similarity(x, y);
1.0 - similarity
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn cosine_distance_identical_vectors() {
let a = vec![1, 2, 3];
let b = vec![1, 2, 3];
let dist: f64 = Cosine::new().distance(&a, &b);
assert!((dist - 0.0).abs() < 1e-8);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn cosine_distance_orthogonal_vectors() {
let a = vec![1, 0];
let b = vec![0, 1];
let dist: f64 = Cosine::new().distance(&a, &b);
assert!((dist - 1.0).abs() < 1e-8);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn cosine_distance_opposite_vectors() {
let a = vec![1, 2, 3];
let b = vec![-1, -2, -3];
let dist: f64 = Cosine::new().distance(&a, &b);
assert!((dist - 2.0).abs() < 1e-8);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn cosine_distance_general_case() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![2.0, 1.0, 3.0];
let dist: f64 = Cosine::new().distance(&a, &b);
// Expected cosine similarity: (1*2 + 2*1 + 3*3) / (sqrt(1+4+9) * sqrt(4+1+9))
// = (2 + 2 + 9) / (sqrt(14) * sqrt(14)) = 13/14 ≈ 0.9286
// So cosine distance = 1 - 13/14 = 1/14 ≈ 0.0714
let expected_dist = 1.0 - (13.0 / 14.0);
assert!((dist - expected_dist).abs() < 1e-8);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
#[should_panic(expected = "Input vector sizes are different.")]
fn cosine_distance_different_sizes() {
let a = vec![1, 2];
let b = vec![1, 2, 3];
let _dist: f64 = Cosine::new().distance(&a, &b);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn cosine_distance_zero_vector() {
let a = vec![0, 0, 0];
let b = vec![1, 2, 3];
let dist: f64 = Cosine::new().distance(&a, &b);
assert!(dist > 1e300)
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn cosine_distance_float_precision() {
let a = vec![1.0f32, 2.0, 3.0];
let b = vec![4.0f32, 5.0, 6.0];
let dist: f64 = Cosine::new().distance(&a, &b);
// Calculate expected value manually
let dot_product = 1.0 * 4.0 + 2.0 * 5.0 + 3.0 * 6.0; // = 32
let mag_a = (1.0 * 1.0 + 2.0 * 2.0 + 3.0 * 3.0_f64).sqrt(); // = sqrt(14)
let mag_b = (4.0 * 4.0 + 5.0 * 5.0 + 6.0 * 6.0_f64).sqrt(); // = sqrt(77)
let expected_similarity = dot_product / (mag_a * mag_b);
let expected_distance = 1.0 - expected_similarity;
assert!((dist - expected_distance).abs() < 1e-6);
}
}
+92
View File
@@ -0,0 +1,92 @@
//! # Euclidian Metric Distance
//!
//! The Euclidean distance (L2) between two points \\( x \\) and \\( y \\) in n-space is defined as
//!
//! \\[ d(x, y) = \sqrt{\sum_{i=1}^n (x-y)^2} \\]
//!
//! Example:
//!
//! ```
//! use smartcore::metrics::distance::Distance;
//! use smartcore::metrics::distance::euclidian::Euclidian;
//!
//! let x = vec![1., 1.];
//! let y = vec![2., 2.];
//!
//! let l2: f64 = Euclidian::new().distance(&x, &y);
//! ```
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
use crate::linalg::basic::arrays::ArrayView1;
use crate::numbers::basenum::Number;
use super::Distance;
/// Euclidean distance is a measure of the true straight line distance between two points in Euclidean n-space.
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Euclidian<T> {
_t: PhantomData<T>,
}
impl<T: Number> Default for Euclidian<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Number> Euclidian<T> {
/// instatiate the initial structure
pub fn new() -> Euclidian<T> {
Euclidian { _t: PhantomData }
}
/// return sum of squared distances
#[inline]
pub(crate) fn squared_distance<A: ArrayView1<T>>(x: &A, y: &A) -> f64 {
if x.shape() != y.shape() {
panic!("Input vector sizes are different.");
}
let sum: f64 = x
.iterator(0)
.zip(y.iterator(0))
.map(|(&a, &b)| {
let r = a - b;
(r * r).to_f64().unwrap()
})
.sum();
sum
}
}
impl<T: Number, A: ArrayView1<T>> Distance<A> for Euclidian<T> {
fn distance(&self, x: &A, y: &A) -> f64 {
Euclidian::squared_distance(x, y).sqrt()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn squared_distance() {
let a = vec![1, 2, 3];
let b = vec![4, 5, 6];
let l2: f64 = Euclidian::new().distance(&a, &b);
assert!((l2 - 5.19615242).abs() < 1e-8);
}
}
@@ -6,13 +6,13 @@
//! Example:
//!
//! ```
//! use smartcore::math::distance::Distance;
//! use smartcore::math::distance::hamming::Hamming;
//! use smartcore::metrics::distance::Distance;
//! use smartcore::metrics::distance::hamming::Hamming;
//!
//! let a = vec![1, 0, 0, 1, 0, 0, 1];
//! let b = vec![1, 1, 0, 0, 1, 0, 1];
//!
//! let h: f64 = Hamming {}.distance(&a, &b);
//! let h: f64 = Hamming::new().distance(&a, &b);
//!
//! ```
//!
@@ -21,30 +21,48 @@
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::math::num::RealNumber;
use std::marker::PhantomData;
use super::Distance;
use crate::linalg::basic::arrays::ArrayView1;
use crate::numbers::basenum::Number;
/// While comparing two integer-valued vectors of equal length, Hamming distance is the number of bit positions in which the two bits are different
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Hamming {}
pub struct Hamming<T: Number> {
_t: PhantomData<T>,
}
impl<T: PartialEq, F: RealNumber> Distance<Vec<T>, F> for Hamming {
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> F {
if x.len() != y.len() {
impl<T: Number> Hamming<T> {
/// instatiate the initial structure
pub fn new() -> Hamming<T> {
Hamming { _t: PhantomData }
}
}
impl<T: Number> Default for Hamming<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Number, A: ArrayView1<T>> Distance<A> for Hamming<T> {
fn distance(&self, x: &A, y: &A) -> f64 {
if x.shape() != y.shape() {
panic!("Input vector sizes are different");
}
let mut dist = 0;
for i in 0..x.len() {
if x[i] != y[i] {
dist += 1;
}
}
let dist: usize = x
.iterator(0)
.zip(y.iterator(0))
.map(|(a, b)| match a != b {
true => 1,
false => 0,
})
.sum();
F::from_i64(dist).unwrap() / F::from_usize(x.len()).unwrap()
dist as f64 / x.shape() as f64
}
}
@@ -52,13 +70,16 @@ impl<T: PartialEq, F: RealNumber> Distance<Vec<T>, F> for Hamming {
mod tests {
use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn hamming_distance() {
let a = vec![1, 0, 0, 1, 0, 0, 1];
let b = vec![1, 1, 0, 0, 1, 0, 1];
let h: f64 = Hamming {}.distance(&a, &b);
let h: f64 = Hamming::new().distance(&a, &b);
assert!((h - 0.42857142).abs() < 1e-8);
}
@@ -14,9 +14,10 @@
//! Example:
//!
//! ```
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::math::distance::Distance;
//! use smartcore::math::distance::mahalanobis::Mahalanobis;
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linalg::basic::arrays::ArrayView2;
//! use smartcore::metrics::distance::Distance;
//! use smartcore::metrics::distance::mahalanobis::Mahalanobis;
//!
//! let data = DenseMatrix::from_2d_array(&[
//! &[64., 580., 29.],
@@ -24,9 +25,9 @@
//! &[68., 590., 37.],
//! &[69., 660., 46.],
//! &[73., 600., 55.],
//! ]);
//! ]).unwrap();
//!
//! let a = data.column_mean();
//! let a = data.mean_by(0);
//! let b = vec![66., 640., 44.];
//!
//! let mahalanobis = Mahalanobis::new(&data);
@@ -42,85 +43,89 @@
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#![allow(non_snake_case)]
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::math::num::RealNumber;
use std::marker::PhantomData;
use super::Distance;
use crate::linalg::Matrix;
use crate::linalg::basic::arrays::{Array, Array2, ArrayView1};
use crate::linalg::basic::matrix::DenseMatrix;
use crate::linalg::traits::lu::LUDecomposable;
use crate::numbers::basenum::Number;
/// Mahalanobis distance.
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Mahalanobis<T: RealNumber, M: Matrix<T>> {
pub struct Mahalanobis<T: Number, M: Array2<f64>> {
/// covariance matrix of the dataset
pub sigma: M,
/// inverse of the covariance matrix
pub sigmaInv: M,
t: PhantomData<T>,
_t: PhantomData<T>,
}
impl<T: RealNumber, M: Matrix<T>> Mahalanobis<T, M> {
impl<T: Number, M: Array2<f64> + LUDecomposable<f64>> Mahalanobis<T, M> {
/// Constructs new instance of `Mahalanobis` from given dataset
/// * `data` - a matrix of _NxM_ where _N_ is number of observations and _M_ is number of attributes
pub fn new(data: &M) -> Mahalanobis<T, M> {
let sigma = data.cov();
pub fn new<X: Array2<T>>(data: &X) -> Mahalanobis<T, M> {
let (_, m) = data.shape();
let mut sigma = M::zeros(m, m);
data.cov(&mut sigma);
let sigmaInv = sigma.lu().and_then(|lu| lu.inverse()).unwrap();
Mahalanobis {
sigma,
sigmaInv,
t: PhantomData,
_t: PhantomData,
}
}
/// Constructs new instance of `Mahalanobis` from given covariance matrix
/// * `cov` - a covariance matrix
pub fn new_from_covariance(cov: &M) -> Mahalanobis<T, M> {
pub fn new_from_covariance<X: Array2<f64> + LUDecomposable<f64>>(cov: &X) -> Mahalanobis<T, X> {
let sigma = cov.clone();
let sigmaInv = sigma.lu().and_then(|lu| lu.inverse()).unwrap();
Mahalanobis {
sigma,
sigmaInv,
t: PhantomData,
_t: PhantomData,
}
}
}
impl<T: RealNumber, M: Matrix<T>> Distance<Vec<T>, T> for Mahalanobis<T, M> {
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T {
impl<T: Number, A: ArrayView1<T>> Distance<A> for Mahalanobis<T, DenseMatrix<f64>> {
fn distance(&self, x: &A, y: &A) -> f64 {
let (nrows, ncols) = self.sigma.shape();
if x.len() != nrows {
if x.shape() != nrows {
panic!(
"Array x[{}] has different dimension with Sigma[{}][{}].",
x.len(),
x.shape(),
nrows,
ncols
);
}
if y.len() != nrows {
if y.shape() != nrows {
panic!(
"Array y[{}] has different dimension with Sigma[{}][{}].",
y.len(),
y.shape(),
nrows,
ncols
);
}
let n = x.len();
let mut z = vec![T::zero(); n];
for i in 0..n {
z[i] = x[i] - y[i];
}
let n = x.shape();
let z: Vec<f64> = x
.iterator(0)
.zip(y.iterator(0))
.map(|(&a, &b)| (a - b).to_f64().unwrap())
.collect();
// np.dot(np.dot((a-b),VI),(a-b).T)
let mut s = T::zero();
let mut s = 0f64;
for j in 0..n {
for i in 0..n {
s += self.sigmaInv.get(i, j) * z[i] * z[j];
s += *self.sigmaInv.get((i, j)) * z[i] * z[j];
}
}
@@ -131,9 +136,13 @@ impl<T: RealNumber, M: Matrix<T>> Distance<Vec<T>, T> for Mahalanobis<T, M> {
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::*;
use crate::linalg::basic::arrays::ArrayView2;
use crate::linalg::basic::matrix::DenseMatrix;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn mahalanobis_distance() {
let data = DenseMatrix::from_2d_array(&[
@@ -142,9 +151,10 @@ mod tests {
&[68., 590., 37.],
&[69., 660., 46.],
&[73., 600., 55.],
]);
])
.unwrap();
let a = data.column_mean();
let a = data.mean_by(0);
let b = vec![66., 640., 44.];
let mahalanobis = Mahalanobis::new(&data);
+82
View File
@@ -0,0 +1,82 @@
//! # Manhattan Distance
//!
//! The Manhattan distance between two points \\(x \in ^n \\) and \\( y \in ^n \\) in n-dimensional space is the sum of the distances in each dimension.
//!
//! \\[ d(x, y) = \sum_{i=0}^n \lvert x_i - y_i \rvert \\]
//!
//! Example:
//!
//! ```
//! use smartcore::metrics::distance::Distance;
//! use smartcore::metrics::distance::manhattan::Manhattan;
//!
//! let x = vec![1., 1.];
//! let y = vec![2., 2.];
//!
//! let l1: f64 = Manhattan::new().distance(&x, &y);
//! ```
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
use crate::linalg::basic::arrays::ArrayView1;
use crate::numbers::basenum::Number;
use super::Distance;
/// Manhattan distance
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Manhattan<T: Number> {
_t: PhantomData<T>,
}
impl<T: Number> Manhattan<T> {
/// instatiate the initial structure
pub fn new() -> Manhattan<T> {
Manhattan { _t: PhantomData }
}
}
impl<T: Number> Default for Manhattan<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Number, A: ArrayView1<T>> Distance<A> for Manhattan<T> {
fn distance(&self, x: &A, y: &A) -> f64 {
if x.shape() != y.shape() {
panic!("Input vector sizes are different");
}
let dist: f64 = x
.iterator(0)
.zip(y.iterator(0))
.map(|(&a, &b)| (a - b).to_f64().unwrap().abs())
.sum();
dist
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn manhattan_distance() {
let a = vec![1., 2., 3.];
let b = vec![4., 5., 6.];
let l1: f64 = Manhattan::new().distance(&a, &b);
assert!((l1 - 9.0).abs() < 1e-8);
}
}
@@ -8,14 +8,14 @@
//! Example:
//!
//! ```
//! use smartcore::math::distance::Distance;
//! use smartcore::math::distance::minkowski::Minkowski;
//! use smartcore::metrics::distance::Distance;
//! use smartcore::metrics::distance::minkowski::Minkowski;
//!
//! let x = vec![1., 1.];
//! let y = vec![2., 2.];
//!
//! let l1: f64 = Minkowski { p: 1 }.distance(&x, &y);
//! let l2: f64 = Minkowski { p: 2 }.distance(&x, &y);
//! let l1: f64 = Minkowski::new(1).distance(&x, &y);
//! let l2: f64 = Minkowski::new(2).distance(&x, &y);
//!
//! ```
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
@@ -23,37 +23,47 @@
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
use crate::math::num::RealNumber;
use crate::linalg::basic::arrays::ArrayView1;
use crate::numbers::basenum::Number;
use super::Distance;
/// Defines the Minkowski distance of order `p`
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Minkowski {
pub struct Minkowski<T: Number> {
/// order, integer
pub p: u16,
_t: PhantomData<T>,
}
impl<T: RealNumber> Distance<Vec<T>, T> for Minkowski {
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T {
if x.len() != y.len() {
impl<T: Number> Minkowski<T> {
/// instatiate the initial structure
pub fn new(p: u16) -> Minkowski<T> {
Minkowski { p, _t: PhantomData }
}
}
impl<T: Number, A: ArrayView1<T>> Distance<A> for Minkowski<T> {
fn distance(&self, x: &A, y: &A) -> f64 {
if x.shape() != y.shape() {
panic!("Input vector sizes are different");
}
if self.p < 1 {
panic!("p must be at least 1");
}
let mut dist = T::zero();
let p_t = T::from_u16(self.p).unwrap();
let p_t = self.p as f64;
for i in 0..x.len() {
let d = (x[i] - y[i]).abs();
dist += d.powf(p_t);
}
let dist: f64 = x
.iterator(0)
.zip(y.iterator(0))
.map(|(&a, &b)| (a - b).to_f64().unwrap().abs().powf(p_t))
.sum();
dist.powf(T::one() / p_t)
dist.powf(1f64 / p_t)
}
}
@@ -61,15 +71,18 @@ impl<T: RealNumber> Distance<Vec<T>, T> for Minkowski {
mod tests {
use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn minkowski_distance() {
let a = vec![1., 2., 3.];
let b = vec![4., 5., 6.];
let l1: f64 = Minkowski { p: 1 }.distance(&a, &b);
let l2: f64 = Minkowski { p: 2 }.distance(&a, &b);
let l3: f64 = Minkowski { p: 3 }.distance(&a, &b);
let l1: f64 = Minkowski::new(1).distance(&a, &b);
let l2: f64 = Minkowski::new(2).distance(&a, &b);
let l3: f64 = Minkowski::new(3).distance(&a, &b);
assert!((l1 - 9.0).abs() < 1e-8);
assert!((l2 - 5.19615242).abs() < 1e-8);
@@ -82,6 +95,6 @@ mod tests {
let a = vec![1., 2., 3.];
let b = vec![4., 5., 6.];
let _: f64 = Minkowski { p: 0 }.distance(&a, &b);
let _: f64 = Minkowski::new(0).distance(&a, &b);
}
}
+118
View File
@@ -0,0 +1,118 @@
//! # Collection of Distance Functions
//!
//! Many algorithms in machine learning require a measure of distance between data points. Distance metric (or metric) is a function that defines a distance between a pair of point elements of a set.
//! Formally, the distance can be any metric measure that is defined as \\( d(x, y) \geq 0\\) and follows three conditions:
//! 1. \\( d(x, y) = 0 \\) if and only \\( x = y \\), positive definiteness
//! 1. \\( d(x, y) = d(y, x) \\), symmetry
//! 1. \\( d(x, y) \leq d(x, z) + d(z, y) \\), subadditivity or triangle inequality
//!
//! for all \\(x, y, z \in Z \\)
//!
//! A good distance metric helps to improve the performance of classification, clustering and information retrieval algorithms significantly.
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
/// Cosine distance
pub mod cosine;
/// Euclidean Distance is the straight-line distance between two points in Euclidean spacere that presents the shortest distance between these points.
pub mod euclidian;
/// Hamming Distance between two strings is the number of positions at which the corresponding symbols are different.
pub mod hamming;
/// The Mahalanobis distance is the distance between two points in multivariate space.
pub mod mahalanobis;
/// Also known as rectilinear distance, city block distance, taxicab metric.
pub mod manhattan;
/// A generalization of both the Euclidean distance and the Manhattan distance.
pub mod minkowski;
use std::cmp::{Eq, Ordering, PartialOrd};
use crate::linalg::basic::arrays::Array2;
use crate::linalg::traits::lu::LUDecomposable;
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
/// Distance metric, a function that calculates distance between two points
pub trait Distance<T>: Clone {
/// Calculates distance between _a_ and _b_
fn distance(&self, a: &T, b: &T) -> f64;
}
/// Multitude of distance metric functions
pub struct Distances {}
impl Distances {
/// Euclidian distance, see [`Euclidian`](euclidian/index.html)
pub fn euclidian<T: Number>() -> euclidian::Euclidian<T> {
euclidian::Euclidian::new()
}
/// Minkowski distance, see [`Minkowski`](minkowski/index.html)
/// * `p` - function order. Should be >= 1
pub fn minkowski<T: Number>(p: u16) -> minkowski::Minkowski<T> {
minkowski::Minkowski::new(p)
}
/// Manhattan distance, see [`Manhattan`](manhattan/index.html)
pub fn manhattan<T: Number>() -> manhattan::Manhattan<T> {
manhattan::Manhattan::new()
}
/// Hamming distance, see [`Hamming`](hamming/index.html)
pub fn hamming<T: Number>() -> hamming::Hamming<T> {
hamming::Hamming::new()
}
/// Mahalanobis distance, see [`Mahalanobis`](mahalanobis/index.html)
pub fn mahalanobis<T: Number, M: Array2<T>, C: Array2<f64> + LUDecomposable<f64>>(
data: &M,
) -> mahalanobis::Mahalanobis<T, C> {
mahalanobis::Mahalanobis::new(data)
}
}
///
/// ### Pairwise dissimilarities.
///
/// Representing distances as pairwise dissimilarities, so to build a
/// graph of closest neighbours. This representation can be reused for
/// different implementations
/// (initially used in this library for [FastPair](algorithm/neighbour/fastpair)).
/// The edge of the subgraph is defined by `PairwiseDistance`.
/// The calling algorithm can store a list of distances as
/// a list of these structures.
///
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, Copy)]
pub struct PairwiseDistance<T: RealNumber> {
/// index of the vector in the original `Matrix` or list
pub node: usize,
/// index of the closest neighbor in the original `Matrix` or same list
pub neighbour: Option<usize>,
/// measure of distance, according to the algorithm distance function
/// if the distance is None, the edge has value "infinite" or max distance
/// each algorithm has to match
pub distance: Option<T>,
}
impl<T: RealNumber> Eq for PairwiseDistance<T> {}
impl<T: RealNumber> PartialEq for PairwiseDistance<T> {
fn eq(&self, other: &Self) -> bool {
self.node == other.node
&& self.neighbour == other.neighbour
&& self.distance == other.distance
}
}
impl<T: RealNumber> PartialOrd for PairwiseDistance<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.distance.partial_cmp(&other.distance)
}
}
+46 -16
View File
@@ -10,48 +10,71 @@
//!
//! ```
//! use smartcore::metrics::f1::F1;
//! use smartcore::metrics::Metrics;
//! let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
//! let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
//!
//! let score: f64 = F1 {beta: 1.0}.get_score(&y_pred, &y_true);
//! let beta = 1.0; // beta default is equal 1.0 anyway
//! let score: f64 = F1::new_with(beta).get_score( &y_true, &y_pred);
//! ```
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
use crate::linalg::basic::arrays::ArrayView1;
use crate::metrics::precision::Precision;
use crate::metrics::recall::Recall;
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
use crate::numbers::realnum::RealNumber;
use crate::metrics::Metrics;
/// F-measure
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct F1<T: RealNumber> {
pub struct F1<T> {
/// a positive real factor
pub beta: T,
pub beta: f64,
_phantom: PhantomData<T>,
}
impl<T: RealNumber> F1<T> {
impl<T: Number + RealNumber + FloatNumber> Metrics<T> for F1<T> {
fn new() -> Self {
let beta: f64 = 1f64;
Self {
beta,
_phantom: PhantomData,
}
}
/// create a typed object to call Recall functions
fn new_with(beta: f64) -> Self {
Self {
beta,
_phantom: PhantomData,
}
}
/// Computes F1 score
/// * `y_true` - cround truth (correct) labels.
/// * `y_pred` - predicted labels, as returned by a classifier.
pub fn get_score<V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
if y_true.len() != y_pred.len() {
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64 {
if y_true.shape() != y_pred.shape() {
panic!(
"The vector sizes don't match: {} != {}",
y_true.len(),
y_pred.len()
y_true.shape(),
y_pred.shape()
);
}
let beta2 = self.beta * self.beta;
let p = Precision {}.get_score(y_true, y_pred);
let r = Recall {}.get_score(y_true, y_pred);
let p = Precision::new().get_score(y_true, y_pred);
let r = Recall::new().get_score(y_true, y_pred);
(T::one() + beta2) * (p * r) / (beta2 * p + r)
(1f64 + beta2) * (p * r) / ((beta2 * p) + r)
}
}
@@ -59,14 +82,21 @@ impl<T: RealNumber> F1<T> {
mod tests {
use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn f1() {
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
let score1: f64 = F1 { beta: 1.0 }.get_score(&y_pred, &y_true);
let score2: f64 = F1 { beta: 1.0 }.get_score(&y_true, &y_true);
let beta = 1.0;
let score1: f64 = F1::new_with(beta).get_score(&y_true, &y_pred);
let score2: f64 = F1::new_with(beta).get_score(&y_true, &y_true);
println!("{score1:?}");
println!("{score2:?}");
assert!((score1 - 0.57142857).abs() < 1e-8);
assert!((score2 - 1.0).abs() < 1e-8);
+39 -16
View File
@@ -10,45 +10,65 @@
//!
//! ```
//! use smartcore::metrics::mean_absolute_error::MeanAbsoluteError;
//! use smartcore::metrics::Metrics;
//! let y_pred: Vec<f64> = vec![3., -0.5, 2., 7.];
//! let y_true: Vec<f64> = vec![2.5, 0.0, 2., 8.];
//!
//! let mse: f64 = MeanAbsoluteError {}.get_score(&y_pred, &y_true);
//! let mse: f64 = MeanAbsoluteError::new().get_score( &y_true, &y_pred);
//! ```
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
use crate::linalg::basic::arrays::ArrayView1;
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
use crate::metrics::Metrics;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
/// Mean Absolute Error
pub struct MeanAbsoluteError {}
pub struct MeanAbsoluteError<T> {
_phantom: PhantomData<T>,
}
impl MeanAbsoluteError {
impl<T: Number + FloatNumber> Metrics<T> for MeanAbsoluteError<T> {
/// create a typed object to call MeanAbsoluteError functions
fn new() -> Self {
Self {
_phantom: PhantomData,
}
}
fn new_with(_parameter: f64) -> Self {
Self {
_phantom: PhantomData,
}
}
/// Computes mean absolute error
/// * `y_true` - Ground truth (correct) target values.
/// * `y_pred` - Estimated target values.
pub fn get_score<T: RealNumber, V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
if y_true.len() != y_pred.len() {
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64 {
if y_true.shape() != y_pred.shape() {
panic!(
"The vector sizes don't match: {} != {}",
y_true.len(),
y_pred.len()
y_true.shape(),
y_pred.shape()
);
}
let n = y_true.len();
let mut ras = T::zero();
let n = y_true.shape();
let mut ras: T = T::zero();
for i in 0..n {
ras += (y_true.get(i) - y_pred.get(i)).abs();
let res: T = *y_true.get(i) - *y_pred.get(i);
ras += res.abs();
}
ras / T::from_usize(n).unwrap()
ras.to_f64().unwrap() / n as f64
}
}
@@ -56,14 +76,17 @@ impl MeanAbsoluteError {
mod tests {
use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn mean_absolute_error() {
let y_true: Vec<f64> = vec![3., -0.5, 2., 7.];
let y_pred: Vec<f64> = vec![2.5, 0.0, 2., 8.];
let score1: f64 = MeanAbsoluteError {}.get_score(&y_pred, &y_true);
let score2: f64 = MeanAbsoluteError {}.get_score(&y_true, &y_true);
let score1: f64 = MeanAbsoluteError::new().get_score(&y_true, &y_pred);
let score2: f64 = MeanAbsoluteError::new().get_score(&y_true, &y_true);
assert!((score1 - 0.5).abs() < 1e-8);
assert!((score2 - 0.0).abs() < 1e-8);
+38 -15
View File
@@ -10,45 +10,65 @@
//!
//! ```
//! use smartcore::metrics::mean_squared_error::MeanSquareError;
//! use smartcore::metrics::Metrics;
//! let y_pred: Vec<f64> = vec![3., -0.5, 2., 7.];
//! let y_true: Vec<f64> = vec![2.5, 0.0, 2., 8.];
//!
//! let mse: f64 = MeanSquareError {}.get_score(&y_pred, &y_true);
//! let mse: f64 = MeanSquareError::new().get_score( &y_true, &y_pred);
//! ```
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
use crate::linalg::basic::arrays::ArrayView1;
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
use crate::metrics::Metrics;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
/// Mean Squared Error
pub struct MeanSquareError {}
pub struct MeanSquareError<T> {
_phantom: PhantomData<T>,
}
impl MeanSquareError {
impl<T: Number + FloatNumber> Metrics<T> for MeanSquareError<T> {
/// create a typed object to call MeanSquareError functions
fn new() -> Self {
Self {
_phantom: PhantomData,
}
}
fn new_with(_parameter: f64) -> Self {
Self {
_phantom: PhantomData,
}
}
/// Computes mean squared error
/// * `y_true` - Ground truth (correct) target values.
/// * `y_pred` - Estimated target values.
pub fn get_score<T: RealNumber, V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
if y_true.len() != y_pred.len() {
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64 {
if y_true.shape() != y_pred.shape() {
panic!(
"The vector sizes don't match: {} != {}",
y_true.len(),
y_pred.len()
y_true.shape(),
y_pred.shape()
);
}
let n = y_true.len();
let n = y_true.shape();
let mut rss = T::zero();
for i in 0..n {
rss += (y_true.get(i) - y_pred.get(i)).square();
let res = *y_true.get(i) - *y_pred.get(i);
rss += res * res;
}
rss / T::from_usize(n).unwrap()
rss.to_f64().unwrap() / n as f64
}
}
@@ -56,14 +76,17 @@ impl MeanSquareError {
mod tests {
use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn mean_squared_error() {
let y_true: Vec<f64> = vec![3., -0.5, 2., 7.];
let y_pred: Vec<f64> = vec![2.5, 0.0, 2., 8.];
let score1: f64 = MeanSquareError {}.get_score(&y_pred, &y_true);
let score2: f64 = MeanSquareError {}.get_score(&y_true, &y_true);
let score1: f64 = MeanSquareError::new().get_score(&y_true, &y_pred);
let score2: f64 = MeanSquareError::new().get_score(&y_true, &y_true);
assert!((score1 - 0.375).abs() < 1e-8);
assert!((score2 - 0.0).abs() < 1e-8);
+142 -64
View File
@@ -4,7 +4,7 @@
//! In a feedback loop you build your model first, then you get feedback from metrics, improve it and repeat until your model achieve desirable performance.
//! Evaluation metrics helps to explain the performance of a model and compare models based on an objective criterion.
//!
//! Choosing the right metric is crucial while evaluating machine learning models. In SmartCore you will find metrics for these classes of ML models:
//! Choosing the right metric is crucial while evaluating machine learning models. In `smartcore` you will find metrics for these classes of ML models:
//!
//! * [Classification metrics](struct.ClassificationMetrics.html)
//! * [Regression metrics](struct.RegressionMetrics.html)
@@ -12,7 +12,7 @@
//!
//! Example:
//! ```
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linear::logistic_regression::LogisticRegression;
//! use smartcore::metrics::*;
//!
@@ -37,27 +37,30 @@
//! &[4.9, 2.4, 3.3, 1.0],
//! &[6.6, 2.9, 4.6, 1.3],
//! &[5.2, 2.7, 3.9, 1.4],
//! ]);
//! let y: Vec<f64> = vec![
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
//! ]).unwrap();
//! let y: Vec<i8> = vec![
//! 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
//! ];
//!
//! let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
//!
//! let y_hat = lr.predict(&x).unwrap();
//!
//! let acc = ClassificationMetrics::accuracy().get_score(&y, &y_hat);
//! let acc = ClassificationMetricsOrd::accuracy().get_score(&y, &y_hat);
//! // or
//! let acc = accuracy(&y, &y_hat);
//! ```
/// Accuracy score.
pub mod accuracy;
/// Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores.
// TODO: reimplement AUC
// /// Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores.
pub mod auc;
/// Compute the homogeneity, completeness and V-Measure scores.
pub mod cluster_hcv;
pub(crate) mod cluster_helpers;
/// Multitude of distance metrics are defined here
pub mod distance;
/// F1 score, also known as balanced F-score or F-measure.
pub mod f1;
/// Mean absolute error regression loss.
@@ -71,150 +74,225 @@ pub mod r2;
/// Computes the recall.
pub mod recall;
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
use crate::linalg::basic::arrays::{Array1, ArrayView1};
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
use crate::numbers::realnum::RealNumber;
use std::marker::PhantomData;
/// A trait to be implemented by all metrics
pub trait Metrics<T> {
/// instantiate a new Metrics trait-object
/// <https://doc.rust-lang.org/error-index.html#E0038>
fn new() -> Self
where
Self: Sized;
/// used to instantiate metric with a paramenter
fn new_with(_parameter: f64) -> Self
where
Self: Sized;
/// compute score realated to this metric
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64;
}
/// Use these metrics to compare classification models.
pub struct ClassificationMetrics {}
pub struct ClassificationMetrics<T> {
phantom: PhantomData<T>,
}
/// Use these metrics to compare classification models for
/// numbers that require `Ord`.
pub struct ClassificationMetricsOrd<T> {
phantom: PhantomData<T>,
}
/// Metrics for regression models.
pub struct RegressionMetrics {}
pub struct RegressionMetrics<T> {
phantom: PhantomData<T>,
}
/// Cluster metrics.
pub struct ClusterMetrics {}
impl ClassificationMetrics {
/// Accuracy score, see [accuracy](accuracy/index.html).
pub fn accuracy() -> accuracy::Accuracy {
accuracy::Accuracy {}
}
pub struct ClusterMetrics<T> {
phantom: PhantomData<T>,
}
impl<T: Number + RealNumber + FloatNumber> ClassificationMetrics<T> {
/// Recall, see [recall](recall/index.html).
pub fn recall() -> recall::Recall {
recall::Recall {}
pub fn recall() -> recall::Recall<T> {
recall::Recall::new()
}
/// Precision, see [precision](precision/index.html).
pub fn precision() -> precision::Precision {
precision::Precision {}
pub fn precision() -> precision::Precision<T> {
precision::Precision::new()
}
/// F1 score, also known as balanced F-score or F-measure, see [F1](f1/index.html).
pub fn f1<T: RealNumber>(beta: T) -> f1::F1<T> {
f1::F1 { beta }
pub fn f1(beta: f64) -> f1::F1<T> {
f1::F1::new_with(beta)
}
/// Area Under the Receiver Operating Characteristic Curve (ROC AUC), see [AUC](auc/index.html).
pub fn roc_auc_score() -> auc::AUC {
auc::AUC {}
pub fn roc_auc_score() -> auc::AUC<T> {
auc::AUC::<T>::new()
}
}
impl RegressionMetrics {
impl<T: Number + Ord> ClassificationMetricsOrd<T> {
/// Accuracy score, see [accuracy](accuracy/index.html).
pub fn accuracy() -> accuracy::Accuracy<T> {
accuracy::Accuracy::new()
}
}
impl<T: Number + FloatNumber> RegressionMetrics<T> {
/// Mean squared error, see [mean squared error](mean_squared_error/index.html).
pub fn mean_squared_error() -> mean_squared_error::MeanSquareError {
mean_squared_error::MeanSquareError {}
pub fn mean_squared_error() -> mean_squared_error::MeanSquareError<T> {
mean_squared_error::MeanSquareError::new()
}
/// Mean absolute error, see [mean absolute error](mean_absolute_error/index.html).
pub fn mean_absolute_error() -> mean_absolute_error::MeanAbsoluteError {
mean_absolute_error::MeanAbsoluteError {}
pub fn mean_absolute_error() -> mean_absolute_error::MeanAbsoluteError<T> {
mean_absolute_error::MeanAbsoluteError::new()
}
/// Coefficient of determination (R2), see [R2](r2/index.html).
pub fn r2() -> r2::R2 {
r2::R2 {}
pub fn r2() -> r2::R2<T> {
r2::R2::<T>::new()
}
}
impl ClusterMetrics {
impl<T: Number + Ord> ClusterMetrics<T> {
/// Homogeneity and completeness and V-Measure scores at once.
pub fn hcv_score() -> cluster_hcv::HCVScore {
cluster_hcv::HCVScore {}
pub fn hcv_score() -> cluster_hcv::HCVScore<T> {
cluster_hcv::HCVScore::<T>::new()
}
}
/// Function that calculated accuracy score, see [accuracy](accuracy/index.html).
/// * `y_true` - cround truth (correct) labels
/// * `y_pred` - predicted labels, as returned by a classifier.
pub fn accuracy<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
ClassificationMetrics::accuracy().get_score(y_true, y_pred)
pub fn accuracy<T: Number + Ord, V: ArrayView1<T>>(y_true: &V, y_pred: &V) -> f64 {
let obj = ClassificationMetricsOrd::<T>::accuracy();
obj.get_score(y_true, y_pred)
}
/// Calculated recall score, see [recall](recall/index.html)
/// * `y_true` - cround truth (correct) labels.
/// * `y_pred` - predicted labels, as returned by a classifier.
pub fn recall<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
ClassificationMetrics::recall().get_score(y_true, y_pred)
pub fn recall<T: Number + RealNumber + FloatNumber, V: ArrayView1<T>>(
y_true: &V,
y_pred: &V,
) -> f64 {
let obj = ClassificationMetrics::<T>::recall();
obj.get_score(y_true, y_pred)
}
/// Calculated precision score, see [precision](precision/index.html).
/// * `y_true` - cround truth (correct) labels.
/// * `y_pred` - predicted labels, as returned by a classifier.
pub fn precision<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
ClassificationMetrics::precision().get_score(y_true, y_pred)
pub fn precision<T: Number + RealNumber + FloatNumber, V: ArrayView1<T>>(
y_true: &V,
y_pred: &V,
) -> f64 {
let obj = ClassificationMetrics::<T>::precision();
obj.get_score(y_true, y_pred)
}
/// Computes F1 score, see [F1](f1/index.html).
/// * `y_true` - cround truth (correct) labels.
/// * `y_pred` - predicted labels, as returned by a classifier.
pub fn f1<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V, beta: T) -> T {
ClassificationMetrics::f1(beta).get_score(y_true, y_pred)
pub fn f1<T: Number + RealNumber + FloatNumber, V: ArrayView1<T>>(
y_true: &V,
y_pred: &V,
beta: f64,
) -> f64 {
let obj = ClassificationMetrics::<T>::f1(beta);
obj.get_score(y_true, y_pred)
}
/// AUC score, see [AUC](auc/index.html).
/// * `y_true` - cround truth (correct) labels.
/// * `y_pred_probabilities` - probability estimates, as returned by a classifier.
pub fn roc_auc_score<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred_probabilities: &V) -> T {
ClassificationMetrics::roc_auc_score().get_score(y_true, y_pred_probabilities)
pub fn roc_auc_score<
T: Number + RealNumber + FloatNumber + PartialOrd,
V: ArrayView1<T> + Array1<T> + Array1<T>,
>(
y_true: &V,
y_pred_probabilities: &V,
) -> f64 {
let obj = ClassificationMetrics::<T>::roc_auc_score();
obj.get_score(y_true, y_pred_probabilities)
}
/// Computes mean squared error, see [mean squared error](mean_squared_error/index.html).
/// * `y_true` - Ground truth (correct) target values.
/// * `y_pred` - Estimated target values.
pub fn mean_squared_error<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
RegressionMetrics::mean_squared_error().get_score(y_true, y_pred)
pub fn mean_squared_error<T: Number + FloatNumber, V: ArrayView1<T>>(
y_true: &V,
y_pred: &V,
) -> f64 {
RegressionMetrics::<T>::mean_squared_error().get_score(y_true, y_pred)
}
/// Computes mean absolute error, see [mean absolute error](mean_absolute_error/index.html).
/// * `y_true` - Ground truth (correct) target values.
/// * `y_pred` - Estimated target values.
pub fn mean_absolute_error<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
RegressionMetrics::mean_absolute_error().get_score(y_true, y_pred)
pub fn mean_absolute_error<T: Number + FloatNumber, V: ArrayView1<T>>(
y_true: &V,
y_pred: &V,
) -> f64 {
RegressionMetrics::<T>::mean_absolute_error().get_score(y_true, y_pred)
}
/// Computes R2 score, see [R2](r2/index.html).
/// * `y_true` - Ground truth (correct) target values.
/// * `y_pred` - Estimated target values.
pub fn r2<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
RegressionMetrics::r2().get_score(y_true, y_pred)
pub fn r2<T: Number + FloatNumber, V: ArrayView1<T>>(y_true: &V, y_pred: &V) -> f64 {
RegressionMetrics::<T>::r2().get_score(y_true, y_pred)
}
/// Homogeneity metric of a cluster labeling given a ground truth (range is between 0.0 and 1.0).
/// A cluster result satisfies homogeneity if all of its clusters contain only data points which are members of a single class.
/// * `labels_true` - ground truth class labels to be used as a reference.
/// * `labels_pred` - cluster labels to evaluate.
pub fn homogeneity_score<T: RealNumber, V: BaseVector<T>>(labels_true: &V, labels_pred: &V) -> T {
ClusterMetrics::hcv_score()
.get_score(labels_true, labels_pred)
.0
pub fn homogeneity_score<
T: Number + FloatNumber + RealNumber + Ord,
V: ArrayView1<T> + Array1<T>,
>(
y_true: &V,
y_pred: &V,
) -> f64 {
let mut obj = ClusterMetrics::<T>::hcv_score();
obj.compute(y_true, y_pred);
obj.homogeneity().unwrap()
}
///
/// Completeness metric of a cluster labeling given a ground truth (range is between 0.0 and 1.0).
/// * `labels_true` - ground truth class labels to be used as a reference.
/// * `labels_pred` - cluster labels to evaluate.
pub fn completeness_score<T: RealNumber, V: BaseVector<T>>(labels_true: &V, labels_pred: &V) -> T {
ClusterMetrics::hcv_score()
.get_score(labels_true, labels_pred)
.1
pub fn completeness_score<
T: Number + FloatNumber + RealNumber + Ord,
V: ArrayView1<T> + Array1<T>,
>(
y_true: &V,
y_pred: &V,
) -> f64 {
let mut obj = ClusterMetrics::<T>::hcv_score();
obj.compute(y_true, y_pred);
obj.completeness().unwrap()
}
/// The harmonic mean between homogeneity and completeness.
/// * `labels_true` - ground truth class labels to be used as a reference.
/// * `labels_pred` - cluster labels to evaluate.
pub fn v_measure_score<T: RealNumber, V: BaseVector<T>>(labels_true: &V, labels_pred: &V) -> T {
ClusterMetrics::hcv_score()
.get_score(labels_true, labels_pred)
.2
pub fn v_measure_score<T: Number + FloatNumber + RealNumber + Ord, V: ArrayView1<T> + Array1<T>>(
y_true: &V,
y_pred: &V,
) -> f64 {
let mut obj = ClusterMetrics::<T>::hcv_score();
obj.compute(y_true, y_pred);
obj.v_measure().unwrap()
}
+145 -37
View File
@@ -4,72 +4,123 @@
//!
//! \\[precision = \frac{tp}{tp + fp}\\]
//!
//! where tp (true positive) - correct result, fp (false positive) - unexpected result
//! where tp (true positive) - correct result, fp (false positive) - unexpected result.
//! For binary classification, this is precision for the positive class (assumed to be 1.0).
//! For multiclass, this is macro-averaged precision (average of per-class precisions).
//!
//! Example:
//!
//! ```
//! use smartcore::metrics::precision::Precision;
//! use smartcore::metrics::Metrics;
//! let y_pred: Vec<f64> = vec![0., 1., 1., 0.];
//! let y_true: Vec<f64> = vec![0., 0., 1., 1.];
//!
//! let score: f64 = Precision {}.get_score(&y_pred, &y_true);
//! let score: f64 = Precision::new().get_score(&y_true, &y_pred);
//! ```
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::collections::{HashMap, HashSet};
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
use crate::linalg::basic::arrays::ArrayView1;
use crate::numbers::realnum::RealNumber;
use crate::metrics::Metrics;
/// Precision metric.
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct Precision {}
pub struct Precision<T> {
_phantom: PhantomData<T>,
}
impl Precision {
impl<T: RealNumber> Metrics<T> for Precision<T> {
/// create a typed object to call Precision functions
fn new() -> Self {
Self {
_phantom: PhantomData,
}
}
fn new_with(_parameter: f64) -> Self {
Self {
_phantom: PhantomData,
}
}
/// Calculated precision score
/// * `y_true` - cround truth (correct) labels.
/// * `y_true` - ground truth (correct) labels.
/// * `y_pred` - predicted labels, as returned by a classifier.
pub fn get_score<T: RealNumber, V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
if y_true.len() != y_pred.len() {
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64 {
if y_true.shape() != y_pred.shape() {
panic!(
"The vector sizes don't match: {} != {}",
y_true.len(),
y_pred.len()
y_true.shape(),
y_pred.shape()
);
}
let mut tp = 0;
let mut p = 0;
let n = y_true.len();
let n = y_true.shape();
let mut classes_set: HashSet<u64> = HashSet::new();
for i in 0..n {
if y_true.get(i) != T::zero() && y_true.get(i) != T::one() {
panic!(
"Precision can only be applied to binary classification: {}",
y_true.get(i)
);
}
classes_set.insert(y_true.get(i).to_f64_bits());
}
let classes: usize = classes_set.len();
if y_pred.get(i) != T::zero() && y_pred.get(i) != T::one() {
panic!(
"Precision can only be applied to binary classification: {}",
y_pred.get(i)
);
}
if y_pred.get(i) == T::one() {
p += 1;
if y_true.get(i) == T::one() {
tp += 1;
if classes == 2 {
// Binary case: precision for positive class (assumed T::one())
let positive = T::one();
let mut tp: usize = 0;
let mut fp_count: usize = 0;
for i in 0..n {
let t = *y_true.get(i);
let p = *y_pred.get(i);
if p == t {
if t == positive {
tp += 1;
}
} else if t != positive {
fp_count += 1;
}
}
if tp + fp_count == 0 {
0.0
} else {
tp as f64 / (tp + fp_count) as f64
}
} else {
// Multiclass case: macro-averaged precision
let mut predicted: HashMap<u64, usize> = HashMap::new();
let mut tp_map: HashMap<u64, usize> = HashMap::new();
for i in 0..n {
let p_bits = y_pred.get(i).to_f64_bits();
*predicted.entry(p_bits).or_insert(0) += 1;
if *y_true.get(i) == *y_pred.get(i) {
*tp_map.entry(p_bits).or_insert(0) += 1;
}
}
let mut precision_sum = 0.0;
for &bits in &classes_set {
let pred_count = *predicted.get(&bits).unwrap_or(&0);
let tp = *tp_map.get(&bits).unwrap_or(&0);
let prec = if pred_count > 0 {
tp as f64 / pred_count as f64
} else {
0.0
};
precision_sum += prec;
}
if classes == 0 {
0.0
} else {
precision_sum / classes as f64
}
}
T::from_i64(tp).unwrap() / T::from_i64(p).unwrap()
}
}
@@ -77,16 +128,73 @@ impl Precision {
mod tests {
use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn precision() {
let y_true: Vec<f64> = vec![0., 1., 1., 0.];
let y_pred: Vec<f64> = vec![0., 0., 1., 1.];
let score1: f64 = Precision {}.get_score(&y_pred, &y_true);
let score2: f64 = Precision {}.get_score(&y_pred, &y_pred);
let score1: f64 = Precision::new().get_score(&y_true, &y_pred);
let score2: f64 = Precision::new().get_score(&y_pred, &y_pred);
assert!((score1 - 0.5).abs() < 1e-8);
assert!((score2 - 1.0).abs() < 1e-8);
let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
let score3: f64 = Precision::new().get_score(&y_true, &y_pred);
assert!((score3 - 0.5).abs() < 1e-8);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn precision_multiclass() {
let y_true: Vec<f64> = vec![0., 0., 0., 1., 1., 1., 2., 2., 2.];
let y_pred: Vec<f64> = vec![0., 1., 2., 0., 1., 2., 0., 1., 2.];
let score1: f64 = Precision::new().get_score(&y_true, &y_pred);
let score2: f64 = Precision::new().get_score(&y_pred, &y_pred);
assert!((score1 - 0.333333333).abs() < 1e-8);
assert!((score2 - 1.0).abs() < 1e-8);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn precision_multiclass_imbalanced() {
let y_true: Vec<f64> = vec![0., 0., 1., 2., 2., 2.];
let y_pred: Vec<f64> = vec![0., 1., 1., 2., 0., 2.];
let score: f64 = Precision::new().get_score(&y_true, &y_pred);
let expected = (0.5 + 0.5 + 1.0) / 3.0;
assert!((score - expected).abs() < 1e-8);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn precision_multiclass_unpredicted_class() {
let y_true: Vec<f64> = vec![0., 0., 1., 2., 2., 2., 3.];
let y_pred: Vec<f64> = vec![0., 1., 1., 2., 0., 2., 0.];
let score: f64 = Precision::new().get_score(&y_true, &y_pred);
// Class 0: pred=3, tp=1 -> 1/3 ≈0.333
// Class 1: pred=2, tp=1 -> 0.5
// Class 2: pred=2, tp=2 -> 1.0
// Class 3: pred=0, tp=0 -> 0.0
let expected = (1.0 / 3.0 + 0.5 + 1.0 + 0.0) / 4.0;
assert!((score - expected).abs() < 1e-8);
}
}
+40 -26
View File
@@ -10,59 +10,70 @@
//!
//! ```
//! use smartcore::metrics::mean_absolute_error::MeanAbsoluteError;
//! use smartcore::metrics::Metrics;
//! let y_pred: Vec<f64> = vec![3., -0.5, 2., 7.];
//! let y_true: Vec<f64> = vec![2.5, 0.0, 2., 8.];
//!
//! let mse: f64 = MeanAbsoluteError {}.get_score(&y_pred, &y_true);
//! let mse: f64 = MeanAbsoluteError::new().get_score( &y_true, &y_pred);
//! ```
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
use crate::linalg::basic::arrays::ArrayView1;
use crate::numbers::basenum::Number;
use crate::metrics::Metrics;
/// Coefficient of Determination (R2)
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct R2 {}
pub struct R2<T> {
_phantom: PhantomData<T>,
}
impl R2 {
impl<T: Number> Metrics<T> for R2<T> {
/// create a typed object to call R2 functions
fn new() -> Self {
Self {
_phantom: PhantomData,
}
}
fn new_with(_parameter: f64) -> Self {
Self {
_phantom: PhantomData,
}
}
/// Computes R2 score
/// * `y_true` - Ground truth (correct) target values.
/// * `y_pred` - Estimated target values.
pub fn get_score<T: RealNumber, V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
if y_true.len() != y_pred.len() {
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64 {
if y_true.shape() != y_pred.shape() {
panic!(
"The vector sizes don't match: {} != {}",
y_true.len(),
y_pred.len()
y_true.shape(),
y_pred.shape()
);
}
let n = y_true.len();
let mut mean = T::zero();
for i in 0..n {
mean += y_true.get(i);
}
mean /= T::from_usize(n).unwrap();
let n = y_true.shape();
let mean: f64 = y_true.mean_by();
let mut ss_tot = T::zero();
let mut ss_res = T::zero();
for i in 0..n {
let y_i = y_true.get(i);
let f_i = y_pred.get(i);
ss_tot += (y_i - mean).square();
ss_res += (y_i - f_i).square();
let y_i = *y_true.get(i);
let f_i = *y_pred.get(i);
ss_tot += (y_i - T::from(mean).unwrap()) * (y_i - T::from(mean).unwrap());
ss_res += (y_i - f_i) * (y_i - f_i);
}
T::one() - (ss_res / ss_tot)
(T::one() - ss_res / ss_tot).to_f64().unwrap()
}
}
@@ -70,14 +81,17 @@ impl R2 {
mod tests {
use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn r2() {
let y_true: Vec<f64> = vec![3., -0.5, 2., 7.];
let y_pred: Vec<f64> = vec![2.5, 0.0, 2., 8.];
let score1: f64 = R2 {}.get_score(&y_true, &y_pred);
let score2: f64 = R2 {}.get_score(&y_true, &y_true);
let score1: f64 = R2::new().get_score(&y_true, &y_pred);
let score2: f64 = R2::new().get_score(&y_true, &y_true);
assert!((score1 - 0.948608137).abs() < 1e-8);
assert!((score2 - 1.0).abs() < 1e-8);
+121 -37
View File
@@ -4,72 +4,117 @@
//!
//! \\[recall = \frac{tp}{tp + fn}\\]
//!
//! where tp (true positive) - correct result, fn (false negative) - missing result
//! where tp (true positive) - correct result, fn (false negative) - missing result.
//! For binary classification, this is recall for the positive class (assumed to be 1.0).
//! For multiclass, this is macro-averaged recall (average of per-class recalls).
//!
//! Example:
//!
//! ```
//! use smartcore::metrics::recall::Recall;
//! use smartcore::metrics::Metrics;
//! let y_pred: Vec<f64> = vec![0., 1., 1., 0.];
//! let y_true: Vec<f64> = vec![0., 0., 1., 1.];
//!
//! let score: f64 = Recall {}.get_score(&y_pred, &y_true);
//! let score: f64 = Recall::new().get_score( &y_true, &y_pred);
//! ```
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::collections::{HashMap, HashSet};
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
use crate::linalg::basic::arrays::ArrayView1;
use crate::numbers::realnum::RealNumber;
use crate::metrics::Metrics;
/// Recall metric.
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct Recall {}
pub struct Recall<T> {
_phantom: PhantomData<T>,
}
impl Recall {
impl<T: RealNumber> Metrics<T> for Recall<T> {
/// create a typed object to call Recall functions
fn new() -> Self {
Self {
_phantom: PhantomData,
}
}
fn new_with(_parameter: f64) -> Self {
Self {
_phantom: PhantomData,
}
}
/// Calculated recall score
/// * `y_true` - cround truth (correct) labels.
/// * `y_true` - ground truth (correct) labels.
/// * `y_pred` - predicted labels, as returned by a classifier.
pub fn get_score<T: RealNumber, V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
if y_true.len() != y_pred.len() {
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64 {
if y_true.shape() != y_pred.shape() {
panic!(
"The vector sizes don't match: {} != {}",
y_true.len(),
y_pred.len()
y_true.shape(),
y_pred.shape()
);
}
let mut tp = 0;
let mut p = 0;
let n = y_true.len();
let n = y_true.shape();
let mut classes_set = HashSet::new();
for i in 0..n {
if y_true.get(i) != T::zero() && y_true.get(i) != T::one() {
panic!(
"Recall can only be applied to binary classification: {}",
y_true.get(i)
);
}
classes_set.insert(y_true.get(i).to_f64_bits());
}
let classes: usize = classes_set.len();
if y_pred.get(i) != T::zero() && y_pred.get(i) != T::one() {
panic!(
"Recall can only be applied to binary classification: {}",
y_pred.get(i)
);
}
if y_true.get(i) == T::one() {
p += 1;
if y_pred.get(i) == T::one() {
tp += 1;
if classes == 2 {
// Binary case: recall for positive class (assumed T::one())
let positive = T::one();
let mut tp: usize = 0;
let mut fn_count: usize = 0;
for i in 0..n {
let t = *y_true.get(i);
let p = *y_pred.get(i);
if p == t {
if t == positive {
tp += 1;
}
} else if t == positive {
fn_count += 1;
}
}
if tp + fn_count == 0 {
0.0
} else {
tp as f64 / (tp + fn_count) as f64
}
} else {
// Multiclass case: macro-averaged recall
let mut support: HashMap<u64, usize> = HashMap::new();
let mut tp_map: HashMap<u64, usize> = HashMap::new();
for i in 0..n {
let t_bits = y_true.get(i).to_f64_bits();
*support.entry(t_bits).or_insert(0) += 1;
if *y_true.get(i) == *y_pred.get(i) {
*tp_map.entry(t_bits).or_insert(0) += 1;
}
}
let mut recall_sum = 0.0;
for (&bits, &sup) in &support {
let tp = *tp_map.get(&bits).unwrap_or(&0);
recall_sum += tp as f64 / sup as f64;
}
if support.is_empty() {
0.0
} else {
recall_sum / support.len() as f64
}
}
T::from_i64(tp).unwrap() / T::from_i64(p).unwrap()
}
}
@@ -77,16 +122,55 @@ impl Recall {
mod tests {
use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn recall() {
let y_true: Vec<f64> = vec![0., 1., 1., 0.];
let y_pred: Vec<f64> = vec![0., 0., 1., 1.];
let score1: f64 = Recall {}.get_score(&y_pred, &y_true);
let score2: f64 = Recall {}.get_score(&y_pred, &y_pred);
let score1: f64 = Recall::new().get_score(&y_true, &y_pred);
let score2: f64 = Recall::new().get_score(&y_pred, &y_pred);
assert!((score1 - 0.5).abs() < 1e-8);
assert!((score2 - 1.0).abs() < 1e-8);
let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
let score3: f64 = Recall::new().get_score(&y_true, &y_pred);
assert!((score3 - (2.0 / 3.0)).abs() < 1e-8);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn recall_multiclass() {
let y_true: Vec<f64> = vec![0., 0., 0., 1., 1., 1., 2., 2., 2.];
let y_pred: Vec<f64> = vec![0., 1., 2., 0., 1., 2., 0., 1., 2.];
let score1: f64 = Recall::new().get_score(&y_true, &y_pred);
let score2: f64 = Recall::new().get_score(&y_pred, &y_pred);
assert!((score1 - 0.333333333).abs() < 1e-8);
assert!((score2 - 1.0).abs() < 1e-8);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn recall_multiclass_imbalanced() {
let y_true: Vec<f64> = vec![0., 0., 1., 2., 2., 2.];
let y_pred: Vec<f64> = vec![0., 1., 1., 2., 0., 2.];
let score: f64 = Recall::new().get_score(&y_true, &y_pred);
let expected = (0.5 + 1.0 + (2.0 / 3.0)) / 3.0;
assert!((score - expected).abs() < 1e-8);
}
}

Some files were not shown because too many files have changed in this diff Show More