Merge pull request #133 from smartcorelib/release-0.2.1

Release 0.2.1
This commit is contained in:
morenol
2022-05-10 08:57:53 -04:00
committed by GitHub
86 changed files with 2421 additions and 421 deletions
-43
View File
@@ -1,43 +0,0 @@
version: 2.1
workflows:
version: 2.1
build:
jobs:
- build
- clippy
jobs:
build:
docker:
- image: circleci/rust:latest
environment:
TZ: "/usr/share/zoneinfo/your/location"
steps:
- checkout
- restore_cache:
key: project-cache
- run:
name: Check formatting
command: cargo fmt -- --check
- run:
name: Stable Build
command: cargo build --features "nalgebra-bindings ndarray-bindings"
- run:
name: Test
command: cargo test --features "nalgebra-bindings ndarray-bindings"
- save_cache:
key: project-cache
paths:
- "~/.cargo"
- "./target"
clippy:
docker:
- image: circleci/rust:latest
steps:
- checkout
- run:
name: Install cargo clippy
command: rustup component add clippy
- run:
name: Run cargo clippy
command: cargo clippy --all-features -- -Drust-2018-idioms -Dwarnings
+11
View File
@@ -0,0 +1,11 @@
version: 2
updates:
- package-ecosystem: cargo
directory: "/"
schedule:
interval: daily
open-pull-requests-limit: 10
ignore:
- dependency-name: rand_distr
versions:
- 0.4.0
+57
View File
@@ -0,0 +1,57 @@
name: CI
on:
push:
branches: [ main, development ]
pull_request:
branches: [ development ]
jobs:
tests:
runs-on: "${{ matrix.platform.os }}-latest"
strategy:
matrix:
platform: [
{ os: "windows", target: "x86_64-pc-windows-msvc" },
{ os: "windows", target: "i686-pc-windows-msvc" },
{ os: "ubuntu", target: "x86_64-unknown-linux-gnu" },
{ os: "ubuntu", target: "i686-unknown-linux-gnu" },
{ os: "ubuntu", target: "wasm32-unknown-unknown" },
{ os: "macos", target: "aarch64-apple-darwin" },
]
env:
TZ: "/usr/share/zoneinfo/your/location"
steps:
- uses: actions/checkout@v2
- name: Cache .cargo and target
uses: actions/cache@v2
with:
path: |
~/.cargo
./target
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.toml') }}
restore-keys: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.toml') }}
- name: Install Rust toolchain
uses: actions-rs/toolchain@v1
with:
toolchain: stable
target: ${{ matrix.platform.target }}
profile: minimal
default: true
- name: Install test runner for wasm
if: matrix.platform.target == 'wasm32-unknown-unknown'
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
- name: Stable Build
uses: actions-rs/cargo@v1
with:
command: build
args: --all-features --target ${{ matrix.platform.target }}
- name: Tests
if: matrix.platform.target == 'x86_64-unknown-linux-gnu' || matrix.platform.target == 'x86_64-pc-windows-msvc' || matrix.platform.target == 'aarch64-apple-darwin'
uses: actions-rs/cargo@v1
with:
command: test
args: --all-features
- name: Tests in WASM
if: matrix.platform.target == 'wasm32-unknown-unknown'
run: wasm-pack test --node -- --all-features
+44
View File
@@ -0,0 +1,44 @@
name: Coverage
on:
push:
branches: [ main, development ]
pull_request:
branches: [ development ]
jobs:
coverage:
runs-on: ubuntu-latest
env:
TZ: "/usr/share/zoneinfo/your/location"
steps:
- uses: actions/checkout@v2
- name: Cache .cargo
uses: actions/cache@v2
with:
path: |
~/.cargo
./target
key: ${{ runner.os }}-coverage-cargo-${{ hashFiles('**/Cargo.toml') }}
restore-keys: ${{ runner.os }}-coverage-cargo-${{ hashFiles('**/Cargo.toml') }}
- name: Install Rust toolchain
uses: actions-rs/toolchain@v1
with:
toolchain: nightly
profile: minimal
default: true
- name: Install cargo-tarpaulin
uses: actions-rs/install@v0.1
with:
crate: cargo-tarpaulin
version: latest
use-tool-cache: true
- name: Run cargo-tarpaulin
uses: actions-rs/cargo@v1
with:
command: tarpaulin
args: --out Lcov --all-features -- --test-threads 1
- name: Upload to codecov.io
uses: codecov/codecov-action@v1
with:
fail_ci_if_error: true
+41
View File
@@ -0,0 +1,41 @@
name: Lint checks
on:
push:
branches: [ main, development ]
pull_request:
branches: [ development ]
jobs:
lint:
runs-on: ubuntu-latest
env:
TZ: "/usr/share/zoneinfo/your/location"
steps:
- uses: actions/checkout@v2
- name: Cache .cargo and target
uses: actions/cache@v2
with:
path: |
~/.cargo
./target
key: ${{ runner.os }}-lint-cargo-${{ hashFiles('**/Cargo.toml') }}
restore-keys: ${{ runner.os }}-lint-cargo-${{ hashFiles('**/Cargo.toml') }}
- name: Install Rust toolchain
uses: actions-rs/toolchain@v1
with:
toolchain: stable
profile: minimal
default: true
- run: rustup component add rustfmt
- name: Check formt
uses: actions-rs/cargo@v1
with:
command: fmt
args: --all -- --check
- run: rustup component add clippy
- name: Run clippy
uses: actions-rs/cargo@v1
with:
command: clippy
args: --all-features -- -Drust-2018-idioms -Dwarnings
+60
View File
@@ -0,0 +1,60 @@
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
## Added
- L2 regularization penalty to the Logistic Regression
- Getters for the naive bayes structs
- One hot encoder
- Make moons data generator
- Support for WASM.
## Changed
- Make serde optional
## [0.2.0] - 2021-01-03
### Added
- DBSCAN
- Epsilon-SVR, SVC
- Ridge, Lasso, ElasticNet
- Bernoulli, Gaussian, Categorical and Multinomial Naive Bayes
- K-fold Cross Validation
- Singular value decomposition
- New api module
- Integration with Clippy
- Cholesky decomposition
### Changed
- ndarray upgraded to 0.14
- smartcore::error:FailedError is now non-exhaustive
- K-Means
- PCA
- Random Forest
- Linear and Logistic Regression
- KNN
- Decision Tree
## [0.1.0] - 2020-09-25
### Added
- First release of smartcore.
- KNN + distance metrics (Euclidian, Minkowski, Manhattan, Hamming, Mahalanobis)
- Linear Regression (OLS)
- Logistic Regression
- Random Forest Classifier
- Decision Tree Classifier
- PCA
- K-Means
- Integrated with ndarray
- Abstract linear algebra methods
- RandomForest Regressor
- Decision Tree Regressor
- Serde integration
- Integrated with nalgebra
- LU, QR, SVD, EVD
- Evaluation Metrics
+12 -7
View File
@@ -2,7 +2,7 @@
name = "smartcore" name = "smartcore"
description = "The most advanced machine learning library in rust." description = "The most advanced machine learning library in rust."
homepage = "https://smartcorelib.org" homepage = "https://smartcorelib.org"
version = "0.2.0" version = "0.2.1"
authors = ["SmartCore Developers"] authors = ["SmartCore Developers"]
edition = "2018" edition = "2018"
license = "Apache-2.0" license = "Apache-2.0"
@@ -19,20 +19,25 @@ nalgebra-bindings = ["nalgebra"]
datasets = [] datasets = []
[dependencies] [dependencies]
ndarray = { version = "0.14", optional = true } ndarray = { version = "0.15", optional = true }
nalgebra = { version = "0.23.0", optional = true } nalgebra = { version = "0.23.0", optional = true }
num-traits = "0.2.12" num-traits = "0.2.12"
num = "0.3.0" num = "0.4.0"
rand = "0.7.3" rand = "0.8.3"
rand_distr = "0.3.0" rand_distr = "0.4.0"
serde = { version = "1.0.115", features = ["derive"] } serde = { version = "1.0.115", features = ["derive"], optional = true }
serde_derive = "1.0.115"
[target.'cfg(target_arch = "wasm32")'.dependencies]
getrandom = { version = "0.2", features = ["js"] }
[dev-dependencies] [dev-dependencies]
criterion = "0.3" criterion = "0.3"
serde_json = "1.0" serde_json = "1.0"
bincode = "1.3.1" bincode = "1.3.1"
[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
wasm-bindgen-test = "0.3"
[[bench]] [[bench]]
name = "distance" name = "distance"
harness = false harness = false
+3 -3
View File
@@ -9,9 +9,9 @@
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape" xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
inkscape:version="1.0 (4035a4f, 2020-05-01)" inkscape:version="1.0 (4035a4f, 2020-05-01)"
sodipodi:docname="smartcore.svg" sodipodi:docname="smartcore.svg"
width="396.01309mm" width="1280"
height="86.286003mm" height="320"
viewBox="0 0 396.0131 86.286004" viewBox="0 0 454 86.286004"
version="1.1" version="1.1"
id="svg512"> id="svg512">
<metadata <metadata

Before

Width:  |  Height:  |  Size: 2.5 KiB

After

Width:  |  Height:  |  Size: 2.5 KiB

+1
View File
@@ -314,6 +314,7 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn bbdtree_iris() { fn bbdtree_iris() {
let data = DenseMatrix::from_2d_array(&[ let data = DenseMatrix::from_2d_array(&[
+26 -18
View File
@@ -24,6 +24,7 @@
//! ``` //! ```
use std::fmt::Debug; use std::fmt::Debug;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::algorithm::sort::heap_select::HeapSelection; use crate::algorithm::sort::heap_select::HeapSelection;
@@ -32,7 +33,8 @@ use crate::math::distance::Distance;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
/// Implements Cover Tree algorithm /// Implements Cover Tree algorithm
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct CoverTree<T, F: RealNumber, D: Distance<T, F>> { pub struct CoverTree<T, F: RealNumber, D: Distance<T, F>> {
base: F, base: F,
inv_log_base: F, inv_log_base: F,
@@ -56,16 +58,17 @@ impl<T, F: RealNumber, D: Distance<T, F>> PartialEq for CoverTree<T, F, D> {
} }
} }
#[derive(Debug, Serialize, Deserialize)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
struct Node<F: RealNumber> { struct Node<F: RealNumber> {
idx: usize, idx: usize,
max_dist: F, max_dist: F,
parent_dist: F, parent_dist: F,
children: Vec<Node<F>>, children: Vec<Node<F>>,
scale: i64, _scale: i64,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug)]
struct DistanceSet<F: RealNumber> { struct DistanceSet<F: RealNumber> {
idx: usize, idx: usize,
dist: Vec<F>, dist: Vec<F>,
@@ -82,7 +85,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
max_dist: F::zero(), max_dist: F::zero(),
parent_dist: F::zero(), parent_dist: F::zero(),
children: Vec::new(), children: Vec::new(),
scale: 0, _scale: 0,
}; };
let mut tree = CoverTree { let mut tree = CoverTree {
base, base,
@@ -114,7 +117,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
} }
let e = self.get_data_value(self.root.idx); let e = self.get_data_value(self.root.idx);
let mut d = self.distance.distance(&e, p); let mut d = self.distance.distance(e, p);
let mut current_cover_set: Vec<(F, &Node<F>)> = Vec::new(); let mut current_cover_set: Vec<(F, &Node<F>)> = Vec::new();
let mut zero_set: Vec<(F, &Node<F>)> = Vec::new(); let mut zero_set: Vec<(F, &Node<F>)> = Vec::new();
@@ -172,11 +175,14 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
if ds.0 <= upper_bound { if ds.0 <= upper_bound {
let v = self.get_data_value(ds.1.idx); let v = self.get_data_value(ds.1.idx);
if !self.identical_excluded || v != p { if !self.identical_excluded || v != p {
neighbors.push((ds.1.idx, ds.0, &v)); neighbors.push((ds.1.idx, ds.0, v));
} }
} }
} }
if neighbors.len() > k {
neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
}
Ok(neighbors.into_iter().take(k).collect()) Ok(neighbors.into_iter().take(k).collect())
} }
@@ -197,7 +203,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
let mut zero_set: Vec<(F, &Node<F>)> = Vec::new(); let mut zero_set: Vec<(F, &Node<F>)> = Vec::new();
let e = self.get_data_value(self.root.idx); let e = self.get_data_value(self.root.idx);
let mut d = self.distance.distance(&e, p); let mut d = self.distance.distance(e, p);
current_cover_set.push((d, &self.root)); current_cover_set.push((d, &self.root));
while !current_cover_set.is_empty() { while !current_cover_set.is_empty() {
@@ -227,7 +233,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
for ds in zero_set { for ds in zero_set {
let v = self.get_data_value(ds.1.idx); let v = self.get_data_value(ds.1.idx);
if !self.identical_excluded || v != p { if !self.identical_excluded || v != p {
neighbors.push((ds.1.idx, ds.0, &v)); neighbors.push((ds.1.idx, ds.0, v));
} }
} }
@@ -240,7 +246,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
max_dist: F::zero(), max_dist: F::zero(),
parent_dist: F::zero(), parent_dist: F::zero(),
children: Vec::new(), children: Vec::new(),
scale: 100, _scale: 100,
} }
} }
@@ -284,7 +290,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
if point_set.is_empty() { if point_set.is_empty() {
self.new_leaf(p) self.new_leaf(p)
} else { } else {
let max_dist = self.max(&point_set); let max_dist = self.max(point_set);
let next_scale = (max_scale - 1).min(self.get_scale(max_dist)); let next_scale = (max_scale - 1).min(self.get_scale(max_dist));
if next_scale == std::i64::MIN { if next_scale == std::i64::MIN {
let mut children: Vec<Node<F>> = Vec::new(); let mut children: Vec<Node<F>> = Vec::new();
@@ -301,7 +307,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
max_dist: F::zero(), max_dist: F::zero(),
parent_dist: F::zero(), parent_dist: F::zero(),
children, children,
scale: 100, _scale: 100,
} }
} else { } else {
let mut far: Vec<DistanceSet<F>> = Vec::new(); let mut far: Vec<DistanceSet<F>> = Vec::new();
@@ -313,8 +319,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
point_set.append(&mut far); point_set.append(&mut far);
child child
} else { } else {
let mut children: Vec<Node<F>> = Vec::new(); let mut children: Vec<Node<F>> = vec![child];
children.push(child);
let mut new_point_set: Vec<DistanceSet<F>> = Vec::new(); let mut new_point_set: Vec<DistanceSet<F>> = Vec::new();
let mut new_consumed_set: Vec<DistanceSet<F>> = Vec::new(); let mut new_consumed_set: Vec<DistanceSet<F>> = Vec::new();
@@ -371,7 +376,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
max_dist: self.max(consumed_set), max_dist: self.max(consumed_set),
parent_dist: F::zero(), parent_dist: F::zero(),
children, children,
scale: (top_scale - max_scale), _scale: (top_scale - max_scale),
} }
} }
} }
@@ -454,7 +459,8 @@ mod tests {
use super::*; use super::*;
use crate::math::distance::Distances; use crate::math::distance::Distances;
#[derive(Debug, Serialize, Deserialize, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
struct SimpleDistance {} struct SimpleDistance {}
impl Distance<i32, f64> for SimpleDistance { impl Distance<i32, f64> for SimpleDistance {
@@ -463,6 +469,7 @@ mod tests {
} }
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn cover_tree_test() { fn cover_tree_test() {
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
@@ -479,7 +486,7 @@ mod tests {
let knn: Vec<i32> = knn.iter().map(|v| *v.2).collect(); let knn: Vec<i32> = knn.iter().map(|v| *v.2).collect();
assert_eq!(vec!(3, 4, 5, 6, 7), knn); assert_eq!(vec!(3, 4, 5, 6, 7), knn);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn cover_tree_test1() { fn cover_tree_test1() {
let data = vec![ let data = vec![
@@ -498,8 +505,9 @@ mod tests {
assert_eq!(vec!(0, 1, 2), knn); assert_eq!(vec!(0, 1, 2), knn);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
#[cfg(feature = "serde")]
fn serde() { fn serde() {
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
+9 -5
View File
@@ -22,6 +22,7 @@
//! //!
//! ``` //! ```
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::cmp::{Ordering, PartialOrd}; use std::cmp::{Ordering, PartialOrd};
use std::marker::PhantomData; use std::marker::PhantomData;
@@ -32,7 +33,8 @@ use crate::math::distance::Distance;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
/// Implements Linear Search algorithm, see [KNN algorithms](../index.html) /// Implements Linear Search algorithm, see [KNN algorithms](../index.html)
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct LinearKNNSearch<T, F: RealNumber, D: Distance<T, F>> { pub struct LinearKNNSearch<T, F: RealNumber, D: Distance<T, F>> {
distance: D, distance: D,
data: Vec<T>, data: Vec<T>,
@@ -72,7 +74,7 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
} }
for i in 0..self.data.len() { for i in 0..self.data.len() {
let d = self.distance.distance(&from, &self.data[i]); let d = self.distance.distance(from, &self.data[i]);
let datum = heap.peek_mut(); let datum = heap.peek_mut();
if d < datum.distance { if d < datum.distance {
datum.distance = d; datum.distance = d;
@@ -102,7 +104,7 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
let mut neighbors: Vec<(usize, F, &T)> = Vec::new(); let mut neighbors: Vec<(usize, F, &T)> = Vec::new();
for i in 0..self.data.len() { for i in 0..self.data.len() {
let d = self.distance.distance(&from, &self.data[i]); let d = self.distance.distance(from, &self.data[i]);
if d <= radius { if d <= radius {
neighbors.push((i, d, &self.data[i])); neighbors.push((i, d, &self.data[i]));
@@ -138,7 +140,8 @@ mod tests {
use super::*; use super::*;
use crate::math::distance::Distances; use crate::math::distance::Distances;
#[derive(Debug, Serialize, Deserialize, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
struct SimpleDistance {} struct SimpleDistance {}
impl Distance<i32, f64> for SimpleDistance { impl Distance<i32, f64> for SimpleDistance {
@@ -147,6 +150,7 @@ mod tests {
} }
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn knn_find() { fn knn_find() {
let data1 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; let data1 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
@@ -193,7 +197,7 @@ mod tests {
assert_eq!(vec!(1, 2, 3), found_idxs2); assert_eq!(vec!(1, 2, 3), found_idxs2);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn knn_point_eq() { fn knn_point_eq() {
let point1 = KNNPoint { let point1 = KNNPoint {
+5 -2
View File
@@ -35,6 +35,7 @@ use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
use crate::error::Failed; use crate::error::Failed;
use crate::math::distance::Distance; use crate::math::distance::Distance;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub(crate) mod bbd_tree; pub(crate) mod bbd_tree;
@@ -45,7 +46,8 @@ pub mod linear_search;
/// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries. /// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries.
/// `KNNAlgorithmName` maintains a list of supported search algorithms, see [KNN algorithms](../algorithm/neighbour/index.html) /// `KNNAlgorithmName` maintains a list of supported search algorithms, see [KNN algorithms](../algorithm/neighbour/index.html)
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub enum KNNAlgorithmName { pub enum KNNAlgorithmName {
/// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html) /// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html)
LinearSearch, LinearSearch,
@@ -53,7 +55,8 @@ pub enum KNNAlgorithmName {
CoverTree, CoverTree,
} }
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub(crate) enum KNNAlgorithm<T: RealNumber, D: Distance<Vec<T>, T>> { pub(crate) enum KNNAlgorithm<T: RealNumber, D: Distance<Vec<T>, T>> {
LinearSearch(LinearKNNSearch<Vec<T>, T, D>), LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
CoverTree(CoverTree<Vec<T>, T, D>), CoverTree(CoverTree<Vec<T>, T, D>),
+6 -2
View File
@@ -53,8 +53,7 @@ impl<'a, T: PartialOrd + Debug> HeapSelection<T> {
if self.sorted { if self.sorted {
&self.heap[0] &self.heap[0]
} else { } else {
&self self.heap
.heap
.iter() .iter()
.max_by(|a, b| a.partial_cmp(b).unwrap()) .max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap() .unwrap()
@@ -96,12 +95,14 @@ impl<'a, T: PartialOrd + Debug> HeapSelection<T> {
mod tests { mod tests {
use super::*; use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn with_capacity() { fn with_capacity() {
let heap = HeapSelection::<i32>::with_capacity(3); let heap = HeapSelection::<i32>::with_capacity(3);
assert_eq!(3, heap.k); assert_eq!(3, heap.k);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn test_add() { fn test_add() {
let mut heap = HeapSelection::with_capacity(3); let mut heap = HeapSelection::with_capacity(3);
@@ -119,6 +120,7 @@ mod tests {
assert_eq!(vec![2, 0, -5], heap.get()); assert_eq!(vec![2, 0, -5], heap.get());
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn test_add1() { fn test_add1() {
let mut heap = HeapSelection::with_capacity(3); let mut heap = HeapSelection::with_capacity(3);
@@ -133,6 +135,7 @@ mod tests {
assert_eq!(vec![0f64, -1f64, -5f64], heap.get()); assert_eq!(vec![0f64, -1f64, -5f64], heap.get());
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn test_add2() { fn test_add2() {
let mut heap = HeapSelection::with_capacity(3); let mut heap = HeapSelection::with_capacity(3);
@@ -145,6 +148,7 @@ mod tests {
assert_eq!(vec![5.6568, 2.8284, 0.0], heap.get()); assert_eq!(vec![5.6568, 2.8284, 0.0], heap.get());
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn test_add_ordered() { fn test_add_ordered() {
let mut heap = HeapSelection::with_capacity(3); let mut heap = HeapSelection::with_capacity(3);
+1
View File
@@ -113,6 +113,7 @@ impl<T: Float> QuickArgSort for Vec<T> {
mod tests { mod tests {
use super::*; use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn with_capacity() { fn with_capacity() {
let arr1 = vec![0.3, 0.1, 0.2, 0.4, 0.9, 0.5, 0.7, 0.6, 0.8]; let arr1 = vec![0.3, 0.1, 0.2, 0.4, 0.9, 0.5, 0.7, 0.6, 0.8];
+9 -3
View File
@@ -43,6 +43,7 @@
use std::fmt::Debug; use std::fmt::Debug;
use std::iter::Sum; use std::iter::Sum;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName}; use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
@@ -55,7 +56,8 @@ use crate::math::num::RealNumber;
use crate::tree::decision_tree_classifier::which_max; use crate::tree::decision_tree_classifier::which_max;
/// DBSCAN clustering algorithm /// DBSCAN clustering algorithm
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct DBSCAN<T: RealNumber, D: Distance<Vec<T>, T>> { pub struct DBSCAN<T: RealNumber, D: Distance<Vec<T>, T>> {
cluster_labels: Vec<i16>, cluster_labels: Vec<i16>,
num_classes: usize, num_classes: usize,
@@ -153,11 +155,11 @@ impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
parameters: DBSCANParameters<T, D>, parameters: DBSCANParameters<T, D>,
) -> Result<DBSCAN<T, D>, Failed> { ) -> Result<DBSCAN<T, D>, Failed> {
if parameters.min_samples < 1 { if parameters.min_samples < 1 {
return Err(Failed::fit(&"Invalid minPts".to_string())); return Err(Failed::fit("Invalid minPts"));
} }
if parameters.eps <= T::zero() { if parameters.eps <= T::zero() {
return Err(Failed::fit(&"Invalid radius: ".to_string())); return Err(Failed::fit("Invalid radius: "));
} }
let mut k = 0; let mut k = 0;
@@ -263,8 +265,10 @@ impl<T: RealNumber + Sum, D: Distance<Vec<T>, T>> DBSCAN<T, D> {
mod tests { mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
#[cfg(feature = "serde")]
use crate::math::distance::euclidian::Euclidian; use crate::math::distance::euclidian::Euclidian;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn fit_predict_dbscan() { fn fit_predict_dbscan() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
@@ -296,7 +300,9 @@ mod tests {
assert_eq!(expected_labels, predicted_labels); assert_eq!(expected_labels, predicted_labels);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
#[cfg(feature = "serde")]
fn serde() { fn serde() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2], &[5.1, 3.5, 1.4, 0.2],
+13 -7
View File
@@ -56,6 +56,7 @@ use rand::Rng;
use std::fmt::Debug; use std::fmt::Debug;
use std::iter::Sum; use std::iter::Sum;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::algorithm::neighbour::bbd_tree::BBDTree; use crate::algorithm::neighbour::bbd_tree::BBDTree;
@@ -66,12 +67,13 @@ use crate::math::distance::euclidian::*;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
/// K-Means clustering algorithm /// K-Means clustering algorithm
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct KMeans<T: RealNumber> { pub struct KMeans<T: RealNumber> {
k: usize, k: usize,
y: Vec<usize>, _y: Vec<usize>,
size: Vec<usize>, size: Vec<usize>,
distortion: T, _distortion: T,
centroids: Vec<Vec<T>>, centroids: Vec<Vec<T>>,
} }
@@ -206,9 +208,9 @@ impl<T: RealNumber + Sum> KMeans<T> {
Ok(KMeans { Ok(KMeans {
k: parameters.k, k: parameters.k,
y, _y: y,
size, size,
distortion, _distortion: distortion,
centroids, centroids,
}) })
} }
@@ -243,7 +245,7 @@ impl<T: RealNumber + Sum> KMeans<T> {
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let (n, m) = data.shape(); let (n, m) = data.shape();
let mut y = vec![0; n]; let mut y = vec![0; n];
let mut centroid = data.get_row_as_vec(rng.gen_range(0, n)); let mut centroid = data.get_row_as_vec(rng.gen_range(0..n));
let mut d = vec![T::max_value(); n]; let mut d = vec![T::max_value(); n];
@@ -297,6 +299,7 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn invalid_k() { fn invalid_k() {
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]); let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
@@ -310,6 +313,7 @@ mod tests {
); );
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn fit_predict_iris() { fn fit_predict_iris() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
@@ -340,11 +344,13 @@ mod tests {
let y = kmeans.predict(&x).unwrap(); let y = kmeans.predict(&x).unwrap();
for i in 0..y.len() { for i in 0..y.len() {
assert_eq!(y[i] as usize, kmeans.y[i]); assert_eq!(y[i] as usize, kmeans._y[i]);
} }
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
#[cfg(feature = "serde")]
fn serde() { fn serde() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2], &[5.1, 3.5, 1.4, 0.2],
+3
View File
@@ -56,9 +56,11 @@ pub fn load_dataset() -> Dataset<f32, f32> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
#[cfg(not(target_arch = "wasm32"))]
use super::super::*; use super::super::*;
use super::*; use super::*;
#[cfg(not(target_arch = "wasm32"))]
#[test] #[test]
#[ignore] #[ignore]
fn refresh_boston_dataset() { fn refresh_boston_dataset() {
@@ -67,6 +69,7 @@ mod tests {
assert!(serialize_data(&dataset, "boston.xy").is_ok()); assert!(serialize_data(&dataset, "boston.xy").is_ok());
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn boston_dataset() { fn boston_dataset() {
let dataset = load_dataset(); let dataset = load_dataset();
+3
View File
@@ -66,17 +66,20 @@ pub fn load_dataset() -> Dataset<f32, f32> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
#[cfg(not(target_arch = "wasm32"))]
use super::super::*; use super::super::*;
use super::*; use super::*;
#[test] #[test]
#[ignore] #[ignore]
#[cfg(not(target_arch = "wasm32"))]
fn refresh_cancer_dataset() { fn refresh_cancer_dataset() {
// run this test to generate breast_cancer.xy file. // run this test to generate breast_cancer.xy file.
let dataset = load_dataset(); let dataset = load_dataset();
assert!(serialize_data(&dataset, "breast_cancer.xy").is_ok()); assert!(serialize_data(&dataset, "breast_cancer.xy").is_ok());
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn cancer_dataset() { fn cancer_dataset() {
let dataset = load_dataset(); let dataset = load_dataset();
+3
View File
@@ -50,9 +50,11 @@ pub fn load_dataset() -> Dataset<f32, f32> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
#[cfg(not(target_arch = "wasm32"))]
use super::super::*; use super::super::*;
use super::*; use super::*;
#[cfg(not(target_arch = "wasm32"))]
#[test] #[test]
#[ignore] #[ignore]
fn refresh_diabetes_dataset() { fn refresh_diabetes_dataset() {
@@ -61,6 +63,7 @@ mod tests {
assert!(serialize_data(&dataset, "diabetes.xy").is_ok()); assert!(serialize_data(&dataset, "diabetes.xy").is_ok());
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn boston_dataset() { fn boston_dataset() {
let dataset = load_dataset(); let dataset = load_dataset();
+3 -1
View File
@@ -45,9 +45,11 @@ pub fn load_dataset() -> Dataset<f32, f32> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
#[cfg(not(target_arch = "wasm32"))]
use super::super::*; use super::super::*;
use super::*; use super::*;
#[cfg(not(target_arch = "wasm32"))]
#[test] #[test]
#[ignore] #[ignore]
fn refresh_digits_dataset() { fn refresh_digits_dataset() {
@@ -55,7 +57,7 @@ mod tests {
let dataset = load_dataset(); let dataset = load_dataset();
assert!(serialize_data(&dataset, "digits.xy").is_ok()); assert!(serialize_data(&dataset, "digits.xy").is_ok());
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn digits_dataset() { fn digits_dataset() {
let dataset = load_dataset(); let dataset = load_dataset();
+52
View File
@@ -88,6 +88,43 @@ pub fn make_circles(num_samples: usize, factor: f32, noise: f32) -> Dataset<f32,
} }
} }
/// Make two interleaving half circles in 2d
pub fn make_moons(num_samples: usize, noise: f32) -> Dataset<f32, f32> {
let num_samples_out = num_samples / 2;
let num_samples_in = num_samples - num_samples_out;
let linspace_out = linspace(0.0, std::f32::consts::PI, num_samples_out);
let linspace_in = linspace(0.0, std::f32::consts::PI, num_samples_in);
let noise = Normal::new(0.0, noise).unwrap();
let mut rng = rand::thread_rng();
let mut x: Vec<f32> = Vec::with_capacity(num_samples * 2);
let mut y: Vec<f32> = Vec::with_capacity(num_samples);
for v in linspace_out {
x.push(v.cos() + noise.sample(&mut rng));
x.push(v.sin() + noise.sample(&mut rng));
y.push(0.0);
}
for v in linspace_in {
x.push(1.0 - v.cos() + noise.sample(&mut rng));
x.push(1.0 - v.sin() + noise.sample(&mut rng) - 0.5);
y.push(1.0);
}
Dataset {
data: x,
target: y,
num_samples,
num_features: 2,
feature_names: (0..2).map(|n| n.to_string()).collect(),
target_names: vec!["label".to_string()],
description: "Two interleaving half circles in 2d".to_string(),
}
}
fn linspace(start: f32, stop: f32, num: usize) -> Vec<f32> { fn linspace(start: f32, stop: f32, num: usize) -> Vec<f32> {
let div = num as f32; let div = num as f32;
let delta = stop - start; let delta = stop - start;
@@ -100,6 +137,7 @@ mod tests {
use super::*; use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn test_make_blobs() { fn test_make_blobs() {
let dataset = make_blobs(10, 2, 3); let dataset = make_blobs(10, 2, 3);
@@ -112,6 +150,7 @@ mod tests {
assert_eq!(dataset.num_samples, 10); assert_eq!(dataset.num_samples, 10);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn test_make_circles() { fn test_make_circles() {
let dataset = make_circles(10, 0.5, 0.05); let dataset = make_circles(10, 0.5, 0.05);
@@ -123,4 +162,17 @@ mod tests {
assert_eq!(dataset.num_features, 2); assert_eq!(dataset.num_features, 2);
assert_eq!(dataset.num_samples, 10); assert_eq!(dataset.num_samples, 10);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_make_moons() {
let dataset = make_moons(10, 0.05);
assert_eq!(
dataset.data.len(),
dataset.num_features * dataset.num_samples
);
assert_eq!(dataset.target.len(), dataset.num_samples);
assert_eq!(dataset.num_features, 2);
assert_eq!(dataset.num_samples, 10);
}
} }
+3
View File
@@ -50,9 +50,11 @@ pub fn load_dataset() -> Dataset<f32, f32> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
#[cfg(not(target_arch = "wasm32"))]
use super::super::*; use super::super::*;
use super::*; use super::*;
#[cfg(not(target_arch = "wasm32"))]
#[test] #[test]
#[ignore] #[ignore]
fn refresh_iris_dataset() { fn refresh_iris_dataset() {
@@ -61,6 +63,7 @@ mod tests {
assert!(serialize_data(&dataset, "iris.xy").is_ok()); assert!(serialize_data(&dataset, "iris.xy").is_ok());
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn iris_dataset() { fn iris_dataset() {
let dataset = load_dataset(); let dataset = load_dataset();
+12 -5
View File
@@ -8,9 +8,12 @@ pub mod digits;
pub mod generator; pub mod generator;
pub mod iris; pub mod iris;
#[cfg(not(target_arch = "wasm32"))]
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
#[cfg(not(target_arch = "wasm32"))]
use std::fs::File; use std::fs::File;
use std::io; use std::io;
#[cfg(not(target_arch = "wasm32"))]
use std::io::prelude::*; use std::io::prelude::*;
/// Dataset /// Dataset
@@ -49,6 +52,8 @@ impl<X, Y> Dataset<X, Y> {
} }
} }
// Running this in wasm throws: operation not supported on this platform.
#[cfg(not(target_arch = "wasm32"))]
#[allow(dead_code)] #[allow(dead_code)]
pub(crate) fn serialize_data<X: RealNumber, Y: RealNumber>( pub(crate) fn serialize_data<X: RealNumber, Y: RealNumber>(
dataset: &Dataset<X, Y>, dataset: &Dataset<X, Y>,
@@ -62,14 +67,14 @@ pub(crate) fn serialize_data<X: RealNumber, Y: RealNumber>(
.data .data
.iter() .iter()
.copied() .copied()
.flat_map(|f| f.to_f32_bits().to_le_bytes().to_vec().into_iter()) .flat_map(|f| f.to_f32_bits().to_le_bytes().to_vec())
.collect(); .collect();
file.write_all(&x)?; file.write_all(&x)?;
let y: Vec<u8> = dataset let y: Vec<u8> = dataset
.target .target
.iter() .iter()
.copied() .copied()
.flat_map(|f| f.to_f32_bits().to_le_bytes().to_vec().into_iter()) .flat_map(|f| f.to_f32_bits().to_le_bytes().to_vec())
.collect(); .collect();
file.write_all(&y)?; file.write_all(&y)?;
} }
@@ -82,11 +87,12 @@ pub(crate) fn deserialize_data(
bytes: &[u8], bytes: &[u8],
) -> Result<(Vec<f32>, Vec<f32>, usize, usize), io::Error> { ) -> Result<(Vec<f32>, Vec<f32>, usize, usize), io::Error> {
// read the same file back into a Vec of bytes // read the same file back into a Vec of bytes
const USIZE_SIZE: usize = std::mem::size_of::<usize>();
let (num_samples, num_features) = { let (num_samples, num_features) = {
let mut buffer = [0u8; 8]; let mut buffer = [0u8; USIZE_SIZE];
buffer.copy_from_slice(&bytes[0..8]); buffer.copy_from_slice(&bytes[0..USIZE_SIZE]);
let num_features = usize::from_le_bytes(buffer); let num_features = usize::from_le_bytes(buffer);
buffer.copy_from_slice(&bytes[8..16]); buffer.copy_from_slice(&bytes[8..8 + USIZE_SIZE]);
let num_samples = usize::from_le_bytes(buffer); let num_samples = usize::from_le_bytes(buffer);
(num_samples, num_features) (num_samples, num_features)
}; };
@@ -115,6 +121,7 @@ pub(crate) fn deserialize_data(
mod tests { mod tests {
use super::*; use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn as_matrix() { fn as_matrix() {
let dataset = Dataset { let dataset = Dataset {
+8 -3
View File
@@ -47,6 +47,7 @@
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::fmt::Debug; use std::fmt::Debug;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::api::{Transformer, UnsupervisedEstimator}; use crate::api::{Transformer, UnsupervisedEstimator};
@@ -55,7 +56,8 @@ use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
/// Principal components analysis algorithm /// Principal components analysis algorithm
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct PCA<T: RealNumber, M: Matrix<T>> { pub struct PCA<T: RealNumber, M: Matrix<T>> {
eigenvectors: M, eigenvectors: M,
eigenvalues: Vec<T>, eigenvalues: Vec<T>,
@@ -323,7 +325,7 @@ mod tests {
&[6.8, 161.0, 60.0, 15.6], &[6.8, 161.0, 60.0, 15.6],
]) ])
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn pca_components() { fn pca_components() {
let us_arrests = us_arrests_data(); let us_arrests = us_arrests_data();
@@ -339,7 +341,7 @@ mod tests {
assert!(expected.approximate_eq(&pca.components().abs(), 0.4)); assert!(expected.approximate_eq(&pca.components().abs(), 0.4));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn decompose_covariance() { fn decompose_covariance() {
let us_arrests = us_arrests_data(); let us_arrests = us_arrests_data();
@@ -449,6 +451,7 @@ mod tests {
.approximate_eq(&expected_projection.abs(), 1e-4)); .approximate_eq(&expected_projection.abs(), 1e-4));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn decompose_correlation() { fn decompose_correlation() {
let us_arrests = us_arrests_data(); let us_arrests = us_arrests_data();
@@ -564,7 +567,9 @@ mod tests {
.approximate_eq(&expected_projection.abs(), 1e-4)); .approximate_eq(&expected_projection.abs(), 1e-4));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
#[cfg(feature = "serde")]
fn serde() { fn serde() {
let iris = DenseMatrix::from_2d_array(&[ let iris = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2], &[5.1, 3.5, 1.4, 0.2],
+6 -1
View File
@@ -46,6 +46,7 @@
use std::fmt::Debug; use std::fmt::Debug;
use std::marker::PhantomData; use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::api::{Transformer, UnsupervisedEstimator}; use crate::api::{Transformer, UnsupervisedEstimator};
@@ -54,7 +55,8 @@ use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
/// SVD /// SVD
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct SVD<T: RealNumber, M: Matrix<T>> { pub struct SVD<T: RealNumber, M: Matrix<T>> {
components: M, components: M,
phantom: PhantomData<T>, phantom: PhantomData<T>,
@@ -151,6 +153,7 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::*; use crate::linalg::naive::dense_matrix::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn svd_decompose() { fn svd_decompose() {
// https://stat.ethz.ch/R-manual/R-devel/library/datasets/html/USArrests.html // https://stat.ethz.ch/R-manual/R-devel/library/datasets/html/USArrests.html
@@ -225,7 +228,9 @@ mod tests {
.approximate_eq(&expected, 1e-4)); .approximate_eq(&expected, 1e-4));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
#[cfg(feature = "serde")]
fn serde() { fn serde() {
let iris = DenseMatrix::from_2d_array(&[ let iris = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2], &[5.1, 3.5, 1.4, 0.2],
+137 -11
View File
@@ -45,14 +45,16 @@
//! //!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::default::Default; use std::default::Default;
use std::fmt::Debug; use std::fmt::Debug;
use rand::Rng; #[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator}; use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed; use crate::error::{Failed, FailedError};
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
use crate::tree::decision_tree_classifier::{ use crate::tree::decision_tree_classifier::{
@@ -61,7 +63,8 @@ use crate::tree::decision_tree_classifier::{
/// Parameters of the Random Forest algorithm. /// Parameters of the Random Forest algorithm.
/// Some parameters here are passed directly into base estimator. /// Some parameters here are passed directly into base estimator.
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct RandomForestClassifierParameters { pub struct RandomForestClassifierParameters {
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) /// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
pub criterion: SplitCriterion, pub criterion: SplitCriterion,
@@ -75,14 +78,20 @@ pub struct RandomForestClassifierParameters {
pub n_trees: u16, pub n_trees: u16,
/// Number of random sample of predictors to use as split candidates. /// Number of random sample of predictors to use as split candidates.
pub m: Option<usize>, pub m: Option<usize>,
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub keep_samples: bool,
/// Seed used for bootstrap sampling and feature selection for each tree.
pub seed: u64,
} }
/// Random Forest Classifier /// Random Forest Classifier
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct RandomForestClassifier<T: RealNumber> { pub struct RandomForestClassifier<T: RealNumber> {
parameters: RandomForestClassifierParameters, _parameters: RandomForestClassifierParameters,
trees: Vec<DecisionTreeClassifier<T>>, trees: Vec<DecisionTreeClassifier<T>>,
classes: Vec<T>, classes: Vec<T>,
samples: Option<Vec<Vec<bool>>>,
} }
impl RandomForestClassifierParameters { impl RandomForestClassifierParameters {
@@ -116,6 +125,18 @@ impl RandomForestClassifierParameters {
self.m = Some(m); self.m = Some(m);
self self
} }
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub fn with_keep_samples(mut self, keep_samples: bool) -> Self {
self.keep_samples = keep_samples;
self
}
/// Seed used for bootstrap sampling and feature selection for each tree.
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
} }
impl<T: RealNumber> PartialEq for RandomForestClassifier<T> { impl<T: RealNumber> PartialEq for RandomForestClassifier<T> {
@@ -147,6 +168,8 @@ impl Default for RandomForestClassifierParameters {
min_samples_split: 2, min_samples_split: 2,
n_trees: 100, n_trees: 100,
m: Option::None, m: Option::None,
keep_samples: false,
seed: 0,
} }
} }
} }
@@ -198,26 +221,38 @@ impl<T: RealNumber> RandomForestClassifier<T> {
.unwrap() .unwrap()
}); });
let mut rng = StdRng::seed_from_u64(parameters.seed);
let classes = y_m.unique(); let classes = y_m.unique();
let k = classes.len(); let k = classes.len();
let mut trees: Vec<DecisionTreeClassifier<T>> = Vec::new(); let mut trees: Vec<DecisionTreeClassifier<T>> = Vec::new();
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
if parameters.keep_samples {
maybe_all_samples = Some(Vec::new());
}
for _ in 0..parameters.n_trees { for _ in 0..parameters.n_trees {
let samples = RandomForestClassifier::<T>::sample_with_replacement(&yi, k); let samples = RandomForestClassifier::<T>::sample_with_replacement(&yi, k, &mut rng);
if let Some(ref mut all_samples) = maybe_all_samples {
all_samples.push(samples.iter().map(|x| *x != 0).collect())
}
let params = DecisionTreeClassifierParameters { let params = DecisionTreeClassifierParameters {
criterion: parameters.criterion.clone(), criterion: parameters.criterion.clone(),
max_depth: parameters.max_depth, max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf, min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split, min_samples_split: parameters.min_samples_split,
}; };
let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params)?; let tree =
DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
trees.push(tree); trees.push(tree);
} }
Ok(RandomForestClassifier { Ok(RandomForestClassifier {
parameters, _parameters: parameters,
trees, trees,
classes, classes,
samples: maybe_all_samples,
}) })
} }
@@ -245,8 +280,43 @@ impl<T: RealNumber> RandomForestClassifier<T> {
which_max(&result) which_max(&result)
} }
fn sample_with_replacement(y: &[usize], num_classes: usize) -> Vec<usize> { /// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
let mut rng = rand::thread_rng(); pub fn predict_oob<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
let (n, _) = x.shape();
if self.samples.is_none() {
Err(Failed::because(
FailedError::PredictFailed,
"Need samples=true for OOB predictions.",
))
} else if self.samples.as_ref().unwrap()[0].len() != n {
Err(Failed::because(
FailedError::PredictFailed,
"Prediction matrix must match matrix used in training for OOB predictions.",
))
} else {
let mut result = M::zeros(1, n);
for i in 0..n {
result.set(0, i, self.classes[self.predict_for_row_oob(x, i)]);
}
Ok(result.to_row_vector())
}
}
fn predict_for_row_oob<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
let mut result = vec![0; self.classes.len()];
for (tree, samples) in self.trees.iter().zip(self.samples.as_ref().unwrap()) {
if !samples[row] {
result[tree.predict_for_row(x, row)] += 1;
}
}
which_max(&result)
}
fn sample_with_replacement(y: &[usize], num_classes: usize, rng: &mut impl Rng) -> Vec<usize> {
let class_weight = vec![1.; num_classes]; let class_weight = vec![1.; num_classes];
let nrows = y.len(); let nrows = y.len();
let mut samples = vec![0; nrows]; let mut samples = vec![0; nrows];
@@ -262,7 +332,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
let size = ((n_samples as f64) / *class_weight_l) as usize; let size = ((n_samples as f64) / *class_weight_l) as usize;
for _ in 0..size { for _ in 0..size {
let xi: usize = rng.gen_range(0, n_samples); let xi: usize = rng.gen_range(0..n_samples);
samples[index[xi]] += 1; samples[index[xi]] += 1;
} }
} }
@@ -276,6 +346,7 @@ mod tests {
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::metrics::*; use crate::metrics::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn fit_predict_iris() { fn fit_predict_iris() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
@@ -314,6 +385,8 @@ mod tests {
min_samples_split: 2, min_samples_split: 2,
n_trees: 100, n_trees: 100,
m: Option::None, m: Option::None,
keep_samples: false,
seed: 87,
}, },
) )
.unwrap(); .unwrap();
@@ -321,7 +394,60 @@ mod tests {
assert!(accuracy(&y, &classifier.predict(&x).unwrap()) >= 0.95); assert!(accuracy(&y, &classifier.predict(&x).unwrap()) >= 0.95);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn fit_predict_iris_oob() {
let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[5.7, 2.8, 4.5, 1.3],
&[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
let y = vec![
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
let classifier = RandomForestClassifier::fit(
&x,
&y,
RandomForestClassifierParameters {
criterion: SplitCriterion::Gini,
max_depth: None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 100,
m: Option::None,
keep_samples: true,
seed: 87,
},
)
.unwrap();
assert!(
accuracy(&y, &classifier.predict_oob(&x).unwrap())
< accuracy(&y, &classifier.predict(&x).unwrap())
);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
#[cfg(feature = "serde")]
fn serde() { fn serde() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2], &[5.1, 3.5, 1.4, 0.2],
+138 -12
View File
@@ -43,21 +43,24 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::default::Default; use std::default::Default;
use std::fmt::Debug; use std::fmt::Debug;
use rand::Rng; #[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator}; use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed; use crate::error::{Failed, FailedError};
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
use crate::tree::decision_tree_regressor::{ use crate::tree::decision_tree_regressor::{
DecisionTreeRegressor, DecisionTreeRegressorParameters, DecisionTreeRegressor, DecisionTreeRegressorParameters,
}; };
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// Parameters of the Random Forest Regressor /// Parameters of the Random Forest Regressor
/// Some parameters here are passed directly into base estimator. /// Some parameters here are passed directly into base estimator.
pub struct RandomForestRegressorParameters { pub struct RandomForestRegressorParameters {
@@ -71,13 +74,19 @@ pub struct RandomForestRegressorParameters {
pub n_trees: usize, pub n_trees: usize,
/// Number of random sample of predictors to use as split candidates. /// Number of random sample of predictors to use as split candidates.
pub m: Option<usize>, pub m: Option<usize>,
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub keep_samples: bool,
/// Seed used for bootstrap sampling and feature selection for each tree.
pub seed: u64,
} }
/// Random Forest Regressor /// Random Forest Regressor
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct RandomForestRegressor<T: RealNumber> { pub struct RandomForestRegressor<T: RealNumber> {
parameters: RandomForestRegressorParameters, _parameters: RandomForestRegressorParameters,
trees: Vec<DecisionTreeRegressor<T>>, trees: Vec<DecisionTreeRegressor<T>>,
samples: Option<Vec<Vec<bool>>>,
} }
impl RandomForestRegressorParameters { impl RandomForestRegressorParameters {
@@ -106,8 +115,19 @@ impl RandomForestRegressorParameters {
self.m = Some(m); self.m = Some(m);
self self
} }
}
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
pub fn with_keep_samples(mut self, keep_samples: bool) -> Self {
self.keep_samples = keep_samples;
self
}
/// Seed used for bootstrap sampling and feature selection for each tree.
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
}
impl Default for RandomForestRegressorParameters { impl Default for RandomForestRegressorParameters {
fn default() -> Self { fn default() -> Self {
RandomForestRegressorParameters { RandomForestRegressorParameters {
@@ -116,6 +136,8 @@ impl Default for RandomForestRegressorParameters {
min_samples_split: 2, min_samples_split: 2,
n_trees: 10, n_trees: 10,
m: Option::None, m: Option::None,
keep_samples: false,
seed: 0,
} }
} }
} }
@@ -169,20 +191,34 @@ impl<T: RealNumber> RandomForestRegressor<T> {
.m .m
.unwrap_or((num_attributes as f64).sqrt().floor() as usize); .unwrap_or((num_attributes as f64).sqrt().floor() as usize);
let mut rng = StdRng::seed_from_u64(parameters.seed);
let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new(); let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new();
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
if parameters.keep_samples {
maybe_all_samples = Some(Vec::new());
}
for _ in 0..parameters.n_trees { for _ in 0..parameters.n_trees {
let samples = RandomForestRegressor::<T>::sample_with_replacement(n_rows); let samples = RandomForestRegressor::<T>::sample_with_replacement(n_rows, &mut rng);
if let Some(ref mut all_samples) = maybe_all_samples {
all_samples.push(samples.iter().map(|x| *x != 0).collect())
}
let params = DecisionTreeRegressorParameters { let params = DecisionTreeRegressorParameters {
max_depth: parameters.max_depth, max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf, min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split, min_samples_split: parameters.min_samples_split,
}; };
let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params)?; let tree =
DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
trees.push(tree); trees.push(tree);
} }
Ok(RandomForestRegressor { parameters, trees }) Ok(RandomForestRegressor {
_parameters: parameters,
trees,
samples: maybe_all_samples,
})
} }
/// Predict class for `x` /// Predict class for `x`
@@ -211,11 +247,49 @@ impl<T: RealNumber> RandomForestRegressor<T> {
result / T::from(n_trees).unwrap() result / T::from(n_trees).unwrap()
} }
fn sample_with_replacement(nrows: usize) -> Vec<usize> { /// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
let mut rng = rand::thread_rng(); pub fn predict_oob<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
let (n, _) = x.shape();
if self.samples.is_none() {
Err(Failed::because(
FailedError::PredictFailed,
"Need samples=true for OOB predictions.",
))
} else if self.samples.as_ref().unwrap()[0].len() != n {
Err(Failed::because(
FailedError::PredictFailed,
"Prediction matrix must match matrix used in training for OOB predictions.",
))
} else {
let mut result = M::zeros(1, n);
for i in 0..n {
result.set(0, i, self.predict_for_row_oob(x, i));
}
Ok(result.to_row_vector())
}
}
fn predict_for_row_oob<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
let mut n_trees = 0;
let mut result = T::zero();
for (tree, samples) in self.trees.iter().zip(self.samples.as_ref().unwrap()) {
if !samples[row] {
result += tree.predict_for_row(x, row);
n_trees += 1;
}
}
// TODO: What to do if there are no oob trees?
result / T::from(n_trees).unwrap()
}
fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec<usize> {
let mut samples = vec![0; nrows]; let mut samples = vec![0; nrows];
for _ in 0..nrows { for _ in 0..nrows {
let xi = rng.gen_range(0, nrows); let xi = rng.gen_range(0..nrows);
samples[xi] += 1; samples[xi] += 1;
} }
samples samples
@@ -228,6 +302,7 @@ mod tests {
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::metrics::mean_absolute_error; use crate::metrics::mean_absolute_error;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn fit_longley() { fn fit_longley() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
@@ -262,6 +337,8 @@ mod tests {
min_samples_split: 2, min_samples_split: 2,
n_trees: 1000, n_trees: 1000,
m: Option::None, m: Option::None,
keep_samples: false,
seed: 87,
}, },
) )
.and_then(|rf| rf.predict(&x)) .and_then(|rf| rf.predict(&x))
@@ -270,7 +347,56 @@ mod tests {
assert!(mean_absolute_error(&y, &y_hat) < 1.0); assert!(mean_absolute_error(&y, &y_hat) < 1.0);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn fit_predict_longley_oob() {
let x = DenseMatrix::from_2d_array(&[
&[234.289, 235.6, 159., 107.608, 1947., 60.323],
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
&[284.599, 335.1, 165., 110.929, 1950., 61.187],
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
&[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
&[365.385, 187., 354.7, 115.094, 1953., 64.989],
&[363.112, 357.8, 335., 116.219, 1954., 63.761],
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
&[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
&[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
]);
let y = vec![
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
114.2, 115.7, 116.9,
];
let regressor = RandomForestRegressor::fit(
&x,
&y,
RandomForestRegressorParameters {
max_depth: None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 1000,
m: Option::None,
keep_samples: true,
seed: 87,
},
)
.unwrap();
let y_hat = regressor.predict(&x).unwrap();
let y_hat_oob = regressor.predict_oob(&x).unwrap();
assert!(mean_absolute_error(&y, &y_hat) < mean_absolute_error(&y, &y_hat_oob));
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
#[cfg(feature = "serde")]
fn serde() { fn serde() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
&[234.289, 235.6, 159., 107.608, 1947., 60.323], &[234.289, 235.6, 159., 107.608, 1947., 60.323],
+5 -2
View File
@@ -2,10 +2,12 @@
use std::error::Error; use std::error::Error;
use std::fmt; use std::fmt;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// Generic error to be raised when something goes wrong. /// Generic error to be raised when something goes wrong.
#[derive(Debug, Serialize, Deserialize)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct Failed { pub struct Failed {
err: FailedError, err: FailedError,
msg: String, msg: String,
@@ -13,7 +15,8 @@ pub struct Failed {
/// Type of error /// Type of error
#[non_exhaustive] #[non_exhaustive]
#[derive(Copy, Clone, Debug, Serialize, Deserialize)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Copy, Clone, Debug)]
pub enum FailedError { pub enum FailedError {
/// Can't fit algorithm to data /// Can't fit algorithm to data
FitFailed = 1, FitFailed = 1,
+7 -3
View File
@@ -1,10 +1,12 @@
#![allow( #![allow(
clippy::type_complexity, clippy::type_complexity,
clippy::too_many_arguments, clippy::too_many_arguments,
clippy::many_single_char_names clippy::many_single_char_names,
clippy::unnecessary_wraps,
clippy::upper_case_acronyms
)] )]
#![warn(missing_docs)] #![warn(missing_docs)]
#![warn(missing_doc_code_examples)] #![warn(rustdoc::missing_doc_code_examples)]
//! # SmartCore //! # SmartCore
//! //!
@@ -28,7 +30,7 @@
//! //!
//! All machine learning algorithms in SmartCore are grouped into these broad categories: //! All machine learning algorithms in SmartCore are grouped into these broad categories:
//! * [Clustering](cluster/index.html), unsupervised clustering of unlabeled data. //! * [Clustering](cluster/index.html), unsupervised clustering of unlabeled data.
//! * [Martix Decomposition](decomposition/index.html), various methods for matrix decomposition. //! * [Matrix Decomposition](decomposition/index.html), various methods for matrix decomposition.
//! * [Linear Models](linear/index.html), regression and classification methods where output is assumed to have linear relation to explanatory variables //! * [Linear Models](linear/index.html), regression and classification methods where output is assumed to have linear relation to explanatory variables
//! * [Ensemble Models](ensemble/index.html), variety of regression and classification ensemble models //! * [Ensemble Models](ensemble/index.html), variety of regression and classification ensemble models
//! * [Tree-based Models](tree/index.html), classification and regression trees //! * [Tree-based Models](tree/index.html), classification and regression trees
@@ -91,6 +93,8 @@ pub mod naive_bayes;
/// Supervised neighbors-based learning methods /// Supervised neighbors-based learning methods
pub mod neighbors; pub mod neighbors;
pub(crate) mod optimization; pub(crate) mod optimization;
/// Preprocessing utilities
pub mod preprocessing;
/// Support Vector Machines /// Support Vector Machines
pub mod svm; pub mod svm;
/// Supervised tree-based learning methods /// Supervised tree-based learning methods
+5 -5
View File
@@ -87,8 +87,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> Cholesky<T, M> {
if bn != rn { if bn != rn {
return Err(Failed::because( return Err(Failed::because(
FailedError::SolutionFailed, FailedError::SolutionFailed,
&"Can\'t solve Ax = b for x. Number of rows in b != number of rows in R." "Can\'t solve Ax = b for x. Number of rows in b != number of rows in R.",
.to_string(),
)); ));
} }
@@ -128,7 +127,7 @@ pub trait CholeskyDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
if m != n { if m != n {
return Err(Failed::because( return Err(Failed::because(
FailedError::DecompositionFailed, FailedError::DecompositionFailed,
&"Can\'t do Cholesky decomposition on a non-square matrix".to_string(), "Can\'t do Cholesky decomposition on a non-square matrix",
)); ));
} }
@@ -148,7 +147,7 @@ pub trait CholeskyDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
if d < T::zero() { if d < T::zero() {
return Err(Failed::because( return Err(Failed::because(
FailedError::DecompositionFailed, FailedError::DecompositionFailed,
&"The matrix is not positive definite.".to_string(), "The matrix is not positive definite.",
)); ));
} }
@@ -168,7 +167,7 @@ pub trait CholeskyDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
mod tests { mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::*; use crate::linalg::naive::dense_matrix::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn cholesky_decompose() { fn cholesky_decompose() {
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]); let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
@@ -187,6 +186,7 @@ mod tests {
.approximate_eq(&a.abs(), 1e-4)); .approximate_eq(&a.abs(), 1e-4));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn cholesky_solve_mut() { fn cholesky_solve_mut() {
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]); let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
+11 -16
View File
@@ -93,11 +93,11 @@ pub trait EVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
sort(&mut d, &mut e, &mut V); sort(&mut d, &mut e, &mut V);
} }
Ok(EVD { V, d, e }) Ok(EVD { d, e, V })
} }
} }
fn tred2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec<T>) { fn tred2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
let (n, _) = V.shape(); let (n, _) = V.shape();
for (i, d_i) in d.iter_mut().enumerate().take(n) { for (i, d_i) in d.iter_mut().enumerate().take(n) {
*d_i = V.get(n - 1, i); *d_i = V.get(n - 1, i);
@@ -195,7 +195,7 @@ fn tred2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec
e[0] = T::zero(); e[0] = T::zero();
} }
fn tql2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut Vec<T>, e: &mut Vec<T>) { fn tql2<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, d: &mut [T], e: &mut [T]) {
let (n, _) = V.shape(); let (n, _) = V.shape();
for i in 1..n { for i in 1..n {
e[i - 1] = e[i]; e[i - 1] = e[i];
@@ -419,7 +419,7 @@ fn eltran<T: RealNumber, M: BaseMatrix<T>>(A: &M, V: &mut M, perm: &[usize]) {
} }
} }
fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e: &mut Vec<T>) { fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut [T], e: &mut [T]) {
let (n, _) = A.shape(); let (n, _) = A.shape();
let mut z = T::zero(); let mut z = T::zero();
let mut s = T::zero(); let mut s = T::zero();
@@ -471,7 +471,7 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e
A.set(nn, nn, x); A.set(nn, nn, x);
A.set(nn - 1, nn - 1, y + t); A.set(nn - 1, nn - 1, y + t);
if q >= T::zero() { if q >= T::zero() {
z = p + z.copysign(p); z = p + RealNumber::copysign(z, p);
d[nn - 1] = x + z; d[nn - 1] = x + z;
d[nn] = x + z; d[nn] = x + z;
if z != T::zero() { if z != T::zero() {
@@ -570,7 +570,7 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e
r /= x; r /= x;
} }
} }
let s = (p * p + q * q + r * r).sqrt().copysign(p); let s = RealNumber::copysign((p * p + q * q + r * r).sqrt(), p);
if s != T::zero() { if s != T::zero() {
if k == m { if k == m {
if l != m { if l != m {
@@ -594,12 +594,7 @@ fn hqr2<T: RealNumber, M: BaseMatrix<T>>(A: &mut M, V: &mut M, d: &mut Vec<T>, e
A.sub_element_mut(k + 1, j, p * y); A.sub_element_mut(k + 1, j, p * y);
A.sub_element_mut(k, j, p * x); A.sub_element_mut(k, j, p * x);
} }
let mmin; let mmin = if nn < k + 3 { nn } else { k + 3 };
if nn < k + 3 {
mmin = nn;
} else {
mmin = k + 3;
}
for i in 0..mmin + 1 { for i in 0..mmin + 1 {
p = x * A.get(i, k) + y * A.get(i, k + 1); p = x * A.get(i, k) + y * A.get(i, k + 1);
if k + 1 != nn { if k + 1 != nn {
@@ -783,7 +778,7 @@ fn balbak<T: RealNumber, M: BaseMatrix<T>>(V: &mut M, scale: &[T]) {
} }
} }
fn sort<T: RealNumber, M: BaseMatrix<T>>(d: &mut Vec<T>, e: &mut Vec<T>, V: &mut M) { fn sort<T: RealNumber, M: BaseMatrix<T>>(d: &mut [T], e: &mut [T], V: &mut M) {
let n = d.len(); let n = d.len();
let mut temp = vec![T::zero(); n]; let mut temp = vec![T::zero(); n];
for j in 1..n { for j in 1..n {
@@ -816,7 +811,7 @@ fn sort<T: RealNumber, M: BaseMatrix<T>>(d: &mut Vec<T>, e: &mut Vec<T>, V: &mut
mod tests { mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn decompose_symmetric() { fn decompose_symmetric() {
let A = DenseMatrix::from_2d_array(&[ let A = DenseMatrix::from_2d_array(&[
@@ -843,7 +838,7 @@ mod tests {
assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON); assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON);
} }
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn decompose_asymmetric() { fn decompose_asymmetric() {
let A = DenseMatrix::from_2d_array(&[ let A = DenseMatrix::from_2d_array(&[
@@ -870,7 +865,7 @@ mod tests {
assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON); assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON);
} }
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn decompose_complex() { fn decompose_complex() {
let A = DenseMatrix::from_2d_array(&[ let A = DenseMatrix::from_2d_array(&[
+5 -4
View File
@@ -46,13 +46,13 @@ use crate::math::num::RealNumber;
pub struct LU<T: RealNumber, M: BaseMatrix<T>> { pub struct LU<T: RealNumber, M: BaseMatrix<T>> {
LU: M, LU: M,
pivot: Vec<usize>, pivot: Vec<usize>,
pivot_sign: i8, _pivot_sign: i8,
singular: bool, singular: bool,
phantom: PhantomData<T>, phantom: PhantomData<T>,
} }
impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> { impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
pub(crate) fn new(LU: M, pivot: Vec<usize>, pivot_sign: i8) -> LU<T, M> { pub(crate) fn new(LU: M, pivot: Vec<usize>, _pivot_sign: i8) -> LU<T, M> {
let (_, n) = LU.shape(); let (_, n) = LU.shape();
let mut singular = false; let mut singular = false;
@@ -66,7 +66,7 @@ impl<T: RealNumber, M: BaseMatrix<T>> LU<T, M> {
LU { LU {
LU, LU,
pivot, pivot,
pivot_sign, _pivot_sign,
singular, singular,
phantom: PhantomData, phantom: PhantomData,
} }
@@ -260,6 +260,7 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::*; use crate::linalg::naive::dense_matrix::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn decompose() { fn decompose() {
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]); let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
@@ -274,7 +275,7 @@ mod tests {
assert!(lu.U().approximate_eq(&expected_U, 1e-4)); assert!(lu.U().approximate_eq(&expected_U, 1e-4));
assert!(lu.pivot().approximate_eq(&expected_pivot, 1e-4)); assert!(lu.pivot().approximate_eq(&expected_pivot, 1e-4));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn inverse() { fn inverse() {
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]); let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
+11 -6
View File
@@ -1,3 +1,4 @@
#![allow(clippy::wrong_self_convention)]
//! # Linear Algebra and Matrix Decomposition //! # Linear Algebra and Matrix Decomposition
//! //!
//! Most machine learning algorithms in SmartCore depend on linear algebra and matrix decomposition methods from this module. //! Most machine learning algorithms in SmartCore depend on linear algebra and matrix decomposition methods from this module.
@@ -265,7 +266,7 @@ pub trait BaseVector<T: RealNumber>: Clone + Debug {
sum += xi * xi; sum += xi * xi;
} }
mu /= div; mu /= div;
sum / div - mu * mu sum / div - mu.powi(2)
} }
/// Computes the standard deviation. /// Computes the standard deviation.
fn std(&self) -> T { fn std(&self) -> T {
@@ -688,12 +689,11 @@ impl<'a, T: RealNumber, M: BaseMatrix<T>> Iterator for RowIter<'a, T, M> {
type Item = Vec<T>; type Item = Vec<T>;
fn next(&mut self) -> Option<Vec<T>> { fn next(&mut self) -> Option<Vec<T>> {
let res; let res = if self.pos < self.max_pos {
if self.pos < self.max_pos { Some(self.m.get_row_as_vec(self.pos))
res = Some(self.m.get_row_as_vec(self.pos))
} else { } else {
res = None None
} };
self.pos += 1; self.pos += 1;
res res
} }
@@ -705,6 +705,7 @@ mod tests {
use crate::linalg::BaseMatrix; use crate::linalg::BaseMatrix;
use crate::linalg::BaseVector; use crate::linalg::BaseVector;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn mean() { fn mean() {
let m = vec![1., 2., 3.]; let m = vec![1., 2., 3.];
@@ -712,6 +713,7 @@ mod tests {
assert_eq!(m.mean(), 2.0); assert_eq!(m.mean(), 2.0);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn std() { fn std() {
let m = vec![1., 2., 3.]; let m = vec![1., 2., 3.];
@@ -719,6 +721,7 @@ mod tests {
assert!((m.std() - 0.81f64).abs() < 1e-2); assert!((m.std() - 0.81f64).abs() < 1e-2);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn var() { fn var() {
let m = vec![1., 2., 3., 4.]; let m = vec![1., 2., 3., 4.];
@@ -726,6 +729,7 @@ mod tests {
assert!((m.var() - 1.25f64).abs() < std::f64::EPSILON); assert!((m.var() - 1.25f64).abs() < std::f64::EPSILON);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn vec_take() { fn vec_take() {
let m = vec![1., 2., 3., 4., 5.]; let m = vec![1., 2., 3., 4., 5.];
@@ -733,6 +737,7 @@ mod tests {
assert_eq!(m.take(&vec!(0, 0, 4, 4)), vec![1., 1., 5., 5.]); assert_eq!(m.take(&vec!(0, 0, 4, 4)), vec![1., 1., 5., 5.]);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn take() { fn take() {
let m = DenseMatrix::from_2d_array(&[ let m = DenseMatrix::from_2d_array(&[
+41 -34
View File
@@ -1,11 +1,15 @@
#![allow(clippy::ptr_arg)] #![allow(clippy::ptr_arg)]
use std::fmt; use std::fmt;
use std::fmt::Debug; use std::fmt::Debug;
#[cfg(feature = "serde")]
use std::marker::PhantomData; use std::marker::PhantomData;
use std::ops::Range; use std::ops::Range;
#[cfg(feature = "serde")]
use serde::de::{Deserializer, MapAccess, SeqAccess, Visitor}; use serde::de::{Deserializer, MapAccess, SeqAccess, Visitor};
#[cfg(feature = "serde")]
use serde::ser::{SerializeStruct, Serializer}; use serde::ser::{SerializeStruct, Serializer};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::linalg::cholesky::CholeskyDecomposableMatrix; use crate::linalg::cholesky::CholeskyDecomposableMatrix;
@@ -326,7 +330,7 @@ impl<T: RealNumber> DenseMatrix<T> {
cur_r: 0, cur_r: 0,
max_c: self.ncols, max_c: self.ncols,
max_r: self.nrows, max_r: self.nrows,
m: &self, m: self,
} }
} }
} }
@@ -349,6 +353,7 @@ impl<'a, T: RealNumber> Iterator for DenseMatrixIterator<'a, T> {
} }
} }
#[cfg(feature = "serde")]
impl<'de, T: RealNumber + fmt::Debug + Deserialize<'de>> Deserialize<'de> for DenseMatrix<T> { impl<'de, T: RealNumber + fmt::Debug + Deserialize<'de>> Deserialize<'de> for DenseMatrix<T> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where where
@@ -434,6 +439,7 @@ impl<'de, T: RealNumber + fmt::Debug + Deserialize<'de>> Deserialize<'de> for De
} }
} }
#[cfg(feature = "serde")]
impl<T: RealNumber + fmt::Debug + Serialize> Serialize for DenseMatrix<T> { impl<T: RealNumber + fmt::Debug + Serialize> Serialize for DenseMatrix<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where where
@@ -517,10 +523,9 @@ impl<T: RealNumber> PartialEq for DenseMatrix<T> {
true true
} }
} }
impl<T: RealNumber> From<DenseMatrix<T>> for Vec<T> {
impl<T: RealNumber> Into<Vec<T>> for DenseMatrix<T> { fn from(dense_matrix: DenseMatrix<T>) -> Vec<T> {
fn into(self) -> Vec<T> { dense_matrix.values
self.values
} }
} }
@@ -1054,14 +1059,14 @@ impl<T: RealNumber> BaseMatrix<T> for DenseMatrix<T> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn vec_dot() { fn vec_dot() {
let v1 = vec![1., 2., 3.]; let v1 = vec![1., 2., 3.];
let v2 = vec![4., 5., 6.]; let v2 = vec![4., 5., 6.];
assert_eq!(32.0, BaseVector::dot(&v1, &v2)); assert_eq!(32.0, BaseVector::dot(&v1, &v2));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn vec_copy_from() { fn vec_copy_from() {
let mut v1 = vec![1., 2., 3.]; let mut v1 = vec![1., 2., 3.];
@@ -1069,7 +1074,7 @@ mod tests {
v1.copy_from(&v2); v1.copy_from(&v2);
assert_eq!(v1, v2); assert_eq!(v1, v2);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn vec_approximate_eq() { fn vec_approximate_eq() {
let a = vec![1., 2., 3.]; let a = vec![1., 2., 3.];
@@ -1077,7 +1082,7 @@ mod tests {
assert!(a.approximate_eq(&b, 1e-4)); assert!(a.approximate_eq(&b, 1e-4));
assert!(!a.approximate_eq(&b, 1e-5)); assert!(!a.approximate_eq(&b, 1e-5));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn from_array() { fn from_array() {
let vec = [1., 2., 3., 4., 5., 6.]; let vec = [1., 2., 3., 4., 5., 6.];
@@ -1090,7 +1095,7 @@ mod tests {
DenseMatrix::new(2, 3, vec![1., 4., 2., 5., 3., 6.]) DenseMatrix::new(2, 3, vec![1., 4., 2., 5., 3., 6.])
); );
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn row_column_vec_from_array() { fn row_column_vec_from_array() {
let vec = vec![1., 2., 3., 4., 5., 6.]; let vec = vec![1., 2., 3., 4., 5., 6.];
@@ -1103,7 +1108,7 @@ mod tests {
DenseMatrix::new(6, 1, vec![1., 2., 3., 4., 5., 6.]) DenseMatrix::new(6, 1, vec![1., 2., 3., 4., 5., 6.])
); );
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn from_to_row_vec() { fn from_to_row_vec() {
let vec = vec![1., 2., 3.]; let vec = vec![1., 2., 3.];
@@ -1116,20 +1121,20 @@ mod tests {
vec![1., 2., 3.] vec![1., 2., 3.]
); );
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn col_matrix_to_row_vector() { fn col_matrix_to_row_vector() {
let m: DenseMatrix<f64> = BaseMatrix::zeros(10, 1); let m: DenseMatrix<f64> = BaseMatrix::zeros(10, 1);
assert_eq!(m.to_row_vector().len(), 10) assert_eq!(m.to_row_vector().len(), 10)
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn iter() { fn iter() {
let vec = vec![1., 2., 3., 4., 5., 6.]; let vec = vec![1., 2., 3., 4., 5., 6.];
let m = DenseMatrix::from_array(3, 2, &vec); let m = DenseMatrix::from_array(3, 2, &vec);
assert_eq!(vec, m.iter().collect::<Vec<f32>>()); assert_eq!(vec, m.iter().collect::<Vec<f32>>());
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn v_stack() { fn v_stack() {
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]); let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
@@ -1144,7 +1149,7 @@ mod tests {
let result = a.v_stack(&b); let result = a.v_stack(&b);
assert_eq!(result, expected); assert_eq!(result, expected);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn h_stack() { fn h_stack() {
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]); let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
@@ -1157,13 +1162,13 @@ mod tests {
let result = a.h_stack(&b); let result = a.h_stack(&b);
assert_eq!(result, expected); assert_eq!(result, expected);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn get_row() { fn get_row() {
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]); let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
assert_eq!(vec![4., 5., 6.], a.get_row(1)); assert_eq!(vec![4., 5., 6.], a.get_row(1));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn matmul() { fn matmul() {
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]); let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
@@ -1172,7 +1177,7 @@ mod tests {
let result = a.matmul(&b); let result = a.matmul(&b);
assert_eq!(result, expected); assert_eq!(result, expected);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn ab() { fn ab() {
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]); let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
@@ -1195,14 +1200,14 @@ mod tests {
DenseMatrix::from_2d_array(&[&[29., 39., 49.], &[40., 54., 68.,], &[51., 69., 87.]]) DenseMatrix::from_2d_array(&[&[29., 39., 49.], &[40., 54., 68.,], &[51., 69., 87.]])
); );
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn dot() { fn dot() {
let a = DenseMatrix::from_array(1, 3, &[1., 2., 3.]); let a = DenseMatrix::from_array(1, 3, &[1., 2., 3.]);
let b = DenseMatrix::from_array(1, 3, &[4., 5., 6.]); let b = DenseMatrix::from_array(1, 3, &[4., 5., 6.]);
assert_eq!(a.dot(&b), 32.); assert_eq!(a.dot(&b), 32.);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn copy_from() { fn copy_from() {
let mut a = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.]]); let mut a = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.]]);
@@ -1210,7 +1215,7 @@ mod tests {
a.copy_from(&b); a.copy_from(&b);
assert_eq!(a, b); assert_eq!(a, b);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn slice() { fn slice() {
let m = DenseMatrix::from_2d_array(&[ let m = DenseMatrix::from_2d_array(&[
@@ -1222,7 +1227,7 @@ mod tests {
let result = m.slice(0..2, 1..3); let result = m.slice(0..2, 1..3);
assert_eq!(result, expected); assert_eq!(result, expected);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn approximate_eq() { fn approximate_eq() {
let m = DenseMatrix::from_2d_array(&[&[2., 3.], &[5., 6.]]); let m = DenseMatrix::from_2d_array(&[&[2., 3.], &[5., 6.]]);
@@ -1231,7 +1236,7 @@ mod tests {
assert!(m.approximate_eq(&m_eq, 0.5)); assert!(m.approximate_eq(&m_eq, 0.5));
assert!(!m.approximate_eq(&m_neq, 0.5)); assert!(!m.approximate_eq(&m_neq, 0.5));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn rand() { fn rand() {
let m: DenseMatrix<f64> = DenseMatrix::rand(3, 3); let m: DenseMatrix<f64> = DenseMatrix::rand(3, 3);
@@ -1241,7 +1246,7 @@ mod tests {
} }
} }
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn transpose() { fn transpose() {
let m = DenseMatrix::from_2d_array(&[&[1.0, 3.0], &[2.0, 4.0]]); let m = DenseMatrix::from_2d_array(&[&[1.0, 3.0], &[2.0, 4.0]]);
@@ -1253,7 +1258,7 @@ mod tests {
} }
} }
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn reshape() { fn reshape() {
let m_orig = DenseMatrix::row_vector_from_array(&[1., 2., 3., 4., 5., 6.]); let m_orig = DenseMatrix::row_vector_from_array(&[1., 2., 3., 4., 5., 6.]);
@@ -1264,7 +1269,7 @@ mod tests {
assert_eq!(m_result.get(0, 1), 2.); assert_eq!(m_result.get(0, 1), 2.);
assert_eq!(m_result.get(0, 3), 4.); assert_eq!(m_result.get(0, 3), 4.);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn norm() { fn norm() {
let v = DenseMatrix::row_vector_from_array(&[3., -2., 6.]); let v = DenseMatrix::row_vector_from_array(&[3., -2., 6.]);
@@ -1273,7 +1278,7 @@ mod tests {
assert_eq!(v.norm(std::f64::INFINITY), 6.); assert_eq!(v.norm(std::f64::INFINITY), 6.);
assert_eq!(v.norm(std::f64::NEG_INFINITY), 2.); assert_eq!(v.norm(std::f64::NEG_INFINITY), 2.);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn softmax_mut() { fn softmax_mut() {
let mut prob: DenseMatrix<f64> = DenseMatrix::row_vector_from_array(&[1., 2., 3.]); let mut prob: DenseMatrix<f64> = DenseMatrix::row_vector_from_array(&[1., 2., 3.]);
@@ -1282,14 +1287,14 @@ mod tests {
assert!((prob.get(0, 1) - 0.24).abs() < 0.01); assert!((prob.get(0, 1) - 0.24).abs() < 0.01);
assert!((prob.get(0, 2) - 0.66).abs() < 0.01); assert!((prob.get(0, 2) - 0.66).abs() < 0.01);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn col_mean() { fn col_mean() {
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]); let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
let res = a.column_mean(); let res = a.column_mean();
assert_eq!(res, vec![4., 5., 6.]); assert_eq!(res, vec![4., 5., 6.]);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn min_max_sum() { fn min_max_sum() {
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]); let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
@@ -1297,30 +1302,32 @@ mod tests {
assert_eq!(1., a.min()); assert_eq!(1., a.min());
assert_eq!(6., a.max()); assert_eq!(6., a.max());
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn eye() { fn eye() {
let a = DenseMatrix::from_2d_array(&[&[1., 0., 0.], &[0., 1., 0.], &[0., 0., 1.]]); let a = DenseMatrix::from_2d_array(&[&[1., 0., 0.], &[0., 1., 0.], &[0., 0., 1.]]);
let res = DenseMatrix::eye(3); let res = DenseMatrix::eye(3);
assert_eq!(res, a); assert_eq!(res, a);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
#[cfg(feature = "serde")]
fn to_from_json() { fn to_from_json() {
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
let deserialized_a: DenseMatrix<f64> = let deserialized_a: DenseMatrix<f64> =
serde_json::from_str(&serde_json::to_string(&a).unwrap()).unwrap(); serde_json::from_str(&serde_json::to_string(&a).unwrap()).unwrap();
assert_eq!(a, deserialized_a); assert_eq!(a, deserialized_a);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
#[cfg(feature = "serde")]
fn to_from_bincode() { fn to_from_bincode() {
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
let deserialized_a: DenseMatrix<f64> = let deserialized_a: DenseMatrix<f64> =
bincode::deserialize(&bincode::serialize(&a).unwrap()).unwrap(); bincode::deserialize(&bincode::serialize(&a).unwrap()).unwrap();
assert_eq!(a, deserialized_a); assert_eq!(a, deserialized_a);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn to_string() { fn to_string() {
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
@@ -1329,7 +1336,7 @@ mod tests {
"[[0.9, 0.4, 0.7], [0.4, 0.5, 0.3], [0.7, 0.3, 0.8]]" "[[0.9, 0.4, 0.7], [0.4, 0.5, 0.3], [0.7, 0.3, 0.8]]"
); );
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn cov() { fn cov() {
let a = DenseMatrix::from_2d_array(&[ let a = DenseMatrix::from_2d_array(&[
+40
View File
@@ -579,6 +579,7 @@ mod tests {
use crate::linear::linear_regression::*; use crate::linear::linear_regression::*;
use nalgebra::{DMatrix, Matrix2x3, RowDVector}; use nalgebra::{DMatrix, Matrix2x3, RowDVector};
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn vec_copy_from() { fn vec_copy_from() {
let mut v1 = RowDVector::from_vec(vec![1., 2., 3.]); let mut v1 = RowDVector::from_vec(vec![1., 2., 3.]);
@@ -589,12 +590,14 @@ mod tests {
assert_ne!(v2, v1); assert_ne!(v2, v1);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn vec_len() { fn vec_len() {
let v = RowDVector::from_vec(vec![1., 2., 3.]); let v = RowDVector::from_vec(vec![1., 2., 3.]);
assert_eq!(3, v.len()); assert_eq!(3, v.len());
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn get_set_vector() { fn get_set_vector() {
let mut v = RowDVector::from_vec(vec![1., 2., 3., 4.]); let mut v = RowDVector::from_vec(vec![1., 2., 3., 4.]);
@@ -607,12 +610,14 @@ mod tests {
assert_eq!(5., BaseVector::get(&v, 1)); assert_eq!(5., BaseVector::get(&v, 1));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn vec_to_vec() { fn vec_to_vec() {
let v = RowDVector::from_vec(vec![1., 2., 3.]); let v = RowDVector::from_vec(vec![1., 2., 3.]);
assert_eq!(vec![1., 2., 3.], v.to_vec()); assert_eq!(vec![1., 2., 3.], v.to_vec());
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn vec_init() { fn vec_init() {
let zeros: RowDVector<f32> = BaseVector::zeros(3); let zeros: RowDVector<f32> = BaseVector::zeros(3);
@@ -623,6 +628,7 @@ mod tests {
assert_eq!(twos, RowDVector::from_vec(vec![2., 2., 2.])); assert_eq!(twos, RowDVector::from_vec(vec![2., 2., 2.]));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn vec_dot() { fn vec_dot() {
let v1 = RowDVector::from_vec(vec![1., 2., 3.]); let v1 = RowDVector::from_vec(vec![1., 2., 3.]);
@@ -630,6 +636,7 @@ mod tests {
assert_eq!(32.0, BaseVector::dot(&v1, &v2)); assert_eq!(32.0, BaseVector::dot(&v1, &v2));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn vec_approximate_eq() { fn vec_approximate_eq() {
let a = RowDVector::from_vec(vec![1., 2., 3.]); let a = RowDVector::from_vec(vec![1., 2., 3.]);
@@ -638,6 +645,7 @@ mod tests {
assert!(!a.approximate_eq(&(&noise + &a), 1e-5)); assert!(!a.approximate_eq(&(&noise + &a), 1e-5));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn get_set_dynamic() { fn get_set_dynamic() {
let mut m = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); let mut m = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
@@ -650,6 +658,7 @@ mod tests {
assert_eq!(10., BaseMatrix::get(&m, 1, 1)); assert_eq!(10., BaseMatrix::get(&m, 1, 1));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn zeros() { fn zeros() {
let expected = DMatrix::from_row_slice(2, 2, &[0., 0., 0., 0.]); let expected = DMatrix::from_row_slice(2, 2, &[0., 0., 0., 0.]);
@@ -659,6 +668,7 @@ mod tests {
assert_eq!(m, expected); assert_eq!(m, expected);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn ones() { fn ones() {
let expected = DMatrix::from_row_slice(2, 2, &[1., 1., 1., 1.]); let expected = DMatrix::from_row_slice(2, 2, &[1., 1., 1., 1.]);
@@ -668,6 +678,7 @@ mod tests {
assert_eq!(m, expected); assert_eq!(m, expected);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn eye() { fn eye() {
let expected = DMatrix::from_row_slice(3, 3, &[1., 0., 0., 0., 1., 0., 0., 0., 1.]); let expected = DMatrix::from_row_slice(3, 3, &[1., 0., 0., 0., 1., 0., 0., 0., 1.]);
@@ -675,6 +686,7 @@ mod tests {
assert_eq!(m, expected); assert_eq!(m, expected);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn shape() { fn shape() {
let m: DMatrix<f64> = BaseMatrix::zeros(5, 10); let m: DMatrix<f64> = BaseMatrix::zeros(5, 10);
@@ -684,6 +696,7 @@ mod tests {
assert_eq!(ncols, 10); assert_eq!(ncols, 10);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn scalar_add_sub_mul_div() { fn scalar_add_sub_mul_div() {
let mut m = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); let mut m = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
@@ -697,6 +710,7 @@ mod tests {
assert_eq!(m, expected); assert_eq!(m, expected);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn add_sub_mul_div() { fn add_sub_mul_div() {
let mut m = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]); let mut m = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
@@ -715,6 +729,7 @@ mod tests {
assert_eq!(m, expected); assert_eq!(m, expected);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn to_from_row_vector() { fn to_from_row_vector() {
let v = RowDVector::from_vec(vec![1., 2., 3., 4.]); let v = RowDVector::from_vec(vec![1., 2., 3., 4.]);
@@ -723,12 +738,14 @@ mod tests {
assert_eq!(m.to_row_vector(), expected); assert_eq!(m.to_row_vector(), expected);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn col_matrix_to_row_vector() { fn col_matrix_to_row_vector() {
let m: DMatrix<f64> = BaseMatrix::zeros(10, 1); let m: DMatrix<f64> = BaseMatrix::zeros(10, 1);
assert_eq!(m.to_row_vector().len(), 10) assert_eq!(m.to_row_vector().len(), 10)
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn get_row_col_as_vec() { fn get_row_col_as_vec() {
let m = DMatrix::from_row_slice(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]); let m = DMatrix::from_row_slice(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
@@ -737,12 +754,14 @@ mod tests {
assert_eq!(m.get_col_as_vec(1), vec!(2., 5., 8.)); assert_eq!(m.get_col_as_vec(1), vec!(2., 5., 8.));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn get_row() { fn get_row() {
let a = DMatrix::from_row_slice(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]); let a = DMatrix::from_row_slice(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
assert_eq!(RowDVector::from_vec(vec![4., 5., 6.]), a.get_row(1)); assert_eq!(RowDVector::from_vec(vec![4., 5., 6.]), a.get_row(1));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn copy_row_col_as_vec() { fn copy_row_col_as_vec() {
let m = DMatrix::from_row_slice(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]); let m = DMatrix::from_row_slice(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
@@ -754,6 +773,7 @@ mod tests {
assert_eq!(v, vec!(2., 5., 8.)); assert_eq!(v, vec!(2., 5., 8.));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn element_add_sub_mul_div() { fn element_add_sub_mul_div() {
let mut m = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]); let mut m = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
@@ -767,6 +787,7 @@ mod tests {
assert_eq!(m, expected); assert_eq!(m, expected);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn vstack_hstack() { fn vstack_hstack() {
let m1 = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]); let m1 = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]);
@@ -782,6 +803,7 @@ mod tests {
assert_eq!(result, expected); assert_eq!(result, expected);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn matmul() { fn matmul() {
let a = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]); let a = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]);
@@ -791,6 +813,7 @@ mod tests {
assert_eq!(result, expected); assert_eq!(result, expected);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn dot() { fn dot() {
let a = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]); let a = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]);
@@ -798,6 +821,7 @@ mod tests {
assert_eq!(14., a.dot(&b)); assert_eq!(14., a.dot(&b));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn slice() { fn slice() {
let a = DMatrix::from_row_slice( let a = DMatrix::from_row_slice(
@@ -810,6 +834,7 @@ mod tests {
assert_eq!(result, expected); assert_eq!(result, expected);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn approximate_eq() { fn approximate_eq() {
let a = DMatrix::from_row_slice(3, 3, &[1., 2., 3., 4., 5., 6., 7., 8., 9.]); let a = DMatrix::from_row_slice(3, 3, &[1., 2., 3., 4., 5., 6., 7., 8., 9.]);
@@ -822,6 +847,7 @@ mod tests {
assert!(!a.approximate_eq(&(&noise + &a), 1e-5)); assert!(!a.approximate_eq(&(&noise + &a), 1e-5));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn negative_mut() { fn negative_mut() {
let mut v = DMatrix::from_row_slice(1, 3, &[3., -2., 6.]); let mut v = DMatrix::from_row_slice(1, 3, &[3., -2., 6.]);
@@ -829,6 +855,7 @@ mod tests {
assert_eq!(v, DMatrix::from_row_slice(1, 3, &[-3., 2., -6.])); assert_eq!(v, DMatrix::from_row_slice(1, 3, &[-3., 2., -6.]));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn transpose() { fn transpose() {
let m = DMatrix::from_row_slice(2, 2, &[1.0, 3.0, 2.0, 4.0]); let m = DMatrix::from_row_slice(2, 2, &[1.0, 3.0, 2.0, 4.0]);
@@ -837,6 +864,7 @@ mod tests {
assert_eq!(m_transposed, expected); assert_eq!(m_transposed, expected);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn rand() { fn rand() {
let m: DMatrix<f64> = BaseMatrix::rand(3, 3); let m: DMatrix<f64> = BaseMatrix::rand(3, 3);
@@ -847,6 +875,7 @@ mod tests {
} }
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn norm() { fn norm() {
let v = DMatrix::from_row_slice(1, 3, &[3., -2., 6.]); let v = DMatrix::from_row_slice(1, 3, &[3., -2., 6.]);
@@ -856,6 +885,7 @@ mod tests {
assert_eq!(BaseMatrix::norm(&v, std::f64::NEG_INFINITY), 2.); assert_eq!(BaseMatrix::norm(&v, std::f64::NEG_INFINITY), 2.);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn col_mean() { fn col_mean() {
let a = DMatrix::from_row_slice(3, 3, &[1., 2., 3., 4., 5., 6., 7., 8., 9.]); let a = DMatrix::from_row_slice(3, 3, &[1., 2., 3., 4., 5., 6., 7., 8., 9.]);
@@ -863,6 +893,7 @@ mod tests {
assert_eq!(res, vec![4., 5., 6.]); assert_eq!(res, vec![4., 5., 6.]);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn reshape() { fn reshape() {
let m_orig = DMatrix::from_row_slice(1, 6, &[1., 2., 3., 4., 5., 6.]); let m_orig = DMatrix::from_row_slice(1, 6, &[1., 2., 3., 4., 5., 6.]);
@@ -874,6 +905,7 @@ mod tests {
assert_eq!(BaseMatrix::get(&m_result, 0, 3), 4.); assert_eq!(BaseMatrix::get(&m_result, 0, 3), 4.);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn copy_from() { fn copy_from() {
let mut src = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]); let mut src = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]);
@@ -882,6 +914,7 @@ mod tests {
assert_eq!(src, dst); assert_eq!(src, dst);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn abs_mut() { fn abs_mut() {
let mut a = DMatrix::from_row_slice(2, 2, &[1., -2., 3., -4.]); let mut a = DMatrix::from_row_slice(2, 2, &[1., -2., 3., -4.]);
@@ -890,6 +923,7 @@ mod tests {
assert_eq!(a, expected); assert_eq!(a, expected);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn min_max_sum() { fn min_max_sum() {
let a = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]); let a = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]);
@@ -898,6 +932,7 @@ mod tests {
assert_eq!(6., a.max()); assert_eq!(6., a.max());
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn max_diff() { fn max_diff() {
let a1 = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., -5., 6.]); let a1 = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., -5., 6.]);
@@ -906,6 +941,7 @@ mod tests {
assert_eq!(a2.max_diff(&a2), 0.); assert_eq!(a2.max_diff(&a2), 0.);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn softmax_mut() { fn softmax_mut() {
let mut prob: DMatrix<f64> = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]); let mut prob: DMatrix<f64> = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]);
@@ -915,6 +951,7 @@ mod tests {
assert!((BaseMatrix::get(&prob, 0, 2) - 0.66).abs() < 0.01); assert!((BaseMatrix::get(&prob, 0, 2) - 0.66).abs() < 0.01);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn pow_mut() { fn pow_mut() {
let mut a = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]); let mut a = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]);
@@ -922,6 +959,7 @@ mod tests {
assert_eq!(a, DMatrix::from_row_slice(1, 3, &[1., 8., 27.])); assert_eq!(a, DMatrix::from_row_slice(1, 3, &[1., 8., 27.]));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn argmax() { fn argmax() {
let a = DMatrix::from_row_slice(3, 3, &[1., 2., 3., -5., -6., -7., 0.1, 0.2, 0.1]); let a = DMatrix::from_row_slice(3, 3, &[1., 2., 3., -5., -6., -7., 0.1, 0.2, 0.1]);
@@ -929,6 +967,7 @@ mod tests {
assert_eq!(res, vec![2, 0, 1]); assert_eq!(res, vec![2, 0, 1]);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn unique() { fn unique() {
let a = DMatrix::from_row_slice(3, 3, &[1., 2., 2., -2., -6., -7., 2., 3., 4.]); let a = DMatrix::from_row_slice(3, 3, &[1., 2., 2., -2., -6., -7., 2., 3., 4.]);
@@ -937,6 +976,7 @@ mod tests {
assert_eq!(res, vec![-7., -6., -2., 1., 2., 3., 4.]); assert_eq!(res, vec![-7., -6., -2., 1., 2., 3., 4.]);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn ols_fit_predict() { fn ols_fit_predict() {
let x = DMatrix::from_row_slice( let x = DMatrix::from_row_slice(
+49 -4
View File
@@ -178,7 +178,7 @@ impl<T: RealNumber + ScalarOperand> BaseVector<T> for ArrayBase<OwnedRepr<T>, Ix
} }
fn copy_from(&mut self, other: &Self) { fn copy_from(&mut self, other: &Self) {
self.assign(&other); self.assign(other);
} }
} }
@@ -385,7 +385,7 @@ impl<T: RealNumber + ScalarOperand + AddAssign + SubAssign + MulAssign + DivAssi
} }
fn copy_from(&mut self, other: &Self) { fn copy_from(&mut self, other: &Self) {
self.assign(&other); self.assign(other);
} }
fn abs_mut(&mut self) -> &Self { fn abs_mut(&mut self) -> &Self {
@@ -530,6 +530,7 @@ mod tests {
use crate::metrics::mean_absolute_error; use crate::metrics::mean_absolute_error;
use ndarray::{arr1, arr2, Array1, Array2}; use ndarray::{arr1, arr2, Array1, Array2};
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn vec_get_set() { fn vec_get_set() {
let mut result = arr1(&[1., 2., 3.]); let mut result = arr1(&[1., 2., 3.]);
@@ -541,6 +542,7 @@ mod tests {
assert_eq!(5., BaseVector::get(&result, 1)); assert_eq!(5., BaseVector::get(&result, 1));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn vec_copy_from() { fn vec_copy_from() {
let mut v1 = arr1(&[1., 2., 3.]); let mut v1 = arr1(&[1., 2., 3.]);
@@ -551,18 +553,21 @@ mod tests {
assert_ne!(v1, v2); assert_ne!(v1, v2);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn vec_len() { fn vec_len() {
let v = arr1(&[1., 2., 3.]); let v = arr1(&[1., 2., 3.]);
assert_eq!(3, v.len()); assert_eq!(3, v.len());
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn vec_to_vec() { fn vec_to_vec() {
let v = arr1(&[1., 2., 3.]); let v = arr1(&[1., 2., 3.]);
assert_eq!(vec![1., 2., 3.], v.to_vec()); assert_eq!(vec![1., 2., 3.], v.to_vec());
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn vec_dot() { fn vec_dot() {
let v1 = arr1(&[1., 2., 3.]); let v1 = arr1(&[1., 2., 3.]);
@@ -570,6 +575,7 @@ mod tests {
assert_eq!(32.0, BaseVector::dot(&v1, &v2)); assert_eq!(32.0, BaseVector::dot(&v1, &v2));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn vec_approximate_eq() { fn vec_approximate_eq() {
let a = arr1(&[1., 2., 3.]); let a = arr1(&[1., 2., 3.]);
@@ -578,6 +584,7 @@ mod tests {
assert!(!a.approximate_eq(&(&noise + &a), 1e-5)); assert!(!a.approximate_eq(&(&noise + &a), 1e-5));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn from_to_row_vec() { fn from_to_row_vec() {
let vec = arr1(&[1., 2., 3.]); let vec = arr1(&[1., 2., 3.]);
@@ -588,12 +595,14 @@ mod tests {
); );
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn col_matrix_to_row_vector() { fn col_matrix_to_row_vector() {
let m: Array2<f64> = BaseMatrix::zeros(10, 1); let m: Array2<f64> = BaseMatrix::zeros(10, 1);
assert_eq!(m.to_row_vector().len(), 10) assert_eq!(m.to_row_vector().len(), 10)
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn add_mut() { fn add_mut() {
let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]); let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
@@ -604,6 +613,7 @@ mod tests {
assert_eq!(a1, a3); assert_eq!(a1, a3);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn sub_mut() { fn sub_mut() {
let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]); let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
@@ -614,6 +624,7 @@ mod tests {
assert_eq!(a1, a3); assert_eq!(a1, a3);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn mul_mut() { fn mul_mut() {
let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]); let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
@@ -624,6 +635,7 @@ mod tests {
assert_eq!(a1, a3); assert_eq!(a1, a3);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn div_mut() { fn div_mut() {
let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]); let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
@@ -634,6 +646,7 @@ mod tests {
assert_eq!(a1, a3); assert_eq!(a1, a3);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn div_element_mut() { fn div_element_mut() {
let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]); let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
@@ -642,6 +655,7 @@ mod tests {
assert_eq!(BaseMatrix::get(&a, 1, 1), 1.); assert_eq!(BaseMatrix::get(&a, 1, 1), 1.);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn mul_element_mut() { fn mul_element_mut() {
let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]); let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
@@ -650,6 +664,7 @@ mod tests {
assert_eq!(BaseMatrix::get(&a, 1, 1), 25.); assert_eq!(BaseMatrix::get(&a, 1, 1), 25.);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn add_element_mut() { fn add_element_mut() {
let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]); let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
@@ -657,7 +672,7 @@ mod tests {
assert_eq!(BaseMatrix::get(&a, 1, 1), 10.); assert_eq!(BaseMatrix::get(&a, 1, 1), 10.);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn sub_element_mut() { fn sub_element_mut() {
let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]); let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
@@ -666,6 +681,7 @@ mod tests {
assert_eq!(BaseMatrix::get(&a, 1, 1), 0.); assert_eq!(BaseMatrix::get(&a, 1, 1), 0.);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn vstack_hstack() { fn vstack_hstack() {
let a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]); let a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
@@ -680,6 +696,7 @@ mod tests {
assert_eq!(result, expected); assert_eq!(result, expected);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn get_set() { fn get_set() {
let mut result = arr2(&[[1., 2., 3.], [4., 5., 6.]]); let mut result = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
@@ -691,6 +708,7 @@ mod tests {
assert_eq!(10., BaseMatrix::get(&result, 1, 1)); assert_eq!(10., BaseMatrix::get(&result, 1, 1));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn matmul() { fn matmul() {
let a = arr2(&[[1., 2., 3.], [4., 5., 6.]]); let a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
@@ -700,6 +718,7 @@ mod tests {
assert_eq!(result, expected); assert_eq!(result, expected);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn dot() { fn dot() {
let a = arr2(&[[1., 2., 3.]]); let a = arr2(&[[1., 2., 3.]]);
@@ -707,6 +726,7 @@ mod tests {
assert_eq!(14., BaseMatrix::dot(&a, &b)); assert_eq!(14., BaseMatrix::dot(&a, &b));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn slice() { fn slice() {
let a = arr2(&[ let a = arr2(&[
@@ -719,6 +739,7 @@ mod tests {
assert_eq!(result, expected); assert_eq!(result, expected);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn scalar_ops() { fn scalar_ops() {
let a = arr2(&[[1., 2., 3.]]); let a = arr2(&[[1., 2., 3.]]);
@@ -728,6 +749,7 @@ mod tests {
assert_eq!(&arr2(&[[0.5, 1., 1.5]]), a.clone().div_scalar_mut(2.)); assert_eq!(&arr2(&[[0.5, 1., 1.5]]), a.clone().div_scalar_mut(2.));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn transpose() { fn transpose() {
let m = arr2(&[[1.0, 3.0], [2.0, 4.0]]); let m = arr2(&[[1.0, 3.0], [2.0, 4.0]]);
@@ -736,6 +758,7 @@ mod tests {
assert_eq!(m_transposed, expected); assert_eq!(m_transposed, expected);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn norm() { fn norm() {
let v = arr2(&[[3., -2., 6.]]); let v = arr2(&[[3., -2., 6.]]);
@@ -745,6 +768,7 @@ mod tests {
assert_eq!(v.norm(std::f64::NEG_INFINITY), 2.); assert_eq!(v.norm(std::f64::NEG_INFINITY), 2.);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn negative_mut() { fn negative_mut() {
let mut v = arr2(&[[3., -2., 6.]]); let mut v = arr2(&[[3., -2., 6.]]);
@@ -752,6 +776,7 @@ mod tests {
assert_eq!(v, arr2(&[[-3., 2., -6.]])); assert_eq!(v, arr2(&[[-3., 2., -6.]]));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn reshape() { fn reshape() {
let m_orig = arr2(&[[1., 2., 3., 4., 5., 6.]]); let m_orig = arr2(&[[1., 2., 3., 4., 5., 6.]]);
@@ -763,6 +788,7 @@ mod tests {
assert_eq!(BaseMatrix::get(&m_result, 0, 3), 4.); assert_eq!(BaseMatrix::get(&m_result, 0, 3), 4.);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn copy_from() { fn copy_from() {
let mut src = arr2(&[[1., 2., 3.]]); let mut src = arr2(&[[1., 2., 3.]]);
@@ -771,6 +797,7 @@ mod tests {
assert_eq!(src, dst); assert_eq!(src, dst);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn min_max_sum() { fn min_max_sum() {
let a = arr2(&[[1., 2., 3.], [4., 5., 6.]]); let a = arr2(&[[1., 2., 3.], [4., 5., 6.]]);
@@ -779,6 +806,7 @@ mod tests {
assert_eq!(6., a.max()); assert_eq!(6., a.max());
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn max_diff() { fn max_diff() {
let a1 = arr2(&[[1., 2., 3.], [4., -5., 6.]]); let a1 = arr2(&[[1., 2., 3.], [4., -5., 6.]]);
@@ -787,6 +815,7 @@ mod tests {
assert_eq!(a2.max_diff(&a2), 0.); assert_eq!(a2.max_diff(&a2), 0.);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn softmax_mut() { fn softmax_mut() {
let mut prob: Array2<f64> = arr2(&[[1., 2., 3.]]); let mut prob: Array2<f64> = arr2(&[[1., 2., 3.]]);
@@ -796,6 +825,7 @@ mod tests {
assert!((BaseMatrix::get(&prob, 0, 2) - 0.66).abs() < 0.01); assert!((BaseMatrix::get(&prob, 0, 2) - 0.66).abs() < 0.01);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn pow_mut() { fn pow_mut() {
let mut a = arr2(&[[1., 2., 3.]]); let mut a = arr2(&[[1., 2., 3.]]);
@@ -803,6 +833,7 @@ mod tests {
assert_eq!(a, arr2(&[[1., 8., 27.]])); assert_eq!(a, arr2(&[[1., 8., 27.]]));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn argmax() { fn argmax() {
let a = arr2(&[[1., 2., 3.], [-5., -6., -7.], [0.1, 0.2, 0.1]]); let a = arr2(&[[1., 2., 3.], [-5., -6., -7.], [0.1, 0.2, 0.1]]);
@@ -810,6 +841,7 @@ mod tests {
assert_eq!(res, vec![2, 0, 1]); assert_eq!(res, vec![2, 0, 1]);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn unique() { fn unique() {
let a = arr2(&[[1., 2., 2.], [-2., -6., -7.], [2., 3., 4.]]); let a = arr2(&[[1., 2., 2.], [-2., -6., -7.], [2., 3., 4.]]);
@@ -818,6 +850,7 @@ mod tests {
assert_eq!(res, vec![-7., -6., -2., 1., 2., 3., 4.]); assert_eq!(res, vec![-7., -6., -2., 1., 2., 3., 4.]);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn get_row_as_vector() { fn get_row_as_vector() {
let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]); let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
@@ -825,12 +858,14 @@ mod tests {
assert_eq!(res, vec![4., 5., 6.]); assert_eq!(res, vec![4., 5., 6.]);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn get_row() { fn get_row() {
let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]); let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
assert_eq!(arr1(&[4., 5., 6.]), a.get_row(1)); assert_eq!(arr1(&[4., 5., 6.]), a.get_row(1));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn get_col_as_vector() { fn get_col_as_vector() {
let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]); let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
@@ -838,6 +873,7 @@ mod tests {
assert_eq!(res, vec![2., 5., 8.]); assert_eq!(res, vec![2., 5., 8.]);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn copy_row_col_as_vec() { fn copy_row_col_as_vec() {
let m = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]); let m = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
@@ -849,6 +885,7 @@ mod tests {
assert_eq!(v, vec!(2., 5., 8.)); assert_eq!(v, vec!(2., 5., 8.));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn col_mean() { fn col_mean() {
let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]); let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
@@ -856,6 +893,7 @@ mod tests {
assert_eq!(res, vec![4., 5., 6.]); assert_eq!(res, vec![4., 5., 6.]);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn eye() { fn eye() {
let a = arr2(&[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]); let a = arr2(&[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]);
@@ -863,6 +901,7 @@ mod tests {
assert_eq!(res, a); assert_eq!(res, a);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn rand() { fn rand() {
let m: Array2<f64> = BaseMatrix::rand(3, 3); let m: Array2<f64> = BaseMatrix::rand(3, 3);
@@ -873,6 +912,7 @@ mod tests {
} }
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn approximate_eq() { fn approximate_eq() {
let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]); let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]);
@@ -881,6 +921,7 @@ mod tests {
assert!(!a.approximate_eq(&(&noise + &a), 1e-5)); assert!(!a.approximate_eq(&(&noise + &a), 1e-5));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn abs_mut() { fn abs_mut() {
let mut a = arr2(&[[1., -2.], [3., -4.]]); let mut a = arr2(&[[1., -2.], [3., -4.]]);
@@ -889,6 +930,7 @@ mod tests {
assert_eq!(a, expected); assert_eq!(a, expected);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn lr_fit_predict_iris() { fn lr_fit_predict_iris() {
let x = arr2(&[ let x = arr2(&[
@@ -924,12 +966,13 @@ mod tests {
let error: f64 = y let error: f64 = y
.into_iter() .into_iter()
.zip(y_hat.into_iter()) .zip(y_hat.into_iter())
.map(|(&a, &b)| (a - b).abs()) .map(|(a, b)| (a - b).abs())
.sum(); .sum();
assert!(error <= 1.0); assert!(error <= 1.0);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn my_fit_longley_ndarray() { fn my_fit_longley_ndarray() {
let x = arr2(&[ let x = arr2(&[
@@ -964,6 +1007,8 @@ mod tests {
min_samples_split: 2, min_samples_split: 2,
n_trees: 1000, n_trees: 1000,
m: Option::None, m: Option::None,
keep_samples: false,
seed: 0,
}, },
) )
.unwrap() .unwrap()
+2 -1
View File
@@ -195,7 +195,7 @@ pub trait QRDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
mod tests { mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::*; use crate::linalg::naive::dense_matrix::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn decompose() { fn decompose() {
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
@@ -214,6 +214,7 @@ mod tests {
assert!(qr.R().abs().approximate_eq(&r.abs(), 1e-4)); assert!(qr.R().abs().approximate_eq(&r.abs(), 1e-4));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn qr_solve_mut() { fn qr_solve_mut() {
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
+5 -5
View File
@@ -61,7 +61,7 @@ pub trait MatrixStats<T: RealNumber>: BaseMatrix<T> {
sum += a * a; sum += a * a;
} }
mu /= div; mu /= div;
*x_i = sum / div - mu * mu; *x_i = sum / div - mu.powi(2);
} }
x x
@@ -150,7 +150,7 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::linalg::BaseVector; use crate::linalg::BaseVector;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn mean() { fn mean() {
let m = DenseMatrix::from_2d_array(&[ let m = DenseMatrix::from_2d_array(&[
@@ -164,7 +164,7 @@ mod tests {
assert_eq!(m.mean(0), expected_0); assert_eq!(m.mean(0), expected_0);
assert_eq!(m.mean(1), expected_1); assert_eq!(m.mean(1), expected_1);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn std() { fn std() {
let m = DenseMatrix::from_2d_array(&[ let m = DenseMatrix::from_2d_array(&[
@@ -178,7 +178,7 @@ mod tests {
assert!(m.std(0).approximate_eq(&expected_0, 1e-2)); assert!(m.std(0).approximate_eq(&expected_0, 1e-2));
assert!(m.std(1).approximate_eq(&expected_1, 1e-2)); assert!(m.std(1).approximate_eq(&expected_1, 1e-2));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn var() { fn var() {
let m = DenseMatrix::from_2d_array(&[&[1., 2., 3., 4.], &[5., 6., 7., 8.]]); let m = DenseMatrix::from_2d_array(&[&[1., 2., 3., 4.], &[5., 6., 7., 8.]]);
@@ -188,7 +188,7 @@ mod tests {
assert!(m.var(0).approximate_eq(&expected_0, std::f64::EPSILON)); assert!(m.var(0).approximate_eq(&expected_0, std::f64::EPSILON));
assert!(m.var(1).approximate_eq(&expected_1, std::f64::EPSILON)); assert!(m.var(1).approximate_eq(&expected_1, std::f64::EPSILON));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn scale() { fn scale() {
let mut m = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]); let mut m = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
+10 -9
View File
@@ -47,7 +47,7 @@ pub struct SVD<T: RealNumber, M: SVDDecomposableMatrix<T>> {
pub V: M, pub V: M,
/// Singular values of the original matrix /// Singular values of the original matrix
pub s: Vec<T>, pub s: Vec<T>,
full: bool, _full: bool,
m: usize, m: usize,
n: usize, n: usize,
tol: T, tol: T,
@@ -116,7 +116,7 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
} }
let mut f = U.get(i, i); let mut f = U.get(i, i);
g = -s.sqrt().copysign(f); g = -RealNumber::copysign(s.sqrt(), f);
let h = f * g - s; let h = f * g - s;
U.set(i, i, f - g); U.set(i, i, f - g);
for j in l - 1..n { for j in l - 1..n {
@@ -152,7 +152,7 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
} }
let f = U.get(i, l - 1); let f = U.get(i, l - 1);
g = -s.sqrt().copysign(f); g = -RealNumber::copysign(s.sqrt(), f);
let h = f * g - s; let h = f * g - s;
U.set(i, l - 1, f - g); U.set(i, l - 1, f - g);
@@ -299,7 +299,7 @@ pub trait SVDDecomposableMatrix<T: RealNumber>: BaseMatrix<T> {
let mut h = rv1[k]; let mut h = rv1[k];
let mut f = ((y - z) * (y + z) + (g - h) * (g + h)) / (T::two() * h * y); let mut f = ((y - z) * (y + z) + (g - h) * (g + h)) / (T::two() * h * y);
g = f.hypot(T::one()); g = f.hypot(T::one());
f = ((x - z) * (x + z) + h * ((y / (f + g.copysign(f))) - h)) / x; f = ((x - z) * (x + z) + h * ((y / (f + RealNumber::copysign(g, f))) - h)) / x;
let mut c = T::one(); let mut c = T::one();
let mut s = T::one(); let mut s = T::one();
@@ -428,13 +428,13 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
pub(crate) fn new(U: M, V: M, s: Vec<T>) -> SVD<T, M> { pub(crate) fn new(U: M, V: M, s: Vec<T>) -> SVD<T, M> {
let m = U.shape().0; let m = U.shape().0;
let n = V.shape().0; let n = V.shape().0;
let full = s.len() == m.min(n); let _full = s.len() == m.min(n);
let tol = T::half() * (T::from(m + n).unwrap() + T::one()).sqrt() * s[0] * T::epsilon(); let tol = T::half() * (T::from(m + n).unwrap() + T::one()).sqrt() * s[0] * T::epsilon();
SVD { SVD {
U, U,
V, V,
s, s,
full, _full,
m, m,
n, n,
tol, tol,
@@ -482,7 +482,7 @@ impl<T: RealNumber, M: SVDDecomposableMatrix<T>> SVD<T, M> {
mod tests { mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn decompose_symmetric() { fn decompose_symmetric() {
let A = DenseMatrix::from_2d_array(&[ let A = DenseMatrix::from_2d_array(&[
@@ -513,7 +513,7 @@ mod tests {
assert!((s[i] - svd.s[i]).abs() < 1e-4); assert!((s[i] - svd.s[i]).abs() < 1e-4);
} }
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn decompose_asymmetric() { fn decompose_asymmetric() {
let A = DenseMatrix::from_2d_array(&[ let A = DenseMatrix::from_2d_array(&[
@@ -714,7 +714,7 @@ mod tests {
assert!((s[i] - svd.s[i]).abs() < 1e-4); assert!((s[i] - svd.s[i]).abs() < 1e-4);
} }
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn solve() { fn solve() {
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
@@ -725,6 +725,7 @@ mod tests {
assert!(w.approximate_eq(&expected_w, 1e-2)); assert!(w.approximate_eq(&expected_w, 1e-2));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn decompose_restore() { fn decompose_restore() {
let a = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0, 4.0], &[5.0, 6.0, 7.0, 8.0]]); let a = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0, 4.0], &[5.0, 6.0, 7.0, 8.0]]);
+1
View File
@@ -126,6 +126,7 @@ mod tests {
impl<T: RealNumber, M: Matrix<T>> BiconjugateGradientSolver<T, M> for BGSolver {} impl<T: RealNumber, M: Matrix<T>> BiconjugateGradientSolver<T, M> for BGSolver {}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn bg_solver() { fn bg_solver() {
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]); let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
+9 -2
View File
@@ -56,6 +56,7 @@
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::fmt::Debug; use std::fmt::Debug;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator}; use crate::api::{Predictor, SupervisedEstimator};
@@ -67,7 +68,8 @@ use crate::math::num::RealNumber;
use crate::linear::lasso_optimizer::InteriorPointOptimizer; use crate::linear::lasso_optimizer::InteriorPointOptimizer;
/// Elastic net parameters /// Elastic net parameters
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct ElasticNetParameters<T: RealNumber> { pub struct ElasticNetParameters<T: RealNumber> {
/// Regularization parameter. /// Regularization parameter.
pub alpha: T, pub alpha: T,
@@ -84,7 +86,8 @@ pub struct ElasticNetParameters<T: RealNumber> {
} }
/// Elastic net /// Elastic net
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct ElasticNet<T: RealNumber, M: Matrix<T>> { pub struct ElasticNet<T: RealNumber, M: Matrix<T>> {
coefficients: M, coefficients: M,
intercept: T, intercept: T,
@@ -288,6 +291,7 @@ mod tests {
use crate::linalg::naive::dense_matrix::*; use crate::linalg::naive::dense_matrix::*;
use crate::metrics::mean_absolute_error; use crate::metrics::mean_absolute_error;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn elasticnet_longley() { fn elasticnet_longley() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
@@ -331,6 +335,7 @@ mod tests {
assert!(mean_absolute_error(&y_hat, &y) < 30.0); assert!(mean_absolute_error(&y_hat, &y) < 30.0);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn elasticnet_fit_predict1() { fn elasticnet_fit_predict1() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
@@ -397,7 +402,9 @@ mod tests {
assert!(l1_model.coefficients().get(0, 0) > l1_model.coefficients().get(2, 0)); assert!(l1_model.coefficients().get(0, 0) > l1_model.coefficients().get(2, 0));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
#[cfg(feature = "serde")]
fn serde() { fn serde() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323], &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
+8 -2
View File
@@ -24,6 +24,7 @@
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::fmt::Debug; use std::fmt::Debug;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator}; use crate::api::{Predictor, SupervisedEstimator};
@@ -34,7 +35,8 @@ use crate::linear::lasso_optimizer::InteriorPointOptimizer;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
/// Lasso regression parameters /// Lasso regression parameters
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LassoParameters<T: RealNumber> { pub struct LassoParameters<T: RealNumber> {
/// Controls the strength of the penalty to the loss function. /// Controls the strength of the penalty to the loss function.
pub alpha: T, pub alpha: T,
@@ -47,7 +49,8 @@ pub struct LassoParameters<T: RealNumber> {
pub max_iter: usize, pub max_iter: usize,
} }
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
/// Lasso regressor /// Lasso regressor
pub struct Lasso<T: RealNumber, M: Matrix<T>> { pub struct Lasso<T: RealNumber, M: Matrix<T>> {
coefficients: M, coefficients: M,
@@ -223,6 +226,7 @@ mod tests {
use crate::linalg::naive::dense_matrix::*; use crate::linalg::naive::dense_matrix::*;
use crate::metrics::mean_absolute_error; use crate::metrics::mean_absolute_error;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn lasso_fit_predict() { fn lasso_fit_predict() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
@@ -271,7 +275,9 @@ mod tests {
assert!(mean_absolute_error(&y_hat, &y) < 2.0); assert!(mean_absolute_error(&y_hat, &y) < 2.0);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
#[cfg(feature = "serde")]
fn serde() { fn serde() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323], &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
+1 -1
View File
@@ -138,7 +138,7 @@ impl<T: RealNumber, M: Matrix<T>> InteriorPointOptimizer<T, M> {
for i in 0..p { for i in 0..p {
self.prb[i] = T::two() + self.d1[i]; self.prb[i] = T::two() + self.d1[i];
self.prs[i] = self.prb[i] * self.d1[i] - self.d2[i] * self.d2[i]; self.prs[i] = self.prb[i] * self.d1[i] - self.d2[i].powi(2);
} }
let normg = grad.norm2(); let normg = grad.norm2();
+13 -6
View File
@@ -62,6 +62,7 @@
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::fmt::Debug; use std::fmt::Debug;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator}; use crate::api::{Predictor, SupervisedEstimator};
@@ -69,7 +70,8 @@ use crate::error::Failed;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable. /// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable.
pub enum LinearRegressionSolverName { pub enum LinearRegressionSolverName {
/// QR decomposition, see [QR](../../linalg/qr/index.html) /// QR decomposition, see [QR](../../linalg/qr/index.html)
@@ -79,18 +81,20 @@ pub enum LinearRegressionSolverName {
} }
/// Linear Regression parameters /// Linear Regression parameters
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LinearRegressionParameters { pub struct LinearRegressionParameters {
/// Solver to use for estimation of regression coefficients. /// Solver to use for estimation of regression coefficients.
pub solver: LinearRegressionSolverName, pub solver: LinearRegressionSolverName,
} }
/// Linear Regression /// Linear Regression
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct LinearRegression<T: RealNumber, M: Matrix<T>> { pub struct LinearRegression<T: RealNumber, M: Matrix<T>> {
coefficients: M, coefficients: M,
intercept: T, intercept: T,
solver: LinearRegressionSolverName, _solver: LinearRegressionSolverName,
} }
impl LinearRegressionParameters { impl LinearRegressionParameters {
@@ -151,7 +155,7 @@ impl<T: RealNumber, M: Matrix<T>> LinearRegression<T, M> {
if x_nrows != y_nrows { if x_nrows != y_nrows {
return Err(Failed::fit( return Err(Failed::fit(
&"Number of rows of X doesn\'t match number of rows of Y".to_string(), "Number of rows of X doesn\'t match number of rows of Y",
)); ));
} }
@@ -167,7 +171,7 @@ impl<T: RealNumber, M: Matrix<T>> LinearRegression<T, M> {
Ok(LinearRegression { Ok(LinearRegression {
intercept: w.get(num_attributes, 0), intercept: w.get(num_attributes, 0),
coefficients: wights, coefficients: wights,
solver: parameters.solver, _solver: parameters.solver,
}) })
} }
@@ -196,6 +200,7 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::*; use crate::linalg::naive::dense_matrix::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn ols_fit_predict() { fn ols_fit_predict() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
@@ -246,7 +251,9 @@ mod tests {
.all(|(&a, &b)| (a - b).abs() <= 5.0)); .all(|(&a, &b)| (a - b).abs() <= 5.0));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
#[cfg(feature = "serde")]
fn serde() { fn serde() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323], &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
+153 -19
View File
@@ -54,8 +54,8 @@
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::cmp::Ordering; use std::cmp::Ordering;
use std::fmt::Debug; use std::fmt::Debug;
use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator}; use crate::api::{Predictor, SupervisedEstimator};
@@ -67,12 +67,27 @@ use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
use crate::optimization::line_search::Backtracking; use crate::optimization::line_search::Backtracking;
use crate::optimization::FunctionOrder; use crate::optimization::FunctionOrder;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// Solver options for Logistic regression. Right now only LBFGS solver is supported.
pub enum LogisticRegressionSolverName {
/// Limited-memory BroydenFletcherGoldfarbShanno method, see [LBFGS paper](http://users.iems.northwestern.edu/~nocedal/lbfgsb.html)
LBFGS,
}
/// Logistic Regression parameters /// Logistic Regression parameters
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct LogisticRegressionParameters {} #[derive(Debug, Clone)]
pub struct LogisticRegressionParameters<T: RealNumber> {
/// Solver to use for estimation of regression coefficients.
pub solver: LogisticRegressionSolverName,
/// Regularization parameter.
pub alpha: T,
}
/// Logistic Regression /// Logistic Regression
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct LogisticRegression<T: RealNumber, M: Matrix<T>> { pub struct LogisticRegression<T: RealNumber, M: Matrix<T>> {
coefficients: M, coefficients: M,
intercept: M, intercept: M,
@@ -99,12 +114,28 @@ trait ObjectiveFunction<T: RealNumber, M: Matrix<T>> {
struct BinaryObjectiveFunction<'a, T: RealNumber, M: Matrix<T>> { struct BinaryObjectiveFunction<'a, T: RealNumber, M: Matrix<T>> {
x: &'a M, x: &'a M,
y: Vec<usize>, y: Vec<usize>,
phantom: PhantomData<&'a T>, alpha: T,
} }
impl Default for LogisticRegressionParameters { impl<T: RealNumber> LogisticRegressionParameters<T> {
/// Solver to use for estimation of regression coefficients.
pub fn with_solver(mut self, solver: LogisticRegressionSolverName) -> Self {
self.solver = solver;
self
}
/// Regularization parameter.
pub fn with_alpha(mut self, alpha: T) -> Self {
self.alpha = alpha;
self
}
}
impl<T: RealNumber> Default for LogisticRegressionParameters<T> {
fn default() -> Self { fn default() -> Self {
LogisticRegressionParameters {} LogisticRegressionParameters {
solver: LogisticRegressionSolverName::LBFGS,
alpha: T::zero(),
}
} }
} }
@@ -132,13 +163,22 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
{ {
fn f(&self, w_bias: &M) -> T { fn f(&self, w_bias: &M) -> T {
let mut f = T::zero(); let mut f = T::zero();
let (n, _) = self.x.shape(); let (n, p) = self.x.shape();
for i in 0..n { for i in 0..n {
let wx = BinaryObjectiveFunction::partial_dot(w_bias, self.x, 0, i); let wx = BinaryObjectiveFunction::partial_dot(w_bias, self.x, 0, i);
f += wx.ln_1pe() - (T::from(self.y[i]).unwrap()) * wx; f += wx.ln_1pe() - (T::from(self.y[i]).unwrap()) * wx;
} }
if self.alpha > T::zero() {
let mut w_squared = T::zero();
for i in 0..p {
let w = w_bias.get(0, i);
w_squared += w * w;
}
f += T::half() * self.alpha * w_squared;
}
f f
} }
@@ -156,6 +196,13 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
} }
g.set(0, p, g.get(0, p) - dyi); g.set(0, p, g.get(0, p) - dyi);
} }
if self.alpha > T::zero() {
for i in 0..p {
let w = w_bias.get(0, i);
g.set(0, i, g.get(0, i) + self.alpha * w);
}
}
} }
} }
@@ -163,7 +210,7 @@ struct MultiClassObjectiveFunction<'a, T: RealNumber, M: Matrix<T>> {
x: &'a M, x: &'a M,
y: Vec<usize>, y: Vec<usize>,
k: usize, k: usize,
phantom: PhantomData<&'a T>, alpha: T,
} }
impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M> impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
@@ -185,6 +232,17 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
f -= prob.get(0, self.y[i]).ln(); f -= prob.get(0, self.y[i]).ln();
} }
if self.alpha > T::zero() {
let mut w_squared = T::zero();
for i in 0..self.k {
for j in 0..p {
let wi = w_bias.get(0, i * (p + 1) + j);
w_squared += wi * wi;
}
}
f += T::half() * self.alpha * w_squared;
}
f f
} }
@@ -215,16 +273,27 @@ impl<'a, T: RealNumber, M: Matrix<T>> ObjectiveFunction<T, M>
g.set(0, j * (p + 1) + p, g.get(0, j * (p + 1) + p) - yi); g.set(0, j * (p + 1) + p, g.get(0, j * (p + 1) + p) - yi);
} }
} }
if self.alpha > T::zero() {
for i in 0..self.k {
for j in 0..p {
let pos = i * (p + 1);
let wi = w.get(0, pos + j);
g.set(0, pos + j, g.get(0, pos + j) + self.alpha * wi);
}
}
}
} }
} }
impl<T: RealNumber, M: Matrix<T>> SupervisedEstimator<M, M::RowVector, LogisticRegressionParameters> impl<T: RealNumber, M: Matrix<T>>
SupervisedEstimator<M, M::RowVector, LogisticRegressionParameters<T>>
for LogisticRegression<T, M> for LogisticRegression<T, M>
{ {
fn fit( fn fit(
x: &M, x: &M,
y: &M::RowVector, y: &M::RowVector,
parameters: LogisticRegressionParameters, parameters: LogisticRegressionParameters<T>,
) -> Result<Self, Failed> { ) -> Result<Self, Failed> {
LogisticRegression::fit(x, y, parameters) LogisticRegression::fit(x, y, parameters)
} }
@@ -244,7 +313,7 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
pub fn fit( pub fn fit(
x: &M, x: &M,
y: &M::RowVector, y: &M::RowVector,
_parameters: LogisticRegressionParameters, parameters: LogisticRegressionParameters<T>,
) -> Result<LogisticRegression<T, M>, Failed> { ) -> Result<LogisticRegression<T, M>, Failed> {
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
let (x_nrows, num_attributes) = x.shape(); let (x_nrows, num_attributes) = x.shape();
@@ -252,7 +321,7 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
if x_nrows != y_nrows { if x_nrows != y_nrows {
return Err(Failed::fit( return Err(Failed::fit(
&"Number of rows of X doesn\'t match number of rows of Y".to_string(), "Number of rows of X doesn\'t match number of rows of Y",
)); ));
} }
@@ -278,7 +347,7 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
let objective = BinaryObjectiveFunction { let objective = BinaryObjectiveFunction {
x, x,
y: yi, y: yi,
phantom: PhantomData, alpha: parameters.alpha,
}; };
let result = LogisticRegression::minimize(x0, objective); let result = LogisticRegression::minimize(x0, objective);
@@ -300,7 +369,7 @@ impl<T: RealNumber, M: Matrix<T>> LogisticRegression<T, M> {
x, x,
y: yi, y: yi,
k, k,
phantom: PhantomData, alpha: parameters.alpha,
}; };
let result = LogisticRegression::minimize(x0, objective); let result = LogisticRegression::minimize(x0, objective);
@@ -383,6 +452,7 @@ mod tests {
use crate::linalg::naive::dense_matrix::*; use crate::linalg::naive::dense_matrix::*;
use crate::metrics::accuracy; use crate::metrics::accuracy;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn multiclass_objective_f() { fn multiclass_objective_f() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
@@ -407,9 +477,9 @@ mod tests {
let objective = MultiClassObjectiveFunction { let objective = MultiClassObjectiveFunction {
x: &x, x: &x,
y, y: y.clone(),
k: 3, k: 3,
phantom: PhantomData, alpha: 0.0,
}; };
let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 9); let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 9);
@@ -430,8 +500,27 @@ mod tests {
])); ]));
assert!((f - 408.0052230582765).abs() < std::f64::EPSILON); assert!((f - 408.0052230582765).abs() < std::f64::EPSILON);
let objective_reg = MultiClassObjectiveFunction {
x: &x,
y: y.clone(),
k: 3,
alpha: 1.0,
};
let f = objective_reg.f(&DenseMatrix::row_vector_from_array(&[
1., 2., 3., 4., 5., 6., 7., 8., 9.,
]));
assert!((f - 487.5052).abs() < 1e-4);
objective_reg.df(
&mut g,
&DenseMatrix::row_vector_from_array(&[1., 2., 3., 4., 5., 6., 7., 8., 9.]),
);
assert!((g.get(0, 0).abs() - 32.0).abs() < 1e-4);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn binary_objective_f() { fn binary_objective_f() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
@@ -456,8 +545,8 @@ mod tests {
let objective = BinaryObjectiveFunction { let objective = BinaryObjectiveFunction {
x: &x, x: &x,
y, y: y.clone(),
phantom: PhantomData, alpha: 0.0,
}; };
let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 3); let mut g: DenseMatrix<f64> = DenseMatrix::zeros(1, 3);
@@ -472,8 +561,23 @@ mod tests {
let f = objective.f(&DenseMatrix::row_vector_from_array(&[1., 2., 3.])); let f = objective.f(&DenseMatrix::row_vector_from_array(&[1., 2., 3.]));
assert!((f - 59.76994756647412).abs() < std::f64::EPSILON); assert!((f - 59.76994756647412).abs() < std::f64::EPSILON);
let objective_reg = BinaryObjectiveFunction {
x: &x,
y: y.clone(),
alpha: 1.0,
};
let f = objective_reg.f(&DenseMatrix::row_vector_from_array(&[1., 2., 3.]));
assert!((f - 62.2699).abs() < 1e-4);
objective_reg.df(&mut g, &DenseMatrix::row_vector_from_array(&[1., 2., 3.]));
assert!((g.get(0, 0) - 27.0511).abs() < 1e-4);
assert!((g.get(0, 1) - 12.239).abs() < 1e-4);
assert!((g.get(0, 2) - 3.8693).abs() < 1e-4);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn lr_fit_predict() { fn lr_fit_predict() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
@@ -511,6 +615,7 @@ mod tests {
); );
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn lr_fit_predict_multiclass() { fn lr_fit_predict_multiclass() {
let blobs = make_blobs(15, 4, 3); let blobs = make_blobs(15, 4, 3);
@@ -523,8 +628,18 @@ mod tests {
let y_hat = lr.predict(&x).unwrap(); let y_hat = lr.predict(&x).unwrap();
assert!(accuracy(&y_hat, &y) > 0.9); assert!(accuracy(&y_hat, &y) > 0.9);
let lr_reg = LogisticRegression::fit(
&x,
&y,
LogisticRegressionParameters::default().with_alpha(10.0),
)
.unwrap();
assert!(lr_reg.coefficients().abs().sum() < lr.coefficients().abs().sum());
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn lr_fit_predict_binary() { fn lr_fit_predict_binary() {
let blobs = make_blobs(20, 4, 2); let blobs = make_blobs(20, 4, 2);
@@ -537,9 +652,20 @@ mod tests {
let y_hat = lr.predict(&x).unwrap(); let y_hat = lr.predict(&x).unwrap();
assert!(accuracy(&y_hat, &y) > 0.9); assert!(accuracy(&y_hat, &y) > 0.9);
let lr_reg = LogisticRegression::fit(
&x,
&y,
LogisticRegressionParameters::default().with_alpha(10.0),
)
.unwrap();
assert!(lr_reg.coefficients().abs().sum() < lr.coefficients().abs().sum());
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
#[cfg(feature = "serde")]
fn serde() { fn serde() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
&[1., -5.], &[1., -5.],
@@ -568,6 +694,7 @@ mod tests {
assert_eq!(lr, deserialized_lr); assert_eq!(lr, deserialized_lr);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn lr_fit_predict_iris() { fn lr_fit_predict_iris() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
@@ -597,6 +724,12 @@ mod tests {
]; ];
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap(); let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
let lr_reg = LogisticRegression::fit(
&x,
&y,
LogisticRegressionParameters::default().with_alpha(1.0),
)
.unwrap();
let y_hat = lr.predict(&x).unwrap(); let y_hat = lr.predict(&x).unwrap();
@@ -607,5 +740,6 @@ mod tests {
.sum(); .sum();
assert!(error <= 1.0); assert!(error <= 1.0);
assert!(lr_reg.coefficients().abs().sum() < lr.coefficients().abs().sum());
} }
} }
+12 -5
View File
@@ -58,6 +58,7 @@
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::fmt::Debug; use std::fmt::Debug;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator}; use crate::api::{Predictor, SupervisedEstimator};
@@ -66,7 +67,8 @@ use crate::linalg::BaseVector;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable. /// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable.
pub enum RidgeRegressionSolverName { pub enum RidgeRegressionSolverName {
/// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html) /// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html)
@@ -76,7 +78,8 @@ pub enum RidgeRegressionSolverName {
} }
/// Ridge Regression parameters /// Ridge Regression parameters
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct RidgeRegressionParameters<T: RealNumber> { pub struct RidgeRegressionParameters<T: RealNumber> {
/// Solver to use for estimation of regression coefficients. /// Solver to use for estimation of regression coefficients.
pub solver: RidgeRegressionSolverName, pub solver: RidgeRegressionSolverName,
@@ -88,11 +91,12 @@ pub struct RidgeRegressionParameters<T: RealNumber> {
} }
/// Ridge regression /// Ridge regression
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct RidgeRegression<T: RealNumber, M: Matrix<T>> { pub struct RidgeRegression<T: RealNumber, M: Matrix<T>> {
coefficients: M, coefficients: M,
intercept: T, intercept: T,
solver: RidgeRegressionSolverName, _solver: RidgeRegressionSolverName,
} }
impl<T: RealNumber> RidgeRegressionParameters<T> { impl<T: RealNumber> RidgeRegressionParameters<T> {
@@ -222,7 +226,7 @@ impl<T: RealNumber, M: Matrix<T>> RidgeRegression<T, M> {
Ok(RidgeRegression { Ok(RidgeRegression {
intercept: b, intercept: b,
coefficients: w, coefficients: w,
solver: parameters.solver, _solver: parameters.solver,
}) })
} }
@@ -270,6 +274,7 @@ mod tests {
use crate::linalg::naive::dense_matrix::*; use crate::linalg::naive::dense_matrix::*;
use crate::metrics::mean_absolute_error; use crate::metrics::mean_absolute_error;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn ridge_fit_predict() { fn ridge_fit_predict() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
@@ -325,7 +330,9 @@ mod tests {
assert!(mean_absolute_error(&y_hat_svd, &y) < 2.0); assert!(mean_absolute_error(&y_hat_svd, &y) < 2.0);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
#[cfg(feature = "serde")]
fn serde() { fn serde() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323], &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
+4 -1
View File
@@ -18,6 +18,7 @@
//! //!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
@@ -25,7 +26,8 @@ use crate::math::num::RealNumber;
use super::Distance; use super::Distance;
/// Euclidean distance is a measure of the true straight line distance between two points in Euclidean n-space. /// Euclidean distance is a measure of the true straight line distance between two points in Euclidean n-space.
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Euclidian {} pub struct Euclidian {}
impl Euclidian { impl Euclidian {
@@ -55,6 +57,7 @@ impl<T: RealNumber> Distance<Vec<T>, T> for Euclidian {
mod tests { mod tests {
use super::*; use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn squared_distance() { fn squared_distance() {
let a = vec![1., 2., 3.]; let a = vec![1., 2., 3.];
+4 -1
View File
@@ -19,6 +19,7 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
@@ -26,7 +27,8 @@ use crate::math::num::RealNumber;
use super::Distance; use super::Distance;
/// While comparing two integer-valued vectors of equal length, Hamming distance is the number of bit positions in which the two bits are different /// While comparing two integer-valued vectors of equal length, Hamming distance is the number of bit positions in which the two bits are different
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Hamming {} pub struct Hamming {}
impl<T: PartialEq, F: RealNumber> Distance<Vec<T>, F> for Hamming { impl<T: PartialEq, F: RealNumber> Distance<Vec<T>, F> for Hamming {
@@ -50,6 +52,7 @@ impl<T: PartialEq, F: RealNumber> Distance<Vec<T>, F> for Hamming {
mod tests { mod tests {
use super::*; use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn hamming_distance() { fn hamming_distance() {
let a = vec![1, 0, 0, 1, 0, 0, 1]; let a = vec![1, 0, 0, 1, 0, 0, 1];
+4 -1
View File
@@ -44,6 +44,7 @@
use std::marker::PhantomData; use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
@@ -52,7 +53,8 @@ use super::Distance;
use crate::linalg::Matrix; use crate::linalg::Matrix;
/// Mahalanobis distance. /// Mahalanobis distance.
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Mahalanobis<T: RealNumber, M: Matrix<T>> { pub struct Mahalanobis<T: RealNumber, M: Matrix<T>> {
/// covariance matrix of the dataset /// covariance matrix of the dataset
pub sigma: M, pub sigma: M,
@@ -131,6 +133,7 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::*; use crate::linalg::naive::dense_matrix::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn mahalanobis_distance() { fn mahalanobis_distance() {
let data = DenseMatrix::from_2d_array(&[ let data = DenseMatrix::from_2d_array(&[
+4 -1
View File
@@ -17,6 +17,7 @@
//! ``` //! ```
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
@@ -24,7 +25,8 @@ use crate::math::num::RealNumber;
use super::Distance; use super::Distance;
/// Manhattan distance /// Manhattan distance
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Manhattan {} pub struct Manhattan {}
impl<T: RealNumber> Distance<Vec<T>, T> for Manhattan { impl<T: RealNumber> Distance<Vec<T>, T> for Manhattan {
@@ -46,6 +48,7 @@ impl<T: RealNumber> Distance<Vec<T>, T> for Manhattan {
mod tests { mod tests {
use super::*; use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn manhattan_distance() { fn manhattan_distance() {
let a = vec![1., 2., 3.]; let a = vec![1., 2., 3.];
+4 -1
View File
@@ -21,6 +21,7 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
@@ -28,7 +29,8 @@ use crate::math::num::RealNumber;
use super::Distance; use super::Distance;
/// Defines the Minkowski distance of order `p` /// Defines the Minkowski distance of order `p`
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Minkowski { pub struct Minkowski {
/// order, integer /// order, integer
pub p: u16, pub p: u16,
@@ -59,6 +61,7 @@ impl<T: RealNumber> Distance<Vec<T>, T> for Minkowski {
mod tests { mod tests {
use super::*; use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn minkowski_distance() { fn minkowski_distance() {
let a = vec![1., 2., 3.]; let a = vec![1., 2., 3.];
+1
View File
@@ -136,6 +136,7 @@ impl RealNumber for f32 {
mod tests { mod tests {
use super::*; use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn sigmoid() { fn sigmoid() {
assert_eq!(1.0.sigmoid(), 0.7310585786300049); assert_eq!(1.0.sigmoid(), 0.7310585786300049);
+1
View File
@@ -30,6 +30,7 @@ impl<T: RealNumber, V: BaseVector<T>> RealNumberVector<T> for V {
mod tests { mod tests {
use super::*; use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn unique_with_indices() { fn unique_with_indices() {
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0]; let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
+4 -1
View File
@@ -16,13 +16,15 @@
//! //!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector; use crate::linalg::BaseVector;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
/// Accuracy metric. /// Accuracy metric.
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct Accuracy {} pub struct Accuracy {}
impl Accuracy { impl Accuracy {
@@ -55,6 +57,7 @@ impl Accuracy {
mod tests { mod tests {
use super::*; use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn accuracy() { fn accuracy() {
let y_pred: Vec<f64> = vec![0., 2., 1., 3.]; let y_pred: Vec<f64> = vec![0., 2., 1., 3.];
+4 -1
View File
@@ -20,6 +20,7 @@
//! * ["The ROC-AUC and the Mann-Whitney U-test", Haupt, J.](https://johaupt.github.io/roc-auc/model%20evaluation/Area_under_ROC_curve.html) //! * ["The ROC-AUC and the Mann-Whitney U-test", Haupt, J.](https://johaupt.github.io/roc-auc/model%20evaluation/Area_under_ROC_curve.html)
#![allow(non_snake_case)] #![allow(non_snake_case)]
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::algorithm::sort::quick_sort::QuickArgSort; use crate::algorithm::sort::quick_sort::QuickArgSort;
@@ -27,7 +28,8 @@ use crate::linalg::BaseVector;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
/// Area Under the Receiver Operating Characteristic Curve (ROC AUC) /// Area Under the Receiver Operating Characteristic Curve (ROC AUC)
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct AUC {} pub struct AUC {}
impl AUC { impl AUC {
@@ -91,6 +93,7 @@ impl AUC {
mod tests { mod tests {
use super::*; use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn auc() { fn auc() {
let y_true: Vec<f64> = vec![0., 0., 1., 1.]; let y_true: Vec<f64> = vec![0., 0., 1., 1.];
+4 -1
View File
@@ -1,10 +1,12 @@
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector; use crate::linalg::BaseVector;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
use crate::metrics::cluster_helpers::*; use crate::metrics::cluster_helpers::*;
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
/// Homogeneity, completeness and V-Measure scores. /// Homogeneity, completeness and V-Measure scores.
pub struct HCVScore {} pub struct HCVScore {}
@@ -41,6 +43,7 @@ impl HCVScore {
mod tests { mod tests {
use super::*; use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn homogeneity_score() { fn homogeneity_score() {
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0]; let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
+3
View File
@@ -101,6 +101,7 @@ pub fn mutual_info_score<T: RealNumber>(contingency: &[Vec<usize>]) -> T {
mod tests { mod tests {
use super::*; use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn contingency_matrix_test() { fn contingency_matrix_test() {
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0]; let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
@@ -112,6 +113,7 @@ mod tests {
); );
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn entropy_test() { fn entropy_test() {
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0]; let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
@@ -119,6 +121,7 @@ mod tests {
assert!((1.2770f32 - entropy(&v1).unwrap()).abs() < 1e-4); assert!((1.2770f32 - entropy(&v1).unwrap()).abs() < 1e-4);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn mutual_info_score_test() { fn mutual_info_score_test() {
let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0]; let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0];
+4 -1
View File
@@ -18,6 +18,7 @@
//! //!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector; use crate::linalg::BaseVector;
@@ -26,7 +27,8 @@ use crate::metrics::precision::Precision;
use crate::metrics::recall::Recall; use crate::metrics::recall::Recall;
/// F-measure /// F-measure
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct F1<T: RealNumber> { pub struct F1<T: RealNumber> {
/// a positive real factor /// a positive real factor
pub beta: T, pub beta: T,
@@ -57,6 +59,7 @@ impl<T: RealNumber> F1<T> {
mod tests { mod tests {
use super::*; use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn f1() { fn f1() {
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.]; let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
+4 -1
View File
@@ -18,12 +18,14 @@
//! //!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector; use crate::linalg::BaseVector;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
/// Mean Absolute Error /// Mean Absolute Error
pub struct MeanAbsoluteError {} pub struct MeanAbsoluteError {}
@@ -54,6 +56,7 @@ impl MeanAbsoluteError {
mod tests { mod tests {
use super::*; use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn mean_absolute_error() { fn mean_absolute_error() {
let y_true: Vec<f64> = vec![3., -0.5, 2., 7.]; let y_true: Vec<f64> = vec![3., -0.5, 2., 7.];
+4 -1
View File
@@ -18,12 +18,14 @@
//! //!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector; use crate::linalg::BaseVector;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
/// Mean Squared Error /// Mean Squared Error
pub struct MeanSquareError {} pub struct MeanSquareError {}
@@ -54,6 +56,7 @@ impl MeanSquareError {
mod tests { mod tests {
use super::*; use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn mean_squared_error() { fn mean_squared_error() {
let y_true: Vec<f64> = vec![3., -0.5, 2., 7.]; let y_true: Vec<f64> = vec![3., -0.5, 2., 7.];
+4 -1
View File
@@ -18,13 +18,15 @@
//! //!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector; use crate::linalg::BaseVector;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
/// Precision metric. /// Precision metric.
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct Precision {} pub struct Precision {}
impl Precision { impl Precision {
@@ -75,6 +77,7 @@ impl Precision {
mod tests { mod tests {
use super::*; use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn precision() { fn precision() {
let y_true: Vec<f64> = vec![0., 1., 1., 0.]; let y_true: Vec<f64> = vec![0., 1., 1., 0.];
+4 -1
View File
@@ -18,13 +18,15 @@
//! //!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector; use crate::linalg::BaseVector;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
/// Coefficient of Determination (R2) /// Coefficient of Determination (R2)
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct R2 {} pub struct R2 {}
impl R2 { impl R2 {
@@ -68,6 +70,7 @@ impl R2 {
mod tests { mod tests {
use super::*; use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn r2() { fn r2() {
let y_true: Vec<f64> = vec![3., -0.5, 2., 7.]; let y_true: Vec<f64> = vec![3., -0.5, 2., 7.];
+4 -1
View File
@@ -18,13 +18,15 @@
//! //!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector; use crate::linalg::BaseVector;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
/// Recall metric. /// Recall metric.
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct Recall {} pub struct Recall {}
impl Recall { impl Recall {
@@ -75,6 +77,7 @@ impl Recall {
mod tests { mod tests {
use super::*; use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn recall() { fn recall() {
let y_true: Vec<f64> = vec![0., 1., 1., 0.]; let y_true: Vec<f64> = vec![0., 1., 1., 0.];
+7
View File
@@ -144,6 +144,7 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::*; use crate::linalg::naive::dense_matrix::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn run_kfold_return_test_indices_simple() { fn run_kfold_return_test_indices_simple() {
let k = KFold { let k = KFold {
@@ -158,6 +159,7 @@ mod tests {
assert_eq!(test_indices[2], (22..33).collect::<Vec<usize>>()); assert_eq!(test_indices[2], (22..33).collect::<Vec<usize>>());
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn run_kfold_return_test_indices_odd() { fn run_kfold_return_test_indices_odd() {
let k = KFold { let k = KFold {
@@ -172,6 +174,7 @@ mod tests {
assert_eq!(test_indices[2], (23..34).collect::<Vec<usize>>()); assert_eq!(test_indices[2], (23..34).collect::<Vec<usize>>());
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn run_kfold_return_test_mask_simple() { fn run_kfold_return_test_mask_simple() {
let k = KFold { let k = KFold {
@@ -197,6 +200,7 @@ mod tests {
} }
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn run_kfold_return_split_simple() { fn run_kfold_return_split_simple() {
let k = KFold { let k = KFold {
@@ -212,6 +216,7 @@ mod tests {
assert_eq!(train_test_splits[1].1, (11..22).collect::<Vec<usize>>()); assert_eq!(train_test_splits[1].1, (11..22).collect::<Vec<usize>>());
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn run_kfold_return_split_simple_shuffle() { fn run_kfold_return_split_simple_shuffle() {
let k = KFold { let k = KFold {
@@ -227,6 +232,7 @@ mod tests {
assert_eq!(train_test_splits[1].1.len(), 11_usize); assert_eq!(train_test_splits[1].1.len(), 11_usize);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn numpy_parity_test() { fn numpy_parity_test() {
let k = KFold { let k = KFold {
@@ -247,6 +253,7 @@ mod tests {
} }
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn numpy_parity_test_shuffle() { fn numpy_parity_test_shuffle() {
let k = KFold { let k = KFold {
+4
View File
@@ -285,6 +285,7 @@ mod tests {
use crate::model_selection::kfold::KFold; use crate::model_selection::kfold::KFold;
use crate::neighbors::knn_regressor::KNNRegressor; use crate::neighbors::knn_regressor::KNNRegressor;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn run_train_test_split() { fn run_train_test_split() {
let n = 123; let n = 123;
@@ -308,6 +309,7 @@ mod tests {
#[derive(Clone)] #[derive(Clone)]
struct NoParameters {} struct NoParameters {}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn test_cross_validate_biased() { fn test_cross_validate_biased() {
struct BiasedEstimator {} struct BiasedEstimator {}
@@ -367,6 +369,7 @@ mod tests {
assert_eq!(0.4, results.mean_train_score()); assert_eq!(0.4, results.mean_train_score());
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn test_cross_validate_knn() { fn test_cross_validate_knn() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
@@ -411,6 +414,7 @@ mod tests {
assert!(results.mean_train_score() < results.mean_test_score()); assert!(results.mean_train_score() < results.mean_test_score());
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn test_cross_val_predict_knn() { fn test_cross_val_predict_knn() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
+139 -20
View File
@@ -42,15 +42,49 @@ use crate::math::num::RealNumber;
use crate::math::vector::RealNumberVector; use crate::math::vector::RealNumberVector;
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution}; use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// Naive Bayes classifier for Bearnoulli features /// Naive Bayes classifier for Bearnoulli features
#[derive(Serialize, Deserialize, Debug, PartialEq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
struct BernoulliNBDistribution<T: RealNumber> { struct BernoulliNBDistribution<T: RealNumber> {
/// class labels known to the classifier /// class labels known to the classifier
class_labels: Vec<T>, class_labels: Vec<T>,
/// number of training samples observed in each class
class_count: Vec<usize>,
/// probability of each class
class_priors: Vec<T>, class_priors: Vec<T>,
feature_prob: Vec<Vec<T>>, /// Number of samples encountered for each (class, feature)
feature_count: Vec<Vec<usize>>,
/// probability of features per class
feature_log_prob: Vec<Vec<T>>,
/// Number of features of each sample
n_features: usize,
}
impl<T: RealNumber> PartialEq for BernoulliNBDistribution<T> {
fn eq(&self, other: &Self) -> bool {
if self.class_labels == other.class_labels
&& self.class_count == other.class_count
&& self.class_priors == other.class_priors
&& self.feature_count == other.feature_count
&& self.n_features == other.n_features
{
for (a, b) in self
.feature_log_prob
.iter()
.zip(other.feature_log_prob.iter())
{
if !a.approximate_eq(b, T::epsilon()) {
return false;
}
}
true
} else {
false
}
}
} }
impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for BernoulliNBDistribution<T> { impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for BernoulliNBDistribution<T> {
@@ -63,9 +97,9 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for BernoulliNBDistributi
for feature in 0..j.len() { for feature in 0..j.len() {
let value = j.get(feature); let value = j.get(feature);
if value == T::one() { if value == T::one() {
likelihood += self.feature_prob[class_index][feature].ln(); likelihood += self.feature_log_prob[class_index][feature];
} else { } else {
likelihood += (T::one() - self.feature_prob[class_index][feature]).ln(); likelihood += (T::one() - self.feature_log_prob[class_index][feature].exp()).ln();
} }
} }
likelihood likelihood
@@ -77,7 +111,8 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for BernoulliNBDistributi
} }
/// `BernoulliNB` parameters. Use `Default::default()` for default values. /// `BernoulliNB` parameters. Use `Default::default()` for default values.
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct BernoulliNBParameters<T: RealNumber> { pub struct BernoulliNBParameters<T: RealNumber> {
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing). /// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
pub alpha: T, pub alpha: T,
@@ -154,10 +189,10 @@ impl<T: RealNumber> BernoulliNBDistribution<T> {
let y = y.to_vec(); let y = y.to_vec();
let (class_labels, indices) = <Vec<T> as RealNumberVector<T>>::unique_with_indices(&y); let (class_labels, indices) = <Vec<T> as RealNumberVector<T>>::unique_with_indices(&y);
let mut class_count = vec![T::zero(); class_labels.len()]; let mut class_count = vec![0_usize; class_labels.len()];
for class_index in indices.iter() { for class_index in indices.iter() {
class_count[*class_index] += T::one(); class_count[*class_index] += 1;
} }
let class_priors = if let Some(class_priors) = priors { let class_priors = if let Some(class_priors) = priors {
@@ -170,25 +205,35 @@ impl<T: RealNumber> BernoulliNBDistribution<T> {
} else { } else {
class_count class_count
.iter() .iter()
.map(|&c| c / T::from(n_samples).unwrap()) .map(|&c| T::from(c).unwrap() / T::from(n_samples).unwrap())
.collect() .collect()
}; };
let mut feature_in_class_counter = vec![vec![T::zero(); n_features]; class_labels.len()]; let mut feature_in_class_counter = vec![vec![0_usize; n_features]; class_labels.len()];
for (row, class_index) in row_iter(x).zip(indices) { for (row, class_index) in row_iter(x).zip(indices) {
for (idx, row_i) in row.iter().enumerate().take(n_features) { for (idx, row_i) in row.iter().enumerate().take(n_features) {
feature_in_class_counter[class_index][idx] += *row_i; feature_in_class_counter[class_index][idx] +=
row_i.to_usize().ok_or_else(|| {
Failed::fit(&format!(
"Elements of the matrix should be 1.0 or 0.0 |found|=[{}]",
row_i
))
})?;
} }
} }
let feature_prob = feature_in_class_counter let feature_log_prob = feature_in_class_counter
.iter() .iter()
.enumerate() .enumerate()
.map(|(class_index, feature_count)| { .map(|(class_index, feature_count)| {
feature_count feature_count
.iter() .iter()
.map(|&count| (count + alpha) / (class_count[class_index] + alpha * T::two())) .map(|&count| {
((T::from(count).unwrap() + alpha)
/ (T::from(class_count[class_index]).unwrap() + alpha * T::two()))
.ln()
})
.collect() .collect()
}) })
.collect(); .collect();
@@ -196,13 +241,18 @@ impl<T: RealNumber> BernoulliNBDistribution<T> {
Ok(Self { Ok(Self {
class_labels, class_labels,
class_priors, class_priors,
feature_prob, class_count,
feature_count: feature_in_class_counter,
feature_log_prob,
n_features,
}) })
} }
} }
/// BernoulliNB implements the categorical naive Bayes algorithm for categorically distributed data. /// BernoulliNB implements the naive Bayes algorithm for data that follows the Bernoulli
#[derive(Serialize, Deserialize, Debug, PartialEq)] /// distribution.
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, PartialEq)]
pub struct BernoulliNB<T: RealNumber, M: Matrix<T>> { pub struct BernoulliNB<T: RealNumber, M: Matrix<T>> {
inner: BaseNaiveBayes<T, M, BernoulliNBDistribution<T>>, inner: BaseNaiveBayes<T, M, BernoulliNBDistribution<T>>,
binarize: Option<T>, binarize: Option<T>,
@@ -262,6 +312,34 @@ impl<T: RealNumber, M: Matrix<T>> BernoulliNB<T, M> {
self.inner.predict(x) self.inner.predict(x)
} }
} }
/// Class labels known to the classifier.
/// Returns a vector of size n_classes.
pub fn classes(&self) -> &Vec<T> {
&self.inner.distribution.class_labels
}
/// Number of training samples observed in each class.
/// Returns a vector of size n_classes.
pub fn class_count(&self) -> &Vec<usize> {
&self.inner.distribution.class_count
}
/// Number of features of each sample
pub fn n_features(&self) -> usize {
self.inner.distribution.n_features
}
/// Number of samples encountered for each (class, feature)
/// Returns a 2d vector of shape (n_classes, n_features)
pub fn feature_count(&self) -> &Vec<Vec<usize>> {
&self.inner.distribution.feature_count
}
/// Empirical log probability of features given a class
pub fn feature_log_prob(&self) -> &Vec<Vec<T>> {
&self.inner.distribution.feature_log_prob
}
} }
#[cfg(test)] #[cfg(test)]
@@ -269,6 +347,7 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn run_bernoulli_naive_bayes() { fn run_bernoulli_naive_bayes() {
// Tests that BernoulliNB when alpha=1.0 gives the same values as // Tests that BernoulliNB when alpha=1.0 gives the same values as
@@ -292,10 +371,24 @@ mod tests {
assert_eq!(bnb.inner.distribution.class_priors, &[0.75, 0.25]); assert_eq!(bnb.inner.distribution.class_priors, &[0.75, 0.25]);
assert_eq!( assert_eq!(
bnb.inner.distribution.feature_prob, bnb.feature_log_prob(),
&[ &[
&[0.4, 0.8, 0.2, 0.4, 0.4, 0.2], &[
&[1. / 3.0, 2. / 3.0, 2. / 3.0, 1. / 3.0, 1. / 3.0, 2. / 3.0] -0.916290731874155,
-0.2231435513142097,
-1.6094379124341003,
-0.916290731874155,
-0.916290731874155,
-1.6094379124341003
],
&[
-1.0986122886681098,
-0.40546510810816444,
-0.40546510810816444,
-1.0986122886681098,
-1.0986122886681098,
-0.40546510810816444
]
] ]
); );
@@ -307,6 +400,7 @@ mod tests {
assert_eq!(y_hat, &[1.]); assert_eq!(y_hat, &[1.]);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn bernoulli_nb_scikit_parity() { fn bernoulli_nb_scikit_parity() {
let x = DenseMatrix::<f64>::from_2d_array(&[ let x = DenseMatrix::<f64>::from_2d_array(&[
@@ -331,13 +425,36 @@ mod tests {
let y_hat = bnb.predict(&x).unwrap(); let y_hat = bnb.predict(&x).unwrap();
assert_eq!(bnb.classes(), &[0., 1., 2.]);
assert_eq!(bnb.class_count(), &[7, 3, 5]);
assert_eq!(bnb.n_features(), 10);
assert_eq!(
bnb.feature_count(),
&[
&[5, 6, 6, 7, 6, 4, 6, 7, 7, 7],
&[3, 3, 3, 1, 3, 2, 3, 2, 2, 3],
&[4, 4, 3, 4, 5, 2, 4, 5, 3, 4]
]
);
assert!(bnb assert!(bnb
.inner .inner
.distribution .distribution
.class_priors .class_priors
.approximate_eq(&vec!(0.46, 0.2, 0.33), 1e-2)); .approximate_eq(&vec!(0.46, 0.2, 0.33), 1e-2));
assert!(bnb.inner.distribution.feature_prob[1].approximate_eq( assert!(bnb.feature_log_prob()[1].approximate_eq(
&vec!(0.8, 0.8, 0.8, 0.4, 0.8, 0.6, 0.8, 0.6, 0.6, 0.8), &vec![
-0.22314355,
-0.22314355,
-0.22314355,
-0.91629073,
-0.22314355,
-0.51082562,
-0.22314355,
-0.51082562,
-0.51082562,
-0.22314355
],
1e-1 1e-1
)); ));
assert!(y_hat.approximate_eq( assert!(y_hat.approximate_eq(
@@ -346,7 +463,9 @@ mod tests {
)); ));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
#[cfg(feature = "serde")]
fn serde() { fn serde() {
let x = DenseMatrix::<f64>::from_2d_array(&[ let x = DenseMatrix::<f64>::from_2d_array(&[
&[1., 1., 0., 0., 0., 0.], &[1., 1., 0., 0., 0., 0.],
+144 -25
View File
@@ -36,19 +36,38 @@ use crate::linalg::BaseVector;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution}; use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// Naive Bayes classifier for categorical features /// Naive Bayes classifier for categorical features
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
struct CategoricalNBDistribution<T: RealNumber> { struct CategoricalNBDistribution<T: RealNumber> {
/// number of training samples observed in each class
class_count: Vec<usize>,
/// class labels known to the classifier
class_labels: Vec<T>, class_labels: Vec<T>,
/// probability of each class
class_priors: Vec<T>, class_priors: Vec<T>,
coefficients: Vec<Vec<Vec<T>>>, coefficients: Vec<Vec<Vec<T>>>,
/// Number of features of each sample
n_features: usize,
/// Number of categories for each feature
n_categories: Vec<usize>,
/// Holds arrays of shape (n_classes, n_categories of respective feature)
/// for each feature. Each array provides the number of samples
/// encountered for each class and category of the specific feature.
category_count: Vec<Vec<Vec<usize>>>,
} }
impl<T: RealNumber> PartialEq for CategoricalNBDistribution<T> { impl<T: RealNumber> PartialEq for CategoricalNBDistribution<T> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
if self.class_labels == other.class_labels && self.class_priors == other.class_priors { if self.class_labels == other.class_labels
&& self.class_priors == other.class_priors
&& self.n_features == other.n_features
&& self.n_categories == other.n_categories
&& self.class_count == other.class_count
{
if self.coefficients.len() != other.coefficients.len() { if self.coefficients.len() != other.coefficients.len() {
return false; return false;
} }
@@ -88,8 +107,8 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for CategoricalNBDistribu
let mut likelihood = T::zero(); let mut likelihood = T::zero();
for feature in 0..j.len() { for feature in 0..j.len() {
let value = j.get(feature).floor().to_usize().unwrap(); let value = j.get(feature).floor().to_usize().unwrap();
if self.coefficients[class_index][feature].len() > value { if self.coefficients[feature][class_index].len() > value {
likelihood += self.coefficients[class_index][feature][value]; likelihood += self.coefficients[feature][class_index][value];
} else { } else {
return T::zero(); return T::zero();
} }
@@ -142,17 +161,17 @@ impl<T: RealNumber> CategoricalNBDistribution<T> {
let y_max = y let y_max = y
.iter() .iter()
.max() .max()
.ok_or_else(|| Failed::fit(&"Failed to get the labels of y.".to_string()))?; .ok_or_else(|| Failed::fit("Failed to get the labels of y."))?;
let class_labels: Vec<T> = (0..*y_max + 1) let class_labels: Vec<T> = (0..*y_max + 1)
.map(|label| T::from(label).unwrap()) .map(|label| T::from(label).unwrap())
.collect(); .collect();
let mut classes_count: Vec<T> = vec![T::zero(); class_labels.len()]; let mut class_count = vec![0_usize; class_labels.len()];
for elem in y.iter() { for elem in y.iter() {
classes_count[*elem] += T::one(); class_count[*elem] += 1;
} }
let mut feature_categories: Vec<Vec<T>> = Vec::with_capacity(n_features); let mut n_categories: Vec<usize> = Vec::with_capacity(n_features);
for feature in 0..n_features { for feature in 0..n_features {
let feature_max = x let feature_max = x
.get_col_as_vec(feature) .get_col_as_vec(feature)
@@ -165,18 +184,15 @@ impl<T: RealNumber> CategoricalNBDistribution<T> {
feature feature
)) ))
})?; })?;
let feature_types = (0..feature_max + 1) n_categories.push(feature_max + 1);
.map(|feat| T::from(feat).unwrap())
.collect();
feature_categories.push(feature_types);
} }
let mut coefficients: Vec<Vec<Vec<T>>> = Vec::with_capacity(class_labels.len()); let mut coefficients: Vec<Vec<Vec<T>>> = Vec::with_capacity(class_labels.len());
for (label, label_count) in class_labels.iter().zip(classes_count.iter()) { let mut category_count: Vec<Vec<Vec<usize>>> = Vec::with_capacity(class_labels.len());
for (feature_index, &n_categories_i) in n_categories.iter().enumerate().take(n_features) {
let mut coef_i: Vec<Vec<T>> = Vec::with_capacity(n_features); let mut coef_i: Vec<Vec<T>> = Vec::with_capacity(n_features);
for (feature_index, feature_options) in let mut category_count_i: Vec<Vec<usize>> = Vec::with_capacity(n_features);
feature_categories.iter().enumerate().take(n_features) for (label, &label_count) in class_labels.iter().zip(class_count.iter()) {
{
let col = x let col = x
.get_col_as_vec(feature_index) .get_col_as_vec(feature_index)
.iter() .iter()
@@ -184,39 +200,48 @@ impl<T: RealNumber> CategoricalNBDistribution<T> {
.filter(|(i, _j)| T::from(y[*i]).unwrap() == *label) .filter(|(i, _j)| T::from(y[*i]).unwrap() == *label)
.map(|(_, j)| *j) .map(|(_, j)| *j)
.collect::<Vec<T>>(); .collect::<Vec<T>>();
let mut feat_count: Vec<T> = vec![T::zero(); feature_options.len()]; let mut feat_count: Vec<usize> = vec![0_usize; n_categories_i];
for row in col.iter() { for row in col.iter() {
let index = row.floor().to_usize().unwrap(); let index = row.floor().to_usize().unwrap();
feat_count[index] += T::one(); feat_count[index] += 1;
} }
let coef_i_j = feat_count let coef_i_j = feat_count
.iter() .iter()
.map(|c| { .map(|c| {
((*c + alpha) ((T::from(*c).unwrap() + alpha)
/ (*label_count + T::from(feature_options.len()).unwrap() * alpha)) / (T::from(label_count).unwrap()
+ T::from(n_categories_i).unwrap() * alpha))
.ln() .ln()
}) })
.collect::<Vec<T>>(); .collect::<Vec<T>>();
category_count_i.push(feat_count);
coef_i.push(coef_i_j); coef_i.push(coef_i_j);
} }
category_count.push(category_count_i);
coefficients.push(coef_i); coefficients.push(coef_i);
} }
let class_priors = classes_count let class_priors = class_count
.into_iter() .iter()
.map(|count| count / T::from(n_samples).unwrap()) .map(|&count| T::from(count).unwrap() / T::from(n_samples).unwrap())
.collect::<Vec<T>>(); .collect::<Vec<T>>();
Ok(Self { Ok(Self {
class_count,
class_labels, class_labels,
class_priors, class_priors,
coefficients, coefficients,
n_features,
n_categories,
category_count,
}) })
} }
} }
/// `CategoricalNB` parameters. Use `Default::default()` for default values. /// `CategoricalNB` parameters. Use `Default::default()` for default values.
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct CategoricalNBParameters<T: RealNumber> { pub struct CategoricalNBParameters<T: RealNumber> {
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing). /// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
pub alpha: T, pub alpha: T,
@@ -237,7 +262,8 @@ impl<T: RealNumber> Default for CategoricalNBParameters<T> {
} }
/// CategoricalNB implements the categorical naive Bayes algorithm for categorically distributed data. /// CategoricalNB implements the categorical naive Bayes algorithm for categorically distributed data.
#[derive(Serialize, Deserialize, Debug, PartialEq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, PartialEq)]
pub struct CategoricalNB<T: RealNumber, M: Matrix<T>> { pub struct CategoricalNB<T: RealNumber, M: Matrix<T>> {
inner: BaseNaiveBayes<T, M, CategoricalNBDistribution<T>>, inner: BaseNaiveBayes<T, M, CategoricalNBDistribution<T>>,
} }
@@ -283,6 +309,41 @@ impl<T: RealNumber, M: Matrix<T>> CategoricalNB<T, M> {
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> { pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
self.inner.predict(x) self.inner.predict(x)
} }
/// Class labels known to the classifier.
/// Returns a vector of size n_classes.
pub fn classes(&self) -> &Vec<T> {
&self.inner.distribution.class_labels
}
/// Number of training samples observed in each class.
/// Returns a vector of size n_classes.
pub fn class_count(&self) -> &Vec<usize> {
&self.inner.distribution.class_count
}
/// Number of features of each sample
pub fn n_features(&self) -> usize {
self.inner.distribution.n_features
}
/// Number of features of each sample
pub fn n_categories(&self) -> &Vec<usize> {
&self.inner.distribution.n_categories
}
/// Holds arrays of shape (n_classes, n_categories of respective feature)
/// for each feature. Each array provides the number of samples
/// encountered for each class and category of the specific feature.
pub fn category_count(&self) -> &Vec<Vec<Vec<usize>>> {
&self.inner.distribution.category_count
}
/// Holds arrays of shape (n_classes, n_categories of respective feature)
/// for each feature. Each array provides the empirical log probability
/// of categories given the respective feature and class, ``P(x_i|y)``.
pub fn feature_log_prob(&self) -> &Vec<Vec<Vec<T>>> {
&self.inner.distribution.coefficients
}
} }
#[cfg(test)] #[cfg(test)]
@@ -290,6 +351,7 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn run_categorical_naive_bayes() { fn run_categorical_naive_bayes() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
@@ -311,11 +373,66 @@ mod tests {
let y = vec![0., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 0.]; let y = vec![0., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 0.];
let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap(); let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap();
// checking parity with scikit
assert_eq!(cnb.classes(), &[0., 1.]);
assert_eq!(cnb.class_count(), &[5, 9]);
assert_eq!(cnb.n_features(), 4);
assert_eq!(cnb.n_categories(), &[3, 3, 2, 2]);
assert_eq!(
cnb.category_count(),
&vec![
vec![vec![3, 0, 2], vec![2, 4, 3]],
vec![vec![1, 2, 2], vec![3, 4, 2]],
vec![vec![1, 4], vec![6, 3]],
vec![vec![2, 3], vec![6, 3]]
]
);
assert_eq!(
cnb.feature_log_prob(),
&vec![
vec![
vec![
-0.6931471805599453,
-2.0794415416798357,
-0.9808292530117262
],
vec![
-1.3862943611198906,
-0.8754687373538999,
-1.0986122886681098
]
],
vec![
vec![
-1.3862943611198906,
-0.9808292530117262,
-0.9808292530117262
],
vec![
-1.0986122886681098,
-0.8754687373538999,
-1.3862943611198906
]
],
vec![
vec![-1.252762968495368, -0.3364722366212129],
vec![-0.45198512374305727, -1.0116009116784799]
],
vec![
vec![-0.8472978603872037, -0.5596157879354228],
vec![-0.45198512374305727, -1.0116009116784799]
]
]
);
let x_test = DenseMatrix::from_2d_array(&[&[0., 2., 1., 0.], &[2., 2., 0., 0.]]); let x_test = DenseMatrix::from_2d_array(&[&[0., 2., 1., 0.], &[2., 2., 0., 0.]]);
let y_hat = cnb.predict(&x_test).unwrap(); let y_hat = cnb.predict(&x_test).unwrap();
assert_eq!(y_hat, vec![0., 1.]); assert_eq!(y_hat, vec![0., 1.]);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn run_categorical_naive_bayes2() { fn run_categorical_naive_bayes2() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
@@ -344,7 +461,9 @@ mod tests {
); );
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
#[cfg(feature = "serde")]
fn serde() { fn serde() {
let x = DenseMatrix::<f64>::from_2d_array(&[ let x = DenseMatrix::<f64>::from_2d_array(&[
&[3., 4., 0., 1.], &[3., 4., 0., 1.],
+70 -27
View File
@@ -30,17 +30,21 @@ use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
use crate::math::vector::RealNumberVector; use crate::math::vector::RealNumberVector;
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution}; use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// Naive Bayes classifier for categorical features /// Naive Bayes classifier using Gaussian distribution
#[derive(Serialize, Deserialize, Debug, PartialEq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, PartialEq)]
struct GaussianNBDistribution<T: RealNumber> { struct GaussianNBDistribution<T: RealNumber> {
/// class labels known to the classifier /// class labels known to the classifier
class_labels: Vec<T>, class_labels: Vec<T>,
/// number of training samples observed in each class
class_count: Vec<usize>,
/// probability of each class. /// probability of each class.
class_priors: Vec<T>, class_priors: Vec<T>,
/// variance of each feature per class /// variance of each feature per class
sigma: Vec<Vec<T>>, var: Vec<Vec<T>>,
/// mean of each feature per class /// mean of each feature per class
theta: Vec<Vec<T>>, theta: Vec<Vec<T>>,
} }
@@ -55,18 +59,14 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for GaussianNBDistributio
} }
fn log_likelihood(&self, class_index: usize, j: &M::RowVector) -> T { fn log_likelihood(&self, class_index: usize, j: &M::RowVector) -> T {
if class_index < self.class_labels.len() { let mut likelihood = T::zero();
let mut likelihood = T::zero(); for feature in 0..j.len() {
for feature in 0..j.len() { let value = j.get(feature);
let value = j.get(feature); let mean = self.theta[class_index][feature];
let mean = self.theta[class_index][feature]; let variance = self.var[class_index][feature];
let variance = self.sigma[class_index][feature]; likelihood += self.calculate_log_probability(value, mean, variance);
likelihood += self.calculate_log_probability(value, mean, variance);
}
likelihood
} else {
T::zero()
} }
likelihood
} }
fn classes(&self) -> &Vec<T> { fn classes(&self) -> &Vec<T> {
@@ -75,7 +75,8 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for GaussianNBDistributio
} }
/// `GaussianNB` parameters. Use `Default::default()` for default values. /// `GaussianNB` parameters. Use `Default::default()` for default values.
#[derive(Serialize, Deserialize, Debug, Default, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Default, Clone)]
pub struct GaussianNBParameters<T: RealNumber> { pub struct GaussianNBParameters<T: RealNumber> {
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data /// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
pub priors: Option<Vec<T>>, pub priors: Option<Vec<T>>,
@@ -118,12 +119,12 @@ impl<T: RealNumber> GaussianNBDistribution<T> {
let y = y.to_vec(); let y = y.to_vec();
let (class_labels, indices) = <Vec<T> as RealNumberVector<T>>::unique_with_indices(&y); let (class_labels, indices) = <Vec<T> as RealNumberVector<T>>::unique_with_indices(&y);
let mut class_count = vec![T::zero(); class_labels.len()]; let mut class_count = vec![0_usize; class_labels.len()];
let mut subdataset: Vec<Vec<Vec<T>>> = vec![vec![]; class_labels.len()]; let mut subdataset: Vec<Vec<Vec<T>>> = vec![vec![]; class_labels.len()];
for (row, class_index) in row_iter(x).zip(indices.iter()) { for (row, class_index) in row_iter(x).zip(indices.iter()) {
class_count[*class_index] += T::one(); class_count[*class_index] += 1;
subdataset[*class_index].push(row); subdataset[*class_index].push(row);
} }
@@ -136,8 +137,8 @@ impl<T: RealNumber> GaussianNBDistribution<T> {
class_priors class_priors
} else { } else {
class_count class_count
.into_iter() .iter()
.map(|c| c / T::from(n_samples).unwrap()) .map(|&c| T::from(c).unwrap() / T::from(n_samples).unwrap())
.collect() .collect()
}; };
@@ -154,15 +155,16 @@ impl<T: RealNumber> GaussianNBDistribution<T> {
}) })
.collect(); .collect();
let (sigma, theta): (Vec<Vec<T>>, Vec<Vec<T>>) = subdataset let (var, theta): (Vec<Vec<T>>, Vec<Vec<T>>) = subdataset
.iter() .iter()
.map(|data| (data.var(0), data.mean(0))) .map(|data| (data.var(0), data.mean(0)))
.unzip(); .unzip();
Ok(Self { Ok(Self {
class_labels, class_labels,
class_count,
class_priors, class_priors,
sigma, var,
theta, theta,
}) })
} }
@@ -177,8 +179,10 @@ impl<T: RealNumber> GaussianNBDistribution<T> {
} }
} }
/// GaussianNB implements the categorical naive Bayes algorithm for categorically distributed data. /// GaussianNB implements the naive Bayes algorithm for data that follows the Gaussian
#[derive(Serialize, Deserialize, Debug, PartialEq)] /// distribution.
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, PartialEq)]
pub struct GaussianNB<T: RealNumber, M: Matrix<T>> { pub struct GaussianNB<T: RealNumber, M: Matrix<T>> {
inner: BaseNaiveBayes<T, M, GaussianNBDistribution<T>>, inner: BaseNaiveBayes<T, M, GaussianNBDistribution<T>>,
} }
@@ -219,6 +223,36 @@ impl<T: RealNumber, M: Matrix<T>> GaussianNB<T, M> {
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> { pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
self.inner.predict(x) self.inner.predict(x)
} }
/// Class labels known to the classifier.
/// Returns a vector of size n_classes.
pub fn classes(&self) -> &Vec<T> {
&self.inner.distribution.class_labels
}
/// Number of training samples observed in each class.
/// Returns a vector of size n_classes.
pub fn class_count(&self) -> &Vec<usize> {
&self.inner.distribution.class_count
}
/// Probability of each class
/// Returns a vector of size n_classes.
pub fn class_priors(&self) -> &Vec<T> {
&self.inner.distribution.class_priors
}
/// Mean of each feature per class
/// Returns a 2d vector of shape (n_classes, n_features).
pub fn theta(&self) -> &Vec<Vec<T>> {
&self.inner.distribution.theta
}
/// Variance of each feature per class
/// Returns a 2d vector of shape (n_classes, n_features).
pub fn var(&self) -> &Vec<Vec<T>> {
&self.inner.distribution.var
}
} }
#[cfg(test)] #[cfg(test)]
@@ -226,6 +260,7 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn run_gaussian_naive_bayes() { fn run_gaussian_naive_bayes() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
@@ -241,22 +276,28 @@ mod tests {
let gnb = GaussianNB::fit(&x, &y, Default::default()).unwrap(); let gnb = GaussianNB::fit(&x, &y, Default::default()).unwrap();
let y_hat = gnb.predict(&x).unwrap(); let y_hat = gnb.predict(&x).unwrap();
assert_eq!(y_hat, y); assert_eq!(y_hat, y);
assert_eq!(gnb.classes(), &[1., 2.]);
assert_eq!(gnb.class_count(), &[3, 3]);
assert_eq!( assert_eq!(
gnb.inner.distribution.sigma, gnb.var(),
&[ &[
&[0.666666666666667, 0.22222222222222232], &[0.666666666666667, 0.22222222222222232],
&[0.666666666666667, 0.22222222222222232] &[0.666666666666667, 0.22222222222222232]
] ]
); );
assert_eq!(gnb.inner.distribution.class_priors, &[0.5, 0.5]); assert_eq!(gnb.class_priors(), &[0.5, 0.5]);
assert_eq!( assert_eq!(
gnb.inner.distribution.theta, gnb.theta(),
&[&[-2., -1.3333333333333333], &[2., 1.3333333333333333]] &[&[-2., -1.3333333333333333], &[2., 1.3333333333333333]]
); );
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn run_gaussian_naive_bayes_with_priors() { fn run_gaussian_naive_bayes_with_priors() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
@@ -273,10 +314,12 @@ mod tests {
let parameters = GaussianNBParameters::default().with_priors(priors.clone()); let parameters = GaussianNBParameters::default().with_priors(priors.clone());
let gnb = GaussianNB::fit(&x, &y, parameters).unwrap(); let gnb = GaussianNB::fit(&x, &y, parameters).unwrap();
assert_eq!(gnb.inner.distribution.class_priors, priors); assert_eq!(gnb.class_priors(), &priors);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
#[cfg(feature = "serde")]
fn serde() { fn serde() {
let x = DenseMatrix::<f64>::from_2d_array(&[ let x = DenseMatrix::<f64>::from_2d_array(&[
&[-1., -1.], &[-1., -1.],
+3 -1
View File
@@ -39,6 +39,7 @@ use crate::error::Failed;
use crate::linalg::BaseVector; use crate::linalg::BaseVector;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::marker::PhantomData; use std::marker::PhantomData;
@@ -55,7 +56,8 @@ pub(crate) trait NBDistribution<T: RealNumber, M: Matrix<T>> {
} }
/// Base struct for the Naive Bayes classifier. /// Base struct for the Naive Bayes classifier.
#[derive(Serialize, Deserialize, Debug, PartialEq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, PartialEq)]
pub(crate) struct BaseNaiveBayes<T: RealNumber, M: Matrix<T>, D: NBDistribution<T, M>> { pub(crate) struct BaseNaiveBayes<T: RealNumber, M: Matrix<T>, D: NBDistribution<T, M>> {
distribution: D, distribution: D,
_phantom_t: PhantomData<T>, _phantom_t: PhantomData<T>,
+117 -21
View File
@@ -42,15 +42,25 @@ use crate::math::num::RealNumber;
use crate::math::vector::RealNumberVector; use crate::math::vector::RealNumberVector;
use crate::naive_bayes::{BaseNaiveBayes, NBDistribution}; use crate::naive_bayes::{BaseNaiveBayes, NBDistribution};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// Naive Bayes classifier for Multinomial features /// Naive Bayes classifier for Multinomial features
#[derive(Serialize, Deserialize, Debug, PartialEq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, PartialEq)]
struct MultinomialNBDistribution<T: RealNumber> { struct MultinomialNBDistribution<T: RealNumber> {
/// class labels known to the classifier /// class labels known to the classifier
class_labels: Vec<T>, class_labels: Vec<T>,
/// number of training samples observed in each class
class_count: Vec<usize>,
/// probability of each class
class_priors: Vec<T>, class_priors: Vec<T>,
feature_prob: Vec<Vec<T>>, /// Empirical log probability of features given a class
feature_log_prob: Vec<Vec<T>>,
/// Number of samples encountered for each (class, feature)
feature_count: Vec<Vec<usize>>,
/// Number of features of each sample
n_features: usize,
} }
impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for MultinomialNBDistribution<T> { impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for MultinomialNBDistribution<T> {
@@ -62,7 +72,7 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for MultinomialNBDistribu
let mut likelihood = T::zero(); let mut likelihood = T::zero();
for feature in 0..j.len() { for feature in 0..j.len() {
let value = j.get(feature); let value = j.get(feature);
likelihood += value * self.feature_prob[class_index][feature].ln(); likelihood += value * self.feature_log_prob[class_index][feature];
} }
likelihood likelihood
} }
@@ -73,7 +83,8 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for MultinomialNBDistribu
} }
/// `MultinomialNB` parameters. Use `Default::default()` for default values. /// `MultinomialNB` parameters. Use `Default::default()` for default values.
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct MultinomialNBParameters<T: RealNumber> { pub struct MultinomialNBParameters<T: RealNumber> {
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing). /// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
pub alpha: T, pub alpha: T,
@@ -141,10 +152,10 @@ impl<T: RealNumber> MultinomialNBDistribution<T> {
let y = y.to_vec(); let y = y.to_vec();
let (class_labels, indices) = <Vec<T> as RealNumberVector<T>>::unique_with_indices(&y); let (class_labels, indices) = <Vec<T> as RealNumberVector<T>>::unique_with_indices(&y);
let mut class_count = vec![T::zero(); class_labels.len()]; let mut class_count = vec![0_usize; class_labels.len()];
for class_index in indices.iter() { for class_index in indices.iter() {
class_count[*class_index] += T::one(); class_count[*class_index] += 1;
} }
let class_priors = if let Some(class_priors) = priors { let class_priors = if let Some(class_priors) = priors {
@@ -157,39 +168,53 @@ impl<T: RealNumber> MultinomialNBDistribution<T> {
} else { } else {
class_count class_count
.iter() .iter()
.map(|&c| c / T::from(n_samples).unwrap()) .map(|&c| T::from(c).unwrap() / T::from(n_samples).unwrap())
.collect() .collect()
}; };
let mut feature_in_class_counter = vec![vec![T::zero(); n_features]; class_labels.len()]; let mut feature_in_class_counter = vec![vec![0_usize; n_features]; class_labels.len()];
for (row, class_index) in row_iter(x).zip(indices) { for (row, class_index) in row_iter(x).zip(indices) {
for (idx, row_i) in row.iter().enumerate().take(n_features) { for (idx, row_i) in row.iter().enumerate().take(n_features) {
feature_in_class_counter[class_index][idx] += *row_i; feature_in_class_counter[class_index][idx] +=
row_i.to_usize().ok_or_else(|| {
Failed::fit(&format!(
"Elements of the matrix should be convertible to usize |found|=[{}]",
row_i
))
})?;
} }
} }
let feature_prob = feature_in_class_counter let feature_log_prob = feature_in_class_counter
.iter() .iter()
.map(|feature_count| { .map(|feature_count| {
let n_c = feature_count.sum(); let n_c: usize = feature_count.iter().sum();
feature_count feature_count
.iter() .iter()
.map(|&count| (count + alpha) / (n_c + alpha * T::from(n_features).unwrap())) .map(|&count| {
((T::from(count).unwrap() + alpha)
/ (T::from(n_c).unwrap() + alpha * T::from(n_features).unwrap()))
.ln()
})
.collect() .collect()
}) })
.collect(); .collect();
Ok(Self { Ok(Self {
class_count,
class_labels, class_labels,
class_priors, class_priors,
feature_prob, feature_log_prob,
feature_count: feature_in_class_counter,
n_features,
}) })
} }
} }
/// MultinomialNB implements the categorical naive Bayes algorithm for categorically distributed data. /// MultinomialNB implements the naive Bayes algorithm for multinomially distributed data.
#[derive(Serialize, Deserialize, Debug, PartialEq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, PartialEq)]
pub struct MultinomialNB<T: RealNumber, M: Matrix<T>> { pub struct MultinomialNB<T: RealNumber, M: Matrix<T>> {
inner: BaseNaiveBayes<T, M, MultinomialNBDistribution<T>>, inner: BaseNaiveBayes<T, M, MultinomialNBDistribution<T>>,
} }
@@ -236,6 +261,35 @@ impl<T: RealNumber, M: Matrix<T>> MultinomialNB<T, M> {
pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> { pub fn predict(&self, x: &M) -> Result<M::RowVector, Failed> {
self.inner.predict(x) self.inner.predict(x)
} }
/// Class labels known to the classifier.
/// Returns a vector of size n_classes.
pub fn classes(&self) -> &Vec<T> {
&self.inner.distribution.class_labels
}
/// Number of training samples observed in each class.
/// Returns a vector of size n_classes.
pub fn class_count(&self) -> &Vec<usize> {
&self.inner.distribution.class_count
}
/// Empirical log probability of features given a class, P(x_i|y).
/// Returns a 2d vector of shape (n_classes, n_features)
pub fn feature_log_prob(&self) -> &Vec<Vec<T>> {
&self.inner.distribution.feature_log_prob
}
/// Number of features of each sample
pub fn n_features(&self) -> usize {
self.inner.distribution.n_features
}
/// Number of samples encountered for each (class, feature)
/// Returns a 2d vector of shape (n_classes, n_features)
pub fn feature_count(&self) -> &Vec<Vec<usize>> {
&self.inner.distribution.feature_count
}
} }
#[cfg(test)] #[cfg(test)]
@@ -243,6 +297,7 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn run_multinomial_naive_bayes() { fn run_multinomial_naive_bayes() {
// Tests that MultinomialNB when alpha=1.0 gives the same values as // Tests that MultinomialNB when alpha=1.0 gives the same values as
@@ -264,12 +319,29 @@ mod tests {
let y = vec![0., 0., 0., 1.]; let y = vec![0., 0., 0., 1.];
let mnb = MultinomialNB::fit(&x, &y, Default::default()).unwrap(); let mnb = MultinomialNB::fit(&x, &y, Default::default()).unwrap();
assert_eq!(mnb.classes(), &[0., 1.]);
assert_eq!(mnb.class_count(), &[3, 1]);
assert_eq!(mnb.inner.distribution.class_priors, &[0.75, 0.25]); assert_eq!(mnb.inner.distribution.class_priors, &[0.75, 0.25]);
assert_eq!( assert_eq!(
mnb.inner.distribution.feature_prob, mnb.feature_log_prob(),
&[ &[
&[1. / 7., 3. / 7., 1. / 14., 1. / 7., 1. / 7., 1. / 14.], &[
&[1. / 9., 2. / 9.0, 2. / 9.0, 1. / 9.0, 1. / 9.0, 2. / 9.0] (1_f64 / 7_f64).ln(),
(3_f64 / 7_f64).ln(),
(1_f64 / 14_f64).ln(),
(1_f64 / 7_f64).ln(),
(1_f64 / 7_f64).ln(),
(1_f64 / 14_f64).ln()
],
&[
(1_f64 / 9_f64).ln(),
(2_f64 / 9_f64).ln(),
(2_f64 / 9_f64).ln(),
(1_f64 / 9_f64).ln(),
(1_f64 / 9_f64).ln(),
(2_f64 / 9_f64).ln()
]
] ]
); );
@@ -281,6 +353,7 @@ mod tests {
assert_eq!(y_hat, &[0.]); assert_eq!(y_hat, &[0.]);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn multinomial_nb_scikit_parity() { fn multinomial_nb_scikit_parity() {
let x = DenseMatrix::<f64>::from_2d_array(&[ let x = DenseMatrix::<f64>::from_2d_array(&[
@@ -303,6 +376,16 @@ mod tests {
let y = vec![2., 2., 0., 0., 0., 2., 1., 1., 0., 1., 0., 0., 2., 0., 2.]; let y = vec![2., 2., 0., 0., 0., 2., 1., 1., 0., 1., 0., 0., 2., 0., 2.];
let nb = MultinomialNB::fit(&x, &y, Default::default()).unwrap(); let nb = MultinomialNB::fit(&x, &y, Default::default()).unwrap();
assert_eq!(nb.n_features(), 10);
assert_eq!(
nb.feature_count(),
&[
&[12, 20, 11, 24, 12, 14, 13, 17, 13, 18],
&[9, 6, 9, 4, 7, 3, 8, 5, 4, 9],
&[10, 12, 9, 9, 11, 3, 9, 18, 10, 10]
]
);
let y_hat = nb.predict(&x).unwrap(); let y_hat = nb.predict(&x).unwrap();
assert!(nb assert!(nb
@@ -310,16 +393,29 @@ mod tests {
.distribution .distribution
.class_priors .class_priors
.approximate_eq(&vec!(0.46, 0.2, 0.33), 1e-2)); .approximate_eq(&vec!(0.46, 0.2, 0.33), 1e-2));
assert!(nb.inner.distribution.feature_prob[1].approximate_eq( assert!(nb.feature_log_prob()[1].approximate_eq(
&vec!(0.07, 0.12, 0.07, 0.15, 0.07, 0.09, 0.08, 0.10, 0.08, 0.11), &vec![
1e-1 -2.00148,
-2.35815494,
-2.00148,
-2.69462718,
-2.22462355,
-2.91777073,
-2.10684052,
-2.51230562,
-2.69462718,
-2.00148
],
1e-5
)); ));
assert!(y_hat.approximate_eq( assert!(y_hat.approximate_eq(
&vec!(2.0, 2.0, 0.0, 0.0, 0.0, 2.0, 2.0, 1.0, 0.0, 1.0, 0.0, 2.0, 0.0, 0.0, 2.0), &vec!(2.0, 2.0, 0.0, 0.0, 0.0, 2.0, 2.0, 1.0, 0.0, 1.0, 0.0, 2.0, 0.0, 0.0, 2.0),
1e-5 1e-5
)); ));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
#[cfg(feature = "serde")]
fn serde() { fn serde() {
let x = DenseMatrix::<f64>::from_2d_array(&[ let x = DenseMatrix::<f64>::from_2d_array(&[
&[1., 1., 0., 0., 0., 0.], &[1., 1., 0., 0., 0., 0.],
+9 -2
View File
@@ -33,6 +33,7 @@
//! //!
use std::marker::PhantomData; use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName}; use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
@@ -45,7 +46,8 @@ use crate::math::num::RealNumber;
use crate::neighbors::KNNWeightFunction; use crate::neighbors::KNNWeightFunction;
/// `KNNClassifier` parameters. Use `Default::default()` for default values. /// `KNNClassifier` parameters. Use `Default::default()` for default values.
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct KNNClassifierParameters<T: RealNumber, D: Distance<Vec<T>, T>> { pub struct KNNClassifierParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
/// a function that defines a distance between each pair of point in training data. /// a function that defines a distance between each pair of point in training data.
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait. /// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
@@ -62,7 +64,8 @@ pub struct KNNClassifierParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
} }
/// K Nearest Neighbors Classifier /// K Nearest Neighbors Classifier
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct KNNClassifier<T: RealNumber, D: Distance<Vec<T>, T>> { pub struct KNNClassifier<T: RealNumber, D: Distance<Vec<T>, T>> {
classes: Vec<T>, classes: Vec<T>,
y: Vec<usize>, y: Vec<usize>,
@@ -248,6 +251,7 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn knn_fit_predict() { fn knn_fit_predict() {
let x = let x =
@@ -259,6 +263,7 @@ mod tests {
assert_eq!(y.to_vec(), y_hat); assert_eq!(y.to_vec(), y_hat);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn knn_fit_predict_weighted() { fn knn_fit_predict_weighted() {
let x = DenseMatrix::from_2d_array(&[&[1.], &[2.], &[3.], &[4.], &[5.]]); let x = DenseMatrix::from_2d_array(&[&[1.], &[2.], &[3.], &[4.], &[5.]]);
@@ -276,7 +281,9 @@ mod tests {
assert_eq!(vec![3.0], y_hat); assert_eq!(vec![3.0], y_hat);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
#[cfg(feature = "serde")]
fn serde() { fn serde() {
let x = let x =
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]); DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
+9 -2
View File
@@ -36,6 +36,7 @@
//! //!
use std::marker::PhantomData; use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName}; use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName};
@@ -48,7 +49,8 @@ use crate::math::num::RealNumber;
use crate::neighbors::KNNWeightFunction; use crate::neighbors::KNNWeightFunction;
/// `KNNRegressor` parameters. Use `Default::default()` for default values. /// `KNNRegressor` parameters. Use `Default::default()` for default values.
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct KNNRegressorParameters<T: RealNumber, D: Distance<Vec<T>, T>> { pub struct KNNRegressorParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
/// a function that defines a distance between each pair of point in training data. /// a function that defines a distance between each pair of point in training data.
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait. /// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
@@ -65,7 +67,8 @@ pub struct KNNRegressorParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
} }
/// K Nearest Neighbors Regressor /// K Nearest Neighbors Regressor
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct KNNRegressor<T: RealNumber, D: Distance<Vec<T>, T>> { pub struct KNNRegressor<T: RealNumber, D: Distance<Vec<T>, T>> {
y: Vec<T>, y: Vec<T>,
knn_algorithm: KNNAlgorithm<T, D>, knn_algorithm: KNNAlgorithm<T, D>,
@@ -228,6 +231,7 @@ mod tests {
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::math::distance::Distances; use crate::math::distance::Distances;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn knn_fit_predict_weighted() { fn knn_fit_predict_weighted() {
let x = let x =
@@ -251,6 +255,7 @@ mod tests {
} }
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn knn_fit_predict_uniform() { fn knn_fit_predict_uniform() {
let x = let x =
@@ -265,7 +270,9 @@ mod tests {
} }
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
#[cfg(feature = "serde")]
fn serde() { fn serde() {
let x = let x =
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]); DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
+3 -1
View File
@@ -33,6 +33,7 @@
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// K Nearest Neighbors Classifier /// K Nearest Neighbors Classifier
@@ -48,7 +49,8 @@ pub mod knn_regressor;
pub type KNNAlgorithmName = crate::algorithm::neighbour::KNNAlgorithmName; pub type KNNAlgorithmName = crate::algorithm::neighbour::KNNAlgorithmName;
/// Weight function that is used to determine estimated value. /// Weight function that is used to determine estimated value.
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub enum KNNWeightFunction { pub enum KNNWeightFunction {
/// All k nearest points are weighted equally /// All k nearest points are weighted equally
Uniform, Uniform,
@@ -50,14 +50,14 @@ impl<T: RealNumber> FirstOrderOptimizer<T> for GradientDescent<T> {
let f_alpha = |alpha: T| -> T { let f_alpha = |alpha: T| -> T {
let mut dx = step.clone(); let mut dx = step.clone();
dx.mul_scalar_mut(alpha); dx.mul_scalar_mut(alpha);
f(&dx.add_mut(&x)) // f(x) = f(x .+ gvec .* alpha) f(dx.add_mut(&x)) // f(x) = f(x .+ gvec .* alpha)
}; };
let df_alpha = |alpha: T| -> T { let df_alpha = |alpha: T| -> T {
let mut dx = step.clone(); let mut dx = step.clone();
let mut dg = gvec.clone(); let mut dg = gvec.clone();
dx.mul_scalar_mut(alpha); dx.mul_scalar_mut(alpha);
df(&mut dg, &dx.add_mut(&x)); //df(x) = df(x .+ gvec .* alpha) df(&mut dg, dx.add_mut(&x)); //df(x) = df(x .+ gvec .* alpha)
gvec.dot(&dg) gvec.dot(&dg)
}; };
@@ -66,7 +66,7 @@ impl<T: RealNumber> FirstOrderOptimizer<T> for GradientDescent<T> {
let ls_r = ls.search(&f_alpha, &df_alpha, alpha, fx, df0); let ls_r = ls.search(&f_alpha, &df_alpha, alpha, fx, df0);
alpha = ls_r.alpha; alpha = ls_r.alpha;
fx = ls_r.f_x; fx = ls_r.f_x;
x.add_mut(&step.mul_scalar_mut(alpha)); x.add_mut(step.mul_scalar_mut(alpha));
df(&mut gvec, &x); df(&mut gvec, &x);
gnorm = gvec.norm2(); gnorm = gvec.norm2();
} }
@@ -88,6 +88,7 @@ mod tests {
use crate::optimization::line_search::Backtracking; use crate::optimization::line_search::Backtracking;
use crate::optimization::FunctionOrder; use crate::optimization::FunctionOrder;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn gradient_descent() { fn gradient_descent() {
let x0 = DenseMatrix::row_vector_from_array(&[-1., 1.]); let x0 = DenseMatrix::row_vector_from_array(&[-1., 1.]);
+6 -3
View File
@@ -1,3 +1,4 @@
#![allow(clippy::suspicious_operation_groupings)]
use std::default::Default; use std::default::Default;
use std::fmt::Debug; use std::fmt::Debug;
@@ -7,6 +8,7 @@ use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
use crate::optimization::line_search::LineSearchMethod; use crate::optimization::line_search::LineSearchMethod;
use crate::optimization::{DF, F}; use crate::optimization::{DF, F};
#[allow(clippy::upper_case_acronyms)]
pub struct LBFGS<T: RealNumber> { pub struct LBFGS<T: RealNumber> {
pub max_iter: usize, pub max_iter: usize,
pub g_rtol: T, pub g_rtol: T,
@@ -116,14 +118,14 @@ impl<T: RealNumber> LBFGS<T> {
let f_alpha = |alpha: T| -> T { let f_alpha = |alpha: T| -> T {
let mut dx = state.s.clone(); let mut dx = state.s.clone();
dx.mul_scalar_mut(alpha); dx.mul_scalar_mut(alpha);
f(&dx.add_mut(&state.x)) // f(x) = f(x .+ gvec .* alpha) f(dx.add_mut(&state.x)) // f(x) = f(x .+ gvec .* alpha)
}; };
let df_alpha = |alpha: T| -> T { let df_alpha = |alpha: T| -> T {
let mut dx = state.s.clone(); let mut dx = state.s.clone();
let mut dg = state.x_df.clone(); let mut dg = state.x_df.clone();
dx.mul_scalar_mut(alpha); dx.mul_scalar_mut(alpha);
df(&mut dg, &dx.add_mut(&state.x)); //df(x) = df(x .+ gvec .* alpha) df(&mut dg, dx.add_mut(&state.x)); //df(x) = df(x .+ gvec .* alpha)
state.x_df.dot(&dg) state.x_df.dot(&dg)
}; };
@@ -205,7 +207,7 @@ impl<T: RealNumber> FirstOrderOptimizer<T> for LBFGS<T> {
) -> OptimizerResult<T, X> { ) -> OptimizerResult<T, X> {
let mut state = self.init_state(x0); let mut state = self.init_state(x0);
df(&mut state.x_df, &x0); df(&mut state.x_df, x0);
let g_converged = state.x_df.norm(T::infinity()) < self.g_atol; let g_converged = state.x_df.norm(T::infinity()) < self.g_atol;
let mut converged = g_converged; let mut converged = g_converged;
@@ -238,6 +240,7 @@ mod tests {
use crate::optimization::line_search::Backtracking; use crate::optimization::line_search::Backtracking;
use crate::optimization::FunctionOrder; use crate::optimization::FunctionOrder;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn lbfgs() { fn lbfgs() {
let x0 = DenseMatrix::row_vector_from_array(&[0., 0.]); let x0 = DenseMatrix::row_vector_from_array(&[0., 0.]);
+1
View File
@@ -112,6 +112,7 @@ impl<T: Float> LineSearchMethod<T> for Backtracking<T> {
mod tests { mod tests {
use super::*; use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn backtracking() { fn backtracking() {
let f = |x: f64| -> f64 { x.powf(2.) + x }; let f = |x: f64| -> f64 { x.powf(2.) + x };
+1
View File
@@ -4,6 +4,7 @@ pub mod line_search;
pub type F<'a, T, X> = dyn for<'b> Fn(&'b X) -> T + 'a; pub type F<'a, T, X> = dyn for<'b> Fn(&'b X) -> T + 'a;
pub type DF<'a, X> = dyn for<'b> Fn(&'b mut X, &'b X) + 'a; pub type DF<'a, X> = dyn for<'b> Fn(&'b mut X, &'b X) + 'a;
#[allow(clippy::upper_case_acronyms)]
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub enum FunctionOrder { pub enum FunctionOrder {
SECOND, SECOND,
+333
View File
@@ -0,0 +1,333 @@
//! # One-hot Encoding For [RealNumber](../../math/num/trait.RealNumber.html) Matricies
//! Transform a data [Matrix](../../linalg/trait.BaseMatrix.html) by replacing all categorical variables with their one-hot equivalents
//!
//! Internally OneHotEncoder treats every categorical column as a series and transforms it using [CategoryMapper](../series_encoder/struct.CategoryMapper.html)
//!
//! ### Usage Example
//! ```
//! use smartcore::linalg::naive::dense_matrix::DenseMatrix;
//! use smartcore::preprocessing::categorical::{OneHotEncoder, OneHotEncoderParams};
//! let data = DenseMatrix::from_2d_array(&[
//! &[1.5, 1.0, 1.5, 3.0],
//! &[1.5, 2.0, 1.5, 4.0],
//! &[1.5, 1.0, 1.5, 5.0],
//! &[1.5, 2.0, 1.5, 6.0],
//! ]);
//! let encoder_params = OneHotEncoderParams::from_cat_idx(&[1, 3]);
//! // Infer number of categories from data and return a reusable encoder
//! let encoder = OneHotEncoder::fit(&data, encoder_params).unwrap();
//! // Transform categorical to one-hot encoded (can transform similar)
//! let oh_data = encoder.transform(&data).unwrap();
//! // Produces the following:
//! // &[1.5, 1.0, 0.0, 1.5, 1.0, 0.0, 0.0, 0.0]
//! // &[1.5, 0.0, 1.0, 1.5, 0.0, 1.0, 0.0, 0.0]
//! // &[1.5, 1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0]
//! // &[1.5, 0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0]
//! ```
use std::iter;
use crate::error::Failed;
use crate::linalg::Matrix;
use crate::preprocessing::data_traits::{CategoricalFloat, Categorizable};
use crate::preprocessing::series_encoder::CategoryMapper;
/// OneHotEncoder Parameters
#[derive(Debug, Clone)]
pub struct OneHotEncoderParams {
/// Column number that contain categorical variable
pub col_idx_categorical: Option<Vec<usize>>,
/// (Currently not implemented) Try and infer which of the matrix columns are categorical variables
infer_categorical: bool,
}
impl OneHotEncoderParams {
/// Generate parameters from categorical variable column numbers
pub fn from_cat_idx(categorical_params: &[usize]) -> Self {
Self {
col_idx_categorical: Some(categorical_params.to_vec()),
infer_categorical: false,
}
}
}
/// Calculate the offset to parameters to due introduction of one-hot encoding
fn find_new_idxs(num_params: usize, cat_sizes: &[usize], cat_idxs: &[usize]) -> Vec<usize> {
// This functions uses iterators and returns a vector.
// In case we get a huge amount of paramenters this might be a problem
// todo: Change this such that it will return an iterator
let cat_idx = cat_idxs.iter().copied().chain((num_params..).take(1));
// Offset is constant between two categorical values, here we calculate the number of steps
// that remain constant
let repeats = cat_idx.scan(0, |a, v| {
let im = v + 1 - *a;
*a = v;
Some(im)
});
// Calculate the offset to parameter idx due to newly intorduced one-hot vectors
let offset_ = cat_sizes.iter().scan(0, |a, &v| {
*a = *a + v - 1;
Some(*a)
});
let offset = (0..1).chain(offset_);
let new_param_idxs: Vec<usize> = (0..num_params)
.zip(
repeats
.zip(offset)
.flat_map(|(r, o)| iter::repeat(o).take(r)),
)
.map(|(idx, ofst)| idx + ofst)
.collect();
new_param_idxs
}
fn validate_col_is_categorical<T: Categorizable>(data: &[T]) -> bool {
for v in data {
if !v.is_valid() {
return false;
}
}
true
}
/// Encode Categorical variavbles of data matrix to one-hot
#[derive(Debug, Clone)]
pub struct OneHotEncoder {
category_mappers: Vec<CategoryMapper<CategoricalFloat>>,
col_idx_categorical: Vec<usize>,
}
impl OneHotEncoder {
/// Create an encoder instance with categories infered from data matrix
pub fn fit<T, M>(data: &M, params: OneHotEncoderParams) -> Result<OneHotEncoder, Failed>
where
T: Categorizable,
M: Matrix<T>,
{
match (params.col_idx_categorical, params.infer_categorical) {
(None, false) => Err(Failed::fit(
"Must pass categorical series ids or infer flag",
)),
(Some(_idxs), true) => Err(Failed::fit(
"Ambigous parameters, got both infer and categroy ids",
)),
(Some(mut idxs), false) => {
// make sure categories have same order as data columns
idxs.sort_unstable();
let (nrows, _) = data.shape();
// col buffer to avoid allocations
let mut col_buf: Vec<T> = iter::repeat(T::zero()).take(nrows).collect();
let mut res: Vec<CategoryMapper<CategoricalFloat>> = Vec::with_capacity(idxs.len());
for &idx in &idxs {
data.copy_col_as_vec(idx, &mut col_buf);
if !validate_col_is_categorical(&col_buf) {
let msg = format!(
"Column {} of data matrix containts non categorizable (integer) values",
idx
);
return Err(Failed::fit(&msg[..]));
}
let hashable_col = col_buf.iter().map(|v| v.to_category());
res.push(CategoryMapper::fit_to_iter(hashable_col));
}
Ok(Self {
category_mappers: res,
col_idx_categorical: idxs,
})
}
(None, true) => {
todo!("Auto-Inference for Categorical Variables not yet implemented")
}
}
}
/// Transform categorical variables to one-hot encoded and return a new matrix
pub fn transform<T, M>(&self, x: &M) -> Result<M, Failed>
where
T: Categorizable,
M: Matrix<T>,
{
let (nrows, p) = x.shape();
let additional_params: Vec<usize> = self
.category_mappers
.iter()
.map(|enc| enc.num_categories())
.collect();
// Eac category of size v adds v-1 params
let expandws_p: usize = p + additional_params.iter().fold(0, |cs, &v| cs + v - 1);
let new_col_idx = find_new_idxs(p, &additional_params[..], &self.col_idx_categorical[..]);
let mut res = M::zeros(nrows, expandws_p);
for (pidx, &old_cidx) in self.col_idx_categorical.iter().enumerate() {
let cidx = new_col_idx[old_cidx];
let col_iter = (0..nrows).map(|r| x.get(r, old_cidx).to_category());
let sencoder = &self.category_mappers[pidx];
let oh_series = col_iter.map(|c| sencoder.get_one_hot::<T, Vec<T>>(&c));
for (row, oh_vec) in oh_series.enumerate() {
match oh_vec {
None => {
// Since we support T types, bad value in a series causes in to be invalid
let msg = format!("At least one value in column {} doesn't conform to category definition", old_cidx);
return Err(Failed::transform(&msg[..]));
}
Some(v) => {
// copy one hot vectors to their place in the data matrix;
for (col_ofst, &val) in v.iter().enumerate() {
res.set(row, cidx + col_ofst, val);
}
}
}
}
}
// copy old data in x to their new location while skipping catergorical vars (already treated)
let mut skip_idx_iter = self.col_idx_categorical.iter();
let mut cur_skip = skip_idx_iter.next();
for (old_p, &new_p) in new_col_idx.iter().enumerate() {
// if found treated varible, skip it
if let Some(&v) = cur_skip {
if v == old_p {
cur_skip = skip_idx_iter.next();
continue;
}
}
for r in 0..nrows {
let val = x.get(r, old_p);
res.set(r, new_p, val);
}
}
Ok(res)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::preprocessing::series_encoder::CategoryMapper;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn adjust_idxs() {
assert_eq!(find_new_idxs(0, &[], &[]), Vec::<usize>::new());
// [0,1,2] -> [0, 1, 1, 1, 2]
assert_eq!(find_new_idxs(3, &[3], &[1]), vec![0, 1, 4]);
}
fn build_cat_first_and_last() -> (DenseMatrix<f64>, DenseMatrix<f64>) {
let orig = DenseMatrix::from_2d_array(&[
&[1.0, 1.5, 3.0],
&[2.0, 1.5, 4.0],
&[1.0, 1.5, 5.0],
&[2.0, 1.5, 6.0],
]);
let oh_enc = DenseMatrix::from_2d_array(&[
&[1.0, 0.0, 1.5, 1.0, 0.0, 0.0, 0.0],
&[0.0, 1.0, 1.5, 0.0, 1.0, 0.0, 0.0],
&[1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0],
&[0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0],
]);
(orig, oh_enc)
}
fn build_fake_matrix() -> (DenseMatrix<f64>, DenseMatrix<f64>) {
// Categorical first and last
let orig = DenseMatrix::from_2d_array(&[
&[1.5, 1.0, 1.5, 3.0],
&[1.5, 2.0, 1.5, 4.0],
&[1.5, 1.0, 1.5, 5.0],
&[1.5, 2.0, 1.5, 6.0],
]);
let oh_enc = DenseMatrix::from_2d_array(&[
&[1.5, 1.0, 0.0, 1.5, 1.0, 0.0, 0.0, 0.0],
&[1.5, 0.0, 1.0, 1.5, 0.0, 1.0, 0.0, 0.0],
&[1.5, 1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0],
&[1.5, 0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0],
]);
(orig, oh_enc)
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn hash_encode_f64_series() {
let series = vec![3.0, 1.0, 2.0, 1.0];
let hashable_series: Vec<CategoricalFloat> =
series.iter().map(|v| v.to_category()).collect();
let enc = CategoryMapper::from_positional_category_vec(hashable_series);
let inv = enc.invert_one_hot(vec![0.0, 0.0, 1.0]);
let orig_val: f64 = inv.unwrap().into();
assert_eq!(orig_val, 2.0);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_fit() {
let (x, _) = build_fake_matrix();
let params = OneHotEncoderParams::from_cat_idx(&[1, 3]);
let oh_enc = OneHotEncoder::fit(&x, params).unwrap();
assert_eq!(oh_enc.category_mappers.len(), 2);
let num_cat: Vec<usize> = oh_enc
.category_mappers
.iter()
.map(|a| a.num_categories())
.collect();
assert_eq!(num_cat, vec![2, 4]);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn matrix_transform_test() {
let (x, expected_x) = build_fake_matrix();
let params = OneHotEncoderParams::from_cat_idx(&[1, 3]);
let oh_enc = OneHotEncoder::fit(&x, params).unwrap();
let nm = oh_enc.transform(&x).unwrap();
assert_eq!(nm, expected_x);
let (x, expected_x) = build_cat_first_and_last();
let params = OneHotEncoderParams::from_cat_idx(&[0, 2]);
let oh_enc = OneHotEncoder::fit(&x, params).unwrap();
let nm = oh_enc.transform(&x).unwrap();
assert_eq!(nm, expected_x);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn fail_on_bad_category() {
let m = DenseMatrix::from_2d_array(&[
&[1.0, 1.5, 3.0],
&[2.0, 1.5, 4.0],
&[1.0, 1.5, 5.0],
&[2.0, 1.5, 6.0],
]);
let params = OneHotEncoderParams::from_cat_idx(&[1]);
match OneHotEncoder::fit(&m, params) {
Err(_) => {
assert!(true);
}
_ => assert!(false),
}
}
}
+43
View File
@@ -0,0 +1,43 @@
//! Traits to indicate that float variables can be viewed as categorical
//! This module assumes
use crate::math::num::RealNumber;
pub type CategoricalFloat = u16;
// pub struct CategoricalFloat(u16);
const ERROR_MARGIN: f64 = 0.001;
pub trait Categorizable: RealNumber {
type A;
fn to_category(self) -> CategoricalFloat;
fn is_valid(self) -> bool;
}
impl Categorizable for f32 {
type A = CategoricalFloat;
fn to_category(self) -> CategoricalFloat {
self as CategoricalFloat
}
fn is_valid(self) -> bool {
let a = self.to_category();
(a as f32 - self).abs() < (ERROR_MARGIN as f32)
}
}
impl Categorizable for f64 {
type A = CategoricalFloat;
fn to_category(self) -> CategoricalFloat {
self as CategoricalFloat
}
fn is_valid(self) -> bool {
let a = self.to_category();
(a as f64 - self).abs() < ERROR_MARGIN
}
}
+5
View File
@@ -0,0 +1,5 @@
/// Transform a data matrix by replaceing all categorical variables with their one-hot vector equivalents
pub mod categorical;
mod data_traits;
/// Encode a series (column, array) of categorical variables as one-hot vectors
pub mod series_encoder;
+282
View File
@@ -0,0 +1,282 @@
#![allow(clippy::ptr_arg)]
//! # Series Encoder
//! Encode a series of categorical features as a one-hot numeric array.
use crate::error::Failed;
use crate::linalg::BaseVector;
use crate::math::num::RealNumber;
use std::collections::HashMap;
use std::hash::Hash;
/// ## Bi-directional map category <-> label num.
/// Turn Hashable objects into a one-hot vectors or ordinal values.
/// This struct encodes single class per exmample
///
/// You can fit_to_iter a category enumeration by passing an iterator of categories.
/// category numbers will be assigned in the order they are encountered
///
/// Example:
/// ```
/// use std::collections::HashMap;
/// use smartcore::preprocessing::series_encoder::CategoryMapper;
///
/// let fake_categories: Vec<usize> = vec![1, 2, 3, 4, 5, 3, 5, 3, 1, 2, 4];
/// let it = fake_categories.iter().map(|&a| a);
/// let enc = CategoryMapper::<usize>::fit_to_iter(it);
/// let oh_vec: Vec<f64> = enc.get_one_hot(&1).unwrap();
/// // notice that 1 is actually a zero-th positional category
/// assert_eq!(oh_vec, vec![1.0, 0.0, 0.0, 0.0, 0.0]);
/// ```
///
/// You can also pass a predefined category enumeration such as a hashmap `HashMap<C, usize>` or a vector `Vec<C>`
///
///
/// ```
/// use std::collections::HashMap;
/// use smartcore::preprocessing::series_encoder::CategoryMapper;
///
/// let category_map: HashMap<&str, usize> =
/// vec![("cat", 2), ("background",0), ("dog", 1)]
/// .into_iter()
/// .collect();
/// let category_vec = vec!["background", "dog", "cat"];
///
/// let enc_lv = CategoryMapper::<&str>::from_positional_category_vec(category_vec);
/// let enc_lm = CategoryMapper::<&str>::from_category_map(category_map);
///
/// // ["background", "dog", "cat"]
/// println!("{:?}", enc_lv.get_categories());
/// let lv: Vec<f64> = enc_lv.get_one_hot(&"dog").unwrap();
/// let lm: Vec<f64> = enc_lm.get_one_hot(&"dog").unwrap();
/// assert_eq!(lv, lm);
/// ```
#[derive(Debug, Clone)]
pub struct CategoryMapper<C> {
category_map: HashMap<C, usize>,
categories: Vec<C>,
num_categories: usize,
}
impl<C> CategoryMapper<C>
where
C: Hash + Eq + Clone,
{
/// Get the number of categories in the mapper
pub fn num_categories(&self) -> usize {
self.num_categories
}
/// Fit an encoder to a lable iterator
pub fn fit_to_iter(categories: impl Iterator<Item = C>) -> Self {
let mut category_map: HashMap<C, usize> = HashMap::new();
let mut category_num = 0usize;
let mut unique_lables: Vec<C> = Vec::new();
for l in categories {
if !category_map.contains_key(&l) {
category_map.insert(l.clone(), category_num);
unique_lables.push(l.clone());
category_num += 1;
}
}
Self {
category_map,
num_categories: category_num,
categories: unique_lables,
}
}
/// Build an encoder from a predefined (category -> class number) map
pub fn from_category_map(category_map: HashMap<C, usize>) -> Self {
let mut _unique_cat: Vec<(C, usize)> =
category_map.iter().map(|(k, v)| (k.clone(), *v)).collect();
_unique_cat.sort_by(|a, b| a.1.cmp(&b.1));
let categories: Vec<C> = _unique_cat.into_iter().map(|a| a.0).collect();
Self {
num_categories: categories.len(),
categories,
category_map,
}
}
/// Build an encoder from a predefined positional category-class num vector
pub fn from_positional_category_vec(categories: Vec<C>) -> Self {
let category_map: HashMap<C, usize> = categories
.iter()
.enumerate()
.map(|(v, k)| (k.clone(), v))
.collect();
Self {
num_categories: categories.len(),
category_map,
categories,
}
}
/// Get label num of a category
pub fn get_num(&self, category: &C) -> Option<&usize> {
self.category_map.get(category)
}
/// Return category corresponding to label num
pub fn get_cat(&self, num: usize) -> &C {
&self.categories[num]
}
/// List all categories (position = category number)
pub fn get_categories(&self) -> &[C] {
&self.categories[..]
}
/// Get one-hot encoding of the category
pub fn get_one_hot<U, V>(&self, category: &C) -> Option<V>
where
U: RealNumber,
V: BaseVector<U>,
{
self.get_num(category)
.map(|&idx| make_one_hot::<U, V>(idx, self.num_categories))
}
/// Invert one-hot vector, back to the category
pub fn invert_one_hot<U, V>(&self, one_hot: V) -> Result<C, Failed>
where
U: RealNumber,
V: BaseVector<U>,
{
let pos = U::one();
let oh_it = (0..one_hot.len()).map(|idx| one_hot.get(idx));
let s: Vec<usize> = oh_it
.enumerate()
.filter_map(|(idx, v)| if v == pos { Some(idx) } else { None })
.collect();
if s.len() == 1 {
let idx = s[0];
return Ok(self.get_cat(idx).clone());
}
let pos_entries = format!(
"Expected a single positive entry, {} entires found",
s.len()
);
Err(Failed::transform(&pos_entries[..]))
}
/// Get ordinal encoding of the catergory
pub fn get_ordinal<U>(&self, category: &C) -> Option<U>
where
U: RealNumber,
{
match self.get_num(category) {
None => None,
Some(&idx) => U::from_usize(idx),
}
}
}
/// Make a one-hot encoded vector from a categorical variable
///
/// Example:
/// ```
/// use smartcore::preprocessing::series_encoder::make_one_hot;
/// let one_hot: Vec<f64> = make_one_hot(2, 3);
/// assert_eq!(one_hot, vec![0.0, 0.0, 1.0]);
/// ```
pub fn make_one_hot<T, V>(category_idx: usize, num_categories: usize) -> V
where
T: RealNumber,
V: BaseVector<T>,
{
let pos = T::one();
let mut z = V::zeros(num_categories);
z.set(category_idx, pos);
z
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn from_categories() {
let fake_categories: Vec<usize> = vec![1, 2, 3, 4, 5, 3, 5, 3, 1, 2, 4];
let it = fake_categories.iter().map(|&a| a);
let enc = CategoryMapper::<usize>::fit_to_iter(it);
let oh_vec: Vec<f64> = match enc.get_one_hot(&1) {
None => panic!("Wrong categories"),
Some(v) => v,
};
let res: Vec<f64> = vec![1f64, 0f64, 0f64, 0f64, 0f64];
assert_eq!(oh_vec, res);
}
fn build_fake_str_enc<'a>() -> CategoryMapper<&'a str> {
let fake_category_pos = vec!["background", "dog", "cat"];
let enc = CategoryMapper::<&str>::from_positional_category_vec(fake_category_pos);
enc
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn ordinal_encoding() {
let enc = build_fake_str_enc();
assert_eq!(1f64, enc.get_ordinal::<f64>(&"dog").unwrap())
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn category_map_and_vec() {
let category_map: HashMap<&str, usize> = vec![("background", 0), ("dog", 1), ("cat", 2)]
.into_iter()
.collect();
let enc = CategoryMapper::<&str>::from_category_map(category_map);
let oh_vec: Vec<f64> = match enc.get_one_hot(&"dog") {
None => panic!("Wrong categories"),
Some(v) => v,
};
let res: Vec<f64> = vec![0f64, 1f64, 0f64];
assert_eq!(oh_vec, res);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn positional_categories_vec() {
let enc = build_fake_str_enc();
let oh_vec: Vec<f64> = match enc.get_one_hot(&"dog") {
None => panic!("Wrong categories"),
Some(v) => v,
};
let res: Vec<f64> = vec![0.0, 1.0, 0.0];
assert_eq!(oh_vec, res);
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn invert_label_test() {
let enc = build_fake_str_enc();
let res: Vec<f64> = vec![0.0, 1.0, 0.0];
let lab = enc.invert_one_hot(res).unwrap();
assert_eq!(lab, "dog");
if let Err(e) = enc.invert_one_hot(vec![0.0, 0.0, 0.0]) {
let pos_entries = format!("Expected a single positive entry, 0 entires found");
assert_eq!(e, Failed::transform(&pos_entries[..]));
};
}
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_many_categorys() {
let enc = build_fake_str_enc();
let cat_it = ["dog", "cat", "fish", "background"].iter().cloned();
let res: Vec<Option<Vec<f64>>> = cat_it.map(|v| enc.get_one_hot(&v)).collect();
let v = vec![
Some(vec![0.0, 1.0, 0.0]),
Some(vec![0.0, 0.0, 1.0]),
None,
Some(vec![1.0, 0.0, 0.0]),
];
assert_eq!(res, v)
}
}
+13 -4
View File
@@ -26,6 +26,7 @@
pub mod svc; pub mod svc;
pub mod svr; pub mod svr;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::linalg::BaseVector; use crate::linalg::BaseVector;
@@ -93,18 +94,21 @@ impl Kernels {
} }
/// Linear Kernel /// Linear Kernel
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LinearKernel {} pub struct LinearKernel {}
/// Radial basis function (Gaussian) kernel /// Radial basis function (Gaussian) kernel
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct RBFKernel<T: RealNumber> { pub struct RBFKernel<T: RealNumber> {
/// kernel coefficient /// kernel coefficient
pub gamma: T, pub gamma: T,
} }
/// Polynomial kernel /// Polynomial kernel
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct PolynomialKernel<T: RealNumber> { pub struct PolynomialKernel<T: RealNumber> {
/// degree of the polynomial /// degree of the polynomial
pub degree: T, pub degree: T,
@@ -115,7 +119,8 @@ pub struct PolynomialKernel<T: RealNumber> {
} }
/// Sigmoid (hyperbolic tangent) kernel /// Sigmoid (hyperbolic tangent) kernel
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct SigmoidKernel<T: RealNumber> { pub struct SigmoidKernel<T: RealNumber> {
/// kernel coefficient /// kernel coefficient
pub gamma: T, pub gamma: T,
@@ -154,6 +159,7 @@ impl<T: RealNumber, V: BaseVector<T>> Kernel<T, V> for SigmoidKernel<T> {
mod tests { mod tests {
use super::*; use super::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn linear_kernel() { fn linear_kernel() {
let v1 = vec![1., 2., 3.]; let v1 = vec![1., 2., 3.];
@@ -162,6 +168,7 @@ mod tests {
assert_eq!(32f64, Kernels::linear().apply(&v1, &v2)); assert_eq!(32f64, Kernels::linear().apply(&v1, &v2));
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn rbf_kernel() { fn rbf_kernel() {
let v1 = vec![1., 2., 3.]; let v1 = vec![1., 2., 3.];
@@ -170,6 +177,7 @@ mod tests {
assert!((0.2265f64 - Kernels::rbf(0.055).apply(&v1, &v2)).abs() < 1e-4); assert!((0.2265f64 - Kernels::rbf(0.055).apply(&v1, &v2)).abs() < 1e-4);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn polynomial_kernel() { fn polynomial_kernel() {
let v1 = vec![1., 2., 3.]; let v1 = vec![1., 2., 3.];
@@ -181,6 +189,7 @@ mod tests {
); );
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn sigmoid_kernel() { fn sigmoid_kernel() {
let v1 = vec![1., 2., 3.]; let v1 = vec![1., 2., 3.];
+28 -16
View File
@@ -57,9 +57,9 @@
//! let y = vec![ 0., 0., 0., 0., 0., 0., 0., 0., //! let y = vec![ 0., 0., 0., 0., 0., 0., 0., 0.,
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]; //! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
//! //!
//! let svr = SVC::fit(&x, &y, SVCParameters::default().with_c(200.0)).unwrap(); //! let svc = SVC::fit(&x, &y, SVCParameters::default().with_c(200.0)).unwrap();
//! //!
//! let y_hat = svr.predict(&x).unwrap(); //! let y_hat = svc.predict(&x).unwrap();
//! ``` //! ```
//! //!
//! ## References: //! ## References:
@@ -76,6 +76,7 @@ use std::marker::PhantomData;
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator}; use crate::api::{Predictor, SupervisedEstimator};
@@ -85,7 +86,8 @@ use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
use crate::svm::{Kernel, Kernels, LinearKernel}; use crate::svm::{Kernel, Kernels, LinearKernel};
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// SVC Parameters /// SVC Parameters
pub struct SVCParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> { pub struct SVCParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
/// Number of epochs. /// Number of epochs.
@@ -100,11 +102,15 @@ pub struct SVCParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>
m: PhantomData<M>, m: PhantomData<M>,
} }
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[serde(bound( #[derive(Debug)]
serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize", #[cfg_attr(
deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>", feature = "serde",
))] serde(bound(
serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize",
deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>",
))
)]
/// Support Vector Classifier /// Support Vector Classifier
pub struct SVC<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> { pub struct SVC<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
classes: Vec<T>, classes: Vec<T>,
@@ -114,7 +120,8 @@ pub struct SVC<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
b: T, b: T,
} }
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
struct SupportVector<T: RealNumber, V: BaseVector<T>> { struct SupportVector<T: RealNumber, V: BaseVector<T>> {
index: usize, index: usize,
x: V, x: V,
@@ -215,7 +222,7 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVC<T, M, K> {
if n != y.len() { if n != y.len() {
return Err(Failed::fit( return Err(Failed::fit(
&"Number of rows of X doesn\'t match number of rows of Y".to_string(), "Number of rows of X doesn\'t match number of rows of Y",
)); ));
} }
@@ -370,7 +377,7 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
Optimizer { Optimizer {
x, x,
y, y,
parameters: &parameters, parameters,
svmin: 0, svmin: 0,
svmax: 0, svmax: 0,
gmin: T::max_value(), gmin: T::max_value(),
@@ -582,7 +589,7 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
for i in 0..self.sv.len() { for i in 0..self.sv.len() {
let v = &self.sv[i]; let v = &self.sv[i];
let z = v.grad - gm; let z = v.grad - gm;
let k = cache.get(sv1, &v); let k = cache.get(sv1, v);
let mut curv = km + v.k - T::two() * k; let mut curv = km + v.k - T::two() * k;
if curv <= T::zero() { if curv <= T::zero() {
curv = self.tau; curv = self.tau;
@@ -719,8 +726,10 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::*; use crate::linalg::naive::dense_matrix::*;
use crate::metrics::accuracy; use crate::metrics::accuracy;
#[cfg(feature = "serde")]
use crate::svm::*; use crate::svm::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn svc_fit_predict() { fn svc_fit_predict() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
@@ -763,6 +772,7 @@ mod tests {
assert!(accuracy(&y_hat, &y) >= 0.9); assert!(accuracy(&y_hat, &y) >= 0.9);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn svc_fit_predict_rbf() { fn svc_fit_predict_rbf() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
@@ -806,7 +816,9 @@ mod tests {
assert!(accuracy(&y_hat, &y) >= 0.9); assert!(accuracy(&y_hat, &y) >= 0.9);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
#[cfg(feature = "serde")]
fn svc_serde() { fn svc_serde() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2], &[5.1, 3.5, 1.4, 0.2],
@@ -835,11 +847,11 @@ mod tests {
-1., -1., -1., -1., -1., -1., -1., -1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., -1., -1., -1., -1., -1., -1., -1., -1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
]; ];
let svr = SVC::fit(&x, &y, Default::default()).unwrap(); let svc = SVC::fit(&x, &y, Default::default()).unwrap();
let deserialized_svr: SVC<f64, DenseMatrix<f64>, LinearKernel> = let deserialized_svc: SVC<f64, DenseMatrix<f64>, LinearKernel> =
serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap(); serde_json::from_str(&serde_json::to_string(&svc).unwrap()).unwrap();
assert_eq!(svr, deserialized_svr); assert_eq!(svc, deserialized_svc);
} }
} }
+19 -8
View File
@@ -68,6 +68,7 @@ use std::cell::{Ref, RefCell};
use std::fmt::Debug; use std::fmt::Debug;
use std::marker::PhantomData; use std::marker::PhantomData;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator}; use crate::api::{Predictor, SupervisedEstimator};
@@ -77,7 +78,8 @@ use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
use crate::svm::{Kernel, Kernels, LinearKernel}; use crate::svm::{Kernel, Kernels, LinearKernel};
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// SVR Parameters /// SVR Parameters
pub struct SVRParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> { pub struct SVRParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
/// Epsilon in the epsilon-SVR model. /// Epsilon in the epsilon-SVR model.
@@ -92,11 +94,15 @@ pub struct SVRParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>
m: PhantomData<M>, m: PhantomData<M>,
} }
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[serde(bound( #[derive(Debug)]
serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize", #[cfg_attr(
deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>", feature = "serde",
))] serde(bound(
serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize",
deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>",
))
)]
/// Epsilon-Support Vector Regression /// Epsilon-Support Vector Regression
pub struct SVR<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> { pub struct SVR<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
@@ -106,7 +112,8 @@ pub struct SVR<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
b: T, b: T,
} }
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
struct SupportVector<T: RealNumber, V: BaseVector<T>> { struct SupportVector<T: RealNumber, V: BaseVector<T>> {
index: usize, index: usize,
x: V, x: V,
@@ -205,7 +212,7 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVR<T, M, K> {
if n != y.len() { if n != y.len() {
return Err(Failed::fit( return Err(Failed::fit(
&"Number of rows of X doesn\'t match number of rows of Y".to_string(), "Number of rows of X doesn\'t match number of rows of Y",
)); ));
} }
@@ -526,8 +533,10 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::*; use crate::linalg::naive::dense_matrix::*;
use crate::metrics::mean_squared_error; use crate::metrics::mean_squared_error;
#[cfg(feature = "serde")]
use crate::svm::*; use crate::svm::*;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn svr_fit_predict() { fn svr_fit_predict() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
@@ -561,7 +570,9 @@ mod tests {
assert!(mean_squared_error(&y_hat, &y) < 2.5); assert!(mean_squared_error(&y_hat, &y) < 2.5);
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
#[cfg(feature = "serde")]
fn svr_serde() { fn svr_serde() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323], &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
+36 -15
View File
@@ -68,6 +68,8 @@ use std::fmt::Debug;
use std::marker::PhantomData; use std::marker::PhantomData;
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use rand::Rng;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::algorithm::sort::quick_sort::QuickArgSort; use crate::algorithm::sort::quick_sort::QuickArgSort;
@@ -76,7 +78,8 @@ use crate::error::Failed;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// Parameters of Decision Tree /// Parameters of Decision Tree
pub struct DecisionTreeClassifierParameters { pub struct DecisionTreeClassifierParameters {
/// Split criteria to use when building a tree. /// Split criteria to use when building a tree.
@@ -90,7 +93,8 @@ pub struct DecisionTreeClassifierParameters {
} }
/// Decision Tree /// Decision Tree
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct DecisionTreeClassifier<T: RealNumber> { pub struct DecisionTreeClassifier<T: RealNumber> {
nodes: Vec<Node<T>>, nodes: Vec<Node<T>>,
parameters: DecisionTreeClassifierParameters, parameters: DecisionTreeClassifierParameters,
@@ -100,7 +104,8 @@ pub struct DecisionTreeClassifier<T: RealNumber> {
} }
/// The function to measure the quality of a split. /// The function to measure the quality of a split.
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub enum SplitCriterion { pub enum SplitCriterion {
/// [Gini index](../decision_tree_classifier/index.html) /// [Gini index](../decision_tree_classifier/index.html)
Gini, Gini,
@@ -110,9 +115,10 @@ pub enum SplitCriterion {
ClassificationError, ClassificationError,
} }
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
struct Node<T: RealNumber> { struct Node<T: RealNumber> {
index: usize, _index: usize,
output: usize, output: usize,
split_feature: usize, split_feature: usize,
split_value: Option<T>, split_value: Option<T>,
@@ -198,7 +204,7 @@ impl Default for DecisionTreeClassifierParameters {
impl<T: RealNumber> Node<T> { impl<T: RealNumber> Node<T> {
fn new(index: usize, output: usize) -> Self { fn new(index: usize, output: usize) -> Self {
Node { Node {
index, _index: index,
output, output,
split_feature: 0, split_feature: 0,
split_value: Option::None, split_value: Option::None,
@@ -323,7 +329,14 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
) -> Result<DecisionTreeClassifier<T>, Failed> { ) -> Result<DecisionTreeClassifier<T>, Failed> {
let (x_nrows, num_attributes) = x.shape(); let (x_nrows, num_attributes) = x.shape();
let samples = vec![1; x_nrows]; let samples = vec![1; x_nrows];
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters) DecisionTreeClassifier::fit_weak_learner(
x,
y,
samples,
num_attributes,
parameters,
&mut rand::thread_rng(),
)
} }
pub(crate) fn fit_weak_learner<M: Matrix<T>>( pub(crate) fn fit_weak_learner<M: Matrix<T>>(
@@ -332,6 +345,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
samples: Vec<usize>, samples: Vec<usize>,
mtry: usize, mtry: usize,
parameters: DecisionTreeClassifierParameters, parameters: DecisionTreeClassifierParameters,
rng: &mut impl Rng,
) -> Result<DecisionTreeClassifier<T>, Failed> { ) -> Result<DecisionTreeClassifier<T>, Failed> {
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
let (_, y_ncols) = y_m.shape(); let (_, y_ncols) = y_m.shape();
@@ -375,17 +389,17 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
depth: 0, depth: 0,
}; };
let mut visitor = NodeVisitor::<T, M>::new(0, samples, &order, &x, &yi, 1); let mut visitor = NodeVisitor::<T, M>::new(0, samples, &order, x, &yi, 1);
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new(); let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
if tree.find_best_cutoff(&mut visitor, mtry) { if tree.find_best_cutoff(&mut visitor, mtry, rng) {
visitor_queue.push_back(visitor); visitor_queue.push_back(visitor);
} }
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) { while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
match visitor_queue.pop_front() { match visitor_queue.pop_front() {
Some(node) => tree.split(node, mtry, &mut visitor_queue), Some(node) => tree.split(node, mtry, &mut visitor_queue, rng),
None => break, None => break,
}; };
} }
@@ -438,6 +452,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
&mut self, &mut self,
visitor: &mut NodeVisitor<'_, T, M>, visitor: &mut NodeVisitor<'_, T, M>,
mtry: usize, mtry: usize,
rng: &mut impl Rng,
) -> bool { ) -> bool {
let (n_rows, n_attr) = visitor.x.shape(); let (n_rows, n_attr) = visitor.x.shape();
@@ -477,7 +492,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
let mut variables = (0..n_attr).collect::<Vec<_>>(); let mut variables = (0..n_attr).collect::<Vec<_>>();
if mtry < n_attr { if mtry < n_attr {
variables.shuffle(&mut rand::thread_rng()); variables.shuffle(rng);
} }
for variable in variables.iter().take(mtry) { for variable in variables.iter().take(mtry) {
@@ -499,7 +514,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
visitor: &mut NodeVisitor<'_, T, M>, visitor: &mut NodeVisitor<'_, T, M>,
n: usize, n: usize,
count: &[usize], count: &[usize],
false_count: &mut Vec<usize>, false_count: &mut [usize],
parent_impurity: T, parent_impurity: T,
j: usize, j: usize,
) { ) {
@@ -536,7 +551,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
- T::from(tc).unwrap() / T::from(n).unwrap() - T::from(tc).unwrap() / T::from(n).unwrap()
* impurity(&self.parameters.criterion, &true_count, tc) * impurity(&self.parameters.criterion, &true_count, tc)
- T::from(fc).unwrap() / T::from(n).unwrap() - T::from(fc).unwrap() / T::from(n).unwrap()
* impurity(&self.parameters.criterion, &false_count, fc); * impurity(&self.parameters.criterion, false_count, fc);
if self.nodes[visitor.node].split_score == Option::None if self.nodes[visitor.node].split_score == Option::None
|| gain > self.nodes[visitor.node].split_score.unwrap() || gain > self.nodes[visitor.node].split_score.unwrap()
@@ -561,6 +576,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
mut visitor: NodeVisitor<'a, T, M>, mut visitor: NodeVisitor<'a, T, M>,
mtry: usize, mtry: usize,
visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>, visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>,
rng: &mut impl Rng,
) -> bool { ) -> bool {
let (n, _) = visitor.x.shape(); let (n, _) = visitor.x.shape();
let mut tc = 0; let mut tc = 0;
@@ -609,7 +625,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
visitor.level + 1, visitor.level + 1,
); );
if self.find_best_cutoff(&mut true_visitor, mtry) { if self.find_best_cutoff(&mut true_visitor, mtry, rng) {
visitor_queue.push_back(true_visitor); visitor_queue.push_back(true_visitor);
} }
@@ -622,7 +638,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
visitor.level + 1, visitor.level + 1,
); );
if self.find_best_cutoff(&mut false_visitor, mtry) { if self.find_best_cutoff(&mut false_visitor, mtry, rng) {
visitor_queue.push_back(false_visitor); visitor_queue.push_back(false_visitor);
} }
@@ -635,6 +651,7 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn gini_impurity() { fn gini_impurity() {
assert!( assert!(
@@ -651,6 +668,7 @@ mod tests {
); );
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn fit_predict_iris() { fn fit_predict_iris() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
@@ -703,6 +721,7 @@ mod tests {
); );
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn fit_predict_baloons() { fn fit_predict_baloons() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
@@ -739,7 +758,9 @@ mod tests {
); );
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
#[cfg(feature = "serde")]
fn serde() { fn serde() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
&[1., 1., 1., 0.], &[1., 1., 1., 0.],
+30 -12
View File
@@ -63,6 +63,8 @@ use std::default::Default;
use std::fmt::Debug; use std::fmt::Debug;
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use rand::Rng;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::algorithm::sort::quick_sort::QuickArgSort; use crate::algorithm::sort::quick_sort::QuickArgSort;
@@ -71,7 +73,8 @@ use crate::error::Failed;
use crate::linalg::Matrix; use crate::linalg::Matrix;
use crate::math::num::RealNumber; use crate::math::num::RealNumber;
#[derive(Serialize, Deserialize, Debug, Clone)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// Parameters of Regression Tree /// Parameters of Regression Tree
pub struct DecisionTreeRegressorParameters { pub struct DecisionTreeRegressorParameters {
/// The maximum depth of the tree. /// The maximum depth of the tree.
@@ -83,16 +86,18 @@ pub struct DecisionTreeRegressorParameters {
} }
/// Regression Tree /// Regression Tree
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct DecisionTreeRegressor<T: RealNumber> { pub struct DecisionTreeRegressor<T: RealNumber> {
nodes: Vec<Node<T>>, nodes: Vec<Node<T>>,
parameters: DecisionTreeRegressorParameters, parameters: DecisionTreeRegressorParameters,
depth: u16, depth: u16,
} }
#[derive(Serialize, Deserialize, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
struct Node<T: RealNumber> { struct Node<T: RealNumber> {
index: usize, _index: usize,
output: T, output: T,
split_feature: usize, split_feature: usize,
split_value: Option<T>, split_value: Option<T>,
@@ -132,7 +137,7 @@ impl Default for DecisionTreeRegressorParameters {
impl<T: RealNumber> Node<T> { impl<T: RealNumber> Node<T> {
fn new(index: usize, output: T) -> Self { fn new(index: usize, output: T) -> Self {
Node { Node {
index, _index: index,
output, output,
split_feature: 0, split_feature: 0,
split_value: Option::None, split_value: Option::None,
@@ -238,7 +243,14 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
) -> Result<DecisionTreeRegressor<T>, Failed> { ) -> Result<DecisionTreeRegressor<T>, Failed> {
let (x_nrows, num_attributes) = x.shape(); let (x_nrows, num_attributes) = x.shape();
let samples = vec![1; x_nrows]; let samples = vec![1; x_nrows];
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters) DecisionTreeRegressor::fit_weak_learner(
x,
y,
samples,
num_attributes,
parameters,
&mut rand::thread_rng(),
)
} }
pub(crate) fn fit_weak_learner<M: Matrix<T>>( pub(crate) fn fit_weak_learner<M: Matrix<T>>(
@@ -247,6 +259,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
samples: Vec<usize>, samples: Vec<usize>,
mtry: usize, mtry: usize,
parameters: DecisionTreeRegressorParameters, parameters: DecisionTreeRegressorParameters,
rng: &mut impl Rng,
) -> Result<DecisionTreeRegressor<T>, Failed> { ) -> Result<DecisionTreeRegressor<T>, Failed> {
let y_m = M::from_row_vector(y.clone()); let y_m = M::from_row_vector(y.clone());
@@ -276,17 +289,17 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
depth: 0, depth: 0,
}; };
let mut visitor = NodeVisitor::<T, M>::new(0, samples, &order, &x, &y_m, 1); let mut visitor = NodeVisitor::<T, M>::new(0, samples, &order, x, &y_m, 1);
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new(); let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
if tree.find_best_cutoff(&mut visitor, mtry) { if tree.find_best_cutoff(&mut visitor, mtry, rng) {
visitor_queue.push_back(visitor); visitor_queue.push_back(visitor);
} }
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) { while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
match visitor_queue.pop_front() { match visitor_queue.pop_front() {
Some(node) => tree.split(node, mtry, &mut visitor_queue), Some(node) => tree.split(node, mtry, &mut visitor_queue, rng),
None => break, None => break,
}; };
} }
@@ -339,6 +352,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
&mut self, &mut self,
visitor: &mut NodeVisitor<'_, T, M>, visitor: &mut NodeVisitor<'_, T, M>,
mtry: usize, mtry: usize,
rng: &mut impl Rng,
) -> bool { ) -> bool {
let (_, n_attr) = visitor.x.shape(); let (_, n_attr) = visitor.x.shape();
@@ -353,7 +367,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
let mut variables = (0..n_attr).collect::<Vec<_>>(); let mut variables = (0..n_attr).collect::<Vec<_>>();
if mtry < n_attr { if mtry < n_attr {
variables.shuffle(&mut rand::thread_rng()); variables.shuffle(rng);
} }
let parent_gain = let parent_gain =
@@ -428,6 +442,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
mut visitor: NodeVisitor<'a, T, M>, mut visitor: NodeVisitor<'a, T, M>,
mtry: usize, mtry: usize,
visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>, visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>,
rng: &mut impl Rng,
) -> bool { ) -> bool {
let (n, _) = visitor.x.shape(); let (n, _) = visitor.x.shape();
let mut tc = 0; let mut tc = 0;
@@ -476,7 +491,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
visitor.level + 1, visitor.level + 1,
); );
if self.find_best_cutoff(&mut true_visitor, mtry) { if self.find_best_cutoff(&mut true_visitor, mtry, rng) {
visitor_queue.push_back(true_visitor); visitor_queue.push_back(true_visitor);
} }
@@ -489,7 +504,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
visitor.level + 1, visitor.level + 1,
); );
if self.find_best_cutoff(&mut false_visitor, mtry) { if self.find_best_cutoff(&mut false_visitor, mtry, rng) {
visitor_queue.push_back(false_visitor); visitor_queue.push_back(false_visitor);
} }
@@ -502,6 +517,7 @@ mod tests {
use super::*; use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::naive::dense_matrix::DenseMatrix;
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
fn fit_longley() { fn fit_longley() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
@@ -576,7 +592,9 @@ mod tests {
} }
} }
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test] #[test]
#[cfg(feature = "serde")]
fn serde() { fn serde() {
let x = DenseMatrix::from_2d_array(&[ let x = DenseMatrix::from_2d_array(&[
&[234.289, 235.6, 159., 107.608, 1947., 60.323], &[234.289, 235.6, 159., 107.608, 1947., 60.323],