Compare commits
226 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
765fab659c | ||
|
|
0df91706f2 | ||
|
|
2e5f88fad8 | ||
|
|
e445f0d558 | ||
|
|
4d5f64c758 | ||
|
|
28c81eb358 | ||
|
|
7f7b2edca0 | ||
|
|
d46b830bcd | ||
|
|
b6fb8191eb | ||
|
|
d305406dfd | ||
|
|
3d2f4f71fa | ||
|
|
61db4ebd90 | ||
|
|
2603a1f42b | ||
|
|
a1c56a859e | ||
|
|
d905ebea15 | ||
|
|
b482acdc8d | ||
|
|
663db0334d | ||
|
|
b4a807eb9f | ||
|
|
ff456df0a4 | ||
|
|
322610c7fb | ||
|
|
70df9a8b49 | ||
|
|
7ea620e6fd | ||
|
|
db5edcf67a | ||
|
|
8297cbe67e | ||
|
|
38c9b5ad2f | ||
|
|
820201e920 | ||
|
|
389b0e8e67 | ||
|
|
f93286ffbd | ||
|
|
12c102d02b | ||
|
|
521dab49ef | ||
|
|
3bf8813946 | ||
|
|
7830946ecb | ||
|
|
813c7ab233 | ||
|
|
4397c91570 | ||
|
|
14245e15ad | ||
|
|
d0a4ccbe20 | ||
|
|
85b9fde9a7 | ||
|
|
d239314967 | ||
|
|
4bae62ab2f | ||
|
|
e8cba343ca | ||
|
|
0b3bf946df | ||
|
|
763a8370eb | ||
|
|
1208051fb5 | ||
|
|
436d0a089f | ||
|
|
92265cc979 | ||
|
|
513d3898c9 | ||
|
|
4b654b25ac | ||
|
|
5a2e1f1262 | ||
|
|
377d5d0b06 | ||
|
|
9ce448379a | ||
|
|
c295a0d1bb | ||
|
|
703dc9688b | ||
|
|
790979a26d | ||
|
|
162bed2aa2 | ||
|
|
5ed5772a4e | ||
|
|
d9814c0918 | ||
|
|
7f44b93838 | ||
|
|
02200ae1e3 | ||
|
|
3dc5336514 | ||
|
|
abeff7926e | ||
|
|
1395cc6518 | ||
|
|
4335ee5a56 | ||
|
|
4c1dbc3327 | ||
|
|
a920959ae3 | ||
|
|
6d58dbe2a2 | ||
|
|
023b449ff1 | ||
|
|
cd44f1d515 | ||
|
|
1b42f8a396 | ||
|
|
c0be45b667 | ||
|
|
0e9c517b1a | ||
|
|
fed11f005c | ||
|
|
483a21bec0 | ||
|
|
4fb2625a33 | ||
|
|
a30802ec43 | ||
|
|
4af69878e0 | ||
|
|
745d0b570e | ||
|
|
6b5bed6092 | ||
|
|
af6ec2d402 | ||
|
|
828df4e338 | ||
|
|
374dfeceb9 | ||
|
|
3cc20fd400 | ||
|
|
700d320724 | ||
|
|
ef06f45638 | ||
|
|
237b1160b1 | ||
|
|
d31145b4fe | ||
|
|
19ff6df84c | ||
|
|
228b54baf7 | ||
|
|
03b9f76e9f | ||
|
|
a882741e12 | ||
|
|
f4b5936dcf | ||
|
|
863be5ef75 | ||
|
|
ca0816db97 | ||
|
|
2f03c1d6d7 | ||
|
|
c987d39d43 | ||
|
|
fd6b2e8014 | ||
|
|
cd5611079c | ||
|
|
dd39433ff8 | ||
|
|
3dc8a42832 | ||
|
|
3480e728af | ||
|
|
f91b1f9942 | ||
|
|
5c400f40d2 | ||
|
|
408b97d8aa | ||
|
|
6109fc5211 | ||
|
|
19088b682a | ||
|
|
244a724445 | ||
|
|
9833a2f851 | ||
|
|
68e7162fba | ||
|
|
7daf536aeb | ||
|
|
0df797cbae | ||
|
|
139bbae456 | ||
|
|
dbca6d43ce | ||
|
|
991631876e | ||
|
|
40a92ee4db | ||
|
|
87d4e9a423 | ||
|
|
bd5fbb63b1 | ||
|
|
272aabcd69 | ||
|
|
fd00bc3780 | ||
|
|
f1cf8a6f08 | ||
|
|
762986b271 | ||
|
|
e0d46f430b | ||
|
|
eb769493e7 | ||
|
|
4a941d1700 | ||
|
|
0e8166386c | ||
|
|
d91999b430 | ||
|
|
051023e4bb | ||
|
|
bb9a05b993 | ||
|
|
c5a7beaf0e | ||
|
|
9475d500db | ||
|
|
ba16c253b9 | ||
|
|
810a5c429b | ||
|
|
a69fb3aada | ||
|
|
d22be7d6ae | ||
|
|
32ae63a577 | ||
|
|
dd341f4a12 | ||
|
|
74f0d9e6fb | ||
|
|
f685f575e0 | ||
|
|
9b221979da | ||
|
|
a2be9e117f | ||
|
|
40dfca702e | ||
|
|
d8d751920b | ||
|
|
c9eb94ba93 | ||
|
|
97dece93de | ||
|
|
8ca13a76d6 | ||
|
|
5a185479a7 | ||
|
|
f76a1d1420 | ||
|
|
2c892aa603 | ||
|
|
1ce18b5296 | ||
|
|
413f1a0f55 | ||
|
|
505f495445 | ||
|
|
d39b04e549 | ||
|
|
74a7c45c75 | ||
|
|
cceb2f046d | ||
|
|
a27c29b736 | ||
|
|
78673b597f | ||
|
|
53351b2ece | ||
|
|
2650416235 | ||
|
|
f0b348dd6e | ||
|
|
4720a3a4eb | ||
|
|
c172c407d2 | ||
|
|
67e5829877 | ||
|
|
89a5136191 | ||
|
|
f9056f716a | ||
|
|
583284e66f | ||
|
|
9db993939e | ||
|
|
ad3ac49dde | ||
|
|
72e9f8293f | ||
|
|
aeddbc8a21 | ||
|
|
6587ac032b | ||
|
|
49487bccd3 | ||
|
|
900078cb04 | ||
|
|
82464f41e4 | ||
|
|
830a0d9194 | ||
|
|
f0371673a4 | ||
|
|
8f72716fe9 | ||
|
|
cc26555bfd | ||
|
|
c42fccdc22 | ||
|
|
b86c553bb1 | ||
|
|
7a4fe114d8 | ||
|
|
ca3a3a101c | ||
|
|
f46d3ba94c | ||
|
|
85d2ecd1c9 | ||
|
|
126b306681 | ||
|
|
18df9c758c | ||
|
|
d620f225ee | ||
|
|
c756496b71 | ||
|
|
3d4d5f64f6 | ||
|
|
5e887634db | ||
|
|
3c1969bdf5 | ||
|
|
0c35adf76a | ||
|
|
dd2864abe7 | ||
|
|
b780e0c289 | ||
|
|
513d916580 | ||
|
|
43584e14e5 | ||
|
|
4d75af6703 | ||
|
|
8a2da00665 | ||
|
|
54886ebd72 | ||
|
|
ea5de9758a | ||
|
|
860056c3ba | ||
|
|
8281a1620e | ||
|
|
ba03ef4678 | ||
|
|
83048dbe94 | ||
|
|
ab7f46603c | ||
|
|
4efad85f8a | ||
|
|
b8fea67fd2 | ||
|
|
6473a6c4ae | ||
|
|
7007e06c9c | ||
|
|
3732ad446c | ||
|
|
a9446c00c2 | ||
|
|
81395bcbb7 | ||
|
|
3a3f904914 | ||
|
|
797dc3c8e0 | ||
|
|
cf4f658f01 | ||
|
|
7a95378a96 | ||
|
|
1773ed0e6e | ||
|
|
bf8d0c081f | ||
|
|
aa38fc8b70 | ||
|
|
47abbbe8b6 | ||
|
|
1b9347baa1 | ||
|
|
5f2984f617 | ||
|
|
83d28dea62 | ||
|
|
5f59588eac | ||
|
|
92dad01810 | ||
|
|
20e58a8817 | ||
|
|
a2588f6f45 | ||
|
|
bb96354363 | ||
|
|
c43990e932 |
@@ -1,26 +0,0 @@
|
||||
version: 2.1
|
||||
|
||||
jobs:
|
||||
build:
|
||||
docker:
|
||||
- image: circleci/rust:latest
|
||||
environment:
|
||||
TZ: "/usr/share/zoneinfo/your/location"
|
||||
steps:
|
||||
- checkout
|
||||
- restore_cache:
|
||||
key: project-cache
|
||||
- run:
|
||||
name: Check formatting
|
||||
command: cargo fmt -- --check
|
||||
- run:
|
||||
name: Stable Build
|
||||
command: cargo build --features "nalgebra-bindings ndarray-bindings"
|
||||
- run:
|
||||
name: Test
|
||||
command: cargo test --features "nalgebra-bindings ndarray-bindings"
|
||||
- save_cache:
|
||||
key: project-cache
|
||||
paths:
|
||||
- "~/.cargo"
|
||||
- "./target"
|
||||
@@ -0,0 +1,11 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: cargo
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: daily
|
||||
open-pull-requests-limit: 10
|
||||
ignore:
|
||||
- dependency-name: rand_distr
|
||||
versions:
|
||||
- 0.4.0
|
||||
@@ -0,0 +1,57 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, development ]
|
||||
pull_request:
|
||||
branches: [ development ]
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
runs-on: "${{ matrix.platform.os }}-latest"
|
||||
strategy:
|
||||
matrix:
|
||||
platform: [
|
||||
{ os: "windows", target: "x86_64-pc-windows-msvc" },
|
||||
{ os: "windows", target: "i686-pc-windows-msvc" },
|
||||
{ os: "ubuntu", target: "x86_64-unknown-linux-gnu" },
|
||||
{ os: "ubuntu", target: "i686-unknown-linux-gnu" },
|
||||
{ os: "ubuntu", target: "wasm32-unknown-unknown" },
|
||||
{ os: "macos", target: "aarch64-apple-darwin" },
|
||||
]
|
||||
env:
|
||||
TZ: "/usr/share/zoneinfo/your/location"
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Cache .cargo and target
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: |
|
||||
~/.cargo
|
||||
./target
|
||||
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.toml') }}
|
||||
restore-keys: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.toml') }}
|
||||
- name: Install Rust toolchain
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: stable
|
||||
target: ${{ matrix.platform.target }}
|
||||
profile: minimal
|
||||
default: true
|
||||
- name: Install test runner for wasm
|
||||
if: matrix.platform.target == 'wasm32-unknown-unknown'
|
||||
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
|
||||
- name: Stable Build
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: build
|
||||
args: --all-features --target ${{ matrix.platform.target }}
|
||||
- name: Tests
|
||||
if: matrix.platform.target == 'x86_64-unknown-linux-gnu' || matrix.platform.target == 'x86_64-pc-windows-msvc' || matrix.platform.target == 'aarch64-apple-darwin'
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --all-features
|
||||
- name: Tests in WASM
|
||||
if: matrix.platform.target == 'wasm32-unknown-unknown'
|
||||
run: wasm-pack test --node -- --all-features
|
||||
@@ -0,0 +1,44 @@
|
||||
name: Coverage
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, development ]
|
||||
pull_request:
|
||||
branches: [ development ]
|
||||
|
||||
jobs:
|
||||
coverage:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
TZ: "/usr/share/zoneinfo/your/location"
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Cache .cargo
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: |
|
||||
~/.cargo
|
||||
./target
|
||||
key: ${{ runner.os }}-coverage-cargo-${{ hashFiles('**/Cargo.toml') }}
|
||||
restore-keys: ${{ runner.os }}-coverage-cargo-${{ hashFiles('**/Cargo.toml') }}
|
||||
- name: Install Rust toolchain
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly
|
||||
profile: minimal
|
||||
default: true
|
||||
- name: Install cargo-tarpaulin
|
||||
uses: actions-rs/install@v0.1
|
||||
with:
|
||||
crate: cargo-tarpaulin
|
||||
version: latest
|
||||
use-tool-cache: true
|
||||
- name: Run cargo-tarpaulin
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: tarpaulin
|
||||
args: --out Lcov --all-features -- --test-threads 1
|
||||
- name: Upload to codecov.io
|
||||
uses: codecov/codecov-action@v1
|
||||
with:
|
||||
fail_ci_if_error: true
|
||||
@@ -0,0 +1,41 @@
|
||||
name: Lint checks
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, development ]
|
||||
pull_request:
|
||||
branches: [ development ]
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
TZ: "/usr/share/zoneinfo/your/location"
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Cache .cargo and target
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: |
|
||||
~/.cargo
|
||||
./target
|
||||
key: ${{ runner.os }}-lint-cargo-${{ hashFiles('**/Cargo.toml') }}
|
||||
restore-keys: ${{ runner.os }}-lint-cargo-${{ hashFiles('**/Cargo.toml') }}
|
||||
- name: Install Rust toolchain
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: stable
|
||||
profile: minimal
|
||||
default: true
|
||||
- run: rustup component add rustfmt
|
||||
- name: Check formt
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: fmt
|
||||
args: --all -- --check
|
||||
- run: rustup component add clippy
|
||||
- name: Run clippy
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: clippy
|
||||
args: --all-features -- -Drust-2018-idioms -Dwarnings
|
||||
@@ -0,0 +1,60 @@
|
||||
# Changelog
|
||||
All notable changes to this project will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
## Added
|
||||
- L2 regularization penalty to the Logistic Regression
|
||||
- Getters for the naive bayes structs
|
||||
- One hot encoder
|
||||
- Make moons data generator
|
||||
- Support for WASM.
|
||||
|
||||
## Changed
|
||||
- Make serde optional
|
||||
|
||||
## [0.2.0] - 2021-01-03
|
||||
|
||||
### Added
|
||||
- DBSCAN
|
||||
- Epsilon-SVR, SVC
|
||||
- Ridge, Lasso, ElasticNet
|
||||
- Bernoulli, Gaussian, Categorical and Multinomial Naive Bayes
|
||||
- K-fold Cross Validation
|
||||
- Singular value decomposition
|
||||
- New api module
|
||||
- Integration with Clippy
|
||||
- Cholesky decomposition
|
||||
|
||||
### Changed
|
||||
- ndarray upgraded to 0.14
|
||||
- smartcore::error:FailedError is now non-exhaustive
|
||||
- K-Means
|
||||
- PCA
|
||||
- Random Forest
|
||||
- Linear and Logistic Regression
|
||||
- KNN
|
||||
- Decision Tree
|
||||
|
||||
## [0.1.0] - 2020-09-25
|
||||
|
||||
### Added
|
||||
- First release of smartcore.
|
||||
- KNN + distance metrics (Euclidian, Minkowski, Manhattan, Hamming, Mahalanobis)
|
||||
- Linear Regression (OLS)
|
||||
- Logistic Regression
|
||||
- Random Forest Classifier
|
||||
- Decision Tree Classifier
|
||||
- PCA
|
||||
- K-Means
|
||||
- Integrated with ndarray
|
||||
- Abstract linear algebra methods
|
||||
- RandomForest Regressor
|
||||
- Decision Tree Regressor
|
||||
- Serde integration
|
||||
- Integrated with nalgebra
|
||||
- LU, QR, SVD, EVD
|
||||
- Evaluation Metrics
|
||||
+29
-10
@@ -2,7 +2,7 @@
|
||||
name = "smartcore"
|
||||
description = "The most advanced machine learning library in rust."
|
||||
homepage = "https://smartcorelib.org"
|
||||
version = "0.1.0"
|
||||
version = "0.2.1"
|
||||
authors = ["SmartCore Developers"]
|
||||
edition = "2018"
|
||||
license = "Apache-2.0"
|
||||
@@ -17,21 +17,40 @@ default = ["datasets"]
|
||||
ndarray-bindings = ["ndarray"]
|
||||
nalgebra-bindings = ["nalgebra"]
|
||||
datasets = []
|
||||
fp_bench = []
|
||||
|
||||
[dependencies]
|
||||
ndarray = { version = "0.13", optional = true }
|
||||
nalgebra = { version = "0.22.0", optional = true }
|
||||
num-traits = "0.2.12"
|
||||
num = "0.3.0"
|
||||
rand = "0.7.3"
|
||||
serde = { version = "1.0.115", features = ["derive"] }
|
||||
serde_derive = "1.0.115"
|
||||
ndarray = { version = "0.15", optional = true }
|
||||
nalgebra = { version = "0.31", optional = true }
|
||||
num-traits = "0.2"
|
||||
num = "0.4"
|
||||
rand = "0.8"
|
||||
rand_distr = "0.4"
|
||||
serde = { version = "1", features = ["derive"], optional = true }
|
||||
itertools = "0.10.3"
|
||||
|
||||
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.3"
|
||||
smartcore = { path = ".", features = ["fp_bench"] }
|
||||
criterion = { version = "0.4", default-features = false }
|
||||
serde_json = "1.0"
|
||||
bincode = "1.3.1"
|
||||
|
||||
[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
|
||||
wasm-bindgen-test = "0.3"
|
||||
|
||||
[[bench]]
|
||||
name = "distance"
|
||||
harness = false
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "naive_bayes"
|
||||
harness = false
|
||||
required-features = ["ndarray-bindings", "nalgebra-bindings"]
|
||||
|
||||
[[bench]]
|
||||
name = "fastpair"
|
||||
harness = false
|
||||
required-features = ["fp_bench"]
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
|
||||
|
||||
// to run this bench you have to change the declaraion in mod.rs ---> pub mod fastpair;
|
||||
use smartcore::algorithm::neighbour::fastpair::FastPair;
|
||||
use smartcore::linalg::naive::dense_matrix::*;
|
||||
use std::time::Duration;
|
||||
|
||||
fn closest_pair_bench(n: usize, m: usize) -> () {
|
||||
let x = DenseMatrix::<f64>::rand(n, m);
|
||||
let fastpair = FastPair::new(&x);
|
||||
let result = fastpair.unwrap();
|
||||
|
||||
result.closest_pair();
|
||||
}
|
||||
|
||||
fn closest_pair_brute_bench(n: usize, m: usize) -> () {
|
||||
let x = DenseMatrix::<f64>::rand(n, m);
|
||||
let fastpair = FastPair::new(&x);
|
||||
let result = fastpair.unwrap();
|
||||
|
||||
result.closest_pair_brute();
|
||||
}
|
||||
|
||||
fn bench_fastpair(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("FastPair");
|
||||
|
||||
// with full samples size (100) the test will take too long
|
||||
group.significance_level(0.1).sample_size(30);
|
||||
// increase from default 5.0 secs
|
||||
group.measurement_time(Duration::from_secs(60));
|
||||
|
||||
for n_samples in [100_usize, 1000_usize].iter() {
|
||||
for n_features in [10_usize, 100_usize, 1000_usize].iter() {
|
||||
group.bench_with_input(
|
||||
BenchmarkId::from_parameter(format!(
|
||||
"fastpair --- n_samples: {}, n_features: {}",
|
||||
n_samples, n_features
|
||||
)),
|
||||
n_samples,
|
||||
|b, _| b.iter(|| closest_pair_bench(*n_samples, *n_features)),
|
||||
);
|
||||
group.bench_with_input(
|
||||
BenchmarkId::from_parameter(format!(
|
||||
"brute --- n_samples: {}, n_features: {}",
|
||||
n_samples, n_features
|
||||
)),
|
||||
n_samples,
|
||||
|b, _| b.iter(|| closest_pair_brute_bench(*n_samples, *n_features)),
|
||||
);
|
||||
}
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(benches, bench_fastpair);
|
||||
criterion_main!(benches);
|
||||
@@ -0,0 +1,73 @@
|
||||
use criterion::BenchmarkId;
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
|
||||
use nalgebra::DMatrix;
|
||||
use ndarray::Array2;
|
||||
use smartcore::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use smartcore::linalg::BaseMatrix;
|
||||
use smartcore::linalg::BaseVector;
|
||||
use smartcore::naive_bayes::gaussian::GaussianNB;
|
||||
|
||||
pub fn gaussian_naive_bayes_fit_benchmark(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("GaussianNB::fit");
|
||||
|
||||
for n_samples in [100_usize, 1000_usize, 10000_usize].iter() {
|
||||
for n_features in [10_usize, 100_usize, 1000_usize].iter() {
|
||||
let x = DenseMatrix::<f64>::rand(*n_samples, *n_features);
|
||||
let y: Vec<f64> = (0..*n_samples)
|
||||
.map(|i| (i % *n_samples / 5_usize) as f64)
|
||||
.collect::<Vec<f64>>();
|
||||
group.bench_with_input(
|
||||
BenchmarkId::from_parameter(format!(
|
||||
"n_samples: {}, n_features: {}",
|
||||
n_samples, n_features
|
||||
)),
|
||||
n_samples,
|
||||
|b, _| {
|
||||
b.iter(|| {
|
||||
GaussianNB::fit(black_box(&x), black_box(&y), Default::default()).unwrap();
|
||||
})
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
pub fn gaussian_naive_matrix_datastructure(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("GaussianNB");
|
||||
let classes = (0..10000).map(|i| (i % 25) as f64).collect::<Vec<f64>>();
|
||||
|
||||
group.bench_function("DenseMatrix", |b| {
|
||||
let x = DenseMatrix::<f64>::rand(10000, 500);
|
||||
let y = <DenseMatrix<f64> as BaseMatrix<f64>>::RowVector::from_array(&classes);
|
||||
|
||||
b.iter(|| {
|
||||
GaussianNB::fit(black_box(&x), black_box(&y), Default::default()).unwrap();
|
||||
})
|
||||
});
|
||||
|
||||
group.bench_function("ndarray", |b| {
|
||||
let x = Array2::<f64>::rand(10000, 500);
|
||||
let y = <Array2<f64> as BaseMatrix<f64>>::RowVector::from_array(&classes);
|
||||
|
||||
b.iter(|| {
|
||||
GaussianNB::fit(black_box(&x), black_box(&y), Default::default()).unwrap();
|
||||
})
|
||||
});
|
||||
|
||||
group.bench_function("ndalgebra", |b| {
|
||||
let x = DMatrix::<f64>::rand(10000, 500);
|
||||
let y = <DMatrix<f64> as BaseMatrix<f64>>::RowVector::from_array(&classes);
|
||||
|
||||
b.iter(|| {
|
||||
GaussianNB::fit(black_box(&x), black_box(&y), Default::default()).unwrap();
|
||||
})
|
||||
});
|
||||
}
|
||||
criterion_group!(
|
||||
benches,
|
||||
gaussian_naive_bayes_fit_benchmark,
|
||||
gaussian_naive_matrix_datastructure
|
||||
);
|
||||
criterion_main!(benches);
|
||||
+3
-3
@@ -9,9 +9,9 @@
|
||||
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
|
||||
inkscape:version="1.0 (4035a4f, 2020-05-01)"
|
||||
sodipodi:docname="smartcore.svg"
|
||||
width="396.01309mm"
|
||||
height="86.286003mm"
|
||||
viewBox="0 0 396.0131 86.286004"
|
||||
width="1280"
|
||||
height="320"
|
||||
viewBox="0 0 454 86.286004"
|
||||
version="1.1"
|
||||
id="svg512">
|
||||
<metadata
|
||||
|
||||
|
Before Width: | Height: | Size: 2.5 KiB After Width: | Height: | Size: 2.5 KiB |
@@ -44,14 +44,11 @@ impl<T: RealNumber> BBDTree<T> {
|
||||
|
||||
let (n, _) = data.shape();
|
||||
|
||||
let mut index = vec![0; n];
|
||||
for i in 0..n {
|
||||
index[i] = i;
|
||||
}
|
||||
let index = (0..n).collect::<Vec<_>>();
|
||||
|
||||
let mut tree = BBDTree {
|
||||
nodes: nodes,
|
||||
index: index,
|
||||
nodes,
|
||||
index,
|
||||
root: 0,
|
||||
};
|
||||
|
||||
@@ -62,9 +59,9 @@ impl<T: RealNumber> BBDTree<T> {
|
||||
tree
|
||||
}
|
||||
|
||||
pub(in crate) fn clustering(
|
||||
pub(crate) fn clustering(
|
||||
&self,
|
||||
centroids: &Vec<Vec<T>>,
|
||||
centroids: &[Vec<T>],
|
||||
sums: &mut Vec<Vec<T>>,
|
||||
counts: &mut Vec<usize>,
|
||||
membership: &mut Vec<usize>,
|
||||
@@ -92,8 +89,8 @@ impl<T: RealNumber> BBDTree<T> {
|
||||
fn filter(
|
||||
&self,
|
||||
node: usize,
|
||||
centroids: &Vec<Vec<T>>,
|
||||
candidates: &Vec<usize>,
|
||||
centroids: &[Vec<T>],
|
||||
candidates: &[usize],
|
||||
k: usize,
|
||||
sums: &mut Vec<Vec<T>>,
|
||||
counts: &mut Vec<usize>,
|
||||
@@ -113,19 +110,19 @@ impl<T: RealNumber> BBDTree<T> {
|
||||
}
|
||||
}
|
||||
|
||||
if !self.nodes[node].lower.is_none() {
|
||||
if self.nodes[node].lower.is_some() {
|
||||
let mut new_candidates = vec![0; k];
|
||||
let mut newk = 0;
|
||||
|
||||
for i in 0..k {
|
||||
for candidate in candidates.iter().take(k) {
|
||||
if !BBDTree::prune(
|
||||
&self.nodes[node].center,
|
||||
&self.nodes[node].radius,
|
||||
centroids,
|
||||
closest,
|
||||
candidates[i],
|
||||
*candidate,
|
||||
) {
|
||||
new_candidates[newk] = candidates[i];
|
||||
new_candidates[newk] = *candidate;
|
||||
newk += 1;
|
||||
}
|
||||
}
|
||||
@@ -134,7 +131,7 @@ impl<T: RealNumber> BBDTree<T> {
|
||||
return self.filter(
|
||||
self.nodes[node].lower.unwrap(),
|
||||
centroids,
|
||||
&mut new_candidates,
|
||||
&new_candidates,
|
||||
newk,
|
||||
sums,
|
||||
counts,
|
||||
@@ -142,7 +139,7 @@ impl<T: RealNumber> BBDTree<T> {
|
||||
) + self.filter(
|
||||
self.nodes[node].upper.unwrap(),
|
||||
centroids,
|
||||
&mut new_candidates,
|
||||
&new_candidates,
|
||||
newk,
|
||||
sums,
|
||||
counts,
|
||||
@@ -152,7 +149,7 @@ impl<T: RealNumber> BBDTree<T> {
|
||||
}
|
||||
|
||||
for i in 0..d {
|
||||
sums[closest][i] = sums[closest][i] + self.nodes[node].sum[i];
|
||||
sums[closest][i] += self.nodes[node].sum[i];
|
||||
}
|
||||
|
||||
counts[closest] += self.nodes[node].count;
|
||||
@@ -166,9 +163,9 @@ impl<T: RealNumber> BBDTree<T> {
|
||||
}
|
||||
|
||||
fn prune(
|
||||
center: &Vec<T>,
|
||||
radius: &Vec<T>,
|
||||
centroids: &Vec<Vec<T>>,
|
||||
center: &[T],
|
||||
radius: &[T],
|
||||
centroids: &[Vec<T>],
|
||||
best_index: usize,
|
||||
test_index: usize,
|
||||
) -> bool {
|
||||
@@ -184,11 +181,11 @@ impl<T: RealNumber> BBDTree<T> {
|
||||
let mut rhs = T::zero();
|
||||
for i in 0..d {
|
||||
let diff = test[i] - best[i];
|
||||
lhs = lhs + diff * diff;
|
||||
lhs += diff * diff;
|
||||
if diff > T::zero() {
|
||||
rhs = rhs + (center[i] + radius[i] - best[i]) * diff;
|
||||
rhs += (center[i] + radius[i] - best[i]) * diff;
|
||||
} else {
|
||||
rhs = rhs + (center[i] - radius[i] - best[i]) * diff;
|
||||
rhs += (center[i] - radius[i] - best[i]) * diff;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -244,7 +241,7 @@ impl<T: RealNumber> BBDTree<T> {
|
||||
if end > begin + 1 {
|
||||
let len = end - begin;
|
||||
for i in 0..d {
|
||||
node.sum[i] = node.sum[i] * T::from(len).unwrap();
|
||||
node.sum[i] *= T::from(len).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -261,9 +258,7 @@ impl<T: RealNumber> BBDTree<T> {
|
||||
let mut i2_good = data.get(self.index[i2], split_index) >= split_cutoff;
|
||||
|
||||
if !i1_good && !i2_good {
|
||||
let temp = self.index[i1];
|
||||
self.index[i1] = self.index[i2];
|
||||
self.index[i2] = temp;
|
||||
self.index.swap(i1, i2);
|
||||
i1_good = true;
|
||||
i2_good = true;
|
||||
}
|
||||
@@ -287,8 +282,8 @@ impl<T: RealNumber> BBDTree<T> {
|
||||
}
|
||||
|
||||
let mut mean = vec![T::zero(); d];
|
||||
for i in 0..d {
|
||||
mean[i] = node.sum[i] / T::from(node.count).unwrap();
|
||||
for (i, mean_i) in mean.iter_mut().enumerate().take(d) {
|
||||
*mean_i = node.sum[i] / T::from(node.count).unwrap();
|
||||
}
|
||||
|
||||
node.cost = BBDTree::node_cost(&self.nodes[node.lower.unwrap()], &mean)
|
||||
@@ -297,12 +292,12 @@ impl<T: RealNumber> BBDTree<T> {
|
||||
self.add_node(node)
|
||||
}
|
||||
|
||||
fn node_cost(node: &BBDTreeNode<T>, center: &Vec<T>) -> T {
|
||||
fn node_cost(node: &BBDTreeNode<T>, center: &[T]) -> T {
|
||||
let d = center.len();
|
||||
let mut scatter = T::zero();
|
||||
for i in 0..d {
|
||||
let x = (node.sum[i] / T::from(node.count).unwrap()) - center[i];
|
||||
scatter = scatter + x * x;
|
||||
for (i, center_i) in center.iter().enumerate().take(d) {
|
||||
let x = (node.sum[i] / T::from(node.count).unwrap()) - *center_i;
|
||||
scatter += x * x;
|
||||
}
|
||||
node.cost + T::from(node.count).unwrap() * scatter
|
||||
}
|
||||
@@ -319,6 +314,7 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn bbdtree_iris() {
|
||||
let data = DenseMatrix::from_2d_array(&[
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
//! use smartcore::algorithm::neighbour::cover_tree::*;
|
||||
//! use smartcore::math::distance::Distance;
|
||||
//!
|
||||
//! #[derive(Clone)]
|
||||
//! struct SimpleDistance {} // Our distance function
|
||||
//!
|
||||
//! impl Distance<i32, f64> for SimpleDistance {
|
||||
@@ -23,6 +24,7 @@
|
||||
//! ```
|
||||
use std::fmt::Debug;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::sort::heap_select::HeapSelection;
|
||||
@@ -31,7 +33,8 @@ use crate::math::distance::Distance;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Implements Cover Tree algorithm
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct CoverTree<T, F: RealNumber, D: Distance<T, F>> {
|
||||
base: F,
|
||||
inv_log_base: F,
|
||||
@@ -51,20 +54,21 @@ impl<T, F: RealNumber, D: Distance<T, F>> PartialEq for CoverTree<T, F, D> {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
struct Node<F: RealNumber> {
|
||||
idx: usize,
|
||||
max_dist: F,
|
||||
parent_dist: F,
|
||||
children: Vec<Node<F>>,
|
||||
scale: i64,
|
||||
_scale: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug)]
|
||||
struct DistanceSet<F: RealNumber> {
|
||||
idx: usize,
|
||||
dist: Vec<F>,
|
||||
@@ -81,14 +85,14 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
max_dist: F::zero(),
|
||||
parent_dist: F::zero(),
|
||||
children: Vec::new(),
|
||||
scale: 0,
|
||||
_scale: 0,
|
||||
};
|
||||
let mut tree = CoverTree {
|
||||
base: base,
|
||||
base,
|
||||
inv_log_base: F::one() / base.ln(),
|
||||
distance: distance,
|
||||
root: root,
|
||||
data: data,
|
||||
distance,
|
||||
root,
|
||||
data,
|
||||
identical_excluded: false,
|
||||
};
|
||||
|
||||
@@ -100,8 +104,8 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
/// Find k nearest neighbors of `p`
|
||||
/// * `p` - look for k nearest points to `p`
|
||||
/// * `k` - the number of nearest neighbors to return
|
||||
pub fn find(&self, p: &T, k: usize) -> Result<Vec<(usize, F)>, Failed> {
|
||||
if k <= 0 {
|
||||
pub fn find(&self, p: &T, k: usize) -> Result<Vec<(usize, F, &T)>, Failed> {
|
||||
if k == 0 {
|
||||
return Err(Failed::because(FailedError::FindFailed, "k should be > 0"));
|
||||
}
|
||||
|
||||
@@ -113,7 +117,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
}
|
||||
|
||||
let e = self.get_data_value(self.root.idx);
|
||||
let mut d = self.distance.distance(&e, p);
|
||||
let mut d = self.distance.distance(e, p);
|
||||
|
||||
let mut current_cover_set: Vec<(F, &Node<F>)> = Vec::new();
|
||||
let mut zero_set: Vec<(F, &Node<F>)> = Vec::new();
|
||||
@@ -147,10 +151,11 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
*heap.peek()
|
||||
};
|
||||
if d <= (upper_bound + child.max_dist) {
|
||||
if c > 0 && d < upper_bound {
|
||||
if !self.identical_excluded || self.get_data_value(child.idx) != p {
|
||||
heap.add(d);
|
||||
}
|
||||
if c > 0
|
||||
&& d < upper_bound
|
||||
&& (!self.identical_excluded || self.get_data_value(child.idx) != p)
|
||||
{
|
||||
heap.add(d);
|
||||
}
|
||||
|
||||
if !child.children.is_empty() {
|
||||
@@ -164,27 +169,84 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
current_cover_set = next_cover_set;
|
||||
}
|
||||
|
||||
let mut neighbors: Vec<(usize, F)> = Vec::new();
|
||||
let mut neighbors: Vec<(usize, F, &T)> = Vec::new();
|
||||
let upper_bound = *heap.peek();
|
||||
for ds in zero_set {
|
||||
if ds.0 <= upper_bound {
|
||||
let v = self.get_data_value(ds.1.idx);
|
||||
if !self.identical_excluded || v != p {
|
||||
neighbors.push((ds.1.idx, ds.0));
|
||||
neighbors.push((ds.1.idx, ds.0, v));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if neighbors.len() > k {
|
||||
neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
|
||||
}
|
||||
Ok(neighbors.into_iter().take(k).collect())
|
||||
}
|
||||
|
||||
/// Find all nearest neighbors within radius `radius` from `p`
|
||||
/// * `p` - look for k nearest points to `p`
|
||||
/// * `radius` - radius of the search
|
||||
pub fn find_radius(&self, p: &T, radius: F) -> Result<Vec<(usize, F, &T)>, Failed> {
|
||||
if radius <= F::zero() {
|
||||
return Err(Failed::because(
|
||||
FailedError::FindFailed,
|
||||
"radius should be > 0",
|
||||
));
|
||||
}
|
||||
|
||||
let mut neighbors: Vec<(usize, F, &T)> = Vec::new();
|
||||
|
||||
let mut current_cover_set: Vec<(F, &Node<F>)> = Vec::new();
|
||||
let mut zero_set: Vec<(F, &Node<F>)> = Vec::new();
|
||||
|
||||
let e = self.get_data_value(self.root.idx);
|
||||
let mut d = self.distance.distance(e, p);
|
||||
current_cover_set.push((d, &self.root));
|
||||
|
||||
while !current_cover_set.is_empty() {
|
||||
let mut next_cover_set: Vec<(F, &Node<F>)> = Vec::new();
|
||||
for par in current_cover_set {
|
||||
let parent = par.1;
|
||||
for c in 0..parent.children.len() {
|
||||
let child = &parent.children[c];
|
||||
if c == 0 {
|
||||
d = par.0;
|
||||
} else {
|
||||
d = self.distance.distance(self.get_data_value(child.idx), p);
|
||||
}
|
||||
|
||||
if d <= radius + child.max_dist {
|
||||
if !child.children.is_empty() {
|
||||
next_cover_set.push((d, child));
|
||||
} else if d <= radius {
|
||||
zero_set.push((d, child));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
current_cover_set = next_cover_set;
|
||||
}
|
||||
|
||||
for ds in zero_set {
|
||||
let v = self.get_data_value(ds.1.idx);
|
||||
if !self.identical_excluded || v != p {
|
||||
neighbors.push((ds.1.idx, ds.0, v));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(neighbors)
|
||||
}
|
||||
|
||||
fn new_leaf(&self, idx: usize) -> Node<F> {
|
||||
Node {
|
||||
idx: idx,
|
||||
idx,
|
||||
max_dist: F::zero(),
|
||||
parent_dist: F::zero(),
|
||||
children: Vec::new(),
|
||||
scale: 100,
|
||||
_scale: 100,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -228,7 +290,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
if point_set.is_empty() {
|
||||
self.new_leaf(p)
|
||||
} else {
|
||||
let max_dist = self.max(&point_set);
|
||||
let max_dist = self.max(point_set);
|
||||
let next_scale = (max_scale - 1).min(self.get_scale(max_dist));
|
||||
if next_scale == std::i64::MIN {
|
||||
let mut children: Vec<Node<F>> = Vec::new();
|
||||
@@ -244,8 +306,8 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
idx: p,
|
||||
max_dist: F::zero(),
|
||||
parent_dist: F::zero(),
|
||||
children: children,
|
||||
scale: 100,
|
||||
children,
|
||||
_scale: 100,
|
||||
}
|
||||
} else {
|
||||
let mut far: Vec<DistanceSet<F>> = Vec::new();
|
||||
@@ -257,8 +319,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
point_set.append(&mut far);
|
||||
child
|
||||
} else {
|
||||
let mut children: Vec<Node<F>> = Vec::new();
|
||||
children.push(child);
|
||||
let mut children: Vec<Node<F>> = vec![child];
|
||||
let mut new_point_set: Vec<DistanceSet<F>> = Vec::new();
|
||||
let mut new_consumed_set: Vec<DistanceSet<F>> = Vec::new();
|
||||
|
||||
@@ -314,8 +375,8 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
idx: p,
|
||||
max_dist: self.max(consumed_set),
|
||||
parent_dist: F::zero(),
|
||||
children: children,
|
||||
scale: (top_scale - max_scale),
|
||||
children,
|
||||
_scale: (top_scale - max_scale),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -381,14 +442,14 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
||||
}
|
||||
}
|
||||
|
||||
fn max(&self, distance_set: &Vec<DistanceSet<F>>) -> F {
|
||||
fn max(&self, distance_set: &[DistanceSet<F>]) -> F {
|
||||
let mut max = F::zero();
|
||||
for n in distance_set {
|
||||
if max < n.dist[n.dist.len() - 1] {
|
||||
max = n.dist[n.dist.len() - 1];
|
||||
}
|
||||
}
|
||||
return max;
|
||||
max
|
||||
}
|
||||
}
|
||||
|
||||
@@ -398,7 +459,8 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::math::distance::Distances;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
struct SimpleDistance {}
|
||||
|
||||
impl Distance<i32, f64> for SimpleDistance {
|
||||
@@ -407,6 +469,7 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn cover_tree_test() {
|
||||
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
|
||||
@@ -417,8 +480,13 @@ mod tests {
|
||||
knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
|
||||
let knn: Vec<usize> = knn.iter().map(|v| v.0).collect();
|
||||
assert_eq!(vec!(3, 4, 5), knn);
|
||||
}
|
||||
|
||||
let mut knn = tree.find_radius(&5, 2.0).unwrap();
|
||||
knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
|
||||
let knn: Vec<i32> = knn.iter().map(|v| *v.2).collect();
|
||||
assert_eq!(vec!(3, 4, 5, 6, 7), knn);
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn cover_tree_test1() {
|
||||
let data = vec![
|
||||
@@ -437,8 +505,9 @@ mod tests {
|
||||
|
||||
assert_eq!(vec!(0, 1, 2), knn);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
|
||||
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
//!
|
||||
//! Dissimilarities for vector-vector distance
|
||||
//!
|
||||
//! Representing distances as pairwise dissimilarities, so to build a
|
||||
//! graph of closest neighbours. This representation can be reused for
|
||||
//! different implementations (initially used in this library for FastPair).
|
||||
use std::cmp::{Eq, Ordering, PartialOrd};
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
///
|
||||
/// The edge of the subgraph is defined by `PairwiseDistance`.
|
||||
/// The calling algorithm can store a list of distsances as
|
||||
/// a list of these structures.
|
||||
///
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct PairwiseDistance<T: RealNumber> {
|
||||
/// index of the vector in the original `Matrix` or list
|
||||
pub node: usize,
|
||||
|
||||
/// index of the closest neighbor in the original `Matrix` or same list
|
||||
pub neighbour: Option<usize>,
|
||||
|
||||
/// measure of distance, according to the algorithm distance function
|
||||
/// if the distance is None, the edge has value "infinite" or max distance
|
||||
/// each algorithm has to match
|
||||
pub distance: Option<T>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Eq for PairwiseDistance<T> {}
|
||||
|
||||
impl<T: RealNumber> PartialEq for PairwiseDistance<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.node == other.node
|
||||
&& self.neighbour == other.neighbour
|
||||
&& self.distance == other.distance
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> PartialOrd for PairwiseDistance<T> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
self.distance.partial_cmp(&other.distance)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,570 @@
|
||||
#![allow(non_snake_case)]
|
||||
use itertools::Itertools;
|
||||
///
|
||||
/// # FastPair: Data-structure for the dynamic closest-pair problem.
|
||||
///
|
||||
/// Reference:
|
||||
/// Eppstein, David: Fast hierarchical clustering and other applications of
|
||||
/// dynamic closest pairs. Journal of Experimental Algorithmics 5 (2000) 1.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// use smartcore::algorithm::neighbour::distances::PairwiseDistance;
|
||||
/// use smartcore::linalg::naive::dense_matrix::DenseMatrix;
|
||||
/// use smartcore::algorithm::neighbour::fastpair::FastPair;
|
||||
/// let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
/// &[5.1, 3.5, 1.4, 0.2],
|
||||
/// &[4.9, 3.0, 1.4, 0.2],
|
||||
/// &[4.7, 3.2, 1.3, 0.2],
|
||||
/// &[4.6, 3.1, 1.5, 0.2],
|
||||
/// &[5.0, 3.6, 1.4, 0.2],
|
||||
/// &[5.4, 3.9, 1.7, 0.4],
|
||||
/// ]);
|
||||
/// let fastpair = FastPair::new(&x);
|
||||
/// let closest_pair: PairwiseDistance<f64> = fastpair.unwrap().closest_pair();
|
||||
/// ```
|
||||
/// <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
/// <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::algorithm::neighbour::distances::PairwiseDistance;
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::distance::euclidian::Euclidian;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
///
|
||||
/// Inspired by Python implementation:
|
||||
/// <https://github.com/carsonfarmer/fastpair/blob/b8b4d3000ab6f795a878936667eee1b557bf353d/fastpair/base.py>
|
||||
/// MIT License (MIT) Copyright (c) 2016 Carson Farmer
|
||||
///
|
||||
/// affinity used is Euclidean so to allow linkage with single, ward, complete and average
|
||||
///
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FastPair<'a, T: RealNumber, M: Matrix<T>> {
|
||||
/// initial matrix
|
||||
samples: &'a M,
|
||||
/// closest pair hashmap (connectivity matrix for closest pairs)
|
||||
pub distances: HashMap<usize, PairwiseDistance<T>>,
|
||||
/// conga line used to keep track of the closest pair
|
||||
pub neighbours: Vec<usize>,
|
||||
}
|
||||
|
||||
impl<'a, T: RealNumber, M: Matrix<T>> FastPair<'a, T, M> {
|
||||
///
|
||||
/// Constructor
|
||||
/// Instantiate and inizialise the algorithm
|
||||
///
|
||||
pub fn new(m: &'a M) -> Result<Self, Failed> {
|
||||
if m.shape().0 < 3 {
|
||||
return Err(Failed::because(
|
||||
FailedError::FindFailed,
|
||||
"min number of rows should be 3",
|
||||
));
|
||||
}
|
||||
|
||||
let mut init = Self {
|
||||
samples: m,
|
||||
// to be computed in init(..)
|
||||
distances: HashMap::with_capacity(m.shape().0),
|
||||
neighbours: Vec::with_capacity(m.shape().0 + 1),
|
||||
};
|
||||
init.init();
|
||||
Ok(init)
|
||||
}
|
||||
|
||||
///
|
||||
/// Initialise `FastPair` by passing a `Matrix`.
|
||||
/// Build a FastPairs data-structure from a set of (new) points.
|
||||
///
|
||||
fn init(&mut self) {
|
||||
// basic measures
|
||||
let len = self.samples.shape().0;
|
||||
let max_index = self.samples.shape().0 - 1;
|
||||
|
||||
// Store all closest neighbors
|
||||
let _distances = Box::new(HashMap::with_capacity(len));
|
||||
let _neighbours = Box::new(Vec::with_capacity(len));
|
||||
|
||||
let mut distances = *_distances;
|
||||
let mut neighbours = *_neighbours;
|
||||
|
||||
// fill neighbours with -1 values
|
||||
neighbours.extend(0..len);
|
||||
|
||||
// init closest neighbour pairwise data
|
||||
for index_row_i in 0..(max_index) {
|
||||
distances.insert(
|
||||
index_row_i,
|
||||
PairwiseDistance {
|
||||
node: index_row_i,
|
||||
neighbour: None,
|
||||
distance: Some(T::max_value()),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// loop through indeces and neighbours
|
||||
for index_row_i in 0..(len) {
|
||||
// start looking for the neighbour in the second element
|
||||
let mut index_closest = index_row_i + 1; // closest neighbour index
|
||||
let mut nbd: Option<T> = distances[&index_row_i].distance; // init neighbour distance
|
||||
for index_row_j in (index_row_i + 1)..len {
|
||||
distances.insert(
|
||||
index_row_j,
|
||||
PairwiseDistance {
|
||||
node: index_row_j,
|
||||
neighbour: Some(index_row_i),
|
||||
distance: nbd,
|
||||
},
|
||||
);
|
||||
|
||||
let d = Euclidian::squared_distance(
|
||||
&(self.samples.get_row_as_vec(index_row_i)),
|
||||
&(self.samples.get_row_as_vec(index_row_j)),
|
||||
);
|
||||
if d < nbd.unwrap() {
|
||||
// set this j-value to be the closest neighbour
|
||||
index_closest = index_row_j;
|
||||
nbd = Some(d);
|
||||
}
|
||||
}
|
||||
|
||||
// Add that edge
|
||||
distances.entry(index_row_i).and_modify(|e| {
|
||||
e.distance = nbd;
|
||||
e.neighbour = Some(index_closest);
|
||||
});
|
||||
}
|
||||
// No more neighbors, terminate conga line.
|
||||
// Last person on the line has no neigbors
|
||||
distances.get_mut(&max_index).unwrap().neighbour = Some(max_index);
|
||||
distances.get_mut(&(len - 1)).unwrap().distance = Some(T::max_value());
|
||||
|
||||
// compute sparse matrix (connectivity matrix)
|
||||
let mut sparse_matrix = M::zeros(len, len);
|
||||
for (_, p) in distances.iter() {
|
||||
sparse_matrix.set(p.node, p.neighbour.unwrap(), p.distance.unwrap());
|
||||
}
|
||||
|
||||
self.distances = distances;
|
||||
self.neighbours = neighbours;
|
||||
}
|
||||
|
||||
///
|
||||
/// Find closest pair by scanning list of nearest neighbors.
|
||||
///
|
||||
#[allow(dead_code)]
|
||||
pub fn closest_pair(&self) -> PairwiseDistance<T> {
|
||||
let mut a = self.neighbours[0]; // Start with first point
|
||||
let mut d = self.distances[&a].distance;
|
||||
for p in self.neighbours.iter() {
|
||||
if self.distances[p].distance < d {
|
||||
a = *p; // Update `a` and distance `d`
|
||||
d = self.distances[p].distance;
|
||||
}
|
||||
}
|
||||
let b = self.distances[&a].neighbour;
|
||||
PairwiseDistance {
|
||||
node: a,
|
||||
neighbour: b,
|
||||
distance: d,
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
/// Brute force algorithm, used only for comparison and testing
|
||||
///
|
||||
#[cfg(feature = "fp_bench")]
|
||||
pub fn closest_pair_brute(&self) -> PairwiseDistance<T> {
|
||||
let m = self.samples.shape().0;
|
||||
|
||||
let mut closest_pair = PairwiseDistance {
|
||||
node: 0,
|
||||
neighbour: None,
|
||||
distance: Some(T::max_value()),
|
||||
};
|
||||
for pair in (0..m).combinations(2) {
|
||||
let d = Euclidian::squared_distance(
|
||||
&(self.samples.get_row_as_vec(pair[0])),
|
||||
&(self.samples.get_row_as_vec(pair[1])),
|
||||
);
|
||||
if d < closest_pair.distance.unwrap() {
|
||||
closest_pair.node = pair[0];
|
||||
closest_pair.neighbour = Some(pair[1]);
|
||||
closest_pair.distance = Some(d);
|
||||
}
|
||||
}
|
||||
closest_pair
|
||||
}
|
||||
|
||||
//
|
||||
// Compute distances from input to all other points in data-structure.
|
||||
// input is the row index of the sample matrix
|
||||
//
|
||||
#[allow(dead_code)]
|
||||
fn distances_from(&self, index_row: usize) -> Vec<PairwiseDistance<T>> {
|
||||
let mut distances = Vec::<PairwiseDistance<T>>::with_capacity(self.samples.shape().0);
|
||||
for other in self.neighbours.iter() {
|
||||
if index_row != *other {
|
||||
distances.push(PairwiseDistance {
|
||||
node: index_row,
|
||||
neighbour: Some(*other),
|
||||
distance: Some(Euclidian::squared_distance(
|
||||
&(self.samples.get_row_as_vec(index_row)),
|
||||
&(self.samples.get_row_as_vec(*other)),
|
||||
)),
|
||||
})
|
||||
}
|
||||
}
|
||||
distances
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests_fastpair {
|
||||
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
|
||||
#[test]
|
||||
fn fastpair_init() {
|
||||
let x: DenseMatrix<f64> = DenseMatrix::rand(10, 4);
|
||||
let _fastpair = FastPair::new(&x);
|
||||
assert!(_fastpair.is_ok());
|
||||
|
||||
let fastpair = _fastpair.unwrap();
|
||||
|
||||
let distances = fastpair.distances;
|
||||
let neighbours = fastpair.neighbours;
|
||||
|
||||
assert!(distances.len() != 0);
|
||||
assert!(neighbours.len() != 0);
|
||||
|
||||
assert_eq!(10, neighbours.len());
|
||||
assert_eq!(10, distances.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dataset_has_at_least_three_points() {
|
||||
// Create a dataset which consists of only two points:
|
||||
// A(0.0, 0.0) and B(1.0, 1.0).
|
||||
let dataset = DenseMatrix::<f64>::from_2d_array(&[&[0.0, 0.0], &[1.0, 1.0]]);
|
||||
|
||||
// We expect an error when we run `FastPair` on this dataset,
|
||||
// becuase `FastPair` currently only works on a minimum of 3
|
||||
// points.
|
||||
let _fastpair = FastPair::new(&dataset);
|
||||
|
||||
match _fastpair {
|
||||
Err(e) => {
|
||||
let expected_error =
|
||||
Failed::because(FailedError::FindFailed, "min number of rows should be 3");
|
||||
assert_eq!(e, expected_error)
|
||||
}
|
||||
_ => {
|
||||
assert!(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one_dimensional_dataset_minimal() {
|
||||
let dataset = DenseMatrix::<f64>::from_2d_array(&[&[0.0], &[2.0], &[9.0]]);
|
||||
|
||||
let result = FastPair::new(&dataset);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let fastpair = result.unwrap();
|
||||
let closest_pair = fastpair.closest_pair();
|
||||
let expected_closest_pair = PairwiseDistance {
|
||||
node: 0,
|
||||
neighbour: Some(1),
|
||||
distance: Some(4.0),
|
||||
};
|
||||
assert_eq!(closest_pair, expected_closest_pair);
|
||||
|
||||
let closest_pair_brute = fastpair.closest_pair_brute();
|
||||
assert_eq!(closest_pair_brute, expected_closest_pair);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one_dimensional_dataset_2() {
|
||||
let dataset = DenseMatrix::<f64>::from_2d_array(&[&[27.0], &[0.0], &[9.0], &[2.0]]);
|
||||
|
||||
let result = FastPair::new(&dataset);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let fastpair = result.unwrap();
|
||||
let closest_pair = fastpair.closest_pair();
|
||||
let expected_closest_pair = PairwiseDistance {
|
||||
node: 1,
|
||||
neighbour: Some(3),
|
||||
distance: Some(4.0),
|
||||
};
|
||||
assert_eq!(closest_pair, fastpair.closest_pair_brute());
|
||||
assert_eq!(closest_pair, expected_closest_pair);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fastpair_new() {
|
||||
// compute
|
||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
]);
|
||||
let fastpair = FastPair::new(&x);
|
||||
assert!(fastpair.is_ok());
|
||||
|
||||
// unwrap results
|
||||
let result = fastpair.unwrap();
|
||||
|
||||
// list of minimal pairwise dissimilarities
|
||||
let dissimilarities = vec![
|
||||
(
|
||||
1,
|
||||
PairwiseDistance {
|
||||
node: 1,
|
||||
neighbour: Some(9),
|
||||
distance: Some(0.030000000000000037),
|
||||
},
|
||||
),
|
||||
(
|
||||
10,
|
||||
PairwiseDistance {
|
||||
node: 10,
|
||||
neighbour: Some(12),
|
||||
distance: Some(0.07000000000000003),
|
||||
},
|
||||
),
|
||||
(
|
||||
11,
|
||||
PairwiseDistance {
|
||||
node: 11,
|
||||
neighbour: Some(14),
|
||||
distance: Some(0.18000000000000013),
|
||||
},
|
||||
),
|
||||
(
|
||||
12,
|
||||
PairwiseDistance {
|
||||
node: 12,
|
||||
neighbour: Some(14),
|
||||
distance: Some(0.34000000000000086),
|
||||
},
|
||||
),
|
||||
(
|
||||
13,
|
||||
PairwiseDistance {
|
||||
node: 13,
|
||||
neighbour: Some(14),
|
||||
distance: Some(1.6499999999999997),
|
||||
},
|
||||
),
|
||||
(
|
||||
14,
|
||||
PairwiseDistance {
|
||||
node: 14,
|
||||
neighbour: Some(14),
|
||||
distance: Some(f64::MAX),
|
||||
},
|
||||
),
|
||||
(
|
||||
6,
|
||||
PairwiseDistance {
|
||||
node: 6,
|
||||
neighbour: Some(7),
|
||||
distance: Some(0.18000000000000027),
|
||||
},
|
||||
),
|
||||
(
|
||||
0,
|
||||
PairwiseDistance {
|
||||
node: 0,
|
||||
neighbour: Some(4),
|
||||
distance: Some(0.01999999999999995),
|
||||
},
|
||||
),
|
||||
(
|
||||
8,
|
||||
PairwiseDistance {
|
||||
node: 8,
|
||||
neighbour: Some(9),
|
||||
distance: Some(0.3100000000000001),
|
||||
},
|
||||
),
|
||||
(
|
||||
2,
|
||||
PairwiseDistance {
|
||||
node: 2,
|
||||
neighbour: Some(3),
|
||||
distance: Some(0.0600000000000001),
|
||||
},
|
||||
),
|
||||
(
|
||||
3,
|
||||
PairwiseDistance {
|
||||
node: 3,
|
||||
neighbour: Some(8),
|
||||
distance: Some(0.08999999999999982),
|
||||
},
|
||||
),
|
||||
(
|
||||
7,
|
||||
PairwiseDistance {
|
||||
node: 7,
|
||||
neighbour: Some(9),
|
||||
distance: Some(0.10999999999999982),
|
||||
},
|
||||
),
|
||||
(
|
||||
9,
|
||||
PairwiseDistance {
|
||||
node: 9,
|
||||
neighbour: Some(13),
|
||||
distance: Some(8.69),
|
||||
},
|
||||
),
|
||||
(
|
||||
4,
|
||||
PairwiseDistance {
|
||||
node: 4,
|
||||
neighbour: Some(7),
|
||||
distance: Some(0.050000000000000086),
|
||||
},
|
||||
),
|
||||
(
|
||||
5,
|
||||
PairwiseDistance {
|
||||
node: 5,
|
||||
neighbour: Some(7),
|
||||
distance: Some(0.4900000000000002),
|
||||
},
|
||||
),
|
||||
];
|
||||
|
||||
let expected: HashMap<_, _> = dissimilarities.into_iter().collect();
|
||||
|
||||
for i in 0..(x.shape().0 - 1) {
|
||||
let input_node = result.samples.get_row_as_vec(i);
|
||||
let input_neighbour: usize = expected.get(&i).unwrap().neighbour.unwrap();
|
||||
let distance = Euclidian::squared_distance(
|
||||
&input_node,
|
||||
&result.samples.get_row_as_vec(input_neighbour),
|
||||
);
|
||||
|
||||
assert_eq!(i, expected.get(&i).unwrap().node);
|
||||
assert_eq!(
|
||||
input_neighbour,
|
||||
expected.get(&i).unwrap().neighbour.unwrap()
|
||||
);
|
||||
assert_eq!(distance, expected.get(&i).unwrap().distance.unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fastpair_closest_pair() {
|
||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
]);
|
||||
// compute
|
||||
let fastpair = FastPair::new(&x);
|
||||
assert!(fastpair.is_ok());
|
||||
|
||||
let dissimilarity = fastpair.unwrap().closest_pair();
|
||||
let closest = PairwiseDistance {
|
||||
node: 0,
|
||||
neighbour: Some(4),
|
||||
distance: Some(0.01999999999999995),
|
||||
};
|
||||
|
||||
assert_eq!(closest, dissimilarity);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fastpair_closest_pair_random_matrix() {
|
||||
let x = DenseMatrix::<f64>::rand(200, 25);
|
||||
// compute
|
||||
let fastpair = FastPair::new(&x);
|
||||
assert!(fastpair.is_ok());
|
||||
|
||||
let result = fastpair.unwrap();
|
||||
|
||||
let dissimilarity1 = result.closest_pair();
|
||||
let dissimilarity2 = result.closest_pair_brute();
|
||||
|
||||
assert_eq!(dissimilarity1, dissimilarity2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fastpair_distances() {
|
||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
]);
|
||||
// compute
|
||||
let fastpair = FastPair::new(&x);
|
||||
assert!(fastpair.is_ok());
|
||||
|
||||
let dissimilarities = fastpair.unwrap().distances_from(0);
|
||||
|
||||
let mut min_dissimilarity = PairwiseDistance {
|
||||
node: 0,
|
||||
neighbour: None,
|
||||
distance: Some(f64::MAX),
|
||||
};
|
||||
for p in dissimilarities.iter() {
|
||||
if p.distance.unwrap() < min_dissimilarity.distance.unwrap() {
|
||||
min_dissimilarity = p.clone()
|
||||
}
|
||||
}
|
||||
|
||||
let closest = PairwiseDistance {
|
||||
node: 0,
|
||||
neighbour: Some(4),
|
||||
distance: Some(0.01999999999999995),
|
||||
};
|
||||
|
||||
assert_eq!(closest, min_dissimilarity);
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@
|
||||
//! use smartcore::algorithm::neighbour::linear_search::*;
|
||||
//! use smartcore::math::distance::Distance;
|
||||
//!
|
||||
//! #[derive(Clone)]
|
||||
//! struct SimpleDistance {} // Our distance function
|
||||
//!
|
||||
//! impl Distance<i32, f64> for SimpleDistance {
|
||||
@@ -21,17 +22,19 @@
|
||||
//!
|
||||
//! ```
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::cmp::{Ordering, PartialOrd};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::algorithm::sort::heap_select::HeapSelection;
|
||||
use crate::error::Failed;
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::math::distance::Distance;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Implements Linear Search algorithm, see [KNN algorithms](../index.html)
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct LinearKNNSearch<T, F: RealNumber, D: Distance<T, F>> {
|
||||
distance: D,
|
||||
data: Vec<T>,
|
||||
@@ -44,8 +47,8 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
||||
/// * `distance` - distance metric to use for searching. This function should extend [`Distance`](../../../math/distance/index.html) interface.
|
||||
pub fn new(data: Vec<T>, distance: D) -> Result<LinearKNNSearch<T, F, D>, Failed> {
|
||||
Ok(LinearKNNSearch {
|
||||
data: data,
|
||||
distance: distance,
|
||||
data,
|
||||
distance,
|
||||
f: PhantomData,
|
||||
})
|
||||
}
|
||||
@@ -53,9 +56,12 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
||||
/// Find k nearest neighbors
|
||||
/// * `from` - look for k nearest points to `from`
|
||||
/// * `k` - the number of nearest neighbors to return
|
||||
pub fn find(&self, from: &T, k: usize) -> Result<Vec<(usize, F)>, Failed> {
|
||||
pub fn find(&self, from: &T, k: usize) -> Result<Vec<(usize, F, &T)>, Failed> {
|
||||
if k < 1 || k > self.data.len() {
|
||||
panic!("k should be >= 1 and <= length(data)");
|
||||
return Err(Failed::because(
|
||||
FailedError::FindFailed,
|
||||
"k should be >= 1 and <= length(data)",
|
||||
));
|
||||
}
|
||||
|
||||
let mut heap = HeapSelection::<KNNPoint<F>>::with_capacity(k);
|
||||
@@ -68,7 +74,7 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
||||
}
|
||||
|
||||
for i in 0..self.data.len() {
|
||||
let d = self.distance.distance(&from, &self.data[i]);
|
||||
let d = self.distance.distance(from, &self.data[i]);
|
||||
let datum = heap.peek_mut();
|
||||
if d < datum.distance {
|
||||
datum.distance = d;
|
||||
@@ -80,9 +86,33 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
||||
Ok(heap
|
||||
.get()
|
||||
.into_iter()
|
||||
.flat_map(|x| x.index.map(|i| (i, x.distance)))
|
||||
.flat_map(|x| x.index.map(|i| (i, x.distance, &self.data[i])))
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Find all nearest neighbors within radius `radius` from `p`
|
||||
/// * `p` - look for k nearest points to `p`
|
||||
/// * `radius` - radius of the search
|
||||
pub fn find_radius(&self, from: &T, radius: F) -> Result<Vec<(usize, F, &T)>, Failed> {
|
||||
if radius <= F::zero() {
|
||||
return Err(Failed::because(
|
||||
FailedError::FindFailed,
|
||||
"radius should be > 0",
|
||||
));
|
||||
}
|
||||
|
||||
let mut neighbors: Vec<(usize, F, &T)> = Vec::new();
|
||||
|
||||
for i in 0..self.data.len() {
|
||||
let d = self.distance.distance(from, &self.data[i]);
|
||||
|
||||
if d <= radius {
|
||||
neighbors.push((i, d, &self.data[i]));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(neighbors)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -110,6 +140,8 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::math::distance::Distances;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
struct SimpleDistance {}
|
||||
|
||||
impl Distance<i32, f64> for SimpleDistance {
|
||||
@@ -118,6 +150,7 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn knn_find() {
|
||||
let data1 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
|
||||
@@ -130,10 +163,20 @@ mod tests {
|
||||
.iter()
|
||||
.map(|v| v.0)
|
||||
.collect();
|
||||
found_idxs1.sort();
|
||||
found_idxs1.sort_unstable();
|
||||
|
||||
assert_eq!(vec!(0, 1, 2), found_idxs1);
|
||||
|
||||
let mut found_idxs1: Vec<i32> = algorithm1
|
||||
.find_radius(&5, 3.0)
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|v| *v.2)
|
||||
.collect();
|
||||
found_idxs1.sort_unstable();
|
||||
|
||||
assert_eq!(vec!(2, 3, 4, 5, 6, 7, 8), found_idxs1);
|
||||
|
||||
let data2 = vec![
|
||||
vec![1., 1.],
|
||||
vec![2., 2.],
|
||||
@@ -150,11 +193,11 @@ mod tests {
|
||||
.iter()
|
||||
.map(|v| v.0)
|
||||
.collect();
|
||||
found_idxs2.sort();
|
||||
found_idxs2.sort_unstable();
|
||||
|
||||
assert_eq!(vec!(1, 2, 3), found_idxs2);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn knn_point_eq() {
|
||||
let point1 = KNNPoint {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#![allow(clippy::ptr_arg)]
|
||||
//! # Nearest Neighbors Search Algorithms and Data Structures
|
||||
//!
|
||||
//! Nearest neighbor search is a basic computational tool that is particularly relevant to machine learning,
|
||||
@@ -29,8 +30,75 @@
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
use crate::algorithm::neighbour::cover_tree::CoverTree;
|
||||
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
|
||||
use crate::error::Failed;
|
||||
use crate::math::distance::Distance;
|
||||
use crate::math::num::RealNumber;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub(crate) mod bbd_tree;
|
||||
/// tree data structure for fast nearest neighbor search
|
||||
pub mod cover_tree;
|
||||
/// dissimilarities for vector-vector distance. Linkage algorithms used in fastpair
|
||||
pub mod distances;
|
||||
/// fastpair closest neighbour algorithm
|
||||
pub mod fastpair;
|
||||
/// very simple algorithm that sequentially checks each element of the list until a match is found or the whole list has been searched.
|
||||
pub mod linear_search;
|
||||
|
||||
/// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries.
|
||||
/// `KNNAlgorithmName` maintains a list of supported search algorithms, see [KNN algorithms](../algorithm/neighbour/index.html)
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum KNNAlgorithmName {
|
||||
/// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html)
|
||||
LinearSearch,
|
||||
/// Cover Tree Search algorithm, see [`CoverTree`](../algorithm/neighbour/cover_tree/index.html)
|
||||
CoverTree,
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum KNNAlgorithm<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
|
||||
CoverTree(CoverTree<Vec<T>, T, D>),
|
||||
}
|
||||
|
||||
impl KNNAlgorithmName {
|
||||
pub(crate) fn fit<T: RealNumber, D: Distance<Vec<T>, T>>(
|
||||
&self,
|
||||
data: Vec<Vec<T>>,
|
||||
distance: D,
|
||||
) -> Result<KNNAlgorithm<T, D>, Failed> {
|
||||
match *self {
|
||||
KNNAlgorithmName::LinearSearch => {
|
||||
LinearKNNSearch::new(data, distance).map(KNNAlgorithm::LinearSearch)
|
||||
}
|
||||
KNNAlgorithmName::CoverTree => {
|
||||
CoverTree::new(data, distance).map(KNNAlgorithm::CoverTree)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNAlgorithm<T, D> {
|
||||
pub fn find(&self, from: &Vec<T>, k: usize) -> Result<Vec<(usize, T, &Vec<T>)>, Failed> {
|
||||
match *self {
|
||||
KNNAlgorithm::LinearSearch(ref linear) => linear.find(from, k),
|
||||
KNNAlgorithm::CoverTree(ref cover) => cover.find(from, k),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn find_radius(
|
||||
&self,
|
||||
from: &Vec<T>,
|
||||
radius: T,
|
||||
) -> Result<Vec<(usize, T, &Vec<T>)>, Failed> {
|
||||
match *self {
|
||||
KNNAlgorithm::LinearSearch(ref linear) => linear.find_radius(from, radius),
|
||||
KNNAlgorithm::CoverTree(ref cover) => cover.find_radius(from, radius),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,10 +12,10 @@ pub struct HeapSelection<T: PartialOrd + Debug> {
|
||||
heap: Vec<T>,
|
||||
}
|
||||
|
||||
impl<'a, T: PartialOrd + Debug> HeapSelection<T> {
|
||||
impl<T: PartialOrd + Debug> HeapSelection<T> {
|
||||
pub fn with_capacity(k: usize) -> HeapSelection<T> {
|
||||
HeapSelection {
|
||||
k: k,
|
||||
k,
|
||||
n: 0,
|
||||
sorted: false,
|
||||
heap: Vec::new(),
|
||||
@@ -41,6 +41,9 @@ impl<'a, T: PartialOrd + Debug> HeapSelection<T> {
|
||||
|
||||
pub fn heapify(&mut self) {
|
||||
let n = self.heap.len();
|
||||
if n <= 1 {
|
||||
return;
|
||||
}
|
||||
for i in (0..=(n / 2 - 1)).rev() {
|
||||
self.sift_down(i, n - 1);
|
||||
}
|
||||
@@ -48,10 +51,9 @@ impl<'a, T: PartialOrd + Debug> HeapSelection<T> {
|
||||
|
||||
pub fn peek(&self) -> &T {
|
||||
if self.sorted {
|
||||
return &self.heap[0];
|
||||
&self.heap[0]
|
||||
} else {
|
||||
&self
|
||||
.heap
|
||||
self.heap
|
||||
.iter()
|
||||
.max_by(|a, b| a.partial_cmp(b).unwrap())
|
||||
.unwrap()
|
||||
@@ -59,11 +61,11 @@ impl<'a, T: PartialOrd + Debug> HeapSelection<T> {
|
||||
}
|
||||
|
||||
pub fn peek_mut(&mut self) -> &mut T {
|
||||
return &mut self.heap[0];
|
||||
&mut self.heap[0]
|
||||
}
|
||||
|
||||
pub fn get(self) -> Vec<T> {
|
||||
return self.heap;
|
||||
self.heap
|
||||
}
|
||||
|
||||
fn sift_down(&mut self, k: usize, n: usize) {
|
||||
@@ -93,12 +95,14 @@ impl<'a, T: PartialOrd + Debug> HeapSelection<T> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn with_capacity() {
|
||||
let heap = HeapSelection::<i32>::with_capacity(3);
|
||||
assert_eq!(3, heap.k);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn test_add() {
|
||||
let mut heap = HeapSelection::with_capacity(3);
|
||||
@@ -116,6 +120,7 @@ mod tests {
|
||||
assert_eq!(vec![2, 0, -5], heap.get());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn test_add1() {
|
||||
let mut heap = HeapSelection::with_capacity(3);
|
||||
@@ -130,6 +135,7 @@ mod tests {
|
||||
assert_eq!(vec![0f64, -1f64, -5f64], heap.get());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn test_add2() {
|
||||
let mut heap = HeapSelection::with_capacity(3);
|
||||
@@ -142,6 +148,7 @@ mod tests {
|
||||
assert_eq!(vec![5.6568, 2.8284, 0.0], heap.get());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn test_add_ordered() {
|
||||
let mut heap = HeapSelection::with_capacity(3);
|
||||
|
||||
@@ -113,6 +113,7 @@ impl<T: Float> QuickArgSort for Vec<T> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn with_capacity() {
|
||||
let arr1 = vec![0.3, 0.1, 0.2, 0.4, 0.9, 0.5, 0.7, 0.6, 0.8];
|
||||
|
||||
+43
@@ -0,0 +1,43 @@
|
||||
//! # Common Interfaces and API
|
||||
//!
|
||||
//! This module provides interfaces and uniform API with simple conventions
|
||||
//! that are used in other modules for supervised and unsupervised learning.
|
||||
|
||||
use crate::error::Failed;
|
||||
|
||||
/// An estimator for unsupervised learning, that provides method `fit` to learn from data
|
||||
pub trait UnsupervisedEstimator<X, P> {
|
||||
/// Fit a model to a training dataset, estimate model's parameters.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `parameters` - hyperparameters of an algorithm
|
||||
fn fit(x: &X, parameters: P) -> Result<Self, Failed>
|
||||
where
|
||||
Self: Sized,
|
||||
P: Clone;
|
||||
}
|
||||
|
||||
/// An estimator for supervised learning, , that provides method `fit` to learn from data and training values
|
||||
pub trait SupervisedEstimator<X, Y, P> {
|
||||
/// Fit a model to a training dataset, estimate model's parameters.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - target training values of size _N_.
|
||||
/// * `parameters` - hyperparameters of an algorithm
|
||||
fn fit(x: &X, y: &Y, parameters: P) -> Result<Self, Failed>
|
||||
where
|
||||
Self: Sized,
|
||||
P: Clone;
|
||||
}
|
||||
|
||||
/// Implements method predict that estimates target value from new data
|
||||
pub trait Predictor<X, Y> {
|
||||
/// Estimate target values from new data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
fn predict(&self, x: &X) -> Result<Y, Failed>;
|
||||
}
|
||||
|
||||
/// Implements method transform that filters or modifies input data
|
||||
pub trait Transformer<X> {
|
||||
/// Transform data by modifying or filtering it
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
fn transform(&self, x: &X) -> Result<X, Failed>;
|
||||
}
|
||||
@@ -0,0 +1,337 @@
|
||||
//! # DBSCAN Clustering
|
||||
//!
|
||||
//! DBSCAN stands for density-based spatial clustering of applications with noise. This algorithms is good for arbitrary shaped clusters and clusters with noise.
|
||||
//! The main idea behind DBSCAN is that a point belongs to a cluster if it is close to many points from that cluster. There are two key parameters of DBSCAN:
|
||||
//!
|
||||
//! * `eps`, the maximum distance that specifies a neighborhood. Two points are considered to be neighbors if the distance between them are less than or equal to `eps`.
|
||||
//! * `min_samples`, minimum number of data points that defines a cluster.
|
||||
//!
|
||||
//! Based on these two parameters, points are classified as core point, border point, or outlier:
|
||||
//!
|
||||
//! * A point is a core point if there are at least `min_samples` number of points, including the point itself in its vicinity.
|
||||
//! * A point is a border point if it is reachable from a core point and there are less than `min_samples` number of points within its surrounding area.
|
||||
//! * All points not reachable from any other point are outliers or noise points.
|
||||
//!
|
||||
//! The algorithm starts from picking up an arbitrarily point in the dataset.
|
||||
//! If there are at least `min_samples` points within a radius of `eps` to the point then we consider all these points to be part of the same cluster.
|
||||
//! The clusters are then expanded by recursively repeating the neighborhood calculation for each neighboring point.
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::cluster::dbscan::*;
|
||||
//! use smartcore::math::distance::Distances;
|
||||
//! use smartcore::neighbors::KNNAlgorithmName;
|
||||
//! use smartcore::dataset::generator;
|
||||
//!
|
||||
//! // Generate three blobs
|
||||
//! let blobs = generator::make_blobs(100, 2, 3);
|
||||
//! let x = DenseMatrix::from_vec(blobs.num_samples, blobs.num_features, &blobs.data);
|
||||
//! // Fit the algorithm and predict cluster labels
|
||||
//! let labels = DBSCAN::fit(&x, DBSCANParameters::default().with_eps(3.0)).
|
||||
//! and_then(|dbscan| dbscan.predict(&x));
|
||||
//!
|
||||
//! println!("{:?}", labels);
|
||||
//! ```
|
||||
//!
|
||||
//! ## References:
|
||||
//!
|
||||
//! * ["A Density-Based Algorithm for Discovering Clusters in Large Spatial Databases with Noise", Ester M., Kriegel HP., Sander J., Xu X.](http://faculty.marshall.usc.edu/gareth-james/ISL/)
|
||||
//! * ["Density-Based Clustering in Spatial Databases: The Algorithm GDBSCAN and its Applications", Sander J., Ester M., Kriegel HP., Xu X.](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.63.1629&rep=rep1&type=pdf)
|
||||
|
||||
use std::fmt::Debug;
|
||||
use std::iter::Sum;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
|
||||
use crate::api::{Predictor, UnsupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::{row_iter, Matrix};
|
||||
use crate::math::distance::euclidian::Euclidian;
|
||||
use crate::math::distance::{Distance, Distances};
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::tree::decision_tree_classifier::which_max;
|
||||
|
||||
/// DBSCAN clustering algorithm
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct DBSCAN<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
cluster_labels: Vec<i16>,
|
||||
num_classes: usize,
|
||||
knn_algorithm: KNNAlgorithm<T, D>,
|
||||
eps: T,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// DBSCAN clustering algorithm parameters
|
||||
pub struct DBSCANParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
/// a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
pub distance: D,
|
||||
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
|
||||
pub min_samples: usize,
|
||||
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
|
||||
pub eps: T,
|
||||
/// KNN algorithm to use.
|
||||
pub algorithm: KNNAlgorithmName,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> DBSCANParameters<T, D> {
|
||||
/// a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
pub fn with_distance<DD: Distance<Vec<T>, T>>(self, distance: DD) -> DBSCANParameters<T, DD> {
|
||||
DBSCANParameters {
|
||||
distance,
|
||||
min_samples: self.min_samples,
|
||||
eps: self.eps,
|
||||
algorithm: self.algorithm,
|
||||
}
|
||||
}
|
||||
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
|
||||
pub fn with_min_samples(mut self, min_samples: usize) -> Self {
|
||||
self.min_samples = min_samples;
|
||||
self
|
||||
}
|
||||
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
|
||||
pub fn with_eps(mut self, eps: T) -> Self {
|
||||
self.eps = eps;
|
||||
self
|
||||
}
|
||||
/// KNN algorithm to use.
|
||||
pub fn with_algorithm(mut self, algorithm: KNNAlgorithmName) -> Self {
|
||||
self.algorithm = algorithm;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> PartialEq for DBSCAN<T, D> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.cluster_labels.len() == other.cluster_labels.len()
|
||||
&& self.num_classes == other.num_classes
|
||||
&& self.eps == other.eps
|
||||
&& self.cluster_labels == other.cluster_labels
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for DBSCANParameters<T, Euclidian> {
|
||||
fn default() -> Self {
|
||||
DBSCANParameters {
|
||||
distance: Distances::euclidian(),
|
||||
min_samples: 5,
|
||||
eps: T::half(),
|
||||
algorithm: KNNAlgorithmName::CoverTree,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber + Sum, M: Matrix<T>, D: Distance<Vec<T>, T>>
|
||||
UnsupervisedEstimator<M, DBSCANParameters<T, D>> for DBSCAN<T, D>
|
||||
{
|
||||
fn fit(x: &M, parameters: DBSCANParameters<T, D>) -> Result<Self, Failed> {
|
||||
DBSCAN::fit(x, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, D: Distance<Vec<T>, T>> Predictor<M, M::RowVector>
|
||||
for DBSCAN<T, D>
|
||||
{
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
|
||||
/// Fit algorithm to _NxM_ matrix where _N_ is number of samples and _M_ is number of features.
|
||||
/// * `data` - training instances to cluster
|
||||
/// * `k` - number of clusters
|
||||
/// * `parameters` - cluster parameters
|
||||
pub fn fit<M: Matrix<T>>(
|
||||
x: &M,
|
||||
parameters: DBSCANParameters<T, D>,
|
||||
) -> Result<DBSCAN<T, D>, Failed> {
|
||||
if parameters.min_samples < 1 {
|
||||
return Err(Failed::fit("Invalid minPts"));
|
||||
}
|
||||
|
||||
if parameters.eps <= T::zero() {
|
||||
return Err(Failed::fit("Invalid radius: "));
|
||||
}
|
||||
|
||||
let mut k = 0;
|
||||
let queued = -2;
|
||||
let outlier = -1;
|
||||
let undefined = -3;
|
||||
|
||||
let n = x.shape().0;
|
||||
let mut y = vec![undefined; n];
|
||||
|
||||
let algo = parameters
|
||||
.algorithm
|
||||
.fit(row_iter(x).collect(), parameters.distance)?;
|
||||
|
||||
for (i, e) in row_iter(x).enumerate() {
|
||||
if y[i] == undefined {
|
||||
let mut neighbors = algo.find_radius(&e, parameters.eps)?;
|
||||
if neighbors.len() < parameters.min_samples {
|
||||
y[i] = outlier;
|
||||
} else {
|
||||
y[i] = k;
|
||||
|
||||
for j in 0..neighbors.len() {
|
||||
if y[neighbors[j].0] == undefined {
|
||||
y[neighbors[j].0] = queued;
|
||||
}
|
||||
}
|
||||
|
||||
while !neighbors.is_empty() {
|
||||
let neighbor = neighbors.pop().unwrap();
|
||||
let index = neighbor.0;
|
||||
|
||||
if y[index] == outlier {
|
||||
y[index] = k;
|
||||
}
|
||||
|
||||
if y[index] == undefined || y[index] == queued {
|
||||
y[index] = k;
|
||||
|
||||
let secondary_neighbors =
|
||||
algo.find_radius(neighbor.2, parameters.eps)?;
|
||||
|
||||
if secondary_neighbors.len() >= parameters.min_samples {
|
||||
for j in 0..secondary_neighbors.len() {
|
||||
let label = y[secondary_neighbors[j].0];
|
||||
if label == undefined {
|
||||
y[secondary_neighbors[j].0] = queued;
|
||||
}
|
||||
|
||||
if label == undefined || label == outlier {
|
||||
neighbors.push(secondary_neighbors[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
k += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(DBSCAN {
|
||||
cluster_labels: y,
|
||||
num_classes: k as usize,
|
||||
knn_algorithm: algo,
|
||||
eps: parameters.eps,
|
||||
})
|
||||
}
|
||||
|
||||
/// Predict clusters for `x`
|
||||
/// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features.
|
||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let (n, m) = x.shape();
|
||||
let mut result = M::zeros(1, n);
|
||||
let mut row = vec![T::zero(); m];
|
||||
|
||||
for i in 0..n {
|
||||
x.copy_row_as_vec(i, &mut row);
|
||||
let neighbors = self.knn_algorithm.find_radius(&row, self.eps)?;
|
||||
let mut label = vec![0usize; self.num_classes + 1];
|
||||
for neighbor in neighbors {
|
||||
let yi = self.cluster_labels[neighbor.0];
|
||||
if yi < 0 {
|
||||
label[self.num_classes] += 1;
|
||||
} else {
|
||||
label[yi as usize] += 1;
|
||||
}
|
||||
}
|
||||
let class = which_max(&label);
|
||||
if class != self.num_classes {
|
||||
result.set(0, i, T::from(class).unwrap());
|
||||
} else {
|
||||
result.set(0, i, -T::one());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result.to_row_vector())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
#[cfg(feature = "serde")]
|
||||
use crate::math::distance::euclidian::Euclidian;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn fit_predict_dbscan() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[1.0, 2.0],
|
||||
&[1.1, 2.1],
|
||||
&[0.9, 1.9],
|
||||
&[1.2, 2.2],
|
||||
&[0.8, 1.8],
|
||||
&[2.0, 1.0],
|
||||
&[2.1, 1.1],
|
||||
&[1.9, 0.9],
|
||||
&[2.2, 1.2],
|
||||
&[1.8, 0.8],
|
||||
&[3.0, 5.0],
|
||||
]);
|
||||
|
||||
let expected_labels = vec![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0];
|
||||
|
||||
let dbscan = DBSCAN::fit(
|
||||
&x,
|
||||
DBSCANParameters::default()
|
||||
.with_eps(0.5)
|
||||
.with_min_samples(2),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let predicted_labels = dbscan.predict(&x).unwrap();
|
||||
|
||||
assert_eq!(expected_labels, predicted_labels);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[5.7, 2.8, 4.5, 1.3],
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
|
||||
let dbscan = DBSCAN::fit(&x, Default::default()).unwrap();
|
||||
|
||||
let deserialized_dbscan: DBSCAN<f64, Euclidian> =
|
||||
serde_json::from_str(&serde_json::to_string(&dbscan).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(dbscan, deserialized_dbscan);
|
||||
}
|
||||
}
|
||||
+71
-38
@@ -43,7 +43,7 @@
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//!
|
||||
//! let kmeans = KMeans::fit(&x, 2, Default::default()).unwrap(); // Fit to data, 2 clusters
|
||||
//! let kmeans = KMeans::fit(&x, KMeansParameters::default().with_k(2)).unwrap(); // Fit to data, 2 clusters
|
||||
//! let y_hat = kmeans.predict(&x).unwrap(); // use the same points for prediction
|
||||
//! ```
|
||||
//!
|
||||
@@ -52,27 +52,28 @@
|
||||
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 10.3.1 K-Means Clustering](http://faculty.marshall.usc.edu/gareth-james/ISL/)
|
||||
//! * ["k-means++: The Advantages of Careful Seeding", Arthur D., Vassilvitskii S.](http://ilpubs.stanford.edu:8090/778/1/2006-13.pdf)
|
||||
|
||||
extern crate rand;
|
||||
|
||||
use rand::Rng;
|
||||
use std::fmt::Debug;
|
||||
use std::iter::Sum;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
||||
use crate::api::{Predictor, UnsupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::distance::euclidian::*;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// K-Means clustering algorithm
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct KMeans<T: RealNumber> {
|
||||
k: usize,
|
||||
y: Vec<usize>,
|
||||
_y: Vec<usize>,
|
||||
size: Vec<usize>,
|
||||
distortion: T,
|
||||
_distortion: T,
|
||||
centroids: Vec<Vec<T>>,
|
||||
}
|
||||
|
||||
@@ -103,33 +104,61 @@ impl<T: RealNumber> PartialEq for KMeans<T> {
|
||||
#[derive(Debug, Clone)]
|
||||
/// K-Means clustering algorithm parameters
|
||||
pub struct KMeansParameters {
|
||||
/// Number of clusters.
|
||||
pub k: usize,
|
||||
/// Maximum number of iterations of the k-means algorithm for a single run.
|
||||
pub max_iter: usize,
|
||||
}
|
||||
|
||||
impl KMeansParameters {
|
||||
/// Number of clusters.
|
||||
pub fn with_k(mut self, k: usize) -> Self {
|
||||
self.k = k;
|
||||
self
|
||||
}
|
||||
/// Maximum number of iterations of the k-means algorithm for a single run.
|
||||
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
|
||||
self.max_iter = max_iter;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for KMeansParameters {
|
||||
fn default() -> Self {
|
||||
KMeansParameters { max_iter: 100 }
|
||||
KMeansParameters {
|
||||
k: 2,
|
||||
max_iter: 100,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber + Sum, M: Matrix<T>> UnsupervisedEstimator<M, KMeansParameters> for KMeans<T> {
|
||||
fn fit(x: &M, parameters: KMeansParameters) -> Result<Self, Failed> {
|
||||
KMeans::fit(x, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for KMeans<T> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber + Sum> KMeans<T> {
|
||||
/// Fit algorithm to _NxM_ matrix where _N_ is number of samples and _M_ is number of features.
|
||||
/// * `data` - training instances to cluster
|
||||
/// * `k` - number of clusters
|
||||
/// * `data` - training instances to cluster
|
||||
/// * `parameters` - cluster parameters
|
||||
pub fn fit<M: Matrix<T>>(
|
||||
data: &M,
|
||||
k: usize,
|
||||
parameters: KMeansParameters,
|
||||
) -> Result<KMeans<T>, Failed> {
|
||||
pub fn fit<M: Matrix<T>>(data: &M, parameters: KMeansParameters) -> Result<KMeans<T>, Failed> {
|
||||
let bbd = BBDTree::new(data);
|
||||
|
||||
if k < 2 {
|
||||
return Err(Failed::fit(&format!("invalid number of clusters: {}", k)));
|
||||
if parameters.k < 2 {
|
||||
return Err(Failed::fit(&format!(
|
||||
"invalid number of clusters: {}",
|
||||
parameters.k
|
||||
)));
|
||||
}
|
||||
|
||||
if parameters.max_iter <= 0 {
|
||||
if parameters.max_iter == 0 {
|
||||
return Err(Failed::fit(&format!(
|
||||
"invalid maximum number of iterations: {}",
|
||||
parameters.max_iter
|
||||
@@ -139,9 +168,9 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
||||
let (n, d) = data.shape();
|
||||
|
||||
let mut distortion = T::max_value();
|
||||
let mut y = KMeans::kmeans_plus_plus(data, k);
|
||||
let mut size = vec![0; k];
|
||||
let mut centroids = vec![vec![T::zero(); d]; k];
|
||||
let mut y = KMeans::kmeans_plus_plus(data, parameters.k);
|
||||
let mut size = vec![0; parameters.k];
|
||||
let mut centroids = vec![vec![T::zero(); d]; parameters.k];
|
||||
|
||||
for i in 0..n {
|
||||
size[y[i]] += 1;
|
||||
@@ -149,20 +178,20 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
||||
|
||||
for i in 0..n {
|
||||
for j in 0..d {
|
||||
centroids[y[i]][j] = centroids[y[i]][j] + data.get(i, j);
|
||||
centroids[y[i]][j] += data.get(i, j);
|
||||
}
|
||||
}
|
||||
|
||||
for i in 0..k {
|
||||
for i in 0..parameters.k {
|
||||
for j in 0..d {
|
||||
centroids[i][j] = centroids[i][j] / T::from(size[i]).unwrap();
|
||||
centroids[i][j] /= T::from(size[i]).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
let mut sums = vec![vec![T::zero(); d]; k];
|
||||
let mut sums = vec![vec![T::zero(); d]; parameters.k];
|
||||
for _ in 1..=parameters.max_iter {
|
||||
let dist = bbd.clustering(¢roids, &mut sums, &mut size, &mut y);
|
||||
for i in 0..k {
|
||||
for i in 0..parameters.k {
|
||||
if size[i] > 0 {
|
||||
for j in 0..d {
|
||||
centroids[i][j] = T::from(sums[i][j]).unwrap() / T::from(size[i]).unwrap();
|
||||
@@ -178,11 +207,11 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
||||
}
|
||||
|
||||
Ok(KMeans {
|
||||
k: k,
|
||||
y: y,
|
||||
size: size,
|
||||
distortion: distortion,
|
||||
centroids: centroids,
|
||||
k: parameters.k,
|
||||
_y: y,
|
||||
size,
|
||||
_distortion: distortion,
|
||||
centroids,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -216,7 +245,7 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
||||
let mut rng = rand::thread_rng();
|
||||
let (n, m) = data.shape();
|
||||
let mut y = vec![0; n];
|
||||
let mut centroid = data.get_row_as_vec(rng.gen_range(0, n));
|
||||
let mut centroid = data.get_row_as_vec(rng.gen_range(0..n));
|
||||
|
||||
let mut d = vec![T::max_value(); n];
|
||||
|
||||
@@ -235,13 +264,13 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
||||
|
||||
let mut sum: T = T::zero();
|
||||
for i in d.iter() {
|
||||
sum = sum + *i;
|
||||
sum += *i;
|
||||
}
|
||||
let cutoff = T::from(rng.gen::<f64>()).unwrap() * sum;
|
||||
let mut cost = T::zero();
|
||||
let mut index = 0;
|
||||
while index < n {
|
||||
cost = cost + d[index];
|
||||
cost += d[index];
|
||||
if cost >= cutoff {
|
||||
break;
|
||||
}
|
||||
@@ -270,19 +299,21 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn invalid_k() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
|
||||
|
||||
assert!(KMeans::fit(&x, 0, Default::default()).is_err());
|
||||
assert!(KMeans::fit(&x, KMeansParameters::default().with_k(0)).is_err());
|
||||
assert_eq!(
|
||||
"Fit failed: invalid number of clusters: 1",
|
||||
KMeans::fit(&x, 1, Default::default())
|
||||
KMeans::fit(&x, KMeansParameters::default().with_k(1))
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn fit_predict_iris() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
@@ -308,16 +339,18 @@ mod tests {
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
|
||||
let kmeans = KMeans::fit(&x, 2, Default::default()).unwrap();
|
||||
let kmeans = KMeans::fit(&x, Default::default()).unwrap();
|
||||
|
||||
let y = kmeans.predict(&x).unwrap();
|
||||
|
||||
for i in 0..y.len() {
|
||||
assert_eq!(y[i] as usize, kmeans.y[i]);
|
||||
assert_eq!(y[i] as usize, kmeans._y[i]);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
@@ -342,7 +375,7 @@ mod tests {
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
|
||||
let kmeans = KMeans::fit(&x, 2, Default::default()).unwrap();
|
||||
let kmeans = KMeans::fit(&x, Default::default()).unwrap();
|
||||
|
||||
let deserialized_kmeans: KMeans<f64> =
|
||||
serde_json::from_str(&serde_json::to_string(&kmeans).unwrap()).unwrap();
|
||||
|
||||
@@ -3,5 +3,6 @@
|
||||
//! Clustering is the type of unsupervised learning where you divide the population or data points into a number of groups such that data points in the same groups
|
||||
//! are more similar to other data points in the same group than those in other groups. In simple words, the aim is to segregate groups with similar traits and assign them into clusters.
|
||||
|
||||
pub mod dbscan;
|
||||
/// An iterative clustering algorithm that aims to find local maxima in each iteration.
|
||||
pub mod kmeans;
|
||||
|
||||
@@ -38,8 +38,8 @@ pub fn load_dataset() -> Dataset<f32, f32> {
|
||||
Dataset {
|
||||
data: x,
|
||||
target: y,
|
||||
num_samples: num_samples,
|
||||
num_features: num_features,
|
||||
num_samples,
|
||||
num_features,
|
||||
feature_names: vec![
|
||||
"CRIM", "ZN", "INDUS", "CHAS", "NOX", "RM", "AGE", "DIS", "RAD", "TAX", "PTRATIO", "B",
|
||||
"LSTAT",
|
||||
@@ -56,9 +56,11 @@ pub fn load_dataset() -> Dataset<f32, f32> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use super::super::*;
|
||||
use super::*;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn refresh_boston_dataset() {
|
||||
@@ -67,6 +69,7 @@ mod tests {
|
||||
assert!(serialize_data(&dataset, "boston.xy").is_ok());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn boston_dataset() {
|
||||
let dataset = load_dataset();
|
||||
|
||||
@@ -40,8 +40,8 @@ pub fn load_dataset() -> Dataset<f32, f32> {
|
||||
Dataset {
|
||||
data: x,
|
||||
target: y,
|
||||
num_samples: num_samples,
|
||||
num_features: num_features,
|
||||
num_samples,
|
||||
num_features,
|
||||
feature_names: vec![
|
||||
"mean radius", "mean texture", "mean perimeter", "mean area",
|
||||
"mean smoothness", "mean compactness", "mean concavity",
|
||||
@@ -66,17 +66,20 @@ pub fn load_dataset() -> Dataset<f32, f32> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use super::super::*;
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
fn refresh_cancer_dataset() {
|
||||
// run this test to generate breast_cancer.xy file.
|
||||
let dataset = load_dataset();
|
||||
assert!(serialize_data(&dataset, "breast_cancer.xy").is_ok());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn cancer_dataset() {
|
||||
let dataset = load_dataset();
|
||||
|
||||
@@ -33,8 +33,8 @@ pub fn load_dataset() -> Dataset<f32, f32> {
|
||||
Dataset {
|
||||
data: x,
|
||||
target: y,
|
||||
num_samples: num_samples,
|
||||
num_features: num_features,
|
||||
num_samples,
|
||||
num_features,
|
||||
feature_names: vec![
|
||||
"Age", "Sex", "BMI", "BP", "S1", "S2", "S3", "S4", "S5", "S6",
|
||||
]
|
||||
@@ -50,9 +50,11 @@ pub fn load_dataset() -> Dataset<f32, f32> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use super::super::*;
|
||||
use super::*;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn refresh_diabetes_dataset() {
|
||||
@@ -61,6 +63,7 @@ mod tests {
|
||||
assert!(serialize_data(&dataset, "diabetes.xy").is_ok());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn boston_dataset() {
|
||||
let dataset = load_dataset();
|
||||
|
||||
@@ -23,8 +23,8 @@ pub fn load_dataset() -> Dataset<f32, f32> {
|
||||
Dataset {
|
||||
data: x,
|
||||
target: y,
|
||||
num_samples: num_samples,
|
||||
num_features: num_features,
|
||||
num_samples,
|
||||
num_features,
|
||||
feature_names: vec![
|
||||
"sepal length (cm)",
|
||||
"sepal width (cm)",
|
||||
@@ -45,9 +45,11 @@ pub fn load_dataset() -> Dataset<f32, f32> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use super::super::*;
|
||||
use super::*;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn refresh_digits_dataset() {
|
||||
@@ -55,7 +57,7 @@ mod tests {
|
||||
let dataset = load_dataset();
|
||||
assert!(serialize_data(&dataset, "digits.xy").is_ok());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn digits_dataset() {
|
||||
let dataset = load_dataset();
|
||||
|
||||
@@ -0,0 +1,178 @@
|
||||
//! # Dataset Generators
|
||||
//!
|
||||
use rand::distributions::Uniform;
|
||||
use rand::prelude::*;
|
||||
use rand_distr::Normal;
|
||||
|
||||
use crate::dataset::Dataset;
|
||||
|
||||
/// Generate `num_centers` clusters of normally distributed points
|
||||
pub fn make_blobs(
|
||||
num_samples: usize,
|
||||
num_features: usize,
|
||||
num_centers: usize,
|
||||
) -> Dataset<f32, f32> {
|
||||
let center_box = Uniform::from(-10.0..10.0);
|
||||
let cluster_std = 1.0;
|
||||
let mut centers: Vec<Vec<Normal<f32>>> = Vec::with_capacity(num_centers);
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
for _ in 0..num_centers {
|
||||
centers.push(
|
||||
(0..num_features)
|
||||
.map(|_| Normal::new(center_box.sample(&mut rng), cluster_std).unwrap())
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
|
||||
let mut y: Vec<f32> = Vec::with_capacity(num_samples);
|
||||
let mut x: Vec<f32> = Vec::with_capacity(num_samples);
|
||||
|
||||
for i in 0..num_samples {
|
||||
let label = i % num_centers;
|
||||
y.push(label as f32);
|
||||
for j in 0..num_features {
|
||||
x.push(centers[label][j].sample(&mut rng));
|
||||
}
|
||||
}
|
||||
|
||||
Dataset {
|
||||
data: x,
|
||||
target: y,
|
||||
num_samples,
|
||||
num_features,
|
||||
feature_names: (0..num_features).map(|n| n.to_string()).collect(),
|
||||
target_names: vec!["label".to_string()],
|
||||
description: "Isotropic Gaussian blobs".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Make a large circle containing a smaller circle in 2d.
|
||||
pub fn make_circles(num_samples: usize, factor: f32, noise: f32) -> Dataset<f32, f32> {
|
||||
if !(0.0..1.0).contains(&factor) {
|
||||
panic!("'factor' has to be between 0 and 1.");
|
||||
}
|
||||
|
||||
let num_samples_out = num_samples / 2;
|
||||
let num_samples_in = num_samples - num_samples_out;
|
||||
|
||||
let linspace_out = linspace(0.0, 2.0 * std::f32::consts::PI, num_samples_out);
|
||||
let linspace_in = linspace(0.0, 2.0 * std::f32::consts::PI, num_samples_in);
|
||||
|
||||
let noise = Normal::new(0.0, noise).unwrap();
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let mut x: Vec<f32> = Vec::with_capacity(num_samples * 2);
|
||||
let mut y: Vec<f32> = Vec::with_capacity(num_samples);
|
||||
|
||||
for v in linspace_out {
|
||||
x.push(v.cos() + noise.sample(&mut rng));
|
||||
x.push(v.sin() + noise.sample(&mut rng));
|
||||
y.push(0.0);
|
||||
}
|
||||
|
||||
for v in linspace_in {
|
||||
x.push(v.cos() * factor + noise.sample(&mut rng));
|
||||
x.push(v.sin() * factor + noise.sample(&mut rng));
|
||||
y.push(1.0);
|
||||
}
|
||||
|
||||
Dataset {
|
||||
data: x,
|
||||
target: y,
|
||||
num_samples,
|
||||
num_features: 2,
|
||||
feature_names: (0..2).map(|n| n.to_string()).collect(),
|
||||
target_names: vec!["label".to_string()],
|
||||
description: "Large circle containing a smaller circle in 2d".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Make two interleaving half circles in 2d
|
||||
pub fn make_moons(num_samples: usize, noise: f32) -> Dataset<f32, f32> {
|
||||
let num_samples_out = num_samples / 2;
|
||||
let num_samples_in = num_samples - num_samples_out;
|
||||
|
||||
let linspace_out = linspace(0.0, std::f32::consts::PI, num_samples_out);
|
||||
let linspace_in = linspace(0.0, std::f32::consts::PI, num_samples_in);
|
||||
|
||||
let noise = Normal::new(0.0, noise).unwrap();
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let mut x: Vec<f32> = Vec::with_capacity(num_samples * 2);
|
||||
let mut y: Vec<f32> = Vec::with_capacity(num_samples);
|
||||
|
||||
for v in linspace_out {
|
||||
x.push(v.cos() + noise.sample(&mut rng));
|
||||
x.push(v.sin() + noise.sample(&mut rng));
|
||||
y.push(0.0);
|
||||
}
|
||||
|
||||
for v in linspace_in {
|
||||
x.push(1.0 - v.cos() + noise.sample(&mut rng));
|
||||
x.push(1.0 - v.sin() + noise.sample(&mut rng) - 0.5);
|
||||
y.push(1.0);
|
||||
}
|
||||
|
||||
Dataset {
|
||||
data: x,
|
||||
target: y,
|
||||
num_samples,
|
||||
num_features: 2,
|
||||
feature_names: (0..2).map(|n| n.to_string()).collect(),
|
||||
target_names: vec!["label".to_string()],
|
||||
description: "Two interleaving half circles in 2d".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn linspace(start: f32, stop: f32, num: usize) -> Vec<f32> {
|
||||
let div = num as f32;
|
||||
let delta = stop - start;
|
||||
let step = delta / div;
|
||||
(0..num).map(|v| v as f32 * step).collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn test_make_blobs() {
|
||||
let dataset = make_blobs(10, 2, 3);
|
||||
assert_eq!(
|
||||
dataset.data.len(),
|
||||
dataset.num_features * dataset.num_samples
|
||||
);
|
||||
assert_eq!(dataset.target.len(), dataset.num_samples);
|
||||
assert_eq!(dataset.num_features, 2);
|
||||
assert_eq!(dataset.num_samples, 10);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn test_make_circles() {
|
||||
let dataset = make_circles(10, 0.5, 0.05);
|
||||
assert_eq!(
|
||||
dataset.data.len(),
|
||||
dataset.num_features * dataset.num_samples
|
||||
);
|
||||
assert_eq!(dataset.target.len(), dataset.num_samples);
|
||||
assert_eq!(dataset.num_features, 2);
|
||||
assert_eq!(dataset.num_samples, 10);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn test_make_moons() {
|
||||
let dataset = make_moons(10, 0.05);
|
||||
assert_eq!(
|
||||
dataset.data.len(),
|
||||
dataset.num_features * dataset.num_samples
|
||||
);
|
||||
assert_eq!(dataset.target.len(), dataset.num_samples);
|
||||
assert_eq!(dataset.num_features, 2);
|
||||
assert_eq!(dataset.num_samples, 10);
|
||||
}
|
||||
}
|
||||
+5
-2
@@ -28,8 +28,8 @@ pub fn load_dataset() -> Dataset<f32, f32> {
|
||||
Dataset {
|
||||
data: x,
|
||||
target: y,
|
||||
num_samples: num_samples,
|
||||
num_features: num_features,
|
||||
num_samples,
|
||||
num_features,
|
||||
feature_names: vec![
|
||||
"sepal length (cm)",
|
||||
"sepal width (cm)",
|
||||
@@ -50,9 +50,11 @@ pub fn load_dataset() -> Dataset<f32, f32> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use super::super::*;
|
||||
use super::*;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn refresh_iris_dataset() {
|
||||
@@ -61,6 +63,7 @@ mod tests {
|
||||
assert!(serialize_data(&dataset, "iris.xy").is_ok());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn iris_dataset() {
|
||||
let dataset = load_dataset();
|
||||
|
||||
+17
-9
@@ -5,11 +5,15 @@ pub mod boston;
|
||||
pub mod breast_cancer;
|
||||
pub mod diabetes;
|
||||
pub mod digits;
|
||||
pub mod generator;
|
||||
pub mod iris;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::math::num::RealNumber;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::fs::File;
|
||||
use std::io;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use std::io::prelude::*;
|
||||
|
||||
/// Dataset
|
||||
@@ -48,6 +52,8 @@ impl<X, Y> Dataset<X, Y> {
|
||||
}
|
||||
}
|
||||
|
||||
// Running this in wasm throws: operation not supported on this platform.
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn serialize_data<X: RealNumber, Y: RealNumber>(
|
||||
dataset: &Dataset<X, Y>,
|
||||
@@ -55,20 +61,20 @@ pub(crate) fn serialize_data<X: RealNumber, Y: RealNumber>(
|
||||
) -> Result<(), io::Error> {
|
||||
match File::create(filename) {
|
||||
Ok(mut file) => {
|
||||
file.write(&dataset.num_features.to_le_bytes())?;
|
||||
file.write(&dataset.num_samples.to_le_bytes())?;
|
||||
file.write_all(&dataset.num_features.to_le_bytes())?;
|
||||
file.write_all(&dataset.num_samples.to_le_bytes())?;
|
||||
let x: Vec<u8> = dataset
|
||||
.data
|
||||
.iter()
|
||||
.map(|v| *v)
|
||||
.flat_map(|f| f.to_f32_bits().to_le_bytes().to_vec().into_iter())
|
||||
.copied()
|
||||
.flat_map(|f| f.to_f32_bits().to_le_bytes().to_vec())
|
||||
.collect();
|
||||
file.write_all(&x)?;
|
||||
let y: Vec<u8> = dataset
|
||||
.target
|
||||
.iter()
|
||||
.map(|v| *v)
|
||||
.flat_map(|f| f.to_f32_bits().to_le_bytes().to_vec().into_iter())
|
||||
.copied()
|
||||
.flat_map(|f| f.to_f32_bits().to_le_bytes().to_vec())
|
||||
.collect();
|
||||
file.write_all(&y)?;
|
||||
}
|
||||
@@ -81,11 +87,12 @@ pub(crate) fn deserialize_data(
|
||||
bytes: &[u8],
|
||||
) -> Result<(Vec<f32>, Vec<f32>, usize, usize), io::Error> {
|
||||
// read the same file back into a Vec of bytes
|
||||
const USIZE_SIZE: usize = std::mem::size_of::<usize>();
|
||||
let (num_samples, num_features) = {
|
||||
let mut buffer = [0u8; 8];
|
||||
buffer.copy_from_slice(&bytes[0..8]);
|
||||
let mut buffer = [0u8; USIZE_SIZE];
|
||||
buffer.copy_from_slice(&bytes[0..USIZE_SIZE]);
|
||||
let num_features = usize::from_le_bytes(buffer);
|
||||
buffer.copy_from_slice(&bytes[8..16]);
|
||||
buffer.copy_from_slice(&bytes[8..8 + USIZE_SIZE]);
|
||||
let num_samples = usize::from_le_bytes(buffer);
|
||||
(num_samples, num_features)
|
||||
};
|
||||
@@ -114,6 +121,7 @@ pub(crate) fn deserialize_data(
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn as_matrix() {
|
||||
let dataset = Dataset {
|
||||
|
||||
@@ -13,3 +13,4 @@
|
||||
|
||||
/// PCA is a popular approach for deriving a low-dimensional set of features from a large set of variables.
|
||||
pub mod pca;
|
||||
pub mod svd;
|
||||
|
||||
+91
-33
@@ -37,7 +37,7 @@
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//!
|
||||
//! let pca = PCA::fit(&iris, 2, Default::default()).unwrap(); // Reduce number of features to 2
|
||||
//! let pca = PCA::fit(&iris, PCAParameters::default().with_n_components(2)).unwrap(); // Reduce number of features to 2
|
||||
//!
|
||||
//! let iris_reduced = pca.transform(&iris).unwrap();
|
||||
//!
|
||||
@@ -47,14 +47,17 @@
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::fmt::Debug;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Transformer, UnsupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Principal components analysis algorithm
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct PCA<T: RealNumber, M: Matrix<T>> {
|
||||
eigenvectors: M,
|
||||
eigenvalues: Vec<T>,
|
||||
@@ -68,14 +71,14 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for PCA<T, M> {
|
||||
if self.eigenvectors != other.eigenvectors
|
||||
|| self.eigenvalues.len() != other.eigenvalues.len()
|
||||
{
|
||||
return false;
|
||||
false
|
||||
} else {
|
||||
for i in 0..self.eigenvalues.len() {
|
||||
if (self.eigenvalues[i] - other.eigenvalues[i]).abs() > T::epsilon() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -83,38 +86,70 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for PCA<T, M> {
|
||||
#[derive(Debug, Clone)]
|
||||
/// PCA parameters
|
||||
pub struct PCAParameters {
|
||||
/// Number of components to keep.
|
||||
pub n_components: usize,
|
||||
/// By default, covariance matrix is used to compute principal components.
|
||||
/// Enable this flag if you want to use correlation matrix instead.
|
||||
pub use_correlation_matrix: bool,
|
||||
}
|
||||
|
||||
impl PCAParameters {
|
||||
/// Number of components to keep.
|
||||
pub fn with_n_components(mut self, n_components: usize) -> Self {
|
||||
self.n_components = n_components;
|
||||
self
|
||||
}
|
||||
/// By default, covariance matrix is used to compute principal components.
|
||||
/// Enable this flag if you want to use correlation matrix instead.
|
||||
pub fn with_use_correlation_matrix(mut self, use_correlation_matrix: bool) -> Self {
|
||||
self.use_correlation_matrix = use_correlation_matrix;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PCAParameters {
|
||||
fn default() -> Self {
|
||||
PCAParameters {
|
||||
n_components: 2,
|
||||
use_correlation_matrix: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> UnsupervisedEstimator<M, PCAParameters> for PCA<T, M> {
|
||||
fn fit(x: &M, parameters: PCAParameters) -> Result<Self, Failed> {
|
||||
PCA::fit(x, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Transformer<M> for PCA<T, M> {
|
||||
fn transform(&self, x: &M) -> Result<M, Failed> {
|
||||
self.transform(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
|
||||
/// Fits PCA to your data.
|
||||
/// * `data` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `n_components` - number of components to keep.
|
||||
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
|
||||
pub fn fit(
|
||||
data: &M,
|
||||
n_components: usize,
|
||||
parameters: PCAParameters,
|
||||
) -> Result<PCA<T, M>, Failed> {
|
||||
pub fn fit(data: &M, parameters: PCAParameters) -> Result<PCA<T, M>, Failed> {
|
||||
let (m, n) = data.shape();
|
||||
|
||||
if parameters.n_components > n {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Number of components, n_components should be <= number of attributes ({})",
|
||||
n
|
||||
)));
|
||||
}
|
||||
|
||||
let mu = data.column_mean();
|
||||
|
||||
let mut x = data.clone();
|
||||
|
||||
for c in 0..n {
|
||||
for (c, mu_c) in mu.iter().enumerate().take(n) {
|
||||
for r in 0..m {
|
||||
x.sub_element_mut(r, c, mu[c]);
|
||||
x.sub_element_mut(r, c, *mu_c);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,8 +159,8 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
|
||||
if m > n && !parameters.use_correlation_matrix {
|
||||
let svd = x.svd()?;
|
||||
eigenvalues = svd.s;
|
||||
for i in 0..eigenvalues.len() {
|
||||
eigenvalues[i] = eigenvalues[i] * eigenvalues[i];
|
||||
for eigenvalue in &mut eigenvalues {
|
||||
*eigenvalue = *eigenvalue * (*eigenvalue);
|
||||
}
|
||||
|
||||
eigenvectors = svd.V;
|
||||
@@ -149,8 +184,8 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
|
||||
|
||||
if parameters.use_correlation_matrix {
|
||||
let mut sd = vec![T::zero(); n];
|
||||
for i in 0..n {
|
||||
sd[i] = cov.get(i, i).sqrt();
|
||||
for (i, sd_i) in sd.iter_mut().enumerate().take(n) {
|
||||
*sd_i = cov.get(i, i).sqrt();
|
||||
}
|
||||
|
||||
for i in 0..n {
|
||||
@@ -166,9 +201,9 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
|
||||
|
||||
eigenvectors = evd.V;
|
||||
|
||||
for i in 0..n {
|
||||
for (i, sd_i) in sd.iter().enumerate().take(n) {
|
||||
for j in 0..n {
|
||||
eigenvectors.div_element_mut(i, j, sd[i]);
|
||||
eigenvectors.div_element_mut(i, j, *sd_i);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -180,26 +215,26 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
let mut projection = M::zeros(n_components, n);
|
||||
let mut projection = M::zeros(parameters.n_components, n);
|
||||
for i in 0..n {
|
||||
for j in 0..n_components {
|
||||
for j in 0..parameters.n_components {
|
||||
projection.set(j, i, eigenvectors.get(i, j));
|
||||
}
|
||||
}
|
||||
|
||||
let mut pmu = vec![T::zero(); n_components];
|
||||
for k in 0..n {
|
||||
for i in 0..n_components {
|
||||
pmu[i] = pmu[i] + projection.get(i, k) * mu[k];
|
||||
let mut pmu = vec![T::zero(); parameters.n_components];
|
||||
for (k, mu_k) in mu.iter().enumerate().take(n) {
|
||||
for (i, pmu_i) in pmu.iter_mut().enumerate().take(parameters.n_components) {
|
||||
*pmu_i += projection.get(i, k) * (*mu_k);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(PCA {
|
||||
eigenvectors: eigenvectors,
|
||||
eigenvalues: eigenvalues,
|
||||
eigenvectors,
|
||||
eigenvalues,
|
||||
projection: projection.transpose(),
|
||||
mu: mu,
|
||||
pmu: pmu,
|
||||
mu,
|
||||
pmu,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -224,6 +259,11 @@ impl<T: RealNumber, M: Matrix<T>> PCA<T, M> {
|
||||
}
|
||||
Ok(x_transformed)
|
||||
}
|
||||
|
||||
/// Get a projection matrix
|
||||
pub fn components(&self) -> &M {
|
||||
&self.projection
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -285,7 +325,23 @@ mod tests {
|
||||
&[6.8, 161.0, 60.0, 15.6],
|
||||
])
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn pca_components() {
|
||||
let us_arrests = us_arrests_data();
|
||||
|
||||
let expected = DenseMatrix::from_2d_array(&[
|
||||
&[0.0417, 0.0448],
|
||||
&[0.9952, 0.0588],
|
||||
&[0.0463, 0.9769],
|
||||
&[0.0752, 0.2007],
|
||||
]);
|
||||
|
||||
let pca = PCA::fit(&us_arrests, Default::default()).unwrap();
|
||||
|
||||
assert!(expected.approximate_eq(&pca.components().abs(), 0.4));
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn decompose_covariance() {
|
||||
let us_arrests = us_arrests_data();
|
||||
@@ -377,7 +433,7 @@ mod tests {
|
||||
302.04806302399646,
|
||||
];
|
||||
|
||||
let pca = PCA::fit(&us_arrests, 4, Default::default()).unwrap();
|
||||
let pca = PCA::fit(&us_arrests, PCAParameters::default().with_n_components(4)).unwrap();
|
||||
|
||||
assert!(pca
|
||||
.eigenvectors
|
||||
@@ -395,6 +451,7 @@ mod tests {
|
||||
.approximate_eq(&expected_projection.abs(), 1e-4));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn decompose_correlation() {
|
||||
let us_arrests = us_arrests_data();
|
||||
@@ -488,10 +545,9 @@ mod tests {
|
||||
|
||||
let pca = PCA::fit(
|
||||
&us_arrests,
|
||||
4,
|
||||
PCAParameters {
|
||||
use_correlation_matrix: true,
|
||||
},
|
||||
PCAParameters::default()
|
||||
.with_n_components(4)
|
||||
.with_use_correlation_matrix(true),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -511,7 +567,9 @@ mod tests {
|
||||
.approximate_eq(&expected_projection.abs(), 1e-4));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let iris = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
@@ -536,7 +594,7 @@ mod tests {
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
|
||||
let pca = PCA::fit(&iris, 4, Default::default()).unwrap();
|
||||
let pca = PCA::fit(&iris, Default::default()).unwrap();
|
||||
|
||||
let deserialized_pca: PCA<f64, DenseMatrix<f64>> =
|
||||
serde_json::from_str(&serde_json::to_string(&pca).unwrap()).unwrap();
|
||||
|
||||
@@ -0,0 +1,265 @@
|
||||
//! # Dimensionality reduction using SVD
|
||||
//!
|
||||
//! Similar to [`PCA`](../pca/index.html), SVD is a technique that can be used to reduce the number of input variables _p_ to a smaller number _k_, while preserving
|
||||
//! the most important structure or relationships between the variables observed in the data.
|
||||
//!
|
||||
//! Contrary to PCA, SVD does not center the data before computing the singular value decomposition.
|
||||
//!
|
||||
//! Example:
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::decomposition::svd::*;
|
||||
//!
|
||||
//! // Iris data
|
||||
//! let iris = DenseMatrix::from_2d_array(&[
|
||||
//! &[5.1, 3.5, 1.4, 0.2],
|
||||
//! &[4.9, 3.0, 1.4, 0.2],
|
||||
//! &[4.7, 3.2, 1.3, 0.2],
|
||||
//! &[4.6, 3.1, 1.5, 0.2],
|
||||
//! &[5.0, 3.6, 1.4, 0.2],
|
||||
//! &[5.4, 3.9, 1.7, 0.4],
|
||||
//! &[4.6, 3.4, 1.4, 0.3],
|
||||
//! &[5.0, 3.4, 1.5, 0.2],
|
||||
//! &[4.4, 2.9, 1.4, 0.2],
|
||||
//! &[4.9, 3.1, 1.5, 0.1],
|
||||
//! &[7.0, 3.2, 4.7, 1.4],
|
||||
//! &[6.4, 3.2, 4.5, 1.5],
|
||||
//! &[6.9, 3.1, 4.9, 1.5],
|
||||
//! &[5.5, 2.3, 4.0, 1.3],
|
||||
//! &[6.5, 2.8, 4.6, 1.5],
|
||||
//! &[5.7, 2.8, 4.5, 1.3],
|
||||
//! &[6.3, 3.3, 4.7, 1.6],
|
||||
//! &[4.9, 2.4, 3.3, 1.0],
|
||||
//! &[6.6, 2.9, 4.6, 1.3],
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//!
|
||||
//! let svd = SVD::fit(&iris, SVDParameters::default().
|
||||
//! with_n_components(2)).unwrap(); // Reduce number of features to 2
|
||||
//!
|
||||
//! let iris_reduced = svd.transform(&iris).unwrap();
|
||||
//!
|
||||
//! ```
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Transformer, UnsupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// SVD
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct SVD<T: RealNumber, M: Matrix<T>> {
|
||||
components: M,
|
||||
phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> PartialEq for SVD<T, M> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.components
|
||||
.approximate_eq(&other.components, T::from_f64(1e-8).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// SVD parameters
|
||||
pub struct SVDParameters {
|
||||
/// Number of components to keep.
|
||||
pub n_components: usize,
|
||||
}
|
||||
|
||||
impl Default for SVDParameters {
|
||||
fn default() -> Self {
|
||||
SVDParameters { n_components: 2 }
|
||||
}
|
||||
}
|
||||
|
||||
impl SVDParameters {
|
||||
/// Number of components to keep.
|
||||
pub fn with_n_components(mut self, n_components: usize) -> Self {
|
||||
self.n_components = n_components;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> UnsupervisedEstimator<M, SVDParameters> for SVD<T, M> {
|
||||
fn fit(x: &M, parameters: SVDParameters) -> Result<Self, Failed> {
|
||||
SVD::fit(x, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Transformer<M> for SVD<T, M> {
|
||||
fn transform(&self, x: &M) -> Result<M, Failed> {
|
||||
self.transform(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> SVD<T, M> {
|
||||
/// Fits SVD to your data.
|
||||
/// * `data` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `n_components` - number of components to keep.
|
||||
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
|
||||
pub fn fit(x: &M, parameters: SVDParameters) -> Result<SVD<T, M>, Failed> {
|
||||
let (_, p) = x.shape();
|
||||
|
||||
if parameters.n_components >= p {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Number of components, n_components should be < number of attributes ({})",
|
||||
p
|
||||
)));
|
||||
}
|
||||
|
||||
let svd = x.svd()?;
|
||||
|
||||
let components = svd.V.slice(0..p, 0..parameters.n_components);
|
||||
|
||||
Ok(SVD {
|
||||
components,
|
||||
phantom: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
/// Run dimensionality reduction for `x`
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn transform(&self, x: &M) -> Result<M, Failed> {
|
||||
let (n, p) = x.shape();
|
||||
let (p_c, k) = self.components.shape();
|
||||
if p_c != p {
|
||||
return Err(Failed::transform(&format!(
|
||||
"Can not transform a {}x{} matrix into {}x{} matrix, incorrect input dimentions",
|
||||
n, p, n, k
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(x.matmul(&self.components))
|
||||
}
|
||||
|
||||
/// Get a projection matrix
|
||||
pub fn components(&self) -> &M {
|
||||
&self.components
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn svd_decompose() {
|
||||
// https://stat.ethz.ch/R-manual/R-devel/library/datasets/html/USArrests.html
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[13.2, 236.0, 58.0, 21.2],
|
||||
&[10.0, 263.0, 48.0, 44.5],
|
||||
&[8.1, 294.0, 80.0, 31.0],
|
||||
&[8.8, 190.0, 50.0, 19.5],
|
||||
&[9.0, 276.0, 91.0, 40.6],
|
||||
&[7.9, 204.0, 78.0, 38.7],
|
||||
&[3.3, 110.0, 77.0, 11.1],
|
||||
&[5.9, 238.0, 72.0, 15.8],
|
||||
&[15.4, 335.0, 80.0, 31.9],
|
||||
&[17.4, 211.0, 60.0, 25.8],
|
||||
&[5.3, 46.0, 83.0, 20.2],
|
||||
&[2.6, 120.0, 54.0, 14.2],
|
||||
&[10.4, 249.0, 83.0, 24.0],
|
||||
&[7.2, 113.0, 65.0, 21.0],
|
||||
&[2.2, 56.0, 57.0, 11.3],
|
||||
&[6.0, 115.0, 66.0, 18.0],
|
||||
&[9.7, 109.0, 52.0, 16.3],
|
||||
&[15.4, 249.0, 66.0, 22.2],
|
||||
&[2.1, 83.0, 51.0, 7.8],
|
||||
&[11.3, 300.0, 67.0, 27.8],
|
||||
&[4.4, 149.0, 85.0, 16.3],
|
||||
&[12.1, 255.0, 74.0, 35.1],
|
||||
&[2.7, 72.0, 66.0, 14.9],
|
||||
&[16.1, 259.0, 44.0, 17.1],
|
||||
&[9.0, 178.0, 70.0, 28.2],
|
||||
&[6.0, 109.0, 53.0, 16.4],
|
||||
&[4.3, 102.0, 62.0, 16.5],
|
||||
&[12.2, 252.0, 81.0, 46.0],
|
||||
&[2.1, 57.0, 56.0, 9.5],
|
||||
&[7.4, 159.0, 89.0, 18.8],
|
||||
&[11.4, 285.0, 70.0, 32.1],
|
||||
&[11.1, 254.0, 86.0, 26.1],
|
||||
&[13.0, 337.0, 45.0, 16.1],
|
||||
&[0.8, 45.0, 44.0, 7.3],
|
||||
&[7.3, 120.0, 75.0, 21.4],
|
||||
&[6.6, 151.0, 68.0, 20.0],
|
||||
&[4.9, 159.0, 67.0, 29.3],
|
||||
&[6.3, 106.0, 72.0, 14.9],
|
||||
&[3.4, 174.0, 87.0, 8.3],
|
||||
&[14.4, 279.0, 48.0, 22.5],
|
||||
&[3.8, 86.0, 45.0, 12.8],
|
||||
&[13.2, 188.0, 59.0, 26.9],
|
||||
&[12.7, 201.0, 80.0, 25.5],
|
||||
&[3.2, 120.0, 80.0, 22.9],
|
||||
&[2.2, 48.0, 32.0, 11.2],
|
||||
&[8.5, 156.0, 63.0, 20.7],
|
||||
&[4.0, 145.0, 73.0, 26.2],
|
||||
&[5.7, 81.0, 39.0, 9.3],
|
||||
&[2.6, 53.0, 66.0, 10.8],
|
||||
&[6.8, 161.0, 60.0, 15.6],
|
||||
]);
|
||||
|
||||
let expected = DenseMatrix::from_2d_array(&[
|
||||
&[243.54655757, -18.76673788],
|
||||
&[268.36802004, -33.79304302],
|
||||
&[305.93972467, -15.39087376],
|
||||
&[197.28420365, -11.66808306],
|
||||
&[293.43187394, 1.91163633],
|
||||
]);
|
||||
let svd = SVD::fit(&x, Default::default()).unwrap();
|
||||
|
||||
let x_transformed = svd.transform(&x).unwrap();
|
||||
|
||||
assert_eq!(svd.components.shape(), (x.shape().1, 2));
|
||||
|
||||
assert!(x_transformed
|
||||
.slice(0..5, 0..2)
|
||||
.approximate_eq(&expected, 1e-4));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let iris = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[5.7, 2.8, 4.5, 1.3],
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
|
||||
let svd = SVD::fit(&iris, Default::default()).unwrap();
|
||||
|
||||
let deserialized_svd: SVD<f64, DenseMatrix<f64>> =
|
||||
serde_json::from_str(&serde_json::to_string(&svd).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(svd, deserialized_svd);
|
||||
}
|
||||
}
|
||||
@@ -9,7 +9,7 @@
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::ensemble::random_forest_classifier::*;
|
||||
//! use smartcore::ensemble::random_forest_classifier::RandomForestClassifier;
|
||||
//!
|
||||
//! // Iris dataset
|
||||
//! let x = DenseMatrix::from_2d_array(&[
|
||||
@@ -45,16 +45,18 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
extern crate rand;
|
||||
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rand::Rng;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::linalg::{BaseMatrix, Matrix};
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::tree::decision_tree_classifier::{
|
||||
which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion,
|
||||
@@ -62,7 +64,8 @@ use crate::tree::decision_tree_classifier::{
|
||||
|
||||
/// Parameters of the Random Forest algorithm.
|
||||
/// Some parameters here are passed directly into base estimator.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RandomForestClassifierParameters {
|
||||
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub criterion: SplitCriterion,
|
||||
@@ -76,20 +79,71 @@ pub struct RandomForestClassifierParameters {
|
||||
pub n_trees: u16,
|
||||
/// Number of random sample of predictors to use as split candidates.
|
||||
pub m: Option<usize>,
|
||||
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||
pub keep_samples: bool,
|
||||
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||
pub seed: u64,
|
||||
}
|
||||
|
||||
/// Random Forest Classifier
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct RandomForestClassifier<T: RealNumber> {
|
||||
parameters: RandomForestClassifierParameters,
|
||||
_parameters: RandomForestClassifierParameters,
|
||||
trees: Vec<DecisionTreeClassifier<T>>,
|
||||
classes: Vec<T>,
|
||||
samples: Option<Vec<Vec<bool>>>,
|
||||
}
|
||||
|
||||
impl RandomForestClassifierParameters {
|
||||
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub fn with_criterion(mut self, criterion: SplitCriterion) -> Self {
|
||||
self.criterion = criterion;
|
||||
self
|
||||
}
|
||||
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub fn with_max_depth(mut self, max_depth: u16) -> Self {
|
||||
self.max_depth = Some(max_depth);
|
||||
self
|
||||
}
|
||||
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
|
||||
self.min_samples_leaf = min_samples_leaf;
|
||||
self
|
||||
}
|
||||
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
|
||||
self.min_samples_split = min_samples_split;
|
||||
self
|
||||
}
|
||||
/// The number of trees in the forest.
|
||||
pub fn with_n_trees(mut self, n_trees: u16) -> Self {
|
||||
self.n_trees = n_trees;
|
||||
self
|
||||
}
|
||||
/// Number of random sample of predictors to use as split candidates.
|
||||
pub fn with_m(mut self, m: usize) -> Self {
|
||||
self.m = Some(m);
|
||||
self
|
||||
}
|
||||
|
||||
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||
pub fn with_keep_samples(mut self, keep_samples: bool) -> Self {
|
||||
self.keep_samples = keep_samples;
|
||||
self
|
||||
}
|
||||
|
||||
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||
pub fn with_seed(mut self, seed: u64) -> Self {
|
||||
self.seed = seed;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> PartialEq for RandomForestClassifier<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.classes.len() != other.classes.len() || self.trees.len() != other.trees.len() {
|
||||
return false;
|
||||
false
|
||||
} else {
|
||||
for i in 0..self.classes.len() {
|
||||
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
|
||||
@@ -115,10 +169,31 @@ impl Default for RandomForestClassifierParameters {
|
||||
min_samples_split: 2,
|
||||
n_trees: 100,
|
||||
m: Option::None,
|
||||
keep_samples: false,
|
||||
seed: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>>
|
||||
SupervisedEstimator<M, M::RowVector, RandomForestClassifierParameters>
|
||||
for RandomForestClassifier<T>
|
||||
{
|
||||
fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: RandomForestClassifierParameters,
|
||||
) -> Result<Self, Failed> {
|
||||
RandomForestClassifier::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RandomForestClassifier<T> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
/// Build a forest of trees from the training set.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
@@ -134,39 +209,51 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
let mut yi: Vec<usize> = vec![0; y_ncols];
|
||||
let classes = y_m.unique();
|
||||
|
||||
for i in 0..y_ncols {
|
||||
for (i, yi_i) in yi.iter_mut().enumerate().take(y_ncols) {
|
||||
let yc = y_m.get(0, i);
|
||||
yi[i] = classes.iter().position(|c| yc == *c).unwrap();
|
||||
*yi_i = classes.iter().position(|c| yc == *c).unwrap();
|
||||
}
|
||||
|
||||
let mtry = parameters.m.unwrap_or(
|
||||
let mtry = parameters.m.unwrap_or_else(|| {
|
||||
(T::from(num_attributes).unwrap())
|
||||
.sqrt()
|
||||
.floor()
|
||||
.to_usize()
|
||||
.unwrap(),
|
||||
);
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let mut rng = StdRng::seed_from_u64(parameters.seed);
|
||||
let classes = y_m.unique();
|
||||
let k = classes.len();
|
||||
let mut trees: Vec<DecisionTreeClassifier<T>> = Vec::new();
|
||||
|
||||
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
|
||||
if parameters.keep_samples {
|
||||
maybe_all_samples = Some(Vec::new());
|
||||
}
|
||||
|
||||
for _ in 0..parameters.n_trees {
|
||||
let samples = RandomForestClassifier::<T>::sample_with_replacement(&yi, k);
|
||||
let samples = RandomForestClassifier::<T>::sample_with_replacement(&yi, k, &mut rng);
|
||||
if let Some(ref mut all_samples) = maybe_all_samples {
|
||||
all_samples.push(samples.iter().map(|x| *x != 0).collect())
|
||||
}
|
||||
|
||||
let params = DecisionTreeClassifierParameters {
|
||||
criterion: parameters.criterion.clone(),
|
||||
max_depth: parameters.max_depth,
|
||||
min_samples_leaf: parameters.min_samples_leaf,
|
||||
min_samples_split: parameters.min_samples_split,
|
||||
};
|
||||
let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params)?;
|
||||
let tree =
|
||||
DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
|
||||
trees.push(tree);
|
||||
}
|
||||
|
||||
Ok(RandomForestClassifier {
|
||||
parameters: parameters,
|
||||
trees: trees,
|
||||
_parameters: parameters,
|
||||
trees,
|
||||
classes,
|
||||
samples: maybe_all_samples,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -191,27 +278,93 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
result[tree.predict_for_row(x, row)] += 1;
|
||||
}
|
||||
|
||||
return which_max(&result);
|
||||
which_max(&result)
|
||||
}
|
||||
|
||||
fn sample_with_replacement(y: &Vec<usize>, num_classes: usize) -> Vec<usize> {
|
||||
let mut rng = rand::thread_rng();
|
||||
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
|
||||
pub fn predict_oob<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let (n, _) = x.shape();
|
||||
if self.samples.is_none() {
|
||||
Err(Failed::because(
|
||||
FailedError::PredictFailed,
|
||||
"Need samples=true for OOB predictions.",
|
||||
))
|
||||
} else if self.samples.as_ref().unwrap()[0].len() != n {
|
||||
Err(Failed::because(
|
||||
FailedError::PredictFailed,
|
||||
"Prediction matrix must match matrix used in training for OOB predictions.",
|
||||
))
|
||||
} else {
|
||||
let mut result = M::zeros(1, n);
|
||||
|
||||
for i in 0..n {
|
||||
result.set(0, i, self.classes[self.predict_for_row_oob(x, i)]);
|
||||
}
|
||||
|
||||
Ok(result.to_row_vector())
|
||||
}
|
||||
}
|
||||
|
||||
fn predict_for_row_oob<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
|
||||
let mut result = vec![0; self.classes.len()];
|
||||
|
||||
for (tree, samples) in self.trees.iter().zip(self.samples.as_ref().unwrap()) {
|
||||
if !samples[row] {
|
||||
result[tree.predict_for_row(x, row)] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
which_max(&result)
|
||||
}
|
||||
|
||||
/// Predict the per-class probabilties for each observation.
|
||||
/// The probability is calculated as the fraction of trees that predicted a given class
|
||||
pub fn predict_probs<M: Matrix<T>>(&self, x: &M) -> Result<DenseMatrix<f64>, Failed> {
|
||||
let mut result = DenseMatrix::<f64>::zeros(x.shape().0, self.classes.len());
|
||||
|
||||
let (n, _) = x.shape();
|
||||
|
||||
for i in 0..n {
|
||||
let row_probs = self.predict_probs_for_row(x, i);
|
||||
|
||||
for (j, item) in row_probs.iter().enumerate() {
|
||||
result.set(i, j, *item);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn predict_probs_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> Vec<f64> {
|
||||
let mut result = vec![0; self.classes.len()];
|
||||
|
||||
for tree in self.trees.iter() {
|
||||
result[tree.predict_for_row(x, row)] += 1;
|
||||
}
|
||||
|
||||
result
|
||||
.iter()
|
||||
.map(|n| *n as f64 / self.trees.len() as f64)
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn sample_with_replacement(y: &[usize], num_classes: usize, rng: &mut impl Rng) -> Vec<usize> {
|
||||
let class_weight = vec![1.; num_classes];
|
||||
let nrows = y.len();
|
||||
let mut samples = vec![0; nrows];
|
||||
for l in 0..num_classes {
|
||||
for (l, class_weight_l) in class_weight.iter().enumerate().take(num_classes) {
|
||||
let mut n_samples = 0;
|
||||
let mut index: Vec<usize> = Vec::new();
|
||||
for i in 0..nrows {
|
||||
if y[i] == l {
|
||||
for (i, y_i) in y.iter().enumerate().take(nrows) {
|
||||
if *y_i == l {
|
||||
index.push(i);
|
||||
n_samples += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let size = ((n_samples as f64) / class_weight[l]) as usize;
|
||||
let size = ((n_samples as f64) / *class_weight_l) as usize;
|
||||
for _ in 0..size {
|
||||
let xi: usize = rng.gen_range(0, n_samples);
|
||||
let xi: usize = rng.gen_range(0..n_samples);
|
||||
samples[index[xi]] += 1;
|
||||
}
|
||||
}
|
||||
@@ -220,11 +373,12 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
mod tests_prob {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::metrics::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn fit_predict_iris() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
@@ -263,6 +417,8 @@ mod tests {
|
||||
min_samples_split: 2,
|
||||
n_trees: 100,
|
||||
m: Option::None,
|
||||
keep_samples: false,
|
||||
seed: 87,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
@@ -270,7 +426,60 @@ mod tests {
|
||||
assert!(accuracy(&y, &classifier.predict(&x).unwrap()) >= 0.95);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn fit_predict_iris_oob() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[5.7, 2.8, 4.5, 1.3],
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
let y = vec![
|
||||
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
|
||||
let classifier = RandomForestClassifier::fit(
|
||||
&x,
|
||||
&y,
|
||||
RandomForestClassifierParameters {
|
||||
criterion: SplitCriterion::Gini,
|
||||
max_depth: None,
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2,
|
||||
n_trees: 100,
|
||||
m: Option::None,
|
||||
keep_samples: true,
|
||||
seed: 87,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
accuracy(&y, &classifier.predict_oob(&x).unwrap())
|
||||
< accuracy(&y, &classifier.predict(&x).unwrap())
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
@@ -305,4 +514,69 @@ mod tests {
|
||||
|
||||
assert_eq!(forest, deserialized_forest);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fit_predict_probabilities() {
|
||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[5.7, 2.8, 4.5, 1.3],
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
let y = vec![
|
||||
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
|
||||
let classifier = RandomForestClassifier::fit(
|
||||
&x,
|
||||
&y,
|
||||
RandomForestClassifierParameters {
|
||||
criterion: SplitCriterion::Gini,
|
||||
max_depth: None,
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2,
|
||||
n_trees: 100,
|
||||
m: Option::None,
|
||||
keep_samples: false,
|
||||
seed: 87,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
println!("{:?}", classifier.classes);
|
||||
|
||||
let results = classifier.predict_probs(&x).unwrap();
|
||||
println!("{:?}", x.shape());
|
||||
println!("{:?}", results);
|
||||
println!("{:?}", results.shape());
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
DenseMatrix::<f64>::from_array(
|
||||
20,
|
||||
2,
|
||||
&[
|
||||
1.0, 0.0, 0.78, 0.22, 0.95, 0.05, 0.82, 0.18, 1.0, 0.0, 0.92, 0.08, 0.99, 0.01,
|
||||
0.96, 0.04, 0.36, 0.64, 0.33, 0.67, 0.02, 0.98, 0.02, 0.98, 0.0, 1.0, 0.0, 1.0,
|
||||
0.0, 1.0, 0.0, 1.0, 0.03, 0.97, 0.05, 0.95, 0.0, 1.0, 0.02, 0.98
|
||||
]
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,22 +42,25 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
extern crate rand;
|
||||
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rand::Rng;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::error::Failed;
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::tree::decision_tree_regressor::{
|
||||
DecisionTreeRegressor, DecisionTreeRegressorParameters,
|
||||
};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// Parameters of the Random Forest Regressor
|
||||
/// Some parameters here are passed directly into base estimator.
|
||||
pub struct RandomForestRegressorParameters {
|
||||
@@ -71,15 +74,60 @@ pub struct RandomForestRegressorParameters {
|
||||
pub n_trees: usize,
|
||||
/// Number of random sample of predictors to use as split candidates.
|
||||
pub m: Option<usize>,
|
||||
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||
pub keep_samples: bool,
|
||||
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||
pub seed: u64,
|
||||
}
|
||||
|
||||
/// Random Forest Regressor
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct RandomForestRegressor<T: RealNumber> {
|
||||
parameters: RandomForestRegressorParameters,
|
||||
_parameters: RandomForestRegressorParameters,
|
||||
trees: Vec<DecisionTreeRegressor<T>>,
|
||||
samples: Option<Vec<Vec<bool>>>,
|
||||
}
|
||||
|
||||
impl RandomForestRegressorParameters {
|
||||
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub fn with_max_depth(mut self, max_depth: u16) -> Self {
|
||||
self.max_depth = Some(max_depth);
|
||||
self
|
||||
}
|
||||
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
|
||||
self.min_samples_leaf = min_samples_leaf;
|
||||
self
|
||||
}
|
||||
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
|
||||
self.min_samples_split = min_samples_split;
|
||||
self
|
||||
}
|
||||
/// The number of trees in the forest.
|
||||
pub fn with_n_trees(mut self, n_trees: usize) -> Self {
|
||||
self.n_trees = n_trees;
|
||||
self
|
||||
}
|
||||
/// Number of random sample of predictors to use as split candidates.
|
||||
pub fn with_m(mut self, m: usize) -> Self {
|
||||
self.m = Some(m);
|
||||
self
|
||||
}
|
||||
|
||||
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||
pub fn with_keep_samples(mut self, keep_samples: bool) -> Self {
|
||||
self.keep_samples = keep_samples;
|
||||
self
|
||||
}
|
||||
|
||||
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||
pub fn with_seed(mut self, seed: u64) -> Self {
|
||||
self.seed = seed;
|
||||
self
|
||||
}
|
||||
}
|
||||
impl Default for RandomForestRegressorParameters {
|
||||
fn default() -> Self {
|
||||
RandomForestRegressorParameters {
|
||||
@@ -88,6 +136,8 @@ impl Default for RandomForestRegressorParameters {
|
||||
min_samples_split: 2,
|
||||
n_trees: 10,
|
||||
m: Option::None,
|
||||
keep_samples: false,
|
||||
seed: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -95,7 +145,7 @@ impl Default for RandomForestRegressorParameters {
|
||||
impl<T: RealNumber> PartialEq for RandomForestRegressor<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.trees.len() != other.trees.len() {
|
||||
return false;
|
||||
false
|
||||
} else {
|
||||
for i in 0..self.trees.len() {
|
||||
if self.trees[i] != other.trees[i] {
|
||||
@@ -107,6 +157,25 @@ impl<T: RealNumber> PartialEq for RandomForestRegressor<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>>
|
||||
SupervisedEstimator<M, M::RowVector, RandomForestRegressorParameters>
|
||||
for RandomForestRegressor<T>
|
||||
{
|
||||
fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: RandomForestRegressorParameters,
|
||||
) -> Result<Self, Failed> {
|
||||
RandomForestRegressor::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RandomForestRegressor<T> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> RandomForestRegressor<T> {
|
||||
/// Build a forest of trees from the training set.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
@@ -122,22 +191,33 @@ impl<T: RealNumber> RandomForestRegressor<T> {
|
||||
.m
|
||||
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
|
||||
|
||||
let mut rng = StdRng::seed_from_u64(parameters.seed);
|
||||
let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new();
|
||||
|
||||
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
|
||||
if parameters.keep_samples {
|
||||
maybe_all_samples = Some(Vec::new());
|
||||
}
|
||||
|
||||
for _ in 0..parameters.n_trees {
|
||||
let samples = RandomForestRegressor::<T>::sample_with_replacement(n_rows);
|
||||
let samples = RandomForestRegressor::<T>::sample_with_replacement(n_rows, &mut rng);
|
||||
if let Some(ref mut all_samples) = maybe_all_samples {
|
||||
all_samples.push(samples.iter().map(|x| *x != 0).collect())
|
||||
}
|
||||
let params = DecisionTreeRegressorParameters {
|
||||
max_depth: parameters.max_depth,
|
||||
min_samples_leaf: parameters.min_samples_leaf,
|
||||
min_samples_split: parameters.min_samples_split,
|
||||
};
|
||||
let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params)?;
|
||||
let tree =
|
||||
DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
|
||||
trees.push(tree);
|
||||
}
|
||||
|
||||
Ok(RandomForestRegressor {
|
||||
parameters: parameters,
|
||||
trees: trees,
|
||||
_parameters: parameters,
|
||||
trees,
|
||||
samples: maybe_all_samples,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -161,17 +241,55 @@ impl<T: RealNumber> RandomForestRegressor<T> {
|
||||
let mut result = T::zero();
|
||||
|
||||
for tree in self.trees.iter() {
|
||||
result = result + tree.predict_for_row(x, row);
|
||||
result += tree.predict_for_row(x, row);
|
||||
}
|
||||
|
||||
result / T::from(n_trees).unwrap()
|
||||
}
|
||||
|
||||
fn sample_with_replacement(nrows: usize) -> Vec<usize> {
|
||||
let mut rng = rand::thread_rng();
|
||||
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
|
||||
pub fn predict_oob<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let (n, _) = x.shape();
|
||||
if self.samples.is_none() {
|
||||
Err(Failed::because(
|
||||
FailedError::PredictFailed,
|
||||
"Need samples=true for OOB predictions.",
|
||||
))
|
||||
} else if self.samples.as_ref().unwrap()[0].len() != n {
|
||||
Err(Failed::because(
|
||||
FailedError::PredictFailed,
|
||||
"Prediction matrix must match matrix used in training for OOB predictions.",
|
||||
))
|
||||
} else {
|
||||
let mut result = M::zeros(1, n);
|
||||
|
||||
for i in 0..n {
|
||||
result.set(0, i, self.predict_for_row_oob(x, i));
|
||||
}
|
||||
|
||||
Ok(result.to_row_vector())
|
||||
}
|
||||
}
|
||||
|
||||
fn predict_for_row_oob<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
|
||||
let mut n_trees = 0;
|
||||
let mut result = T::zero();
|
||||
|
||||
for (tree, samples) in self.trees.iter().zip(self.samples.as_ref().unwrap()) {
|
||||
if !samples[row] {
|
||||
result += tree.predict_for_row(x, row);
|
||||
n_trees += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: What to do if there are no oob trees?
|
||||
result / T::from(n_trees).unwrap()
|
||||
}
|
||||
|
||||
fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec<usize> {
|
||||
let mut samples = vec![0; nrows];
|
||||
for _ in 0..nrows {
|
||||
let xi = rng.gen_range(0, nrows);
|
||||
let xi = rng.gen_range(0..nrows);
|
||||
samples[xi] += 1;
|
||||
}
|
||||
samples
|
||||
@@ -184,6 +302,7 @@ mod tests {
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::metrics::mean_absolute_error;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn fit_longley() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
@@ -218,6 +337,8 @@ mod tests {
|
||||
min_samples_split: 2,
|
||||
n_trees: 1000,
|
||||
m: Option::None,
|
||||
keep_samples: false,
|
||||
seed: 87,
|
||||
},
|
||||
)
|
||||
.and_then(|rf| rf.predict(&x))
|
||||
@@ -226,7 +347,56 @@ mod tests {
|
||||
assert!(mean_absolute_error(&y, &y_hat) < 1.0);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn fit_predict_longley_oob() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
&[284.599, 335.1, 165., 110.929, 1950., 61.187],
|
||||
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
&[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
|
||||
&[365.385, 187., 354.7, 115.094, 1953., 64.989],
|
||||
&[363.112, 357.8, 335., 116.219, 1954., 63.761],
|
||||
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
&[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
&[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
|
||||
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
let y = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
|
||||
let regressor = RandomForestRegressor::fit(
|
||||
&x,
|
||||
&y,
|
||||
RandomForestRegressorParameters {
|
||||
max_depth: None,
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2,
|
||||
n_trees: 1000,
|
||||
m: Option::None,
|
||||
keep_samples: true,
|
||||
seed: 87,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let y_hat = regressor.predict(&x).unwrap();
|
||||
let y_hat_oob = regressor.predict_oob(&x).unwrap();
|
||||
|
||||
assert!(mean_absolute_error(&y, &y_hat) < mean_absolute_error(&y, &y_hat_oob));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||
|
||||
+12
-5
@@ -2,17 +2,21 @@
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Generic error to be raised when something goes wrong.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct Failed {
|
||||
err: FailedError,
|
||||
msg: String,
|
||||
}
|
||||
|
||||
/// Type of error
|
||||
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
|
||||
#[non_exhaustive]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub enum FailedError {
|
||||
/// Can't fit algorithm to data
|
||||
FitFailed = 1,
|
||||
@@ -24,6 +28,8 @@ pub enum FailedError {
|
||||
FindFailed,
|
||||
/// Can't decompose a matrix
|
||||
DecompositionFailed,
|
||||
/// Can't solve for x
|
||||
SolutionFailed,
|
||||
}
|
||||
|
||||
impl Failed {
|
||||
@@ -59,7 +65,7 @@ impl Failed {
|
||||
/// new instance of `err`
|
||||
pub fn because(err: FailedError, msg: &str) -> Self {
|
||||
Failed {
|
||||
err: err,
|
||||
err,
|
||||
msg: msg.to_string(),
|
||||
}
|
||||
}
|
||||
@@ -80,20 +86,21 @@ impl PartialEq for Failed {
|
||||
}
|
||||
|
||||
impl fmt::Display for FailedError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let failed_err_str = match self {
|
||||
FailedError::FitFailed => "Fit failed",
|
||||
FailedError::PredictFailed => "Predict failed",
|
||||
FailedError::TransformFailed => "Transform failed",
|
||||
FailedError::FindFailed => "Find failed",
|
||||
FailedError::DecompositionFailed => "Decomposition failed",
|
||||
FailedError::SolutionFailed => "Can't find solution",
|
||||
};
|
||||
write!(f, "{}", failed_err_str)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Failed {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}: {}", self.err, self.msg)
|
||||
}
|
||||
}
|
||||
|
||||
+26
-17
@@ -1,20 +1,22 @@
|
||||
#![allow(
|
||||
clippy::type_complexity,
|
||||
clippy::too_many_arguments,
|
||||
clippy::many_single_char_names,
|
||||
clippy::unnecessary_wraps,
|
||||
clippy::upper_case_acronyms
|
||||
)]
|
||||
#![warn(missing_docs)]
|
||||
#![warn(missing_doc_code_examples)]
|
||||
#![warn(rustdoc::missing_doc_code_examples)]
|
||||
|
||||
//! # SmartCore
|
||||
//!
|
||||
//! Welcome to SmartCore, the most advanced machine learning library in Rust!
|
||||
//!
|
||||
//! In SmartCore you will find implementation of these ML algorithms:
|
||||
//! * __Regression__: Linear Regression (OLS), Decision Tree Regressor, Random Forest Regressor, K Nearest Neighbors
|
||||
//! * __Classification__: Logistic Regressor, Decision Tree Classifier, Random Forest Classifier, Supervised Nearest Neighbors (KNN)
|
||||
//! * __Clustering__: K-Means
|
||||
//! * __Matrix Decomposition__: PCA, LU, QR, SVD, EVD
|
||||
//! * __Distance Metrics__: Euclidian, Minkowski, Manhattan, Hamming, Mahalanobis
|
||||
//! * __Evaluation Metrics__: Accuracy, AUC, Recall, Precision, F1, Mean Absolute Error, Mean Squared Error, R2
|
||||
//! SmartCore features various classification, regression and clustering algorithms including support vector machines, random forests, k-means and DBSCAN,
|
||||
//! as well as tools for model selection and model evaluation.
|
||||
//!
|
||||
//! Most of algorithms implemented in SmartCore operate on n-dimentional arrays. While you can use Rust vectors with all functions defined in this library
|
||||
//! we do recommend to go with one of the popular linear algebra libraries available in Rust. At this moment we support these packages:
|
||||
//! SmartCore is well integrated with a with wide variaty of libraries that provide support for large, multi-dimensional arrays and matrices. At this moment,
|
||||
//! all Smartcore's algorithms work with ordinary Rust vectors, as well as matrices and vectors defined in these packages:
|
||||
//! * [ndarray](https://docs.rs/ndarray)
|
||||
//! * [nalgebra](https://docs.rs/nalgebra/)
|
||||
//!
|
||||
@@ -23,21 +25,21 @@
|
||||
//! To start using SmartCore simply add the following to your Cargo.toml file:
|
||||
//! ```ignore
|
||||
//! [dependencies]
|
||||
//! smartcore = "0.1.0"
|
||||
//! smartcore = "0.2.0"
|
||||
//! ```
|
||||
//!
|
||||
//! All ML algorithms in SmartCore are grouped into these generic categories:
|
||||
//! All machine learning algorithms in SmartCore are grouped into these broad categories:
|
||||
//! * [Clustering](cluster/index.html), unsupervised clustering of unlabeled data.
|
||||
//! * [Martix Decomposition](decomposition/index.html), various methods for matrix decomposition.
|
||||
//! * [Matrix Decomposition](decomposition/index.html), various methods for matrix decomposition.
|
||||
//! * [Linear Models](linear/index.html), regression and classification methods where output is assumed to have linear relation to explanatory variables
|
||||
//! * [Ensemble Models](ensemble/index.html), variety of regression and classification ensemble models
|
||||
//! * [Tree-based Models](tree/index.html), classification and regression trees
|
||||
//! * [Nearest Neighbors](neighbors/index.html), K Nearest Neighbors for classification and regression
|
||||
//! * [Naive Bayes](naive_bayes/index.html), statistical classification technique based on Bayes Theorem
|
||||
//! * [SVM](svm/index.html), support vector machines
|
||||
//!
|
||||
//! Each category is assigned to a separate module.
|
||||
//!
|
||||
//! For example, KNN classifier is defined in [smartcore::neighbors::knn_classifier](neighbors/knn_classifier/index.html). To train and run it using standard Rust vectors you will
|
||||
//! run this code:
|
||||
//! For example, you can use this code to fit a [K Nearest Neighbors classifier](neighbors/knn_classifier/index.html) to a dataset that is defined as standard Rust vector:
|
||||
//!
|
||||
//! ```
|
||||
//! // DenseMatrix defenition
|
||||
@@ -58,7 +60,7 @@
|
||||
//! let y = vec![2., 2., 2., 3., 3.];
|
||||
//!
|
||||
//! // Train classifier
|
||||
//! let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||
//! let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||
//!
|
||||
//! // Predict classes
|
||||
//! let y_hat = knn.predict(&x).unwrap();
|
||||
@@ -66,6 +68,7 @@
|
||||
|
||||
/// Various algorithms and helper methods that are used elsewhere in SmartCore
|
||||
pub mod algorithm;
|
||||
pub mod api;
|
||||
/// Algorithms for clustering of unlabeled data
|
||||
pub mod cluster;
|
||||
/// Various datasets
|
||||
@@ -85,8 +88,14 @@ pub mod math;
|
||||
/// Functions for assessing prediction error.
|
||||
pub mod metrics;
|
||||
pub mod model_selection;
|
||||
/// Supervised learning algorithms based on applying the Bayes theorem with the independence assumptions between predictors
|
||||
pub mod naive_bayes;
|
||||
/// Supervised neighbors-based learning methods
|
||||
pub mod neighbors;
|
||||
pub(crate) mod optimization;
|
||||
/// Preprocessing utilities
|
||||
pub mod preprocessing;
|
||||
/// Support Vector Machines
|
||||
pub mod svm;
|
||||
/// Supervised tree-based learning methods
|
||||
pub mod tree;
|
||||
|
||||
@@ -0,0 +1,204 @@
|
||||
//! # Cholesky Decomposition
|
||||
//!
|
||||
//! every positive definite matrix \\(A \in R^{n \times n}\\) can be factored as
|
||||
//!
|
||||
//! \\[A = R^TR\\]
|
||||
//!
|
||||
//! where \\(R\\) is upper triangular matrix with positive diagonal elements
|
||||
//!
|
||||
//! Example:
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use crate::smartcore::linalg::cholesky::*;
|
||||
//!
|
||||
//! let A = DenseMatrix::from_2d_array(&[
|
||||
//! &[25., 15., -5.],
|
||||
//! &[15., 18., 0.],
|
||||
//! &[-5., 0., 11.]
|
||||
//! ]);
|
||||
//!
|
||||
//! let cholesky = A.cholesky().unwrap();
|
||||
//! let lower_triangular: DenseMatrix<f64> = cholesky.L();
|
||||
//! let upper_triangular: DenseMatrix<f64> = cholesky.U();
|
||||
//! ```
|
||||
//!
|
||||
//! ## References:
|
||||
//! * ["No bullshit guide to linear algebra", Ivan Savov, 2016, 7.6 Matrix decompositions](https://minireference.com/)
|
||||
//! * ["Numerical Recipes: The Art of Scientific Computing", Press W.H., Teukolsky S.A., Vetterling W.T, Flannery B.P, 3rd ed., 2.9 Cholesky Decomposition](http://numerical.recipes/)
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::BaseMatrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Results of Cholesky decomposition.
|
||||
pub struct Cholesky<T: RealNumber, M: BaseMatrix<T>> {
|
||||
R: M,
|
||||
t: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: BaseMatrix<T>> Cholesky<T, M> {
|
||||
pub(crate) fn new(R: M) -> Cholesky<T, M> {
|
||||
Cholesky { R, t: PhantomData }
|
||||
}
|
||||
|
||||
/// Get lower triangular matrix.
|
||||
pub fn L(&self) -> M {
|
||||
let (n, _) = self.R.shape();
|
||||
let mut R = M::zeros(n, n);
|
||||
|
||||
for i in 0..n {
|
||||
for j in 0..n {
|
||||
if j <= i {
|
||||
R.set(i, j, self.R.get(i, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
R
|
||||
}
|
||||
|
||||
/// Get upper triangular matrix.
|
||||
pub fn U(&self) -> M {
|
||||
let (n, _) = self.R.shape();
|
||||
let mut R = M::zeros(n, n);
|
||||
|
||||
for i in 0..n {
|
||||
for j in 0..n {
|
||||
if j <= i {
|
||||
R.set(j, i, self.R.get(i, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
R
|
||||
}
|
||||
|
||||
/// Solves Ax = b
|
||||
pub(crate) fn solve(&self, mut b: M) -> Result<M, Failed> {
|
||||
let (bn, m) = b.shape();
|
||||
let (rn, _) = self.R.shape();
|
||||
|
||||
if bn != rn {
|
||||
return Err(Failed::because(
|
||||
FailedError::SolutionFailed,
|
||||
"Can\'t solve Ax = b for x. Number of rows in b != number of rows in R.",
|
||||
));
|
||||
}
|
||||
|
||||
for k in 0..bn {
|
||||
for j in 0..m {
|
||||
for i in 0..k {
|
||||
b.sub_element_mut(k, j, b.get(i, j) * self.R.get(k, i));
|
||||
}
|
||||
b.div_element_mut(k, j, self.R.get(k, k));
|
||||
}
|
||||
}
|
||||
|
||||
for k in (0..bn).rev() {
|
||||
for j in 0..m {
|
||||
for i in k + 1..bn {
|
||||
b.sub_element_mut(k, j, b.get(i, j) * self.R.get(i, k));
|
||||
}
|
||||
b.div_element_mut(k, j, self.R.get(k, k));
|
||||
}
|
||||
}
|
||||
Ok(b)
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait that implements Cholesky decomposition routine for any matrix.
|
||||
pub trait CholeskyDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
/// Compute the Cholesky decomposition of a matrix.
|
||||
fn cholesky(&self) -> Result<Cholesky<T, Self>, Failed> {
|
||||
self.clone().cholesky_mut()
|
||||
}
|
||||
|
||||
/// Compute the Cholesky decomposition of a matrix. The input matrix
|
||||
/// will be used for factorization.
|
||||
fn cholesky_mut(mut self) -> Result<Cholesky<T, Self>, Failed> {
|
||||
let (m, n) = self.shape();
|
||||
|
||||
if m != n {
|
||||
return Err(Failed::because(
|
||||
FailedError::DecompositionFailed,
|
||||
"Can\'t do Cholesky decomposition on a non-square matrix",
|
||||
));
|
||||
}
|
||||
|
||||
for j in 0..n {
|
||||
let mut d = T::zero();
|
||||
for k in 0..j {
|
||||
let mut s = T::zero();
|
||||
for i in 0..k {
|
||||
s += self.get(k, i) * self.get(j, i);
|
||||
}
|
||||
s = (self.get(j, k) - s) / self.get(k, k);
|
||||
self.set(j, k, s);
|
||||
d += s * s;
|
||||
}
|
||||
d = self.get(j, j) - d;
|
||||
|
||||
if d < T::zero() {
|
||||
return Err(Failed::because(
|
||||
FailedError::DecompositionFailed,
|
||||
"The matrix is not positive definite.",
|
||||
));
|
||||
}
|
||||
|
||||
self.set(j, j, d.sqrt());
|
||||
}
|
||||
|
||||
Ok(Cholesky::new(self))
|
||||
}
|
||||
|
||||
/// Solves Ax = b
|
||||
fn cholesky_solve_mut(self, b: Self) -> Result<Self, Failed> {
|
||||
self.cholesky_mut().and_then(|qr| qr.solve(b))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn cholesky_decompose() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
|
||||
let l =
|
||||
DenseMatrix::from_2d_array(&[&[5.0, 0.0, 0.0], &[3.0, 3.0, 0.0], &[-1.0, 1.0, 3.0]]);
|
||||
let u =
|
||||
DenseMatrix::from_2d_array(&[&[5.0, 3.0, -1.0], &[0.0, 3.0, 1.0], &[0.0, 0.0, 3.0]]);
|
||||
let cholesky = a.cholesky().unwrap();
|
||||
|
||||
assert!(cholesky.L().abs().approximate_eq(&l.abs(), 1e-4));
|
||||
assert!(cholesky.U().abs().approximate_eq(&u.abs(), 1e-4));
|
||||
assert!(cholesky
|
||||
.L()
|
||||
.matmul(&cholesky.U())
|
||||
.abs()
|
||||
.approximate_eq(&a.abs(), 1e-4));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn cholesky_solve_mut() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
|
||||
let b = DenseMatrix::from_2d_array(&[&[40., 51., 28.]]);
|
||||
let expected = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]);
|
||||
|
||||
let cholesky = a.cholesky().unwrap();
|
||||
|
||||
assert!(cholesky
|
||||
.solve(b.transpose())
|
||||
.unwrap()
|
||||
.transpose()
|
||||
.approximate_eq(&expected, 1e-4));
|
||||
}
|
||||
}
|
||||
+99
-91
@@ -25,6 +25,19 @@
|
||||
//! let eigenvectors: DenseMatrix<f64> = evd.V;
|
||||
//! let eigenvalues: Vec<f64> = evd.d;
|
||||
//! ```
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::linalg::evd::*;
|
||||
//!
|
||||
//! let A = DenseMatrix::from_2d_array(&[
|
||||
//! &[-5.0, 2.0],
|
||||
//! &[-7.0, 4.0],
|
||||
//! ]);
|
||||
//!
|
||||
//! let evd = A.evd(false).unwrap();
|
||||
//! let eigenvectors: DenseMatrix<f64> = evd.V;
|
||||
//! let eigenvalues: Vec<f64> = evd.d;
|
||||
//! ```
|
||||
//!
|
||||
//! ## References:
|
||||
//! * ["Numerical Recipes: The Art of Scientific Computing", Press W.H., Teukolsky S.A., Vetterling W.T, Flannery B.P, 3rd ed., Section 11 Eigensystems](http://numerical.recipes/)
|
||||
@@ -93,33 +106,33 @@ pub trait EVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
sort(&mut d, &mut e, &mut V);
|
||||
}
|
||||
|
||||
Ok(EVD { V: V, d: d, e: e })
|
||||
Ok(EVD { d, e, V })
|
||||
}
|
||||
}
|
||||
|
||||
fn tred2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec<T>) {
|
||||
fn tred2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
|
||||
let (n, _) = V.shape();
|
||||
for i in 0..n {
|
||||
d[i] = V.get(n - 1, i);
|
||||
for (i, d_i) in d.iter_mut().enumerate().take(n) {
|
||||
*d_i = V.get(n - 1, i);
|
||||
}
|
||||
|
||||
for i in (1..n).rev() {
|
||||
let mut scale = T::zero();
|
||||
let mut h = T::zero();
|
||||
for k in 0..i {
|
||||
scale = scale + d[k].abs();
|
||||
for d_k in d.iter().take(i) {
|
||||
scale += d_k.abs();
|
||||
}
|
||||
if scale == T::zero() {
|
||||
e[i] = d[i - 1];
|
||||
for j in 0..i {
|
||||
d[j] = V.get(i - 1, j);
|
||||
for (j, d_j) in d.iter_mut().enumerate().take(i) {
|
||||
*d_j = V.get(i - 1, j);
|
||||
V.set(i, j, T::zero());
|
||||
V.set(j, i, T::zero());
|
||||
}
|
||||
} else {
|
||||
for k in 0..i {
|
||||
d[k] = d[k] / scale;
|
||||
h = h + d[k] * d[k];
|
||||
for d_k in d.iter_mut().take(i) {
|
||||
*d_k /= scale;
|
||||
h += (*d_k) * (*d_k);
|
||||
}
|
||||
let mut f = d[i - 1];
|
||||
let mut g = h.sqrt();
|
||||
@@ -127,10 +140,10 @@ fn tred2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec
|
||||
g = -g;
|
||||
}
|
||||
e[i] = scale * g;
|
||||
h = h - f * g;
|
||||
h -= f * g;
|
||||
d[i - 1] = f - g;
|
||||
for j in 0..i {
|
||||
e[j] = T::zero();
|
||||
for e_j in e.iter_mut().take(i) {
|
||||
*e_j = T::zero();
|
||||
}
|
||||
|
||||
for j in 0..i {
|
||||
@@ -138,19 +151,19 @@ fn tred2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec
|
||||
V.set(j, i, f);
|
||||
g = e[j] + V.get(j, j) * f;
|
||||
for k in j + 1..=i - 1 {
|
||||
g = g + V.get(k, j) * d[k];
|
||||
e[k] = e[k] + V.get(k, j) * f;
|
||||
g += V.get(k, j) * d[k];
|
||||
e[k] += V.get(k, j) * f;
|
||||
}
|
||||
e[j] = g;
|
||||
}
|
||||
f = T::zero();
|
||||
for j in 0..i {
|
||||
e[j] = e[j] / h;
|
||||
f = f + e[j] * d[j];
|
||||
e[j] /= h;
|
||||
f += e[j] * d[j];
|
||||
}
|
||||
let hh = f / (h + h);
|
||||
for j in 0..i {
|
||||
e[j] = e[j] - hh * d[j];
|
||||
e[j] -= hh * d[j];
|
||||
}
|
||||
for j in 0..i {
|
||||
f = d[j];
|
||||
@@ -170,16 +183,16 @@ fn tred2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec
|
||||
V.set(i, i, T::one());
|
||||
let h = d[i + 1];
|
||||
if h != T::zero() {
|
||||
for k in 0..=i {
|
||||
d[k] = V.get(k, i + 1) / h;
|
||||
for (k, d_k) in d.iter_mut().enumerate().take(i + 1) {
|
||||
*d_k = V.get(k, i + 1) / h;
|
||||
}
|
||||
for j in 0..=i {
|
||||
let mut g = T::zero();
|
||||
for k in 0..=i {
|
||||
g = g + V.get(k, i + 1) * V.get(k, j);
|
||||
g += V.get(k, i + 1) * V.get(k, j);
|
||||
}
|
||||
for k in 0..=i {
|
||||
V.sub_element_mut(k, j, g * d[k]);
|
||||
for (k, d_k) in d.iter().enumerate().take(i + 1) {
|
||||
V.sub_element_mut(k, j, g * (*d_k));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -187,15 +200,15 @@ fn tred2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec
|
||||
V.set(k, i + 1, T::zero());
|
||||
}
|
||||
}
|
||||
for j in 0..n {
|
||||
d[j] = V.get(n - 1, j);
|
||||
for (j, d_j) in d.iter_mut().enumerate().take(n) {
|
||||
*d_j = V.get(n - 1, j);
|
||||
V.set(n - 1, j, T::zero());
|
||||
}
|
||||
V.set(n - 1, n - 1, T::one());
|
||||
e[0] = T::zero();
|
||||
}
|
||||
|
||||
fn tql2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec<T>) {
|
||||
fn tql2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
|
||||
let (n, _) = V.shape();
|
||||
for i in 1..n {
|
||||
e[i - 1] = e[i];
|
||||
@@ -238,10 +251,10 @@ fn tql2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec<
|
||||
d[l + 1] = e[l] * (p + r);
|
||||
let dl1 = d[l + 1];
|
||||
let mut h = g - d[l];
|
||||
for i in l + 2..n {
|
||||
d[i] = d[i] - h;
|
||||
for d_i in d.iter_mut().take(n).skip(l + 2) {
|
||||
*d_i -= h;
|
||||
}
|
||||
f = f + h;
|
||||
f += h;
|
||||
|
||||
p = d[m];
|
||||
let mut c = T::one();
|
||||
@@ -278,17 +291,17 @@ fn tql2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec<
|
||||
}
|
||||
}
|
||||
}
|
||||
d[l] = d[l] + f;
|
||||
d[l] += f;
|
||||
e[l] = T::zero();
|
||||
}
|
||||
|
||||
for i in 0..n - 1 {
|
||||
let mut k = i;
|
||||
let mut p = d[i];
|
||||
for j in i + 1..n {
|
||||
if d[j] > p {
|
||||
for (j, d_j) in d.iter().enumerate().take(n).skip(i + 1) {
|
||||
if *d_j > p {
|
||||
k = j;
|
||||
p = d[j];
|
||||
p = *d_j;
|
||||
}
|
||||
}
|
||||
if k != i {
|
||||
@@ -316,13 +329,13 @@ fn balance<T: RealNumber, M: BaseMatrix<T>>(A: &mut M) -> Vec<T> {
|
||||
let mut done = false;
|
||||
while !done {
|
||||
done = true;
|
||||
for i in 0..n {
|
||||
for (i, scale_i) in scale.iter_mut().enumerate().take(n) {
|
||||
let mut r = T::zero();
|
||||
let mut c = T::zero();
|
||||
for j in 0..n {
|
||||
if j != i {
|
||||
c = c + A.get(j, i).abs();
|
||||
r = r + A.get(i, j).abs();
|
||||
c += A.get(j, i).abs();
|
||||
r += A.get(i, j).abs();
|
||||
}
|
||||
}
|
||||
if c != T::zero() && r != T::zero() {
|
||||
@@ -330,18 +343,18 @@ fn balance<T: RealNumber, M: BaseMatrix<T>>(A: &mut M) -> Vec<T> {
|
||||
let mut f = T::one();
|
||||
let s = c + r;
|
||||
while c < g {
|
||||
f = f * radix;
|
||||
c = c * sqrdx;
|
||||
f *= radix;
|
||||
c *= sqrdx;
|
||||
}
|
||||
g = r * radix;
|
||||
while c > g {
|
||||
f = f / radix;
|
||||
c = c / sqrdx;
|
||||
f /= radix;
|
||||
c /= sqrdx;
|
||||
}
|
||||
if (c + r) / f < t * s {
|
||||
done = false;
|
||||
g = T::one() / f;
|
||||
scale[i] = scale[i] * f;
|
||||
*scale_i *= f;
|
||||
for j in 0..n {
|
||||
A.mul_element_mut(i, j, g);
|
||||
}
|
||||
@@ -353,14 +366,14 @@ fn balance<T: RealNumber, M: BaseMatrix<T>>(A: &mut M) -> Vec<T> {
|
||||
}
|
||||
}
|
||||
|
||||
return scale;
|
||||
scale
|
||||
}
|
||||
|
||||
fn elmhes<T: RealNumber, M: BaseMatrix<T>>(A: &mut M) -> Vec<usize> {
|
||||
let (n, _) = A.shape();
|
||||
let mut perm = vec![0; n];
|
||||
|
||||
for m in 1..n - 1 {
|
||||
for (m, perm_m) in perm.iter_mut().enumerate().take(n - 1).skip(1) {
|
||||
let mut x = T::zero();
|
||||
let mut i = m;
|
||||
for j in m..n {
|
||||
@@ -369,7 +382,7 @@ fn elmhes<T: RealNumber, M: BaseMatrix<T>>(A: &mut M) -> Vec<usize> {
|
||||
i = j;
|
||||
}
|
||||
}
|
||||
perm[m] = i;
|
||||
*perm_m = i;
|
||||
if i != m {
|
||||
for j in (m - 1)..n {
|
||||
let swap = A.get(i, j);
|
||||
@@ -386,7 +399,7 @@ fn elmhes<T: RealNumber, M: BaseMatrix<T>>(A: &mut M) -> Vec<usize> {
|
||||
for i in (m + 1)..n {
|
||||
let mut y = A.get(i, m - 1);
|
||||
if y != T::zero() {
|
||||
y = y / x;
|
||||
y /= x;
|
||||
A.set(i, m - 1, y);
|
||||
for j in m..n {
|
||||
A.sub_element_mut(i, j, y * A.get(m, j));
|
||||
@@ -399,10 +412,10 @@ fn elmhes<T: RealNumber, M: BaseMatrix<T>>(A: &mut M) -> Vec<usize> {
|
||||
}
|
||||
}
|
||||
|
||||
return perm;
|
||||
perm
|
||||
}
|
||||
|
||||
fn eltran<T: RealNumber, M: BaseMatrix<T>>(A: &M, V: &mut M, perm: &Vec<usize>) {
|
||||
fn eltran<T: RealNumber, M: BaseMatrix<T>>(A: &M, V: &mut M, perm: &[usize]) {
|
||||
let (n, _) = A.shape();
|
||||
for mp in (1..n - 1).rev() {
|
||||
for k in mp + 1..n {
|
||||
@@ -419,7 +432,7 @@ fn eltran<T: RealNumber, M: BaseMatrix<T>>(A: &M, V: &mut M, perm: &Vec<usize>)
|
||||
}
|
||||
}
|
||||
|
||||
fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e: &mut Vec<T>) {
|
||||
fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &mut [T]) {
|
||||
let (n, _) = A.shape();
|
||||
let mut z = T::zero();
|
||||
let mut s = T::zero();
|
||||
@@ -430,7 +443,7 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e
|
||||
|
||||
for i in 0..n {
|
||||
for j in i32::max(i as i32 - 1, 0)..n as i32 {
|
||||
anorm = anorm + A.get(i, j as usize).abs();
|
||||
anorm += A.get(i, j as usize).abs();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -467,11 +480,11 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e
|
||||
p = T::half() * (y - x);
|
||||
q = p * p + w;
|
||||
z = q.abs().sqrt();
|
||||
x = x + t;
|
||||
x += t;
|
||||
A.set(nn, nn, x);
|
||||
A.set(nn - 1, nn - 1, y + t);
|
||||
if q >= T::zero() {
|
||||
z = p + z.copysign(p);
|
||||
z = p + RealNumber::copysign(z, p);
|
||||
d[nn - 1] = x + z;
|
||||
d[nn] = x + z;
|
||||
if z != T::zero() {
|
||||
@@ -482,8 +495,8 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e
|
||||
p = x / s;
|
||||
q = z / s;
|
||||
r = (p * p + q * q).sqrt();
|
||||
p = p / r;
|
||||
q = q / r;
|
||||
p /= r;
|
||||
q /= r;
|
||||
for j in nn - 1..n {
|
||||
z = A.get(nn - 1, j);
|
||||
A.set(nn - 1, j, q * z + p * A.get(nn, j));
|
||||
@@ -516,7 +529,7 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e
|
||||
panic!("Too many iterations in hqr");
|
||||
}
|
||||
if its == 10 || its == 20 {
|
||||
t = t + x;
|
||||
t += x;
|
||||
for i in 0..nn + 1 {
|
||||
A.sub_element_mut(i, i, x);
|
||||
}
|
||||
@@ -535,9 +548,9 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e
|
||||
q = A.get(m + 1, m + 1) - z - r - s;
|
||||
r = A.get(m + 2, m + 1);
|
||||
s = p.abs() + q.abs() + r.abs();
|
||||
p = p / s;
|
||||
q = q / s;
|
||||
r = r / s;
|
||||
p /= s;
|
||||
q /= s;
|
||||
r /= s;
|
||||
if m == l {
|
||||
break;
|
||||
}
|
||||
@@ -565,12 +578,12 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e
|
||||
}
|
||||
x = p.abs() + q.abs() + r.abs();
|
||||
if x != T::zero() {
|
||||
p = p / x;
|
||||
q = q / x;
|
||||
r = r / x;
|
||||
p /= x;
|
||||
q /= x;
|
||||
r /= x;
|
||||
}
|
||||
}
|
||||
let s = (p * p + q * q + r * r).sqrt().copysign(p);
|
||||
let s = RealNumber::copysign((p * p + q * q + r * r).sqrt(), p);
|
||||
if s != T::zero() {
|
||||
if k == m {
|
||||
if l != m {
|
||||
@@ -579,31 +592,26 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e
|
||||
} else {
|
||||
A.set(k, k - 1, -s * x);
|
||||
}
|
||||
p = p + s;
|
||||
p += s;
|
||||
x = p / s;
|
||||
y = q / s;
|
||||
z = r / s;
|
||||
q = q / p;
|
||||
r = r / p;
|
||||
q /= p;
|
||||
r /= p;
|
||||
for j in k..n {
|
||||
p = A.get(k, j) + q * A.get(k + 1, j);
|
||||
if k + 1 != nn {
|
||||
p = p + r * A.get(k + 2, j);
|
||||
p += r * A.get(k + 2, j);
|
||||
A.sub_element_mut(k + 2, j, p * z);
|
||||
}
|
||||
A.sub_element_mut(k + 1, j, p * y);
|
||||
A.sub_element_mut(k, j, p * x);
|
||||
}
|
||||
let mmin;
|
||||
if nn < k + 3 {
|
||||
mmin = nn;
|
||||
} else {
|
||||
mmin = k + 3;
|
||||
}
|
||||
let mmin = if nn < k + 3 { nn } else { k + 3 };
|
||||
for i in 0..mmin + 1 {
|
||||
p = x * A.get(i, k) + y * A.get(i, k + 1);
|
||||
if k + 1 != nn {
|
||||
p = p + z * A.get(i, k + 2);
|
||||
p += z * A.get(i, k + 2);
|
||||
A.sub_element_mut(i, k + 2, p * r);
|
||||
}
|
||||
A.sub_element_mut(i, k + 1, p * q);
|
||||
@@ -612,7 +620,7 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e
|
||||
for i in 0..n {
|
||||
p = x * V.get(i, k) + y * V.get(i, k + 1);
|
||||
if k + 1 != nn {
|
||||
p = p + z * V.get(i, k + 2);
|
||||
p += z * V.get(i, k + 2);
|
||||
V.sub_element_mut(i, k + 2, p * r);
|
||||
}
|
||||
V.sub_element_mut(i, k + 1, p * q);
|
||||
@@ -642,7 +650,7 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e
|
||||
let w = A.get(i, i) - p;
|
||||
r = T::zero();
|
||||
for j in m..=nn {
|
||||
r = r + A.get(i, j) * A.get(j, nn);
|
||||
r += A.get(i, j) * A.get(j, nn);
|
||||
}
|
||||
if e[i] < T::zero() {
|
||||
z = w;
|
||||
@@ -701,8 +709,8 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e
|
||||
let mut ra = T::zero();
|
||||
let mut sa = T::zero();
|
||||
for j in m..=nn {
|
||||
ra = ra + A.get(i, j) * A.get(j, na);
|
||||
sa = sa + A.get(i, j) * A.get(j, nn);
|
||||
ra += A.get(i, j) * A.get(j, na);
|
||||
sa += A.get(i, j) * A.get(j, nn);
|
||||
}
|
||||
if e[i] < T::zero() {
|
||||
z = w;
|
||||
@@ -766,7 +774,7 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e
|
||||
for i in 0..n {
|
||||
z = T::zero();
|
||||
for k in 0..=j {
|
||||
z = z + V.get(i, k) * A.get(k, j);
|
||||
z += V.get(i, k) * A.get(k, j);
|
||||
}
|
||||
V.set(i, j, z);
|
||||
}
|
||||
@@ -774,23 +782,23 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e
|
||||
}
|
||||
}
|
||||
|
||||
fn balbak<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, scale: &Vec<T>) {
|
||||
fn balbak<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, scale: &[T]) {
|
||||
let (n, _) = V.shape();
|
||||
for i in 0..n {
|
||||
for (i, scale_i) in scale.iter().enumerate().take(n) {
|
||||
for j in 0..n {
|
||||
V.mul_element_mut(i, j, scale[i]);
|
||||
V.mul_element_mut(i, j, *scale_i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn sort<T: RealNumber, M: BaseMatrix<T>>(d: &mut Vec<T>, e: &mut Vec<T>, V: &mut M) {
|
||||
fn sort<T: RealNumber, M: BaseMatrix<T>>(d: &mut [T], e: &mut [T], V: &mut M) {
|
||||
let n = d.len();
|
||||
let mut temp = vec![T::zero(); n];
|
||||
for j in 1..n {
|
||||
let real = d[j];
|
||||
let img = e[j];
|
||||
for k in 0..n {
|
||||
temp[k] = V.get(k, j);
|
||||
for (k, temp_k) in temp.iter_mut().enumerate().take(n) {
|
||||
*temp_k = V.get(k, j);
|
||||
}
|
||||
let mut i = j as i32 - 1;
|
||||
while i >= 0 {
|
||||
@@ -804,10 +812,10 @@ fn sort<T: RealNumber, M: BaseMatrix<T>>(d: &mut Vec<T>, e: &mut Vec<T>, V: &mut
|
||||
}
|
||||
i -= 1;
|
||||
}
|
||||
d[i as usize + 1] = real;
|
||||
e[i as usize + 1] = img;
|
||||
for k in 0..n {
|
||||
V.set(k, i as usize + 1, temp[k]);
|
||||
d[(i + 1) as usize] = real;
|
||||
e[(i + 1) as usize] = img;
|
||||
for (k, temp_k) in temp.iter().enumerate().take(n) {
|
||||
V.set(k, (i + 1) as usize, *temp_k);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -816,7 +824,7 @@ fn sort<T: RealNumber, M: BaseMatrix<T>>(d: &mut Vec<T>, e: &mut Vec<T>, V: &mut
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn decompose_symmetric() {
|
||||
let A = DenseMatrix::from_2d_array(&[
|
||||
@@ -843,7 +851,7 @@ mod tests {
|
||||
assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn decompose_asymmetric() {
|
||||
let A = DenseMatrix::from_2d_array(&[
|
||||
@@ -870,7 +878,7 @@ mod tests {
|
||||
assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn decompose_complex() {
|
||||
let A = DenseMatrix::from_2d_array(&[
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
//! In this module you will find composite of matrix operations that are used elsewhere
|
||||
//! for improved efficiency.
|
||||
|
||||
use crate::linalg::BaseMatrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// High order matrix operations.
|
||||
pub trait HighOrderOperations<T: RealNumber>: BaseMatrix<T> {
|
||||
/// Y = AB
|
||||
/// ```
|
||||
/// use smartcore::linalg::naive::dense_matrix::*;
|
||||
/// use smartcore::linalg::high_order::HighOrderOperations;
|
||||
///
|
||||
/// let a = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.]]);
|
||||
/// let b = DenseMatrix::from_2d_array(&[&[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
/// let expected = DenseMatrix::from_2d_array(&[&[71., 80.], &[92., 104.]]);
|
||||
///
|
||||
/// assert_eq!(a.ab(true, &b, false), expected);
|
||||
/// ```
|
||||
fn ab(&self, a_transpose: bool, b: &Self, b_transpose: bool) -> Self {
|
||||
match (a_transpose, b_transpose) {
|
||||
(true, true) => b.matmul(self).transpose(),
|
||||
(false, true) => self.matmul(&b.transpose()),
|
||||
(true, false) => self.transpose().matmul(b),
|
||||
(false, false) => self.matmul(b),
|
||||
}
|
||||
}
|
||||
}
|
||||
+20
-25
@@ -33,6 +33,7 @@
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
use std::cmp::Ordering;
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
@@ -45,13 +46,13 @@ use crate::math::num::RealNumber;
|
||||
pub struct LU<T: RealNumber, M: BaseMatrix<T>> {
|
||||
LU: M,
|
||||
pivot: Vec<usize>,
|
||||
pivot_sign: i8,
|
||||
_pivot_sign: i8,
|
||||
singular: bool,
|
||||
phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
|
||||
pub(crate) fn new(LU: M, pivot: Vec<usize>, pivot_sign: i8) -> LU<T, M> {
|
||||
pub(crate) fn new(LU: M, pivot: Vec<usize>, _pivot_sign: i8) -> LU<T, M> {
|
||||
let (_, n) = LU.shape();
|
||||
|
||||
let mut singular = false;
|
||||
@@ -63,10 +64,10 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
|
||||
}
|
||||
|
||||
LU {
|
||||
LU: LU,
|
||||
pivot: pivot,
|
||||
pivot_sign: pivot_sign,
|
||||
singular: singular,
|
||||
LU,
|
||||
pivot,
|
||||
_pivot_sign,
|
||||
singular,
|
||||
phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
@@ -78,12 +79,10 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
|
||||
|
||||
for i in 0..n_rows {
|
||||
for j in 0..n_cols {
|
||||
if i > j {
|
||||
L.set(i, j, self.LU.get(i, j));
|
||||
} else if i == j {
|
||||
L.set(i, j, T::one());
|
||||
} else {
|
||||
L.set(i, j, T::zero());
|
||||
match i.cmp(&j) {
|
||||
Ordering::Greater => L.set(i, j, self.LU.get(i, j)),
|
||||
Ordering::Equal => L.set(i, j, T::one()),
|
||||
Ordering::Less => L.set(i, j, T::zero()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -203,27 +202,24 @@ pub trait LUDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
fn lu_mut(mut self) -> Result<LU<T, Self>, Failed> {
|
||||
let (m, n) = self.shape();
|
||||
|
||||
let mut piv = vec![0; m];
|
||||
for i in 0..m {
|
||||
piv[i] = i;
|
||||
}
|
||||
let mut piv = (0..m).collect::<Vec<_>>();
|
||||
|
||||
let mut pivsign = 1;
|
||||
let mut LUcolj = vec![T::zero(); m];
|
||||
|
||||
for j in 0..n {
|
||||
for i in 0..m {
|
||||
LUcolj[i] = self.get(i, j);
|
||||
for (i, LUcolj_i) in LUcolj.iter_mut().enumerate().take(m) {
|
||||
*LUcolj_i = self.get(i, j);
|
||||
}
|
||||
|
||||
for i in 0..m {
|
||||
let kmax = usize::min(i, j);
|
||||
let mut s = T::zero();
|
||||
for k in 0..kmax {
|
||||
s = s + self.get(i, k) * LUcolj[k];
|
||||
for (k, LUcolj_k) in LUcolj.iter().enumerate().take(kmax) {
|
||||
s += self.get(i, k) * (*LUcolj_k);
|
||||
}
|
||||
|
||||
LUcolj[i] = LUcolj[i] - s;
|
||||
LUcolj[i] -= s;
|
||||
self.set(i, j, LUcolj[i]);
|
||||
}
|
||||
|
||||
@@ -239,9 +235,7 @@ pub trait LUDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
self.set(p, k, self.get(j, k));
|
||||
self.set(j, k, t);
|
||||
}
|
||||
let k = piv[p];
|
||||
piv[p] = piv[j];
|
||||
piv[j] = k;
|
||||
piv.swap(p, j);
|
||||
pivsign = -pivsign;
|
||||
}
|
||||
|
||||
@@ -266,6 +260,7 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn decompose() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
|
||||
@@ -280,7 +275,7 @@ mod tests {
|
||||
assert!(lu.U().approximate_eq(&expected_U, 1e-4));
|
||||
assert!(lu.pivot().approximate_eq(&expected_pivot, 1e-4));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn inverse() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
|
||||
|
||||
+328
-7
@@ -1,3 +1,4 @@
|
||||
#![allow(clippy::wrong_self_convention)]
|
||||
//! # Linear Algebra and Matrix Decomposition
|
||||
//!
|
||||
//! Most machine learning algorithms in SmartCore depend on linear algebra and matrix decomposition methods from this module.
|
||||
@@ -33,8 +34,10 @@
|
||||
//! let u: DenseMatrix<f64> = svd.U;
|
||||
//! ```
|
||||
|
||||
pub mod cholesky;
|
||||
/// The matrix is represented in terms of its eigenvalues and eigenvectors.
|
||||
pub mod evd;
|
||||
pub mod high_order;
|
||||
/// Factors a matrix as the product of a lower triangular matrix and an upper triangular matrix.
|
||||
pub mod lu;
|
||||
/// Dense matrix with column-major order that wraps [Vec](https://doc.rust-lang.org/std/vec/struct.Vec.html).
|
||||
@@ -47,6 +50,7 @@ pub mod nalgebra_bindings;
|
||||
pub mod ndarray_bindings;
|
||||
/// QR factorization that factors a matrix into a product of an orthogonal matrix and an upper triangular matrix.
|
||||
pub mod qr;
|
||||
pub mod stats;
|
||||
/// Singular value decomposition.
|
||||
pub mod svd;
|
||||
|
||||
@@ -55,9 +59,12 @@ use std::marker::PhantomData;
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::math::num::RealNumber;
|
||||
use cholesky::CholeskyDecomposableMatrix;
|
||||
use evd::EVDDecomposableMatrix;
|
||||
use high_order::HighOrderOperations;
|
||||
use lu::LUDecomposableMatrix;
|
||||
use qr::QRDecomposableMatrix;
|
||||
use stats::{MatrixPreprocessing, MatrixStats};
|
||||
use svd::SVDDecomposableMatrix;
|
||||
|
||||
/// Column or row vector
|
||||
@@ -74,6 +81,26 @@ pub trait BaseVector<T: RealNumber>: Clone + Debug {
|
||||
/// Get number of elevemnt in the vector
|
||||
fn len(&self) -> usize;
|
||||
|
||||
/// Returns true if the vector is empty.
|
||||
fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
|
||||
/// Create a new vector from a &[T]
|
||||
/// ```
|
||||
/// use smartcore::linalg::naive::dense_matrix::*;
|
||||
/// let a: [f64; 5] = [0., 0.5, 2., 3., 4.];
|
||||
/// let v: Vec<f64> = BaseVector::from_array(&a);
|
||||
/// assert_eq!(v, vec![0., 0.5, 2., 3., 4.]);
|
||||
/// ```
|
||||
fn from_array(f: &[T]) -> Self {
|
||||
let mut v = Self::zeros(f.len());
|
||||
for (i, elem) in f.iter().enumerate() {
|
||||
v.set(i, *elem);
|
||||
}
|
||||
v
|
||||
}
|
||||
|
||||
/// Return a vector with the elements of the one-dimensional array.
|
||||
fn to_vec(&self) -> Vec<T>;
|
||||
|
||||
@@ -85,6 +112,182 @@ pub trait BaseVector<T: RealNumber>: Clone + Debug {
|
||||
|
||||
/// Create new vector of size `len` where each element is set to `value`.
|
||||
fn fill(len: usize, value: T) -> Self;
|
||||
|
||||
/// Vector dot product
|
||||
fn dot(&self, other: &Self) -> T;
|
||||
|
||||
/// Returns True if matrices are element-wise equal within a tolerance `error`.
|
||||
fn approximate_eq(&self, other: &Self, error: T) -> bool;
|
||||
|
||||
/// Returns [L2 norm] of the vector(https://en.wikipedia.org/wiki/Matrix_norm).
|
||||
fn norm2(&self) -> T;
|
||||
|
||||
/// Returns [vectors norm](https://en.wikipedia.org/wiki/Matrix_norm) of order `p`.
|
||||
fn norm(&self, p: T) -> T;
|
||||
|
||||
/// Divide single element of the vector by `x`, write result to original vector.
|
||||
fn div_element_mut(&mut self, pos: usize, x: T);
|
||||
|
||||
/// Multiply single element of the vector by `x`, write result to original vector.
|
||||
fn mul_element_mut(&mut self, pos: usize, x: T);
|
||||
|
||||
/// Add single element of the vector to `x`, write result to original vector.
|
||||
fn add_element_mut(&mut self, pos: usize, x: T);
|
||||
|
||||
/// Subtract `x` from single element of the vector, write result to original vector.
|
||||
fn sub_element_mut(&mut self, pos: usize, x: T);
|
||||
|
||||
/// Subtract scalar
|
||||
fn sub_scalar_mut(&mut self, x: T) -> &Self {
|
||||
for i in 0..self.len() {
|
||||
self.set(i, self.get(i) - x);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Subtract scalar
|
||||
fn add_scalar_mut(&mut self, x: T) -> &Self {
|
||||
for i in 0..self.len() {
|
||||
self.set(i, self.get(i) + x);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Subtract scalar
|
||||
fn mul_scalar_mut(&mut self, x: T) -> &Self {
|
||||
for i in 0..self.len() {
|
||||
self.set(i, self.get(i) * x);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Subtract scalar
|
||||
fn div_scalar_mut(&mut self, x: T) -> &Self {
|
||||
for i in 0..self.len() {
|
||||
self.set(i, self.get(i) / x);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Add vectors, element-wise
|
||||
fn add_scalar(&self, x: T) -> Self {
|
||||
let mut r = self.clone();
|
||||
r.add_scalar_mut(x);
|
||||
r
|
||||
}
|
||||
|
||||
/// Subtract vectors, element-wise
|
||||
fn sub_scalar(&self, x: T) -> Self {
|
||||
let mut r = self.clone();
|
||||
r.sub_scalar_mut(x);
|
||||
r
|
||||
}
|
||||
|
||||
/// Multiply vectors, element-wise
|
||||
fn mul_scalar(&self, x: T) -> Self {
|
||||
let mut r = self.clone();
|
||||
r.mul_scalar_mut(x);
|
||||
r
|
||||
}
|
||||
|
||||
/// Divide vectors, element-wise
|
||||
fn div_scalar(&self, x: T) -> Self {
|
||||
let mut r = self.clone();
|
||||
r.div_scalar_mut(x);
|
||||
r
|
||||
}
|
||||
|
||||
/// Add vectors, element-wise, overriding original vector with result.
|
||||
fn add_mut(&mut self, other: &Self) -> &Self;
|
||||
|
||||
/// Subtract vectors, element-wise, overriding original vector with result.
|
||||
fn sub_mut(&mut self, other: &Self) -> &Self;
|
||||
|
||||
/// Multiply vectors, element-wise, overriding original vector with result.
|
||||
fn mul_mut(&mut self, other: &Self) -> &Self;
|
||||
|
||||
/// Divide vectors, element-wise, overriding original vector with result.
|
||||
fn div_mut(&mut self, other: &Self) -> &Self;
|
||||
|
||||
/// Add vectors, element-wise
|
||||
fn add(&self, other: &Self) -> Self {
|
||||
let mut r = self.clone();
|
||||
r.add_mut(other);
|
||||
r
|
||||
}
|
||||
|
||||
/// Subtract vectors, element-wise
|
||||
fn sub(&self, other: &Self) -> Self {
|
||||
let mut r = self.clone();
|
||||
r.sub_mut(other);
|
||||
r
|
||||
}
|
||||
|
||||
/// Multiply vectors, element-wise
|
||||
fn mul(&self, other: &Self) -> Self {
|
||||
let mut r = self.clone();
|
||||
r.mul_mut(other);
|
||||
r
|
||||
}
|
||||
|
||||
/// Divide vectors, element-wise
|
||||
fn div(&self, other: &Self) -> Self {
|
||||
let mut r = self.clone();
|
||||
r.div_mut(other);
|
||||
r
|
||||
}
|
||||
|
||||
/// Calculates sum of all elements of the vector.
|
||||
fn sum(&self) -> T;
|
||||
|
||||
/// Returns unique values from the vector.
|
||||
/// ```
|
||||
/// use smartcore::linalg::naive::dense_matrix::*;
|
||||
/// let a = vec!(1., 2., 2., -2., -6., -7., 2., 3., 4.);
|
||||
///
|
||||
///assert_eq!(a.unique(), vec![-7., -6., -2., 1., 2., 3., 4.]);
|
||||
/// ```
|
||||
fn unique(&self) -> Vec<T>;
|
||||
|
||||
/// Computes the arithmetic mean.
|
||||
fn mean(&self) -> T {
|
||||
self.sum() / T::from_usize(self.len()).unwrap()
|
||||
}
|
||||
/// Computes variance.
|
||||
fn var(&self) -> T {
|
||||
let n = self.len();
|
||||
|
||||
let mut mu = T::zero();
|
||||
let mut sum = T::zero();
|
||||
let div = T::from_usize(n).unwrap();
|
||||
for i in 0..n {
|
||||
let xi = self.get(i);
|
||||
mu += xi;
|
||||
sum += xi * xi;
|
||||
}
|
||||
mu /= div;
|
||||
sum / div - mu.powi(2)
|
||||
}
|
||||
/// Computes the standard deviation.
|
||||
fn std(&self) -> T {
|
||||
self.var().sqrt()
|
||||
}
|
||||
|
||||
/// Copies content of `other` vector.
|
||||
fn copy_from(&mut self, other: &Self);
|
||||
|
||||
/// Take elements from an array.
|
||||
fn take(&self, index: &[usize]) -> Self {
|
||||
let n = index.len();
|
||||
|
||||
let mut result = Self::zeros(n);
|
||||
|
||||
for (i, idx) in index.iter().enumerate() {
|
||||
result.set(i, self.get(*idx));
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
/// Generic matrix type.
|
||||
@@ -110,6 +313,10 @@ pub trait BaseMatrix<T: RealNumber>: Clone + Debug {
|
||||
/// * `row` - row number
|
||||
fn get_row_as_vec(&self, row: usize) -> Vec<T>;
|
||||
|
||||
/// Get the `row`'th row
|
||||
/// * `row` - row number
|
||||
fn get_row(&self, row: usize) -> Self::RowVector;
|
||||
|
||||
/// Copies a vector with elements of the `row`'th row into `result`
|
||||
/// * `row` - row number
|
||||
/// * `result` - receiver for the row
|
||||
@@ -418,6 +625,36 @@ pub trait BaseMatrix<T: RealNumber>: Clone + Debug {
|
||||
|
||||
/// Calculates the covariance matrix
|
||||
fn cov(&self) -> Self;
|
||||
|
||||
/// Take elements from an array along an axis.
|
||||
fn take(&self, index: &[usize], axis: u8) -> Self {
|
||||
let (n, p) = self.shape();
|
||||
|
||||
let k = match axis {
|
||||
0 => p,
|
||||
_ => n,
|
||||
};
|
||||
|
||||
let mut result = match axis {
|
||||
0 => Self::zeros(index.len(), p),
|
||||
_ => Self::zeros(n, index.len()),
|
||||
};
|
||||
|
||||
for (i, idx) in index.iter().enumerate() {
|
||||
for j in 0..k {
|
||||
match axis {
|
||||
0 => result.set(i, j, self.get(*idx, j)),
|
||||
_ => result.set(j, i, self.get(j, *idx)),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
/// Take an individual column from the matrix.
|
||||
fn take_column(&self, column_index: usize) -> Self {
|
||||
self.take(&[column_index], 1)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generic matrix with additional mixins like various factorization methods.
|
||||
@@ -427,14 +664,18 @@ pub trait Matrix<T: RealNumber>:
|
||||
+ EVDDecomposableMatrix<T>
|
||||
+ QRDecomposableMatrix<T>
|
||||
+ LUDecomposableMatrix<T>
|
||||
+ CholeskyDecomposableMatrix<T>
|
||||
+ MatrixStats<T>
|
||||
+ MatrixPreprocessing<T>
|
||||
+ HighOrderOperations<T>
|
||||
+ PartialEq
|
||||
+ Display
|
||||
{
|
||||
}
|
||||
|
||||
pub(crate) fn row_iter<F: RealNumber, M: BaseMatrix<F>>(m: &M) -> RowIter<F, M> {
|
||||
pub(crate) fn row_iter<F: RealNumber, M: BaseMatrix<F>>(m: &M) -> RowIter<'_, F, M> {
|
||||
RowIter {
|
||||
m: m,
|
||||
m,
|
||||
pos: 0,
|
||||
max_pos: m.shape().0,
|
||||
phantom: PhantomData,
|
||||
@@ -452,13 +693,93 @@ impl<'a, T: RealNumber, M: BaseMatrix<T>> Iterator for RowIter<'a, T, M> {
|
||||
type Item = Vec<T>;
|
||||
|
||||
fn next(&mut self) -> Option<Vec<T>> {
|
||||
let res;
|
||||
if self.pos < self.max_pos {
|
||||
res = Some(self.m.get_row_as_vec(self.pos))
|
||||
let res = if self.pos < self.max_pos {
|
||||
Some(self.m.get_row_as_vec(self.pos))
|
||||
} else {
|
||||
res = None
|
||||
}
|
||||
None
|
||||
};
|
||||
self.pos += 1;
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::linalg::BaseMatrix;
|
||||
use crate::linalg::BaseVector;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn mean() {
|
||||
let m = vec![1., 2., 3.];
|
||||
|
||||
assert_eq!(m.mean(), 2.0);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn std() {
|
||||
let m = vec![1., 2., 3.];
|
||||
|
||||
assert!((m.std() - 0.81f64).abs() < 1e-2);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn var() {
|
||||
let m = vec![1., 2., 3., 4.];
|
||||
|
||||
assert!((m.var() - 1.25f64).abs() < std::f64::EPSILON);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn vec_take() {
|
||||
let m = vec![1., 2., 3., 4., 5.];
|
||||
|
||||
assert_eq!(m.take(&vec!(0, 0, 4, 4)), vec![1., 1., 5., 5.]);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn take() {
|
||||
let m = DenseMatrix::from_2d_array(&[
|
||||
&[1.0, 2.0],
|
||||
&[3.0, 4.0],
|
||||
&[5.0, 6.0],
|
||||
&[7.0, 8.0],
|
||||
&[9.0, 10.0],
|
||||
]);
|
||||
|
||||
let expected_0 = DenseMatrix::from_2d_array(&[&[3.0, 4.0], &[3.0, 4.0], &[7.0, 8.0]]);
|
||||
|
||||
let expected_1 = DenseMatrix::from_2d_array(&[
|
||||
&[2.0, 1.0],
|
||||
&[4.0, 3.0],
|
||||
&[6.0, 5.0],
|
||||
&[8.0, 7.0],
|
||||
&[10.0, 9.0],
|
||||
]);
|
||||
|
||||
assert_eq!(m.take(&vec!(1, 1, 3), 0), expected_0);
|
||||
assert_eq!(m.take(&vec!(1, 0), 1), expected_1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn take_second_column_from_matrix() {
|
||||
let four_columns: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
|
||||
&[0.0, 1.0, 2.0, 3.0],
|
||||
&[0.0, 1.0, 2.0, 3.0],
|
||||
&[0.0, 1.0, 2.0, 3.0],
|
||||
&[0.0, 1.0, 2.0, 3.0],
|
||||
]);
|
||||
|
||||
let second_column = four_columns.take_column(1);
|
||||
assert_eq!(
|
||||
second_column,
|
||||
DenseMatrix::from_2d_array(&[&[1.0], &[1.0], &[1.0], &[1.0]]),
|
||||
"The second column was not extracted correctly"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
extern crate num;
|
||||
#![allow(clippy::ptr_arg)]
|
||||
use std::fmt;
|
||||
use std::fmt::Debug;
|
||||
#[cfg(feature = "serde")]
|
||||
use std::marker::PhantomData;
|
||||
use std::ops::Range;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::de::{Deserializer, MapAccess, SeqAccess, Visitor};
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::ser::{SerializeStruct, Serializer};
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::cholesky::CholeskyDecomposableMatrix;
|
||||
use crate::linalg::evd::EVDDecomposableMatrix;
|
||||
use crate::linalg::high_order::HighOrderOperations;
|
||||
use crate::linalg::lu::LUDecomposableMatrix;
|
||||
use crate::linalg::qr::QRDecomposableMatrix;
|
||||
use crate::linalg::stats::{MatrixPreprocessing, MatrixStats};
|
||||
use crate::linalg::svd::SVDDecomposableMatrix;
|
||||
use crate::linalg::Matrix;
|
||||
pub use crate::linalg::{BaseMatrix, BaseVector};
|
||||
@@ -29,8 +36,7 @@ impl<T: RealNumber> BaseVector<T> for Vec<T> {
|
||||
}
|
||||
|
||||
fn to_vec(&self) -> Vec<T> {
|
||||
let v = self.clone();
|
||||
v
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn zeros(len: usize) -> Self {
|
||||
@@ -44,6 +50,149 @@ impl<T: RealNumber> BaseVector<T> for Vec<T> {
|
||||
fn fill(len: usize, value: T) -> Self {
|
||||
vec![value; len]
|
||||
}
|
||||
|
||||
fn dot(&self, other: &Self) -> T {
|
||||
if self.len() != other.len() {
|
||||
panic!("A and B should have the same size");
|
||||
}
|
||||
|
||||
let mut result = T::zero();
|
||||
for i in 0..self.len() {
|
||||
result += self[i] * other[i];
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn norm2(&self) -> T {
|
||||
let mut norm = T::zero();
|
||||
|
||||
for xi in self.iter() {
|
||||
norm += *xi * *xi;
|
||||
}
|
||||
|
||||
norm.sqrt()
|
||||
}
|
||||
|
||||
fn norm(&self, p: T) -> T {
|
||||
if p.is_infinite() && p.is_sign_positive() {
|
||||
self.iter()
|
||||
.map(|x| x.abs())
|
||||
.fold(T::neg_infinity(), |a, b| a.max(b))
|
||||
} else if p.is_infinite() && p.is_sign_negative() {
|
||||
self.iter()
|
||||
.map(|x| x.abs())
|
||||
.fold(T::infinity(), |a, b| a.min(b))
|
||||
} else {
|
||||
let mut norm = T::zero();
|
||||
|
||||
for xi in self.iter() {
|
||||
norm += xi.abs().powf(p);
|
||||
}
|
||||
|
||||
norm.powf(T::one() / p)
|
||||
}
|
||||
}
|
||||
|
||||
fn div_element_mut(&mut self, pos: usize, x: T) {
|
||||
self[pos] /= x;
|
||||
}
|
||||
|
||||
fn mul_element_mut(&mut self, pos: usize, x: T) {
|
||||
self[pos] *= x;
|
||||
}
|
||||
|
||||
fn add_element_mut(&mut self, pos: usize, x: T) {
|
||||
self[pos] += x
|
||||
}
|
||||
|
||||
fn sub_element_mut(&mut self, pos: usize, x: T) {
|
||||
self[pos] -= x;
|
||||
}
|
||||
|
||||
fn add_mut(&mut self, other: &Self) -> &Self {
|
||||
if self.len() != other.len() {
|
||||
panic!("A and B should have the same shape");
|
||||
}
|
||||
for i in 0..self.len() {
|
||||
self.add_element_mut(i, other.get(i));
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
fn sub_mut(&mut self, other: &Self) -> &Self {
|
||||
if self.len() != other.len() {
|
||||
panic!("A and B should have the same shape");
|
||||
}
|
||||
for i in 0..self.len() {
|
||||
self.sub_element_mut(i, other.get(i));
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
fn mul_mut(&mut self, other: &Self) -> &Self {
|
||||
if self.len() != other.len() {
|
||||
panic!("A and B should have the same shape");
|
||||
}
|
||||
for i in 0..self.len() {
|
||||
self.mul_element_mut(i, other.get(i));
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
fn div_mut(&mut self, other: &Self) -> &Self {
|
||||
if self.len() != other.len() {
|
||||
panic!("A and B should have the same shape");
|
||||
}
|
||||
for i in 0..self.len() {
|
||||
self.div_element_mut(i, other.get(i));
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
fn approximate_eq(&self, other: &Self, error: T) -> bool {
|
||||
if self.len() != other.len() {
|
||||
false
|
||||
} else {
|
||||
for i in 0..other.len() {
|
||||
if (self[i] - other[i]).abs() > error {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
fn sum(&self) -> T {
|
||||
let mut sum = T::zero();
|
||||
for self_i in self.iter() {
|
||||
sum += *self_i;
|
||||
}
|
||||
sum
|
||||
}
|
||||
|
||||
fn unique(&self) -> Vec<T> {
|
||||
let mut result = self.clone();
|
||||
result.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
result.dedup();
|
||||
result
|
||||
}
|
||||
|
||||
fn copy_from(&mut self, other: &Self) {
|
||||
if self.len() != other.len() {
|
||||
panic!(
|
||||
"Can't copy vector of length {} into a vector of length {}.",
|
||||
self.len(),
|
||||
other.len()
|
||||
);
|
||||
}
|
||||
|
||||
self[..].clone_from_slice(&other[..]);
|
||||
}
|
||||
}
|
||||
|
||||
/// Column-major, dense matrix. See [Simple Dense Matrix](../index.html).
|
||||
@@ -65,7 +214,7 @@ pub struct DenseMatrixIterator<'a, T: RealNumber> {
|
||||
}
|
||||
|
||||
impl<T: RealNumber> fmt::Display for DenseMatrix<T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let mut rows: Vec<Vec<f64>> = Vec::new();
|
||||
for r in 0..self.nrows {
|
||||
rows.push(
|
||||
@@ -84,15 +233,15 @@ impl<T: RealNumber> DenseMatrix<T> {
|
||||
/// `values` should be in column-major order.
|
||||
pub fn new(nrows: usize, ncols: usize, values: Vec<T>) -> Self {
|
||||
DenseMatrix {
|
||||
ncols: ncols,
|
||||
nrows: nrows,
|
||||
values: values,
|
||||
ncols,
|
||||
nrows,
|
||||
values,
|
||||
}
|
||||
}
|
||||
|
||||
/// New instance of `DenseMatrix` from 2d array.
|
||||
pub fn from_2d_array(values: &[&[T]]) -> Self {
|
||||
DenseMatrix::from_2d_vec(&values.into_iter().map(|row| Vec::from(*row)).collect())
|
||||
DenseMatrix::from_2d_vec(&values.iter().map(|row| Vec::from(*row)).collect())
|
||||
}
|
||||
|
||||
/// New instance of `DenseMatrix` from 2d vector.
|
||||
@@ -103,13 +252,13 @@ impl<T: RealNumber> DenseMatrix<T> {
|
||||
.unwrap_or_else(|| panic!("Cannot create 2d matrix from an empty vector"))
|
||||
.len();
|
||||
let mut m = DenseMatrix {
|
||||
ncols: ncols,
|
||||
nrows: nrows,
|
||||
ncols,
|
||||
nrows,
|
||||
values: vec![T::zero(); ncols * nrows],
|
||||
};
|
||||
for row in 0..nrows {
|
||||
for col in 0..ncols {
|
||||
m.set(row, col, values[row][col]);
|
||||
for (row_index, row) in values.iter().enumerate().take(nrows) {
|
||||
for (col_index, value) in row.iter().enumerate().take(ncols) {
|
||||
m.set(row_index, col_index, *value);
|
||||
}
|
||||
}
|
||||
m
|
||||
@@ -127,10 +276,10 @@ impl<T: RealNumber> DenseMatrix<T> {
|
||||
/// * `nrows` - number of rows in new matrix.
|
||||
/// * `ncols` - number of columns in new matrix.
|
||||
/// * `values` - values to initialize the matrix.
|
||||
pub fn from_vec(nrows: usize, ncols: usize, values: &Vec<T>) -> DenseMatrix<T> {
|
||||
pub fn from_vec(nrows: usize, ncols: usize, values: &[T]) -> DenseMatrix<T> {
|
||||
let mut m = DenseMatrix {
|
||||
ncols: ncols,
|
||||
nrows: nrows,
|
||||
ncols,
|
||||
nrows,
|
||||
values: vec![T::zero(); ncols * nrows],
|
||||
};
|
||||
for row in 0..nrows {
|
||||
@@ -153,7 +302,7 @@ impl<T: RealNumber> DenseMatrix<T> {
|
||||
DenseMatrix {
|
||||
ncols: values.len(),
|
||||
nrows: 1,
|
||||
values: values,
|
||||
values,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -169,19 +318,19 @@ impl<T: RealNumber> DenseMatrix<T> {
|
||||
DenseMatrix {
|
||||
ncols: 1,
|
||||
nrows: values.len(),
|
||||
values: values,
|
||||
values,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates new column vector (_1xN_ matrix) from a vector.
|
||||
/// * `values` - values to initialize the matrix.
|
||||
pub fn iter<'a>(&'a self) -> DenseMatrixIterator<'a, T> {
|
||||
pub fn iter(&self) -> DenseMatrixIterator<'_, T> {
|
||||
DenseMatrixIterator {
|
||||
cur_c: 0,
|
||||
cur_r: 0,
|
||||
max_c: self.ncols,
|
||||
max_r: self.nrows,
|
||||
m: &self,
|
||||
m: self,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -204,6 +353,7 @@ impl<'a, T: RealNumber> Iterator for DenseMatrixIterator<'a, T> {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
impl<'de, T: RealNumber + fmt::Debug + Deserialize<'de>> Deserialize<'de> for DenseMatrix<T> {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
@@ -224,7 +374,7 @@ impl<'de, T: RealNumber + fmt::Debug + Deserialize<'de>> Deserialize<'de> for De
|
||||
impl<'a, T: RealNumber + fmt::Debug + Deserialize<'a>> Visitor<'a> for DenseMatrixVisitor<T> {
|
||||
type Value = DenseMatrix<T>;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
formatter.write_str("struct DenseMatrix")
|
||||
}
|
||||
|
||||
@@ -280,7 +430,7 @@ impl<'de, T: RealNumber + fmt::Debug + Deserialize<'de>> Deserialize<'de> for De
|
||||
}
|
||||
}
|
||||
|
||||
const FIELDS: &'static [&'static str] = &["nrows", "ncols", "values"];
|
||||
const FIELDS: &[&str] = &["nrows", "ncols", "values"];
|
||||
deserializer.deserialize_struct(
|
||||
"DenseMatrix",
|
||||
FIELDS,
|
||||
@@ -289,6 +439,7 @@ impl<'de, T: RealNumber + fmt::Debug + Deserialize<'de>> Deserialize<'de> for De
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
impl<T: RealNumber + fmt::Debug + Serialize> Serialize for DenseMatrix<T> {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
@@ -311,6 +462,43 @@ impl<T: RealNumber> QRDecomposableMatrix<T> for DenseMatrix<T> {}
|
||||
|
||||
impl<T: RealNumber> LUDecomposableMatrix<T> for DenseMatrix<T> {}
|
||||
|
||||
impl<T: RealNumber> CholeskyDecomposableMatrix<T> for DenseMatrix<T> {}
|
||||
|
||||
impl<T: RealNumber> HighOrderOperations<T> for DenseMatrix<T> {
|
||||
fn ab(&self, a_transpose: bool, b: &Self, b_transpose: bool) -> Self {
|
||||
if !a_transpose && !b_transpose {
|
||||
self.matmul(b)
|
||||
} else {
|
||||
let (d1, d2, d3, d4) = match (a_transpose, b_transpose) {
|
||||
(true, false) => (self.nrows, self.ncols, b.ncols, b.nrows),
|
||||
(false, true) => (self.ncols, self.nrows, b.nrows, b.ncols),
|
||||
_ => (self.nrows, self.ncols, b.nrows, b.ncols),
|
||||
};
|
||||
if d1 != d4 {
|
||||
panic!("Can not multiply {}x{} by {}x{} matrices", d2, d1, d4, d3);
|
||||
}
|
||||
let mut result = Self::zeros(d2, d3);
|
||||
for r in 0..d2 {
|
||||
for c in 0..d3 {
|
||||
let mut s = T::zero();
|
||||
for i in 0..d1 {
|
||||
match (a_transpose, b_transpose) {
|
||||
(true, false) => s += self.get(i, r) * b.get(i, c),
|
||||
(false, true) => s += self.get(r, i) * b.get(c, i),
|
||||
_ => s += self.get(i, r) * b.get(c, i),
|
||||
}
|
||||
}
|
||||
result.set(r, c, s);
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> MatrixStats<T> for DenseMatrix<T> {}
|
||||
impl<T: RealNumber> MatrixPreprocessing<T> for DenseMatrix<T> {}
|
||||
|
||||
impl<T: RealNumber> Matrix<T> for DenseMatrix<T> {}
|
||||
|
||||
impl<T: RealNumber> PartialEq for DenseMatrix<T> {
|
||||
@@ -335,10 +523,9 @@ impl<T: RealNumber> PartialEq for DenseMatrix<T> {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Into<Vec<T>> for DenseMatrix<T> {
|
||||
fn into(self) -> Vec<T> {
|
||||
self.values
|
||||
impl<T: RealNumber> From<DenseMatrix<T>> for Vec<T> {
|
||||
fn from(dense_matrix: DenseMatrix<T>) -> Vec<T> {
|
||||
dense_matrix.values
|
||||
}
|
||||
}
|
||||
|
||||
@@ -371,31 +558,41 @@ impl<T: RealNumber> BaseMatrix<T> for DenseMatrix<T> {
|
||||
self.values[col * self.nrows + row]
|
||||
}
|
||||
|
||||
fn get_row(&self, row: usize) -> Self::RowVector {
|
||||
let mut v = vec![T::zero(); self.ncols];
|
||||
|
||||
for (c, v_c) in v.iter_mut().enumerate().take(self.ncols) {
|
||||
*v_c = self.get(row, c);
|
||||
}
|
||||
|
||||
v
|
||||
}
|
||||
|
||||
fn get_row_as_vec(&self, row: usize) -> Vec<T> {
|
||||
let mut result = vec![T::zero(); self.ncols];
|
||||
for c in 0..self.ncols {
|
||||
result[c] = self.get(row, c);
|
||||
for (c, result_c) in result.iter_mut().enumerate().take(self.ncols) {
|
||||
*result_c = self.get(row, c);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn copy_row_as_vec(&self, row: usize, result: &mut Vec<T>) {
|
||||
for c in 0..self.ncols {
|
||||
result[c] = self.get(row, c);
|
||||
for (c, result_c) in result.iter_mut().enumerate().take(self.ncols) {
|
||||
*result_c = self.get(row, c);
|
||||
}
|
||||
}
|
||||
|
||||
fn get_col_as_vec(&self, col: usize) -> Vec<T> {
|
||||
let mut result = vec![T::zero(); self.nrows];
|
||||
for r in 0..self.nrows {
|
||||
result[r] = self.get(r, col);
|
||||
for (r, result_r) in result.iter_mut().enumerate().take(self.nrows) {
|
||||
*result_r = self.get(r, col);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn copy_col_as_vec(&self, col: usize, result: &mut Vec<T>) {
|
||||
for r in 0..self.nrows {
|
||||
result[r] = self.get(r, col);
|
||||
for (r, result_r) in result.iter_mut().enumerate().take(self.nrows) {
|
||||
*result_r = self.get(r, col);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -418,7 +615,7 @@ impl<T: RealNumber> BaseMatrix<T> for DenseMatrix<T> {
|
||||
matrix.set(i, i, T::one());
|
||||
}
|
||||
|
||||
return matrix;
|
||||
matrix
|
||||
}
|
||||
|
||||
fn shape(&self) -> (usize, usize) {
|
||||
@@ -470,7 +667,7 @@ impl<T: RealNumber> BaseMatrix<T> for DenseMatrix<T> {
|
||||
for c in 0..other.ncols {
|
||||
let mut s = T::zero();
|
||||
for i in 0..inner_d {
|
||||
s = s + self.get(r, i) * other.get(i, c);
|
||||
s += self.get(r, i) * other.get(i, c);
|
||||
}
|
||||
result.set(r, c, s);
|
||||
}
|
||||
@@ -480,8 +677,8 @@ impl<T: RealNumber> BaseMatrix<T> for DenseMatrix<T> {
|
||||
}
|
||||
|
||||
fn dot(&self, other: &Self) -> T {
|
||||
if self.nrows != 1 && other.nrows != 1 {
|
||||
panic!("A and B should both be 1-dimentional vectors.");
|
||||
if (self.nrows != 1 && other.nrows != 1) && (self.ncols != 1 && other.ncols != 1) {
|
||||
panic!("A and B should both be either a row or a column vector.");
|
||||
}
|
||||
if self.nrows * self.ncols != other.nrows * other.ncols {
|
||||
panic!("A and B should have the same size");
|
||||
@@ -489,7 +686,7 @@ impl<T: RealNumber> BaseMatrix<T> for DenseMatrix<T> {
|
||||
|
||||
let mut result = T::zero();
|
||||
for i in 0..(self.nrows * self.ncols) {
|
||||
result = result + self.values[i] * other.values[i];
|
||||
result += self.values[i] * other.values[i];
|
||||
}
|
||||
|
||||
result
|
||||
@@ -583,19 +780,19 @@ impl<T: RealNumber> BaseMatrix<T> for DenseMatrix<T> {
|
||||
}
|
||||
|
||||
fn div_element_mut(&mut self, row: usize, col: usize, x: T) {
|
||||
self.values[col * self.nrows + row] = self.values[col * self.nrows + row] / x;
|
||||
self.values[col * self.nrows + row] /= x;
|
||||
}
|
||||
|
||||
fn mul_element_mut(&mut self, row: usize, col: usize, x: T) {
|
||||
self.values[col * self.nrows + row] = self.values[col * self.nrows + row] * x;
|
||||
self.values[col * self.nrows + row] *= x;
|
||||
}
|
||||
|
||||
fn add_element_mut(&mut self, row: usize, col: usize, x: T) {
|
||||
self.values[col * self.nrows + row] = self.values[col * self.nrows + row] + x
|
||||
self.values[col * self.nrows + row] += x
|
||||
}
|
||||
|
||||
fn sub_element_mut(&mut self, row: usize, col: usize, x: T) {
|
||||
self.values[col * self.nrows + row] = self.values[col * self.nrows + row] - x;
|
||||
self.values[col * self.nrows + row] -= x;
|
||||
}
|
||||
|
||||
fn transpose(&self) -> Self {
|
||||
@@ -615,9 +812,9 @@ impl<T: RealNumber> BaseMatrix<T> for DenseMatrix<T> {
|
||||
fn rand(nrows: usize, ncols: usize) -> Self {
|
||||
let values: Vec<T> = (0..nrows * ncols).map(|_| T::rand()).collect();
|
||||
DenseMatrix {
|
||||
ncols: ncols,
|
||||
nrows: nrows,
|
||||
values: values,
|
||||
ncols,
|
||||
nrows,
|
||||
values,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -625,7 +822,7 @@ impl<T: RealNumber> BaseMatrix<T> for DenseMatrix<T> {
|
||||
let mut norm = T::zero();
|
||||
|
||||
for xi in self.values.iter() {
|
||||
norm = norm + *xi * *xi;
|
||||
norm += *xi * *xi;
|
||||
}
|
||||
|
||||
norm.sqrt()
|
||||
@@ -646,7 +843,7 @@ impl<T: RealNumber> BaseMatrix<T> for DenseMatrix<T> {
|
||||
let mut norm = T::zero();
|
||||
|
||||
for xi in self.values.iter() {
|
||||
norm = norm + xi.abs().powf(p);
|
||||
norm += xi.abs().powf(p);
|
||||
}
|
||||
|
||||
norm.powf(T::one() / p)
|
||||
@@ -657,13 +854,13 @@ impl<T: RealNumber> BaseMatrix<T> for DenseMatrix<T> {
|
||||
let mut mean = vec![T::zero(); self.ncols];
|
||||
|
||||
for r in 0..self.nrows {
|
||||
for c in 0..self.ncols {
|
||||
mean[c] = mean[c] + self.get(r, c);
|
||||
for (c, mean_c) in mean.iter_mut().enumerate().take(self.ncols) {
|
||||
*mean_c += self.get(r, c);
|
||||
}
|
||||
}
|
||||
|
||||
for i in 0..mean.len() {
|
||||
mean[i] = mean[i] / T::from(self.nrows).unwrap();
|
||||
for mean_i in mean.iter_mut() {
|
||||
*mean_i /= T::from(self.nrows).unwrap();
|
||||
}
|
||||
|
||||
mean
|
||||
@@ -671,28 +868,28 @@ impl<T: RealNumber> BaseMatrix<T> for DenseMatrix<T> {
|
||||
|
||||
fn add_scalar_mut(&mut self, scalar: T) -> &Self {
|
||||
for i in 0..self.values.len() {
|
||||
self.values[i] = self.values[i] + scalar;
|
||||
self.values[i] += scalar;
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
fn sub_scalar_mut(&mut self, scalar: T) -> &Self {
|
||||
for i in 0..self.values.len() {
|
||||
self.values[i] = self.values[i] - scalar;
|
||||
self.values[i] -= scalar;
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
fn mul_scalar_mut(&mut self, scalar: T) -> &Self {
|
||||
for i in 0..self.values.len() {
|
||||
self.values[i] = self.values[i] * scalar;
|
||||
self.values[i] *= scalar;
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
fn div_scalar_mut(&mut self, scalar: T) -> &Self {
|
||||
for i in 0..self.values.len() {
|
||||
self.values[i] = self.values[i] / scalar;
|
||||
self.values[i] /= scalar;
|
||||
}
|
||||
self
|
||||
}
|
||||
@@ -735,9 +932,7 @@ impl<T: RealNumber> BaseMatrix<T> for DenseMatrix<T> {
|
||||
);
|
||||
}
|
||||
|
||||
for i in 0..self.values.len() {
|
||||
self.values[i] = other.values[i];
|
||||
}
|
||||
self.values[..].clone_from_slice(&other.values[..]);
|
||||
}
|
||||
|
||||
fn abs_mut(&mut self) -> &Self {
|
||||
@@ -758,7 +953,7 @@ impl<T: RealNumber> BaseMatrix<T> for DenseMatrix<T> {
|
||||
fn sum(&self) -> T {
|
||||
let mut sum = T::zero();
|
||||
for i in 0..self.values.len() {
|
||||
sum = sum + self.values[i];
|
||||
sum += self.values[i];
|
||||
}
|
||||
sum
|
||||
}
|
||||
@@ -790,7 +985,7 @@ impl<T: RealNumber> BaseMatrix<T> for DenseMatrix<T> {
|
||||
for c in 0..self.ncols {
|
||||
let p = (self.get(r, c) - max).exp();
|
||||
self.set(r, c, p);
|
||||
z = z + p;
|
||||
z += p;
|
||||
}
|
||||
}
|
||||
for r in 0..self.nrows {
|
||||
@@ -810,7 +1005,7 @@ impl<T: RealNumber> BaseMatrix<T> for DenseMatrix<T> {
|
||||
fn argmax(&self) -> Vec<usize> {
|
||||
let mut res = vec![0usize; self.nrows];
|
||||
|
||||
for r in 0..self.nrows {
|
||||
for (r, res_r) in res.iter_mut().enumerate().take(self.nrows) {
|
||||
let mut max = T::neg_infinity();
|
||||
let mut max_pos = 0usize;
|
||||
for c in 0..self.ncols {
|
||||
@@ -820,7 +1015,7 @@ impl<T: RealNumber> BaseMatrix<T> for DenseMatrix<T> {
|
||||
max_pos = c;
|
||||
}
|
||||
}
|
||||
res[r] = max_pos;
|
||||
*res_r = max_pos;
|
||||
}
|
||||
|
||||
res
|
||||
@@ -864,7 +1059,30 @@ impl<T: RealNumber> BaseMatrix<T> for DenseMatrix<T> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn vec_dot() {
|
||||
let v1 = vec![1., 2., 3.];
|
||||
let v2 = vec![4., 5., 6.];
|
||||
assert_eq!(32.0, BaseVector::dot(&v1, &v2));
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn vec_copy_from() {
|
||||
let mut v1 = vec![1., 2., 3.];
|
||||
let v2 = vec![4., 5., 6.];
|
||||
v1.copy_from(&v2);
|
||||
assert_eq!(v1, v2);
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn vec_approximate_eq() {
|
||||
let a = vec![1., 2., 3.];
|
||||
let b = vec![1. + 1e-5, 2. + 2e-5, 3. + 3e-5];
|
||||
assert!(a.approximate_eq(&b, 1e-4));
|
||||
assert!(!a.approximate_eq(&b, 1e-5));
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn from_array() {
|
||||
let vec = [1., 2., 3., 4., 5., 6.];
|
||||
@@ -877,7 +1095,7 @@ mod tests {
|
||||
DenseMatrix::new(2, 3, vec![1., 4., 2., 5., 3., 6.])
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn row_column_vec_from_array() {
|
||||
let vec = vec![1., 2., 3., 4., 5., 6.];
|
||||
@@ -890,7 +1108,7 @@ mod tests {
|
||||
DenseMatrix::new(6, 1, vec![1., 2., 3., 4., 5., 6.])
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn from_to_row_vec() {
|
||||
let vec = vec![1., 2., 3.];
|
||||
@@ -899,18 +1117,24 @@ mod tests {
|
||||
DenseMatrix::new(1, 3, vec![1., 2., 3.])
|
||||
);
|
||||
assert_eq!(
|
||||
DenseMatrix::from_row_vector(vec.clone()).to_row_vector(),
|
||||
DenseMatrix::from_row_vector(vec).to_row_vector(),
|
||||
vec![1., 2., 3.]
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn col_matrix_to_row_vector() {
|
||||
let m: DenseMatrix<f64> = BaseMatrix::zeros(10, 1);
|
||||
assert_eq!(m.to_row_vector().len(), 10)
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn iter() {
|
||||
let vec = vec![1., 2., 3., 4., 5., 6.];
|
||||
let m = DenseMatrix::from_array(3, 2, &vec);
|
||||
assert_eq!(vec, m.iter().collect::<Vec<f32>>());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn v_stack() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
|
||||
@@ -925,7 +1149,7 @@ mod tests {
|
||||
let result = a.v_stack(&b);
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn h_stack() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
|
||||
@@ -938,7 +1162,13 @@ mod tests {
|
||||
let result = a.h_stack(&b);
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn get_row() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
|
||||
assert_eq!(vec![4., 5., 6.], a.get_row(1));
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn matmul() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
|
||||
@@ -947,14 +1177,45 @@ mod tests {
|
||||
let result = a.matmul(&b);
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn ab() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
|
||||
let b = DenseMatrix::from_2d_array(&[&[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
let c = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.]]);
|
||||
assert_eq!(
|
||||
a.ab(false, &b, false),
|
||||
DenseMatrix::from_2d_array(&[&[46., 52.], &[109., 124.]])
|
||||
);
|
||||
assert_eq!(
|
||||
c.ab(true, &b, false),
|
||||
DenseMatrix::from_2d_array(&[&[71., 80.], &[92., 104.]])
|
||||
);
|
||||
assert_eq!(
|
||||
b.ab(false, &c, true),
|
||||
DenseMatrix::from_2d_array(&[&[17., 39., 61.], &[23., 53., 83.,], &[29., 67., 105.]])
|
||||
);
|
||||
assert_eq!(
|
||||
a.ab(true, &b, true),
|
||||
DenseMatrix::from_2d_array(&[&[29., 39., 49.], &[40., 54., 68.,], &[51., 69., 87.]])
|
||||
);
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn dot() {
|
||||
let a = DenseMatrix::from_array(1, 3, &[1., 2., 3.]);
|
||||
let b = DenseMatrix::from_array(1, 3, &[4., 5., 6.]);
|
||||
assert_eq!(a.dot(&b), 32.);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn copy_from() {
|
||||
let mut a = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.]]);
|
||||
let b = DenseMatrix::from_2d_array(&[&[7., 8.], &[9., 10.], &[11., 12.]]);
|
||||
a.copy_from(&b);
|
||||
assert_eq!(a, b);
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn slice() {
|
||||
let m = DenseMatrix::from_2d_array(&[
|
||||
@@ -966,7 +1227,7 @@ mod tests {
|
||||
let result = m.slice(0..2, 1..3);
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn approximate_eq() {
|
||||
let m = DenseMatrix::from_2d_array(&[&[2., 3.], &[5., 6.]]);
|
||||
@@ -975,7 +1236,7 @@ mod tests {
|
||||
assert!(m.approximate_eq(&m_eq, 0.5));
|
||||
assert!(!m.approximate_eq(&m_neq, 0.5));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn rand() {
|
||||
let m: DenseMatrix<f64> = DenseMatrix::rand(3, 3);
|
||||
@@ -985,7 +1246,7 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn transpose() {
|
||||
let m = DenseMatrix::from_2d_array(&[&[1.0, 3.0], &[2.0, 4.0]]);
|
||||
@@ -997,7 +1258,7 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn reshape() {
|
||||
let m_orig = DenseMatrix::row_vector_from_array(&[1., 2., 3., 4., 5., 6.]);
|
||||
@@ -1008,7 +1269,7 @@ mod tests {
|
||||
assert_eq!(m_result.get(0, 1), 2.);
|
||||
assert_eq!(m_result.get(0, 3), 4.);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn norm() {
|
||||
let v = DenseMatrix::row_vector_from_array(&[3., -2., 6.]);
|
||||
@@ -1017,7 +1278,7 @@ mod tests {
|
||||
assert_eq!(v.norm(std::f64::INFINITY), 6.);
|
||||
assert_eq!(v.norm(std::f64::NEG_INFINITY), 2.);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn softmax_mut() {
|
||||
let mut prob: DenseMatrix<f64> = DenseMatrix::row_vector_from_array(&[1., 2., 3.]);
|
||||
@@ -1026,14 +1287,14 @@ mod tests {
|
||||
assert!((prob.get(0, 1) - 0.24).abs() < 0.01);
|
||||
assert!((prob.get(0, 2) - 0.66).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn col_mean() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
|
||||
let res = a.column_mean();
|
||||
assert_eq!(res, vec![4., 5., 6.]);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn min_max_sum() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
|
||||
@@ -1041,30 +1302,32 @@ mod tests {
|
||||
assert_eq!(1., a.min());
|
||||
assert_eq!(6., a.max());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn eye() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1., 0., 0.], &[0., 1., 0.], &[0., 0., 1.]]);
|
||||
let res = DenseMatrix::eye(3);
|
||||
assert_eq!(res, a);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn to_from_json() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||
let deserialized_a: DenseMatrix<f64> =
|
||||
serde_json::from_str(&serde_json::to_string(&a).unwrap()).unwrap();
|
||||
assert_eq!(a, deserialized_a);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn to_from_bincode() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||
let deserialized_a: DenseMatrix<f64> =
|
||||
bincode::deserialize(&bincode::serialize(&a).unwrap()).unwrap();
|
||||
assert_eq!(a, deserialized_a);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn to_string() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||
@@ -1073,7 +1336,7 @@ mod tests {
|
||||
"[[0.9, 0.4, 0.7], [0.4, 0.5, 0.3], [0.7, 0.3, 0.8]]"
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn cov() {
|
||||
let a = DenseMatrix::from_2d_array(&[
|
||||
|
||||
+222
-18
@@ -40,17 +40,20 @@
|
||||
use std::iter::Sum;
|
||||
use std::ops::{AddAssign, DivAssign, MulAssign, Range, SubAssign};
|
||||
|
||||
use nalgebra::{DMatrix, Dynamic, Matrix, MatrixMN, RowDVector, Scalar, VecStorage, U1};
|
||||
use nalgebra::{Const, DMatrix, Dynamic, Matrix, OMatrix, RowDVector, Scalar, VecStorage, U1};
|
||||
|
||||
use crate::linalg::cholesky::CholeskyDecomposableMatrix;
|
||||
use crate::linalg::evd::EVDDecomposableMatrix;
|
||||
use crate::linalg::high_order::HighOrderOperations;
|
||||
use crate::linalg::lu::LUDecomposableMatrix;
|
||||
use crate::linalg::qr::QRDecomposableMatrix;
|
||||
use crate::linalg::stats::{MatrixPreprocessing, MatrixStats};
|
||||
use crate::linalg::svd::SVDDecomposableMatrix;
|
||||
use crate::linalg::Matrix as SmartCoreMatrix;
|
||||
use crate::linalg::{BaseMatrix, BaseVector};
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
impl<T: RealNumber + 'static> BaseVector<T> for MatrixMN<T, U1, Dynamic> {
|
||||
impl<T: RealNumber + 'static> BaseVector<T> for OMatrix<T, U1, Dynamic> {
|
||||
fn get(&self, i: usize) -> T {
|
||||
*self.get((0, i)).unwrap()
|
||||
}
|
||||
@@ -63,7 +66,7 @@ impl<T: RealNumber + 'static> BaseVector<T> for MatrixMN<T, U1, Dynamic> {
|
||||
}
|
||||
|
||||
fn to_vec(&self) -> Vec<T> {
|
||||
self.row(0).iter().map(|v| *v).collect()
|
||||
self.row(0).iter().copied().collect()
|
||||
}
|
||||
|
||||
fn zeros(len: usize) -> Self {
|
||||
@@ -79,19 +82,123 @@ impl<T: RealNumber + 'static> BaseVector<T> for MatrixMN<T, U1, Dynamic> {
|
||||
m.fill(value);
|
||||
m
|
||||
}
|
||||
|
||||
fn dot(&self, other: &Self) -> T {
|
||||
self.dot(other)
|
||||
}
|
||||
|
||||
fn norm2(&self) -> T {
|
||||
self.iter().map(|x| *x * *x).sum::<T>().sqrt()
|
||||
}
|
||||
|
||||
fn norm(&self, p: T) -> T {
|
||||
if p.is_infinite() && p.is_sign_positive() {
|
||||
self.iter().fold(T::neg_infinity(), |f, &val| {
|
||||
let v = val.abs();
|
||||
if f > v {
|
||||
f
|
||||
} else {
|
||||
v
|
||||
}
|
||||
})
|
||||
} else if p.is_infinite() && p.is_sign_negative() {
|
||||
self.iter().fold(T::infinity(), |f, &val| {
|
||||
let v = val.abs();
|
||||
if f < v {
|
||||
f
|
||||
} else {
|
||||
v
|
||||
}
|
||||
})
|
||||
} else {
|
||||
let mut norm = T::zero();
|
||||
|
||||
for xi in self.iter() {
|
||||
norm += xi.abs().powf(p);
|
||||
}
|
||||
|
||||
norm.powf(T::one() / p)
|
||||
}
|
||||
}
|
||||
|
||||
fn div_element_mut(&mut self, pos: usize, x: T) {
|
||||
*self.get_mut(pos).unwrap() = *self.get(pos).unwrap() / x;
|
||||
}
|
||||
|
||||
fn mul_element_mut(&mut self, pos: usize, x: T) {
|
||||
*self.get_mut(pos).unwrap() = *self.get(pos).unwrap() * x;
|
||||
}
|
||||
|
||||
fn add_element_mut(&mut self, pos: usize, x: T) {
|
||||
*self.get_mut(pos).unwrap() = *self.get(pos).unwrap() + x;
|
||||
}
|
||||
|
||||
fn sub_element_mut(&mut self, pos: usize, x: T) {
|
||||
*self.get_mut(pos).unwrap() = *self.get(pos).unwrap() - x;
|
||||
}
|
||||
|
||||
fn add_mut(&mut self, other: &Self) -> &Self {
|
||||
*self += other;
|
||||
self
|
||||
}
|
||||
|
||||
fn sub_mut(&mut self, other: &Self) -> &Self {
|
||||
*self -= other;
|
||||
self
|
||||
}
|
||||
|
||||
fn mul_mut(&mut self, other: &Self) -> &Self {
|
||||
self.component_mul_assign(other);
|
||||
self
|
||||
}
|
||||
|
||||
fn div_mut(&mut self, other: &Self) -> &Self {
|
||||
self.component_div_assign(other);
|
||||
self
|
||||
}
|
||||
|
||||
fn approximate_eq(&self, other: &Self, error: T) -> bool {
|
||||
if self.shape() != other.shape() {
|
||||
false
|
||||
} else {
|
||||
self.iter()
|
||||
.zip(other.iter())
|
||||
.all(|(a, b)| (*a - *b).abs() <= error)
|
||||
}
|
||||
}
|
||||
|
||||
fn sum(&self) -> T {
|
||||
let mut sum = T::zero();
|
||||
for v in self.iter() {
|
||||
sum += *v;
|
||||
}
|
||||
sum
|
||||
}
|
||||
|
||||
fn unique(&self) -> Vec<T> {
|
||||
let mut result: Vec<T> = self.iter().copied().collect();
|
||||
result.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
result.dedup();
|
||||
result
|
||||
}
|
||||
|
||||
fn copy_from(&mut self, other: &Self) {
|
||||
Matrix::copy_from(self, other);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
|
||||
BaseMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
|
||||
{
|
||||
type RowVector = MatrixMN<T, U1, Dynamic>;
|
||||
type RowVector = RowDVector<T>;
|
||||
|
||||
fn from_row_vector(vec: Self::RowVector) -> Self {
|
||||
Matrix::from_rows(&[vec])
|
||||
}
|
||||
|
||||
fn to_row_vector(self) -> Self::RowVector {
|
||||
self.row(0).into_owned()
|
||||
let (nrows, ncols) = self.shape();
|
||||
self.reshape_generic(Const::<1>, Dynamic::new(nrows * ncols))
|
||||
}
|
||||
|
||||
fn get(&self, row: usize, col: usize) -> T {
|
||||
@@ -99,26 +206,26 @@ impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Su
|
||||
}
|
||||
|
||||
fn get_row_as_vec(&self, row: usize) -> Vec<T> {
|
||||
self.row(row).iter().map(|v| *v).collect()
|
||||
self.row(row).iter().copied().collect()
|
||||
}
|
||||
|
||||
fn get_row(&self, row: usize) -> Self::RowVector {
|
||||
self.row(row).into_owned()
|
||||
}
|
||||
|
||||
fn copy_row_as_vec(&self, row: usize, result: &mut Vec<T>) {
|
||||
let mut r = 0;
|
||||
for e in self.row(row).iter() {
|
||||
for (r, e) in self.row(row).iter().enumerate() {
|
||||
result[r] = *e;
|
||||
r += 1;
|
||||
}
|
||||
}
|
||||
|
||||
fn get_col_as_vec(&self, col: usize) -> Vec<T> {
|
||||
self.column(col).iter().map(|v| *v).collect()
|
||||
self.column(col).iter().copied().collect()
|
||||
}
|
||||
|
||||
fn copy_col_as_vec(&self, col: usize, result: &mut Vec<T>) {
|
||||
let mut r = 0;
|
||||
for e in self.column(col).iter() {
|
||||
result[r] = *e;
|
||||
r += 1;
|
||||
for (c, e) in self.column(col).iter().enumerate() {
|
||||
result[c] = *e;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -264,7 +371,7 @@ impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Su
|
||||
let mut norm = T::zero();
|
||||
|
||||
for xi in self.iter() {
|
||||
norm = norm + xi.abs().powf(p);
|
||||
norm += xi.abs().powf(p);
|
||||
}
|
||||
|
||||
norm.powf(T::one() / p)
|
||||
@@ -373,7 +480,7 @@ impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Su
|
||||
for c in 0..self.ncols() {
|
||||
let p = (self[(r, c)] - max).exp();
|
||||
self.set(r, c, p);
|
||||
z = z + p;
|
||||
z += p;
|
||||
}
|
||||
}
|
||||
for r in 0..self.nrows() {
|
||||
@@ -410,7 +517,7 @@ impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Su
|
||||
}
|
||||
|
||||
fn unique(&self) -> Vec<T> {
|
||||
let mut result: Vec<T> = self.iter().map(|v| *v).collect();
|
||||
let mut result: Vec<T> = self.iter().copied().collect();
|
||||
result.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
result.dedup();
|
||||
result
|
||||
@@ -441,6 +548,26 @@ impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Su
|
||||
{
|
||||
}
|
||||
|
||||
impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
|
||||
CholeskyDecomposableMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
|
||||
{
|
||||
}
|
||||
|
||||
impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
|
||||
MatrixStats<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
|
||||
{
|
||||
}
|
||||
|
||||
impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
|
||||
MatrixPreprocessing<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
|
||||
{
|
||||
}
|
||||
|
||||
impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
|
||||
HighOrderOperations<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
|
||||
{
|
||||
}
|
||||
|
||||
impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Sum + 'static>
|
||||
SmartCoreMatrix<T> for Matrix<T, Dynamic, Dynamic, VecStorage<T, Dynamic, Dynamic>>
|
||||
{
|
||||
@@ -452,12 +579,25 @@ mod tests {
|
||||
use crate::linear::linear_regression::*;
|
||||
use nalgebra::{DMatrix, Matrix2x3, RowDVector};
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn vec_copy_from() {
|
||||
let mut v1 = RowDVector::from_vec(vec![1., 2., 3.]);
|
||||
let mut v2 = RowDVector::from_vec(vec![4., 5., 6.]);
|
||||
v1.copy_from(&v2);
|
||||
assert_eq!(v2, v1);
|
||||
v2[0] = 10.0;
|
||||
assert_ne!(v2, v1);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn vec_len() {
|
||||
let v = RowDVector::from_vec(vec![1., 2., 3.]);
|
||||
assert_eq!(3, v.len());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn get_set_vector() {
|
||||
let mut v = RowDVector::from_vec(vec![1., 2., 3., 4.]);
|
||||
@@ -470,12 +610,14 @@ mod tests {
|
||||
assert_eq!(5., BaseVector::get(&v, 1));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn vec_to_vec() {
|
||||
let v = RowDVector::from_vec(vec![1., 2., 3.]);
|
||||
assert_eq!(vec![1., 2., 3.], v.to_vec());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn vec_init() {
|
||||
let zeros: RowDVector<f32> = BaseVector::zeros(3);
|
||||
@@ -486,6 +628,24 @@ mod tests {
|
||||
assert_eq!(twos, RowDVector::from_vec(vec![2., 2., 2.]));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn vec_dot() {
|
||||
let v1 = RowDVector::from_vec(vec![1., 2., 3.]);
|
||||
let v2 = RowDVector::from_vec(vec![4., 5., 6.]);
|
||||
assert_eq!(32.0, BaseVector::dot(&v1, &v2));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn vec_approximate_eq() {
|
||||
let a = RowDVector::from_vec(vec![1., 2., 3.]);
|
||||
let noise = RowDVector::from_vec(vec![1e-5, 2e-5, 3e-5]);
|
||||
assert!(a.approximate_eq(&(&noise + &a), 1e-4));
|
||||
assert!(!a.approximate_eq(&(&noise + &a), 1e-5));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn get_set_dynamic() {
|
||||
let mut m = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
|
||||
@@ -498,6 +658,7 @@ mod tests {
|
||||
assert_eq!(10., BaseMatrix::get(&m, 1, 1));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn zeros() {
|
||||
let expected = DMatrix::from_row_slice(2, 2, &[0., 0., 0., 0.]);
|
||||
@@ -507,6 +668,7 @@ mod tests {
|
||||
assert_eq!(m, expected);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn ones() {
|
||||
let expected = DMatrix::from_row_slice(2, 2, &[1., 1., 1., 1.]);
|
||||
@@ -516,6 +678,7 @@ mod tests {
|
||||
assert_eq!(m, expected);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn eye() {
|
||||
let expected = DMatrix::from_row_slice(3, 3, &[1., 0., 0., 0., 1., 0., 0., 0., 1.]);
|
||||
@@ -523,6 +686,7 @@ mod tests {
|
||||
assert_eq!(m, expected);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn shape() {
|
||||
let m: DMatrix<f64> = BaseMatrix::zeros(5, 10);
|
||||
@@ -532,6 +696,7 @@ mod tests {
|
||||
assert_eq!(ncols, 10);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn scalar_add_sub_mul_div() {
|
||||
let mut m = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
|
||||
@@ -545,6 +710,7 @@ mod tests {
|
||||
assert_eq!(m, expected);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn add_sub_mul_div() {
|
||||
let mut m = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
|
||||
@@ -563,6 +729,7 @@ mod tests {
|
||||
assert_eq!(m, expected);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn to_from_row_vector() {
|
||||
let v = RowDVector::from_vec(vec![1., 2., 3., 4.]);
|
||||
@@ -571,6 +738,14 @@ mod tests {
|
||||
assert_eq!(m.to_row_vector(), expected);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn col_matrix_to_row_vector() {
|
||||
let m: DMatrix<f64> = BaseMatrix::zeros(10, 1);
|
||||
assert_eq!(m.to_row_vector().len(), 10)
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn get_row_col_as_vec() {
|
||||
let m = DMatrix::from_row_slice(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
|
||||
@@ -579,6 +754,14 @@ mod tests {
|
||||
assert_eq!(m.get_col_as_vec(1), vec!(2., 5., 8.));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn get_row() {
|
||||
let a = DMatrix::from_row_slice(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
|
||||
assert_eq!(RowDVector::from_vec(vec![4., 5., 6.]), a.get_row(1));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn copy_row_col_as_vec() {
|
||||
let m = DMatrix::from_row_slice(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
|
||||
@@ -590,6 +773,7 @@ mod tests {
|
||||
assert_eq!(v, vec!(2., 5., 8.));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn element_add_sub_mul_div() {
|
||||
let mut m = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
|
||||
@@ -603,6 +787,7 @@ mod tests {
|
||||
assert_eq!(m, expected);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn vstack_hstack() {
|
||||
let m1 = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]);
|
||||
@@ -618,6 +803,7 @@ mod tests {
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn matmul() {
|
||||
let a = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]);
|
||||
@@ -627,6 +813,7 @@ mod tests {
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn dot() {
|
||||
let a = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]);
|
||||
@@ -634,6 +821,7 @@ mod tests {
|
||||
assert_eq!(14., a.dot(&b));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn slice() {
|
||||
let a = DMatrix::from_row_slice(
|
||||
@@ -646,6 +834,7 @@ mod tests {
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn approximate_eq() {
|
||||
let a = DMatrix::from_row_slice(3, 3, &[1., 2., 3., 4., 5., 6., 7., 8., 9.]);
|
||||
@@ -658,6 +847,7 @@ mod tests {
|
||||
assert!(!a.approximate_eq(&(&noise + &a), 1e-5));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn negative_mut() {
|
||||
let mut v = DMatrix::from_row_slice(1, 3, &[3., -2., 6.]);
|
||||
@@ -665,6 +855,7 @@ mod tests {
|
||||
assert_eq!(v, DMatrix::from_row_slice(1, 3, &[-3., 2., -6.]));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn transpose() {
|
||||
let m = DMatrix::from_row_slice(2, 2, &[1.0, 3.0, 2.0, 4.0]);
|
||||
@@ -673,6 +864,7 @@ mod tests {
|
||||
assert_eq!(m_transposed, expected);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn rand() {
|
||||
let m: DMatrix<f64> = BaseMatrix::rand(3, 3);
|
||||
@@ -683,6 +875,7 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn norm() {
|
||||
let v = DMatrix::from_row_slice(1, 3, &[3., -2., 6.]);
|
||||
@@ -692,6 +885,7 @@ mod tests {
|
||||
assert_eq!(BaseMatrix::norm(&v, std::f64::NEG_INFINITY), 2.);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn col_mean() {
|
||||
let a = DMatrix::from_row_slice(3, 3, &[1., 2., 3., 4., 5., 6., 7., 8., 9.]);
|
||||
@@ -699,6 +893,7 @@ mod tests {
|
||||
assert_eq!(res, vec![4., 5., 6.]);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn reshape() {
|
||||
let m_orig = DMatrix::from_row_slice(1, 6, &[1., 2., 3., 4., 5., 6.]);
|
||||
@@ -710,6 +905,7 @@ mod tests {
|
||||
assert_eq!(BaseMatrix::get(&m_result, 0, 3), 4.);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn copy_from() {
|
||||
let mut src = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]);
|
||||
@@ -718,6 +914,7 @@ mod tests {
|
||||
assert_eq!(src, dst);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn abs_mut() {
|
||||
let mut a = DMatrix::from_row_slice(2, 2, &[1., -2., 3., -4.]);
|
||||
@@ -726,6 +923,7 @@ mod tests {
|
||||
assert_eq!(a, expected);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn min_max_sum() {
|
||||
let a = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]);
|
||||
@@ -734,6 +932,7 @@ mod tests {
|
||||
assert_eq!(6., a.max());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn max_diff() {
|
||||
let a1 = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., -5., 6.]);
|
||||
@@ -742,6 +941,7 @@ mod tests {
|
||||
assert_eq!(a2.max_diff(&a2), 0.);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn softmax_mut() {
|
||||
let mut prob: DMatrix<f64> = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]);
|
||||
@@ -751,13 +951,15 @@ mod tests {
|
||||
assert!((BaseMatrix::get(&prob, 0, 2) - 0.66).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn pow_mut() {
|
||||
let mut a = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]);
|
||||
a.pow_mut(3.);
|
||||
BaseMatrix::pow_mut(&mut a, 3.);
|
||||
assert_eq!(a, DMatrix::from_row_slice(1, 3, &[1., 8., 27.]));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn argmax() {
|
||||
let a = DMatrix::from_row_slice(3, 3, &[1., 2., 3., -5., -6., -7., 0.1, 0.2, 0.1]);
|
||||
@@ -765,6 +967,7 @@ mod tests {
|
||||
assert_eq!(res, vec![2, 0, 1]);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn unique() {
|
||||
let a = DMatrix::from_row_slice(3, 3, &[1., 2., 2., -2., -6., -7., 2., 3., 4.]);
|
||||
@@ -773,6 +976,7 @@ mod tests {
|
||||
assert_eq!(res, vec![-7., -6., -2., 1., 2., 3., 4.]);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn ols_fit_predict() {
|
||||
let x = DMatrix::from_row_slice(
|
||||
|
||||
+220
-22
@@ -36,7 +36,7 @@
|
||||
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.
|
||||
//! ]);
|
||||
//!
|
||||
//! let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||
//! let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
//! let y_hat = lr.predict(&x).unwrap();
|
||||
//! ```
|
||||
use std::iter::Sum;
|
||||
@@ -47,17 +47,20 @@ use std::ops::Range;
|
||||
use std::ops::SubAssign;
|
||||
|
||||
use ndarray::ScalarOperand;
|
||||
use ndarray::{s, stack, Array, ArrayBase, Axis, Ix1, Ix2, OwnedRepr};
|
||||
use ndarray::{concatenate, s, Array, ArrayBase, Axis, Ix1, Ix2, OwnedRepr};
|
||||
|
||||
use crate::linalg::cholesky::CholeskyDecomposableMatrix;
|
||||
use crate::linalg::evd::EVDDecomposableMatrix;
|
||||
use crate::linalg::high_order::HighOrderOperations;
|
||||
use crate::linalg::lu::LUDecomposableMatrix;
|
||||
use crate::linalg::qr::QRDecomposableMatrix;
|
||||
use crate::linalg::stats::{MatrixPreprocessing, MatrixStats};
|
||||
use crate::linalg::svd::SVDDecomposableMatrix;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::linalg::{BaseMatrix, BaseVector};
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
impl<T: RealNumber> BaseVector<T> for ArrayBase<OwnedRepr<T>, Ix1> {
|
||||
impl<T: RealNumber + ScalarOperand> BaseVector<T> for ArrayBase<OwnedRepr<T>, Ix1> {
|
||||
fn get(&self, i: usize) -> T {
|
||||
self[i]
|
||||
}
|
||||
@@ -84,6 +87,99 @@ impl<T: RealNumber> BaseVector<T> for ArrayBase<OwnedRepr<T>, Ix1> {
|
||||
fn fill(len: usize, value: T) -> Self {
|
||||
Array::from_elem(len, value)
|
||||
}
|
||||
|
||||
fn dot(&self, other: &Self) -> T {
|
||||
self.dot(other)
|
||||
}
|
||||
|
||||
fn norm2(&self) -> T {
|
||||
self.iter().map(|x| *x * *x).sum::<T>().sqrt()
|
||||
}
|
||||
|
||||
fn norm(&self, p: T) -> T {
|
||||
if p.is_infinite() && p.is_sign_positive() {
|
||||
self.iter().fold(T::neg_infinity(), |f, &val| {
|
||||
let v = val.abs();
|
||||
if f > v {
|
||||
f
|
||||
} else {
|
||||
v
|
||||
}
|
||||
})
|
||||
} else if p.is_infinite() && p.is_sign_negative() {
|
||||
self.iter().fold(T::infinity(), |f, &val| {
|
||||
let v = val.abs();
|
||||
if f < v {
|
||||
f
|
||||
} else {
|
||||
v
|
||||
}
|
||||
})
|
||||
} else {
|
||||
let mut norm = T::zero();
|
||||
|
||||
for xi in self.iter() {
|
||||
norm += xi.abs().powf(p);
|
||||
}
|
||||
|
||||
norm.powf(T::one() / p)
|
||||
}
|
||||
}
|
||||
|
||||
fn div_element_mut(&mut self, pos: usize, x: T) {
|
||||
self[pos] /= x;
|
||||
}
|
||||
|
||||
fn mul_element_mut(&mut self, pos: usize, x: T) {
|
||||
self[pos] *= x;
|
||||
}
|
||||
|
||||
fn add_element_mut(&mut self, pos: usize, x: T) {
|
||||
self[pos] += x;
|
||||
}
|
||||
|
||||
fn sub_element_mut(&mut self, pos: usize, x: T) {
|
||||
self[pos] -= x;
|
||||
}
|
||||
|
||||
fn approximate_eq(&self, other: &Self, error: T) -> bool {
|
||||
(self - other).iter().all(|v| v.abs() <= error)
|
||||
}
|
||||
|
||||
fn add_mut(&mut self, other: &Self) -> &Self {
|
||||
*self += other;
|
||||
self
|
||||
}
|
||||
|
||||
fn sub_mut(&mut self, other: &Self) -> &Self {
|
||||
*self -= other;
|
||||
self
|
||||
}
|
||||
|
||||
fn mul_mut(&mut self, other: &Self) -> &Self {
|
||||
*self *= other;
|
||||
self
|
||||
}
|
||||
|
||||
fn div_mut(&mut self, other: &Self) -> &Self {
|
||||
*self /= other;
|
||||
self
|
||||
}
|
||||
|
||||
fn sum(&self) -> T {
|
||||
self.sum()
|
||||
}
|
||||
|
||||
fn unique(&self) -> Vec<T> {
|
||||
let mut result = self.clone().into_raw_vec();
|
||||
result.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
result.dedup();
|
||||
result
|
||||
}
|
||||
|
||||
fn copy_from(&mut self, other: &Self) {
|
||||
self.assign(other);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum>
|
||||
@@ -109,11 +205,13 @@ impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssi
|
||||
self.row(row).to_vec()
|
||||
}
|
||||
|
||||
fn get_row(&self, row: usize) -> Self::RowVector {
|
||||
self.row(row).to_owned()
|
||||
}
|
||||
|
||||
fn copy_row_as_vec(&self, row: usize, result: &mut Vec<T>) {
|
||||
let mut r = 0;
|
||||
for e in self.row(row).iter() {
|
||||
for (r, e) in self.row(row).iter().enumerate() {
|
||||
result[r] = *e;
|
||||
r += 1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,10 +220,8 @@ impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssi
|
||||
}
|
||||
|
||||
fn copy_col_as_vec(&self, col: usize, result: &mut Vec<T>) {
|
||||
let mut r = 0;
|
||||
for e in self.column(col).iter() {
|
||||
result[r] = *e;
|
||||
r += 1;
|
||||
for (c, e) in self.column(col).iter().enumerate() {
|
||||
result[c] = *e;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -154,11 +250,11 @@ impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssi
|
||||
}
|
||||
|
||||
fn h_stack(&self, other: &Self) -> Self {
|
||||
stack(Axis(1), &[self.view(), other.view()]).unwrap()
|
||||
concatenate(Axis(1), &[self.view(), other.view()]).unwrap()
|
||||
}
|
||||
|
||||
fn v_stack(&self, other: &Self) -> Self {
|
||||
stack(Axis(0), &[self.view(), other.view()]).unwrap()
|
||||
concatenate(Axis(0), &[self.view(), other.view()]).unwrap()
|
||||
}
|
||||
|
||||
fn matmul(&self, other: &Self) -> Self {
|
||||
@@ -253,7 +349,7 @@ impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssi
|
||||
let mut norm = T::zero();
|
||||
|
||||
for xi in self.iter() {
|
||||
norm = norm + xi.abs().powf(p);
|
||||
norm += xi.abs().powf(p);
|
||||
}
|
||||
|
||||
norm.powf(T::one() / p)
|
||||
@@ -265,19 +361,19 @@ impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssi
|
||||
}
|
||||
|
||||
fn div_element_mut(&mut self, row: usize, col: usize, x: T) {
|
||||
self[[row, col]] = self[[row, col]] / x;
|
||||
self[[row, col]] /= x;
|
||||
}
|
||||
|
||||
fn mul_element_mut(&mut self, row: usize, col: usize, x: T) {
|
||||
self[[row, col]] = self[[row, col]] * x;
|
||||
self[[row, col]] *= x;
|
||||
}
|
||||
|
||||
fn add_element_mut(&mut self, row: usize, col: usize, x: T) {
|
||||
self[[row, col]] = self[[row, col]] + x;
|
||||
self[[row, col]] += x;
|
||||
}
|
||||
|
||||
fn sub_element_mut(&mut self, row: usize, col: usize, x: T) {
|
||||
self[[row, col]] = self[[row, col]] - x;
|
||||
self[[row, col]] -= x;
|
||||
}
|
||||
|
||||
fn negative_mut(&mut self) {
|
||||
@@ -289,7 +385,7 @@ impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssi
|
||||
}
|
||||
|
||||
fn copy_from(&mut self, other: &Self) {
|
||||
self.assign(&other);
|
||||
self.assign(other);
|
||||
}
|
||||
|
||||
fn abs_mut(&mut self) -> &Self {
|
||||
@@ -331,7 +427,7 @@ impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssi
|
||||
for c in 0..self.ncols() {
|
||||
let p = (self[(r, c)] - max).exp();
|
||||
self.set(r, c, p);
|
||||
z = z + p;
|
||||
z += p;
|
||||
}
|
||||
}
|
||||
for r in 0..self.nrows() {
|
||||
@@ -401,6 +497,26 @@ impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssi
|
||||
{
|
||||
}
|
||||
|
||||
impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum>
|
||||
CholeskyDecomposableMatrix<T> for ArrayBase<OwnedRepr<T>, Ix2>
|
||||
{
|
||||
}
|
||||
|
||||
impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum>
|
||||
MatrixStats<T> for ArrayBase<OwnedRepr<T>, Ix2>
|
||||
{
|
||||
}
|
||||
|
||||
impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum>
|
||||
MatrixPreprocessing<T> for ArrayBase<OwnedRepr<T>, Ix2>
|
||||
{
|
||||
}
|
||||
|
||||
impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum>
|
||||
HighOrderOperations<T> for ArrayBase<OwnedRepr<T>, Ix2>
|
||||
{
|
||||
}
|
||||
|
||||
impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssign + Sum> Matrix<T>
|
||||
for ArrayBase<OwnedRepr<T>, Ix2>
|
||||
{
|
||||
@@ -414,6 +530,7 @@ mod tests {
|
||||
use crate::metrics::mean_absolute_error;
|
||||
use ndarray::{arr1, arr2, Array1, Array2};
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn vec_get_set() {
|
||||
let mut result = arr1(&[1., 2., 3.]);
|
||||
@@ -425,18 +542,49 @@ mod tests {
|
||||
assert_eq!(5., BaseVector::get(&result, 1));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn vec_copy_from() {
|
||||
let mut v1 = arr1(&[1., 2., 3.]);
|
||||
let mut v2 = arr1(&[4., 5., 6.]);
|
||||
v1.copy_from(&v2);
|
||||
assert_eq!(v1, v2);
|
||||
v2[0] = 10.0;
|
||||
assert_ne!(v1, v2);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn vec_len() {
|
||||
let v = arr1(&[1., 2., 3.]);
|
||||
assert_eq!(3, v.len());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn vec_to_vec() {
|
||||
let v = arr1(&[1., 2., 3.]);
|
||||
assert_eq!(vec![1., 2., 3.], v.to_vec());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn vec_dot() {
|
||||
let v1 = arr1(&[1., 2., 3.]);
|
||||
let v2 = arr1(&[4., 5., 6.]);
|
||||
assert_eq!(32.0, BaseVector::dot(&v1, &v2));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn vec_approximate_eq() {
|
||||
let a = arr1(&[1., 2., 3.]);
|
||||
let noise = arr1(&[1e-5, 2e-5, 3e-5]);
|
||||
assert!(a.approximate_eq(&(&noise + &a), 1e-4));
|
||||
assert!(!a.approximate_eq(&(&noise + &a), 1e-5));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn from_to_row_vec() {
|
||||
let vec = arr1(&[1., 2., 3.]);
|
||||
@@ -447,6 +595,14 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn col_matrix_to_row_vector() {
|
||||
let m: Array2<f64> = BaseMatrix::zeros(10, 1);
|
||||
assert_eq!(m.to_row_vector().len(), 10)
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn add_mut() {
|
||||
let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||
@@ -457,6 +613,7 @@ mod tests {
|
||||
assert_eq!(a1, a3);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn sub_mut() {
|
||||
let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||
@@ -467,6 +624,7 @@ mod tests {
|
||||
assert_eq!(a1, a3);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn mul_mut() {
|
||||
let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||
@@ -477,6 +635,7 @@ mod tests {
|
||||
assert_eq!(a1, a3);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn div_mut() {
|
||||
let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||
@@ -487,6 +646,7 @@ mod tests {
|
||||
assert_eq!(a1, a3);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn div_element_mut() {
|
||||
let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||
@@ -495,6 +655,7 @@ mod tests {
|
||||
assert_eq!(BaseMatrix::get(&a, 1, 1), 1.);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn mul_element_mut() {
|
||||
let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||
@@ -503,6 +664,7 @@ mod tests {
|
||||
assert_eq!(BaseMatrix::get(&a, 1, 1), 25.);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn add_element_mut() {
|
||||
let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||
@@ -510,7 +672,7 @@ mod tests {
|
||||
|
||||
assert_eq!(BaseMatrix::get(&a, 1, 1), 10.);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn sub_element_mut() {
|
||||
let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||
@@ -519,6 +681,7 @@ mod tests {
|
||||
assert_eq!(BaseMatrix::get(&a, 1, 1), 0.);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn vstack_hstack() {
|
||||
let a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||
@@ -533,6 +696,7 @@ mod tests {
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn get_set() {
|
||||
let mut result = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||
@@ -544,6 +708,7 @@ mod tests {
|
||||
assert_eq!(10., BaseMatrix::get(&result, 1, 1));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn matmul() {
|
||||
let a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||
@@ -553,6 +718,7 @@ mod tests {
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn dot() {
|
||||
let a = arr2(&[[1., 2., 3.]]);
|
||||
@@ -560,6 +726,7 @@ mod tests {
|
||||
assert_eq!(14., BaseMatrix::dot(&a, &b));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn slice() {
|
||||
let a = arr2(&[
|
||||
@@ -572,6 +739,7 @@ mod tests {
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn scalar_ops() {
|
||||
let a = arr2(&[[1., 2., 3.]]);
|
||||
@@ -581,6 +749,7 @@ mod tests {
|
||||
assert_eq!(&arr2(&[[0.5, 1., 1.5]]), a.clone().div_scalar_mut(2.));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn transpose() {
|
||||
let m = arr2(&[[1.0, 3.0], [2.0, 4.0]]);
|
||||
@@ -589,6 +758,7 @@ mod tests {
|
||||
assert_eq!(m_transposed, expected);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn norm() {
|
||||
let v = arr2(&[[3., -2., 6.]]);
|
||||
@@ -598,6 +768,7 @@ mod tests {
|
||||
assert_eq!(v.norm(std::f64::NEG_INFINITY), 2.);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn negative_mut() {
|
||||
let mut v = arr2(&[[3., -2., 6.]]);
|
||||
@@ -605,6 +776,7 @@ mod tests {
|
||||
assert_eq!(v, arr2(&[[-3., 2., -6.]]));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn reshape() {
|
||||
let m_orig = arr2(&[[1., 2., 3., 4., 5., 6.]]);
|
||||
@@ -616,6 +788,7 @@ mod tests {
|
||||
assert_eq!(BaseMatrix::get(&m_result, 0, 3), 4.);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn copy_from() {
|
||||
let mut src = arr2(&[[1., 2., 3.]]);
|
||||
@@ -624,6 +797,7 @@ mod tests {
|
||||
assert_eq!(src, dst);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn min_max_sum() {
|
||||
let a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||
@@ -632,6 +806,7 @@ mod tests {
|
||||
assert_eq!(6., a.max());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn max_diff() {
|
||||
let a1 = arr2(&[[1., 2., 3.], [4., -5., 6.]]);
|
||||
@@ -640,6 +815,7 @@ mod tests {
|
||||
assert_eq!(a2.max_diff(&a2), 0.);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn softmax_mut() {
|
||||
let mut prob: Array2<f64> = arr2(&[[1., 2., 3.]]);
|
||||
@@ -649,6 +825,7 @@ mod tests {
|
||||
assert!((BaseMatrix::get(&prob, 0, 2) - 0.66).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn pow_mut() {
|
||||
let mut a = arr2(&[[1., 2., 3.]]);
|
||||
@@ -656,6 +833,7 @@ mod tests {
|
||||
assert_eq!(a, arr2(&[[1., 8., 27.]]));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn argmax() {
|
||||
let a = arr2(&[[1., 2., 3.], [-5., -6., -7.], [0.1, 0.2, 0.1]]);
|
||||
@@ -663,6 +841,7 @@ mod tests {
|
||||
assert_eq!(res, vec![2, 0, 1]);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn unique() {
|
||||
let a = arr2(&[[1., 2., 2.], [-2., -6., -7.], [2., 3., 4.]]);
|
||||
@@ -671,6 +850,7 @@ mod tests {
|
||||
assert_eq!(res, vec![-7., -6., -2., 1., 2., 3., 4.]);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn get_row_as_vector() {
|
||||
let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
|
||||
@@ -678,6 +858,14 @@ mod tests {
|
||||
assert_eq!(res, vec![4., 5., 6.]);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn get_row() {
|
||||
let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
|
||||
assert_eq!(arr1(&[4., 5., 6.]), a.get_row(1));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn get_col_as_vector() {
|
||||
let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
|
||||
@@ -685,6 +873,7 @@ mod tests {
|
||||
assert_eq!(res, vec![2., 5., 8.]);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn copy_row_col_as_vec() {
|
||||
let m = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
|
||||
@@ -696,6 +885,7 @@ mod tests {
|
||||
assert_eq!(v, vec!(2., 5., 8.));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn col_mean() {
|
||||
let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
|
||||
@@ -703,6 +893,7 @@ mod tests {
|
||||
assert_eq!(res, vec![4., 5., 6.]);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn eye() {
|
||||
let a = arr2(&[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]);
|
||||
@@ -710,6 +901,7 @@ mod tests {
|
||||
assert_eq!(res, a);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn rand() {
|
||||
let m: Array2<f64> = BaseMatrix::rand(3, 3);
|
||||
@@ -720,6 +912,7 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn approximate_eq() {
|
||||
let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
|
||||
@@ -728,6 +921,7 @@ mod tests {
|
||||
assert!(!a.approximate_eq(&(&noise + &a), 1e-5));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn abs_mut() {
|
||||
let mut a = arr2(&[[1., -2.], [3., -4.]]);
|
||||
@@ -736,6 +930,7 @@ mod tests {
|
||||
assert_eq!(a, expected);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn lr_fit_predict_iris() {
|
||||
let x = arr2(&[
|
||||
@@ -764,19 +959,20 @@ mod tests {
|
||||
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
]);
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let y_hat = lr.predict(&x).unwrap();
|
||||
|
||||
let error: f64 = y
|
||||
.into_iter()
|
||||
.zip(y_hat.into_iter())
|
||||
.map(|(&a, &b)| (a - b).abs())
|
||||
.map(|(a, b)| (a - b).abs())
|
||||
.sum();
|
||||
|
||||
assert!(error <= 1.0);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn my_fit_longley_ndarray() {
|
||||
let x = arr2(&[
|
||||
@@ -811,6 +1007,8 @@ mod tests {
|
||||
min_samples_split: 2,
|
||||
n_trees: 1000,
|
||||
m: Option::None,
|
||||
keep_samples: false,
|
||||
seed: 0,
|
||||
},
|
||||
)
|
||||
.unwrap()
|
||||
|
||||
+12
-15
@@ -44,18 +44,14 @@ pub struct QR<T: RealNumber, M: BaseMatrix<T>> {
|
||||
impl<T: RealNumber, M: BaseMatrix<T>> QR<T, M> {
|
||||
pub(crate) fn new(QR: M, tau: Vec<T>) -> QR<T, M> {
|
||||
let mut singular = false;
|
||||
for j in 0..tau.len() {
|
||||
if tau[j] == T::zero() {
|
||||
for tau_elem in tau.iter() {
|
||||
if *tau_elem == T::zero() {
|
||||
singular = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
QR {
|
||||
QR: QR,
|
||||
tau: tau,
|
||||
singular: singular,
|
||||
}
|
||||
QR { QR, tau, singular }
|
||||
}
|
||||
|
||||
/// Get upper triangular matrix.
|
||||
@@ -68,7 +64,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> QR<T, M> {
|
||||
R.set(i, j, self.QR.get(i, j));
|
||||
}
|
||||
}
|
||||
return R;
|
||||
R
|
||||
}
|
||||
|
||||
/// Get an orthogonal matrix.
|
||||
@@ -82,7 +78,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> QR<T, M> {
|
||||
if self.QR.get(k, k) != T::zero() {
|
||||
let mut s = T::zero();
|
||||
for i in k..m {
|
||||
s = s + self.QR.get(i, k) * Q.get(i, j);
|
||||
s += self.QR.get(i, k) * Q.get(i, j);
|
||||
}
|
||||
s = -s / self.QR.get(k, k);
|
||||
for i in k..m {
|
||||
@@ -96,7 +92,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> QR<T, M> {
|
||||
k -= 1;
|
||||
}
|
||||
}
|
||||
return Q;
|
||||
Q
|
||||
}
|
||||
|
||||
fn solve(&self, mut b: M) -> Result<M, Failed> {
|
||||
@@ -118,7 +114,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> QR<T, M> {
|
||||
for j in 0..b_ncols {
|
||||
let mut s = T::zero();
|
||||
for i in k..m {
|
||||
s = s + self.QR.get(i, k) * b.get(i, j);
|
||||
s += self.QR.get(i, k) * b.get(i, j);
|
||||
}
|
||||
s = -s / self.QR.get(k, k);
|
||||
for i in k..m {
|
||||
@@ -157,7 +153,7 @@ pub trait QRDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
|
||||
let mut r_diagonal: Vec<T> = vec![T::zero(); n];
|
||||
|
||||
for k in 0..n {
|
||||
for (k, r_diagonal_k) in r_diagonal.iter_mut().enumerate().take(n) {
|
||||
let mut nrm = T::zero();
|
||||
for i in k..m {
|
||||
nrm = nrm.hypot(self.get(i, k));
|
||||
@@ -175,7 +171,7 @@ pub trait QRDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
for j in k + 1..n {
|
||||
let mut s = T::zero();
|
||||
for i in k..m {
|
||||
s = s + self.get(i, k) * self.get(i, j);
|
||||
s += self.get(i, k) * self.get(i, j);
|
||||
}
|
||||
s = -s / self.get(k, k);
|
||||
for i in k..m {
|
||||
@@ -183,7 +179,7 @@ pub trait QRDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
}
|
||||
}
|
||||
}
|
||||
r_diagonal[k] = -nrm;
|
||||
*r_diagonal_k = -nrm;
|
||||
}
|
||||
|
||||
Ok(QR::new(self, r_diagonal))
|
||||
@@ -199,7 +195,7 @@ pub trait QRDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn decompose() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||
@@ -218,6 +214,7 @@ mod tests {
|
||||
assert!(qr.R().abs().approximate_eq(&r.abs(), 1e-4));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn qr_solve_mut() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||
|
||||
@@ -0,0 +1,207 @@
|
||||
//! # Various Statistical Methods
|
||||
//!
|
||||
//! This module provides reference implementations for various statistical functions.
|
||||
//! Concrete implementations of the `BaseMatrix` trait are free to override these methods for better performance.
|
||||
|
||||
use crate::linalg::BaseMatrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Defines baseline implementations for various statistical functions
|
||||
pub trait MatrixStats<T: RealNumber>: BaseMatrix<T> {
|
||||
/// Computes the arithmetic mean along the specified axis.
|
||||
fn mean(&self, axis: u8) -> Vec<T> {
|
||||
let (n, m) = match axis {
|
||||
0 => {
|
||||
let (n, m) = self.shape();
|
||||
(m, n)
|
||||
}
|
||||
_ => self.shape(),
|
||||
};
|
||||
|
||||
let mut x: Vec<T> = vec![T::zero(); n];
|
||||
|
||||
let div = T::from_usize(m).unwrap();
|
||||
|
||||
for (i, x_i) in x.iter_mut().enumerate().take(n) {
|
||||
for j in 0..m {
|
||||
*x_i += match axis {
|
||||
0 => self.get(j, i),
|
||||
_ => self.get(i, j),
|
||||
};
|
||||
}
|
||||
*x_i /= div;
|
||||
}
|
||||
|
||||
x
|
||||
}
|
||||
|
||||
/// Computes variance along the specified axis.
|
||||
fn var(&self, axis: u8) -> Vec<T> {
|
||||
let (n, m) = match axis {
|
||||
0 => {
|
||||
let (n, m) = self.shape();
|
||||
(m, n)
|
||||
}
|
||||
_ => self.shape(),
|
||||
};
|
||||
|
||||
let mut x: Vec<T> = vec![T::zero(); n];
|
||||
|
||||
let div = T::from_usize(m).unwrap();
|
||||
|
||||
for (i, x_i) in x.iter_mut().enumerate().take(n) {
|
||||
let mut mu = T::zero();
|
||||
let mut sum = T::zero();
|
||||
for j in 0..m {
|
||||
let a = match axis {
|
||||
0 => self.get(j, i),
|
||||
_ => self.get(i, j),
|
||||
};
|
||||
mu += a;
|
||||
sum += a * a;
|
||||
}
|
||||
mu /= div;
|
||||
*x_i = sum / div - mu.powi(2);
|
||||
}
|
||||
|
||||
x
|
||||
}
|
||||
|
||||
/// Computes the standard deviation along the specified axis.
|
||||
fn std(&self, axis: u8) -> Vec<T> {
|
||||
let mut x = self.var(axis);
|
||||
|
||||
let n = match axis {
|
||||
0 => self.shape().1,
|
||||
_ => self.shape().0,
|
||||
};
|
||||
|
||||
for x_i in x.iter_mut().take(n) {
|
||||
*x_i = x_i.sqrt();
|
||||
}
|
||||
|
||||
x
|
||||
}
|
||||
|
||||
/// standardize values by removing the mean and scaling to unit variance
|
||||
fn scale_mut(&mut self, mean: &[T], std: &[T], axis: u8) {
|
||||
let (n, m) = match axis {
|
||||
0 => {
|
||||
let (n, m) = self.shape();
|
||||
(m, n)
|
||||
}
|
||||
_ => self.shape(),
|
||||
};
|
||||
|
||||
for i in 0..n {
|
||||
for j in 0..m {
|
||||
match axis {
|
||||
0 => self.set(j, i, (self.get(j, i) - mean[i]) / std[i]),
|
||||
_ => self.set(i, j, (self.get(i, j) - mean[i]) / std[i]),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Defines baseline implementations for various matrix processing functions
|
||||
pub trait MatrixPreprocessing<T: RealNumber>: BaseMatrix<T> {
|
||||
/// Each element of the matrix greater than the threshold becomes 1, while values less than or equal to the threshold become 0
|
||||
/// ```
|
||||
/// use smartcore::linalg::naive::dense_matrix::*;
|
||||
/// use crate::smartcore::linalg::stats::MatrixPreprocessing;
|
||||
/// let mut a = DenseMatrix::from_array(2, 3, &[0., 2., 3., -5., -6., -7.]);
|
||||
/// let expected = DenseMatrix::from_array(2, 3, &[0., 1., 1., 0., 0., 0.]);
|
||||
/// a.binarize_mut(0.);
|
||||
///
|
||||
/// assert_eq!(a, expected);
|
||||
/// ```
|
||||
|
||||
fn binarize_mut(&mut self, threshold: T) {
|
||||
let (nrows, ncols) = self.shape();
|
||||
for row in 0..nrows {
|
||||
for col in 0..ncols {
|
||||
if self.get(row, col) > threshold {
|
||||
self.set(row, col, T::one());
|
||||
} else {
|
||||
self.set(row, col, T::zero());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
/// Returns new matrix where elements are binarized according to a given threshold.
|
||||
/// ```
|
||||
/// use smartcore::linalg::naive::dense_matrix::*;
|
||||
/// use crate::smartcore::linalg::stats::MatrixPreprocessing;
|
||||
/// let a = DenseMatrix::from_array(2, 3, &[0., 2., 3., -5., -6., -7.]);
|
||||
/// let expected = DenseMatrix::from_array(2, 3, &[0., 1., 1., 0., 0., 0.]);
|
||||
///
|
||||
/// assert_eq!(a.binarize(0.), expected);
|
||||
/// ```
|
||||
fn binarize(&self, threshold: T) -> Self {
|
||||
let mut m = self.clone();
|
||||
m.binarize_mut(threshold);
|
||||
m
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::linalg::BaseVector;
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn mean() {
|
||||
let m = DenseMatrix::from_2d_array(&[
|
||||
&[1., 2., 3., 1., 2.],
|
||||
&[4., 5., 6., 3., 4.],
|
||||
&[7., 8., 9., 5., 6.],
|
||||
]);
|
||||
let expected_0 = vec![4., 5., 6., 3., 4.];
|
||||
let expected_1 = vec![1.8, 4.4, 7.];
|
||||
|
||||
assert_eq!(m.mean(0), expected_0);
|
||||
assert_eq!(m.mean(1), expected_1);
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn std() {
|
||||
let m = DenseMatrix::from_2d_array(&[
|
||||
&[1., 2., 3., 1., 2.],
|
||||
&[4., 5., 6., 3., 4.],
|
||||
&[7., 8., 9., 5., 6.],
|
||||
]);
|
||||
let expected_0 = vec![2.44, 2.44, 2.44, 1.63, 1.63];
|
||||
let expected_1 = vec![0.74, 1.01, 1.41];
|
||||
|
||||
assert!(m.std(0).approximate_eq(&expected_0, 1e-2));
|
||||
assert!(m.std(1).approximate_eq(&expected_1, 1e-2));
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn var() {
|
||||
let m = DenseMatrix::from_2d_array(&[&[1., 2., 3., 4.], &[5., 6., 7., 8.]]);
|
||||
let expected_0 = vec![4., 4., 4., 4.];
|
||||
let expected_1 = vec![1.25, 1.25];
|
||||
|
||||
assert!(m.var(0).approximate_eq(&expected_0, std::f64::EPSILON));
|
||||
assert!(m.var(1).approximate_eq(&expected_1, std::f64::EPSILON));
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn scale() {
|
||||
let mut m = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
|
||||
let expected_0 = DenseMatrix::from_2d_array(&[&[-1., -1., -1.], &[1., 1., 1.]]);
|
||||
let expected_1 = DenseMatrix::from_2d_array(&[&[-1.22, 0.0, 1.22], &[-1.22, 0.0, 1.22]]);
|
||||
|
||||
{
|
||||
let mut m = m.clone();
|
||||
m.scale_mut(&m.mean(0), &m.std(0), 0);
|
||||
assert!(m.approximate_eq(&expected_0, std::f32::EPSILON));
|
||||
}
|
||||
|
||||
m.scale_mut(&m.mean(1), &m.std(1), 1);
|
||||
assert!(m.approximate_eq(&expected_1, 1e-2));
|
||||
}
|
||||
}
|
||||
+44
-43
@@ -47,7 +47,7 @@ pub struct SVD<T: RealNumber, M: SVDDecomposableMatrix<T>> {
|
||||
pub V: M,
|
||||
/// Singular values of the original matrix
|
||||
pub s: Vec<T>,
|
||||
full: bool,
|
||||
_full: bool,
|
||||
m: usize,
|
||||
n: usize,
|
||||
tol: T,
|
||||
@@ -106,23 +106,23 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
|
||||
if i < m {
|
||||
for k in i..m {
|
||||
scale = scale + U.get(k, i).abs();
|
||||
scale += U.get(k, i).abs();
|
||||
}
|
||||
|
||||
if scale.abs() > T::epsilon() {
|
||||
for k in i..m {
|
||||
U.div_element_mut(k, i, scale);
|
||||
s = s + U.get(k, i) * U.get(k, i);
|
||||
s += U.get(k, i) * U.get(k, i);
|
||||
}
|
||||
|
||||
let mut f = U.get(i, i);
|
||||
g = -s.sqrt().copysign(f);
|
||||
g = -RealNumber::copysign(s.sqrt(), f);
|
||||
let h = f * g - s;
|
||||
U.set(i, i, f - g);
|
||||
for j in l - 1..n {
|
||||
s = T::zero();
|
||||
for k in i..m {
|
||||
s = s + U.get(k, i) * U.get(k, j);
|
||||
s += U.get(k, i) * U.get(k, j);
|
||||
}
|
||||
f = s / h;
|
||||
for k in i..m {
|
||||
@@ -140,34 +140,34 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
let mut s = T::zero();
|
||||
scale = T::zero();
|
||||
|
||||
if i + 1 <= m && i + 1 != n {
|
||||
if i < m && i + 1 != n {
|
||||
for k in l - 1..n {
|
||||
scale = scale + U.get(i, k).abs();
|
||||
scale += U.get(i, k).abs();
|
||||
}
|
||||
|
||||
if scale.abs() > T::epsilon() {
|
||||
for k in l - 1..n {
|
||||
U.div_element_mut(i, k, scale);
|
||||
s = s + U.get(i, k) * U.get(i, k);
|
||||
s += U.get(i, k) * U.get(i, k);
|
||||
}
|
||||
|
||||
let f = U.get(i, l - 1);
|
||||
g = -s.sqrt().copysign(f);
|
||||
g = -RealNumber::copysign(s.sqrt(), f);
|
||||
let h = f * g - s;
|
||||
U.set(i, l - 1, f - g);
|
||||
|
||||
for k in l - 1..n {
|
||||
rv1[k] = U.get(i, k) / h;
|
||||
for (k, rv1_k) in rv1.iter_mut().enumerate().take(n).skip(l - 1) {
|
||||
*rv1_k = U.get(i, k) / h;
|
||||
}
|
||||
|
||||
for j in l - 1..m {
|
||||
s = T::zero();
|
||||
for k in l - 1..n {
|
||||
s = s + U.get(j, k) * U.get(i, k);
|
||||
s += U.get(j, k) * U.get(i, k);
|
||||
}
|
||||
|
||||
for k in l - 1..n {
|
||||
U.add_element_mut(j, k, s * rv1[k]);
|
||||
for (k, rv1_k) in rv1.iter().enumerate().take(n).skip(l - 1) {
|
||||
U.add_element_mut(j, k, s * (*rv1_k));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,7 +189,7 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
for j in l..n {
|
||||
let mut s = T::zero();
|
||||
for k in l..n {
|
||||
s = s + U.get(i, k) * v.get(k, j);
|
||||
s += U.get(i, k) * v.get(k, j);
|
||||
}
|
||||
for k in l..n {
|
||||
v.add_element_mut(k, j, s * v.get(k, i));
|
||||
@@ -218,7 +218,7 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
for j in l..n {
|
||||
let mut s = T::zero();
|
||||
for k in l..m {
|
||||
s = s + U.get(k, i) * U.get(k, j);
|
||||
s += U.get(k, i) * U.get(k, j);
|
||||
}
|
||||
let f = (s / U.get(i, i)) * g;
|
||||
for k in i..m {
|
||||
@@ -299,7 +299,7 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
let mut h = rv1[k];
|
||||
let mut f = ((y - z) * (y + z) + (g - h) * (g + h)) / (T::two() * h * y);
|
||||
g = f.hypot(T::one());
|
||||
f = ((x - z) * (x + z) + h * ((y / (f + g.copysign(f))) - h)) / x;
|
||||
f = ((x - z) * (x + z) + h * ((y / (f + RealNumber::copysign(g, f))) - h)) / x;
|
||||
let mut c = T::one();
|
||||
let mut s = T::one();
|
||||
|
||||
@@ -316,7 +316,7 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
f = x * c + g * s;
|
||||
g = g * c - x * s;
|
||||
h = y * s;
|
||||
y = y * c;
|
||||
y *= c;
|
||||
|
||||
for jj in 0..n {
|
||||
x = v.get(jj, j);
|
||||
@@ -365,11 +365,11 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
inc /= 3;
|
||||
for i in inc..n {
|
||||
let sw = w[i];
|
||||
for k in 0..m {
|
||||
su[k] = U.get(k, i);
|
||||
for (k, su_k) in su.iter_mut().enumerate().take(m) {
|
||||
*su_k = U.get(k, i);
|
||||
}
|
||||
for k in 0..n {
|
||||
sv[k] = v.get(k, i);
|
||||
for (k, sv_k) in sv.iter_mut().enumerate().take(n) {
|
||||
*sv_k = v.get(k, i);
|
||||
}
|
||||
let mut j = i;
|
||||
while w[j - inc] < sw {
|
||||
@@ -386,11 +386,11 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
||||
}
|
||||
}
|
||||
w[j] = sw;
|
||||
for k in 0..m {
|
||||
U.set(k, j, su[k]);
|
||||
for (k, su_k) in su.iter().enumerate().take(m) {
|
||||
U.set(k, j, *su_k);
|
||||
}
|
||||
for k in 0..n {
|
||||
v.set(k, j, sv[k]);
|
||||
for (k, sv_k) in sv.iter().enumerate().take(n) {
|
||||
v.set(k, j, *sv_k);
|
||||
}
|
||||
}
|
||||
if inc <= 1 {
|
||||
@@ -428,16 +428,16 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
|
||||
pub(crate) fn new(U: M, V: M, s: Vec<T>) -> SVD<T, M> {
|
||||
let m = U.shape().0;
|
||||
let n = V.shape().0;
|
||||
let full = s.len() == m.min(n);
|
||||
let _full = s.len() == m.min(n);
|
||||
let tol = T::half() * (T::from(m + n).unwrap() + T::one()).sqrt() * s[0] * T::epsilon();
|
||||
SVD {
|
||||
U: U,
|
||||
V: V,
|
||||
s: s,
|
||||
full: full,
|
||||
m: m,
|
||||
n: n,
|
||||
tol: tol,
|
||||
U,
|
||||
V,
|
||||
s,
|
||||
_full,
|
||||
m,
|
||||
n,
|
||||
tol,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -454,21 +454,21 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
|
||||
|
||||
for k in 0..p {
|
||||
let mut tmp = vec![T::zero(); self.n];
|
||||
for j in 0..self.n {
|
||||
for (j, tmp_j) in tmp.iter_mut().enumerate().take(self.n) {
|
||||
let mut r = T::zero();
|
||||
if self.s[j] > self.tol {
|
||||
for i in 0..self.m {
|
||||
r = r + self.U.get(i, j) * b.get(i, k);
|
||||
r += self.U.get(i, j) * b.get(i, k);
|
||||
}
|
||||
r = r / self.s[j];
|
||||
r /= self.s[j];
|
||||
}
|
||||
tmp[j] = r;
|
||||
*tmp_j = r;
|
||||
}
|
||||
|
||||
for j in 0..self.n {
|
||||
let mut r = T::zero();
|
||||
for jj in 0..self.n {
|
||||
r = r + self.V.get(j, jj) * tmp[jj];
|
||||
for (jj, tmp_jj) in tmp.iter().enumerate().take(self.n) {
|
||||
r += self.V.get(j, jj) * (*tmp_jj);
|
||||
}
|
||||
b.set(j, k, r);
|
||||
}
|
||||
@@ -482,7 +482,7 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn decompose_symmetric() {
|
||||
let A = DenseMatrix::from_2d_array(&[
|
||||
@@ -513,7 +513,7 @@ mod tests {
|
||||
assert!((s[i] - svd.s[i]).abs() < 1e-4);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn decompose_asymmetric() {
|
||||
let A = DenseMatrix::from_2d_array(&[
|
||||
@@ -714,7 +714,7 @@ mod tests {
|
||||
assert!((s[i] - svd.s[i]).abs() < 1e-4);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn solve() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||
@@ -725,6 +725,7 @@ mod tests {
|
||||
assert!(w.approximate_eq(&expected_w, 1e-2));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn decompose_restore() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0, 4.0], &[5.0, 6.0, 7.0, 8.0]]);
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
//! This is a generic solver for Ax = b type of equation
|
||||
//!
|
||||
//! for more information take a look at [this Wikipedia article](https://en.wikipedia.org/wiki/Biconjugate_gradient_method)
|
||||
//! and [this paper](https://www.cs.cmu.edu/~quake-papers/painless-conjugate-gradient.pdf)
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
pub trait BiconjugateGradientSolver<T: RealNumber, M: Matrix<T>> {
|
||||
fn solve_mut(&self, a: &M, b: &M, x: &mut M, tol: T, max_iter: usize) -> Result<T, Failed> {
|
||||
if tol <= T::zero() {
|
||||
return Err(Failed::fit("tolerance shoud be > 0"));
|
||||
}
|
||||
|
||||
if max_iter == 0 {
|
||||
return Err(Failed::fit("maximum number of iterations should be > 0"));
|
||||
}
|
||||
|
||||
let (n, _) = b.shape();
|
||||
|
||||
let mut r = M::zeros(n, 1);
|
||||
let mut rr = M::zeros(n, 1);
|
||||
let mut z = M::zeros(n, 1);
|
||||
let mut zz = M::zeros(n, 1);
|
||||
|
||||
self.mat_vec_mul(a, x, &mut r);
|
||||
|
||||
for j in 0..n {
|
||||
r.set(j, 0, b.get(j, 0) - r.get(j, 0));
|
||||
rr.set(j, 0, r.get(j, 0));
|
||||
}
|
||||
|
||||
let bnrm = b.norm(T::two());
|
||||
self.solve_preconditioner(a, &r, &mut z);
|
||||
|
||||
let mut p = M::zeros(n, 1);
|
||||
let mut pp = M::zeros(n, 1);
|
||||
let mut bkden = T::zero();
|
||||
let mut err = T::zero();
|
||||
|
||||
for iter in 1..max_iter {
|
||||
let mut bknum = T::zero();
|
||||
|
||||
self.solve_preconditioner(a, &rr, &mut zz);
|
||||
for j in 0..n {
|
||||
bknum += z.get(j, 0) * rr.get(j, 0);
|
||||
}
|
||||
if iter == 1 {
|
||||
for j in 0..n {
|
||||
p.set(j, 0, z.get(j, 0));
|
||||
pp.set(j, 0, zz.get(j, 0));
|
||||
}
|
||||
} else {
|
||||
let bk = bknum / bkden;
|
||||
for j in 0..n {
|
||||
p.set(j, 0, bk * p.get(j, 0) + z.get(j, 0));
|
||||
pp.set(j, 0, bk * pp.get(j, 0) + zz.get(j, 0));
|
||||
}
|
||||
}
|
||||
bkden = bknum;
|
||||
self.mat_vec_mul(a, &p, &mut z);
|
||||
let mut akden = T::zero();
|
||||
for j in 0..n {
|
||||
akden += z.get(j, 0) * pp.get(j, 0);
|
||||
}
|
||||
let ak = bknum / akden;
|
||||
self.mat_t_vec_mul(a, &pp, &mut zz);
|
||||
for j in 0..n {
|
||||
x.set(j, 0, x.get(j, 0) + ak * p.get(j, 0));
|
||||
r.set(j, 0, r.get(j, 0) - ak * z.get(j, 0));
|
||||
rr.set(j, 0, rr.get(j, 0) - ak * zz.get(j, 0));
|
||||
}
|
||||
self.solve_preconditioner(a, &r, &mut z);
|
||||
err = r.norm(T::two()) / bnrm;
|
||||
|
||||
if err <= tol {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(err)
|
||||
}
|
||||
|
||||
fn solve_preconditioner(&self, a: &M, b: &M, x: &mut M) {
|
||||
let diag = Self::diag(a);
|
||||
let n = diag.len();
|
||||
|
||||
for (i, diag_i) in diag.iter().enumerate().take(n) {
|
||||
if *diag_i != T::zero() {
|
||||
x.set(i, 0, b.get(i, 0) / *diag_i);
|
||||
} else {
|
||||
x.set(i, 0, b.get(i, 0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// y = Ax
|
||||
fn mat_vec_mul(&self, a: &M, x: &M, y: &mut M) {
|
||||
y.copy_from(&a.matmul(x));
|
||||
}
|
||||
|
||||
// y = Atx
|
||||
fn mat_t_vec_mul(&self, a: &M, x: &M, y: &mut M) {
|
||||
y.copy_from(&a.ab(true, x, false));
|
||||
}
|
||||
|
||||
fn diag(a: &M) -> Vec<T> {
|
||||
let (nrows, ncols) = a.shape();
|
||||
let n = nrows.min(ncols);
|
||||
|
||||
let mut d = Vec::with_capacity(n);
|
||||
for i in 0..n {
|
||||
d.push(a.get(i, i));
|
||||
}
|
||||
|
||||
d
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
|
||||
pub struct BGSolver {}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> BiconjugateGradientSolver<T, M> for BGSolver {}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn bg_solver() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
|
||||
let b = DenseMatrix::from_2d_array(&[&[40., 51., 28.]]);
|
||||
let expected = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]);
|
||||
|
||||
let mut x = DenseMatrix::zeros(3, 1);
|
||||
|
||||
let solver = BGSolver {};
|
||||
|
||||
let err: f64 = solver
|
||||
.solve_mut(&a, &b.transpose(), &mut x, 1e-6, 6)
|
||||
.unwrap();
|
||||
|
||||
assert!(x.transpose().approximate_eq(&expected, 1e-4));
|
||||
assert!((err - 0.0).abs() < 1e-4);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,440 @@
|
||||
#![allow(clippy::needless_range_loop)]
|
||||
//! # Elastic Net
|
||||
//!
|
||||
//! Elastic net is an extension of [linear regression](../linear_regression/index.html) that adds regularization penalties to the loss function during training.
|
||||
//! Just like in ordinary linear regression you assume a linear relationship between input variables and the target variable.
|
||||
//! Unlike linear regression elastic net adds regularization penalties to the loss function during training.
|
||||
//! In particular, the elastic net coefficient estimates \\(\beta\\) are the values that minimize
|
||||
//!
|
||||
//! \\[L(\alpha, \beta) = \vert \boldsymbol{y} - \boldsymbol{X}\beta\vert^2 + \lambda_1 \vert \beta \vert^2 + \lambda_2 \vert \beta \vert_1\\]
|
||||
//!
|
||||
//! where \\(\lambda_1 = \\alpha l_{1r}\\), \\(\lambda_2 = \\alpha (1 - l_{1r})\\) and \\(l_{1r}\\) is the l1 ratio, elastic net mixing parameter.
|
||||
//!
|
||||
//! In essense, elastic net combines both the [L1](../lasso/index.html) and [L2](../ridge_regression/index.html) penalties during training,
|
||||
//! which can result in better performance than a model with either one or the other penalty on some problems.
|
||||
//! The elastic net is particularly useful when the number of predictors (p) is much bigger than the number of observations (n).
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::linear::elastic_net::*;
|
||||
//!
|
||||
//! // Longley dataset (https://www.statsmodels.org/stable/datasets/generated/longley.html)
|
||||
//! let x = DenseMatrix::from_2d_array(&[
|
||||
//! &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
//! &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
//! &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
//! &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
|
||||
//! &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
//! &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
|
||||
//! &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
|
||||
//! &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
|
||||
//! &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
//! &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
//! &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
//! &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
|
||||
//! &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
//! ]);
|
||||
//!
|
||||
//! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0,
|
||||
//! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
||||
//!
|
||||
//! let y_hat = ElasticNet::fit(&x, &y, Default::default()).
|
||||
//! and_then(|lr| lr.predict(&x)).unwrap();
|
||||
//! ```
|
||||
//!
|
||||
//! ## References:
|
||||
//!
|
||||
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 6.2. Shrinkage Methods](http://faculty.marshall.usc.edu/gareth-james/ISL/)
|
||||
//! * ["Regularization and variable selection via the elastic net", Hui Zou and Trevor Hastie](https://web.stanford.edu/~hastie/Papers/B67.2%20(2005)%20301-320%20Zou%20&%20Hastie.pdf)
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::fmt::Debug;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
use crate::linear::lasso_optimizer::InteriorPointOptimizer;
|
||||
|
||||
/// Elastic net parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ElasticNetParameters<T: RealNumber> {
|
||||
/// Regularization parameter.
|
||||
pub alpha: T,
|
||||
/// The elastic net mixing parameter, with 0 <= l1_ratio <= 1.
|
||||
/// For l1_ratio = 0 the penalty is an L2 penalty.
|
||||
/// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.
|
||||
pub l1_ratio: T,
|
||||
/// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation.
|
||||
pub normalize: bool,
|
||||
/// The tolerance for the optimization
|
||||
pub tol: T,
|
||||
/// The maximum number of iterations
|
||||
pub max_iter: usize,
|
||||
}
|
||||
|
||||
/// Elastic net
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct ElasticNet<T: RealNumber, M: Matrix<T>> {
|
||||
coefficients: M,
|
||||
intercept: T,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> ElasticNetParameters<T> {
|
||||
/// Regularization parameter.
|
||||
pub fn with_alpha(mut self, alpha: T) -> Self {
|
||||
self.alpha = alpha;
|
||||
self
|
||||
}
|
||||
/// The elastic net mixing parameter, with 0 <= l1_ratio <= 1.
|
||||
/// For l1_ratio = 0 the penalty is an L2 penalty.
|
||||
/// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.
|
||||
pub fn with_l1_ratio(mut self, l1_ratio: T) -> Self {
|
||||
self.l1_ratio = l1_ratio;
|
||||
self
|
||||
}
|
||||
/// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation.
|
||||
pub fn with_normalize(mut self, normalize: bool) -> Self {
|
||||
self.normalize = normalize;
|
||||
self
|
||||
}
|
||||
/// The tolerance for the optimization
|
||||
pub fn with_tol(mut self, tol: T) -> Self {
|
||||
self.tol = tol;
|
||||
self
|
||||
}
|
||||
/// The maximum number of iterations
|
||||
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
|
||||
self.max_iter = max_iter;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for ElasticNetParameters<T> {
|
||||
fn default() -> Self {
|
||||
ElasticNetParameters {
|
||||
alpha: T::one(),
|
||||
l1_ratio: T::half(),
|
||||
normalize: true,
|
||||
tol: T::from_f64(1e-4).unwrap(),
|
||||
max_iter: 1000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> PartialEq for ElasticNet<T, M> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.coefficients == other.coefficients
|
||||
&& (self.intercept - other.intercept).abs() <= T::epsilon()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, ElasticNetParameters<T>>
|
||||
for ElasticNet<T, M>
|
||||
{
|
||||
fn fit(x: &M, y: &M::RowVector, parameters: ElasticNetParameters<T>) -> Result<Self, Failed> {
|
||||
ElasticNet::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for ElasticNet<T, M> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> ElasticNet<T, M> {
|
||||
/// Fits elastic net regression to your data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - target values
|
||||
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
|
||||
pub fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: ElasticNetParameters<T>,
|
||||
) -> Result<ElasticNet<T, M>, Failed> {
|
||||
let (n, p) = x.shape();
|
||||
|
||||
if y.len() != n {
|
||||
return Err(Failed::fit("Number of rows in X should = len(y)"));
|
||||
}
|
||||
|
||||
let n_float = T::from_usize(n).unwrap();
|
||||
|
||||
let l1_reg = parameters.alpha * parameters.l1_ratio * n_float;
|
||||
let l2_reg = parameters.alpha * (T::one() - parameters.l1_ratio) * n_float;
|
||||
|
||||
let y_mean = y.mean();
|
||||
|
||||
let (w, b) = if parameters.normalize {
|
||||
let (scaled_x, col_mean, col_std) = Self::rescale_x(x)?;
|
||||
|
||||
let (x, y, gamma) = Self::augment_x_and_y(&scaled_x, y, l2_reg);
|
||||
|
||||
let mut optimizer = InteriorPointOptimizer::new(&x, p);
|
||||
|
||||
let mut w =
|
||||
optimizer.optimize(&x, &y, l1_reg * gamma, parameters.max_iter, parameters.tol)?;
|
||||
|
||||
for i in 0..p {
|
||||
w.set(i, 0, gamma * w.get(i, 0) / col_std[i]);
|
||||
}
|
||||
|
||||
let mut b = T::zero();
|
||||
|
||||
for i in 0..p {
|
||||
b += w.get(i, 0) * col_mean[i];
|
||||
}
|
||||
|
||||
b = y_mean - b;
|
||||
|
||||
(w, b)
|
||||
} else {
|
||||
let (x, y, gamma) = Self::augment_x_and_y(x, y, l2_reg);
|
||||
|
||||
let mut optimizer = InteriorPointOptimizer::new(&x, p);
|
||||
|
||||
let mut w =
|
||||
optimizer.optimize(&x, &y, l1_reg * gamma, parameters.max_iter, parameters.tol)?;
|
||||
|
||||
for i in 0..p {
|
||||
w.set(i, 0, gamma * w.get(i, 0));
|
||||
}
|
||||
|
||||
(w, y_mean)
|
||||
};
|
||||
|
||||
Ok(ElasticNet {
|
||||
intercept: b,
|
||||
coefficients: w,
|
||||
})
|
||||
}
|
||||
|
||||
/// Predict target values from `x`
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let (nrows, _) = x.shape();
|
||||
let mut y_hat = x.matmul(&self.coefficients);
|
||||
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
|
||||
Ok(y_hat.transpose().to_row_vector())
|
||||
}
|
||||
|
||||
/// Get estimates regression coefficients
|
||||
pub fn coefficients(&self) -> &M {
|
||||
&self.coefficients
|
||||
}
|
||||
|
||||
/// Get estimate of intercept
|
||||
pub fn intercept(&self) -> T {
|
||||
self.intercept
|
||||
}
|
||||
|
||||
fn rescale_x(x: &M) -> Result<(M, Vec<T>, Vec<T>), Failed> {
|
||||
let col_mean = x.mean(0);
|
||||
let col_std = x.std(0);
|
||||
|
||||
for i in 0..col_std.len() {
|
||||
if (col_std[i] - T::zero()).abs() < T::epsilon() {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Cannot rescale constant column {}",
|
||||
i
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
let mut scaled_x = x.clone();
|
||||
scaled_x.scale_mut(&col_mean, &col_std, 0);
|
||||
Ok((scaled_x, col_mean, col_std))
|
||||
}
|
||||
|
||||
fn augment_x_and_y(x: &M, y: &M::RowVector, l2_reg: T) -> (M, M::RowVector, T) {
|
||||
let (n, p) = x.shape();
|
||||
|
||||
let gamma = T::one() / (T::one() + l2_reg).sqrt();
|
||||
let padding = gamma * l2_reg.sqrt();
|
||||
|
||||
let mut y2 = M::RowVector::zeros(n + p);
|
||||
for i in 0..y.len() {
|
||||
y2.set(i, y.get(i));
|
||||
}
|
||||
|
||||
let mut x2 = M::zeros(n + p, p);
|
||||
|
||||
for j in 0..p {
|
||||
for i in 0..n {
|
||||
x2.set(i, j, gamma * x.get(i, j));
|
||||
}
|
||||
|
||||
x2.set(j + n, j, padding);
|
||||
}
|
||||
|
||||
(x2, y2, gamma)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::metrics::mean_absolute_error;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn elasticnet_longley() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
&[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
|
||||
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
&[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
|
||||
&[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
|
||||
&[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
|
||||
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
&[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
&[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
|
||||
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
|
||||
let y: Vec<f64> = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
|
||||
let y_hat = ElasticNet::fit(
|
||||
&x,
|
||||
&y,
|
||||
ElasticNetParameters {
|
||||
alpha: 1.0,
|
||||
l1_ratio: 0.5,
|
||||
normalize: false,
|
||||
tol: 1e-4,
|
||||
max_iter: 1000,
|
||||
},
|
||||
)
|
||||
.and_then(|lr| lr.predict(&x))
|
||||
.unwrap();
|
||||
|
||||
assert!(mean_absolute_error(&y_hat, &y) < 30.0);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn elasticnet_fit_predict1() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[0.0, 1931.0, 1.2232755825400514],
|
||||
&[1.0, 1933.0, 1.1379726120972395],
|
||||
&[2.0, 1920.0, 1.4366265120543429],
|
||||
&[3.0, 1918.0, 1.206005737827858],
|
||||
&[4.0, 1934.0, 1.436613542400669],
|
||||
&[5.0, 1918.0, 1.1594588621640636],
|
||||
&[6.0, 1933.0, 1.19809994745985],
|
||||
&[7.0, 1918.0, 1.3396363871645678],
|
||||
&[8.0, 1931.0, 1.2535342096493207],
|
||||
&[9.0, 1933.0, 1.3101281563456293],
|
||||
&[10.0, 1922.0, 1.3585833349920762],
|
||||
&[11.0, 1930.0, 1.4830786699709897],
|
||||
&[12.0, 1916.0, 1.4919891143094546],
|
||||
&[13.0, 1915.0, 1.259655137451551],
|
||||
&[14.0, 1932.0, 1.3979191428724789],
|
||||
&[15.0, 1917.0, 1.3686634746782371],
|
||||
&[16.0, 1932.0, 1.381658454569724],
|
||||
&[17.0, 1918.0, 1.4054969025700674],
|
||||
&[18.0, 1929.0, 1.3271699396384906],
|
||||
&[19.0, 1915.0, 1.1373332337674806],
|
||||
]);
|
||||
|
||||
let y: Vec<f64> = vec![
|
||||
1.48, 2.72, 4.52, 5.72, 5.25, 4.07, 3.75, 4.75, 6.77, 4.72, 6.78, 6.79, 8.3, 7.42,
|
||||
10.2, 7.92, 7.62, 8.06, 9.06, 9.29,
|
||||
];
|
||||
|
||||
let l1_model = ElasticNet::fit(
|
||||
&x,
|
||||
&y,
|
||||
ElasticNetParameters {
|
||||
alpha: 1.0,
|
||||
l1_ratio: 1.0,
|
||||
normalize: true,
|
||||
tol: 1e-4,
|
||||
max_iter: 1000,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let l2_model = ElasticNet::fit(
|
||||
&x,
|
||||
&y,
|
||||
ElasticNetParameters {
|
||||
alpha: 1.0,
|
||||
l1_ratio: 0.0,
|
||||
normalize: true,
|
||||
tol: 1e-4,
|
||||
max_iter: 1000,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let mae_l1 = mean_absolute_error(&l1_model.predict(&x).unwrap(), &y);
|
||||
let mae_l2 = mean_absolute_error(&l2_model.predict(&x).unwrap(), &y);
|
||||
|
||||
assert!(mae_l1 < 2.0);
|
||||
assert!(mae_l2 < 2.0);
|
||||
|
||||
assert!(l1_model.coefficients().get(0, 0) > l1_model.coefficients().get(1, 0));
|
||||
assert!(l1_model.coefficients().get(0, 0) > l1_model.coefficients().get(2, 0));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
&[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
|
||||
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
&[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
|
||||
&[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
|
||||
&[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
|
||||
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
&[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
&[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
|
||||
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
|
||||
let y = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
|
||||
let lr = ElasticNet::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let deserialized_lr: ElasticNet<f64, DenseMatrix<f64>> =
|
||||
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(lr, deserialized_lr);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,313 @@
|
||||
//! # Lasso
|
||||
//!
|
||||
//! [Linear regression](../linear_regression/index.html) is the standard algorithm for predicting a quantitative response \\(y\\) on the basis of a linear combination of explanatory variables \\(X\\)
|
||||
//! that assumes that there is approximately a linear relationship between \\(X\\) and \\(y\\).
|
||||
//! Lasso is an extension to linear regression that adds L1 regularization term to the loss function during training.
|
||||
//!
|
||||
//! Similar to [ridge regression](../ridge_regression/index.html), the lasso shrinks the coefficient estimates towards zero when. However, in the case of the lasso, the l1 penalty has the effect of
|
||||
//! forcing some of the coefficient estimates to be exactly equal to zero when the tuning parameter \\(\alpha\\) is sufficiently large.
|
||||
//!
|
||||
//! Lasso coefficient estimates solve the problem:
|
||||
//!
|
||||
//! \\[\underset{\beta}{minimize} \space \space \sum_{i=1}^n \left( y_i - \beta_0 - \sum_{j=1}^p \beta_jx_{ij} \right)^2 + \alpha \sum_{j=1}^p \lVert \beta_j \rVert_1\\]
|
||||
//!
|
||||
//! This problem is solved with an interior-point method that is comparable to coordinate descent in solving large problems with modest accuracy,
|
||||
//! but is able to solve them with high accuracy with relatively small additional computational cost.
|
||||
//!
|
||||
//! ## References:
|
||||
//!
|
||||
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 6.2. Shrinkage Methods](http://faculty.marshall.usc.edu/gareth-james/ISL/)
|
||||
//! * ["An Interior-Point Method for Large-Scale l1-Regularized Least Squares", K. Koh, M. Lustig, S. Boyd, D. Gorinevsky](https://web.stanford.edu/~boyd/papers/pdf/l1_ls.pdf)
|
||||
//! * [Simple Matlab Solver for l1-regularized Least Squares Problems](https://web.stanford.edu/~boyd/l1_ls/)
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::fmt::Debug;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::linear::lasso_optimizer::InteriorPointOptimizer;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Lasso regression parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LassoParameters<T: RealNumber> {
|
||||
/// Controls the strength of the penalty to the loss function.
|
||||
pub alpha: T,
|
||||
/// If true the regressors X will be normalized before regression
|
||||
/// by subtracting the mean and dividing by the standard deviation.
|
||||
pub normalize: bool,
|
||||
/// The tolerance for the optimization
|
||||
pub tol: T,
|
||||
/// The maximum number of iterations
|
||||
pub max_iter: usize,
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
/// Lasso regressor
|
||||
pub struct Lasso<T: RealNumber, M: Matrix<T>> {
|
||||
coefficients: M,
|
||||
intercept: T,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> LassoParameters<T> {
|
||||
/// Regularization parameter.
|
||||
pub fn with_alpha(mut self, alpha: T) -> Self {
|
||||
self.alpha = alpha;
|
||||
self
|
||||
}
|
||||
/// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation.
|
||||
pub fn with_normalize(mut self, normalize: bool) -> Self {
|
||||
self.normalize = normalize;
|
||||
self
|
||||
}
|
||||
/// The tolerance for the optimization
|
||||
pub fn with_tol(mut self, tol: T) -> Self {
|
||||
self.tol = tol;
|
||||
self
|
||||
}
|
||||
/// The maximum number of iterations
|
||||
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
|
||||
self.max_iter = max_iter;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for LassoParameters<T> {
|
||||
fn default() -> Self {
|
||||
LassoParameters {
|
||||
alpha: T::one(),
|
||||
normalize: true,
|
||||
tol: T::from_f64(1e-4).unwrap(),
|
||||
max_iter: 1000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> PartialEq for Lasso<T, M> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.coefficients == other.coefficients
|
||||
&& (self.intercept - other.intercept).abs() <= T::epsilon()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, LassoParameters<T>>
|
||||
for Lasso<T, M>
|
||||
{
|
||||
fn fit(x: &M, y: &M::RowVector, parameters: LassoParameters<T>) -> Result<Self, Failed> {
|
||||
Lasso::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for Lasso<T, M> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Lasso<T, M> {
|
||||
/// Fits Lasso regression to your data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - target values
|
||||
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
|
||||
pub fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: LassoParameters<T>,
|
||||
) -> Result<Lasso<T, M>, Failed> {
|
||||
let (n, p) = x.shape();
|
||||
|
||||
if n <= p {
|
||||
return Err(Failed::fit(
|
||||
"Number of rows in X should be >= number of columns in X",
|
||||
));
|
||||
}
|
||||
|
||||
if parameters.alpha < T::zero() {
|
||||
return Err(Failed::fit("alpha should be >= 0"));
|
||||
}
|
||||
|
||||
if parameters.tol <= T::zero() {
|
||||
return Err(Failed::fit("tol should be > 0"));
|
||||
}
|
||||
|
||||
if parameters.max_iter == 0 {
|
||||
return Err(Failed::fit("max_iter should be > 0"));
|
||||
}
|
||||
|
||||
if y.len() != n {
|
||||
return Err(Failed::fit("Number of rows in X should = len(y)"));
|
||||
}
|
||||
|
||||
let l1_reg = parameters.alpha * T::from_usize(n).unwrap();
|
||||
|
||||
let (w, b) = if parameters.normalize {
|
||||
let (scaled_x, col_mean, col_std) = Self::rescale_x(x)?;
|
||||
|
||||
let mut optimizer = InteriorPointOptimizer::new(&scaled_x, p);
|
||||
|
||||
let mut w =
|
||||
optimizer.optimize(&scaled_x, y, l1_reg, parameters.max_iter, parameters.tol)?;
|
||||
|
||||
for (j, col_std_j) in col_std.iter().enumerate().take(p) {
|
||||
w.set(j, 0, w.get(j, 0) / *col_std_j);
|
||||
}
|
||||
|
||||
let mut b = T::zero();
|
||||
|
||||
for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
|
||||
b += w.get(i, 0) * *col_mean_i;
|
||||
}
|
||||
|
||||
b = y.mean() - b;
|
||||
(w, b)
|
||||
} else {
|
||||
let mut optimizer = InteriorPointOptimizer::new(x, p);
|
||||
|
||||
let w = optimizer.optimize(x, y, l1_reg, parameters.max_iter, parameters.tol)?;
|
||||
|
||||
(w, y.mean())
|
||||
};
|
||||
|
||||
Ok(Lasso {
|
||||
intercept: b,
|
||||
coefficients: w,
|
||||
})
|
||||
}
|
||||
|
||||
/// Predict target values from `x`
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let (nrows, _) = x.shape();
|
||||
let mut y_hat = x.matmul(&self.coefficients);
|
||||
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
|
||||
Ok(y_hat.transpose().to_row_vector())
|
||||
}
|
||||
|
||||
/// Get estimates regression coefficients
|
||||
pub fn coefficients(&self) -> &M {
|
||||
&self.coefficients
|
||||
}
|
||||
|
||||
/// Get estimate of intercept
|
||||
pub fn intercept(&self) -> T {
|
||||
self.intercept
|
||||
}
|
||||
|
||||
fn rescale_x(x: &M) -> Result<(M, Vec<T>, Vec<T>), Failed> {
|
||||
let col_mean = x.mean(0);
|
||||
let col_std = x.std(0);
|
||||
|
||||
for (i, col_std_i) in col_std.iter().enumerate() {
|
||||
if (*col_std_i - T::zero()).abs() < T::epsilon() {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Cannot rescale constant column {}",
|
||||
i
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
let mut scaled_x = x.clone();
|
||||
scaled_x.scale_mut(&col_mean, &col_std, 0);
|
||||
Ok((scaled_x, col_mean, col_std))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::metrics::mean_absolute_error;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn lasso_fit_predict() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
&[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
|
||||
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
&[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
|
||||
&[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
|
||||
&[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
|
||||
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
&[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
&[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
|
||||
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
|
||||
let y: Vec<f64> = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
|
||||
let y_hat = Lasso::fit(&x, &y, Default::default())
|
||||
.and_then(|lr| lr.predict(&x))
|
||||
.unwrap();
|
||||
|
||||
assert!(mean_absolute_error(&y_hat, &y) < 2.0);
|
||||
|
||||
let y_hat = Lasso::fit(
|
||||
&x,
|
||||
&y,
|
||||
LassoParameters {
|
||||
alpha: 0.1,
|
||||
normalize: false,
|
||||
tol: 1e-4,
|
||||
max_iter: 1000,
|
||||
},
|
||||
)
|
||||
.and_then(|lr| lr.predict(&x))
|
||||
.unwrap();
|
||||
|
||||
assert!(mean_absolute_error(&y_hat, &y) < 2.0);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
&[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
|
||||
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
&[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
|
||||
&[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
|
||||
&[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
|
||||
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
&[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
&[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
|
||||
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
|
||||
let y = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
|
||||
let lr = Lasso::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let deserialized_lr: Lasso<f64, DenseMatrix<f64>> =
|
||||
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(lr, deserialized_lr);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,253 @@
|
||||
//! An Interior-Point Method for Large-Scale l1-Regularized Least Squares
|
||||
//!
|
||||
//! This is a specialized interior-point method for solving large-scale 1-regularized LSPs that uses the
|
||||
//! preconditioned conjugate gradients algorithm to compute the search direction.
|
||||
//!
|
||||
//! The interior-point method can solve large sparse problems, with a million variables and observations, in a few tens of minutes on a PC.
|
||||
//! It can efficiently solve large dense problems, that arise in sparse signal recovery with orthogonal transforms, by exploiting fast algorithms for these transforms.
|
||||
//!
|
||||
//! ## References:
|
||||
//! * ["An Interior-Point Method for Large-Scale l1-Regularized Least Squares", K. Koh, M. Lustig, S. Boyd, D. Gorinevsky](https://web.stanford.edu/~boyd/papers/pdf/l1_ls.pdf)
|
||||
//! * [Simple Matlab Solver for l1-regularized Least Squares Problems](https://web.stanford.edu/~boyd/l1_ls/)
|
||||
//!
|
||||
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::linear::bg_solver::BiconjugateGradientSolver;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
pub struct InteriorPointOptimizer<T: RealNumber, M: Matrix<T>> {
|
||||
ata: M,
|
||||
d1: Vec<T>,
|
||||
d2: Vec<T>,
|
||||
prb: Vec<T>,
|
||||
prs: Vec<T>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> InteriorPointOptimizer<T, M> {
|
||||
pub fn new(a: &M, n: usize) -> InteriorPointOptimizer<T, M> {
|
||||
InteriorPointOptimizer {
|
||||
ata: a.ab(true, a, false),
|
||||
d1: vec![T::zero(); n],
|
||||
d2: vec![T::zero(); n],
|
||||
prb: vec![T::zero(); n],
|
||||
prs: vec![T::zero(); n],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn optimize(
|
||||
&mut self,
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
lambda: T,
|
||||
max_iter: usize,
|
||||
tol: T,
|
||||
) -> Result<M, Failed> {
|
||||
let (n, p) = x.shape();
|
||||
let p_f64 = T::from_usize(p).unwrap();
|
||||
|
||||
let lambda = lambda.max(T::epsilon());
|
||||
|
||||
//parameters
|
||||
let pcgmaxi = 5000;
|
||||
let min_pcgtol = T::from_f64(0.1).unwrap();
|
||||
let eta = T::from_f64(1E-3).unwrap();
|
||||
let alpha = T::from_f64(0.01).unwrap();
|
||||
let beta = T::from_f64(0.5).unwrap();
|
||||
let gamma = T::from_f64(-0.25).unwrap();
|
||||
let mu = T::two();
|
||||
|
||||
let y = M::from_row_vector(y.sub_scalar(y.mean())).transpose();
|
||||
|
||||
let mut max_ls_iter = 100;
|
||||
let mut pitr = 0;
|
||||
let mut w = M::zeros(p, 1);
|
||||
let mut neww = w.clone();
|
||||
let mut u = M::ones(p, 1);
|
||||
let mut newu = u.clone();
|
||||
|
||||
let mut f = M::fill(p, 2, -T::one());
|
||||
let mut newf = f.clone();
|
||||
|
||||
let mut q1 = vec![T::zero(); p];
|
||||
let mut q2 = vec![T::zero(); p];
|
||||
|
||||
let mut dx = M::zeros(p, 1);
|
||||
let mut du = M::zeros(p, 1);
|
||||
let mut dxu = M::zeros(2 * p, 1);
|
||||
let mut grad = M::zeros(2 * p, 1);
|
||||
|
||||
let mut nu = M::zeros(n, 1);
|
||||
let mut dobj = T::zero();
|
||||
let mut s = T::infinity();
|
||||
let mut t = T::one()
|
||||
.max(T::one() / lambda)
|
||||
.min(T::two() * p_f64 / T::from(1e-3).unwrap());
|
||||
|
||||
for ntiter in 0..max_iter {
|
||||
let mut z = x.matmul(&w);
|
||||
|
||||
for i in 0..n {
|
||||
z.set(i, 0, z.get(i, 0) - y.get(i, 0));
|
||||
nu.set(i, 0, T::two() * z.get(i, 0));
|
||||
}
|
||||
|
||||
// CALCULATE DUALITY GAP
|
||||
let xnu = x.ab(true, &nu, false);
|
||||
let max_xnu = xnu.norm(T::infinity());
|
||||
if max_xnu > lambda {
|
||||
let lnu = lambda / max_xnu;
|
||||
nu.mul_scalar_mut(lnu);
|
||||
}
|
||||
|
||||
let pobj = z.dot(&z) + lambda * w.norm(T::one());
|
||||
dobj = dobj.max(gamma * nu.dot(&nu) - nu.dot(&y));
|
||||
|
||||
let gap = pobj - dobj;
|
||||
|
||||
// STOPPING CRITERION
|
||||
if gap / dobj < tol {
|
||||
break;
|
||||
}
|
||||
|
||||
// UPDATE t
|
||||
if s >= T::half() {
|
||||
t = t.max((T::two() * p_f64 * mu / gap).min(mu * t));
|
||||
}
|
||||
|
||||
// CALCULATE NEWTON STEP
|
||||
for i in 0..p {
|
||||
let q1i = T::one() / (u.get(i, 0) + w.get(i, 0));
|
||||
let q2i = T::one() / (u.get(i, 0) - w.get(i, 0));
|
||||
q1[i] = q1i;
|
||||
q2[i] = q2i;
|
||||
self.d1[i] = (q1i * q1i + q2i * q2i) / t;
|
||||
self.d2[i] = (q1i * q1i - q2i * q2i) / t;
|
||||
}
|
||||
|
||||
let mut gradphi = x.ab(true, &z, false);
|
||||
|
||||
for i in 0..p {
|
||||
let g1 = T::two() * gradphi.get(i, 0) - (q1[i] - q2[i]) / t;
|
||||
let g2 = lambda - (q1[i] + q2[i]) / t;
|
||||
gradphi.set(i, 0, g1);
|
||||
grad.set(i, 0, -g1);
|
||||
grad.set(i + p, 0, -g2);
|
||||
}
|
||||
|
||||
for i in 0..p {
|
||||
self.prb[i] = T::two() + self.d1[i];
|
||||
self.prs[i] = self.prb[i] * self.d1[i] - self.d2[i].powi(2);
|
||||
}
|
||||
|
||||
let normg = grad.norm2();
|
||||
let mut pcgtol = min_pcgtol.min(eta * gap / T::one().min(normg));
|
||||
if ntiter != 0 && pitr == 0 {
|
||||
pcgtol *= min_pcgtol;
|
||||
}
|
||||
|
||||
let error = self.solve_mut(x, &grad, &mut dxu, pcgtol, pcgmaxi)?;
|
||||
if error > pcgtol {
|
||||
pitr = pcgmaxi;
|
||||
}
|
||||
|
||||
for i in 0..p {
|
||||
dx.set(i, 0, dxu.get(i, 0));
|
||||
du.set(i, 0, dxu.get(i + p, 0));
|
||||
}
|
||||
|
||||
// BACKTRACKING LINE SEARCH
|
||||
let phi = z.dot(&z) + lambda * u.sum() - Self::sumlogneg(&f) / t;
|
||||
s = T::one();
|
||||
let gdx = grad.dot(&dxu);
|
||||
|
||||
let lsiter = 0;
|
||||
while lsiter < max_ls_iter {
|
||||
for i in 0..p {
|
||||
neww.set(i, 0, w.get(i, 0) + s * dx.get(i, 0));
|
||||
newu.set(i, 0, u.get(i, 0) + s * du.get(i, 0));
|
||||
newf.set(i, 0, neww.get(i, 0) - newu.get(i, 0));
|
||||
newf.set(i, 1, -neww.get(i, 0) - newu.get(i, 0));
|
||||
}
|
||||
|
||||
if newf.max() < T::zero() {
|
||||
let mut newz = x.matmul(&neww);
|
||||
for i in 0..n {
|
||||
newz.set(i, 0, newz.get(i, 0) - y.get(i, 0));
|
||||
}
|
||||
|
||||
let newphi = newz.dot(&newz) + lambda * newu.sum() - Self::sumlogneg(&newf) / t;
|
||||
if newphi - phi <= alpha * s * gdx {
|
||||
break;
|
||||
}
|
||||
}
|
||||
s = beta * s;
|
||||
max_ls_iter += 1;
|
||||
}
|
||||
|
||||
if lsiter == max_ls_iter {
|
||||
return Err(Failed::fit(
|
||||
"Exceeded maximum number of iteration for interior point optimizer",
|
||||
));
|
||||
}
|
||||
|
||||
w.copy_from(&neww);
|
||||
u.copy_from(&newu);
|
||||
f.copy_from(&newf);
|
||||
}
|
||||
|
||||
Ok(w)
|
||||
}
|
||||
|
||||
fn sumlogneg(f: &M) -> T {
|
||||
let (n, _) = f.shape();
|
||||
let mut sum = T::zero();
|
||||
for i in 0..n {
|
||||
sum += (-f.get(i, 0)).ln();
|
||||
sum += (-f.get(i, 1)).ln();
|
||||
}
|
||||
sum
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> BiconjugateGradientSolver<T, M> for InteriorPointOptimizer<T, M> {
|
||||
fn solve_preconditioner(&self, a: &M, b: &M, x: &mut M) {
|
||||
let (_, p) = a.shape();
|
||||
|
||||
for i in 0..p {
|
||||
x.set(
|
||||
i,
|
||||
0,
|
||||
(self.d1[i] * b.get(i, 0) - self.d2[i] * b.get(i + p, 0)) / self.prs[i],
|
||||
);
|
||||
x.set(
|
||||
i + p,
|
||||
0,
|
||||
(-self.d2[i] * b.get(i, 0) + self.prb[i] * b.get(i + p, 0)) / self.prs[i],
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn mat_vec_mul(&self, _: &M, x: &M, y: &mut M) {
|
||||
let (_, p) = self.ata.shape();
|
||||
let atax = self.ata.matmul(&x.slice(0..p, 0..1));
|
||||
|
||||
for i in 0..p {
|
||||
y.set(
|
||||
i,
|
||||
0,
|
||||
T::two() * atax.get(i, 0) + self.d1[i] * x.get(i, 0) + self.d2[i] * x.get(i + p, 0),
|
||||
);
|
||||
y.set(
|
||||
i + p,
|
||||
0,
|
||||
self.d2[i] * x.get(i, 0) + self.d1[i] * x.get(i + p, 0),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn mat_t_vec_mul(&self, a: &M, x: &M, y: &mut M) {
|
||||
self.mat_vec_mul(a, x, y);
|
||||
}
|
||||
}
|
||||
@@ -45,9 +45,9 @@
|
||||
//! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0,
|
||||
//! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
||||
//!
|
||||
//! let lr = LinearRegression::fit(&x, &y, LinearRegressionParameters {
|
||||
//! solver: LinearRegressionSolverName::QR, // or SVD
|
||||
//! }).unwrap();
|
||||
//! let lr = LinearRegression::fit(&x, &y,
|
||||
//! LinearRegressionParameters::default().
|
||||
//! with_solver(LinearRegressionSolverName::QR)).unwrap();
|
||||
//!
|
||||
//! let y_hat = lr.predict(&x).unwrap();
|
||||
//! ```
|
||||
@@ -62,13 +62,16 @@
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::fmt::Debug;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable.
|
||||
pub enum LinearRegressionSolverName {
|
||||
/// QR decomposition, see [QR](../../linalg/qr/index.html)
|
||||
@@ -78,18 +81,28 @@ pub enum LinearRegressionSolverName {
|
||||
}
|
||||
|
||||
/// Linear Regression parameters
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LinearRegressionParameters {
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub solver: LinearRegressionSolverName,
|
||||
}
|
||||
|
||||
/// Linear Regression
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct LinearRegression<T: RealNumber, M: Matrix<T>> {
|
||||
coefficients: M,
|
||||
intercept: T,
|
||||
solver: LinearRegressionSolverName,
|
||||
_solver: LinearRegressionSolverName,
|
||||
}
|
||||
|
||||
impl LinearRegressionParameters {
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub fn with_solver(mut self, solver: LinearRegressionSolverName) -> Self {
|
||||
self.solver = solver;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LinearRegressionParameters {
|
||||
@@ -107,6 +120,24 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for LinearRegression<T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, LinearRegressionParameters>
|
||||
for LinearRegression<T, M>
|
||||
{
|
||||
fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: LinearRegressionParameters,
|
||||
) -> Result<Self, Failed> {
|
||||
LinearRegression::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for LinearRegression<T, M> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> LinearRegression<T, M> {
|
||||
/// Fits Linear Regression to your data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
@@ -123,9 +154,9 @@ impl<T: RealNumber, M: Matrix<T>> LinearRegression<T, M> {
|
||||
let (y_nrows, _) = b.shape();
|
||||
|
||||
if x_nrows != y_nrows {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Number of rows of X doesn't match number of rows of Y"
|
||||
)));
|
||||
return Err(Failed::fit(
|
||||
"Number of rows of X doesn\'t match number of rows of Y",
|
||||
));
|
||||
}
|
||||
|
||||
let a = x.h_stack(&M::ones(x_nrows, 1));
|
||||
@@ -140,7 +171,7 @@ impl<T: RealNumber, M: Matrix<T>> LinearRegression<T, M> {
|
||||
Ok(LinearRegression {
|
||||
intercept: w.get(num_attributes, 0),
|
||||
coefficients: wights,
|
||||
solver: parameters.solver,
|
||||
_solver: parameters.solver,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -154,8 +185,8 @@ impl<T: RealNumber, M: Matrix<T>> LinearRegression<T, M> {
|
||||
}
|
||||
|
||||
/// Get estimates regression coefficients
|
||||
pub fn coefficients(&self) -> M {
|
||||
self.coefficients.clone()
|
||||
pub fn coefficients(&self) -> &M {
|
||||
&self.coefficients
|
||||
}
|
||||
|
||||
/// Get estimate of intercept
|
||||
@@ -169,6 +200,7 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn ols_fit_predict() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
@@ -219,7 +251,9 @@ mod tests {
|
||||
.all(|(&a, &b)| (a - b).abs() <= 5.0));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
//! ];
|
||||
//!
|
||||
//! let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||
//! let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
//!
|
||||
//! let y_hat = lr.predict(&x).unwrap();
|
||||
//! ```
|
||||
@@ -52,11 +52,13 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::cmp::Ordering;
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
@@ -65,10 +67,30 @@ use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
|
||||
use crate::optimization::line_search::Backtracking;
|
||||
use crate::optimization::FunctionOrder;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// Solver options for Logistic regression. Right now only LBFGS solver is supported.
|
||||
pub enum LogisticRegressionSolverName {
|
||||
/// Limited-memory Broyden–Fletcher–Goldfarb–Shanno method, see [LBFGS paper](http://users.iems.northwestern.edu/~nocedal/lbfgsb.html)
|
||||
LBFGS,
|
||||
}
|
||||
|
||||
/// Logistic Regression parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LogisticRegressionParameters<T: RealNumber> {
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub solver: LogisticRegressionSolverName,
|
||||
/// Regularization parameter.
|
||||
pub alpha: T,
|
||||
}
|
||||
|
||||
/// Logistic Regression
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct LogisticRegression<T: RealNumber, M: Matrix<T>> {
|
||||
weights: M,
|
||||
coefficients: M,
|
||||
intercept: M,
|
||||
classes: Vec<T>,
|
||||
num_attributes: usize,
|
||||
num_classes: usize,
|
||||
@@ -82,7 +104,7 @@ trait ObjectiveFunction<T: RealNumber, M: Matrix<T>> {
|
||||
let mut sum = T::zero();
|
||||
let p = x.shape().1;
|
||||
for i in 0..p {
|
||||
sum = sum + x.get(m_row, i) * w.get(0, i + v_col);
|
||||
sum += x.get(m_row, i) * w.get(0, i + v_col);
|
||||
}
|
||||
|
||||
sum + w.get(0, p + v_col)
|
||||
@@ -92,7 +114,29 @@ trait ObjectiveFunction<T: RealNumber, M: Matrix<T>> {
|
||||
struct BinaryObjectiveFunction<'a, T: RealNumber, M: Matrix<T>> {
|
||||
x: &'a M,
|
||||
y: Vec<usize>,
|
||||
phantom: PhantomData<&'a T>,
|
||||
alpha: T,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> LogisticRegressionParameters<T> {
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub fn with_solver(mut self, solver: LogisticRegressionSolverName) -> Self {
|
||||
self.solver = solver;
|
||||
self
|
||||
}
|
||||
/// Regularization parameter.
|
||||
pub fn with_alpha(mut self, alpha: T) -> Self {
|
||||
self.alpha = alpha;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for LogisticRegressionParameters<T> {
|
||||
fn default() -> Self {
|
||||
LogisticRegressionParameters {
|
||||
solver: LogisticRegressionSolverName::LBFGS,
|
||||
alpha: T::zero(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> PartialEq for LogisticRegression<T, M> {
|
||||
@@ -101,7 +145,7 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for LogisticRegression<T, M> {
|
||||
|| self.num_attributes != other.num_attributes
|
||||
|| self.classes.len() != other.classes.len()
|
||||
{
|
||||
return false;
|
||||
false
|
||||
} else {
|
||||
for i in 0..self.classes.len() {
|
||||
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
|
||||
@@ -109,7 +153,7 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for LogisticRegression<T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
return self.weights == other.weights;
|
||||
self.coefficients == other.coefficients && self.intercept == other.intercept
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -119,11 +163,20 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
|
||||
{
|
||||
fn f(&self, w_bias: &M) -> T {
|
||||
let mut f = T::zero();
|
||||
let (n, _) = self.x.shape();
|
||||
let (n, p) = self.x.shape();
|
||||
|
||||
for i in 0..n {
|
||||
let wx = BinaryObjectiveFunction::partial_dot(w_bias, self.x, 0, i);
|
||||
f = f + (wx.ln_1pe() - (T::from(self.y[i]).unwrap()) * wx);
|
||||
f += wx.ln_1pe() - (T::from(self.y[i]).unwrap()) * wx;
|
||||
}
|
||||
|
||||
if self.alpha > T::zero() {
|
||||
let mut w_squared = T::zero();
|
||||
for i in 0..p {
|
||||
let w = w_bias.get(0, i);
|
||||
w_squared += w * w;
|
||||
}
|
||||
f += T::half() * self.alpha * w_squared;
|
||||
}
|
||||
|
||||
f
|
||||
@@ -143,6 +196,13 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
|
||||
}
|
||||
g.set(0, p, g.get(0, p) - dyi);
|
||||
}
|
||||
|
||||
if self.alpha > T::zero() {
|
||||
for i in 0..p {
|
||||
let w = w_bias.get(0, i);
|
||||
g.set(0, i, g.get(0, i) + self.alpha * w);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -150,7 +210,7 @@ struct MultiClassObjectiveFunction<'a, T: RealNumber, M: Matrix<T>> {
|
||||
x: &'a M,
|
||||
y: Vec<usize>,
|
||||
k: usize,
|
||||
phantom: PhantomData<&'a T>,
|
||||
alpha: T,
|
||||
}
|
||||
|
||||
impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
|
||||
@@ -169,7 +229,18 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
|
||||
);
|
||||
}
|
||||
prob.softmax_mut();
|
||||
f = f - prob.get(0, self.y[i]).ln();
|
||||
f -= prob.get(0, self.y[i]).ln();
|
||||
}
|
||||
|
||||
if self.alpha > T::zero() {
|
||||
let mut w_squared = T::zero();
|
||||
for i in 0..self.k {
|
||||
for j in 0..p {
|
||||
let wi = w_bias.get(0, i * (p + 1) + j);
|
||||
w_squared += wi * wi;
|
||||
}
|
||||
}
|
||||
f += T::half() * self.alpha * w_squared;
|
||||
}
|
||||
|
||||
f
|
||||
@@ -202,6 +273,35 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
|
||||
g.set(0, j * (p + 1) + p, g.get(0, j * (p + 1) + p) - yi);
|
||||
}
|
||||
}
|
||||
|
||||
if self.alpha > T::zero() {
|
||||
for i in 0..self.k {
|
||||
for j in 0..p {
|
||||
let pos = i * (p + 1);
|
||||
let wi = w.get(0, pos + j);
|
||||
g.set(0, pos + j, g.get(0, pos + j) + self.alpha * wi);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>>
|
||||
SupervisedEstimator<M, M::RowVector, LogisticRegressionParameters<T>>
|
||||
for LogisticRegression<T, M>
|
||||
{
|
||||
fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: LogisticRegressionParameters<T>,
|
||||
) -> Result<Self, Failed> {
|
||||
LogisticRegression::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for LogisticRegression<T, M> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -209,15 +309,20 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
|
||||
/// Fits Logistic Regression to your data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - target class values
|
||||
pub fn fit(x: &M, y: &M::RowVector) -> Result<LogisticRegression<T, M>, Failed> {
|
||||
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
|
||||
pub fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: LogisticRegressionParameters<T>,
|
||||
) -> Result<LogisticRegression<T, M>, Failed> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let (_, y_nrows) = y_m.shape();
|
||||
|
||||
if x_nrows != y_nrows {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Number of rows of X doesn't match number of rows of Y"
|
||||
)));
|
||||
return Err(Failed::fit(
|
||||
"Number of rows of X doesn\'t match number of rows of Y",
|
||||
));
|
||||
}
|
||||
|
||||
let classes = y_m.unique();
|
||||
@@ -226,53 +331,58 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
|
||||
|
||||
let mut yi: Vec<usize> = vec![0; y_nrows];
|
||||
|
||||
for i in 0..y_nrows {
|
||||
for (i, yi_i) in yi.iter_mut().enumerate().take(y_nrows) {
|
||||
let yc = y_m.get(0, i);
|
||||
yi[i] = classes.iter().position(|c| yc == *c).unwrap();
|
||||
*yi_i = classes.iter().position(|c| yc == *c).unwrap();
|
||||
}
|
||||
|
||||
if k < 2 {
|
||||
Err(Failed::fit(&format!(
|
||||
match k.cmp(&2) {
|
||||
Ordering::Less => Err(Failed::fit(&format!(
|
||||
"incorrect number of classes: {}. Should be >= 2.",
|
||||
k
|
||||
)))
|
||||
} else if k == 2 {
|
||||
let x0 = M::zeros(1, num_attributes + 1);
|
||||
))),
|
||||
Ordering::Equal => {
|
||||
let x0 = M::zeros(1, num_attributes + 1);
|
||||
|
||||
let objective = BinaryObjectiveFunction {
|
||||
x: x,
|
||||
y: yi,
|
||||
phantom: PhantomData,
|
||||
};
|
||||
let objective = BinaryObjectiveFunction {
|
||||
x,
|
||||
y: yi,
|
||||
alpha: parameters.alpha,
|
||||
};
|
||||
|
||||
let result = LogisticRegression::minimize(x0, objective);
|
||||
let result = LogisticRegression::minimize(x0, objective);
|
||||
|
||||
Ok(LogisticRegression {
|
||||
weights: result.x,
|
||||
classes: classes,
|
||||
num_attributes: num_attributes,
|
||||
num_classes: k,
|
||||
})
|
||||
} else {
|
||||
let x0 = M::zeros(1, (num_attributes + 1) * k);
|
||||
let weights = result.x;
|
||||
|
||||
let objective = MultiClassObjectiveFunction {
|
||||
x: x,
|
||||
y: yi,
|
||||
k: k,
|
||||
phantom: PhantomData,
|
||||
};
|
||||
Ok(LogisticRegression {
|
||||
coefficients: weights.slice(0..1, 0..num_attributes),
|
||||
intercept: weights.slice(0..1, num_attributes..num_attributes + 1),
|
||||
classes,
|
||||
num_attributes,
|
||||
num_classes: k,
|
||||
})
|
||||
}
|
||||
Ordering::Greater => {
|
||||
let x0 = M::zeros(1, (num_attributes + 1) * k);
|
||||
|
||||
let result = LogisticRegression::minimize(x0, objective);
|
||||
let objective = MultiClassObjectiveFunction {
|
||||
x,
|
||||
y: yi,
|
||||
k,
|
||||
alpha: parameters.alpha,
|
||||
};
|
||||
|
||||
let weights = result.x.reshape(k, num_attributes + 1);
|
||||
let result = LogisticRegression::minimize(x0, objective);
|
||||
let weights = result.x.reshape(k, num_attributes + 1);
|
||||
|
||||
Ok(LogisticRegression {
|
||||
weights: weights,
|
||||
classes: classes,
|
||||
num_attributes: num_attributes,
|
||||
num_classes: k,
|
||||
})
|
||||
Ok(LogisticRegression {
|
||||
coefficients: weights.slice(0..k, 0..num_attributes),
|
||||
intercept: weights.slice(0..k, num_attributes..num_attributes + 1),
|
||||
classes,
|
||||
num_attributes,
|
||||
num_classes: k,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -282,42 +392,42 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
|
||||
let n = x.shape().0;
|
||||
let mut result = M::zeros(1, n);
|
||||
if self.num_classes == 2 {
|
||||
let (nrows, _) = x.shape();
|
||||
let x_and_bias = x.h_stack(&M::ones(nrows, 1));
|
||||
let y_hat: Vec<T> = x_and_bias
|
||||
.matmul(&self.weights.transpose())
|
||||
.get_col_as_vec(0);
|
||||
for i in 0..n {
|
||||
let y_hat: Vec<T> = x.ab(false, &self.coefficients, true).get_col_as_vec(0);
|
||||
let intercept = self.intercept.get(0, 0);
|
||||
for (i, y_hat_i) in y_hat.iter().enumerate().take(n) {
|
||||
result.set(
|
||||
0,
|
||||
i,
|
||||
self.classes[if y_hat[i].sigmoid() > T::half() { 1 } else { 0 }],
|
||||
self.classes[if (*y_hat_i + intercept).sigmoid() > T::half() {
|
||||
1
|
||||
} else {
|
||||
0
|
||||
}],
|
||||
);
|
||||
}
|
||||
} else {
|
||||
let (nrows, _) = x.shape();
|
||||
let x_and_bias = x.h_stack(&M::ones(nrows, 1));
|
||||
let y_hat = x_and_bias.matmul(&self.weights.transpose());
|
||||
let mut y_hat = x.matmul(&self.coefficients.transpose());
|
||||
for r in 0..n {
|
||||
for c in 0..self.num_classes {
|
||||
y_hat.set(r, c, y_hat.get(r, c) + self.intercept.get(c, 0));
|
||||
}
|
||||
}
|
||||
let class_idxs = y_hat.argmax();
|
||||
for i in 0..n {
|
||||
result.set(0, i, self.classes[class_idxs[i]]);
|
||||
for (i, class_i) in class_idxs.iter().enumerate().take(n) {
|
||||
result.set(0, i, self.classes[*class_i]);
|
||||
}
|
||||
}
|
||||
Ok(result.to_row_vector())
|
||||
}
|
||||
|
||||
/// Get estimates regression coefficients
|
||||
pub fn coefficients(&self) -> M {
|
||||
self.weights
|
||||
.slice(0..self.num_classes, 0..self.num_attributes)
|
||||
pub fn coefficients(&self) -> &M {
|
||||
&self.coefficients
|
||||
}
|
||||
|
||||
/// Get estimate of intercept
|
||||
pub fn intercept(&self) -> M {
|
||||
self.weights.slice(
|
||||
0..self.num_classes,
|
||||
self.num_attributes..self.num_attributes + 1,
|
||||
)
|
||||
pub fn intercept(&self) -> &M {
|
||||
&self.intercept
|
||||
}
|
||||
|
||||
fn minimize(x0: M, objective: impl ObjectiveFunction<T, M>) -> OptimizerResult<T, M> {
|
||||
@@ -325,8 +435,10 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
|
||||
|
||||
let df = |g: &mut M, w: &M| objective.df(g, w);
|
||||
|
||||
let mut ls: Backtracking<T> = Default::default();
|
||||
ls.order = FunctionOrder::THIRD;
|
||||
let ls: Backtracking<T> = Backtracking {
|
||||
order: FunctionOrder::THIRD,
|
||||
..Default::default()
|
||||
};
|
||||
let optimizer: LBFGS<T> = Default::default();
|
||||
|
||||
optimizer.optimize(&f, &df, &x0, &ls)
|
||||
@@ -336,8 +448,11 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::dataset::generator::make_blobs;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::metrics::accuracy;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn multiclass_objective_f() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
@@ -362,9 +477,9 @@ mod tests {
|
||||
|
||||
let objective = MultiClassObjectiveFunction {
|
||||
x: &x,
|
||||
y: y,
|
||||
y: y.clone(),
|
||||
k: 3,
|
||||
phantom: PhantomData,
|
||||
alpha: 0.0,
|
||||
};
|
||||
|
||||
let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 9);
|
||||
@@ -385,8 +500,27 @@ mod tests {
|
||||
]));
|
||||
|
||||
assert!((f - 408.0052230582765).abs() < std::f64::EPSILON);
|
||||
|
||||
let objective_reg = MultiClassObjectiveFunction {
|
||||
x: &x,
|
||||
y: y.clone(),
|
||||
k: 3,
|
||||
alpha: 1.0,
|
||||
};
|
||||
|
||||
let f = objective_reg.f(&DenseMatrix::row_vector_from_array(&[
|
||||
1., 2., 3., 4., 5., 6., 7., 8., 9.,
|
||||
]));
|
||||
assert!((f - 487.5052).abs() < 1e-4);
|
||||
|
||||
objective_reg.df(
|
||||
&mut g,
|
||||
&DenseMatrix::row_vector_from_array(&[1., 2., 3., 4., 5., 6., 7., 8., 9.]),
|
||||
);
|
||||
assert!((g.get(0, 0).abs() - 32.0).abs() < 1e-4);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn binary_objective_f() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
@@ -411,8 +545,8 @@ mod tests {
|
||||
|
||||
let objective = BinaryObjectiveFunction {
|
||||
x: &x,
|
||||
y: y,
|
||||
phantom: PhantomData,
|
||||
y: y.clone(),
|
||||
alpha: 0.0,
|
||||
};
|
||||
|
||||
let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 3);
|
||||
@@ -427,8 +561,23 @@ mod tests {
|
||||
let f = objective.f(&DenseMatrix::row_vector_from_array(&[1., 2., 3.]));
|
||||
|
||||
assert!((f - 59.76994756647412).abs() < std::f64::EPSILON);
|
||||
|
||||
let objective_reg = BinaryObjectiveFunction {
|
||||
x: &x,
|
||||
y: y.clone(),
|
||||
alpha: 1.0,
|
||||
};
|
||||
|
||||
let f = objective_reg.f(&DenseMatrix::row_vector_from_array(&[1., 2., 3.]));
|
||||
assert!((f - 62.2699).abs() < 1e-4);
|
||||
|
||||
objective_reg.df(&mut g, &DenseMatrix::row_vector_from_array(&[1., 2., 3.]));
|
||||
assert!((g.get(0, 0) - 27.0511).abs() < 1e-4);
|
||||
assert!((g.get(0, 1) - 12.239).abs() < 1e-4);
|
||||
assert!((g.get(0, 2) - 3.8693).abs() < 1e-4);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn lr_fit_predict() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
@@ -450,7 +599,7 @@ mod tests {
|
||||
]);
|
||||
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
assert_eq!(lr.coefficients().shape(), (3, 2));
|
||||
assert_eq!(lr.intercept().shape(), (3, 1));
|
||||
@@ -466,7 +615,57 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn lr_fit_predict_multiclass() {
|
||||
let blobs = make_blobs(15, 4, 3);
|
||||
|
||||
let x = DenseMatrix::from_vec(15, 4, &blobs.data);
|
||||
let y = blobs.target;
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let y_hat = lr.predict(&x).unwrap();
|
||||
|
||||
assert!(accuracy(&y_hat, &y) > 0.9);
|
||||
|
||||
let lr_reg = LogisticRegression::fit(
|
||||
&x,
|
||||
&y,
|
||||
LogisticRegressionParameters::default().with_alpha(10.0),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(lr_reg.coefficients().abs().sum() < lr.coefficients().abs().sum());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn lr_fit_predict_binary() {
|
||||
let blobs = make_blobs(20, 4, 2);
|
||||
|
||||
let x = DenseMatrix::from_vec(20, 4, &blobs.data);
|
||||
let y = blobs.target;
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let y_hat = lr.predict(&x).unwrap();
|
||||
|
||||
assert!(accuracy(&y_hat, &y) > 0.9);
|
||||
|
||||
let lr_reg = LogisticRegression::fit(
|
||||
&x,
|
||||
&y,
|
||||
LogisticRegressionParameters::default().with_alpha(10.0),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(lr_reg.coefficients().abs().sum() < lr.coefficients().abs().sum());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[1., -5.],
|
||||
@@ -487,7 +686,7 @@ mod tests {
|
||||
]);
|
||||
let y: Vec<f64> = vec![0., 0., 1., 1., 2., 1., 1., 0., 0., 2., 1., 1., 0., 0., 1.];
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let deserialized_lr: LogisticRegression<f64, DenseMatrix<f64>> =
|
||||
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
||||
@@ -495,6 +694,7 @@ mod tests {
|
||||
assert_eq!(lr, deserialized_lr);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn lr_fit_predict_iris() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
@@ -523,7 +723,13 @@ mod tests {
|
||||
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
let lr_reg = LogisticRegression::fit(
|
||||
&x,
|
||||
&y,
|
||||
LogisticRegressionParameters::default().with_alpha(1.0),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let y_hat = lr.predict(&x).unwrap();
|
||||
|
||||
@@ -534,5 +740,6 @@ mod tests {
|
||||
.sum();
|
||||
|
||||
assert!(error <= 1.0);
|
||||
assert!(lr_reg.coefficients().abs().sum() < lr.coefficients().abs().sum());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,5 +20,10 @@
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
pub(crate) mod bg_solver;
|
||||
pub mod elastic_net;
|
||||
pub mod lasso;
|
||||
pub(crate) mod lasso_optimizer;
|
||||
pub mod linear_regression;
|
||||
pub mod logistic_regression;
|
||||
pub mod ridge_regression;
|
||||
|
||||
@@ -0,0 +1,368 @@
|
||||
//! # Ridge Regression
|
||||
//!
|
||||
//! [Linear regression](../linear_regression/index.html) is the standard algorithm for predicting a quantitative response \\(y\\) on the basis of a linear combination of explanatory variables \\(X\\)
|
||||
//! that assumes that there is approximately a linear relationship between \\(X\\) and \\(y\\).
|
||||
//! Ridge regression is an extension to linear regression that adds L2 regularization term to the loss function during training.
|
||||
//! This term encourages simpler models that have smaller coefficient values.
|
||||
//!
|
||||
//! In ridge regression coefficients \\(\beta_0, \beta_0, ... \beta_n\\) are are estimated by solving
|
||||
//!
|
||||
//! \\[\hat{\beta} = (X^TX + \alpha I)^{-1}X^Ty \\]
|
||||
//!
|
||||
//! where \\(\alpha \geq 0\\) is a tuning parameter that controls strength of regularization. When \\(\alpha = 0\\) the penalty term has no effect, and ridge regression will produce the least squares estimates.
|
||||
//! However, as \\(\alpha \rightarrow \infty\\), the impact of the shrinkage penalty grows, and the ridge regression coefficient estimates will approach zero.
|
||||
//!
|
||||
//! SmartCore uses [SVD](../../linalg/svd/index.html) and [Cholesky](../../linalg/cholesky/index.html) matrix decomposition to find estimates of \\(\hat{\beta}\\).
|
||||
//! The Cholesky decomposition is more computationally efficient and more numerically stable than calculating the normal equation directly,
|
||||
//! but does not work for all data matrices. Unlike the Cholesky decomposition, all matrices have an SVD decomposition.
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::linear::ridge_regression::*;
|
||||
//!
|
||||
//! // Longley dataset (https://www.statsmodels.org/stable/datasets/generated/longley.html)
|
||||
//! let x = DenseMatrix::from_2d_array(&[
|
||||
//! &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
//! &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
//! &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
//! &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
|
||||
//! &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
//! &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
|
||||
//! &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
|
||||
//! &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
|
||||
//! &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
//! &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
//! &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
//! &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
|
||||
//! &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
//! ]);
|
||||
//!
|
||||
//! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0,
|
||||
//! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
||||
//!
|
||||
//! let y_hat = RidgeRegression::fit(&x, &y, RidgeRegressionParameters::default().with_alpha(0.1)).
|
||||
//! and_then(|lr| lr.predict(&x)).unwrap();
|
||||
//! ```
|
||||
//!
|
||||
//! ## References:
|
||||
//!
|
||||
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 6.2. Shrinkage Methods](http://faculty.marshall.usc.edu/gareth-james/ISL/)
|
||||
//! * ["Numerical Recipes: The Art of Scientific Computing", Press W.H., Teukolsky S.A., Vetterling W.T, Flannery B.P, 3rd ed., Section 15.4 General Linear Least Squares](http://numerical.recipes/)
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::fmt::Debug;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable.
|
||||
pub enum RidgeRegressionSolverName {
|
||||
/// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html)
|
||||
Cholesky,
|
||||
/// SVD decomposition, see [SVD](../../linalg/svd/index.html)
|
||||
SVD,
|
||||
}
|
||||
|
||||
/// Ridge Regression parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RidgeRegressionParameters<T: RealNumber> {
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub solver: RidgeRegressionSolverName,
|
||||
/// Controls the strength of the penalty to the loss function.
|
||||
pub alpha: T,
|
||||
/// If true the regressors X will be normalized before regression
|
||||
/// by subtracting the mean and dividing by the standard deviation.
|
||||
pub normalize: bool,
|
||||
}
|
||||
|
||||
/// Ridge regression
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct RidgeRegression<T: RealNumber, M: Matrix<T>> {
|
||||
coefficients: M,
|
||||
intercept: T,
|
||||
_solver: RidgeRegressionSolverName,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> RidgeRegressionParameters<T> {
|
||||
/// Regularization parameter.
|
||||
pub fn with_alpha(mut self, alpha: T) -> Self {
|
||||
self.alpha = alpha;
|
||||
self
|
||||
}
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub fn with_solver(mut self, solver: RidgeRegressionSolverName) -> Self {
|
||||
self.solver = solver;
|
||||
self
|
||||
}
|
||||
/// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation.
|
||||
pub fn with_normalize(mut self, normalize: bool) -> Self {
|
||||
self.normalize = normalize;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for RidgeRegressionParameters<T> {
|
||||
fn default() -> Self {
|
||||
RidgeRegressionParameters {
|
||||
solver: RidgeRegressionSolverName::Cholesky,
|
||||
alpha: T::one(),
|
||||
normalize: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> PartialEq for RidgeRegression<T, M> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.coefficients == other.coefficients
|
||||
&& (self.intercept - other.intercept).abs() <= T::epsilon()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, RidgeRegressionParameters<T>>
|
||||
for RidgeRegression<T, M>
|
||||
{
|
||||
fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: RidgeRegressionParameters<T>,
|
||||
) -> Result<Self, Failed> {
|
||||
RidgeRegression::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RidgeRegression<T, M> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> RidgeRegression<T, M> {
|
||||
/// Fits ridge regression to your data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - target values
|
||||
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
|
||||
pub fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: RidgeRegressionParameters<T>,
|
||||
) -> Result<RidgeRegression<T, M>, Failed> {
|
||||
//w = inv(X^t X + alpha*Id) * X.T y
|
||||
|
||||
let (n, p) = x.shape();
|
||||
|
||||
if n <= p {
|
||||
return Err(Failed::fit(
|
||||
"Number of rows in X should be >= number of columns in X",
|
||||
));
|
||||
}
|
||||
|
||||
if y.len() != n {
|
||||
return Err(Failed::fit("Number of rows in X should = len(y)"));
|
||||
}
|
||||
|
||||
let y_column = M::from_row_vector(y.clone()).transpose();
|
||||
|
||||
let (w, b) = if parameters.normalize {
|
||||
let (scaled_x, col_mean, col_std) = Self::rescale_x(x)?;
|
||||
let x_t = scaled_x.transpose();
|
||||
let x_t_y = x_t.matmul(&y_column);
|
||||
let mut x_t_x = x_t.matmul(&scaled_x);
|
||||
|
||||
for i in 0..p {
|
||||
x_t_x.add_element_mut(i, i, parameters.alpha);
|
||||
}
|
||||
|
||||
let mut w = match parameters.solver {
|
||||
RidgeRegressionSolverName::Cholesky => x_t_x.cholesky_solve_mut(x_t_y)?,
|
||||
RidgeRegressionSolverName::SVD => x_t_x.svd_solve_mut(x_t_y)?,
|
||||
};
|
||||
|
||||
for (i, col_std_i) in col_std.iter().enumerate().take(p) {
|
||||
w.set(i, 0, w.get(i, 0) / *col_std_i);
|
||||
}
|
||||
|
||||
let mut b = T::zero();
|
||||
|
||||
for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
|
||||
b += w.get(i, 0) * *col_mean_i;
|
||||
}
|
||||
|
||||
let b = y.mean() - b;
|
||||
|
||||
(w, b)
|
||||
} else {
|
||||
let x_t = x.transpose();
|
||||
let x_t_y = x_t.matmul(&y_column);
|
||||
let mut x_t_x = x_t.matmul(x);
|
||||
|
||||
for i in 0..p {
|
||||
x_t_x.add_element_mut(i, i, parameters.alpha);
|
||||
}
|
||||
|
||||
let w = match parameters.solver {
|
||||
RidgeRegressionSolverName::Cholesky => x_t_x.cholesky_solve_mut(x_t_y)?,
|
||||
RidgeRegressionSolverName::SVD => x_t_x.svd_solve_mut(x_t_y)?,
|
||||
};
|
||||
|
||||
(w, T::zero())
|
||||
};
|
||||
|
||||
Ok(RidgeRegression {
|
||||
intercept: b,
|
||||
coefficients: w,
|
||||
_solver: parameters.solver,
|
||||
})
|
||||
}
|
||||
|
||||
fn rescale_x(x: &M) -> Result<(M, Vec<T>, Vec<T>), Failed> {
|
||||
let col_mean = x.mean(0);
|
||||
let col_std = x.std(0);
|
||||
|
||||
for (i, col_std_i) in col_std.iter().enumerate() {
|
||||
if (*col_std_i - T::zero()).abs() < T::epsilon() {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Cannot rescale constant column {}",
|
||||
i
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
let mut scaled_x = x.clone();
|
||||
scaled_x.scale_mut(&col_mean, &col_std, 0);
|
||||
Ok((scaled_x, col_mean, col_std))
|
||||
}
|
||||
|
||||
/// Predict target values from `x`
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let (nrows, _) = x.shape();
|
||||
let mut y_hat = x.matmul(&self.coefficients);
|
||||
y_hat.add_mut(&M::fill(nrows, 1, self.intercept));
|
||||
Ok(y_hat.transpose().to_row_vector())
|
||||
}
|
||||
|
||||
/// Get estimates regression coefficients
|
||||
pub fn coefficients(&self) -> &M {
|
||||
&self.coefficients
|
||||
}
|
||||
|
||||
/// Get estimate of intercept
|
||||
pub fn intercept(&self) -> T {
|
||||
self.intercept
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::metrics::mean_absolute_error;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn ridge_fit_predict() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
&[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
|
||||
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
&[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
|
||||
&[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
|
||||
&[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
|
||||
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
&[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
&[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
|
||||
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
|
||||
let y: Vec<f64> = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
|
||||
let y_hat_cholesky = RidgeRegression::fit(
|
||||
&x,
|
||||
&y,
|
||||
RidgeRegressionParameters {
|
||||
solver: RidgeRegressionSolverName::Cholesky,
|
||||
alpha: 0.1,
|
||||
normalize: true,
|
||||
},
|
||||
)
|
||||
.and_then(|lr| lr.predict(&x))
|
||||
.unwrap();
|
||||
|
||||
assert!(mean_absolute_error(&y_hat_cholesky, &y) < 2.0);
|
||||
|
||||
let y_hat_svd = RidgeRegression::fit(
|
||||
&x,
|
||||
&y,
|
||||
RidgeRegressionParameters {
|
||||
solver: RidgeRegressionSolverName::SVD,
|
||||
alpha: 0.1,
|
||||
normalize: false,
|
||||
},
|
||||
)
|
||||
.and_then(|lr| lr.predict(&x))
|
||||
.unwrap();
|
||||
|
||||
assert!(mean_absolute_error(&y_hat_svd, &y) < 2.0);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
&[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
|
||||
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
&[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
|
||||
&[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
|
||||
&[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
|
||||
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
&[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
&[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
|
||||
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
|
||||
let y = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
|
||||
let lr = RidgeRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let deserialized_lr: RidgeRegression<f64, DenseMatrix<f64>> =
|
||||
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(lr, deserialized_lr);
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,7 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::math::num::RealNumber;
|
||||
@@ -25,12 +26,13 @@ use crate::math::num::RealNumber;
|
||||
use super::Distance;
|
||||
|
||||
/// Euclidean distance is a measure of the true straight line distance between two points in Euclidean n-space.
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Euclidian {}
|
||||
|
||||
impl Euclidian {
|
||||
#[inline]
|
||||
pub(crate) fn squared_distance<T: RealNumber>(x: &Vec<T>, y: &Vec<T>) -> T {
|
||||
pub(crate) fn squared_distance<T: RealNumber>(x: &[T], y: &[T]) -> T {
|
||||
if x.len() != y.len() {
|
||||
panic!("Input vector sizes are different.");
|
||||
}
|
||||
@@ -38,7 +40,7 @@ impl Euclidian {
|
||||
let mut sum = T::zero();
|
||||
for i in 0..x.len() {
|
||||
let d = x[i] - y[i];
|
||||
sum = sum + d * d;
|
||||
sum += d * d;
|
||||
}
|
||||
|
||||
sum
|
||||
@@ -55,6 +57,7 @@ impl<T: RealNumber> Distance<Vec<T>, T> for Euclidian {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn squared_distance() {
|
||||
let a = vec![1., 2., 3.];
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::math::num::RealNumber;
|
||||
@@ -26,7 +27,8 @@ use crate::math::num::RealNumber;
|
||||
use super::Distance;
|
||||
|
||||
/// While comparing two integer-valued vectors of equal length, Hamming distance is the number of bit positions in which the two bits are different
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Hamming {}
|
||||
|
||||
impl<T: PartialEq, F: RealNumber> Distance<Vec<T>, F> for Hamming {
|
||||
@@ -50,6 +52,7 @@ impl<T: PartialEq, F: RealNumber> Distance<Vec<T>, F> for Hamming {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn hamming_distance() {
|
||||
let a = vec![1, 0, 0, 1, 0, 0, 1];
|
||||
|
||||
@@ -44,6 +44,7 @@
|
||||
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::math::num::RealNumber;
|
||||
@@ -52,7 +53,8 @@ use super::Distance;
|
||||
use crate::linalg::Matrix;
|
||||
|
||||
/// Mahalanobis distance.
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Mahalanobis<T: RealNumber, M: Matrix<T>> {
|
||||
/// covariance matrix of the dataset
|
||||
pub sigma: M,
|
||||
@@ -68,8 +70,8 @@ impl<T: RealNumber, M: Matrix<T>> Mahalanobis<T, M> {
|
||||
let sigma = data.cov();
|
||||
let sigmaInv = sigma.lu().and_then(|lu| lu.inverse()).unwrap();
|
||||
Mahalanobis {
|
||||
sigma: sigma,
|
||||
sigmaInv: sigmaInv,
|
||||
sigma,
|
||||
sigmaInv,
|
||||
t: PhantomData,
|
||||
}
|
||||
}
|
||||
@@ -80,8 +82,8 @@ impl<T: RealNumber, M: Matrix<T>> Mahalanobis<T, M> {
|
||||
let sigma = cov.clone();
|
||||
let sigmaInv = sigma.lu().and_then(|lu| lu.inverse()).unwrap();
|
||||
Mahalanobis {
|
||||
sigma: sigma,
|
||||
sigmaInv: sigmaInv,
|
||||
sigma,
|
||||
sigmaInv,
|
||||
t: PhantomData,
|
||||
}
|
||||
}
|
||||
@@ -118,7 +120,7 @@ impl<T: RealNumber, M: Matrix<T>> Distance<Vec<T>, T> for Mahalanobis<T, M> {
|
||||
let mut s = T::zero();
|
||||
for j in 0..n {
|
||||
for i in 0..n {
|
||||
s = s + self.sigmaInv.get(i, j) * z[i] * z[j];
|
||||
s += self.sigmaInv.get(i, j) * z[i] * z[j];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,6 +133,7 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn mahalanobis_distance() {
|
||||
let data = DenseMatrix::from_2d_array(&[
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
//! ```
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::math::num::RealNumber;
|
||||
@@ -24,7 +25,8 @@ use crate::math::num::RealNumber;
|
||||
use super::Distance;
|
||||
|
||||
/// Manhattan distance
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Manhattan {}
|
||||
|
||||
impl<T: RealNumber> Distance<Vec<T>, T> for Manhattan {
|
||||
@@ -35,7 +37,7 @@ impl<T: RealNumber> Distance<Vec<T>, T> for Manhattan {
|
||||
|
||||
let mut dist = T::zero();
|
||||
for i in 0..x.len() {
|
||||
dist = dist + (x[i] - y[i]).abs();
|
||||
dist += (x[i] - y[i]).abs();
|
||||
}
|
||||
|
||||
dist
|
||||
@@ -46,6 +48,7 @@ impl<T: RealNumber> Distance<Vec<T>, T> for Manhattan {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn manhattan_distance() {
|
||||
let a = vec![1., 2., 3.];
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::math::num::RealNumber;
|
||||
@@ -28,7 +29,8 @@ use crate::math::num::RealNumber;
|
||||
use super::Distance;
|
||||
|
||||
/// Defines the Minkowski distance of order `p`
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Minkowski {
|
||||
/// order, integer
|
||||
pub p: u16,
|
||||
@@ -48,7 +50,7 @@ impl<T: RealNumber> Distance<Vec<T>, T> for Minkowski {
|
||||
|
||||
for i in 0..x.len() {
|
||||
let d = (x[i] - y[i]).abs();
|
||||
dist = dist + d.powf(p_t);
|
||||
dist += d.powf(p_t);
|
||||
}
|
||||
|
||||
dist.powf(T::one() / p_t)
|
||||
@@ -59,6 +61,7 @@ impl<T: RealNumber> Distance<Vec<T>, T> for Minkowski {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn minkowski_distance() {
|
||||
let a = vec![1., 2., 3.];
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
//! Formally, the distance can be any metric measure that is defined as \\( d(x, y) \geq 0\\) and follows three conditions:
|
||||
//! 1. \\( d(x, y) = 0 \\) if and only \\( x = y \\), positive definiteness
|
||||
//! 1. \\( d(x, y) = d(y, x) \\), symmetry
|
||||
//! 1. \\( d(x, y) \leq d(x, z) + d(z, y) \\), subadditivity or triangle inequality
|
||||
//! 1. \\( d(x, y) \leq d(x, z) + d(z, y) \\), subadditivity or triangle inequality
|
||||
//!
|
||||
//! for all \\(x, y, z \in Z \\)
|
||||
//!
|
||||
@@ -28,7 +28,7 @@ use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Distance metric, a function that calculates distance between two points
|
||||
pub trait Distance<T, F: RealNumber> {
|
||||
pub trait Distance<T, F: RealNumber>: Clone {
|
||||
/// Calculates distance between _a_ and _b_
|
||||
fn distance(&self, a: &T, b: &T) -> F;
|
||||
}
|
||||
@@ -45,7 +45,7 @@ impl Distances {
|
||||
/// Minkowski distance, see [`Minkowski`](minkowski/index.html)
|
||||
/// * `p` - function order. Should be >= 1
|
||||
pub fn minkowski(p: u16) -> minkowski::Minkowski {
|
||||
minkowski::Minkowski { p: p }
|
||||
minkowski::Minkowski { p }
|
||||
}
|
||||
|
||||
/// Manhattan distance, see [`Manhattan`](manhattan/index.html)
|
||||
|
||||
+37
-12
@@ -6,10 +6,23 @@ use num_traits::{Float, FromPrimitive};
|
||||
use rand::prelude::*;
|
||||
use std::fmt::{Debug, Display};
|
||||
use std::iter::{Product, Sum};
|
||||
use std::ops::{AddAssign, DivAssign, MulAssign, SubAssign};
|
||||
|
||||
/// Defines real number
|
||||
/// <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
pub trait RealNumber: Float + FromPrimitive + Debug + Display + Copy + Sum + Product {
|
||||
pub trait RealNumber:
|
||||
Float
|
||||
+ FromPrimitive
|
||||
+ Debug
|
||||
+ Display
|
||||
+ Copy
|
||||
+ Sum
|
||||
+ Product
|
||||
+ AddAssign
|
||||
+ SubAssign
|
||||
+ MulAssign
|
||||
+ DivAssign
|
||||
{
|
||||
/// Copy sign from `sign` - another real number
|
||||
fn copysign(self, sign: Self) -> Self;
|
||||
|
||||
@@ -33,8 +46,11 @@ pub trait RealNumber: Float + FromPrimitive + Debug + Display + Copy + Sum + Pro
|
||||
self * self
|
||||
}
|
||||
|
||||
/// Raw transmutation to u64
|
||||
/// Raw transmutation to u32
|
||||
fn to_f32_bits(self) -> u32;
|
||||
|
||||
/// Raw transmutation to u64
|
||||
fn to_f64_bits(self) -> u64;
|
||||
}
|
||||
|
||||
impl RealNumber for f64 {
|
||||
@@ -44,19 +60,19 @@ impl RealNumber for f64 {
|
||||
|
||||
fn ln_1pe(self) -> f64 {
|
||||
if self > 15. {
|
||||
return self;
|
||||
self
|
||||
} else {
|
||||
return self.exp().ln_1p();
|
||||
self.exp().ln_1p()
|
||||
}
|
||||
}
|
||||
|
||||
fn sigmoid(self) -> f64 {
|
||||
if self < -40. {
|
||||
return 0.;
|
||||
0.
|
||||
} else if self > 40. {
|
||||
return 1.;
|
||||
1.
|
||||
} else {
|
||||
return 1. / (1. + f64::exp(-self));
|
||||
1. / (1. + f64::exp(-self))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,6 +92,10 @@ impl RealNumber for f64 {
|
||||
fn to_f32_bits(self) -> u32 {
|
||||
self.to_bits() as u32
|
||||
}
|
||||
|
||||
fn to_f64_bits(self) -> u64 {
|
||||
self.to_bits()
|
||||
}
|
||||
}
|
||||
|
||||
impl RealNumber for f32 {
|
||||
@@ -85,19 +105,19 @@ impl RealNumber for f32 {
|
||||
|
||||
fn ln_1pe(self) -> f32 {
|
||||
if self > 15. {
|
||||
return self;
|
||||
self
|
||||
} else {
|
||||
return self.exp().ln_1p();
|
||||
self.exp().ln_1p()
|
||||
}
|
||||
}
|
||||
|
||||
fn sigmoid(self) -> f32 {
|
||||
if self < -40. {
|
||||
return 0.;
|
||||
0.
|
||||
} else if self > 40. {
|
||||
return 1.;
|
||||
1.
|
||||
} else {
|
||||
return 1. / (1. + f32::exp(-self));
|
||||
1. / (1. + f32::exp(-self))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,12 +137,17 @@ impl RealNumber for f32 {
|
||||
fn to_f32_bits(self) -> u32 {
|
||||
self.to_bits()
|
||||
}
|
||||
|
||||
fn to_f64_bits(self) -> u64 {
|
||||
self.to_bits() as u64
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn sigmoid() {
|
||||
assert_eq!(1.0.sigmoid(), 0.7310585786300049);
|
||||
|
||||
+10
-8
@@ -1,13 +1,14 @@
|
||||
use crate::math::num::RealNumber;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
pub trait RealNumberVector<T: RealNumber> {
|
||||
fn unique(&self) -> (Vec<T>, Vec<usize>);
|
||||
fn unique_with_indices(&self) -> (Vec<T>, Vec<usize>);
|
||||
}
|
||||
|
||||
impl<T: RealNumber> RealNumberVector<T> for Vec<T> {
|
||||
fn unique(&self) -> (Vec<T>, Vec<usize>) {
|
||||
let mut unique = self.clone();
|
||||
impl<T: RealNumber, V: BaseVector<T>> RealNumberVector<T> for V {
|
||||
fn unique_with_indices(&self) -> (Vec<T>, Vec<usize>) {
|
||||
let mut unique = self.to_vec();
|
||||
unique.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
unique.dedup();
|
||||
|
||||
@@ -17,8 +18,8 @@ impl<T: RealNumber> RealNumberVector<T> for Vec<T> {
|
||||
}
|
||||
|
||||
let mut unique_index = Vec::with_capacity(self.len());
|
||||
for e in self {
|
||||
unique_index.push(index[&e.to_i64().unwrap()]);
|
||||
for idx in 0..self.len() {
|
||||
unique_index.push(index[&self.get(idx).to_i64().unwrap()]);
|
||||
}
|
||||
|
||||
(unique, unique_index)
|
||||
@@ -29,12 +30,13 @@ impl<T: RealNumber> RealNumberVector<T> for Vec<T> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn unique() {
|
||||
fn unique_with_indices() {
|
||||
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
|
||||
assert_eq!(
|
||||
(vec!(0.0, 1.0, 2.0, 4.0), vec!(0, 0, 1, 1, 2, 0, 3)),
|
||||
v1.unique()
|
||||
v1.unique_with_indices()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,13 +16,15 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Accuracy metric.
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct Accuracy {}
|
||||
|
||||
impl Accuracy {
|
||||
@@ -55,6 +57,7 @@ impl Accuracy {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn accuracy() {
|
||||
let y_pred: Vec<f64> = vec![0., 2., 1., 3.];
|
||||
|
||||
+9
-6
@@ -20,6 +20,7 @@
|
||||
//! * ["The ROC-AUC and the Mann-Whitney U-test", Haupt, J.](https://johaupt.github.io/roc-auc/model%20evaluation/Area_under_ROC_curve.html)
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||
@@ -27,7 +28,8 @@ use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Area Under the Receiver Operating Characteristic Curve (ROC AUC)
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct AUC {}
|
||||
|
||||
impl AUC {
|
||||
@@ -42,9 +44,9 @@ impl AUC {
|
||||
|
||||
for i in 0..n {
|
||||
if y_true.get(i) == T::zero() {
|
||||
neg = neg + T::one();
|
||||
neg += T::one();
|
||||
} else if y_true.get(i) == T::one() {
|
||||
pos = pos + T::one();
|
||||
pos += T::one();
|
||||
} else {
|
||||
panic!(
|
||||
"AUC is only for binary classification. Invalid label: {}",
|
||||
@@ -68,8 +70,8 @@ impl AUC {
|
||||
j += 1;
|
||||
}
|
||||
let r = T::from_usize(i + 1 + j).unwrap() / T::two();
|
||||
for k in i..j {
|
||||
rank[k] = r;
|
||||
for rank_k in rank.iter_mut().take(j).skip(i) {
|
||||
*rank_k = r;
|
||||
}
|
||||
i = j - 1;
|
||||
}
|
||||
@@ -79,7 +81,7 @@ impl AUC {
|
||||
let mut auc = T::zero();
|
||||
for i in 0..n {
|
||||
if y_true.get(label_idx[i]) == T::one() {
|
||||
auc = auc + rank[i];
|
||||
auc += rank[i];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,6 +93,7 @@ impl AUC {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn auc() {
|
||||
let y_true: Vec<f64> = vec![0., 0., 1., 1.];
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::metrics::cluster_helpers::*;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
/// Homogeneity, completeness and V-Measure scores.
|
||||
pub struct HCVScore {}
|
||||
|
||||
@@ -24,8 +26,8 @@ impl HCVScore {
|
||||
let contingency = contingency_matrix(&labels_true, &labels_pred);
|
||||
let mi: T = mutual_info_score(&contingency);
|
||||
|
||||
let homogeneity = entropy_c.map(|e| mi / e).unwrap_or(T::one());
|
||||
let completeness = entropy_k.map(|e| mi / e).unwrap_or(T::one());
|
||||
let homogeneity = entropy_c.map(|e| mi / e).unwrap_or_else(T::one);
|
||||
let completeness = entropy_k.map(|e| mi / e).unwrap_or_else(T::one);
|
||||
|
||||
let v_measure_score = if homogeneity + completeness == T::zero() {
|
||||
T::zero()
|
||||
@@ -41,6 +43,7 @@ impl HCVScore {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn homogeneity_score() {
|
||||
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#![allow(clippy::ptr_arg)]
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::math::num::RealNumber;
|
||||
@@ -7,8 +8,8 @@ pub fn contingency_matrix<T: RealNumber>(
|
||||
labels_true: &Vec<T>,
|
||||
labels_pred: &Vec<T>,
|
||||
) -> Vec<Vec<usize>> {
|
||||
let (classes, class_idx) = labels_true.unique();
|
||||
let (clusters, cluster_idx) = labels_pred.unique();
|
||||
let (classes, class_idx) = labels_true.unique_with_indices();
|
||||
let (clusters, cluster_idx) = labels_pred.unique_with_indices();
|
||||
|
||||
let mut contingency_matrix = Vec::with_capacity(classes.len());
|
||||
|
||||
@@ -23,7 +24,7 @@ pub fn contingency_matrix<T: RealNumber>(
|
||||
contingency_matrix
|
||||
}
|
||||
|
||||
pub fn entropy<T: RealNumber>(data: &Vec<T>) -> Option<T> {
|
||||
pub fn entropy<T: RealNumber>(data: &[T]) -> Option<T> {
|
||||
let mut bincounts = HashMap::with_capacity(data.len());
|
||||
|
||||
for e in data.iter() {
|
||||
@@ -37,24 +38,24 @@ pub fn entropy<T: RealNumber>(data: &Vec<T>) -> Option<T> {
|
||||
for &c in bincounts.values() {
|
||||
if c > 0 {
|
||||
let pi = T::from_usize(c).unwrap();
|
||||
entropy = entropy - (pi / sum) * (pi.ln() - sum.ln());
|
||||
entropy -= (pi / sum) * (pi.ln() - sum.ln());
|
||||
}
|
||||
}
|
||||
|
||||
Some(entropy)
|
||||
}
|
||||
|
||||
pub fn mutual_info_score<T: RealNumber>(contingency: &Vec<Vec<usize>>) -> T {
|
||||
pub fn mutual_info_score<T: RealNumber>(contingency: &[Vec<usize>]) -> T {
|
||||
let mut contingency_sum = 0;
|
||||
let mut pi = vec![0; contingency.len()];
|
||||
let mut pj = vec![0; contingency[0].len()];
|
||||
let (mut nzx, mut nzy, mut nz_val) = (Vec::new(), Vec::new(), Vec::new());
|
||||
|
||||
for r in 0..contingency.len() {
|
||||
for c in 0..contingency[0].len() {
|
||||
for (c, pj_c) in pj.iter_mut().enumerate().take(contingency[0].len()) {
|
||||
contingency_sum += contingency[r][c];
|
||||
pi[r] += contingency[r][c];
|
||||
pj[c] += contingency[r][c];
|
||||
*pj_c += contingency[r][c];
|
||||
if contingency[r][c] > 0 {
|
||||
nzx.push(r);
|
||||
nzy.push(c);
|
||||
@@ -89,9 +90,8 @@ pub fn mutual_info_score<T: RealNumber>(contingency: &Vec<Vec<usize>>) -> T {
|
||||
let mut result = T::zero();
|
||||
|
||||
for i in 0..log_outer.len() {
|
||||
result = result
|
||||
+ ((contingency_nm[i] * (log_contingency_nm[i] - contingency_sum_ln))
|
||||
+ contingency_nm[i] * log_outer[i])
|
||||
result += (contingency_nm[i] * (log_contingency_nm[i] - contingency_sum_ln))
|
||||
+ contingency_nm[i] * log_outer[i]
|
||||
}
|
||||
|
||||
result.max(T::zero())
|
||||
@@ -101,6 +101,7 @@ pub fn mutual_info_score<T: RealNumber>(contingency: &Vec<Vec<usize>>) -> T {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn contingency_matrix_test() {
|
||||
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
|
||||
@@ -112,6 +113,7 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn entropy_test() {
|
||||
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
|
||||
@@ -119,6 +121,7 @@ mod tests {
|
||||
assert!((1.2770f32 - entropy(&v1).unwrap()).abs() < 1e-4);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn mutual_info_score_test() {
|
||||
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
|
||||
|
||||
+4
-1
@@ -18,6 +18,7 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
@@ -26,7 +27,8 @@ use crate::metrics::precision::Precision;
|
||||
use crate::metrics::recall::Recall;
|
||||
|
||||
/// F-measure
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct F1<T: RealNumber> {
|
||||
/// a positive real factor
|
||||
pub beta: T,
|
||||
@@ -57,6 +59,7 @@ impl<T: RealNumber> F1<T> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn f1() {
|
||||
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
||||
|
||||
@@ -18,12 +18,14 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
/// Mean Absolute Error
|
||||
pub struct MeanAbsoluteError {}
|
||||
|
||||
@@ -43,7 +45,7 @@ impl MeanAbsoluteError {
|
||||
let n = y_true.len();
|
||||
let mut ras = T::zero();
|
||||
for i in 0..n {
|
||||
ras = ras + (y_true.get(i) - y_pred.get(i)).abs();
|
||||
ras += (y_true.get(i) - y_pred.get(i)).abs();
|
||||
}
|
||||
|
||||
ras / T::from_usize(n).unwrap()
|
||||
@@ -54,6 +56,7 @@ impl MeanAbsoluteError {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn mean_absolute_error() {
|
||||
let y_true: Vec<f64> = vec![3., -0.5, 2., 7.];
|
||||
|
||||
@@ -18,12 +18,14 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
/// Mean Squared Error
|
||||
pub struct MeanSquareError {}
|
||||
|
||||
@@ -43,7 +45,7 @@ impl MeanSquareError {
|
||||
let n = y_true.len();
|
||||
let mut rss = T::zero();
|
||||
for i in 0..n {
|
||||
rss = rss + (y_true.get(i) - y_pred.get(i)).square();
|
||||
rss += (y_true.get(i) - y_pred.get(i)).square();
|
||||
}
|
||||
|
||||
rss / T::from_usize(n).unwrap()
|
||||
@@ -54,6 +56,7 @@ impl MeanSquareError {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn mean_squared_error() {
|
||||
let y_true: Vec<f64> = vec![3., -0.5, 2., 7.];
|
||||
|
||||
+2
-2
@@ -42,7 +42,7 @@
|
||||
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
//! ];
|
||||
//!
|
||||
//! let lr = LogisticRegression::fit(&x, &y).unwrap();
|
||||
//! let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
//!
|
||||
//! let y_hat = lr.predict(&x).unwrap();
|
||||
//!
|
||||
@@ -101,7 +101,7 @@ impl ClassificationMetrics {
|
||||
|
||||
/// F1 score, also known as balanced F-score or F-measure, see [F1](f1/index.html).
|
||||
pub fn f1<T: RealNumber>(beta: T) -> f1::F1<T> {
|
||||
f1::F1 { beta: beta }
|
||||
f1::F1 { beta }
|
||||
}
|
||||
|
||||
/// Area Under the Receiver Operating Characteristic Curve (ROC AUC), see [AUC](auc/index.html).
|
||||
|
||||
+46
-23
@@ -18,13 +18,17 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::collections::HashSet;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Precision metric.
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct Precision {}
|
||||
|
||||
impl Precision {
|
||||
@@ -40,34 +44,33 @@ impl Precision {
|
||||
);
|
||||
}
|
||||
|
||||
let mut classes = HashSet::new();
|
||||
for i in 0..y_true.len() {
|
||||
classes.insert(y_true.get(i).to_f64_bits());
|
||||
}
|
||||
let classes = classes.len();
|
||||
|
||||
let mut tp = 0;
|
||||
let mut p = 0;
|
||||
let n = y_true.len();
|
||||
for i in 0..n {
|
||||
if y_true.get(i) != T::zero() && y_true.get(i) != T::one() {
|
||||
panic!(
|
||||
"Precision can only be applied to binary classification: {}",
|
||||
y_true.get(i)
|
||||
);
|
||||
}
|
||||
|
||||
if y_pred.get(i) != T::zero() && y_pred.get(i) != T::one() {
|
||||
panic!(
|
||||
"Precision can only be applied to binary classification: {}",
|
||||
y_pred.get(i)
|
||||
);
|
||||
}
|
||||
|
||||
if y_pred.get(i) == T::one() {
|
||||
p += 1;
|
||||
|
||||
if y_true.get(i) == T::one() {
|
||||
let mut fp = 0;
|
||||
for i in 0..y_true.len() {
|
||||
if y_pred.get(i) == y_true.get(i) {
|
||||
if classes == 2 {
|
||||
if y_true.get(i) == T::one() {
|
||||
tp += 1;
|
||||
}
|
||||
} else {
|
||||
tp += 1;
|
||||
}
|
||||
} else if classes == 2 {
|
||||
if y_true.get(i) == T::one() {
|
||||
fp += 1;
|
||||
}
|
||||
} else {
|
||||
fp += 1;
|
||||
}
|
||||
}
|
||||
|
||||
T::from_i64(tp).unwrap() / T::from_i64(p).unwrap()
|
||||
T::from_i64(tp).unwrap() / (T::from_i64(tp).unwrap() + T::from_i64(fp).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,6 +78,7 @@ impl Precision {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn precision() {
|
||||
let y_true: Vec<f64> = vec![0., 1., 1., 0.];
|
||||
@@ -85,5 +89,24 @@ mod tests {
|
||||
|
||||
assert!((score1 - 0.5).abs() < 1e-8);
|
||||
assert!((score2 - 1.0).abs() < 1e-8);
|
||||
|
||||
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
||||
let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
|
||||
|
||||
let score3: f64 = Precision {}.get_score(&y_pred, &y_true);
|
||||
assert!((score3 - 0.5).abs() < 1e-8);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn precision_multiclass() {
|
||||
let y_true: Vec<f64> = vec![0., 0., 0., 1., 1., 1., 2., 2., 2.];
|
||||
let y_pred: Vec<f64> = vec![0., 1., 2., 0., 1., 2., 0., 1., 2.];
|
||||
|
||||
let score1: f64 = Precision {}.get_score(&y_pred, &y_true);
|
||||
let score2: f64 = Precision {}.get_score(&y_pred, &y_pred);
|
||||
|
||||
assert!((score1 - 0.333333333).abs() < 1e-8);
|
||||
assert!((score2 - 1.0).abs() < 1e-8);
|
||||
}
|
||||
}
|
||||
|
||||
+8
-5
@@ -18,13 +18,15 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Coefficient of Determination (R2)
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct R2 {}
|
||||
|
||||
impl R2 {
|
||||
@@ -45,10 +47,10 @@ impl R2 {
|
||||
let mut mean = T::zero();
|
||||
|
||||
for i in 0..n {
|
||||
mean = mean + y_true.get(i);
|
||||
mean += y_true.get(i);
|
||||
}
|
||||
|
||||
mean = mean / T::from_usize(n).unwrap();
|
||||
mean /= T::from_usize(n).unwrap();
|
||||
|
||||
let mut ss_tot = T::zero();
|
||||
let mut ss_res = T::zero();
|
||||
@@ -56,8 +58,8 @@ impl R2 {
|
||||
for i in 0..n {
|
||||
let y_i = y_true.get(i);
|
||||
let f_i = y_pred.get(i);
|
||||
ss_tot = ss_tot + (y_i - mean).square();
|
||||
ss_res = ss_res + (y_i - f_i).square();
|
||||
ss_tot += (y_i - mean).square();
|
||||
ss_res += (y_i - f_i).square();
|
||||
}
|
||||
|
||||
T::one() - (ss_res / ss_tot)
|
||||
@@ -68,6 +70,7 @@ impl R2 {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn r2() {
|
||||
let y_true: Vec<f64> = vec![3., -0.5, 2., 7.];
|
||||
|
||||
+47
-24
@@ -18,13 +18,18 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::collections::HashSet;
|
||||
use std::convert::TryInto;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Recall metric.
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct Recall {}
|
||||
|
||||
impl Recall {
|
||||
@@ -40,34 +45,32 @@ impl Recall {
|
||||
);
|
||||
}
|
||||
|
||||
let mut classes = HashSet::new();
|
||||
for i in 0..y_true.len() {
|
||||
classes.insert(y_true.get(i).to_f64_bits());
|
||||
}
|
||||
let classes: i64 = classes.len().try_into().unwrap();
|
||||
|
||||
let mut tp = 0;
|
||||
let mut p = 0;
|
||||
let n = y_true.len();
|
||||
for i in 0..n {
|
||||
if y_true.get(i) != T::zero() && y_true.get(i) != T::one() {
|
||||
panic!(
|
||||
"Recall can only be applied to binary classification: {}",
|
||||
y_true.get(i)
|
||||
);
|
||||
}
|
||||
|
||||
if y_pred.get(i) != T::zero() && y_pred.get(i) != T::one() {
|
||||
panic!(
|
||||
"Recall can only be applied to binary classification: {}",
|
||||
y_pred.get(i)
|
||||
);
|
||||
}
|
||||
|
||||
if y_true.get(i) == T::one() {
|
||||
p += 1;
|
||||
|
||||
if y_pred.get(i) == T::one() {
|
||||
let mut fne = 0;
|
||||
for i in 0..y_true.len() {
|
||||
if y_pred.get(i) == y_true.get(i) {
|
||||
if classes == 2 {
|
||||
if y_true.get(i) == T::one() {
|
||||
tp += 1;
|
||||
}
|
||||
} else {
|
||||
tp += 1;
|
||||
}
|
||||
} else if classes == 2 {
|
||||
if y_true.get(i) != T::one() {
|
||||
fne += 1;
|
||||
}
|
||||
} else {
|
||||
fne += 1;
|
||||
}
|
||||
}
|
||||
|
||||
T::from_i64(tp).unwrap() / T::from_i64(p).unwrap()
|
||||
T::from_i64(tp).unwrap() / (T::from_i64(tp).unwrap() + T::from_i64(fne).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,6 +78,7 @@ impl Recall {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn recall() {
|
||||
let y_true: Vec<f64> = vec![0., 1., 1., 0.];
|
||||
@@ -85,5 +89,24 @@ mod tests {
|
||||
|
||||
assert!((score1 - 0.5).abs() < 1e-8);
|
||||
assert!((score2 - 1.0).abs() < 1e-8);
|
||||
|
||||
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
||||
let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
|
||||
|
||||
let score3: f64 = Recall {}.get_score(&y_pred, &y_true);
|
||||
assert!((score3 - 0.66666666).abs() < 1e-8);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn recall_multiclass() {
|
||||
let y_true: Vec<f64> = vec![0., 0., 0., 1., 1., 1., 2., 2., 2.];
|
||||
let y_pred: Vec<f64> = vec![0., 1., 2., 0., 1., 2., 0., 1., 2.];
|
||||
|
||||
let score1: f64 = Recall {}.get_score(&y_pred, &y_true);
|
||||
let score2: f64 = Recall {}.get_score(&y_pred, &y_pred);
|
||||
|
||||
assert!((score1 - 0.333333333).abs() < 1e-8);
|
||||
assert!((score2 - 1.0).abs() < 1e-8);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,276 @@
|
||||
//! # KFold
|
||||
//!
|
||||
//! Defines k-fold cross validator.
|
||||
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::model_selection::BaseKFold;
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
/// K-Folds cross-validator
|
||||
pub struct KFold {
|
||||
/// Number of folds. Must be at least 2.
|
||||
pub n_splits: usize, // cannot exceed std::usize::MAX
|
||||
/// Whether to shuffle the data before splitting into batches
|
||||
pub shuffle: bool,
|
||||
}
|
||||
|
||||
impl KFold {
|
||||
fn test_indices<T: RealNumber, M: Matrix<T>>(&self, x: &M) -> Vec<Vec<usize>> {
|
||||
// number of samples (rows) in the matrix
|
||||
let n_samples: usize = x.shape().0;
|
||||
|
||||
// initialise indices
|
||||
let mut indices: Vec<usize> = (0..n_samples).collect();
|
||||
if self.shuffle {
|
||||
indices.shuffle(&mut thread_rng());
|
||||
}
|
||||
// return a new array of given shape n_split, filled with each element of n_samples divided by n_splits.
|
||||
let mut fold_sizes = vec![n_samples / self.n_splits; self.n_splits];
|
||||
|
||||
// increment by one if odd
|
||||
for fold_size in fold_sizes.iter_mut().take(n_samples % self.n_splits) {
|
||||
*fold_size += 1;
|
||||
}
|
||||
|
||||
// generate the right array of arrays for test indices
|
||||
let mut return_values: Vec<Vec<usize>> = Vec::with_capacity(self.n_splits);
|
||||
let mut current: usize = 0;
|
||||
for fold_size in fold_sizes.drain(..) {
|
||||
let stop = current + fold_size;
|
||||
return_values.push(indices[current..stop].to_vec());
|
||||
current = stop
|
||||
}
|
||||
|
||||
return_values
|
||||
}
|
||||
|
||||
fn test_masks<T: RealNumber, M: Matrix<T>>(&self, x: &M) -> Vec<Vec<bool>> {
|
||||
let mut return_values: Vec<Vec<bool>> = Vec::with_capacity(self.n_splits);
|
||||
for test_index in self.test_indices(x).drain(..) {
|
||||
// init mask
|
||||
let mut test_mask = vec![false; x.shape().0];
|
||||
// set mask's indices to true according to test indices
|
||||
for i in test_index {
|
||||
test_mask[i] = true; // can be implemented with map()
|
||||
}
|
||||
return_values.push(test_mask);
|
||||
}
|
||||
return_values
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for KFold {
|
||||
fn default() -> KFold {
|
||||
KFold {
|
||||
n_splits: 3,
|
||||
shuffle: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl KFold {
|
||||
/// Number of folds. Must be at least 2.
|
||||
pub fn with_n_splits(mut self, n_splits: usize) -> Self {
|
||||
self.n_splits = n_splits;
|
||||
self
|
||||
}
|
||||
/// Whether to shuffle the data before splitting into batches
|
||||
pub fn with_shuffle(mut self, shuffle: bool) -> Self {
|
||||
self.shuffle = shuffle;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// An iterator over indices that split data into training and test set.
|
||||
pub struct KFoldIter {
|
||||
indices: Vec<usize>,
|
||||
test_indices: Vec<Vec<bool>>,
|
||||
}
|
||||
|
||||
impl Iterator for KFoldIter {
|
||||
type Item = (Vec<usize>, Vec<usize>);
|
||||
|
||||
fn next(&mut self) -> Option<(Vec<usize>, Vec<usize>)> {
|
||||
self.test_indices.pop().map(|test_index| {
|
||||
let train_index = self
|
||||
.indices
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|&(idx, _)| !test_index[idx])
|
||||
.map(|(idx, _)| idx)
|
||||
.collect::<Vec<usize>>(); // filter train indices out according to mask
|
||||
let test_index = self
|
||||
.indices
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|&(idx, _)| test_index[idx])
|
||||
.map(|(idx, _)| idx)
|
||||
.collect::<Vec<usize>>(); // filter tests indices out according to mask
|
||||
|
||||
(train_index, test_index)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Abstract class for all KFold functionalities
|
||||
impl BaseKFold for KFold {
|
||||
type Output = KFoldIter;
|
||||
|
||||
fn n_splits(&self) -> usize {
|
||||
self.n_splits
|
||||
}
|
||||
|
||||
fn split<T: RealNumber, M: Matrix<T>>(&self, x: &M) -> Self::Output {
|
||||
if self.n_splits < 2 {
|
||||
panic!("Number of splits is too small: {}", self.n_splits);
|
||||
}
|
||||
let n_samples: usize = x.shape().0;
|
||||
let indices: Vec<usize> = (0..n_samples).collect();
|
||||
let mut test_indices = self.test_masks(x);
|
||||
test_indices.reverse();
|
||||
|
||||
KFoldIter {
|
||||
indices,
|
||||
test_indices,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn run_kfold_return_test_indices_simple() {
|
||||
let k = KFold {
|
||||
n_splits: 3,
|
||||
shuffle: false,
|
||||
};
|
||||
let x: DenseMatrix<f64> = DenseMatrix::rand(33, 100);
|
||||
let test_indices = k.test_indices(&x);
|
||||
|
||||
assert_eq!(test_indices[0], (0..11).collect::<Vec<usize>>());
|
||||
assert_eq!(test_indices[1], (11..22).collect::<Vec<usize>>());
|
||||
assert_eq!(test_indices[2], (22..33).collect::<Vec<usize>>());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn run_kfold_return_test_indices_odd() {
|
||||
let k = KFold {
|
||||
n_splits: 3,
|
||||
shuffle: false,
|
||||
};
|
||||
let x: DenseMatrix<f64> = DenseMatrix::rand(34, 100);
|
||||
let test_indices = k.test_indices(&x);
|
||||
|
||||
assert_eq!(test_indices[0], (0..12).collect::<Vec<usize>>());
|
||||
assert_eq!(test_indices[1], (12..23).collect::<Vec<usize>>());
|
||||
assert_eq!(test_indices[2], (23..34).collect::<Vec<usize>>());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn run_kfold_return_test_mask_simple() {
|
||||
let k = KFold {
|
||||
n_splits: 2,
|
||||
shuffle: false,
|
||||
};
|
||||
let x: DenseMatrix<f64> = DenseMatrix::rand(22, 100);
|
||||
let test_masks = k.test_masks(&x);
|
||||
|
||||
for t in &test_masks[0][0..11] {
|
||||
// TODO: this can be prob done better
|
||||
assert_eq!(*t, true)
|
||||
}
|
||||
for t in &test_masks[0][11..22] {
|
||||
assert_eq!(*t, false)
|
||||
}
|
||||
|
||||
for t in &test_masks[1][0..11] {
|
||||
assert_eq!(*t, false)
|
||||
}
|
||||
for t in &test_masks[1][11..22] {
|
||||
assert_eq!(*t, true)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn run_kfold_return_split_simple() {
|
||||
let k = KFold {
|
||||
n_splits: 2,
|
||||
shuffle: false,
|
||||
};
|
||||
let x: DenseMatrix<f64> = DenseMatrix::rand(22, 100);
|
||||
let train_test_splits: Vec<(Vec<usize>, Vec<usize>)> = k.split(&x).collect();
|
||||
|
||||
assert_eq!(train_test_splits[0].1, (0..11).collect::<Vec<usize>>());
|
||||
assert_eq!(train_test_splits[0].0, (11..22).collect::<Vec<usize>>());
|
||||
assert_eq!(train_test_splits[1].0, (0..11).collect::<Vec<usize>>());
|
||||
assert_eq!(train_test_splits[1].1, (11..22).collect::<Vec<usize>>());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn run_kfold_return_split_simple_shuffle() {
|
||||
let k = KFold {
|
||||
n_splits: 2,
|
||||
..KFold::default()
|
||||
};
|
||||
let x: DenseMatrix<f64> = DenseMatrix::rand(23, 100);
|
||||
let train_test_splits: Vec<(Vec<usize>, Vec<usize>)> = k.split(&x).collect();
|
||||
|
||||
assert_eq!(train_test_splits[0].1.len(), 12_usize);
|
||||
assert_eq!(train_test_splits[0].0.len(), 11_usize);
|
||||
assert_eq!(train_test_splits[1].0.len(), 12_usize);
|
||||
assert_eq!(train_test_splits[1].1.len(), 11_usize);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn numpy_parity_test() {
|
||||
let k = KFold {
|
||||
n_splits: 3,
|
||||
shuffle: false,
|
||||
};
|
||||
let x: DenseMatrix<f64> = DenseMatrix::rand(10, 4);
|
||||
let expected: Vec<(Vec<usize>, Vec<usize>)> = vec![
|
||||
(vec![4, 5, 6, 7, 8, 9], vec![0, 1, 2, 3]),
|
||||
(vec![0, 1, 2, 3, 7, 8, 9], vec![4, 5, 6]),
|
||||
(vec![0, 1, 2, 3, 4, 5, 6], vec![7, 8, 9]),
|
||||
];
|
||||
for ((train, test), (expected_train, expected_test)) in
|
||||
k.split(&x).into_iter().zip(expected)
|
||||
{
|
||||
assert_eq!(test, expected_test);
|
||||
assert_eq!(train, expected_train);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn numpy_parity_test_shuffle() {
|
||||
let k = KFold {
|
||||
n_splits: 3,
|
||||
..KFold::default()
|
||||
};
|
||||
let x: DenseMatrix<f64> = DenseMatrix::rand(10, 4);
|
||||
let expected: Vec<(Vec<usize>, Vec<usize>)> = vec![
|
||||
(vec![4, 5, 6, 7, 8, 9], vec![0, 1, 2, 3]),
|
||||
(vec![0, 1, 2, 3, 7, 8, 9], vec![4, 5, 6]),
|
||||
(vec![0, 1, 2, 3, 4, 5, 6], vec![7, 8, 9]),
|
||||
];
|
||||
for ((train, test), (expected_train, expected_test)) in
|
||||
k.split(&x).into_iter().zip(expected)
|
||||
{
|
||||
assert_eq!(test.len(), expected_test.len());
|
||||
assert_eq!(train.len(), expected_train.len());
|
||||
}
|
||||
}
|
||||
}
|
||||
+389
-46
@@ -1,28 +1,140 @@
|
||||
//! # Model Selection methods
|
||||
//!
|
||||
//! In statistics and machine learning we usually split our data into multiple subsets: training data and testing data (and sometimes to validate),
|
||||
//! and fit our model on the train data, in order to make predictions on the test data. We do that to avoid overfitting or underfitting model to our data.
|
||||
//! In statistics and machine learning we usually split our data into two sets: one for training and the other one for testing.
|
||||
//! We fit our model to the training data, in order to make predictions on the test data. We do that to avoid overfitting or underfitting model to our data.
|
||||
//! Overfitting is bad because the model we trained fits trained data too well and can’t make any inferences on new data.
|
||||
//! Underfitted is bad because the model is undetrained and does not fit the training data well.
|
||||
//! Splitting data into multiple subsets helps to find the right combination of hyperparameters, estimate model performance and choose the right model for
|
||||
//! your data.
|
||||
//! Splitting data into multiple subsets helps us to find the right combination of hyperparameters, estimate model performance and choose the right model for
|
||||
//! the data.
|
||||
//!
|
||||
//! In SmartCore you can split your data into training and test datasets using `train_test_split` function.
|
||||
extern crate rand;
|
||||
//! In SmartCore a random split into training and test sets can be quickly computed with the [train_test_split](./fn.train_test_split.html) helper function.
|
||||
//!
|
||||
//! ```
|
||||
//! use crate::smartcore::linalg::BaseMatrix;
|
||||
//! use smartcore::linalg::naive::dense_matrix::DenseMatrix;
|
||||
//! use smartcore::model_selection::train_test_split;
|
||||
//!
|
||||
//! //Iris data
|
||||
//! let x = DenseMatrix::from_2d_array(&[
|
||||
//! &[5.1, 3.5, 1.4, 0.2],
|
||||
//! &[4.9, 3.0, 1.4, 0.2],
|
||||
//! &[4.7, 3.2, 1.3, 0.2],
|
||||
//! &[4.6, 3.1, 1.5, 0.2],
|
||||
//! &[5.0, 3.6, 1.4, 0.2],
|
||||
//! &[5.4, 3.9, 1.7, 0.4],
|
||||
//! &[4.6, 3.4, 1.4, 0.3],
|
||||
//! &[5.0, 3.4, 1.5, 0.2],
|
||||
//! &[4.4, 2.9, 1.4, 0.2],
|
||||
//! &[4.9, 3.1, 1.5, 0.1],
|
||||
//! &[7.0, 3.2, 4.7, 1.4],
|
||||
//! &[6.4, 3.2, 4.5, 1.5],
|
||||
//! &[6.9, 3.1, 4.9, 1.5],
|
||||
//! &[5.5, 2.3, 4.0, 1.3],
|
||||
//! &[6.5, 2.8, 4.6, 1.5],
|
||||
//! &[5.7, 2.8, 4.5, 1.3],
|
||||
//! &[6.3, 3.3, 4.7, 1.6],
|
||||
//! &[4.9, 2.4, 3.3, 1.0],
|
||||
//! &[6.6, 2.9, 4.6, 1.3],
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//! let y: Vec<f64> = vec![
|
||||
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
//! ];
|
||||
//!
|
||||
//! let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true);
|
||||
//!
|
||||
//! println!("X train: {:?}, y train: {}, X test: {:?}, y test: {}",
|
||||
//! x_train.shape(), y_train.len(), x_test.shape(), y_test.len());
|
||||
//! ```
|
||||
//!
|
||||
//! When we partition the available data into two disjoint sets, we drastically reduce the number of samples that can be used for training.
|
||||
//!
|
||||
//! One way to solve this problem is to use k-fold cross-validation. With k-fold validation, the dataset is split into k disjoint sets.
|
||||
//! A model is trained using k - 1 of the folds, and the resulting model is validated on the remaining portion of the data.
|
||||
//!
|
||||
//! The simplest way to run cross-validation is to use the [cross_val_score](./fn.cross_validate.html) helper function on your estimator and the dataset.
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::DenseMatrix;
|
||||
//! use smartcore::model_selection::{KFold, cross_validate};
|
||||
//! use smartcore::metrics::accuracy;
|
||||
//! use smartcore::linear::logistic_regression::LogisticRegression;
|
||||
//!
|
||||
//! //Iris data
|
||||
//! let x = DenseMatrix::from_2d_array(&[
|
||||
//! &[5.1, 3.5, 1.4, 0.2],
|
||||
//! &[4.9, 3.0, 1.4, 0.2],
|
||||
//! &[4.7, 3.2, 1.3, 0.2],
|
||||
//! &[4.6, 3.1, 1.5, 0.2],
|
||||
//! &[5.0, 3.6, 1.4, 0.2],
|
||||
//! &[5.4, 3.9, 1.7, 0.4],
|
||||
//! &[4.6, 3.4, 1.4, 0.3],
|
||||
//! &[5.0, 3.4, 1.5, 0.2],
|
||||
//! &[4.4, 2.9, 1.4, 0.2],
|
||||
//! &[4.9, 3.1, 1.5, 0.1],
|
||||
//! &[7.0, 3.2, 4.7, 1.4],
|
||||
//! &[6.4, 3.2, 4.5, 1.5],
|
||||
//! &[6.9, 3.1, 4.9, 1.5],
|
||||
//! &[5.5, 2.3, 4.0, 1.3],
|
||||
//! &[6.5, 2.8, 4.6, 1.5],
|
||||
//! &[5.7, 2.8, 4.5, 1.3],
|
||||
//! &[6.3, 3.3, 4.7, 1.6],
|
||||
//! &[4.9, 2.4, 3.3, 1.0],
|
||||
//! &[6.6, 2.9, 4.6, 1.3],
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//! let y: Vec<f64> = vec![
|
||||
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
//! ];
|
||||
//!
|
||||
//! let cv = KFold::default().with_n_splits(3);
|
||||
//!
|
||||
//! let results = cross_validate(LogisticRegression::fit, //estimator
|
||||
//! &x, &y, //data
|
||||
//! Default::default(), //hyperparameters
|
||||
//! cv, //cross validation split
|
||||
//! &accuracy).unwrap(); //metric
|
||||
//!
|
||||
//! println!("Training accuracy: {}, test accuracy: {}",
|
||||
//! results.mean_test_score(), results.mean_train_score());
|
||||
//! ```
|
||||
//!
|
||||
//! The function [cross_val_predict](./fn.cross_val_predict.html) has a similar interface to `cross_val_score`,
|
||||
//! but instead of test error it calculates predictions for all samples in the test set.
|
||||
|
||||
use crate::api::Predictor;
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use rand::Rng;
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
pub(crate) mod kfold;
|
||||
|
||||
pub use kfold::{KFold, KFoldIter};
|
||||
|
||||
/// An interface for the K-Folds cross-validator
|
||||
pub trait BaseKFold {
|
||||
/// An iterator over indices that split data into training and test set.
|
||||
type Output: Iterator<Item = (Vec<usize>, Vec<usize>)>;
|
||||
/// Return a tuple containing the the training set indices for that split and
|
||||
/// the testing set indices for that split.
|
||||
fn split<T: RealNumber, M: Matrix<T>>(&self, x: &M) -> Self::Output;
|
||||
/// Returns the number of splits
|
||||
fn n_splits(&self) -> usize;
|
||||
}
|
||||
|
||||
/// Splits data into 2 disjoint datasets.
|
||||
/// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes.
|
||||
/// * `y` - target values, should be of size _M_
|
||||
/// * `y` - target values, should be of size _N_
|
||||
/// * `test_size`, (0, 1] - the proportion of the dataset to include in the test split.
|
||||
/// * `shuffle`, - whether or not to shuffle the data before splitting
|
||||
pub fn train_test_split<T: RealNumber, M: Matrix<T>>(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
test_size: f32,
|
||||
shuffle: bool,
|
||||
) -> (M, M, M::RowVector, M::RowVector) {
|
||||
if x.shape().0 != y.len() {
|
||||
panic!(
|
||||
@@ -37,63 +149,150 @@ pub fn train_test_split<T: RealNumber, M: Matrix<T>>(
|
||||
}
|
||||
|
||||
let n = y.len();
|
||||
let m = x.shape().1;
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut n_test = 0;
|
||||
let mut index = vec![false; n];
|
||||
let n_test = ((n as f32) * test_size) as usize;
|
||||
|
||||
for i in 0..n {
|
||||
let p_test: f32 = rng.gen();
|
||||
if p_test <= test_size {
|
||||
index[i] = true;
|
||||
n_test += 1;
|
||||
}
|
||||
if n_test < 1 {
|
||||
panic!("number of sample is too small {}", n);
|
||||
}
|
||||
|
||||
let n_train = n - n_test;
|
||||
let mut indices: Vec<usize> = (0..n).collect();
|
||||
|
||||
let mut x_train = M::zeros(n_train, m);
|
||||
let mut x_test = M::zeros(n_test, m);
|
||||
let mut y_train = M::RowVector::zeros(n_train);
|
||||
let mut y_test = M::RowVector::zeros(n_test);
|
||||
|
||||
let mut r_train = 0;
|
||||
let mut r_test = 0;
|
||||
|
||||
for r in 0..n {
|
||||
if index[r] {
|
||||
//sample belongs to test
|
||||
for c in 0..m {
|
||||
x_test.set(r_test, c, x.get(r, c));
|
||||
y_test.set(r_test, y.get(r));
|
||||
}
|
||||
r_test += 1;
|
||||
} else {
|
||||
for c in 0..m {
|
||||
x_train.set(r_train, c, x.get(r, c));
|
||||
y_train.set(r_train, y.get(r));
|
||||
}
|
||||
r_train += 1;
|
||||
}
|
||||
if shuffle {
|
||||
indices.shuffle(&mut thread_rng());
|
||||
}
|
||||
|
||||
let x_train = x.take(&indices[n_test..n], 0);
|
||||
let x_test = x.take(&indices[0..n_test], 0);
|
||||
let y_train = y.take(&indices[n_test..n]);
|
||||
let y_test = y.take(&indices[0..n_test]);
|
||||
|
||||
(x_train, x_test, y_train, y_test)
|
||||
}
|
||||
|
||||
/// Cross validation results.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CrossValidationResult<T: RealNumber> {
|
||||
/// Vector with test scores on each cv split
|
||||
pub test_score: Vec<T>,
|
||||
/// Vector with training scores on each cv split
|
||||
pub train_score: Vec<T>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> CrossValidationResult<T> {
|
||||
/// Average test score
|
||||
pub fn mean_test_score(&self) -> T {
|
||||
self.test_score.sum() / T::from_usize(self.test_score.len()).unwrap()
|
||||
}
|
||||
/// Average training score
|
||||
pub fn mean_train_score(&self) -> T {
|
||||
self.train_score.sum() / T::from_usize(self.train_score.len()).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate an estimator by cross-validation using given metric.
|
||||
/// * `fit_estimator` - a `fit` function of an estimator
|
||||
/// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes.
|
||||
/// * `y` - target values, should be of size _N_
|
||||
/// * `parameters` - parameters of selected estimator. Use `Default::default()` for default parameters.
|
||||
/// * `cv` - the cross-validation splitting strategy, should be an instance of [`BaseKFold`](./trait.BaseKFold.html)
|
||||
/// * `score` - a metric to use for evaluation, see [metrics](../metrics/index.html)
|
||||
pub fn cross_validate<T, M, H, E, K, F, S>(
|
||||
fit_estimator: F,
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: H,
|
||||
cv: K,
|
||||
score: S,
|
||||
) -> Result<CrossValidationResult<T>, Failed>
|
||||
where
|
||||
T: RealNumber,
|
||||
M: Matrix<T>,
|
||||
H: Clone,
|
||||
E: Predictor<M, M::RowVector>,
|
||||
K: BaseKFold,
|
||||
F: Fn(&M, &M::RowVector, H) -> Result<E, Failed>,
|
||||
S: Fn(&M::RowVector, &M::RowVector) -> T,
|
||||
{
|
||||
let k = cv.n_splits();
|
||||
let mut test_score = Vec::with_capacity(k);
|
||||
let mut train_score = Vec::with_capacity(k);
|
||||
|
||||
for (train_idx, test_idx) in cv.split(x) {
|
||||
let train_x = x.take(&train_idx, 0);
|
||||
let train_y = y.take(&train_idx);
|
||||
let test_x = x.take(&test_idx, 0);
|
||||
let test_y = y.take(&test_idx);
|
||||
|
||||
let estimator = fit_estimator(&train_x, &train_y, parameters.clone())?;
|
||||
|
||||
train_score.push(score(&train_y, &estimator.predict(&train_x)?));
|
||||
test_score.push(score(&test_y, &estimator.predict(&test_x)?));
|
||||
}
|
||||
|
||||
Ok(CrossValidationResult {
|
||||
test_score,
|
||||
train_score,
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate cross-validated estimates for each input data point.
|
||||
/// The data is split according to the cv parameter. Each sample belongs to exactly one test set, and its prediction is computed with an estimator fitted on the corresponding training set.
|
||||
/// * `fit_estimator` - a `fit` function of an estimator
|
||||
/// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes.
|
||||
/// * `y` - target values, should be of size _N_
|
||||
/// * `parameters` - parameters of selected estimator. Use `Default::default()` for default parameters.
|
||||
/// * `cv` - the cross-validation splitting strategy, should be an instance of [`BaseKFold`](./trait.BaseKFold.html)
|
||||
pub fn cross_val_predict<T, M, H, E, K, F>(
|
||||
fit_estimator: F,
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: H,
|
||||
cv: K,
|
||||
) -> Result<M::RowVector, Failed>
|
||||
where
|
||||
T: RealNumber,
|
||||
M: Matrix<T>,
|
||||
H: Clone,
|
||||
E: Predictor<M, M::RowVector>,
|
||||
K: BaseKFold,
|
||||
F: Fn(&M, &M::RowVector, H) -> Result<E, Failed>,
|
||||
{
|
||||
let mut y_hat = M::RowVector::zeros(y.len());
|
||||
|
||||
for (train_idx, test_idx) in cv.split(x) {
|
||||
let train_x = x.take(&train_idx, 0);
|
||||
let train_y = y.take(&train_idx);
|
||||
let test_x = x.take(&test_idx, 0);
|
||||
|
||||
let estimator = fit_estimator(&train_x, &train_y, parameters.clone())?;
|
||||
|
||||
let y_test_hat = estimator.predict(&test_x)?;
|
||||
for (i, &idx) in test_idx.iter().enumerate() {
|
||||
y_hat.set(idx, y_test_hat.get(i));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(y_hat)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::metrics::{accuracy, mean_absolute_error};
|
||||
use crate::model_selection::kfold::KFold;
|
||||
use crate::neighbors::knn_regressor::KNNRegressor;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn run_train_test_split() {
|
||||
let n = 100;
|
||||
let x: DenseMatrix<f64> = DenseMatrix::rand(100, 3);
|
||||
let y = vec![0f64; 100];
|
||||
let n = 123;
|
||||
let x: DenseMatrix<f64> = DenseMatrix::rand(n, 3);
|
||||
let y = vec![0f64; n];
|
||||
|
||||
let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2);
|
||||
let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true);
|
||||
|
||||
assert!(
|
||||
x_train.shape().0 > (n as f64 * 0.65) as usize
|
||||
@@ -106,4 +305,148 @@ mod tests {
|
||||
assert_eq!(x_train.shape().0, y_train.len());
|
||||
assert_eq!(x_test.shape().0, y_test.len());
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct NoParameters {}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn test_cross_validate_biased() {
|
||||
struct BiasedEstimator {}
|
||||
|
||||
impl BiasedEstimator {
|
||||
fn fit<M: Matrix<f32>>(
|
||||
_: &M,
|
||||
_: &M::RowVector,
|
||||
_: NoParameters,
|
||||
) -> Result<BiasedEstimator, Failed> {
|
||||
Ok(BiasedEstimator {})
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: Matrix<f32>> Predictor<M, M::RowVector> for BiasedEstimator {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let (n, _) = x.shape();
|
||||
Ok(M::RowVector::zeros(n))
|
||||
}
|
||||
}
|
||||
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[5.7, 2.8, 4.5, 1.3],
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
let y = vec![
|
||||
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
|
||||
let cv = KFold {
|
||||
n_splits: 5,
|
||||
..KFold::default()
|
||||
};
|
||||
|
||||
let results =
|
||||
cross_validate(BiasedEstimator::fit, &x, &y, NoParameters {}, cv, &accuracy).unwrap();
|
||||
|
||||
assert_eq!(0.4, results.mean_test_score());
|
||||
assert_eq!(0.4, results.mean_train_score());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn test_cross_validate_knn() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
&[284.599, 335.1, 165., 110.929, 1950., 61.187],
|
||||
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
&[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
|
||||
&[365.385, 187., 354.7, 115.094, 1953., 64.989],
|
||||
&[363.112, 357.8, 335., 116.219, 1954., 63.761],
|
||||
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
&[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
&[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
|
||||
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
let y = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
|
||||
let cv = KFold {
|
||||
n_splits: 5,
|
||||
..KFold::default()
|
||||
};
|
||||
|
||||
let results = cross_validate(
|
||||
KNNRegressor::fit,
|
||||
&x,
|
||||
&y,
|
||||
Default::default(),
|
||||
cv,
|
||||
&mean_absolute_error,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(results.mean_test_score() < 15.0);
|
||||
assert!(results.mean_train_score() < results.mean_test_score());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn test_cross_val_predict_knn() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
&[284.599, 335.1, 165., 110.929, 1950., 61.187],
|
||||
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
&[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
|
||||
&[365.385, 187., 354.7, 115.094, 1953., 64.989],
|
||||
&[363.112, 357.8, 335., 116.219, 1954., 63.761],
|
||||
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
&[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
&[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
|
||||
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
let y = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
|
||||
let cv = KFold {
|
||||
n_splits: 2,
|
||||
..KFold::default()
|
||||
};
|
||||
|
||||
let y_hat = cross_val_predict(KNNRegressor::fit, &x, &y, Default::default(), cv).unwrap();
|
||||
|
||||
assert!(mean_absolute_error(&y, &y_hat) < 10.0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,484 @@
|
||||
//! # Bernoulli Naive Bayes
|
||||
//!
|
||||
//! Bernoulli Naive Bayes classifier is a variant of [Naive Bayes](../index.html) for the data that is distributed according to multivariate Bernoulli distribution.
|
||||
//! It is used for discrete data with binary features. One example of a binary feature is a word that occurs in the text or not.
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::naive_bayes::bernoulli::BernoulliNB;
|
||||
//!
|
||||
//! // Training data points are:
|
||||
//! // Chinese Beijing Chinese (class: China)
|
||||
//! // Chinese Chinese Shanghai (class: China)
|
||||
//! // Chinese Macao (class: China)
|
||||
//! // Tokyo Japan Chinese (class: Japan)
|
||||
//! let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
//! &[1., 1., 0., 0., 0., 0.],
|
||||
//! &[0., 1., 0., 0., 1., 0.],
|
||||
//! &[0., 1., 0., 1., 0., 0.],
|
||||
//! &[0., 1., 1., 0., 0., 1.],
|
||||
//! ]);
|
||||
//! let y = vec![0., 0., 0., 1.];
|
||||
//!
|
||||
//! let nb = BernoulliNB::fit(&x, &y, Default::default()).unwrap();
|
||||
//!
|
||||
//! // Testing data point is:
|
||||
//! // Chinese Chinese Chinese Tokyo Japan
|
||||
//! let x_test = DenseMatrix::<f64>::from_2d_array(&[&[0., 1., 1., 0., 0., 1.]]);
|
||||
//! let y_hat = nb.predict(&x_test).unwrap();
|
||||
//! ```
|
||||
//!
|
||||
//! ## References:
|
||||
//!
|
||||
//! * ["Introduction to Information Retrieval", Manning C. D., Raghavan P., Schutze H., 2009, Chapter 13 ](https://nlp.stanford.edu/IR-book/information-retrieval-book.html)
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::row_iter;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::math::vector::RealNumberVector;
|
||||
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Naive Bayes classifier for Bearnoulli features
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
struct BernoulliNBDistribution<T: RealNumber> {
|
||||
/// class labels known to the classifier
|
||||
class_labels: Vec<T>,
|
||||
/// number of training samples observed in each class
|
||||
class_count: Vec<usize>,
|
||||
/// probability of each class
|
||||
class_priors: Vec<T>,
|
||||
/// Number of samples encountered for each (class, feature)
|
||||
feature_count: Vec<Vec<usize>>,
|
||||
/// probability of features per class
|
||||
feature_log_prob: Vec<Vec<T>>,
|
||||
/// Number of features of each sample
|
||||
n_features: usize,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> PartialEq for BernoulliNBDistribution<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.class_labels == other.class_labels
|
||||
&& self.class_count == other.class_count
|
||||
&& self.class_priors == other.class_priors
|
||||
&& self.feature_count == other.feature_count
|
||||
&& self.n_features == other.n_features
|
||||
{
|
||||
for (a, b) in self
|
||||
.feature_log_prob
|
||||
.iter()
|
||||
.zip(other.feature_log_prob.iter())
|
||||
{
|
||||
if !a.approximate_eq(b, T::epsilon()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for BernoulliNBDistribution<T> {
|
||||
fn prior(&self, class_index: usize) -> T {
|
||||
self.class_priors[class_index]
|
||||
}
|
||||
|
||||
fn log_likelihood(&self, class_index: usize, j: &M::RowVector) -> T {
|
||||
let mut likelihood = T::zero();
|
||||
for feature in 0..j.len() {
|
||||
let value = j.get(feature);
|
||||
if value == T::one() {
|
||||
likelihood += self.feature_log_prob[class_index][feature];
|
||||
} else {
|
||||
likelihood += (T::one() - self.feature_log_prob[class_index][feature].exp()).ln();
|
||||
}
|
||||
}
|
||||
likelihood
|
||||
}
|
||||
|
||||
fn classes(&self) -> &Vec<T> {
|
||||
&self.class_labels
|
||||
}
|
||||
}
|
||||
|
||||
/// `BernoulliNB` parameters. Use `Default::default()` for default values.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BernoulliNBParameters<T: RealNumber> {
|
||||
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||
pub alpha: T,
|
||||
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
|
||||
pub priors: Option<Vec<T>>,
|
||||
/// Threshold for binarizing (mapping to booleans) of sample features. If None, input is presumed to already consist of binary vectors.
|
||||
pub binarize: Option<T>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> BernoulliNBParameters<T> {
|
||||
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||
pub fn with_alpha(mut self, alpha: T) -> Self {
|
||||
self.alpha = alpha;
|
||||
self
|
||||
}
|
||||
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
|
||||
pub fn with_priors(mut self, priors: Vec<T>) -> Self {
|
||||
self.priors = Some(priors);
|
||||
self
|
||||
}
|
||||
/// Threshold for binarizing (mapping to booleans) of sample features. If None, input is presumed to already consist of binary vectors.
|
||||
pub fn with_binarize(mut self, binarize: T) -> Self {
|
||||
self.binarize = Some(binarize);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for BernoulliNBParameters<T> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
alpha: T::one(),
|
||||
priors: None,
|
||||
binarize: Some(T::zero()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> BernoulliNBDistribution<T> {
|
||||
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
|
||||
/// * `x` - training data.
|
||||
/// * `y` - vector with target values (classes) of length N.
|
||||
/// * `priors` - Optional vector with prior probabilities of the classes. If not defined,
|
||||
/// priors are adjusted according to the data.
|
||||
/// * `alpha` - Additive (Laplace/Lidstone) smoothing parameter.
|
||||
/// * `binarize` - Threshold for binarizing.
|
||||
pub fn fit<M: Matrix<T>>(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
alpha: T,
|
||||
priors: Option<Vec<T>>,
|
||||
) -> Result<Self, Failed> {
|
||||
let (n_samples, n_features) = x.shape();
|
||||
let y_samples = y.len();
|
||||
if y_samples != n_samples {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Size of x should equal size of y; |x|=[{}], |y|=[{}]",
|
||||
n_samples, y_samples
|
||||
)));
|
||||
}
|
||||
|
||||
if n_samples == 0 {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Size of x and y should greater than 0; |x|=[{}]",
|
||||
n_samples
|
||||
)));
|
||||
}
|
||||
if alpha < T::zero() {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Alpha should be greater than 0; |alpha|=[{}]",
|
||||
alpha
|
||||
)));
|
||||
}
|
||||
|
||||
let y = y.to_vec();
|
||||
|
||||
let (class_labels, indices) = <Vec<T> as RealNumberVector<T>>::unique_with_indices(&y);
|
||||
let mut class_count = vec![0_usize; class_labels.len()];
|
||||
|
||||
for class_index in indices.iter() {
|
||||
class_count[*class_index] += 1;
|
||||
}
|
||||
|
||||
let class_priors = if let Some(class_priors) = priors {
|
||||
if class_priors.len() != class_labels.len() {
|
||||
return Err(Failed::fit(
|
||||
"Size of priors provided does not match the number of classes of the data.",
|
||||
));
|
||||
}
|
||||
class_priors
|
||||
} else {
|
||||
class_count
|
||||
.iter()
|
||||
.map(|&c| T::from(c).unwrap() / T::from(n_samples).unwrap())
|
||||
.collect()
|
||||
};
|
||||
|
||||
let mut feature_in_class_counter = vec![vec![0_usize; n_features]; class_labels.len()];
|
||||
|
||||
for (row, class_index) in row_iter(x).zip(indices) {
|
||||
for (idx, row_i) in row.iter().enumerate().take(n_features) {
|
||||
feature_in_class_counter[class_index][idx] +=
|
||||
row_i.to_usize().ok_or_else(|| {
|
||||
Failed::fit(&format!(
|
||||
"Elements of the matrix should be 1.0 or 0.0 |found|=[{}]",
|
||||
row_i
|
||||
))
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
let feature_log_prob = feature_in_class_counter
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(class_index, feature_count)| {
|
||||
feature_count
|
||||
.iter()
|
||||
.map(|&count| {
|
||||
((T::from(count).unwrap() + alpha)
|
||||
/ (T::from(class_count[class_index]).unwrap() + alpha * T::two()))
|
||||
.ln()
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(Self {
|
||||
class_labels,
|
||||
class_priors,
|
||||
class_count,
|
||||
feature_count: feature_in_class_counter,
|
||||
feature_log_prob,
|
||||
n_features,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// BernoulliNB implements the naive Bayes algorithm for data that follows the Bernoulli
|
||||
/// distribution.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct BernoulliNB<T: RealNumber, M: Matrix<T>> {
|
||||
inner: BaseNaiveBayes<T, M, BernoulliNBDistribution<T>>,
|
||||
binarize: Option<T>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, BernoulliNBParameters<T>>
|
||||
for BernoulliNB<T, M>
|
||||
{
|
||||
fn fit(x: &M, y: &M::RowVector, parameters: BernoulliNBParameters<T>) -> Result<Self, Failed> {
|
||||
BernoulliNB::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for BernoulliNB<T, M> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> BernoulliNB<T, M> {
|
||||
/// Fits BernoulliNB with given data
|
||||
/// * `x` - training data of size NxM where N is the number of samples and M is the number of
|
||||
/// features.
|
||||
/// * `y` - vector with target values (classes) of length N.
|
||||
/// * `parameters` - additional parameters like class priors, alpha for smoothing and
|
||||
/// binarizing threshold.
|
||||
pub fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: BernoulliNBParameters<T>,
|
||||
) -> Result<Self, Failed> {
|
||||
let distribution = if let Some(threshold) = parameters.binarize {
|
||||
BernoulliNBDistribution::fit(
|
||||
&(x.binarize(threshold)),
|
||||
y,
|
||||
parameters.alpha,
|
||||
parameters.priors,
|
||||
)?
|
||||
} else {
|
||||
BernoulliNBDistribution::fit(x, y, parameters.alpha, parameters.priors)?
|
||||
};
|
||||
|
||||
let inner = BaseNaiveBayes::fit(distribution)?;
|
||||
Ok(Self {
|
||||
inner,
|
||||
binarize: parameters.binarize,
|
||||
})
|
||||
}
|
||||
|
||||
/// Estimates the class labels for the provided data.
|
||||
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
|
||||
/// Returns a vector of size N with class estimates.
|
||||
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
if let Some(threshold) = self.binarize {
|
||||
self.inner.predict(&(x.binarize(threshold)))
|
||||
} else {
|
||||
self.inner.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
/// Class labels known to the classifier.
|
||||
/// Returns a vector of size n_classes.
|
||||
pub fn classes(&self) -> &Vec<T> {
|
||||
&self.inner.distribution.class_labels
|
||||
}
|
||||
|
||||
/// Number of training samples observed in each class.
|
||||
/// Returns a vector of size n_classes.
|
||||
pub fn class_count(&self) -> &Vec<usize> {
|
||||
&self.inner.distribution.class_count
|
||||
}
|
||||
|
||||
/// Number of features of each sample
|
||||
pub fn n_features(&self) -> usize {
|
||||
self.inner.distribution.n_features
|
||||
}
|
||||
|
||||
/// Number of samples encountered for each (class, feature)
|
||||
/// Returns a 2d vector of shape (n_classes, n_features)
|
||||
pub fn feature_count(&self) -> &Vec<Vec<usize>> {
|
||||
&self.inner.distribution.feature_count
|
||||
}
|
||||
|
||||
/// Empirical log probability of features given a class
|
||||
pub fn feature_log_prob(&self) -> &Vec<Vec<T>> {
|
||||
&self.inner.distribution.feature_log_prob
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn run_bernoulli_naive_bayes() {
|
||||
// Tests that BernoulliNB when alpha=1.0 gives the same values as
|
||||
// those given for the toy example in Manning, Raghavan, and
|
||||
// Schuetze's "Introduction to Information Retrieval" book:
|
||||
// https://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html
|
||||
|
||||
// Training data points are:
|
||||
// Chinese Beijing Chinese (class: China)
|
||||
// Chinese Chinese Shanghai (class: China)
|
||||
// Chinese Macao (class: China)
|
||||
// Tokyo Japan Chinese (class: Japan)
|
||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
&[1., 1., 0., 0., 0., 0.],
|
||||
&[0., 1., 0., 0., 1., 0.],
|
||||
&[0., 1., 0., 1., 0., 0.],
|
||||
&[0., 1., 1., 0., 0., 1.],
|
||||
]);
|
||||
let y = vec![0., 0., 0., 1.];
|
||||
let bnb = BernoulliNB::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
assert_eq!(bnb.inner.distribution.class_priors, &[0.75, 0.25]);
|
||||
assert_eq!(
|
||||
bnb.feature_log_prob(),
|
||||
&[
|
||||
&[
|
||||
-0.916290731874155,
|
||||
-0.2231435513142097,
|
||||
-1.6094379124341003,
|
||||
-0.916290731874155,
|
||||
-0.916290731874155,
|
||||
-1.6094379124341003
|
||||
],
|
||||
&[
|
||||
-1.0986122886681098,
|
||||
-0.40546510810816444,
|
||||
-0.40546510810816444,
|
||||
-1.0986122886681098,
|
||||
-1.0986122886681098,
|
||||
-0.40546510810816444
|
||||
]
|
||||
]
|
||||
);
|
||||
|
||||
// Testing data point is:
|
||||
// Chinese Chinese Chinese Tokyo Japan
|
||||
let x_test = DenseMatrix::<f64>::from_2d_array(&[&[0., 1., 1., 0., 0., 1.]]);
|
||||
let y_hat = bnb.predict(&x_test).unwrap();
|
||||
|
||||
assert_eq!(y_hat, &[1.]);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn bernoulli_nb_scikit_parity() {
|
||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
&[2., 4., 0., 0., 2., 1., 2., 4., 2., 0.],
|
||||
&[3., 4., 0., 2., 1., 0., 1., 4., 0., 3.],
|
||||
&[1., 4., 2., 4., 1., 0., 1., 2., 3., 2.],
|
||||
&[0., 3., 3., 4., 1., 0., 3., 1., 1., 1.],
|
||||
&[0., 2., 1., 4., 3., 4., 1., 2., 3., 1.],
|
||||
&[3., 2., 4., 1., 3., 0., 2., 4., 0., 2.],
|
||||
&[3., 1., 3., 0., 2., 0., 4., 4., 3., 4.],
|
||||
&[2., 2., 2., 0., 1., 1., 2., 1., 0., 1.],
|
||||
&[3., 3., 2., 2., 0., 2., 3., 2., 2., 3.],
|
||||
&[4., 3., 4., 4., 4., 2., 2., 0., 1., 4.],
|
||||
&[3., 4., 2., 2., 1., 4., 4., 4., 1., 3.],
|
||||
&[3., 0., 1., 4., 4., 0., 0., 3., 2., 4.],
|
||||
&[2., 0., 3., 3., 1., 2., 0., 2., 4., 1.],
|
||||
&[2., 4., 0., 4., 2., 4., 1., 3., 1., 4.],
|
||||
&[0., 2., 2., 3., 4., 0., 4., 4., 4., 4.],
|
||||
]);
|
||||
let y = vec![2., 2., 0., 0., 0., 2., 1., 1., 0., 1., 0., 0., 2., 0., 2.];
|
||||
let bnb = BernoulliNB::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let y_hat = bnb.predict(&x).unwrap();
|
||||
|
||||
assert_eq!(bnb.classes(), &[0., 1., 2.]);
|
||||
assert_eq!(bnb.class_count(), &[7, 3, 5]);
|
||||
assert_eq!(bnb.n_features(), 10);
|
||||
assert_eq!(
|
||||
bnb.feature_count(),
|
||||
&[
|
||||
&[5, 6, 6, 7, 6, 4, 6, 7, 7, 7],
|
||||
&[3, 3, 3, 1, 3, 2, 3, 2, 2, 3],
|
||||
&[4, 4, 3, 4, 5, 2, 4, 5, 3, 4]
|
||||
]
|
||||
);
|
||||
|
||||
assert!(bnb
|
||||
.inner
|
||||
.distribution
|
||||
.class_priors
|
||||
.approximate_eq(&vec!(0.46, 0.2, 0.33), 1e-2));
|
||||
assert!(bnb.feature_log_prob()[1].approximate_eq(
|
||||
&vec![
|
||||
-0.22314355,
|
||||
-0.22314355,
|
||||
-0.22314355,
|
||||
-0.91629073,
|
||||
-0.22314355,
|
||||
-0.51082562,
|
||||
-0.22314355,
|
||||
-0.51082562,
|
||||
-0.51082562,
|
||||
-0.22314355
|
||||
],
|
||||
1e-1
|
||||
));
|
||||
assert!(y_hat.approximate_eq(
|
||||
&vec!(2.0, 2.0, 0.0, 0.0, 0.0, 2.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0),
|
||||
1e-5
|
||||
));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
&[1., 1., 0., 0., 0., 0.],
|
||||
&[0., 1., 0., 0., 1., 0.],
|
||||
&[0., 1., 0., 1., 0., 0.],
|
||||
&[0., 1., 1., 0., 0., 1.],
|
||||
]);
|
||||
let y = vec![0., 0., 0., 1.];
|
||||
|
||||
let bnb = BernoulliNB::fit(&x, &y, Default::default()).unwrap();
|
||||
let deserialized_bnb: BernoulliNB<f64, DenseMatrix<f64>> =
|
||||
serde_json::from_str(&serde_json::to_string(&bnb).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(bnb, deserialized_bnb);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,493 @@
|
||||
//! # Categorical Naive Bayes
|
||||
//!
|
||||
//! Categorical Naive Bayes is a variant of [Naive Bayes](../index.html) for the categorically distributed data.
|
||||
//! It assumes that each feature has its own categorical distribution.
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::naive_bayes::categorical::CategoricalNB;
|
||||
//!
|
||||
//! let x = DenseMatrix::from_2d_array(&[
|
||||
//! &[3., 4., 0., 1.],
|
||||
//! &[3., 0., 0., 1.],
|
||||
//! &[4., 4., 1., 2.],
|
||||
//! &[4., 2., 4., 3.],
|
||||
//! &[4., 2., 4., 2.],
|
||||
//! &[4., 1., 1., 0.],
|
||||
//! &[1., 1., 1., 1.],
|
||||
//! &[0., 4., 1., 0.],
|
||||
//! &[0., 3., 2., 1.],
|
||||
//! &[0., 3., 1., 1.],
|
||||
//! &[3., 4., 0., 1.],
|
||||
//! &[3., 4., 2., 4.],
|
||||
//! &[0., 3., 1., 2.],
|
||||
//! &[0., 4., 1., 2.],
|
||||
//! ]);
|
||||
//! let y = vec![0., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 0.];
|
||||
//!
|
||||
//! let nb = CategoricalNB::fit(&x, &y, Default::default()).unwrap();
|
||||
//! let y_hat = nb.predict(&x).unwrap();
|
||||
//! ```
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Naive Bayes classifier for categorical features
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
struct CategoricalNBDistribution<T: RealNumber> {
|
||||
/// number of training samples observed in each class
|
||||
class_count: Vec<usize>,
|
||||
/// class labels known to the classifier
|
||||
class_labels: Vec<T>,
|
||||
/// probability of each class
|
||||
class_priors: Vec<T>,
|
||||
coefficients: Vec<Vec<Vec<T>>>,
|
||||
/// Number of features of each sample
|
||||
n_features: usize,
|
||||
/// Number of categories for each feature
|
||||
n_categories: Vec<usize>,
|
||||
/// Holds arrays of shape (n_classes, n_categories of respective feature)
|
||||
/// for each feature. Each array provides the number of samples
|
||||
/// encountered for each class and category of the specific feature.
|
||||
category_count: Vec<Vec<Vec<usize>>>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> PartialEq for CategoricalNBDistribution<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.class_labels == other.class_labels
|
||||
&& self.class_priors == other.class_priors
|
||||
&& self.n_features == other.n_features
|
||||
&& self.n_categories == other.n_categories
|
||||
&& self.class_count == other.class_count
|
||||
{
|
||||
if self.coefficients.len() != other.coefficients.len() {
|
||||
return false;
|
||||
}
|
||||
for (a, b) in self.coefficients.iter().zip(other.coefficients.iter()) {
|
||||
if a.len() != b.len() {
|
||||
return false;
|
||||
}
|
||||
for (a_i, b_i) in a.iter().zip(b.iter()) {
|
||||
if a_i.len() != b_i.len() {
|
||||
return false;
|
||||
}
|
||||
for (a_i_j, b_i_j) in a_i.iter().zip(b_i.iter()) {
|
||||
if (*a_i_j - *b_i_j).abs() > T::epsilon() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for CategoricalNBDistribution<T> {
|
||||
fn prior(&self, class_index: usize) -> T {
|
||||
if class_index >= self.class_labels.len() {
|
||||
T::zero()
|
||||
} else {
|
||||
self.class_priors[class_index]
|
||||
}
|
||||
}
|
||||
|
||||
fn log_likelihood(&self, class_index: usize, j: &M::RowVector) -> T {
|
||||
if class_index < self.class_labels.len() {
|
||||
let mut likelihood = T::zero();
|
||||
for feature in 0..j.len() {
|
||||
let value = j.get(feature).floor().to_usize().unwrap();
|
||||
if self.coefficients[feature][class_index].len() > value {
|
||||
likelihood += self.coefficients[feature][class_index][value];
|
||||
} else {
|
||||
return T::zero();
|
||||
}
|
||||
}
|
||||
likelihood
|
||||
} else {
|
||||
T::zero()
|
||||
}
|
||||
}
|
||||
|
||||
fn classes(&self) -> &Vec<T> {
|
||||
&self.class_labels
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> CategoricalNBDistribution<T> {
|
||||
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
|
||||
/// * `x` - training data.
|
||||
/// * `y` - vector with target values (classes) of length N.
|
||||
/// * `alpha` - Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||
pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, alpha: T) -> Result<Self, Failed> {
|
||||
if alpha < T::zero() {
|
||||
return Err(Failed::fit(&format!(
|
||||
"alpha should be >= 0, alpha=[{}]",
|
||||
alpha
|
||||
)));
|
||||
}
|
||||
|
||||
let (n_samples, n_features) = x.shape();
|
||||
let y_samples = y.len();
|
||||
if y_samples != n_samples {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Size of x should equal size of y; |x|=[{}], |y|=[{}]",
|
||||
n_samples, y_samples
|
||||
)));
|
||||
}
|
||||
|
||||
if n_samples == 0 {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Size of x and y should greater than 0; |x|=[{}]",
|
||||
n_samples
|
||||
)));
|
||||
}
|
||||
let y: Vec<usize> = y
|
||||
.to_vec()
|
||||
.iter()
|
||||
.map(|y_i| y_i.floor().to_usize().unwrap())
|
||||
.collect();
|
||||
|
||||
let y_max = y
|
||||
.iter()
|
||||
.max()
|
||||
.ok_or_else(|| Failed::fit("Failed to get the labels of y."))?;
|
||||
|
||||
let class_labels: Vec<T> = (0..*y_max + 1)
|
||||
.map(|label| T::from(label).unwrap())
|
||||
.collect();
|
||||
let mut class_count = vec![0_usize; class_labels.len()];
|
||||
for elem in y.iter() {
|
||||
class_count[*elem] += 1;
|
||||
}
|
||||
|
||||
let mut n_categories: Vec<usize> = Vec::with_capacity(n_features);
|
||||
for feature in 0..n_features {
|
||||
let feature_max = x
|
||||
.get_col_as_vec(feature)
|
||||
.iter()
|
||||
.map(|f_i| f_i.floor().to_usize().unwrap())
|
||||
.max()
|
||||
.ok_or_else(|| {
|
||||
Failed::fit(&format!(
|
||||
"Failed to get the categories for feature = {}",
|
||||
feature
|
||||
))
|
||||
})?;
|
||||
n_categories.push(feature_max + 1);
|
||||
}
|
||||
|
||||
let mut coefficients: Vec<Vec<Vec<T>>> = Vec::with_capacity(class_labels.len());
|
||||
let mut category_count: Vec<Vec<Vec<usize>>> = Vec::with_capacity(class_labels.len());
|
||||
for (feature_index, &n_categories_i) in n_categories.iter().enumerate().take(n_features) {
|
||||
let mut coef_i: Vec<Vec<T>> = Vec::with_capacity(n_features);
|
||||
let mut category_count_i: Vec<Vec<usize>> = Vec::with_capacity(n_features);
|
||||
for (label, &label_count) in class_labels.iter().zip(class_count.iter()) {
|
||||
let col = x
|
||||
.get_col_as_vec(feature_index)
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(i, _j)| T::from(y[*i]).unwrap() == *label)
|
||||
.map(|(_, j)| *j)
|
||||
.collect::<Vec<T>>();
|
||||
let mut feat_count: Vec<usize> = vec![0_usize; n_categories_i];
|
||||
for row in col.iter() {
|
||||
let index = row.floor().to_usize().unwrap();
|
||||
feat_count[index] += 1;
|
||||
}
|
||||
|
||||
let coef_i_j = feat_count
|
||||
.iter()
|
||||
.map(|c| {
|
||||
((T::from(*c).unwrap() + alpha)
|
||||
/ (T::from(label_count).unwrap()
|
||||
+ T::from(n_categories_i).unwrap() * alpha))
|
||||
.ln()
|
||||
})
|
||||
.collect::<Vec<T>>();
|
||||
category_count_i.push(feat_count);
|
||||
coef_i.push(coef_i_j);
|
||||
}
|
||||
category_count.push(category_count_i);
|
||||
coefficients.push(coef_i);
|
||||
}
|
||||
|
||||
let class_priors = class_count
|
||||
.iter()
|
||||
.map(|&count| T::from(count).unwrap() / T::from(n_samples).unwrap())
|
||||
.collect::<Vec<T>>();
|
||||
|
||||
Ok(Self {
|
||||
class_count,
|
||||
class_labels,
|
||||
class_priors,
|
||||
coefficients,
|
||||
n_features,
|
||||
n_categories,
|
||||
category_count,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// `CategoricalNB` parameters. Use `Default::default()` for default values.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CategoricalNBParameters<T: RealNumber> {
|
||||
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||
pub alpha: T,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> CategoricalNBParameters<T> {
|
||||
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||
pub fn with_alpha(mut self, alpha: T) -> Self {
|
||||
self.alpha = alpha;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for CategoricalNBParameters<T> {
|
||||
fn default() -> Self {
|
||||
Self { alpha: T::one() }
|
||||
}
|
||||
}
|
||||
|
||||
/// CategoricalNB implements the categorical naive Bayes algorithm for categorically distributed data.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct CategoricalNB<T: RealNumber, M: Matrix<T>> {
|
||||
inner: BaseNaiveBayes<T, M, CategoricalNBDistribution<T>>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, CategoricalNBParameters<T>>
|
||||
for CategoricalNB<T, M>
|
||||
{
|
||||
fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: CategoricalNBParameters<T>,
|
||||
) -> Result<Self, Failed> {
|
||||
CategoricalNB::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for CategoricalNB<T, M> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> CategoricalNB<T, M> {
|
||||
/// Fits CategoricalNB with given data
|
||||
/// * `x` - training data of size NxM where N is the number of samples and M is the number of
|
||||
/// features.
|
||||
/// * `y` - vector with target values (classes) of length N.
|
||||
/// * `parameters` - additional parameters like alpha for smoothing
|
||||
pub fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: CategoricalNBParameters<T>,
|
||||
) -> Result<Self, Failed> {
|
||||
let alpha = parameters.alpha;
|
||||
let distribution = CategoricalNBDistribution::fit(x, y, alpha)?;
|
||||
let inner = BaseNaiveBayes::fit(distribution)?;
|
||||
Ok(Self { inner })
|
||||
}
|
||||
|
||||
/// Estimates the class labels for the provided data.
|
||||
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
|
||||
/// Returns a vector of size N with class estimates.
|
||||
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.inner.predict(x)
|
||||
}
|
||||
|
||||
/// Class labels known to the classifier.
|
||||
/// Returns a vector of size n_classes.
|
||||
pub fn classes(&self) -> &Vec<T> {
|
||||
&self.inner.distribution.class_labels
|
||||
}
|
||||
|
||||
/// Number of training samples observed in each class.
|
||||
/// Returns a vector of size n_classes.
|
||||
pub fn class_count(&self) -> &Vec<usize> {
|
||||
&self.inner.distribution.class_count
|
||||
}
|
||||
|
||||
/// Number of features of each sample
|
||||
pub fn n_features(&self) -> usize {
|
||||
self.inner.distribution.n_features
|
||||
}
|
||||
|
||||
/// Number of features of each sample
|
||||
pub fn n_categories(&self) -> &Vec<usize> {
|
||||
&self.inner.distribution.n_categories
|
||||
}
|
||||
|
||||
/// Holds arrays of shape (n_classes, n_categories of respective feature)
|
||||
/// for each feature. Each array provides the number of samples
|
||||
/// encountered for each class and category of the specific feature.
|
||||
pub fn category_count(&self) -> &Vec<Vec<Vec<usize>>> {
|
||||
&self.inner.distribution.category_count
|
||||
}
|
||||
/// Holds arrays of shape (n_classes, n_categories of respective feature)
|
||||
/// for each feature. Each array provides the empirical log probability
|
||||
/// of categories given the respective feature and class, ``P(x_i|y)``.
|
||||
pub fn feature_log_prob(&self) -> &Vec<Vec<Vec<T>>> {
|
||||
&self.inner.distribution.coefficients
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn run_categorical_naive_bayes() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[0., 2., 1., 0.],
|
||||
&[0., 2., 1., 1.],
|
||||
&[1., 2., 1., 0.],
|
||||
&[2., 1., 1., 0.],
|
||||
&[2., 0., 0., 0.],
|
||||
&[2., 0., 0., 1.],
|
||||
&[1., 0., 0., 1.],
|
||||
&[0., 1., 1., 0.],
|
||||
&[0., 0., 0., 0.],
|
||||
&[2., 1., 0., 0.],
|
||||
&[0., 1., 0., 1.],
|
||||
&[1., 1., 1., 1.],
|
||||
&[1., 2., 0., 0.],
|
||||
&[2., 1., 1., 1.],
|
||||
]);
|
||||
let y = vec![0., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 0.];
|
||||
|
||||
let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
// checking parity with scikit
|
||||
assert_eq!(cnb.classes(), &[0., 1.]);
|
||||
assert_eq!(cnb.class_count(), &[5, 9]);
|
||||
assert_eq!(cnb.n_features(), 4);
|
||||
assert_eq!(cnb.n_categories(), &[3, 3, 2, 2]);
|
||||
assert_eq!(
|
||||
cnb.category_count(),
|
||||
&vec![
|
||||
vec![vec![3, 0, 2], vec![2, 4, 3]],
|
||||
vec![vec![1, 2, 2], vec![3, 4, 2]],
|
||||
vec![vec![1, 4], vec![6, 3]],
|
||||
vec![vec![2, 3], vec![6, 3]]
|
||||
]
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
cnb.feature_log_prob(),
|
||||
&vec![
|
||||
vec![
|
||||
vec![
|
||||
-0.6931471805599453,
|
||||
-2.0794415416798357,
|
||||
-0.9808292530117262
|
||||
],
|
||||
vec![
|
||||
-1.3862943611198906,
|
||||
-0.8754687373538999,
|
||||
-1.0986122886681098
|
||||
]
|
||||
],
|
||||
vec![
|
||||
vec![
|
||||
-1.3862943611198906,
|
||||
-0.9808292530117262,
|
||||
-0.9808292530117262
|
||||
],
|
||||
vec![
|
||||
-1.0986122886681098,
|
||||
-0.8754687373538999,
|
||||
-1.3862943611198906
|
||||
]
|
||||
],
|
||||
vec![
|
||||
vec![-1.252762968495368, -0.3364722366212129],
|
||||
vec![-0.45198512374305727, -1.0116009116784799]
|
||||
],
|
||||
vec![
|
||||
vec![-0.8472978603872037, -0.5596157879354228],
|
||||
vec![-0.45198512374305727, -1.0116009116784799]
|
||||
]
|
||||
]
|
||||
);
|
||||
|
||||
let x_test = DenseMatrix::from_2d_array(&[&[0., 2., 1., 0.], &[2., 2., 0., 0.]]);
|
||||
let y_hat = cnb.predict(&x_test).unwrap();
|
||||
assert_eq!(y_hat, vec![0., 1.]);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn run_categorical_naive_bayes2() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[3., 4., 0., 1.],
|
||||
&[3., 0., 0., 1.],
|
||||
&[4., 4., 1., 2.],
|
||||
&[4., 2., 4., 3.],
|
||||
&[4., 2., 4., 2.],
|
||||
&[4., 1., 1., 0.],
|
||||
&[1., 1., 1., 1.],
|
||||
&[0., 4., 1., 0.],
|
||||
&[0., 3., 2., 1.],
|
||||
&[0., 3., 1., 1.],
|
||||
&[3., 4., 0., 1.],
|
||||
&[3., 4., 2., 4.],
|
||||
&[0., 3., 1., 2.],
|
||||
&[0., 4., 1., 2.],
|
||||
]);
|
||||
let y = vec![0., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 0.];
|
||||
|
||||
let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap();
|
||||
let y_hat = cnb.predict(&x).unwrap();
|
||||
assert_eq!(
|
||||
y_hat,
|
||||
vec![0., 0., 1., 1., 1., 0., 1., 0., 1., 1., 0., 1., 1., 1.]
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
&[3., 4., 0., 1.],
|
||||
&[3., 0., 0., 1.],
|
||||
&[4., 4., 1., 2.],
|
||||
&[4., 2., 4., 3.],
|
||||
&[4., 2., 4., 2.],
|
||||
&[4., 1., 1., 0.],
|
||||
&[1., 1., 1., 1.],
|
||||
&[0., 4., 1., 0.],
|
||||
&[0., 3., 2., 1.],
|
||||
&[0., 3., 1., 1.],
|
||||
&[3., 4., 0., 1.],
|
||||
&[3., 4., 2., 4.],
|
||||
&[0., 3., 1., 2.],
|
||||
&[0., 4., 1., 2.],
|
||||
]);
|
||||
|
||||
let y = vec![0., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 0.];
|
||||
let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let deserialized_cnb: CategoricalNB<f64, DenseMatrix<f64>> =
|
||||
serde_json::from_str(&serde_json::to_string(&cnb).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(cnb, deserialized_cnb);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,340 @@
|
||||
//! # Gaussian Naive Bayes
|
||||
//!
|
||||
//! Gaussian Naive Bayes is a variant of [Naive Bayes](../index.html) for the data that follows Gaussian distribution and
|
||||
//! it supports continuous valued features conforming to a normal distribution.
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::naive_bayes::gaussian::GaussianNB;
|
||||
//!
|
||||
//! let x = DenseMatrix::from_2d_array(&[
|
||||
//! &[-1., -1.],
|
||||
//! &[-2., -1.],
|
||||
//! &[-3., -2.],
|
||||
//! &[ 1., 1.],
|
||||
//! &[ 2., 1.],
|
||||
//! &[ 3., 2.],
|
||||
//! ]);
|
||||
//! let y = vec![1., 1., 1., 2., 2., 2.];
|
||||
//!
|
||||
//! let nb = GaussianNB::fit(&x, &y, Default::default()).unwrap();
|
||||
//! let y_hat = nb.predict(&x).unwrap();
|
||||
//! ```
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::row_iter;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::math::vector::RealNumberVector;
|
||||
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Naive Bayes classifier using Gaussian distribution
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, PartialEq)]
|
||||
struct GaussianNBDistribution<T: RealNumber> {
|
||||
/// class labels known to the classifier
|
||||
class_labels: Vec<T>,
|
||||
/// number of training samples observed in each class
|
||||
class_count: Vec<usize>,
|
||||
/// probability of each class.
|
||||
class_priors: Vec<T>,
|
||||
/// variance of each feature per class
|
||||
var: Vec<Vec<T>>,
|
||||
/// mean of each feature per class
|
||||
theta: Vec<Vec<T>>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for GaussianNBDistribution<T> {
|
||||
fn prior(&self, class_index: usize) -> T {
|
||||
if class_index >= self.class_labels.len() {
|
||||
T::zero()
|
||||
} else {
|
||||
self.class_priors[class_index]
|
||||
}
|
||||
}
|
||||
|
||||
fn log_likelihood(&self, class_index: usize, j: &M::RowVector) -> T {
|
||||
let mut likelihood = T::zero();
|
||||
for feature in 0..j.len() {
|
||||
let value = j.get(feature);
|
||||
let mean = self.theta[class_index][feature];
|
||||
let variance = self.var[class_index][feature];
|
||||
likelihood += self.calculate_log_probability(value, mean, variance);
|
||||
}
|
||||
likelihood
|
||||
}
|
||||
|
||||
fn classes(&self) -> &Vec<T> {
|
||||
&self.class_labels
|
||||
}
|
||||
}
|
||||
|
||||
/// `GaussianNB` parameters. Use `Default::default()` for default values.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct GaussianNBParameters<T: RealNumber> {
|
||||
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
|
||||
pub priors: Option<Vec<T>>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> GaussianNBParameters<T> {
|
||||
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
|
||||
pub fn with_priors(mut self, priors: Vec<T>) -> Self {
|
||||
self.priors = Some(priors);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> GaussianNBDistribution<T> {
|
||||
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
|
||||
/// * `x` - training data.
|
||||
/// * `y` - vector with target values (classes) of length N.
|
||||
/// * `priors` - Optional vector with prior probabilities of the classes. If not defined,
|
||||
/// priors are adjusted according to the data.
|
||||
pub fn fit<M: Matrix<T>>(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
priors: Option<Vec<T>>,
|
||||
) -> Result<Self, Failed> {
|
||||
let (n_samples, n_features) = x.shape();
|
||||
let y_samples = y.len();
|
||||
if y_samples != n_samples {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Size of x should equal size of y; |x|=[{}], |y|=[{}]",
|
||||
n_samples, y_samples
|
||||
)));
|
||||
}
|
||||
|
||||
if n_samples == 0 {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Size of x and y should greater than 0; |x|=[{}]",
|
||||
n_samples
|
||||
)));
|
||||
}
|
||||
let y = y.to_vec();
|
||||
let (class_labels, indices) = <Vec<T> as RealNumberVector<T>>::unique_with_indices(&y);
|
||||
|
||||
let mut class_count = vec![0_usize; class_labels.len()];
|
||||
|
||||
let mut subdataset: Vec<Vec<Vec<T>>> = vec![vec![]; class_labels.len()];
|
||||
|
||||
for (row, class_index) in row_iter(x).zip(indices.iter()) {
|
||||
class_count[*class_index] += 1;
|
||||
subdataset[*class_index].push(row);
|
||||
}
|
||||
|
||||
let class_priors = if let Some(class_priors) = priors {
|
||||
if class_priors.len() != class_labels.len() {
|
||||
return Err(Failed::fit(
|
||||
"Size of priors provided does not match the number of classes of the data.",
|
||||
));
|
||||
}
|
||||
class_priors
|
||||
} else {
|
||||
class_count
|
||||
.iter()
|
||||
.map(|&c| T::from(c).unwrap() / T::from(n_samples).unwrap())
|
||||
.collect()
|
||||
};
|
||||
|
||||
let subdataset: Vec<M> = subdataset
|
||||
.into_iter()
|
||||
.map(|v| {
|
||||
let mut m = M::zeros(v.len(), n_features);
|
||||
for (row_i, v_i) in v.iter().enumerate() {
|
||||
for (col_j, v_i_j) in v_i.iter().enumerate().take(n_features) {
|
||||
m.set(row_i, col_j, *v_i_j);
|
||||
}
|
||||
}
|
||||
m
|
||||
})
|
||||
.collect();
|
||||
|
||||
let (var, theta): (Vec<Vec<T>>, Vec<Vec<T>>) = subdataset
|
||||
.iter()
|
||||
.map(|data| (data.var(0), data.mean(0)))
|
||||
.unzip();
|
||||
|
||||
Ok(Self {
|
||||
class_labels,
|
||||
class_count,
|
||||
class_priors,
|
||||
var,
|
||||
theta,
|
||||
})
|
||||
}
|
||||
|
||||
/// Calculate probability of x equals to a value of a Gaussian distribution given its mean and its
|
||||
/// variance.
|
||||
fn calculate_log_probability(&self, value: T, mean: T, variance: T) -> T {
|
||||
let pi = T::from(std::f64::consts::PI).unwrap();
|
||||
-((value - mean).powf(T::two()) / (T::two() * variance))
|
||||
- (T::two() * pi).ln() / T::two()
|
||||
- (variance).ln() / T::two()
|
||||
}
|
||||
}
|
||||
|
||||
/// GaussianNB implements the naive Bayes algorithm for data that follows the Gaussian
|
||||
/// distribution.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct GaussianNB<T: RealNumber, M: Matrix<T>> {
|
||||
inner: BaseNaiveBayes<T, M, GaussianNBDistribution<T>>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, GaussianNBParameters<T>>
|
||||
for GaussianNB<T, M>
|
||||
{
|
||||
fn fit(x: &M, y: &M::RowVector, parameters: GaussianNBParameters<T>) -> Result<Self, Failed> {
|
||||
GaussianNB::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for GaussianNB<T, M> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> GaussianNB<T, M> {
|
||||
/// Fits GaussianNB with given data
|
||||
/// * `x` - training data of size NxM where N is the number of samples and M is the number of
|
||||
/// features.
|
||||
/// * `y` - vector with target values (classes) of length N.
|
||||
/// * `parameters` - additional parameters like class priors.
|
||||
pub fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: GaussianNBParameters<T>,
|
||||
) -> Result<Self, Failed> {
|
||||
let distribution = GaussianNBDistribution::fit(x, y, parameters.priors)?;
|
||||
let inner = BaseNaiveBayes::fit(distribution)?;
|
||||
Ok(Self { inner })
|
||||
}
|
||||
|
||||
/// Estimates the class labels for the provided data.
|
||||
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
|
||||
/// Returns a vector of size N with class estimates.
|
||||
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.inner.predict(x)
|
||||
}
|
||||
|
||||
/// Class labels known to the classifier.
|
||||
/// Returns a vector of size n_classes.
|
||||
pub fn classes(&self) -> &Vec<T> {
|
||||
&self.inner.distribution.class_labels
|
||||
}
|
||||
|
||||
/// Number of training samples observed in each class.
|
||||
/// Returns a vector of size n_classes.
|
||||
pub fn class_count(&self) -> &Vec<usize> {
|
||||
&self.inner.distribution.class_count
|
||||
}
|
||||
|
||||
/// Probability of each class
|
||||
/// Returns a vector of size n_classes.
|
||||
pub fn class_priors(&self) -> &Vec<T> {
|
||||
&self.inner.distribution.class_priors
|
||||
}
|
||||
|
||||
/// Mean of each feature per class
|
||||
/// Returns a 2d vector of shape (n_classes, n_features).
|
||||
pub fn theta(&self) -> &Vec<Vec<T>> {
|
||||
&self.inner.distribution.theta
|
||||
}
|
||||
|
||||
/// Variance of each feature per class
|
||||
/// Returns a 2d vector of shape (n_classes, n_features).
|
||||
pub fn var(&self) -> &Vec<Vec<T>> {
|
||||
&self.inner.distribution.var
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn run_gaussian_naive_bayes() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[-1., -1.],
|
||||
&[-2., -1.],
|
||||
&[-3., -2.],
|
||||
&[1., 1.],
|
||||
&[2., 1.],
|
||||
&[3., 2.],
|
||||
]);
|
||||
let y = vec![1., 1., 1., 2., 2., 2.];
|
||||
|
||||
let gnb = GaussianNB::fit(&x, &y, Default::default()).unwrap();
|
||||
let y_hat = gnb.predict(&x).unwrap();
|
||||
assert_eq!(y_hat, y);
|
||||
|
||||
assert_eq!(gnb.classes(), &[1., 2.]);
|
||||
|
||||
assert_eq!(gnb.class_count(), &[3, 3]);
|
||||
|
||||
assert_eq!(
|
||||
gnb.var(),
|
||||
&[
|
||||
&[0.666666666666667, 0.22222222222222232],
|
||||
&[0.666666666666667, 0.22222222222222232]
|
||||
]
|
||||
);
|
||||
|
||||
assert_eq!(gnb.class_priors(), &[0.5, 0.5]);
|
||||
|
||||
assert_eq!(
|
||||
gnb.theta(),
|
||||
&[&[-2., -1.3333333333333333], &[2., 1.3333333333333333]]
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn run_gaussian_naive_bayes_with_priors() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[-1., -1.],
|
||||
&[-2., -1.],
|
||||
&[-3., -2.],
|
||||
&[1., 1.],
|
||||
&[2., 1.],
|
||||
&[3., 2.],
|
||||
]);
|
||||
let y = vec![1., 1., 1., 2., 2., 2.];
|
||||
|
||||
let priors = vec![0.3, 0.7];
|
||||
let parameters = GaussianNBParameters::default().with_priors(priors.clone());
|
||||
let gnb = GaussianNB::fit(&x, &y, parameters).unwrap();
|
||||
|
||||
assert_eq!(gnb.class_priors(), &priors);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
&[-1., -1.],
|
||||
&[-2., -1.],
|
||||
&[-3., -2.],
|
||||
&[1., 1.],
|
||||
&[2., 1.],
|
||||
&[3., 2.],
|
||||
]);
|
||||
let y = vec![1., 1., 1., 2., 2., 2.];
|
||||
|
||||
let gnb = GaussianNB::fit(&x, &y, Default::default()).unwrap();
|
||||
let deserialized_gnb: GaussianNB<f64, DenseMatrix<f64>> =
|
||||
serde_json::from_str(&serde_json::to_string(&gnb).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(gnb, deserialized_gnb);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
//! # Naive Bayes
|
||||
//!
|
||||
//! Naive Bayes (NB) is a simple but powerful machine learning algorithm.
|
||||
//! Naive Bayes classifier is based on Bayes’ Theorem with an ssumption of conditional independence
|
||||
//! between every pair of features given the value of the class variable.
|
||||
//!
|
||||
//! Bayes’ theorem can be written as
|
||||
//!
|
||||
//! \\[ P(y | X) = \frac{P(y)P(X| y)}{P(X)} \\]
|
||||
//!
|
||||
//! where
|
||||
//!
|
||||
//! * \\(X = (x_1,...x_n)\\) represents the predictors.
|
||||
//! * \\(P(y | X)\\) is the probability of class _y_ given the data X
|
||||
//! * \\(P(X| y)\\) is the probability of data X given the class _y_.
|
||||
//! * \\(P(y)\\) is the probability of class y. This is called the prior probability of y.
|
||||
//! * \\(P(y | X)\\) is the probability of the data (regardless of the class value).
|
||||
//!
|
||||
//! The naive conditional independence assumption let us rewrite this equation as
|
||||
//!
|
||||
//! \\[ P(y | x_1,...x_n) = \frac{P(y)\prod_{i=1}^nP(x_i|y)}{P(x_1,...x_n)} \\]
|
||||
//!
|
||||
//!
|
||||
//! The denominator can be removed since \\(P(x_1,...x_n)\\) is constrant for all the entries in the dataset.
|
||||
//!
|
||||
//! \\[ P(y | x_1,...x_n) \propto P(y)\prod_{i=1}^nP(x_i|y) \\]
|
||||
//!
|
||||
//! To find class y from predictors X we use this equation
|
||||
//!
|
||||
//! \\[ y = \underset{y}{argmax} P(y)\prod_{i=1}^nP(x_i|y) \\]
|
||||
//!
|
||||
//! ## References:
|
||||
//!
|
||||
//! * ["Machine Learning: A Probabilistic Perspective", Kevin P. Murphy, 2012, Chapter 3 ](https://mitpress.mit.edu/books/machine-learning-1)
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// Distribution used in the Naive Bayes classifier.
|
||||
pub(crate) trait NBDistribution<T: RealNumber, M: Matrix<T>> {
|
||||
/// Prior of class at the given index.
|
||||
fn prior(&self, class_index: usize) -> T;
|
||||
|
||||
/// Logarithm of conditional probability of sample j given class in the specified index.
|
||||
fn log_likelihood(&self, class_index: usize, j: &M::RowVector) -> T;
|
||||
|
||||
/// Possible classes of the distribution.
|
||||
fn classes(&self) -> &Vec<T>;
|
||||
}
|
||||
|
||||
/// Base struct for the Naive Bayes classifier.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub(crate) struct BaseNaiveBayes<T: RealNumber, M: Matrix<T>, D: NBDistribution<T, M>> {
|
||||
distribution: D,
|
||||
_phantom_t: PhantomData<T>,
|
||||
_phantom_m: PhantomData<M>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, D: NBDistribution<T, M>> BaseNaiveBayes<T, M, D> {
|
||||
/// Fits NB classifier to a given NBdistribution.
|
||||
/// * `distribution` - NBDistribution of the training data
|
||||
pub fn fit(distribution: D) -> Result<Self, Failed> {
|
||||
Ok(Self {
|
||||
distribution,
|
||||
_phantom_t: PhantomData,
|
||||
_phantom_m: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
/// Estimates the class labels for the provided data.
|
||||
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
|
||||
/// Returns a vector of size N with class estimates.
|
||||
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let y_classes = self.distribution.classes();
|
||||
let (rows, _) = x.shape();
|
||||
let predictions = (0..rows)
|
||||
.map(|row_index| {
|
||||
let row = x.get_row(row_index);
|
||||
let (prediction, _probability) = y_classes
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(class_index, class)| {
|
||||
(
|
||||
class,
|
||||
self.distribution.log_likelihood(class_index, &row)
|
||||
+ self.distribution.prior(class_index).ln(),
|
||||
)
|
||||
})
|
||||
.max_by(|(_, p1), (_, p2)| p1.partial_cmp(p2).unwrap())
|
||||
.unwrap();
|
||||
*prediction
|
||||
})
|
||||
.collect::<Vec<T>>();
|
||||
let y_hat = M::RowVector::from_array(&predictions);
|
||||
Ok(y_hat)
|
||||
}
|
||||
}
|
||||
pub mod bernoulli;
|
||||
pub mod categorical;
|
||||
pub mod gaussian;
|
||||
pub mod multinomial;
|
||||
@@ -0,0 +1,434 @@
|
||||
//! # Multinomial Naive Bayes
|
||||
//!
|
||||
//! Multinomial Naive Bayes classifier is a variant of [Naive Bayes](../index.html) for the multinomially distributed data.
|
||||
//! It is often used for discrete data with predictors representing the number of times an event was observed in a particular instance,
|
||||
//! for example frequency of the words present in the document.
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::naive_bayes::multinomial::MultinomialNB;
|
||||
//!
|
||||
//! // Training data points are:
|
||||
//! // Chinese Beijing Chinese (class: China)
|
||||
//! // Chinese Chinese Shanghai (class: China)
|
||||
//! // Chinese Macao (class: China)
|
||||
//! // Tokyo Japan Chinese (class: Japan)
|
||||
//! let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
//! &[1., 2., 0., 0., 0., 0.],
|
||||
//! &[0., 2., 0., 0., 1., 0.],
|
||||
//! &[0., 1., 0., 1., 0., 0.],
|
||||
//! &[0., 1., 1., 0., 0., 1.],
|
||||
//! ]);
|
||||
//! let y = vec![0., 0., 0., 1.];
|
||||
//! let nb = MultinomialNB::fit(&x, &y, Default::default()).unwrap();
|
||||
//!
|
||||
//! // Testing data point is:
|
||||
//! // Chinese Chinese Chinese Tokyo Japan
|
||||
//! let x_test = DenseMatrix::<f64>::from_2d_array(&[&[0., 3., 1., 0., 0., 1.]]);
|
||||
//! let y_hat = nb.predict(&x_test).unwrap();
|
||||
//! ```
|
||||
//!
|
||||
//! ## References:
|
||||
//!
|
||||
//! * ["Introduction to Information Retrieval", Manning C. D., Raghavan P., Schutze H., 2009, Chapter 13 ](https://nlp.stanford.edu/IR-book/information-retrieval-book.html)
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::row_iter;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::math::vector::RealNumberVector;
|
||||
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Naive Bayes classifier for Multinomial features
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, PartialEq)]
|
||||
struct MultinomialNBDistribution<T: RealNumber> {
|
||||
/// class labels known to the classifier
|
||||
class_labels: Vec<T>,
|
||||
/// number of training samples observed in each class
|
||||
class_count: Vec<usize>,
|
||||
/// probability of each class
|
||||
class_priors: Vec<T>,
|
||||
/// Empirical log probability of features given a class
|
||||
feature_log_prob: Vec<Vec<T>>,
|
||||
/// Number of samples encountered for each (class, feature)
|
||||
feature_count: Vec<Vec<usize>>,
|
||||
/// Number of features of each sample
|
||||
n_features: usize,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for MultinomialNBDistribution<T> {
|
||||
fn prior(&self, class_index: usize) -> T {
|
||||
self.class_priors[class_index]
|
||||
}
|
||||
|
||||
fn log_likelihood(&self, class_index: usize, j: &M::RowVector) -> T {
|
||||
let mut likelihood = T::zero();
|
||||
for feature in 0..j.len() {
|
||||
let value = j.get(feature);
|
||||
likelihood += value * self.feature_log_prob[class_index][feature];
|
||||
}
|
||||
likelihood
|
||||
}
|
||||
|
||||
fn classes(&self) -> &Vec<T> {
|
||||
&self.class_labels
|
||||
}
|
||||
}
|
||||
|
||||
/// `MultinomialNB` parameters. Use `Default::default()` for default values.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MultinomialNBParameters<T: RealNumber> {
|
||||
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||
pub alpha: T,
|
||||
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
|
||||
pub priors: Option<Vec<T>>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> MultinomialNBParameters<T> {
|
||||
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||
pub fn with_alpha(mut self, alpha: T) -> Self {
|
||||
self.alpha = alpha;
|
||||
self
|
||||
}
|
||||
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
|
||||
pub fn with_priors(mut self, priors: Vec<T>) -> Self {
|
||||
self.priors = Some(priors);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for MultinomialNBParameters<T> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
alpha: T::one(),
|
||||
priors: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> MultinomialNBDistribution<T> {
|
||||
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
|
||||
/// * `x` - training data.
|
||||
/// * `y` - vector with target values (classes) of length N.
|
||||
/// * `priors` - Optional vector with prior probabilities of the classes. If not defined,
|
||||
/// priors are adjusted according to the data.
|
||||
/// * `alpha` - Additive (Laplace/Lidstone) smoothing parameter.
|
||||
pub fn fit<M: Matrix<T>>(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
alpha: T,
|
||||
priors: Option<Vec<T>>,
|
||||
) -> Result<Self, Failed> {
|
||||
let (n_samples, n_features) = x.shape();
|
||||
let y_samples = y.len();
|
||||
if y_samples != n_samples {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Size of x should equal size of y; |x|=[{}], |y|=[{}]",
|
||||
n_samples, y_samples
|
||||
)));
|
||||
}
|
||||
|
||||
if n_samples == 0 {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Size of x and y should greater than 0; |x|=[{}]",
|
||||
n_samples
|
||||
)));
|
||||
}
|
||||
if alpha < T::zero() {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Alpha should be greater than 0; |alpha|=[{}]",
|
||||
alpha
|
||||
)));
|
||||
}
|
||||
|
||||
let y = y.to_vec();
|
||||
|
||||
let (class_labels, indices) = <Vec<T> as RealNumberVector<T>>::unique_with_indices(&y);
|
||||
let mut class_count = vec![0_usize; class_labels.len()];
|
||||
|
||||
for class_index in indices.iter() {
|
||||
class_count[*class_index] += 1;
|
||||
}
|
||||
|
||||
let class_priors = if let Some(class_priors) = priors {
|
||||
if class_priors.len() != class_labels.len() {
|
||||
return Err(Failed::fit(
|
||||
"Size of priors provided does not match the number of classes of the data.",
|
||||
));
|
||||
}
|
||||
class_priors
|
||||
} else {
|
||||
class_count
|
||||
.iter()
|
||||
.map(|&c| T::from(c).unwrap() / T::from(n_samples).unwrap())
|
||||
.collect()
|
||||
};
|
||||
|
||||
let mut feature_in_class_counter = vec![vec![0_usize; n_features]; class_labels.len()];
|
||||
|
||||
for (row, class_index) in row_iter(x).zip(indices) {
|
||||
for (idx, row_i) in row.iter().enumerate().take(n_features) {
|
||||
feature_in_class_counter[class_index][idx] +=
|
||||
row_i.to_usize().ok_or_else(|| {
|
||||
Failed::fit(&format!(
|
||||
"Elements of the matrix should be convertible to usize |found|=[{}]",
|
||||
row_i
|
||||
))
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
let feature_log_prob = feature_in_class_counter
|
||||
.iter()
|
||||
.map(|feature_count| {
|
||||
let n_c: usize = feature_count.iter().sum();
|
||||
feature_count
|
||||
.iter()
|
||||
.map(|&count| {
|
||||
((T::from(count).unwrap() + alpha)
|
||||
/ (T::from(n_c).unwrap() + alpha * T::from(n_features).unwrap()))
|
||||
.ln()
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(Self {
|
||||
class_count,
|
||||
class_labels,
|
||||
class_priors,
|
||||
feature_log_prob,
|
||||
feature_count: feature_in_class_counter,
|
||||
n_features,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// MultinomialNB implements the naive Bayes algorithm for multinomially distributed data.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct MultinomialNB<T: RealNumber, M: Matrix<T>> {
|
||||
inner: BaseNaiveBayes<T, M, MultinomialNBDistribution<T>>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, MultinomialNBParameters<T>>
|
||||
for MultinomialNB<T, M>
|
||||
{
|
||||
fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: MultinomialNBParameters<T>,
|
||||
) -> Result<Self, Failed> {
|
||||
MultinomialNB::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for MultinomialNB<T, M> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> MultinomialNB<T, M> {
|
||||
/// Fits MultinomialNB with given data
|
||||
/// * `x` - training data of size NxM where N is the number of samples and M is the number of
|
||||
/// features.
|
||||
/// * `y` - vector with target values (classes) of length N.
|
||||
/// * `parameters` - additional parameters like class priors, alpha for smoothing and
|
||||
/// binarizing threshold.
|
||||
pub fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: MultinomialNBParameters<T>,
|
||||
) -> Result<Self, Failed> {
|
||||
let distribution =
|
||||
MultinomialNBDistribution::fit(x, y, parameters.alpha, parameters.priors)?;
|
||||
let inner = BaseNaiveBayes::fit(distribution)?;
|
||||
Ok(Self { inner })
|
||||
}
|
||||
|
||||
/// Estimates the class labels for the provided data.
|
||||
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
|
||||
/// Returns a vector of size N with class estimates.
|
||||
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.inner.predict(x)
|
||||
}
|
||||
|
||||
/// Class labels known to the classifier.
|
||||
/// Returns a vector of size n_classes.
|
||||
pub fn classes(&self) -> &Vec<T> {
|
||||
&self.inner.distribution.class_labels
|
||||
}
|
||||
|
||||
/// Number of training samples observed in each class.
|
||||
/// Returns a vector of size n_classes.
|
||||
pub fn class_count(&self) -> &Vec<usize> {
|
||||
&self.inner.distribution.class_count
|
||||
}
|
||||
|
||||
/// Empirical log probability of features given a class, P(x_i|y).
|
||||
/// Returns a 2d vector of shape (n_classes, n_features)
|
||||
pub fn feature_log_prob(&self) -> &Vec<Vec<T>> {
|
||||
&self.inner.distribution.feature_log_prob
|
||||
}
|
||||
|
||||
/// Number of features of each sample
|
||||
pub fn n_features(&self) -> usize {
|
||||
self.inner.distribution.n_features
|
||||
}
|
||||
|
||||
/// Number of samples encountered for each (class, feature)
|
||||
/// Returns a 2d vector of shape (n_classes, n_features)
|
||||
pub fn feature_count(&self) -> &Vec<Vec<usize>> {
|
||||
&self.inner.distribution.feature_count
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn run_multinomial_naive_bayes() {
|
||||
// Tests that MultinomialNB when alpha=1.0 gives the same values as
|
||||
// those given for the toy example in Manning, Raghavan, and
|
||||
// Schuetze's "Introduction to Information Retrieval" book:
|
||||
// https://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html
|
||||
|
||||
// Training data points are:
|
||||
// Chinese Beijing Chinese (class: China)
|
||||
// Chinese Chinese Shanghai (class: China)
|
||||
// Chinese Macao (class: China)
|
||||
// Tokyo Japan Chinese (class: Japan)
|
||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
&[1., 2., 0., 0., 0., 0.],
|
||||
&[0., 2., 0., 0., 1., 0.],
|
||||
&[0., 1., 0., 1., 0., 0.],
|
||||
&[0., 1., 1., 0., 0., 1.],
|
||||
]);
|
||||
let y = vec![0., 0., 0., 1.];
|
||||
let mnb = MultinomialNB::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
assert_eq!(mnb.classes(), &[0., 1.]);
|
||||
assert_eq!(mnb.class_count(), &[3, 1]);
|
||||
|
||||
assert_eq!(mnb.inner.distribution.class_priors, &[0.75, 0.25]);
|
||||
assert_eq!(
|
||||
mnb.feature_log_prob(),
|
||||
&[
|
||||
&[
|
||||
(1_f64 / 7_f64).ln(),
|
||||
(3_f64 / 7_f64).ln(),
|
||||
(1_f64 / 14_f64).ln(),
|
||||
(1_f64 / 7_f64).ln(),
|
||||
(1_f64 / 7_f64).ln(),
|
||||
(1_f64 / 14_f64).ln()
|
||||
],
|
||||
&[
|
||||
(1_f64 / 9_f64).ln(),
|
||||
(2_f64 / 9_f64).ln(),
|
||||
(2_f64 / 9_f64).ln(),
|
||||
(1_f64 / 9_f64).ln(),
|
||||
(1_f64 / 9_f64).ln(),
|
||||
(2_f64 / 9_f64).ln()
|
||||
]
|
||||
]
|
||||
);
|
||||
|
||||
// Testing data point is:
|
||||
// Chinese Chinese Chinese Tokyo Japan
|
||||
let x_test = DenseMatrix::<f64>::from_2d_array(&[&[0., 3., 1., 0., 0., 1.]]);
|
||||
let y_hat = mnb.predict(&x_test).unwrap();
|
||||
|
||||
assert_eq!(y_hat, &[0.]);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn multinomial_nb_scikit_parity() {
|
||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
&[2., 4., 0., 0., 2., 1., 2., 4., 2., 0.],
|
||||
&[3., 4., 0., 2., 1., 0., 1., 4., 0., 3.],
|
||||
&[1., 4., 2., 4., 1., 0., 1., 2., 3., 2.],
|
||||
&[0., 3., 3., 4., 1., 0., 3., 1., 1., 1.],
|
||||
&[0., 2., 1., 4., 3., 4., 1., 2., 3., 1.],
|
||||
&[3., 2., 4., 1., 3., 0., 2., 4., 0., 2.],
|
||||
&[3., 1., 3., 0., 2., 0., 4., 4., 3., 4.],
|
||||
&[2., 2., 2., 0., 1., 1., 2., 1., 0., 1.],
|
||||
&[3., 3., 2., 2., 0., 2., 3., 2., 2., 3.],
|
||||
&[4., 3., 4., 4., 4., 2., 2., 0., 1., 4.],
|
||||
&[3., 4., 2., 2., 1., 4., 4., 4., 1., 3.],
|
||||
&[3., 0., 1., 4., 4., 0., 0., 3., 2., 4.],
|
||||
&[2., 0., 3., 3., 1., 2., 0., 2., 4., 1.],
|
||||
&[2., 4., 0., 4., 2., 4., 1., 3., 1., 4.],
|
||||
&[0., 2., 2., 3., 4., 0., 4., 4., 4., 4.],
|
||||
]);
|
||||
let y = vec![2., 2., 0., 0., 0., 2., 1., 1., 0., 1., 0., 0., 2., 0., 2.];
|
||||
let nb = MultinomialNB::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
assert_eq!(nb.n_features(), 10);
|
||||
assert_eq!(
|
||||
nb.feature_count(),
|
||||
&[
|
||||
&[12, 20, 11, 24, 12, 14, 13, 17, 13, 18],
|
||||
&[9, 6, 9, 4, 7, 3, 8, 5, 4, 9],
|
||||
&[10, 12, 9, 9, 11, 3, 9, 18, 10, 10]
|
||||
]
|
||||
);
|
||||
|
||||
let y_hat = nb.predict(&x).unwrap();
|
||||
|
||||
assert!(nb
|
||||
.inner
|
||||
.distribution
|
||||
.class_priors
|
||||
.approximate_eq(&vec!(0.46, 0.2, 0.33), 1e-2));
|
||||
assert!(nb.feature_log_prob()[1].approximate_eq(
|
||||
&vec![
|
||||
-2.00148,
|
||||
-2.35815494,
|
||||
-2.00148,
|
||||
-2.69462718,
|
||||
-2.22462355,
|
||||
-2.91777073,
|
||||
-2.10684052,
|
||||
-2.51230562,
|
||||
-2.69462718,
|
||||
-2.00148
|
||||
],
|
||||
1e-5
|
||||
));
|
||||
assert!(y_hat.approximate_eq(
|
||||
&vec!(2.0, 2.0, 0.0, 0.0, 0.0, 2.0, 2.0, 1.0, 0.0, 1.0, 0.0, 2.0, 0.0, 0.0, 2.0),
|
||||
1e-5
|
||||
));
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
&[1., 1., 0., 0., 0., 0.],
|
||||
&[0., 1., 0., 0., 1., 0.],
|
||||
&[0., 1., 0., 1., 0., 0.],
|
||||
&[0., 1., 1., 0., 0., 1.],
|
||||
]);
|
||||
let y = vec![0., 0., 0., 1.];
|
||||
|
||||
let mnb = MultinomialNB::fit(&x, &y, Default::default()).unwrap();
|
||||
let deserialized_mnb: MultinomialNB<f64, DenseMatrix<f64>> =
|
||||
serde_json::from_str(&serde_json::to_string(&mnb).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(mnb, deserialized_mnb);
|
||||
}
|
||||
}
|
||||
@@ -25,34 +25,47 @@
|
||||
//! &[9., 10.]]);
|
||||
//! let y = vec![2., 2., 2., 3., 3.]; //your class labels
|
||||
//!
|
||||
//! let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||
//! let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||
//! let y_hat = knn.predict(&x).unwrap();
|
||||
//! ```
|
||||
//!
|
||||
//! variable `y_hat` will hold a vector with estimates of class labels
|
||||
//!
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::{row_iter, Matrix};
|
||||
use crate::math::distance::Distance;
|
||||
use crate::math::distance::euclidian::Euclidian;
|
||||
use crate::math::distance::{Distance, Distances};
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::neighbors::{KNNAlgorithm, KNNAlgorithmName, KNNWeightFunction};
|
||||
use crate::neighbors::KNNWeightFunction;
|
||||
|
||||
/// `KNNClassifier` parameters. Use `Default::default()` for default values.
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct KNNClassifierParameters {
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KNNClassifierParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
/// a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
pub distance: D,
|
||||
/// backend search algorithm. See [`knn search algorithms`](../../algorithm/neighbour/index.html). `CoverTree` is default.
|
||||
pub algorithm: KNNAlgorithmName,
|
||||
/// weighting function that is used to calculate estimated class value. Default function is `KNNWeightFunction::Uniform`.
|
||||
pub weight: KNNWeightFunction,
|
||||
/// number of training samples to consider when estimating class for new point. Default value is 3.
|
||||
pub k: usize,
|
||||
/// this parameter is not used
|
||||
t: PhantomData<T>,
|
||||
}
|
||||
|
||||
/// K Nearest Neighbors Classifier
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct KNNClassifier<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
classes: Vec<T>,
|
||||
y: Vec<usize>,
|
||||
@@ -61,12 +74,47 @@ pub struct KNNClassifier<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
k: usize,
|
||||
}
|
||||
|
||||
impl Default for KNNClassifierParameters {
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNClassifierParameters<T, D> {
|
||||
/// number of training samples to consider when estimating class for new point. Default value is 3.
|
||||
pub fn with_k(mut self, k: usize) -> Self {
|
||||
self.k = k;
|
||||
self
|
||||
}
|
||||
/// a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
pub fn with_distance<DD: Distance<Vec<T>, T>>(
|
||||
self,
|
||||
distance: DD,
|
||||
) -> KNNClassifierParameters<T, DD> {
|
||||
KNNClassifierParameters {
|
||||
distance,
|
||||
algorithm: self.algorithm,
|
||||
weight: self.weight,
|
||||
k: self.k,
|
||||
t: PhantomData,
|
||||
}
|
||||
}
|
||||
/// backend search algorithm. See [`knn search algorithms`](../../algorithm/neighbour/index.html). `CoverTree` is default.
|
||||
pub fn with_algorithm(mut self, algorithm: KNNAlgorithmName) -> Self {
|
||||
self.algorithm = algorithm;
|
||||
self
|
||||
}
|
||||
/// weighting function that is used to calculate estimated class value. Default function is `KNNWeightFunction::Uniform`.
|
||||
pub fn with_weight(mut self, weight: KNNWeightFunction) -> Self {
|
||||
self.weight = weight;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for KNNClassifierParameters<T, Euclidian> {
|
||||
fn default() -> Self {
|
||||
KNNClassifierParameters {
|
||||
distance: Distances::euclidian(),
|
||||
algorithm: KNNAlgorithmName::CoverTree,
|
||||
weight: KNNWeightFunction::Uniform,
|
||||
k: 3,
|
||||
t: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -77,7 +125,7 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> PartialEq for KNNClassifier<T, D> {
|
||||
|| self.k != other.k
|
||||
|| self.y.len() != other.y.len()
|
||||
{
|
||||
return false;
|
||||
false
|
||||
} else {
|
||||
for i in 0..self.classes.len() {
|
||||
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
|
||||
@@ -94,19 +142,35 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> PartialEq for KNNClassifier<T, D> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, D: Distance<Vec<T>, T>>
|
||||
SupervisedEstimator<M, M::RowVector, KNNClassifierParameters<T, D>> for KNNClassifier<T, D>
|
||||
{
|
||||
fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: KNNClassifierParameters<T, D>,
|
||||
) -> Result<Self, Failed> {
|
||||
KNNClassifier::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, D: Distance<Vec<T>, T>> Predictor<M, M::RowVector>
|
||||
for KNNClassifier<T, D>
|
||||
{
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||
/// Fits KNN classifier to a NxM matrix where N is number of samples and M is number of features.
|
||||
/// * `x` - training data
|
||||
/// * `y` - vector with target values (classes) of length N
|
||||
/// * `distance` - a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
/// * `y` - vector with target values (classes) of length N
|
||||
/// * `parameters` - additional parameters like search algorithm and k
|
||||
pub fn fit<M: Matrix<T>>(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
distance: D,
|
||||
parameters: KNNClassifierParameters,
|
||||
parameters: KNNClassifierParameters<T, D>,
|
||||
) -> Result<KNNClassifier<T, D>, Failed> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
|
||||
@@ -118,9 +182,9 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||
let mut yi: Vec<usize> = vec![0; y_n];
|
||||
let classes = y_m.unique();
|
||||
|
||||
for i in 0..y_n {
|
||||
for (i, yi_i) in yi.iter_mut().enumerate().take(y_n) {
|
||||
let yc = y_m.get(0, i);
|
||||
yi[i] = classes.iter().position(|c| yc == *c).unwrap();
|
||||
*yi_i = classes.iter().position(|c| yc == *c).unwrap();
|
||||
}
|
||||
|
||||
if x_n != y_n {
|
||||
@@ -138,10 +202,10 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||
}
|
||||
|
||||
Ok(KNNClassifier {
|
||||
classes: classes,
|
||||
classes,
|
||||
y: yi,
|
||||
k: parameters.k,
|
||||
knn_algorithm: parameters.algorithm.fit(data, distance)?,
|
||||
knn_algorithm: parameters.algorithm.fit(data, parameters.distance)?,
|
||||
weight: parameters.weight,
|
||||
})
|
||||
}
|
||||
@@ -165,13 +229,13 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||
let weights = self
|
||||
.weight
|
||||
.calc_weights(search_result.iter().map(|v| v.1).collect());
|
||||
let w_sum = weights.iter().map(|w| *w).sum();
|
||||
let w_sum = weights.iter().copied().sum();
|
||||
|
||||
let mut c = vec![T::zero(); self.classes.len()];
|
||||
let mut max_c = T::zero();
|
||||
let mut max_i = 0;
|
||||
for (r, w) in search_result.iter().zip(weights.iter()) {
|
||||
c[self.y[r.0]] = c[self.y[r.0]] + (*w / w_sum);
|
||||
c[self.y[r.0]] += *w / w_sum;
|
||||
if c[self.y[r.0]] > max_c {
|
||||
max_c = c[self.y[r.0]];
|
||||
max_i = self.y[r.0];
|
||||
@@ -186,19 +250,20 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNClassifier<T, D> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::math::distance::Distances;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn knn_fit_predict() {
|
||||
let x =
|
||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
let y = vec![2., 2., 2., 3., 3.];
|
||||
let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||
let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||
let y_hat = knn.predict(&x).unwrap();
|
||||
assert_eq!(5, Vec::len(&y_hat));
|
||||
assert_eq!(y.to_vec(), y_hat);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn knn_fit_predict_weighted() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1.], &[2.], &[3.], &[4.], &[5.]]);
|
||||
@@ -206,25 +271,25 @@ mod tests {
|
||||
let knn = KNNClassifier::fit(
|
||||
&x,
|
||||
&y,
|
||||
Distances::euclidian(),
|
||||
KNNClassifierParameters {
|
||||
k: 5,
|
||||
algorithm: KNNAlgorithmName::LinearSearch,
|
||||
weight: KNNWeightFunction::Distance,
|
||||
},
|
||||
KNNClassifierParameters::default()
|
||||
.with_k(5)
|
||||
.with_algorithm(KNNAlgorithmName::LinearSearch)
|
||||
.with_weight(KNNWeightFunction::Distance),
|
||||
)
|
||||
.unwrap();
|
||||
let y_hat = knn.predict(&DenseMatrix::from_2d_array(&[&[4.1]])).unwrap();
|
||||
assert_eq!(vec![3.0], y_hat);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x =
|
||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
let y = vec![2., 2., 2., 3., 3.];
|
||||
|
||||
let knn = KNNClassifier::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||
let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let deserialized_knn = bincode::deserialize(&bincode::serialize(&knn).unwrap()).unwrap();
|
||||
|
||||
|
||||
@@ -27,34 +27,48 @@
|
||||
//! &[5., 5.]]);
|
||||
//! let y = vec![1., 2., 3., 4., 5.]; //your target values
|
||||
//!
|
||||
//! let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||
//! let knn = KNNRegressor::fit(&x, &y, Default::default()).unwrap();
|
||||
//! let y_hat = knn.predict(&x).unwrap();
|
||||
//! ```
|
||||
//!
|
||||
//! variable `y_hat` will hold predicted value
|
||||
//!
|
||||
//!
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::{row_iter, BaseVector, Matrix};
|
||||
use crate::math::distance::Distance;
|
||||
use crate::math::distance::euclidian::Euclidian;
|
||||
use crate::math::distance::{Distance, Distances};
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::neighbors::{KNNAlgorithm, KNNAlgorithmName, KNNWeightFunction};
|
||||
use crate::neighbors::KNNWeightFunction;
|
||||
|
||||
/// `KNNRegressor` parameters. Use `Default::default()` for default values.
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct KNNRegressorParameters {
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KNNRegressorParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
/// a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
distance: D,
|
||||
/// backend search algorithm. See [`knn search algorithms`](../../algorithm/neighbour/index.html). `CoverTree` is default.
|
||||
pub algorithm: KNNAlgorithmName,
|
||||
/// weighting function that is used to calculate estimated class value. Default function is `KNNWeightFunction::Uniform`.
|
||||
pub weight: KNNWeightFunction,
|
||||
/// number of training samples to consider when estimating class for new point. Default value is 3.
|
||||
pub k: usize,
|
||||
/// this parameter is not used
|
||||
t: PhantomData<T>,
|
||||
}
|
||||
|
||||
/// K Nearest Neighbors Regressor
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct KNNRegressor<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
y: Vec<T>,
|
||||
knn_algorithm: KNNAlgorithm<T, D>,
|
||||
@@ -62,12 +76,47 @@ pub struct KNNRegressor<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
k: usize,
|
||||
}
|
||||
|
||||
impl Default for KNNRegressorParameters {
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNRegressorParameters<T, D> {
|
||||
/// number of training samples to consider when estimating class for new point. Default value is 3.
|
||||
pub fn with_k(mut self, k: usize) -> Self {
|
||||
self.k = k;
|
||||
self
|
||||
}
|
||||
/// a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
pub fn with_distance<DD: Distance<Vec<T>, T>>(
|
||||
self,
|
||||
distance: DD,
|
||||
) -> KNNRegressorParameters<T, DD> {
|
||||
KNNRegressorParameters {
|
||||
distance,
|
||||
algorithm: self.algorithm,
|
||||
weight: self.weight,
|
||||
k: self.k,
|
||||
t: PhantomData,
|
||||
}
|
||||
}
|
||||
/// backend search algorithm. See [`knn search algorithms`](../../algorithm/neighbour/index.html). `CoverTree` is default.
|
||||
pub fn with_algorithm(mut self, algorithm: KNNAlgorithmName) -> Self {
|
||||
self.algorithm = algorithm;
|
||||
self
|
||||
}
|
||||
/// weighting function that is used to calculate estimated class value. Default function is `KNNWeightFunction::Uniform`.
|
||||
pub fn with_weight(mut self, weight: KNNWeightFunction) -> Self {
|
||||
self.weight = weight;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for KNNRegressorParameters<T, Euclidian> {
|
||||
fn default() -> Self {
|
||||
KNNRegressorParameters {
|
||||
distance: Distances::euclidian(),
|
||||
algorithm: KNNAlgorithmName::CoverTree,
|
||||
weight: KNNWeightFunction::Uniform,
|
||||
k: 3,
|
||||
t: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -75,7 +124,7 @@ impl Default for KNNRegressorParameters {
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> PartialEq for KNNRegressor<T, D> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.k != other.k || self.y.len() != other.y.len() {
|
||||
return false;
|
||||
false
|
||||
} else {
|
||||
for i in 0..self.y.len() {
|
||||
if (self.y[i] - other.y[i]).abs() > T::epsilon() {
|
||||
@@ -87,19 +136,35 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> PartialEq for KNNRegressor<T, D> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, D: Distance<Vec<T>, T>>
|
||||
SupervisedEstimator<M, M::RowVector, KNNRegressorParameters<T, D>> for KNNRegressor<T, D>
|
||||
{
|
||||
fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: KNNRegressorParameters<T, D>,
|
||||
) -> Result<Self, Failed> {
|
||||
KNNRegressor::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, D: Distance<Vec<T>, T>> Predictor<M, M::RowVector>
|
||||
for KNNRegressor<T, D>
|
||||
{
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
|
||||
/// Fits KNN regressor to a NxM matrix where N is number of samples and M is number of features.
|
||||
/// * `x` - training data
|
||||
/// * `y` - vector with real values
|
||||
/// * `distance` - a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
/// * `y` - vector with real values
|
||||
/// * `parameters` - additional parameters like search algorithm and k
|
||||
pub fn fit<M: Matrix<T>>(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
distance: D,
|
||||
parameters: KNNRegressorParameters,
|
||||
parameters: KNNRegressorParameters<T, D>,
|
||||
) -> Result<KNNRegressor<T, D>, Failed> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
|
||||
@@ -115,9 +180,9 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
|
||||
)));
|
||||
}
|
||||
|
||||
if parameters.k <= 1 {
|
||||
if parameters.k < 1 {
|
||||
return Err(Failed::fit(&format!(
|
||||
"k should be > 1, k=[{}]",
|
||||
"k should be > 0, k=[{}]",
|
||||
parameters.k
|
||||
)));
|
||||
}
|
||||
@@ -125,7 +190,7 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
|
||||
Ok(KNNRegressor {
|
||||
y: y.to_vec(),
|
||||
k: parameters.k,
|
||||
knn_algorithm: parameters.algorithm.fit(data, distance)?,
|
||||
knn_algorithm: parameters.algorithm.fit(data, parameters.distance)?,
|
||||
weight: parameters.weight,
|
||||
})
|
||||
}
|
||||
@@ -150,10 +215,10 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNRegressor<T, D> {
|
||||
let weights = self
|
||||
.weight
|
||||
.calc_weights(search_result.iter().map(|v| v.1).collect());
|
||||
let w_sum = weights.iter().map(|w| *w).sum();
|
||||
let w_sum = weights.iter().copied().sum();
|
||||
|
||||
for (r, w) in search_result.iter().zip(weights.iter()) {
|
||||
result = result + self.y[r.0] * (*w / w_sum);
|
||||
result += self.y[r.0] * (*w / w_sum);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
@@ -166,6 +231,7 @@ mod tests {
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::math::distance::Distances;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn knn_fit_predict_weighted() {
|
||||
let x =
|
||||
@@ -175,12 +241,11 @@ mod tests {
|
||||
let knn = KNNRegressor::fit(
|
||||
&x,
|
||||
&y,
|
||||
Distances::euclidian(),
|
||||
KNNRegressorParameters {
|
||||
k: 3,
|
||||
algorithm: KNNAlgorithmName::LinearSearch,
|
||||
weight: KNNWeightFunction::Distance,
|
||||
},
|
||||
KNNRegressorParameters::default()
|
||||
.with_k(3)
|
||||
.with_distance(Distances::euclidian())
|
||||
.with_algorithm(KNNAlgorithmName::LinearSearch)
|
||||
.with_weight(KNNWeightFunction::Distance),
|
||||
)
|
||||
.unwrap();
|
||||
let y_hat = knn.predict(&x).unwrap();
|
||||
@@ -190,13 +255,14 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn knn_fit_predict_uniform() {
|
||||
let x =
|
||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
let y: Vec<f64> = vec![1., 2., 3., 4., 5.];
|
||||
let y_exp = vec![2., 2., 3., 4., 4.];
|
||||
let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||
let knn = KNNRegressor::fit(&x, &y, Default::default()).unwrap();
|
||||
let y_hat = knn.predict(&x).unwrap();
|
||||
assert_eq!(5, Vec::len(&y_hat));
|
||||
for i in 0..y_hat.len() {
|
||||
@@ -204,13 +270,15 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x =
|
||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
let y = vec![1., 2., 3., 4., 5.];
|
||||
|
||||
let knn = KNNRegressor::fit(&x, &y, Distances::euclidian(), Default::default()).unwrap();
|
||||
let knn = KNNRegressor::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let deserialized_knn = bincode::deserialize(&bincode::serialize(&knn).unwrap()).unwrap();
|
||||
|
||||
|
||||
+9
-46
@@ -10,7 +10,7 @@
|
||||
//! and follows three conditions:
|
||||
//! 1. \\( d(x, y) = 0 \\) if and only \\( x = y \\), positive definiteness
|
||||
//! 1. \\( d(x, y) = d(y, x) \\), symmetry
|
||||
//! 1. \\( d(x, y) \leq d(x, z) + d(z, y) \\), subadditivity or triangle inequality
|
||||
//! 1. \\( d(x, y) \leq d(x, z) + d(z, y) \\), subadditivity or triangle inequality
|
||||
//!
|
||||
//! for all \\(x, y, z \in Z \\)
|
||||
//!
|
||||
@@ -32,11 +32,8 @@
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
use crate::algorithm::neighbour::cover_tree::CoverTree;
|
||||
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
|
||||
use crate::error::Failed;
|
||||
use crate::math::distance::Distance;
|
||||
use crate::math::num::RealNumber;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// K Nearest Neighbors Classifier
|
||||
@@ -44,18 +41,16 @@ pub mod knn_classifier;
|
||||
/// K Nearest Neighbors Regressor
|
||||
pub mod knn_regressor;
|
||||
|
||||
/// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries.
|
||||
/// `KNNAlgorithmName` maintains a list of supported search algorithms, see [KNN algorithms](../algorithm/neighbour/index.html)
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub enum KNNAlgorithmName {
|
||||
/// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html)
|
||||
LinearSearch,
|
||||
/// Cover Tree Search algorithm, see [`CoverTree`](../algorithm/neighbour/cover_tree/index.html)
|
||||
CoverTree,
|
||||
}
|
||||
#[deprecated(
|
||||
since = "0.2.0",
|
||||
note = "please use `smartcore::algorithm::neighbour::KNNAlgorithmName` instead"
|
||||
)]
|
||||
pub type KNNAlgorithmName = crate::algorithm::neighbour::KNNAlgorithmName;
|
||||
|
||||
/// Weight function that is used to determine estimated value.
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum KNNWeightFunction {
|
||||
/// All k nearest points are weighted equally
|
||||
Uniform,
|
||||
@@ -63,12 +58,6 @@ pub enum KNNWeightFunction {
|
||||
Distance,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
enum KNNAlgorithm<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
|
||||
CoverTree(CoverTree<Vec<T>, T, D>),
|
||||
}
|
||||
|
||||
impl KNNWeightFunction {
|
||||
fn calc_weights<T: RealNumber>(&self, distances: Vec<T>) -> std::vec::Vec<T> {
|
||||
match *self {
|
||||
@@ -88,29 +77,3 @@ impl KNNWeightFunction {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl KNNAlgorithmName {
|
||||
fn fit<T: RealNumber, D: Distance<Vec<T>, T>>(
|
||||
&self,
|
||||
data: Vec<Vec<T>>,
|
||||
distance: D,
|
||||
) -> Result<KNNAlgorithm<T, D>, Failed> {
|
||||
match *self {
|
||||
KNNAlgorithmName::LinearSearch => {
|
||||
LinearKNNSearch::new(data, distance).map(|a| KNNAlgorithm::LinearSearch(a))
|
||||
}
|
||||
KNNAlgorithmName::CoverTree => {
|
||||
CoverTree::new(data, distance).map(|a| KNNAlgorithm::CoverTree(a))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNAlgorithm<T, D> {
|
||||
fn find(&self, from: &Vec<T>, k: usize) -> Result<Vec<(usize, T)>, Failed> {
|
||||
match *self {
|
||||
KNNAlgorithm::LinearSearch(ref linear) => linear.find(from, k),
|
||||
KNNAlgorithm::CoverTree(ref cover) => cover.find(from, k),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,8 +25,8 @@ impl<T: RealNumber> Default for GradientDescent<T> {
|
||||
impl<T: RealNumber> FirstOrderOptimizer<T> for GradientDescent<T> {
|
||||
fn optimize<'a, X: Matrix<T>, LS: LineSearchMethod<T>>(
|
||||
&self,
|
||||
f: &'a F<T, X>,
|
||||
df: &'a DF<X>,
|
||||
f: &'a F<'_, T, X>,
|
||||
df: &'a DF<'_, X>,
|
||||
x0: &X,
|
||||
ls: &'a LS,
|
||||
) -> OptimizerResult<T, X> {
|
||||
@@ -50,14 +50,14 @@ impl<T: RealNumber> FirstOrderOptimizer<T> for GradientDescent<T> {
|
||||
let f_alpha = |alpha: T| -> T {
|
||||
let mut dx = step.clone();
|
||||
dx.mul_scalar_mut(alpha);
|
||||
f(&dx.add_mut(&x)) // f(x) = f(x .+ gvec .* alpha)
|
||||
f(dx.add_mut(&x)) // f(x) = f(x .+ gvec .* alpha)
|
||||
};
|
||||
|
||||
let df_alpha = |alpha: T| -> T {
|
||||
let mut dx = step.clone();
|
||||
let mut dg = gvec.clone();
|
||||
dx.mul_scalar_mut(alpha);
|
||||
df(&mut dg, &dx.add_mut(&x)); //df(x) = df(x .+ gvec .* alpha)
|
||||
df(&mut dg, dx.add_mut(&x)); //df(x) = df(x .+ gvec .* alpha)
|
||||
gvec.dot(&dg)
|
||||
};
|
||||
|
||||
@@ -66,7 +66,7 @@ impl<T: RealNumber> FirstOrderOptimizer<T> for GradientDescent<T> {
|
||||
let ls_r = ls.search(&f_alpha, &df_alpha, alpha, fx, df0);
|
||||
alpha = ls_r.alpha;
|
||||
fx = ls_r.f_x;
|
||||
x.add_mut(&step.mul_scalar_mut(alpha));
|
||||
x.add_mut(step.mul_scalar_mut(alpha));
|
||||
df(&mut gvec, &x);
|
||||
gnorm = gvec.norm2();
|
||||
}
|
||||
@@ -74,8 +74,8 @@ impl<T: RealNumber> FirstOrderOptimizer<T> for GradientDescent<T> {
|
||||
let f_x = f(&x);
|
||||
|
||||
OptimizerResult {
|
||||
x: x,
|
||||
f_x: f_x,
|
||||
x,
|
||||
f_x,
|
||||
iterations: iter,
|
||||
}
|
||||
}
|
||||
@@ -88,6 +88,7 @@ mod tests {
|
||||
use crate::optimization::line_search::Backtracking;
|
||||
use crate::optimization::FunctionOrder;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn gradient_descent() {
|
||||
let x0 = DenseMatrix::row_vector_from_array(&[-1., 1.]);
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#![allow(clippy::suspicious_operation_groupings)]
|
||||
use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
|
||||
@@ -7,6 +8,7 @@ use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
|
||||
use crate::optimization::line_search::LineSearchMethod;
|
||||
use crate::optimization::{DF, F};
|
||||
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
pub struct LBFGS<T: RealNumber> {
|
||||
pub max_iter: usize,
|
||||
pub g_rtol: T,
|
||||
@@ -100,8 +102,8 @@ impl<T: RealNumber> LBFGS<T> {
|
||||
|
||||
fn update_state<'a, X: Matrix<T>, LS: LineSearchMethod<T>>(
|
||||
&self,
|
||||
f: &'a F<T, X>,
|
||||
df: &'a DF<X>,
|
||||
f: &'a F<'_, T, X>,
|
||||
df: &'a DF<'_, X>,
|
||||
ls: &'a LS,
|
||||
state: &mut LBFGSState<T, X>,
|
||||
) {
|
||||
@@ -116,14 +118,14 @@ impl<T: RealNumber> LBFGS<T> {
|
||||
let f_alpha = |alpha: T| -> T {
|
||||
let mut dx = state.s.clone();
|
||||
dx.mul_scalar_mut(alpha);
|
||||
f(&dx.add_mut(&state.x)) // f(x) = f(x .+ gvec .* alpha)
|
||||
f(dx.add_mut(&state.x)) // f(x) = f(x .+ gvec .* alpha)
|
||||
};
|
||||
|
||||
let df_alpha = |alpha: T| -> T {
|
||||
let mut dx = state.s.clone();
|
||||
let mut dg = state.x_df.clone();
|
||||
dx.mul_scalar_mut(alpha);
|
||||
df(&mut dg, &dx.add_mut(&state.x)); //df(x) = df(x .+ gvec .* alpha)
|
||||
df(&mut dg, dx.add_mut(&state.x)); //df(x) = df(x .+ gvec .* alpha)
|
||||
state.x_df.dot(&dg)
|
||||
};
|
||||
|
||||
@@ -162,7 +164,7 @@ impl<T: RealNumber> LBFGS<T> {
|
||||
g_converged || x_converged || state.counter_f_tol > self.successive_f_tol
|
||||
}
|
||||
|
||||
fn update_hessian<'a, X: Matrix<T>>(&self, _: &'a DF<X>, state: &mut LBFGSState<T, X>) {
|
||||
fn update_hessian<'a, X: Matrix<T>>(&self, _: &'a DF<'_, X>, state: &mut LBFGSState<T, X>) {
|
||||
state.dg = state.x_df.sub(&state.x_df_prev);
|
||||
let rho_iteration = T::one() / state.dx.dot(&state.dg);
|
||||
if !rho_iteration.is_infinite() {
|
||||
@@ -198,14 +200,14 @@ struct LBFGSState<T: RealNumber, X: Matrix<T>> {
|
||||
impl<T: RealNumber> FirstOrderOptimizer<T> for LBFGS<T> {
|
||||
fn optimize<'a, X: Matrix<T>, LS: LineSearchMethod<T>>(
|
||||
&self,
|
||||
f: &F<T, X>,
|
||||
df: &'a DF<X>,
|
||||
f: &F<'_, T, X>,
|
||||
df: &'a DF<'_, X>,
|
||||
x0: &X,
|
||||
ls: &'a LS,
|
||||
) -> OptimizerResult<T, X> {
|
||||
let mut state = self.init_state(x0);
|
||||
|
||||
df(&mut state.x_df, &x0);
|
||||
df(&mut state.x_df, x0);
|
||||
|
||||
let g_converged = state.x_df.norm(T::infinity()) < self.g_atol;
|
||||
let mut converged = g_converged;
|
||||
@@ -238,6 +240,7 @@ mod tests {
|
||||
use crate::optimization::line_search::Backtracking;
|
||||
use crate::optimization::FunctionOrder;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn lbfgs() {
|
||||
let x0 = DenseMatrix::row_vector_from_array(&[0., 0.]);
|
||||
|
||||
@@ -12,8 +12,8 @@ use crate::optimization::{DF, F};
|
||||
pub trait FirstOrderOptimizer<T: RealNumber> {
|
||||
fn optimize<'a, X: Matrix<T>, LS: LineSearchMethod<T>>(
|
||||
&self,
|
||||
f: &F<T, X>,
|
||||
df: &'a DF<X>,
|
||||
f: &F<'_, T, X>,
|
||||
df: &'a DF<'_, X>,
|
||||
x0: &X,
|
||||
ls: &'a LS,
|
||||
) -> OptimizerResult<T, X>;
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::optimization::FunctionOrder;
|
||||
use num_traits::Float;
|
||||
|
||||
pub trait LineSearchMethod<T: Float> {
|
||||
fn search<'a>(
|
||||
fn search(
|
||||
&self,
|
||||
f: &(dyn Fn(T) -> T),
|
||||
df: &(dyn Fn(T) -> T),
|
||||
@@ -41,7 +41,7 @@ impl<T: Float> Default for Backtracking<T> {
|
||||
}
|
||||
|
||||
impl<T: Float> LineSearchMethod<T> for Backtracking<T> {
|
||||
fn search<'a>(
|
||||
fn search(
|
||||
&self,
|
||||
f: &(dyn Fn(T) -> T),
|
||||
_: &(dyn Fn(T) -> T),
|
||||
@@ -112,6 +112,7 @@ impl<T: Float> LineSearchMethod<T> for Backtracking<T> {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn backtracking() {
|
||||
let f = |x: f64| -> f64 { x.powf(2.) + x };
|
||||
|
||||
@@ -4,7 +4,8 @@ pub mod line_search;
|
||||
pub type F<'a, T, X> = dyn for<'b> Fn(&'b X) -> T + 'a;
|
||||
pub type DF<'a, X> = dyn for<'b> Fn(&'b mut X, &'b X) + 'a;
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum FunctionOrder {
|
||||
SECOND,
|
||||
THIRD,
|
||||
|
||||
@@ -0,0 +1,333 @@
|
||||
//! # One-hot Encoding For [RealNumber](../../math/num/trait.RealNumber.html) Matricies
|
||||
//! Transform a data [Matrix](../../linalg/trait.BaseMatrix.html) by replacing all categorical variables with their one-hot equivalents
|
||||
//!
|
||||
//! Internally OneHotEncoder treats every categorical column as a series and transforms it using [CategoryMapper](../series_encoder/struct.CategoryMapper.html)
|
||||
//!
|
||||
//! ### Usage Example
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::DenseMatrix;
|
||||
//! use smartcore::preprocessing::categorical::{OneHotEncoder, OneHotEncoderParams};
|
||||
//! let data = DenseMatrix::from_2d_array(&[
|
||||
//! &[1.5, 1.0, 1.5, 3.0],
|
||||
//! &[1.5, 2.0, 1.5, 4.0],
|
||||
//! &[1.5, 1.0, 1.5, 5.0],
|
||||
//! &[1.5, 2.0, 1.5, 6.0],
|
||||
//! ]);
|
||||
//! let encoder_params = OneHotEncoderParams::from_cat_idx(&[1, 3]);
|
||||
//! // Infer number of categories from data and return a reusable encoder
|
||||
//! let encoder = OneHotEncoder::fit(&data, encoder_params).unwrap();
|
||||
//! // Transform categorical to one-hot encoded (can transform similar)
|
||||
//! let oh_data = encoder.transform(&data).unwrap();
|
||||
//! // Produces the following:
|
||||
//! // &[1.5, 1.0, 0.0, 1.5, 1.0, 0.0, 0.0, 0.0]
|
||||
//! // &[1.5, 0.0, 1.0, 1.5, 0.0, 1.0, 0.0, 0.0]
|
||||
//! // &[1.5, 1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0]
|
||||
//! // &[1.5, 0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0]
|
||||
//! ```
|
||||
use std::iter;
|
||||
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
|
||||
use crate::preprocessing::data_traits::{CategoricalFloat, Categorizable};
|
||||
use crate::preprocessing::series_encoder::CategoryMapper;
|
||||
|
||||
/// OneHotEncoder Parameters
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OneHotEncoderParams {
|
||||
/// Column number that contain categorical variable
|
||||
pub col_idx_categorical: Option<Vec<usize>>,
|
||||
/// (Currently not implemented) Try and infer which of the matrix columns are categorical variables
|
||||
infer_categorical: bool,
|
||||
}
|
||||
|
||||
impl OneHotEncoderParams {
|
||||
/// Generate parameters from categorical variable column numbers
|
||||
pub fn from_cat_idx(categorical_params: &[usize]) -> Self {
|
||||
Self {
|
||||
col_idx_categorical: Some(categorical_params.to_vec()),
|
||||
infer_categorical: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate the offset to parameters to due introduction of one-hot encoding
|
||||
fn find_new_idxs(num_params: usize, cat_sizes: &[usize], cat_idxs: &[usize]) -> Vec<usize> {
|
||||
// This functions uses iterators and returns a vector.
|
||||
// In case we get a huge amount of paramenters this might be a problem
|
||||
// todo: Change this such that it will return an iterator
|
||||
|
||||
let cat_idx = cat_idxs.iter().copied().chain((num_params..).take(1));
|
||||
|
||||
// Offset is constant between two categorical values, here we calculate the number of steps
|
||||
// that remain constant
|
||||
let repeats = cat_idx.scan(0, |a, v| {
|
||||
let im = v + 1 - *a;
|
||||
*a = v;
|
||||
Some(im)
|
||||
});
|
||||
|
||||
// Calculate the offset to parameter idx due to newly intorduced one-hot vectors
|
||||
let offset_ = cat_sizes.iter().scan(0, |a, &v| {
|
||||
*a = *a + v - 1;
|
||||
Some(*a)
|
||||
});
|
||||
let offset = (0..1).chain(offset_);
|
||||
|
||||
let new_param_idxs: Vec<usize> = (0..num_params)
|
||||
.zip(
|
||||
repeats
|
||||
.zip(offset)
|
||||
.flat_map(|(r, o)| iter::repeat(o).take(r)),
|
||||
)
|
||||
.map(|(idx, ofst)| idx + ofst)
|
||||
.collect();
|
||||
new_param_idxs
|
||||
}
|
||||
|
||||
fn validate_col_is_categorical<T: Categorizable>(data: &[T]) -> bool {
|
||||
for v in data {
|
||||
if !v.is_valid() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
/// Encode Categorical variavbles of data matrix to one-hot
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OneHotEncoder {
|
||||
category_mappers: Vec<CategoryMapper<CategoricalFloat>>,
|
||||
col_idx_categorical: Vec<usize>,
|
||||
}
|
||||
|
||||
impl OneHotEncoder {
|
||||
/// Create an encoder instance with categories infered from data matrix
|
||||
pub fn fit<T, M>(data: &M, params: OneHotEncoderParams) -> Result<OneHotEncoder, Failed>
|
||||
where
|
||||
T: Categorizable,
|
||||
M: Matrix<T>,
|
||||
{
|
||||
match (params.col_idx_categorical, params.infer_categorical) {
|
||||
(None, false) => Err(Failed::fit(
|
||||
"Must pass categorical series ids or infer flag",
|
||||
)),
|
||||
|
||||
(Some(_idxs), true) => Err(Failed::fit(
|
||||
"Ambigous parameters, got both infer and categroy ids",
|
||||
)),
|
||||
|
||||
(Some(mut idxs), false) => {
|
||||
// make sure categories have same order as data columns
|
||||
idxs.sort_unstable();
|
||||
|
||||
let (nrows, _) = data.shape();
|
||||
|
||||
// col buffer to avoid allocations
|
||||
let mut col_buf: Vec<T> = iter::repeat(T::zero()).take(nrows).collect();
|
||||
|
||||
let mut res: Vec<CategoryMapper<CategoricalFloat>> = Vec::with_capacity(idxs.len());
|
||||
|
||||
for &idx in &idxs {
|
||||
data.copy_col_as_vec(idx, &mut col_buf);
|
||||
if !validate_col_is_categorical(&col_buf) {
|
||||
let msg = format!(
|
||||
"Column {} of data matrix containts non categorizable (integer) values",
|
||||
idx
|
||||
);
|
||||
return Err(Failed::fit(&msg[..]));
|
||||
}
|
||||
let hashable_col = col_buf.iter().map(|v| v.to_category());
|
||||
res.push(CategoryMapper::fit_to_iter(hashable_col));
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
category_mappers: res,
|
||||
col_idx_categorical: idxs,
|
||||
})
|
||||
}
|
||||
|
||||
(None, true) => {
|
||||
todo!("Auto-Inference for Categorical Variables not yet implemented")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Transform categorical variables to one-hot encoded and return a new matrix
|
||||
pub fn transform<T, M>(&self, x: &M) -> Result<M, Failed>
|
||||
where
|
||||
T: Categorizable,
|
||||
M: Matrix<T>,
|
||||
{
|
||||
let (nrows, p) = x.shape();
|
||||
let additional_params: Vec<usize> = self
|
||||
.category_mappers
|
||||
.iter()
|
||||
.map(|enc| enc.num_categories())
|
||||
.collect();
|
||||
|
||||
// Eac category of size v adds v-1 params
|
||||
let expandws_p: usize = p + additional_params.iter().fold(0, |cs, &v| cs + v - 1);
|
||||
|
||||
let new_col_idx = find_new_idxs(p, &additional_params[..], &self.col_idx_categorical[..]);
|
||||
let mut res = M::zeros(nrows, expandws_p);
|
||||
|
||||
for (pidx, &old_cidx) in self.col_idx_categorical.iter().enumerate() {
|
||||
let cidx = new_col_idx[old_cidx];
|
||||
let col_iter = (0..nrows).map(|r| x.get(r, old_cidx).to_category());
|
||||
let sencoder = &self.category_mappers[pidx];
|
||||
let oh_series = col_iter.map(|c| sencoder.get_one_hot::<T, Vec<T>>(&c));
|
||||
|
||||
for (row, oh_vec) in oh_series.enumerate() {
|
||||
match oh_vec {
|
||||
None => {
|
||||
// Since we support T types, bad value in a series causes in to be invalid
|
||||
let msg = format!("At least one value in column {} doesn't conform to category definition", old_cidx);
|
||||
return Err(Failed::transform(&msg[..]));
|
||||
}
|
||||
Some(v) => {
|
||||
// copy one hot vectors to their place in the data matrix;
|
||||
for (col_ofst, &val) in v.iter().enumerate() {
|
||||
res.set(row, cidx + col_ofst, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// copy old data in x to their new location while skipping catergorical vars (already treated)
|
||||
let mut skip_idx_iter = self.col_idx_categorical.iter();
|
||||
let mut cur_skip = skip_idx_iter.next();
|
||||
|
||||
for (old_p, &new_p) in new_col_idx.iter().enumerate() {
|
||||
// if found treated varible, skip it
|
||||
if let Some(&v) = cur_skip {
|
||||
if v == old_p {
|
||||
cur_skip = skip_idx_iter.next();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
for r in 0..nrows {
|
||||
let val = x.get(r, old_p);
|
||||
res.set(r, new_p, val);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::preprocessing::series_encoder::CategoryMapper;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn adjust_idxs() {
|
||||
assert_eq!(find_new_idxs(0, &[], &[]), Vec::<usize>::new());
|
||||
// [0,1,2] -> [0, 1, 1, 1, 2]
|
||||
assert_eq!(find_new_idxs(3, &[3], &[1]), vec![0, 1, 4]);
|
||||
}
|
||||
|
||||
fn build_cat_first_and_last() -> (DenseMatrix<f64>, DenseMatrix<f64>) {
|
||||
let orig = DenseMatrix::from_2d_array(&[
|
||||
&[1.0, 1.5, 3.0],
|
||||
&[2.0, 1.5, 4.0],
|
||||
&[1.0, 1.5, 5.0],
|
||||
&[2.0, 1.5, 6.0],
|
||||
]);
|
||||
|
||||
let oh_enc = DenseMatrix::from_2d_array(&[
|
||||
&[1.0, 0.0, 1.5, 1.0, 0.0, 0.0, 0.0],
|
||||
&[0.0, 1.0, 1.5, 0.0, 1.0, 0.0, 0.0],
|
||||
&[1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0],
|
||||
&[0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0],
|
||||
]);
|
||||
|
||||
(orig, oh_enc)
|
||||
}
|
||||
|
||||
fn build_fake_matrix() -> (DenseMatrix<f64>, DenseMatrix<f64>) {
|
||||
// Categorical first and last
|
||||
let orig = DenseMatrix::from_2d_array(&[
|
||||
&[1.5, 1.0, 1.5, 3.0],
|
||||
&[1.5, 2.0, 1.5, 4.0],
|
||||
&[1.5, 1.0, 1.5, 5.0],
|
||||
&[1.5, 2.0, 1.5, 6.0],
|
||||
]);
|
||||
|
||||
let oh_enc = DenseMatrix::from_2d_array(&[
|
||||
&[1.5, 1.0, 0.0, 1.5, 1.0, 0.0, 0.0, 0.0],
|
||||
&[1.5, 0.0, 1.0, 1.5, 0.0, 1.0, 0.0, 0.0],
|
||||
&[1.5, 1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0],
|
||||
&[1.5, 0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0],
|
||||
]);
|
||||
|
||||
(orig, oh_enc)
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn hash_encode_f64_series() {
|
||||
let series = vec![3.0, 1.0, 2.0, 1.0];
|
||||
let hashable_series: Vec<CategoricalFloat> =
|
||||
series.iter().map(|v| v.to_category()).collect();
|
||||
let enc = CategoryMapper::from_positional_category_vec(hashable_series);
|
||||
let inv = enc.invert_one_hot(vec![0.0, 0.0, 1.0]);
|
||||
let orig_val: f64 = inv.unwrap().into();
|
||||
assert_eq!(orig_val, 2.0);
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn test_fit() {
|
||||
let (x, _) = build_fake_matrix();
|
||||
let params = OneHotEncoderParams::from_cat_idx(&[1, 3]);
|
||||
let oh_enc = OneHotEncoder::fit(&x, params).unwrap();
|
||||
assert_eq!(oh_enc.category_mappers.len(), 2);
|
||||
|
||||
let num_cat: Vec<usize> = oh_enc
|
||||
.category_mappers
|
||||
.iter()
|
||||
.map(|a| a.num_categories())
|
||||
.collect();
|
||||
assert_eq!(num_cat, vec![2, 4]);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn matrix_transform_test() {
|
||||
let (x, expected_x) = build_fake_matrix();
|
||||
let params = OneHotEncoderParams::from_cat_idx(&[1, 3]);
|
||||
let oh_enc = OneHotEncoder::fit(&x, params).unwrap();
|
||||
let nm = oh_enc.transform(&x).unwrap();
|
||||
assert_eq!(nm, expected_x);
|
||||
|
||||
let (x, expected_x) = build_cat_first_and_last();
|
||||
let params = OneHotEncoderParams::from_cat_idx(&[0, 2]);
|
||||
let oh_enc = OneHotEncoder::fit(&x, params).unwrap();
|
||||
let nm = oh_enc.transform(&x).unwrap();
|
||||
assert_eq!(nm, expected_x);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn fail_on_bad_category() {
|
||||
let m = DenseMatrix::from_2d_array(&[
|
||||
&[1.0, 1.5, 3.0],
|
||||
&[2.0, 1.5, 4.0],
|
||||
&[1.0, 1.5, 5.0],
|
||||
&[2.0, 1.5, 6.0],
|
||||
]);
|
||||
|
||||
let params = OneHotEncoderParams::from_cat_idx(&[1]);
|
||||
match OneHotEncoder::fit(&m, params) {
|
||||
Err(_) => {
|
||||
assert!(true);
|
||||
}
|
||||
_ => assert!(false),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
//! Traits to indicate that float variables can be viewed as categorical
|
||||
//! This module assumes
|
||||
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
pub type CategoricalFloat = u16;
|
||||
|
||||
// pub struct CategoricalFloat(u16);
|
||||
const ERROR_MARGIN: f64 = 0.001;
|
||||
|
||||
pub trait Categorizable: RealNumber {
|
||||
type A;
|
||||
|
||||
fn to_category(self) -> CategoricalFloat;
|
||||
|
||||
fn is_valid(self) -> bool;
|
||||
}
|
||||
|
||||
impl Categorizable for f32 {
|
||||
type A = CategoricalFloat;
|
||||
|
||||
fn to_category(self) -> CategoricalFloat {
|
||||
self as CategoricalFloat
|
||||
}
|
||||
|
||||
fn is_valid(self) -> bool {
|
||||
let a = self.to_category();
|
||||
(a as f32 - self).abs() < (ERROR_MARGIN as f32)
|
||||
}
|
||||
}
|
||||
|
||||
impl Categorizable for f64 {
|
||||
type A = CategoricalFloat;
|
||||
|
||||
fn to_category(self) -> CategoricalFloat {
|
||||
self as CategoricalFloat
|
||||
}
|
||||
|
||||
fn is_valid(self) -> bool {
|
||||
let a = self.to_category();
|
||||
(a as f64 - self).abs() < ERROR_MARGIN
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
/// Transform a data matrix by replacing all categorical variables with their one-hot vector equivalents
|
||||
pub mod categorical;
|
||||
mod data_traits;
|
||||
/// Preprocess numerical matrices.
|
||||
pub mod numerical;
|
||||
/// Encode a series (column, array) of categorical variables as one-hot vectors
|
||||
pub mod series_encoder;
|
||||
@@ -0,0 +1,447 @@
|
||||
//! # Standard-Scaling For [RealNumber](../../math/num/trait.RealNumber.html) Matricies
|
||||
//! Transform a data [Matrix](../../linalg/trait.BaseMatrix.html) by removing the mean and scaling to unit variance.
|
||||
//!
|
||||
//! ### Usage Example
|
||||
//! ```
|
||||
//! use smartcore::api::{Transformer, UnsupervisedEstimator};
|
||||
//! use smartcore::linalg::naive::dense_matrix::DenseMatrix;
|
||||
//! use smartcore::preprocessing::numerical;
|
||||
//! let data = DenseMatrix::from_2d_vec(&vec![
|
||||
//! vec![0.0, 0.0],
|
||||
//! vec![0.0, 0.0],
|
||||
//! vec![1.0, 1.0],
|
||||
//! vec![1.0, 1.0],
|
||||
//! ]);
|
||||
//!
|
||||
//! let standard_scaler =
|
||||
//! numerical::StandardScaler::fit(&data, numerical::StandardScalerParameters::default())
|
||||
//! .unwrap();
|
||||
//! let transformed_data = standard_scaler.transform(&data).unwrap();
|
||||
//! assert_eq!(
|
||||
//! transformed_data,
|
||||
//! DenseMatrix::from_2d_vec(&vec![
|
||||
//! vec![-1.0, -1.0],
|
||||
//! vec![-1.0, -1.0],
|
||||
//! vec![1.0, 1.0],
|
||||
//! vec![1.0, 1.0],
|
||||
//! ])
|
||||
//! );
|
||||
//! ```
|
||||
use crate::api::{Transformer, UnsupervisedEstimator};
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Configure Behaviour of `StandardScaler`.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Clone, Debug, Copy, Eq, PartialEq)]
|
||||
pub struct StandardScalerParameters {
|
||||
/// Optionaly adjust mean to be zero.
|
||||
with_mean: bool,
|
||||
/// Optionally adjust standard-deviation to be one.
|
||||
with_std: bool,
|
||||
}
|
||||
impl Default for StandardScalerParameters {
|
||||
fn default() -> Self {
|
||||
StandardScalerParameters {
|
||||
with_mean: true,
|
||||
with_std: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// With the `StandardScaler` data can be adjusted so
|
||||
/// that every column has a mean of zero and a standard
|
||||
/// deviation of one. This can improve model training for
|
||||
/// scaling sensitive models like neural network or nearest
|
||||
/// neighbors based models.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Clone, Debug, Default, Eq, PartialEq)]
|
||||
pub struct StandardScaler<T: RealNumber> {
|
||||
means: Vec<T>,
|
||||
stds: Vec<T>,
|
||||
parameters: StandardScalerParameters,
|
||||
}
|
||||
impl<T: RealNumber> StandardScaler<T> {
|
||||
/// When the mean should be adjusted, the column mean
|
||||
/// should be kept. Otherwise, replace it by zero.
|
||||
fn adjust_column_mean(&self, mean: T) -> T {
|
||||
if self.parameters.with_mean {
|
||||
mean
|
||||
} else {
|
||||
T::zero()
|
||||
}
|
||||
}
|
||||
/// When the standard-deviation should be adjusted, the column
|
||||
/// standard-deviation should be kept. Otherwise, replace it by one.
|
||||
fn adjust_column_std(&self, std: T) -> T {
|
||||
if self.parameters.with_std {
|
||||
ensure_std_valid(std)
|
||||
} else {
|
||||
T::one()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Make sure the standard deviation is valid. If it is
|
||||
/// negative or zero, it should replaced by the smallest
|
||||
/// positive value the type can have. That way we can savely
|
||||
/// divide the columns with the resulting scalar.
|
||||
fn ensure_std_valid<T: RealNumber>(value: T) -> T {
|
||||
value.max(T::min_positive_value())
|
||||
}
|
||||
|
||||
/// During `fit` the `StandardScaler` computes the column means and standard deviation.
|
||||
impl<T: RealNumber, M: Matrix<T>> UnsupervisedEstimator<M, StandardScalerParameters>
|
||||
for StandardScaler<T>
|
||||
{
|
||||
fn fit(x: &M, parameters: StandardScalerParameters) -> Result<Self, Failed> {
|
||||
Ok(Self {
|
||||
means: x.column_mean(),
|
||||
stds: x.std(0),
|
||||
parameters,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// During `transform` the `StandardScaler` applies the summary statistics
|
||||
/// computed during `fit` to set the mean of each column to zero and the
|
||||
/// standard deviation to one.
|
||||
impl<T: RealNumber, M: Matrix<T>> Transformer<M> for StandardScaler<T> {
|
||||
fn transform(&self, x: &M) -> Result<M, Failed> {
|
||||
let (_, n_cols) = x.shape();
|
||||
if n_cols != self.means.len() {
|
||||
return Err(Failed::because(
|
||||
FailedError::TransformFailed,
|
||||
&format!(
|
||||
"Expected {} columns, but got {} columns instead.",
|
||||
self.means.len(),
|
||||
n_cols,
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(build_matrix_from_columns(
|
||||
self.means
|
||||
.iter()
|
||||
.zip(self.stds.iter())
|
||||
.enumerate()
|
||||
.map(|(column_index, (column_mean, column_std))| {
|
||||
x.take_column(column_index)
|
||||
.sub_scalar(self.adjust_column_mean(*column_mean))
|
||||
.div_scalar(self.adjust_column_std(*column_std))
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
.unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
/// From a collection of matrices, that contain columns, construct
|
||||
/// a matrix by stacking the columns horizontally.
|
||||
fn build_matrix_from_columns<T, M>(columns: Vec<M>) -> Option<M>
|
||||
where
|
||||
T: RealNumber,
|
||||
M: Matrix<T>,
|
||||
{
|
||||
if let Some(output_matrix) = columns.first().cloned() {
|
||||
return Some(
|
||||
columns
|
||||
.iter()
|
||||
.skip(1)
|
||||
.fold(output_matrix, |current_matrix, new_colum| {
|
||||
current_matrix.h_stack(new_colum)
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
mod helper_functionality {
|
||||
use super::super::{build_matrix_from_columns, ensure_std_valid};
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[test]
|
||||
fn combine_three_columns() {
|
||||
assert_eq!(
|
||||
build_matrix_from_columns(vec![
|
||||
DenseMatrix::from_2d_vec(&vec![vec![1.0], vec![1.0], vec![1.0],]),
|
||||
DenseMatrix::from_2d_vec(&vec![vec![2.0], vec![2.0], vec![2.0],]),
|
||||
DenseMatrix::from_2d_vec(&vec![vec![3.0], vec![3.0], vec![3.0],])
|
||||
]),
|
||||
Some(DenseMatrix::from_2d_vec(&vec![
|
||||
vec![1.0, 2.0, 3.0],
|
||||
vec![1.0, 2.0, 3.0],
|
||||
vec![1.0, 2.0, 3.0]
|
||||
]))
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn negative_value_should_be_replace_with_minimal_positive_value() {
|
||||
assert_eq!(ensure_std_valid(-1.0), f64::MIN_POSITIVE)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zero_should_be_replace_with_minimal_positive_value() {
|
||||
assert_eq!(ensure_std_valid(0.0), f64::MIN_POSITIVE)
|
||||
}
|
||||
}
|
||||
mod standard_scaler {
|
||||
use super::super::{StandardScaler, StandardScalerParameters};
|
||||
use crate::api::{Transformer, UnsupervisedEstimator};
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::linalg::BaseMatrix;
|
||||
|
||||
#[test]
|
||||
fn dont_adjust_mean_if_used() {
|
||||
assert_eq!(
|
||||
(StandardScaler {
|
||||
means: vec![],
|
||||
stds: vec![],
|
||||
parameters: StandardScalerParameters {
|
||||
with_mean: true,
|
||||
with_std: true
|
||||
}
|
||||
})
|
||||
.adjust_column_mean(1.0),
|
||||
1.0
|
||||
)
|
||||
}
|
||||
#[test]
|
||||
fn replace_mean_with_zero_if_not_used() {
|
||||
assert_eq!(
|
||||
(StandardScaler {
|
||||
means: vec![],
|
||||
stds: vec![],
|
||||
parameters: StandardScalerParameters {
|
||||
with_mean: false,
|
||||
with_std: true
|
||||
}
|
||||
})
|
||||
.adjust_column_mean(1.0),
|
||||
0.0
|
||||
)
|
||||
}
|
||||
#[test]
|
||||
fn dont_adjust_std_if_used() {
|
||||
assert_eq!(
|
||||
(StandardScaler {
|
||||
means: vec![],
|
||||
stds: vec![],
|
||||
parameters: StandardScalerParameters {
|
||||
with_mean: true,
|
||||
with_std: true
|
||||
}
|
||||
})
|
||||
.adjust_column_std(10.0),
|
||||
10.0
|
||||
)
|
||||
}
|
||||
#[test]
|
||||
fn replace_std_with_one_if_not_used() {
|
||||
assert_eq!(
|
||||
(StandardScaler {
|
||||
means: vec![],
|
||||
stds: vec![],
|
||||
parameters: StandardScalerParameters {
|
||||
with_mean: true,
|
||||
with_std: false
|
||||
}
|
||||
})
|
||||
.adjust_column_std(10.0),
|
||||
1.0
|
||||
)
|
||||
}
|
||||
|
||||
/// Helper function to apply fit as well as transform at the same time.
|
||||
fn fit_transform_with_default_standard_scaler(
|
||||
values_to_be_transformed: &DenseMatrix<f64>,
|
||||
) -> DenseMatrix<f64> {
|
||||
StandardScaler::fit(
|
||||
values_to_be_transformed,
|
||||
StandardScalerParameters::default(),
|
||||
)
|
||||
.unwrap()
|
||||
.transform(values_to_be_transformed)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Fit transform with random generated values, expected values taken from
|
||||
/// sklearn.
|
||||
#[test]
|
||||
fn fit_transform_random_values() {
|
||||
let transformed_values =
|
||||
fit_transform_with_default_standard_scaler(&DenseMatrix::from_2d_array(&[
|
||||
&[0.1004222429, 0.2194113576, 0.9310663354, 0.3313593793],
|
||||
&[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264],
|
||||
&[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046],
|
||||
&[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442],
|
||||
]));
|
||||
println!("{}", transformed_values);
|
||||
assert!(transformed_values.approximate_eq(
|
||||
&DenseMatrix::from_2d_array(&[
|
||||
&[-1.1154020653, -0.4031985330, 0.9284605204, -0.4271473866],
|
||||
&[-0.7615464283, -0.7076698384, -1.1075452562, 1.2632979631],
|
||||
&[0.4832504303, -0.6106747444, 1.0630075435, 0.5494084257],
|
||||
&[1.3936980634, 1.7215431158, -0.8839228078, -1.3855590021],
|
||||
]),
|
||||
1.0
|
||||
))
|
||||
}
|
||||
|
||||
/// Test `fit` and `transform` for a column with zero variance.
|
||||
#[test]
|
||||
fn fit_transform_with_zero_variance() {
|
||||
assert_eq!(
|
||||
fit_transform_with_default_standard_scaler(&DenseMatrix::from_2d_array(&[
|
||||
&[1.0],
|
||||
&[1.0],
|
||||
&[1.0],
|
||||
&[1.0]
|
||||
])),
|
||||
DenseMatrix::from_2d_array(&[&[0.0], &[0.0], &[0.0], &[0.0]]),
|
||||
"When scaling values with zero variance, zero is expected as return value"
|
||||
)
|
||||
}
|
||||
|
||||
/// Test `fit` for columns with nice summary statistics.
|
||||
#[test]
|
||||
fn fit_for_simple_values() {
|
||||
assert_eq!(
|
||||
StandardScaler::fit(
|
||||
&DenseMatrix::from_2d_array(&[
|
||||
&[1.0, 1.0, 1.0],
|
||||
&[1.0, 2.0, 5.0],
|
||||
&[1.0, 1.0, 1.0],
|
||||
&[1.0, 2.0, 5.0]
|
||||
]),
|
||||
StandardScalerParameters::default(),
|
||||
),
|
||||
Ok(StandardScaler {
|
||||
means: vec![1.0, 1.5, 3.0],
|
||||
stds: vec![0.0, 0.5, 2.0],
|
||||
parameters: StandardScalerParameters {
|
||||
with_mean: true,
|
||||
with_std: true
|
||||
}
|
||||
})
|
||||
)
|
||||
}
|
||||
/// Test `fit` for random generated values.
|
||||
#[test]
|
||||
fn fit_for_random_values() {
|
||||
let fitted_scaler = StandardScaler::fit(
|
||||
&DenseMatrix::from_2d_array(&[
|
||||
&[0.1004222429, 0.2194113576, 0.9310663354, 0.3313593793],
|
||||
&[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264],
|
||||
&[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046],
|
||||
&[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442],
|
||||
]),
|
||||
StandardScalerParameters::default(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
fitted_scaler.means,
|
||||
vec![0.42864544605, 0.2869813741, 0.737752073825, 0.431011663625],
|
||||
);
|
||||
|
||||
assert!(
|
||||
&DenseMatrix::from_2d_vec(&vec![fitted_scaler.stds]).approximate_eq(
|
||||
&DenseMatrix::from_2d_array(&[&[
|
||||
0.29426447500954,
|
||||
0.16758497615485,
|
||||
0.20820945786863,
|
||||
0.23329718831165
|
||||
],]),
|
||||
0.00000000000001
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
/// If `with_std` is set to `false` the values should not be
|
||||
/// adjusted to have a std of one.
|
||||
#[test]
|
||||
fn transform_without_std() {
|
||||
let standard_scaler = StandardScaler {
|
||||
means: vec![1.0, 3.0],
|
||||
stds: vec![1.0, 2.0],
|
||||
parameters: StandardScalerParameters {
|
||||
with_mean: true,
|
||||
with_std: false,
|
||||
},
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
standard_scaler.transform(&DenseMatrix::from_2d_array(&[&[0.0, 2.0], &[2.0, 4.0]])),
|
||||
Ok(DenseMatrix::from_2d_array(&[&[-1.0, -1.0], &[1.0, 1.0]]))
|
||||
)
|
||||
}
|
||||
|
||||
/// If `with_mean` is set to `false` the values should not be adjusted
|
||||
/// to have a mean of zero.
|
||||
#[test]
|
||||
fn transform_without_mean() {
|
||||
let standard_scaler = StandardScaler {
|
||||
means: vec![1.0, 2.0],
|
||||
stds: vec![2.0, 3.0],
|
||||
parameters: StandardScalerParameters {
|
||||
with_mean: false,
|
||||
with_std: true,
|
||||
},
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
standard_scaler
|
||||
.transform(&DenseMatrix::from_2d_array(&[&[0.0, 9.0], &[4.0, 12.0]])),
|
||||
Ok(DenseMatrix::from_2d_array(&[&[0.0, 3.0], &[2.0, 4.0]]))
|
||||
)
|
||||
}
|
||||
|
||||
/// Same as `fit_for_random_values` test, but using a `StandardScaler` that has been
|
||||
/// serialized and deserialized.
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde_fit_for_random_values() {
|
||||
let fitted_scaler = StandardScaler::fit(
|
||||
&DenseMatrix::from_2d_array(&[
|
||||
&[0.1004222429, 0.2194113576, 0.9310663354, 0.3313593793],
|
||||
&[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264],
|
||||
&[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046],
|
||||
&[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442],
|
||||
]),
|
||||
StandardScalerParameters::default(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let deserialized_scaler: StandardScaler<f64> =
|
||||
serde_json::from_str(&serde_json::to_string(&fitted_scaler).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
deserialized_scaler.means,
|
||||
vec![0.42864544605, 0.2869813741, 0.737752073825, 0.431011663625],
|
||||
);
|
||||
|
||||
assert!(
|
||||
&DenseMatrix::from_2d_vec(&vec![deserialized_scaler.stds]).approximate_eq(
|
||||
&DenseMatrix::from_2d_array(&[&[
|
||||
0.29426447500954,
|
||||
0.16758497615485,
|
||||
0.20820945786863,
|
||||
0.23329718831165
|
||||
],]),
|
||||
0.00000000000001
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,282 @@
|
||||
#![allow(clippy::ptr_arg)]
|
||||
//! # Series Encoder
|
||||
//! Encode a series of categorical features as a one-hot numeric array.
|
||||
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
use std::collections::HashMap;
|
||||
use std::hash::Hash;
|
||||
|
||||
/// ## Bi-directional map category <-> label num.
|
||||
/// Turn Hashable objects into a one-hot vectors or ordinal values.
|
||||
/// This struct encodes single class per exmample
|
||||
///
|
||||
/// You can fit_to_iter a category enumeration by passing an iterator of categories.
|
||||
/// category numbers will be assigned in the order they are encountered
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// use std::collections::HashMap;
|
||||
/// use smartcore::preprocessing::series_encoder::CategoryMapper;
|
||||
///
|
||||
/// let fake_categories: Vec<usize> = vec![1, 2, 3, 4, 5, 3, 5, 3, 1, 2, 4];
|
||||
/// let it = fake_categories.iter().map(|&a| a);
|
||||
/// let enc = CategoryMapper::<usize>::fit_to_iter(it);
|
||||
/// let oh_vec: Vec<f64> = enc.get_one_hot(&1).unwrap();
|
||||
/// // notice that 1 is actually a zero-th positional category
|
||||
/// assert_eq!(oh_vec, vec![1.0, 0.0, 0.0, 0.0, 0.0]);
|
||||
/// ```
|
||||
///
|
||||
/// You can also pass a predefined category enumeration such as a hashmap `HashMap<C, usize>` or a vector `Vec<C>`
|
||||
///
|
||||
///
|
||||
/// ```
|
||||
/// use std::collections::HashMap;
|
||||
/// use smartcore::preprocessing::series_encoder::CategoryMapper;
|
||||
///
|
||||
/// let category_map: HashMap<&str, usize> =
|
||||
/// vec![("cat", 2), ("background",0), ("dog", 1)]
|
||||
/// .into_iter()
|
||||
/// .collect();
|
||||
/// let category_vec = vec!["background", "dog", "cat"];
|
||||
///
|
||||
/// let enc_lv = CategoryMapper::<&str>::from_positional_category_vec(category_vec);
|
||||
/// let enc_lm = CategoryMapper::<&str>::from_category_map(category_map);
|
||||
///
|
||||
/// // ["background", "dog", "cat"]
|
||||
/// println!("{:?}", enc_lv.get_categories());
|
||||
/// let lv: Vec<f64> = enc_lv.get_one_hot(&"dog").unwrap();
|
||||
/// let lm: Vec<f64> = enc_lm.get_one_hot(&"dog").unwrap();
|
||||
/// assert_eq!(lv, lm);
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CategoryMapper<C> {
|
||||
category_map: HashMap<C, usize>,
|
||||
categories: Vec<C>,
|
||||
num_categories: usize,
|
||||
}
|
||||
|
||||
impl<C> CategoryMapper<C>
|
||||
where
|
||||
C: Hash + Eq + Clone,
|
||||
{
|
||||
/// Get the number of categories in the mapper
|
||||
pub fn num_categories(&self) -> usize {
|
||||
self.num_categories
|
||||
}
|
||||
|
||||
/// Fit an encoder to a lable iterator
|
||||
pub fn fit_to_iter(categories: impl Iterator<Item = C>) -> Self {
|
||||
let mut category_map: HashMap<C, usize> = HashMap::new();
|
||||
let mut category_num = 0usize;
|
||||
let mut unique_lables: Vec<C> = Vec::new();
|
||||
|
||||
for l in categories {
|
||||
if !category_map.contains_key(&l) {
|
||||
category_map.insert(l.clone(), category_num);
|
||||
unique_lables.push(l.clone());
|
||||
category_num += 1;
|
||||
}
|
||||
}
|
||||
Self {
|
||||
category_map,
|
||||
num_categories: category_num,
|
||||
categories: unique_lables,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build an encoder from a predefined (category -> class number) map
|
||||
pub fn from_category_map(category_map: HashMap<C, usize>) -> Self {
|
||||
let mut _unique_cat: Vec<(C, usize)> =
|
||||
category_map.iter().map(|(k, v)| (k.clone(), *v)).collect();
|
||||
_unique_cat.sort_by(|a, b| a.1.cmp(&b.1));
|
||||
let categories: Vec<C> = _unique_cat.into_iter().map(|a| a.0).collect();
|
||||
Self {
|
||||
num_categories: categories.len(),
|
||||
categories,
|
||||
category_map,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build an encoder from a predefined positional category-class num vector
|
||||
pub fn from_positional_category_vec(categories: Vec<C>) -> Self {
|
||||
let category_map: HashMap<C, usize> = categories
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(v, k)| (k.clone(), v))
|
||||
.collect();
|
||||
Self {
|
||||
num_categories: categories.len(),
|
||||
category_map,
|
||||
categories,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get label num of a category
|
||||
pub fn get_num(&self, category: &C) -> Option<&usize> {
|
||||
self.category_map.get(category)
|
||||
}
|
||||
|
||||
/// Return category corresponding to label num
|
||||
pub fn get_cat(&self, num: usize) -> &C {
|
||||
&self.categories[num]
|
||||
}
|
||||
|
||||
/// List all categories (position = category number)
|
||||
pub fn get_categories(&self) -> &[C] {
|
||||
&self.categories[..]
|
||||
}
|
||||
|
||||
/// Get one-hot encoding of the category
|
||||
pub fn get_one_hot<U, V>(&self, category: &C) -> Option<V>
|
||||
where
|
||||
U: RealNumber,
|
||||
V: BaseVector<U>,
|
||||
{
|
||||
self.get_num(category)
|
||||
.map(|&idx| make_one_hot::<U, V>(idx, self.num_categories))
|
||||
}
|
||||
|
||||
/// Invert one-hot vector, back to the category
|
||||
pub fn invert_one_hot<U, V>(&self, one_hot: V) -> Result<C, Failed>
|
||||
where
|
||||
U: RealNumber,
|
||||
V: BaseVector<U>,
|
||||
{
|
||||
let pos = U::one();
|
||||
|
||||
let oh_it = (0..one_hot.len()).map(|idx| one_hot.get(idx));
|
||||
|
||||
let s: Vec<usize> = oh_it
|
||||
.enumerate()
|
||||
.filter_map(|(idx, v)| if v == pos { Some(idx) } else { None })
|
||||
.collect();
|
||||
|
||||
if s.len() == 1 {
|
||||
let idx = s[0];
|
||||
return Ok(self.get_cat(idx).clone());
|
||||
}
|
||||
let pos_entries = format!(
|
||||
"Expected a single positive entry, {} entires found",
|
||||
s.len()
|
||||
);
|
||||
Err(Failed::transform(&pos_entries[..]))
|
||||
}
|
||||
|
||||
/// Get ordinal encoding of the catergory
|
||||
pub fn get_ordinal<U>(&self, category: &C) -> Option<U>
|
||||
where
|
||||
U: RealNumber,
|
||||
{
|
||||
match self.get_num(category) {
|
||||
None => None,
|
||||
Some(&idx) => U::from_usize(idx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Make a one-hot encoded vector from a categorical variable
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// use smartcore::preprocessing::series_encoder::make_one_hot;
|
||||
/// let one_hot: Vec<f64> = make_one_hot(2, 3);
|
||||
/// assert_eq!(one_hot, vec![0.0, 0.0, 1.0]);
|
||||
/// ```
|
||||
pub fn make_one_hot<T, V>(category_idx: usize, num_categories: usize) -> V
|
||||
where
|
||||
T: RealNumber,
|
||||
V: BaseVector<T>,
|
||||
{
|
||||
let pos = T::one();
|
||||
let mut z = V::zeros(num_categories);
|
||||
z.set(category_idx, pos);
|
||||
z
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn from_categories() {
|
||||
let fake_categories: Vec<usize> = vec![1, 2, 3, 4, 5, 3, 5, 3, 1, 2, 4];
|
||||
let it = fake_categories.iter().map(|&a| a);
|
||||
let enc = CategoryMapper::<usize>::fit_to_iter(it);
|
||||
let oh_vec: Vec<f64> = match enc.get_one_hot(&1) {
|
||||
None => panic!("Wrong categories"),
|
||||
Some(v) => v,
|
||||
};
|
||||
let res: Vec<f64> = vec![1f64, 0f64, 0f64, 0f64, 0f64];
|
||||
assert_eq!(oh_vec, res);
|
||||
}
|
||||
|
||||
fn build_fake_str_enc<'a>() -> CategoryMapper<&'a str> {
|
||||
let fake_category_pos = vec!["background", "dog", "cat"];
|
||||
let enc = CategoryMapper::<&str>::from_positional_category_vec(fake_category_pos);
|
||||
enc
|
||||
}
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn ordinal_encoding() {
|
||||
let enc = build_fake_str_enc();
|
||||
assert_eq!(1f64, enc.get_ordinal::<f64>(&"dog").unwrap())
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn category_map_and_vec() {
|
||||
let category_map: HashMap<&str, usize> = vec![("background", 0), ("dog", 1), ("cat", 2)]
|
||||
.into_iter()
|
||||
.collect();
|
||||
let enc = CategoryMapper::<&str>::from_category_map(category_map);
|
||||
let oh_vec: Vec<f64> = match enc.get_one_hot(&"dog") {
|
||||
None => panic!("Wrong categories"),
|
||||
Some(v) => v,
|
||||
};
|
||||
let res: Vec<f64> = vec![0f64, 1f64, 0f64];
|
||||
assert_eq!(oh_vec, res);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn positional_categories_vec() {
|
||||
let enc = build_fake_str_enc();
|
||||
let oh_vec: Vec<f64> = match enc.get_one_hot(&"dog") {
|
||||
None => panic!("Wrong categories"),
|
||||
Some(v) => v,
|
||||
};
|
||||
let res: Vec<f64> = vec![0.0, 1.0, 0.0];
|
||||
assert_eq!(oh_vec, res);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn invert_label_test() {
|
||||
let enc = build_fake_str_enc();
|
||||
let res: Vec<f64> = vec![0.0, 1.0, 0.0];
|
||||
let lab = enc.invert_one_hot(res).unwrap();
|
||||
assert_eq!(lab, "dog");
|
||||
if let Err(e) = enc.invert_one_hot(vec![0.0, 0.0, 0.0]) {
|
||||
let pos_entries = format!("Expected a single positive entry, 0 entires found");
|
||||
assert_eq!(e, Failed::transform(&pos_entries[..]));
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn test_many_categorys() {
|
||||
let enc = build_fake_str_enc();
|
||||
let cat_it = ["dog", "cat", "fish", "background"].iter().cloned();
|
||||
let res: Vec<Option<Vec<f64>>> = cat_it.map(|v| enc.get_one_hot(&v)).collect();
|
||||
let v = vec![
|
||||
Some(vec![0.0, 1.0, 0.0]),
|
||||
Some(vec![0.0, 0.0, 1.0]),
|
||||
None,
|
||||
Some(vec![1.0, 0.0, 0.0]),
|
||||
];
|
||||
assert_eq!(res, v)
|
||||
}
|
||||
}
|
||||
+200
@@ -0,0 +1,200 @@
|
||||
//! # Support Vector Machines
|
||||
//!
|
||||
//! Support Vector Machines (SVM) is one of the most performant off-the-shelf machine learning algorithms.
|
||||
//! SVM is based on the [Vapnik–Chervonenkiy theory](https://en.wikipedia.org/wiki/Vapnik%E2%80%93Chervonenkis_theory) that was developed during 1960–1990 by Vladimir Vapnik and Alexey Chervonenkiy.
|
||||
//!
|
||||
//! SVM splits data into two sets using a maximal-margin decision boundary, \\(f(x)\\). For regression, the algorithm uses a value of the function \\(f(x)\\) to predict a target value.
|
||||
//! To classify a new point, algorithm calculates a sign of the decision function to see where the new point is relative to the boundary.
|
||||
//!
|
||||
//! SVM is memory efficient since it uses only a subset of training data to find a decision boundary. This subset is called support vectors.
|
||||
//!
|
||||
//! In SVM distance between a data point and the support vectors is defined by the kernel function.
|
||||
//! SmartCore supports multiple kernel functions but you can always define a new kernel function by implementing the `Kernel` trait. Not all functions can be a kernel.
|
||||
//! Building a new kernel requires a good mathematical understanding of the [Mercer theorem](https://en.wikipedia.org/wiki/Mercer%27s_theorem)
|
||||
//! that gives necessary and sufficient condition for a function to be a kernel function.
|
||||
//!
|
||||
//! Pre-defined kernel functions:
|
||||
//!
|
||||
//! * *Linear*, \\( K(x, x') = \langle x, x' \rangle\\)
|
||||
//! * *Polynomial*, \\( K(x, x') = (\gamma\langle x, x' \rangle + r)^d\\), where \\(d\\) is polynomial degree, \\(\gamma\\) is a kernel coefficient and \\(r\\) is an independent term in the kernel function.
|
||||
//! * *RBF (Gaussian)*, \\( K(x, x') = e^{-\gamma \lVert x - x' \rVert ^2} \\), where \\(\gamma\\) is kernel coefficient
|
||||
//! * *Sigmoid (hyperbolic tangent)*, \\( K(x, x') = \tanh ( \gamma \langle x, x' \rangle + r ) \\), where \\(\gamma\\) is kernel coefficient and \\(r\\) is an independent term in the kernel function.
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
pub mod svc;
|
||||
pub mod svr;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Defines a kernel function
|
||||
pub trait Kernel<T: RealNumber, V: BaseVector<T>> {
|
||||
/// Apply kernel function to x_i and x_j
|
||||
fn apply(&self, x_i: &V, x_j: &V) -> T;
|
||||
}
|
||||
|
||||
/// Pre-defined kernel functions
|
||||
pub struct Kernels {}
|
||||
|
||||
impl Kernels {
|
||||
/// Linear kernel
|
||||
pub fn linear() -> LinearKernel {
|
||||
LinearKernel {}
|
||||
}
|
||||
|
||||
/// Radial basis function kernel (Gaussian)
|
||||
pub fn rbf<T: RealNumber>(gamma: T) -> RBFKernel<T> {
|
||||
RBFKernel { gamma }
|
||||
}
|
||||
|
||||
/// Polynomial kernel
|
||||
/// * `degree` - degree of the polynomial
|
||||
/// * `gamma` - kernel coefficient
|
||||
/// * `coef0` - independent term in kernel function
|
||||
pub fn polynomial<T: RealNumber>(degree: T, gamma: T, coef0: T) -> PolynomialKernel<T> {
|
||||
PolynomialKernel {
|
||||
degree,
|
||||
gamma,
|
||||
coef0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Polynomial kernel
|
||||
/// * `degree` - degree of the polynomial
|
||||
/// * `n_features` - number of features in vector
|
||||
pub fn polynomial_with_degree<T: RealNumber>(
|
||||
degree: T,
|
||||
n_features: usize,
|
||||
) -> PolynomialKernel<T> {
|
||||
let coef0 = T::one();
|
||||
let gamma = T::one() / T::from_usize(n_features).unwrap();
|
||||
Kernels::polynomial(degree, gamma, coef0)
|
||||
}
|
||||
|
||||
/// Sigmoid kernel
|
||||
/// * `gamma` - kernel coefficient
|
||||
/// * `coef0` - independent term in kernel function
|
||||
pub fn sigmoid<T: RealNumber>(gamma: T, coef0: T) -> SigmoidKernel<T> {
|
||||
SigmoidKernel { gamma, coef0 }
|
||||
}
|
||||
|
||||
/// Sigmoid kernel
|
||||
/// * `gamma` - kernel coefficient
|
||||
pub fn sigmoid_with_gamma<T: RealNumber>(gamma: T) -> SigmoidKernel<T> {
|
||||
SigmoidKernel {
|
||||
gamma,
|
||||
coef0: T::one(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Linear Kernel
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LinearKernel {}
|
||||
|
||||
/// Radial basis function (Gaussian) kernel
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RBFKernel<T: RealNumber> {
|
||||
/// kernel coefficient
|
||||
pub gamma: T,
|
||||
}
|
||||
|
||||
/// Polynomial kernel
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PolynomialKernel<T: RealNumber> {
|
||||
/// degree of the polynomial
|
||||
pub degree: T,
|
||||
/// kernel coefficient
|
||||
pub gamma: T,
|
||||
/// independent term in kernel function
|
||||
pub coef0: T,
|
||||
}
|
||||
|
||||
/// Sigmoid (hyperbolic tangent) kernel
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SigmoidKernel<T: RealNumber> {
|
||||
/// kernel coefficient
|
||||
pub gamma: T,
|
||||
/// independent term in kernel function
|
||||
pub coef0: T,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, V: BaseVector<T>> Kernel<T, V> for LinearKernel {
|
||||
fn apply(&self, x_i: &V, x_j: &V) -> T {
|
||||
x_i.dot(x_j)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, V: BaseVector<T>> Kernel<T, V> for RBFKernel<T> {
|
||||
fn apply(&self, x_i: &V, x_j: &V) -> T {
|
||||
let v_diff = x_i.sub(x_j);
|
||||
(-self.gamma * v_diff.mul(&v_diff).sum()).exp()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, V: BaseVector<T>> Kernel<T, V> for PolynomialKernel<T> {
|
||||
fn apply(&self, x_i: &V, x_j: &V) -> T {
|
||||
let dot = x_i.dot(x_j);
|
||||
(self.gamma * dot + self.coef0).powf(self.degree)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, V: BaseVector<T>> Kernel<T, V> for SigmoidKernel<T> {
|
||||
fn apply(&self, x_i: &V, x_j: &V) -> T {
|
||||
let dot = x_i.dot(x_j);
|
||||
(self.gamma * dot + self.coef0).tanh()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn linear_kernel() {
|
||||
let v1 = vec![1., 2., 3.];
|
||||
let v2 = vec![4., 5., 6.];
|
||||
|
||||
assert_eq!(32f64, Kernels::linear().apply(&v1, &v2));
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn rbf_kernel() {
|
||||
let v1 = vec![1., 2., 3.];
|
||||
let v2 = vec![4., 5., 6.];
|
||||
|
||||
assert!((0.2265f64 - Kernels::rbf(0.055).apply(&v1, &v2)).abs() < 1e-4);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn polynomial_kernel() {
|
||||
let v1 = vec![1., 2., 3.];
|
||||
let v2 = vec![4., 5., 6.];
|
||||
|
||||
assert!(
|
||||
(4913f64 - Kernels::polynomial(3.0, 0.5, 1.0).apply(&v1, &v2)).abs()
|
||||
< std::f64::EPSILON
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn sigmoid_kernel() {
|
||||
let v1 = vec![1., 2., 3.];
|
||||
let v2 = vec![4., 5., 6.];
|
||||
|
||||
assert!((0.3969f64 - Kernels::sigmoid(0.01, 0.1).apply(&v1, &v2)).abs() < 1e-4);
|
||||
}
|
||||
}
|
||||
+904
@@ -0,0 +1,904 @@
|
||||
//! # Support Vector Classifier.
|
||||
//!
|
||||
//! Support Vector Classifier (SVC) is a binary classifier that uses an optimal hyperplane to separate the points in the input variable space by their class.
|
||||
//!
|
||||
//! During training, SVC chooses a Maximal-Margin hyperplane that can separate all training instances with the largest margin.
|
||||
//! The margin is calculated as the perpendicular distance from the boundary to only the closest points. Hence, only these points are relevant in defining
|
||||
//! the hyperplane and in the construction of the classifier. These points are called the support vectors.
|
||||
//!
|
||||
//! While SVC selects a hyperplane with the largest margin it allows some points in the training data to violate the separating boundary.
|
||||
//! The parameter `C` > 0 gives you control over how SVC will handle violating points. The bigger the value of this parameter the more we penalize the algorithm
|
||||
//! for incorrectly classified points. In other words, setting this parameter to a small value will result in a classifier that allows for a big number
|
||||
//! of misclassified samples. Mathematically, SVC optimization problem can be defined as:
|
||||
//!
|
||||
//! \\[\underset{w, \zeta}{minimize} \space \space \frac{1}{2} \lVert \vec{w} \rVert^2 + C\sum_{i=1}^m \zeta_i \\]
|
||||
//!
|
||||
//! subject to:
|
||||
//!
|
||||
//! \\[y_i(\langle\vec{w}, \vec{x}_i \rangle + b) \geq 1 - \zeta_i \\]
|
||||
//! \\[\zeta_i \geq 0 for \space any \space i = 1, ... , m\\]
|
||||
//!
|
||||
//! Where \\( m \\) is a number of training samples, \\( y_i \\) is a label value (either 1 or -1) and \\(\langle\vec{w}, \vec{x}_i \rangle + b\\) is a decision boundary.
|
||||
//!
|
||||
//! To solve this optimization problem, SmartCore uses an [approximate SVM solver](https://leon.bottou.org/projects/lasvm).
|
||||
//! The optimizer reaches accuracies similar to that of a real SVM after performing two passes through the training examples. You can choose the number of passes
|
||||
//! through the data that the algorithm takes by changing the `epoch` parameter of the classifier.
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::svm::Kernels;
|
||||
//! use smartcore::svm::svc::{SVC, SVCParameters};
|
||||
//!
|
||||
//! // Iris dataset
|
||||
//! let x = DenseMatrix::from_2d_array(&[
|
||||
//! &[5.1, 3.5, 1.4, 0.2],
|
||||
//! &[4.9, 3.0, 1.4, 0.2],
|
||||
//! &[4.7, 3.2, 1.3, 0.2],
|
||||
//! &[4.6, 3.1, 1.5, 0.2],
|
||||
//! &[5.0, 3.6, 1.4, 0.2],
|
||||
//! &[5.4, 3.9, 1.7, 0.4],
|
||||
//! &[4.6, 3.4, 1.4, 0.3],
|
||||
//! &[5.0, 3.4, 1.5, 0.2],
|
||||
//! &[4.4, 2.9, 1.4, 0.2],
|
||||
//! &[4.9, 3.1, 1.5, 0.1],
|
||||
//! &[7.0, 3.2, 4.7, 1.4],
|
||||
//! &[6.4, 3.2, 4.5, 1.5],
|
||||
//! &[6.9, 3.1, 4.9, 1.5],
|
||||
//! &[5.5, 2.3, 4.0, 1.3],
|
||||
//! &[6.5, 2.8, 4.6, 1.5],
|
||||
//! &[5.7, 2.8, 4.5, 1.3],
|
||||
//! &[6.3, 3.3, 4.7, 1.6],
|
||||
//! &[4.9, 2.4, 3.3, 1.0],
|
||||
//! &[6.6, 2.9, 4.6, 1.3],
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//! let y = vec![ 0., 0., 0., 0., 0., 0., 0., 0.,
|
||||
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
|
||||
//!
|
||||
//! let svc = SVC::fit(&x, &y, SVCParameters::default().with_c(200.0)).unwrap();
|
||||
//!
|
||||
//! let y_hat = svc.predict(&x).unwrap();
|
||||
//! ```
|
||||
//!
|
||||
//! ## References:
|
||||
//!
|
||||
//! * ["Support Vector Machines", Kowalczyk A., 2017](https://www.svm-tutorial.com/2017/10/support-vector-machines-succinctly-released/)
|
||||
//! * ["Fast Kernel Classifiers with Online and Active Learning", Bordes A., Ertekin S., Weston J., Bottou L., 2005](https://www.jmlr.org/papers/volume6/bordes05a/bordes05a.pdf)
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use rand::seq::SliceRandom;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::svm::{Kernel, Kernels, LinearKernel};
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// SVC Parameters
|
||||
pub struct SVCParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||
/// Number of epochs.
|
||||
pub epoch: usize,
|
||||
/// Regularization parameter.
|
||||
pub c: T,
|
||||
/// Tolerance for stopping criterion.
|
||||
pub tol: T,
|
||||
/// The kernel function.
|
||||
pub kernel: K,
|
||||
/// Unused parameter.
|
||||
m: PhantomData<M>,
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
#[cfg_attr(
|
||||
feature = "serde",
|
||||
serde(bound(
|
||||
serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize",
|
||||
deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>",
|
||||
))
|
||||
)]
|
||||
/// Support Vector Classifier
|
||||
pub struct SVC<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||
classes: Vec<T>,
|
||||
kernel: K,
|
||||
instances: Vec<M::RowVector>,
|
||||
w: Vec<T>,
|
||||
b: T,
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
struct SupportVector<T: RealNumber, V: BaseVector<T>> {
|
||||
index: usize,
|
||||
x: V,
|
||||
alpha: T,
|
||||
grad: T,
|
||||
cmin: T,
|
||||
cmax: T,
|
||||
k: T,
|
||||
}
|
||||
|
||||
struct Cache<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||
kernel: &'a K,
|
||||
data: HashMap<(usize, usize), T>,
|
||||
phantom: PhantomData<M>,
|
||||
}
|
||||
|
||||
struct Optimizer<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||
x: &'a M,
|
||||
y: &'a M::RowVector,
|
||||
parameters: &'a SVCParameters<T, M, K>,
|
||||
svmin: usize,
|
||||
svmax: usize,
|
||||
gmin: T,
|
||||
gmax: T,
|
||||
tau: T,
|
||||
sv: Vec<SupportVector<T, M::RowVector>>,
|
||||
kernel: &'a K,
|
||||
recalculate_minmax_grad: bool,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVCParameters<T, M, K> {
|
||||
/// Number of epochs.
|
||||
pub fn with_epoch(mut self, epoch: usize) -> Self {
|
||||
self.epoch = epoch;
|
||||
self
|
||||
}
|
||||
/// Regularization parameter.
|
||||
pub fn with_c(mut self, c: T) -> Self {
|
||||
self.c = c;
|
||||
self
|
||||
}
|
||||
/// Tolerance for stopping criterion.
|
||||
pub fn with_tol(mut self, tol: T) -> Self {
|
||||
self.tol = tol;
|
||||
self
|
||||
}
|
||||
/// The kernel function.
|
||||
pub fn with_kernel<KK: Kernel<T, M::RowVector>>(&self, kernel: KK) -> SVCParameters<T, M, KK> {
|
||||
SVCParameters {
|
||||
epoch: self.epoch,
|
||||
c: self.c,
|
||||
tol: self.tol,
|
||||
kernel,
|
||||
m: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Default for SVCParameters<T, M, LinearKernel> {
|
||||
fn default() -> Self {
|
||||
SVCParameters {
|
||||
epoch: 2,
|
||||
c: T::one(),
|
||||
tol: T::from_f64(1e-3).unwrap(),
|
||||
kernel: Kernels::linear(),
|
||||
m: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>>
|
||||
SupervisedEstimator<M, M::RowVector, SVCParameters<T, M, K>> for SVC<T, M, K>
|
||||
{
|
||||
fn fit(x: &M, y: &M::RowVector, parameters: SVCParameters<T, M, K>) -> Result<Self, Failed> {
|
||||
SVC::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Predictor<M, M::RowVector>
|
||||
for SVC<T, M, K>
|
||||
{
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVC<T, M, K> {
|
||||
/// Fits SVC to your data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - class labels
|
||||
/// * `parameters` - optional parameters, use `Default::default()` to set parameters to default values.
|
||||
pub fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: SVCParameters<T, M, K>,
|
||||
) -> Result<SVC<T, M, K>, Failed> {
|
||||
let (n, _) = x.shape();
|
||||
|
||||
if n != y.len() {
|
||||
return Err(Failed::fit(
|
||||
"Number of rows of X doesn\'t match number of rows of Y",
|
||||
));
|
||||
}
|
||||
|
||||
let classes = y.unique();
|
||||
|
||||
if classes.len() != 2 {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Incorrect number of classes {}",
|
||||
classes.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// Make sure class labels are either 1 or -1
|
||||
let mut y = y.clone();
|
||||
for i in 0..y.len() {
|
||||
let y_v = y.get(i);
|
||||
if y_v != -T::one() || y_v != T::one() {
|
||||
match y_v == classes[0] {
|
||||
true => y.set(i, -T::one()),
|
||||
false => y.set(i, T::one()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let optimizer = Optimizer::new(x, &y, ¶meters.kernel, ¶meters);
|
||||
|
||||
let (support_vectors, weight, b) = optimizer.optimize();
|
||||
|
||||
Ok(SVC {
|
||||
classes,
|
||||
kernel: parameters.kernel,
|
||||
instances: support_vectors,
|
||||
w: weight,
|
||||
b,
|
||||
})
|
||||
}
|
||||
|
||||
/// Predicts estimated class labels from `x`
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let mut y_hat = self.decision_function(x)?;
|
||||
|
||||
for i in 0..y_hat.len() {
|
||||
let cls_idx = match y_hat.get(i) > T::zero() {
|
||||
false => self.classes[0],
|
||||
true => self.classes[1],
|
||||
};
|
||||
|
||||
y_hat.set(i, cls_idx);
|
||||
}
|
||||
|
||||
Ok(y_hat)
|
||||
}
|
||||
|
||||
/// Evaluates the decision function for the rows in `x`
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn decision_function(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let (n, _) = x.shape();
|
||||
let mut y_hat = M::RowVector::zeros(n);
|
||||
|
||||
for i in 0..n {
|
||||
y_hat.set(i, self.predict_for_row(x.get_row(i)));
|
||||
}
|
||||
|
||||
Ok(y_hat)
|
||||
}
|
||||
|
||||
fn predict_for_row(&self, x: M::RowVector) -> T {
|
||||
let mut f = self.b;
|
||||
|
||||
for i in 0..self.instances.len() {
|
||||
f += self.w[i] * self.kernel.apply(&x, &self.instances[i]);
|
||||
}
|
||||
|
||||
f
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> PartialEq for SVC<T, M, K> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if (self.b - other.b).abs() > T::epsilon() * T::two()
|
||||
|| self.w.len() != other.w.len()
|
||||
|| self.instances.len() != other.instances.len()
|
||||
{
|
||||
false
|
||||
} else {
|
||||
for i in 0..self.w.len() {
|
||||
if (self.w[i] - other.w[i]).abs() > T::epsilon() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
for i in 0..self.instances.len() {
|
||||
if !self.instances[i].approximate_eq(&other.instances[i], T::epsilon()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, V: BaseVector<T>> SupportVector<T, V> {
|
||||
fn new<K: Kernel<T, V>>(i: usize, x: V, y: T, g: T, c: T, k: &K) -> SupportVector<T, V> {
|
||||
let k_v = k.apply(&x, &x);
|
||||
let (cmin, cmax) = if y > T::zero() {
|
||||
(T::zero(), c)
|
||||
} else {
|
||||
(-c, T::zero())
|
||||
};
|
||||
SupportVector {
|
||||
index: i,
|
||||
x,
|
||||
grad: g,
|
||||
k: k_v,
|
||||
alpha: T::zero(),
|
||||
cmin,
|
||||
cmax,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Cache<'a, T, M, K> {
|
||||
fn new(kernel: &'a K) -> Cache<'a, T, M, K> {
|
||||
Cache {
|
||||
kernel,
|
||||
data: HashMap::new(),
|
||||
phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn get(&mut self, i: &SupportVector<T, M::RowVector>, j: &SupportVector<T, M::RowVector>) -> T {
|
||||
let idx_i = i.index;
|
||||
let idx_j = j.index;
|
||||
#[allow(clippy::or_fun_call)]
|
||||
let entry = self
|
||||
.data
|
||||
.entry((idx_i, idx_j))
|
||||
.or_insert(self.kernel.apply(&i.x, &j.x));
|
||||
*entry
|
||||
}
|
||||
|
||||
fn insert(&mut self, key: (usize, usize), value: T) {
|
||||
self.data.insert(key, value);
|
||||
}
|
||||
|
||||
fn drop(&mut self, idxs_to_drop: HashSet<usize>) {
|
||||
self.data.retain(|k, _| !idxs_to_drop.contains(&k.0));
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a, T, M, K> {
|
||||
fn new(
|
||||
x: &'a M,
|
||||
y: &'a M::RowVector,
|
||||
kernel: &'a K,
|
||||
parameters: &'a SVCParameters<T, M, K>,
|
||||
) -> Optimizer<'a, T, M, K> {
|
||||
let (n, _) = x.shape();
|
||||
|
||||
Optimizer {
|
||||
x,
|
||||
y,
|
||||
parameters,
|
||||
svmin: 0,
|
||||
svmax: 0,
|
||||
gmin: T::max_value(),
|
||||
gmax: T::min_value(),
|
||||
tau: T::from_f64(1e-12).unwrap(),
|
||||
sv: Vec::with_capacity(n),
|
||||
kernel,
|
||||
recalculate_minmax_grad: true,
|
||||
}
|
||||
}
|
||||
|
||||
fn optimize(mut self) -> (Vec<M::RowVector>, Vec<T>, T) {
|
||||
let (n, _) = self.x.shape();
|
||||
|
||||
let mut cache = Cache::new(self.kernel);
|
||||
|
||||
self.initialize(&mut cache);
|
||||
|
||||
let tol = self.parameters.tol;
|
||||
let good_enough = T::from_i32(1000).unwrap();
|
||||
|
||||
for _ in 0..self.parameters.epoch {
|
||||
for i in Self::permutate(n) {
|
||||
self.process(i, self.x.get_row(i), self.y.get(i), &mut cache);
|
||||
loop {
|
||||
self.reprocess(tol, &mut cache);
|
||||
self.find_min_max_gradient();
|
||||
if self.gmax - self.gmin < good_enough {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.finish(&mut cache);
|
||||
|
||||
let mut support_vectors: Vec<M::RowVector> = Vec::new();
|
||||
let mut w: Vec<T> = Vec::new();
|
||||
|
||||
let b = (self.gmax + self.gmin) / T::two();
|
||||
|
||||
for v in self.sv {
|
||||
support_vectors.push(v.x);
|
||||
w.push(v.alpha);
|
||||
}
|
||||
|
||||
(support_vectors, w, b)
|
||||
}
|
||||
|
||||
fn initialize(&mut self, cache: &mut Cache<'_, T, M, K>) {
|
||||
let (n, _) = self.x.shape();
|
||||
let few = 5;
|
||||
let mut cp = 0;
|
||||
let mut cn = 0;
|
||||
|
||||
for i in Self::permutate(n) {
|
||||
if self.y.get(i) == T::one() && cp < few {
|
||||
if self.process(i, self.x.get_row(i), self.y.get(i), cache) {
|
||||
cp += 1;
|
||||
}
|
||||
} else if self.y.get(i) == -T::one()
|
||||
&& cn < few
|
||||
&& self.process(i, self.x.get_row(i), self.y.get(i), cache)
|
||||
{
|
||||
cn += 1;
|
||||
}
|
||||
|
||||
if cp >= few && cn >= few {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn process(&mut self, i: usize, x: M::RowVector, y: T, cache: &mut Cache<'_, T, M, K>) -> bool {
|
||||
for j in 0..self.sv.len() {
|
||||
if self.sv[j].index == i {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
let mut g = y;
|
||||
|
||||
let mut cache_values: Vec<((usize, usize), T)> = Vec::new();
|
||||
|
||||
for v in self.sv.iter() {
|
||||
let k = self.kernel.apply(&v.x, &x);
|
||||
cache_values.push(((i, v.index), k));
|
||||
g -= v.alpha * k;
|
||||
}
|
||||
|
||||
self.find_min_max_gradient();
|
||||
|
||||
if self.gmin < self.gmax
|
||||
&& ((y > T::zero() && g < self.gmin) || (y < T::zero() && g > self.gmax))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
for v in cache_values {
|
||||
cache.insert(v.0, v.1);
|
||||
}
|
||||
|
||||
self.sv.insert(
|
||||
0,
|
||||
SupportVector::new(i, x, y, g, self.parameters.c, self.kernel),
|
||||
);
|
||||
|
||||
if y > T::zero() {
|
||||
self.smo(None, Some(0), T::zero(), cache);
|
||||
} else {
|
||||
self.smo(Some(0), None, T::zero(), cache);
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
fn reprocess(&mut self, tol: T, cache: &mut Cache<'_, T, M, K>) -> bool {
|
||||
let status = self.smo(None, None, tol, cache);
|
||||
self.clean(cache);
|
||||
status
|
||||
}
|
||||
|
||||
fn finish(&mut self, cache: &mut Cache<'_, T, M, K>) {
|
||||
let mut max_iter = self.sv.len();
|
||||
|
||||
while self.smo(None, None, self.parameters.tol, cache) && max_iter > 0 {
|
||||
max_iter -= 1;
|
||||
}
|
||||
|
||||
self.clean(cache);
|
||||
}
|
||||
|
||||
fn find_min_max_gradient(&mut self) {
|
||||
if !self.recalculate_minmax_grad {
|
||||
return;
|
||||
}
|
||||
|
||||
self.gmin = T::max_value();
|
||||
self.gmax = T::min_value();
|
||||
|
||||
for i in 0..self.sv.len() {
|
||||
let v = &self.sv[i];
|
||||
let g = v.grad;
|
||||
let a = v.alpha;
|
||||
if g < self.gmin && a > v.cmin {
|
||||
self.gmin = g;
|
||||
self.svmin = i;
|
||||
}
|
||||
if g > self.gmax && a < v.cmax {
|
||||
self.gmax = g;
|
||||
self.svmax = i;
|
||||
}
|
||||
}
|
||||
|
||||
self.recalculate_minmax_grad = false
|
||||
}
|
||||
|
||||
fn clean(&mut self, cache: &mut Cache<'_, T, M, K>) {
|
||||
self.find_min_max_gradient();
|
||||
|
||||
let gmax = self.gmax;
|
||||
let gmin = self.gmin;
|
||||
|
||||
let mut idxs_to_drop: HashSet<usize> = HashSet::new();
|
||||
|
||||
self.sv.retain(|v| {
|
||||
if v.alpha == T::zero()
|
||||
&& ((v.grad >= gmax && T::zero() >= v.cmax)
|
||||
|| (v.grad <= gmin && T::zero() <= v.cmin))
|
||||
{
|
||||
idxs_to_drop.insert(v.index);
|
||||
return false;
|
||||
};
|
||||
true
|
||||
});
|
||||
|
||||
cache.drop(idxs_to_drop);
|
||||
self.recalculate_minmax_grad = true;
|
||||
}
|
||||
|
||||
fn permutate(n: usize) -> Vec<usize> {
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut range: Vec<usize> = (0..n).collect();
|
||||
range.shuffle(&mut rng);
|
||||
range
|
||||
}
|
||||
|
||||
fn select_pair(
|
||||
&mut self,
|
||||
idx_1: Option<usize>,
|
||||
idx_2: Option<usize>,
|
||||
cache: &mut Cache<'_, T, M, K>,
|
||||
) -> Option<(usize, usize, T)> {
|
||||
match (idx_1, idx_2) {
|
||||
(None, None) => {
|
||||
if self.gmax > -self.gmin {
|
||||
self.select_pair(None, Some(self.svmax), cache)
|
||||
} else {
|
||||
self.select_pair(Some(self.svmin), None, cache)
|
||||
}
|
||||
}
|
||||
(Some(idx_1), None) => {
|
||||
let sv1 = &self.sv[idx_1];
|
||||
let mut idx_2 = None;
|
||||
let mut k_v_12 = None;
|
||||
let km = sv1.k;
|
||||
let gm = sv1.grad;
|
||||
let mut best = T::zero();
|
||||
for i in 0..self.sv.len() {
|
||||
let v = &self.sv[i];
|
||||
let z = v.grad - gm;
|
||||
let k = cache.get(sv1, v);
|
||||
let mut curv = km + v.k - T::two() * k;
|
||||
if curv <= T::zero() {
|
||||
curv = self.tau;
|
||||
}
|
||||
let mu = z / curv;
|
||||
if (mu > T::zero() && v.alpha < v.cmax) || (mu < T::zero() && v.alpha > v.cmin)
|
||||
{
|
||||
let gain = z * mu;
|
||||
if gain > best {
|
||||
best = gain;
|
||||
idx_2 = Some(i);
|
||||
k_v_12 = Some(k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
idx_2.map(|idx_2| {
|
||||
(
|
||||
idx_1,
|
||||
idx_2,
|
||||
k_v_12.unwrap_or_else(|| {
|
||||
self.kernel.apply(&self.sv[idx_1].x, &self.sv[idx_2].x)
|
||||
}),
|
||||
)
|
||||
})
|
||||
}
|
||||
(None, Some(idx_2)) => {
|
||||
let mut idx_1 = None;
|
||||
let sv2 = &self.sv[idx_2];
|
||||
let mut k_v_12 = None;
|
||||
let km = sv2.k;
|
||||
let gm = sv2.grad;
|
||||
let mut best = T::zero();
|
||||
for i in 0..self.sv.len() {
|
||||
let v = &self.sv[i];
|
||||
let z = gm - v.grad;
|
||||
let k = cache.get(sv2, v);
|
||||
let mut curv = km + v.k - T::two() * k;
|
||||
if curv <= T::zero() {
|
||||
curv = self.tau;
|
||||
}
|
||||
|
||||
let mu = z / curv;
|
||||
if (mu > T::zero() && v.alpha > v.cmin) || (mu < T::zero() && v.alpha < v.cmax)
|
||||
{
|
||||
let gain = z * mu;
|
||||
if gain > best {
|
||||
best = gain;
|
||||
idx_1 = Some(i);
|
||||
k_v_12 = Some(k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
idx_1.map(|idx_1| {
|
||||
(
|
||||
idx_1,
|
||||
idx_2,
|
||||
k_v_12.unwrap_or_else(|| {
|
||||
self.kernel.apply(&self.sv[idx_1].x, &self.sv[idx_2].x)
|
||||
}),
|
||||
)
|
||||
})
|
||||
}
|
||||
(Some(idx_1), Some(idx_2)) => Some((
|
||||
idx_1,
|
||||
idx_2,
|
||||
self.kernel.apply(&self.sv[idx_1].x, &self.sv[idx_2].x),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn smo(
|
||||
&mut self,
|
||||
idx_1: Option<usize>,
|
||||
idx_2: Option<usize>,
|
||||
tol: T,
|
||||
cache: &mut Cache<'_, T, M, K>,
|
||||
) -> bool {
|
||||
match self.select_pair(idx_1, idx_2, cache) {
|
||||
Some((idx_1, idx_2, k_v_12)) => {
|
||||
let mut curv = self.sv[idx_1].k + self.sv[idx_2].k - T::two() * k_v_12;
|
||||
if curv <= T::zero() {
|
||||
curv = self.tau;
|
||||
}
|
||||
|
||||
let mut step = (self.sv[idx_2].grad - self.sv[idx_1].grad) / curv;
|
||||
|
||||
if step >= T::zero() {
|
||||
let mut ostep = self.sv[idx_1].alpha - self.sv[idx_1].cmin;
|
||||
if ostep < step {
|
||||
step = ostep;
|
||||
}
|
||||
ostep = self.sv[idx_2].cmax - self.sv[idx_2].alpha;
|
||||
if ostep < step {
|
||||
step = ostep;
|
||||
}
|
||||
} else {
|
||||
let mut ostep = self.sv[idx_2].cmin - self.sv[idx_2].alpha;
|
||||
if ostep > step {
|
||||
step = ostep;
|
||||
}
|
||||
ostep = self.sv[idx_1].alpha - self.sv[idx_1].cmax;
|
||||
if ostep > step {
|
||||
step = ostep;
|
||||
}
|
||||
}
|
||||
|
||||
self.update(idx_1, idx_2, step, cache);
|
||||
|
||||
self.gmax - self.gmin > tol
|
||||
}
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn update(&mut self, v1: usize, v2: usize, step: T, cache: &mut Cache<'_, T, M, K>) {
|
||||
self.sv[v1].alpha -= step;
|
||||
self.sv[v2].alpha += step;
|
||||
|
||||
for i in 0..self.sv.len() {
|
||||
let k2 = cache.get(&self.sv[v2], &self.sv[i]);
|
||||
let k1 = cache.get(&self.sv[v1], &self.sv[i]);
|
||||
self.sv[i].grad -= step * (k2 - k1);
|
||||
}
|
||||
|
||||
self.recalculate_minmax_grad = true;
|
||||
self.find_min_max_gradient();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::metrics::accuracy;
|
||||
#[cfg(feature = "serde")]
|
||||
use crate::svm::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn svc_fit_predict() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[5.7, 2.8, 4.5, 1.3],
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
|
||||
let y: Vec<f64> = vec![
|
||||
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
|
||||
let y_hat = SVC::fit(
|
||||
&x,
|
||||
&y,
|
||||
SVCParameters::default()
|
||||
.with_c(200.0)
|
||||
.with_kernel(Kernels::linear()),
|
||||
)
|
||||
.and_then(|lr| lr.predict(&x))
|
||||
.unwrap();
|
||||
|
||||
assert!(accuracy(&y_hat, &y) >= 0.9);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn svc_fit_decision_function() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[4.0, 0.0], &[0.0, 4.0], &[8.0, 0.0], &[0.0, 8.0]]);
|
||||
|
||||
let x2 = DenseMatrix::from_2d_array(&[
|
||||
&[3.0, 3.0],
|
||||
&[4.0, 4.0],
|
||||
&[6.0, 6.0],
|
||||
&[10.0, 10.0],
|
||||
&[1.0, 1.0],
|
||||
&[0.0, 0.0],
|
||||
]);
|
||||
|
||||
let y: Vec<f64> = vec![0., 0., 1., 1.];
|
||||
|
||||
let y_hat = SVC::fit(
|
||||
&x,
|
||||
&y,
|
||||
SVCParameters::default()
|
||||
.with_c(200.0)
|
||||
.with_kernel(Kernels::linear()),
|
||||
)
|
||||
.and_then(|lr| lr.decision_function(&x2))
|
||||
.unwrap();
|
||||
|
||||
// x can be classified by a straight line through [6.0, 0.0] and [0.0, 6.0],
|
||||
// so the score should increase as points get further away from that line
|
||||
println!("{:?}", y_hat);
|
||||
assert!(y_hat[1] < y_hat[2]);
|
||||
assert!(y_hat[2] < y_hat[3]);
|
||||
|
||||
// for negative scores the score should decrease
|
||||
assert!(y_hat[4] > y_hat[5]);
|
||||
|
||||
// y_hat[0] is on the line, so its score should be close to 0
|
||||
assert!(y_hat[0].abs() <= 0.1);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn svc_fit_predict_rbf() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[5.7, 2.8, 4.5, 1.3],
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
|
||||
let y: Vec<f64> = vec![
|
||||
-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
1.,
|
||||
];
|
||||
|
||||
let y_hat = SVC::fit(
|
||||
&x,
|
||||
&y,
|
||||
SVCParameters::default()
|
||||
.with_c(1.0)
|
||||
.with_kernel(Kernels::rbf(0.7)),
|
||||
)
|
||||
.and_then(|lr| lr.predict(&x))
|
||||
.unwrap();
|
||||
|
||||
assert!(accuracy(&y_hat, &y) >= 0.9);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn svc_serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[5.7, 2.8, 4.5, 1.3],
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
|
||||
let y: Vec<f64> = vec![
|
||||
-1., -1., -1., -1., -1., -1., -1., -1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
|
||||
let svc = SVC::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let deserialized_svc: SVC<f64, DenseMatrix<f64>, LinearKernel> =
|
||||
serde_json::from_str(&serde_json::to_string(&svc).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(svc, deserialized_svc);
|
||||
}
|
||||
}
|
||||
+608
@@ -0,0 +1,608 @@
|
||||
//! # Epsilon-Support Vector Regression.
|
||||
//!
|
||||
//! Support Vector Regression (SVR) is a popular algorithm used for regression that uses the same principle as SVM.
|
||||
//!
|
||||
//! Just like [SVC](../svc/index.html) SVR finds optimal decision boundary, \\(f(x)\\) that separates all training instances with the largest margin.
|
||||
//! Unlike SVC, in \\(\epsilon\\)-SVR regression the goal is to find a function \\(f(x)\\) that has at most \\(\epsilon\\) deviation from the
|
||||
//! known targets \\(y_i\\) for all the training data. To find this function, we need to find solution to this optimization problem:
|
||||
//!
|
||||
//! \\[\underset{w, \zeta}{minimize} \space \space \frac{1}{2} \lVert \vec{w} \rVert^2 + C\sum_{i=1}^m \zeta_i \\]
|
||||
//!
|
||||
//! subject to:
|
||||
//!
|
||||
//! \\[\lvert y_i - \langle\vec{w}, \vec{x}_i \rangle - b \rvert \leq \epsilon + \zeta_i \\]
|
||||
//! \\[\lvert \langle\vec{w}, \vec{x}_i \rangle + b - y_i \rvert \leq \epsilon + \zeta_i \\]
|
||||
//! \\[\zeta_i \geq 0 for \space any \space i = 1, ... , m\\]
|
||||
//!
|
||||
//! Where \\( m \\) is a number of training samples, \\( y_i \\) is a target value and \\(\langle\vec{w}, \vec{x}_i \rangle + b\\) is a decision boundary.
|
||||
//!
|
||||
//! The parameter `C` > 0 determines the trade-off between the flatness of \\(f(x)\\) and the amount up to which deviations larger than \\(\epsilon\\) are tolerated
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::linear::linear_regression::*;
|
||||
//! use smartcore::svm::*;
|
||||
//! use smartcore::svm::svr::{SVR, SVRParameters};
|
||||
//!
|
||||
//! // Longley dataset (https://www.statsmodels.org/stable/datasets/generated/longley.html)
|
||||
//! let x = DenseMatrix::from_2d_array(&[
|
||||
//! &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
//! &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
//! &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
//! &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
|
||||
//! &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
//! &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
|
||||
//! &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
|
||||
//! &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
|
||||
//! &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
//! &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
//! &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
//! &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
|
||||
//! &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
//! ]);
|
||||
//!
|
||||
//! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0,
|
||||
//! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
||||
//!
|
||||
//! let svr = SVR::fit(&x, &y, SVRParameters::default().with_eps(2.0).with_c(10.0)).unwrap();
|
||||
//!
|
||||
//! let y_hat = svr.predict(&x).unwrap();
|
||||
//! ```
|
||||
//!
|
||||
//! ## References:
|
||||
//!
|
||||
//! * ["Support Vector Machines", Kowalczyk A., 2017](https://www.svm-tutorial.com/2017/10/support-vector-machines-succinctly-released/)
|
||||
//! * ["A Fast Algorithm for Training Support Vector Machines", Platt J.C., 1998](https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/tr-98-14.pdf)
|
||||
//! * ["Working Set Selection Using Second Order Information for Training Support Vector Machines", Rong-En Fan et al., 2005](https://www.jmlr.org/papers/volume6/fan05a/fan05a.pdf)
|
||||
//! * ["A tutorial on support vector regression", Smola A.J., Scholkopf B., 2003](https://alex.smola.org/papers/2004/SmoSch04.pdf)
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
use std::cell::{Ref, RefCell};
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::svm::{Kernel, Kernels, LinearKernel};
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// SVR Parameters
|
||||
pub struct SVRParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||
/// Epsilon in the epsilon-SVR model.
|
||||
pub eps: T,
|
||||
/// Regularization parameter.
|
||||
pub c: T,
|
||||
/// Tolerance for stopping criterion.
|
||||
pub tol: T,
|
||||
/// The kernel function.
|
||||
pub kernel: K,
|
||||
/// Unused parameter.
|
||||
m: PhantomData<M>,
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
#[cfg_attr(
|
||||
feature = "serde",
|
||||
serde(bound(
|
||||
serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize",
|
||||
deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>",
|
||||
))
|
||||
)]
|
||||
|
||||
/// Epsilon-Support Vector Regression
|
||||
pub struct SVR<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||
kernel: K,
|
||||
instances: Vec<M::RowVector>,
|
||||
w: Vec<T>,
|
||||
b: T,
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
struct SupportVector<T: RealNumber, V: BaseVector<T>> {
|
||||
index: usize,
|
||||
x: V,
|
||||
alpha: [T; 2],
|
||||
grad: [T; 2],
|
||||
k: T,
|
||||
}
|
||||
|
||||
/// Sequential Minimal Optimization algorithm
|
||||
struct Optimizer<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||
tol: T,
|
||||
c: T,
|
||||
svmin: usize,
|
||||
svmax: usize,
|
||||
gmin: T,
|
||||
gmax: T,
|
||||
gminindex: usize,
|
||||
gmaxindex: usize,
|
||||
tau: T,
|
||||
sv: Vec<SupportVector<T, M::RowVector>>,
|
||||
kernel: &'a K,
|
||||
}
|
||||
|
||||
struct Cache<T: Clone> {
|
||||
data: Vec<RefCell<Option<Vec<T>>>>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVRParameters<T, M, K> {
|
||||
/// Epsilon in the epsilon-SVR model.
|
||||
pub fn with_eps(mut self, eps: T) -> Self {
|
||||
self.eps = eps;
|
||||
self
|
||||
}
|
||||
/// Regularization parameter.
|
||||
pub fn with_c(mut self, c: T) -> Self {
|
||||
self.c = c;
|
||||
self
|
||||
}
|
||||
/// Tolerance for stopping criterion.
|
||||
pub fn with_tol(mut self, tol: T) -> Self {
|
||||
self.tol = tol;
|
||||
self
|
||||
}
|
||||
/// The kernel function.
|
||||
pub fn with_kernel<KK: Kernel<T, M::RowVector>>(&self, kernel: KK) -> SVRParameters<T, M, KK> {
|
||||
SVRParameters {
|
||||
eps: self.eps,
|
||||
c: self.c,
|
||||
tol: self.tol,
|
||||
kernel,
|
||||
m: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Default for SVRParameters<T, M, LinearKernel> {
|
||||
fn default() -> Self {
|
||||
SVRParameters {
|
||||
eps: T::from_f64(0.1).unwrap(),
|
||||
c: T::one(),
|
||||
tol: T::from_f64(1e-3).unwrap(),
|
||||
kernel: Kernels::linear(),
|
||||
m: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>>
|
||||
SupervisedEstimator<M, M::RowVector, SVRParameters<T, M, K>> for SVR<T, M, K>
|
||||
{
|
||||
fn fit(x: &M, y: &M::RowVector, parameters: SVRParameters<T, M, K>) -> Result<Self, Failed> {
|
||||
SVR::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Predictor<M, M::RowVector>
|
||||
for SVR<T, M, K>
|
||||
{
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVR<T, M, K> {
|
||||
/// Fits SVR to your data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - target values
|
||||
/// * `kernel` - the kernel function
|
||||
/// * `parameters` - optional parameters, use `Default::default()` to set parameters to default values.
|
||||
pub fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: SVRParameters<T, M, K>,
|
||||
) -> Result<SVR<T, M, K>, Failed> {
|
||||
let (n, _) = x.shape();
|
||||
|
||||
if n != y.len() {
|
||||
return Err(Failed::fit(
|
||||
"Number of rows of X doesn\'t match number of rows of Y",
|
||||
));
|
||||
}
|
||||
|
||||
let optimizer = Optimizer::new(x, y, ¶meters.kernel, ¶meters);
|
||||
|
||||
let (support_vectors, weight, b) = optimizer.smo();
|
||||
|
||||
Ok(SVR {
|
||||
kernel: parameters.kernel,
|
||||
instances: support_vectors,
|
||||
w: weight,
|
||||
b,
|
||||
})
|
||||
}
|
||||
|
||||
/// Predict target values from `x`
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let (n, _) = x.shape();
|
||||
|
||||
let mut y_hat = M::RowVector::zeros(n);
|
||||
|
||||
for i in 0..n {
|
||||
y_hat.set(i, self.predict_for_row(x.get_row(i)));
|
||||
}
|
||||
|
||||
Ok(y_hat)
|
||||
}
|
||||
|
||||
pub(crate) fn predict_for_row(&self, x: M::RowVector) -> T {
|
||||
let mut f = self.b;
|
||||
|
||||
for i in 0..self.instances.len() {
|
||||
f += self.w[i] * self.kernel.apply(&x, &self.instances[i]);
|
||||
}
|
||||
|
||||
f
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> PartialEq for SVR<T, M, K> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if (self.b - other.b).abs() > T::epsilon() * T::two()
|
||||
|| self.w.len() != other.w.len()
|
||||
|| self.instances.len() != other.instances.len()
|
||||
{
|
||||
false
|
||||
} else {
|
||||
for i in 0..self.w.len() {
|
||||
if (self.w[i] - other.w[i]).abs() > T::epsilon() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
for i in 0..self.instances.len() {
|
||||
if !self.instances[i].approximate_eq(&other.instances[i], T::epsilon()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, V: BaseVector<T>> SupportVector<T, V> {
|
||||
fn new<K: Kernel<T, V>>(i: usize, x: V, y: T, eps: T, k: &K) -> SupportVector<T, V> {
|
||||
let k_v = k.apply(&x, &x);
|
||||
SupportVector {
|
||||
index: i,
|
||||
x,
|
||||
grad: [eps + y, eps - y],
|
||||
k: k_v,
|
||||
alpha: [T::zero(), T::zero()],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a, T, M, K> {
|
||||
fn new(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
kernel: &'a K,
|
||||
parameters: &SVRParameters<T, M, K>,
|
||||
) -> Optimizer<'a, T, M, K> {
|
||||
let (n, _) = x.shape();
|
||||
|
||||
let mut support_vectors: Vec<SupportVector<T, M::RowVector>> = Vec::with_capacity(n);
|
||||
|
||||
for i in 0..n {
|
||||
support_vectors.push(SupportVector::new(
|
||||
i,
|
||||
x.get_row(i),
|
||||
y.get(i),
|
||||
parameters.eps,
|
||||
kernel,
|
||||
));
|
||||
}
|
||||
|
||||
Optimizer {
|
||||
tol: parameters.tol,
|
||||
c: parameters.c,
|
||||
svmin: 0,
|
||||
svmax: 0,
|
||||
gmin: T::max_value(),
|
||||
gmax: T::min_value(),
|
||||
gminindex: 0,
|
||||
gmaxindex: 0,
|
||||
tau: T::from_f64(1e-12).unwrap(),
|
||||
sv: support_vectors,
|
||||
kernel,
|
||||
}
|
||||
}
|
||||
|
||||
fn find_min_max_gradient(&mut self) {
|
||||
self.gmin = T::max_value();
|
||||
self.gmax = T::min_value();
|
||||
|
||||
for i in 0..self.sv.len() {
|
||||
let v = &self.sv[i];
|
||||
let g = -v.grad[0];
|
||||
let a = v.alpha[0];
|
||||
if g < self.gmin && a > T::zero() {
|
||||
self.gmin = g;
|
||||
self.gminindex = 0;
|
||||
self.svmin = i;
|
||||
}
|
||||
if g > self.gmax && a < self.c {
|
||||
self.gmax = g;
|
||||
self.gmaxindex = 0;
|
||||
self.svmax = i;
|
||||
}
|
||||
|
||||
let g = v.grad[1];
|
||||
let a = v.alpha[1];
|
||||
if g < self.gmin && a < self.c {
|
||||
self.gmin = g;
|
||||
self.gminindex = 1;
|
||||
self.svmin = i;
|
||||
}
|
||||
if g > self.gmax && a > T::zero() {
|
||||
self.gmax = g;
|
||||
self.gmaxindex = 1;
|
||||
self.svmax = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Solvs the quadratic programming (QP) problem that arises during the training of support-vector machines (SVM) algorithm.
|
||||
/// Returns:
|
||||
/// * support vectors
|
||||
/// * hyperplane parameters: w and b
|
||||
fn smo(mut self) -> (Vec<M::RowVector>, Vec<T>, T) {
|
||||
let cache: Cache<T> = Cache::new(self.sv.len());
|
||||
|
||||
self.find_min_max_gradient();
|
||||
|
||||
while self.gmax - self.gmin > self.tol {
|
||||
let v1 = self.svmax;
|
||||
let i = self.gmaxindex;
|
||||
let old_alpha_i = self.sv[v1].alpha[i];
|
||||
|
||||
let k1 = cache.get(self.sv[v1].index, || {
|
||||
self.sv
|
||||
.iter()
|
||||
.map(|vi| self.kernel.apply(&self.sv[v1].x, &vi.x))
|
||||
.collect()
|
||||
});
|
||||
|
||||
let mut v2 = self.svmin;
|
||||
let mut j = self.gminindex;
|
||||
let mut old_alpha_j = self.sv[v2].alpha[j];
|
||||
|
||||
let mut best = T::zero();
|
||||
let gi = if i == 0 {
|
||||
-self.sv[v1].grad[0]
|
||||
} else {
|
||||
self.sv[v1].grad[1]
|
||||
};
|
||||
for jj in 0..self.sv.len() {
|
||||
let v = &self.sv[jj];
|
||||
let mut curv = self.sv[v1].k + v.k - T::two() * k1[v.index];
|
||||
if curv <= T::zero() {
|
||||
curv = self.tau;
|
||||
}
|
||||
|
||||
let mut gj = -v.grad[0];
|
||||
if v.alpha[0] > T::zero() && gj < gi {
|
||||
let gain = -((gi - gj) * (gi - gj)) / curv;
|
||||
if gain < best {
|
||||
best = gain;
|
||||
v2 = jj;
|
||||
j = 0;
|
||||
old_alpha_j = self.sv[v2].alpha[0];
|
||||
}
|
||||
}
|
||||
|
||||
gj = v.grad[1];
|
||||
if v.alpha[1] < self.c && gj < gi {
|
||||
let gain = -((gi - gj) * (gi - gj)) / curv;
|
||||
if gain < best {
|
||||
best = gain;
|
||||
v2 = jj;
|
||||
j = 1;
|
||||
old_alpha_j = self.sv[v2].alpha[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let k2 = cache.get(self.sv[v2].index, || {
|
||||
self.sv
|
||||
.iter()
|
||||
.map(|vi| self.kernel.apply(&self.sv[v2].x, &vi.x))
|
||||
.collect()
|
||||
});
|
||||
|
||||
let mut curv = self.sv[v1].k + self.sv[v2].k - T::two() * k1[self.sv[v2].index];
|
||||
if curv <= T::zero() {
|
||||
curv = self.tau;
|
||||
}
|
||||
|
||||
if i != j {
|
||||
let delta = (-self.sv[v1].grad[i] - self.sv[v2].grad[j]) / curv;
|
||||
let diff = self.sv[v1].alpha[i] - self.sv[v2].alpha[j];
|
||||
self.sv[v1].alpha[i] += delta;
|
||||
self.sv[v2].alpha[j] += delta;
|
||||
|
||||
if diff > T::zero() {
|
||||
if self.sv[v2].alpha[j] < T::zero() {
|
||||
self.sv[v2].alpha[j] = T::zero();
|
||||
self.sv[v1].alpha[i] = diff;
|
||||
}
|
||||
} else if self.sv[v1].alpha[i] < T::zero() {
|
||||
self.sv[v1].alpha[i] = T::zero();
|
||||
self.sv[v2].alpha[j] = -diff;
|
||||
}
|
||||
|
||||
if diff > T::zero() {
|
||||
if self.sv[v1].alpha[i] > self.c {
|
||||
self.sv[v1].alpha[i] = self.c;
|
||||
self.sv[v2].alpha[j] = self.c - diff;
|
||||
}
|
||||
} else if self.sv[v2].alpha[j] > self.c {
|
||||
self.sv[v2].alpha[j] = self.c;
|
||||
self.sv[v1].alpha[i] = self.c + diff;
|
||||
}
|
||||
} else {
|
||||
let delta = (self.sv[v1].grad[i] - self.sv[v2].grad[j]) / curv;
|
||||
let sum = self.sv[v1].alpha[i] + self.sv[v2].alpha[j];
|
||||
self.sv[v1].alpha[i] -= delta;
|
||||
self.sv[v2].alpha[j] += delta;
|
||||
|
||||
if sum > self.c {
|
||||
if self.sv[v1].alpha[i] > self.c {
|
||||
self.sv[v1].alpha[i] = self.c;
|
||||
self.sv[v2].alpha[j] = sum - self.c;
|
||||
}
|
||||
} else if self.sv[v2].alpha[j] < T::zero() {
|
||||
self.sv[v2].alpha[j] = T::zero();
|
||||
self.sv[v1].alpha[i] = sum;
|
||||
}
|
||||
|
||||
if sum > self.c {
|
||||
if self.sv[v2].alpha[j] > self.c {
|
||||
self.sv[v2].alpha[j] = self.c;
|
||||
self.sv[v1].alpha[i] = sum - self.c;
|
||||
}
|
||||
} else if self.sv[v1].alpha[i] < T::zero() {
|
||||
self.sv[v1].alpha[i] = T::zero();
|
||||
self.sv[v2].alpha[j] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
let delta_alpha_i = self.sv[v1].alpha[i] - old_alpha_i;
|
||||
let delta_alpha_j = self.sv[v2].alpha[j] - old_alpha_j;
|
||||
|
||||
let si = T::two() * T::from_usize(i).unwrap() - T::one();
|
||||
let sj = T::two() * T::from_usize(j).unwrap() - T::one();
|
||||
for v in self.sv.iter_mut() {
|
||||
v.grad[0] -= si * k1[v.index] * delta_alpha_i + sj * k2[v.index] * delta_alpha_j;
|
||||
v.grad[1] += si * k1[v.index] * delta_alpha_i + sj * k2[v.index] * delta_alpha_j;
|
||||
}
|
||||
|
||||
self.find_min_max_gradient();
|
||||
}
|
||||
|
||||
let b = -(self.gmax + self.gmin) / T::two();
|
||||
|
||||
let mut support_vectors: Vec<M::RowVector> = Vec::new();
|
||||
let mut w: Vec<T> = Vec::new();
|
||||
|
||||
for v in self.sv {
|
||||
if v.alpha[0] != v.alpha[1] {
|
||||
support_vectors.push(v.x);
|
||||
w.push(v.alpha[1] - v.alpha[0]);
|
||||
}
|
||||
}
|
||||
|
||||
(support_vectors, w, b)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone> Cache<T> {
|
||||
fn new(n: usize) -> Cache<T> {
|
||||
Cache {
|
||||
data: vec![RefCell::new(None); n],
|
||||
}
|
||||
}
|
||||
|
||||
fn get<F: Fn() -> Vec<T>>(&self, i: usize, or: F) -> Ref<'_, Vec<T>> {
|
||||
if self.data[i].borrow().is_none() {
|
||||
self.data[i].replace(Some(or()));
|
||||
}
|
||||
Ref::map(self.data[i].borrow(), |v| v.as_ref().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::metrics::mean_squared_error;
|
||||
#[cfg(feature = "serde")]
|
||||
use crate::svm::*;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn svr_fit_predict() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
&[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
|
||||
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
&[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
|
||||
&[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
|
||||
&[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
|
||||
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
&[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
&[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
|
||||
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
|
||||
let y: Vec<f64> = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
|
||||
let y_hat = SVR::fit(&x, &y, SVRParameters::default().with_eps(2.0).with_c(10.0))
|
||||
.and_then(|lr| lr.predict(&x))
|
||||
.unwrap();
|
||||
|
||||
assert!(mean_squared_error(&y_hat, &y) < 2.5);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn svr_serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
&[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
|
||||
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
&[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
|
||||
&[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
|
||||
&[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
|
||||
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
&[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
&[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
|
||||
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
|
||||
let y: Vec<f64> = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
|
||||
let svr = SVR::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let deserialized_svr: SVR<f64, DenseMatrix<f64>, LinearKernel> =
|
||||
serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(svr, deserialized_svr);
|
||||
}
|
||||
}
|
||||
@@ -68,14 +68,18 @@ use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::Rng;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// Parameters of Decision Tree
|
||||
pub struct DecisionTreeClassifierParameters {
|
||||
/// Split criteria to use when building a tree.
|
||||
@@ -89,7 +93,8 @@ pub struct DecisionTreeClassifierParameters {
|
||||
}
|
||||
|
||||
/// Decision Tree
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct DecisionTreeClassifier<T: RealNumber> {
|
||||
nodes: Vec<Node<T>>,
|
||||
parameters: DecisionTreeClassifierParameters,
|
||||
@@ -99,7 +104,8 @@ pub struct DecisionTreeClassifier<T: RealNumber> {
|
||||
}
|
||||
|
||||
/// The function to measure the quality of a split.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum SplitCriterion {
|
||||
/// [Gini index](../decision_tree_classifier/index.html)
|
||||
Gini,
|
||||
@@ -109,9 +115,10 @@ pub enum SplitCriterion {
|
||||
ClassificationError,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
struct Node<T: RealNumber> {
|
||||
index: usize,
|
||||
_index: usize,
|
||||
output: usize,
|
||||
split_feature: usize,
|
||||
split_value: Option<T>,
|
||||
@@ -126,7 +133,7 @@ impl<T: RealNumber> PartialEq for DecisionTreeClassifier<T> {
|
||||
|| self.num_classes != other.num_classes
|
||||
|| self.nodes.len() != other.nodes.len()
|
||||
{
|
||||
return false;
|
||||
false
|
||||
} else {
|
||||
for i in 0..self.classes.len() {
|
||||
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
|
||||
@@ -138,7 +145,7 @@ impl<T: RealNumber> PartialEq for DecisionTreeClassifier<T> {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -160,6 +167,29 @@ impl<T: RealNumber> PartialEq for Node<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl DecisionTreeClassifierParameters {
|
||||
/// Split criteria to use when building a tree.
|
||||
pub fn with_criterion(mut self, criterion: SplitCriterion) -> Self {
|
||||
self.criterion = criterion;
|
||||
self
|
||||
}
|
||||
/// The maximum depth of the tree.
|
||||
pub fn with_max_depth(mut self, max_depth: u16) -> Self {
|
||||
self.max_depth = Some(max_depth);
|
||||
self
|
||||
}
|
||||
/// The minimum number of samples required to be at a leaf node.
|
||||
pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
|
||||
self.min_samples_leaf = min_samples_leaf;
|
||||
self
|
||||
}
|
||||
/// The minimum number of samples required to split an internal node.
|
||||
pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
|
||||
self.min_samples_split = min_samples_split;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DecisionTreeClassifierParameters {
|
||||
fn default() -> Self {
|
||||
DecisionTreeClassifierParameters {
|
||||
@@ -174,8 +204,8 @@ impl Default for DecisionTreeClassifierParameters {
|
||||
impl<T: RealNumber> Node<T> {
|
||||
fn new(index: usize, output: usize) -> Self {
|
||||
Node {
|
||||
index: index,
|
||||
output: output,
|
||||
_index: index,
|
||||
output,
|
||||
split_feature: 0,
|
||||
split_value: Option::None,
|
||||
split_score: Option::None,
|
||||
@@ -187,86 +217,105 @@ impl<T: RealNumber> Node<T> {
|
||||
|
||||
struct NodeVisitor<'a, T: RealNumber, M: Matrix<T>> {
|
||||
x: &'a M,
|
||||
y: &'a Vec<usize>,
|
||||
y: &'a [usize],
|
||||
node: usize,
|
||||
samples: Vec<usize>,
|
||||
order: &'a Vec<Vec<usize>>,
|
||||
order: &'a [Vec<usize>],
|
||||
true_child_output: usize,
|
||||
false_child_output: usize,
|
||||
level: u16,
|
||||
phantom: PhantomData<&'a T>,
|
||||
}
|
||||
|
||||
fn impurity<T: RealNumber>(criterion: &SplitCriterion, count: &Vec<usize>, n: usize) -> T {
|
||||
fn impurity<T: RealNumber>(criterion: &SplitCriterion, count: &[usize], n: usize) -> T {
|
||||
let mut impurity = T::zero();
|
||||
|
||||
match criterion {
|
||||
SplitCriterion::Gini => {
|
||||
impurity = T::one();
|
||||
for i in 0..count.len() {
|
||||
if count[i] > 0 {
|
||||
let p = T::from(count[i]).unwrap() / T::from(n).unwrap();
|
||||
impurity = impurity - p * p;
|
||||
for count_i in count.iter() {
|
||||
if *count_i > 0 {
|
||||
let p = T::from(*count_i).unwrap() / T::from(n).unwrap();
|
||||
impurity -= p * p;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SplitCriterion::Entropy => {
|
||||
for i in 0..count.len() {
|
||||
if count[i] > 0 {
|
||||
let p = T::from(count[i]).unwrap() / T::from(n).unwrap();
|
||||
impurity = impurity - p * p.log2();
|
||||
for count_i in count.iter() {
|
||||
if *count_i > 0 {
|
||||
let p = T::from(*count_i).unwrap() / T::from(n).unwrap();
|
||||
impurity -= p * p.log2();
|
||||
}
|
||||
}
|
||||
}
|
||||
SplitCriterion::ClassificationError => {
|
||||
for i in 0..count.len() {
|
||||
if count[i] > 0 {
|
||||
impurity = impurity.max(T::from(count[i]).unwrap() / T::from(n).unwrap());
|
||||
for count_i in count.iter() {
|
||||
if *count_i > 0 {
|
||||
impurity = impurity.max(T::from(*count_i).unwrap() / T::from(n).unwrap());
|
||||
}
|
||||
}
|
||||
impurity = (T::one() - impurity).abs();
|
||||
}
|
||||
}
|
||||
|
||||
return impurity;
|
||||
impurity
|
||||
}
|
||||
|
||||
impl<'a, T: RealNumber, M: Matrix<T>> NodeVisitor<'a, T, M> {
|
||||
fn new(
|
||||
node_id: usize,
|
||||
samples: Vec<usize>,
|
||||
order: &'a Vec<Vec<usize>>,
|
||||
order: &'a [Vec<usize>],
|
||||
x: &'a M,
|
||||
y: &'a Vec<usize>,
|
||||
y: &'a [usize],
|
||||
level: u16,
|
||||
) -> Self {
|
||||
NodeVisitor {
|
||||
x: x,
|
||||
y: y,
|
||||
x,
|
||||
y,
|
||||
node: node_id,
|
||||
samples: samples,
|
||||
order: order,
|
||||
samples,
|
||||
order,
|
||||
true_child_output: 0,
|
||||
false_child_output: 0,
|
||||
level: level,
|
||||
level,
|
||||
phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(in crate) fn which_max(x: &Vec<usize>) -> usize {
|
||||
pub(crate) fn which_max(x: &[usize]) -> usize {
|
||||
let mut m = x[0];
|
||||
let mut which = 0;
|
||||
|
||||
for i in 1..x.len() {
|
||||
if x[i] > m {
|
||||
m = x[i];
|
||||
for (i, x_i) in x.iter().enumerate().skip(1) {
|
||||
if *x_i > m {
|
||||
m = *x_i;
|
||||
which = i;
|
||||
}
|
||||
}
|
||||
|
||||
return which;
|
||||
which
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>>
|
||||
SupervisedEstimator<M, M::RowVector, DecisionTreeClassifierParameters>
|
||||
for DecisionTreeClassifier<T>
|
||||
{
|
||||
fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: DecisionTreeClassifierParameters,
|
||||
) -> Result<Self, Failed> {
|
||||
DecisionTreeClassifier::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for DecisionTreeClassifier<T> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
@@ -280,7 +329,14 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
) -> Result<DecisionTreeClassifier<T>, Failed> {
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let samples = vec![1; x_nrows];
|
||||
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||
DecisionTreeClassifier::fit_weak_learner(
|
||||
x,
|
||||
y,
|
||||
samples,
|
||||
num_attributes,
|
||||
parameters,
|
||||
&mut rand::thread_rng(),
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn fit_weak_learner<M: Matrix<T>>(
|
||||
@@ -289,6 +345,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
samples: Vec<usize>,
|
||||
mtry: usize,
|
||||
parameters: DecisionTreeClassifierParameters,
|
||||
rng: &mut impl Rng,
|
||||
) -> Result<DecisionTreeClassifier<T>, Failed> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
let (_, y_ncols) = y_m.shape();
|
||||
@@ -304,9 +361,9 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
|
||||
let mut yi: Vec<usize> = vec![0; y_ncols];
|
||||
|
||||
for i in 0..y_ncols {
|
||||
for (i, yi_i) in yi.iter_mut().enumerate().take(y_ncols) {
|
||||
let yc = y_m.get(0, i);
|
||||
yi[i] = classes.iter().position(|c| yc == *c).unwrap();
|
||||
*yi_i = classes.iter().position(|c| yc == *c).unwrap();
|
||||
}
|
||||
|
||||
let mut nodes: Vec<Node<T>> = Vec::new();
|
||||
@@ -325,24 +382,24 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
}
|
||||
|
||||
let mut tree = DecisionTreeClassifier {
|
||||
nodes: nodes,
|
||||
parameters: parameters,
|
||||
nodes,
|
||||
parameters,
|
||||
num_classes: k,
|
||||
classes: classes,
|
||||
classes,
|
||||
depth: 0,
|
||||
};
|
||||
|
||||
let mut visitor = NodeVisitor::<T, M>::new(0, samples, &order, &x, &yi, 1);
|
||||
let mut visitor = NodeVisitor::<T, M>::new(0, samples, &order, x, &yi, 1);
|
||||
|
||||
let mut visitor_queue: LinkedList<NodeVisitor<T, M>> = LinkedList::new();
|
||||
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
|
||||
|
||||
if tree.find_best_cutoff(&mut visitor, mtry) {
|
||||
if tree.find_best_cutoff(&mut visitor, mtry, rng) {
|
||||
visitor_queue.push_back(visitor);
|
||||
}
|
||||
|
||||
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
||||
match visitor_queue.pop_front() {
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue),
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue, rng),
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
@@ -364,7 +421,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
Ok(result.to_row_vector())
|
||||
}
|
||||
|
||||
pub(in crate) fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
|
||||
pub(crate) fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
|
||||
let mut result = 0;
|
||||
let mut queue: LinkedList<usize> = LinkedList::new();
|
||||
|
||||
@@ -376,25 +433,26 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
let node = &self.nodes[node_id];
|
||||
if node.true_child == None && node.false_child == None {
|
||||
result = node.output;
|
||||
} else if x.get(row, node.split_feature)
|
||||
<= node.split_value.unwrap_or_else(T::nan)
|
||||
{
|
||||
queue.push_back(node.true_child.unwrap());
|
||||
} else {
|
||||
if x.get(row, node.split_feature) <= node.split_value.unwrap_or(T::nan()) {
|
||||
queue.push_back(node.true_child.unwrap());
|
||||
} else {
|
||||
queue.push_back(node.false_child.unwrap());
|
||||
}
|
||||
queue.push_back(node.false_child.unwrap());
|
||||
}
|
||||
}
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
|
||||
return result;
|
||||
result
|
||||
}
|
||||
|
||||
fn find_best_cutoff<M: Matrix<T>>(
|
||||
&mut self,
|
||||
visitor: &mut NodeVisitor<T, M>,
|
||||
visitor: &mut NodeVisitor<'_, T, M>,
|
||||
mtry: usize,
|
||||
rng: &mut impl Rng,
|
||||
) -> bool {
|
||||
let (n_rows, n_attr) = visitor.x.shape();
|
||||
|
||||
@@ -431,23 +489,20 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
|
||||
let parent_impurity = impurity(&self.parameters.criterion, &count, n);
|
||||
|
||||
let mut variables = vec![0; n_attr];
|
||||
for i in 0..n_attr {
|
||||
variables[i] = i;
|
||||
}
|
||||
let mut variables = (0..n_attr).collect::<Vec<_>>();
|
||||
|
||||
if mtry < n_attr {
|
||||
variables.shuffle(&mut rand::thread_rng());
|
||||
variables.shuffle(rng);
|
||||
}
|
||||
|
||||
for j in 0..mtry {
|
||||
for variable in variables.iter().take(mtry) {
|
||||
self.find_best_split(
|
||||
visitor,
|
||||
n,
|
||||
&count,
|
||||
&mut false_count,
|
||||
parent_impurity,
|
||||
variables[j],
|
||||
*variable,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -456,10 +511,10 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
|
||||
fn find_best_split<M: Matrix<T>>(
|
||||
&mut self,
|
||||
visitor: &mut NodeVisitor<T, M>,
|
||||
visitor: &mut NodeVisitor<'_, T, M>,
|
||||
n: usize,
|
||||
count: &Vec<usize>,
|
||||
false_count: &mut Vec<usize>,
|
||||
count: &[usize],
|
||||
false_count: &mut [usize],
|
||||
parent_impurity: T,
|
||||
j: usize,
|
||||
) {
|
||||
@@ -496,7 +551,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
- T::from(tc).unwrap() / T::from(n).unwrap()
|
||||
* impurity(&self.parameters.criterion, &true_count, tc)
|
||||
- T::from(fc).unwrap() / T::from(n).unwrap()
|
||||
* impurity(&self.parameters.criterion, &false_count, fc);
|
||||
* impurity(&self.parameters.criterion, false_count, fc);
|
||||
|
||||
if self.nodes[visitor.node].split_score == Option::None
|
||||
|| gain > self.nodes[visitor.node].split_score.unwrap()
|
||||
@@ -521,19 +576,20 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
mut visitor: NodeVisitor<'a, T, M>,
|
||||
mtry: usize,
|
||||
visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>,
|
||||
rng: &mut impl Rng,
|
||||
) -> bool {
|
||||
let (n, _) = visitor.x.shape();
|
||||
let mut tc = 0;
|
||||
let mut fc = 0;
|
||||
let mut true_samples: Vec<usize> = vec![0; n];
|
||||
|
||||
for i in 0..n {
|
||||
for (i, true_sample) in true_samples.iter_mut().enumerate().take(n) {
|
||||
if visitor.samples[i] > 0 {
|
||||
if visitor.x.get(i, self.nodes[visitor.node].split_feature)
|
||||
<= self.nodes[visitor.node].split_value.unwrap_or(T::nan())
|
||||
<= self.nodes[visitor.node].split_value.unwrap_or_else(T::nan)
|
||||
{
|
||||
true_samples[i] = visitor.samples[i];
|
||||
tc += true_samples[i];
|
||||
*true_sample = visitor.samples[i];
|
||||
tc += *true_sample;
|
||||
visitor.samples[i] = 0;
|
||||
} else {
|
||||
fc += visitor.samples[i];
|
||||
@@ -569,7 +625,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
visitor.level + 1,
|
||||
);
|
||||
|
||||
if self.find_best_cutoff(&mut true_visitor, mtry) {
|
||||
if self.find_best_cutoff(&mut true_visitor, mtry, rng) {
|
||||
visitor_queue.push_back(true_visitor);
|
||||
}
|
||||
|
||||
@@ -582,7 +638,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
visitor.level + 1,
|
||||
);
|
||||
|
||||
if self.find_best_cutoff(&mut false_visitor, mtry) {
|
||||
if self.find_best_cutoff(&mut false_visitor, mtry, rng) {
|
||||
visitor_queue.push_back(false_visitor);
|
||||
}
|
||||
|
||||
@@ -595,6 +651,7 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn gini_impurity() {
|
||||
assert!(
|
||||
@@ -611,6 +668,7 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn fit_predict_iris() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
@@ -663,6 +721,7 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn fit_predict_baloons() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
@@ -699,7 +758,9 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[1., 1., 1., 0.],
|
||||
|
||||
@@ -63,14 +63,18 @@ use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::Rng;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// Parameters of Regression Tree
|
||||
pub struct DecisionTreeRegressorParameters {
|
||||
/// The maximum depth of the tree.
|
||||
@@ -82,16 +86,18 @@ pub struct DecisionTreeRegressorParameters {
|
||||
}
|
||||
|
||||
/// Regression Tree
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct DecisionTreeRegressor<T: RealNumber> {
|
||||
nodes: Vec<Node<T>>,
|
||||
parameters: DecisionTreeRegressorParameters,
|
||||
depth: u16,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
struct Node<T: RealNumber> {
|
||||
index: usize,
|
||||
_index: usize,
|
||||
output: T,
|
||||
split_feature: usize,
|
||||
split_value: Option<T>,
|
||||
@@ -100,6 +106,24 @@ struct Node<T: RealNumber> {
|
||||
false_child: Option<usize>,
|
||||
}
|
||||
|
||||
impl DecisionTreeRegressorParameters {
|
||||
/// The maximum depth of the tree.
|
||||
pub fn with_max_depth(mut self, max_depth: u16) -> Self {
|
||||
self.max_depth = Some(max_depth);
|
||||
self
|
||||
}
|
||||
/// The minimum number of samples required to be at a leaf node.
|
||||
pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
|
||||
self.min_samples_leaf = min_samples_leaf;
|
||||
self
|
||||
}
|
||||
/// The minimum number of samples required to split an internal node.
|
||||
pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
|
||||
self.min_samples_split = min_samples_split;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DecisionTreeRegressorParameters {
|
||||
fn default() -> Self {
|
||||
DecisionTreeRegressorParameters {
|
||||
@@ -113,8 +137,8 @@ impl Default for DecisionTreeRegressorParameters {
|
||||
impl<T: RealNumber> Node<T> {
|
||||
fn new(index: usize, output: T) -> Self {
|
||||
Node {
|
||||
index: index,
|
||||
output: output,
|
||||
_index: index,
|
||||
output,
|
||||
split_feature: 0,
|
||||
split_value: Option::None,
|
||||
split_score: Option::None,
|
||||
@@ -144,14 +168,14 @@ impl<T: RealNumber> PartialEq for Node<T> {
|
||||
impl<T: RealNumber> PartialEq for DecisionTreeRegressor<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.depth != other.depth || self.nodes.len() != other.nodes.len() {
|
||||
return false;
|
||||
false
|
||||
} else {
|
||||
for i in 0..self.nodes.len() {
|
||||
if self.nodes[i] != other.nodes[i] {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -161,7 +185,7 @@ struct NodeVisitor<'a, T: RealNumber, M: Matrix<T>> {
|
||||
y: &'a M,
|
||||
node: usize,
|
||||
samples: Vec<usize>,
|
||||
order: &'a Vec<Vec<usize>>,
|
||||
order: &'a [Vec<usize>],
|
||||
true_child_output: T,
|
||||
false_child_output: T,
|
||||
level: u16,
|
||||
@@ -171,24 +195,43 @@ impl<'a, T: RealNumber, M: Matrix<T>> NodeVisitor<'a, T, M> {
|
||||
fn new(
|
||||
node_id: usize,
|
||||
samples: Vec<usize>,
|
||||
order: &'a Vec<Vec<usize>>,
|
||||
order: &'a [Vec<usize>],
|
||||
x: &'a M,
|
||||
y: &'a M,
|
||||
level: u16,
|
||||
) -> Self {
|
||||
NodeVisitor {
|
||||
x: x,
|
||||
y: y,
|
||||
x,
|
||||
y,
|
||||
node: node_id,
|
||||
samples: samples,
|
||||
order: order,
|
||||
samples,
|
||||
order,
|
||||
true_child_output: T::zero(),
|
||||
false_child_output: T::zero(),
|
||||
level: level,
|
||||
level,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>>
|
||||
SupervisedEstimator<M, M::RowVector, DecisionTreeRegressorParameters>
|
||||
for DecisionTreeRegressor<T>
|
||||
{
|
||||
fn fit(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: DecisionTreeRegressorParameters,
|
||||
) -> Result<Self, Failed> {
|
||||
DecisionTreeRegressor::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for DecisionTreeRegressor<T> {
|
||||
fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
self.predict(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
/// Build a decision tree regressor from the training data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
@@ -200,7 +243,14 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
) -> Result<DecisionTreeRegressor<T>, Failed> {
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let samples = vec![1; x_nrows];
|
||||
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||
DecisionTreeRegressor::fit_weak_learner(
|
||||
x,
|
||||
y,
|
||||
samples,
|
||||
num_attributes,
|
||||
parameters,
|
||||
&mut rand::thread_rng(),
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn fit_weak_learner<M: Matrix<T>>(
|
||||
@@ -209,6 +259,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
samples: Vec<usize>,
|
||||
mtry: usize,
|
||||
parameters: DecisionTreeRegressorParameters,
|
||||
rng: &mut impl Rng,
|
||||
) -> Result<DecisionTreeRegressor<T>, Failed> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
|
||||
@@ -219,9 +270,9 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
|
||||
let mut n = 0;
|
||||
let mut sum = T::zero();
|
||||
for i in 0..y_ncols {
|
||||
n += samples[i];
|
||||
sum = sum + T::from(samples[i]).unwrap() * y_m.get(0, i);
|
||||
for (i, sample_i) in samples.iter().enumerate().take(y_ncols) {
|
||||
n += *sample_i;
|
||||
sum += T::from(*sample_i).unwrap() * y_m.get(0, i);
|
||||
}
|
||||
|
||||
let root = Node::new(0, sum / T::from(n).unwrap());
|
||||
@@ -233,22 +284,22 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
}
|
||||
|
||||
let mut tree = DecisionTreeRegressor {
|
||||
nodes: nodes,
|
||||
parameters: parameters,
|
||||
nodes,
|
||||
parameters,
|
||||
depth: 0,
|
||||
};
|
||||
|
||||
let mut visitor = NodeVisitor::<T, M>::new(0, samples, &order, &x, &y_m, 1);
|
||||
let mut visitor = NodeVisitor::<T, M>::new(0, samples, &order, x, &y_m, 1);
|
||||
|
||||
let mut visitor_queue: LinkedList<NodeVisitor<T, M>> = LinkedList::new();
|
||||
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
|
||||
|
||||
if tree.find_best_cutoff(&mut visitor, mtry) {
|
||||
if tree.find_best_cutoff(&mut visitor, mtry, rng) {
|
||||
visitor_queue.push_back(visitor);
|
||||
}
|
||||
|
||||
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
||||
match visitor_queue.pop_front() {
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue),
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue, rng),
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
@@ -270,7 +321,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
Ok(result.to_row_vector())
|
||||
}
|
||||
|
||||
pub(in crate) fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
|
||||
pub(crate) fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
|
||||
let mut result = T::zero();
|
||||
let mut queue: LinkedList<usize> = LinkedList::new();
|
||||
|
||||
@@ -282,25 +333,26 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
let node = &self.nodes[node_id];
|
||||
if node.true_child == None && node.false_child == None {
|
||||
result = node.output;
|
||||
} else if x.get(row, node.split_feature)
|
||||
<= node.split_value.unwrap_or_else(T::nan)
|
||||
{
|
||||
queue.push_back(node.true_child.unwrap());
|
||||
} else {
|
||||
if x.get(row, node.split_feature) <= node.split_value.unwrap_or(T::nan()) {
|
||||
queue.push_back(node.true_child.unwrap());
|
||||
} else {
|
||||
queue.push_back(node.false_child.unwrap());
|
||||
}
|
||||
queue.push_back(node.false_child.unwrap());
|
||||
}
|
||||
}
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
|
||||
return result;
|
||||
result
|
||||
}
|
||||
|
||||
fn find_best_cutoff<M: Matrix<T>>(
|
||||
&mut self,
|
||||
visitor: &mut NodeVisitor<T, M>,
|
||||
visitor: &mut NodeVisitor<'_, T, M>,
|
||||
mtry: usize,
|
||||
rng: &mut impl Rng,
|
||||
) -> bool {
|
||||
let (_, n_attr) = visitor.x.shape();
|
||||
|
||||
@@ -312,20 +364,17 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
|
||||
let sum = self.nodes[visitor.node].output * T::from(n).unwrap();
|
||||
|
||||
let mut variables = vec![0; n_attr];
|
||||
for i in 0..n_attr {
|
||||
variables[i] = i;
|
||||
}
|
||||
let mut variables = (0..n_attr).collect::<Vec<_>>();
|
||||
|
||||
if mtry < n_attr {
|
||||
variables.shuffle(&mut rand::thread_rng());
|
||||
variables.shuffle(rng);
|
||||
}
|
||||
|
||||
let parent_gain =
|
||||
T::from(n).unwrap() * self.nodes[visitor.node].output * self.nodes[visitor.node].output;
|
||||
|
||||
for j in 0..mtry {
|
||||
self.find_best_split(visitor, n, sum, parent_gain, variables[j]);
|
||||
for variable in variables.iter().take(mtry) {
|
||||
self.find_best_split(visitor, n, sum, parent_gain, *variable);
|
||||
}
|
||||
|
||||
self.nodes[visitor.node].split_score != Option::None
|
||||
@@ -333,7 +382,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
|
||||
fn find_best_split<M: Matrix<T>>(
|
||||
&mut self,
|
||||
visitor: &mut NodeVisitor<T, M>,
|
||||
visitor: &mut NodeVisitor<'_, T, M>,
|
||||
n: usize,
|
||||
sum: T,
|
||||
parent_gain: T,
|
||||
@@ -348,8 +397,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
if prevx.is_nan() || visitor.x.get(*i, j) == prevx {
|
||||
prevx = visitor.x.get(*i, j);
|
||||
true_count += visitor.samples[*i];
|
||||
true_sum =
|
||||
true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i);
|
||||
true_sum += T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -360,8 +408,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
{
|
||||
prevx = visitor.x.get(*i, j);
|
||||
true_count += visitor.samples[*i];
|
||||
true_sum =
|
||||
true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i);
|
||||
true_sum += T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -384,7 +431,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
}
|
||||
|
||||
prevx = visitor.x.get(*i, j);
|
||||
true_sum = true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i);
|
||||
true_sum += T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i);
|
||||
true_count += visitor.samples[*i];
|
||||
}
|
||||
}
|
||||
@@ -395,19 +442,20 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
mut visitor: NodeVisitor<'a, T, M>,
|
||||
mtry: usize,
|
||||
visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>,
|
||||
rng: &mut impl Rng,
|
||||
) -> bool {
|
||||
let (n, _) = visitor.x.shape();
|
||||
let mut tc = 0;
|
||||
let mut fc = 0;
|
||||
let mut true_samples: Vec<usize> = vec![0; n];
|
||||
|
||||
for i in 0..n {
|
||||
for (i, true_sample) in true_samples.iter_mut().enumerate().take(n) {
|
||||
if visitor.samples[i] > 0 {
|
||||
if visitor.x.get(i, self.nodes[visitor.node].split_feature)
|
||||
<= self.nodes[visitor.node].split_value.unwrap_or(T::nan())
|
||||
<= self.nodes[visitor.node].split_value.unwrap_or_else(T::nan)
|
||||
{
|
||||
true_samples[i] = visitor.samples[i];
|
||||
tc += true_samples[i];
|
||||
*true_sample = visitor.samples[i];
|
||||
tc += *true_sample;
|
||||
visitor.samples[i] = 0;
|
||||
} else {
|
||||
fc += visitor.samples[i];
|
||||
@@ -443,7 +491,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
visitor.level + 1,
|
||||
);
|
||||
|
||||
if self.find_best_cutoff(&mut true_visitor, mtry) {
|
||||
if self.find_best_cutoff(&mut true_visitor, mtry, rng) {
|
||||
visitor_queue.push_back(true_visitor);
|
||||
}
|
||||
|
||||
@@ -456,7 +504,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
visitor.level + 1,
|
||||
);
|
||||
|
||||
if self.find_best_cutoff(&mut false_visitor, mtry) {
|
||||
if self.find_best_cutoff(&mut false_visitor, mtry, rng) {
|
||||
visitor_queue.push_back(false_visitor);
|
||||
}
|
||||
|
||||
@@ -469,6 +517,7 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn fit_longley() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
@@ -543,7 +592,9 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||
|
||||
Reference in New Issue
Block a user