diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 82d0eab..e2cd825 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,19 +19,20 @@ jobs: { os: "ubuntu", target: "i686-unknown-linux-gnu" }, { os: "ubuntu", target: "wasm32-unknown-unknown" }, { os: "macos", target: "aarch64-apple-darwin" }, + { os: "ubuntu", target: "wasm32-wasi" }, ] env: TZ: "/usr/share/zoneinfo/your/location" steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Cache .cargo and target uses: actions/cache@v2 with: path: | ~/.cargo ./target - key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.toml') }} - restore-keys: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.toml') }} + key: ${{ runner.os }}-cargo-${{ matrix.platform.target }}-${{ hashFiles('**/Cargo.toml') }} + restore-keys: ${{ runner.os }}-cargo-${{ matrix.platform.target }}-${{ hashFiles('**/Cargo.toml') }} - name: Install Rust toolchain uses: actions-rs/toolchain@v1 with: @@ -42,6 +43,9 @@ jobs: - name: Install test runner for wasm if: matrix.platform.target == 'wasm32-unknown-unknown' run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh + - name: Install test runner for wasi + if: matrix.platform.target == 'wasm32-wasi' + run: curl https://wasmtime.dev/install.sh -sSf | bash - name: Stable Build uses: actions-rs/cargo@v1 with: @@ -56,3 +60,40 @@ jobs: - name: Tests in WASM if: matrix.platform.target == 'wasm32-unknown-unknown' run: wasm-pack test --node -- --all-features + - name: Tests in WASI + if: matrix.platform.target == 'wasm32-wasi' + run: | + export WASMTIME_HOME="$HOME/.wasmtime" + export PATH="$WASMTIME_HOME/bin:$PATH" + cargo install cargo-wasi && cargo wasi test + + check_features: + runs-on: "${{ matrix.platform.os }}-latest" + strategy: + matrix: + platform: [{ os: "ubuntu" }] + features: ["--features serde", "--features datasets", ""] + env: + TZ: "/usr/share/zoneinfo/your/location" + steps: + - uses: actions/checkout@v3 + - name: Cache .cargo and target + uses: actions/cache@v2 + with: + path: | + ~/.cargo + ./target + key: ${{ runner.os }}-cargo-features-${{ hashFiles('**/Cargo.toml') }} + restore-keys: ${{ runner.os }}-cargo-features-${{ hashFiles('**/Cargo.toml') }} + - name: Install Rust toolchain + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + target: ${{ matrix.platform.target }} + profile: minimal + default: true + - name: Stable Build + uses: actions-rs/cargo@v1 + with: + command: build + args: --no-default-features ${{ matrix.features }} diff --git a/Cargo.toml b/Cargo.toml index d048eea..c5cb4fd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,38 +12,38 @@ readme = "README.md" keywords = ["machine-learning", "statistical", "ai", "optimization", "linear-algebra"] categories = ["science"] -[features] -default = ["datasets", "serde"] -ndarray-bindings = ["ndarray"] -datasets = ["rand_distr", "std"] -std = ["rand/std", "rand/std_rng"] -# wasm32 only -js = ["getrandom/js"] - [dependencies] approx = "0.5.1" cfg-if = "1.0.0" ndarray = { version = "0.15", optional = true } num-traits = "0.2.12" num = "0.4" -rand = { version = "0.8", default-features = false, features = ["small_rng"] } +rand = { version = "0.8.5", default-features = false, features = ["small_rng"] } rand_distr = { version = "0.4", optional = true } serde = { version = "1", features = ["derive"], optional = true } +[features] +default = ["serde", "datasets"] +serde = ["dep:serde"] +ndarray-bindings = ["dep:ndarray"] +datasets = ["dep:rand_distr", "std"] +std = ["rand/std_rng", "rand/std"] +# wasm32 only +js = ["getrandom/js"] + [target.'cfg(target_arch = "wasm32")'.dependencies] getrandom = { version = "0.2", optional = true } [dev-dependencies] +itertools = "*" criterion = { version = "0.4", default-features = false } serde_json = "1.0" bincode = "1.3.1" -[target.'cfg(target_arch = "wasm32")'.dev-dependencies] +[target.'cfg(all(target_arch = "wasm32", not(target_os = "wasi")))'.dev-dependencies] wasm-bindgen-test = "0.3" -[profile.bench] -debug = true - +[workspace] resolver = "2" [profile.test] @@ -56,4 +56,4 @@ strip = true debug = 1 lto = true codegen-units = 1 -overflow-checks = true \ No newline at end of file +overflow-checks = true diff --git a/src/algorithm/neighbour/bbd_tree.rs b/src/algorithm/neighbour/bbd_tree.rs index e84f6de..44cef50 100644 --- a/src/algorithm/neighbour/bbd_tree.rs +++ b/src/algorithm/neighbour/bbd_tree.rs @@ -316,7 +316,10 @@ mod tests { use super::*; use crate::linalg::basic::matrix::DenseMatrix; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn bbdtree_iris() { let data = DenseMatrix::from_2d_array(&[ diff --git a/src/algorithm/neighbour/cover_tree.rs b/src/algorithm/neighbour/cover_tree.rs index 85e0d22..db062f9 100644 --- a/src/algorithm/neighbour/cover_tree.rs +++ b/src/algorithm/neighbour/cover_tree.rs @@ -468,7 +468,10 @@ mod tests { } } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn cover_tree_test() { let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; @@ -485,7 +488,10 @@ mod tests { let knn: Vec = knn.iter().map(|v| *v.2).collect(); assert_eq!(vec!(3, 4, 5, 6, 7), knn); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn cover_tree_test1() { let data = vec![ @@ -504,7 +510,10 @@ mod tests { assert_eq!(vec!(0, 1, 2), knn); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] #[cfg(feature = "serde")] fn serde() { diff --git a/src/algorithm/neighbour/distances.rs b/src/algorithm/neighbour/distances.rs deleted file mode 100644 index eee99ca..0000000 --- a/src/algorithm/neighbour/distances.rs +++ /dev/null @@ -1,48 +0,0 @@ -//! -//! Dissimilarities for vector-vector distance -//! -//! Representing distances as pairwise dissimilarities, so to build a -//! graph of closest neighbours. This representation can be reused for -//! different implementations (initially used in this library for FastPair). -use std::cmp::{Eq, Ordering, PartialOrd}; - -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; - -use crate::numbers::realnum::RealNumber; - -/// -/// The edge of the subgraph is defined by `PairwiseDistance`. -/// The calling algorithm can store a list of distsances as -/// a list of these structures. -/// -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Clone, Copy)] -pub struct PairwiseDistance { - /// index of the vector in the original `Matrix` or list - pub node: usize, - - /// index of the closest neighbor in the original `Matrix` or same list - pub neighbour: Option, - - /// measure of distance, according to the algorithm distance function - /// if the distance is None, the edge has value "infinite" or max distance - /// each algorithm has to match - pub distance: Option, -} - -impl Eq for PairwiseDistance {} - -impl PartialEq for PairwiseDistance { - fn eq(&self, other: &Self) -> bool { - self.node == other.node - && self.neighbour == other.neighbour - && self.distance == other.distance - } -} - -impl PartialOrd for PairwiseDistance { - fn partial_cmp(&self, other: &Self) -> Option { - self.distance.partial_cmp(&other.distance) - } -} diff --git a/src/algorithm/neighbour/fastpair.rs b/src/algorithm/neighbour/fastpair.rs index d676460..ab3c7a2 100644 --- a/src/algorithm/neighbour/fastpair.rs +++ b/src/algorithm/neighbour/fastpair.rs @@ -1,5 +1,5 @@ /// -/// # FastPair: Data-structure for the dynamic closest-pair problem. +/// ### FastPair: Data-structure for the dynamic closest-pair problem. /// /// Reference: /// Eppstein, David: Fast hierarchical clustering and other applications of @@ -7,8 +7,8 @@ /// /// Example: /// ``` -/// use smartcore::algorithm::neighbour::distances::PairwiseDistance; -/// use smartcore::linalg::naive::dense_matrix::DenseMatrix; +/// use smartcore::metrics::distance::PairwiseDistance; +/// use smartcore::linalg::basic::matrix::DenseMatrix; /// use smartcore::algorithm::neighbour::fastpair::FastPair; /// let x = DenseMatrix::::from_2d_array(&[ /// &[5.1, 3.5, 1.4, 0.2], @@ -25,12 +25,14 @@ /// use std::collections::HashMap; -use crate::algorithm::neighbour::distances::PairwiseDistance; +use num::Bounded; + use crate::error::{Failed, FailedError}; -use crate::linalg::basic::arrays::Array2; +use crate::linalg::basic::arrays::{Array1, Array2}; use crate::metrics::distance::euclidian::Euclidian; -use crate::numbers::realnum::RealNumber; +use crate::metrics::distance::PairwiseDistance; use crate::numbers::floatnum::FloatNumber; +use crate::numbers::realnum::RealNumber; /// /// Inspired by Python implementation: @@ -98,7 +100,7 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2> FastPair<'a, T, M> { PairwiseDistance { node: index_row_i, neighbour: Option::None, - distance: Some(T::MAX), + distance: Some(::max_value()), }, ); } @@ -119,13 +121,19 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2> FastPair<'a, T, M> { ); let d = Euclidian::squared_distance( - &(self.samples.get_row_as_vec(index_row_i)), - &(self.samples.get_row_as_vec(index_row_j)), + &Vec::from_iterator( + self.samples.get_row(index_row_i).iterator(0).copied(), + self.samples.shape().1, + ), + &Vec::from_iterator( + self.samples.get_row(index_row_j).iterator(0).copied(), + self.samples.shape().1, + ), ); - if d < nbd.unwrap() { + if d < nbd.unwrap().to_f64().unwrap() { // set this j-value to be the closest neighbour index_closest = index_row_j; - nbd = Some(d); + nbd = Some(T::from(d).unwrap()); } } @@ -138,7 +146,7 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2> FastPair<'a, T, M> { // No more neighbors, terminate conga line. // Last person on the line has no neigbors distances.get_mut(&max_index).unwrap().neighbour = Some(max_index); - distances.get_mut(&(len - 1)).unwrap().distance = Some(T::max_value()); + distances.get_mut(&(len - 1)).unwrap().distance = Some(::max_value()); // compute sparse matrix (connectivity matrix) let mut sparse_matrix = M::zeros(len, len); @@ -171,33 +179,6 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2> FastPair<'a, T, M> { } } - /// - /// Brute force algorithm, used only for comparison and testing - /// - #[cfg(feature = "fp_bench")] - pub fn closest_pair_brute(&self) -> PairwiseDistance { - use itertools::Itertools; - let m = self.samples.shape().0; - - let mut closest_pair = PairwiseDistance { - node: 0, - neighbour: Option::None, - distance: Some(T::max_value()), - }; - for pair in (0..m).combinations(2) { - let d = Euclidian::squared_distance( - &(self.samples.get_row_as_vec(pair[0])), - &(self.samples.get_row_as_vec(pair[1])), - ); - if d < closest_pair.distance.unwrap() { - closest_pair.node = pair[0]; - closest_pair.neighbour = Some(pair[1]); - closest_pair.distance = Some(d); - } - } - closest_pair - } - // // Compute distances from input to all other points in data-structure. // input is the row index of the sample matrix @@ -210,10 +191,19 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2> FastPair<'a, T, M> { distances.push(PairwiseDistance { node: index_row, neighbour: Some(*other), - distance: Some(Euclidian::squared_distance( - &(self.samples.get_row_as_vec(index_row)), - &(self.samples.get_row_as_vec(*other)), - )), + distance: Some( + T::from(Euclidian::squared_distance( + &Vec::from_iterator( + self.samples.get_row(index_row).iterator(0).copied(), + self.samples.shape().1, + ), + &Vec::from_iterator( + self.samples.get_row(*other).iterator(0).copied(), + self.samples.shape().1, + ), + )) + .unwrap(), + ), }) } } @@ -225,7 +215,39 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2> FastPair<'a, T, M> { mod tests_fastpair { use super::*; - use crate::linalg::naive::dense_matrix::*; + use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix}; + + /// + /// Brute force algorithm, used only for comparison and testing + /// + pub fn closest_pair_brute(fastpair: &FastPair>) -> PairwiseDistance { + use itertools::Itertools; + let m = fastpair.samples.shape().0; + + let mut closest_pair = PairwiseDistance { + node: 0, + neighbour: Option::None, + distance: Some(f64::max_value()), + }; + for pair in (0..m).combinations(2) { + let d = Euclidian::squared_distance( + &Vec::from_iterator( + fastpair.samples.get_row(pair[0]).iterator(0).copied(), + fastpair.samples.shape().1, + ), + &Vec::from_iterator( + fastpair.samples.get_row(pair[1]).iterator(0).copied(), + fastpair.samples.shape().1, + ), + ); + if d < closest_pair.distance.unwrap() { + closest_pair.node = pair[0]; + closest_pair.neighbour = Some(pair[1]); + closest_pair.distance = Some(d); + } + } + closest_pair + } #[test] fn fastpair_init() { @@ -284,7 +306,7 @@ mod tests_fastpair { }; assert_eq!(closest_pair, expected_closest_pair); - let closest_pair_brute = fastpair.closest_pair_brute(); + let closest_pair_brute = closest_pair_brute(&fastpair); assert_eq!(closest_pair_brute, expected_closest_pair); } @@ -302,7 +324,7 @@ mod tests_fastpair { neighbour: Some(3), distance: Some(4.0), }; - assert_eq!(closest_pair, fastpair.closest_pair_brute()); + assert_eq!(closest_pair, closest_pair_brute(&fastpair)); assert_eq!(closest_pair, expected_closest_pair); } @@ -459,11 +481,16 @@ mod tests_fastpair { let expected: HashMap<_, _> = dissimilarities.into_iter().collect(); for i in 0..(x.shape().0 - 1) { - let input_node = result.samples.get_row_as_vec(i); let input_neighbour: usize = expected.get(&i).unwrap().neighbour.unwrap(); let distance = Euclidian::squared_distance( - &input_node, - &result.samples.get_row_as_vec(input_neighbour), + &Vec::from_iterator( + result.samples.get_row(i).iterator(0).copied(), + result.samples.shape().1, + ), + &Vec::from_iterator( + result.samples.get_row(input_neighbour).iterator(0).copied(), + result.samples.shape().1, + ), ); assert_eq!(i, expected.get(&i).unwrap().node); @@ -518,7 +545,7 @@ mod tests_fastpair { let result = fastpair.unwrap(); let dissimilarity1 = result.closest_pair(); - let dissimilarity2 = result.closest_pair_brute(); + let dissimilarity2 = closest_pair_brute(&result); assert_eq!(dissimilarity1, dissimilarity2); } diff --git a/src/algorithm/neighbour/linear_search.rs b/src/algorithm/neighbour/linear_search.rs index ccd5c10..b1ce727 100644 --- a/src/algorithm/neighbour/linear_search.rs +++ b/src/algorithm/neighbour/linear_search.rs @@ -143,7 +143,10 @@ mod tests { } } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn knn_find() { let data1 = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; @@ -190,7 +193,10 @@ mod tests { assert_eq!(vec!(1, 2, 3), found_idxs2); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn knn_point_eq() { let point1 = KNNPoint { diff --git a/src/algorithm/neighbour/mod.rs b/src/algorithm/neighbour/mod.rs index fdfaeb7..e150d19 100644 --- a/src/algorithm/neighbour/mod.rs +++ b/src/algorithm/neighbour/mod.rs @@ -41,10 +41,8 @@ use serde::{Deserialize, Serialize}; pub(crate) mod bbd_tree; /// tree data structure for fast nearest neighbor search pub mod cover_tree; -/// dissimilarities for vector-vector distance. Linkage algorithms used in fastpair -pub mod distances; /// fastpair closest neighbour algorithm -// pub mod fastpair; +pub mod fastpair; /// very simple algorithm that sequentially checks each element of the list until a match is found or the whole list has been searched. pub mod linear_search; diff --git a/src/algorithm/sort/heap_select.rs b/src/algorithm/sort/heap_select.rs index bc880bc..23d2704 100644 --- a/src/algorithm/sort/heap_select.rs +++ b/src/algorithm/sort/heap_select.rs @@ -95,14 +95,20 @@ impl HeapSelection { mod tests { use super::*; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn with_capacity() { let heap = HeapSelection::::with_capacity(3); assert_eq!(3, heap.k); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn test_add() { let mut heap = HeapSelection::with_capacity(3); @@ -120,7 +126,10 @@ mod tests { assert_eq!(vec![2, 0, -5], heap.get()); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn test_add1() { let mut heap = HeapSelection::with_capacity(3); @@ -135,7 +144,10 @@ mod tests { assert_eq!(vec![0f64, -1f64, -5f64], heap.get()); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn test_add2() { let mut heap = HeapSelection::with_capacity(3); @@ -148,7 +160,10 @@ mod tests { assert_eq!(vec![5.6568, 2.8284, 0.0], heap.get()); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn test_add_ordered() { let mut heap = HeapSelection::with_capacity(3); diff --git a/src/algorithm/sort/quick_sort.rs b/src/algorithm/sort/quick_sort.rs index 7ae7cc0..97d34e7 100644 --- a/src/algorithm/sort/quick_sort.rs +++ b/src/algorithm/sort/quick_sort.rs @@ -113,7 +113,10 @@ impl QuickArgSort for Vec { mod tests { use super::*; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn with_capacity() { let arr1 = vec![0.3, 0.1, 0.2, 0.4, 0.9, 0.5, 0.7, 0.6, 0.8]; diff --git a/src/cluster/dbscan.rs b/src/cluster/dbscan.rs index bec45b9..2887dc2 100644 --- a/src/cluster/dbscan.rs +++ b/src/cluster/dbscan.rs @@ -425,7 +425,10 @@ mod tests { assert!(iter.next().is_none()); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn fit_predict_dbscan() { let x = DenseMatrix::from_2d_array(&[ @@ -457,7 +460,10 @@ mod tests { assert_eq!(expected_labels, predicted_labels); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] #[cfg(feature = "serde")] fn serde() { @@ -491,10 +497,12 @@ mod tests { assert_eq!(dbscan, deserialized_dbscan); } - use crate::dataset::generator; + #[cfg(feature = "datasets")] #[test] fn from_vec() { + use crate::dataset::generator; + // Generate three blobs let blobs = generator::make_blobs(100, 2, 3); let x: DenseMatrix = DenseMatrix::from_iterator(blobs.data.into_iter(), 100, 2, 0); diff --git a/src/cluster/kmeans.rs b/src/cluster/kmeans.rs index a7b9f08..9322d65 100644 --- a/src/cluster/kmeans.rs +++ b/src/cluster/kmeans.rs @@ -418,7 +418,10 @@ mod tests { use super::*; use crate::linalg::basic::matrix::DenseMatrix; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn invalid_k() { let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]); @@ -462,7 +465,10 @@ mod tests { assert!(iter.next().is_none()); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn fit_predict_iris() { let x = DenseMatrix::from_2d_array(&[ @@ -497,7 +503,10 @@ mod tests { } } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] #[cfg(feature = "serde")] fn serde() { diff --git a/src/dataset/boston.rs b/src/dataset/boston.rs index 1e4ee12..f10db61 100644 --- a/src/dataset/boston.rs +++ b/src/dataset/boston.rs @@ -69,7 +69,10 @@ mod tests { assert!(serialize_data(&dataset, "boston.xy").is_ok()); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn boston_dataset() { let dataset = load_dataset(); diff --git a/src/dataset/breast_cancer.rs b/src/dataset/breast_cancer.rs index 236d69c..b88eaf9 100644 --- a/src/dataset/breast_cancer.rs +++ b/src/dataset/breast_cancer.rs @@ -83,7 +83,10 @@ mod tests { // assert!(serialize_data(&dataset, "breast_cancer.xy").is_ok()); // } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn cancer_dataset() { let dataset = load_dataset(); diff --git a/src/dataset/diabetes.rs b/src/dataset/diabetes.rs index f3e4156..0450522 100644 --- a/src/dataset/diabetes.rs +++ b/src/dataset/diabetes.rs @@ -67,7 +67,10 @@ mod tests { // assert!(serialize_data(&dataset, "diabetes.xy").is_ok()); // } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn boston_dataset() { let dataset = load_dataset(); diff --git a/src/dataset/digits.rs b/src/dataset/digits.rs index b7dd2d4..6f081de 100644 --- a/src/dataset/digits.rs +++ b/src/dataset/digits.rs @@ -57,7 +57,10 @@ mod tests { let dataset = load_dataset(); assert!(serialize_data(&dataset, "digits.xy").is_ok()); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn digits_dataset() { let dataset = load_dataset(); diff --git a/src/dataset/generator.rs b/src/dataset/generator.rs index d880f37..f8e5944 100644 --- a/src/dataset/generator.rs +++ b/src/dataset/generator.rs @@ -137,7 +137,10 @@ mod tests { use super::*; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn test_make_blobs() { let dataset = make_blobs(10, 2, 3); @@ -150,7 +153,10 @@ mod tests { assert_eq!(dataset.num_samples, 10); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn test_make_circles() { let dataset = make_circles(10, 0.5, 0.05); @@ -163,7 +169,10 @@ mod tests { assert_eq!(dataset.num_samples, 10); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn test_make_moons() { let dataset = make_moons(10, 0.05); diff --git a/src/dataset/iris.rs b/src/dataset/iris.rs index 9c81440..838f1ec 100644 --- a/src/dataset/iris.rs +++ b/src/dataset/iris.rs @@ -70,7 +70,10 @@ mod tests { // assert!(serialize_data(&dataset, "iris.xy").is_ok()); // } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn iris_dataset() { let dataset = load_dataset(); diff --git a/src/dataset/mod.rs b/src/dataset/mod.rs index 602abde..5b32d02 100644 --- a/src/dataset/mod.rs +++ b/src/dataset/mod.rs @@ -121,7 +121,10 @@ pub(crate) fn deserialize_data( mod tests { use super::*; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn as_matrix() { let dataset = Dataset { diff --git a/src/decomposition/pca.rs b/src/decomposition/pca.rs index 29bf551..20aee37 100644 --- a/src/decomposition/pca.rs +++ b/src/decomposition/pca.rs @@ -446,7 +446,10 @@ mod tests { &[6.8, 161.0, 60.0, 15.6], ]) } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn pca_components() { let us_arrests = us_arrests_data(); @@ -466,7 +469,10 @@ mod tests { epsilon = 1e-3 )); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn decompose_covariance() { let us_arrests = us_arrests_data(); @@ -579,7 +585,10 @@ mod tests { )); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn decompose_correlation() { let us_arrests = us_arrests_data(); @@ -700,7 +709,7 @@ mod tests { // Disable this test for now // TODO: implement deserialization for new DenseMatrix - // #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + // #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)] // #[test] // #[cfg(feature = "serde")] // fn pca_serde() { diff --git a/src/decomposition/svd.rs b/src/decomposition/svd.rs index 7b563b1..dab7099 100644 --- a/src/decomposition/svd.rs +++ b/src/decomposition/svd.rs @@ -237,7 +237,10 @@ mod tests { assert!(iter.next().is_none()); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn svd_decompose() { // https://stat.ethz.ch/R-manual/R-devel/library/datasets/html/USArrests.html @@ -316,7 +319,7 @@ mod tests { // Disable this test for now // TODO: implement deserialization for new DenseMatrix - // #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + // #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)] // #[test] // #[cfg(feature = "serde")] // fn serde() { diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index 44bd4e3..1225082 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -664,7 +664,10 @@ mod tests { assert!(iter.next().is_none()); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn fit_predict_iris() { let x = DenseMatrix::from_2d_array(&[ @@ -710,7 +713,10 @@ mod tests { assert!(accuracy(&y, &classifier.predict(&x).unwrap()) >= 0.95); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn fit_predict_iris_oob() { let x = DenseMatrix::from_2d_array(&[ @@ -759,7 +765,10 @@ mod tests { ); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] #[cfg(feature = "serde")] fn serde() { diff --git a/src/ensemble/random_forest_regressor.rs b/src/ensemble/random_forest_regressor.rs index a54ac3a..4ccdd4a 100644 --- a/src/ensemble/random_forest_regressor.rs +++ b/src/ensemble/random_forest_regressor.rs @@ -550,7 +550,10 @@ mod tests { assert!(iter.next().is_none()); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn fit_longley() { let x = DenseMatrix::from_2d_array(&[ @@ -595,7 +598,10 @@ mod tests { assert!(mean_absolute_error(&y, &y_hat) < 1.0); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn fit_predict_longley_oob() { let x = DenseMatrix::from_2d_array(&[ @@ -645,7 +651,10 @@ mod tests { assert!(mean_absolute_error(&y, &y_hat) < mean_absolute_error(&y, &y_hat_oob)); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] #[cfg(feature = "serde")] fn serde() { diff --git a/src/lib.rs b/src/lib.rs index d665838..11c5b38 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,34 +10,30 @@ //! # SmartCore //! -//! Welcome to SmartCore, the most advanced machine learning library in Rust! +//! Welcome to SmartCore, machine learning in Rust! //! //! SmartCore features various classification, regression and clustering algorithms including support vector machines, random forests, k-means and DBSCAN, //! as well as tools for model selection and model evaluation. //! -//! SmartCore is well integrated with a with wide variaty of libraries that provide support for large, multi-dimensional arrays and matrices. At this moment, -//! all Smartcore's algorithms work with ordinary Rust vectors, as well as matrices and vectors defined in these packages: -//! * [ndarray](https://docs.rs/ndarray) +//! SmartCore provides its own traits system that extends Rust standard library, to deal with linear algebra and common +//! computational models. Its API is designed using well recognizable patterns. Extra features (like support for [ndarray](https://docs.rs/ndarray) +//! structures) is available via optional features. //! //! ## Getting Started //! //! To start using SmartCore simply add the following to your Cargo.toml file: //! ```ignore //! [dependencies] -//! smartcore = { git = "https://github.com/smartcorelib/smartcore", branch = "v0.5-wip" } +//! smartcore = { git = "https://github.com/smartcorelib/smartcore", branch = "development" } //! ``` //! -//! All machine learning algorithms in SmartCore are grouped into these broad categories: -//! * [Clustering](cluster/index.html), unsupervised clustering of unlabeled data. -//! * [Matrix Decomposition](decomposition/index.html), various methods for matrix decomposition. -//! * [Linear Models](linear/index.html), regression and classification methods where output is assumed to have linear relation to explanatory variables -//! * [Ensemble Models](ensemble/index.html), variety of regression and classification ensemble models -//! * [Tree-based Models](tree/index.html), classification and regression trees -//! * [Nearest Neighbors](neighbors/index.html), K Nearest Neighbors for classification and regression -//! * [Naive Bayes](naive_bayes/index.html), statistical classification technique based on Bayes Theorem -//! * [SVM](svm/index.html), support vector machines +//! ## Using Jupyter +//! For quick introduction, Jupyter Notebooks are available [here](https://github.com/smartcorelib/smartcore-jupyter/tree/main/notebooks). +//! You can set up a local environment to run Rust notebooks using [EVCXR](https://github.com/google/evcxr) +//! following [these instructions](https://depth-first.com/articles/2020/09/21/interactive-rust-in-a-repl-and-jupyter-notebook-with-evcxr/). //! //! +//! ## First Example //! For example, you can use this code to fit a [K Nearest Neighbors classifier](neighbors/knn_classifier/index.html) to a dataset that is defined as standard Rust vector: //! //! ``` @@ -48,14 +44,14 @@ //! // Various distance metrics //! use smartcore::metrics::distance::*; //! -//! // Turn Rust vectors with samples into a matrix +//! // Turn Rust vector-slices with samples into a matrix //! let x = DenseMatrix::from_2d_array(&[ //! &[1., 2.], //! &[3., 4.], //! &[5., 6.], //! &[7., 8.], //! &[9., 10.]]); -//! // Our classes are defined as a Vector +//! // Our classes are defined as a vector //! let y = vec![2, 2, 2, 3, 3]; //! //! // Train classifier @@ -64,6 +60,17 @@ //! // Predict classes //! let y_hat = knn.predict(&x).unwrap(); //! ``` +//! +//! ## Overview +//! All machine learning algorithms in SmartCore are grouped into these broad categories: +//! * [Clustering](cluster/index.html), unsupervised clustering of unlabeled data. +//! * [Matrix Decomposition](decomposition/index.html), various methods for matrix decomposition. +//! * [Linear Models](linear/index.html), regression and classification methods where output is assumed to have linear relation to explanatory variables +//! * [Ensemble Models](ensemble/index.html), variety of regression and classification ensemble models +//! * [Tree-based Models](tree/index.html), classification and regression trees +//! * [Nearest Neighbors](neighbors/index.html), K Nearest Neighbors for classification and regression +//! * [Naive Bayes](naive_bayes/index.html), statistical classification technique based on Bayes Theorem +//! * [SVM](svm/index.html), support vector machines /// Foundamental numbers traits pub mod numbers; diff --git a/src/linalg/basic/matrix.rs b/src/linalg/basic/matrix.rs index 149c1fc..bde6b78 100644 --- a/src/linalg/basic/matrix.rs +++ b/src/linalg/basic/matrix.rs @@ -4,6 +4,7 @@ use std::ops::Range; use std::slice::Iter; use approx::{AbsDiffEq, RelativeEq}; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use crate::linalg::basic::arrays::{ @@ -19,7 +20,8 @@ use crate::numbers::basenum::Number; use crate::numbers::realnum::RealNumber; /// Dense matrix -#[derive(Debug, Clone, Serialize, Deserialize)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] pub struct DenseMatrix { ncols: usize, nrows: usize, diff --git a/src/linalg/traits/cholesky.rs b/src/linalg/traits/cholesky.rs index 22ec9a9..1394270 100644 --- a/src/linalg/traits/cholesky.rs +++ b/src/linalg/traits/cholesky.rs @@ -169,7 +169,10 @@ mod tests { use super::*; use crate::linalg::basic::matrix::DenseMatrix; use approx::relative_eq; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn cholesky_decompose() { let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]); @@ -188,7 +191,10 @@ mod tests { )); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn cholesky_solve_mut() { let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]); diff --git a/src/linalg/traits/evd.rs b/src/linalg/traits/evd.rs index 7b017e7..c0a54df 100644 --- a/src/linalg/traits/evd.rs +++ b/src/linalg/traits/evd.rs @@ -810,7 +810,10 @@ mod tests { use crate::linalg::basic::matrix::DenseMatrix; use approx::relative_eq; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn decompose_symmetric() { let A = DenseMatrix::from_2d_array(&[ @@ -841,7 +844,10 @@ mod tests { assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON); } } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn decompose_asymmetric() { let A = DenseMatrix::from_2d_array(&[ @@ -872,7 +878,10 @@ mod tests { assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON); } } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn decompose_complex() { let A = DenseMatrix::from_2d_array(&[ diff --git a/src/linalg/traits/lu.rs b/src/linalg/traits/lu.rs index 8e54f89..020c271 100644 --- a/src/linalg/traits/lu.rs +++ b/src/linalg/traits/lu.rs @@ -260,7 +260,10 @@ mod tests { use crate::linalg::basic::matrix::DenseMatrix; use approx::relative_eq; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn decompose() { let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]); @@ -275,7 +278,10 @@ mod tests { assert!(relative_eq!(lu.U(), expected_U, epsilon = 1e-4)); assert!(relative_eq!(lu.pivot(), expected_pivot, epsilon = 1e-4)); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn inverse() { let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]); diff --git a/src/linalg/traits/qr.rs b/src/linalg/traits/qr.rs index 1337fd8..da13729 100644 --- a/src/linalg/traits/qr.rs +++ b/src/linalg/traits/qr.rs @@ -198,7 +198,10 @@ mod tests { use super::*; use crate::linalg::basic::matrix::DenseMatrix; use approx::relative_eq; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn decompose() { let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); @@ -217,7 +220,10 @@ mod tests { assert!(relative_eq!(qr.R().abs(), r.abs(), epsilon = 1e-4)); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn qr_solve_mut() { let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); diff --git a/src/linalg/traits/stats.rs b/src/linalg/traits/stats.rs index fccd293..3bd7042 100644 --- a/src/linalg/traits/stats.rs +++ b/src/linalg/traits/stats.rs @@ -71,8 +71,8 @@ pub trait MatrixStats: ArrayView2 + Array2 { x } - /// (reference)[http://en.wikipedia.org/wiki/Arithmetic_mean] - /// Taken from statistical + /// + /// Taken from `statistical` /// The MIT License (MIT) /// Copyright (c) 2015 Jeff Belgum fn _mean_of_vector(v: &[T]) -> T { @@ -97,7 +97,7 @@ pub trait MatrixStats: ArrayView2 + Array2 { sum } - /// (Sample variance)[http://en.wikipedia.org/wiki/Variance#Sample_variance] + /// /// Taken from statistical /// The MIT License (MIT) /// Copyright (c) 2015 Jeff Belgum diff --git a/src/linalg/traits/svd.rs b/src/linalg/traits/svd.rs index 1920f99..93c8d9a 100644 --- a/src/linalg/traits/svd.rs +++ b/src/linalg/traits/svd.rs @@ -479,7 +479,10 @@ mod tests { use crate::linalg::basic::matrix::DenseMatrix; use approx::relative_eq; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn decompose_symmetric() { let A = DenseMatrix::from_2d_array(&[ @@ -510,7 +513,10 @@ mod tests { assert!((s[i] - svd.s[i]).abs() < 1e-4); } } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn decompose_asymmetric() { let A = DenseMatrix::from_2d_array(&[ @@ -711,7 +717,10 @@ mod tests { assert!((s[i] - svd.s[i]).abs() < 1e-4); } } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn solve() { let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]); @@ -722,7 +731,10 @@ mod tests { assert!(relative_eq!(w, expected_w, epsilon = 1e-2)); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn decompose_restore() { let a = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0, 4.0], &[5.0, 6.0, 7.0, 8.0]]); diff --git a/src/linear/elastic_net.rs b/src/linear/elastic_net.rs index 46272ed..7d57e1d 100644 --- a/src/linear/elastic_net.rs +++ b/src/linear/elastic_net.rs @@ -491,7 +491,10 @@ mod tests { assert!(iter.next().is_none()); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn elasticnet_longley() { let x = DenseMatrix::from_2d_array(&[ @@ -535,7 +538,10 @@ mod tests { assert!(mean_absolute_error(&y_hat, &y) < 30.0); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn elasticnet_fit_predict1() { let x = DenseMatrix::from_2d_array(&[ @@ -603,7 +609,7 @@ mod tests { } // TODO: serialization for the new DenseMatrix needs to be implemented - // #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + // #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)] // #[test] // #[cfg(feature = "serde")] // fn serde() { diff --git a/src/linear/lasso.rs b/src/linear/lasso.rs index 08076c6..150d5ca 100644 --- a/src/linear/lasso.rs +++ b/src/linear/lasso.rs @@ -398,7 +398,10 @@ mod tests { assert!(iter.next().is_none()); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn lasso_fit_predict() { let x = DenseMatrix::from_2d_array(&[ @@ -448,7 +451,7 @@ mod tests { } // TODO: serialization for the new DenseMatrix needs to be implemented - // #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + // #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)] // #[test] // #[cfg(feature = "serde")] // fn serde() { diff --git a/src/linear/linear_regression.rs b/src/linear/linear_regression.rs index ef471db..1f7d540 100644 --- a/src/linear/linear_regression.rs +++ b/src/linear/linear_regression.rs @@ -325,7 +325,10 @@ mod tests { assert!(iter.next().is_none()); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn ols_fit_predict() { let x = DenseMatrix::from_2d_array(&[ @@ -372,7 +375,7 @@ mod tests { } // TODO: serialization for the new DenseMatrix needs to be implemented - // #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + // #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)] // #[test] // #[cfg(feature = "serde")] // fn serde() { diff --git a/src/linear/logistic_regression.rs b/src/linear/logistic_regression.rs index 2012ae0..6b706dd 100644 --- a/src/linear/logistic_regression.rs +++ b/src/linear/logistic_regression.rs @@ -577,6 +577,8 @@ impl, Y: #[cfg(test)] mod tests { use super::*; + + #[cfg(feature = "datasets")] use crate::dataset::generator::make_blobs; use crate::linalg::basic::arrays::Array; use crate::linalg::basic::matrix::DenseMatrix; @@ -596,7 +598,10 @@ mod tests { assert!(iter.next().is_none()); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn multiclass_objective_f() { let x = DenseMatrix::from_2d_array(&[ @@ -653,7 +658,10 @@ mod tests { assert!((g[0].abs() - 32.0).abs() < 1e-4); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn binary_objective_f() { let x = DenseMatrix::from_2d_array(&[ @@ -712,7 +720,10 @@ mod tests { assert!((g[2] - 3.8693).abs() < 1e-4); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn lr_fit_predict() { let x: DenseMatrix = DenseMatrix::from_2d_array(&[ @@ -751,7 +762,11 @@ mod tests { assert_eq!(y_hat, vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg(feature = "datasets")] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn lr_fit_predict_multiclass() { let blobs = make_blobs(15, 4, 3); @@ -778,7 +793,11 @@ mod tests { assert!(reg_coeff_sum < coeff); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg(feature = "datasets")] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn lr_fit_predict_binary() { let blobs = make_blobs(20, 4, 2); @@ -809,7 +828,7 @@ mod tests { } // TODO: serialization for the new DenseMatrix needs to be implemented - // #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + // #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)] // #[test] // #[cfg(feature = "serde")] // fn serde() { @@ -840,7 +859,10 @@ mod tests { // assert_eq!(lr, deserialized_lr); // } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn lr_fit_predict_iris() { let x = DenseMatrix::from_2d_array(&[ diff --git a/src/linear/ridge_regression.rs b/src/linear/ridge_regression.rs index 671a8fb..914afc2 100644 --- a/src/linear/ridge_regression.rs +++ b/src/linear/ridge_regression.rs @@ -443,7 +443,10 @@ mod tests { assert!(iter.next().is_none()); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn ridge_fit_predict() { let x = DenseMatrix::from_2d_array(&[ @@ -500,7 +503,7 @@ mod tests { } // TODO: implement serialization for new DenseMatrix - // #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + // #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)] // #[test] // #[cfg(feature = "serde")] // fn serde() { diff --git a/src/metrics/accuracy.rs b/src/metrics/accuracy.rs index b2a454e..1279614 100644 --- a/src/metrics/accuracy.rs +++ b/src/metrics/accuracy.rs @@ -83,7 +83,10 @@ impl Metrics for Accuracy { mod tests { use super::*; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn accuracy_float() { let y_pred: Vec = vec![0., 2., 1., 3.]; @@ -96,7 +99,10 @@ mod tests { assert!((score2 - 1.0).abs() < 1e-8); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn accuracy_int() { let y_pred: Vec = vec![0, 2, 1, 3]; diff --git a/src/metrics/auc.rs b/src/metrics/auc.rs index a94f3a3..ecaf646 100644 --- a/src/metrics/auc.rs +++ b/src/metrics/auc.rs @@ -26,8 +26,8 @@ use std::marker::PhantomData; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -use crate::linalg::basic::arrays::{Array1, ArrayView1, MutArrayView1}; -use crate::numbers::basenum::Number; +use crate::linalg::basic::arrays::{Array1, ArrayView1}; +use crate::numbers::floatnum::FloatNumber; use crate::metrics::Metrics; @@ -38,14 +38,14 @@ pub struct AUC { _phantom: PhantomData, } -impl Metrics for AUC { +impl Metrics for AUC { /// create a typed object to call AUC functions fn new() -> Self { Self { _phantom: PhantomData, } } - fn new_with(_parameter: T) -> Self { + fn new_with(_parameter: f64) -> Self { Self { _phantom: PhantomData, } @@ -53,11 +53,7 @@ impl Metrics for AUC { /// AUC score. /// * `y_true` - ground truth (correct) labels. /// * `y_pred_prob` - probability estimates, as returned by a classifier. - fn get_score( - &self, - y_true: &dyn ArrayView1, - y_pred_prob: &dyn ArrayView1, - ) -> f64 { + fn get_score(&self, y_true: &dyn ArrayView1, y_pred_prob: &dyn ArrayView1) -> f64 { let mut pos = T::zero(); let mut neg = T::zero(); @@ -76,9 +72,10 @@ impl Metrics for AUC { } } - let y_pred = y_pred_prob.clone(); - - let label_idx = y_pred.argsort(); + let y_pred: Vec = + Array1::::from_iterator(y_pred_prob.iterator(0).copied(), y_pred_prob.shape()); + // TODO: try to use `crate::algorithm::sort::quick_sort` here + let label_idx: Vec = y_pred.argsort(); let mut rank = vec![0f64; n]; let mut i = 0; @@ -108,7 +105,7 @@ impl Metrics for AUC { let pos = pos.to_f64().unwrap(); let neg = neg.to_f64().unwrap(); - T::from(auc - (pos * (pos + 1f64) / 2.0)).unwrap() / T::from(pos * neg).unwrap() + (auc - (pos * (pos + 1f64) / 2f64)) / (pos * neg) } } @@ -116,7 +113,10 @@ impl Metrics for AUC { mod tests { use super::*; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn auc() { let y_true: Vec = vec![0., 0., 1., 1.]; diff --git a/src/metrics/cluster_hcv.rs b/src/metrics/cluster_hcv.rs index 4ee5974..ad43c94 100644 --- a/src/metrics/cluster_hcv.rs +++ b/src/metrics/cluster_hcv.rs @@ -87,7 +87,10 @@ impl Metrics for HCVScore { mod tests { use super::*; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn homogeneity_score() { let v1 = vec![0, 0, 1, 1, 2, 0, 4]; diff --git a/src/metrics/cluster_helpers.rs b/src/metrics/cluster_helpers.rs index e3f1881..47d8061 100644 --- a/src/metrics/cluster_helpers.rs +++ b/src/metrics/cluster_helpers.rs @@ -102,7 +102,10 @@ pub fn mutual_info_score(contingency: &[Vec]) -> f64 { mod tests { use super::*; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn contingency_matrix_test() { let v1 = vec![0, 0, 1, 1, 2, 0, 4]; @@ -114,7 +117,10 @@ mod tests { ); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn entropy_test() { let v1 = vec![0, 0, 1, 1, 2, 0, 4]; @@ -122,7 +128,10 @@ mod tests { assert!((1.2770 - entropy(&v1).unwrap() as f64).abs() < 1e-4); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn mutual_info_score_test() { let v1 = vec![0, 0, 1, 1, 2, 0, 4]; diff --git a/src/metrics/distance/euclidian.rs b/src/metrics/distance/euclidian.rs index 2c8a2db..39deebf 100644 --- a/src/metrics/distance/euclidian.rs +++ b/src/metrics/distance/euclidian.rs @@ -76,7 +76,10 @@ impl> Distance for Euclidian { mod tests { use super::*; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn squared_distance() { let a = vec![1, 2, 3]; diff --git a/src/metrics/distance/hamming.rs b/src/metrics/distance/hamming.rs index 80fbc24..ac0c2c3 100644 --- a/src/metrics/distance/hamming.rs +++ b/src/metrics/distance/hamming.rs @@ -70,7 +70,10 @@ impl> Distance for Hamming { mod tests { use super::*; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn hamming_distance() { let a = vec![1, 0, 0, 1, 0, 0, 1]; diff --git a/src/metrics/distance/mahalanobis.rs b/src/metrics/distance/mahalanobis.rs index 1b79a0a..e526c20 100644 --- a/src/metrics/distance/mahalanobis.rs +++ b/src/metrics/distance/mahalanobis.rs @@ -139,7 +139,10 @@ mod tests { use crate::linalg::basic::arrays::ArrayView2; use crate::linalg::basic::matrix::DenseMatrix; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn mahalanobis_distance() { let data = DenseMatrix::from_2d_array(&[ diff --git a/src/metrics/distance/manhattan.rs b/src/metrics/distance/manhattan.rs index 719043f..fae7868 100644 --- a/src/metrics/distance/manhattan.rs +++ b/src/metrics/distance/manhattan.rs @@ -66,7 +66,10 @@ impl> Distance for Manhattan { mod tests { use super::*; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn manhattan_distance() { let a = vec![1., 2., 3.]; diff --git a/src/metrics/distance/minkowski.rs b/src/metrics/distance/minkowski.rs index 9bfde0b..93e0c93 100644 --- a/src/metrics/distance/minkowski.rs +++ b/src/metrics/distance/minkowski.rs @@ -71,7 +71,10 @@ impl> Distance for Minkowski { mod tests { use super::*; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn minkowski_distance() { let a = vec![1., 2., 3.]; diff --git a/src/metrics/distance/mod.rs b/src/metrics/distance/mod.rs index 4075e14..193d7a1 100644 --- a/src/metrics/distance/mod.rs +++ b/src/metrics/distance/mod.rs @@ -24,9 +24,15 @@ pub mod manhattan; /// A generalization of both the Euclidean distance and the Manhattan distance. pub mod minkowski; +use std::cmp::{Eq, Ordering, PartialOrd}; + use crate::linalg::basic::arrays::Array2; use crate::linalg::traits::lu::LUDecomposable; use crate::numbers::basenum::Number; +use crate::numbers::realnum::RealNumber; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; /// Distance metric, a function that calculates distance between two points pub trait Distance: Clone { @@ -66,3 +72,45 @@ impl Distances { mahalanobis::Mahalanobis::new(data) } } + +/// +/// ### Pairwise dissimilarities. +/// +/// Representing distances as pairwise dissimilarities, so to build a +/// graph of closest neighbours. This representation can be reused for +/// different implementations +/// (initially used in this library for [FastPair](algorithm/neighbour/fastpair)). +/// The edge of the subgraph is defined by `PairwiseDistance`. +/// The calling algorithm can store a list of distances as +/// a list of these structures. +/// +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, Copy)] +pub struct PairwiseDistance { + /// index of the vector in the original `Matrix` or list + pub node: usize, + + /// index of the closest neighbor in the original `Matrix` or same list + pub neighbour: Option, + + /// measure of distance, according to the algorithm distance function + /// if the distance is None, the edge has value "infinite" or max distance + /// each algorithm has to match + pub distance: Option, +} + +impl Eq for PairwiseDistance {} + +impl PartialEq for PairwiseDistance { + fn eq(&self, other: &Self) -> bool { + self.node == other.node + && self.neighbour == other.neighbour + && self.distance == other.distance + } +} + +impl PartialOrd for PairwiseDistance { + fn partial_cmp(&self, other: &Self) -> Option { + self.distance.partial_cmp(&other.distance) + } +} diff --git a/src/metrics/f1.rs b/src/metrics/f1.rs index 4eb4e48..fd41019 100644 --- a/src/metrics/f1.rs +++ b/src/metrics/f1.rs @@ -82,7 +82,10 @@ impl Metrics for F1 { mod tests { use super::*; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn f1() { let y_pred: Vec = vec![0., 0., 1., 1., 1., 1.]; diff --git a/src/metrics/mean_absolute_error.rs b/src/metrics/mean_absolute_error.rs index 74bf4c3..36e5f48 100644 --- a/src/metrics/mean_absolute_error.rs +++ b/src/metrics/mean_absolute_error.rs @@ -76,7 +76,10 @@ impl Metrics for MeanAbsoluteError { mod tests { use super::*; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn mean_absolute_error() { let y_true: Vec = vec![3., -0.5, 2., 7.]; diff --git a/src/metrics/mean_squared_error.rs b/src/metrics/mean_squared_error.rs index 7ad296a..7443857 100644 --- a/src/metrics/mean_squared_error.rs +++ b/src/metrics/mean_squared_error.rs @@ -76,7 +76,10 @@ impl Metrics for MeanSquareError { mod tests { use super::*; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn mean_squared_error() { let y_true: Vec = vec![3., -0.5, 2., 7.]; diff --git a/src/metrics/mod.rs b/src/metrics/mod.rs index 503391c..06d44a1 100644 --- a/src/metrics/mod.rs +++ b/src/metrics/mod.rs @@ -55,7 +55,7 @@ pub mod accuracy; // TODO: reimplement AUC // /// Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores. -// pub mod auc; +pub mod auc; /// Compute the homogeneity, completeness and V-Measure scores. pub mod cluster_hcv; pub(crate) mod cluster_helpers; @@ -84,7 +84,7 @@ use std::marker::PhantomData; /// A trait to be implemented by all metrics pub trait Metrics { /// instantiate a new Metrics trait-object - /// https://doc.rust-lang.org/error-index.html#E0038 + /// fn new() -> Self where Self: Sized; @@ -133,10 +133,10 @@ impl ClassificationMetrics { f1::F1::new_with(beta) } - // /// Area Under the Receiver Operating Characteristic Curve (ROC AUC), see [AUC](auc/index.html). - // pub fn roc_auc_score() -> auc::AUC { - // auc::AUC::::new() - // } + /// Area Under the Receiver Operating Characteristic Curve (ROC AUC), see [AUC](auc/index.html). + pub fn roc_auc_score() -> auc::AUC { + auc::AUC::::new() + } } impl ClassificationMetricsOrd { @@ -212,16 +212,19 @@ pub fn f1>( obj.get_score(y_true, y_pred) } -// /// AUC score, see [AUC](auc/index.html). -// /// * `y_true` - cround truth (correct) labels. -// /// * `y_pred_probabilities` - probability estimates, as returned by a classifier. -// pub fn roc_auc_score + Array1 + Array1>( -// y_true: &V, -// y_pred_probabilities: &V, -// ) -> T { -// let obj = ClassificationMetrics::::roc_auc_score(); -// obj.get_score(y_true, y_pred_probabilities) -// } +/// AUC score, see [AUC](auc/index.html). +/// * `y_true` - cround truth (correct) labels. +/// * `y_pred_probabilities` - probability estimates, as returned by a classifier. +pub fn roc_auc_score< + T: Number + RealNumber + FloatNumber + PartialOrd, + V: ArrayView1 + Array1 + Array1, +>( + y_true: &V, + y_pred_probabilities: &V, +) -> f64 { + let obj = ClassificationMetrics::::roc_auc_score(); + obj.get_score(y_true, y_pred_probabilities) +} /// Computes mean squared error, see [mean squared error](mean_squared_error/index.html). /// * `y_true` - Ground truth (correct) target values. diff --git a/src/metrics/precision.rs b/src/metrics/precision.rs index 9bc0ff5..a6fcef1 100644 --- a/src/metrics/precision.rs +++ b/src/metrics/precision.rs @@ -95,7 +95,10 @@ impl Metrics for Precision { mod tests { use super::*; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn precision() { let y_true: Vec = vec![0., 1., 1., 0.]; @@ -114,7 +117,10 @@ mod tests { assert!((score3 - 0.5).abs() < 1e-8); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn precision_multiclass() { let y_true: Vec = vec![0., 0., 0., 1., 1., 1., 2., 2., 2.]; diff --git a/src/metrics/r2.rs b/src/metrics/r2.rs index b217aed..6581abe 100644 --- a/src/metrics/r2.rs +++ b/src/metrics/r2.rs @@ -81,7 +81,10 @@ impl Metrics for R2 { mod tests { use super::*; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn r2() { let y_true: Vec = vec![3., -0.5, 2., 7.]; diff --git a/src/metrics/recall.rs b/src/metrics/recall.rs index 640471d..04a779a 100644 --- a/src/metrics/recall.rs +++ b/src/metrics/recall.rs @@ -96,7 +96,10 @@ impl Metrics for Recall { mod tests { use super::*; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn recall() { let y_true: Vec = vec![0., 1., 1., 0.]; @@ -115,7 +118,10 @@ mod tests { assert!((score3 - 0.6666666666666666).abs() < 1e-8); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn recall_multiclass() { let y_true: Vec = vec![0., 0., 0., 1., 1., 1., 2., 2., 2.]; diff --git a/src/model_selection/kfold.rs b/src/model_selection/kfold.rs index 8387d7a..680d2ac 100644 --- a/src/model_selection/kfold.rs +++ b/src/model_selection/kfold.rs @@ -159,7 +159,10 @@ mod tests { use super::*; use crate::linalg::basic::matrix::DenseMatrix; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn run_kfold_return_test_indices_simple() { let k = KFold { @@ -175,7 +178,10 @@ mod tests { assert_eq!(test_indices[2], (22..33).collect::>()); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn run_kfold_return_test_indices_odd() { let k = KFold { @@ -191,7 +197,10 @@ mod tests { assert_eq!(test_indices[2], (23..34).collect::>()); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn run_kfold_return_test_mask_simple() { let k = KFold { @@ -218,7 +227,10 @@ mod tests { } } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn run_kfold_return_split_simple() { let k = KFold { @@ -235,7 +247,10 @@ mod tests { assert_eq!(train_test_splits[1].1, (11..22).collect::>()); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn run_kfold_return_split_simple_shuffle() { let k = KFold { @@ -251,7 +266,10 @@ mod tests { assert_eq!(train_test_splits[1].1.len(), 11_usize); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn numpy_parity_test() { let k = KFold { @@ -273,7 +291,10 @@ mod tests { } } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn numpy_parity_test_shuffle() { let k = KFold { diff --git a/src/model_selection/mod.rs b/src/model_selection/mod.rs index 7bb8b8a..b8e4e7f 100644 --- a/src/model_selection/mod.rs +++ b/src/model_selection/mod.rs @@ -321,7 +321,10 @@ mod tests { use crate::neighbors::knn_regressor::{KNNRegressor, KNNRegressorParameters}; use crate::neighbors::KNNWeightFunction; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn run_train_test_split() { let n = 123; @@ -346,7 +349,10 @@ mod tests { struct BiasedParameters {} impl NoParameters for BiasedParameters {} - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn test_cross_validate_biased() { struct BiasedEstimator {} @@ -412,7 +418,10 @@ mod tests { assert_eq!(0.4, results.mean_train_score()); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn test_cross_validate_knn() { let x = DenseMatrix::from_2d_array(&[ @@ -457,7 +466,10 @@ mod tests { assert!(results.mean_train_score() < results.mean_test_score()); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn test_cross_val_predict_knn() { let x: DenseMatrix = DenseMatrix::from_2d_array(&[ diff --git a/src/naive_bayes/bernoulli.rs b/src/naive_bayes/bernoulli.rs index 1ded589..02bf330 100644 --- a/src/naive_bayes/bernoulli.rs +++ b/src/naive_bayes/bernoulli.rs @@ -496,7 +496,10 @@ mod tests { assert!(iter.next().is_none()); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn run_bernoulli_naive_bayes() { // Tests that BernoulliNB when alpha=1.0 gives the same values as @@ -551,7 +554,10 @@ mod tests { assert_eq!(y_hat, &[1]); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn bernoulli_nb_scikit_parity() { let x = DenseMatrix::from_2d_array(&[ @@ -612,7 +618,10 @@ mod tests { assert_eq!(y_hat, vec!(2, 2, 0, 0, 0, 2, 1, 1, 0, 0, 0, 0, 0, 0, 0)); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] #[cfg(feature = "serde")] fn serde() { diff --git a/src/naive_bayes/categorical.rs b/src/naive_bayes/categorical.rs index 3196b3b..f2ae4a8 100644 --- a/src/naive_bayes/categorical.rs +++ b/src/naive_bayes/categorical.rs @@ -428,7 +428,10 @@ mod tests { assert!(iter.next().is_none()); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn run_categorical_naive_bayes() { let x = DenseMatrix::::from_2d_array(&[ @@ -509,7 +512,10 @@ mod tests { assert_eq!(y_hat, vec![0, 1]); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn run_categorical_naive_bayes2() { let x = DenseMatrix::::from_2d_array(&[ @@ -535,7 +541,10 @@ mod tests { assert_eq!(y_hat, vec![0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1]); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] #[cfg(feature = "serde")] fn serde() { diff --git a/src/naive_bayes/gaussian.rs b/src/naive_bayes/gaussian.rs index c8223fd..f23ffdb 100644 --- a/src/naive_bayes/gaussian.rs +++ b/src/naive_bayes/gaussian.rs @@ -372,7 +372,10 @@ mod tests { assert!(iter.next().is_none()); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn run_gaussian_naive_bayes() { let x = DenseMatrix::from_2d_array(&[ @@ -409,7 +412,10 @@ mod tests { ); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn run_gaussian_naive_bayes_with_priors() { let x = DenseMatrix::from_2d_array(&[ @@ -429,7 +435,10 @@ mod tests { assert_eq!(gnb.class_priors(), &priors); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] #[cfg(feature = "serde")] fn serde() { diff --git a/src/naive_bayes/multinomial.rs b/src/naive_bayes/multinomial.rs index f82d4fc..f3305ac 100644 --- a/src/naive_bayes/multinomial.rs +++ b/src/naive_bayes/multinomial.rs @@ -403,7 +403,10 @@ mod tests { assert!(iter.next().is_none()); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn run_multinomial_naive_bayes() { // Tests that MultinomialNB when alpha=1.0 gives the same values as @@ -461,7 +464,10 @@ mod tests { assert_eq!(y_hat, &[0]); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn multinomial_nb_scikit_parity() { let x = DenseMatrix::::from_2d_array(&[ @@ -524,7 +530,10 @@ mod tests { assert_eq!(y_hat, vec!(2, 2, 0, 0, 0, 2, 2, 1, 0, 1, 0, 2, 0, 0, 2)); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] #[cfg(feature = "serde")] fn serde() { diff --git a/src/neighbors/knn_classifier.rs b/src/neighbors/knn_classifier.rs index fb02b82..67d094a 100644 --- a/src/neighbors/knn_classifier.rs +++ b/src/neighbors/knn_classifier.rs @@ -305,7 +305,10 @@ mod tests { use super::*; use crate::linalg::basic::matrix::DenseMatrix; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn knn_fit_predict() { let x = @@ -317,7 +320,10 @@ mod tests { assert_eq!(y.to_vec(), y_hat); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn knn_fit_predict_weighted() { let x = DenseMatrix::from_2d_array(&[&[1.], &[2.], &[3.], &[4.], &[5.]]); @@ -335,7 +341,10 @@ mod tests { assert_eq!(vec![3], y_hat); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] #[cfg(feature = "serde")] fn serde() { diff --git a/src/neighbors/knn_regressor.rs b/src/neighbors/knn_regressor.rs index cf9b88d..3a123f7 100644 --- a/src/neighbors/knn_regressor.rs +++ b/src/neighbors/knn_regressor.rs @@ -289,7 +289,10 @@ mod tests { use crate::linalg::basic::matrix::DenseMatrix; use crate::metrics::distance::Distances; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn knn_fit_predict_weighted() { let x = @@ -313,7 +316,10 @@ mod tests { } } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn knn_fit_predict_uniform() { let x = @@ -328,7 +334,10 @@ mod tests { } } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] #[cfg(feature = "serde")] fn serde() { diff --git a/src/numbers/floatnum.rs b/src/numbers/floatnum.rs index 15966cf..4ca7f73 100644 --- a/src/numbers/floatnum.rs +++ b/src/numbers/floatnum.rs @@ -1,8 +1,6 @@ -use rand::Rng; - use num_traits::{Float, Signed}; -use crate::numbers::basenum::Number; +use crate::{numbers::basenum::Number, rand_custom::get_rng_impl}; /// Defines float number /// @@ -58,7 +56,8 @@ impl FloatNumber for f64 { } fn rand() -> f64 { - let mut rng = rand::thread_rng(); + use rand::Rng; + let mut rng = get_rng_impl(None); rng.gen() } @@ -99,7 +98,8 @@ impl FloatNumber for f32 { } fn rand() -> f32 { - let mut rng = rand::thread_rng(); + use rand::Rng; + let mut rng = get_rng_impl(None); rng.gen() } diff --git a/src/numbers/realnum.rs b/src/numbers/realnum.rs index 6855e4b..8c60e47 100644 --- a/src/numbers/realnum.rs +++ b/src/numbers/realnum.rs @@ -63,6 +63,7 @@ impl RealNumber for f64 { } fn rand() -> f64 { + // TODO: to be implemented, see issue smartcore#214 1.0 } diff --git a/src/optimization/first_order/gradient_descent.rs b/src/optimization/first_order/gradient_descent.rs index 63c5c4a..5603a34 100644 --- a/src/optimization/first_order/gradient_descent.rs +++ b/src/optimization/first_order/gradient_descent.rs @@ -99,7 +99,10 @@ mod tests { use crate::optimization::line_search::Backtracking; use crate::optimization::FunctionOrder; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn gradient_descent() { let x0 = vec![-1., 1.]; diff --git a/src/optimization/first_order/lbfgs.rs b/src/optimization/first_order/lbfgs.rs index 1410bac..3bd5f13 100644 --- a/src/optimization/first_order/lbfgs.rs +++ b/src/optimization/first_order/lbfgs.rs @@ -278,7 +278,10 @@ mod tests { use crate::optimization::line_search::Backtracking; use crate::optimization::FunctionOrder; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn lbfgs() { let x0 = vec![0., 0.]; diff --git a/src/optimization/line_search.rs b/src/optimization/line_search.rs index 3d6c012..9a2656c 100644 --- a/src/optimization/line_search.rs +++ b/src/optimization/line_search.rs @@ -129,7 +129,10 @@ impl LineSearchMethod for Backtracking { mod tests { use super::*; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn backtracking() { let f = |x: f64| -> f64 { x.powf(2.) + x }; diff --git a/src/preprocessing/categorical.rs b/src/preprocessing/categorical.rs index 1316f2a..048dd26 100644 --- a/src/preprocessing/categorical.rs +++ b/src/preprocessing/categorical.rs @@ -224,7 +224,10 @@ mod tests { use crate::linalg::basic::matrix::DenseMatrix; use crate::preprocessing::series_encoder::CategoryMapper; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn adjust_idxs() { assert_eq!(find_new_idxs(0, &[], &[]), Vec::::new()); @@ -269,7 +272,10 @@ mod tests { (orig, oh_enc) } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn hash_encode_f64_series() { let series = vec![3.0, 1.0, 2.0, 1.0]; @@ -280,7 +286,10 @@ mod tests { let orig_val: f64 = inv.unwrap().into(); assert_eq!(orig_val, 2.0); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn test_fit() { let (x, _) = build_fake_matrix(); @@ -296,7 +305,10 @@ mod tests { assert_eq!(num_cat, vec![2, 4]); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn matrix_transform_test() { let (x, expected_x) = build_fake_matrix(); @@ -312,7 +324,10 @@ mod tests { assert_eq!(nm, expected_x); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn fail_on_bad_category() { let m = DenseMatrix::from_2d_array(&[ diff --git a/src/preprocessing/numerical.rs b/src/preprocessing/numerical.rs index fc0aa9b..2e424e0 100644 --- a/src/preprocessing/numerical.rs +++ b/src/preprocessing/numerical.rs @@ -420,7 +420,10 @@ mod tests { /// Same as `fit_for_random_values` test, but using a `StandardScaler` that has been /// serialized and deserialized. - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] #[cfg(feature = "serde")] fn serde_fit_for_random_values() { diff --git a/src/preprocessing/series_encoder.rs b/src/preprocessing/series_encoder.rs index 6c81134..5d8b720 100644 --- a/src/preprocessing/series_encoder.rs +++ b/src/preprocessing/series_encoder.rs @@ -199,7 +199,10 @@ where mod tests { use super::*; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn from_categories() { let fake_categories: Vec = vec![1, 2, 3, 4, 5, 3, 5, 3, 1, 2, 4]; @@ -218,14 +221,20 @@ mod tests { let enc = CategoryMapper::<&str>::from_positional_category_vec(fake_category_pos); enc } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn ordinal_encoding() { let enc = build_fake_str_enc(); assert_eq!(1f64, enc.get_ordinal::(&"dog").unwrap()) } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn category_map_and_vec() { let category_map: HashMap<&str, usize> = vec![("background", 0), ("dog", 1), ("cat", 2)] @@ -240,7 +249,10 @@ mod tests { assert_eq!(oh_vec, res); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn positional_categories_vec() { let enc = build_fake_str_enc(); @@ -252,7 +264,10 @@ mod tests { assert_eq!(oh_vec, res); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn invert_label_test() { let enc = build_fake_str_enc(); @@ -265,7 +280,10 @@ mod tests { }; } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn test_many_categorys() { let enc = build_fake_str_enc(); diff --git a/src/svm/mod.rs b/src/svm/mod.rs index 3bb3c41..48e5907 100644 --- a/src/svm/mod.rs +++ b/src/svm/mod.rs @@ -22,6 +22,8 @@ //! //! //! +/// search parameters +pub mod search; pub mod svc; pub mod svr; @@ -52,6 +54,7 @@ impl<'a> Debug for dyn Kernel<'_> + 'a { } } +#[cfg(feature = "serde")] impl<'a> Serialize for dyn Kernel<'_> + 'a { fn serialize(&self, serializer: S) -> Result where @@ -64,7 +67,8 @@ impl<'a> Serialize for dyn Kernel<'_> + 'a { } /// Pre-defined kernel functions -#[derive(Debug, Clone, Serialize, Deserialize)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] pub struct Kernels {} impl<'a> Kernels { @@ -267,7 +271,10 @@ mod tests { use super::*; use crate::svm::Kernels; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn linear_kernel() { let v1 = vec![1., 2., 3.]; @@ -276,7 +283,10 @@ mod tests { assert_eq!(32f64, Kernels::linear().apply(&v1, &v2).unwrap()); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn rbf_kernel() { let v1 = vec![1., 2., 3.]; @@ -291,7 +301,10 @@ mod tests { assert!((0.2265f64 - result) < 1e-4); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn polynomial_kernel() { let v1 = vec![1., 2., 3.]; @@ -306,7 +319,10 @@ mod tests { assert!((4913f64 - result) < std::f64::EPSILON); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn sigmoid_kernel() { let v1 = vec![1., 2., 3.]; diff --git a/src/svm/search/mod.rs b/src/svm/search/mod.rs new file mode 100644 index 0000000..6d86feb --- /dev/null +++ b/src/svm/search/mod.rs @@ -0,0 +1,4 @@ +/// SVC search parameters +pub mod svc_params; +/// SVC search parameters +pub mod svr_params; diff --git a/src/svm/search/svc_params.rs b/src/svm/search/svc_params.rs new file mode 100644 index 0000000..42f686b --- /dev/null +++ b/src/svm/search/svc_params.rs @@ -0,0 +1,183 @@ +// /// SVC grid search parameters +// #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +// #[derive(Debug, Clone)] +// pub struct SVCSearchParameters< +// TX: Number + RealNumber, +// TY: Number + Ord, +// X: Array2, +// Y: Array1, +// K: Kernel, +// > { +// #[cfg_attr(feature = "serde", serde(default))] +// /// Number of epochs. +// pub epoch: Vec, +// #[cfg_attr(feature = "serde", serde(default))] +// /// Regularization parameter. +// pub c: Vec, +// #[cfg_attr(feature = "serde", serde(default))] +// /// Tolerance for stopping epoch. +// pub tol: Vec, +// #[cfg_attr(feature = "serde", serde(default))] +// /// The kernel function. +// pub kernel: Vec, +// #[cfg_attr(feature = "serde", serde(default))] +// /// Unused parameter. +// m: PhantomData<(X, Y, TY)>, +// #[cfg_attr(feature = "serde", serde(default))] +// /// Controls the pseudo random number generation for shuffling the data for probability estimates +// seed: Vec>, +// } + +// /// SVC grid search iterator +// pub struct SVCSearchParametersIterator< +// TX: Number + RealNumber, +// TY: Number + Ord, +// X: Array2, +// Y: Array1, +// K: Kernel, +// > { +// svc_search_parameters: SVCSearchParameters, +// current_epoch: usize, +// current_c: usize, +// current_tol: usize, +// current_kernel: usize, +// current_seed: usize, +// } + +// impl, Y: Array1, K: Kernel> +// IntoIterator for SVCSearchParameters +// { +// type Item = SVCParameters<'a, TX, TY, X, Y>; +// type IntoIter = SVCSearchParametersIterator; + +// fn into_iter(self) -> Self::IntoIter { +// SVCSearchParametersIterator { +// svc_search_parameters: self, +// current_epoch: 0, +// current_c: 0, +// current_tol: 0, +// current_kernel: 0, +// current_seed: 0, +// } +// } +// } + +// impl, Y: Array1, K: Kernel> +// Iterator for SVCSearchParametersIterator +// { +// type Item = SVCParameters; + +// fn next(&mut self) -> Option { +// if self.current_epoch == self.svc_search_parameters.epoch.len() +// && self.current_c == self.svc_search_parameters.c.len() +// && self.current_tol == self.svc_search_parameters.tol.len() +// && self.current_kernel == self.svc_search_parameters.kernel.len() +// && self.current_seed == self.svc_search_parameters.seed.len() +// { +// return None; +// } + +// let next = SVCParameters { +// epoch: self.svc_search_parameters.epoch[self.current_epoch], +// c: self.svc_search_parameters.c[self.current_c], +// tol: self.svc_search_parameters.tol[self.current_tol], +// kernel: self.svc_search_parameters.kernel[self.current_kernel].clone(), +// m: PhantomData, +// seed: self.svc_search_parameters.seed[self.current_seed], +// }; + +// if self.current_epoch + 1 < self.svc_search_parameters.epoch.len() { +// self.current_epoch += 1; +// } else if self.current_c + 1 < self.svc_search_parameters.c.len() { +// self.current_epoch = 0; +// self.current_c += 1; +// } else if self.current_tol + 1 < self.svc_search_parameters.tol.len() { +// self.current_epoch = 0; +// self.current_c = 0; +// self.current_tol += 1; +// } else if self.current_kernel + 1 < self.svc_search_parameters.kernel.len() { +// self.current_epoch = 0; +// self.current_c = 0; +// self.current_tol = 0; +// self.current_kernel += 1; +// } else if self.current_seed + 1 < self.svc_search_parameters.seed.len() { +// self.current_epoch = 0; +// self.current_c = 0; +// self.current_tol = 0; +// self.current_kernel = 0; +// self.current_seed += 1; +// } else { +// self.current_epoch += 1; +// self.current_c += 1; +// self.current_tol += 1; +// self.current_kernel += 1; +// self.current_seed += 1; +// } + +// Some(next) +// } +// } + +// impl, Y: Array1, K: Kernel> Default +// for SVCSearchParameters +// { +// fn default() -> Self { +// let default_params: SVCParameters = SVCParameters::default(); + +// SVCSearchParameters { +// epoch: vec![default_params.epoch], +// c: vec![default_params.c], +// tol: vec![default_params.tol], +// kernel: vec![default_params.kernel], +// m: PhantomData, +// seed: vec![default_params.seed], +// } +// } +// } + +// #[cfg(test)] +// mod tests { +// use num::ToPrimitive; + +// use super::*; +// use crate::linalg::basic::matrix::DenseMatrix; +// use crate::metrics::accuracy; +// #[cfg(feature = "serde")] +// use crate::svm::*; + +// #[test] +// fn search_parameters() { +// let parameters: SVCSearchParameters, LinearKernel> = +// SVCSearchParameters { +// epoch: vec![10, 100], +// kernel: vec![LinearKernel {}], +// ..Default::default() +// }; +// let mut iter = parameters.into_iter(); +// let next = iter.next().unwrap(); +// assert_eq!(next.epoch, 10); +// assert_eq!(next.kernel, LinearKernel {}); +// let next = iter.next().unwrap(); +// assert_eq!(next.epoch, 100); +// assert_eq!(next.kernel, LinearKernel {}); +// assert!(iter.next().is_none()); +// } + +// #[test] +// fn search_parameters() { +// let parameters: SVCSearchParameters, LinearKernel> = +// SVCSearchParameters { +// epoch: vec![10, 100], +// kernel: vec![LinearKernel {}], +// ..Default::default() +// }; +// let mut iter = parameters.into_iter(); +// let next = iter.next().unwrap(); +// assert_eq!(next.epoch, 10); +// assert_eq!(next.kernel, LinearKernel {}); +// let next = iter.next().unwrap(); +// assert_eq!(next.epoch, 100); +// assert_eq!(next.kernel, LinearKernel {}); +// assert!(iter.next().is_none()); +// } +// } diff --git a/src/svm/search/svr_params.rs b/src/svm/search/svr_params.rs new file mode 100644 index 0000000..03d0ece --- /dev/null +++ b/src/svm/search/svr_params.rs @@ -0,0 +1,112 @@ +// /// SVR grid search parameters +// #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +// #[derive(Debug, Clone)] +// pub struct SVRSearchParameters, K: Kernel> { +// /// Epsilon in the epsilon-SVR model. +// pub eps: Vec, +// /// Regularization parameter. +// pub c: Vec, +// /// Tolerance for stopping eps. +// pub tol: Vec, +// /// The kernel function. +// pub kernel: Vec, +// /// Unused parameter. +// m: PhantomData, +// } + +// /// SVR grid search iterator +// pub struct SVRSearchParametersIterator, K: Kernel> { +// svr_search_parameters: SVRSearchParameters, +// current_eps: usize, +// current_c: usize, +// current_tol: usize, +// current_kernel: usize, +// } + +// impl, K: Kernel> IntoIterator +// for SVRSearchParameters +// { +// type Item = SVRParameters; +// type IntoIter = SVRSearchParametersIterator; + +// fn into_iter(self) -> Self::IntoIter { +// SVRSearchParametersIterator { +// svr_search_parameters: self, +// current_eps: 0, +// current_c: 0, +// current_tol: 0, +// current_kernel: 0, +// } +// } +// } + +// impl, K: Kernel> Iterator +// for SVRSearchParametersIterator +// { +// type Item = SVRParameters; + +// fn next(&mut self) -> Option { +// if self.current_eps == self.svr_search_parameters.eps.len() +// && self.current_c == self.svr_search_parameters.c.len() +// && self.current_tol == self.svr_search_parameters.tol.len() +// && self.current_kernel == self.svr_search_parameters.kernel.len() +// { +// return None; +// } + +// let next = SVRParameters:: { +// eps: self.svr_search_parameters.eps[self.current_eps], +// c: self.svr_search_parameters.c[self.current_c], +// tol: self.svr_search_parameters.tol[self.current_tol], +// kernel: self.svr_search_parameters.kernel[self.current_kernel].clone(), +// m: PhantomData, +// }; + +// if self.current_eps + 1 < self.svr_search_parameters.eps.len() { +// self.current_eps += 1; +// } else if self.current_c + 1 < self.svr_search_parameters.c.len() { +// self.current_eps = 0; +// self.current_c += 1; +// } else if self.current_tol + 1 < self.svr_search_parameters.tol.len() { +// self.current_eps = 0; +// self.current_c = 0; +// self.current_tol += 1; +// } else if self.current_kernel + 1 < self.svr_search_parameters.kernel.len() { +// self.current_eps = 0; +// self.current_c = 0; +// self.current_tol = 0; +// self.current_kernel += 1; +// } else { +// self.current_eps += 1; +// self.current_c += 1; +// self.current_tol += 1; +// self.current_kernel += 1; +// } + +// Some(next) +// } +// } + +// impl> Default for SVRSearchParameters { +// fn default() -> Self { +// let default_params: SVRParameters = SVRParameters::default(); + +// SVRSearchParameters { +// eps: vec![default_params.eps], +// c: vec![default_params.c], +// tol: vec![default_params.tol], +// kernel: vec![default_params.kernel], +// m: PhantomData, +// } +// } +// } + +// #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +// #[derive(Debug)] +// #[cfg_attr( +// feature = "serde", +// serde(bound( +// serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize", +// deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>", +// )) +// )] diff --git a/src/svm/svc.rs b/src/svm/svc.rs index 256c3c3..716f521 100644 --- a/src/svm/svc.rs +++ b/src/svm/svc.rs @@ -100,22 +100,17 @@ pub struct SVCParameters< X: Array2, Y: Array1, > { - #[cfg_attr(feature = "serde", serde(default))] /// Number of epochs. pub epoch: usize, - #[cfg_attr(feature = "serde", serde(default))] /// Regularization parameter. pub c: TX, - #[cfg_attr(feature = "serde", serde(default))] /// Tolerance for stopping criterion. pub tol: TX, #[cfg_attr(feature = "serde", serde(skip_deserializing))] /// The kernel function. pub kernel: Option<&'a dyn Kernel<'a>>, - #[cfg_attr(feature = "serde", serde(default))] /// Unused parameter. m: PhantomData<(X, Y, TY)>, - #[cfg_attr(feature = "serde", serde(default))] /// Controls the pseudo random number generation for shuffling the data for probability estimates seed: Option, } @@ -133,7 +128,7 @@ pub struct SVCParameters< pub struct SVC<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1> { classes: Option>, instances: Option>>, - #[serde(skip)] + #[cfg_attr(feature = "serde", serde(skip))] parameters: Option<&'a SVCParameters<'a, TX, TY, X, Y>>, w: Option>, b: Option, @@ -948,7 +943,10 @@ mod tests { #[cfg(feature = "serde")] use crate::svm::*; - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn svc_fit_predict() { let x = DenseMatrix::from_2d_array(&[ @@ -996,7 +994,10 @@ mod tests { ); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn svc_fit_decision_function() { let x = DenseMatrix::from_2d_array(&[&[4.0, 0.0], &[0.0, 4.0], &[8.0, 0.0], &[0.0, 8.0]]); @@ -1034,7 +1035,10 @@ mod tests { assert!(num::Float::abs(y_hat[0]) <= 0.1); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn svc_fit_predict_rbf() { let x = DenseMatrix::from_2d_array(&[ @@ -1083,7 +1087,10 @@ mod tests { ); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] #[cfg(feature = "serde")] fn svc_serde() { diff --git a/src/svm/svc_gridsearch.rs b/src/svm/svc_gridsearch.rs deleted file mode 100644 index 6f1de6a..0000000 --- a/src/svm/svc_gridsearch.rs +++ /dev/null @@ -1,184 +0,0 @@ -/// SVC grid search parameters -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Clone)] -pub struct SVCSearchParameters< - TX: Number + RealNumber, - TY: Number + Ord, - X: Array2, - Y: Array1, - K: Kernel, -> { - #[cfg_attr(feature = "serde", serde(default))] - /// Number of epochs. - pub epoch: Vec, - #[cfg_attr(feature = "serde", serde(default))] - /// Regularization parameter. - pub c: Vec, - #[cfg_attr(feature = "serde", serde(default))] - /// Tolerance for stopping epoch. - pub tol: Vec, - #[cfg_attr(feature = "serde", serde(default))] - /// The kernel function. - pub kernel: Vec, - #[cfg_attr(feature = "serde", serde(default))] - /// Unused parameter. - m: PhantomData<(X, Y, TY)>, - #[cfg_attr(feature = "serde", serde(default))] - /// Controls the pseudo random number generation for shuffling the data for probability estimates - seed: Vec>, -} - -/// SVC grid search iterator -pub struct SVCSearchParametersIterator< - TX: Number + RealNumber, - TY: Number + Ord, - X: Array2, - Y: Array1, - K: Kernel, -> { - svc_search_parameters: SVCSearchParameters, - current_epoch: usize, - current_c: usize, - current_tol: usize, - current_kernel: usize, - current_seed: usize, -} - -impl, Y: Array1, K: Kernel> - IntoIterator for SVCSearchParameters -{ - type Item = SVCParameters<'a, TX, TY, X, Y>; - type IntoIter = SVCSearchParametersIterator; - - fn into_iter(self) -> Self::IntoIter { - SVCSearchParametersIterator { - svc_search_parameters: self, - current_epoch: 0, - current_c: 0, - current_tol: 0, - current_kernel: 0, - current_seed: 0, - } - } -} - -impl, Y: Array1, K: Kernel> - Iterator for SVCSearchParametersIterator -{ - type Item = SVCParameters; - - fn next(&mut self) -> Option { - if self.current_epoch == self.svc_search_parameters.epoch.len() - && self.current_c == self.svc_search_parameters.c.len() - && self.current_tol == self.svc_search_parameters.tol.len() - && self.current_kernel == self.svc_search_parameters.kernel.len() - && self.current_seed == self.svc_search_parameters.seed.len() - { - return None; - } - - let next = SVCParameters { - epoch: self.svc_search_parameters.epoch[self.current_epoch], - c: self.svc_search_parameters.c[self.current_c], - tol: self.svc_search_parameters.tol[self.current_tol], - kernel: self.svc_search_parameters.kernel[self.current_kernel].clone(), - m: PhantomData, - seed: self.svc_search_parameters.seed[self.current_seed], - }; - - if self.current_epoch + 1 < self.svc_search_parameters.epoch.len() { - self.current_epoch += 1; - } else if self.current_c + 1 < self.svc_search_parameters.c.len() { - self.current_epoch = 0; - self.current_c += 1; - } else if self.current_tol + 1 < self.svc_search_parameters.tol.len() { - self.current_epoch = 0; - self.current_c = 0; - self.current_tol += 1; - } else if self.current_kernel + 1 < self.svc_search_parameters.kernel.len() { - self.current_epoch = 0; - self.current_c = 0; - self.current_tol = 0; - self.current_kernel += 1; - } else if self.current_seed + 1 < self.svc_search_parameters.seed.len() { - self.current_epoch = 0; - self.current_c = 0; - self.current_tol = 0; - self.current_kernel = 0; - self.current_seed += 1; - } else { - self.current_epoch += 1; - self.current_c += 1; - self.current_tol += 1; - self.current_kernel += 1; - self.current_seed += 1; - } - - Some(next) - } -} - -impl, Y: Array1, K: Kernel> Default - for SVCSearchParameters -{ - fn default() -> Self { - let default_params: SVCParameters = SVCParameters::default(); - - SVCSearchParameters { - epoch: vec![default_params.epoch], - c: vec![default_params.c], - tol: vec![default_params.tol], - kernel: vec![default_params.kernel], - m: PhantomData, - seed: vec![default_params.seed], - } - } -} - - -#[cfg(test)] -mod tests { - use num::ToPrimitive; - - use super::*; - use crate::linalg::basic::matrix::DenseMatrix; - use crate::metrics::accuracy; - #[cfg(feature = "serde")] - use crate::svm::*; - - #[test] - fn search_parameters() { - let parameters: SVCSearchParameters, LinearKernel> = - SVCSearchParameters { - epoch: vec![10, 100], - kernel: vec![LinearKernel {}], - ..Default::default() - }; - let mut iter = parameters.into_iter(); - let next = iter.next().unwrap(); - assert_eq!(next.epoch, 10); - assert_eq!(next.kernel, LinearKernel {}); - let next = iter.next().unwrap(); - assert_eq!(next.epoch, 100); - assert_eq!(next.kernel, LinearKernel {}); - assert!(iter.next().is_none()); - } - - #[test] - fn search_parameters() { - let parameters: SVCSearchParameters, LinearKernel> = - SVCSearchParameters { - epoch: vec![10, 100], - kernel: vec![LinearKernel {}], - ..Default::default() - }; - let mut iter = parameters.into_iter(); - let next = iter.next().unwrap(); - assert_eq!(next.epoch, 10); - assert_eq!(next.kernel, LinearKernel {}); - let next = iter.next().unwrap(); - assert_eq!(next.epoch, 100); - assert_eq!(next.kernel, LinearKernel {}); - assert!(iter.next().is_none()); - } -} diff --git a/src/svm/svr.rs b/src/svm/svr.rs index 00191b0..cf35bde 100644 --- a/src/svm/svr.rs +++ b/src/svm/svr.rs @@ -79,140 +79,30 @@ use crate::api::{PredictorBorrow, SupervisedEstimatorBorrow}; use crate::error::{Failed, FailedError}; use crate::linalg::basic::arrays::{Array1, Array2, MutArray}; use crate::numbers::basenum::Number; -use crate::numbers::realnum::RealNumber; +use crate::numbers::floatnum::FloatNumber; use crate::svm::Kernel; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] /// SVR Parameters -pub struct SVRParameters<'a, T: Number + RealNumber> { +pub struct SVRParameters<'a, T: Number + FloatNumber + PartialOrd> { /// Epsilon in the epsilon-SVR model. pub eps: T, /// Regularization parameter. pub c: T, /// Tolerance for stopping criterion. pub tol: T, - #[serde(skip_deserializing)] + #[cfg_attr(feature = "serde", serde(skip_deserializing))] /// The kernel function. pub kernel: Option<&'a dyn Kernel<'a>>, } -// /// SVR grid search parameters -// #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -// #[derive(Debug, Clone)] -// pub struct SVRSearchParameters, K: Kernel> { -// /// Epsilon in the epsilon-SVR model. -// pub eps: Vec, -// /// Regularization parameter. -// pub c: Vec, -// /// Tolerance for stopping eps. -// pub tol: Vec, -// /// The kernel function. -// pub kernel: Vec, -// /// Unused parameter. -// m: PhantomData, -// } - -// /// SVR grid search iterator -// pub struct SVRSearchParametersIterator, K: Kernel> { -// svr_search_parameters: SVRSearchParameters, -// current_eps: usize, -// current_c: usize, -// current_tol: usize, -// current_kernel: usize, -// } - -// impl, K: Kernel> IntoIterator -// for SVRSearchParameters -// { -// type Item = SVRParameters; -// type IntoIter = SVRSearchParametersIterator; - -// fn into_iter(self) -> Self::IntoIter { -// SVRSearchParametersIterator { -// svr_search_parameters: self, -// current_eps: 0, -// current_c: 0, -// current_tol: 0, -// current_kernel: 0, -// } -// } -// } - -// impl, K: Kernel> Iterator -// for SVRSearchParametersIterator -// { -// type Item = SVRParameters; - -// fn next(&mut self) -> Option { -// if self.current_eps == self.svr_search_parameters.eps.len() -// && self.current_c == self.svr_search_parameters.c.len() -// && self.current_tol == self.svr_search_parameters.tol.len() -// && self.current_kernel == self.svr_search_parameters.kernel.len() -// { -// return None; -// } - -// let next = SVRParameters:: { -// eps: self.svr_search_parameters.eps[self.current_eps], -// c: self.svr_search_parameters.c[self.current_c], -// tol: self.svr_search_parameters.tol[self.current_tol], -// kernel: self.svr_search_parameters.kernel[self.current_kernel].clone(), -// m: PhantomData, -// }; - -// if self.current_eps + 1 < self.svr_search_parameters.eps.len() { -// self.current_eps += 1; -// } else if self.current_c + 1 < self.svr_search_parameters.c.len() { -// self.current_eps = 0; -// self.current_c += 1; -// } else if self.current_tol + 1 < self.svr_search_parameters.tol.len() { -// self.current_eps = 0; -// self.current_c = 0; -// self.current_tol += 1; -// } else if self.current_kernel + 1 < self.svr_search_parameters.kernel.len() { -// self.current_eps = 0; -// self.current_c = 0; -// self.current_tol = 0; -// self.current_kernel += 1; -// } else { -// self.current_eps += 1; -// self.current_c += 1; -// self.current_tol += 1; -// self.current_kernel += 1; -// } - -// Some(next) -// } -// } - -// impl> Default for SVRSearchParameters { -// fn default() -> Self { -// let default_params: SVRParameters = SVRParameters::default(); - -// SVRSearchParameters { -// eps: vec![default_params.eps], -// c: vec![default_params.c], -// tol: vec![default_params.tol], -// kernel: vec![default_params.kernel], -// m: PhantomData, -// } -// } -// } - -// #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -// #[derive(Debug)] -// #[cfg_attr( -// feature = "serde", -// serde(bound( -// serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize", -// deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>", -// )) -// )] - +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] /// Epsilon-Support Vector Regression -pub struct SVR<'a, T: Number + RealNumber, X: Array2, Y: Array1> { +pub struct SVR<'a, T: Number + FloatNumber + PartialOrd, X: Array2, Y: Array1> { instances: Option>>, + #[cfg_attr(feature = "serde", serde(skip_deserializing))] parameters: Option<&'a SVRParameters<'a, T>>, w: Option>, b: T, @@ -230,7 +120,7 @@ struct SupportVector { } /// Sequential Minimal Optimization algorithm -struct Optimizer<'a, T: Number + RealNumber> { +struct Optimizer<'a, T: Number + FloatNumber + PartialOrd> { tol: T, c: T, parameters: Option<&'a SVRParameters<'a, T>>, @@ -242,13 +132,15 @@ struct Optimizer<'a, T: Number + RealNumber> { gmaxindex: usize, tau: T, sv: Vec>, + /// avoid infinite loop if SMO does not converge + max_iterations: usize, } struct Cache { data: Vec>>>, } -impl<'a, T: Number + RealNumber> SVRParameters<'a, T> { +impl<'a, T: Number + FloatNumber + PartialOrd> SVRParameters<'a, T> { /// Epsilon in the epsilon-SVR model. pub fn with_eps(mut self, eps: T) -> Self { self.eps = eps; @@ -271,7 +163,7 @@ impl<'a, T: Number + RealNumber> SVRParameters<'a, T> { } } -impl<'a, T: Number + RealNumber> Default for SVRParameters<'a, T> { +impl<'a, T: Number + FloatNumber + PartialOrd> Default for SVRParameters<'a, T> { fn default() -> Self { SVRParameters { eps: T::from_f64(0.1).unwrap(), @@ -282,7 +174,7 @@ impl<'a, T: Number + RealNumber> Default for SVRParameters<'a, T> { } } -impl<'a, T: Number + RealNumber, X: Array2, Y: Array1> +impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2, Y: Array1> SupervisedEstimatorBorrow<'a, X, Y, SVRParameters<'a, T>> for SVR<'a, T, X, Y> { fn new() -> Self { @@ -299,7 +191,7 @@ impl<'a, T: Number + RealNumber, X: Array2, Y: Array1> } } -impl<'a, T: Number + RealNumber, X: Array2, Y: Array1> PredictorBorrow<'a, X, T> +impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2, Y: Array1> PredictorBorrow<'a, X, T> for SVR<'a, T, X, Y> { fn predict(&self, x: &'a X) -> Result, Failed> { @@ -307,7 +199,7 @@ impl<'a, T: Number + RealNumber, X: Array2, Y: Array1> PredictorBorrow<'a, } } -impl<'a, T: Number + RealNumber, X: Array2, Y: Array1> SVR<'a, T, X, Y> { +impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2, Y: Array1> SVR<'a, T, X, Y> { /// Fits SVR to your data. /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation. /// * `y` - target values @@ -388,7 +280,9 @@ impl<'a, T: Number + RealNumber, X: Array2, Y: Array1> SVR<'a, T, X, Y> { } } -impl<'a, T: Number + RealNumber, X: Array2, Y: Array1> PartialEq for SVR<'a, T, X, Y> { +impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2, Y: Array1> PartialEq + for SVR<'a, T, X, Y> +{ fn eq(&self, other: &Self) -> bool { if (self.b - other.b).abs() > T::epsilon() * T::two() || self.w.as_ref().unwrap().len() != other.w.as_ref().unwrap().len() @@ -414,7 +308,7 @@ impl<'a, T: Number + RealNumber, X: Array2, Y: Array1> PartialEq for SVR<' } } -impl SupportVector { +impl SupportVector { fn new(i: usize, x: Vec, y: T, eps: T, k: f64) -> SupportVector { SupportVector { index: i, @@ -426,7 +320,7 @@ impl SupportVector { } } -impl<'a, T: Number + RealNumber> Optimizer<'a, T> { +impl<'a, T: Number + FloatNumber + PartialOrd> Optimizer<'a, T> { fn new, Y: Array1>( x: &'a X, y: &'a Y, @@ -468,12 +362,13 @@ impl<'a, T: Number + RealNumber> Optimizer<'a, T> { gmaxindex: 0, tau: T::from_f64(1e-12).unwrap(), sv: support_vectors, + max_iterations: 49999, } } fn find_min_max_gradient(&mut self) { - // self.gmin = ::max_value()(); - // self.gmax = ::min_value(); + self.gmin = ::max_value(); + self.gmax = ::min_value(); for i in 0..self.sv.len() { let v = &self.sv[i]; @@ -511,10 +406,13 @@ impl<'a, T: Number + RealNumber> Optimizer<'a, T> { /// * hyperplane parameters: w and b (computed with T) fn smo(mut self) -> (Vec>, Vec, T) { let cache: Cache = Cache::new(self.sv.len()); - + let mut n_iteration = 0usize; self.find_min_max_gradient(); while self.gmax - self.gmin > self.tol { + if n_iteration > self.max_iterations { + break; + } let v1 = self.svmax; let i = self.gmaxindex; let old_alpha_i = self.sv[v1].alpha[i]; @@ -659,6 +557,7 @@ impl<'a, T: Number + RealNumber> Optimizer<'a, T> { } self.find_min_max_gradient(); + n_iteration += 1; } let b = -(self.gmax + self.gmin) / T::two(); @@ -694,11 +593,11 @@ impl Cache { #[cfg(test)] mod tests { - // use super::*; - // use crate::linalg::basic::matrix::DenseMatrix; - // use crate::metrics::mean_squared_error; - // #[cfg(feature = "serde")] - // use crate::svm::*; + use super::*; + use crate::linalg::basic::matrix::DenseMatrix; + use crate::metrics::mean_squared_error; + #[cfg(feature = "serde")] + use crate::svm::Kernels; // #[test] // fn search_parameters() { @@ -718,79 +617,97 @@ mod tests { // assert!(iter.next().is_none()); // } - // TODO: had to disable this test as it runs for too long - // #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] - // #[test] - // fn svr_fit_predict() { - // let x = DenseMatrix::from_2d_array(&[ - // &[234.289, 235.6, 159.0, 107.608, 1947., 60.323], - // &[259.426, 232.5, 145.6, 108.632, 1948., 61.122], - // &[258.054, 368.2, 161.6, 109.773, 1949., 60.171], - // &[284.599, 335.1, 165.0, 110.929, 1950., 61.187], - // &[328.975, 209.9, 309.9, 112.075, 1951., 63.221], - // &[346.999, 193.2, 359.4, 113.270, 1952., 63.639], - // &[365.385, 187.0, 354.7, 115.094, 1953., 64.989], - // &[363.112, 357.8, 335.0, 116.219, 1954., 63.761], - // &[397.469, 290.4, 304.8, 117.388, 1955., 66.019], - // &[419.180, 282.2, 285.7, 118.734, 1956., 67.857], - // &[442.769, 293.6, 279.8, 120.445, 1957., 68.169], - // &[444.546, 468.1, 263.7, 121.950, 1958., 66.513], - // &[482.704, 381.3, 255.2, 123.366, 1959., 68.655], - // &[502.601, 393.1, 251.4, 125.368, 1960., 69.564], - // &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], - // &[554.894, 400.7, 282.7, 130.081, 1962., 70.551], - // ]); + //TODO: had to disable this test as it runs for too long + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn svr_fit_predict() { + let x = DenseMatrix::from_2d_array(&[ + &[234.289, 235.6, 159.0, 107.608, 1947., 60.323], + &[259.426, 232.5, 145.6, 108.632, 1948., 61.122], + &[258.054, 368.2, 161.6, 109.773, 1949., 60.171], + &[284.599, 335.1, 165.0, 110.929, 1950., 61.187], + &[328.975, 209.9, 309.9, 112.075, 1951., 63.221], + &[346.999, 193.2, 359.4, 113.270, 1952., 63.639], + &[365.385, 187.0, 354.7, 115.094, 1953., 64.989], + &[363.112, 357.8, 335.0, 116.219, 1954., 63.761], + &[397.469, 290.4, 304.8, 117.388, 1955., 66.019], + &[419.180, 282.2, 285.7, 118.734, 1956., 67.857], + &[442.769, 293.6, 279.8, 120.445, 1957., 68.169], + &[444.546, 468.1, 263.7, 121.950, 1958., 66.513], + &[482.704, 381.3, 255.2, 123.366, 1959., 68.655], + &[502.601, 393.1, 251.4, 125.368, 1960., 69.564], + &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], + &[554.894, 400.7, 282.7, 130.081, 1962., 70.551], + ]); - // let y: Vec = vec![ - // 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, - // 114.2, 115.7, 116.9, - // ]; + let y: Vec = vec![ + 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, + 114.2, 115.7, 116.9, + ]; - // let knl = Kernels::linear(); - // let y_hat = SVR::fit(&x, &y, &SVRParameters::default() - // .with_eps(2.0) - // .with_c(10.0) - // .with_kernel(&knl) - // ) - // .and_then(|lr| lr.predict(&x)) - // .unwrap(); + let knl = Kernels::linear(); + let y_hat = SVR::fit( + &x, + &y, + &SVRParameters::default() + .with_eps(2.0) + .with_c(10.0) + .with_kernel(&knl), + ) + .and_then(|lr| lr.predict(&x)) + .unwrap(); - // assert!(mean_squared_error(&y_hat, &y) < 2.5); - // } + let t = mean_squared_error(&y_hat, &y); + println!("{:?}", t); + assert!(t < 2.5); + } - // #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] - // #[test] - // #[cfg(feature = "serde")] - // fn svr_serde() { - // let x = DenseMatrix::from_2d_array(&[ - // &[234.289, 235.6, 159.0, 107.608, 1947., 60.323], - // &[259.426, 232.5, 145.6, 108.632, 1948., 61.122], - // &[258.054, 368.2, 161.6, 109.773, 1949., 60.171], - // &[284.599, 335.1, 165.0, 110.929, 1950., 61.187], - // &[328.975, 209.9, 309.9, 112.075, 1951., 63.221], - // &[346.999, 193.2, 359.4, 113.270, 1952., 63.639], - // &[365.385, 187.0, 354.7, 115.094, 1953., 64.989], - // &[363.112, 357.8, 335.0, 116.219, 1954., 63.761], - // &[397.469, 290.4, 304.8, 117.388, 1955., 66.019], - // &[419.180, 282.2, 285.7, 118.734, 1956., 67.857], - // &[442.769, 293.6, 279.8, 120.445, 1957., 68.169], - // &[444.546, 468.1, 263.7, 121.950, 1958., 66.513], - // &[482.704, 381.3, 255.2, 123.366, 1959., 68.655], - // &[502.601, 393.1, 251.4, 125.368, 1960., 69.564], - // &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], - // &[554.894, 400.7, 282.7, 130.081, 1962., 70.551], - // ]); + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + #[cfg(feature = "serde")] + fn svr_serde() { + let x = DenseMatrix::from_2d_array(&[ + &[234.289, 235.6, 159.0, 107.608, 1947., 60.323], + &[259.426, 232.5, 145.6, 108.632, 1948., 61.122], + &[258.054, 368.2, 161.6, 109.773, 1949., 60.171], + &[284.599, 335.1, 165.0, 110.929, 1950., 61.187], + &[328.975, 209.9, 309.9, 112.075, 1951., 63.221], + &[346.999, 193.2, 359.4, 113.270, 1952., 63.639], + &[365.385, 187.0, 354.7, 115.094, 1953., 64.989], + &[363.112, 357.8, 335.0, 116.219, 1954., 63.761], + &[397.469, 290.4, 304.8, 117.388, 1955., 66.019], + &[419.180, 282.2, 285.7, 118.734, 1956., 67.857], + &[442.769, 293.6, 279.8, 120.445, 1957., 68.169], + &[444.546, 468.1, 263.7, 121.950, 1958., 66.513], + &[482.704, 381.3, 255.2, 123.366, 1959., 68.655], + &[502.601, 393.1, 251.4, 125.368, 1960., 69.564], + &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], + &[554.894, 400.7, 282.7, 130.081, 1962., 70.551], + ]); - // let y: Vec = vec![ - // 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, - // 114.2, 115.7, 116.9, - // ]; + let y: Vec = vec![ + 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, + 114.2, 115.7, 116.9, + ]; - // let svr = SVR::fit(&x, &y, Default::default()).unwrap(); + let knl = Kernels::rbf().with_gamma(0.7); + let params = SVRParameters::default().with_kernel(&knl); - // let deserialized_svr: SVR, LinearKernel> = - // serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap(); + let svr = SVR::fit(&x, &y, ¶ms).unwrap(); - // assert_eq!(svr, deserialized_svr); - // } + let serialized = &serde_json::to_string(&svr).unwrap(); + + println!("{}", &serialized); + + // let deserialized_svr: SVR, LinearKernel> = + // serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap(); + + // assert_eq!(svr, deserialized_svr); + } } diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index e5d366c..043d79b 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -899,7 +899,10 @@ mod tests { assert!(iter.next().is_none()); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn gini_impurity() { assert!( @@ -915,7 +918,10 @@ mod tests { ); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn fit_predict_iris() { let x: DenseMatrix = DenseMatrix::from_2d_array(&[ @@ -968,7 +974,10 @@ mod tests { ); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn fit_predict_baloons() { let x: DenseMatrix = DenseMatrix::from_2d_array(&[ @@ -1003,7 +1012,10 @@ mod tests { ); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] #[cfg(feature = "serde")] fn serde() { diff --git a/src/tree/decision_tree_regressor.rs b/src/tree/decision_tree_regressor.rs index a2397d1..397040b 100644 --- a/src/tree/decision_tree_regressor.rs +++ b/src/tree/decision_tree_regressor.rs @@ -731,7 +731,10 @@ mod tests { assert!(iter.next().is_none()); } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] fn fit_longley() { let x = DenseMatrix::from_2d_array(&[ @@ -808,7 +811,10 @@ mod tests { } } - #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] #[test] #[cfg(feature = "serde")] fn serde() {