Compare commits
112 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
70212c71e0 | ||
|
|
63f86f7bc9 | ||
|
|
e633afa520 | ||
|
|
b6e32fb328 | ||
|
|
948d78a4d0 | ||
|
|
448b6f77e3 | ||
|
|
09be4681cf | ||
|
|
4841791b7e | ||
|
|
9fef05ecc6 | ||
|
|
c5816b0e1b | ||
|
|
5cc5528367 | ||
|
|
d459c48372 | ||
|
|
730c0d64df | ||
|
|
44424807a0 | ||
|
|
76d1ef610d | ||
|
|
4092e24c2a | ||
|
|
17dc9f3bbf | ||
|
|
c8ec8fec00 | ||
|
|
3da433f757 | ||
|
|
4523ac73ff | ||
|
|
ba75f9ffad | ||
|
|
239c00428f | ||
|
|
80a93c1a0e | ||
|
|
4eadd16ce4 | ||
|
|
886b5631b7 | ||
|
|
9c07925d8a | ||
|
|
6f22bbd150 | ||
|
|
dbdc2b2a77 | ||
|
|
2d7c055154 | ||
|
|
545ed6ce2b | ||
|
|
8939ed93b9 | ||
|
|
9cd7348403 | ||
|
|
d52830a818 | ||
|
|
d15ea43975 | ||
|
|
f498f9629e | ||
|
|
7d059c4fb1 | ||
|
|
c7353d0b57 | ||
|
|
83dcf9a8ac | ||
|
|
3126ee87d3 | ||
|
|
8efb959b3c | ||
|
|
9eaae9ef35 | ||
|
|
46b6285d05 | ||
|
|
c683073b14 | ||
|
|
161d249917 | ||
|
|
4558be5f73 | ||
|
|
6c03e6e0b3 | ||
|
|
c934f6b6cf | ||
|
|
48f1d6b74d | ||
|
|
dad0d01f6d | ||
|
|
98b18c4dae | ||
|
|
2418b24ff4 | ||
|
|
6c6f92697f | ||
|
|
a4097fce15 | ||
|
|
b71c7b49cb | ||
|
|
78bf75b5d8 | ||
|
|
a60fdaf235 | ||
|
|
b4206c4b08 | ||
|
|
3c4a807be8 | ||
|
|
c1af60cafb | ||
|
|
2fa454ea94 | ||
|
|
8e6e5f9e68 | ||
|
|
bf7b714126 | ||
|
|
3ac6598951 | ||
|
|
cc91e31a0e | ||
|
|
0ec89402e8 | ||
|
|
23b3699730 | ||
|
|
aab3817c58 | ||
|
|
d3a496419d | ||
|
|
ab18f127a0 | ||
|
|
425c3c1d0b | ||
|
|
35fe68e024 | ||
|
|
d592b628be | ||
|
|
b66afa9222 | ||
|
|
ba70bb941f | ||
|
|
d298709040 | ||
|
|
e50b4e8637 | ||
|
|
26b72b67f4 | ||
|
|
1964424589 | ||
|
|
deac31a2ab | ||
|
|
4cff7da50d | ||
|
|
df0ae907f7 | ||
|
|
cfbd45bfc0 | ||
|
|
b60329ca5d | ||
|
|
4b096ad558 | ||
|
|
4cf7e4d7b7 | ||
|
|
c3093f11f1 | ||
|
|
083803c900 | ||
|
|
4f64f2e0ff | ||
|
|
52eb6ce023 | ||
|
|
bb71656137 | ||
|
|
edbac7e4c7 | ||
|
|
8a2bdd5a75 | ||
|
|
b823b55460 | ||
|
|
12df301f32 | ||
|
|
f8210d0af9 | ||
|
|
3c62686d6e | ||
|
|
9c59e37a0f | ||
|
|
0b619fe7eb | ||
|
|
764309e313 | ||
|
|
403d3f2348 | ||
|
|
3a44161406 | ||
|
|
48514d1b15 | ||
|
|
69d8be35de | ||
|
|
c21e75276a | ||
|
|
6a2e10452f | ||
|
|
436da104d7 | ||
|
|
2510ca4e9d | ||
|
|
b6f585e60f | ||
|
|
4685fc73e0 | ||
|
|
2e5f88fad8 | ||
|
|
e445f0d558 | ||
|
|
4d5f64c758 |
@@ -0,0 +1,6 @@
|
||||
# These owners will be the default owners for everything in
|
||||
# the repo. Unless a later match takes precedence,
|
||||
# Developers in this list will be requested for
|
||||
# review when someone opens a pull request.
|
||||
* @morenol
|
||||
* @Mec-iS
|
||||
@@ -0,0 +1,22 @@
|
||||
# Code of Conduct
|
||||
|
||||
As contributors and maintainers of this project, and in the interest of fostering an open and welcoming community, we pledge to respect all people who contribute through reporting issues, posting feature requests, updating documentation, submitting pull requests or patches, and other activities.
|
||||
|
||||
We are committed to making participation in this project a harassment-free experience for everyone, regardless of level of experience, gender, gender identity and expression, sexual orientation, disability, personal appearance, body size, race, ethnicity, age, religion, or nationality.
|
||||
|
||||
Examples of unacceptable behavior by participants include:
|
||||
|
||||
* The use of sexualized language or imagery
|
||||
* Personal attacks
|
||||
* Trolling or insulting/derogatory comments
|
||||
* Public or private harassment
|
||||
* Publishing other's private information, such as physical or electronic addresses, without explicit permission
|
||||
* Other unethical or unprofessional conduct.
|
||||
|
||||
Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct. By adopting this Code of Conduct, project maintainers commit themselves to fairly and consistently applying these principles to every aspect of managing this project. Project maintainers who do not follow or enforce the Code of Conduct may be permanently removed from the project team.
|
||||
|
||||
This code of conduct applies both within project spaces and in public spaces when an individual is representing the project or its community.
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by opening an issue or contacting one or more of the project maintainers.
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant](http://contributor-covenant.org), version 1.2.0, available at [http://contributor-covenant.org/version/1/2/0/](http://contributor-covenant.org/version/1/2/0/)
|
||||
@@ -0,0 +1,72 @@
|
||||
# **Contributing**
|
||||
|
||||
When contributing to this repository, please first discuss the change you wish to make via issue,
|
||||
email, or any other method with the owners of this repository before making a change.
|
||||
|
||||
Please note we have a [code of conduct](CODE_OF_CONDUCT.md), please follow it in all your interactions with the project.
|
||||
|
||||
## Background
|
||||
|
||||
We try to follow these principles:
|
||||
* follow as much as possible the sklearn API to give a frictionless user experience for practitioners already familiar with it
|
||||
* use only pure-Rust implementations for safety and future-proofing (with some low-level limited exceptions)
|
||||
* do not use macros in the library code to allow readability and transparent behavior
|
||||
* priority is not on "big data" dataset, try to be fast for small/average dataset with limited memory footprint.
|
||||
|
||||
## Pull Request Process
|
||||
|
||||
1. Open a PR following the template (erase the part of the template you don't need).
|
||||
2. Update the CHANGELOG.md with details of changes to the interface if they are breaking changes, this includes new environment variables, exposed ports useful file locations and container parameters.
|
||||
3. Pull Request can be merged in once you have the sign-off of one other developer, or if you do not have permission to do that you may request the reviewer to merge it for you.
|
||||
|
||||
### generic guidelines
|
||||
Take a look to the conventions established by existing code:
|
||||
* Every module should come with some reference to scientific literature that allows relating the code to research. Use the `//!` comments at the top of the module to tell readers about the basics of the procedure you are implementing.
|
||||
* Every module should provide a Rust doctest, a brief test embedded with the documentation that explains how to use the procedure implemented.
|
||||
* Every module should provide comprehensive tests at the end, in its `mod tests {}` sub-module. These tests can be flagged or not with configuration flags to allow WebAssembly target.
|
||||
* Run `cargo doc --no-deps --open` and read the generated documentation in the browser to be sure that your changes reflects in the documentation and new code is documented.
|
||||
|
||||
#### digging deeper
|
||||
* a nice overview of the codebase is given by [static analyzer](https://mozilla.github.io/rust-code-analysis/metrics.html):
|
||||
```
|
||||
$ cargo install rust-code-analysis-cli
|
||||
// print metrics for every module
|
||||
$ rust-code-analysis-cli -m -O json -o . -p src/ --pr
|
||||
// print full AST for a module
|
||||
$ rust-code-analysis-cli -p src/algorithm/neighbour/fastpair.rs --ls 22 --le 213 -d > ast.txt
|
||||
```
|
||||
* find more information about what happens in your binary with [`twiggy`](https://rustwasm.github.io/twiggy/install.html). This need a compiled binary so create a brief `main {}` function using `smartcore` and then point `twiggy` to that file.
|
||||
|
||||
* Please take a look to the output of a profiler to spot most evident performance problems, see [this guide about using a profiler](http://www.codeofview.com/fix-rs/2017/01/24/how-to-optimize-rust-programs-on-linux/).
|
||||
|
||||
## Issue Report Process
|
||||
|
||||
1. Go to the project's issues.
|
||||
2. Select the template that better fits your issue.
|
||||
3. Read carefully the instructions and write within the template guidelines.
|
||||
4. Submit it and wait for support.
|
||||
|
||||
## Reviewing process
|
||||
|
||||
1. After a PR is opened maintainers are notified
|
||||
2. Probably changes will be required to comply with the workflow, these commands are run automatically and all tests shall pass:
|
||||
* **Formatting**: run `rustfmt src/*.rs` to apply automatic formatting
|
||||
* **Linting**: `clippy` is used with command `cargo clippy --all-features -- -Drust-2018-idioms -Dwarnings`
|
||||
* **Coverage** (optional): `tarpaulin` is used with command `cargo tarpaulin --out Lcov --all-features -- --test-threads 1`
|
||||
* **Testing**: multiple test pipelines are run for different targets
|
||||
3. When everything is OK, code is merged.
|
||||
|
||||
|
||||
## Contribution Best Practices
|
||||
|
||||
* Read this [how-to about Github workflow here](https://guides.github.com/introduction/flow/) if you are not familiar with.
|
||||
|
||||
* Read all the texts related to [contributing for an OS community](https://github.com/HTTP-APIs/hydrus/tree/master/.github).
|
||||
|
||||
* Read this [how-to about writing a PR](https://github.com/blog/1943-how-to-write-the-perfect-pull-request) and this [other how-to about writing a issue](https://wiredcraft.com/blog/how-we-write-our-github-issues/)
|
||||
|
||||
* **read history**: search past open or closed issues for your problem before opening a new issue.
|
||||
|
||||
* **PRs on develop**: any change should be PRed first in `development`
|
||||
|
||||
* **testing**: everything should work and be tested as defined in the workflow. If any is failing for non-related reasons, annotate the test failure in the PR comment.
|
||||
@@ -0,0 +1,43 @@
|
||||
# smartcore: Introduction to modules
|
||||
|
||||
Important source of information:
|
||||
* [Rust API guidelines](https://rust-lang.github.io/api-guidelines/about.html)
|
||||
|
||||
## Walkthrough: traits system and basic structures
|
||||
|
||||
#### numbers
|
||||
The library is founded on basic traits provided by `num-traits`. Basic traits are in `src/numbers`. These traits are used to define all the procedures in the library to make everything safer and provide constraints to what implementations can handle.
|
||||
|
||||
#### linalg
|
||||
`numbers` are made at use in linear algebra structures in the **`src/linalg/basic`** module. These sub-modules define the traits used all over the code base.
|
||||
|
||||
* *arrays*: In particular data structures like `Array`, `Array1` (1-dimensional), `Array2` (matrix, 2-D); plus their "views" traits. Views are used to provide no-footprint access to data, they have composed traits to allow writing (mutable traits: `MutArray`, `ArrayViewMut`, ...).
|
||||
* *matrix*: This provides the main entrypoint to matrices operations and currently the only structure provided in the shape of `struct DenseMatrix`. A matrix can be instantiated and automatically make available all the traits in "arrays" (sparse matrices implementation will be provided).
|
||||
* *vector*: Convenience traits are implemented for `std::Vec` to allow extensive reuse.
|
||||
|
||||
These are all traits and by definition they do not allow instantiation. For instantiable structures see implementation like `DenseMatrix` with relative constructor.
|
||||
|
||||
#### linalg/traits
|
||||
The traits in `src/linalg/traits` are closely linked to Linear Algebra's theoretical framework. These traits are used to specify characteristics and constraints for types accepted by various algorithms. For example these allow to define if a matrix is `QRDecomposable` and/or `SVDDecomposable`. See docstring for referencese to theoretical framework.
|
||||
|
||||
As above these are all traits and by definition they do not allow instantiation. They are mostly used to provide constraints for implementations. For example, the implementation for Linear Regression requires the input data `X` to be in `smartcore`'s trait system `Array2<FloatNumber> + QRDecomposable<TX> + SVDDecomposable<TX>`, a 2-D matrix that is both QR and SVD decomposable; that is what the provided strucure `linalg::arrays::matrix::DenseMatrix` happens to be: `impl<T: FloatNumber> QRDecomposable<T> for DenseMatrix<T> {};impl<T: FloatNumber> SVDDecomposable<T> for DenseMatrix<T> {}`.
|
||||
|
||||
#### metrics
|
||||
Implementations for metrics (classification, regression, cluster, ...) and distance measure (Euclidean, Hamming, Manhattan, ...). For example: `Accuracy`, `F1`, `AUC`, `Precision`, `R2`. As everything else in the code base, these implementations reuse `numbers` and `linalg` traits and structures.
|
||||
|
||||
These are collected in structures like `pub struct ClassificationMetrics<T> {}` that implements `metrics::Metrics`, these are groups of functions (classification, regression, cluster, ...) that provide instantiation for the structures. Each of those instantiation can be passed around using the relative function, like `pub fn accuracy<T: Number + RealNumber + FloatNumber, V: ArrayView1<T>>(y_true: &V, y_pred: &V) -> T`. This provides a mechanism for metrics to be passed to higher interfaces like the `cross_validate`:
|
||||
```rust
|
||||
let results =
|
||||
cross_validate(
|
||||
BiasedEstimator::new(), // custom estimator
|
||||
&x, &y, // input data
|
||||
NoParameters {}, // extra parameters
|
||||
cv, // type of cross validator
|
||||
&accuracy // **metrics function** <--------
|
||||
).unwrap();
|
||||
```
|
||||
|
||||
TODO: complete for all modules
|
||||
|
||||
## Notebooks
|
||||
Proceed to the [**notebooks**](https://github.com/smartcorelib/smartcore-jupyter/) to see these modules in action.
|
||||
@@ -0,0 +1,25 @@
|
||||
### I'm submitting a
|
||||
- [ ] bug report.
|
||||
- [ ] improvement.
|
||||
- [ ] feature request.
|
||||
|
||||
### Current Behaviour:
|
||||
<!-- Describe about the bug -->
|
||||
|
||||
### Expected Behaviour:
|
||||
<!-- Describe what will happen if bug is removed -->
|
||||
|
||||
### Steps to reproduce:
|
||||
<!-- If you can then please provide the steps to reproduce the bug -->
|
||||
|
||||
### Snapshot:
|
||||
<!-- If you can then please provide the screenshot of the issue you are facing -->
|
||||
|
||||
### Environment:
|
||||
<!-- Please provide the following environment details if relevant -->
|
||||
* rustc version
|
||||
* cargo version
|
||||
* OS details
|
||||
|
||||
### Do you want to work on this issue?
|
||||
<!-- yes/no -->
|
||||
@@ -0,0 +1,29 @@
|
||||
<!-- Please create (if there is not one yet) a issue before sending a PR -->
|
||||
<!-- Add issue number (Eg: fixes #123) -->
|
||||
<!-- Always provide changes in existing tests or new tests -->
|
||||
|
||||
Fixes #
|
||||
|
||||
### Checklist
|
||||
- [ ] My branch is up-to-date with development branch.
|
||||
- [ ] Everything works and tested on latest stable Rust.
|
||||
- [ ] Coverage and Linting have been applied
|
||||
|
||||
### Current behaviour
|
||||
<!-- Describe the code you are going to change and its behaviour -->
|
||||
|
||||
### New expected behaviour
|
||||
<!-- Describe the new code and its expected behaviour -->
|
||||
|
||||
### Change logs
|
||||
|
||||
<!-- #### Added -->
|
||||
<!-- Edit these points below to describe the new features added with this PR -->
|
||||
<!-- - Feature 1 -->
|
||||
<!-- - Feature 2 -->
|
||||
|
||||
|
||||
<!-- #### Changed -->
|
||||
<!-- Edit these points below to describe the changes made in existing functionality with this PR -->
|
||||
<!-- - Change 1 -->
|
||||
<!-- - Change 1 -->
|
||||
+53
-16
@@ -2,35 +2,36 @@ name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, development ]
|
||||
branches: [main, development]
|
||||
pull_request:
|
||||
branches: [ development ]
|
||||
branches: [development]
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
runs-on: "${{ matrix.platform.os }}-latest"
|
||||
strategy:
|
||||
matrix:
|
||||
platform: [
|
||||
{ os: "windows", target: "x86_64-pc-windows-msvc" },
|
||||
{ os: "windows", target: "i686-pc-windows-msvc" },
|
||||
{ os: "ubuntu", target: "x86_64-unknown-linux-gnu" },
|
||||
{ os: "ubuntu", target: "i686-unknown-linux-gnu" },
|
||||
{ os: "ubuntu", target: "wasm32-unknown-unknown" },
|
||||
{ os: "macos", target: "aarch64-apple-darwin" },
|
||||
]
|
||||
platform:
|
||||
[
|
||||
{ os: "windows", target: "x86_64-pc-windows-msvc" },
|
||||
{ os: "windows", target: "i686-pc-windows-msvc" },
|
||||
{ os: "ubuntu", target: "x86_64-unknown-linux-gnu" },
|
||||
{ os: "ubuntu", target: "i686-unknown-linux-gnu" },
|
||||
{ os: "ubuntu", target: "wasm32-unknown-unknown" },
|
||||
{ os: "macos", target: "aarch64-apple-darwin" },
|
||||
]
|
||||
env:
|
||||
TZ: "/usr/share/zoneinfo/your/location"
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v4
|
||||
- name: Cache .cargo and target
|
||||
uses: actions/cache@v2
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo
|
||||
./target
|
||||
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.toml') }}
|
||||
restore-keys: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.toml') }}
|
||||
key: ${{ runner.os }}-cargo-${{ matrix.platform.target }}-${{ hashFiles('**/Cargo.toml') }}
|
||||
restore-keys: ${{ runner.os }}-cargo-${{ matrix.platform.target }}-${{ hashFiles('**/Cargo.toml') }}
|
||||
- name: Install Rust toolchain
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
@@ -40,12 +41,17 @@ jobs:
|
||||
default: true
|
||||
- name: Install test runner for wasm
|
||||
if: matrix.platform.target == 'wasm32-unknown-unknown'
|
||||
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
|
||||
- name: Stable Build
|
||||
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
|
||||
- name: Stable Build with all features
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: build
|
||||
args: --all-features --target ${{ matrix.platform.target }}
|
||||
- name: Stable Build without features
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: build
|
||||
args: --target ${{ matrix.platform.target }}
|
||||
- name: Tests
|
||||
if: matrix.platform.target == 'x86_64-unknown-linux-gnu' || matrix.platform.target == 'x86_64-pc-windows-msvc' || matrix.platform.target == 'aarch64-apple-darwin'
|
||||
uses: actions-rs/cargo@v1
|
||||
@@ -55,3 +61,34 @@ jobs:
|
||||
- name: Tests in WASM
|
||||
if: matrix.platform.target == 'wasm32-unknown-unknown'
|
||||
run: wasm-pack test --node -- --all-features
|
||||
|
||||
check_features:
|
||||
runs-on: "${{ matrix.platform.os }}-latest"
|
||||
strategy:
|
||||
matrix:
|
||||
platform: [{ os: "ubuntu" }]
|
||||
features: ["--features serde", "--features datasets", ""]
|
||||
env:
|
||||
TZ: "/usr/share/zoneinfo/your/location"
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Cache .cargo and target
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo
|
||||
./target
|
||||
key: ${{ runner.os }}-cargo-features-${{ hashFiles('**/Cargo.toml') }}
|
||||
restore-keys: ${{ runner.os }}-cargo-features-${{ hashFiles('**/Cargo.toml') }}
|
||||
- name: Install Rust toolchain
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: stable
|
||||
target: ${{ matrix.platform.target }}
|
||||
profile: minimal
|
||||
default: true
|
||||
- name: Stable Build
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: build
|
||||
args: --no-default-features ${{ matrix.features }}
|
||||
|
||||
@@ -12,9 +12,9 @@ jobs:
|
||||
env:
|
||||
TZ: "/usr/share/zoneinfo/your/location"
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v4
|
||||
- name: Cache .cargo
|
||||
uses: actions/cache@v2
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo
|
||||
@@ -39,6 +39,6 @@ jobs:
|
||||
command: tarpaulin
|
||||
args: --out Lcov --all-features -- --test-threads 1
|
||||
- name: Upload to codecov.io
|
||||
uses: codecov/codecov-action@v1
|
||||
uses: codecov/codecov-action@v2
|
||||
with:
|
||||
fail_ci_if_error: true
|
||||
fail_ci_if_error: false
|
||||
|
||||
@@ -14,7 +14,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Cache .cargo and target
|
||||
uses: actions/cache@v2
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo
|
||||
|
||||
+12
@@ -17,3 +17,15 @@ smartcore.code-workspace
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
|
||||
|
||||
flamegraph.svg
|
||||
perf.data
|
||||
perf.data.old
|
||||
src.dot
|
||||
out.svg
|
||||
|
||||
FlameGraph/
|
||||
out.stacks
|
||||
*.json
|
||||
*.txt
|
||||
+29
-1
@@ -4,7 +4,35 @@ All notable changes to this project will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
## [0.4.0] - 2023-04-05
|
||||
|
||||
## Added
|
||||
- WARNING: Breaking changes!
|
||||
- `DenseMatrix` constructor now returns `Result` to avoid user instantiating inconsistent rows/cols count. Their return values need to be unwrapped with `unwrap()`, see tests
|
||||
|
||||
## [0.3.0] - 2022-11-09
|
||||
|
||||
## Added
|
||||
- WARNING: Breaking changes!
|
||||
- Complete refactoring with **extensive API changes** that includes:
|
||||
* moving to a new traits system, less structs more traits
|
||||
* adapting all the modules to the new traits system
|
||||
* moving to Rust 2021, use of object-safe traits and `as_ref`
|
||||
* reorganization of the code base, eliminate duplicates
|
||||
- implements `readers` (needs "serde" feature) for read/write CSV file, extendible to other formats
|
||||
- default feature is now Wasm-/Wasi-first
|
||||
|
||||
## Changed
|
||||
- WARNING: Breaking changes!
|
||||
- Seeds to multiple algorithims that depend on random number generation
|
||||
- Added a new parameter to `train_test_split` to define the seed
|
||||
- changed use of "serde" feature
|
||||
|
||||
## Dropped
|
||||
- WARNING: Breaking changes!
|
||||
- Drop `nalgebra-bindings` feature, only `ndarray` as supported library
|
||||
|
||||
## [0.2.1] - 2021-05-10
|
||||
|
||||
## Added
|
||||
- L2 regularization penalty to the Logistic Regression
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
cff-version: 1.2.0
|
||||
message: "If this software contributes to published work, please cite smartcore."
|
||||
type: software
|
||||
title: "smartcore: Machine Learning in Rust"
|
||||
abstract: "smartcore is a comprehensive machine learning and numerical computing library for Rust, offering supervised and unsupervised algorithms, model evaluation tools, and linear algebra abstractions, with optional ndarray integration." [web:5][web:3]
|
||||
repository-code: "https://github.com/smartcorelib/smartcore" [web:5]
|
||||
url: "https://github.com/smartcorelib" [web:3]
|
||||
license: "MIT" [web:13]
|
||||
keywords:
|
||||
- Rust
|
||||
- machine learning
|
||||
- numerical computing
|
||||
- linear algebra
|
||||
- classification
|
||||
- regression
|
||||
- clustering
|
||||
- SVM
|
||||
- Random Forest
|
||||
- XGBoost [web:5]
|
||||
authors:
|
||||
- name: "smartcore Developers" [web:7]
|
||||
- name: "Lorenzo (contributor)" [web:16]
|
||||
- name: "Community contributors" [web:7]
|
||||
version: "0.4.2" [attached_file:1]
|
||||
date-released: "2025-09-14" [attached_file:1]
|
||||
preferred-citation:
|
||||
type: software
|
||||
title: "smartcore: Machine Learning in Rust"
|
||||
authors:
|
||||
- name: "smartcore Developers" [web:7]
|
||||
url: "https://github.com/smartcorelib" [web:3]
|
||||
repository-code: "https://github.com/smartcorelib/smartcore" [web:5]
|
||||
license: "MIT" [web:13]
|
||||
references:
|
||||
- type: manual
|
||||
title: "smartcore Documentation"
|
||||
url: "https://docs.rs/smartcore" [web:5]
|
||||
- type: webpage
|
||||
title: "smartcore Homepage"
|
||||
url: "https://github.com/smartcorelib" [web:3]
|
||||
notes: "For development features, see the docs.rs page and the repository README; SmartCore includes algorithms such as SVM, Random Forest, K-Means, PCA, DBSCAN, and XGBoost." [web:5]
|
||||
+43
-32
@@ -1,55 +1,66 @@
|
||||
[package]
|
||||
name = "smartcore"
|
||||
description = "The most advanced machine learning library in rust."
|
||||
description = "Machine Learning in Rust."
|
||||
homepage = "https://smartcorelib.org"
|
||||
version = "0.2.1"
|
||||
authors = ["SmartCore Developers"]
|
||||
edition = "2018"
|
||||
version = "0.4.5"
|
||||
authors = ["smartcore Developers"]
|
||||
edition = "2021"
|
||||
license = "Apache-2.0"
|
||||
documentation = "https://docs.rs/smartcore"
|
||||
repository = "https://github.com/smartcorelib/smartcore"
|
||||
readme = "README.md"
|
||||
keywords = ["machine-learning", "statistical", "ai", "optimization", "linear-algebra"]
|
||||
categories = ["science"]
|
||||
|
||||
[features]
|
||||
default = ["datasets"]
|
||||
ndarray-bindings = ["ndarray"]
|
||||
nalgebra-bindings = ["nalgebra"]
|
||||
datasets = []
|
||||
fp_bench = []
|
||||
exclude = [
|
||||
".github",
|
||||
".gitignore",
|
||||
"smartcore.iml",
|
||||
"smartcore.svg",
|
||||
"tests/"
|
||||
]
|
||||
|
||||
[dependencies]
|
||||
approx = "0.5.1"
|
||||
cfg-if = "1.0.0"
|
||||
ndarray = { version = "0.15", optional = true }
|
||||
nalgebra = { version = "0.31", optional = true }
|
||||
num-traits = "0.2"
|
||||
num-traits = "0.2.12"
|
||||
num = "0.4"
|
||||
rand = "0.8"
|
||||
rand_distr = "0.4"
|
||||
rand = { version = "0.8.5", default-features = false, features = ["small_rng"] }
|
||||
rand_distr = { version = "0.4", optional = true }
|
||||
serde = { version = "1", features = ["derive"], optional = true }
|
||||
itertools = "0.10.3"
|
||||
ordered-float = "5.1.0"
|
||||
|
||||
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
|
||||
typetag = { version = "0.2", optional = true }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
serde = ["dep:serde", "dep:typetag"]
|
||||
ndarray-bindings = ["dep:ndarray"]
|
||||
datasets = ["dep:rand_distr", "std_rand", "serde"]
|
||||
std_rand = ["rand/std_rng", "rand/std"]
|
||||
# used by wasm32-unknown-unknown for in-browser usage
|
||||
js = ["getrandom/js"]
|
||||
|
||||
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
getrandom = { version = "0.2.8", optional = true }
|
||||
|
||||
[target.'cfg(all(target_arch = "wasm32", not(target_os = "wasi")))'.dev-dependencies]
|
||||
wasm-bindgen-test = "0.3"
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.3"
|
||||
itertools = "0.13.0"
|
||||
serde_json = "1.0"
|
||||
bincode = "1.3.1"
|
||||
|
||||
[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
|
||||
wasm-bindgen-test = "0.3"
|
||||
[workspace]
|
||||
|
||||
[[bench]]
|
||||
name = "distance"
|
||||
harness = false
|
||||
[profile.test]
|
||||
debug = 1
|
||||
opt-level = 3
|
||||
|
||||
[[bench]]
|
||||
name = "naive_bayes"
|
||||
harness = false
|
||||
required-features = ["ndarray-bindings", "nalgebra-bindings"]
|
||||
|
||||
[[bench]]
|
||||
name = "fastpair"
|
||||
harness = false
|
||||
required-features = ["fp_bench"]
|
||||
[profile.release]
|
||||
strip = true
|
||||
lto = true
|
||||
codegen-units = 1
|
||||
overflow-checks = true
|
||||
|
||||
@@ -186,7 +186,7 @@
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
Copyright 2019-present at smartcore developers (smartcorelib.org)
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,18 +1,147 @@
|
||||
<p align="center">
|
||||
<a href="https://smartcorelib.org">
|
||||
<img src="smartcore.svg" width="450" alt="SmartCore">
|
||||
<img src="smartcore.svg" width="450" alt="smartcore">
|
||||
</a>
|
||||
</p>
|
||||
<p align = "center">
|
||||
<strong>
|
||||
<a href="https://smartcorelib.org">User guide</a> | <a href="https://docs.rs/smartcore/">API</a> | <a href="https://github.com/smartcorelib/smartcore-examples">Examples</a>
|
||||
<a href="https://smartcorelib.org">User guide</a> | <a href="https://docs.rs/smartcore/">API</a> | <a href="https://github.com/smartcorelib/smartcore-jupyter">Notebooks</a>
|
||||
</strong>
|
||||
</p>
|
||||
|
||||
-----
|
||||
|
||||
<p align = "center">
|
||||
<b>The Most Advanced Machine Learning Library In Rust.</b>
|
||||
<b>Machine Learning in Rust</b>
|
||||
</p>
|
||||
|
||||
-----
|
||||
-----
|
||||
[](https://github.com/smartcorelib/smartcore/actions/workflows/ci.yml) [](https://doi.org/10.5281/zenodo.17219259)
|
||||
|
||||
To start getting familiar with the new smartcore v0.4 API, there is now available a [**Jupyter Notebook environment repository**](https://github.com/smartcorelib/smartcore-jupyter). Please see instructions there, contributions welcome see [CONTRIBUTING](.github/CONTRIBUTING.md).
|
||||
|
||||
smartcore is a fast, ergonomic machine learning library for Rust, covering classical supervised and unsupervised methods with a modular linear algebra abstraction and optional ndarray support. It aims to provide production-friendly APIs, strong typing, and good defaults while remaining flexible for research and experimentation.
|
||||
|
||||
|
||||
## Highlights
|
||||
|
||||
- Broad algorithm coverage: linear models, tree-based methods, ensembles, SVMs, neighbors, clustering, decomposition, and preprocessing.
|
||||
- Strong linear algebra traits with optional ndarray integration for users who prefer array-first workflows.
|
||||
- WASM-first defaults with attention to portability; features such as serde and datasets are opt-in.
|
||||
- Practical utilities for model selection, evaluation, readers (CSV), dataset generators, and built-in sample datasets.
|
||||
|
||||
|
||||
## Install
|
||||
|
||||
Add to Cargo.toml:
|
||||
|
||||
```toml
|
||||
[dependencies]
|
||||
smartcore = "^0.4.3"
|
||||
```
|
||||
|
||||
For the latest development branch:
|
||||
|
||||
```toml
|
||||
[dependencies]
|
||||
smartcore = { git = "https://github.com/smartcorelib/smartcore", branch = "development" }
|
||||
```
|
||||
|
||||
Optional features (examples):
|
||||
|
||||
- datasets
|
||||
- serde
|
||||
- ndarray-bindings (deprecated in favor of ndarray-only support per recent changes)
|
||||
|
||||
Check Cargo.toml for available features and compatibility notes.
|
||||
|
||||
## Quick start
|
||||
|
||||
Here is a minimal example fitting a KNN classifier from native Rust vectors using DenseMatrix:
|
||||
|
||||
```rust
|
||||
use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
use smartcore::neighbors::knn_classifier::KNNClassifier;
|
||||
|
||||
// Turn vector slices into a matrix
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[1., 2.],
|
||||
&[3., 4.],
|
||||
&[5., 6.],
|
||||
&[7., 8.],
|
||||
&[9., 10.],
|
||||
]).unwrap;
|
||||
|
||||
// Class labels
|
||||
let y = vec![2, 2, 2, 3, 3];
|
||||
|
||||
// Train classifier
|
||||
let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
// Predict
|
||||
let yhat = knn.predict(&x).unwrap();
|
||||
```
|
||||
|
||||
This example mirrors the “First Example” section of the crate docs and demonstrates smartcore’s ergonomic API surface.
|
||||
|
||||
## Algorithms
|
||||
|
||||
smartcore organizes algorithms into clear modules with consistent traits:
|
||||
|
||||
- Clustering: K-Means, DBSCAN, agglomerative (including single-linkage), with K-Means++ initialization and utilities.
|
||||
- Matrix decomposition: SVD, EVD, Cholesky, LU, QR, plus related linear algebra helpers.
|
||||
- Linear models: OLS, Ridge, Lasso, ElasticNet, Logistic Regression.
|
||||
- Ensemble and tree-based: Random Forest (classifier and regressor), Extra Trees, shared reusable components across trees and forests.
|
||||
- SVM: SVC/SVR with kernel enum support and multiclass extensions.
|
||||
- Neighbors: KNN classification and regression with distance metrics and fast selection helpers.
|
||||
- Naive Bayes: Gaussian, Bernoulli, Categorical, Multinomial.
|
||||
- Preprocessing: encoders, split utilities, and common transforms.
|
||||
- Model selection and metrics: K-fold, search parameters, and evaluation utilities.
|
||||
|
||||
Recent refactors emphasize reusable components in trees/forests and expanded multiclass SVM capabilities. XGBoost-style regression and single-linkage clustering have been added. See CHANGELOG for API changes and migration notes.
|
||||
|
||||
## Data access and readers
|
||||
|
||||
- CSV readers: Read matrices from CSV with configurable delimiter and header rows, with helpful error messages and testing utilities (including non-IO reader abstractions).
|
||||
- Dataset generators: make_blobs, make_circles, make_moons for quick experiments.
|
||||
- Built-in datasets (feature-gated): digits, diabetes, breast cancer, boston, with serialization utilities to persist or refresh .xy bundles.
|
||||
|
||||
|
||||
## WebAssembly and portability
|
||||
|
||||
smartcore adopts a WASM/WASI-first posture in defaults to ease browser and embedded deployments. Some file-system operations are restricted in wasm targets; tests and IO utilities are structured to avoid unsupported calls where possible. Enable features like serde selectively to minimize footprint. Consult module-level docs and CHANGELOG for target-specific caveats.
|
||||
|
||||
## Notebooks
|
||||
|
||||
A curated set of Jupyter notebooks is available via the companion repository to explore smartcore interactively. To run locally, use EVCXR to enable Rust notebooks. This is the recommended path to quickly experiment with the v0.4 API.
|
||||
|
||||
## Roadmap and recent changes
|
||||
|
||||
- Trait-system refactor, fewer structs and more object-safe traits, large codebase reorganization.
|
||||
- Move to Rust 2021 edition and cleanup of duplicate code paths.
|
||||
- Seeds and deterministic controls across algorithms using RNG plumbing.
|
||||
- Search parameter API for hyperparameter exploration in K-Means and SVM families.
|
||||
- Tree and forest components refactored for reuse; Extra Trees added.
|
||||
- SVM multiclass support; SVR kernel enum and related improvements.
|
||||
- XGBoost-style regression introduced; single-linkage clustering implemented.
|
||||
|
||||
See CHANGELOG.md for precise details, deprecations, and breaking changes. Some features like nalgebra-bindings have been dropped in favor of ndarray-only paths. Default features are tuned for WASM/WASI builds; enable serde/datasets as needed.
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome:
|
||||
|
||||
- Open an issue describing the change and link it in the PR.
|
||||
- Keep PRs in sync with the development branch and ensure tests pass on stable Rust.
|
||||
- Provide or update tests; run clippy and apply formatting. Coverage and linting are part of the workflow.
|
||||
- Use the provided PR and issue templates to describe behavior changes, new features, and expectations.
|
||||
|
||||
If adding IO, prefer abstractions that make non-IO testing straightforward (see readers/iotesting). For datasets, keep serialization helpers in tests gated appropriately to avoid unintended file writes in wasm targets.
|
||||
|
||||
## License
|
||||
|
||||
smartcore is open source under a permissive license; see Cargo.toml and LICENSE for details. The crate metadata identifies “smartcore Developers” as authors; community contributions are credited via Git history and releases.
|
||||
|
||||
## Acknowledgments
|
||||
|
||||
smartcore’s design incorporates well-known ML patterns while staying idiomatic to Rust. Thanks to all contributors who have helped expand algorithms, improve docs, modernize traits, and harden the codebase for production.
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
#[macro_use]
|
||||
extern crate criterion;
|
||||
extern crate smartcore;
|
||||
|
||||
use criterion::black_box;
|
||||
use criterion::Criterion;
|
||||
use smartcore::math::distance::*;
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let a = vec![1., 2., 3.];
|
||||
|
||||
c.bench_function("Euclidean Distance", move |b| {
|
||||
b.iter(|| Distances::euclidian().distance(black_box(&a), black_box(&a)))
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
||||
criterion_main!(benches);
|
||||
@@ -1,56 +0,0 @@
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
|
||||
|
||||
// to run this bench you have to change the declaraion in mod.rs ---> pub mod fastpair;
|
||||
use smartcore::algorithm::neighbour::fastpair::FastPair;
|
||||
use smartcore::linalg::naive::dense_matrix::*;
|
||||
use std::time::Duration;
|
||||
|
||||
fn closest_pair_bench(n: usize, m: usize) -> () {
|
||||
let x = DenseMatrix::<f64>::rand(n, m);
|
||||
let fastpair = FastPair::new(&x);
|
||||
let result = fastpair.unwrap();
|
||||
|
||||
result.closest_pair();
|
||||
}
|
||||
|
||||
fn closest_pair_brute_bench(n: usize, m: usize) -> () {
|
||||
let x = DenseMatrix::<f64>::rand(n, m);
|
||||
let fastpair = FastPair::new(&x);
|
||||
let result = fastpair.unwrap();
|
||||
|
||||
result.closest_pair_brute();
|
||||
}
|
||||
|
||||
fn bench_fastpair(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("FastPair");
|
||||
|
||||
// with full samples size (100) the test will take too long
|
||||
group.significance_level(0.1).sample_size(30);
|
||||
// increase from default 5.0 secs
|
||||
group.measurement_time(Duration::from_secs(60));
|
||||
|
||||
for n_samples in [100_usize, 1000_usize].iter() {
|
||||
for n_features in [10_usize, 100_usize, 1000_usize].iter() {
|
||||
group.bench_with_input(
|
||||
BenchmarkId::from_parameter(format!(
|
||||
"fastpair --- n_samples: {}, n_features: {}",
|
||||
n_samples, n_features
|
||||
)),
|
||||
n_samples,
|
||||
|b, _| b.iter(|| closest_pair_bench(*n_samples, *n_features)),
|
||||
);
|
||||
group.bench_with_input(
|
||||
BenchmarkId::from_parameter(format!(
|
||||
"brute --- n_samples: {}, n_features: {}",
|
||||
n_samples, n_features
|
||||
)),
|
||||
n_samples,
|
||||
|b, _| b.iter(|| closest_pair_brute_bench(*n_samples, *n_features)),
|
||||
);
|
||||
}
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(benches, bench_fastpair);
|
||||
criterion_main!(benches);
|
||||
@@ -1,73 +0,0 @@
|
||||
use criterion::BenchmarkId;
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
|
||||
use nalgebra::DMatrix;
|
||||
use ndarray::Array2;
|
||||
use smartcore::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use smartcore::linalg::BaseMatrix;
|
||||
use smartcore::linalg::BaseVector;
|
||||
use smartcore::naive_bayes::gaussian::GaussianNB;
|
||||
|
||||
pub fn gaussian_naive_bayes_fit_benchmark(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("GaussianNB::fit");
|
||||
|
||||
for n_samples in [100_usize, 1000_usize, 10000_usize].iter() {
|
||||
for n_features in [10_usize, 100_usize, 1000_usize].iter() {
|
||||
let x = DenseMatrix::<f64>::rand(*n_samples, *n_features);
|
||||
let y: Vec<f64> = (0..*n_samples)
|
||||
.map(|i| (i % *n_samples / 5_usize) as f64)
|
||||
.collect::<Vec<f64>>();
|
||||
group.bench_with_input(
|
||||
BenchmarkId::from_parameter(format!(
|
||||
"n_samples: {}, n_features: {}",
|
||||
n_samples, n_features
|
||||
)),
|
||||
n_samples,
|
||||
|b, _| {
|
||||
b.iter(|| {
|
||||
GaussianNB::fit(black_box(&x), black_box(&y), Default::default()).unwrap();
|
||||
})
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn gaussian_naive_matrix_datastructure(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("GaussianNB");
|
||||
let classes = (0..10000).map(|i| (i % 25) as f64).collect::<Vec<f64>>();
|
||||
|
||||
group.bench_function("DenseMatrix", |b| {
|
||||
let x = DenseMatrix::<f64>::rand(10000, 500);
|
||||
let y = <DenseMatrix<f64> as BaseMatrix<f64>>::RowVector::from_array(&classes);
|
||||
|
||||
b.iter(|| {
|
||||
GaussianNB::fit(black_box(&x), black_box(&y), Default::default()).unwrap();
|
||||
})
|
||||
});
|
||||
|
||||
group.bench_function("ndarray", |b| {
|
||||
let x = Array2::<f64>::rand(10000, 500);
|
||||
let y = <Array2<f64> as BaseMatrix<f64>>::RowVector::from_array(&classes);
|
||||
|
||||
b.iter(|| {
|
||||
GaussianNB::fit(black_box(&x), black_box(&y), Default::default()).unwrap();
|
||||
})
|
||||
});
|
||||
|
||||
group.bench_function("ndalgebra", |b| {
|
||||
let x = DMatrix::<f64>::rand(10000, 500);
|
||||
let y = <DMatrix<f64> as BaseMatrix<f64>>::RowVector::from_array(&classes);
|
||||
|
||||
b.iter(|| {
|
||||
GaussianNB::fit(black_box(&x), black_box(&y), Default::default()).unwrap();
|
||||
})
|
||||
});
|
||||
}
|
||||
criterion_group!(
|
||||
benches,
|
||||
gaussian_naive_bayes_fit_benchmark,
|
||||
gaussian_naive_matrix_datastructure
|
||||
);
|
||||
criterion_main!(benches);
|
||||
@@ -1,15 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="RUST_MODULE" version="4">
|
||||
<component name="NewModuleRootManager" inherit-compiler-output="true">
|
||||
<exclude-output />
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/examples" isTestSource="false" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/tests" isTestSource="true" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/benches" isTestSource="true" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/target" />
|
||||
</content>
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
||||
+1
-1
@@ -76,5 +76,5 @@
|
||||
y="81.876823"
|
||||
x="91.861809"
|
||||
id="tspan842"
|
||||
sodipodi:role="line">SmartCore</tspan></text>
|
||||
sodipodi:role="line">smartcore</tspan></text>
|
||||
</svg>
|
||||
|
||||
|
Before Width: | Height: | Size: 2.5 KiB After Width: | Height: | Size: 2.5 KiB |
@@ -1,50 +1,50 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::distance::euclidian::*;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::linalg::basic::arrays::Array2;
|
||||
use crate::metrics::distance::euclidian::*;
|
||||
use crate::numbers::basenum::Number;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct BBDTree<T: RealNumber> {
|
||||
nodes: Vec<BBDTreeNode<T>>,
|
||||
pub struct BBDTree {
|
||||
nodes: Vec<BBDTreeNode>,
|
||||
index: Vec<usize>,
|
||||
root: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct BBDTreeNode<T: RealNumber> {
|
||||
struct BBDTreeNode {
|
||||
count: usize,
|
||||
index: usize,
|
||||
center: Vec<T>,
|
||||
radius: Vec<T>,
|
||||
sum: Vec<T>,
|
||||
cost: T,
|
||||
center: Vec<f64>,
|
||||
radius: Vec<f64>,
|
||||
sum: Vec<f64>,
|
||||
cost: f64,
|
||||
lower: Option<usize>,
|
||||
upper: Option<usize>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> BBDTreeNode<T> {
|
||||
fn new(d: usize) -> BBDTreeNode<T> {
|
||||
impl BBDTreeNode {
|
||||
fn new(d: usize) -> BBDTreeNode {
|
||||
BBDTreeNode {
|
||||
count: 0,
|
||||
index: 0,
|
||||
center: vec![T::zero(); d],
|
||||
radius: vec![T::zero(); d],
|
||||
sum: vec![T::zero(); d],
|
||||
cost: T::zero(),
|
||||
center: vec![0f64; d],
|
||||
radius: vec![0f64; d],
|
||||
sum: vec![0f64; d],
|
||||
cost: 0f64,
|
||||
lower: Option::None,
|
||||
upper: Option::None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> BBDTree<T> {
|
||||
pub fn new<M: Matrix<T>>(data: &M) -> BBDTree<T> {
|
||||
let nodes = Vec::new();
|
||||
impl BBDTree {
|
||||
pub fn new<T: Number, M: Array2<T>>(data: &M) -> BBDTree {
|
||||
let nodes: Vec<BBDTreeNode> = Vec::new();
|
||||
|
||||
let (n, _) = data.shape();
|
||||
|
||||
let index = (0..n).collect::<Vec<_>>();
|
||||
let index = (0..n).collect::<Vec<usize>>();
|
||||
|
||||
let mut tree = BBDTree {
|
||||
nodes,
|
||||
@@ -61,18 +61,18 @@ impl<T: RealNumber> BBDTree<T> {
|
||||
|
||||
pub(crate) fn clustering(
|
||||
&self,
|
||||
centroids: &[Vec<T>],
|
||||
sums: &mut Vec<Vec<T>>,
|
||||
centroids: &[Vec<f64>],
|
||||
sums: &mut Vec<Vec<f64>>,
|
||||
counts: &mut Vec<usize>,
|
||||
membership: &mut Vec<usize>,
|
||||
) -> T {
|
||||
) -> f64 {
|
||||
let k = centroids.len();
|
||||
|
||||
counts.iter_mut().for_each(|v| *v = 0);
|
||||
let mut candidates = vec![0; k];
|
||||
for i in 0..k {
|
||||
candidates[i] = i;
|
||||
sums[i].iter_mut().for_each(|v| *v = T::zero());
|
||||
sums[i].iter_mut().for_each(|v| *v = 0f64);
|
||||
}
|
||||
|
||||
self.filter(
|
||||
@@ -89,13 +89,13 @@ impl<T: RealNumber> BBDTree<T> {
|
||||
fn filter(
|
||||
&self,
|
||||
node: usize,
|
||||
centroids: &[Vec<T>],
|
||||
centroids: &[Vec<f64>],
|
||||
candidates: &[usize],
|
||||
k: usize,
|
||||
sums: &mut Vec<Vec<T>>,
|
||||
sums: &mut Vec<Vec<f64>>,
|
||||
counts: &mut Vec<usize>,
|
||||
membership: &mut Vec<usize>,
|
||||
) -> T {
|
||||
) -> f64 {
|
||||
let d = centroids[0].len();
|
||||
|
||||
let mut min_dist =
|
||||
@@ -163,9 +163,9 @@ impl<T: RealNumber> BBDTree<T> {
|
||||
}
|
||||
|
||||
fn prune(
|
||||
center: &[T],
|
||||
radius: &[T],
|
||||
centroids: &[Vec<T>],
|
||||
center: &[f64],
|
||||
radius: &[f64],
|
||||
centroids: &[Vec<f64>],
|
||||
best_index: usize,
|
||||
test_index: usize,
|
||||
) -> bool {
|
||||
@@ -177,22 +177,22 @@ impl<T: RealNumber> BBDTree<T> {
|
||||
|
||||
let best = ¢roids[best_index];
|
||||
let test = ¢roids[test_index];
|
||||
let mut lhs = T::zero();
|
||||
let mut rhs = T::zero();
|
||||
let mut lhs = 0f64;
|
||||
let mut rhs = 0f64;
|
||||
for i in 0..d {
|
||||
let diff = test[i] - best[i];
|
||||
lhs += diff * diff;
|
||||
if diff > T::zero() {
|
||||
if diff > 0f64 {
|
||||
rhs += (center[i] + radius[i] - best[i]) * diff;
|
||||
} else {
|
||||
rhs += (center[i] - radius[i] - best[i]) * diff;
|
||||
}
|
||||
}
|
||||
|
||||
lhs >= T::two() * rhs
|
||||
lhs >= 2f64 * rhs
|
||||
}
|
||||
|
||||
fn build_node<M: Matrix<T>>(&mut self, data: &M, begin: usize, end: usize) -> usize {
|
||||
fn build_node<T: Number, M: Array2<T>>(&mut self, data: &M, begin: usize, end: usize) -> usize {
|
||||
let (_, d) = data.shape();
|
||||
|
||||
let mut node = BBDTreeNode::new(d);
|
||||
@@ -200,17 +200,17 @@ impl<T: RealNumber> BBDTree<T> {
|
||||
node.count = end - begin;
|
||||
node.index = begin;
|
||||
|
||||
let mut lower_bound = vec![T::zero(); d];
|
||||
let mut upper_bound = vec![T::zero(); d];
|
||||
let mut lower_bound = vec![0f64; d];
|
||||
let mut upper_bound = vec![0f64; d];
|
||||
|
||||
for i in 0..d {
|
||||
lower_bound[i] = data.get(self.index[begin], i);
|
||||
upper_bound[i] = data.get(self.index[begin], i);
|
||||
lower_bound[i] = data.get((self.index[begin], i)).to_f64().unwrap();
|
||||
upper_bound[i] = data.get((self.index[begin], i)).to_f64().unwrap();
|
||||
}
|
||||
|
||||
for i in begin..end {
|
||||
for j in 0..d {
|
||||
let c = data.get(self.index[i], j);
|
||||
let c = data.get((self.index[i], j)).to_f64().unwrap();
|
||||
if lower_bound[j] > c {
|
||||
lower_bound[j] = c;
|
||||
}
|
||||
@@ -220,32 +220,32 @@ impl<T: RealNumber> BBDTree<T> {
|
||||
}
|
||||
}
|
||||
|
||||
let mut max_radius = T::from(-1.).unwrap();
|
||||
let mut max_radius = -1f64;
|
||||
let mut split_index = 0;
|
||||
for i in 0..d {
|
||||
node.center[i] = (lower_bound[i] + upper_bound[i]) / T::two();
|
||||
node.radius[i] = (upper_bound[i] - lower_bound[i]) / T::two();
|
||||
node.center[i] = (lower_bound[i] + upper_bound[i]) / 2f64;
|
||||
node.radius[i] = (upper_bound[i] - lower_bound[i]) / 2f64;
|
||||
if node.radius[i] > max_radius {
|
||||
max_radius = node.radius[i];
|
||||
split_index = i;
|
||||
}
|
||||
}
|
||||
|
||||
if max_radius < T::from(1E-10).unwrap() {
|
||||
if max_radius < 1E-10 {
|
||||
node.lower = Option::None;
|
||||
node.upper = Option::None;
|
||||
for i in 0..d {
|
||||
node.sum[i] = data.get(self.index[begin], i);
|
||||
node.sum[i] = data.get((self.index[begin], i)).to_f64().unwrap();
|
||||
}
|
||||
|
||||
if end > begin + 1 {
|
||||
let len = end - begin;
|
||||
for i in 0..d {
|
||||
node.sum[i] *= T::from(len).unwrap();
|
||||
node.sum[i] *= len as f64;
|
||||
}
|
||||
}
|
||||
|
||||
node.cost = T::zero();
|
||||
node.cost = 0f64;
|
||||
return self.add_node(node);
|
||||
}
|
||||
|
||||
@@ -254,8 +254,10 @@ impl<T: RealNumber> BBDTree<T> {
|
||||
let mut i2 = end - 1;
|
||||
let mut size = 0;
|
||||
while i1 <= i2 {
|
||||
let mut i1_good = data.get(self.index[i1], split_index) < split_cutoff;
|
||||
let mut i2_good = data.get(self.index[i2], split_index) >= split_cutoff;
|
||||
let mut i1_good =
|
||||
data.get((self.index[i1], split_index)).to_f64().unwrap() < split_cutoff;
|
||||
let mut i2_good =
|
||||
data.get((self.index[i2], split_index)).to_f64().unwrap() >= split_cutoff;
|
||||
|
||||
if !i1_good && !i2_good {
|
||||
self.index.swap(i1, i2);
|
||||
@@ -281,9 +283,9 @@ impl<T: RealNumber> BBDTree<T> {
|
||||
self.nodes[node.lower.unwrap()].sum[i] + self.nodes[node.upper.unwrap()].sum[i];
|
||||
}
|
||||
|
||||
let mut mean = vec![T::zero(); d];
|
||||
let mut mean = vec![0f64; d];
|
||||
for (i, mean_i) in mean.iter_mut().enumerate().take(d) {
|
||||
*mean_i = node.sum[i] / T::from(node.count).unwrap();
|
||||
*mean_i = node.sum[i] / node.count as f64;
|
||||
}
|
||||
|
||||
node.cost = BBDTree::node_cost(&self.nodes[node.lower.unwrap()], &mean)
|
||||
@@ -292,17 +294,17 @@ impl<T: RealNumber> BBDTree<T> {
|
||||
self.add_node(node)
|
||||
}
|
||||
|
||||
fn node_cost(node: &BBDTreeNode<T>, center: &[T]) -> T {
|
||||
fn node_cost(node: &BBDTreeNode, center: &[f64]) -> f64 {
|
||||
let d = center.len();
|
||||
let mut scatter = T::zero();
|
||||
let mut scatter = 0f64;
|
||||
for (i, center_i) in center.iter().enumerate().take(d) {
|
||||
let x = (node.sum[i] / T::from(node.count).unwrap()) - *center_i;
|
||||
let x = (node.sum[i] / node.count as f64) - *center_i;
|
||||
scatter += x * x;
|
||||
}
|
||||
node.cost + T::from(node.count).unwrap() * scatter
|
||||
node.cost + node.count as f64 * scatter
|
||||
}
|
||||
|
||||
fn add_node(&mut self, new_node: BBDTreeNode<T>) -> usize {
|
||||
fn add_node(&mut self, new_node: BBDTreeNode) -> usize {
|
||||
let idx = self.nodes.len();
|
||||
self.nodes.push(new_node);
|
||||
idx
|
||||
@@ -312,9 +314,12 @@ impl<T: RealNumber> BBDTree<T> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn bbdtree_iris() {
|
||||
let data = DenseMatrix::from_2d_array(&[
|
||||
@@ -338,7 +343,8 @@ mod tests {
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let tree = BBDTree::new(&data);
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,12 +4,12 @@
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::algorithm::neighbour::cover_tree::*;
|
||||
//! use smartcore::math::distance::Distance;
|
||||
//! use smartcore::metrics::distance::Distance;
|
||||
//!
|
||||
//! #[derive(Clone)]
|
||||
//! struct SimpleDistance {} // Our distance function
|
||||
//!
|
||||
//! impl Distance<i32, f64> for SimpleDistance {
|
||||
//! impl Distance<i32> for SimpleDistance {
|
||||
//! fn distance(&self, a: &i32, b: &i32) -> f64 { // simple simmetrical scalar distance
|
||||
//! (a - b).abs() as f64
|
||||
//! }
|
||||
@@ -29,28 +29,27 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::sort::heap_select::HeapSelection;
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::math::distance::Distance;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::metrics::distance::Distance;
|
||||
|
||||
/// Implements Cover Tree algorithm
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct CoverTree<T, F: RealNumber, D: Distance<T, F>> {
|
||||
base: F,
|
||||
inv_log_base: F,
|
||||
pub struct CoverTree<T, D: Distance<T>> {
|
||||
base: f64,
|
||||
inv_log_base: f64,
|
||||
distance: D,
|
||||
root: Node<F>,
|
||||
root: Node,
|
||||
data: Vec<T>,
|
||||
identical_excluded: bool,
|
||||
}
|
||||
|
||||
impl<T, F: RealNumber, D: Distance<T, F>> PartialEq for CoverTree<T, F, D> {
|
||||
impl<T, D: Distance<T>> PartialEq for CoverTree<T, D> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.data.len() != other.data.len() {
|
||||
return false;
|
||||
}
|
||||
for i in 0..self.data.len() {
|
||||
if self.distance.distance(&self.data[i], &other.data[i]) != F::zero() {
|
||||
if self.distance.distance(&self.data[i], &other.data[i]) != 0f64 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -60,36 +59,36 @@ impl<T, F: RealNumber, D: Distance<T, F>> PartialEq for CoverTree<T, F, D> {
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
struct Node<F: RealNumber> {
|
||||
struct Node {
|
||||
idx: usize,
|
||||
max_dist: F,
|
||||
parent_dist: F,
|
||||
children: Vec<Node<F>>,
|
||||
max_dist: f64,
|
||||
parent_dist: f64,
|
||||
children: Vec<Node>,
|
||||
_scale: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct DistanceSet<F: RealNumber> {
|
||||
struct DistanceSet {
|
||||
idx: usize,
|
||||
dist: Vec<F>,
|
||||
dist: Vec<f64>,
|
||||
}
|
||||
|
||||
impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D> {
|
||||
impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
|
||||
/// Construct a cover tree.
|
||||
/// * `data` - vector of data points to search for.
|
||||
/// * `distance` - distance metric to use for searching. This function should extend [`Distance`](../../../math/distance/index.html) interface.
|
||||
pub fn new(data: Vec<T>, distance: D) -> Result<CoverTree<T, F, D>, Failed> {
|
||||
let base = F::from_f64(1.3).unwrap();
|
||||
pub fn new(data: Vec<T>, distance: D) -> Result<CoverTree<T, D>, Failed> {
|
||||
let base = 1.3f64;
|
||||
let root = Node {
|
||||
idx: 0,
|
||||
max_dist: F::zero(),
|
||||
parent_dist: F::zero(),
|
||||
max_dist: 0f64,
|
||||
parent_dist: 0f64,
|
||||
children: Vec::new(),
|
||||
_scale: 0,
|
||||
};
|
||||
let mut tree = CoverTree {
|
||||
base,
|
||||
inv_log_base: F::one() / base.ln(),
|
||||
inv_log_base: 1f64 / base.ln(),
|
||||
distance,
|
||||
root,
|
||||
data,
|
||||
@@ -104,7 +103,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
/// Find k nearest neighbors of `p`
|
||||
/// * `p` - look for k nearest points to `p`
|
||||
/// * `k` - the number of nearest neighbors to return
|
||||
pub fn find(&self, p: &T, k: usize) -> Result<Vec<(usize, F, &T)>, Failed> {
|
||||
pub fn find(&self, p: &T, k: usize) -> Result<Vec<(usize, f64, &T)>, Failed> {
|
||||
if k == 0 {
|
||||
return Err(Failed::because(FailedError::FindFailed, "k should be > 0"));
|
||||
}
|
||||
@@ -119,13 +118,13 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
let e = self.get_data_value(self.root.idx);
|
||||
let mut d = self.distance.distance(e, p);
|
||||
|
||||
let mut current_cover_set: Vec<(F, &Node<F>)> = Vec::new();
|
||||
let mut zero_set: Vec<(F, &Node<F>)> = Vec::new();
|
||||
let mut current_cover_set: Vec<(f64, &Node)> = Vec::new();
|
||||
let mut zero_set: Vec<(f64, &Node)> = Vec::new();
|
||||
|
||||
current_cover_set.push((d, &self.root));
|
||||
|
||||
let mut heap = HeapSelection::with_capacity(k);
|
||||
heap.add(F::max_value());
|
||||
heap.add(f64::MAX);
|
||||
|
||||
let mut empty_heap = true;
|
||||
if !self.identical_excluded || self.get_data_value(self.root.idx) != p {
|
||||
@@ -134,7 +133,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
}
|
||||
|
||||
while !current_cover_set.is_empty() {
|
||||
let mut next_cover_set: Vec<(F, &Node<F>)> = Vec::new();
|
||||
let mut next_cover_set: Vec<(f64, &Node)> = Vec::new();
|
||||
for par in current_cover_set {
|
||||
let parent = par.1;
|
||||
for c in 0..parent.children.len() {
|
||||
@@ -146,7 +145,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
}
|
||||
|
||||
let upper_bound = if empty_heap {
|
||||
F::infinity()
|
||||
f64::INFINITY
|
||||
} else {
|
||||
*heap.peek()
|
||||
};
|
||||
@@ -169,7 +168,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
current_cover_set = next_cover_set;
|
||||
}
|
||||
|
||||
let mut neighbors: Vec<(usize, F, &T)> = Vec::new();
|
||||
let mut neighbors: Vec<(usize, f64, &T)> = Vec::new();
|
||||
let upper_bound = *heap.peek();
|
||||
for ds in zero_set {
|
||||
if ds.0 <= upper_bound {
|
||||
@@ -189,25 +188,25 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
/// Find all nearest neighbors within radius `radius` from `p`
|
||||
/// * `p` - look for k nearest points to `p`
|
||||
/// * `radius` - radius of the search
|
||||
pub fn find_radius(&self, p: &T, radius: F) -> Result<Vec<(usize, F, &T)>, Failed> {
|
||||
if radius <= F::zero() {
|
||||
pub fn find_radius(&self, p: &T, radius: f64) -> Result<Vec<(usize, f64, &T)>, Failed> {
|
||||
if radius <= 0f64 {
|
||||
return Err(Failed::because(
|
||||
FailedError::FindFailed,
|
||||
"radius should be > 0",
|
||||
));
|
||||
}
|
||||
|
||||
let mut neighbors: Vec<(usize, F, &T)> = Vec::new();
|
||||
let mut neighbors: Vec<(usize, f64, &T)> = Vec::new();
|
||||
|
||||
let mut current_cover_set: Vec<(F, &Node<F>)> = Vec::new();
|
||||
let mut zero_set: Vec<(F, &Node<F>)> = Vec::new();
|
||||
let mut current_cover_set: Vec<(f64, &Node)> = Vec::new();
|
||||
let mut zero_set: Vec<(f64, &Node)> = Vec::new();
|
||||
|
||||
let e = self.get_data_value(self.root.idx);
|
||||
let mut d = self.distance.distance(e, p);
|
||||
current_cover_set.push((d, &self.root));
|
||||
|
||||
while !current_cover_set.is_empty() {
|
||||
let mut next_cover_set: Vec<(F, &Node<F>)> = Vec::new();
|
||||
let mut next_cover_set: Vec<(f64, &Node)> = Vec::new();
|
||||
for par in current_cover_set {
|
||||
let parent = par.1;
|
||||
for c in 0..parent.children.len() {
|
||||
@@ -240,23 +239,23 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
Ok(neighbors)
|
||||
}
|
||||
|
||||
fn new_leaf(&self, idx: usize) -> Node<F> {
|
||||
fn new_leaf(&self, idx: usize) -> Node {
|
||||
Node {
|
||||
idx,
|
||||
max_dist: F::zero(),
|
||||
parent_dist: F::zero(),
|
||||
max_dist: 0f64,
|
||||
parent_dist: 0f64,
|
||||
children: Vec::new(),
|
||||
_scale: 100,
|
||||
}
|
||||
}
|
||||
|
||||
fn build_cover_tree(&mut self) {
|
||||
let mut point_set: Vec<DistanceSet<F>> = Vec::new();
|
||||
let mut consumed_set: Vec<DistanceSet<F>> = Vec::new();
|
||||
let mut point_set: Vec<DistanceSet> = Vec::new();
|
||||
let mut consumed_set: Vec<DistanceSet> = Vec::new();
|
||||
|
||||
let point = &self.data[0];
|
||||
let idx = 0;
|
||||
let mut max_dist = -F::one();
|
||||
let mut max_dist = -1f64;
|
||||
|
||||
for i in 1..self.data.len() {
|
||||
let dist = self.distance.distance(point, &self.data[i]);
|
||||
@@ -284,16 +283,16 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
p: usize,
|
||||
max_scale: i64,
|
||||
top_scale: i64,
|
||||
point_set: &mut Vec<DistanceSet<F>>,
|
||||
consumed_set: &mut Vec<DistanceSet<F>>,
|
||||
) -> Node<F> {
|
||||
point_set: &mut Vec<DistanceSet>,
|
||||
consumed_set: &mut Vec<DistanceSet>,
|
||||
) -> Node {
|
||||
if point_set.is_empty() {
|
||||
self.new_leaf(p)
|
||||
} else {
|
||||
let max_dist = self.max(point_set);
|
||||
let next_scale = (max_scale - 1).min(self.get_scale(max_dist));
|
||||
if next_scale == std::i64::MIN {
|
||||
let mut children: Vec<Node<F>> = Vec::new();
|
||||
if next_scale == i64::MIN {
|
||||
let mut children: Vec<Node> = Vec::new();
|
||||
let mut leaf = self.new_leaf(p);
|
||||
children.push(leaf);
|
||||
while !point_set.is_empty() {
|
||||
@@ -304,13 +303,13 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
}
|
||||
Node {
|
||||
idx: p,
|
||||
max_dist: F::zero(),
|
||||
parent_dist: F::zero(),
|
||||
max_dist: 0f64,
|
||||
parent_dist: 0f64,
|
||||
children,
|
||||
_scale: 100,
|
||||
}
|
||||
} else {
|
||||
let mut far: Vec<DistanceSet<F>> = Vec::new();
|
||||
let mut far: Vec<DistanceSet> = Vec::new();
|
||||
self.split(point_set, &mut far, max_scale);
|
||||
|
||||
let child = self.batch_insert(p, next_scale, top_scale, point_set, consumed_set);
|
||||
@@ -319,14 +318,14 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
point_set.append(&mut far);
|
||||
child
|
||||
} else {
|
||||
let mut children: Vec<Node<F>> = vec![child];
|
||||
let mut new_point_set: Vec<DistanceSet<F>> = Vec::new();
|
||||
let mut new_consumed_set: Vec<DistanceSet<F>> = Vec::new();
|
||||
let mut children: Vec<Node> = vec![child];
|
||||
let mut new_point_set: Vec<DistanceSet> = Vec::new();
|
||||
let mut new_consumed_set: Vec<DistanceSet> = Vec::new();
|
||||
|
||||
while !point_set.is_empty() {
|
||||
let set: DistanceSet<F> = point_set.remove(point_set.len() - 1);
|
||||
let set: DistanceSet = point_set.remove(point_set.len() - 1);
|
||||
|
||||
let new_dist: F = set.dist[set.dist.len() - 1];
|
||||
let new_dist = set.dist[set.dist.len() - 1];
|
||||
|
||||
self.dist_split(
|
||||
point_set,
|
||||
@@ -374,7 +373,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
Node {
|
||||
idx: p,
|
||||
max_dist: self.max(consumed_set),
|
||||
parent_dist: F::zero(),
|
||||
parent_dist: 0f64,
|
||||
children,
|
||||
_scale: (top_scale - max_scale),
|
||||
}
|
||||
@@ -385,12 +384,12 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
|
||||
fn split(
|
||||
&self,
|
||||
point_set: &mut Vec<DistanceSet<F>>,
|
||||
far_set: &mut Vec<DistanceSet<F>>,
|
||||
point_set: &mut Vec<DistanceSet>,
|
||||
far_set: &mut Vec<DistanceSet>,
|
||||
max_scale: i64,
|
||||
) {
|
||||
let fmax = self.get_cover_radius(max_scale);
|
||||
let mut new_set: Vec<DistanceSet<F>> = Vec::new();
|
||||
let mut new_set: Vec<DistanceSet> = Vec::new();
|
||||
for n in point_set.drain(0..) {
|
||||
if n.dist[n.dist.len() - 1] <= fmax {
|
||||
new_set.push(n);
|
||||
@@ -404,13 +403,13 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
|
||||
fn dist_split(
|
||||
&self,
|
||||
point_set: &mut Vec<DistanceSet<F>>,
|
||||
new_point_set: &mut Vec<DistanceSet<F>>,
|
||||
point_set: &mut Vec<DistanceSet>,
|
||||
new_point_set: &mut Vec<DistanceSet>,
|
||||
new_point: &T,
|
||||
max_scale: i64,
|
||||
) {
|
||||
let fmax = self.get_cover_radius(max_scale);
|
||||
let mut new_set: Vec<DistanceSet<F>> = Vec::new();
|
||||
let mut new_set: Vec<DistanceSet> = Vec::new();
|
||||
for mut n in point_set.drain(0..) {
|
||||
let new_dist = self
|
||||
.distance
|
||||
@@ -426,24 +425,24 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
point_set.append(&mut new_set);
|
||||
}
|
||||
|
||||
fn get_cover_radius(&self, s: i64) -> F {
|
||||
self.base.powf(F::from_i64(s).unwrap())
|
||||
fn get_cover_radius(&self, s: i64) -> f64 {
|
||||
self.base.powf(s as f64)
|
||||
}
|
||||
|
||||
fn get_data_value(&self, idx: usize) -> &T {
|
||||
&self.data[idx]
|
||||
}
|
||||
|
||||
fn get_scale(&self, d: F) -> i64 {
|
||||
if d == F::zero() {
|
||||
std::i64::MIN
|
||||
fn get_scale(&self, d: f64) -> i64 {
|
||||
if d == 0f64 {
|
||||
i64::MIN
|
||||
} else {
|
||||
(self.inv_log_base * d.ln()).ceil().to_i64().unwrap()
|
||||
(self.inv_log_base * d.ln()).ceil() as i64
|
||||
}
|
||||
}
|
||||
|
||||
fn max(&self, distance_set: &[DistanceSet<F>]) -> F {
|
||||
let mut max = F::zero();
|
||||
fn max(&self, distance_set: &[DistanceSet]) -> f64 {
|
||||
let mut max = 0f64;
|
||||
for n in distance_set {
|
||||
if max < n.dist[n.dist.len() - 1] {
|
||||
max = n.dist[n.dist.len() - 1];
|
||||
@@ -457,19 +456,22 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::math::distance::Distances;
|
||||
use crate::metrics::distance::Distances;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
struct SimpleDistance {}
|
||||
|
||||
impl Distance<i32, f64> for SimpleDistance {
|
||||
impl Distance<i32> for SimpleDistance {
|
||||
fn distance(&self, a: &i32, b: &i32) -> f64 {
|
||||
(a - b).abs() as f64
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn cover_tree_test() {
|
||||
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
|
||||
@@ -486,7 +488,10 @@ mod tests {
|
||||
let knn: Vec<i32> = knn.iter().map(|v| *v.2).collect();
|
||||
assert_eq!(vec!(3, 4, 5, 6, 7), knn);
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn cover_tree_test1() {
|
||||
let data = vec![
|
||||
@@ -505,7 +510,10 @@ mod tests {
|
||||
|
||||
assert_eq!(vec!(0, 1, 2), knn);
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
@@ -513,7 +521,7 @@ mod tests {
|
||||
|
||||
let tree = CoverTree::new(data, SimpleDistance {}).unwrap();
|
||||
|
||||
let deserialized_tree: CoverTree<i32, f64, SimpleDistance> =
|
||||
let deserialized_tree: CoverTree<i32, SimpleDistance> =
|
||||
serde_json::from_str(&serde_json::to_string(&tree).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(tree, deserialized_tree);
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
//!
|
||||
//! Dissimilarities for vector-vector distance
|
||||
//!
|
||||
//! Representing distances as pairwise dissimilarities, so to build a
|
||||
//! graph of closest neighbours. This representation can be reused for
|
||||
//! different implementations (initially used in this library for FastPair).
|
||||
use std::cmp::{Eq, Ordering, PartialOrd};
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
///
|
||||
/// The edge of the subgraph is defined by `PairwiseDistance`.
|
||||
/// The calling algorithm can store a list of distsances as
|
||||
/// a list of these structures.
|
||||
///
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct PairwiseDistance<T: RealNumber> {
|
||||
/// index of the vector in the original `Matrix` or list
|
||||
pub node: usize,
|
||||
|
||||
/// index of the closest neighbor in the original `Matrix` or same list
|
||||
pub neighbour: Option<usize>,
|
||||
|
||||
/// measure of distance, according to the algorithm distance function
|
||||
/// if the distance is None, the edge has value "infinite" or max distance
|
||||
/// each algorithm has to match
|
||||
pub distance: Option<T>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Eq for PairwiseDistance<T> {}
|
||||
|
||||
impl<T: RealNumber> PartialEq for PairwiseDistance<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.node == other.node
|
||||
&& self.neighbour == other.neighbour
|
||||
&& self.distance == other.distance
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> PartialOrd for PairwiseDistance<T> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
self.distance.partial_cmp(&other.distance)
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,5 @@
|
||||
#![allow(non_snake_case)]
|
||||
use itertools::Itertools;
|
||||
///
|
||||
/// # FastPair: Data-structure for the dynamic closest-pair problem.
|
||||
/// ### FastPair: Data-structure for the dynamic closest-pair problem.
|
||||
///
|
||||
/// Reference:
|
||||
/// Eppstein, David: Fast hierarchical clustering and other applications of
|
||||
@@ -9,8 +7,8 @@ use itertools::Itertools;
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// use smartcore::algorithm::neighbour::distances::PairwiseDistance;
|
||||
/// use smartcore::linalg::naive::dense_matrix::DenseMatrix;
|
||||
/// use smartcore::metrics::distance::PairwiseDistance;
|
||||
/// use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
/// use smartcore::algorithm::neighbour::fastpair::FastPair;
|
||||
/// let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
/// &[5.1, 3.5, 1.4, 0.2],
|
||||
@@ -19,7 +17,7 @@ use itertools::Itertools;
|
||||
/// &[4.6, 3.1, 1.5, 0.2],
|
||||
/// &[5.0, 3.6, 1.4, 0.2],
|
||||
/// &[5.4, 3.9, 1.7, 0.4],
|
||||
/// ]);
|
||||
/// ]).unwrap();
|
||||
/// let fastpair = FastPair::new(&x);
|
||||
/// let closest_pair: PairwiseDistance<f64> = fastpair.unwrap().closest_pair();
|
||||
/// ```
|
||||
@@ -27,11 +25,14 @@ use itertools::Itertools;
|
||||
/// <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::algorithm::neighbour::distances::PairwiseDistance;
|
||||
use num::Bounded;
|
||||
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::distance::euclidian::Euclidian;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::linalg::basic::arrays::{Array1, Array2};
|
||||
use crate::metrics::distance::euclidian::Euclidian;
|
||||
use crate::metrics::distance::PairwiseDistance;
|
||||
use crate::numbers::floatnum::FloatNumber;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
|
||||
///
|
||||
/// Inspired by Python implementation:
|
||||
@@ -41,7 +42,7 @@ use crate::math::num::RealNumber;
|
||||
/// affinity used is Euclidean so to allow linkage with single, ward, complete and average
|
||||
///
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FastPair<'a, T: RealNumber, M: Matrix<T>> {
|
||||
pub struct FastPair<'a, T: RealNumber + FloatNumber, M: Array2<T>> {
|
||||
/// initial matrix
|
||||
samples: &'a M,
|
||||
/// closest pair hashmap (connectivity matrix for closest pairs)
|
||||
@@ -50,11 +51,9 @@ pub struct FastPair<'a, T: RealNumber, M: Matrix<T>> {
|
||||
pub neighbours: Vec<usize>,
|
||||
}
|
||||
|
||||
impl<'a, T: RealNumber, M: Matrix<T>> FastPair<'a, T, M> {
|
||||
///
|
||||
impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
|
||||
/// Constructor
|
||||
/// Instantiate and inizialise the algorithm
|
||||
///
|
||||
/// Instantiate and initialize the algorithm
|
||||
pub fn new(m: &'a M) -> Result<Self, Failed> {
|
||||
if m.shape().0 < 3 {
|
||||
return Err(Failed::because(
|
||||
@@ -73,10 +72,8 @@ impl<'a, T: RealNumber, M: Matrix<T>> FastPair<'a, T, M> {
|
||||
Ok(init)
|
||||
}
|
||||
|
||||
///
|
||||
/// Initialise `FastPair` by passing a `Matrix`.
|
||||
/// Initialise `FastPair` by passing a `Array2`.
|
||||
/// Build a FastPairs data-structure from a set of (new) points.
|
||||
///
|
||||
fn init(&mut self) {
|
||||
// basic measures
|
||||
let len = self.samples.shape().0;
|
||||
@@ -98,8 +95,8 @@ impl<'a, T: RealNumber, M: Matrix<T>> FastPair<'a, T, M> {
|
||||
index_row_i,
|
||||
PairwiseDistance {
|
||||
node: index_row_i,
|
||||
neighbour: None,
|
||||
distance: Some(T::max_value()),
|
||||
neighbour: Option::None,
|
||||
distance: Some(<T as Bounded>::max_value()),
|
||||
},
|
||||
);
|
||||
}
|
||||
@@ -120,13 +117,19 @@ impl<'a, T: RealNumber, M: Matrix<T>> FastPair<'a, T, M> {
|
||||
);
|
||||
|
||||
let d = Euclidian::squared_distance(
|
||||
&(self.samples.get_row_as_vec(index_row_i)),
|
||||
&(self.samples.get_row_as_vec(index_row_j)),
|
||||
&Vec::from_iterator(
|
||||
self.samples.get_row(index_row_i).iterator(0).copied(),
|
||||
self.samples.shape().1,
|
||||
),
|
||||
&Vec::from_iterator(
|
||||
self.samples.get_row(index_row_j).iterator(0).copied(),
|
||||
self.samples.shape().1,
|
||||
),
|
||||
);
|
||||
if d < nbd.unwrap() {
|
||||
if d < nbd.unwrap().to_f64().unwrap() {
|
||||
// set this j-value to be the closest neighbour
|
||||
index_closest = index_row_j;
|
||||
nbd = Some(d);
|
||||
nbd = Some(T::from(d).unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -139,21 +142,19 @@ impl<'a, T: RealNumber, M: Matrix<T>> FastPair<'a, T, M> {
|
||||
// No more neighbors, terminate conga line.
|
||||
// Last person on the line has no neigbors
|
||||
distances.get_mut(&max_index).unwrap().neighbour = Some(max_index);
|
||||
distances.get_mut(&(len - 1)).unwrap().distance = Some(T::max_value());
|
||||
distances.get_mut(&(len - 1)).unwrap().distance = Some(<T as Bounded>::max_value());
|
||||
|
||||
// compute sparse matrix (connectivity matrix)
|
||||
let mut sparse_matrix = M::zeros(len, len);
|
||||
for (_, p) in distances.iter() {
|
||||
sparse_matrix.set(p.node, p.neighbour.unwrap(), p.distance.unwrap());
|
||||
sparse_matrix.set((p.node, p.neighbour.unwrap()), p.distance.unwrap());
|
||||
}
|
||||
|
||||
self.distances = distances;
|
||||
self.neighbours = neighbours;
|
||||
}
|
||||
|
||||
///
|
||||
/// Find closest pair by scanning list of nearest neighbors.
|
||||
///
|
||||
#[allow(dead_code)]
|
||||
pub fn closest_pair(&self) -> PairwiseDistance<T> {
|
||||
let mut a = self.neighbours[0]; // Start with first point
|
||||
@@ -173,29 +174,18 @@ impl<'a, T: RealNumber, M: Matrix<T>> FastPair<'a, T, M> {
|
||||
}
|
||||
|
||||
///
|
||||
/// Brute force algorithm, used only for comparison and testing
|
||||
/// Return order dissimilarities from closest to furthest
|
||||
///
|
||||
#[cfg(feature = "fp_bench")]
|
||||
pub fn closest_pair_brute(&self) -> PairwiseDistance<T> {
|
||||
let m = self.samples.shape().0;
|
||||
|
||||
let mut closest_pair = PairwiseDistance {
|
||||
node: 0,
|
||||
neighbour: None,
|
||||
distance: Some(T::max_value()),
|
||||
};
|
||||
for pair in (0..m).combinations(2) {
|
||||
let d = Euclidian::squared_distance(
|
||||
&(self.samples.get_row_as_vec(pair[0])),
|
||||
&(self.samples.get_row_as_vec(pair[1])),
|
||||
);
|
||||
if d < closest_pair.distance.unwrap() {
|
||||
closest_pair.node = pair[0];
|
||||
closest_pair.neighbour = Some(pair[1]);
|
||||
closest_pair.distance = Some(d);
|
||||
}
|
||||
}
|
||||
closest_pair
|
||||
#[allow(dead_code)]
|
||||
pub fn ordered_pairs(&self) -> std::vec::IntoIter<&PairwiseDistance<T>> {
|
||||
// improvement: implement this to return `impl Iterator<Item = &PairwiseDistance<T>>`
|
||||
// need to implement trait `Iterator` for `Vec<&PairwiseDistance<T>>`
|
||||
let mut distances = self
|
||||
.distances
|
||||
.values()
|
||||
.collect::<Vec<&PairwiseDistance<T>>>();
|
||||
distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
distances.into_iter()
|
||||
}
|
||||
|
||||
//
|
||||
@@ -210,10 +200,19 @@ impl<'a, T: RealNumber, M: Matrix<T>> FastPair<'a, T, M> {
|
||||
distances.push(PairwiseDistance {
|
||||
node: index_row,
|
||||
neighbour: Some(*other),
|
||||
distance: Some(Euclidian::squared_distance(
|
||||
&(self.samples.get_row_as_vec(index_row)),
|
||||
&(self.samples.get_row_as_vec(*other)),
|
||||
)),
|
||||
distance: Some(
|
||||
T::from(Euclidian::squared_distance(
|
||||
&Vec::from_iterator(
|
||||
self.samples.get_row(index_row).iterator(0).copied(),
|
||||
self.samples.shape().1,
|
||||
),
|
||||
&Vec::from_iterator(
|
||||
self.samples.get_row(*other).iterator(0).copied(),
|
||||
self.samples.shape().1,
|
||||
),
|
||||
))
|
||||
.unwrap(),
|
||||
),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -225,7 +224,39 @@ impl<'a, T: RealNumber, M: Matrix<T>> FastPair<'a, T, M> {
|
||||
mod tests_fastpair {
|
||||
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix};
|
||||
|
||||
/// Brute force algorithm, used only for comparison and testing
|
||||
pub fn closest_pair_brute(
|
||||
fastpair: &FastPair<'_, f64, DenseMatrix<f64>>,
|
||||
) -> PairwiseDistance<f64> {
|
||||
use itertools::Itertools;
|
||||
let m = fastpair.samples.shape().0;
|
||||
|
||||
let mut closest_pair = PairwiseDistance {
|
||||
node: 0,
|
||||
neighbour: Option::None,
|
||||
distance: Some(f64::max_value()),
|
||||
};
|
||||
for pair in (0..m).combinations(2) {
|
||||
let d = Euclidian::squared_distance(
|
||||
&Vec::from_iterator(
|
||||
fastpair.samples.get_row(pair[0]).iterator(0).copied(),
|
||||
fastpair.samples.shape().1,
|
||||
),
|
||||
&Vec::from_iterator(
|
||||
fastpair.samples.get_row(pair[1]).iterator(0).copied(),
|
||||
fastpair.samples.shape().1,
|
||||
),
|
||||
);
|
||||
if d < closest_pair.distance.unwrap() {
|
||||
closest_pair.node = pair[0];
|
||||
closest_pair.neighbour = Some(pair[1]);
|
||||
closest_pair.distance = Some(d);
|
||||
}
|
||||
}
|
||||
closest_pair
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fastpair_init() {
|
||||
@@ -238,8 +269,8 @@ mod tests_fastpair {
|
||||
let distances = fastpair.distances;
|
||||
let neighbours = fastpair.neighbours;
|
||||
|
||||
assert!(distances.len() != 0);
|
||||
assert!(neighbours.len() != 0);
|
||||
assert!(!distances.is_empty());
|
||||
assert!(!neighbours.is_empty());
|
||||
|
||||
assert_eq!(10, neighbours.len());
|
||||
assert_eq!(10, distances.len());
|
||||
@@ -249,28 +280,24 @@ mod tests_fastpair {
|
||||
fn dataset_has_at_least_three_points() {
|
||||
// Create a dataset which consists of only two points:
|
||||
// A(0.0, 0.0) and B(1.0, 1.0).
|
||||
let dataset = DenseMatrix::<f64>::from_2d_array(&[&[0.0, 0.0], &[1.0, 1.0]]);
|
||||
let dataset = DenseMatrix::<f64>::from_2d_array(&[&[0.0, 0.0], &[1.0, 1.0]]).unwrap();
|
||||
|
||||
// We expect an error when we run `FastPair` on this dataset,
|
||||
// becuase `FastPair` currently only works on a minimum of 3
|
||||
// points.
|
||||
let _fastpair = FastPair::new(&dataset);
|
||||
let fastpair = FastPair::new(&dataset);
|
||||
assert!(fastpair.is_err());
|
||||
|
||||
match _fastpair {
|
||||
Err(e) => {
|
||||
let expected_error =
|
||||
Failed::because(FailedError::FindFailed, "min number of rows should be 3");
|
||||
assert_eq!(e, expected_error)
|
||||
}
|
||||
_ => {
|
||||
assert!(false);
|
||||
}
|
||||
if let Err(e) = fastpair {
|
||||
let expected_error =
|
||||
Failed::because(FailedError::FindFailed, "min number of rows should be 3");
|
||||
assert_eq!(e, expected_error)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one_dimensional_dataset_minimal() {
|
||||
let dataset = DenseMatrix::<f64>::from_2d_array(&[&[0.0], &[2.0], &[9.0]]);
|
||||
let dataset = DenseMatrix::<f64>::from_2d_array(&[&[0.0], &[2.0], &[9.0]]).unwrap();
|
||||
|
||||
let result = FastPair::new(&dataset);
|
||||
assert!(result.is_ok());
|
||||
@@ -284,13 +311,14 @@ mod tests_fastpair {
|
||||
};
|
||||
assert_eq!(closest_pair, expected_closest_pair);
|
||||
|
||||
let closest_pair_brute = fastpair.closest_pair_brute();
|
||||
let closest_pair_brute = closest_pair_brute(&fastpair);
|
||||
assert_eq!(closest_pair_brute, expected_closest_pair);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one_dimensional_dataset_2() {
|
||||
let dataset = DenseMatrix::<f64>::from_2d_array(&[&[27.0], &[0.0], &[9.0], &[2.0]]);
|
||||
let dataset =
|
||||
DenseMatrix::<f64>::from_2d_array(&[&[27.0], &[0.0], &[9.0], &[2.0]]).unwrap();
|
||||
|
||||
let result = FastPair::new(&dataset);
|
||||
assert!(result.is_ok());
|
||||
@@ -302,7 +330,7 @@ mod tests_fastpair {
|
||||
neighbour: Some(3),
|
||||
distance: Some(4.0),
|
||||
};
|
||||
assert_eq!(closest_pair, fastpair.closest_pair_brute());
|
||||
assert_eq!(closest_pair, closest_pair_brute(&fastpair));
|
||||
assert_eq!(closest_pair, expected_closest_pair);
|
||||
}
|
||||
|
||||
@@ -325,7 +353,8 @@ mod tests_fastpair {
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let fastpair = FastPair::new(&x);
|
||||
assert!(fastpair.is_ok());
|
||||
|
||||
@@ -459,11 +488,16 @@ mod tests_fastpair {
|
||||
let expected: HashMap<_, _> = dissimilarities.into_iter().collect();
|
||||
|
||||
for i in 0..(x.shape().0 - 1) {
|
||||
let input_node = result.samples.get_row_as_vec(i);
|
||||
let input_neighbour: usize = expected.get(&i).unwrap().neighbour.unwrap();
|
||||
let distance = Euclidian::squared_distance(
|
||||
&input_node,
|
||||
&result.samples.get_row_as_vec(input_neighbour),
|
||||
&Vec::from_iterator(
|
||||
result.samples.get_row(i).iterator(0).copied(),
|
||||
result.samples.shape().1,
|
||||
),
|
||||
&Vec::from_iterator(
|
||||
result.samples.get_row(input_neighbour).iterator(0).copied(),
|
||||
result.samples.shape().1,
|
||||
),
|
||||
);
|
||||
|
||||
assert_eq!(i, expected.get(&i).unwrap().node);
|
||||
@@ -493,7 +527,8 @@ mod tests_fastpair {
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
// compute
|
||||
let fastpair = FastPair::new(&x);
|
||||
assert!(fastpair.is_ok());
|
||||
@@ -518,7 +553,7 @@ mod tests_fastpair {
|
||||
let result = fastpair.unwrap();
|
||||
|
||||
let dissimilarity1 = result.closest_pair();
|
||||
let dissimilarity2 = result.closest_pair_brute();
|
||||
let dissimilarity2 = closest_pair_brute(&result);
|
||||
|
||||
assert_eq!(dissimilarity1, dissimilarity2);
|
||||
}
|
||||
@@ -541,7 +576,8 @@ mod tests_fastpair {
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
// compute
|
||||
let fastpair = FastPair::new(&x);
|
||||
assert!(fastpair.is_ok());
|
||||
@@ -550,12 +586,12 @@ mod tests_fastpair {
|
||||
|
||||
let mut min_dissimilarity = PairwiseDistance {
|
||||
node: 0,
|
||||
neighbour: None,
|
||||
neighbour: Option::None,
|
||||
distance: Some(f64::MAX),
|
||||
};
|
||||
for p in dissimilarities.iter() {
|
||||
if p.distance.unwrap() < min_dissimilarity.distance.unwrap() {
|
||||
min_dissimilarity = p.clone()
|
||||
min_dissimilarity = *p
|
||||
}
|
||||
}
|
||||
|
||||
@@ -567,4 +603,103 @@ mod tests_fastpair {
|
||||
|
||||
assert_eq!(closest, min_dissimilarity);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fastpair_ordered_pairs() {
|
||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
])
|
||||
.unwrap();
|
||||
let fastpair = FastPair::new(&x).unwrap();
|
||||
|
||||
let ordered = fastpair.ordered_pairs();
|
||||
|
||||
let mut previous: f64 = -1.0;
|
||||
for p in ordered {
|
||||
if previous == -1.0 {
|
||||
previous = p.distance.unwrap();
|
||||
} else {
|
||||
let current = p.distance.unwrap();
|
||||
assert!(current >= previous);
|
||||
previous = current;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_set() {
|
||||
let empty_matrix = DenseMatrix::<f64>::zeros(0, 0);
|
||||
let result = FastPair::new(&empty_matrix);
|
||||
assert!(result.is_err());
|
||||
if let Err(e) = result {
|
||||
assert_eq!(
|
||||
e,
|
||||
Failed::because(FailedError::FindFailed, "min number of rows should be 3")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_point() {
|
||||
let single_point = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]).unwrap();
|
||||
let result = FastPair::new(&single_point);
|
||||
assert!(result.is_err());
|
||||
if let Err(e) = result {
|
||||
assert_eq!(
|
||||
e,
|
||||
Failed::because(FailedError::FindFailed, "min number of rows should be 3")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_two_points() {
|
||||
let two_points = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
|
||||
let result = FastPair::new(&two_points);
|
||||
assert!(result.is_err());
|
||||
if let Err(e) = result {
|
||||
assert_eq!(
|
||||
e,
|
||||
Failed::because(FailedError::FindFailed, "min number of rows should be 3")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_three_identical_points() {
|
||||
let identical_points =
|
||||
DenseMatrix::from_2d_array(&[&[1.0, 1.0], &[1.0, 1.0], &[1.0, 1.0]]).unwrap();
|
||||
let result = FastPair::new(&identical_points);
|
||||
assert!(result.is_ok());
|
||||
let fastpair = result.unwrap();
|
||||
let closest_pair = fastpair.closest_pair();
|
||||
assert_eq!(closest_pair.distance, Some(0.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_result_unwrapping() {
|
||||
let valid_matrix =
|
||||
DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0], &[5.0, 6.0], &[7.0, 8.0]])
|
||||
.unwrap();
|
||||
|
||||
let result = FastPair::new(&valid_matrix);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// This should not panic
|
||||
let _fastpair = result.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,12 +3,12 @@
|
||||
//! see [KNN algorithms](../index.html)
|
||||
//! ```
|
||||
//! use smartcore::algorithm::neighbour::linear_search::*;
|
||||
//! use smartcore::math::distance::Distance;
|
||||
//! use smartcore::metrics::distance::Distance;
|
||||
//!
|
||||
//! #[derive(Clone)]
|
||||
//! struct SimpleDistance {} // Our distance function
|
||||
//!
|
||||
//! impl Distance<i32, f64> for SimpleDistance {
|
||||
//! impl Distance<i32> for SimpleDistance {
|
||||
//! fn distance(&self, a: &i32, b: &i32) -> f64 { // simple simmetrical scalar distance
|
||||
//! (a - b).abs() as f64
|
||||
//! }
|
||||
@@ -25,38 +25,31 @@
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::cmp::{Ordering, PartialOrd};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::algorithm::sort::heap_select::HeapSelection;
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::math::distance::Distance;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::metrics::distance::Distance;
|
||||
|
||||
/// Implements Linear Search algorithm, see [KNN algorithms](../index.html)
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct LinearKNNSearch<T, F: RealNumber, D: Distance<T, F>> {
|
||||
pub struct LinearKNNSearch<T, D: Distance<T>> {
|
||||
distance: D,
|
||||
data: Vec<T>,
|
||||
f: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
||||
impl<T, D: Distance<T>> LinearKNNSearch<T, D> {
|
||||
/// Initializes algorithm.
|
||||
/// * `data` - vector of data points to search for.
|
||||
/// * `distance` - distance metric to use for searching. This function should extend [`Distance`](../../../math/distance/index.html) interface.
|
||||
pub fn new(data: Vec<T>, distance: D) -> Result<LinearKNNSearch<T, F, D>, Failed> {
|
||||
Ok(LinearKNNSearch {
|
||||
data,
|
||||
distance,
|
||||
f: PhantomData,
|
||||
})
|
||||
pub fn new(data: Vec<T>, distance: D) -> Result<LinearKNNSearch<T, D>, Failed> {
|
||||
Ok(LinearKNNSearch { data, distance })
|
||||
}
|
||||
|
||||
/// Find k nearest neighbors
|
||||
/// * `from` - look for k nearest points to `from`
|
||||
/// * `k` - the number of nearest neighbors to return
|
||||
pub fn find(&self, from: &T, k: usize) -> Result<Vec<(usize, F, &T)>, Failed> {
|
||||
pub fn find(&self, from: &T, k: usize) -> Result<Vec<(usize, f64, &T)>, Failed> {
|
||||
if k < 1 || k > self.data.len() {
|
||||
return Err(Failed::because(
|
||||
FailedError::FindFailed,
|
||||
@@ -64,11 +57,11 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
||||
));
|
||||
}
|
||||
|
||||
let mut heap = HeapSelection::<KNNPoint<F>>::with_capacity(k);
|
||||
let mut heap = HeapSelection::<KNNPoint>::with_capacity(k);
|
||||
|
||||
for _ in 0..k {
|
||||
heap.add(KNNPoint {
|
||||
distance: F::infinity(),
|
||||
distance: f64::INFINITY,
|
||||
index: None,
|
||||
});
|
||||
}
|
||||
@@ -93,15 +86,15 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
||||
/// Find all nearest neighbors within radius `radius` from `p`
|
||||
/// * `p` - look for k nearest points to `p`
|
||||
/// * `radius` - radius of the search
|
||||
pub fn find_radius(&self, from: &T, radius: F) -> Result<Vec<(usize, F, &T)>, Failed> {
|
||||
if radius <= F::zero() {
|
||||
pub fn find_radius(&self, from: &T, radius: f64) -> Result<Vec<(usize, f64, &T)>, Failed> {
|
||||
if radius <= 0f64 {
|
||||
return Err(Failed::because(
|
||||
FailedError::FindFailed,
|
||||
"radius should be > 0",
|
||||
));
|
||||
}
|
||||
|
||||
let mut neighbors: Vec<(usize, F, &T)> = Vec::new();
|
||||
let mut neighbors: Vec<(usize, f64, &T)> = Vec::new();
|
||||
|
||||
for i in 0..self.data.len() {
|
||||
let d = self.distance.distance(from, &self.data[i]);
|
||||
@@ -116,41 +109,44 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct KNNPoint<F: RealNumber> {
|
||||
distance: F,
|
||||
struct KNNPoint {
|
||||
distance: f64,
|
||||
index: Option<usize>,
|
||||
}
|
||||
|
||||
impl<F: RealNumber> PartialOrd for KNNPoint<F> {
|
||||
impl PartialOrd for KNNPoint {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
self.distance.partial_cmp(&other.distance)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: RealNumber> PartialEq for KNNPoint<F> {
|
||||
impl PartialEq for KNNPoint {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.distance == other.distance
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: RealNumber> Eq for KNNPoint<F> {}
|
||||
impl Eq for KNNPoint {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::math::distance::Distances;
|
||||
use crate::metrics::distance::Distances;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
struct SimpleDistance {}
|
||||
|
||||
impl Distance<i32, f64> for SimpleDistance {
|
||||
impl Distance<i32> for SimpleDistance {
|
||||
fn distance(&self, a: &i32, b: &i32) -> f64 {
|
||||
(a - b).abs() as f64
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn knn_find() {
|
||||
let data1 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
|
||||
@@ -197,7 +193,10 @@ mod tests {
|
||||
|
||||
assert_eq!(vec!(1, 2, 3), found_idxs2);
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn knn_point_eq() {
|
||||
let point1 = KNNPoint {
|
||||
@@ -216,7 +215,7 @@ mod tests {
|
||||
};
|
||||
|
||||
let point_inf = KNNPoint {
|
||||
distance: std::f64::INFINITY,
|
||||
distance: f64::INFINITY,
|
||||
index: Some(3),
|
||||
};
|
||||
|
||||
|
||||
@@ -33,16 +33,16 @@
|
||||
use crate::algorithm::neighbour::cover_tree::CoverTree;
|
||||
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
|
||||
use crate::error::Failed;
|
||||
use crate::math::distance::Distance;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::metrics::distance::Distance;
|
||||
use crate::numbers::basenum::Number;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub(crate) mod bbd_tree;
|
||||
/// a variant of fastpair using cosine distance
|
||||
pub mod cosinepair;
|
||||
/// tree data structure for fast nearest neighbor search
|
||||
pub mod cover_tree;
|
||||
/// dissimilarities for vector-vector distance. Linkage algorithms used in fastpair
|
||||
pub mod distances;
|
||||
/// fastpair closest neighbour algorithm
|
||||
pub mod fastpair;
|
||||
/// very simple algorithm that sequentially checks each element of the list until a match is found or the whole list has been searched.
|
||||
@@ -51,23 +51,25 @@ pub mod linear_search;
|
||||
/// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries.
|
||||
/// `KNNAlgorithmName` maintains a list of supported search algorithms, see [KNN algorithms](../algorithm/neighbour/index.html)
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub enum KNNAlgorithmName {
|
||||
/// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html)
|
||||
LinearSearch,
|
||||
/// Cover Tree Search algorithm, see [`CoverTree`](../algorithm/neighbour/cover_tree/index.html)
|
||||
#[default]
|
||||
CoverTree,
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum KNNAlgorithm<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
|
||||
CoverTree(CoverTree<Vec<T>, T, D>),
|
||||
pub(crate) enum KNNAlgorithm<T: Number, D: Distance<Vec<T>>> {
|
||||
LinearSearch(LinearKNNSearch<Vec<T>, D>),
|
||||
CoverTree(CoverTree<Vec<T>, D>),
|
||||
}
|
||||
|
||||
// TODO: missing documentation
|
||||
impl KNNAlgorithmName {
|
||||
pub(crate) fn fit<T: RealNumber, D: Distance<Vec<T>, T>>(
|
||||
pub(crate) fn fit<T: Number, D: Distance<Vec<T>>>(
|
||||
&self,
|
||||
data: Vec<Vec<T>>,
|
||||
distance: D,
|
||||
@@ -83,8 +85,8 @@ impl KNNAlgorithmName {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNAlgorithm<T, D> {
|
||||
pub fn find(&self, from: &Vec<T>, k: usize) -> Result<Vec<(usize, T, &Vec<T>)>, Failed> {
|
||||
impl<T: Number, D: Distance<Vec<T>>> KNNAlgorithm<T, D> {
|
||||
pub fn find(&self, from: &Vec<T>, k: usize) -> Result<Vec<(usize, f64, &Vec<T>)>, Failed> {
|
||||
match *self {
|
||||
KNNAlgorithm::LinearSearch(ref linear) => linear.find(from, k),
|
||||
KNNAlgorithm::CoverTree(ref cover) => cover.find(from, k),
|
||||
@@ -94,8 +96,8 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNAlgorithm<T, D> {
|
||||
pub fn find_radius(
|
||||
&self,
|
||||
from: &Vec<T>,
|
||||
radius: T,
|
||||
) -> Result<Vec<(usize, T, &Vec<T>)>, Failed> {
|
||||
radius: f64,
|
||||
) -> Result<Vec<(usize, f64, &Vec<T>)>, Failed> {
|
||||
match *self {
|
||||
KNNAlgorithm::LinearSearch(ref linear) => linear.find_radius(from, radius),
|
||||
KNNAlgorithm::CoverTree(ref cover) => cover.find_radius(from, radius),
|
||||
|
||||
@@ -95,14 +95,20 @@ impl<T: PartialOrd + Debug> HeapSelection<T> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn with_capacity() {
|
||||
let heap = HeapSelection::<i32>::with_capacity(3);
|
||||
assert_eq!(3, heap.k);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn test_add() {
|
||||
let mut heap = HeapSelection::with_capacity(3);
|
||||
@@ -120,11 +126,14 @@ mod tests {
|
||||
assert_eq!(vec![2, 0, -5], heap.get());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn test_add1() {
|
||||
let mut heap = HeapSelection::with_capacity(3);
|
||||
heap.add(std::f64::INFINITY);
|
||||
heap.add(f64::INFINITY);
|
||||
heap.add(-5f64);
|
||||
heap.add(4f64);
|
||||
heap.add(-1f64);
|
||||
@@ -135,11 +144,14 @@ mod tests {
|
||||
assert_eq!(vec![0f64, -1f64, -5f64], heap.get());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn test_add2() {
|
||||
let mut heap = HeapSelection::with_capacity(3);
|
||||
heap.add(std::f64::INFINITY);
|
||||
heap.add(f64::INFINITY);
|
||||
heap.add(0.0);
|
||||
heap.add(8.4852);
|
||||
heap.add(5.6568);
|
||||
@@ -148,7 +160,10 @@ mod tests {
|
||||
assert_eq!(vec![5.6568, 2.8284, 0.0], heap.get());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn test_add_ordered() {
|
||||
let mut heap = HeapSelection::with_capacity(3);
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
use num_traits::Float;
|
||||
use num_traits::Num;
|
||||
|
||||
pub trait QuickArgSort {
|
||||
#[allow(dead_code)]
|
||||
fn quick_argsort_mut(&mut self) -> Vec<usize>;
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn quick_argsort(&self) -> Vec<usize>;
|
||||
}
|
||||
|
||||
impl<T: Float> QuickArgSort for Vec<T> {
|
||||
impl<T: Num + PartialOrd + Copy> QuickArgSort for Vec<T> {
|
||||
fn quick_argsort(&self) -> Vec<usize> {
|
||||
let mut v = self.clone();
|
||||
v.quick_argsort_mut()
|
||||
@@ -113,7 +115,10 @@ impl<T: Float> QuickArgSort for Vec<T> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn with_capacity() {
|
||||
let arr1 = vec![0.3, 0.1, 0.2, 0.4, 0.9, 0.5, 0.7, 0.6, 0.8];
|
||||
|
||||
+34
-2
@@ -16,8 +16,12 @@ pub trait UnsupervisedEstimator<X, P> {
|
||||
P: Clone;
|
||||
}
|
||||
|
||||
/// An estimator for supervised learning, , that provides method `fit` to learn from data and training values
|
||||
pub trait SupervisedEstimator<X, Y, P> {
|
||||
/// An estimator for supervised learning, that provides method `fit` to learn from data and training values
|
||||
pub trait SupervisedEstimator<X, Y, P>: Predictor<X, Y> {
|
||||
/// Empty constructor, instantiate an empty estimator. Object is dropped as soon as `fit()` is called.
|
||||
/// used to pass around the correct `fit()` implementation.
|
||||
/// by calling `::fit()`. mostly used to be used with `model_selection::cross_validate(...)`
|
||||
fn new() -> Self;
|
||||
/// Fit a model to a training dataset, estimate model's parameters.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - target training values of size _N_.
|
||||
@@ -28,6 +32,24 @@ pub trait SupervisedEstimator<X, Y, P> {
|
||||
P: Clone;
|
||||
}
|
||||
|
||||
/// An estimator for supervised learning.
|
||||
/// In this one parameters are borrowed instead of moved, this is useful for parameters that carry
|
||||
/// references. Also to be used when there is no predictor attached to the estimator.
|
||||
pub trait SupervisedEstimatorBorrow<'a, X, Y, P> {
|
||||
/// Empty constructor, instantiate an empty estimator. Object is dropped as soon as `fit()` is called.
|
||||
/// used to pass around the correct `fit()` implementation.
|
||||
/// by calling `::fit()`. mostly used to be used with `model_selection::cross_validate(...)`
|
||||
fn new() -> Self;
|
||||
/// Fit a model to a training dataset, estimate model's parameters.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - target training values of size _N_.
|
||||
/// * `¶meters` - hyperparameters of an algorithm
|
||||
fn fit(x: &'a X, y: &'a Y, parameters: &'a P) -> Result<Self, Failed>
|
||||
where
|
||||
Self: Sized,
|
||||
P: Clone;
|
||||
}
|
||||
|
||||
/// Implements method predict that estimates target value from new data
|
||||
pub trait Predictor<X, Y> {
|
||||
/// Estimate target values from new data.
|
||||
@@ -35,9 +57,19 @@ pub trait Predictor<X, Y> {
|
||||
fn predict(&self, x: &X) -> Result<Y, Failed>;
|
||||
}
|
||||
|
||||
/// Implements method predict that estimates target value from new data, with borrowing
|
||||
pub trait PredictorBorrow<'a, X, T> {
|
||||
/// Estimate target values from new data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
fn predict(&self, x: &'a X) -> Result<Vec<T>, Failed>;
|
||||
}
|
||||
|
||||
/// Implements method transform that filters or modifies input data
|
||||
pub trait Transformer<X> {
|
||||
/// Transform data by modifying or filtering it
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
fn transform(&self, x: &X) -> Result<X, Failed>;
|
||||
}
|
||||
|
||||
/// empty parameters for an estimator, see `BiasedEstimator`
|
||||
pub trait NoParameters {}
|
||||
|
||||
@@ -0,0 +1,317 @@
|
||||
//! # Agglomerative Hierarchical Clustering
|
||||
//!
|
||||
//! Agglomerative clustering is a "bottom-up" hierarchical clustering method. It works by placing each data point in its own cluster and then successively merging the two most similar clusters until a stopping criterion is met. This process creates a tree-based hierarchy of clusters known as a dendrogram.
|
||||
//!
|
||||
//! The similarity of two clusters is determined by a **linkage criterion**. This implementation uses **single-linkage**, where the distance between two clusters is defined as the minimum distance between any single point in the first cluster and any single point in the second cluster. The distance between points is the standard Euclidean distance.
|
||||
//!
|
||||
//! The algorithm first builds the full hierarchy of `N-1` merges. To obtain a specific number of clusters, `n_clusters`, the algorithm then effectively "cuts" the dendrogram at the point where `n_clusters` remain.
|
||||
//!
|
||||
//! ## Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
//! use smartcore::cluster::agglomerative::{AgglomerativeClustering, AgglomerativeClusteringParameters};
|
||||
//!
|
||||
//! // A dataset with 2 distinct groups of points.
|
||||
//! let x = DenseMatrix::from_2d_array(&[
|
||||
//! &[0.0, 0.0], &[1.0, 1.0], &[0.5, 0.5], // Cluster A
|
||||
//! &[10.0, 10.0], &[11.0, 11.0], &[10.5, 10.5], // Cluster B
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! // Set parameters to find 2 clusters.
|
||||
//! let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(2);
|
||||
//!
|
||||
//! // Fit the model to the data.
|
||||
//! let clustering = AgglomerativeClustering::<f64, usize, DenseMatrix<f64>, Vec<usize>>::fit(&x, parameters).unwrap();
|
||||
//!
|
||||
//! // Get the cluster assignments.
|
||||
//! let labels = clustering.labels; // e.g., [0, 0, 0, 1, 1, 1]
|
||||
//! ```
|
||||
//!
|
||||
//! ## References:
|
||||
//!
|
||||
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 10.3.2 Hierarchical Clustering](http://faculty.marshall.usc.edu/gareth-james/ISL/)
|
||||
//! * ["The Elements of Statistical Learning", Hastie T., Tibshirani R., Friedman J., 14.3.12 Hierarchical Clustering](https://hastie.su.domains/ElemStatLearn/)
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::api::UnsupervisedEstimator;
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::basic::arrays::{Array1, Array2};
|
||||
use crate::numbers::basenum::Number;
|
||||
|
||||
/// Parameters for the Agglomerative Clustering algorithm.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct AgglomerativeClusteringParameters {
|
||||
/// The number of clusters to find.
|
||||
pub n_clusters: usize,
|
||||
}
|
||||
|
||||
impl AgglomerativeClusteringParameters {
|
||||
/// Sets the number of clusters.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `n_clusters` - The desired number of clusters.
|
||||
pub fn with_n_clusters(mut self, n_clusters: usize) -> Self {
|
||||
self.n_clusters = n_clusters;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AgglomerativeClusteringParameters {
|
||||
fn default() -> Self {
|
||||
AgglomerativeClusteringParameters { n_clusters: 2 }
|
||||
}
|
||||
}
|
||||
|
||||
/// Agglomerative Clustering model.
|
||||
///
|
||||
/// This implementation uses single-linkage clustering, which is mathematically
|
||||
/// equivalent to finding the Minimum Spanning Tree (MST) of the data points.
|
||||
/// The core logic is an efficient implementation of Kruskal's algorithm, which
|
||||
/// processes all pairwise distances in increasing order and uses a Disjoint
|
||||
/// Set Union (DSU) data structure to track cluster membership.
|
||||
#[derive(Debug)]
|
||||
pub struct AgglomerativeClustering<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
|
||||
/// The cluster label assigned to each sample.
|
||||
pub labels: Vec<usize>,
|
||||
_phantom_tx: PhantomData<TX>,
|
||||
_phantom_ty: PhantomData<TY>,
|
||||
_phantom_x: PhantomData<X>,
|
||||
_phantom_y: PhantomData<Y>,
|
||||
}
|
||||
|
||||
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> AgglomerativeClustering<TX, TY, X, Y> {
|
||||
/// Fits the agglomerative clustering model to the data.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `data` - A reference to the input data matrix.
|
||||
/// * `parameters` - The parameters for the clustering algorithm, including `n_clusters`.
|
||||
///
|
||||
/// # Returns
|
||||
/// A `Result` containing the fitted model with cluster labels, or an error if
|
||||
pub fn fit(data: &X, parameters: AgglomerativeClusteringParameters) -> Result<Self, Failed> {
|
||||
let (num_samples, _) = data.shape();
|
||||
let n_clusters = parameters.n_clusters;
|
||||
if n_clusters > num_samples {
|
||||
return Err(Failed::because(
|
||||
FailedError::ParametersError,
|
||||
&format!(
|
||||
"n_clusters: {n_clusters} cannot be greater than n_samples: {num_samples}"
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
let mut distance_pairs = Vec::new();
|
||||
for i in 0..num_samples {
|
||||
for j in (i + 1)..num_samples {
|
||||
let distance: f64 = data
|
||||
.get_row(i)
|
||||
.iterator(0)
|
||||
.zip(data.get_row(j).iterator(0))
|
||||
.map(|(&a, &b)| (a.to_f64().unwrap() - b.to_f64().unwrap()).powi(2))
|
||||
.sum::<f64>();
|
||||
|
||||
distance_pairs.push((distance, i, j));
|
||||
}
|
||||
}
|
||||
distance_pairs.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
|
||||
let mut parent = HashMap::new();
|
||||
let mut children = HashMap::new();
|
||||
for i in 0..num_samples {
|
||||
parent.insert(i, i);
|
||||
children.insert(i, vec![i]);
|
||||
}
|
||||
|
||||
let mut merge_history = Vec::new();
|
||||
let num_merges_needed = num_samples - 1;
|
||||
|
||||
while merge_history.len() < num_merges_needed {
|
||||
let (_, p1, p2) = distance_pairs.pop().unwrap();
|
||||
|
||||
let root1 = parent[&p1];
|
||||
let root2 = parent[&p2];
|
||||
|
||||
if root1 != root2 {
|
||||
let root2_children = children.remove(&root2).unwrap();
|
||||
for child in root2_children.iter() {
|
||||
parent.insert(*child, root1);
|
||||
}
|
||||
let root1_children = children.get_mut(&root1).unwrap();
|
||||
root1_children.extend(root2_children);
|
||||
merge_history.push((root1, root2));
|
||||
}
|
||||
}
|
||||
|
||||
let mut clusters = HashMap::new();
|
||||
let mut assignments = HashMap::new();
|
||||
|
||||
for i in 0..num_samples {
|
||||
clusters.insert(i, vec![i]);
|
||||
assignments.insert(i, i);
|
||||
}
|
||||
|
||||
let merges_to_apply = num_samples - n_clusters;
|
||||
|
||||
for (root1, root2) in merge_history[0..merges_to_apply].iter() {
|
||||
let root1_cluster = assignments[root1];
|
||||
let root2_cluster = assignments[root2];
|
||||
|
||||
let root2_assignments = clusters.remove(&root2_cluster).unwrap();
|
||||
for assignment in root2_assignments.iter() {
|
||||
assignments.insert(*assignment, root1_cluster);
|
||||
}
|
||||
let root1_assignments = clusters.get_mut(&root1_cluster).unwrap();
|
||||
root1_assignments.extend(root2_assignments);
|
||||
}
|
||||
|
||||
let mut labels: Vec<usize> = (0..num_samples).map(|_| 0).collect();
|
||||
let mut cluster_keys: Vec<&usize> = clusters.keys().collect();
|
||||
cluster_keys.sort();
|
||||
for (i, key) in cluster_keys.into_iter().enumerate() {
|
||||
for index in clusters[key].iter() {
|
||||
labels[*index] = i;
|
||||
}
|
||||
}
|
||||
Ok(AgglomerativeClustering {
|
||||
labels,
|
||||
_phantom_tx: PhantomData,
|
||||
_phantom_ty: PhantomData,
|
||||
_phantom_x: PhantomData,
|
||||
_phantom_y: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||
UnsupervisedEstimator<X, AgglomerativeClusteringParameters>
|
||||
for AgglomerativeClustering<TX, TY, X, Y>
|
||||
{
|
||||
fn fit(x: &X, parameters: AgglomerativeClusteringParameters) -> Result<Self, Failed> {
|
||||
AgglomerativeClustering::fit(x, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use std::collections::HashSet;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_simple_clustering() {
|
||||
// Two distinct clusters, far apart.
|
||||
let data = vec![
|
||||
0.0, 0.0, 1.0, 1.0, 0.5, 0.5, // Cluster A
|
||||
10.0, 10.0, 11.0, 11.0, 10.5, 10.5, // Cluster B
|
||||
];
|
||||
let matrix = DenseMatrix::new(6, 2, data, false).unwrap();
|
||||
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(2);
|
||||
// Using f64 for TY as usize doesn't satisfy the Number trait bound.
|
||||
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
|
||||
&matrix, parameters,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let labels = clustering.labels;
|
||||
|
||||
// Check that all points in the first group have the same label.
|
||||
let first_group_label = labels[0];
|
||||
assert!(labels[0..3].iter().all(|&l| l == first_group_label));
|
||||
|
||||
// Check that all points in the second group have the same label.
|
||||
let second_group_label = labels[3];
|
||||
assert!(labels[3..6].iter().all(|&l| l == second_group_label));
|
||||
|
||||
// Check that the two groups have different labels.
|
||||
assert_ne!(first_group_label, second_group_label);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_four_clusters() {
|
||||
// Four distinct clusters in the corners of a square.
|
||||
let data = vec![
|
||||
0.0, 0.0, 1.0, 1.0, // Cluster A
|
||||
100.0, 100.0, 101.0, 101.0, // Cluster B
|
||||
0.0, 100.0, 1.0, 101.0, // Cluster C
|
||||
100.0, 0.0, 101.0, 1.0, // Cluster D
|
||||
];
|
||||
let matrix = DenseMatrix::new(8, 2, data, false).unwrap();
|
||||
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(4);
|
||||
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
|
||||
&matrix, parameters,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let labels = clustering.labels;
|
||||
|
||||
// Verify that there are exactly 4 unique labels produced.
|
||||
let unique_labels: HashSet<usize> = labels.iter().cloned().collect();
|
||||
assert_eq!(unique_labels.len(), 4);
|
||||
|
||||
// Verify that points within each original group were assigned the same cluster label.
|
||||
let label_a = labels[0];
|
||||
assert_eq!(label_a, labels[1]);
|
||||
|
||||
let label_b = labels[2];
|
||||
assert_eq!(label_b, labels[3]);
|
||||
|
||||
let label_c = labels[4];
|
||||
assert_eq!(label_c, labels[5]);
|
||||
|
||||
let label_d = labels[6];
|
||||
assert_eq!(label_d, labels[7]);
|
||||
|
||||
// Verify that all four groups received different labels.
|
||||
assert_ne!(label_a, label_b);
|
||||
assert_ne!(label_a, label_c);
|
||||
assert_ne!(label_a, label_d);
|
||||
assert_ne!(label_b, label_c);
|
||||
assert_ne!(label_b, label_d);
|
||||
assert_ne!(label_c, label_d);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_n_clusters_equal_to_samples() {
|
||||
let data = vec![0.0, 0.0, 5.0, 5.0, 10.0, 10.0];
|
||||
let matrix = DenseMatrix::new(3, 2, data, false).unwrap();
|
||||
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(3);
|
||||
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
|
||||
&matrix, parameters,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Each point should be its own cluster. Sorting makes the test deterministic.
|
||||
let mut labels = clustering.labels;
|
||||
labels.sort();
|
||||
assert_eq!(labels, vec![0, 1, 2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_one_cluster() {
|
||||
let data = vec![0.0, 0.0, 5.0, 5.0, 10.0, 10.0];
|
||||
let matrix = DenseMatrix::new(3, 2, data, false).unwrap();
|
||||
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(1);
|
||||
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
|
||||
&matrix, parameters,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// All points should be in the same cluster.
|
||||
assert_eq!(clustering.labels, vec![0, 0, 0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_on_too_many_clusters() {
|
||||
let data = vec![0.0, 0.0, 5.0, 5.0];
|
||||
let matrix = DenseMatrix::new(2, 2, data, false).unwrap();
|
||||
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(3);
|
||||
let result = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
|
||||
&matrix, parameters,
|
||||
);
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
+239
-59
@@ -18,19 +18,20 @@
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! ```ignore
|
||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
//! use smartcore::linalg::basic::arrays::Array2;
|
||||
//! use smartcore::cluster::dbscan::*;
|
||||
//! use smartcore::math::distance::Distances;
|
||||
//! use smartcore::metrics::distance::Distances;
|
||||
//! use smartcore::neighbors::KNNAlgorithmName;
|
||||
//! use smartcore::dataset::generator;
|
||||
//!
|
||||
//! // Generate three blobs
|
||||
//! let blobs = generator::make_blobs(100, 2, 3);
|
||||
//! let x = DenseMatrix::from_vec(blobs.num_samples, blobs.num_features, &blobs.data);
|
||||
//! let x: DenseMatrix<f32> = DenseMatrix::from_iterator(blobs.data.into_iter(), 100, 2, 0);
|
||||
//! // Fit the algorithm and predict cluster labels
|
||||
//! let labels = DBSCAN::fit(&x, DBSCANParameters::default().with_eps(3.0)).
|
||||
//! and_then(|dbscan| dbscan.predict(&x));
|
||||
//! let labels: Vec<u32> = DBSCAN::fit(&x, DBSCANParameters::default().with_eps(3.0)).
|
||||
//! and_then(|dbscan| dbscan.predict(&x)).unwrap();
|
||||
//!
|
||||
//! println!("{:?}", labels);
|
||||
//! ```
|
||||
@@ -41,7 +42,7 @@
|
||||
//! * ["Density-Based Clustering in Spatial Databases: The Algorithm GDBSCAN and its Applications", Sander J., Ester M., Kriegel HP., Xu X.](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.63.1629&rep=rep1&type=pdf)
|
||||
|
||||
use std::fmt::Debug;
|
||||
use std::iter::Sum;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -49,47 +50,58 @@ use serde::{Deserialize, Serialize};
|
||||
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
|
||||
use crate::api::{Predictor, UnsupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::{row_iter, Matrix};
|
||||
use crate::math::distance::euclidian::Euclidian;
|
||||
use crate::math::distance::{Distance, Distances};
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::linalg::basic::arrays::{Array1, Array2};
|
||||
use crate::metrics::distance::euclidian::Euclidian;
|
||||
use crate::metrics::distance::{Distance, Distances};
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::tree::decision_tree_classifier::which_max;
|
||||
|
||||
/// DBSCAN clustering algorithm
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct DBSCAN<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
pub struct DBSCAN<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>> {
|
||||
cluster_labels: Vec<i16>,
|
||||
num_classes: usize,
|
||||
knn_algorithm: KNNAlgorithm<T, D>,
|
||||
eps: T,
|
||||
knn_algorithm: KNNAlgorithm<TX, D>,
|
||||
eps: f64,
|
||||
_phantom_ty: PhantomData<TY>,
|
||||
_phantom_x: PhantomData<X>,
|
||||
_phantom_y: PhantomData<Y>,
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// DBSCAN clustering algorithm parameters
|
||||
pub struct DBSCANParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
pub struct DBSCANParameters<T: Number, D: Distance<Vec<T>>> {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
pub distance: D,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
|
||||
pub min_samples: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
|
||||
pub eps: T,
|
||||
pub eps: f64,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// KNN algorithm to use.
|
||||
pub algorithm: KNNAlgorithmName,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
_phantom_t: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> DBSCANParameters<T, D> {
|
||||
impl<T: Number, D: Distance<Vec<T>>> DBSCANParameters<T, D> {
|
||||
/// a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
pub fn with_distance<DD: Distance<Vec<T>, T>>(self, distance: DD) -> DBSCANParameters<T, DD> {
|
||||
pub fn with_distance<DD: Distance<Vec<T>>>(self, distance: DD) -> DBSCANParameters<T, DD> {
|
||||
DBSCANParameters {
|
||||
distance,
|
||||
min_samples: self.min_samples,
|
||||
eps: self.eps,
|
||||
algorithm: self.algorithm,
|
||||
_phantom_t: PhantomData,
|
||||
}
|
||||
}
|
||||
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
|
||||
@@ -98,7 +110,7 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> DBSCANParameters<T, D> {
|
||||
self
|
||||
}
|
||||
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
|
||||
pub fn with_eps(mut self, eps: T) -> Self {
|
||||
pub fn with_eps(mut self, eps: f64) -> Self {
|
||||
self.eps = eps;
|
||||
self
|
||||
}
|
||||
@@ -109,7 +121,113 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> DBSCANParameters<T, D> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> PartialEq for DBSCAN<T, D> {
|
||||
/// DBSCAN grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DBSCANSearchParameters<T: Number, D: Distance<Vec<T>>> {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
pub distance: Vec<D>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
|
||||
pub min_samples: Vec<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
|
||||
pub eps: Vec<f64>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// KNN algorithm to use.
|
||||
pub algorithm: Vec<KNNAlgorithmName>,
|
||||
_phantom_t: PhantomData<T>,
|
||||
}
|
||||
|
||||
/// DBSCAN grid search iterator
|
||||
pub struct DBSCANSearchParametersIterator<T: Number, D: Distance<Vec<T>>> {
|
||||
dbscan_search_parameters: DBSCANSearchParameters<T, D>,
|
||||
current_distance: usize,
|
||||
current_min_samples: usize,
|
||||
current_eps: usize,
|
||||
current_algorithm: usize,
|
||||
}
|
||||
|
||||
impl<T: Number, D: Distance<Vec<T>>> IntoIterator for DBSCANSearchParameters<T, D> {
|
||||
type Item = DBSCANParameters<T, D>;
|
||||
type IntoIter = DBSCANSearchParametersIterator<T, D>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
DBSCANSearchParametersIterator {
|
||||
dbscan_search_parameters: self,
|
||||
current_distance: 0,
|
||||
current_min_samples: 0,
|
||||
current_eps: 0,
|
||||
current_algorithm: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Number, D: Distance<Vec<T>>> Iterator for DBSCANSearchParametersIterator<T, D> {
|
||||
type Item = DBSCANParameters<T, D>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_distance == self.dbscan_search_parameters.distance.len()
|
||||
&& self.current_min_samples == self.dbscan_search_parameters.min_samples.len()
|
||||
&& self.current_eps == self.dbscan_search_parameters.eps.len()
|
||||
&& self.current_algorithm == self.dbscan_search_parameters.algorithm.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = DBSCANParameters {
|
||||
distance: self.dbscan_search_parameters.distance[self.current_distance].clone(),
|
||||
min_samples: self.dbscan_search_parameters.min_samples[self.current_min_samples],
|
||||
eps: self.dbscan_search_parameters.eps[self.current_eps],
|
||||
algorithm: self.dbscan_search_parameters.algorithm[self.current_algorithm].clone(),
|
||||
_phantom_t: PhantomData,
|
||||
};
|
||||
|
||||
if self.current_distance + 1 < self.dbscan_search_parameters.distance.len() {
|
||||
self.current_distance += 1;
|
||||
} else if self.current_min_samples + 1 < self.dbscan_search_parameters.min_samples.len() {
|
||||
self.current_distance = 0;
|
||||
self.current_min_samples += 1;
|
||||
} else if self.current_eps + 1 < self.dbscan_search_parameters.eps.len() {
|
||||
self.current_distance = 0;
|
||||
self.current_min_samples = 0;
|
||||
self.current_eps += 1;
|
||||
} else if self.current_algorithm + 1 < self.dbscan_search_parameters.algorithm.len() {
|
||||
self.current_distance = 0;
|
||||
self.current_min_samples = 0;
|
||||
self.current_eps = 0;
|
||||
self.current_algorithm += 1;
|
||||
} else {
|
||||
self.current_distance += 1;
|
||||
self.current_min_samples += 1;
|
||||
self.current_eps += 1;
|
||||
self.current_algorithm += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Number> Default for DBSCANSearchParameters<T, Euclidian<T>> {
|
||||
fn default() -> Self {
|
||||
let default_params = DBSCANParameters::default();
|
||||
|
||||
DBSCANSearchParameters {
|
||||
distance: vec![default_params.distance],
|
||||
min_samples: vec![default_params.min_samples],
|
||||
eps: vec![default_params.eps],
|
||||
algorithm: vec![default_params.algorithm],
|
||||
_phantom_t: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>> PartialEq
|
||||
for DBSCAN<TX, TY, X, Y, D>
|
||||
{
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.cluster_labels.len() == other.cluster_labels.len()
|
||||
&& self.num_classes == other.num_classes
|
||||
@@ -118,47 +236,50 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> PartialEq for DBSCAN<T, D> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for DBSCANParameters<T, Euclidian> {
|
||||
impl<T: Number> Default for DBSCANParameters<T, Euclidian<T>> {
|
||||
fn default() -> Self {
|
||||
DBSCANParameters {
|
||||
distance: Distances::euclidian(),
|
||||
min_samples: 5,
|
||||
eps: T::half(),
|
||||
algorithm: KNNAlgorithmName::CoverTree,
|
||||
eps: 0.5f64,
|
||||
algorithm: KNNAlgorithmName::default(),
|
||||
_phantom_t: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber + Sum, M: Matrix<T>, D: Distance<Vec<T>, T>>
|
||||
UnsupervisedEstimator<M, DBSCANParameters<T, D>> for DBSCAN<T, D>
|
||||
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
|
||||
UnsupervisedEstimator<X, DBSCANParameters<TX, D>> for DBSCAN<TX, TY, X, Y, D>
|
||||
{
|
||||
fn fit(x: &M, parameters: DBSCANParameters<T, D>) -> Result<Self, Failed> {
|
||||
fn fit(x: &X, parameters: DBSCANParameters<TX, D>) -> Result<Self, Failed> {
|
||||
DBSCAN::fit(x, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, D: Distance<Vec<T>, T>> Predictor<M, M::RowVector>
|
||||
for DBSCAN<T, D>
|
||||
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>> Predictor<X, Y>
|
||||
for DBSCAN<TX, TY, X, Y, D>
|
||||
{
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
|
||||
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
|
||||
DBSCAN<TX, TY, X, Y, D>
|
||||
{
|
||||
/// Fit algorithm to _NxM_ matrix where _N_ is number of samples and _M_ is number of features.
|
||||
/// * `data` - training instances to cluster
|
||||
/// * `k` - number of clusters
|
||||
/// * `parameters` - cluster parameters
|
||||
pub fn fit<M: Matrix<T>>(
|
||||
x: &M,
|
||||
parameters: DBSCANParameters<T, D>,
|
||||
) -> Result<DBSCAN<T, D>, Failed> {
|
||||
pub fn fit(
|
||||
x: &X,
|
||||
parameters: DBSCANParameters<TX, D>,
|
||||
) -> Result<DBSCAN<TX, TY, X, Y, D>, Failed> {
|
||||
if parameters.min_samples < 1 {
|
||||
return Err(Failed::fit("Invalid minPts"));
|
||||
}
|
||||
|
||||
if parameters.eps <= T::zero() {
|
||||
if parameters.eps <= 0f64 {
|
||||
return Err(Failed::fit("Invalid radius: "));
|
||||
}
|
||||
|
||||
@@ -170,13 +291,19 @@ impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
|
||||
let n = x.shape().0;
|
||||
let mut y = vec![undefined; n];
|
||||
|
||||
let algo = parameters
|
||||
.algorithm
|
||||
.fit(row_iter(x).collect(), parameters.distance)?;
|
||||
let algo = parameters.algorithm.fit(
|
||||
x.row_iter()
|
||||
.map(|row| row.iterator(0).cloned().collect())
|
||||
.collect(),
|
||||
parameters.distance,
|
||||
)?;
|
||||
|
||||
for (i, e) in row_iter(x).enumerate() {
|
||||
let mut row = vec![TX::zero(); x.shape().1];
|
||||
|
||||
for (i, e) in x.row_iter().enumerate() {
|
||||
if y[i] == undefined {
|
||||
let mut neighbors = algo.find_radius(&e, parameters.eps)?;
|
||||
e.iterator(0).zip(row.iter_mut()).for_each(|(&x, r)| *r = x);
|
||||
let mut neighbors = algo.find_radius(&row, parameters.eps)?;
|
||||
if neighbors.len() < parameters.min_samples {
|
||||
y[i] = outlier;
|
||||
} else {
|
||||
@@ -188,8 +315,7 @@ impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
|
||||
}
|
||||
}
|
||||
|
||||
while !neighbors.is_empty() {
|
||||
let neighbor = neighbors.pop().unwrap();
|
||||
while let Some(neighbor) = neighbors.pop() {
|
||||
let index = neighbor.0;
|
||||
|
||||
if y[index] == outlier {
|
||||
@@ -227,18 +353,25 @@ impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
|
||||
num_classes: k as usize,
|
||||
knn_algorithm: algo,
|
||||
eps: parameters.eps,
|
||||
_phantom_ty: PhantomData,
|
||||
_phantom_x: PhantomData,
|
||||
_phantom_y: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
/// Predict clusters for `x`
|
||||
/// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features.
|
||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let (n, m) = x.shape();
|
||||
let mut result = M::zeros(1, n);
|
||||
let mut row = vec![T::zero(); m];
|
||||
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
let (n, _) = x.shape();
|
||||
let mut result = Y::zeros(n);
|
||||
|
||||
let mut row = vec![TX::zero(); x.shape().1];
|
||||
|
||||
for i in 0..n {
|
||||
x.copy_row_as_vec(i, &mut row);
|
||||
x.get_row(i)
|
||||
.iterator(0)
|
||||
.zip(row.iter_mut())
|
||||
.for_each(|(&x, r)| *r = x);
|
||||
let neighbors = self.knn_algorithm.find_radius(&row, self.eps)?;
|
||||
let mut label = vec![0usize; self.num_classes + 1];
|
||||
for neighbor in neighbors {
|
||||
@@ -251,24 +384,50 @@ impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
|
||||
}
|
||||
let class = which_max(&label);
|
||||
if class != self.num_classes {
|
||||
result.set(0, i, T::from(class).unwrap());
|
||||
result.set(i, TY::from(class + 1).unwrap());
|
||||
} else {
|
||||
result.set(0, i, -T::one());
|
||||
result.set(i, TY::zero());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result.to_row_vector())
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
#[cfg(feature = "serde")]
|
||||
use crate::math::distance::euclidian::Euclidian;
|
||||
use crate::metrics::distance::euclidian::Euclidian;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters: DBSCANSearchParameters<f64, Euclidian<f64>> = DBSCANSearchParameters {
|
||||
min_samples: vec![10, 100],
|
||||
eps: vec![1., 2.],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.min_samples, 10);
|
||||
assert_eq!(next.eps, 1.);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.min_samples, 100);
|
||||
assert_eq!(next.eps, 1.);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.min_samples, 10);
|
||||
assert_eq!(next.eps, 2.);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.min_samples, 100);
|
||||
assert_eq!(next.eps, 2.);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn fit_predict_dbscan() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
@@ -283,9 +442,10 @@ mod tests {
|
||||
&[2.2, 1.2],
|
||||
&[1.8, 0.8],
|
||||
&[3.0, 5.0],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let expected_labels = vec![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0];
|
||||
let expected_labels = vec![1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 0];
|
||||
|
||||
let dbscan = DBSCAN::fit(
|
||||
&x,
|
||||
@@ -295,12 +455,15 @@ mod tests {
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let predicted_labels = dbscan.predict(&x).unwrap();
|
||||
let predicted_labels: Vec<i32> = dbscan.predict(&x).unwrap();
|
||||
|
||||
assert_eq!(expected_labels, predicted_labels);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
@@ -325,13 +488,30 @@ mod tests {
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let dbscan = DBSCAN::fit(&x, Default::default()).unwrap();
|
||||
|
||||
let deserialized_dbscan: DBSCAN<f64, Euclidian> =
|
||||
let deserialized_dbscan: DBSCAN<f32, f32, DenseMatrix<f32>, Vec<f32>, Euclidian<f32>> =
|
||||
serde_json::from_str(&serde_json::to_string(&dbscan).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(dbscan, deserialized_dbscan);
|
||||
}
|
||||
|
||||
#[cfg(feature = "datasets")]
|
||||
#[test]
|
||||
fn from_vec() {
|
||||
use crate::dataset::generator;
|
||||
|
||||
// Generate three blobs
|
||||
let blobs = generator::make_blobs(100, 2, 3);
|
||||
let x: DenseMatrix<f32> = DenseMatrix::from_iterator(blobs.data.into_iter(), 100, 2, 0);
|
||||
// Fit the algorithm and predict cluster labels
|
||||
let labels: Vec<i32> = DBSCAN::fit(&x, DBSCANParameters::default().with_eps(3.0))
|
||||
.and_then(|dbscan| dbscan.predict(&x))
|
||||
.unwrap();
|
||||
|
||||
println!("{labels:?}");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
/// # Hierarchical clustering
|
||||
///
|
||||
/// Implement hierarchical clustering methods:
|
||||
/// * Agglomerative clustering (current)
|
||||
/// * Bisecting K-Means (future)
|
||||
/// * Fastcluster (future)
|
||||
///
|
||||
|
||||
/*
|
||||
class AgglomerativeClustering():
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
n_clusters : int or None, default=2
|
||||
The number of clusters to find. It must be ``None`` if
|
||||
``distance_threshold`` is not ``None``.
|
||||
affinity : str or callable, default='euclidean'
|
||||
If linkage is "ward", only "euclidean" is accepted.
|
||||
linkage : {'ward',}, default='ward'
|
||||
Which linkage criterion to use. The linkage criterion determines which
|
||||
distance to use between sets of observation. The algorithm will merge
|
||||
the pairs of cluster that minimize this criterion.
|
||||
- 'ward' minimizes the variance of the clusters being merged.
|
||||
compute_distances : bool, default=False
|
||||
Computes distances between clusters even if `distance_threshold` is not
|
||||
used. This can be used to make dendrogram visualization, but introduces
|
||||
a computational and memory overhead.
|
||||
"""
|
||||
|
||||
def fit(X):
|
||||
# compute tree
|
||||
# <https://github.com/scikit-learn/scikit-learn/blob/02ebf9e68fe1fc7687d9e1047b9e465ae0fd945e/sklearn/cluster/_agglomerative.py#L172>
|
||||
parents, childern = ward_tree(X, ....)
|
||||
# compute clusters
|
||||
# <https://github.com/scikit-learn/scikit-learn/blob/70c495250fea7fa3c8c1a4631e6ddcddc9f22451/sklearn/cluster/_hierarchical_fast.pyx#L98>
|
||||
labels = _hierarchical.hc_get_heads(parents)
|
||||
# assign cluster numbers
|
||||
self.labels_ = np.searchsorted(np.unique(labels), labels)
|
||||
|
||||
*/
|
||||
|
||||
// implement ward tree
|
||||
// use scipy.cluster.hierarchy.ward
|
||||
// <https://github.com/scipy/scipy/blob/main/scipy/cluster/hierarchy.py#L738>
|
||||
// use linkage
|
||||
// <https://github.com/scipy/scipy/blob/main/scipy/cluster/hierarchy.py#L837>
|
||||
// use nn_chain
|
||||
// <https://github.com/scipy/scipy/blob/main/scipy/cluster/_hierarchy.pyx#L906>
|
||||
|
||||
// implement hc_get_heads
|
||||
|
||||
|
||||
mod tests {
|
||||
// >>> from sklearn.cluster import AgglomerativeClustering
|
||||
// >>> import numpy as np
|
||||
// >>> X = np.array([[1, 2], [1, 4], [1, 0],
|
||||
// ... [4, 2], [4, 4], [4, 0]])
|
||||
// >>> clustering = AgglomerativeClustering().fit(X)
|
||||
// >>> clustering
|
||||
// AgglomerativeClustering()
|
||||
// >>> clustering.labels_
|
||||
// array([1, 1, 1, 0, 0, 0])
|
||||
}
|
||||
+227
-66
@@ -11,12 +11,12 @@
|
||||
//! these re-calculated centroids becoming the new centers of their respective clusters. Next all instances of the training set are re-assigned to their closest cluster again.
|
||||
//! This iterative process continues until convergence is achieved and the clusters are considered settled.
|
||||
//!
|
||||
//! Initial choice of K data points is very important and has big effect on performance of the algorithm. SmartCore uses k-means++ algorithm to initialize cluster centers.
|
||||
//! Initial choice of K data points is very important and has big effect on performance of the algorithm. `smartcore` uses k-means++ algorithm to initialize cluster centers.
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
//! use smartcore::cluster::kmeans::*;
|
||||
//!
|
||||
//! // Iris data
|
||||
@@ -41,10 +41,10 @@
|
||||
//! &[4.9, 2.4, 3.3, 1.0],
|
||||
//! &[6.6, 2.9, 4.6, 1.3],
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! let kmeans = KMeans::fit(&x, KMeansParameters::default().with_k(2)).unwrap(); // Fit to data, 2 clusters
|
||||
//! let y_hat = kmeans.predict(&x).unwrap(); // use the same points for prediction
|
||||
//! let y_hat: Vec<u8> = kmeans.predict(&x).unwrap(); // use the same points for prediction
|
||||
//! ```
|
||||
//!
|
||||
//! ## References:
|
||||
@@ -52,32 +52,37 @@
|
||||
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 10.3.1 K-Means Clustering](http://faculty.marshall.usc.edu/gareth-james/ISL/)
|
||||
//! * ["k-means++: The Advantages of Careful Seeding", Arthur D., Vassilvitskii S.](http://ilpubs.stanford.edu:8090/778/1/2006-13.pdf)
|
||||
|
||||
use rand::Rng;
|
||||
use std::fmt::Debug;
|
||||
use std::iter::Sum;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use rand::Rng;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
||||
use crate::api::{Predictor, UnsupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::distance::euclidian::*;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::linalg::basic::arrays::{Array1, Array2};
|
||||
use crate::metrics::distance::euclidian::*;
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::rand_custom::get_rng_impl;
|
||||
|
||||
/// K-Means clustering algorithm
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct KMeans<T: RealNumber> {
|
||||
pub struct KMeans<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
|
||||
k: usize,
|
||||
_y: Vec<usize>,
|
||||
size: Vec<usize>,
|
||||
_distortion: T,
|
||||
centroids: Vec<Vec<T>>,
|
||||
_distortion: f64,
|
||||
centroids: Vec<Vec<f64>>,
|
||||
_phantom_tx: PhantomData<TX>,
|
||||
_phantom_ty: PhantomData<TY>,
|
||||
_phantom_x: PhantomData<X>,
|
||||
_phantom_y: PhantomData<Y>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> PartialEq for KMeans<T> {
|
||||
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq for KMeans<TX, TY, X, Y> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.k != other.k
|
||||
|| self.size != other.size
|
||||
@@ -91,7 +96,7 @@ impl<T: RealNumber> PartialEq for KMeans<T> {
|
||||
return false;
|
||||
}
|
||||
for j in 0..self.centroids[i].len() {
|
||||
if (self.centroids[i][j] - other.centroids[i][j]).abs() > T::epsilon() {
|
||||
if (self.centroids[i][j] - other.centroids[i][j]).abs() > f64::EPSILON {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -101,13 +106,20 @@ impl<T: RealNumber> PartialEq for KMeans<T> {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// K-Means clustering algorithm parameters
|
||||
pub struct KMeansParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Number of clusters.
|
||||
pub k: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Maximum number of iterations of the k-means algorithm for a single run.
|
||||
pub max_iter: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Determines random number generation for centroid initialization.
|
||||
/// Use an int to make the randomness deterministic
|
||||
pub seed: Option<u64>,
|
||||
}
|
||||
|
||||
impl KMeansParameters {
|
||||
@@ -128,27 +140,118 @@ impl Default for KMeansParameters {
|
||||
KMeansParameters {
|
||||
k: 2,
|
||||
max_iter: 100,
|
||||
seed: Option::None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber + Sum, M: Matrix<T>> UnsupervisedEstimator<M, KMeansParameters> for KMeans<T> {
|
||||
fn fit(x: &M, parameters: KMeansParameters) -> Result<Self, Failed> {
|
||||
/// KMeans grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KMeansSearchParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Number of clusters.
|
||||
pub k: Vec<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Maximum number of iterations of the k-means algorithm for a single run.
|
||||
pub max_iter: Vec<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Determines random number generation for centroid initialization.
|
||||
/// Use an int to make the randomness deterministic
|
||||
pub seed: Vec<Option<u64>>,
|
||||
}
|
||||
|
||||
/// KMeans grid search iterator
|
||||
pub struct KMeansSearchParametersIterator {
|
||||
kmeans_search_parameters: KMeansSearchParameters,
|
||||
current_k: usize,
|
||||
current_max_iter: usize,
|
||||
current_seed: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for KMeansSearchParameters {
|
||||
type Item = KMeansParameters;
|
||||
type IntoIter = KMeansSearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
KMeansSearchParametersIterator {
|
||||
kmeans_search_parameters: self,
|
||||
current_k: 0,
|
||||
current_max_iter: 0,
|
||||
current_seed: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for KMeansSearchParametersIterator {
|
||||
type Item = KMeansParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_k == self.kmeans_search_parameters.k.len()
|
||||
&& self.current_max_iter == self.kmeans_search_parameters.max_iter.len()
|
||||
&& self.current_seed == self.kmeans_search_parameters.seed.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = KMeansParameters {
|
||||
k: self.kmeans_search_parameters.k[self.current_k],
|
||||
max_iter: self.kmeans_search_parameters.max_iter[self.current_max_iter],
|
||||
seed: self.kmeans_search_parameters.seed[self.current_seed],
|
||||
};
|
||||
|
||||
if self.current_k + 1 < self.kmeans_search_parameters.k.len() {
|
||||
self.current_k += 1;
|
||||
} else if self.current_max_iter + 1 < self.kmeans_search_parameters.max_iter.len() {
|
||||
self.current_k = 0;
|
||||
self.current_max_iter += 1;
|
||||
} else if self.current_seed + 1 < self.kmeans_search_parameters.seed.len() {
|
||||
self.current_k = 0;
|
||||
self.current_max_iter = 0;
|
||||
self.current_seed += 1;
|
||||
} else {
|
||||
self.current_k += 1;
|
||||
self.current_max_iter += 1;
|
||||
self.current_seed += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for KMeansSearchParameters {
|
||||
fn default() -> Self {
|
||||
let default_params = KMeansParameters::default();
|
||||
|
||||
KMeansSearchParameters {
|
||||
k: vec![default_params.k],
|
||||
max_iter: vec![default_params.max_iter],
|
||||
seed: vec![default_params.seed],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||
UnsupervisedEstimator<X, KMeansParameters> for KMeans<TX, TY, X, Y>
|
||||
{
|
||||
fn fit(x: &X, parameters: KMeansParameters) -> Result<Self, Failed> {
|
||||
KMeans::fit(x, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for KMeans<T> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> Predictor<X, Y>
|
||||
for KMeans<TX, TY, X, Y>
|
||||
{
|
||||
fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber + Sum> KMeans<T> {
|
||||
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y> {
|
||||
/// Fit algorithm to _NxM_ matrix where _N_ is number of samples and _M_ is number of features.
|
||||
/// * `data` - training instances to cluster
|
||||
/// * `data` - training instances to cluster
|
||||
/// * `parameters` - cluster parameters
|
||||
pub fn fit<M: Matrix<T>>(data: &M, parameters: KMeansParameters) -> Result<KMeans<T>, Failed> {
|
||||
pub fn fit(data: &X, parameters: KMeansParameters) -> Result<KMeans<TX, TY, X, Y>, Failed> {
|
||||
let bbd = BBDTree::new(data);
|
||||
|
||||
if parameters.k < 2 {
|
||||
@@ -167,10 +270,10 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
||||
|
||||
let (n, d) = data.shape();
|
||||
|
||||
let mut distortion = T::max_value();
|
||||
let mut y = KMeans::kmeans_plus_plus(data, parameters.k);
|
||||
let mut distortion = f64::MAX;
|
||||
let mut y = KMeans::<TX, TY, X, Y>::kmeans_plus_plus(data, parameters.k, parameters.seed);
|
||||
let mut size = vec![0; parameters.k];
|
||||
let mut centroids = vec![vec![T::zero(); d]; parameters.k];
|
||||
let mut centroids = vec![vec![0f64; d]; parameters.k];
|
||||
|
||||
for i in 0..n {
|
||||
size[y[i]] += 1;
|
||||
@@ -178,23 +281,23 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
||||
|
||||
for i in 0..n {
|
||||
for j in 0..d {
|
||||
centroids[y[i]][j] += data.get(i, j);
|
||||
centroids[y[i]][j] += data.get((i, j)).to_f64().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
for i in 0..parameters.k {
|
||||
for j in 0..d {
|
||||
centroids[i][j] /= T::from(size[i]).unwrap();
|
||||
centroids[i][j] /= size[i] as f64;
|
||||
}
|
||||
}
|
||||
|
||||
let mut sums = vec![vec![T::zero(); d]; parameters.k];
|
||||
let mut sums = vec![vec![0f64; d]; parameters.k];
|
||||
for _ in 1..=parameters.max_iter {
|
||||
let dist = bbd.clustering(¢roids, &mut sums, &mut size, &mut y);
|
||||
for i in 0..parameters.k {
|
||||
if size[i] > 0 {
|
||||
for j in 0..d {
|
||||
centroids[i][j] = T::from(sums[i][j]).unwrap() / T::from(size[i]).unwrap();
|
||||
centroids[i][j] = sums[i][j] / size[i] as f64;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -212,48 +315,61 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
||||
size,
|
||||
_distortion: distortion,
|
||||
centroids,
|
||||
_phantom_tx: PhantomData,
|
||||
_phantom_ty: PhantomData,
|
||||
_phantom_x: PhantomData,
|
||||
_phantom_y: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
/// Predict clusters for `x`
|
||||
/// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features.
|
||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let (n, m) = x.shape();
|
||||
let mut result = M::zeros(1, n);
|
||||
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
let (n, _) = x.shape();
|
||||
let mut result = Y::zeros(n);
|
||||
|
||||
let mut row = vec![T::zero(); m];
|
||||
let mut row = vec![0f64; x.shape().1];
|
||||
|
||||
for i in 0..n {
|
||||
let mut min_dist = T::max_value();
|
||||
let mut min_dist = f64::MAX;
|
||||
let mut best_cluster = 0;
|
||||
|
||||
for j in 0..self.k {
|
||||
x.copy_row_as_vec(i, &mut row);
|
||||
x.get_row(i)
|
||||
.iterator(0)
|
||||
.zip(row.iter_mut())
|
||||
.for_each(|(&x, r)| *r = x.to_f64().unwrap());
|
||||
let dist = Euclidian::squared_distance(&row, &self.centroids[j]);
|
||||
if dist < min_dist {
|
||||
min_dist = dist;
|
||||
best_cluster = j;
|
||||
}
|
||||
}
|
||||
result.set(0, i, T::from(best_cluster).unwrap());
|
||||
result.set(i, TY::from_usize(best_cluster).unwrap());
|
||||
}
|
||||
|
||||
Ok(result.to_row_vector())
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize) -> Vec<usize> {
|
||||
let mut rng = rand::thread_rng();
|
||||
let (n, m) = data.shape();
|
||||
fn kmeans_plus_plus(data: &X, k: usize, seed: Option<u64>) -> Vec<usize> {
|
||||
let mut rng = get_rng_impl(seed);
|
||||
let (n, _) = data.shape();
|
||||
let mut y = vec![0; n];
|
||||
let mut centroid = data.get_row_as_vec(rng.gen_range(0..n));
|
||||
let mut centroid: Vec<TX> = data
|
||||
.get_row(rng.gen_range(0..n))
|
||||
.iterator(0)
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let mut d = vec![T::max_value(); n];
|
||||
|
||||
let mut row = vec![T::zero(); m];
|
||||
let mut d = vec![f64::MAX; n];
|
||||
let mut row = vec![TX::zero(); data.shape().1];
|
||||
|
||||
for j in 1..k {
|
||||
for i in 0..n {
|
||||
data.copy_row_as_vec(i, &mut row);
|
||||
data.get_row(i)
|
||||
.iterator(0)
|
||||
.zip(row.iter_mut())
|
||||
.for_each(|(&x, r)| *r = x);
|
||||
let dist = Euclidian::squared_distance(&row, ¢roid);
|
||||
|
||||
if dist < d[i] {
|
||||
@@ -262,12 +378,12 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
||||
}
|
||||
}
|
||||
|
||||
let mut sum: T = T::zero();
|
||||
let mut sum = 0f64;
|
||||
for i in d.iter() {
|
||||
sum += *i;
|
||||
}
|
||||
let cutoff = T::from(rng.gen::<f64>()).unwrap() * sum;
|
||||
let mut cost = T::zero();
|
||||
let cutoff = rng.gen::<f64>() * sum;
|
||||
let mut cost = 0f64;
|
||||
let mut index = 0;
|
||||
while index < n {
|
||||
cost += d[index];
|
||||
@@ -277,11 +393,14 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
||||
index += 1;
|
||||
}
|
||||
|
||||
data.copy_row_as_vec(index, &mut centroid);
|
||||
centroid = data.get_row(index).iterator(0).cloned().collect();
|
||||
}
|
||||
|
||||
for i in 0..n {
|
||||
data.copy_row_as_vec(i, &mut row);
|
||||
data.get_row(i)
|
||||
.iterator(0)
|
||||
.zip(row.iter_mut())
|
||||
.for_each(|(&x, r)| *r = x);
|
||||
let dist = Euclidian::squared_distance(&row, ¢roid);
|
||||
|
||||
if dist < d[i] {
|
||||
@@ -297,25 +416,61 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn invalid_k() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
|
||||
let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
|
||||
|
||||
assert!(KMeans::fit(&x, KMeansParameters::default().with_k(0)).is_err());
|
||||
assert!(KMeans::<i32, i32, DenseMatrix<i32>, Vec<i32>>::fit(
|
||||
&x,
|
||||
KMeansParameters::default().with_k(0)
|
||||
)
|
||||
.is_err());
|
||||
assert_eq!(
|
||||
"Fit failed: invalid number of clusters: 1",
|
||||
KMeans::fit(&x, KMeansParameters::default().with_k(1))
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
KMeans::<i32, i32, DenseMatrix<i32>, Vec<i32>>::fit(
|
||||
&x,
|
||||
KMeansParameters::default().with_k(1)
|
||||
)
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn fit_predict_iris() {
|
||||
fn search_parameters() {
|
||||
let parameters = KMeansSearchParameters {
|
||||
k: vec![2, 4],
|
||||
max_iter: vec![10, 100],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.k, 2);
|
||||
assert_eq!(next.max_iter, 10);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.k, 4);
|
||||
assert_eq!(next.max_iter, 10);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.k, 2);
|
||||
assert_eq!(next.max_iter, 100);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.k, 4);
|
||||
assert_eq!(next.max_iter, 100);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn fit_predict() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
@@ -337,18 +492,22 @@ mod tests {
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let kmeans = KMeans::fit(&x, Default::default()).unwrap();
|
||||
|
||||
let y = kmeans.predict(&x).unwrap();
|
||||
let y: Vec<usize> = kmeans.predict(&x).unwrap();
|
||||
|
||||
for i in 0..y.len() {
|
||||
assert_eq!(y[i] as usize, kmeans._y[i]);
|
||||
for (i, _y_i) in y.iter().enumerate() {
|
||||
assert_eq!({ y[i] }, kmeans._y[i]);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
@@ -373,11 +532,13 @@ mod tests {
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let kmeans = KMeans::fit(&x, Default::default()).unwrap();
|
||||
let kmeans: KMeans<f32, f32, DenseMatrix<f32>, Vec<f32>> =
|
||||
KMeans::fit(&x, Default::default()).unwrap();
|
||||
|
||||
let deserialized_kmeans: KMeans<f64> =
|
||||
let deserialized_kmeans: KMeans<f32, f32, DenseMatrix<f32>, Vec<f32>> =
|
||||
serde_json::from_str(&serde_json::to_string(&kmeans).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(kmeans, deserialized_kmeans);
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
//! Clustering is the type of unsupervised learning where you divide the population or data points into a number of groups such that data points in the same groups
|
||||
//! are more similar to other data points in the same group than those in other groups. In simple words, the aim is to segregate groups with similar traits and assign them into clusters.
|
||||
|
||||
pub mod agglomerative;
|
||||
pub mod dbscan;
|
||||
/// An iterative clustering algorithm that aims to find local maxima in each iteration.
|
||||
pub mod kmeans;
|
||||
|
||||
@@ -31,7 +31,7 @@ use crate::dataset::Dataset;
|
||||
pub fn load_dataset() -> Dataset<f32, f32> {
|
||||
let (x, y, num_samples, num_features) = match deserialize_data(std::include_bytes!("boston.xy"))
|
||||
{
|
||||
Err(why) => panic!("Can't deserialize boston.xy. {}", why),
|
||||
Err(why) => panic!("Can't deserialize boston.xy. {why}"),
|
||||
Ok((x, y, num_samples, num_features)) => (x, y, num_samples, num_features),
|
||||
};
|
||||
|
||||
@@ -69,7 +69,10 @@ mod tests {
|
||||
assert!(serialize_data(&dataset, "boston.xy").is_ok());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn boston_dataset() {
|
||||
let dataset = load_dataset();
|
||||
|
||||
@@ -30,11 +30,16 @@ use crate::dataset::deserialize_data;
|
||||
use crate::dataset::Dataset;
|
||||
|
||||
/// Get dataset
|
||||
pub fn load_dataset() -> Dataset<f32, f32> {
|
||||
pub fn load_dataset() -> Dataset<f32, u32> {
|
||||
let (x, y, num_samples, num_features) =
|
||||
match deserialize_data(std::include_bytes!("breast_cancer.xy")) {
|
||||
Err(why) => panic!("Can't deserialize breast_cancer.xy. {}", why),
|
||||
Ok((x, y, num_samples, num_features)) => (x, y, num_samples, num_features),
|
||||
Err(why) => panic!("Can't deserialize breast_cancer.xy. {why}"),
|
||||
Ok((x, y, num_samples, num_features)) => (
|
||||
x,
|
||||
y.into_iter().map(|x| x as u32).collect(),
|
||||
num_samples,
|
||||
num_features,
|
||||
),
|
||||
};
|
||||
|
||||
Dataset {
|
||||
@@ -66,20 +71,22 @@ pub fn load_dataset() -> Dataset<f32, f32> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use super::super::*;
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
fn refresh_cancer_dataset() {
|
||||
// run this test to generate breast_cancer.xy file.
|
||||
let dataset = load_dataset();
|
||||
assert!(serialize_data(&dataset, "breast_cancer.xy").is_ok());
|
||||
}
|
||||
// TODO: implement serialization
|
||||
// #[test]
|
||||
// #[ignore]
|
||||
// #[cfg(not(target_arch = "wasm32"))]
|
||||
// fn refresh_cancer_dataset() {
|
||||
// // run this test to generate breast_cancer.xy file.
|
||||
// let dataset = load_dataset();
|
||||
// assert!(serialize_data(&dataset, "breast_cancer.xy").is_ok());
|
||||
// }
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn cancer_dataset() {
|
||||
let dataset = load_dataset();
|
||||
|
||||
+22
-15
@@ -23,11 +23,16 @@ use crate::dataset::deserialize_data;
|
||||
use crate::dataset::Dataset;
|
||||
|
||||
/// Get dataset
|
||||
pub fn load_dataset() -> Dataset<f32, f32> {
|
||||
pub fn load_dataset() -> Dataset<f32, u32> {
|
||||
let (x, y, num_samples, num_features) =
|
||||
match deserialize_data(std::include_bytes!("diabetes.xy")) {
|
||||
Err(why) => panic!("Can't deserialize diabetes.xy. {}", why),
|
||||
Ok((x, y, num_samples, num_features)) => (x, y, num_samples, num_features),
|
||||
Err(why) => panic!("Can't deserialize diabetes.xy. {why}"),
|
||||
Ok((x, y, num_samples, num_features)) => (
|
||||
x,
|
||||
y.into_iter().map(|x| x as u32).collect(),
|
||||
num_samples,
|
||||
num_features,
|
||||
),
|
||||
};
|
||||
|
||||
Dataset {
|
||||
@@ -35,7 +40,7 @@ pub fn load_dataset() -> Dataset<f32, f32> {
|
||||
target: y,
|
||||
num_samples,
|
||||
num_features,
|
||||
feature_names: vec![
|
||||
feature_names: [
|
||||
"Age", "Sex", "BMI", "BP", "S1", "S2", "S3", "S4", "S5", "S6",
|
||||
]
|
||||
.iter()
|
||||
@@ -50,20 +55,22 @@ pub fn load_dataset() -> Dataset<f32, f32> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use super::super::*;
|
||||
use super::*;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn refresh_diabetes_dataset() {
|
||||
// run this test to generate diabetes.xy file.
|
||||
let dataset = load_dataset();
|
||||
assert!(serialize_data(&dataset, "diabetes.xy").is_ok());
|
||||
}
|
||||
// TODO: fix serialization
|
||||
// #[cfg(not(target_arch = "wasm32"))]
|
||||
// #[test]
|
||||
// #[ignore]
|
||||
// fn refresh_diabetes_dataset() {
|
||||
// // run this test to generate diabetes.xy file.
|
||||
// let dataset = load_dataset();
|
||||
// assert!(serialize_data(&dataset, "diabetes.xy").is_ok());
|
||||
// }
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn boston_dataset() {
|
||||
let dataset = load_dataset();
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//! # Optical Recognition of Handwritten Digits Data Set
|
||||
//! # Optical Recognition of Handwritten Digits Dataset
|
||||
//!
|
||||
//! | Number of Instances | Number of Attributes | Missing Values? | Associated Tasks: |
|
||||
//! |-|-|-|-|
|
||||
@@ -16,7 +16,7 @@ use crate::dataset::Dataset;
|
||||
pub fn load_dataset() -> Dataset<f32, f32> {
|
||||
let (x, y, num_samples, num_features) = match deserialize_data(std::include_bytes!("digits.xy"))
|
||||
{
|
||||
Err(why) => panic!("Can't deserialize digits.xy. {}", why),
|
||||
Err(why) => panic!("Can't deserialize digits.xy. {why}"),
|
||||
Ok((x, y, num_samples, num_features)) => (x, y, num_samples, num_features),
|
||||
};
|
||||
|
||||
@@ -25,16 +25,14 @@ pub fn load_dataset() -> Dataset<f32, f32> {
|
||||
target: y,
|
||||
num_samples,
|
||||
num_features,
|
||||
feature_names: vec![
|
||||
"sepal length (cm)",
|
||||
feature_names: ["sepal length (cm)",
|
||||
"sepal width (cm)",
|
||||
"petal length (cm)",
|
||||
"petal width (cm)",
|
||||
]
|
||||
"petal width (cm)"]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect(),
|
||||
target_names: vec!["setosa", "versicolor", "virginica"]
|
||||
target_names: ["setosa", "versicolor", "virginica"]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect(),
|
||||
@@ -57,7 +55,10 @@ mod tests {
|
||||
let dataset = load_dataset();
|
||||
assert!(serialize_data(&dataset, "digits.xy").is_ok());
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn digits_dataset() {
|
||||
let dataset = load_dataset();
|
||||
|
||||
@@ -48,7 +48,7 @@ pub fn make_blobs(
|
||||
}
|
||||
|
||||
/// Make a large circle containing a smaller circle in 2d.
|
||||
pub fn make_circles(num_samples: usize, factor: f32, noise: f32) -> Dataset<f32, f32> {
|
||||
pub fn make_circles(num_samples: usize, factor: f32, noise: f32) -> Dataset<f32, u32> {
|
||||
if !(0.0..1.0).contains(&factor) {
|
||||
panic!("'factor' has to be between 0 and 1.");
|
||||
}
|
||||
@@ -79,7 +79,7 @@ pub fn make_circles(num_samples: usize, factor: f32, noise: f32) -> Dataset<f32,
|
||||
|
||||
Dataset {
|
||||
data: x,
|
||||
target: y,
|
||||
target: y.into_iter().map(|x| x as u32).collect(),
|
||||
num_samples,
|
||||
num_features: 2,
|
||||
feature_names: (0..2).map(|n| n.to_string()).collect(),
|
||||
@@ -89,7 +89,7 @@ pub fn make_circles(num_samples: usize, factor: f32, noise: f32) -> Dataset<f32,
|
||||
}
|
||||
|
||||
/// Make two interleaving half circles in 2d
|
||||
pub fn make_moons(num_samples: usize, noise: f32) -> Dataset<f32, f32> {
|
||||
pub fn make_moons(num_samples: usize, noise: f32) -> Dataset<f32, u32> {
|
||||
let num_samples_out = num_samples / 2;
|
||||
let num_samples_in = num_samples - num_samples_out;
|
||||
|
||||
@@ -116,7 +116,7 @@ pub fn make_moons(num_samples: usize, noise: f32) -> Dataset<f32, f32> {
|
||||
|
||||
Dataset {
|
||||
data: x,
|
||||
target: y,
|
||||
target: y.into_iter().map(|x| x as u32).collect(),
|
||||
num_samples,
|
||||
num_features: 2,
|
||||
feature_names: (0..2).map(|n| n.to_string()).collect(),
|
||||
@@ -137,7 +137,10 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn test_make_blobs() {
|
||||
let dataset = make_blobs(10, 2, 3);
|
||||
@@ -150,7 +153,10 @@ mod tests {
|
||||
assert_eq!(dataset.num_samples, 10);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn test_make_circles() {
|
||||
let dataset = make_circles(10, 0.5, 0.05);
|
||||
@@ -163,7 +169,10 @@ mod tests {
|
||||
assert_eq!(dataset.num_samples, 10);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn test_make_moons() {
|
||||
let dataset = make_moons(10, 0.05);
|
||||
|
||||
+29
-19
@@ -1,4 +1,4 @@
|
||||
//! # The Iris Dataset flower
|
||||
//! # The Iris flower dataset
|
||||
//!
|
||||
//! | Number of Instances | Number of Attributes | Missing Values? | Associated Tasks: |
|
||||
//! |-|-|-|-|
|
||||
@@ -19,18 +19,24 @@ use crate::dataset::deserialize_data;
|
||||
use crate::dataset::Dataset;
|
||||
|
||||
/// Get dataset
|
||||
pub fn load_dataset() -> Dataset<f32, f32> {
|
||||
let (x, y, num_samples, num_features) = match deserialize_data(std::include_bytes!("iris.xy")) {
|
||||
Err(why) => panic!("Can't deserialize iris.xy. {}", why),
|
||||
Ok((x, y, num_samples, num_features)) => (x, y, num_samples, num_features),
|
||||
};
|
||||
pub fn load_dataset() -> Dataset<f32, u32> {
|
||||
let (x, y, num_samples, num_features): (Vec<f32>, Vec<u32>, usize, usize) =
|
||||
match deserialize_data(std::include_bytes!("iris.xy")) {
|
||||
Err(why) => panic!("Can't deserialize iris.xy. {why}"),
|
||||
Ok((x, y, num_samples, num_features)) => (
|
||||
x,
|
||||
y.into_iter().map(|x| x as u32).collect(),
|
||||
num_samples,
|
||||
num_features,
|
||||
),
|
||||
};
|
||||
|
||||
Dataset {
|
||||
data: x,
|
||||
target: y,
|
||||
num_samples,
|
||||
num_features,
|
||||
feature_names: vec![
|
||||
feature_names: [
|
||||
"sepal length (cm)",
|
||||
"sepal width (cm)",
|
||||
"petal length (cm)",
|
||||
@@ -39,7 +45,7 @@ pub fn load_dataset() -> Dataset<f32, f32> {
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect(),
|
||||
target_names: vec!["setosa", "versicolor", "virginica"]
|
||||
target_names: ["setosa", "versicolor", "virginica"]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect(),
|
||||
@@ -50,20 +56,24 @@ pub fn load_dataset() -> Dataset<f32, f32> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use super::super::*;
|
||||
// #[cfg(not(target_arch = "wasm32"))]
|
||||
// use super::super::*;
|
||||
use super::*;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn refresh_iris_dataset() {
|
||||
// run this test to generate iris.xy file.
|
||||
let dataset = load_dataset();
|
||||
assert!(serialize_data(&dataset, "iris.xy").is_ok());
|
||||
}
|
||||
// TODO: fix serialization
|
||||
// #[cfg(not(target_arch = "wasm32"))]
|
||||
// #[test]
|
||||
// #[ignore]
|
||||
// fn refresh_iris_dataset() {
|
||||
// // run this test to generate iris.xy file.
|
||||
// let dataset = load_dataset();
|
||||
// assert!(serialize_data(&dataset, "iris.xy").is_ok());
|
||||
// }
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn iris_dataset() {
|
||||
let dataset = load_dataset();
|
||||
|
||||
+8
-5
@@ -1,6 +1,6 @@
|
||||
//! Datasets
|
||||
//!
|
||||
//! In this module you will find small datasets that are used in SmartCore for demonstration purpose mostly.
|
||||
//! In this module you will find small datasets that are used in `smartcore` mostly for demonstration purposes.
|
||||
pub mod boston;
|
||||
pub mod breast_cancer;
|
||||
pub mod diabetes;
|
||||
@@ -9,7 +9,7 @@ pub mod generator;
|
||||
pub mod iris;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::numbers::{basenum::Number, realnum::RealNumber};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::fs::File;
|
||||
use std::io;
|
||||
@@ -55,7 +55,7 @@ impl<X, Y> Dataset<X, Y> {
|
||||
// Running this in wasm throws: operation not supported on this platform.
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn serialize_data<X: RealNumber, Y: RealNumber>(
|
||||
pub(crate) fn serialize_data<X: Number + RealNumber, Y: RealNumber>(
|
||||
dataset: &Dataset<X, Y>,
|
||||
filename: &str,
|
||||
) -> Result<(), io::Error> {
|
||||
@@ -78,7 +78,7 @@ pub(crate) fn serialize_data<X: RealNumber, Y: RealNumber>(
|
||||
.collect();
|
||||
file.write_all(&y)?;
|
||||
}
|
||||
Err(why) => panic!("couldn't create {}: {}", filename, why),
|
||||
Err(why) => panic!("couldn't create {filename}: {why}"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -121,7 +121,10 @@ pub(crate) fn deserialize_data(
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn as_matrix() {
|
||||
let dataset = Dataset {
|
||||
|
||||
+240
-94
@@ -10,7 +10,7 @@
|
||||
//!
|
||||
//! Example:
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
//! use smartcore::decomposition::pca::*;
|
||||
//!
|
||||
//! // Iris data
|
||||
@@ -35,7 +35,7 @@
|
||||
//! &[4.9, 2.4, 3.3, 1.0],
|
||||
//! &[6.6, 2.9, 4.6, 1.3],
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! let pca = PCA::fit(&iris, PCAParameters::default().with_n_components(2)).unwrap(); // Reduce number of features to 2
|
||||
//!
|
||||
@@ -52,24 +52,33 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Transformer, UnsupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::linalg::basic::arrays::Array2;
|
||||
use crate::linalg::traits::evd::EVDDecomposable;
|
||||
use crate::linalg::traits::svd::SVDDecomposable;
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
|
||||
/// Principal components analysis algorithm
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct PCA<T: RealNumber, M: Matrix<T>> {
|
||||
eigenvectors: M,
|
||||
pub struct PCA<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>> {
|
||||
eigenvectors: X,
|
||||
eigenvalues: Vec<T>,
|
||||
projection: M,
|
||||
projection: X,
|
||||
mu: Vec<T>,
|
||||
pmu: Vec<T>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> PartialEq for PCA<T, M> {
|
||||
impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>> PartialEq
|
||||
for PCA<T, X>
|
||||
{
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.eigenvectors != other.eigenvectors
|
||||
|| self.eigenvalues.len() != other.eigenvalues.len()
|
||||
if self.eigenvalues.len() != other.eigenvalues.len()
|
||||
|| self
|
||||
.eigenvectors
|
||||
.iterator(0)
|
||||
.zip(other.eigenvectors.iterator(0))
|
||||
.any(|(&a, &b)| (a - b).abs() > T::epsilon())
|
||||
{
|
||||
false
|
||||
} else {
|
||||
@@ -83,11 +92,14 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for PCA<T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// PCA parameters
|
||||
pub struct PCAParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Number of components to keep.
|
||||
pub n_components: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// By default, covariance matrix is used to compute principal components.
|
||||
/// Enable this flag if you want to use correlation matrix instead.
|
||||
pub use_correlation_matrix: bool,
|
||||
@@ -116,40 +128,124 @@ impl Default for PCAParameters {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> UnsupervisedEstimator<M, PCAParameters> for PCA<T, M> {
|
||||
fn fit(x: &M, parameters: PCAParameters) -> Result<Self, Failed> {
|
||||
/// PCA grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PCASearchParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Number of components to keep.
|
||||
pub n_components: Vec<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// By default, covariance matrix is used to compute principal components.
|
||||
/// Enable this flag if you want to use correlation matrix instead.
|
||||
pub use_correlation_matrix: Vec<bool>,
|
||||
}
|
||||
|
||||
/// PCA grid search iterator
|
||||
pub struct PCASearchParametersIterator {
|
||||
pca_search_parameters: PCASearchParameters,
|
||||
current_k: usize,
|
||||
current_use_correlation_matrix: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for PCASearchParameters {
|
||||
type Item = PCAParameters;
|
||||
type IntoIter = PCASearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
PCASearchParametersIterator {
|
||||
pca_search_parameters: self,
|
||||
current_k: 0,
|
||||
current_use_correlation_matrix: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for PCASearchParametersIterator {
|
||||
type Item = PCAParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_k == self.pca_search_parameters.n_components.len()
|
||||
&& self.current_use_correlation_matrix
|
||||
== self.pca_search_parameters.use_correlation_matrix.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = PCAParameters {
|
||||
n_components: self.pca_search_parameters.n_components[self.current_k],
|
||||
use_correlation_matrix: self.pca_search_parameters.use_correlation_matrix
|
||||
[self.current_use_correlation_matrix],
|
||||
};
|
||||
|
||||
if self.current_k + 1 < self.pca_search_parameters.n_components.len() {
|
||||
self.current_k += 1;
|
||||
} else if self.current_use_correlation_matrix + 1
|
||||
< self.pca_search_parameters.use_correlation_matrix.len()
|
||||
{
|
||||
self.current_k = 0;
|
||||
self.current_use_correlation_matrix += 1;
|
||||
} else {
|
||||
self.current_k += 1;
|
||||
self.current_use_correlation_matrix += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PCASearchParameters {
|
||||
fn default() -> Self {
|
||||
let default_params = PCAParameters::default();
|
||||
|
||||
PCASearchParameters {
|
||||
n_components: vec![default_params.n_components],
|
||||
use_correlation_matrix: vec![default_params.use_correlation_matrix],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>>
|
||||
UnsupervisedEstimator<X, PCAParameters> for PCA<T, X>
|
||||
{
|
||||
fn fit(x: &X, parameters: PCAParameters) -> Result<Self, Failed> {
|
||||
PCA::fit(x, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Transformer<M> for PCA<T, M> {
|
||||
fn transform(&self, x: &M) -> Result<M, Failed> {
|
||||
impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>> Transformer<X>
|
||||
for PCA<T, X>
|
||||
{
|
||||
fn transform(&self, x: &X) -> Result<X, Failed> {
|
||||
self.transform(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
|
||||
impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>> PCA<T, X> {
|
||||
/// Fits PCA to your data.
|
||||
/// * `data` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `n_components` - number of components to keep.
|
||||
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
|
||||
pub fn fit(data: &M, parameters: PCAParameters) -> Result<PCA<T, M>, Failed> {
|
||||
pub fn fit(data: &X, parameters: PCAParameters) -> Result<PCA<T, X>, Failed> {
|
||||
let (m, n) = data.shape();
|
||||
|
||||
if parameters.n_components > n {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Number of components, n_components should be <= number of attributes ({})",
|
||||
n
|
||||
"Number of components, n_components should be <= number of attributes ({n})"
|
||||
)));
|
||||
}
|
||||
|
||||
let mu = data.column_mean();
|
||||
let mu: Vec<T> = data
|
||||
.mean_by(0)
|
||||
.iter()
|
||||
.map(|&v| T::from_f64(v).unwrap())
|
||||
.collect();
|
||||
|
||||
let mut x = data.clone();
|
||||
|
||||
for (c, mu_c) in mu.iter().enumerate().take(n) {
|
||||
for (c, &mu_c) in mu.iter().enumerate().take(n) {
|
||||
for r in 0..m {
|
||||
x.sub_element_mut(r, c, *mu_c);
|
||||
x.sub_element_mut((r, c), mu_c);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -165,33 +261,33 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
|
||||
|
||||
eigenvectors = svd.V;
|
||||
} else {
|
||||
let mut cov = M::zeros(n, n);
|
||||
let mut cov = X::zeros(n, n);
|
||||
|
||||
for k in 0..m {
|
||||
for i in 0..n {
|
||||
for j in 0..=i {
|
||||
cov.add_element_mut(i, j, x.get(k, i) * x.get(k, j));
|
||||
cov.add_element_mut((i, j), *x.get((k, i)) * *x.get((k, j)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i in 0..n {
|
||||
for j in 0..=i {
|
||||
cov.div_element_mut(i, j, T::from(m).unwrap());
|
||||
cov.set(j, i, cov.get(i, j));
|
||||
cov.div_element_mut((i, j), T::from(m).unwrap());
|
||||
cov.set((j, i), *cov.get((i, j)));
|
||||
}
|
||||
}
|
||||
|
||||
if parameters.use_correlation_matrix {
|
||||
let mut sd = vec![T::zero(); n];
|
||||
for (i, sd_i) in sd.iter_mut().enumerate().take(n) {
|
||||
*sd_i = cov.get(i, i).sqrt();
|
||||
*sd_i = cov.get((i, i)).sqrt();
|
||||
}
|
||||
|
||||
for i in 0..n {
|
||||
for j in 0..=i {
|
||||
cov.div_element_mut(i, j, sd[i] * sd[j]);
|
||||
cov.set(j, i, cov.get(i, j));
|
||||
cov.div_element_mut((i, j), sd[i] * sd[j]);
|
||||
cov.set((j, i), *cov.get((i, j)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -203,7 +299,7 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
|
||||
|
||||
for (i, sd_i) in sd.iter().enumerate().take(n) {
|
||||
for j in 0..n {
|
||||
eigenvectors.div_element_mut(i, j, *sd_i);
|
||||
eigenvectors.div_element_mut((i, j), *sd_i);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -215,17 +311,17 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
let mut projection = M::zeros(parameters.n_components, n);
|
||||
let mut projection = X::zeros(parameters.n_components, n);
|
||||
for i in 0..n {
|
||||
for j in 0..parameters.n_components {
|
||||
projection.set(j, i, eigenvectors.get(i, j));
|
||||
projection.set((j, i), *eigenvectors.get((i, j)));
|
||||
}
|
||||
}
|
||||
|
||||
let mut pmu = vec![T::zero(); parameters.n_components];
|
||||
for (k, mu_k) in mu.iter().enumerate().take(n) {
|
||||
for (i, pmu_i) in pmu.iter_mut().enumerate().take(parameters.n_components) {
|
||||
*pmu_i += projection.get(i, k) * (*mu_k);
|
||||
*pmu_i += *projection.get((i, k)) * (*mu_k);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -240,7 +336,7 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
|
||||
|
||||
/// Run dimensionality reduction for `x`
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn transform(&self, x: &M) -> Result<M, Failed> {
|
||||
pub fn transform(&self, x: &X) -> Result<X, Failed> {
|
||||
let (nrows, ncols) = x.shape();
|
||||
let (_, n_components) = self.projection.shape();
|
||||
if ncols != self.mu.len() {
|
||||
@@ -254,14 +350,14 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
|
||||
let mut x_transformed = x.matmul(&self.projection);
|
||||
for r in 0..nrows {
|
||||
for c in 0..n_components {
|
||||
x_transformed.sub_element_mut(r, c, self.pmu[c]);
|
||||
x_transformed.sub_element_mut((r, c), self.pmu[c]);
|
||||
}
|
||||
}
|
||||
Ok(x_transformed)
|
||||
}
|
||||
|
||||
/// Get a projection matrix
|
||||
pub fn components(&self) -> &M {
|
||||
pub fn components(&self) -> &X {
|
||||
&self.projection
|
||||
}
|
||||
}
|
||||
@@ -269,7 +365,30 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use approx::relative_eq;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = PCASearchParameters {
|
||||
n_components: vec![2, 4],
|
||||
use_correlation_matrix: vec![true, false],
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_components, 2);
|
||||
assert!(next.use_correlation_matrix);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_components, 4);
|
||||
assert!(next.use_correlation_matrix);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_components, 2);
|
||||
assert!(!next.use_correlation_matrix);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_components, 4);
|
||||
assert!(!next.use_correlation_matrix);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
fn us_arrests_data() -> DenseMatrix<f64> {
|
||||
DenseMatrix::from_2d_array(&[
|
||||
@@ -324,8 +443,12 @@ mod tests {
|
||||
&[2.6, 53.0, 66.0, 10.8],
|
||||
&[6.8, 161.0, 60.0, 15.6],
|
||||
])
|
||||
.unwrap()
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn pca_components() {
|
||||
let us_arrests = us_arrests_data();
|
||||
@@ -335,13 +458,21 @@ mod tests {
|
||||
&[0.9952, 0.0588],
|
||||
&[0.0463, 0.9769],
|
||||
&[0.0752, 0.2007],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let pca = PCA::fit(&us_arrests, Default::default()).unwrap();
|
||||
|
||||
assert!(expected.approximate_eq(&pca.components().abs(), 0.4));
|
||||
assert!(relative_eq!(
|
||||
expected,
|
||||
pca.components().abs(),
|
||||
epsilon = 1e-3
|
||||
));
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn decompose_covariance() {
|
||||
let us_arrests = us_arrests_data();
|
||||
@@ -371,7 +502,8 @@ mod tests {
|
||||
-0.974080592182491,
|
||||
0.0723250196376097,
|
||||
],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let expected_projection = DenseMatrix::from_2d_array(&[
|
||||
&[-64.8022, -11.448, 2.4949, -2.4079],
|
||||
@@ -424,7 +556,8 @@ mod tests {
|
||||
&[91.5446, -22.9529, 0.402, -0.7369],
|
||||
&[118.1763, 5.5076, 2.7113, -0.205],
|
||||
&[10.4345, -5.9245, 3.7944, 0.5179],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let expected_eigenvalues: Vec<f64> = vec![
|
||||
343544.6277001563,
|
||||
@@ -435,23 +568,29 @@ mod tests {
|
||||
|
||||
let pca = PCA::fit(&us_arrests, PCAParameters::default().with_n_components(4)).unwrap();
|
||||
|
||||
assert!(pca
|
||||
.eigenvectors
|
||||
.abs()
|
||||
.approximate_eq(&expected_eigenvectors.abs(), 1e-4));
|
||||
assert!(relative_eq!(
|
||||
pca.eigenvectors.abs(),
|
||||
&expected_eigenvectors.abs(),
|
||||
epsilon = 1e-4
|
||||
));
|
||||
|
||||
for i in 0..pca.eigenvalues.len() {
|
||||
assert!((pca.eigenvalues[i].abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
|
||||
for (i, pca_eigenvalues_i) in pca.eigenvalues.iter().enumerate() {
|
||||
assert!((pca_eigenvalues_i.abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
|
||||
}
|
||||
|
||||
let us_arrests_t = pca.transform(&us_arrests).unwrap();
|
||||
|
||||
assert!(us_arrests_t
|
||||
.abs()
|
||||
.approximate_eq(&expected_projection.abs(), 1e-4));
|
||||
assert!(relative_eq!(
|
||||
us_arrests_t.abs(),
|
||||
&expected_projection.abs(),
|
||||
epsilon = 1e-4
|
||||
));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn decompose_correlation() {
|
||||
let us_arrests = us_arrests_data();
|
||||
@@ -481,7 +620,8 @@ mod tests {
|
||||
-0.0881962972508558,
|
||||
-0.0096011588898465,
|
||||
],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let expected_projection = DenseMatrix::from_2d_array(&[
|
||||
&[0.9856, -1.1334, 0.4443, -0.1563],
|
||||
@@ -534,7 +674,8 @@ mod tests {
|
||||
&[-2.1086, -1.4248, -0.1048, -0.1319],
|
||||
&[-2.0797, 0.6113, 0.1389, -0.1841],
|
||||
&[-0.6294, -0.321, 0.2407, 0.1667],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let expected_eigenvalues: Vec<f64> = vec![
|
||||
2.480241579149493,
|
||||
@@ -551,54 +692,59 @@ mod tests {
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(pca
|
||||
.eigenvectors
|
||||
.abs()
|
||||
.approximate_eq(&expected_eigenvectors.abs(), 1e-4));
|
||||
assert!(relative_eq!(
|
||||
pca.eigenvectors.abs(),
|
||||
&expected_eigenvectors.abs(),
|
||||
epsilon = 1e-4
|
||||
));
|
||||
|
||||
for i in 0..pca.eigenvalues.len() {
|
||||
assert!((pca.eigenvalues[i].abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
|
||||
for (i, pca_eigenvalues_i) in pca.eigenvalues.iter().enumerate() {
|
||||
assert!((pca_eigenvalues_i.abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
|
||||
}
|
||||
|
||||
let us_arrests_t = pca.transform(&us_arrests).unwrap();
|
||||
|
||||
assert!(us_arrests_t
|
||||
.abs()
|
||||
.approximate_eq(&expected_projection.abs(), 1e-4));
|
||||
assert!(relative_eq!(
|
||||
us_arrests_t.abs(),
|
||||
&expected_projection.abs(),
|
||||
epsilon = 1e-4
|
||||
));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let iris = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[5.7, 2.8, 4.5, 1.3],
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
// Disable this test for now
|
||||
// TODO: implement deserialization for new DenseMatrix
|
||||
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
|
||||
// #[test]
|
||||
// #[cfg(feature = "serde")]
|
||||
// fn pca_serde() {
|
||||
// let iris = DenseMatrix::from_2d_array(&[
|
||||
// &[5.1, 3.5, 1.4, 0.2],
|
||||
// &[4.9, 3.0, 1.4, 0.2],
|
||||
// &[4.7, 3.2, 1.3, 0.2],
|
||||
// &[4.6, 3.1, 1.5, 0.2],
|
||||
// &[5.0, 3.6, 1.4, 0.2],
|
||||
// &[5.4, 3.9, 1.7, 0.4],
|
||||
// &[4.6, 3.4, 1.4, 0.3],
|
||||
// &[5.0, 3.4, 1.5, 0.2],
|
||||
// &[4.4, 2.9, 1.4, 0.2],
|
||||
// &[4.9, 3.1, 1.5, 0.1],
|
||||
// &[7.0, 3.2, 4.7, 1.4],
|
||||
// &[6.4, 3.2, 4.5, 1.5],
|
||||
// &[6.9, 3.1, 4.9, 1.5],
|
||||
// &[5.5, 2.3, 4.0, 1.3],
|
||||
// &[6.5, 2.8, 4.6, 1.5],
|
||||
// &[5.7, 2.8, 4.5, 1.3],
|
||||
// &[6.3, 3.3, 4.7, 1.6],
|
||||
// &[4.9, 2.4, 3.3, 1.0],
|
||||
// &[6.6, 2.9, 4.6, 1.3],
|
||||
// &[5.2, 2.7, 3.9, 1.4],
|
||||
// ]).unwrap();
|
||||
|
||||
let pca = PCA::fit(&iris, Default::default()).unwrap();
|
||||
// let pca = PCA::fit(&iris, Default::default()).unwrap();
|
||||
|
||||
let deserialized_pca: PCA<f64, DenseMatrix<f64>> =
|
||||
serde_json::from_str(&serde_json::to_string(&pca).unwrap()).unwrap();
|
||||
// let deserialized_pca: PCA<f64, DenseMatrix<f64>> =
|
||||
// serde_json::from_str(&serde_json::to_string(&pca).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(pca, deserialized_pca);
|
||||
}
|
||||
// assert_eq!(pca, deserialized_pca);
|
||||
// }
|
||||
}
|
||||
|
||||
+149
-59
@@ -7,7 +7,7 @@
|
||||
//!
|
||||
//! Example:
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
//! use smartcore::decomposition::svd::*;
|
||||
//!
|
||||
//! // Iris data
|
||||
@@ -32,7 +32,7 @@
|
||||
//! &[4.9, 2.4, 3.3, 1.0],
|
||||
//! &[6.6, 2.9, 4.6, 1.3],
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! let svd = SVD::fit(&iris, SVDParameters::default().
|
||||
//! with_n_components(2)).unwrap(); // Reduce number of features to 2
|
||||
@@ -51,27 +51,36 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Transformer, UnsupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::linalg::basic::arrays::Array2;
|
||||
use crate::linalg::traits::evd::EVDDecomposable;
|
||||
use crate::linalg::traits::svd::SVDDecomposable;
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
|
||||
/// SVD
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct SVD<T: RealNumber, M: Matrix<T>> {
|
||||
components: M,
|
||||
pub struct SVD<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>> {
|
||||
components: X,
|
||||
phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> PartialEq for SVD<T, M> {
|
||||
impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>> PartialEq
|
||||
for SVD<T, X>
|
||||
{
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.components
|
||||
.approximate_eq(&other.components, T::from_f64(1e-8).unwrap())
|
||||
.iterator(0)
|
||||
.zip(other.components.iterator(0))
|
||||
.all(|(&a, &b)| (a - b).abs() <= T::epsilon())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// SVD parameters
|
||||
pub struct SVDParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Number of components to keep.
|
||||
pub n_components: usize,
|
||||
}
|
||||
@@ -90,36 +99,94 @@ impl SVDParameters {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> UnsupervisedEstimator<M, SVDParameters> for SVD<T, M> {
|
||||
fn fit(x: &M, parameters: SVDParameters) -> Result<Self, Failed> {
|
||||
/// SVD grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SVDSearchParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Maximum number of iterations of the k-means algorithm for a single run.
|
||||
pub n_components: Vec<usize>,
|
||||
}
|
||||
|
||||
/// SVD grid search iterator
|
||||
pub struct SVDSearchParametersIterator {
|
||||
svd_search_parameters: SVDSearchParameters,
|
||||
current_n_components: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for SVDSearchParameters {
|
||||
type Item = SVDParameters;
|
||||
type IntoIter = SVDSearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
SVDSearchParametersIterator {
|
||||
svd_search_parameters: self,
|
||||
current_n_components: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for SVDSearchParametersIterator {
|
||||
type Item = SVDParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_n_components == self.svd_search_parameters.n_components.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = SVDParameters {
|
||||
n_components: self.svd_search_parameters.n_components[self.current_n_components],
|
||||
};
|
||||
|
||||
self.current_n_components += 1;
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SVDSearchParameters {
|
||||
fn default() -> Self {
|
||||
let default_params = SVDParameters::default();
|
||||
|
||||
SVDSearchParameters {
|
||||
n_components: vec![default_params.n_components],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>>
|
||||
UnsupervisedEstimator<X, SVDParameters> for SVD<T, X>
|
||||
{
|
||||
fn fit(x: &X, parameters: SVDParameters) -> Result<Self, Failed> {
|
||||
SVD::fit(x, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Transformer<M> for SVD<T, M> {
|
||||
fn transform(&self, x: &M) -> Result<M, Failed> {
|
||||
impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>> Transformer<X>
|
||||
for SVD<T, X>
|
||||
{
|
||||
fn transform(&self, x: &X) -> Result<X, Failed> {
|
||||
self.transform(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> SVD<T, M> {
|
||||
impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>> SVD<T, X> {
|
||||
/// Fits SVD to your data.
|
||||
/// * `data` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `n_components` - number of components to keep.
|
||||
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
|
||||
pub fn fit(x: &M, parameters: SVDParameters) -> Result<SVD<T, M>, Failed> {
|
||||
pub fn fit(x: &X, parameters: SVDParameters) -> Result<SVD<T, X>, Failed> {
|
||||
let (_, p) = x.shape();
|
||||
|
||||
if parameters.n_components >= p {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Number of components, n_components should be < number of attributes ({})",
|
||||
p
|
||||
"Number of components, n_components should be < number of attributes ({p})"
|
||||
)));
|
||||
}
|
||||
|
||||
let svd = x.svd()?;
|
||||
|
||||
let components = svd.V.slice(0..p, 0..parameters.n_components);
|
||||
let components = X::from_slice(svd.V.slice(0..p, 0..parameters.n_components).as_ref());
|
||||
|
||||
Ok(SVD {
|
||||
components,
|
||||
@@ -129,13 +196,12 @@ impl<T: RealNumber, M: Matrix<T>> SVD<T, M> {
|
||||
|
||||
/// Run dimensionality reduction for `x`
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn transform(&self, x: &M) -> Result<M, Failed> {
|
||||
pub fn transform(&self, x: &X) -> Result<X, Failed> {
|
||||
let (n, p) = x.shape();
|
||||
let (p_c, k) = self.components.shape();
|
||||
if p_c != p {
|
||||
return Err(Failed::transform(&format!(
|
||||
"Can not transform a {}x{} matrix into {}x{} matrix, incorrect input dimentions",
|
||||
n, p, n, k
|
||||
"Can not transform a {n}x{p} matrix into {n}x{k} matrix, incorrect input dimentions"
|
||||
)));
|
||||
}
|
||||
|
||||
@@ -143,7 +209,7 @@ impl<T: RealNumber, M: Matrix<T>> SVD<T, M> {
|
||||
}
|
||||
|
||||
/// Get a projection matrix
|
||||
pub fn components(&self) -> &M {
|
||||
pub fn components(&self) -> &X {
|
||||
&self.components
|
||||
}
|
||||
}
|
||||
@@ -151,9 +217,27 @@ impl<T: RealNumber, M: Matrix<T>> SVD<T, M> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::linalg::basic::arrays::Array;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use approx::relative_eq;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = SVDSearchParameters {
|
||||
n_components: vec![10, 100],
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_components, 10);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_components, 100);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn svd_decompose() {
|
||||
// https://stat.ethz.ch/R-manual/R-devel/library/datasets/html/USArrests.html
|
||||
@@ -208,7 +292,8 @@ mod tests {
|
||||
&[5.7, 81.0, 39.0, 9.3],
|
||||
&[2.6, 53.0, 66.0, 10.8],
|
||||
&[6.8, 161.0, 60.0, 15.6],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let expected = DenseMatrix::from_2d_array(&[
|
||||
&[243.54655757, -18.76673788],
|
||||
@@ -216,50 +301,55 @@ mod tests {
|
||||
&[305.93972467, -15.39087376],
|
||||
&[197.28420365, -11.66808306],
|
||||
&[293.43187394, 1.91163633],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let svd = SVD::fit(&x, Default::default()).unwrap();
|
||||
|
||||
let x_transformed = svd.transform(&x).unwrap();
|
||||
|
||||
assert_eq!(svd.components.shape(), (x.shape().1, 2));
|
||||
|
||||
assert!(x_transformed
|
||||
.slice(0..5, 0..2)
|
||||
.approximate_eq(&expected, 1e-4));
|
||||
assert!(relative_eq!(
|
||||
DenseMatrix::from_slice(x_transformed.slice(0..5, 0..2).as_ref()),
|
||||
&expected,
|
||||
epsilon = 1e-4
|
||||
));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let iris = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[5.7, 2.8, 4.5, 1.3],
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
// Disable this test for now
|
||||
// TODO: implement deserialization for new DenseMatrix
|
||||
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
|
||||
// #[test]
|
||||
// #[cfg(feature = "serde")]
|
||||
// fn serde() {
|
||||
// let iris = DenseMatrix::from_2d_array(&[
|
||||
// &[5.1, 3.5, 1.4, 0.2],
|
||||
// &[4.9, 3.0, 1.4, 0.2],
|
||||
// &[4.7, 3.2, 1.3, 0.2],
|
||||
// &[4.6, 3.1, 1.5, 0.2],
|
||||
// &[5.0, 3.6, 1.4, 0.2],
|
||||
// &[5.4, 3.9, 1.7, 0.4],
|
||||
// &[4.6, 3.4, 1.4, 0.3],
|
||||
// &[5.0, 3.4, 1.5, 0.2],
|
||||
// &[4.4, 2.9, 1.4, 0.2],
|
||||
// &[4.9, 3.1, 1.5, 0.1],
|
||||
// &[7.0, 3.2, 4.7, 1.4],
|
||||
// &[6.4, 3.2, 4.5, 1.5],
|
||||
// &[6.9, 3.1, 4.9, 1.5],
|
||||
// &[5.5, 2.3, 4.0, 1.3],
|
||||
// &[6.5, 2.8, 4.6, 1.5],
|
||||
// &[5.7, 2.8, 4.5, 1.3],
|
||||
// &[6.3, 3.3, 4.7, 1.6],
|
||||
// &[4.9, 2.4, 3.3, 1.0],
|
||||
// &[6.6, 2.9, 4.6, 1.3],
|
||||
// &[5.2, 2.7, 3.9, 1.4],
|
||||
// ]).unwrap();
|
||||
|
||||
let svd = SVD::fit(&iris, Default::default()).unwrap();
|
||||
// let svd = SVD::fit(&iris, Default::default()).unwrap();
|
||||
|
||||
let deserialized_svd: SVD<f64, DenseMatrix<f64>> =
|
||||
serde_json::from_str(&serde_json::to_string(&svd).unwrap()).unwrap();
|
||||
// let deserialized_svd: SVD<f32, DenseMatrix<f32>> =
|
||||
// serde_json::from_str(&serde_json::to_string(&svd).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(svd, deserialized_svd);
|
||||
}
|
||||
// assert_eq!(svd, deserialized_svd);
|
||||
// }
|
||||
}
|
||||
|
||||
@@ -0,0 +1,214 @@
|
||||
use rand::Rng;
|
||||
use std::fmt::Debug;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::basic::arrays::{Array1, Array2};
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::floatnum::FloatNumber;
|
||||
|
||||
use crate::rand_custom::get_rng_impl;
|
||||
use crate::tree::base_tree_regressor::{BaseTreeRegressor, BaseTreeRegressorParameters, Splitter};
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// Parameters of the Forest Regressor
|
||||
/// Some parameters here are passed directly into base estimator.
|
||||
pub struct BaseForestRegressorParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||
pub max_depth: Option<u16>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||
pub min_samples_leaf: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||
pub min_samples_split: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The number of trees in the forest.
|
||||
pub n_trees: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Number of random sample of predictors to use as split candidates.
|
||||
pub m: Option<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||
pub keep_samples: bool,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||
pub seed: u64,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
pub bootstrap: bool,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
pub splitter: Splitter,
|
||||
}
|
||||
|
||||
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
|
||||
for BaseForestRegressor<TX, TY, X, Y>
|
||||
{
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.trees.as_ref().unwrap().len() != other.trees.as_ref().unwrap().len() {
|
||||
false
|
||||
} else {
|
||||
self.trees
|
||||
.iter()
|
||||
.zip(other.trees.iter())
|
||||
.all(|(a, b)| a == b)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Forest Regressor
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct BaseForestRegressor<
|
||||
TX: Number + FloatNumber + PartialOrd,
|
||||
TY: Number,
|
||||
X: Array2<TX>,
|
||||
Y: Array1<TY>,
|
||||
> {
|
||||
trees: Option<Vec<BaseTreeRegressor<TX, TY, X, Y>>>,
|
||||
samples: Option<Vec<Vec<bool>>>,
|
||||
}
|
||||
|
||||
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||
BaseForestRegressor<TX, TY, X, Y>
|
||||
{
|
||||
/// Build a forest of trees from the training set.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - the target class values
|
||||
pub fn fit(
|
||||
x: &X,
|
||||
y: &Y,
|
||||
parameters: BaseForestRegressorParameters,
|
||||
) -> Result<BaseForestRegressor<TX, TY, X, Y>, Failed> {
|
||||
let (n_rows, num_attributes) = x.shape();
|
||||
|
||||
if n_rows != y.shape() {
|
||||
return Err(Failed::fit("Number of rows in X should = len(y)"));
|
||||
}
|
||||
|
||||
let mtry = parameters
|
||||
.m
|
||||
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
|
||||
|
||||
let mut rng = get_rng_impl(Some(parameters.seed));
|
||||
let mut trees: Vec<BaseTreeRegressor<TX, TY, X, Y>> = Vec::new();
|
||||
|
||||
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
|
||||
if parameters.keep_samples {
|
||||
// TODO: use with_capacity here
|
||||
maybe_all_samples = Some(Vec::new());
|
||||
}
|
||||
|
||||
let mut samples: Vec<usize> = (0..n_rows).map(|_| 1).collect();
|
||||
|
||||
for _ in 0..parameters.n_trees {
|
||||
if parameters.bootstrap {
|
||||
samples =
|
||||
BaseForestRegressor::<TX, TY, X, Y>::sample_with_replacement(n_rows, &mut rng);
|
||||
}
|
||||
|
||||
// keep samples is flag is on
|
||||
if let Some(ref mut all_samples) = maybe_all_samples {
|
||||
all_samples.push(samples.iter().map(|x| *x != 0).collect())
|
||||
}
|
||||
|
||||
let params = BaseTreeRegressorParameters {
|
||||
max_depth: parameters.max_depth,
|
||||
min_samples_leaf: parameters.min_samples_leaf,
|
||||
min_samples_split: parameters.min_samples_split,
|
||||
seed: Some(parameters.seed),
|
||||
splitter: parameters.splitter.clone(),
|
||||
};
|
||||
let tree = BaseTreeRegressor::fit_weak_learner(x, y, samples.clone(), mtry, params)?;
|
||||
trees.push(tree);
|
||||
}
|
||||
|
||||
Ok(BaseForestRegressor {
|
||||
trees: Some(trees),
|
||||
samples: maybe_all_samples,
|
||||
})
|
||||
}
|
||||
|
||||
/// Predict class for `x`
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
let mut result = Y::zeros(x.shape().0);
|
||||
|
||||
let (n, _) = x.shape();
|
||||
|
||||
for i in 0..n {
|
||||
result.set(i, self.predict_for_row(x, i));
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn predict_for_row(&self, x: &X, row: usize) -> TY {
|
||||
let n_trees = self.trees.as_ref().unwrap().len();
|
||||
|
||||
let mut result = TY::zero();
|
||||
|
||||
for tree in self.trees.as_ref().unwrap().iter() {
|
||||
result += tree.predict_for_row(x, row);
|
||||
}
|
||||
|
||||
result / TY::from_usize(n_trees).unwrap()
|
||||
}
|
||||
|
||||
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
|
||||
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
|
||||
let (n, _) = x.shape();
|
||||
if self.samples.is_none() {
|
||||
Err(Failed::because(
|
||||
FailedError::PredictFailed,
|
||||
"Need samples=true for OOB predictions.",
|
||||
))
|
||||
} else if self.samples.as_ref().unwrap()[0].len() != n {
|
||||
Err(Failed::because(
|
||||
FailedError::PredictFailed,
|
||||
"Prediction matrix must match matrix used in training for OOB predictions.",
|
||||
))
|
||||
} else {
|
||||
let mut result = Y::zeros(n);
|
||||
|
||||
for i in 0..n {
|
||||
result.set(i, self.predict_for_row_oob(x, i));
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
fn predict_for_row_oob(&self, x: &X, row: usize) -> TY {
|
||||
let mut n_trees = 0;
|
||||
let mut result = TY::zero();
|
||||
|
||||
for (tree, samples) in self
|
||||
.trees
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.zip(self.samples.as_ref().unwrap())
|
||||
{
|
||||
if !samples[row] {
|
||||
result += tree.predict_for_row(x, row);
|
||||
n_trees += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: What to do if there are no oob trees?
|
||||
result / TY::from(n_trees).unwrap()
|
||||
}
|
||||
|
||||
fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec<usize> {
|
||||
let mut samples = vec![0; nrows];
|
||||
for _ in 0..nrows {
|
||||
let xi = rng.gen_range(0..nrows);
|
||||
samples[xi] += 1;
|
||||
}
|
||||
samples
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,318 @@
|
||||
//! # Extra Trees Regressor
|
||||
//! An Extra-Trees (Extremely Randomized Trees) regressor is an ensemble learning method that fits multiple randomized
|
||||
//! decision trees on the dataset and averages their predictions to improve accuracy and control over-fitting.
|
||||
//!
|
||||
//! It is similar to a standard Random Forest, but introduces more randomness in the way splits are chosen, which can
|
||||
//! reduce the variance of the model and often make the training process faster.
|
||||
//!
|
||||
//! The two key differences from a standard Random Forest are:
|
||||
//! 1. It uses the whole original dataset to build each tree instead of bootstrap samples.
|
||||
//! 2. When splitting a node, it chooses a random split point for each feature, rather than the most optimal one.
|
||||
//!
|
||||
//! See [ensemble models](../index.html) for more details.
|
||||
//!
|
||||
//! Bigger number of estimators in general improves performance of the algorithm with an increased cost of training time.
|
||||
//! The random sample of _m_ predictors is typically set to be \\(\sqrt{p}\\) from the full set of _p_ predictors.
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
//! use smartcore::ensemble::extra_trees_regressor::*;
|
||||
//!
|
||||
//! // Longley dataset ([https://www.statsmodels.org/stable/datasets/generated/longley.html](https://www.statsmodels.org/stable/datasets/generated/longley.html))
|
||||
//! let x = DenseMatrix::from_2d_array(&[
|
||||
//! &[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||
//! &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
//! &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
//! &[284.599, 335.1, 165., 110.929, 1950., 61.187],
|
||||
//! &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
//! &[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
|
||||
//! &[365.385, 187., 354.7, 115.094, 1953., 64.989],
|
||||
//! &[363.112, 357.8, 335., 116.219, 1954., 63.761],
|
||||
//! &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
//! &[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
//! &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
//! &[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
|
||||
//! &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
//! ]).unwrap();
|
||||
//! let y = vec![
|
||||
//! 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2,
|
||||
//! 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9
|
||||
//! ];
|
||||
//!
|
||||
//! let regressor = ExtraTreesRegressor::fit(&x, &y, Default::default()).unwrap();
|
||||
//!
|
||||
//! let y_hat = regressor.predict(&x).unwrap(); // use the same data for prediction
|
||||
//! ```
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::ensemble::base_forest_regressor::{BaseForestRegressor, BaseForestRegressorParameters};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::basic::arrays::{Array1, Array2};
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::floatnum::FloatNumber;
|
||||
use crate::tree::base_tree_regressor::Splitter;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// Parameters of the Extra Trees Regressor
|
||||
/// Some parameters here are passed directly into base estimator.
|
||||
pub struct ExtraTreesRegressorParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||
pub max_depth: Option<u16>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||
pub min_samples_leaf: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||
pub min_samples_split: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The number of trees in the forest.
|
||||
pub n_trees: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Number of random sample of predictors to use as split candidates.
|
||||
pub m: Option<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||
pub keep_samples: bool,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||
pub seed: u64,
|
||||
}
|
||||
|
||||
/// Extra Trees Regressor
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct ExtraTreesRegressor<
|
||||
TX: Number + FloatNumber + PartialOrd,
|
||||
TY: Number,
|
||||
X: Array2<TX>,
|
||||
Y: Array1<TY>,
|
||||
> {
|
||||
forest_regressor: Option<BaseForestRegressor<TX, TY, X, Y>>,
|
||||
}
|
||||
|
||||
impl ExtraTreesRegressorParameters {
|
||||
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub fn with_max_depth(mut self, max_depth: u16) -> Self {
|
||||
self.max_depth = Some(max_depth);
|
||||
self
|
||||
}
|
||||
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
|
||||
self.min_samples_leaf = min_samples_leaf;
|
||||
self
|
||||
}
|
||||
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
|
||||
self.min_samples_split = min_samples_split;
|
||||
self
|
||||
}
|
||||
/// The number of trees in the forest.
|
||||
pub fn with_n_trees(mut self, n_trees: usize) -> Self {
|
||||
self.n_trees = n_trees;
|
||||
self
|
||||
}
|
||||
/// Number of random sample of predictors to use as split candidates.
|
||||
pub fn with_m(mut self, m: usize) -> Self {
|
||||
self.m = Some(m);
|
||||
self
|
||||
}
|
||||
|
||||
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||
pub fn with_keep_samples(mut self, keep_samples: bool) -> Self {
|
||||
self.keep_samples = keep_samples;
|
||||
self
|
||||
}
|
||||
|
||||
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||
pub fn with_seed(mut self, seed: u64) -> Self {
|
||||
self.seed = seed;
|
||||
self
|
||||
}
|
||||
}
|
||||
impl Default for ExtraTreesRegressorParameters {
|
||||
fn default() -> Self {
|
||||
ExtraTreesRegressorParameters {
|
||||
max_depth: Option::None,
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2,
|
||||
n_trees: 10,
|
||||
m: Option::None,
|
||||
keep_samples: false,
|
||||
seed: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||
SupervisedEstimator<X, Y, ExtraTreesRegressorParameters> for ExtraTreesRegressor<TX, TY, X, Y>
|
||||
{
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
forest_regressor: Option::None,
|
||||
}
|
||||
}
|
||||
|
||||
fn fit(x: &X, y: &Y, parameters: ExtraTreesRegressorParameters) -> Result<Self, Failed> {
|
||||
ExtraTreesRegressor::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||
Predictor<X, Y> for ExtraTreesRegressor<TX, TY, X, Y>
|
||||
{
|
||||
fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||
ExtraTreesRegressor<TX, TY, X, Y>
|
||||
{
|
||||
/// Build a forest of trees from the training set.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - the target class values
|
||||
pub fn fit(
|
||||
x: &X,
|
||||
y: &Y,
|
||||
parameters: ExtraTreesRegressorParameters,
|
||||
) -> Result<ExtraTreesRegressor<TX, TY, X, Y>, Failed> {
|
||||
let regressor_params = BaseForestRegressorParameters {
|
||||
max_depth: parameters.max_depth,
|
||||
min_samples_leaf: parameters.min_samples_leaf,
|
||||
min_samples_split: parameters.min_samples_split,
|
||||
n_trees: parameters.n_trees,
|
||||
m: parameters.m,
|
||||
keep_samples: parameters.keep_samples,
|
||||
seed: parameters.seed,
|
||||
bootstrap: false,
|
||||
splitter: Splitter::Random,
|
||||
};
|
||||
let forest_regressor = BaseForestRegressor::fit(x, y, regressor_params)?;
|
||||
|
||||
Ok(ExtraTreesRegressor {
|
||||
forest_regressor: Some(forest_regressor),
|
||||
})
|
||||
}
|
||||
|
||||
/// Predict class for `x`
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
let forest_regressor = self.forest_regressor.as_ref().unwrap();
|
||||
forest_regressor.predict(x)
|
||||
}
|
||||
|
||||
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
|
||||
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
|
||||
let forest_regressor = self.forest_regressor.as_ref().unwrap();
|
||||
forest_regressor.predict_oob(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use crate::metrics::mean_squared_error;
|
||||
|
||||
#[test]
|
||||
fn test_extra_trees_regressor_fit_predict() {
|
||||
// Use a simpler, more predictable dataset for unit testing.
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[1., 2.],
|
||||
&[3., 4.],
|
||||
&[5., 6.],
|
||||
&[7., 8.],
|
||||
&[9., 10.],
|
||||
&[11., 12.],
|
||||
&[13., 14.],
|
||||
&[15., 16.],
|
||||
])
|
||||
.unwrap();
|
||||
let y = vec![1., 2., 3., 4., 5., 6., 7., 8.];
|
||||
|
||||
let parameters = ExtraTreesRegressorParameters::default()
|
||||
.with_n_trees(100)
|
||||
.with_seed(42);
|
||||
|
||||
let regressor = ExtraTreesRegressor::fit(&x, &y, parameters).unwrap();
|
||||
let y_hat = regressor.predict(&x).unwrap();
|
||||
|
||||
assert_eq!(y_hat.len(), y.len());
|
||||
// A basic check to ensure the model is learning something.
|
||||
// The error should be significantly less than the variance of y.
|
||||
let mse = mean_squared_error(&y, &y_hat);
|
||||
// With this simple dataset, the error should be very low.
|
||||
assert!(mse < 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fit_predict_higher_dims() {
|
||||
// Dataset with 10 features, but y is only dependent on the 3rd feature (index 2).
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
// The 3rd column is the important one. The rest are noise.
|
||||
&[0., 0., 10., 5., 8., 1., 4., 9., 2., 7.],
|
||||
&[0., 0., 20., 1., 2., 3., 4., 5., 6., 7.],
|
||||
&[0., 0., 30., 7., 6., 5., 4., 3., 2., 1.],
|
||||
&[0., 0., 40., 9., 2., 4., 6., 8., 1., 3.],
|
||||
&[0., 0., 55., 3., 1., 8., 6., 4., 2., 9.],
|
||||
&[0., 0., 65., 2., 4., 7., 5., 3., 1., 8.],
|
||||
])
|
||||
.unwrap();
|
||||
let y = vec![10., 20., 30., 40., 55., 65.];
|
||||
|
||||
let parameters = ExtraTreesRegressorParameters::default()
|
||||
.with_n_trees(100)
|
||||
.with_seed(42);
|
||||
|
||||
let regressor = ExtraTreesRegressor::fit(&x, &y, parameters).unwrap();
|
||||
let y_hat = regressor.predict(&x).unwrap();
|
||||
|
||||
assert_eq!(y_hat.len(), y.len());
|
||||
|
||||
let mse = mean_squared_error(&y, &y_hat);
|
||||
|
||||
// The model should be able to learn this simple relationship perfectly,
|
||||
// ignoring the noise features. The MSE should be very low.
|
||||
assert!(mse < 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reproducibility() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[1., 2.],
|
||||
&[3., 4.],
|
||||
&[5., 6.],
|
||||
&[7., 8.],
|
||||
&[9., 10.],
|
||||
&[11., 12.],
|
||||
])
|
||||
.unwrap();
|
||||
let y = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
|
||||
let params = ExtraTreesRegressorParameters::default().with_seed(42);
|
||||
|
||||
let regressor1 = ExtraTreesRegressor::fit(&x, &y, params.clone()).unwrap();
|
||||
let y_hat1 = regressor1.predict(&x).unwrap();
|
||||
|
||||
let regressor2 = ExtraTreesRegressor::fit(&x, &y, params.clone()).unwrap();
|
||||
let y_hat2 = regressor2.predict(&x).unwrap();
|
||||
|
||||
assert_eq!(y_hat1, y_hat2);
|
||||
}
|
||||
}
|
||||
+3
-1
@@ -7,7 +7,7 @@
|
||||
//! set and then aggregate their individual predictions to form a final prediction. In classification setting the overall prediction is the most commonly
|
||||
//! occurring majority class among the individual predictions.
|
||||
//!
|
||||
//! In SmartCore you will find implementation of RandomForest - a popular averaging algorithms based on randomized [decision trees](../tree/index.html).
|
||||
//! In `smartcore` you will find implementation of RandomForest - a popular averaging algorithms based on randomized [decision trees](../tree/index.html).
|
||||
//! Random forests provide an improvement over bagged trees by way of a small tweak that decorrelates the trees. As in bagging, we build a number of
|
||||
//! decision trees on bootstrapped training samples. But when building these decision trees, each time a split in a tree is considered,
|
||||
//! a random sample of _m_ predictors is chosen as split candidates from the full set of _p_ predictors.
|
||||
@@ -16,6 +16,8 @@
|
||||
//!
|
||||
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 8.2 Bagging, Random Forests, Boosting](http://faculty.marshall.usc.edu/gareth-james/ISL/)
|
||||
|
||||
mod base_forest_regressor;
|
||||
pub mod extra_trees_regressor;
|
||||
/// Random forest classifier
|
||||
pub mod random_forest_classifier;
|
||||
/// Random forest regressor
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
//! use smartcore::ensemble::random_forest_classifier::RandomForestClassifier;
|
||||
//!
|
||||
//! // Iris dataset
|
||||
@@ -33,10 +33,10 @@
|
||||
//! &[4.9, 2.4, 3.3, 1.0],
|
||||
//! &[6.6, 2.9, 4.6, 1.3],
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//! let y = vec![
|
||||
//! 0., 0., 0., 0., 0., 0., 0., 0.,
|
||||
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
//! 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
//! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
//! ];
|
||||
//!
|
||||
//! let classifier = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||
@@ -45,8 +45,8 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rand::Rng;
|
||||
|
||||
use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
|
||||
@@ -55,8 +55,11 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::linalg::basic::arrays::{Array1, Array2};
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::floatnum::FloatNumber;
|
||||
|
||||
use crate::rand_custom::get_rng_impl;
|
||||
use crate::tree::decision_tree_classifier::{
|
||||
which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion,
|
||||
};
|
||||
@@ -66,20 +69,28 @@ use crate::tree::decision_tree_classifier::{
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RandomForestClassifierParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub criterion: SplitCriterion,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub max_depth: Option<u16>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub min_samples_leaf: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub min_samples_split: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The number of trees in the forest.
|
||||
pub n_trees: u16,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Number of random sample of predictors to use as split candidates.
|
||||
pub m: Option<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||
pub keep_samples: bool,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||
pub seed: u64,
|
||||
}
|
||||
@@ -87,10 +98,14 @@ pub struct RandomForestClassifierParameters {
|
||||
/// Random Forest Classifier
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct RandomForestClassifier<T: RealNumber> {
|
||||
_parameters: RandomForestClassifierParameters,
|
||||
trees: Vec<DecisionTreeClassifier<T>>,
|
||||
classes: Vec<T>,
|
||||
pub struct RandomForestClassifier<
|
||||
TX: Number + FloatNumber + PartialOrd,
|
||||
TY: Number + Ord,
|
||||
X: Array2<TX>,
|
||||
Y: Array1<TY>,
|
||||
> {
|
||||
trees: Option<Vec<DecisionTreeClassifier<TX, TY, X, Y>>>,
|
||||
classes: Option<Vec<TY>>,
|
||||
samples: Option<Vec<Vec<bool>>>,
|
||||
}
|
||||
|
||||
@@ -139,22 +154,24 @@ impl RandomForestClassifierParameters {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> PartialEq for RandomForestClassifier<T> {
|
||||
impl<TX: Number + FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
PartialEq for RandomForestClassifier<TX, TY, X, Y>
|
||||
{
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.classes.len() != other.classes.len() || self.trees.len() != other.trees.len() {
|
||||
if self.classes.as_ref().unwrap().len() != other.classes.as_ref().unwrap().len()
|
||||
|| self.trees.as_ref().unwrap().len() != other.trees.as_ref().unwrap().len()
|
||||
{
|
||||
false
|
||||
} else {
|
||||
for i in 0..self.classes.len() {
|
||||
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
for i in 0..self.trees.len() {
|
||||
if self.trees[i] != other.trees[i] {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
self.classes
|
||||
.iter()
|
||||
.zip(other.classes.iter())
|
||||
.all(|(a, b)| a == b)
|
||||
&& self
|
||||
.trees
|
||||
.iter()
|
||||
.zip(other.trees.iter())
|
||||
.all(|(a, b)| a == b)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -163,7 +180,7 @@ impl Default for RandomForestClassifierParameters {
|
||||
fn default() -> Self {
|
||||
RandomForestClassifierParameters {
|
||||
criterion: SplitCriterion::Gini,
|
||||
max_depth: None,
|
||||
max_depth: Option::None,
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2,
|
||||
n_trees: 100,
|
||||
@@ -174,65 +191,302 @@ impl Default for RandomForestClassifierParameters {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>>
|
||||
SupervisedEstimator<M, M::RowVector, RandomForestClassifierParameters>
|
||||
for RandomForestClassifier<T>
|
||||
impl<TX: Number + FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
SupervisedEstimator<X, Y, RandomForestClassifierParameters>
|
||||
for RandomForestClassifier<TX, TY, X, Y>
|
||||
{
|
||||
fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: RandomForestClassifierParameters,
|
||||
) -> Result<Self, Failed> {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
trees: Option::None,
|
||||
classes: Option::None,
|
||||
samples: Option::None,
|
||||
}
|
||||
}
|
||||
fn fit(x: &X, y: &Y, parameters: RandomForestClassifierParameters) -> Result<Self, Failed> {
|
||||
RandomForestClassifier::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RandomForestClassifier<T> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
impl<TX: Number + FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
Predictor<X, Y> for RandomForestClassifier<TX, TY, X, Y>
|
||||
{
|
||||
fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
/// RandomForestClassifier grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RandomForestClassifierSearchParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub criterion: Vec<SplitCriterion>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub max_depth: Vec<Option<u16>>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub min_samples_leaf: Vec<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub min_samples_split: Vec<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The number of trees in the forest.
|
||||
pub n_trees: Vec<u16>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Number of random sample of predictors to use as split candidates.
|
||||
pub m: Vec<Option<usize>>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||
pub keep_samples: Vec<bool>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||
pub seed: Vec<u64>,
|
||||
}
|
||||
|
||||
/// RandomForestClassifier grid search iterator
|
||||
pub struct RandomForestClassifierSearchParametersIterator {
|
||||
random_forest_classifier_search_parameters: RandomForestClassifierSearchParameters,
|
||||
current_criterion: usize,
|
||||
current_max_depth: usize,
|
||||
current_min_samples_leaf: usize,
|
||||
current_min_samples_split: usize,
|
||||
current_n_trees: usize,
|
||||
current_m: usize,
|
||||
current_keep_samples: usize,
|
||||
current_seed: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for RandomForestClassifierSearchParameters {
|
||||
type Item = RandomForestClassifierParameters;
|
||||
type IntoIter = RandomForestClassifierSearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
RandomForestClassifierSearchParametersIterator {
|
||||
random_forest_classifier_search_parameters: self,
|
||||
current_criterion: 0,
|
||||
current_max_depth: 0,
|
||||
current_min_samples_leaf: 0,
|
||||
current_min_samples_split: 0,
|
||||
current_n_trees: 0,
|
||||
current_m: 0,
|
||||
current_keep_samples: 0,
|
||||
current_seed: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for RandomForestClassifierSearchParametersIterator {
|
||||
type Item = RandomForestClassifierParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_criterion
|
||||
== self
|
||||
.random_forest_classifier_search_parameters
|
||||
.criterion
|
||||
.len()
|
||||
&& self.current_max_depth
|
||||
== self
|
||||
.random_forest_classifier_search_parameters
|
||||
.max_depth
|
||||
.len()
|
||||
&& self.current_min_samples_leaf
|
||||
== self
|
||||
.random_forest_classifier_search_parameters
|
||||
.min_samples_leaf
|
||||
.len()
|
||||
&& self.current_min_samples_split
|
||||
== self
|
||||
.random_forest_classifier_search_parameters
|
||||
.min_samples_split
|
||||
.len()
|
||||
&& self.current_n_trees
|
||||
== self
|
||||
.random_forest_classifier_search_parameters
|
||||
.n_trees
|
||||
.len()
|
||||
&& self.current_m == self.random_forest_classifier_search_parameters.m.len()
|
||||
&& self.current_keep_samples
|
||||
== self
|
||||
.random_forest_classifier_search_parameters
|
||||
.keep_samples
|
||||
.len()
|
||||
&& self.current_seed == self.random_forest_classifier_search_parameters.seed.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = RandomForestClassifierParameters {
|
||||
criterion: self.random_forest_classifier_search_parameters.criterion
|
||||
[self.current_criterion]
|
||||
.clone(),
|
||||
max_depth: self.random_forest_classifier_search_parameters.max_depth
|
||||
[self.current_max_depth],
|
||||
min_samples_leaf: self
|
||||
.random_forest_classifier_search_parameters
|
||||
.min_samples_leaf[self.current_min_samples_leaf],
|
||||
min_samples_split: self
|
||||
.random_forest_classifier_search_parameters
|
||||
.min_samples_split[self.current_min_samples_split],
|
||||
n_trees: self.random_forest_classifier_search_parameters.n_trees[self.current_n_trees],
|
||||
m: self.random_forest_classifier_search_parameters.m[self.current_m],
|
||||
keep_samples: self.random_forest_classifier_search_parameters.keep_samples
|
||||
[self.current_keep_samples],
|
||||
seed: self.random_forest_classifier_search_parameters.seed[self.current_seed],
|
||||
};
|
||||
|
||||
if self.current_criterion + 1
|
||||
< self
|
||||
.random_forest_classifier_search_parameters
|
||||
.criterion
|
||||
.len()
|
||||
{
|
||||
self.current_criterion += 1;
|
||||
} else if self.current_max_depth + 1
|
||||
< self
|
||||
.random_forest_classifier_search_parameters
|
||||
.max_depth
|
||||
.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth += 1;
|
||||
} else if self.current_min_samples_leaf + 1
|
||||
< self
|
||||
.random_forest_classifier_search_parameters
|
||||
.min_samples_leaf
|
||||
.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf += 1;
|
||||
} else if self.current_min_samples_split + 1
|
||||
< self
|
||||
.random_forest_classifier_search_parameters
|
||||
.min_samples_split
|
||||
.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split += 1;
|
||||
} else if self.current_n_trees + 1
|
||||
< self
|
||||
.random_forest_classifier_search_parameters
|
||||
.n_trees
|
||||
.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split = 0;
|
||||
self.current_n_trees += 1;
|
||||
} else if self.current_m + 1 < self.random_forest_classifier_search_parameters.m.len() {
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split = 0;
|
||||
self.current_n_trees = 0;
|
||||
self.current_m += 1;
|
||||
} else if self.current_keep_samples + 1
|
||||
< self
|
||||
.random_forest_classifier_search_parameters
|
||||
.keep_samples
|
||||
.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split = 0;
|
||||
self.current_n_trees = 0;
|
||||
self.current_m = 0;
|
||||
self.current_keep_samples += 1;
|
||||
} else if self.current_seed + 1 < self.random_forest_classifier_search_parameters.seed.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split = 0;
|
||||
self.current_n_trees = 0;
|
||||
self.current_m = 0;
|
||||
self.current_keep_samples = 0;
|
||||
self.current_seed += 1;
|
||||
} else {
|
||||
self.current_criterion += 1;
|
||||
self.current_max_depth += 1;
|
||||
self.current_min_samples_leaf += 1;
|
||||
self.current_min_samples_split += 1;
|
||||
self.current_n_trees += 1;
|
||||
self.current_m += 1;
|
||||
self.current_keep_samples += 1;
|
||||
self.current_seed += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RandomForestClassifierSearchParameters {
|
||||
fn default() -> Self {
|
||||
let default_params = RandomForestClassifierParameters::default();
|
||||
|
||||
RandomForestClassifierSearchParameters {
|
||||
criterion: vec![default_params.criterion],
|
||||
max_depth: vec![default_params.max_depth],
|
||||
min_samples_leaf: vec![default_params.min_samples_leaf],
|
||||
min_samples_split: vec![default_params.min_samples_split],
|
||||
n_trees: vec![default_params.n_trees],
|
||||
m: vec![default_params.m],
|
||||
keep_samples: vec![default_params.keep_samples],
|
||||
seed: vec![default_params.seed],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
RandomForestClassifier<TX, TY, X, Y>
|
||||
{
|
||||
/// Build a forest of trees from the training set.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - the target class values
|
||||
pub fn fit<M: Matrix<T>>(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
pub fn fit(
|
||||
x: &X,
|
||||
y: &Y,
|
||||
parameters: RandomForestClassifierParameters,
|
||||
) -> Result<RandomForestClassifier<T>, Failed> {
|
||||
let (_, num_attributes) = x.shape();
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
let (_, y_ncols) = y_m.shape();
|
||||
let mut yi: Vec<usize> = vec![0; y_ncols];
|
||||
let classes = y_m.unique();
|
||||
|
||||
for (i, yi_i) in yi.iter_mut().enumerate().take(y_ncols) {
|
||||
let yc = y_m.get(0, i);
|
||||
*yi_i = classes.iter().position(|c| yc == *c).unwrap();
|
||||
) -> Result<RandomForestClassifier<TX, TY, X, Y>, Failed> {
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let y_ncols = y.shape();
|
||||
if x_nrows != y_ncols {
|
||||
return Err(Failed::fit("Number of rows in X should = len(y)"));
|
||||
}
|
||||
|
||||
let mtry = parameters.m.unwrap_or_else(|| {
|
||||
(T::from(num_attributes).unwrap())
|
||||
.sqrt()
|
||||
.floor()
|
||||
.to_usize()
|
||||
.unwrap()
|
||||
});
|
||||
let mut yi: Vec<usize> = vec![0; y_ncols];
|
||||
let classes = y.unique();
|
||||
|
||||
let mut rng = StdRng::seed_from_u64(parameters.seed);
|
||||
let classes = y_m.unique();
|
||||
for (i, yi_i) in yi.iter_mut().enumerate().take(y_ncols) {
|
||||
let yc = y.get(i);
|
||||
*yi_i = classes.iter().position(|c| yc == c).unwrap();
|
||||
}
|
||||
|
||||
let mtry = parameters
|
||||
.m
|
||||
.unwrap_or_else(|| ((num_attributes as f64).sqrt().floor()) as usize);
|
||||
|
||||
let mut rng = get_rng_impl(Some(parameters.seed));
|
||||
let classes = y.unique();
|
||||
let k = classes.len();
|
||||
let mut trees: Vec<DecisionTreeClassifier<T>> = Vec::new();
|
||||
// TODO: use with_capacity here
|
||||
let mut trees: Vec<DecisionTreeClassifier<TX, TY, X, Y>> = Vec::new();
|
||||
|
||||
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
|
||||
if parameters.keep_samples {
|
||||
// TODO: use with_capacity here
|
||||
maybe_all_samples = Some(Vec::new());
|
||||
}
|
||||
|
||||
for _ in 0..parameters.n_trees {
|
||||
let samples = RandomForestClassifier::<T>::sample_with_replacement(&yi, k, &mut rng);
|
||||
let samples: Vec<usize> =
|
||||
RandomForestClassifier::<TX, TY, X, Y>::sample_with_replacement(&yi, k, &mut rng);
|
||||
if let Some(ref mut all_samples) = maybe_all_samples {
|
||||
all_samples.push(samples.iter().map(|x| *x != 0).collect())
|
||||
}
|
||||
@@ -242,38 +496,40 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
max_depth: parameters.max_depth,
|
||||
min_samples_leaf: parameters.min_samples_leaf,
|
||||
min_samples_split: parameters.min_samples_split,
|
||||
seed: Some(parameters.seed),
|
||||
};
|
||||
let tree =
|
||||
DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
|
||||
let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params)?;
|
||||
trees.push(tree);
|
||||
}
|
||||
|
||||
Ok(RandomForestClassifier {
|
||||
_parameters: parameters,
|
||||
trees,
|
||||
classes,
|
||||
trees: Some(trees),
|
||||
classes: Some(classes),
|
||||
samples: maybe_all_samples,
|
||||
})
|
||||
}
|
||||
|
||||
/// Predict class for `x`
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let mut result = M::zeros(1, x.shape().0);
|
||||
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
let mut result = Y::zeros(x.shape().0);
|
||||
|
||||
let (n, _) = x.shape();
|
||||
|
||||
for i in 0..n {
|
||||
result.set(0, i, self.classes[self.predict_for_row(x, i)]);
|
||||
result.set(
|
||||
i,
|
||||
self.classes.as_ref().unwrap()[self.predict_for_row(x, i)],
|
||||
);
|
||||
}
|
||||
|
||||
Ok(result.to_row_vector())
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
|
||||
let mut result = vec![0; self.classes.len()];
|
||||
fn predict_for_row(&self, x: &X, row: usize) -> usize {
|
||||
let mut result = vec![0; self.classes.as_ref().unwrap().len()];
|
||||
|
||||
for tree in self.trees.iter() {
|
||||
for tree in self.trees.as_ref().unwrap().iter() {
|
||||
result[tree.predict_for_row(x, row)] += 1;
|
||||
}
|
||||
|
||||
@@ -281,7 +537,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
}
|
||||
|
||||
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
|
||||
pub fn predict_oob<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
|
||||
let (n, _) = x.shape();
|
||||
if self.samples.is_none() {
|
||||
Err(Failed::because(
|
||||
@@ -294,20 +550,28 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
"Prediction matrix must match matrix used in training for OOB predictions.",
|
||||
))
|
||||
} else {
|
||||
let mut result = M::zeros(1, n);
|
||||
let mut result = Y::zeros(n);
|
||||
|
||||
for i in 0..n {
|
||||
result.set(0, i, self.classes[self.predict_for_row_oob(x, i)]);
|
||||
result.set(
|
||||
i,
|
||||
self.classes.as_ref().unwrap()[self.predict_for_row_oob(x, i)],
|
||||
);
|
||||
}
|
||||
|
||||
Ok(result.to_row_vector())
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
fn predict_for_row_oob<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
|
||||
let mut result = vec![0; self.classes.len()];
|
||||
fn predict_for_row_oob(&self, x: &X, row: usize) -> usize {
|
||||
let mut result = vec![0; self.classes.as_ref().unwrap().len()];
|
||||
|
||||
for (tree, samples) in self.trees.iter().zip(self.samples.as_ref().unwrap()) {
|
||||
for (tree, samples) in self
|
||||
.trees
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.zip(self.samples.as_ref().unwrap())
|
||||
{
|
||||
if !samples[row] {
|
||||
result[tree.predict_for_row(x, row)] += 1;
|
||||
}
|
||||
@@ -343,12 +607,38 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use crate::metrics::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn fit_predict_iris() {
|
||||
fn search_parameters() {
|
||||
let parameters = RandomForestClassifierSearchParameters {
|
||||
n_trees: vec![10, 100],
|
||||
m: vec![None, Some(1)],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_trees, 10);
|
||||
assert_eq!(next.m, None);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_trees, 100);
|
||||
assert_eq!(next.m, None);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_trees, 10);
|
||||
assert_eq!(next.m, Some(1));
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_trees, 100);
|
||||
assert_eq!(next.m, Some(1));
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn fit_predict() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
@@ -370,17 +660,16 @@ mod tests {
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
let y = vec![
|
||||
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
])
|
||||
.unwrap();
|
||||
let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
|
||||
|
||||
let classifier = RandomForestClassifier::fit(
|
||||
&x,
|
||||
&y,
|
||||
RandomForestClassifierParameters {
|
||||
criterion: SplitCriterion::Gini,
|
||||
max_depth: None,
|
||||
max_depth: Option::None,
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2,
|
||||
n_trees: 100,
|
||||
@@ -394,7 +683,34 @@ mod tests {
|
||||
assert!(accuracy(&y, &classifier.predict(&x).unwrap()) >= 0.95);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn test_random_matrix_with_wrong_rownum() {
|
||||
let x_rand: DenseMatrix<f64> = DenseMatrix::<f64>::rand(21, 200);
|
||||
|
||||
let y: Vec<u32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
|
||||
|
||||
let fail = RandomForestClassifier::fit(
|
||||
&x_rand,
|
||||
&y,
|
||||
RandomForestClassifierParameters {
|
||||
criterion: SplitCriterion::Gini,
|
||||
max_depth: Option::None,
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2,
|
||||
n_trees: 100,
|
||||
m: Option::None,
|
||||
keep_samples: false,
|
||||
seed: 87,
|
||||
},
|
||||
);
|
||||
|
||||
assert!(fail.is_err());
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn fit_predict_iris_oob() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
@@ -418,17 +734,16 @@ mod tests {
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
let y = vec![
|
||||
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
])
|
||||
.unwrap();
|
||||
let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
|
||||
|
||||
let classifier = RandomForestClassifier::fit(
|
||||
&x,
|
||||
&y,
|
||||
RandomForestClassifierParameters {
|
||||
criterion: SplitCriterion::Gini,
|
||||
max_depth: None,
|
||||
max_depth: Option::None,
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2,
|
||||
n_trees: 100,
|
||||
@@ -445,7 +760,10 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
@@ -470,14 +788,13 @@ mod tests {
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
let y = vec![
|
||||
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
])
|
||||
.unwrap();
|
||||
let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
|
||||
|
||||
let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let deserialized_forest: RandomForestClassifier<f64> =
|
||||
let deserialized_forest: RandomForestClassifier<f64, i64, DenseMatrix<f64>, Vec<i64>> =
|
||||
bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(forest, deserialized_forest);
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
//! use smartcore::ensemble::random_forest_regressor::*;
|
||||
//!
|
||||
//! // Longley dataset (https://www.statsmodels.org/stable/datasets/generated/longley.html)
|
||||
@@ -29,7 +29,7 @@
|
||||
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//! let y = vec![
|
||||
//! 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2,
|
||||
//! 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9
|
||||
@@ -43,8 +43,6 @@
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
|
||||
@@ -52,30 +50,37 @@ use std::fmt::Debug;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::tree::decision_tree_regressor::{
|
||||
DecisionTreeRegressor, DecisionTreeRegressorParameters,
|
||||
};
|
||||
use crate::ensemble::base_forest_regressor::{BaseForestRegressor, BaseForestRegressorParameters};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::basic::arrays::{Array1, Array2};
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::floatnum::FloatNumber;
|
||||
use crate::tree::base_tree_regressor::Splitter;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// Parameters of the Random Forest Regressor
|
||||
/// Some parameters here are passed directly into base estimator.
|
||||
pub struct RandomForestRegressorParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||
pub max_depth: Option<u16>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||
pub min_samples_leaf: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||
pub min_samples_split: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The number of trees in the forest.
|
||||
pub n_trees: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Number of random sample of predictors to use as split candidates.
|
||||
pub m: Option<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||
pub keep_samples: bool,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||
pub seed: u64,
|
||||
}
|
||||
@@ -83,10 +88,13 @@ pub struct RandomForestRegressorParameters {
|
||||
/// Random Forest Regressor
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct RandomForestRegressor<T: RealNumber> {
|
||||
_parameters: RandomForestRegressorParameters,
|
||||
trees: Vec<DecisionTreeRegressor<T>>,
|
||||
samples: Option<Vec<Vec<bool>>>,
|
||||
pub struct RandomForestRegressor<
|
||||
TX: Number + FloatNumber + PartialOrd,
|
||||
TY: Number,
|
||||
X: Array2<TX>,
|
||||
Y: Array1<TY>,
|
||||
> {
|
||||
forest_regressor: Option<BaseForestRegressor<TX, TY, X, Y>>,
|
||||
}
|
||||
|
||||
impl RandomForestRegressorParameters {
|
||||
@@ -131,7 +139,7 @@ impl RandomForestRegressorParameters {
|
||||
impl Default for RandomForestRegressorParameters {
|
||||
fn default() -> Self {
|
||||
RandomForestRegressorParameters {
|
||||
max_depth: None,
|
||||
max_depth: Option::None,
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2,
|
||||
n_trees: 10,
|
||||
@@ -142,167 +150,305 @@ impl Default for RandomForestRegressorParameters {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> PartialEq for RandomForestRegressor<T> {
|
||||
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
|
||||
for RandomForestRegressor<TX, TY, X, Y>
|
||||
{
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.trees.len() != other.trees.len() {
|
||||
false
|
||||
} else {
|
||||
for i in 0..self.trees.len() {
|
||||
if self.trees[i] != other.trees[i] {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
self.forest_regressor == other.forest_regressor
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>>
|
||||
SupervisedEstimator<M, M::RowVector, RandomForestRegressorParameters>
|
||||
for RandomForestRegressor<T>
|
||||
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||
SupervisedEstimator<X, Y, RandomForestRegressorParameters>
|
||||
for RandomForestRegressor<TX, TY, X, Y>
|
||||
{
|
||||
fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: RandomForestRegressorParameters,
|
||||
) -> Result<Self, Failed> {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
forest_regressor: Option::None,
|
||||
}
|
||||
}
|
||||
|
||||
fn fit(x: &X, y: &Y, parameters: RandomForestRegressorParameters) -> Result<Self, Failed> {
|
||||
RandomForestRegressor::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RandomForestRegressor<T> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||
Predictor<X, Y> for RandomForestRegressor<TX, TY, X, Y>
|
||||
{
|
||||
fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> RandomForestRegressor<T> {
|
||||
/// RandomForestRegressor grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RandomForestRegressorSearchParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub max_depth: Vec<Option<u16>>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub min_samples_leaf: Vec<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub min_samples_split: Vec<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The number of trees in the forest.
|
||||
pub n_trees: Vec<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Number of random sample of predictors to use as split candidates.
|
||||
pub m: Vec<Option<usize>>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||
pub keep_samples: Vec<bool>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||
pub seed: Vec<u64>,
|
||||
}
|
||||
|
||||
/// RandomForestRegressor grid search iterator
|
||||
pub struct RandomForestRegressorSearchParametersIterator {
|
||||
random_forest_regressor_search_parameters: RandomForestRegressorSearchParameters,
|
||||
current_max_depth: usize,
|
||||
current_min_samples_leaf: usize,
|
||||
current_min_samples_split: usize,
|
||||
current_n_trees: usize,
|
||||
current_m: usize,
|
||||
current_keep_samples: usize,
|
||||
current_seed: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for RandomForestRegressorSearchParameters {
|
||||
type Item = RandomForestRegressorParameters;
|
||||
type IntoIter = RandomForestRegressorSearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
RandomForestRegressorSearchParametersIterator {
|
||||
random_forest_regressor_search_parameters: self,
|
||||
current_max_depth: 0,
|
||||
current_min_samples_leaf: 0,
|
||||
current_min_samples_split: 0,
|
||||
current_n_trees: 0,
|
||||
current_m: 0,
|
||||
current_keep_samples: 0,
|
||||
current_seed: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for RandomForestRegressorSearchParametersIterator {
|
||||
type Item = RandomForestRegressorParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_max_depth
|
||||
== self
|
||||
.random_forest_regressor_search_parameters
|
||||
.max_depth
|
||||
.len()
|
||||
&& self.current_min_samples_leaf
|
||||
== self
|
||||
.random_forest_regressor_search_parameters
|
||||
.min_samples_leaf
|
||||
.len()
|
||||
&& self.current_min_samples_split
|
||||
== self
|
||||
.random_forest_regressor_search_parameters
|
||||
.min_samples_split
|
||||
.len()
|
||||
&& self.current_n_trees == self.random_forest_regressor_search_parameters.n_trees.len()
|
||||
&& self.current_m == self.random_forest_regressor_search_parameters.m.len()
|
||||
&& self.current_keep_samples
|
||||
== self
|
||||
.random_forest_regressor_search_parameters
|
||||
.keep_samples
|
||||
.len()
|
||||
&& self.current_seed == self.random_forest_regressor_search_parameters.seed.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = RandomForestRegressorParameters {
|
||||
max_depth: self.random_forest_regressor_search_parameters.max_depth
|
||||
[self.current_max_depth],
|
||||
min_samples_leaf: self
|
||||
.random_forest_regressor_search_parameters
|
||||
.min_samples_leaf[self.current_min_samples_leaf],
|
||||
min_samples_split: self
|
||||
.random_forest_regressor_search_parameters
|
||||
.min_samples_split[self.current_min_samples_split],
|
||||
n_trees: self.random_forest_regressor_search_parameters.n_trees[self.current_n_trees],
|
||||
m: self.random_forest_regressor_search_parameters.m[self.current_m],
|
||||
keep_samples: self.random_forest_regressor_search_parameters.keep_samples
|
||||
[self.current_keep_samples],
|
||||
seed: self.random_forest_regressor_search_parameters.seed[self.current_seed],
|
||||
};
|
||||
|
||||
if self.current_max_depth + 1
|
||||
< self
|
||||
.random_forest_regressor_search_parameters
|
||||
.max_depth
|
||||
.len()
|
||||
{
|
||||
self.current_max_depth += 1;
|
||||
} else if self.current_min_samples_leaf + 1
|
||||
< self
|
||||
.random_forest_regressor_search_parameters
|
||||
.min_samples_leaf
|
||||
.len()
|
||||
{
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf += 1;
|
||||
} else if self.current_min_samples_split + 1
|
||||
< self
|
||||
.random_forest_regressor_search_parameters
|
||||
.min_samples_split
|
||||
.len()
|
||||
{
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split += 1;
|
||||
} else if self.current_n_trees + 1
|
||||
< self.random_forest_regressor_search_parameters.n_trees.len()
|
||||
{
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split = 0;
|
||||
self.current_n_trees += 1;
|
||||
} else if self.current_m + 1 < self.random_forest_regressor_search_parameters.m.len() {
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split = 0;
|
||||
self.current_n_trees = 0;
|
||||
self.current_m += 1;
|
||||
} else if self.current_keep_samples + 1
|
||||
< self
|
||||
.random_forest_regressor_search_parameters
|
||||
.keep_samples
|
||||
.len()
|
||||
{
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split = 0;
|
||||
self.current_n_trees = 0;
|
||||
self.current_m = 0;
|
||||
self.current_keep_samples += 1;
|
||||
} else if self.current_seed + 1 < self.random_forest_regressor_search_parameters.seed.len()
|
||||
{
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split = 0;
|
||||
self.current_n_trees = 0;
|
||||
self.current_m = 0;
|
||||
self.current_keep_samples = 0;
|
||||
self.current_seed += 1;
|
||||
} else {
|
||||
self.current_max_depth += 1;
|
||||
self.current_min_samples_leaf += 1;
|
||||
self.current_min_samples_split += 1;
|
||||
self.current_n_trees += 1;
|
||||
self.current_m += 1;
|
||||
self.current_keep_samples += 1;
|
||||
self.current_seed += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RandomForestRegressorSearchParameters {
|
||||
fn default() -> Self {
|
||||
let default_params = RandomForestRegressorParameters::default();
|
||||
|
||||
RandomForestRegressorSearchParameters {
|
||||
max_depth: vec![default_params.max_depth],
|
||||
min_samples_leaf: vec![default_params.min_samples_leaf],
|
||||
min_samples_split: vec![default_params.min_samples_split],
|
||||
n_trees: vec![default_params.n_trees],
|
||||
m: vec![default_params.m],
|
||||
keep_samples: vec![default_params.keep_samples],
|
||||
seed: vec![default_params.seed],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||
RandomForestRegressor<TX, TY, X, Y>
|
||||
{
|
||||
/// Build a forest of trees from the training set.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - the target class values
|
||||
pub fn fit<M: Matrix<T>>(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
pub fn fit(
|
||||
x: &X,
|
||||
y: &Y,
|
||||
parameters: RandomForestRegressorParameters,
|
||||
) -> Result<RandomForestRegressor<T>, Failed> {
|
||||
let (n_rows, num_attributes) = x.shape();
|
||||
|
||||
let mtry = parameters
|
||||
.m
|
||||
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
|
||||
|
||||
let mut rng = StdRng::seed_from_u64(parameters.seed);
|
||||
let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new();
|
||||
|
||||
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
|
||||
if parameters.keep_samples {
|
||||
maybe_all_samples = Some(Vec::new());
|
||||
}
|
||||
|
||||
for _ in 0..parameters.n_trees {
|
||||
let samples = RandomForestRegressor::<T>::sample_with_replacement(n_rows, &mut rng);
|
||||
if let Some(ref mut all_samples) = maybe_all_samples {
|
||||
all_samples.push(samples.iter().map(|x| *x != 0).collect())
|
||||
}
|
||||
let params = DecisionTreeRegressorParameters {
|
||||
max_depth: parameters.max_depth,
|
||||
min_samples_leaf: parameters.min_samples_leaf,
|
||||
min_samples_split: parameters.min_samples_split,
|
||||
};
|
||||
let tree =
|
||||
DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
|
||||
trees.push(tree);
|
||||
}
|
||||
) -> Result<RandomForestRegressor<TX, TY, X, Y>, Failed> {
|
||||
let regressor_params = BaseForestRegressorParameters {
|
||||
max_depth: parameters.max_depth,
|
||||
min_samples_leaf: parameters.min_samples_leaf,
|
||||
min_samples_split: parameters.min_samples_split,
|
||||
n_trees: parameters.n_trees,
|
||||
m: parameters.m,
|
||||
keep_samples: parameters.keep_samples,
|
||||
seed: parameters.seed,
|
||||
bootstrap: true,
|
||||
splitter: Splitter::Best,
|
||||
};
|
||||
let forest_regressor = BaseForestRegressor::fit(x, y, regressor_params)?;
|
||||
|
||||
Ok(RandomForestRegressor {
|
||||
_parameters: parameters,
|
||||
trees,
|
||||
samples: maybe_all_samples,
|
||||
forest_regressor: Some(forest_regressor),
|
||||
})
|
||||
}
|
||||
|
||||
/// Predict class for `x`
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let mut result = M::zeros(1, x.shape().0);
|
||||
|
||||
let (n, _) = x.shape();
|
||||
|
||||
for i in 0..n {
|
||||
result.set(0, i, self.predict_for_row(x, i));
|
||||
}
|
||||
|
||||
Ok(result.to_row_vector())
|
||||
}
|
||||
|
||||
fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
|
||||
let n_trees = self.trees.len();
|
||||
|
||||
let mut result = T::zero();
|
||||
|
||||
for tree in self.trees.iter() {
|
||||
result += tree.predict_for_row(x, row);
|
||||
}
|
||||
|
||||
result / T::from(n_trees).unwrap()
|
||||
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
let forest_regressor = self.forest_regressor.as_ref().unwrap();
|
||||
forest_regressor.predict(x)
|
||||
}
|
||||
|
||||
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
|
||||
pub fn predict_oob<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let (n, _) = x.shape();
|
||||
if self.samples.is_none() {
|
||||
Err(Failed::because(
|
||||
FailedError::PredictFailed,
|
||||
"Need samples=true for OOB predictions.",
|
||||
))
|
||||
} else if self.samples.as_ref().unwrap()[0].len() != n {
|
||||
Err(Failed::because(
|
||||
FailedError::PredictFailed,
|
||||
"Prediction matrix must match matrix used in training for OOB predictions.",
|
||||
))
|
||||
} else {
|
||||
let mut result = M::zeros(1, n);
|
||||
|
||||
for i in 0..n {
|
||||
result.set(0, i, self.predict_for_row_oob(x, i));
|
||||
}
|
||||
|
||||
Ok(result.to_row_vector())
|
||||
}
|
||||
}
|
||||
|
||||
fn predict_for_row_oob<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
|
||||
let mut n_trees = 0;
|
||||
let mut result = T::zero();
|
||||
|
||||
for (tree, samples) in self.trees.iter().zip(self.samples.as_ref().unwrap()) {
|
||||
if !samples[row] {
|
||||
result += tree.predict_for_row(x, row);
|
||||
n_trees += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: What to do if there are no oob trees?
|
||||
result / T::from(n_trees).unwrap()
|
||||
}
|
||||
|
||||
fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec<usize> {
|
||||
let mut samples = vec![0; nrows];
|
||||
for _ in 0..nrows {
|
||||
let xi = rng.gen_range(0..nrows);
|
||||
samples[xi] += 1;
|
||||
}
|
||||
samples
|
||||
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
|
||||
let forest_regressor = self.forest_regressor.as_ref().unwrap();
|
||||
forest_regressor.predict_oob(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use crate::metrics::mean_absolute_error;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = RandomForestRegressorSearchParameters {
|
||||
n_trees: vec![10, 100],
|
||||
m: vec![None, Some(1)],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_trees, 10);
|
||||
assert_eq!(next.m, None);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_trees, 100);
|
||||
assert_eq!(next.m, None);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_trees, 10);
|
||||
assert_eq!(next.m, Some(1));
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_trees, 100);
|
||||
assert_eq!(next.m, Some(1));
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn fit_longley() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
@@ -322,7 +468,8 @@ mod tests {
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
@@ -332,7 +479,7 @@ mod tests {
|
||||
&x,
|
||||
&y,
|
||||
RandomForestRegressorParameters {
|
||||
max_depth: None,
|
||||
max_depth: Option::None,
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2,
|
||||
n_trees: 1000,
|
||||
@@ -347,7 +494,36 @@ mod tests {
|
||||
assert!(mean_absolute_error(&y, &y_hat) < 1.0);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn test_random_matrix_with_wrong_rownum() {
|
||||
let x_rand: DenseMatrix<f64> = DenseMatrix::<f64>::rand(17, 200);
|
||||
|
||||
let y = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
|
||||
let fail = RandomForestRegressor::fit(
|
||||
&x_rand,
|
||||
&y,
|
||||
RandomForestRegressorParameters {
|
||||
max_depth: Option::None,
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2,
|
||||
n_trees: 1000,
|
||||
m: Option::None,
|
||||
keep_samples: false,
|
||||
seed: 87,
|
||||
},
|
||||
);
|
||||
|
||||
assert!(fail.is_err());
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn fit_predict_longley_oob() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
@@ -367,7 +543,8 @@ mod tests {
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
@@ -377,7 +554,7 @@ mod tests {
|
||||
&x,
|
||||
&y,
|
||||
RandomForestRegressorParameters {
|
||||
max_depth: None,
|
||||
max_depth: Option::None,
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2,
|
||||
n_trees: 1000,
|
||||
@@ -391,10 +568,16 @@ mod tests {
|
||||
let y_hat = regressor.predict(&x).unwrap();
|
||||
let y_hat_oob = regressor.predict_oob(&x).unwrap();
|
||||
|
||||
println!("{:?}", mean_absolute_error(&y, &y_hat));
|
||||
println!("{:?}", mean_absolute_error(&y, &y_hat_oob));
|
||||
|
||||
assert!(mean_absolute_error(&y, &y_hat) < mean_absolute_error(&y, &y_hat_oob));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
@@ -415,7 +598,8 @@ mod tests {
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
@@ -423,7 +607,7 @@ mod tests {
|
||||
|
||||
let forest = RandomForestRegressor::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let deserialized_forest: RandomForestRegressor<f64> =
|
||||
let deserialized_forest: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>> =
|
||||
bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(forest, deserialized_forest);
|
||||
|
||||
+23
-1
@@ -30,6 +30,10 @@ pub enum FailedError {
|
||||
DecompositionFailed,
|
||||
/// Can't solve for x
|
||||
SolutionFailed,
|
||||
/// Error in input parameters
|
||||
ParametersError,
|
||||
/// Invalid state error (should never happen)
|
||||
InvalidStateError,
|
||||
}
|
||||
|
||||
impl Failed {
|
||||
@@ -62,6 +66,22 @@ impl Failed {
|
||||
}
|
||||
}
|
||||
|
||||
/// new instance of `FailedError::ParametersError`
|
||||
pub fn input(msg: &str) -> Self {
|
||||
Failed {
|
||||
err: FailedError::ParametersError,
|
||||
msg: msg.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// new instance of `FailedError::InvalidStateError`
|
||||
pub fn invalid_state(msg: &str) -> Self {
|
||||
Failed {
|
||||
err: FailedError::InvalidStateError,
|
||||
msg: msg.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// new instance of `err`
|
||||
pub fn because(err: FailedError, msg: &str) -> Self {
|
||||
Failed {
|
||||
@@ -94,8 +114,10 @@ impl fmt::Display for FailedError {
|
||||
FailedError::FindFailed => "Find failed",
|
||||
FailedError::DecompositionFailed => "Decomposition failed",
|
||||
FailedError::SolutionFailed => "Can't find solution",
|
||||
FailedError::ParametersError => "Error in input, check parameters",
|
||||
FailedError::InvalidStateError => "Invalid state, this should never happen", // useful in development phase of lib
|
||||
};
|
||||
write!(f, "{}", failed_err_str)
|
||||
write!(f, "{failed_err_str}")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+78
-44
@@ -3,32 +3,81 @@
|
||||
clippy::too_many_arguments,
|
||||
clippy::many_single_char_names,
|
||||
clippy::unnecessary_wraps,
|
||||
clippy::upper_case_acronyms
|
||||
clippy::upper_case_acronyms,
|
||||
clippy::approx_constant
|
||||
)]
|
||||
#![warn(missing_docs)]
|
||||
#![warn(rustdoc::missing_doc_code_examples)]
|
||||
|
||||
//! # SmartCore
|
||||
//! # smartcore
|
||||
//!
|
||||
//! Welcome to SmartCore, the most advanced machine learning library in Rust!
|
||||
//! Welcome to `smartcore`, machine learning in Rust!
|
||||
//!
|
||||
//! SmartCore features various classification, regression and clustering algorithms including support vector machines, random forests, k-means and DBSCAN,
|
||||
//! `smartcore` features various classification, regression and clustering algorithms including support vector machines, random forests, k-means and DBSCAN,
|
||||
//! as well as tools for model selection and model evaluation.
|
||||
//!
|
||||
//! SmartCore is well integrated with a with wide variaty of libraries that provide support for large, multi-dimensional arrays and matrices. At this moment,
|
||||
//! all Smartcore's algorithms work with ordinary Rust vectors, as well as matrices and vectors defined in these packages:
|
||||
//! * [ndarray](https://docs.rs/ndarray)
|
||||
//! * [nalgebra](https://docs.rs/nalgebra/)
|
||||
//! `smartcore` provides its own traits system that extends Rust standard library, to deal with linear algebra and common
|
||||
//! computational models. Its API is designed using well recognizable patterns. Extra features (like support for [ndarray](https://docs.rs/ndarray)
|
||||
//! structures) is available via optional features.
|
||||
//!
|
||||
//! ## Getting Started
|
||||
//!
|
||||
//! To start using SmartCore simply add the following to your Cargo.toml file:
|
||||
//! To start using `smartcore` latest stable version simply add the following to your `Cargo.toml` file:
|
||||
//! ```ignore
|
||||
//! [dependencies]
|
||||
//! smartcore = "0.2.0"
|
||||
//! smartcore = "*"
|
||||
//! ```
|
||||
//!
|
||||
//! All machine learning algorithms in SmartCore are grouped into these broad categories:
|
||||
//! To start using smartcore development version with latest unstable additions:
|
||||
//! ```ignore
|
||||
//! [dependencies]
|
||||
//! smartcore = { git = "https://github.com/smartcorelib/smartcore", branch = "development" }
|
||||
//! ```
|
||||
//!
|
||||
//! There are different features that can be added to the base library, for example to add sample datasets:
|
||||
//! ```ignore
|
||||
//! [dependencies]
|
||||
//! smartcore = { git = "https://github.com/smartcorelib/smartcore", features = ["datasets"] }
|
||||
//! ```
|
||||
//! Check `smartcore`'s `Cargo.toml` for available features.
|
||||
//!
|
||||
//! ## Using Jupyter
|
||||
//! For quick introduction, Jupyter Notebooks are available [here](https://github.com/smartcorelib/smartcore-jupyter/tree/main/notebooks).
|
||||
//! You can set up a local environment to run Rust notebooks using [EVCXR](https://github.com/google/evcxr)
|
||||
//! following [these instructions](https://depth-first.com/articles/2020/09/21/interactive-rust-in-a-repl-and-jupyter-notebook-with-evcxr/).
|
||||
//!
|
||||
//!
|
||||
//! ## First Example
|
||||
//! For example, you can use this code to fit a [K Nearest Neighbors classifier](neighbors/knn_classifier/index.html) to a dataset that is defined as standard Rust vector:
|
||||
//!
|
||||
//! ```
|
||||
//! // DenseMatrix definition
|
||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
//! // KNNClassifier
|
||||
//! use smartcore::neighbors::knn_classifier::*;
|
||||
//! // Various distance metrics
|
||||
//! use smartcore::metrics::distance::*;
|
||||
//!
|
||||
//! // Turn Rust vector-slices with samples into a matrix
|
||||
//! let x = DenseMatrix::from_2d_array(&[
|
||||
//! &[1., 2.],
|
||||
//! &[3., 4.],
|
||||
//! &[5., 6.],
|
||||
//! &[7., 8.],
|
||||
//! &[9., 10.]]).unwrap();
|
||||
//! // Our classes are defined as a vector
|
||||
//! let y = vec![2, 2, 2, 3, 3];
|
||||
//!
|
||||
//! // Train classifier
|
||||
//! let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||
//!
|
||||
//! // Predict classes
|
||||
//! let y_hat = knn.predict(&x).unwrap();
|
||||
//! ```
|
||||
//!
|
||||
//! ## Overview
|
||||
//!
|
||||
//! ### Supported algorithms
|
||||
//! All machine learning algorithms are grouped into these broad categories:
|
||||
//! * [Clustering](cluster/index.html), unsupervised clustering of unlabeled data.
|
||||
//! * [Matrix Decomposition](decomposition/index.html), various methods for matrix decomposition.
|
||||
//! * [Linear Models](linear/index.html), regression and classification methods where output is assumed to have linear relation to explanatory variables
|
||||
@@ -38,37 +87,16 @@
|
||||
//! * [Naive Bayes](naive_bayes/index.html), statistical classification technique based on Bayes Theorem
|
||||
//! * [SVM](svm/index.html), support vector machines
|
||||
//!
|
||||
//!
|
||||
//! For example, you can use this code to fit a [K Nearest Neighbors classifier](neighbors/knn_classifier/index.html) to a dataset that is defined as standard Rust vector:
|
||||
//!
|
||||
//! ```
|
||||
//! // DenseMatrix defenition
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! // KNNClassifier
|
||||
//! use smartcore::neighbors::knn_classifier::*;
|
||||
//! // Various distance metrics
|
||||
//! use smartcore::math::distance::*;
|
||||
//!
|
||||
//! // Turn Rust vectors with samples into a matrix
|
||||
//! let x = DenseMatrix::from_2d_array(&[
|
||||
//! &[1., 2.],
|
||||
//! &[3., 4.],
|
||||
//! &[5., 6.],
|
||||
//! &[7., 8.],
|
||||
//! &[9., 10.]]);
|
||||
//! // Our classes are defined as a Vector
|
||||
//! let y = vec![2., 2., 2., 3., 3.];
|
||||
//!
|
||||
//! // Train classifier
|
||||
//! let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||
//!
|
||||
//! // Predict classes
|
||||
//! let y_hat = knn.predict(&x).unwrap();
|
||||
//! ```
|
||||
//! ### Linear Algebra traits system
|
||||
//! For an introduction to `smartcore`'s traits system see [this notebook](https://github.com/smartcorelib/smartcore-jupyter/blob/5523993c53c6ec1fd72eea130ef4e7883121c1ea/notebooks/01-A-little-bit-about-numbers.ipynb)
|
||||
|
||||
/// Various algorithms and helper methods that are used elsewhere in SmartCore
|
||||
/// Foundamental numbers traits
|
||||
pub mod numbers;
|
||||
|
||||
/// Various algorithms and helper methods that are used elsewhere in smartcore
|
||||
pub mod algorithm;
|
||||
pub mod api;
|
||||
|
||||
/// Algorithms for clustering of unlabeled data
|
||||
pub mod cluster;
|
||||
/// Various datasets
|
||||
@@ -79,23 +107,29 @@ pub mod decomposition;
|
||||
/// Ensemble methods, including Random Forest classifier and regressor
|
||||
pub mod ensemble;
|
||||
pub mod error;
|
||||
/// Diverse collection of linear algebra abstractions and methods that power SmartCore algorithms
|
||||
/// Diverse collection of linear algebra abstractions and methods that power smartcore algorithms
|
||||
pub mod linalg;
|
||||
/// Supervised classification and regression models that assume linear relationship between dependent and explanatory variables.
|
||||
pub mod linear;
|
||||
/// Helper methods and classes, including definitions of distance metrics
|
||||
pub mod math;
|
||||
/// Functions for assessing prediction error.
|
||||
pub mod metrics;
|
||||
/// TODO: add docstring for model_selection
|
||||
pub mod model_selection;
|
||||
/// Supervised learning algorithms based on applying the Bayes theorem with the independence assumptions between predictors
|
||||
pub mod naive_bayes;
|
||||
/// Supervised neighbors-based learning methods
|
||||
pub mod neighbors;
|
||||
pub(crate) mod optimization;
|
||||
/// Optimization procedures
|
||||
pub mod optimization;
|
||||
/// Preprocessing utilities
|
||||
pub mod preprocessing;
|
||||
/// Reading in data from serialized formats
|
||||
#[cfg(feature = "serde")]
|
||||
pub mod readers;
|
||||
/// Support Vector Machines
|
||||
pub mod svm;
|
||||
/// Supervised tree-based learning methods
|
||||
pub mod tree;
|
||||
pub mod xgboost;
|
||||
|
||||
pub(crate) mod rand_custom;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,845 @@
|
||||
use std::fmt;
|
||||
use std::fmt::{Debug, Display};
|
||||
use std::ops::Range;
|
||||
use std::slice::Iter;
|
||||
|
||||
use approx::{AbsDiffEq, RelativeEq};
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::basic::arrays::{
|
||||
Array, Array2, ArrayView1, ArrayView2, MutArray, MutArrayView2,
|
||||
};
|
||||
use crate::linalg::traits::cholesky::CholeskyDecomposable;
|
||||
use crate::linalg::traits::evd::EVDDecomposable;
|
||||
use crate::linalg::traits::lu::LUDecomposable;
|
||||
use crate::linalg::traits::qr::QRDecomposable;
|
||||
use crate::linalg::traits::stats::{MatrixPreprocessing, MatrixStats};
|
||||
use crate::linalg::traits::svd::SVDDecomposable;
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
|
||||
use crate::error::Failed;
|
||||
|
||||
/// Dense matrix
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DenseMatrix<T> {
|
||||
ncols: usize,
|
||||
nrows: usize,
|
||||
values: Vec<T>,
|
||||
column_major: bool,
|
||||
}
|
||||
|
||||
/// View on dense matrix
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DenseMatrixView<'a, T: Debug + Display + Copy + Sized> {
|
||||
values: &'a [T],
|
||||
stride: usize,
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
column_major: bool,
|
||||
}
|
||||
|
||||
/// Mutable view on dense matrix
|
||||
#[derive(Debug)]
|
||||
pub struct DenseMatrixMutView<'a, T: Debug + Display + Copy + Sized> {
|
||||
values: &'a mut [T],
|
||||
stride: usize,
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
column_major: bool,
|
||||
}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixView<'a, T> {
|
||||
fn new(
|
||||
m: &'a DenseMatrix<T>,
|
||||
vrows: Range<usize>,
|
||||
vcols: Range<usize>,
|
||||
) -> Result<Self, Failed> {
|
||||
if m.is_valid_view(m.shape().0, m.shape().1, &vrows, &vcols) {
|
||||
Err(Failed::input(
|
||||
"The specified view is outside of the matrix range",
|
||||
))
|
||||
} else {
|
||||
let (start, end, stride) =
|
||||
m.stride_range(m.shape().0, m.shape().1, &vrows, &vcols, m.column_major);
|
||||
|
||||
Ok(DenseMatrixView {
|
||||
values: &m.values[start..end],
|
||||
stride,
|
||||
nrows: vrows.end - vrows.start,
|
||||
ncols: vcols.end - vcols.start,
|
||||
column_major: m.column_major,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn iter<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
|
||||
assert!(
|
||||
axis == 1 || axis == 0,
|
||||
"For two dimensional array `axis` should be either 0 or 1"
|
||||
);
|
||||
match axis {
|
||||
0 => Box::new(
|
||||
(0..self.nrows).flat_map(move |r| (0..self.ncols).map(move |c| self.get((r, c)))),
|
||||
),
|
||||
_ => Box::new(
|
||||
(0..self.ncols).flat_map(move |c| (0..self.nrows).map(move |r| self.get((r, c)))),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixView<'_, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"DenseMatrix: nrows: {:?}, ncols: {:?}",
|
||||
self.nrows, self.ncols
|
||||
)?;
|
||||
writeln!(f, "column_major: {:?}", self.column_major)?;
|
||||
self.display(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> {
|
||||
fn new(
|
||||
m: &'a mut DenseMatrix<T>,
|
||||
vrows: Range<usize>,
|
||||
vcols: Range<usize>,
|
||||
) -> Result<Self, Failed> {
|
||||
if m.is_valid_view(m.shape().0, m.shape().1, &vrows, &vcols) {
|
||||
Err(Failed::input(
|
||||
"The specified view is outside of the matrix range",
|
||||
))
|
||||
} else {
|
||||
let (start, end, stride) =
|
||||
m.stride_range(m.shape().0, m.shape().1, &vrows, &vcols, m.column_major);
|
||||
|
||||
Ok(DenseMatrixMutView {
|
||||
values: &mut m.values[start..end],
|
||||
stride,
|
||||
nrows: vrows.end - vrows.start,
|
||||
ncols: vcols.end - vcols.start,
|
||||
column_major: m.column_major,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn iter<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
|
||||
assert!(
|
||||
axis == 1 || axis == 0,
|
||||
"For two dimensional array `axis` should be either 0 or 1"
|
||||
);
|
||||
match axis {
|
||||
0 => Box::new(
|
||||
(0..self.nrows).flat_map(move |r| (0..self.ncols).map(move |c| self.get((r, c)))),
|
||||
),
|
||||
_ => Box::new(
|
||||
(0..self.ncols).flat_map(move |c| (0..self.nrows).map(move |r| self.get((r, c)))),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn iter_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
|
||||
let column_major = self.column_major;
|
||||
let stride = self.stride;
|
||||
let ptr = self.values.as_mut_ptr();
|
||||
match axis {
|
||||
0 => Box::new((0..self.nrows).flat_map(move |r| {
|
||||
(0..self.ncols).map(move |c| unsafe {
|
||||
&mut *ptr.add(if column_major {
|
||||
r + c * stride
|
||||
} else {
|
||||
r * stride + c
|
||||
})
|
||||
})
|
||||
})),
|
||||
_ => Box::new((0..self.ncols).flat_map(move |c| {
|
||||
(0..self.nrows).map(move |r| unsafe {
|
||||
&mut *ptr.add(if column_major {
|
||||
r + c * stride
|
||||
} else {
|
||||
r * stride + c
|
||||
})
|
||||
})
|
||||
})),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixMutView<'_, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"DenseMatrix: nrows: {:?}, ncols: {:?}",
|
||||
self.nrows, self.ncols
|
||||
)?;
|
||||
writeln!(f, "column_major: {:?}", self.column_major)?;
|
||||
self.display(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> DenseMatrix<T> {
|
||||
/// Create new instance of `DenseMatrix` without copying data.
|
||||
/// `values` should be in column-major order.
|
||||
pub fn new(
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
values: Vec<T>,
|
||||
column_major: bool,
|
||||
) -> Result<Self, Failed> {
|
||||
let data_len = values.len();
|
||||
if nrows * ncols != values.len() {
|
||||
Err(Failed::input(&format!(
|
||||
"The specified shape: (cols: {ncols}, rows: {nrows}) does not align with data len: {data_len}"
|
||||
)))
|
||||
} else {
|
||||
Ok(DenseMatrix {
|
||||
ncols,
|
||||
nrows,
|
||||
values,
|
||||
column_major,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// New instance of `DenseMatrix` from 2d array.
|
||||
pub fn from_2d_array(values: &[&[T]]) -> Result<Self, Failed> {
|
||||
DenseMatrix::from_2d_vec(&values.iter().map(|row| Vec::from(*row)).collect())
|
||||
}
|
||||
|
||||
/// New instance of `DenseMatrix` from 2d vector.
|
||||
#[allow(clippy::ptr_arg)]
|
||||
pub fn from_2d_vec(values: &Vec<Vec<T>>) -> Result<Self, Failed> {
|
||||
if values.is_empty() || values[0].is_empty() {
|
||||
Err(Failed::input(
|
||||
"The 2d vec provided is empty; cannot instantiate the matrix",
|
||||
))
|
||||
} else {
|
||||
let nrows = values.len();
|
||||
let ncols = values
|
||||
.first()
|
||||
.unwrap_or_else(|| {
|
||||
panic!("Invalid state: Cannot create 2d matrix from an empty vector")
|
||||
})
|
||||
.len();
|
||||
let mut m_values = Vec::with_capacity(nrows * ncols);
|
||||
|
||||
for c in 0..ncols {
|
||||
for r in values.iter().take(nrows) {
|
||||
m_values.push(r[c])
|
||||
}
|
||||
}
|
||||
|
||||
DenseMatrix::new(nrows, ncols, m_values, true)
|
||||
}
|
||||
}
|
||||
|
||||
/// Iterate over values of matrix
|
||||
pub fn iter(&self) -> Iter<'_, T> {
|
||||
self.values.iter()
|
||||
}
|
||||
|
||||
/// Check if the size of the requested view is bounded to matrix rows/cols count
|
||||
fn is_valid_view(
|
||||
&self,
|
||||
n_rows: usize,
|
||||
n_cols: usize,
|
||||
vrows: &Range<usize>,
|
||||
vcols: &Range<usize>,
|
||||
) -> bool {
|
||||
!(vrows.end <= n_rows
|
||||
&& vcols.end <= n_cols
|
||||
&& vrows.start <= n_rows
|
||||
&& vcols.start <= n_cols)
|
||||
}
|
||||
|
||||
/// Compute the range of the requested view: start, end, size of the slice
|
||||
fn stride_range(
|
||||
&self,
|
||||
n_rows: usize,
|
||||
n_cols: usize,
|
||||
vrows: &Range<usize>,
|
||||
vcols: &Range<usize>,
|
||||
column_major: bool,
|
||||
) -> (usize, usize, usize) {
|
||||
let (start, end, stride) = if column_major {
|
||||
(
|
||||
vrows.start + vcols.start * n_rows,
|
||||
vrows.end + (vcols.end - 1) * n_rows,
|
||||
n_rows,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
vrows.start * n_cols + vcols.start,
|
||||
(vrows.end - 1) * n_cols + vcols.end,
|
||||
n_cols,
|
||||
)
|
||||
};
|
||||
(start, end, stride)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrix<T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"DenseMatrix: nrows: {:?}, ncols: {:?}",
|
||||
self.nrows, self.ncols
|
||||
)?;
|
||||
writeln!(f, "column_major: {:?}", self.column_major)?;
|
||||
self.display(f)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized + PartialEq> PartialEq for DenseMatrix<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.ncols != other.ncols || self.nrows != other.nrows {
|
||||
return false;
|
||||
}
|
||||
|
||||
let len = self.values.len();
|
||||
let other_len = other.values.len();
|
||||
|
||||
if len != other_len {
|
||||
return false;
|
||||
}
|
||||
|
||||
match self.column_major == other.column_major {
|
||||
true => self
|
||||
.values
|
||||
.iter()
|
||||
.zip(other.values.iter())
|
||||
.all(|(&v1, v2)| v1.eq(v2)),
|
||||
false => self
|
||||
.iterator(0)
|
||||
.zip(other.iterator(0))
|
||||
.all(|(&v1, v2)| v1.eq(v2)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Number + RealNumber + AbsDiffEq> AbsDiffEq for DenseMatrix<T>
|
||||
where
|
||||
T::Epsilon: Copy,
|
||||
{
|
||||
type Epsilon = T::Epsilon;
|
||||
|
||||
fn default_epsilon() -> T::Epsilon {
|
||||
T::default_epsilon()
|
||||
}
|
||||
|
||||
// equality in differences in absolute values, according to an epsilon
|
||||
fn abs_diff_eq(&self, other: &Self, epsilon: T::Epsilon) -> bool {
|
||||
if self.ncols != other.ncols || self.nrows != other.nrows {
|
||||
false
|
||||
} else {
|
||||
self.values
|
||||
.iter()
|
||||
.zip(other.values.iter())
|
||||
.all(|(v1, v2)| T::abs_diff_eq(v1, v2, epsilon))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Number + RealNumber + RelativeEq> RelativeEq for DenseMatrix<T>
|
||||
where
|
||||
T::Epsilon: Copy,
|
||||
{
|
||||
fn default_max_relative() -> T::Epsilon {
|
||||
T::default_max_relative()
|
||||
}
|
||||
|
||||
fn relative_eq(&self, other: &Self, epsilon: T::Epsilon, max_relative: T::Epsilon) -> bool {
|
||||
if self.ncols != other.ncols || self.nrows != other.nrows {
|
||||
false
|
||||
} else {
|
||||
self.iterator(0)
|
||||
.zip(other.iterator(0))
|
||||
.all(|(v1, v2)| T::relative_eq(v1, v2, epsilon, max_relative))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrix<T> {
|
||||
fn get(&self, pos: (usize, usize)) -> &T {
|
||||
let (row, col) = pos;
|
||||
|
||||
if row >= self.nrows || col >= self.ncols {
|
||||
panic!(
|
||||
"Invalid index ({},{}) for {}x{} matrix",
|
||||
row, col, self.nrows, self.ncols
|
||||
);
|
||||
}
|
||||
if self.column_major {
|
||||
&self.values[col * self.nrows + row]
|
||||
} else {
|
||||
&self.values[col + self.ncols * row]
|
||||
}
|
||||
}
|
||||
|
||||
fn shape(&self) -> (usize, usize) {
|
||||
(self.nrows, self.ncols)
|
||||
}
|
||||
|
||||
fn is_empty(&self) -> bool {
|
||||
self.ncols > 0 && self.nrows > 0
|
||||
}
|
||||
|
||||
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
|
||||
assert!(
|
||||
axis == 1 || axis == 0,
|
||||
"For two dimensional array `axis` should be either 0 or 1"
|
||||
);
|
||||
match axis {
|
||||
0 => Box::new(
|
||||
(0..self.nrows).flat_map(move |r| (0..self.ncols).map(move |c| self.get((r, c)))),
|
||||
),
|
||||
_ => Box::new(
|
||||
(0..self.ncols).flat_map(move |c| (0..self.nrows).map(move |r| self.get((r, c)))),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for DenseMatrix<T> {
|
||||
fn set(&mut self, pos: (usize, usize), x: T) {
|
||||
if self.column_major {
|
||||
self.values[pos.1 * self.nrows + pos.0] = x;
|
||||
} else {
|
||||
self.values[pos.1 + pos.0 * self.ncols] = x;
|
||||
}
|
||||
}
|
||||
|
||||
fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
|
||||
let ptr = self.values.as_mut_ptr();
|
||||
let column_major = self.column_major;
|
||||
let (nrows, ncols) = self.shape();
|
||||
match axis {
|
||||
0 => Box::new((0..self.nrows).flat_map(move |r| {
|
||||
(0..self.ncols).map(move |c| unsafe {
|
||||
&mut *ptr.add(if column_major {
|
||||
r + c * nrows
|
||||
} else {
|
||||
r * ncols + c
|
||||
})
|
||||
})
|
||||
})),
|
||||
_ => Box::new((0..self.ncols).flat_map(move |c| {
|
||||
(0..self.nrows).map(move |r| unsafe {
|
||||
&mut *ptr.add(if column_major {
|
||||
r + c * nrows
|
||||
} else {
|
||||
r * ncols + c
|
||||
})
|
||||
})
|
||||
})),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrix<T> {}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for DenseMatrix<T> {}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> Array2<T> for DenseMatrix<T> {
|
||||
fn get_row<'a>(&'a self, row: usize) -> Box<dyn ArrayView1<T> + 'a> {
|
||||
Box::new(DenseMatrixView::new(self, row..row + 1, 0..self.ncols).unwrap())
|
||||
}
|
||||
|
||||
fn get_col<'a>(&'a self, col: usize) -> Box<dyn ArrayView1<T> + 'a> {
|
||||
Box::new(DenseMatrixView::new(self, 0..self.nrows, col..col + 1).unwrap())
|
||||
}
|
||||
|
||||
fn slice<'a>(&'a self, rows: Range<usize>, cols: Range<usize>) -> Box<dyn ArrayView2<T> + 'a> {
|
||||
Box::new(DenseMatrixView::new(self, rows, cols).unwrap())
|
||||
}
|
||||
|
||||
fn slice_mut<'a>(
|
||||
&'a mut self,
|
||||
rows: Range<usize>,
|
||||
cols: Range<usize>,
|
||||
) -> Box<dyn MutArrayView2<T> + 'a>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
Box::new(DenseMatrixMutView::new(self, rows, cols).unwrap())
|
||||
}
|
||||
|
||||
// private function so for now assume infalible
|
||||
fn fill(nrows: usize, ncols: usize, value: T) -> Self {
|
||||
DenseMatrix::new(nrows, ncols, vec![value; nrows * ncols], true).unwrap()
|
||||
}
|
||||
|
||||
// private function so for now assume infalible
|
||||
fn from_iterator<I: Iterator<Item = T>>(iter: I, nrows: usize, ncols: usize, axis: u8) -> Self {
|
||||
DenseMatrix::new(nrows, ncols, iter.collect(), axis != 0).unwrap()
|
||||
}
|
||||
|
||||
fn transpose(&self) -> Self {
|
||||
let mut m = self.clone();
|
||||
m.ncols = self.nrows;
|
||||
m.nrows = self.ncols;
|
||||
m.column_major = !self.column_major;
|
||||
m
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Number + RealNumber> QRDecomposable<T> for DenseMatrix<T> {}
|
||||
impl<T: Number + RealNumber> CholeskyDecomposable<T> for DenseMatrix<T> {}
|
||||
impl<T: Number + RealNumber> EVDDecomposable<T> for DenseMatrix<T> {}
|
||||
impl<T: Number + RealNumber> LUDecomposable<T> for DenseMatrix<T> {}
|
||||
impl<T: Number + RealNumber> SVDDecomposable<T> for DenseMatrix<T> {}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixView<'_, T> {
|
||||
fn get(&self, pos: (usize, usize)) -> &T {
|
||||
if self.column_major {
|
||||
&self.values[pos.0 + pos.1 * self.stride]
|
||||
} else {
|
||||
&self.values[pos.0 * self.stride + pos.1]
|
||||
}
|
||||
}
|
||||
|
||||
fn shape(&self) -> (usize, usize) {
|
||||
(self.nrows, self.ncols)
|
||||
}
|
||||
|
||||
fn is_empty(&self) -> bool {
|
||||
self.nrows * self.ncols > 0
|
||||
}
|
||||
|
||||
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
|
||||
self.iter(axis)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for DenseMatrixView<'_, T> {
|
||||
fn get(&self, i: usize) -> &T {
|
||||
if self.nrows == 1 {
|
||||
if self.column_major {
|
||||
&self.values[i * self.stride]
|
||||
} else {
|
||||
&self.values[i]
|
||||
}
|
||||
} else if self.ncols == 1 || (!self.column_major && self.nrows == 1) {
|
||||
if self.column_major {
|
||||
&self.values[i]
|
||||
} else {
|
||||
&self.values[i * self.stride]
|
||||
}
|
||||
} else {
|
||||
panic!("This is neither a column nor a row");
|
||||
}
|
||||
}
|
||||
|
||||
fn shape(&self) -> usize {
|
||||
if self.nrows == 1 {
|
||||
self.ncols
|
||||
} else if self.ncols == 1 {
|
||||
self.nrows
|
||||
} else {
|
||||
panic!("This is neither a column nor a row");
|
||||
}
|
||||
}
|
||||
|
||||
fn is_empty(&self) -> bool {
|
||||
self.nrows * self.ncols > 0
|
||||
}
|
||||
|
||||
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
|
||||
self.iter(axis)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixView<'_, T> {}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for DenseMatrixView<'_, T> {}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixMutView<'_, T> {
|
||||
fn get(&self, pos: (usize, usize)) -> &T {
|
||||
if self.column_major {
|
||||
&self.values[pos.0 + pos.1 * self.stride]
|
||||
} else {
|
||||
&self.values[pos.0 * self.stride + pos.1]
|
||||
}
|
||||
}
|
||||
|
||||
fn shape(&self) -> (usize, usize) {
|
||||
(self.nrows, self.ncols)
|
||||
}
|
||||
|
||||
fn is_empty(&self) -> bool {
|
||||
self.nrows * self.ncols > 0
|
||||
}
|
||||
|
||||
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
|
||||
self.iter(axis)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for DenseMatrixMutView<'_, T> {
|
||||
fn set(&mut self, pos: (usize, usize), x: T) {
|
||||
if self.column_major {
|
||||
self.values[pos.0 + pos.1 * self.stride] = x;
|
||||
} else {
|
||||
self.values[pos.0 * self.stride + pos.1] = x;
|
||||
}
|
||||
}
|
||||
|
||||
fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
|
||||
self.iter_mut(axis)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for DenseMatrixMutView<'_, T> {}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixMutView<'_, T> {}
|
||||
|
||||
impl<T: RealNumber> MatrixStats<T> for DenseMatrix<T> {}
|
||||
|
||||
impl<T: RealNumber> MatrixPreprocessing<T> for DenseMatrix<T> {}
|
||||
|
||||
#[cfg(test)]
|
||||
#[warn(clippy::reversed_empty_ranges)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use approx::relative_eq;
|
||||
|
||||
#[test]
|
||||
fn test_instantiate_from_2d() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
|
||||
assert!(x.is_ok());
|
||||
}
|
||||
#[test]
|
||||
fn test_instantiate_from_2d_empty() {
|
||||
let input: &[&[f64]] = &[&[]];
|
||||
let x = DenseMatrix::from_2d_array(input);
|
||||
assert!(x.is_err());
|
||||
}
|
||||
#[test]
|
||||
fn test_instantiate_from_2d_empty2() {
|
||||
let input: &[&[f64]] = &[&[], &[]];
|
||||
let x = DenseMatrix::from_2d_array(input);
|
||||
assert!(x.is_err());
|
||||
}
|
||||
#[test]
|
||||
fn test_instantiate_ok_view1() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
|
||||
let v = DenseMatrixView::new(&x, 0..2, 0..2);
|
||||
assert!(v.is_ok());
|
||||
}
|
||||
#[test]
|
||||
fn test_instantiate_ok_view2() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
|
||||
let v = DenseMatrixView::new(&x, 0..3, 0..3);
|
||||
assert!(v.is_ok());
|
||||
}
|
||||
#[test]
|
||||
fn test_instantiate_ok_view3() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
|
||||
let v = DenseMatrixView::new(&x, 2..3, 0..3);
|
||||
assert!(v.is_ok());
|
||||
}
|
||||
#[test]
|
||||
fn test_instantiate_ok_view4() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
|
||||
let v = DenseMatrixView::new(&x, 3..3, 0..3);
|
||||
assert!(v.is_ok());
|
||||
}
|
||||
#[test]
|
||||
fn test_instantiate_err_view1() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
|
||||
let v = DenseMatrixView::new(&x, 3..4, 0..3);
|
||||
assert!(v.is_err());
|
||||
}
|
||||
#[test]
|
||||
fn test_instantiate_err_view2() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
|
||||
let v = DenseMatrixView::new(&x, 0..3, 3..4);
|
||||
assert!(v.is_err());
|
||||
}
|
||||
#[test]
|
||||
fn test_instantiate_err_view3() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
|
||||
#[allow(clippy::reversed_empty_ranges)]
|
||||
let v = DenseMatrixView::new(&x, 0..3, 4..3);
|
||||
assert!(v.is_err());
|
||||
}
|
||||
#[test]
|
||||
fn test_display() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
|
||||
|
||||
println!("{}", &x);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_row_col() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
|
||||
|
||||
assert_eq!(15.0, x.get_col(1).sum());
|
||||
assert_eq!(15.0, x.get_row(1).sum());
|
||||
assert_eq!(81.0, x.get_col(1).dot(&(*x.get_row(1))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_row_major() {
|
||||
let mut x = DenseMatrix::new(2, 3, vec![1, 2, 3, 4, 5, 6], false).unwrap();
|
||||
|
||||
assert_eq!(5, *x.get_col(1).get(1));
|
||||
assert_eq!(7, x.get_col(1).sum());
|
||||
assert_eq!(5, *x.get_row(1).get(1));
|
||||
assert_eq!(15, x.get_row(1).sum());
|
||||
x.slice_mut(0..2, 1..2)
|
||||
.iterator_mut(0)
|
||||
.for_each(|v| *v += 2);
|
||||
assert_eq!(vec![1, 4, 3, 4, 7, 6], *x.values);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_slice() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9], &[10, 11, 12]])
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
vec![4, 5, 6],
|
||||
DenseMatrix::from_slice(&(*x.slice(1..2, 0..3))).values
|
||||
);
|
||||
let second_row: Vec<i32> = x.slice(1..2, 0..3).iterator(0).copied().collect();
|
||||
assert_eq!(vec![4, 5, 6], second_row);
|
||||
let second_col: Vec<i32> = x.slice(0..3, 1..2).iterator(0).copied().collect();
|
||||
assert_eq!(vec![2, 5, 8], second_col);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_iter_mut() {
|
||||
let mut x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]]).unwrap();
|
||||
|
||||
assert_eq!(vec![1, 4, 7, 2, 5, 8, 3, 6, 9], x.values);
|
||||
// add +2 to some elements
|
||||
x.slice_mut(1..2, 0..3)
|
||||
.iterator_mut(0)
|
||||
.for_each(|v| *v += 2);
|
||||
assert_eq!(vec![1, 6, 7, 2, 7, 8, 3, 8, 9], x.values);
|
||||
// add +1 to some others
|
||||
x.slice_mut(0..3, 1..2)
|
||||
.iterator_mut(0)
|
||||
.for_each(|v| *v += 1);
|
||||
assert_eq!(vec![1, 6, 7, 3, 8, 9, 3, 8, 9], x.values);
|
||||
|
||||
// rewrite matrix as indices of values per axis 1 (row-wise)
|
||||
x.iterator_mut(1).enumerate().for_each(|(a, b)| *b = a);
|
||||
assert_eq!(vec![0, 1, 2, 3, 4, 5, 6, 7, 8], x.values);
|
||||
// rewrite matrix as indices of values per axis 0 (column-wise)
|
||||
x.iterator_mut(0).enumerate().for_each(|(a, b)| *b = a);
|
||||
assert_eq!(vec![0, 3, 6, 1, 4, 7, 2, 5, 8], x.values);
|
||||
// rewrite some by slice
|
||||
x.slice_mut(0..3, 0..2)
|
||||
.iterator_mut(0)
|
||||
.enumerate()
|
||||
.for_each(|(a, b)| *b = a);
|
||||
assert_eq!(vec![0, 2, 4, 1, 3, 5, 2, 5, 8], x.values);
|
||||
x.slice_mut(0..2, 0..3)
|
||||
.iterator_mut(1)
|
||||
.enumerate()
|
||||
.for_each(|(a, b)| *b = a);
|
||||
assert_eq!(vec![0, 1, 4, 2, 3, 5, 4, 5, 8], x.values);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_str_array() {
|
||||
let mut x =
|
||||
DenseMatrix::from_2d_array(&[&["1", "2", "3"], &["4", "5", "6"], &["7", "8", "9"]])
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(vec!["1", "4", "7", "2", "5", "8", "3", "6", "9"], x.values);
|
||||
x.iterator_mut(0).for_each(|v| *v = "str");
|
||||
assert_eq!(
|
||||
vec!["str", "str", "str", "str", "str", "str", "str", "str", "str"],
|
||||
x.values
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transpose() {
|
||||
let x = DenseMatrix::<&str>::from_2d_array(&[&["1", "2", "3"], &["4", "5", "6"]]).unwrap();
|
||||
|
||||
assert_eq!(vec!["1", "4", "2", "5", "3", "6"], x.values);
|
||||
assert!(x.column_major);
|
||||
|
||||
// transpose
|
||||
let x = x.transpose();
|
||||
assert_eq!(vec!["1", "4", "2", "5", "3", "6"], x.values);
|
||||
assert!(!x.column_major); // should change column_major
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_iterator() {
|
||||
let data = [1, 2, 3, 4, 5, 6];
|
||||
|
||||
let m = DenseMatrix::from_iterator(data.iter(), 2, 3, 0);
|
||||
|
||||
// make a vector into a 2x3 matrix.
|
||||
assert_eq!(
|
||||
vec![1, 2, 3, 4, 5, 6],
|
||||
m.values.iter().map(|e| **e).collect::<Vec<i32>>()
|
||||
);
|
||||
assert!(!m.column_major);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_take() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
|
||||
let b = DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4], &[5, 6]]).unwrap();
|
||||
|
||||
println!("{a}");
|
||||
// take column 0 and 2
|
||||
assert_eq!(vec![1, 3, 4, 6], a.take(&[0, 2], 1).values);
|
||||
println!("{b}");
|
||||
// take rows 0 and 2
|
||||
assert_eq!(vec![1, 2, 5, 6], b.take(&[0, 2], 0).values);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mut() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1.3, -2.1, 3.4], &[-4., -5.3, 6.1]]).unwrap();
|
||||
|
||||
let a = a.abs();
|
||||
assert_eq!(vec![1.3, 4.0, 2.1, 5.3, 3.4, 6.1], a.values);
|
||||
|
||||
let a = a.neg();
|
||||
assert_eq!(vec![-1.3, -4.0, -2.1, -5.3, -3.4, -6.1], a.values);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reshape() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9], &[10, 11, 12]])
|
||||
.unwrap();
|
||||
|
||||
let a = a.reshape(2, 6, 0);
|
||||
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], a.values);
|
||||
assert!(a.ncols == 6 && a.nrows == 2 && !a.column_major);
|
||||
|
||||
let a = a.reshape(3, 4, 1);
|
||||
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], a.values);
|
||||
assert!(a.ncols == 4 && a.nrows == 3 && a.column_major);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_eq() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]).unwrap();
|
||||
let b = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
|
||||
let c = DenseMatrix::from_2d_array(&[
|
||||
&[1. + f32::EPSILON, 2., 3.],
|
||||
&[4., 5., 6. + f32::EPSILON],
|
||||
])
|
||||
.unwrap();
|
||||
let d = DenseMatrix::from_2d_array(&[&[1. + 0.5, 2., 3.], &[4., 5., 6. + f32::EPSILON]])
|
||||
.unwrap();
|
||||
|
||||
assert!(!relative_eq!(a, b));
|
||||
assert!(!relative_eq!(a, d));
|
||||
assert!(relative_eq!(a, c));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
/// `Array`, `ArrayView` and related multidimensional
|
||||
pub mod arrays;
|
||||
|
||||
/// foundamental implementation for a `DenseMatrix` construct
|
||||
pub mod matrix;
|
||||
|
||||
/// foundamental implementation for 1D constructs
|
||||
pub mod vector;
|
||||
@@ -0,0 +1,348 @@
|
||||
use std::fmt::{Debug, Display};
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::linalg::basic::arrays::{Array, Array1, ArrayView1, MutArray, MutArrayView1};
|
||||
|
||||
/// Provide mutable window on array
|
||||
#[derive(Debug)]
|
||||
pub struct VecMutView<'a, T: Debug + Display + Copy + Sized> {
|
||||
ptr: &'a mut [T],
|
||||
}
|
||||
|
||||
/// Provide window on array
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VecView<'a, T: Debug + Display + Copy + Sized> {
|
||||
ptr: &'a [T],
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for &[T] {
|
||||
fn get(&self, i: usize) -> &T {
|
||||
&self[i]
|
||||
}
|
||||
|
||||
fn shape(&self) -> usize {
|
||||
self.len()
|
||||
}
|
||||
|
||||
fn is_empty(&self) -> bool {
|
||||
self.len() > 0
|
||||
}
|
||||
|
||||
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
|
||||
assert!(axis == 0, "For one dimensional array `axis` should == 0");
|
||||
Box::new(self.iter())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for Vec<T> {
|
||||
fn get(&self, i: usize) -> &T {
|
||||
&self[i]
|
||||
}
|
||||
|
||||
fn shape(&self) -> usize {
|
||||
self.len()
|
||||
}
|
||||
|
||||
fn is_empty(&self) -> bool {
|
||||
self.len() > 0
|
||||
}
|
||||
|
||||
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
|
||||
assert!(axis == 0, "For one dimensional array `axis` should == 0");
|
||||
Box::new(self.iter())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for Vec<T> {
|
||||
fn set(&mut self, i: usize, x: T) {
|
||||
// NOTE: this panics in case of out of bounds index
|
||||
self[i] = x
|
||||
}
|
||||
|
||||
fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
|
||||
assert!(axis == 0, "For one dimensional array `axis` should == 0");
|
||||
Box::new(self.iter_mut())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for Vec<T> {}
|
||||
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for &[T] {}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for Vec<T> {}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> Array1<T> for Vec<T> {
|
||||
fn slice<'a>(&'a self, range: Range<usize>) -> Box<dyn ArrayView1<T> + 'a> {
|
||||
assert!(
|
||||
range.end <= self.len(),
|
||||
"`range` should be <= {}",
|
||||
self.len()
|
||||
);
|
||||
let view = VecView { ptr: &self[range] };
|
||||
Box::new(view)
|
||||
}
|
||||
|
||||
fn slice_mut<'b>(&'b mut self, range: Range<usize>) -> Box<dyn MutArrayView1<T> + 'b> {
|
||||
assert!(
|
||||
range.end <= self.len(),
|
||||
"`range` should be <= {}",
|
||||
self.len()
|
||||
);
|
||||
let view = VecMutView {
|
||||
ptr: &mut self[range],
|
||||
};
|
||||
Box::new(view)
|
||||
}
|
||||
|
||||
fn fill(len: usize, value: T) -> Self {
|
||||
vec![value; len]
|
||||
}
|
||||
|
||||
fn from_iterator<I: Iterator<Item = T>>(iter: I, len: usize) -> Self
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let mut v: Vec<T> = Vec::with_capacity(len);
|
||||
iter.take(len).for_each(|i| v.push(i));
|
||||
v
|
||||
}
|
||||
|
||||
fn from_vec_slice(slice: &[T]) -> Self {
|
||||
let mut v: Vec<T> = Vec::with_capacity(slice.len());
|
||||
slice.iter().for_each(|i| v.push(*i));
|
||||
v
|
||||
}
|
||||
|
||||
fn from_slice(slice: &dyn ArrayView1<T>) -> Self {
|
||||
let mut v: Vec<T> = Vec::with_capacity(slice.shape());
|
||||
slice.iterator(0).for_each(|i| v.push(*i));
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for VecMutView<'_, T> {
|
||||
fn get(&self, i: usize) -> &T {
|
||||
&self.ptr[i]
|
||||
}
|
||||
|
||||
fn shape(&self) -> usize {
|
||||
self.ptr.len()
|
||||
}
|
||||
|
||||
fn is_empty(&self) -> bool {
|
||||
self.ptr.len() > 0
|
||||
}
|
||||
|
||||
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
|
||||
assert!(axis == 0, "For one dimensional array `axis` should == 0");
|
||||
Box::new(self.ptr.iter())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for VecMutView<'_, T> {
|
||||
fn set(&mut self, i: usize, x: T) {
|
||||
self.ptr[i] = x;
|
||||
}
|
||||
|
||||
fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
|
||||
assert!(axis == 0, "For one dimensional array `axis` should == 0");
|
||||
Box::new(self.ptr.iter_mut())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for VecMutView<'_, T> {}
|
||||
impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for VecMutView<'_, T> {}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for VecView<'_, T> {
|
||||
fn get(&self, i: usize) -> &T {
|
||||
&self.ptr[i]
|
||||
}
|
||||
|
||||
fn shape(&self) -> usize {
|
||||
self.ptr.len()
|
||||
}
|
||||
|
||||
fn is_empty(&self) -> bool {
|
||||
self.ptr.len() > 0
|
||||
}
|
||||
|
||||
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
|
||||
assert!(axis == 0, "For one dimensional array `axis` should == 0");
|
||||
Box::new(self.ptr.iter())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for VecView<'_, T> {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::numbers::basenum::Number;
|
||||
|
||||
fn dot_product<T: Number, V: Array1<T>>(v: &V) -> T {
|
||||
let vv = V::zeros(10);
|
||||
let v_s = vv.slice(0..3);
|
||||
|
||||
v_s.dot(v)
|
||||
}
|
||||
|
||||
fn vector_ops<T: Number + PartialOrd, V: Array1<T>>(_: &V) -> T {
|
||||
let v = V::zeros(10);
|
||||
v.max()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_set() {
|
||||
let mut x = vec![1, 2, 3];
|
||||
assert_eq!(3, *x.get(2));
|
||||
x.set(1, 1);
|
||||
assert_eq!(1, *x.get(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_failed_set() {
|
||||
vec![1, 2, 3].set(3, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_failed_get() {
|
||||
vec![1, 2, 3].get(3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_len() {
|
||||
let x = [1, 2, 3];
|
||||
assert_eq!(3, x.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_empty() {
|
||||
assert!(vec![1; 0].is_empty());
|
||||
assert!(!vec![1, 2, 3].is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_iterator() {
|
||||
let v: Vec<i32> = vec![1, 2, 3].iterator(0).map(|&v| v * 2).collect();
|
||||
assert_eq!(vec![2, 4, 6], v);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_failed_iterator() {
|
||||
let _ = vec![1, 2, 3].iterator(1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mut_iterator() {
|
||||
let mut x = vec![1, 2, 3];
|
||||
x.iterator_mut(0).for_each(|v| *v *= 2);
|
||||
assert_eq!(vec![2, 4, 6], x);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_failed_mut_iterator() {
|
||||
let _ = vec![1, 2, 3].iterator_mut(1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slice() {
|
||||
let x = vec![1, 2, 3, 4, 5];
|
||||
let x_slice = x.slice(2..3);
|
||||
assert_eq!(1, x_slice.shape());
|
||||
assert_eq!(3, *x_slice.get(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_failed_slice() {
|
||||
vec![1, 2, 3].slice(0..4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mut_slice() {
|
||||
let mut x = vec![1, 2, 3, 4, 5];
|
||||
let mut x_slice = x.slice_mut(2..4);
|
||||
x_slice.set(0, 9);
|
||||
assert_eq!(2, x_slice.shape());
|
||||
assert_eq!(9, *x_slice.get(0));
|
||||
assert_eq!(4, *x_slice.get(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_failed_mut_slice() {
|
||||
vec![1, 2, 3].slice_mut(0..4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_init() {
|
||||
assert_eq!(Vec::fill(3, 0), vec![0, 0, 0]);
|
||||
assert_eq!(
|
||||
Vec::from_iterator([0, 1, 2, 3].iter().cloned(), 3),
|
||||
vec![0, 1, 2]
|
||||
);
|
||||
assert_eq!(Vec::from_vec_slice(&[0, 1, 2]), vec![0, 1, 2]);
|
||||
assert_eq!(Vec::from_vec_slice(&[0, 1, 2, 3, 4][2..]), vec![2, 3, 4]);
|
||||
assert_eq!(Vec::from_slice(&vec![1, 2, 3, 4, 5]), vec![1, 2, 3, 4, 5]);
|
||||
assert_eq!(
|
||||
Vec::from_slice(vec![1, 2, 3, 4, 5].slice(0..3).as_ref()),
|
||||
vec![1, 2, 3]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mul_scalar() {
|
||||
let mut x = vec![1., 2., 3.];
|
||||
|
||||
let mut y = Vec::<f32>::zeros(10);
|
||||
|
||||
y.slice_mut(0..2).add_scalar_mut(1.0);
|
||||
y.sub_scalar(1.0);
|
||||
x.slice_mut(0..2).sub_scalar_mut(2.);
|
||||
|
||||
assert_eq!(vec![-1.0, 0.0, 3.0], x);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dot() {
|
||||
let y_i = vec![1, 2, 3];
|
||||
let y = vec![1.0, 2.0, 3.0];
|
||||
|
||||
println!("Regular dot1: {:?}", dot_product(&y));
|
||||
|
||||
let x = vec![4.0, 5.0, 6.0];
|
||||
assert_eq!(32.0, y.slice(0..3).dot(&(*x.slice(0..3))));
|
||||
assert_eq!(32.0, y.slice(0..3).dot(&x));
|
||||
assert_eq!(32.0, y.dot(&x));
|
||||
assert_eq!(14, y_i.dot(&y_i));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_operators() {
|
||||
let mut x: Vec<f32> = Vec::zeros(10);
|
||||
|
||||
x.add_scalar(15.0);
|
||||
{
|
||||
let mut x_s = x.slice_mut(0..5);
|
||||
x_s.add_scalar_mut(1.0);
|
||||
assert_eq!(
|
||||
vec![1.0, 1.0, 1.0, 1.0, 1.0],
|
||||
x_s.iterator(0).copied().collect::<Vec<f32>>()
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(1.0, x.slice(2..3).min());
|
||||
|
||||
assert_eq!(vec![1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], x);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vector_ops() {
|
||||
let x = vec![1., 2., 3.];
|
||||
|
||||
vector_ops(&x);
|
||||
}
|
||||
}
|
||||
+7
-783
@@ -1,785 +1,9 @@
|
||||
#![allow(clippy::wrong_self_convention)]
|
||||
//! # Linear Algebra and Matrix Decomposition
|
||||
//!
|
||||
//! Most machine learning algorithms in SmartCore depend on linear algebra and matrix decomposition methods from this module.
|
||||
//!
|
||||
//! Traits [`BaseMatrix`](trait.BaseMatrix.html), [`Matrix`](trait.Matrix.html) and [`BaseVector`](trait.BaseVector.html) define
|
||||
//! abstract methods that can be implemented for any two-dimensional and one-dimentional arrays (matrix and vector).
|
||||
//! Functions from these traits are designed for SmartCore machine learning algorithms and should not be used directly in your code.
|
||||
//! If you still want to use functions from `BaseMatrix`, `Matrix` and `BaseVector` please be aware that methods defined in these
|
||||
//! traits might change in the future.
|
||||
//!
|
||||
//! One reason why linear algebra traits are public is to allow for different types of matrices and vectors to be plugged into SmartCore.
|
||||
//! Once all methods defined in `BaseMatrix`, `Matrix` and `BaseVector` are implemented for your favourite type of matrix and vector you
|
||||
//! should be able to run SmartCore algorithms on it. Please see `nalgebra_bindings` and `ndarray_bindings` modules for an example of how
|
||||
//! it is done for other libraries.
|
||||
//!
|
||||
//! You will also find verious matrix decomposition methods that work for any matrix that extends [`Matrix`](trait.Matrix.html).
|
||||
//! For example, to decompose matrix defined as [Vec](https://doc.rust-lang.org/std/vec/struct.Vec.html):
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::linalg::svd::*;
|
||||
//!
|
||||
//! let A = DenseMatrix::from_2d_array(&[
|
||||
//! &[0.9000, 0.4000, 0.7000],
|
||||
//! &[0.4000, 0.5000, 0.3000],
|
||||
//! &[0.7000, 0.3000, 0.8000],
|
||||
//! ]);
|
||||
//!
|
||||
//! let svd = A.svd().unwrap();
|
||||
//!
|
||||
//! let s: Vec<f64> = svd.s;
|
||||
//! let v: DenseMatrix<f64> = svd.V;
|
||||
//! let u: DenseMatrix<f64> = svd.U;
|
||||
//! ```
|
||||
/// basic data structures for linear algebra constructs: arrays and views
|
||||
pub mod basic;
|
||||
|
||||
/// traits associated to algebraic constructs
|
||||
pub mod traits;
|
||||
|
||||
pub mod cholesky;
|
||||
/// The matrix is represented in terms of its eigenvalues and eigenvectors.
|
||||
pub mod evd;
|
||||
pub mod high_order;
|
||||
/// Factors a matrix as the product of a lower triangular matrix and an upper triangular matrix.
|
||||
pub mod lu;
|
||||
/// Dense matrix with column-major order that wraps [Vec](https://doc.rust-lang.org/std/vec/struct.Vec.html).
|
||||
pub mod naive;
|
||||
/// [nalgebra](https://docs.rs/nalgebra/) bindings.
|
||||
#[cfg(feature = "nalgebra-bindings")]
|
||||
pub mod nalgebra_bindings;
|
||||
/// [ndarray](https://docs.rs/ndarray) bindings.
|
||||
#[cfg(feature = "ndarray-bindings")]
|
||||
pub mod ndarray_bindings;
|
||||
/// QR factorization that factors a matrix into a product of an orthogonal matrix and an upper triangular matrix.
|
||||
pub mod qr;
|
||||
pub mod stats;
|
||||
/// Singular value decomposition.
|
||||
pub mod svd;
|
||||
|
||||
use std::fmt::{Debug, Display};
|
||||
use std::marker::PhantomData;
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::math::num::RealNumber;
|
||||
use cholesky::CholeskyDecomposableMatrix;
|
||||
use evd::EVDDecomposableMatrix;
|
||||
use high_order::HighOrderOperations;
|
||||
use lu::LUDecomposableMatrix;
|
||||
use qr::QRDecomposableMatrix;
|
||||
use stats::{MatrixPreprocessing, MatrixStats};
|
||||
use svd::SVDDecomposableMatrix;
|
||||
|
||||
/// Column or row vector
|
||||
pub trait BaseVector<T: RealNumber>: Clone + Debug {
|
||||
/// Get an element of a vector
|
||||
/// * `i` - index of an element
|
||||
fn get(&self, i: usize) -> T;
|
||||
|
||||
/// Set an element at `i` to `x`
|
||||
/// * `i` - index of an element
|
||||
/// * `x` - new value
|
||||
fn set(&mut self, i: usize, x: T);
|
||||
|
||||
/// Get number of elevemnt in the vector
|
||||
fn len(&self) -> usize;
|
||||
|
||||
/// Returns true if the vector is empty.
|
||||
fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
|
||||
/// Create a new vector from a &[T]
|
||||
/// ```
|
||||
/// use smartcore::linalg::naive::dense_matrix::*;
|
||||
/// let a: [f64; 5] = [0., 0.5, 2., 3., 4.];
|
||||
/// let v: Vec<f64> = BaseVector::from_array(&a);
|
||||
/// assert_eq!(v, vec![0., 0.5, 2., 3., 4.]);
|
||||
/// ```
|
||||
fn from_array(f: &[T]) -> Self {
|
||||
let mut v = Self::zeros(f.len());
|
||||
for (i, elem) in f.iter().enumerate() {
|
||||
v.set(i, *elem);
|
||||
}
|
||||
v
|
||||
}
|
||||
|
||||
/// Return a vector with the elements of the one-dimensional array.
|
||||
fn to_vec(&self) -> Vec<T>;
|
||||
|
||||
/// Create new vector with zeros of size `len`.
|
||||
fn zeros(len: usize) -> Self;
|
||||
|
||||
/// Create new vector with ones of size `len`.
|
||||
fn ones(len: usize) -> Self;
|
||||
|
||||
/// Create new vector of size `len` where each element is set to `value`.
|
||||
fn fill(len: usize, value: T) -> Self;
|
||||
|
||||
/// Vector dot product
|
||||
fn dot(&self, other: &Self) -> T;
|
||||
|
||||
/// Returns True if matrices are element-wise equal within a tolerance `error`.
|
||||
fn approximate_eq(&self, other: &Self, error: T) -> bool;
|
||||
|
||||
/// Returns [L2 norm] of the vector(https://en.wikipedia.org/wiki/Matrix_norm).
|
||||
fn norm2(&self) -> T;
|
||||
|
||||
/// Returns [vectors norm](https://en.wikipedia.org/wiki/Matrix_norm) of order `p`.
|
||||
fn norm(&self, p: T) -> T;
|
||||
|
||||
/// Divide single element of the vector by `x`, write result to original vector.
|
||||
fn div_element_mut(&mut self, pos: usize, x: T);
|
||||
|
||||
/// Multiply single element of the vector by `x`, write result to original vector.
|
||||
fn mul_element_mut(&mut self, pos: usize, x: T);
|
||||
|
||||
/// Add single element of the vector to `x`, write result to original vector.
|
||||
fn add_element_mut(&mut self, pos: usize, x: T);
|
||||
|
||||
/// Subtract `x` from single element of the vector, write result to original vector.
|
||||
fn sub_element_mut(&mut self, pos: usize, x: T);
|
||||
|
||||
/// Subtract scalar
|
||||
fn sub_scalar_mut(&mut self, x: T) -> &Self {
|
||||
for i in 0..self.len() {
|
||||
self.set(i, self.get(i) - x);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Subtract scalar
|
||||
fn add_scalar_mut(&mut self, x: T) -> &Self {
|
||||
for i in 0..self.len() {
|
||||
self.set(i, self.get(i) + x);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Subtract scalar
|
||||
fn mul_scalar_mut(&mut self, x: T) -> &Self {
|
||||
for i in 0..self.len() {
|
||||
self.set(i, self.get(i) * x);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Subtract scalar
|
||||
fn div_scalar_mut(&mut self, x: T) -> &Self {
|
||||
for i in 0..self.len() {
|
||||
self.set(i, self.get(i) / x);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Add vectors, element-wise
|
||||
fn add_scalar(&self, x: T) -> Self {
|
||||
let mut r = self.clone();
|
||||
r.add_scalar_mut(x);
|
||||
r
|
||||
}
|
||||
|
||||
/// Subtract vectors, element-wise
|
||||
fn sub_scalar(&self, x: T) -> Self {
|
||||
let mut r = self.clone();
|
||||
r.sub_scalar_mut(x);
|
||||
r
|
||||
}
|
||||
|
||||
/// Multiply vectors, element-wise
|
||||
fn mul_scalar(&self, x: T) -> Self {
|
||||
let mut r = self.clone();
|
||||
r.mul_scalar_mut(x);
|
||||
r
|
||||
}
|
||||
|
||||
/// Divide vectors, element-wise
|
||||
fn div_scalar(&self, x: T) -> Self {
|
||||
let mut r = self.clone();
|
||||
r.div_scalar_mut(x);
|
||||
r
|
||||
}
|
||||
|
||||
/// Add vectors, element-wise, overriding original vector with result.
|
||||
fn add_mut(&mut self, other: &Self) -> &Self;
|
||||
|
||||
/// Subtract vectors, element-wise, overriding original vector with result.
|
||||
fn sub_mut(&mut self, other: &Self) -> &Self;
|
||||
|
||||
/// Multiply vectors, element-wise, overriding original vector with result.
|
||||
fn mul_mut(&mut self, other: &Self) -> &Self;
|
||||
|
||||
/// Divide vectors, element-wise, overriding original vector with result.
|
||||
fn div_mut(&mut self, other: &Self) -> &Self;
|
||||
|
||||
/// Add vectors, element-wise
|
||||
fn add(&self, other: &Self) -> Self {
|
||||
let mut r = self.clone();
|
||||
r.add_mut(other);
|
||||
r
|
||||
}
|
||||
|
||||
/// Subtract vectors, element-wise
|
||||
fn sub(&self, other: &Self) -> Self {
|
||||
let mut r = self.clone();
|
||||
r.sub_mut(other);
|
||||
r
|
||||
}
|
||||
|
||||
/// Multiply vectors, element-wise
|
||||
fn mul(&self, other: &Self) -> Self {
|
||||
let mut r = self.clone();
|
||||
r.mul_mut(other);
|
||||
r
|
||||
}
|
||||
|
||||
/// Divide vectors, element-wise
|
||||
fn div(&self, other: &Self) -> Self {
|
||||
let mut r = self.clone();
|
||||
r.div_mut(other);
|
||||
r
|
||||
}
|
||||
|
||||
/// Calculates sum of all elements of the vector.
|
||||
fn sum(&self) -> T;
|
||||
|
||||
/// Returns unique values from the vector.
|
||||
/// ```
|
||||
/// use smartcore::linalg::naive::dense_matrix::*;
|
||||
/// let a = vec!(1., 2., 2., -2., -6., -7., 2., 3., 4.);
|
||||
///
|
||||
///assert_eq!(a.unique(), vec![-7., -6., -2., 1., 2., 3., 4.]);
|
||||
/// ```
|
||||
fn unique(&self) -> Vec<T>;
|
||||
|
||||
/// Computes the arithmetic mean.
|
||||
fn mean(&self) -> T {
|
||||
self.sum() / T::from_usize(self.len()).unwrap()
|
||||
}
|
||||
/// Computes variance.
|
||||
fn var(&self) -> T {
|
||||
let n = self.len();
|
||||
|
||||
let mut mu = T::zero();
|
||||
let mut sum = T::zero();
|
||||
let div = T::from_usize(n).unwrap();
|
||||
for i in 0..n {
|
||||
let xi = self.get(i);
|
||||
mu += xi;
|
||||
sum += xi * xi;
|
||||
}
|
||||
mu /= div;
|
||||
sum / div - mu.powi(2)
|
||||
}
|
||||
/// Computes the standard deviation.
|
||||
fn std(&self) -> T {
|
||||
self.var().sqrt()
|
||||
}
|
||||
|
||||
/// Copies content of `other` vector.
|
||||
fn copy_from(&mut self, other: &Self);
|
||||
|
||||
/// Take elements from an array.
|
||||
fn take(&self, index: &[usize]) -> Self {
|
||||
let n = index.len();
|
||||
|
||||
let mut result = Self::zeros(n);
|
||||
|
||||
for (i, idx) in index.iter().enumerate() {
|
||||
result.set(i, self.get(*idx));
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
/// Generic matrix type.
|
||||
pub trait BaseMatrix<T: RealNumber>: Clone + Debug {
|
||||
/// Row vector that is associated with this matrix type,
|
||||
/// e.g. if we have an implementation of sparce matrix
|
||||
/// we should have an associated sparce vector type that
|
||||
/// represents a row in this matrix.
|
||||
type RowVector: BaseVector<T> + Clone + Debug;
|
||||
|
||||
/// Transforms row vector `vec` into a 1xM matrix.
|
||||
fn from_row_vector(vec: Self::RowVector) -> Self;
|
||||
|
||||
/// Transforms 1-d matrix of 1xM into a row vector.
|
||||
fn to_row_vector(self) -> Self::RowVector;
|
||||
|
||||
/// Get an element of the matrix.
|
||||
/// * `row` - row number
|
||||
/// * `col` - column number
|
||||
fn get(&self, row: usize, col: usize) -> T;
|
||||
|
||||
/// Get a vector with elements of the `row`'th row
|
||||
/// * `row` - row number
|
||||
fn get_row_as_vec(&self, row: usize) -> Vec<T>;
|
||||
|
||||
/// Get the `row`'th row
|
||||
/// * `row` - row number
|
||||
fn get_row(&self, row: usize) -> Self::RowVector;
|
||||
|
||||
/// Copies a vector with elements of the `row`'th row into `result`
|
||||
/// * `row` - row number
|
||||
/// * `result` - receiver for the row
|
||||
fn copy_row_as_vec(&self, row: usize, result: &mut Vec<T>);
|
||||
|
||||
/// Get a vector with elements of the `col`'th column
|
||||
/// * `col` - column number
|
||||
fn get_col_as_vec(&self, col: usize) -> Vec<T>;
|
||||
|
||||
/// Copies a vector with elements of the `col`'th column into `result`
|
||||
/// * `col` - column number
|
||||
/// * `result` - receiver for the col
|
||||
fn copy_col_as_vec(&self, col: usize, result: &mut Vec<T>);
|
||||
|
||||
/// Set an element at `col`, `row` to `x`
|
||||
fn set(&mut self, row: usize, col: usize, x: T);
|
||||
|
||||
/// Create an identity matrix of size `size`
|
||||
fn eye(size: usize) -> Self;
|
||||
|
||||
/// Create new matrix with zeros of size `nrows` by `ncols`.
|
||||
fn zeros(nrows: usize, ncols: usize) -> Self;
|
||||
|
||||
/// Create new matrix with ones of size `nrows` by `ncols`.
|
||||
fn ones(nrows: usize, ncols: usize) -> Self;
|
||||
|
||||
/// Create new matrix of size `nrows` by `ncols` where each element is set to `value`.
|
||||
fn fill(nrows: usize, ncols: usize, value: T) -> Self;
|
||||
|
||||
/// Return the shape of an array.
|
||||
fn shape(&self) -> (usize, usize);
|
||||
|
||||
/// Stack arrays in sequence vertically (row wise).
|
||||
/// ```
|
||||
/// use smartcore::linalg::naive::dense_matrix::*;
|
||||
///
|
||||
/// let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
|
||||
/// let b = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.]]);
|
||||
/// let expected = DenseMatrix::from_2d_array(&[
|
||||
/// &[1., 2., 3., 1., 2.],
|
||||
/// &[4., 5., 6., 3., 4.]
|
||||
/// ]);
|
||||
///
|
||||
/// assert_eq!(a.h_stack(&b), expected);
|
||||
/// ```
|
||||
fn h_stack(&self, other: &Self) -> Self;
|
||||
|
||||
/// Stack arrays in sequence horizontally (column wise).
|
||||
/// ```
|
||||
/// use smartcore::linalg::naive::dense_matrix::*;
|
||||
///
|
||||
/// let a = DenseMatrix::from_array(1, 3, &[1., 2., 3.]);
|
||||
/// let b = DenseMatrix::from_array(1, 3, &[4., 5., 6.]);
|
||||
/// let expected = DenseMatrix::from_2d_array(&[
|
||||
/// &[1., 2., 3.],
|
||||
/// &[4., 5., 6.]
|
||||
/// ]);
|
||||
///
|
||||
/// assert_eq!(a.v_stack(&b), expected);
|
||||
/// ```
|
||||
fn v_stack(&self, other: &Self) -> Self;
|
||||
|
||||
/// Matrix product.
|
||||
/// ```
|
||||
/// use smartcore::linalg::naive::dense_matrix::*;
|
||||
///
|
||||
/// let a = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.]]);
|
||||
/// let expected = DenseMatrix::from_2d_array(&[
|
||||
/// &[7., 10.],
|
||||
/// &[15., 22.]
|
||||
/// ]);
|
||||
///
|
||||
/// assert_eq!(a.matmul(&a), expected);
|
||||
/// ```
|
||||
fn matmul(&self, other: &Self) -> Self;
|
||||
|
||||
/// Vector dot product
|
||||
/// Both matrices should be of size _1xM_
|
||||
/// ```
|
||||
/// use smartcore::linalg::naive::dense_matrix::*;
|
||||
///
|
||||
/// let a = DenseMatrix::from_array(1, 3, &[1., 2., 3.]);
|
||||
/// let b = DenseMatrix::from_array(1, 3, &[4., 5., 6.]);
|
||||
///
|
||||
/// assert_eq!(a.dot(&b), 32.);
|
||||
/// ```
|
||||
fn dot(&self, other: &Self) -> T;
|
||||
|
||||
/// Return a slice of the matrix.
|
||||
/// * `rows` - range of rows to return
|
||||
/// * `cols` - range of columns to return
|
||||
/// ```
|
||||
/// use smartcore::linalg::naive::dense_matrix::*;
|
||||
///
|
||||
/// let m = DenseMatrix::from_2d_array(&[
|
||||
/// &[1., 2., 3., 1.],
|
||||
/// &[4., 5., 6., 3.],
|
||||
/// &[7., 8., 9., 5.]
|
||||
/// ]);
|
||||
/// let expected = DenseMatrix::from_2d_array(&[&[2., 3.], &[5., 6.]]);
|
||||
/// let result = m.slice(0..2, 1..3);
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
fn slice(&self, rows: Range<usize>, cols: Range<usize>) -> Self;
|
||||
|
||||
/// Returns True if matrices are element-wise equal within a tolerance `error`.
|
||||
fn approximate_eq(&self, other: &Self, error: T) -> bool;
|
||||
|
||||
/// Add matrices, element-wise, overriding original matrix with result.
|
||||
fn add_mut(&mut self, other: &Self) -> &Self;
|
||||
|
||||
/// Subtract matrices, element-wise, overriding original matrix with result.
|
||||
fn sub_mut(&mut self, other: &Self) -> &Self;
|
||||
|
||||
/// Multiply matrices, element-wise, overriding original matrix with result.
|
||||
fn mul_mut(&mut self, other: &Self) -> &Self;
|
||||
|
||||
/// Divide matrices, element-wise, overriding original matrix with result.
|
||||
fn div_mut(&mut self, other: &Self) -> &Self;
|
||||
|
||||
/// Divide single element of the matrix by `x`, write result to original matrix.
|
||||
fn div_element_mut(&mut self, row: usize, col: usize, x: T);
|
||||
|
||||
/// Multiply single element of the matrix by `x`, write result to original matrix.
|
||||
fn mul_element_mut(&mut self, row: usize, col: usize, x: T);
|
||||
|
||||
/// Add single element of the matrix to `x`, write result to original matrix.
|
||||
fn add_element_mut(&mut self, row: usize, col: usize, x: T);
|
||||
|
||||
/// Subtract `x` from single element of the matrix, write result to original matrix.
|
||||
fn sub_element_mut(&mut self, row: usize, col: usize, x: T);
|
||||
|
||||
/// Add matrices, element-wise
|
||||
fn add(&self, other: &Self) -> Self {
|
||||
let mut r = self.clone();
|
||||
r.add_mut(other);
|
||||
r
|
||||
}
|
||||
|
||||
/// Subtract matrices, element-wise
|
||||
fn sub(&self, other: &Self) -> Self {
|
||||
let mut r = self.clone();
|
||||
r.sub_mut(other);
|
||||
r
|
||||
}
|
||||
|
||||
/// Multiply matrices, element-wise
|
||||
fn mul(&self, other: &Self) -> Self {
|
||||
let mut r = self.clone();
|
||||
r.mul_mut(other);
|
||||
r
|
||||
}
|
||||
|
||||
/// Divide matrices, element-wise
|
||||
fn div(&self, other: &Self) -> Self {
|
||||
let mut r = self.clone();
|
||||
r.div_mut(other);
|
||||
r
|
||||
}
|
||||
|
||||
/// Add `scalar` to the matrix, override original matrix with result.
|
||||
fn add_scalar_mut(&mut self, scalar: T) -> &Self;
|
||||
|
||||
/// Subtract `scalar` from the elements of matrix, override original matrix with result.
|
||||
fn sub_scalar_mut(&mut self, scalar: T) -> &Self;
|
||||
|
||||
/// Multiply `scalar` by the elements of matrix, override original matrix with result.
|
||||
fn mul_scalar_mut(&mut self, scalar: T) -> &Self;
|
||||
|
||||
/// Divide elements of the matrix by `scalar`, override original matrix with result.
|
||||
fn div_scalar_mut(&mut self, scalar: T) -> &Self;
|
||||
|
||||
/// Add `scalar` to the matrix.
|
||||
fn add_scalar(&self, scalar: T) -> Self {
|
||||
let mut r = self.clone();
|
||||
r.add_scalar_mut(scalar);
|
||||
r
|
||||
}
|
||||
|
||||
/// Subtract `scalar` from the elements of matrix.
|
||||
fn sub_scalar(&self, scalar: T) -> Self {
|
||||
let mut r = self.clone();
|
||||
r.sub_scalar_mut(scalar);
|
||||
r
|
||||
}
|
||||
|
||||
/// Multiply `scalar` by the elements of matrix.
|
||||
fn mul_scalar(&self, scalar: T) -> Self {
|
||||
let mut r = self.clone();
|
||||
r.mul_scalar_mut(scalar);
|
||||
r
|
||||
}
|
||||
|
||||
/// Divide elements of the matrix by `scalar`.
|
||||
fn div_scalar(&self, scalar: T) -> Self {
|
||||
let mut r = self.clone();
|
||||
r.div_scalar_mut(scalar);
|
||||
r
|
||||
}
|
||||
|
||||
/// Reverse or permute the axes of the matrix, return new matrix.
|
||||
fn transpose(&self) -> Self;
|
||||
|
||||
/// Create new `nrows` by `ncols` matrix and populate it with random samples from a uniform distribution over [0, 1).
|
||||
fn rand(nrows: usize, ncols: usize) -> Self;
|
||||
|
||||
/// Returns [L2 norm](https://en.wikipedia.org/wiki/Matrix_norm).
|
||||
fn norm2(&self) -> T;
|
||||
|
||||
/// Returns [matrix norm](https://en.wikipedia.org/wiki/Matrix_norm) of order `p`.
|
||||
fn norm(&self, p: T) -> T;
|
||||
|
||||
/// Returns the average of the matrix columns.
|
||||
fn column_mean(&self) -> Vec<T>;
|
||||
|
||||
/// Numerical negative, element-wise. Overrides original matrix.
|
||||
fn negative_mut(&mut self);
|
||||
|
||||
/// Numerical negative, element-wise.
|
||||
fn negative(&self) -> Self {
|
||||
let mut result = self.clone();
|
||||
result.negative_mut();
|
||||
result
|
||||
}
|
||||
|
||||
/// Returns new matrix of shape `nrows` by `ncols` with data copied from original matrix.
|
||||
/// ```
|
||||
/// use smartcore::linalg::naive::dense_matrix::*;
|
||||
///
|
||||
/// let a = DenseMatrix::from_array(1, 6, &[1., 2., 3., 4., 5., 6.]);
|
||||
/// let expected = DenseMatrix::from_2d_array(&[
|
||||
/// &[1., 2., 3.],
|
||||
/// &[4., 5., 6.]
|
||||
/// ]);
|
||||
///
|
||||
/// assert_eq!(a.reshape(2, 3), expected);
|
||||
/// ```
|
||||
fn reshape(&self, nrows: usize, ncols: usize) -> Self;
|
||||
|
||||
/// Copies content of `other` matrix.
|
||||
fn copy_from(&mut self, other: &Self);
|
||||
|
||||
/// Calculate the absolute value element-wise. Overrides original matrix.
|
||||
fn abs_mut(&mut self) -> &Self;
|
||||
|
||||
/// Calculate the absolute value element-wise.
|
||||
fn abs(&self) -> Self {
|
||||
let mut result = self.clone();
|
||||
result.abs_mut();
|
||||
result
|
||||
}
|
||||
|
||||
/// Calculates sum of all elements of the matrix.
|
||||
fn sum(&self) -> T;
|
||||
|
||||
/// Calculates max of all elements of the matrix.
|
||||
fn max(&self) -> T;
|
||||
|
||||
/// Calculates min of all elements of the matrix.
|
||||
fn min(&self) -> T;
|
||||
|
||||
/// Calculates max(|a - b|) of two matrices
|
||||
/// ```
|
||||
/// use smartcore::linalg::naive::dense_matrix::*;
|
||||
///
|
||||
/// let a = DenseMatrix::from_array(2, 3, &[1., 2., 3., 4., -5., 6.]);
|
||||
/// let b = DenseMatrix::from_array(2, 3, &[2., 3., 4., 1., 0., -12.]);
|
||||
///
|
||||
/// assert_eq!(a.max_diff(&b), 18.);
|
||||
/// assert_eq!(b.max_diff(&b), 0.);
|
||||
/// ```
|
||||
fn max_diff(&self, other: &Self) -> T {
|
||||
self.sub(other).abs().max()
|
||||
}
|
||||
|
||||
/// Calculates [Softmax function](https://en.wikipedia.org/wiki/Softmax_function). Overrides the matrix with result.
|
||||
fn softmax_mut(&mut self);
|
||||
|
||||
/// Raises elements of the matrix to the power of `p`
|
||||
fn pow_mut(&mut self, p: T) -> &Self;
|
||||
|
||||
/// Returns new matrix with elements raised to the power of `p`
|
||||
fn pow(&mut self, p: T) -> Self {
|
||||
let mut result = self.clone();
|
||||
result.pow_mut(p);
|
||||
result
|
||||
}
|
||||
|
||||
/// Returns the indices of the maximum values in each row.
|
||||
/// ```
|
||||
/// use smartcore::linalg::naive::dense_matrix::*;
|
||||
/// let a = DenseMatrix::from_array(2, 3, &[1., 2., 3., -5., -6., -7.]);
|
||||
///
|
||||
/// assert_eq!(a.argmax(), vec![2, 0]);
|
||||
/// ```
|
||||
fn argmax(&self) -> Vec<usize>;
|
||||
|
||||
/// Returns vector with unique values from the matrix.
|
||||
/// ```
|
||||
/// use smartcore::linalg::naive::dense_matrix::*;
|
||||
/// let a = DenseMatrix::from_array(3, 3, &[1., 2., 2., -2., -6., -7., 2., 3., 4.]);
|
||||
///
|
||||
///assert_eq!(a.unique(), vec![-7., -6., -2., 1., 2., 3., 4.]);
|
||||
/// ```
|
||||
fn unique(&self) -> Vec<T>;
|
||||
|
||||
/// Calculates the covariance matrix
|
||||
fn cov(&self) -> Self;
|
||||
|
||||
/// Take elements from an array along an axis.
|
||||
fn take(&self, index: &[usize], axis: u8) -> Self {
|
||||
let (n, p) = self.shape();
|
||||
|
||||
let k = match axis {
|
||||
0 => p,
|
||||
_ => n,
|
||||
};
|
||||
|
||||
let mut result = match axis {
|
||||
0 => Self::zeros(index.len(), p),
|
||||
_ => Self::zeros(n, index.len()),
|
||||
};
|
||||
|
||||
for (i, idx) in index.iter().enumerate() {
|
||||
for j in 0..k {
|
||||
match axis {
|
||||
0 => result.set(i, j, self.get(*idx, j)),
|
||||
_ => result.set(j, i, self.get(j, *idx)),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
/// Take an individual column from the matrix.
|
||||
fn take_column(&self, column_index: usize) -> Self {
|
||||
self.take(&[column_index], 1)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generic matrix with additional mixins like various factorization methods.
|
||||
pub trait Matrix<T: RealNumber>:
|
||||
BaseMatrix<T>
|
||||
+ SVDDecomposableMatrix<T>
|
||||
+ EVDDecomposableMatrix<T>
|
||||
+ QRDecomposableMatrix<T>
|
||||
+ LUDecomposableMatrix<T>
|
||||
+ CholeskyDecomposableMatrix<T>
|
||||
+ MatrixStats<T>
|
||||
+ MatrixPreprocessing<T>
|
||||
+ HighOrderOperations<T>
|
||||
+ PartialEq
|
||||
+ Display
|
||||
{
|
||||
}
|
||||
|
||||
pub(crate) fn row_iter<F: RealNumber, M: BaseMatrix<F>>(m: &M) -> RowIter<'_, F, M> {
|
||||
RowIter {
|
||||
m,
|
||||
pos: 0,
|
||||
max_pos: m.shape().0,
|
||||
phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct RowIter<'a, T: RealNumber, M: BaseMatrix<T>> {
|
||||
m: &'a M,
|
||||
pos: usize,
|
||||
max_pos: usize,
|
||||
phantom: PhantomData<&'a T>,
|
||||
}
|
||||
|
||||
impl<'a, T: RealNumber, M: BaseMatrix<T>> Iterator for RowIter<'a, T, M> {
|
||||
type Item = Vec<T>;
|
||||
|
||||
fn next(&mut self) -> Option<Vec<T>> {
|
||||
let res = if self.pos < self.max_pos {
|
||||
Some(self.m.get_row_as_vec(self.pos))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
self.pos += 1;
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::linalg::BaseMatrix;
|
||||
use crate::linalg::BaseVector;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn mean() {
|
||||
let m = vec![1., 2., 3.];
|
||||
|
||||
assert_eq!(m.mean(), 2.0);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn std() {
|
||||
let m = vec![1., 2., 3.];
|
||||
|
||||
assert!((m.std() - 0.81f64).abs() < 1e-2);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn var() {
|
||||
let m = vec![1., 2., 3., 4.];
|
||||
|
||||
assert!((m.var() - 1.25f64).abs() < std::f64::EPSILON);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn vec_take() {
|
||||
let m = vec![1., 2., 3., 4., 5.];
|
||||
|
||||
assert_eq!(m.take(&vec!(0, 0, 4, 4)), vec![1., 1., 5., 5.]);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn take() {
|
||||
let m = DenseMatrix::from_2d_array(&[
|
||||
&[1.0, 2.0],
|
||||
&[3.0, 4.0],
|
||||
&[5.0, 6.0],
|
||||
&[7.0, 8.0],
|
||||
&[9.0, 10.0],
|
||||
]);
|
||||
|
||||
let expected_0 = DenseMatrix::from_2d_array(&[&[3.0, 4.0], &[3.0, 4.0], &[7.0, 8.0]]);
|
||||
|
||||
let expected_1 = DenseMatrix::from_2d_array(&[
|
||||
&[2.0, 1.0],
|
||||
&[4.0, 3.0],
|
||||
&[6.0, 5.0],
|
||||
&[8.0, 7.0],
|
||||
&[10.0, 9.0],
|
||||
]);
|
||||
|
||||
assert_eq!(m.take(&vec!(1, 1, 3), 0), expected_0);
|
||||
assert_eq!(m.take(&vec!(1, 0), 1), expected_1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn take_second_column_from_matrix() {
|
||||
let four_columns: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
|
||||
&[0.0, 1.0, 2.0, 3.0],
|
||||
&[0.0, 1.0, 2.0, 3.0],
|
||||
&[0.0, 1.0, 2.0, 3.0],
|
||||
&[0.0, 1.0, 2.0, 3.0],
|
||||
]);
|
||||
|
||||
let second_column = four_columns.take_column(1);
|
||||
assert_eq!(
|
||||
second_column,
|
||||
DenseMatrix::from_2d_array(&[&[1.0], &[1.0], &[1.0], &[1.0]]),
|
||||
"The second column was not extracted correctly"
|
||||
);
|
||||
}
|
||||
}
|
||||
/// ndarray bindings
|
||||
pub mod ndarray;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,26 +0,0 @@
|
||||
//! # Simple Dense Matrix
|
||||
//!
|
||||
//! Implements [`BaseMatrix`](../../trait.BaseMatrix.html) and [`BaseVector`](../../trait.BaseVector.html) for [Vec](https://doc.rust-lang.org/std/vec/struct.Vec.html).
|
||||
//! Data is stored in dense format with [column-major order](https://en.wikipedia.org/wiki/Row-_and_column-major_order).
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//!
|
||||
//! // 3x3 matrix
|
||||
//! let A = DenseMatrix::from_2d_array(&[
|
||||
//! &[0.9000, 0.4000, 0.7000],
|
||||
//! &[0.4000, 0.5000, 0.3000],
|
||||
//! &[0.7000, 0.3000, 0.8000],
|
||||
//! ]);
|
||||
//!
|
||||
//! // row vector
|
||||
//! let B = DenseMatrix::from_array(1, 3, &[0.9, 0.4, 0.7]);
|
||||
//!
|
||||
//! // column vector
|
||||
//! let C = DenseMatrix::from_vec(3, 1, &vec!(0.9, 0.4, 0.7));
|
||||
//! ```
|
||||
|
||||
/// Add this module to use Dense Matrix
|
||||
pub mod dense_matrix;
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,282 @@
|
||||
use std::fmt::{Debug, Display};
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::linalg::basic::arrays::{
|
||||
Array as BaseArray, Array2, ArrayView1, ArrayView2, MutArray, MutArrayView2,
|
||||
};
|
||||
|
||||
use crate::linalg::traits::cholesky::CholeskyDecomposable;
|
||||
use crate::linalg::traits::evd::EVDDecomposable;
|
||||
use crate::linalg::traits::lu::LUDecomposable;
|
||||
use crate::linalg::traits::qr::QRDecomposable;
|
||||
use crate::linalg::traits::svd::SVDDecomposable;
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
|
||||
use ndarray::{s, Array, ArrayBase, ArrayView, ArrayViewMut, Ix2, OwnedRepr};
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)>
|
||||
for ArrayBase<OwnedRepr<T>, Ix2>
|
||||
{
|
||||
fn get(&self, pos: (usize, usize)) -> &T {
|
||||
&self[[pos.0, pos.1]]
|
||||
}
|
||||
|
||||
fn shape(&self) -> (usize, usize) {
|
||||
(self.nrows(), self.ncols())
|
||||
}
|
||||
|
||||
fn is_empty(&self) -> bool {
|
||||
self.len() > 0
|
||||
}
|
||||
|
||||
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
|
||||
assert!(
|
||||
axis == 1 || axis == 0,
|
||||
"For two dimensional array `axis` should be either 0 or 1"
|
||||
);
|
||||
match axis {
|
||||
0 => Box::new(self.iter()),
|
||||
_ => Box::new(
|
||||
(0..self.ncols()).flat_map(move |c| (0..self.nrows()).map(move |r| &self[[r, c]])),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)>
|
||||
for ArrayBase<OwnedRepr<T>, Ix2>
|
||||
{
|
||||
fn set(&mut self, pos: (usize, usize), x: T) {
|
||||
self[[pos.0, pos.1]] = x
|
||||
}
|
||||
|
||||
fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
|
||||
let ptr = self.as_mut_ptr();
|
||||
let stride = self.strides();
|
||||
let (rstride, cstride) = (stride[0] as usize, stride[1] as usize);
|
||||
match axis {
|
||||
0 => Box::new(self.iter_mut()),
|
||||
_ => Box::new((0..self.ncols()).flat_map(move |c| {
|
||||
(0..self.nrows()).map(move |r| unsafe { &mut *ptr.add(r * rstride + c * cstride) })
|
||||
})),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> for ArrayView<'_, T, Ix2> {
|
||||
fn get(&self, pos: (usize, usize)) -> &T {
|
||||
&self[[pos.0, pos.1]]
|
||||
}
|
||||
|
||||
fn shape(&self) -> (usize, usize) {
|
||||
(self.nrows(), self.ncols())
|
||||
}
|
||||
|
||||
fn is_empty(&self) -> bool {
|
||||
self.len() > 0
|
||||
}
|
||||
|
||||
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
|
||||
assert!(
|
||||
axis == 1 || axis == 0,
|
||||
"For two dimensional array `axis` should be either 0 or 1"
|
||||
);
|
||||
match axis {
|
||||
0 => Box::new(self.iter()),
|
||||
_ => Box::new(
|
||||
(0..self.ncols()).flat_map(move |c| (0..self.nrows()).map(move |r| &self[[r, c]])),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> Array2<T> for ArrayBase<OwnedRepr<T>, Ix2> {
|
||||
fn get_row<'a>(&'a self, row: usize) -> Box<dyn ArrayView1<T> + 'a> {
|
||||
Box::new(self.row(row))
|
||||
}
|
||||
|
||||
fn get_col<'a>(&'a self, col: usize) -> Box<dyn ArrayView1<T> + 'a> {
|
||||
Box::new(self.column(col))
|
||||
}
|
||||
|
||||
fn slice<'a>(&'a self, rows: Range<usize>, cols: Range<usize>) -> Box<dyn ArrayView2<T> + 'a> {
|
||||
Box::new(self.slice(s![rows, cols]))
|
||||
}
|
||||
|
||||
fn slice_mut<'a>(
|
||||
&'a mut self,
|
||||
rows: Range<usize>,
|
||||
cols: Range<usize>,
|
||||
) -> Box<dyn MutArrayView2<T> + 'a>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
Box::new(self.slice_mut(s![rows, cols]))
|
||||
}
|
||||
|
||||
fn fill(nrows: usize, ncols: usize, value: T) -> Self {
|
||||
Array::from_elem([nrows, ncols], value)
|
||||
}
|
||||
|
||||
fn from_iterator<I: Iterator<Item = T>>(iter: I, nrows: usize, ncols: usize, axis: u8) -> Self {
|
||||
let a = Array::from_iter(iter.take(nrows * ncols))
|
||||
.into_shape((nrows, ncols))
|
||||
.unwrap();
|
||||
match axis {
|
||||
0 => a,
|
||||
_ => a.reversed_axes().into_shape((nrows, ncols)).unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
fn transpose(&self) -> Self {
|
||||
self.t().to_owned()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Number + RealNumber> QRDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
||||
impl<T: Number + RealNumber> CholeskyDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
||||
impl<T: Number + RealNumber> EVDDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
||||
impl<T: Number + RealNumber> LUDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
||||
impl<T: Number + RealNumber> SVDDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayView<'_, T, Ix2> {}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> for ArrayViewMut<'_, T, Ix2> {
|
||||
fn get(&self, pos: (usize, usize)) -> &T {
|
||||
&self[[pos.0, pos.1]]
|
||||
}
|
||||
|
||||
fn shape(&self) -> (usize, usize) {
|
||||
(self.nrows(), self.ncols())
|
||||
}
|
||||
|
||||
fn is_empty(&self) -> bool {
|
||||
self.len() > 0
|
||||
}
|
||||
|
||||
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
|
||||
assert!(
|
||||
axis == 1 || axis == 0,
|
||||
"For two dimensional array `axis` should be either 0 or 1"
|
||||
);
|
||||
match axis {
|
||||
0 => Box::new(self.iter()),
|
||||
_ => Box::new(
|
||||
(0..self.ncols()).flat_map(move |c| (0..self.nrows()).map(move |r| &self[[r, c]])),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for ArrayViewMut<'_, T, Ix2> {
|
||||
fn set(&mut self, pos: (usize, usize), x: T) {
|
||||
self[[pos.0, pos.1]] = x
|
||||
}
|
||||
|
||||
fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
|
||||
let ptr = self.as_mut_ptr();
|
||||
let stride = self.strides();
|
||||
let (rstride, cstride) = (stride[0] as usize, stride[1] as usize);
|
||||
match axis {
|
||||
0 => Box::new(self.iter_mut()),
|
||||
_ => Box::new((0..self.ncols()).flat_map(move |c| {
|
||||
(0..self.nrows()).map(move |r| unsafe { &mut *ptr.add(r * rstride + c * cstride) })
|
||||
})),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for ArrayViewMut<'_, T, Ix2> {}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayViewMut<'_, T, Ix2> {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use ndarray::{arr2, Array2 as NDArray2};
|
||||
|
||||
#[test]
|
||||
fn test_get_set() {
|
||||
let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]);
|
||||
|
||||
assert_eq!(*BaseArray::get(&a, (1, 1)), 5);
|
||||
a.set((1, 1), 9);
|
||||
assert_eq!(a, arr2(&[[1, 2, 3], [4, 9, 6]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_iterator() {
|
||||
let a = arr2(&[[1, 2, 3], [4, 5, 6]]);
|
||||
|
||||
let v: Vec<i32> = a.iterator(0).copied().collect();
|
||||
assert_eq!(v, vec!(1, 2, 3, 4, 5, 6));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mut_iterator() {
|
||||
let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]);
|
||||
|
||||
a.iterator_mut(0).enumerate().for_each(|(i, v)| *v = i);
|
||||
assert_eq!(a, arr2(&[[0, 1, 2], [3, 4, 5]]));
|
||||
a.iterator_mut(1).enumerate().for_each(|(i, v)| *v = i);
|
||||
assert_eq!(a, arr2(&[[0, 2, 4], [1, 3, 5]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slice() {
|
||||
let x = arr2(&[[1, 2, 3], [4, 5, 6]]);
|
||||
let x_slice = Array2::slice(&x, 0..2, 1..2);
|
||||
assert_eq!((2, 1), x_slice.shape());
|
||||
let v: Vec<i32> = x_slice.iterator(0).copied().collect();
|
||||
assert_eq!(v, [2, 5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slice_iter() {
|
||||
let x = arr2(&[[1, 2, 3], [4, 5, 6]]);
|
||||
let x_slice = Array2::slice(&x, 0..2, 0..3);
|
||||
assert_eq!(
|
||||
x_slice.iterator(0).copied().collect::<Vec<i32>>(),
|
||||
vec![1, 2, 3, 4, 5, 6]
|
||||
);
|
||||
assert_eq!(
|
||||
x_slice.iterator(1).copied().collect::<Vec<i32>>(),
|
||||
vec![1, 4, 2, 5, 3, 6]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slice_mut_iter() {
|
||||
let mut x = arr2(&[[1, 2, 3], [4, 5, 6]]);
|
||||
{
|
||||
let mut x_slice = Array2::slice_mut(&mut x, 0..2, 0..3);
|
||||
x_slice
|
||||
.iterator_mut(0)
|
||||
.enumerate()
|
||||
.for_each(|(i, v)| *v = i);
|
||||
}
|
||||
assert_eq!(x, arr2(&[[0, 1, 2], [3, 4, 5]]));
|
||||
{
|
||||
let mut x_slice = Array2::slice_mut(&mut x, 0..2, 0..3);
|
||||
x_slice
|
||||
.iterator_mut(1)
|
||||
.enumerate()
|
||||
.for_each(|(i, v)| *v = i);
|
||||
}
|
||||
assert_eq!(x, arr2(&[[0, 2, 4], [1, 3, 5]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_c_from_iterator() {
|
||||
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
|
||||
let a: NDArray2<i32> = Array2::from_iterator(data.clone().into_iter(), 4, 3, 0);
|
||||
println!("{a}");
|
||||
let a: NDArray2<i32> = Array2::from_iterator(data.into_iter(), 4, 3, 1);
|
||||
println!("{a}");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
/// matrix bindings
|
||||
pub mod matrix;
|
||||
/// vector bindings
|
||||
pub mod vector;
|
||||
@@ -0,0 +1,184 @@
|
||||
use std::fmt::{Debug, Display};
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::linalg::basic::arrays::{
|
||||
Array as BaseArray, Array1, ArrayView1, MutArray, MutArrayView1,
|
||||
};
|
||||
|
||||
use ndarray::{s, Array, ArrayBase, ArrayView, ArrayViewMut, Ix1, OwnedRepr};
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayBase<OwnedRepr<T>, Ix1> {
|
||||
fn get(&self, i: usize) -> &T {
|
||||
&self[i]
|
||||
}
|
||||
|
||||
fn shape(&self) -> usize {
|
||||
self.len()
|
||||
}
|
||||
|
||||
fn is_empty(&self) -> bool {
|
||||
self.len() > 0
|
||||
}
|
||||
|
||||
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
|
||||
assert!(axis == 0, "For one dimensional array `axis` should == 0");
|
||||
Box::new(self.iter())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for ArrayBase<OwnedRepr<T>, Ix1> {
|
||||
fn set(&mut self, i: usize, x: T) {
|
||||
self[i] = x
|
||||
}
|
||||
|
||||
fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
|
||||
assert!(axis == 0, "For one dimensional array `axis` should == 0");
|
||||
Box::new(self.iter_mut())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayBase<OwnedRepr<T>, Ix1> {}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for ArrayBase<OwnedRepr<T>, Ix1> {}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayView<'_, T, Ix1> {
|
||||
fn get(&self, i: usize) -> &T {
|
||||
&self[i]
|
||||
}
|
||||
|
||||
fn shape(&self) -> usize {
|
||||
self.len()
|
||||
}
|
||||
|
||||
fn is_empty(&self) -> bool {
|
||||
self.len() > 0
|
||||
}
|
||||
|
||||
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
|
||||
assert!(axis == 0, "For one dimensional array `axis` should == 0");
|
||||
Box::new(self.iter())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayView<'_, T, Ix1> {}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayViewMut<'_, T, Ix1> {
|
||||
fn get(&self, i: usize) -> &T {
|
||||
&self[i]
|
||||
}
|
||||
|
||||
fn shape(&self) -> usize {
|
||||
self.len()
|
||||
}
|
||||
|
||||
fn is_empty(&self) -> bool {
|
||||
self.len() > 0
|
||||
}
|
||||
|
||||
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
|
||||
assert!(axis == 0, "For one dimensional array `axis` should == 0");
|
||||
Box::new(self.iter())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for ArrayViewMut<'_, T, Ix1> {
|
||||
fn set(&mut self, i: usize, x: T) {
|
||||
self[i] = x;
|
||||
}
|
||||
|
||||
fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
|
||||
assert!(axis == 0, "For one dimensional array `axis` should == 0");
|
||||
Box::new(self.iter_mut())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayViewMut<'_, T, Ix1> {}
|
||||
impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for ArrayViewMut<'_, T, Ix1> {}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> Array1<T> for ArrayBase<OwnedRepr<T>, Ix1> {
|
||||
fn slice<'a>(&'a self, range: Range<usize>) -> Box<dyn ArrayView1<T> + 'a> {
|
||||
assert!(
|
||||
range.end <= self.len(),
|
||||
"`range` should be <= {}",
|
||||
self.len()
|
||||
);
|
||||
Box::new(self.slice(s![range]))
|
||||
}
|
||||
|
||||
fn slice_mut<'b>(&'b mut self, range: Range<usize>) -> Box<dyn MutArrayView1<T> + 'b> {
|
||||
assert!(
|
||||
range.end <= self.len(),
|
||||
"`range` should be <= {}",
|
||||
self.len()
|
||||
);
|
||||
Box::new(self.slice_mut(s![range]))
|
||||
}
|
||||
|
||||
fn fill(len: usize, value: T) -> Self {
|
||||
Array::from_elem(len, value)
|
||||
}
|
||||
|
||||
fn from_iterator<I: Iterator<Item = T>>(iter: I, len: usize) -> Self
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
Array::from_iter(iter.take(len))
|
||||
}
|
||||
|
||||
fn from_vec_slice(slice: &[T]) -> Self {
|
||||
Array::from_iter(slice.iter().copied())
|
||||
}
|
||||
|
||||
fn from_slice(slice: &dyn ArrayView1<T>) -> Self {
|
||||
Array::from_iter(slice.iterator(0).copied())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use ndarray::arr1;
|
||||
|
||||
#[test]
|
||||
fn test_get_set() {
|
||||
let mut a = arr1(&[1, 2, 3]);
|
||||
|
||||
assert_eq!(*BaseArray::get(&a, 1), 2);
|
||||
a.set(1, 9);
|
||||
assert_eq!(a, arr1(&[1, 9, 3]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_iterator() {
|
||||
let a = arr1(&[1, 2, 3]);
|
||||
|
||||
let v: Vec<i32> = a.iterator(0).copied().collect();
|
||||
assert_eq!(v, vec!(1, 2, 3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mut_iterator() {
|
||||
let mut a = arr1(&[1, 2, 3]);
|
||||
|
||||
a.iterator_mut(0).for_each(|v| *v = 1);
|
||||
assert_eq!(a, arr1(&[1, 1, 1]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slice() {
|
||||
let x = arr1(&[1, 2, 3, 4, 5]);
|
||||
let x_slice = Array1::slice(&x, 2..3);
|
||||
assert_eq!(1, x_slice.shape());
|
||||
assert_eq!(3, *x_slice.get(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mut_slice() {
|
||||
let mut x = arr1(&[1, 2, 3, 4, 5]);
|
||||
let mut x_slice = Array1::slice_mut(&mut x, 2..4);
|
||||
x_slice.set(0, 9);
|
||||
assert_eq!(2, x_slice.shape());
|
||||
assert_eq!(9, *x_slice.get(0));
|
||||
assert_eq!(4, *x_slice.get(1));
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,207 +0,0 @@
|
||||
//! # Various Statistical Methods
|
||||
//!
|
||||
//! This module provides reference implementations for various statistical functions.
|
||||
//! Concrete implementations of the `BaseMatrix` trait are free to override these methods for better performance.
|
||||
|
||||
use crate::linalg::BaseMatrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Defines baseline implementations for various statistical functions
|
||||
pub trait MatrixStats<T: RealNumber>: BaseMatrix<T> {
|
||||
/// Computes the arithmetic mean along the specified axis.
|
||||
fn mean(&self, axis: u8) -> Vec<T> {
|
||||
let (n, m) = match axis {
|
||||
0 => {
|
||||
let (n, m) = self.shape();
|
||||
(m, n)
|
||||
}
|
||||
_ => self.shape(),
|
||||
};
|
||||
|
||||
let mut x: Vec<T> = vec![T::zero(); n];
|
||||
|
||||
let div = T::from_usize(m).unwrap();
|
||||
|
||||
for (i, x_i) in x.iter_mut().enumerate().take(n) {
|
||||
for j in 0..m {
|
||||
*x_i += match axis {
|
||||
0 => self.get(j, i),
|
||||
_ => self.get(i, j),
|
||||
};
|
||||
}
|
||||
*x_i /= div;
|
||||
}
|
||||
|
||||
x
|
||||
}
|
||||
|
||||
/// Computes variance along the specified axis.
|
||||
fn var(&self, axis: u8) -> Vec<T> {
|
||||
let (n, m) = match axis {
|
||||
0 => {
|
||||
let (n, m) = self.shape();
|
||||
(m, n)
|
||||
}
|
||||
_ => self.shape(),
|
||||
};
|
||||
|
||||
let mut x: Vec<T> = vec![T::zero(); n];
|
||||
|
||||
let div = T::from_usize(m).unwrap();
|
||||
|
||||
for (i, x_i) in x.iter_mut().enumerate().take(n) {
|
||||
let mut mu = T::zero();
|
||||
let mut sum = T::zero();
|
||||
for j in 0..m {
|
||||
let a = match axis {
|
||||
0 => self.get(j, i),
|
||||
_ => self.get(i, j),
|
||||
};
|
||||
mu += a;
|
||||
sum += a * a;
|
||||
}
|
||||
mu /= div;
|
||||
*x_i = sum / div - mu.powi(2);
|
||||
}
|
||||
|
||||
x
|
||||
}
|
||||
|
||||
/// Computes the standard deviation along the specified axis.
|
||||
fn std(&self, axis: u8) -> Vec<T> {
|
||||
let mut x = self.var(axis);
|
||||
|
||||
let n = match axis {
|
||||
0 => self.shape().1,
|
||||
_ => self.shape().0,
|
||||
};
|
||||
|
||||
for x_i in x.iter_mut().take(n) {
|
||||
*x_i = x_i.sqrt();
|
||||
}
|
||||
|
||||
x
|
||||
}
|
||||
|
||||
/// standardize values by removing the mean and scaling to unit variance
|
||||
fn scale_mut(&mut self, mean: &[T], std: &[T], axis: u8) {
|
||||
let (n, m) = match axis {
|
||||
0 => {
|
||||
let (n, m) = self.shape();
|
||||
(m, n)
|
||||
}
|
||||
_ => self.shape(),
|
||||
};
|
||||
|
||||
for i in 0..n {
|
||||
for j in 0..m {
|
||||
match axis {
|
||||
0 => self.set(j, i, (self.get(j, i) - mean[i]) / std[i]),
|
||||
_ => self.set(i, j, (self.get(i, j) - mean[i]) / std[i]),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Defines baseline implementations for various matrix processing functions
|
||||
pub trait MatrixPreprocessing<T: RealNumber>: BaseMatrix<T> {
|
||||
/// Each element of the matrix greater than the threshold becomes 1, while values less than or equal to the threshold become 0
|
||||
/// ```
|
||||
/// use smartcore::linalg::naive::dense_matrix::*;
|
||||
/// use crate::smartcore::linalg::stats::MatrixPreprocessing;
|
||||
/// let mut a = DenseMatrix::from_array(2, 3, &[0., 2., 3., -5., -6., -7.]);
|
||||
/// let expected = DenseMatrix::from_array(2, 3, &[0., 1., 1., 0., 0., 0.]);
|
||||
/// a.binarize_mut(0.);
|
||||
///
|
||||
/// assert_eq!(a, expected);
|
||||
/// ```
|
||||
|
||||
fn binarize_mut(&mut self, threshold: T) {
|
||||
let (nrows, ncols) = self.shape();
|
||||
for row in 0..nrows {
|
||||
for col in 0..ncols {
|
||||
if self.get(row, col) > threshold {
|
||||
self.set(row, col, T::one());
|
||||
} else {
|
||||
self.set(row, col, T::zero());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
/// Returns new matrix where elements are binarized according to a given threshold.
|
||||
/// ```
|
||||
/// use smartcore::linalg::naive::dense_matrix::*;
|
||||
/// use crate::smartcore::linalg::stats::MatrixPreprocessing;
|
||||
/// let a = DenseMatrix::from_array(2, 3, &[0., 2., 3., -5., -6., -7.]);
|
||||
/// let expected = DenseMatrix::from_array(2, 3, &[0., 1., 1., 0., 0., 0.]);
|
||||
///
|
||||
/// assert_eq!(a.binarize(0.), expected);
|
||||
/// ```
|
||||
fn binarize(&self, threshold: T) -> Self {
|
||||
let mut m = self.clone();
|
||||
m.binarize_mut(threshold);
|
||||
m
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::linalg::BaseVector;
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn mean() {
|
||||
let m = DenseMatrix::from_2d_array(&[
|
||||
&[1., 2., 3., 1., 2.],
|
||||
&[4., 5., 6., 3., 4.],
|
||||
&[7., 8., 9., 5., 6.],
|
||||
]);
|
||||
let expected_0 = vec![4., 5., 6., 3., 4.];
|
||||
let expected_1 = vec![1.8, 4.4, 7.];
|
||||
|
||||
assert_eq!(m.mean(0), expected_0);
|
||||
assert_eq!(m.mean(1), expected_1);
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn std() {
|
||||
let m = DenseMatrix::from_2d_array(&[
|
||||
&[1., 2., 3., 1., 2.],
|
||||
&[4., 5., 6., 3., 4.],
|
||||
&[7., 8., 9., 5., 6.],
|
||||
]);
|
||||
let expected_0 = vec![2.44, 2.44, 2.44, 1.63, 1.63];
|
||||
let expected_1 = vec![0.74, 1.01, 1.41];
|
||||
|
||||
assert!(m.std(0).approximate_eq(&expected_0, 1e-2));
|
||||
assert!(m.std(1).approximate_eq(&expected_1, 1e-2));
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn var() {
|
||||
let m = DenseMatrix::from_2d_array(&[&[1., 2., 3., 4.], &[5., 6., 7., 8.]]);
|
||||
let expected_0 = vec![4., 4., 4., 4.];
|
||||
let expected_1 = vec![1.25, 1.25];
|
||||
|
||||
assert!(m.var(0).approximate_eq(&expected_0, std::f64::EPSILON));
|
||||
assert!(m.var(1).approximate_eq(&expected_1, std::f64::EPSILON));
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn scale() {
|
||||
let mut m = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
|
||||
let expected_0 = DenseMatrix::from_2d_array(&[&[-1., -1., -1.], &[1., 1., 1.]]);
|
||||
let expected_1 = DenseMatrix::from_2d_array(&[&[-1.22, 0.0, 1.22], &[-1.22, 0.0, 1.22]]);
|
||||
|
||||
{
|
||||
let mut m = m.clone();
|
||||
m.scale_mut(&m.mean(0), &m.std(0), 0);
|
||||
assert!(m.approximate_eq(&expected_0, std::f32::EPSILON));
|
||||
}
|
||||
|
||||
m.scale_mut(&m.mean(1), &m.std(1), 1);
|
||||
assert!(m.approximate_eq(&expected_1, 1e-2));
|
||||
}
|
||||
}
|
||||
@@ -8,14 +8,14 @@
|
||||
//!
|
||||
//! Example:
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use crate::smartcore::linalg::cholesky::*;
|
||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
//! use smartcore::linalg::traits::cholesky::*;
|
||||
//!
|
||||
//! let A = DenseMatrix::from_2d_array(&[
|
||||
//! &[25., 15., -5.],
|
||||
//! &[15., 18., 0.],
|
||||
//! &[-5., 0., 11.]
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! let cholesky = A.cholesky().unwrap();
|
||||
//! let lower_triangular: DenseMatrix<f64> = cholesky.L();
|
||||
@@ -34,17 +34,18 @@ use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::BaseMatrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::linalg::basic::arrays::Array2;
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Results of Cholesky decomposition.
|
||||
pub struct Cholesky<T: RealNumber, M: BaseMatrix<T>> {
|
||||
pub struct Cholesky<T: Number + RealNumber, M: Array2<T>> {
|
||||
R: M,
|
||||
t: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: BaseMatrix<T>> Cholesky<T, M> {
|
||||
impl<T: Number + RealNumber, M: Array2<T>> Cholesky<T, M> {
|
||||
pub(crate) fn new(R: M) -> Cholesky<T, M> {
|
||||
Cholesky { R, t: PhantomData }
|
||||
}
|
||||
@@ -57,7 +58,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> Cholesky<T, M> {
|
||||
for i in 0..n {
|
||||
for j in 0..n {
|
||||
if j <= i {
|
||||
R.set(i, j, self.R.get(i, j));
|
||||
R.set((i, j), *self.R.get((i, j)));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -72,7 +73,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> Cholesky<T, M> {
|
||||
for i in 0..n {
|
||||
for j in 0..n {
|
||||
if j <= i {
|
||||
R.set(j, i, self.R.get(i, j));
|
||||
R.set((j, i), *self.R.get((i, j)));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -87,25 +88,25 @@ impl<T: RealNumber, M: BaseMatrix<T>> Cholesky<T, M> {
|
||||
if bn != rn {
|
||||
return Err(Failed::because(
|
||||
FailedError::SolutionFailed,
|
||||
"Can\'t solve Ax = b for x. Number of rows in b != number of rows in R.",
|
||||
"Can\'t solve Ax = b for x. FloatNumber of rows in b != number of rows in R.",
|
||||
));
|
||||
}
|
||||
|
||||
for k in 0..bn {
|
||||
for j in 0..m {
|
||||
for i in 0..k {
|
||||
b.sub_element_mut(k, j, b.get(i, j) * self.R.get(k, i));
|
||||
b.sub_element_mut((k, j), *b.get((i, j)) * *self.R.get((k, i)));
|
||||
}
|
||||
b.div_element_mut(k, j, self.R.get(k, k));
|
||||
b.div_element_mut((k, j), *self.R.get((k, k)));
|
||||
}
|
||||
}
|
||||
|
||||
for k in (0..bn).rev() {
|
||||
for j in 0..m {
|
||||
for i in k + 1..bn {
|
||||
b.sub_element_mut(k, j, b.get(i, j) * self.R.get(i, k));
|
||||
b.sub_element_mut((k, j), *b.get((i, j)) * *self.R.get((i, k)));
|
||||
}
|
||||
b.div_element_mut(k, j, self.R.get(k, k));
|
||||
b.div_element_mut((k, j), *self.R.get((k, k)));
|
||||
}
|
||||
}
|
||||
Ok(b)
|
||||
@@ -113,7 +114,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> Cholesky<T, M> {
|
||||
}
|
||||
|
||||
/// Trait that implements Cholesky decomposition routine for any matrix.
|
||||
pub trait CholeskyDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
pub trait CholeskyDecomposable<T: Number + RealNumber>: Array2<T> {
|
||||
/// Compute the Cholesky decomposition of a matrix.
|
||||
fn cholesky(&self) -> Result<Cholesky<T, Self>, Failed> {
|
||||
self.clone().cholesky_mut()
|
||||
@@ -136,13 +137,13 @@ pub trait CholeskyDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
for k in 0..j {
|
||||
let mut s = T::zero();
|
||||
for i in 0..k {
|
||||
s += self.get(k, i) * self.get(j, i);
|
||||
s += *self.get((k, i)) * *self.get((j, i));
|
||||
}
|
||||
s = (self.get(j, k) - s) / self.get(k, k);
|
||||
self.set(j, k, s);
|
||||
s = (*self.get((j, k)) - s) / *self.get((k, k));
|
||||
self.set((j, k), s);
|
||||
d += s * s;
|
||||
}
|
||||
d = self.get(j, j) - d;
|
||||
d = *self.get((j, j)) - d;
|
||||
|
||||
if d < T::zero() {
|
||||
return Err(Failed::because(
|
||||
@@ -151,7 +152,7 @@ pub trait CholeskyDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
));
|
||||
}
|
||||
|
||||
self.set(j, j, d.sqrt());
|
||||
self.set((j, j), d.sqrt());
|
||||
}
|
||||
|
||||
Ok(Cholesky::new(self))
|
||||
@@ -166,39 +167,50 @@ pub trait CholeskyDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use approx::relative_eq;
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn cholesky_decompose() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
|
||||
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]])
|
||||
.unwrap();
|
||||
let l =
|
||||
DenseMatrix::from_2d_array(&[&[5.0, 0.0, 0.0], &[3.0, 3.0, 0.0], &[-1.0, 1.0, 3.0]]);
|
||||
DenseMatrix::from_2d_array(&[&[5.0, 0.0, 0.0], &[3.0, 3.0, 0.0], &[-1.0, 1.0, 3.0]])
|
||||
.unwrap();
|
||||
let u =
|
||||
DenseMatrix::from_2d_array(&[&[5.0, 3.0, -1.0], &[0.0, 3.0, 1.0], &[0.0, 0.0, 3.0]]);
|
||||
DenseMatrix::from_2d_array(&[&[5.0, 3.0, -1.0], &[0.0, 3.0, 1.0], &[0.0, 0.0, 3.0]])
|
||||
.unwrap();
|
||||
let cholesky = a.cholesky().unwrap();
|
||||
|
||||
assert!(cholesky.L().abs().approximate_eq(&l.abs(), 1e-4));
|
||||
assert!(cholesky.U().abs().approximate_eq(&u.abs(), 1e-4));
|
||||
assert!(cholesky
|
||||
.L()
|
||||
.matmul(&cholesky.U())
|
||||
.abs()
|
||||
.approximate_eq(&a.abs(), 1e-4));
|
||||
assert!(relative_eq!(cholesky.L().abs(), l.abs(), epsilon = 1e-4));
|
||||
assert!(relative_eq!(cholesky.U().abs(), u.abs(), epsilon = 1e-4));
|
||||
assert!(relative_eq!(
|
||||
cholesky.L().matmul(&cholesky.U()).abs(),
|
||||
a.abs(),
|
||||
epsilon = 1e-4
|
||||
));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn cholesky_solve_mut() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
|
||||
let b = DenseMatrix::from_2d_array(&[&[40., 51., 28.]]);
|
||||
let expected = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]);
|
||||
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]])
|
||||
.unwrap();
|
||||
let b = DenseMatrix::from_2d_array(&[&[40., 51., 28.]]).unwrap();
|
||||
let expected = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]).unwrap();
|
||||
|
||||
let cholesky = a.cholesky().unwrap();
|
||||
|
||||
assert!(cholesky
|
||||
.solve(b.transpose())
|
||||
.unwrap()
|
||||
.transpose()
|
||||
.approximate_eq(&expected, 1e-4));
|
||||
assert!(relative_eq!(
|
||||
cholesky.solve(b.transpose()).unwrap().transpose(),
|
||||
expected,
|
||||
epsilon = 1e-4
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -12,32 +12,19 @@
|
||||
//!
|
||||
//! Example:
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::linalg::evd::*;
|
||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
//! use smartcore::linalg::traits::evd::*;
|
||||
//!
|
||||
//! let A = DenseMatrix::from_2d_array(&[
|
||||
//! &[0.9000, 0.4000, 0.7000],
|
||||
//! &[0.4000, 0.5000, 0.3000],
|
||||
//! &[0.7000, 0.3000, 0.8000],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! let evd = A.evd(true).unwrap();
|
||||
//! let eigenvectors: DenseMatrix<f64> = evd.V;
|
||||
//! let eigenvalues: Vec<f64> = evd.d;
|
||||
//! ```
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::linalg::evd::*;
|
||||
//!
|
||||
//! let A = DenseMatrix::from_2d_array(&[
|
||||
//! &[-5.0, 2.0],
|
||||
//! &[-7.0, 4.0],
|
||||
//! ]);
|
||||
//!
|
||||
//! let evd = A.evd(false).unwrap();
|
||||
//! let eigenvectors: DenseMatrix<f64> = evd.V;
|
||||
//! let eigenvalues: Vec<f64> = evd.d;
|
||||
//! ```
|
||||
//!
|
||||
//! ## References:
|
||||
//! * ["Numerical Recipes: The Art of Scientific Computing", Press W.H., Teukolsky S.A., Vetterling W.T, Flannery B.P, 3rd ed., Section 11 Eigensystems](http://numerical.recipes/)
|
||||
@@ -48,14 +35,15 @@
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::BaseMatrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::linalg::basic::arrays::Array2;
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
use num::complex::Complex;
|
||||
use std::fmt::Debug;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Results of eigen decomposition
|
||||
pub struct EVD<T: RealNumber, M: BaseMatrix<T>> {
|
||||
pub struct EVD<T: Number + RealNumber, M: Array2<T>> {
|
||||
/// Real part of eigenvalues.
|
||||
pub d: Vec<T>,
|
||||
/// Imaginary part of eigenvalues.
|
||||
@@ -65,7 +53,7 @@ pub struct EVD<T: RealNumber, M: BaseMatrix<T>> {
|
||||
}
|
||||
|
||||
/// Trait that implements EVD decomposition routine for any matrix.
|
||||
pub trait EVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
pub trait EVDDecomposable<T: Number + RealNumber>: Array2<T> {
|
||||
/// Compute the eigen decomposition of a square matrix.
|
||||
/// * `symmetric` - whether the matrix is symmetric
|
||||
fn evd(&self, symmetric: bool) -> Result<EVD<T, Self>, Failed> {
|
||||
@@ -78,7 +66,7 @@ pub trait EVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
fn evd_mut(mut self, symmetric: bool) -> Result<EVD<T, Self>, Failed> {
|
||||
let (nrows, ncols) = self.shape();
|
||||
if ncols != nrows {
|
||||
panic!("Matrix is not square: {} x {}", nrows, ncols);
|
||||
panic!("Matrix is not square: {nrows} x {ncols}");
|
||||
}
|
||||
|
||||
let n = nrows;
|
||||
@@ -106,14 +94,14 @@ pub trait EVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
sort(&mut d, &mut e, &mut V);
|
||||
}
|
||||
|
||||
Ok(EVD { d, e, V })
|
||||
Ok(EVD { V, d, e })
|
||||
}
|
||||
}
|
||||
|
||||
fn tred2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
|
||||
fn tred2<T: Number + RealNumber, M: Array2<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
|
||||
let (n, _) = V.shape();
|
||||
for (i, d_i) in d.iter_mut().enumerate().take(n) {
|
||||
*d_i = V.get(n - 1, i);
|
||||
*d_i = *V.get((n - 1, i));
|
||||
}
|
||||
|
||||
for i in (1..n).rev() {
|
||||
@@ -125,9 +113,9 @@ fn tred2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
|
||||
if scale == T::zero() {
|
||||
e[i] = d[i - 1];
|
||||
for (j, d_j) in d.iter_mut().enumerate().take(i) {
|
||||
*d_j = V.get(i - 1, j);
|
||||
V.set(i, j, T::zero());
|
||||
V.set(j, i, T::zero());
|
||||
*d_j = *V.get((i - 1, j));
|
||||
V.set((i, j), T::zero());
|
||||
V.set((j, i), T::zero());
|
||||
}
|
||||
} else {
|
||||
for d_k in d.iter_mut().take(i) {
|
||||
@@ -148,11 +136,11 @@ fn tred2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
|
||||
|
||||
for j in 0..i {
|
||||
f = d[j];
|
||||
V.set(j, i, f);
|
||||
g = e[j] + V.get(j, j) * f;
|
||||
V.set((j, i), f);
|
||||
g = e[j] + *V.get((j, j)) * f;
|
||||
for k in j + 1..=i - 1 {
|
||||
g += V.get(k, j) * d[k];
|
||||
e[k] += V.get(k, j) * f;
|
||||
g += *V.get((k, j)) * d[k];
|
||||
e[k] += *V.get((k, j)) * f;
|
||||
}
|
||||
e[j] = g;
|
||||
}
|
||||
@@ -169,46 +157,46 @@ fn tred2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
|
||||
f = d[j];
|
||||
g = e[j];
|
||||
for k in j..=i - 1 {
|
||||
V.sub_element_mut(k, j, f * e[k] + g * d[k]);
|
||||
V.sub_element_mut((k, j), f * e[k] + g * d[k]);
|
||||
}
|
||||
d[j] = V.get(i - 1, j);
|
||||
V.set(i, j, T::zero());
|
||||
d[j] = *V.get((i - 1, j));
|
||||
V.set((i, j), T::zero());
|
||||
}
|
||||
}
|
||||
d[i] = h;
|
||||
}
|
||||
|
||||
for i in 0..n - 1 {
|
||||
V.set(n - 1, i, V.get(i, i));
|
||||
V.set(i, i, T::one());
|
||||
V.set((n - 1, i), *V.get((i, i)));
|
||||
V.set((i, i), T::one());
|
||||
let h = d[i + 1];
|
||||
if h != T::zero() {
|
||||
for (k, d_k) in d.iter_mut().enumerate().take(i + 1) {
|
||||
*d_k = V.get(k, i + 1) / h;
|
||||
*d_k = *V.get((k, i + 1)) / h;
|
||||
}
|
||||
for j in 0..=i {
|
||||
let mut g = T::zero();
|
||||
for k in 0..=i {
|
||||
g += V.get(k, i + 1) * V.get(k, j);
|
||||
g += *V.get((k, i + 1)) * *V.get((k, j));
|
||||
}
|
||||
for (k, d_k) in d.iter().enumerate().take(i + 1) {
|
||||
V.sub_element_mut(k, j, g * (*d_k));
|
||||
V.sub_element_mut((k, j), g * (*d_k));
|
||||
}
|
||||
}
|
||||
}
|
||||
for k in 0..=i {
|
||||
V.set(k, i + 1, T::zero());
|
||||
V.set((k, i + 1), T::zero());
|
||||
}
|
||||
}
|
||||
for (j, d_j) in d.iter_mut().enumerate().take(n) {
|
||||
*d_j = V.get(n - 1, j);
|
||||
V.set(n - 1, j, T::zero());
|
||||
*d_j = *V.get((n - 1, j));
|
||||
V.set((n - 1, j), T::zero());
|
||||
}
|
||||
V.set(n - 1, n - 1, T::one());
|
||||
V.set((n - 1, n - 1), T::one());
|
||||
e[0] = T::zero();
|
||||
}
|
||||
|
||||
fn tql2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
|
||||
fn tql2<T: Number + RealNumber, M: Array2<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
|
||||
let (n, _) = V.shape();
|
||||
for i in 1..n {
|
||||
e[i - 1] = e[i];
|
||||
@@ -277,9 +265,9 @@ fn tql2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
|
||||
d[i + 1] = h + s * (c * g + s * d[i]);
|
||||
|
||||
for k in 0..n {
|
||||
h = V.get(k, i + 1);
|
||||
V.set(k, i + 1, s * V.get(k, i) + c * h);
|
||||
V.set(k, i, c * V.get(k, i) - s * h);
|
||||
h = *V.get((k, i + 1));
|
||||
V.set((k, i + 1), s * *V.get((k, i)) + c * h);
|
||||
V.set((k, i), c * *V.get((k, i)) - s * h);
|
||||
}
|
||||
}
|
||||
p = -s * s2 * c3 * el1 * e[l] / dl1;
|
||||
@@ -308,15 +296,15 @@ fn tql2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
|
||||
d[k] = d[i];
|
||||
d[i] = p;
|
||||
for j in 0..n {
|
||||
p = V.get(j, i);
|
||||
V.set(j, i, V.get(j, k));
|
||||
V.set(j, k, p);
|
||||
p = *V.get((j, i));
|
||||
V.set((j, i), *V.get((j, k)));
|
||||
V.set((j, k), p);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn balance<T: RealNumber, M: BaseMatrix<T>>(A: &mut M) -> Vec<T> {
|
||||
fn balance<T: Number + RealNumber, M: Array2<T>>(A: &mut M) -> Vec<T> {
|
||||
let radix = T::two();
|
||||
let sqrdx = radix * radix;
|
||||
|
||||
@@ -334,8 +322,8 @@ fn balance<T: RealNumber, M: BaseMatrix<T>>(A: &mut M) -> Vec<T> {
|
||||
let mut c = T::zero();
|
||||
for j in 0..n {
|
||||
if j != i {
|
||||
c += A.get(j, i).abs();
|
||||
r += A.get(i, j).abs();
|
||||
c += A.get((j, i)).abs();
|
||||
r += A.get((i, j)).abs();
|
||||
}
|
||||
}
|
||||
if c != T::zero() && r != T::zero() {
|
||||
@@ -356,10 +344,10 @@ fn balance<T: RealNumber, M: BaseMatrix<T>>(A: &mut M) -> Vec<T> {
|
||||
g = T::one() / f;
|
||||
*scale_i *= f;
|
||||
for j in 0..n {
|
||||
A.mul_element_mut(i, j, g);
|
||||
A.mul_element_mut((i, j), g);
|
||||
}
|
||||
for j in 0..n {
|
||||
A.mul_element_mut(j, i, f);
|
||||
A.mul_element_mut((j, i), f);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -369,7 +357,7 @@ fn balance<T: RealNumber, M: BaseMatrix<T>>(A: &mut M) -> Vec<T> {
|
||||
scale
|
||||
}
|
||||
|
||||
fn elmhes<T: RealNumber, M: BaseMatrix<T>>(A: &mut M) -> Vec<usize> {
|
||||
fn elmhes<T: Number + RealNumber, M: Array2<T>>(A: &mut M) -> Vec<usize> {
|
||||
let (n, _) = A.shape();
|
||||
let mut perm = vec![0; n];
|
||||
|
||||
@@ -377,35 +365,31 @@ fn elmhes<T: RealNumber, M: BaseMatrix<T>>(A: &mut M) -> Vec<usize> {
|
||||
let mut x = T::zero();
|
||||
let mut i = m;
|
||||
for j in m..n {
|
||||
if A.get(j, m - 1).abs() > x.abs() {
|
||||
x = A.get(j, m - 1);
|
||||
if A.get((j, m - 1)).abs() > x.abs() {
|
||||
x = *A.get((j, m - 1));
|
||||
i = j;
|
||||
}
|
||||
}
|
||||
*perm_m = i;
|
||||
if i != m {
|
||||
for j in (m - 1)..n {
|
||||
let swap = A.get(i, j);
|
||||
A.set(i, j, A.get(m, j));
|
||||
A.set(m, j, swap);
|
||||
A.swap((i, j), (m, j));
|
||||
}
|
||||
for j in 0..n {
|
||||
let swap = A.get(j, i);
|
||||
A.set(j, i, A.get(j, m));
|
||||
A.set(j, m, swap);
|
||||
A.swap((j, i), (j, m));
|
||||
}
|
||||
}
|
||||
if x != T::zero() {
|
||||
for i in (m + 1)..n {
|
||||
let mut y = A.get(i, m - 1);
|
||||
let mut y = *A.get((i, m - 1));
|
||||
if y != T::zero() {
|
||||
y /= x;
|
||||
A.set(i, m - 1, y);
|
||||
A.set((i, m - 1), y);
|
||||
for j in m..n {
|
||||
A.sub_element_mut(i, j, y * A.get(m, j));
|
||||
A.sub_element_mut((i, j), y * *A.get((m, j)));
|
||||
}
|
||||
for j in 0..n {
|
||||
A.add_element_mut(j, m, y * A.get(j, i));
|
||||
A.add_element_mut((j, m), y * *A.get((j, i)));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -415,24 +399,24 @@ fn elmhes<T: RealNumber, M: BaseMatrix<T>>(A: &mut M) -> Vec<usize> {
|
||||
perm
|
||||
}
|
||||
|
||||
fn eltran<T: RealNumber, M: BaseMatrix<T>>(A: &M, V: &mut M, perm: &[usize]) {
|
||||
fn eltran<T: Number + RealNumber, M: Array2<T>>(A: &M, V: &mut M, perm: &[usize]) {
|
||||
let (n, _) = A.shape();
|
||||
for mp in (1..n - 1).rev() {
|
||||
for k in mp + 1..n {
|
||||
V.set(k, mp, A.get(k, mp - 1));
|
||||
V.set((k, mp), *A.get((k, mp - 1)));
|
||||
}
|
||||
let i = perm[mp];
|
||||
if i != mp {
|
||||
for j in mp..n {
|
||||
V.set(mp, j, V.get(i, j));
|
||||
V.set(i, j, T::zero());
|
||||
V.set((mp, j), *V.get((i, j)));
|
||||
V.set((i, j), T::zero());
|
||||
}
|
||||
V.set(i, mp, T::one());
|
||||
V.set((i, mp), T::one());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &mut [T]) {
|
||||
fn hqr2<T: Number + RealNumber, M: Array2<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &mut [T]) {
|
||||
let (n, _) = A.shape();
|
||||
let mut z = T::zero();
|
||||
let mut s = T::zero();
|
||||
@@ -443,7 +427,7 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &
|
||||
|
||||
for i in 0..n {
|
||||
for j in i32::max(i as i32 - 1, 0)..n as i32 {
|
||||
anorm += A.get(i, j as usize).abs();
|
||||
anorm += A.get((i, j as usize)).abs();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -454,43 +438,43 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &
|
||||
loop {
|
||||
let mut l = nn;
|
||||
while l > 0 {
|
||||
s = A.get(l - 1, l - 1).abs() + A.get(l, l).abs();
|
||||
s = A.get((l - 1, l - 1)).abs() + A.get((l, l)).abs();
|
||||
if s == T::zero() {
|
||||
s = anorm;
|
||||
}
|
||||
if A.get(l, l - 1).abs() <= T::epsilon() * s {
|
||||
A.set(l, l - 1, T::zero());
|
||||
if A.get((l, l - 1)).abs() <= T::epsilon() * s {
|
||||
A.set((l, l - 1), T::zero());
|
||||
break;
|
||||
}
|
||||
l -= 1;
|
||||
}
|
||||
let mut x = A.get(nn, nn);
|
||||
let mut x = *A.get((nn, nn));
|
||||
if l == nn {
|
||||
d[nn] = x + t;
|
||||
A.set(nn, nn, x + t);
|
||||
A.set((nn, nn), x + t);
|
||||
if nn == 0 {
|
||||
break 'outer;
|
||||
} else {
|
||||
nn -= 1;
|
||||
}
|
||||
} else {
|
||||
let mut y = A.get(nn - 1, nn - 1);
|
||||
let mut w = A.get(nn, nn - 1) * A.get(nn - 1, nn);
|
||||
let mut y = *A.get((nn - 1, nn - 1));
|
||||
let mut w = *A.get((nn, nn - 1)) * *A.get((nn - 1, nn));
|
||||
if l == nn - 1 {
|
||||
p = T::half() * (y - x);
|
||||
q = p * p + w;
|
||||
z = q.abs().sqrt();
|
||||
x += t;
|
||||
A.set(nn, nn, x);
|
||||
A.set(nn - 1, nn - 1, y + t);
|
||||
A.set((nn, nn), x);
|
||||
A.set((nn - 1, nn - 1), y + t);
|
||||
if q >= T::zero() {
|
||||
z = p + RealNumber::copysign(z, p);
|
||||
z = p + <T as RealNumber>::copysign(z, p);
|
||||
d[nn - 1] = x + z;
|
||||
d[nn] = x + z;
|
||||
if z != T::zero() {
|
||||
d[nn] = x - w / z;
|
||||
}
|
||||
x = A.get(nn, nn - 1);
|
||||
x = *A.get((nn, nn - 1));
|
||||
s = x.abs() + z.abs();
|
||||
p = x / s;
|
||||
q = z / s;
|
||||
@@ -498,19 +482,19 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &
|
||||
p /= r;
|
||||
q /= r;
|
||||
for j in nn - 1..n {
|
||||
z = A.get(nn - 1, j);
|
||||
A.set(nn - 1, j, q * z + p * A.get(nn, j));
|
||||
A.set(nn, j, q * A.get(nn, j) - p * z);
|
||||
z = *A.get((nn - 1, j));
|
||||
A.set((nn - 1, j), q * z + p * *A.get((nn, j)));
|
||||
A.set((nn, j), q * *A.get((nn, j)) - p * z);
|
||||
}
|
||||
for i in 0..=nn {
|
||||
z = A.get(i, nn - 1);
|
||||
A.set(i, nn - 1, q * z + p * A.get(i, nn));
|
||||
A.set(i, nn, q * A.get(i, nn) - p * z);
|
||||
z = *A.get((i, nn - 1));
|
||||
A.set((i, nn - 1), q * z + p * *A.get((i, nn)));
|
||||
A.set((i, nn), q * *A.get((i, nn)) - p * z);
|
||||
}
|
||||
for i in 0..n {
|
||||
z = V.get(i, nn - 1);
|
||||
V.set(i, nn - 1, q * z + p * V.get(i, nn));
|
||||
V.set(i, nn, q * V.get(i, nn) - p * z);
|
||||
z = *V.get((i, nn - 1));
|
||||
V.set((i, nn - 1), q * z + p * *V.get((i, nn)));
|
||||
V.set((i, nn), q * *V.get((i, nn)) - p * z);
|
||||
}
|
||||
} else {
|
||||
d[nn] = x + p;
|
||||
@@ -531,22 +515,22 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &
|
||||
if its == 10 || its == 20 {
|
||||
t += x;
|
||||
for i in 0..nn + 1 {
|
||||
A.sub_element_mut(i, i, x);
|
||||
A.sub_element_mut((i, i), x);
|
||||
}
|
||||
s = A.get(nn, nn - 1).abs() + A.get(nn - 1, nn - 2).abs();
|
||||
y = T::from(0.75).unwrap() * s;
|
||||
x = T::from(0.75).unwrap() * s;
|
||||
w = T::from(-0.4375).unwrap() * s * s;
|
||||
s = A.get((nn, nn - 1)).abs() + A.get((nn - 1, nn - 2)).abs();
|
||||
y = T::from_f64(0.75).unwrap() * s;
|
||||
x = T::from_f64(0.75).unwrap() * s;
|
||||
w = T::from_f64(-0.4375).unwrap() * s * s;
|
||||
}
|
||||
its += 1;
|
||||
let mut m = nn - 2;
|
||||
while m >= l {
|
||||
z = A.get(m, m);
|
||||
z = *A.get((m, m));
|
||||
r = x - z;
|
||||
s = y - z;
|
||||
p = (r * s - w) / A.get(m + 1, m) + A.get(m, m + 1);
|
||||
q = A.get(m + 1, m + 1) - z - r - s;
|
||||
r = A.get(m + 2, m + 1);
|
||||
p = (r * s - w) / *A.get((m + 1, m)) + *A.get((m, m + 1));
|
||||
q = *A.get((m + 1, m + 1)) - z - r - s;
|
||||
r = *A.get((m + 2, m + 1));
|
||||
s = p.abs() + q.abs() + r.abs();
|
||||
p /= s;
|
||||
q /= s;
|
||||
@@ -554,27 +538,27 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &
|
||||
if m == l {
|
||||
break;
|
||||
}
|
||||
let u = A.get(m, m - 1).abs() * (q.abs() + r.abs());
|
||||
let u = A.get((m, m - 1)).abs() * (q.abs() + r.abs());
|
||||
let v = p.abs()
|
||||
* (A.get(m - 1, m - 1).abs() + z.abs() + A.get(m + 1, m + 1).abs());
|
||||
* (A.get((m - 1, m - 1)).abs() + z.abs() + A.get((m + 1, m + 1)).abs());
|
||||
if u <= T::epsilon() * v {
|
||||
break;
|
||||
}
|
||||
m -= 1;
|
||||
}
|
||||
for i in m..nn - 1 {
|
||||
A.set(i + 2, i, T::zero());
|
||||
A.set((i + 2, i), T::zero());
|
||||
if i != m {
|
||||
A.set(i + 2, i - 1, T::zero());
|
||||
A.set((i + 2, i - 1), T::zero());
|
||||
}
|
||||
}
|
||||
for k in m..nn {
|
||||
if k != m {
|
||||
p = A.get(k, k - 1);
|
||||
q = A.get(k + 1, k - 1);
|
||||
p = *A.get((k, k - 1));
|
||||
q = *A.get((k + 1, k - 1));
|
||||
r = T::zero();
|
||||
if k + 1 != nn {
|
||||
r = A.get(k + 2, k - 1);
|
||||
r = *A.get((k + 2, k - 1));
|
||||
}
|
||||
x = p.abs() + q.abs() + r.abs();
|
||||
if x != T::zero() {
|
||||
@@ -583,14 +567,14 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &
|
||||
r /= x;
|
||||
}
|
||||
}
|
||||
let s = RealNumber::copysign((p * p + q * q + r * r).sqrt(), p);
|
||||
let s = <T as RealNumber>::copysign((p * p + q * q + r * r).sqrt(), p);
|
||||
if s != T::zero() {
|
||||
if k == m {
|
||||
if l != m {
|
||||
A.set(k, k - 1, -A.get(k, k - 1));
|
||||
A.set((k, k - 1), -*A.get((k, k - 1)));
|
||||
}
|
||||
} else {
|
||||
A.set(k, k - 1, -s * x);
|
||||
A.set((k, k - 1), -s * x);
|
||||
}
|
||||
p += s;
|
||||
x = p / s;
|
||||
@@ -599,32 +583,33 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &
|
||||
q /= p;
|
||||
r /= p;
|
||||
for j in k..n {
|
||||
p = A.get(k, j) + q * A.get(k + 1, j);
|
||||
p = *A.get((k, j)) + q * *A.get((k + 1, j));
|
||||
if k + 1 != nn {
|
||||
p += r * A.get(k + 2, j);
|
||||
A.sub_element_mut(k + 2, j, p * z);
|
||||
p += r * *A.get((k + 2, j));
|
||||
A.sub_element_mut((k + 2, j), p * z);
|
||||
}
|
||||
A.sub_element_mut(k + 1, j, p * y);
|
||||
A.sub_element_mut(k, j, p * x);
|
||||
A.sub_element_mut((k + 1, j), p * y);
|
||||
A.sub_element_mut((k, j), p * x);
|
||||
}
|
||||
|
||||
let mmin = if nn < k + 3 { nn } else { k + 3 };
|
||||
for i in 0..mmin + 1 {
|
||||
p = x * A.get(i, k) + y * A.get(i, k + 1);
|
||||
for i in 0..(mmin + 1) {
|
||||
p = x * *A.get((i, k)) + y * *A.get((i, k + 1));
|
||||
if k + 1 != nn {
|
||||
p += z * A.get(i, k + 2);
|
||||
A.sub_element_mut(i, k + 2, p * r);
|
||||
p += z * *A.get((i, k + 2));
|
||||
A.sub_element_mut((i, k + 2), p * r);
|
||||
}
|
||||
A.sub_element_mut(i, k + 1, p * q);
|
||||
A.sub_element_mut(i, k, p);
|
||||
A.sub_element_mut((i, k + 1), p * q);
|
||||
A.sub_element_mut((i, k), p);
|
||||
}
|
||||
for i in 0..n {
|
||||
p = x * V.get(i, k) + y * V.get(i, k + 1);
|
||||
p = x * *V.get((i, k)) + y * *V.get((i, k + 1));
|
||||
if k + 1 != nn {
|
||||
p += z * V.get(i, k + 2);
|
||||
V.sub_element_mut(i, k + 2, p * r);
|
||||
p += z * *V.get((i, k + 2));
|
||||
V.sub_element_mut((i, k + 2), p * r);
|
||||
}
|
||||
V.sub_element_mut(i, k + 1, p * q);
|
||||
V.sub_element_mut(i, k, p);
|
||||
V.sub_element_mut((i, k + 1), p * q);
|
||||
V.sub_element_mut((i, k), p);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -643,14 +628,14 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &
|
||||
let na = nn.wrapping_sub(1);
|
||||
if q == T::zero() {
|
||||
let mut m = nn;
|
||||
A.set(nn, nn, T::one());
|
||||
A.set((nn, nn), T::one());
|
||||
if nn > 0 {
|
||||
let mut i = nn - 1;
|
||||
loop {
|
||||
let w = A.get(i, i) - p;
|
||||
let w = *A.get((i, i)) - p;
|
||||
r = T::zero();
|
||||
for j in m..=nn {
|
||||
r += A.get(i, j) * A.get(j, nn);
|
||||
r += *A.get((i, j)) * *A.get((j, nn));
|
||||
}
|
||||
if e[i] < T::zero() {
|
||||
z = w;
|
||||
@@ -663,23 +648,23 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &
|
||||
if t == T::zero() {
|
||||
t = T::epsilon() * anorm;
|
||||
}
|
||||
A.set(i, nn, -r / t);
|
||||
A.set((i, nn), -r / t);
|
||||
} else {
|
||||
let x = A.get(i, i + 1);
|
||||
let y = A.get(i + 1, i);
|
||||
let x = *A.get((i, i + 1));
|
||||
let y = *A.get((i + 1, i));
|
||||
q = (d[i] - p).powf(T::two()) + e[i].powf(T::two());
|
||||
t = (x * s - z * r) / q;
|
||||
A.set(i, nn, t);
|
||||
A.set((i, nn), t);
|
||||
if x.abs() > z.abs() {
|
||||
A.set(i + 1, nn, (-r - w * t) / x);
|
||||
A.set((i + 1, nn), (-r - w * t) / x);
|
||||
} else {
|
||||
A.set(i + 1, nn, (-s - y * t) / z);
|
||||
A.set((i + 1, nn), (-s - y * t) / z);
|
||||
}
|
||||
}
|
||||
t = A.get(i, nn).abs();
|
||||
t = A.get((i, nn)).abs();
|
||||
if T::epsilon() * t * t > T::one() {
|
||||
for j in i..=nn {
|
||||
A.div_element_mut(j, nn, t);
|
||||
A.div_element_mut((j, nn), t);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -692,25 +677,25 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &
|
||||
}
|
||||
} else if q < T::zero() {
|
||||
let mut m = na;
|
||||
if A.get(nn, na).abs() > A.get(na, nn).abs() {
|
||||
A.set(na, na, q / A.get(nn, na));
|
||||
A.set(na, nn, -(A.get(nn, nn) - p) / A.get(nn, na));
|
||||
if A.get((nn, na)).abs() > A.get((na, nn)).abs() {
|
||||
A.set((na, na), q / *A.get((nn, na)));
|
||||
A.set((na, nn), -(*A.get((nn, nn)) - p) / *A.get((nn, na)));
|
||||
} else {
|
||||
let temp = Complex::new(T::zero(), -A.get(na, nn))
|
||||
/ Complex::new(A.get(na, na) - p, q);
|
||||
A.set(na, na, temp.re);
|
||||
A.set(na, nn, temp.im);
|
||||
let temp = Complex::new(T::zero(), -*A.get((na, nn)))
|
||||
/ Complex::new(*A.get((na, na)) - p, q);
|
||||
A.set((na, na), temp.re);
|
||||
A.set((na, nn), temp.im);
|
||||
}
|
||||
A.set(nn, na, T::zero());
|
||||
A.set(nn, nn, T::one());
|
||||
A.set((nn, na), T::zero());
|
||||
A.set((nn, nn), T::one());
|
||||
if nn >= 2 {
|
||||
for i in (0..nn - 1).rev() {
|
||||
let w = A.get(i, i) - p;
|
||||
let w = *A.get((i, i)) - p;
|
||||
let mut ra = T::zero();
|
||||
let mut sa = T::zero();
|
||||
for j in m..=nn {
|
||||
ra += A.get(i, j) * A.get(j, na);
|
||||
sa += A.get(i, j) * A.get(j, nn);
|
||||
ra += *A.get((i, j)) * *A.get((j, na));
|
||||
sa += *A.get((i, j)) * *A.get((j, nn));
|
||||
}
|
||||
if e[i] < T::zero() {
|
||||
z = w;
|
||||
@@ -720,11 +705,11 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &
|
||||
m = i;
|
||||
if e[i] == T::zero() {
|
||||
let temp = Complex::new(-ra, -sa) / Complex::new(w, q);
|
||||
A.set(i, na, temp.re);
|
||||
A.set(i, nn, temp.im);
|
||||
A.set((i, na), temp.re);
|
||||
A.set((i, nn), temp.im);
|
||||
} else {
|
||||
let x = A.get(i, i + 1);
|
||||
let y = A.get(i + 1, i);
|
||||
let x = *A.get((i, i + 1));
|
||||
let y = *A.get((i + 1, i));
|
||||
let mut vr =
|
||||
(d[i] - p).powf(T::two()) + (e[i]).powf(T::two()) - q * q;
|
||||
let vi = T::two() * q * (d[i] - p);
|
||||
@@ -736,33 +721,32 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &
|
||||
let temp =
|
||||
Complex::new(x * r - z * ra + q * sa, x * s - z * sa - q * ra)
|
||||
/ Complex::new(vr, vi);
|
||||
A.set(i, na, temp.re);
|
||||
A.set(i, nn, temp.im);
|
||||
A.set((i, na), temp.re);
|
||||
A.set((i, nn), temp.im);
|
||||
if x.abs() > z.abs() + q.abs() {
|
||||
A.set(
|
||||
i + 1,
|
||||
na,
|
||||
(-ra - w * A.get(i, na) + q * A.get(i, nn)) / x,
|
||||
(i + 1, na),
|
||||
(-ra - w * *A.get((i, na)) + q * *A.get((i, nn))) / x,
|
||||
);
|
||||
A.set(
|
||||
i + 1,
|
||||
nn,
|
||||
(-sa - w * A.get(i, nn) - q * A.get(i, na)) / x,
|
||||
(i + 1, nn),
|
||||
(-sa - w * *A.get((i, nn)) - q * *A.get((i, na))) / x,
|
||||
);
|
||||
} else {
|
||||
let temp =
|
||||
Complex::new(-r - y * A.get(i, na), -s - y * A.get(i, nn))
|
||||
/ Complex::new(z, q);
|
||||
A.set(i + 1, na, temp.re);
|
||||
A.set(i + 1, nn, temp.im);
|
||||
let temp = Complex::new(
|
||||
-r - y * *A.get((i, na)),
|
||||
-s - y * *A.get((i, nn)),
|
||||
) / Complex::new(z, q);
|
||||
A.set((i + 1, na), temp.re);
|
||||
A.set((i + 1, nn), temp.im);
|
||||
}
|
||||
}
|
||||
}
|
||||
t = T::max(A.get(i, na).abs(), A.get(i, nn).abs());
|
||||
t = T::max(A.get((i, na)).abs(), A.get((i, nn)).abs());
|
||||
if T::epsilon() * t * t > T::one() {
|
||||
for j in i..=nn {
|
||||
A.div_element_mut(j, na, t);
|
||||
A.div_element_mut(j, nn, t);
|
||||
A.div_element_mut((j, na), t);
|
||||
A.div_element_mut((j, nn), t);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -774,31 +758,31 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &
|
||||
for i in 0..n {
|
||||
z = T::zero();
|
||||
for k in 0..=j {
|
||||
z += V.get(i, k) * A.get(k, j);
|
||||
z += *V.get((i, k)) * *A.get((k, j));
|
||||
}
|
||||
V.set(i, j, z);
|
||||
V.set((i, j), z);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn balbak<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, scale: &[T]) {
|
||||
fn balbak<T: Number + RealNumber, M: Array2<T>>(V: &mut M, scale: &[T]) {
|
||||
let (n, _) = V.shape();
|
||||
for (i, scale_i) in scale.iter().enumerate().take(n) {
|
||||
for j in 0..n {
|
||||
V.mul_element_mut(i, j, *scale_i);
|
||||
V.mul_element_mut((i, j), *scale_i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn sort<T: RealNumber, M: BaseMatrix<T>>(d: &mut [T], e: &mut [T], V: &mut M) {
|
||||
fn sort<T: Number + RealNumber, M: Array2<T>>(d: &mut [T], e: &mut [T], V: &mut M) {
|
||||
let n = d.len();
|
||||
let mut temp = vec![T::zero(); n];
|
||||
for j in 1..n {
|
||||
let real = d[j];
|
||||
let img = e[j];
|
||||
for (k, temp_k) in temp.iter_mut().enumerate().take(n) {
|
||||
*temp_k = V.get(k, j);
|
||||
*temp_k = *V.get((k, j));
|
||||
}
|
||||
let mut i = j as i32 - 1;
|
||||
while i >= 0 {
|
||||
@@ -808,14 +792,14 @@ fn sort<T: RealNumber, M: BaseMatrix<T>>(d: &mut [T], e: &mut [T], V: &mut M) {
|
||||
d[i as usize + 1] = d[i as usize];
|
||||
e[i as usize + 1] = e[i as usize];
|
||||
for k in 0..n {
|
||||
V.set(k, i as usize + 1, V.get(k, i as usize));
|
||||
V.set((k, i as usize + 1), *V.get((k, i as usize)));
|
||||
}
|
||||
i -= 1;
|
||||
}
|
||||
d[(i + 1) as usize] = real;
|
||||
e[(i + 1) as usize] = img;
|
||||
d[i as usize + 1] = real;
|
||||
e[i as usize + 1] = img;
|
||||
for (k, temp_k) in temp.iter().enumerate().take(n) {
|
||||
V.set(k, (i + 1) as usize, *temp_k);
|
||||
V.set((k, i as usize + 1), *temp_k);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -823,15 +807,21 @@ fn sort<T: RealNumber, M: BaseMatrix<T>>(d: &mut [T], e: &mut [T], V: &mut M) {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use approx::relative_eq;
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn decompose_symmetric() {
|
||||
let A = DenseMatrix::from_2d_array(&[
|
||||
&[0.9000, 0.4000, 0.7000],
|
||||
&[0.4000, 0.5000, 0.3000],
|
||||
&[0.7000, 0.3000, 0.8000],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let eigen_values: Vec<f64> = vec![1.7498382, 0.3165784, 0.1335834];
|
||||
|
||||
@@ -839,26 +829,33 @@ mod tests {
|
||||
&[0.6881997, -0.07121225, 0.7220180],
|
||||
&[0.3700456, 0.89044952, -0.2648886],
|
||||
&[0.6240573, -0.44947578, -0.6391588],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let evd = A.evd(true).unwrap();
|
||||
|
||||
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
|
||||
for i in 0..eigen_values.len() {
|
||||
assert!((eigen_values[i] - evd.d[i]).abs() < 1e-4);
|
||||
}
|
||||
for i in 0..eigen_values.len() {
|
||||
assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON);
|
||||
assert!(relative_eq!(
|
||||
eigen_vectors.abs(),
|
||||
evd.V.abs(),
|
||||
epsilon = 1e-4
|
||||
));
|
||||
for (i, eigen_values_i) in eigen_values.iter().enumerate() {
|
||||
assert!((eigen_values_i - evd.d[i]).abs() < 1e-4);
|
||||
assert!((0f64 - evd.e[i]).abs() < f64::EPSILON);
|
||||
}
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn decompose_asymmetric() {
|
||||
let A = DenseMatrix::from_2d_array(&[
|
||||
&[0.9000, 0.4000, 0.7000],
|
||||
&[0.4000, 0.5000, 0.3000],
|
||||
&[0.8000, 0.3000, 0.8000],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let eigen_values: Vec<f64> = vec![1.79171122, 0.31908143, 0.08920735];
|
||||
|
||||
@@ -866,19 +863,25 @@ mod tests {
|
||||
&[0.7178958, 0.05322098, 0.6812010],
|
||||
&[0.3837711, -0.84702111, -0.1494582],
|
||||
&[0.6952105, 0.43984484, -0.7036135],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let evd = A.evd(false).unwrap();
|
||||
|
||||
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
|
||||
for i in 0..eigen_values.len() {
|
||||
assert!((eigen_values[i] - evd.d[i]).abs() < 1e-4);
|
||||
}
|
||||
for i in 0..eigen_values.len() {
|
||||
assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON);
|
||||
assert!(relative_eq!(
|
||||
eigen_vectors.abs(),
|
||||
evd.V.abs(),
|
||||
epsilon = 1e-4
|
||||
));
|
||||
for (i, eigen_values_i) in eigen_values.iter().enumerate() {
|
||||
assert!((eigen_values_i - evd.d[i]).abs() < 1e-4);
|
||||
assert!((0f64 - evd.e[i]).abs() < f64::EPSILON);
|
||||
}
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn decompose_complex() {
|
||||
let A = DenseMatrix::from_2d_array(&[
|
||||
@@ -886,7 +889,8 @@ mod tests {
|
||||
&[4.0, -1.0, 1.0, 1.0],
|
||||
&[1.0, 1.0, 3.0, -2.0],
|
||||
&[1.0, 1.0, 4.0, -1.0],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let eigen_values_d: Vec<f64> = vec![0.0, 2.0, 2.0, 0.0];
|
||||
let eigen_values_e: Vec<f64> = vec![2.2361, 0.9999, -0.9999, -2.2361];
|
||||
@@ -896,16 +900,21 @@ mod tests {
|
||||
&[-0.6707, 0.1059, 0.901, 0.6289],
|
||||
&[0.9159, -0.1378, 0.3816, 0.0806],
|
||||
&[0.6707, 0.1059, 0.901, -0.6289],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let evd = A.evd(false).unwrap();
|
||||
|
||||
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
|
||||
for i in 0..eigen_values_d.len() {
|
||||
assert!((eigen_values_d[i] - evd.d[i]).abs() < 1e-4);
|
||||
assert!(relative_eq!(
|
||||
eigen_vectors.abs(),
|
||||
evd.V.abs(),
|
||||
epsilon = 1e-4
|
||||
));
|
||||
for (i, eigen_values_d_i) in eigen_values_d.iter().enumerate() {
|
||||
assert!((eigen_values_d_i - evd.d[i]).abs() < 1e-4);
|
||||
}
|
||||
for i in 0..eigen_values_e.len() {
|
||||
assert!((eigen_values_e[i] - evd.e[i]).abs() < 1e-4);
|
||||
for (i, eigen_values_e_i) in eigen_values_e.iter().enumerate() {
|
||||
assert!((eigen_values_e_i - evd.e[i]).abs() < 1e-4);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,19 +1,20 @@
|
||||
//! In this module you will find composite of matrix operations that are used elsewhere
|
||||
//! for improved efficiency.
|
||||
|
||||
use crate::linalg::BaseMatrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::linalg::basic::arrays::Array2;
|
||||
use crate::numbers::basenum::Number;
|
||||
|
||||
/// High order matrix operations.
|
||||
pub trait HighOrderOperations<T: RealNumber>: BaseMatrix<T> {
|
||||
pub trait HighOrderOperations<T: Number>: Array2<T> {
|
||||
/// Y = AB
|
||||
/// ```
|
||||
/// use smartcore::linalg::naive::dense_matrix::*;
|
||||
/// use smartcore::linalg::high_order::HighOrderOperations;
|
||||
/// use smartcore::linalg::basic::matrix::*;
|
||||
/// use smartcore::linalg::traits::high_order::HighOrderOperations;
|
||||
/// use smartcore::linalg::basic::arrays::Array2;
|
||||
///
|
||||
/// let a = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.]]);
|
||||
/// let b = DenseMatrix::from_2d_array(&[&[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
/// let expected = DenseMatrix::from_2d_array(&[&[71., 80.], &[92., 104.]]);
|
||||
/// let a = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.]]).unwrap();
|
||||
/// let b = DenseMatrix::from_2d_array(&[&[5., 6.], &[7., 8.], &[9., 10.]]).unwrap();
|
||||
/// let expected = DenseMatrix::from_2d_array(&[&[71., 80.], &[92., 104.]]).unwrap();
|
||||
///
|
||||
/// assert_eq!(a.ab(true, &b, false), expected);
|
||||
/// ```
|
||||
@@ -26,3 +27,7 @@ pub trait HighOrderOperations<T: RealNumber>: BaseMatrix<T> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mod tests {
|
||||
/* TODO: Add tests */
|
||||
}
|
||||
@@ -11,14 +11,14 @@
|
||||
//!
|
||||
//! Example:
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::linalg::lu::*;
|
||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
//! use smartcore::linalg::traits::lu::*;
|
||||
//!
|
||||
//! let A = DenseMatrix::from_2d_array(&[
|
||||
//! &[1., 2., 3.],
|
||||
//! &[0., 1., 5.],
|
||||
//! &[5., 6., 0.]
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! let lu = A.lu().unwrap();
|
||||
//! let lower: DenseMatrix<f64> = lu.L();
|
||||
@@ -38,26 +38,27 @@ use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::BaseMatrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
use crate::linalg::basic::arrays::Array2;
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
#[derive(Debug, Clone)]
|
||||
/// Result of LU decomposition.
|
||||
pub struct LU<T: RealNumber, M: BaseMatrix<T>> {
|
||||
pub struct LU<T: Number + RealNumber, M: Array2<T>> {
|
||||
LU: M,
|
||||
pivot: Vec<usize>,
|
||||
_pivot_sign: i8,
|
||||
#[allow(dead_code)]
|
||||
pivot_sign: i8,
|
||||
singular: bool,
|
||||
phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
|
||||
pub(crate) fn new(LU: M, pivot: Vec<usize>, _pivot_sign: i8) -> LU<T, M> {
|
||||
impl<T: Number + RealNumber, M: Array2<T>> LU<T, M> {
|
||||
pub(crate) fn new(LU: M, pivot: Vec<usize>, pivot_sign: i8) -> LU<T, M> {
|
||||
let (_, n) = LU.shape();
|
||||
|
||||
let mut singular = false;
|
||||
for j in 0..n {
|
||||
if LU.get(j, j) == T::zero() {
|
||||
if LU.get((j, j)) == &T::zero() {
|
||||
singular = true;
|
||||
break;
|
||||
}
|
||||
@@ -66,7 +67,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
|
||||
LU {
|
||||
LU,
|
||||
pivot,
|
||||
_pivot_sign,
|
||||
pivot_sign,
|
||||
singular,
|
||||
phantom: PhantomData,
|
||||
}
|
||||
@@ -80,9 +81,9 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
|
||||
for i in 0..n_rows {
|
||||
for j in 0..n_cols {
|
||||
match i.cmp(&j) {
|
||||
Ordering::Greater => L.set(i, j, self.LU.get(i, j)),
|
||||
Ordering::Equal => L.set(i, j, T::one()),
|
||||
Ordering::Less => L.set(i, j, T::zero()),
|
||||
Ordering::Greater => L.set((i, j), *self.LU.get((i, j))),
|
||||
Ordering::Equal => L.set((i, j), T::one()),
|
||||
Ordering::Less => L.set((i, j), T::zero()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -98,9 +99,9 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
|
||||
for i in 0..n_rows {
|
||||
for j in 0..n_cols {
|
||||
if i <= j {
|
||||
U.set(i, j, self.LU.get(i, j));
|
||||
U.set((i, j), *self.LU.get((i, j)));
|
||||
} else {
|
||||
U.set(i, j, T::zero());
|
||||
U.set((i, j), T::zero());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -114,7 +115,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
|
||||
let mut piv = M::zeros(n, n);
|
||||
|
||||
for i in 0..n {
|
||||
piv.set(i, self.pivot[i], T::one());
|
||||
piv.set((i, self.pivot[i]), T::one());
|
||||
}
|
||||
|
||||
piv
|
||||
@@ -125,13 +126,13 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
|
||||
let (m, n) = self.LU.shape();
|
||||
|
||||
if m != n {
|
||||
panic!("Matrix is not square: {}x{}", m, n);
|
||||
panic!("Matrix is not square: {m}x{n}");
|
||||
}
|
||||
|
||||
let mut inv = M::zeros(n, n);
|
||||
|
||||
for i in 0..n {
|
||||
inv.set(i, i, T::one());
|
||||
inv.set((i, i), T::one());
|
||||
}
|
||||
|
||||
self.solve(inv)
|
||||
@@ -142,10 +143,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
|
||||
let (b_m, b_n) = b.shape();
|
||||
|
||||
if b_m != m {
|
||||
panic!(
|
||||
"Row dimensions do not agree: A is {} x {}, but B is {} x {}",
|
||||
m, n, b_m, b_n
|
||||
);
|
||||
panic!("Row dimensions do not agree: A is {m} x {n}, but B is {b_m} x {b_n}");
|
||||
}
|
||||
|
||||
if self.singular {
|
||||
@@ -156,33 +154,33 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
|
||||
|
||||
for j in 0..b_n {
|
||||
for i in 0..m {
|
||||
X.set(i, j, b.get(self.pivot[i], j));
|
||||
X.set((i, j), *b.get((self.pivot[i], j)));
|
||||
}
|
||||
}
|
||||
|
||||
for k in 0..n {
|
||||
for i in k + 1..n {
|
||||
for j in 0..b_n {
|
||||
X.sub_element_mut(i, j, X.get(k, j) * self.LU.get(i, k));
|
||||
X.sub_element_mut((i, j), *X.get((k, j)) * *self.LU.get((i, k)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for k in (0..n).rev() {
|
||||
for j in 0..b_n {
|
||||
X.div_element_mut(k, j, self.LU.get(k, k));
|
||||
X.div_element_mut((k, j), *self.LU.get((k, k)));
|
||||
}
|
||||
|
||||
for i in 0..k {
|
||||
for j in 0..b_n {
|
||||
X.sub_element_mut(i, j, X.get(k, j) * self.LU.get(i, k));
|
||||
X.sub_element_mut((i, j), *X.get((k, j)) * *self.LU.get((i, k)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for j in 0..b_n {
|
||||
for i in 0..m {
|
||||
b.set(i, j, X.get(i, j));
|
||||
b.set((i, j), *X.get((i, j)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -191,7 +189,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
|
||||
}
|
||||
|
||||
/// Trait that implements LU decomposition routine for any matrix.
|
||||
pub trait LUDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
pub trait LUDecomposable<T: Number + RealNumber>: Array2<T> {
|
||||
/// Compute the LU decomposition of a square matrix.
|
||||
fn lu(&self) -> Result<LU<T, Self>, Failed> {
|
||||
self.clone().lu_mut()
|
||||
@@ -209,18 +207,18 @@ pub trait LUDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
|
||||
for j in 0..n {
|
||||
for (i, LUcolj_i) in LUcolj.iter_mut().enumerate().take(m) {
|
||||
*LUcolj_i = self.get(i, j);
|
||||
*LUcolj_i = *self.get((i, j));
|
||||
}
|
||||
|
||||
for i in 0..m {
|
||||
let kmax = usize::min(i, j);
|
||||
let mut s = T::zero();
|
||||
for (k, LUcolj_k) in LUcolj.iter().enumerate().take(kmax) {
|
||||
s += self.get(i, k) * (*LUcolj_k);
|
||||
s += *self.get((i, k)) * (*LUcolj_k);
|
||||
}
|
||||
|
||||
LUcolj[i] -= s;
|
||||
self.set(i, j, LUcolj[i]);
|
||||
self.set((i, j), LUcolj[i]);
|
||||
}
|
||||
|
||||
let mut p = j;
|
||||
@@ -231,17 +229,15 @@ pub trait LUDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
}
|
||||
if p != j {
|
||||
for k in 0..n {
|
||||
let t = self.get(p, k);
|
||||
self.set(p, k, self.get(j, k));
|
||||
self.set(j, k, t);
|
||||
self.swap((p, k), (j, k));
|
||||
}
|
||||
piv.swap(p, j);
|
||||
pivsign = -pivsign;
|
||||
}
|
||||
|
||||
if j < m && self.get(j, j) != T::zero() {
|
||||
if j < m && self.get((j, j)) != &T::zero() {
|
||||
for i in j + 1..m {
|
||||
self.div_element_mut(i, j, self.get(j, j));
|
||||
self.div_element_mut((i, j), *self.get((j, j)));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -258,30 +254,38 @@ pub trait LUDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use approx::relative_eq;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn decompose() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
|
||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]).unwrap();
|
||||
let expected_L =
|
||||
DenseMatrix::from_2d_array(&[&[1., 0., 0.], &[0., 1., 0.], &[0.2, 0.8, 1.]]);
|
||||
DenseMatrix::from_2d_array(&[&[1., 0., 0.], &[0., 1., 0.], &[0.2, 0.8, 1.]]).unwrap();
|
||||
let expected_U =
|
||||
DenseMatrix::from_2d_array(&[&[5., 6., 0.], &[0., 1., 5.], &[0., 0., -1.]]);
|
||||
DenseMatrix::from_2d_array(&[&[5., 6., 0.], &[0., 1., 5.], &[0., 0., -1.]]).unwrap();
|
||||
let expected_pivot =
|
||||
DenseMatrix::from_2d_array(&[&[0., 0., 1.], &[0., 1., 0.], &[1., 0., 0.]]);
|
||||
DenseMatrix::from_2d_array(&[&[0., 0., 1.], &[0., 1., 0.], &[1., 0., 0.]]).unwrap();
|
||||
let lu = a.lu().unwrap();
|
||||
assert!(lu.L().approximate_eq(&expected_L, 1e-4));
|
||||
assert!(lu.U().approximate_eq(&expected_U, 1e-4));
|
||||
assert!(lu.pivot().approximate_eq(&expected_pivot, 1e-4));
|
||||
assert!(relative_eq!(lu.L(), expected_L, epsilon = 1e-4));
|
||||
assert!(relative_eq!(lu.U(), expected_U, epsilon = 1e-4));
|
||||
assert!(relative_eq!(lu.pivot(), expected_pivot, epsilon = 1e-4));
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn inverse() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
|
||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]).unwrap();
|
||||
let expected =
|
||||
DenseMatrix::from_2d_array(&[&[-6.0, 3.6, 1.4], &[5.0, -3.0, -1.0], &[-1.0, 0.8, 0.2]]);
|
||||
DenseMatrix::from_2d_array(&[&[-6.0, 3.6, 1.4], &[5.0, -3.0, -1.0], &[-1.0, 0.8, 0.2]])
|
||||
.unwrap();
|
||||
let a_inv = a.lu().and_then(|lu| lu.inverse()).unwrap();
|
||||
assert!(a_inv.approximate_eq(&expected, 1e-4));
|
||||
assert!(relative_eq!(a_inv, expected, epsilon = 1e-4));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
#![allow(clippy::wrong_self_convention)]
|
||||
|
||||
pub mod cholesky;
|
||||
/// The matrix is represented in terms of its eigenvalues and eigenvectors.
|
||||
pub mod evd;
|
||||
pub mod high_order;
|
||||
/// Factors a matrix as the product of a lower triangular matrix and an upper triangular matrix.
|
||||
pub mod lu;
|
||||
|
||||
/// QR factorization that factors a matrix into a product of an orthogonal matrix and an upper triangular matrix.
|
||||
pub mod qr;
|
||||
/// statistacal tools for DenseMatrix
|
||||
pub mod stats;
|
||||
/// Singular value decomposition.
|
||||
pub mod svd;
|
||||
@@ -6,14 +6,14 @@
|
||||
//!
|
||||
//! Example:
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::linalg::qr::*;
|
||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
//! use smartcore::linalg::traits::qr::*;
|
||||
//!
|
||||
//! let A = DenseMatrix::from_2d_array(&[
|
||||
//! &[0.9, 0.4, 0.7],
|
||||
//! &[0.4, 0.5, 0.3],
|
||||
//! &[0.7, 0.3, 0.8]
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! let qr = A.qr().unwrap();
|
||||
//! let orthogonal: DenseMatrix<f64> = qr.Q();
|
||||
@@ -28,20 +28,22 @@
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::BaseMatrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use std::fmt::Debug;
|
||||
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::basic::arrays::Array2;
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Results of QR decomposition.
|
||||
pub struct QR<T: RealNumber, M: BaseMatrix<T>> {
|
||||
pub struct QR<T: Number + RealNumber, M: Array2<T>> {
|
||||
QR: M,
|
||||
tau: Vec<T>,
|
||||
singular: bool,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: BaseMatrix<T>> QR<T, M> {
|
||||
impl<T: Number + RealNumber, M: Array2<T>> QR<T, M> {
|
||||
pub(crate) fn new(QR: M, tau: Vec<T>) -> QR<T, M> {
|
||||
let mut singular = false;
|
||||
for tau_elem in tau.iter() {
|
||||
@@ -59,9 +61,9 @@ impl<T: RealNumber, M: BaseMatrix<T>> QR<T, M> {
|
||||
let (_, n) = self.QR.shape();
|
||||
let mut R = M::zeros(n, n);
|
||||
for i in 0..n {
|
||||
R.set(i, i, self.tau[i]);
|
||||
R.set((i, i), self.tau[i]);
|
||||
for j in i + 1..n {
|
||||
R.set(i, j, self.QR.get(i, j));
|
||||
R.set((i, j), *self.QR.get((i, j)));
|
||||
}
|
||||
}
|
||||
R
|
||||
@@ -73,16 +75,16 @@ impl<T: RealNumber, M: BaseMatrix<T>> QR<T, M> {
|
||||
let mut Q = M::zeros(m, n);
|
||||
let mut k = n - 1;
|
||||
loop {
|
||||
Q.set(k, k, T::one());
|
||||
Q.set((k, k), T::one());
|
||||
for j in k..n {
|
||||
if self.QR.get(k, k) != T::zero() {
|
||||
if self.QR.get((k, k)) != &T::zero() {
|
||||
let mut s = T::zero();
|
||||
for i in k..m {
|
||||
s += self.QR.get(i, k) * Q.get(i, j);
|
||||
s += *self.QR.get((i, k)) * *Q.get((i, j));
|
||||
}
|
||||
s = -s / self.QR.get(k, k);
|
||||
s = -s / *self.QR.get((k, k));
|
||||
for i in k..m {
|
||||
Q.add_element_mut(i, j, s * self.QR.get(i, k));
|
||||
Q.add_element_mut((i, j), s * *self.QR.get((i, k)));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -100,10 +102,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> QR<T, M> {
|
||||
let (b_nrows, b_ncols) = b.shape();
|
||||
|
||||
if b_nrows != m {
|
||||
panic!(
|
||||
"Row dimensions do not agree: A is {} x {}, but B is {} x {}",
|
||||
m, n, b_nrows, b_ncols
|
||||
);
|
||||
panic!("Row dimensions do not agree: A is {m} x {n}, but B is {b_nrows} x {b_ncols}");
|
||||
}
|
||||
|
||||
if self.singular {
|
||||
@@ -114,23 +113,23 @@ impl<T: RealNumber, M: BaseMatrix<T>> QR<T, M> {
|
||||
for j in 0..b_ncols {
|
||||
let mut s = T::zero();
|
||||
for i in k..m {
|
||||
s += self.QR.get(i, k) * b.get(i, j);
|
||||
s += *self.QR.get((i, k)) * *b.get((i, j));
|
||||
}
|
||||
s = -s / self.QR.get(k, k);
|
||||
s = -s / *self.QR.get((k, k));
|
||||
for i in k..m {
|
||||
b.add_element_mut(i, j, s * self.QR.get(i, k));
|
||||
b.add_element_mut((i, j), s * *self.QR.get((i, k)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for k in (0..n).rev() {
|
||||
for j in 0..b_ncols {
|
||||
b.set(k, j, b.get(k, j) / self.tau[k]);
|
||||
b.set((k, j), *b.get((k, j)) / self.tau[k]);
|
||||
}
|
||||
|
||||
for i in 0..k {
|
||||
for j in 0..b_ncols {
|
||||
b.sub_element_mut(i, j, b.get(k, j) * self.QR.get(i, k));
|
||||
b.sub_element_mut((i, j), *b.get((k, j)) * *self.QR.get((i, k)));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -140,7 +139,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> QR<T, M> {
|
||||
}
|
||||
|
||||
/// Trait that implements QR decomposition routine for any matrix.
|
||||
pub trait QRDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
pub trait QRDecomposable<T: Number + RealNumber>: Array2<T> {
|
||||
/// Compute the QR decomposition of a matrix.
|
||||
fn qr(&self) -> Result<QR<T, Self>, Failed> {
|
||||
self.clone().qr_mut()
|
||||
@@ -156,26 +155,26 @@ pub trait QRDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
for (k, r_diagonal_k) in r_diagonal.iter_mut().enumerate().take(n) {
|
||||
let mut nrm = T::zero();
|
||||
for i in k..m {
|
||||
nrm = nrm.hypot(self.get(i, k));
|
||||
nrm = nrm.hypot(*self.get((i, k)));
|
||||
}
|
||||
|
||||
if nrm.abs() > T::epsilon() {
|
||||
if self.get(k, k) < T::zero() {
|
||||
if self.get((k, k)) < &T::zero() {
|
||||
nrm = -nrm;
|
||||
}
|
||||
for i in k..m {
|
||||
self.div_element_mut(i, k, nrm);
|
||||
self.div_element_mut((i, k), nrm);
|
||||
}
|
||||
self.add_element_mut(k, k, T::one());
|
||||
self.add_element_mut((k, k), T::one());
|
||||
|
||||
for j in k + 1..n {
|
||||
let mut s = T::zero();
|
||||
for i in k..m {
|
||||
s += self.get(i, k) * self.get(i, j);
|
||||
s += *self.get((i, k)) * *self.get((i, j));
|
||||
}
|
||||
s = -s / self.get(k, k);
|
||||
s = -s / *self.get((k, k));
|
||||
for i in k..m {
|
||||
self.add_element_mut(i, j, s * self.get(i, k));
|
||||
self.add_element_mut((i, j), s * *self.get((i, k)));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -194,37 +193,49 @@ pub trait QRDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use approx::relative_eq;
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn decompose() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]])
|
||||
.unwrap();
|
||||
let q = DenseMatrix::from_2d_array(&[
|
||||
&[-0.7448, 0.2436, 0.6212],
|
||||
&[-0.331, -0.9432, -0.027],
|
||||
&[-0.5793, 0.2257, -0.7832],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let r = DenseMatrix::from_2d_array(&[
|
||||
&[-1.2083, -0.6373, -1.0842],
|
||||
&[0.0, -0.3064, 0.0682],
|
||||
&[0.0, 0.0, -0.1999],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let qr = a.qr().unwrap();
|
||||
assert!(qr.Q().abs().approximate_eq(&q.abs(), 1e-4));
|
||||
assert!(qr.R().abs().approximate_eq(&r.abs(), 1e-4));
|
||||
assert!(relative_eq!(qr.Q().abs(), q.abs(), epsilon = 1e-4));
|
||||
assert!(relative_eq!(qr.R().abs(), r.abs(), epsilon = 1e-4));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn qr_solve_mut() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]);
|
||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]])
|
||||
.unwrap();
|
||||
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]).unwrap();
|
||||
let expected_w = DenseMatrix::from_2d_array(&[
|
||||
&[-0.2027027, -1.2837838],
|
||||
&[0.8783784, 2.2297297],
|
||||
&[0.4729730, 0.6621622],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let w = a.qr_solve_mut(b).unwrap();
|
||||
assert!(w.approximate_eq(&expected_w, 1e-2));
|
||||
assert!(relative_eq!(w, expected_w, epsilon = 1e-2));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,297 @@
|
||||
//! # Various Statistical Methods
|
||||
//!
|
||||
//! This module provides reference implementations for various statistical functions.
|
||||
//! Concrete implementations of the `BaseMatrix` trait are free to override these methods for better performance.
|
||||
|
||||
//! This methods shall be used when dealing with `DenseMatrix`. Use the ones in `linalg::arrays` for `Array` types.
|
||||
|
||||
use crate::linalg::basic::arrays::{Array2, ArrayView2, MutArrayView2};
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
|
||||
/// Defines baseline implementations for various statistical functions
|
||||
pub trait MatrixStats<T: RealNumber>: ArrayView2<T> + Array2<T> {
|
||||
/// Computes the arithmetic mean along the specified axis.
|
||||
fn mean(&self, axis: u8) -> Vec<T> {
|
||||
let (n, _m) = match axis {
|
||||
0 => {
|
||||
let (n, m) = self.shape();
|
||||
(m, n)
|
||||
}
|
||||
_ => self.shape(),
|
||||
};
|
||||
|
||||
let mut x: Vec<T> = vec![T::zero(); n];
|
||||
|
||||
for (i, x_i) in x.iter_mut().enumerate().take(n) {
|
||||
let vec = match axis {
|
||||
0 => self.get_col(i).iterator(0).copied().collect::<Vec<T>>(),
|
||||
_ => self.get_row(i).iterator(0).copied().collect::<Vec<T>>(),
|
||||
};
|
||||
*x_i = Self::_mean_of_vector(&vec[..]);
|
||||
}
|
||||
x
|
||||
}
|
||||
|
||||
/// Computes variance along the specified axis.
|
||||
fn var(&self, axis: u8) -> Vec<T> {
|
||||
let (n, _m) = match axis {
|
||||
0 => {
|
||||
let (n, m) = self.shape();
|
||||
(m, n)
|
||||
}
|
||||
_ => self.shape(),
|
||||
};
|
||||
|
||||
let mut x: Vec<T> = vec![T::zero(); n];
|
||||
|
||||
for (i, x_i) in x.iter_mut().enumerate().take(n) {
|
||||
let vec = match axis {
|
||||
0 => self.get_col(i).iterator(0).copied().collect::<Vec<T>>(),
|
||||
_ => self.get_row(i).iterator(0).copied().collect::<Vec<T>>(),
|
||||
};
|
||||
*x_i = Self::_var_of_vec(&vec[..], Option::None);
|
||||
}
|
||||
|
||||
x
|
||||
}
|
||||
|
||||
/// Computes the standard deviation along the specified axis.
|
||||
fn std(&self, axis: u8) -> Vec<T> {
|
||||
let mut x = Self::var(self, axis);
|
||||
|
||||
let n = match axis {
|
||||
0 => self.shape().1,
|
||||
_ => self.shape().0,
|
||||
};
|
||||
|
||||
for x_i in x.iter_mut().take(n) {
|
||||
*x_i = x_i.sqrt();
|
||||
}
|
||||
|
||||
x
|
||||
}
|
||||
|
||||
/// <http://en.wikipedia.org/wiki/Arithmetic_mean>
|
||||
/// Taken from `statistical`
|
||||
/// The MIT License (MIT)
|
||||
/// Copyright (c) 2015 Jeff Belgum
|
||||
fn _mean_of_vector(v: &[T]) -> T {
|
||||
let len = num::cast(v.len()).unwrap();
|
||||
v.iter().fold(T::zero(), |acc: T, elem| acc + *elem) / len
|
||||
}
|
||||
|
||||
/// Taken from statistical
|
||||
/// The MIT License (MIT)
|
||||
/// Copyright (c) 2015 Jeff Belgum
|
||||
fn _sum_square_deviations_vec(v: &[T], c: Option<T>) -> T {
|
||||
let c = match c {
|
||||
Some(c) => c,
|
||||
None => Self::_mean_of_vector(v),
|
||||
};
|
||||
|
||||
let sum = v
|
||||
.iter()
|
||||
.map(|x| (*x - c) * (*x - c))
|
||||
.fold(T::zero(), |acc, elem| acc + elem);
|
||||
assert!(sum >= T::zero(), "negative sum of square root deviations");
|
||||
sum
|
||||
}
|
||||
|
||||
/// <http://en.wikipedia.org/wiki/Variance#Sample_variance>
|
||||
/// Taken from statistical
|
||||
/// The MIT License (MIT)
|
||||
/// Copyright (c) 2015 Jeff Belgum
|
||||
fn _var_of_vec(v: &[T], xbar: Option<T>) -> T {
|
||||
assert!(v.len() > 1, "variance requires at least two data points");
|
||||
let len: T = num::cast(v.len()).unwrap();
|
||||
let sum = Self::_sum_square_deviations_vec(v, xbar);
|
||||
sum / len
|
||||
}
|
||||
|
||||
/// standardize values by removing the mean and scaling to unit variance
|
||||
fn standard_scale_mut(&mut self, mean: &[T], std: &[T], axis: u8) {
|
||||
let (n, m) = match axis {
|
||||
0 => {
|
||||
let (n, m) = self.shape();
|
||||
(m, n)
|
||||
}
|
||||
_ => self.shape(),
|
||||
};
|
||||
|
||||
for i in 0..n {
|
||||
for j in 0..m {
|
||||
match axis {
|
||||
0 => self.set((j, i), (*self.get((j, i)) - mean[i]) / std[i]),
|
||||
_ => self.set((i, j), (*self.get((i, j)) - mean[i]) / std[i]),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//TODO: this is processing. Should have its own "processing.rs" module
|
||||
/// Defines baseline implementations for various matrix processing functions
|
||||
pub trait MatrixPreprocessing<T: RealNumber>: MutArrayView2<T> + Clone {
|
||||
/// Each element of the matrix greater than the threshold becomes 1, while values less than or equal to the threshold become 0
|
||||
/// ```rust
|
||||
/// use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
/// use smartcore::linalg::traits::stats::MatrixPreprocessing;
|
||||
/// let mut a = DenseMatrix::from_2d_array(&[&[0., 2., 3.], &[-5., -6., -7.]]).unwrap();
|
||||
/// let expected = DenseMatrix::from_2d_array(&[&[0., 1., 1.],&[0., 0., 0.]]).unwrap();
|
||||
/// a.binarize_mut(0.);
|
||||
///
|
||||
/// assert_eq!(a, expected);
|
||||
/// ```
|
||||
fn binarize_mut(&mut self, threshold: T) {
|
||||
let (nrows, ncols) = self.shape();
|
||||
for row in 0..nrows {
|
||||
for col in 0..ncols {
|
||||
if *self.get((row, col)) > threshold {
|
||||
self.set((row, col), T::one());
|
||||
} else {
|
||||
self.set((row, col), T::zero());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
/// Returns new matrix where elements are binarized according to a given threshold.
|
||||
/// ```rust
|
||||
/// use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
/// use smartcore::linalg::traits::stats::MatrixPreprocessing;
|
||||
/// let a = DenseMatrix::from_2d_array(&[&[0., 2., 3.], &[-5., -6., -7.]]).unwrap();
|
||||
/// let expected = DenseMatrix::from_2d_array(&[&[0., 1., 1.],&[0., 0., 0.]]).unwrap();
|
||||
///
|
||||
/// assert_eq!(a.binarize(0.), expected);
|
||||
/// ```
|
||||
fn binarize(self, threshold: T) -> Self
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let mut m = self;
|
||||
m.binarize_mut(threshold);
|
||||
m
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::linalg::basic::arrays::Array1;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use crate::linalg::traits::stats::MatrixStats;
|
||||
|
||||
#[test]
|
||||
fn test_mean() {
|
||||
let m = DenseMatrix::from_2d_array(&[
|
||||
&[1., 2., 3., 1., 2.],
|
||||
&[4., 5., 6., 3., 4.],
|
||||
&[7., 8., 9., 5., 6.],
|
||||
])
|
||||
.unwrap();
|
||||
let expected_0 = vec![4., 5., 6., 3., 4.];
|
||||
let expected_1 = vec![1.8, 4.4, 7.];
|
||||
|
||||
assert_eq!(m.mean(0), expected_0);
|
||||
assert_eq!(m.mean(1), expected_1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_var() {
|
||||
let m = DenseMatrix::from_2d_array(&[&[1., 2., 3., 4.], &[5., 6., 7., 8.]]).unwrap();
|
||||
let expected_0 = vec![4., 4., 4., 4.];
|
||||
let expected_1 = vec![1.25, 1.25];
|
||||
|
||||
assert!(m.var(0).approximate_eq(&expected_0, 1e-6));
|
||||
assert!(m.var(1).approximate_eq(&expected_1, 1e-6));
|
||||
assert_eq!(m.mean(0), vec![3.0, 4.0, 5.0, 6.0]);
|
||||
assert_eq!(m.mean(1), vec![2.5, 6.5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_var_other() {
|
||||
let m = DenseMatrix::from_2d_array(&[
|
||||
&[0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25],
|
||||
&[0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25],
|
||||
])
|
||||
.unwrap();
|
||||
let expected_0 = vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
|
||||
let expected_1 = vec![1.25, 1.25];
|
||||
|
||||
assert!(m.var(0).approximate_eq(&expected_0, f64::EPSILON));
|
||||
assert!(m.var(1).approximate_eq(&expected_1, f64::EPSILON));
|
||||
assert_eq!(
|
||||
m.mean(0),
|
||||
vec![0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25]
|
||||
);
|
||||
assert_eq!(m.mean(1), vec![1.375, 1.375]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_std() {
|
||||
let m = DenseMatrix::from_2d_array(&[
|
||||
&[1., 2., 3., 1., 2.],
|
||||
&[4., 5., 6., 3., 4.],
|
||||
&[7., 8., 9., 5., 6.],
|
||||
])
|
||||
.unwrap();
|
||||
let expected_0 = vec![
|
||||
2.449489742783178,
|
||||
2.449489742783178,
|
||||
2.449489742783178,
|
||||
1.632993161855452,
|
||||
1.632993161855452,
|
||||
];
|
||||
let expected_1 = vec![0.7483314773547883, 1.019803902718557, 1.4142135623730951];
|
||||
|
||||
println!("{:?}", m.var(0));
|
||||
|
||||
assert!(m.std(0).approximate_eq(&expected_0, f64::EPSILON));
|
||||
assert!(m.std(1).approximate_eq(&expected_1, f64::EPSILON));
|
||||
assert_eq!(m.mean(0), vec![4.0, 5.0, 6.0, 3.0, 4.0]);
|
||||
assert_eq!(m.mean(1), vec![1.8, 4.4, 7.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scale() {
|
||||
let m: DenseMatrix<f64> =
|
||||
DenseMatrix::from_2d_array(&[&[1., 2., 3., 4.], &[5., 6., 7., 8.]]).unwrap();
|
||||
|
||||
let expected_0: DenseMatrix<f64> =
|
||||
DenseMatrix::from_2d_array(&[&[-1., -1., -1., -1.], &[1., 1., 1., 1.]]).unwrap();
|
||||
let expected_1: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
|
||||
&[
|
||||
-1.3416407864998738,
|
||||
-0.4472135954999579,
|
||||
0.4472135954999579,
|
||||
1.3416407864998738,
|
||||
],
|
||||
&[
|
||||
-1.3416407864998738,
|
||||
-0.4472135954999579,
|
||||
0.4472135954999579,
|
||||
1.3416407864998738,
|
||||
],
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(m.mean(0), vec![3.0, 4.0, 5.0, 6.0]);
|
||||
assert_eq!(m.mean(1), vec![2.5, 6.5]);
|
||||
|
||||
assert_eq!(m.var(0), vec![4., 4., 4., 4.]);
|
||||
assert_eq!(m.var(1), vec![1.25, 1.25]);
|
||||
|
||||
assert_eq!(m.std(0), vec![2., 2., 2., 2.]);
|
||||
assert_eq!(m.std(1), vec![1.118033988749895, 1.118033988749895]);
|
||||
|
||||
{
|
||||
let mut m = m.clone();
|
||||
m.standard_scale_mut(&m.mean(0), &m.std(0), 0);
|
||||
assert_eq!(&m, &expected_0);
|
||||
}
|
||||
|
||||
{
|
||||
let mut m = m;
|
||||
m.standard_scale_mut(&m.mean(1), &m.std(1), 1);
|
||||
assert_eq!(&m, &expected_1);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -10,14 +10,14 @@
|
||||
//!
|
||||
//! Example:
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::linalg::svd::*;
|
||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
//! use smartcore::linalg::traits::svd::*;
|
||||
//!
|
||||
//! let A = DenseMatrix::from_2d_array(&[
|
||||
//! &[0.9, 0.4, 0.7],
|
||||
//! &[0.4, 0.5, 0.3],
|
||||
//! &[0.7, 0.3, 0.8]
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! let svd = A.svd().unwrap();
|
||||
//! let u: DenseMatrix<f64> = svd.U;
|
||||
@@ -34,32 +34,33 @@
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::BaseMatrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::linalg::basic::arrays::Array2;
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
use std::fmt::Debug;
|
||||
|
||||
/// Results of SVD decomposition
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SVD<T: RealNumber, M: SVDDecomposableMatrix<T>> {
|
||||
pub struct SVD<T: Number + RealNumber, M: SVDDecomposable<T>> {
|
||||
/// Left-singular vectors of _A_
|
||||
pub U: M,
|
||||
/// Right-singular vectors of _A_
|
||||
pub V: M,
|
||||
/// Singular values of the original matrix
|
||||
pub s: Vec<T>,
|
||||
_full: bool,
|
||||
m: usize,
|
||||
n: usize,
|
||||
/// Tolerance
|
||||
tol: T,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
|
||||
impl<T: Number + RealNumber, M: SVDDecomposable<T>> SVD<T, M> {
|
||||
/// Diagonal matrix with singular values
|
||||
pub fn S(&self) -> M {
|
||||
let mut s = M::zeros(self.U.shape().1, self.V.shape().0);
|
||||
|
||||
for i in 0..self.s.len() {
|
||||
s.set(i, i, self.s[i]);
|
||||
s.set((i, i), self.s[i]);
|
||||
}
|
||||
|
||||
s
|
||||
@@ -67,7 +68,7 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
|
||||
}
|
||||
|
||||
/// Trait that implements SVD decomposition routine for any matrix.
|
||||
pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
pub trait SVDDecomposable<T: Number + RealNumber>: Array2<T> {
|
||||
/// Solves Ax = b. Overrides original matrix in the process.
|
||||
fn svd_solve_mut(self, b: Self) -> Result<Self, Failed> {
|
||||
self.svd_mut().and_then(|svd| svd.solve(b))
|
||||
@@ -106,31 +107,31 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
|
||||
if i < m {
|
||||
for k in i..m {
|
||||
scale += U.get(k, i).abs();
|
||||
scale += U.get((k, i)).abs();
|
||||
}
|
||||
|
||||
if scale.abs() > T::epsilon() {
|
||||
for k in i..m {
|
||||
U.div_element_mut(k, i, scale);
|
||||
s += U.get(k, i) * U.get(k, i);
|
||||
U.div_element_mut((k, i), scale);
|
||||
s += *U.get((k, i)) * *U.get((k, i));
|
||||
}
|
||||
|
||||
let mut f = U.get(i, i);
|
||||
g = -RealNumber::copysign(s.sqrt(), f);
|
||||
let mut f = *U.get((i, i));
|
||||
g = -<T as RealNumber>::copysign(s.sqrt(), f);
|
||||
let h = f * g - s;
|
||||
U.set(i, i, f - g);
|
||||
U.set((i, i), f - g);
|
||||
for j in l - 1..n {
|
||||
s = T::zero();
|
||||
for k in i..m {
|
||||
s += U.get(k, i) * U.get(k, j);
|
||||
s += *U.get((k, i)) * *U.get((k, j));
|
||||
}
|
||||
f = s / h;
|
||||
for k in i..m {
|
||||
U.add_element_mut(k, j, f * U.get(k, i));
|
||||
U.add_element_mut((k, j), f * *U.get((k, i)));
|
||||
}
|
||||
}
|
||||
for k in i..m {
|
||||
U.mul_element_mut(k, i, scale);
|
||||
U.mul_element_mut((k, i), scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -142,37 +143,37 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
|
||||
if i < m && i + 1 != n {
|
||||
for k in l - 1..n {
|
||||
scale += U.get(i, k).abs();
|
||||
scale += U.get((i, k)).abs();
|
||||
}
|
||||
|
||||
if scale.abs() > T::epsilon() {
|
||||
for k in l - 1..n {
|
||||
U.div_element_mut(i, k, scale);
|
||||
s += U.get(i, k) * U.get(i, k);
|
||||
U.div_element_mut((i, k), scale);
|
||||
s += *U.get((i, k)) * *U.get((i, k));
|
||||
}
|
||||
|
||||
let f = U.get(i, l - 1);
|
||||
g = -RealNumber::copysign(s.sqrt(), f);
|
||||
let f = *U.get((i, l - 1));
|
||||
g = -<T as RealNumber>::copysign(s.sqrt(), f);
|
||||
let h = f * g - s;
|
||||
U.set(i, l - 1, f - g);
|
||||
U.set((i, l - 1), f - g);
|
||||
|
||||
for (k, rv1_k) in rv1.iter_mut().enumerate().take(n).skip(l - 1) {
|
||||
*rv1_k = U.get(i, k) / h;
|
||||
*rv1_k = *U.get((i, k)) / h;
|
||||
}
|
||||
|
||||
for j in l - 1..m {
|
||||
s = T::zero();
|
||||
for k in l - 1..n {
|
||||
s += U.get(j, k) * U.get(i, k);
|
||||
s += *U.get((j, k)) * *U.get((i, k));
|
||||
}
|
||||
|
||||
for (k, rv1_k) in rv1.iter().enumerate().take(n).skip(l - 1) {
|
||||
U.add_element_mut(j, k, s * (*rv1_k));
|
||||
U.add_element_mut((j, k), s * (*rv1_k));
|
||||
}
|
||||
}
|
||||
|
||||
for k in l - 1..n {
|
||||
U.mul_element_mut(i, k, scale);
|
||||
U.mul_element_mut((i, k), scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -184,24 +185,24 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
if i < n - 1 {
|
||||
if g != T::zero() {
|
||||
for j in l..n {
|
||||
v.set(j, i, (U.get(i, j) / U.get(i, l)) / g);
|
||||
v.set((j, i), (*U.get((i, j)) / *U.get((i, l))) / g);
|
||||
}
|
||||
for j in l..n {
|
||||
let mut s = T::zero();
|
||||
for k in l..n {
|
||||
s += U.get(i, k) * v.get(k, j);
|
||||
s += *U.get((i, k)) * *v.get((k, j));
|
||||
}
|
||||
for k in l..n {
|
||||
v.add_element_mut(k, j, s * v.get(k, i));
|
||||
v.add_element_mut((k, j), s * *v.get((k, i)));
|
||||
}
|
||||
}
|
||||
}
|
||||
for j in l..n {
|
||||
v.set(i, j, T::zero());
|
||||
v.set(j, i, T::zero());
|
||||
v.set((i, j), T::zero());
|
||||
v.set((j, i), T::zero());
|
||||
}
|
||||
}
|
||||
v.set(i, i, T::one());
|
||||
v.set((i, i), T::one());
|
||||
g = rv1[i];
|
||||
l = i;
|
||||
}
|
||||
@@ -210,7 +211,7 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
l = i + 1;
|
||||
g = w[i];
|
||||
for j in l..n {
|
||||
U.set(i, j, T::zero());
|
||||
U.set((i, j), T::zero());
|
||||
}
|
||||
|
||||
if g.abs() > T::epsilon() {
|
||||
@@ -218,23 +219,23 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
for j in l..n {
|
||||
let mut s = T::zero();
|
||||
for k in l..m {
|
||||
s += U.get(k, i) * U.get(k, j);
|
||||
s += *U.get((k, i)) * *U.get((k, j));
|
||||
}
|
||||
let f = (s / U.get(i, i)) * g;
|
||||
let f = (s / *U.get((i, i))) * g;
|
||||
for k in i..m {
|
||||
U.add_element_mut(k, j, f * U.get(k, i));
|
||||
U.add_element_mut((k, j), f * *U.get((k, i)));
|
||||
}
|
||||
}
|
||||
for j in i..m {
|
||||
U.mul_element_mut(j, i, g);
|
||||
U.mul_element_mut((j, i), g);
|
||||
}
|
||||
} else {
|
||||
for j in i..m {
|
||||
U.set(j, i, T::zero());
|
||||
U.set((j, i), T::zero());
|
||||
}
|
||||
}
|
||||
|
||||
U.add_element_mut(i, i, T::one());
|
||||
U.add_element_mut((i, i), T::one());
|
||||
}
|
||||
|
||||
for k in (0..n).rev() {
|
||||
@@ -269,10 +270,10 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
c = g * h;
|
||||
s = -f * h;
|
||||
for j in 0..m {
|
||||
let y = U.get(j, nm);
|
||||
let z = U.get(j, i);
|
||||
U.set(j, nm, y * c + z * s);
|
||||
U.set(j, i, z * c - y * s);
|
||||
let y = *U.get((j, nm));
|
||||
let z = *U.get((j, i));
|
||||
U.set((j, nm), y * c + z * s);
|
||||
U.set((j, i), z * c - y * s);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -282,7 +283,7 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
if z < T::zero() {
|
||||
w[k] = -z;
|
||||
for j in 0..n {
|
||||
v.set(j, k, -v.get(j, k));
|
||||
v.set((j, k), -*v.get((j, k)));
|
||||
}
|
||||
}
|
||||
break;
|
||||
@@ -299,7 +300,8 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
let mut h = rv1[k];
|
||||
let mut f = ((y - z) * (y + z) + (g - h) * (g + h)) / (T::two() * h * y);
|
||||
g = f.hypot(T::one());
|
||||
f = ((x - z) * (x + z) + h * ((y / (f + RealNumber::copysign(g, f))) - h)) / x;
|
||||
f = ((x - z) * (x + z) + h * ((y / (f + <T as RealNumber>::copysign(g, f))) - h))
|
||||
/ x;
|
||||
let mut c = T::one();
|
||||
let mut s = T::one();
|
||||
|
||||
@@ -319,10 +321,10 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
y *= c;
|
||||
|
||||
for jj in 0..n {
|
||||
x = v.get(jj, j);
|
||||
z = v.get(jj, i);
|
||||
v.set(jj, j, x * c + z * s);
|
||||
v.set(jj, i, z * c - x * s);
|
||||
x = *v.get((jj, j));
|
||||
z = *v.get((jj, i));
|
||||
v.set((jj, j), x * c + z * s);
|
||||
v.set((jj, i), z * c - x * s);
|
||||
}
|
||||
|
||||
z = f.hypot(h);
|
||||
@@ -336,10 +338,10 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
f = c * g + s * y;
|
||||
x = c * y - s * g;
|
||||
for jj in 0..m {
|
||||
y = U.get(jj, j);
|
||||
z = U.get(jj, i);
|
||||
U.set(jj, j, y * c + z * s);
|
||||
U.set(jj, i, z * c - y * s);
|
||||
y = *U.get((jj, j));
|
||||
z = *U.get((jj, i));
|
||||
U.set((jj, j), y * c + z * s);
|
||||
U.set((jj, i), z * c - y * s);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -366,19 +368,19 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
for i in inc..n {
|
||||
let sw = w[i];
|
||||
for (k, su_k) in su.iter_mut().enumerate().take(m) {
|
||||
*su_k = U.get(k, i);
|
||||
*su_k = *U.get((k, i));
|
||||
}
|
||||
for (k, sv_k) in sv.iter_mut().enumerate().take(n) {
|
||||
*sv_k = v.get(k, i);
|
||||
*sv_k = *v.get((k, i));
|
||||
}
|
||||
let mut j = i;
|
||||
while w[j - inc] < sw {
|
||||
w[j] = w[j - inc];
|
||||
for k in 0..m {
|
||||
U.set(k, j, U.get(k, j - inc));
|
||||
U.set((k, j), *U.get((k, j - inc)));
|
||||
}
|
||||
for k in 0..n {
|
||||
v.set(k, j, v.get(k, j - inc));
|
||||
v.set((k, j), *v.get((k, j - inc)));
|
||||
}
|
||||
j -= inc;
|
||||
if j < inc {
|
||||
@@ -387,10 +389,10 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
}
|
||||
w[j] = sw;
|
||||
for (k, su_k) in su.iter().enumerate().take(m) {
|
||||
U.set(k, j, *su_k);
|
||||
U.set((k, j), *su_k);
|
||||
}
|
||||
for (k, sv_k) in sv.iter().enumerate().take(n) {
|
||||
v.set(k, j, *sv_k);
|
||||
v.set((k, j), *sv_k);
|
||||
}
|
||||
}
|
||||
if inc <= 1 {
|
||||
@@ -401,21 +403,21 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
for k in 0..n {
|
||||
let mut s = 0.;
|
||||
for i in 0..m {
|
||||
if U.get(i, k) < T::zero() {
|
||||
if U.get((i, k)) < &T::zero() {
|
||||
s += 1.;
|
||||
}
|
||||
}
|
||||
for j in 0..n {
|
||||
if v.get(j, k) < T::zero() {
|
||||
if v.get((j, k)) < &T::zero() {
|
||||
s += 1.;
|
||||
}
|
||||
}
|
||||
if s > (m + n) as f64 / 2. {
|
||||
for i in 0..m {
|
||||
U.set(i, k, -U.get(i, k));
|
||||
U.set((i, k), -*U.get((i, k)));
|
||||
}
|
||||
for j in 0..n {
|
||||
v.set(j, k, -v.get(j, k));
|
||||
v.set((j, k), -*v.get((j, k)));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -424,21 +426,12 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
|
||||
impl<T: Number + RealNumber, M: SVDDecomposable<T>> SVD<T, M> {
|
||||
pub(crate) fn new(U: M, V: M, s: Vec<T>) -> SVD<T, M> {
|
||||
let m = U.shape().0;
|
||||
let n = V.shape().0;
|
||||
let _full = s.len() == m.min(n);
|
||||
let tol = T::half() * (T::from(m + n).unwrap() + T::one()).sqrt() * s[0] * T::epsilon();
|
||||
SVD {
|
||||
U,
|
||||
V,
|
||||
s,
|
||||
_full,
|
||||
m,
|
||||
n,
|
||||
tol,
|
||||
}
|
||||
SVD { U, V, s, m, n, tol }
|
||||
}
|
||||
|
||||
pub(crate) fn solve(&self, mut b: M) -> Result<M, Failed> {
|
||||
@@ -458,7 +451,7 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
|
||||
let mut r = T::zero();
|
||||
if self.s[j] > self.tol {
|
||||
for i in 0..self.m {
|
||||
r += self.U.get(i, j) * b.get(i, k);
|
||||
r += *self.U.get((i, j)) * *b.get((i, k));
|
||||
}
|
||||
r /= self.s[j];
|
||||
}
|
||||
@@ -468,9 +461,9 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
|
||||
for j in 0..self.n {
|
||||
let mut r = T::zero();
|
||||
for (jj, tmp_jj) in tmp.iter().enumerate().take(self.n) {
|
||||
r += self.V.get(j, jj) * (*tmp_jj);
|
||||
r += *self.V.get((j, jj)) * (*tmp_jj);
|
||||
}
|
||||
b.set(j, k, r);
|
||||
b.set((j, k), r);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -481,15 +474,21 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use approx::relative_eq;
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn decompose_symmetric() {
|
||||
let A = DenseMatrix::from_2d_array(&[
|
||||
&[0.9000, 0.4000, 0.7000],
|
||||
&[0.4000, 0.5000, 0.3000],
|
||||
&[0.7000, 0.3000, 0.8000],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let s: Vec<f64> = vec![1.7498382, 0.3165784, 0.1335834];
|
||||
|
||||
@@ -497,23 +496,28 @@ mod tests {
|
||||
&[0.6881997, -0.07121225, 0.7220180],
|
||||
&[0.3700456, 0.89044952, -0.2648886],
|
||||
&[0.6240573, -0.44947578, -0.639158],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let V = DenseMatrix::from_2d_array(&[
|
||||
&[0.6881997, -0.07121225, 0.7220180],
|
||||
&[0.3700456, 0.89044952, -0.2648886],
|
||||
&[0.6240573, -0.44947578, -0.6391588],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let svd = A.svd().unwrap();
|
||||
|
||||
assert!(V.abs().approximate_eq(&svd.V.abs(), 1e-4));
|
||||
assert!(U.abs().approximate_eq(&svd.U.abs(), 1e-4));
|
||||
for i in 0..s.len() {
|
||||
assert!((s[i] - svd.s[i]).abs() < 1e-4);
|
||||
assert!(relative_eq!(V.abs(), svd.V.abs(), epsilon = 1e-4));
|
||||
assert!(relative_eq!(U.abs(), svd.U.abs(), epsilon = 1e-4));
|
||||
for (i, s_i) in s.iter().enumerate() {
|
||||
assert!((s_i - svd.s[i]).abs() < 1e-4);
|
||||
}
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn decompose_asymmetric() {
|
||||
let A = DenseMatrix::from_2d_array(&[
|
||||
@@ -574,7 +578,8 @@ mod tests {
|
||||
-0.2158704,
|
||||
-0.27529472,
|
||||
],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let s: Vec<f64> = vec![
|
||||
3.8589375, 3.4396766, 2.6487176, 2.2317399, 1.5165054, 0.8109055, 0.2706515,
|
||||
@@ -644,7 +649,8 @@ mod tests {
|
||||
0.73034065,
|
||||
-0.43965505,
|
||||
],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let V = DenseMatrix::from_2d_array(&[
|
||||
&[
|
||||
@@ -704,31 +710,40 @@ mod tests {
|
||||
0.1654796,
|
||||
-0.32346758,
|
||||
],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let svd = A.svd().unwrap();
|
||||
|
||||
assert!(V.abs().approximate_eq(&svd.V.abs(), 1e-4));
|
||||
assert!(U.abs().approximate_eq(&svd.U.abs(), 1e-4));
|
||||
for i in 0..s.len() {
|
||||
assert!((s[i] - svd.s[i]).abs() < 1e-4);
|
||||
assert!(relative_eq!(V.abs(), svd.V.abs(), epsilon = 1e-4));
|
||||
assert!(relative_eq!(U.abs(), svd.U.abs(), epsilon = 1e-4));
|
||||
for (i, s_i) in s.iter().enumerate() {
|
||||
assert!((s_i - svd.s[i]).abs() < 1e-4);
|
||||
}
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn solve() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]);
|
||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]])
|
||||
.unwrap();
|
||||
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]).unwrap();
|
||||
let expected_w =
|
||||
DenseMatrix::from_2d_array(&[&[-0.20, -1.28], &[0.87, 2.22], &[0.47, 0.66]]);
|
||||
DenseMatrix::from_2d_array(&[&[-0.20, -1.28], &[0.87, 2.22], &[0.47, 0.66]]).unwrap();
|
||||
let w = a.svd_solve_mut(b).unwrap();
|
||||
assert!(w.approximate_eq(&expected_w, 1e-2));
|
||||
assert!(relative_eq!(w, expected_w, epsilon = 1e-2));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn decompose_restore() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0, 4.0], &[5.0, 6.0, 7.0, 8.0]]);
|
||||
let a =
|
||||
DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0, 4.0], &[5.0, 6.0, 7.0, 8.0]]).unwrap();
|
||||
let svd = a.svd().unwrap();
|
||||
let u: &DenseMatrix<f32> = &svd.U; //U
|
||||
let v: &DenseMatrix<f32> = &svd.V; // V
|
||||
@@ -736,8 +751,6 @@ mod tests {
|
||||
|
||||
let a_hat = u.matmul(s).matmul(&v.transpose());
|
||||
|
||||
for (a, a_hat) in a.iter().zip(a_hat.iter()) {
|
||||
assert!((a - a_hat).abs() < 1e-3)
|
||||
}
|
||||
assert!(relative_eq!(a, a_hat, epsilon = 1e-3));
|
||||
}
|
||||
}
|
||||
+81
-49
@@ -1,13 +1,43 @@
|
||||
//! This is a generic solver for Ax = b type of equation
|
||||
//!
|
||||
//! Example:
|
||||
//! ```
|
||||
//! use smartcore::linalg::basic::arrays::Array1;
|
||||
//! use smartcore::linalg::basic::arrays::Array2;
|
||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
//! use smartcore::linear::bg_solver::*;
|
||||
//! use smartcore::numbers::floatnum::FloatNumber;
|
||||
//! use smartcore::linear::bg_solver::BiconjugateGradientSolver;
|
||||
//!
|
||||
//! pub struct BGSolver {}
|
||||
//! impl<'a, T: FloatNumber, X: Array2<T>> BiconjugateGradientSolver<'a, T, X> for BGSolver {}
|
||||
//!
|
||||
//! let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0.,
|
||||
//! 11.]]).unwrap();
|
||||
//! let b = vec![40., 51., 28.];
|
||||
//! let expected = vec![1.0, 2.0, 3.0];
|
||||
//! let mut x = Vec::zeros(3);
|
||||
//! let solver = BGSolver {};
|
||||
//! let err: f64 = solver.solve_mut(&a, &b, &mut x, 1e-6, 6).unwrap();
|
||||
//! ```
|
||||
//!
|
||||
//! for more information take a look at [this Wikipedia article](https://en.wikipedia.org/wiki/Biconjugate_gradient_method)
|
||||
//! and [this paper](https://www.cs.cmu.edu/~quake-papers/painless-conjugate-gradient.pdf)
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::linalg::basic::arrays::{Array, Array1, Array2, ArrayView1, MutArrayView1};
|
||||
use crate::numbers::floatnum::FloatNumber;
|
||||
|
||||
pub trait BiconjugateGradientSolver<T: RealNumber, M: Matrix<T>> {
|
||||
fn solve_mut(&self, a: &M, b: &M, x: &mut M, tol: T, max_iter: usize) -> Result<T, Failed> {
|
||||
/// Trait for Biconjugate Gradient Solver
|
||||
pub trait BiconjugateGradientSolver<'a, T: FloatNumber, X: Array2<T>> {
|
||||
/// Solve Ax = b
|
||||
fn solve_mut(
|
||||
&self,
|
||||
a: &'a X,
|
||||
b: &Vec<T>,
|
||||
x: &mut Vec<T>,
|
||||
tol: T,
|
||||
max_iter: usize,
|
||||
) -> Result<T, Failed> {
|
||||
if tol <= T::zero() {
|
||||
return Err(Failed::fit("tolerance shoud be > 0"));
|
||||
}
|
||||
@@ -16,25 +46,25 @@ pub trait BiconjugateGradientSolver<T: RealNumber, M: Matrix<T>> {
|
||||
return Err(Failed::fit("maximum number of iterations should be > 0"));
|
||||
}
|
||||
|
||||
let (n, _) = b.shape();
|
||||
let n = b.shape();
|
||||
|
||||
let mut r = M::zeros(n, 1);
|
||||
let mut rr = M::zeros(n, 1);
|
||||
let mut z = M::zeros(n, 1);
|
||||
let mut zz = M::zeros(n, 1);
|
||||
let mut r = Vec::zeros(n);
|
||||
let mut rr = Vec::zeros(n);
|
||||
let mut z = Vec::zeros(n);
|
||||
let mut zz = Vec::zeros(n);
|
||||
|
||||
self.mat_vec_mul(a, x, &mut r);
|
||||
|
||||
for j in 0..n {
|
||||
r.set(j, 0, b.get(j, 0) - r.get(j, 0));
|
||||
rr.set(j, 0, r.get(j, 0));
|
||||
r[j] = b[j] - r[j];
|
||||
rr[j] = r[j];
|
||||
}
|
||||
|
||||
let bnrm = b.norm(T::two());
|
||||
self.solve_preconditioner(a, &r, &mut z);
|
||||
let bnrm = b.norm(2f64);
|
||||
self.solve_preconditioner(a, &r[..], &mut z[..]);
|
||||
|
||||
let mut p = M::zeros(n, 1);
|
||||
let mut pp = M::zeros(n, 1);
|
||||
let mut p = Vec::zeros(n);
|
||||
let mut pp = Vec::zeros(n);
|
||||
let mut bkden = T::zero();
|
||||
let mut err = T::zero();
|
||||
|
||||
@@ -43,35 +73,33 @@ pub trait BiconjugateGradientSolver<T: RealNumber, M: Matrix<T>> {
|
||||
|
||||
self.solve_preconditioner(a, &rr, &mut zz);
|
||||
for j in 0..n {
|
||||
bknum += z.get(j, 0) * rr.get(j, 0);
|
||||
bknum += z[j] * rr[j];
|
||||
}
|
||||
if iter == 1 {
|
||||
for j in 0..n {
|
||||
p.set(j, 0, z.get(j, 0));
|
||||
pp.set(j, 0, zz.get(j, 0));
|
||||
}
|
||||
p[..n].copy_from_slice(&z[..n]);
|
||||
pp[..n].copy_from_slice(&zz[..n]);
|
||||
} else {
|
||||
let bk = bknum / bkden;
|
||||
for j in 0..n {
|
||||
p.set(j, 0, bk * p.get(j, 0) + z.get(j, 0));
|
||||
pp.set(j, 0, bk * pp.get(j, 0) + zz.get(j, 0));
|
||||
p[j] = bk * pp[j] + z[j];
|
||||
pp[j] = bk * pp[j] + zz[j];
|
||||
}
|
||||
}
|
||||
bkden = bknum;
|
||||
self.mat_vec_mul(a, &p, &mut z);
|
||||
let mut akden = T::zero();
|
||||
for j in 0..n {
|
||||
akden += z.get(j, 0) * pp.get(j, 0);
|
||||
akden += z[j] * pp[j];
|
||||
}
|
||||
let ak = bknum / akden;
|
||||
self.mat_t_vec_mul(a, &pp, &mut zz);
|
||||
for j in 0..n {
|
||||
x.set(j, 0, x.get(j, 0) + ak * p.get(j, 0));
|
||||
r.set(j, 0, r.get(j, 0) - ak * z.get(j, 0));
|
||||
rr.set(j, 0, rr.get(j, 0) - ak * zz.get(j, 0));
|
||||
x[j] += ak * p[j];
|
||||
r[j] -= ak * z[j];
|
||||
rr[j] -= ak * zz[j];
|
||||
}
|
||||
self.solve_preconditioner(a, &r, &mut z);
|
||||
err = r.norm(T::two()) / bnrm;
|
||||
err = T::from_f64(r.norm(2f64) / bnrm).unwrap();
|
||||
|
||||
if err <= tol {
|
||||
break;
|
||||
@@ -81,36 +109,38 @@ pub trait BiconjugateGradientSolver<T: RealNumber, M: Matrix<T>> {
|
||||
Ok(err)
|
||||
}
|
||||
|
||||
fn solve_preconditioner(&self, a: &M, b: &M, x: &mut M) {
|
||||
/// solve preconditioner
|
||||
fn solve_preconditioner(&self, a: &'a X, b: &[T], x: &mut [T]) {
|
||||
let diag = Self::diag(a);
|
||||
let n = diag.len();
|
||||
|
||||
for (i, diag_i) in diag.iter().enumerate().take(n) {
|
||||
if *diag_i != T::zero() {
|
||||
x.set(i, 0, b.get(i, 0) / *diag_i);
|
||||
x[i] = b[i] / *diag_i;
|
||||
} else {
|
||||
x.set(i, 0, b.get(i, 0));
|
||||
x[i] = b[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// y = Ax
|
||||
fn mat_vec_mul(&self, a: &M, x: &M, y: &mut M) {
|
||||
y.copy_from(&a.matmul(x));
|
||||
/// y = Ax
|
||||
fn mat_vec_mul(&self, a: &X, x: &Vec<T>, y: &mut Vec<T>) {
|
||||
y.copy_from(&x.xa(false, a));
|
||||
}
|
||||
|
||||
// y = Atx
|
||||
fn mat_t_vec_mul(&self, a: &M, x: &M, y: &mut M) {
|
||||
y.copy_from(&a.ab(true, x, false));
|
||||
/// y = Atx
|
||||
fn mat_t_vec_mul(&self, a: &X, x: &Vec<T>, y: &mut Vec<T>) {
|
||||
y.copy_from(&x.xa(true, a));
|
||||
}
|
||||
|
||||
fn diag(a: &M) -> Vec<T> {
|
||||
/// Extract the diagonal from a matrix
|
||||
fn diag(a: &X) -> Vec<T> {
|
||||
let (nrows, ncols) = a.shape();
|
||||
let n = nrows.min(ncols);
|
||||
|
||||
let mut d = Vec::with_capacity(n);
|
||||
for i in 0..n {
|
||||
d.push(a.get(i, i));
|
||||
d.push(*a.get((i, i)));
|
||||
}
|
||||
|
||||
d
|
||||
@@ -120,28 +150,30 @@ pub trait BiconjugateGradientSolver<T: RealNumber, M: Matrix<T>> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::linalg::basic::arrays::Array2;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
|
||||
pub struct BGSolver {}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> BiconjugateGradientSolver<T, M> for BGSolver {}
|
||||
impl<T: FloatNumber, X: Array2<T>> BiconjugateGradientSolver<'_, T, X> for BGSolver {}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn bg_solver() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
|
||||
let b = DenseMatrix::from_2d_array(&[&[40., 51., 28.]]);
|
||||
let expected = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]);
|
||||
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]])
|
||||
.unwrap();
|
||||
let b = vec![40., 51., 28.];
|
||||
let expected = [1.0, 2.0, 3.0];
|
||||
|
||||
let mut x = DenseMatrix::zeros(3, 1);
|
||||
let mut x = Vec::zeros(3);
|
||||
|
||||
let solver = BGSolver {};
|
||||
|
||||
let err: f64 = solver
|
||||
.solve_mut(&a, &b.transpose(), &mut x, 1e-6, 6)
|
||||
.unwrap();
|
||||
let err: f64 = solver.solve_mut(&a, &b, &mut x, 1e-6, 6).unwrap();
|
||||
|
||||
assert!(x.transpose().approximate_eq(&expected, 1e-4));
|
||||
assert!(x
|
||||
.iter()
|
||||
.zip(expected.iter())
|
||||
.all(|(&a, &b)| (a - b).abs() < 1e-4));
|
||||
assert!((err - 0.0).abs() < 1e-4);
|
||||
}
|
||||
}
|
||||
|
||||
+320
-114
@@ -17,7 +17,7 @@
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
//! use smartcore::linear::elastic_net::*;
|
||||
//!
|
||||
//! // Longley dataset (https://www.statsmodels.org/stable/datasets/generated/longley.html)
|
||||
@@ -38,7 +38,7 @@
|
||||
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0,
|
||||
//! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
||||
@@ -55,32 +55,39 @@
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::linalg::basic::arrays::{Array, Array1, Array2, MutArray};
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::floatnum::FloatNumber;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
|
||||
use crate::linear::lasso_optimizer::InteriorPointOptimizer;
|
||||
|
||||
/// Elastic net parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ElasticNetParameters<T: RealNumber> {
|
||||
pub struct ElasticNetParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Regularization parameter.
|
||||
pub alpha: T,
|
||||
pub alpha: f64,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The elastic net mixing parameter, with 0 <= l1_ratio <= 1.
|
||||
/// For l1_ratio = 0 the penalty is an L2 penalty.
|
||||
/// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.
|
||||
pub l1_ratio: T,
|
||||
pub l1_ratio: f64,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation.
|
||||
pub normalize: bool,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The tolerance for the optimization
|
||||
pub tol: T,
|
||||
pub tol: f64,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The maximum number of iterations
|
||||
pub max_iter: usize,
|
||||
}
|
||||
@@ -88,21 +95,23 @@ pub struct ElasticNetParameters<T: RealNumber> {
|
||||
/// Elastic net
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct ElasticNet<T: RealNumber, M: Matrix<T>> {
|
||||
coefficients: M,
|
||||
intercept: T,
|
||||
pub struct ElasticNet<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
|
||||
coefficients: Option<X>,
|
||||
intercept: Option<TX>,
|
||||
_phantom_ty: PhantomData<TY>,
|
||||
_phantom_y: PhantomData<Y>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> ElasticNetParameters<T> {
|
||||
impl ElasticNetParameters {
|
||||
/// Regularization parameter.
|
||||
pub fn with_alpha(mut self, alpha: T) -> Self {
|
||||
pub fn with_alpha(mut self, alpha: f64) -> Self {
|
||||
self.alpha = alpha;
|
||||
self
|
||||
}
|
||||
/// The elastic net mixing parameter, with 0 <= l1_ratio <= 1.
|
||||
/// For l1_ratio = 0 the penalty is an L2 penalty.
|
||||
/// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.
|
||||
pub fn with_l1_ratio(mut self, l1_ratio: T) -> Self {
|
||||
pub fn with_l1_ratio(mut self, l1_ratio: f64) -> Self {
|
||||
self.l1_ratio = l1_ratio;
|
||||
self
|
||||
}
|
||||
@@ -112,7 +121,7 @@ impl<T: RealNumber> ElasticNetParameters<T> {
|
||||
self
|
||||
}
|
||||
/// The tolerance for the optimization
|
||||
pub fn with_tol(mut self, tol: T) -> Self {
|
||||
pub fn with_tol(mut self, tol: f64) -> Self {
|
||||
self.tol = tol;
|
||||
self
|
||||
}
|
||||
@@ -123,61 +132,205 @@ impl<T: RealNumber> ElasticNetParameters<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for ElasticNetParameters<T> {
|
||||
impl Default for ElasticNetParameters {
|
||||
fn default() -> Self {
|
||||
ElasticNetParameters {
|
||||
alpha: T::one(),
|
||||
l1_ratio: T::half(),
|
||||
alpha: 1.0,
|
||||
l1_ratio: 0.5,
|
||||
normalize: true,
|
||||
tol: T::from_f64(1e-4).unwrap(),
|
||||
tol: 1e-4,
|
||||
max_iter: 1000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> PartialEq for ElasticNet<T, M> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.coefficients == other.coefficients
|
||||
&& (self.intercept - other.intercept).abs() <= T::epsilon()
|
||||
/// ElasticNet grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ElasticNetSearchParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Regularization parameter.
|
||||
pub alpha: Vec<f64>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The elastic net mixing parameter, with 0 <= l1_ratio <= 1.
|
||||
/// For l1_ratio = 0 the penalty is an L2 penalty.
|
||||
/// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.
|
||||
pub l1_ratio: Vec<f64>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation.
|
||||
pub normalize: Vec<bool>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The tolerance for the optimization
|
||||
pub tol: Vec<f64>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The maximum number of iterations
|
||||
pub max_iter: Vec<usize>,
|
||||
}
|
||||
|
||||
/// ElasticNet grid search iterator
|
||||
pub struct ElasticNetSearchParametersIterator {
|
||||
lasso_regression_search_parameters: ElasticNetSearchParameters,
|
||||
current_alpha: usize,
|
||||
current_l1_ratio: usize,
|
||||
current_normalize: usize,
|
||||
current_tol: usize,
|
||||
current_max_iter: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for ElasticNetSearchParameters {
|
||||
type Item = ElasticNetParameters;
|
||||
type IntoIter = ElasticNetSearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
ElasticNetSearchParametersIterator {
|
||||
lasso_regression_search_parameters: self,
|
||||
current_alpha: 0,
|
||||
current_l1_ratio: 0,
|
||||
current_normalize: 0,
|
||||
current_tol: 0,
|
||||
current_max_iter: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, ElasticNetParameters<T>>
|
||||
for ElasticNet<T, M>
|
||||
impl Iterator for ElasticNetSearchParametersIterator {
|
||||
type Item = ElasticNetParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_alpha == self.lasso_regression_search_parameters.alpha.len()
|
||||
&& self.current_l1_ratio == self.lasso_regression_search_parameters.l1_ratio.len()
|
||||
&& self.current_normalize == self.lasso_regression_search_parameters.normalize.len()
|
||||
&& self.current_tol == self.lasso_regression_search_parameters.tol.len()
|
||||
&& self.current_max_iter == self.lasso_regression_search_parameters.max_iter.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = ElasticNetParameters {
|
||||
alpha: self.lasso_regression_search_parameters.alpha[self.current_alpha],
|
||||
l1_ratio: self.lasso_regression_search_parameters.alpha[self.current_l1_ratio],
|
||||
normalize: self.lasso_regression_search_parameters.normalize[self.current_normalize],
|
||||
tol: self.lasso_regression_search_parameters.tol[self.current_tol],
|
||||
max_iter: self.lasso_regression_search_parameters.max_iter[self.current_max_iter],
|
||||
};
|
||||
|
||||
if self.current_alpha + 1 < self.lasso_regression_search_parameters.alpha.len() {
|
||||
self.current_alpha += 1;
|
||||
} else if self.current_l1_ratio + 1 < self.lasso_regression_search_parameters.l1_ratio.len()
|
||||
{
|
||||
self.current_alpha = 0;
|
||||
self.current_l1_ratio += 1;
|
||||
} else if self.current_normalize + 1
|
||||
< self.lasso_regression_search_parameters.normalize.len()
|
||||
{
|
||||
self.current_alpha = 0;
|
||||
self.current_l1_ratio = 0;
|
||||
self.current_normalize += 1;
|
||||
} else if self.current_tol + 1 < self.lasso_regression_search_parameters.tol.len() {
|
||||
self.current_alpha = 0;
|
||||
self.current_l1_ratio = 0;
|
||||
self.current_normalize = 0;
|
||||
self.current_tol += 1;
|
||||
} else if self.current_max_iter + 1 < self.lasso_regression_search_parameters.max_iter.len()
|
||||
{
|
||||
self.current_alpha = 0;
|
||||
self.current_l1_ratio = 0;
|
||||
self.current_normalize = 0;
|
||||
self.current_tol = 0;
|
||||
self.current_max_iter += 1;
|
||||
} else {
|
||||
self.current_alpha += 1;
|
||||
self.current_l1_ratio += 1;
|
||||
self.current_normalize += 1;
|
||||
self.current_tol += 1;
|
||||
self.current_max_iter += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ElasticNetSearchParameters {
|
||||
fn default() -> Self {
|
||||
let default_params = ElasticNetParameters::default();
|
||||
|
||||
ElasticNetSearchParameters {
|
||||
alpha: vec![default_params.alpha],
|
||||
l1_ratio: vec![default_params.l1_ratio],
|
||||
normalize: vec![default_params.normalize],
|
||||
tol: vec![default_params.tol],
|
||||
max_iter: vec![default_params.max_iter],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
|
||||
for ElasticNet<TX, TY, X, Y>
|
||||
{
|
||||
fn fit(x: &M, y: &M::RowVector, parameters: ElasticNetParameters<T>) -> Result<Self, Failed> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.intercept() != other.intercept() {
|
||||
return false;
|
||||
}
|
||||
if self.coefficients().shape() != other.coefficients().shape() {
|
||||
return false;
|
||||
}
|
||||
self.coefficients()
|
||||
.iterator(0)
|
||||
.zip(other.coefficients().iterator(0))
|
||||
.all(|(&a, &b)| (a - b).abs() <= TX::epsilon())
|
||||
}
|
||||
}
|
||||
|
||||
impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||
SupervisedEstimator<X, Y, ElasticNetParameters> for ElasticNet<TX, TY, X, Y>
|
||||
{
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
coefficients: Option::None,
|
||||
intercept: Option::None,
|
||||
_phantom_ty: PhantomData,
|
||||
_phantom_y: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn fit(x: &X, y: &Y, parameters: ElasticNetParameters) -> Result<Self, Failed> {
|
||||
ElasticNet::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for ElasticNet<T, M> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Predictor<X, Y>
|
||||
for ElasticNet<TX, TY, X, Y>
|
||||
{
|
||||
fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> ElasticNet<T, M> {
|
||||
impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||
ElasticNet<TX, TY, X, Y>
|
||||
{
|
||||
/// Fits elastic net regression to your data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - target values
|
||||
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
|
||||
pub fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: ElasticNetParameters<T>,
|
||||
) -> Result<ElasticNet<T, M>, Failed> {
|
||||
x: &X,
|
||||
y: &Y,
|
||||
parameters: ElasticNetParameters,
|
||||
) -> Result<ElasticNet<TX, TY, X, Y>, Failed> {
|
||||
let (n, p) = x.shape();
|
||||
|
||||
if y.len() != n {
|
||||
if y.shape() != n {
|
||||
return Err(Failed::fit("Number of rows in X should = len(y)"));
|
||||
}
|
||||
|
||||
let n_float = T::from_usize(n).unwrap();
|
||||
let n_float = n as f64;
|
||||
|
||||
let l1_reg = parameters.alpha * parameters.l1_ratio * n_float;
|
||||
let l2_reg = parameters.alpha * (T::one() - parameters.l1_ratio) * n_float;
|
||||
let l1_reg = TX::from_f64(parameters.alpha * parameters.l1_ratio * n_float).unwrap();
|
||||
let l2_reg =
|
||||
TX::from_f64(parameters.alpha * (1.0 - parameters.l1_ratio) * n_float).unwrap();
|
||||
|
||||
let y_mean = y.mean();
|
||||
let y_mean = TX::from_f64(y.mean_by()).unwrap();
|
||||
|
||||
let (w, b) = if parameters.normalize {
|
||||
let (scaled_x, col_mean, col_std) = Self::rescale_x(x)?;
|
||||
@@ -186,72 +339,93 @@ impl<T: RealNumber, M: Matrix<T>> ElasticNet<T, M> {
|
||||
|
||||
let mut optimizer = InteriorPointOptimizer::new(&x, p);
|
||||
|
||||
let mut w =
|
||||
optimizer.optimize(&x, &y, l1_reg * gamma, parameters.max_iter, parameters.tol)?;
|
||||
let mut w = optimizer.optimize(
|
||||
&x,
|
||||
&y,
|
||||
l1_reg * gamma,
|
||||
parameters.max_iter,
|
||||
TX::from_f64(parameters.tol).unwrap(),
|
||||
)?;
|
||||
|
||||
for i in 0..p {
|
||||
w.set(i, 0, gamma * w.get(i, 0) / col_std[i]);
|
||||
w.set(i, gamma * *w.get(i) / col_std[i]);
|
||||
}
|
||||
|
||||
let mut b = T::zero();
|
||||
let mut b = TX::zero();
|
||||
|
||||
for i in 0..p {
|
||||
b += w.get(i, 0) * col_mean[i];
|
||||
b += *w.get(i) * col_mean[i];
|
||||
}
|
||||
|
||||
b = y_mean - b;
|
||||
|
||||
(w, b)
|
||||
(X::from_column(&w), b)
|
||||
} else {
|
||||
let (x, y, gamma) = Self::augment_x_and_y(x, y, l2_reg);
|
||||
|
||||
let mut optimizer = InteriorPointOptimizer::new(&x, p);
|
||||
|
||||
let mut w =
|
||||
optimizer.optimize(&x, &y, l1_reg * gamma, parameters.max_iter, parameters.tol)?;
|
||||
let mut w = optimizer.optimize(
|
||||
&x,
|
||||
&y,
|
||||
l1_reg * gamma,
|
||||
parameters.max_iter,
|
||||
TX::from_f64(parameters.tol).unwrap(),
|
||||
)?;
|
||||
|
||||
for i in 0..p {
|
||||
w.set(i, 0, gamma * w.get(i, 0));
|
||||
w.set(i, gamma * *w.get(i));
|
||||
}
|
||||
|
||||
(w, y_mean)
|
||||
(X::from_column(&w), y_mean)
|
||||
};
|
||||
|
||||
Ok(ElasticNet {
|
||||
intercept: b,
|
||||
coefficients: w,
|
||||
intercept: Some(b),
|
||||
coefficients: Some(w),
|
||||
_phantom_ty: PhantomData,
|
||||
_phantom_y: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
/// Predict target values from `x`
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
let (nrows, _) = x.shape();
|
||||
let mut y_hat = x.matmul(&self.coefficients);
|
||||
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
|
||||
Ok(y_hat.transpose().to_row_vector())
|
||||
let mut y_hat = x.matmul(self.coefficients.as_ref().unwrap());
|
||||
let bias = X::fill(nrows, 1, self.intercept.unwrap());
|
||||
y_hat.add_mut(&bias);
|
||||
Ok(Y::from_iterator(
|
||||
y_hat.iterator(0).map(|&v| TY::from(v).unwrap()),
|
||||
nrows,
|
||||
))
|
||||
}
|
||||
|
||||
/// Get estimates regression coefficients
|
||||
pub fn coefficients(&self) -> &M {
|
||||
&self.coefficients
|
||||
pub fn coefficients(&self) -> &X {
|
||||
self.coefficients.as_ref().unwrap()
|
||||
}
|
||||
|
||||
/// Get estimate of intercept
|
||||
pub fn intercept(&self) -> T {
|
||||
self.intercept
|
||||
pub fn intercept(&self) -> &TX {
|
||||
self.intercept.as_ref().unwrap()
|
||||
}
|
||||
|
||||
fn rescale_x(x: &M) -> Result<(M, Vec<T>, Vec<T>), Failed> {
|
||||
let col_mean = x.mean(0);
|
||||
let col_std = x.std(0);
|
||||
fn rescale_x(x: &X) -> Result<(X, Vec<TX>, Vec<TX>), Failed> {
|
||||
let col_mean: Vec<TX> = x
|
||||
.mean_by(0)
|
||||
.iter()
|
||||
.map(|&v| TX::from_f64(v).unwrap())
|
||||
.collect();
|
||||
let col_std: Vec<TX> = x
|
||||
.std_dev(0)
|
||||
.iter()
|
||||
.map(|&v| TX::from_f64(v).unwrap())
|
||||
.collect();
|
||||
|
||||
for i in 0..col_std.len() {
|
||||
if (col_std[i] - T::zero()).abs() < T::epsilon() {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Cannot rescale constant column {}",
|
||||
i
|
||||
)));
|
||||
for (i, col_std_i) in col_std.iter().enumerate() {
|
||||
if (*col_std_i - TX::zero()).abs() < TX::epsilon() {
|
||||
return Err(Failed::fit(&format!("Cannot rescale constant column {i}")));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -260,25 +434,25 @@ impl<T: RealNumber, M: Matrix<T>> ElasticNet<T, M> {
|
||||
Ok((scaled_x, col_mean, col_std))
|
||||
}
|
||||
|
||||
fn augment_x_and_y(x: &M, y: &M::RowVector, l2_reg: T) -> (M, M::RowVector, T) {
|
||||
fn augment_x_and_y(x: &X, y: &Y, l2_reg: TX) -> (X, Vec<TX>, TX) {
|
||||
let (n, p) = x.shape();
|
||||
|
||||
let gamma = T::one() / (T::one() + l2_reg).sqrt();
|
||||
let gamma = TX::one() / (TX::one() + l2_reg).sqrt();
|
||||
let padding = gamma * l2_reg.sqrt();
|
||||
|
||||
let mut y2 = M::RowVector::zeros(n + p);
|
||||
for i in 0..y.len() {
|
||||
y2.set(i, y.get(i));
|
||||
let mut y2 = Vec::<TX>::zeros(n + p);
|
||||
for i in 0..y.shape() {
|
||||
y2.set(i, TX::from(*y.get(i)).unwrap());
|
||||
}
|
||||
|
||||
let mut x2 = M::zeros(n + p, p);
|
||||
let mut x2 = X::zeros(n + p, p);
|
||||
|
||||
for j in 0..p {
|
||||
for i in 0..n {
|
||||
x2.set(i, j, gamma * x.get(i, j));
|
||||
x2.set((i, j), gamma * *x.get((i, j)));
|
||||
}
|
||||
|
||||
x2.set(j + n, j, padding);
|
||||
x2.set((j + n, j), padding);
|
||||
}
|
||||
|
||||
(x2, y2, gamma)
|
||||
@@ -288,10 +462,36 @@ impl<T: RealNumber, M: Matrix<T>> ElasticNet<T, M> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use crate::metrics::mean_absolute_error;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = ElasticNetSearchParameters {
|
||||
alpha: vec![0., 1.],
|
||||
max_iter: vec![10, 100],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 0.);
|
||||
assert_eq!(next.max_iter, 10);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 1.);
|
||||
assert_eq!(next.max_iter, 10);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 0.);
|
||||
assert_eq!(next.max_iter, 100);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 1.);
|
||||
assert_eq!(next.max_iter, 100);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn elasticnet_longley() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
@@ -311,7 +511,8 @@ mod tests {
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let y: Vec<f64> = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
@@ -335,7 +536,10 @@ mod tests {
|
||||
assert!(mean_absolute_error(&y_hat, &y) < 30.0);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn elasticnet_fit_predict1() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
@@ -359,7 +563,8 @@ mod tests {
|
||||
&[17.0, 1918.0, 1.4054969025700674],
|
||||
&[18.0, 1929.0, 1.3271699396384906],
|
||||
&[19.0, 1915.0, 1.1373332337674806],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let y: Vec<f64> = vec![
|
||||
1.48, 2.72, 4.52, 5.72, 5.25, 4.07, 3.75, 4.75, 6.77, 4.72, 6.78, 6.79, 8.3, 7.42,
|
||||
@@ -398,43 +603,44 @@ mod tests {
|
||||
assert!(mae_l1 < 2.0);
|
||||
assert!(mae_l2 < 2.0);
|
||||
|
||||
assert!(l1_model.coefficients().get(0, 0) > l1_model.coefficients().get(1, 0));
|
||||
assert!(l1_model.coefficients().get(0, 0) > l1_model.coefficients().get(2, 0));
|
||||
assert!(l1_model.coefficients().get((0, 0)) > l1_model.coefficients().get((1, 0)));
|
||||
assert!(l1_model.coefficients().get((0, 0)) > l1_model.coefficients().get((2, 0)));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
&[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
|
||||
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
&[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
|
||||
&[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
|
||||
&[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
|
||||
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
&[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
&[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
|
||||
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
// TODO: serialization for the new DenseMatrix needs to be implemented
|
||||
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
|
||||
// #[test]
|
||||
// #[cfg(feature = "serde")]
|
||||
// fn serde() {
|
||||
// let x = DenseMatrix::from_2d_array(&[
|
||||
// &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
// &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
// &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
// &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
|
||||
// &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
// &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
|
||||
// &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
|
||||
// &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
|
||||
// &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
// &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
// &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
// &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
|
||||
// &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
// &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
// &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
// &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
// ]).unwrap();
|
||||
|
||||
let y = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
// let y = vec![
|
||||
// 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
// 114.2, 115.7, 116.9,
|
||||
// ];
|
||||
|
||||
let lr = ElasticNet::fit(&x, &y, Default::default()).unwrap();
|
||||
// let lr = ElasticNet::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let deserialized_lr: ElasticNet<f64, DenseMatrix<f64>> =
|
||||
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
||||
// let deserialized_lr: ElasticNet<f64, f64, DenseMatrix<f64>, Vec<f64>> =
|
||||
// serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(lr, deserialized_lr);
|
||||
}
|
||||
// assert_eq!(lr, deserialized_lr);
|
||||
// }
|
||||
}
|
||||
|
||||
+268
-94
@@ -23,28 +23,34 @@
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::linalg::basic::arrays::{Array1, Array2, ArrayView1};
|
||||
use crate::linear::lasso_optimizer::InteriorPointOptimizer;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::floatnum::FloatNumber;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
|
||||
/// Lasso regression parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LassoParameters<T: RealNumber> {
|
||||
pub struct LassoParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Controls the strength of the penalty to the loss function.
|
||||
pub alpha: T,
|
||||
pub alpha: f64,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// If true the regressors X will be normalized before regression
|
||||
/// by subtracting the mean and dividing by the standard deviation.
|
||||
pub normalize: bool,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The tolerance for the optimization
|
||||
pub tol: T,
|
||||
pub tol: f64,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The maximum number of iterations
|
||||
pub max_iter: usize,
|
||||
}
|
||||
@@ -52,14 +58,16 @@ pub struct LassoParameters<T: RealNumber> {
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
/// Lasso regressor
|
||||
pub struct Lasso<T: RealNumber, M: Matrix<T>> {
|
||||
coefficients: M,
|
||||
intercept: T,
|
||||
pub struct Lasso<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
|
||||
coefficients: Option<X>,
|
||||
intercept: Option<TX>,
|
||||
_phantom_ty: PhantomData<TY>,
|
||||
_phantom_y: PhantomData<Y>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> LassoParameters<T> {
|
||||
impl LassoParameters {
|
||||
/// Regularization parameter.
|
||||
pub fn with_alpha(mut self, alpha: T) -> Self {
|
||||
pub fn with_alpha(mut self, alpha: f64) -> Self {
|
||||
self.alpha = alpha;
|
||||
self
|
||||
}
|
||||
@@ -69,7 +77,7 @@ impl<T: RealNumber> LassoParameters<T> {
|
||||
self
|
||||
}
|
||||
/// The tolerance for the optimization
|
||||
pub fn with_tol(mut self, tol: T) -> Self {
|
||||
pub fn with_tol(mut self, tol: f64) -> Self {
|
||||
self.tol = tol;
|
||||
self
|
||||
}
|
||||
@@ -80,48 +88,162 @@ impl<T: RealNumber> LassoParameters<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for LassoParameters<T> {
|
||||
impl Default for LassoParameters {
|
||||
fn default() -> Self {
|
||||
LassoParameters {
|
||||
alpha: T::one(),
|
||||
alpha: 1f64,
|
||||
normalize: true,
|
||||
tol: T::from_f64(1e-4).unwrap(),
|
||||
tol: 1e-4,
|
||||
max_iter: 1000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> PartialEq for Lasso<T, M> {
|
||||
impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
|
||||
for Lasso<TX, TY, X, Y>
|
||||
{
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.coefficients == other.coefficients
|
||||
&& (self.intercept - other.intercept).abs() <= T::epsilon()
|
||||
self.intercept == other.intercept
|
||||
&& self.coefficients().shape() == other.coefficients().shape()
|
||||
&& self
|
||||
.coefficients()
|
||||
.iterator(0)
|
||||
.zip(other.coefficients().iterator(0))
|
||||
.all(|(&a, &b)| (a - b).abs() <= TX::epsilon())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, LassoParameters<T>>
|
||||
for Lasso<T, M>
|
||||
impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||
SupervisedEstimator<X, Y, LassoParameters> for Lasso<TX, TY, X, Y>
|
||||
{
|
||||
fn fit(x: &M, y: &M::RowVector, parameters: LassoParameters<T>) -> Result<Self, Failed> {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
coefficients: Option::None,
|
||||
intercept: Option::None,
|
||||
_phantom_ty: PhantomData,
|
||||
_phantom_y: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn fit(x: &X, y: &Y, parameters: LassoParameters) -> Result<Self, Failed> {
|
||||
Lasso::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for Lasso<T, M> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Predictor<X, Y>
|
||||
for Lasso<TX, TY, X, Y>
|
||||
{
|
||||
fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Lasso<T, M> {
|
||||
/// Lasso grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LassoSearchParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Controls the strength of the penalty to the loss function.
|
||||
pub alpha: Vec<f64>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// If true the regressors X will be normalized before regression
|
||||
/// by subtracting the mean and dividing by the standard deviation.
|
||||
pub normalize: Vec<bool>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The tolerance for the optimization
|
||||
pub tol: Vec<f64>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The maximum number of iterations
|
||||
pub max_iter: Vec<usize>,
|
||||
}
|
||||
|
||||
/// Lasso grid search iterator
|
||||
pub struct LassoSearchParametersIterator {
|
||||
lasso_search_parameters: LassoSearchParameters,
|
||||
current_alpha: usize,
|
||||
current_normalize: usize,
|
||||
current_tol: usize,
|
||||
current_max_iter: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for LassoSearchParameters {
|
||||
type Item = LassoParameters;
|
||||
type IntoIter = LassoSearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
LassoSearchParametersIterator {
|
||||
lasso_search_parameters: self,
|
||||
current_alpha: 0,
|
||||
current_normalize: 0,
|
||||
current_tol: 0,
|
||||
current_max_iter: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for LassoSearchParametersIterator {
|
||||
type Item = LassoParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_alpha == self.lasso_search_parameters.alpha.len()
|
||||
&& self.current_normalize == self.lasso_search_parameters.normalize.len()
|
||||
&& self.current_tol == self.lasso_search_parameters.tol.len()
|
||||
&& self.current_max_iter == self.lasso_search_parameters.max_iter.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = LassoParameters {
|
||||
alpha: self.lasso_search_parameters.alpha[self.current_alpha],
|
||||
normalize: self.lasso_search_parameters.normalize[self.current_normalize],
|
||||
tol: self.lasso_search_parameters.tol[self.current_tol],
|
||||
max_iter: self.lasso_search_parameters.max_iter[self.current_max_iter],
|
||||
};
|
||||
|
||||
if self.current_alpha + 1 < self.lasso_search_parameters.alpha.len() {
|
||||
self.current_alpha += 1;
|
||||
} else if self.current_normalize + 1 < self.lasso_search_parameters.normalize.len() {
|
||||
self.current_alpha = 0;
|
||||
self.current_normalize += 1;
|
||||
} else if self.current_tol + 1 < self.lasso_search_parameters.tol.len() {
|
||||
self.current_alpha = 0;
|
||||
self.current_normalize = 0;
|
||||
self.current_tol += 1;
|
||||
} else if self.current_max_iter + 1 < self.lasso_search_parameters.max_iter.len() {
|
||||
self.current_alpha = 0;
|
||||
self.current_normalize = 0;
|
||||
self.current_tol = 0;
|
||||
self.current_max_iter += 1;
|
||||
} else {
|
||||
self.current_alpha += 1;
|
||||
self.current_normalize += 1;
|
||||
self.current_tol += 1;
|
||||
self.current_max_iter += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LassoSearchParameters {
|
||||
fn default() -> Self {
|
||||
let default_params = LassoParameters::default();
|
||||
|
||||
LassoSearchParameters {
|
||||
alpha: vec![default_params.alpha],
|
||||
normalize: vec![default_params.normalize],
|
||||
tol: vec![default_params.tol],
|
||||
max_iter: vec![default_params.max_iter],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Lasso<TX, TY, X, Y> {
|
||||
/// Fits Lasso regression to your data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - target values
|
||||
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
|
||||
pub fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: LassoParameters<T>,
|
||||
) -> Result<Lasso<T, M>, Failed> {
|
||||
pub fn fit(x: &X, y: &Y, parameters: LassoParameters) -> Result<Lasso<TX, TY, X, Y>, Failed> {
|
||||
let (n, p) = x.shape();
|
||||
|
||||
if n <= p {
|
||||
@@ -130,11 +252,11 @@ impl<T: RealNumber, M: Matrix<T>> Lasso<T, M> {
|
||||
));
|
||||
}
|
||||
|
||||
if parameters.alpha < T::zero() {
|
||||
if parameters.alpha < 0f64 {
|
||||
return Err(Failed::fit("alpha should be >= 0"));
|
||||
}
|
||||
|
||||
if parameters.tol <= T::zero() {
|
||||
if parameters.tol <= 0f64 {
|
||||
return Err(Failed::fit("tol should be > 0"));
|
||||
}
|
||||
|
||||
@@ -142,75 +264,99 @@ impl<T: RealNumber, M: Matrix<T>> Lasso<T, M> {
|
||||
return Err(Failed::fit("max_iter should be > 0"));
|
||||
}
|
||||
|
||||
if y.len() != n {
|
||||
if y.shape() != n {
|
||||
return Err(Failed::fit("Number of rows in X should = len(y)"));
|
||||
}
|
||||
|
||||
let l1_reg = parameters.alpha * T::from_usize(n).unwrap();
|
||||
let y: Vec<TX> = y.iterator(0).map(|&v| TX::from(v).unwrap()).collect();
|
||||
|
||||
let l1_reg = TX::from_f64(parameters.alpha * n as f64).unwrap();
|
||||
|
||||
let (w, b) = if parameters.normalize {
|
||||
let (scaled_x, col_mean, col_std) = Self::rescale_x(x)?;
|
||||
|
||||
let mut optimizer = InteriorPointOptimizer::new(&scaled_x, p);
|
||||
|
||||
let mut w =
|
||||
optimizer.optimize(&scaled_x, y, l1_reg, parameters.max_iter, parameters.tol)?;
|
||||
let mut w = optimizer.optimize(
|
||||
&scaled_x,
|
||||
&y,
|
||||
l1_reg,
|
||||
parameters.max_iter,
|
||||
TX::from_f64(parameters.tol).unwrap(),
|
||||
)?;
|
||||
|
||||
for (j, col_std_j) in col_std.iter().enumerate().take(p) {
|
||||
w.set(j, 0, w.get(j, 0) / *col_std_j);
|
||||
w[j] /= *col_std_j;
|
||||
}
|
||||
|
||||
let mut b = T::zero();
|
||||
let mut b = TX::zero();
|
||||
|
||||
for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
|
||||
b += w.get(i, 0) * *col_mean_i;
|
||||
b += w[i] * *col_mean_i;
|
||||
}
|
||||
|
||||
b = y.mean() - b;
|
||||
(w, b)
|
||||
b = TX::from_f64(y.mean_by()).unwrap() - b;
|
||||
(X::from_column(&w), b)
|
||||
} else {
|
||||
let mut optimizer = InteriorPointOptimizer::new(x, p);
|
||||
|
||||
let w = optimizer.optimize(x, y, l1_reg, parameters.max_iter, parameters.tol)?;
|
||||
let w = optimizer.optimize(
|
||||
x,
|
||||
&y,
|
||||
l1_reg,
|
||||
parameters.max_iter,
|
||||
TX::from_f64(parameters.tol).unwrap(),
|
||||
)?;
|
||||
|
||||
(w, y.mean())
|
||||
(X::from_column(&w), TX::from_f64(y.mean_by()).unwrap())
|
||||
};
|
||||
|
||||
Ok(Lasso {
|
||||
intercept: b,
|
||||
coefficients: w,
|
||||
intercept: Some(b),
|
||||
coefficients: Some(w),
|
||||
_phantom_ty: PhantomData,
|
||||
_phantom_y: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
/// Predict target values from `x`
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
let (nrows, _) = x.shape();
|
||||
let mut y_hat = x.matmul(&self.coefficients);
|
||||
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
|
||||
Ok(y_hat.transpose().to_row_vector())
|
||||
let mut y_hat = x.matmul(self.coefficients());
|
||||
let bias = X::fill(nrows, 1, self.intercept.unwrap());
|
||||
y_hat.add_mut(&bias);
|
||||
Ok(Y::from_iterator(
|
||||
y_hat.iterator(0).map(|&v| TY::from(v).unwrap()),
|
||||
nrows,
|
||||
))
|
||||
}
|
||||
|
||||
/// Get estimates regression coefficients
|
||||
pub fn coefficients(&self) -> &M {
|
||||
&self.coefficients
|
||||
pub fn coefficients(&self) -> &X {
|
||||
self.coefficients.as_ref().unwrap()
|
||||
}
|
||||
|
||||
/// Get estimate of intercept
|
||||
pub fn intercept(&self) -> T {
|
||||
self.intercept
|
||||
pub fn intercept(&self) -> &TX {
|
||||
self.intercept.as_ref().unwrap()
|
||||
}
|
||||
|
||||
fn rescale_x(x: &M) -> Result<(M, Vec<T>, Vec<T>), Failed> {
|
||||
let col_mean = x.mean(0);
|
||||
let col_std = x.std(0);
|
||||
fn rescale_x(x: &X) -> Result<(X, Vec<TX>, Vec<TX>), Failed> {
|
||||
let col_mean: Vec<TX> = x
|
||||
.mean_by(0)
|
||||
.iter()
|
||||
.map(|&v| TX::from_f64(v).unwrap())
|
||||
.collect();
|
||||
let col_std: Vec<TX> = x
|
||||
.std_dev(0)
|
||||
.iter()
|
||||
.map(|&v| TX::from_f64(v).unwrap())
|
||||
.collect();
|
||||
|
||||
for (i, col_std_i) in col_std.iter().enumerate() {
|
||||
if (*col_std_i - T::zero()).abs() < T::epsilon() {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Cannot rescale constant column {}",
|
||||
i
|
||||
)));
|
||||
if (*col_std_i - TX::zero()).abs() < TX::epsilon() {
|
||||
return Err(Failed::fit(&format!("Cannot rescale constant column {i}")));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -223,10 +369,36 @@ impl<T: RealNumber, M: Matrix<T>> Lasso<T, M> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use crate::metrics::mean_absolute_error;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = LassoSearchParameters {
|
||||
alpha: vec![0., 1.],
|
||||
max_iter: vec![10, 100],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 0.);
|
||||
assert_eq!(next.max_iter, 10);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 1.);
|
||||
assert_eq!(next.max_iter, 10);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 0.);
|
||||
assert_eq!(next.max_iter, 100);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 1.);
|
||||
assert_eq!(next.max_iter, 100);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn lasso_fit_predict() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
@@ -246,7 +418,8 @@ mod tests {
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let y: Vec<f64> = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
@@ -275,39 +448,40 @@ mod tests {
|
||||
assert!(mean_absolute_error(&y_hat, &y) < 2.0);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
&[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
|
||||
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
&[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
|
||||
&[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
|
||||
&[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
|
||||
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
&[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
&[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
|
||||
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
// TODO: serialization for the new DenseMatrix needs to be implemented
|
||||
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
|
||||
// #[test]
|
||||
// #[cfg(feature = "serde")]
|
||||
// fn serde() {
|
||||
// let x = DenseMatrix::from_2d_array(&[
|
||||
// &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
// &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
// &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
// &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
|
||||
// &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
// &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
|
||||
// &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
|
||||
// &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
|
||||
// &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
// &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
// &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
// &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
|
||||
// &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
// &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
// &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
// &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
// ]);
|
||||
|
||||
let y = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
// let y = vec![
|
||||
// 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
// 114.2, 115.7, 116.9,
|
||||
// ];
|
||||
|
||||
let lr = Lasso::fit(&x, &y, Default::default()).unwrap();
|
||||
// let lr = Lasso::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let deserialized_lr: Lasso<f64, DenseMatrix<f64>> =
|
||||
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
||||
// let deserialized_lr: Lasso<f64, f64, DenseMatrix<f64>, Vec<f64>> =
|
||||
// serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(lr, deserialized_lr);
|
||||
}
|
||||
// assert_eq!(lr, deserialized_lr);
|
||||
// }
|
||||
}
|
||||
|
||||
@@ -12,21 +12,22 @@
|
||||
//!
|
||||
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::linalg::basic::arrays::{Array1, Array2, ArrayView1, MutArray, MutArrayView1};
|
||||
use crate::linear::bg_solver::BiconjugateGradientSolver;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::numbers::floatnum::FloatNumber;
|
||||
|
||||
pub struct InteriorPointOptimizer<T: RealNumber, M: Matrix<T>> {
|
||||
ata: M,
|
||||
/// Interior Point Optimizer
|
||||
pub struct InteriorPointOptimizer<T: FloatNumber, X: Array2<T>> {
|
||||
ata: X,
|
||||
d1: Vec<T>,
|
||||
d2: Vec<T>,
|
||||
prb: Vec<T>,
|
||||
prs: Vec<T>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> InteriorPointOptimizer<T, M> {
|
||||
pub fn new(a: &M, n: usize) -> InteriorPointOptimizer<T, M> {
|
||||
impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
|
||||
/// Initialize a new Interior Point Optimizer
|
||||
pub fn new(a: &X, n: usize) -> InteriorPointOptimizer<T, X> {
|
||||
InteriorPointOptimizer {
|
||||
ata: a.ab(true, a, false),
|
||||
d1: vec![T::zero(); n],
|
||||
@@ -36,14 +37,15 @@ impl<T: RealNumber, M: Matrix<T>> InteriorPointOptimizer<T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the optimization
|
||||
pub fn optimize(
|
||||
&mut self,
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
x: &X,
|
||||
y: &Vec<T>,
|
||||
lambda: T,
|
||||
max_iter: usize,
|
||||
tol: T,
|
||||
) -> Result<M, Failed> {
|
||||
) -> Result<Vec<T>, Failed> {
|
||||
let (n, p) = x.shape();
|
||||
let p_f64 = T::from_usize(p).unwrap();
|
||||
|
||||
@@ -58,50 +60,53 @@ impl<T: RealNumber, M: Matrix<T>> InteriorPointOptimizer<T, M> {
|
||||
let gamma = T::from_f64(-0.25).unwrap();
|
||||
let mu = T::two();
|
||||
|
||||
let y = M::from_row_vector(y.sub_scalar(y.mean())).transpose();
|
||||
// let y = M::from_row_vector(y.sub_scalar(y.mean_by())).transpose();
|
||||
let y = y.sub_scalar(T::from_f64(y.mean_by()).unwrap());
|
||||
|
||||
let mut max_ls_iter = 100;
|
||||
let mut pitr = 0;
|
||||
let mut w = M::zeros(p, 1);
|
||||
let mut w = Vec::zeros(p);
|
||||
let mut neww = w.clone();
|
||||
let mut u = M::ones(p, 1);
|
||||
let mut u = Vec::ones(p);
|
||||
let mut newu = u.clone();
|
||||
|
||||
let mut f = M::fill(p, 2, -T::one());
|
||||
let mut f = X::fill(p, 2, -T::one());
|
||||
let mut newf = f.clone();
|
||||
|
||||
let mut q1 = vec![T::zero(); p];
|
||||
let mut q2 = vec![T::zero(); p];
|
||||
|
||||
let mut dx = M::zeros(p, 1);
|
||||
let mut du = M::zeros(p, 1);
|
||||
let mut dxu = M::zeros(2 * p, 1);
|
||||
let mut grad = M::zeros(2 * p, 1);
|
||||
let mut dx = Vec::zeros(p);
|
||||
let mut du = Vec::zeros(p);
|
||||
let mut dxu = Vec::zeros(2 * p);
|
||||
let mut grad = Vec::zeros(2 * p);
|
||||
|
||||
let mut nu = M::zeros(n, 1);
|
||||
let mut nu = Vec::zeros(n);
|
||||
let mut dobj = T::zero();
|
||||
let mut s = T::infinity();
|
||||
let mut t = T::one()
|
||||
.max(T::one() / lambda)
|
||||
.min(T::two() * p_f64 / T::from(1e-3).unwrap());
|
||||
|
||||
let lambda_f64 = lambda.to_f64().unwrap();
|
||||
|
||||
for ntiter in 0..max_iter {
|
||||
let mut z = x.matmul(&w);
|
||||
let mut z = w.xa(true, x);
|
||||
|
||||
for i in 0..n {
|
||||
z.set(i, 0, z.get(i, 0) - y.get(i, 0));
|
||||
nu.set(i, 0, T::two() * z.get(i, 0));
|
||||
z[i] -= y[i];
|
||||
nu[i] = T::two() * z[i];
|
||||
}
|
||||
|
||||
// CALCULATE DUALITY GAP
|
||||
let xnu = x.ab(true, &nu, false);
|
||||
let max_xnu = xnu.norm(T::infinity());
|
||||
if max_xnu > lambda {
|
||||
let lnu = lambda / max_xnu;
|
||||
let xnu = nu.xa(false, x);
|
||||
let max_xnu = xnu.norm(f64::INFINITY);
|
||||
if max_xnu > lambda_f64 {
|
||||
let lnu = T::from_f64(lambda_f64 / max_xnu).unwrap();
|
||||
nu.mul_scalar_mut(lnu);
|
||||
}
|
||||
|
||||
let pobj = z.dot(&z) + lambda * w.norm(T::one());
|
||||
let pobj = z.dot(&z) + lambda * T::from_f64(w.norm(1f64)).unwrap();
|
||||
dobj = dobj.max(gamma * nu.dot(&nu) - nu.dot(&y));
|
||||
|
||||
let gap = pobj - dobj;
|
||||
@@ -118,22 +123,22 @@ impl<T: RealNumber, M: Matrix<T>> InteriorPointOptimizer<T, M> {
|
||||
|
||||
// CALCULATE NEWTON STEP
|
||||
for i in 0..p {
|
||||
let q1i = T::one() / (u.get(i, 0) + w.get(i, 0));
|
||||
let q2i = T::one() / (u.get(i, 0) - w.get(i, 0));
|
||||
let q1i = T::one() / (u[i] + w[i]);
|
||||
let q2i = T::one() / (u[i] - w[i]);
|
||||
q1[i] = q1i;
|
||||
q2[i] = q2i;
|
||||
self.d1[i] = (q1i * q1i + q2i * q2i) / t;
|
||||
self.d2[i] = (q1i * q1i - q2i * q2i) / t;
|
||||
}
|
||||
|
||||
let mut gradphi = x.ab(true, &z, false);
|
||||
let mut gradphi = z.xa(false, x);
|
||||
|
||||
for i in 0..p {
|
||||
let g1 = T::two() * gradphi.get(i, 0) - (q1[i] - q2[i]) / t;
|
||||
let g1 = T::two() * gradphi[i] - (q1[i] - q2[i]) / t;
|
||||
let g2 = lambda - (q1[i] + q2[i]) / t;
|
||||
gradphi.set(i, 0, g1);
|
||||
grad.set(i, 0, -g1);
|
||||
grad.set(i + p, 0, -g2);
|
||||
gradphi[i] = g1;
|
||||
grad[i] = -g1;
|
||||
grad[i + p] = -g2;
|
||||
}
|
||||
|
||||
for i in 0..p {
|
||||
@@ -141,7 +146,7 @@ impl<T: RealNumber, M: Matrix<T>> InteriorPointOptimizer<T, M> {
|
||||
self.prs[i] = self.prb[i] * self.d1[i] - self.d2[i].powi(2);
|
||||
}
|
||||
|
||||
let normg = grad.norm2();
|
||||
let normg = T::from_f64(grad.norm2()).unwrap();
|
||||
let mut pcgtol = min_pcgtol.min(eta * gap / T::one().min(normg));
|
||||
if ntiter != 0 && pitr == 0 {
|
||||
pcgtol *= min_pcgtol;
|
||||
@@ -152,10 +157,8 @@ impl<T: RealNumber, M: Matrix<T>> InteriorPointOptimizer<T, M> {
|
||||
pitr = pcgmaxi;
|
||||
}
|
||||
|
||||
for i in 0..p {
|
||||
dx.set(i, 0, dxu.get(i, 0));
|
||||
du.set(i, 0, dxu.get(i + p, 0));
|
||||
}
|
||||
dx[..p].copy_from_slice(&dxu[..p]);
|
||||
du[..p].copy_from_slice(&dxu[p..(p + p)]);
|
||||
|
||||
// BACKTRACKING LINE SEARCH
|
||||
let phi = z.dot(&z) + lambda * u.sum() - Self::sumlogneg(&f) / t;
|
||||
@@ -165,16 +168,20 @@ impl<T: RealNumber, M: Matrix<T>> InteriorPointOptimizer<T, M> {
|
||||
let lsiter = 0;
|
||||
while lsiter < max_ls_iter {
|
||||
for i in 0..p {
|
||||
neww.set(i, 0, w.get(i, 0) + s * dx.get(i, 0));
|
||||
newu.set(i, 0, u.get(i, 0) + s * du.get(i, 0));
|
||||
newf.set(i, 0, neww.get(i, 0) - newu.get(i, 0));
|
||||
newf.set(i, 1, -neww.get(i, 0) - newu.get(i, 0));
|
||||
neww[i] = w[i] + s * dx[i];
|
||||
newu[i] = u[i] + s * du[i];
|
||||
newf.set((i, 0), neww[i] - newu[i]);
|
||||
newf.set((i, 1), -neww[i] - newu[i]);
|
||||
}
|
||||
|
||||
if newf.max() < T::zero() {
|
||||
let mut newz = x.matmul(&neww);
|
||||
if newf
|
||||
.iterator(0)
|
||||
.fold(T::neg_infinity(), |max, v| v.max(max))
|
||||
< T::zero()
|
||||
{
|
||||
let mut newz = neww.xa(true, x);
|
||||
for i in 0..n {
|
||||
newz.set(i, 0, newz.get(i, 0) - y.get(i, 0));
|
||||
newz[i] -= y[i];
|
||||
}
|
||||
|
||||
let newphi = newz.dot(&newz) + lambda * newu.sum() - Self::sumlogneg(&newf) / t;
|
||||
@@ -200,54 +207,41 @@ impl<T: RealNumber, M: Matrix<T>> InteriorPointOptimizer<T, M> {
|
||||
Ok(w)
|
||||
}
|
||||
|
||||
fn sumlogneg(f: &M) -> T {
|
||||
fn sumlogneg(f: &X) -> T {
|
||||
let (n, _) = f.shape();
|
||||
let mut sum = T::zero();
|
||||
for i in 0..n {
|
||||
sum += (-f.get(i, 0)).ln();
|
||||
sum += (-f.get(i, 1)).ln();
|
||||
sum += (-*f.get((i, 0))).ln();
|
||||
sum += (-*f.get((i, 1))).ln();
|
||||
}
|
||||
sum
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> BiconjugateGradientSolver<T, M> for InteriorPointOptimizer<T, M> {
|
||||
fn solve_preconditioner(&self, a: &M, b: &M, x: &mut M) {
|
||||
impl<'a, T: FloatNumber, X: Array2<T>> BiconjugateGradientSolver<'a, T, X>
|
||||
for InteriorPointOptimizer<T, X>
|
||||
{
|
||||
fn solve_preconditioner(&self, a: &'a X, b: &[T], x: &mut [T]) {
|
||||
let (_, p) = a.shape();
|
||||
|
||||
for i in 0..p {
|
||||
x.set(
|
||||
i,
|
||||
0,
|
||||
(self.d1[i] * b.get(i, 0) - self.d2[i] * b.get(i + p, 0)) / self.prs[i],
|
||||
);
|
||||
x.set(
|
||||
i + p,
|
||||
0,
|
||||
(-self.d2[i] * b.get(i, 0) + self.prb[i] * b.get(i + p, 0)) / self.prs[i],
|
||||
);
|
||||
x[i] = (self.d1[i] * b[i] - self.d2[i] * b[i + p]) / self.prs[i];
|
||||
x[i + p] = (-self.d2[i] * b[i] + self.prb[i] * b[i + p]) / self.prs[i];
|
||||
}
|
||||
}
|
||||
|
||||
fn mat_vec_mul(&self, _: &M, x: &M, y: &mut M) {
|
||||
fn mat_vec_mul(&self, _: &X, x: &Vec<T>, y: &mut Vec<T>) {
|
||||
let (_, p) = self.ata.shape();
|
||||
let atax = self.ata.matmul(&x.slice(0..p, 0..1));
|
||||
let x_slice = Vec::from_slice(x.slice(0..p).as_ref());
|
||||
let atax = x_slice.xa(true, &self.ata);
|
||||
|
||||
for i in 0..p {
|
||||
y.set(
|
||||
i,
|
||||
0,
|
||||
T::two() * atax.get(i, 0) + self.d1[i] * x.get(i, 0) + self.d2[i] * x.get(i + p, 0),
|
||||
);
|
||||
y.set(
|
||||
i + p,
|
||||
0,
|
||||
self.d2[i] * x.get(i, 0) + self.d1[i] * x.get(i + p, 0),
|
||||
);
|
||||
y[i] = T::two() * atax[i] + self.d1[i] * x[i] + self.d2[i] * x[i + p];
|
||||
y[i + p] = self.d2[i] * x[i] + self.d1[i] * x[i + p];
|
||||
}
|
||||
}
|
||||
|
||||
fn mat_t_vec_mul(&self, a: &M, x: &M, y: &mut M) {
|
||||
fn mat_t_vec_mul(&self, a: &X, x: &Vec<T>, y: &mut Vec<T>) {
|
||||
self.mat_vec_mul(a, x, y);
|
||||
}
|
||||
}
|
||||
|
||||
+221
-95
@@ -12,14 +12,14 @@
|
||||
//! \\[\hat{\beta} = (X^TX)^{-1}X^Ty \\]
|
||||
//!
|
||||
//! the \\((X^TX)^{-1}\\) term is both computationally expensive and numerically unstable. An alternative approach is to use a matrix decomposition to avoid this operation.
|
||||
//! SmartCore uses [SVD](../../linalg/svd/index.html) and [QR](../../linalg/qr/index.html) matrix decomposition to find estimates of \\(\hat{\beta}\\).
|
||||
//! `smartcore` uses [SVD](../../linalg/svd/index.html) and [QR](../../linalg/qr/index.html) matrix decomposition to find estimates of \\(\hat{\beta}\\).
|
||||
//! The QR decomposition is more computationally efficient and more numerically stable than calculating the normal equation directly,
|
||||
//! but does not work for all data matrices. Unlike the QR decomposition, all matrices have an SVD decomposition.
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
//! use smartcore::linear::linear_regression::*;
|
||||
//!
|
||||
//! // Longley dataset (https://www.statsmodels.org/stable/datasets/generated/longley.html)
|
||||
@@ -40,7 +40,7 @@
|
||||
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0,
|
||||
//! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
||||
@@ -61,21 +61,26 @@
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::linalg::basic::arrays::{Array1, Array2};
|
||||
use crate::linalg::traits::qr::QRDecomposable;
|
||||
use crate::linalg::traits::svd::SVDDecomposable;
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Default, Clone, Eq, PartialEq)]
|
||||
/// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable.
|
||||
pub enum LinearRegressionSolverName {
|
||||
/// QR decomposition, see [QR](../../linalg/qr/index.html)
|
||||
QR,
|
||||
#[default]
|
||||
/// SVD decomposition, see [SVD](../../linalg/svd/index.html)
|
||||
SVD,
|
||||
}
|
||||
@@ -84,27 +89,11 @@ pub enum LinearRegressionSolverName {
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LinearRegressionParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub solver: LinearRegressionSolverName,
|
||||
}
|
||||
|
||||
/// Linear Regression
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct LinearRegression<T: RealNumber, M: Matrix<T>> {
|
||||
coefficients: M,
|
||||
intercept: T,
|
||||
_solver: LinearRegressionSolverName,
|
||||
}
|
||||
|
||||
impl LinearRegressionParameters {
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub fn with_solver(mut self, solver: LinearRegressionSolverName) -> Self {
|
||||
self.solver = solver;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LinearRegressionParameters {
|
||||
fn default() -> Self {
|
||||
LinearRegressionParameters {
|
||||
@@ -113,43 +102,157 @@ impl Default for LinearRegressionParameters {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> PartialEq for LinearRegression<T, M> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.coefficients == other.coefficients
|
||||
&& (self.intercept - other.intercept).abs() <= T::epsilon()
|
||||
/// Linear Regression
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct LinearRegression<
|
||||
TX: Number + RealNumber,
|
||||
TY: Number,
|
||||
X: Array2<TX> + QRDecomposable<TX> + SVDDecomposable<TX>,
|
||||
Y: Array1<TY>,
|
||||
> {
|
||||
coefficients: Option<X>,
|
||||
intercept: Option<TX>,
|
||||
_phantom_ty: PhantomData<TY>,
|
||||
_phantom_y: PhantomData<Y>,
|
||||
}
|
||||
|
||||
impl LinearRegressionParameters {
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub fn with_solver(mut self, solver: LinearRegressionSolverName) -> Self {
|
||||
self.solver = solver;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, LinearRegressionParameters>
|
||||
for LinearRegression<T, M>
|
||||
/// Linear Regression grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LinearRegressionSearchParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub solver: Vec<LinearRegressionSolverName>,
|
||||
}
|
||||
|
||||
/// Linear Regression grid search iterator
|
||||
pub struct LinearRegressionSearchParametersIterator {
|
||||
linear_regression_search_parameters: LinearRegressionSearchParameters,
|
||||
current_solver: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for LinearRegressionSearchParameters {
|
||||
type Item = LinearRegressionParameters;
|
||||
type IntoIter = LinearRegressionSearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
LinearRegressionSearchParametersIterator {
|
||||
linear_regression_search_parameters: self,
|
||||
current_solver: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for LinearRegressionSearchParametersIterator {
|
||||
type Item = LinearRegressionParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_solver == self.linear_regression_search_parameters.solver.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = LinearRegressionParameters {
|
||||
solver: self.linear_regression_search_parameters.solver[self.current_solver].clone(),
|
||||
};
|
||||
|
||||
self.current_solver += 1;
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LinearRegressionSearchParameters {
|
||||
fn default() -> Self {
|
||||
let default_params = LinearRegressionParameters::default();
|
||||
|
||||
LinearRegressionSearchParameters {
|
||||
solver: vec![default_params.solver],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
TX: Number + RealNumber,
|
||||
TY: Number,
|
||||
X: Array2<TX> + QRDecomposable<TX> + SVDDecomposable<TX>,
|
||||
Y: Array1<TY>,
|
||||
> PartialEq for LinearRegression<TX, TY, X, Y>
|
||||
{
|
||||
fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: LinearRegressionParameters,
|
||||
) -> Result<Self, Failed> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.intercept == other.intercept
|
||||
&& self.coefficients().shape() == other.coefficients().shape()
|
||||
&& self
|
||||
.coefficients()
|
||||
.iterator(0)
|
||||
.zip(other.coefficients().iterator(0))
|
||||
.all(|(&a, &b)| (a - b).abs() <= TX::epsilon())
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
TX: Number + RealNumber,
|
||||
TY: Number,
|
||||
X: Array2<TX> + QRDecomposable<TX> + SVDDecomposable<TX>,
|
||||
Y: Array1<TY>,
|
||||
> SupervisedEstimator<X, Y, LinearRegressionParameters> for LinearRegression<TX, TY, X, Y>
|
||||
{
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
coefficients: Option::None,
|
||||
intercept: Option::None,
|
||||
_phantom_ty: PhantomData,
|
||||
_phantom_y: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn fit(x: &X, y: &Y, parameters: LinearRegressionParameters) -> Result<Self, Failed> {
|
||||
LinearRegression::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for LinearRegression<T, M> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
impl<
|
||||
TX: Number + RealNumber,
|
||||
TY: Number,
|
||||
X: Array2<TX> + QRDecomposable<TX> + SVDDecomposable<TX>,
|
||||
Y: Array1<TY>,
|
||||
> Predictor<X, Y> for LinearRegression<TX, TY, X, Y>
|
||||
{
|
||||
fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> LinearRegression<T, M> {
|
||||
impl<
|
||||
TX: Number + RealNumber,
|
||||
TY: Number,
|
||||
X: Array2<TX> + QRDecomposable<TX> + SVDDecomposable<TX>,
|
||||
Y: Array1<TY>,
|
||||
> LinearRegression<TX, TY, X, Y>
|
||||
{
|
||||
/// Fits Linear Regression to your data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - target values
|
||||
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
|
||||
pub fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
x: &X,
|
||||
y: &Y,
|
||||
parameters: LinearRegressionParameters,
|
||||
) -> Result<LinearRegression<T, M>, Failed> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
let b = y_m.transpose();
|
||||
) -> Result<LinearRegression<TX, TY, X, Y>, Failed> {
|
||||
let b = X::from_iterator(
|
||||
y.iterator(0).map(|&v| TX::from(v).unwrap()),
|
||||
y.shape(),
|
||||
1,
|
||||
0,
|
||||
);
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let (y_nrows, _) = b.shape();
|
||||
|
||||
@@ -159,59 +262,77 @@ impl<T: RealNumber, M: Matrix<T>> LinearRegression<T, M> {
|
||||
));
|
||||
}
|
||||
|
||||
let a = x.h_stack(&M::ones(x_nrows, 1));
|
||||
let a = x.h_stack(&X::ones(x_nrows, 1));
|
||||
|
||||
let w = match parameters.solver {
|
||||
LinearRegressionSolverName::QR => a.qr_solve_mut(b)?,
|
||||
LinearRegressionSolverName::SVD => a.svd_solve_mut(b)?,
|
||||
};
|
||||
|
||||
let wights = w.slice(0..num_attributes, 0..1);
|
||||
let weights = X::from_slice(w.slice(0..num_attributes, 0..1).as_ref());
|
||||
|
||||
Ok(LinearRegression {
|
||||
intercept: w.get(num_attributes, 0),
|
||||
coefficients: wights,
|
||||
_solver: parameters.solver,
|
||||
intercept: Some(*w.get((num_attributes, 0))),
|
||||
coefficients: Some(weights),
|
||||
_phantom_ty: PhantomData,
|
||||
_phantom_y: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
/// Predict target values from `x`
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
let (nrows, _) = x.shape();
|
||||
let mut y_hat = x.matmul(&self.coefficients);
|
||||
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
|
||||
Ok(y_hat.transpose().to_row_vector())
|
||||
let bias = X::fill(nrows, 1, *self.intercept());
|
||||
let mut y_hat = x.matmul(self.coefficients());
|
||||
y_hat.add_mut(&bias);
|
||||
Ok(Y::from_iterator(
|
||||
y_hat.iterator(0).map(|&v| TY::from(v).unwrap()),
|
||||
nrows,
|
||||
))
|
||||
}
|
||||
|
||||
/// Get estimates regression coefficients
|
||||
pub fn coefficients(&self) -> &M {
|
||||
&self.coefficients
|
||||
pub fn coefficients(&self) -> &X {
|
||||
self.coefficients.as_ref().unwrap()
|
||||
}
|
||||
|
||||
/// Get estimate of intercept
|
||||
pub fn intercept(&self) -> T {
|
||||
self.intercept
|
||||
pub fn intercept(&self) -> &TX {
|
||||
self.intercept.as_ref().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = LinearRegressionSearchParameters {
|
||||
solver: vec![
|
||||
LinearRegressionSolverName::QR,
|
||||
LinearRegressionSolverName::SVD,
|
||||
],
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
assert_eq!(iter.next().unwrap().solver, LinearRegressionSolverName::QR);
|
||||
assert_eq!(iter.next().unwrap().solver, LinearRegressionSolverName::SVD);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn ols_fit_predict() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
&[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
|
||||
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
&[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
|
||||
&[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
|
||||
&[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
|
||||
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
&[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
@@ -220,11 +341,11 @@ mod tests {
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let y: Vec<f64> = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8,
|
||||
];
|
||||
|
||||
let y_hat_qr = LinearRegression::fit(
|
||||
@@ -251,39 +372,44 @@ mod tests {
|
||||
.all(|(&a, &b)| (a - b).abs() <= 5.0));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
&[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
|
||||
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
&[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
|
||||
&[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
|
||||
&[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
|
||||
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
&[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
&[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
|
||||
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
// TODO: serialization for the new DenseMatrix needs to be implemented
|
||||
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
|
||||
// #[test]
|
||||
// #[cfg(feature = "serde")]
|
||||
// fn serde() {
|
||||
// let x = DenseMatrix::from_2d_array(&[
|
||||
// &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
// &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
// &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
// &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
|
||||
// &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
// &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
|
||||
// &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
|
||||
// &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
|
||||
// &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
// &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
// &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
// &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
|
||||
// &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
// &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
// &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
// &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
// ]).unwrap();
|
||||
|
||||
let y = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
// let y = vec![
|
||||
// 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
// 114.2, 115.7, 116.9,
|
||||
// ];
|
||||
|
||||
let lr = LinearRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
// let lr = LinearRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let deserialized_lr: LinearRegression<f64, DenseMatrix<f64>> =
|
||||
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
||||
// let deserialized_lr: LinearRegression<f64, f64, DenseMatrix<f64>, Vec<f64>> =
|
||||
// serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(lr, deserialized_lr);
|
||||
}
|
||||
// assert_eq!(lr, deserialized_lr);
|
||||
|
||||
// let default = LinearRegressionParameters::default();
|
||||
// let parameters: LinearRegressionParameters = serde_json::from_str("{}").unwrap();
|
||||
// assert_eq!(parameters.solver, default.solver);
|
||||
// }
|
||||
}
|
||||
|
||||
+446
-242
File diff suppressed because it is too large
Load Diff
+2
-2
@@ -20,10 +20,10 @@
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
pub(crate) mod bg_solver;
|
||||
pub mod bg_solver;
|
||||
pub mod elastic_net;
|
||||
pub mod lasso;
|
||||
pub(crate) mod lasso_optimizer;
|
||||
pub mod lasso_optimizer;
|
||||
pub mod linear_regression;
|
||||
pub mod logistic_regression;
|
||||
pub mod ridge_regression;
|
||||
|
||||
+258
-95
@@ -12,14 +12,14 @@
|
||||
//! where \\(\alpha \geq 0\\) is a tuning parameter that controls strength of regularization. When \\(\alpha = 0\\) the penalty term has no effect, and ridge regression will produce the least squares estimates.
|
||||
//! However, as \\(\alpha \rightarrow \infty\\), the impact of the shrinkage penalty grows, and the ridge regression coefficient estimates will approach zero.
|
||||
//!
|
||||
//! SmartCore uses [SVD](../../linalg/svd/index.html) and [Cholesky](../../linalg/cholesky/index.html) matrix decomposition to find estimates of \\(\hat{\beta}\\).
|
||||
//! `smartcore` uses [SVD](../../linalg/svd/index.html) and [Cholesky](../../linalg/cholesky/index.html) matrix decomposition to find estimates of \\(\hat{\beta}\\).
|
||||
//! The Cholesky decomposition is more computationally efficient and more numerically stable than calculating the normal equation directly,
|
||||
//! but does not work for all data matrices. Unlike the Cholesky decomposition, all matrices have an SVD decomposition.
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
//! use smartcore::linear::ridge_regression::*;
|
||||
//!
|
||||
//! // Longley dataset (https://www.statsmodels.org/stable/datasets/generated/longley.html)
|
||||
@@ -40,7 +40,7 @@
|
||||
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0,
|
||||
//! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
||||
@@ -57,21 +57,25 @@
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::linalg::basic::arrays::{Array1, Array2};
|
||||
use crate::linalg::traits::cholesky::CholeskyDecomposable;
|
||||
use crate::linalg::traits::svd::SVDDecomposable;
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Eq, PartialEq, Default)]
|
||||
/// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable.
|
||||
pub enum RidgeRegressionSolverName {
|
||||
/// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html)
|
||||
#[default]
|
||||
Cholesky,
|
||||
/// SVD decomposition, see [SVD](../../linalg/svd/index.html)
|
||||
SVD,
|
||||
@@ -80,7 +84,7 @@ pub enum RidgeRegressionSolverName {
|
||||
/// Ridge Regression parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RidgeRegressionParameters<T: RealNumber> {
|
||||
pub struct RidgeRegressionParameters<T: Number + RealNumber> {
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub solver: RidgeRegressionSolverName,
|
||||
/// Controls the strength of the penalty to the loss function.
|
||||
@@ -90,16 +94,109 @@ pub struct RidgeRegressionParameters<T: RealNumber> {
|
||||
pub normalize: bool,
|
||||
}
|
||||
|
||||
/// Ridge Regression grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RidgeRegressionSearchParameters<T: Number + RealNumber> {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub solver: Vec<RidgeRegressionSolverName>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Regularization parameter.
|
||||
pub alpha: Vec<T>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// If true the regressors X will be normalized before regression
|
||||
/// by subtracting the mean and dividing by the standard deviation.
|
||||
pub normalize: Vec<bool>,
|
||||
}
|
||||
|
||||
/// Ridge Regression grid search iterator
|
||||
pub struct RidgeRegressionSearchParametersIterator<T: Number + RealNumber> {
|
||||
ridge_regression_search_parameters: RidgeRegressionSearchParameters<T>,
|
||||
current_solver: usize,
|
||||
current_alpha: usize,
|
||||
current_normalize: usize,
|
||||
}
|
||||
|
||||
impl<T: Number + RealNumber> IntoIterator for RidgeRegressionSearchParameters<T> {
|
||||
type Item = RidgeRegressionParameters<T>;
|
||||
type IntoIter = RidgeRegressionSearchParametersIterator<T>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
RidgeRegressionSearchParametersIterator {
|
||||
ridge_regression_search_parameters: self,
|
||||
current_solver: 0,
|
||||
current_alpha: 0,
|
||||
current_normalize: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Number + RealNumber> Iterator for RidgeRegressionSearchParametersIterator<T> {
|
||||
type Item = RidgeRegressionParameters<T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_alpha == self.ridge_regression_search_parameters.alpha.len()
|
||||
&& self.current_solver == self.ridge_regression_search_parameters.solver.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = RidgeRegressionParameters {
|
||||
solver: self.ridge_regression_search_parameters.solver[self.current_solver].clone(),
|
||||
alpha: self.ridge_regression_search_parameters.alpha[self.current_alpha],
|
||||
normalize: self.ridge_regression_search_parameters.normalize[self.current_normalize],
|
||||
};
|
||||
|
||||
if self.current_alpha + 1 < self.ridge_regression_search_parameters.alpha.len() {
|
||||
self.current_alpha += 1;
|
||||
} else if self.current_solver + 1 < self.ridge_regression_search_parameters.solver.len() {
|
||||
self.current_alpha = 0;
|
||||
self.current_solver += 1;
|
||||
} else if self.current_normalize + 1
|
||||
< self.ridge_regression_search_parameters.normalize.len()
|
||||
{
|
||||
self.current_alpha = 0;
|
||||
self.current_solver = 0;
|
||||
self.current_normalize += 1;
|
||||
} else {
|
||||
self.current_alpha += 1;
|
||||
self.current_solver += 1;
|
||||
self.current_normalize += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Number + RealNumber> Default for RidgeRegressionSearchParameters<T> {
|
||||
fn default() -> Self {
|
||||
let default_params = RidgeRegressionParameters::default();
|
||||
|
||||
RidgeRegressionSearchParameters {
|
||||
solver: vec![default_params.solver],
|
||||
alpha: vec![default_params.alpha],
|
||||
normalize: vec![default_params.normalize],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Ridge regression
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct RidgeRegression<T: RealNumber, M: Matrix<T>> {
|
||||
coefficients: M,
|
||||
intercept: T,
|
||||
_solver: RidgeRegressionSolverName,
|
||||
pub struct RidgeRegression<
|
||||
TX: Number + RealNumber,
|
||||
TY: Number,
|
||||
X: Array2<TX> + CholeskyDecomposable<TX> + SVDDecomposable<TX>,
|
||||
Y: Array1<TY>,
|
||||
> {
|
||||
coefficients: Option<X>,
|
||||
intercept: Option<TX>,
|
||||
_phantom_ty: PhantomData<TY>,
|
||||
_phantom_y: PhantomData<Y>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> RidgeRegressionParameters<T> {
|
||||
impl<T: Number + RealNumber> RidgeRegressionParameters<T> {
|
||||
/// Regularization parameter.
|
||||
pub fn with_alpha(mut self, alpha: T) -> Self {
|
||||
self.alpha = alpha;
|
||||
@@ -117,51 +214,83 @@ impl<T: RealNumber> RidgeRegressionParameters<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for RidgeRegressionParameters<T> {
|
||||
impl<T: Number + RealNumber> Default for RidgeRegressionParameters<T> {
|
||||
fn default() -> Self {
|
||||
RidgeRegressionParameters {
|
||||
solver: RidgeRegressionSolverName::Cholesky,
|
||||
alpha: T::one(),
|
||||
solver: RidgeRegressionSolverName::default(),
|
||||
alpha: T::from_f64(1.0).unwrap(),
|
||||
normalize: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> PartialEq for RidgeRegression<T, M> {
|
||||
impl<
|
||||
TX: Number + RealNumber,
|
||||
TY: Number,
|
||||
X: Array2<TX> + CholeskyDecomposable<TX> + SVDDecomposable<TX>,
|
||||
Y: Array1<TY>,
|
||||
> PartialEq for RidgeRegression<TX, TY, X, Y>
|
||||
{
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.coefficients == other.coefficients
|
||||
&& (self.intercept - other.intercept).abs() <= T::epsilon()
|
||||
self.intercept() == other.intercept()
|
||||
&& self.coefficients().shape() == other.coefficients().shape()
|
||||
&& self
|
||||
.coefficients()
|
||||
.iterator(0)
|
||||
.zip(other.coefficients().iterator(0))
|
||||
.all(|(&a, &b)| (a - b).abs() <= TX::epsilon())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, RidgeRegressionParameters<T>>
|
||||
for RidgeRegression<T, M>
|
||||
impl<
|
||||
TX: Number + RealNumber,
|
||||
TY: Number,
|
||||
X: Array2<TX> + CholeskyDecomposable<TX> + SVDDecomposable<TX>,
|
||||
Y: Array1<TY>,
|
||||
> SupervisedEstimator<X, Y, RidgeRegressionParameters<TX>> for RidgeRegression<TX, TY, X, Y>
|
||||
{
|
||||
fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: RidgeRegressionParameters<T>,
|
||||
) -> Result<Self, Failed> {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
coefficients: Option::None,
|
||||
intercept: Option::None,
|
||||
_phantom_ty: PhantomData,
|
||||
_phantom_y: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn fit(x: &X, y: &Y, parameters: RidgeRegressionParameters<TX>) -> Result<Self, Failed> {
|
||||
RidgeRegression::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RidgeRegression<T, M> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
impl<
|
||||
TX: Number + RealNumber,
|
||||
TY: Number,
|
||||
X: Array2<TX> + CholeskyDecomposable<TX> + SVDDecomposable<TX>,
|
||||
Y: Array1<TY>,
|
||||
> Predictor<X, Y> for RidgeRegression<TX, TY, X, Y>
|
||||
{
|
||||
fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> RidgeRegression<T, M> {
|
||||
impl<
|
||||
TX: Number + RealNumber,
|
||||
TY: Number,
|
||||
X: Array2<TX> + CholeskyDecomposable<TX> + SVDDecomposable<TX>,
|
||||
Y: Array1<TY>,
|
||||
> RidgeRegression<TX, TY, X, Y>
|
||||
{
|
||||
/// Fits ridge regression to your data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - target values
|
||||
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
|
||||
pub fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: RidgeRegressionParameters<T>,
|
||||
) -> Result<RidgeRegression<T, M>, Failed> {
|
||||
x: &X,
|
||||
y: &Y,
|
||||
parameters: RidgeRegressionParameters<TX>,
|
||||
) -> Result<RidgeRegression<TX, TY, X, Y>, Failed> {
|
||||
//w = inv(X^t X + alpha*Id) * X.T y
|
||||
|
||||
let (n, p) = x.shape();
|
||||
@@ -172,11 +301,16 @@ impl<T: RealNumber, M: Matrix<T>> RidgeRegression<T, M> {
|
||||
));
|
||||
}
|
||||
|
||||
if y.len() != n {
|
||||
if y.shape() != n {
|
||||
return Err(Failed::fit("Number of rows in X should = len(y)"));
|
||||
}
|
||||
|
||||
let y_column = M::from_row_vector(y.clone()).transpose();
|
||||
let y_column = X::from_iterator(
|
||||
y.iterator(0).map(|&v| TX::from(v).unwrap()),
|
||||
y.shape(),
|
||||
1,
|
||||
0,
|
||||
);
|
||||
|
||||
let (w, b) = if parameters.normalize {
|
||||
let (scaled_x, col_mean, col_std) = Self::rescale_x(x)?;
|
||||
@@ -185,7 +319,7 @@ impl<T: RealNumber, M: Matrix<T>> RidgeRegression<T, M> {
|
||||
let mut x_t_x = x_t.matmul(&scaled_x);
|
||||
|
||||
for i in 0..p {
|
||||
x_t_x.add_element_mut(i, i, parameters.alpha);
|
||||
x_t_x.add_element_mut((i, i), parameters.alpha);
|
||||
}
|
||||
|
||||
let mut w = match parameters.solver {
|
||||
@@ -194,16 +328,16 @@ impl<T: RealNumber, M: Matrix<T>> RidgeRegression<T, M> {
|
||||
};
|
||||
|
||||
for (i, col_std_i) in col_std.iter().enumerate().take(p) {
|
||||
w.set(i, 0, w.get(i, 0) / *col_std_i);
|
||||
w.set((i, 0), *w.get((i, 0)) / *col_std_i);
|
||||
}
|
||||
|
||||
let mut b = T::zero();
|
||||
let mut b = TX::zero();
|
||||
|
||||
for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
|
||||
b += w.get(i, 0) * *col_mean_i;
|
||||
b += *w.get((i, 0)) * *col_mean_i;
|
||||
}
|
||||
|
||||
let b = y.mean() - b;
|
||||
let b = TX::from_f64(y.mean_by()).unwrap() - b;
|
||||
|
||||
(w, b)
|
||||
} else {
|
||||
@@ -212,7 +346,7 @@ impl<T: RealNumber, M: Matrix<T>> RidgeRegression<T, M> {
|
||||
let mut x_t_x = x_t.matmul(x);
|
||||
|
||||
for i in 0..p {
|
||||
x_t_x.add_element_mut(i, i, parameters.alpha);
|
||||
x_t_x.add_element_mut((i, i), parameters.alpha);
|
||||
}
|
||||
|
||||
let w = match parameters.solver {
|
||||
@@ -220,26 +354,32 @@ impl<T: RealNumber, M: Matrix<T>> RidgeRegression<T, M> {
|
||||
RidgeRegressionSolverName::SVD => x_t_x.svd_solve_mut(x_t_y)?,
|
||||
};
|
||||
|
||||
(w, T::zero())
|
||||
(w, TX::zero())
|
||||
};
|
||||
|
||||
Ok(RidgeRegression {
|
||||
intercept: b,
|
||||
coefficients: w,
|
||||
_solver: parameters.solver,
|
||||
intercept: Some(b),
|
||||
coefficients: Some(w),
|
||||
_phantom_ty: PhantomData,
|
||||
_phantom_y: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
fn rescale_x(x: &M) -> Result<(M, Vec<T>, Vec<T>), Failed> {
|
||||
let col_mean = x.mean(0);
|
||||
let col_std = x.std(0);
|
||||
fn rescale_x(x: &X) -> Result<(X, Vec<TX>, Vec<TX>), Failed> {
|
||||
let col_mean: Vec<TX> = x
|
||||
.mean_by(0)
|
||||
.iter()
|
||||
.map(|&v| TX::from_f64(v).unwrap())
|
||||
.collect();
|
||||
let col_std: Vec<TX> = x
|
||||
.std_dev(0)
|
||||
.iter()
|
||||
.map(|&v| TX::from_f64(v).unwrap())
|
||||
.collect();
|
||||
|
||||
for (i, col_std_i) in col_std.iter().enumerate() {
|
||||
if (*col_std_i - T::zero()).abs() < T::epsilon() {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Cannot rescale constant column {}",
|
||||
i
|
||||
)));
|
||||
if (*col_std_i - TX::zero()).abs() < TX::epsilon() {
|
||||
return Err(Failed::fit(&format!("Cannot rescale constant column {i}")));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -250,31 +390,52 @@ impl<T: RealNumber, M: Matrix<T>> RidgeRegression<T, M> {
|
||||
|
||||
/// Predict target values from `x`
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
let (nrows, _) = x.shape();
|
||||
let mut y_hat = x.matmul(&self.coefficients);
|
||||
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
|
||||
Ok(y_hat.transpose().to_row_vector())
|
||||
let mut y_hat = x.matmul(self.coefficients());
|
||||
y_hat.add_mut(&X::fill(nrows, 1, self.intercept.unwrap()));
|
||||
Ok(Y::from_iterator(
|
||||
y_hat.iterator(0).map(|&v| TY::from(v).unwrap()),
|
||||
nrows,
|
||||
))
|
||||
}
|
||||
|
||||
/// Get estimates regression coefficients
|
||||
pub fn coefficients(&self) -> &M {
|
||||
&self.coefficients
|
||||
pub fn coefficients(&self) -> &X {
|
||||
self.coefficients.as_ref().unwrap()
|
||||
}
|
||||
|
||||
/// Get estimate of intercept
|
||||
pub fn intercept(&self) -> T {
|
||||
self.intercept
|
||||
pub fn intercept(&self) -> &TX {
|
||||
self.intercept.as_ref().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use crate::metrics::mean_absolute_error;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = RidgeRegressionSearchParameters {
|
||||
alpha: vec![0., 1.],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
assert_eq!(iter.next().unwrap().alpha, 0.);
|
||||
assert_eq!(
|
||||
iter.next().unwrap().solver,
|
||||
RidgeRegressionSolverName::Cholesky
|
||||
);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn ridge_fit_predict() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
@@ -294,7 +455,8 @@ mod tests {
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let y: Vec<f64> = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
@@ -330,39 +492,40 @@ mod tests {
|
||||
assert!(mean_absolute_error(&y_hat_svd, &y) < 2.0);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
&[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
|
||||
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
&[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
|
||||
&[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
|
||||
&[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
|
||||
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
&[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
&[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
|
||||
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
// TODO: implement serialization for new DenseMatrix
|
||||
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
|
||||
// #[test]
|
||||
// #[cfg(feature = "serde")]
|
||||
// fn serde() {
|
||||
// let x = DenseMatrix::from_2d_array(&[
|
||||
// &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
// &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
// &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
// &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
|
||||
// &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
// &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
|
||||
// &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
|
||||
// &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
|
||||
// &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
// &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
// &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
// &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
|
||||
// &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
// &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
// &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
// &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
// ]).unwrap();
|
||||
|
||||
let y = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
// let y = vec![
|
||||
// 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
// 114.2, 115.7, 116.9,
|
||||
// ];
|
||||
|
||||
let lr = RidgeRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
// let lr = RidgeRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let deserialized_lr: RidgeRegression<f64, DenseMatrix<f64>> =
|
||||
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
||||
// let deserialized_lr: RidgeRegression<f64, f64, DenseMatrix<f64>, Vec<f64>> =
|
||||
// serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(lr, deserialized_lr);
|
||||
}
|
||||
// assert_eq!(lr, deserialized_lr);
|
||||
// }
|
||||
}
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
//! # Euclidian Metric Distance
|
||||
//!
|
||||
//! The Euclidean distance (L2) between two points \\( x \\) and \\( y \\) in n-space is defined as
|
||||
//!
|
||||
//! \\[ d(x, y) = \sqrt{\sum_{i=1}^n (x-y)^2} \\]
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::math::distance::Distance;
|
||||
//! use smartcore::math::distance::euclidian::Euclidian;
|
||||
//!
|
||||
//! let x = vec![1., 1.];
|
||||
//! let y = vec![2., 2.];
|
||||
//!
|
||||
//! let l2: f64 = Euclidian{}.distance(&x, &y);
|
||||
//! ```
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
use super::Distance;
|
||||
|
||||
/// Euclidean distance is a measure of the true straight line distance between two points in Euclidean n-space.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Euclidian {}
|
||||
|
||||
impl Euclidian {
|
||||
#[inline]
|
||||
pub(crate) fn squared_distance<T: RealNumber>(x: &[T], y: &[T]) -> T {
|
||||
if x.len() != y.len() {
|
||||
panic!("Input vector sizes are different.");
|
||||
}
|
||||
|
||||
let mut sum = T::zero();
|
||||
for i in 0..x.len() {
|
||||
let d = x[i] - y[i];
|
||||
sum += d * d;
|
||||
}
|
||||
|
||||
sum
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Distance<Vec<T>, T> for Euclidian {
|
||||
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T {
|
||||
Euclidian::squared_distance(x, y).sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn squared_distance() {
|
||||
let a = vec![1., 2., 3.];
|
||||
let b = vec![4., 5., 6.];
|
||||
|
||||
let l2: f64 = Euclidian {}.distance(&a, &b);
|
||||
|
||||
assert!((l2 - 5.19615242).abs() < 1e-8);
|
||||
}
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
//! # Manhattan Distance
|
||||
//!
|
||||
//! The Manhattan distance between two points \\(x \in ℝ^n \\) and \\( y \in ℝ^n \\) in n-dimensional space is the sum of the distances in each dimension.
|
||||
//!
|
||||
//! \\[ d(x, y) = \sum_{i=0}^n \lvert x_i - y_i \rvert \\]
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::math::distance::Distance;
|
||||
//! use smartcore::math::distance::manhattan::Manhattan;
|
||||
//!
|
||||
//! let x = vec![1., 1.];
|
||||
//! let y = vec![2., 2.];
|
||||
//!
|
||||
//! let l1: f64 = Manhattan {}.distance(&x, &y);
|
||||
//! ```
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
use super::Distance;
|
||||
|
||||
/// Manhattan distance
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Manhattan {}
|
||||
|
||||
impl<T: RealNumber> Distance<Vec<T>, T> for Manhattan {
|
||||
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T {
|
||||
if x.len() != y.len() {
|
||||
panic!("Input vector sizes are different");
|
||||
}
|
||||
|
||||
let mut dist = T::zero();
|
||||
for i in 0..x.len() {
|
||||
dist += (x[i] - y[i]).abs();
|
||||
}
|
||||
|
||||
dist
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn manhattan_distance() {
|
||||
let a = vec![1., 2., 3.];
|
||||
let b = vec![4., 5., 6.];
|
||||
|
||||
let l1: f64 = Manhattan {}.distance(&a, &b);
|
||||
|
||||
assert!((l1 - 9.0).abs() < 1e-8);
|
||||
}
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
//! # Collection of Distance Functions
|
||||
//!
|
||||
//! Many algorithms in machine learning require a measure of distance between data points. Distance metric (or metric) is a function that defines a distance between a pair of point elements of a set.
|
||||
//! Formally, the distance can be any metric measure that is defined as \\( d(x, y) \geq 0\\) and follows three conditions:
|
||||
//! 1. \\( d(x, y) = 0 \\) if and only \\( x = y \\), positive definiteness
|
||||
//! 1. \\( d(x, y) = d(y, x) \\), symmetry
|
||||
//! 1. \\( d(x, y) \leq d(x, z) + d(z, y) \\), subadditivity or triangle inequality
|
||||
//!
|
||||
//! for all \\(x, y, z \in Z \\)
|
||||
//!
|
||||
//! A good distance metric helps to improve the performance of classification, clustering and information retrieval algorithms significantly.
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
/// Euclidean Distance is the straight-line distance between two points in Euclidean spacere that presents the shortest distance between these points.
|
||||
pub mod euclidian;
|
||||
/// Hamming Distance between two strings is the number of positions at which the corresponding symbols are different.
|
||||
pub mod hamming;
|
||||
/// The Mahalanobis distance is the distance between two points in multivariate space.
|
||||
pub mod mahalanobis;
|
||||
/// Also known as rectilinear distance, city block distance, taxicab metric.
|
||||
pub mod manhattan;
|
||||
/// A generalization of both the Euclidean distance and the Manhattan distance.
|
||||
pub mod minkowski;
|
||||
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Distance metric, a function that calculates distance between two points
|
||||
pub trait Distance<T, F: RealNumber>: Clone {
|
||||
/// Calculates distance between _a_ and _b_
|
||||
fn distance(&self, a: &T, b: &T) -> F;
|
||||
}
|
||||
|
||||
/// Multitude of distance metric functions
|
||||
pub struct Distances {}
|
||||
|
||||
impl Distances {
|
||||
/// Euclidian distance, see [`Euclidian`](euclidian/index.html)
|
||||
pub fn euclidian() -> euclidian::Euclidian {
|
||||
euclidian::Euclidian {}
|
||||
}
|
||||
|
||||
/// Minkowski distance, see [`Minkowski`](minkowski/index.html)
|
||||
/// * `p` - function order. Should be >= 1
|
||||
pub fn minkowski(p: u16) -> minkowski::Minkowski {
|
||||
minkowski::Minkowski { p }
|
||||
}
|
||||
|
||||
/// Manhattan distance, see [`Manhattan`](manhattan/index.html)
|
||||
pub fn manhattan() -> manhattan::Manhattan {
|
||||
manhattan::Manhattan {}
|
||||
}
|
||||
|
||||
/// Hamming distance, see [`Hamming`](hamming/index.html)
|
||||
pub fn hamming() -> hamming::Hamming {
|
||||
hamming::Hamming {}
|
||||
}
|
||||
|
||||
/// Mahalanobis distance, see [`Mahalanobis`](mahalanobis/index.html)
|
||||
pub fn mahalanobis<T: RealNumber, M: Matrix<T>>(data: &M) -> mahalanobis::Mahalanobis<T, M> {
|
||||
mahalanobis::Mahalanobis::new(data)
|
||||
}
|
||||
}
|
||||
@@ -1,4 +0,0 @@
|
||||
/// Multitude of distance metrics are defined here
|
||||
pub mod distance;
|
||||
pub mod num;
|
||||
pub(crate) mod vector;
|
||||
@@ -1,42 +0,0 @@
|
||||
use crate::math::num::RealNumber;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
pub trait RealNumberVector<T: RealNumber> {
|
||||
fn unique_with_indices(&self) -> (Vec<T>, Vec<usize>);
|
||||
}
|
||||
|
||||
impl<T: RealNumber, V: BaseVector<T>> RealNumberVector<T> for V {
|
||||
fn unique_with_indices(&self) -> (Vec<T>, Vec<usize>) {
|
||||
let mut unique = self.to_vec();
|
||||
unique.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
unique.dedup();
|
||||
|
||||
let mut index = HashMap::with_capacity(unique.len());
|
||||
for (i, u) in unique.iter().enumerate() {
|
||||
index.insert(u.to_i64().unwrap(), i);
|
||||
}
|
||||
|
||||
let mut unique_index = Vec::with_capacity(self.len());
|
||||
for idx in 0..self.len() {
|
||||
unique_index.push(index[&self.get(idx).to_i64().unwrap()]);
|
||||
}
|
||||
|
||||
(unique, unique_index)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn unique_with_indices() {
|
||||
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
|
||||
assert_eq!(
|
||||
(vec!(0.0, 1.0, 2.0, 4.0), vec!(0, 0, 1, 1, 2, 0, 3)),
|
||||
v1.unique_with_indices()
|
||||
);
|
||||
}
|
||||
}
|
||||
+62
-17
@@ -8,10 +8,20 @@
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::metrics::accuracy::Accuracy;
|
||||
//! use smartcore::metrics::Metrics;
|
||||
//! let y_pred: Vec<f64> = vec![0., 2., 1., 3.];
|
||||
//! let y_true: Vec<f64> = vec![0., 1., 2., 3.];
|
||||
//!
|
||||
//! let score: f64 = Accuracy {}.get_score(&y_pred, &y_true);
|
||||
//! let score: f64 = Accuracy::new().get_score( &y_true, &y_pred);
|
||||
//! ```
|
||||
//! With integers:
|
||||
//! ```
|
||||
//! use smartcore::metrics::accuracy::Accuracy;
|
||||
//! use smartcore::metrics::Metrics;
|
||||
//! let y_pred: Vec<i64> = vec![0, 2, 1, 3];
|
||||
//! let y_true: Vec<i64> = vec![0, 1, 2, 3];
|
||||
//!
|
||||
//! let score: f64 = Accuracy::new().get_score( &y_true, &y_pred);
|
||||
//! ```
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
@@ -19,37 +29,53 @@
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::linalg::basic::arrays::ArrayView1;
|
||||
use crate::numbers::basenum::Number;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::metrics::Metrics;
|
||||
|
||||
/// Accuracy metric.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct Accuracy {}
|
||||
pub struct Accuracy<T> {
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl Accuracy {
|
||||
impl<T: Number> Metrics<T> for Accuracy<T> {
|
||||
/// create a typed object to call Accuracy functions
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
fn new_with(_parameter: f64) -> Self {
|
||||
Self {
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
/// Function that calculated accuracy score.
|
||||
/// * `y_true` - cround truth (correct) labels
|
||||
/// * `y_pred` - predicted labels, as returned by a classifier.
|
||||
pub fn get_score<T: RealNumber, V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
|
||||
if y_true.len() != y_pred.len() {
|
||||
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64 {
|
||||
if y_true.shape() != y_pred.shape() {
|
||||
panic!(
|
||||
"The vector sizes don't match: {} != {}",
|
||||
y_true.len(),
|
||||
y_pred.len()
|
||||
y_true.shape(),
|
||||
y_pred.shape()
|
||||
);
|
||||
}
|
||||
|
||||
let n = y_true.len();
|
||||
let n = y_true.shape();
|
||||
|
||||
let mut positive = 0;
|
||||
let mut positive: i32 = 0;
|
||||
for i in 0..n {
|
||||
if y_true.get(i) == y_pred.get(i) {
|
||||
if *y_true.get(i) == *y_pred.get(i) {
|
||||
positive += 1;
|
||||
}
|
||||
}
|
||||
|
||||
T::from_i64(positive).unwrap() / T::from_usize(n).unwrap()
|
||||
positive as f64 / n as f64
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,16 +83,35 @@ impl Accuracy {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn accuracy() {
|
||||
fn accuracy_float() {
|
||||
let y_pred: Vec<f64> = vec![0., 2., 1., 3.];
|
||||
let y_true: Vec<f64> = vec![0., 1., 2., 3.];
|
||||
|
||||
let score1: f64 = Accuracy {}.get_score(&y_pred, &y_true);
|
||||
let score2: f64 = Accuracy {}.get_score(&y_true, &y_true);
|
||||
let score1: f64 = Accuracy::<f64>::new().get_score(&y_true, &y_pred);
|
||||
let score2: f64 = Accuracy::<f64>::new().get_score(&y_true, &y_true);
|
||||
|
||||
assert!((score1 - 0.5).abs() < 1e-8);
|
||||
assert!((score2 - 1.0).abs() < 1e-8);
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn accuracy_int() {
|
||||
let y_pred: Vec<i32> = vec![0, 2, 1, 3];
|
||||
let y_true: Vec<i32> = vec![0, 1, 2, 3];
|
||||
|
||||
let score1: f64 = Accuracy::<i32>::new().get_score(&y_true, &y_pred);
|
||||
let score2: f64 = Accuracy::<i32>::new().get_score(&y_true, &y_true);
|
||||
|
||||
assert_eq!(score1, 0.5);
|
||||
assert_eq!(score2, 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
+50
-27
@@ -2,16 +2,17 @@
|
||||
//! Computes the area under the receiver operating characteristic (ROC) curve that is equal to the probability that a classifier will rank a
|
||||
//! randomly chosen positive instance higher than a randomly chosen negative one.
|
||||
//!
|
||||
//! SmartCore calculates ROC AUC from Wilcoxon or Mann-Whitney U test.
|
||||
//! `smartcore` calculates ROC AUC from Wilcoxon or Mann-Whitney U test.
|
||||
//!
|
||||
//! Example:
|
||||
//! ```
|
||||
//! use smartcore::metrics::auc::AUC;
|
||||
//! use smartcore::metrics::Metrics;
|
||||
//!
|
||||
//! let y_true: Vec<f64> = vec![0., 0., 1., 1.];
|
||||
//! let y_pred: Vec<f64> = vec![0.1, 0.4, 0.35, 0.8];
|
||||
//!
|
||||
//! let score1: f64 = AUC {}.get_score(&y_true, &y_pred);
|
||||
//! let score1: f64 = AUC::new().get_score(&y_true, &y_pred);
|
||||
//! ```
|
||||
//!
|
||||
//! ## References:
|
||||
@@ -20,32 +21,48 @@
|
||||
//! * ["The ROC-AUC and the Mann-Whitney U-test", Haupt, J.](https://johaupt.github.io/roc-auc/model%20evaluation/Area_under_ROC_curve.html)
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::linalg::basic::arrays::{Array1, ArrayView1};
|
||||
use crate::numbers::floatnum::FloatNumber;
|
||||
|
||||
use crate::metrics::Metrics;
|
||||
|
||||
/// Area Under the Receiver Operating Characteristic Curve (ROC AUC)
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct AUC {}
|
||||
pub struct AUC<T> {
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl AUC {
|
||||
impl<T: FloatNumber + PartialOrd> Metrics<T> for AUC<T> {
|
||||
/// create a typed object to call AUC functions
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
fn new_with(_parameter: f64) -> Self {
|
||||
Self {
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
/// AUC score.
|
||||
/// * `y_true` - cround truth (correct) labels.
|
||||
/// * `y_pred_probabilities` - probability estimates, as returned by a classifier.
|
||||
pub fn get_score<T: RealNumber, V: BaseVector<T>>(&self, y_true: &V, y_pred_prob: &V) -> T {
|
||||
/// * `y_true` - ground truth (correct) labels.
|
||||
/// * `y_pred_prob` - probability estimates, as returned by a classifier.
|
||||
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred_prob: &dyn ArrayView1<T>) -> f64 {
|
||||
let mut pos = T::zero();
|
||||
let mut neg = T::zero();
|
||||
|
||||
let n = y_true.len();
|
||||
let n = y_true.shape();
|
||||
|
||||
for i in 0..n {
|
||||
if y_true.get(i) == T::zero() {
|
||||
if y_true.get(i) == &T::zero() {
|
||||
neg += T::one();
|
||||
} else if y_true.get(i) == T::one() {
|
||||
} else if y_true.get(i) == &T::one() {
|
||||
pos += T::one();
|
||||
} else {
|
||||
panic!(
|
||||
@@ -55,21 +72,22 @@ impl AUC {
|
||||
}
|
||||
}
|
||||
|
||||
let mut y_pred = y_pred_prob.to_vec();
|
||||
let y_pred: Vec<T> =
|
||||
Array1::<T>::from_iterator(y_pred_prob.iterator(0).copied(), y_pred_prob.shape());
|
||||
// TODO: try to use `crate::algorithm::sort::quick_sort` here
|
||||
let label_idx: Vec<usize> = y_pred.argsort();
|
||||
|
||||
let label_idx = y_pred.quick_argsort_mut();
|
||||
|
||||
let mut rank = vec![T::zero(); n];
|
||||
let mut rank = vec![0f64; n];
|
||||
let mut i = 0;
|
||||
while i < n {
|
||||
if i == n - 1 || y_pred[i] != y_pred[i + 1] {
|
||||
rank[i] = T::from_usize(i + 1).unwrap();
|
||||
if i == n - 1 || y_pred.get(i) != y_pred.get(i + 1) {
|
||||
rank[i] = (i + 1) as f64;
|
||||
} else {
|
||||
let mut j = i + 1;
|
||||
while j < n && y_pred[j] == y_pred[i] {
|
||||
while j < n && y_pred.get(j) == y_pred.get(i) {
|
||||
j += 1;
|
||||
}
|
||||
let r = T::from_usize(i + 1 + j).unwrap() / T::two();
|
||||
let r = (i + 1 + j) as f64 / 2f64;
|
||||
for rank_k in rank.iter_mut().take(j).skip(i) {
|
||||
*rank_k = r;
|
||||
}
|
||||
@@ -78,14 +96,16 @@ impl AUC {
|
||||
i += 1;
|
||||
}
|
||||
|
||||
let mut auc = T::zero();
|
||||
let mut auc = 0f64;
|
||||
for i in 0..n {
|
||||
if y_true.get(label_idx[i]) == T::one() {
|
||||
if y_true.get(label_idx[i]) == &T::one() {
|
||||
auc += rank[i];
|
||||
}
|
||||
}
|
||||
let pos = pos.to_f64().unwrap();
|
||||
let neg = neg.to_f64().unwrap();
|
||||
|
||||
(auc - (pos * (pos + T::one()) / T::two())) / (pos * neg)
|
||||
(auc - (pos * (pos + 1f64) / 2f64)) / (pos * neg)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,14 +113,17 @@ impl AUC {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn auc() {
|
||||
let y_true: Vec<f64> = vec![0., 0., 1., 1.];
|
||||
let y_pred: Vec<f64> = vec![0.1, 0.4, 0.35, 0.8];
|
||||
|
||||
let score1: f64 = AUC {}.get_score(&y_true, &y_pred);
|
||||
let score2: f64 = AUC {}.get_score(&y_true, &y_true);
|
||||
let score1: f64 = AUC::new().get_score(&y_true, &y_pred);
|
||||
let score2: f64 = AUC::new().get_score(&y_true, &y_true);
|
||||
|
||||
assert!((score1 - 0.75).abs() < 1e-8);
|
||||
assert!((score2 - 1.0).abs() < 1e-8);
|
||||
|
||||
+79
-31
@@ -1,41 +1,85 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::linalg::basic::arrays::ArrayView1;
|
||||
use crate::metrics::cluster_helpers::*;
|
||||
use crate::numbers::basenum::Number;
|
||||
|
||||
use crate::metrics::Metrics;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
/// Homogeneity, completeness and V-Measure scores.
|
||||
pub struct HCVScore {}
|
||||
pub struct HCVScore<T> {
|
||||
_phantom: PhantomData<T>,
|
||||
homogeneity: Option<f64>,
|
||||
completeness: Option<f64>,
|
||||
v_measure: Option<f64>,
|
||||
}
|
||||
|
||||
impl HCVScore {
|
||||
/// Computes Homogeneity, completeness and V-Measure scores at once.
|
||||
/// * `labels_true` - ground truth class labels to be used as a reference.
|
||||
/// * `labels_pred` - cluster labels to evaluate.
|
||||
pub fn get_score<T: RealNumber, V: BaseVector<T>>(
|
||||
&self,
|
||||
labels_true: &V,
|
||||
labels_pred: &V,
|
||||
) -> (T, T, T) {
|
||||
let labels_true = labels_true.to_vec();
|
||||
let labels_pred = labels_pred.to_vec();
|
||||
let entropy_c = entropy(&labels_true);
|
||||
let entropy_k = entropy(&labels_pred);
|
||||
let contingency = contingency_matrix(&labels_true, &labels_pred);
|
||||
let mi: T = mutual_info_score(&contingency);
|
||||
impl<T: Number + Ord> HCVScore<T> {
|
||||
/// return homogenity score
|
||||
pub fn homogeneity(&self) -> Option<f64> {
|
||||
self.homogeneity
|
||||
}
|
||||
/// return completeness score
|
||||
pub fn completeness(&self) -> Option<f64> {
|
||||
self.completeness
|
||||
}
|
||||
/// return v_measure score
|
||||
pub fn v_measure(&self) -> Option<f64> {
|
||||
self.v_measure
|
||||
}
|
||||
/// run computation for measures
|
||||
pub fn compute(&mut self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) {
|
||||
let entropy_c: Option<f64> = entropy(y_true);
|
||||
let entropy_k: Option<f64> = entropy(y_pred);
|
||||
let contingency = contingency_matrix(y_true, y_pred);
|
||||
let mi = mutual_info_score(&contingency);
|
||||
|
||||
let homogeneity = entropy_c.map(|e| mi / e).unwrap_or_else(T::one);
|
||||
let completeness = entropy_k.map(|e| mi / e).unwrap_or_else(T::one);
|
||||
let homogeneity = entropy_c.map(|e| mi / e).unwrap_or(0f64);
|
||||
let completeness = entropy_k.map(|e| mi / e).unwrap_or(0f64);
|
||||
|
||||
let v_measure_score = if homogeneity + completeness == T::zero() {
|
||||
T::zero()
|
||||
let v_measure_score = if homogeneity + completeness == 0f64 {
|
||||
0f64
|
||||
} else {
|
||||
T::two() * homogeneity * completeness / (T::one() * homogeneity + completeness)
|
||||
2.0f64 * homogeneity * completeness / (1.0f64 * homogeneity + completeness)
|
||||
};
|
||||
|
||||
(homogeneity, completeness, v_measure_score)
|
||||
self.homogeneity = Some(homogeneity);
|
||||
self.completeness = Some(completeness);
|
||||
self.v_measure = Some(v_measure_score);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Number + Ord> Metrics<T> for HCVScore<T> {
|
||||
/// create a typed object to call HCVScore functions
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
_phantom: PhantomData,
|
||||
homogeneity: Option::None,
|
||||
completeness: Option::None,
|
||||
v_measure: Option::None,
|
||||
}
|
||||
}
|
||||
fn new_with(_parameter: f64) -> Self {
|
||||
Self {
|
||||
_phantom: PhantomData,
|
||||
homogeneity: Option::None,
|
||||
completeness: Option::None,
|
||||
v_measure: Option::None,
|
||||
}
|
||||
}
|
||||
/// Computes Homogeneity, completeness and V-Measure scores at once.
|
||||
/// * `y_true` - ground truth class labels to be used as a reference.
|
||||
/// * `y_pred` - cluster labels to evaluate.
|
||||
fn get_score(&self, _y_true: &dyn ArrayView1<T>, _y_pred: &dyn ArrayView1<T>) -> f64 {
|
||||
// this functions should not be used for this struct
|
||||
// use homogeneity(), completeness(), v_measure()
|
||||
// TODO: implement Metrics -> Result<T, Failed>
|
||||
0f64
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,15 +87,19 @@ impl HCVScore {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn homogeneity_score() {
|
||||
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
|
||||
let v2 = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0];
|
||||
let scores = HCVScore {}.get_score(&v1, &v2);
|
||||
let v1 = vec![0, 0, 1, 1, 2, 0, 4];
|
||||
let v2 = vec![1, 0, 0, 0, 0, 1, 0];
|
||||
let mut scores = HCVScore::new();
|
||||
scores.compute(&v1, &v2);
|
||||
|
||||
assert!((0.2548f32 - scores.0).abs() < 1e-4);
|
||||
assert!((0.5440f32 - scores.1).abs() < 1e-4);
|
||||
assert!((0.3471f32 - scores.2).abs() < 1e-4);
|
||||
assert!((0.2548 - scores.homogeneity.unwrap()).abs() < 1e-4);
|
||||
assert!((0.5440 - scores.completeness.unwrap()).abs() < 1e-4);
|
||||
assert!((0.3471 - scores.v_measure.unwrap()).abs() < 1e-4);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
#![allow(clippy::ptr_arg)]
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::math::vector::RealNumberVector;
|
||||
use crate::linalg::basic::arrays::ArrayView1;
|
||||
use crate::numbers::basenum::Number;
|
||||
|
||||
pub fn contingency_matrix<T: RealNumber>(
|
||||
labels_true: &Vec<T>,
|
||||
labels_pred: &Vec<T>,
|
||||
pub fn contingency_matrix<T: Number + Ord, V: ArrayView1<T> + ?Sized>(
|
||||
labels_true: &V,
|
||||
labels_pred: &V,
|
||||
) -> Vec<Vec<usize>> {
|
||||
let (classes, class_idx) = labels_true.unique_with_indices();
|
||||
let (clusters, cluster_idx) = labels_pred.unique_with_indices();
|
||||
@@ -24,28 +24,30 @@ pub fn contingency_matrix<T: RealNumber>(
|
||||
contingency_matrix
|
||||
}
|
||||
|
||||
pub fn entropy<T: RealNumber>(data: &[T]) -> Option<T> {
|
||||
let mut bincounts = HashMap::with_capacity(data.len());
|
||||
pub fn entropy<T: Number + Ord, V: ArrayView1<T> + ?Sized>(data: &V) -> Option<f64> {
|
||||
let mut bincounts = HashMap::with_capacity(data.shape());
|
||||
|
||||
for e in data.iter() {
|
||||
for e in data.iterator(0) {
|
||||
let k = e.to_i64().unwrap();
|
||||
bincounts.insert(k, bincounts.get(&k).unwrap_or(&0) + 1);
|
||||
}
|
||||
|
||||
let mut entropy = T::zero();
|
||||
let sum = T::from_usize(bincounts.values().sum()).unwrap();
|
||||
let mut entropy = 0f64;
|
||||
let sum: i64 = bincounts.values().sum();
|
||||
|
||||
for &c in bincounts.values() {
|
||||
if c > 0 {
|
||||
let pi = T::from_usize(c).unwrap();
|
||||
entropy -= (pi / sum) * (pi.ln() - sum.ln());
|
||||
let pi = c as f64;
|
||||
let pi_ln = pi.ln();
|
||||
let sum_ln = (sum as f64).ln();
|
||||
entropy -= (pi / sum as f64) * (pi_ln - sum_ln);
|
||||
}
|
||||
}
|
||||
|
||||
Some(entropy)
|
||||
}
|
||||
|
||||
pub fn mutual_info_score<T: RealNumber>(contingency: &[Vec<usize>]) -> T {
|
||||
pub fn mutual_info_score(contingency: &[Vec<usize>]) -> f64 {
|
||||
let mut contingency_sum = 0;
|
||||
let mut pi = vec![0; contingency.len()];
|
||||
let mut pj = vec![0; contingency[0].len()];
|
||||
@@ -64,48 +66,50 @@ pub fn mutual_info_score<T: RealNumber>(contingency: &[Vec<usize>]) -> T {
|
||||
}
|
||||
}
|
||||
|
||||
let contingency_sum = T::from_usize(contingency_sum).unwrap();
|
||||
let contingency_sum = contingency_sum as f64;
|
||||
let contingency_sum_ln = contingency_sum.ln();
|
||||
let pi_sum_l = T::from_usize(pi.iter().sum()).unwrap().ln();
|
||||
let pj_sum_l = T::from_usize(pj.iter().sum()).unwrap().ln();
|
||||
let pi_sum: usize = pi.iter().sum();
|
||||
let pj_sum: usize = pj.iter().sum();
|
||||
let pi_sum_l = (pi_sum as f64).ln();
|
||||
let pj_sum_l = (pj_sum as f64).ln();
|
||||
|
||||
let log_contingency_nm: Vec<T> = nz_val
|
||||
let log_contingency_nm: Vec<f64> = nz_val.iter().map(|v| (*v as f64).ln()).collect();
|
||||
let contingency_nm: Vec<f64> = nz_val
|
||||
.iter()
|
||||
.map(|v| T::from_usize(*v).unwrap().ln())
|
||||
.collect();
|
||||
let contingency_nm: Vec<T> = nz_val
|
||||
.iter()
|
||||
.map(|v| T::from_usize(*v).unwrap() / contingency_sum)
|
||||
.map(|v| (*v as f64) / contingency_sum)
|
||||
.collect();
|
||||
let outer: Vec<usize> = nzx
|
||||
.iter()
|
||||
.zip(nzy.iter())
|
||||
.map(|(&x, &y)| pi[x] * pj[y])
|
||||
.collect();
|
||||
let log_outer: Vec<T> = outer
|
||||
let log_outer: Vec<f64> = outer
|
||||
.iter()
|
||||
.map(|&o| -T::from_usize(o).unwrap().ln() + pi_sum_l + pj_sum_l)
|
||||
.map(|&o| -(o as f64).ln() + pi_sum_l + pj_sum_l)
|
||||
.collect();
|
||||
|
||||
let mut result = T::zero();
|
||||
let mut result = 0f64;
|
||||
|
||||
for i in 0..log_outer.len() {
|
||||
result += (contingency_nm[i] * (log_contingency_nm[i] - contingency_sum_ln))
|
||||
+ contingency_nm[i] * log_outer[i]
|
||||
}
|
||||
|
||||
result.max(T::zero())
|
||||
result.max(0f64)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn contingency_matrix_test() {
|
||||
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
|
||||
let v2 = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0];
|
||||
let v1 = vec![0, 0, 1, 1, 2, 0, 4];
|
||||
let v2 = vec![1, 0, 0, 0, 0, 1, 0];
|
||||
|
||||
assert_eq!(
|
||||
vec!(vec!(1, 2), vec!(2, 0), vec!(1, 0), vec!(1, 0)),
|
||||
@@ -113,20 +117,26 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn entropy_test() {
|
||||
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
|
||||
let v1 = vec![0, 0, 1, 1, 2, 0, 4];
|
||||
|
||||
assert!((1.2770f32 - entropy(&v1).unwrap()).abs() < 1e-4);
|
||||
assert!((1.2770 - entropy(&v1).unwrap()).abs() < 1e-4);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn mutual_info_score_test() {
|
||||
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
|
||||
let v2 = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0];
|
||||
let s: f32 = mutual_info_score(&contingency_matrix(&v1, &v2));
|
||||
let v1 = vec![0, 0, 1, 1, 2, 0, 4];
|
||||
let v2 = vec![1, 0, 0, 0, 0, 1, 0];
|
||||
let s = mutual_info_score(&contingency_matrix(&v1, &v2));
|
||||
|
||||
assert!((0.3254 - s).abs() < 1e-4);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,219 @@
|
||||
//! # Cosine Distance Metric
|
||||
//!
|
||||
//! The cosine distance between two points \\( x \\) and \\( y \\) in n-space is defined as:
|
||||
//!
|
||||
//! \\[ d(x, y) = 1 - \frac{x \cdot y}{||x|| ||y||} \\]
|
||||
//!
|
||||
//! where \\( x \cdot y \\) is the dot product of the vectors, and \\( ||x|| \\) and \\( ||y|| \\)
|
||||
//! are their respective magnitudes (Euclidean norms).
|
||||
//!
|
||||
//! Cosine distance measures the angular dissimilarity between vectors, ranging from 0 to 2.
|
||||
//! A value of 0 indicates identical direction (parallel vectors), while larger values indicate
|
||||
//! greater angular separation.
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::metrics::distance::Distance;
|
||||
//! use smartcore::metrics::distance::cosine::Cosine;
|
||||
//!
|
||||
//! let x = vec![1., 1.];
|
||||
//! let y = vec![2., 2.];
|
||||
//!
|
||||
//! let cosine_dist: f64 = Cosine::new().distance(&x, &y);
|
||||
//! ```
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::linalg::basic::arrays::ArrayView1;
|
||||
use crate::numbers::basenum::Number;
|
||||
|
||||
use super::Distance;
|
||||
|
||||
/// Cosine distance is a measure of the angular dissimilarity between two non-zero vectors in n-space.
|
||||
/// It is defined as 1 minus the cosine similarity of the vectors.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Cosine<T> {
|
||||
_t: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: Number> Default for Cosine<T> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Number> Cosine<T> {
|
||||
/// Instantiate the initial structure
|
||||
pub fn new() -> Cosine<T> {
|
||||
Cosine { _t: PhantomData }
|
||||
}
|
||||
|
||||
/// Calculate the dot product of two vectors using smartcore's ArrayView1 trait
|
||||
#[inline]
|
||||
pub(crate) fn dot_product<A: ArrayView1<T>>(x: &A, y: &A) -> f64 {
|
||||
if x.shape() != y.shape() {
|
||||
panic!("Input vector sizes are different.");
|
||||
}
|
||||
|
||||
// Use the built-in dot product method from ArrayView1 trait
|
||||
x.dot(y).to_f64().unwrap()
|
||||
}
|
||||
|
||||
/// Calculate the squared magnitude (norm squared) of a vector
|
||||
#[inline]
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn squared_magnitude<A: ArrayView1<T>>(x: &A) -> f64 {
|
||||
x.iterator(0)
|
||||
.map(|&a| {
|
||||
let val = a.to_f64().unwrap();
|
||||
val * val
|
||||
})
|
||||
.sum()
|
||||
}
|
||||
|
||||
/// Calculate the magnitude (Euclidean norm) of a vector using smartcore's norm2 method
|
||||
#[inline]
|
||||
pub(crate) fn magnitude<A: ArrayView1<T>>(x: &A) -> f64 {
|
||||
// Use the built-in norm2 method from ArrayView1 trait
|
||||
x.norm2()
|
||||
}
|
||||
|
||||
/// Calculate cosine similarity between two vectors
|
||||
#[inline]
|
||||
pub(crate) fn cosine_similarity<A: ArrayView1<T>>(x: &A, y: &A) -> f64 {
|
||||
let dot_product = Self::dot_product(x, y);
|
||||
let magnitude_x = Self::magnitude(x);
|
||||
let magnitude_y = Self::magnitude(y);
|
||||
|
||||
if magnitude_x == 0.0 || magnitude_y == 0.0 {
|
||||
return f64::MIN;
|
||||
}
|
||||
|
||||
dot_product / (magnitude_x * magnitude_y)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Number, A: ArrayView1<T>> Distance<A> for Cosine<T> {
|
||||
fn distance(&self, x: &A, y: &A) -> f64 {
|
||||
let similarity = Cosine::cosine_similarity(x, y);
|
||||
1.0 - similarity
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn cosine_distance_identical_vectors() {
|
||||
let a = vec![1, 2, 3];
|
||||
let b = vec![1, 2, 3];
|
||||
|
||||
let dist: f64 = Cosine::new().distance(&a, &b);
|
||||
|
||||
assert!((dist - 0.0).abs() < 1e-8);
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn cosine_distance_orthogonal_vectors() {
|
||||
let a = vec![1, 0];
|
||||
let b = vec![0, 1];
|
||||
|
||||
let dist: f64 = Cosine::new().distance(&a, &b);
|
||||
|
||||
assert!((dist - 1.0).abs() < 1e-8);
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn cosine_distance_opposite_vectors() {
|
||||
let a = vec![1, 2, 3];
|
||||
let b = vec![-1, -2, -3];
|
||||
|
||||
let dist: f64 = Cosine::new().distance(&a, &b);
|
||||
|
||||
assert!((dist - 2.0).abs() < 1e-8);
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn cosine_distance_general_case() {
|
||||
let a = vec![1.0, 2.0, 3.0];
|
||||
let b = vec![2.0, 1.0, 3.0];
|
||||
|
||||
let dist: f64 = Cosine::new().distance(&a, &b);
|
||||
|
||||
// Expected cosine similarity: (1*2 + 2*1 + 3*3) / (sqrt(1+4+9) * sqrt(4+1+9))
|
||||
// = (2 + 2 + 9) / (sqrt(14) * sqrt(14)) = 13/14 ≈ 0.9286
|
||||
// So cosine distance = 1 - 13/14 = 1/14 ≈ 0.0714
|
||||
let expected_dist = 1.0 - (13.0 / 14.0);
|
||||
assert!((dist - expected_dist).abs() < 1e-8);
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
#[should_panic(expected = "Input vector sizes are different.")]
|
||||
fn cosine_distance_different_sizes() {
|
||||
let a = vec![1, 2];
|
||||
let b = vec![1, 2, 3];
|
||||
|
||||
let _dist: f64 = Cosine::new().distance(&a, &b);
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn cosine_distance_zero_vector() {
|
||||
let a = vec![0, 0, 0];
|
||||
let b = vec![1, 2, 3];
|
||||
|
||||
let dist: f64 = Cosine::new().distance(&a, &b);
|
||||
assert!(dist > 1e300)
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn cosine_distance_float_precision() {
|
||||
let a = vec![1.0f32, 2.0, 3.0];
|
||||
let b = vec![4.0f32, 5.0, 6.0];
|
||||
|
||||
let dist: f64 = Cosine::new().distance(&a, &b);
|
||||
|
||||
// Calculate expected value manually
|
||||
let dot_product = 1.0 * 4.0 + 2.0 * 5.0 + 3.0 * 6.0; // = 32
|
||||
let mag_a = (1.0 * 1.0 + 2.0 * 2.0 + 3.0 * 3.0_f64).sqrt(); // = sqrt(14)
|
||||
let mag_b = (4.0 * 4.0 + 5.0 * 5.0 + 6.0 * 6.0_f64).sqrt(); // = sqrt(77)
|
||||
let expected_similarity = dot_product / (mag_a * mag_b);
|
||||
let expected_distance = 1.0 - expected_similarity;
|
||||
|
||||
assert!((dist - expected_distance).abs() < 1e-6);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
//! # Euclidian Metric Distance
|
||||
//!
|
||||
//! The Euclidean distance (L2) between two points \\( x \\) and \\( y \\) in n-space is defined as
|
||||
//!
|
||||
//! \\[ d(x, y) = \sqrt{\sum_{i=1}^n (x-y)^2} \\]
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::metrics::distance::Distance;
|
||||
//! use smartcore::metrics::distance::euclidian::Euclidian;
|
||||
//!
|
||||
//! let x = vec![1., 1.];
|
||||
//! let y = vec![2., 2.];
|
||||
//!
|
||||
//! let l2: f64 = Euclidian::new().distance(&x, &y);
|
||||
//! ```
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::linalg::basic::arrays::ArrayView1;
|
||||
use crate::numbers::basenum::Number;
|
||||
|
||||
use super::Distance;
|
||||
|
||||
/// Euclidean distance is a measure of the true straight line distance between two points in Euclidean n-space.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Euclidian<T> {
|
||||
_t: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: Number> Default for Euclidian<T> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Number> Euclidian<T> {
|
||||
/// instatiate the initial structure
|
||||
pub fn new() -> Euclidian<T> {
|
||||
Euclidian { _t: PhantomData }
|
||||
}
|
||||
|
||||
/// return sum of squared distances
|
||||
#[inline]
|
||||
pub(crate) fn squared_distance<A: ArrayView1<T>>(x: &A, y: &A) -> f64 {
|
||||
if x.shape() != y.shape() {
|
||||
panic!("Input vector sizes are different.");
|
||||
}
|
||||
|
||||
let sum: f64 = x
|
||||
.iterator(0)
|
||||
.zip(y.iterator(0))
|
||||
.map(|(&a, &b)| {
|
||||
let r = a - b;
|
||||
(r * r).to_f64().unwrap()
|
||||
})
|
||||
.sum();
|
||||
|
||||
sum
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Number, A: ArrayView1<T>> Distance<A> for Euclidian<T> {
|
||||
fn distance(&self, x: &A, y: &A) -> f64 {
|
||||
Euclidian::squared_distance(x, y).sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn squared_distance() {
|
||||
let a = vec![1, 2, 3];
|
||||
let b = vec![4, 5, 6];
|
||||
|
||||
let l2: f64 = Euclidian::new().distance(&a, &b);
|
||||
|
||||
assert!((l2 - 5.19615242).abs() < 1e-8);
|
||||
}
|
||||
}
|
||||
@@ -6,13 +6,13 @@
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::math::distance::Distance;
|
||||
//! use smartcore::math::distance::hamming::Hamming;
|
||||
//! use smartcore::metrics::distance::Distance;
|
||||
//! use smartcore::metrics::distance::hamming::Hamming;
|
||||
//!
|
||||
//! let a = vec![1, 0, 0, 1, 0, 0, 1];
|
||||
//! let b = vec![1, 1, 0, 0, 1, 0, 1];
|
||||
//!
|
||||
//! let h: f64 = Hamming {}.distance(&a, &b);
|
||||
//! let h: f64 = Hamming::new().distance(&a, &b);
|
||||
//!
|
||||
//! ```
|
||||
//!
|
||||
@@ -21,30 +21,48 @@
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::math::num::RealNumber;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use super::Distance;
|
||||
use crate::linalg::basic::arrays::ArrayView1;
|
||||
use crate::numbers::basenum::Number;
|
||||
|
||||
/// While comparing two integer-valued vectors of equal length, Hamming distance is the number of bit positions in which the two bits are different
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Hamming {}
|
||||
pub struct Hamming<T: Number> {
|
||||
_t: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: PartialEq, F: RealNumber> Distance<Vec<T>, F> for Hamming {
|
||||
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> F {
|
||||
if x.len() != y.len() {
|
||||
impl<T: Number> Hamming<T> {
|
||||
/// instatiate the initial structure
|
||||
pub fn new() -> Hamming<T> {
|
||||
Hamming { _t: PhantomData }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Number> Default for Hamming<T> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Number, A: ArrayView1<T>> Distance<A> for Hamming<T> {
|
||||
fn distance(&self, x: &A, y: &A) -> f64 {
|
||||
if x.shape() != y.shape() {
|
||||
panic!("Input vector sizes are different");
|
||||
}
|
||||
|
||||
let mut dist = 0;
|
||||
for i in 0..x.len() {
|
||||
if x[i] != y[i] {
|
||||
dist += 1;
|
||||
}
|
||||
}
|
||||
let dist: usize = x
|
||||
.iterator(0)
|
||||
.zip(y.iterator(0))
|
||||
.map(|(a, b)| match a != b {
|
||||
true => 1,
|
||||
false => 0,
|
||||
})
|
||||
.sum();
|
||||
|
||||
F::from_i64(dist).unwrap() / F::from_usize(x.len()).unwrap()
|
||||
dist as f64 / x.shape() as f64
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,13 +70,16 @@ impl<T: PartialEq, F: RealNumber> Distance<Vec<T>, F> for Hamming {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn hamming_distance() {
|
||||
let a = vec![1, 0, 0, 1, 0, 0, 1];
|
||||
let b = vec![1, 1, 0, 0, 1, 0, 1];
|
||||
|
||||
let h: f64 = Hamming {}.distance(&a, &b);
|
||||
let h: f64 = Hamming::new().distance(&a, &b);
|
||||
|
||||
assert!((h - 0.42857142).abs() < 1e-8);
|
||||
}
|
||||
@@ -14,9 +14,10 @@
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::math::distance::Distance;
|
||||
//! use smartcore::math::distance::mahalanobis::Mahalanobis;
|
||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
//! use smartcore::linalg::basic::arrays::ArrayView2;
|
||||
//! use smartcore::metrics::distance::Distance;
|
||||
//! use smartcore::metrics::distance::mahalanobis::Mahalanobis;
|
||||
//!
|
||||
//! let data = DenseMatrix::from_2d_array(&[
|
||||
//! &[64., 580., 29.],
|
||||
@@ -24,9 +25,9 @@
|
||||
//! &[68., 590., 37.],
|
||||
//! &[69., 660., 46.],
|
||||
//! &[73., 600., 55.],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! let a = data.column_mean();
|
||||
//! let a = data.mean_by(0);
|
||||
//! let b = vec![66., 640., 44.];
|
||||
//!
|
||||
//! let mahalanobis = Mahalanobis::new(&data);
|
||||
@@ -42,85 +43,89 @@
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::math::num::RealNumber;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use super::Distance;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::linalg::basic::arrays::{Array, Array2, ArrayView1};
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use crate::linalg::traits::lu::LUDecomposable;
|
||||
use crate::numbers::basenum::Number;
|
||||
|
||||
/// Mahalanobis distance.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Mahalanobis<T: RealNumber, M: Matrix<T>> {
|
||||
pub struct Mahalanobis<T: Number, M: Array2<f64>> {
|
||||
/// covariance matrix of the dataset
|
||||
pub sigma: M,
|
||||
/// inverse of the covariance matrix
|
||||
pub sigmaInv: M,
|
||||
t: PhantomData<T>,
|
||||
_t: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Mahalanobis<T, M> {
|
||||
impl<T: Number, M: Array2<f64> + LUDecomposable<f64>> Mahalanobis<T, M> {
|
||||
/// Constructs new instance of `Mahalanobis` from given dataset
|
||||
/// * `data` - a matrix of _NxM_ where _N_ is number of observations and _M_ is number of attributes
|
||||
pub fn new(data: &M) -> Mahalanobis<T, M> {
|
||||
let sigma = data.cov();
|
||||
pub fn new<X: Array2<T>>(data: &X) -> Mahalanobis<T, M> {
|
||||
let (_, m) = data.shape();
|
||||
let mut sigma = M::zeros(m, m);
|
||||
data.cov(&mut sigma);
|
||||
let sigmaInv = sigma.lu().and_then(|lu| lu.inverse()).unwrap();
|
||||
Mahalanobis {
|
||||
sigma,
|
||||
sigmaInv,
|
||||
t: PhantomData,
|
||||
_t: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Constructs new instance of `Mahalanobis` from given covariance matrix
|
||||
/// * `cov` - a covariance matrix
|
||||
pub fn new_from_covariance(cov: &M) -> Mahalanobis<T, M> {
|
||||
pub fn new_from_covariance<X: Array2<f64> + LUDecomposable<f64>>(cov: &X) -> Mahalanobis<T, X> {
|
||||
let sigma = cov.clone();
|
||||
let sigmaInv = sigma.lu().and_then(|lu| lu.inverse()).unwrap();
|
||||
Mahalanobis {
|
||||
sigma,
|
||||
sigmaInv,
|
||||
t: PhantomData,
|
||||
_t: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Distance<Vec<T>, T> for Mahalanobis<T, M> {
|
||||
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T {
|
||||
impl<T: Number, A: ArrayView1<T>> Distance<A> for Mahalanobis<T, DenseMatrix<f64>> {
|
||||
fn distance(&self, x: &A, y: &A) -> f64 {
|
||||
let (nrows, ncols) = self.sigma.shape();
|
||||
if x.len() != nrows {
|
||||
if x.shape() != nrows {
|
||||
panic!(
|
||||
"Array x[{}] has different dimension with Sigma[{}][{}].",
|
||||
x.len(),
|
||||
x.shape(),
|
||||
nrows,
|
||||
ncols
|
||||
);
|
||||
}
|
||||
|
||||
if y.len() != nrows {
|
||||
if y.shape() != nrows {
|
||||
panic!(
|
||||
"Array y[{}] has different dimension with Sigma[{}][{}].",
|
||||
y.len(),
|
||||
y.shape(),
|
||||
nrows,
|
||||
ncols
|
||||
);
|
||||
}
|
||||
|
||||
let n = x.len();
|
||||
let mut z = vec![T::zero(); n];
|
||||
for i in 0..n {
|
||||
z[i] = x[i] - y[i];
|
||||
}
|
||||
let n = x.shape();
|
||||
|
||||
let z: Vec<f64> = x
|
||||
.iterator(0)
|
||||
.zip(y.iterator(0))
|
||||
.map(|(&a, &b)| (a - b).to_f64().unwrap())
|
||||
.collect();
|
||||
|
||||
// np.dot(np.dot((a-b),VI),(a-b).T)
|
||||
let mut s = T::zero();
|
||||
let mut s = 0f64;
|
||||
for j in 0..n {
|
||||
for i in 0..n {
|
||||
s += self.sigmaInv.get(i, j) * z[i] * z[j];
|
||||
s += *self.sigmaInv.get((i, j)) * z[i] * z[j];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,9 +136,13 @@ impl<T: RealNumber, M: Matrix<T>> Distance<Vec<T>, T> for Mahalanobis<T, M> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::linalg::basic::arrays::ArrayView2;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn mahalanobis_distance() {
|
||||
let data = DenseMatrix::from_2d_array(&[
|
||||
@@ -142,9 +151,10 @@ mod tests {
|
||||
&[68., 590., 37.],
|
||||
&[69., 660., 46.],
|
||||
&[73., 600., 55.],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let a = data.column_mean();
|
||||
let a = data.mean_by(0);
|
||||
let b = vec![66., 640., 44.];
|
||||
|
||||
let mahalanobis = Mahalanobis::new(&data);
|
||||
@@ -0,0 +1,82 @@
|
||||
//! # Manhattan Distance
|
||||
//!
|
||||
//! The Manhattan distance between two points \\(x \in ℝ^n \\) and \\( y \in ℝ^n \\) in n-dimensional space is the sum of the distances in each dimension.
|
||||
//!
|
||||
//! \\[ d(x, y) = \sum_{i=0}^n \lvert x_i - y_i \rvert \\]
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::metrics::distance::Distance;
|
||||
//! use smartcore::metrics::distance::manhattan::Manhattan;
|
||||
//!
|
||||
//! let x = vec![1., 1.];
|
||||
//! let y = vec![2., 2.];
|
||||
//!
|
||||
//! let l1: f64 = Manhattan::new().distance(&x, &y);
|
||||
//! ```
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::linalg::basic::arrays::ArrayView1;
|
||||
use crate::numbers::basenum::Number;
|
||||
|
||||
use super::Distance;
|
||||
|
||||
/// Manhattan distance
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Manhattan<T: Number> {
|
||||
_t: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: Number> Manhattan<T> {
|
||||
/// instatiate the initial structure
|
||||
pub fn new() -> Manhattan<T> {
|
||||
Manhattan { _t: PhantomData }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Number> Default for Manhattan<T> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Number, A: ArrayView1<T>> Distance<A> for Manhattan<T> {
|
||||
fn distance(&self, x: &A, y: &A) -> f64 {
|
||||
if x.shape() != y.shape() {
|
||||
panic!("Input vector sizes are different");
|
||||
}
|
||||
|
||||
let dist: f64 = x
|
||||
.iterator(0)
|
||||
.zip(y.iterator(0))
|
||||
.map(|(&a, &b)| (a - b).to_f64().unwrap().abs())
|
||||
.sum();
|
||||
|
||||
dist
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn manhattan_distance() {
|
||||
let a = vec![1., 2., 3.];
|
||||
let b = vec![4., 5., 6.];
|
||||
|
||||
let l1: f64 = Manhattan::new().distance(&a, &b);
|
||||
|
||||
assert!((l1 - 9.0).abs() < 1e-8);
|
||||
}
|
||||
}
|
||||
@@ -8,14 +8,14 @@
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::math::distance::Distance;
|
||||
//! use smartcore::math::distance::minkowski::Minkowski;
|
||||
//! use smartcore::metrics::distance::Distance;
|
||||
//! use smartcore::metrics::distance::minkowski::Minkowski;
|
||||
//!
|
||||
//! let x = vec![1., 1.];
|
||||
//! let y = vec![2., 2.];
|
||||
//!
|
||||
//! let l1: f64 = Minkowski { p: 1 }.distance(&x, &y);
|
||||
//! let l2: f64 = Minkowski { p: 2 }.distance(&x, &y);
|
||||
//! let l1: f64 = Minkowski::new(1).distance(&x, &y);
|
||||
//! let l2: f64 = Minkowski::new(2).distance(&x, &y);
|
||||
//!
|
||||
//! ```
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
@@ -23,37 +23,47 @@
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::linalg::basic::arrays::ArrayView1;
|
||||
use crate::numbers::basenum::Number;
|
||||
|
||||
use super::Distance;
|
||||
|
||||
/// Defines the Minkowski distance of order `p`
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Minkowski {
|
||||
pub struct Minkowski<T: Number> {
|
||||
/// order, integer
|
||||
pub p: u16,
|
||||
_t: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Distance<Vec<T>, T> for Minkowski {
|
||||
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T {
|
||||
if x.len() != y.len() {
|
||||
impl<T: Number> Minkowski<T> {
|
||||
/// instatiate the initial structure
|
||||
pub fn new(p: u16) -> Minkowski<T> {
|
||||
Minkowski { p, _t: PhantomData }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Number, A: ArrayView1<T>> Distance<A> for Minkowski<T> {
|
||||
fn distance(&self, x: &A, y: &A) -> f64 {
|
||||
if x.shape() != y.shape() {
|
||||
panic!("Input vector sizes are different");
|
||||
}
|
||||
if self.p < 1 {
|
||||
panic!("p must be at least 1");
|
||||
}
|
||||
|
||||
let mut dist = T::zero();
|
||||
let p_t = T::from_u16(self.p).unwrap();
|
||||
let p_t = self.p as f64;
|
||||
|
||||
for i in 0..x.len() {
|
||||
let d = (x[i] - y[i]).abs();
|
||||
dist += d.powf(p_t);
|
||||
}
|
||||
let dist: f64 = x
|
||||
.iterator(0)
|
||||
.zip(y.iterator(0))
|
||||
.map(|(&a, &b)| (a - b).to_f64().unwrap().abs().powf(p_t))
|
||||
.sum();
|
||||
|
||||
dist.powf(T::one() / p_t)
|
||||
dist.powf(1f64 / p_t)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,15 +71,18 @@ impl<T: RealNumber> Distance<Vec<T>, T> for Minkowski {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn minkowski_distance() {
|
||||
let a = vec![1., 2., 3.];
|
||||
let b = vec![4., 5., 6.];
|
||||
|
||||
let l1: f64 = Minkowski { p: 1 }.distance(&a, &b);
|
||||
let l2: f64 = Minkowski { p: 2 }.distance(&a, &b);
|
||||
let l3: f64 = Minkowski { p: 3 }.distance(&a, &b);
|
||||
let l1: f64 = Minkowski::new(1).distance(&a, &b);
|
||||
let l2: f64 = Minkowski::new(2).distance(&a, &b);
|
||||
let l3: f64 = Minkowski::new(3).distance(&a, &b);
|
||||
|
||||
assert!((l1 - 9.0).abs() < 1e-8);
|
||||
assert!((l2 - 5.19615242).abs() < 1e-8);
|
||||
@@ -82,6 +95,6 @@ mod tests {
|
||||
let a = vec![1., 2., 3.];
|
||||
let b = vec![4., 5., 6.];
|
||||
|
||||
let _: f64 = Minkowski { p: 0 }.distance(&a, &b);
|
||||
let _: f64 = Minkowski::new(0).distance(&a, &b);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,118 @@
|
||||
//! # Collection of Distance Functions
|
||||
//!
|
||||
//! Many algorithms in machine learning require a measure of distance between data points. Distance metric (or metric) is a function that defines a distance between a pair of point elements of a set.
|
||||
//! Formally, the distance can be any metric measure that is defined as \\( d(x, y) \geq 0\\) and follows three conditions:
|
||||
//! 1. \\( d(x, y) = 0 \\) if and only \\( x = y \\), positive definiteness
|
||||
//! 1. \\( d(x, y) = d(y, x) \\), symmetry
|
||||
//! 1. \\( d(x, y) \leq d(x, z) + d(z, y) \\), subadditivity or triangle inequality
|
||||
//!
|
||||
//! for all \\(x, y, z \in Z \\)
|
||||
//!
|
||||
//! A good distance metric helps to improve the performance of classification, clustering and information retrieval algorithms significantly.
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
/// Cosine distance
|
||||
pub mod cosine;
|
||||
/// Euclidean Distance is the straight-line distance between two points in Euclidean spacere that presents the shortest distance between these points.
|
||||
pub mod euclidian;
|
||||
/// Hamming Distance between two strings is the number of positions at which the corresponding symbols are different.
|
||||
pub mod hamming;
|
||||
/// The Mahalanobis distance is the distance between two points in multivariate space.
|
||||
pub mod mahalanobis;
|
||||
/// Also known as rectilinear distance, city block distance, taxicab metric.
|
||||
pub mod manhattan;
|
||||
/// A generalization of both the Euclidean distance and the Manhattan distance.
|
||||
pub mod minkowski;
|
||||
|
||||
use std::cmp::{Eq, Ordering, PartialOrd};
|
||||
|
||||
use crate::linalg::basic::arrays::Array2;
|
||||
use crate::linalg::traits::lu::LUDecomposable;
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Distance metric, a function that calculates distance between two points
|
||||
pub trait Distance<T>: Clone {
|
||||
/// Calculates distance between _a_ and _b_
|
||||
fn distance(&self, a: &T, b: &T) -> f64;
|
||||
}
|
||||
|
||||
/// Multitude of distance metric functions
|
||||
pub struct Distances {}
|
||||
|
||||
impl Distances {
|
||||
/// Euclidian distance, see [`Euclidian`](euclidian/index.html)
|
||||
pub fn euclidian<T: Number>() -> euclidian::Euclidian<T> {
|
||||
euclidian::Euclidian::new()
|
||||
}
|
||||
|
||||
/// Minkowski distance, see [`Minkowski`](minkowski/index.html)
|
||||
/// * `p` - function order. Should be >= 1
|
||||
pub fn minkowski<T: Number>(p: u16) -> minkowski::Minkowski<T> {
|
||||
minkowski::Minkowski::new(p)
|
||||
}
|
||||
|
||||
/// Manhattan distance, see [`Manhattan`](manhattan/index.html)
|
||||
pub fn manhattan<T: Number>() -> manhattan::Manhattan<T> {
|
||||
manhattan::Manhattan::new()
|
||||
}
|
||||
|
||||
/// Hamming distance, see [`Hamming`](hamming/index.html)
|
||||
pub fn hamming<T: Number>() -> hamming::Hamming<T> {
|
||||
hamming::Hamming::new()
|
||||
}
|
||||
|
||||
/// Mahalanobis distance, see [`Mahalanobis`](mahalanobis/index.html)
|
||||
pub fn mahalanobis<T: Number, M: Array2<T>, C: Array2<f64> + LUDecomposable<f64>>(
|
||||
data: &M,
|
||||
) -> mahalanobis::Mahalanobis<T, C> {
|
||||
mahalanobis::Mahalanobis::new(data)
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
/// ### Pairwise dissimilarities.
|
||||
///
|
||||
/// Representing distances as pairwise dissimilarities, so to build a
|
||||
/// graph of closest neighbours. This representation can be reused for
|
||||
/// different implementations
|
||||
/// (initially used in this library for [FastPair](algorithm/neighbour/fastpair)).
|
||||
/// The edge of the subgraph is defined by `PairwiseDistance`.
|
||||
/// The calling algorithm can store a list of distances as
|
||||
/// a list of these structures.
|
||||
///
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct PairwiseDistance<T: RealNumber> {
|
||||
/// index of the vector in the original `Matrix` or list
|
||||
pub node: usize,
|
||||
|
||||
/// index of the closest neighbor in the original `Matrix` or same list
|
||||
pub neighbour: Option<usize>,
|
||||
|
||||
/// measure of distance, according to the algorithm distance function
|
||||
/// if the distance is None, the edge has value "infinite" or max distance
|
||||
/// each algorithm has to match
|
||||
pub distance: Option<T>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Eq for PairwiseDistance<T> {}
|
||||
|
||||
impl<T: RealNumber> PartialEq for PairwiseDistance<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.node == other.node
|
||||
&& self.neighbour == other.neighbour
|
||||
&& self.distance == other.distance
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> PartialOrd for PairwiseDistance<T> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
self.distance.partial_cmp(&other.distance)
|
||||
}
|
||||
}
|
||||
+46
-16
@@ -10,48 +10,71 @@
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::metrics::f1::F1;
|
||||
//! use smartcore::metrics::Metrics;
|
||||
//! let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
||||
//! let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
|
||||
//!
|
||||
//! let score: f64 = F1 {beta: 1.0}.get_score(&y_pred, &y_true);
|
||||
//! let beta = 1.0; // beta default is equal 1.0 anyway
|
||||
//! let score: f64 = F1::new_with(beta).get_score( &y_true, &y_pred);
|
||||
//! ```
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::linalg::basic::arrays::ArrayView1;
|
||||
use crate::metrics::precision::Precision;
|
||||
use crate::metrics::recall::Recall;
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::floatnum::FloatNumber;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
|
||||
use crate::metrics::Metrics;
|
||||
|
||||
/// F-measure
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct F1<T: RealNumber> {
|
||||
pub struct F1<T> {
|
||||
/// a positive real factor
|
||||
pub beta: T,
|
||||
pub beta: f64,
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> F1<T> {
|
||||
impl<T: Number + RealNumber + FloatNumber> Metrics<T> for F1<T> {
|
||||
fn new() -> Self {
|
||||
let beta: f64 = 1f64;
|
||||
Self {
|
||||
beta,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
/// create a typed object to call Recall functions
|
||||
fn new_with(beta: f64) -> Self {
|
||||
Self {
|
||||
beta,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
/// Computes F1 score
|
||||
/// * `y_true` - cround truth (correct) labels.
|
||||
/// * `y_pred` - predicted labels, as returned by a classifier.
|
||||
pub fn get_score<V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
|
||||
if y_true.len() != y_pred.len() {
|
||||
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64 {
|
||||
if y_true.shape() != y_pred.shape() {
|
||||
panic!(
|
||||
"The vector sizes don't match: {} != {}",
|
||||
y_true.len(),
|
||||
y_pred.len()
|
||||
y_true.shape(),
|
||||
y_pred.shape()
|
||||
);
|
||||
}
|
||||
let beta2 = self.beta * self.beta;
|
||||
|
||||
let p = Precision {}.get_score(y_true, y_pred);
|
||||
let r = Recall {}.get_score(y_true, y_pred);
|
||||
let p = Precision::new().get_score(y_true, y_pred);
|
||||
let r = Recall::new().get_score(y_true, y_pred);
|
||||
|
||||
(T::one() + beta2) * (p * r) / (beta2 * p + r)
|
||||
(1f64 + beta2) * (p * r) / ((beta2 * p) + r)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,14 +82,21 @@ impl<T: RealNumber> F1<T> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn f1() {
|
||||
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
||||
let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
|
||||
|
||||
let score1: f64 = F1 { beta: 1.0 }.get_score(&y_pred, &y_true);
|
||||
let score2: f64 = F1 { beta: 1.0 }.get_score(&y_true, &y_true);
|
||||
let beta = 1.0;
|
||||
let score1: f64 = F1::new_with(beta).get_score(&y_true, &y_pred);
|
||||
let score2: f64 = F1::new_with(beta).get_score(&y_true, &y_true);
|
||||
|
||||
println!("{score1:?}");
|
||||
println!("{score2:?}");
|
||||
|
||||
assert!((score1 - 0.57142857).abs() < 1e-8);
|
||||
assert!((score2 - 1.0).abs() < 1e-8);
|
||||
|
||||
@@ -10,45 +10,65 @@
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::metrics::mean_absolute_error::MeanAbsoluteError;
|
||||
//! use smartcore::metrics::Metrics;
|
||||
//! let y_pred: Vec<f64> = vec![3., -0.5, 2., 7.];
|
||||
//! let y_true: Vec<f64> = vec![2.5, 0.0, 2., 8.];
|
||||
//!
|
||||
//! let mse: f64 = MeanAbsoluteError {}.get_score(&y_pred, &y_true);
|
||||
//! let mse: f64 = MeanAbsoluteError::new().get_score( &y_true, &y_pred);
|
||||
//! ```
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::linalg::basic::arrays::ArrayView1;
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::floatnum::FloatNumber;
|
||||
|
||||
use crate::metrics::Metrics;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
/// Mean Absolute Error
|
||||
pub struct MeanAbsoluteError {}
|
||||
pub struct MeanAbsoluteError<T> {
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl MeanAbsoluteError {
|
||||
impl<T: Number + FloatNumber> Metrics<T> for MeanAbsoluteError<T> {
|
||||
/// create a typed object to call MeanAbsoluteError functions
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
fn new_with(_parameter: f64) -> Self {
|
||||
Self {
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
/// Computes mean absolute error
|
||||
/// * `y_true` - Ground truth (correct) target values.
|
||||
/// * `y_pred` - Estimated target values.
|
||||
pub fn get_score<T: RealNumber, V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
|
||||
if y_true.len() != y_pred.len() {
|
||||
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64 {
|
||||
if y_true.shape() != y_pred.shape() {
|
||||
panic!(
|
||||
"The vector sizes don't match: {} != {}",
|
||||
y_true.len(),
|
||||
y_pred.len()
|
||||
y_true.shape(),
|
||||
y_pred.shape()
|
||||
);
|
||||
}
|
||||
|
||||
let n = y_true.len();
|
||||
let mut ras = T::zero();
|
||||
let n = y_true.shape();
|
||||
let mut ras: T = T::zero();
|
||||
for i in 0..n {
|
||||
ras += (y_true.get(i) - y_pred.get(i)).abs();
|
||||
let res: T = *y_true.get(i) - *y_pred.get(i);
|
||||
ras += res.abs();
|
||||
}
|
||||
|
||||
ras / T::from_usize(n).unwrap()
|
||||
ras.to_f64().unwrap() / n as f64
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,14 +76,17 @@ impl MeanAbsoluteError {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn mean_absolute_error() {
|
||||
let y_true: Vec<f64> = vec![3., -0.5, 2., 7.];
|
||||
let y_pred: Vec<f64> = vec![2.5, 0.0, 2., 8.];
|
||||
|
||||
let score1: f64 = MeanAbsoluteError {}.get_score(&y_pred, &y_true);
|
||||
let score2: f64 = MeanAbsoluteError {}.get_score(&y_true, &y_true);
|
||||
let score1: f64 = MeanAbsoluteError::new().get_score(&y_true, &y_pred);
|
||||
let score2: f64 = MeanAbsoluteError::new().get_score(&y_true, &y_true);
|
||||
|
||||
assert!((score1 - 0.5).abs() < 1e-8);
|
||||
assert!((score2 - 0.0).abs() < 1e-8);
|
||||
|
||||
@@ -10,45 +10,65 @@
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::metrics::mean_squared_error::MeanSquareError;
|
||||
//! use smartcore::metrics::Metrics;
|
||||
//! let y_pred: Vec<f64> = vec![3., -0.5, 2., 7.];
|
||||
//! let y_true: Vec<f64> = vec![2.5, 0.0, 2., 8.];
|
||||
//!
|
||||
//! let mse: f64 = MeanSquareError {}.get_score(&y_pred, &y_true);
|
||||
//! let mse: f64 = MeanSquareError::new().get_score( &y_true, &y_pred);
|
||||
//! ```
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::linalg::basic::arrays::ArrayView1;
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::floatnum::FloatNumber;
|
||||
|
||||
use crate::metrics::Metrics;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
/// Mean Squared Error
|
||||
pub struct MeanSquareError {}
|
||||
pub struct MeanSquareError<T> {
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl MeanSquareError {
|
||||
impl<T: Number + FloatNumber> Metrics<T> for MeanSquareError<T> {
|
||||
/// create a typed object to call MeanSquareError functions
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
fn new_with(_parameter: f64) -> Self {
|
||||
Self {
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
/// Computes mean squared error
|
||||
/// * `y_true` - Ground truth (correct) target values.
|
||||
/// * `y_pred` - Estimated target values.
|
||||
pub fn get_score<T: RealNumber, V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
|
||||
if y_true.len() != y_pred.len() {
|
||||
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64 {
|
||||
if y_true.shape() != y_pred.shape() {
|
||||
panic!(
|
||||
"The vector sizes don't match: {} != {}",
|
||||
y_true.len(),
|
||||
y_pred.len()
|
||||
y_true.shape(),
|
||||
y_pred.shape()
|
||||
);
|
||||
}
|
||||
|
||||
let n = y_true.len();
|
||||
let n = y_true.shape();
|
||||
let mut rss = T::zero();
|
||||
for i in 0..n {
|
||||
rss += (y_true.get(i) - y_pred.get(i)).square();
|
||||
let res = *y_true.get(i) - *y_pred.get(i);
|
||||
rss += res * res;
|
||||
}
|
||||
|
||||
rss / T::from_usize(n).unwrap()
|
||||
rss.to_f64().unwrap() / n as f64
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,14 +76,17 @@ impl MeanSquareError {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn mean_squared_error() {
|
||||
let y_true: Vec<f64> = vec![3., -0.5, 2., 7.];
|
||||
let y_pred: Vec<f64> = vec![2.5, 0.0, 2., 8.];
|
||||
|
||||
let score1: f64 = MeanSquareError {}.get_score(&y_pred, &y_true);
|
||||
let score2: f64 = MeanSquareError {}.get_score(&y_true, &y_true);
|
||||
let score1: f64 = MeanSquareError::new().get_score(&y_true, &y_pred);
|
||||
let score2: f64 = MeanSquareError::new().get_score(&y_true, &y_true);
|
||||
|
||||
assert!((score1 - 0.375).abs() < 1e-8);
|
||||
assert!((score2 - 0.0).abs() < 1e-8);
|
||||
|
||||
+142
-64
@@ -4,7 +4,7 @@
|
||||
//! In a feedback loop you build your model first, then you get feedback from metrics, improve it and repeat until your model achieve desirable performance.
|
||||
//! Evaluation metrics helps to explain the performance of a model and compare models based on an objective criterion.
|
||||
//!
|
||||
//! Choosing the right metric is crucial while evaluating machine learning models. In SmartCore you will find metrics for these classes of ML models:
|
||||
//! Choosing the right metric is crucial while evaluating machine learning models. In `smartcore` you will find metrics for these classes of ML models:
|
||||
//!
|
||||
//! * [Classification metrics](struct.ClassificationMetrics.html)
|
||||
//! * [Regression metrics](struct.RegressionMetrics.html)
|
||||
@@ -12,7 +12,7 @@
|
||||
//!
|
||||
//! Example:
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
//! use smartcore::linear::logistic_regression::LogisticRegression;
|
||||
//! use smartcore::metrics::*;
|
||||
//!
|
||||
@@ -37,27 +37,30 @@
|
||||
//! &[4.9, 2.4, 3.3, 1.0],
|
||||
//! &[6.6, 2.9, 4.6, 1.3],
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//! let y: Vec<f64> = vec![
|
||||
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
//! ]).unwrap();
|
||||
//! let y: Vec<i8> = vec![
|
||||
//! 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
//! ];
|
||||
//!
|
||||
//! let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
//!
|
||||
//! let y_hat = lr.predict(&x).unwrap();
|
||||
//!
|
||||
//! let acc = ClassificationMetrics::accuracy().get_score(&y, &y_hat);
|
||||
//! let acc = ClassificationMetricsOrd::accuracy().get_score(&y, &y_hat);
|
||||
//! // or
|
||||
//! let acc = accuracy(&y, &y_hat);
|
||||
//! ```
|
||||
|
||||
/// Accuracy score.
|
||||
pub mod accuracy;
|
||||
/// Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores.
|
||||
// TODO: reimplement AUC
|
||||
// /// Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores.
|
||||
pub mod auc;
|
||||
/// Compute the homogeneity, completeness and V-Measure scores.
|
||||
pub mod cluster_hcv;
|
||||
pub(crate) mod cluster_helpers;
|
||||
/// Multitude of distance metrics are defined here
|
||||
pub mod distance;
|
||||
/// F1 score, also known as balanced F-score or F-measure.
|
||||
pub mod f1;
|
||||
/// Mean absolute error regression loss.
|
||||
@@ -71,150 +74,225 @@ pub mod r2;
|
||||
/// Computes the recall.
|
||||
pub mod recall;
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::linalg::basic::arrays::{Array1, ArrayView1};
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::floatnum::FloatNumber;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// A trait to be implemented by all metrics
|
||||
pub trait Metrics<T> {
|
||||
/// instantiate a new Metrics trait-object
|
||||
/// <https://doc.rust-lang.org/error-index.html#E0038>
|
||||
fn new() -> Self
|
||||
where
|
||||
Self: Sized;
|
||||
/// used to instantiate metric with a paramenter
|
||||
fn new_with(_parameter: f64) -> Self
|
||||
where
|
||||
Self: Sized;
|
||||
/// compute score realated to this metric
|
||||
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64;
|
||||
}
|
||||
|
||||
/// Use these metrics to compare classification models.
|
||||
pub struct ClassificationMetrics {}
|
||||
pub struct ClassificationMetrics<T> {
|
||||
phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
/// Use these metrics to compare classification models for
|
||||
/// numbers that require `Ord`.
|
||||
pub struct ClassificationMetricsOrd<T> {
|
||||
phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
/// Metrics for regression models.
|
||||
pub struct RegressionMetrics {}
|
||||
pub struct RegressionMetrics<T> {
|
||||
phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
/// Cluster metrics.
|
||||
pub struct ClusterMetrics {}
|
||||
|
||||
impl ClassificationMetrics {
|
||||
/// Accuracy score, see [accuracy](accuracy/index.html).
|
||||
pub fn accuracy() -> accuracy::Accuracy {
|
||||
accuracy::Accuracy {}
|
||||
}
|
||||
pub struct ClusterMetrics<T> {
|
||||
phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: Number + RealNumber + FloatNumber> ClassificationMetrics<T> {
|
||||
/// Recall, see [recall](recall/index.html).
|
||||
pub fn recall() -> recall::Recall {
|
||||
recall::Recall {}
|
||||
pub fn recall() -> recall::Recall<T> {
|
||||
recall::Recall::new()
|
||||
}
|
||||
|
||||
/// Precision, see [precision](precision/index.html).
|
||||
pub fn precision() -> precision::Precision {
|
||||
precision::Precision {}
|
||||
pub fn precision() -> precision::Precision<T> {
|
||||
precision::Precision::new()
|
||||
}
|
||||
|
||||
/// F1 score, also known as balanced F-score or F-measure, see [F1](f1/index.html).
|
||||
pub fn f1<T: RealNumber>(beta: T) -> f1::F1<T> {
|
||||
f1::F1 { beta }
|
||||
pub fn f1(beta: f64) -> f1::F1<T> {
|
||||
f1::F1::new_with(beta)
|
||||
}
|
||||
|
||||
/// Area Under the Receiver Operating Characteristic Curve (ROC AUC), see [AUC](auc/index.html).
|
||||
pub fn roc_auc_score() -> auc::AUC {
|
||||
auc::AUC {}
|
||||
pub fn roc_auc_score() -> auc::AUC<T> {
|
||||
auc::AUC::<T>::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl RegressionMetrics {
|
||||
impl<T: Number + Ord> ClassificationMetricsOrd<T> {
|
||||
/// Accuracy score, see [accuracy](accuracy/index.html).
|
||||
pub fn accuracy() -> accuracy::Accuracy<T> {
|
||||
accuracy::Accuracy::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Number + FloatNumber> RegressionMetrics<T> {
|
||||
/// Mean squared error, see [mean squared error](mean_squared_error/index.html).
|
||||
pub fn mean_squared_error() -> mean_squared_error::MeanSquareError {
|
||||
mean_squared_error::MeanSquareError {}
|
||||
pub fn mean_squared_error() -> mean_squared_error::MeanSquareError<T> {
|
||||
mean_squared_error::MeanSquareError::new()
|
||||
}
|
||||
|
||||
/// Mean absolute error, see [mean absolute error](mean_absolute_error/index.html).
|
||||
pub fn mean_absolute_error() -> mean_absolute_error::MeanAbsoluteError {
|
||||
mean_absolute_error::MeanAbsoluteError {}
|
||||
pub fn mean_absolute_error() -> mean_absolute_error::MeanAbsoluteError<T> {
|
||||
mean_absolute_error::MeanAbsoluteError::new()
|
||||
}
|
||||
|
||||
/// Coefficient of determination (R2), see [R2](r2/index.html).
|
||||
pub fn r2() -> r2::R2 {
|
||||
r2::R2 {}
|
||||
pub fn r2() -> r2::R2<T> {
|
||||
r2::R2::<T>::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ClusterMetrics {
|
||||
impl<T: Number + Ord> ClusterMetrics<T> {
|
||||
/// Homogeneity and completeness and V-Measure scores at once.
|
||||
pub fn hcv_score() -> cluster_hcv::HCVScore {
|
||||
cluster_hcv::HCVScore {}
|
||||
pub fn hcv_score() -> cluster_hcv::HCVScore<T> {
|
||||
cluster_hcv::HCVScore::<T>::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Function that calculated accuracy score, see [accuracy](accuracy/index.html).
|
||||
/// * `y_true` - cround truth (correct) labels
|
||||
/// * `y_pred` - predicted labels, as returned by a classifier.
|
||||
pub fn accuracy<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
|
||||
ClassificationMetrics::accuracy().get_score(y_true, y_pred)
|
||||
pub fn accuracy<T: Number + Ord, V: ArrayView1<T>>(y_true: &V, y_pred: &V) -> f64 {
|
||||
let obj = ClassificationMetricsOrd::<T>::accuracy();
|
||||
obj.get_score(y_true, y_pred)
|
||||
}
|
||||
|
||||
/// Calculated recall score, see [recall](recall/index.html)
|
||||
/// * `y_true` - cround truth (correct) labels.
|
||||
/// * `y_pred` - predicted labels, as returned by a classifier.
|
||||
pub fn recall<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
|
||||
ClassificationMetrics::recall().get_score(y_true, y_pred)
|
||||
pub fn recall<T: Number + RealNumber + FloatNumber, V: ArrayView1<T>>(
|
||||
y_true: &V,
|
||||
y_pred: &V,
|
||||
) -> f64 {
|
||||
let obj = ClassificationMetrics::<T>::recall();
|
||||
obj.get_score(y_true, y_pred)
|
||||
}
|
||||
|
||||
/// Calculated precision score, see [precision](precision/index.html).
|
||||
/// * `y_true` - cround truth (correct) labels.
|
||||
/// * `y_pred` - predicted labels, as returned by a classifier.
|
||||
pub fn precision<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
|
||||
ClassificationMetrics::precision().get_score(y_true, y_pred)
|
||||
pub fn precision<T: Number + RealNumber + FloatNumber, V: ArrayView1<T>>(
|
||||
y_true: &V,
|
||||
y_pred: &V,
|
||||
) -> f64 {
|
||||
let obj = ClassificationMetrics::<T>::precision();
|
||||
obj.get_score(y_true, y_pred)
|
||||
}
|
||||
|
||||
/// Computes F1 score, see [F1](f1/index.html).
|
||||
/// * `y_true` - cround truth (correct) labels.
|
||||
/// * `y_pred` - predicted labels, as returned by a classifier.
|
||||
pub fn f1<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V, beta: T) -> T {
|
||||
ClassificationMetrics::f1(beta).get_score(y_true, y_pred)
|
||||
pub fn f1<T: Number + RealNumber + FloatNumber, V: ArrayView1<T>>(
|
||||
y_true: &V,
|
||||
y_pred: &V,
|
||||
beta: f64,
|
||||
) -> f64 {
|
||||
let obj = ClassificationMetrics::<T>::f1(beta);
|
||||
obj.get_score(y_true, y_pred)
|
||||
}
|
||||
|
||||
/// AUC score, see [AUC](auc/index.html).
|
||||
/// * `y_true` - cround truth (correct) labels.
|
||||
/// * `y_pred_probabilities` - probability estimates, as returned by a classifier.
|
||||
pub fn roc_auc_score<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred_probabilities: &V) -> T {
|
||||
ClassificationMetrics::roc_auc_score().get_score(y_true, y_pred_probabilities)
|
||||
pub fn roc_auc_score<
|
||||
T: Number + RealNumber + FloatNumber + PartialOrd,
|
||||
V: ArrayView1<T> + Array1<T> + Array1<T>,
|
||||
>(
|
||||
y_true: &V,
|
||||
y_pred_probabilities: &V,
|
||||
) -> f64 {
|
||||
let obj = ClassificationMetrics::<T>::roc_auc_score();
|
||||
obj.get_score(y_true, y_pred_probabilities)
|
||||
}
|
||||
|
||||
/// Computes mean squared error, see [mean squared error](mean_squared_error/index.html).
|
||||
/// * `y_true` - Ground truth (correct) target values.
|
||||
/// * `y_pred` - Estimated target values.
|
||||
pub fn mean_squared_error<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
|
||||
RegressionMetrics::mean_squared_error().get_score(y_true, y_pred)
|
||||
pub fn mean_squared_error<T: Number + FloatNumber, V: ArrayView1<T>>(
|
||||
y_true: &V,
|
||||
y_pred: &V,
|
||||
) -> f64 {
|
||||
RegressionMetrics::<T>::mean_squared_error().get_score(y_true, y_pred)
|
||||
}
|
||||
|
||||
/// Computes mean absolute error, see [mean absolute error](mean_absolute_error/index.html).
|
||||
/// * `y_true` - Ground truth (correct) target values.
|
||||
/// * `y_pred` - Estimated target values.
|
||||
pub fn mean_absolute_error<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
|
||||
RegressionMetrics::mean_absolute_error().get_score(y_true, y_pred)
|
||||
pub fn mean_absolute_error<T: Number + FloatNumber, V: ArrayView1<T>>(
|
||||
y_true: &V,
|
||||
y_pred: &V,
|
||||
) -> f64 {
|
||||
RegressionMetrics::<T>::mean_absolute_error().get_score(y_true, y_pred)
|
||||
}
|
||||
|
||||
/// Computes R2 score, see [R2](r2/index.html).
|
||||
/// * `y_true` - Ground truth (correct) target values.
|
||||
/// * `y_pred` - Estimated target values.
|
||||
pub fn r2<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
|
||||
RegressionMetrics::r2().get_score(y_true, y_pred)
|
||||
pub fn r2<T: Number + FloatNumber, V: ArrayView1<T>>(y_true: &V, y_pred: &V) -> f64 {
|
||||
RegressionMetrics::<T>::r2().get_score(y_true, y_pred)
|
||||
}
|
||||
|
||||
/// Homogeneity metric of a cluster labeling given a ground truth (range is between 0.0 and 1.0).
|
||||
/// A cluster result satisfies homogeneity if all of its clusters contain only data points which are members of a single class.
|
||||
/// * `labels_true` - ground truth class labels to be used as a reference.
|
||||
/// * `labels_pred` - cluster labels to evaluate.
|
||||
pub fn homogeneity_score<T: RealNumber, V: BaseVector<T>>(labels_true: &V, labels_pred: &V) -> T {
|
||||
ClusterMetrics::hcv_score()
|
||||
.get_score(labels_true, labels_pred)
|
||||
.0
|
||||
pub fn homogeneity_score<
|
||||
T: Number + FloatNumber + RealNumber + Ord,
|
||||
V: ArrayView1<T> + Array1<T>,
|
||||
>(
|
||||
y_true: &V,
|
||||
y_pred: &V,
|
||||
) -> f64 {
|
||||
let mut obj = ClusterMetrics::<T>::hcv_score();
|
||||
obj.compute(y_true, y_pred);
|
||||
obj.homogeneity().unwrap()
|
||||
}
|
||||
|
||||
///
|
||||
/// Completeness metric of a cluster labeling given a ground truth (range is between 0.0 and 1.0).
|
||||
/// * `labels_true` - ground truth class labels to be used as a reference.
|
||||
/// * `labels_pred` - cluster labels to evaluate.
|
||||
pub fn completeness_score<T: RealNumber, V: BaseVector<T>>(labels_true: &V, labels_pred: &V) -> T {
|
||||
ClusterMetrics::hcv_score()
|
||||
.get_score(labels_true, labels_pred)
|
||||
.1
|
||||
pub fn completeness_score<
|
||||
T: Number + FloatNumber + RealNumber + Ord,
|
||||
V: ArrayView1<T> + Array1<T>,
|
||||
>(
|
||||
y_true: &V,
|
||||
y_pred: &V,
|
||||
) -> f64 {
|
||||
let mut obj = ClusterMetrics::<T>::hcv_score();
|
||||
obj.compute(y_true, y_pred);
|
||||
obj.completeness().unwrap()
|
||||
}
|
||||
|
||||
/// The harmonic mean between homogeneity and completeness.
|
||||
/// * `labels_true` - ground truth class labels to be used as a reference.
|
||||
/// * `labels_pred` - cluster labels to evaluate.
|
||||
pub fn v_measure_score<T: RealNumber, V: BaseVector<T>>(labels_true: &V, labels_pred: &V) -> T {
|
||||
ClusterMetrics::hcv_score()
|
||||
.get_score(labels_true, labels_pred)
|
||||
.2
|
||||
pub fn v_measure_score<T: Number + FloatNumber + RealNumber + Ord, V: ArrayView1<T> + Array1<T>>(
|
||||
y_true: &V,
|
||||
y_pred: &V,
|
||||
) -> f64 {
|
||||
let mut obj = ClusterMetrics::<T>::hcv_score();
|
||||
obj.compute(y_true, y_pred);
|
||||
obj.v_measure().unwrap()
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user