3 Commits

Author SHA1 Message Date
Lorenzo (Mec-iS)
61db4ebd90 Add test 2022-08-24 12:34:56 +01:00
Lorenzo (Mec-iS)
2603a1f42b Add test 2022-08-24 11:44:30 +01:00
Alan Race
663db0334d Added per-class probability prediction for random forests 2022-07-11 16:08:03 +02:00
143 changed files with 9558 additions and 21394 deletions
-6
View File
@@ -1,6 +0,0 @@
# These owners will be the default owners for everything in
# the repo. Unless a later match takes precedence,
# Developers in this list will be requested for
# review when someone opens a pull request.
* @morenol
* @Mec-iS
-22
View File
@@ -1,22 +0,0 @@
# Code of Conduct
As contributors and maintainers of this project, and in the interest of fostering an open and welcoming community, we pledge to respect all people who contribute through reporting issues, posting feature requests, updating documentation, submitting pull requests or patches, and other activities.
We are committed to making participation in this project a harassment-free experience for everyone, regardless of level of experience, gender, gender identity and expression, sexual orientation, disability, personal appearance, body size, race, ethnicity, age, religion, or nationality.
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery
* Personal attacks
* Trolling or insulting/derogatory comments
* Public or private harassment
* Publishing other's private information, such as physical or electronic addresses, without explicit permission
* Other unethical or unprofessional conduct.
Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct. By adopting this Code of Conduct, project maintainers commit themselves to fairly and consistently applying these principles to every aspect of managing this project. Project maintainers who do not follow or enforce the Code of Conduct may be permanently removed from the project team.
This code of conduct applies both within project spaces and in public spaces when an individual is representing the project or its community.
Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by opening an issue or contacting one or more of the project maintainers.
This Code of Conduct is adapted from the [Contributor Covenant](http://contributor-covenant.org), version 1.2.0, available at [http://contributor-covenant.org/version/1/2/0/](http://contributor-covenant.org/version/1/2/0/)
-72
View File
@@ -1,72 +0,0 @@
# **Contributing**
When contributing to this repository, please first discuss the change you wish to make via issue,
email, or any other method with the owners of this repository before making a change.
Please note we have a [code of conduct](CODE_OF_CONDUCT.md), please follow it in all your interactions with the project.
## Background
We try to follow these principles:
* follow as much as possible the sklearn API to give a frictionless user experience for practitioners already familiar with it
* use only pure-Rust implementations for safety and future-proofing (with some low-level limited exceptions)
* do not use macros in the library code to allow readability and transparent behavior
* priority is not on "big data" dataset, try to be fast for small/average dataset with limited memory footprint.
## Pull Request Process
1. Open a PR following the template (erase the part of the template you don't need).
2. Update the CHANGELOG.md with details of changes to the interface if they are breaking changes, this includes new environment variables, exposed ports useful file locations and container parameters.
3. Pull Request can be merged in once you have the sign-off of one other developer, or if you do not have permission to do that you may request the reviewer to merge it for you.
### generic guidelines
Take a look to the conventions established by existing code:
* Every module should come with some reference to scientific literature that allows relating the code to research. Use the `//!` comments at the top of the module to tell readers about the basics of the procedure you are implementing.
* Every module should provide a Rust doctest, a brief test embedded with the documentation that explains how to use the procedure implemented.
* Every module should provide comprehensive tests at the end, in its `mod tests {}` sub-module. These tests can be flagged or not with configuration flags to allow WebAssembly target.
* Run `cargo doc --no-deps --open` and read the generated documentation in the browser to be sure that your changes reflects in the documentation and new code is documented.
#### digging deeper
* a nice overview of the codebase is given by [static analyzer](https://mozilla.github.io/rust-code-analysis/metrics.html):
```
$ cargo install rust-code-analysis-cli
// print metrics for every module
$ rust-code-analysis-cli -m -O json -o . -p src/ --pr
// print full AST for a module
$ rust-code-analysis-cli -p src/algorithm/neighbour/fastpair.rs --ls 22 --le 213 -d > ast.txt
```
* find more information about what happens in your binary with [`twiggy`](https://rustwasm.github.io/twiggy/install.html). This need a compiled binary so create a brief `main {}` function using `smartcore` and then point `twiggy` to that file.
* Please take a look to the output of a profiler to spot most evident performance problems, see [this guide about using a profiler](http://www.codeofview.com/fix-rs/2017/01/24/how-to-optimize-rust-programs-on-linux/).
## Issue Report Process
1. Go to the project's issues.
2. Select the template that better fits your issue.
3. Read carefully the instructions and write within the template guidelines.
4. Submit it and wait for support.
## Reviewing process
1. After a PR is opened maintainers are notified
2. Probably changes will be required to comply with the workflow, these commands are run automatically and all tests shall pass:
* **Formatting**: run `rustfmt src/*.rs` to apply automatic formatting
* **Linting**: `clippy` is used with command `cargo clippy --all-features -- -Drust-2018-idioms -Dwarnings`
* **Coverage** (optional): `tarpaulin` is used with command `cargo tarpaulin --out Lcov --all-features -- --test-threads 1`
* **Testing**: multiple test pipelines are run for different targets
3. When everything is OK, code is merged.
## Contribution Best Practices
* Read this [how-to about Github workflow here](https://guides.github.com/introduction/flow/) if you are not familiar with.
* Read all the texts related to [contributing for an OS community](https://github.com/HTTP-APIs/hydrus/tree/master/.github).
* Read this [how-to about writing a PR](https://github.com/blog/1943-how-to-write-the-perfect-pull-request) and this [other how-to about writing a issue](https://wiredcraft.com/blog/how-we-write-our-github-issues/)
* **read history**: search past open or closed issues for your problem before opening a new issue.
* **PRs on develop**: any change should be PRed first in `development`
* **testing**: everything should work and be tested as defined in the workflow. If any is failing for non-related reasons, annotate the test failure in the PR comment.
-43
View File
@@ -1,43 +0,0 @@
# smartcore: Introduction to modules
Important source of information:
* [Rust API guidelines](https://rust-lang.github.io/api-guidelines/about.html)
## Walkthrough: traits system and basic structures
#### numbers
The library is founded on basic traits provided by `num-traits`. Basic traits are in `src/numbers`. These traits are used to define all the procedures in the library to make everything safer and provide constraints to what implementations can handle.
#### linalg
`numbers` are made at use in linear algebra structures in the **`src/linalg/basic`** module. These sub-modules define the traits used all over the code base.
* *arrays*: In particular data structures like `Array`, `Array1` (1-dimensional), `Array2` (matrix, 2-D); plus their "views" traits. Views are used to provide no-footprint access to data, they have composed traits to allow writing (mutable traits: `MutArray`, `ArrayViewMut`, ...).
* *matrix*: This provides the main entrypoint to matrices operations and currently the only structure provided in the shape of `struct DenseMatrix`. A matrix can be instantiated and automatically make available all the traits in "arrays" (sparse matrices implementation will be provided).
* *vector*: Convenience traits are implemented for `std::Vec` to allow extensive reuse.
These are all traits and by definition they do not allow instantiation. For instantiable structures see implementation like `DenseMatrix` with relative constructor.
#### linalg/traits
The traits in `src/linalg/traits` are closely linked to Linear Algebra's theoretical framework. These traits are used to specify characteristics and constraints for types accepted by various algorithms. For example these allow to define if a matrix is `QRDecomposable` and/or `SVDDecomposable`. See docstring for referencese to theoretical framework.
As above these are all traits and by definition they do not allow instantiation. They are mostly used to provide constraints for implementations. For example, the implementation for Linear Regression requires the input data `X` to be in `smartcore`'s trait system `Array2<FloatNumber> + QRDecomposable<TX> + SVDDecomposable<TX>`, a 2-D matrix that is both QR and SVD decomposable; that is what the provided strucure `linalg::arrays::matrix::DenseMatrix` happens to be: `impl<T: FloatNumber> QRDecomposable<T> for DenseMatrix<T> {};impl<T: FloatNumber> SVDDecomposable<T> for DenseMatrix<T> {}`.
#### metrics
Implementations for metrics (classification, regression, cluster, ...) and distance measure (Euclidean, Hamming, Manhattan, ...). For example: `Accuracy`, `F1`, `AUC`, `Precision`, `R2`. As everything else in the code base, these implementations reuse `numbers` and `linalg` traits and structures.
These are collected in structures like `pub struct ClassificationMetrics<T> {}` that implements `metrics::Metrics`, these are groups of functions (classification, regression, cluster, ...) that provide instantiation for the structures. Each of those instantiation can be passed around using the relative function, like `pub fn accuracy<T: Number + RealNumber + FloatNumber, V: ArrayView1<T>>(y_true: &V, y_pred: &V) -> T`. This provides a mechanism for metrics to be passed to higher interfaces like the `cross_validate`:
```rust
let results =
cross_validate(
BiasedEstimator::new(), // custom estimator
&x, &y, // input data
NoParameters {}, // extra parameters
cv, // type of cross validator
&accuracy // **metrics function** <--------
).unwrap();
```
TODO: complete for all modules
## Notebooks
Proceed to the [**notebooks**](https://github.com/smartcorelib/smartcore-jupyter/) to see these modules in action.
-25
View File
@@ -1,25 +0,0 @@
### I'm submitting a
- [ ] bug report.
- [ ] improvement.
- [ ] feature request.
### Current Behaviour:
<!-- Describe about the bug -->
### Expected Behaviour:
<!-- Describe what will happen if bug is removed -->
### Steps to reproduce:
<!-- If you can then please provide the steps to reproduce the bug -->
### Snapshot:
<!-- If you can then please provide the screenshot of the issue you are facing -->
### Environment:
<!-- Please provide the following environment details if relevant -->
* rustc version
* cargo version
* OS details
### Do you want to work on this issue?
<!-- yes/no -->
-29
View File
@@ -1,29 +0,0 @@
<!-- Please create (if there is not one yet) a issue before sending a PR -->
<!-- Add issue number (Eg: fixes #123) -->
<!-- Always provide changes in existing tests or new tests -->
Fixes #
### Checklist
- [ ] My branch is up-to-date with development branch.
- [ ] Everything works and tested on latest stable Rust.
- [ ] Coverage and Linting have been applied
### Current behaviour
<!-- Describe the code you are going to change and its behaviour -->
### New expected behaviour
<!-- Describe the new code and its expected behaviour -->
### Change logs
<!-- #### Added -->
<!-- Edit these points below to describe the new features added with this PR -->
<!-- - Feature 1 -->
<!-- - Feature 2 -->
<!-- #### Changed -->
<!-- Edit these points below to describe the changes made in existing functionality with this PR -->
<!-- - Change 1 -->
<!-- - Change 1 -->
+29 -46
View File
@@ -2,73 +2,56 @@ name: CI
on:
push:
branches: [main, development]
branches: [ main, development ]
pull_request:
branches: [development]
branches: [ development ]
jobs:
tests:
runs-on: "${{ matrix.platform.os }}-latest"
strategy:
matrix:
platform:
[
{ os: "windows", target: "x86_64-pc-windows-msvc" },
{ os: "windows", target: "i686-pc-windows-msvc" },
{ os: "ubuntu", target: "x86_64-unknown-linux-gnu" },
{ os: "ubuntu", target: "i686-unknown-linux-gnu" },
{ os: "ubuntu", target: "wasm32-unknown-unknown" },
{ os: "macos", target: "aarch64-apple-darwin" },
]
platform: [
{ os: "windows", target: "x86_64-pc-windows-msvc" },
{ os: "windows", target: "i686-pc-windows-msvc" },
{ os: "ubuntu", target: "x86_64-unknown-linux-gnu" },
{ os: "ubuntu", target: "i686-unknown-linux-gnu" },
{ os: "ubuntu", target: "wasm32-unknown-unknown" },
{ os: "macos", target: "aarch64-apple-darwin" },
]
env:
TZ: "/usr/share/zoneinfo/your/location"
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v2
- name: Cache .cargo and target
uses: actions/cache@v4
uses: actions/cache@v2
with:
path: |
~/.cargo
./target
key: ${{ runner.os }}-cargo-${{ matrix.platform.target }}-${{ hashFiles('**/Cargo.toml') }}
restore-keys: ${{ runner.os }}-cargo-${{ matrix.platform.target }}
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.toml') }}
restore-keys: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.toml') }}
- name: Install Rust toolchain
uses: dtolnay/rust-toolchain@stable
uses: actions-rs/toolchain@v1
with:
targets: ${{ matrix.platform.target }}
toolchain: stable
target: ${{ matrix.platform.target }}
profile: minimal
default: true
- name: Install test runner for wasm
if: matrix.platform.target == 'wasm32-unknown-unknown'
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
- name: Stable Build with all features
run: cargo build --all-features --target ${{ matrix.platform.target }}
- name: Stable Build without features
run: cargo build --target ${{ matrix.platform.target }}
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
- name: Stable Build
uses: actions-rs/cargo@v1
with:
command: build
args: --all-features --target ${{ matrix.platform.target }}
- name: Tests
if: matrix.platform.target == 'x86_64-unknown-linux-gnu' || matrix.platform.target == 'x86_64-pc-windows-msvc' || matrix.platform.target == 'aarch64-apple-darwin'
run: cargo test --all-features
uses: actions-rs/cargo@v1
with:
command: test
args: --all-features
- name: Tests in WASM
if: matrix.platform.target == 'wasm32-unknown-unknown'
run: wasm-pack test --node -- --all-features
check_features:
runs-on: "${{ matrix.platform.os }}-latest"
strategy:
matrix:
platform: [{ os: "ubuntu" }]
features: ["--features serde", "--features datasets", ""]
env:
TZ: "/usr/share/zoneinfo/your/location"
steps:
- uses: actions/checkout@v4
- name: Cache .cargo and target
uses: actions/cache@v4
with:
path: |
~/.cargo
./target
key: ${{ runner.os }}-cargo-features-${{ hashFiles('Cargo.toml') }}
restore-keys: ${{ runner.os }}-cargo-features
- name: Install Rust toolchain
uses: dtolnay/rust-toolchain@stable
- name: Stable Build
run: cargo build --no-default-features ${{ matrix.features }}
+23 -12
View File
@@ -12,22 +12,33 @@ jobs:
env:
TZ: "/usr/share/zoneinfo/your/location"
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v2
- name: Cache .cargo
uses: actions/cache@v4
uses: actions/cache@v2
with:
path: |
~/.cargo
./target
key: ${{ runner.os }}-coverage-cargo-${{ hashFiles('Cargo.toml') }}
restore-keys: ${{ runner.os }}-coverage-cargo
key: ${{ runner.os }}-coverage-cargo-${{ hashFiles('**/Cargo.toml') }}
restore-keys: ${{ runner.os }}-coverage-cargo-${{ hashFiles('**/Cargo.toml') }}
- name: Install Rust toolchain
uses: dtolnay/rust-toolchain@nightly
- name: Install cargo-tarpaulin
run: cargo install cargo-tarpaulin
- name: Run cargo-tarpaulin
run: cargo tarpaulin --out Lcov --all-features -- --test-threads 1
- name: Upload to codecov.io
uses: codecov/codecov-action@v4
uses: actions-rs/toolchain@v1
with:
fail_ci_if_error: false
toolchain: nightly
profile: minimal
default: true
- name: Install cargo-tarpaulin
uses: actions-rs/install@v0.1
with:
crate: cargo-tarpaulin
version: latest
use-tool-cache: true
- name: Run cargo-tarpaulin
uses: actions-rs/cargo@v1
with:
command: tarpaulin
args: --out Lcov --all-features -- --test-threads 1
- name: Upload to codecov.io
uses: codecov/codecov-action@v1
with:
fail_ci_if_error: true
+19 -10
View File
@@ -6,27 +6,36 @@ on:
pull_request:
branches: [ development ]
jobs:
lint:
runs-on: ubuntu-latest
env:
TZ: "/usr/share/zoneinfo/your/location"
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v2
- name: Cache .cargo and target
uses: actions/cache@v4
uses: actions/cache@v2
with:
path: |
~/.cargo
./target
key: ${{ runner.os }}-lint-cargo-${{ hashFiles('Cargo.toml') }}
restore-keys: ${{ runner.os }}-lint-cargo
key: ${{ runner.os }}-lint-cargo-${{ hashFiles('**/Cargo.toml') }}
restore-keys: ${{ runner.os }}-lint-cargo-${{ hashFiles('**/Cargo.toml') }}
- name: Install Rust toolchain
uses: dtolnay/rust-toolchain@stable
uses: actions-rs/toolchain@v1
with:
components: rustfmt, clippy
- name: Check format
run: cargo fmt --all -- --check
toolchain: stable
profile: minimal
default: true
- run: rustup component add rustfmt
- name: Check formt
uses: actions-rs/cargo@v1
with:
command: fmt
args: --all -- --check
- run: rustup component add clippy
- name: Run clippy
run: cargo clippy --all-features -- -Drust-2018-idioms -Dwarnings
uses: actions-rs/cargo@v1
with:
command: clippy
args: --all-features -- -Drust-2018-idioms -Dwarnings
-12
View File
@@ -17,15 +17,3 @@ smartcore.code-workspace
# OS
.DS_Store
flamegraph.svg
perf.data
perf.data.old
src.dot
out.svg
FlameGraph/
out.stacks
*.json
*.txt
+1 -34
View File
@@ -4,40 +4,7 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.4.8] - 2025-11-29
- WARNING: Breaking changes!
- `LassoParameters` and `LassoSearchParameters` have a new field `fit_intercept`. When it is set to false, the `beta_0` term in the formula will be forced to zero, and `intercept` field in `Lasso` will be set to `None`.
## [0.4.0] - 2023-04-05
## Added
- WARNING: Breaking changes!
- `DenseMatrix` constructor now returns `Result` to avoid user instantiating inconsistent rows/cols count. Their return values need to be unwrapped with `unwrap()`, see tests
## [0.3.0] - 2022-11-09
## Added
- WARNING: Breaking changes!
- Complete refactoring with **extensive API changes** that includes:
* moving to a new traits system, less structs more traits
* adapting all the modules to the new traits system
* moving to Rust 2021, use of object-safe traits and `as_ref`
* reorganization of the code base, eliminate duplicates
- implements `readers` (needs "serde" feature) for read/write CSV file, extendible to other formats
- default feature is now Wasm-/Wasi-first
## Changed
- WARNING: Breaking changes!
- Seeds to multiple algorithims that depend on random number generation
- Added a new parameter to `train_test_split` to define the seed
- changed use of "serde" feature
## Dropped
- WARNING: Breaking changes!
- Drop `nalgebra-bindings` feature, only `ndarray` as supported library
## [0.2.1] - 2021-05-10
## [Unreleased]
## Added
- L2 regularization penalty to the Logistic Regression
-41
View File
@@ -1,41 +0,0 @@
cff-version: 1.2.0
message: "If this software contributes to published work, please cite smartcore."
type: software
title: "smartcore: Machine Learning in Rust"
abstract: "smartcore is a comprehensive machine learning and numerical computing library for Rust, offering supervised and unsupervised algorithms, model evaluation tools, and linear algebra abstractions, with optional ndarray integration." [web:5][web:3]
repository-code: "https://github.com/smartcorelib/smartcore" [web:5]
url: "https://github.com/smartcorelib" [web:3]
license: "MIT" [web:13]
keywords:
- Rust
- machine learning
- numerical computing
- linear algebra
- classification
- regression
- clustering
- SVM
- Random Forest
- XGBoost [web:5]
authors:
- name: "smartcore Developers" [web:7]
- name: "Lorenzo (contributor)" [web:16]
- name: "Community contributors" [web:7]
version: "0.4.2" [attached_file:1]
date-released: "2025-09-14" [attached_file:1]
preferred-citation:
type: software
title: "smartcore: Machine Learning in Rust"
authors:
- name: "smartcore Developers" [web:7]
url: "https://github.com/smartcorelib" [web:3]
repository-code: "https://github.com/smartcorelib/smartcore" [web:5]
license: "MIT" [web:13]
references:
- type: manual
title: "smartcore Documentation"
url: "https://docs.rs/smartcore" [web:5]
- type: webpage
title: "smartcore Homepage"
url: "https://github.com/smartcorelib" [web:3]
notes: "For development features, see the docs.rs page and the repository README; SmartCore includes algorithms such as SVM, Random Forest, K-Means, PCA, DBSCAN, and XGBoost." [web:5]
+28 -46
View File
@@ -1,66 +1,48 @@
[package]
name = "smartcore"
description = "Machine Learning in Rust."
description = "The most advanced machine learning library in rust."
homepage = "https://smartcorelib.org"
version = "0.4.9"
authors = ["smartcore Developers"]
edition = "2021"
version = "0.2.1"
authors = ["SmartCore Developers"]
edition = "2018"
license = "Apache-2.0"
documentation = "https://docs.rs/smartcore"
repository = "https://github.com/smartcorelib/smartcore"
readme = "README.md"
keywords = ["machine-learning", "statistical", "ai", "optimization", "linear-algebra"]
categories = ["science"]
exclude = [
".github",
".gitignore",
"smartcore.iml",
"smartcore.svg",
"tests/"
]
[dependencies]
approx = "0.5.1"
cfg-if = "1.0.0"
ndarray = { version = "0.17", optional = true }
num-traits = "0.2.12"
num = "0.4"
rand = { version = "0.8.5", default-features = false, features = ["small_rng"] }
rand_distr = { version = "0.4", optional = true }
serde = { version = "1", features = ["derive"], optional = true }
ordered-float = "5.1.0"
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
typetag = { version = "0.2", optional = true }
[features]
default = []
serde = ["dep:serde", "dep:typetag"]
ndarray-bindings = ["dep:ndarray"]
datasets = ["dep:rand_distr", "std_rand", "serde"]
std_rand = ["rand/std_rng", "rand/std"]
# used by wasm32-unknown-unknown for in-browser usage
js = ["getrandom/js"]
default = ["datasets"]
ndarray-bindings = ["ndarray"]
nalgebra-bindings = ["nalgebra"]
datasets = []
[dependencies]
ndarray = { version = "0.15", optional = true }
nalgebra = { version = "0.31", optional = true }
num-traits = "0.2"
num = "0.4"
rand = "0.8"
rand_distr = "0.4"
serde = { version = "1", features = ["derive"], optional = true }
[target.'cfg(target_arch = "wasm32")'.dependencies]
getrandom = { version = "0.2.8", optional = true }
[target.'cfg(all(target_arch = "wasm32", not(target_os = "wasi")))'.dev-dependencies]
wasm-bindgen-test = "0.3"
getrandom = { version = "0.2", features = ["js"] }
[dev-dependencies]
itertools = "0.13.0"
criterion = "0.3"
serde_json = "1.0"
bincode = "1.3.1"
[workspace]
[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
wasm-bindgen-test = "0.3"
[profile.test]
debug = 1
opt-level = 3
[[bench]]
name = "distance"
harness = false
[profile.release]
strip = true
lto = true
codegen-units = 1
overflow-checks = true
[[bench]]
name = "naive_bayes"
harness = false
required-features = ["ndarray-bindings", "nalgebra-bindings"]
+1 -1
View File
@@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2019-present at smartcore developers (smartcorelib.org)
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
+4 -133
View File
@@ -1,147 +1,18 @@
<p align="center">
<a href="https://smartcorelib.org">
<img src="smartcore.svg" width="450" alt="smartcore">
<img src="smartcore.svg" width="450" alt="SmartCore">
</a>
</p>
<p align = "center">
<strong>
<a href="https://smartcorelib.org">User guide</a> | <a href="https://docs.rs/smartcore/">API</a> | <a href="https://github.com/smartcorelib/smartcore-jupyter">Notebooks</a>
<a href="https://smartcorelib.org">User guide</a> | <a href="https://docs.rs/smartcore/">API</a> | <a href="https://github.com/smartcorelib/smartcore-examples">Examples</a>
</strong>
</p>
-----
<p align = "center">
<b>Machine Learning in Rust</b>
<b>The Most Advanced Machine Learning Library In Rust.</b>
</p>
-----
[![CI](https://github.com/smartcorelib/smartcore/actions/workflows/ci.yml/badge.svg)](https://github.com/smartcorelib/smartcore/actions/workflows/ci.yml) [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.17219259.svg)](https://doi.org/10.5281/zenodo.17219259)
To start getting familiar with the new smartcore v0.4 API, there is now available a [**Jupyter Notebook environment repository**](https://github.com/smartcorelib/smartcore-jupyter). Please see instructions there, contributions welcome see [CONTRIBUTING](.github/CONTRIBUTING.md).
smartcore is a fast, ergonomic machine learning library for Rust, covering classical supervised and unsupervised methods with a modular linear algebra abstraction and optional ndarray support. It aims to provide production-friendly APIs, strong typing, and good defaults while remaining flexible for research and experimentation.
## Highlights
- Broad algorithm coverage: linear models, tree-based methods, ensembles, SVMs, neighbors, clustering, decomposition, and preprocessing.
- Strong linear algebra traits with optional ndarray integration for users who prefer array-first workflows.
- WASM-first defaults with attention to portability; features such as serde and datasets are opt-in.
- Practical utilities for model selection, evaluation, readers (CSV), dataset generators, and built-in sample datasets.
## Install
Add to Cargo.toml:
```toml
[dependencies]
smartcore = "^0.4.3"
```
For the latest development branch:
```toml
[dependencies]
smartcore = { git = "https://github.com/smartcorelib/smartcore", branch = "development" }
```
Optional features (examples):
- datasets
- serde
- ndarray-bindings (deprecated in favor of ndarray-only support per recent changes)
Check Cargo.toml for available features and compatibility notes.
## Quick start
Here is a minimal example fitting a KNN classifier from native Rust vectors using DenseMatrix:
```rust
use smartcore::linalg::basic::matrix::DenseMatrix;
use smartcore::neighbors::knn_classifier::KNNClassifier;
// Turn vector slices into a matrix
let x = DenseMatrix::from_2d_array(&[
&[1., 2.],
&[3., 4.],
&[5., 6.],
&[7., 8.],
&[9., 10.],
]).unwrap;
// Class labels
let y = vec![2, 2, 2, 3, 3];
// Train classifier
let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
// Predict
let yhat = knn.predict(&x).unwrap();
```
This example mirrors the “First Example” section of the crate docs and demonstrates smartcores ergonomic API surface.
## Algorithms
smartcore organizes algorithms into clear modules with consistent traits:
- Clustering: K-Means, DBSCAN, agglomerative (including single-linkage), with K-Means++ initialization and utilities.
- Matrix decomposition: SVD, EVD, Cholesky, LU, QR, plus related linear algebra helpers.
- Linear models: OLS, Ridge, Lasso, ElasticNet, Logistic Regression.
- Ensemble and tree-based: Random Forest (classifier and regressor), Extra Trees, shared reusable components across trees and forests.
- SVM: SVC/SVR with kernel enum support and multiclass extensions.
- Neighbors: KNN classification and regression with distance metrics and fast selection helpers.
- Naive Bayes: Gaussian, Bernoulli, Categorical, Multinomial.
- Preprocessing: encoders, split utilities, and common transforms.
- Model selection and metrics: K-fold, search parameters, and evaluation utilities.
Recent refactors emphasize reusable components in trees/forests and expanded multiclass SVM capabilities. XGBoost-style regression and single-linkage clustering have been added. See CHANGELOG for API changes and migration notes.
## Data access and readers
- CSV readers: Read matrices from CSV with configurable delimiter and header rows, with helpful error messages and testing utilities (including non-IO reader abstractions).
- Dataset generators: make_blobs, make_circles, make_moons for quick experiments.
- Built-in datasets (feature-gated): digits, diabetes, breast cancer, boston, with serialization utilities to persist or refresh .xy bundles.
## WebAssembly and portability
smartcore adopts a WASM/WASI-first posture in defaults to ease browser and embedded deployments. Some file-system operations are restricted in wasm targets; tests and IO utilities are structured to avoid unsupported calls where possible. Enable features like serde selectively to minimize footprint. Consult module-level docs and CHANGELOG for target-specific caveats.
## Notebooks
A curated set of Jupyter notebooks is available via the companion repository to explore smartcore interactively. To run locally, use EVCXR to enable Rust notebooks. This is the recommended path to quickly experiment with the v0.4 API.
## Roadmap and recent changes
- Trait-system refactor, fewer structs and more object-safe traits, large codebase reorganization.
- Move to Rust 2021 edition and cleanup of duplicate code paths.
- Seeds and deterministic controls across algorithms using RNG plumbing.
- Search parameter API for hyperparameter exploration in K-Means and SVM families.
- Tree and forest components refactored for reuse; Extra Trees added.
- SVM multiclass support; SVR kernel enum and related improvements.
- XGBoost-style regression introduced; single-linkage clustering implemented.
See CHANGELOG.md for precise details, deprecations, and breaking changes. Some features like nalgebra-bindings have been dropped in favor of ndarray-only paths. Default features are tuned for WASM/WASI builds; enable serde/datasets as needed.
## Contributing
Contributions are welcome:
- Open an issue describing the change and link it in the PR.
- Keep PRs in sync with the development branch and ensure tests pass on stable Rust.
- Provide or update tests; run clippy and apply formatting. Coverage and linting are part of the workflow.
- Use the provided PR and issue templates to describe behavior changes, new features, and expectations.
If adding IO, prefer abstractions that make non-IO testing straightforward (see readers/iotesting). For datasets, keep serialization helpers in tests gated appropriately to avoid unintended file writes in wasm targets.
## License
smartcore is open source under a permissive license; see Cargo.toml and LICENSE for details. The crate metadata identifies “smartcore Developers” as authors; community contributions are credited via Git history and releases.
## Acknowledgments
smartcores design incorporates well-known ML patterns while staying idiomatic to Rust. Thanks to all contributors who have helped expand algorithms, improve docs, modernize traits, and harden the codebase for production.
-----
+18
View File
@@ -0,0 +1,18 @@
#[macro_use]
extern crate criterion;
extern crate smartcore;
use criterion::black_box;
use criterion::Criterion;
use smartcore::math::distance::*;
fn criterion_benchmark(c: &mut Criterion) {
let a = vec![1., 2., 3.];
c.bench_function("Euclidean Distance", move |b| {
b.iter(|| Distances::euclidian().distance(black_box(&a), black_box(&a)))
});
}
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
+73
View File
@@ -0,0 +1,73 @@
use criterion::BenchmarkId;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use nalgebra::DMatrix;
use ndarray::Array2;
use smartcore::linalg::naive::dense_matrix::DenseMatrix;
use smartcore::linalg::BaseMatrix;
use smartcore::linalg::BaseVector;
use smartcore::naive_bayes::gaussian::GaussianNB;
pub fn gaussian_naive_bayes_fit_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("GaussianNB::fit");
for n_samples in [100_usize, 1000_usize, 10000_usize].iter() {
for n_features in [10_usize, 100_usize, 1000_usize].iter() {
let x = DenseMatrix::<f64>::rand(*n_samples, *n_features);
let y: Vec<f64> = (0..*n_samples)
.map(|i| (i % *n_samples / 5_usize) as f64)
.collect::<Vec<f64>>();
group.bench_with_input(
BenchmarkId::from_parameter(format!(
"n_samples: {}, n_features: {}",
n_samples, n_features
)),
n_samples,
|b, _| {
b.iter(|| {
GaussianNB::fit(black_box(&x), black_box(&y), Default::default()).unwrap();
})
},
);
}
}
group.finish();
}
pub fn gaussian_naive_matrix_datastructure(c: &mut Criterion) {
let mut group = c.benchmark_group("GaussianNB");
let classes = (0..10000).map(|i| (i % 25) as f64).collect::<Vec<f64>>();
group.bench_function("DenseMatrix", |b| {
let x = DenseMatrix::<f64>::rand(10000, 500);
let y = <DenseMatrix<f64> as BaseMatrix<f64>>::RowVector::from_array(&classes);
b.iter(|| {
GaussianNB::fit(black_box(&x), black_box(&y), Default::default()).unwrap();
})
});
group.bench_function("ndarray", |b| {
let x = Array2::<f64>::rand(10000, 500);
let y = <Array2<f64> as BaseMatrix<f64>>::RowVector::from_array(&classes);
b.iter(|| {
GaussianNB::fit(black_box(&x), black_box(&y), Default::default()).unwrap();
})
});
group.bench_function("ndalgebra", |b| {
let x = DMatrix::<f64>::rand(10000, 500);
let y = <DMatrix<f64> as BaseMatrix<f64>>::RowVector::from_array(&classes);
b.iter(|| {
GaussianNB::fit(black_box(&x), black_box(&y), Default::default()).unwrap();
})
});
}
criterion_group!(
benches,
gaussian_naive_bayes_fit_benchmark,
gaussian_naive_matrix_datastructure
);
criterion_main!(benches);
+15
View File
@@ -0,0 +1,15 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="RUST_MODULE" version="4">
<component name="NewModuleRootManager" inherit-compiler-output="true">
<exclude-output />
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/examples" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/tests" isTestSource="true" />
<sourceFolder url="file://$MODULE_DIR$/benches" isTestSource="true" />
<excludeFolder url="file://$MODULE_DIR$/target" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>
+1 -1
View File
@@ -76,5 +76,5 @@
y="81.876823"
x="91.861809"
id="tspan842"
sodipodi:role="line">smartcore</tspan></text>
sodipodi:role="line">SmartCore</tspan></text>
</svg>

Before

Width:  |  Height:  |  Size: 2.5 KiB

After

Width:  |  Height:  |  Size: 2.5 KiB

+60 -66
View File
@@ -1,50 +1,50 @@
use std::fmt::Debug;
use crate::linalg::basic::arrays::Array2;
use crate::metrics::distance::euclidian::*;
use crate::numbers::basenum::Number;
use crate::linalg::Matrix;
use crate::math::distance::euclidian::*;
use crate::math::num::RealNumber;
#[derive(Debug)]
pub struct BBDTree {
nodes: Vec<BBDTreeNode>,
pub struct BBDTree<T: RealNumber> {
nodes: Vec<BBDTreeNode<T>>,
index: Vec<usize>,
root: usize,
}
#[derive(Debug)]
struct BBDTreeNode {
struct BBDTreeNode<T: RealNumber> {
count: usize,
index: usize,
center: Vec<f64>,
radius: Vec<f64>,
sum: Vec<f64>,
cost: f64,
center: Vec<T>,
radius: Vec<T>,
sum: Vec<T>,
cost: T,
lower: Option<usize>,
upper: Option<usize>,
}
impl BBDTreeNode {
fn new(d: usize) -> BBDTreeNode {
impl<T: RealNumber> BBDTreeNode<T> {
fn new(d: usize) -> BBDTreeNode<T> {
BBDTreeNode {
count: 0,
index: 0,
center: vec![0f64; d],
radius: vec![0f64; d],
sum: vec![0f64; d],
cost: 0f64,
center: vec![T::zero(); d],
radius: vec![T::zero(); d],
sum: vec![T::zero(); d],
cost: T::zero(),
lower: Option::None,
upper: Option::None,
}
}
}
impl BBDTree {
pub fn new<T: Number, M: Array2<T>>(data: &M) -> BBDTree {
let nodes: Vec<BBDTreeNode> = Vec::new();
impl<T: RealNumber> BBDTree<T> {
pub fn new<M: Matrix<T>>(data: &M) -> BBDTree<T> {
let nodes = Vec::new();
let (n, _) = data.shape();
let index = (0..n).collect::<Vec<usize>>();
let index = (0..n).collect::<Vec<_>>();
let mut tree = BBDTree {
nodes,
@@ -59,20 +59,20 @@ impl BBDTree {
tree
}
pub(crate) fn clustering(
pub(in crate) fn clustering(
&self,
centroids: &[Vec<f64>],
sums: &mut Vec<Vec<f64>>,
centroids: &[Vec<T>],
sums: &mut Vec<Vec<T>>,
counts: &mut Vec<usize>,
membership: &mut Vec<usize>,
) -> f64 {
) -> T {
let k = centroids.len();
counts.iter_mut().for_each(|v| *v = 0);
let mut candidates = vec![0; k];
for i in 0..k {
candidates[i] = i;
sums[i].iter_mut().for_each(|v| *v = 0f64);
sums[i].iter_mut().for_each(|v| *v = T::zero());
}
self.filter(
@@ -89,13 +89,13 @@ impl BBDTree {
fn filter(
&self,
node: usize,
centroids: &[Vec<f64>],
centroids: &[Vec<T>],
candidates: &[usize],
k: usize,
sums: &mut Vec<Vec<f64>>,
sums: &mut Vec<Vec<T>>,
counts: &mut Vec<usize>,
membership: &mut Vec<usize>,
) -> f64 {
) -> T {
let d = centroids[0].len();
let mut min_dist =
@@ -163,9 +163,9 @@ impl BBDTree {
}
fn prune(
center: &[f64],
radius: &[f64],
centroids: &[Vec<f64>],
center: &[T],
radius: &[T],
centroids: &[Vec<T>],
best_index: usize,
test_index: usize,
) -> bool {
@@ -177,22 +177,22 @@ impl BBDTree {
let best = &centroids[best_index];
let test = &centroids[test_index];
let mut lhs = 0f64;
let mut rhs = 0f64;
let mut lhs = T::zero();
let mut rhs = T::zero();
for i in 0..d {
let diff = test[i] - best[i];
lhs += diff * diff;
if diff > 0f64 {
if diff > T::zero() {
rhs += (center[i] + radius[i] - best[i]) * diff;
} else {
rhs += (center[i] - radius[i] - best[i]) * diff;
}
}
lhs >= 2f64 * rhs
lhs >= T::two() * rhs
}
fn build_node<T: Number, M: Array2<T>>(&mut self, data: &M, begin: usize, end: usize) -> usize {
fn build_node<M: Matrix<T>>(&mut self, data: &M, begin: usize, end: usize) -> usize {
let (_, d) = data.shape();
let mut node = BBDTreeNode::new(d);
@@ -200,17 +200,17 @@ impl BBDTree {
node.count = end - begin;
node.index = begin;
let mut lower_bound = vec![0f64; d];
let mut upper_bound = vec![0f64; d];
let mut lower_bound = vec![T::zero(); d];
let mut upper_bound = vec![T::zero(); d];
for i in 0..d {
lower_bound[i] = data.get((self.index[begin], i)).to_f64().unwrap();
upper_bound[i] = data.get((self.index[begin], i)).to_f64().unwrap();
lower_bound[i] = data.get(self.index[begin], i);
upper_bound[i] = data.get(self.index[begin], i);
}
for i in begin..end {
for j in 0..d {
let c = data.get((self.index[i], j)).to_f64().unwrap();
let c = data.get(self.index[i], j);
if lower_bound[j] > c {
lower_bound[j] = c;
}
@@ -220,32 +220,32 @@ impl BBDTree {
}
}
let mut max_radius = -1f64;
let mut max_radius = T::from(-1.).unwrap();
let mut split_index = 0;
for i in 0..d {
node.center[i] = (lower_bound[i] + upper_bound[i]) / 2f64;
node.radius[i] = (upper_bound[i] - lower_bound[i]) / 2f64;
node.center[i] = (lower_bound[i] + upper_bound[i]) / T::two();
node.radius[i] = (upper_bound[i] - lower_bound[i]) / T::two();
if node.radius[i] > max_radius {
max_radius = node.radius[i];
split_index = i;
}
}
if max_radius < 1E-10 {
if max_radius < T::from(1E-10).unwrap() {
node.lower = Option::None;
node.upper = Option::None;
for i in 0..d {
node.sum[i] = data.get((self.index[begin], i)).to_f64().unwrap();
node.sum[i] = data.get(self.index[begin], i);
}
if end > begin + 1 {
let len = end - begin;
for i in 0..d {
node.sum[i] *= len as f64;
node.sum[i] *= T::from(len).unwrap();
}
}
node.cost = 0f64;
node.cost = T::zero();
return self.add_node(node);
}
@@ -254,10 +254,8 @@ impl BBDTree {
let mut i2 = end - 1;
let mut size = 0;
while i1 <= i2 {
let mut i1_good =
data.get((self.index[i1], split_index)).to_f64().unwrap() < split_cutoff;
let mut i2_good =
data.get((self.index[i2], split_index)).to_f64().unwrap() >= split_cutoff;
let mut i1_good = data.get(self.index[i1], split_index) < split_cutoff;
let mut i2_good = data.get(self.index[i2], split_index) >= split_cutoff;
if !i1_good && !i2_good {
self.index.swap(i1, i2);
@@ -283,9 +281,9 @@ impl BBDTree {
self.nodes[node.lower.unwrap()].sum[i] + self.nodes[node.upper.unwrap()].sum[i];
}
let mut mean = vec![0f64; d];
let mut mean = vec![T::zero(); d];
for (i, mean_i) in mean.iter_mut().enumerate().take(d) {
*mean_i = node.sum[i] / node.count as f64;
*mean_i = node.sum[i] / T::from(node.count).unwrap();
}
node.cost = BBDTree::node_cost(&self.nodes[node.lower.unwrap()], &mean)
@@ -294,17 +292,17 @@ impl BBDTree {
self.add_node(node)
}
fn node_cost(node: &BBDTreeNode, center: &[f64]) -> f64 {
fn node_cost(node: &BBDTreeNode<T>, center: &[T]) -> T {
let d = center.len();
let mut scatter = 0f64;
let mut scatter = T::zero();
for (i, center_i) in center.iter().enumerate().take(d) {
let x = (node.sum[i] / node.count as f64) - *center_i;
let x = (node.sum[i] / T::from(node.count).unwrap()) - *center_i;
scatter += x * x;
}
node.cost + node.count as f64 * scatter
node.cost + T::from(node.count).unwrap() * scatter
}
fn add_node(&mut self, new_node: BBDTreeNode) -> usize {
fn add_node(&mut self, new_node: BBDTreeNode<T>) -> usize {
let idx = self.nodes.len();
self.nodes.push(new_node);
idx
@@ -314,12 +312,9 @@ impl BBDTree {
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::linalg::naive::dense_matrix::DenseMatrix;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn bbdtree_iris() {
let data = DenseMatrix::from_2d_array(&[
@@ -343,8 +338,7 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
])
.unwrap();
]);
let tree = BBDTree::new(&data);
File diff suppressed because it is too large Load Diff
+75 -83
View File
@@ -4,12 +4,12 @@
//!
//! ```
//! use smartcore::algorithm::neighbour::cover_tree::*;
//! use smartcore::metrics::distance::Distance;
//! use smartcore::math::distance::Distance;
//!
//! #[derive(Clone)]
//! struct SimpleDistance {} // Our distance function
//!
//! impl Distance<i32> for SimpleDistance {
//! impl Distance<i32, f64> for SimpleDistance {
//! fn distance(&self, a: &i32, b: &i32) -> f64 { // simple simmetrical scalar distance
//! (a - b).abs() as f64
//! }
@@ -29,27 +29,28 @@ use serde::{Deserialize, Serialize};
use crate::algorithm::sort::heap_select::HeapSelection;
use crate::error::{Failed, FailedError};
use crate::metrics::distance::Distance;
use crate::math::distance::Distance;
use crate::math::num::RealNumber;
/// Implements Cover Tree algorithm
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct CoverTree<T, D: Distance<T>> {
base: f64,
inv_log_base: f64,
pub struct CoverTree<T, F: RealNumber, D: Distance<T, F>> {
base: F,
inv_log_base: F,
distance: D,
root: Node,
root: Node<F>,
data: Vec<T>,
identical_excluded: bool,
}
impl<T, D: Distance<T>> PartialEq for CoverTree<T, D> {
impl<T, F: RealNumber, D: Distance<T, F>> PartialEq for CoverTree<T, F, D> {
fn eq(&self, other: &Self) -> bool {
if self.data.len() != other.data.len() {
return false;
}
for i in 0..self.data.len() {
if self.distance.distance(&self.data[i], &other.data[i]) != 0f64 {
if self.distance.distance(&self.data[i], &other.data[i]) != F::zero() {
return false;
}
}
@@ -59,36 +60,36 @@ impl<T, D: Distance<T>> PartialEq for CoverTree<T, D> {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
struct Node {
struct Node<F: RealNumber> {
idx: usize,
max_dist: f64,
parent_dist: f64,
children: Vec<Node>,
max_dist: F,
parent_dist: F,
children: Vec<Node<F>>,
_scale: i64,
}
#[derive(Debug)]
struct DistanceSet {
struct DistanceSet<F: RealNumber> {
idx: usize,
dist: Vec<f64>,
dist: Vec<F>,
}
impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D> {
/// Construct a cover tree.
/// * `data` - vector of data points to search for.
/// * `distance` - distance metric to use for searching. This function should extend [`Distance`](../../../math/distance/index.html) interface.
pub fn new(data: Vec<T>, distance: D) -> Result<CoverTree<T, D>, Failed> {
let base = 1.3f64;
pub fn new(data: Vec<T>, distance: D) -> Result<CoverTree<T, F, D>, Failed> {
let base = F::from_f64(1.3).unwrap();
let root = Node {
idx: 0,
max_dist: 0f64,
parent_dist: 0f64,
max_dist: F::zero(),
parent_dist: F::zero(),
children: Vec::new(),
_scale: 0,
};
let mut tree = CoverTree {
base,
inv_log_base: 1f64 / base.ln(),
inv_log_base: F::one() / base.ln(),
distance,
root,
data,
@@ -103,7 +104,7 @@ impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
/// Find k nearest neighbors of `p`
/// * `p` - look for k nearest points to `p`
/// * `k` - the number of nearest neighbors to return
pub fn find(&self, p: &T, k: usize) -> Result<Vec<(usize, f64, &T)>, Failed> {
pub fn find(&self, p: &T, k: usize) -> Result<Vec<(usize, F, &T)>, Failed> {
if k == 0 {
return Err(Failed::because(FailedError::FindFailed, "k should be > 0"));
}
@@ -118,13 +119,13 @@ impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
let e = self.get_data_value(self.root.idx);
let mut d = self.distance.distance(e, p);
let mut current_cover_set: Vec<(f64, &Node)> = Vec::new();
let mut zero_set: Vec<(f64, &Node)> = Vec::new();
let mut current_cover_set: Vec<(F, &Node<F>)> = Vec::new();
let mut zero_set: Vec<(F, &Node<F>)> = Vec::new();
current_cover_set.push((d, &self.root));
let mut heap = HeapSelection::with_capacity(k);
heap.add(f64::MAX);
heap.add(F::max_value());
let mut empty_heap = true;
if !self.identical_excluded || self.get_data_value(self.root.idx) != p {
@@ -133,7 +134,7 @@ impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
}
while !current_cover_set.is_empty() {
let mut next_cover_set: Vec<(f64, &Node)> = Vec::new();
let mut next_cover_set: Vec<(F, &Node<F>)> = Vec::new();
for par in current_cover_set {
let parent = par.1;
for c in 0..parent.children.len() {
@@ -145,7 +146,7 @@ impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
}
let upper_bound = if empty_heap {
f64::INFINITY
F::infinity()
} else {
*heap.peek()
};
@@ -168,7 +169,7 @@ impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
current_cover_set = next_cover_set;
}
let mut neighbors: Vec<(usize, f64, &T)> = Vec::new();
let mut neighbors: Vec<(usize, F, &T)> = Vec::new();
let upper_bound = *heap.peek();
for ds in zero_set {
if ds.0 <= upper_bound {
@@ -188,25 +189,25 @@ impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
/// Find all nearest neighbors within radius `radius` from `p`
/// * `p` - look for k nearest points to `p`
/// * `radius` - radius of the search
pub fn find_radius(&self, p: &T, radius: f64) -> Result<Vec<(usize, f64, &T)>, Failed> {
if radius <= 0f64 {
pub fn find_radius(&self, p: &T, radius: F) -> Result<Vec<(usize, F, &T)>, Failed> {
if radius <= F::zero() {
return Err(Failed::because(
FailedError::FindFailed,
"radius should be > 0",
));
}
let mut neighbors: Vec<(usize, f64, &T)> = Vec::new();
let mut neighbors: Vec<(usize, F, &T)> = Vec::new();
let mut current_cover_set: Vec<(f64, &Node)> = Vec::new();
let mut zero_set: Vec<(f64, &Node)> = Vec::new();
let mut current_cover_set: Vec<(F, &Node<F>)> = Vec::new();
let mut zero_set: Vec<(F, &Node<F>)> = Vec::new();
let e = self.get_data_value(self.root.idx);
let mut d = self.distance.distance(e, p);
current_cover_set.push((d, &self.root));
while !current_cover_set.is_empty() {
let mut next_cover_set: Vec<(f64, &Node)> = Vec::new();
let mut next_cover_set: Vec<(F, &Node<F>)> = Vec::new();
for par in current_cover_set {
let parent = par.1;
for c in 0..parent.children.len() {
@@ -239,23 +240,23 @@ impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
Ok(neighbors)
}
fn new_leaf(&self, idx: usize) -> Node {
fn new_leaf(&self, idx: usize) -> Node<F> {
Node {
idx,
max_dist: 0f64,
parent_dist: 0f64,
max_dist: F::zero(),
parent_dist: F::zero(),
children: Vec::new(),
_scale: 100,
}
}
fn build_cover_tree(&mut self) {
let mut point_set: Vec<DistanceSet> = Vec::new();
let mut consumed_set: Vec<DistanceSet> = Vec::new();
let mut point_set: Vec<DistanceSet<F>> = Vec::new();
let mut consumed_set: Vec<DistanceSet<F>> = Vec::new();
let point = &self.data[0];
let idx = 0;
let mut max_dist = -1f64;
let mut max_dist = -F::one();
for i in 1..self.data.len() {
let dist = self.distance.distance(point, &self.data[i]);
@@ -283,16 +284,16 @@ impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
p: usize,
max_scale: i64,
top_scale: i64,
point_set: &mut Vec<DistanceSet>,
consumed_set: &mut Vec<DistanceSet>,
) -> Node {
point_set: &mut Vec<DistanceSet<F>>,
consumed_set: &mut Vec<DistanceSet<F>>,
) -> Node<F> {
if point_set.is_empty() {
self.new_leaf(p)
} else {
let max_dist = self.max(point_set);
let next_scale = (max_scale - 1).min(self.get_scale(max_dist));
if next_scale == i64::MIN {
let mut children: Vec<Node> = Vec::new();
if next_scale == std::i64::MIN {
let mut children: Vec<Node<F>> = Vec::new();
let mut leaf = self.new_leaf(p);
children.push(leaf);
while !point_set.is_empty() {
@@ -303,13 +304,13 @@ impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
}
Node {
idx: p,
max_dist: 0f64,
parent_dist: 0f64,
max_dist: F::zero(),
parent_dist: F::zero(),
children,
_scale: 100,
}
} else {
let mut far: Vec<DistanceSet> = Vec::new();
let mut far: Vec<DistanceSet<F>> = Vec::new();
self.split(point_set, &mut far, max_scale);
let child = self.batch_insert(p, next_scale, top_scale, point_set, consumed_set);
@@ -318,14 +319,14 @@ impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
point_set.append(&mut far);
child
} else {
let mut children: Vec<Node> = vec![child];
let mut new_point_set: Vec<DistanceSet> = Vec::new();
let mut new_consumed_set: Vec<DistanceSet> = Vec::new();
let mut children: Vec<Node<F>> = vec![child];
let mut new_point_set: Vec<DistanceSet<F>> = Vec::new();
let mut new_consumed_set: Vec<DistanceSet<F>> = Vec::new();
while !point_set.is_empty() {
let set: DistanceSet = point_set.remove(point_set.len() - 1);
let set: DistanceSet<F> = point_set.remove(point_set.len() - 1);
let new_dist = set.dist[set.dist.len() - 1];
let new_dist: F = set.dist[set.dist.len() - 1];
self.dist_split(
point_set,
@@ -373,7 +374,7 @@ impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
Node {
idx: p,
max_dist: self.max(consumed_set),
parent_dist: 0f64,
parent_dist: F::zero(),
children,
_scale: (top_scale - max_scale),
}
@@ -384,12 +385,12 @@ impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
fn split(
&self,
point_set: &mut Vec<DistanceSet>,
far_set: &mut Vec<DistanceSet>,
point_set: &mut Vec<DistanceSet<F>>,
far_set: &mut Vec<DistanceSet<F>>,
max_scale: i64,
) {
let fmax = self.get_cover_radius(max_scale);
let mut new_set: Vec<DistanceSet> = Vec::new();
let mut new_set: Vec<DistanceSet<F>> = Vec::new();
for n in point_set.drain(0..) {
if n.dist[n.dist.len() - 1] <= fmax {
new_set.push(n);
@@ -403,13 +404,13 @@ impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
fn dist_split(
&self,
point_set: &mut Vec<DistanceSet>,
new_point_set: &mut Vec<DistanceSet>,
point_set: &mut Vec<DistanceSet<F>>,
new_point_set: &mut Vec<DistanceSet<F>>,
new_point: &T,
max_scale: i64,
) {
let fmax = self.get_cover_radius(max_scale);
let mut new_set: Vec<DistanceSet> = Vec::new();
let mut new_set: Vec<DistanceSet<F>> = Vec::new();
for mut n in point_set.drain(0..) {
let new_dist = self
.distance
@@ -425,24 +426,24 @@ impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
point_set.append(&mut new_set);
}
fn get_cover_radius(&self, s: i64) -> f64 {
self.base.powf(s as f64)
fn get_cover_radius(&self, s: i64) -> F {
self.base.powf(F::from_i64(s).unwrap())
}
fn get_data_value(&self, idx: usize) -> &T {
&self.data[idx]
}
fn get_scale(&self, d: f64) -> i64 {
if d == 0f64 {
i64::MIN
fn get_scale(&self, d: F) -> i64 {
if d == F::zero() {
std::i64::MIN
} else {
(self.inv_log_base * d.ln()).ceil() as i64
(self.inv_log_base * d.ln()).ceil().to_i64().unwrap()
}
}
fn max(&self, distance_set: &[DistanceSet]) -> f64 {
let mut max = 0f64;
fn max(&self, distance_set: &[DistanceSet<F>]) -> F {
let mut max = F::zero();
for n in distance_set {
if max < n.dist[n.dist.len() - 1] {
max = n.dist[n.dist.len() - 1];
@@ -456,22 +457,19 @@ impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
mod tests {
use super::*;
use crate::metrics::distance::Distances;
use crate::math::distance::Distances;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
struct SimpleDistance {}
impl Distance<i32> for SimpleDistance {
impl Distance<i32, f64> for SimpleDistance {
fn distance(&self, a: &i32, b: &i32) -> f64 {
(a - b).abs() as f64
}
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn cover_tree_test() {
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
@@ -488,10 +486,7 @@ mod tests {
let knn: Vec<i32> = knn.iter().map(|v| *v.2).collect();
assert_eq!(vec!(3, 4, 5, 6, 7), knn);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn cover_tree_test1() {
let data = vec![
@@ -510,10 +505,7 @@ mod tests {
assert_eq!(vec!(0, 1, 2), knn);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
#[cfg(feature = "serde")]
fn serde() {
@@ -521,7 +513,7 @@ mod tests {
let tree = CoverTree::new(data, SimpleDistance {}).unwrap();
let deserialized_tree: CoverTree<i32, SimpleDistance> =
let deserialized_tree: CoverTree<i32, f64, SimpleDistance> =
serde_json::from_str(&serde_json::to_string(&tree).unwrap()).unwrap();
assert_eq!(tree, deserialized_tree);
-705
View File
@@ -1,705 +0,0 @@
///
/// ### FastPair: Data-structure for the dynamic closest-pair problem.
///
/// Reference:
/// Eppstein, David: Fast hierarchical clustering and other applications of
/// dynamic closest pairs. Journal of Experimental Algorithmics 5 (2000) 1.
///
/// Example:
/// ```
/// use smartcore::metrics::distance::PairwiseDistance;
/// use smartcore::linalg::basic::matrix::DenseMatrix;
/// use smartcore::algorithm::neighbour::fastpair::FastPair;
/// let x = DenseMatrix::<f64>::from_2d_array(&[
/// &[5.1, 3.5, 1.4, 0.2],
/// &[4.9, 3.0, 1.4, 0.2],
/// &[4.7, 3.2, 1.3, 0.2],
/// &[4.6, 3.1, 1.5, 0.2],
/// &[5.0, 3.6, 1.4, 0.2],
/// &[5.4, 3.9, 1.7, 0.4],
/// ]).unwrap();
/// let fastpair = FastPair::new(&x);
/// let closest_pair: PairwiseDistance<f64> = fastpair.unwrap().closest_pair();
/// ```
/// <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
/// <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::collections::HashMap;
use num::Bounded;
use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::metrics::distance::euclidian::Euclidian;
use crate::metrics::distance::PairwiseDistance;
use crate::numbers::floatnum::FloatNumber;
use crate::numbers::realnum::RealNumber;
///
/// Inspired by Python implementation:
/// <https://github.com/carsonfarmer/fastpair/blob/b8b4d3000ab6f795a878936667eee1b557bf353d/fastpair/base.py>
/// MIT License (MIT) Copyright (c) 2016 Carson Farmer
///
/// affinity used is Euclidean so to allow linkage with single, ward, complete and average
///
#[derive(Debug, Clone)]
pub struct FastPair<'a, T: RealNumber + FloatNumber, M: Array2<T>> {
/// initial matrix
samples: &'a M,
/// closest pair hashmap (connectivity matrix for closest pairs)
pub distances: HashMap<usize, PairwiseDistance<T>>,
/// conga line used to keep track of the closest pair
pub neighbours: Vec<usize>,
}
impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
/// Constructor
/// Instantiate and initialize the algorithm
pub fn new(m: &'a M) -> Result<Self, Failed> {
if m.shape().0 < 3 {
return Err(Failed::because(
FailedError::FindFailed,
"min number of rows should be 3",
));
}
let mut init = Self {
samples: m,
// to be computed in init(..)
distances: HashMap::with_capacity(m.shape().0),
neighbours: Vec::with_capacity(m.shape().0 + 1),
};
init.init();
Ok(init)
}
/// Initialise `FastPair` by passing a `Array2`.
/// Build a FastPairs data-structure from a set of (new) points.
fn init(&mut self) {
// basic measures
let len = self.samples.shape().0;
let max_index = self.samples.shape().0 - 1;
// Store all closest neighbors
let _distances = Box::new(HashMap::with_capacity(len));
let _neighbours = Box::new(Vec::with_capacity(len));
let mut distances = *_distances;
let mut neighbours = *_neighbours;
// fill neighbours with -1 values
neighbours.extend(0..len);
// init closest neighbour pairwise data
for index_row_i in 0..(max_index) {
distances.insert(
index_row_i,
PairwiseDistance {
node: index_row_i,
neighbour: Option::None,
distance: Some(<T as Bounded>::max_value()),
},
);
}
// loop through indeces and neighbours
for index_row_i in 0..(len) {
// start looking for the neighbour in the second element
let mut index_closest = index_row_i + 1; // closest neighbour index
let mut nbd: Option<T> = distances[&index_row_i].distance; // init neighbour distance
for index_row_j in (index_row_i + 1)..len {
distances.insert(
index_row_j,
PairwiseDistance {
node: index_row_j,
neighbour: Some(index_row_i),
distance: nbd,
},
);
let d = Euclidian::squared_distance(
&Vec::from_iterator(
self.samples.get_row(index_row_i).iterator(0).copied(),
self.samples.shape().1,
),
&Vec::from_iterator(
self.samples.get_row(index_row_j).iterator(0).copied(),
self.samples.shape().1,
),
);
if d < nbd.unwrap().to_f64().unwrap() {
// set this j-value to be the closest neighbour
index_closest = index_row_j;
nbd = Some(T::from(d).unwrap());
}
}
// Add that edge
distances.entry(index_row_i).and_modify(|e| {
e.distance = nbd;
e.neighbour = Some(index_closest);
});
}
// No more neighbors, terminate conga line.
// Last person on the line has no neigbors
distances.get_mut(&max_index).unwrap().neighbour = Some(max_index);
distances.get_mut(&(len - 1)).unwrap().distance = Some(<T as Bounded>::max_value());
// compute sparse matrix (connectivity matrix)
let mut sparse_matrix = M::zeros(len, len);
for (_, p) in distances.iter() {
sparse_matrix.set((p.node, p.neighbour.unwrap()), p.distance.unwrap());
}
self.distances = distances;
self.neighbours = neighbours;
}
/// Find closest pair by scanning list of nearest neighbors.
#[allow(dead_code)]
pub fn closest_pair(&self) -> PairwiseDistance<T> {
let mut a = self.neighbours[0]; // Start with first point
let mut d = self.distances[&a].distance;
for p in self.neighbours.iter() {
if self.distances[p].distance < d {
a = *p; // Update `a` and distance `d`
d = self.distances[p].distance;
}
}
let b = self.distances[&a].neighbour;
PairwiseDistance {
node: a,
neighbour: b,
distance: d,
}
}
///
/// Return order dissimilarities from closest to furthest
///
#[allow(dead_code)]
pub fn ordered_pairs(&self) -> std::vec::IntoIter<&PairwiseDistance<T>> {
// improvement: implement this to return `impl Iterator<Item = &PairwiseDistance<T>>`
// need to implement trait `Iterator` for `Vec<&PairwiseDistance<T>>`
let mut distances = self
.distances
.values()
.collect::<Vec<&PairwiseDistance<T>>>();
distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
distances.into_iter()
}
//
// Compute distances from input to all other points in data-structure.
// input is the row index of the sample matrix
//
#[allow(dead_code)]
fn distances_from(&self, index_row: usize) -> Vec<PairwiseDistance<T>> {
let mut distances = Vec::<PairwiseDistance<T>>::with_capacity(self.samples.shape().0);
for other in self.neighbours.iter() {
if index_row != *other {
distances.push(PairwiseDistance {
node: index_row,
neighbour: Some(*other),
distance: Some(
T::from(Euclidian::squared_distance(
&Vec::from_iterator(
self.samples.get_row(index_row).iterator(0).copied(),
self.samples.shape().1,
),
&Vec::from_iterator(
self.samples.get_row(*other).iterator(0).copied(),
self.samples.shape().1,
),
))
.unwrap(),
),
})
}
}
distances
}
}
#[cfg(test)]
mod tests_fastpair {
use super::*;
use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix};
/// Brute force algorithm, used only for comparison and testing
pub fn closest_pair_brute(
fastpair: &FastPair<'_, f64, DenseMatrix<f64>>,
) -> PairwiseDistance<f64> {
use itertools::Itertools;
let m = fastpair.samples.shape().0;
let mut closest_pair = PairwiseDistance {
node: 0,
neighbour: Option::None,
distance: Some(f64::max_value()),
};
for pair in (0..m).combinations(2) {
let d = Euclidian::squared_distance(
&Vec::from_iterator(
fastpair.samples.get_row(pair[0]).iterator(0).copied(),
fastpair.samples.shape().1,
),
&Vec::from_iterator(
fastpair.samples.get_row(pair[1]).iterator(0).copied(),
fastpair.samples.shape().1,
),
);
if d < closest_pair.distance.unwrap() {
closest_pair.node = pair[0];
closest_pair.neighbour = Some(pair[1]);
closest_pair.distance = Some(d);
}
}
closest_pair
}
#[test]
fn fastpair_init() {
let x: DenseMatrix<f64> = DenseMatrix::rand(10, 4);
let _fastpair = FastPair::new(&x);
assert!(_fastpair.is_ok());
let fastpair = _fastpair.unwrap();
let distances = fastpair.distances;
let neighbours = fastpair.neighbours;
assert!(!distances.is_empty());
assert!(!neighbours.is_empty());
assert_eq!(10, neighbours.len());
assert_eq!(10, distances.len());
}
#[test]
fn dataset_has_at_least_three_points() {
// Create a dataset which consists of only two points:
// A(0.0, 0.0) and B(1.0, 1.0).
let dataset = DenseMatrix::<f64>::from_2d_array(&[&[0.0, 0.0], &[1.0, 1.0]]).unwrap();
// We expect an error when we run `FastPair` on this dataset,
// becuase `FastPair` currently only works on a minimum of 3
// points.
let fastpair = FastPair::new(&dataset);
assert!(fastpair.is_err());
if let Err(e) = fastpair {
let expected_error =
Failed::because(FailedError::FindFailed, "min number of rows should be 3");
assert_eq!(e, expected_error)
}
}
#[test]
fn one_dimensional_dataset_minimal() {
let dataset = DenseMatrix::<f64>::from_2d_array(&[&[0.0], &[2.0], &[9.0]]).unwrap();
let result = FastPair::new(&dataset);
assert!(result.is_ok());
let fastpair = result.unwrap();
let closest_pair = fastpair.closest_pair();
let expected_closest_pair = PairwiseDistance {
node: 0,
neighbour: Some(1),
distance: Some(4.0),
};
assert_eq!(closest_pair, expected_closest_pair);
let closest_pair_brute = closest_pair_brute(&fastpair);
assert_eq!(closest_pair_brute, expected_closest_pair);
}
#[test]
fn one_dimensional_dataset_2() {
let dataset =
DenseMatrix::<f64>::from_2d_array(&[&[27.0], &[0.0], &[9.0], &[2.0]]).unwrap();
let result = FastPair::new(&dataset);
assert!(result.is_ok());
let fastpair = result.unwrap();
let closest_pair = fastpair.closest_pair();
let expected_closest_pair = PairwiseDistance {
node: 1,
neighbour: Some(3),
distance: Some(4.0),
};
assert_eq!(closest_pair, closest_pair_brute(&fastpair));
assert_eq!(closest_pair, expected_closest_pair);
}
#[test]
fn fastpair_new() {
// compute
let x = DenseMatrix::<f64>::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
])
.unwrap();
let fastpair = FastPair::new(&x);
assert!(fastpair.is_ok());
// unwrap results
let result = fastpair.unwrap();
// list of minimal pairwise dissimilarities
let dissimilarities = vec![
(
1,
PairwiseDistance {
node: 1,
neighbour: Some(9),
distance: Some(0.030000000000000037),
},
),
(
10,
PairwiseDistance {
node: 10,
neighbour: Some(12),
distance: Some(0.07000000000000003),
},
),
(
11,
PairwiseDistance {
node: 11,
neighbour: Some(14),
distance: Some(0.18000000000000013),
},
),
(
12,
PairwiseDistance {
node: 12,
neighbour: Some(14),
distance: Some(0.34000000000000086),
},
),
(
13,
PairwiseDistance {
node: 13,
neighbour: Some(14),
distance: Some(1.6499999999999997),
},
),
(
14,
PairwiseDistance {
node: 14,
neighbour: Some(14),
distance: Some(f64::MAX),
},
),
(
6,
PairwiseDistance {
node: 6,
neighbour: Some(7),
distance: Some(0.18000000000000027),
},
),
(
0,
PairwiseDistance {
node: 0,
neighbour: Some(4),
distance: Some(0.01999999999999995),
},
),
(
8,
PairwiseDistance {
node: 8,
neighbour: Some(9),
distance: Some(0.3100000000000001),
},
),
(
2,
PairwiseDistance {
node: 2,
neighbour: Some(3),
distance: Some(0.0600000000000001),
},
),
(
3,
PairwiseDistance {
node: 3,
neighbour: Some(8),
distance: Some(0.08999999999999982),
},
),
(
7,
PairwiseDistance {
node: 7,
neighbour: Some(9),
distance: Some(0.10999999999999982),
},
),
(
9,
PairwiseDistance {
node: 9,
neighbour: Some(13),
distance: Some(8.69),
},
),
(
4,
PairwiseDistance {
node: 4,
neighbour: Some(7),
distance: Some(0.050000000000000086),
},
),
(
5,
PairwiseDistance {
node: 5,
neighbour: Some(7),
distance: Some(0.4900000000000002),
},
),
];
let expected: HashMap<_, _> = dissimilarities.into_iter().collect();
for i in 0..(x.shape().0 - 1) {
let input_neighbour: usize = expected.get(&i).unwrap().neighbour.unwrap();
let distance = Euclidian::squared_distance(
&Vec::from_iterator(
result.samples.get_row(i).iterator(0).copied(),
result.samples.shape().1,
),
&Vec::from_iterator(
result.samples.get_row(input_neighbour).iterator(0).copied(),
result.samples.shape().1,
),
);
assert_eq!(i, expected.get(&i).unwrap().node);
assert_eq!(
input_neighbour,
expected.get(&i).unwrap().neighbour.unwrap()
);
assert_eq!(distance, expected.get(&i).unwrap().distance.unwrap());
}
}
#[test]
fn fastpair_closest_pair() {
let x = DenseMatrix::<f64>::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
])
.unwrap();
// compute
let fastpair = FastPair::new(&x);
assert!(fastpair.is_ok());
let dissimilarity = fastpair.unwrap().closest_pair();
let closest = PairwiseDistance {
node: 0,
neighbour: Some(4),
distance: Some(0.01999999999999995),
};
assert_eq!(closest, dissimilarity);
}
#[test]
fn fastpair_closest_pair_random_matrix() {
let x = DenseMatrix::<f64>::rand(200, 25);
// compute
let fastpair = FastPair::new(&x);
assert!(fastpair.is_ok());
let result = fastpair.unwrap();
let dissimilarity1 = result.closest_pair();
let dissimilarity2 = closest_pair_brute(&result);
assert_eq!(dissimilarity1, dissimilarity2);
}
#[test]
fn fastpair_distances() {
let x = DenseMatrix::<f64>::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
])
.unwrap();
// compute
let fastpair = FastPair::new(&x);
assert!(fastpair.is_ok());
let dissimilarities = fastpair.unwrap().distances_from(0);
let mut min_dissimilarity = PairwiseDistance {
node: 0,
neighbour: Option::None,
distance: Some(f64::MAX),
};
for p in dissimilarities.iter() {
if p.distance.unwrap() < min_dissimilarity.distance.unwrap() {
min_dissimilarity = *p
}
}
let closest = PairwiseDistance {
node: 0,
neighbour: Some(4),
distance: Some(0.01999999999999995),
};
assert_eq!(closest, min_dissimilarity);
}
#[test]
fn fastpair_ordered_pairs() {
let x = DenseMatrix::<f64>::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
])
.unwrap();
let fastpair = FastPair::new(&x).unwrap();
let ordered = fastpair.ordered_pairs();
let mut previous: f64 = -1.0;
for p in ordered {
if previous == -1.0 {
previous = p.distance.unwrap();
} else {
let current = p.distance.unwrap();
assert!(current >= previous);
previous = current;
}
}
}
#[test]
fn test_empty_set() {
let empty_matrix = DenseMatrix::<f64>::zeros(0, 0);
let result = FastPair::new(&empty_matrix);
assert!(result.is_err());
if let Err(e) = result {
assert_eq!(
e,
Failed::because(FailedError::FindFailed, "min number of rows should be 3")
);
}
}
#[test]
fn test_single_point() {
let single_point = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]).unwrap();
let result = FastPair::new(&single_point);
assert!(result.is_err());
if let Err(e) = result {
assert_eq!(
e,
Failed::because(FailedError::FindFailed, "min number of rows should be 3")
);
}
}
#[test]
fn test_two_points() {
let two_points = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
let result = FastPair::new(&two_points);
assert!(result.is_err());
if let Err(e) = result {
assert_eq!(
e,
Failed::because(FailedError::FindFailed, "min number of rows should be 3")
);
}
}
#[test]
fn test_three_identical_points() {
let identical_points =
DenseMatrix::from_2d_array(&[&[1.0, 1.0], &[1.0, 1.0], &[1.0, 1.0]]).unwrap();
let result = FastPair::new(&identical_points);
assert!(result.is_ok());
let fastpair = result.unwrap();
let closest_pair = fastpair.closest_pair();
assert_eq!(closest_pair.distance, Some(0.0));
}
#[test]
fn test_result_unwrapping() {
let valid_matrix =
DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0], &[5.0, 6.0], &[7.0, 8.0]])
.unwrap();
let result = FastPair::new(&valid_matrix);
assert!(result.is_ok());
// This should not panic
let _fastpair = result.unwrap();
}
}
+30 -29
View File
@@ -3,12 +3,12 @@
//! see [KNN algorithms](../index.html)
//! ```
//! use smartcore::algorithm::neighbour::linear_search::*;
//! use smartcore::metrics::distance::Distance;
//! use smartcore::math::distance::Distance;
//!
//! #[derive(Clone)]
//! struct SimpleDistance {} // Our distance function
//!
//! impl Distance<i32> for SimpleDistance {
//! impl Distance<i32, f64> for SimpleDistance {
//! fn distance(&self, a: &i32, b: &i32) -> f64 { // simple simmetrical scalar distance
//! (a - b).abs() as f64
//! }
@@ -25,31 +25,38 @@
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::cmp::{Ordering, PartialOrd};
use std::marker::PhantomData;
use crate::algorithm::sort::heap_select::HeapSelection;
use crate::error::{Failed, FailedError};
use crate::metrics::distance::Distance;
use crate::math::distance::Distance;
use crate::math::num::RealNumber;
/// Implements Linear Search algorithm, see [KNN algorithms](../index.html)
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct LinearKNNSearch<T, D: Distance<T>> {
pub struct LinearKNNSearch<T, F: RealNumber, D: Distance<T, F>> {
distance: D,
data: Vec<T>,
f: PhantomData<F>,
}
impl<T, D: Distance<T>> LinearKNNSearch<T, D> {
impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
/// Initializes algorithm.
/// * `data` - vector of data points to search for.
/// * `distance` - distance metric to use for searching. This function should extend [`Distance`](../../../math/distance/index.html) interface.
pub fn new(data: Vec<T>, distance: D) -> Result<LinearKNNSearch<T, D>, Failed> {
Ok(LinearKNNSearch { data, distance })
pub fn new(data: Vec<T>, distance: D) -> Result<LinearKNNSearch<T, F, D>, Failed> {
Ok(LinearKNNSearch {
data,
distance,
f: PhantomData,
})
}
/// Find k nearest neighbors
/// * `from` - look for k nearest points to `from`
/// * `k` - the number of nearest neighbors to return
pub fn find(&self, from: &T, k: usize) -> Result<Vec<(usize, f64, &T)>, Failed> {
pub fn find(&self, from: &T, k: usize) -> Result<Vec<(usize, F, &T)>, Failed> {
if k < 1 || k > self.data.len() {
return Err(Failed::because(
FailedError::FindFailed,
@@ -57,11 +64,11 @@ impl<T, D: Distance<T>> LinearKNNSearch<T, D> {
));
}
let mut heap = HeapSelection::<KNNPoint>::with_capacity(k);
let mut heap = HeapSelection::<KNNPoint<F>>::with_capacity(k);
for _ in 0..k {
heap.add(KNNPoint {
distance: f64::INFINITY,
distance: F::infinity(),
index: None,
});
}
@@ -86,15 +93,15 @@ impl<T, D: Distance<T>> LinearKNNSearch<T, D> {
/// Find all nearest neighbors within radius `radius` from `p`
/// * `p` - look for k nearest points to `p`
/// * `radius` - radius of the search
pub fn find_radius(&self, from: &T, radius: f64) -> Result<Vec<(usize, f64, &T)>, Failed> {
if radius <= 0f64 {
pub fn find_radius(&self, from: &T, radius: F) -> Result<Vec<(usize, F, &T)>, Failed> {
if radius <= F::zero() {
return Err(Failed::because(
FailedError::FindFailed,
"radius should be > 0",
));
}
let mut neighbors: Vec<(usize, f64, &T)> = Vec::new();
let mut neighbors: Vec<(usize, F, &T)> = Vec::new();
for i in 0..self.data.len() {
let d = self.distance.distance(from, &self.data[i]);
@@ -109,44 +116,41 @@ impl<T, D: Distance<T>> LinearKNNSearch<T, D> {
}
#[derive(Debug)]
struct KNNPoint {
distance: f64,
struct KNNPoint<F: RealNumber> {
distance: F,
index: Option<usize>,
}
impl PartialOrd for KNNPoint {
impl<F: RealNumber> PartialOrd for KNNPoint<F> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.distance.partial_cmp(&other.distance)
}
}
impl PartialEq for KNNPoint {
impl<F: RealNumber> PartialEq for KNNPoint<F> {
fn eq(&self, other: &Self) -> bool {
self.distance == other.distance
}
}
impl Eq for KNNPoint {}
impl<F: RealNumber> Eq for KNNPoint<F> {}
#[cfg(test)]
mod tests {
use super::*;
use crate::metrics::distance::Distances;
use crate::math::distance::Distances;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
struct SimpleDistance {}
impl Distance<i32> for SimpleDistance {
impl Distance<i32, f64> for SimpleDistance {
fn distance(&self, a: &i32, b: &i32) -> f64 {
(a - b).abs() as f64
}
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn knn_find() {
let data1 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
@@ -193,10 +197,7 @@ mod tests {
assert_eq!(vec!(1, 2, 3), found_idxs2);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn knn_point_eq() {
let point1 = KNNPoint {
@@ -215,7 +216,7 @@ mod tests {
};
let point_inf = KNNPoint {
distance: f64::INFINITY,
distance: std::f64::INFINITY,
index: Some(3),
};
+12 -18
View File
@@ -1,4 +1,4 @@
#![allow(clippy::ptr_arg, clippy::needless_range_loop)]
#![allow(clippy::ptr_arg)]
//! # Nearest Neighbors Search Algorithms and Data Structures
//!
//! Nearest neighbor search is a basic computational tool that is particularly relevant to machine learning,
@@ -33,43 +33,37 @@
use crate::algorithm::neighbour::cover_tree::CoverTree;
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
use crate::error::Failed;
use crate::metrics::distance::Distance;
use crate::numbers::basenum::Number;
use crate::math::distance::Distance;
use crate::math::num::RealNumber;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
pub(crate) mod bbd_tree;
/// a variant of fastpair using cosine distance
pub mod cosinepair;
/// tree data structure for fast nearest neighbor search
pub mod cover_tree;
/// fastpair closest neighbour algorithm
pub mod fastpair;
/// very simple algorithm that sequentially checks each element of the list until a match is found or the whole list has been searched.
pub mod linear_search;
/// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries.
/// `KNNAlgorithmName` maintains a list of supported search algorithms, see [KNN algorithms](../algorithm/neighbour/index.html)
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, Default)]
#[derive(Debug, Clone)]
pub enum KNNAlgorithmName {
/// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html)
LinearSearch,
/// Cover Tree Search algorithm, see [`CoverTree`](../algorithm/neighbour/cover_tree/index.html)
#[default]
CoverTree,
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub(crate) enum KNNAlgorithm<T: Number, D: Distance<Vec<T>>> {
LinearSearch(LinearKNNSearch<Vec<T>, D>),
CoverTree(CoverTree<Vec<T>, D>),
pub(crate) enum KNNAlgorithm<T: RealNumber, D: Distance<Vec<T>, T>> {
LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
CoverTree(CoverTree<Vec<T>, T, D>),
}
// TODO: missing documentation
impl KNNAlgorithmName {
pub(crate) fn fit<T: Number, D: Distance<Vec<T>>>(
pub(crate) fn fit<T: RealNumber, D: Distance<Vec<T>, T>>(
&self,
data: Vec<Vec<T>>,
distance: D,
@@ -85,8 +79,8 @@ impl KNNAlgorithmName {
}
}
impl<T: Number, D: Distance<Vec<T>>> KNNAlgorithm<T, D> {
pub fn find(&self, from: &Vec<T>, k: usize) -> Result<Vec<(usize, f64, &Vec<T>)>, Failed> {
impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNAlgorithm<T, D> {
pub fn find(&self, from: &Vec<T>, k: usize) -> Result<Vec<(usize, T, &Vec<T>)>, Failed> {
match *self {
KNNAlgorithm::LinearSearch(ref linear) => linear.find(from, k),
KNNAlgorithm::CoverTree(ref cover) => cover.find(from, k),
@@ -96,8 +90,8 @@ impl<T: Number, D: Distance<Vec<T>>> KNNAlgorithm<T, D> {
pub fn find_radius(
&self,
from: &Vec<T>,
radius: f64,
) -> Result<Vec<(usize, f64, &Vec<T>)>, Failed> {
radius: T,
) -> Result<Vec<(usize, T, &Vec<T>)>, Failed> {
match *self {
KNNAlgorithm::LinearSearch(ref linear) => linear.find_radius(from, radius),
KNNAlgorithm::CoverTree(ref cover) => cover.find_radius(from, radius),
+8 -23
View File
@@ -12,7 +12,7 @@ pub struct HeapSelection<T: PartialOrd + Debug> {
heap: Vec<T>,
}
impl<T: PartialOrd + Debug> HeapSelection<T> {
impl<'a, T: PartialOrd + Debug> HeapSelection<T> {
pub fn with_capacity(k: usize) -> HeapSelection<T> {
HeapSelection {
k,
@@ -95,20 +95,14 @@ impl<T: PartialOrd + Debug> HeapSelection<T> {
mod tests {
use super::*;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn with_capacity() {
let heap = HeapSelection::<i32>::with_capacity(3);
assert_eq!(3, heap.k);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_add() {
let mut heap = HeapSelection::with_capacity(3);
@@ -126,14 +120,11 @@ mod tests {
assert_eq!(vec![2, 0, -5], heap.get());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_add1() {
let mut heap = HeapSelection::with_capacity(3);
heap.add(f64::INFINITY);
heap.add(std::f64::INFINITY);
heap.add(-5f64);
heap.add(4f64);
heap.add(-1f64);
@@ -144,14 +135,11 @@ mod tests {
assert_eq!(vec![0f64, -1f64, -5f64], heap.get());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_add2() {
let mut heap = HeapSelection::with_capacity(3);
heap.add(f64::INFINITY);
heap.add(std::f64::INFINITY);
heap.add(0.0);
heap.add(8.4852);
heap.add(5.6568);
@@ -160,10 +148,7 @@ mod tests {
assert_eq!(vec![5.6568, 2.8284, 0.0], heap.get());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_add_ordered() {
let mut heap = HeapSelection::with_capacity(3);
+3 -8
View File
@@ -1,14 +1,12 @@
use num_traits::Num;
use num_traits::Float;
pub trait QuickArgSort {
#[allow(dead_code)]
fn quick_argsort_mut(&mut self) -> Vec<usize>;
#[allow(dead_code)]
fn quick_argsort(&self) -> Vec<usize>;
}
impl<T: Num + PartialOrd + Copy> QuickArgSort for Vec<T> {
impl<T: Float> QuickArgSort for Vec<T> {
fn quick_argsort(&self) -> Vec<usize> {
let mut v = self.clone();
v.quick_argsort_mut()
@@ -115,10 +113,7 @@ impl<T: Num + PartialOrd + Copy> QuickArgSort for Vec<T> {
mod tests {
use super::*;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn with_capacity() {
let arr1 = vec![0.3, 0.1, 0.2, 0.4, 0.9, 0.5, 0.7, 0.6, 0.8];
+2 -34
View File
@@ -16,12 +16,8 @@ pub trait UnsupervisedEstimator<X, P> {
P: Clone;
}
/// An estimator for supervised learning, that provides method `fit` to learn from data and training values
pub trait SupervisedEstimator<X, Y, P>: Predictor<X, Y> {
/// Empty constructor, instantiate an empty estimator. Object is dropped as soon as `fit()` is called.
/// used to pass around the correct `fit()` implementation.
/// by calling `::fit()`. mostly used to be used with `model_selection::cross_validate(...)`
fn new() -> Self;
/// An estimator for supervised learning, , that provides method `fit` to learn from data and training values
pub trait SupervisedEstimator<X, Y, P> {
/// Fit a model to a training dataset, estimate model's parameters.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - target training values of size _N_.
@@ -32,24 +28,6 @@ pub trait SupervisedEstimator<X, Y, P>: Predictor<X, Y> {
P: Clone;
}
/// An estimator for supervised learning.
/// In this one parameters are borrowed instead of moved, this is useful for parameters that carry
/// references. Also to be used when there is no predictor attached to the estimator.
pub trait SupervisedEstimatorBorrow<'a, X, Y, P> {
/// Empty constructor, instantiate an empty estimator. Object is dropped as soon as `fit()` is called.
/// used to pass around the correct `fit()` implementation.
/// by calling `::fit()`. mostly used to be used with `model_selection::cross_validate(...)`
fn new() -> Self;
/// Fit a model to a training dataset, estimate model's parameters.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - target training values of size _N_.
/// * `&parameters` - hyperparameters of an algorithm
fn fit(x: &'a X, y: &'a Y, parameters: &'a P) -> Result<Self, Failed>
where
Self: Sized,
P: Clone;
}
/// Implements method predict that estimates target value from new data
pub trait Predictor<X, Y> {
/// Estimate target values from new data.
@@ -57,19 +35,9 @@ pub trait Predictor<X, Y> {
fn predict(&self, x: &X) -> Result<Y, Failed>;
}
/// Implements method predict that estimates target value from new data, with borrowing
pub trait PredictorBorrow<'a, X, T> {
/// Estimate target values from new data.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
fn predict(&self, x: &'a X) -> Result<Vec<T>, Failed>;
}
/// Implements method transform that filters or modifies input data
pub trait Transformer<X> {
/// Transform data by modifying or filtering it
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
fn transform(&self, x: &X) -> Result<X, Failed>;
}
/// empty parameters for an estimator, see `BiasedEstimator`
pub trait NoParameters {}
-317
View File
@@ -1,317 +0,0 @@
//! # Agglomerative Hierarchical Clustering
//!
//! Agglomerative clustering is a "bottom-up" hierarchical clustering method. It works by placing each data point in its own cluster and then successively merging the two most similar clusters until a stopping criterion is met. This process creates a tree-based hierarchy of clusters known as a dendrogram.
//!
//! The similarity of two clusters is determined by a **linkage criterion**. This implementation uses **single-linkage**, where the distance between two clusters is defined as the minimum distance between any single point in the first cluster and any single point in the second cluster. The distance between points is the standard Euclidean distance.
//!
//! The algorithm first builds the full hierarchy of `N-1` merges. To obtain a specific number of clusters, `n_clusters`, the algorithm then effectively "cuts" the dendrogram at the point where `n_clusters` remain.
//!
//! ## Example:
//!
//! ```
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::cluster::agglomerative::{AgglomerativeClustering, AgglomerativeClusteringParameters};
//!
//! // A dataset with 2 distinct groups of points.
//! let x = DenseMatrix::from_2d_array(&[
//! &[0.0, 0.0], &[1.0, 1.0], &[0.5, 0.5], // Cluster A
//! &[10.0, 10.0], &[11.0, 11.0], &[10.5, 10.5], // Cluster B
//! ]).unwrap();
//!
//! // Set parameters to find 2 clusters.
//! let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(2);
//!
//! // Fit the model to the data.
//! let clustering = AgglomerativeClustering::<f64, usize, DenseMatrix<f64>, Vec<usize>>::fit(&x, parameters).unwrap();
//!
//! // Get the cluster assignments.
//! let labels = clustering.labels; // e.g., [0, 0, 0, 1, 1, 1]
//! ```
//!
//! ## References:
//!
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 10.3.2 Hierarchical Clustering](http://faculty.marshall.usc.edu/gareth-james/ISL/)
//! * ["The Elements of Statistical Learning", Hastie T., Tibshirani R., Friedman J., 14.3.12 Hierarchical Clustering](https://hastie.su.domains/ElemStatLearn/)
use std::collections::HashMap;
use std::marker::PhantomData;
use crate::api::UnsupervisedEstimator;
use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::numbers::basenum::Number;
/// Parameters for the Agglomerative Clustering algorithm.
#[derive(Debug, Clone, Copy)]
pub struct AgglomerativeClusteringParameters {
/// The number of clusters to find.
pub n_clusters: usize,
}
impl AgglomerativeClusteringParameters {
/// Sets the number of clusters.
///
/// # Arguments
/// * `n_clusters` - The desired number of clusters.
pub fn with_n_clusters(mut self, n_clusters: usize) -> Self {
self.n_clusters = n_clusters;
self
}
}
impl Default for AgglomerativeClusteringParameters {
fn default() -> Self {
AgglomerativeClusteringParameters { n_clusters: 2 }
}
}
/// Agglomerative Clustering model.
///
/// This implementation uses single-linkage clustering, which is mathematically
/// equivalent to finding the Minimum Spanning Tree (MST) of the data points.
/// The core logic is an efficient implementation of Kruskal's algorithm, which
/// processes all pairwise distances in increasing order and uses a Disjoint
/// Set Union (DSU) data structure to track cluster membership.
#[derive(Debug)]
pub struct AgglomerativeClustering<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
/// The cluster label assigned to each sample.
pub labels: Vec<usize>,
_phantom_tx: PhantomData<TX>,
_phantom_ty: PhantomData<TY>,
_phantom_x: PhantomData<X>,
_phantom_y: PhantomData<Y>,
}
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> AgglomerativeClustering<TX, TY, X, Y> {
/// Fits the agglomerative clustering model to the data.
///
/// # Arguments
/// * `data` - A reference to the input data matrix.
/// * `parameters` - The parameters for the clustering algorithm, including `n_clusters`.
///
/// # Returns
/// A `Result` containing the fitted model with cluster labels, or an error if
pub fn fit(data: &X, parameters: AgglomerativeClusteringParameters) -> Result<Self, Failed> {
let (num_samples, _) = data.shape();
let n_clusters = parameters.n_clusters;
if n_clusters > num_samples {
return Err(Failed::because(
FailedError::ParametersError,
&format!(
"n_clusters: {n_clusters} cannot be greater than n_samples: {num_samples}"
),
));
}
let mut distance_pairs = Vec::new();
for i in 0..num_samples {
for j in (i + 1)..num_samples {
let distance: f64 = data
.get_row(i)
.iterator(0)
.zip(data.get_row(j).iterator(0))
.map(|(&a, &b)| (a.to_f64().unwrap() - b.to_f64().unwrap()).powi(2))
.sum::<f64>();
distance_pairs.push((distance, i, j));
}
}
distance_pairs.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
let mut parent = HashMap::new();
let mut children = HashMap::new();
for i in 0..num_samples {
parent.insert(i, i);
children.insert(i, vec![i]);
}
let mut merge_history = Vec::new();
let num_merges_needed = num_samples - 1;
while merge_history.len() < num_merges_needed {
let (_, p1, p2) = distance_pairs.pop().unwrap();
let root1 = parent[&p1];
let root2 = parent[&p2];
if root1 != root2 {
let root2_children = children.remove(&root2).unwrap();
for child in root2_children.iter() {
parent.insert(*child, root1);
}
let root1_children = children.get_mut(&root1).unwrap();
root1_children.extend(root2_children);
merge_history.push((root1, root2));
}
}
let mut clusters = HashMap::new();
let mut assignments = HashMap::new();
for i in 0..num_samples {
clusters.insert(i, vec![i]);
assignments.insert(i, i);
}
let merges_to_apply = num_samples - n_clusters;
for (root1, root2) in merge_history[0..merges_to_apply].iter() {
let root1_cluster = assignments[root1];
let root2_cluster = assignments[root2];
let root2_assignments = clusters.remove(&root2_cluster).unwrap();
for assignment in root2_assignments.iter() {
assignments.insert(*assignment, root1_cluster);
}
let root1_assignments = clusters.get_mut(&root1_cluster).unwrap();
root1_assignments.extend(root2_assignments);
}
let mut labels: Vec<usize> = (0..num_samples).map(|_| 0).collect();
let mut cluster_keys: Vec<&usize> = clusters.keys().collect();
cluster_keys.sort();
for (i, key) in cluster_keys.into_iter().enumerate() {
for index in clusters[key].iter() {
labels[*index] = i;
}
}
Ok(AgglomerativeClustering {
labels,
_phantom_tx: PhantomData,
_phantom_ty: PhantomData,
_phantom_x: PhantomData,
_phantom_y: PhantomData,
})
}
}
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>>
UnsupervisedEstimator<X, AgglomerativeClusteringParameters>
for AgglomerativeClustering<TX, TY, X, Y>
{
fn fit(x: &X, parameters: AgglomerativeClusteringParameters) -> Result<Self, Failed> {
AgglomerativeClustering::fit(x, parameters)
}
}
#[cfg(test)]
mod tests {
use crate::linalg::basic::matrix::DenseMatrix;
use std::collections::HashSet;
use super::*;
#[test]
fn test_simple_clustering() {
// Two distinct clusters, far apart.
let data = vec![
0.0, 0.0, 1.0, 1.0, 0.5, 0.5, // Cluster A
10.0, 10.0, 11.0, 11.0, 10.5, 10.5, // Cluster B
];
let matrix = DenseMatrix::new(6, 2, data, false).unwrap();
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(2);
// Using f64 for TY as usize doesn't satisfy the Number trait bound.
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
&matrix, parameters,
)
.unwrap();
let labels = clustering.labels;
// Check that all points in the first group have the same label.
let first_group_label = labels[0];
assert!(labels[0..3].iter().all(|&l| l == first_group_label));
// Check that all points in the second group have the same label.
let second_group_label = labels[3];
assert!(labels[3..6].iter().all(|&l| l == second_group_label));
// Check that the two groups have different labels.
assert_ne!(first_group_label, second_group_label);
}
#[test]
fn test_four_clusters() {
// Four distinct clusters in the corners of a square.
let data = vec![
0.0, 0.0, 1.0, 1.0, // Cluster A
100.0, 100.0, 101.0, 101.0, // Cluster B
0.0, 100.0, 1.0, 101.0, // Cluster C
100.0, 0.0, 101.0, 1.0, // Cluster D
];
let matrix = DenseMatrix::new(8, 2, data, false).unwrap();
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(4);
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
&matrix, parameters,
)
.unwrap();
let labels = clustering.labels;
// Verify that there are exactly 4 unique labels produced.
let unique_labels: HashSet<usize> = labels.iter().cloned().collect();
assert_eq!(unique_labels.len(), 4);
// Verify that points within each original group were assigned the same cluster label.
let label_a = labels[0];
assert_eq!(label_a, labels[1]);
let label_b = labels[2];
assert_eq!(label_b, labels[3]);
let label_c = labels[4];
assert_eq!(label_c, labels[5]);
let label_d = labels[6];
assert_eq!(label_d, labels[7]);
// Verify that all four groups received different labels.
assert_ne!(label_a, label_b);
assert_ne!(label_a, label_c);
assert_ne!(label_a, label_d);
assert_ne!(label_b, label_c);
assert_ne!(label_b, label_d);
assert_ne!(label_c, label_d);
}
#[test]
fn test_n_clusters_equal_to_samples() {
let data = vec![0.0, 0.0, 5.0, 5.0, 10.0, 10.0];
let matrix = DenseMatrix::new(3, 2, data, false).unwrap();
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(3);
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
&matrix, parameters,
)
.unwrap();
// Each point should be its own cluster. Sorting makes the test deterministic.
let mut labels = clustering.labels;
labels.sort();
assert_eq!(labels, vec![0, 1, 2]);
}
#[test]
fn test_one_cluster() {
let data = vec![0.0, 0.0, 5.0, 5.0, 10.0, 10.0];
let matrix = DenseMatrix::new(3, 2, data, false).unwrap();
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(1);
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
&matrix, parameters,
)
.unwrap();
// All points should be in the same cluster.
assert_eq!(clustering.labels, vec![0, 0, 0]);
}
#[test]
fn test_error_on_too_many_clusters() {
let data = vec![0.0, 0.0, 5.0, 5.0];
let matrix = DenseMatrix::new(2, 2, data, false).unwrap();
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(3);
let result = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
&matrix, parameters,
);
assert!(result.is_err());
}
}
+59 -239
View File
@@ -18,20 +18,19 @@
//!
//! Example:
//!
//! ```ignore
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linalg::basic::arrays::Array2;
//! ```
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::cluster::dbscan::*;
//! use smartcore::metrics::distance::Distances;
//! use smartcore::math::distance::Distances;
//! use smartcore::neighbors::KNNAlgorithmName;
//! use smartcore::dataset::generator;
//!
//! // Generate three blobs
//! let blobs = generator::make_blobs(100, 2, 3);
//! let x: DenseMatrix<f32> = DenseMatrix::from_iterator(blobs.data.into_iter(), 100, 2, 0);
//! let x = DenseMatrix::from_vec(blobs.num_samples, blobs.num_features, &blobs.data);
//! // Fit the algorithm and predict cluster labels
//! let labels: Vec<u32> = DBSCAN::fit(&x, DBSCANParameters::default().with_eps(3.0)).
//! and_then(|dbscan| dbscan.predict(&x)).unwrap();
//! let labels = DBSCAN::fit(&x, DBSCANParameters::default().with_eps(3.0)).
//! and_then(|dbscan| dbscan.predict(&x));
//!
//! println!("{:?}", labels);
//! ```
@@ -42,7 +41,7 @@
//! * ["Density-Based Clustering in Spatial Databases: The Algorithm GDBSCAN and its Applications", Sander J., Ester M., Kriegel HP., Xu X.](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.63.1629&rep=rep1&type=pdf)
use std::fmt::Debug;
use std::marker::PhantomData;
use std::iter::Sum;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
@@ -50,58 +49,47 @@ use serde::{Deserialize, Serialize};
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
use crate::api::{Predictor, UnsupervisedEstimator};
use crate::error::Failed;
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::metrics::distance::euclidian::Euclidian;
use crate::metrics::distance::{Distance, Distances};
use crate::numbers::basenum::Number;
use crate::linalg::{row_iter, Matrix};
use crate::math::distance::euclidian::Euclidian;
use crate::math::distance::{Distance, Distances};
use crate::math::num::RealNumber;
use crate::tree::decision_tree_classifier::which_max;
/// DBSCAN clustering algorithm
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct DBSCAN<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>> {
pub struct DBSCAN<T: RealNumber, D: Distance<Vec<T>, T>> {
cluster_labels: Vec<i16>,
num_classes: usize,
knn_algorithm: KNNAlgorithm<TX, D>,
eps: f64,
_phantom_ty: PhantomData<TY>,
_phantom_x: PhantomData<X>,
_phantom_y: PhantomData<Y>,
knn_algorithm: KNNAlgorithm<T, D>,
eps: T,
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// DBSCAN clustering algorithm parameters
pub struct DBSCANParameters<T: Number, D: Distance<Vec<T>>> {
#[cfg_attr(feature = "serde", serde(default))]
pub struct DBSCANParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
/// a function that defines a distance between each pair of point in training data.
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
pub distance: D,
#[cfg_attr(feature = "serde", serde(default))]
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
pub min_samples: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
pub eps: f64,
#[cfg_attr(feature = "serde", serde(default))]
pub eps: T,
/// KNN algorithm to use.
pub algorithm: KNNAlgorithmName,
#[cfg_attr(feature = "serde", serde(default))]
_phantom_t: PhantomData<T>,
}
impl<T: Number, D: Distance<Vec<T>>> DBSCANParameters<T, D> {
impl<T: RealNumber, D: Distance<Vec<T>, T>> DBSCANParameters<T, D> {
/// a function that defines a distance between each pair of point in training data.
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
pub fn with_distance<DD: Distance<Vec<T>>>(self, distance: DD) -> DBSCANParameters<T, DD> {
pub fn with_distance<DD: Distance<Vec<T>, T>>(self, distance: DD) -> DBSCANParameters<T, DD> {
DBSCANParameters {
distance,
min_samples: self.min_samples,
eps: self.eps,
algorithm: self.algorithm,
_phantom_t: PhantomData,
}
}
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
@@ -110,7 +98,7 @@ impl<T: Number, D: Distance<Vec<T>>> DBSCANParameters<T, D> {
self
}
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
pub fn with_eps(mut self, eps: f64) -> Self {
pub fn with_eps(mut self, eps: T) -> Self {
self.eps = eps;
self
}
@@ -121,113 +109,7 @@ impl<T: Number, D: Distance<Vec<T>>> DBSCANParameters<T, D> {
}
}
/// DBSCAN grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct DBSCANSearchParameters<T: Number, D: Distance<Vec<T>>> {
#[cfg_attr(feature = "serde", serde(default))]
/// a function that defines a distance between each pair of point in training data.
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
pub distance: Vec<D>,
#[cfg_attr(feature = "serde", serde(default))]
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
pub min_samples: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
pub eps: Vec<f64>,
#[cfg_attr(feature = "serde", serde(default))]
/// KNN algorithm to use.
pub algorithm: Vec<KNNAlgorithmName>,
_phantom_t: PhantomData<T>,
}
/// DBSCAN grid search iterator
pub struct DBSCANSearchParametersIterator<T: Number, D: Distance<Vec<T>>> {
dbscan_search_parameters: DBSCANSearchParameters<T, D>,
current_distance: usize,
current_min_samples: usize,
current_eps: usize,
current_algorithm: usize,
}
impl<T: Number, D: Distance<Vec<T>>> IntoIterator for DBSCANSearchParameters<T, D> {
type Item = DBSCANParameters<T, D>;
type IntoIter = DBSCANSearchParametersIterator<T, D>;
fn into_iter(self) -> Self::IntoIter {
DBSCANSearchParametersIterator {
dbscan_search_parameters: self,
current_distance: 0,
current_min_samples: 0,
current_eps: 0,
current_algorithm: 0,
}
}
}
impl<T: Number, D: Distance<Vec<T>>> Iterator for DBSCANSearchParametersIterator<T, D> {
type Item = DBSCANParameters<T, D>;
fn next(&mut self) -> Option<Self::Item> {
if self.current_distance == self.dbscan_search_parameters.distance.len()
&& self.current_min_samples == self.dbscan_search_parameters.min_samples.len()
&& self.current_eps == self.dbscan_search_parameters.eps.len()
&& self.current_algorithm == self.dbscan_search_parameters.algorithm.len()
{
return None;
}
let next = DBSCANParameters {
distance: self.dbscan_search_parameters.distance[self.current_distance].clone(),
min_samples: self.dbscan_search_parameters.min_samples[self.current_min_samples],
eps: self.dbscan_search_parameters.eps[self.current_eps],
algorithm: self.dbscan_search_parameters.algorithm[self.current_algorithm].clone(),
_phantom_t: PhantomData,
};
if self.current_distance + 1 < self.dbscan_search_parameters.distance.len() {
self.current_distance += 1;
} else if self.current_min_samples + 1 < self.dbscan_search_parameters.min_samples.len() {
self.current_distance = 0;
self.current_min_samples += 1;
} else if self.current_eps + 1 < self.dbscan_search_parameters.eps.len() {
self.current_distance = 0;
self.current_min_samples = 0;
self.current_eps += 1;
} else if self.current_algorithm + 1 < self.dbscan_search_parameters.algorithm.len() {
self.current_distance = 0;
self.current_min_samples = 0;
self.current_eps = 0;
self.current_algorithm += 1;
} else {
self.current_distance += 1;
self.current_min_samples += 1;
self.current_eps += 1;
self.current_algorithm += 1;
}
Some(next)
}
}
impl<T: Number> Default for DBSCANSearchParameters<T, Euclidian<T>> {
fn default() -> Self {
let default_params = DBSCANParameters::default();
DBSCANSearchParameters {
distance: vec![default_params.distance],
min_samples: vec![default_params.min_samples],
eps: vec![default_params.eps],
algorithm: vec![default_params.algorithm],
_phantom_t: PhantomData,
}
}
}
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>> PartialEq
for DBSCAN<TX, TY, X, Y, D>
{
impl<T: RealNumber, D: Distance<Vec<T>, T>> PartialEq for DBSCAN<T, D> {
fn eq(&self, other: &Self) -> bool {
self.cluster_labels.len() == other.cluster_labels.len()
&& self.num_classes == other.num_classes
@@ -236,50 +118,47 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
}
}
impl<T: Number> Default for DBSCANParameters<T, Euclidian<T>> {
impl<T: RealNumber> Default for DBSCANParameters<T, Euclidian> {
fn default() -> Self {
DBSCANParameters {
distance: Distances::euclidian(),
min_samples: 5,
eps: 0.5f64,
algorithm: KNNAlgorithmName::default(),
_phantom_t: PhantomData,
eps: T::half(),
algorithm: KNNAlgorithmName::CoverTree,
}
}
}
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
UnsupervisedEstimator<X, DBSCANParameters<TX, D>> for DBSCAN<TX, TY, X, Y, D>
impl<T: RealNumber + Sum, M: Matrix<T>, D: Distance<Vec<T>, T>>
UnsupervisedEstimator<M, DBSCANParameters<T, D>> for DBSCAN<T, D>
{
fn fit(x: &X, parameters: DBSCANParameters<TX, D>) -> Result<Self, Failed> {
fn fit(x: &M, parameters: DBSCANParameters<T, D>) -> Result<Self, Failed> {
DBSCAN::fit(x, parameters)
}
}
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>> Predictor<X, Y>
for DBSCAN<TX, TY, X, Y, D>
impl<T: RealNumber, M: Matrix<T>, D: Distance<Vec<T>, T>> Predictor<M, M::RowVector>
for DBSCAN<T, D>
{
fn predict(&self, x: &X) -> Result<Y, Failed> {
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
self.predict(x)
}
}
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
DBSCAN<TX, TY, X, Y, D>
{
impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
/// Fit algorithm to _NxM_ matrix where _N_ is number of samples and _M_ is number of features.
/// * `data` - training instances to cluster
/// * `k` - number of clusters
/// * `parameters` - cluster parameters
pub fn fit(
x: &X,
parameters: DBSCANParameters<TX, D>,
) -> Result<DBSCAN<TX, TY, X, Y, D>, Failed> {
pub fn fit<M: Matrix<T>>(
x: &M,
parameters: DBSCANParameters<T, D>,
) -> Result<DBSCAN<T, D>, Failed> {
if parameters.min_samples < 1 {
return Err(Failed::fit("Invalid minPts"));
}
if parameters.eps <= 0f64 {
if parameters.eps <= T::zero() {
return Err(Failed::fit("Invalid radius: "));
}
@@ -291,19 +170,13 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
let n = x.shape().0;
let mut y = vec![undefined; n];
let algo = parameters.algorithm.fit(
x.row_iter()
.map(|row| row.iterator(0).cloned().collect())
.collect(),
parameters.distance,
)?;
let algo = parameters
.algorithm
.fit(row_iter(x).collect(), parameters.distance)?;
let mut row = vec![TX::zero(); x.shape().1];
for (i, e) in x.row_iter().enumerate() {
for (i, e) in row_iter(x).enumerate() {
if y[i] == undefined {
e.iterator(0).zip(row.iter_mut()).for_each(|(&x, r)| *r = x);
let mut neighbors = algo.find_radius(&row, parameters.eps)?;
let mut neighbors = algo.find_radius(&e, parameters.eps)?;
if neighbors.len() < parameters.min_samples {
y[i] = outlier;
} else {
@@ -315,7 +188,8 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
}
}
while let Some(neighbor) = neighbors.pop() {
while !neighbors.is_empty() {
let neighbor = neighbors.pop().unwrap();
let index = neighbor.0;
if y[index] == outlier {
@@ -353,25 +227,18 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
num_classes: k as usize,
knn_algorithm: algo,
eps: parameters.eps,
_phantom_ty: PhantomData,
_phantom_x: PhantomData,
_phantom_y: PhantomData,
})
}
/// Predict clusters for `x`
/// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features.
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let (n, _) = x.shape();
let mut result = Y::zeros(n);
let mut row = vec![TX::zero(); x.shape().1];
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
let (n, m) = x.shape();
let mut result = M::zeros(1, n);
let mut row = vec![T::zero(); m];
for i in 0..n {
x.get_row(i)
.iterator(0)
.zip(row.iter_mut())
.for_each(|(&x, r)| *r = x);
x.copy_row_as_vec(i, &mut row);
let neighbors = self.knn_algorithm.find_radius(&row, self.eps)?;
let mut label = vec![0usize; self.num_classes + 1];
for neighbor in neighbors {
@@ -384,50 +251,24 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
}
let class = which_max(&label);
if class != self.num_classes {
result.set(i, TY::from(class + 1).unwrap());
result.set(0, i, T::from(class).unwrap());
} else {
result.set(i, TY::zero());
result.set(0, i, -T::one());
}
}
Ok(result)
Ok(result.to_row_vector())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::linalg::naive::dense_matrix::DenseMatrix;
#[cfg(feature = "serde")]
use crate::metrics::distance::euclidian::Euclidian;
use crate::math::distance::euclidian::Euclidian;
#[test]
fn search_parameters() {
let parameters: DBSCANSearchParameters<f64, Euclidian<f64>> = DBSCANSearchParameters {
min_samples: vec![10, 100],
eps: vec![1., 2.],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.min_samples, 10);
assert_eq!(next.eps, 1.);
let next = iter.next().unwrap();
assert_eq!(next.min_samples, 100);
assert_eq!(next.eps, 1.);
let next = iter.next().unwrap();
assert_eq!(next.min_samples, 10);
assert_eq!(next.eps, 2.);
let next = iter.next().unwrap();
assert_eq!(next.min_samples, 100);
assert_eq!(next.eps, 2.);
assert!(iter.next().is_none());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn fit_predict_dbscan() {
let x = DenseMatrix::from_2d_array(&[
@@ -442,10 +283,9 @@ mod tests {
&[2.2, 1.2],
&[1.8, 0.8],
&[3.0, 5.0],
])
.unwrap();
]);
let expected_labels = vec![1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 0];
let expected_labels = vec![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0];
let dbscan = DBSCAN::fit(
&x,
@@ -455,15 +295,12 @@ mod tests {
)
.unwrap();
let predicted_labels: Vec<i32> = dbscan.predict(&x).unwrap();
let predicted_labels = dbscan.predict(&x).unwrap();
assert_eq!(expected_labels, predicted_labels);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
#[cfg(feature = "serde")]
fn serde() {
@@ -488,30 +325,13 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
])
.unwrap();
]);
let dbscan = DBSCAN::fit(&x, Default::default()).unwrap();
let deserialized_dbscan: DBSCAN<f32, f32, DenseMatrix<f32>, Vec<f32>, Euclidian<f32>> =
let deserialized_dbscan: DBSCAN<f64, Euclidian> =
serde_json::from_str(&serde_json::to_string(&dbscan).unwrap()).unwrap();
assert_eq!(dbscan, deserialized_dbscan);
}
#[cfg(feature = "datasets")]
#[test]
fn from_vec() {
use crate::dataset::generator;
// Generate three blobs
let blobs = generator::make_blobs(100, 2, 3);
let x: DenseMatrix<f32> = DenseMatrix::from_iterator(blobs.data.into_iter(), 100, 2, 0);
// Fit the algorithm and predict cluster labels
let labels: Vec<i32> = DBSCAN::fit(&x, DBSCANParameters::default().with_eps(3.0))
.and_then(|dbscan| dbscan.predict(&x))
.unwrap();
println!("{labels:?}");
}
}
+67 -228
View File
@@ -11,12 +11,12 @@
//! these re-calculated centroids becoming the new centers of their respective clusters. Next all instances of the training set are re-assigned to their closest cluster again.
//! This iterative process continues until convergence is achieved and the clusters are considered settled.
//!
//! Initial choice of K data points is very important and has big effect on performance of the algorithm. `smartcore` uses k-means++ algorithm to initialize cluster centers.
//! Initial choice of K data points is very important and has big effect on performance of the algorithm. SmartCore uses k-means++ algorithm to initialize cluster centers.
//!
//! Example:
//!
//! ```
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::cluster::kmeans::*;
//!
//! // Iris data
@@ -41,10 +41,10 @@
//! &[4.9, 2.4, 3.3, 1.0],
//! &[6.6, 2.9, 4.6, 1.3],
//! &[5.2, 2.7, 3.9, 1.4],
//! ]).unwrap();
//! ]);
//!
//! let kmeans = KMeans::fit(&x, KMeansParameters::default().with_k(2)).unwrap(); // Fit to data, 2 clusters
//! let y_hat: Vec<u8> = kmeans.predict(&x).unwrap(); // use the same points for prediction
//! let y_hat = kmeans.predict(&x).unwrap(); // use the same points for prediction
//! ```
//!
//! ## References:
@@ -52,37 +52,32 @@
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 10.3.1 K-Means Clustering](http://faculty.marshall.usc.edu/gareth-james/ISL/)
//! * ["k-means++: The Advantages of Careful Seeding", Arthur D., Vassilvitskii S.](http://ilpubs.stanford.edu:8090/778/1/2006-13.pdf)
use std::fmt::Debug;
use std::marker::PhantomData;
use rand::Rng;
use std::fmt::Debug;
use std::iter::Sum;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::algorithm::neighbour::bbd_tree::BBDTree;
use crate::api::{Predictor, UnsupervisedEstimator};
use crate::error::Failed;
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::metrics::distance::euclidian::*;
use crate::numbers::basenum::Number;
use crate::rand_custom::get_rng_impl;
use crate::linalg::Matrix;
use crate::math::distance::euclidian::*;
use crate::math::num::RealNumber;
/// K-Means clustering algorithm
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct KMeans<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
pub struct KMeans<T: RealNumber> {
k: usize,
_y: Vec<usize>,
size: Vec<usize>,
_distortion: f64,
centroids: Vec<Vec<f64>>,
_phantom_tx: PhantomData<TX>,
_phantom_ty: PhantomData<TY>,
_phantom_x: PhantomData<X>,
_phantom_y: PhantomData<Y>,
_distortion: T,
centroids: Vec<Vec<T>>,
}
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq for KMeans<TX, TY, X, Y> {
impl<T: RealNumber> PartialEq for KMeans<T> {
fn eq(&self, other: &Self) -> bool {
if self.k != other.k
|| self.size != other.size
@@ -96,7 +91,7 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq for KMeans<
return false;
}
for j in 0..self.centroids[i].len() {
if (self.centroids[i][j] - other.centroids[i][j]).abs() > f64::EPSILON {
if (self.centroids[i][j] - other.centroids[i][j]).abs() > T::epsilon() {
return false;
}
}
@@ -106,20 +101,13 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq for KMeans<
}
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// K-Means clustering algorithm parameters
pub struct KMeansParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Number of clusters.
pub k: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// Maximum number of iterations of the k-means algorithm for a single run.
pub max_iter: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// Determines random number generation for centroid initialization.
/// Use an int to make the randomness deterministic
pub seed: Option<u64>,
}
impl KMeansParameters {
@@ -140,118 +128,27 @@ impl Default for KMeansParameters {
KMeansParameters {
k: 2,
max_iter: 100,
seed: Option::None,
}
}
}
/// KMeans grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct KMeansSearchParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Number of clusters.
pub k: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// Maximum number of iterations of the k-means algorithm for a single run.
pub max_iter: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// Determines random number generation for centroid initialization.
/// Use an int to make the randomness deterministic
pub seed: Vec<Option<u64>>,
}
/// KMeans grid search iterator
pub struct KMeansSearchParametersIterator {
kmeans_search_parameters: KMeansSearchParameters,
current_k: usize,
current_max_iter: usize,
current_seed: usize,
}
impl IntoIterator for KMeansSearchParameters {
type Item = KMeansParameters;
type IntoIter = KMeansSearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
KMeansSearchParametersIterator {
kmeans_search_parameters: self,
current_k: 0,
current_max_iter: 0,
current_seed: 0,
}
}
}
impl Iterator for KMeansSearchParametersIterator {
type Item = KMeansParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_k == self.kmeans_search_parameters.k.len()
&& self.current_max_iter == self.kmeans_search_parameters.max_iter.len()
&& self.current_seed == self.kmeans_search_parameters.seed.len()
{
return None;
}
let next = KMeansParameters {
k: self.kmeans_search_parameters.k[self.current_k],
max_iter: self.kmeans_search_parameters.max_iter[self.current_max_iter],
seed: self.kmeans_search_parameters.seed[self.current_seed],
};
if self.current_k + 1 < self.kmeans_search_parameters.k.len() {
self.current_k += 1;
} else if self.current_max_iter + 1 < self.kmeans_search_parameters.max_iter.len() {
self.current_k = 0;
self.current_max_iter += 1;
} else if self.current_seed + 1 < self.kmeans_search_parameters.seed.len() {
self.current_k = 0;
self.current_max_iter = 0;
self.current_seed += 1;
} else {
self.current_k += 1;
self.current_max_iter += 1;
self.current_seed += 1;
}
Some(next)
}
}
impl Default for KMeansSearchParameters {
fn default() -> Self {
let default_params = KMeansParameters::default();
KMeansSearchParameters {
k: vec![default_params.k],
max_iter: vec![default_params.max_iter],
seed: vec![default_params.seed],
}
}
}
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>>
UnsupervisedEstimator<X, KMeansParameters> for KMeans<TX, TY, X, Y>
{
fn fit(x: &X, parameters: KMeansParameters) -> Result<Self, Failed> {
impl<T: RealNumber + Sum, M: Matrix<T>> UnsupervisedEstimator<M, KMeansParameters> for KMeans<T> {
fn fit(x: &M, parameters: KMeansParameters) -> Result<Self, Failed> {
KMeans::fit(x, parameters)
}
}
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> Predictor<X, Y>
for KMeans<TX, TY, X, Y>
{
fn predict(&self, x: &X) -> Result<Y, Failed> {
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for KMeans<T> {
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
self.predict(x)
}
}
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y> {
impl<T: RealNumber + Sum> KMeans<T> {
/// Fit algorithm to _NxM_ matrix where _N_ is number of samples and _M_ is number of features.
/// * `data` - training instances to cluster
/// * `data` - training instances to cluster
/// * `parameters` - cluster parameters
pub fn fit(data: &X, parameters: KMeansParameters) -> Result<KMeans<TX, TY, X, Y>, Failed> {
pub fn fit<M: Matrix<T>>(data: &M, parameters: KMeansParameters) -> Result<KMeans<T>, Failed> {
let bbd = BBDTree::new(data);
if parameters.k < 2 {
@@ -270,10 +167,10 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y>
let (n, d) = data.shape();
let mut distortion = f64::MAX;
let mut y = KMeans::<TX, TY, X, Y>::kmeans_plus_plus(data, parameters.k, parameters.seed);
let mut distortion = T::max_value();
let mut y = KMeans::kmeans_plus_plus(data, parameters.k);
let mut size = vec![0; parameters.k];
let mut centroids = vec![vec![0f64; d]; parameters.k];
let mut centroids = vec![vec![T::zero(); d]; parameters.k];
for i in 0..n {
size[y[i]] += 1;
@@ -281,23 +178,23 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y>
for i in 0..n {
for j in 0..d {
centroids[y[i]][j] += data.get((i, j)).to_f64().unwrap();
centroids[y[i]][j] += data.get(i, j);
}
}
for i in 0..parameters.k {
for j in 0..d {
centroids[i][j] /= size[i] as f64;
centroids[i][j] /= T::from(size[i]).unwrap();
}
}
let mut sums = vec![vec![0f64; d]; parameters.k];
let mut sums = vec![vec![T::zero(); d]; parameters.k];
for _ in 1..=parameters.max_iter {
let dist = bbd.clustering(&centroids, &mut sums, &mut size, &mut y);
for i in 0..parameters.k {
if size[i] > 0 {
for j in 0..d {
centroids[i][j] = sums[i][j] / size[i] as f64;
centroids[i][j] = T::from(sums[i][j]).unwrap() / T::from(size[i]).unwrap();
}
}
}
@@ -315,61 +212,48 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y>
size,
_distortion: distortion,
centroids,
_phantom_tx: PhantomData,
_phantom_ty: PhantomData,
_phantom_x: PhantomData,
_phantom_y: PhantomData,
})
}
/// Predict clusters for `x`
/// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features.
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let (n, _) = x.shape();
let mut result = Y::zeros(n);
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
let (n, m) = x.shape();
let mut result = M::zeros(1, n);
let mut row = vec![0f64; x.shape().1];
let mut row = vec![T::zero(); m];
for i in 0..n {
let mut min_dist = f64::MAX;
let mut min_dist = T::max_value();
let mut best_cluster = 0;
for j in 0..self.k {
x.get_row(i)
.iterator(0)
.zip(row.iter_mut())
.for_each(|(&x, r)| *r = x.to_f64().unwrap());
x.copy_row_as_vec(i, &mut row);
let dist = Euclidian::squared_distance(&row, &self.centroids[j]);
if dist < min_dist {
min_dist = dist;
best_cluster = j;
}
}
result.set(i, TY::from_usize(best_cluster).unwrap());
result.set(0, i, T::from(best_cluster).unwrap());
}
Ok(result)
Ok(result.to_row_vector())
}
fn kmeans_plus_plus(data: &X, k: usize, seed: Option<u64>) -> Vec<usize> {
let mut rng = get_rng_impl(seed);
let (n, _) = data.shape();
fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize) -> Vec<usize> {
let mut rng = rand::thread_rng();
let (n, m) = data.shape();
let mut y = vec![0; n];
let mut centroid: Vec<TX> = data
.get_row(rng.gen_range(0..n))
.iterator(0)
.cloned()
.collect();
let mut centroid = data.get_row_as_vec(rng.gen_range(0..n));
let mut d = vec![f64::MAX; n];
let mut row = vec![TX::zero(); data.shape().1];
let mut d = vec![T::max_value(); n];
let mut row = vec![T::zero(); m];
for j in 1..k {
for i in 0..n {
data.get_row(i)
.iterator(0)
.zip(row.iter_mut())
.for_each(|(&x, r)| *r = x);
data.copy_row_as_vec(i, &mut row);
let dist = Euclidian::squared_distance(&row, &centroid);
if dist < d[i] {
@@ -378,12 +262,12 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y>
}
}
let mut sum = 0f64;
let mut sum: T = T::zero();
for i in d.iter() {
sum += *i;
}
let cutoff = rng.gen::<f64>() * sum;
let mut cost = 0f64;
let cutoff = T::from(rng.gen::<f64>()).unwrap() * sum;
let mut cost = T::zero();
let mut index = 0;
while index < n {
cost += d[index];
@@ -393,14 +277,11 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y>
index += 1;
}
centroid = data.get_row(index).iterator(0).cloned().collect();
data.copy_row_as_vec(index, &mut centroid);
}
for i in 0..n {
data.get_row(i)
.iterator(0)
.zip(row.iter_mut())
.for_each(|(&x, r)| *r = x);
data.copy_row_as_vec(i, &mut row);
let dist = Euclidian::squared_distance(&row, &centroid);
if dist < d[i] {
@@ -416,61 +297,25 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y>
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::linalg::naive::dense_matrix::DenseMatrix;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn invalid_k() {
let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
assert!(KMeans::<i32, i32, DenseMatrix<i32>, Vec<i32>>::fit(
&x,
KMeansParameters::default().with_k(0)
)
.is_err());
assert!(KMeans::fit(&x, KMeansParameters::default().with_k(0)).is_err());
assert_eq!(
"Fit failed: invalid number of clusters: 1",
KMeans::<i32, i32, DenseMatrix<i32>, Vec<i32>>::fit(
&x,
KMeansParameters::default().with_k(1)
)
.unwrap_err()
.to_string()
KMeans::fit(&x, KMeansParameters::default().with_k(1))
.unwrap_err()
.to_string()
);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn search_parameters() {
let parameters = KMeansSearchParameters {
k: vec![2, 4],
max_iter: vec![10, 100],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.k, 2);
assert_eq!(next.max_iter, 10);
let next = iter.next().unwrap();
assert_eq!(next.k, 4);
assert_eq!(next.max_iter, 10);
let next = iter.next().unwrap();
assert_eq!(next.k, 2);
assert_eq!(next.max_iter, 100);
let next = iter.next().unwrap();
assert_eq!(next.k, 4);
assert_eq!(next.max_iter, 100);
assert!(iter.next().is_none());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn fit_predict() {
fn fit_predict_iris() {
let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
@@ -492,22 +337,18 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
])
.unwrap();
]);
let kmeans = KMeans::fit(&x, Default::default()).unwrap();
let y: Vec<usize> = kmeans.predict(&x).unwrap();
let y = kmeans.predict(&x).unwrap();
for (i, _y_i) in y.iter().enumerate() {
assert_eq!({ y[i] }, kmeans._y[i]);
for i in 0..y.len() {
assert_eq!(y[i] as usize, kmeans._y[i]);
}
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
#[cfg(feature = "serde")]
fn serde() {
@@ -532,13 +373,11 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
])
.unwrap();
]);
let kmeans: KMeans<f32, f32, DenseMatrix<f32>, Vec<f32>> =
KMeans::fit(&x, Default::default()).unwrap();
let kmeans = KMeans::fit(&x, Default::default()).unwrap();
let deserialized_kmeans: KMeans<f32, f32, DenseMatrix<f32>, Vec<f32>> =
let deserialized_kmeans: KMeans<f64> =
serde_json::from_str(&serde_json::to_string(&kmeans).unwrap()).unwrap();
assert_eq!(kmeans, deserialized_kmeans);
-2
View File
@@ -1,10 +1,8 @@
#![allow(clippy::ptr_arg, clippy::needless_range_loop)]
//! # Clustering
//!
//! Clustering is the type of unsupervised learning where you divide the population or data points into a number of groups such that data points in the same groups
//! are more similar to other data points in the same group than those in other groups. In simple words, the aim is to segregate groups with similar traits and assign them into clusters.
pub mod agglomerative;
pub mod dbscan;
/// An iterative clustering algorithm that aims to find local maxima in each iteration.
pub mod kmeans;
+2 -5
View File
@@ -31,7 +31,7 @@ use crate::dataset::Dataset;
pub fn load_dataset() -> Dataset<f32, f32> {
let (x, y, num_samples, num_features) = match deserialize_data(std::include_bytes!("boston.xy"))
{
Err(why) => panic!("Can't deserialize boston.xy. {why}"),
Err(why) => panic!("Can't deserialize boston.xy. {}", why),
Ok((x, y, num_samples, num_features)) => (x, y, num_samples, num_features),
};
@@ -69,10 +69,7 @@ mod tests {
assert!(serialize_data(&dataset, "boston.xy").is_ok());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn boston_dataset() {
let dataset = load_dataset();
+14 -21
View File
@@ -30,16 +30,11 @@ use crate::dataset::deserialize_data;
use crate::dataset::Dataset;
/// Get dataset
pub fn load_dataset() -> Dataset<f32, u32> {
pub fn load_dataset() -> Dataset<f32, f32> {
let (x, y, num_samples, num_features) =
match deserialize_data(std::include_bytes!("breast_cancer.xy")) {
Err(why) => panic!("Can't deserialize breast_cancer.xy. {why}"),
Ok((x, y, num_samples, num_features)) => (
x,
y.into_iter().map(|x| x as u32).collect(),
num_samples,
num_features,
),
Err(why) => panic!("Can't deserialize breast_cancer.xy. {}", why),
Ok((x, y, num_samples, num_features)) => (x, y, num_samples, num_features),
};
Dataset {
@@ -71,22 +66,20 @@ pub fn load_dataset() -> Dataset<f32, u32> {
#[cfg(test)]
mod tests {
#[cfg(not(target_arch = "wasm32"))]
use super::super::*;
use super::*;
// TODO: implement serialization
// #[test]
// #[ignore]
// #[cfg(not(target_arch = "wasm32"))]
// fn refresh_cancer_dataset() {
// // run this test to generate breast_cancer.xy file.
// let dataset = load_dataset();
// assert!(serialize_data(&dataset, "breast_cancer.xy").is_ok());
// }
#[test]
#[ignore]
#[cfg(not(target_arch = "wasm32"))]
fn refresh_cancer_dataset() {
// run this test to generate breast_cancer.xy file.
let dataset = load_dataset();
assert!(serialize_data(&dataset, "breast_cancer.xy").is_ok());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn cancer_dataset() {
let dataset = load_dataset();
+15 -22
View File
@@ -23,16 +23,11 @@ use crate::dataset::deserialize_data;
use crate::dataset::Dataset;
/// Get dataset
pub fn load_dataset() -> Dataset<f32, u32> {
pub fn load_dataset() -> Dataset<f32, f32> {
let (x, y, num_samples, num_features) =
match deserialize_data(std::include_bytes!("diabetes.xy")) {
Err(why) => panic!("Can't deserialize diabetes.xy. {why}"),
Ok((x, y, num_samples, num_features)) => (
x,
y.into_iter().map(|x| x as u32).collect(),
num_samples,
num_features,
),
Err(why) => panic!("Can't deserialize diabetes.xy. {}", why),
Ok((x, y, num_samples, num_features)) => (x, y, num_samples, num_features),
};
Dataset {
@@ -40,7 +35,7 @@ pub fn load_dataset() -> Dataset<f32, u32> {
target: y,
num_samples,
num_features,
feature_names: [
feature_names: vec![
"Age", "Sex", "BMI", "BP", "S1", "S2", "S3", "S4", "S5", "S6",
]
.iter()
@@ -55,22 +50,20 @@ pub fn load_dataset() -> Dataset<f32, u32> {
#[cfg(test)]
mod tests {
#[cfg(not(target_arch = "wasm32"))]
use super::super::*;
use super::*;
// TODO: fix serialization
// #[cfg(not(target_arch = "wasm32"))]
// #[test]
// #[ignore]
// fn refresh_diabetes_dataset() {
// // run this test to generate diabetes.xy file.
// let dataset = load_dataset();
// assert!(serialize_data(&dataset, "diabetes.xy").is_ok());
// }
#[cfg(not(target_arch = "wasm32"))]
#[test]
#[ignore]
fn refresh_diabetes_dataset() {
// run this test to generate diabetes.xy file.
let dataset = load_dataset();
assert!(serialize_data(&dataset, "diabetes.xy").is_ok());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn boston_dataset() {
let dataset = load_dataset();
+8 -9
View File
@@ -1,4 +1,4 @@
//! # Optical Recognition of Handwritten Digits Dataset
//! # Optical Recognition of Handwritten Digits Data Set
//!
//! | Number of Instances | Number of Attributes | Missing Values? | Associated Tasks: |
//! |-|-|-|-|
@@ -16,7 +16,7 @@ use crate::dataset::Dataset;
pub fn load_dataset() -> Dataset<f32, f32> {
let (x, y, num_samples, num_features) = match deserialize_data(std::include_bytes!("digits.xy"))
{
Err(why) => panic!("Can't deserialize digits.xy. {why}"),
Err(why) => panic!("Can't deserialize digits.xy. {}", why),
Ok((x, y, num_samples, num_features)) => (x, y, num_samples, num_features),
};
@@ -25,14 +25,16 @@ pub fn load_dataset() -> Dataset<f32, f32> {
target: y,
num_samples,
num_features,
feature_names: ["sepal length (cm)",
feature_names: vec![
"sepal length (cm)",
"sepal width (cm)",
"petal length (cm)",
"petal width (cm)"]
"petal width (cm)",
]
.iter()
.map(|s| s.to_string())
.collect(),
target_names: ["setosa", "versicolor", "virginica"]
target_names: vec!["setosa", "versicolor", "virginica"]
.iter()
.map(|s| s.to_string())
.collect(),
@@ -55,10 +57,7 @@ mod tests {
let dataset = load_dataset();
assert!(serialize_data(&dataset, "digits.xy").is_ok());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn digits_dataset() {
let dataset = load_dataset();
+7 -16
View File
@@ -48,7 +48,7 @@ pub fn make_blobs(
}
/// Make a large circle containing a smaller circle in 2d.
pub fn make_circles(num_samples: usize, factor: f32, noise: f32) -> Dataset<f32, u32> {
pub fn make_circles(num_samples: usize, factor: f32, noise: f32) -> Dataset<f32, f32> {
if !(0.0..1.0).contains(&factor) {
panic!("'factor' has to be between 0 and 1.");
}
@@ -79,7 +79,7 @@ pub fn make_circles(num_samples: usize, factor: f32, noise: f32) -> Dataset<f32,
Dataset {
data: x,
target: y.into_iter().map(|x| x as u32).collect(),
target: y,
num_samples,
num_features: 2,
feature_names: (0..2).map(|n| n.to_string()).collect(),
@@ -89,7 +89,7 @@ pub fn make_circles(num_samples: usize, factor: f32, noise: f32) -> Dataset<f32,
}
/// Make two interleaving half circles in 2d
pub fn make_moons(num_samples: usize, noise: f32) -> Dataset<f32, u32> {
pub fn make_moons(num_samples: usize, noise: f32) -> Dataset<f32, f32> {
let num_samples_out = num_samples / 2;
let num_samples_in = num_samples - num_samples_out;
@@ -116,7 +116,7 @@ pub fn make_moons(num_samples: usize, noise: f32) -> Dataset<f32, u32> {
Dataset {
data: x,
target: y.into_iter().map(|x| x as u32).collect(),
target: y,
num_samples,
num_features: 2,
feature_names: (0..2).map(|n| n.to_string()).collect(),
@@ -137,10 +137,7 @@ mod tests {
use super::*;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_make_blobs() {
let dataset = make_blobs(10, 2, 3);
@@ -153,10 +150,7 @@ mod tests {
assert_eq!(dataset.num_samples, 10);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_make_circles() {
let dataset = make_circles(10, 0.5, 0.05);
@@ -169,10 +163,7 @@ mod tests {
assert_eq!(dataset.num_samples, 10);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_make_moons() {
let dataset = make_moons(10, 0.05);
+19 -29
View File
@@ -1,4 +1,4 @@
//! # The Iris flower dataset
//! # The Iris Dataset flower
//!
//! | Number of Instances | Number of Attributes | Missing Values? | Associated Tasks: |
//! |-|-|-|-|
@@ -19,24 +19,18 @@ use crate::dataset::deserialize_data;
use crate::dataset::Dataset;
/// Get dataset
pub fn load_dataset() -> Dataset<f32, u32> {
let (x, y, num_samples, num_features): (Vec<f32>, Vec<u32>, usize, usize) =
match deserialize_data(std::include_bytes!("iris.xy")) {
Err(why) => panic!("Can't deserialize iris.xy. {why}"),
Ok((x, y, num_samples, num_features)) => (
x,
y.into_iter().map(|x| x as u32).collect(),
num_samples,
num_features,
),
};
pub fn load_dataset() -> Dataset<f32, f32> {
let (x, y, num_samples, num_features) = match deserialize_data(std::include_bytes!("iris.xy")) {
Err(why) => panic!("Can't deserialize iris.xy. {}", why),
Ok((x, y, num_samples, num_features)) => (x, y, num_samples, num_features),
};
Dataset {
data: x,
target: y,
num_samples,
num_features,
feature_names: [
feature_names: vec![
"sepal length (cm)",
"sepal width (cm)",
"petal length (cm)",
@@ -45,7 +39,7 @@ pub fn load_dataset() -> Dataset<f32, u32> {
.iter()
.map(|s| s.to_string())
.collect(),
target_names: ["setosa", "versicolor", "virginica"]
target_names: vec!["setosa", "versicolor", "virginica"]
.iter()
.map(|s| s.to_string())
.collect(),
@@ -56,24 +50,20 @@ pub fn load_dataset() -> Dataset<f32, u32> {
#[cfg(test)]
mod tests {
// #[cfg(not(target_arch = "wasm32"))]
// use super::super::*;
#[cfg(not(target_arch = "wasm32"))]
use super::super::*;
use super::*;
// TODO: fix serialization
// #[cfg(not(target_arch = "wasm32"))]
// #[test]
// #[ignore]
// fn refresh_iris_dataset() {
// // run this test to generate iris.xy file.
// let dataset = load_dataset();
// assert!(serialize_data(&dataset, "iris.xy").is_ok());
// }
#[cfg(not(target_arch = "wasm32"))]
#[test]
#[ignore]
fn refresh_iris_dataset() {
// run this test to generate iris.xy file.
let dataset = load_dataset();
assert!(serialize_data(&dataset, "iris.xy").is_ok());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn iris_dataset() {
let dataset = load_dataset();
+5 -9
View File
@@ -1,7 +1,6 @@
#![allow(clippy::ptr_arg, clippy::needless_range_loop)]
//! Datasets
//!
//! In this module you will find small datasets that are used in `smartcore` mostly for demonstration purposes.
//! In this module you will find small datasets that are used in SmartCore for demonstration purpose mostly.
pub mod boston;
pub mod breast_cancer;
pub mod diabetes;
@@ -10,7 +9,7 @@ pub mod generator;
pub mod iris;
#[cfg(not(target_arch = "wasm32"))]
use crate::numbers::{basenum::Number, realnum::RealNumber};
use crate::math::num::RealNumber;
#[cfg(not(target_arch = "wasm32"))]
use std::fs::File;
use std::io;
@@ -56,7 +55,7 @@ impl<X, Y> Dataset<X, Y> {
// Running this in wasm throws: operation not supported on this platform.
#[cfg(not(target_arch = "wasm32"))]
#[allow(dead_code)]
pub(crate) fn serialize_data<X: Number + RealNumber, Y: RealNumber>(
pub(crate) fn serialize_data<X: RealNumber, Y: RealNumber>(
dataset: &Dataset<X, Y>,
filename: &str,
) -> Result<(), io::Error> {
@@ -79,7 +78,7 @@ pub(crate) fn serialize_data<X: Number + RealNumber, Y: RealNumber>(
.collect();
file.write_all(&y)?;
}
Err(why) => panic!("couldn't create {filename}: {why}"),
Err(why) => panic!("couldn't create {}: {}", filename, why),
}
Ok(())
}
@@ -122,10 +121,7 @@ pub(crate) fn deserialize_data(
mod tests {
use super::*;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn as_matrix() {
let dataset = Dataset {
+94 -240
View File
@@ -10,7 +10,7 @@
//!
//! Example:
//! ```
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::decomposition::pca::*;
//!
//! // Iris data
@@ -35,7 +35,7 @@
//! &[4.9, 2.4, 3.3, 1.0],
//! &[6.6, 2.9, 4.6, 1.3],
//! &[5.2, 2.7, 3.9, 1.4],
//! ]).unwrap();
//! ]);
//!
//! let pca = PCA::fit(&iris, PCAParameters::default().with_n_components(2)).unwrap(); // Reduce number of features to 2
//!
@@ -52,33 +52,24 @@ use serde::{Deserialize, Serialize};
use crate::api::{Transformer, UnsupervisedEstimator};
use crate::error::Failed;
use crate::linalg::basic::arrays::Array2;
use crate::linalg::traits::evd::EVDDecomposable;
use crate::linalg::traits::svd::SVDDecomposable;
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
/// Principal components analysis algorithm
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct PCA<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>> {
eigenvectors: X,
pub struct PCA<T: RealNumber, M: Matrix<T>> {
eigenvectors: M,
eigenvalues: Vec<T>,
projection: X,
projection: M,
mu: Vec<T>,
pmu: Vec<T>,
}
impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>> PartialEq
for PCA<T, X>
{
impl<T: RealNumber, M: Matrix<T>> PartialEq for PCA<T, M> {
fn eq(&self, other: &Self) -> bool {
if self.eigenvalues.len() != other.eigenvalues.len()
|| self
.eigenvectors
.iterator(0)
.zip(other.eigenvectors.iterator(0))
.any(|(&a, &b)| (a - b).abs() > T::epsilon())
if self.eigenvectors != other.eigenvectors
|| self.eigenvalues.len() != other.eigenvalues.len()
{
false
} else {
@@ -92,14 +83,11 @@ impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable
}
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// PCA parameters
pub struct PCAParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Number of components to keep.
pub n_components: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// By default, covariance matrix is used to compute principal components.
/// Enable this flag if you want to use correlation matrix instead.
pub use_correlation_matrix: bool,
@@ -128,124 +116,40 @@ impl Default for PCAParameters {
}
}
/// PCA grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct PCASearchParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Number of components to keep.
pub n_components: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// By default, covariance matrix is used to compute principal components.
/// Enable this flag if you want to use correlation matrix instead.
pub use_correlation_matrix: Vec<bool>,
}
/// PCA grid search iterator
pub struct PCASearchParametersIterator {
pca_search_parameters: PCASearchParameters,
current_k: usize,
current_use_correlation_matrix: usize,
}
impl IntoIterator for PCASearchParameters {
type Item = PCAParameters;
type IntoIter = PCASearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
PCASearchParametersIterator {
pca_search_parameters: self,
current_k: 0,
current_use_correlation_matrix: 0,
}
}
}
impl Iterator for PCASearchParametersIterator {
type Item = PCAParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_k == self.pca_search_parameters.n_components.len()
&& self.current_use_correlation_matrix
== self.pca_search_parameters.use_correlation_matrix.len()
{
return None;
}
let next = PCAParameters {
n_components: self.pca_search_parameters.n_components[self.current_k],
use_correlation_matrix: self.pca_search_parameters.use_correlation_matrix
[self.current_use_correlation_matrix],
};
if self.current_k + 1 < self.pca_search_parameters.n_components.len() {
self.current_k += 1;
} else if self.current_use_correlation_matrix + 1
< self.pca_search_parameters.use_correlation_matrix.len()
{
self.current_k = 0;
self.current_use_correlation_matrix += 1;
} else {
self.current_k += 1;
self.current_use_correlation_matrix += 1;
}
Some(next)
}
}
impl Default for PCASearchParameters {
fn default() -> Self {
let default_params = PCAParameters::default();
PCASearchParameters {
n_components: vec![default_params.n_components],
use_correlation_matrix: vec![default_params.use_correlation_matrix],
}
}
}
impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>>
UnsupervisedEstimator<X, PCAParameters> for PCA<T, X>
{
fn fit(x: &X, parameters: PCAParameters) -> Result<Self, Failed> {
impl<T: RealNumber, M: Matrix<T>> UnsupervisedEstimator<M, PCAParameters> for PCA<T, M> {
fn fit(x: &M, parameters: PCAParameters) -> Result<Self, Failed> {
PCA::fit(x, parameters)
}
}
impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>> Transformer<X>
for PCA<T, X>
{
fn transform(&self, x: &X) -> Result<X, Failed> {
impl<T: RealNumber, M: Matrix<T>> Transformer<M> for PCA<T, M> {
fn transform(&self, x: &M) -> Result<M, Failed> {
self.transform(x)
}
}
impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>> PCA<T, X> {
impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
/// Fits PCA to your data.
/// * `data` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `n_components` - number of components to keep.
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
pub fn fit(data: &X, parameters: PCAParameters) -> Result<PCA<T, X>, Failed> {
pub fn fit(data: &M, parameters: PCAParameters) -> Result<PCA<T, M>, Failed> {
let (m, n) = data.shape();
if parameters.n_components > n {
return Err(Failed::fit(&format!(
"Number of components, n_components should be <= number of attributes ({n})"
"Number of components, n_components should be <= number of attributes ({})",
n
)));
}
let mu: Vec<T> = data
.mean_by(0)
.iter()
.map(|&v| T::from_f64(v).unwrap())
.collect();
let mu = data.column_mean();
let mut x = data.clone();
for (c, &mu_c) in mu.iter().enumerate().take(n) {
for (c, mu_c) in mu.iter().enumerate().take(n) {
for r in 0..m {
x.sub_element_mut((r, c), mu_c);
x.sub_element_mut(r, c, *mu_c);
}
}
@@ -261,33 +165,33 @@ impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable
eigenvectors = svd.V;
} else {
let mut cov = X::zeros(n, n);
let mut cov = M::zeros(n, n);
for k in 0..m {
for i in 0..n {
for j in 0..=i {
cov.add_element_mut((i, j), *x.get((k, i)) * *x.get((k, j)));
cov.add_element_mut(i, j, x.get(k, i) * x.get(k, j));
}
}
}
for i in 0..n {
for j in 0..=i {
cov.div_element_mut((i, j), T::from(m).unwrap());
cov.set((j, i), *cov.get((i, j)));
cov.div_element_mut(i, j, T::from(m).unwrap());
cov.set(j, i, cov.get(i, j));
}
}
if parameters.use_correlation_matrix {
let mut sd = vec![T::zero(); n];
for (i, sd_i) in sd.iter_mut().enumerate().take(n) {
*sd_i = cov.get((i, i)).sqrt();
*sd_i = cov.get(i, i).sqrt();
}
for i in 0..n {
for j in 0..=i {
cov.div_element_mut((i, j), sd[i] * sd[j]);
cov.set((j, i), *cov.get((i, j)));
cov.div_element_mut(i, j, sd[i] * sd[j]);
cov.set(j, i, cov.get(i, j));
}
}
@@ -299,7 +203,7 @@ impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable
for (i, sd_i) in sd.iter().enumerate().take(n) {
for j in 0..n {
eigenvectors.div_element_mut((i, j), *sd_i);
eigenvectors.div_element_mut(i, j, *sd_i);
}
}
} else {
@@ -311,17 +215,17 @@ impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable
}
}
let mut projection = X::zeros(parameters.n_components, n);
let mut projection = M::zeros(parameters.n_components, n);
for i in 0..n {
for j in 0..parameters.n_components {
projection.set((j, i), *eigenvectors.get((i, j)));
projection.set(j, i, eigenvectors.get(i, j));
}
}
let mut pmu = vec![T::zero(); parameters.n_components];
for (k, mu_k) in mu.iter().enumerate().take(n) {
for (i, pmu_i) in pmu.iter_mut().enumerate().take(parameters.n_components) {
*pmu_i += *projection.get((i, k)) * (*mu_k);
*pmu_i += projection.get(i, k) * (*mu_k);
}
}
@@ -336,7 +240,7 @@ impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable
/// Run dimensionality reduction for `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn transform(&self, x: &X) -> Result<X, Failed> {
pub fn transform(&self, x: &M) -> Result<M, Failed> {
let (nrows, ncols) = x.shape();
let (_, n_components) = self.projection.shape();
if ncols != self.mu.len() {
@@ -350,14 +254,14 @@ impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable
let mut x_transformed = x.matmul(&self.projection);
for r in 0..nrows {
for c in 0..n_components {
x_transformed.sub_element_mut((r, c), self.pmu[c]);
x_transformed.sub_element_mut(r, c, self.pmu[c]);
}
}
Ok(x_transformed)
}
/// Get a projection matrix
pub fn components(&self) -> &X {
pub fn components(&self) -> &M {
&self.projection
}
}
@@ -365,30 +269,7 @@ impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use approx::relative_eq;
#[test]
fn search_parameters() {
let parameters = PCASearchParameters {
n_components: vec![2, 4],
use_correlation_matrix: vec![true, false],
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.n_components, 2);
assert!(next.use_correlation_matrix);
let next = iter.next().unwrap();
assert_eq!(next.n_components, 4);
assert!(next.use_correlation_matrix);
let next = iter.next().unwrap();
assert_eq!(next.n_components, 2);
assert!(!next.use_correlation_matrix);
let next = iter.next().unwrap();
assert_eq!(next.n_components, 4);
assert!(!next.use_correlation_matrix);
assert!(iter.next().is_none());
}
use crate::linalg::naive::dense_matrix::*;
fn us_arrests_data() -> DenseMatrix<f64> {
DenseMatrix::from_2d_array(&[
@@ -443,12 +324,8 @@ mod tests {
&[2.6, 53.0, 66.0, 10.8],
&[6.8, 161.0, 60.0, 15.6],
])
.unwrap()
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn pca_components() {
let us_arrests = us_arrests_data();
@@ -458,21 +335,13 @@ mod tests {
&[0.9952, 0.0588],
&[0.0463, 0.9769],
&[0.0752, 0.2007],
])
.unwrap();
]);
let pca = PCA::fit(&us_arrests, Default::default()).unwrap();
assert!(relative_eq!(
expected,
pca.components().abs(),
epsilon = 1e-3
));
assert!(expected.approximate_eq(&pca.components().abs(), 0.4));
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn decompose_covariance() {
let us_arrests = us_arrests_data();
@@ -502,8 +371,7 @@ mod tests {
-0.974080592182491,
0.0723250196376097,
],
])
.unwrap();
]);
let expected_projection = DenseMatrix::from_2d_array(&[
&[-64.8022, -11.448, 2.4949, -2.4079],
@@ -556,8 +424,7 @@ mod tests {
&[91.5446, -22.9529, 0.402, -0.7369],
&[118.1763, 5.5076, 2.7113, -0.205],
&[10.4345, -5.9245, 3.7944, 0.5179],
])
.unwrap();
]);
let expected_eigenvalues: Vec<f64> = vec![
343544.6277001563,
@@ -568,29 +435,23 @@ mod tests {
let pca = PCA::fit(&us_arrests, PCAParameters::default().with_n_components(4)).unwrap();
assert!(relative_eq!(
pca.eigenvectors.abs(),
&expected_eigenvectors.abs(),
epsilon = 1e-4
));
assert!(pca
.eigenvectors
.abs()
.approximate_eq(&expected_eigenvectors.abs(), 1e-4));
for (i, pca_eigenvalues_i) in pca.eigenvalues.iter().enumerate() {
assert!((pca_eigenvalues_i.abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
for i in 0..pca.eigenvalues.len() {
assert!((pca.eigenvalues[i].abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
}
let us_arrests_t = pca.transform(&us_arrests).unwrap();
assert!(relative_eq!(
us_arrests_t.abs(),
&expected_projection.abs(),
epsilon = 1e-4
));
assert!(us_arrests_t
.abs()
.approximate_eq(&expected_projection.abs(), 1e-4));
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn decompose_correlation() {
let us_arrests = us_arrests_data();
@@ -620,8 +481,7 @@ mod tests {
-0.0881962972508558,
-0.0096011588898465,
],
])
.unwrap();
]);
let expected_projection = DenseMatrix::from_2d_array(&[
&[0.9856, -1.1334, 0.4443, -0.1563],
@@ -674,8 +534,7 @@ mod tests {
&[-2.1086, -1.4248, -0.1048, -0.1319],
&[-2.0797, 0.6113, 0.1389, -0.1841],
&[-0.6294, -0.321, 0.2407, 0.1667],
])
.unwrap();
]);
let expected_eigenvalues: Vec<f64> = vec![
2.480241579149493,
@@ -692,59 +551,54 @@ mod tests {
)
.unwrap();
assert!(relative_eq!(
pca.eigenvectors.abs(),
&expected_eigenvectors.abs(),
epsilon = 1e-4
));
assert!(pca
.eigenvectors
.abs()
.approximate_eq(&expected_eigenvectors.abs(), 1e-4));
for (i, pca_eigenvalues_i) in pca.eigenvalues.iter().enumerate() {
assert!((pca_eigenvalues_i.abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
for i in 0..pca.eigenvalues.len() {
assert!((pca.eigenvalues[i].abs() - expected_eigenvalues[i].abs()).abs() < 1e-8);
}
let us_arrests_t = pca.transform(&us_arrests).unwrap();
assert!(relative_eq!(
us_arrests_t.abs(),
&expected_projection.abs(),
epsilon = 1e-4
));
assert!(us_arrests_t
.abs()
.approximate_eq(&expected_projection.abs(), 1e-4));
}
// Disable this test for now
// TODO: implement deserialization for new DenseMatrix
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
// #[test]
// #[cfg(feature = "serde")]
// fn pca_serde() {
// let iris = DenseMatrix::from_2d_array(&[
// &[5.1, 3.5, 1.4, 0.2],
// &[4.9, 3.0, 1.4, 0.2],
// &[4.7, 3.2, 1.3, 0.2],
// &[4.6, 3.1, 1.5, 0.2],
// &[5.0, 3.6, 1.4, 0.2],
// &[5.4, 3.9, 1.7, 0.4],
// &[4.6, 3.4, 1.4, 0.3],
// &[5.0, 3.4, 1.5, 0.2],
// &[4.4, 2.9, 1.4, 0.2],
// &[4.9, 3.1, 1.5, 0.1],
// &[7.0, 3.2, 4.7, 1.4],
// &[6.4, 3.2, 4.5, 1.5],
// &[6.9, 3.1, 4.9, 1.5],
// &[5.5, 2.3, 4.0, 1.3],
// &[6.5, 2.8, 4.6, 1.5],
// &[5.7, 2.8, 4.5, 1.3],
// &[6.3, 3.3, 4.7, 1.6],
// &[4.9, 2.4, 3.3, 1.0],
// &[6.6, 2.9, 4.6, 1.3],
// &[5.2, 2.7, 3.9, 1.4],
// ]).unwrap();
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
#[cfg(feature = "serde")]
fn serde() {
let iris = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[5.7, 2.8, 4.5, 1.3],
&[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
// let pca = PCA::fit(&iris, Default::default()).unwrap();
let pca = PCA::fit(&iris, Default::default()).unwrap();
// let deserialized_pca: PCA<f64, DenseMatrix<f64>> =
// serde_json::from_str(&serde_json::to_string(&pca).unwrap()).unwrap();
let deserialized_pca: PCA<f64, DenseMatrix<f64>> =
serde_json::from_str(&serde_json::to_string(&pca).unwrap()).unwrap();
// assert_eq!(pca, deserialized_pca);
// }
assert_eq!(pca, deserialized_pca);
}
}
+59 -149
View File
@@ -7,7 +7,7 @@
//!
//! Example:
//! ```
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::decomposition::svd::*;
//!
//! // Iris data
@@ -32,7 +32,7 @@
//! &[4.9, 2.4, 3.3, 1.0],
//! &[6.6, 2.9, 4.6, 1.3],
//! &[5.2, 2.7, 3.9, 1.4],
//! ]).unwrap();
//! ]);
//!
//! let svd = SVD::fit(&iris, SVDParameters::default().
//! with_n_components(2)).unwrap(); // Reduce number of features to 2
@@ -51,36 +51,27 @@ use serde::{Deserialize, Serialize};
use crate::api::{Transformer, UnsupervisedEstimator};
use crate::error::Failed;
use crate::linalg::basic::arrays::Array2;
use crate::linalg::traits::evd::EVDDecomposable;
use crate::linalg::traits::svd::SVDDecomposable;
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
/// SVD
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct SVD<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>> {
components: X,
pub struct SVD<T: RealNumber, M: Matrix<T>> {
components: M,
phantom: PhantomData<T>,
}
impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>> PartialEq
for SVD<T, X>
{
impl<T: RealNumber, M: Matrix<T>> PartialEq for SVD<T, M> {
fn eq(&self, other: &Self) -> bool {
self.components
.iterator(0)
.zip(other.components.iterator(0))
.all(|(&a, &b)| (a - b).abs() <= T::epsilon())
.approximate_eq(&other.components, T::from_f64(1e-8).unwrap())
}
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// SVD parameters
pub struct SVDParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Number of components to keep.
pub n_components: usize,
}
@@ -99,94 +90,36 @@ impl SVDParameters {
}
}
/// SVD grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct SVDSearchParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Maximum number of iterations of the k-means algorithm for a single run.
pub n_components: Vec<usize>,
}
/// SVD grid search iterator
pub struct SVDSearchParametersIterator {
svd_search_parameters: SVDSearchParameters,
current_n_components: usize,
}
impl IntoIterator for SVDSearchParameters {
type Item = SVDParameters;
type IntoIter = SVDSearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
SVDSearchParametersIterator {
svd_search_parameters: self,
current_n_components: 0,
}
}
}
impl Iterator for SVDSearchParametersIterator {
type Item = SVDParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_n_components == self.svd_search_parameters.n_components.len() {
return None;
}
let next = SVDParameters {
n_components: self.svd_search_parameters.n_components[self.current_n_components],
};
self.current_n_components += 1;
Some(next)
}
}
impl Default for SVDSearchParameters {
fn default() -> Self {
let default_params = SVDParameters::default();
SVDSearchParameters {
n_components: vec![default_params.n_components],
}
}
}
impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>>
UnsupervisedEstimator<X, SVDParameters> for SVD<T, X>
{
fn fit(x: &X, parameters: SVDParameters) -> Result<Self, Failed> {
impl<T: RealNumber, M: Matrix<T>> UnsupervisedEstimator<M, SVDParameters> for SVD<T, M> {
fn fit(x: &M, parameters: SVDParameters) -> Result<Self, Failed> {
SVD::fit(x, parameters)
}
}
impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>> Transformer<X>
for SVD<T, X>
{
fn transform(&self, x: &X) -> Result<X, Failed> {
impl<T: RealNumber, M: Matrix<T>> Transformer<M> for SVD<T, M> {
fn transform(&self, x: &M) -> Result<M, Failed> {
self.transform(x)
}
}
impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable<T>> SVD<T, X> {
impl<T: RealNumber, M: Matrix<T>> SVD<T, M> {
/// Fits SVD to your data.
/// * `data` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `n_components` - number of components to keep.
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
pub fn fit(x: &X, parameters: SVDParameters) -> Result<SVD<T, X>, Failed> {
pub fn fit(x: &M, parameters: SVDParameters) -> Result<SVD<T, M>, Failed> {
let (_, p) = x.shape();
if parameters.n_components >= p {
return Err(Failed::fit(&format!(
"Number of components, n_components should be < number of attributes ({p})"
"Number of components, n_components should be < number of attributes ({})",
p
)));
}
let svd = x.svd()?;
let components = X::from_slice(svd.V.slice(0..p, 0..parameters.n_components).as_ref());
let components = svd.V.slice(0..p, 0..parameters.n_components);
Ok(SVD {
components,
@@ -196,12 +129,13 @@ impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable
/// Run dimensionality reduction for `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn transform(&self, x: &X) -> Result<X, Failed> {
pub fn transform(&self, x: &M) -> Result<M, Failed> {
let (n, p) = x.shape();
let (p_c, k) = self.components.shape();
if p_c != p {
return Err(Failed::transform(&format!(
"Can not transform a {n}x{p} matrix into {n}x{k} matrix, incorrect input dimentions"
"Can not transform a {}x{} matrix into {}x{} matrix, incorrect input dimentions",
n, p, n, k
)));
}
@@ -209,7 +143,7 @@ impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable
}
/// Get a projection matrix
pub fn components(&self) -> &X {
pub fn components(&self) -> &M {
&self.components
}
}
@@ -217,27 +151,9 @@ impl<T: Number + RealNumber, X: Array2<T> + SVDDecomposable<T> + EVDDecomposable
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::arrays::Array;
use crate::linalg::basic::matrix::DenseMatrix;
use approx::relative_eq;
use crate::linalg::naive::dense_matrix::*;
#[test]
fn search_parameters() {
let parameters = SVDSearchParameters {
n_components: vec![10, 100],
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.n_components, 10);
let next = iter.next().unwrap();
assert_eq!(next.n_components, 100);
assert!(iter.next().is_none());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn svd_decompose() {
// https://stat.ethz.ch/R-manual/R-devel/library/datasets/html/USArrests.html
@@ -292,8 +208,7 @@ mod tests {
&[5.7, 81.0, 39.0, 9.3],
&[2.6, 53.0, 66.0, 10.8],
&[6.8, 161.0, 60.0, 15.6],
])
.unwrap();
]);
let expected = DenseMatrix::from_2d_array(&[
&[243.54655757, -18.76673788],
@@ -301,55 +216,50 @@ mod tests {
&[305.93972467, -15.39087376],
&[197.28420365, -11.66808306],
&[293.43187394, 1.91163633],
])
.unwrap();
]);
let svd = SVD::fit(&x, Default::default()).unwrap();
let x_transformed = svd.transform(&x).unwrap();
assert_eq!(svd.components.shape(), (x.shape().1, 2));
assert!(relative_eq!(
DenseMatrix::from_slice(x_transformed.slice(0..5, 0..2).as_ref()),
&expected,
epsilon = 1e-4
));
assert!(x_transformed
.slice(0..5, 0..2)
.approximate_eq(&expected, 1e-4));
}
// Disable this test for now
// TODO: implement deserialization for new DenseMatrix
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
// #[test]
// #[cfg(feature = "serde")]
// fn serde() {
// let iris = DenseMatrix::from_2d_array(&[
// &[5.1, 3.5, 1.4, 0.2],
// &[4.9, 3.0, 1.4, 0.2],
// &[4.7, 3.2, 1.3, 0.2],
// &[4.6, 3.1, 1.5, 0.2],
// &[5.0, 3.6, 1.4, 0.2],
// &[5.4, 3.9, 1.7, 0.4],
// &[4.6, 3.4, 1.4, 0.3],
// &[5.0, 3.4, 1.5, 0.2],
// &[4.4, 2.9, 1.4, 0.2],
// &[4.9, 3.1, 1.5, 0.1],
// &[7.0, 3.2, 4.7, 1.4],
// &[6.4, 3.2, 4.5, 1.5],
// &[6.9, 3.1, 4.9, 1.5],
// &[5.5, 2.3, 4.0, 1.3],
// &[6.5, 2.8, 4.6, 1.5],
// &[5.7, 2.8, 4.5, 1.3],
// &[6.3, 3.3, 4.7, 1.6],
// &[4.9, 2.4, 3.3, 1.0],
// &[6.6, 2.9, 4.6, 1.3],
// &[5.2, 2.7, 3.9, 1.4],
// ]).unwrap();
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
#[cfg(feature = "serde")]
fn serde() {
let iris = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[5.7, 2.8, 4.5, 1.3],
&[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
// let svd = SVD::fit(&iris, Default::default()).unwrap();
let svd = SVD::fit(&iris, Default::default()).unwrap();
// let deserialized_svd: SVD<f32, DenseMatrix<f32>> =
// serde_json::from_str(&serde_json::to_string(&svd).unwrap()).unwrap();
let deserialized_svd: SVD<f64, DenseMatrix<f64>> =
serde_json::from_str(&serde_json::to_string(&svd).unwrap()).unwrap();
// assert_eq!(svd, deserialized_svd);
// }
assert_eq!(svd, deserialized_svd);
}
}
-214
View File
@@ -1,214 +0,0 @@
use rand::Rng;
use std::fmt::Debug;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
use crate::rand_custom::get_rng_impl;
use crate::tree::base_tree_regressor::{BaseTreeRegressor, BaseTreeRegressorParameters, Splitter};
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// Parameters of the Forest Regressor
/// Some parameters here are passed directly into base estimator.
pub struct BaseForestRegressorParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub max_depth: Option<u16>,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub min_samples_leaf: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub min_samples_split: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The number of trees in the forest.
pub n_trees: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// Number of random sample of predictors to use as split candidates.
pub m: Option<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub keep_samples: bool,
#[cfg_attr(feature = "serde", serde(default))]
/// Seed used for bootstrap sampling and feature selection for each tree.
pub seed: u64,
#[cfg_attr(feature = "serde", serde(default))]
pub bootstrap: bool,
#[cfg_attr(feature = "serde", serde(default))]
pub splitter: Splitter,
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
for BaseForestRegressor<TX, TY, X, Y>
{
fn eq(&self, other: &Self) -> bool {
if self.trees.as_ref().unwrap().len() != other.trees.as_ref().unwrap().len() {
false
} else {
self.trees
.iter()
.zip(other.trees.iter())
.all(|(a, b)| a == b)
}
}
}
/// Forest Regressor
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct BaseForestRegressor<
TX: Number + FloatNumber + PartialOrd,
TY: Number,
X: Array2<TX>,
Y: Array1<TY>,
> {
trees: Option<Vec<BaseTreeRegressor<TX, TY, X, Y>>>,
samples: Option<Vec<Vec<bool>>>,
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
BaseForestRegressor<TX, TY, X, Y>
{
/// Build a forest of trees from the training set.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - the target class values
pub fn fit(
x: &X,
y: &Y,
parameters: BaseForestRegressorParameters,
) -> Result<BaseForestRegressor<TX, TY, X, Y>, Failed> {
let (n_rows, num_attributes) = x.shape();
if n_rows != y.shape() {
return Err(Failed::fit("Number of rows in X should = len(y)"));
}
let mtry = parameters
.m
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
let mut rng = get_rng_impl(Some(parameters.seed));
let mut trees: Vec<BaseTreeRegressor<TX, TY, X, Y>> = Vec::new();
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
if parameters.keep_samples {
// TODO: use with_capacity here
maybe_all_samples = Some(Vec::new());
}
let mut samples: Vec<usize> = (0..n_rows).map(|_| 1).collect();
for _ in 0..parameters.n_trees {
if parameters.bootstrap {
samples =
BaseForestRegressor::<TX, TY, X, Y>::sample_with_replacement(n_rows, &mut rng);
}
// keep samples is flag is on
if let Some(ref mut all_samples) = maybe_all_samples {
all_samples.push(samples.iter().map(|x| *x != 0).collect())
}
let params = BaseTreeRegressorParameters {
max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split,
seed: Some(parameters.seed),
splitter: parameters.splitter.clone(),
};
let tree = BaseTreeRegressor::fit_weak_learner(x, y, samples.clone(), mtry, params)?;
trees.push(tree);
}
Ok(BaseForestRegressor {
trees: Some(trees),
samples: maybe_all_samples,
})
}
/// Predict class for `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let mut result = Y::zeros(x.shape().0);
let (n, _) = x.shape();
for i in 0..n {
result.set(i, self.predict_for_row(x, i));
}
Ok(result)
}
fn predict_for_row(&self, x: &X, row: usize) -> TY {
let n_trees = self.trees.as_ref().unwrap().len();
let mut result = TY::zero();
for tree in self.trees.as_ref().unwrap().iter() {
result += tree.predict_for_row(x, row);
}
result / TY::from_usize(n_trees).unwrap()
}
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
let (n, _) = x.shape();
if self.samples.is_none() {
Err(Failed::because(
FailedError::PredictFailed,
"Need samples=true for OOB predictions.",
))
} else if self.samples.as_ref().unwrap()[0].len() != n {
Err(Failed::because(
FailedError::PredictFailed,
"Prediction matrix must match matrix used in training for OOB predictions.",
))
} else {
let mut result = Y::zeros(n);
for i in 0..n {
result.set(i, self.predict_for_row_oob(x, i));
}
Ok(result)
}
}
fn predict_for_row_oob(&self, x: &X, row: usize) -> TY {
let mut n_trees = 0;
let mut result = TY::zero();
for (tree, samples) in self
.trees
.as_ref()
.unwrap()
.iter()
.zip(self.samples.as_ref().unwrap())
{
if !samples[row] {
result += tree.predict_for_row(x, row);
n_trees += 1;
}
}
// TODO: What to do if there are no oob trees?
result / TY::from(n_trees).unwrap()
}
fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec<usize> {
let mut samples = vec![0; nrows];
for _ in 0..nrows {
let xi = rng.gen_range(0..nrows);
samples[xi] += 1;
}
samples
}
}
-318
View File
@@ -1,318 +0,0 @@
//! # Extra Trees Regressor
//! An Extra-Trees (Extremely Randomized Trees) regressor is an ensemble learning method that fits multiple randomized
//! decision trees on the dataset and averages their predictions to improve accuracy and control over-fitting.
//!
//! It is similar to a standard Random Forest, but introduces more randomness in the way splits are chosen, which can
//! reduce the variance of the model and often make the training process faster.
//!
//! The two key differences from a standard Random Forest are:
//! 1. It uses the whole original dataset to build each tree instead of bootstrap samples.
//! 2. When splitting a node, it chooses a random split point for each feature, rather than the most optimal one.
//!
//! See [ensemble models](../index.html) for more details.
//!
//! Bigger number of estimators in general improves performance of the algorithm with an increased cost of training time.
//! The random sample of _m_ predictors is typically set to be \\(\sqrt{p}\\) from the full set of _p_ predictors.
//!
//! Example:
//!
//! ```
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::ensemble::extra_trees_regressor::*;
//!
//! // Longley dataset ([https://www.statsmodels.org/stable/datasets/generated/longley.html](https://www.statsmodels.org/stable/datasets/generated/longley.html))
//! let x = DenseMatrix::from_2d_array(&[
//! &[234.289, 235.6, 159., 107.608, 1947., 60.323],
//! &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
//! &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
//! &[284.599, 335.1, 165., 110.929, 1950., 61.187],
//! &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
//! &[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
//! &[365.385, 187., 354.7, 115.094, 1953., 64.989],
//! &[363.112, 357.8, 335., 116.219, 1954., 63.761],
//! &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
//! &[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
//! &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
//! &[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
//! &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
//! ]).unwrap();
//! let y = vec![
//! 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2,
//! 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9
//! ];
//!
//! let regressor = ExtraTreesRegressor::fit(&x, &y, Default::default()).unwrap();
//!
//! let y_hat = regressor.predict(&x).unwrap(); // use the same data for prediction
//! ```
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::default::Default;
use std::fmt::Debug;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::ensemble::base_forest_regressor::{BaseForestRegressor, BaseForestRegressorParameters};
use crate::error::Failed;
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
use crate::tree::base_tree_regressor::Splitter;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// Parameters of the Extra Trees Regressor
/// Some parameters here are passed directly into base estimator.
pub struct ExtraTreesRegressorParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub max_depth: Option<u16>,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub min_samples_leaf: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub min_samples_split: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The number of trees in the forest.
pub n_trees: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// Number of random sample of predictors to use as split candidates.
pub m: Option<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub keep_samples: bool,
#[cfg_attr(feature = "serde", serde(default))]
/// Seed used for bootstrap sampling and feature selection for each tree.
pub seed: u64,
}
/// Extra Trees Regressor
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct ExtraTreesRegressor<
TX: Number + FloatNumber + PartialOrd,
TY: Number,
X: Array2<TX>,
Y: Array1<TY>,
> {
forest_regressor: Option<BaseForestRegressor<TX, TY, X, Y>>,
}
impl ExtraTreesRegressorParameters {
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub fn with_max_depth(mut self, max_depth: u16) -> Self {
self.max_depth = Some(max_depth);
self
}
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
self.min_samples_leaf = min_samples_leaf;
self
}
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
self.min_samples_split = min_samples_split;
self
}
/// The number of trees in the forest.
pub fn with_n_trees(mut self, n_trees: usize) -> Self {
self.n_trees = n_trees;
self
}
/// Number of random sample of predictors to use as split candidates.
pub fn with_m(mut self, m: usize) -> Self {
self.m = Some(m);
self
}
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub fn with_keep_samples(mut self, keep_samples: bool) -> Self {
self.keep_samples = keep_samples;
self
}
/// Seed used for bootstrap sampling and feature selection for each tree.
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
}
impl Default for ExtraTreesRegressorParameters {
fn default() -> Self {
ExtraTreesRegressorParameters {
max_depth: Option::None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 10,
m: Option::None,
keep_samples: false,
seed: 0,
}
}
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
SupervisedEstimator<X, Y, ExtraTreesRegressorParameters> for ExtraTreesRegressor<TX, TY, X, Y>
{
fn new() -> Self {
Self {
forest_regressor: Option::None,
}
}
fn fit(x: &X, y: &Y, parameters: ExtraTreesRegressorParameters) -> Result<Self, Failed> {
ExtraTreesRegressor::fit(x, y, parameters)
}
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
Predictor<X, Y> for ExtraTreesRegressor<TX, TY, X, Y>
{
fn predict(&self, x: &X) -> Result<Y, Failed> {
self.predict(x)
}
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
ExtraTreesRegressor<TX, TY, X, Y>
{
/// Build a forest of trees from the training set.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - the target class values
pub fn fit(
x: &X,
y: &Y,
parameters: ExtraTreesRegressorParameters,
) -> Result<ExtraTreesRegressor<TX, TY, X, Y>, Failed> {
let regressor_params = BaseForestRegressorParameters {
max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split,
n_trees: parameters.n_trees,
m: parameters.m,
keep_samples: parameters.keep_samples,
seed: parameters.seed,
bootstrap: false,
splitter: Splitter::Random,
};
let forest_regressor = BaseForestRegressor::fit(x, y, regressor_params)?;
Ok(ExtraTreesRegressor {
forest_regressor: Some(forest_regressor),
})
}
/// Predict class for `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let forest_regressor = self.forest_regressor.as_ref().unwrap();
forest_regressor.predict(x)
}
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
let forest_regressor = self.forest_regressor.as_ref().unwrap();
forest_regressor.predict_oob(x)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::metrics::mean_squared_error;
#[test]
fn test_extra_trees_regressor_fit_predict() {
// Use a simpler, more predictable dataset for unit testing.
let x = DenseMatrix::from_2d_array(&[
&[1., 2.],
&[3., 4.],
&[5., 6.],
&[7., 8.],
&[9., 10.],
&[11., 12.],
&[13., 14.],
&[15., 16.],
])
.unwrap();
let y = vec![1., 2., 3., 4., 5., 6., 7., 8.];
let parameters = ExtraTreesRegressorParameters::default()
.with_n_trees(100)
.with_seed(42);
let regressor = ExtraTreesRegressor::fit(&x, &y, parameters).unwrap();
let y_hat = regressor.predict(&x).unwrap();
assert_eq!(y_hat.len(), y.len());
// A basic check to ensure the model is learning something.
// The error should be significantly less than the variance of y.
let mse = mean_squared_error(&y, &y_hat);
// With this simple dataset, the error should be very low.
assert!(mse < 1.0);
}
#[test]
fn test_fit_predict_higher_dims() {
// Dataset with 10 features, but y is only dependent on the 3rd feature (index 2).
let x = DenseMatrix::from_2d_array(&[
// The 3rd column is the important one. The rest are noise.
&[0., 0., 10., 5., 8., 1., 4., 9., 2., 7.],
&[0., 0., 20., 1., 2., 3., 4., 5., 6., 7.],
&[0., 0., 30., 7., 6., 5., 4., 3., 2., 1.],
&[0., 0., 40., 9., 2., 4., 6., 8., 1., 3.],
&[0., 0., 55., 3., 1., 8., 6., 4., 2., 9.],
&[0., 0., 65., 2., 4., 7., 5., 3., 1., 8.],
])
.unwrap();
let y = vec![10., 20., 30., 40., 55., 65.];
let parameters = ExtraTreesRegressorParameters::default()
.with_n_trees(100)
.with_seed(42);
let regressor = ExtraTreesRegressor::fit(&x, &y, parameters).unwrap();
let y_hat = regressor.predict(&x).unwrap();
assert_eq!(y_hat.len(), y.len());
let mse = mean_squared_error(&y, &y_hat);
// The model should be able to learn this simple relationship perfectly,
// ignoring the noise features. The MSE should be very low.
assert!(mse < 1.0);
}
#[test]
fn test_reproducibility() {
let x = DenseMatrix::from_2d_array(&[
&[1., 2.],
&[3., 4.],
&[5., 6.],
&[7., 8.],
&[9., 10.],
&[11., 12.],
])
.unwrap();
let y = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let params = ExtraTreesRegressorParameters::default().with_seed(42);
let regressor1 = ExtraTreesRegressor::fit(&x, &y, params.clone()).unwrap();
let y_hat1 = regressor1.predict(&x).unwrap();
let regressor2 = ExtraTreesRegressor::fit(&x, &y, params.clone()).unwrap();
let y_hat2 = regressor2.predict(&x).unwrap();
assert_eq!(y_hat1, y_hat2);
}
}
+1 -3
View File
@@ -7,7 +7,7 @@
//! set and then aggregate their individual predictions to form a final prediction. In classification setting the overall prediction is the most commonly
//! occurring majority class among the individual predictions.
//!
//! In `smartcore` you will find implementation of RandomForest - a popular averaging algorithms based on randomized [decision trees](../tree/index.html).
//! In SmartCore you will find implementation of RandomForest - a popular averaging algorithms based on randomized [decision trees](../tree/index.html).
//! Random forests provide an improvement over bagged trees by way of a small tweak that decorrelates the trees. As in bagging, we build a number of
//! decision trees on bootstrapped training samples. But when building these decision trees, each time a split in a tree is considered,
//! a random sample of _m_ predictors is chosen as split candidates from the full set of _p_ predictors.
@@ -16,8 +16,6 @@
//!
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 8.2 Bagging, Random Forests, Boosting](http://faculty.marshall.usc.edu/gareth-james/ISL/)
mod base_forest_regressor;
pub mod extra_trees_regressor;
/// Random forest classifier
pub mod random_forest_classifier;
/// Random forest regressor
+198 -416
View File
@@ -8,7 +8,7 @@
//! Example:
//!
//! ```
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::ensemble::random_forest_classifier::RandomForestClassifier;
//!
//! // Iris dataset
@@ -33,10 +33,10 @@
//! &[4.9, 2.4, 3.3, 1.0],
//! &[6.6, 2.9, 4.6, 1.3],
//! &[5.2, 2.7, 3.9, 1.4],
//! ]).unwrap();
//! ]);
//! let y = vec![
//! 0, 0, 0, 0, 0, 0, 0, 0,
//! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
//! 0., 0., 0., 0., 0., 0., 0., 0.,
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
//! ];
//!
//! let classifier = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap();
@@ -45,8 +45,8 @@
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use rand::Rng;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::default::Default;
use std::fmt::Debug;
@@ -55,11 +55,9 @@ use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
use crate::rand_custom::get_rng_impl;
use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::linalg::{BaseMatrix, Matrix};
use crate::math::num::RealNumber;
use crate::tree::decision_tree_classifier::{
which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion,
};
@@ -69,28 +67,20 @@ use crate::tree::decision_tree_classifier::{
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct RandomForestClassifierParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub criterion: SplitCriterion,
#[cfg_attr(feature = "serde", serde(default))]
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub max_depth: Option<u16>,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub min_samples_leaf: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub min_samples_split: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The number of trees in the forest.
pub n_trees: u16,
#[cfg_attr(feature = "serde", serde(default))]
/// Number of random sample of predictors to use as split candidates.
pub m: Option<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub keep_samples: bool,
#[cfg_attr(feature = "serde", serde(default))]
/// Seed used for bootstrap sampling and feature selection for each tree.
pub seed: u64,
}
@@ -98,14 +88,10 @@ pub struct RandomForestClassifierParameters {
/// Random Forest Classifier
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct RandomForestClassifier<
TX: Number + FloatNumber + PartialOrd,
TY: Number + Ord,
X: Array2<TX>,
Y: Array1<TY>,
> {
trees: Option<Vec<DecisionTreeClassifier<TX, TY, X, Y>>>,
classes: Option<Vec<TY>>,
pub struct RandomForestClassifier<T: RealNumber> {
_parameters: RandomForestClassifierParameters,
trees: Vec<DecisionTreeClassifier<T>>,
classes: Vec<T>,
samples: Option<Vec<Vec<bool>>>,
}
@@ -154,24 +140,22 @@ impl RandomForestClassifierParameters {
}
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
PartialEq for RandomForestClassifier<TX, TY, X, Y>
{
impl<T: RealNumber> PartialEq for RandomForestClassifier<T> {
fn eq(&self, other: &Self) -> bool {
if self.classes.as_ref().unwrap().len() != other.classes.as_ref().unwrap().len()
|| self.trees.as_ref().unwrap().len() != other.trees.as_ref().unwrap().len()
{
if self.classes.len() != other.classes.len() || self.trees.len() != other.trees.len() {
false
} else {
self.classes
.iter()
.zip(other.classes.iter())
.all(|(a, b)| a == b)
&& self
.trees
.iter()
.zip(other.trees.iter())
.all(|(a, b)| a == b)
for i in 0..self.classes.len() {
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
return false;
}
}
for i in 0..self.trees.len() {
if self.trees[i] != other.trees[i] {
return false;
}
}
true
}
}
}
@@ -180,7 +164,7 @@ impl Default for RandomForestClassifierParameters {
fn default() -> Self {
RandomForestClassifierParameters {
criterion: SplitCriterion::Gini,
max_depth: Option::None,
max_depth: None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 100,
@@ -191,302 +175,65 @@ impl Default for RandomForestClassifierParameters {
}
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
SupervisedEstimator<X, Y, RandomForestClassifierParameters>
for RandomForestClassifier<TX, TY, X, Y>
impl<T: RealNumber, M: Matrix<T>>
SupervisedEstimator<M, M::RowVector, RandomForestClassifierParameters>
for RandomForestClassifier<T>
{
fn new() -> Self {
Self {
trees: Option::None,
classes: Option::None,
samples: Option::None,
}
}
fn fit(x: &X, y: &Y, parameters: RandomForestClassifierParameters) -> Result<Self, Failed> {
fn fit(
x: &M,
y: &M::RowVector,
parameters: RandomForestClassifierParameters,
) -> Result<Self, Failed> {
RandomForestClassifier::fit(x, y, parameters)
}
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
Predictor<X, Y> for RandomForestClassifier<TX, TY, X, Y>
{
fn predict(&self, x: &X) -> Result<Y, Failed> {
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RandomForestClassifier<T> {
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
self.predict(x)
}
}
/// RandomForestClassifier grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct RandomForestClassifierSearchParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub criterion: Vec<SplitCriterion>,
#[cfg_attr(feature = "serde", serde(default))]
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub max_depth: Vec<Option<u16>>,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub min_samples_leaf: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub min_samples_split: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// The number of trees in the forest.
pub n_trees: Vec<u16>,
#[cfg_attr(feature = "serde", serde(default))]
/// Number of random sample of predictors to use as split candidates.
pub m: Vec<Option<usize>>,
#[cfg_attr(feature = "serde", serde(default))]
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub keep_samples: Vec<bool>,
#[cfg_attr(feature = "serde", serde(default))]
/// Seed used for bootstrap sampling and feature selection for each tree.
pub seed: Vec<u64>,
}
/// RandomForestClassifier grid search iterator
pub struct RandomForestClassifierSearchParametersIterator {
random_forest_classifier_search_parameters: RandomForestClassifierSearchParameters,
current_criterion: usize,
current_max_depth: usize,
current_min_samples_leaf: usize,
current_min_samples_split: usize,
current_n_trees: usize,
current_m: usize,
current_keep_samples: usize,
current_seed: usize,
}
impl IntoIterator for RandomForestClassifierSearchParameters {
type Item = RandomForestClassifierParameters;
type IntoIter = RandomForestClassifierSearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
RandomForestClassifierSearchParametersIterator {
random_forest_classifier_search_parameters: self,
current_criterion: 0,
current_max_depth: 0,
current_min_samples_leaf: 0,
current_min_samples_split: 0,
current_n_trees: 0,
current_m: 0,
current_keep_samples: 0,
current_seed: 0,
}
}
}
impl Iterator for RandomForestClassifierSearchParametersIterator {
type Item = RandomForestClassifierParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_criterion
== self
.random_forest_classifier_search_parameters
.criterion
.len()
&& self.current_max_depth
== self
.random_forest_classifier_search_parameters
.max_depth
.len()
&& self.current_min_samples_leaf
== self
.random_forest_classifier_search_parameters
.min_samples_leaf
.len()
&& self.current_min_samples_split
== self
.random_forest_classifier_search_parameters
.min_samples_split
.len()
&& self.current_n_trees
== self
.random_forest_classifier_search_parameters
.n_trees
.len()
&& self.current_m == self.random_forest_classifier_search_parameters.m.len()
&& self.current_keep_samples
== self
.random_forest_classifier_search_parameters
.keep_samples
.len()
&& self.current_seed == self.random_forest_classifier_search_parameters.seed.len()
{
return None;
}
let next = RandomForestClassifierParameters {
criterion: self.random_forest_classifier_search_parameters.criterion
[self.current_criterion]
.clone(),
max_depth: self.random_forest_classifier_search_parameters.max_depth
[self.current_max_depth],
min_samples_leaf: self
.random_forest_classifier_search_parameters
.min_samples_leaf[self.current_min_samples_leaf],
min_samples_split: self
.random_forest_classifier_search_parameters
.min_samples_split[self.current_min_samples_split],
n_trees: self.random_forest_classifier_search_parameters.n_trees[self.current_n_trees],
m: self.random_forest_classifier_search_parameters.m[self.current_m],
keep_samples: self.random_forest_classifier_search_parameters.keep_samples
[self.current_keep_samples],
seed: self.random_forest_classifier_search_parameters.seed[self.current_seed],
};
if self.current_criterion + 1
< self
.random_forest_classifier_search_parameters
.criterion
.len()
{
self.current_criterion += 1;
} else if self.current_max_depth + 1
< self
.random_forest_classifier_search_parameters
.max_depth
.len()
{
self.current_criterion = 0;
self.current_max_depth += 1;
} else if self.current_min_samples_leaf + 1
< self
.random_forest_classifier_search_parameters
.min_samples_leaf
.len()
{
self.current_criterion = 0;
self.current_max_depth = 0;
self.current_min_samples_leaf += 1;
} else if self.current_min_samples_split + 1
< self
.random_forest_classifier_search_parameters
.min_samples_split
.len()
{
self.current_criterion = 0;
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split += 1;
} else if self.current_n_trees + 1
< self
.random_forest_classifier_search_parameters
.n_trees
.len()
{
self.current_criterion = 0;
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split = 0;
self.current_n_trees += 1;
} else if self.current_m + 1 < self.random_forest_classifier_search_parameters.m.len() {
self.current_criterion = 0;
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split = 0;
self.current_n_trees = 0;
self.current_m += 1;
} else if self.current_keep_samples + 1
< self
.random_forest_classifier_search_parameters
.keep_samples
.len()
{
self.current_criterion = 0;
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split = 0;
self.current_n_trees = 0;
self.current_m = 0;
self.current_keep_samples += 1;
} else if self.current_seed + 1 < self.random_forest_classifier_search_parameters.seed.len()
{
self.current_criterion = 0;
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split = 0;
self.current_n_trees = 0;
self.current_m = 0;
self.current_keep_samples = 0;
self.current_seed += 1;
} else {
self.current_criterion += 1;
self.current_max_depth += 1;
self.current_min_samples_leaf += 1;
self.current_min_samples_split += 1;
self.current_n_trees += 1;
self.current_m += 1;
self.current_keep_samples += 1;
self.current_seed += 1;
}
Some(next)
}
}
impl Default for RandomForestClassifierSearchParameters {
fn default() -> Self {
let default_params = RandomForestClassifierParameters::default();
RandomForestClassifierSearchParameters {
criterion: vec![default_params.criterion],
max_depth: vec![default_params.max_depth],
min_samples_leaf: vec![default_params.min_samples_leaf],
min_samples_split: vec![default_params.min_samples_split],
n_trees: vec![default_params.n_trees],
m: vec![default_params.m],
keep_samples: vec![default_params.keep_samples],
seed: vec![default_params.seed],
}
}
}
impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
RandomForestClassifier<TX, TY, X, Y>
{
impl<T: RealNumber> RandomForestClassifier<T> {
/// Build a forest of trees from the training set.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - the target class values
pub fn fit(
x: &X,
y: &Y,
pub fn fit<M: Matrix<T>>(
x: &M,
y: &M::RowVector,
parameters: RandomForestClassifierParameters,
) -> Result<RandomForestClassifier<TX, TY, X, Y>, Failed> {
let (x_nrows, num_attributes) = x.shape();
let y_ncols = y.shape();
if x_nrows != y_ncols {
return Err(Failed::fit("Number of rows in X should = len(y)"));
}
) -> Result<RandomForestClassifier<T>, Failed> {
let (_, num_attributes) = x.shape();
let y_m = M::from_row_vector(y.clone());
let (_, y_ncols) = y_m.shape();
let mut yi: Vec<usize> = vec![0; y_ncols];
let classes = y.unique();
let classes = y_m.unique();
for (i, yi_i) in yi.iter_mut().enumerate().take(y_ncols) {
let yc = y.get(i);
*yi_i = classes.iter().position(|c| yc == c).unwrap();
let yc = y_m.get(0, i);
*yi_i = classes.iter().position(|c| yc == *c).unwrap();
}
let mtry = parameters
.m
.unwrap_or_else(|| ((num_attributes as f64).sqrt().floor()) as usize);
let mtry = parameters.m.unwrap_or_else(|| {
(T::from(num_attributes).unwrap())
.sqrt()
.floor()
.to_usize()
.unwrap()
});
let mut rng = get_rng_impl(Some(parameters.seed));
let classes = y.unique();
let mut rng = StdRng::seed_from_u64(parameters.seed);
let classes = y_m.unique();
let k = classes.len();
// TODO: use with_capacity here
let mut trees: Vec<DecisionTreeClassifier<TX, TY, X, Y>> = Vec::new();
let mut trees: Vec<DecisionTreeClassifier<T>> = Vec::new();
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
if parameters.keep_samples {
// TODO: use with_capacity here
maybe_all_samples = Some(Vec::new());
}
for _ in 0..parameters.n_trees {
let samples: Vec<usize> =
RandomForestClassifier::<TX, TY, X, Y>::sample_with_replacement(&yi, k, &mut rng);
let samples = RandomForestClassifier::<T>::sample_with_replacement(&yi, k, &mut rng);
if let Some(ref mut all_samples) = maybe_all_samples {
all_samples.push(samples.iter().map(|x| *x != 0).collect())
}
@@ -496,40 +243,38 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split,
seed: Some(parameters.seed),
};
let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params)?;
let tree =
DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
trees.push(tree);
}
Ok(RandomForestClassifier {
trees: Some(trees),
classes: Some(classes),
_parameters: parameters,
trees,
classes,
samples: maybe_all_samples,
})
}
/// Predict class for `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let mut result = Y::zeros(x.shape().0);
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
let mut result = M::zeros(1, x.shape().0);
let (n, _) = x.shape();
for i in 0..n {
result.set(
i,
self.classes.as_ref().unwrap()[self.predict_for_row(x, i)],
);
result.set(0, i, self.classes[self.predict_for_row(x, i)]);
}
Ok(result)
Ok(result.to_row_vector())
}
fn predict_for_row(&self, x: &X, row: usize) -> usize {
let mut result = vec![0; self.classes.as_ref().unwrap().len()];
fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
let mut result = vec![0; self.classes.len()];
for tree in self.trees.as_ref().unwrap().iter() {
for tree in self.trees.iter() {
result[tree.predict_for_row(x, row)] += 1;
}
@@ -537,7 +282,7 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
}
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
pub fn predict_oob<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
let (n, _) = x.shape();
if self.samples.is_none() {
Err(Failed::because(
@@ -550,28 +295,20 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
"Prediction matrix must match matrix used in training for OOB predictions.",
))
} else {
let mut result = Y::zeros(n);
let mut result = M::zeros(1, n);
for i in 0..n {
result.set(
i,
self.classes.as_ref().unwrap()[self.predict_for_row_oob(x, i)],
);
result.set(0, i, self.classes[self.predict_for_row_oob(x, i)]);
}
Ok(result)
Ok(result.to_row_vector())
}
}
fn predict_for_row_oob(&self, x: &X, row: usize) -> usize {
let mut result = vec![0; self.classes.as_ref().unwrap().len()];
fn predict_for_row_oob<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
let mut result = vec![0; self.classes.len()];
for (tree, samples) in self
.trees
.as_ref()
.unwrap()
.iter()
.zip(self.samples.as_ref().unwrap())
{
for (tree, samples) in self.trees.iter().zip(self.samples.as_ref().unwrap()) {
if !samples[row] {
result[tree.predict_for_row(x, row)] += 1;
}
@@ -580,6 +317,37 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
which_max(&result)
}
/// Predict the per-class probabilties for each observation.
/// The probability is calculated as the fraction of trees that predicted a given class
pub fn predict_probs<M: Matrix<T>>(&self, x: &M) -> Result<DenseMatrix<f64>, Failed> {
let mut result = DenseMatrix::<f64>::zeros(x.shape().0, self.classes.len());
let (n, _) = x.shape();
for i in 0..n {
let row_probs = self.predict_probs_for_row(x, i);
for (j, item) in row_probs.iter().enumerate() {
result.set(i, j, *item);
}
}
Ok(result)
}
fn predict_probs_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> Vec<f64> {
let mut result = vec![0; self.classes.len()];
for tree in self.trees.iter() {
result[tree.predict_for_row(x, row)] += 1;
}
result
.iter()
.map(|n| *n as f64 / self.trees.len() as f64)
.collect()
}
fn sample_with_replacement(y: &[usize], num_classes: usize, rng: &mut impl Rng) -> Vec<usize> {
let class_weight = vec![1.; num_classes];
let nrows = y.len();
@@ -605,40 +373,14 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
}
#[cfg(test)]
mod tests {
mod tests_prob {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::metrics::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn search_parameters() {
let parameters = RandomForestClassifierSearchParameters {
n_trees: vec![10, 100],
m: vec![None, Some(1)],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.n_trees, 10);
assert_eq!(next.m, None);
let next = iter.next().unwrap();
assert_eq!(next.n_trees, 100);
assert_eq!(next.m, None);
let next = iter.next().unwrap();
assert_eq!(next.n_trees, 10);
assert_eq!(next.m, Some(1));
let next = iter.next().unwrap();
assert_eq!(next.n_trees, 100);
assert_eq!(next.m, Some(1));
assert!(iter.next().is_none());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn fit_predict() {
fn fit_predict_iris() {
let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
@@ -660,16 +402,17 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
])
.unwrap();
let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
]);
let y = vec![
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
let classifier = RandomForestClassifier::fit(
&x,
&y,
RandomForestClassifierParameters {
criterion: SplitCriterion::Gini,
max_depth: Option::None,
max_depth: None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 100,
@@ -683,34 +426,7 @@ mod tests {
assert!(accuracy(&y, &classifier.predict(&x).unwrap()) >= 0.95);
}
#[test]
fn test_random_matrix_with_wrong_rownum() {
let x_rand: DenseMatrix<f64> = DenseMatrix::<f64>::rand(21, 200);
let y: Vec<u32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
let fail = RandomForestClassifier::fit(
&x_rand,
&y,
RandomForestClassifierParameters {
criterion: SplitCriterion::Gini,
max_depth: Option::None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 100,
m: Option::None,
keep_samples: false,
seed: 87,
},
);
assert!(fail.is_err());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn fit_predict_iris_oob() {
let x = DenseMatrix::from_2d_array(&[
@@ -734,16 +450,17 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
])
.unwrap();
let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
]);
let y = vec![
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
let classifier = RandomForestClassifier::fit(
&x,
&y,
RandomForestClassifierParameters {
criterion: SplitCriterion::Gini,
max_depth: Option::None,
max_depth: None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 100,
@@ -760,10 +477,7 @@ mod tests {
);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
#[cfg(feature = "serde")]
fn serde() {
@@ -788,15 +502,83 @@ mod tests {
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
])
.unwrap();
let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
]);
let y = vec![
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap();
let deserialized_forest: RandomForestClassifier<f64, i64, DenseMatrix<f64>, Vec<i64>> =
let deserialized_forest: RandomForestClassifier<f64> =
bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap();
assert_eq!(forest, deserialized_forest);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn fit_predict_probabilities() {
let x = DenseMatrix::<f64>::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[5.7, 2.8, 4.5, 1.3],
&[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
let y = vec![
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
let classifier = RandomForestClassifier::fit(
&x,
&y,
RandomForestClassifierParameters {
criterion: SplitCriterion::Gini,
max_depth: None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 100,
m: Option::None,
keep_samples: false,
seed: 87,
},
)
.unwrap();
println!("{:?}", classifier.classes);
let results = classifier.predict_probs(&x).unwrap();
println!("{:?}", x.shape());
println!("{:?}", results);
println!("{:?}", results.shape());
assert_eq!(
results,
DenseMatrix::<f64>::from_array(
20,
2,
&[
1.0, 0.78, 0.95, 0.82, 1.0, 0.92, 0.99, 0.96, 0.36, 0.33, 0.02, 0.02, 0.0, 0.0,
0.0, 0.0, 0.03, 0.05, 0.0, 0.02, 0.0, 0.22, 0.05, 0.18, 0.0, 0.08, 0.01, 0.04,
0.64, 0.67, 0.98, 0.98, 1.0, 1.0, 1.0, 1.0, 0.97, 0.95, 1.0, 0.98
]
)
);
assert!(false);
}
}
+149 -333
View File
@@ -8,7 +8,7 @@
//! Example:
//!
//! ```
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::ensemble::random_forest_regressor::*;
//!
//! // Longley dataset (https://www.statsmodels.org/stable/datasets/generated/longley.html)
@@ -29,7 +29,7 @@
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
//! ]).unwrap();
//! ]);
//! let y = vec![
//! 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2,
//! 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9
@@ -43,6 +43,8 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::default::Default;
use std::fmt::Debug;
@@ -50,37 +52,30 @@ use std::fmt::Debug;
use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::ensemble::base_forest_regressor::{BaseForestRegressor, BaseForestRegressorParameters};
use crate::error::Failed;
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
use crate::tree::base_tree_regressor::Splitter;
use crate::error::{Failed, FailedError};
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
use crate::tree::decision_tree_regressor::{
DecisionTreeRegressor, DecisionTreeRegressorParameters,
};
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// Parameters of the Random Forest Regressor
/// Some parameters here are passed directly into base estimator.
pub struct RandomForestRegressorParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub max_depth: Option<u16>,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub min_samples_leaf: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
pub min_samples_split: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The number of trees in the forest.
pub n_trees: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// Number of random sample of predictors to use as split candidates.
pub m: Option<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub keep_samples: bool,
#[cfg_attr(feature = "serde", serde(default))]
/// Seed used for bootstrap sampling and feature selection for each tree.
pub seed: u64,
}
@@ -88,13 +83,10 @@ pub struct RandomForestRegressorParameters {
/// Random Forest Regressor
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct RandomForestRegressor<
TX: Number + FloatNumber + PartialOrd,
TY: Number,
X: Array2<TX>,
Y: Array1<TY>,
> {
forest_regressor: Option<BaseForestRegressor<TX, TY, X, Y>>,
pub struct RandomForestRegressor<T: RealNumber> {
_parameters: RandomForestRegressorParameters,
trees: Vec<DecisionTreeRegressor<T>>,
samples: Option<Vec<Vec<bool>>>,
}
impl RandomForestRegressorParameters {
@@ -139,7 +131,7 @@ impl RandomForestRegressorParameters {
impl Default for RandomForestRegressorParameters {
fn default() -> Self {
RandomForestRegressorParameters {
max_depth: Option::None,
max_depth: None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 10,
@@ -150,305 +142,167 @@ impl Default for RandomForestRegressorParameters {
}
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
for RandomForestRegressor<TX, TY, X, Y>
{
impl<T: RealNumber> PartialEq for RandomForestRegressor<T> {
fn eq(&self, other: &Self) -> bool {
self.forest_regressor == other.forest_regressor
if self.trees.len() != other.trees.len() {
false
} else {
for i in 0..self.trees.len() {
if self.trees[i] != other.trees[i] {
return false;
}
}
true
}
}
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
SupervisedEstimator<X, Y, RandomForestRegressorParameters>
for RandomForestRegressor<TX, TY, X, Y>
impl<T: RealNumber, M: Matrix<T>>
SupervisedEstimator<M, M::RowVector, RandomForestRegressorParameters>
for RandomForestRegressor<T>
{
fn new() -> Self {
Self {
forest_regressor: Option::None,
}
}
fn fit(x: &X, y: &Y, parameters: RandomForestRegressorParameters) -> Result<Self, Failed> {
fn fit(
x: &M,
y: &M::RowVector,
parameters: RandomForestRegressorParameters,
) -> Result<Self, Failed> {
RandomForestRegressor::fit(x, y, parameters)
}
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
Predictor<X, Y> for RandomForestRegressor<TX, TY, X, Y>
{
fn predict(&self, x: &X) -> Result<Y, Failed> {
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RandomForestRegressor<T> {
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
self.predict(x)
}
}
/// RandomForestRegressor grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct RandomForestRegressorSearchParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub max_depth: Vec<Option<u16>>,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub min_samples_leaf: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub min_samples_split: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// The number of trees in the forest.
pub n_trees: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// Number of random sample of predictors to use as split candidates.
pub m: Vec<Option<usize>>,
#[cfg_attr(feature = "serde", serde(default))]
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub keep_samples: Vec<bool>,
#[cfg_attr(feature = "serde", serde(default))]
/// Seed used for bootstrap sampling and feature selection for each tree.
pub seed: Vec<u64>,
}
/// RandomForestRegressor grid search iterator
pub struct RandomForestRegressorSearchParametersIterator {
random_forest_regressor_search_parameters: RandomForestRegressorSearchParameters,
current_max_depth: usize,
current_min_samples_leaf: usize,
current_min_samples_split: usize,
current_n_trees: usize,
current_m: usize,
current_keep_samples: usize,
current_seed: usize,
}
impl IntoIterator for RandomForestRegressorSearchParameters {
type Item = RandomForestRegressorParameters;
type IntoIter = RandomForestRegressorSearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
RandomForestRegressorSearchParametersIterator {
random_forest_regressor_search_parameters: self,
current_max_depth: 0,
current_min_samples_leaf: 0,
current_min_samples_split: 0,
current_n_trees: 0,
current_m: 0,
current_keep_samples: 0,
current_seed: 0,
}
}
}
impl Iterator for RandomForestRegressorSearchParametersIterator {
type Item = RandomForestRegressorParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_max_depth
== self
.random_forest_regressor_search_parameters
.max_depth
.len()
&& self.current_min_samples_leaf
== self
.random_forest_regressor_search_parameters
.min_samples_leaf
.len()
&& self.current_min_samples_split
== self
.random_forest_regressor_search_parameters
.min_samples_split
.len()
&& self.current_n_trees == self.random_forest_regressor_search_parameters.n_trees.len()
&& self.current_m == self.random_forest_regressor_search_parameters.m.len()
&& self.current_keep_samples
== self
.random_forest_regressor_search_parameters
.keep_samples
.len()
&& self.current_seed == self.random_forest_regressor_search_parameters.seed.len()
{
return None;
}
let next = RandomForestRegressorParameters {
max_depth: self.random_forest_regressor_search_parameters.max_depth
[self.current_max_depth],
min_samples_leaf: self
.random_forest_regressor_search_parameters
.min_samples_leaf[self.current_min_samples_leaf],
min_samples_split: self
.random_forest_regressor_search_parameters
.min_samples_split[self.current_min_samples_split],
n_trees: self.random_forest_regressor_search_parameters.n_trees[self.current_n_trees],
m: self.random_forest_regressor_search_parameters.m[self.current_m],
keep_samples: self.random_forest_regressor_search_parameters.keep_samples
[self.current_keep_samples],
seed: self.random_forest_regressor_search_parameters.seed[self.current_seed],
};
if self.current_max_depth + 1
< self
.random_forest_regressor_search_parameters
.max_depth
.len()
{
self.current_max_depth += 1;
} else if self.current_min_samples_leaf + 1
< self
.random_forest_regressor_search_parameters
.min_samples_leaf
.len()
{
self.current_max_depth = 0;
self.current_min_samples_leaf += 1;
} else if self.current_min_samples_split + 1
< self
.random_forest_regressor_search_parameters
.min_samples_split
.len()
{
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split += 1;
} else if self.current_n_trees + 1
< self.random_forest_regressor_search_parameters.n_trees.len()
{
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split = 0;
self.current_n_trees += 1;
} else if self.current_m + 1 < self.random_forest_regressor_search_parameters.m.len() {
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split = 0;
self.current_n_trees = 0;
self.current_m += 1;
} else if self.current_keep_samples + 1
< self
.random_forest_regressor_search_parameters
.keep_samples
.len()
{
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split = 0;
self.current_n_trees = 0;
self.current_m = 0;
self.current_keep_samples += 1;
} else if self.current_seed + 1 < self.random_forest_regressor_search_parameters.seed.len()
{
self.current_max_depth = 0;
self.current_min_samples_leaf = 0;
self.current_min_samples_split = 0;
self.current_n_trees = 0;
self.current_m = 0;
self.current_keep_samples = 0;
self.current_seed += 1;
} else {
self.current_max_depth += 1;
self.current_min_samples_leaf += 1;
self.current_min_samples_split += 1;
self.current_n_trees += 1;
self.current_m += 1;
self.current_keep_samples += 1;
self.current_seed += 1;
}
Some(next)
}
}
impl Default for RandomForestRegressorSearchParameters {
fn default() -> Self {
let default_params = RandomForestRegressorParameters::default();
RandomForestRegressorSearchParameters {
max_depth: vec![default_params.max_depth],
min_samples_leaf: vec![default_params.min_samples_leaf],
min_samples_split: vec![default_params.min_samples_split],
n_trees: vec![default_params.n_trees],
m: vec![default_params.m],
keep_samples: vec![default_params.keep_samples],
seed: vec![default_params.seed],
}
}
}
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
RandomForestRegressor<TX, TY, X, Y>
{
impl<T: RealNumber> RandomForestRegressor<T> {
/// Build a forest of trees from the training set.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - the target class values
pub fn fit(
x: &X,
y: &Y,
pub fn fit<M: Matrix<T>>(
x: &M,
y: &M::RowVector,
parameters: RandomForestRegressorParameters,
) -> Result<RandomForestRegressor<TX, TY, X, Y>, Failed> {
let regressor_params = BaseForestRegressorParameters {
max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split,
n_trees: parameters.n_trees,
m: parameters.m,
keep_samples: parameters.keep_samples,
seed: parameters.seed,
bootstrap: true,
splitter: Splitter::Best,
};
let forest_regressor = BaseForestRegressor::fit(x, y, regressor_params)?;
) -> Result<RandomForestRegressor<T>, Failed> {
let (n_rows, num_attributes) = x.shape();
let mtry = parameters
.m
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
let mut rng = StdRng::seed_from_u64(parameters.seed);
let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new();
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
if parameters.keep_samples {
maybe_all_samples = Some(Vec::new());
}
for _ in 0..parameters.n_trees {
let samples = RandomForestRegressor::<T>::sample_with_replacement(n_rows, &mut rng);
if let Some(ref mut all_samples) = maybe_all_samples {
all_samples.push(samples.iter().map(|x| *x != 0).collect())
}
let params = DecisionTreeRegressorParameters {
max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split,
};
let tree =
DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
trees.push(tree);
}
Ok(RandomForestRegressor {
forest_regressor: Some(forest_regressor),
_parameters: parameters,
trees,
samples: maybe_all_samples,
})
}
/// Predict class for `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let forest_regressor = self.forest_regressor.as_ref().unwrap();
forest_regressor.predict(x)
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
let mut result = M::zeros(1, x.shape().0);
let (n, _) = x.shape();
for i in 0..n {
result.set(0, i, self.predict_for_row(x, i));
}
Ok(result.to_row_vector())
}
fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
let n_trees = self.trees.len();
let mut result = T::zero();
for tree in self.trees.iter() {
result += tree.predict_for_row(x, row);
}
result / T::from(n_trees).unwrap()
}
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
let forest_regressor = self.forest_regressor.as_ref().unwrap();
forest_regressor.predict_oob(x)
pub fn predict_oob<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
let (n, _) = x.shape();
if self.samples.is_none() {
Err(Failed::because(
FailedError::PredictFailed,
"Need samples=true for OOB predictions.",
))
} else if self.samples.as_ref().unwrap()[0].len() != n {
Err(Failed::because(
FailedError::PredictFailed,
"Prediction matrix must match matrix used in training for OOB predictions.",
))
} else {
let mut result = M::zeros(1, n);
for i in 0..n {
result.set(0, i, self.predict_for_row_oob(x, i));
}
Ok(result.to_row_vector())
}
}
fn predict_for_row_oob<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
let mut n_trees = 0;
let mut result = T::zero();
for (tree, samples) in self.trees.iter().zip(self.samples.as_ref().unwrap()) {
if !samples[row] {
result += tree.predict_for_row(x, row);
n_trees += 1;
}
}
// TODO: What to do if there are no oob trees?
result / T::from(n_trees).unwrap()
}
fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec<usize> {
let mut samples = vec![0; nrows];
for _ in 0..nrows {
let xi = rng.gen_range(0..nrows);
samples[xi] += 1;
}
samples
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::metrics::mean_absolute_error;
#[test]
fn search_parameters() {
let parameters = RandomForestRegressorSearchParameters {
n_trees: vec![10, 100],
m: vec![None, Some(1)],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.n_trees, 10);
assert_eq!(next.m, None);
let next = iter.next().unwrap();
assert_eq!(next.n_trees, 100);
assert_eq!(next.m, None);
let next = iter.next().unwrap();
assert_eq!(next.n_trees, 10);
assert_eq!(next.m, Some(1));
let next = iter.next().unwrap();
assert_eq!(next.n_trees, 100);
assert_eq!(next.m, Some(1));
assert!(iter.next().is_none());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn fit_longley() {
let x = DenseMatrix::from_2d_array(&[
@@ -468,8 +322,7 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
])
.unwrap();
]);
let y = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
@@ -479,7 +332,7 @@ mod tests {
&x,
&y,
RandomForestRegressorParameters {
max_depth: Option::None,
max_depth: None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 1000,
@@ -494,36 +347,7 @@ mod tests {
assert!(mean_absolute_error(&y, &y_hat) < 1.0);
}
#[test]
fn test_random_matrix_with_wrong_rownum() {
let x_rand: DenseMatrix<f64> = DenseMatrix::<f64>::rand(17, 200);
let y = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
];
let fail = RandomForestRegressor::fit(
&x_rand,
&y,
RandomForestRegressorParameters {
max_depth: Option::None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 1000,
m: Option::None,
keep_samples: false,
seed: 87,
},
);
assert!(fail.is_err());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn fit_predict_longley_oob() {
let x = DenseMatrix::from_2d_array(&[
@@ -543,8 +367,7 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
])
.unwrap();
]);
let y = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
@@ -554,7 +377,7 @@ mod tests {
&x,
&y,
RandomForestRegressorParameters {
max_depth: Option::None,
max_depth: None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 1000,
@@ -568,16 +391,10 @@ mod tests {
let y_hat = regressor.predict(&x).unwrap();
let y_hat_oob = regressor.predict_oob(&x).unwrap();
println!("{:?}", mean_absolute_error(&y, &y_hat));
println!("{:?}", mean_absolute_error(&y, &y_hat_oob));
assert!(mean_absolute_error(&y, &y_hat) < mean_absolute_error(&y, &y_hat_oob));
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
#[cfg(feature = "serde")]
fn serde() {
@@ -598,8 +415,7 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
])
.unwrap();
]);
let y = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
@@ -607,7 +423,7 @@ mod tests {
let forest = RandomForestRegressor::fit(&x, &y, Default::default()).unwrap();
let deserialized_forest: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>> =
let deserialized_forest: RandomForestRegressor<f64> =
bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap();
assert_eq!(forest, deserialized_forest);
+1 -23
View File
@@ -30,10 +30,6 @@ pub enum FailedError {
DecompositionFailed,
/// Can't solve for x
SolutionFailed,
/// Error in input parameters
ParametersError,
/// Invalid state error (should never happen)
InvalidStateError,
}
impl Failed {
@@ -66,22 +62,6 @@ impl Failed {
}
}
/// new instance of `FailedError::ParametersError`
pub fn input(msg: &str) -> Self {
Failed {
err: FailedError::ParametersError,
msg: msg.to_string(),
}
}
/// new instance of `FailedError::InvalidStateError`
pub fn invalid_state(msg: &str) -> Self {
Failed {
err: FailedError::InvalidStateError,
msg: msg.to_string(),
}
}
/// new instance of `err`
pub fn because(err: FailedError, msg: &str) -> Self {
Failed {
@@ -114,10 +94,8 @@ impl fmt::Display for FailedError {
FailedError::FindFailed => "Find failed",
FailedError::DecompositionFailed => "Decomposition failed",
FailedError::SolutionFailed => "Can't find solution",
FailedError::ParametersError => "Error in input, check parameters",
FailedError::InvalidStateError => "Invalid state, this should never happen", // useful in development phase of lib
};
write!(f, "{failed_err_str}")
write!(f, "{}", failed_err_str)
}
}
+44 -78
View File
@@ -3,81 +3,32 @@
clippy::too_many_arguments,
clippy::many_single_char_names,
clippy::unnecessary_wraps,
clippy::upper_case_acronyms,
clippy::approx_constant
clippy::upper_case_acronyms
)]
#![warn(missing_docs)]
#![warn(rustdoc::missing_doc_code_examples)]
//! # smartcore
//! # SmartCore
//!
//! Welcome to `smartcore`, machine learning in Rust!
//! Welcome to SmartCore, the most advanced machine learning library in Rust!
//!
//! `smartcore` features various classification, regression and clustering algorithms including support vector machines, random forests, k-means and DBSCAN,
//! SmartCore features various classification, regression and clustering algorithms including support vector machines, random forests, k-means and DBSCAN,
//! as well as tools for model selection and model evaluation.
//!
//! `smartcore` provides its own traits system that extends Rust standard library, to deal with linear algebra and common
//! computational models. Its API is designed using well recognizable patterns. Extra features (like support for [ndarray](https://docs.rs/ndarray)
//! structures) is available via optional features.
//! SmartCore is well integrated with a with wide variaty of libraries that provide support for large, multi-dimensional arrays and matrices. At this moment,
//! all Smartcore's algorithms work with ordinary Rust vectors, as well as matrices and vectors defined in these packages:
//! * [ndarray](https://docs.rs/ndarray)
//! * [nalgebra](https://docs.rs/nalgebra/)
//!
//! ## Getting Started
//!
//! To start using `smartcore` latest stable version simply add the following to your `Cargo.toml` file:
//! To start using SmartCore simply add the following to your Cargo.toml file:
//! ```ignore
//! [dependencies]
//! smartcore = "*"
//! smartcore = "0.2.0"
//! ```
//!
//! To start using smartcore development version with latest unstable additions:
//! ```ignore
//! [dependencies]
//! smartcore = { git = "https://github.com/smartcorelib/smartcore", branch = "development" }
//! ```
//!
//! There are different features that can be added to the base library, for example to add sample datasets:
//! ```ignore
//! [dependencies]
//! smartcore = { git = "https://github.com/smartcorelib/smartcore", features = ["datasets"] }
//! ```
//! Check `smartcore`'s `Cargo.toml` for available features.
//!
//! ## Using Jupyter
//! For quick introduction, Jupyter Notebooks are available [here](https://github.com/smartcorelib/smartcore-jupyter/tree/main/notebooks).
//! You can set up a local environment to run Rust notebooks using [EVCXR](https://github.com/google/evcxr)
//! following [these instructions](https://depth-first.com/articles/2020/09/21/interactive-rust-in-a-repl-and-jupyter-notebook-with-evcxr/).
//!
//!
//! ## First Example
//! For example, you can use this code to fit a [K Nearest Neighbors classifier](neighbors/knn_classifier/index.html) to a dataset that is defined as standard Rust vector:
//!
//! ```
//! // DenseMatrix definition
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! // KNNClassifier
//! use smartcore::neighbors::knn_classifier::*;
//! // Various distance metrics
//! use smartcore::metrics::distance::*;
//!
//! // Turn Rust vector-slices with samples into a matrix
//! let x = DenseMatrix::from_2d_array(&[
//! &[1., 2.],
//! &[3., 4.],
//! &[5., 6.],
//! &[7., 8.],
//! &[9., 10.]]).unwrap();
//! // Our classes are defined as a vector
//! let y = vec![2, 2, 2, 3, 3];
//!
//! // Train classifier
//! let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
//!
//! // Predict classes
//! let y_hat = knn.predict(&x).unwrap();
//! ```
//!
//! ## Overview
//!
//! ### Supported algorithms
//! All machine learning algorithms are grouped into these broad categories:
//! All machine learning algorithms in SmartCore are grouped into these broad categories:
//! * [Clustering](cluster/index.html), unsupervised clustering of unlabeled data.
//! * [Matrix Decomposition](decomposition/index.html), various methods for matrix decomposition.
//! * [Linear Models](linear/index.html), regression and classification methods where output is assumed to have linear relation to explanatory variables
@@ -87,16 +38,37 @@
//! * [Naive Bayes](naive_bayes/index.html), statistical classification technique based on Bayes Theorem
//! * [SVM](svm/index.html), support vector machines
//!
//! ### Linear Algebra traits system
//! For an introduction to `smartcore`'s traits system see [this notebook](https://github.com/smartcorelib/smartcore-jupyter/blob/5523993c53c6ec1fd72eea130ef4e7883121c1ea/notebooks/01-A-little-bit-about-numbers.ipynb)
//!
//! For example, you can use this code to fit a [K Nearest Neighbors classifier](neighbors/knn_classifier/index.html) to a dataset that is defined as standard Rust vector:
//!
//! ```
//! // DenseMatrix defenition
//! use smartcore::linalg::naive::dense_matrix::*;
//! // KNNClassifier
//! use smartcore::neighbors::knn_classifier::*;
//! // Various distance metrics
//! use smartcore::math::distance::*;
//!
//! // Turn Rust vectors with samples into a matrix
//! let x = DenseMatrix::from_2d_array(&[
//! &[1., 2.],
//! &[3., 4.],
//! &[5., 6.],
//! &[7., 8.],
//! &[9., 10.]]);
//! // Our classes are defined as a Vector
//! let y = vec![2., 2., 2., 3., 3.];
//!
//! // Train classifier
//! let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
//!
//! // Predict classes
//! let y_hat = knn.predict(&x).unwrap();
//! ```
/// Foundamental numbers traits
pub mod numbers;
/// Various algorithms and helper methods that are used elsewhere in smartcore
/// Various algorithms and helper methods that are used elsewhere in SmartCore
pub mod algorithm;
pub mod api;
/// Algorithms for clustering of unlabeled data
pub mod cluster;
/// Various datasets
@@ -107,29 +79,23 @@ pub mod decomposition;
/// Ensemble methods, including Random Forest classifier and regressor
pub mod ensemble;
pub mod error;
/// Diverse collection of linear algebra abstractions and methods that power smartcore algorithms
/// Diverse collection of linear algebra abstractions and methods that power SmartCore algorithms
pub mod linalg;
/// Supervised classification and regression models that assume linear relationship between dependent and explanatory variables.
pub mod linear;
/// Helper methods and classes, including definitions of distance metrics
pub mod math;
/// Functions for assessing prediction error.
pub mod metrics;
/// TODO: add docstring for model_selection
pub mod model_selection;
/// Supervised learning algorithms based on applying the Bayes theorem with the independence assumptions between predictors
pub mod naive_bayes;
/// Supervised neighbors-based learning methods
pub mod neighbors;
/// Optimization procedures
pub mod optimization;
pub(crate) mod optimization;
/// Preprocessing utilities
pub mod preprocessing;
/// Reading in data from serialized formats
#[cfg(feature = "serde")]
pub mod readers;
/// Support Vector Machines
pub mod svm;
/// Supervised tree-based learning methods
pub mod tree;
pub mod xgboost;
pub(crate) mod rand_custom;
File diff suppressed because it is too large Load Diff
-845
View File
@@ -1,845 +0,0 @@
use std::fmt;
use std::fmt::{Debug, Display};
use std::ops::Range;
use std::slice::Iter;
use approx::{AbsDiffEq, RelativeEq};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::linalg::basic::arrays::{
Array, Array2, ArrayView1, ArrayView2, MutArray, MutArrayView2,
};
use crate::linalg::traits::cholesky::CholeskyDecomposable;
use crate::linalg::traits::evd::EVDDecomposable;
use crate::linalg::traits::lu::LUDecomposable;
use crate::linalg::traits::qr::QRDecomposable;
use crate::linalg::traits::stats::{MatrixPreprocessing, MatrixStats};
use crate::linalg::traits::svd::SVDDecomposable;
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
use crate::error::Failed;
/// Dense matrix
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct DenseMatrix<T> {
ncols: usize,
nrows: usize,
values: Vec<T>,
column_major: bool,
}
/// View on dense matrix
#[derive(Debug, Clone)]
pub struct DenseMatrixView<'a, T: Debug + Display + Copy + Sized> {
values: &'a [T],
stride: usize,
nrows: usize,
ncols: usize,
column_major: bool,
}
/// Mutable view on dense matrix
#[derive(Debug)]
pub struct DenseMatrixMutView<'a, T: Debug + Display + Copy + Sized> {
values: &'a mut [T],
stride: usize,
nrows: usize,
ncols: usize,
column_major: bool,
}
impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixView<'a, T> {
fn new(
m: &'a DenseMatrix<T>,
vrows: Range<usize>,
vcols: Range<usize>,
) -> Result<Self, Failed> {
if m.is_valid_view(m.shape().0, m.shape().1, &vrows, &vcols) {
Err(Failed::input(
"The specified view is outside of the matrix range",
))
} else {
let (start, end, stride) =
m.stride_range(m.shape().0, m.shape().1, &vrows, &vcols, m.column_major);
Ok(DenseMatrixView {
values: &m.values[start..end],
stride,
nrows: vrows.end - vrows.start,
ncols: vcols.end - vcols.start,
column_major: m.column_major,
})
}
}
fn iter<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
assert!(
axis == 1 || axis == 0,
"For two dimensional array `axis` should be either 0 or 1"
);
match axis {
0 => Box::new(
(0..self.nrows).flat_map(move |r| (0..self.ncols).map(move |c| self.get((r, c)))),
),
_ => Box::new(
(0..self.ncols).flat_map(move |c| (0..self.nrows).map(move |r| self.get((r, c)))),
),
}
}
}
impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixView<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"DenseMatrix: nrows: {:?}, ncols: {:?}",
self.nrows, self.ncols
)?;
writeln!(f, "column_major: {:?}", self.column_major)?;
self.display(f)
}
}
impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> {
fn new(
m: &'a mut DenseMatrix<T>,
vrows: Range<usize>,
vcols: Range<usize>,
) -> Result<Self, Failed> {
if m.is_valid_view(m.shape().0, m.shape().1, &vrows, &vcols) {
Err(Failed::input(
"The specified view is outside of the matrix range",
))
} else {
let (start, end, stride) =
m.stride_range(m.shape().0, m.shape().1, &vrows, &vcols, m.column_major);
Ok(DenseMatrixMutView {
values: &mut m.values[start..end],
stride,
nrows: vrows.end - vrows.start,
ncols: vcols.end - vcols.start,
column_major: m.column_major,
})
}
}
fn iter<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
assert!(
axis == 1 || axis == 0,
"For two dimensional array `axis` should be either 0 or 1"
);
match axis {
0 => Box::new(
(0..self.nrows).flat_map(move |r| (0..self.ncols).map(move |c| self.get((r, c)))),
),
_ => Box::new(
(0..self.ncols).flat_map(move |c| (0..self.nrows).map(move |r| self.get((r, c)))),
),
}
}
fn iter_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
let column_major = self.column_major;
let stride = self.stride;
let ptr = self.values.as_mut_ptr();
match axis {
0 => Box::new((0..self.nrows).flat_map(move |r| {
(0..self.ncols).map(move |c| unsafe {
&mut *ptr.add(if column_major {
r + c * stride
} else {
r * stride + c
})
})
})),
_ => Box::new((0..self.ncols).flat_map(move |c| {
(0..self.nrows).map(move |r| unsafe {
&mut *ptr.add(if column_major {
r + c * stride
} else {
r * stride + c
})
})
})),
}
}
}
impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixMutView<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"DenseMatrix: nrows: {:?}, ncols: {:?}",
self.nrows, self.ncols
)?;
writeln!(f, "column_major: {:?}", self.column_major)?;
self.display(f)
}
}
impl<T: Debug + Display + Copy + Sized> DenseMatrix<T> {
/// Create new instance of `DenseMatrix` without copying data.
/// `values` should be in column-major order.
pub fn new(
nrows: usize,
ncols: usize,
values: Vec<T>,
column_major: bool,
) -> Result<Self, Failed> {
let data_len = values.len();
if nrows * ncols != values.len() {
Err(Failed::input(&format!(
"The specified shape: (cols: {ncols}, rows: {nrows}) does not align with data len: {data_len}"
)))
} else {
Ok(DenseMatrix {
ncols,
nrows,
values,
column_major,
})
}
}
/// New instance of `DenseMatrix` from 2d array.
pub fn from_2d_array(values: &[&[T]]) -> Result<Self, Failed> {
DenseMatrix::from_2d_vec(&values.iter().map(|row| Vec::from(*row)).collect())
}
/// New instance of `DenseMatrix` from 2d vector.
#[allow(clippy::ptr_arg)]
pub fn from_2d_vec(values: &Vec<Vec<T>>) -> Result<Self, Failed> {
if values.is_empty() || values[0].is_empty() {
Err(Failed::input(
"The 2d vec provided is empty; cannot instantiate the matrix",
))
} else {
let nrows = values.len();
let ncols = values
.first()
.unwrap_or_else(|| {
panic!("Invalid state: Cannot create 2d matrix from an empty vector")
})
.len();
let mut m_values = Vec::with_capacity(nrows * ncols);
for c in 0..ncols {
for r in values.iter().take(nrows) {
m_values.push(r[c])
}
}
DenseMatrix::new(nrows, ncols, m_values, true)
}
}
/// Iterate over values of matrix
pub fn iter(&self) -> Iter<'_, T> {
self.values.iter()
}
/// Check if the size of the requested view is bounded to matrix rows/cols count
fn is_valid_view(
&self,
n_rows: usize,
n_cols: usize,
vrows: &Range<usize>,
vcols: &Range<usize>,
) -> bool {
!(vrows.end <= n_rows
&& vcols.end <= n_cols
&& vrows.start <= n_rows
&& vcols.start <= n_cols)
}
/// Compute the range of the requested view: start, end, size of the slice
fn stride_range(
&self,
n_rows: usize,
n_cols: usize,
vrows: &Range<usize>,
vcols: &Range<usize>,
column_major: bool,
) -> (usize, usize, usize) {
let (start, end, stride) = if column_major {
(
vrows.start + vcols.start * n_rows,
vrows.end + (vcols.end - 1) * n_rows,
n_rows,
)
} else {
(
vrows.start * n_cols + vcols.start,
(vrows.end - 1) * n_cols + vcols.end,
n_cols,
)
};
(start, end, stride)
}
}
impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrix<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"DenseMatrix: nrows: {:?}, ncols: {:?}",
self.nrows, self.ncols
)?;
writeln!(f, "column_major: {:?}", self.column_major)?;
self.display(f)
}
}
impl<T: Debug + Display + Copy + Sized + PartialEq> PartialEq for DenseMatrix<T> {
fn eq(&self, other: &Self) -> bool {
if self.ncols != other.ncols || self.nrows != other.nrows {
return false;
}
let len = self.values.len();
let other_len = other.values.len();
if len != other_len {
return false;
}
match self.column_major == other.column_major {
true => self
.values
.iter()
.zip(other.values.iter())
.all(|(&v1, v2)| v1.eq(v2)),
false => self
.iterator(0)
.zip(other.iterator(0))
.all(|(&v1, v2)| v1.eq(v2)),
}
}
}
impl<T: Number + RealNumber + AbsDiffEq> AbsDiffEq for DenseMatrix<T>
where
T::Epsilon: Copy,
{
type Epsilon = T::Epsilon;
fn default_epsilon() -> T::Epsilon {
T::default_epsilon()
}
// equality in differences in absolute values, according to an epsilon
fn abs_diff_eq(&self, other: &Self, epsilon: T::Epsilon) -> bool {
if self.ncols != other.ncols || self.nrows != other.nrows {
false
} else {
self.values
.iter()
.zip(other.values.iter())
.all(|(v1, v2)| T::abs_diff_eq(v1, v2, epsilon))
}
}
}
impl<T: Number + RealNumber + RelativeEq> RelativeEq for DenseMatrix<T>
where
T::Epsilon: Copy,
{
fn default_max_relative() -> T::Epsilon {
T::default_max_relative()
}
fn relative_eq(&self, other: &Self, epsilon: T::Epsilon, max_relative: T::Epsilon) -> bool {
if self.ncols != other.ncols || self.nrows != other.nrows {
false
} else {
self.iterator(0)
.zip(other.iterator(0))
.all(|(v1, v2)| T::relative_eq(v1, v2, epsilon, max_relative))
}
}
}
impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrix<T> {
fn get(&self, pos: (usize, usize)) -> &T {
let (row, col) = pos;
if row >= self.nrows || col >= self.ncols {
panic!(
"Invalid index ({},{}) for {}x{} matrix",
row, col, self.nrows, self.ncols
);
}
if self.column_major {
&self.values[col * self.nrows + row]
} else {
&self.values[col + self.ncols * row]
}
}
fn shape(&self) -> (usize, usize) {
(self.nrows, self.ncols)
}
fn is_empty(&self) -> bool {
self.ncols < 1 || self.nrows < 1
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
assert!(
axis == 1 || axis == 0,
"For two dimensional array `axis` should be either 0 or 1"
);
match axis {
0 => Box::new(
(0..self.nrows).flat_map(move |r| (0..self.ncols).map(move |c| self.get((r, c)))),
),
_ => Box::new(
(0..self.ncols).flat_map(move |c| (0..self.nrows).map(move |r| self.get((r, c)))),
),
}
}
}
impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for DenseMatrix<T> {
fn set(&mut self, pos: (usize, usize), x: T) {
if self.column_major {
self.values[pos.1 * self.nrows + pos.0] = x;
} else {
self.values[pos.1 + pos.0 * self.ncols] = x;
}
}
fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
let ptr = self.values.as_mut_ptr();
let column_major = self.column_major;
let (nrows, ncols) = self.shape();
match axis {
0 => Box::new((0..self.nrows).flat_map(move |r| {
(0..self.ncols).map(move |c| unsafe {
&mut *ptr.add(if column_major {
r + c * nrows
} else {
r * ncols + c
})
})
})),
_ => Box::new((0..self.ncols).flat_map(move |c| {
(0..self.nrows).map(move |r| unsafe {
&mut *ptr.add(if column_major {
r + c * nrows
} else {
r * ncols + c
})
})
})),
}
}
}
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrix<T> {}
impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for DenseMatrix<T> {}
impl<T: Debug + Display + Copy + Sized> Array2<T> for DenseMatrix<T> {
fn get_row<'a>(&'a self, row: usize) -> Box<dyn ArrayView1<T> + 'a> {
Box::new(DenseMatrixView::new(self, row..row + 1, 0..self.ncols).unwrap())
}
fn get_col<'a>(&'a self, col: usize) -> Box<dyn ArrayView1<T> + 'a> {
Box::new(DenseMatrixView::new(self, 0..self.nrows, col..col + 1).unwrap())
}
fn slice<'a>(&'a self, rows: Range<usize>, cols: Range<usize>) -> Box<dyn ArrayView2<T> + 'a> {
Box::new(DenseMatrixView::new(self, rows, cols).unwrap())
}
fn slice_mut<'a>(
&'a mut self,
rows: Range<usize>,
cols: Range<usize>,
) -> Box<dyn MutArrayView2<T> + 'a>
where
Self: Sized,
{
Box::new(DenseMatrixMutView::new(self, rows, cols).unwrap())
}
// private function so for now assume infalible
fn fill(nrows: usize, ncols: usize, value: T) -> Self {
DenseMatrix::new(nrows, ncols, vec![value; nrows * ncols], true).unwrap()
}
// private function so for now assume infalible
fn from_iterator<I: Iterator<Item = T>>(iter: I, nrows: usize, ncols: usize, axis: u8) -> Self {
DenseMatrix::new(nrows, ncols, iter.collect(), axis != 0).unwrap()
}
fn transpose(&self) -> Self {
let mut m = self.clone();
m.ncols = self.nrows;
m.nrows = self.ncols;
m.column_major = !self.column_major;
m
}
}
impl<T: Number + RealNumber> QRDecomposable<T> for DenseMatrix<T> {}
impl<T: Number + RealNumber> CholeskyDecomposable<T> for DenseMatrix<T> {}
impl<T: Number + RealNumber> EVDDecomposable<T> for DenseMatrix<T> {}
impl<T: Number + RealNumber> LUDecomposable<T> for DenseMatrix<T> {}
impl<T: Number + RealNumber> SVDDecomposable<T> for DenseMatrix<T> {}
impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixView<'_, T> {
fn get(&self, pos: (usize, usize)) -> &T {
if self.column_major {
&self.values[pos.0 + pos.1 * self.stride]
} else {
&self.values[pos.0 * self.stride + pos.1]
}
}
fn shape(&self) -> (usize, usize) {
(self.nrows, self.ncols)
}
fn is_empty(&self) -> bool {
self.nrows * self.ncols > 0
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
self.iter(axis)
}
}
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for DenseMatrixView<'_, T> {
fn get(&self, i: usize) -> &T {
if self.nrows == 1 {
if self.column_major {
&self.values[i * self.stride]
} else {
&self.values[i]
}
} else if self.ncols == 1 || (!self.column_major && self.nrows == 1) {
if self.column_major {
&self.values[i]
} else {
&self.values[i * self.stride]
}
} else {
panic!("This is neither a column nor a row");
}
}
fn shape(&self) -> usize {
if self.nrows == 1 {
self.ncols
} else if self.ncols == 1 {
self.nrows
} else {
panic!("This is neither a column nor a row");
}
}
fn is_empty(&self) -> bool {
self.nrows * self.ncols > 0
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
self.iter(axis)
}
}
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixView<'_, T> {}
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for DenseMatrixView<'_, T> {}
impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixMutView<'_, T> {
fn get(&self, pos: (usize, usize)) -> &T {
if self.column_major {
&self.values[pos.0 + pos.1 * self.stride]
} else {
&self.values[pos.0 * self.stride + pos.1]
}
}
fn shape(&self) -> (usize, usize) {
(self.nrows, self.ncols)
}
fn is_empty(&self) -> bool {
self.nrows * self.ncols > 0
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
self.iter(axis)
}
}
impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for DenseMatrixMutView<'_, T> {
fn set(&mut self, pos: (usize, usize), x: T) {
if self.column_major {
self.values[pos.0 + pos.1 * self.stride] = x;
} else {
self.values[pos.0 * self.stride + pos.1] = x;
}
}
fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
self.iter_mut(axis)
}
}
impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for DenseMatrixMutView<'_, T> {}
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixMutView<'_, T> {}
impl<T: RealNumber> MatrixStats<T> for DenseMatrix<T> {}
impl<T: RealNumber> MatrixPreprocessing<T> for DenseMatrix<T> {}
#[cfg(test)]
#[warn(clippy::reversed_empty_ranges)]
mod tests {
use super::*;
use approx::relative_eq;
#[test]
fn test_instantiate_from_2d() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
assert!(x.is_ok());
}
#[test]
fn test_instantiate_from_2d_empty() {
let input: &[&[f64]] = &[&[]];
let x = DenseMatrix::from_2d_array(input);
assert!(x.is_err());
}
#[test]
fn test_instantiate_from_2d_empty2() {
let input: &[&[f64]] = &[&[], &[]];
let x = DenseMatrix::from_2d_array(input);
assert!(x.is_err());
}
#[test]
fn test_instantiate_ok_view1() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
let v = DenseMatrixView::new(&x, 0..2, 0..2);
assert!(v.is_ok());
}
#[test]
fn test_instantiate_ok_view2() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
let v = DenseMatrixView::new(&x, 0..3, 0..3);
assert!(v.is_ok());
}
#[test]
fn test_instantiate_ok_view3() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
let v = DenseMatrixView::new(&x, 2..3, 0..3);
assert!(v.is_ok());
}
#[test]
fn test_instantiate_ok_view4() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
let v = DenseMatrixView::new(&x, 3..3, 0..3);
assert!(v.is_ok());
}
#[test]
fn test_instantiate_err_view1() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
let v = DenseMatrixView::new(&x, 3..4, 0..3);
assert!(v.is_err());
}
#[test]
fn test_instantiate_err_view2() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
let v = DenseMatrixView::new(&x, 0..3, 3..4);
assert!(v.is_err());
}
#[test]
fn test_instantiate_err_view3() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
#[allow(clippy::reversed_empty_ranges)]
let v = DenseMatrixView::new(&x, 0..3, 4..3);
assert!(v.is_err());
}
#[test]
fn test_display() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
println!("{}", &x);
}
#[test]
fn test_get_row_col() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
assert_eq!(15.0, x.get_col(1).sum());
assert_eq!(15.0, x.get_row(1).sum());
assert_eq!(81.0, x.get_col(1).dot(&(*x.get_row(1))));
}
#[test]
fn test_row_major() {
let mut x = DenseMatrix::new(2, 3, vec![1, 2, 3, 4, 5, 6], false).unwrap();
assert_eq!(5, *x.get_col(1).get(1));
assert_eq!(7, x.get_col(1).sum());
assert_eq!(5, *x.get_row(1).get(1));
assert_eq!(15, x.get_row(1).sum());
x.slice_mut(0..2, 1..2)
.iterator_mut(0)
.for_each(|v| *v += 2);
assert_eq!(vec![1, 4, 3, 4, 7, 6], *x.values);
}
#[test]
fn test_get_slice() {
let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9], &[10, 11, 12]])
.unwrap();
assert_eq!(
vec![4, 5, 6],
DenseMatrix::from_slice(&(*x.slice(1..2, 0..3))).values
);
let second_row: Vec<i32> = x.slice(1..2, 0..3).iterator(0).copied().collect();
assert_eq!(vec![4, 5, 6], second_row);
let second_col: Vec<i32> = x.slice(0..3, 1..2).iterator(0).copied().collect();
assert_eq!(vec![2, 5, 8], second_col);
}
#[test]
fn test_iter_mut() {
let mut x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]]).unwrap();
assert_eq!(vec![1, 4, 7, 2, 5, 8, 3, 6, 9], x.values);
// add +2 to some elements
x.slice_mut(1..2, 0..3)
.iterator_mut(0)
.for_each(|v| *v += 2);
assert_eq!(vec![1, 6, 7, 2, 7, 8, 3, 8, 9], x.values);
// add +1 to some others
x.slice_mut(0..3, 1..2)
.iterator_mut(0)
.for_each(|v| *v += 1);
assert_eq!(vec![1, 6, 7, 3, 8, 9, 3, 8, 9], x.values);
// rewrite matrix as indices of values per axis 1 (row-wise)
x.iterator_mut(1).enumerate().for_each(|(a, b)| *b = a);
assert_eq!(vec![0, 1, 2, 3, 4, 5, 6, 7, 8], x.values);
// rewrite matrix as indices of values per axis 0 (column-wise)
x.iterator_mut(0).enumerate().for_each(|(a, b)| *b = a);
assert_eq!(vec![0, 3, 6, 1, 4, 7, 2, 5, 8], x.values);
// rewrite some by slice
x.slice_mut(0..3, 0..2)
.iterator_mut(0)
.enumerate()
.for_each(|(a, b)| *b = a);
assert_eq!(vec![0, 2, 4, 1, 3, 5, 2, 5, 8], x.values);
x.slice_mut(0..2, 0..3)
.iterator_mut(1)
.enumerate()
.for_each(|(a, b)| *b = a);
assert_eq!(vec![0, 1, 4, 2, 3, 5, 4, 5, 8], x.values);
}
#[test]
fn test_str_array() {
let mut x =
DenseMatrix::from_2d_array(&[&["1", "2", "3"], &["4", "5", "6"], &["7", "8", "9"]])
.unwrap();
assert_eq!(vec!["1", "4", "7", "2", "5", "8", "3", "6", "9"], x.values);
x.iterator_mut(0).for_each(|v| *v = "str");
assert_eq!(
vec!["str", "str", "str", "str", "str", "str", "str", "str", "str"],
x.values
);
}
#[test]
fn test_transpose() {
let x = DenseMatrix::<&str>::from_2d_array(&[&["1", "2", "3"], &["4", "5", "6"]]).unwrap();
assert_eq!(vec!["1", "4", "2", "5", "3", "6"], x.values);
assert!(x.column_major);
// transpose
let x = x.transpose();
assert_eq!(vec!["1", "4", "2", "5", "3", "6"], x.values);
assert!(!x.column_major); // should change column_major
}
#[test]
fn test_from_iterator() {
let data = [1, 2, 3, 4, 5, 6];
let m = DenseMatrix::from_iterator(data.iter(), 2, 3, 0);
// make a vector into a 2x3 matrix.
assert_eq!(
vec![1, 2, 3, 4, 5, 6],
m.values.iter().map(|e| **e).collect::<Vec<i32>>()
);
assert!(!m.column_major);
}
#[test]
fn test_take() {
let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
let b = DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4], &[5, 6]]).unwrap();
println!("{a}");
// take column 0 and 2
assert_eq!(vec![1, 3, 4, 6], a.take(&[0, 2], 1).values);
println!("{b}");
// take rows 0 and 2
assert_eq!(vec![1, 2, 5, 6], b.take(&[0, 2], 0).values);
}
#[test]
fn test_mut() {
let a = DenseMatrix::from_2d_array(&[&[1.3, -2.1, 3.4], &[-4., -5.3, 6.1]]).unwrap();
let a = a.abs();
assert_eq!(vec![1.3, 4.0, 2.1, 5.3, 3.4, 6.1], a.values);
let a = a.neg();
assert_eq!(vec![-1.3, -4.0, -2.1, -5.3, -3.4, -6.1], a.values);
}
#[test]
fn test_reshape() {
let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9], &[10, 11, 12]])
.unwrap();
let a = a.reshape(2, 6, 0);
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], a.values);
assert!(a.ncols == 6 && a.nrows == 2 && !a.column_major);
let a = a.reshape(3, 4, 1);
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], a.values);
assert!(a.ncols == 4 && a.nrows == 3 && a.column_major);
}
#[test]
fn test_eq() {
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]).unwrap();
let b = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
let c = DenseMatrix::from_2d_array(&[
&[1. + f32::EPSILON, 2., 3.],
&[4., 5., 6. + f32::EPSILON],
])
.unwrap();
let d = DenseMatrix::from_2d_array(&[&[1. + 0.5, 2., 3.], &[4., 5., 6. + f32::EPSILON]])
.unwrap();
assert!(!relative_eq!(a, b));
assert!(!relative_eq!(a, d));
assert!(relative_eq!(a, c));
}
}
-8
View File
@@ -1,8 +0,0 @@
/// `Array`, `ArrayView` and related multidimensional
pub mod arrays;
/// foundamental implementation for a `DenseMatrix` construct
pub mod matrix;
/// foundamental implementation for 1D constructs
pub mod vector;
-348
View File
@@ -1,348 +0,0 @@
use std::fmt::{Debug, Display};
use std::ops::Range;
use crate::linalg::basic::arrays::{Array, Array1, ArrayView1, MutArray, MutArrayView1};
/// Provide mutable window on array
#[derive(Debug)]
pub struct VecMutView<'a, T: Debug + Display + Copy + Sized> {
ptr: &'a mut [T],
}
/// Provide window on array
#[derive(Debug, Clone)]
pub struct VecView<'a, T: Debug + Display + Copy + Sized> {
ptr: &'a [T],
}
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for &[T] {
fn get(&self, i: usize) -> &T {
&self[i]
}
fn shape(&self) -> usize {
self.len()
}
fn is_empty(&self) -> bool {
self.len() > 0
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
assert!(axis == 0, "For one dimensional array `axis` should == 0");
Box::new(self.iter())
}
}
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for Vec<T> {
fn get(&self, i: usize) -> &T {
&self[i]
}
fn shape(&self) -> usize {
self.len()
}
fn is_empty(&self) -> bool {
self.len() > 0
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
assert!(axis == 0, "For one dimensional array `axis` should == 0");
Box::new(self.iter())
}
}
impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for Vec<T> {
fn set(&mut self, i: usize, x: T) {
// NOTE: this panics in case of out of bounds index
self[i] = x
}
fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
assert!(axis == 0, "For one dimensional array `axis` should == 0");
Box::new(self.iter_mut())
}
}
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for Vec<T> {}
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for &[T] {}
impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for Vec<T> {}
impl<T: Debug + Display + Copy + Sized> Array1<T> for Vec<T> {
fn slice<'a>(&'a self, range: Range<usize>) -> Box<dyn ArrayView1<T> + 'a> {
assert!(
range.end <= self.len(),
"`range` should be <= {}",
self.len()
);
let view = VecView { ptr: &self[range] };
Box::new(view)
}
fn slice_mut<'b>(&'b mut self, range: Range<usize>) -> Box<dyn MutArrayView1<T> + 'b> {
assert!(
range.end <= self.len(),
"`range` should be <= {}",
self.len()
);
let view = VecMutView {
ptr: &mut self[range],
};
Box::new(view)
}
fn fill(len: usize, value: T) -> Self {
vec![value; len]
}
fn from_iterator<I: Iterator<Item = T>>(iter: I, len: usize) -> Self
where
Self: Sized,
{
let mut v: Vec<T> = Vec::with_capacity(len);
iter.take(len).for_each(|i| v.push(i));
v
}
fn from_vec_slice(slice: &[T]) -> Self {
let mut v: Vec<T> = Vec::with_capacity(slice.len());
slice.iter().for_each(|i| v.push(*i));
v
}
fn from_slice(slice: &dyn ArrayView1<T>) -> Self {
let mut v: Vec<T> = Vec::with_capacity(slice.shape());
slice.iterator(0).for_each(|i| v.push(*i));
v
}
}
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for VecMutView<'_, T> {
fn get(&self, i: usize) -> &T {
&self.ptr[i]
}
fn shape(&self) -> usize {
self.ptr.len()
}
fn is_empty(&self) -> bool {
self.ptr.len() > 0
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
assert!(axis == 0, "For one dimensional array `axis` should == 0");
Box::new(self.ptr.iter())
}
}
impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for VecMutView<'_, T> {
fn set(&mut self, i: usize, x: T) {
self.ptr[i] = x;
}
fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
assert!(axis == 0, "For one dimensional array `axis` should == 0");
Box::new(self.ptr.iter_mut())
}
}
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for VecMutView<'_, T> {}
impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for VecMutView<'_, T> {}
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for VecView<'_, T> {
fn get(&self, i: usize) -> &T {
&self.ptr[i]
}
fn shape(&self) -> usize {
self.ptr.len()
}
fn is_empty(&self) -> bool {
self.ptr.len() > 0
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
assert!(axis == 0, "For one dimensional array `axis` should == 0");
Box::new(self.ptr.iter())
}
}
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for VecView<'_, T> {}
#[cfg(test)]
mod tests {
use super::*;
use crate::numbers::basenum::Number;
fn dot_product<T: Number, V: Array1<T>>(v: &V) -> T {
let vv = V::zeros(10);
let v_s = vv.slice(0..3);
v_s.dot(v)
}
fn vector_ops<T: Number + PartialOrd, V: Array1<T>>(_: &V) -> T {
let v = V::zeros(10);
v.max()
}
#[test]
fn test_get_set() {
let mut x = vec![1, 2, 3];
assert_eq!(3, *x.get(2));
x.set(1, 1);
assert_eq!(1, *x.get(1));
}
#[test]
#[should_panic]
fn test_failed_set() {
vec![1, 2, 3].set(3, 1);
}
#[test]
#[should_panic]
fn test_failed_get() {
vec![1, 2, 3].get(3);
}
#[test]
fn test_len() {
let x = [1, 2, 3];
assert_eq!(3, x.len());
}
#[test]
fn test_is_empty() {
assert!(vec![1; 0].is_empty());
assert!(!vec![1, 2, 3].is_empty());
}
#[test]
fn test_iterator() {
let v: Vec<i32> = vec![1, 2, 3].iterator(0).map(|&v| v * 2).collect();
assert_eq!(vec![2, 4, 6], v);
}
#[test]
#[should_panic]
fn test_failed_iterator() {
let _ = vec![1, 2, 3].iterator(1);
}
#[test]
fn test_mut_iterator() {
let mut x = vec![1, 2, 3];
x.iterator_mut(0).for_each(|v| *v *= 2);
assert_eq!(vec![2, 4, 6], x);
}
#[test]
#[should_panic]
fn test_failed_mut_iterator() {
let _ = vec![1, 2, 3].iterator_mut(1);
}
#[test]
fn test_slice() {
let x = vec![1, 2, 3, 4, 5];
let x_slice = x.slice(2..3);
assert_eq!(1, x_slice.shape());
assert_eq!(3, *x_slice.get(0));
}
#[test]
#[should_panic]
fn test_failed_slice() {
vec![1, 2, 3].slice(0..4);
}
#[test]
fn test_mut_slice() {
let mut x = vec![1, 2, 3, 4, 5];
let mut x_slice = x.slice_mut(2..4);
x_slice.set(0, 9);
assert_eq!(2, x_slice.shape());
assert_eq!(9, *x_slice.get(0));
assert_eq!(4, *x_slice.get(1));
}
#[test]
#[should_panic]
fn test_failed_mut_slice() {
vec![1, 2, 3].slice_mut(0..4);
}
#[test]
fn test_init() {
assert_eq!(Vec::fill(3, 0), vec![0, 0, 0]);
assert_eq!(
Vec::from_iterator([0, 1, 2, 3].iter().cloned(), 3),
vec![0, 1, 2]
);
assert_eq!(Vec::from_vec_slice(&[0, 1, 2]), vec![0, 1, 2]);
assert_eq!(Vec::from_vec_slice(&[0, 1, 2, 3, 4][2..]), vec![2, 3, 4]);
assert_eq!(Vec::from_slice(&vec![1, 2, 3, 4, 5]), vec![1, 2, 3, 4, 5]);
assert_eq!(
Vec::from_slice(vec![1, 2, 3, 4, 5].slice(0..3).as_ref()),
vec![1, 2, 3]
);
}
#[test]
fn test_mul_scalar() {
let mut x = vec![1., 2., 3.];
let mut y = Vec::<f32>::zeros(10);
y.slice_mut(0..2).add_scalar_mut(1.0);
y.sub_scalar(1.0);
x.slice_mut(0..2).sub_scalar_mut(2.);
assert_eq!(vec![-1.0, 0.0, 3.0], x);
}
#[test]
fn test_dot() {
let y_i = vec![1, 2, 3];
let y = vec![1.0, 2.0, 3.0];
println!("Regular dot1: {:?}", dot_product(&y));
let x = vec![4.0, 5.0, 6.0];
assert_eq!(32.0, y.slice(0..3).dot(&(*x.slice(0..3))));
assert_eq!(32.0, y.slice(0..3).dot(&x));
assert_eq!(32.0, y.dot(&x));
assert_eq!(14, y_i.dot(&y_i));
}
#[test]
fn test_operators() {
let mut x: Vec<f32> = Vec::zeros(10);
x.add_scalar(15.0);
{
let mut x_s = x.slice_mut(0..5);
x_s.add_scalar_mut(1.0);
assert_eq!(
vec![1.0, 1.0, 1.0, 1.0, 1.0],
x_s.iterator(0).copied().collect::<Vec<f32>>()
);
}
assert_eq!(1.0, x.slice(2..3).min());
assert_eq!(vec![1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], x);
}
#[test]
fn test_vector_ops() {
let x = vec![1., 2., 3.];
vector_ops(&x);
}
}
@@ -8,14 +8,14 @@
//!
//! Example:
//! ```
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linalg::traits::cholesky::*;
//! use smartcore::linalg::naive::dense_matrix::*;
//! use crate::smartcore::linalg::cholesky::*;
//!
//! let A = DenseMatrix::from_2d_array(&[
//! &[25., 15., -5.],
//! &[15., 18., 0.],
//! &[-5., 0., 11.]
//! ]).unwrap();
//! ]);
//!
//! let cholesky = A.cholesky().unwrap();
//! let lower_triangular: DenseMatrix<f64> = cholesky.L();
@@ -34,18 +34,17 @@ use std::fmt::Debug;
use std::marker::PhantomData;
use crate::error::{Failed, FailedError};
use crate::linalg::basic::arrays::Array2;
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
use crate::linalg::BaseMatrix;
use crate::math::num::RealNumber;
#[derive(Debug, Clone)]
/// Results of Cholesky decomposition.
pub struct Cholesky<T: Number + RealNumber, M: Array2<T>> {
pub struct Cholesky<T: RealNumber, M: BaseMatrix<T>> {
R: M,
t: PhantomData<T>,
}
impl<T: Number + RealNumber, M: Array2<T>> Cholesky<T, M> {
impl<T: RealNumber, M: BaseMatrix<T>> Cholesky<T, M> {
pub(crate) fn new(R: M) -> Cholesky<T, M> {
Cholesky { R, t: PhantomData }
}
@@ -58,7 +57,7 @@ impl<T: Number + RealNumber, M: Array2<T>> Cholesky<T, M> {
for i in 0..n {
for j in 0..n {
if j <= i {
R.set((i, j), *self.R.get((i, j)));
R.set(i, j, self.R.get(i, j));
}
}
}
@@ -73,7 +72,7 @@ impl<T: Number + RealNumber, M: Array2<T>> Cholesky<T, M> {
for i in 0..n {
for j in 0..n {
if j <= i {
R.set((j, i), *self.R.get((i, j)));
R.set(j, i, self.R.get(i, j));
}
}
}
@@ -88,25 +87,25 @@ impl<T: Number + RealNumber, M: Array2<T>> Cholesky<T, M> {
if bn != rn {
return Err(Failed::because(
FailedError::SolutionFailed,
"Can\'t solve Ax = b for x. FloatNumber of rows in b != number of rows in R.",
"Can\'t solve Ax = b for x. Number of rows in b != number of rows in R.",
));
}
for k in 0..bn {
for j in 0..m {
for i in 0..k {
b.sub_element_mut((k, j), *b.get((i, j)) * *self.R.get((k, i)));
b.sub_element_mut(k, j, b.get(i, j) * self.R.get(k, i));
}
b.div_element_mut((k, j), *self.R.get((k, k)));
b.div_element_mut(k, j, self.R.get(k, k));
}
}
for k in (0..bn).rev() {
for j in 0..m {
for i in k + 1..bn {
b.sub_element_mut((k, j), *b.get((i, j)) * *self.R.get((i, k)));
b.sub_element_mut(k, j, b.get(i, j) * self.R.get(i, k));
}
b.div_element_mut((k, j), *self.R.get((k, k)));
b.div_element_mut(k, j, self.R.get(k, k));
}
}
Ok(b)
@@ -114,7 +113,7 @@ impl<T: Number + RealNumber, M: Array2<T>> Cholesky<T, M> {
}
/// Trait that implements Cholesky decomposition routine for any matrix.
pub trait CholeskyDecomposable<T: Number + RealNumber>: Array2<T> {
pub trait CholeskyDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
/// Compute the Cholesky decomposition of a matrix.
fn cholesky(&self) -> Result<Cholesky<T, Self>, Failed> {
self.clone().cholesky_mut()
@@ -137,13 +136,13 @@ pub trait CholeskyDecomposable<T: Number + RealNumber>: Array2<T> {
for k in 0..j {
let mut s = T::zero();
for i in 0..k {
s += *self.get((k, i)) * *self.get((j, i));
s += self.get(k, i) * self.get(j, i);
}
s = (*self.get((j, k)) - s) / *self.get((k, k));
self.set((j, k), s);
s = (self.get(j, k) - s) / self.get(k, k);
self.set(j, k, s);
d += s * s;
}
d = *self.get((j, j)) - d;
d = self.get(j, j) - d;
if d < T::zero() {
return Err(Failed::because(
@@ -152,7 +151,7 @@ pub trait CholeskyDecomposable<T: Number + RealNumber>: Array2<T> {
));
}
self.set((j, j), d.sqrt());
self.set(j, j, d.sqrt());
}
Ok(Cholesky::new(self))
@@ -167,50 +166,39 @@ pub trait CholeskyDecomposable<T: Number + RealNumber>: Array2<T> {
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use approx::relative_eq;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
use crate::linalg::naive::dense_matrix::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn cholesky_decompose() {
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]])
.unwrap();
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
let l =
DenseMatrix::from_2d_array(&[&[5.0, 0.0, 0.0], &[3.0, 3.0, 0.0], &[-1.0, 1.0, 3.0]])
.unwrap();
DenseMatrix::from_2d_array(&[&[5.0, 0.0, 0.0], &[3.0, 3.0, 0.0], &[-1.0, 1.0, 3.0]]);
let u =
DenseMatrix::from_2d_array(&[&[5.0, 3.0, -1.0], &[0.0, 3.0, 1.0], &[0.0, 0.0, 3.0]])
.unwrap();
DenseMatrix::from_2d_array(&[&[5.0, 3.0, -1.0], &[0.0, 3.0, 1.0], &[0.0, 0.0, 3.0]]);
let cholesky = a.cholesky().unwrap();
assert!(relative_eq!(cholesky.L().abs(), l.abs(), epsilon = 1e-4));
assert!(relative_eq!(cholesky.U().abs(), u.abs(), epsilon = 1e-4));
assert!(relative_eq!(
cholesky.L().matmul(&cholesky.U()).abs(),
a.abs(),
epsilon = 1e-4
));
assert!(cholesky.L().abs().approximate_eq(&l.abs(), 1e-4));
assert!(cholesky.U().abs().approximate_eq(&u.abs(), 1e-4));
assert!(cholesky
.L()
.matmul(&cholesky.U())
.abs()
.approximate_eq(&a.abs(), 1e-4));
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn cholesky_solve_mut() {
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]])
.unwrap();
let b = DenseMatrix::from_2d_array(&[&[40., 51., 28.]]).unwrap();
let expected = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]).unwrap();
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
let b = DenseMatrix::from_2d_array(&[&[40., 51., 28.]]);
let expected = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]);
let cholesky = a.cholesky().unwrap();
assert!(relative_eq!(
cholesky.solve(b.transpose()).unwrap().transpose(),
expected,
epsilon = 1e-4
));
assert!(cholesky
.solve(b.transpose())
.unwrap()
.transpose()
.approximate_eq(&expected, 1e-4));
}
}
+195 -217
View File
@@ -12,14 +12,14 @@
//!
//! Example:
//! ```
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linalg::traits::evd::*;
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::linalg::evd::*;
//!
//! let A = DenseMatrix::from_2d_array(&[
//! &[0.9000, 0.4000, 0.7000],
//! &[0.4000, 0.5000, 0.3000],
//! &[0.7000, 0.3000, 0.8000],
//! ]).unwrap();
//! ]);
//!
//! let evd = A.evd(true).unwrap();
//! let eigenvectors: DenseMatrix<f64> = evd.V;
@@ -35,15 +35,14 @@
#![allow(non_snake_case)]
use crate::error::Failed;
use crate::linalg::basic::arrays::Array2;
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
use crate::linalg::BaseMatrix;
use crate::math::num::RealNumber;
use num::complex::Complex;
use std::fmt::Debug;
#[derive(Debug, Clone)]
/// Results of eigen decomposition
pub struct EVD<T: Number + RealNumber, M: Array2<T>> {
pub struct EVD<T: RealNumber, M: BaseMatrix<T>> {
/// Real part of eigenvalues.
pub d: Vec<T>,
/// Imaginary part of eigenvalues.
@@ -53,7 +52,7 @@ pub struct EVD<T: Number + RealNumber, M: Array2<T>> {
}
/// Trait that implements EVD decomposition routine for any matrix.
pub trait EVDDecomposable<T: Number + RealNumber>: Array2<T> {
pub trait EVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
/// Compute the eigen decomposition of a square matrix.
/// * `symmetric` - whether the matrix is symmetric
fn evd(&self, symmetric: bool) -> Result<EVD<T, Self>, Failed> {
@@ -66,7 +65,7 @@ pub trait EVDDecomposable<T: Number + RealNumber>: Array2<T> {
fn evd_mut(mut self, symmetric: bool) -> Result<EVD<T, Self>, Failed> {
let (nrows, ncols) = self.shape();
if ncols != nrows {
panic!("Matrix is not square: {nrows} x {ncols}");
panic!("Matrix is not square: {} x {}", nrows, ncols);
}
let n = nrows;
@@ -94,14 +93,14 @@ pub trait EVDDecomposable<T: Number + RealNumber>: Array2<T> {
sort(&mut d, &mut e, &mut V);
}
Ok(EVD { V, d, e })
Ok(EVD { d, e, V })
}
}
fn tred2<T: Number + RealNumber, M: Array2<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
fn tred2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
let (n, _) = V.shape();
for (i, d_i) in d.iter_mut().enumerate().take(n) {
*d_i = *V.get((n - 1, i));
*d_i = V.get(n - 1, i);
}
for i in (1..n).rev() {
@@ -113,9 +112,9 @@ fn tred2<T: Number + RealNumber, M: Array2<T>>(V: &mut M, d: &mut [T], e: &mut [
if scale == T::zero() {
e[i] = d[i - 1];
for (j, d_j) in d.iter_mut().enumerate().take(i) {
*d_j = *V.get((i - 1, j));
V.set((i, j), T::zero());
V.set((j, i), T::zero());
*d_j = V.get(i - 1, j);
V.set(i, j, T::zero());
V.set(j, i, T::zero());
}
} else {
for d_k in d.iter_mut().take(i) {
@@ -136,11 +135,11 @@ fn tred2<T: Number + RealNumber, M: Array2<T>>(V: &mut M, d: &mut [T], e: &mut [
for j in 0..i {
f = d[j];
V.set((j, i), f);
g = e[j] + *V.get((j, j)) * f;
V.set(j, i, f);
g = e[j] + V.get(j, j) * f;
for k in j + 1..=i - 1 {
g += *V.get((k, j)) * d[k];
e[k] += *V.get((k, j)) * f;
g += V.get(k, j) * d[k];
e[k] += V.get(k, j) * f;
}
e[j] = g;
}
@@ -157,46 +156,46 @@ fn tred2<T: Number + RealNumber, M: Array2<T>>(V: &mut M, d: &mut [T], e: &mut [
f = d[j];
g = e[j];
for k in j..=i - 1 {
V.sub_element_mut((k, j), f * e[k] + g * d[k]);
V.sub_element_mut(k, j, f * e[k] + g * d[k]);
}
d[j] = *V.get((i - 1, j));
V.set((i, j), T::zero());
d[j] = V.get(i - 1, j);
V.set(i, j, T::zero());
}
}
d[i] = h;
}
for i in 0..n - 1 {
V.set((n - 1, i), *V.get((i, i)));
V.set((i, i), T::one());
V.set(n - 1, i, V.get(i, i));
V.set(i, i, T::one());
let h = d[i + 1];
if h != T::zero() {
for (k, d_k) in d.iter_mut().enumerate().take(i + 1) {
*d_k = *V.get((k, i + 1)) / h;
*d_k = V.get(k, i + 1) / h;
}
for j in 0..=i {
let mut g = T::zero();
for k in 0..=i {
g += *V.get((k, i + 1)) * *V.get((k, j));
g += V.get(k, i + 1) * V.get(k, j);
}
for (k, d_k) in d.iter().enumerate().take(i + 1) {
V.sub_element_mut((k, j), g * (*d_k));
V.sub_element_mut(k, j, g * (*d_k));
}
}
}
for k in 0..=i {
V.set((k, i + 1), T::zero());
V.set(k, i + 1, T::zero());
}
}
for (j, d_j) in d.iter_mut().enumerate().take(n) {
*d_j = *V.get((n - 1, j));
V.set((n - 1, j), T::zero());
*d_j = V.get(n - 1, j);
V.set(n - 1, j, T::zero());
}
V.set((n - 1, n - 1), T::one());
V.set(n - 1, n - 1, T::one());
e[0] = T::zero();
}
fn tql2<T: Number + RealNumber, M: Array2<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
fn tql2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
let (n, _) = V.shape();
for i in 1..n {
e[i - 1] = e[i];
@@ -265,9 +264,9 @@ fn tql2<T: Number + RealNumber, M: Array2<T>>(V: &mut M, d: &mut [T], e: &mut [T
d[i + 1] = h + s * (c * g + s * d[i]);
for k in 0..n {
h = *V.get((k, i + 1));
V.set((k, i + 1), s * *V.get((k, i)) + c * h);
V.set((k, i), c * *V.get((k, i)) - s * h);
h = V.get(k, i + 1);
V.set(k, i + 1, s * V.get(k, i) + c * h);
V.set(k, i, c * V.get(k, i) - s * h);
}
}
p = -s * s2 * c3 * el1 * e[l] / dl1;
@@ -296,15 +295,15 @@ fn tql2<T: Number + RealNumber, M: Array2<T>>(V: &mut M, d: &mut [T], e: &mut [T
d[k] = d[i];
d[i] = p;
for j in 0..n {
p = *V.get((j, i));
V.set((j, i), *V.get((j, k)));
V.set((j, k), p);
p = V.get(j, i);
V.set(j, i, V.get(j, k));
V.set(j, k, p);
}
}
}
}
fn balance<T: Number + RealNumber, M: Array2<T>>(A: &mut M) -> Vec<T> {
fn balance<T: RealNumber, M: BaseMatrix<T>>(A: &mut M) -> Vec<T> {
let radix = T::two();
let sqrdx = radix * radix;
@@ -322,8 +321,8 @@ fn balance<T: Number + RealNumber, M: Array2<T>>(A: &mut M) -> Vec<T> {
let mut c = T::zero();
for j in 0..n {
if j != i {
c += A.get((j, i)).abs();
r += A.get((i, j)).abs();
c += A.get(j, i).abs();
r += A.get(i, j).abs();
}
}
if c != T::zero() && r != T::zero() {
@@ -344,10 +343,10 @@ fn balance<T: Number + RealNumber, M: Array2<T>>(A: &mut M) -> Vec<T> {
g = T::one() / f;
*scale_i *= f;
for j in 0..n {
A.mul_element_mut((i, j), g);
A.mul_element_mut(i, j, g);
}
for j in 0..n {
A.mul_element_mut((j, i), f);
A.mul_element_mut(j, i, f);
}
}
}
@@ -357,7 +356,7 @@ fn balance<T: Number + RealNumber, M: Array2<T>>(A: &mut M) -> Vec<T> {
scale
}
fn elmhes<T: Number + RealNumber, M: Array2<T>>(A: &mut M) -> Vec<usize> {
fn elmhes<T: RealNumber, M: BaseMatrix<T>>(A: &mut M) -> Vec<usize> {
let (n, _) = A.shape();
let mut perm = vec![0; n];
@@ -365,31 +364,35 @@ fn elmhes<T: Number + RealNumber, M: Array2<T>>(A: &mut M) -> Vec<usize> {
let mut x = T::zero();
let mut i = m;
for j in m..n {
if A.get((j, m - 1)).abs() > x.abs() {
x = *A.get((j, m - 1));
if A.get(j, m - 1).abs() > x.abs() {
x = A.get(j, m - 1);
i = j;
}
}
*perm_m = i;
if i != m {
for j in (m - 1)..n {
A.swap((i, j), (m, j));
let swap = A.get(i, j);
A.set(i, j, A.get(m, j));
A.set(m, j, swap);
}
for j in 0..n {
A.swap((j, i), (j, m));
let swap = A.get(j, i);
A.set(j, i, A.get(j, m));
A.set(j, m, swap);
}
}
if x != T::zero() {
for i in (m + 1)..n {
let mut y = *A.get((i, m - 1));
let mut y = A.get(i, m - 1);
if y != T::zero() {
y /= x;
A.set((i, m - 1), y);
A.set(i, m - 1, y);
for j in m..n {
A.sub_element_mut((i, j), y * *A.get((m, j)));
A.sub_element_mut(i, j, y * A.get(m, j));
}
for j in 0..n {
A.add_element_mut((j, m), y * *A.get((j, i)));
A.add_element_mut(j, m, y * A.get(j, i));
}
}
}
@@ -399,24 +402,24 @@ fn elmhes<T: Number + RealNumber, M: Array2<T>>(A: &mut M) -> Vec<usize> {
perm
}
fn eltran<T: Number + RealNumber, M: Array2<T>>(A: &M, V: &mut M, perm: &[usize]) {
fn eltran<T: RealNumber, M: BaseMatrix<T>>(A: &M, V: &mut M, perm: &[usize]) {
let (n, _) = A.shape();
for mp in (1..n - 1).rev() {
for k in mp + 1..n {
V.set((k, mp), *A.get((k, mp - 1)));
V.set(k, mp, A.get(k, mp - 1));
}
let i = perm[mp];
if i != mp {
for j in mp..n {
V.set((mp, j), *V.get((i, j)));
V.set((i, j), T::zero());
V.set(mp, j, V.get(i, j));
V.set(i, j, T::zero());
}
V.set((i, mp), T::one());
V.set(i, mp, T::one());
}
}
}
fn hqr2<T: Number + RealNumber, M: Array2<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &mut [T]) {
fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &mut [T]) {
let (n, _) = A.shape();
let mut z = T::zero();
let mut s = T::zero();
@@ -427,7 +430,7 @@ fn hqr2<T: Number + RealNumber, M: Array2<T>>(A: &mut M, V: &mut M, d: &mut [T],
for i in 0..n {
for j in i32::max(i as i32 - 1, 0)..n as i32 {
anorm += A.get((i, j as usize)).abs();
anorm += A.get(i, j as usize).abs();
}
}
@@ -438,43 +441,43 @@ fn hqr2<T: Number + RealNumber, M: Array2<T>>(A: &mut M, V: &mut M, d: &mut [T],
loop {
let mut l = nn;
while l > 0 {
s = A.get((l - 1, l - 1)).abs() + A.get((l, l)).abs();
s = A.get(l - 1, l - 1).abs() + A.get(l, l).abs();
if s == T::zero() {
s = anorm;
}
if A.get((l, l - 1)).abs() <= T::epsilon() * s {
A.set((l, l - 1), T::zero());
if A.get(l, l - 1).abs() <= T::epsilon() * s {
A.set(l, l - 1, T::zero());
break;
}
l -= 1;
}
let mut x = *A.get((nn, nn));
let mut x = A.get(nn, nn);
if l == nn {
d[nn] = x + t;
A.set((nn, nn), x + t);
A.set(nn, nn, x + t);
if nn == 0 {
break 'outer;
} else {
nn -= 1;
}
} else {
let mut y = *A.get((nn - 1, nn - 1));
let mut w = *A.get((nn, nn - 1)) * *A.get((nn - 1, nn));
let mut y = A.get(nn - 1, nn - 1);
let mut w = A.get(nn, nn - 1) * A.get(nn - 1, nn);
if l == nn - 1 {
p = T::half() * (y - x);
q = p * p + w;
z = q.abs().sqrt();
x += t;
A.set((nn, nn), x);
A.set((nn - 1, nn - 1), y + t);
A.set(nn, nn, x);
A.set(nn - 1, nn - 1, y + t);
if q >= T::zero() {
z = p + <T as RealNumber>::copysign(z, p);
z = p + RealNumber::copysign(z, p);
d[nn - 1] = x + z;
d[nn] = x + z;
if z != T::zero() {
d[nn] = x - w / z;
}
x = *A.get((nn, nn - 1));
x = A.get(nn, nn - 1);
s = x.abs() + z.abs();
p = x / s;
q = z / s;
@@ -482,19 +485,19 @@ fn hqr2<T: Number + RealNumber, M: Array2<T>>(A: &mut M, V: &mut M, d: &mut [T],
p /= r;
q /= r;
for j in nn - 1..n {
z = *A.get((nn - 1, j));
A.set((nn - 1, j), q * z + p * *A.get((nn, j)));
A.set((nn, j), q * *A.get((nn, j)) - p * z);
z = A.get(nn - 1, j);
A.set(nn - 1, j, q * z + p * A.get(nn, j));
A.set(nn, j, q * A.get(nn, j) - p * z);
}
for i in 0..=nn {
z = *A.get((i, nn - 1));
A.set((i, nn - 1), q * z + p * *A.get((i, nn)));
A.set((i, nn), q * *A.get((i, nn)) - p * z);
z = A.get(i, nn - 1);
A.set(i, nn - 1, q * z + p * A.get(i, nn));
A.set(i, nn, q * A.get(i, nn) - p * z);
}
for i in 0..n {
z = *V.get((i, nn - 1));
V.set((i, nn - 1), q * z + p * *V.get((i, nn)));
V.set((i, nn), q * *V.get((i, nn)) - p * z);
z = V.get(i, nn - 1);
V.set(i, nn - 1, q * z + p * V.get(i, nn));
V.set(i, nn, q * V.get(i, nn) - p * z);
}
} else {
d[nn] = x + p;
@@ -515,22 +518,22 @@ fn hqr2<T: Number + RealNumber, M: Array2<T>>(A: &mut M, V: &mut M, d: &mut [T],
if its == 10 || its == 20 {
t += x;
for i in 0..nn + 1 {
A.sub_element_mut((i, i), x);
A.sub_element_mut(i, i, x);
}
s = A.get((nn, nn - 1)).abs() + A.get((nn - 1, nn - 2)).abs();
y = T::from_f64(0.75).unwrap() * s;
x = T::from_f64(0.75).unwrap() * s;
w = T::from_f64(-0.4375).unwrap() * s * s;
s = A.get(nn, nn - 1).abs() + A.get(nn - 1, nn - 2).abs();
y = T::from(0.75).unwrap() * s;
x = T::from(0.75).unwrap() * s;
w = T::from(-0.4375).unwrap() * s * s;
}
its += 1;
let mut m = nn - 2;
while m >= l {
z = *A.get((m, m));
z = A.get(m, m);
r = x - z;
s = y - z;
p = (r * s - w) / *A.get((m + 1, m)) + *A.get((m, m + 1));
q = *A.get((m + 1, m + 1)) - z - r - s;
r = *A.get((m + 2, m + 1));
p = (r * s - w) / A.get(m + 1, m) + A.get(m, m + 1);
q = A.get(m + 1, m + 1) - z - r - s;
r = A.get(m + 2, m + 1);
s = p.abs() + q.abs() + r.abs();
p /= s;
q /= s;
@@ -538,27 +541,27 @@ fn hqr2<T: Number + RealNumber, M: Array2<T>>(A: &mut M, V: &mut M, d: &mut [T],
if m == l {
break;
}
let u = A.get((m, m - 1)).abs() * (q.abs() + r.abs());
let u = A.get(m, m - 1).abs() * (q.abs() + r.abs());
let v = p.abs()
* (A.get((m - 1, m - 1)).abs() + z.abs() + A.get((m + 1, m + 1)).abs());
* (A.get(m - 1, m - 1).abs() + z.abs() + A.get(m + 1, m + 1).abs());
if u <= T::epsilon() * v {
break;
}
m -= 1;
}
for i in m..nn - 1 {
A.set((i + 2, i), T::zero());
A.set(i + 2, i, T::zero());
if i != m {
A.set((i + 2, i - 1), T::zero());
A.set(i + 2, i - 1, T::zero());
}
}
for k in m..nn {
if k != m {
p = *A.get((k, k - 1));
q = *A.get((k + 1, k - 1));
p = A.get(k, k - 1);
q = A.get(k + 1, k - 1);
r = T::zero();
if k + 1 != nn {
r = *A.get((k + 2, k - 1));
r = A.get(k + 2, k - 1);
}
x = p.abs() + q.abs() + r.abs();
if x != T::zero() {
@@ -567,14 +570,14 @@ fn hqr2<T: Number + RealNumber, M: Array2<T>>(A: &mut M, V: &mut M, d: &mut [T],
r /= x;
}
}
let s = <T as RealNumber>::copysign((p * p + q * q + r * r).sqrt(), p);
let s = RealNumber::copysign((p * p + q * q + r * r).sqrt(), p);
if s != T::zero() {
if k == m {
if l != m {
A.set((k, k - 1), -*A.get((k, k - 1)));
A.set(k, k - 1, -A.get(k, k - 1));
}
} else {
A.set((k, k - 1), -s * x);
A.set(k, k - 1, -s * x);
}
p += s;
x = p / s;
@@ -583,33 +586,32 @@ fn hqr2<T: Number + RealNumber, M: Array2<T>>(A: &mut M, V: &mut M, d: &mut [T],
q /= p;
r /= p;
for j in k..n {
p = *A.get((k, j)) + q * *A.get((k + 1, j));
p = A.get(k, j) + q * A.get(k + 1, j);
if k + 1 != nn {
p += r * *A.get((k + 2, j));
A.sub_element_mut((k + 2, j), p * z);
p += r * A.get(k + 2, j);
A.sub_element_mut(k + 2, j, p * z);
}
A.sub_element_mut((k + 1, j), p * y);
A.sub_element_mut((k, j), p * x);
A.sub_element_mut(k + 1, j, p * y);
A.sub_element_mut(k, j, p * x);
}
let mmin = if nn < k + 3 { nn } else { k + 3 };
for i in 0..(mmin + 1) {
p = x * *A.get((i, k)) + y * *A.get((i, k + 1));
for i in 0..mmin + 1 {
p = x * A.get(i, k) + y * A.get(i, k + 1);
if k + 1 != nn {
p += z * *A.get((i, k + 2));
A.sub_element_mut((i, k + 2), p * r);
p += z * A.get(i, k + 2);
A.sub_element_mut(i, k + 2, p * r);
}
A.sub_element_mut((i, k + 1), p * q);
A.sub_element_mut((i, k), p);
A.sub_element_mut(i, k + 1, p * q);
A.sub_element_mut(i, k, p);
}
for i in 0..n {
p = x * *V.get((i, k)) + y * *V.get((i, k + 1));
p = x * V.get(i, k) + y * V.get(i, k + 1);
if k + 1 != nn {
p += z * *V.get((i, k + 2));
V.sub_element_mut((i, k + 2), p * r);
p += z * V.get(i, k + 2);
V.sub_element_mut(i, k + 2, p * r);
}
V.sub_element_mut((i, k + 1), p * q);
V.sub_element_mut((i, k), p);
V.sub_element_mut(i, k + 1, p * q);
V.sub_element_mut(i, k, p);
}
}
}
@@ -628,14 +630,14 @@ fn hqr2<T: Number + RealNumber, M: Array2<T>>(A: &mut M, V: &mut M, d: &mut [T],
let na = nn.wrapping_sub(1);
if q == T::zero() {
let mut m = nn;
A.set((nn, nn), T::one());
A.set(nn, nn, T::one());
if nn > 0 {
let mut i = nn - 1;
loop {
let w = *A.get((i, i)) - p;
let w = A.get(i, i) - p;
r = T::zero();
for j in m..=nn {
r += *A.get((i, j)) * *A.get((j, nn));
r += A.get(i, j) * A.get(j, nn);
}
if e[i] < T::zero() {
z = w;
@@ -648,23 +650,23 @@ fn hqr2<T: Number + RealNumber, M: Array2<T>>(A: &mut M, V: &mut M, d: &mut [T],
if t == T::zero() {
t = T::epsilon() * anorm;
}
A.set((i, nn), -r / t);
A.set(i, nn, -r / t);
} else {
let x = *A.get((i, i + 1));
let y = *A.get((i + 1, i));
let x = A.get(i, i + 1);
let y = A.get(i + 1, i);
q = (d[i] - p).powf(T::two()) + e[i].powf(T::two());
t = (x * s - z * r) / q;
A.set((i, nn), t);
A.set(i, nn, t);
if x.abs() > z.abs() {
A.set((i + 1, nn), (-r - w * t) / x);
A.set(i + 1, nn, (-r - w * t) / x);
} else {
A.set((i + 1, nn), (-s - y * t) / z);
A.set(i + 1, nn, (-s - y * t) / z);
}
}
t = A.get((i, nn)).abs();
t = A.get(i, nn).abs();
if T::epsilon() * t * t > T::one() {
for j in i..=nn {
A.div_element_mut((j, nn), t);
A.div_element_mut(j, nn, t);
}
}
}
@@ -677,25 +679,25 @@ fn hqr2<T: Number + RealNumber, M: Array2<T>>(A: &mut M, V: &mut M, d: &mut [T],
}
} else if q < T::zero() {
let mut m = na;
if A.get((nn, na)).abs() > A.get((na, nn)).abs() {
A.set((na, na), q / *A.get((nn, na)));
A.set((na, nn), -(*A.get((nn, nn)) - p) / *A.get((nn, na)));
if A.get(nn, na).abs() > A.get(na, nn).abs() {
A.set(na, na, q / A.get(nn, na));
A.set(na, nn, -(A.get(nn, nn) - p) / A.get(nn, na));
} else {
let temp = Complex::new(T::zero(), -*A.get((na, nn)))
/ Complex::new(*A.get((na, na)) - p, q);
A.set((na, na), temp.re);
A.set((na, nn), temp.im);
let temp = Complex::new(T::zero(), -A.get(na, nn))
/ Complex::new(A.get(na, na) - p, q);
A.set(na, na, temp.re);
A.set(na, nn, temp.im);
}
A.set((nn, na), T::zero());
A.set((nn, nn), T::one());
A.set(nn, na, T::zero());
A.set(nn, nn, T::one());
if nn >= 2 {
for i in (0..nn - 1).rev() {
let w = *A.get((i, i)) - p;
let w = A.get(i, i) - p;
let mut ra = T::zero();
let mut sa = T::zero();
for j in m..=nn {
ra += *A.get((i, j)) * *A.get((j, na));
sa += *A.get((i, j)) * *A.get((j, nn));
ra += A.get(i, j) * A.get(j, na);
sa += A.get(i, j) * A.get(j, nn);
}
if e[i] < T::zero() {
z = w;
@@ -705,11 +707,11 @@ fn hqr2<T: Number + RealNumber, M: Array2<T>>(A: &mut M, V: &mut M, d: &mut [T],
m = i;
if e[i] == T::zero() {
let temp = Complex::new(-ra, -sa) / Complex::new(w, q);
A.set((i, na), temp.re);
A.set((i, nn), temp.im);
A.set(i, na, temp.re);
A.set(i, nn, temp.im);
} else {
let x = *A.get((i, i + 1));
let y = *A.get((i + 1, i));
let x = A.get(i, i + 1);
let y = A.get(i + 1, i);
let mut vr =
(d[i] - p).powf(T::two()) + (e[i]).powf(T::two()) - q * q;
let vi = T::two() * q * (d[i] - p);
@@ -721,32 +723,33 @@ fn hqr2<T: Number + RealNumber, M: Array2<T>>(A: &mut M, V: &mut M, d: &mut [T],
let temp =
Complex::new(x * r - z * ra + q * sa, x * s - z * sa - q * ra)
/ Complex::new(vr, vi);
A.set((i, na), temp.re);
A.set((i, nn), temp.im);
A.set(i, na, temp.re);
A.set(i, nn, temp.im);
if x.abs() > z.abs() + q.abs() {
A.set(
(i + 1, na),
(-ra - w * *A.get((i, na)) + q * *A.get((i, nn))) / x,
i + 1,
na,
(-ra - w * A.get(i, na) + q * A.get(i, nn)) / x,
);
A.set(
(i + 1, nn),
(-sa - w * *A.get((i, nn)) - q * *A.get((i, na))) / x,
i + 1,
nn,
(-sa - w * A.get(i, nn) - q * A.get(i, na)) / x,
);
} else {
let temp = Complex::new(
-r - y * *A.get((i, na)),
-s - y * *A.get((i, nn)),
) / Complex::new(z, q);
A.set((i + 1, na), temp.re);
A.set((i + 1, nn), temp.im);
let temp =
Complex::new(-r - y * A.get(i, na), -s - y * A.get(i, nn))
/ Complex::new(z, q);
A.set(i + 1, na, temp.re);
A.set(i + 1, nn, temp.im);
}
}
}
t = T::max(A.get((i, na)).abs(), A.get((i, nn)).abs());
t = T::max(A.get(i, na).abs(), A.get(i, nn).abs());
if T::epsilon() * t * t > T::one() {
for j in i..=nn {
A.div_element_mut((j, na), t);
A.div_element_mut((j, nn), t);
A.div_element_mut(j, na, t);
A.div_element_mut(j, nn, t);
}
}
}
@@ -758,31 +761,31 @@ fn hqr2<T: Number + RealNumber, M: Array2<T>>(A: &mut M, V: &mut M, d: &mut [T],
for i in 0..n {
z = T::zero();
for k in 0..=j {
z += *V.get((i, k)) * *A.get((k, j));
z += V.get(i, k) * A.get(k, j);
}
V.set((i, j), z);
V.set(i, j, z);
}
}
}
}
fn balbak<T: Number + RealNumber, M: Array2<T>>(V: &mut M, scale: &[T]) {
fn balbak<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, scale: &[T]) {
let (n, _) = V.shape();
for (i, scale_i) in scale.iter().enumerate().take(n) {
for j in 0..n {
V.mul_element_mut((i, j), *scale_i);
V.mul_element_mut(i, j, *scale_i);
}
}
}
fn sort<T: Number + RealNumber, M: Array2<T>>(d: &mut [T], e: &mut [T], V: &mut M) {
fn sort<T: RealNumber, M: BaseMatrix<T>>(d: &mut [T], e: &mut [T], V: &mut M) {
let n = d.len();
let mut temp = vec![T::zero(); n];
for j in 1..n {
let real = d[j];
let img = e[j];
for (k, temp_k) in temp.iter_mut().enumerate().take(n) {
*temp_k = *V.get((k, j));
*temp_k = V.get(k, j);
}
let mut i = j as i32 - 1;
while i >= 0 {
@@ -792,14 +795,14 @@ fn sort<T: Number + RealNumber, M: Array2<T>>(d: &mut [T], e: &mut [T], V: &mut
d[i as usize + 1] = d[i as usize];
e[i as usize + 1] = e[i as usize];
for k in 0..n {
V.set((k, i as usize + 1), *V.get((k, i as usize)));
V.set(k, i as usize + 1, V.get(k, i as usize));
}
i -= 1;
}
d[i as usize + 1] = real;
e[i as usize + 1] = img;
for (k, temp_k) in temp.iter().enumerate().take(n) {
V.set((k, i as usize + 1), *temp_k);
V.set(k, i as usize + 1, *temp_k);
}
}
}
@@ -807,21 +810,15 @@ fn sort<T: Number + RealNumber, M: Array2<T>>(d: &mut [T], e: &mut [T], V: &mut
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use approx::relative_eq;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
use crate::linalg::naive::dense_matrix::DenseMatrix;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn decompose_symmetric() {
let A = DenseMatrix::from_2d_array(&[
&[0.9000, 0.4000, 0.7000],
&[0.4000, 0.5000, 0.3000],
&[0.7000, 0.3000, 0.8000],
])
.unwrap();
]);
let eigen_values: Vec<f64> = vec![1.7498382, 0.3165784, 0.1335834];
@@ -829,33 +826,26 @@ mod tests {
&[0.6881997, -0.07121225, 0.7220180],
&[0.3700456, 0.89044952, -0.2648886],
&[0.6240573, -0.44947578, -0.6391588],
])
.unwrap();
]);
let evd = A.evd(true).unwrap();
assert!(relative_eq!(
eigen_vectors.abs(),
evd.V.abs(),
epsilon = 1e-4
));
for (i, eigen_values_i) in eigen_values.iter().enumerate() {
assert!((eigen_values_i - evd.d[i]).abs() < 1e-4);
assert!((0f64 - evd.e[i]).abs() < f64::EPSILON);
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
for i in 0..eigen_values.len() {
assert!((eigen_values[i] - evd.d[i]).abs() < 1e-4);
}
for i in 0..eigen_values.len() {
assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON);
}
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn decompose_asymmetric() {
let A = DenseMatrix::from_2d_array(&[
&[0.9000, 0.4000, 0.7000],
&[0.4000, 0.5000, 0.3000],
&[0.8000, 0.3000, 0.8000],
])
.unwrap();
]);
let eigen_values: Vec<f64> = vec![1.79171122, 0.31908143, 0.08920735];
@@ -863,25 +853,19 @@ mod tests {
&[0.7178958, 0.05322098, 0.6812010],
&[0.3837711, -0.84702111, -0.1494582],
&[0.6952105, 0.43984484, -0.7036135],
])
.unwrap();
]);
let evd = A.evd(false).unwrap();
assert!(relative_eq!(
eigen_vectors.abs(),
evd.V.abs(),
epsilon = 1e-4
));
for (i, eigen_values_i) in eigen_values.iter().enumerate() {
assert!((eigen_values_i - evd.d[i]).abs() < 1e-4);
assert!((0f64 - evd.e[i]).abs() < f64::EPSILON);
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
for i in 0..eigen_values.len() {
assert!((eigen_values[i] - evd.d[i]).abs() < 1e-4);
}
for i in 0..eigen_values.len() {
assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON);
}
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn decompose_complex() {
let A = DenseMatrix::from_2d_array(&[
@@ -889,8 +873,7 @@ mod tests {
&[4.0, -1.0, 1.0, 1.0],
&[1.0, 1.0, 3.0, -2.0],
&[1.0, 1.0, 4.0, -1.0],
])
.unwrap();
]);
let eigen_values_d: Vec<f64> = vec![0.0, 2.0, 2.0, 0.0];
let eigen_values_e: Vec<f64> = vec![2.2361, 0.9999, -0.9999, -2.2361];
@@ -900,21 +883,16 @@ mod tests {
&[-0.6707, 0.1059, 0.901, 0.6289],
&[0.9159, -0.1378, 0.3816, 0.0806],
&[0.6707, 0.1059, 0.901, -0.6289],
])
.unwrap();
]);
let evd = A.evd(false).unwrap();
assert!(relative_eq!(
eigen_vectors.abs(),
evd.V.abs(),
epsilon = 1e-4
));
for (i, eigen_values_d_i) in eigen_values_d.iter().enumerate() {
assert!((eigen_values_d_i - evd.d[i]).abs() < 1e-4);
assert!(eigen_vectors.abs().approximate_eq(&evd.V.abs(), 1e-4));
for i in 0..eigen_values_d.len() {
assert!((eigen_values_d[i] - evd.d[i]).abs() < 1e-4);
}
for (i, eigen_values_e_i) in eigen_values_e.iter().enumerate() {
assert!((eigen_values_e_i - evd.e[i]).abs() < 1e-4);
for i in 0..eigen_values_e.len() {
assert!((eigen_values_e[i] - evd.e[i]).abs() < 1e-4);
}
}
}
@@ -1,20 +1,19 @@
//! In this module you will find composite of matrix operations that are used elsewhere
//! for improved efficiency.
use crate::linalg::basic::arrays::Array2;
use crate::numbers::basenum::Number;
use crate::linalg::BaseMatrix;
use crate::math::num::RealNumber;
/// High order matrix operations.
pub trait HighOrderOperations<T: Number>: Array2<T> {
pub trait HighOrderOperations<T: RealNumber>: BaseMatrix<T> {
/// Y = AB
/// ```
/// use smartcore::linalg::basic::matrix::*;
/// use smartcore::linalg::traits::high_order::HighOrderOperations;
/// use smartcore::linalg::basic::arrays::Array2;
/// use smartcore::linalg::naive::dense_matrix::*;
/// use smartcore::linalg::high_order::HighOrderOperations;
///
/// let a = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.]]).unwrap();
/// let b = DenseMatrix::from_2d_array(&[&[5., 6.], &[7., 8.], &[9., 10.]]).unwrap();
/// let expected = DenseMatrix::from_2d_array(&[&[71., 80.], &[92., 104.]]).unwrap();
/// let a = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.]]);
/// let b = DenseMatrix::from_2d_array(&[&[5., 6.], &[7., 8.], &[9., 10.]]);
/// let expected = DenseMatrix::from_2d_array(&[&[71., 80.], &[92., 104.]]);
///
/// assert_eq!(a.ab(true, &b, false), expected);
/// ```
@@ -27,7 +26,3 @@ pub trait HighOrderOperations<T: Number>: Array2<T> {
}
}
}
mod tests {
/* TODO: Add tests */
}
+51 -55
View File
@@ -11,14 +11,14 @@
//!
//! Example:
//! ```
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linalg::traits::lu::*;
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::linalg::lu::*;
//!
//! let A = DenseMatrix::from_2d_array(&[
//! &[1., 2., 3.],
//! &[0., 1., 5.],
//! &[5., 6., 0.]
//! ]).unwrap();
//! ]);
//!
//! let lu = A.lu().unwrap();
//! let lower: DenseMatrix<f64> = lu.L();
@@ -38,27 +38,26 @@ use std::fmt::Debug;
use std::marker::PhantomData;
use crate::error::Failed;
use crate::linalg::basic::arrays::Array2;
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
use crate::linalg::BaseMatrix;
use crate::math::num::RealNumber;
#[derive(Debug, Clone)]
/// Result of LU decomposition.
pub struct LU<T: Number + RealNumber, M: Array2<T>> {
pub struct LU<T: RealNumber, M: BaseMatrix<T>> {
LU: M,
pivot: Vec<usize>,
#[allow(dead_code)]
pivot_sign: i8,
_pivot_sign: i8,
singular: bool,
phantom: PhantomData<T>,
}
impl<T: Number + RealNumber, M: Array2<T>> LU<T, M> {
pub(crate) fn new(LU: M, pivot: Vec<usize>, pivot_sign: i8) -> LU<T, M> {
impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
pub(crate) fn new(LU: M, pivot: Vec<usize>, _pivot_sign: i8) -> LU<T, M> {
let (_, n) = LU.shape();
let mut singular = false;
for j in 0..n {
if LU.get((j, j)) == &T::zero() {
if LU.get(j, j) == T::zero() {
singular = true;
break;
}
@@ -67,7 +66,7 @@ impl<T: Number + RealNumber, M: Array2<T>> LU<T, M> {
LU {
LU,
pivot,
pivot_sign,
_pivot_sign,
singular,
phantom: PhantomData,
}
@@ -81,9 +80,9 @@ impl<T: Number + RealNumber, M: Array2<T>> LU<T, M> {
for i in 0..n_rows {
for j in 0..n_cols {
match i.cmp(&j) {
Ordering::Greater => L.set((i, j), *self.LU.get((i, j))),
Ordering::Equal => L.set((i, j), T::one()),
Ordering::Less => L.set((i, j), T::zero()),
Ordering::Greater => L.set(i, j, self.LU.get(i, j)),
Ordering::Equal => L.set(i, j, T::one()),
Ordering::Less => L.set(i, j, T::zero()),
}
}
}
@@ -99,9 +98,9 @@ impl<T: Number + RealNumber, M: Array2<T>> LU<T, M> {
for i in 0..n_rows {
for j in 0..n_cols {
if i <= j {
U.set((i, j), *self.LU.get((i, j)));
U.set(i, j, self.LU.get(i, j));
} else {
U.set((i, j), T::zero());
U.set(i, j, T::zero());
}
}
}
@@ -115,7 +114,7 @@ impl<T: Number + RealNumber, M: Array2<T>> LU<T, M> {
let mut piv = M::zeros(n, n);
for i in 0..n {
piv.set((i, self.pivot[i]), T::one());
piv.set(i, self.pivot[i], T::one());
}
piv
@@ -126,13 +125,13 @@ impl<T: Number + RealNumber, M: Array2<T>> LU<T, M> {
let (m, n) = self.LU.shape();
if m != n {
panic!("Matrix is not square: {m}x{n}");
panic!("Matrix is not square: {}x{}", m, n);
}
let mut inv = M::zeros(n, n);
for i in 0..n {
inv.set((i, i), T::one());
inv.set(i, i, T::one());
}
self.solve(inv)
@@ -143,7 +142,10 @@ impl<T: Number + RealNumber, M: Array2<T>> LU<T, M> {
let (b_m, b_n) = b.shape();
if b_m != m {
panic!("Row dimensions do not agree: A is {m} x {n}, but B is {b_m} x {b_n}");
panic!(
"Row dimensions do not agree: A is {} x {}, but B is {} x {}",
m, n, b_m, b_n
);
}
if self.singular {
@@ -154,33 +156,33 @@ impl<T: Number + RealNumber, M: Array2<T>> LU<T, M> {
for j in 0..b_n {
for i in 0..m {
X.set((i, j), *b.get((self.pivot[i], j)));
X.set(i, j, b.get(self.pivot[i], j));
}
}
for k in 0..n {
for i in k + 1..n {
for j in 0..b_n {
X.sub_element_mut((i, j), *X.get((k, j)) * *self.LU.get((i, k)));
X.sub_element_mut(i, j, X.get(k, j) * self.LU.get(i, k));
}
}
}
for k in (0..n).rev() {
for j in 0..b_n {
X.div_element_mut((k, j), *self.LU.get((k, k)));
X.div_element_mut(k, j, self.LU.get(k, k));
}
for i in 0..k {
for j in 0..b_n {
X.sub_element_mut((i, j), *X.get((k, j)) * *self.LU.get((i, k)));
X.sub_element_mut(i, j, X.get(k, j) * self.LU.get(i, k));
}
}
}
for j in 0..b_n {
for i in 0..m {
b.set((i, j), *X.get((i, j)));
b.set(i, j, X.get(i, j));
}
}
@@ -189,7 +191,7 @@ impl<T: Number + RealNumber, M: Array2<T>> LU<T, M> {
}
/// Trait that implements LU decomposition routine for any matrix.
pub trait LUDecomposable<T: Number + RealNumber>: Array2<T> {
pub trait LUDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
/// Compute the LU decomposition of a square matrix.
fn lu(&self) -> Result<LU<T, Self>, Failed> {
self.clone().lu_mut()
@@ -207,18 +209,18 @@ pub trait LUDecomposable<T: Number + RealNumber>: Array2<T> {
for j in 0..n {
for (i, LUcolj_i) in LUcolj.iter_mut().enumerate().take(m) {
*LUcolj_i = *self.get((i, j));
*LUcolj_i = self.get(i, j);
}
for i in 0..m {
let kmax = usize::min(i, j);
let mut s = T::zero();
for (k, LUcolj_k) in LUcolj.iter().enumerate().take(kmax) {
s += *self.get((i, k)) * (*LUcolj_k);
s += self.get(i, k) * (*LUcolj_k);
}
LUcolj[i] -= s;
self.set((i, j), LUcolj[i]);
self.set(i, j, LUcolj[i]);
}
let mut p = j;
@@ -229,15 +231,17 @@ pub trait LUDecomposable<T: Number + RealNumber>: Array2<T> {
}
if p != j {
for k in 0..n {
self.swap((p, k), (j, k));
let t = self.get(p, k);
self.set(p, k, self.get(j, k));
self.set(j, k, t);
}
piv.swap(p, j);
pivsign = -pivsign;
}
if j < m && self.get((j, j)) != &T::zero() {
if j < m && self.get(j, j) != T::zero() {
for i in j + 1..m {
self.div_element_mut((i, j), *self.get((j, j)));
self.div_element_mut(i, j, self.get(j, j));
}
}
}
@@ -254,38 +258,30 @@ pub trait LUDecomposable<T: Number + RealNumber>: Array2<T> {
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use approx::relative_eq;
use crate::linalg::naive::dense_matrix::*;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn decompose() {
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]).unwrap();
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
let expected_L =
DenseMatrix::from_2d_array(&[&[1., 0., 0.], &[0., 1., 0.], &[0.2, 0.8, 1.]]).unwrap();
DenseMatrix::from_2d_array(&[&[1., 0., 0.], &[0., 1., 0.], &[0.2, 0.8, 1.]]);
let expected_U =
DenseMatrix::from_2d_array(&[&[5., 6., 0.], &[0., 1., 5.], &[0., 0., -1.]]).unwrap();
DenseMatrix::from_2d_array(&[&[5., 6., 0.], &[0., 1., 5.], &[0., 0., -1.]]);
let expected_pivot =
DenseMatrix::from_2d_array(&[&[0., 0., 1.], &[0., 1., 0.], &[1., 0., 0.]]).unwrap();
DenseMatrix::from_2d_array(&[&[0., 0., 1.], &[0., 1., 0.], &[1., 0., 0.]]);
let lu = a.lu().unwrap();
assert!(relative_eq!(lu.L(), expected_L, epsilon = 1e-4));
assert!(relative_eq!(lu.U(), expected_U, epsilon = 1e-4));
assert!(relative_eq!(lu.pivot(), expected_pivot, epsilon = 1e-4));
assert!(lu.L().approximate_eq(&expected_L, 1e-4));
assert!(lu.U().approximate_eq(&expected_U, 1e-4));
assert!(lu.pivot().approximate_eq(&expected_pivot, 1e-4));
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn inverse() {
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]).unwrap();
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
let expected =
DenseMatrix::from_2d_array(&[&[-6.0, 3.6, 1.4], &[5.0, -3.0, -1.0], &[-1.0, 0.8, 0.2]])
.unwrap();
DenseMatrix::from_2d_array(&[&[-6.0, 3.6, 1.4], &[5.0, -3.0, -1.0], &[-1.0, 0.8, 0.2]]);
let a_inv = a.lu().and_then(|lu| lu.inverse()).unwrap();
assert!(relative_eq!(a_inv, expected, epsilon = 1e-4));
assert!(a_inv.approximate_eq(&expected, 1e-4));
}
}
+762 -7
View File
@@ -1,9 +1,764 @@
/// basic data structures for linear algebra constructs: arrays and views
pub mod basic;
/// traits associated to algebraic constructs
pub mod traits;
#![allow(clippy::wrong_self_convention)]
//! # Linear Algebra and Matrix Decomposition
//!
//! Most machine learning algorithms in SmartCore depend on linear algebra and matrix decomposition methods from this module.
//!
//! Traits [`BaseMatrix`](trait.BaseMatrix.html), [`Matrix`](trait.Matrix.html) and [`BaseVector`](trait.BaseVector.html) define
//! abstract methods that can be implemented for any two-dimensional and one-dimentional arrays (matrix and vector).
//! Functions from these traits are designed for SmartCore machine learning algorithms and should not be used directly in your code.
//! If you still want to use functions from `BaseMatrix`, `Matrix` and `BaseVector` please be aware that methods defined in these
//! traits might change in the future.
//!
//! One reason why linear algebra traits are public is to allow for different types of matrices and vectors to be plugged into SmartCore.
//! Once all methods defined in `BaseMatrix`, `Matrix` and `BaseVector` are implemented for your favourite type of matrix and vector you
//! should be able to run SmartCore algorithms on it. Please see `nalgebra_bindings` and `ndarray_bindings` modules for an example of how
//! it is done for other libraries.
//!
//! You will also find verious matrix decomposition methods that work for any matrix that extends [`Matrix`](trait.Matrix.html).
//! For example, to decompose matrix defined as [Vec](https://doc.rust-lang.org/std/vec/struct.Vec.html):
//!
//! ```
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::linalg::svd::*;
//!
//! let A = DenseMatrix::from_2d_array(&[
//! &[0.9000, 0.4000, 0.7000],
//! &[0.4000, 0.5000, 0.3000],
//! &[0.7000, 0.3000, 0.8000],
//! ]);
//!
//! let svd = A.svd().unwrap();
//!
//! let s: Vec<f64> = svd.s;
//! let v: DenseMatrix<f64> = svd.V;
//! let u: DenseMatrix<f64> = svd.U;
//! ```
pub mod cholesky;
/// The matrix is represented in terms of its eigenvalues and eigenvectors.
pub mod evd;
pub mod high_order;
/// Factors a matrix as the product of a lower triangular matrix and an upper triangular matrix.
pub mod lu;
/// Dense matrix with column-major order that wraps [Vec](https://doc.rust-lang.org/std/vec/struct.Vec.html).
pub mod naive;
/// [nalgebra](https://docs.rs/nalgebra/) bindings.
#[cfg(feature = "nalgebra-bindings")]
pub mod nalgebra_bindings;
/// [ndarray](https://docs.rs/ndarray) bindings.
#[cfg(feature = "ndarray-bindings")]
/// ndarray bindings
pub mod ndarray;
pub mod ndarray_bindings;
/// QR factorization that factors a matrix into a product of an orthogonal matrix and an upper triangular matrix.
pub mod qr;
pub mod stats;
/// Singular value decomposition.
pub mod svd;
use std::fmt::{Debug, Display};
use std::marker::PhantomData;
use std::ops::Range;
use crate::math::num::RealNumber;
use cholesky::CholeskyDecomposableMatrix;
use evd::EVDDecomposableMatrix;
use high_order::HighOrderOperations;
use lu::LUDecomposableMatrix;
use qr::QRDecomposableMatrix;
use stats::{MatrixPreprocessing, MatrixStats};
use svd::SVDDecomposableMatrix;
/// Column or row vector
pub trait BaseVector<T: RealNumber>: Clone + Debug {
/// Get an element of a vector
/// * `i` - index of an element
fn get(&self, i: usize) -> T;
/// Set an element at `i` to `x`
/// * `i` - index of an element
/// * `x` - new value
fn set(&mut self, i: usize, x: T);
/// Get number of elevemnt in the vector
fn len(&self) -> usize;
/// Returns true if the vector is empty.
fn is_empty(&self) -> bool {
self.len() == 0
}
/// Create a new vector from a &[T]
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
/// let a: [f64; 5] = [0., 0.5, 2., 3., 4.];
/// let v: Vec<f64> = BaseVector::from_array(&a);
/// assert_eq!(v, vec![0., 0.5, 2., 3., 4.]);
/// ```
fn from_array(f: &[T]) -> Self {
let mut v = Self::zeros(f.len());
for (i, elem) in f.iter().enumerate() {
v.set(i, *elem);
}
v
}
/// Return a vector with the elements of the one-dimensional array.
fn to_vec(&self) -> Vec<T>;
/// Create new vector with zeros of size `len`.
fn zeros(len: usize) -> Self;
/// Create new vector with ones of size `len`.
fn ones(len: usize) -> Self;
/// Create new vector of size `len` where each element is set to `value`.
fn fill(len: usize, value: T) -> Self;
/// Vector dot product
fn dot(&self, other: &Self) -> T;
/// Returns True if matrices are element-wise equal within a tolerance `error`.
fn approximate_eq(&self, other: &Self, error: T) -> bool;
/// Returns [L2 norm] of the vector(https://en.wikipedia.org/wiki/Matrix_norm).
fn norm2(&self) -> T;
/// Returns [vectors norm](https://en.wikipedia.org/wiki/Matrix_norm) of order `p`.
fn norm(&self, p: T) -> T;
/// Divide single element of the vector by `x`, write result to original vector.
fn div_element_mut(&mut self, pos: usize, x: T);
/// Multiply single element of the vector by `x`, write result to original vector.
fn mul_element_mut(&mut self, pos: usize, x: T);
/// Add single element of the vector to `x`, write result to original vector.
fn add_element_mut(&mut self, pos: usize, x: T);
/// Subtract `x` from single element of the vector, write result to original vector.
fn sub_element_mut(&mut self, pos: usize, x: T);
/// Subtract scalar
fn sub_scalar_mut(&mut self, x: T) -> &Self {
for i in 0..self.len() {
self.set(i, self.get(i) - x);
}
self
}
/// Subtract scalar
fn add_scalar_mut(&mut self, x: T) -> &Self {
for i in 0..self.len() {
self.set(i, self.get(i) + x);
}
self
}
/// Subtract scalar
fn mul_scalar_mut(&mut self, x: T) -> &Self {
for i in 0..self.len() {
self.set(i, self.get(i) * x);
}
self
}
/// Subtract scalar
fn div_scalar_mut(&mut self, x: T) -> &Self {
for i in 0..self.len() {
self.set(i, self.get(i) / x);
}
self
}
/// Add vectors, element-wise
fn add_scalar(&self, x: T) -> Self {
let mut r = self.clone();
r.add_scalar_mut(x);
r
}
/// Subtract vectors, element-wise
fn sub_scalar(&self, x: T) -> Self {
let mut r = self.clone();
r.sub_scalar_mut(x);
r
}
/// Multiply vectors, element-wise
fn mul_scalar(&self, x: T) -> Self {
let mut r = self.clone();
r.mul_scalar_mut(x);
r
}
/// Divide vectors, element-wise
fn div_scalar(&self, x: T) -> Self {
let mut r = self.clone();
r.div_scalar_mut(x);
r
}
/// Add vectors, element-wise, overriding original vector with result.
fn add_mut(&mut self, other: &Self) -> &Self;
/// Subtract vectors, element-wise, overriding original vector with result.
fn sub_mut(&mut self, other: &Self) -> &Self;
/// Multiply vectors, element-wise, overriding original vector with result.
fn mul_mut(&mut self, other: &Self) -> &Self;
/// Divide vectors, element-wise, overriding original vector with result.
fn div_mut(&mut self, other: &Self) -> &Self;
/// Add vectors, element-wise
fn add(&self, other: &Self) -> Self {
let mut r = self.clone();
r.add_mut(other);
r
}
/// Subtract vectors, element-wise
fn sub(&self, other: &Self) -> Self {
let mut r = self.clone();
r.sub_mut(other);
r
}
/// Multiply vectors, element-wise
fn mul(&self, other: &Self) -> Self {
let mut r = self.clone();
r.mul_mut(other);
r
}
/// Divide vectors, element-wise
fn div(&self, other: &Self) -> Self {
let mut r = self.clone();
r.div_mut(other);
r
}
/// Calculates sum of all elements of the vector.
fn sum(&self) -> T;
/// Returns unique values from the vector.
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
/// let a = vec!(1., 2., 2., -2., -6., -7., 2., 3., 4.);
///
///assert_eq!(a.unique(), vec![-7., -6., -2., 1., 2., 3., 4.]);
/// ```
fn unique(&self) -> Vec<T>;
/// Computes the arithmetic mean.
fn mean(&self) -> T {
self.sum() / T::from_usize(self.len()).unwrap()
}
/// Computes variance.
fn var(&self) -> T {
let n = self.len();
let mut mu = T::zero();
let mut sum = T::zero();
let div = T::from_usize(n).unwrap();
for i in 0..n {
let xi = self.get(i);
mu += xi;
sum += xi * xi;
}
mu /= div;
sum / div - mu.powi(2)
}
/// Computes the standard deviation.
fn std(&self) -> T {
self.var().sqrt()
}
/// Copies content of `other` vector.
fn copy_from(&mut self, other: &Self);
/// Take elements from an array.
fn take(&self, index: &[usize]) -> Self {
let n = index.len();
let mut result = Self::zeros(n);
for (i, idx) in index.iter().enumerate() {
result.set(i, self.get(*idx));
}
result
}
}
/// Generic matrix type.
pub trait BaseMatrix<T: RealNumber>: Clone + Debug {
/// Row vector that is associated with this matrix type,
/// e.g. if we have an implementation of sparce matrix
/// we should have an associated sparce vector type that
/// represents a row in this matrix.
type RowVector: BaseVector<T> + Clone + Debug;
/// Transforms row vector `vec` into a 1xM matrix.
fn from_row_vector(vec: Self::RowVector) -> Self;
/// Transforms 1-d matrix of 1xM into a row vector.
fn to_row_vector(self) -> Self::RowVector;
/// Get an element of the matrix.
/// * `row` - row number
/// * `col` - column number
fn get(&self, row: usize, col: usize) -> T;
/// Get a vector with elements of the `row`'th row
/// * `row` - row number
fn get_row_as_vec(&self, row: usize) -> Vec<T>;
/// Get the `row`'th row
/// * `row` - row number
fn get_row(&self, row: usize) -> Self::RowVector;
/// Copies a vector with elements of the `row`'th row into `result`
/// * `row` - row number
/// * `result` - receiver for the row
fn copy_row_as_vec(&self, row: usize, result: &mut Vec<T>);
/// Get a vector with elements of the `col`'th column
/// * `col` - column number
fn get_col_as_vec(&self, col: usize) -> Vec<T>;
/// Copies a vector with elements of the `col`'th column into `result`
/// * `col` - column number
/// * `result` - receiver for the col
fn copy_col_as_vec(&self, col: usize, result: &mut Vec<T>);
/// Set an element at `col`, `row` to `x`
fn set(&mut self, row: usize, col: usize, x: T);
/// Create an identity matrix of size `size`
fn eye(size: usize) -> Self;
/// Create new matrix with zeros of size `nrows` by `ncols`.
fn zeros(nrows: usize, ncols: usize) -> Self;
/// Create new matrix with ones of size `nrows` by `ncols`.
fn ones(nrows: usize, ncols: usize) -> Self;
/// Create new matrix of size `nrows` by `ncols` where each element is set to `value`.
fn fill(nrows: usize, ncols: usize, value: T) -> Self;
/// Return the shape of an array.
fn shape(&self) -> (usize, usize);
/// Stack arrays in sequence vertically (row wise).
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
///
/// let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
/// let b = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.]]);
/// let expected = DenseMatrix::from_2d_array(&[
/// &[1., 2., 3., 1., 2.],
/// &[4., 5., 6., 3., 4.]
/// ]);
///
/// assert_eq!(a.h_stack(&b), expected);
/// ```
fn h_stack(&self, other: &Self) -> Self;
/// Stack arrays in sequence horizontally (column wise).
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
///
/// let a = DenseMatrix::from_array(1, 3, &[1., 2., 3.]);
/// let b = DenseMatrix::from_array(1, 3, &[4., 5., 6.]);
/// let expected = DenseMatrix::from_2d_array(&[
/// &[1., 2., 3.],
/// &[4., 5., 6.]
/// ]);
///
/// assert_eq!(a.v_stack(&b), expected);
/// ```
fn v_stack(&self, other: &Self) -> Self;
/// Matrix product.
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
///
/// let a = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.]]);
/// let expected = DenseMatrix::from_2d_array(&[
/// &[7., 10.],
/// &[15., 22.]
/// ]);
///
/// assert_eq!(a.matmul(&a), expected);
/// ```
fn matmul(&self, other: &Self) -> Self;
/// Vector dot product
/// Both matrices should be of size _1xM_
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
///
/// let a = DenseMatrix::from_array(1, 3, &[1., 2., 3.]);
/// let b = DenseMatrix::from_array(1, 3, &[4., 5., 6.]);
///
/// assert_eq!(a.dot(&b), 32.);
/// ```
fn dot(&self, other: &Self) -> T;
/// Return a slice of the matrix.
/// * `rows` - range of rows to return
/// * `cols` - range of columns to return
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
///
/// let m = DenseMatrix::from_2d_array(&[
/// &[1., 2., 3., 1.],
/// &[4., 5., 6., 3.],
/// &[7., 8., 9., 5.]
/// ]);
/// let expected = DenseMatrix::from_2d_array(&[&[2., 3.], &[5., 6.]]);
/// let result = m.slice(0..2, 1..3);
/// assert_eq!(result, expected);
/// ```
fn slice(&self, rows: Range<usize>, cols: Range<usize>) -> Self;
/// Returns True if matrices are element-wise equal within a tolerance `error`.
fn approximate_eq(&self, other: &Self, error: T) -> bool;
/// Add matrices, element-wise, overriding original matrix with result.
fn add_mut(&mut self, other: &Self) -> &Self;
/// Subtract matrices, element-wise, overriding original matrix with result.
fn sub_mut(&mut self, other: &Self) -> &Self;
/// Multiply matrices, element-wise, overriding original matrix with result.
fn mul_mut(&mut self, other: &Self) -> &Self;
/// Divide matrices, element-wise, overriding original matrix with result.
fn div_mut(&mut self, other: &Self) -> &Self;
/// Divide single element of the matrix by `x`, write result to original matrix.
fn div_element_mut(&mut self, row: usize, col: usize, x: T);
/// Multiply single element of the matrix by `x`, write result to original matrix.
fn mul_element_mut(&mut self, row: usize, col: usize, x: T);
/// Add single element of the matrix to `x`, write result to original matrix.
fn add_element_mut(&mut self, row: usize, col: usize, x: T);
/// Subtract `x` from single element of the matrix, write result to original matrix.
fn sub_element_mut(&mut self, row: usize, col: usize, x: T);
/// Add matrices, element-wise
fn add(&self, other: &Self) -> Self {
let mut r = self.clone();
r.add_mut(other);
r
}
/// Subtract matrices, element-wise
fn sub(&self, other: &Self) -> Self {
let mut r = self.clone();
r.sub_mut(other);
r
}
/// Multiply matrices, element-wise
fn mul(&self, other: &Self) -> Self {
let mut r = self.clone();
r.mul_mut(other);
r
}
/// Divide matrices, element-wise
fn div(&self, other: &Self) -> Self {
let mut r = self.clone();
r.div_mut(other);
r
}
/// Add `scalar` to the matrix, override original matrix with result.
fn add_scalar_mut(&mut self, scalar: T) -> &Self;
/// Subtract `scalar` from the elements of matrix, override original matrix with result.
fn sub_scalar_mut(&mut self, scalar: T) -> &Self;
/// Multiply `scalar` by the elements of matrix, override original matrix with result.
fn mul_scalar_mut(&mut self, scalar: T) -> &Self;
/// Divide elements of the matrix by `scalar`, override original matrix with result.
fn div_scalar_mut(&mut self, scalar: T) -> &Self;
/// Add `scalar` to the matrix.
fn add_scalar(&self, scalar: T) -> Self {
let mut r = self.clone();
r.add_scalar_mut(scalar);
r
}
/// Subtract `scalar` from the elements of matrix.
fn sub_scalar(&self, scalar: T) -> Self {
let mut r = self.clone();
r.sub_scalar_mut(scalar);
r
}
/// Multiply `scalar` by the elements of matrix.
fn mul_scalar(&self, scalar: T) -> Self {
let mut r = self.clone();
r.mul_scalar_mut(scalar);
r
}
/// Divide elements of the matrix by `scalar`.
fn div_scalar(&self, scalar: T) -> Self {
let mut r = self.clone();
r.div_scalar_mut(scalar);
r
}
/// Reverse or permute the axes of the matrix, return new matrix.
fn transpose(&self) -> Self;
/// Create new `nrows` by `ncols` matrix and populate it with random samples from a uniform distribution over [0, 1).
fn rand(nrows: usize, ncols: usize) -> Self;
/// Returns [L2 norm](https://en.wikipedia.org/wiki/Matrix_norm).
fn norm2(&self) -> T;
/// Returns [matrix norm](https://en.wikipedia.org/wiki/Matrix_norm) of order `p`.
fn norm(&self, p: T) -> T;
/// Returns the average of the matrix columns.
fn column_mean(&self) -> Vec<T>;
/// Numerical negative, element-wise. Overrides original matrix.
fn negative_mut(&mut self);
/// Numerical negative, element-wise.
fn negative(&self) -> Self {
let mut result = self.clone();
result.negative_mut();
result
}
/// Returns new matrix of shape `nrows` by `ncols` with data copied from original matrix.
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
///
/// let a = DenseMatrix::from_array(1, 6, &[1., 2., 3., 4., 5., 6.]);
/// let expected = DenseMatrix::from_2d_array(&[
/// &[1., 2., 3.],
/// &[4., 5., 6.]
/// ]);
///
/// assert_eq!(a.reshape(2, 3), expected);
/// ```
fn reshape(&self, nrows: usize, ncols: usize) -> Self;
/// Copies content of `other` matrix.
fn copy_from(&mut self, other: &Self);
/// Calculate the absolute value element-wise. Overrides original matrix.
fn abs_mut(&mut self) -> &Self;
/// Calculate the absolute value element-wise.
fn abs(&self) -> Self {
let mut result = self.clone();
result.abs_mut();
result
}
/// Calculates sum of all elements of the matrix.
fn sum(&self) -> T;
/// Calculates max of all elements of the matrix.
fn max(&self) -> T;
/// Calculates min of all elements of the matrix.
fn min(&self) -> T;
/// Calculates max(|a - b|) of two matrices
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
///
/// let a = DenseMatrix::from_array(2, 3, &[1., 2., 3., 4., -5., 6.]);
/// let b = DenseMatrix::from_array(2, 3, &[2., 3., 4., 1., 0., -12.]);
///
/// assert_eq!(a.max_diff(&b), 18.);
/// assert_eq!(b.max_diff(&b), 0.);
/// ```
fn max_diff(&self, other: &Self) -> T {
self.sub(other).abs().max()
}
/// Calculates [Softmax function](https://en.wikipedia.org/wiki/Softmax_function). Overrides the matrix with result.
fn softmax_mut(&mut self);
/// Raises elements of the matrix to the power of `p`
fn pow_mut(&mut self, p: T) -> &Self;
/// Returns new matrix with elements raised to the power of `p`
fn pow(&mut self, p: T) -> Self {
let mut result = self.clone();
result.pow_mut(p);
result
}
/// Returns the indices of the maximum values in each row.
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
/// let a = DenseMatrix::from_array(2, 3, &[1., 2., 3., -5., -6., -7.]);
///
/// assert_eq!(a.argmax(), vec![2, 0]);
/// ```
fn argmax(&self) -> Vec<usize>;
/// Returns vector with unique values from the matrix.
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
/// let a = DenseMatrix::from_array(3, 3, &[1., 2., 2., -2., -6., -7., 2., 3., 4.]);
///
///assert_eq!(a.unique(), vec![-7., -6., -2., 1., 2., 3., 4.]);
/// ```
fn unique(&self) -> Vec<T>;
/// Calculates the covariance matrix
fn cov(&self) -> Self;
/// Take elements from an array along an axis.
fn take(&self, index: &[usize], axis: u8) -> Self {
let (n, p) = self.shape();
let k = match axis {
0 => p,
_ => n,
};
let mut result = match axis {
0 => Self::zeros(index.len(), p),
_ => Self::zeros(n, index.len()),
};
for (i, idx) in index.iter().enumerate() {
for j in 0..k {
match axis {
0 => result.set(i, j, self.get(*idx, j)),
_ => result.set(j, i, self.get(j, *idx)),
};
}
}
result
}
}
/// Generic matrix with additional mixins like various factorization methods.
pub trait Matrix<T: RealNumber>:
BaseMatrix<T>
+ SVDDecomposableMatrix<T>
+ EVDDecomposableMatrix<T>
+ QRDecomposableMatrix<T>
+ LUDecomposableMatrix<T>
+ CholeskyDecomposableMatrix<T>
+ MatrixStats<T>
+ MatrixPreprocessing<T>
+ HighOrderOperations<T>
+ PartialEq
+ Display
{
}
pub(crate) fn row_iter<F: RealNumber, M: BaseMatrix<F>>(m: &M) -> RowIter<'_, F, M> {
RowIter {
m,
pos: 0,
max_pos: m.shape().0,
phantom: PhantomData,
}
}
pub(crate) struct RowIter<'a, T: RealNumber, M: BaseMatrix<T>> {
m: &'a M,
pos: usize,
max_pos: usize,
phantom: PhantomData<&'a T>,
}
impl<'a, T: RealNumber, M: BaseMatrix<T>> Iterator for RowIter<'a, T, M> {
type Item = Vec<T>;
fn next(&mut self) -> Option<Vec<T>> {
let res = if self.pos < self.max_pos {
Some(self.m.get_row_as_vec(self.pos))
} else {
None
};
self.pos += 1;
res
}
}
#[cfg(test)]
mod tests {
use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::linalg::BaseMatrix;
use crate::linalg::BaseVector;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn mean() {
let m = vec![1., 2., 3.];
assert_eq!(m.mean(), 2.0);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn std() {
let m = vec![1., 2., 3.];
assert!((m.std() - 0.81f64).abs() < 1e-2);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn var() {
let m = vec![1., 2., 3., 4.];
assert!((m.var() - 1.25f64).abs() < std::f64::EPSILON);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn vec_take() {
let m = vec![1., 2., 3., 4., 5.];
assert_eq!(m.take(&vec!(0, 0, 4, 4)), vec![1., 1., 5., 5.]);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn take() {
let m = DenseMatrix::from_2d_array(&[
&[1.0, 2.0],
&[3.0, 4.0],
&[5.0, 6.0],
&[7.0, 8.0],
&[9.0, 10.0],
]);
let expected_0 = DenseMatrix::from_2d_array(&[&[3.0, 4.0], &[3.0, 4.0], &[7.0, 8.0]]);
let expected_1 = DenseMatrix::from_2d_array(&[
&[2.0, 1.0],
&[4.0, 3.0],
&[6.0, 5.0],
&[8.0, 7.0],
&[10.0, 9.0],
]);
assert_eq!(m.take(&vec!(1, 1, 3), 0), expected_0);
assert_eq!(m.take(&vec!(1, 0), 1), expected_1);
}
}
File diff suppressed because it is too large Load Diff
+26
View File
@@ -0,0 +1,26 @@
//! # Simple Dense Matrix
//!
//! Implements [`BaseMatrix`](../../trait.BaseMatrix.html) and [`BaseVector`](../../trait.BaseVector.html) for [Vec](https://doc.rust-lang.org/std/vec/struct.Vec.html).
//! Data is stored in dense format with [column-major order](https://en.wikipedia.org/wiki/Row-_and_column-major_order).
//!
//! Example:
//!
//! ```
//! use smartcore::linalg::naive::dense_matrix::*;
//!
//! // 3x3 matrix
//! let A = DenseMatrix::from_2d_array(&[
//! &[0.9000, 0.4000, 0.7000],
//! &[0.4000, 0.5000, 0.3000],
//! &[0.7000, 0.3000, 0.8000],
//! ]);
//!
//! // row vector
//! let B = DenseMatrix::from_array(1, 3, &[0.9, 0.4, 0.7]);
//!
//! // column vector
//! let C = DenseMatrix::from_vec(3, 1, &vec!(0.9, 0.4, 0.7));
//! ```
/// Add this module to use Dense Matrix
pub mod dense_matrix;
File diff suppressed because it is too large Load Diff
-282
View File
@@ -1,282 +0,0 @@
use std::fmt::{Debug, Display};
use std::ops::Range;
use crate::linalg::basic::arrays::{
Array as BaseArray, Array2, ArrayView1, ArrayView2, MutArray, MutArrayView2,
};
use crate::linalg::traits::cholesky::CholeskyDecomposable;
use crate::linalg::traits::evd::EVDDecomposable;
use crate::linalg::traits::lu::LUDecomposable;
use crate::linalg::traits::qr::QRDecomposable;
use crate::linalg::traits::svd::SVDDecomposable;
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
use ndarray::{s, Array, ArrayBase, ArrayView, ArrayViewMut, Ix2, OwnedRepr};
impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)>
for ArrayBase<OwnedRepr<T>, Ix2>
{
fn get(&self, pos: (usize, usize)) -> &T {
&self[[pos.0, pos.1]]
}
fn shape(&self) -> (usize, usize) {
(self.nrows(), self.ncols())
}
fn is_empty(&self) -> bool {
self.len() > 0
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
assert!(
axis == 1 || axis == 0,
"For two dimensional array `axis` should be either 0 or 1"
);
match axis {
0 => Box::new(self.iter()),
_ => Box::new(
(0..self.ncols()).flat_map(move |c| (0..self.nrows()).map(move |r| &self[[r, c]])),
),
}
}
}
impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)>
for ArrayBase<OwnedRepr<T>, Ix2>
{
fn set(&mut self, pos: (usize, usize), x: T) {
self[[pos.0, pos.1]] = x
}
fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
let ptr = self.as_mut_ptr();
let stride = self.strides();
let (rstride, cstride) = (stride[0] as usize, stride[1] as usize);
match axis {
0 => Box::new(self.iter_mut()),
_ => Box::new((0..self.ncols()).flat_map(move |c| {
(0..self.nrows()).map(move |r| unsafe { &mut *ptr.add(r * rstride + c * cstride) })
})),
}
}
}
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> for ArrayView<'_, T, Ix2> {
fn get(&self, pos: (usize, usize)) -> &T {
&self[[pos.0, pos.1]]
}
fn shape(&self) -> (usize, usize) {
(self.nrows(), self.ncols())
}
fn is_empty(&self) -> bool {
self.len() > 0
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
assert!(
axis == 1 || axis == 0,
"For two dimensional array `axis` should be either 0 or 1"
);
match axis {
0 => Box::new(self.iter()),
_ => Box::new(
(0..self.ncols()).flat_map(move |c| (0..self.nrows()).map(move |r| &self[[r, c]])),
),
}
}
}
impl<T: Debug + Display + Copy + Sized> Array2<T> for ArrayBase<OwnedRepr<T>, Ix2> {
fn get_row<'a>(&'a self, row: usize) -> Box<dyn ArrayView1<T> + 'a> {
Box::new(self.row(row))
}
fn get_col<'a>(&'a self, col: usize) -> Box<dyn ArrayView1<T> + 'a> {
Box::new(self.column(col))
}
fn slice<'a>(&'a self, rows: Range<usize>, cols: Range<usize>) -> Box<dyn ArrayView2<T> + 'a> {
Box::new(self.slice(s![rows, cols]))
}
fn slice_mut<'a>(
&'a mut self,
rows: Range<usize>,
cols: Range<usize>,
) -> Box<dyn MutArrayView2<T> + 'a>
where
Self: Sized,
{
Box::new(self.slice_mut(s![rows, cols]))
}
fn fill(nrows: usize, ncols: usize, value: T) -> Self {
Array::from_elem([nrows, ncols], value)
}
fn from_iterator<I: Iterator<Item = T>>(iter: I, nrows: usize, ncols: usize, axis: u8) -> Self {
let a = Array::from_iter(iter.take(nrows * ncols))
.into_shape((nrows, ncols))
.unwrap();
match axis {
0 => a,
_ => a.reversed_axes().into_shape((nrows, ncols)).unwrap(),
}
}
fn transpose(&self) -> Self {
self.t().to_owned()
}
}
impl<T: Number + RealNumber> QRDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl<T: Number + RealNumber> CholeskyDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl<T: Number + RealNumber> EVDDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl<T: Number + RealNumber> LUDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl<T: Number + RealNumber> SVDDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayView<'_, T, Ix2> {}
impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> for ArrayViewMut<'_, T, Ix2> {
fn get(&self, pos: (usize, usize)) -> &T {
&self[[pos.0, pos.1]]
}
fn shape(&self) -> (usize, usize) {
(self.nrows(), self.ncols())
}
fn is_empty(&self) -> bool {
self.len() > 0
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
assert!(
axis == 1 || axis == 0,
"For two dimensional array `axis` should be either 0 or 1"
);
match axis {
0 => Box::new(self.iter()),
_ => Box::new(
(0..self.ncols()).flat_map(move |c| (0..self.nrows()).map(move |r| &self[[r, c]])),
),
}
}
}
impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for ArrayViewMut<'_, T, Ix2> {
fn set(&mut self, pos: (usize, usize), x: T) {
self[[pos.0, pos.1]] = x
}
fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
let ptr = self.as_mut_ptr();
let stride = self.strides();
let (rstride, cstride) = (stride[0] as usize, stride[1] as usize);
match axis {
0 => Box::new(self.iter_mut()),
_ => Box::new((0..self.ncols()).flat_map(move |c| {
(0..self.nrows()).map(move |r| unsafe { &mut *ptr.add(r * rstride + c * cstride) })
})),
}
}
}
impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for ArrayViewMut<'_, T, Ix2> {}
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayViewMut<'_, T, Ix2> {}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{arr2, Array2 as NDArray2};
#[test]
fn test_get_set() {
let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]);
assert_eq!(*BaseArray::get(&a, (1, 1)), 5);
a.set((1, 1), 9);
assert_eq!(a, arr2(&[[1, 2, 3], [4, 9, 6]]));
}
#[test]
fn test_iterator() {
let a = arr2(&[[1, 2, 3], [4, 5, 6]]);
let v: Vec<i32> = a.iterator(0).copied().collect();
assert_eq!(v, vec!(1, 2, 3, 4, 5, 6));
}
#[test]
fn test_mut_iterator() {
let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]);
a.iterator_mut(0).enumerate().for_each(|(i, v)| *v = i);
assert_eq!(a, arr2(&[[0, 1, 2], [3, 4, 5]]));
a.iterator_mut(1).enumerate().for_each(|(i, v)| *v = i);
assert_eq!(a, arr2(&[[0, 2, 4], [1, 3, 5]]));
}
#[test]
fn test_slice() {
let x = arr2(&[[1, 2, 3], [4, 5, 6]]);
let x_slice = Array2::slice(&x, 0..2, 1..2);
assert_eq!((2, 1), x_slice.shape());
let v: Vec<i32> = x_slice.iterator(0).copied().collect();
assert_eq!(v, [2, 5]);
}
#[test]
fn test_slice_iter() {
let x = arr2(&[[1, 2, 3], [4, 5, 6]]);
let x_slice = Array2::slice(&x, 0..2, 0..3);
assert_eq!(
x_slice.iterator(0).copied().collect::<Vec<i32>>(),
vec![1, 2, 3, 4, 5, 6]
);
assert_eq!(
x_slice.iterator(1).copied().collect::<Vec<i32>>(),
vec![1, 4, 2, 5, 3, 6]
);
}
#[test]
fn test_slice_mut_iter() {
let mut x = arr2(&[[1, 2, 3], [4, 5, 6]]);
{
let mut x_slice = Array2::slice_mut(&mut x, 0..2, 0..3);
x_slice
.iterator_mut(0)
.enumerate()
.for_each(|(i, v)| *v = i);
}
assert_eq!(x, arr2(&[[0, 1, 2], [3, 4, 5]]));
{
let mut x_slice = Array2::slice_mut(&mut x, 0..2, 0..3);
x_slice
.iterator_mut(1)
.enumerate()
.for_each(|(i, v)| *v = i);
}
assert_eq!(x, arr2(&[[0, 2, 4], [1, 3, 5]]));
}
#[test]
fn test_c_from_iterator() {
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
let a: NDArray2<i32> = Array2::from_iterator(data.clone().into_iter(), 4, 3, 0);
println!("{a}");
let a: NDArray2<i32> = Array2::from_iterator(data.into_iter(), 4, 3, 1);
println!("{a}");
}
}
-4
View File
@@ -1,4 +0,0 @@
/// matrix bindings
pub mod matrix;
/// vector bindings
pub mod vector;
-184
View File
@@ -1,184 +0,0 @@
use std::fmt::{Debug, Display};
use std::ops::Range;
use crate::linalg::basic::arrays::{
Array as BaseArray, Array1, ArrayView1, MutArray, MutArrayView1,
};
use ndarray::{s, Array, ArrayBase, ArrayView, ArrayViewMut, Ix1, OwnedRepr};
impl<T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayBase<OwnedRepr<T>, Ix1> {
fn get(&self, i: usize) -> &T {
&self[i]
}
fn shape(&self) -> usize {
self.len()
}
fn is_empty(&self) -> bool {
self.len() > 0
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
assert!(axis == 0, "For one dimensional array `axis` should == 0");
Box::new(self.iter())
}
}
impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for ArrayBase<OwnedRepr<T>, Ix1> {
fn set(&mut self, i: usize, x: T) {
self[i] = x
}
fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
assert!(axis == 0, "For one dimensional array `axis` should == 0");
Box::new(self.iter_mut())
}
}
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayBase<OwnedRepr<T>, Ix1> {}
impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for ArrayBase<OwnedRepr<T>, Ix1> {}
impl<T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayView<'_, T, Ix1> {
fn get(&self, i: usize) -> &T {
&self[i]
}
fn shape(&self) -> usize {
self.len()
}
fn is_empty(&self) -> bool {
self.len() > 0
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
assert!(axis == 0, "For one dimensional array `axis` should == 0");
Box::new(self.iter())
}
}
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayView<'_, T, Ix1> {}
impl<T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayViewMut<'_, T, Ix1> {
fn get(&self, i: usize) -> &T {
&self[i]
}
fn shape(&self) -> usize {
self.len()
}
fn is_empty(&self) -> bool {
self.len() > 0
}
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
assert!(axis == 0, "For one dimensional array `axis` should == 0");
Box::new(self.iter())
}
}
impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for ArrayViewMut<'_, T, Ix1> {
fn set(&mut self, i: usize, x: T) {
self[i] = x;
}
fn iterator_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
assert!(axis == 0, "For one dimensional array `axis` should == 0");
Box::new(self.iter_mut())
}
}
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayViewMut<'_, T, Ix1> {}
impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for ArrayViewMut<'_, T, Ix1> {}
impl<T: Debug + Display + Copy + Sized> Array1<T> for ArrayBase<OwnedRepr<T>, Ix1> {
fn slice<'a>(&'a self, range: Range<usize>) -> Box<dyn ArrayView1<T> + 'a> {
assert!(
range.end <= self.len(),
"`range` should be <= {}",
self.len()
);
Box::new(self.slice(s![range]))
}
fn slice_mut<'b>(&'b mut self, range: Range<usize>) -> Box<dyn MutArrayView1<T> + 'b> {
assert!(
range.end <= self.len(),
"`range` should be <= {}",
self.len()
);
Box::new(self.slice_mut(s![range]))
}
fn fill(len: usize, value: T) -> Self {
Array::from_elem(len, value)
}
fn from_iterator<I: Iterator<Item = T>>(iter: I, len: usize) -> Self
where
Self: Sized,
{
Array::from_iter(iter.take(len))
}
fn from_vec_slice(slice: &[T]) -> Self {
Array::from_iter(slice.iter().copied())
}
fn from_slice(slice: &dyn ArrayView1<T>) -> Self {
Array::from_iter(slice.iterator(0).copied())
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::arr1;
#[test]
fn test_get_set() {
let mut a = arr1(&[1, 2, 3]);
assert_eq!(*BaseArray::get(&a, 1), 2);
a.set(1, 9);
assert_eq!(a, arr1(&[1, 9, 3]));
}
#[test]
fn test_iterator() {
let a = arr1(&[1, 2, 3]);
let v: Vec<i32> = a.iterator(0).copied().collect();
assert_eq!(v, vec!(1, 2, 3));
}
#[test]
fn test_mut_iterator() {
let mut a = arr1(&[1, 2, 3]);
a.iterator_mut(0).for_each(|v| *v = 1);
assert_eq!(a, arr1(&[1, 1, 1]));
}
#[test]
fn test_slice() {
let x = arr1(&[1, 2, 3, 4, 5]);
let x_slice = Array1::slice(&x, 2..3);
assert_eq!(1, x_slice.shape());
assert_eq!(3, *x_slice.get(0));
}
#[test]
fn test_mut_slice() {
let mut x = arr1(&[1, 2, 3, 4, 5]);
let mut x_slice = Array1::slice_mut(&mut x, 2..4);
x_slice.set(0, 9);
assert_eq!(2, x_slice.shape());
assert_eq!(9, *x_slice.get(0));
assert_eq!(4, *x_slice.get(1));
}
}
File diff suppressed because it is too large Load Diff
+44 -55
View File
@@ -6,14 +6,14 @@
//!
//! Example:
//! ```
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linalg::traits::qr::*;
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::linalg::qr::*;
//!
//! let A = DenseMatrix::from_2d_array(&[
//! &[0.9, 0.4, 0.7],
//! &[0.4, 0.5, 0.3],
//! &[0.7, 0.3, 0.8]
//! ]).unwrap();
//! ]);
//!
//! let qr = A.qr().unwrap();
//! let orthogonal: DenseMatrix<f64> = qr.Q();
@@ -28,22 +28,20 @@
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#![allow(non_snake_case)]
use std::fmt::Debug;
use crate::error::Failed;
use crate::linalg::basic::arrays::Array2;
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
use crate::linalg::BaseMatrix;
use crate::math::num::RealNumber;
use std::fmt::Debug;
#[derive(Debug, Clone)]
/// Results of QR decomposition.
pub struct QR<T: Number + RealNumber, M: Array2<T>> {
pub struct QR<T: RealNumber, M: BaseMatrix<T>> {
QR: M,
tau: Vec<T>,
singular: bool,
}
impl<T: Number + RealNumber, M: Array2<T>> QR<T, M> {
impl<T: RealNumber, M: BaseMatrix<T>> QR<T, M> {
pub(crate) fn new(QR: M, tau: Vec<T>) -> QR<T, M> {
let mut singular = false;
for tau_elem in tau.iter() {
@@ -61,9 +59,9 @@ impl<T: Number + RealNumber, M: Array2<T>> QR<T, M> {
let (_, n) = self.QR.shape();
let mut R = M::zeros(n, n);
for i in 0..n {
R.set((i, i), self.tau[i]);
R.set(i, i, self.tau[i]);
for j in i + 1..n {
R.set((i, j), *self.QR.get((i, j)));
R.set(i, j, self.QR.get(i, j));
}
}
R
@@ -75,16 +73,16 @@ impl<T: Number + RealNumber, M: Array2<T>> QR<T, M> {
let mut Q = M::zeros(m, n);
let mut k = n - 1;
loop {
Q.set((k, k), T::one());
Q.set(k, k, T::one());
for j in k..n {
if self.QR.get((k, k)) != &T::zero() {
if self.QR.get(k, k) != T::zero() {
let mut s = T::zero();
for i in k..m {
s += *self.QR.get((i, k)) * *Q.get((i, j));
s += self.QR.get(i, k) * Q.get(i, j);
}
s = -s / *self.QR.get((k, k));
s = -s / self.QR.get(k, k);
for i in k..m {
Q.add_element_mut((i, j), s * *self.QR.get((i, k)));
Q.add_element_mut(i, j, s * self.QR.get(i, k));
}
}
}
@@ -102,7 +100,10 @@ impl<T: Number + RealNumber, M: Array2<T>> QR<T, M> {
let (b_nrows, b_ncols) = b.shape();
if b_nrows != m {
panic!("Row dimensions do not agree: A is {m} x {n}, but B is {b_nrows} x {b_ncols}");
panic!(
"Row dimensions do not agree: A is {} x {}, but B is {} x {}",
m, n, b_nrows, b_ncols
);
}
if self.singular {
@@ -113,23 +114,23 @@ impl<T: Number + RealNumber, M: Array2<T>> QR<T, M> {
for j in 0..b_ncols {
let mut s = T::zero();
for i in k..m {
s += *self.QR.get((i, k)) * *b.get((i, j));
s += self.QR.get(i, k) * b.get(i, j);
}
s = -s / *self.QR.get((k, k));
s = -s / self.QR.get(k, k);
for i in k..m {
b.add_element_mut((i, j), s * *self.QR.get((i, k)));
b.add_element_mut(i, j, s * self.QR.get(i, k));
}
}
}
for k in (0..n).rev() {
for j in 0..b_ncols {
b.set((k, j), *b.get((k, j)) / self.tau[k]);
b.set(k, j, b.get(k, j) / self.tau[k]);
}
for i in 0..k {
for j in 0..b_ncols {
b.sub_element_mut((i, j), *b.get((k, j)) * *self.QR.get((i, k)));
b.sub_element_mut(i, j, b.get(k, j) * self.QR.get(i, k));
}
}
}
@@ -139,7 +140,7 @@ impl<T: Number + RealNumber, M: Array2<T>> QR<T, M> {
}
/// Trait that implements QR decomposition routine for any matrix.
pub trait QRDecomposable<T: Number + RealNumber>: Array2<T> {
pub trait QRDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
/// Compute the QR decomposition of a matrix.
fn qr(&self) -> Result<QR<T, Self>, Failed> {
self.clone().qr_mut()
@@ -155,26 +156,26 @@ pub trait QRDecomposable<T: Number + RealNumber>: Array2<T> {
for (k, r_diagonal_k) in r_diagonal.iter_mut().enumerate().take(n) {
let mut nrm = T::zero();
for i in k..m {
nrm = nrm.hypot(*self.get((i, k)));
nrm = nrm.hypot(self.get(i, k));
}
if nrm.abs() > T::epsilon() {
if self.get((k, k)) < &T::zero() {
if self.get(k, k) < T::zero() {
nrm = -nrm;
}
for i in k..m {
self.div_element_mut((i, k), nrm);
self.div_element_mut(i, k, nrm);
}
self.add_element_mut((k, k), T::one());
self.add_element_mut(k, k, T::one());
for j in k + 1..n {
let mut s = T::zero();
for i in k..m {
s += *self.get((i, k)) * *self.get((i, j));
s += self.get(i, k) * self.get(i, j);
}
s = -s / *self.get((k, k));
s = -s / self.get(k, k);
for i in k..m {
self.add_element_mut((i, j), s * *self.get((i, k)));
self.add_element_mut(i, j, s * self.get(i, k));
}
}
}
@@ -193,49 +194,37 @@ pub trait QRDecomposable<T: Number + RealNumber>: Array2<T> {
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use approx::relative_eq;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
use crate::linalg::naive::dense_matrix::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn decompose() {
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]])
.unwrap();
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
let q = DenseMatrix::from_2d_array(&[
&[-0.7448, 0.2436, 0.6212],
&[-0.331, -0.9432, -0.027],
&[-0.5793, 0.2257, -0.7832],
])
.unwrap();
]);
let r = DenseMatrix::from_2d_array(&[
&[-1.2083, -0.6373, -1.0842],
&[0.0, -0.3064, 0.0682],
&[0.0, 0.0, -0.1999],
])
.unwrap();
]);
let qr = a.qr().unwrap();
assert!(relative_eq!(qr.Q().abs(), q.abs(), epsilon = 1e-4));
assert!(relative_eq!(qr.R().abs(), r.abs(), epsilon = 1e-4));
assert!(qr.Q().abs().approximate_eq(&q.abs(), 1e-4));
assert!(qr.R().abs().approximate_eq(&r.abs(), 1e-4));
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn qr_solve_mut() {
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]])
.unwrap();
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]).unwrap();
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]);
let expected_w = DenseMatrix::from_2d_array(&[
&[-0.2027027, -1.2837838],
&[0.8783784, 2.2297297],
&[0.4729730, 0.6621622],
])
.unwrap();
]);
let w = a.qr_solve_mut(b).unwrap();
assert!(relative_eq!(w, expected_w, epsilon = 1e-2));
assert!(w.approximate_eq(&expected_w, 1e-2));
}
}
+207
View File
@@ -0,0 +1,207 @@
//! # Various Statistical Methods
//!
//! This module provides reference implementations for various statistical functions.
//! Concrete implementations of the `BaseMatrix` trait are free to override these methods for better performance.
use crate::linalg::BaseMatrix;
use crate::math::num::RealNumber;
/// Defines baseline implementations for various statistical functions
pub trait MatrixStats<T: RealNumber>: BaseMatrix<T> {
/// Computes the arithmetic mean along the specified axis.
fn mean(&self, axis: u8) -> Vec<T> {
let (n, m) = match axis {
0 => {
let (n, m) = self.shape();
(m, n)
}
_ => self.shape(),
};
let mut x: Vec<T> = vec![T::zero(); n];
let div = T::from_usize(m).unwrap();
for (i, x_i) in x.iter_mut().enumerate().take(n) {
for j in 0..m {
*x_i += match axis {
0 => self.get(j, i),
_ => self.get(i, j),
};
}
*x_i /= div;
}
x
}
/// Computes variance along the specified axis.
fn var(&self, axis: u8) -> Vec<T> {
let (n, m) = match axis {
0 => {
let (n, m) = self.shape();
(m, n)
}
_ => self.shape(),
};
let mut x: Vec<T> = vec![T::zero(); n];
let div = T::from_usize(m).unwrap();
for (i, x_i) in x.iter_mut().enumerate().take(n) {
let mut mu = T::zero();
let mut sum = T::zero();
for j in 0..m {
let a = match axis {
0 => self.get(j, i),
_ => self.get(i, j),
};
mu += a;
sum += a * a;
}
mu /= div;
*x_i = sum / div - mu.powi(2);
}
x
}
/// Computes the standard deviation along the specified axis.
fn std(&self, axis: u8) -> Vec<T> {
let mut x = self.var(axis);
let n = match axis {
0 => self.shape().1,
_ => self.shape().0,
};
for x_i in x.iter_mut().take(n) {
*x_i = x_i.sqrt();
}
x
}
/// standardize values by removing the mean and scaling to unit variance
fn scale_mut(&mut self, mean: &[T], std: &[T], axis: u8) {
let (n, m) = match axis {
0 => {
let (n, m) = self.shape();
(m, n)
}
_ => self.shape(),
};
for i in 0..n {
for j in 0..m {
match axis {
0 => self.set(j, i, (self.get(j, i) - mean[i]) / std[i]),
_ => self.set(i, j, (self.get(i, j) - mean[i]) / std[i]),
}
}
}
}
}
/// Defines baseline implementations for various matrix processing functions
pub trait MatrixPreprocessing<T: RealNumber>: BaseMatrix<T> {
/// Each element of the matrix greater than the threshold becomes 1, while values less than or equal to the threshold become 0
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
/// use crate::smartcore::linalg::stats::MatrixPreprocessing;
/// let mut a = DenseMatrix::from_array(2, 3, &[0., 2., 3., -5., -6., -7.]);
/// let expected = DenseMatrix::from_array(2, 3, &[0., 1., 1., 0., 0., 0.]);
/// a.binarize_mut(0.);
///
/// assert_eq!(a, expected);
/// ```
fn binarize_mut(&mut self, threshold: T) {
let (nrows, ncols) = self.shape();
for row in 0..nrows {
for col in 0..ncols {
if self.get(row, col) > threshold {
self.set(row, col, T::one());
} else {
self.set(row, col, T::zero());
}
}
}
}
/// Returns new matrix where elements are binarized according to a given threshold.
/// ```
/// use smartcore::linalg::naive::dense_matrix::*;
/// use crate::smartcore::linalg::stats::MatrixPreprocessing;
/// let a = DenseMatrix::from_array(2, 3, &[0., 2., 3., -5., -6., -7.]);
/// let expected = DenseMatrix::from_array(2, 3, &[0., 1., 1., 0., 0., 0.]);
///
/// assert_eq!(a.binarize(0.), expected);
/// ```
fn binarize(&self, threshold: T) -> Self {
let mut m = self.clone();
m.binarize_mut(threshold);
m
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::linalg::BaseVector;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn mean() {
let m = DenseMatrix::from_2d_array(&[
&[1., 2., 3., 1., 2.],
&[4., 5., 6., 3., 4.],
&[7., 8., 9., 5., 6.],
]);
let expected_0 = vec![4., 5., 6., 3., 4.];
let expected_1 = vec![1.8, 4.4, 7.];
assert_eq!(m.mean(0), expected_0);
assert_eq!(m.mean(1), expected_1);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn std() {
let m = DenseMatrix::from_2d_array(&[
&[1., 2., 3., 1., 2.],
&[4., 5., 6., 3., 4.],
&[7., 8., 9., 5., 6.],
]);
let expected_0 = vec![2.44, 2.44, 2.44, 1.63, 1.63];
let expected_1 = vec![0.74, 1.01, 1.41];
assert!(m.std(0).approximate_eq(&expected_0, 1e-2));
assert!(m.std(1).approximate_eq(&expected_1, 1e-2));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn var() {
let m = DenseMatrix::from_2d_array(&[&[1., 2., 3., 4.], &[5., 6., 7., 8.]]);
let expected_0 = vec![4., 4., 4., 4.];
let expected_1 = vec![1.25, 1.25];
assert!(m.var(0).approximate_eq(&expected_0, std::f64::EPSILON));
assert!(m.var(1).approximate_eq(&expected_1, std::f64::EPSILON));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn scale() {
let mut m = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
let expected_0 = DenseMatrix::from_2d_array(&[&[-1., -1., -1.], &[1., 1., 1.]]);
let expected_1 = DenseMatrix::from_2d_array(&[&[-1.22, 0.0, 1.22], &[-1.22, 0.0, 1.22]]);
{
let mut m = m.clone();
m.scale_mut(&m.mean(0), &m.std(0), 0);
assert!(m.approximate_eq(&expected_0, std::f32::EPSILON));
}
m.scale_mut(&m.mean(1), &m.std(1), 1);
assert!(m.approximate_eq(&expected_1, 1e-2));
}
}
+107 -120
View File
@@ -10,14 +10,14 @@
//!
//! Example:
//! ```
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linalg::traits::svd::*;
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::linalg::svd::*;
//!
//! let A = DenseMatrix::from_2d_array(&[
//! &[0.9, 0.4, 0.7],
//! &[0.4, 0.5, 0.3],
//! &[0.7, 0.3, 0.8]
//! ]).unwrap();
//! ]);
//!
//! let svd = A.svd().unwrap();
//! let u: DenseMatrix<f64> = svd.U;
@@ -34,33 +34,32 @@
#![allow(non_snake_case)]
use crate::error::Failed;
use crate::linalg::basic::arrays::Array2;
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
use crate::linalg::BaseMatrix;
use crate::math::num::RealNumber;
use std::fmt::Debug;
/// Results of SVD decomposition
#[derive(Debug, Clone)]
pub struct SVD<T: Number + RealNumber, M: SVDDecomposable<T>> {
pub struct SVD<T: RealNumber, M: SVDDecomposableMatrix<T>> {
/// Left-singular vectors of _A_
pub U: M,
/// Right-singular vectors of _A_
pub V: M,
/// Singular values of the original matrix
pub s: Vec<T>,
_full: bool,
m: usize,
n: usize,
/// Tolerance
tol: T,
}
impl<T: Number + RealNumber, M: SVDDecomposable<T>> SVD<T, M> {
impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
/// Diagonal matrix with singular values
pub fn S(&self) -> M {
let mut s = M::zeros(self.U.shape().1, self.V.shape().0);
for i in 0..self.s.len() {
s.set((i, i), self.s[i]);
s.set(i, i, self.s[i]);
}
s
@@ -68,7 +67,7 @@ impl<T: Number + RealNumber, M: SVDDecomposable<T>> SVD<T, M> {
}
/// Trait that implements SVD decomposition routine for any matrix.
pub trait SVDDecomposable<T: Number + RealNumber>: Array2<T> {
pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
/// Solves Ax = b. Overrides original matrix in the process.
fn svd_solve_mut(self, b: Self) -> Result<Self, Failed> {
self.svd_mut().and_then(|svd| svd.solve(b))
@@ -107,31 +106,31 @@ pub trait SVDDecomposable<T: Number + RealNumber>: Array2<T> {
if i < m {
for k in i..m {
scale += U.get((k, i)).abs();
scale += U.get(k, i).abs();
}
if scale.abs() > T::epsilon() {
for k in i..m {
U.div_element_mut((k, i), scale);
s += *U.get((k, i)) * *U.get((k, i));
U.div_element_mut(k, i, scale);
s += U.get(k, i) * U.get(k, i);
}
let mut f = *U.get((i, i));
g = -<T as RealNumber>::copysign(s.sqrt(), f);
let mut f = U.get(i, i);
g = -RealNumber::copysign(s.sqrt(), f);
let h = f * g - s;
U.set((i, i), f - g);
U.set(i, i, f - g);
for j in l - 1..n {
s = T::zero();
for k in i..m {
s += *U.get((k, i)) * *U.get((k, j));
s += U.get(k, i) * U.get(k, j);
}
f = s / h;
for k in i..m {
U.add_element_mut((k, j), f * *U.get((k, i)));
U.add_element_mut(k, j, f * U.get(k, i));
}
}
for k in i..m {
U.mul_element_mut((k, i), scale);
U.mul_element_mut(k, i, scale);
}
}
}
@@ -143,37 +142,37 @@ pub trait SVDDecomposable<T: Number + RealNumber>: Array2<T> {
if i < m && i + 1 != n {
for k in l - 1..n {
scale += U.get((i, k)).abs();
scale += U.get(i, k).abs();
}
if scale.abs() > T::epsilon() {
for k in l - 1..n {
U.div_element_mut((i, k), scale);
s += *U.get((i, k)) * *U.get((i, k));
U.div_element_mut(i, k, scale);
s += U.get(i, k) * U.get(i, k);
}
let f = *U.get((i, l - 1));
g = -<T as RealNumber>::copysign(s.sqrt(), f);
let f = U.get(i, l - 1);
g = -RealNumber::copysign(s.sqrt(), f);
let h = f * g - s;
U.set((i, l - 1), f - g);
U.set(i, l - 1, f - g);
for (k, rv1_k) in rv1.iter_mut().enumerate().take(n).skip(l - 1) {
*rv1_k = *U.get((i, k)) / h;
*rv1_k = U.get(i, k) / h;
}
for j in l - 1..m {
s = T::zero();
for k in l - 1..n {
s += *U.get((j, k)) * *U.get((i, k));
s += U.get(j, k) * U.get(i, k);
}
for (k, rv1_k) in rv1.iter().enumerate().take(n).skip(l - 1) {
U.add_element_mut((j, k), s * (*rv1_k));
U.add_element_mut(j, k, s * (*rv1_k));
}
}
for k in l - 1..n {
U.mul_element_mut((i, k), scale);
U.mul_element_mut(i, k, scale);
}
}
}
@@ -185,24 +184,24 @@ pub trait SVDDecomposable<T: Number + RealNumber>: Array2<T> {
if i < n - 1 {
if g != T::zero() {
for j in l..n {
v.set((j, i), (*U.get((i, j)) / *U.get((i, l))) / g);
v.set(j, i, (U.get(i, j) / U.get(i, l)) / g);
}
for j in l..n {
let mut s = T::zero();
for k in l..n {
s += *U.get((i, k)) * *v.get((k, j));
s += U.get(i, k) * v.get(k, j);
}
for k in l..n {
v.add_element_mut((k, j), s * *v.get((k, i)));
v.add_element_mut(k, j, s * v.get(k, i));
}
}
}
for j in l..n {
v.set((i, j), T::zero());
v.set((j, i), T::zero());
v.set(i, j, T::zero());
v.set(j, i, T::zero());
}
}
v.set((i, i), T::one());
v.set(i, i, T::one());
g = rv1[i];
l = i;
}
@@ -211,7 +210,7 @@ pub trait SVDDecomposable<T: Number + RealNumber>: Array2<T> {
l = i + 1;
g = w[i];
for j in l..n {
U.set((i, j), T::zero());
U.set(i, j, T::zero());
}
if g.abs() > T::epsilon() {
@@ -219,23 +218,23 @@ pub trait SVDDecomposable<T: Number + RealNumber>: Array2<T> {
for j in l..n {
let mut s = T::zero();
for k in l..m {
s += *U.get((k, i)) * *U.get((k, j));
s += U.get(k, i) * U.get(k, j);
}
let f = (s / *U.get((i, i))) * g;
let f = (s / U.get(i, i)) * g;
for k in i..m {
U.add_element_mut((k, j), f * *U.get((k, i)));
U.add_element_mut(k, j, f * U.get(k, i));
}
}
for j in i..m {
U.mul_element_mut((j, i), g);
U.mul_element_mut(j, i, g);
}
} else {
for j in i..m {
U.set((j, i), T::zero());
U.set(j, i, T::zero());
}
}
U.add_element_mut((i, i), T::one());
U.add_element_mut(i, i, T::one());
}
for k in (0..n).rev() {
@@ -270,10 +269,10 @@ pub trait SVDDecomposable<T: Number + RealNumber>: Array2<T> {
c = g * h;
s = -f * h;
for j in 0..m {
let y = *U.get((j, nm));
let z = *U.get((j, i));
U.set((j, nm), y * c + z * s);
U.set((j, i), z * c - y * s);
let y = U.get(j, nm);
let z = U.get(j, i);
U.set(j, nm, y * c + z * s);
U.set(j, i, z * c - y * s);
}
}
}
@@ -283,7 +282,7 @@ pub trait SVDDecomposable<T: Number + RealNumber>: Array2<T> {
if z < T::zero() {
w[k] = -z;
for j in 0..n {
v.set((j, k), -*v.get((j, k)));
v.set(j, k, -v.get(j, k));
}
}
break;
@@ -300,8 +299,7 @@ pub trait SVDDecomposable<T: Number + RealNumber>: Array2<T> {
let mut h = rv1[k];
let mut f = ((y - z) * (y + z) + (g - h) * (g + h)) / (T::two() * h * y);
g = f.hypot(T::one());
f = ((x - z) * (x + z) + h * ((y / (f + <T as RealNumber>::copysign(g, f))) - h))
/ x;
f = ((x - z) * (x + z) + h * ((y / (f + RealNumber::copysign(g, f))) - h)) / x;
let mut c = T::one();
let mut s = T::one();
@@ -321,10 +319,10 @@ pub trait SVDDecomposable<T: Number + RealNumber>: Array2<T> {
y *= c;
for jj in 0..n {
x = *v.get((jj, j));
z = *v.get((jj, i));
v.set((jj, j), x * c + z * s);
v.set((jj, i), z * c - x * s);
x = v.get(jj, j);
z = v.get(jj, i);
v.set(jj, j, x * c + z * s);
v.set(jj, i, z * c - x * s);
}
z = f.hypot(h);
@@ -338,10 +336,10 @@ pub trait SVDDecomposable<T: Number + RealNumber>: Array2<T> {
f = c * g + s * y;
x = c * y - s * g;
for jj in 0..m {
y = *U.get((jj, j));
z = *U.get((jj, i));
U.set((jj, j), y * c + z * s);
U.set((jj, i), z * c - y * s);
y = U.get(jj, j);
z = U.get(jj, i);
U.set(jj, j, y * c + z * s);
U.set(jj, i, z * c - y * s);
}
}
@@ -368,19 +366,19 @@ pub trait SVDDecomposable<T: Number + RealNumber>: Array2<T> {
for i in inc..n {
let sw = w[i];
for (k, su_k) in su.iter_mut().enumerate().take(m) {
*su_k = *U.get((k, i));
*su_k = U.get(k, i);
}
for (k, sv_k) in sv.iter_mut().enumerate().take(n) {
*sv_k = *v.get((k, i));
*sv_k = v.get(k, i);
}
let mut j = i;
while w[j - inc] < sw {
w[j] = w[j - inc];
for k in 0..m {
U.set((k, j), *U.get((k, j - inc)));
U.set(k, j, U.get(k, j - inc));
}
for k in 0..n {
v.set((k, j), *v.get((k, j - inc)));
v.set(k, j, v.get(k, j - inc));
}
j -= inc;
if j < inc {
@@ -389,10 +387,10 @@ pub trait SVDDecomposable<T: Number + RealNumber>: Array2<T> {
}
w[j] = sw;
for (k, su_k) in su.iter().enumerate().take(m) {
U.set((k, j), *su_k);
U.set(k, j, *su_k);
}
for (k, sv_k) in sv.iter().enumerate().take(n) {
v.set((k, j), *sv_k);
v.set(k, j, *sv_k);
}
}
if inc <= 1 {
@@ -403,21 +401,21 @@ pub trait SVDDecomposable<T: Number + RealNumber>: Array2<T> {
for k in 0..n {
let mut s = 0.;
for i in 0..m {
if U.get((i, k)) < &T::zero() {
if U.get(i, k) < T::zero() {
s += 1.;
}
}
for j in 0..n {
if v.get((j, k)) < &T::zero() {
if v.get(j, k) < T::zero() {
s += 1.;
}
}
if s > (m + n) as f64 / 2. {
for i in 0..m {
U.set((i, k), -*U.get((i, k)));
U.set(i, k, -U.get(i, k));
}
for j in 0..n {
v.set((j, k), -*v.get((j, k)));
v.set(j, k, -v.get(j, k));
}
}
}
@@ -426,12 +424,21 @@ pub trait SVDDecomposable<T: Number + RealNumber>: Array2<T> {
}
}
impl<T: Number + RealNumber, M: SVDDecomposable<T>> SVD<T, M> {
impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
pub(crate) fn new(U: M, V: M, s: Vec<T>) -> SVD<T, M> {
let m = U.shape().0;
let n = V.shape().0;
let _full = s.len() == m.min(n);
let tol = T::half() * (T::from(m + n).unwrap() + T::one()).sqrt() * s[0] * T::epsilon();
SVD { U, V, s, m, n, tol }
SVD {
U,
V,
s,
_full,
m,
n,
tol,
}
}
pub(crate) fn solve(&self, mut b: M) -> Result<M, Failed> {
@@ -451,7 +458,7 @@ impl<T: Number + RealNumber, M: SVDDecomposable<T>> SVD<T, M> {
let mut r = T::zero();
if self.s[j] > self.tol {
for i in 0..self.m {
r += *self.U.get((i, j)) * *b.get((i, k));
r += self.U.get(i, j) * b.get(i, k);
}
r /= self.s[j];
}
@@ -461,9 +468,9 @@ impl<T: Number + RealNumber, M: SVDDecomposable<T>> SVD<T, M> {
for j in 0..self.n {
let mut r = T::zero();
for (jj, tmp_jj) in tmp.iter().enumerate().take(self.n) {
r += *self.V.get((j, jj)) * (*tmp_jj);
r += self.V.get(j, jj) * (*tmp_jj);
}
b.set((j, k), r);
b.set(j, k, r);
}
}
@@ -474,21 +481,15 @@ impl<T: Number + RealNumber, M: SVDDecomposable<T>> SVD<T, M> {
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use approx::relative_eq;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
use crate::linalg::naive::dense_matrix::DenseMatrix;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn decompose_symmetric() {
let A = DenseMatrix::from_2d_array(&[
&[0.9000, 0.4000, 0.7000],
&[0.4000, 0.5000, 0.3000],
&[0.7000, 0.3000, 0.8000],
])
.unwrap();
]);
let s: Vec<f64> = vec![1.7498382, 0.3165784, 0.1335834];
@@ -496,28 +497,23 @@ mod tests {
&[0.6881997, -0.07121225, 0.7220180],
&[0.3700456, 0.89044952, -0.2648886],
&[0.6240573, -0.44947578, -0.639158],
])
.unwrap();
]);
let V = DenseMatrix::from_2d_array(&[
&[0.6881997, -0.07121225, 0.7220180],
&[0.3700456, 0.89044952, -0.2648886],
&[0.6240573, -0.44947578, -0.6391588],
])
.unwrap();
]);
let svd = A.svd().unwrap();
assert!(relative_eq!(V.abs(), svd.V.abs(), epsilon = 1e-4));
assert!(relative_eq!(U.abs(), svd.U.abs(), epsilon = 1e-4));
for (i, s_i) in s.iter().enumerate() {
assert!((s_i - svd.s[i]).abs() < 1e-4);
assert!(V.abs().approximate_eq(&svd.V.abs(), 1e-4));
assert!(U.abs().approximate_eq(&svd.U.abs(), 1e-4));
for i in 0..s.len() {
assert!((s[i] - svd.s[i]).abs() < 1e-4);
}
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn decompose_asymmetric() {
let A = DenseMatrix::from_2d_array(&[
@@ -578,8 +574,7 @@ mod tests {
-0.2158704,
-0.27529472,
],
])
.unwrap();
]);
let s: Vec<f64> = vec![
3.8589375, 3.4396766, 2.6487176, 2.2317399, 1.5165054, 0.8109055, 0.2706515,
@@ -649,8 +644,7 @@ mod tests {
0.73034065,
-0.43965505,
],
])
.unwrap();
]);
let V = DenseMatrix::from_2d_array(&[
&[
@@ -710,40 +704,31 @@ mod tests {
0.1654796,
-0.32346758,
],
])
.unwrap();
]);
let svd = A.svd().unwrap();
assert!(relative_eq!(V.abs(), svd.V.abs(), epsilon = 1e-4));
assert!(relative_eq!(U.abs(), svd.U.abs(), epsilon = 1e-4));
for (i, s_i) in s.iter().enumerate() {
assert!((s_i - svd.s[i]).abs() < 1e-4);
assert!(V.abs().approximate_eq(&svd.V.abs(), 1e-4));
assert!(U.abs().approximate_eq(&svd.U.abs(), 1e-4));
for i in 0..s.len() {
assert!((s[i] - svd.s[i]).abs() < 1e-4);
}
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn solve() {
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]])
.unwrap();
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]).unwrap();
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]);
let expected_w =
DenseMatrix::from_2d_array(&[&[-0.20, -1.28], &[0.87, 2.22], &[0.47, 0.66]]).unwrap();
DenseMatrix::from_2d_array(&[&[-0.20, -1.28], &[0.87, 2.22], &[0.47, 0.66]]);
let w = a.svd_solve_mut(b).unwrap();
assert!(relative_eq!(w, expected_w, epsilon = 1e-2));
assert!(w.approximate_eq(&expected_w, 1e-2));
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn decompose_restore() {
let a =
DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0, 4.0], &[5.0, 6.0, 7.0, 8.0]]).unwrap();
let a = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0, 4.0], &[5.0, 6.0, 7.0, 8.0]]);
let svd = a.svd().unwrap();
let u: &DenseMatrix<f32> = &svd.U; //U
let v: &DenseMatrix<f32> = &svd.V; // V
@@ -751,6 +736,8 @@ mod tests {
let a_hat = u.matmul(s).matmul(&v.transpose());
assert!(relative_eq!(a, a_hat, epsilon = 1e-3));
for (a, a_hat) in a.iter().zip(a_hat.iter()) {
assert!((a - a_hat).abs() < 1e-3)
}
}
}
-15
View File
@@ -1,15 +0,0 @@
#![allow(clippy::wrong_self_convention)]
pub mod cholesky;
/// The matrix is represented in terms of its eigenvalues and eigenvectors.
pub mod evd;
pub mod high_order;
/// Factors a matrix as the product of a lower triangular matrix and an upper triangular matrix.
pub mod lu;
/// QR factorization that factors a matrix into a product of an orthogonal matrix and an upper triangular matrix.
pub mod qr;
/// statistacal tools for DenseMatrix
pub mod stats;
/// Singular value decomposition.
pub mod svd;
-297
View File
@@ -1,297 +0,0 @@
//! # Various Statistical Methods
//!
//! This module provides reference implementations for various statistical functions.
//! Concrete implementations of the `BaseMatrix` trait are free to override these methods for better performance.
//! This methods shall be used when dealing with `DenseMatrix`. Use the ones in `linalg::arrays` for `Array` types.
use crate::linalg::basic::arrays::{Array2, ArrayView2, MutArrayView2};
use crate::numbers::realnum::RealNumber;
/// Defines baseline implementations for various statistical functions
pub trait MatrixStats<T: RealNumber>: ArrayView2<T> + Array2<T> {
/// Computes the arithmetic mean along the specified axis.
fn mean(&self, axis: u8) -> Vec<T> {
let (n, _m) = match axis {
0 => {
let (n, m) = self.shape();
(m, n)
}
_ => self.shape(),
};
let mut x: Vec<T> = vec![T::zero(); n];
for (i, x_i) in x.iter_mut().enumerate().take(n) {
let vec = match axis {
0 => self.get_col(i).iterator(0).copied().collect::<Vec<T>>(),
_ => self.get_row(i).iterator(0).copied().collect::<Vec<T>>(),
};
*x_i = Self::_mean_of_vector(&vec[..]);
}
x
}
/// Computes variance along the specified axis.
fn var(&self, axis: u8) -> Vec<T> {
let (n, _m) = match axis {
0 => {
let (n, m) = self.shape();
(m, n)
}
_ => self.shape(),
};
let mut x: Vec<T> = vec![T::zero(); n];
for (i, x_i) in x.iter_mut().enumerate().take(n) {
let vec = match axis {
0 => self.get_col(i).iterator(0).copied().collect::<Vec<T>>(),
_ => self.get_row(i).iterator(0).copied().collect::<Vec<T>>(),
};
*x_i = Self::_var_of_vec(&vec[..], Option::None);
}
x
}
/// Computes the standard deviation along the specified axis.
fn std(&self, axis: u8) -> Vec<T> {
let mut x = Self::var(self, axis);
let n = match axis {
0 => self.shape().1,
_ => self.shape().0,
};
for x_i in x.iter_mut().take(n) {
*x_i = x_i.sqrt();
}
x
}
/// <http://en.wikipedia.org/wiki/Arithmetic_mean>
/// Taken from `statistical`
/// The MIT License (MIT)
/// Copyright (c) 2015 Jeff Belgum
fn _mean_of_vector(v: &[T]) -> T {
let len = num::cast(v.len()).unwrap();
v.iter().fold(T::zero(), |acc: T, elem| acc + *elem) / len
}
/// Taken from statistical
/// The MIT License (MIT)
/// Copyright (c) 2015 Jeff Belgum
fn _sum_square_deviations_vec(v: &[T], c: Option<T>) -> T {
let c = match c {
Some(c) => c,
None => Self::_mean_of_vector(v),
};
let sum = v
.iter()
.map(|x| (*x - c) * (*x - c))
.fold(T::zero(), |acc, elem| acc + elem);
assert!(sum >= T::zero(), "negative sum of square root deviations");
sum
}
/// <http://en.wikipedia.org/wiki/Variance#Sample_variance>
/// Taken from statistical
/// The MIT License (MIT)
/// Copyright (c) 2015 Jeff Belgum
fn _var_of_vec(v: &[T], xbar: Option<T>) -> T {
assert!(v.len() > 1, "variance requires at least two data points");
let len: T = num::cast(v.len()).unwrap();
let sum = Self::_sum_square_deviations_vec(v, xbar);
sum / len
}
/// standardize values by removing the mean and scaling to unit variance
fn standard_scale_mut(&mut self, mean: &[T], std: &[T], axis: u8) {
let (n, m) = match axis {
0 => {
let (n, m) = self.shape();
(m, n)
}
_ => self.shape(),
};
for i in 0..n {
for j in 0..m {
match axis {
0 => self.set((j, i), (*self.get((j, i)) - mean[i]) / std[i]),
_ => self.set((i, j), (*self.get((i, j)) - mean[i]) / std[i]),
}
}
}
}
}
//TODO: this is processing. Should have its own "processing.rs" module
/// Defines baseline implementations for various matrix processing functions
pub trait MatrixPreprocessing<T: RealNumber>: MutArrayView2<T> + Clone {
/// Each element of the matrix greater than the threshold becomes 1, while values less than or equal to the threshold become 0
/// ```rust
/// use smartcore::linalg::basic::matrix::DenseMatrix;
/// use smartcore::linalg::traits::stats::MatrixPreprocessing;
/// let mut a = DenseMatrix::from_2d_array(&[&[0., 2., 3.], &[-5., -6., -7.]]).unwrap();
/// let expected = DenseMatrix::from_2d_array(&[&[0., 1., 1.],&[0., 0., 0.]]).unwrap();
/// a.binarize_mut(0.);
///
/// assert_eq!(a, expected);
/// ```
fn binarize_mut(&mut self, threshold: T) {
let (nrows, ncols) = self.shape();
for row in 0..nrows {
for col in 0..ncols {
if *self.get((row, col)) > threshold {
self.set((row, col), T::one());
} else {
self.set((row, col), T::zero());
}
}
}
}
/// Returns new matrix where elements are binarized according to a given threshold.
/// ```rust
/// use smartcore::linalg::basic::matrix::DenseMatrix;
/// use smartcore::linalg::traits::stats::MatrixPreprocessing;
/// let a = DenseMatrix::from_2d_array(&[&[0., 2., 3.], &[-5., -6., -7.]]).unwrap();
/// let expected = DenseMatrix::from_2d_array(&[&[0., 1., 1.],&[0., 0., 0.]]).unwrap();
///
/// assert_eq!(a.binarize(0.), expected);
/// ```
fn binarize(self, threshold: T) -> Self
where
Self: Sized,
{
let mut m = self;
m.binarize_mut(threshold);
m
}
}
#[cfg(test)]
mod tests {
use crate::linalg::basic::arrays::Array1;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::linalg::traits::stats::MatrixStats;
#[test]
fn test_mean() {
let m = DenseMatrix::from_2d_array(&[
&[1., 2., 3., 1., 2.],
&[4., 5., 6., 3., 4.],
&[7., 8., 9., 5., 6.],
])
.unwrap();
let expected_0 = vec![4., 5., 6., 3., 4.];
let expected_1 = vec![1.8, 4.4, 7.];
assert_eq!(m.mean(0), expected_0);
assert_eq!(m.mean(1), expected_1);
}
#[test]
fn test_var() {
let m = DenseMatrix::from_2d_array(&[&[1., 2., 3., 4.], &[5., 6., 7., 8.]]).unwrap();
let expected_0 = vec![4., 4., 4., 4.];
let expected_1 = vec![1.25, 1.25];
assert!(m.var(0).approximate_eq(&expected_0, 1e-6));
assert!(m.var(1).approximate_eq(&expected_1, 1e-6));
assert_eq!(m.mean(0), vec![3.0, 4.0, 5.0, 6.0]);
assert_eq!(m.mean(1), vec![2.5, 6.5]);
}
#[test]
fn test_var_other() {
let m = DenseMatrix::from_2d_array(&[
&[0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25],
&[0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25],
])
.unwrap();
let expected_0 = vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let expected_1 = vec![1.25, 1.25];
assert!(m.var(0).approximate_eq(&expected_0, f64::EPSILON));
assert!(m.var(1).approximate_eq(&expected_1, f64::EPSILON));
assert_eq!(
m.mean(0),
vec![0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25]
);
assert_eq!(m.mean(1), vec![1.375, 1.375]);
}
#[test]
fn test_std() {
let m = DenseMatrix::from_2d_array(&[
&[1., 2., 3., 1., 2.],
&[4., 5., 6., 3., 4.],
&[7., 8., 9., 5., 6.],
])
.unwrap();
let expected_0 = vec![
2.449489742783178,
2.449489742783178,
2.449489742783178,
1.632993161855452,
1.632993161855452,
];
let expected_1 = vec![0.7483314773547883, 1.019803902718557, 1.4142135623730951];
println!("{:?}", m.var(0));
assert!(m.std(0).approximate_eq(&expected_0, f64::EPSILON));
assert!(m.std(1).approximate_eq(&expected_1, f64::EPSILON));
assert_eq!(m.mean(0), vec![4.0, 5.0, 6.0, 3.0, 4.0]);
assert_eq!(m.mean(1), vec![1.8, 4.4, 7.0]);
}
#[test]
fn test_scale() {
let m: DenseMatrix<f64> =
DenseMatrix::from_2d_array(&[&[1., 2., 3., 4.], &[5., 6., 7., 8.]]).unwrap();
let expected_0: DenseMatrix<f64> =
DenseMatrix::from_2d_array(&[&[-1., -1., -1., -1.], &[1., 1., 1., 1.]]).unwrap();
let expected_1: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
&[
-1.3416407864998738,
-0.4472135954999579,
0.4472135954999579,
1.3416407864998738,
],
&[
-1.3416407864998738,
-0.4472135954999579,
0.4472135954999579,
1.3416407864998738,
],
])
.unwrap();
assert_eq!(m.mean(0), vec![3.0, 4.0, 5.0, 6.0]);
assert_eq!(m.mean(1), vec![2.5, 6.5]);
assert_eq!(m.var(0), vec![4., 4., 4., 4.]);
assert_eq!(m.var(1), vec![1.25, 1.25]);
assert_eq!(m.std(0), vec![2., 2., 2., 2.]);
assert_eq!(m.std(1), vec![1.118033988749895, 1.118033988749895]);
{
let mut m = m.clone();
m.standard_scale_mut(&m.mean(0), &m.std(0), 0);
assert_eq!(&m, &expected_0);
}
{
let mut m = m;
m.standard_scale_mut(&m.mean(1), &m.std(1), 1);
assert_eq!(&m, &expected_1);
}
}
}
+49 -81
View File
@@ -1,43 +1,13 @@
//! This is a generic solver for Ax = b type of equation
//!
//! Example:
//! ```
//! use smartcore::linalg::basic::arrays::Array1;
//! use smartcore::linalg::basic::arrays::Array2;
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linear::bg_solver::*;
//! use smartcore::numbers::floatnum::FloatNumber;
//! use smartcore::linear::bg_solver::BiconjugateGradientSolver;
//!
//! pub struct BGSolver {}
//! impl<'a, T: FloatNumber, X: Array2<T>> BiconjugateGradientSolver<'a, T, X> for BGSolver {}
//!
//! let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0.,
//! 11.]]).unwrap();
//! let b = vec![40., 51., 28.];
//! let expected = vec![1.0, 2.0, 3.0];
//! let mut x = Vec::zeros(3);
//! let solver = BGSolver {};
//! let err: f64 = solver.solve_mut(&a, &b, &mut x, 1e-6, 6).unwrap();
//! ```
//!
//! for more information take a look at [this Wikipedia article](https://en.wikipedia.org/wiki/Biconjugate_gradient_method)
//! and [this paper](https://www.cs.cmu.edu/~quake-papers/painless-conjugate-gradient.pdf)
use crate::error::Failed;
use crate::linalg::basic::arrays::{Array, Array1, Array2, ArrayView1, MutArrayView1};
use crate::numbers::floatnum::FloatNumber;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
/// Trait for Biconjugate Gradient Solver
pub trait BiconjugateGradientSolver<'a, T: FloatNumber, X: Array2<T>> {
/// Solve Ax = b
fn solve_mut(
&self,
a: &'a X,
b: &Vec<T>,
x: &mut Vec<T>,
tol: T,
max_iter: usize,
) -> Result<T, Failed> {
pub trait BiconjugateGradientSolver<T: RealNumber, M: Matrix<T>> {
fn solve_mut(&self, a: &M, b: &M, x: &mut M, tol: T, max_iter: usize) -> Result<T, Failed> {
if tol <= T::zero() {
return Err(Failed::fit("tolerance shoud be > 0"));
}
@@ -46,25 +16,25 @@ pub trait BiconjugateGradientSolver<'a, T: FloatNumber, X: Array2<T>> {
return Err(Failed::fit("maximum number of iterations should be > 0"));
}
let n = b.shape();
let (n, _) = b.shape();
let mut r = Vec::zeros(n);
let mut rr = Vec::zeros(n);
let mut z = Vec::zeros(n);
let mut zz = Vec::zeros(n);
let mut r = M::zeros(n, 1);
let mut rr = M::zeros(n, 1);
let mut z = M::zeros(n, 1);
let mut zz = M::zeros(n, 1);
self.mat_vec_mul(a, x, &mut r);
for j in 0..n {
r[j] = b[j] - r[j];
rr[j] = r[j];
r.set(j, 0, b.get(j, 0) - r.get(j, 0));
rr.set(j, 0, r.get(j, 0));
}
let bnrm = b.norm(2f64);
self.solve_preconditioner(a, &r[..], &mut z[..]);
let bnrm = b.norm(T::two());
self.solve_preconditioner(a, &r, &mut z);
let mut p = Vec::zeros(n);
let mut pp = Vec::zeros(n);
let mut p = M::zeros(n, 1);
let mut pp = M::zeros(n, 1);
let mut bkden = T::zero();
let mut err = T::zero();
@@ -73,33 +43,35 @@ pub trait BiconjugateGradientSolver<'a, T: FloatNumber, X: Array2<T>> {
self.solve_preconditioner(a, &rr, &mut zz);
for j in 0..n {
bknum += z[j] * rr[j];
bknum += z.get(j, 0) * rr.get(j, 0);
}
if iter == 1 {
p[..n].copy_from_slice(&z[..n]);
pp[..n].copy_from_slice(&zz[..n]);
for j in 0..n {
p.set(j, 0, z.get(j, 0));
pp.set(j, 0, zz.get(j, 0));
}
} else {
let bk = bknum / bkden;
for j in 0..n {
p[j] = bk * pp[j] + z[j];
pp[j] = bk * pp[j] + zz[j];
p.set(j, 0, bk * p.get(j, 0) + z.get(j, 0));
pp.set(j, 0, bk * pp.get(j, 0) + zz.get(j, 0));
}
}
bkden = bknum;
self.mat_vec_mul(a, &p, &mut z);
let mut akden = T::zero();
for j in 0..n {
akden += z[j] * pp[j];
akden += z.get(j, 0) * pp.get(j, 0);
}
let ak = bknum / akden;
self.mat_t_vec_mul(a, &pp, &mut zz);
for j in 0..n {
x[j] += ak * p[j];
r[j] -= ak * z[j];
rr[j] -= ak * zz[j];
x.set(j, 0, x.get(j, 0) + ak * p.get(j, 0));
r.set(j, 0, r.get(j, 0) - ak * z.get(j, 0));
rr.set(j, 0, rr.get(j, 0) - ak * zz.get(j, 0));
}
self.solve_preconditioner(a, &r, &mut z);
err = T::from_f64(r.norm(2f64) / bnrm).unwrap();
err = r.norm(T::two()) / bnrm;
if err <= tol {
break;
@@ -109,38 +81,36 @@ pub trait BiconjugateGradientSolver<'a, T: FloatNumber, X: Array2<T>> {
Ok(err)
}
/// solve preconditioner
fn solve_preconditioner(&self, a: &'a X, b: &[T], x: &mut [T]) {
fn solve_preconditioner(&self, a: &M, b: &M, x: &mut M) {
let diag = Self::diag(a);
let n = diag.len();
for (i, diag_i) in diag.iter().enumerate().take(n) {
if *diag_i != T::zero() {
x[i] = b[i] / *diag_i;
x.set(i, 0, b.get(i, 0) / *diag_i);
} else {
x[i] = b[i];
x.set(i, 0, b.get(i, 0));
}
}
}
/// y = Ax
fn mat_vec_mul(&self, a: &X, x: &Vec<T>, y: &mut Vec<T>) {
y.copy_from(&x.xa(false, a));
// y = Ax
fn mat_vec_mul(&self, a: &M, x: &M, y: &mut M) {
y.copy_from(&a.matmul(x));
}
/// y = Atx
fn mat_t_vec_mul(&self, a: &X, x: &Vec<T>, y: &mut Vec<T>) {
y.copy_from(&x.xa(true, a));
// y = Atx
fn mat_t_vec_mul(&self, a: &M, x: &M, y: &mut M) {
y.copy_from(&a.ab(true, x, false));
}
/// Extract the diagonal from a matrix
fn diag(a: &X) -> Vec<T> {
fn diag(a: &M) -> Vec<T> {
let (nrows, ncols) = a.shape();
let n = nrows.min(ncols);
let mut d = Vec::with_capacity(n);
for i in 0..n {
d.push(*a.get((i, i)));
d.push(a.get(i, i));
}
d
@@ -150,30 +120,28 @@ pub trait BiconjugateGradientSolver<'a, T: FloatNumber, X: Array2<T>> {
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::arrays::Array2;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::linalg::naive::dense_matrix::*;
pub struct BGSolver {}
impl<T: FloatNumber, X: Array2<T>> BiconjugateGradientSolver<'_, T, X> for BGSolver {}
impl<T: RealNumber, M: Matrix<T>> BiconjugateGradientSolver<T, M> for BGSolver {}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn bg_solver() {
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]])
.unwrap();
let b = vec![40., 51., 28.];
let expected = [1.0, 2.0, 3.0];
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
let b = DenseMatrix::from_2d_array(&[&[40., 51., 28.]]);
let expected = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]);
let mut x = Vec::zeros(3);
let mut x = DenseMatrix::zeros(3, 1);
let solver = BGSolver {};
let err: f64 = solver.solve_mut(&a, &b, &mut x, 1e-6, 6).unwrap();
let err: f64 = solver
.solve_mut(&a, &b.transpose(), &mut x, 1e-6, 6)
.unwrap();
assert!(x
.iter()
.zip(expected.iter())
.all(|(&a, &b)| (a - b).abs() < 1e-4));
assert!(x.transpose().approximate_eq(&expected, 1e-4));
assert!((err - 0.0).abs() < 1e-4);
}
}
+113 -321
View File
@@ -17,7 +17,7 @@
//! Example:
//!
//! ```
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::linear::elastic_net::*;
//!
//! // Longley dataset (https://www.statsmodels.org/stable/datasets/generated/longley.html)
@@ -38,7 +38,7 @@
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
//! ]).unwrap();
//! ]);
//!
//! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0,
//! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
@@ -55,39 +55,32 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::fmt::Debug;
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed;
use crate::linalg::basic::arrays::{Array, Array1, Array2, MutArray};
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
use crate::numbers::realnum::RealNumber;
use crate::linalg::BaseVector;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
use crate::linear::lasso_optimizer::InteriorPointOptimizer;
/// Elastic net parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct ElasticNetParameters {
#[cfg_attr(feature = "serde", serde(default))]
pub struct ElasticNetParameters<T: RealNumber> {
/// Regularization parameter.
pub alpha: f64,
#[cfg_attr(feature = "serde", serde(default))]
pub alpha: T,
/// The elastic net mixing parameter, with 0 <= l1_ratio <= 1.
/// For l1_ratio = 0 the penalty is an L2 penalty.
/// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.
pub l1_ratio: f64,
#[cfg_attr(feature = "serde", serde(default))]
pub l1_ratio: T,
/// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation.
pub normalize: bool,
#[cfg_attr(feature = "serde", serde(default))]
/// The tolerance for the optimization
pub tol: f64,
#[cfg_attr(feature = "serde", serde(default))]
pub tol: T,
/// The maximum number of iterations
pub max_iter: usize,
}
@@ -95,23 +88,21 @@ pub struct ElasticNetParameters {
/// Elastic net
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct ElasticNet<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
coefficients: Option<X>,
intercept: Option<TX>,
_phantom_ty: PhantomData<TY>,
_phantom_y: PhantomData<Y>,
pub struct ElasticNet<T: RealNumber, M: Matrix<T>> {
coefficients: M,
intercept: T,
}
impl ElasticNetParameters {
impl<T: RealNumber> ElasticNetParameters<T> {
/// Regularization parameter.
pub fn with_alpha(mut self, alpha: f64) -> Self {
pub fn with_alpha(mut self, alpha: T) -> Self {
self.alpha = alpha;
self
}
/// The elastic net mixing parameter, with 0 <= l1_ratio <= 1.
/// For l1_ratio = 0 the penalty is an L2 penalty.
/// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.
pub fn with_l1_ratio(mut self, l1_ratio: f64) -> Self {
pub fn with_l1_ratio(mut self, l1_ratio: T) -> Self {
self.l1_ratio = l1_ratio;
self
}
@@ -121,7 +112,7 @@ impl ElasticNetParameters {
self
}
/// The tolerance for the optimization
pub fn with_tol(mut self, tol: f64) -> Self {
pub fn with_tol(mut self, tol: T) -> Self {
self.tol = tol;
self
}
@@ -132,205 +123,61 @@ impl ElasticNetParameters {
}
}
impl Default for ElasticNetParameters {
impl<T: RealNumber> Default for ElasticNetParameters<T> {
fn default() -> Self {
ElasticNetParameters {
alpha: 1.0,
l1_ratio: 0.5,
alpha: T::one(),
l1_ratio: T::half(),
normalize: true,
tol: 1e-4,
tol: T::from_f64(1e-4).unwrap(),
max_iter: 1000,
}
}
}
/// ElasticNet grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct ElasticNetSearchParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Regularization parameter.
pub alpha: Vec<f64>,
#[cfg_attr(feature = "serde", serde(default))]
/// The elastic net mixing parameter, with 0 <= l1_ratio <= 1.
/// For l1_ratio = 0 the penalty is an L2 penalty.
/// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.
pub l1_ratio: Vec<f64>,
#[cfg_attr(feature = "serde", serde(default))]
/// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation.
pub normalize: Vec<bool>,
#[cfg_attr(feature = "serde", serde(default))]
/// The tolerance for the optimization
pub tol: Vec<f64>,
#[cfg_attr(feature = "serde", serde(default))]
/// The maximum number of iterations
pub max_iter: Vec<usize>,
}
/// ElasticNet grid search iterator
pub struct ElasticNetSearchParametersIterator {
lasso_regression_search_parameters: ElasticNetSearchParameters,
current_alpha: usize,
current_l1_ratio: usize,
current_normalize: usize,
current_tol: usize,
current_max_iter: usize,
}
impl IntoIterator for ElasticNetSearchParameters {
type Item = ElasticNetParameters;
type IntoIter = ElasticNetSearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
ElasticNetSearchParametersIterator {
lasso_regression_search_parameters: self,
current_alpha: 0,
current_l1_ratio: 0,
current_normalize: 0,
current_tol: 0,
current_max_iter: 0,
}
}
}
impl Iterator for ElasticNetSearchParametersIterator {
type Item = ElasticNetParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_alpha == self.lasso_regression_search_parameters.alpha.len()
&& self.current_l1_ratio == self.lasso_regression_search_parameters.l1_ratio.len()
&& self.current_normalize == self.lasso_regression_search_parameters.normalize.len()
&& self.current_tol == self.lasso_regression_search_parameters.tol.len()
&& self.current_max_iter == self.lasso_regression_search_parameters.max_iter.len()
{
return None;
}
let next = ElasticNetParameters {
alpha: self.lasso_regression_search_parameters.alpha[self.current_alpha],
l1_ratio: self.lasso_regression_search_parameters.alpha[self.current_l1_ratio],
normalize: self.lasso_regression_search_parameters.normalize[self.current_normalize],
tol: self.lasso_regression_search_parameters.tol[self.current_tol],
max_iter: self.lasso_regression_search_parameters.max_iter[self.current_max_iter],
};
if self.current_alpha + 1 < self.lasso_regression_search_parameters.alpha.len() {
self.current_alpha += 1;
} else if self.current_l1_ratio + 1 < self.lasso_regression_search_parameters.l1_ratio.len()
{
self.current_alpha = 0;
self.current_l1_ratio += 1;
} else if self.current_normalize + 1
< self.lasso_regression_search_parameters.normalize.len()
{
self.current_alpha = 0;
self.current_l1_ratio = 0;
self.current_normalize += 1;
} else if self.current_tol + 1 < self.lasso_regression_search_parameters.tol.len() {
self.current_alpha = 0;
self.current_l1_ratio = 0;
self.current_normalize = 0;
self.current_tol += 1;
} else if self.current_max_iter + 1 < self.lasso_regression_search_parameters.max_iter.len()
{
self.current_alpha = 0;
self.current_l1_ratio = 0;
self.current_normalize = 0;
self.current_tol = 0;
self.current_max_iter += 1;
} else {
self.current_alpha += 1;
self.current_l1_ratio += 1;
self.current_normalize += 1;
self.current_tol += 1;
self.current_max_iter += 1;
}
Some(next)
}
}
impl Default for ElasticNetSearchParameters {
fn default() -> Self {
let default_params = ElasticNetParameters::default();
ElasticNetSearchParameters {
alpha: vec![default_params.alpha],
l1_ratio: vec![default_params.l1_ratio],
normalize: vec![default_params.normalize],
tol: vec![default_params.tol],
max_iter: vec![default_params.max_iter],
}
}
}
impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
for ElasticNet<TX, TY, X, Y>
{
impl<T: RealNumber, M: Matrix<T>> PartialEq for ElasticNet<T, M> {
fn eq(&self, other: &Self) -> bool {
if self.intercept() != other.intercept() {
return false;
}
if self.coefficients().shape() != other.coefficients().shape() {
return false;
}
self.coefficients()
.iterator(0)
.zip(other.coefficients().iterator(0))
.all(|(&a, &b)| (a - b).abs() <= TX::epsilon())
self.coefficients == other.coefficients
&& (self.intercept - other.intercept).abs() <= T::epsilon()
}
}
impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
SupervisedEstimator<X, Y, ElasticNetParameters> for ElasticNet<TX, TY, X, Y>
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, ElasticNetParameters<T>>
for ElasticNet<T, M>
{
fn new() -> Self {
Self {
coefficients: Option::None,
intercept: Option::None,
_phantom_ty: PhantomData,
_phantom_y: PhantomData,
}
}
fn fit(x: &X, y: &Y, parameters: ElasticNetParameters) -> Result<Self, Failed> {
fn fit(x: &M, y: &M::RowVector, parameters: ElasticNetParameters<T>) -> Result<Self, Failed> {
ElasticNet::fit(x, y, parameters)
}
}
impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Predictor<X, Y>
for ElasticNet<TX, TY, X, Y>
{
fn predict(&self, x: &X) -> Result<Y, Failed> {
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for ElasticNet<T, M> {
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
self.predict(x)
}
}
impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
ElasticNet<TX, TY, X, Y>
{
impl<T: RealNumber, M: Matrix<T>> ElasticNet<T, M> {
/// Fits elastic net regression to your data.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - target values
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
pub fn fit(
x: &X,
y: &Y,
parameters: ElasticNetParameters,
) -> Result<ElasticNet<TX, TY, X, Y>, Failed> {
x: &M,
y: &M::RowVector,
parameters: ElasticNetParameters<T>,
) -> Result<ElasticNet<T, M>, Failed> {
let (n, p) = x.shape();
if y.shape() != n {
if y.len() != n {
return Err(Failed::fit("Number of rows in X should = len(y)"));
}
let n_float = n as f64;
let n_float = T::from_usize(n).unwrap();
let l1_reg = TX::from_f64(parameters.alpha * parameters.l1_ratio * n_float).unwrap();
let l2_reg =
TX::from_f64(parameters.alpha * (1.0 - parameters.l1_ratio) * n_float).unwrap();
let l1_reg = parameters.alpha * parameters.l1_ratio * n_float;
let l2_reg = parameters.alpha * (T::one() - parameters.l1_ratio) * n_float;
let y_mean = TX::from_f64(y.mean_by()).unwrap();
let y_mean = y.mean();
let (w, b) = if parameters.normalize {
let (scaled_x, col_mean, col_std) = Self::rescale_x(x)?;
@@ -339,95 +186,72 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
let mut optimizer = InteriorPointOptimizer::new(&x, p);
let mut w = optimizer.optimize(
&x,
&y,
l1_reg * gamma,
parameters.max_iter,
TX::from_f64(parameters.tol).unwrap(),
true,
)?;
let mut w =
optimizer.optimize(&x, &y, l1_reg * gamma, parameters.max_iter, parameters.tol)?;
for i in 0..p {
w.set(i, gamma * *w.get(i) / col_std[i]);
w.set(i, 0, gamma * w.get(i, 0) / col_std[i]);
}
let mut b = TX::zero();
let mut b = T::zero();
for i in 0..p {
b += *w.get(i) * col_mean[i];
b += w.get(i, 0) * col_mean[i];
}
b = y_mean - b;
(X::from_column(&w), b)
(w, b)
} else {
let (x, y, gamma) = Self::augment_x_and_y(x, y, l2_reg);
let mut optimizer = InteriorPointOptimizer::new(&x, p);
let mut w = optimizer.optimize(
&x,
&y,
l1_reg * gamma,
parameters.max_iter,
TX::from_f64(parameters.tol).unwrap(),
true,
)?;
let mut w =
optimizer.optimize(&x, &y, l1_reg * gamma, parameters.max_iter, parameters.tol)?;
for i in 0..p {
w.set(i, gamma * *w.get(i));
w.set(i, 0, gamma * w.get(i, 0));
}
(X::from_column(&w), y_mean)
(w, y_mean)
};
Ok(ElasticNet {
intercept: Some(b),
coefficients: Some(w),
_phantom_ty: PhantomData,
_phantom_y: PhantomData,
intercept: b,
coefficients: w,
})
}
/// Predict target values from `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
let (nrows, _) = x.shape();
let mut y_hat = x.matmul(self.coefficients.as_ref().unwrap());
let bias = X::fill(nrows, 1, self.intercept.unwrap());
y_hat.add_mut(&bias);
Ok(Y::from_iterator(
y_hat.iterator(0).map(|&v| TY::from(v).unwrap()),
nrows,
))
let mut y_hat = x.matmul(&self.coefficients);
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
Ok(y_hat.transpose().to_row_vector())
}
/// Get estimates regression coefficients
pub fn coefficients(&self) -> &X {
self.coefficients.as_ref().unwrap()
pub fn coefficients(&self) -> &M {
&self.coefficients
}
/// Get estimate of intercept
pub fn intercept(&self) -> &TX {
self.intercept.as_ref().unwrap()
pub fn intercept(&self) -> T {
self.intercept
}
fn rescale_x(x: &X) -> Result<(X, Vec<TX>, Vec<TX>), Failed> {
let col_mean: Vec<TX> = x
.mean_by(0)
.iter()
.map(|&v| TX::from_f64(v).unwrap())
.collect();
let col_std: Vec<TX> = x
.std_dev(0)
.iter()
.map(|&v| TX::from_f64(v).unwrap())
.collect();
fn rescale_x(x: &M) -> Result<(M, Vec<T>, Vec<T>), Failed> {
let col_mean = x.mean(0);
let col_std = x.std(0);
for (i, col_std_i) in col_std.iter().enumerate() {
if (*col_std_i - TX::zero()).abs() < TX::epsilon() {
return Err(Failed::fit(&format!("Cannot rescale constant column {i}")));
for i in 0..col_std.len() {
if (col_std[i] - T::zero()).abs() < T::epsilon() {
return Err(Failed::fit(&format!(
"Cannot rescale constant column {}",
i
)));
}
}
@@ -436,25 +260,25 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
Ok((scaled_x, col_mean, col_std))
}
fn augment_x_and_y(x: &X, y: &Y, l2_reg: TX) -> (X, Vec<TX>, TX) {
fn augment_x_and_y(x: &M, y: &M::RowVector, l2_reg: T) -> (M, M::RowVector, T) {
let (n, p) = x.shape();
let gamma = TX::one() / (TX::one() + l2_reg).sqrt();
let gamma = T::one() / (T::one() + l2_reg).sqrt();
let padding = gamma * l2_reg.sqrt();
let mut y2 = Vec::<TX>::zeros(n + p);
for i in 0..y.shape() {
y2.set(i, TX::from(*y.get(i)).unwrap());
let mut y2 = M::RowVector::zeros(n + p);
for i in 0..y.len() {
y2.set(i, y.get(i));
}
let mut x2 = X::zeros(n + p, p);
let mut x2 = M::zeros(n + p, p);
for j in 0..p {
for i in 0..n {
x2.set((i, j), gamma * *x.get((i, j)));
x2.set(i, j, gamma * x.get(i, j));
}
x2.set((j + n, j), padding);
x2.set(j + n, j, padding);
}
(x2, y2, gamma)
@@ -464,36 +288,10 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::linalg::naive::dense_matrix::*;
use crate::metrics::mean_absolute_error;
#[test]
fn search_parameters() {
let parameters = ElasticNetSearchParameters {
alpha: vec![0., 1.],
max_iter: vec![10, 100],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.alpha, 0.);
assert_eq!(next.max_iter, 10);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 1.);
assert_eq!(next.max_iter, 10);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 0.);
assert_eq!(next.max_iter, 100);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 1.);
assert_eq!(next.max_iter, 100);
assert!(iter.next().is_none());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn elasticnet_longley() {
let x = DenseMatrix::from_2d_array(&[
@@ -513,8 +311,7 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
])
.unwrap();
]);
let y: Vec<f64> = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
@@ -538,10 +335,7 @@ mod tests {
assert!(mean_absolute_error(&y_hat, &y) < 30.0);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn elasticnet_fit_predict1() {
let x = DenseMatrix::from_2d_array(&[
@@ -565,8 +359,7 @@ mod tests {
&[17.0, 1918.0, 1.4054969025700674],
&[18.0, 1929.0, 1.3271699396384906],
&[19.0, 1915.0, 1.1373332337674806],
])
.unwrap();
]);
let y: Vec<f64> = vec![
1.48, 2.72, 4.52, 5.72, 5.25, 4.07, 3.75, 4.75, 6.77, 4.72, 6.78, 6.79, 8.3, 7.42,
@@ -605,44 +398,43 @@ mod tests {
assert!(mae_l1 < 2.0);
assert!(mae_l2 < 2.0);
assert!(l1_model.coefficients().get((0, 0)) > l1_model.coefficients().get((1, 0)));
assert!(l1_model.coefficients().get((0, 0)) > l1_model.coefficients().get((2, 0)));
assert!(l1_model.coefficients().get(0, 0) > l1_model.coefficients().get(1, 0));
assert!(l1_model.coefficients().get(0, 0) > l1_model.coefficients().get(2, 0));
}
// TODO: serialization for the new DenseMatrix needs to be implemented
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
// #[test]
// #[cfg(feature = "serde")]
// fn serde() {
// let x = DenseMatrix::from_2d_array(&[
// &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
// &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
// &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
// &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
// &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
// &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
// &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
// &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
// &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
// &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
// &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
// &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
// &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
// &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
// &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
// &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
// ]).unwrap();
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
#[cfg(feature = "serde")]
fn serde() {
let x = DenseMatrix::from_2d_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
&[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
&[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
&[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
&[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
&[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
&[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]);
// let y = vec![
// 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
// 114.2, 115.7, 116.9,
// ];
let y = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
];
// let lr = ElasticNet::fit(&x, &y, Default::default()).unwrap();
let lr = ElasticNet::fit(&x, &y, Default::default()).unwrap();
// let deserialized_lr: ElasticNet<f64, f64, DenseMatrix<f64>, Vec<f64>> =
// serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
let deserialized_lr: ElasticNet<f64, DenseMatrix<f64>> =
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
// assert_eq!(lr, deserialized_lr);
// }
assert_eq!(lr, deserialized_lr);
}
}
+98 -362
View File
@@ -9,7 +9,7 @@
//!
//! Lasso coefficient estimates solve the problem:
//!
//! \\[\underset{\beta}{minimize} \space \space \frac{1}{n} \sum_{i=1}^n \left( y_i - \beta_0 - \sum_{j=1}^p \beta_jx_{ij} \right)^2 + \alpha \sum_{j=1}^p \lVert \beta_j \rVert_1\\]
//! \\[\underset{\beta}{minimize} \space \space \sum_{i=1}^n \left( y_i - \beta_0 - \sum_{j=1}^p \beta_jx_{ij} \right)^2 + \alpha \sum_{j=1}^p \lVert \beta_j \rVert_1\\]
//!
//! This problem is solved with an interior-point method that is comparable to coordinate descent in solving large problems with modest accuracy,
//! but is able to solve them with high accuracy with relatively small additional computational cost.
@@ -23,54 +23,43 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::fmt::Debug;
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed;
use crate::linalg::basic::arrays::{Array1, Array2, ArrayView1};
use crate::linalg::BaseVector;
use crate::linalg::Matrix;
use crate::linear::lasso_optimizer::InteriorPointOptimizer;
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
use crate::numbers::realnum::RealNumber;
use crate::math::num::RealNumber;
/// Lasso regression parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LassoParameters {
#[cfg_attr(feature = "serde", serde(default))]
pub struct LassoParameters<T: RealNumber> {
/// Controls the strength of the penalty to the loss function.
pub alpha: f64,
#[cfg_attr(feature = "serde", serde(default))]
pub alpha: T,
/// If true the regressors X will be normalized before regression
/// by subtracting the mean and dividing by the standard deviation.
pub normalize: bool,
#[cfg_attr(feature = "serde", serde(default))]
/// The tolerance for the optimization
pub tol: f64,
#[cfg_attr(feature = "serde", serde(default))]
pub tol: T,
/// The maximum number of iterations
pub max_iter: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// If false, force the intercept parameter (beta_0) to be zero.
pub fit_intercept: bool,
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
/// Lasso regressor
pub struct Lasso<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
coefficients: Option<X>,
intercept: Option<TX>,
_phantom_ty: PhantomData<TY>,
_phantom_y: PhantomData<Y>,
pub struct Lasso<T: RealNumber, M: Matrix<T>> {
coefficients: M,
intercept: T,
}
impl LassoParameters {
impl<T: RealNumber> LassoParameters<T> {
/// Regularization parameter.
pub fn with_alpha(mut self, alpha: f64) -> Self {
pub fn with_alpha(mut self, alpha: T) -> Self {
self.alpha = alpha;
self
}
@@ -80,7 +69,7 @@ impl LassoParameters {
self
}
/// The tolerance for the optimization
pub fn with_tol(mut self, tol: f64) -> Self {
pub fn with_tol(mut self, tol: T) -> Self {
self.tol = tol;
self
}
@@ -89,200 +78,63 @@ impl LassoParameters {
self.max_iter = max_iter;
self
}
/// If false, force the intercept parameter (beta_0) to be zero.
pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
self.fit_intercept = fit_intercept;
self
}
}
impl Default for LassoParameters {
impl<T: RealNumber> Default for LassoParameters<T> {
fn default() -> Self {
LassoParameters {
alpha: 1f64,
alpha: T::one(),
normalize: true,
tol: 1e-4,
tol: T::from_f64(1e-4).unwrap(),
max_iter: 1000,
fit_intercept: true,
}
}
}
impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
for Lasso<TX, TY, X, Y>
{
impl<T: RealNumber, M: Matrix<T>> PartialEq for Lasso<T, M> {
fn eq(&self, other: &Self) -> bool {
self.intercept == other.intercept
&& self.coefficients().shape() == other.coefficients().shape()
&& self
.coefficients()
.iterator(0)
.zip(other.coefficients().iterator(0))
.all(|(&a, &b)| (a - b).abs() <= TX::epsilon())
self.coefficients == other.coefficients
&& (self.intercept - other.intercept).abs() <= T::epsilon()
}
}
impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
SupervisedEstimator<X, Y, LassoParameters> for Lasso<TX, TY, X, Y>
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, LassoParameters<T>>
for Lasso<T, M>
{
fn new() -> Self {
Self {
coefficients: None,
intercept: None,
_phantom_ty: PhantomData,
_phantom_y: PhantomData,
}
}
fn fit(x: &X, y: &Y, parameters: LassoParameters) -> Result<Self, Failed> {
fn fit(x: &M, y: &M::RowVector, parameters: LassoParameters<T>) -> Result<Self, Failed> {
Lasso::fit(x, y, parameters)
}
}
impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Predictor<X, Y>
for Lasso<TX, TY, X, Y>
{
fn predict(&self, x: &X) -> Result<Y, Failed> {
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for Lasso<T, M> {
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
self.predict(x)
}
}
/// Lasso grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LassoSearchParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Controls the strength of the penalty to the loss function.
pub alpha: Vec<f64>,
#[cfg_attr(feature = "serde", serde(default))]
/// If true the regressors X will be normalized before regression
/// by subtracting the mean and dividing by the standard deviation.
pub normalize: Vec<bool>,
#[cfg_attr(feature = "serde", serde(default))]
/// The tolerance for the optimization
pub tol: Vec<f64>,
#[cfg_attr(feature = "serde", serde(default))]
/// The maximum number of iterations
pub max_iter: Vec<usize>,
#[cfg_attr(feature = "serde", serde(default))]
/// If false, force the intercept parameter (beta_0) to be zero.
pub fit_intercept: Vec<bool>,
}
/// Lasso grid search iterator
pub struct LassoSearchParametersIterator {
lasso_search_parameters: LassoSearchParameters,
current_alpha: usize,
current_normalize: usize,
current_tol: usize,
current_max_iter: usize,
current_fit_intercept: usize,
}
impl IntoIterator for LassoSearchParameters {
type Item = LassoParameters;
type IntoIter = LassoSearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
LassoSearchParametersIterator {
lasso_search_parameters: self,
current_alpha: 0,
current_normalize: 0,
current_tol: 0,
current_max_iter: 0,
current_fit_intercept: 0,
}
}
}
impl Iterator for LassoSearchParametersIterator {
type Item = LassoParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_alpha == self.lasso_search_parameters.alpha.len()
&& self.current_normalize == self.lasso_search_parameters.normalize.len()
&& self.current_tol == self.lasso_search_parameters.tol.len()
&& self.current_max_iter == self.lasso_search_parameters.max_iter.len()
&& self.current_fit_intercept == self.lasso_search_parameters.fit_intercept.len()
{
return None;
}
let next = LassoParameters {
alpha: self.lasso_search_parameters.alpha[self.current_alpha],
normalize: self.lasso_search_parameters.normalize[self.current_normalize],
tol: self.lasso_search_parameters.tol[self.current_tol],
max_iter: self.lasso_search_parameters.max_iter[self.current_max_iter],
fit_intercept: self.lasso_search_parameters.fit_intercept[self.current_fit_intercept],
};
if self.current_alpha + 1 < self.lasso_search_parameters.alpha.len() {
self.current_alpha += 1;
} else if self.current_normalize + 1 < self.lasso_search_parameters.normalize.len() {
self.current_alpha = 0;
self.current_normalize += 1;
} else if self.current_tol + 1 < self.lasso_search_parameters.tol.len() {
self.current_alpha = 0;
self.current_normalize = 0;
self.current_tol += 1;
} else if self.current_max_iter + 1 < self.lasso_search_parameters.max_iter.len() {
self.current_alpha = 0;
self.current_normalize = 0;
self.current_tol = 0;
self.current_max_iter += 1;
} else if self.current_fit_intercept + 1 < self.lasso_search_parameters.fit_intercept.len()
{
self.current_alpha = 0;
self.current_normalize = 0;
self.current_tol = 0;
self.current_max_iter = 0;
self.current_fit_intercept += 1;
} else {
self.current_alpha += 1;
self.current_normalize += 1;
self.current_tol += 1;
self.current_max_iter += 1;
self.current_fit_intercept += 1;
}
Some(next)
}
}
impl Default for LassoSearchParameters {
fn default() -> Self {
let default_params = LassoParameters::default();
LassoSearchParameters {
alpha: vec![default_params.alpha],
normalize: vec![default_params.normalize],
tol: vec![default_params.tol],
max_iter: vec![default_params.max_iter],
fit_intercept: vec![default_params.fit_intercept],
}
}
}
impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Lasso<TX, TY, X, Y> {
impl<T: RealNumber, M: Matrix<T>> Lasso<T, M> {
/// Fits Lasso regression to your data.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - target values
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
pub fn fit(x: &X, y: &Y, parameters: LassoParameters) -> Result<Lasso<TX, TY, X, Y>, Failed> {
pub fn fit(
x: &M,
y: &M::RowVector,
parameters: LassoParameters<T>,
) -> Result<Lasso<T, M>, Failed> {
let (n, p) = x.shape();
if n < p {
if n <= p {
return Err(Failed::fit(
"Number of rows in X should be >= number of columns in X",
));
}
if parameters.alpha < 0f64 {
if parameters.alpha < T::zero() {
return Err(Failed::fit("alpha should be >= 0"));
}
if parameters.tol <= 0f64 {
if parameters.tol <= T::zero() {
return Err(Failed::fit("tol should be > 0"));
}
@@ -290,111 +142,75 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
return Err(Failed::fit("max_iter should be > 0"));
}
if y.shape() != n {
if y.len() != n {
return Err(Failed::fit("Number of rows in X should = len(y)"));
}
let y: Vec<TX> = y.iterator(0).map(|&v| TX::from(v).unwrap()).collect();
let l1_reg = TX::from_f64(parameters.alpha * n as f64).unwrap();
let l1_reg = parameters.alpha * T::from_usize(n).unwrap();
let (w, b) = if parameters.normalize {
let (scaled_x, col_mean, col_std) = Self::rescale_x(x)?;
let mut optimizer = InteriorPointOptimizer::new(&scaled_x, p);
let mut w = optimizer.optimize(
&scaled_x,
&y,
l1_reg,
parameters.max_iter,
TX::from_f64(parameters.tol).unwrap(),
parameters.fit_intercept,
)?;
let mut w =
optimizer.optimize(&scaled_x, y, l1_reg, parameters.max_iter, parameters.tol)?;
for (j, col_std_j) in col_std.iter().enumerate().take(p) {
w[j] /= *col_std_j;
w.set(j, 0, w.get(j, 0) / *col_std_j);
}
let b = if parameters.fit_intercept {
let mut xw_mean = TX::zero();
for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
xw_mean += w[i] * *col_mean_i;
}
let mut b = T::zero();
Some(TX::from_f64(y.mean_by()).unwrap() - xw_mean)
} else {
None
};
(X::from_column(&w), b)
for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
b += w.get(i, 0) * *col_mean_i;
}
b = y.mean() - b;
(w, b)
} else {
let mut optimizer = InteriorPointOptimizer::new(x, p);
let w = optimizer.optimize(
x,
&y,
l1_reg,
parameters.max_iter,
TX::from_f64(parameters.tol).unwrap(),
parameters.fit_intercept,
)?;
let w = optimizer.optimize(x, y, l1_reg, parameters.max_iter, parameters.tol)?;
(
X::from_column(&w),
if parameters.fit_intercept {
Some(TX::from_f64(y.mean_by()).unwrap())
} else {
None
},
)
(w, y.mean())
};
Ok(Lasso {
intercept: b,
coefficients: Some(w),
_phantom_ty: PhantomData,
_phantom_y: PhantomData,
coefficients: w,
})
}
/// Predict target values from `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
let (nrows, _) = x.shape();
let mut y_hat = x.matmul(self.coefficients());
let bias = X::fill(nrows, 1, self.intercept.unwrap());
y_hat.add_mut(&bias);
Ok(Y::from_iterator(
y_hat.iterator(0).map(|&v| TY::from(v).unwrap()),
nrows,
))
let mut y_hat = x.matmul(&self.coefficients);
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
Ok(y_hat.transpose().to_row_vector())
}
/// Get estimates regression coefficients
pub fn coefficients(&self) -> &X {
self.coefficients.as_ref().unwrap()
pub fn coefficients(&self) -> &M {
&self.coefficients
}
/// Get estimate of intercept
pub fn intercept(&self) -> &TX {
self.intercept.as_ref().unwrap()
pub fn intercept(&self) -> T {
self.intercept
}
fn rescale_x(x: &X) -> Result<(X, Vec<TX>, Vec<TX>), Failed> {
let col_mean: Vec<TX> = x
.mean_by(0)
.iter()
.map(|&v| TX::from_f64(v).unwrap())
.collect();
let col_std: Vec<TX> = x
.std_dev(0)
.iter()
.map(|&v| TX::from_f64(v).unwrap())
.collect();
fn rescale_x(x: &M) -> Result<(M, Vec<T>, Vec<T>), Failed> {
let col_mean = x.mean(0);
let col_std = x.std(0);
for (i, col_std_i) in col_std.iter().enumerate() {
if (*col_std_i - TX::zero()).abs() < TX::epsilon() {
return Err(Failed::fit(&format!("Cannot rescale constant column {i}")));
if (*col_std_i - T::zero()).abs() < T::epsilon() {
return Err(Failed::fit(&format!(
"Cannot rescale constant column {}",
i
)));
}
}
@@ -407,37 +223,12 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::arrays::Array;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::linalg::naive::dense_matrix::*;
use crate::metrics::mean_absolute_error;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn search_parameters() {
let parameters = LassoSearchParameters {
alpha: vec![0., 1.],
max_iter: vec![10, 100],
fit_intercept: vec![false, true],
..Default::default()
};
let mut iter = parameters.clone().into_iter();
for current_fit_intercept in 0..parameters.fit_intercept.len() {
for current_max_iter in 0..parameters.max_iter.len() {
for current_alpha in 0..parameters.alpha.len() {
let next = iter.next().unwrap();
assert_eq!(next.alpha, parameters.alpha[current_alpha]);
assert_eq!(next.max_iter, parameters.max_iter[current_max_iter]);
assert_eq!(
next.fit_intercept,
parameters.fit_intercept[current_fit_intercept]
);
}
}
}
assert!(iter.next().is_none());
}
fn get_example_x_y() -> (DenseMatrix<f64>, Vec<f64>) {
fn lasso_fit_predict() {
let x = DenseMatrix::from_2d_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
@@ -455,25 +246,13 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
])
.unwrap();
]);
let y: Vec<f64> = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
];
(x, y)
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn lasso_fit_predict() {
let (x, y) = get_example_x_y();
let y_hat = Lasso::fit(&x, &y, Default::default())
.and_then(|lr| lr.predict(&x))
.unwrap();
@@ -488,7 +267,6 @@ mod tests {
normalize: false,
tol: 1e-4,
max_iter: 1000,
fit_intercept: true,
},
)
.and_then(|lr| lr.predict(&x))
@@ -497,81 +275,39 @@ mod tests {
assert!(mean_absolute_error(&y_hat, &y) < 2.0);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_full_rank_x() {
// x: randn(3,3) * 10, demean, then round to 2 decimal points
// y = x @ [10.0, 0.2, -3.0], round to 2 decimal points
let param = LassoParameters::default()
.with_normalize(false)
.with_alpha(200.0);
#[cfg(feature = "serde")]
fn serde() {
let x = DenseMatrix::from_2d_array(&[
&[-8.9, -2.24, 8.89],
&[-4.02, 8.89, 12.33],
&[12.92, -6.65, -21.22],
])
.unwrap();
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
&[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
&[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
&[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
&[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
&[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
&[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]);
let y = vec![-116.12, -75.41, 191.53];
let w = Lasso::fit(&x, &y, param)
.unwrap()
.coefficients()
.iterator(0)
.copied()
.collect();
let expected_w = vec![5.20289531, 0., -5.32823882]; // by coordinate descent
assert!(mean_absolute_error(&w, &expected_w) < 1e-3); // actual mean_absolute_error is about 2e-4
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn test_fit_intercept() {
let (x, y) = get_example_x_y();
let fit_result = Lasso::fit(
&x,
&y,
LassoParameters {
alpha: 0.1,
normalize: false,
tol: 1e-8,
max_iter: 1000,
fit_intercept: false,
},
)
.unwrap();
let w = fit_result.coefficients().iterator(0).copied().collect();
// by sklearn LassoLars. coordinate descent doesn't converge well
let expected_w = vec![
0.18335684,
0.02106526,
0.00703214,
-1.35952542,
0.09295222,
0.,
let y = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
];
assert!(mean_absolute_error(&w, &expected_w) < 1e-6);
assert_eq!(fit_result.intercept, None);
let lr = Lasso::fit(&x, &y, Default::default()).unwrap();
let deserialized_lr: Lasso<f64, DenseMatrix<f64>> =
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
assert_eq!(lr, deserialized_lr);
}
// TODO: serialization for the new DenseMatrix needs to be implemented
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
// #[test]
// #[cfg(feature = "serde")]
// fn serde() {
// let (x, y) = get_lasso_sample_x_y();
// let lr = Lasso::fit(&x, &y, Default::default()).unwrap();
// let deserialized_lr: Lasso<f64, f64, DenseMatrix<f64>, Vec<f64>> =
// serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
// assert_eq!(lr, deserialized_lr);
// }
}
+78 -75
View File
@@ -12,22 +12,21 @@
//!
use crate::error::Failed;
use crate::linalg::basic::arrays::{Array1, Array2, ArrayView1, MutArray, MutArrayView1};
use crate::linalg::BaseVector;
use crate::linalg::Matrix;
use crate::linear::bg_solver::BiconjugateGradientSolver;
use crate::numbers::floatnum::FloatNumber;
use crate::math::num::RealNumber;
/// Interior Point Optimizer
pub struct InteriorPointOptimizer<T: FloatNumber, X: Array2<T>> {
ata: X,
pub struct InteriorPointOptimizer<T: RealNumber, M: Matrix<T>> {
ata: M,
d1: Vec<T>,
d2: Vec<T>,
prb: Vec<T>,
prs: Vec<T>,
}
impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
/// Initialize a new Interior Point Optimizer
pub fn new(a: &X, n: usize) -> InteriorPointOptimizer<T, X> {
impl<T: RealNumber, M: Matrix<T>> InteriorPointOptimizer<T, M> {
pub fn new(a: &M, n: usize) -> InteriorPointOptimizer<T, M> {
InteriorPointOptimizer {
ata: a.ab(true, a, false),
d1: vec![T::zero(); n],
@@ -37,23 +36,20 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
}
}
/// Run the optimization
pub fn optimize(
&mut self,
x: &X,
y: &Vec<T>,
x: &M,
y: &M::RowVector,
lambda: T,
max_iter: usize,
tol: T,
fit_intercept: bool,
) -> Result<Vec<T>, Failed> {
) -> Result<M, Failed> {
let (n, p) = x.shape();
let p_f64 = T::from_usize(p).unwrap();
let lambda = lambda.max(T::epsilon());
//parameters
let max_ls_iter = 100;
let pcgmaxi = 5000;
let min_pcgtol = T::from_f64(0.1).unwrap();
let eta = T::from_f64(1E-3).unwrap();
@@ -62,56 +58,50 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
let gamma = T::from_f64(-0.25).unwrap();
let mu = T::two();
// let y = M::from_row_vector(y.sub_scalar(y.mean_by())).transpose();
let y = if fit_intercept {
y.sub_scalar(T::from_f64(y.mean_by()).unwrap())
} else {
y.to_owned()
};
let y = M::from_row_vector(y.sub_scalar(y.mean())).transpose();
let mut max_ls_iter = 100;
let mut pitr = 0;
let mut w = Vec::zeros(p);
let mut w = M::zeros(p, 1);
let mut neww = w.clone();
let mut u = Vec::ones(p);
let mut u = M::ones(p, 1);
let mut newu = u.clone();
let mut f = X::fill(p, 2, -T::one());
let mut f = M::fill(p, 2, -T::one());
let mut newf = f.clone();
let mut q1 = vec![T::zero(); p];
let mut q2 = vec![T::zero(); p];
let mut dx = Vec::zeros(p);
let mut du = Vec::zeros(p);
let mut dxu = Vec::zeros(2 * p);
let mut grad = Vec::zeros(2 * p);
let mut dx = M::zeros(p, 1);
let mut du = M::zeros(p, 1);
let mut dxu = M::zeros(2 * p, 1);
let mut grad = M::zeros(2 * p, 1);
let mut nu = Vec::zeros(n);
let mut nu = M::zeros(n, 1);
let mut dobj = T::zero();
let mut s = T::infinity();
let mut t = T::one()
.max(T::one() / lambda)
.min(T::two() * p_f64 / T::from(1e-3).unwrap());
let lambda_f64 = lambda.to_f64().unwrap();
for ntiter in 0..max_iter {
let mut z = w.xa(true, x);
let mut z = x.matmul(&w);
for i in 0..n {
z[i] -= y[i];
nu[i] = T::two() * z[i];
z.set(i, 0, z.get(i, 0) - y.get(i, 0));
nu.set(i, 0, T::two() * z.get(i, 0));
}
// CALCULATE DUALITY GAP
let xnu = nu.xa(false, x);
let max_xnu = xnu.norm(f64::INFINITY);
if max_xnu > lambda_f64 {
let lnu = T::from_f64(lambda_f64 / max_xnu).unwrap();
let xnu = x.ab(true, &nu, false);
let max_xnu = xnu.norm(T::infinity());
if max_xnu > lambda {
let lnu = lambda / max_xnu;
nu.mul_scalar_mut(lnu);
}
let pobj = z.dot(&z) + lambda * T::from_f64(w.norm(1f64)).unwrap();
let pobj = z.dot(&z) + lambda * w.norm(T::one());
dobj = dobj.max(gamma * nu.dot(&nu) - nu.dot(&y));
let gap = pobj - dobj;
@@ -128,22 +118,22 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
// CALCULATE NEWTON STEP
for i in 0..p {
let q1i = T::one() / (u[i] + w[i]);
let q2i = T::one() / (u[i] - w[i]);
let q1i = T::one() / (u.get(i, 0) + w.get(i, 0));
let q2i = T::one() / (u.get(i, 0) - w.get(i, 0));
q1[i] = q1i;
q2[i] = q2i;
self.d1[i] = (q1i * q1i + q2i * q2i) / t;
self.d2[i] = (q1i * q1i - q2i * q2i) / t;
}
let mut gradphi = z.xa(false, x);
let mut gradphi = x.ab(true, &z, false);
for i in 0..p {
let g1 = T::two() * gradphi[i] - (q1[i] - q2[i]) / t;
let g1 = T::two() * gradphi.get(i, 0) - (q1[i] - q2[i]) / t;
let g2 = lambda - (q1[i] + q2[i]) / t;
gradphi[i] = g1;
grad[i] = -g1;
grad[i + p] = -g2;
gradphi.set(i, 0, g1);
grad.set(i, 0, -g1);
grad.set(i + p, 0, -g2);
}
for i in 0..p {
@@ -151,7 +141,7 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
self.prs[i] = self.prb[i] * self.d1[i] - self.d2[i].powi(2);
}
let normg = T::from_f64(grad.norm2()).unwrap();
let normg = grad.norm2();
let mut pcgtol = min_pcgtol.min(eta * gap / T::one().min(normg));
if ntiter != 0 && pitr == 0 {
pcgtol *= min_pcgtol;
@@ -162,31 +152,29 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
pitr = pcgmaxi;
}
dx[..p].copy_from_slice(&dxu[..p]);
du[..p].copy_from_slice(&dxu[p..(p + p)]);
for i in 0..p {
dx.set(i, 0, dxu.get(i, 0));
du.set(i, 0, dxu.get(i + p, 0));
}
// BACKTRACKING LINE SEARCH
let phi = z.dot(&z) + lambda * u.sum() - Self::sumlogneg(&f) / t;
s = T::one();
let gdx = grad.dot(&dxu);
let mut lsiter = 0;
let lsiter = 0;
while lsiter < max_ls_iter {
for i in 0..p {
neww[i] = w[i] + s * dx[i];
newu[i] = u[i] + s * du[i];
newf.set((i, 0), neww[i] - newu[i]);
newf.set((i, 1), -neww[i] - newu[i]);
neww.set(i, 0, w.get(i, 0) + s * dx.get(i, 0));
newu.set(i, 0, u.get(i, 0) + s * du.get(i, 0));
newf.set(i, 0, neww.get(i, 0) - newu.get(i, 0));
newf.set(i, 1, -neww.get(i, 0) - newu.get(i, 0));
}
if newf
.iterator(0)
.fold(T::neg_infinity(), |max, v| v.max(max))
< T::zero()
{
let mut newz = neww.xa(true, x);
if newf.max() < T::zero() {
let mut newz = x.matmul(&neww);
for i in 0..n {
newz[i] -= y[i];
newz.set(i, 0, newz.get(i, 0) - y.get(i, 0));
}
let newphi = newz.dot(&newz) + lambda * newu.sum() - Self::sumlogneg(&newf) / t;
@@ -195,7 +183,7 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
}
}
s = beta * s;
lsiter += 1;
max_ls_iter += 1;
}
if lsiter == max_ls_iter {
@@ -212,41 +200,56 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
Ok(w)
}
fn sumlogneg(f: &X) -> T {
fn sumlogneg(f: &M) -> T {
let (n, _) = f.shape();
let mut sum = T::zero();
for i in 0..n {
sum += (-*f.get((i, 0))).ln();
sum += (-*f.get((i, 1))).ln();
sum += (-f.get(i, 0)).ln();
sum += (-f.get(i, 1)).ln();
}
sum
}
}
impl<'a, T: FloatNumber, X: Array2<T>> BiconjugateGradientSolver<'a, T, X>
for InteriorPointOptimizer<T, X>
impl<'a, T: RealNumber, M: Matrix<T>> BiconjugateGradientSolver<T, M>
for InteriorPointOptimizer<T, M>
{
fn solve_preconditioner(&self, a: &'a X, b: &[T], x: &mut [T]) {
fn solve_preconditioner(&self, a: &M, b: &M, x: &mut M) {
let (_, p) = a.shape();
for i in 0..p {
x[i] = (self.d1[i] * b[i] - self.d2[i] * b[i + p]) / self.prs[i];
x[i + p] = (-self.d2[i] * b[i] + self.prb[i] * b[i + p]) / self.prs[i];
x.set(
i,
0,
(self.d1[i] * b.get(i, 0) - self.d2[i] * b.get(i + p, 0)) / self.prs[i],
);
x.set(
i + p,
0,
(-self.d2[i] * b.get(i, 0) + self.prb[i] * b.get(i + p, 0)) / self.prs[i],
);
}
}
fn mat_vec_mul(&self, _: &X, x: &Vec<T>, y: &mut Vec<T>) {
fn mat_vec_mul(&self, _: &M, x: &M, y: &mut M) {
let (_, p) = self.ata.shape();
let x_slice = Vec::from_slice(x.slice(0..p).as_ref());
let atax = x_slice.xa(true, &self.ata);
let atax = self.ata.matmul(&x.slice(0..p, 0..1));
for i in 0..p {
y[i] = T::two() * atax[i] + self.d1[i] * x[i] + self.d2[i] * x[i + p];
y[i + p] = self.d2[i] * x[i] + self.d1[i] * x[i + p];
y.set(
i,
0,
T::two() * atax.get(i, 0) + self.d1[i] * x.get(i, 0) + self.d2[i] * x.get(i + p, 0),
);
y.set(
i + p,
0,
self.d2[i] * x.get(i, 0) + self.d1[i] * x.get(i + p, 0),
);
}
}
fn mat_t_vec_mul(&self, a: &X, x: &Vec<T>, y: &mut Vec<T>) {
fn mat_t_vec_mul(&self, a: &M, x: &M, y: &mut M) {
self.mat_vec_mul(a, x, y);
}
}
+84 -210
View File
@@ -12,14 +12,14 @@
//! \\[\hat{\beta} = (X^TX)^{-1}X^Ty \\]
//!
//! the \\((X^TX)^{-1}\\) term is both computationally expensive and numerically unstable. An alternative approach is to use a matrix decomposition to avoid this operation.
//! `smartcore` uses [SVD](../../linalg/svd/index.html) and [QR](../../linalg/qr/index.html) matrix decomposition to find estimates of \\(\hat{\beta}\\).
//! SmartCore uses [SVD](../../linalg/svd/index.html) and [QR](../../linalg/qr/index.html) matrix decomposition to find estimates of \\(\hat{\beta}\\).
//! The QR decomposition is more computationally efficient and more numerically stable than calculating the normal equation directly,
//! but does not work for all data matrices. Unlike the QR decomposition, all matrices have an SVD decomposition.
//!
//! Example:
//!
//! ```
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::linear::linear_regression::*;
//!
//! // Longley dataset (https://www.statsmodels.org/stable/datasets/generated/longley.html)
@@ -40,7 +40,7 @@
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
//! ]).unwrap();
//! ]);
//!
//! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0,
//! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
@@ -61,26 +61,21 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::fmt::Debug;
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed;
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::linalg::traits::qr::QRDecomposable;
use crate::linalg::traits::svd::SVDDecomposable;
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Default, Clone, Eq, PartialEq)]
#[derive(Debug, Clone)]
/// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable.
pub enum LinearRegressionSolverName {
/// QR decomposition, see [QR](../../linalg/qr/index.html)
QR,
#[default]
/// SVD decomposition, see [SVD](../../linalg/svd/index.html)
SVD,
}
@@ -89,32 +84,17 @@ pub enum LinearRegressionSolverName {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LinearRegressionParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Solver to use for estimation of regression coefficients.
pub solver: LinearRegressionSolverName,
}
impl Default for LinearRegressionParameters {
fn default() -> Self {
LinearRegressionParameters {
solver: LinearRegressionSolverName::SVD,
}
}
}
/// Linear Regression
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct LinearRegression<
TX: Number + RealNumber,
TY: Number,
X: Array2<TX> + QRDecomposable<TX> + SVDDecomposable<TX>,
Y: Array1<TY>,
> {
coefficients: Option<X>,
intercept: Option<TX>,
_phantom_ty: PhantomData<TY>,
_phantom_y: PhantomData<Y>,
pub struct LinearRegression<T: RealNumber, M: Matrix<T>> {
coefficients: M,
intercept: T,
_solver: LinearRegressionSolverName,
}
impl LinearRegressionParameters {
@@ -125,134 +105,51 @@ impl LinearRegressionParameters {
}
}
/// Linear Regression grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LinearRegressionSearchParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// Solver to use for estimation of regression coefficients.
pub solver: Vec<LinearRegressionSolverName>,
}
/// Linear Regression grid search iterator
pub struct LinearRegressionSearchParametersIterator {
linear_regression_search_parameters: LinearRegressionSearchParameters,
current_solver: usize,
}
impl IntoIterator for LinearRegressionSearchParameters {
type Item = LinearRegressionParameters;
type IntoIter = LinearRegressionSearchParametersIterator;
fn into_iter(self) -> Self::IntoIter {
LinearRegressionSearchParametersIterator {
linear_regression_search_parameters: self,
current_solver: 0,
}
}
}
impl Iterator for LinearRegressionSearchParametersIterator {
type Item = LinearRegressionParameters;
fn next(&mut self) -> Option<Self::Item> {
if self.current_solver == self.linear_regression_search_parameters.solver.len() {
return None;
}
let next = LinearRegressionParameters {
solver: self.linear_regression_search_parameters.solver[self.current_solver].clone(),
};
self.current_solver += 1;
Some(next)
}
}
impl Default for LinearRegressionSearchParameters {
impl Default for LinearRegressionParameters {
fn default() -> Self {
let default_params = LinearRegressionParameters::default();
LinearRegressionSearchParameters {
solver: vec![default_params.solver],
LinearRegressionParameters {
solver: LinearRegressionSolverName::SVD,
}
}
}
impl<
TX: Number + RealNumber,
TY: Number,
X: Array2<TX> + QRDecomposable<TX> + SVDDecomposable<TX>,
Y: Array1<TY>,
> PartialEq for LinearRegression<TX, TY, X, Y>
{
impl<T: RealNumber, M: Matrix<T>> PartialEq for LinearRegression<T, M> {
fn eq(&self, other: &Self) -> bool {
self.intercept == other.intercept
&& self.coefficients().shape() == other.coefficients().shape()
&& self
.coefficients()
.iterator(0)
.zip(other.coefficients().iterator(0))
.all(|(&a, &b)| (a - b).abs() <= TX::epsilon())
self.coefficients == other.coefficients
&& (self.intercept - other.intercept).abs() <= T::epsilon()
}
}
impl<
TX: Number + RealNumber,
TY: Number,
X: Array2<TX> + QRDecomposable<TX> + SVDDecomposable<TX>,
Y: Array1<TY>,
> SupervisedEstimator<X, Y, LinearRegressionParameters> for LinearRegression<TX, TY, X, Y>
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, LinearRegressionParameters>
for LinearRegression<T, M>
{
fn new() -> Self {
Self {
coefficients: Option::None,
intercept: Option::None,
_phantom_ty: PhantomData,
_phantom_y: PhantomData,
}
}
fn fit(x: &X, y: &Y, parameters: LinearRegressionParameters) -> Result<Self, Failed> {
fn fit(
x: &M,
y: &M::RowVector,
parameters: LinearRegressionParameters,
) -> Result<Self, Failed> {
LinearRegression::fit(x, y, parameters)
}
}
impl<
TX: Number + RealNumber,
TY: Number,
X: Array2<TX> + QRDecomposable<TX> + SVDDecomposable<TX>,
Y: Array1<TY>,
> Predictor<X, Y> for LinearRegression<TX, TY, X, Y>
{
fn predict(&self, x: &X) -> Result<Y, Failed> {
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for LinearRegression<T, M> {
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
self.predict(x)
}
}
impl<
TX: Number + RealNumber,
TY: Number,
X: Array2<TX> + QRDecomposable<TX> + SVDDecomposable<TX>,
Y: Array1<TY>,
> LinearRegression<TX, TY, X, Y>
{
impl<T: RealNumber, M: Matrix<T>> LinearRegression<T, M> {
/// Fits Linear Regression to your data.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - target values
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
pub fn fit(
x: &X,
y: &Y,
x: &M,
y: &M::RowVector,
parameters: LinearRegressionParameters,
) -> Result<LinearRegression<TX, TY, X, Y>, Failed> {
let b = X::from_iterator(
y.iterator(0).map(|&v| TX::from(v).unwrap()),
y.shape(),
1,
0,
);
) -> Result<LinearRegression<T, M>, Failed> {
let y_m = M::from_row_vector(y.clone());
let b = y_m.transpose();
let (x_nrows, num_attributes) = x.shape();
let (y_nrows, _) = b.shape();
@@ -262,77 +159,59 @@ impl<
));
}
let a = x.h_stack(&X::ones(x_nrows, 1));
let a = x.h_stack(&M::ones(x_nrows, 1));
let w = match parameters.solver {
LinearRegressionSolverName::QR => a.qr_solve_mut(b)?,
LinearRegressionSolverName::SVD => a.svd_solve_mut(b)?,
};
let weights = X::from_slice(w.slice(0..num_attributes, 0..1).as_ref());
let wights = w.slice(0..num_attributes, 0..1);
Ok(LinearRegression {
intercept: Some(*w.get((num_attributes, 0))),
coefficients: Some(weights),
_phantom_ty: PhantomData,
_phantom_y: PhantomData,
intercept: w.get(num_attributes, 0),
coefficients: wights,
_solver: parameters.solver,
})
}
/// Predict target values from `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
let (nrows, _) = x.shape();
let bias = X::fill(nrows, 1, *self.intercept());
let mut y_hat = x.matmul(self.coefficients());
y_hat.add_mut(&bias);
Ok(Y::from_iterator(
y_hat.iterator(0).map(|&v| TY::from(v).unwrap()),
nrows,
))
let mut y_hat = x.matmul(&self.coefficients);
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
Ok(y_hat.transpose().to_row_vector())
}
/// Get estimates regression coefficients
pub fn coefficients(&self) -> &X {
self.coefficients.as_ref().unwrap()
pub fn coefficients(&self) -> &M {
&self.coefficients
}
/// Get estimate of intercept
pub fn intercept(&self) -> &TX {
self.intercept.as_ref().unwrap()
pub fn intercept(&self) -> T {
self.intercept
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::linalg::naive::dense_matrix::*;
#[test]
fn search_parameters() {
let parameters = LinearRegressionSearchParameters {
solver: vec![
LinearRegressionSolverName::QR,
LinearRegressionSolverName::SVD,
],
};
let mut iter = parameters.into_iter();
assert_eq!(iter.next().unwrap().solver, LinearRegressionSolverName::QR);
assert_eq!(iter.next().unwrap().solver, LinearRegressionSolverName::SVD);
assert!(iter.next().is_none());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn ols_fit_predict() {
let x = DenseMatrix::from_2d_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
&[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
&[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
&[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
&[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
&[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
@@ -341,11 +220,11 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
])
.unwrap();
]);
let y: Vec<f64> = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8,
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
];
let y_hat_qr = LinearRegression::fit(
@@ -372,44 +251,39 @@ mod tests {
.all(|(&a, &b)| (a - b).abs() <= 5.0));
}
// TODO: serialization for the new DenseMatrix needs to be implemented
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
// #[test]
// #[cfg(feature = "serde")]
// fn serde() {
// let x = DenseMatrix::from_2d_array(&[
// &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
// &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
// &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
// &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
// &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
// &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
// &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
// &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
// &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
// &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
// &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
// &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
// &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
// &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
// &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
// &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
// ]).unwrap();
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
#[cfg(feature = "serde")]
fn serde() {
let x = DenseMatrix::from_2d_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
&[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
&[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
&[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
&[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
&[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
&[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]);
// let y = vec![
// 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
// 114.2, 115.7, 116.9,
// ];
let y = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
];
// let lr = LinearRegression::fit(&x, &y, Default::default()).unwrap();
let lr = LinearRegression::fit(&x, &y, Default::default()).unwrap();
// let deserialized_lr: LinearRegression<f64, f64, DenseMatrix<f64>, Vec<f64>> =
// serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
let deserialized_lr: LinearRegression<f64, DenseMatrix<f64>> =
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
// assert_eq!(lr, deserialized_lr);
// let default = LinearRegressionParameters::default();
// let parameters: LinearRegressionParameters = serde_json::from_str("{}").unwrap();
// assert_eq!(parameters.solver, default.solver);
// }
assert_eq!(lr, deserialized_lr);
}
}
File diff suppressed because it is too large Load Diff
+2 -2
View File
@@ -20,10 +20,10 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
pub mod bg_solver;
pub(crate) mod bg_solver;
pub mod elastic_net;
pub mod lasso;
pub mod lasso_optimizer;
pub(crate) mod lasso_optimizer;
pub mod linear_regression;
pub mod logistic_regression;
pub mod ridge_regression;
+95 -258
View File
@@ -12,14 +12,14 @@
//! where \\(\alpha \geq 0\\) is a tuning parameter that controls strength of regularization. When \\(\alpha = 0\\) the penalty term has no effect, and ridge regression will produce the least squares estimates.
//! However, as \\(\alpha \rightarrow \infty\\), the impact of the shrinkage penalty grows, and the ridge regression coefficient estimates will approach zero.
//!
//! `smartcore` uses [SVD](../../linalg/svd/index.html) and [Cholesky](../../linalg/cholesky/index.html) matrix decomposition to find estimates of \\(\hat{\beta}\\).
//! SmartCore uses [SVD](../../linalg/svd/index.html) and [Cholesky](../../linalg/cholesky/index.html) matrix decomposition to find estimates of \\(\hat{\beta}\\).
//! The Cholesky decomposition is more computationally efficient and more numerically stable than calculating the normal equation directly,
//! but does not work for all data matrices. Unlike the Cholesky decomposition, all matrices have an SVD decomposition.
//!
//! Example:
//!
//! ```
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::linear::ridge_regression::*;
//!
//! // Longley dataset (https://www.statsmodels.org/stable/datasets/generated/longley.html)
@@ -40,7 +40,7 @@
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
//! ]).unwrap();
//! ]);
//!
//! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0,
//! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
@@ -57,25 +57,21 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::fmt::Debug;
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed;
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::linalg::traits::cholesky::CholeskyDecomposable;
use crate::linalg::traits::svd::SVDDecomposable;
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
use crate::linalg::BaseVector;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, Eq, PartialEq, Default)]
#[derive(Debug, Clone)]
/// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable.
pub enum RidgeRegressionSolverName {
/// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html)
#[default]
Cholesky,
/// SVD decomposition, see [SVD](../../linalg/svd/index.html)
SVD,
@@ -84,7 +80,7 @@ pub enum RidgeRegressionSolverName {
/// Ridge Regression parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct RidgeRegressionParameters<T: Number + RealNumber> {
pub struct RidgeRegressionParameters<T: RealNumber> {
/// Solver to use for estimation of regression coefficients.
pub solver: RidgeRegressionSolverName,
/// Controls the strength of the penalty to the loss function.
@@ -94,109 +90,16 @@ pub struct RidgeRegressionParameters<T: Number + RealNumber> {
pub normalize: bool,
}
/// Ridge Regression grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct RidgeRegressionSearchParameters<T: Number + RealNumber> {
#[cfg_attr(feature = "serde", serde(default))]
/// Solver to use for estimation of regression coefficients.
pub solver: Vec<RidgeRegressionSolverName>,
#[cfg_attr(feature = "serde", serde(default))]
/// Regularization parameter.
pub alpha: Vec<T>,
#[cfg_attr(feature = "serde", serde(default))]
/// If true the regressors X will be normalized before regression
/// by subtracting the mean and dividing by the standard deviation.
pub normalize: Vec<bool>,
}
/// Ridge Regression grid search iterator
pub struct RidgeRegressionSearchParametersIterator<T: Number + RealNumber> {
ridge_regression_search_parameters: RidgeRegressionSearchParameters<T>,
current_solver: usize,
current_alpha: usize,
current_normalize: usize,
}
impl<T: Number + RealNumber> IntoIterator for RidgeRegressionSearchParameters<T> {
type Item = RidgeRegressionParameters<T>;
type IntoIter = RidgeRegressionSearchParametersIterator<T>;
fn into_iter(self) -> Self::IntoIter {
RidgeRegressionSearchParametersIterator {
ridge_regression_search_parameters: self,
current_solver: 0,
current_alpha: 0,
current_normalize: 0,
}
}
}
impl<T: Number + RealNumber> Iterator for RidgeRegressionSearchParametersIterator<T> {
type Item = RidgeRegressionParameters<T>;
fn next(&mut self) -> Option<Self::Item> {
if self.current_alpha == self.ridge_regression_search_parameters.alpha.len()
&& self.current_solver == self.ridge_regression_search_parameters.solver.len()
{
return None;
}
let next = RidgeRegressionParameters {
solver: self.ridge_regression_search_parameters.solver[self.current_solver].clone(),
alpha: self.ridge_regression_search_parameters.alpha[self.current_alpha],
normalize: self.ridge_regression_search_parameters.normalize[self.current_normalize],
};
if self.current_alpha + 1 < self.ridge_regression_search_parameters.alpha.len() {
self.current_alpha += 1;
} else if self.current_solver + 1 < self.ridge_regression_search_parameters.solver.len() {
self.current_alpha = 0;
self.current_solver += 1;
} else if self.current_normalize + 1
< self.ridge_regression_search_parameters.normalize.len()
{
self.current_alpha = 0;
self.current_solver = 0;
self.current_normalize += 1;
} else {
self.current_alpha += 1;
self.current_solver += 1;
self.current_normalize += 1;
}
Some(next)
}
}
impl<T: Number + RealNumber> Default for RidgeRegressionSearchParameters<T> {
fn default() -> Self {
let default_params = RidgeRegressionParameters::default();
RidgeRegressionSearchParameters {
solver: vec![default_params.solver],
alpha: vec![default_params.alpha],
normalize: vec![default_params.normalize],
}
}
}
/// Ridge regression
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct RidgeRegression<
TX: Number + RealNumber,
TY: Number,
X: Array2<TX> + CholeskyDecomposable<TX> + SVDDecomposable<TX>,
Y: Array1<TY>,
> {
coefficients: Option<X>,
intercept: Option<TX>,
_phantom_ty: PhantomData<TY>,
_phantom_y: PhantomData<Y>,
pub struct RidgeRegression<T: RealNumber, M: Matrix<T>> {
coefficients: M,
intercept: T,
_solver: RidgeRegressionSolverName,
}
impl<T: Number + RealNumber> RidgeRegressionParameters<T> {
impl<T: RealNumber> RidgeRegressionParameters<T> {
/// Regularization parameter.
pub fn with_alpha(mut self, alpha: T) -> Self {
self.alpha = alpha;
@@ -214,83 +117,51 @@ impl<T: Number + RealNumber> RidgeRegressionParameters<T> {
}
}
impl<T: Number + RealNumber> Default for RidgeRegressionParameters<T> {
impl<T: RealNumber> Default for RidgeRegressionParameters<T> {
fn default() -> Self {
RidgeRegressionParameters {
solver: RidgeRegressionSolverName::default(),
alpha: T::from_f64(1.0).unwrap(),
solver: RidgeRegressionSolverName::Cholesky,
alpha: T::one(),
normalize: true,
}
}
}
impl<
TX: Number + RealNumber,
TY: Number,
X: Array2<TX> + CholeskyDecomposable<TX> + SVDDecomposable<TX>,
Y: Array1<TY>,
> PartialEq for RidgeRegression<TX, TY, X, Y>
{
impl<T: RealNumber, M: Matrix<T>> PartialEq for RidgeRegression<T, M> {
fn eq(&self, other: &Self) -> bool {
self.intercept() == other.intercept()
&& self.coefficients().shape() == other.coefficients().shape()
&& self
.coefficients()
.iterator(0)
.zip(other.coefficients().iterator(0))
.all(|(&a, &b)| (a - b).abs() <= TX::epsilon())
self.coefficients == other.coefficients
&& (self.intercept - other.intercept).abs() <= T::epsilon()
}
}
impl<
TX: Number + RealNumber,
TY: Number,
X: Array2<TX> + CholeskyDecomposable<TX> + SVDDecomposable<TX>,
Y: Array1<TY>,
> SupervisedEstimator<X, Y, RidgeRegressionParameters<TX>> for RidgeRegression<TX, TY, X, Y>
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, RidgeRegressionParameters<T>>
for RidgeRegression<T, M>
{
fn new() -> Self {
Self {
coefficients: Option::None,
intercept: Option::None,
_phantom_ty: PhantomData,
_phantom_y: PhantomData,
}
}
fn fit(x: &X, y: &Y, parameters: RidgeRegressionParameters<TX>) -> Result<Self, Failed> {
fn fit(
x: &M,
y: &M::RowVector,
parameters: RidgeRegressionParameters<T>,
) -> Result<Self, Failed> {
RidgeRegression::fit(x, y, parameters)
}
}
impl<
TX: Number + RealNumber,
TY: Number,
X: Array2<TX> + CholeskyDecomposable<TX> + SVDDecomposable<TX>,
Y: Array1<TY>,
> Predictor<X, Y> for RidgeRegression<TX, TY, X, Y>
{
fn predict(&self, x: &X) -> Result<Y, Failed> {
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RidgeRegression<T, M> {
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
self.predict(x)
}
}
impl<
TX: Number + RealNumber,
TY: Number,
X: Array2<TX> + CholeskyDecomposable<TX> + SVDDecomposable<TX>,
Y: Array1<TY>,
> RidgeRegression<TX, TY, X, Y>
{
impl<T: RealNumber, M: Matrix<T>> RidgeRegression<T, M> {
/// Fits ridge regression to your data.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - target values
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
pub fn fit(
x: &X,
y: &Y,
parameters: RidgeRegressionParameters<TX>,
) -> Result<RidgeRegression<TX, TY, X, Y>, Failed> {
x: &M,
y: &M::RowVector,
parameters: RidgeRegressionParameters<T>,
) -> Result<RidgeRegression<T, M>, Failed> {
//w = inv(X^t X + alpha*Id) * X.T y
let (n, p) = x.shape();
@@ -301,16 +172,11 @@ impl<
));
}
if y.shape() != n {
if y.len() != n {
return Err(Failed::fit("Number of rows in X should = len(y)"));
}
let y_column = X::from_iterator(
y.iterator(0).map(|&v| TX::from(v).unwrap()),
y.shape(),
1,
0,
);
let y_column = M::from_row_vector(y.clone()).transpose();
let (w, b) = if parameters.normalize {
let (scaled_x, col_mean, col_std) = Self::rescale_x(x)?;
@@ -319,7 +185,7 @@ impl<
let mut x_t_x = x_t.matmul(&scaled_x);
for i in 0..p {
x_t_x.add_element_mut((i, i), parameters.alpha);
x_t_x.add_element_mut(i, i, parameters.alpha);
}
let mut w = match parameters.solver {
@@ -328,16 +194,16 @@ impl<
};
for (i, col_std_i) in col_std.iter().enumerate().take(p) {
w.set((i, 0), *w.get((i, 0)) / *col_std_i);
w.set(i, 0, w.get(i, 0) / *col_std_i);
}
let mut b = TX::zero();
let mut b = T::zero();
for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
b += *w.get((i, 0)) * *col_mean_i;
b += w.get(i, 0) * *col_mean_i;
}
let b = TX::from_f64(y.mean_by()).unwrap() - b;
let b = y.mean() - b;
(w, b)
} else {
@@ -346,7 +212,7 @@ impl<
let mut x_t_x = x_t.matmul(x);
for i in 0..p {
x_t_x.add_element_mut((i, i), parameters.alpha);
x_t_x.add_element_mut(i, i, parameters.alpha);
}
let w = match parameters.solver {
@@ -354,32 +220,26 @@ impl<
RidgeRegressionSolverName::SVD => x_t_x.svd_solve_mut(x_t_y)?,
};
(w, TX::zero())
(w, T::zero())
};
Ok(RidgeRegression {
intercept: Some(b),
coefficients: Some(w),
_phantom_ty: PhantomData,
_phantom_y: PhantomData,
intercept: b,
coefficients: w,
_solver: parameters.solver,
})
}
fn rescale_x(x: &X) -> Result<(X, Vec<TX>, Vec<TX>), Failed> {
let col_mean: Vec<TX> = x
.mean_by(0)
.iter()
.map(|&v| TX::from_f64(v).unwrap())
.collect();
let col_std: Vec<TX> = x
.std_dev(0)
.iter()
.map(|&v| TX::from_f64(v).unwrap())
.collect();
fn rescale_x(x: &M) -> Result<(M, Vec<T>, Vec<T>), Failed> {
let col_mean = x.mean(0);
let col_std = x.std(0);
for (i, col_std_i) in col_std.iter().enumerate() {
if (*col_std_i - TX::zero()).abs() < TX::epsilon() {
return Err(Failed::fit(&format!("Cannot rescale constant column {i}")));
if (*col_std_i - T::zero()).abs() < T::epsilon() {
return Err(Failed::fit(&format!(
"Cannot rescale constant column {}",
i
)));
}
}
@@ -390,52 +250,31 @@ impl<
/// Predict target values from `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
let (nrows, _) = x.shape();
let mut y_hat = x.matmul(self.coefficients());
y_hat.add_mut(&X::fill(nrows, 1, self.intercept.unwrap()));
Ok(Y::from_iterator(
y_hat.iterator(0).map(|&v| TY::from(v).unwrap()),
nrows,
))
let mut y_hat = x.matmul(&self.coefficients);
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
Ok(y_hat.transpose().to_row_vector())
}
/// Get estimates regression coefficients
pub fn coefficients(&self) -> &X {
self.coefficients.as_ref().unwrap()
pub fn coefficients(&self) -> &M {
&self.coefficients
}
/// Get estimate of intercept
pub fn intercept(&self) -> &TX {
self.intercept.as_ref().unwrap()
pub fn intercept(&self) -> T {
self.intercept
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::linalg::naive::dense_matrix::*;
use crate::metrics::mean_absolute_error;
#[test]
fn search_parameters() {
let parameters = RidgeRegressionSearchParameters {
alpha: vec![0., 1.],
..Default::default()
};
let mut iter = parameters.into_iter();
assert_eq!(iter.next().unwrap().alpha, 0.);
assert_eq!(
iter.next().unwrap().solver,
RidgeRegressionSolverName::Cholesky
);
assert!(iter.next().is_none());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn ridge_fit_predict() {
let x = DenseMatrix::from_2d_array(&[
@@ -455,8 +294,7 @@ mod tests {
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
])
.unwrap();
]);
let y: Vec<f64> = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
@@ -492,40 +330,39 @@ mod tests {
assert!(mean_absolute_error(&y_hat_svd, &y) < 2.0);
}
// TODO: implement serialization for new DenseMatrix
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
// #[test]
// #[cfg(feature = "serde")]
// fn serde() {
// let x = DenseMatrix::from_2d_array(&[
// &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
// &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
// &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
// &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
// &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
// &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
// &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
// &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
// &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
// &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
// &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
// &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
// &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
// &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
// &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
// &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
// ]).unwrap();
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
#[cfg(feature = "serde")]
fn serde() {
let x = DenseMatrix::from_2d_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
&[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
&[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
&[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
&[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
&[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
&[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]);
// let y = vec![
// 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
// 114.2, 115.7, 116.9,
// ];
let y = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
];
// let lr = RidgeRegression::fit(&x, &y, Default::default()).unwrap();
let lr = RidgeRegression::fit(&x, &y, Default::default()).unwrap();
// let deserialized_lr: RidgeRegression<f64, f64, DenseMatrix<f64>, Vec<f64>> =
// serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
let deserialized_lr: RidgeRegression<f64, DenseMatrix<f64>> =
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
// assert_eq!(lr, deserialized_lr);
// }
assert_eq!(lr, deserialized_lr);
}
}
+70
View File
@@ -0,0 +1,70 @@
//! # Euclidian Metric Distance
//!
//! The Euclidean distance (L2) between two points \\( x \\) and \\( y \\) in n-space is defined as
//!
//! \\[ d(x, y) = \sqrt{\sum_{i=1}^n (x-y)^2} \\]
//!
//! Example:
//!
//! ```
//! use smartcore::math::distance::Distance;
//! use smartcore::math::distance::euclidian::Euclidian;
//!
//! let x = vec![1., 1.];
//! let y = vec![2., 2.];
//!
//! let l2: f64 = Euclidian{}.distance(&x, &y);
//! ```
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::math::num::RealNumber;
use super::Distance;
/// Euclidean distance is a measure of the true straight line distance between two points in Euclidean n-space.
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Euclidian {}
impl Euclidian {
#[inline]
pub(crate) fn squared_distance<T: RealNumber>(x: &[T], y: &[T]) -> T {
if x.len() != y.len() {
panic!("Input vector sizes are different.");
}
let mut sum = T::zero();
for i in 0..x.len() {
let d = x[i] - y[i];
sum += d * d;
}
sum
}
}
impl<T: RealNumber> Distance<Vec<T>, T> for Euclidian {
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T {
Euclidian::squared_distance(x, y).sqrt()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn squared_distance() {
let a = vec![1., 2., 3.];
let b = vec![4., 5., 6.];
let l2: f64 = Euclidian {}.distance(&a, &b);
assert!((l2 - 5.19615242).abs() < 1e-8);
}
}
@@ -6,13 +6,13 @@
//! Example:
//!
//! ```
//! use smartcore::metrics::distance::Distance;
//! use smartcore::metrics::distance::hamming::Hamming;
//! use smartcore::math::distance::Distance;
//! use smartcore::math::distance::hamming::Hamming;
//!
//! let a = vec![1, 0, 0, 1, 0, 0, 1];
//! let b = vec![1, 1, 0, 0, 1, 0, 1];
//!
//! let h: f64 = Hamming::new().distance(&a, &b);
//! let h: f64 = Hamming {}.distance(&a, &b);
//!
//! ```
//!
@@ -21,48 +21,30 @@
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
use crate::math::num::RealNumber;
use super::Distance;
use crate::linalg::basic::arrays::ArrayView1;
use crate::numbers::basenum::Number;
/// While comparing two integer-valued vectors of equal length, Hamming distance is the number of bit positions in which the two bits are different
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Hamming<T: Number> {
_t: PhantomData<T>,
}
pub struct Hamming {}
impl<T: Number> Hamming<T> {
/// instatiate the initial structure
pub fn new() -> Hamming<T> {
Hamming { _t: PhantomData }
}
}
impl<T: Number> Default for Hamming<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Number, A: ArrayView1<T>> Distance<A> for Hamming<T> {
fn distance(&self, x: &A, y: &A) -> f64 {
if x.shape() != y.shape() {
impl<T: PartialEq, F: RealNumber> Distance<Vec<T>, F> for Hamming {
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> F {
if x.len() != y.len() {
panic!("Input vector sizes are different");
}
let dist: usize = x
.iterator(0)
.zip(y.iterator(0))
.map(|(a, b)| match a != b {
true => 1,
false => 0,
})
.sum();
let mut dist = 0;
for i in 0..x.len() {
if x[i] != y[i] {
dist += 1;
}
}
dist as f64 / x.shape() as f64
F::from_i64(dist).unwrap() / F::from_usize(x.len()).unwrap()
}
}
@@ -70,16 +52,13 @@ impl<T: Number, A: ArrayView1<T>> Distance<A> for Hamming<T> {
mod tests {
use super::*;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn hamming_distance() {
let a = vec![1, 0, 0, 1, 0, 0, 1];
let b = vec![1, 1, 0, 0, 1, 0, 1];
let h: f64 = Hamming::new().distance(&a, &b);
let h: f64 = Hamming {}.distance(&a, &b);
assert!((h - 0.42857142).abs() < 1e-8);
}
@@ -14,10 +14,9 @@
//! Example:
//!
//! ```
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linalg::basic::arrays::ArrayView2;
//! use smartcore::metrics::distance::Distance;
//! use smartcore::metrics::distance::mahalanobis::Mahalanobis;
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::math::distance::Distance;
//! use smartcore::math::distance::mahalanobis::Mahalanobis;
//!
//! let data = DenseMatrix::from_2d_array(&[
//! &[64., 580., 29.],
@@ -25,9 +24,9 @@
//! &[68., 590., 37.],
//! &[69., 660., 46.],
//! &[73., 600., 55.],
//! ]).unwrap();
//! ]);
//!
//! let a = data.mean_by(0);
//! let a = data.column_mean();
//! let b = vec![66., 640., 44.];
//!
//! let mahalanobis = Mahalanobis::new(&data);
@@ -43,89 +42,85 @@
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#![allow(non_snake_case)]
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::math::num::RealNumber;
use super::Distance;
use crate::linalg::basic::arrays::{Array, Array2, ArrayView1};
use crate::linalg::basic::matrix::DenseMatrix;
use crate::linalg::traits::lu::LUDecomposable;
use crate::numbers::basenum::Number;
use crate::linalg::Matrix;
/// Mahalanobis distance.
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Mahalanobis<T: Number, M: Array2<f64>> {
pub struct Mahalanobis<T: RealNumber, M: Matrix<T>> {
/// covariance matrix of the dataset
pub sigma: M,
/// inverse of the covariance matrix
pub sigmaInv: M,
_t: PhantomData<T>,
t: PhantomData<T>,
}
impl<T: Number, M: Array2<f64> + LUDecomposable<f64>> Mahalanobis<T, M> {
impl<T: RealNumber, M: Matrix<T>> Mahalanobis<T, M> {
/// Constructs new instance of `Mahalanobis` from given dataset
/// * `data` - a matrix of _NxM_ where _N_ is number of observations and _M_ is number of attributes
pub fn new<X: Array2<T>>(data: &X) -> Mahalanobis<T, M> {
let (_, m) = data.shape();
let mut sigma = M::zeros(m, m);
data.cov(&mut sigma);
pub fn new(data: &M) -> Mahalanobis<T, M> {
let sigma = data.cov();
let sigmaInv = sigma.lu().and_then(|lu| lu.inverse()).unwrap();
Mahalanobis {
sigma,
sigmaInv,
_t: PhantomData,
t: PhantomData,
}
}
/// Constructs new instance of `Mahalanobis` from given covariance matrix
/// * `cov` - a covariance matrix
pub fn new_from_covariance<X: Array2<f64> + LUDecomposable<f64>>(cov: &X) -> Mahalanobis<T, X> {
pub fn new_from_covariance(cov: &M) -> Mahalanobis<T, M> {
let sigma = cov.clone();
let sigmaInv = sigma.lu().and_then(|lu| lu.inverse()).unwrap();
Mahalanobis {
sigma,
sigmaInv,
_t: PhantomData,
t: PhantomData,
}
}
}
impl<T: Number, A: ArrayView1<T>> Distance<A> for Mahalanobis<T, DenseMatrix<f64>> {
fn distance(&self, x: &A, y: &A) -> f64 {
impl<T: RealNumber, M: Matrix<T>> Distance<Vec<T>, T> for Mahalanobis<T, M> {
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T {
let (nrows, ncols) = self.sigma.shape();
if x.shape() != nrows {
if x.len() != nrows {
panic!(
"Array x[{}] has different dimension with Sigma[{}][{}].",
x.shape(),
x.len(),
nrows,
ncols
);
}
if y.shape() != nrows {
if y.len() != nrows {
panic!(
"Array y[{}] has different dimension with Sigma[{}][{}].",
y.shape(),
y.len(),
nrows,
ncols
);
}
let n = x.shape();
let z: Vec<f64> = x
.iterator(0)
.zip(y.iterator(0))
.map(|(&a, &b)| (a - b).to_f64().unwrap())
.collect();
let n = x.len();
let mut z = vec![T::zero(); n];
for i in 0..n {
z[i] = x[i] - y[i];
}
// np.dot(np.dot((a-b),VI),(a-b).T)
let mut s = 0f64;
let mut s = T::zero();
for j in 0..n {
for i in 0..n {
s += *self.sigmaInv.get((i, j)) * z[i] * z[j];
s += self.sigmaInv.get(i, j) * z[i] * z[j];
}
}
@@ -136,13 +131,9 @@ impl<T: Number, A: ArrayView1<T>> Distance<A> for Mahalanobis<T, DenseMatrix<f64
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::basic::arrays::ArrayView2;
use crate::linalg::basic::matrix::DenseMatrix;
use crate::linalg::naive::dense_matrix::*;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn mahalanobis_distance() {
let data = DenseMatrix::from_2d_array(&[
@@ -151,10 +142,9 @@ mod tests {
&[68., 590., 37.],
&[69., 660., 46.],
&[73., 600., 55.],
])
.unwrap();
]);
let a = data.mean_by(0);
let a = data.column_mean();
let b = vec![66., 640., 44.];
let mahalanobis = Mahalanobis::new(&data);
+61
View File
@@ -0,0 +1,61 @@
//! # Manhattan Distance
//!
//! The Manhattan distance between two points \\(x \in ^n \\) and \\( y \in ^n \\) in n-dimensional space is the sum of the distances in each dimension.
//!
//! \\[ d(x, y) = \sum_{i=0}^n \lvert x_i - y_i \rvert \\]
//!
//! Example:
//!
//! ```
//! use smartcore::math::distance::Distance;
//! use smartcore::math::distance::manhattan::Manhattan;
//!
//! let x = vec![1., 1.];
//! let y = vec![2., 2.];
//!
//! let l1: f64 = Manhattan {}.distance(&x, &y);
//! ```
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::math::num::RealNumber;
use super::Distance;
/// Manhattan distance
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Manhattan {}
impl<T: RealNumber> Distance<Vec<T>, T> for Manhattan {
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T {
if x.len() != y.len() {
panic!("Input vector sizes are different");
}
let mut dist = T::zero();
for i in 0..x.len() {
dist += (x[i] - y[i]).abs();
}
dist
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn manhattan_distance() {
let a = vec![1., 2., 3.];
let b = vec![4., 5., 6.];
let l1: f64 = Manhattan {}.distance(&a, &b);
assert!((l1 - 9.0).abs() < 1e-8);
}
}
@@ -8,14 +8,14 @@
//! Example:
//!
//! ```
//! use smartcore::metrics::distance::Distance;
//! use smartcore::metrics::distance::minkowski::Minkowski;
//! use smartcore::math::distance::Distance;
//! use smartcore::math::distance::minkowski::Minkowski;
//!
//! let x = vec![1., 1.];
//! let y = vec![2., 2.];
//!
//! let l1: f64 = Minkowski::new(1).distance(&x, &y);
//! let l2: f64 = Minkowski::new(2).distance(&x, &y);
//! let l1: f64 = Minkowski { p: 1 }.distance(&x, &y);
//! let l2: f64 = Minkowski { p: 2 }.distance(&x, &y);
//!
//! ```
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
@@ -23,47 +23,37 @@
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
use crate::linalg::basic::arrays::ArrayView1;
use crate::numbers::basenum::Number;
use crate::math::num::RealNumber;
use super::Distance;
/// Defines the Minkowski distance of order `p`
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Minkowski<T: Number> {
pub struct Minkowski {
/// order, integer
pub p: u16,
_t: PhantomData<T>,
}
impl<T: Number> Minkowski<T> {
/// instatiate the initial structure
pub fn new(p: u16) -> Minkowski<T> {
Minkowski { p, _t: PhantomData }
}
}
impl<T: Number, A: ArrayView1<T>> Distance<A> for Minkowski<T> {
fn distance(&self, x: &A, y: &A) -> f64 {
if x.shape() != y.shape() {
impl<T: RealNumber> Distance<Vec<T>, T> for Minkowski {
fn distance(&self, x: &Vec<T>, y: &Vec<T>) -> T {
if x.len() != y.len() {
panic!("Input vector sizes are different");
}
if self.p < 1 {
panic!("p must be at least 1");
}
let p_t = self.p as f64;
let mut dist = T::zero();
let p_t = T::from_u16(self.p).unwrap();
let dist: f64 = x
.iterator(0)
.zip(y.iterator(0))
.map(|(&a, &b)| (a - b).to_f64().unwrap().abs().powf(p_t))
.sum();
for i in 0..x.len() {
let d = (x[i] - y[i]).abs();
dist += d.powf(p_t);
}
dist.powf(1f64 / p_t)
dist.powf(T::one() / p_t)
}
}
@@ -71,18 +61,15 @@ impl<T: Number, A: ArrayView1<T>> Distance<A> for Minkowski<T> {
mod tests {
use super::*;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn minkowski_distance() {
let a = vec![1., 2., 3.];
let b = vec![4., 5., 6.];
let l1: f64 = Minkowski::new(1).distance(&a, &b);
let l2: f64 = Minkowski::new(2).distance(&a, &b);
let l3: f64 = Minkowski::new(3).distance(&a, &b);
let l1: f64 = Minkowski { p: 1 }.distance(&a, &b);
let l2: f64 = Minkowski { p: 2 }.distance(&a, &b);
let l3: f64 = Minkowski { p: 3 }.distance(&a, &b);
assert!((l1 - 9.0).abs() < 1e-8);
assert!((l2 - 5.19615242).abs() < 1e-8);
@@ -95,6 +82,6 @@ mod tests {
let a = vec![1., 2., 3.];
let b = vec![4., 5., 6.];
let _: f64 = Minkowski::new(0).distance(&a, &b);
let _: f64 = Minkowski { p: 0 }.distance(&a, &b);
}
}
+65
View File
@@ -0,0 +1,65 @@
//! # Collection of Distance Functions
//!
//! Many algorithms in machine learning require a measure of distance between data points. Distance metric (or metric) is a function that defines a distance between a pair of point elements of a set.
//! Formally, the distance can be any metric measure that is defined as \\( d(x, y) \geq 0\\) and follows three conditions:
//! 1. \\( d(x, y) = 0 \\) if and only \\( x = y \\), positive definiteness
//! 1. \\( d(x, y) = d(y, x) \\), symmetry
//! 1. \\( d(x, y) \leq d(x, z) + d(z, y) \\), subadditivity or triangle inequality
//!
//! for all \\(x, y, z \in Z \\)
//!
//! A good distance metric helps to improve the performance of classification, clustering and information retrieval algorithms significantly.
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
/// Euclidean Distance is the straight-line distance between two points in Euclidean spacere that presents the shortest distance between these points.
pub mod euclidian;
/// Hamming Distance between two strings is the number of positions at which the corresponding symbols are different.
pub mod hamming;
/// The Mahalanobis distance is the distance between two points in multivariate space.
pub mod mahalanobis;
/// Also known as rectilinear distance, city block distance, taxicab metric.
pub mod manhattan;
/// A generalization of both the Euclidean distance and the Manhattan distance.
pub mod minkowski;
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
/// Distance metric, a function that calculates distance between two points
pub trait Distance<T, F: RealNumber>: Clone {
/// Calculates distance between _a_ and _b_
fn distance(&self, a: &T, b: &T) -> F;
}
/// Multitude of distance metric functions
pub struct Distances {}
impl Distances {
/// Euclidian distance, see [`Euclidian`](euclidian/index.html)
pub fn euclidian() -> euclidian::Euclidian {
euclidian::Euclidian {}
}
/// Minkowski distance, see [`Minkowski`](minkowski/index.html)
/// * `p` - function order. Should be >= 1
pub fn minkowski(p: u16) -> minkowski::Minkowski {
minkowski::Minkowski { p }
}
/// Manhattan distance, see [`Manhattan`](manhattan/index.html)
pub fn manhattan() -> manhattan::Manhattan {
manhattan::Manhattan {}
}
/// Hamming distance, see [`Hamming`](hamming/index.html)
pub fn hamming() -> hamming::Hamming {
hamming::Hamming {}
}
/// Mahalanobis distance, see [`Mahalanobis`](mahalanobis/index.html)
pub fn mahalanobis<T: RealNumber, M: Matrix<T>>(data: &M) -> mahalanobis::Mahalanobis<T, M> {
mahalanobis::Mahalanobis::new(data)
}
}
+4
View File
@@ -0,0 +1,4 @@
/// Multitude of distance metrics are defined here
pub mod distance;
pub mod num;
pub(crate) mod vector;
+25 -54
View File
@@ -1,18 +1,28 @@
//! # Real Number
//! Most algorithms in `smartcore` rely on basic linear algebra operations like dot product, matrix decomposition and other subroutines that are defined for a set of real numbers, .
//! Most algorithms in SmartCore rely on basic linear algebra operations like dot product, matrix decomposition and other subroutines that are defined for a set of real numbers, .
//! This module defines real number and some useful functions that are used in [Linear Algebra](../../linalg/index.html) module.
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use num_traits::Float;
use crate::numbers::basenum::Number;
use crate::rand_custom::get_rng_impl;
use num_traits::{Float, FromPrimitive};
use rand::prelude::*;
use std::fmt::{Debug, Display};
use std::iter::{Product, Sum};
use std::ops::{AddAssign, DivAssign, MulAssign, SubAssign};
/// Defines real number
/// <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
pub trait RealNumber: Number + Float {
pub trait RealNumber:
Float
+ FromPrimitive
+ Debug
+ Display
+ Copy
+ Sum
+ Product
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
{
/// Copy sign from `sign` - another real number
fn copysign(self, sign: Self) -> Self;
@@ -36,11 +46,8 @@ pub trait RealNumber: Number + Float {
self * self
}
/// Raw transmutation to u32
fn to_f32_bits(self) -> u32;
/// Raw transmutation to u64
fn to_f64_bits(self) -> u64;
fn to_f32_bits(self) -> u32;
}
impl RealNumber for f64 {
@@ -67,12 +74,8 @@ impl RealNumber for f64 {
}
fn rand() -> f64 {
let mut small_rng = get_rng_impl(None);
let mut rngs: Vec<SmallRng> = (0..3)
.map(|_| SmallRng::from_rng(&mut small_rng).unwrap())
.collect();
rngs[0].gen::<f64>()
let mut rng = rand::thread_rng();
rng.gen()
}
fn two() -> Self {
@@ -86,10 +89,6 @@ impl RealNumber for f64 {
fn to_f32_bits(self) -> u32 {
self.to_bits() as u32
}
fn to_f64_bits(self) -> u64 {
self.to_bits()
}
}
impl RealNumber for f32 {
@@ -116,12 +115,8 @@ impl RealNumber for f32 {
}
fn rand() -> f32 {
let mut small_rng = get_rng_impl(None);
let mut rngs: Vec<SmallRng> = (0..3)
.map(|_| SmallRng::from_rng(&mut small_rng).unwrap())
.collect();
rngs[0].gen::<f32>()
let mut rng = rand::thread_rng();
rng.gen()
}
fn two() -> Self {
@@ -135,41 +130,17 @@ impl RealNumber for f32 {
fn to_f32_bits(self) -> u32 {
self.to_bits()
}
fn to_f64_bits(self) -> u64 {
self.to_bits() as u64
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::str::FromStr;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn sigmoid() {
assert_eq!(1.0.sigmoid(), 0.7310585786300049);
assert_eq!(41.0.sigmoid(), 1.);
assert_eq!((-41.0).sigmoid(), 0.);
}
#[test]
fn f32_from_string() {
assert_eq!(f32::from_str("1.111111").unwrap(), 1.111111)
}
#[test]
fn f64_from_string() {
assert_eq!(f64::from_str("1.111111111").unwrap(), 1.111111111)
}
#[test]
fn f64_rand() {
f64::rand();
}
#[test]
fn f32_rand() {
f32::rand();
}
}
+42
View File
@@ -0,0 +1,42 @@
use crate::math::num::RealNumber;
use std::collections::HashMap;
use crate::linalg::BaseVector;
pub trait RealNumberVector<T: RealNumber> {
fn unique_with_indices(&self) -> (Vec<T>, Vec<usize>);
}
impl<T: RealNumber, V: BaseVector<T>> RealNumberVector<T> for V {
fn unique_with_indices(&self) -> (Vec<T>, Vec<usize>) {
let mut unique = self.to_vec();
unique.sort_by(|a, b| a.partial_cmp(b).unwrap());
unique.dedup();
let mut index = HashMap::with_capacity(unique.len());
for (i, u) in unique.iter().enumerate() {
index.insert(u.to_i64().unwrap(), i);
}
let mut unique_index = Vec::with_capacity(self.len());
for idx in 0..self.len() {
unique_index.push(index[&self.get(idx).to_i64().unwrap()]);
}
(unique, unique_index)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn unique_with_indices() {
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
assert_eq!(
(vec!(0.0, 1.0, 2.0, 4.0), vec!(0, 0, 1, 1, 2, 0, 3)),
v1.unique_with_indices()
);
}
}
+17 -62
View File
@@ -8,20 +8,10 @@
//!
//! ```
//! use smartcore::metrics::accuracy::Accuracy;
//! use smartcore::metrics::Metrics;
//! let y_pred: Vec<f64> = vec![0., 2., 1., 3.];
//! let y_true: Vec<f64> = vec![0., 1., 2., 3.];
//!
//! let score: f64 = Accuracy::new().get_score( &y_true, &y_pred);
//! ```
//! With integers:
//! ```
//! use smartcore::metrics::accuracy::Accuracy;
//! use smartcore::metrics::Metrics;
//! let y_pred: Vec<i64> = vec![0, 2, 1, 3];
//! let y_true: Vec<i64> = vec![0, 1, 2, 3];
//!
//! let score: f64 = Accuracy::new().get_score( &y_true, &y_pred);
//! let score: f64 = Accuracy {}.get_score(&y_pred, &y_true);
//! ```
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
@@ -29,53 +19,37 @@
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::linalg::basic::arrays::ArrayView1;
use crate::numbers::basenum::Number;
use std::marker::PhantomData;
use crate::metrics::Metrics;
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
/// Accuracy metric.
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct Accuracy<T> {
_phantom: PhantomData<T>,
}
pub struct Accuracy {}
impl<T: Number> Metrics<T> for Accuracy<T> {
/// create a typed object to call Accuracy functions
fn new() -> Self {
Self {
_phantom: PhantomData,
}
}
fn new_with(_parameter: f64) -> Self {
Self {
_phantom: PhantomData,
}
}
impl Accuracy {
/// Function that calculated accuracy score.
/// * `y_true` - cround truth (correct) labels
/// * `y_pred` - predicted labels, as returned by a classifier.
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64 {
if y_true.shape() != y_pred.shape() {
pub fn get_score<T: RealNumber, V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
if y_true.len() != y_pred.len() {
panic!(
"The vector sizes don't match: {} != {}",
y_true.shape(),
y_pred.shape()
y_true.len(),
y_pred.len()
);
}
let n = y_true.shape();
let n = y_true.len();
let mut positive: i32 = 0;
let mut positive = 0;
for i in 0..n {
if *y_true.get(i) == *y_pred.get(i) {
if y_true.get(i) == y_pred.get(i) {
positive += 1;
}
}
positive as f64 / n as f64
T::from_i64(positive).unwrap() / T::from_usize(n).unwrap()
}
}
@@ -83,35 +57,16 @@ impl<T: Number> Metrics<T> for Accuracy<T> {
mod tests {
use super::*;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn accuracy_float() {
fn accuracy() {
let y_pred: Vec<f64> = vec![0., 2., 1., 3.];
let y_true: Vec<f64> = vec![0., 1., 2., 3.];
let score1: f64 = Accuracy::<f64>::new().get_score(&y_true, &y_pred);
let score2: f64 = Accuracy::<f64>::new().get_score(&y_true, &y_true);
let score1: f64 = Accuracy {}.get_score(&y_pred, &y_true);
let score2: f64 = Accuracy {}.get_score(&y_true, &y_true);
assert!((score1 - 0.5).abs() < 1e-8);
assert!((score2 - 1.0).abs() < 1e-8);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn accuracy_int() {
let y_pred: Vec<i32> = vec![0, 2, 1, 3];
let y_true: Vec<i32> = vec![0, 1, 2, 3];
let score1: f64 = Accuracy::<i32>::new().get_score(&y_true, &y_pred);
let score2: f64 = Accuracy::<i32>::new().get_score(&y_true, &y_true);
assert_eq!(score1, 0.5);
assert_eq!(score2, 1.0);
}
}
+27 -50
View File
@@ -2,17 +2,16 @@
//! Computes the area under the receiver operating characteristic (ROC) curve that is equal to the probability that a classifier will rank a
//! randomly chosen positive instance higher than a randomly chosen negative one.
//!
//! `smartcore` calculates ROC AUC from Wilcoxon or Mann-Whitney U test.
//! SmartCore calculates ROC AUC from Wilcoxon or Mann-Whitney U test.
//!
//! Example:
//! ```
//! use smartcore::metrics::auc::AUC;
//! use smartcore::metrics::Metrics;
//!
//! let y_true: Vec<f64> = vec![0., 0., 1., 1.];
//! let y_pred: Vec<f64> = vec![0.1, 0.4, 0.35, 0.8];
//!
//! let score1: f64 = AUC::new().get_score(&y_true, &y_pred);
//! let score1: f64 = AUC {}.get_score(&y_true, &y_pred);
//! ```
//!
//! ## References:
@@ -21,48 +20,32 @@
//! * ["The ROC-AUC and the Mann-Whitney U-test", Haupt, J.](https://johaupt.github.io/roc-auc/model%20evaluation/Area_under_ROC_curve.html)
#![allow(non_snake_case)]
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::linalg::basic::arrays::{Array1, ArrayView1};
use crate::numbers::floatnum::FloatNumber;
use crate::metrics::Metrics;
use crate::algorithm::sort::quick_sort::QuickArgSort;
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
/// Area Under the Receiver Operating Characteristic Curve (ROC AUC)
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct AUC<T> {
_phantom: PhantomData<T>,
}
pub struct AUC {}
impl<T: FloatNumber + PartialOrd> Metrics<T> for AUC<T> {
/// create a typed object to call AUC functions
fn new() -> Self {
Self {
_phantom: PhantomData,
}
}
fn new_with(_parameter: f64) -> Self {
Self {
_phantom: PhantomData,
}
}
impl AUC {
/// AUC score.
/// * `y_true` - ground truth (correct) labels.
/// * `y_pred_prob` - probability estimates, as returned by a classifier.
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred_prob: &dyn ArrayView1<T>) -> f64 {
/// * `y_true` - cround truth (correct) labels.
/// * `y_pred_probabilities` - probability estimates, as returned by a classifier.
pub fn get_score<T: RealNumber, V: BaseVector<T>>(&self, y_true: &V, y_pred_prob: &V) -> T {
let mut pos = T::zero();
let mut neg = T::zero();
let n = y_true.shape();
let n = y_true.len();
for i in 0..n {
if y_true.get(i) == &T::zero() {
if y_true.get(i) == T::zero() {
neg += T::one();
} else if y_true.get(i) == &T::one() {
} else if y_true.get(i) == T::one() {
pos += T::one();
} else {
panic!(
@@ -72,22 +55,21 @@ impl<T: FloatNumber + PartialOrd> Metrics<T> for AUC<T> {
}
}
let y_pred: Vec<T> =
Array1::<T>::from_iterator(y_pred_prob.iterator(0).copied(), y_pred_prob.shape());
// TODO: try to use `crate::algorithm::sort::quick_sort` here
let label_idx: Vec<usize> = y_pred.argsort();
let mut y_pred = y_pred_prob.to_vec();
let mut rank = vec![0f64; n];
let label_idx = y_pred.quick_argsort_mut();
let mut rank = vec![T::zero(); n];
let mut i = 0;
while i < n {
if i == n - 1 || y_pred.get(i) != y_pred.get(i + 1) {
rank[i] = (i + 1) as f64;
if i == n - 1 || y_pred[i] != y_pred[i + 1] {
rank[i] = T::from_usize(i + 1).unwrap();
} else {
let mut j = i + 1;
while j < n && y_pred.get(j) == y_pred.get(i) {
while j < n && y_pred[j] == y_pred[i] {
j += 1;
}
let r = (i + 1 + j) as f64 / 2f64;
let r = T::from_usize(i + 1 + j).unwrap() / T::two();
for rank_k in rank.iter_mut().take(j).skip(i) {
*rank_k = r;
}
@@ -96,16 +78,14 @@ impl<T: FloatNumber + PartialOrd> Metrics<T> for AUC<T> {
i += 1;
}
let mut auc = 0f64;
let mut auc = T::zero();
for i in 0..n {
if y_true.get(label_idx[i]) == &T::one() {
if y_true.get(label_idx[i]) == T::one() {
auc += rank[i];
}
}
let pos = pos.to_f64().unwrap();
let neg = neg.to_f64().unwrap();
(auc - (pos * (pos + 1f64) / 2f64)) / (pos * neg)
(auc - (pos * (pos + T::one()) / T::two())) / (pos * neg)
}
}
@@ -113,17 +93,14 @@ impl<T: FloatNumber + PartialOrd> Metrics<T> for AUC<T> {
mod tests {
use super::*;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn auc() {
let y_true: Vec<f64> = vec![0., 0., 1., 1.];
let y_pred: Vec<f64> = vec![0.1, 0.4, 0.35, 0.8];
let score1: f64 = AUC::new().get_score(&y_true, &y_pred);
let score2: f64 = AUC::new().get_score(&y_true, &y_true);
let score1: f64 = AUC {}.get_score(&y_true, &y_pred);
let score2: f64 = AUC {}.get_score(&y_true, &y_true);
assert!((score1 - 0.75).abs() < 1e-8);
assert!((score2 - 1.0).abs() < 1e-8);
+31 -79
View File
@@ -1,85 +1,41 @@
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::linalg::basic::arrays::ArrayView1;
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
use crate::metrics::cluster_helpers::*;
use crate::numbers::basenum::Number;
use crate::metrics::Metrics;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
/// Homogeneity, completeness and V-Measure scores.
pub struct HCVScore<T> {
_phantom: PhantomData<T>,
homogeneity: Option<f64>,
completeness: Option<f64>,
v_measure: Option<f64>,
}
pub struct HCVScore {}
impl<T: Number + Ord> HCVScore<T> {
/// return homogenity score
pub fn homogeneity(&self) -> Option<f64> {
self.homogeneity
}
/// return completeness score
pub fn completeness(&self) -> Option<f64> {
self.completeness
}
/// return v_measure score
pub fn v_measure(&self) -> Option<f64> {
self.v_measure
}
/// run computation for measures
pub fn compute(&mut self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) {
let entropy_c: Option<f64> = entropy(y_true);
let entropy_k: Option<f64> = entropy(y_pred);
let contingency = contingency_matrix(y_true, y_pred);
let mi = mutual_info_score(&contingency);
impl HCVScore {
/// Computes Homogeneity, completeness and V-Measure scores at once.
/// * `labels_true` - ground truth class labels to be used as a reference.
/// * `labels_pred` - cluster labels to evaluate.
pub fn get_score<T: RealNumber, V: BaseVector<T>>(
&self,
labels_true: &V,
labels_pred: &V,
) -> (T, T, T) {
let labels_true = labels_true.to_vec();
let labels_pred = labels_pred.to_vec();
let entropy_c = entropy(&labels_true);
let entropy_k = entropy(&labels_pred);
let contingency = contingency_matrix(&labels_true, &labels_pred);
let mi: T = mutual_info_score(&contingency);
let homogeneity = entropy_c.map(|e| mi / e).unwrap_or(0f64);
let completeness = entropy_k.map(|e| mi / e).unwrap_or(0f64);
let homogeneity = entropy_c.map(|e| mi / e).unwrap_or_else(T::one);
let completeness = entropy_k.map(|e| mi / e).unwrap_or_else(T::one);
let v_measure_score = if homogeneity + completeness == 0f64 {
0f64
let v_measure_score = if homogeneity + completeness == T::zero() {
T::zero()
} else {
2.0f64 * homogeneity * completeness / (1.0f64 * homogeneity + completeness)
T::two() * homogeneity * completeness / (T::one() * homogeneity + completeness)
};
self.homogeneity = Some(homogeneity);
self.completeness = Some(completeness);
self.v_measure = Some(v_measure_score);
}
}
impl<T: Number + Ord> Metrics<T> for HCVScore<T> {
/// create a typed object to call HCVScore functions
fn new() -> Self {
Self {
_phantom: PhantomData,
homogeneity: Option::None,
completeness: Option::None,
v_measure: Option::None,
}
}
fn new_with(_parameter: f64) -> Self {
Self {
_phantom: PhantomData,
homogeneity: Option::None,
completeness: Option::None,
v_measure: Option::None,
}
}
/// Computes Homogeneity, completeness and V-Measure scores at once.
/// * `y_true` - ground truth class labels to be used as a reference.
/// * `y_pred` - cluster labels to evaluate.
fn get_score(&self, _y_true: &dyn ArrayView1<T>, _y_pred: &dyn ArrayView1<T>) -> f64 {
// this functions should not be used for this struct
// use homogeneity(), completeness(), v_measure()
// TODO: implement Metrics -> Result<T, Failed>
0f64
(homogeneity, completeness, v_measure_score)
}
}
@@ -87,19 +43,15 @@ impl<T: Number + Ord> Metrics<T> for HCVScore<T> {
mod tests {
use super::*;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn homogeneity_score() {
let v1 = vec![0, 0, 1, 1, 2, 0, 4];
let v2 = vec![1, 0, 0, 0, 0, 1, 0];
let mut scores = HCVScore::new();
scores.compute(&v1, &v2);
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
let v2 = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0];
let scores = HCVScore {}.get_score(&v1, &v2);
assert!((0.2548 - scores.homogeneity.unwrap()).abs() < 1e-4);
assert!((0.5440 - scores.completeness.unwrap()).abs() < 1e-4);
assert!((0.3471 - scores.v_measure.unwrap()).abs() < 1e-4);
assert!((0.2548f32 - scores.0).abs() < 1e-4);
assert!((0.5440f32 - scores.1).abs() < 1e-4);
assert!((0.3471f32 - scores.2).abs() < 1e-4);
}
}
+36 -46
View File
@@ -1,12 +1,12 @@
#![allow(clippy::ptr_arg)]
use std::collections::HashMap;
use crate::linalg::basic::arrays::ArrayView1;
use crate::numbers::basenum::Number;
use crate::math::num::RealNumber;
use crate::math::vector::RealNumberVector;
pub fn contingency_matrix<T: Number + Ord, V: ArrayView1<T> + ?Sized>(
labels_true: &V,
labels_pred: &V,
pub fn contingency_matrix<T: RealNumber>(
labels_true: &Vec<T>,
labels_pred: &Vec<T>,
) -> Vec<Vec<usize>> {
let (classes, class_idx) = labels_true.unique_with_indices();
let (clusters, cluster_idx) = labels_pred.unique_with_indices();
@@ -24,30 +24,28 @@ pub fn contingency_matrix<T: Number + Ord, V: ArrayView1<T> + ?Sized>(
contingency_matrix
}
pub fn entropy<T: Number + Ord, V: ArrayView1<T> + ?Sized>(data: &V) -> Option<f64> {
let mut bincounts = HashMap::with_capacity(data.shape());
pub fn entropy<T: RealNumber>(data: &[T]) -> Option<T> {
let mut bincounts = HashMap::with_capacity(data.len());
for e in data.iterator(0) {
for e in data.iter() {
let k = e.to_i64().unwrap();
bincounts.insert(k, bincounts.get(&k).unwrap_or(&0) + 1);
}
let mut entropy = 0f64;
let sum: i64 = bincounts.values().sum();
let mut entropy = T::zero();
let sum = T::from_usize(bincounts.values().sum()).unwrap();
for &c in bincounts.values() {
if c > 0 {
let pi = c as f64;
let pi_ln = pi.ln();
let sum_ln = (sum as f64).ln();
entropy -= (pi / sum as f64) * (pi_ln - sum_ln);
let pi = T::from_usize(c).unwrap();
entropy -= (pi / sum) * (pi.ln() - sum.ln());
}
}
Some(entropy)
}
pub fn mutual_info_score(contingency: &[Vec<usize>]) -> f64 {
pub fn mutual_info_score<T: RealNumber>(contingency: &[Vec<usize>]) -> T {
let mut contingency_sum = 0;
let mut pi = vec![0; contingency.len()];
let mut pj = vec![0; contingency[0].len()];
@@ -66,50 +64,48 @@ pub fn mutual_info_score(contingency: &[Vec<usize>]) -> f64 {
}
}
let contingency_sum = contingency_sum as f64;
let contingency_sum = T::from_usize(contingency_sum).unwrap();
let contingency_sum_ln = contingency_sum.ln();
let pi_sum: usize = pi.iter().sum();
let pj_sum: usize = pj.iter().sum();
let pi_sum_l = (pi_sum as f64).ln();
let pj_sum_l = (pj_sum as f64).ln();
let pi_sum_l = T::from_usize(pi.iter().sum()).unwrap().ln();
let pj_sum_l = T::from_usize(pj.iter().sum()).unwrap().ln();
let log_contingency_nm: Vec<f64> = nz_val.iter().map(|v| (*v as f64).ln()).collect();
let contingency_nm: Vec<f64> = nz_val
let log_contingency_nm: Vec<T> = nz_val
.iter()
.map(|v| (*v as f64) / contingency_sum)
.map(|v| T::from_usize(*v).unwrap().ln())
.collect();
let contingency_nm: Vec<T> = nz_val
.iter()
.map(|v| T::from_usize(*v).unwrap() / contingency_sum)
.collect();
let outer: Vec<usize> = nzx
.iter()
.zip(nzy.iter())
.map(|(&x, &y)| pi[x] * pj[y])
.collect();
let log_outer: Vec<f64> = outer
let log_outer: Vec<T> = outer
.iter()
.map(|&o| -(o as f64).ln() + pi_sum_l + pj_sum_l)
.map(|&o| -T::from_usize(o).unwrap().ln() + pi_sum_l + pj_sum_l)
.collect();
let mut result = 0f64;
let mut result = T::zero();
for i in 0..log_outer.len() {
result += (contingency_nm[i] * (log_contingency_nm[i] - contingency_sum_ln))
+ contingency_nm[i] * log_outer[i]
}
result.max(0f64)
result.max(T::zero())
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn contingency_matrix_test() {
let v1 = vec![0, 0, 1, 1, 2, 0, 4];
let v2 = vec![1, 0, 0, 0, 0, 1, 0];
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
let v2 = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0];
assert_eq!(
vec!(vec!(1, 2), vec!(2, 0), vec!(1, 0), vec!(1, 0)),
@@ -117,26 +113,20 @@ mod tests {
);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn entropy_test() {
let v1 = vec![0, 0, 1, 1, 2, 0, 4];
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
assert!((1.2770 - entropy(&v1).unwrap()).abs() < 1e-4);
assert!((1.2770f32 - entropy(&v1).unwrap()).abs() < 1e-4);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn mutual_info_score_test() {
let v1 = vec![0, 0, 1, 1, 2, 0, 4];
let v2 = vec![1, 0, 0, 0, 0, 1, 0];
let s = mutual_info_score(&contingency_matrix(&v1, &v2));
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
let v2 = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0];
let s: f32 = mutual_info_score(&contingency_matrix(&v1, &v2));
assert!((0.3254 - s).abs() < 1e-4);
}
-219
View File
@@ -1,219 +0,0 @@
//! # Cosine Distance Metric
//!
//! The cosine distance between two points \\( x \\) and \\( y \\) in n-space is defined as:
//!
//! \\[ d(x, y) = 1 - \frac{x \cdot y}{||x|| ||y||} \\]
//!
//! where \\( x \cdot y \\) is the dot product of the vectors, and \\( ||x|| \\) and \\( ||y|| \\)
//! are their respective magnitudes (Euclidean norms).
//!
//! Cosine distance measures the angular dissimilarity between vectors, ranging from 0 to 2.
//! A value of 0 indicates identical direction (parallel vectors), while larger values indicate
//! greater angular separation.
//!
//! Example:
//!
//! ```
//! use smartcore::metrics::distance::Distance;
//! use smartcore::metrics::distance::cosine::Cosine;
//!
//! let x = vec![1., 1.];
//! let y = vec![2., 2.];
//!
//! let cosine_dist: f64 = Cosine::new().distance(&x, &y);
//! ```
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
use crate::linalg::basic::arrays::ArrayView1;
use crate::numbers::basenum::Number;
use super::Distance;
/// Cosine distance is a measure of the angular dissimilarity between two non-zero vectors in n-space.
/// It is defined as 1 minus the cosine similarity of the vectors.
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Cosine<T> {
_t: PhantomData<T>,
}
impl<T: Number> Default for Cosine<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Number> Cosine<T> {
/// Instantiate the initial structure
pub fn new() -> Cosine<T> {
Cosine { _t: PhantomData }
}
/// Calculate the dot product of two vectors using smartcore's ArrayView1 trait
#[inline]
pub(crate) fn dot_product<A: ArrayView1<T>>(x: &A, y: &A) -> f64 {
if x.shape() != y.shape() {
panic!("Input vector sizes are different.");
}
// Use the built-in dot product method from ArrayView1 trait
x.dot(y).to_f64().unwrap()
}
/// Calculate the squared magnitude (norm squared) of a vector
#[inline]
#[allow(dead_code)]
pub(crate) fn squared_magnitude<A: ArrayView1<T>>(x: &A) -> f64 {
x.iterator(0)
.map(|&a| {
let val = a.to_f64().unwrap();
val * val
})
.sum()
}
/// Calculate the magnitude (Euclidean norm) of a vector using smartcore's norm2 method
#[inline]
pub(crate) fn magnitude<A: ArrayView1<T>>(x: &A) -> f64 {
// Use the built-in norm2 method from ArrayView1 trait
x.norm2()
}
/// Calculate cosine similarity between two vectors
#[inline]
pub(crate) fn cosine_similarity<A: ArrayView1<T>>(x: &A, y: &A) -> f64 {
let dot_product = Self::dot_product(x, y);
let magnitude_x = Self::magnitude(x);
let magnitude_y = Self::magnitude(y);
if magnitude_x == 0.0 || magnitude_y == 0.0 {
return f64::MIN;
}
dot_product / (magnitude_x * magnitude_y)
}
}
impl<T: Number, A: ArrayView1<T>> Distance<A> for Cosine<T> {
fn distance(&self, x: &A, y: &A) -> f64 {
let similarity = Cosine::cosine_similarity(x, y);
1.0 - similarity
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn cosine_distance_identical_vectors() {
let a = vec![1, 2, 3];
let b = vec![1, 2, 3];
let dist: f64 = Cosine::new().distance(&a, &b);
assert!((dist - 0.0).abs() < 1e-8);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn cosine_distance_orthogonal_vectors() {
let a = vec![1, 0];
let b = vec![0, 1];
let dist: f64 = Cosine::new().distance(&a, &b);
assert!((dist - 1.0).abs() < 1e-8);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn cosine_distance_opposite_vectors() {
let a = vec![1, 2, 3];
let b = vec![-1, -2, -3];
let dist: f64 = Cosine::new().distance(&a, &b);
assert!((dist - 2.0).abs() < 1e-8);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn cosine_distance_general_case() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![2.0, 1.0, 3.0];
let dist: f64 = Cosine::new().distance(&a, &b);
// Expected cosine similarity: (1*2 + 2*1 + 3*3) / (sqrt(1+4+9) * sqrt(4+1+9))
// = (2 + 2 + 9) / (sqrt(14) * sqrt(14)) = 13/14 ≈ 0.9286
// So cosine distance = 1 - 13/14 = 1/14 ≈ 0.0714
let expected_dist = 1.0 - (13.0 / 14.0);
assert!((dist - expected_dist).abs() < 1e-8);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
#[should_panic(expected = "Input vector sizes are different.")]
fn cosine_distance_different_sizes() {
let a = vec![1, 2];
let b = vec![1, 2, 3];
let _dist: f64 = Cosine::new().distance(&a, &b);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn cosine_distance_zero_vector() {
let a = vec![0, 0, 0];
let b = vec![1, 2, 3];
let dist: f64 = Cosine::new().distance(&a, &b);
assert!(dist > 1e300)
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn cosine_distance_float_precision() {
let a = vec![1.0f32, 2.0, 3.0];
let b = vec![4.0f32, 5.0, 6.0];
let dist: f64 = Cosine::new().distance(&a, &b);
// Calculate expected value manually
let dot_product = 1.0 * 4.0 + 2.0 * 5.0 + 3.0 * 6.0; // = 32
let mag_a = (1.0 * 1.0 + 2.0 * 2.0 + 3.0 * 3.0_f64).sqrt(); // = sqrt(14)
let mag_b = (4.0 * 4.0 + 5.0 * 5.0 + 6.0 * 6.0_f64).sqrt(); // = sqrt(77)
let expected_similarity = dot_product / (mag_a * mag_b);
let expected_distance = 1.0 - expected_similarity;
assert!((dist - expected_distance).abs() < 1e-6);
}
}
-92
View File
@@ -1,92 +0,0 @@
//! # Euclidian Metric Distance
//!
//! The Euclidean distance (L2) between two points \\( x \\) and \\( y \\) in n-space is defined as
//!
//! \\[ d(x, y) = \sqrt{\sum_{i=1}^n (x-y)^2} \\]
//!
//! Example:
//!
//! ```
//! use smartcore::metrics::distance::Distance;
//! use smartcore::metrics::distance::euclidian::Euclidian;
//!
//! let x = vec![1., 1.];
//! let y = vec![2., 2.];
//!
//! let l2: f64 = Euclidian::new().distance(&x, &y);
//! ```
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
use crate::linalg::basic::arrays::ArrayView1;
use crate::numbers::basenum::Number;
use super::Distance;
/// Euclidean distance is a measure of the true straight line distance between two points in Euclidean n-space.
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Euclidian<T> {
_t: PhantomData<T>,
}
impl<T: Number> Default for Euclidian<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Number> Euclidian<T> {
/// instatiate the initial structure
pub fn new() -> Euclidian<T> {
Euclidian { _t: PhantomData }
}
/// return sum of squared distances
#[inline]
pub(crate) fn squared_distance<A: ArrayView1<T>>(x: &A, y: &A) -> f64 {
if x.shape() != y.shape() {
panic!("Input vector sizes are different.");
}
let sum: f64 = x
.iterator(0)
.zip(y.iterator(0))
.map(|(&a, &b)| {
let r = a - b;
(r * r).to_f64().unwrap()
})
.sum();
sum
}
}
impl<T: Number, A: ArrayView1<T>> Distance<A> for Euclidian<T> {
fn distance(&self, x: &A, y: &A) -> f64 {
Euclidian::squared_distance(x, y).sqrt()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn squared_distance() {
let a = vec![1, 2, 3];
let b = vec![4, 5, 6];
let l2: f64 = Euclidian::new().distance(&a, &b);
assert!((l2 - 5.19615242).abs() < 1e-8);
}
}
-82
View File
@@ -1,82 +0,0 @@
//! # Manhattan Distance
//!
//! The Manhattan distance between two points \\(x \in ^n \\) and \\( y \in ^n \\) in n-dimensional space is the sum of the distances in each dimension.
//!
//! \\[ d(x, y) = \sum_{i=0}^n \lvert x_i - y_i \rvert \\]
//!
//! Example:
//!
//! ```
//! use smartcore::metrics::distance::Distance;
//! use smartcore::metrics::distance::manhattan::Manhattan;
//!
//! let x = vec![1., 1.];
//! let y = vec![2., 2.];
//!
//! let l1: f64 = Manhattan::new().distance(&x, &y);
//! ```
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
use crate::linalg::basic::arrays::ArrayView1;
use crate::numbers::basenum::Number;
use super::Distance;
/// Manhattan distance
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Manhattan<T: Number> {
_t: PhantomData<T>,
}
impl<T: Number> Manhattan<T> {
/// instatiate the initial structure
pub fn new() -> Manhattan<T> {
Manhattan { _t: PhantomData }
}
}
impl<T: Number> Default for Manhattan<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Number, A: ArrayView1<T>> Distance<A> for Manhattan<T> {
fn distance(&self, x: &A, y: &A) -> f64 {
if x.shape() != y.shape() {
panic!("Input vector sizes are different");
}
let dist: f64 = x
.iterator(0)
.zip(y.iterator(0))
.map(|(&a, &b)| (a - b).to_f64().unwrap().abs())
.sum();
dist
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn manhattan_distance() {
let a = vec![1., 2., 3.];
let b = vec![4., 5., 6.];
let l1: f64 = Manhattan::new().distance(&a, &b);
assert!((l1 - 9.0).abs() < 1e-8);
}
}
-118
View File
@@ -1,118 +0,0 @@
//! # Collection of Distance Functions
//!
//! Many algorithms in machine learning require a measure of distance between data points. Distance metric (or metric) is a function that defines a distance between a pair of point elements of a set.
//! Formally, the distance can be any metric measure that is defined as \\( d(x, y) \geq 0\\) and follows three conditions:
//! 1. \\( d(x, y) = 0 \\) if and only \\( x = y \\), positive definiteness
//! 1. \\( d(x, y) = d(y, x) \\), symmetry
//! 1. \\( d(x, y) \leq d(x, z) + d(z, y) \\), subadditivity or triangle inequality
//!
//! for all \\(x, y, z \in Z \\)
//!
//! A good distance metric helps to improve the performance of classification, clustering and information retrieval algorithms significantly.
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
/// Cosine distance
pub mod cosine;
/// Euclidean Distance is the straight-line distance between two points in Euclidean spacere that presents the shortest distance between these points.
pub mod euclidian;
/// Hamming Distance between two strings is the number of positions at which the corresponding symbols are different.
pub mod hamming;
/// The Mahalanobis distance is the distance between two points in multivariate space.
pub mod mahalanobis;
/// Also known as rectilinear distance, city block distance, taxicab metric.
pub mod manhattan;
/// A generalization of both the Euclidean distance and the Manhattan distance.
pub mod minkowski;
use std::cmp::{Eq, Ordering, PartialOrd};
use crate::linalg::basic::arrays::Array2;
use crate::linalg::traits::lu::LUDecomposable;
use crate::numbers::basenum::Number;
use crate::numbers::realnum::RealNumber;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
/// Distance metric, a function that calculates distance between two points
pub trait Distance<T>: Clone {
/// Calculates distance between _a_ and _b_
fn distance(&self, a: &T, b: &T) -> f64;
}
/// Multitude of distance metric functions
pub struct Distances {}
impl Distances {
/// Euclidian distance, see [`Euclidian`](euclidian/index.html)
pub fn euclidian<T: Number>() -> euclidian::Euclidian<T> {
euclidian::Euclidian::new()
}
/// Minkowski distance, see [`Minkowski`](minkowski/index.html)
/// * `p` - function order. Should be >= 1
pub fn minkowski<T: Number>(p: u16) -> minkowski::Minkowski<T> {
minkowski::Minkowski::new(p)
}
/// Manhattan distance, see [`Manhattan`](manhattan/index.html)
pub fn manhattan<T: Number>() -> manhattan::Manhattan<T> {
manhattan::Manhattan::new()
}
/// Hamming distance, see [`Hamming`](hamming/index.html)
pub fn hamming<T: Number>() -> hamming::Hamming<T> {
hamming::Hamming::new()
}
/// Mahalanobis distance, see [`Mahalanobis`](mahalanobis/index.html)
pub fn mahalanobis<T: Number, M: Array2<T>, C: Array2<f64> + LUDecomposable<f64>>(
data: &M,
) -> mahalanobis::Mahalanobis<T, C> {
mahalanobis::Mahalanobis::new(data)
}
}
///
/// ### Pairwise dissimilarities.
///
/// Representing distances as pairwise dissimilarities, so to build a
/// graph of closest neighbours. This representation can be reused for
/// different implementations
/// (initially used in this library for [FastPair](algorithm/neighbour/fastpair)).
/// The edge of the subgraph is defined by `PairwiseDistance`.
/// The calling algorithm can store a list of distances as
/// a list of these structures.
///
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, Copy)]
pub struct PairwiseDistance<T: RealNumber> {
/// index of the vector in the original `Matrix` or list
pub node: usize,
/// index of the closest neighbor in the original `Matrix` or same list
pub neighbour: Option<usize>,
/// measure of distance, according to the algorithm distance function
/// if the distance is None, the edge has value "infinite" or max distance
/// each algorithm has to match
pub distance: Option<T>,
}
impl<T: RealNumber> Eq for PairwiseDistance<T> {}
impl<T: RealNumber> PartialEq for PairwiseDistance<T> {
fn eq(&self, other: &Self) -> bool {
self.node == other.node
&& self.neighbour == other.neighbour
&& self.distance == other.distance
}
}
impl<T: RealNumber> PartialOrd for PairwiseDistance<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.distance.partial_cmp(&other.distance)
}
}
+16 -46
View File
@@ -10,71 +10,48 @@
//!
//! ```
//! use smartcore::metrics::f1::F1;
//! use smartcore::metrics::Metrics;
//! let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
//! let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
//!
//! let beta = 1.0; // beta default is equal 1.0 anyway
//! let score: f64 = F1::new_with(beta).get_score( &y_true, &y_pred);
//! let score: f64 = F1 {beta: 1.0}.get_score(&y_pred, &y_true);
//! ```
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::linalg::basic::arrays::ArrayView1;
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
use crate::metrics::precision::Precision;
use crate::metrics::recall::Recall;
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
use crate::numbers::realnum::RealNumber;
use crate::metrics::Metrics;
/// F-measure
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct F1<T> {
pub struct F1<T: RealNumber> {
/// a positive real factor
pub beta: f64,
_phantom: PhantomData<T>,
pub beta: T,
}
impl<T: Number + RealNumber + FloatNumber> Metrics<T> for F1<T> {
fn new() -> Self {
let beta: f64 = 1f64;
Self {
beta,
_phantom: PhantomData,
}
}
/// create a typed object to call Recall functions
fn new_with(beta: f64) -> Self {
Self {
beta,
_phantom: PhantomData,
}
}
impl<T: RealNumber> F1<T> {
/// Computes F1 score
/// * `y_true` - cround truth (correct) labels.
/// * `y_pred` - predicted labels, as returned by a classifier.
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64 {
if y_true.shape() != y_pred.shape() {
pub fn get_score<V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
if y_true.len() != y_pred.len() {
panic!(
"The vector sizes don't match: {} != {}",
y_true.shape(),
y_pred.shape()
y_true.len(),
y_pred.len()
);
}
let beta2 = self.beta * self.beta;
let p = Precision::new().get_score(y_true, y_pred);
let r = Recall::new().get_score(y_true, y_pred);
let p = Precision {}.get_score(y_true, y_pred);
let r = Recall {}.get_score(y_true, y_pred);
(1f64 + beta2) * (p * r) / ((beta2 * p) + r)
(T::one() + beta2) * (p * r) / (beta2 * p + r)
}
}
@@ -82,21 +59,14 @@ impl<T: Number + RealNumber + FloatNumber> Metrics<T> for F1<T> {
mod tests {
use super::*;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn f1() {
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
let beta = 1.0;
let score1: f64 = F1::new_with(beta).get_score(&y_true, &y_pred);
let score2: f64 = F1::new_with(beta).get_score(&y_true, &y_true);
println!("{score1:?}");
println!("{score2:?}");
let score1: f64 = F1 { beta: 1.0 }.get_score(&y_pred, &y_true);
let score2: f64 = F1 { beta: 1.0 }.get_score(&y_true, &y_true);
assert!((score1 - 0.57142857).abs() < 1e-8);
assert!((score2 - 1.0).abs() < 1e-8);
+16 -39
View File
@@ -10,65 +10,45 @@
//!
//! ```
//! use smartcore::metrics::mean_absolute_error::MeanAbsoluteError;
//! use smartcore::metrics::Metrics;
//! let y_pred: Vec<f64> = vec![3., -0.5, 2., 7.];
//! let y_true: Vec<f64> = vec![2.5, 0.0, 2., 8.];
//!
//! let mse: f64 = MeanAbsoluteError::new().get_score( &y_true, &y_pred);
//! let mse: f64 = MeanAbsoluteError {}.get_score(&y_pred, &y_true);
//! ```
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::linalg::basic::arrays::ArrayView1;
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
use crate::metrics::Metrics;
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
/// Mean Absolute Error
pub struct MeanAbsoluteError<T> {
_phantom: PhantomData<T>,
}
pub struct MeanAbsoluteError {}
impl<T: Number + FloatNumber> Metrics<T> for MeanAbsoluteError<T> {
/// create a typed object to call MeanAbsoluteError functions
fn new() -> Self {
Self {
_phantom: PhantomData,
}
}
fn new_with(_parameter: f64) -> Self {
Self {
_phantom: PhantomData,
}
}
impl MeanAbsoluteError {
/// Computes mean absolute error
/// * `y_true` - Ground truth (correct) target values.
/// * `y_pred` - Estimated target values.
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64 {
if y_true.shape() != y_pred.shape() {
pub fn get_score<T: RealNumber, V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
if y_true.len() != y_pred.len() {
panic!(
"The vector sizes don't match: {} != {}",
y_true.shape(),
y_pred.shape()
y_true.len(),
y_pred.len()
);
}
let n = y_true.shape();
let mut ras: T = T::zero();
let n = y_true.len();
let mut ras = T::zero();
for i in 0..n {
let res: T = *y_true.get(i) - *y_pred.get(i);
ras += res.abs();
ras += (y_true.get(i) - y_pred.get(i)).abs();
}
ras.to_f64().unwrap() / n as f64
ras / T::from_usize(n).unwrap()
}
}
@@ -76,17 +56,14 @@ impl<T: Number + FloatNumber> Metrics<T> for MeanAbsoluteError<T> {
mod tests {
use super::*;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn mean_absolute_error() {
let y_true: Vec<f64> = vec![3., -0.5, 2., 7.];
let y_pred: Vec<f64> = vec![2.5, 0.0, 2., 8.];
let score1: f64 = MeanAbsoluteError::new().get_score(&y_true, &y_pred);
let score2: f64 = MeanAbsoluteError::new().get_score(&y_true, &y_true);
let score1: f64 = MeanAbsoluteError {}.get_score(&y_pred, &y_true);
let score2: f64 = MeanAbsoluteError {}.get_score(&y_true, &y_true);
assert!((score1 - 0.5).abs() < 1e-8);
assert!((score2 - 0.0).abs() < 1e-8);
+15 -38
View File
@@ -10,65 +10,45 @@
//!
//! ```
//! use smartcore::metrics::mean_squared_error::MeanSquareError;
//! use smartcore::metrics::Metrics;
//! let y_pred: Vec<f64> = vec![3., -0.5, 2., 7.];
//! let y_true: Vec<f64> = vec![2.5, 0.0, 2., 8.];
//!
//! let mse: f64 = MeanSquareError::new().get_score( &y_true, &y_pred);
//! let mse: f64 = MeanSquareError {}.get_score(&y_pred, &y_true);
//! ```
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::linalg::basic::arrays::ArrayView1;
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
use crate::metrics::Metrics;
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
/// Mean Squared Error
pub struct MeanSquareError<T> {
_phantom: PhantomData<T>,
}
pub struct MeanSquareError {}
impl<T: Number + FloatNumber> Metrics<T> for MeanSquareError<T> {
/// create a typed object to call MeanSquareError functions
fn new() -> Self {
Self {
_phantom: PhantomData,
}
}
fn new_with(_parameter: f64) -> Self {
Self {
_phantom: PhantomData,
}
}
impl MeanSquareError {
/// Computes mean squared error
/// * `y_true` - Ground truth (correct) target values.
/// * `y_pred` - Estimated target values.
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64 {
if y_true.shape() != y_pred.shape() {
pub fn get_score<T: RealNumber, V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
if y_true.len() != y_pred.len() {
panic!(
"The vector sizes don't match: {} != {}",
y_true.shape(),
y_pred.shape()
y_true.len(),
y_pred.len()
);
}
let n = y_true.shape();
let n = y_true.len();
let mut rss = T::zero();
for i in 0..n {
let res = *y_true.get(i) - *y_pred.get(i);
rss += res * res;
rss += (y_true.get(i) - y_pred.get(i)).square();
}
rss.to_f64().unwrap() / n as f64
rss / T::from_usize(n).unwrap()
}
}
@@ -76,17 +56,14 @@ impl<T: Number + FloatNumber> Metrics<T> for MeanSquareError<T> {
mod tests {
use super::*;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn mean_squared_error() {
let y_true: Vec<f64> = vec![3., -0.5, 2., 7.];
let y_pred: Vec<f64> = vec![2.5, 0.0, 2., 8.];
let score1: f64 = MeanSquareError::new().get_score(&y_true, &y_pred);
let score2: f64 = MeanSquareError::new().get_score(&y_true, &y_true);
let score1: f64 = MeanSquareError {}.get_score(&y_pred, &y_true);
let score2: f64 = MeanSquareError {}.get_score(&y_true, &y_true);
assert!((score1 - 0.375).abs() < 1e-8);
assert!((score2 - 0.0).abs() < 1e-8);
+64 -142
View File
@@ -4,7 +4,7 @@
//! In a feedback loop you build your model first, then you get feedback from metrics, improve it and repeat until your model achieve desirable performance.
//! Evaluation metrics helps to explain the performance of a model and compare models based on an objective criterion.
//!
//! Choosing the right metric is crucial while evaluating machine learning models. In `smartcore` you will find metrics for these classes of ML models:
//! Choosing the right metric is crucial while evaluating machine learning models. In SmartCore you will find metrics for these classes of ML models:
//!
//! * [Classification metrics](struct.ClassificationMetrics.html)
//! * [Regression metrics](struct.RegressionMetrics.html)
@@ -12,7 +12,7 @@
//!
//! Example:
//! ```
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::linalg::naive::dense_matrix::*;
//! use smartcore::linear::logistic_regression::LogisticRegression;
//! use smartcore::metrics::*;
//!
@@ -37,30 +37,27 @@
//! &[4.9, 2.4, 3.3, 1.0],
//! &[6.6, 2.9, 4.6, 1.3],
//! &[5.2, 2.7, 3.9, 1.4],
//! ]).unwrap();
//! let y: Vec<i8> = vec![
//! 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
//! ]);
//! let y: Vec<f64> = vec![
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
//! ];
//!
//! let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
//!
//! let y_hat = lr.predict(&x).unwrap();
//!
//! let acc = ClassificationMetricsOrd::accuracy().get_score(&y, &y_hat);
//! let acc = ClassificationMetrics::accuracy().get_score(&y, &y_hat);
//! // or
//! let acc = accuracy(&y, &y_hat);
//! ```
/// Accuracy score.
pub mod accuracy;
// TODO: reimplement AUC
// /// Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores.
/// Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores.
pub mod auc;
/// Compute the homogeneity, completeness and V-Measure scores.
pub mod cluster_hcv;
pub(crate) mod cluster_helpers;
/// Multitude of distance metrics are defined here
pub mod distance;
/// F1 score, also known as balanced F-score or F-measure.
pub mod f1;
/// Mean absolute error regression loss.
@@ -74,225 +71,150 @@ pub mod r2;
/// Computes the recall.
pub mod recall;
use crate::linalg::basic::arrays::{Array1, ArrayView1};
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
use crate::numbers::realnum::RealNumber;
use std::marker::PhantomData;
/// A trait to be implemented by all metrics
pub trait Metrics<T> {
/// instantiate a new Metrics trait-object
/// <https://doc.rust-lang.org/error-index.html#E0038>
fn new() -> Self
where
Self: Sized;
/// used to instantiate metric with a paramenter
fn new_with(_parameter: f64) -> Self
where
Self: Sized;
/// compute score realated to this metric
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64;
}
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
/// Use these metrics to compare classification models.
pub struct ClassificationMetrics<T> {
phantom: PhantomData<T>,
}
/// Use these metrics to compare classification models for
/// numbers that require `Ord`.
pub struct ClassificationMetricsOrd<T> {
phantom: PhantomData<T>,
}
pub struct ClassificationMetrics {}
/// Metrics for regression models.
pub struct RegressionMetrics<T> {
phantom: PhantomData<T>,
}
pub struct RegressionMetrics {}
/// Cluster metrics.
pub struct ClusterMetrics<T> {
phantom: PhantomData<T>,
}
pub struct ClusterMetrics {}
impl ClassificationMetrics {
/// Accuracy score, see [accuracy](accuracy/index.html).
pub fn accuracy() -> accuracy::Accuracy {
accuracy::Accuracy {}
}
impl<T: Number + RealNumber + FloatNumber> ClassificationMetrics<T> {
/// Recall, see [recall](recall/index.html).
pub fn recall() -> recall::Recall<T> {
recall::Recall::new()
pub fn recall() -> recall::Recall {
recall::Recall {}
}
/// Precision, see [precision](precision/index.html).
pub fn precision() -> precision::Precision<T> {
precision::Precision::new()
pub fn precision() -> precision::Precision {
precision::Precision {}
}
/// F1 score, also known as balanced F-score or F-measure, see [F1](f1/index.html).
pub fn f1(beta: f64) -> f1::F1<T> {
f1::F1::new_with(beta)
pub fn f1<T: RealNumber>(beta: T) -> f1::F1<T> {
f1::F1 { beta }
}
/// Area Under the Receiver Operating Characteristic Curve (ROC AUC), see [AUC](auc/index.html).
pub fn roc_auc_score() -> auc::AUC<T> {
auc::AUC::<T>::new()
pub fn roc_auc_score() -> auc::AUC {
auc::AUC {}
}
}
impl<T: Number + Ord> ClassificationMetricsOrd<T> {
/// Accuracy score, see [accuracy](accuracy/index.html).
pub fn accuracy() -> accuracy::Accuracy<T> {
accuracy::Accuracy::new()
}
}
impl<T: Number + FloatNumber> RegressionMetrics<T> {
impl RegressionMetrics {
/// Mean squared error, see [mean squared error](mean_squared_error/index.html).
pub fn mean_squared_error() -> mean_squared_error::MeanSquareError<T> {
mean_squared_error::MeanSquareError::new()
pub fn mean_squared_error() -> mean_squared_error::MeanSquareError {
mean_squared_error::MeanSquareError {}
}
/// Mean absolute error, see [mean absolute error](mean_absolute_error/index.html).
pub fn mean_absolute_error() -> mean_absolute_error::MeanAbsoluteError<T> {
mean_absolute_error::MeanAbsoluteError::new()
pub fn mean_absolute_error() -> mean_absolute_error::MeanAbsoluteError {
mean_absolute_error::MeanAbsoluteError {}
}
/// Coefficient of determination (R2), see [R2](r2/index.html).
pub fn r2() -> r2::R2<T> {
r2::R2::<T>::new()
pub fn r2() -> r2::R2 {
r2::R2 {}
}
}
impl<T: Number + Ord> ClusterMetrics<T> {
impl ClusterMetrics {
/// Homogeneity and completeness and V-Measure scores at once.
pub fn hcv_score() -> cluster_hcv::HCVScore<T> {
cluster_hcv::HCVScore::<T>::new()
pub fn hcv_score() -> cluster_hcv::HCVScore {
cluster_hcv::HCVScore {}
}
}
/// Function that calculated accuracy score, see [accuracy](accuracy/index.html).
/// * `y_true` - cround truth (correct) labels
/// * `y_pred` - predicted labels, as returned by a classifier.
pub fn accuracy<T: Number + Ord, V: ArrayView1<T>>(y_true: &V, y_pred: &V) -> f64 {
let obj = ClassificationMetricsOrd::<T>::accuracy();
obj.get_score(y_true, y_pred)
pub fn accuracy<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
ClassificationMetrics::accuracy().get_score(y_true, y_pred)
}
/// Calculated recall score, see [recall](recall/index.html)
/// * `y_true` - cround truth (correct) labels.
/// * `y_pred` - predicted labels, as returned by a classifier.
pub fn recall<T: Number + RealNumber + FloatNumber, V: ArrayView1<T>>(
y_true: &V,
y_pred: &V,
) -> f64 {
let obj = ClassificationMetrics::<T>::recall();
obj.get_score(y_true, y_pred)
pub fn recall<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
ClassificationMetrics::recall().get_score(y_true, y_pred)
}
/// Calculated precision score, see [precision](precision/index.html).
/// * `y_true` - cround truth (correct) labels.
/// * `y_pred` - predicted labels, as returned by a classifier.
pub fn precision<T: Number + RealNumber + FloatNumber, V: ArrayView1<T>>(
y_true: &V,
y_pred: &V,
) -> f64 {
let obj = ClassificationMetrics::<T>::precision();
obj.get_score(y_true, y_pred)
pub fn precision<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
ClassificationMetrics::precision().get_score(y_true, y_pred)
}
/// Computes F1 score, see [F1](f1/index.html).
/// * `y_true` - cround truth (correct) labels.
/// * `y_pred` - predicted labels, as returned by a classifier.
pub fn f1<T: Number + RealNumber + FloatNumber, V: ArrayView1<T>>(
y_true: &V,
y_pred: &V,
beta: f64,
) -> f64 {
let obj = ClassificationMetrics::<T>::f1(beta);
obj.get_score(y_true, y_pred)
pub fn f1<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V, beta: T) -> T {
ClassificationMetrics::f1(beta).get_score(y_true, y_pred)
}
/// AUC score, see [AUC](auc/index.html).
/// * `y_true` - cround truth (correct) labels.
/// * `y_pred_probabilities` - probability estimates, as returned by a classifier.
pub fn roc_auc_score<
T: Number + RealNumber + FloatNumber + PartialOrd,
V: ArrayView1<T> + Array1<T> + Array1<T>,
>(
y_true: &V,
y_pred_probabilities: &V,
) -> f64 {
let obj = ClassificationMetrics::<T>::roc_auc_score();
obj.get_score(y_true, y_pred_probabilities)
pub fn roc_auc_score<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred_probabilities: &V) -> T {
ClassificationMetrics::roc_auc_score().get_score(y_true, y_pred_probabilities)
}
/// Computes mean squared error, see [mean squared error](mean_squared_error/index.html).
/// * `y_true` - Ground truth (correct) target values.
/// * `y_pred` - Estimated target values.
pub fn mean_squared_error<T: Number + FloatNumber, V: ArrayView1<T>>(
y_true: &V,
y_pred: &V,
) -> f64 {
RegressionMetrics::<T>::mean_squared_error().get_score(y_true, y_pred)
pub fn mean_squared_error<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
RegressionMetrics::mean_squared_error().get_score(y_true, y_pred)
}
/// Computes mean absolute error, see [mean absolute error](mean_absolute_error/index.html).
/// * `y_true` - Ground truth (correct) target values.
/// * `y_pred` - Estimated target values.
pub fn mean_absolute_error<T: Number + FloatNumber, V: ArrayView1<T>>(
y_true: &V,
y_pred: &V,
) -> f64 {
RegressionMetrics::<T>::mean_absolute_error().get_score(y_true, y_pred)
pub fn mean_absolute_error<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
RegressionMetrics::mean_absolute_error().get_score(y_true, y_pred)
}
/// Computes R2 score, see [R2](r2/index.html).
/// * `y_true` - Ground truth (correct) target values.
/// * `y_pred` - Estimated target values.
pub fn r2<T: Number + FloatNumber, V: ArrayView1<T>>(y_true: &V, y_pred: &V) -> f64 {
RegressionMetrics::<T>::r2().get_score(y_true, y_pred)
pub fn r2<T: RealNumber, V: BaseVector<T>>(y_true: &V, y_pred: &V) -> T {
RegressionMetrics::r2().get_score(y_true, y_pred)
}
/// Homogeneity metric of a cluster labeling given a ground truth (range is between 0.0 and 1.0).
/// A cluster result satisfies homogeneity if all of its clusters contain only data points which are members of a single class.
/// * `labels_true` - ground truth class labels to be used as a reference.
/// * `labels_pred` - cluster labels to evaluate.
pub fn homogeneity_score<
T: Number + FloatNumber + RealNumber + Ord,
V: ArrayView1<T> + Array1<T>,
>(
y_true: &V,
y_pred: &V,
) -> f64 {
let mut obj = ClusterMetrics::<T>::hcv_score();
obj.compute(y_true, y_pred);
obj.homogeneity().unwrap()
pub fn homogeneity_score<T: RealNumber, V: BaseVector<T>>(labels_true: &V, labels_pred: &V) -> T {
ClusterMetrics::hcv_score()
.get_score(labels_true, labels_pred)
.0
}
///
/// Completeness metric of a cluster labeling given a ground truth (range is between 0.0 and 1.0).
/// * `labels_true` - ground truth class labels to be used as a reference.
/// * `labels_pred` - cluster labels to evaluate.
pub fn completeness_score<
T: Number + FloatNumber + RealNumber + Ord,
V: ArrayView1<T> + Array1<T>,
>(
y_true: &V,
y_pred: &V,
) -> f64 {
let mut obj = ClusterMetrics::<T>::hcv_score();
obj.compute(y_true, y_pred);
obj.completeness().unwrap()
pub fn completeness_score<T: RealNumber, V: BaseVector<T>>(labels_true: &V, labels_pred: &V) -> T {
ClusterMetrics::hcv_score()
.get_score(labels_true, labels_pred)
.1
}
/// The harmonic mean between homogeneity and completeness.
/// * `labels_true` - ground truth class labels to be used as a reference.
/// * `labels_pred` - cluster labels to evaluate.
pub fn v_measure_score<T: Number + FloatNumber + RealNumber + Ord, V: ArrayView1<T> + Array1<T>>(
y_true: &V,
y_pred: &V,
) -> f64 {
let mut obj = ClusterMetrics::<T>::hcv_score();
obj.compute(y_true, y_pred);
obj.v_measure().unwrap()
pub fn v_measure_score<T: RealNumber, V: BaseVector<T>>(labels_true: &V, labels_pred: &V) -> T {
ClusterMetrics::hcv_score()
.get_score(labels_true, labels_pred)
.2
}
+37 -145
View File
@@ -4,123 +4,72 @@
//!
//! \\[precision = \frac{tp}{tp + fp}\\]
//!
//! where tp (true positive) - correct result, fp (false positive) - unexpected result.
//! For binary classification, this is precision for the positive class (assumed to be 1.0).
//! For multiclass, this is macro-averaged precision (average of per-class precisions).
//! where tp (true positive) - correct result, fp (false positive) - unexpected result
//!
//! Example:
//!
//! ```
//! use smartcore::metrics::precision::Precision;
//! use smartcore::metrics::Metrics;
//! let y_pred: Vec<f64> = vec![0., 1., 1., 0.];
//! let y_true: Vec<f64> = vec![0., 0., 1., 1.];
//!
//! let score: f64 = Precision::new().get_score(&y_true, &y_pred);
//! let score: f64 = Precision {}.get_score(&y_pred, &y_true);
//! ```
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::collections::{HashMap, HashSet};
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::linalg::basic::arrays::ArrayView1;
use crate::numbers::realnum::RealNumber;
use crate::metrics::Metrics;
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
/// Precision metric.
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct Precision<T> {
_phantom: PhantomData<T>,
}
pub struct Precision {}
impl<T: RealNumber> Metrics<T> for Precision<T> {
/// create a typed object to call Precision functions
fn new() -> Self {
Self {
_phantom: PhantomData,
}
}
fn new_with(_parameter: f64) -> Self {
Self {
_phantom: PhantomData,
}
}
impl Precision {
/// Calculated precision score
/// * `y_true` - ground truth (correct) labels.
/// * `y_true` - cround truth (correct) labels.
/// * `y_pred` - predicted labels, as returned by a classifier.
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64 {
if y_true.shape() != y_pred.shape() {
pub fn get_score<T: RealNumber, V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
if y_true.len() != y_pred.len() {
panic!(
"The vector sizes don't match: {} != {}",
y_true.shape(),
y_pred.shape()
y_true.len(),
y_pred.len()
);
}
let n = y_true.shape();
let mut classes_set: HashSet<u64> = HashSet::new();
let mut tp = 0;
let mut p = 0;
let n = y_true.len();
for i in 0..n {
classes_set.insert(y_true.get(i).to_f64_bits());
}
let classes: usize = classes_set.len();
if y_true.get(i) != T::zero() && y_true.get(i) != T::one() {
panic!(
"Precision can only be applied to binary classification: {}",
y_true.get(i)
);
}
if classes == 2 {
// Binary case: precision for positive class (assumed T::one())
let positive = T::one();
let mut tp: usize = 0;
let mut fp_count: usize = 0;
for i in 0..n {
let t = *y_true.get(i);
let p = *y_pred.get(i);
if p == t {
if t == positive {
tp += 1;
}
} else if t != positive {
fp_count += 1;
if y_pred.get(i) != T::zero() && y_pred.get(i) != T::one() {
panic!(
"Precision can only be applied to binary classification: {}",
y_pred.get(i)
);
}
if y_pred.get(i) == T::one() {
p += 1;
if y_true.get(i) == T::one() {
tp += 1;
}
}
if tp + fp_count == 0 {
0.0
} else {
tp as f64 / (tp + fp_count) as f64
}
} else {
// Multiclass case: macro-averaged precision
let mut predicted: HashMap<u64, usize> = HashMap::new();
let mut tp_map: HashMap<u64, usize> = HashMap::new();
for i in 0..n {
let p_bits = y_pred.get(i).to_f64_bits();
*predicted.entry(p_bits).or_insert(0) += 1;
if *y_true.get(i) == *y_pred.get(i) {
*tp_map.entry(p_bits).or_insert(0) += 1;
}
}
let mut precision_sum = 0.0;
for &bits in &classes_set {
let pred_count = *predicted.get(&bits).unwrap_or(&0);
let tp = *tp_map.get(&bits).unwrap_or(&0);
let prec = if pred_count > 0 {
tp as f64 / pred_count as f64
} else {
0.0
};
precision_sum += prec;
}
if classes == 0 {
0.0
} else {
precision_sum / classes as f64
}
}
T::from_i64(tp).unwrap() / T::from_i64(p).unwrap()
}
}
@@ -128,73 +77,16 @@ impl<T: RealNumber> Metrics<T> for Precision<T> {
mod tests {
use super::*;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn precision() {
let y_true: Vec<f64> = vec![0., 1., 1., 0.];
let y_pred: Vec<f64> = vec![0., 0., 1., 1.];
let score1: f64 = Precision::new().get_score(&y_true, &y_pred);
let score2: f64 = Precision::new().get_score(&y_pred, &y_pred);
let score1: f64 = Precision {}.get_score(&y_pred, &y_true);
let score2: f64 = Precision {}.get_score(&y_pred, &y_pred);
assert!((score1 - 0.5).abs() < 1e-8);
assert!((score2 - 1.0).abs() < 1e-8);
let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
let score3: f64 = Precision::new().get_score(&y_true, &y_pred);
assert!((score3 - 0.5).abs() < 1e-8);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn precision_multiclass() {
let y_true: Vec<f64> = vec![0., 0., 0., 1., 1., 1., 2., 2., 2.];
let y_pred: Vec<f64> = vec![0., 1., 2., 0., 1., 2., 0., 1., 2.];
let score1: f64 = Precision::new().get_score(&y_true, &y_pred);
let score2: f64 = Precision::new().get_score(&y_pred, &y_pred);
assert!((score1 - 0.333333333).abs() < 1e-8);
assert!((score2 - 1.0).abs() < 1e-8);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn precision_multiclass_imbalanced() {
let y_true: Vec<f64> = vec![0., 0., 1., 2., 2., 2.];
let y_pred: Vec<f64> = vec![0., 1., 1., 2., 0., 2.];
let score: f64 = Precision::new().get_score(&y_true, &y_pred);
let expected = (0.5 + 0.5 + 1.0) / 3.0;
assert!((score - expected).abs() < 1e-8);
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
fn precision_multiclass_unpredicted_class() {
let y_true: Vec<f64> = vec![0., 0., 1., 2., 2., 2., 3.];
let y_pred: Vec<f64> = vec![0., 1., 1., 2., 0., 2., 0.];
let score: f64 = Precision::new().get_score(&y_true, &y_pred);
// Class 0: pred=3, tp=1 -> 1/3 ≈0.333
// Class 1: pred=2, tp=1 -> 0.5
// Class 2: pred=2, tp=2 -> 1.0
// Class 3: pred=0, tp=0 -> 0.0
let expected = (1.0 / 3.0 + 0.5 + 1.0 + 0.0) / 4.0;
assert!((score - expected).abs() < 1e-8);
}
}
+26 -40
View File
@@ -10,70 +10,59 @@
//!
//! ```
//! use smartcore::metrics::mean_absolute_error::MeanAbsoluteError;
//! use smartcore::metrics::Metrics;
//! let y_pred: Vec<f64> = vec![3., -0.5, 2., 7.];
//! let y_true: Vec<f64> = vec![2.5, 0.0, 2., 8.];
//!
//! let mse: f64 = MeanAbsoluteError::new().get_score( &y_true, &y_pred);
//! let mse: f64 = MeanAbsoluteError {}.get_score(&y_pred, &y_true);
//! ```
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::linalg::basic::arrays::ArrayView1;
use crate::numbers::basenum::Number;
use crate::metrics::Metrics;
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
/// Coefficient of Determination (R2)
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct R2<T> {
_phantom: PhantomData<T>,
}
pub struct R2 {}
impl<T: Number> Metrics<T> for R2<T> {
/// create a typed object to call R2 functions
fn new() -> Self {
Self {
_phantom: PhantomData,
}
}
fn new_with(_parameter: f64) -> Self {
Self {
_phantom: PhantomData,
}
}
impl R2 {
/// Computes R2 score
/// * `y_true` - Ground truth (correct) target values.
/// * `y_pred` - Estimated target values.
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64 {
if y_true.shape() != y_pred.shape() {
pub fn get_score<T: RealNumber, V: BaseVector<T>>(&self, y_true: &V, y_pred: &V) -> T {
if y_true.len() != y_pred.len() {
panic!(
"The vector sizes don't match: {} != {}",
y_true.shape(),
y_pred.shape()
y_true.len(),
y_pred.len()
);
}
let n = y_true.shape();
let n = y_true.len();
let mut mean = T::zero();
for i in 0..n {
mean += y_true.get(i);
}
mean /= T::from_usize(n).unwrap();
let mean: f64 = y_true.mean_by();
let mut ss_tot = T::zero();
let mut ss_res = T::zero();
for i in 0..n {
let y_i = *y_true.get(i);
let f_i = *y_pred.get(i);
ss_tot += (y_i - T::from(mean).unwrap()) * (y_i - T::from(mean).unwrap());
ss_res += (y_i - f_i) * (y_i - f_i);
let y_i = y_true.get(i);
let f_i = y_pred.get(i);
ss_tot += (y_i - mean).square();
ss_res += (y_i - f_i).square();
}
(T::one() - ss_res / ss_tot).to_f64().unwrap()
T::one() - (ss_res / ss_tot)
}
}
@@ -81,17 +70,14 @@ impl<T: Number> Metrics<T> for R2<T> {
mod tests {
use super::*;
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn r2() {
let y_true: Vec<f64> = vec![3., -0.5, 2., 7.];
let y_pred: Vec<f64> = vec![2.5, 0.0, 2., 8.];
let score1: f64 = R2::new().get_score(&y_true, &y_pred);
let score2: f64 = R2::new().get_score(&y_true, &y_true);
let score1: f64 = R2 {}.get_score(&y_true, &y_pred);
let score2: f64 = R2 {}.get_score(&y_true, &y_true);
assert!((score1 - 0.948608137).abs() < 1e-8);
assert!((score2 - 1.0).abs() < 1e-8);

Some files were not shown because too many files have changed in this diff Show More