Compare commits
114 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
20ca5c9647 | ||
|
|
3fe916988f | ||
|
|
d305406dfd | ||
|
|
3d2f4f71fa | ||
|
|
a1c56a859e | ||
|
|
d905ebea15 | ||
|
|
b482acdc8d | ||
|
|
b4a807eb9f | ||
|
|
ff456df0a4 | ||
|
|
322610c7fb | ||
|
|
70df9a8b49 | ||
|
|
7ea620e6fd | ||
|
|
db5edcf67a | ||
|
|
8297cbe67e | ||
|
|
38c9b5ad2f | ||
|
|
820201e920 | ||
|
|
389b0e8e67 | ||
|
|
f93286ffbd | ||
|
|
12c102d02b | ||
|
|
521dab49ef | ||
|
|
3bf8813946 | ||
|
|
7830946ecb | ||
|
|
813c7ab233 | ||
|
|
4397c91570 | ||
|
|
14245e15ad | ||
|
|
d0a4ccbe20 | ||
|
|
85b9fde9a7 | ||
|
|
d239314967 | ||
|
|
4bae62ab2f | ||
|
|
e8cba343ca | ||
|
|
0b3bf946df | ||
|
|
763a8370eb | ||
|
|
1208051fb5 | ||
|
|
436d0a089f | ||
|
|
92265cc979 | ||
|
|
513d3898c9 | ||
|
|
4b654b25ac | ||
|
|
5a2e1f1262 | ||
|
|
377d5d0b06 | ||
|
|
9ce448379a | ||
|
|
c295a0d1bb | ||
|
|
703dc9688b | ||
|
|
790979a26d | ||
|
|
162bed2aa2 | ||
|
|
5ed5772a4e | ||
|
|
d9814c0918 | ||
|
|
7f44b93838 | ||
|
|
02200ae1e3 | ||
|
|
3dc5336514 | ||
|
|
abeff7926e | ||
|
|
1395cc6518 | ||
|
|
4335ee5a56 | ||
|
|
4c1dbc3327 | ||
|
|
a920959ae3 | ||
|
|
6d58dbe2a2 | ||
|
|
023b449ff1 | ||
|
|
cd44f1d515 | ||
|
|
1b42f8a396 | ||
|
|
c0be45b667 | ||
|
|
0e9c517b1a | ||
|
|
fed11f005c | ||
|
|
483a21bec0 | ||
|
|
4fb2625a33 | ||
|
|
a30802ec43 | ||
|
|
4af69878e0 | ||
|
|
745d0b570e | ||
|
|
6b5bed6092 | ||
|
|
af6ec2d402 | ||
|
|
828df4e338 | ||
|
|
374dfeceb9 | ||
|
|
3cc20fd400 | ||
|
|
700d320724 | ||
|
|
ef06f45638 | ||
|
|
237b1160b1 | ||
|
|
d31145b4fe | ||
|
|
19ff6df84c | ||
|
|
228b54baf7 | ||
|
|
03b9f76e9f | ||
|
|
a882741e12 | ||
|
|
f4b5936dcf | ||
|
|
863be5ef75 | ||
|
|
ca0816db97 | ||
|
|
2f03c1d6d7 | ||
|
|
c987d39d43 | ||
|
|
fd6b2e8014 | ||
|
|
cd5611079c | ||
|
|
dd39433ff8 | ||
|
|
3dc8a42832 | ||
|
|
3480e728af | ||
|
|
f91b1f9942 | ||
|
|
5c400f40d2 | ||
|
|
408b97d8aa | ||
|
|
6109fc5211 | ||
|
|
19088b682a | ||
|
|
244a724445 | ||
|
|
9833a2f851 | ||
|
|
68e7162fba | ||
|
|
7daf536aeb | ||
|
|
0df797cbae | ||
|
|
139bbae456 | ||
|
|
dbca6d43ce | ||
|
|
991631876e | ||
|
|
40a92ee4db | ||
|
|
87d4e9a423 | ||
|
|
bd5fbb63b1 | ||
|
|
272aabcd69 | ||
|
|
fd00bc3780 | ||
|
|
f1cf8a6f08 | ||
|
|
762986b271 | ||
|
|
e0d46f430b | ||
|
|
eb769493e7 | ||
|
|
4a941d1700 | ||
|
|
0e8166386c | ||
|
|
d91999b430 |
@@ -1,43 +0,0 @@
|
|||||||
version: 2.1
|
|
||||||
|
|
||||||
workflows:
|
|
||||||
version: 2.1
|
|
||||||
build:
|
|
||||||
jobs:
|
|
||||||
- build
|
|
||||||
- clippy
|
|
||||||
jobs:
|
|
||||||
build:
|
|
||||||
docker:
|
|
||||||
- image: circleci/rust:latest
|
|
||||||
environment:
|
|
||||||
TZ: "/usr/share/zoneinfo/your/location"
|
|
||||||
steps:
|
|
||||||
- checkout
|
|
||||||
- restore_cache:
|
|
||||||
key: project-cache
|
|
||||||
- run:
|
|
||||||
name: Check formatting
|
|
||||||
command: cargo fmt -- --check
|
|
||||||
- run:
|
|
||||||
name: Stable Build
|
|
||||||
command: cargo build --features "nalgebra-bindings ndarray-bindings"
|
|
||||||
- run:
|
|
||||||
name: Test
|
|
||||||
command: cargo test --features "nalgebra-bindings ndarray-bindings"
|
|
||||||
- save_cache:
|
|
||||||
key: project-cache
|
|
||||||
paths:
|
|
||||||
- "~/.cargo"
|
|
||||||
- "./target"
|
|
||||||
clippy:
|
|
||||||
docker:
|
|
||||||
- image: circleci/rust:latest
|
|
||||||
steps:
|
|
||||||
- checkout
|
|
||||||
- run:
|
|
||||||
name: Install cargo clippy
|
|
||||||
command: rustup component add clippy
|
|
||||||
- run:
|
|
||||||
name: Run cargo clippy
|
|
||||||
command: cargo clippy --all-features -- -Drust-2018-idioms -Dwarnings
|
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
version: 2
|
||||||
|
updates:
|
||||||
|
- package-ecosystem: cargo
|
||||||
|
directory: "/"
|
||||||
|
schedule:
|
||||||
|
interval: daily
|
||||||
|
open-pull-requests-limit: 10
|
||||||
|
ignore:
|
||||||
|
- dependency-name: rand_distr
|
||||||
|
versions:
|
||||||
|
- 0.4.0
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main, development ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ development ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
tests:
|
||||||
|
runs-on: "${{ matrix.platform.os }}-latest"
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
platform: [
|
||||||
|
{ os: "windows", target: "x86_64-pc-windows-msvc" },
|
||||||
|
{ os: "windows", target: "i686-pc-windows-msvc" },
|
||||||
|
{ os: "ubuntu", target: "x86_64-unknown-linux-gnu" },
|
||||||
|
{ os: "ubuntu", target: "i686-unknown-linux-gnu" },
|
||||||
|
{ os: "ubuntu", target: "wasm32-unknown-unknown" },
|
||||||
|
{ os: "macos", target: "aarch64-apple-darwin" },
|
||||||
|
]
|
||||||
|
env:
|
||||||
|
TZ: "/usr/share/zoneinfo/your/location"
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Cache .cargo and target
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/.cargo
|
||||||
|
./target
|
||||||
|
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.toml') }}
|
||||||
|
restore-keys: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.toml') }}
|
||||||
|
- name: Install Rust toolchain
|
||||||
|
uses: actions-rs/toolchain@v1
|
||||||
|
with:
|
||||||
|
toolchain: stable
|
||||||
|
target: ${{ matrix.platform.target }}
|
||||||
|
profile: minimal
|
||||||
|
default: true
|
||||||
|
- name: Install test runner for wasm
|
||||||
|
if: matrix.platform.target == 'wasm32-unknown-unknown'
|
||||||
|
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
|
||||||
|
- name: Stable Build
|
||||||
|
uses: actions-rs/cargo@v1
|
||||||
|
with:
|
||||||
|
command: build
|
||||||
|
args: --all-features --target ${{ matrix.platform.target }}
|
||||||
|
- name: Tests
|
||||||
|
if: matrix.platform.target == 'x86_64-unknown-linux-gnu' || matrix.platform.target == 'x86_64-pc-windows-msvc' || matrix.platform.target == 'aarch64-apple-darwin'
|
||||||
|
uses: actions-rs/cargo@v1
|
||||||
|
with:
|
||||||
|
command: test
|
||||||
|
args: --all-features
|
||||||
|
- name: Tests in WASM
|
||||||
|
if: matrix.platform.target == 'wasm32-unknown-unknown'
|
||||||
|
run: wasm-pack test --node -- --all-features
|
||||||
@@ -0,0 +1,44 @@
|
|||||||
|
name: Coverage
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main, development ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ development ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
coverage:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
TZ: "/usr/share/zoneinfo/your/location"
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Cache .cargo
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/.cargo
|
||||||
|
./target
|
||||||
|
key: ${{ runner.os }}-coverage-cargo-${{ hashFiles('**/Cargo.toml') }}
|
||||||
|
restore-keys: ${{ runner.os }}-coverage-cargo-${{ hashFiles('**/Cargo.toml') }}
|
||||||
|
- name: Install Rust toolchain
|
||||||
|
uses: actions-rs/toolchain@v1
|
||||||
|
with:
|
||||||
|
toolchain: nightly
|
||||||
|
profile: minimal
|
||||||
|
default: true
|
||||||
|
- name: Install cargo-tarpaulin
|
||||||
|
uses: actions-rs/install@v0.1
|
||||||
|
with:
|
||||||
|
crate: cargo-tarpaulin
|
||||||
|
version: latest
|
||||||
|
use-tool-cache: true
|
||||||
|
- name: Run cargo-tarpaulin
|
||||||
|
uses: actions-rs/cargo@v1
|
||||||
|
with:
|
||||||
|
command: tarpaulin
|
||||||
|
args: --out Lcov --all-features -- --test-threads 1
|
||||||
|
- name: Upload to codecov.io
|
||||||
|
uses: codecov/codecov-action@v1
|
||||||
|
with:
|
||||||
|
fail_ci_if_error: true
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
name: Lint checks
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main, development ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ development ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
TZ: "/usr/share/zoneinfo/your/location"
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Cache .cargo and target
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/.cargo
|
||||||
|
./target
|
||||||
|
key: ${{ runner.os }}-lint-cargo-${{ hashFiles('**/Cargo.toml') }}
|
||||||
|
restore-keys: ${{ runner.os }}-lint-cargo-${{ hashFiles('**/Cargo.toml') }}
|
||||||
|
- name: Install Rust toolchain
|
||||||
|
uses: actions-rs/toolchain@v1
|
||||||
|
with:
|
||||||
|
toolchain: stable
|
||||||
|
profile: minimal
|
||||||
|
default: true
|
||||||
|
- run: rustup component add rustfmt
|
||||||
|
- name: Check formt
|
||||||
|
uses: actions-rs/cargo@v1
|
||||||
|
with:
|
||||||
|
command: fmt
|
||||||
|
args: --all -- --check
|
||||||
|
- run: rustup component add clippy
|
||||||
|
- name: Run clippy
|
||||||
|
uses: actions-rs/cargo@v1
|
||||||
|
with:
|
||||||
|
command: clippy
|
||||||
|
args: --all-features -- -Drust-2018-idioms -Dwarnings
|
||||||
@@ -0,0 +1,60 @@
|
|||||||
|
# Changelog
|
||||||
|
All notable changes to this project will be documented in this file.
|
||||||
|
|
||||||
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||||
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
|
## [Unreleased]
|
||||||
|
|
||||||
|
## Added
|
||||||
|
- L2 regularization penalty to the Logistic Regression
|
||||||
|
- Getters for the naive bayes structs
|
||||||
|
- One hot encoder
|
||||||
|
- Make moons data generator
|
||||||
|
- Support for WASM.
|
||||||
|
|
||||||
|
## Changed
|
||||||
|
- Make serde optional
|
||||||
|
|
||||||
|
## [0.2.0] - 2021-01-03
|
||||||
|
|
||||||
|
### Added
|
||||||
|
- DBSCAN
|
||||||
|
- Epsilon-SVR, SVC
|
||||||
|
- Ridge, Lasso, ElasticNet
|
||||||
|
- Bernoulli, Gaussian, Categorical and Multinomial Naive Bayes
|
||||||
|
- K-fold Cross Validation
|
||||||
|
- Singular value decomposition
|
||||||
|
- New api module
|
||||||
|
- Integration with Clippy
|
||||||
|
- Cholesky decomposition
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
- ndarray upgraded to 0.14
|
||||||
|
- smartcore::error:FailedError is now non-exhaustive
|
||||||
|
- K-Means
|
||||||
|
- PCA
|
||||||
|
- Random Forest
|
||||||
|
- Linear and Logistic Regression
|
||||||
|
- KNN
|
||||||
|
- Decision Tree
|
||||||
|
|
||||||
|
## [0.1.0] - 2020-09-25
|
||||||
|
|
||||||
|
### Added
|
||||||
|
- First release of smartcore.
|
||||||
|
- KNN + distance metrics (Euclidian, Minkowski, Manhattan, Hamming, Mahalanobis)
|
||||||
|
- Linear Regression (OLS)
|
||||||
|
- Logistic Regression
|
||||||
|
- Random Forest Classifier
|
||||||
|
- Decision Tree Classifier
|
||||||
|
- PCA
|
||||||
|
- K-Means
|
||||||
|
- Integrated with ndarray
|
||||||
|
- Abstract linear algebra methods
|
||||||
|
- RandomForest Regressor
|
||||||
|
- Decision Tree Regressor
|
||||||
|
- Serde integration
|
||||||
|
- Integrated with nalgebra
|
||||||
|
- LU, QR, SVD, EVD
|
||||||
|
- Evaluation Metrics
|
||||||
+21
-9
@@ -2,7 +2,7 @@
|
|||||||
name = "smartcore"
|
name = "smartcore"
|
||||||
description = "The most advanced machine learning library in rust."
|
description = "The most advanced machine learning library in rust."
|
||||||
homepage = "https://smartcorelib.org"
|
homepage = "https://smartcorelib.org"
|
||||||
version = "0.2.0"
|
version = "0.2.1"
|
||||||
authors = ["SmartCore Developers"]
|
authors = ["SmartCore Developers"]
|
||||||
edition = "2018"
|
edition = "2018"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
@@ -17,22 +17,29 @@ default = ["datasets"]
|
|||||||
ndarray-bindings = ["ndarray"]
|
ndarray-bindings = ["ndarray"]
|
||||||
nalgebra-bindings = ["nalgebra"]
|
nalgebra-bindings = ["nalgebra"]
|
||||||
datasets = []
|
datasets = []
|
||||||
|
fp_bench = []
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
ndarray = { version = "0.14", optional = true }
|
ndarray = { version = "0.15", optional = true }
|
||||||
nalgebra = { version = "0.23.0", optional = true }
|
nalgebra = { version = "0.31", optional = true }
|
||||||
num-traits = "0.2.12"
|
num-traits = "0.2"
|
||||||
num = "0.3.0"
|
num = "0.4"
|
||||||
rand = "0.7.3"
|
rand = "0.8"
|
||||||
rand_distr = "0.3.0"
|
rand_distr = "0.4"
|
||||||
serde = { version = "1.0.115", features = ["derive"] }
|
serde = { version = "1", features = ["derive"], optional = true }
|
||||||
serde_derive = "1.0.115"
|
itertools = "0.10.3"
|
||||||
|
|
||||||
|
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||||
|
getrandom = { version = "0.2", features = ["js"] }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
criterion = "0.3"
|
criterion = "0.3"
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
bincode = "1.3.1"
|
bincode = "1.3.1"
|
||||||
|
|
||||||
|
[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
|
||||||
|
wasm-bindgen-test = "0.3"
|
||||||
|
|
||||||
[[bench]]
|
[[bench]]
|
||||||
name = "distance"
|
name = "distance"
|
||||||
harness = false
|
harness = false
|
||||||
@@ -41,3 +48,8 @@ harness = false
|
|||||||
name = "naive_bayes"
|
name = "naive_bayes"
|
||||||
harness = false
|
harness = false
|
||||||
required-features = ["ndarray-bindings", "nalgebra-bindings"]
|
required-features = ["ndarray-bindings", "nalgebra-bindings"]
|
||||||
|
|
||||||
|
[[bench]]
|
||||||
|
name = "fastpair"
|
||||||
|
harness = false
|
||||||
|
required-features = ["fp_bench"]
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
|
||||||
|
|
||||||
|
// to run this bench you have to change the declaraion in mod.rs ---> pub mod fastpair;
|
||||||
|
use smartcore::algorithm::neighbour::fastpair::FastPair;
|
||||||
|
use smartcore::linalg::naive::dense_matrix::*;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
fn closest_pair_bench(n: usize, m: usize) -> () {
|
||||||
|
let x = DenseMatrix::<f64>::rand(n, m);
|
||||||
|
let fastpair = FastPair::new(&x);
|
||||||
|
let result = fastpair.unwrap();
|
||||||
|
|
||||||
|
result.closest_pair();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn closest_pair_brute_bench(n: usize, m: usize) -> () {
|
||||||
|
let x = DenseMatrix::<f64>::rand(n, m);
|
||||||
|
let fastpair = FastPair::new(&x);
|
||||||
|
let result = fastpair.unwrap();
|
||||||
|
|
||||||
|
result.closest_pair_brute();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bench_fastpair(c: &mut Criterion) {
|
||||||
|
let mut group = c.benchmark_group("FastPair");
|
||||||
|
|
||||||
|
// with full samples size (100) the test will take too long
|
||||||
|
group.significance_level(0.1).sample_size(30);
|
||||||
|
// increase from default 5.0 secs
|
||||||
|
group.measurement_time(Duration::from_secs(60));
|
||||||
|
|
||||||
|
for n_samples in [100_usize, 1000_usize].iter() {
|
||||||
|
for n_features in [10_usize, 100_usize, 1000_usize].iter() {
|
||||||
|
group.bench_with_input(
|
||||||
|
BenchmarkId::from_parameter(format!(
|
||||||
|
"fastpair --- n_samples: {}, n_features: {}",
|
||||||
|
n_samples, n_features
|
||||||
|
)),
|
||||||
|
n_samples,
|
||||||
|
|b, _| b.iter(|| closest_pair_bench(*n_samples, *n_features)),
|
||||||
|
);
|
||||||
|
group.bench_with_input(
|
||||||
|
BenchmarkId::from_parameter(format!(
|
||||||
|
"brute --- n_samples: {}, n_features: {}",
|
||||||
|
n_samples, n_features
|
||||||
|
)),
|
||||||
|
n_samples,
|
||||||
|
|b, _| b.iter(|| closest_pair_brute_bench(*n_samples, *n_features)),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
group.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
criterion_group!(benches, bench_fastpair);
|
||||||
|
criterion_main!(benches);
|
||||||
+3
-3
@@ -9,9 +9,9 @@
|
|||||||
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
|
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
|
||||||
inkscape:version="1.0 (4035a4f, 2020-05-01)"
|
inkscape:version="1.0 (4035a4f, 2020-05-01)"
|
||||||
sodipodi:docname="smartcore.svg"
|
sodipodi:docname="smartcore.svg"
|
||||||
width="396.01309mm"
|
width="1280"
|
||||||
height="86.286003mm"
|
height="320"
|
||||||
viewBox="0 0 396.0131 86.286004"
|
viewBox="0 0 454 86.286004"
|
||||||
version="1.1"
|
version="1.1"
|
||||||
id="svg512">
|
id="svg512">
|
||||||
<metadata
|
<metadata
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 2.5 KiB After Width: | Height: | Size: 2.5 KiB |
@@ -59,7 +59,7 @@ impl<T: RealNumber> BBDTree<T> {
|
|||||||
tree
|
tree
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(in crate) fn clustering(
|
pub(crate) fn clustering(
|
||||||
&self,
|
&self,
|
||||||
centroids: &[Vec<T>],
|
centroids: &[Vec<T>],
|
||||||
sums: &mut Vec<Vec<T>>,
|
sums: &mut Vec<Vec<T>>,
|
||||||
@@ -314,6 +314,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn bbdtree_iris() {
|
fn bbdtree_iris() {
|
||||||
let data = DenseMatrix::from_2d_array(&[
|
let data = DenseMatrix::from_2d_array(&[
|
||||||
|
|||||||
@@ -24,6 +24,7 @@
|
|||||||
//! ```
|
//! ```
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::algorithm::sort::heap_select::HeapSelection;
|
use crate::algorithm::sort::heap_select::HeapSelection;
|
||||||
@@ -32,7 +33,8 @@ use crate::math::distance::Distance;
|
|||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
/// Implements Cover Tree algorithm
|
/// Implements Cover Tree algorithm
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct CoverTree<T, F: RealNumber, D: Distance<T, F>> {
|
pub struct CoverTree<T, F: RealNumber, D: Distance<T, F>> {
|
||||||
base: F,
|
base: F,
|
||||||
inv_log_base: F,
|
inv_log_base: F,
|
||||||
@@ -56,16 +58,17 @@ impl<T, F: RealNumber, D: Distance<T, F>> PartialEq for CoverTree<T, F, D> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
struct Node<F: RealNumber> {
|
struct Node<F: RealNumber> {
|
||||||
idx: usize,
|
idx: usize,
|
||||||
max_dist: F,
|
max_dist: F,
|
||||||
parent_dist: F,
|
parent_dist: F,
|
||||||
children: Vec<Node<F>>,
|
children: Vec<Node<F>>,
|
||||||
scale: i64,
|
_scale: i64,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug)]
|
||||||
struct DistanceSet<F: RealNumber> {
|
struct DistanceSet<F: RealNumber> {
|
||||||
idx: usize,
|
idx: usize,
|
||||||
dist: Vec<F>,
|
dist: Vec<F>,
|
||||||
@@ -82,7 +85,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
|||||||
max_dist: F::zero(),
|
max_dist: F::zero(),
|
||||||
parent_dist: F::zero(),
|
parent_dist: F::zero(),
|
||||||
children: Vec::new(),
|
children: Vec::new(),
|
||||||
scale: 0,
|
_scale: 0,
|
||||||
};
|
};
|
||||||
let mut tree = CoverTree {
|
let mut tree = CoverTree {
|
||||||
base,
|
base,
|
||||||
@@ -114,7 +117,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
|||||||
}
|
}
|
||||||
|
|
||||||
let e = self.get_data_value(self.root.idx);
|
let e = self.get_data_value(self.root.idx);
|
||||||
let mut d = self.distance.distance(&e, p);
|
let mut d = self.distance.distance(e, p);
|
||||||
|
|
||||||
let mut current_cover_set: Vec<(F, &Node<F>)> = Vec::new();
|
let mut current_cover_set: Vec<(F, &Node<F>)> = Vec::new();
|
||||||
let mut zero_set: Vec<(F, &Node<F>)> = Vec::new();
|
let mut zero_set: Vec<(F, &Node<F>)> = Vec::new();
|
||||||
@@ -172,11 +175,14 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
|||||||
if ds.0 <= upper_bound {
|
if ds.0 <= upper_bound {
|
||||||
let v = self.get_data_value(ds.1.idx);
|
let v = self.get_data_value(ds.1.idx);
|
||||||
if !self.identical_excluded || v != p {
|
if !self.identical_excluded || v != p {
|
||||||
neighbors.push((ds.1.idx, ds.0, &v));
|
neighbors.push((ds.1.idx, ds.0, v));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if neighbors.len() > k {
|
||||||
|
neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
|
||||||
|
}
|
||||||
Ok(neighbors.into_iter().take(k).collect())
|
Ok(neighbors.into_iter().take(k).collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -197,7 +203,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
|||||||
let mut zero_set: Vec<(F, &Node<F>)> = Vec::new();
|
let mut zero_set: Vec<(F, &Node<F>)> = Vec::new();
|
||||||
|
|
||||||
let e = self.get_data_value(self.root.idx);
|
let e = self.get_data_value(self.root.idx);
|
||||||
let mut d = self.distance.distance(&e, p);
|
let mut d = self.distance.distance(e, p);
|
||||||
current_cover_set.push((d, &self.root));
|
current_cover_set.push((d, &self.root));
|
||||||
|
|
||||||
while !current_cover_set.is_empty() {
|
while !current_cover_set.is_empty() {
|
||||||
@@ -227,7 +233,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
|||||||
for ds in zero_set {
|
for ds in zero_set {
|
||||||
let v = self.get_data_value(ds.1.idx);
|
let v = self.get_data_value(ds.1.idx);
|
||||||
if !self.identical_excluded || v != p {
|
if !self.identical_excluded || v != p {
|
||||||
neighbors.push((ds.1.idx, ds.0, &v));
|
neighbors.push((ds.1.idx, ds.0, v));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -240,7 +246,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
|||||||
max_dist: F::zero(),
|
max_dist: F::zero(),
|
||||||
parent_dist: F::zero(),
|
parent_dist: F::zero(),
|
||||||
children: Vec::new(),
|
children: Vec::new(),
|
||||||
scale: 100,
|
_scale: 100,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -284,7 +290,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
|||||||
if point_set.is_empty() {
|
if point_set.is_empty() {
|
||||||
self.new_leaf(p)
|
self.new_leaf(p)
|
||||||
} else {
|
} else {
|
||||||
let max_dist = self.max(&point_set);
|
let max_dist = self.max(point_set);
|
||||||
let next_scale = (max_scale - 1).min(self.get_scale(max_dist));
|
let next_scale = (max_scale - 1).min(self.get_scale(max_dist));
|
||||||
if next_scale == std::i64::MIN {
|
if next_scale == std::i64::MIN {
|
||||||
let mut children: Vec<Node<F>> = Vec::new();
|
let mut children: Vec<Node<F>> = Vec::new();
|
||||||
@@ -301,7 +307,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
|||||||
max_dist: F::zero(),
|
max_dist: F::zero(),
|
||||||
parent_dist: F::zero(),
|
parent_dist: F::zero(),
|
||||||
children,
|
children,
|
||||||
scale: 100,
|
_scale: 100,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let mut far: Vec<DistanceSet<F>> = Vec::new();
|
let mut far: Vec<DistanceSet<F>> = Vec::new();
|
||||||
@@ -313,8 +319,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
|||||||
point_set.append(&mut far);
|
point_set.append(&mut far);
|
||||||
child
|
child
|
||||||
} else {
|
} else {
|
||||||
let mut children: Vec<Node<F>> = Vec::new();
|
let mut children: Vec<Node<F>> = vec![child];
|
||||||
children.push(child);
|
|
||||||
let mut new_point_set: Vec<DistanceSet<F>> = Vec::new();
|
let mut new_point_set: Vec<DistanceSet<F>> = Vec::new();
|
||||||
let mut new_consumed_set: Vec<DistanceSet<F>> = Vec::new();
|
let mut new_consumed_set: Vec<DistanceSet<F>> = Vec::new();
|
||||||
|
|
||||||
@@ -371,7 +376,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
|
|||||||
max_dist: self.max(consumed_set),
|
max_dist: self.max(consumed_set),
|
||||||
parent_dist: F::zero(),
|
parent_dist: F::zero(),
|
||||||
children,
|
children,
|
||||||
scale: (top_scale - max_scale),
|
_scale: (top_scale - max_scale),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -454,7 +459,8 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::math::distance::Distances;
|
use crate::math::distance::Distances;
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
struct SimpleDistance {}
|
struct SimpleDistance {}
|
||||||
|
|
||||||
impl Distance<i32, f64> for SimpleDistance {
|
impl Distance<i32, f64> for SimpleDistance {
|
||||||
@@ -463,6 +469,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn cover_tree_test() {
|
fn cover_tree_test() {
|
||||||
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
|
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
|
||||||
@@ -479,7 +486,7 @@ mod tests {
|
|||||||
let knn: Vec<i32> = knn.iter().map(|v| *v.2).collect();
|
let knn: Vec<i32> = knn.iter().map(|v| *v.2).collect();
|
||||||
assert_eq!(vec!(3, 4, 5, 6, 7), knn);
|
assert_eq!(vec!(3, 4, 5, 6, 7), knn);
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn cover_tree_test1() {
|
fn cover_tree_test1() {
|
||||||
let data = vec![
|
let data = vec![
|
||||||
@@ -498,8 +505,9 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(vec!(0, 1, 2), knn);
|
assert_eq!(vec!(0, 1, 2), knn);
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
|
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,48 @@
|
|||||||
|
//!
|
||||||
|
//! Dissimilarities for vector-vector distance
|
||||||
|
//!
|
||||||
|
//! Representing distances as pairwise dissimilarities, so to build a
|
||||||
|
//! graph of closest neighbours. This representation can be reused for
|
||||||
|
//! different implementations (initially used in this library for FastPair).
|
||||||
|
use std::cmp::{Eq, Ordering, PartialOrd};
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
|
///
|
||||||
|
/// The edge of the subgraph is defined by `PairwiseDistance`.
|
||||||
|
/// The calling algorithm can store a list of distsances as
|
||||||
|
/// a list of these structures.
|
||||||
|
///
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct PairwiseDistance<T: RealNumber> {
|
||||||
|
/// index of the vector in the original `Matrix` or list
|
||||||
|
pub node: usize,
|
||||||
|
|
||||||
|
/// index of the closest neighbor in the original `Matrix` or same list
|
||||||
|
pub neighbour: Option<usize>,
|
||||||
|
|
||||||
|
/// measure of distance, according to the algorithm distance function
|
||||||
|
/// if the distance is None, the edge has value "infinite" or max distance
|
||||||
|
/// each algorithm has to match
|
||||||
|
pub distance: Option<T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber> Eq for PairwiseDistance<T> {}
|
||||||
|
|
||||||
|
impl<T: RealNumber> PartialEq for PairwiseDistance<T> {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
self.node == other.node
|
||||||
|
&& self.neighbour == other.neighbour
|
||||||
|
&& self.distance == other.distance
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber> PartialOrd for PairwiseDistance<T> {
|
||||||
|
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||||
|
self.distance.partial_cmp(&other.distance)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,570 @@
|
|||||||
|
#![allow(non_snake_case)]
|
||||||
|
use itertools::Itertools;
|
||||||
|
///
|
||||||
|
/// # FastPair: Data-structure for the dynamic closest-pair problem.
|
||||||
|
///
|
||||||
|
/// Reference:
|
||||||
|
/// Eppstein, David: Fast hierarchical clustering and other applications of
|
||||||
|
/// dynamic closest pairs. Journal of Experimental Algorithmics 5 (2000) 1.
|
||||||
|
///
|
||||||
|
/// Example:
|
||||||
|
/// ```
|
||||||
|
/// use smartcore::algorithm::neighbour::distances::PairwiseDistance;
|
||||||
|
/// use smartcore::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
/// use smartcore::algorithm::neighbour::fastpair::FastPair;
|
||||||
|
/// let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
/// &[5.1, 3.5, 1.4, 0.2],
|
||||||
|
/// &[4.9, 3.0, 1.4, 0.2],
|
||||||
|
/// &[4.7, 3.2, 1.3, 0.2],
|
||||||
|
/// &[4.6, 3.1, 1.5, 0.2],
|
||||||
|
/// &[5.0, 3.6, 1.4, 0.2],
|
||||||
|
/// &[5.4, 3.9, 1.7, 0.4],
|
||||||
|
/// ]);
|
||||||
|
/// let fastpair = FastPair::new(&x);
|
||||||
|
/// let closest_pair: PairwiseDistance<f64> = fastpair.unwrap().closest_pair();
|
||||||
|
/// ```
|
||||||
|
/// <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
|
/// <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use crate::algorithm::neighbour::distances::PairwiseDistance;
|
||||||
|
use crate::error::{Failed, FailedError};
|
||||||
|
use crate::linalg::Matrix;
|
||||||
|
use crate::math::distance::euclidian::Euclidian;
|
||||||
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Inspired by Python implementation:
|
||||||
|
/// <https://github.com/carsonfarmer/fastpair/blob/b8b4d3000ab6f795a878936667eee1b557bf353d/fastpair/base.py>
|
||||||
|
/// MIT License (MIT) Copyright (c) 2016 Carson Farmer
|
||||||
|
///
|
||||||
|
/// affinity used is Euclidean so to allow linkage with single, ward, complete and average
|
||||||
|
///
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct FastPair<'a, T: RealNumber, M: Matrix<T>> {
|
||||||
|
/// initial matrix
|
||||||
|
samples: &'a M,
|
||||||
|
/// closest pair hashmap (connectivity matrix for closest pairs)
|
||||||
|
pub distances: HashMap<usize, PairwiseDistance<T>>,
|
||||||
|
/// conga line used to keep track of the closest pair
|
||||||
|
pub neighbours: Vec<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T: RealNumber, M: Matrix<T>> FastPair<'a, T, M> {
|
||||||
|
///
|
||||||
|
/// Constructor
|
||||||
|
/// Instantiate and inizialise the algorithm
|
||||||
|
///
|
||||||
|
pub fn new(m: &'a M) -> Result<Self, Failed> {
|
||||||
|
if m.shape().0 < 3 {
|
||||||
|
return Err(Failed::because(
|
||||||
|
FailedError::FindFailed,
|
||||||
|
"min number of rows should be 3",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut init = Self {
|
||||||
|
samples: m,
|
||||||
|
// to be computed in init(..)
|
||||||
|
distances: HashMap::with_capacity(m.shape().0),
|
||||||
|
neighbours: Vec::with_capacity(m.shape().0 + 1),
|
||||||
|
};
|
||||||
|
init.init();
|
||||||
|
Ok(init)
|
||||||
|
}
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Initialise `FastPair` by passing a `Matrix`.
|
||||||
|
/// Build a FastPairs data-structure from a set of (new) points.
|
||||||
|
///
|
||||||
|
fn init(&mut self) {
|
||||||
|
// basic measures
|
||||||
|
let len = self.samples.shape().0;
|
||||||
|
let max_index = self.samples.shape().0 - 1;
|
||||||
|
|
||||||
|
// Store all closest neighbors
|
||||||
|
let _distances = Box::new(HashMap::with_capacity(len));
|
||||||
|
let _neighbours = Box::new(Vec::with_capacity(len));
|
||||||
|
|
||||||
|
let mut distances = *_distances;
|
||||||
|
let mut neighbours = *_neighbours;
|
||||||
|
|
||||||
|
// fill neighbours with -1 values
|
||||||
|
neighbours.extend(0..len);
|
||||||
|
|
||||||
|
// init closest neighbour pairwise data
|
||||||
|
for index_row_i in 0..(max_index) {
|
||||||
|
distances.insert(
|
||||||
|
index_row_i,
|
||||||
|
PairwiseDistance {
|
||||||
|
node: index_row_i,
|
||||||
|
neighbour: None,
|
||||||
|
distance: Some(T::max_value()),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// loop through indeces and neighbours
|
||||||
|
for index_row_i in 0..(len) {
|
||||||
|
// start looking for the neighbour in the second element
|
||||||
|
let mut index_closest = index_row_i + 1; // closest neighbour index
|
||||||
|
let mut nbd: Option<T> = distances[&index_row_i].distance; // init neighbour distance
|
||||||
|
for index_row_j in (index_row_i + 1)..len {
|
||||||
|
distances.insert(
|
||||||
|
index_row_j,
|
||||||
|
PairwiseDistance {
|
||||||
|
node: index_row_j,
|
||||||
|
neighbour: Some(index_row_i),
|
||||||
|
distance: nbd,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
let d = Euclidian::squared_distance(
|
||||||
|
&(self.samples.get_row_as_vec(index_row_i)),
|
||||||
|
&(self.samples.get_row_as_vec(index_row_j)),
|
||||||
|
);
|
||||||
|
if d < nbd.unwrap() {
|
||||||
|
// set this j-value to be the closest neighbour
|
||||||
|
index_closest = index_row_j;
|
||||||
|
nbd = Some(d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add that edge
|
||||||
|
distances.entry(index_row_i).and_modify(|e| {
|
||||||
|
e.distance = nbd;
|
||||||
|
e.neighbour = Some(index_closest);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
// No more neighbors, terminate conga line.
|
||||||
|
// Last person on the line has no neigbors
|
||||||
|
distances.get_mut(&max_index).unwrap().neighbour = Some(max_index);
|
||||||
|
distances.get_mut(&(len - 1)).unwrap().distance = Some(T::max_value());
|
||||||
|
|
||||||
|
// compute sparse matrix (connectivity matrix)
|
||||||
|
let mut sparse_matrix = M::zeros(len, len);
|
||||||
|
for (_, p) in distances.iter() {
|
||||||
|
sparse_matrix.set(p.node, p.neighbour.unwrap(), p.distance.unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
self.distances = distances;
|
||||||
|
self.neighbours = neighbours;
|
||||||
|
}
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Find closest pair by scanning list of nearest neighbors.
|
||||||
|
///
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn closest_pair(&self) -> PairwiseDistance<T> {
|
||||||
|
let mut a = self.neighbours[0]; // Start with first point
|
||||||
|
let mut d = self.distances[&a].distance;
|
||||||
|
for p in self.neighbours.iter() {
|
||||||
|
if self.distances[p].distance < d {
|
||||||
|
a = *p; // Update `a` and distance `d`
|
||||||
|
d = self.distances[p].distance;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let b = self.distances[&a].neighbour;
|
||||||
|
PairwiseDistance {
|
||||||
|
node: a,
|
||||||
|
neighbour: b,
|
||||||
|
distance: d,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Brute force algorithm, used only for comparison and testing
|
||||||
|
///
|
||||||
|
#[cfg(feature = "fp_bench")]
|
||||||
|
pub fn closest_pair_brute(&self) -> PairwiseDistance<T> {
|
||||||
|
let m = self.samples.shape().0;
|
||||||
|
|
||||||
|
let mut closest_pair = PairwiseDistance {
|
||||||
|
node: 0,
|
||||||
|
neighbour: None,
|
||||||
|
distance: Some(T::max_value()),
|
||||||
|
};
|
||||||
|
for pair in (0..m).combinations(2) {
|
||||||
|
let d = Euclidian::squared_distance(
|
||||||
|
&(self.samples.get_row_as_vec(pair[0])),
|
||||||
|
&(self.samples.get_row_as_vec(pair[1])),
|
||||||
|
);
|
||||||
|
if d < closest_pair.distance.unwrap() {
|
||||||
|
closest_pair.node = pair[0];
|
||||||
|
closest_pair.neighbour = Some(pair[1]);
|
||||||
|
closest_pair.distance = Some(d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
closest_pair
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Compute distances from input to all other points in data-structure.
|
||||||
|
// input is the row index of the sample matrix
|
||||||
|
//
|
||||||
|
#[allow(dead_code)]
|
||||||
|
fn distances_from(&self, index_row: usize) -> Vec<PairwiseDistance<T>> {
|
||||||
|
let mut distances = Vec::<PairwiseDistance<T>>::with_capacity(self.samples.shape().0);
|
||||||
|
for other in self.neighbours.iter() {
|
||||||
|
if index_row != *other {
|
||||||
|
distances.push(PairwiseDistance {
|
||||||
|
node: index_row,
|
||||||
|
neighbour: Some(*other),
|
||||||
|
distance: Some(Euclidian::squared_distance(
|
||||||
|
&(self.samples.get_row_as_vec(index_row)),
|
||||||
|
&(self.samples.get_row_as_vec(*other)),
|
||||||
|
)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
distances
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests_fastpair {
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn fastpair_init() {
|
||||||
|
let x: DenseMatrix<f64> = DenseMatrix::rand(10, 4);
|
||||||
|
let _fastpair = FastPair::new(&x);
|
||||||
|
assert!(_fastpair.is_ok());
|
||||||
|
|
||||||
|
let fastpair = _fastpair.unwrap();
|
||||||
|
|
||||||
|
let distances = fastpair.distances;
|
||||||
|
let neighbours = fastpair.neighbours;
|
||||||
|
|
||||||
|
assert!(distances.len() != 0);
|
||||||
|
assert!(neighbours.len() != 0);
|
||||||
|
|
||||||
|
assert_eq!(10, neighbours.len());
|
||||||
|
assert_eq!(10, distances.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn dataset_has_at_least_three_points() {
|
||||||
|
// Create a dataset which consists of only two points:
|
||||||
|
// A(0.0, 0.0) and B(1.0, 1.0).
|
||||||
|
let dataset = DenseMatrix::<f64>::from_2d_array(&[&[0.0, 0.0], &[1.0, 1.0]]);
|
||||||
|
|
||||||
|
// We expect an error when we run `FastPair` on this dataset,
|
||||||
|
// becuase `FastPair` currently only works on a minimum of 3
|
||||||
|
// points.
|
||||||
|
let _fastpair = FastPair::new(&dataset);
|
||||||
|
|
||||||
|
match _fastpair {
|
||||||
|
Err(e) => {
|
||||||
|
let expected_error =
|
||||||
|
Failed::because(FailedError::FindFailed, "min number of rows should be 3");
|
||||||
|
assert_eq!(e, expected_error)
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
assert!(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn one_dimensional_dataset_minimal() {
|
||||||
|
let dataset = DenseMatrix::<f64>::from_2d_array(&[&[0.0], &[2.0], &[9.0]]);
|
||||||
|
|
||||||
|
let result = FastPair::new(&dataset);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
let fastpair = result.unwrap();
|
||||||
|
let closest_pair = fastpair.closest_pair();
|
||||||
|
let expected_closest_pair = PairwiseDistance {
|
||||||
|
node: 0,
|
||||||
|
neighbour: Some(1),
|
||||||
|
distance: Some(4.0),
|
||||||
|
};
|
||||||
|
assert_eq!(closest_pair, expected_closest_pair);
|
||||||
|
|
||||||
|
let closest_pair_brute = fastpair.closest_pair_brute();
|
||||||
|
assert_eq!(closest_pair_brute, expected_closest_pair);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn one_dimensional_dataset_2() {
|
||||||
|
let dataset = DenseMatrix::<f64>::from_2d_array(&[&[27.0], &[0.0], &[9.0], &[2.0]]);
|
||||||
|
|
||||||
|
let result = FastPair::new(&dataset);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
let fastpair = result.unwrap();
|
||||||
|
let closest_pair = fastpair.closest_pair();
|
||||||
|
let expected_closest_pair = PairwiseDistance {
|
||||||
|
node: 1,
|
||||||
|
neighbour: Some(3),
|
||||||
|
distance: Some(4.0),
|
||||||
|
};
|
||||||
|
assert_eq!(closest_pair, fastpair.closest_pair_brute());
|
||||||
|
assert_eq!(closest_pair, expected_closest_pair);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn fastpair_new() {
|
||||||
|
// compute
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
|
&[4.7, 3.2, 1.3, 0.2],
|
||||||
|
&[4.6, 3.1, 1.5, 0.2],
|
||||||
|
&[5.0, 3.6, 1.4, 0.2],
|
||||||
|
&[5.4, 3.9, 1.7, 0.4],
|
||||||
|
&[4.6, 3.4, 1.4, 0.3],
|
||||||
|
&[5.0, 3.4, 1.5, 0.2],
|
||||||
|
&[4.4, 2.9, 1.4, 0.2],
|
||||||
|
&[4.9, 3.1, 1.5, 0.1],
|
||||||
|
&[7.0, 3.2, 4.7, 1.4],
|
||||||
|
&[6.4, 3.2, 4.5, 1.5],
|
||||||
|
&[6.9, 3.1, 4.9, 1.5],
|
||||||
|
&[5.5, 2.3, 4.0, 1.3],
|
||||||
|
&[6.5, 2.8, 4.6, 1.5],
|
||||||
|
]);
|
||||||
|
let fastpair = FastPair::new(&x);
|
||||||
|
assert!(fastpair.is_ok());
|
||||||
|
|
||||||
|
// unwrap results
|
||||||
|
let result = fastpair.unwrap();
|
||||||
|
|
||||||
|
// list of minimal pairwise dissimilarities
|
||||||
|
let dissimilarities = vec![
|
||||||
|
(
|
||||||
|
1,
|
||||||
|
PairwiseDistance {
|
||||||
|
node: 1,
|
||||||
|
neighbour: Some(9),
|
||||||
|
distance: Some(0.030000000000000037),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
10,
|
||||||
|
PairwiseDistance {
|
||||||
|
node: 10,
|
||||||
|
neighbour: Some(12),
|
||||||
|
distance: Some(0.07000000000000003),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
11,
|
||||||
|
PairwiseDistance {
|
||||||
|
node: 11,
|
||||||
|
neighbour: Some(14),
|
||||||
|
distance: Some(0.18000000000000013),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
12,
|
||||||
|
PairwiseDistance {
|
||||||
|
node: 12,
|
||||||
|
neighbour: Some(14),
|
||||||
|
distance: Some(0.34000000000000086),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
13,
|
||||||
|
PairwiseDistance {
|
||||||
|
node: 13,
|
||||||
|
neighbour: Some(14),
|
||||||
|
distance: Some(1.6499999999999997),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
14,
|
||||||
|
PairwiseDistance {
|
||||||
|
node: 14,
|
||||||
|
neighbour: Some(14),
|
||||||
|
distance: Some(f64::MAX),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
6,
|
||||||
|
PairwiseDistance {
|
||||||
|
node: 6,
|
||||||
|
neighbour: Some(7),
|
||||||
|
distance: Some(0.18000000000000027),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
0,
|
||||||
|
PairwiseDistance {
|
||||||
|
node: 0,
|
||||||
|
neighbour: Some(4),
|
||||||
|
distance: Some(0.01999999999999995),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
8,
|
||||||
|
PairwiseDistance {
|
||||||
|
node: 8,
|
||||||
|
neighbour: Some(9),
|
||||||
|
distance: Some(0.3100000000000001),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
2,
|
||||||
|
PairwiseDistance {
|
||||||
|
node: 2,
|
||||||
|
neighbour: Some(3),
|
||||||
|
distance: Some(0.0600000000000001),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
3,
|
||||||
|
PairwiseDistance {
|
||||||
|
node: 3,
|
||||||
|
neighbour: Some(8),
|
||||||
|
distance: Some(0.08999999999999982),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
7,
|
||||||
|
PairwiseDistance {
|
||||||
|
node: 7,
|
||||||
|
neighbour: Some(9),
|
||||||
|
distance: Some(0.10999999999999982),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
9,
|
||||||
|
PairwiseDistance {
|
||||||
|
node: 9,
|
||||||
|
neighbour: Some(13),
|
||||||
|
distance: Some(8.69),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
4,
|
||||||
|
PairwiseDistance {
|
||||||
|
node: 4,
|
||||||
|
neighbour: Some(7),
|
||||||
|
distance: Some(0.050000000000000086),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
5,
|
||||||
|
PairwiseDistance {
|
||||||
|
node: 5,
|
||||||
|
neighbour: Some(7),
|
||||||
|
distance: Some(0.4900000000000002),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
];
|
||||||
|
|
||||||
|
let expected: HashMap<_, _> = dissimilarities.into_iter().collect();
|
||||||
|
|
||||||
|
for i in 0..(x.shape().0 - 1) {
|
||||||
|
let input_node = result.samples.get_row_as_vec(i);
|
||||||
|
let input_neighbour: usize = expected.get(&i).unwrap().neighbour.unwrap();
|
||||||
|
let distance = Euclidian::squared_distance(
|
||||||
|
&input_node,
|
||||||
|
&result.samples.get_row_as_vec(input_neighbour),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(i, expected.get(&i).unwrap().node);
|
||||||
|
assert_eq!(
|
||||||
|
input_neighbour,
|
||||||
|
expected.get(&i).unwrap().neighbour.unwrap()
|
||||||
|
);
|
||||||
|
assert_eq!(distance, expected.get(&i).unwrap().distance.unwrap());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn fastpair_closest_pair() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
|
&[4.7, 3.2, 1.3, 0.2],
|
||||||
|
&[4.6, 3.1, 1.5, 0.2],
|
||||||
|
&[5.0, 3.6, 1.4, 0.2],
|
||||||
|
&[5.4, 3.9, 1.7, 0.4],
|
||||||
|
&[4.6, 3.4, 1.4, 0.3],
|
||||||
|
&[5.0, 3.4, 1.5, 0.2],
|
||||||
|
&[4.4, 2.9, 1.4, 0.2],
|
||||||
|
&[4.9, 3.1, 1.5, 0.1],
|
||||||
|
&[7.0, 3.2, 4.7, 1.4],
|
||||||
|
&[6.4, 3.2, 4.5, 1.5],
|
||||||
|
&[6.9, 3.1, 4.9, 1.5],
|
||||||
|
&[5.5, 2.3, 4.0, 1.3],
|
||||||
|
&[6.5, 2.8, 4.6, 1.5],
|
||||||
|
]);
|
||||||
|
// compute
|
||||||
|
let fastpair = FastPair::new(&x);
|
||||||
|
assert!(fastpair.is_ok());
|
||||||
|
|
||||||
|
let dissimilarity = fastpair.unwrap().closest_pair();
|
||||||
|
let closest = PairwiseDistance {
|
||||||
|
node: 0,
|
||||||
|
neighbour: Some(4),
|
||||||
|
distance: Some(0.01999999999999995),
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(closest, dissimilarity);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn fastpair_closest_pair_random_matrix() {
|
||||||
|
let x = DenseMatrix::<f64>::rand(200, 25);
|
||||||
|
// compute
|
||||||
|
let fastpair = FastPair::new(&x);
|
||||||
|
assert!(fastpair.is_ok());
|
||||||
|
|
||||||
|
let result = fastpair.unwrap();
|
||||||
|
|
||||||
|
let dissimilarity1 = result.closest_pair();
|
||||||
|
let dissimilarity2 = result.closest_pair_brute();
|
||||||
|
|
||||||
|
assert_eq!(dissimilarity1, dissimilarity2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn fastpair_distances() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
|
&[4.7, 3.2, 1.3, 0.2],
|
||||||
|
&[4.6, 3.1, 1.5, 0.2],
|
||||||
|
&[5.0, 3.6, 1.4, 0.2],
|
||||||
|
&[5.4, 3.9, 1.7, 0.4],
|
||||||
|
&[4.6, 3.4, 1.4, 0.3],
|
||||||
|
&[5.0, 3.4, 1.5, 0.2],
|
||||||
|
&[4.4, 2.9, 1.4, 0.2],
|
||||||
|
&[4.9, 3.1, 1.5, 0.1],
|
||||||
|
&[7.0, 3.2, 4.7, 1.4],
|
||||||
|
&[6.4, 3.2, 4.5, 1.5],
|
||||||
|
&[6.9, 3.1, 4.9, 1.5],
|
||||||
|
&[5.5, 2.3, 4.0, 1.3],
|
||||||
|
&[6.5, 2.8, 4.6, 1.5],
|
||||||
|
]);
|
||||||
|
// compute
|
||||||
|
let fastpair = FastPair::new(&x);
|
||||||
|
assert!(fastpair.is_ok());
|
||||||
|
|
||||||
|
let dissimilarities = fastpair.unwrap().distances_from(0);
|
||||||
|
|
||||||
|
let mut min_dissimilarity = PairwiseDistance {
|
||||||
|
node: 0,
|
||||||
|
neighbour: None,
|
||||||
|
distance: Some(f64::MAX),
|
||||||
|
};
|
||||||
|
for p in dissimilarities.iter() {
|
||||||
|
if p.distance.unwrap() < min_dissimilarity.distance.unwrap() {
|
||||||
|
min_dissimilarity = p.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let closest = PairwiseDistance {
|
||||||
|
node: 0,
|
||||||
|
neighbour: Some(4),
|
||||||
|
distance: Some(0.01999999999999995),
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(closest, min_dissimilarity);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -22,6 +22,7 @@
|
|||||||
//!
|
//!
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::cmp::{Ordering, PartialOrd};
|
use std::cmp::{Ordering, PartialOrd};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
@@ -32,7 +33,8 @@ use crate::math::distance::Distance;
|
|||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
/// Implements Linear Search algorithm, see [KNN algorithms](../index.html)
|
/// Implements Linear Search algorithm, see [KNN algorithms](../index.html)
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct LinearKNNSearch<T, F: RealNumber, D: Distance<T, F>> {
|
pub struct LinearKNNSearch<T, F: RealNumber, D: Distance<T, F>> {
|
||||||
distance: D,
|
distance: D,
|
||||||
data: Vec<T>,
|
data: Vec<T>,
|
||||||
@@ -72,7 +74,7 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i in 0..self.data.len() {
|
for i in 0..self.data.len() {
|
||||||
let d = self.distance.distance(&from, &self.data[i]);
|
let d = self.distance.distance(from, &self.data[i]);
|
||||||
let datum = heap.peek_mut();
|
let datum = heap.peek_mut();
|
||||||
if d < datum.distance {
|
if d < datum.distance {
|
||||||
datum.distance = d;
|
datum.distance = d;
|
||||||
@@ -102,7 +104,7 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
|
|||||||
let mut neighbors: Vec<(usize, F, &T)> = Vec::new();
|
let mut neighbors: Vec<(usize, F, &T)> = Vec::new();
|
||||||
|
|
||||||
for i in 0..self.data.len() {
|
for i in 0..self.data.len() {
|
||||||
let d = self.distance.distance(&from, &self.data[i]);
|
let d = self.distance.distance(from, &self.data[i]);
|
||||||
|
|
||||||
if d <= radius {
|
if d <= radius {
|
||||||
neighbors.push((i, d, &self.data[i]));
|
neighbors.push((i, d, &self.data[i]));
|
||||||
@@ -138,7 +140,8 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::math::distance::Distances;
|
use crate::math::distance::Distances;
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
struct SimpleDistance {}
|
struct SimpleDistance {}
|
||||||
|
|
||||||
impl Distance<i32, f64> for SimpleDistance {
|
impl Distance<i32, f64> for SimpleDistance {
|
||||||
@@ -147,6 +150,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn knn_find() {
|
fn knn_find() {
|
||||||
let data1 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
|
let data1 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
|
||||||
@@ -193,7 +197,7 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(vec!(1, 2, 3), found_idxs2);
|
assert_eq!(vec!(1, 2, 3), found_idxs2);
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn knn_point_eq() {
|
fn knn_point_eq() {
|
||||||
let point1 = KNNPoint {
|
let point1 = KNNPoint {
|
||||||
|
|||||||
@@ -35,17 +35,23 @@ use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
|
|||||||
use crate::error::Failed;
|
use crate::error::Failed;
|
||||||
use crate::math::distance::Distance;
|
use crate::math::distance::Distance;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
pub(crate) mod bbd_tree;
|
pub(crate) mod bbd_tree;
|
||||||
/// tree data structure for fast nearest neighbor search
|
/// tree data structure for fast nearest neighbor search
|
||||||
pub mod cover_tree;
|
pub mod cover_tree;
|
||||||
|
/// dissimilarities for vector-vector distance. Linkage algorithms used in fastpair
|
||||||
|
pub mod distances;
|
||||||
|
/// fastpair closest neighbour algorithm
|
||||||
|
pub mod fastpair;
|
||||||
/// very simple algorithm that sequentially checks each element of the list until a match is found or the whole list has been searched.
|
/// very simple algorithm that sequentially checks each element of the list until a match is found or the whole list has been searched.
|
||||||
pub mod linear_search;
|
pub mod linear_search;
|
||||||
|
|
||||||
/// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries.
|
/// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries.
|
||||||
/// `KNNAlgorithmName` maintains a list of supported search algorithms, see [KNN algorithms](../algorithm/neighbour/index.html)
|
/// `KNNAlgorithmName` maintains a list of supported search algorithms, see [KNN algorithms](../algorithm/neighbour/index.html)
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub enum KNNAlgorithmName {
|
pub enum KNNAlgorithmName {
|
||||||
/// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html)
|
/// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html)
|
||||||
LinearSearch,
|
LinearSearch,
|
||||||
@@ -53,7 +59,8 @@ pub enum KNNAlgorithmName {
|
|||||||
CoverTree,
|
CoverTree,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub(crate) enum KNNAlgorithm<T: RealNumber, D: Distance<Vec<T>, T>> {
|
pub(crate) enum KNNAlgorithm<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||||
LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
|
LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
|
||||||
CoverTree(CoverTree<Vec<T>, T, D>),
|
CoverTree(CoverTree<Vec<T>, T, D>),
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ pub struct HeapSelection<T: PartialOrd + Debug> {
|
|||||||
heap: Vec<T>,
|
heap: Vec<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: PartialOrd + Debug> HeapSelection<T> {
|
impl<T: PartialOrd + Debug> HeapSelection<T> {
|
||||||
pub fn with_capacity(k: usize) -> HeapSelection<T> {
|
pub fn with_capacity(k: usize) -> HeapSelection<T> {
|
||||||
HeapSelection {
|
HeapSelection {
|
||||||
k,
|
k,
|
||||||
@@ -53,8 +53,7 @@ impl<'a, T: PartialOrd + Debug> HeapSelection<T> {
|
|||||||
if self.sorted {
|
if self.sorted {
|
||||||
&self.heap[0]
|
&self.heap[0]
|
||||||
} else {
|
} else {
|
||||||
&self
|
self.heap
|
||||||
.heap
|
|
||||||
.iter()
|
.iter()
|
||||||
.max_by(|a, b| a.partial_cmp(b).unwrap())
|
.max_by(|a, b| a.partial_cmp(b).unwrap())
|
||||||
.unwrap()
|
.unwrap()
|
||||||
@@ -96,12 +95,14 @@ impl<'a, T: PartialOrd + Debug> HeapSelection<T> {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn with_capacity() {
|
fn with_capacity() {
|
||||||
let heap = HeapSelection::<i32>::with_capacity(3);
|
let heap = HeapSelection::<i32>::with_capacity(3);
|
||||||
assert_eq!(3, heap.k);
|
assert_eq!(3, heap.k);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn test_add() {
|
fn test_add() {
|
||||||
let mut heap = HeapSelection::with_capacity(3);
|
let mut heap = HeapSelection::with_capacity(3);
|
||||||
@@ -119,6 +120,7 @@ mod tests {
|
|||||||
assert_eq!(vec![2, 0, -5], heap.get());
|
assert_eq!(vec![2, 0, -5], heap.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn test_add1() {
|
fn test_add1() {
|
||||||
let mut heap = HeapSelection::with_capacity(3);
|
let mut heap = HeapSelection::with_capacity(3);
|
||||||
@@ -133,6 +135,7 @@ mod tests {
|
|||||||
assert_eq!(vec![0f64, -1f64, -5f64], heap.get());
|
assert_eq!(vec![0f64, -1f64, -5f64], heap.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn test_add2() {
|
fn test_add2() {
|
||||||
let mut heap = HeapSelection::with_capacity(3);
|
let mut heap = HeapSelection::with_capacity(3);
|
||||||
@@ -145,6 +148,7 @@ mod tests {
|
|||||||
assert_eq!(vec![5.6568, 2.8284, 0.0], heap.get());
|
assert_eq!(vec![5.6568, 2.8284, 0.0], heap.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn test_add_ordered() {
|
fn test_add_ordered() {
|
||||||
let mut heap = HeapSelection::with_capacity(3);
|
let mut heap = HeapSelection::with_capacity(3);
|
||||||
|
|||||||
@@ -113,6 +113,7 @@ impl<T: Float> QuickArgSort for Vec<T> {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn with_capacity() {
|
fn with_capacity() {
|
||||||
let arr1 = vec![0.3, 0.1, 0.2, 0.4, 0.9, 0.5, 0.7, 0.6, 0.8];
|
let arr1 = vec![0.3, 0.1, 0.2, 0.4, 0.9, 0.5, 0.7, 0.6, 0.8];
|
||||||
|
|||||||
@@ -43,6 +43,7 @@
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::iter::Sum;
|
use std::iter::Sum;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
|
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
|
||||||
@@ -55,7 +56,8 @@ use crate::math::num::RealNumber;
|
|||||||
use crate::tree::decision_tree_classifier::which_max;
|
use crate::tree::decision_tree_classifier::which_max;
|
||||||
|
|
||||||
/// DBSCAN clustering algorithm
|
/// DBSCAN clustering algorithm
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct DBSCAN<T: RealNumber, D: Distance<Vec<T>, T>> {
|
pub struct DBSCAN<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||||
cluster_labels: Vec<i16>,
|
cluster_labels: Vec<i16>,
|
||||||
num_classes: usize,
|
num_classes: usize,
|
||||||
@@ -153,11 +155,11 @@ impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
|
|||||||
parameters: DBSCANParameters<T, D>,
|
parameters: DBSCANParameters<T, D>,
|
||||||
) -> Result<DBSCAN<T, D>, Failed> {
|
) -> Result<DBSCAN<T, D>, Failed> {
|
||||||
if parameters.min_samples < 1 {
|
if parameters.min_samples < 1 {
|
||||||
return Err(Failed::fit(&"Invalid minPts".to_string()));
|
return Err(Failed::fit("Invalid minPts"));
|
||||||
}
|
}
|
||||||
|
|
||||||
if parameters.eps <= T::zero() {
|
if parameters.eps <= T::zero() {
|
||||||
return Err(Failed::fit(&"Invalid radius: ".to_string()));
|
return Err(Failed::fit("Invalid radius: "));
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut k = 0;
|
let mut k = 0;
|
||||||
@@ -263,8 +265,10 @@ impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use crate::math::distance::euclidian::Euclidian;
|
use crate::math::distance::euclidian::Euclidian;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn fit_predict_dbscan() {
|
fn fit_predict_dbscan() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
@@ -296,7 +300,9 @@ mod tests {
|
|||||||
assert_eq!(expected_labels, predicted_labels);
|
assert_eq!(expected_labels, predicted_labels);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[5.1, 3.5, 1.4, 0.2],
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
|||||||
@@ -0,0 +1,63 @@
|
|||||||
|
/// # Hierarchical clustering
|
||||||
|
///
|
||||||
|
/// Implement hierarchical clustering methods:
|
||||||
|
/// * Agglomerative clustering (current)
|
||||||
|
/// * Bisecting K-Means (future)
|
||||||
|
/// * Fastcluster (future)
|
||||||
|
///
|
||||||
|
|
||||||
|
/*
|
||||||
|
class AgglomerativeClustering():
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
n_clusters : int or None, default=2
|
||||||
|
The number of clusters to find. It must be ``None`` if
|
||||||
|
``distance_threshold`` is not ``None``.
|
||||||
|
affinity : str or callable, default='euclidean'
|
||||||
|
If linkage is "ward", only "euclidean" is accepted.
|
||||||
|
linkage : {'ward',}, default='ward'
|
||||||
|
Which linkage criterion to use. The linkage criterion determines which
|
||||||
|
distance to use between sets of observation. The algorithm will merge
|
||||||
|
the pairs of cluster that minimize this criterion.
|
||||||
|
- 'ward' minimizes the variance of the clusters being merged.
|
||||||
|
compute_distances : bool, default=False
|
||||||
|
Computes distances between clusters even if `distance_threshold` is not
|
||||||
|
used. This can be used to make dendrogram visualization, but introduces
|
||||||
|
a computational and memory overhead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def fit(X):
|
||||||
|
# compute tree
|
||||||
|
# <https://github.com/scikit-learn/scikit-learn/blob/02ebf9e68fe1fc7687d9e1047b9e465ae0fd945e/sklearn/cluster/_agglomerative.py#L172>
|
||||||
|
parents, childern = ward_tree(X, ....)
|
||||||
|
# compute clusters
|
||||||
|
# <https://github.com/scikit-learn/scikit-learn/blob/70c495250fea7fa3c8c1a4631e6ddcddc9f22451/sklearn/cluster/_hierarchical_fast.pyx#L98>
|
||||||
|
labels = _hierarchical.hc_get_heads(parents)
|
||||||
|
# assign cluster numbers
|
||||||
|
self.labels_ = np.searchsorted(np.unique(labels), labels)
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
// implement ward tree
|
||||||
|
// use scipy.cluster.hierarchy.ward
|
||||||
|
// <https://github.com/scipy/scipy/blob/main/scipy/cluster/hierarchy.py#L738>
|
||||||
|
// use linkage
|
||||||
|
// <https://github.com/scipy/scipy/blob/main/scipy/cluster/hierarchy.py#L837>
|
||||||
|
// use nn_chain
|
||||||
|
// <https://github.com/scipy/scipy/blob/main/scipy/cluster/_hierarchy.pyx#L906>
|
||||||
|
|
||||||
|
// implement hc_get_heads
|
||||||
|
|
||||||
|
|
||||||
|
mod tests {
|
||||||
|
// >>> from sklearn.cluster import AgglomerativeClustering
|
||||||
|
// >>> import numpy as np
|
||||||
|
// >>> X = np.array([[1, 2], [1, 4], [1, 0],
|
||||||
|
// ... [4, 2], [4, 4], [4, 0]])
|
||||||
|
// >>> clustering = AgglomerativeClustering().fit(X)
|
||||||
|
// >>> clustering
|
||||||
|
// AgglomerativeClustering()
|
||||||
|
// >>> clustering.labels_
|
||||||
|
// array([1, 1, 1, 0, 0, 0])
|
||||||
|
}
|
||||||
+13
-7
@@ -56,6 +56,7 @@ use rand::Rng;
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::iter::Sum;
|
use std::iter::Sum;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
||||||
@@ -66,12 +67,13 @@ use crate::math::distance::euclidian::*;
|
|||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
/// K-Means clustering algorithm
|
/// K-Means clustering algorithm
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct KMeans<T: RealNumber> {
|
pub struct KMeans<T: RealNumber> {
|
||||||
k: usize,
|
k: usize,
|
||||||
y: Vec<usize>,
|
_y: Vec<usize>,
|
||||||
size: Vec<usize>,
|
size: Vec<usize>,
|
||||||
distortion: T,
|
_distortion: T,
|
||||||
centroids: Vec<Vec<T>>,
|
centroids: Vec<Vec<T>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -206,9 +208,9 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
|||||||
|
|
||||||
Ok(KMeans {
|
Ok(KMeans {
|
||||||
k: parameters.k,
|
k: parameters.k,
|
||||||
y,
|
_y: y,
|
||||||
size,
|
size,
|
||||||
distortion,
|
_distortion: distortion,
|
||||||
centroids,
|
centroids,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -243,7 +245,7 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
|||||||
let mut rng = rand::thread_rng();
|
let mut rng = rand::thread_rng();
|
||||||
let (n, m) = data.shape();
|
let (n, m) = data.shape();
|
||||||
let mut y = vec![0; n];
|
let mut y = vec![0; n];
|
||||||
let mut centroid = data.get_row_as_vec(rng.gen_range(0, n));
|
let mut centroid = data.get_row_as_vec(rng.gen_range(0..n));
|
||||||
|
|
||||||
let mut d = vec![T::max_value(); n];
|
let mut d = vec![T::max_value(); n];
|
||||||
|
|
||||||
@@ -297,6 +299,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn invalid_k() {
|
fn invalid_k() {
|
||||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
|
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
|
||||||
@@ -310,6 +313,7 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn fit_predict_iris() {
|
fn fit_predict_iris() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
@@ -340,11 +344,13 @@ mod tests {
|
|||||||
let y = kmeans.predict(&x).unwrap();
|
let y = kmeans.predict(&x).unwrap();
|
||||||
|
|
||||||
for i in 0..y.len() {
|
for i in 0..y.len() {
|
||||||
assert_eq!(y[i] as usize, kmeans.y[i]);
|
assert_eq!(y[i] as usize, kmeans._y[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[5.1, 3.5, 1.4, 0.2],
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
|||||||
@@ -56,9 +56,11 @@ pub fn load_dataset() -> Dataset<f32, f32> {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
use super::super::*;
|
use super::super::*;
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
#[test]
|
#[test]
|
||||||
#[ignore]
|
#[ignore]
|
||||||
fn refresh_boston_dataset() {
|
fn refresh_boston_dataset() {
|
||||||
@@ -67,6 +69,7 @@ mod tests {
|
|||||||
assert!(serialize_data(&dataset, "boston.xy").is_ok());
|
assert!(serialize_data(&dataset, "boston.xy").is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn boston_dataset() {
|
fn boston_dataset() {
|
||||||
let dataset = load_dataset();
|
let dataset = load_dataset();
|
||||||
|
|||||||
@@ -66,17 +66,20 @@ pub fn load_dataset() -> Dataset<f32, f32> {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
use super::super::*;
|
use super::super::*;
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[ignore]
|
#[ignore]
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
fn refresh_cancer_dataset() {
|
fn refresh_cancer_dataset() {
|
||||||
// run this test to generate breast_cancer.xy file.
|
// run this test to generate breast_cancer.xy file.
|
||||||
let dataset = load_dataset();
|
let dataset = load_dataset();
|
||||||
assert!(serialize_data(&dataset, "breast_cancer.xy").is_ok());
|
assert!(serialize_data(&dataset, "breast_cancer.xy").is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn cancer_dataset() {
|
fn cancer_dataset() {
|
||||||
let dataset = load_dataset();
|
let dataset = load_dataset();
|
||||||
|
|||||||
@@ -50,9 +50,11 @@ pub fn load_dataset() -> Dataset<f32, f32> {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
use super::super::*;
|
use super::super::*;
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
#[test]
|
#[test]
|
||||||
#[ignore]
|
#[ignore]
|
||||||
fn refresh_diabetes_dataset() {
|
fn refresh_diabetes_dataset() {
|
||||||
@@ -61,6 +63,7 @@ mod tests {
|
|||||||
assert!(serialize_data(&dataset, "diabetes.xy").is_ok());
|
assert!(serialize_data(&dataset, "diabetes.xy").is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn boston_dataset() {
|
fn boston_dataset() {
|
||||||
let dataset = load_dataset();
|
let dataset = load_dataset();
|
||||||
|
|||||||
@@ -45,9 +45,11 @@ pub fn load_dataset() -> Dataset<f32, f32> {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
use super::super::*;
|
use super::super::*;
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
#[test]
|
#[test]
|
||||||
#[ignore]
|
#[ignore]
|
||||||
fn refresh_digits_dataset() {
|
fn refresh_digits_dataset() {
|
||||||
@@ -55,7 +57,7 @@ mod tests {
|
|||||||
let dataset = load_dataset();
|
let dataset = load_dataset();
|
||||||
assert!(serialize_data(&dataset, "digits.xy").is_ok());
|
assert!(serialize_data(&dataset, "digits.xy").is_ok());
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn digits_dataset() {
|
fn digits_dataset() {
|
||||||
let dataset = load_dataset();
|
let dataset = load_dataset();
|
||||||
|
|||||||
@@ -88,6 +88,43 @@ pub fn make_circles(num_samples: usize, factor: f32, noise: f32) -> Dataset<f32,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Make two interleaving half circles in 2d
|
||||||
|
pub fn make_moons(num_samples: usize, noise: f32) -> Dataset<f32, f32> {
|
||||||
|
let num_samples_out = num_samples / 2;
|
||||||
|
let num_samples_in = num_samples - num_samples_out;
|
||||||
|
|
||||||
|
let linspace_out = linspace(0.0, std::f32::consts::PI, num_samples_out);
|
||||||
|
let linspace_in = linspace(0.0, std::f32::consts::PI, num_samples_in);
|
||||||
|
|
||||||
|
let noise = Normal::new(0.0, noise).unwrap();
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
|
||||||
|
let mut x: Vec<f32> = Vec::with_capacity(num_samples * 2);
|
||||||
|
let mut y: Vec<f32> = Vec::with_capacity(num_samples);
|
||||||
|
|
||||||
|
for v in linspace_out {
|
||||||
|
x.push(v.cos() + noise.sample(&mut rng));
|
||||||
|
x.push(v.sin() + noise.sample(&mut rng));
|
||||||
|
y.push(0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
for v in linspace_in {
|
||||||
|
x.push(1.0 - v.cos() + noise.sample(&mut rng));
|
||||||
|
x.push(1.0 - v.sin() + noise.sample(&mut rng) - 0.5);
|
||||||
|
y.push(1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
Dataset {
|
||||||
|
data: x,
|
||||||
|
target: y,
|
||||||
|
num_samples,
|
||||||
|
num_features: 2,
|
||||||
|
feature_names: (0..2).map(|n| n.to_string()).collect(),
|
||||||
|
target_names: vec!["label".to_string()],
|
||||||
|
description: "Two interleaving half circles in 2d".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn linspace(start: f32, stop: f32, num: usize) -> Vec<f32> {
|
fn linspace(start: f32, stop: f32, num: usize) -> Vec<f32> {
|
||||||
let div = num as f32;
|
let div = num as f32;
|
||||||
let delta = stop - start;
|
let delta = stop - start;
|
||||||
@@ -100,6 +137,7 @@ mod tests {
|
|||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn test_make_blobs() {
|
fn test_make_blobs() {
|
||||||
let dataset = make_blobs(10, 2, 3);
|
let dataset = make_blobs(10, 2, 3);
|
||||||
@@ -112,6 +150,7 @@ mod tests {
|
|||||||
assert_eq!(dataset.num_samples, 10);
|
assert_eq!(dataset.num_samples, 10);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn test_make_circles() {
|
fn test_make_circles() {
|
||||||
let dataset = make_circles(10, 0.5, 0.05);
|
let dataset = make_circles(10, 0.5, 0.05);
|
||||||
@@ -123,4 +162,17 @@ mod tests {
|
|||||||
assert_eq!(dataset.num_features, 2);
|
assert_eq!(dataset.num_features, 2);
|
||||||
assert_eq!(dataset.num_samples, 10);
|
assert_eq!(dataset.num_samples, 10);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
|
#[test]
|
||||||
|
fn test_make_moons() {
|
||||||
|
let dataset = make_moons(10, 0.05);
|
||||||
|
assert_eq!(
|
||||||
|
dataset.data.len(),
|
||||||
|
dataset.num_features * dataset.num_samples
|
||||||
|
);
|
||||||
|
assert_eq!(dataset.target.len(), dataset.num_samples);
|
||||||
|
assert_eq!(dataset.num_features, 2);
|
||||||
|
assert_eq!(dataset.num_samples, 10);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -50,9 +50,11 @@ pub fn load_dataset() -> Dataset<f32, f32> {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
use super::super::*;
|
use super::super::*;
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
#[test]
|
#[test]
|
||||||
#[ignore]
|
#[ignore]
|
||||||
fn refresh_iris_dataset() {
|
fn refresh_iris_dataset() {
|
||||||
@@ -61,6 +63,7 @@ mod tests {
|
|||||||
assert!(serialize_data(&dataset, "iris.xy").is_ok());
|
assert!(serialize_data(&dataset, "iris.xy").is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn iris_dataset() {
|
fn iris_dataset() {
|
||||||
let dataset = load_dataset();
|
let dataset = load_dataset();
|
||||||
|
|||||||
+12
-5
@@ -8,9 +8,12 @@ pub mod digits;
|
|||||||
pub mod generator;
|
pub mod generator;
|
||||||
pub mod iris;
|
pub mod iris;
|
||||||
|
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io;
|
use std::io;
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
use std::io::prelude::*;
|
use std::io::prelude::*;
|
||||||
|
|
||||||
/// Dataset
|
/// Dataset
|
||||||
@@ -49,6 +52,8 @@ impl<X, Y> Dataset<X, Y> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Running this in wasm throws: operation not supported on this platform.
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub(crate) fn serialize_data<X: RealNumber, Y: RealNumber>(
|
pub(crate) fn serialize_data<X: RealNumber, Y: RealNumber>(
|
||||||
dataset: &Dataset<X, Y>,
|
dataset: &Dataset<X, Y>,
|
||||||
@@ -62,14 +67,14 @@ pub(crate) fn serialize_data<X: RealNumber, Y: RealNumber>(
|
|||||||
.data
|
.data
|
||||||
.iter()
|
.iter()
|
||||||
.copied()
|
.copied()
|
||||||
.flat_map(|f| f.to_f32_bits().to_le_bytes().to_vec().into_iter())
|
.flat_map(|f| f.to_f32_bits().to_le_bytes().to_vec())
|
||||||
.collect();
|
.collect();
|
||||||
file.write_all(&x)?;
|
file.write_all(&x)?;
|
||||||
let y: Vec<u8> = dataset
|
let y: Vec<u8> = dataset
|
||||||
.target
|
.target
|
||||||
.iter()
|
.iter()
|
||||||
.copied()
|
.copied()
|
||||||
.flat_map(|f| f.to_f32_bits().to_le_bytes().to_vec().into_iter())
|
.flat_map(|f| f.to_f32_bits().to_le_bytes().to_vec())
|
||||||
.collect();
|
.collect();
|
||||||
file.write_all(&y)?;
|
file.write_all(&y)?;
|
||||||
}
|
}
|
||||||
@@ -82,11 +87,12 @@ pub(crate) fn deserialize_data(
|
|||||||
bytes: &[u8],
|
bytes: &[u8],
|
||||||
) -> Result<(Vec<f32>, Vec<f32>, usize, usize), io::Error> {
|
) -> Result<(Vec<f32>, Vec<f32>, usize, usize), io::Error> {
|
||||||
// read the same file back into a Vec of bytes
|
// read the same file back into a Vec of bytes
|
||||||
|
const USIZE_SIZE: usize = std::mem::size_of::<usize>();
|
||||||
let (num_samples, num_features) = {
|
let (num_samples, num_features) = {
|
||||||
let mut buffer = [0u8; 8];
|
let mut buffer = [0u8; USIZE_SIZE];
|
||||||
buffer.copy_from_slice(&bytes[0..8]);
|
buffer.copy_from_slice(&bytes[0..USIZE_SIZE]);
|
||||||
let num_features = usize::from_le_bytes(buffer);
|
let num_features = usize::from_le_bytes(buffer);
|
||||||
buffer.copy_from_slice(&bytes[8..16]);
|
buffer.copy_from_slice(&bytes[8..8 + USIZE_SIZE]);
|
||||||
let num_samples = usize::from_le_bytes(buffer);
|
let num_samples = usize::from_le_bytes(buffer);
|
||||||
(num_samples, num_features)
|
(num_samples, num_features)
|
||||||
};
|
};
|
||||||
@@ -115,6 +121,7 @@ pub(crate) fn deserialize_data(
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn as_matrix() {
|
fn as_matrix() {
|
||||||
let dataset = Dataset {
|
let dataset = Dataset {
|
||||||
|
|||||||
@@ -47,6 +47,7 @@
|
|||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::api::{Transformer, UnsupervisedEstimator};
|
use crate::api::{Transformer, UnsupervisedEstimator};
|
||||||
@@ -55,7 +56,8 @@ use crate::linalg::Matrix;
|
|||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
/// Principal components analysis algorithm
|
/// Principal components analysis algorithm
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct PCA<T: RealNumber, M: Matrix<T>> {
|
pub struct PCA<T: RealNumber, M: Matrix<T>> {
|
||||||
eigenvectors: M,
|
eigenvectors: M,
|
||||||
eigenvalues: Vec<T>,
|
eigenvalues: Vec<T>,
|
||||||
@@ -323,7 +325,7 @@ mod tests {
|
|||||||
&[6.8, 161.0, 60.0, 15.6],
|
&[6.8, 161.0, 60.0, 15.6],
|
||||||
])
|
])
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn pca_components() {
|
fn pca_components() {
|
||||||
let us_arrests = us_arrests_data();
|
let us_arrests = us_arrests_data();
|
||||||
@@ -339,7 +341,7 @@ mod tests {
|
|||||||
|
|
||||||
assert!(expected.approximate_eq(&pca.components().abs(), 0.4));
|
assert!(expected.approximate_eq(&pca.components().abs(), 0.4));
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn decompose_covariance() {
|
fn decompose_covariance() {
|
||||||
let us_arrests = us_arrests_data();
|
let us_arrests = us_arrests_data();
|
||||||
@@ -449,6 +451,7 @@ mod tests {
|
|||||||
.approximate_eq(&expected_projection.abs(), 1e-4));
|
.approximate_eq(&expected_projection.abs(), 1e-4));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn decompose_correlation() {
|
fn decompose_correlation() {
|
||||||
let us_arrests = us_arrests_data();
|
let us_arrests = us_arrests_data();
|
||||||
@@ -564,7 +567,9 @@ mod tests {
|
|||||||
.approximate_eq(&expected_projection.abs(), 1e-4));
|
.approximate_eq(&expected_projection.abs(), 1e-4));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let iris = DenseMatrix::from_2d_array(&[
|
let iris = DenseMatrix::from_2d_array(&[
|
||||||
&[5.1, 3.5, 1.4, 0.2],
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
|||||||
@@ -46,6 +46,7 @@
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::api::{Transformer, UnsupervisedEstimator};
|
use crate::api::{Transformer, UnsupervisedEstimator};
|
||||||
@@ -54,7 +55,8 @@ use crate::linalg::Matrix;
|
|||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
/// SVD
|
/// SVD
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct SVD<T: RealNumber, M: Matrix<T>> {
|
pub struct SVD<T: RealNumber, M: Matrix<T>> {
|
||||||
components: M,
|
components: M,
|
||||||
phantom: PhantomData<T>,
|
phantom: PhantomData<T>,
|
||||||
@@ -151,6 +153,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::*;
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn svd_decompose() {
|
fn svd_decompose() {
|
||||||
// https://stat.ethz.ch/R-manual/R-devel/library/datasets/html/USArrests.html
|
// https://stat.ethz.ch/R-manual/R-devel/library/datasets/html/USArrests.html
|
||||||
@@ -225,7 +228,9 @@ mod tests {
|
|||||||
.approximate_eq(&expected, 1e-4));
|
.approximate_eq(&expected, 1e-4));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let iris = DenseMatrix::from_2d_array(&[
|
let iris = DenseMatrix::from_2d_array(&[
|
||||||
&[5.1, 3.5, 1.4, 0.2],
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
|||||||
@@ -45,14 +45,16 @@
|
|||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
use rand::rngs::StdRng;
|
||||||
|
use rand::{Rng, SeedableRng};
|
||||||
use std::default::Default;
|
use std::default::Default;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
use rand::Rng;
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::api::{Predictor, SupervisedEstimator};
|
use crate::api::{Predictor, SupervisedEstimator};
|
||||||
use crate::error::Failed;
|
use crate::error::{Failed, FailedError};
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
use crate::tree::decision_tree_classifier::{
|
use crate::tree::decision_tree_classifier::{
|
||||||
@@ -61,7 +63,8 @@ use crate::tree::decision_tree_classifier::{
|
|||||||
|
|
||||||
/// Parameters of the Random Forest algorithm.
|
/// Parameters of the Random Forest algorithm.
|
||||||
/// Some parameters here are passed directly into base estimator.
|
/// Some parameters here are passed directly into base estimator.
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct RandomForestClassifierParameters {
|
pub struct RandomForestClassifierParameters {
|
||||||
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||||
pub criterion: SplitCriterion,
|
pub criterion: SplitCriterion,
|
||||||
@@ -75,14 +78,20 @@ pub struct RandomForestClassifierParameters {
|
|||||||
pub n_trees: u16,
|
pub n_trees: u16,
|
||||||
/// Number of random sample of predictors to use as split candidates.
|
/// Number of random sample of predictors to use as split candidates.
|
||||||
pub m: Option<usize>,
|
pub m: Option<usize>,
|
||||||
|
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||||
|
pub keep_samples: bool,
|
||||||
|
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||||
|
pub seed: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Random Forest Classifier
|
/// Random Forest Classifier
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct RandomForestClassifier<T: RealNumber> {
|
pub struct RandomForestClassifier<T: RealNumber> {
|
||||||
parameters: RandomForestClassifierParameters,
|
_parameters: RandomForestClassifierParameters,
|
||||||
trees: Vec<DecisionTreeClassifier<T>>,
|
trees: Vec<DecisionTreeClassifier<T>>,
|
||||||
classes: Vec<T>,
|
classes: Vec<T>,
|
||||||
|
samples: Option<Vec<Vec<bool>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RandomForestClassifierParameters {
|
impl RandomForestClassifierParameters {
|
||||||
@@ -116,6 +125,18 @@ impl RandomForestClassifierParameters {
|
|||||||
self.m = Some(m);
|
self.m = Some(m);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||||
|
pub fn with_keep_samples(mut self, keep_samples: bool) -> Self {
|
||||||
|
self.keep_samples = keep_samples;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||||
|
pub fn with_seed(mut self, seed: u64) -> Self {
|
||||||
|
self.seed = seed;
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: RealNumber> PartialEq for RandomForestClassifier<T> {
|
impl<T: RealNumber> PartialEq for RandomForestClassifier<T> {
|
||||||
@@ -147,6 +168,8 @@ impl Default for RandomForestClassifierParameters {
|
|||||||
min_samples_split: 2,
|
min_samples_split: 2,
|
||||||
n_trees: 100,
|
n_trees: 100,
|
||||||
m: Option::None,
|
m: Option::None,
|
||||||
|
keep_samples: false,
|
||||||
|
seed: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -198,26 +221,38 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
|||||||
.unwrap()
|
.unwrap()
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let mut rng = StdRng::seed_from_u64(parameters.seed);
|
||||||
let classes = y_m.unique();
|
let classes = y_m.unique();
|
||||||
let k = classes.len();
|
let k = classes.len();
|
||||||
let mut trees: Vec<DecisionTreeClassifier<T>> = Vec::new();
|
let mut trees: Vec<DecisionTreeClassifier<T>> = Vec::new();
|
||||||
|
|
||||||
|
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
|
||||||
|
if parameters.keep_samples {
|
||||||
|
maybe_all_samples = Some(Vec::new());
|
||||||
|
}
|
||||||
|
|
||||||
for _ in 0..parameters.n_trees {
|
for _ in 0..parameters.n_trees {
|
||||||
let samples = RandomForestClassifier::<T>::sample_with_replacement(&yi, k);
|
let samples = RandomForestClassifier::<T>::sample_with_replacement(&yi, k, &mut rng);
|
||||||
|
if let Some(ref mut all_samples) = maybe_all_samples {
|
||||||
|
all_samples.push(samples.iter().map(|x| *x != 0).collect())
|
||||||
|
}
|
||||||
|
|
||||||
let params = DecisionTreeClassifierParameters {
|
let params = DecisionTreeClassifierParameters {
|
||||||
criterion: parameters.criterion.clone(),
|
criterion: parameters.criterion.clone(),
|
||||||
max_depth: parameters.max_depth,
|
max_depth: parameters.max_depth,
|
||||||
min_samples_leaf: parameters.min_samples_leaf,
|
min_samples_leaf: parameters.min_samples_leaf,
|
||||||
min_samples_split: parameters.min_samples_split,
|
min_samples_split: parameters.min_samples_split,
|
||||||
};
|
};
|
||||||
let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params)?;
|
let tree =
|
||||||
|
DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
|
||||||
trees.push(tree);
|
trees.push(tree);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(RandomForestClassifier {
|
Ok(RandomForestClassifier {
|
||||||
parameters,
|
_parameters: parameters,
|
||||||
trees,
|
trees,
|
||||||
classes,
|
classes,
|
||||||
|
samples: maybe_all_samples,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -245,8 +280,43 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
|||||||
which_max(&result)
|
which_max(&result)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sample_with_replacement(y: &[usize], num_classes: usize) -> Vec<usize> {
|
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
|
||||||
let mut rng = rand::thread_rng();
|
pub fn predict_oob<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||||
|
let (n, _) = x.shape();
|
||||||
|
if self.samples.is_none() {
|
||||||
|
Err(Failed::because(
|
||||||
|
FailedError::PredictFailed,
|
||||||
|
"Need samples=true for OOB predictions.",
|
||||||
|
))
|
||||||
|
} else if self.samples.as_ref().unwrap()[0].len() != n {
|
||||||
|
Err(Failed::because(
|
||||||
|
FailedError::PredictFailed,
|
||||||
|
"Prediction matrix must match matrix used in training for OOB predictions.",
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
let mut result = M::zeros(1, n);
|
||||||
|
|
||||||
|
for i in 0..n {
|
||||||
|
result.set(0, i, self.classes[self.predict_for_row_oob(x, i)]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result.to_row_vector())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn predict_for_row_oob<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
|
||||||
|
let mut result = vec![0; self.classes.len()];
|
||||||
|
|
||||||
|
for (tree, samples) in self.trees.iter().zip(self.samples.as_ref().unwrap()) {
|
||||||
|
if !samples[row] {
|
||||||
|
result[tree.predict_for_row(x, row)] += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
which_max(&result)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sample_with_replacement(y: &[usize], num_classes: usize, rng: &mut impl Rng) -> Vec<usize> {
|
||||||
let class_weight = vec![1.; num_classes];
|
let class_weight = vec![1.; num_classes];
|
||||||
let nrows = y.len();
|
let nrows = y.len();
|
||||||
let mut samples = vec![0; nrows];
|
let mut samples = vec![0; nrows];
|
||||||
@@ -262,7 +332,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
|||||||
|
|
||||||
let size = ((n_samples as f64) / *class_weight_l) as usize;
|
let size = ((n_samples as f64) / *class_weight_l) as usize;
|
||||||
for _ in 0..size {
|
for _ in 0..size {
|
||||||
let xi: usize = rng.gen_range(0, n_samples);
|
let xi: usize = rng.gen_range(0..n_samples);
|
||||||
samples[index[xi]] += 1;
|
samples[index[xi]] += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -276,6 +346,7 @@ mod tests {
|
|||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
use crate::metrics::*;
|
use crate::metrics::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn fit_predict_iris() {
|
fn fit_predict_iris() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
@@ -314,6 +385,8 @@ mod tests {
|
|||||||
min_samples_split: 2,
|
min_samples_split: 2,
|
||||||
n_trees: 100,
|
n_trees: 100,
|
||||||
m: Option::None,
|
m: Option::None,
|
||||||
|
keep_samples: false,
|
||||||
|
seed: 87,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -321,7 +394,60 @@ mod tests {
|
|||||||
assert!(accuracy(&y, &classifier.predict(&x).unwrap()) >= 0.95);
|
assert!(accuracy(&y, &classifier.predict(&x).unwrap()) >= 0.95);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
|
fn fit_predict_iris_oob() {
|
||||||
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
|
&[4.7, 3.2, 1.3, 0.2],
|
||||||
|
&[4.6, 3.1, 1.5, 0.2],
|
||||||
|
&[5.0, 3.6, 1.4, 0.2],
|
||||||
|
&[5.4, 3.9, 1.7, 0.4],
|
||||||
|
&[4.6, 3.4, 1.4, 0.3],
|
||||||
|
&[5.0, 3.4, 1.5, 0.2],
|
||||||
|
&[4.4, 2.9, 1.4, 0.2],
|
||||||
|
&[4.9, 3.1, 1.5, 0.1],
|
||||||
|
&[7.0, 3.2, 4.7, 1.4],
|
||||||
|
&[6.4, 3.2, 4.5, 1.5],
|
||||||
|
&[6.9, 3.1, 4.9, 1.5],
|
||||||
|
&[5.5, 2.3, 4.0, 1.3],
|
||||||
|
&[6.5, 2.8, 4.6, 1.5],
|
||||||
|
&[5.7, 2.8, 4.5, 1.3],
|
||||||
|
&[6.3, 3.3, 4.7, 1.6],
|
||||||
|
&[4.9, 2.4, 3.3, 1.0],
|
||||||
|
&[6.6, 2.9, 4.6, 1.3],
|
||||||
|
&[5.2, 2.7, 3.9, 1.4],
|
||||||
|
]);
|
||||||
|
let y = vec![
|
||||||
|
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||||
|
];
|
||||||
|
|
||||||
|
let classifier = RandomForestClassifier::fit(
|
||||||
|
&x,
|
||||||
|
&y,
|
||||||
|
RandomForestClassifierParameters {
|
||||||
|
criterion: SplitCriterion::Gini,
|
||||||
|
max_depth: None,
|
||||||
|
min_samples_leaf: 1,
|
||||||
|
min_samples_split: 2,
|
||||||
|
n_trees: 100,
|
||||||
|
m: Option::None,
|
||||||
|
keep_samples: true,
|
||||||
|
seed: 87,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
accuracy(&y, &classifier.predict_oob(&x).unwrap())
|
||||||
|
< accuracy(&y, &classifier.predict(&x).unwrap())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[5.1, 3.5, 1.4, 0.2],
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
|||||||
@@ -43,21 +43,24 @@
|
|||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
|
||||||
|
use rand::rngs::StdRng;
|
||||||
|
use rand::{Rng, SeedableRng};
|
||||||
use std::default::Default;
|
use std::default::Default;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
use rand::Rng;
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::api::{Predictor, SupervisedEstimator};
|
use crate::api::{Predictor, SupervisedEstimator};
|
||||||
use crate::error::Failed;
|
use crate::error::{Failed, FailedError};
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
use crate::tree::decision_tree_regressor::{
|
use crate::tree::decision_tree_regressor::{
|
||||||
DecisionTreeRegressor, DecisionTreeRegressorParameters,
|
DecisionTreeRegressor, DecisionTreeRegressorParameters,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
/// Parameters of the Random Forest Regressor
|
/// Parameters of the Random Forest Regressor
|
||||||
/// Some parameters here are passed directly into base estimator.
|
/// Some parameters here are passed directly into base estimator.
|
||||||
pub struct RandomForestRegressorParameters {
|
pub struct RandomForestRegressorParameters {
|
||||||
@@ -71,13 +74,19 @@ pub struct RandomForestRegressorParameters {
|
|||||||
pub n_trees: usize,
|
pub n_trees: usize,
|
||||||
/// Number of random sample of predictors to use as split candidates.
|
/// Number of random sample of predictors to use as split candidates.
|
||||||
pub m: Option<usize>,
|
pub m: Option<usize>,
|
||||||
|
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||||
|
pub keep_samples: bool,
|
||||||
|
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||||
|
pub seed: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Random Forest Regressor
|
/// Random Forest Regressor
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct RandomForestRegressor<T: RealNumber> {
|
pub struct RandomForestRegressor<T: RealNumber> {
|
||||||
parameters: RandomForestRegressorParameters,
|
_parameters: RandomForestRegressorParameters,
|
||||||
trees: Vec<DecisionTreeRegressor<T>>,
|
trees: Vec<DecisionTreeRegressor<T>>,
|
||||||
|
samples: Option<Vec<Vec<bool>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RandomForestRegressorParameters {
|
impl RandomForestRegressorParameters {
|
||||||
@@ -106,8 +115,19 @@ impl RandomForestRegressorParameters {
|
|||||||
self.m = Some(m);
|
self.m = Some(m);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||||
|
pub fn with_keep_samples(mut self, keep_samples: bool) -> Self {
|
||||||
|
self.keep_samples = keep_samples;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||||
|
pub fn with_seed(mut self, seed: u64) -> Self {
|
||||||
|
self.seed = seed;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
impl Default for RandomForestRegressorParameters {
|
impl Default for RandomForestRegressorParameters {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
RandomForestRegressorParameters {
|
RandomForestRegressorParameters {
|
||||||
@@ -116,6 +136,8 @@ impl Default for RandomForestRegressorParameters {
|
|||||||
min_samples_split: 2,
|
min_samples_split: 2,
|
||||||
n_trees: 10,
|
n_trees: 10,
|
||||||
m: Option::None,
|
m: Option::None,
|
||||||
|
keep_samples: false,
|
||||||
|
seed: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -169,20 +191,34 @@ impl<T: RealNumber> RandomForestRegressor<T> {
|
|||||||
.m
|
.m
|
||||||
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
|
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
|
||||||
|
|
||||||
|
let mut rng = StdRng::seed_from_u64(parameters.seed);
|
||||||
let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new();
|
let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new();
|
||||||
|
|
||||||
|
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
|
||||||
|
if parameters.keep_samples {
|
||||||
|
maybe_all_samples = Some(Vec::new());
|
||||||
|
}
|
||||||
|
|
||||||
for _ in 0..parameters.n_trees {
|
for _ in 0..parameters.n_trees {
|
||||||
let samples = RandomForestRegressor::<T>::sample_with_replacement(n_rows);
|
let samples = RandomForestRegressor::<T>::sample_with_replacement(n_rows, &mut rng);
|
||||||
|
if let Some(ref mut all_samples) = maybe_all_samples {
|
||||||
|
all_samples.push(samples.iter().map(|x| *x != 0).collect())
|
||||||
|
}
|
||||||
let params = DecisionTreeRegressorParameters {
|
let params = DecisionTreeRegressorParameters {
|
||||||
max_depth: parameters.max_depth,
|
max_depth: parameters.max_depth,
|
||||||
min_samples_leaf: parameters.min_samples_leaf,
|
min_samples_leaf: parameters.min_samples_leaf,
|
||||||
min_samples_split: parameters.min_samples_split,
|
min_samples_split: parameters.min_samples_split,
|
||||||
};
|
};
|
||||||
let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params)?;
|
let tree =
|
||||||
|
DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
|
||||||
trees.push(tree);
|
trees.push(tree);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(RandomForestRegressor { parameters, trees })
|
Ok(RandomForestRegressor {
|
||||||
|
_parameters: parameters,
|
||||||
|
trees,
|
||||||
|
samples: maybe_all_samples,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Predict class for `x`
|
/// Predict class for `x`
|
||||||
@@ -211,11 +247,49 @@ impl<T: RealNumber> RandomForestRegressor<T> {
|
|||||||
result / T::from(n_trees).unwrap()
|
result / T::from(n_trees).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sample_with_replacement(nrows: usize) -> Vec<usize> {
|
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
|
||||||
let mut rng = rand::thread_rng();
|
pub fn predict_oob<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||||
|
let (n, _) = x.shape();
|
||||||
|
if self.samples.is_none() {
|
||||||
|
Err(Failed::because(
|
||||||
|
FailedError::PredictFailed,
|
||||||
|
"Need samples=true for OOB predictions.",
|
||||||
|
))
|
||||||
|
} else if self.samples.as_ref().unwrap()[0].len() != n {
|
||||||
|
Err(Failed::because(
|
||||||
|
FailedError::PredictFailed,
|
||||||
|
"Prediction matrix must match matrix used in training for OOB predictions.",
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
let mut result = M::zeros(1, n);
|
||||||
|
|
||||||
|
for i in 0..n {
|
||||||
|
result.set(0, i, self.predict_for_row_oob(x, i));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result.to_row_vector())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn predict_for_row_oob<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
|
||||||
|
let mut n_trees = 0;
|
||||||
|
let mut result = T::zero();
|
||||||
|
|
||||||
|
for (tree, samples) in self.trees.iter().zip(self.samples.as_ref().unwrap()) {
|
||||||
|
if !samples[row] {
|
||||||
|
result += tree.predict_for_row(x, row);
|
||||||
|
n_trees += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: What to do if there are no oob trees?
|
||||||
|
result / T::from(n_trees).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec<usize> {
|
||||||
let mut samples = vec![0; nrows];
|
let mut samples = vec![0; nrows];
|
||||||
for _ in 0..nrows {
|
for _ in 0..nrows {
|
||||||
let xi = rng.gen_range(0, nrows);
|
let xi = rng.gen_range(0..nrows);
|
||||||
samples[xi] += 1;
|
samples[xi] += 1;
|
||||||
}
|
}
|
||||||
samples
|
samples
|
||||||
@@ -228,6 +302,7 @@ mod tests {
|
|||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
use crate::metrics::mean_absolute_error;
|
use crate::metrics::mean_absolute_error;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn fit_longley() {
|
fn fit_longley() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
@@ -262,6 +337,8 @@ mod tests {
|
|||||||
min_samples_split: 2,
|
min_samples_split: 2,
|
||||||
n_trees: 1000,
|
n_trees: 1000,
|
||||||
m: Option::None,
|
m: Option::None,
|
||||||
|
keep_samples: false,
|
||||||
|
seed: 87,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.and_then(|rf| rf.predict(&x))
|
.and_then(|rf| rf.predict(&x))
|
||||||
@@ -270,7 +347,56 @@ mod tests {
|
|||||||
assert!(mean_absolute_error(&y, &y_hat) < 1.0);
|
assert!(mean_absolute_error(&y, &y_hat) < 1.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
|
fn fit_predict_longley_oob() {
|
||||||
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
|
&[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||||
|
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||||
|
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||||
|
&[284.599, 335.1, 165., 110.929, 1950., 61.187],
|
||||||
|
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||||
|
&[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
|
||||||
|
&[365.385, 187., 354.7, 115.094, 1953., 64.989],
|
||||||
|
&[363.112, 357.8, 335., 116.219, 1954., 63.761],
|
||||||
|
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||||
|
&[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||||
|
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||||
|
&[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
|
||||||
|
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||||
|
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||||
|
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||||
|
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||||
|
]);
|
||||||
|
let y = vec![
|
||||||
|
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||||
|
114.2, 115.7, 116.9,
|
||||||
|
];
|
||||||
|
|
||||||
|
let regressor = RandomForestRegressor::fit(
|
||||||
|
&x,
|
||||||
|
&y,
|
||||||
|
RandomForestRegressorParameters {
|
||||||
|
max_depth: None,
|
||||||
|
min_samples_leaf: 1,
|
||||||
|
min_samples_split: 2,
|
||||||
|
n_trees: 1000,
|
||||||
|
m: Option::None,
|
||||||
|
keep_samples: true,
|
||||||
|
seed: 87,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let y_hat = regressor.predict(&x).unwrap();
|
||||||
|
let y_hat_oob = regressor.predict_oob(&x).unwrap();
|
||||||
|
|
||||||
|
assert!(mean_absolute_error(&y, &y_hat) < mean_absolute_error(&y, &y_hat_oob));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
&[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||||
|
|||||||
+5
-2
@@ -2,10 +2,12 @@
|
|||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// Generic error to be raised when something goes wrong.
|
/// Generic error to be raised when something goes wrong.
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct Failed {
|
pub struct Failed {
|
||||||
err: FailedError,
|
err: FailedError,
|
||||||
msg: String,
|
msg: String,
|
||||||
@@ -13,7 +15,8 @@ pub struct Failed {
|
|||||||
|
|
||||||
/// Type of error
|
/// Type of error
|
||||||
#[non_exhaustive]
|
#[non_exhaustive]
|
||||||
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Copy, Clone, Debug)]
|
||||||
pub enum FailedError {
|
pub enum FailedError {
|
||||||
/// Can't fit algorithm to data
|
/// Can't fit algorithm to data
|
||||||
FitFailed = 1,
|
FitFailed = 1,
|
||||||
|
|||||||
+7
-3
@@ -1,10 +1,12 @@
|
|||||||
#![allow(
|
#![allow(
|
||||||
clippy::type_complexity,
|
clippy::type_complexity,
|
||||||
clippy::too_many_arguments,
|
clippy::too_many_arguments,
|
||||||
clippy::many_single_char_names
|
clippy::many_single_char_names,
|
||||||
|
clippy::unnecessary_wraps,
|
||||||
|
clippy::upper_case_acronyms
|
||||||
)]
|
)]
|
||||||
#![warn(missing_docs)]
|
#![warn(missing_docs)]
|
||||||
#![warn(missing_doc_code_examples)]
|
#![warn(rustdoc::missing_doc_code_examples)]
|
||||||
|
|
||||||
//! # SmartCore
|
//! # SmartCore
|
||||||
//!
|
//!
|
||||||
@@ -28,7 +30,7 @@
|
|||||||
//!
|
//!
|
||||||
//! All machine learning algorithms in SmartCore are grouped into these broad categories:
|
//! All machine learning algorithms in SmartCore are grouped into these broad categories:
|
||||||
//! * [Clustering](cluster/index.html), unsupervised clustering of unlabeled data.
|
//! * [Clustering](cluster/index.html), unsupervised clustering of unlabeled data.
|
||||||
//! * [Martix Decomposition](decomposition/index.html), various methods for matrix decomposition.
|
//! * [Matrix Decomposition](decomposition/index.html), various methods for matrix decomposition.
|
||||||
//! * [Linear Models](linear/index.html), regression and classification methods where output is assumed to have linear relation to explanatory variables
|
//! * [Linear Models](linear/index.html), regression and classification methods where output is assumed to have linear relation to explanatory variables
|
||||||
//! * [Ensemble Models](ensemble/index.html), variety of regression and classification ensemble models
|
//! * [Ensemble Models](ensemble/index.html), variety of regression and classification ensemble models
|
||||||
//! * [Tree-based Models](tree/index.html), classification and regression trees
|
//! * [Tree-based Models](tree/index.html), classification and regression trees
|
||||||
@@ -91,6 +93,8 @@ pub mod naive_bayes;
|
|||||||
/// Supervised neighbors-based learning methods
|
/// Supervised neighbors-based learning methods
|
||||||
pub mod neighbors;
|
pub mod neighbors;
|
||||||
pub(crate) mod optimization;
|
pub(crate) mod optimization;
|
||||||
|
/// Preprocessing utilities
|
||||||
|
pub mod preprocessing;
|
||||||
/// Support Vector Machines
|
/// Support Vector Machines
|
||||||
pub mod svm;
|
pub mod svm;
|
||||||
/// Supervised tree-based learning methods
|
/// Supervised tree-based learning methods
|
||||||
|
|||||||
@@ -87,8 +87,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> Cholesky<T, M> {
|
|||||||
if bn != rn {
|
if bn != rn {
|
||||||
return Err(Failed::because(
|
return Err(Failed::because(
|
||||||
FailedError::SolutionFailed,
|
FailedError::SolutionFailed,
|
||||||
&"Can\'t solve Ax = b for x. Number of rows in b != number of rows in R."
|
"Can\'t solve Ax = b for x. Number of rows in b != number of rows in R.",
|
||||||
.to_string(),
|
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -128,7 +127,7 @@ pub trait CholeskyDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
|||||||
if m != n {
|
if m != n {
|
||||||
return Err(Failed::because(
|
return Err(Failed::because(
|
||||||
FailedError::DecompositionFailed,
|
FailedError::DecompositionFailed,
|
||||||
&"Can\'t do Cholesky decomposition on a non-square matrix".to_string(),
|
"Can\'t do Cholesky decomposition on a non-square matrix",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -148,7 +147,7 @@ pub trait CholeskyDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
|||||||
if d < T::zero() {
|
if d < T::zero() {
|
||||||
return Err(Failed::because(
|
return Err(Failed::because(
|
||||||
FailedError::DecompositionFailed,
|
FailedError::DecompositionFailed,
|
||||||
&"The matrix is not positive definite.".to_string(),
|
"The matrix is not positive definite.",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,7 +167,7 @@ pub trait CholeskyDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::*;
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn cholesky_decompose() {
|
fn cholesky_decompose() {
|
||||||
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
|
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
|
||||||
@@ -187,6 +186,7 @@ mod tests {
|
|||||||
.approximate_eq(&a.abs(), 1e-4));
|
.approximate_eq(&a.abs(), 1e-4));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn cholesky_solve_mut() {
|
fn cholesky_solve_mut() {
|
||||||
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
|
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
|
||||||
|
|||||||
+27
-19
@@ -25,6 +25,19 @@
|
|||||||
//! let eigenvectors: DenseMatrix<f64> = evd.V;
|
//! let eigenvectors: DenseMatrix<f64> = evd.V;
|
||||||
//! let eigenvalues: Vec<f64> = evd.d;
|
//! let eigenvalues: Vec<f64> = evd.d;
|
||||||
//! ```
|
//! ```
|
||||||
|
//! ```
|
||||||
|
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||||
|
//! use smartcore::linalg::evd::*;
|
||||||
|
//!
|
||||||
|
//! let A = DenseMatrix::from_2d_array(&[
|
||||||
|
//! &[-5.0, 2.0],
|
||||||
|
//! &[-7.0, 4.0],
|
||||||
|
//! ]);
|
||||||
|
//!
|
||||||
|
//! let evd = A.evd(false).unwrap();
|
||||||
|
//! let eigenvectors: DenseMatrix<f64> = evd.V;
|
||||||
|
//! let eigenvalues: Vec<f64> = evd.d;
|
||||||
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! ## References:
|
//! ## References:
|
||||||
//! * ["Numerical Recipes: The Art of Scientific Computing", Press W.H., Teukolsky S.A., Vetterling W.T, Flannery B.P, 3rd ed., Section 11 Eigensystems](http://numerical.recipes/)
|
//! * ["Numerical Recipes: The Art of Scientific Computing", Press W.H., Teukolsky S.A., Vetterling W.T, Flannery B.P, 3rd ed., Section 11 Eigensystems](http://numerical.recipes/)
|
||||||
@@ -93,11 +106,11 @@ pub trait EVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
|||||||
sort(&mut d, &mut e, &mut V);
|
sort(&mut d, &mut e, &mut V);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(EVD { V, d, e })
|
Ok(EVD { d, e, V })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn tred2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec<T>) {
|
fn tred2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
|
||||||
let (n, _) = V.shape();
|
let (n, _) = V.shape();
|
||||||
for (i, d_i) in d.iter_mut().enumerate().take(n) {
|
for (i, d_i) in d.iter_mut().enumerate().take(n) {
|
||||||
*d_i = V.get(n - 1, i);
|
*d_i = V.get(n - 1, i);
|
||||||
@@ -195,7 +208,7 @@ fn tred2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec
|
|||||||
e[0] = T::zero();
|
e[0] = T::zero();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn tql2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec<T>) {
|
fn tql2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
|
||||||
let (n, _) = V.shape();
|
let (n, _) = V.shape();
|
||||||
for i in 1..n {
|
for i in 1..n {
|
||||||
e[i - 1] = e[i];
|
e[i - 1] = e[i];
|
||||||
@@ -419,7 +432,7 @@ fn eltran<T: RealNumber, M: BaseMatrix<T>>(A: &M, V: &mut M, perm: &[usize]) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e: &mut Vec<T>) {
|
fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &mut [T]) {
|
||||||
let (n, _) = A.shape();
|
let (n, _) = A.shape();
|
||||||
let mut z = T::zero();
|
let mut z = T::zero();
|
||||||
let mut s = T::zero();
|
let mut s = T::zero();
|
||||||
@@ -471,7 +484,7 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e
|
|||||||
A.set(nn, nn, x);
|
A.set(nn, nn, x);
|
||||||
A.set(nn - 1, nn - 1, y + t);
|
A.set(nn - 1, nn - 1, y + t);
|
||||||
if q >= T::zero() {
|
if q >= T::zero() {
|
||||||
z = p + z.copysign(p);
|
z = p + RealNumber::copysign(z, p);
|
||||||
d[nn - 1] = x + z;
|
d[nn - 1] = x + z;
|
||||||
d[nn] = x + z;
|
d[nn] = x + z;
|
||||||
if z != T::zero() {
|
if z != T::zero() {
|
||||||
@@ -570,7 +583,7 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e
|
|||||||
r /= x;
|
r /= x;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let s = (p * p + q * q + r * r).sqrt().copysign(p);
|
let s = RealNumber::copysign((p * p + q * q + r * r).sqrt(), p);
|
||||||
if s != T::zero() {
|
if s != T::zero() {
|
||||||
if k == m {
|
if k == m {
|
||||||
if l != m {
|
if l != m {
|
||||||
@@ -594,12 +607,7 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e
|
|||||||
A.sub_element_mut(k + 1, j, p * y);
|
A.sub_element_mut(k + 1, j, p * y);
|
||||||
A.sub_element_mut(k, j, p * x);
|
A.sub_element_mut(k, j, p * x);
|
||||||
}
|
}
|
||||||
let mmin;
|
let mmin = if nn < k + 3 { nn } else { k + 3 };
|
||||||
if nn < k + 3 {
|
|
||||||
mmin = nn;
|
|
||||||
} else {
|
|
||||||
mmin = k + 3;
|
|
||||||
}
|
|
||||||
for i in 0..mmin + 1 {
|
for i in 0..mmin + 1 {
|
||||||
p = x * A.get(i, k) + y * A.get(i, k + 1);
|
p = x * A.get(i, k) + y * A.get(i, k + 1);
|
||||||
if k + 1 != nn {
|
if k + 1 != nn {
|
||||||
@@ -783,7 +791,7 @@ fn balbak<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, scale: &[T]) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sort<T: RealNumber, M: BaseMatrix<T>>(d: &mut Vec<T>, e: &mut Vec<T>, V: &mut M) {
|
fn sort<T: RealNumber, M: BaseMatrix<T>>(d: &mut [T], e: &mut [T], V: &mut M) {
|
||||||
let n = d.len();
|
let n = d.len();
|
||||||
let mut temp = vec![T::zero(); n];
|
let mut temp = vec![T::zero(); n];
|
||||||
for j in 1..n {
|
for j in 1..n {
|
||||||
@@ -804,10 +812,10 @@ fn sort<T: RealNumber, M: BaseMatrix<T>>(d: &mut Vec<T>, e: &mut Vec<T>, V: &mut
|
|||||||
}
|
}
|
||||||
i -= 1;
|
i -= 1;
|
||||||
}
|
}
|
||||||
d[i as usize + 1] = real;
|
d[(i + 1) as usize] = real;
|
||||||
e[i as usize + 1] = img;
|
e[(i + 1) as usize] = img;
|
||||||
for (k, temp_k) in temp.iter().enumerate().take(n) {
|
for (k, temp_k) in temp.iter().enumerate().take(n) {
|
||||||
V.set(k, i as usize + 1, *temp_k);
|
V.set(k, (i + 1) as usize, *temp_k);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -816,7 +824,7 @@ fn sort<T: RealNumber, M: BaseMatrix<T>>(d: &mut Vec<T>, e: &mut Vec<T>, V: &mut
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn decompose_symmetric() {
|
fn decompose_symmetric() {
|
||||||
let A = DenseMatrix::from_2d_array(&[
|
let A = DenseMatrix::from_2d_array(&[
|
||||||
@@ -843,7 +851,7 @@ mod tests {
|
|||||||
assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON);
|
assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn decompose_asymmetric() {
|
fn decompose_asymmetric() {
|
||||||
let A = DenseMatrix::from_2d_array(&[
|
let A = DenseMatrix::from_2d_array(&[
|
||||||
@@ -870,7 +878,7 @@ mod tests {
|
|||||||
assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON);
|
assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn decompose_complex() {
|
fn decompose_complex() {
|
||||||
let A = DenseMatrix::from_2d_array(&[
|
let A = DenseMatrix::from_2d_array(&[
|
||||||
|
|||||||
+5
-4
@@ -46,13 +46,13 @@ use crate::math::num::RealNumber;
|
|||||||
pub struct LU<T: RealNumber, M: BaseMatrix<T>> {
|
pub struct LU<T: RealNumber, M: BaseMatrix<T>> {
|
||||||
LU: M,
|
LU: M,
|
||||||
pivot: Vec<usize>,
|
pivot: Vec<usize>,
|
||||||
pivot_sign: i8,
|
_pivot_sign: i8,
|
||||||
singular: bool,
|
singular: bool,
|
||||||
phantom: PhantomData<T>,
|
phantom: PhantomData<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
|
impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
|
||||||
pub(crate) fn new(LU: M, pivot: Vec<usize>, pivot_sign: i8) -> LU<T, M> {
|
pub(crate) fn new(LU: M, pivot: Vec<usize>, _pivot_sign: i8) -> LU<T, M> {
|
||||||
let (_, n) = LU.shape();
|
let (_, n) = LU.shape();
|
||||||
|
|
||||||
let mut singular = false;
|
let mut singular = false;
|
||||||
@@ -66,7 +66,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
|
|||||||
LU {
|
LU {
|
||||||
LU,
|
LU,
|
||||||
pivot,
|
pivot,
|
||||||
pivot_sign,
|
_pivot_sign,
|
||||||
singular,
|
singular,
|
||||||
phantom: PhantomData,
|
phantom: PhantomData,
|
||||||
}
|
}
|
||||||
@@ -260,6 +260,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::*;
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn decompose() {
|
fn decompose() {
|
||||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
|
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
|
||||||
@@ -274,7 +275,7 @@ mod tests {
|
|||||||
assert!(lu.U().approximate_eq(&expected_U, 1e-4));
|
assert!(lu.U().approximate_eq(&expected_U, 1e-4));
|
||||||
assert!(lu.pivot().approximate_eq(&expected_pivot, 1e-4));
|
assert!(lu.pivot().approximate_eq(&expected_pivot, 1e-4));
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn inverse() {
|
fn inverse() {
|
||||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
|
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
|
||||||
|
|||||||
+32
-6
@@ -1,3 +1,4 @@
|
|||||||
|
#![allow(clippy::wrong_self_convention)]
|
||||||
//! # Linear Algebra and Matrix Decomposition
|
//! # Linear Algebra and Matrix Decomposition
|
||||||
//!
|
//!
|
||||||
//! Most machine learning algorithms in SmartCore depend on linear algebra and matrix decomposition methods from this module.
|
//! Most machine learning algorithms in SmartCore depend on linear algebra and matrix decomposition methods from this module.
|
||||||
@@ -265,7 +266,7 @@ pub trait BaseVector<T: RealNumber>: Clone + Debug {
|
|||||||
sum += xi * xi;
|
sum += xi * xi;
|
||||||
}
|
}
|
||||||
mu /= div;
|
mu /= div;
|
||||||
sum / div - mu * mu
|
sum / div - mu.powi(2)
|
||||||
}
|
}
|
||||||
/// Computes the standard deviation.
|
/// Computes the standard deviation.
|
||||||
fn std(&self) -> T {
|
fn std(&self) -> T {
|
||||||
@@ -650,6 +651,10 @@ pub trait BaseMatrix<T: RealNumber>: Clone + Debug {
|
|||||||
|
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
/// Take an individual column from the matrix.
|
||||||
|
fn take_column(&self, column_index: usize) -> Self {
|
||||||
|
self.take(&[column_index], 1)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generic matrix with additional mixins like various factorization methods.
|
/// Generic matrix with additional mixins like various factorization methods.
|
||||||
@@ -688,12 +693,11 @@ impl<'a, T: RealNumber, M: BaseMatrix<T>> Iterator for RowIter<'a, T, M> {
|
|||||||
type Item = Vec<T>;
|
type Item = Vec<T>;
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Vec<T>> {
|
fn next(&mut self) -> Option<Vec<T>> {
|
||||||
let res;
|
let res = if self.pos < self.max_pos {
|
||||||
if self.pos < self.max_pos {
|
Some(self.m.get_row_as_vec(self.pos))
|
||||||
res = Some(self.m.get_row_as_vec(self.pos))
|
|
||||||
} else {
|
} else {
|
||||||
res = None
|
None
|
||||||
}
|
};
|
||||||
self.pos += 1;
|
self.pos += 1;
|
||||||
res
|
res
|
||||||
}
|
}
|
||||||
@@ -705,6 +709,7 @@ mod tests {
|
|||||||
use crate::linalg::BaseMatrix;
|
use crate::linalg::BaseMatrix;
|
||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn mean() {
|
fn mean() {
|
||||||
let m = vec![1., 2., 3.];
|
let m = vec![1., 2., 3.];
|
||||||
@@ -712,6 +717,7 @@ mod tests {
|
|||||||
assert_eq!(m.mean(), 2.0);
|
assert_eq!(m.mean(), 2.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn std() {
|
fn std() {
|
||||||
let m = vec![1., 2., 3.];
|
let m = vec![1., 2., 3.];
|
||||||
@@ -719,6 +725,7 @@ mod tests {
|
|||||||
assert!((m.std() - 0.81f64).abs() < 1e-2);
|
assert!((m.std() - 0.81f64).abs() < 1e-2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn var() {
|
fn var() {
|
||||||
let m = vec![1., 2., 3., 4.];
|
let m = vec![1., 2., 3., 4.];
|
||||||
@@ -726,6 +733,7 @@ mod tests {
|
|||||||
assert!((m.var() - 1.25f64).abs() < std::f64::EPSILON);
|
assert!((m.var() - 1.25f64).abs() < std::f64::EPSILON);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn vec_take() {
|
fn vec_take() {
|
||||||
let m = vec![1., 2., 3., 4., 5.];
|
let m = vec![1., 2., 3., 4., 5.];
|
||||||
@@ -733,6 +741,7 @@ mod tests {
|
|||||||
assert_eq!(m.take(&vec!(0, 0, 4, 4)), vec![1., 1., 5., 5.]);
|
assert_eq!(m.take(&vec!(0, 0, 4, 4)), vec![1., 1., 5., 5.]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn take() {
|
fn take() {
|
||||||
let m = DenseMatrix::from_2d_array(&[
|
let m = DenseMatrix::from_2d_array(&[
|
||||||
@@ -756,4 +765,21 @@ mod tests {
|
|||||||
assert_eq!(m.take(&vec!(1, 1, 3), 0), expected_0);
|
assert_eq!(m.take(&vec!(1, 1, 3), 0), expected_0);
|
||||||
assert_eq!(m.take(&vec!(1, 0), 1), expected_1);
|
assert_eq!(m.take(&vec!(1, 0), 1), expected_1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn take_second_column_from_matrix() {
|
||||||
|
let four_columns: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
|
||||||
|
&[0.0, 1.0, 2.0, 3.0],
|
||||||
|
&[0.0, 1.0, 2.0, 3.0],
|
||||||
|
&[0.0, 1.0, 2.0, 3.0],
|
||||||
|
&[0.0, 1.0, 2.0, 3.0],
|
||||||
|
]);
|
||||||
|
|
||||||
|
let second_column = four_columns.take_column(1);
|
||||||
|
assert_eq!(
|
||||||
|
second_column,
|
||||||
|
DenseMatrix::from_2d_array(&[&[1.0], &[1.0], &[1.0], &[1.0]]),
|
||||||
|
"The second column was not extracted correctly"
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
#![allow(clippy::ptr_arg)]
|
#![allow(clippy::ptr_arg)]
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
use std::ops::Range;
|
use std::ops::Range;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::de::{Deserializer, MapAccess, SeqAccess, Visitor};
|
use serde::de::{Deserializer, MapAccess, SeqAccess, Visitor};
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::ser::{SerializeStruct, Serializer};
|
use serde::ser::{SerializeStruct, Serializer};
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::linalg::cholesky::CholeskyDecomposableMatrix;
|
use crate::linalg::cholesky::CholeskyDecomposableMatrix;
|
||||||
@@ -326,7 +330,7 @@ impl<T: RealNumber> DenseMatrix<T> {
|
|||||||
cur_r: 0,
|
cur_r: 0,
|
||||||
max_c: self.ncols,
|
max_c: self.ncols,
|
||||||
max_r: self.nrows,
|
max_r: self.nrows,
|
||||||
m: &self,
|
m: self,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -349,6 +353,7 @@ impl<'a, T: RealNumber> Iterator for DenseMatrixIterator<'a, T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
impl<'de, T: RealNumber + fmt::Debug + Deserialize<'de>> Deserialize<'de> for DenseMatrix<T> {
|
impl<'de, T: RealNumber + fmt::Debug + Deserialize<'de>> Deserialize<'de> for DenseMatrix<T> {
|
||||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||||
where
|
where
|
||||||
@@ -434,6 +439,7 @@ impl<'de, T: RealNumber + fmt::Debug + Deserialize<'de>> Deserialize<'de> for De
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
impl<T: RealNumber + fmt::Debug + Serialize> Serialize for DenseMatrix<T> {
|
impl<T: RealNumber + fmt::Debug + Serialize> Serialize for DenseMatrix<T> {
|
||||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
where
|
where
|
||||||
@@ -517,10 +523,9 @@ impl<T: RealNumber> PartialEq for DenseMatrix<T> {
|
|||||||
true
|
true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
impl<T: RealNumber> From<DenseMatrix<T>> for Vec<T> {
|
||||||
impl<T: RealNumber> Into<Vec<T>> for DenseMatrix<T> {
|
fn from(dense_matrix: DenseMatrix<T>) -> Vec<T> {
|
||||||
fn into(self) -> Vec<T> {
|
dense_matrix.values
|
||||||
self.values
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1054,14 +1059,14 @@ impl<T: RealNumber> BaseMatrix<T> for DenseMatrix<T> {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn vec_dot() {
|
fn vec_dot() {
|
||||||
let v1 = vec![1., 2., 3.];
|
let v1 = vec![1., 2., 3.];
|
||||||
let v2 = vec![4., 5., 6.];
|
let v2 = vec![4., 5., 6.];
|
||||||
assert_eq!(32.0, BaseVector::dot(&v1, &v2));
|
assert_eq!(32.0, BaseVector::dot(&v1, &v2));
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn vec_copy_from() {
|
fn vec_copy_from() {
|
||||||
let mut v1 = vec![1., 2., 3.];
|
let mut v1 = vec![1., 2., 3.];
|
||||||
@@ -1069,7 +1074,7 @@ mod tests {
|
|||||||
v1.copy_from(&v2);
|
v1.copy_from(&v2);
|
||||||
assert_eq!(v1, v2);
|
assert_eq!(v1, v2);
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn vec_approximate_eq() {
|
fn vec_approximate_eq() {
|
||||||
let a = vec![1., 2., 3.];
|
let a = vec![1., 2., 3.];
|
||||||
@@ -1077,7 +1082,7 @@ mod tests {
|
|||||||
assert!(a.approximate_eq(&b, 1e-4));
|
assert!(a.approximate_eq(&b, 1e-4));
|
||||||
assert!(!a.approximate_eq(&b, 1e-5));
|
assert!(!a.approximate_eq(&b, 1e-5));
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn from_array() {
|
fn from_array() {
|
||||||
let vec = [1., 2., 3., 4., 5., 6.];
|
let vec = [1., 2., 3., 4., 5., 6.];
|
||||||
@@ -1090,7 +1095,7 @@ mod tests {
|
|||||||
DenseMatrix::new(2, 3, vec![1., 4., 2., 5., 3., 6.])
|
DenseMatrix::new(2, 3, vec![1., 4., 2., 5., 3., 6.])
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn row_column_vec_from_array() {
|
fn row_column_vec_from_array() {
|
||||||
let vec = vec![1., 2., 3., 4., 5., 6.];
|
let vec = vec![1., 2., 3., 4., 5., 6.];
|
||||||
@@ -1103,7 +1108,7 @@ mod tests {
|
|||||||
DenseMatrix::new(6, 1, vec![1., 2., 3., 4., 5., 6.])
|
DenseMatrix::new(6, 1, vec![1., 2., 3., 4., 5., 6.])
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn from_to_row_vec() {
|
fn from_to_row_vec() {
|
||||||
let vec = vec![1., 2., 3.];
|
let vec = vec![1., 2., 3.];
|
||||||
@@ -1116,20 +1121,20 @@ mod tests {
|
|||||||
vec![1., 2., 3.]
|
vec![1., 2., 3.]
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn col_matrix_to_row_vector() {
|
fn col_matrix_to_row_vector() {
|
||||||
let m: DenseMatrix<f64> = BaseMatrix::zeros(10, 1);
|
let m: DenseMatrix<f64> = BaseMatrix::zeros(10, 1);
|
||||||
assert_eq!(m.to_row_vector().len(), 10)
|
assert_eq!(m.to_row_vector().len(), 10)
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn iter() {
|
fn iter() {
|
||||||
let vec = vec![1., 2., 3., 4., 5., 6.];
|
let vec = vec![1., 2., 3., 4., 5., 6.];
|
||||||
let m = DenseMatrix::from_array(3, 2, &vec);
|
let m = DenseMatrix::from_array(3, 2, &vec);
|
||||||
assert_eq!(vec, m.iter().collect::<Vec<f32>>());
|
assert_eq!(vec, m.iter().collect::<Vec<f32>>());
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn v_stack() {
|
fn v_stack() {
|
||||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
|
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
|
||||||
@@ -1144,7 +1149,7 @@ mod tests {
|
|||||||
let result = a.v_stack(&b);
|
let result = a.v_stack(&b);
|
||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn h_stack() {
|
fn h_stack() {
|
||||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
|
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
|
||||||
@@ -1157,13 +1162,13 @@ mod tests {
|
|||||||
let result = a.h_stack(&b);
|
let result = a.h_stack(&b);
|
||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn get_row() {
|
fn get_row() {
|
||||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
|
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
|
||||||
assert_eq!(vec![4., 5., 6.], a.get_row(1));
|
assert_eq!(vec![4., 5., 6.], a.get_row(1));
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn matmul() {
|
fn matmul() {
|
||||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
|
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
|
||||||
@@ -1172,7 +1177,7 @@ mod tests {
|
|||||||
let result = a.matmul(&b);
|
let result = a.matmul(&b);
|
||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn ab() {
|
fn ab() {
|
||||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
|
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
|
||||||
@@ -1195,14 +1200,14 @@ mod tests {
|
|||||||
DenseMatrix::from_2d_array(&[&[29., 39., 49.], &[40., 54., 68.,], &[51., 69., 87.]])
|
DenseMatrix::from_2d_array(&[&[29., 39., 49.], &[40., 54., 68.,], &[51., 69., 87.]])
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn dot() {
|
fn dot() {
|
||||||
let a = DenseMatrix::from_array(1, 3, &[1., 2., 3.]);
|
let a = DenseMatrix::from_array(1, 3, &[1., 2., 3.]);
|
||||||
let b = DenseMatrix::from_array(1, 3, &[4., 5., 6.]);
|
let b = DenseMatrix::from_array(1, 3, &[4., 5., 6.]);
|
||||||
assert_eq!(a.dot(&b), 32.);
|
assert_eq!(a.dot(&b), 32.);
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn copy_from() {
|
fn copy_from() {
|
||||||
let mut a = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.]]);
|
let mut a = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.]]);
|
||||||
@@ -1210,7 +1215,7 @@ mod tests {
|
|||||||
a.copy_from(&b);
|
a.copy_from(&b);
|
||||||
assert_eq!(a, b);
|
assert_eq!(a, b);
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn slice() {
|
fn slice() {
|
||||||
let m = DenseMatrix::from_2d_array(&[
|
let m = DenseMatrix::from_2d_array(&[
|
||||||
@@ -1222,7 +1227,7 @@ mod tests {
|
|||||||
let result = m.slice(0..2, 1..3);
|
let result = m.slice(0..2, 1..3);
|
||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn approximate_eq() {
|
fn approximate_eq() {
|
||||||
let m = DenseMatrix::from_2d_array(&[&[2., 3.], &[5., 6.]]);
|
let m = DenseMatrix::from_2d_array(&[&[2., 3.], &[5., 6.]]);
|
||||||
@@ -1231,7 +1236,7 @@ mod tests {
|
|||||||
assert!(m.approximate_eq(&m_eq, 0.5));
|
assert!(m.approximate_eq(&m_eq, 0.5));
|
||||||
assert!(!m.approximate_eq(&m_neq, 0.5));
|
assert!(!m.approximate_eq(&m_neq, 0.5));
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn rand() {
|
fn rand() {
|
||||||
let m: DenseMatrix<f64> = DenseMatrix::rand(3, 3);
|
let m: DenseMatrix<f64> = DenseMatrix::rand(3, 3);
|
||||||
@@ -1241,7 +1246,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn transpose() {
|
fn transpose() {
|
||||||
let m = DenseMatrix::from_2d_array(&[&[1.0, 3.0], &[2.0, 4.0]]);
|
let m = DenseMatrix::from_2d_array(&[&[1.0, 3.0], &[2.0, 4.0]]);
|
||||||
@@ -1253,7 +1258,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn reshape() {
|
fn reshape() {
|
||||||
let m_orig = DenseMatrix::row_vector_from_array(&[1., 2., 3., 4., 5., 6.]);
|
let m_orig = DenseMatrix::row_vector_from_array(&[1., 2., 3., 4., 5., 6.]);
|
||||||
@@ -1264,7 +1269,7 @@ mod tests {
|
|||||||
assert_eq!(m_result.get(0, 1), 2.);
|
assert_eq!(m_result.get(0, 1), 2.);
|
||||||
assert_eq!(m_result.get(0, 3), 4.);
|
assert_eq!(m_result.get(0, 3), 4.);
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn norm() {
|
fn norm() {
|
||||||
let v = DenseMatrix::row_vector_from_array(&[3., -2., 6.]);
|
let v = DenseMatrix::row_vector_from_array(&[3., -2., 6.]);
|
||||||
@@ -1273,7 +1278,7 @@ mod tests {
|
|||||||
assert_eq!(v.norm(std::f64::INFINITY), 6.);
|
assert_eq!(v.norm(std::f64::INFINITY), 6.);
|
||||||
assert_eq!(v.norm(std::f64::NEG_INFINITY), 2.);
|
assert_eq!(v.norm(std::f64::NEG_INFINITY), 2.);
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn softmax_mut() {
|
fn softmax_mut() {
|
||||||
let mut prob: DenseMatrix<f64> = DenseMatrix::row_vector_from_array(&[1., 2., 3.]);
|
let mut prob: DenseMatrix<f64> = DenseMatrix::row_vector_from_array(&[1., 2., 3.]);
|
||||||
@@ -1282,14 +1287,14 @@ mod tests {
|
|||||||
assert!((prob.get(0, 1) - 0.24).abs() < 0.01);
|
assert!((prob.get(0, 1) - 0.24).abs() < 0.01);
|
||||||
assert!((prob.get(0, 2) - 0.66).abs() < 0.01);
|
assert!((prob.get(0, 2) - 0.66).abs() < 0.01);
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn col_mean() {
|
fn col_mean() {
|
||||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
|
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
|
||||||
let res = a.column_mean();
|
let res = a.column_mean();
|
||||||
assert_eq!(res, vec![4., 5., 6.]);
|
assert_eq!(res, vec![4., 5., 6.]);
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn min_max_sum() {
|
fn min_max_sum() {
|
||||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
|
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
|
||||||
@@ -1297,30 +1302,32 @@ mod tests {
|
|||||||
assert_eq!(1., a.min());
|
assert_eq!(1., a.min());
|
||||||
assert_eq!(6., a.max());
|
assert_eq!(6., a.max());
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn eye() {
|
fn eye() {
|
||||||
let a = DenseMatrix::from_2d_array(&[&[1., 0., 0.], &[0., 1., 0.], &[0., 0., 1.]]);
|
let a = DenseMatrix::from_2d_array(&[&[1., 0., 0.], &[0., 1., 0.], &[0., 0., 1.]]);
|
||||||
let res = DenseMatrix::eye(3);
|
let res = DenseMatrix::eye(3);
|
||||||
assert_eq!(res, a);
|
assert_eq!(res, a);
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn to_from_json() {
|
fn to_from_json() {
|
||||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||||
let deserialized_a: DenseMatrix<f64> =
|
let deserialized_a: DenseMatrix<f64> =
|
||||||
serde_json::from_str(&serde_json::to_string(&a).unwrap()).unwrap();
|
serde_json::from_str(&serde_json::to_string(&a).unwrap()).unwrap();
|
||||||
assert_eq!(a, deserialized_a);
|
assert_eq!(a, deserialized_a);
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn to_from_bincode() {
|
fn to_from_bincode() {
|
||||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||||
let deserialized_a: DenseMatrix<f64> =
|
let deserialized_a: DenseMatrix<f64> =
|
||||||
bincode::deserialize(&bincode::serialize(&a).unwrap()).unwrap();
|
bincode::deserialize(&bincode::serialize(&a).unwrap()).unwrap();
|
||||||
assert_eq!(a, deserialized_a);
|
assert_eq!(a, deserialized_a);
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn to_string() {
|
fn to_string() {
|
||||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||||
@@ -1329,7 +1336,7 @@ mod tests {
|
|||||||
"[[0.9, 0.4, 0.7], [0.4, 0.5, 0.3], [0.7, 0.3, 0.8]]"
|
"[[0.9, 0.4, 0.7], [0.4, 0.5, 0.3], [0.7, 0.3, 0.8]]"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn cov() {
|
fn cov() {
|
||||||
let a = DenseMatrix::from_2d_array(&[
|
let a = DenseMatrix::from_2d_array(&[
|
||||||
|
|||||||
@@ -40,7 +40,7 @@
|
|||||||
use std::iter::Sum;
|
use std::iter::Sum;
|
||||||
use std::ops::{AddAssign, DivAssign, MulAssign, Range, SubAssign};
|
use std::ops::{AddAssign, DivAssign, MulAssign, Range, SubAssign};
|
||||||
|
|
||||||
use nalgebra::{DMatrix, Dynamic, Matrix, MatrixMN, RowDVector, Scalar, VecStorage, U1};
|
use nalgebra::{Const, DMatrix, Dynamic, Matrix, OMatrix, RowDVector, Scalar, VecStorage, U1};
|
||||||
|
|
||||||
use crate::linalg::cholesky::CholeskyDecomposableMatrix;
|
use crate::linalg::cholesky::CholeskyDecomposableMatrix;
|
||||||
use crate::linalg::evd::EVDDecomposableMatrix;
|
use crate::linalg::evd::EVDDecomposableMatrix;
|
||||||
@@ -53,7 +53,7 @@ use crate::linalg::Matrix as SmartCoreMatrix;
|
|||||||
use crate::linalg::{BaseMatrix, BaseVector};
|
use crate::linalg::{BaseMatrix, BaseVector};
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
impl<T: RealNumber + 'static> BaseVector<T> for MatrixMN<T, U1, Dynamic> {
|
impl<T: RealNumber + 'static> BaseVector<T> for OMatrix<T, U1, Dynamic> {
|
||||||
fn get(&self, i: usize) -> T {
|
fn get(&self, i: usize) -> T {
|
||||||
*self.get((0, i)).unwrap()
|
*self.get((0, i)).unwrap()
|
||||||
}
|
}
|
||||||
@@ -198,7 +198,7 @@ impl<T: RealNumber + Scalar + AddAssign + SubAssign + MulAssign + DivAssign + Su
|
|||||||
|
|
||||||
fn to_row_vector(self) -> Self::RowVector {
|
fn to_row_vector(self) -> Self::RowVector {
|
||||||
let (nrows, ncols) = self.shape();
|
let (nrows, ncols) = self.shape();
|
||||||
self.reshape_generic(U1, Dynamic::new(nrows * ncols))
|
self.reshape_generic(Const::<1>, Dynamic::new(nrows * ncols))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get(&self, row: usize, col: usize) -> T {
|
fn get(&self, row: usize, col: usize) -> T {
|
||||||
@@ -579,6 +579,7 @@ mod tests {
|
|||||||
use crate::linear::linear_regression::*;
|
use crate::linear::linear_regression::*;
|
||||||
use nalgebra::{DMatrix, Matrix2x3, RowDVector};
|
use nalgebra::{DMatrix, Matrix2x3, RowDVector};
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn vec_copy_from() {
|
fn vec_copy_from() {
|
||||||
let mut v1 = RowDVector::from_vec(vec![1., 2., 3.]);
|
let mut v1 = RowDVector::from_vec(vec![1., 2., 3.]);
|
||||||
@@ -589,12 +590,14 @@ mod tests {
|
|||||||
assert_ne!(v2, v1);
|
assert_ne!(v2, v1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn vec_len() {
|
fn vec_len() {
|
||||||
let v = RowDVector::from_vec(vec![1., 2., 3.]);
|
let v = RowDVector::from_vec(vec![1., 2., 3.]);
|
||||||
assert_eq!(3, v.len());
|
assert_eq!(3, v.len());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn get_set_vector() {
|
fn get_set_vector() {
|
||||||
let mut v = RowDVector::from_vec(vec![1., 2., 3., 4.]);
|
let mut v = RowDVector::from_vec(vec![1., 2., 3., 4.]);
|
||||||
@@ -607,12 +610,14 @@ mod tests {
|
|||||||
assert_eq!(5., BaseVector::get(&v, 1));
|
assert_eq!(5., BaseVector::get(&v, 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn vec_to_vec() {
|
fn vec_to_vec() {
|
||||||
let v = RowDVector::from_vec(vec![1., 2., 3.]);
|
let v = RowDVector::from_vec(vec![1., 2., 3.]);
|
||||||
assert_eq!(vec![1., 2., 3.], v.to_vec());
|
assert_eq!(vec![1., 2., 3.], v.to_vec());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn vec_init() {
|
fn vec_init() {
|
||||||
let zeros: RowDVector<f32> = BaseVector::zeros(3);
|
let zeros: RowDVector<f32> = BaseVector::zeros(3);
|
||||||
@@ -623,6 +628,7 @@ mod tests {
|
|||||||
assert_eq!(twos, RowDVector::from_vec(vec![2., 2., 2.]));
|
assert_eq!(twos, RowDVector::from_vec(vec![2., 2., 2.]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn vec_dot() {
|
fn vec_dot() {
|
||||||
let v1 = RowDVector::from_vec(vec![1., 2., 3.]);
|
let v1 = RowDVector::from_vec(vec![1., 2., 3.]);
|
||||||
@@ -630,6 +636,7 @@ mod tests {
|
|||||||
assert_eq!(32.0, BaseVector::dot(&v1, &v2));
|
assert_eq!(32.0, BaseVector::dot(&v1, &v2));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn vec_approximate_eq() {
|
fn vec_approximate_eq() {
|
||||||
let a = RowDVector::from_vec(vec![1., 2., 3.]);
|
let a = RowDVector::from_vec(vec![1., 2., 3.]);
|
||||||
@@ -638,6 +645,7 @@ mod tests {
|
|||||||
assert!(!a.approximate_eq(&(&noise + &a), 1e-5));
|
assert!(!a.approximate_eq(&(&noise + &a), 1e-5));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn get_set_dynamic() {
|
fn get_set_dynamic() {
|
||||||
let mut m = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
|
let mut m = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
|
||||||
@@ -650,6 +658,7 @@ mod tests {
|
|||||||
assert_eq!(10., BaseMatrix::get(&m, 1, 1));
|
assert_eq!(10., BaseMatrix::get(&m, 1, 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn zeros() {
|
fn zeros() {
|
||||||
let expected = DMatrix::from_row_slice(2, 2, &[0., 0., 0., 0.]);
|
let expected = DMatrix::from_row_slice(2, 2, &[0., 0., 0., 0.]);
|
||||||
@@ -659,6 +668,7 @@ mod tests {
|
|||||||
assert_eq!(m, expected);
|
assert_eq!(m, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn ones() {
|
fn ones() {
|
||||||
let expected = DMatrix::from_row_slice(2, 2, &[1., 1., 1., 1.]);
|
let expected = DMatrix::from_row_slice(2, 2, &[1., 1., 1., 1.]);
|
||||||
@@ -668,6 +678,7 @@ mod tests {
|
|||||||
assert_eq!(m, expected);
|
assert_eq!(m, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn eye() {
|
fn eye() {
|
||||||
let expected = DMatrix::from_row_slice(3, 3, &[1., 0., 0., 0., 1., 0., 0., 0., 1.]);
|
let expected = DMatrix::from_row_slice(3, 3, &[1., 0., 0., 0., 1., 0., 0., 0., 1.]);
|
||||||
@@ -675,6 +686,7 @@ mod tests {
|
|||||||
assert_eq!(m, expected);
|
assert_eq!(m, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn shape() {
|
fn shape() {
|
||||||
let m: DMatrix<f64> = BaseMatrix::zeros(5, 10);
|
let m: DMatrix<f64> = BaseMatrix::zeros(5, 10);
|
||||||
@@ -684,6 +696,7 @@ mod tests {
|
|||||||
assert_eq!(ncols, 10);
|
assert_eq!(ncols, 10);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn scalar_add_sub_mul_div() {
|
fn scalar_add_sub_mul_div() {
|
||||||
let mut m = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
|
let mut m = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
|
||||||
@@ -697,6 +710,7 @@ mod tests {
|
|||||||
assert_eq!(m, expected);
|
assert_eq!(m, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn add_sub_mul_div() {
|
fn add_sub_mul_div() {
|
||||||
let mut m = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
|
let mut m = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
|
||||||
@@ -715,6 +729,7 @@ mod tests {
|
|||||||
assert_eq!(m, expected);
|
assert_eq!(m, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn to_from_row_vector() {
|
fn to_from_row_vector() {
|
||||||
let v = RowDVector::from_vec(vec![1., 2., 3., 4.]);
|
let v = RowDVector::from_vec(vec![1., 2., 3., 4.]);
|
||||||
@@ -723,12 +738,14 @@ mod tests {
|
|||||||
assert_eq!(m.to_row_vector(), expected);
|
assert_eq!(m.to_row_vector(), expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn col_matrix_to_row_vector() {
|
fn col_matrix_to_row_vector() {
|
||||||
let m: DMatrix<f64> = BaseMatrix::zeros(10, 1);
|
let m: DMatrix<f64> = BaseMatrix::zeros(10, 1);
|
||||||
assert_eq!(m.to_row_vector().len(), 10)
|
assert_eq!(m.to_row_vector().len(), 10)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn get_row_col_as_vec() {
|
fn get_row_col_as_vec() {
|
||||||
let m = DMatrix::from_row_slice(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
|
let m = DMatrix::from_row_slice(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
|
||||||
@@ -737,12 +754,14 @@ mod tests {
|
|||||||
assert_eq!(m.get_col_as_vec(1), vec!(2., 5., 8.));
|
assert_eq!(m.get_col_as_vec(1), vec!(2., 5., 8.));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn get_row() {
|
fn get_row() {
|
||||||
let a = DMatrix::from_row_slice(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
|
let a = DMatrix::from_row_slice(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
|
||||||
assert_eq!(RowDVector::from_vec(vec![4., 5., 6.]), a.get_row(1));
|
assert_eq!(RowDVector::from_vec(vec![4., 5., 6.]), a.get_row(1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn copy_row_col_as_vec() {
|
fn copy_row_col_as_vec() {
|
||||||
let m = DMatrix::from_row_slice(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
|
let m = DMatrix::from_row_slice(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
|
||||||
@@ -754,6 +773,7 @@ mod tests {
|
|||||||
assert_eq!(v, vec!(2., 5., 8.));
|
assert_eq!(v, vec!(2., 5., 8.));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn element_add_sub_mul_div() {
|
fn element_add_sub_mul_div() {
|
||||||
let mut m = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
|
let mut m = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
|
||||||
@@ -767,6 +787,7 @@ mod tests {
|
|||||||
assert_eq!(m, expected);
|
assert_eq!(m, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn vstack_hstack() {
|
fn vstack_hstack() {
|
||||||
let m1 = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]);
|
let m1 = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]);
|
||||||
@@ -782,6 +803,7 @@ mod tests {
|
|||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn matmul() {
|
fn matmul() {
|
||||||
let a = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]);
|
let a = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]);
|
||||||
@@ -791,6 +813,7 @@ mod tests {
|
|||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn dot() {
|
fn dot() {
|
||||||
let a = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]);
|
let a = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]);
|
||||||
@@ -798,6 +821,7 @@ mod tests {
|
|||||||
assert_eq!(14., a.dot(&b));
|
assert_eq!(14., a.dot(&b));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn slice() {
|
fn slice() {
|
||||||
let a = DMatrix::from_row_slice(
|
let a = DMatrix::from_row_slice(
|
||||||
@@ -810,6 +834,7 @@ mod tests {
|
|||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn approximate_eq() {
|
fn approximate_eq() {
|
||||||
let a = DMatrix::from_row_slice(3, 3, &[1., 2., 3., 4., 5., 6., 7., 8., 9.]);
|
let a = DMatrix::from_row_slice(3, 3, &[1., 2., 3., 4., 5., 6., 7., 8., 9.]);
|
||||||
@@ -822,6 +847,7 @@ mod tests {
|
|||||||
assert!(!a.approximate_eq(&(&noise + &a), 1e-5));
|
assert!(!a.approximate_eq(&(&noise + &a), 1e-5));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn negative_mut() {
|
fn negative_mut() {
|
||||||
let mut v = DMatrix::from_row_slice(1, 3, &[3., -2., 6.]);
|
let mut v = DMatrix::from_row_slice(1, 3, &[3., -2., 6.]);
|
||||||
@@ -829,6 +855,7 @@ mod tests {
|
|||||||
assert_eq!(v, DMatrix::from_row_slice(1, 3, &[-3., 2., -6.]));
|
assert_eq!(v, DMatrix::from_row_slice(1, 3, &[-3., 2., -6.]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn transpose() {
|
fn transpose() {
|
||||||
let m = DMatrix::from_row_slice(2, 2, &[1.0, 3.0, 2.0, 4.0]);
|
let m = DMatrix::from_row_slice(2, 2, &[1.0, 3.0, 2.0, 4.0]);
|
||||||
@@ -837,6 +864,7 @@ mod tests {
|
|||||||
assert_eq!(m_transposed, expected);
|
assert_eq!(m_transposed, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn rand() {
|
fn rand() {
|
||||||
let m: DMatrix<f64> = BaseMatrix::rand(3, 3);
|
let m: DMatrix<f64> = BaseMatrix::rand(3, 3);
|
||||||
@@ -847,6 +875,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn norm() {
|
fn norm() {
|
||||||
let v = DMatrix::from_row_slice(1, 3, &[3., -2., 6.]);
|
let v = DMatrix::from_row_slice(1, 3, &[3., -2., 6.]);
|
||||||
@@ -856,6 +885,7 @@ mod tests {
|
|||||||
assert_eq!(BaseMatrix::norm(&v, std::f64::NEG_INFINITY), 2.);
|
assert_eq!(BaseMatrix::norm(&v, std::f64::NEG_INFINITY), 2.);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn col_mean() {
|
fn col_mean() {
|
||||||
let a = DMatrix::from_row_slice(3, 3, &[1., 2., 3., 4., 5., 6., 7., 8., 9.]);
|
let a = DMatrix::from_row_slice(3, 3, &[1., 2., 3., 4., 5., 6., 7., 8., 9.]);
|
||||||
@@ -863,6 +893,7 @@ mod tests {
|
|||||||
assert_eq!(res, vec![4., 5., 6.]);
|
assert_eq!(res, vec![4., 5., 6.]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn reshape() {
|
fn reshape() {
|
||||||
let m_orig = DMatrix::from_row_slice(1, 6, &[1., 2., 3., 4., 5., 6.]);
|
let m_orig = DMatrix::from_row_slice(1, 6, &[1., 2., 3., 4., 5., 6.]);
|
||||||
@@ -874,6 +905,7 @@ mod tests {
|
|||||||
assert_eq!(BaseMatrix::get(&m_result, 0, 3), 4.);
|
assert_eq!(BaseMatrix::get(&m_result, 0, 3), 4.);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn copy_from() {
|
fn copy_from() {
|
||||||
let mut src = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]);
|
let mut src = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]);
|
||||||
@@ -882,6 +914,7 @@ mod tests {
|
|||||||
assert_eq!(src, dst);
|
assert_eq!(src, dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn abs_mut() {
|
fn abs_mut() {
|
||||||
let mut a = DMatrix::from_row_slice(2, 2, &[1., -2., 3., -4.]);
|
let mut a = DMatrix::from_row_slice(2, 2, &[1., -2., 3., -4.]);
|
||||||
@@ -890,6 +923,7 @@ mod tests {
|
|||||||
assert_eq!(a, expected);
|
assert_eq!(a, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn min_max_sum() {
|
fn min_max_sum() {
|
||||||
let a = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]);
|
let a = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]);
|
||||||
@@ -898,6 +932,7 @@ mod tests {
|
|||||||
assert_eq!(6., a.max());
|
assert_eq!(6., a.max());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn max_diff() {
|
fn max_diff() {
|
||||||
let a1 = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., -5., 6.]);
|
let a1 = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., -5., 6.]);
|
||||||
@@ -906,6 +941,7 @@ mod tests {
|
|||||||
assert_eq!(a2.max_diff(&a2), 0.);
|
assert_eq!(a2.max_diff(&a2), 0.);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn softmax_mut() {
|
fn softmax_mut() {
|
||||||
let mut prob: DMatrix<f64> = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]);
|
let mut prob: DMatrix<f64> = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]);
|
||||||
@@ -915,13 +951,15 @@ mod tests {
|
|||||||
assert!((BaseMatrix::get(&prob, 0, 2) - 0.66).abs() < 0.01);
|
assert!((BaseMatrix::get(&prob, 0, 2) - 0.66).abs() < 0.01);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn pow_mut() {
|
fn pow_mut() {
|
||||||
let mut a = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]);
|
let mut a = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]);
|
||||||
a.pow_mut(3.);
|
BaseMatrix::pow_mut(&mut a, 3.);
|
||||||
assert_eq!(a, DMatrix::from_row_slice(1, 3, &[1., 8., 27.]));
|
assert_eq!(a, DMatrix::from_row_slice(1, 3, &[1., 8., 27.]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn argmax() {
|
fn argmax() {
|
||||||
let a = DMatrix::from_row_slice(3, 3, &[1., 2., 3., -5., -6., -7., 0.1, 0.2, 0.1]);
|
let a = DMatrix::from_row_slice(3, 3, &[1., 2., 3., -5., -6., -7., 0.1, 0.2, 0.1]);
|
||||||
@@ -929,6 +967,7 @@ mod tests {
|
|||||||
assert_eq!(res, vec![2, 0, 1]);
|
assert_eq!(res, vec![2, 0, 1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn unique() {
|
fn unique() {
|
||||||
let a = DMatrix::from_row_slice(3, 3, &[1., 2., 2., -2., -6., -7., 2., 3., 4.]);
|
let a = DMatrix::from_row_slice(3, 3, &[1., 2., 2., -2., -6., -7., 2., 3., 4.]);
|
||||||
@@ -937,6 +976,7 @@ mod tests {
|
|||||||
assert_eq!(res, vec![-7., -6., -2., 1., 2., 3., 4.]);
|
assert_eq!(res, vec![-7., -6., -2., 1., 2., 3., 4.]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn ols_fit_predict() {
|
fn ols_fit_predict() {
|
||||||
let x = DMatrix::from_row_slice(
|
let x = DMatrix::from_row_slice(
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ impl<T: RealNumber + ScalarOperand> BaseVector<T> for ArrayBase<OwnedRepr<T>, Ix
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn copy_from(&mut self, other: &Self) {
|
fn copy_from(&mut self, other: &Self) {
|
||||||
self.assign(&other);
|
self.assign(other);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -385,7 +385,7 @@ impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssi
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn copy_from(&mut self, other: &Self) {
|
fn copy_from(&mut self, other: &Self) {
|
||||||
self.assign(&other);
|
self.assign(other);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn abs_mut(&mut self) -> &Self {
|
fn abs_mut(&mut self) -> &Self {
|
||||||
@@ -530,6 +530,7 @@ mod tests {
|
|||||||
use crate::metrics::mean_absolute_error;
|
use crate::metrics::mean_absolute_error;
|
||||||
use ndarray::{arr1, arr2, Array1, Array2};
|
use ndarray::{arr1, arr2, Array1, Array2};
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn vec_get_set() {
|
fn vec_get_set() {
|
||||||
let mut result = arr1(&[1., 2., 3.]);
|
let mut result = arr1(&[1., 2., 3.]);
|
||||||
@@ -541,6 +542,7 @@ mod tests {
|
|||||||
assert_eq!(5., BaseVector::get(&result, 1));
|
assert_eq!(5., BaseVector::get(&result, 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn vec_copy_from() {
|
fn vec_copy_from() {
|
||||||
let mut v1 = arr1(&[1., 2., 3.]);
|
let mut v1 = arr1(&[1., 2., 3.]);
|
||||||
@@ -551,18 +553,21 @@ mod tests {
|
|||||||
assert_ne!(v1, v2);
|
assert_ne!(v1, v2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn vec_len() {
|
fn vec_len() {
|
||||||
let v = arr1(&[1., 2., 3.]);
|
let v = arr1(&[1., 2., 3.]);
|
||||||
assert_eq!(3, v.len());
|
assert_eq!(3, v.len());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn vec_to_vec() {
|
fn vec_to_vec() {
|
||||||
let v = arr1(&[1., 2., 3.]);
|
let v = arr1(&[1., 2., 3.]);
|
||||||
assert_eq!(vec![1., 2., 3.], v.to_vec());
|
assert_eq!(vec![1., 2., 3.], v.to_vec());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn vec_dot() {
|
fn vec_dot() {
|
||||||
let v1 = arr1(&[1., 2., 3.]);
|
let v1 = arr1(&[1., 2., 3.]);
|
||||||
@@ -570,6 +575,7 @@ mod tests {
|
|||||||
assert_eq!(32.0, BaseVector::dot(&v1, &v2));
|
assert_eq!(32.0, BaseVector::dot(&v1, &v2));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn vec_approximate_eq() {
|
fn vec_approximate_eq() {
|
||||||
let a = arr1(&[1., 2., 3.]);
|
let a = arr1(&[1., 2., 3.]);
|
||||||
@@ -578,6 +584,7 @@ mod tests {
|
|||||||
assert!(!a.approximate_eq(&(&noise + &a), 1e-5));
|
assert!(!a.approximate_eq(&(&noise + &a), 1e-5));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn from_to_row_vec() {
|
fn from_to_row_vec() {
|
||||||
let vec = arr1(&[1., 2., 3.]);
|
let vec = arr1(&[1., 2., 3.]);
|
||||||
@@ -588,12 +595,14 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn col_matrix_to_row_vector() {
|
fn col_matrix_to_row_vector() {
|
||||||
let m: Array2<f64> = BaseMatrix::zeros(10, 1);
|
let m: Array2<f64> = BaseMatrix::zeros(10, 1);
|
||||||
assert_eq!(m.to_row_vector().len(), 10)
|
assert_eq!(m.to_row_vector().len(), 10)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn add_mut() {
|
fn add_mut() {
|
||||||
let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||||
@@ -604,6 +613,7 @@ mod tests {
|
|||||||
assert_eq!(a1, a3);
|
assert_eq!(a1, a3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn sub_mut() {
|
fn sub_mut() {
|
||||||
let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||||
@@ -614,6 +624,7 @@ mod tests {
|
|||||||
assert_eq!(a1, a3);
|
assert_eq!(a1, a3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn mul_mut() {
|
fn mul_mut() {
|
||||||
let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||||
@@ -624,6 +635,7 @@ mod tests {
|
|||||||
assert_eq!(a1, a3);
|
assert_eq!(a1, a3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn div_mut() {
|
fn div_mut() {
|
||||||
let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||||
@@ -634,6 +646,7 @@ mod tests {
|
|||||||
assert_eq!(a1, a3);
|
assert_eq!(a1, a3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn div_element_mut() {
|
fn div_element_mut() {
|
||||||
let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||||
@@ -642,6 +655,7 @@ mod tests {
|
|||||||
assert_eq!(BaseMatrix::get(&a, 1, 1), 1.);
|
assert_eq!(BaseMatrix::get(&a, 1, 1), 1.);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn mul_element_mut() {
|
fn mul_element_mut() {
|
||||||
let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||||
@@ -650,6 +664,7 @@ mod tests {
|
|||||||
assert_eq!(BaseMatrix::get(&a, 1, 1), 25.);
|
assert_eq!(BaseMatrix::get(&a, 1, 1), 25.);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn add_element_mut() {
|
fn add_element_mut() {
|
||||||
let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||||
@@ -657,7 +672,7 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(BaseMatrix::get(&a, 1, 1), 10.);
|
assert_eq!(BaseMatrix::get(&a, 1, 1), 10.);
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn sub_element_mut() {
|
fn sub_element_mut() {
|
||||||
let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||||
@@ -666,6 +681,7 @@ mod tests {
|
|||||||
assert_eq!(BaseMatrix::get(&a, 1, 1), 0.);
|
assert_eq!(BaseMatrix::get(&a, 1, 1), 0.);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn vstack_hstack() {
|
fn vstack_hstack() {
|
||||||
let a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
let a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||||
@@ -680,6 +696,7 @@ mod tests {
|
|||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn get_set() {
|
fn get_set() {
|
||||||
let mut result = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
let mut result = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||||
@@ -691,6 +708,7 @@ mod tests {
|
|||||||
assert_eq!(10., BaseMatrix::get(&result, 1, 1));
|
assert_eq!(10., BaseMatrix::get(&result, 1, 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn matmul() {
|
fn matmul() {
|
||||||
let a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
let a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||||
@@ -700,6 +718,7 @@ mod tests {
|
|||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn dot() {
|
fn dot() {
|
||||||
let a = arr2(&[[1., 2., 3.]]);
|
let a = arr2(&[[1., 2., 3.]]);
|
||||||
@@ -707,6 +726,7 @@ mod tests {
|
|||||||
assert_eq!(14., BaseMatrix::dot(&a, &b));
|
assert_eq!(14., BaseMatrix::dot(&a, &b));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn slice() {
|
fn slice() {
|
||||||
let a = arr2(&[
|
let a = arr2(&[
|
||||||
@@ -719,6 +739,7 @@ mod tests {
|
|||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn scalar_ops() {
|
fn scalar_ops() {
|
||||||
let a = arr2(&[[1., 2., 3.]]);
|
let a = arr2(&[[1., 2., 3.]]);
|
||||||
@@ -728,6 +749,7 @@ mod tests {
|
|||||||
assert_eq!(&arr2(&[[0.5, 1., 1.5]]), a.clone().div_scalar_mut(2.));
|
assert_eq!(&arr2(&[[0.5, 1., 1.5]]), a.clone().div_scalar_mut(2.));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn transpose() {
|
fn transpose() {
|
||||||
let m = arr2(&[[1.0, 3.0], [2.0, 4.0]]);
|
let m = arr2(&[[1.0, 3.0], [2.0, 4.0]]);
|
||||||
@@ -736,6 +758,7 @@ mod tests {
|
|||||||
assert_eq!(m_transposed, expected);
|
assert_eq!(m_transposed, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn norm() {
|
fn norm() {
|
||||||
let v = arr2(&[[3., -2., 6.]]);
|
let v = arr2(&[[3., -2., 6.]]);
|
||||||
@@ -745,6 +768,7 @@ mod tests {
|
|||||||
assert_eq!(v.norm(std::f64::NEG_INFINITY), 2.);
|
assert_eq!(v.norm(std::f64::NEG_INFINITY), 2.);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn negative_mut() {
|
fn negative_mut() {
|
||||||
let mut v = arr2(&[[3., -2., 6.]]);
|
let mut v = arr2(&[[3., -2., 6.]]);
|
||||||
@@ -752,6 +776,7 @@ mod tests {
|
|||||||
assert_eq!(v, arr2(&[[-3., 2., -6.]]));
|
assert_eq!(v, arr2(&[[-3., 2., -6.]]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn reshape() {
|
fn reshape() {
|
||||||
let m_orig = arr2(&[[1., 2., 3., 4., 5., 6.]]);
|
let m_orig = arr2(&[[1., 2., 3., 4., 5., 6.]]);
|
||||||
@@ -763,6 +788,7 @@ mod tests {
|
|||||||
assert_eq!(BaseMatrix::get(&m_result, 0, 3), 4.);
|
assert_eq!(BaseMatrix::get(&m_result, 0, 3), 4.);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn copy_from() {
|
fn copy_from() {
|
||||||
let mut src = arr2(&[[1., 2., 3.]]);
|
let mut src = arr2(&[[1., 2., 3.]]);
|
||||||
@@ -771,6 +797,7 @@ mod tests {
|
|||||||
assert_eq!(src, dst);
|
assert_eq!(src, dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn min_max_sum() {
|
fn min_max_sum() {
|
||||||
let a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
let a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
|
||||||
@@ -779,6 +806,7 @@ mod tests {
|
|||||||
assert_eq!(6., a.max());
|
assert_eq!(6., a.max());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn max_diff() {
|
fn max_diff() {
|
||||||
let a1 = arr2(&[[1., 2., 3.], [4., -5., 6.]]);
|
let a1 = arr2(&[[1., 2., 3.], [4., -5., 6.]]);
|
||||||
@@ -787,6 +815,7 @@ mod tests {
|
|||||||
assert_eq!(a2.max_diff(&a2), 0.);
|
assert_eq!(a2.max_diff(&a2), 0.);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn softmax_mut() {
|
fn softmax_mut() {
|
||||||
let mut prob: Array2<f64> = arr2(&[[1., 2., 3.]]);
|
let mut prob: Array2<f64> = arr2(&[[1., 2., 3.]]);
|
||||||
@@ -796,6 +825,7 @@ mod tests {
|
|||||||
assert!((BaseMatrix::get(&prob, 0, 2) - 0.66).abs() < 0.01);
|
assert!((BaseMatrix::get(&prob, 0, 2) - 0.66).abs() < 0.01);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn pow_mut() {
|
fn pow_mut() {
|
||||||
let mut a = arr2(&[[1., 2., 3.]]);
|
let mut a = arr2(&[[1., 2., 3.]]);
|
||||||
@@ -803,6 +833,7 @@ mod tests {
|
|||||||
assert_eq!(a, arr2(&[[1., 8., 27.]]));
|
assert_eq!(a, arr2(&[[1., 8., 27.]]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn argmax() {
|
fn argmax() {
|
||||||
let a = arr2(&[[1., 2., 3.], [-5., -6., -7.], [0.1, 0.2, 0.1]]);
|
let a = arr2(&[[1., 2., 3.], [-5., -6., -7.], [0.1, 0.2, 0.1]]);
|
||||||
@@ -810,6 +841,7 @@ mod tests {
|
|||||||
assert_eq!(res, vec![2, 0, 1]);
|
assert_eq!(res, vec![2, 0, 1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn unique() {
|
fn unique() {
|
||||||
let a = arr2(&[[1., 2., 2.], [-2., -6., -7.], [2., 3., 4.]]);
|
let a = arr2(&[[1., 2., 2.], [-2., -6., -7.], [2., 3., 4.]]);
|
||||||
@@ -818,6 +850,7 @@ mod tests {
|
|||||||
assert_eq!(res, vec![-7., -6., -2., 1., 2., 3., 4.]);
|
assert_eq!(res, vec![-7., -6., -2., 1., 2., 3., 4.]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn get_row_as_vector() {
|
fn get_row_as_vector() {
|
||||||
let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
|
let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
|
||||||
@@ -825,12 +858,14 @@ mod tests {
|
|||||||
assert_eq!(res, vec![4., 5., 6.]);
|
assert_eq!(res, vec![4., 5., 6.]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn get_row() {
|
fn get_row() {
|
||||||
let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
|
let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
|
||||||
assert_eq!(arr1(&[4., 5., 6.]), a.get_row(1));
|
assert_eq!(arr1(&[4., 5., 6.]), a.get_row(1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn get_col_as_vector() {
|
fn get_col_as_vector() {
|
||||||
let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
|
let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
|
||||||
@@ -838,6 +873,7 @@ mod tests {
|
|||||||
assert_eq!(res, vec![2., 5., 8.]);
|
assert_eq!(res, vec![2., 5., 8.]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn copy_row_col_as_vec() {
|
fn copy_row_col_as_vec() {
|
||||||
let m = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
|
let m = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
|
||||||
@@ -849,6 +885,7 @@ mod tests {
|
|||||||
assert_eq!(v, vec!(2., 5., 8.));
|
assert_eq!(v, vec!(2., 5., 8.));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn col_mean() {
|
fn col_mean() {
|
||||||
let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
|
let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
|
||||||
@@ -856,6 +893,7 @@ mod tests {
|
|||||||
assert_eq!(res, vec![4., 5., 6.]);
|
assert_eq!(res, vec![4., 5., 6.]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn eye() {
|
fn eye() {
|
||||||
let a = arr2(&[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]);
|
let a = arr2(&[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]);
|
||||||
@@ -863,6 +901,7 @@ mod tests {
|
|||||||
assert_eq!(res, a);
|
assert_eq!(res, a);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn rand() {
|
fn rand() {
|
||||||
let m: Array2<f64> = BaseMatrix::rand(3, 3);
|
let m: Array2<f64> = BaseMatrix::rand(3, 3);
|
||||||
@@ -873,6 +912,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn approximate_eq() {
|
fn approximate_eq() {
|
||||||
let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
|
let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
|
||||||
@@ -881,6 +921,7 @@ mod tests {
|
|||||||
assert!(!a.approximate_eq(&(&noise + &a), 1e-5));
|
assert!(!a.approximate_eq(&(&noise + &a), 1e-5));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn abs_mut() {
|
fn abs_mut() {
|
||||||
let mut a = arr2(&[[1., -2.], [3., -4.]]);
|
let mut a = arr2(&[[1., -2.], [3., -4.]]);
|
||||||
@@ -889,6 +930,7 @@ mod tests {
|
|||||||
assert_eq!(a, expected);
|
assert_eq!(a, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn lr_fit_predict_iris() {
|
fn lr_fit_predict_iris() {
|
||||||
let x = arr2(&[
|
let x = arr2(&[
|
||||||
@@ -924,12 +966,13 @@ mod tests {
|
|||||||
let error: f64 = y
|
let error: f64 = y
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.zip(y_hat.into_iter())
|
.zip(y_hat.into_iter())
|
||||||
.map(|(&a, &b)| (a - b).abs())
|
.map(|(a, b)| (a - b).abs())
|
||||||
.sum();
|
.sum();
|
||||||
|
|
||||||
assert!(error <= 1.0);
|
assert!(error <= 1.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn my_fit_longley_ndarray() {
|
fn my_fit_longley_ndarray() {
|
||||||
let x = arr2(&[
|
let x = arr2(&[
|
||||||
@@ -964,6 +1007,8 @@ mod tests {
|
|||||||
min_samples_split: 2,
|
min_samples_split: 2,
|
||||||
n_trees: 1000,
|
n_trees: 1000,
|
||||||
m: Option::None,
|
m: Option::None,
|
||||||
|
keep_samples: false,
|
||||||
|
seed: 0,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
|||||||
+2
-1
@@ -195,7 +195,7 @@ pub trait QRDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::*;
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn decompose() {
|
fn decompose() {
|
||||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||||
@@ -214,6 +214,7 @@ mod tests {
|
|||||||
assert!(qr.R().abs().approximate_eq(&r.abs(), 1e-4));
|
assert!(qr.R().abs().approximate_eq(&r.abs(), 1e-4));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn qr_solve_mut() {
|
fn qr_solve_mut() {
|
||||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||||
|
|||||||
+5
-5
@@ -61,7 +61,7 @@ pub trait MatrixStats<T: RealNumber>: BaseMatrix<T> {
|
|||||||
sum += a * a;
|
sum += a * a;
|
||||||
}
|
}
|
||||||
mu /= div;
|
mu /= div;
|
||||||
*x_i = sum / div - mu * mu;
|
*x_i = sum / div - mu.powi(2);
|
||||||
}
|
}
|
||||||
|
|
||||||
x
|
x
|
||||||
@@ -150,7 +150,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn mean() {
|
fn mean() {
|
||||||
let m = DenseMatrix::from_2d_array(&[
|
let m = DenseMatrix::from_2d_array(&[
|
||||||
@@ -164,7 +164,7 @@ mod tests {
|
|||||||
assert_eq!(m.mean(0), expected_0);
|
assert_eq!(m.mean(0), expected_0);
|
||||||
assert_eq!(m.mean(1), expected_1);
|
assert_eq!(m.mean(1), expected_1);
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn std() {
|
fn std() {
|
||||||
let m = DenseMatrix::from_2d_array(&[
|
let m = DenseMatrix::from_2d_array(&[
|
||||||
@@ -178,7 +178,7 @@ mod tests {
|
|||||||
assert!(m.std(0).approximate_eq(&expected_0, 1e-2));
|
assert!(m.std(0).approximate_eq(&expected_0, 1e-2));
|
||||||
assert!(m.std(1).approximate_eq(&expected_1, 1e-2));
|
assert!(m.std(1).approximate_eq(&expected_1, 1e-2));
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn var() {
|
fn var() {
|
||||||
let m = DenseMatrix::from_2d_array(&[&[1., 2., 3., 4.], &[5., 6., 7., 8.]]);
|
let m = DenseMatrix::from_2d_array(&[&[1., 2., 3., 4.], &[5., 6., 7., 8.]]);
|
||||||
@@ -188,7 +188,7 @@ mod tests {
|
|||||||
assert!(m.var(0).approximate_eq(&expected_0, std::f64::EPSILON));
|
assert!(m.var(0).approximate_eq(&expected_0, std::f64::EPSILON));
|
||||||
assert!(m.var(1).approximate_eq(&expected_1, std::f64::EPSILON));
|
assert!(m.var(1).approximate_eq(&expected_1, std::f64::EPSILON));
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn scale() {
|
fn scale() {
|
||||||
let mut m = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
|
let mut m = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
|
||||||
|
|||||||
+10
-9
@@ -47,7 +47,7 @@ pub struct SVD<T: RealNumber, M: SVDDecomposableMatrix<T>> {
|
|||||||
pub V: M,
|
pub V: M,
|
||||||
/// Singular values of the original matrix
|
/// Singular values of the original matrix
|
||||||
pub s: Vec<T>,
|
pub s: Vec<T>,
|
||||||
full: bool,
|
_full: bool,
|
||||||
m: usize,
|
m: usize,
|
||||||
n: usize,
|
n: usize,
|
||||||
tol: T,
|
tol: T,
|
||||||
@@ -116,7 +116,7 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let mut f = U.get(i, i);
|
let mut f = U.get(i, i);
|
||||||
g = -s.sqrt().copysign(f);
|
g = -RealNumber::copysign(s.sqrt(), f);
|
||||||
let h = f * g - s;
|
let h = f * g - s;
|
||||||
U.set(i, i, f - g);
|
U.set(i, i, f - g);
|
||||||
for j in l - 1..n {
|
for j in l - 1..n {
|
||||||
@@ -152,7 +152,7 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let f = U.get(i, l - 1);
|
let f = U.get(i, l - 1);
|
||||||
g = -s.sqrt().copysign(f);
|
g = -RealNumber::copysign(s.sqrt(), f);
|
||||||
let h = f * g - s;
|
let h = f * g - s;
|
||||||
U.set(i, l - 1, f - g);
|
U.set(i, l - 1, f - g);
|
||||||
|
|
||||||
@@ -299,7 +299,7 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
|
|||||||
let mut h = rv1[k];
|
let mut h = rv1[k];
|
||||||
let mut f = ((y - z) * (y + z) + (g - h) * (g + h)) / (T::two() * h * y);
|
let mut f = ((y - z) * (y + z) + (g - h) * (g + h)) / (T::two() * h * y);
|
||||||
g = f.hypot(T::one());
|
g = f.hypot(T::one());
|
||||||
f = ((x - z) * (x + z) + h * ((y / (f + g.copysign(f))) - h)) / x;
|
f = ((x - z) * (x + z) + h * ((y / (f + RealNumber::copysign(g, f))) - h)) / x;
|
||||||
let mut c = T::one();
|
let mut c = T::one();
|
||||||
let mut s = T::one();
|
let mut s = T::one();
|
||||||
|
|
||||||
@@ -428,13 +428,13 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
|
|||||||
pub(crate) fn new(U: M, V: M, s: Vec<T>) -> SVD<T, M> {
|
pub(crate) fn new(U: M, V: M, s: Vec<T>) -> SVD<T, M> {
|
||||||
let m = U.shape().0;
|
let m = U.shape().0;
|
||||||
let n = V.shape().0;
|
let n = V.shape().0;
|
||||||
let full = s.len() == m.min(n);
|
let _full = s.len() == m.min(n);
|
||||||
let tol = T::half() * (T::from(m + n).unwrap() + T::one()).sqrt() * s[0] * T::epsilon();
|
let tol = T::half() * (T::from(m + n).unwrap() + T::one()).sqrt() * s[0] * T::epsilon();
|
||||||
SVD {
|
SVD {
|
||||||
U,
|
U,
|
||||||
V,
|
V,
|
||||||
s,
|
s,
|
||||||
full,
|
_full,
|
||||||
m,
|
m,
|
||||||
n,
|
n,
|
||||||
tol,
|
tol,
|
||||||
@@ -482,7 +482,7 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn decompose_symmetric() {
|
fn decompose_symmetric() {
|
||||||
let A = DenseMatrix::from_2d_array(&[
|
let A = DenseMatrix::from_2d_array(&[
|
||||||
@@ -513,7 +513,7 @@ mod tests {
|
|||||||
assert!((s[i] - svd.s[i]).abs() < 1e-4);
|
assert!((s[i] - svd.s[i]).abs() < 1e-4);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn decompose_asymmetric() {
|
fn decompose_asymmetric() {
|
||||||
let A = DenseMatrix::from_2d_array(&[
|
let A = DenseMatrix::from_2d_array(&[
|
||||||
@@ -714,7 +714,7 @@ mod tests {
|
|||||||
assert!((s[i] - svd.s[i]).abs() < 1e-4);
|
assert!((s[i] - svd.s[i]).abs() < 1e-4);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn solve() {
|
fn solve() {
|
||||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||||
@@ -725,6 +725,7 @@ mod tests {
|
|||||||
assert!(w.approximate_eq(&expected_w, 1e-2));
|
assert!(w.approximate_eq(&expected_w, 1e-2));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn decompose_restore() {
|
fn decompose_restore() {
|
||||||
let a = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0, 4.0], &[5.0, 6.0, 7.0, 8.0]]);
|
let a = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0, 4.0], &[5.0, 6.0, 7.0, 8.0]]);
|
||||||
|
|||||||
@@ -126,6 +126,7 @@ mod tests {
|
|||||||
|
|
||||||
impl<T: RealNumber, M: Matrix<T>> BiconjugateGradientSolver<T, M> for BGSolver {}
|
impl<T: RealNumber, M: Matrix<T>> BiconjugateGradientSolver<T, M> for BGSolver {}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn bg_solver() {
|
fn bg_solver() {
|
||||||
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
|
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
|
||||||
|
|||||||
@@ -56,6 +56,7 @@
|
|||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::api::{Predictor, SupervisedEstimator};
|
use crate::api::{Predictor, SupervisedEstimator};
|
||||||
@@ -67,7 +68,8 @@ use crate::math::num::RealNumber;
|
|||||||
use crate::linear::lasso_optimizer::InteriorPointOptimizer;
|
use crate::linear::lasso_optimizer::InteriorPointOptimizer;
|
||||||
|
|
||||||
/// Elastic net parameters
|
/// Elastic net parameters
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct ElasticNetParameters<T: RealNumber> {
|
pub struct ElasticNetParameters<T: RealNumber> {
|
||||||
/// Regularization parameter.
|
/// Regularization parameter.
|
||||||
pub alpha: T,
|
pub alpha: T,
|
||||||
@@ -84,7 +86,8 @@ pub struct ElasticNetParameters<T: RealNumber> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Elastic net
|
/// Elastic net
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct ElasticNet<T: RealNumber, M: Matrix<T>> {
|
pub struct ElasticNet<T: RealNumber, M: Matrix<T>> {
|
||||||
coefficients: M,
|
coefficients: M,
|
||||||
intercept: T,
|
intercept: T,
|
||||||
@@ -288,6 +291,7 @@ mod tests {
|
|||||||
use crate::linalg::naive::dense_matrix::*;
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
use crate::metrics::mean_absolute_error;
|
use crate::metrics::mean_absolute_error;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn elasticnet_longley() {
|
fn elasticnet_longley() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
@@ -331,6 +335,7 @@ mod tests {
|
|||||||
assert!(mean_absolute_error(&y_hat, &y) < 30.0);
|
assert!(mean_absolute_error(&y_hat, &y) < 30.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn elasticnet_fit_predict1() {
|
fn elasticnet_fit_predict1() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
@@ -397,7 +402,9 @@ mod tests {
|
|||||||
assert!(l1_model.coefficients().get(0, 0) > l1_model.coefficients().get(2, 0));
|
assert!(l1_model.coefficients().get(0, 0) > l1_model.coefficients().get(2, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||||
|
|||||||
+8
-2
@@ -24,6 +24,7 @@
|
|||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::api::{Predictor, SupervisedEstimator};
|
use crate::api::{Predictor, SupervisedEstimator};
|
||||||
@@ -34,7 +35,8 @@ use crate::linear::lasso_optimizer::InteriorPointOptimizer;
|
|||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
/// Lasso regression parameters
|
/// Lasso regression parameters
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct LassoParameters<T: RealNumber> {
|
pub struct LassoParameters<T: RealNumber> {
|
||||||
/// Controls the strength of the penalty to the loss function.
|
/// Controls the strength of the penalty to the loss function.
|
||||||
pub alpha: T,
|
pub alpha: T,
|
||||||
@@ -47,7 +49,8 @@ pub struct LassoParameters<T: RealNumber> {
|
|||||||
pub max_iter: usize,
|
pub max_iter: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
/// Lasso regressor
|
/// Lasso regressor
|
||||||
pub struct Lasso<T: RealNumber, M: Matrix<T>> {
|
pub struct Lasso<T: RealNumber, M: Matrix<T>> {
|
||||||
coefficients: M,
|
coefficients: M,
|
||||||
@@ -223,6 +226,7 @@ mod tests {
|
|||||||
use crate::linalg::naive::dense_matrix::*;
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
use crate::metrics::mean_absolute_error;
|
use crate::metrics::mean_absolute_error;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn lasso_fit_predict() {
|
fn lasso_fit_predict() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
@@ -271,7 +275,9 @@ mod tests {
|
|||||||
assert!(mean_absolute_error(&y_hat, &y) < 2.0);
|
assert!(mean_absolute_error(&y_hat, &y) < 2.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ impl<T: RealNumber, M: Matrix<T>> InteriorPointOptimizer<T, M> {
|
|||||||
|
|
||||||
for i in 0..p {
|
for i in 0..p {
|
||||||
self.prb[i] = T::two() + self.d1[i];
|
self.prb[i] = T::two() + self.d1[i];
|
||||||
self.prs[i] = self.prb[i] * self.d1[i] - self.d2[i] * self.d2[i];
|
self.prs[i] = self.prb[i] * self.d1[i] - self.d2[i].powi(2);
|
||||||
}
|
}
|
||||||
|
|
||||||
let normg = grad.norm2();
|
let normg = grad.norm2();
|
||||||
@@ -211,9 +211,7 @@ impl<T: RealNumber, M: Matrix<T>> InteriorPointOptimizer<T, M> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: RealNumber, M: Matrix<T>> BiconjugateGradientSolver<T, M>
|
impl<T: RealNumber, M: Matrix<T>> BiconjugateGradientSolver<T, M> for InteriorPointOptimizer<T, M> {
|
||||||
for InteriorPointOptimizer<T, M>
|
|
||||||
{
|
|
||||||
fn solve_preconditioner(&self, a: &M, b: &M, x: &mut M) {
|
fn solve_preconditioner(&self, a: &M, b: &M, x: &mut M) {
|
||||||
let (_, p) = a.shape();
|
let (_, p) = a.shape();
|
||||||
|
|
||||||
|
|||||||
@@ -62,6 +62,7 @@
|
|||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::api::{Predictor, SupervisedEstimator};
|
use crate::api::{Predictor, SupervisedEstimator};
|
||||||
@@ -69,7 +70,8 @@ use crate::error::Failed;
|
|||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
/// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable.
|
/// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable.
|
||||||
pub enum LinearRegressionSolverName {
|
pub enum LinearRegressionSolverName {
|
||||||
/// QR decomposition, see [QR](../../linalg/qr/index.html)
|
/// QR decomposition, see [QR](../../linalg/qr/index.html)
|
||||||
@@ -79,18 +81,20 @@ pub enum LinearRegressionSolverName {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Linear Regression parameters
|
/// Linear Regression parameters
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct LinearRegressionParameters {
|
pub struct LinearRegressionParameters {
|
||||||
/// Solver to use for estimation of regression coefficients.
|
/// Solver to use for estimation of regression coefficients.
|
||||||
pub solver: LinearRegressionSolverName,
|
pub solver: LinearRegressionSolverName,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Linear Regression
|
/// Linear Regression
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct LinearRegression<T: RealNumber, M: Matrix<T>> {
|
pub struct LinearRegression<T: RealNumber, M: Matrix<T>> {
|
||||||
coefficients: M,
|
coefficients: M,
|
||||||
intercept: T,
|
intercept: T,
|
||||||
solver: LinearRegressionSolverName,
|
_solver: LinearRegressionSolverName,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LinearRegressionParameters {
|
impl LinearRegressionParameters {
|
||||||
@@ -151,7 +155,7 @@ impl<T: RealNumber, M: Matrix<T>> LinearRegression<T, M> {
|
|||||||
|
|
||||||
if x_nrows != y_nrows {
|
if x_nrows != y_nrows {
|
||||||
return Err(Failed::fit(
|
return Err(Failed::fit(
|
||||||
&"Number of rows of X doesn\'t match number of rows of Y".to_string(),
|
"Number of rows of X doesn\'t match number of rows of Y",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -167,7 +171,7 @@ impl<T: RealNumber, M: Matrix<T>> LinearRegression<T, M> {
|
|||||||
Ok(LinearRegression {
|
Ok(LinearRegression {
|
||||||
intercept: w.get(num_attributes, 0),
|
intercept: w.get(num_attributes, 0),
|
||||||
coefficients: wights,
|
coefficients: wights,
|
||||||
solver: parameters.solver,
|
_solver: parameters.solver,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -196,6 +200,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::*;
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn ols_fit_predict() {
|
fn ols_fit_predict() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
@@ -246,7 +251,9 @@ mod tests {
|
|||||||
.all(|(&a, &b)| (a - b).abs() <= 5.0));
|
.all(|(&a, &b)| (a - b).abs() <= 5.0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||||
|
|||||||
@@ -54,8 +54,8 @@
|
|||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
use std::cmp::Ordering;
|
use std::cmp::Ordering;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::marker::PhantomData;
|
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::api::{Predictor, SupervisedEstimator};
|
use crate::api::{Predictor, SupervisedEstimator};
|
||||||
@@ -67,12 +67,27 @@ use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
|
|||||||
use crate::optimization::line_search::Backtracking;
|
use crate::optimization::line_search::Backtracking;
|
||||||
use crate::optimization::FunctionOrder;
|
use crate::optimization::FunctionOrder;
|
||||||
|
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
/// Solver options for Logistic regression. Right now only LBFGS solver is supported.
|
||||||
|
pub enum LogisticRegressionSolverName {
|
||||||
|
/// Limited-memory Broyden–Fletcher–Goldfarb–Shanno method, see [LBFGS paper](http://users.iems.northwestern.edu/~nocedal/lbfgsb.html)
|
||||||
|
LBFGS,
|
||||||
|
}
|
||||||
|
|
||||||
/// Logistic Regression parameters
|
/// Logistic Regression parameters
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
pub struct LogisticRegressionParameters {}
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct LogisticRegressionParameters<T: RealNumber> {
|
||||||
|
/// Solver to use for estimation of regression coefficients.
|
||||||
|
pub solver: LogisticRegressionSolverName,
|
||||||
|
/// Regularization parameter.
|
||||||
|
pub alpha: T,
|
||||||
|
}
|
||||||
|
|
||||||
/// Logistic Regression
|
/// Logistic Regression
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct LogisticRegression<T: RealNumber, M: Matrix<T>> {
|
pub struct LogisticRegression<T: RealNumber, M: Matrix<T>> {
|
||||||
coefficients: M,
|
coefficients: M,
|
||||||
intercept: M,
|
intercept: M,
|
||||||
@@ -99,12 +114,28 @@ trait ObjectiveFunction<T: RealNumber, M: Matrix<T>> {
|
|||||||
struct BinaryObjectiveFunction<'a, T: RealNumber, M: Matrix<T>> {
|
struct BinaryObjectiveFunction<'a, T: RealNumber, M: Matrix<T>> {
|
||||||
x: &'a M,
|
x: &'a M,
|
||||||
y: Vec<usize>,
|
y: Vec<usize>,
|
||||||
phantom: PhantomData<&'a T>,
|
alpha: T,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for LogisticRegressionParameters {
|
impl<T: RealNumber> LogisticRegressionParameters<T> {
|
||||||
|
/// Solver to use for estimation of regression coefficients.
|
||||||
|
pub fn with_solver(mut self, solver: LogisticRegressionSolverName) -> Self {
|
||||||
|
self.solver = solver;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
/// Regularization parameter.
|
||||||
|
pub fn with_alpha(mut self, alpha: T) -> Self {
|
||||||
|
self.alpha = alpha;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber> Default for LogisticRegressionParameters<T> {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
LogisticRegressionParameters {}
|
LogisticRegressionParameters {
|
||||||
|
solver: LogisticRegressionSolverName::LBFGS,
|
||||||
|
alpha: T::zero(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -132,13 +163,22 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
|
|||||||
{
|
{
|
||||||
fn f(&self, w_bias: &M) -> T {
|
fn f(&self, w_bias: &M) -> T {
|
||||||
let mut f = T::zero();
|
let mut f = T::zero();
|
||||||
let (n, _) = self.x.shape();
|
let (n, p) = self.x.shape();
|
||||||
|
|
||||||
for i in 0..n {
|
for i in 0..n {
|
||||||
let wx = BinaryObjectiveFunction::partial_dot(w_bias, self.x, 0, i);
|
let wx = BinaryObjectiveFunction::partial_dot(w_bias, self.x, 0, i);
|
||||||
f += wx.ln_1pe() - (T::from(self.y[i]).unwrap()) * wx;
|
f += wx.ln_1pe() - (T::from(self.y[i]).unwrap()) * wx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.alpha > T::zero() {
|
||||||
|
let mut w_squared = T::zero();
|
||||||
|
for i in 0..p {
|
||||||
|
let w = w_bias.get(0, i);
|
||||||
|
w_squared += w * w;
|
||||||
|
}
|
||||||
|
f += T::half() * self.alpha * w_squared;
|
||||||
|
}
|
||||||
|
|
||||||
f
|
f
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -156,6 +196,13 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
|
|||||||
}
|
}
|
||||||
g.set(0, p, g.get(0, p) - dyi);
|
g.set(0, p, g.get(0, p) - dyi);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.alpha > T::zero() {
|
||||||
|
for i in 0..p {
|
||||||
|
let w = w_bias.get(0, i);
|
||||||
|
g.set(0, i, g.get(0, i) + self.alpha * w);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -163,7 +210,7 @@ struct MultiClassObjectiveFunction<'a, T: RealNumber, M: Matrix<T>> {
|
|||||||
x: &'a M,
|
x: &'a M,
|
||||||
y: Vec<usize>,
|
y: Vec<usize>,
|
||||||
k: usize,
|
k: usize,
|
||||||
phantom: PhantomData<&'a T>,
|
alpha: T,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
|
impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
|
||||||
@@ -185,6 +232,17 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
|
|||||||
f -= prob.get(0, self.y[i]).ln();
|
f -= prob.get(0, self.y[i]).ln();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.alpha > T::zero() {
|
||||||
|
let mut w_squared = T::zero();
|
||||||
|
for i in 0..self.k {
|
||||||
|
for j in 0..p {
|
||||||
|
let wi = w_bias.get(0, i * (p + 1) + j);
|
||||||
|
w_squared += wi * wi;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
f += T::half() * self.alpha * w_squared;
|
||||||
|
}
|
||||||
|
|
||||||
f
|
f
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -215,16 +273,27 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
|
|||||||
g.set(0, j * (p + 1) + p, g.get(0, j * (p + 1) + p) - yi);
|
g.set(0, j * (p + 1) + p, g.get(0, j * (p + 1) + p) - yi);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.alpha > T::zero() {
|
||||||
|
for i in 0..self.k {
|
||||||
|
for j in 0..p {
|
||||||
|
let pos = i * (p + 1);
|
||||||
|
let wi = w.get(0, pos + j);
|
||||||
|
g.set(0, pos + j, g.get(0, pos + j) + self.alpha * wi);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, LogisticRegressionParameters>
|
impl<T: RealNumber, M: Matrix<T>>
|
||||||
|
SupervisedEstimator<M, M::RowVector, LogisticRegressionParameters<T>>
|
||||||
for LogisticRegression<T, M>
|
for LogisticRegression<T, M>
|
||||||
{
|
{
|
||||||
fn fit(
|
fn fit(
|
||||||
x: &M,
|
x: &M,
|
||||||
y: &M::RowVector,
|
y: &M::RowVector,
|
||||||
parameters: LogisticRegressionParameters,
|
parameters: LogisticRegressionParameters<T>,
|
||||||
) -> Result<Self, Failed> {
|
) -> Result<Self, Failed> {
|
||||||
LogisticRegression::fit(x, y, parameters)
|
LogisticRegression::fit(x, y, parameters)
|
||||||
}
|
}
|
||||||
@@ -244,7 +313,7 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
|
|||||||
pub fn fit(
|
pub fn fit(
|
||||||
x: &M,
|
x: &M,
|
||||||
y: &M::RowVector,
|
y: &M::RowVector,
|
||||||
_parameters: LogisticRegressionParameters,
|
parameters: LogisticRegressionParameters<T>,
|
||||||
) -> Result<LogisticRegression<T, M>, Failed> {
|
) -> Result<LogisticRegression<T, M>, Failed> {
|
||||||
let y_m = M::from_row_vector(y.clone());
|
let y_m = M::from_row_vector(y.clone());
|
||||||
let (x_nrows, num_attributes) = x.shape();
|
let (x_nrows, num_attributes) = x.shape();
|
||||||
@@ -252,7 +321,7 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
|
|||||||
|
|
||||||
if x_nrows != y_nrows {
|
if x_nrows != y_nrows {
|
||||||
return Err(Failed::fit(
|
return Err(Failed::fit(
|
||||||
&"Number of rows of X doesn\'t match number of rows of Y".to_string(),
|
"Number of rows of X doesn\'t match number of rows of Y",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -278,7 +347,7 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
|
|||||||
let objective = BinaryObjectiveFunction {
|
let objective = BinaryObjectiveFunction {
|
||||||
x,
|
x,
|
||||||
y: yi,
|
y: yi,
|
||||||
phantom: PhantomData,
|
alpha: parameters.alpha,
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = LogisticRegression::minimize(x0, objective);
|
let result = LogisticRegression::minimize(x0, objective);
|
||||||
@@ -300,7 +369,7 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
|
|||||||
x,
|
x,
|
||||||
y: yi,
|
y: yi,
|
||||||
k,
|
k,
|
||||||
phantom: PhantomData,
|
alpha: parameters.alpha,
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = LogisticRegression::minimize(x0, objective);
|
let result = LogisticRegression::minimize(x0, objective);
|
||||||
@@ -383,6 +452,7 @@ mod tests {
|
|||||||
use crate::linalg::naive::dense_matrix::*;
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
use crate::metrics::accuracy;
|
use crate::metrics::accuracy;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn multiclass_objective_f() {
|
fn multiclass_objective_f() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
@@ -407,9 +477,9 @@ mod tests {
|
|||||||
|
|
||||||
let objective = MultiClassObjectiveFunction {
|
let objective = MultiClassObjectiveFunction {
|
||||||
x: &x,
|
x: &x,
|
||||||
y,
|
y: y.clone(),
|
||||||
k: 3,
|
k: 3,
|
||||||
phantom: PhantomData,
|
alpha: 0.0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 9);
|
let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 9);
|
||||||
@@ -430,8 +500,27 @@ mod tests {
|
|||||||
]));
|
]));
|
||||||
|
|
||||||
assert!((f - 408.0052230582765).abs() < std::f64::EPSILON);
|
assert!((f - 408.0052230582765).abs() < std::f64::EPSILON);
|
||||||
|
|
||||||
|
let objective_reg = MultiClassObjectiveFunction {
|
||||||
|
x: &x,
|
||||||
|
y: y.clone(),
|
||||||
|
k: 3,
|
||||||
|
alpha: 1.0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let f = objective_reg.f(&DenseMatrix::row_vector_from_array(&[
|
||||||
|
1., 2., 3., 4., 5., 6., 7., 8., 9.,
|
||||||
|
]));
|
||||||
|
assert!((f - 487.5052).abs() < 1e-4);
|
||||||
|
|
||||||
|
objective_reg.df(
|
||||||
|
&mut g,
|
||||||
|
&DenseMatrix::row_vector_from_array(&[1., 2., 3., 4., 5., 6., 7., 8., 9.]),
|
||||||
|
);
|
||||||
|
assert!((g.get(0, 0).abs() - 32.0).abs() < 1e-4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn binary_objective_f() {
|
fn binary_objective_f() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
@@ -456,8 +545,8 @@ mod tests {
|
|||||||
|
|
||||||
let objective = BinaryObjectiveFunction {
|
let objective = BinaryObjectiveFunction {
|
||||||
x: &x,
|
x: &x,
|
||||||
y,
|
y: y.clone(),
|
||||||
phantom: PhantomData,
|
alpha: 0.0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 3);
|
let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 3);
|
||||||
@@ -472,8 +561,23 @@ mod tests {
|
|||||||
let f = objective.f(&DenseMatrix::row_vector_from_array(&[1., 2., 3.]));
|
let f = objective.f(&DenseMatrix::row_vector_from_array(&[1., 2., 3.]));
|
||||||
|
|
||||||
assert!((f - 59.76994756647412).abs() < std::f64::EPSILON);
|
assert!((f - 59.76994756647412).abs() < std::f64::EPSILON);
|
||||||
|
|
||||||
|
let objective_reg = BinaryObjectiveFunction {
|
||||||
|
x: &x,
|
||||||
|
y: y.clone(),
|
||||||
|
alpha: 1.0,
|
||||||
|
};
|
||||||
|
|
||||||
|
let f = objective_reg.f(&DenseMatrix::row_vector_from_array(&[1., 2., 3.]));
|
||||||
|
assert!((f - 62.2699).abs() < 1e-4);
|
||||||
|
|
||||||
|
objective_reg.df(&mut g, &DenseMatrix::row_vector_from_array(&[1., 2., 3.]));
|
||||||
|
assert!((g.get(0, 0) - 27.0511).abs() < 1e-4);
|
||||||
|
assert!((g.get(0, 1) - 12.239).abs() < 1e-4);
|
||||||
|
assert!((g.get(0, 2) - 3.8693).abs() < 1e-4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn lr_fit_predict() {
|
fn lr_fit_predict() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
@@ -511,6 +615,7 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn lr_fit_predict_multiclass() {
|
fn lr_fit_predict_multiclass() {
|
||||||
let blobs = make_blobs(15, 4, 3);
|
let blobs = make_blobs(15, 4, 3);
|
||||||
@@ -523,8 +628,18 @@ mod tests {
|
|||||||
let y_hat = lr.predict(&x).unwrap();
|
let y_hat = lr.predict(&x).unwrap();
|
||||||
|
|
||||||
assert!(accuracy(&y_hat, &y) > 0.9);
|
assert!(accuracy(&y_hat, &y) > 0.9);
|
||||||
|
|
||||||
|
let lr_reg = LogisticRegression::fit(
|
||||||
|
&x,
|
||||||
|
&y,
|
||||||
|
LogisticRegressionParameters::default().with_alpha(10.0),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(lr_reg.coefficients().abs().sum() < lr.coefficients().abs().sum());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn lr_fit_predict_binary() {
|
fn lr_fit_predict_binary() {
|
||||||
let blobs = make_blobs(20, 4, 2);
|
let blobs = make_blobs(20, 4, 2);
|
||||||
@@ -537,9 +652,20 @@ mod tests {
|
|||||||
let y_hat = lr.predict(&x).unwrap();
|
let y_hat = lr.predict(&x).unwrap();
|
||||||
|
|
||||||
assert!(accuracy(&y_hat, &y) > 0.9);
|
assert!(accuracy(&y_hat, &y) > 0.9);
|
||||||
|
|
||||||
|
let lr_reg = LogisticRegression::fit(
|
||||||
|
&x,
|
||||||
|
&y,
|
||||||
|
LogisticRegressionParameters::default().with_alpha(10.0),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(lr_reg.coefficients().abs().sum() < lr.coefficients().abs().sum());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[1., -5.],
|
&[1., -5.],
|
||||||
@@ -568,6 +694,7 @@ mod tests {
|
|||||||
assert_eq!(lr, deserialized_lr);
|
assert_eq!(lr, deserialized_lr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn lr_fit_predict_iris() {
|
fn lr_fit_predict_iris() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
@@ -597,6 +724,12 @@ mod tests {
|
|||||||
];
|
];
|
||||||
|
|
||||||
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||||
|
let lr_reg = LogisticRegression::fit(
|
||||||
|
&x,
|
||||||
|
&y,
|
||||||
|
LogisticRegressionParameters::default().with_alpha(1.0),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let y_hat = lr.predict(&x).unwrap();
|
let y_hat = lr.predict(&x).unwrap();
|
||||||
|
|
||||||
@@ -607,5 +740,6 @@ mod tests {
|
|||||||
.sum();
|
.sum();
|
||||||
|
|
||||||
assert!(error <= 1.0);
|
assert!(error <= 1.0);
|
||||||
|
assert!(lr_reg.coefficients().abs().sum() < lr.coefficients().abs().sum());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -58,6 +58,7 @@
|
|||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::api::{Predictor, SupervisedEstimator};
|
use crate::api::{Predictor, SupervisedEstimator};
|
||||||
@@ -66,7 +67,8 @@ use crate::linalg::BaseVector;
|
|||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
/// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable.
|
/// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable.
|
||||||
pub enum RidgeRegressionSolverName {
|
pub enum RidgeRegressionSolverName {
|
||||||
/// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html)
|
/// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html)
|
||||||
@@ -76,7 +78,8 @@ pub enum RidgeRegressionSolverName {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Ridge Regression parameters
|
/// Ridge Regression parameters
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct RidgeRegressionParameters<T: RealNumber> {
|
pub struct RidgeRegressionParameters<T: RealNumber> {
|
||||||
/// Solver to use for estimation of regression coefficients.
|
/// Solver to use for estimation of regression coefficients.
|
||||||
pub solver: RidgeRegressionSolverName,
|
pub solver: RidgeRegressionSolverName,
|
||||||
@@ -88,11 +91,12 @@ pub struct RidgeRegressionParameters<T: RealNumber> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Ridge regression
|
/// Ridge regression
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct RidgeRegression<T: RealNumber, M: Matrix<T>> {
|
pub struct RidgeRegression<T: RealNumber, M: Matrix<T>> {
|
||||||
coefficients: M,
|
coefficients: M,
|
||||||
intercept: T,
|
intercept: T,
|
||||||
solver: RidgeRegressionSolverName,
|
_solver: RidgeRegressionSolverName,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: RealNumber> RidgeRegressionParameters<T> {
|
impl<T: RealNumber> RidgeRegressionParameters<T> {
|
||||||
@@ -222,7 +226,7 @@ impl<T: RealNumber, M: Matrix<T>> RidgeRegression<T, M> {
|
|||||||
Ok(RidgeRegression {
|
Ok(RidgeRegression {
|
||||||
intercept: b,
|
intercept: b,
|
||||||
coefficients: w,
|
coefficients: w,
|
||||||
solver: parameters.solver,
|
_solver: parameters.solver,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -270,6 +274,7 @@ mod tests {
|
|||||||
use crate::linalg::naive::dense_matrix::*;
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
use crate::metrics::mean_absolute_error;
|
use crate::metrics::mean_absolute_error;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn ridge_fit_predict() {
|
fn ridge_fit_predict() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
@@ -325,7 +330,9 @@ mod tests {
|
|||||||
assert!(mean_absolute_error(&y_hat_svd, &y) < 2.0);
|
assert!(mean_absolute_error(&y_hat_svd, &y) < 2.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||||
|
|||||||
@@ -18,6 +18,7 @@
|
|||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
@@ -25,7 +26,8 @@ use crate::math::num::RealNumber;
|
|||||||
use super::Distance;
|
use super::Distance;
|
||||||
|
|
||||||
/// Euclidean distance is a measure of the true straight line distance between two points in Euclidean n-space.
|
/// Euclidean distance is a measure of the true straight line distance between two points in Euclidean n-space.
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct Euclidian {}
|
pub struct Euclidian {}
|
||||||
|
|
||||||
impl Euclidian {
|
impl Euclidian {
|
||||||
@@ -55,6 +57,7 @@ impl<T: RealNumber> Distance<Vec<T>, T> for Euclidian {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn squared_distance() {
|
fn squared_distance() {
|
||||||
let a = vec![1., 2., 3.];
|
let a = vec![1., 2., 3.];
|
||||||
|
|||||||
@@ -19,6 +19,7 @@
|
|||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
@@ -26,7 +27,8 @@ use crate::math::num::RealNumber;
|
|||||||
use super::Distance;
|
use super::Distance;
|
||||||
|
|
||||||
/// While comparing two integer-valued vectors of equal length, Hamming distance is the number of bit positions in which the two bits are different
|
/// While comparing two integer-valued vectors of equal length, Hamming distance is the number of bit positions in which the two bits are different
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct Hamming {}
|
pub struct Hamming {}
|
||||||
|
|
||||||
impl<T: PartialEq, F: RealNumber> Distance<Vec<T>, F> for Hamming {
|
impl<T: PartialEq, F: RealNumber> Distance<Vec<T>, F> for Hamming {
|
||||||
@@ -50,6 +52,7 @@ impl<T: PartialEq, F: RealNumber> Distance<Vec<T>, F> for Hamming {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn hamming_distance() {
|
fn hamming_distance() {
|
||||||
let a = vec![1, 0, 0, 1, 0, 0, 1];
|
let a = vec![1, 0, 0, 1, 0, 0, 1];
|
||||||
|
|||||||
@@ -44,6 +44,7 @@
|
|||||||
|
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
@@ -52,7 +53,8 @@ use super::Distance;
|
|||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
|
|
||||||
/// Mahalanobis distance.
|
/// Mahalanobis distance.
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct Mahalanobis<T: RealNumber, M: Matrix<T>> {
|
pub struct Mahalanobis<T: RealNumber, M: Matrix<T>> {
|
||||||
/// covariance matrix of the dataset
|
/// covariance matrix of the dataset
|
||||||
pub sigma: M,
|
pub sigma: M,
|
||||||
@@ -131,6 +133,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::*;
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn mahalanobis_distance() {
|
fn mahalanobis_distance() {
|
||||||
let data = DenseMatrix::from_2d_array(&[
|
let data = DenseMatrix::from_2d_array(&[
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
//! ```
|
//! ```
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
@@ -24,7 +25,8 @@ use crate::math::num::RealNumber;
|
|||||||
use super::Distance;
|
use super::Distance;
|
||||||
|
|
||||||
/// Manhattan distance
|
/// Manhattan distance
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct Manhattan {}
|
pub struct Manhattan {}
|
||||||
|
|
||||||
impl<T: RealNumber> Distance<Vec<T>, T> for Manhattan {
|
impl<T: RealNumber> Distance<Vec<T>, T> for Manhattan {
|
||||||
@@ -46,6 +48,7 @@ impl<T: RealNumber> Distance<Vec<T>, T> for Manhattan {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn manhattan_distance() {
|
fn manhattan_distance() {
|
||||||
let a = vec![1., 2., 3.];
|
let a = vec![1., 2., 3.];
|
||||||
|
|||||||
@@ -21,6 +21,7 @@
|
|||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
@@ -28,7 +29,8 @@ use crate::math::num::RealNumber;
|
|||||||
use super::Distance;
|
use super::Distance;
|
||||||
|
|
||||||
/// Defines the Minkowski distance of order `p`
|
/// Defines the Minkowski distance of order `p`
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct Minkowski {
|
pub struct Minkowski {
|
||||||
/// order, integer
|
/// order, integer
|
||||||
pub p: u16,
|
pub p: u16,
|
||||||
@@ -59,6 +61,7 @@ impl<T: RealNumber> Distance<Vec<T>, T> for Minkowski {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn minkowski_distance() {
|
fn minkowski_distance() {
|
||||||
let a = vec![1., 2., 3.];
|
let a = vec![1., 2., 3.];
|
||||||
|
|||||||
@@ -136,6 +136,7 @@ impl RealNumber for f32 {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn sigmoid() {
|
fn sigmoid() {
|
||||||
assert_eq!(1.0.sigmoid(), 0.7310585786300049);
|
assert_eq!(1.0.sigmoid(), 0.7310585786300049);
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ impl<T: RealNumber, V: BaseVector<T>> RealNumberVector<T> for V {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn unique_with_indices() {
|
fn unique_with_indices() {
|
||||||
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
|
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
|
||||||
|
|||||||
@@ -16,13 +16,15 @@
|
|||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
/// Accuracy metric.
|
/// Accuracy metric.
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct Accuracy {}
|
pub struct Accuracy {}
|
||||||
|
|
||||||
impl Accuracy {
|
impl Accuracy {
|
||||||
@@ -55,6 +57,7 @@ impl Accuracy {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn accuracy() {
|
fn accuracy() {
|
||||||
let y_pred: Vec<f64> = vec![0., 2., 1., 3.];
|
let y_pred: Vec<f64> = vec![0., 2., 1., 3.];
|
||||||
|
|||||||
+4
-1
@@ -20,6 +20,7 @@
|
|||||||
//! * ["The ROC-AUC and the Mann-Whitney U-test", Haupt, J.](https://johaupt.github.io/roc-auc/model%20evaluation/Area_under_ROC_curve.html)
|
//! * ["The ROC-AUC and the Mann-Whitney U-test", Haupt, J.](https://johaupt.github.io/roc-auc/model%20evaluation/Area_under_ROC_curve.html)
|
||||||
#![allow(non_snake_case)]
|
#![allow(non_snake_case)]
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||||
@@ -27,7 +28,8 @@ use crate::linalg::BaseVector;
|
|||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
/// Area Under the Receiver Operating Characteristic Curve (ROC AUC)
|
/// Area Under the Receiver Operating Characteristic Curve (ROC AUC)
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct AUC {}
|
pub struct AUC {}
|
||||||
|
|
||||||
impl AUC {
|
impl AUC {
|
||||||
@@ -91,6 +93,7 @@ impl AUC {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn auc() {
|
fn auc() {
|
||||||
let y_true: Vec<f64> = vec![0., 0., 1., 1.];
|
let y_true: Vec<f64> = vec![0., 0., 1., 1.];
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
use crate::metrics::cluster_helpers::*;
|
use crate::metrics::cluster_helpers::*;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
/// Homogeneity, completeness and V-Measure scores.
|
/// Homogeneity, completeness and V-Measure scores.
|
||||||
pub struct HCVScore {}
|
pub struct HCVScore {}
|
||||||
|
|
||||||
@@ -41,6 +43,7 @@ impl HCVScore {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn homogeneity_score() {
|
fn homogeneity_score() {
|
||||||
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
|
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
|
||||||
|
|||||||
@@ -101,6 +101,7 @@ pub fn mutual_info_score<T: RealNumber>(contingency: &[Vec<usize>]) -> T {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn contingency_matrix_test() {
|
fn contingency_matrix_test() {
|
||||||
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
|
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
|
||||||
@@ -112,6 +113,7 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn entropy_test() {
|
fn entropy_test() {
|
||||||
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
|
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
|
||||||
@@ -119,6 +121,7 @@ mod tests {
|
|||||||
assert!((1.2770f32 - entropy(&v1).unwrap()).abs() < 1e-4);
|
assert!((1.2770f32 - entropy(&v1).unwrap()).abs() < 1e-4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn mutual_info_score_test() {
|
fn mutual_info_score_test() {
|
||||||
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
|
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
|
||||||
|
|||||||
+4
-1
@@ -18,6 +18,7 @@
|
|||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
@@ -26,7 +27,8 @@ use crate::metrics::precision::Precision;
|
|||||||
use crate::metrics::recall::Recall;
|
use crate::metrics::recall::Recall;
|
||||||
|
|
||||||
/// F-measure
|
/// F-measure
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct F1<T: RealNumber> {
|
pub struct F1<T: RealNumber> {
|
||||||
/// a positive real factor
|
/// a positive real factor
|
||||||
pub beta: T,
|
pub beta: T,
|
||||||
@@ -57,6 +59,7 @@ impl<T: RealNumber> F1<T> {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn f1() {
|
fn f1() {
|
||||||
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
||||||
|
|||||||
@@ -18,12 +18,14 @@
|
|||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
/// Mean Absolute Error
|
/// Mean Absolute Error
|
||||||
pub struct MeanAbsoluteError {}
|
pub struct MeanAbsoluteError {}
|
||||||
|
|
||||||
@@ -54,6 +56,7 @@ impl MeanAbsoluteError {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn mean_absolute_error() {
|
fn mean_absolute_error() {
|
||||||
let y_true: Vec<f64> = vec![3., -0.5, 2., 7.];
|
let y_true: Vec<f64> = vec![3., -0.5, 2., 7.];
|
||||||
|
|||||||
@@ -18,12 +18,14 @@
|
|||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
/// Mean Squared Error
|
/// Mean Squared Error
|
||||||
pub struct MeanSquareError {}
|
pub struct MeanSquareError {}
|
||||||
|
|
||||||
@@ -54,6 +56,7 @@ impl MeanSquareError {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn mean_squared_error() {
|
fn mean_squared_error() {
|
||||||
let y_true: Vec<f64> = vec![3., -0.5, 2., 7.];
|
let y_true: Vec<f64> = vec![3., -0.5, 2., 7.];
|
||||||
|
|||||||
@@ -18,13 +18,15 @@
|
|||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
/// Precision metric.
|
/// Precision metric.
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct Precision {}
|
pub struct Precision {}
|
||||||
|
|
||||||
impl Precision {
|
impl Precision {
|
||||||
@@ -75,6 +77,7 @@ impl Precision {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn precision() {
|
fn precision() {
|
||||||
let y_true: Vec<f64> = vec![0., 1., 1., 0.];
|
let y_true: Vec<f64> = vec![0., 1., 1., 0.];
|
||||||
|
|||||||
+4
-1
@@ -18,13 +18,15 @@
|
|||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
/// Coefficient of Determination (R2)
|
/// Coefficient of Determination (R2)
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct R2 {}
|
pub struct R2 {}
|
||||||
|
|
||||||
impl R2 {
|
impl R2 {
|
||||||
@@ -68,6 +70,7 @@ impl R2 {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn r2() {
|
fn r2() {
|
||||||
let y_true: Vec<f64> = vec![3., -0.5, 2., 7.];
|
let y_true: Vec<f64> = vec![3., -0.5, 2., 7.];
|
||||||
|
|||||||
@@ -18,13 +18,15 @@
|
|||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
/// Recall metric.
|
/// Recall metric.
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct Recall {}
|
pub struct Recall {}
|
||||||
|
|
||||||
impl Recall {
|
impl Recall {
|
||||||
@@ -75,6 +77,7 @@ impl Recall {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn recall() {
|
fn recall() {
|
||||||
let y_true: Vec<f64> = vec![0., 1., 1., 0.];
|
let y_true: Vec<f64> = vec![0., 1., 1., 0.];
|
||||||
|
|||||||
@@ -144,6 +144,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::*;
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn run_kfold_return_test_indices_simple() {
|
fn run_kfold_return_test_indices_simple() {
|
||||||
let k = KFold {
|
let k = KFold {
|
||||||
@@ -158,6 +159,7 @@ mod tests {
|
|||||||
assert_eq!(test_indices[2], (22..33).collect::<Vec<usize>>());
|
assert_eq!(test_indices[2], (22..33).collect::<Vec<usize>>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn run_kfold_return_test_indices_odd() {
|
fn run_kfold_return_test_indices_odd() {
|
||||||
let k = KFold {
|
let k = KFold {
|
||||||
@@ -172,6 +174,7 @@ mod tests {
|
|||||||
assert_eq!(test_indices[2], (23..34).collect::<Vec<usize>>());
|
assert_eq!(test_indices[2], (23..34).collect::<Vec<usize>>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn run_kfold_return_test_mask_simple() {
|
fn run_kfold_return_test_mask_simple() {
|
||||||
let k = KFold {
|
let k = KFold {
|
||||||
@@ -197,6 +200,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn run_kfold_return_split_simple() {
|
fn run_kfold_return_split_simple() {
|
||||||
let k = KFold {
|
let k = KFold {
|
||||||
@@ -212,6 +216,7 @@ mod tests {
|
|||||||
assert_eq!(train_test_splits[1].1, (11..22).collect::<Vec<usize>>());
|
assert_eq!(train_test_splits[1].1, (11..22).collect::<Vec<usize>>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn run_kfold_return_split_simple_shuffle() {
|
fn run_kfold_return_split_simple_shuffle() {
|
||||||
let k = KFold {
|
let k = KFold {
|
||||||
@@ -227,6 +232,7 @@ mod tests {
|
|||||||
assert_eq!(train_test_splits[1].1.len(), 11_usize);
|
assert_eq!(train_test_splits[1].1.len(), 11_usize);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn numpy_parity_test() {
|
fn numpy_parity_test() {
|
||||||
let k = KFold {
|
let k = KFold {
|
||||||
@@ -247,6 +253,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn numpy_parity_test_shuffle() {
|
fn numpy_parity_test_shuffle() {
|
||||||
let k = KFold {
|
let k = KFold {
|
||||||
|
|||||||
@@ -285,6 +285,7 @@ mod tests {
|
|||||||
use crate::model_selection::kfold::KFold;
|
use crate::model_selection::kfold::KFold;
|
||||||
use crate::neighbors::knn_regressor::KNNRegressor;
|
use crate::neighbors::knn_regressor::KNNRegressor;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn run_train_test_split() {
|
fn run_train_test_split() {
|
||||||
let n = 123;
|
let n = 123;
|
||||||
@@ -308,6 +309,7 @@ mod tests {
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct NoParameters {}
|
struct NoParameters {}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn test_cross_validate_biased() {
|
fn test_cross_validate_biased() {
|
||||||
struct BiasedEstimator {}
|
struct BiasedEstimator {}
|
||||||
@@ -367,6 +369,7 @@ mod tests {
|
|||||||
assert_eq!(0.4, results.mean_train_score());
|
assert_eq!(0.4, results.mean_train_score());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn test_cross_validate_knn() {
|
fn test_cross_validate_knn() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
@@ -411,6 +414,7 @@ mod tests {
|
|||||||
assert!(results.mean_train_score() < results.mean_test_score());
|
assert!(results.mean_train_score() < results.mean_test_score());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn test_cross_val_predict_knn() {
|
fn test_cross_val_predict_knn() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
|
|||||||
+139
-20
@@ -42,15 +42,49 @@ use crate::math::num::RealNumber;
|
|||||||
use crate::math::vector::RealNumberVector;
|
use crate::math::vector::RealNumberVector;
|
||||||
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
|
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// Naive Bayes classifier for Bearnoulli features
|
/// Naive Bayes classifier for Bearnoulli features
|
||||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
struct BernoulliNBDistribution<T: RealNumber> {
|
struct BernoulliNBDistribution<T: RealNumber> {
|
||||||
/// class labels known to the classifier
|
/// class labels known to the classifier
|
||||||
class_labels: Vec<T>,
|
class_labels: Vec<T>,
|
||||||
|
/// number of training samples observed in each class
|
||||||
|
class_count: Vec<usize>,
|
||||||
|
/// probability of each class
|
||||||
class_priors: Vec<T>,
|
class_priors: Vec<T>,
|
||||||
feature_prob: Vec<Vec<T>>,
|
/// Number of samples encountered for each (class, feature)
|
||||||
|
feature_count: Vec<Vec<usize>>,
|
||||||
|
/// probability of features per class
|
||||||
|
feature_log_prob: Vec<Vec<T>>,
|
||||||
|
/// Number of features of each sample
|
||||||
|
n_features: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: RealNumber> PartialEq for BernoulliNBDistribution<T> {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
if self.class_labels == other.class_labels
|
||||||
|
&& self.class_count == other.class_count
|
||||||
|
&& self.class_priors == other.class_priors
|
||||||
|
&& self.feature_count == other.feature_count
|
||||||
|
&& self.n_features == other.n_features
|
||||||
|
{
|
||||||
|
for (a, b) in self
|
||||||
|
.feature_log_prob
|
||||||
|
.iter()
|
||||||
|
.zip(other.feature_log_prob.iter())
|
||||||
|
{
|
||||||
|
if !a.approximate_eq(b, T::epsilon()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for BernoulliNBDistribution<T> {
|
impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for BernoulliNBDistribution<T> {
|
||||||
@@ -63,9 +97,9 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for BernoulliNBDistributi
|
|||||||
for feature in 0..j.len() {
|
for feature in 0..j.len() {
|
||||||
let value = j.get(feature);
|
let value = j.get(feature);
|
||||||
if value == T::one() {
|
if value == T::one() {
|
||||||
likelihood += self.feature_prob[class_index][feature].ln();
|
likelihood += self.feature_log_prob[class_index][feature];
|
||||||
} else {
|
} else {
|
||||||
likelihood += (T::one() - self.feature_prob[class_index][feature]).ln();
|
likelihood += (T::one() - self.feature_log_prob[class_index][feature].exp()).ln();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
likelihood
|
likelihood
|
||||||
@@ -77,7 +111,8 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for BernoulliNBDistributi
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// `BernoulliNB` parameters. Use `Default::default()` for default values.
|
/// `BernoulliNB` parameters. Use `Default::default()` for default values.
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct BernoulliNBParameters<T: RealNumber> {
|
pub struct BernoulliNBParameters<T: RealNumber> {
|
||||||
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||||
pub alpha: T,
|
pub alpha: T,
|
||||||
@@ -154,10 +189,10 @@ impl<T: RealNumber> BernoulliNBDistribution<T> {
|
|||||||
let y = y.to_vec();
|
let y = y.to_vec();
|
||||||
|
|
||||||
let (class_labels, indices) = <Vec<T> as RealNumberVector<T>>::unique_with_indices(&y);
|
let (class_labels, indices) = <Vec<T> as RealNumberVector<T>>::unique_with_indices(&y);
|
||||||
let mut class_count = vec![T::zero(); class_labels.len()];
|
let mut class_count = vec![0_usize; class_labels.len()];
|
||||||
|
|
||||||
for class_index in indices.iter() {
|
for class_index in indices.iter() {
|
||||||
class_count[*class_index] += T::one();
|
class_count[*class_index] += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
let class_priors = if let Some(class_priors) = priors {
|
let class_priors = if let Some(class_priors) = priors {
|
||||||
@@ -170,25 +205,35 @@ impl<T: RealNumber> BernoulliNBDistribution<T> {
|
|||||||
} else {
|
} else {
|
||||||
class_count
|
class_count
|
||||||
.iter()
|
.iter()
|
||||||
.map(|&c| c / T::from(n_samples).unwrap())
|
.map(|&c| T::from(c).unwrap() / T::from(n_samples).unwrap())
|
||||||
.collect()
|
.collect()
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut feature_in_class_counter = vec![vec![T::zero(); n_features]; class_labels.len()];
|
let mut feature_in_class_counter = vec![vec![0_usize; n_features]; class_labels.len()];
|
||||||
|
|
||||||
for (row, class_index) in row_iter(x).zip(indices) {
|
for (row, class_index) in row_iter(x).zip(indices) {
|
||||||
for (idx, row_i) in row.iter().enumerate().take(n_features) {
|
for (idx, row_i) in row.iter().enumerate().take(n_features) {
|
||||||
feature_in_class_counter[class_index][idx] += *row_i;
|
feature_in_class_counter[class_index][idx] +=
|
||||||
|
row_i.to_usize().ok_or_else(|| {
|
||||||
|
Failed::fit(&format!(
|
||||||
|
"Elements of the matrix should be 1.0 or 0.0 |found|=[{}]",
|
||||||
|
row_i
|
||||||
|
))
|
||||||
|
})?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let feature_prob = feature_in_class_counter
|
let feature_log_prob = feature_in_class_counter
|
||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(class_index, feature_count)| {
|
.map(|(class_index, feature_count)| {
|
||||||
feature_count
|
feature_count
|
||||||
.iter()
|
.iter()
|
||||||
.map(|&count| (count + alpha) / (class_count[class_index] + alpha * T::two()))
|
.map(|&count| {
|
||||||
|
((T::from(count).unwrap() + alpha)
|
||||||
|
/ (T::from(class_count[class_index]).unwrap() + alpha * T::two()))
|
||||||
|
.ln()
|
||||||
|
})
|
||||||
.collect()
|
.collect()
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
@@ -196,13 +241,18 @@ impl<T: RealNumber> BernoulliNBDistribution<T> {
|
|||||||
Ok(Self {
|
Ok(Self {
|
||||||
class_labels,
|
class_labels,
|
||||||
class_priors,
|
class_priors,
|
||||||
feature_prob,
|
class_count,
|
||||||
|
feature_count: feature_in_class_counter,
|
||||||
|
feature_log_prob,
|
||||||
|
n_features,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// BernoulliNB implements the categorical naive Bayes algorithm for categorically distributed data.
|
/// BernoulliNB implements the naive Bayes algorithm for data that follows the Bernoulli
|
||||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
/// distribution.
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
pub struct BernoulliNB<T: RealNumber, M: Matrix<T>> {
|
pub struct BernoulliNB<T: RealNumber, M: Matrix<T>> {
|
||||||
inner: BaseNaiveBayes<T, M, BernoulliNBDistribution<T>>,
|
inner: BaseNaiveBayes<T, M, BernoulliNBDistribution<T>>,
|
||||||
binarize: Option<T>,
|
binarize: Option<T>,
|
||||||
@@ -262,6 +312,34 @@ impl<T: RealNumber, M: Matrix<T>> BernoulliNB<T, M> {
|
|||||||
self.inner.predict(x)
|
self.inner.predict(x)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Class labels known to the classifier.
|
||||||
|
/// Returns a vector of size n_classes.
|
||||||
|
pub fn classes(&self) -> &Vec<T> {
|
||||||
|
&self.inner.distribution.class_labels
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Number of training samples observed in each class.
|
||||||
|
/// Returns a vector of size n_classes.
|
||||||
|
pub fn class_count(&self) -> &Vec<usize> {
|
||||||
|
&self.inner.distribution.class_count
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Number of features of each sample
|
||||||
|
pub fn n_features(&self) -> usize {
|
||||||
|
self.inner.distribution.n_features
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Number of samples encountered for each (class, feature)
|
||||||
|
/// Returns a 2d vector of shape (n_classes, n_features)
|
||||||
|
pub fn feature_count(&self) -> &Vec<Vec<usize>> {
|
||||||
|
&self.inner.distribution.feature_count
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Empirical log probability of features given a class
|
||||||
|
pub fn feature_log_prob(&self) -> &Vec<Vec<T>> {
|
||||||
|
&self.inner.distribution.feature_log_prob
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -269,6 +347,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn run_bernoulli_naive_bayes() {
|
fn run_bernoulli_naive_bayes() {
|
||||||
// Tests that BernoulliNB when alpha=1.0 gives the same values as
|
// Tests that BernoulliNB when alpha=1.0 gives the same values as
|
||||||
@@ -292,10 +371,24 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(bnb.inner.distribution.class_priors, &[0.75, 0.25]);
|
assert_eq!(bnb.inner.distribution.class_priors, &[0.75, 0.25]);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
bnb.inner.distribution.feature_prob,
|
bnb.feature_log_prob(),
|
||||||
&[
|
&[
|
||||||
&[0.4, 0.8, 0.2, 0.4, 0.4, 0.2],
|
&[
|
||||||
&[1. / 3.0, 2. / 3.0, 2. / 3.0, 1. / 3.0, 1. / 3.0, 2. / 3.0]
|
-0.916290731874155,
|
||||||
|
-0.2231435513142097,
|
||||||
|
-1.6094379124341003,
|
||||||
|
-0.916290731874155,
|
||||||
|
-0.916290731874155,
|
||||||
|
-1.6094379124341003
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
-1.0986122886681098,
|
||||||
|
-0.40546510810816444,
|
||||||
|
-0.40546510810816444,
|
||||||
|
-1.0986122886681098,
|
||||||
|
-1.0986122886681098,
|
||||||
|
-0.40546510810816444
|
||||||
|
]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -307,6 +400,7 @@ mod tests {
|
|||||||
assert_eq!(y_hat, &[1.]);
|
assert_eq!(y_hat, &[1.]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn bernoulli_nb_scikit_parity() {
|
fn bernoulli_nb_scikit_parity() {
|
||||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
@@ -331,13 +425,36 @@ mod tests {
|
|||||||
|
|
||||||
let y_hat = bnb.predict(&x).unwrap();
|
let y_hat = bnb.predict(&x).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(bnb.classes(), &[0., 1., 2.]);
|
||||||
|
assert_eq!(bnb.class_count(), &[7, 3, 5]);
|
||||||
|
assert_eq!(bnb.n_features(), 10);
|
||||||
|
assert_eq!(
|
||||||
|
bnb.feature_count(),
|
||||||
|
&[
|
||||||
|
&[5, 6, 6, 7, 6, 4, 6, 7, 7, 7],
|
||||||
|
&[3, 3, 3, 1, 3, 2, 3, 2, 2, 3],
|
||||||
|
&[4, 4, 3, 4, 5, 2, 4, 5, 3, 4]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
assert!(bnb
|
assert!(bnb
|
||||||
.inner
|
.inner
|
||||||
.distribution
|
.distribution
|
||||||
.class_priors
|
.class_priors
|
||||||
.approximate_eq(&vec!(0.46, 0.2, 0.33), 1e-2));
|
.approximate_eq(&vec!(0.46, 0.2, 0.33), 1e-2));
|
||||||
assert!(bnb.inner.distribution.feature_prob[1].approximate_eq(
|
assert!(bnb.feature_log_prob()[1].approximate_eq(
|
||||||
&vec!(0.8, 0.8, 0.8, 0.4, 0.8, 0.6, 0.8, 0.6, 0.6, 0.8),
|
&vec![
|
||||||
|
-0.22314355,
|
||||||
|
-0.22314355,
|
||||||
|
-0.22314355,
|
||||||
|
-0.91629073,
|
||||||
|
-0.22314355,
|
||||||
|
-0.51082562,
|
||||||
|
-0.22314355,
|
||||||
|
-0.51082562,
|
||||||
|
-0.51082562,
|
||||||
|
-0.22314355
|
||||||
|
],
|
||||||
1e-1
|
1e-1
|
||||||
));
|
));
|
||||||
assert!(y_hat.approximate_eq(
|
assert!(y_hat.approximate_eq(
|
||||||
@@ -346,7 +463,9 @@ mod tests {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
&[1., 1., 0., 0., 0., 0.],
|
&[1., 1., 0., 0., 0., 0.],
|
||||||
|
|||||||
+144
-25
@@ -36,19 +36,38 @@ use crate::linalg::BaseVector;
|
|||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
|
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// Naive Bayes classifier for categorical features
|
/// Naive Bayes classifier for categorical features
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
struct CategoricalNBDistribution<T: RealNumber> {
|
struct CategoricalNBDistribution<T: RealNumber> {
|
||||||
|
/// number of training samples observed in each class
|
||||||
|
class_count: Vec<usize>,
|
||||||
|
/// class labels known to the classifier
|
||||||
class_labels: Vec<T>,
|
class_labels: Vec<T>,
|
||||||
|
/// probability of each class
|
||||||
class_priors: Vec<T>,
|
class_priors: Vec<T>,
|
||||||
coefficients: Vec<Vec<Vec<T>>>,
|
coefficients: Vec<Vec<Vec<T>>>,
|
||||||
|
/// Number of features of each sample
|
||||||
|
n_features: usize,
|
||||||
|
/// Number of categories for each feature
|
||||||
|
n_categories: Vec<usize>,
|
||||||
|
/// Holds arrays of shape (n_classes, n_categories of respective feature)
|
||||||
|
/// for each feature. Each array provides the number of samples
|
||||||
|
/// encountered for each class and category of the specific feature.
|
||||||
|
category_count: Vec<Vec<Vec<usize>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: RealNumber> PartialEq for CategoricalNBDistribution<T> {
|
impl<T: RealNumber> PartialEq for CategoricalNBDistribution<T> {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
if self.class_labels == other.class_labels && self.class_priors == other.class_priors {
|
if self.class_labels == other.class_labels
|
||||||
|
&& self.class_priors == other.class_priors
|
||||||
|
&& self.n_features == other.n_features
|
||||||
|
&& self.n_categories == other.n_categories
|
||||||
|
&& self.class_count == other.class_count
|
||||||
|
{
|
||||||
if self.coefficients.len() != other.coefficients.len() {
|
if self.coefficients.len() != other.coefficients.len() {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -88,8 +107,8 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for CategoricalNBDistribu
|
|||||||
let mut likelihood = T::zero();
|
let mut likelihood = T::zero();
|
||||||
for feature in 0..j.len() {
|
for feature in 0..j.len() {
|
||||||
let value = j.get(feature).floor().to_usize().unwrap();
|
let value = j.get(feature).floor().to_usize().unwrap();
|
||||||
if self.coefficients[class_index][feature].len() > value {
|
if self.coefficients[feature][class_index].len() > value {
|
||||||
likelihood += self.coefficients[class_index][feature][value];
|
likelihood += self.coefficients[feature][class_index][value];
|
||||||
} else {
|
} else {
|
||||||
return T::zero();
|
return T::zero();
|
||||||
}
|
}
|
||||||
@@ -142,17 +161,17 @@ impl<T: RealNumber> CategoricalNBDistribution<T> {
|
|||||||
let y_max = y
|
let y_max = y
|
||||||
.iter()
|
.iter()
|
||||||
.max()
|
.max()
|
||||||
.ok_or_else(|| Failed::fit(&"Failed to get the labels of y.".to_string()))?;
|
.ok_or_else(|| Failed::fit("Failed to get the labels of y."))?;
|
||||||
|
|
||||||
let class_labels: Vec<T> = (0..*y_max + 1)
|
let class_labels: Vec<T> = (0..*y_max + 1)
|
||||||
.map(|label| T::from(label).unwrap())
|
.map(|label| T::from(label).unwrap())
|
||||||
.collect();
|
.collect();
|
||||||
let mut classes_count: Vec<T> = vec![T::zero(); class_labels.len()];
|
let mut class_count = vec![0_usize; class_labels.len()];
|
||||||
for elem in y.iter() {
|
for elem in y.iter() {
|
||||||
classes_count[*elem] += T::one();
|
class_count[*elem] += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut feature_categories: Vec<Vec<T>> = Vec::with_capacity(n_features);
|
let mut n_categories: Vec<usize> = Vec::with_capacity(n_features);
|
||||||
for feature in 0..n_features {
|
for feature in 0..n_features {
|
||||||
let feature_max = x
|
let feature_max = x
|
||||||
.get_col_as_vec(feature)
|
.get_col_as_vec(feature)
|
||||||
@@ -165,18 +184,15 @@ impl<T: RealNumber> CategoricalNBDistribution<T> {
|
|||||||
feature
|
feature
|
||||||
))
|
))
|
||||||
})?;
|
})?;
|
||||||
let feature_types = (0..feature_max + 1)
|
n_categories.push(feature_max + 1);
|
||||||
.map(|feat| T::from(feat).unwrap())
|
|
||||||
.collect();
|
|
||||||
feature_categories.push(feature_types);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut coefficients: Vec<Vec<Vec<T>>> = Vec::with_capacity(class_labels.len());
|
let mut coefficients: Vec<Vec<Vec<T>>> = Vec::with_capacity(class_labels.len());
|
||||||
for (label, label_count) in class_labels.iter().zip(classes_count.iter()) {
|
let mut category_count: Vec<Vec<Vec<usize>>> = Vec::with_capacity(class_labels.len());
|
||||||
|
for (feature_index, &n_categories_i) in n_categories.iter().enumerate().take(n_features) {
|
||||||
let mut coef_i: Vec<Vec<T>> = Vec::with_capacity(n_features);
|
let mut coef_i: Vec<Vec<T>> = Vec::with_capacity(n_features);
|
||||||
for (feature_index, feature_options) in
|
let mut category_count_i: Vec<Vec<usize>> = Vec::with_capacity(n_features);
|
||||||
feature_categories.iter().enumerate().take(n_features)
|
for (label, &label_count) in class_labels.iter().zip(class_count.iter()) {
|
||||||
{
|
|
||||||
let col = x
|
let col = x
|
||||||
.get_col_as_vec(feature_index)
|
.get_col_as_vec(feature_index)
|
||||||
.iter()
|
.iter()
|
||||||
@@ -184,39 +200,48 @@ impl<T: RealNumber> CategoricalNBDistribution<T> {
|
|||||||
.filter(|(i, _j)| T::from(y[*i]).unwrap() == *label)
|
.filter(|(i, _j)| T::from(y[*i]).unwrap() == *label)
|
||||||
.map(|(_, j)| *j)
|
.map(|(_, j)| *j)
|
||||||
.collect::<Vec<T>>();
|
.collect::<Vec<T>>();
|
||||||
let mut feat_count: Vec<T> = vec![T::zero(); feature_options.len()];
|
let mut feat_count: Vec<usize> = vec![0_usize; n_categories_i];
|
||||||
for row in col.iter() {
|
for row in col.iter() {
|
||||||
let index = row.floor().to_usize().unwrap();
|
let index = row.floor().to_usize().unwrap();
|
||||||
feat_count[index] += T::one();
|
feat_count[index] += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
let coef_i_j = feat_count
|
let coef_i_j = feat_count
|
||||||
.iter()
|
.iter()
|
||||||
.map(|c| {
|
.map(|c| {
|
||||||
((*c + alpha)
|
((T::from(*c).unwrap() + alpha)
|
||||||
/ (*label_count + T::from(feature_options.len()).unwrap() * alpha))
|
/ (T::from(label_count).unwrap()
|
||||||
|
+ T::from(n_categories_i).unwrap() * alpha))
|
||||||
.ln()
|
.ln()
|
||||||
})
|
})
|
||||||
.collect::<Vec<T>>();
|
.collect::<Vec<T>>();
|
||||||
|
category_count_i.push(feat_count);
|
||||||
coef_i.push(coef_i_j);
|
coef_i.push(coef_i_j);
|
||||||
}
|
}
|
||||||
|
category_count.push(category_count_i);
|
||||||
coefficients.push(coef_i);
|
coefficients.push(coef_i);
|
||||||
}
|
}
|
||||||
|
|
||||||
let class_priors = classes_count
|
let class_priors = class_count
|
||||||
.into_iter()
|
.iter()
|
||||||
.map(|count| count / T::from(n_samples).unwrap())
|
.map(|&count| T::from(count).unwrap() / T::from(n_samples).unwrap())
|
||||||
.collect::<Vec<T>>();
|
.collect::<Vec<T>>();
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
class_count,
|
||||||
class_labels,
|
class_labels,
|
||||||
class_priors,
|
class_priors,
|
||||||
coefficients,
|
coefficients,
|
||||||
|
n_features,
|
||||||
|
n_categories,
|
||||||
|
category_count,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// `CategoricalNB` parameters. Use `Default::default()` for default values.
|
/// `CategoricalNB` parameters. Use `Default::default()` for default values.
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct CategoricalNBParameters<T: RealNumber> {
|
pub struct CategoricalNBParameters<T: RealNumber> {
|
||||||
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||||
pub alpha: T,
|
pub alpha: T,
|
||||||
@@ -237,7 +262,8 @@ impl<T: RealNumber> Default for CategoricalNBParameters<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// CategoricalNB implements the categorical naive Bayes algorithm for categorically distributed data.
|
/// CategoricalNB implements the categorical naive Bayes algorithm for categorically distributed data.
|
||||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
pub struct CategoricalNB<T: RealNumber, M: Matrix<T>> {
|
pub struct CategoricalNB<T: RealNumber, M: Matrix<T>> {
|
||||||
inner: BaseNaiveBayes<T, M, CategoricalNBDistribution<T>>,
|
inner: BaseNaiveBayes<T, M, CategoricalNBDistribution<T>>,
|
||||||
}
|
}
|
||||||
@@ -283,6 +309,41 @@ impl<T: RealNumber, M: Matrix<T>> CategoricalNB<T, M> {
|
|||||||
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||||
self.inner.predict(x)
|
self.inner.predict(x)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Class labels known to the classifier.
|
||||||
|
/// Returns a vector of size n_classes.
|
||||||
|
pub fn classes(&self) -> &Vec<T> {
|
||||||
|
&self.inner.distribution.class_labels
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Number of training samples observed in each class.
|
||||||
|
/// Returns a vector of size n_classes.
|
||||||
|
pub fn class_count(&self) -> &Vec<usize> {
|
||||||
|
&self.inner.distribution.class_count
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Number of features of each sample
|
||||||
|
pub fn n_features(&self) -> usize {
|
||||||
|
self.inner.distribution.n_features
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Number of features of each sample
|
||||||
|
pub fn n_categories(&self) -> &Vec<usize> {
|
||||||
|
&self.inner.distribution.n_categories
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Holds arrays of shape (n_classes, n_categories of respective feature)
|
||||||
|
/// for each feature. Each array provides the number of samples
|
||||||
|
/// encountered for each class and category of the specific feature.
|
||||||
|
pub fn category_count(&self) -> &Vec<Vec<Vec<usize>>> {
|
||||||
|
&self.inner.distribution.category_count
|
||||||
|
}
|
||||||
|
/// Holds arrays of shape (n_classes, n_categories of respective feature)
|
||||||
|
/// for each feature. Each array provides the empirical log probability
|
||||||
|
/// of categories given the respective feature and class, ``P(x_i|y)``.
|
||||||
|
pub fn feature_log_prob(&self) -> &Vec<Vec<Vec<T>>> {
|
||||||
|
&self.inner.distribution.coefficients
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -290,6 +351,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn run_categorical_naive_bayes() {
|
fn run_categorical_naive_bayes() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
@@ -311,11 +373,66 @@ mod tests {
|
|||||||
let y = vec![0., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 0.];
|
let y = vec![0., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 0.];
|
||||||
|
|
||||||
let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap();
|
let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap();
|
||||||
|
|
||||||
|
// checking parity with scikit
|
||||||
|
assert_eq!(cnb.classes(), &[0., 1.]);
|
||||||
|
assert_eq!(cnb.class_count(), &[5, 9]);
|
||||||
|
assert_eq!(cnb.n_features(), 4);
|
||||||
|
assert_eq!(cnb.n_categories(), &[3, 3, 2, 2]);
|
||||||
|
assert_eq!(
|
||||||
|
cnb.category_count(),
|
||||||
|
&vec![
|
||||||
|
vec![vec![3, 0, 2], vec![2, 4, 3]],
|
||||||
|
vec![vec![1, 2, 2], vec![3, 4, 2]],
|
||||||
|
vec![vec![1, 4], vec![6, 3]],
|
||||||
|
vec![vec![2, 3], vec![6, 3]]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
cnb.feature_log_prob(),
|
||||||
|
&vec![
|
||||||
|
vec![
|
||||||
|
vec![
|
||||||
|
-0.6931471805599453,
|
||||||
|
-2.0794415416798357,
|
||||||
|
-0.9808292530117262
|
||||||
|
],
|
||||||
|
vec![
|
||||||
|
-1.3862943611198906,
|
||||||
|
-0.8754687373538999,
|
||||||
|
-1.0986122886681098
|
||||||
|
]
|
||||||
|
],
|
||||||
|
vec![
|
||||||
|
vec![
|
||||||
|
-1.3862943611198906,
|
||||||
|
-0.9808292530117262,
|
||||||
|
-0.9808292530117262
|
||||||
|
],
|
||||||
|
vec![
|
||||||
|
-1.0986122886681098,
|
||||||
|
-0.8754687373538999,
|
||||||
|
-1.3862943611198906
|
||||||
|
]
|
||||||
|
],
|
||||||
|
vec![
|
||||||
|
vec![-1.252762968495368, -0.3364722366212129],
|
||||||
|
vec![-0.45198512374305727, -1.0116009116784799]
|
||||||
|
],
|
||||||
|
vec![
|
||||||
|
vec![-0.8472978603872037, -0.5596157879354228],
|
||||||
|
vec![-0.45198512374305727, -1.0116009116784799]
|
||||||
|
]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
let x_test = DenseMatrix::from_2d_array(&[&[0., 2., 1., 0.], &[2., 2., 0., 0.]]);
|
let x_test = DenseMatrix::from_2d_array(&[&[0., 2., 1., 0.], &[2., 2., 0., 0.]]);
|
||||||
let y_hat = cnb.predict(&x_test).unwrap();
|
let y_hat = cnb.predict(&x_test).unwrap();
|
||||||
assert_eq!(y_hat, vec![0., 1.]);
|
assert_eq!(y_hat, vec![0., 1.]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn run_categorical_naive_bayes2() {
|
fn run_categorical_naive_bayes2() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
@@ -344,7 +461,9 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
&[3., 4., 0., 1.],
|
&[3., 4., 0., 1.],
|
||||||
|
|||||||
+70
-27
@@ -30,17 +30,21 @@ use crate::linalg::Matrix;
|
|||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
use crate::math::vector::RealNumberVector;
|
use crate::math::vector::RealNumberVector;
|
||||||
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
|
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// Naive Bayes classifier for categorical features
|
/// Naive Bayes classifier using Gaussian distribution
|
||||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
struct GaussianNBDistribution<T: RealNumber> {
|
struct GaussianNBDistribution<T: RealNumber> {
|
||||||
/// class labels known to the classifier
|
/// class labels known to the classifier
|
||||||
class_labels: Vec<T>,
|
class_labels: Vec<T>,
|
||||||
|
/// number of training samples observed in each class
|
||||||
|
class_count: Vec<usize>,
|
||||||
/// probability of each class.
|
/// probability of each class.
|
||||||
class_priors: Vec<T>,
|
class_priors: Vec<T>,
|
||||||
/// variance of each feature per class
|
/// variance of each feature per class
|
||||||
sigma: Vec<Vec<T>>,
|
var: Vec<Vec<T>>,
|
||||||
/// mean of each feature per class
|
/// mean of each feature per class
|
||||||
theta: Vec<Vec<T>>,
|
theta: Vec<Vec<T>>,
|
||||||
}
|
}
|
||||||
@@ -55,18 +59,14 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for GaussianNBDistributio
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn log_likelihood(&self, class_index: usize, j: &M::RowVector) -> T {
|
fn log_likelihood(&self, class_index: usize, j: &M::RowVector) -> T {
|
||||||
if class_index < self.class_labels.len() {
|
let mut likelihood = T::zero();
|
||||||
let mut likelihood = T::zero();
|
for feature in 0..j.len() {
|
||||||
for feature in 0..j.len() {
|
let value = j.get(feature);
|
||||||
let value = j.get(feature);
|
let mean = self.theta[class_index][feature];
|
||||||
let mean = self.theta[class_index][feature];
|
let variance = self.var[class_index][feature];
|
||||||
let variance = self.sigma[class_index][feature];
|
likelihood += self.calculate_log_probability(value, mean, variance);
|
||||||
likelihood += self.calculate_log_probability(value, mean, variance);
|
|
||||||
}
|
|
||||||
likelihood
|
|
||||||
} else {
|
|
||||||
T::zero()
|
|
||||||
}
|
}
|
||||||
|
likelihood
|
||||||
}
|
}
|
||||||
|
|
||||||
fn classes(&self) -> &Vec<T> {
|
fn classes(&self) -> &Vec<T> {
|
||||||
@@ -75,7 +75,8 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for GaussianNBDistributio
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// `GaussianNB` parameters. Use `Default::default()` for default values.
|
/// `GaussianNB` parameters. Use `Default::default()` for default values.
|
||||||
#[derive(Serialize, Deserialize, Debug, Default, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Default, Clone)]
|
||||||
pub struct GaussianNBParameters<T: RealNumber> {
|
pub struct GaussianNBParameters<T: RealNumber> {
|
||||||
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
|
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
|
||||||
pub priors: Option<Vec<T>>,
|
pub priors: Option<Vec<T>>,
|
||||||
@@ -118,12 +119,12 @@ impl<T: RealNumber> GaussianNBDistribution<T> {
|
|||||||
let y = y.to_vec();
|
let y = y.to_vec();
|
||||||
let (class_labels, indices) = <Vec<T> as RealNumberVector<T>>::unique_with_indices(&y);
|
let (class_labels, indices) = <Vec<T> as RealNumberVector<T>>::unique_with_indices(&y);
|
||||||
|
|
||||||
let mut class_count = vec![T::zero(); class_labels.len()];
|
let mut class_count = vec![0_usize; class_labels.len()];
|
||||||
|
|
||||||
let mut subdataset: Vec<Vec<Vec<T>>> = vec![vec![]; class_labels.len()];
|
let mut subdataset: Vec<Vec<Vec<T>>> = vec![vec![]; class_labels.len()];
|
||||||
|
|
||||||
for (row, class_index) in row_iter(x).zip(indices.iter()) {
|
for (row, class_index) in row_iter(x).zip(indices.iter()) {
|
||||||
class_count[*class_index] += T::one();
|
class_count[*class_index] += 1;
|
||||||
subdataset[*class_index].push(row);
|
subdataset[*class_index].push(row);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,8 +137,8 @@ impl<T: RealNumber> GaussianNBDistribution<T> {
|
|||||||
class_priors
|
class_priors
|
||||||
} else {
|
} else {
|
||||||
class_count
|
class_count
|
||||||
.into_iter()
|
.iter()
|
||||||
.map(|c| c / T::from(n_samples).unwrap())
|
.map(|&c| T::from(c).unwrap() / T::from(n_samples).unwrap())
|
||||||
.collect()
|
.collect()
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -154,15 +155,16 @@ impl<T: RealNumber> GaussianNBDistribution<T> {
|
|||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let (sigma, theta): (Vec<Vec<T>>, Vec<Vec<T>>) = subdataset
|
let (var, theta): (Vec<Vec<T>>, Vec<Vec<T>>) = subdataset
|
||||||
.iter()
|
.iter()
|
||||||
.map(|data| (data.var(0), data.mean(0)))
|
.map(|data| (data.var(0), data.mean(0)))
|
||||||
.unzip();
|
.unzip();
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
class_labels,
|
class_labels,
|
||||||
|
class_count,
|
||||||
class_priors,
|
class_priors,
|
||||||
sigma,
|
var,
|
||||||
theta,
|
theta,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -177,8 +179,10 @@ impl<T: RealNumber> GaussianNBDistribution<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// GaussianNB implements the categorical naive Bayes algorithm for categorically distributed data.
|
/// GaussianNB implements the naive Bayes algorithm for data that follows the Gaussian
|
||||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
/// distribution.
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
pub struct GaussianNB<T: RealNumber, M: Matrix<T>> {
|
pub struct GaussianNB<T: RealNumber, M: Matrix<T>> {
|
||||||
inner: BaseNaiveBayes<T, M, GaussianNBDistribution<T>>,
|
inner: BaseNaiveBayes<T, M, GaussianNBDistribution<T>>,
|
||||||
}
|
}
|
||||||
@@ -219,6 +223,36 @@ impl<T: RealNumber, M: Matrix<T>> GaussianNB<T, M> {
|
|||||||
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||||
self.inner.predict(x)
|
self.inner.predict(x)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Class labels known to the classifier.
|
||||||
|
/// Returns a vector of size n_classes.
|
||||||
|
pub fn classes(&self) -> &Vec<T> {
|
||||||
|
&self.inner.distribution.class_labels
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Number of training samples observed in each class.
|
||||||
|
/// Returns a vector of size n_classes.
|
||||||
|
pub fn class_count(&self) -> &Vec<usize> {
|
||||||
|
&self.inner.distribution.class_count
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Probability of each class
|
||||||
|
/// Returns a vector of size n_classes.
|
||||||
|
pub fn class_priors(&self) -> &Vec<T> {
|
||||||
|
&self.inner.distribution.class_priors
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mean of each feature per class
|
||||||
|
/// Returns a 2d vector of shape (n_classes, n_features).
|
||||||
|
pub fn theta(&self) -> &Vec<Vec<T>> {
|
||||||
|
&self.inner.distribution.theta
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Variance of each feature per class
|
||||||
|
/// Returns a 2d vector of shape (n_classes, n_features).
|
||||||
|
pub fn var(&self) -> &Vec<Vec<T>> {
|
||||||
|
&self.inner.distribution.var
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -226,6 +260,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn run_gaussian_naive_bayes() {
|
fn run_gaussian_naive_bayes() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
@@ -241,22 +276,28 @@ mod tests {
|
|||||||
let gnb = GaussianNB::fit(&x, &y, Default::default()).unwrap();
|
let gnb = GaussianNB::fit(&x, &y, Default::default()).unwrap();
|
||||||
let y_hat = gnb.predict(&x).unwrap();
|
let y_hat = gnb.predict(&x).unwrap();
|
||||||
assert_eq!(y_hat, y);
|
assert_eq!(y_hat, y);
|
||||||
|
|
||||||
|
assert_eq!(gnb.classes(), &[1., 2.]);
|
||||||
|
|
||||||
|
assert_eq!(gnb.class_count(), &[3, 3]);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
gnb.inner.distribution.sigma,
|
gnb.var(),
|
||||||
&[
|
&[
|
||||||
&[0.666666666666667, 0.22222222222222232],
|
&[0.666666666666667, 0.22222222222222232],
|
||||||
&[0.666666666666667, 0.22222222222222232]
|
&[0.666666666666667, 0.22222222222222232]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(gnb.inner.distribution.class_priors, &[0.5, 0.5]);
|
assert_eq!(gnb.class_priors(), &[0.5, 0.5]);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
gnb.inner.distribution.theta,
|
gnb.theta(),
|
||||||
&[&[-2., -1.3333333333333333], &[2., 1.3333333333333333]]
|
&[&[-2., -1.3333333333333333], &[2., 1.3333333333333333]]
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn run_gaussian_naive_bayes_with_priors() {
|
fn run_gaussian_naive_bayes_with_priors() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
@@ -273,10 +314,12 @@ mod tests {
|
|||||||
let parameters = GaussianNBParameters::default().with_priors(priors.clone());
|
let parameters = GaussianNBParameters::default().with_priors(priors.clone());
|
||||||
let gnb = GaussianNB::fit(&x, &y, parameters).unwrap();
|
let gnb = GaussianNB::fit(&x, &y, parameters).unwrap();
|
||||||
|
|
||||||
assert_eq!(gnb.inner.distribution.class_priors, priors);
|
assert_eq!(gnb.class_priors(), &priors);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
&[-1., -1.],
|
&[-1., -1.],
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ use crate::error::Failed;
|
|||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
@@ -55,7 +56,8 @@ pub(crate) trait NBDistribution<T: RealNumber, M: Matrix<T>> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Base struct for the Naive Bayes classifier.
|
/// Base struct for the Naive Bayes classifier.
|
||||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
pub(crate) struct BaseNaiveBayes<T: RealNumber, M: Matrix<T>, D: NBDistribution<T, M>> {
|
pub(crate) struct BaseNaiveBayes<T: RealNumber, M: Matrix<T>, D: NBDistribution<T, M>> {
|
||||||
distribution: D,
|
distribution: D,
|
||||||
_phantom_t: PhantomData<T>,
|
_phantom_t: PhantomData<T>,
|
||||||
|
|||||||
+117
-21
@@ -42,15 +42,25 @@ use crate::math::num::RealNumber;
|
|||||||
use crate::math::vector::RealNumberVector;
|
use crate::math::vector::RealNumberVector;
|
||||||
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
|
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// Naive Bayes classifier for Multinomial features
|
/// Naive Bayes classifier for Multinomial features
|
||||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
struct MultinomialNBDistribution<T: RealNumber> {
|
struct MultinomialNBDistribution<T: RealNumber> {
|
||||||
/// class labels known to the classifier
|
/// class labels known to the classifier
|
||||||
class_labels: Vec<T>,
|
class_labels: Vec<T>,
|
||||||
|
/// number of training samples observed in each class
|
||||||
|
class_count: Vec<usize>,
|
||||||
|
/// probability of each class
|
||||||
class_priors: Vec<T>,
|
class_priors: Vec<T>,
|
||||||
feature_prob: Vec<Vec<T>>,
|
/// Empirical log probability of features given a class
|
||||||
|
feature_log_prob: Vec<Vec<T>>,
|
||||||
|
/// Number of samples encountered for each (class, feature)
|
||||||
|
feature_count: Vec<Vec<usize>>,
|
||||||
|
/// Number of features of each sample
|
||||||
|
n_features: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for MultinomialNBDistribution<T> {
|
impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for MultinomialNBDistribution<T> {
|
||||||
@@ -62,7 +72,7 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for MultinomialNBDistribu
|
|||||||
let mut likelihood = T::zero();
|
let mut likelihood = T::zero();
|
||||||
for feature in 0..j.len() {
|
for feature in 0..j.len() {
|
||||||
let value = j.get(feature);
|
let value = j.get(feature);
|
||||||
likelihood += value * self.feature_prob[class_index][feature].ln();
|
likelihood += value * self.feature_log_prob[class_index][feature];
|
||||||
}
|
}
|
||||||
likelihood
|
likelihood
|
||||||
}
|
}
|
||||||
@@ -73,7 +83,8 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for MultinomialNBDistribu
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// `MultinomialNB` parameters. Use `Default::default()` for default values.
|
/// `MultinomialNB` parameters. Use `Default::default()` for default values.
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct MultinomialNBParameters<T: RealNumber> {
|
pub struct MultinomialNBParameters<T: RealNumber> {
|
||||||
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||||
pub alpha: T,
|
pub alpha: T,
|
||||||
@@ -141,10 +152,10 @@ impl<T: RealNumber> MultinomialNBDistribution<T> {
|
|||||||
let y = y.to_vec();
|
let y = y.to_vec();
|
||||||
|
|
||||||
let (class_labels, indices) = <Vec<T> as RealNumberVector<T>>::unique_with_indices(&y);
|
let (class_labels, indices) = <Vec<T> as RealNumberVector<T>>::unique_with_indices(&y);
|
||||||
let mut class_count = vec![T::zero(); class_labels.len()];
|
let mut class_count = vec![0_usize; class_labels.len()];
|
||||||
|
|
||||||
for class_index in indices.iter() {
|
for class_index in indices.iter() {
|
||||||
class_count[*class_index] += T::one();
|
class_count[*class_index] += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
let class_priors = if let Some(class_priors) = priors {
|
let class_priors = if let Some(class_priors) = priors {
|
||||||
@@ -157,39 +168,53 @@ impl<T: RealNumber> MultinomialNBDistribution<T> {
|
|||||||
} else {
|
} else {
|
||||||
class_count
|
class_count
|
||||||
.iter()
|
.iter()
|
||||||
.map(|&c| c / T::from(n_samples).unwrap())
|
.map(|&c| T::from(c).unwrap() / T::from(n_samples).unwrap())
|
||||||
.collect()
|
.collect()
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut feature_in_class_counter = vec![vec![T::zero(); n_features]; class_labels.len()];
|
let mut feature_in_class_counter = vec![vec![0_usize; n_features]; class_labels.len()];
|
||||||
|
|
||||||
for (row, class_index) in row_iter(x).zip(indices) {
|
for (row, class_index) in row_iter(x).zip(indices) {
|
||||||
for (idx, row_i) in row.iter().enumerate().take(n_features) {
|
for (idx, row_i) in row.iter().enumerate().take(n_features) {
|
||||||
feature_in_class_counter[class_index][idx] += *row_i;
|
feature_in_class_counter[class_index][idx] +=
|
||||||
|
row_i.to_usize().ok_or_else(|| {
|
||||||
|
Failed::fit(&format!(
|
||||||
|
"Elements of the matrix should be convertible to usize |found|=[{}]",
|
||||||
|
row_i
|
||||||
|
))
|
||||||
|
})?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let feature_prob = feature_in_class_counter
|
let feature_log_prob = feature_in_class_counter
|
||||||
.iter()
|
.iter()
|
||||||
.map(|feature_count| {
|
.map(|feature_count| {
|
||||||
let n_c = feature_count.sum();
|
let n_c: usize = feature_count.iter().sum();
|
||||||
feature_count
|
feature_count
|
||||||
.iter()
|
.iter()
|
||||||
.map(|&count| (count + alpha) / (n_c + alpha * T::from(n_features).unwrap()))
|
.map(|&count| {
|
||||||
|
((T::from(count).unwrap() + alpha)
|
||||||
|
/ (T::from(n_c).unwrap() + alpha * T::from(n_features).unwrap()))
|
||||||
|
.ln()
|
||||||
|
})
|
||||||
.collect()
|
.collect()
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
class_count,
|
||||||
class_labels,
|
class_labels,
|
||||||
class_priors,
|
class_priors,
|
||||||
feature_prob,
|
feature_log_prob,
|
||||||
|
feature_count: feature_in_class_counter,
|
||||||
|
n_features,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// MultinomialNB implements the categorical naive Bayes algorithm for categorically distributed data.
|
/// MultinomialNB implements the naive Bayes algorithm for multinomially distributed data.
|
||||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
pub struct MultinomialNB<T: RealNumber, M: Matrix<T>> {
|
pub struct MultinomialNB<T: RealNumber, M: Matrix<T>> {
|
||||||
inner: BaseNaiveBayes<T, M, MultinomialNBDistribution<T>>,
|
inner: BaseNaiveBayes<T, M, MultinomialNBDistribution<T>>,
|
||||||
}
|
}
|
||||||
@@ -236,6 +261,35 @@ impl<T: RealNumber, M: Matrix<T>> MultinomialNB<T, M> {
|
|||||||
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||||
self.inner.predict(x)
|
self.inner.predict(x)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Class labels known to the classifier.
|
||||||
|
/// Returns a vector of size n_classes.
|
||||||
|
pub fn classes(&self) -> &Vec<T> {
|
||||||
|
&self.inner.distribution.class_labels
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Number of training samples observed in each class.
|
||||||
|
/// Returns a vector of size n_classes.
|
||||||
|
pub fn class_count(&self) -> &Vec<usize> {
|
||||||
|
&self.inner.distribution.class_count
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Empirical log probability of features given a class, P(x_i|y).
|
||||||
|
/// Returns a 2d vector of shape (n_classes, n_features)
|
||||||
|
pub fn feature_log_prob(&self) -> &Vec<Vec<T>> {
|
||||||
|
&self.inner.distribution.feature_log_prob
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Number of features of each sample
|
||||||
|
pub fn n_features(&self) -> usize {
|
||||||
|
self.inner.distribution.n_features
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Number of samples encountered for each (class, feature)
|
||||||
|
/// Returns a 2d vector of shape (n_classes, n_features)
|
||||||
|
pub fn feature_count(&self) -> &Vec<Vec<usize>> {
|
||||||
|
&self.inner.distribution.feature_count
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -243,6 +297,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn run_multinomial_naive_bayes() {
|
fn run_multinomial_naive_bayes() {
|
||||||
// Tests that MultinomialNB when alpha=1.0 gives the same values as
|
// Tests that MultinomialNB when alpha=1.0 gives the same values as
|
||||||
@@ -264,12 +319,29 @@ mod tests {
|
|||||||
let y = vec![0., 0., 0., 1.];
|
let y = vec![0., 0., 0., 1.];
|
||||||
let mnb = MultinomialNB::fit(&x, &y, Default::default()).unwrap();
|
let mnb = MultinomialNB::fit(&x, &y, Default::default()).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(mnb.classes(), &[0., 1.]);
|
||||||
|
assert_eq!(mnb.class_count(), &[3, 1]);
|
||||||
|
|
||||||
assert_eq!(mnb.inner.distribution.class_priors, &[0.75, 0.25]);
|
assert_eq!(mnb.inner.distribution.class_priors, &[0.75, 0.25]);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
mnb.inner.distribution.feature_prob,
|
mnb.feature_log_prob(),
|
||||||
&[
|
&[
|
||||||
&[1. / 7., 3. / 7., 1. / 14., 1. / 7., 1. / 7., 1. / 14.],
|
&[
|
||||||
&[1. / 9., 2. / 9.0, 2. / 9.0, 1. / 9.0, 1. / 9.0, 2. / 9.0]
|
(1_f64 / 7_f64).ln(),
|
||||||
|
(3_f64 / 7_f64).ln(),
|
||||||
|
(1_f64 / 14_f64).ln(),
|
||||||
|
(1_f64 / 7_f64).ln(),
|
||||||
|
(1_f64 / 7_f64).ln(),
|
||||||
|
(1_f64 / 14_f64).ln()
|
||||||
|
],
|
||||||
|
&[
|
||||||
|
(1_f64 / 9_f64).ln(),
|
||||||
|
(2_f64 / 9_f64).ln(),
|
||||||
|
(2_f64 / 9_f64).ln(),
|
||||||
|
(1_f64 / 9_f64).ln(),
|
||||||
|
(1_f64 / 9_f64).ln(),
|
||||||
|
(2_f64 / 9_f64).ln()
|
||||||
|
]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -281,6 +353,7 @@ mod tests {
|
|||||||
assert_eq!(y_hat, &[0.]);
|
assert_eq!(y_hat, &[0.]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn multinomial_nb_scikit_parity() {
|
fn multinomial_nb_scikit_parity() {
|
||||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
@@ -303,6 +376,16 @@ mod tests {
|
|||||||
let y = vec![2., 2., 0., 0., 0., 2., 1., 1., 0., 1., 0., 0., 2., 0., 2.];
|
let y = vec![2., 2., 0., 0., 0., 2., 1., 1., 0., 1., 0., 0., 2., 0., 2.];
|
||||||
let nb = MultinomialNB::fit(&x, &y, Default::default()).unwrap();
|
let nb = MultinomialNB::fit(&x, &y, Default::default()).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(nb.n_features(), 10);
|
||||||
|
assert_eq!(
|
||||||
|
nb.feature_count(),
|
||||||
|
&[
|
||||||
|
&[12, 20, 11, 24, 12, 14, 13, 17, 13, 18],
|
||||||
|
&[9, 6, 9, 4, 7, 3, 8, 5, 4, 9],
|
||||||
|
&[10, 12, 9, 9, 11, 3, 9, 18, 10, 10]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
let y_hat = nb.predict(&x).unwrap();
|
let y_hat = nb.predict(&x).unwrap();
|
||||||
|
|
||||||
assert!(nb
|
assert!(nb
|
||||||
@@ -310,16 +393,29 @@ mod tests {
|
|||||||
.distribution
|
.distribution
|
||||||
.class_priors
|
.class_priors
|
||||||
.approximate_eq(&vec!(0.46, 0.2, 0.33), 1e-2));
|
.approximate_eq(&vec!(0.46, 0.2, 0.33), 1e-2));
|
||||||
assert!(nb.inner.distribution.feature_prob[1].approximate_eq(
|
assert!(nb.feature_log_prob()[1].approximate_eq(
|
||||||
&vec!(0.07, 0.12, 0.07, 0.15, 0.07, 0.09, 0.08, 0.10, 0.08, 0.11),
|
&vec![
|
||||||
1e-1
|
-2.00148,
|
||||||
|
-2.35815494,
|
||||||
|
-2.00148,
|
||||||
|
-2.69462718,
|
||||||
|
-2.22462355,
|
||||||
|
-2.91777073,
|
||||||
|
-2.10684052,
|
||||||
|
-2.51230562,
|
||||||
|
-2.69462718,
|
||||||
|
-2.00148
|
||||||
|
],
|
||||||
|
1e-5
|
||||||
));
|
));
|
||||||
assert!(y_hat.approximate_eq(
|
assert!(y_hat.approximate_eq(
|
||||||
&vec!(2.0, 2.0, 0.0, 0.0, 0.0, 2.0, 2.0, 1.0, 0.0, 1.0, 0.0, 2.0, 0.0, 0.0, 2.0),
|
&vec!(2.0, 2.0, 0.0, 0.0, 0.0, 2.0, 2.0, 1.0, 0.0, 1.0, 0.0, 2.0, 0.0, 0.0, 2.0),
|
||||||
1e-5
|
1e-5
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
&[1., 1., 0., 0., 0., 0.],
|
&[1., 1., 0., 0., 0., 0.],
|
||||||
|
|||||||
@@ -33,6 +33,7 @@
|
|||||||
//!
|
//!
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
|
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
|
||||||
@@ -45,7 +46,8 @@ use crate::math::num::RealNumber;
|
|||||||
use crate::neighbors::KNNWeightFunction;
|
use crate::neighbors::KNNWeightFunction;
|
||||||
|
|
||||||
/// `KNNClassifier` parameters. Use `Default::default()` for default values.
|
/// `KNNClassifier` parameters. Use `Default::default()` for default values.
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct KNNClassifierParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
pub struct KNNClassifierParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||||
/// a function that defines a distance between each pair of point in training data.
|
/// a function that defines a distance between each pair of point in training data.
|
||||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||||
@@ -62,7 +64,8 @@ pub struct KNNClassifierParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// K Nearest Neighbors Classifier
|
/// K Nearest Neighbors Classifier
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct KNNClassifier<T: RealNumber, D: Distance<Vec<T>, T>> {
|
pub struct KNNClassifier<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||||
classes: Vec<T>,
|
classes: Vec<T>,
|
||||||
y: Vec<usize>,
|
y: Vec<usize>,
|
||||||
@@ -248,6 +251,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn knn_fit_predict() {
|
fn knn_fit_predict() {
|
||||||
let x =
|
let x =
|
||||||
@@ -259,6 +263,7 @@ mod tests {
|
|||||||
assert_eq!(y.to_vec(), y_hat);
|
assert_eq!(y.to_vec(), y_hat);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn knn_fit_predict_weighted() {
|
fn knn_fit_predict_weighted() {
|
||||||
let x = DenseMatrix::from_2d_array(&[&[1.], &[2.], &[3.], &[4.], &[5.]]);
|
let x = DenseMatrix::from_2d_array(&[&[1.], &[2.], &[3.], &[4.], &[5.]]);
|
||||||
@@ -276,7 +281,9 @@ mod tests {
|
|||||||
assert_eq!(vec![3.0], y_hat);
|
assert_eq!(vec![3.0], y_hat);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x =
|
let x =
|
||||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||||
|
|||||||
@@ -36,6 +36,7 @@
|
|||||||
//!
|
//!
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
|
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
|
||||||
@@ -48,7 +49,8 @@ use crate::math::num::RealNumber;
|
|||||||
use crate::neighbors::KNNWeightFunction;
|
use crate::neighbors::KNNWeightFunction;
|
||||||
|
|
||||||
/// `KNNRegressor` parameters. Use `Default::default()` for default values.
|
/// `KNNRegressor` parameters. Use `Default::default()` for default values.
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct KNNRegressorParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
pub struct KNNRegressorParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||||
/// a function that defines a distance between each pair of point in training data.
|
/// a function that defines a distance between each pair of point in training data.
|
||||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||||
@@ -65,7 +67,8 @@ pub struct KNNRegressorParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// K Nearest Neighbors Regressor
|
/// K Nearest Neighbors Regressor
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct KNNRegressor<T: RealNumber, D: Distance<Vec<T>, T>> {
|
pub struct KNNRegressor<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||||
y: Vec<T>,
|
y: Vec<T>,
|
||||||
knn_algorithm: KNNAlgorithm<T, D>,
|
knn_algorithm: KNNAlgorithm<T, D>,
|
||||||
@@ -228,6 +231,7 @@ mod tests {
|
|||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
use crate::math::distance::Distances;
|
use crate::math::distance::Distances;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn knn_fit_predict_weighted() {
|
fn knn_fit_predict_weighted() {
|
||||||
let x =
|
let x =
|
||||||
@@ -251,6 +255,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn knn_fit_predict_uniform() {
|
fn knn_fit_predict_uniform() {
|
||||||
let x =
|
let x =
|
||||||
@@ -265,7 +270,9 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x =
|
let x =
|
||||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||||
|
|||||||
@@ -33,6 +33,7 @@
|
|||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// K Nearest Neighbors Classifier
|
/// K Nearest Neighbors Classifier
|
||||||
@@ -48,7 +49,8 @@ pub mod knn_regressor;
|
|||||||
pub type KNNAlgorithmName = crate::algorithm::neighbour::KNNAlgorithmName;
|
pub type KNNAlgorithmName = crate::algorithm::neighbour::KNNAlgorithmName;
|
||||||
|
|
||||||
/// Weight function that is used to determine estimated value.
|
/// Weight function that is used to determine estimated value.
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub enum KNNWeightFunction {
|
pub enum KNNWeightFunction {
|
||||||
/// All k nearest points are weighted equally
|
/// All k nearest points are weighted equally
|
||||||
Uniform,
|
Uniform,
|
||||||
|
|||||||
@@ -50,14 +50,14 @@ impl<T: RealNumber> FirstOrderOptimizer<T> for GradientDescent<T> {
|
|||||||
let f_alpha = |alpha: T| -> T {
|
let f_alpha = |alpha: T| -> T {
|
||||||
let mut dx = step.clone();
|
let mut dx = step.clone();
|
||||||
dx.mul_scalar_mut(alpha);
|
dx.mul_scalar_mut(alpha);
|
||||||
f(&dx.add_mut(&x)) // f(x) = f(x .+ gvec .* alpha)
|
f(dx.add_mut(&x)) // f(x) = f(x .+ gvec .* alpha)
|
||||||
};
|
};
|
||||||
|
|
||||||
let df_alpha = |alpha: T| -> T {
|
let df_alpha = |alpha: T| -> T {
|
||||||
let mut dx = step.clone();
|
let mut dx = step.clone();
|
||||||
let mut dg = gvec.clone();
|
let mut dg = gvec.clone();
|
||||||
dx.mul_scalar_mut(alpha);
|
dx.mul_scalar_mut(alpha);
|
||||||
df(&mut dg, &dx.add_mut(&x)); //df(x) = df(x .+ gvec .* alpha)
|
df(&mut dg, dx.add_mut(&x)); //df(x) = df(x .+ gvec .* alpha)
|
||||||
gvec.dot(&dg)
|
gvec.dot(&dg)
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -66,7 +66,7 @@ impl<T: RealNumber> FirstOrderOptimizer<T> for GradientDescent<T> {
|
|||||||
let ls_r = ls.search(&f_alpha, &df_alpha, alpha, fx, df0);
|
let ls_r = ls.search(&f_alpha, &df_alpha, alpha, fx, df0);
|
||||||
alpha = ls_r.alpha;
|
alpha = ls_r.alpha;
|
||||||
fx = ls_r.f_x;
|
fx = ls_r.f_x;
|
||||||
x.add_mut(&step.mul_scalar_mut(alpha));
|
x.add_mut(step.mul_scalar_mut(alpha));
|
||||||
df(&mut gvec, &x);
|
df(&mut gvec, &x);
|
||||||
gnorm = gvec.norm2();
|
gnorm = gvec.norm2();
|
||||||
}
|
}
|
||||||
@@ -88,6 +88,7 @@ mod tests {
|
|||||||
use crate::optimization::line_search::Backtracking;
|
use crate::optimization::line_search::Backtracking;
|
||||||
use crate::optimization::FunctionOrder;
|
use crate::optimization::FunctionOrder;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn gradient_descent() {
|
fn gradient_descent() {
|
||||||
let x0 = DenseMatrix::row_vector_from_array(&[-1., 1.]);
|
let x0 = DenseMatrix::row_vector_from_array(&[-1., 1.]);
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
#![allow(clippy::suspicious_operation_groupings)]
|
||||||
use std::default::Default;
|
use std::default::Default;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
@@ -7,6 +8,7 @@ use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
|
|||||||
use crate::optimization::line_search::LineSearchMethod;
|
use crate::optimization::line_search::LineSearchMethod;
|
||||||
use crate::optimization::{DF, F};
|
use crate::optimization::{DF, F};
|
||||||
|
|
||||||
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
pub struct LBFGS<T: RealNumber> {
|
pub struct LBFGS<T: RealNumber> {
|
||||||
pub max_iter: usize,
|
pub max_iter: usize,
|
||||||
pub g_rtol: T,
|
pub g_rtol: T,
|
||||||
@@ -116,14 +118,14 @@ impl<T: RealNumber> LBFGS<T> {
|
|||||||
let f_alpha = |alpha: T| -> T {
|
let f_alpha = |alpha: T| -> T {
|
||||||
let mut dx = state.s.clone();
|
let mut dx = state.s.clone();
|
||||||
dx.mul_scalar_mut(alpha);
|
dx.mul_scalar_mut(alpha);
|
||||||
f(&dx.add_mut(&state.x)) // f(x) = f(x .+ gvec .* alpha)
|
f(dx.add_mut(&state.x)) // f(x) = f(x .+ gvec .* alpha)
|
||||||
};
|
};
|
||||||
|
|
||||||
let df_alpha = |alpha: T| -> T {
|
let df_alpha = |alpha: T| -> T {
|
||||||
let mut dx = state.s.clone();
|
let mut dx = state.s.clone();
|
||||||
let mut dg = state.x_df.clone();
|
let mut dg = state.x_df.clone();
|
||||||
dx.mul_scalar_mut(alpha);
|
dx.mul_scalar_mut(alpha);
|
||||||
df(&mut dg, &dx.add_mut(&state.x)); //df(x) = df(x .+ gvec .* alpha)
|
df(&mut dg, dx.add_mut(&state.x)); //df(x) = df(x .+ gvec .* alpha)
|
||||||
state.x_df.dot(&dg)
|
state.x_df.dot(&dg)
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -205,7 +207,7 @@ impl<T: RealNumber> FirstOrderOptimizer<T> for LBFGS<T> {
|
|||||||
) -> OptimizerResult<T, X> {
|
) -> OptimizerResult<T, X> {
|
||||||
let mut state = self.init_state(x0);
|
let mut state = self.init_state(x0);
|
||||||
|
|
||||||
df(&mut state.x_df, &x0);
|
df(&mut state.x_df, x0);
|
||||||
|
|
||||||
let g_converged = state.x_df.norm(T::infinity()) < self.g_atol;
|
let g_converged = state.x_df.norm(T::infinity()) < self.g_atol;
|
||||||
let mut converged = g_converged;
|
let mut converged = g_converged;
|
||||||
@@ -238,6 +240,7 @@ mod tests {
|
|||||||
use crate::optimization::line_search::Backtracking;
|
use crate::optimization::line_search::Backtracking;
|
||||||
use crate::optimization::FunctionOrder;
|
use crate::optimization::FunctionOrder;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn lbfgs() {
|
fn lbfgs() {
|
||||||
let x0 = DenseMatrix::row_vector_from_array(&[0., 0.]);
|
let x0 = DenseMatrix::row_vector_from_array(&[0., 0.]);
|
||||||
|
|||||||
@@ -112,6 +112,7 @@ impl<T: Float> LineSearchMethod<T> for Backtracking<T> {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn backtracking() {
|
fn backtracking() {
|
||||||
let f = |x: f64| -> f64 { x.powf(2.) + x };
|
let f = |x: f64| -> f64 { x.powf(2.) + x };
|
||||||
|
|||||||
@@ -4,7 +4,8 @@ pub mod line_search;
|
|||||||
pub type F<'a, T, X> = dyn for<'b> Fn(&'b X) -> T + 'a;
|
pub type F<'a, T, X> = dyn for<'b> Fn(&'b X) -> T + 'a;
|
||||||
pub type DF<'a, X> = dyn for<'b> Fn(&'b mut X, &'b X) + 'a;
|
pub type DF<'a, X> = dyn for<'b> Fn(&'b mut X, &'b X) + 'a;
|
||||||
|
|
||||||
#[derive(Debug, PartialEq)]
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
pub enum FunctionOrder {
|
pub enum FunctionOrder {
|
||||||
SECOND,
|
SECOND,
|
||||||
THIRD,
|
THIRD,
|
||||||
|
|||||||
@@ -0,0 +1,333 @@
|
|||||||
|
//! # One-hot Encoding For [RealNumber](../../math/num/trait.RealNumber.html) Matricies
|
||||||
|
//! Transform a data [Matrix](../../linalg/trait.BaseMatrix.html) by replacing all categorical variables with their one-hot equivalents
|
||||||
|
//!
|
||||||
|
//! Internally OneHotEncoder treats every categorical column as a series and transforms it using [CategoryMapper](../series_encoder/struct.CategoryMapper.html)
|
||||||
|
//!
|
||||||
|
//! ### Usage Example
|
||||||
|
//! ```
|
||||||
|
//! use smartcore::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
//! use smartcore::preprocessing::categorical::{OneHotEncoder, OneHotEncoderParams};
|
||||||
|
//! let data = DenseMatrix::from_2d_array(&[
|
||||||
|
//! &[1.5, 1.0, 1.5, 3.0],
|
||||||
|
//! &[1.5, 2.0, 1.5, 4.0],
|
||||||
|
//! &[1.5, 1.0, 1.5, 5.0],
|
||||||
|
//! &[1.5, 2.0, 1.5, 6.0],
|
||||||
|
//! ]);
|
||||||
|
//! let encoder_params = OneHotEncoderParams::from_cat_idx(&[1, 3]);
|
||||||
|
//! // Infer number of categories from data and return a reusable encoder
|
||||||
|
//! let encoder = OneHotEncoder::fit(&data, encoder_params).unwrap();
|
||||||
|
//! // Transform categorical to one-hot encoded (can transform similar)
|
||||||
|
//! let oh_data = encoder.transform(&data).unwrap();
|
||||||
|
//! // Produces the following:
|
||||||
|
//! // &[1.5, 1.0, 0.0, 1.5, 1.0, 0.0, 0.0, 0.0]
|
||||||
|
//! // &[1.5, 0.0, 1.0, 1.5, 0.0, 1.0, 0.0, 0.0]
|
||||||
|
//! // &[1.5, 1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0]
|
||||||
|
//! // &[1.5, 0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0]
|
||||||
|
//! ```
|
||||||
|
use std::iter;
|
||||||
|
|
||||||
|
use crate::error::Failed;
|
||||||
|
use crate::linalg::Matrix;
|
||||||
|
|
||||||
|
use crate::preprocessing::data_traits::{CategoricalFloat, Categorizable};
|
||||||
|
use crate::preprocessing::series_encoder::CategoryMapper;
|
||||||
|
|
||||||
|
/// OneHotEncoder Parameters
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct OneHotEncoderParams {
|
||||||
|
/// Column number that contain categorical variable
|
||||||
|
pub col_idx_categorical: Option<Vec<usize>>,
|
||||||
|
/// (Currently not implemented) Try and infer which of the matrix columns are categorical variables
|
||||||
|
infer_categorical: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OneHotEncoderParams {
|
||||||
|
/// Generate parameters from categorical variable column numbers
|
||||||
|
pub fn from_cat_idx(categorical_params: &[usize]) -> Self {
|
||||||
|
Self {
|
||||||
|
col_idx_categorical: Some(categorical_params.to_vec()),
|
||||||
|
infer_categorical: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculate the offset to parameters to due introduction of one-hot encoding
|
||||||
|
fn find_new_idxs(num_params: usize, cat_sizes: &[usize], cat_idxs: &[usize]) -> Vec<usize> {
|
||||||
|
// This functions uses iterators and returns a vector.
|
||||||
|
// In case we get a huge amount of paramenters this might be a problem
|
||||||
|
// todo: Change this such that it will return an iterator
|
||||||
|
|
||||||
|
let cat_idx = cat_idxs.iter().copied().chain((num_params..).take(1));
|
||||||
|
|
||||||
|
// Offset is constant between two categorical values, here we calculate the number of steps
|
||||||
|
// that remain constant
|
||||||
|
let repeats = cat_idx.scan(0, |a, v| {
|
||||||
|
let im = v + 1 - *a;
|
||||||
|
*a = v;
|
||||||
|
Some(im)
|
||||||
|
});
|
||||||
|
|
||||||
|
// Calculate the offset to parameter idx due to newly intorduced one-hot vectors
|
||||||
|
let offset_ = cat_sizes.iter().scan(0, |a, &v| {
|
||||||
|
*a = *a + v - 1;
|
||||||
|
Some(*a)
|
||||||
|
});
|
||||||
|
let offset = (0..1).chain(offset_);
|
||||||
|
|
||||||
|
let new_param_idxs: Vec<usize> = (0..num_params)
|
||||||
|
.zip(
|
||||||
|
repeats
|
||||||
|
.zip(offset)
|
||||||
|
.flat_map(|(r, o)| iter::repeat(o).take(r)),
|
||||||
|
)
|
||||||
|
.map(|(idx, ofst)| idx + ofst)
|
||||||
|
.collect();
|
||||||
|
new_param_idxs
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate_col_is_categorical<T: Categorizable>(data: &[T]) -> bool {
|
||||||
|
for v in data {
|
||||||
|
if !v.is_valid() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Encode Categorical variavbles of data matrix to one-hot
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct OneHotEncoder {
|
||||||
|
category_mappers: Vec<CategoryMapper<CategoricalFloat>>,
|
||||||
|
col_idx_categorical: Vec<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OneHotEncoder {
|
||||||
|
/// Create an encoder instance with categories infered from data matrix
|
||||||
|
pub fn fit<T, M>(data: &M, params: OneHotEncoderParams) -> Result<OneHotEncoder, Failed>
|
||||||
|
where
|
||||||
|
T: Categorizable,
|
||||||
|
M: Matrix<T>,
|
||||||
|
{
|
||||||
|
match (params.col_idx_categorical, params.infer_categorical) {
|
||||||
|
(None, false) => Err(Failed::fit(
|
||||||
|
"Must pass categorical series ids or infer flag",
|
||||||
|
)),
|
||||||
|
|
||||||
|
(Some(_idxs), true) => Err(Failed::fit(
|
||||||
|
"Ambigous parameters, got both infer and categroy ids",
|
||||||
|
)),
|
||||||
|
|
||||||
|
(Some(mut idxs), false) => {
|
||||||
|
// make sure categories have same order as data columns
|
||||||
|
idxs.sort_unstable();
|
||||||
|
|
||||||
|
let (nrows, _) = data.shape();
|
||||||
|
|
||||||
|
// col buffer to avoid allocations
|
||||||
|
let mut col_buf: Vec<T> = iter::repeat(T::zero()).take(nrows).collect();
|
||||||
|
|
||||||
|
let mut res: Vec<CategoryMapper<CategoricalFloat>> = Vec::with_capacity(idxs.len());
|
||||||
|
|
||||||
|
for &idx in &idxs {
|
||||||
|
data.copy_col_as_vec(idx, &mut col_buf);
|
||||||
|
if !validate_col_is_categorical(&col_buf) {
|
||||||
|
let msg = format!(
|
||||||
|
"Column {} of data matrix containts non categorizable (integer) values",
|
||||||
|
idx
|
||||||
|
);
|
||||||
|
return Err(Failed::fit(&msg[..]));
|
||||||
|
}
|
||||||
|
let hashable_col = col_buf.iter().map(|v| v.to_category());
|
||||||
|
res.push(CategoryMapper::fit_to_iter(hashable_col));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
category_mappers: res,
|
||||||
|
col_idx_categorical: idxs,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
(None, true) => {
|
||||||
|
todo!("Auto-Inference for Categorical Variables not yet implemented")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Transform categorical variables to one-hot encoded and return a new matrix
|
||||||
|
pub fn transform<T, M>(&self, x: &M) -> Result<M, Failed>
|
||||||
|
where
|
||||||
|
T: Categorizable,
|
||||||
|
M: Matrix<T>,
|
||||||
|
{
|
||||||
|
let (nrows, p) = x.shape();
|
||||||
|
let additional_params: Vec<usize> = self
|
||||||
|
.category_mappers
|
||||||
|
.iter()
|
||||||
|
.map(|enc| enc.num_categories())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Eac category of size v adds v-1 params
|
||||||
|
let expandws_p: usize = p + additional_params.iter().fold(0, |cs, &v| cs + v - 1);
|
||||||
|
|
||||||
|
let new_col_idx = find_new_idxs(p, &additional_params[..], &self.col_idx_categorical[..]);
|
||||||
|
let mut res = M::zeros(nrows, expandws_p);
|
||||||
|
|
||||||
|
for (pidx, &old_cidx) in self.col_idx_categorical.iter().enumerate() {
|
||||||
|
let cidx = new_col_idx[old_cidx];
|
||||||
|
let col_iter = (0..nrows).map(|r| x.get(r, old_cidx).to_category());
|
||||||
|
let sencoder = &self.category_mappers[pidx];
|
||||||
|
let oh_series = col_iter.map(|c| sencoder.get_one_hot::<T, Vec<T>>(&c));
|
||||||
|
|
||||||
|
for (row, oh_vec) in oh_series.enumerate() {
|
||||||
|
match oh_vec {
|
||||||
|
None => {
|
||||||
|
// Since we support T types, bad value in a series causes in to be invalid
|
||||||
|
let msg = format!("At least one value in column {} doesn't conform to category definition", old_cidx);
|
||||||
|
return Err(Failed::transform(&msg[..]));
|
||||||
|
}
|
||||||
|
Some(v) => {
|
||||||
|
// copy one hot vectors to their place in the data matrix;
|
||||||
|
for (col_ofst, &val) in v.iter().enumerate() {
|
||||||
|
res.set(row, cidx + col_ofst, val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy old data in x to their new location while skipping catergorical vars (already treated)
|
||||||
|
let mut skip_idx_iter = self.col_idx_categorical.iter();
|
||||||
|
let mut cur_skip = skip_idx_iter.next();
|
||||||
|
|
||||||
|
for (old_p, &new_p) in new_col_idx.iter().enumerate() {
|
||||||
|
// if found treated varible, skip it
|
||||||
|
if let Some(&v) = cur_skip {
|
||||||
|
if v == old_p {
|
||||||
|
cur_skip = skip_idx_iter.next();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for r in 0..nrows {
|
||||||
|
let val = x.get(r, old_p);
|
||||||
|
res.set(r, new_p, val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
use crate::preprocessing::series_encoder::CategoryMapper;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
|
#[test]
|
||||||
|
fn adjust_idxs() {
|
||||||
|
assert_eq!(find_new_idxs(0, &[], &[]), Vec::<usize>::new());
|
||||||
|
// [0,1,2] -> [0, 1, 1, 1, 2]
|
||||||
|
assert_eq!(find_new_idxs(3, &[3], &[1]), vec![0, 1, 4]);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_cat_first_and_last() -> (DenseMatrix<f64>, DenseMatrix<f64>) {
|
||||||
|
let orig = DenseMatrix::from_2d_array(&[
|
||||||
|
&[1.0, 1.5, 3.0],
|
||||||
|
&[2.0, 1.5, 4.0],
|
||||||
|
&[1.0, 1.5, 5.0],
|
||||||
|
&[2.0, 1.5, 6.0],
|
||||||
|
]);
|
||||||
|
|
||||||
|
let oh_enc = DenseMatrix::from_2d_array(&[
|
||||||
|
&[1.0, 0.0, 1.5, 1.0, 0.0, 0.0, 0.0],
|
||||||
|
&[0.0, 1.0, 1.5, 0.0, 1.0, 0.0, 0.0],
|
||||||
|
&[1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0],
|
||||||
|
&[0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0],
|
||||||
|
]);
|
||||||
|
|
||||||
|
(orig, oh_enc)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_fake_matrix() -> (DenseMatrix<f64>, DenseMatrix<f64>) {
|
||||||
|
// Categorical first and last
|
||||||
|
let orig = DenseMatrix::from_2d_array(&[
|
||||||
|
&[1.5, 1.0, 1.5, 3.0],
|
||||||
|
&[1.5, 2.0, 1.5, 4.0],
|
||||||
|
&[1.5, 1.0, 1.5, 5.0],
|
||||||
|
&[1.5, 2.0, 1.5, 6.0],
|
||||||
|
]);
|
||||||
|
|
||||||
|
let oh_enc = DenseMatrix::from_2d_array(&[
|
||||||
|
&[1.5, 1.0, 0.0, 1.5, 1.0, 0.0, 0.0, 0.0],
|
||||||
|
&[1.5, 0.0, 1.0, 1.5, 0.0, 1.0, 0.0, 0.0],
|
||||||
|
&[1.5, 1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0],
|
||||||
|
&[1.5, 0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0],
|
||||||
|
]);
|
||||||
|
|
||||||
|
(orig, oh_enc)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
|
#[test]
|
||||||
|
fn hash_encode_f64_series() {
|
||||||
|
let series = vec![3.0, 1.0, 2.0, 1.0];
|
||||||
|
let hashable_series: Vec<CategoricalFloat> =
|
||||||
|
series.iter().map(|v| v.to_category()).collect();
|
||||||
|
let enc = CategoryMapper::from_positional_category_vec(hashable_series);
|
||||||
|
let inv = enc.invert_one_hot(vec![0.0, 0.0, 1.0]);
|
||||||
|
let orig_val: f64 = inv.unwrap().into();
|
||||||
|
assert_eq!(orig_val, 2.0);
|
||||||
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
|
#[test]
|
||||||
|
fn test_fit() {
|
||||||
|
let (x, _) = build_fake_matrix();
|
||||||
|
let params = OneHotEncoderParams::from_cat_idx(&[1, 3]);
|
||||||
|
let oh_enc = OneHotEncoder::fit(&x, params).unwrap();
|
||||||
|
assert_eq!(oh_enc.category_mappers.len(), 2);
|
||||||
|
|
||||||
|
let num_cat: Vec<usize> = oh_enc
|
||||||
|
.category_mappers
|
||||||
|
.iter()
|
||||||
|
.map(|a| a.num_categories())
|
||||||
|
.collect();
|
||||||
|
assert_eq!(num_cat, vec![2, 4]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
|
#[test]
|
||||||
|
fn matrix_transform_test() {
|
||||||
|
let (x, expected_x) = build_fake_matrix();
|
||||||
|
let params = OneHotEncoderParams::from_cat_idx(&[1, 3]);
|
||||||
|
let oh_enc = OneHotEncoder::fit(&x, params).unwrap();
|
||||||
|
let nm = oh_enc.transform(&x).unwrap();
|
||||||
|
assert_eq!(nm, expected_x);
|
||||||
|
|
||||||
|
let (x, expected_x) = build_cat_first_and_last();
|
||||||
|
let params = OneHotEncoderParams::from_cat_idx(&[0, 2]);
|
||||||
|
let oh_enc = OneHotEncoder::fit(&x, params).unwrap();
|
||||||
|
let nm = oh_enc.transform(&x).unwrap();
|
||||||
|
assert_eq!(nm, expected_x);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
|
#[test]
|
||||||
|
fn fail_on_bad_category() {
|
||||||
|
let m = DenseMatrix::from_2d_array(&[
|
||||||
|
&[1.0, 1.5, 3.0],
|
||||||
|
&[2.0, 1.5, 4.0],
|
||||||
|
&[1.0, 1.5, 5.0],
|
||||||
|
&[2.0, 1.5, 6.0],
|
||||||
|
]);
|
||||||
|
|
||||||
|
let params = OneHotEncoderParams::from_cat_idx(&[1]);
|
||||||
|
match OneHotEncoder::fit(&m, params) {
|
||||||
|
Err(_) => {
|
||||||
|
assert!(true);
|
||||||
|
}
|
||||||
|
_ => assert!(false),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
//! Traits to indicate that float variables can be viewed as categorical
|
||||||
|
//! This module assumes
|
||||||
|
|
||||||
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
|
pub type CategoricalFloat = u16;
|
||||||
|
|
||||||
|
// pub struct CategoricalFloat(u16);
|
||||||
|
const ERROR_MARGIN: f64 = 0.001;
|
||||||
|
|
||||||
|
pub trait Categorizable: RealNumber {
|
||||||
|
type A;
|
||||||
|
|
||||||
|
fn to_category(self) -> CategoricalFloat;
|
||||||
|
|
||||||
|
fn is_valid(self) -> bool;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Categorizable for f32 {
|
||||||
|
type A = CategoricalFloat;
|
||||||
|
|
||||||
|
fn to_category(self) -> CategoricalFloat {
|
||||||
|
self as CategoricalFloat
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_valid(self) -> bool {
|
||||||
|
let a = self.to_category();
|
||||||
|
(a as f32 - self).abs() < (ERROR_MARGIN as f32)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Categorizable for f64 {
|
||||||
|
type A = CategoricalFloat;
|
||||||
|
|
||||||
|
fn to_category(self) -> CategoricalFloat {
|
||||||
|
self as CategoricalFloat
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_valid(self) -> bool {
|
||||||
|
let a = self.to_category();
|
||||||
|
(a as f64 - self).abs() < ERROR_MARGIN
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
/// Transform a data matrix by replacing all categorical variables with their one-hot vector equivalents
|
||||||
|
pub mod categorical;
|
||||||
|
mod data_traits;
|
||||||
|
/// Preprocess numerical matrices.
|
||||||
|
pub mod numerical;
|
||||||
|
/// Encode a series (column, array) of categorical variables as one-hot vectors
|
||||||
|
pub mod series_encoder;
|
||||||
@@ -0,0 +1,404 @@
|
|||||||
|
//! # Standard-Scaling For [RealNumber](../../math/num/trait.RealNumber.html) Matricies
|
||||||
|
//! Transform a data [Matrix](../../linalg/trait.BaseMatrix.html) by removing the mean and scaling to unit variance.
|
||||||
|
//!
|
||||||
|
//! ### Usage Example
|
||||||
|
//! ```
|
||||||
|
//! use smartcore::api::{Transformer, UnsupervisedEstimator};
|
||||||
|
//! use smartcore::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
//! use smartcore::preprocessing::numerical;
|
||||||
|
//! let data = DenseMatrix::from_2d_vec(&vec![
|
||||||
|
//! vec![0.0, 0.0],
|
||||||
|
//! vec![0.0, 0.0],
|
||||||
|
//! vec![1.0, 1.0],
|
||||||
|
//! vec![1.0, 1.0],
|
||||||
|
//! ]);
|
||||||
|
//!
|
||||||
|
//! let standard_scaler =
|
||||||
|
//! numerical::StandardScaler::fit(&data, numerical::StandardScalerParameters::default())
|
||||||
|
//! .unwrap();
|
||||||
|
//! let transformed_data = standard_scaler.transform(&data).unwrap();
|
||||||
|
//! assert_eq!(
|
||||||
|
//! transformed_data,
|
||||||
|
//! DenseMatrix::from_2d_vec(&vec![
|
||||||
|
//! vec![-1.0, -1.0],
|
||||||
|
//! vec![-1.0, -1.0],
|
||||||
|
//! vec![1.0, 1.0],
|
||||||
|
//! vec![1.0, 1.0],
|
||||||
|
//! ])
|
||||||
|
//! );
|
||||||
|
//! ```
|
||||||
|
use crate::api::{Transformer, UnsupervisedEstimator};
|
||||||
|
use crate::error::{Failed, FailedError};
|
||||||
|
use crate::linalg::Matrix;
|
||||||
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
|
/// Configure Behaviour of `StandardScaler`.
|
||||||
|
#[derive(Clone, Debug, Copy, Eq, PartialEq)]
|
||||||
|
pub struct StandardScalerParameters {
|
||||||
|
/// Optionaly adjust mean to be zero.
|
||||||
|
with_mean: bool,
|
||||||
|
/// Optionally adjust standard-deviation to be one.
|
||||||
|
with_std: bool,
|
||||||
|
}
|
||||||
|
impl Default for StandardScalerParameters {
|
||||||
|
fn default() -> Self {
|
||||||
|
StandardScalerParameters {
|
||||||
|
with_mean: true,
|
||||||
|
with_std: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// With the `StandardScaler` data can be adjusted so
|
||||||
|
/// that every column has a mean of zero and a standard
|
||||||
|
/// deviation of one. This can improve model training for
|
||||||
|
/// scaling sensitive models like neural network or nearest
|
||||||
|
/// neighbors based models.
|
||||||
|
#[derive(Clone, Debug, Default, Eq, PartialEq)]
|
||||||
|
pub struct StandardScaler<T: RealNumber> {
|
||||||
|
means: Vec<T>,
|
||||||
|
stds: Vec<T>,
|
||||||
|
parameters: StandardScalerParameters,
|
||||||
|
}
|
||||||
|
impl<T: RealNumber> StandardScaler<T> {
|
||||||
|
/// When the mean should be adjusted, the column mean
|
||||||
|
/// should be kept. Otherwise, replace it by zero.
|
||||||
|
fn adjust_column_mean(&self, mean: T) -> T {
|
||||||
|
if self.parameters.with_mean {
|
||||||
|
mean
|
||||||
|
} else {
|
||||||
|
T::zero()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// When the standard-deviation should be adjusted, the column
|
||||||
|
/// standard-deviation should be kept. Otherwise, replace it by one.
|
||||||
|
fn adjust_column_std(&self, std: T) -> T {
|
||||||
|
if self.parameters.with_std {
|
||||||
|
ensure_std_valid(std)
|
||||||
|
} else {
|
||||||
|
T::one()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Make sure the standard deviation is valid. If it is
|
||||||
|
/// negative or zero, it should replaced by the smallest
|
||||||
|
/// positive value the type can have. That way we can savely
|
||||||
|
/// divide the columns with the resulting scalar.
|
||||||
|
fn ensure_std_valid<T: RealNumber>(value: T) -> T {
|
||||||
|
value.max(T::min_positive_value())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// During `fit` the `StandardScaler` computes the column means and standard deviation.
|
||||||
|
impl<T: RealNumber, M: Matrix<T>> UnsupervisedEstimator<M, StandardScalerParameters>
|
||||||
|
for StandardScaler<T>
|
||||||
|
{
|
||||||
|
fn fit(x: &M, parameters: StandardScalerParameters) -> Result<Self, Failed> {
|
||||||
|
Ok(Self {
|
||||||
|
means: x.column_mean(),
|
||||||
|
stds: x.std(0),
|
||||||
|
parameters,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// During `transform` the `StandardScaler` applies the summary statistics
|
||||||
|
/// computed during `fit` to set the mean of each column to zero and the
|
||||||
|
/// standard deviation to one.
|
||||||
|
impl<T: RealNumber, M: Matrix<T>> Transformer<M> for StandardScaler<T> {
|
||||||
|
fn transform(&self, x: &M) -> Result<M, Failed> {
|
||||||
|
let (_, n_cols) = x.shape();
|
||||||
|
if n_cols != self.means.len() {
|
||||||
|
return Err(Failed::because(
|
||||||
|
FailedError::TransformFailed,
|
||||||
|
&format!(
|
||||||
|
"Expected {} columns, but got {} columns instead.",
|
||||||
|
self.means.len(),
|
||||||
|
n_cols,
|
||||||
|
),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(build_matrix_from_columns(
|
||||||
|
self.means
|
||||||
|
.iter()
|
||||||
|
.zip(self.stds.iter())
|
||||||
|
.enumerate()
|
||||||
|
.map(|(column_index, (column_mean, column_std))| {
|
||||||
|
x.take_column(column_index)
|
||||||
|
.sub_scalar(self.adjust_column_mean(*column_mean))
|
||||||
|
.div_scalar(self.adjust_column_std(*column_std))
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
.unwrap())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// From a collection of matrices, that contain columns, construct
|
||||||
|
/// a matrix by stacking the columns horizontally.
|
||||||
|
fn build_matrix_from_columns<T, M>(columns: Vec<M>) -> Option<M>
|
||||||
|
where
|
||||||
|
T: RealNumber,
|
||||||
|
M: Matrix<T>,
|
||||||
|
{
|
||||||
|
if let Some(output_matrix) = columns.first().cloned() {
|
||||||
|
return Some(
|
||||||
|
columns
|
||||||
|
.iter()
|
||||||
|
.skip(1)
|
||||||
|
.fold(output_matrix, |current_matrix, new_colum| {
|
||||||
|
current_matrix.h_stack(new_colum)
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
|
||||||
|
mod helper_functionality {
|
||||||
|
use super::super::{build_matrix_from_columns, ensure_std_valid};
|
||||||
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn combine_three_columns() {
|
||||||
|
assert_eq!(
|
||||||
|
build_matrix_from_columns(vec![
|
||||||
|
DenseMatrix::from_2d_vec(&vec![vec![1.0], vec![1.0], vec![1.0],]),
|
||||||
|
DenseMatrix::from_2d_vec(&vec![vec![2.0], vec![2.0], vec![2.0],]),
|
||||||
|
DenseMatrix::from_2d_vec(&vec![vec![3.0], vec![3.0], vec![3.0],])
|
||||||
|
]),
|
||||||
|
Some(DenseMatrix::from_2d_vec(&vec![
|
||||||
|
vec![1.0, 2.0, 3.0],
|
||||||
|
vec![1.0, 2.0, 3.0],
|
||||||
|
vec![1.0, 2.0, 3.0]
|
||||||
|
]))
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn negative_value_should_be_replace_with_minimal_positive_value() {
|
||||||
|
assert_eq!(ensure_std_valid(-1.0), f64::MIN_POSITIVE)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn zero_should_be_replace_with_minimal_positive_value() {
|
||||||
|
assert_eq!(ensure_std_valid(0.0), f64::MIN_POSITIVE)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mod standard_scaler {
|
||||||
|
use super::super::{StandardScaler, StandardScalerParameters};
|
||||||
|
use crate::api::{Transformer, UnsupervisedEstimator};
|
||||||
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
use crate::linalg::BaseMatrix;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn dont_adjust_mean_if_used() {
|
||||||
|
assert_eq!(
|
||||||
|
(StandardScaler {
|
||||||
|
means: vec![],
|
||||||
|
stds: vec![],
|
||||||
|
parameters: StandardScalerParameters {
|
||||||
|
with_mean: true,
|
||||||
|
with_std: true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.adjust_column_mean(1.0),
|
||||||
|
1.0
|
||||||
|
)
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn replace_mean_with_zero_if_not_used() {
|
||||||
|
assert_eq!(
|
||||||
|
(StandardScaler {
|
||||||
|
means: vec![],
|
||||||
|
stds: vec![],
|
||||||
|
parameters: StandardScalerParameters {
|
||||||
|
with_mean: false,
|
||||||
|
with_std: true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.adjust_column_mean(1.0),
|
||||||
|
0.0
|
||||||
|
)
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn dont_adjust_std_if_used() {
|
||||||
|
assert_eq!(
|
||||||
|
(StandardScaler {
|
||||||
|
means: vec![],
|
||||||
|
stds: vec![],
|
||||||
|
parameters: StandardScalerParameters {
|
||||||
|
with_mean: true,
|
||||||
|
with_std: true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.adjust_column_std(10.0),
|
||||||
|
10.0
|
||||||
|
)
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn replace_std_with_one_if_not_used() {
|
||||||
|
assert_eq!(
|
||||||
|
(StandardScaler {
|
||||||
|
means: vec![],
|
||||||
|
stds: vec![],
|
||||||
|
parameters: StandardScalerParameters {
|
||||||
|
with_mean: true,
|
||||||
|
with_std: false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.adjust_column_std(10.0),
|
||||||
|
1.0
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper function to apply fit as well as transform at the same time.
|
||||||
|
fn fit_transform_with_default_standard_scaler(
|
||||||
|
values_to_be_transformed: &DenseMatrix<f64>,
|
||||||
|
) -> DenseMatrix<f64> {
|
||||||
|
StandardScaler::fit(
|
||||||
|
values_to_be_transformed,
|
||||||
|
StandardScalerParameters::default(),
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
.transform(values_to_be_transformed)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fit transform with random generated values, expected values taken from
|
||||||
|
/// sklearn.
|
||||||
|
#[test]
|
||||||
|
fn fit_transform_random_values() {
|
||||||
|
let transformed_values =
|
||||||
|
fit_transform_with_default_standard_scaler(&DenseMatrix::from_2d_array(&[
|
||||||
|
&[0.1004222429, 0.2194113576, 0.9310663354, 0.3313593793],
|
||||||
|
&[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264],
|
||||||
|
&[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046],
|
||||||
|
&[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442],
|
||||||
|
]));
|
||||||
|
println!("{}", transformed_values);
|
||||||
|
assert!(transformed_values.approximate_eq(
|
||||||
|
&DenseMatrix::from_2d_array(&[
|
||||||
|
&[-1.1154020653, -0.4031985330, 0.9284605204, -0.4271473866],
|
||||||
|
&[-0.7615464283, -0.7076698384, -1.1075452562, 1.2632979631],
|
||||||
|
&[0.4832504303, -0.6106747444, 1.0630075435, 0.5494084257],
|
||||||
|
&[1.3936980634, 1.7215431158, -0.8839228078, -1.3855590021],
|
||||||
|
]),
|
||||||
|
1.0
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Test `fit` and `transform` for a column with zero variance.
|
||||||
|
#[test]
|
||||||
|
fn fit_transform_with_zero_variance() {
|
||||||
|
assert_eq!(
|
||||||
|
fit_transform_with_default_standard_scaler(&DenseMatrix::from_2d_array(&[
|
||||||
|
&[1.0],
|
||||||
|
&[1.0],
|
||||||
|
&[1.0],
|
||||||
|
&[1.0]
|
||||||
|
])),
|
||||||
|
DenseMatrix::from_2d_array(&[&[0.0], &[0.0], &[0.0], &[0.0]]),
|
||||||
|
"When scaling values with zero variance, zero is expected as return value"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Test `fit` for columns with nice summary statistics.
|
||||||
|
#[test]
|
||||||
|
fn fit_for_simple_values() {
|
||||||
|
assert_eq!(
|
||||||
|
StandardScaler::fit(
|
||||||
|
&DenseMatrix::from_2d_array(&[
|
||||||
|
&[1.0, 1.0, 1.0],
|
||||||
|
&[1.0, 2.0, 5.0],
|
||||||
|
&[1.0, 1.0, 1.0],
|
||||||
|
&[1.0, 2.0, 5.0]
|
||||||
|
]),
|
||||||
|
StandardScalerParameters::default(),
|
||||||
|
),
|
||||||
|
Ok(StandardScaler {
|
||||||
|
means: vec![1.0, 1.5, 3.0],
|
||||||
|
stds: vec![0.0, 0.5, 2.0],
|
||||||
|
parameters: StandardScalerParameters {
|
||||||
|
with_mean: true,
|
||||||
|
with_std: true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
)
|
||||||
|
}
|
||||||
|
/// Test `fit` for random generated values.
|
||||||
|
#[test]
|
||||||
|
fn fit_for_random_values() {
|
||||||
|
let fitted_scaler = StandardScaler::fit(
|
||||||
|
&DenseMatrix::from_2d_array(&[
|
||||||
|
&[0.1004222429, 0.2194113576, 0.9310663354, 0.3313593793],
|
||||||
|
&[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264],
|
||||||
|
&[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046],
|
||||||
|
&[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442],
|
||||||
|
]),
|
||||||
|
StandardScalerParameters::default(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
fitted_scaler.means,
|
||||||
|
vec![0.42864544605, 0.2869813741, 0.737752073825, 0.431011663625],
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
&DenseMatrix::from_2d_vec(&vec![fitted_scaler.stds]).approximate_eq(
|
||||||
|
&DenseMatrix::from_2d_array(&[&[
|
||||||
|
0.29426447500954,
|
||||||
|
0.16758497615485,
|
||||||
|
0.20820945786863,
|
||||||
|
0.23329718831165
|
||||||
|
],]),
|
||||||
|
0.00000000000001
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// If `with_std` is set to `false` the values should not be
|
||||||
|
/// adjusted to have a std of one.
|
||||||
|
#[test]
|
||||||
|
fn transform_without_std() {
|
||||||
|
let standard_scaler = StandardScaler {
|
||||||
|
means: vec![1.0, 3.0],
|
||||||
|
stds: vec![1.0, 2.0],
|
||||||
|
parameters: StandardScalerParameters {
|
||||||
|
with_mean: true,
|
||||||
|
with_std: false,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
standard_scaler.transform(&DenseMatrix::from_2d_array(&[&[0.0, 2.0], &[2.0, 4.0]])),
|
||||||
|
Ok(DenseMatrix::from_2d_array(&[&[-1.0, -1.0], &[1.0, 1.0]]))
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// If `with_mean` is set to `false` the values should not be adjusted
|
||||||
|
/// to have a mean of zero.
|
||||||
|
#[test]
|
||||||
|
fn transform_without_mean() {
|
||||||
|
let standard_scaler = StandardScaler {
|
||||||
|
means: vec![1.0, 2.0],
|
||||||
|
stds: vec![2.0, 3.0],
|
||||||
|
parameters: StandardScalerParameters {
|
||||||
|
with_mean: false,
|
||||||
|
with_std: true,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
standard_scaler
|
||||||
|
.transform(&DenseMatrix::from_2d_array(&[&[0.0, 9.0], &[4.0, 12.0]])),
|
||||||
|
Ok(DenseMatrix::from_2d_array(&[&[0.0, 3.0], &[2.0, 4.0]]))
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,282 @@
|
|||||||
|
#![allow(clippy::ptr_arg)]
|
||||||
|
//! # Series Encoder
|
||||||
|
//! Encode a series of categorical features as a one-hot numeric array.
|
||||||
|
|
||||||
|
use crate::error::Failed;
|
||||||
|
use crate::linalg::BaseVector;
|
||||||
|
use crate::math::num::RealNumber;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::hash::Hash;
|
||||||
|
|
||||||
|
/// ## Bi-directional map category <-> label num.
|
||||||
|
/// Turn Hashable objects into a one-hot vectors or ordinal values.
|
||||||
|
/// This struct encodes single class per exmample
|
||||||
|
///
|
||||||
|
/// You can fit_to_iter a category enumeration by passing an iterator of categories.
|
||||||
|
/// category numbers will be assigned in the order they are encountered
|
||||||
|
///
|
||||||
|
/// Example:
|
||||||
|
/// ```
|
||||||
|
/// use std::collections::HashMap;
|
||||||
|
/// use smartcore::preprocessing::series_encoder::CategoryMapper;
|
||||||
|
///
|
||||||
|
/// let fake_categories: Vec<usize> = vec![1, 2, 3, 4, 5, 3, 5, 3, 1, 2, 4];
|
||||||
|
/// let it = fake_categories.iter().map(|&a| a);
|
||||||
|
/// let enc = CategoryMapper::<usize>::fit_to_iter(it);
|
||||||
|
/// let oh_vec: Vec<f64> = enc.get_one_hot(&1).unwrap();
|
||||||
|
/// // notice that 1 is actually a zero-th positional category
|
||||||
|
/// assert_eq!(oh_vec, vec![1.0, 0.0, 0.0, 0.0, 0.0]);
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// You can also pass a predefined category enumeration such as a hashmap `HashMap<C, usize>` or a vector `Vec<C>`
|
||||||
|
///
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// use std::collections::HashMap;
|
||||||
|
/// use smartcore::preprocessing::series_encoder::CategoryMapper;
|
||||||
|
///
|
||||||
|
/// let category_map: HashMap<&str, usize> =
|
||||||
|
/// vec![("cat", 2), ("background",0), ("dog", 1)]
|
||||||
|
/// .into_iter()
|
||||||
|
/// .collect();
|
||||||
|
/// let category_vec = vec!["background", "dog", "cat"];
|
||||||
|
///
|
||||||
|
/// let enc_lv = CategoryMapper::<&str>::from_positional_category_vec(category_vec);
|
||||||
|
/// let enc_lm = CategoryMapper::<&str>::from_category_map(category_map);
|
||||||
|
///
|
||||||
|
/// // ["background", "dog", "cat"]
|
||||||
|
/// println!("{:?}", enc_lv.get_categories());
|
||||||
|
/// let lv: Vec<f64> = enc_lv.get_one_hot(&"dog").unwrap();
|
||||||
|
/// let lm: Vec<f64> = enc_lm.get_one_hot(&"dog").unwrap();
|
||||||
|
/// assert_eq!(lv, lm);
|
||||||
|
/// ```
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct CategoryMapper<C> {
|
||||||
|
category_map: HashMap<C, usize>,
|
||||||
|
categories: Vec<C>,
|
||||||
|
num_categories: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C> CategoryMapper<C>
|
||||||
|
where
|
||||||
|
C: Hash + Eq + Clone,
|
||||||
|
{
|
||||||
|
/// Get the number of categories in the mapper
|
||||||
|
pub fn num_categories(&self) -> usize {
|
||||||
|
self.num_categories
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fit an encoder to a lable iterator
|
||||||
|
pub fn fit_to_iter(categories: impl Iterator<Item = C>) -> Self {
|
||||||
|
let mut category_map: HashMap<C, usize> = HashMap::new();
|
||||||
|
let mut category_num = 0usize;
|
||||||
|
let mut unique_lables: Vec<C> = Vec::new();
|
||||||
|
|
||||||
|
for l in categories {
|
||||||
|
if !category_map.contains_key(&l) {
|
||||||
|
category_map.insert(l.clone(), category_num);
|
||||||
|
unique_lables.push(l.clone());
|
||||||
|
category_num += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Self {
|
||||||
|
category_map,
|
||||||
|
num_categories: category_num,
|
||||||
|
categories: unique_lables,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build an encoder from a predefined (category -> class number) map
|
||||||
|
pub fn from_category_map(category_map: HashMap<C, usize>) -> Self {
|
||||||
|
let mut _unique_cat: Vec<(C, usize)> =
|
||||||
|
category_map.iter().map(|(k, v)| (k.clone(), *v)).collect();
|
||||||
|
_unique_cat.sort_by(|a, b| a.1.cmp(&b.1));
|
||||||
|
let categories: Vec<C> = _unique_cat.into_iter().map(|a| a.0).collect();
|
||||||
|
Self {
|
||||||
|
num_categories: categories.len(),
|
||||||
|
categories,
|
||||||
|
category_map,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build an encoder from a predefined positional category-class num vector
|
||||||
|
pub fn from_positional_category_vec(categories: Vec<C>) -> Self {
|
||||||
|
let category_map: HashMap<C, usize> = categories
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(v, k)| (k.clone(), v))
|
||||||
|
.collect();
|
||||||
|
Self {
|
||||||
|
num_categories: categories.len(),
|
||||||
|
category_map,
|
||||||
|
categories,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get label num of a category
|
||||||
|
pub fn get_num(&self, category: &C) -> Option<&usize> {
|
||||||
|
self.category_map.get(category)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return category corresponding to label num
|
||||||
|
pub fn get_cat(&self, num: usize) -> &C {
|
||||||
|
&self.categories[num]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List all categories (position = category number)
|
||||||
|
pub fn get_categories(&self) -> &[C] {
|
||||||
|
&self.categories[..]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get one-hot encoding of the category
|
||||||
|
pub fn get_one_hot<U, V>(&self, category: &C) -> Option<V>
|
||||||
|
where
|
||||||
|
U: RealNumber,
|
||||||
|
V: BaseVector<U>,
|
||||||
|
{
|
||||||
|
self.get_num(category)
|
||||||
|
.map(|&idx| make_one_hot::<U, V>(idx, self.num_categories))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Invert one-hot vector, back to the category
|
||||||
|
pub fn invert_one_hot<U, V>(&self, one_hot: V) -> Result<C, Failed>
|
||||||
|
where
|
||||||
|
U: RealNumber,
|
||||||
|
V: BaseVector<U>,
|
||||||
|
{
|
||||||
|
let pos = U::one();
|
||||||
|
|
||||||
|
let oh_it = (0..one_hot.len()).map(|idx| one_hot.get(idx));
|
||||||
|
|
||||||
|
let s: Vec<usize> = oh_it
|
||||||
|
.enumerate()
|
||||||
|
.filter_map(|(idx, v)| if v == pos { Some(idx) } else { None })
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if s.len() == 1 {
|
||||||
|
let idx = s[0];
|
||||||
|
return Ok(self.get_cat(idx).clone());
|
||||||
|
}
|
||||||
|
let pos_entries = format!(
|
||||||
|
"Expected a single positive entry, {} entires found",
|
||||||
|
s.len()
|
||||||
|
);
|
||||||
|
Err(Failed::transform(&pos_entries[..]))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get ordinal encoding of the catergory
|
||||||
|
pub fn get_ordinal<U>(&self, category: &C) -> Option<U>
|
||||||
|
where
|
||||||
|
U: RealNumber,
|
||||||
|
{
|
||||||
|
match self.get_num(category) {
|
||||||
|
None => None,
|
||||||
|
Some(&idx) => U::from_usize(idx),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Make a one-hot encoded vector from a categorical variable
|
||||||
|
///
|
||||||
|
/// Example:
|
||||||
|
/// ```
|
||||||
|
/// use smartcore::preprocessing::series_encoder::make_one_hot;
|
||||||
|
/// let one_hot: Vec<f64> = make_one_hot(2, 3);
|
||||||
|
/// assert_eq!(one_hot, vec![0.0, 0.0, 1.0]);
|
||||||
|
/// ```
|
||||||
|
pub fn make_one_hot<T, V>(category_idx: usize, num_categories: usize) -> V
|
||||||
|
where
|
||||||
|
T: RealNumber,
|
||||||
|
V: BaseVector<T>,
|
||||||
|
{
|
||||||
|
let pos = T::one();
|
||||||
|
let mut z = V::zeros(num_categories);
|
||||||
|
z.set(category_idx, pos);
|
||||||
|
z
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
|
#[test]
|
||||||
|
fn from_categories() {
|
||||||
|
let fake_categories: Vec<usize> = vec![1, 2, 3, 4, 5, 3, 5, 3, 1, 2, 4];
|
||||||
|
let it = fake_categories.iter().map(|&a| a);
|
||||||
|
let enc = CategoryMapper::<usize>::fit_to_iter(it);
|
||||||
|
let oh_vec: Vec<f64> = match enc.get_one_hot(&1) {
|
||||||
|
None => panic!("Wrong categories"),
|
||||||
|
Some(v) => v,
|
||||||
|
};
|
||||||
|
let res: Vec<f64> = vec![1f64, 0f64, 0f64, 0f64, 0f64];
|
||||||
|
assert_eq!(oh_vec, res);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_fake_str_enc<'a>() -> CategoryMapper<&'a str> {
|
||||||
|
let fake_category_pos = vec!["background", "dog", "cat"];
|
||||||
|
let enc = CategoryMapper::<&str>::from_positional_category_vec(fake_category_pos);
|
||||||
|
enc
|
||||||
|
}
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
|
#[test]
|
||||||
|
fn ordinal_encoding() {
|
||||||
|
let enc = build_fake_str_enc();
|
||||||
|
assert_eq!(1f64, enc.get_ordinal::<f64>(&"dog").unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
|
#[test]
|
||||||
|
fn category_map_and_vec() {
|
||||||
|
let category_map: HashMap<&str, usize> = vec![("background", 0), ("dog", 1), ("cat", 2)]
|
||||||
|
.into_iter()
|
||||||
|
.collect();
|
||||||
|
let enc = CategoryMapper::<&str>::from_category_map(category_map);
|
||||||
|
let oh_vec: Vec<f64> = match enc.get_one_hot(&"dog") {
|
||||||
|
None => panic!("Wrong categories"),
|
||||||
|
Some(v) => v,
|
||||||
|
};
|
||||||
|
let res: Vec<f64> = vec![0f64, 1f64, 0f64];
|
||||||
|
assert_eq!(oh_vec, res);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
|
#[test]
|
||||||
|
fn positional_categories_vec() {
|
||||||
|
let enc = build_fake_str_enc();
|
||||||
|
let oh_vec: Vec<f64> = match enc.get_one_hot(&"dog") {
|
||||||
|
None => panic!("Wrong categories"),
|
||||||
|
Some(v) => v,
|
||||||
|
};
|
||||||
|
let res: Vec<f64> = vec![0.0, 1.0, 0.0];
|
||||||
|
assert_eq!(oh_vec, res);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
|
#[test]
|
||||||
|
fn invert_label_test() {
|
||||||
|
let enc = build_fake_str_enc();
|
||||||
|
let res: Vec<f64> = vec![0.0, 1.0, 0.0];
|
||||||
|
let lab = enc.invert_one_hot(res).unwrap();
|
||||||
|
assert_eq!(lab, "dog");
|
||||||
|
if let Err(e) = enc.invert_one_hot(vec![0.0, 0.0, 0.0]) {
|
||||||
|
let pos_entries = format!("Expected a single positive entry, 0 entires found");
|
||||||
|
assert_eq!(e, Failed::transform(&pos_entries[..]));
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
|
#[test]
|
||||||
|
fn test_many_categorys() {
|
||||||
|
let enc = build_fake_str_enc();
|
||||||
|
let cat_it = ["dog", "cat", "fish", "background"].iter().cloned();
|
||||||
|
let res: Vec<Option<Vec<f64>>> = cat_it.map(|v| enc.get_one_hot(&v)).collect();
|
||||||
|
let v = vec![
|
||||||
|
Some(vec![0.0, 1.0, 0.0]),
|
||||||
|
Some(vec![0.0, 0.0, 1.0]),
|
||||||
|
None,
|
||||||
|
Some(vec![1.0, 0.0, 0.0]),
|
||||||
|
];
|
||||||
|
assert_eq!(res, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
+13
-4
@@ -26,6 +26,7 @@
|
|||||||
pub mod svc;
|
pub mod svc;
|
||||||
pub mod svr;
|
pub mod svr;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::linalg::BaseVector;
|
use crate::linalg::BaseVector;
|
||||||
@@ -93,18 +94,21 @@ impl Kernels {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Linear Kernel
|
/// Linear Kernel
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct LinearKernel {}
|
pub struct LinearKernel {}
|
||||||
|
|
||||||
/// Radial basis function (Gaussian) kernel
|
/// Radial basis function (Gaussian) kernel
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct RBFKernel<T: RealNumber> {
|
pub struct RBFKernel<T: RealNumber> {
|
||||||
/// kernel coefficient
|
/// kernel coefficient
|
||||||
pub gamma: T,
|
pub gamma: T,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Polynomial kernel
|
/// Polynomial kernel
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct PolynomialKernel<T: RealNumber> {
|
pub struct PolynomialKernel<T: RealNumber> {
|
||||||
/// degree of the polynomial
|
/// degree of the polynomial
|
||||||
pub degree: T,
|
pub degree: T,
|
||||||
@@ -115,7 +119,8 @@ pub struct PolynomialKernel<T: RealNumber> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Sigmoid (hyperbolic tangent) kernel
|
/// Sigmoid (hyperbolic tangent) kernel
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct SigmoidKernel<T: RealNumber> {
|
pub struct SigmoidKernel<T: RealNumber> {
|
||||||
/// kernel coefficient
|
/// kernel coefficient
|
||||||
pub gamma: T,
|
pub gamma: T,
|
||||||
@@ -154,6 +159,7 @@ impl<T: RealNumber, V: BaseVector<T>> Kernel<T, V> for SigmoidKernel<T> {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn linear_kernel() {
|
fn linear_kernel() {
|
||||||
let v1 = vec![1., 2., 3.];
|
let v1 = vec![1., 2., 3.];
|
||||||
@@ -162,6 +168,7 @@ mod tests {
|
|||||||
assert_eq!(32f64, Kernels::linear().apply(&v1, &v2));
|
assert_eq!(32f64, Kernels::linear().apply(&v1, &v2));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn rbf_kernel() {
|
fn rbf_kernel() {
|
||||||
let v1 = vec![1., 2., 3.];
|
let v1 = vec![1., 2., 3.];
|
||||||
@@ -170,6 +177,7 @@ mod tests {
|
|||||||
assert!((0.2265f64 - Kernels::rbf(0.055).apply(&v1, &v2)).abs() < 1e-4);
|
assert!((0.2265f64 - Kernels::rbf(0.055).apply(&v1, &v2)).abs() < 1e-4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn polynomial_kernel() {
|
fn polynomial_kernel() {
|
||||||
let v1 = vec![1., 2., 3.];
|
let v1 = vec![1., 2., 3.];
|
||||||
@@ -181,6 +189,7 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn sigmoid_kernel() {
|
fn sigmoid_kernel() {
|
||||||
let v1 = vec![1., 2., 3.];
|
let v1 = vec![1., 2., 3.];
|
||||||
|
|||||||
+85
-26
@@ -57,9 +57,9 @@
|
|||||||
//! let y = vec![ 0., 0., 0., 0., 0., 0., 0., 0.,
|
//! let y = vec![ 0., 0., 0., 0., 0., 0., 0., 0.,
|
||||||
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
|
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
|
||||||
//!
|
//!
|
||||||
//! let svr = SVC::fit(&x, &y, SVCParameters::default().with_c(200.0)).unwrap();
|
//! let svc = SVC::fit(&x, &y, SVCParameters::default().with_c(200.0)).unwrap();
|
||||||
//!
|
//!
|
||||||
//! let y_hat = svr.predict(&x).unwrap();
|
//! let y_hat = svc.predict(&x).unwrap();
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! ## References:
|
//! ## References:
|
||||||
@@ -76,6 +76,7 @@ use std::marker::PhantomData;
|
|||||||
|
|
||||||
use rand::seq::SliceRandom;
|
use rand::seq::SliceRandom;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::api::{Predictor, SupervisedEstimator};
|
use crate::api::{Predictor, SupervisedEstimator};
|
||||||
@@ -85,7 +86,8 @@ use crate::linalg::Matrix;
|
|||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
use crate::svm::{Kernel, Kernels, LinearKernel};
|
use crate::svm::{Kernel, Kernels, LinearKernel};
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
/// SVC Parameters
|
/// SVC Parameters
|
||||||
pub struct SVCParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
pub struct SVCParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||||
/// Number of epochs.
|
/// Number of epochs.
|
||||||
@@ -100,11 +102,15 @@ pub struct SVCParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>
|
|||||||
m: PhantomData<M>,
|
m: PhantomData<M>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[serde(bound(
|
#[derive(Debug)]
|
||||||
serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize",
|
#[cfg_attr(
|
||||||
deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>",
|
feature = "serde",
|
||||||
))]
|
serde(bound(
|
||||||
|
serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize",
|
||||||
|
deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>",
|
||||||
|
))
|
||||||
|
)]
|
||||||
/// Support Vector Classifier
|
/// Support Vector Classifier
|
||||||
pub struct SVC<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
pub struct SVC<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||||
classes: Vec<T>,
|
classes: Vec<T>,
|
||||||
@@ -114,7 +120,8 @@ pub struct SVC<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
|||||||
b: T,
|
b: T,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
struct SupportVector<T: RealNumber, V: BaseVector<T>> {
|
struct SupportVector<T: RealNumber, V: BaseVector<T>> {
|
||||||
index: usize,
|
index: usize,
|
||||||
x: V,
|
x: V,
|
||||||
@@ -215,7 +222,7 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVC<T, M, K> {
|
|||||||
|
|
||||||
if n != y.len() {
|
if n != y.len() {
|
||||||
return Err(Failed::fit(
|
return Err(Failed::fit(
|
||||||
&"Number of rows of X doesn\'t match number of rows of Y".to_string(),
|
"Number of rows of X doesn\'t match number of rows of Y",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -256,21 +263,33 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVC<T, M, K> {
|
|||||||
/// Predicts estimated class labels from `x`
|
/// Predicts estimated class labels from `x`
|
||||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||||
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||||
let (n, _) = x.shape();
|
let mut y_hat = self.decision_function(x)?;
|
||||||
|
|
||||||
let mut y_hat = M::RowVector::zeros(n);
|
for i in 0..y_hat.len() {
|
||||||
|
let cls_idx = match y_hat.get(i) > T::zero() {
|
||||||
for i in 0..n {
|
|
||||||
let cls_idx = match self.predict_for_row(x.get_row(i)) == T::one() {
|
|
||||||
false => self.classes[0],
|
false => self.classes[0],
|
||||||
true => self.classes[1],
|
true => self.classes[1],
|
||||||
};
|
};
|
||||||
|
|
||||||
y_hat.set(i, cls_idx);
|
y_hat.set(i, cls_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(y_hat)
|
Ok(y_hat)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Evaluates the decision function for the rows in `x`
|
||||||
|
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||||
|
pub fn decision_function(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||||
|
let (n, _) = x.shape();
|
||||||
|
let mut y_hat = M::RowVector::zeros(n);
|
||||||
|
|
||||||
|
for i in 0..n {
|
||||||
|
y_hat.set(i, self.predict_for_row(x.get_row(i)));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(y_hat)
|
||||||
|
}
|
||||||
|
|
||||||
fn predict_for_row(&self, x: M::RowVector) -> T {
|
fn predict_for_row(&self, x: M::RowVector) -> T {
|
||||||
let mut f = self.b;
|
let mut f = self.b;
|
||||||
|
|
||||||
@@ -278,11 +297,7 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVC<T, M, K> {
|
|||||||
f += self.w[i] * self.kernel.apply(&x, &self.instances[i]);
|
f += self.w[i] * self.kernel.apply(&x, &self.instances[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
if f > T::zero() {
|
f
|
||||||
T::one()
|
|
||||||
} else {
|
|
||||||
-T::one()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -370,7 +385,7 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
|
|||||||
Optimizer {
|
Optimizer {
|
||||||
x,
|
x,
|
||||||
y,
|
y,
|
||||||
parameters: ¶meters,
|
parameters,
|
||||||
svmin: 0,
|
svmin: 0,
|
||||||
svmax: 0,
|
svmax: 0,
|
||||||
gmin: T::max_value(),
|
gmin: T::max_value(),
|
||||||
@@ -582,7 +597,7 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
|
|||||||
for i in 0..self.sv.len() {
|
for i in 0..self.sv.len() {
|
||||||
let v = &self.sv[i];
|
let v = &self.sv[i];
|
||||||
let z = v.grad - gm;
|
let z = v.grad - gm;
|
||||||
let k = cache.get(sv1, &v);
|
let k = cache.get(sv1, v);
|
||||||
let mut curv = km + v.k - T::two() * k;
|
let mut curv = km + v.k - T::two() * k;
|
||||||
if curv <= T::zero() {
|
if curv <= T::zero() {
|
||||||
curv = self.tau;
|
curv = self.tau;
|
||||||
@@ -719,8 +734,10 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::*;
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
use crate::metrics::accuracy;
|
use crate::metrics::accuracy;
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use crate::svm::*;
|
use crate::svm::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn svc_fit_predict() {
|
fn svc_fit_predict() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
@@ -763,6 +780,46 @@ mod tests {
|
|||||||
assert!(accuracy(&y_hat, &y) >= 0.9);
|
assert!(accuracy(&y_hat, &y) >= 0.9);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
|
#[test]
|
||||||
|
fn svc_fit_decision_function() {
|
||||||
|
let x = DenseMatrix::from_2d_array(&[&[4.0, 0.0], &[0.0, 4.0], &[8.0, 0.0], &[0.0, 8.0]]);
|
||||||
|
|
||||||
|
let x2 = DenseMatrix::from_2d_array(&[
|
||||||
|
&[3.0, 3.0],
|
||||||
|
&[4.0, 4.0],
|
||||||
|
&[6.0, 6.0],
|
||||||
|
&[10.0, 10.0],
|
||||||
|
&[1.0, 1.0],
|
||||||
|
&[0.0, 0.0],
|
||||||
|
]);
|
||||||
|
|
||||||
|
let y: Vec<f64> = vec![0., 0., 1., 1.];
|
||||||
|
|
||||||
|
let y_hat = SVC::fit(
|
||||||
|
&x,
|
||||||
|
&y,
|
||||||
|
SVCParameters::default()
|
||||||
|
.with_c(200.0)
|
||||||
|
.with_kernel(Kernels::linear()),
|
||||||
|
)
|
||||||
|
.and_then(|lr| lr.decision_function(&x2))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// x can be classified by a straight line through [6.0, 0.0] and [0.0, 6.0],
|
||||||
|
// so the score should increase as points get further away from that line
|
||||||
|
println!("{:?}", y_hat);
|
||||||
|
assert!(y_hat[1] < y_hat[2]);
|
||||||
|
assert!(y_hat[2] < y_hat[3]);
|
||||||
|
|
||||||
|
// for negative scores the score should decrease
|
||||||
|
assert!(y_hat[4] > y_hat[5]);
|
||||||
|
|
||||||
|
// y_hat[0] is on the line, so its score should be close to 0
|
||||||
|
assert!(y_hat[0].abs() <= 0.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn svc_fit_predict_rbf() {
|
fn svc_fit_predict_rbf() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
@@ -806,7 +863,9 @@ mod tests {
|
|||||||
assert!(accuracy(&y_hat, &y) >= 0.9);
|
assert!(accuracy(&y_hat, &y) >= 0.9);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn svc_serde() {
|
fn svc_serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[5.1, 3.5, 1.4, 0.2],
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
@@ -835,11 +894,11 @@ mod tests {
|
|||||||
-1., -1., -1., -1., -1., -1., -1., -1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
-1., -1., -1., -1., -1., -1., -1., -1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||||
];
|
];
|
||||||
|
|
||||||
let svr = SVC::fit(&x, &y, Default::default()).unwrap();
|
let svc = SVC::fit(&x, &y, Default::default()).unwrap();
|
||||||
|
|
||||||
let deserialized_svr: SVC<f64, DenseMatrix<f64>, LinearKernel> =
|
let deserialized_svc: SVC<f64, DenseMatrix<f64>, LinearKernel> =
|
||||||
serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap();
|
serde_json::from_str(&serde_json::to_string(&svc).unwrap()).unwrap();
|
||||||
|
|
||||||
assert_eq!(svr, deserialized_svr);
|
assert_eq!(svc, deserialized_svc);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+20
-9
@@ -68,6 +68,7 @@ use std::cell::{Ref, RefCell};
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::api::{Predictor, SupervisedEstimator};
|
use crate::api::{Predictor, SupervisedEstimator};
|
||||||
@@ -77,7 +78,8 @@ use crate::linalg::Matrix;
|
|||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
use crate::svm::{Kernel, Kernels, LinearKernel};
|
use crate::svm::{Kernel, Kernels, LinearKernel};
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
/// SVR Parameters
|
/// SVR Parameters
|
||||||
pub struct SVRParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
pub struct SVRParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||||
/// Epsilon in the epsilon-SVR model.
|
/// Epsilon in the epsilon-SVR model.
|
||||||
@@ -92,11 +94,15 @@ pub struct SVRParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>
|
|||||||
m: PhantomData<M>,
|
m: PhantomData<M>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[serde(bound(
|
#[derive(Debug)]
|
||||||
serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize",
|
#[cfg_attr(
|
||||||
deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>",
|
feature = "serde",
|
||||||
))]
|
serde(bound(
|
||||||
|
serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize",
|
||||||
|
deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>",
|
||||||
|
))
|
||||||
|
)]
|
||||||
|
|
||||||
/// Epsilon-Support Vector Regression
|
/// Epsilon-Support Vector Regression
|
||||||
pub struct SVR<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
pub struct SVR<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||||
@@ -106,7 +112,8 @@ pub struct SVR<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
|||||||
b: T,
|
b: T,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
struct SupportVector<T: RealNumber, V: BaseVector<T>> {
|
struct SupportVector<T: RealNumber, V: BaseVector<T>> {
|
||||||
index: usize,
|
index: usize,
|
||||||
x: V,
|
x: V,
|
||||||
@@ -205,7 +212,7 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVR<T, M, K> {
|
|||||||
|
|
||||||
if n != y.len() {
|
if n != y.len() {
|
||||||
return Err(Failed::fit(
|
return Err(Failed::fit(
|
||||||
&"Number of rows of X doesn\'t match number of rows of Y".to_string(),
|
"Number of rows of X doesn\'t match number of rows of Y",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -235,7 +242,7 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVR<T, M, K> {
|
|||||||
Ok(y_hat)
|
Ok(y_hat)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(in crate) fn predict_for_row(&self, x: M::RowVector) -> T {
|
pub(crate) fn predict_for_row(&self, x: M::RowVector) -> T {
|
||||||
let mut f = self.b;
|
let mut f = self.b;
|
||||||
|
|
||||||
for i in 0..self.instances.len() {
|
for i in 0..self.instances.len() {
|
||||||
@@ -526,8 +533,10 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::*;
|
use crate::linalg::naive::dense_matrix::*;
|
||||||
use crate::metrics::mean_squared_error;
|
use crate::metrics::mean_squared_error;
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use crate::svm::*;
|
use crate::svm::*;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn svr_fit_predict() {
|
fn svr_fit_predict() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
@@ -561,7 +570,9 @@ mod tests {
|
|||||||
assert!(mean_squared_error(&y_hat, &y) < 2.5);
|
assert!(mean_squared_error(&y_hat, &y) < 2.5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn svr_serde() {
|
fn svr_serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||||
|
|||||||
@@ -68,6 +68,8 @@ use std::fmt::Debug;
|
|||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
use rand::seq::SliceRandom;
|
use rand::seq::SliceRandom;
|
||||||
|
use rand::Rng;
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||||
@@ -76,7 +78,8 @@ use crate::error::Failed;
|
|||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
/// Parameters of Decision Tree
|
/// Parameters of Decision Tree
|
||||||
pub struct DecisionTreeClassifierParameters {
|
pub struct DecisionTreeClassifierParameters {
|
||||||
/// Split criteria to use when building a tree.
|
/// Split criteria to use when building a tree.
|
||||||
@@ -90,7 +93,8 @@ pub struct DecisionTreeClassifierParameters {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Decision Tree
|
/// Decision Tree
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct DecisionTreeClassifier<T: RealNumber> {
|
pub struct DecisionTreeClassifier<T: RealNumber> {
|
||||||
nodes: Vec<Node<T>>,
|
nodes: Vec<Node<T>>,
|
||||||
parameters: DecisionTreeClassifierParameters,
|
parameters: DecisionTreeClassifierParameters,
|
||||||
@@ -100,7 +104,8 @@ pub struct DecisionTreeClassifier<T: RealNumber> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// The function to measure the quality of a split.
|
/// The function to measure the quality of a split.
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub enum SplitCriterion {
|
pub enum SplitCriterion {
|
||||||
/// [Gini index](../decision_tree_classifier/index.html)
|
/// [Gini index](../decision_tree_classifier/index.html)
|
||||||
Gini,
|
Gini,
|
||||||
@@ -110,9 +115,10 @@ pub enum SplitCriterion {
|
|||||||
ClassificationError,
|
ClassificationError,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
struct Node<T: RealNumber> {
|
struct Node<T: RealNumber> {
|
||||||
index: usize,
|
_index: usize,
|
||||||
output: usize,
|
output: usize,
|
||||||
split_feature: usize,
|
split_feature: usize,
|
||||||
split_value: Option<T>,
|
split_value: Option<T>,
|
||||||
@@ -198,7 +204,7 @@ impl Default for DecisionTreeClassifierParameters {
|
|||||||
impl<T: RealNumber> Node<T> {
|
impl<T: RealNumber> Node<T> {
|
||||||
fn new(index: usize, output: usize) -> Self {
|
fn new(index: usize, output: usize) -> Self {
|
||||||
Node {
|
Node {
|
||||||
index,
|
_index: index,
|
||||||
output,
|
output,
|
||||||
split_feature: 0,
|
split_feature: 0,
|
||||||
split_value: Option::None,
|
split_value: Option::None,
|
||||||
@@ -279,7 +285,7 @@ impl<'a, T: RealNumber, M: Matrix<T>> NodeVisitor<'a, T, M> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(in crate) fn which_max(x: &[usize]) -> usize {
|
pub(crate) fn which_max(x: &[usize]) -> usize {
|
||||||
let mut m = x[0];
|
let mut m = x[0];
|
||||||
let mut which = 0;
|
let mut which = 0;
|
||||||
|
|
||||||
@@ -323,7 +329,14 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
|||||||
) -> Result<DecisionTreeClassifier<T>, Failed> {
|
) -> Result<DecisionTreeClassifier<T>, Failed> {
|
||||||
let (x_nrows, num_attributes) = x.shape();
|
let (x_nrows, num_attributes) = x.shape();
|
||||||
let samples = vec![1; x_nrows];
|
let samples = vec![1; x_nrows];
|
||||||
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
DecisionTreeClassifier::fit_weak_learner(
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
samples,
|
||||||
|
num_attributes,
|
||||||
|
parameters,
|
||||||
|
&mut rand::thread_rng(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn fit_weak_learner<M: Matrix<T>>(
|
pub(crate) fn fit_weak_learner<M: Matrix<T>>(
|
||||||
@@ -332,6 +345,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
|||||||
samples: Vec<usize>,
|
samples: Vec<usize>,
|
||||||
mtry: usize,
|
mtry: usize,
|
||||||
parameters: DecisionTreeClassifierParameters,
|
parameters: DecisionTreeClassifierParameters,
|
||||||
|
rng: &mut impl Rng,
|
||||||
) -> Result<DecisionTreeClassifier<T>, Failed> {
|
) -> Result<DecisionTreeClassifier<T>, Failed> {
|
||||||
let y_m = M::from_row_vector(y.clone());
|
let y_m = M::from_row_vector(y.clone());
|
||||||
let (_, y_ncols) = y_m.shape();
|
let (_, y_ncols) = y_m.shape();
|
||||||
@@ -375,17 +389,17 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
|||||||
depth: 0,
|
depth: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut visitor = NodeVisitor::<T, M>::new(0, samples, &order, &x, &yi, 1);
|
let mut visitor = NodeVisitor::<T, M>::new(0, samples, &order, x, &yi, 1);
|
||||||
|
|
||||||
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
|
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
|
||||||
|
|
||||||
if tree.find_best_cutoff(&mut visitor, mtry) {
|
if tree.find_best_cutoff(&mut visitor, mtry, rng) {
|
||||||
visitor_queue.push_back(visitor);
|
visitor_queue.push_back(visitor);
|
||||||
}
|
}
|
||||||
|
|
||||||
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
||||||
match visitor_queue.pop_front() {
|
match visitor_queue.pop_front() {
|
||||||
Some(node) => tree.split(node, mtry, &mut visitor_queue),
|
Some(node) => tree.split(node, mtry, &mut visitor_queue, rng),
|
||||||
None => break,
|
None => break,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@@ -407,7 +421,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
|||||||
Ok(result.to_row_vector())
|
Ok(result.to_row_vector())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(in crate) fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
|
pub(crate) fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
|
||||||
let mut result = 0;
|
let mut result = 0;
|
||||||
let mut queue: LinkedList<usize> = LinkedList::new();
|
let mut queue: LinkedList<usize> = LinkedList::new();
|
||||||
|
|
||||||
@@ -438,6 +452,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
|||||||
&mut self,
|
&mut self,
|
||||||
visitor: &mut NodeVisitor<'_, T, M>,
|
visitor: &mut NodeVisitor<'_, T, M>,
|
||||||
mtry: usize,
|
mtry: usize,
|
||||||
|
rng: &mut impl Rng,
|
||||||
) -> bool {
|
) -> bool {
|
||||||
let (n_rows, n_attr) = visitor.x.shape();
|
let (n_rows, n_attr) = visitor.x.shape();
|
||||||
|
|
||||||
@@ -477,7 +492,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
|||||||
let mut variables = (0..n_attr).collect::<Vec<_>>();
|
let mut variables = (0..n_attr).collect::<Vec<_>>();
|
||||||
|
|
||||||
if mtry < n_attr {
|
if mtry < n_attr {
|
||||||
variables.shuffle(&mut rand::thread_rng());
|
variables.shuffle(rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
for variable in variables.iter().take(mtry) {
|
for variable in variables.iter().take(mtry) {
|
||||||
@@ -499,7 +514,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
|||||||
visitor: &mut NodeVisitor<'_, T, M>,
|
visitor: &mut NodeVisitor<'_, T, M>,
|
||||||
n: usize,
|
n: usize,
|
||||||
count: &[usize],
|
count: &[usize],
|
||||||
false_count: &mut Vec<usize>,
|
false_count: &mut [usize],
|
||||||
parent_impurity: T,
|
parent_impurity: T,
|
||||||
j: usize,
|
j: usize,
|
||||||
) {
|
) {
|
||||||
@@ -536,7 +551,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
|||||||
- T::from(tc).unwrap() / T::from(n).unwrap()
|
- T::from(tc).unwrap() / T::from(n).unwrap()
|
||||||
* impurity(&self.parameters.criterion, &true_count, tc)
|
* impurity(&self.parameters.criterion, &true_count, tc)
|
||||||
- T::from(fc).unwrap() / T::from(n).unwrap()
|
- T::from(fc).unwrap() / T::from(n).unwrap()
|
||||||
* impurity(&self.parameters.criterion, &false_count, fc);
|
* impurity(&self.parameters.criterion, false_count, fc);
|
||||||
|
|
||||||
if self.nodes[visitor.node].split_score == Option::None
|
if self.nodes[visitor.node].split_score == Option::None
|
||||||
|| gain > self.nodes[visitor.node].split_score.unwrap()
|
|| gain > self.nodes[visitor.node].split_score.unwrap()
|
||||||
@@ -561,6 +576,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
|||||||
mut visitor: NodeVisitor<'a, T, M>,
|
mut visitor: NodeVisitor<'a, T, M>,
|
||||||
mtry: usize,
|
mtry: usize,
|
||||||
visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>,
|
visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>,
|
||||||
|
rng: &mut impl Rng,
|
||||||
) -> bool {
|
) -> bool {
|
||||||
let (n, _) = visitor.x.shape();
|
let (n, _) = visitor.x.shape();
|
||||||
let mut tc = 0;
|
let mut tc = 0;
|
||||||
@@ -609,7 +625,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
|||||||
visitor.level + 1,
|
visitor.level + 1,
|
||||||
);
|
);
|
||||||
|
|
||||||
if self.find_best_cutoff(&mut true_visitor, mtry) {
|
if self.find_best_cutoff(&mut true_visitor, mtry, rng) {
|
||||||
visitor_queue.push_back(true_visitor);
|
visitor_queue.push_back(true_visitor);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -622,7 +638,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
|||||||
visitor.level + 1,
|
visitor.level + 1,
|
||||||
);
|
);
|
||||||
|
|
||||||
if self.find_best_cutoff(&mut false_visitor, mtry) {
|
if self.find_best_cutoff(&mut false_visitor, mtry, rng) {
|
||||||
visitor_queue.push_back(false_visitor);
|
visitor_queue.push_back(false_visitor);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -635,6 +651,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn gini_impurity() {
|
fn gini_impurity() {
|
||||||
assert!(
|
assert!(
|
||||||
@@ -651,6 +668,7 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn fit_predict_iris() {
|
fn fit_predict_iris() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
@@ -703,6 +721,7 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn fit_predict_baloons() {
|
fn fit_predict_baloons() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
@@ -739,7 +758,9 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[1., 1., 1., 0.],
|
&[1., 1., 1., 0.],
|
||||||
|
|||||||
@@ -63,6 +63,8 @@ use std::default::Default;
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
use rand::seq::SliceRandom;
|
use rand::seq::SliceRandom;
|
||||||
|
use rand::Rng;
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||||
@@ -71,7 +73,8 @@ use crate::error::Failed;
|
|||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::math::num::RealNumber;
|
use crate::math::num::RealNumber;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
/// Parameters of Regression Tree
|
/// Parameters of Regression Tree
|
||||||
pub struct DecisionTreeRegressorParameters {
|
pub struct DecisionTreeRegressorParameters {
|
||||||
/// The maximum depth of the tree.
|
/// The maximum depth of the tree.
|
||||||
@@ -83,16 +86,18 @@ pub struct DecisionTreeRegressorParameters {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Regression Tree
|
/// Regression Tree
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct DecisionTreeRegressor<T: RealNumber> {
|
pub struct DecisionTreeRegressor<T: RealNumber> {
|
||||||
nodes: Vec<Node<T>>,
|
nodes: Vec<Node<T>>,
|
||||||
parameters: DecisionTreeRegressorParameters,
|
parameters: DecisionTreeRegressorParameters,
|
||||||
depth: u16,
|
depth: u16,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
struct Node<T: RealNumber> {
|
struct Node<T: RealNumber> {
|
||||||
index: usize,
|
_index: usize,
|
||||||
output: T,
|
output: T,
|
||||||
split_feature: usize,
|
split_feature: usize,
|
||||||
split_value: Option<T>,
|
split_value: Option<T>,
|
||||||
@@ -132,7 +137,7 @@ impl Default for DecisionTreeRegressorParameters {
|
|||||||
impl<T: RealNumber> Node<T> {
|
impl<T: RealNumber> Node<T> {
|
||||||
fn new(index: usize, output: T) -> Self {
|
fn new(index: usize, output: T) -> Self {
|
||||||
Node {
|
Node {
|
||||||
index,
|
_index: index,
|
||||||
output,
|
output,
|
||||||
split_feature: 0,
|
split_feature: 0,
|
||||||
split_value: Option::None,
|
split_value: Option::None,
|
||||||
@@ -238,7 +243,14 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
|||||||
) -> Result<DecisionTreeRegressor<T>, Failed> {
|
) -> Result<DecisionTreeRegressor<T>, Failed> {
|
||||||
let (x_nrows, num_attributes) = x.shape();
|
let (x_nrows, num_attributes) = x.shape();
|
||||||
let samples = vec![1; x_nrows];
|
let samples = vec![1; x_nrows];
|
||||||
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
DecisionTreeRegressor::fit_weak_learner(
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
samples,
|
||||||
|
num_attributes,
|
||||||
|
parameters,
|
||||||
|
&mut rand::thread_rng(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn fit_weak_learner<M: Matrix<T>>(
|
pub(crate) fn fit_weak_learner<M: Matrix<T>>(
|
||||||
@@ -247,6 +259,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
|||||||
samples: Vec<usize>,
|
samples: Vec<usize>,
|
||||||
mtry: usize,
|
mtry: usize,
|
||||||
parameters: DecisionTreeRegressorParameters,
|
parameters: DecisionTreeRegressorParameters,
|
||||||
|
rng: &mut impl Rng,
|
||||||
) -> Result<DecisionTreeRegressor<T>, Failed> {
|
) -> Result<DecisionTreeRegressor<T>, Failed> {
|
||||||
let y_m = M::from_row_vector(y.clone());
|
let y_m = M::from_row_vector(y.clone());
|
||||||
|
|
||||||
@@ -276,17 +289,17 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
|||||||
depth: 0,
|
depth: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut visitor = NodeVisitor::<T, M>::new(0, samples, &order, &x, &y_m, 1);
|
let mut visitor = NodeVisitor::<T, M>::new(0, samples, &order, x, &y_m, 1);
|
||||||
|
|
||||||
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
|
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
|
||||||
|
|
||||||
if tree.find_best_cutoff(&mut visitor, mtry) {
|
if tree.find_best_cutoff(&mut visitor, mtry, rng) {
|
||||||
visitor_queue.push_back(visitor);
|
visitor_queue.push_back(visitor);
|
||||||
}
|
}
|
||||||
|
|
||||||
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
||||||
match visitor_queue.pop_front() {
|
match visitor_queue.pop_front() {
|
||||||
Some(node) => tree.split(node, mtry, &mut visitor_queue),
|
Some(node) => tree.split(node, mtry, &mut visitor_queue, rng),
|
||||||
None => break,
|
None => break,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@@ -308,7 +321,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
|||||||
Ok(result.to_row_vector())
|
Ok(result.to_row_vector())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(in crate) fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
|
pub(crate) fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
|
||||||
let mut result = T::zero();
|
let mut result = T::zero();
|
||||||
let mut queue: LinkedList<usize> = LinkedList::new();
|
let mut queue: LinkedList<usize> = LinkedList::new();
|
||||||
|
|
||||||
@@ -339,6 +352,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
|||||||
&mut self,
|
&mut self,
|
||||||
visitor: &mut NodeVisitor<'_, T, M>,
|
visitor: &mut NodeVisitor<'_, T, M>,
|
||||||
mtry: usize,
|
mtry: usize,
|
||||||
|
rng: &mut impl Rng,
|
||||||
) -> bool {
|
) -> bool {
|
||||||
let (_, n_attr) = visitor.x.shape();
|
let (_, n_attr) = visitor.x.shape();
|
||||||
|
|
||||||
@@ -353,7 +367,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
|||||||
let mut variables = (0..n_attr).collect::<Vec<_>>();
|
let mut variables = (0..n_attr).collect::<Vec<_>>();
|
||||||
|
|
||||||
if mtry < n_attr {
|
if mtry < n_attr {
|
||||||
variables.shuffle(&mut rand::thread_rng());
|
variables.shuffle(rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
let parent_gain =
|
let parent_gain =
|
||||||
@@ -428,6 +442,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
|||||||
mut visitor: NodeVisitor<'a, T, M>,
|
mut visitor: NodeVisitor<'a, T, M>,
|
||||||
mtry: usize,
|
mtry: usize,
|
||||||
visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>,
|
visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>,
|
||||||
|
rng: &mut impl Rng,
|
||||||
) -> bool {
|
) -> bool {
|
||||||
let (n, _) = visitor.x.shape();
|
let (n, _) = visitor.x.shape();
|
||||||
let mut tc = 0;
|
let mut tc = 0;
|
||||||
@@ -476,7 +491,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
|||||||
visitor.level + 1,
|
visitor.level + 1,
|
||||||
);
|
);
|
||||||
|
|
||||||
if self.find_best_cutoff(&mut true_visitor, mtry) {
|
if self.find_best_cutoff(&mut true_visitor, mtry, rng) {
|
||||||
visitor_queue.push_back(true_visitor);
|
visitor_queue.push_back(true_visitor);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -489,7 +504,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
|||||||
visitor.level + 1,
|
visitor.level + 1,
|
||||||
);
|
);
|
||||||
|
|
||||||
if self.find_best_cutoff(&mut false_visitor, mtry) {
|
if self.find_best_cutoff(&mut false_visitor, mtry, rng) {
|
||||||
visitor_queue.push_back(false_visitor);
|
visitor_queue.push_back(false_visitor);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -502,6 +517,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
fn fit_longley() {
|
fn fit_longley() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
@@ -576,7 +592,9 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
&[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||||
|
|||||||
Reference in New Issue
Block a user