diff --git a/.circleci/config.yml b/.circleci/config.yml deleted file mode 100644 index 17da167..0000000 --- a/.circleci/config.yml +++ /dev/null @@ -1,43 +0,0 @@ -version: 2.1 - -workflows: - version: 2.1 - build: - jobs: - - build - - clippy -jobs: - build: - docker: - - image: circleci/rust:latest - environment: - TZ: "/usr/share/zoneinfo/your/location" - steps: - - checkout - - restore_cache: - key: project-cache - - run: - name: Check formatting - command: cargo fmt -- --check - - run: - name: Stable Build - command: cargo build --features "nalgebra-bindings ndarray-bindings" - - run: - name: Test - command: cargo test --features "nalgebra-bindings ndarray-bindings" - - save_cache: - key: project-cache - paths: - - "~/.cargo" - - "./target" - clippy: - docker: - - image: circleci/rust:latest - steps: - - checkout - - run: - name: Install cargo clippy - command: rustup component add clippy - - run: - name: Run cargo clippy - command: cargo clippy --all-features -- -Drust-2018-idioms -Dwarnings diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..95f9250 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,11 @@ +version: 2 +updates: +- package-ecosystem: cargo + directory: "/" + schedule: + interval: daily + open-pull-requests-limit: 10 + ignore: + - dependency-name: rand_distr + versions: + - 0.4.0 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..5041117 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,57 @@ +name: CI + +on: + push: + branches: [ main, development ] + pull_request: + branches: [ development ] + +jobs: + tests: + runs-on: "${{ matrix.platform.os }}-latest" + strategy: + matrix: + platform: [ + { os: "windows", target: "x86_64-pc-windows-msvc" }, + { os: "windows", target: "i686-pc-windows-msvc" }, + { os: "ubuntu", target: "x86_64-unknown-linux-gnu" }, + { os: "ubuntu", target: "i686-unknown-linux-gnu" }, + { os: "ubuntu", target: "wasm32-unknown-unknown" }, + { os: "macos", target: "aarch64-apple-darwin" }, + ] + env: + TZ: "/usr/share/zoneinfo/your/location" + steps: + - uses: actions/checkout@v2 + - name: Cache .cargo and target + uses: actions/cache@v2 + with: + path: | + ~/.cargo + ./target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.toml') }} + restore-keys: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.toml') }} + - name: Install Rust toolchain + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + target: ${{ matrix.platform.target }} + profile: minimal + default: true + - name: Install test runner for wasm + if: matrix.platform.target == 'wasm32-unknown-unknown' + run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh + - name: Stable Build + uses: actions-rs/cargo@v1 + with: + command: build + args: --all-features --target ${{ matrix.platform.target }} + - name: Tests + if: matrix.platform.target == 'x86_64-unknown-linux-gnu' || matrix.platform.target == 'x86_64-pc-windows-msvc' || matrix.platform.target == 'aarch64-apple-darwin' + uses: actions-rs/cargo@v1 + with: + command: test + args: --all-features + - name: Tests in WASM + if: matrix.platform.target == 'wasm32-unknown-unknown' + run: wasm-pack test --node -- --all-features diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml new file mode 100644 index 0000000..793e79d --- /dev/null +++ b/.github/workflows/coverage.yml @@ -0,0 +1,44 @@ +name: Coverage + +on: + push: + branches: [ main, development ] + pull_request: + branches: [ development ] + +jobs: + coverage: + runs-on: ubuntu-latest + env: + TZ: "/usr/share/zoneinfo/your/location" + steps: + - uses: actions/checkout@v2 + - name: Cache .cargo + uses: actions/cache@v2 + with: + path: | + ~/.cargo + ./target + key: ${{ runner.os }}-coverage-cargo-${{ hashFiles('**/Cargo.toml') }} + restore-keys: ${{ runner.os }}-coverage-cargo-${{ hashFiles('**/Cargo.toml') }} + - name: Install Rust toolchain + uses: actions-rs/toolchain@v1 + with: + toolchain: nightly + profile: minimal + default: true + - name: Install cargo-tarpaulin + uses: actions-rs/install@v0.1 + with: + crate: cargo-tarpaulin + version: latest + use-tool-cache: true + - name: Run cargo-tarpaulin + uses: actions-rs/cargo@v1 + with: + command: tarpaulin + args: --out Lcov --all-features -- --test-threads 1 + - name: Upload to codecov.io + uses: codecov/codecov-action@v1 + with: + fail_ci_if_error: true diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..77a082f --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,41 @@ +name: Lint checks + +on: + push: + branches: [ main, development ] + pull_request: + branches: [ development ] + +jobs: + lint: + runs-on: ubuntu-latest + env: + TZ: "/usr/share/zoneinfo/your/location" + steps: + - uses: actions/checkout@v2 + - name: Cache .cargo and target + uses: actions/cache@v2 + with: + path: | + ~/.cargo + ./target + key: ${{ runner.os }}-lint-cargo-${{ hashFiles('**/Cargo.toml') }} + restore-keys: ${{ runner.os }}-lint-cargo-${{ hashFiles('**/Cargo.toml') }} + - name: Install Rust toolchain + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + profile: minimal + default: true + - run: rustup component add rustfmt + - name: Check formt + uses: actions-rs/cargo@v1 + with: + command: fmt + args: --all -- --check + - run: rustup component add clippy + - name: Run clippy + uses: actions-rs/cargo@v1 + with: + command: clippy + args: --all-features -- -Drust-2018-idioms -Dwarnings diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..ade6825 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,60 @@ +# Changelog +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +## Added +- L2 regularization penalty to the Logistic Regression +- Getters for the naive bayes structs +- One hot encoder +- Make moons data generator +- Support for WASM. + +## Changed +- Make serde optional + +## [0.2.0] - 2021-01-03 + +### Added +- DBSCAN +- Epsilon-SVR, SVC +- Ridge, Lasso, ElasticNet +- Bernoulli, Gaussian, Categorical and Multinomial Naive Bayes +- K-fold Cross Validation +- Singular value decomposition +- New api module +- Integration with Clippy +- Cholesky decomposition + +### Changed +- ndarray upgraded to 0.14 +- smartcore::error:FailedError is now non-exhaustive +- K-Means +- PCA +- Random Forest +- Linear and Logistic Regression +- KNN +- Decision Tree + +## [0.1.0] - 2020-09-25 + +### Added +- First release of smartcore. +- KNN + distance metrics (Euclidian, Minkowski, Manhattan, Hamming, Mahalanobis) +- Linear Regression (OLS) +- Logistic Regression +- Random Forest Classifier +- Decision Tree Classifier +- PCA +- K-Means +- Integrated with ndarray +- Abstract linear algebra methods +- RandomForest Regressor +- Decision Tree Regressor +- Serde integration +- Integrated with nalgebra +- LU, QR, SVD, EVD +- Evaluation Metrics diff --git a/Cargo.toml b/Cargo.toml index 5e21aef..f83889e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ name = "smartcore" description = "The most advanced machine learning library in rust." homepage = "https://smartcorelib.org" -version = "0.2.0" +version = "0.2.1" authors = ["SmartCore Developers"] edition = "2018" license = "Apache-2.0" @@ -19,20 +19,25 @@ nalgebra-bindings = ["nalgebra"] datasets = [] [dependencies] -ndarray = { version = "0.14", optional = true } +ndarray = { version = "0.15", optional = true } nalgebra = { version = "0.23.0", optional = true } num-traits = "0.2.12" -num = "0.3.0" -rand = "0.7.3" -rand_distr = "0.3.0" -serde = { version = "1.0.115", features = ["derive"] } -serde_derive = "1.0.115" +num = "0.4.0" +rand = "0.8.3" +rand_distr = "0.4.0" +serde = { version = "1.0.115", features = ["derive"], optional = true } + +[target.'cfg(target_arch = "wasm32")'.dependencies] +getrandom = { version = "0.2", features = ["js"] } [dev-dependencies] criterion = "0.3" serde_json = "1.0" bincode = "1.3.1" +[target.'cfg(target_arch = "wasm32")'.dev-dependencies] +wasm-bindgen-test = "0.3" + [[bench]] name = "distance" harness = false diff --git a/smartcore.svg b/smartcore.svg index f8ff7e9..3e4c68d 100644 --- a/smartcore.svg +++ b/smartcore.svg @@ -9,9 +9,9 @@ xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape" inkscape:version="1.0 (4035a4f, 2020-05-01)" sodipodi:docname="smartcore.svg" - width="396.01309mm" - height="86.286003mm" - viewBox="0 0 396.0131 86.286004" + width="1280" + height="320" + viewBox="0 0 454 86.286004" version="1.1" id="svg512"> > { base: F, inv_log_base: F, @@ -56,16 +58,17 @@ impl> PartialEq for CoverTree { } } -#[derive(Debug, Serialize, Deserialize)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] struct Node { idx: usize, max_dist: F, parent_dist: F, children: Vec>, - scale: i64, + _scale: i64, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug)] struct DistanceSet { idx: usize, dist: Vec, @@ -82,7 +85,7 @@ impl> CoverTree max_dist: F::zero(), parent_dist: F::zero(), children: Vec::new(), - scale: 0, + _scale: 0, }; let mut tree = CoverTree { base, @@ -114,7 +117,7 @@ impl> CoverTree } let e = self.get_data_value(self.root.idx); - let mut d = self.distance.distance(&e, p); + let mut d = self.distance.distance(e, p); let mut current_cover_set: Vec<(F, &Node)> = Vec::new(); let mut zero_set: Vec<(F, &Node)> = Vec::new(); @@ -172,11 +175,14 @@ impl> CoverTree if ds.0 <= upper_bound { let v = self.get_data_value(ds.1.idx); if !self.identical_excluded || v != p { - neighbors.push((ds.1.idx, ds.0, &v)); + neighbors.push((ds.1.idx, ds.0, v)); } } } + if neighbors.len() > k { + neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); + } Ok(neighbors.into_iter().take(k).collect()) } @@ -197,7 +203,7 @@ impl> CoverTree let mut zero_set: Vec<(F, &Node)> = Vec::new(); let e = self.get_data_value(self.root.idx); - let mut d = self.distance.distance(&e, p); + let mut d = self.distance.distance(e, p); current_cover_set.push((d, &self.root)); while !current_cover_set.is_empty() { @@ -227,7 +233,7 @@ impl> CoverTree for ds in zero_set { let v = self.get_data_value(ds.1.idx); if !self.identical_excluded || v != p { - neighbors.push((ds.1.idx, ds.0, &v)); + neighbors.push((ds.1.idx, ds.0, v)); } } @@ -240,7 +246,7 @@ impl> CoverTree max_dist: F::zero(), parent_dist: F::zero(), children: Vec::new(), - scale: 100, + _scale: 100, } } @@ -284,7 +290,7 @@ impl> CoverTree if point_set.is_empty() { self.new_leaf(p) } else { - let max_dist = self.max(&point_set); + let max_dist = self.max(point_set); let next_scale = (max_scale - 1).min(self.get_scale(max_dist)); if next_scale == std::i64::MIN { let mut children: Vec> = Vec::new(); @@ -301,7 +307,7 @@ impl> CoverTree max_dist: F::zero(), parent_dist: F::zero(), children, - scale: 100, + _scale: 100, } } else { let mut far: Vec> = Vec::new(); @@ -313,8 +319,7 @@ impl> CoverTree point_set.append(&mut far); child } else { - let mut children: Vec> = Vec::new(); - children.push(child); + let mut children: Vec> = vec![child]; let mut new_point_set: Vec> = Vec::new(); let mut new_consumed_set: Vec> = Vec::new(); @@ -371,7 +376,7 @@ impl> CoverTree max_dist: self.max(consumed_set), parent_dist: F::zero(), children, - scale: (top_scale - max_scale), + _scale: (top_scale - max_scale), } } } @@ -454,7 +459,8 @@ mod tests { use super::*; use crate::math::distance::Distances; - #[derive(Debug, Serialize, Deserialize, Clone)] + #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] + #[derive(Debug, Clone)] struct SimpleDistance {} impl Distance for SimpleDistance { @@ -463,6 +469,7 @@ mod tests { } } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn cover_tree_test() { let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; @@ -479,7 +486,7 @@ mod tests { let knn: Vec = knn.iter().map(|v| *v.2).collect(); assert_eq!(vec!(3, 4, 5, 6, 7), knn); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn cover_tree_test1() { let data = vec![ @@ -498,8 +505,9 @@ mod tests { assert_eq!(vec!(0, 1, 2), knn); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] + #[cfg(feature = "serde")] fn serde() { let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; diff --git a/src/algorithm/neighbour/linear_search.rs b/src/algorithm/neighbour/linear_search.rs index 45fbd6f..e2a1b6d 100644 --- a/src/algorithm/neighbour/linear_search.rs +++ b/src/algorithm/neighbour/linear_search.rs @@ -22,6 +22,7 @@ //! //! ``` +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use std::cmp::{Ordering, PartialOrd}; use std::marker::PhantomData; @@ -32,7 +33,8 @@ use crate::math::distance::Distance; use crate::math::num::RealNumber; /// Implements Linear Search algorithm, see [KNN algorithms](../index.html) -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] pub struct LinearKNNSearch> { distance: D, data: Vec, @@ -72,7 +74,7 @@ impl> LinearKNNSearch { } for i in 0..self.data.len() { - let d = self.distance.distance(&from, &self.data[i]); + let d = self.distance.distance(from, &self.data[i]); let datum = heap.peek_mut(); if d < datum.distance { datum.distance = d; @@ -102,7 +104,7 @@ impl> LinearKNNSearch { let mut neighbors: Vec<(usize, F, &T)> = Vec::new(); for i in 0..self.data.len() { - let d = self.distance.distance(&from, &self.data[i]); + let d = self.distance.distance(from, &self.data[i]); if d <= radius { neighbors.push((i, d, &self.data[i])); @@ -138,7 +140,8 @@ mod tests { use super::*; use crate::math::distance::Distances; - #[derive(Debug, Serialize, Deserialize, Clone)] + #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] + #[derive(Debug, Clone)] struct SimpleDistance {} impl Distance for SimpleDistance { @@ -147,6 +150,7 @@ mod tests { } } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn knn_find() { let data1 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; @@ -193,7 +197,7 @@ mod tests { assert_eq!(vec!(1, 2, 3), found_idxs2); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn knn_point_eq() { let point1 = KNNPoint { diff --git a/src/algorithm/neighbour/mod.rs b/src/algorithm/neighbour/mod.rs index bf9e669..321ec01 100644 --- a/src/algorithm/neighbour/mod.rs +++ b/src/algorithm/neighbour/mod.rs @@ -35,6 +35,7 @@ use crate::algorithm::neighbour::linear_search::LinearKNNSearch; use crate::error::Failed; use crate::math::distance::Distance; use crate::math::num::RealNumber; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; pub(crate) mod bbd_tree; @@ -45,7 +46,8 @@ pub mod linear_search; /// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries. /// `KNNAlgorithmName` maintains a list of supported search algorithms, see [KNN algorithms](../algorithm/neighbour/index.html) -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] pub enum KNNAlgorithmName { /// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html) LinearSearch, @@ -53,7 +55,8 @@ pub enum KNNAlgorithmName { CoverTree, } -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] pub(crate) enum KNNAlgorithm, T>> { LinearSearch(LinearKNNSearch, T, D>), CoverTree(CoverTree, T, D>), diff --git a/src/algorithm/sort/heap_select.rs b/src/algorithm/sort/heap_select.rs index a44b2bb..beb698f 100644 --- a/src/algorithm/sort/heap_select.rs +++ b/src/algorithm/sort/heap_select.rs @@ -53,8 +53,7 @@ impl<'a, T: PartialOrd + Debug> HeapSelection { if self.sorted { &self.heap[0] } else { - &self - .heap + self.heap .iter() .max_by(|a, b| a.partial_cmp(b).unwrap()) .unwrap() @@ -96,12 +95,14 @@ impl<'a, T: PartialOrd + Debug> HeapSelection { mod tests { use super::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn with_capacity() { let heap = HeapSelection::::with_capacity(3); assert_eq!(3, heap.k); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn test_add() { let mut heap = HeapSelection::with_capacity(3); @@ -119,6 +120,7 @@ mod tests { assert_eq!(vec![2, 0, -5], heap.get()); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn test_add1() { let mut heap = HeapSelection::with_capacity(3); @@ -133,6 +135,7 @@ mod tests { assert_eq!(vec![0f64, -1f64, -5f64], heap.get()); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn test_add2() { let mut heap = HeapSelection::with_capacity(3); @@ -145,6 +148,7 @@ mod tests { assert_eq!(vec![5.6568, 2.8284, 0.0], heap.get()); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn test_add_ordered() { let mut heap = HeapSelection::with_capacity(3); diff --git a/src/algorithm/sort/quick_sort.rs b/src/algorithm/sort/quick_sort.rs index e160ed2..ddf2503 100644 --- a/src/algorithm/sort/quick_sort.rs +++ b/src/algorithm/sort/quick_sort.rs @@ -113,6 +113,7 @@ impl QuickArgSort for Vec { mod tests { use super::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn with_capacity() { let arr1 = vec![0.3, 0.1, 0.2, 0.4, 0.9, 0.5, 0.7, 0.6, 0.8]; diff --git a/src/cluster/dbscan.rs b/src/cluster/dbscan.rs index c793039..7f2baef 100644 --- a/src/cluster/dbscan.rs +++ b/src/cluster/dbscan.rs @@ -43,6 +43,7 @@ use std::fmt::Debug; use std::iter::Sum; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName}; @@ -55,7 +56,8 @@ use crate::math::num::RealNumber; use crate::tree::decision_tree_classifier::which_max; /// DBSCAN clustering algorithm -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] pub struct DBSCAN, T>> { cluster_labels: Vec, num_classes: usize, @@ -153,11 +155,11 @@ impl, T>> DBSCAN { parameters: DBSCANParameters, ) -> Result, Failed> { if parameters.min_samples < 1 { - return Err(Failed::fit(&"Invalid minPts".to_string())); + return Err(Failed::fit("Invalid minPts")); } if parameters.eps <= T::zero() { - return Err(Failed::fit(&"Invalid radius: ".to_string())); + return Err(Failed::fit("Invalid radius: ")); } let mut k = 0; @@ -263,8 +265,10 @@ impl, T>> DBSCAN { mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; + #[cfg(feature = "serde")] use crate::math::distance::euclidian::Euclidian; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn fit_predict_dbscan() { let x = DenseMatrix::from_2d_array(&[ @@ -296,7 +300,9 @@ mod tests { assert_eq!(expected_labels, predicted_labels); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] + #[cfg(feature = "serde")] fn serde() { let x = DenseMatrix::from_2d_array(&[ &[5.1, 3.5, 1.4, 0.2], diff --git a/src/cluster/kmeans.rs b/src/cluster/kmeans.rs index 44ce1e6..05af680 100644 --- a/src/cluster/kmeans.rs +++ b/src/cluster/kmeans.rs @@ -56,6 +56,7 @@ use rand::Rng; use std::fmt::Debug; use std::iter::Sum; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::algorithm::neighbour::bbd_tree::BBDTree; @@ -66,12 +67,13 @@ use crate::math::distance::euclidian::*; use crate::math::num::RealNumber; /// K-Means clustering algorithm -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] pub struct KMeans { k: usize, - y: Vec, + _y: Vec, size: Vec, - distortion: T, + _distortion: T, centroids: Vec>, } @@ -206,9 +208,9 @@ impl KMeans { Ok(KMeans { k: parameters.k, - y, + _y: y, size, - distortion, + _distortion: distortion, centroids, }) } @@ -243,7 +245,7 @@ impl KMeans { let mut rng = rand::thread_rng(); let (n, m) = data.shape(); let mut y = vec![0; n]; - let mut centroid = data.get_row_as_vec(rng.gen_range(0, n)); + let mut centroid = data.get_row_as_vec(rng.gen_range(0..n)); let mut d = vec![T::max_value(); n]; @@ -297,6 +299,7 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn invalid_k() { let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]); @@ -310,6 +313,7 @@ mod tests { ); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn fit_predict_iris() { let x = DenseMatrix::from_2d_array(&[ @@ -340,11 +344,13 @@ mod tests { let y = kmeans.predict(&x).unwrap(); for i in 0..y.len() { - assert_eq!(y[i] as usize, kmeans.y[i]); + assert_eq!(y[i] as usize, kmeans._y[i]); } } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] + #[cfg(feature = "serde")] fn serde() { let x = DenseMatrix::from_2d_array(&[ &[5.1, 3.5, 1.4, 0.2], diff --git a/src/dataset/boston.rs b/src/dataset/boston.rs index 33f7700..1e4ee12 100644 --- a/src/dataset/boston.rs +++ b/src/dataset/boston.rs @@ -56,9 +56,11 @@ pub fn load_dataset() -> Dataset { #[cfg(test)] mod tests { + #[cfg(not(target_arch = "wasm32"))] use super::super::*; use super::*; + #[cfg(not(target_arch = "wasm32"))] #[test] #[ignore] fn refresh_boston_dataset() { @@ -67,6 +69,7 @@ mod tests { assert!(serialize_data(&dataset, "boston.xy").is_ok()); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn boston_dataset() { let dataset = load_dataset(); diff --git a/src/dataset/breast_cancer.rs b/src/dataset/breast_cancer.rs index e469794..0e13be1 100644 --- a/src/dataset/breast_cancer.rs +++ b/src/dataset/breast_cancer.rs @@ -66,17 +66,20 @@ pub fn load_dataset() -> Dataset { #[cfg(test)] mod tests { + #[cfg(not(target_arch = "wasm32"))] use super::super::*; use super::*; #[test] #[ignore] + #[cfg(not(target_arch = "wasm32"))] fn refresh_cancer_dataset() { // run this test to generate breast_cancer.xy file. let dataset = load_dataset(); assert!(serialize_data(&dataset, "breast_cancer.xy").is_ok()); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn cancer_dataset() { let dataset = load_dataset(); diff --git a/src/dataset/diabetes.rs b/src/dataset/diabetes.rs index 2a3e20c..cbee636 100644 --- a/src/dataset/diabetes.rs +++ b/src/dataset/diabetes.rs @@ -50,9 +50,11 @@ pub fn load_dataset() -> Dataset { #[cfg(test)] mod tests { + #[cfg(not(target_arch = "wasm32"))] use super::super::*; use super::*; + #[cfg(not(target_arch = "wasm32"))] #[test] #[ignore] fn refresh_diabetes_dataset() { @@ -61,6 +63,7 @@ mod tests { assert!(serialize_data(&dataset, "diabetes.xy").is_ok()); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn boston_dataset() { let dataset = load_dataset(); diff --git a/src/dataset/digits.rs b/src/dataset/digits.rs index fd643d5..9120e59 100644 --- a/src/dataset/digits.rs +++ b/src/dataset/digits.rs @@ -45,9 +45,11 @@ pub fn load_dataset() -> Dataset { #[cfg(test)] mod tests { + #[cfg(not(target_arch = "wasm32"))] use super::super::*; use super::*; + #[cfg(not(target_arch = "wasm32"))] #[test] #[ignore] fn refresh_digits_dataset() { @@ -55,7 +57,7 @@ mod tests { let dataset = load_dataset(); assert!(serialize_data(&dataset, "digits.xy").is_ok()); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn digits_dataset() { let dataset = load_dataset(); diff --git a/src/dataset/generator.rs b/src/dataset/generator.rs index 28a2224..a73f546 100644 --- a/src/dataset/generator.rs +++ b/src/dataset/generator.rs @@ -88,6 +88,43 @@ pub fn make_circles(num_samples: usize, factor: f32, noise: f32) -> Dataset Dataset { + let num_samples_out = num_samples / 2; + let num_samples_in = num_samples - num_samples_out; + + let linspace_out = linspace(0.0, std::f32::consts::PI, num_samples_out); + let linspace_in = linspace(0.0, std::f32::consts::PI, num_samples_in); + + let noise = Normal::new(0.0, noise).unwrap(); + let mut rng = rand::thread_rng(); + + let mut x: Vec = Vec::with_capacity(num_samples * 2); + let mut y: Vec = Vec::with_capacity(num_samples); + + for v in linspace_out { + x.push(v.cos() + noise.sample(&mut rng)); + x.push(v.sin() + noise.sample(&mut rng)); + y.push(0.0); + } + + for v in linspace_in { + x.push(1.0 - v.cos() + noise.sample(&mut rng)); + x.push(1.0 - v.sin() + noise.sample(&mut rng) - 0.5); + y.push(1.0); + } + + Dataset { + data: x, + target: y, + num_samples, + num_features: 2, + feature_names: (0..2).map(|n| n.to_string()).collect(), + target_names: vec!["label".to_string()], + description: "Two interleaving half circles in 2d".to_string(), + } +} + fn linspace(start: f32, stop: f32, num: usize) -> Vec { let div = num as f32; let delta = stop - start; @@ -100,6 +137,7 @@ mod tests { use super::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn test_make_blobs() { let dataset = make_blobs(10, 2, 3); @@ -112,6 +150,7 @@ mod tests { assert_eq!(dataset.num_samples, 10); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn test_make_circles() { let dataset = make_circles(10, 0.5, 0.05); @@ -123,4 +162,17 @@ mod tests { assert_eq!(dataset.num_features, 2); assert_eq!(dataset.num_samples, 10); } + + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + fn test_make_moons() { + let dataset = make_moons(10, 0.05); + assert_eq!( + dataset.data.len(), + dataset.num_features * dataset.num_samples + ); + assert_eq!(dataset.target.len(), dataset.num_samples); + assert_eq!(dataset.num_features, 2); + assert_eq!(dataset.num_samples, 10); + } } diff --git a/src/dataset/iris.rs b/src/dataset/iris.rs index 3c92428..888d3e8 100644 --- a/src/dataset/iris.rs +++ b/src/dataset/iris.rs @@ -50,9 +50,11 @@ pub fn load_dataset() -> Dataset { #[cfg(test)] mod tests { + #[cfg(not(target_arch = "wasm32"))] use super::super::*; use super::*; + #[cfg(not(target_arch = "wasm32"))] #[test] #[ignore] fn refresh_iris_dataset() { @@ -61,6 +63,7 @@ mod tests { assert!(serialize_data(&dataset, "iris.xy").is_ok()); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn iris_dataset() { let dataset = load_dataset(); diff --git a/src/dataset/mod.rs b/src/dataset/mod.rs index da790b4..acd7641 100644 --- a/src/dataset/mod.rs +++ b/src/dataset/mod.rs @@ -8,9 +8,12 @@ pub mod digits; pub mod generator; pub mod iris; +#[cfg(not(target_arch = "wasm32"))] use crate::math::num::RealNumber; +#[cfg(not(target_arch = "wasm32"))] use std::fs::File; use std::io; +#[cfg(not(target_arch = "wasm32"))] use std::io::prelude::*; /// Dataset @@ -49,6 +52,8 @@ impl Dataset { } } +// Running this in wasm throws: operation not supported on this platform. +#[cfg(not(target_arch = "wasm32"))] #[allow(dead_code)] pub(crate) fn serialize_data( dataset: &Dataset, @@ -62,14 +67,14 @@ pub(crate) fn serialize_data( .data .iter() .copied() - .flat_map(|f| f.to_f32_bits().to_le_bytes().to_vec().into_iter()) + .flat_map(|f| f.to_f32_bits().to_le_bytes().to_vec()) .collect(); file.write_all(&x)?; let y: Vec = dataset .target .iter() .copied() - .flat_map(|f| f.to_f32_bits().to_le_bytes().to_vec().into_iter()) + .flat_map(|f| f.to_f32_bits().to_le_bytes().to_vec()) .collect(); file.write_all(&y)?; } @@ -82,11 +87,12 @@ pub(crate) fn deserialize_data( bytes: &[u8], ) -> Result<(Vec, Vec, usize, usize), io::Error> { // read the same file back into a Vec of bytes + const USIZE_SIZE: usize = std::mem::size_of::(); let (num_samples, num_features) = { - let mut buffer = [0u8; 8]; - buffer.copy_from_slice(&bytes[0..8]); + let mut buffer = [0u8; USIZE_SIZE]; + buffer.copy_from_slice(&bytes[0..USIZE_SIZE]); let num_features = usize::from_le_bytes(buffer); - buffer.copy_from_slice(&bytes[8..16]); + buffer.copy_from_slice(&bytes[8..8 + USIZE_SIZE]); let num_samples = usize::from_le_bytes(buffer); (num_samples, num_features) }; @@ -115,6 +121,7 @@ pub(crate) fn deserialize_data( mod tests { use super::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn as_matrix() { let dataset = Dataset { diff --git a/src/decomposition/pca.rs b/src/decomposition/pca.rs index 189e6de..9aebae2 100644 --- a/src/decomposition/pca.rs +++ b/src/decomposition/pca.rs @@ -47,6 +47,7 @@ //! use std::fmt::Debug; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::api::{Transformer, UnsupervisedEstimator}; @@ -55,7 +56,8 @@ use crate::linalg::Matrix; use crate::math::num::RealNumber; /// Principal components analysis algorithm -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] pub struct PCA> { eigenvectors: M, eigenvalues: Vec, @@ -323,7 +325,7 @@ mod tests { &[6.8, 161.0, 60.0, 15.6], ]) } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn pca_components() { let us_arrests = us_arrests_data(); @@ -339,7 +341,7 @@ mod tests { assert!(expected.approximate_eq(&pca.components().abs(), 0.4)); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn decompose_covariance() { let us_arrests = us_arrests_data(); @@ -449,6 +451,7 @@ mod tests { .approximate_eq(&expected_projection.abs(), 1e-4)); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn decompose_correlation() { let us_arrests = us_arrests_data(); @@ -564,7 +567,9 @@ mod tests { .approximate_eq(&expected_projection.abs(), 1e-4)); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] + #[cfg(feature = "serde")] fn serde() { let iris = DenseMatrix::from_2d_array(&[ &[5.1, 3.5, 1.4, 0.2], diff --git a/src/decomposition/svd.rs b/src/decomposition/svd.rs index 595e93c..3807760 100644 --- a/src/decomposition/svd.rs +++ b/src/decomposition/svd.rs @@ -46,6 +46,7 @@ use std::fmt::Debug; use std::marker::PhantomData; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::api::{Transformer, UnsupervisedEstimator}; @@ -54,7 +55,8 @@ use crate::linalg::Matrix; use crate::math::num::RealNumber; /// SVD -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] pub struct SVD> { components: M, phantom: PhantomData, @@ -151,6 +153,7 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn svd_decompose() { // https://stat.ethz.ch/R-manual/R-devel/library/datasets/html/USArrests.html @@ -225,7 +228,9 @@ mod tests { .approximate_eq(&expected, 1e-4)); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] + #[cfg(feature = "serde")] fn serde() { let iris = DenseMatrix::from_2d_array(&[ &[5.1, 3.5, 1.4, 0.2], diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index 49c4239..247b502 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -45,14 +45,16 @@ //! //! //! +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; use std::default::Default; use std::fmt::Debug; -use rand::Rng; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::api::{Predictor, SupervisedEstimator}; -use crate::error::Failed; +use crate::error::{Failed, FailedError}; use crate::linalg::Matrix; use crate::math::num::RealNumber; use crate::tree::decision_tree_classifier::{ @@ -61,7 +63,8 @@ use crate::tree::decision_tree_classifier::{ /// Parameters of the Random Forest algorithm. /// Some parameters here are passed directly into base estimator. -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] pub struct RandomForestClassifierParameters { /// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) pub criterion: SplitCriterion, @@ -75,14 +78,20 @@ pub struct RandomForestClassifierParameters { pub n_trees: u16, /// Number of random sample of predictors to use as split candidates. pub m: Option, + /// Whether to keep samples used for tree generation. This is required for OOB prediction. + pub keep_samples: bool, + /// Seed used for bootstrap sampling and feature selection for each tree. + pub seed: u64, } /// Random Forest Classifier -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] pub struct RandomForestClassifier { - parameters: RandomForestClassifierParameters, + _parameters: RandomForestClassifierParameters, trees: Vec>, classes: Vec, + samples: Option>>, } impl RandomForestClassifierParameters { @@ -116,6 +125,18 @@ impl RandomForestClassifierParameters { self.m = Some(m); self } + + /// Whether to keep samples used for tree generation. This is required for OOB prediction. + pub fn with_keep_samples(mut self, keep_samples: bool) -> Self { + self.keep_samples = keep_samples; + self + } + + /// Seed used for bootstrap sampling and feature selection for each tree. + pub fn with_seed(mut self, seed: u64) -> Self { + self.seed = seed; + self + } } impl PartialEq for RandomForestClassifier { @@ -147,6 +168,8 @@ impl Default for RandomForestClassifierParameters { min_samples_split: 2, n_trees: 100, m: Option::None, + keep_samples: false, + seed: 0, } } } @@ -198,26 +221,38 @@ impl RandomForestClassifier { .unwrap() }); + let mut rng = StdRng::seed_from_u64(parameters.seed); let classes = y_m.unique(); let k = classes.len(); let mut trees: Vec> = Vec::new(); + let mut maybe_all_samples: Option>> = Option::None; + if parameters.keep_samples { + maybe_all_samples = Some(Vec::new()); + } + for _ in 0..parameters.n_trees { - let samples = RandomForestClassifier::::sample_with_replacement(&yi, k); + let samples = RandomForestClassifier::::sample_with_replacement(&yi, k, &mut rng); + if let Some(ref mut all_samples) = maybe_all_samples { + all_samples.push(samples.iter().map(|x| *x != 0).collect()) + } + let params = DecisionTreeClassifierParameters { criterion: parameters.criterion.clone(), max_depth: parameters.max_depth, min_samples_leaf: parameters.min_samples_leaf, min_samples_split: parameters.min_samples_split, }; - let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params)?; + let tree = + DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?; trees.push(tree); } Ok(RandomForestClassifier { - parameters, + _parameters: parameters, trees, classes, + samples: maybe_all_samples, }) } @@ -245,8 +280,43 @@ impl RandomForestClassifier { which_max(&result) } - fn sample_with_replacement(y: &[usize], num_classes: usize) -> Vec { - let mut rng = rand::thread_rng(); + /// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training. + pub fn predict_oob>(&self, x: &M) -> Result { + let (n, _) = x.shape(); + if self.samples.is_none() { + Err(Failed::because( + FailedError::PredictFailed, + "Need samples=true for OOB predictions.", + )) + } else if self.samples.as_ref().unwrap()[0].len() != n { + Err(Failed::because( + FailedError::PredictFailed, + "Prediction matrix must match matrix used in training for OOB predictions.", + )) + } else { + let mut result = M::zeros(1, n); + + for i in 0..n { + result.set(0, i, self.classes[self.predict_for_row_oob(x, i)]); + } + + Ok(result.to_row_vector()) + } + } + + fn predict_for_row_oob>(&self, x: &M, row: usize) -> usize { + let mut result = vec![0; self.classes.len()]; + + for (tree, samples) in self.trees.iter().zip(self.samples.as_ref().unwrap()) { + if !samples[row] { + result[tree.predict_for_row(x, row)] += 1; + } + } + + which_max(&result) + } + + fn sample_with_replacement(y: &[usize], num_classes: usize, rng: &mut impl Rng) -> Vec { let class_weight = vec![1.; num_classes]; let nrows = y.len(); let mut samples = vec![0; nrows]; @@ -262,7 +332,7 @@ impl RandomForestClassifier { let size = ((n_samples as f64) / *class_weight_l) as usize; for _ in 0..size { - let xi: usize = rng.gen_range(0, n_samples); + let xi: usize = rng.gen_range(0..n_samples); samples[index[xi]] += 1; } } @@ -276,6 +346,7 @@ mod tests { use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::metrics::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn fit_predict_iris() { let x = DenseMatrix::from_2d_array(&[ @@ -314,6 +385,8 @@ mod tests { min_samples_split: 2, n_trees: 100, m: Option::None, + keep_samples: false, + seed: 87, }, ) .unwrap(); @@ -321,7 +394,60 @@ mod tests { assert!(accuracy(&y, &classifier.predict(&x).unwrap()) >= 0.95); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] + fn fit_predict_iris_oob() { + let x = DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + &[4.9, 3.1, 1.5, 0.1], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + &[5.7, 2.8, 4.5, 1.3], + &[6.3, 3.3, 4.7, 1.6], + &[4.9, 2.4, 3.3, 1.0], + &[6.6, 2.9, 4.6, 1.3], + &[5.2, 2.7, 3.9, 1.4], + ]); + let y = vec![ + 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., + ]; + + let classifier = RandomForestClassifier::fit( + &x, + &y, + RandomForestClassifierParameters { + criterion: SplitCriterion::Gini, + max_depth: None, + min_samples_leaf: 1, + min_samples_split: 2, + n_trees: 100, + m: Option::None, + keep_samples: true, + seed: 87, + }, + ) + .unwrap(); + + assert!( + accuracy(&y, &classifier.predict_oob(&x).unwrap()) + < accuracy(&y, &classifier.predict(&x).unwrap()) + ); + } + + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + #[cfg(feature = "serde")] fn serde() { let x = DenseMatrix::from_2d_array(&[ &[5.1, 3.5, 1.4, 0.2], diff --git a/src/ensemble/random_forest_regressor.rs b/src/ensemble/random_forest_regressor.rs index fdeb9fc..08a7dcc 100644 --- a/src/ensemble/random_forest_regressor.rs +++ b/src/ensemble/random_forest_regressor.rs @@ -43,21 +43,24 @@ //! //! +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; use std::default::Default; use std::fmt::Debug; -use rand::Rng; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::api::{Predictor, SupervisedEstimator}; -use crate::error::Failed; +use crate::error::{Failed, FailedError}; use crate::linalg::Matrix; use crate::math::num::RealNumber; use crate::tree::decision_tree_regressor::{ DecisionTreeRegressor, DecisionTreeRegressorParameters, }; -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] /// Parameters of the Random Forest Regressor /// Some parameters here are passed directly into base estimator. pub struct RandomForestRegressorParameters { @@ -71,13 +74,19 @@ pub struct RandomForestRegressorParameters { pub n_trees: usize, /// Number of random sample of predictors to use as split candidates. pub m: Option, + /// Whether to keep samples used for tree generation. This is required for OOB prediction. + pub keep_samples: bool, + /// Seed used for bootstrap sampling and feature selection for each tree. + pub seed: u64, } /// Random Forest Regressor -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] pub struct RandomForestRegressor { - parameters: RandomForestRegressorParameters, + _parameters: RandomForestRegressorParameters, trees: Vec>, + samples: Option>>, } impl RandomForestRegressorParameters { @@ -106,8 +115,19 @@ impl RandomForestRegressorParameters { self.m = Some(m); self } -} + /// Whether to keep samples used for tree generation. This is required for OOB prediction. + pub fn with_keep_samples(mut self, keep_samples: bool) -> Self { + self.keep_samples = keep_samples; + self + } + + /// Seed used for bootstrap sampling and feature selection for each tree. + pub fn with_seed(mut self, seed: u64) -> Self { + self.seed = seed; + self + } +} impl Default for RandomForestRegressorParameters { fn default() -> Self { RandomForestRegressorParameters { @@ -116,6 +136,8 @@ impl Default for RandomForestRegressorParameters { min_samples_split: 2, n_trees: 10, m: Option::None, + keep_samples: false, + seed: 0, } } } @@ -169,20 +191,34 @@ impl RandomForestRegressor { .m .unwrap_or((num_attributes as f64).sqrt().floor() as usize); + let mut rng = StdRng::seed_from_u64(parameters.seed); let mut trees: Vec> = Vec::new(); + let mut maybe_all_samples: Option>> = Option::None; + if parameters.keep_samples { + maybe_all_samples = Some(Vec::new()); + } + for _ in 0..parameters.n_trees { - let samples = RandomForestRegressor::::sample_with_replacement(n_rows); + let samples = RandomForestRegressor::::sample_with_replacement(n_rows, &mut rng); + if let Some(ref mut all_samples) = maybe_all_samples { + all_samples.push(samples.iter().map(|x| *x != 0).collect()) + } let params = DecisionTreeRegressorParameters { max_depth: parameters.max_depth, min_samples_leaf: parameters.min_samples_leaf, min_samples_split: parameters.min_samples_split, }; - let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params)?; + let tree = + DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?; trees.push(tree); } - Ok(RandomForestRegressor { parameters, trees }) + Ok(RandomForestRegressor { + _parameters: parameters, + trees, + samples: maybe_all_samples, + }) } /// Predict class for `x` @@ -211,11 +247,49 @@ impl RandomForestRegressor { result / T::from(n_trees).unwrap() } - fn sample_with_replacement(nrows: usize) -> Vec { - let mut rng = rand::thread_rng(); + /// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training. + pub fn predict_oob>(&self, x: &M) -> Result { + let (n, _) = x.shape(); + if self.samples.is_none() { + Err(Failed::because( + FailedError::PredictFailed, + "Need samples=true for OOB predictions.", + )) + } else if self.samples.as_ref().unwrap()[0].len() != n { + Err(Failed::because( + FailedError::PredictFailed, + "Prediction matrix must match matrix used in training for OOB predictions.", + )) + } else { + let mut result = M::zeros(1, n); + + for i in 0..n { + result.set(0, i, self.predict_for_row_oob(x, i)); + } + + Ok(result.to_row_vector()) + } + } + + fn predict_for_row_oob>(&self, x: &M, row: usize) -> T { + let mut n_trees = 0; + let mut result = T::zero(); + + for (tree, samples) in self.trees.iter().zip(self.samples.as_ref().unwrap()) { + if !samples[row] { + result += tree.predict_for_row(x, row); + n_trees += 1; + } + } + + // TODO: What to do if there are no oob trees? + result / T::from(n_trees).unwrap() + } + + fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec { let mut samples = vec![0; nrows]; for _ in 0..nrows { - let xi = rng.gen_range(0, nrows); + let xi = rng.gen_range(0..nrows); samples[xi] += 1; } samples @@ -228,6 +302,7 @@ mod tests { use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::metrics::mean_absolute_error; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn fit_longley() { let x = DenseMatrix::from_2d_array(&[ @@ -262,6 +337,8 @@ mod tests { min_samples_split: 2, n_trees: 1000, m: Option::None, + keep_samples: false, + seed: 87, }, ) .and_then(|rf| rf.predict(&x)) @@ -270,7 +347,56 @@ mod tests { assert!(mean_absolute_error(&y, &y_hat) < 1.0); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] + fn fit_predict_longley_oob() { + let x = DenseMatrix::from_2d_array(&[ + &[234.289, 235.6, 159., 107.608, 1947., 60.323], + &[259.426, 232.5, 145.6, 108.632, 1948., 61.122], + &[258.054, 368.2, 161.6, 109.773, 1949., 60.171], + &[284.599, 335.1, 165., 110.929, 1950., 61.187], + &[328.975, 209.9, 309.9, 112.075, 1951., 63.221], + &[346.999, 193.2, 359.4, 113.27, 1952., 63.639], + &[365.385, 187., 354.7, 115.094, 1953., 64.989], + &[363.112, 357.8, 335., 116.219, 1954., 63.761], + &[397.469, 290.4, 304.8, 117.388, 1955., 66.019], + &[419.18, 282.2, 285.7, 118.734, 1956., 67.857], + &[442.769, 293.6, 279.8, 120.445, 1957., 68.169], + &[444.546, 468.1, 263.7, 121.95, 1958., 66.513], + &[482.704, 381.3, 255.2, 123.366, 1959., 68.655], + &[502.601, 393.1, 251.4, 125.368, 1960., 69.564], + &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], + &[554.894, 400.7, 282.7, 130.081, 1962., 70.551], + ]); + let y = vec![ + 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, + 114.2, 115.7, 116.9, + ]; + + let regressor = RandomForestRegressor::fit( + &x, + &y, + RandomForestRegressorParameters { + max_depth: None, + min_samples_leaf: 1, + min_samples_split: 2, + n_trees: 1000, + m: Option::None, + keep_samples: true, + seed: 87, + }, + ) + .unwrap(); + + let y_hat = regressor.predict(&x).unwrap(); + let y_hat_oob = regressor.predict_oob(&x).unwrap(); + + assert!(mean_absolute_error(&y, &y_hat) < mean_absolute_error(&y, &y_hat_oob)); + } + + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + #[cfg(feature = "serde")] fn serde() { let x = DenseMatrix::from_2d_array(&[ &[234.289, 235.6, 159., 107.608, 1947., 60.323], diff --git a/src/error/mod.rs b/src/error/mod.rs index 2409889..4e84f6e 100644 --- a/src/error/mod.rs +++ b/src/error/mod.rs @@ -2,10 +2,12 @@ use std::error::Error; use std::fmt; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; /// Generic error to be raised when something goes wrong. -#[derive(Debug, Serialize, Deserialize)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] pub struct Failed { err: FailedError, msg: String, @@ -13,7 +15,8 @@ pub struct Failed { /// Type of error #[non_exhaustive] -#[derive(Copy, Clone, Debug, Serialize, Deserialize)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Copy, Clone, Debug)] pub enum FailedError { /// Can't fit algorithm to data FitFailed = 1, diff --git a/src/lib.rs b/src/lib.rs index d962894..2edada4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,12 @@ #![allow( clippy::type_complexity, clippy::too_many_arguments, - clippy::many_single_char_names + clippy::many_single_char_names, + clippy::unnecessary_wraps, + clippy::upper_case_acronyms )] #![warn(missing_docs)] -#![warn(missing_doc_code_examples)] +#![warn(rustdoc::missing_doc_code_examples)] //! # SmartCore //! @@ -28,7 +30,7 @@ //! //! All machine learning algorithms in SmartCore are grouped into these broad categories: //! * [Clustering](cluster/index.html), unsupervised clustering of unlabeled data. -//! * [Martix Decomposition](decomposition/index.html), various methods for matrix decomposition. +//! * [Matrix Decomposition](decomposition/index.html), various methods for matrix decomposition. //! * [Linear Models](linear/index.html), regression and classification methods where output is assumed to have linear relation to explanatory variables //! * [Ensemble Models](ensemble/index.html), variety of regression and classification ensemble models //! * [Tree-based Models](tree/index.html), classification and regression trees @@ -91,6 +93,8 @@ pub mod naive_bayes; /// Supervised neighbors-based learning methods pub mod neighbors; pub(crate) mod optimization; +/// Preprocessing utilities +pub mod preprocessing; /// Support Vector Machines pub mod svm; /// Supervised tree-based learning methods diff --git a/src/linalg/cholesky.rs b/src/linalg/cholesky.rs index 724dc8a..9b5b9cc 100644 --- a/src/linalg/cholesky.rs +++ b/src/linalg/cholesky.rs @@ -87,8 +87,7 @@ impl> Cholesky { if bn != rn { return Err(Failed::because( FailedError::SolutionFailed, - &"Can\'t solve Ax = b for x. Number of rows in b != number of rows in R." - .to_string(), + "Can\'t solve Ax = b for x. Number of rows in b != number of rows in R.", )); } @@ -128,7 +127,7 @@ pub trait CholeskyDecomposableMatrix: BaseMatrix { if m != n { return Err(Failed::because( FailedError::DecompositionFailed, - &"Can\'t do Cholesky decomposition on a non-square matrix".to_string(), + "Can\'t do Cholesky decomposition on a non-square matrix", )); } @@ -148,7 +147,7 @@ pub trait CholeskyDecomposableMatrix: BaseMatrix { if d < T::zero() { return Err(Failed::because( FailedError::DecompositionFailed, - &"The matrix is not positive definite.".to_string(), + "The matrix is not positive definite.", )); } @@ -168,7 +167,7 @@ pub trait CholeskyDecomposableMatrix: BaseMatrix { mod tests { use super::*; use crate::linalg::naive::dense_matrix::*; - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn cholesky_decompose() { let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]); @@ -187,6 +186,7 @@ mod tests { .approximate_eq(&a.abs(), 1e-4)); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn cholesky_solve_mut() { let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]); diff --git a/src/linalg/evd.rs b/src/linalg/evd.rs index 4c1b6c3..bf195a0 100644 --- a/src/linalg/evd.rs +++ b/src/linalg/evd.rs @@ -93,11 +93,11 @@ pub trait EVDDecomposableMatrix: BaseMatrix { sort(&mut d, &mut e, &mut V); } - Ok(EVD { V, d, e }) + Ok(EVD { d, e, V }) } } -fn tred2>(V: &mut M, d: &mut Vec, e: &mut Vec) { +fn tred2>(V: &mut M, d: &mut [T], e: &mut [T]) { let (n, _) = V.shape(); for (i, d_i) in d.iter_mut().enumerate().take(n) { *d_i = V.get(n - 1, i); @@ -195,7 +195,7 @@ fn tred2>(V: &mut M, d: &mut Vec, e: &mut Vec e[0] = T::zero(); } -fn tql2>(V: &mut M, d: &mut Vec, e: &mut Vec) { +fn tql2>(V: &mut M, d: &mut [T], e: &mut [T]) { let (n, _) = V.shape(); for i in 1..n { e[i - 1] = e[i]; @@ -419,7 +419,7 @@ fn eltran>(A: &M, V: &mut M, perm: &[usize]) { } } -fn hqr2>(A: &mut M, V: &mut M, d: &mut Vec, e: &mut Vec) { +fn hqr2>(A: &mut M, V: &mut M, d: &mut [T], e: &mut [T]) { let (n, _) = A.shape(); let mut z = T::zero(); let mut s = T::zero(); @@ -471,7 +471,7 @@ fn hqr2>(A: &mut M, V: &mut M, d: &mut Vec, e A.set(nn, nn, x); A.set(nn - 1, nn - 1, y + t); if q >= T::zero() { - z = p + z.copysign(p); + z = p + RealNumber::copysign(z, p); d[nn - 1] = x + z; d[nn] = x + z; if z != T::zero() { @@ -570,7 +570,7 @@ fn hqr2>(A: &mut M, V: &mut M, d: &mut Vec, e r /= x; } } - let s = (p * p + q * q + r * r).sqrt().copysign(p); + let s = RealNumber::copysign((p * p + q * q + r * r).sqrt(), p); if s != T::zero() { if k == m { if l != m { @@ -594,12 +594,7 @@ fn hqr2>(A: &mut M, V: &mut M, d: &mut Vec, e A.sub_element_mut(k + 1, j, p * y); A.sub_element_mut(k, j, p * x); } - let mmin; - if nn < k + 3 { - mmin = nn; - } else { - mmin = k + 3; - } + let mmin = if nn < k + 3 { nn } else { k + 3 }; for i in 0..mmin + 1 { p = x * A.get(i, k) + y * A.get(i, k + 1); if k + 1 != nn { @@ -783,7 +778,7 @@ fn balbak>(V: &mut M, scale: &[T]) { } } -fn sort>(d: &mut Vec, e: &mut Vec, V: &mut M) { +fn sort>(d: &mut [T], e: &mut [T], V: &mut M) { let n = d.len(); let mut temp = vec![T::zero(); n]; for j in 1..n { @@ -816,7 +811,7 @@ fn sort>(d: &mut Vec, e: &mut Vec, V: &mut mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn decompose_symmetric() { let A = DenseMatrix::from_2d_array(&[ @@ -843,7 +838,7 @@ mod tests { assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON); } } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn decompose_asymmetric() { let A = DenseMatrix::from_2d_array(&[ @@ -870,7 +865,7 @@ mod tests { assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON); } } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn decompose_complex() { let A = DenseMatrix::from_2d_array(&[ diff --git a/src/linalg/lu.rs b/src/linalg/lu.rs index 6daed69..cb001af 100644 --- a/src/linalg/lu.rs +++ b/src/linalg/lu.rs @@ -46,13 +46,13 @@ use crate::math::num::RealNumber; pub struct LU> { LU: M, pivot: Vec, - pivot_sign: i8, + _pivot_sign: i8, singular: bool, phantom: PhantomData, } impl> LU { - pub(crate) fn new(LU: M, pivot: Vec, pivot_sign: i8) -> LU { + pub(crate) fn new(LU: M, pivot: Vec, _pivot_sign: i8) -> LU { let (_, n) = LU.shape(); let mut singular = false; @@ -66,7 +66,7 @@ impl> LU { LU { LU, pivot, - pivot_sign, + _pivot_sign, singular, phantom: PhantomData, } @@ -260,6 +260,7 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn decompose() { let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]); @@ -274,7 +275,7 @@ mod tests { assert!(lu.U().approximate_eq(&expected_U, 1e-4)); assert!(lu.pivot().approximate_eq(&expected_pivot, 1e-4)); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn inverse() { let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]); diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs index 264815b..59b6089 100644 --- a/src/linalg/mod.rs +++ b/src/linalg/mod.rs @@ -1,3 +1,4 @@ +#![allow(clippy::wrong_self_convention)] //! # Linear Algebra and Matrix Decomposition //! //! Most machine learning algorithms in SmartCore depend on linear algebra and matrix decomposition methods from this module. @@ -265,7 +266,7 @@ pub trait BaseVector: Clone + Debug { sum += xi * xi; } mu /= div; - sum / div - mu * mu + sum / div - mu.powi(2) } /// Computes the standard deviation. fn std(&self) -> T { @@ -688,12 +689,11 @@ impl<'a, T: RealNumber, M: BaseMatrix> Iterator for RowIter<'a, T, M> { type Item = Vec; fn next(&mut self) -> Option> { - let res; - if self.pos < self.max_pos { - res = Some(self.m.get_row_as_vec(self.pos)) + let res = if self.pos < self.max_pos { + Some(self.m.get_row_as_vec(self.pos)) } else { - res = None - } + None + }; self.pos += 1; res } @@ -705,6 +705,7 @@ mod tests { use crate::linalg::BaseMatrix; use crate::linalg::BaseVector; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn mean() { let m = vec![1., 2., 3.]; @@ -712,6 +713,7 @@ mod tests { assert_eq!(m.mean(), 2.0); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn std() { let m = vec![1., 2., 3.]; @@ -719,6 +721,7 @@ mod tests { assert!((m.std() - 0.81f64).abs() < 1e-2); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn var() { let m = vec![1., 2., 3., 4.]; @@ -726,6 +729,7 @@ mod tests { assert!((m.var() - 1.25f64).abs() < std::f64::EPSILON); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn vec_take() { let m = vec![1., 2., 3., 4., 5.]; @@ -733,6 +737,7 @@ mod tests { assert_eq!(m.take(&vec!(0, 0, 4, 4)), vec![1., 1., 5., 5.]); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn take() { let m = DenseMatrix::from_2d_array(&[ diff --git a/src/linalg/naive/dense_matrix.rs b/src/linalg/naive/dense_matrix.rs index a0b7bdb..1af926c 100644 --- a/src/linalg/naive/dense_matrix.rs +++ b/src/linalg/naive/dense_matrix.rs @@ -1,11 +1,15 @@ #![allow(clippy::ptr_arg)] use std::fmt; use std::fmt::Debug; +#[cfg(feature = "serde")] use std::marker::PhantomData; use std::ops::Range; +#[cfg(feature = "serde")] use serde::de::{Deserializer, MapAccess, SeqAccess, Visitor}; +#[cfg(feature = "serde")] use serde::ser::{SerializeStruct, Serializer}; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::linalg::cholesky::CholeskyDecomposableMatrix; @@ -326,7 +330,7 @@ impl DenseMatrix { cur_r: 0, max_c: self.ncols, max_r: self.nrows, - m: &self, + m: self, } } } @@ -349,6 +353,7 @@ impl<'a, T: RealNumber> Iterator for DenseMatrixIterator<'a, T> { } } +#[cfg(feature = "serde")] impl<'de, T: RealNumber + fmt::Debug + Deserialize<'de>> Deserialize<'de> for DenseMatrix { fn deserialize(deserializer: D) -> Result where @@ -434,6 +439,7 @@ impl<'de, T: RealNumber + fmt::Debug + Deserialize<'de>> Deserialize<'de> for De } } +#[cfg(feature = "serde")] impl Serialize for DenseMatrix { fn serialize(&self, serializer: S) -> Result where @@ -517,10 +523,9 @@ impl PartialEq for DenseMatrix { true } } - -impl Into> for DenseMatrix { - fn into(self) -> Vec { - self.values +impl From> for Vec { + fn from(dense_matrix: DenseMatrix) -> Vec { + dense_matrix.values } } @@ -1054,14 +1059,14 @@ impl BaseMatrix for DenseMatrix { #[cfg(test)] mod tests { use super::*; - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn vec_dot() { let v1 = vec![1., 2., 3.]; let v2 = vec![4., 5., 6.]; assert_eq!(32.0, BaseVector::dot(&v1, &v2)); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn vec_copy_from() { let mut v1 = vec![1., 2., 3.]; @@ -1069,7 +1074,7 @@ mod tests { v1.copy_from(&v2); assert_eq!(v1, v2); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn vec_approximate_eq() { let a = vec![1., 2., 3.]; @@ -1077,7 +1082,7 @@ mod tests { assert!(a.approximate_eq(&b, 1e-4)); assert!(!a.approximate_eq(&b, 1e-5)); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn from_array() { let vec = [1., 2., 3., 4., 5., 6.]; @@ -1090,7 +1095,7 @@ mod tests { DenseMatrix::new(2, 3, vec![1., 4., 2., 5., 3., 6.]) ); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn row_column_vec_from_array() { let vec = vec![1., 2., 3., 4., 5., 6.]; @@ -1103,7 +1108,7 @@ mod tests { DenseMatrix::new(6, 1, vec![1., 2., 3., 4., 5., 6.]) ); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn from_to_row_vec() { let vec = vec![1., 2., 3.]; @@ -1116,20 +1121,20 @@ mod tests { vec![1., 2., 3.] ); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn col_matrix_to_row_vector() { let m: DenseMatrix = BaseMatrix::zeros(10, 1); assert_eq!(m.to_row_vector().len(), 10) } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn iter() { let vec = vec![1., 2., 3., 4., 5., 6.]; let m = DenseMatrix::from_array(3, 2, &vec); assert_eq!(vec, m.iter().collect::>()); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn v_stack() { let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]); @@ -1144,7 +1149,7 @@ mod tests { let result = a.v_stack(&b); assert_eq!(result, expected); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn h_stack() { let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]); @@ -1157,13 +1162,13 @@ mod tests { let result = a.h_stack(&b); assert_eq!(result, expected); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn get_row() { let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]); assert_eq!(vec![4., 5., 6.], a.get_row(1)); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn matmul() { let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]); @@ -1172,7 +1177,7 @@ mod tests { let result = a.matmul(&b); assert_eq!(result, expected); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn ab() { let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]); @@ -1195,14 +1200,14 @@ mod tests { DenseMatrix::from_2d_array(&[&[29., 39., 49.], &[40., 54., 68.,], &[51., 69., 87.]]) ); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn dot() { let a = DenseMatrix::from_array(1, 3, &[1., 2., 3.]); let b = DenseMatrix::from_array(1, 3, &[4., 5., 6.]); assert_eq!(a.dot(&b), 32.); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn copy_from() { let mut a = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.]]); @@ -1210,7 +1215,7 @@ mod tests { a.copy_from(&b); assert_eq!(a, b); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn slice() { let m = DenseMatrix::from_2d_array(&[ @@ -1222,7 +1227,7 @@ mod tests { let result = m.slice(0..2, 1..3); assert_eq!(result, expected); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn approximate_eq() { let m = DenseMatrix::from_2d_array(&[&[2., 3.], &[5., 6.]]); @@ -1231,7 +1236,7 @@ mod tests { assert!(m.approximate_eq(&m_eq, 0.5)); assert!(!m.approximate_eq(&m_neq, 0.5)); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn rand() { let m: DenseMatrix = DenseMatrix::rand(3, 3); @@ -1241,7 +1246,7 @@ mod tests { } } } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn transpose() { let m = DenseMatrix::from_2d_array(&[&[1.0, 3.0], &[2.0, 4.0]]); @@ -1253,7 +1258,7 @@ mod tests { } } } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn reshape() { let m_orig = DenseMatrix::row_vector_from_array(&[1., 2., 3., 4., 5., 6.]); @@ -1264,7 +1269,7 @@ mod tests { assert_eq!(m_result.get(0, 1), 2.); assert_eq!(m_result.get(0, 3), 4.); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn norm() { let v = DenseMatrix::row_vector_from_array(&[3., -2., 6.]); @@ -1273,7 +1278,7 @@ mod tests { assert_eq!(v.norm(std::f64::INFINITY), 6.); assert_eq!(v.norm(std::f64::NEG_INFINITY), 2.); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn softmax_mut() { let mut prob: DenseMatrix = DenseMatrix::row_vector_from_array(&[1., 2., 3.]); @@ -1282,14 +1287,14 @@ mod tests { assert!((prob.get(0, 1) - 0.24).abs() < 0.01); assert!((prob.get(0, 2) - 0.66).abs() < 0.01); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn col_mean() { let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]); let res = a.column_mean(); assert_eq!(res, vec![4., 5., 6.]); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn min_max_sum() { let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]); @@ -1297,30 +1302,32 @@ mod tests { assert_eq!(1., a.min()); assert_eq!(6., a.max()); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn eye() { let a = DenseMatrix::from_2d_array(&[&[1., 0., 0.], &[0., 1., 0.], &[0., 0., 1.]]); let res = DenseMatrix::eye(3); assert_eq!(res, a); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] + #[cfg(feature = "serde")] fn to_from_json() { let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); let deserialized_a: DenseMatrix = serde_json::from_str(&serde_json::to_string(&a).unwrap()).unwrap(); assert_eq!(a, deserialized_a); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] + #[cfg(feature = "serde")] fn to_from_bincode() { let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); let deserialized_a: DenseMatrix = bincode::deserialize(&bincode::serialize(&a).unwrap()).unwrap(); assert_eq!(a, deserialized_a); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn to_string() { let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); @@ -1329,7 +1336,7 @@ mod tests { "[[0.9, 0.4, 0.7], [0.4, 0.5, 0.3], [0.7, 0.3, 0.8]]" ); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn cov() { let a = DenseMatrix::from_2d_array(&[ diff --git a/src/linalg/nalgebra_bindings.rs b/src/linalg/nalgebra_bindings.rs index b976fbd..249f21f 100644 --- a/src/linalg/nalgebra_bindings.rs +++ b/src/linalg/nalgebra_bindings.rs @@ -579,6 +579,7 @@ mod tests { use crate::linear::linear_regression::*; use nalgebra::{DMatrix, Matrix2x3, RowDVector}; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn vec_copy_from() { let mut v1 = RowDVector::from_vec(vec![1., 2., 3.]); @@ -589,12 +590,14 @@ mod tests { assert_ne!(v2, v1); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn vec_len() { let v = RowDVector::from_vec(vec![1., 2., 3.]); assert_eq!(3, v.len()); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn get_set_vector() { let mut v = RowDVector::from_vec(vec![1., 2., 3., 4.]); @@ -607,12 +610,14 @@ mod tests { assert_eq!(5., BaseVector::get(&v, 1)); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn vec_to_vec() { let v = RowDVector::from_vec(vec![1., 2., 3.]); assert_eq!(vec![1., 2., 3.], v.to_vec()); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn vec_init() { let zeros: RowDVector = BaseVector::zeros(3); @@ -623,6 +628,7 @@ mod tests { assert_eq!(twos, RowDVector::from_vec(vec![2., 2., 2.])); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn vec_dot() { let v1 = RowDVector::from_vec(vec![1., 2., 3.]); @@ -630,6 +636,7 @@ mod tests { assert_eq!(32.0, BaseVector::dot(&v1, &v2)); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn vec_approximate_eq() { let a = RowDVector::from_vec(vec![1., 2., 3.]); @@ -638,6 +645,7 @@ mod tests { assert!(!a.approximate_eq(&(&noise + &a), 1e-5)); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn get_set_dynamic() { let mut m = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); @@ -650,6 +658,7 @@ mod tests { assert_eq!(10., BaseMatrix::get(&m, 1, 1)); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn zeros() { let expected = DMatrix::from_row_slice(2, 2, &[0., 0., 0., 0.]); @@ -659,6 +668,7 @@ mod tests { assert_eq!(m, expected); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn ones() { let expected = DMatrix::from_row_slice(2, 2, &[1., 1., 1., 1.]); @@ -668,6 +678,7 @@ mod tests { assert_eq!(m, expected); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn eye() { let expected = DMatrix::from_row_slice(3, 3, &[1., 0., 0., 0., 1., 0., 0., 0., 1.]); @@ -675,6 +686,7 @@ mod tests { assert_eq!(m, expected); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn shape() { let m: DMatrix = BaseMatrix::zeros(5, 10); @@ -684,6 +696,7 @@ mod tests { assert_eq!(ncols, 10); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn scalar_add_sub_mul_div() { let mut m = DMatrix::from_row_slice(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); @@ -697,6 +710,7 @@ mod tests { assert_eq!(m, expected); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn add_sub_mul_div() { let mut m = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]); @@ -715,6 +729,7 @@ mod tests { assert_eq!(m, expected); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn to_from_row_vector() { let v = RowDVector::from_vec(vec![1., 2., 3., 4.]); @@ -723,12 +738,14 @@ mod tests { assert_eq!(m.to_row_vector(), expected); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn col_matrix_to_row_vector() { let m: DMatrix = BaseMatrix::zeros(10, 1); assert_eq!(m.to_row_vector().len(), 10) } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn get_row_col_as_vec() { let m = DMatrix::from_row_slice(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]); @@ -737,12 +754,14 @@ mod tests { assert_eq!(m.get_col_as_vec(1), vec!(2., 5., 8.)); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn get_row() { let a = DMatrix::from_row_slice(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]); assert_eq!(RowDVector::from_vec(vec![4., 5., 6.]), a.get_row(1)); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn copy_row_col_as_vec() { let m = DMatrix::from_row_slice(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]); @@ -754,6 +773,7 @@ mod tests { assert_eq!(v, vec!(2., 5., 8.)); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn element_add_sub_mul_div() { let mut m = DMatrix::from_row_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]); @@ -767,6 +787,7 @@ mod tests { assert_eq!(m, expected); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn vstack_hstack() { let m1 = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]); @@ -782,6 +803,7 @@ mod tests { assert_eq!(result, expected); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn matmul() { let a = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]); @@ -791,6 +813,7 @@ mod tests { assert_eq!(result, expected); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn dot() { let a = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]); @@ -798,6 +821,7 @@ mod tests { assert_eq!(14., a.dot(&b)); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn slice() { let a = DMatrix::from_row_slice( @@ -810,6 +834,7 @@ mod tests { assert_eq!(result, expected); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn approximate_eq() { let a = DMatrix::from_row_slice(3, 3, &[1., 2., 3., 4., 5., 6., 7., 8., 9.]); @@ -822,6 +847,7 @@ mod tests { assert!(!a.approximate_eq(&(&noise + &a), 1e-5)); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn negative_mut() { let mut v = DMatrix::from_row_slice(1, 3, &[3., -2., 6.]); @@ -829,6 +855,7 @@ mod tests { assert_eq!(v, DMatrix::from_row_slice(1, 3, &[-3., 2., -6.])); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn transpose() { let m = DMatrix::from_row_slice(2, 2, &[1.0, 3.0, 2.0, 4.0]); @@ -837,6 +864,7 @@ mod tests { assert_eq!(m_transposed, expected); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn rand() { let m: DMatrix = BaseMatrix::rand(3, 3); @@ -847,6 +875,7 @@ mod tests { } } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn norm() { let v = DMatrix::from_row_slice(1, 3, &[3., -2., 6.]); @@ -856,6 +885,7 @@ mod tests { assert_eq!(BaseMatrix::norm(&v, std::f64::NEG_INFINITY), 2.); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn col_mean() { let a = DMatrix::from_row_slice(3, 3, &[1., 2., 3., 4., 5., 6., 7., 8., 9.]); @@ -863,6 +893,7 @@ mod tests { assert_eq!(res, vec![4., 5., 6.]); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn reshape() { let m_orig = DMatrix::from_row_slice(1, 6, &[1., 2., 3., 4., 5., 6.]); @@ -874,6 +905,7 @@ mod tests { assert_eq!(BaseMatrix::get(&m_result, 0, 3), 4.); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn copy_from() { let mut src = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]); @@ -882,6 +914,7 @@ mod tests { assert_eq!(src, dst); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn abs_mut() { let mut a = DMatrix::from_row_slice(2, 2, &[1., -2., 3., -4.]); @@ -890,6 +923,7 @@ mod tests { assert_eq!(a, expected); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn min_max_sum() { let a = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., 5., 6.]); @@ -898,6 +932,7 @@ mod tests { assert_eq!(6., a.max()); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn max_diff() { let a1 = DMatrix::from_row_slice(2, 3, &[1., 2., 3., 4., -5., 6.]); @@ -906,6 +941,7 @@ mod tests { assert_eq!(a2.max_diff(&a2), 0.); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn softmax_mut() { let mut prob: DMatrix = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]); @@ -915,6 +951,7 @@ mod tests { assert!((BaseMatrix::get(&prob, 0, 2) - 0.66).abs() < 0.01); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn pow_mut() { let mut a = DMatrix::from_row_slice(1, 3, &[1., 2., 3.]); @@ -922,6 +959,7 @@ mod tests { assert_eq!(a, DMatrix::from_row_slice(1, 3, &[1., 8., 27.])); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn argmax() { let a = DMatrix::from_row_slice(3, 3, &[1., 2., 3., -5., -6., -7., 0.1, 0.2, 0.1]); @@ -929,6 +967,7 @@ mod tests { assert_eq!(res, vec![2, 0, 1]); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn unique() { let a = DMatrix::from_row_slice(3, 3, &[1., 2., 2., -2., -6., -7., 2., 3., 4.]); @@ -937,6 +976,7 @@ mod tests { assert_eq!(res, vec![-7., -6., -2., 1., 2., 3., 4.]); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn ols_fit_predict() { let x = DMatrix::from_row_slice( diff --git a/src/linalg/ndarray_bindings.rs b/src/linalg/ndarray_bindings.rs index 6ed40c8..99e0918 100644 --- a/src/linalg/ndarray_bindings.rs +++ b/src/linalg/ndarray_bindings.rs @@ -178,7 +178,7 @@ impl BaseVector for ArrayBase, Ix } fn copy_from(&mut self, other: &Self) { - self.assign(&other); + self.assign(other); } } @@ -385,7 +385,7 @@ impl &Self { @@ -530,6 +530,7 @@ mod tests { use crate::metrics::mean_absolute_error; use ndarray::{arr1, arr2, Array1, Array2}; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn vec_get_set() { let mut result = arr1(&[1., 2., 3.]); @@ -541,6 +542,7 @@ mod tests { assert_eq!(5., BaseVector::get(&result, 1)); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn vec_copy_from() { let mut v1 = arr1(&[1., 2., 3.]); @@ -551,18 +553,21 @@ mod tests { assert_ne!(v1, v2); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn vec_len() { let v = arr1(&[1., 2., 3.]); assert_eq!(3, v.len()); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn vec_to_vec() { let v = arr1(&[1., 2., 3.]); assert_eq!(vec![1., 2., 3.], v.to_vec()); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn vec_dot() { let v1 = arr1(&[1., 2., 3.]); @@ -570,6 +575,7 @@ mod tests { assert_eq!(32.0, BaseVector::dot(&v1, &v2)); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn vec_approximate_eq() { let a = arr1(&[1., 2., 3.]); @@ -578,6 +584,7 @@ mod tests { assert!(!a.approximate_eq(&(&noise + &a), 1e-5)); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn from_to_row_vec() { let vec = arr1(&[1., 2., 3.]); @@ -588,12 +595,14 @@ mod tests { ); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn col_matrix_to_row_vector() { let m: Array2 = BaseMatrix::zeros(10, 1); assert_eq!(m.to_row_vector().len(), 10) } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn add_mut() { let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]); @@ -604,6 +613,7 @@ mod tests { assert_eq!(a1, a3); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn sub_mut() { let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]); @@ -614,6 +624,7 @@ mod tests { assert_eq!(a1, a3); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn mul_mut() { let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]); @@ -624,6 +635,7 @@ mod tests { assert_eq!(a1, a3); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn div_mut() { let mut a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]); @@ -634,6 +646,7 @@ mod tests { assert_eq!(a1, a3); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn div_element_mut() { let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]); @@ -642,6 +655,7 @@ mod tests { assert_eq!(BaseMatrix::get(&a, 1, 1), 1.); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn mul_element_mut() { let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]); @@ -650,6 +664,7 @@ mod tests { assert_eq!(BaseMatrix::get(&a, 1, 1), 25.); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn add_element_mut() { let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]); @@ -657,7 +672,7 @@ mod tests { assert_eq!(BaseMatrix::get(&a, 1, 1), 10.); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn sub_element_mut() { let mut a = arr2(&[[1., 2., 3.], [4., 5., 6.]]); @@ -666,6 +681,7 @@ mod tests { assert_eq!(BaseMatrix::get(&a, 1, 1), 0.); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn vstack_hstack() { let a1 = arr2(&[[1., 2., 3.], [4., 5., 6.]]); @@ -680,6 +696,7 @@ mod tests { assert_eq!(result, expected); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn get_set() { let mut result = arr2(&[[1., 2., 3.], [4., 5., 6.]]); @@ -691,6 +708,7 @@ mod tests { assert_eq!(10., BaseMatrix::get(&result, 1, 1)); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn matmul() { let a = arr2(&[[1., 2., 3.], [4., 5., 6.]]); @@ -700,6 +718,7 @@ mod tests { assert_eq!(result, expected); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn dot() { let a = arr2(&[[1., 2., 3.]]); @@ -707,6 +726,7 @@ mod tests { assert_eq!(14., BaseMatrix::dot(&a, &b)); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn slice() { let a = arr2(&[ @@ -719,6 +739,7 @@ mod tests { assert_eq!(result, expected); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn scalar_ops() { let a = arr2(&[[1., 2., 3.]]); @@ -728,6 +749,7 @@ mod tests { assert_eq!(&arr2(&[[0.5, 1., 1.5]]), a.clone().div_scalar_mut(2.)); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn transpose() { let m = arr2(&[[1.0, 3.0], [2.0, 4.0]]); @@ -736,6 +758,7 @@ mod tests { assert_eq!(m_transposed, expected); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn norm() { let v = arr2(&[[3., -2., 6.]]); @@ -745,6 +768,7 @@ mod tests { assert_eq!(v.norm(std::f64::NEG_INFINITY), 2.); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn negative_mut() { let mut v = arr2(&[[3., -2., 6.]]); @@ -752,6 +776,7 @@ mod tests { assert_eq!(v, arr2(&[[-3., 2., -6.]])); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn reshape() { let m_orig = arr2(&[[1., 2., 3., 4., 5., 6.]]); @@ -763,6 +788,7 @@ mod tests { assert_eq!(BaseMatrix::get(&m_result, 0, 3), 4.); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn copy_from() { let mut src = arr2(&[[1., 2., 3.]]); @@ -771,6 +797,7 @@ mod tests { assert_eq!(src, dst); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn min_max_sum() { let a = arr2(&[[1., 2., 3.], [4., 5., 6.]]); @@ -779,6 +806,7 @@ mod tests { assert_eq!(6., a.max()); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn max_diff() { let a1 = arr2(&[[1., 2., 3.], [4., -5., 6.]]); @@ -787,6 +815,7 @@ mod tests { assert_eq!(a2.max_diff(&a2), 0.); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn softmax_mut() { let mut prob: Array2 = arr2(&[[1., 2., 3.]]); @@ -796,6 +825,7 @@ mod tests { assert!((BaseMatrix::get(&prob, 0, 2) - 0.66).abs() < 0.01); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn pow_mut() { let mut a = arr2(&[[1., 2., 3.]]); @@ -803,6 +833,7 @@ mod tests { assert_eq!(a, arr2(&[[1., 8., 27.]])); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn argmax() { let a = arr2(&[[1., 2., 3.], [-5., -6., -7.], [0.1, 0.2, 0.1]]); @@ -810,6 +841,7 @@ mod tests { assert_eq!(res, vec![2, 0, 1]); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn unique() { let a = arr2(&[[1., 2., 2.], [-2., -6., -7.], [2., 3., 4.]]); @@ -818,6 +850,7 @@ mod tests { assert_eq!(res, vec![-7., -6., -2., 1., 2., 3., 4.]); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn get_row_as_vector() { let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]); @@ -825,12 +858,14 @@ mod tests { assert_eq!(res, vec![4., 5., 6.]); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn get_row() { let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]); assert_eq!(arr1(&[4., 5., 6.]), a.get_row(1)); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn get_col_as_vector() { let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]); @@ -838,6 +873,7 @@ mod tests { assert_eq!(res, vec![2., 5., 8.]); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn copy_row_col_as_vec() { let m = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]); @@ -849,6 +885,7 @@ mod tests { assert_eq!(v, vec!(2., 5., 8.)); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn col_mean() { let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]); @@ -856,6 +893,7 @@ mod tests { assert_eq!(res, vec![4., 5., 6.]); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn eye() { let a = arr2(&[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]); @@ -863,6 +901,7 @@ mod tests { assert_eq!(res, a); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn rand() { let m: Array2 = BaseMatrix::rand(3, 3); @@ -873,6 +912,7 @@ mod tests { } } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn approximate_eq() { let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]); @@ -881,6 +921,7 @@ mod tests { assert!(!a.approximate_eq(&(&noise + &a), 1e-5)); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn abs_mut() { let mut a = arr2(&[[1., -2.], [3., -4.]]); @@ -889,6 +930,7 @@ mod tests { assert_eq!(a, expected); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn lr_fit_predict_iris() { let x = arr2(&[ @@ -924,12 +966,13 @@ mod tests { let error: f64 = y .into_iter() .zip(y_hat.into_iter()) - .map(|(&a, &b)| (a - b).abs()) + .map(|(a, b)| (a - b).abs()) .sum(); assert!(error <= 1.0); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn my_fit_longley_ndarray() { let x = arr2(&[ @@ -964,6 +1007,8 @@ mod tests { min_samples_split: 2, n_trees: 1000, m: Option::None, + keep_samples: false, + seed: 0, }, ) .unwrap() diff --git a/src/linalg/qr.rs b/src/linalg/qr.rs index a06a01f..3380fb4 100644 --- a/src/linalg/qr.rs +++ b/src/linalg/qr.rs @@ -195,7 +195,7 @@ pub trait QRDecomposableMatrix: BaseMatrix { mod tests { use super::*; use crate::linalg::naive::dense_matrix::*; - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn decompose() { let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); @@ -214,6 +214,7 @@ mod tests { assert!(qr.R().abs().approximate_eq(&r.abs(), 1e-4)); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn qr_solve_mut() { let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); diff --git a/src/linalg/stats.rs b/src/linalg/stats.rs index 45a17af..10a3fc4 100644 --- a/src/linalg/stats.rs +++ b/src/linalg/stats.rs @@ -61,7 +61,7 @@ pub trait MatrixStats: BaseMatrix { sum += a * a; } mu /= div; - *x_i = sum / div - mu * mu; + *x_i = sum / div - mu.powi(2); } x @@ -150,7 +150,7 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::linalg::BaseVector; - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn mean() { let m = DenseMatrix::from_2d_array(&[ @@ -164,7 +164,7 @@ mod tests { assert_eq!(m.mean(0), expected_0); assert_eq!(m.mean(1), expected_1); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn std() { let m = DenseMatrix::from_2d_array(&[ @@ -178,7 +178,7 @@ mod tests { assert!(m.std(0).approximate_eq(&expected_0, 1e-2)); assert!(m.std(1).approximate_eq(&expected_1, 1e-2)); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn var() { let m = DenseMatrix::from_2d_array(&[&[1., 2., 3., 4.], &[5., 6., 7., 8.]]); @@ -188,7 +188,7 @@ mod tests { assert!(m.var(0).approximate_eq(&expected_0, std::f64::EPSILON)); assert!(m.var(1).approximate_eq(&expected_1, std::f64::EPSILON)); } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn scale() { let mut m = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]); diff --git a/src/linalg/svd.rs b/src/linalg/svd.rs index e370453..97d85ca 100644 --- a/src/linalg/svd.rs +++ b/src/linalg/svd.rs @@ -47,7 +47,7 @@ pub struct SVD> { pub V: M, /// Singular values of the original matrix pub s: Vec, - full: bool, + _full: bool, m: usize, n: usize, tol: T, @@ -116,7 +116,7 @@ pub trait SVDDecomposableMatrix: BaseMatrix { } let mut f = U.get(i, i); - g = -s.sqrt().copysign(f); + g = -RealNumber::copysign(s.sqrt(), f); let h = f * g - s; U.set(i, i, f - g); for j in l - 1..n { @@ -152,7 +152,7 @@ pub trait SVDDecomposableMatrix: BaseMatrix { } let f = U.get(i, l - 1); - g = -s.sqrt().copysign(f); + g = -RealNumber::copysign(s.sqrt(), f); let h = f * g - s; U.set(i, l - 1, f - g); @@ -299,7 +299,7 @@ pub trait SVDDecomposableMatrix: BaseMatrix { let mut h = rv1[k]; let mut f = ((y - z) * (y + z) + (g - h) * (g + h)) / (T::two() * h * y); g = f.hypot(T::one()); - f = ((x - z) * (x + z) + h * ((y / (f + g.copysign(f))) - h)) / x; + f = ((x - z) * (x + z) + h * ((y / (f + RealNumber::copysign(g, f))) - h)) / x; let mut c = T::one(); let mut s = T::one(); @@ -428,13 +428,13 @@ impl> SVD { pub(crate) fn new(U: M, V: M, s: Vec) -> SVD { let m = U.shape().0; let n = V.shape().0; - let full = s.len() == m.min(n); + let _full = s.len() == m.min(n); let tol = T::half() * (T::from(m + n).unwrap() + T::one()).sqrt() * s[0] * T::epsilon(); SVD { U, V, s, - full, + _full, m, n, tol, @@ -482,7 +482,7 @@ impl> SVD { mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn decompose_symmetric() { let A = DenseMatrix::from_2d_array(&[ @@ -513,7 +513,7 @@ mod tests { assert!((s[i] - svd.s[i]).abs() < 1e-4); } } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn decompose_asymmetric() { let A = DenseMatrix::from_2d_array(&[ @@ -714,7 +714,7 @@ mod tests { assert!((s[i] - svd.s[i]).abs() < 1e-4); } } - + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn solve() { let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); @@ -725,6 +725,7 @@ mod tests { assert!(w.approximate_eq(&expected_w, 1e-2)); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn decompose_restore() { let a = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0, 4.0], &[5.0, 6.0, 7.0, 8.0]]); diff --git a/src/linear/bg_solver.rs b/src/linear/bg_solver.rs index 46ef13d..28cc3d8 100644 --- a/src/linear/bg_solver.rs +++ b/src/linear/bg_solver.rs @@ -126,6 +126,7 @@ mod tests { impl> BiconjugateGradientSolver for BGSolver {} + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn bg_solver() { let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]); diff --git a/src/linear/elastic_net.rs b/src/linear/elastic_net.rs index 2833ff1..ce13435 100644 --- a/src/linear/elastic_net.rs +++ b/src/linear/elastic_net.rs @@ -56,6 +56,7 @@ //! use std::fmt::Debug; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::api::{Predictor, SupervisedEstimator}; @@ -67,7 +68,8 @@ use crate::math::num::RealNumber; use crate::linear::lasso_optimizer::InteriorPointOptimizer; /// Elastic net parameters -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] pub struct ElasticNetParameters { /// Regularization parameter. pub alpha: T, @@ -84,7 +86,8 @@ pub struct ElasticNetParameters { } /// Elastic net -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] pub struct ElasticNet> { coefficients: M, intercept: T, @@ -288,6 +291,7 @@ mod tests { use crate::linalg::naive::dense_matrix::*; use crate::metrics::mean_absolute_error; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn elasticnet_longley() { let x = DenseMatrix::from_2d_array(&[ @@ -331,6 +335,7 @@ mod tests { assert!(mean_absolute_error(&y_hat, &y) < 30.0); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn elasticnet_fit_predict1() { let x = DenseMatrix::from_2d_array(&[ @@ -397,7 +402,9 @@ mod tests { assert!(l1_model.coefficients().get(0, 0) > l1_model.coefficients().get(2, 0)); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] + #[cfg(feature = "serde")] fn serde() { let x = DenseMatrix::from_2d_array(&[ &[234.289, 235.6, 159.0, 107.608, 1947., 60.323], diff --git a/src/linear/lasso.rs b/src/linear/lasso.rs index b99ecff..7edd325 100644 --- a/src/linear/lasso.rs +++ b/src/linear/lasso.rs @@ -24,6 +24,7 @@ //! use std::fmt::Debug; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::api::{Predictor, SupervisedEstimator}; @@ -34,7 +35,8 @@ use crate::linear::lasso_optimizer::InteriorPointOptimizer; use crate::math::num::RealNumber; /// Lasso regression parameters -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] pub struct LassoParameters { /// Controls the strength of the penalty to the loss function. pub alpha: T, @@ -47,7 +49,8 @@ pub struct LassoParameters { pub max_iter: usize, } -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] /// Lasso regressor pub struct Lasso> { coefficients: M, @@ -223,6 +226,7 @@ mod tests { use crate::linalg::naive::dense_matrix::*; use crate::metrics::mean_absolute_error; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn lasso_fit_predict() { let x = DenseMatrix::from_2d_array(&[ @@ -271,7 +275,9 @@ mod tests { assert!(mean_absolute_error(&y_hat, &y) < 2.0); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] + #[cfg(feature = "serde")] fn serde() { let x = DenseMatrix::from_2d_array(&[ &[234.289, 235.6, 159.0, 107.608, 1947., 60.323], diff --git a/src/linear/lasso_optimizer.rs b/src/linear/lasso_optimizer.rs index 4f5011f..c4340fc 100644 --- a/src/linear/lasso_optimizer.rs +++ b/src/linear/lasso_optimizer.rs @@ -138,7 +138,7 @@ impl> InteriorPointOptimizer { for i in 0..p { self.prb[i] = T::two() + self.d1[i]; - self.prs[i] = self.prb[i] * self.d1[i] - self.d2[i] * self.d2[i]; + self.prs[i] = self.prb[i] * self.d1[i] - self.d2[i].powi(2); } let normg = grad.norm2(); diff --git a/src/linear/linear_regression.rs b/src/linear/linear_regression.rs index 2ef03c1..b1f7c51 100644 --- a/src/linear/linear_regression.rs +++ b/src/linear/linear_regression.rs @@ -62,6 +62,7 @@ //! use std::fmt::Debug; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::api::{Predictor, SupervisedEstimator}; @@ -69,7 +70,8 @@ use crate::error::Failed; use crate::linalg::Matrix; use crate::math::num::RealNumber; -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] /// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable. pub enum LinearRegressionSolverName { /// QR decomposition, see [QR](../../linalg/qr/index.html) @@ -79,18 +81,20 @@ pub enum LinearRegressionSolverName { } /// Linear Regression parameters -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] pub struct LinearRegressionParameters { /// Solver to use for estimation of regression coefficients. pub solver: LinearRegressionSolverName, } /// Linear Regression -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] pub struct LinearRegression> { coefficients: M, intercept: T, - solver: LinearRegressionSolverName, + _solver: LinearRegressionSolverName, } impl LinearRegressionParameters { @@ -151,7 +155,7 @@ impl> LinearRegression { if x_nrows != y_nrows { return Err(Failed::fit( - &"Number of rows of X doesn\'t match number of rows of Y".to_string(), + "Number of rows of X doesn\'t match number of rows of Y", )); } @@ -167,7 +171,7 @@ impl> LinearRegression { Ok(LinearRegression { intercept: w.get(num_attributes, 0), coefficients: wights, - solver: parameters.solver, + _solver: parameters.solver, }) } @@ -196,6 +200,7 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn ols_fit_predict() { let x = DenseMatrix::from_2d_array(&[ @@ -246,7 +251,9 @@ mod tests { .all(|(&a, &b)| (a - b).abs() <= 5.0)); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] + #[cfg(feature = "serde")] fn serde() { let x = DenseMatrix::from_2d_array(&[ &[234.289, 235.6, 159.0, 107.608, 1947., 60.323], diff --git a/src/linear/logistic_regression.rs b/src/linear/logistic_regression.rs index a71ac45..1a20077 100644 --- a/src/linear/logistic_regression.rs +++ b/src/linear/logistic_regression.rs @@ -54,8 +54,8 @@ //! use std::cmp::Ordering; use std::fmt::Debug; -use std::marker::PhantomData; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::api::{Predictor, SupervisedEstimator}; @@ -67,12 +67,27 @@ use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult}; use crate::optimization::line_search::Backtracking; use crate::optimization::FunctionOrder; +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +/// Solver options for Logistic regression. Right now only LBFGS solver is supported. +pub enum LogisticRegressionSolverName { + /// Limited-memory Broyden–Fletcher–Goldfarb–Shanno method, see [LBFGS paper](http://users.iems.northwestern.edu/~nocedal/lbfgsb.html) + LBFGS, +} + /// Logistic Regression parameters -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct LogisticRegressionParameters {} +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct LogisticRegressionParameters { + /// Solver to use for estimation of regression coefficients. + pub solver: LogisticRegressionSolverName, + /// Regularization parameter. + pub alpha: T, +} /// Logistic Regression -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] pub struct LogisticRegression> { coefficients: M, intercept: M, @@ -99,12 +114,28 @@ trait ObjectiveFunction> { struct BinaryObjectiveFunction<'a, T: RealNumber, M: Matrix> { x: &'a M, y: Vec, - phantom: PhantomData<&'a T>, + alpha: T, } -impl Default for LogisticRegressionParameters { +impl LogisticRegressionParameters { + /// Solver to use for estimation of regression coefficients. + pub fn with_solver(mut self, solver: LogisticRegressionSolverName) -> Self { + self.solver = solver; + self + } + /// Regularization parameter. + pub fn with_alpha(mut self, alpha: T) -> Self { + self.alpha = alpha; + self + } +} + +impl Default for LogisticRegressionParameters { fn default() -> Self { - LogisticRegressionParameters {} + LogisticRegressionParameters { + solver: LogisticRegressionSolverName::LBFGS, + alpha: T::zero(), + } } } @@ -132,13 +163,22 @@ impl<'a, T: RealNumber, M: Matrix> ObjectiveFunction { fn f(&self, w_bias: &M) -> T { let mut f = T::zero(); - let (n, _) = self.x.shape(); + let (n, p) = self.x.shape(); for i in 0..n { let wx = BinaryObjectiveFunction::partial_dot(w_bias, self.x, 0, i); f += wx.ln_1pe() - (T::from(self.y[i]).unwrap()) * wx; } + if self.alpha > T::zero() { + let mut w_squared = T::zero(); + for i in 0..p { + let w = w_bias.get(0, i); + w_squared += w * w; + } + f += T::half() * self.alpha * w_squared; + } + f } @@ -156,6 +196,13 @@ impl<'a, T: RealNumber, M: Matrix> ObjectiveFunction } g.set(0, p, g.get(0, p) - dyi); } + + if self.alpha > T::zero() { + for i in 0..p { + let w = w_bias.get(0, i); + g.set(0, i, g.get(0, i) + self.alpha * w); + } + } } } @@ -163,7 +210,7 @@ struct MultiClassObjectiveFunction<'a, T: RealNumber, M: Matrix> { x: &'a M, y: Vec, k: usize, - phantom: PhantomData<&'a T>, + alpha: T, } impl<'a, T: RealNumber, M: Matrix> ObjectiveFunction @@ -185,6 +232,17 @@ impl<'a, T: RealNumber, M: Matrix> ObjectiveFunction f -= prob.get(0, self.y[i]).ln(); } + if self.alpha > T::zero() { + let mut w_squared = T::zero(); + for i in 0..self.k { + for j in 0..p { + let wi = w_bias.get(0, i * (p + 1) + j); + w_squared += wi * wi; + } + } + f += T::half() * self.alpha * w_squared; + } + f } @@ -215,16 +273,27 @@ impl<'a, T: RealNumber, M: Matrix> ObjectiveFunction g.set(0, j * (p + 1) + p, g.get(0, j * (p + 1) + p) - yi); } } + + if self.alpha > T::zero() { + for i in 0..self.k { + for j in 0..p { + let pos = i * (p + 1); + let wi = w.get(0, pos + j); + g.set(0, pos + j, g.get(0, pos + j) + self.alpha * wi); + } + } + } } } -impl> SupervisedEstimator +impl> + SupervisedEstimator> for LogisticRegression { fn fit( x: &M, y: &M::RowVector, - parameters: LogisticRegressionParameters, + parameters: LogisticRegressionParameters, ) -> Result { LogisticRegression::fit(x, y, parameters) } @@ -244,7 +313,7 @@ impl> LogisticRegression { pub fn fit( x: &M, y: &M::RowVector, - _parameters: LogisticRegressionParameters, + parameters: LogisticRegressionParameters, ) -> Result, Failed> { let y_m = M::from_row_vector(y.clone()); let (x_nrows, num_attributes) = x.shape(); @@ -252,7 +321,7 @@ impl> LogisticRegression { if x_nrows != y_nrows { return Err(Failed::fit( - &"Number of rows of X doesn\'t match number of rows of Y".to_string(), + "Number of rows of X doesn\'t match number of rows of Y", )); } @@ -278,7 +347,7 @@ impl> LogisticRegression { let objective = BinaryObjectiveFunction { x, y: yi, - phantom: PhantomData, + alpha: parameters.alpha, }; let result = LogisticRegression::minimize(x0, objective); @@ -300,7 +369,7 @@ impl> LogisticRegression { x, y: yi, k, - phantom: PhantomData, + alpha: parameters.alpha, }; let result = LogisticRegression::minimize(x0, objective); @@ -383,6 +452,7 @@ mod tests { use crate::linalg::naive::dense_matrix::*; use crate::metrics::accuracy; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn multiclass_objective_f() { let x = DenseMatrix::from_2d_array(&[ @@ -407,9 +477,9 @@ mod tests { let objective = MultiClassObjectiveFunction { x: &x, - y, + y: y.clone(), k: 3, - phantom: PhantomData, + alpha: 0.0, }; let mut g: DenseMatrix = DenseMatrix::zeros(1, 9); @@ -430,8 +500,27 @@ mod tests { ])); assert!((f - 408.0052230582765).abs() < std::f64::EPSILON); + + let objective_reg = MultiClassObjectiveFunction { + x: &x, + y: y.clone(), + k: 3, + alpha: 1.0, + }; + + let f = objective_reg.f(&DenseMatrix::row_vector_from_array(&[ + 1., 2., 3., 4., 5., 6., 7., 8., 9., + ])); + assert!((f - 487.5052).abs() < 1e-4); + + objective_reg.df( + &mut g, + &DenseMatrix::row_vector_from_array(&[1., 2., 3., 4., 5., 6., 7., 8., 9.]), + ); + assert!((g.get(0, 0).abs() - 32.0).abs() < 1e-4); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn binary_objective_f() { let x = DenseMatrix::from_2d_array(&[ @@ -456,8 +545,8 @@ mod tests { let objective = BinaryObjectiveFunction { x: &x, - y, - phantom: PhantomData, + y: y.clone(), + alpha: 0.0, }; let mut g: DenseMatrix = DenseMatrix::zeros(1, 3); @@ -472,8 +561,23 @@ mod tests { let f = objective.f(&DenseMatrix::row_vector_from_array(&[1., 2., 3.])); assert!((f - 59.76994756647412).abs() < std::f64::EPSILON); + + let objective_reg = BinaryObjectiveFunction { + x: &x, + y: y.clone(), + alpha: 1.0, + }; + + let f = objective_reg.f(&DenseMatrix::row_vector_from_array(&[1., 2., 3.])); + assert!((f - 62.2699).abs() < 1e-4); + + objective_reg.df(&mut g, &DenseMatrix::row_vector_from_array(&[1., 2., 3.])); + assert!((g.get(0, 0) - 27.0511).abs() < 1e-4); + assert!((g.get(0, 1) - 12.239).abs() < 1e-4); + assert!((g.get(0, 2) - 3.8693).abs() < 1e-4); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn lr_fit_predict() { let x = DenseMatrix::from_2d_array(&[ @@ -511,6 +615,7 @@ mod tests { ); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn lr_fit_predict_multiclass() { let blobs = make_blobs(15, 4, 3); @@ -523,8 +628,18 @@ mod tests { let y_hat = lr.predict(&x).unwrap(); assert!(accuracy(&y_hat, &y) > 0.9); + + let lr_reg = LogisticRegression::fit( + &x, + &y, + LogisticRegressionParameters::default().with_alpha(10.0), + ) + .unwrap(); + + assert!(lr_reg.coefficients().abs().sum() < lr.coefficients().abs().sum()); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn lr_fit_predict_binary() { let blobs = make_blobs(20, 4, 2); @@ -537,9 +652,20 @@ mod tests { let y_hat = lr.predict(&x).unwrap(); assert!(accuracy(&y_hat, &y) > 0.9); + + let lr_reg = LogisticRegression::fit( + &x, + &y, + LogisticRegressionParameters::default().with_alpha(10.0), + ) + .unwrap(); + + assert!(lr_reg.coefficients().abs().sum() < lr.coefficients().abs().sum()); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] + #[cfg(feature = "serde")] fn serde() { let x = DenseMatrix::from_2d_array(&[ &[1., -5.], @@ -568,6 +694,7 @@ mod tests { assert_eq!(lr, deserialized_lr); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn lr_fit_predict_iris() { let x = DenseMatrix::from_2d_array(&[ @@ -597,6 +724,12 @@ mod tests { ]; let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap(); + let lr_reg = LogisticRegression::fit( + &x, + &y, + LogisticRegressionParameters::default().with_alpha(1.0), + ) + .unwrap(); let y_hat = lr.predict(&x).unwrap(); @@ -607,5 +740,6 @@ mod tests { .sum(); assert!(error <= 1.0); + assert!(lr_reg.coefficients().abs().sum() < lr.coefficients().abs().sum()); } } diff --git a/src/linear/ridge_regression.rs b/src/linear/ridge_regression.rs index e9ed1ff..ecad250 100644 --- a/src/linear/ridge_regression.rs +++ b/src/linear/ridge_regression.rs @@ -58,6 +58,7 @@ //! use std::fmt::Debug; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::api::{Predictor, SupervisedEstimator}; @@ -66,7 +67,8 @@ use crate::linalg::BaseVector; use crate::linalg::Matrix; use crate::math::num::RealNumber; -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] /// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable. pub enum RidgeRegressionSolverName { /// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html) @@ -76,7 +78,8 @@ pub enum RidgeRegressionSolverName { } /// Ridge Regression parameters -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] pub struct RidgeRegressionParameters { /// Solver to use for estimation of regression coefficients. pub solver: RidgeRegressionSolverName, @@ -88,11 +91,12 @@ pub struct RidgeRegressionParameters { } /// Ridge regression -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] pub struct RidgeRegression> { coefficients: M, intercept: T, - solver: RidgeRegressionSolverName, + _solver: RidgeRegressionSolverName, } impl RidgeRegressionParameters { @@ -222,7 +226,7 @@ impl> RidgeRegression { Ok(RidgeRegression { intercept: b, coefficients: w, - solver: parameters.solver, + _solver: parameters.solver, }) } @@ -270,6 +274,7 @@ mod tests { use crate::linalg::naive::dense_matrix::*; use crate::metrics::mean_absolute_error; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn ridge_fit_predict() { let x = DenseMatrix::from_2d_array(&[ @@ -325,7 +330,9 @@ mod tests { assert!(mean_absolute_error(&y_hat_svd, &y) < 2.0); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] + #[cfg(feature = "serde")] fn serde() { let x = DenseMatrix::from_2d_array(&[ &[234.289, 235.6, 159.0, 107.608, 1947., 60.323], diff --git a/src/math/distance/euclidian.rs b/src/math/distance/euclidian.rs index 9034727..ed836f6 100644 --- a/src/math/distance/euclidian.rs +++ b/src/math/distance/euclidian.rs @@ -18,6 +18,7 @@ //! //! //! +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::math::num::RealNumber; @@ -25,7 +26,8 @@ use crate::math::num::RealNumber; use super::Distance; /// Euclidean distance is a measure of the true straight line distance between two points in Euclidean n-space. -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] pub struct Euclidian {} impl Euclidian { @@ -55,6 +57,7 @@ impl Distance, T> for Euclidian { mod tests { use super::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn squared_distance() { let a = vec![1., 2., 3.]; diff --git a/src/math/distance/hamming.rs b/src/math/distance/hamming.rs index 129fe16..da0d28f 100644 --- a/src/math/distance/hamming.rs +++ b/src/math/distance/hamming.rs @@ -19,6 +19,7 @@ //! //! +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::math::num::RealNumber; @@ -26,7 +27,8 @@ use crate::math::num::RealNumber; use super::Distance; /// While comparing two integer-valued vectors of equal length, Hamming distance is the number of bit positions in which the two bits are different -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] pub struct Hamming {} impl Distance, F> for Hamming { @@ -50,6 +52,7 @@ impl Distance, F> for Hamming { mod tests { use super::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn hamming_distance() { let a = vec![1, 0, 0, 1, 0, 0, 1]; diff --git a/src/math/distance/mahalanobis.rs b/src/math/distance/mahalanobis.rs index 84aa947..5a3fae8 100644 --- a/src/math/distance/mahalanobis.rs +++ b/src/math/distance/mahalanobis.rs @@ -44,6 +44,7 @@ use std::marker::PhantomData; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::math::num::RealNumber; @@ -52,7 +53,8 @@ use super::Distance; use crate::linalg::Matrix; /// Mahalanobis distance. -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] pub struct Mahalanobis> { /// covariance matrix of the dataset pub sigma: M, @@ -131,6 +133,7 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn mahalanobis_distance() { let data = DenseMatrix::from_2d_array(&[ diff --git a/src/math/distance/manhattan.rs b/src/math/distance/manhattan.rs index 9a69184..372f524 100644 --- a/src/math/distance/manhattan.rs +++ b/src/math/distance/manhattan.rs @@ -17,6 +17,7 @@ //! ``` //! //! +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::math::num::RealNumber; @@ -24,7 +25,8 @@ use crate::math::num::RealNumber; use super::Distance; /// Manhattan distance -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] pub struct Manhattan {} impl Distance, T> for Manhattan { @@ -46,6 +48,7 @@ impl Distance, T> for Manhattan { mod tests { use super::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn manhattan_distance() { let a = vec![1., 2., 3.]; diff --git a/src/math/distance/minkowski.rs b/src/math/distance/minkowski.rs index c5dd85d..bd9c1c4 100644 --- a/src/math/distance/minkowski.rs +++ b/src/math/distance/minkowski.rs @@ -21,6 +21,7 @@ //! //! +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::math::num::RealNumber; @@ -28,7 +29,8 @@ use crate::math::num::RealNumber; use super::Distance; /// Defines the Minkowski distance of order `p` -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] pub struct Minkowski { /// order, integer pub p: u16, @@ -59,6 +61,7 @@ impl Distance, T> for Minkowski { mod tests { use super::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn minkowski_distance() { let a = vec![1., 2., 3.]; diff --git a/src/math/num.rs b/src/math/num.rs index 490623c..7199949 100644 --- a/src/math/num.rs +++ b/src/math/num.rs @@ -136,6 +136,7 @@ impl RealNumber for f32 { mod tests { use super::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn sigmoid() { assert_eq!(1.0.sigmoid(), 0.7310585786300049); diff --git a/src/math/vector.rs b/src/math/vector.rs index 62cf63b..c38c7a4 100644 --- a/src/math/vector.rs +++ b/src/math/vector.rs @@ -30,6 +30,7 @@ impl> RealNumberVector for V { mod tests { use super::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn unique_with_indices() { let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0]; diff --git a/src/metrics/accuracy.rs b/src/metrics/accuracy.rs index ef7028f..0c9ce06 100644 --- a/src/metrics/accuracy.rs +++ b/src/metrics/accuracy.rs @@ -16,13 +16,15 @@ //! //! //! +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::linalg::BaseVector; use crate::math::num::RealNumber; /// Accuracy metric. -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] pub struct Accuracy {} impl Accuracy { @@ -55,6 +57,7 @@ impl Accuracy { mod tests { use super::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn accuracy() { let y_pred: Vec = vec![0., 2., 1., 3.]; diff --git a/src/metrics/auc.rs b/src/metrics/auc.rs index 0f8d56a..c413dc4 100644 --- a/src/metrics/auc.rs +++ b/src/metrics/auc.rs @@ -20,6 +20,7 @@ //! * ["The ROC-AUC and the Mann-Whitney U-test", Haupt, J.](https://johaupt.github.io/roc-auc/model%20evaluation/Area_under_ROC_curve.html) #![allow(non_snake_case)] +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::algorithm::sort::quick_sort::QuickArgSort; @@ -27,7 +28,8 @@ use crate::linalg::BaseVector; use crate::math::num::RealNumber; /// Area Under the Receiver Operating Characteristic Curve (ROC AUC) -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] pub struct AUC {} impl AUC { @@ -91,6 +93,7 @@ impl AUC { mod tests { use super::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn auc() { let y_true: Vec = vec![0., 0., 1., 1.]; diff --git a/src/metrics/cluster_hcv.rs b/src/metrics/cluster_hcv.rs index 29a9db2..f20f448 100644 --- a/src/metrics/cluster_hcv.rs +++ b/src/metrics/cluster_hcv.rs @@ -1,10 +1,12 @@ +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::linalg::BaseVector; use crate::math::num::RealNumber; use crate::metrics::cluster_helpers::*; -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] /// Homogeneity, completeness and V-Measure scores. pub struct HCVScore {} @@ -41,6 +43,7 @@ impl HCVScore { mod tests { use super::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn homogeneity_score() { let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0]; diff --git a/src/metrics/cluster_helpers.rs b/src/metrics/cluster_helpers.rs index a8fa7e5..05cf97c 100644 --- a/src/metrics/cluster_helpers.rs +++ b/src/metrics/cluster_helpers.rs @@ -101,6 +101,7 @@ pub fn mutual_info_score(contingency: &[Vec]) -> T { mod tests { use super::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn contingency_matrix_test() { let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0]; @@ -112,6 +113,7 @@ mod tests { ); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn entropy_test() { let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0]; @@ -119,6 +121,7 @@ mod tests { assert!((1.2770f32 - entropy(&v1).unwrap()).abs() < 1e-4); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn mutual_info_score_test() { let v1 = vec![0.0, 0.0, 1.0, 1.0, 2.0, 0.0, 4.0]; diff --git a/src/metrics/f1.rs b/src/metrics/f1.rs index 5c8537c..4ad6a5d 100644 --- a/src/metrics/f1.rs +++ b/src/metrics/f1.rs @@ -18,6 +18,7 @@ //! //! //! +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::linalg::BaseVector; @@ -26,7 +27,8 @@ use crate::metrics::precision::Precision; use crate::metrics::recall::Recall; /// F-measure -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] pub struct F1 { /// a positive real factor pub beta: T, @@ -57,6 +59,7 @@ impl F1 { mod tests { use super::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn f1() { let y_pred: Vec = vec![0., 0., 1., 1., 1., 1.]; diff --git a/src/metrics/mean_absolute_error.rs b/src/metrics/mean_absolute_error.rs index a069335..3e8ce85 100644 --- a/src/metrics/mean_absolute_error.rs +++ b/src/metrics/mean_absolute_error.rs @@ -18,12 +18,14 @@ //! //! //! +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::linalg::BaseVector; use crate::math::num::RealNumber; -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] /// Mean Absolute Error pub struct MeanAbsoluteError {} @@ -54,6 +56,7 @@ impl MeanAbsoluteError { mod tests { use super::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn mean_absolute_error() { let y_true: Vec = vec![3., -0.5, 2., 7.]; diff --git a/src/metrics/mean_squared_error.rs b/src/metrics/mean_squared_error.rs index 137c8e6..dce758d 100644 --- a/src/metrics/mean_squared_error.rs +++ b/src/metrics/mean_squared_error.rs @@ -18,12 +18,14 @@ //! //! //! +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::linalg::BaseVector; use crate::math::num::RealNumber; -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] /// Mean Squared Error pub struct MeanSquareError {} @@ -54,6 +56,7 @@ impl MeanSquareError { mod tests { use super::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn mean_squared_error() { let y_true: Vec = vec![3., -0.5, 2., 7.]; diff --git a/src/metrics/precision.rs b/src/metrics/precision.rs index 3524e7f..a0171aa 100644 --- a/src/metrics/precision.rs +++ b/src/metrics/precision.rs @@ -18,13 +18,15 @@ //! //! //! +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::linalg::BaseVector; use crate::math::num::RealNumber; /// Precision metric. -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] pub struct Precision {} impl Precision { @@ -75,6 +77,7 @@ impl Precision { mod tests { use super::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn precision() { let y_true: Vec = vec![0., 1., 1., 0.]; diff --git a/src/metrics/r2.rs b/src/metrics/r2.rs index cbcf7e4..738aae6 100644 --- a/src/metrics/r2.rs +++ b/src/metrics/r2.rs @@ -18,13 +18,15 @@ //! //! //! +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::linalg::BaseVector; use crate::math::num::RealNumber; /// Coefficient of Determination (R2) -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] pub struct R2 {} impl R2 { @@ -68,6 +70,7 @@ impl R2 { mod tests { use super::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn r2() { let y_true: Vec = vec![3., -0.5, 2., 7.]; diff --git a/src/metrics/recall.rs b/src/metrics/recall.rs index 4d2be95..18863ae 100644 --- a/src/metrics/recall.rs +++ b/src/metrics/recall.rs @@ -18,13 +18,15 @@ //! //! //! +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::linalg::BaseVector; use crate::math::num::RealNumber; /// Recall metric. -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] pub struct Recall {} impl Recall { @@ -75,6 +77,7 @@ impl Recall { mod tests { use super::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn recall() { let y_true: Vec = vec![0., 1., 1., 0.]; diff --git a/src/model_selection/kfold.rs b/src/model_selection/kfold.rs index 63827c4..8706954 100644 --- a/src/model_selection/kfold.rs +++ b/src/model_selection/kfold.rs @@ -144,6 +144,7 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn run_kfold_return_test_indices_simple() { let k = KFold { @@ -158,6 +159,7 @@ mod tests { assert_eq!(test_indices[2], (22..33).collect::>()); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn run_kfold_return_test_indices_odd() { let k = KFold { @@ -172,6 +174,7 @@ mod tests { assert_eq!(test_indices[2], (23..34).collect::>()); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn run_kfold_return_test_mask_simple() { let k = KFold { @@ -197,6 +200,7 @@ mod tests { } } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn run_kfold_return_split_simple() { let k = KFold { @@ -212,6 +216,7 @@ mod tests { assert_eq!(train_test_splits[1].1, (11..22).collect::>()); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn run_kfold_return_split_simple_shuffle() { let k = KFold { @@ -227,6 +232,7 @@ mod tests { assert_eq!(train_test_splits[1].1.len(), 11_usize); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn numpy_parity_test() { let k = KFold { @@ -247,6 +253,7 @@ mod tests { } } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn numpy_parity_test_shuffle() { let k = KFold { diff --git a/src/model_selection/mod.rs b/src/model_selection/mod.rs index 0058367..d283176 100644 --- a/src/model_selection/mod.rs +++ b/src/model_selection/mod.rs @@ -285,6 +285,7 @@ mod tests { use crate::model_selection::kfold::KFold; use crate::neighbors::knn_regressor::KNNRegressor; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn run_train_test_split() { let n = 123; @@ -308,6 +309,7 @@ mod tests { #[derive(Clone)] struct NoParameters {} + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn test_cross_validate_biased() { struct BiasedEstimator {} @@ -367,6 +369,7 @@ mod tests { assert_eq!(0.4, results.mean_train_score()); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn test_cross_validate_knn() { let x = DenseMatrix::from_2d_array(&[ @@ -411,6 +414,7 @@ mod tests { assert!(results.mean_train_score() < results.mean_test_score()); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn test_cross_val_predict_knn() { let x = DenseMatrix::from_2d_array(&[ diff --git a/src/naive_bayes/bernoulli.rs b/src/naive_bayes/bernoulli.rs index 388646f..95c4d36 100644 --- a/src/naive_bayes/bernoulli.rs +++ b/src/naive_bayes/bernoulli.rs @@ -42,15 +42,49 @@ use crate::math::num::RealNumber; use crate::math::vector::RealNumberVector; use crate::naive_bayes::{BaseNaiveBayes, NBDistribution}; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; /// Naive Bayes classifier for Bearnoulli features -#[derive(Serialize, Deserialize, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] struct BernoulliNBDistribution { /// class labels known to the classifier class_labels: Vec, + /// number of training samples observed in each class + class_count: Vec, + /// probability of each class class_priors: Vec, - feature_prob: Vec>, + /// Number of samples encountered for each (class, feature) + feature_count: Vec>, + /// probability of features per class + feature_log_prob: Vec>, + /// Number of features of each sample + n_features: usize, +} + +impl PartialEq for BernoulliNBDistribution { + fn eq(&self, other: &Self) -> bool { + if self.class_labels == other.class_labels + && self.class_count == other.class_count + && self.class_priors == other.class_priors + && self.feature_count == other.feature_count + && self.n_features == other.n_features + { + for (a, b) in self + .feature_log_prob + .iter() + .zip(other.feature_log_prob.iter()) + { + if !a.approximate_eq(b, T::epsilon()) { + return false; + } + } + true + } else { + false + } + } } impl> NBDistribution for BernoulliNBDistribution { @@ -63,9 +97,9 @@ impl> NBDistribution for BernoulliNBDistributi for feature in 0..j.len() { let value = j.get(feature); if value == T::one() { - likelihood += self.feature_prob[class_index][feature].ln(); + likelihood += self.feature_log_prob[class_index][feature]; } else { - likelihood += (T::one() - self.feature_prob[class_index][feature]).ln(); + likelihood += (T::one() - self.feature_log_prob[class_index][feature].exp()).ln(); } } likelihood @@ -77,7 +111,8 @@ impl> NBDistribution for BernoulliNBDistributi } /// `BernoulliNB` parameters. Use `Default::default()` for default values. -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] pub struct BernoulliNBParameters { /// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing). pub alpha: T, @@ -154,10 +189,10 @@ impl BernoulliNBDistribution { let y = y.to_vec(); let (class_labels, indices) = as RealNumberVector>::unique_with_indices(&y); - let mut class_count = vec![T::zero(); class_labels.len()]; + let mut class_count = vec![0_usize; class_labels.len()]; for class_index in indices.iter() { - class_count[*class_index] += T::one(); + class_count[*class_index] += 1; } let class_priors = if let Some(class_priors) = priors { @@ -170,25 +205,35 @@ impl BernoulliNBDistribution { } else { class_count .iter() - .map(|&c| c / T::from(n_samples).unwrap()) + .map(|&c| T::from(c).unwrap() / T::from(n_samples).unwrap()) .collect() }; - let mut feature_in_class_counter = vec![vec![T::zero(); n_features]; class_labels.len()]; + let mut feature_in_class_counter = vec![vec![0_usize; n_features]; class_labels.len()]; for (row, class_index) in row_iter(x).zip(indices) { for (idx, row_i) in row.iter().enumerate().take(n_features) { - feature_in_class_counter[class_index][idx] += *row_i; + feature_in_class_counter[class_index][idx] += + row_i.to_usize().ok_or_else(|| { + Failed::fit(&format!( + "Elements of the matrix should be 1.0 or 0.0 |found|=[{}]", + row_i + )) + })?; } } - let feature_prob = feature_in_class_counter + let feature_log_prob = feature_in_class_counter .iter() .enumerate() .map(|(class_index, feature_count)| { feature_count .iter() - .map(|&count| (count + alpha) / (class_count[class_index] + alpha * T::two())) + .map(|&count| { + ((T::from(count).unwrap() + alpha) + / (T::from(class_count[class_index]).unwrap() + alpha * T::two())) + .ln() + }) .collect() }) .collect(); @@ -196,13 +241,18 @@ impl BernoulliNBDistribution { Ok(Self { class_labels, class_priors, - feature_prob, + class_count, + feature_count: feature_in_class_counter, + feature_log_prob, + n_features, }) } } -/// BernoulliNB implements the categorical naive Bayes algorithm for categorically distributed data. -#[derive(Serialize, Deserialize, Debug, PartialEq)] +/// BernoulliNB implements the naive Bayes algorithm for data that follows the Bernoulli +/// distribution. +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, PartialEq)] pub struct BernoulliNB> { inner: BaseNaiveBayes>, binarize: Option, @@ -262,6 +312,34 @@ impl> BernoulliNB { self.inner.predict(x) } } + + /// Class labels known to the classifier. + /// Returns a vector of size n_classes. + pub fn classes(&self) -> &Vec { + &self.inner.distribution.class_labels + } + + /// Number of training samples observed in each class. + /// Returns a vector of size n_classes. + pub fn class_count(&self) -> &Vec { + &self.inner.distribution.class_count + } + + /// Number of features of each sample + pub fn n_features(&self) -> usize { + self.inner.distribution.n_features + } + + /// Number of samples encountered for each (class, feature) + /// Returns a 2d vector of shape (n_classes, n_features) + pub fn feature_count(&self) -> &Vec> { + &self.inner.distribution.feature_count + } + + /// Empirical log probability of features given a class + pub fn feature_log_prob(&self) -> &Vec> { + &self.inner.distribution.feature_log_prob + } } #[cfg(test)] @@ -269,6 +347,7 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn run_bernoulli_naive_bayes() { // Tests that BernoulliNB when alpha=1.0 gives the same values as @@ -292,10 +371,24 @@ mod tests { assert_eq!(bnb.inner.distribution.class_priors, &[0.75, 0.25]); assert_eq!( - bnb.inner.distribution.feature_prob, + bnb.feature_log_prob(), &[ - &[0.4, 0.8, 0.2, 0.4, 0.4, 0.2], - &[1. / 3.0, 2. / 3.0, 2. / 3.0, 1. / 3.0, 1. / 3.0, 2. / 3.0] + &[ + -0.916290731874155, + -0.2231435513142097, + -1.6094379124341003, + -0.916290731874155, + -0.916290731874155, + -1.6094379124341003 + ], + &[ + -1.0986122886681098, + -0.40546510810816444, + -0.40546510810816444, + -1.0986122886681098, + -1.0986122886681098, + -0.40546510810816444 + ] ] ); @@ -307,6 +400,7 @@ mod tests { assert_eq!(y_hat, &[1.]); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn bernoulli_nb_scikit_parity() { let x = DenseMatrix::::from_2d_array(&[ @@ -331,13 +425,36 @@ mod tests { let y_hat = bnb.predict(&x).unwrap(); + assert_eq!(bnb.classes(), &[0., 1., 2.]); + assert_eq!(bnb.class_count(), &[7, 3, 5]); + assert_eq!(bnb.n_features(), 10); + assert_eq!( + bnb.feature_count(), + &[ + &[5, 6, 6, 7, 6, 4, 6, 7, 7, 7], + &[3, 3, 3, 1, 3, 2, 3, 2, 2, 3], + &[4, 4, 3, 4, 5, 2, 4, 5, 3, 4] + ] + ); + assert!(bnb .inner .distribution .class_priors .approximate_eq(&vec!(0.46, 0.2, 0.33), 1e-2)); - assert!(bnb.inner.distribution.feature_prob[1].approximate_eq( - &vec!(0.8, 0.8, 0.8, 0.4, 0.8, 0.6, 0.8, 0.6, 0.6, 0.8), + assert!(bnb.feature_log_prob()[1].approximate_eq( + &vec![ + -0.22314355, + -0.22314355, + -0.22314355, + -0.91629073, + -0.22314355, + -0.51082562, + -0.22314355, + -0.51082562, + -0.51082562, + -0.22314355 + ], 1e-1 )); assert!(y_hat.approximate_eq( @@ -346,7 +463,9 @@ mod tests { )); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] + #[cfg(feature = "serde")] fn serde() { let x = DenseMatrix::::from_2d_array(&[ &[1., 1., 0., 0., 0., 0.], diff --git a/src/naive_bayes/categorical.rs b/src/naive_bayes/categorical.rs index c6f28bd..8706702 100644 --- a/src/naive_bayes/categorical.rs +++ b/src/naive_bayes/categorical.rs @@ -36,19 +36,38 @@ use crate::linalg::BaseVector; use crate::linalg::Matrix; use crate::math::num::RealNumber; use crate::naive_bayes::{BaseNaiveBayes, NBDistribution}; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; /// Naive Bayes classifier for categorical features -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] struct CategoricalNBDistribution { + /// number of training samples observed in each class + class_count: Vec, + /// class labels known to the classifier class_labels: Vec, + /// probability of each class class_priors: Vec, coefficients: Vec>>, + /// Number of features of each sample + n_features: usize, + /// Number of categories for each feature + n_categories: Vec, + /// Holds arrays of shape (n_classes, n_categories of respective feature) + /// for each feature. Each array provides the number of samples + /// encountered for each class and category of the specific feature. + category_count: Vec>>, } impl PartialEq for CategoricalNBDistribution { fn eq(&self, other: &Self) -> bool { - if self.class_labels == other.class_labels && self.class_priors == other.class_priors { + if self.class_labels == other.class_labels + && self.class_priors == other.class_priors + && self.n_features == other.n_features + && self.n_categories == other.n_categories + && self.class_count == other.class_count + { if self.coefficients.len() != other.coefficients.len() { return false; } @@ -88,8 +107,8 @@ impl> NBDistribution for CategoricalNBDistribu let mut likelihood = T::zero(); for feature in 0..j.len() { let value = j.get(feature).floor().to_usize().unwrap(); - if self.coefficients[class_index][feature].len() > value { - likelihood += self.coefficients[class_index][feature][value]; + if self.coefficients[feature][class_index].len() > value { + likelihood += self.coefficients[feature][class_index][value]; } else { return T::zero(); } @@ -142,17 +161,17 @@ impl CategoricalNBDistribution { let y_max = y .iter() .max() - .ok_or_else(|| Failed::fit(&"Failed to get the labels of y.".to_string()))?; + .ok_or_else(|| Failed::fit("Failed to get the labels of y."))?; let class_labels: Vec = (0..*y_max + 1) .map(|label| T::from(label).unwrap()) .collect(); - let mut classes_count: Vec = vec![T::zero(); class_labels.len()]; + let mut class_count = vec![0_usize; class_labels.len()]; for elem in y.iter() { - classes_count[*elem] += T::one(); + class_count[*elem] += 1; } - let mut feature_categories: Vec> = Vec::with_capacity(n_features); + let mut n_categories: Vec = Vec::with_capacity(n_features); for feature in 0..n_features { let feature_max = x .get_col_as_vec(feature) @@ -165,18 +184,15 @@ impl CategoricalNBDistribution { feature )) })?; - let feature_types = (0..feature_max + 1) - .map(|feat| T::from(feat).unwrap()) - .collect(); - feature_categories.push(feature_types); + n_categories.push(feature_max + 1); } let mut coefficients: Vec>> = Vec::with_capacity(class_labels.len()); - for (label, label_count) in class_labels.iter().zip(classes_count.iter()) { + let mut category_count: Vec>> = Vec::with_capacity(class_labels.len()); + for (feature_index, &n_categories_i) in n_categories.iter().enumerate().take(n_features) { let mut coef_i: Vec> = Vec::with_capacity(n_features); - for (feature_index, feature_options) in - feature_categories.iter().enumerate().take(n_features) - { + let mut category_count_i: Vec> = Vec::with_capacity(n_features); + for (label, &label_count) in class_labels.iter().zip(class_count.iter()) { let col = x .get_col_as_vec(feature_index) .iter() @@ -184,39 +200,48 @@ impl CategoricalNBDistribution { .filter(|(i, _j)| T::from(y[*i]).unwrap() == *label) .map(|(_, j)| *j) .collect::>(); - let mut feat_count: Vec = vec![T::zero(); feature_options.len()]; + let mut feat_count: Vec = vec![0_usize; n_categories_i]; for row in col.iter() { let index = row.floor().to_usize().unwrap(); - feat_count[index] += T::one(); + feat_count[index] += 1; } + let coef_i_j = feat_count .iter() .map(|c| { - ((*c + alpha) - / (*label_count + T::from(feature_options.len()).unwrap() * alpha)) + ((T::from(*c).unwrap() + alpha) + / (T::from(label_count).unwrap() + + T::from(n_categories_i).unwrap() * alpha)) .ln() }) .collect::>(); + category_count_i.push(feat_count); coef_i.push(coef_i_j); } + category_count.push(category_count_i); coefficients.push(coef_i); } - let class_priors = classes_count - .into_iter() - .map(|count| count / T::from(n_samples).unwrap()) + let class_priors = class_count + .iter() + .map(|&count| T::from(count).unwrap() / T::from(n_samples).unwrap()) .collect::>(); Ok(Self { + class_count, class_labels, class_priors, coefficients, + n_features, + n_categories, + category_count, }) } } /// `CategoricalNB` parameters. Use `Default::default()` for default values. -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] pub struct CategoricalNBParameters { /// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing). pub alpha: T, @@ -237,7 +262,8 @@ impl Default for CategoricalNBParameters { } /// CategoricalNB implements the categorical naive Bayes algorithm for categorically distributed data. -#[derive(Serialize, Deserialize, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, PartialEq)] pub struct CategoricalNB> { inner: BaseNaiveBayes>, } @@ -283,6 +309,41 @@ impl> CategoricalNB { pub fn predict(&self, x: &M) -> Result { self.inner.predict(x) } + + /// Class labels known to the classifier. + /// Returns a vector of size n_classes. + pub fn classes(&self) -> &Vec { + &self.inner.distribution.class_labels + } + + /// Number of training samples observed in each class. + /// Returns a vector of size n_classes. + pub fn class_count(&self) -> &Vec { + &self.inner.distribution.class_count + } + + /// Number of features of each sample + pub fn n_features(&self) -> usize { + self.inner.distribution.n_features + } + + /// Number of features of each sample + pub fn n_categories(&self) -> &Vec { + &self.inner.distribution.n_categories + } + + /// Holds arrays of shape (n_classes, n_categories of respective feature) + /// for each feature. Each array provides the number of samples + /// encountered for each class and category of the specific feature. + pub fn category_count(&self) -> &Vec>> { + &self.inner.distribution.category_count + } + /// Holds arrays of shape (n_classes, n_categories of respective feature) + /// for each feature. Each array provides the empirical log probability + /// of categories given the respective feature and class, ``P(x_i|y)``. + pub fn feature_log_prob(&self) -> &Vec>> { + &self.inner.distribution.coefficients + } } #[cfg(test)] @@ -290,6 +351,7 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn run_categorical_naive_bayes() { let x = DenseMatrix::from_2d_array(&[ @@ -311,11 +373,66 @@ mod tests { let y = vec![0., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 0.]; let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap(); + + // checking parity with scikit + assert_eq!(cnb.classes(), &[0., 1.]); + assert_eq!(cnb.class_count(), &[5, 9]); + assert_eq!(cnb.n_features(), 4); + assert_eq!(cnb.n_categories(), &[3, 3, 2, 2]); + assert_eq!( + cnb.category_count(), + &vec![ + vec![vec![3, 0, 2], vec![2, 4, 3]], + vec![vec![1, 2, 2], vec![3, 4, 2]], + vec![vec![1, 4], vec![6, 3]], + vec![vec![2, 3], vec![6, 3]] + ] + ); + + assert_eq!( + cnb.feature_log_prob(), + &vec![ + vec![ + vec![ + -0.6931471805599453, + -2.0794415416798357, + -0.9808292530117262 + ], + vec![ + -1.3862943611198906, + -0.8754687373538999, + -1.0986122886681098 + ] + ], + vec![ + vec![ + -1.3862943611198906, + -0.9808292530117262, + -0.9808292530117262 + ], + vec![ + -1.0986122886681098, + -0.8754687373538999, + -1.3862943611198906 + ] + ], + vec![ + vec![-1.252762968495368, -0.3364722366212129], + vec![-0.45198512374305727, -1.0116009116784799] + ], + vec![ + vec![-0.8472978603872037, -0.5596157879354228], + vec![-0.45198512374305727, -1.0116009116784799] + ] + ] + ); + let x_test = DenseMatrix::from_2d_array(&[&[0., 2., 1., 0.], &[2., 2., 0., 0.]]); let y_hat = cnb.predict(&x_test).unwrap(); assert_eq!(y_hat, vec![0., 1.]); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn run_categorical_naive_bayes2() { let x = DenseMatrix::from_2d_array(&[ @@ -344,7 +461,9 @@ mod tests { ); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] + #[cfg(feature = "serde")] fn serde() { let x = DenseMatrix::::from_2d_array(&[ &[3., 4., 0., 1.], diff --git a/src/naive_bayes/gaussian.rs b/src/naive_bayes/gaussian.rs index 2ac9892..bd23919 100644 --- a/src/naive_bayes/gaussian.rs +++ b/src/naive_bayes/gaussian.rs @@ -30,17 +30,21 @@ use crate::linalg::Matrix; use crate::math::num::RealNumber; use crate::math::vector::RealNumberVector; use crate::naive_bayes::{BaseNaiveBayes, NBDistribution}; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -/// Naive Bayes classifier for categorical features -#[derive(Serialize, Deserialize, Debug, PartialEq)] +/// Naive Bayes classifier using Gaussian distribution +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, PartialEq)] struct GaussianNBDistribution { /// class labels known to the classifier class_labels: Vec, + /// number of training samples observed in each class + class_count: Vec, /// probability of each class. class_priors: Vec, /// variance of each feature per class - sigma: Vec>, + var: Vec>, /// mean of each feature per class theta: Vec>, } @@ -55,18 +59,14 @@ impl> NBDistribution for GaussianNBDistributio } fn log_likelihood(&self, class_index: usize, j: &M::RowVector) -> T { - if class_index < self.class_labels.len() { - let mut likelihood = T::zero(); - for feature in 0..j.len() { - let value = j.get(feature); - let mean = self.theta[class_index][feature]; - let variance = self.sigma[class_index][feature]; - likelihood += self.calculate_log_probability(value, mean, variance); - } - likelihood - } else { - T::zero() + let mut likelihood = T::zero(); + for feature in 0..j.len() { + let value = j.get(feature); + let mean = self.theta[class_index][feature]; + let variance = self.var[class_index][feature]; + likelihood += self.calculate_log_probability(value, mean, variance); } + likelihood } fn classes(&self) -> &Vec { @@ -75,7 +75,8 @@ impl> NBDistribution for GaussianNBDistributio } /// `GaussianNB` parameters. Use `Default::default()` for default values. -#[derive(Serialize, Deserialize, Debug, Default, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Default, Clone)] pub struct GaussianNBParameters { /// Prior probabilities of the classes. If specified the priors are not adjusted according to the data pub priors: Option>, @@ -118,12 +119,12 @@ impl GaussianNBDistribution { let y = y.to_vec(); let (class_labels, indices) = as RealNumberVector>::unique_with_indices(&y); - let mut class_count = vec![T::zero(); class_labels.len()]; + let mut class_count = vec![0_usize; class_labels.len()]; let mut subdataset: Vec>> = vec![vec![]; class_labels.len()]; for (row, class_index) in row_iter(x).zip(indices.iter()) { - class_count[*class_index] += T::one(); + class_count[*class_index] += 1; subdataset[*class_index].push(row); } @@ -136,8 +137,8 @@ impl GaussianNBDistribution { class_priors } else { class_count - .into_iter() - .map(|c| c / T::from(n_samples).unwrap()) + .iter() + .map(|&c| T::from(c).unwrap() / T::from(n_samples).unwrap()) .collect() }; @@ -154,15 +155,16 @@ impl GaussianNBDistribution { }) .collect(); - let (sigma, theta): (Vec>, Vec>) = subdataset + let (var, theta): (Vec>, Vec>) = subdataset .iter() .map(|data| (data.var(0), data.mean(0))) .unzip(); Ok(Self { class_labels, + class_count, class_priors, - sigma, + var, theta, }) } @@ -177,8 +179,10 @@ impl GaussianNBDistribution { } } -/// GaussianNB implements the categorical naive Bayes algorithm for categorically distributed data. -#[derive(Serialize, Deserialize, Debug, PartialEq)] +/// GaussianNB implements the naive Bayes algorithm for data that follows the Gaussian +/// distribution. +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, PartialEq)] pub struct GaussianNB> { inner: BaseNaiveBayes>, } @@ -219,6 +223,36 @@ impl> GaussianNB { pub fn predict(&self, x: &M) -> Result { self.inner.predict(x) } + + /// Class labels known to the classifier. + /// Returns a vector of size n_classes. + pub fn classes(&self) -> &Vec { + &self.inner.distribution.class_labels + } + + /// Number of training samples observed in each class. + /// Returns a vector of size n_classes. + pub fn class_count(&self) -> &Vec { + &self.inner.distribution.class_count + } + + /// Probability of each class + /// Returns a vector of size n_classes. + pub fn class_priors(&self) -> &Vec { + &self.inner.distribution.class_priors + } + + /// Mean of each feature per class + /// Returns a 2d vector of shape (n_classes, n_features). + pub fn theta(&self) -> &Vec> { + &self.inner.distribution.theta + } + + /// Variance of each feature per class + /// Returns a 2d vector of shape (n_classes, n_features). + pub fn var(&self) -> &Vec> { + &self.inner.distribution.var + } } #[cfg(test)] @@ -226,6 +260,7 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn run_gaussian_naive_bayes() { let x = DenseMatrix::from_2d_array(&[ @@ -241,22 +276,28 @@ mod tests { let gnb = GaussianNB::fit(&x, &y, Default::default()).unwrap(); let y_hat = gnb.predict(&x).unwrap(); assert_eq!(y_hat, y); + + assert_eq!(gnb.classes(), &[1., 2.]); + + assert_eq!(gnb.class_count(), &[3, 3]); + assert_eq!( - gnb.inner.distribution.sigma, + gnb.var(), &[ &[0.666666666666667, 0.22222222222222232], &[0.666666666666667, 0.22222222222222232] ] ); - assert_eq!(gnb.inner.distribution.class_priors, &[0.5, 0.5]); + assert_eq!(gnb.class_priors(), &[0.5, 0.5]); assert_eq!( - gnb.inner.distribution.theta, + gnb.theta(), &[&[-2., -1.3333333333333333], &[2., 1.3333333333333333]] ); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn run_gaussian_naive_bayes_with_priors() { let x = DenseMatrix::from_2d_array(&[ @@ -273,10 +314,12 @@ mod tests { let parameters = GaussianNBParameters::default().with_priors(priors.clone()); let gnb = GaussianNB::fit(&x, &y, parameters).unwrap(); - assert_eq!(gnb.inner.distribution.class_priors, priors); + assert_eq!(gnb.class_priors(), &priors); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] + #[cfg(feature = "serde")] fn serde() { let x = DenseMatrix::::from_2d_array(&[ &[-1., -1.], diff --git a/src/naive_bayes/mod.rs b/src/naive_bayes/mod.rs index 7ab8b85..f7c8da6 100644 --- a/src/naive_bayes/mod.rs +++ b/src/naive_bayes/mod.rs @@ -39,6 +39,7 @@ use crate::error::Failed; use crate::linalg::BaseVector; use crate::linalg::Matrix; use crate::math::num::RealNumber; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use std::marker::PhantomData; @@ -55,7 +56,8 @@ pub(crate) trait NBDistribution> { } /// Base struct for the Naive Bayes classifier. -#[derive(Serialize, Deserialize, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, PartialEq)] pub(crate) struct BaseNaiveBayes, D: NBDistribution> { distribution: D, _phantom_t: PhantomData, diff --git a/src/naive_bayes/multinomial.rs b/src/naive_bayes/multinomial.rs index 4cae1f3..f42b99e 100644 --- a/src/naive_bayes/multinomial.rs +++ b/src/naive_bayes/multinomial.rs @@ -42,15 +42,25 @@ use crate::math::num::RealNumber; use crate::math::vector::RealNumberVector; use crate::naive_bayes::{BaseNaiveBayes, NBDistribution}; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; /// Naive Bayes classifier for Multinomial features -#[derive(Serialize, Deserialize, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, PartialEq)] struct MultinomialNBDistribution { /// class labels known to the classifier class_labels: Vec, + /// number of training samples observed in each class + class_count: Vec, + /// probability of each class class_priors: Vec, - feature_prob: Vec>, + /// Empirical log probability of features given a class + feature_log_prob: Vec>, + /// Number of samples encountered for each (class, feature) + feature_count: Vec>, + /// Number of features of each sample + n_features: usize, } impl> NBDistribution for MultinomialNBDistribution { @@ -62,7 +72,7 @@ impl> NBDistribution for MultinomialNBDistribu let mut likelihood = T::zero(); for feature in 0..j.len() { let value = j.get(feature); - likelihood += value * self.feature_prob[class_index][feature].ln(); + likelihood += value * self.feature_log_prob[class_index][feature]; } likelihood } @@ -73,7 +83,8 @@ impl> NBDistribution for MultinomialNBDistribu } /// `MultinomialNB` parameters. Use `Default::default()` for default values. -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] pub struct MultinomialNBParameters { /// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing). pub alpha: T, @@ -141,10 +152,10 @@ impl MultinomialNBDistribution { let y = y.to_vec(); let (class_labels, indices) = as RealNumberVector>::unique_with_indices(&y); - let mut class_count = vec![T::zero(); class_labels.len()]; + let mut class_count = vec![0_usize; class_labels.len()]; for class_index in indices.iter() { - class_count[*class_index] += T::one(); + class_count[*class_index] += 1; } let class_priors = if let Some(class_priors) = priors { @@ -157,39 +168,53 @@ impl MultinomialNBDistribution { } else { class_count .iter() - .map(|&c| c / T::from(n_samples).unwrap()) + .map(|&c| T::from(c).unwrap() / T::from(n_samples).unwrap()) .collect() }; - let mut feature_in_class_counter = vec![vec![T::zero(); n_features]; class_labels.len()]; + let mut feature_in_class_counter = vec![vec![0_usize; n_features]; class_labels.len()]; for (row, class_index) in row_iter(x).zip(indices) { for (idx, row_i) in row.iter().enumerate().take(n_features) { - feature_in_class_counter[class_index][idx] += *row_i; + feature_in_class_counter[class_index][idx] += + row_i.to_usize().ok_or_else(|| { + Failed::fit(&format!( + "Elements of the matrix should be convertible to usize |found|=[{}]", + row_i + )) + })?; } } - let feature_prob = feature_in_class_counter + let feature_log_prob = feature_in_class_counter .iter() .map(|feature_count| { - let n_c = feature_count.sum(); + let n_c: usize = feature_count.iter().sum(); feature_count .iter() - .map(|&count| (count + alpha) / (n_c + alpha * T::from(n_features).unwrap())) + .map(|&count| { + ((T::from(count).unwrap() + alpha) + / (T::from(n_c).unwrap() + alpha * T::from(n_features).unwrap())) + .ln() + }) .collect() }) .collect(); Ok(Self { + class_count, class_labels, class_priors, - feature_prob, + feature_log_prob, + feature_count: feature_in_class_counter, + n_features, }) } } -/// MultinomialNB implements the categorical naive Bayes algorithm for categorically distributed data. -#[derive(Serialize, Deserialize, Debug, PartialEq)] +/// MultinomialNB implements the naive Bayes algorithm for multinomially distributed data. +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, PartialEq)] pub struct MultinomialNB> { inner: BaseNaiveBayes>, } @@ -236,6 +261,35 @@ impl> MultinomialNB { pub fn predict(&self, x: &M) -> Result { self.inner.predict(x) } + + /// Class labels known to the classifier. + /// Returns a vector of size n_classes. + pub fn classes(&self) -> &Vec { + &self.inner.distribution.class_labels + } + + /// Number of training samples observed in each class. + /// Returns a vector of size n_classes. + pub fn class_count(&self) -> &Vec { + &self.inner.distribution.class_count + } + + /// Empirical log probability of features given a class, P(x_i|y). + /// Returns a 2d vector of shape (n_classes, n_features) + pub fn feature_log_prob(&self) -> &Vec> { + &self.inner.distribution.feature_log_prob + } + + /// Number of features of each sample + pub fn n_features(&self) -> usize { + self.inner.distribution.n_features + } + + /// Number of samples encountered for each (class, feature) + /// Returns a 2d vector of shape (n_classes, n_features) + pub fn feature_count(&self) -> &Vec> { + &self.inner.distribution.feature_count + } } #[cfg(test)] @@ -243,6 +297,7 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn run_multinomial_naive_bayes() { // Tests that MultinomialNB when alpha=1.0 gives the same values as @@ -264,12 +319,29 @@ mod tests { let y = vec![0., 0., 0., 1.]; let mnb = MultinomialNB::fit(&x, &y, Default::default()).unwrap(); + assert_eq!(mnb.classes(), &[0., 1.]); + assert_eq!(mnb.class_count(), &[3, 1]); + assert_eq!(mnb.inner.distribution.class_priors, &[0.75, 0.25]); assert_eq!( - mnb.inner.distribution.feature_prob, + mnb.feature_log_prob(), &[ - &[1. / 7., 3. / 7., 1. / 14., 1. / 7., 1. / 7., 1. / 14.], - &[1. / 9., 2. / 9.0, 2. / 9.0, 1. / 9.0, 1. / 9.0, 2. / 9.0] + &[ + (1_f64 / 7_f64).ln(), + (3_f64 / 7_f64).ln(), + (1_f64 / 14_f64).ln(), + (1_f64 / 7_f64).ln(), + (1_f64 / 7_f64).ln(), + (1_f64 / 14_f64).ln() + ], + &[ + (1_f64 / 9_f64).ln(), + (2_f64 / 9_f64).ln(), + (2_f64 / 9_f64).ln(), + (1_f64 / 9_f64).ln(), + (1_f64 / 9_f64).ln(), + (2_f64 / 9_f64).ln() + ] ] ); @@ -281,6 +353,7 @@ mod tests { assert_eq!(y_hat, &[0.]); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn multinomial_nb_scikit_parity() { let x = DenseMatrix::::from_2d_array(&[ @@ -303,6 +376,16 @@ mod tests { let y = vec![2., 2., 0., 0., 0., 2., 1., 1., 0., 1., 0., 0., 2., 0., 2.]; let nb = MultinomialNB::fit(&x, &y, Default::default()).unwrap(); + assert_eq!(nb.n_features(), 10); + assert_eq!( + nb.feature_count(), + &[ + &[12, 20, 11, 24, 12, 14, 13, 17, 13, 18], + &[9, 6, 9, 4, 7, 3, 8, 5, 4, 9], + &[10, 12, 9, 9, 11, 3, 9, 18, 10, 10] + ] + ); + let y_hat = nb.predict(&x).unwrap(); assert!(nb @@ -310,16 +393,29 @@ mod tests { .distribution .class_priors .approximate_eq(&vec!(0.46, 0.2, 0.33), 1e-2)); - assert!(nb.inner.distribution.feature_prob[1].approximate_eq( - &vec!(0.07, 0.12, 0.07, 0.15, 0.07, 0.09, 0.08, 0.10, 0.08, 0.11), - 1e-1 + assert!(nb.feature_log_prob()[1].approximate_eq( + &vec![ + -2.00148, + -2.35815494, + -2.00148, + -2.69462718, + -2.22462355, + -2.91777073, + -2.10684052, + -2.51230562, + -2.69462718, + -2.00148 + ], + 1e-5 )); assert!(y_hat.approximate_eq( &vec!(2.0, 2.0, 0.0, 0.0, 0.0, 2.0, 2.0, 1.0, 0.0, 1.0, 0.0, 2.0, 0.0, 0.0, 2.0), 1e-5 )); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] + #[cfg(feature = "serde")] fn serde() { let x = DenseMatrix::::from_2d_array(&[ &[1., 1., 0., 0., 0., 0.], diff --git a/src/neighbors/knn_classifier.rs b/src/neighbors/knn_classifier.rs index 97dd748..8723900 100644 --- a/src/neighbors/knn_classifier.rs +++ b/src/neighbors/knn_classifier.rs @@ -33,6 +33,7 @@ //! use std::marker::PhantomData; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName}; @@ -45,7 +46,8 @@ use crate::math::num::RealNumber; use crate::neighbors::KNNWeightFunction; /// `KNNClassifier` parameters. Use `Default::default()` for default values. -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] pub struct KNNClassifierParameters, T>> { /// a function that defines a distance between each pair of point in training data. /// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait. @@ -62,7 +64,8 @@ pub struct KNNClassifierParameters, T>> { } /// K Nearest Neighbors Classifier -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] pub struct KNNClassifier, T>> { classes: Vec, y: Vec, @@ -248,6 +251,7 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn knn_fit_predict() { let x = @@ -259,6 +263,7 @@ mod tests { assert_eq!(y.to_vec(), y_hat); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn knn_fit_predict_weighted() { let x = DenseMatrix::from_2d_array(&[&[1.], &[2.], &[3.], &[4.], &[5.]]); @@ -276,7 +281,9 @@ mod tests { assert_eq!(vec![3.0], y_hat); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] + #[cfg(feature = "serde")] fn serde() { let x = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]); diff --git a/src/neighbors/knn_regressor.rs b/src/neighbors/knn_regressor.rs index 4e73103..649cd1f 100644 --- a/src/neighbors/knn_regressor.rs +++ b/src/neighbors/knn_regressor.rs @@ -36,6 +36,7 @@ //! use std::marker::PhantomData; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::algorithm::neighbour::{KNNAlgorithm, KNNAlgorithmName}; @@ -48,7 +49,8 @@ use crate::math::num::RealNumber; use crate::neighbors::KNNWeightFunction; /// `KNNRegressor` parameters. Use `Default::default()` for default values. -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] pub struct KNNRegressorParameters, T>> { /// a function that defines a distance between each pair of point in training data. /// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait. @@ -65,7 +67,8 @@ pub struct KNNRegressorParameters, T>> { } /// K Nearest Neighbors Regressor -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] pub struct KNNRegressor, T>> { y: Vec, knn_algorithm: KNNAlgorithm, @@ -228,6 +231,7 @@ mod tests { use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::math::distance::Distances; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn knn_fit_predict_weighted() { let x = @@ -251,6 +255,7 @@ mod tests { } } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn knn_fit_predict_uniform() { let x = @@ -265,7 +270,9 @@ mod tests { } } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] + #[cfg(feature = "serde")] fn serde() { let x = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]); diff --git a/src/neighbors/mod.rs b/src/neighbors/mod.rs index 85ea6b8..86b1e46 100644 --- a/src/neighbors/mod.rs +++ b/src/neighbors/mod.rs @@ -33,6 +33,7 @@ //! use crate::math::num::RealNumber; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; /// K Nearest Neighbors Classifier @@ -48,7 +49,8 @@ pub mod knn_regressor; pub type KNNAlgorithmName = crate::algorithm::neighbour::KNNAlgorithmName; /// Weight function that is used to determine estimated value. -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] pub enum KNNWeightFunction { /// All k nearest points are weighted equally Uniform, diff --git a/src/optimization/first_order/gradient_descent.rs b/src/optimization/first_order/gradient_descent.rs index d57896f..a936ae4 100644 --- a/src/optimization/first_order/gradient_descent.rs +++ b/src/optimization/first_order/gradient_descent.rs @@ -50,14 +50,14 @@ impl FirstOrderOptimizer for GradientDescent { let f_alpha = |alpha: T| -> T { let mut dx = step.clone(); dx.mul_scalar_mut(alpha); - f(&dx.add_mut(&x)) // f(x) = f(x .+ gvec .* alpha) + f(dx.add_mut(&x)) // f(x) = f(x .+ gvec .* alpha) }; let df_alpha = |alpha: T| -> T { let mut dx = step.clone(); let mut dg = gvec.clone(); dx.mul_scalar_mut(alpha); - df(&mut dg, &dx.add_mut(&x)); //df(x) = df(x .+ gvec .* alpha) + df(&mut dg, dx.add_mut(&x)); //df(x) = df(x .+ gvec .* alpha) gvec.dot(&dg) }; @@ -66,7 +66,7 @@ impl FirstOrderOptimizer for GradientDescent { let ls_r = ls.search(&f_alpha, &df_alpha, alpha, fx, df0); alpha = ls_r.alpha; fx = ls_r.f_x; - x.add_mut(&step.mul_scalar_mut(alpha)); + x.add_mut(step.mul_scalar_mut(alpha)); df(&mut gvec, &x); gnorm = gvec.norm2(); } @@ -88,6 +88,7 @@ mod tests { use crate::optimization::line_search::Backtracking; use crate::optimization::FunctionOrder; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn gradient_descent() { let x0 = DenseMatrix::row_vector_from_array(&[-1., 1.]); diff --git a/src/optimization/first_order/lbfgs.rs b/src/optimization/first_order/lbfgs.rs index 5dedfe6..1b3bfde 100644 --- a/src/optimization/first_order/lbfgs.rs +++ b/src/optimization/first_order/lbfgs.rs @@ -1,3 +1,4 @@ +#![allow(clippy::suspicious_operation_groupings)] use std::default::Default; use std::fmt::Debug; @@ -7,6 +8,7 @@ use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult}; use crate::optimization::line_search::LineSearchMethod; use crate::optimization::{DF, F}; +#[allow(clippy::upper_case_acronyms)] pub struct LBFGS { pub max_iter: usize, pub g_rtol: T, @@ -116,14 +118,14 @@ impl LBFGS { let f_alpha = |alpha: T| -> T { let mut dx = state.s.clone(); dx.mul_scalar_mut(alpha); - f(&dx.add_mut(&state.x)) // f(x) = f(x .+ gvec .* alpha) + f(dx.add_mut(&state.x)) // f(x) = f(x .+ gvec .* alpha) }; let df_alpha = |alpha: T| -> T { let mut dx = state.s.clone(); let mut dg = state.x_df.clone(); dx.mul_scalar_mut(alpha); - df(&mut dg, &dx.add_mut(&state.x)); //df(x) = df(x .+ gvec .* alpha) + df(&mut dg, dx.add_mut(&state.x)); //df(x) = df(x .+ gvec .* alpha) state.x_df.dot(&dg) }; @@ -205,7 +207,7 @@ impl FirstOrderOptimizer for LBFGS { ) -> OptimizerResult { let mut state = self.init_state(x0); - df(&mut state.x_df, &x0); + df(&mut state.x_df, x0); let g_converged = state.x_df.norm(T::infinity()) < self.g_atol; let mut converged = g_converged; @@ -238,6 +240,7 @@ mod tests { use crate::optimization::line_search::Backtracking; use crate::optimization::FunctionOrder; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn lbfgs() { let x0 = DenseMatrix::row_vector_from_array(&[0., 0.]); diff --git a/src/optimization/line_search.rs b/src/optimization/line_search.rs index 99457c9..bbaa3fc 100644 --- a/src/optimization/line_search.rs +++ b/src/optimization/line_search.rs @@ -112,6 +112,7 @@ impl LineSearchMethod for Backtracking { mod tests { use super::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn backtracking() { let f = |x: f64| -> f64 { x.powf(2.) + x }; diff --git a/src/optimization/mod.rs b/src/optimization/mod.rs index e5e58d1..b0be9d6 100644 --- a/src/optimization/mod.rs +++ b/src/optimization/mod.rs @@ -4,6 +4,7 @@ pub mod line_search; pub type F<'a, T, X> = dyn for<'b> Fn(&'b X) -> T + 'a; pub type DF<'a, X> = dyn for<'b> Fn(&'b mut X, &'b X) + 'a; +#[allow(clippy::upper_case_acronyms)] #[derive(Debug, PartialEq)] pub enum FunctionOrder { SECOND, diff --git a/src/preprocessing/categorical.rs b/src/preprocessing/categorical.rs new file mode 100644 index 0000000..478e706 --- /dev/null +++ b/src/preprocessing/categorical.rs @@ -0,0 +1,333 @@ +//! # One-hot Encoding For [RealNumber](../../math/num/trait.RealNumber.html) Matricies +//! Transform a data [Matrix](../../linalg/trait.BaseMatrix.html) by replacing all categorical variables with their one-hot equivalents +//! +//! Internally OneHotEncoder treats every categorical column as a series and transforms it using [CategoryMapper](../series_encoder/struct.CategoryMapper.html) +//! +//! ### Usage Example +//! ``` +//! use smartcore::linalg::naive::dense_matrix::DenseMatrix; +//! use smartcore::preprocessing::categorical::{OneHotEncoder, OneHotEncoderParams}; +//! let data = DenseMatrix::from_2d_array(&[ +//! &[1.5, 1.0, 1.5, 3.0], +//! &[1.5, 2.0, 1.5, 4.0], +//! &[1.5, 1.0, 1.5, 5.0], +//! &[1.5, 2.0, 1.5, 6.0], +//! ]); +//! let encoder_params = OneHotEncoderParams::from_cat_idx(&[1, 3]); +//! // Infer number of categories from data and return a reusable encoder +//! let encoder = OneHotEncoder::fit(&data, encoder_params).unwrap(); +//! // Transform categorical to one-hot encoded (can transform similar) +//! let oh_data = encoder.transform(&data).unwrap(); +//! // Produces the following: +//! // &[1.5, 1.0, 0.0, 1.5, 1.0, 0.0, 0.0, 0.0] +//! // &[1.5, 0.0, 1.0, 1.5, 0.0, 1.0, 0.0, 0.0] +//! // &[1.5, 1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0] +//! // &[1.5, 0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0] +//! ``` +use std::iter; + +use crate::error::Failed; +use crate::linalg::Matrix; + +use crate::preprocessing::data_traits::{CategoricalFloat, Categorizable}; +use crate::preprocessing::series_encoder::CategoryMapper; + +/// OneHotEncoder Parameters +#[derive(Debug, Clone)] +pub struct OneHotEncoderParams { + /// Column number that contain categorical variable + pub col_idx_categorical: Option>, + /// (Currently not implemented) Try and infer which of the matrix columns are categorical variables + infer_categorical: bool, +} + +impl OneHotEncoderParams { + /// Generate parameters from categorical variable column numbers + pub fn from_cat_idx(categorical_params: &[usize]) -> Self { + Self { + col_idx_categorical: Some(categorical_params.to_vec()), + infer_categorical: false, + } + } +} + +/// Calculate the offset to parameters to due introduction of one-hot encoding +fn find_new_idxs(num_params: usize, cat_sizes: &[usize], cat_idxs: &[usize]) -> Vec { + // This functions uses iterators and returns a vector. + // In case we get a huge amount of paramenters this might be a problem + // todo: Change this such that it will return an iterator + + let cat_idx = cat_idxs.iter().copied().chain((num_params..).take(1)); + + // Offset is constant between two categorical values, here we calculate the number of steps + // that remain constant + let repeats = cat_idx.scan(0, |a, v| { + let im = v + 1 - *a; + *a = v; + Some(im) + }); + + // Calculate the offset to parameter idx due to newly intorduced one-hot vectors + let offset_ = cat_sizes.iter().scan(0, |a, &v| { + *a = *a + v - 1; + Some(*a) + }); + let offset = (0..1).chain(offset_); + + let new_param_idxs: Vec = (0..num_params) + .zip( + repeats + .zip(offset) + .flat_map(|(r, o)| iter::repeat(o).take(r)), + ) + .map(|(idx, ofst)| idx + ofst) + .collect(); + new_param_idxs +} + +fn validate_col_is_categorical(data: &[T]) -> bool { + for v in data { + if !v.is_valid() { + return false; + } + } + true +} + +/// Encode Categorical variavbles of data matrix to one-hot +#[derive(Debug, Clone)] +pub struct OneHotEncoder { + category_mappers: Vec>, + col_idx_categorical: Vec, +} + +impl OneHotEncoder { + /// Create an encoder instance with categories infered from data matrix + pub fn fit(data: &M, params: OneHotEncoderParams) -> Result + where + T: Categorizable, + M: Matrix, + { + match (params.col_idx_categorical, params.infer_categorical) { + (None, false) => Err(Failed::fit( + "Must pass categorical series ids or infer flag", + )), + + (Some(_idxs), true) => Err(Failed::fit( + "Ambigous parameters, got both infer and categroy ids", + )), + + (Some(mut idxs), false) => { + // make sure categories have same order as data columns + idxs.sort_unstable(); + + let (nrows, _) = data.shape(); + + // col buffer to avoid allocations + let mut col_buf: Vec = iter::repeat(T::zero()).take(nrows).collect(); + + let mut res: Vec> = Vec::with_capacity(idxs.len()); + + for &idx in &idxs { + data.copy_col_as_vec(idx, &mut col_buf); + if !validate_col_is_categorical(&col_buf) { + let msg = format!( + "Column {} of data matrix containts non categorizable (integer) values", + idx + ); + return Err(Failed::fit(&msg[..])); + } + let hashable_col = col_buf.iter().map(|v| v.to_category()); + res.push(CategoryMapper::fit_to_iter(hashable_col)); + } + + Ok(Self { + category_mappers: res, + col_idx_categorical: idxs, + }) + } + + (None, true) => { + todo!("Auto-Inference for Categorical Variables not yet implemented") + } + } + } + + /// Transform categorical variables to one-hot encoded and return a new matrix + pub fn transform(&self, x: &M) -> Result + where + T: Categorizable, + M: Matrix, + { + let (nrows, p) = x.shape(); + let additional_params: Vec = self + .category_mappers + .iter() + .map(|enc| enc.num_categories()) + .collect(); + + // Eac category of size v adds v-1 params + let expandws_p: usize = p + additional_params.iter().fold(0, |cs, &v| cs + v - 1); + + let new_col_idx = find_new_idxs(p, &additional_params[..], &self.col_idx_categorical[..]); + let mut res = M::zeros(nrows, expandws_p); + + for (pidx, &old_cidx) in self.col_idx_categorical.iter().enumerate() { + let cidx = new_col_idx[old_cidx]; + let col_iter = (0..nrows).map(|r| x.get(r, old_cidx).to_category()); + let sencoder = &self.category_mappers[pidx]; + let oh_series = col_iter.map(|c| sencoder.get_one_hot::>(&c)); + + for (row, oh_vec) in oh_series.enumerate() { + match oh_vec { + None => { + // Since we support T types, bad value in a series causes in to be invalid + let msg = format!("At least one value in column {} doesn't conform to category definition", old_cidx); + return Err(Failed::transform(&msg[..])); + } + Some(v) => { + // copy one hot vectors to their place in the data matrix; + for (col_ofst, &val) in v.iter().enumerate() { + res.set(row, cidx + col_ofst, val); + } + } + } + } + } + + // copy old data in x to their new location while skipping catergorical vars (already treated) + let mut skip_idx_iter = self.col_idx_categorical.iter(); + let mut cur_skip = skip_idx_iter.next(); + + for (old_p, &new_p) in new_col_idx.iter().enumerate() { + // if found treated varible, skip it + if let Some(&v) = cur_skip { + if v == old_p { + cur_skip = skip_idx_iter.next(); + continue; + } + } + + for r in 0..nrows { + let val = x.get(r, old_p); + res.set(r, new_p, val); + } + } + + Ok(res) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::linalg::naive::dense_matrix::DenseMatrix; + use crate::preprocessing::series_encoder::CategoryMapper; + + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + fn adjust_idxs() { + assert_eq!(find_new_idxs(0, &[], &[]), Vec::::new()); + // [0,1,2] -> [0, 1, 1, 1, 2] + assert_eq!(find_new_idxs(3, &[3], &[1]), vec![0, 1, 4]); + } + + fn build_cat_first_and_last() -> (DenseMatrix, DenseMatrix) { + let orig = DenseMatrix::from_2d_array(&[ + &[1.0, 1.5, 3.0], + &[2.0, 1.5, 4.0], + &[1.0, 1.5, 5.0], + &[2.0, 1.5, 6.0], + ]); + + let oh_enc = DenseMatrix::from_2d_array(&[ + &[1.0, 0.0, 1.5, 1.0, 0.0, 0.0, 0.0], + &[0.0, 1.0, 1.5, 0.0, 1.0, 0.0, 0.0], + &[1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0], + &[0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0], + ]); + + (orig, oh_enc) + } + + fn build_fake_matrix() -> (DenseMatrix, DenseMatrix) { + // Categorical first and last + let orig = DenseMatrix::from_2d_array(&[ + &[1.5, 1.0, 1.5, 3.0], + &[1.5, 2.0, 1.5, 4.0], + &[1.5, 1.0, 1.5, 5.0], + &[1.5, 2.0, 1.5, 6.0], + ]); + + let oh_enc = DenseMatrix::from_2d_array(&[ + &[1.5, 1.0, 0.0, 1.5, 1.0, 0.0, 0.0, 0.0], + &[1.5, 0.0, 1.0, 1.5, 0.0, 1.0, 0.0, 0.0], + &[1.5, 1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0], + &[1.5, 0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0], + ]); + + (orig, oh_enc) + } + + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + fn hash_encode_f64_series() { + let series = vec![3.0, 1.0, 2.0, 1.0]; + let hashable_series: Vec = + series.iter().map(|v| v.to_category()).collect(); + let enc = CategoryMapper::from_positional_category_vec(hashable_series); + let inv = enc.invert_one_hot(vec![0.0, 0.0, 1.0]); + let orig_val: f64 = inv.unwrap().into(); + assert_eq!(orig_val, 2.0); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + fn test_fit() { + let (x, _) = build_fake_matrix(); + let params = OneHotEncoderParams::from_cat_idx(&[1, 3]); + let oh_enc = OneHotEncoder::fit(&x, params).unwrap(); + assert_eq!(oh_enc.category_mappers.len(), 2); + + let num_cat: Vec = oh_enc + .category_mappers + .iter() + .map(|a| a.num_categories()) + .collect(); + assert_eq!(num_cat, vec![2, 4]); + } + + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + fn matrix_transform_test() { + let (x, expected_x) = build_fake_matrix(); + let params = OneHotEncoderParams::from_cat_idx(&[1, 3]); + let oh_enc = OneHotEncoder::fit(&x, params).unwrap(); + let nm = oh_enc.transform(&x).unwrap(); + assert_eq!(nm, expected_x); + + let (x, expected_x) = build_cat_first_and_last(); + let params = OneHotEncoderParams::from_cat_idx(&[0, 2]); + let oh_enc = OneHotEncoder::fit(&x, params).unwrap(); + let nm = oh_enc.transform(&x).unwrap(); + assert_eq!(nm, expected_x); + } + + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + fn fail_on_bad_category() { + let m = DenseMatrix::from_2d_array(&[ + &[1.0, 1.5, 3.0], + &[2.0, 1.5, 4.0], + &[1.0, 1.5, 5.0], + &[2.0, 1.5, 6.0], + ]); + + let params = OneHotEncoderParams::from_cat_idx(&[1]); + match OneHotEncoder::fit(&m, params) { + Err(_) => { + assert!(true); + } + _ => assert!(false), + } + } +} diff --git a/src/preprocessing/data_traits.rs b/src/preprocessing/data_traits.rs new file mode 100644 index 0000000..38d9e3e --- /dev/null +++ b/src/preprocessing/data_traits.rs @@ -0,0 +1,43 @@ +//! Traits to indicate that float variables can be viewed as categorical +//! This module assumes + +use crate::math::num::RealNumber; + +pub type CategoricalFloat = u16; + +// pub struct CategoricalFloat(u16); +const ERROR_MARGIN: f64 = 0.001; + +pub trait Categorizable: RealNumber { + type A; + + fn to_category(self) -> CategoricalFloat; + + fn is_valid(self) -> bool; +} + +impl Categorizable for f32 { + type A = CategoricalFloat; + + fn to_category(self) -> CategoricalFloat { + self as CategoricalFloat + } + + fn is_valid(self) -> bool { + let a = self.to_category(); + (a as f32 - self).abs() < (ERROR_MARGIN as f32) + } +} + +impl Categorizable for f64 { + type A = CategoricalFloat; + + fn to_category(self) -> CategoricalFloat { + self as CategoricalFloat + } + + fn is_valid(self) -> bool { + let a = self.to_category(); + (a as f64 - self).abs() < ERROR_MARGIN + } +} diff --git a/src/preprocessing/mod.rs b/src/preprocessing/mod.rs new file mode 100644 index 0000000..32a0cfa --- /dev/null +++ b/src/preprocessing/mod.rs @@ -0,0 +1,5 @@ +/// Transform a data matrix by replaceing all categorical variables with their one-hot vector equivalents +pub mod categorical; +mod data_traits; +/// Encode a series (column, array) of categorical variables as one-hot vectors +pub mod series_encoder; diff --git a/src/preprocessing/series_encoder.rs b/src/preprocessing/series_encoder.rs new file mode 100644 index 0000000..ab99b08 --- /dev/null +++ b/src/preprocessing/series_encoder.rs @@ -0,0 +1,282 @@ +#![allow(clippy::ptr_arg)] +//! # Series Encoder +//! Encode a series of categorical features as a one-hot numeric array. + +use crate::error::Failed; +use crate::linalg::BaseVector; +use crate::math::num::RealNumber; +use std::collections::HashMap; +use std::hash::Hash; + +/// ## Bi-directional map category <-> label num. +/// Turn Hashable objects into a one-hot vectors or ordinal values. +/// This struct encodes single class per exmample +/// +/// You can fit_to_iter a category enumeration by passing an iterator of categories. +/// category numbers will be assigned in the order they are encountered +/// +/// Example: +/// ``` +/// use std::collections::HashMap; +/// use smartcore::preprocessing::series_encoder::CategoryMapper; +/// +/// let fake_categories: Vec = vec![1, 2, 3, 4, 5, 3, 5, 3, 1, 2, 4]; +/// let it = fake_categories.iter().map(|&a| a); +/// let enc = CategoryMapper::::fit_to_iter(it); +/// let oh_vec: Vec = enc.get_one_hot(&1).unwrap(); +/// // notice that 1 is actually a zero-th positional category +/// assert_eq!(oh_vec, vec![1.0, 0.0, 0.0, 0.0, 0.0]); +/// ``` +/// +/// You can also pass a predefined category enumeration such as a hashmap `HashMap` or a vector `Vec` +/// +/// +/// ``` +/// use std::collections::HashMap; +/// use smartcore::preprocessing::series_encoder::CategoryMapper; +/// +/// let category_map: HashMap<&str, usize> = +/// vec![("cat", 2), ("background",0), ("dog", 1)] +/// .into_iter() +/// .collect(); +/// let category_vec = vec!["background", "dog", "cat"]; +/// +/// let enc_lv = CategoryMapper::<&str>::from_positional_category_vec(category_vec); +/// let enc_lm = CategoryMapper::<&str>::from_category_map(category_map); +/// +/// // ["background", "dog", "cat"] +/// println!("{:?}", enc_lv.get_categories()); +/// let lv: Vec = enc_lv.get_one_hot(&"dog").unwrap(); +/// let lm: Vec = enc_lm.get_one_hot(&"dog").unwrap(); +/// assert_eq!(lv, lm); +/// ``` +#[derive(Debug, Clone)] +pub struct CategoryMapper { + category_map: HashMap, + categories: Vec, + num_categories: usize, +} + +impl CategoryMapper +where + C: Hash + Eq + Clone, +{ + /// Get the number of categories in the mapper + pub fn num_categories(&self) -> usize { + self.num_categories + } + + /// Fit an encoder to a lable iterator + pub fn fit_to_iter(categories: impl Iterator) -> Self { + let mut category_map: HashMap = HashMap::new(); + let mut category_num = 0usize; + let mut unique_lables: Vec = Vec::new(); + + for l in categories { + if !category_map.contains_key(&l) { + category_map.insert(l.clone(), category_num); + unique_lables.push(l.clone()); + category_num += 1; + } + } + Self { + category_map, + num_categories: category_num, + categories: unique_lables, + } + } + + /// Build an encoder from a predefined (category -> class number) map + pub fn from_category_map(category_map: HashMap) -> Self { + let mut _unique_cat: Vec<(C, usize)> = + category_map.iter().map(|(k, v)| (k.clone(), *v)).collect(); + _unique_cat.sort_by(|a, b| a.1.cmp(&b.1)); + let categories: Vec = _unique_cat.into_iter().map(|a| a.0).collect(); + Self { + num_categories: categories.len(), + categories, + category_map, + } + } + + /// Build an encoder from a predefined positional category-class num vector + pub fn from_positional_category_vec(categories: Vec) -> Self { + let category_map: HashMap = categories + .iter() + .enumerate() + .map(|(v, k)| (k.clone(), v)) + .collect(); + Self { + num_categories: categories.len(), + category_map, + categories, + } + } + + /// Get label num of a category + pub fn get_num(&self, category: &C) -> Option<&usize> { + self.category_map.get(category) + } + + /// Return category corresponding to label num + pub fn get_cat(&self, num: usize) -> &C { + &self.categories[num] + } + + /// List all categories (position = category number) + pub fn get_categories(&self) -> &[C] { + &self.categories[..] + } + + /// Get one-hot encoding of the category + pub fn get_one_hot(&self, category: &C) -> Option + where + U: RealNumber, + V: BaseVector, + { + self.get_num(category) + .map(|&idx| make_one_hot::(idx, self.num_categories)) + } + + /// Invert one-hot vector, back to the category + pub fn invert_one_hot(&self, one_hot: V) -> Result + where + U: RealNumber, + V: BaseVector, + { + let pos = U::one(); + + let oh_it = (0..one_hot.len()).map(|idx| one_hot.get(idx)); + + let s: Vec = oh_it + .enumerate() + .filter_map(|(idx, v)| if v == pos { Some(idx) } else { None }) + .collect(); + + if s.len() == 1 { + let idx = s[0]; + return Ok(self.get_cat(idx).clone()); + } + let pos_entries = format!( + "Expected a single positive entry, {} entires found", + s.len() + ); + Err(Failed::transform(&pos_entries[..])) + } + + /// Get ordinal encoding of the catergory + pub fn get_ordinal(&self, category: &C) -> Option + where + U: RealNumber, + { + match self.get_num(category) { + None => None, + Some(&idx) => U::from_usize(idx), + } + } +} + +/// Make a one-hot encoded vector from a categorical variable +/// +/// Example: +/// ``` +/// use smartcore::preprocessing::series_encoder::make_one_hot; +/// let one_hot: Vec = make_one_hot(2, 3); +/// assert_eq!(one_hot, vec![0.0, 0.0, 1.0]); +/// ``` +pub fn make_one_hot(category_idx: usize, num_categories: usize) -> V +where + T: RealNumber, + V: BaseVector, +{ + let pos = T::one(); + let mut z = V::zeros(num_categories); + z.set(category_idx, pos); + z +} + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + fn from_categories() { + let fake_categories: Vec = vec![1, 2, 3, 4, 5, 3, 5, 3, 1, 2, 4]; + let it = fake_categories.iter().map(|&a| a); + let enc = CategoryMapper::::fit_to_iter(it); + let oh_vec: Vec = match enc.get_one_hot(&1) { + None => panic!("Wrong categories"), + Some(v) => v, + }; + let res: Vec = vec![1f64, 0f64, 0f64, 0f64, 0f64]; + assert_eq!(oh_vec, res); + } + + fn build_fake_str_enc<'a>() -> CategoryMapper<&'a str> { + let fake_category_pos = vec!["background", "dog", "cat"]; + let enc = CategoryMapper::<&str>::from_positional_category_vec(fake_category_pos); + enc + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + fn ordinal_encoding() { + let enc = build_fake_str_enc(); + assert_eq!(1f64, enc.get_ordinal::(&"dog").unwrap()) + } + + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + fn category_map_and_vec() { + let category_map: HashMap<&str, usize> = vec![("background", 0), ("dog", 1), ("cat", 2)] + .into_iter() + .collect(); + let enc = CategoryMapper::<&str>::from_category_map(category_map); + let oh_vec: Vec = match enc.get_one_hot(&"dog") { + None => panic!("Wrong categories"), + Some(v) => v, + }; + let res: Vec = vec![0f64, 1f64, 0f64]; + assert_eq!(oh_vec, res); + } + + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + fn positional_categories_vec() { + let enc = build_fake_str_enc(); + let oh_vec: Vec = match enc.get_one_hot(&"dog") { + None => panic!("Wrong categories"), + Some(v) => v, + }; + let res: Vec = vec![0.0, 1.0, 0.0]; + assert_eq!(oh_vec, res); + } + + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + fn invert_label_test() { + let enc = build_fake_str_enc(); + let res: Vec = vec![0.0, 1.0, 0.0]; + let lab = enc.invert_one_hot(res).unwrap(); + assert_eq!(lab, "dog"); + if let Err(e) = enc.invert_one_hot(vec![0.0, 0.0, 0.0]) { + let pos_entries = format!("Expected a single positive entry, 0 entires found"); + assert_eq!(e, Failed::transform(&pos_entries[..])); + }; + } + + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + fn test_many_categorys() { + let enc = build_fake_str_enc(); + let cat_it = ["dog", "cat", "fish", "background"].iter().cloned(); + let res: Vec>> = cat_it.map(|v| enc.get_one_hot(&v)).collect(); + let v = vec![ + Some(vec![0.0, 1.0, 0.0]), + Some(vec![0.0, 0.0, 1.0]), + None, + Some(vec![1.0, 0.0, 0.0]), + ]; + assert_eq!(res, v) + } +} diff --git a/src/svm/mod.rs b/src/svm/mod.rs index 1e013d2..55df584 100644 --- a/src/svm/mod.rs +++ b/src/svm/mod.rs @@ -26,6 +26,7 @@ pub mod svc; pub mod svr; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::linalg::BaseVector; @@ -93,18 +94,21 @@ impl Kernels { } /// Linear Kernel -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] pub struct LinearKernel {} /// Radial basis function (Gaussian) kernel -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] pub struct RBFKernel { /// kernel coefficient pub gamma: T, } /// Polynomial kernel -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] pub struct PolynomialKernel { /// degree of the polynomial pub degree: T, @@ -115,7 +119,8 @@ pub struct PolynomialKernel { } /// Sigmoid (hyperbolic tangent) kernel -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] pub struct SigmoidKernel { /// kernel coefficient pub gamma: T, @@ -154,6 +159,7 @@ impl> Kernel for SigmoidKernel { mod tests { use super::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn linear_kernel() { let v1 = vec![1., 2., 3.]; @@ -162,6 +168,7 @@ mod tests { assert_eq!(32f64, Kernels::linear().apply(&v1, &v2)); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn rbf_kernel() { let v1 = vec![1., 2., 3.]; @@ -170,6 +177,7 @@ mod tests { assert!((0.2265f64 - Kernels::rbf(0.055).apply(&v1, &v2)).abs() < 1e-4); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn polynomial_kernel() { let v1 = vec![1., 2., 3.]; @@ -181,6 +189,7 @@ mod tests { ); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn sigmoid_kernel() { let v1 = vec![1., 2., 3.]; diff --git a/src/svm/svc.rs b/src/svm/svc.rs index 095d555..7432b9c 100644 --- a/src/svm/svc.rs +++ b/src/svm/svc.rs @@ -57,9 +57,9 @@ //! let y = vec![ 0., 0., 0., 0., 0., 0., 0., 0., //! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]; //! -//! let svr = SVC::fit(&x, &y, SVCParameters::default().with_c(200.0)).unwrap(); +//! let svc = SVC::fit(&x, &y, SVCParameters::default().with_c(200.0)).unwrap(); //! -//! let y_hat = svr.predict(&x).unwrap(); +//! let y_hat = svc.predict(&x).unwrap(); //! ``` //! //! ## References: @@ -76,6 +76,7 @@ use std::marker::PhantomData; use rand::seq::SliceRandom; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::api::{Predictor, SupervisedEstimator}; @@ -85,7 +86,8 @@ use crate::linalg::Matrix; use crate::math::num::RealNumber; use crate::svm::{Kernel, Kernels, LinearKernel}; -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] /// SVC Parameters pub struct SVCParameters, K: Kernel> { /// Number of epochs. @@ -100,11 +102,15 @@ pub struct SVCParameters, K: Kernel m: PhantomData, } -#[derive(Serialize, Deserialize, Debug)] -#[serde(bound( - serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize", - deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>", -))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] +#[cfg_attr( + feature = "serde", + serde(bound( + serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize", + deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>", + )) +)] /// Support Vector Classifier pub struct SVC, K: Kernel> { classes: Vec, @@ -114,7 +120,8 @@ pub struct SVC, K: Kernel> { b: T, } -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] struct SupportVector> { index: usize, x: V, @@ -215,7 +222,7 @@ impl, K: Kernel> SVC { if n != y.len() { return Err(Failed::fit( - &"Number of rows of X doesn\'t match number of rows of Y".to_string(), + "Number of rows of X doesn\'t match number of rows of Y", )); } @@ -370,7 +377,7 @@ impl<'a, T: RealNumber, M: Matrix, K: Kernel> Optimizer<'a, Optimizer { x, y, - parameters: ¶meters, + parameters, svmin: 0, svmax: 0, gmin: T::max_value(), @@ -582,7 +589,7 @@ impl<'a, T: RealNumber, M: Matrix, K: Kernel> Optimizer<'a, for i in 0..self.sv.len() { let v = &self.sv[i]; let z = v.grad - gm; - let k = cache.get(sv1, &v); + let k = cache.get(sv1, v); let mut curv = km + v.k - T::two() * k; if curv <= T::zero() { curv = self.tau; @@ -719,8 +726,10 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::*; use crate::metrics::accuracy; + #[cfg(feature = "serde")] use crate::svm::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn svc_fit_predict() { let x = DenseMatrix::from_2d_array(&[ @@ -763,6 +772,7 @@ mod tests { assert!(accuracy(&y_hat, &y) >= 0.9); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn svc_fit_predict_rbf() { let x = DenseMatrix::from_2d_array(&[ @@ -806,7 +816,9 @@ mod tests { assert!(accuracy(&y_hat, &y) >= 0.9); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] + #[cfg(feature = "serde")] fn svc_serde() { let x = DenseMatrix::from_2d_array(&[ &[5.1, 3.5, 1.4, 0.2], @@ -835,11 +847,11 @@ mod tests { -1., -1., -1., -1., -1., -1., -1., -1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., ]; - let svr = SVC::fit(&x, &y, Default::default()).unwrap(); + let svc = SVC::fit(&x, &y, Default::default()).unwrap(); - let deserialized_svr: SVC, LinearKernel> = - serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap(); + let deserialized_svc: SVC, LinearKernel> = + serde_json::from_str(&serde_json::to_string(&svc).unwrap()).unwrap(); - assert_eq!(svr, deserialized_svr); + assert_eq!(svc, deserialized_svc); } } diff --git a/src/svm/svr.rs b/src/svm/svr.rs index 9eb6046..3257111 100644 --- a/src/svm/svr.rs +++ b/src/svm/svr.rs @@ -68,6 +68,7 @@ use std::cell::{Ref, RefCell}; use std::fmt::Debug; use std::marker::PhantomData; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::api::{Predictor, SupervisedEstimator}; @@ -77,7 +78,8 @@ use crate::linalg::Matrix; use crate::math::num::RealNumber; use crate::svm::{Kernel, Kernels, LinearKernel}; -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] /// SVR Parameters pub struct SVRParameters, K: Kernel> { /// Epsilon in the epsilon-SVR model. @@ -92,11 +94,15 @@ pub struct SVRParameters, K: Kernel m: PhantomData, } -#[derive(Serialize, Deserialize, Debug)] -#[serde(bound( - serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize", - deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>", -))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] +#[cfg_attr( + feature = "serde", + serde(bound( + serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize", + deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>", + )) +)] /// Epsilon-Support Vector Regression pub struct SVR, K: Kernel> { @@ -106,7 +112,8 @@ pub struct SVR, K: Kernel> { b: T, } -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] struct SupportVector> { index: usize, x: V, @@ -205,7 +212,7 @@ impl, K: Kernel> SVR { if n != y.len() { return Err(Failed::fit( - &"Number of rows of X doesn\'t match number of rows of Y".to_string(), + "Number of rows of X doesn\'t match number of rows of Y", )); } @@ -526,8 +533,10 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::*; use crate::metrics::mean_squared_error; + #[cfg(feature = "serde")] use crate::svm::*; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn svr_fit_predict() { let x = DenseMatrix::from_2d_array(&[ @@ -561,7 +570,9 @@ mod tests { assert!(mean_squared_error(&y_hat, &y) < 2.5); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] + #[cfg(feature = "serde")] fn svr_serde() { let x = DenseMatrix::from_2d_array(&[ &[234.289, 235.6, 159.0, 107.608, 1947., 60.323], diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index 3a92c54..d86f59a 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -68,6 +68,8 @@ use std::fmt::Debug; use std::marker::PhantomData; use rand::seq::SliceRandom; +use rand::Rng; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::algorithm::sort::quick_sort::QuickArgSort; @@ -76,7 +78,8 @@ use crate::error::Failed; use crate::linalg::Matrix; use crate::math::num::RealNumber; -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] /// Parameters of Decision Tree pub struct DecisionTreeClassifierParameters { /// Split criteria to use when building a tree. @@ -90,7 +93,8 @@ pub struct DecisionTreeClassifierParameters { } /// Decision Tree -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] pub struct DecisionTreeClassifier { nodes: Vec>, parameters: DecisionTreeClassifierParameters, @@ -100,7 +104,8 @@ pub struct DecisionTreeClassifier { } /// The function to measure the quality of a split. -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] pub enum SplitCriterion { /// [Gini index](../decision_tree_classifier/index.html) Gini, @@ -110,9 +115,10 @@ pub enum SplitCriterion { ClassificationError, } -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] struct Node { - index: usize, + _index: usize, output: usize, split_feature: usize, split_value: Option, @@ -198,7 +204,7 @@ impl Default for DecisionTreeClassifierParameters { impl Node { fn new(index: usize, output: usize) -> Self { Node { - index, + _index: index, output, split_feature: 0, split_value: Option::None, @@ -323,7 +329,14 @@ impl DecisionTreeClassifier { ) -> Result, Failed> { let (x_nrows, num_attributes) = x.shape(); let samples = vec![1; x_nrows]; - DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters) + DecisionTreeClassifier::fit_weak_learner( + x, + y, + samples, + num_attributes, + parameters, + &mut rand::thread_rng(), + ) } pub(crate) fn fit_weak_learner>( @@ -332,6 +345,7 @@ impl DecisionTreeClassifier { samples: Vec, mtry: usize, parameters: DecisionTreeClassifierParameters, + rng: &mut impl Rng, ) -> Result, Failed> { let y_m = M::from_row_vector(y.clone()); let (_, y_ncols) = y_m.shape(); @@ -375,17 +389,17 @@ impl DecisionTreeClassifier { depth: 0, }; - let mut visitor = NodeVisitor::::new(0, samples, &order, &x, &yi, 1); + let mut visitor = NodeVisitor::::new(0, samples, &order, x, &yi, 1); let mut visitor_queue: LinkedList> = LinkedList::new(); - if tree.find_best_cutoff(&mut visitor, mtry) { + if tree.find_best_cutoff(&mut visitor, mtry, rng) { visitor_queue.push_back(visitor); } while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) { match visitor_queue.pop_front() { - Some(node) => tree.split(node, mtry, &mut visitor_queue), + Some(node) => tree.split(node, mtry, &mut visitor_queue, rng), None => break, }; } @@ -438,6 +452,7 @@ impl DecisionTreeClassifier { &mut self, visitor: &mut NodeVisitor<'_, T, M>, mtry: usize, + rng: &mut impl Rng, ) -> bool { let (n_rows, n_attr) = visitor.x.shape(); @@ -477,7 +492,7 @@ impl DecisionTreeClassifier { let mut variables = (0..n_attr).collect::>(); if mtry < n_attr { - variables.shuffle(&mut rand::thread_rng()); + variables.shuffle(rng); } for variable in variables.iter().take(mtry) { @@ -499,7 +514,7 @@ impl DecisionTreeClassifier { visitor: &mut NodeVisitor<'_, T, M>, n: usize, count: &[usize], - false_count: &mut Vec, + false_count: &mut [usize], parent_impurity: T, j: usize, ) { @@ -536,7 +551,7 @@ impl DecisionTreeClassifier { - T::from(tc).unwrap() / T::from(n).unwrap() * impurity(&self.parameters.criterion, &true_count, tc) - T::from(fc).unwrap() / T::from(n).unwrap() - * impurity(&self.parameters.criterion, &false_count, fc); + * impurity(&self.parameters.criterion, false_count, fc); if self.nodes[visitor.node].split_score == Option::None || gain > self.nodes[visitor.node].split_score.unwrap() @@ -561,6 +576,7 @@ impl DecisionTreeClassifier { mut visitor: NodeVisitor<'a, T, M>, mtry: usize, visitor_queue: &mut LinkedList>, + rng: &mut impl Rng, ) -> bool { let (n, _) = visitor.x.shape(); let mut tc = 0; @@ -609,7 +625,7 @@ impl DecisionTreeClassifier { visitor.level + 1, ); - if self.find_best_cutoff(&mut true_visitor, mtry) { + if self.find_best_cutoff(&mut true_visitor, mtry, rng) { visitor_queue.push_back(true_visitor); } @@ -622,7 +638,7 @@ impl DecisionTreeClassifier { visitor.level + 1, ); - if self.find_best_cutoff(&mut false_visitor, mtry) { + if self.find_best_cutoff(&mut false_visitor, mtry, rng) { visitor_queue.push_back(false_visitor); } @@ -635,6 +651,7 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn gini_impurity() { assert!( @@ -651,6 +668,7 @@ mod tests { ); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn fit_predict_iris() { let x = DenseMatrix::from_2d_array(&[ @@ -703,6 +721,7 @@ mod tests { ); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn fit_predict_baloons() { let x = DenseMatrix::from_2d_array(&[ @@ -739,7 +758,9 @@ mod tests { ); } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] + #[cfg(feature = "serde")] fn serde() { let x = DenseMatrix::from_2d_array(&[ &[1., 1., 1., 0.], diff --git a/src/tree/decision_tree_regressor.rs b/src/tree/decision_tree_regressor.rs index 06ee507..94fa0f8 100644 --- a/src/tree/decision_tree_regressor.rs +++ b/src/tree/decision_tree_regressor.rs @@ -63,6 +63,8 @@ use std::default::Default; use std::fmt::Debug; use rand::seq::SliceRandom; +use rand::Rng; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::algorithm::sort::quick_sort::QuickArgSort; @@ -71,7 +73,8 @@ use crate::error::Failed; use crate::linalg::Matrix; use crate::math::num::RealNumber; -#[derive(Serialize, Deserialize, Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] /// Parameters of Regression Tree pub struct DecisionTreeRegressorParameters { /// The maximum depth of the tree. @@ -83,16 +86,18 @@ pub struct DecisionTreeRegressorParameters { } /// Regression Tree -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] pub struct DecisionTreeRegressor { nodes: Vec>, parameters: DecisionTreeRegressorParameters, depth: u16, } -#[derive(Serialize, Deserialize, Debug)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] struct Node { - index: usize, + _index: usize, output: T, split_feature: usize, split_value: Option, @@ -132,7 +137,7 @@ impl Default for DecisionTreeRegressorParameters { impl Node { fn new(index: usize, output: T) -> Self { Node { - index, + _index: index, output, split_feature: 0, split_value: Option::None, @@ -238,7 +243,14 @@ impl DecisionTreeRegressor { ) -> Result, Failed> { let (x_nrows, num_attributes) = x.shape(); let samples = vec![1; x_nrows]; - DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters) + DecisionTreeRegressor::fit_weak_learner( + x, + y, + samples, + num_attributes, + parameters, + &mut rand::thread_rng(), + ) } pub(crate) fn fit_weak_learner>( @@ -247,6 +259,7 @@ impl DecisionTreeRegressor { samples: Vec, mtry: usize, parameters: DecisionTreeRegressorParameters, + rng: &mut impl Rng, ) -> Result, Failed> { let y_m = M::from_row_vector(y.clone()); @@ -276,17 +289,17 @@ impl DecisionTreeRegressor { depth: 0, }; - let mut visitor = NodeVisitor::::new(0, samples, &order, &x, &y_m, 1); + let mut visitor = NodeVisitor::::new(0, samples, &order, x, &y_m, 1); let mut visitor_queue: LinkedList> = LinkedList::new(); - if tree.find_best_cutoff(&mut visitor, mtry) { + if tree.find_best_cutoff(&mut visitor, mtry, rng) { visitor_queue.push_back(visitor); } while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) { match visitor_queue.pop_front() { - Some(node) => tree.split(node, mtry, &mut visitor_queue), + Some(node) => tree.split(node, mtry, &mut visitor_queue, rng), None => break, }; } @@ -339,6 +352,7 @@ impl DecisionTreeRegressor { &mut self, visitor: &mut NodeVisitor<'_, T, M>, mtry: usize, + rng: &mut impl Rng, ) -> bool { let (_, n_attr) = visitor.x.shape(); @@ -353,7 +367,7 @@ impl DecisionTreeRegressor { let mut variables = (0..n_attr).collect::>(); if mtry < n_attr { - variables.shuffle(&mut rand::thread_rng()); + variables.shuffle(rng); } let parent_gain = @@ -428,6 +442,7 @@ impl DecisionTreeRegressor { mut visitor: NodeVisitor<'a, T, M>, mtry: usize, visitor_queue: &mut LinkedList>, + rng: &mut impl Rng, ) -> bool { let (n, _) = visitor.x.shape(); let mut tc = 0; @@ -476,7 +491,7 @@ impl DecisionTreeRegressor { visitor.level + 1, ); - if self.find_best_cutoff(&mut true_visitor, mtry) { + if self.find_best_cutoff(&mut true_visitor, mtry, rng) { visitor_queue.push_back(true_visitor); } @@ -489,7 +504,7 @@ impl DecisionTreeRegressor { visitor.level + 1, ); - if self.find_best_cutoff(&mut false_visitor, mtry) { + if self.find_best_cutoff(&mut false_visitor, mtry, rng) { visitor_queue.push_back(false_visitor); } @@ -502,6 +517,7 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn fit_longley() { let x = DenseMatrix::from_2d_array(&[ @@ -576,7 +592,9 @@ mod tests { } } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] + #[cfg(feature = "serde")] fn serde() { let x = DenseMatrix::from_2d_array(&[ &[234.289, 235.6, 159., 107.608, 1947., 60.323],