Merge branch 'development' into prdct-prb
This commit is contained in:
@@ -26,6 +26,17 @@ Take a look to the conventions established by existing code:
|
|||||||
* Every module should provide comprehensive tests at the end, in its `mod tests {}` sub-module. These tests can be flagged or not with configuration flags to allow WebAssembly target.
|
* Every module should provide comprehensive tests at the end, in its `mod tests {}` sub-module. These tests can be flagged or not with configuration flags to allow WebAssembly target.
|
||||||
* Run `cargo doc --no-deps --open` and read the generated documentation in the browser to be sure that your changes reflects in the documentation and new code is documented.
|
* Run `cargo doc --no-deps --open` and read the generated documentation in the browser to be sure that your changes reflects in the documentation and new code is documented.
|
||||||
|
|
||||||
|
#### digging deeper
|
||||||
|
* a nice overview of the codebase is given by [static analyzer](https://mozilla.github.io/rust-code-analysis/metrics.html):
|
||||||
|
```
|
||||||
|
$ cargo install rust-code-analysis-cli
|
||||||
|
// print metrics for every module
|
||||||
|
$ rust-code-analysis-cli -m -O json -o . -p src/ --pr
|
||||||
|
// print full AST for a module
|
||||||
|
$ rust-code-analysis-cli -p src/algorithm/neighbour/fastpair.rs --ls 22 --le 213 -d > ast.txt
|
||||||
|
```
|
||||||
|
* find more information about what happens in your binary with [`twiggy`](https://rustwasm.github.io/twiggy/install.html). This need a compiled binary so create a brief `main {}` function using `smartcore` and then point `twiggy` to that file.
|
||||||
|
|
||||||
## Issue Report Process
|
## Issue Report Process
|
||||||
|
|
||||||
1. Go to the project's issues.
|
1. Go to the project's issues.
|
||||||
|
|||||||
@@ -0,0 +1,43 @@
|
|||||||
|
# smartcore: Introduction to modules
|
||||||
|
|
||||||
|
Important source of information:
|
||||||
|
* [Rust API guidelines](https://rust-lang.github.io/api-guidelines/about.html)
|
||||||
|
|
||||||
|
## Walkthrough: traits system and basic structures
|
||||||
|
|
||||||
|
#### numbers
|
||||||
|
The library is founded on basic traits provided by `num-traits`. Basic traits are in `src/numbers`. These traits are used to define all the procedures in the library to make everything safer and provide constraints to what implementations can handle.
|
||||||
|
|
||||||
|
#### linalg
|
||||||
|
`numbers` are made at use in linear algebra structures in the **`src/linalg/basic`** module. These sub-modules define the traits used all over the code base.
|
||||||
|
|
||||||
|
* *arrays*: In particular data structures like `Array`, `Array1` (1-dimensional), `Array2` (matrix, 2-D); plus their "views" traits. Views are used to provide no-footprint access to data, they have composed traits to allow writing (mutable traits: `MutArray`, `ArrayViewMut`, ...).
|
||||||
|
* *matrix*: This provides the main entrypoint to matrices operations and currently the only structure provided in the shape of `struct DenseMatrix`. A matrix can be instantiated and automatically make available all the traits in "arrays" (sparse matrices implementation will be provided).
|
||||||
|
* *vector*: Convenience traits are implemented for `std::Vec` to allow extensive reuse.
|
||||||
|
|
||||||
|
These are all traits and by definition they do not allow instantiation. For instantiable structures see implementation like `DenseMatrix` with relative constructor.
|
||||||
|
|
||||||
|
#### linalg/traits
|
||||||
|
The traits in `src/linalg/traits` are closely linked to Linear Algebra's theoretical framework. These traits are used to specify characteristics and constraints for types accepted by various algorithms. For example these allow to define if a matrix is `QRDecomposable` and/or `SVDDecomposable`. See docstring for referencese to theoretical framework.
|
||||||
|
|
||||||
|
As above these are all traits and by definition they do not allow instantiation. They are mostly used to provide constraints for implementations. For example, the implementation for Linear Regression requires the input data `X` to be in `smartcore`'s trait system `Array2<FloatNumber> + QRDecomposable<TX> + SVDDecomposable<TX>`, a 2-D matrix that is both QR and SVD decomposable; that is what the provided strucure `linalg::arrays::matrix::DenseMatrix` happens to be: `impl<T: FloatNumber> QRDecomposable<T> for DenseMatrix<T> {};impl<T: FloatNumber> SVDDecomposable<T> for DenseMatrix<T> {}`.
|
||||||
|
|
||||||
|
#### metrics
|
||||||
|
Implementations for metrics (classification, regression, cluster, ...) and distance measure (Euclidean, Hamming, Manhattan, ...). For example: `Accuracy`, `F1`, `AUC`, `Precision`, `R2`. As everything else in the code base, these implementations reuse `numbers` and `linalg` traits and structures.
|
||||||
|
|
||||||
|
These are collected in structures like `pub struct ClassificationMetrics<T> {}` that implements `metrics::Metrics`, these are groups of functions (classification, regression, cluster, ...) that provide instantiation for the structures. Each of those instantiation can be passed around using the relative function, like `pub fn accuracy<T: Number + RealNumber + FloatNumber, V: ArrayView1<T>>(y_true: &V, y_pred: &V) -> T`. This provides a mechanism for metrics to be passed to higher interfaces like the `cross_validate`:
|
||||||
|
```rust
|
||||||
|
let results =
|
||||||
|
cross_validate(
|
||||||
|
BiasedEstimator::new(), // custom estimator
|
||||||
|
&x, &y, // input data
|
||||||
|
NoParameters {}, // extra parameters
|
||||||
|
cv, // type of cross validator
|
||||||
|
&accuracy // **metrics function** <--------
|
||||||
|
).unwrap();
|
||||||
|
```
|
||||||
|
|
||||||
|
TODO: complete for all modules
|
||||||
|
|
||||||
|
## Notebooks
|
||||||
|
Proceed to the [**notebooks**](https://github.com/smartcorelib/smartcore-jupyter/) to see these modules in action.
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
### I'm submitting a
|
### I'm submitting a
|
||||||
- [ ] bug report.
|
- [ ] bug report.
|
||||||
|
- [ ] improvement.
|
||||||
- [ ] feature request.
|
- [ ] feature request.
|
||||||
|
|
||||||
### Current Behaviour:
|
### Current Behaviour:
|
||||||
|
|||||||
@@ -46,11 +46,16 @@ jobs:
|
|||||||
- name: Install test runner for wasi
|
- name: Install test runner for wasi
|
||||||
if: matrix.platform.target == 'wasm32-wasi'
|
if: matrix.platform.target == 'wasm32-wasi'
|
||||||
run: curl https://wasmtime.dev/install.sh -sSf | bash
|
run: curl https://wasmtime.dev/install.sh -sSf | bash
|
||||||
- name: Stable Build
|
- name: Stable Build with all features
|
||||||
uses: actions-rs/cargo@v1
|
uses: actions-rs/cargo@v1
|
||||||
with:
|
with:
|
||||||
command: build
|
command: build
|
||||||
args: --all-features --target ${{ matrix.platform.target }}
|
args: --all-features --target ${{ matrix.platform.target }}
|
||||||
|
- name: Stable Build without features
|
||||||
|
uses: actions-rs/cargo@v1
|
||||||
|
with:
|
||||||
|
command: build
|
||||||
|
args: --target ${{ matrix.platform.target }}
|
||||||
- name: Tests
|
- name: Tests
|
||||||
if: matrix.platform.target == 'x86_64-unknown-linux-gnu' || matrix.platform.target == 'x86_64-pc-windows-msvc' || matrix.platform.target == 'aarch64-apple-darwin'
|
if: matrix.platform.target == 'x86_64-unknown-linux-gnu' || matrix.platform.target == 'x86_64-pc-windows-msvc' || matrix.platform.target == 'aarch64-apple-darwin'
|
||||||
uses: actions-rs/cargo@v1
|
uses: actions-rs/cargo@v1
|
||||||
|
|||||||
@@ -27,3 +27,5 @@ out.svg
|
|||||||
|
|
||||||
FlameGraph/
|
FlameGraph/
|
||||||
out.stacks
|
out.stacks
|
||||||
|
*.json
|
||||||
|
*.txt
|
||||||
+17
-10
@@ -4,22 +4,29 @@ All notable changes to this project will be documented in this file.
|
|||||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
## [Unreleased]
|
## [0.3.0] - 2022-11-09
|
||||||
|
|
||||||
## Added
|
## Added
|
||||||
- Seeds to multiple algorithims that depend on random number generation.
|
- WARNING: Breaking changes!
|
||||||
- Added feature `js` to use WASM in browser
|
- Complete refactoring with **extensive API changes** that includes:
|
||||||
- Drop `nalgebra-bindings` feature
|
|
||||||
- Complete refactoring with *extensive API changes* that includes:
|
|
||||||
* moving to a new traits system, less structs more traits
|
* moving to a new traits system, less structs more traits
|
||||||
* adapting all the modules to the new traits system
|
* adapting all the modules to the new traits system
|
||||||
* moving towards Rust 2021, in particular the use of `dyn` and `as_ref`
|
* moving to Rust 2021, use of object-safe traits and `as_ref`
|
||||||
* reorganization of the code base, trying to eliminate duplicates
|
* reorganization of the code base, eliminate duplicates
|
||||||
|
- implements `readers` (needs "serde" feature) for read/write CSV file, extendible to other formats
|
||||||
|
- default feature is now Wasm-/Wasi-first
|
||||||
|
|
||||||
## BREAKING CHANGE
|
## Changed
|
||||||
- Added a new parameter to `train_test_split` to define the seed.
|
- WARNING: Breaking changes!
|
||||||
|
- Seeds to multiple algorithims that depend on random number generation
|
||||||
|
- Added a new parameter to `train_test_split` to define the seed
|
||||||
|
- changed use of "serde" feature
|
||||||
|
|
||||||
## [0.2.1] - 2022-05-10
|
## Dropped
|
||||||
|
- WARNING: Breaking changes!
|
||||||
|
- Drop `nalgebra-bindings` feature, only `ndarray` as supported library
|
||||||
|
|
||||||
|
## [0.2.1] - 2021-05-10
|
||||||
|
|
||||||
## Added
|
## Added
|
||||||
- L2 regularization penalty to the Logistic Regression
|
- L2 regularization penalty to the Logistic Regression
|
||||||
|
|||||||
+24
-18
@@ -1,9 +1,9 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "smartcore"
|
name = "smartcore"
|
||||||
description = "The most advanced machine learning library in rust."
|
description = "Machine Learning in Rust."
|
||||||
homepage = "https://smartcorelib.org"
|
homepage = "https://smartcorelib.org"
|
||||||
version = "0.4.0"
|
version = "0.3.0"
|
||||||
authors = ["SmartCore Developers"]
|
authors = ["smartcore Developers"]
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
documentation = "https://docs.rs/smartcore"
|
documentation = "https://docs.rs/smartcore"
|
||||||
@@ -11,6 +11,13 @@ repository = "https://github.com/smartcorelib/smartcore"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
keywords = ["machine-learning", "statistical", "ai", "optimization", "linear-algebra"]
|
keywords = ["machine-learning", "statistical", "ai", "optimization", "linear-algebra"]
|
||||||
categories = ["science"]
|
categories = ["science"]
|
||||||
|
exclude = [
|
||||||
|
".github",
|
||||||
|
".gitignore",
|
||||||
|
"smartcore.iml",
|
||||||
|
"smartcore.svg",
|
||||||
|
"tests/"
|
||||||
|
]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
approx = "0.5.1"
|
approx = "0.5.1"
|
||||||
@@ -22,38 +29,37 @@ rand = { version = "0.8.5", default-features = false, features = ["small_rng"] }
|
|||||||
rand_distr = { version = "0.4", optional = true }
|
rand_distr = { version = "0.4", optional = true }
|
||||||
serde = { version = "1", features = ["derive"], optional = true }
|
serde = { version = "1", features = ["derive"], optional = true }
|
||||||
|
|
||||||
|
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
|
||||||
|
typetag = { version = "0.2", optional = true }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["serde", "datasets"]
|
default = []
|
||||||
serde = ["dep:serde"]
|
serde = ["dep:serde", "dep:typetag"]
|
||||||
ndarray-bindings = ["dep:ndarray"]
|
ndarray-bindings = ["dep:ndarray"]
|
||||||
datasets = ["dep:rand_distr", "std"]
|
datasets = ["dep:rand_distr", "std_rand", "serde"]
|
||||||
std = ["rand/std_rng", "rand/std"]
|
std_rand = ["rand/std_rng", "rand/std"]
|
||||||
# wasm32 only
|
# used by wasm32-unknown-unknown for in-browser usage
|
||||||
js = ["getrandom/js"]
|
js = ["getrandom/js"]
|
||||||
|
|
||||||
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||||
getrandom = { version = "0.2", optional = true }
|
getrandom = { version = "0.2.8", optional = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
|
||||||
itertools = "*"
|
|
||||||
criterion = { version = "0.4", default-features = false }
|
|
||||||
serde_json = "1.0"
|
|
||||||
bincode = "1.3.1"
|
|
||||||
|
|
||||||
[target.'cfg(all(target_arch = "wasm32", not(target_os = "wasi")))'.dev-dependencies]
|
[target.'cfg(all(target_arch = "wasm32", not(target_os = "wasi")))'.dev-dependencies]
|
||||||
wasm-bindgen-test = "0.3"
|
wasm-bindgen-test = "0.3"
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
itertools = "0.10.5"
|
||||||
|
serde_json = "1.0"
|
||||||
|
bincode = "1.3.1"
|
||||||
|
|
||||||
[workspace]
|
[workspace]
|
||||||
resolver = "2"
|
|
||||||
|
|
||||||
[profile.test]
|
[profile.test]
|
||||||
debug = 1
|
debug = 1
|
||||||
opt-level = 3
|
opt-level = 3
|
||||||
split-debuginfo = "unpacked"
|
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
strip = true
|
strip = true
|
||||||
debug = 1
|
|
||||||
lto = true
|
lto = true
|
||||||
codegen-units = 1
|
codegen-units = 1
|
||||||
overflow-checks = true
|
overflow-checks = true
|
||||||
|
|||||||
@@ -186,7 +186,7 @@
|
|||||||
same "printed page" as the copyright notice for easier
|
same "printed page" as the copyright notice for easier
|
||||||
identification within third-party archives.
|
identification within third-party archives.
|
||||||
|
|
||||||
Copyright 2019-present at SmartCore developers (smartcorelib.org)
|
Copyright 2019-present at smartcore developers (smartcorelib.org)
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
|
|||||||
@@ -1,60 +1,21 @@
|
|||||||
<p align="center">
|
<p align="center">
|
||||||
<a href="https://smartcorelib.org">
|
<a href="https://smartcorelib.org">
|
||||||
<img src="smartcore.svg" width="450" alt="SmartCore">
|
<img src="smartcore.svg" width="450" alt="smartcore">
|
||||||
</a>
|
</a>
|
||||||
</p>
|
</p>
|
||||||
<p align = "center">
|
<p align = "center">
|
||||||
<strong>
|
<strong>
|
||||||
<a href="https://smartcorelib.org">User guide</a> | <a href="https://docs.rs/smartcore/">API</a> | <a href="https://github.com/smartcorelib/smartcore-examples">Examples</a>
|
<a href="https://smartcorelib.org">User guide</a> | <a href="https://docs.rs/smartcore/">API</a> | <a href="https://github.com/smartcorelib/smartcore-jupyter">Notebooks</a>
|
||||||
</strong>
|
</strong>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
-----
|
-----
|
||||||
|
|
||||||
<p align = "center">
|
<p align = "center">
|
||||||
<b>The Most Advanced Machine Learning Library In Rust.</b>
|
<b>Machine Learning in Rust</b>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
-----
|
-----
|
||||||
|
[](https://github.com/smartcorelib/smartcore/actions/workflows/ci.yml)
|
||||||
|
|
||||||
To start getting familiar with the new Smartcore v0.5 API, there is now available a [**Jupyter Notebook environment repository**](https://github.com/smartcorelib/smartcore-jupyter). Please see instructions there, your feedback is valuable for the future of the library.
|
To start getting familiar with the new smartcore v0.5 API, there is now available a [**Jupyter Notebook environment repository**](https://github.com/smartcorelib/smartcore-jupyter). Please see instructions there, contributions welcome see [CONTRIBUTING](.github/CONTRIBUTING.md).
|
||||||
|
|
||||||
## Developers
|
|
||||||
Contributions welcome, please start from [CONTRIBUTING and other relevant files](.github/CONTRIBUTING.md).
|
|
||||||
|
|
||||||
### Walkthrough: traits system and basic structures
|
|
||||||
|
|
||||||
#### numbers
|
|
||||||
The library is founded on basic traits provided by `num-traits`. Basic traits are in `src/numbers`. These traits are used to define all the procedures in the library to make everything safer and provide constraints to what implementations can handle.
|
|
||||||
|
|
||||||
#### linalg
|
|
||||||
`numbers` are made at use in linear algebra structures in the **`src/linalg/basic`** module. These sub-modules define the traits used all over the code base.
|
|
||||||
|
|
||||||
* *arrays*: In particular data structures like `Array`, `Array1` (1-dimensional), `Array2` (matrix, 2-D); plus their "views" traits. Views are used to provide no-footprint access to data, they have composed traits to allow writing (mutable traits: `MutArray`, `ArrayViewMut`, ...).
|
|
||||||
* *matrix*: This provides the main entrypoint to matrices operations and currently the only structure provided in the shape of `struct DenseMatrix`. A matrix can be instantiated and automatically make available all the traits in "arrays" (sparse matrices implementation will be provided).
|
|
||||||
* *vector*: Convenience traits are implemented for `std::Vec` to allow extensive reuse.
|
|
||||||
|
|
||||||
These are all traits and by definition they do not allow instantiation. For instantiable structures see implementation like `DenseMatrix` with relative constructor.
|
|
||||||
|
|
||||||
#### linalg/traits
|
|
||||||
The traits in `src/linalg/traits` are closely linked to Linear Algebra's theoretical framework. These traits are used to specify characteristics and constraints for types accepted by various algorithms. For example these allow to define if a matrix is `QRDecomposable` and/or `SVDDecomposable`. See docstring for referencese to theoretical framework.
|
|
||||||
|
|
||||||
As above these are all traits and by definition they do not allow instantiation. They are mostly used to provide constraints for implementations. For example, the implementation for Linear Regression requires the input data `X` to be in `smartcore`'s trait system `Array2<FloatNumber> + QRDecomposable<TX> + SVDDecomposable<TX>`, a 2-D matrix that is both QR and SVD decomposable; that is what the provided strucure `linalg::arrays::matrix::DenseMatrix` happens to be: `impl<T: FloatNumber> QRDecomposable<T> for DenseMatrix<T> {};impl<T: FloatNumber> SVDDecomposable<T> for DenseMatrix<T> {}`.
|
|
||||||
|
|
||||||
#### metrics
|
|
||||||
Implementations for metrics (classification, regression, cluster, ...) and distance measure (Euclidean, Hamming, Manhattan, ...). For example: `Accuracy`, `F1`, `AUC`, `Precision`, `R2`. As everything else in the code base, these implementations reuse `numbers` and `linalg` traits and structures.
|
|
||||||
|
|
||||||
These are collected in structures like `pub struct ClassificationMetrics<T> {}` that implements `metrics::Metrics`, these are groups of functions (classification, regression, cluster, ...) that provide instantiation for the structures. Each of those instantiation can be passed around using the relative function, like `pub fn accuracy<T: Number + RealNumber + FloatNumber, V: ArrayView1<T>>(y_true: &V, y_pred: &V) -> T`. This provides a mechanism for metrics to be passed to higher interfaces like the `cross_validate`:
|
|
||||||
```rust
|
|
||||||
let results =
|
|
||||||
cross_validate(
|
|
||||||
BiasedEstimator::fit, // custom estimator
|
|
||||||
&x, &y, // input data
|
|
||||||
NoParameters {}, // extra parameters
|
|
||||||
cv, // type of cross validator
|
|
||||||
&accuracy // **metrics function** <--------
|
|
||||||
).unwrap();
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
TODO: complete for all modules
|
|
||||||
|
|||||||
+1
-1
@@ -76,5 +76,5 @@
|
|||||||
y="81.876823"
|
y="81.876823"
|
||||||
x="91.861809"
|
x="91.861809"
|
||||||
id="tspan842"
|
id="tspan842"
|
||||||
sodipodi:role="line">SmartCore</tspan></text>
|
sodipodi:role="line">smartcore</tspan></text>
|
||||||
</svg>
|
</svg>
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 2.5 KiB After Width: | Height: | Size: 2.5 KiB |
@@ -64,7 +64,7 @@ struct Node {
|
|||||||
max_dist: f64,
|
max_dist: f64,
|
||||||
parent_dist: f64,
|
parent_dist: f64,
|
||||||
children: Vec<Node>,
|
children: Vec<Node>,
|
||||||
scale: i64,
|
_scale: i64,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@@ -84,7 +84,7 @@ impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
|
|||||||
max_dist: 0f64,
|
max_dist: 0f64,
|
||||||
parent_dist: 0f64,
|
parent_dist: 0f64,
|
||||||
children: Vec::new(),
|
children: Vec::new(),
|
||||||
scale: 0,
|
_scale: 0,
|
||||||
};
|
};
|
||||||
let mut tree = CoverTree {
|
let mut tree = CoverTree {
|
||||||
base,
|
base,
|
||||||
@@ -245,7 +245,7 @@ impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
|
|||||||
max_dist: 0f64,
|
max_dist: 0f64,
|
||||||
parent_dist: 0f64,
|
parent_dist: 0f64,
|
||||||
children: Vec::new(),
|
children: Vec::new(),
|
||||||
scale: 100,
|
_scale: 100,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -306,7 +306,7 @@ impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
|
|||||||
max_dist: 0f64,
|
max_dist: 0f64,
|
||||||
parent_dist: 0f64,
|
parent_dist: 0f64,
|
||||||
children,
|
children,
|
||||||
scale: 100,
|
_scale: 100,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let mut far: Vec<DistanceSet> = Vec::new();
|
let mut far: Vec<DistanceSet> = Vec::new();
|
||||||
@@ -375,7 +375,7 @@ impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
|
|||||||
max_dist: self.max(consumed_set),
|
max_dist: self.max(consumed_set),
|
||||||
parent_dist: 0f64,
|
parent_dist: 0f64,
|
||||||
children,
|
children,
|
||||||
scale: (top_scale - max_scale),
|
_scale: (top_scale - max_scale),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
//! these re-calculated centroids becoming the new centers of their respective clusters. Next all instances of the training set are re-assigned to their closest cluster again.
|
//! these re-calculated centroids becoming the new centers of their respective clusters. Next all instances of the training set are re-assigned to their closest cluster again.
|
||||||
//! This iterative process continues until convergence is achieved and the clusters are considered settled.
|
//! This iterative process continues until convergence is achieved and the clusters are considered settled.
|
||||||
//!
|
//!
|
||||||
//! Initial choice of K data points is very important and has big effect on performance of the algorithm. SmartCore uses k-means++ algorithm to initialize cluster centers.
|
//! Initial choice of K data points is very important and has big effect on performance of the algorithm. `smartcore` uses k-means++ algorithm to initialize cluster centers.
|
||||||
//!
|
//!
|
||||||
//! Example:
|
//! Example:
|
||||||
//!
|
//!
|
||||||
@@ -74,7 +74,7 @@ pub struct KMeans<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
|
|||||||
k: usize,
|
k: usize,
|
||||||
_y: Vec<usize>,
|
_y: Vec<usize>,
|
||||||
size: Vec<usize>,
|
size: Vec<usize>,
|
||||||
distortion: f64,
|
_distortion: f64,
|
||||||
centroids: Vec<Vec<f64>>,
|
centroids: Vec<Vec<f64>>,
|
||||||
_phantom_tx: PhantomData<TX>,
|
_phantom_tx: PhantomData<TX>,
|
||||||
_phantom_ty: PhantomData<TY>,
|
_phantom_ty: PhantomData<TY>,
|
||||||
@@ -313,7 +313,7 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y>
|
|||||||
k: parameters.k,
|
k: parameters.k,
|
||||||
_y: y,
|
_y: y,
|
||||||
size,
|
size,
|
||||||
distortion,
|
_distortion: distortion,
|
||||||
centroids,
|
centroids,
|
||||||
_phantom_tx: PhantomData,
|
_phantom_tx: PhantomData,
|
||||||
_phantom_ty: PhantomData,
|
_phantom_ty: PhantomData,
|
||||||
@@ -470,7 +470,7 @@ mod tests {
|
|||||||
wasm_bindgen_test::wasm_bindgen_test
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
)]
|
)]
|
||||||
#[test]
|
#[test]
|
||||||
fn fit_predict_iris() {
|
fn fit_predict() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[5.1, 3.5, 1.4, 0.2],
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
&[4.9, 3.0, 1.4, 0.2],
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
|
|||||||
+1
-1
@@ -1,6 +1,6 @@
|
|||||||
//! Datasets
|
//! Datasets
|
||||||
//!
|
//!
|
||||||
//! In this module you will find small datasets that are used in SmartCore mostly for demonstration purposes.
|
//! In this module you will find small datasets that are used in `smartcore` mostly for demonstration purposes.
|
||||||
pub mod boston;
|
pub mod boston;
|
||||||
pub mod breast_cancer;
|
pub mod breast_cancer;
|
||||||
pub mod diabetes;
|
pub mod diabetes;
|
||||||
|
|||||||
+1
-1
@@ -7,7 +7,7 @@
|
|||||||
//! set and then aggregate their individual predictions to form a final prediction. In classification setting the overall prediction is the most commonly
|
//! set and then aggregate their individual predictions to form a final prediction. In classification setting the overall prediction is the most commonly
|
||||||
//! occurring majority class among the individual predictions.
|
//! occurring majority class among the individual predictions.
|
||||||
//!
|
//!
|
||||||
//! In SmartCore you will find implementation of RandomForest - a popular averaging algorithms based on randomized [decision trees](../tree/index.html).
|
//! In `smartcore` you will find implementation of RandomForest - a popular averaging algorithms based on randomized [decision trees](../tree/index.html).
|
||||||
//! Random forests provide an improvement over bagged trees by way of a small tweak that decorrelates the trees. As in bagging, we build a number of
|
//! Random forests provide an improvement over bagged trees by way of a small tweak that decorrelates the trees. As in bagging, we build a number of
|
||||||
//! decision trees on bootstrapped training samples. But when building these decision trees, each time a split in a tree is considered,
|
//! decision trees on bootstrapped training samples. But when building these decision trees, each time a split in a tree is considered,
|
||||||
//! a random sample of _m_ predictors is chosen as split candidates from the full set of _p_ predictors.
|
//! a random sample of _m_ predictors is chosen as split candidates from the full set of _p_ predictors.
|
||||||
|
|||||||
@@ -104,7 +104,6 @@ pub struct RandomForestClassifier<
|
|||||||
X: Array2<TX>,
|
X: Array2<TX>,
|
||||||
Y: Array1<TY>,
|
Y: Array1<TY>,
|
||||||
> {
|
> {
|
||||||
parameters: Option<RandomForestClassifierParameters>,
|
|
||||||
trees: Option<Vec<DecisionTreeClassifier<TX, TY, X, Y>>>,
|
trees: Option<Vec<DecisionTreeClassifier<TX, TY, X, Y>>>,
|
||||||
classes: Option<Vec<TY>>,
|
classes: Option<Vec<TY>>,
|
||||||
samples: Option<Vec<Vec<bool>>>,
|
samples: Option<Vec<Vec<bool>>>,
|
||||||
@@ -198,7 +197,6 @@ impl<TX: Number + FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y:
|
|||||||
{
|
{
|
||||||
fn new() -> Self {
|
fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
parameters: Option::None,
|
|
||||||
trees: Option::None,
|
trees: Option::None,
|
||||||
classes: Option::None,
|
classes: Option::None,
|
||||||
samples: Option::None,
|
samples: Option::None,
|
||||||
@@ -501,7 +499,6 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
|
|||||||
}
|
}
|
||||||
|
|
||||||
Ok(RandomForestClassifier {
|
Ok(RandomForestClassifier {
|
||||||
parameters: Some(parameters),
|
|
||||||
trees: Some(trees),
|
trees: Some(trees),
|
||||||
classes: Some(classes),
|
classes: Some(classes),
|
||||||
samples: maybe_all_samples,
|
samples: maybe_all_samples,
|
||||||
@@ -669,7 +666,7 @@ mod tests {
|
|||||||
wasm_bindgen_test::wasm_bindgen_test
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
)]
|
)]
|
||||||
#[test]
|
#[test]
|
||||||
fn fit_predict_iris() {
|
fn fit_predict() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[5.1, 3.5, 1.4, 0.2],
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
&[4.9, 3.0, 1.4, 0.2],
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
|
|||||||
@@ -98,7 +98,6 @@ pub struct RandomForestRegressor<
|
|||||||
X: Array2<TX>,
|
X: Array2<TX>,
|
||||||
Y: Array1<TY>,
|
Y: Array1<TY>,
|
||||||
> {
|
> {
|
||||||
parameters: Option<RandomForestRegressorParameters>,
|
|
||||||
trees: Option<Vec<DecisionTreeRegressor<TX, TY, X, Y>>>,
|
trees: Option<Vec<DecisionTreeRegressor<TX, TY, X, Y>>>,
|
||||||
samples: Option<Vec<Vec<bool>>>,
|
samples: Option<Vec<Vec<bool>>>,
|
||||||
}
|
}
|
||||||
@@ -177,7 +176,6 @@ impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1
|
|||||||
{
|
{
|
||||||
fn new() -> Self {
|
fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
parameters: Option::None,
|
|
||||||
trees: Option::None,
|
trees: Option::None,
|
||||||
samples: Option::None,
|
samples: Option::None,
|
||||||
}
|
}
|
||||||
@@ -434,7 +432,6 @@ impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1
|
|||||||
}
|
}
|
||||||
|
|
||||||
Ok(RandomForestRegressor {
|
Ok(RandomForestRegressor {
|
||||||
parameters: Some(parameters),
|
|
||||||
trees: Some(trees),
|
trees: Some(trees),
|
||||||
samples: maybe_all_samples,
|
samples: maybe_all_samples,
|
||||||
})
|
})
|
||||||
|
|||||||
+30
-11
@@ -8,25 +8,38 @@
|
|||||||
#![warn(missing_docs)]
|
#![warn(missing_docs)]
|
||||||
#![warn(rustdoc::missing_doc_code_examples)]
|
#![warn(rustdoc::missing_doc_code_examples)]
|
||||||
|
|
||||||
//! # SmartCore
|
//! # smartcore
|
||||||
//!
|
//!
|
||||||
//! Welcome to SmartCore, machine learning in Rust!
|
//! Welcome to `smartcore`, machine learning in Rust!
|
||||||
//!
|
//!
|
||||||
//! SmartCore features various classification, regression and clustering algorithms including support vector machines, random forests, k-means and DBSCAN,
|
//! `smartcore` features various classification, regression and clustering algorithms including support vector machines, random forests, k-means and DBSCAN,
|
||||||
//! as well as tools for model selection and model evaluation.
|
//! as well as tools for model selection and model evaluation.
|
||||||
//!
|
//!
|
||||||
//! SmartCore provides its own traits system that extends Rust standard library, to deal with linear algebra and common
|
//! `smartcore` provides its own traits system that extends Rust standard library, to deal with linear algebra and common
|
||||||
//! computational models. Its API is designed using well recognizable patterns. Extra features (like support for [ndarray](https://docs.rs/ndarray)
|
//! computational models. Its API is designed using well recognizable patterns. Extra features (like support for [ndarray](https://docs.rs/ndarray)
|
||||||
//! structures) is available via optional features.
|
//! structures) is available via optional features.
|
||||||
//!
|
//!
|
||||||
//! ## Getting Started
|
//! ## Getting Started
|
||||||
//!
|
//!
|
||||||
//! To start using SmartCore simply add the following to your Cargo.toml file:
|
//! To start using `smartcore` latest stable version simply add the following to your `Cargo.toml` file:
|
||||||
|
//! ```ignore
|
||||||
|
//! [dependencies]
|
||||||
|
//! smartcore = "*"
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! To start using smartcore development version with latest unstable additions:
|
||||||
//! ```ignore
|
//! ```ignore
|
||||||
//! [dependencies]
|
//! [dependencies]
|
||||||
//! smartcore = { git = "https://github.com/smartcorelib/smartcore", branch = "development" }
|
//! smartcore = { git = "https://github.com/smartcorelib/smartcore", branch = "development" }
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
|
//! There are different features that can be added to the base library, for example to add sample datasets:
|
||||||
|
//! ```ignore
|
||||||
|
//! [dependencies]
|
||||||
|
//! smartcore = { git = "https://github.com/smartcorelib/smartcore", features = ["datasets"] }
|
||||||
|
//! ```
|
||||||
|
//! Check `smartcore`'s `Cargo.toml` for available features.
|
||||||
|
//!
|
||||||
//! ## Using Jupyter
|
//! ## Using Jupyter
|
||||||
//! For quick introduction, Jupyter Notebooks are available [here](https://github.com/smartcorelib/smartcore-jupyter/tree/main/notebooks).
|
//! For quick introduction, Jupyter Notebooks are available [here](https://github.com/smartcorelib/smartcore-jupyter/tree/main/notebooks).
|
||||||
//! You can set up a local environment to run Rust notebooks using [EVCXR](https://github.com/google/evcxr)
|
//! You can set up a local environment to run Rust notebooks using [EVCXR](https://github.com/google/evcxr)
|
||||||
@@ -37,7 +50,7 @@
|
|||||||
//! For example, you can use this code to fit a [K Nearest Neighbors classifier](neighbors/knn_classifier/index.html) to a dataset that is defined as standard Rust vector:
|
//! For example, you can use this code to fit a [K Nearest Neighbors classifier](neighbors/knn_classifier/index.html) to a dataset that is defined as standard Rust vector:
|
||||||
//!
|
//!
|
||||||
//! ```
|
//! ```
|
||||||
//! // DenseMatrix defenition
|
//! // DenseMatrix definition
|
||||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||||
//! // KNNClassifier
|
//! // KNNClassifier
|
||||||
//! use smartcore::neighbors::knn_classifier::*;
|
//! use smartcore::neighbors::knn_classifier::*;
|
||||||
@@ -62,7 +75,9 @@
|
|||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! ## Overview
|
//! ## Overview
|
||||||
//! All machine learning algorithms in SmartCore are grouped into these broad categories:
|
//!
|
||||||
|
//! ### Supported algorithms
|
||||||
|
//! All machine learning algorithms are grouped into these broad categories:
|
||||||
//! * [Clustering](cluster/index.html), unsupervised clustering of unlabeled data.
|
//! * [Clustering](cluster/index.html), unsupervised clustering of unlabeled data.
|
||||||
//! * [Matrix Decomposition](decomposition/index.html), various methods for matrix decomposition.
|
//! * [Matrix Decomposition](decomposition/index.html), various methods for matrix decomposition.
|
||||||
//! * [Linear Models](linear/index.html), regression and classification methods where output is assumed to have linear relation to explanatory variables
|
//! * [Linear Models](linear/index.html), regression and classification methods where output is assumed to have linear relation to explanatory variables
|
||||||
@@ -71,11 +86,14 @@
|
|||||||
//! * [Nearest Neighbors](neighbors/index.html), K Nearest Neighbors for classification and regression
|
//! * [Nearest Neighbors](neighbors/index.html), K Nearest Neighbors for classification and regression
|
||||||
//! * [Naive Bayes](naive_bayes/index.html), statistical classification technique based on Bayes Theorem
|
//! * [Naive Bayes](naive_bayes/index.html), statistical classification technique based on Bayes Theorem
|
||||||
//! * [SVM](svm/index.html), support vector machines
|
//! * [SVM](svm/index.html), support vector machines
|
||||||
|
//!
|
||||||
|
//! ### Linear Algebra traits system
|
||||||
|
//! For an introduction to `smartcore`'s traits system see [this notebook](https://github.com/smartcorelib/smartcore-jupyter/blob/5523993c53c6ec1fd72eea130ef4e7883121c1ea/notebooks/01-A-little-bit-about-numbers.ipynb)
|
||||||
|
|
||||||
/// Foundamental numbers traits
|
/// Foundamental numbers traits
|
||||||
pub mod numbers;
|
pub mod numbers;
|
||||||
|
|
||||||
/// Various algorithms and helper methods that are used elsewhere in SmartCore
|
/// Various algorithms and helper methods that are used elsewhere in smartcore
|
||||||
pub mod algorithm;
|
pub mod algorithm;
|
||||||
pub mod api;
|
pub mod api;
|
||||||
|
|
||||||
@@ -89,7 +107,7 @@ pub mod decomposition;
|
|||||||
/// Ensemble methods, including Random Forest classifier and regressor
|
/// Ensemble methods, including Random Forest classifier and regressor
|
||||||
pub mod ensemble;
|
pub mod ensemble;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
/// Diverse collection of linear algebra abstractions and methods that power SmartCore algorithms
|
/// Diverse collection of linear algebra abstractions and methods that power smartcore algorithms
|
||||||
pub mod linalg;
|
pub mod linalg;
|
||||||
/// Supervised classification and regression models that assume linear relationship between dependent and explanatory variables.
|
/// Supervised classification and regression models that assume linear relationship between dependent and explanatory variables.
|
||||||
pub mod linear;
|
pub mod linear;
|
||||||
@@ -105,8 +123,9 @@ pub mod neighbors;
|
|||||||
pub mod optimization;
|
pub mod optimization;
|
||||||
/// Preprocessing utilities
|
/// Preprocessing utilities
|
||||||
pub mod preprocessing;
|
pub mod preprocessing;
|
||||||
// /// Reading in Data.
|
/// Reading in data from serialized formats
|
||||||
// pub mod readers;
|
#[cfg(feature = "serde")]
|
||||||
|
pub mod readers;
|
||||||
/// Support Vector Machines
|
/// Support Vector Machines
|
||||||
pub mod svm;
|
pub mod svm;
|
||||||
/// Supervised tree-based learning methods
|
/// Supervised tree-based learning methods
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
//! \\[\hat{\beta} = (X^TX)^{-1}X^Ty \\]
|
//! \\[\hat{\beta} = (X^TX)^{-1}X^Ty \\]
|
||||||
//!
|
//!
|
||||||
//! the \\((X^TX)^{-1}\\) term is both computationally expensive and numerically unstable. An alternative approach is to use a matrix decomposition to avoid this operation.
|
//! the \\((X^TX)^{-1}\\) term is both computationally expensive and numerically unstable. An alternative approach is to use a matrix decomposition to avoid this operation.
|
||||||
//! SmartCore uses [SVD](../../linalg/svd/index.html) and [QR](../../linalg/qr/index.html) matrix decomposition to find estimates of \\(\hat{\beta}\\).
|
//! `smartcore` uses [SVD](../../linalg/svd/index.html) and [QR](../../linalg/qr/index.html) matrix decomposition to find estimates of \\(\hat{\beta}\\).
|
||||||
//! The QR decomposition is more computationally efficient and more numerically stable than calculating the normal equation directly,
|
//! The QR decomposition is more computationally efficient and more numerically stable than calculating the normal equation directly,
|
||||||
//! but does not work for all data matrices. Unlike the QR decomposition, all matrices have an SVD decomposition.
|
//! but does not work for all data matrices. Unlike the QR decomposition, all matrices have an SVD decomposition.
|
||||||
//!
|
//!
|
||||||
@@ -113,7 +113,6 @@ pub struct LinearRegression<
|
|||||||
> {
|
> {
|
||||||
coefficients: Option<X>,
|
coefficients: Option<X>,
|
||||||
intercept: Option<TX>,
|
intercept: Option<TX>,
|
||||||
solver: LinearRegressionSolverName,
|
|
||||||
_phantom_ty: PhantomData<TY>,
|
_phantom_ty: PhantomData<TY>,
|
||||||
_phantom_y: PhantomData<Y>,
|
_phantom_y: PhantomData<Y>,
|
||||||
}
|
}
|
||||||
@@ -210,7 +209,6 @@ impl<
|
|||||||
Self {
|
Self {
|
||||||
coefficients: Option::None,
|
coefficients: Option::None,
|
||||||
intercept: Option::None,
|
intercept: Option::None,
|
||||||
solver: LinearRegressionParameters::default().solver,
|
|
||||||
_phantom_ty: PhantomData,
|
_phantom_ty: PhantomData,
|
||||||
_phantom_y: PhantomData,
|
_phantom_y: PhantomData,
|
||||||
}
|
}
|
||||||
@@ -276,7 +274,6 @@ impl<
|
|||||||
Ok(LinearRegression {
|
Ok(LinearRegression {
|
||||||
intercept: Some(*w.get((num_attributes, 0))),
|
intercept: Some(*w.get((num_attributes, 0))),
|
||||||
coefficients: Some(weights),
|
coefficients: Some(weights),
|
||||||
solver: parameters.solver,
|
|
||||||
_phantom_ty: PhantomData,
|
_phantom_ty: PhantomData,
|
||||||
_phantom_y: PhantomData,
|
_phantom_y: PhantomData,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
//!
|
//!
|
||||||
//! \\[ Pr(y=1) \approx \frac{e^{\beta_0 + \sum_{i=1}^n \beta_iX_i}}{1 + e^{\beta_0 + \sum_{i=1}^n \beta_iX_i}} \\]
|
//! \\[ Pr(y=1) \approx \frac{e^{\beta_0 + \sum_{i=1}^n \beta_iX_i}}{1 + e^{\beta_0 + \sum_{i=1}^n \beta_iX_i}} \\]
|
||||||
//!
|
//!
|
||||||
//! SmartCore uses [limited memory BFGS](https://en.wikipedia.org/wiki/Limited-memory_BFGS) method to find estimates of regression coefficients, \\(\beta\\)
|
//! `smartcore` uses [limited memory BFGS](https://en.wikipedia.org/wiki/Limited-memory_BFGS) method to find estimates of regression coefficients, \\(\beta\\)
|
||||||
//!
|
//!
|
||||||
//! Example:
|
//! Example:
|
||||||
//!
|
//!
|
||||||
@@ -518,12 +518,9 @@ impl<TX: Number + FloatNumber + RealNumber, TY: Number + Ord, X: Array2<TX>, Y:
|
|||||||
for (i, y_hat_i) in y_hat.iterator(0).enumerate().take(n) {
|
for (i, y_hat_i) in y_hat.iterator(0).enumerate().take(n) {
|
||||||
result.set(
|
result.set(
|
||||||
i,
|
i,
|
||||||
self.classes()[if RealNumber::sigmoid(*y_hat_i + intercept) > RealNumber::half()
|
self.classes()[usize::from(
|
||||||
{
|
RealNumber::sigmoid(*y_hat_i + intercept) > RealNumber::half(),
|
||||||
1
|
)],
|
||||||
} else {
|
|
||||||
0
|
|
||||||
}],
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
//! where \\(\alpha \geq 0\\) is a tuning parameter that controls strength of regularization. When \\(\alpha = 0\\) the penalty term has no effect, and ridge regression will produce the least squares estimates.
|
//! where \\(\alpha \geq 0\\) is a tuning parameter that controls strength of regularization. When \\(\alpha = 0\\) the penalty term has no effect, and ridge regression will produce the least squares estimates.
|
||||||
//! However, as \\(\alpha \rightarrow \infty\\), the impact of the shrinkage penalty grows, and the ridge regression coefficient estimates will approach zero.
|
//! However, as \\(\alpha \rightarrow \infty\\), the impact of the shrinkage penalty grows, and the ridge regression coefficient estimates will approach zero.
|
||||||
//!
|
//!
|
||||||
//! SmartCore uses [SVD](../../linalg/svd/index.html) and [Cholesky](../../linalg/cholesky/index.html) matrix decomposition to find estimates of \\(\hat{\beta}\\).
|
//! `smartcore` uses [SVD](../../linalg/svd/index.html) and [Cholesky](../../linalg/cholesky/index.html) matrix decomposition to find estimates of \\(\hat{\beta}\\).
|
||||||
//! The Cholesky decomposition is more computationally efficient and more numerically stable than calculating the normal equation directly,
|
//! The Cholesky decomposition is more computationally efficient and more numerically stable than calculating the normal equation directly,
|
||||||
//! but does not work for all data matrices. Unlike the Cholesky decomposition, all matrices have an SVD decomposition.
|
//! but does not work for all data matrices. Unlike the Cholesky decomposition, all matrices have an SVD decomposition.
|
||||||
//!
|
//!
|
||||||
@@ -197,7 +197,6 @@ pub struct RidgeRegression<
|
|||||||
> {
|
> {
|
||||||
coefficients: Option<X>,
|
coefficients: Option<X>,
|
||||||
intercept: Option<TX>,
|
intercept: Option<TX>,
|
||||||
solver: Option<RidgeRegressionSolverName>,
|
|
||||||
_phantom_ty: PhantomData<TY>,
|
_phantom_ty: PhantomData<TY>,
|
||||||
_phantom_y: PhantomData<Y>,
|
_phantom_y: PhantomData<Y>,
|
||||||
}
|
}
|
||||||
@@ -259,7 +258,6 @@ impl<
|
|||||||
Self {
|
Self {
|
||||||
coefficients: Option::None,
|
coefficients: Option::None,
|
||||||
intercept: Option::None,
|
intercept: Option::None,
|
||||||
solver: Option::None,
|
|
||||||
_phantom_ty: PhantomData,
|
_phantom_ty: PhantomData,
|
||||||
_phantom_y: PhantomData,
|
_phantom_y: PhantomData,
|
||||||
}
|
}
|
||||||
@@ -367,7 +365,6 @@ impl<
|
|||||||
Ok(RidgeRegression {
|
Ok(RidgeRegression {
|
||||||
intercept: Some(b),
|
intercept: Some(b),
|
||||||
coefficients: Some(w),
|
coefficients: Some(w),
|
||||||
solver: Some(parameters.solver),
|
|
||||||
_phantom_ty: PhantomData,
|
_phantom_ty: PhantomData,
|
||||||
_phantom_y: PhantomData,
|
_phantom_y: PhantomData,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
//! let y_pred: Vec<f64> = vec![0., 2., 1., 3.];
|
//! let y_pred: Vec<f64> = vec![0., 2., 1., 3.];
|
||||||
//! let y_true: Vec<f64> = vec![0., 1., 2., 3.];
|
//! let y_true: Vec<f64> = vec![0., 1., 2., 3.];
|
||||||
//!
|
//!
|
||||||
//! let score: f64 = Accuracy::new().get_score(&y_pred, &y_true);
|
//! let score: f64 = Accuracy::new().get_score( &y_true, &y_pred);
|
||||||
//! ```
|
//! ```
|
||||||
//! With integers:
|
//! With integers:
|
||||||
//! ```
|
//! ```
|
||||||
@@ -21,7 +21,7 @@
|
|||||||
//! let y_pred: Vec<i64> = vec![0, 2, 1, 3];
|
//! let y_pred: Vec<i64> = vec![0, 2, 1, 3];
|
||||||
//! let y_true: Vec<i64> = vec![0, 1, 2, 3];
|
//! let y_true: Vec<i64> = vec![0, 1, 2, 3];
|
||||||
//!
|
//!
|
||||||
//! let score: f64 = Accuracy::new().get_score(&y_pred, &y_true);
|
//! let score: f64 = Accuracy::new().get_score( &y_true, &y_pred);
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
@@ -92,7 +92,7 @@ mod tests {
|
|||||||
let y_pred: Vec<f64> = vec![0., 2., 1., 3.];
|
let y_pred: Vec<f64> = vec![0., 2., 1., 3.];
|
||||||
let y_true: Vec<f64> = vec![0., 1., 2., 3.];
|
let y_true: Vec<f64> = vec![0., 1., 2., 3.];
|
||||||
|
|
||||||
let score1: f64 = Accuracy::<f64>::new().get_score(&y_pred, &y_true);
|
let score1: f64 = Accuracy::<f64>::new().get_score(&y_true, &y_pred);
|
||||||
let score2: f64 = Accuracy::<f64>::new().get_score(&y_true, &y_true);
|
let score2: f64 = Accuracy::<f64>::new().get_score(&y_true, &y_true);
|
||||||
|
|
||||||
assert!((score1 - 0.5).abs() < 1e-8);
|
assert!((score1 - 0.5).abs() < 1e-8);
|
||||||
@@ -108,7 +108,7 @@ mod tests {
|
|||||||
let y_pred: Vec<i32> = vec![0, 2, 1, 3];
|
let y_pred: Vec<i32> = vec![0, 2, 1, 3];
|
||||||
let y_true: Vec<i32> = vec![0, 1, 2, 3];
|
let y_true: Vec<i32> = vec![0, 1, 2, 3];
|
||||||
|
|
||||||
let score1: f64 = Accuracy::<i32>::new().get_score(&y_pred, &y_true);
|
let score1: f64 = Accuracy::<i32>::new().get_score(&y_true, &y_pred);
|
||||||
let score2: f64 = Accuracy::<i32>::new().get_score(&y_true, &y_true);
|
let score2: f64 = Accuracy::<i32>::new().get_score(&y_true, &y_true);
|
||||||
|
|
||||||
assert_eq!(score1, 0.5);
|
assert_eq!(score1, 0.5);
|
||||||
|
|||||||
+1
-1
@@ -2,7 +2,7 @@
|
|||||||
//! Computes the area under the receiver operating characteristic (ROC) curve that is equal to the probability that a classifier will rank a
|
//! Computes the area under the receiver operating characteristic (ROC) curve that is equal to the probability that a classifier will rank a
|
||||||
//! randomly chosen positive instance higher than a randomly chosen negative one.
|
//! randomly chosen positive instance higher than a randomly chosen negative one.
|
||||||
//!
|
//!
|
||||||
//! SmartCore calculates ROC AUC from Wilcoxon or Mann-Whitney U test.
|
//! `smartcore` calculates ROC AUC from Wilcoxon or Mann-Whitney U test.
|
||||||
//!
|
//!
|
||||||
//! Example:
|
//! Example:
|
||||||
//! ```
|
//! ```
|
||||||
|
|||||||
+2
-2
@@ -15,7 +15,7 @@
|
|||||||
//! let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
|
//! let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
|
||||||
//!
|
//!
|
||||||
//! let beta = 1.0; // beta default is equal 1.0 anyway
|
//! let beta = 1.0; // beta default is equal 1.0 anyway
|
||||||
//! let score: f64 = F1::new_with(beta).get_score(&y_pred, &y_true);
|
//! let score: f64 = F1::new_with(beta).get_score( &y_true, &y_pred);
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
@@ -92,7 +92,7 @@ mod tests {
|
|||||||
let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
|
let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
|
||||||
|
|
||||||
let beta = 1.0;
|
let beta = 1.0;
|
||||||
let score1: f64 = F1::new_with(beta).get_score(&y_pred, &y_true);
|
let score1: f64 = F1::new_with(beta).get_score(&y_true, &y_pred);
|
||||||
let score2: f64 = F1::new_with(beta).get_score(&y_true, &y_true);
|
let score2: f64 = F1::new_with(beta).get_score(&y_true, &y_true);
|
||||||
|
|
||||||
println!("{:?}", score1);
|
println!("{:?}", score1);
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
//! let y_pred: Vec<f64> = vec![3., -0.5, 2., 7.];
|
//! let y_pred: Vec<f64> = vec![3., -0.5, 2., 7.];
|
||||||
//! let y_true: Vec<f64> = vec![2.5, 0.0, 2., 8.];
|
//! let y_true: Vec<f64> = vec![2.5, 0.0, 2., 8.];
|
||||||
//!
|
//!
|
||||||
//! let mse: f64 = MeanAbsoluteError::new().get_score(&y_pred, &y_true);
|
//! let mse: f64 = MeanAbsoluteError::new().get_score( &y_true, &y_pred);
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
@@ -85,7 +85,7 @@ mod tests {
|
|||||||
let y_true: Vec<f64> = vec![3., -0.5, 2., 7.];
|
let y_true: Vec<f64> = vec![3., -0.5, 2., 7.];
|
||||||
let y_pred: Vec<f64> = vec![2.5, 0.0, 2., 8.];
|
let y_pred: Vec<f64> = vec![2.5, 0.0, 2., 8.];
|
||||||
|
|
||||||
let score1: f64 = MeanAbsoluteError::new().get_score(&y_pred, &y_true);
|
let score1: f64 = MeanAbsoluteError::new().get_score(&y_true, &y_pred);
|
||||||
let score2: f64 = MeanAbsoluteError::new().get_score(&y_true, &y_true);
|
let score2: f64 = MeanAbsoluteError::new().get_score(&y_true, &y_true);
|
||||||
|
|
||||||
assert!((score1 - 0.5).abs() < 1e-8);
|
assert!((score1 - 0.5).abs() < 1e-8);
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
//! let y_pred: Vec<f64> = vec![3., -0.5, 2., 7.];
|
//! let y_pred: Vec<f64> = vec![3., -0.5, 2., 7.];
|
||||||
//! let y_true: Vec<f64> = vec![2.5, 0.0, 2., 8.];
|
//! let y_true: Vec<f64> = vec![2.5, 0.0, 2., 8.];
|
||||||
//!
|
//!
|
||||||
//! let mse: f64 = MeanSquareError::new().get_score(&y_pred, &y_true);
|
//! let mse: f64 = MeanSquareError::new().get_score( &y_true, &y_pred);
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
@@ -85,7 +85,7 @@ mod tests {
|
|||||||
let y_true: Vec<f64> = vec![3., -0.5, 2., 7.];
|
let y_true: Vec<f64> = vec![3., -0.5, 2., 7.];
|
||||||
let y_pred: Vec<f64> = vec![2.5, 0.0, 2., 8.];
|
let y_pred: Vec<f64> = vec![2.5, 0.0, 2., 8.];
|
||||||
|
|
||||||
let score1: f64 = MeanSquareError::new().get_score(&y_pred, &y_true);
|
let score1: f64 = MeanSquareError::new().get_score(&y_true, &y_pred);
|
||||||
let score2: f64 = MeanSquareError::new().get_score(&y_true, &y_true);
|
let score2: f64 = MeanSquareError::new().get_score(&y_true, &y_true);
|
||||||
|
|
||||||
assert!((score1 - 0.375).abs() < 1e-8);
|
assert!((score1 - 0.375).abs() < 1e-8);
|
||||||
|
|||||||
+1
-1
@@ -4,7 +4,7 @@
|
|||||||
//! In a feedback loop you build your model first, then you get feedback from metrics, improve it and repeat until your model achieve desirable performance.
|
//! In a feedback loop you build your model first, then you get feedback from metrics, improve it and repeat until your model achieve desirable performance.
|
||||||
//! Evaluation metrics helps to explain the performance of a model and compare models based on an objective criterion.
|
//! Evaluation metrics helps to explain the performance of a model and compare models based on an objective criterion.
|
||||||
//!
|
//!
|
||||||
//! Choosing the right metric is crucial while evaluating machine learning models. In SmartCore you will find metrics for these classes of ML models:
|
//! Choosing the right metric is crucial while evaluating machine learning models. In `smartcore` you will find metrics for these classes of ML models:
|
||||||
//!
|
//!
|
||||||
//! * [Classification metrics](struct.ClassificationMetrics.html)
|
//! * [Classification metrics](struct.ClassificationMetrics.html)
|
||||||
//! * [Regression metrics](struct.RegressionMetrics.html)
|
//! * [Regression metrics](struct.RegressionMetrics.html)
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
//! let y_pred: Vec<f64> = vec![0., 1., 1., 0.];
|
//! let y_pred: Vec<f64> = vec![0., 1., 1., 0.];
|
||||||
//! let y_true: Vec<f64> = vec![0., 0., 1., 1.];
|
//! let y_true: Vec<f64> = vec![0., 0., 1., 1.];
|
||||||
//!
|
//!
|
||||||
//! let score: f64 = Precision::new().get_score(&y_pred, &y_true);
|
//! let score: f64 = Precision::new().get_score(&y_true, &y_pred);
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
@@ -104,17 +104,17 @@ mod tests {
|
|||||||
let y_true: Vec<f64> = vec![0., 1., 1., 0.];
|
let y_true: Vec<f64> = vec![0., 1., 1., 0.];
|
||||||
let y_pred: Vec<f64> = vec![0., 0., 1., 1.];
|
let y_pred: Vec<f64> = vec![0., 0., 1., 1.];
|
||||||
|
|
||||||
let score1: f64 = Precision::new().get_score(&y_pred, &y_true);
|
let score1: f64 = Precision::new().get_score(&y_true, &y_pred);
|
||||||
let score2: f64 = Precision::new().get_score(&y_pred, &y_pred);
|
let score2: f64 = Precision::new().get_score(&y_pred, &y_pred);
|
||||||
|
|
||||||
assert!((score1 - 0.5).abs() < 1e-8);
|
assert!((score1 - 0.5).abs() < 1e-8);
|
||||||
assert!((score2 - 1.0).abs() < 1e-8);
|
assert!((score2 - 1.0).abs() < 1e-8);
|
||||||
|
|
||||||
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
|
||||||
let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
|
let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
|
||||||
|
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
||||||
|
|
||||||
let score3: f64 = Precision::new().get_score(&y_pred, &y_true);
|
let score3: f64 = Precision::new().get_score(&y_true, &y_pred);
|
||||||
assert!((score3 - 0.5).abs() < 1e-8);
|
assert!((score3 - 0.6666666666).abs() < 1e-8);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg_attr(
|
#[cfg_attr(
|
||||||
@@ -126,7 +126,7 @@ mod tests {
|
|||||||
let y_true: Vec<f64> = vec![0., 0., 0., 1., 1., 1., 2., 2., 2.];
|
let y_true: Vec<f64> = vec![0., 0., 0., 1., 1., 1., 2., 2., 2.];
|
||||||
let y_pred: Vec<f64> = vec![0., 1., 2., 0., 1., 2., 0., 1., 2.];
|
let y_pred: Vec<f64> = vec![0., 1., 2., 0., 1., 2., 0., 1., 2.];
|
||||||
|
|
||||||
let score1: f64 = Precision::new().get_score(&y_pred, &y_true);
|
let score1: f64 = Precision::new().get_score(&y_true, &y_pred);
|
||||||
let score2: f64 = Precision::new().get_score(&y_pred, &y_pred);
|
let score2: f64 = Precision::new().get_score(&y_pred, &y_pred);
|
||||||
|
|
||||||
assert!((score1 - 0.333333333).abs() < 1e-8);
|
assert!((score1 - 0.333333333).abs() < 1e-8);
|
||||||
|
|||||||
+1
-1
@@ -14,7 +14,7 @@
|
|||||||
//! let y_pred: Vec<f64> = vec![3., -0.5, 2., 7.];
|
//! let y_pred: Vec<f64> = vec![3., -0.5, 2., 7.];
|
||||||
//! let y_true: Vec<f64> = vec![2.5, 0.0, 2., 8.];
|
//! let y_true: Vec<f64> = vec![2.5, 0.0, 2., 8.];
|
||||||
//!
|
//!
|
||||||
//! let mse: f64 = MeanAbsoluteError::new().get_score(&y_pred, &y_true);
|
//! let mse: f64 = MeanAbsoluteError::new().get_score( &y_true, &y_pred);
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
//! let y_pred: Vec<f64> = vec![0., 1., 1., 0.];
|
//! let y_pred: Vec<f64> = vec![0., 1., 1., 0.];
|
||||||
//! let y_true: Vec<f64> = vec![0., 0., 1., 1.];
|
//! let y_true: Vec<f64> = vec![0., 0., 1., 1.];
|
||||||
//!
|
//!
|
||||||
//! let score: f64 = Recall::new().get_score(&y_pred, &y_true);
|
//! let score: f64 = Recall::new().get_score( &y_true, &y_pred);
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
@@ -105,17 +105,17 @@ mod tests {
|
|||||||
let y_true: Vec<f64> = vec![0., 1., 1., 0.];
|
let y_true: Vec<f64> = vec![0., 1., 1., 0.];
|
||||||
let y_pred: Vec<f64> = vec![0., 0., 1., 1.];
|
let y_pred: Vec<f64> = vec![0., 0., 1., 1.];
|
||||||
|
|
||||||
let score1: f64 = Recall::new().get_score(&y_pred, &y_true);
|
let score1: f64 = Recall::new().get_score(&y_true, &y_pred);
|
||||||
let score2: f64 = Recall::new().get_score(&y_pred, &y_pred);
|
let score2: f64 = Recall::new().get_score(&y_pred, &y_pred);
|
||||||
|
|
||||||
assert!((score1 - 0.5).abs() < 1e-8);
|
assert!((score1 - 0.5).abs() < 1e-8);
|
||||||
assert!((score2 - 1.0).abs() < 1e-8);
|
assert!((score2 - 1.0).abs() < 1e-8);
|
||||||
|
|
||||||
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
|
||||||
let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
|
let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
|
||||||
|
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
||||||
|
|
||||||
let score3: f64 = Recall::new().get_score(&y_pred, &y_true);
|
let score3: f64 = Recall::new().get_score(&y_true, &y_pred);
|
||||||
assert!((score3 - 0.6666666666666666).abs() < 1e-8);
|
assert!((score3 - 0.5).abs() < 1e-8);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg_attr(
|
#[cfg_attr(
|
||||||
@@ -127,7 +127,7 @@ mod tests {
|
|||||||
let y_true: Vec<f64> = vec![0., 0., 0., 1., 1., 1., 2., 2., 2.];
|
let y_true: Vec<f64> = vec![0., 0., 0., 1., 1., 1., 2., 2., 2.];
|
||||||
let y_pred: Vec<f64> = vec![0., 1., 2., 0., 1., 2., 0., 1., 2.];
|
let y_pred: Vec<f64> = vec![0., 1., 2., 0., 1., 2., 0., 1., 2.];
|
||||||
|
|
||||||
let score1: f64 = Recall::new().get_score(&y_pred, &y_true);
|
let score1: f64 = Recall::new().get_score(&y_true, &y_pred);
|
||||||
let score2: f64 = Recall::new().get_score(&y_pred, &y_pred);
|
let score2: f64 = Recall::new().get_score(&y_pred, &y_pred);
|
||||||
|
|
||||||
assert!((score1 - 0.333333333).abs() < 1e-8);
|
assert!((score1 - 0.333333333).abs() < 1e-8);
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
//! Splitting data into multiple subsets helps us to find the right combination of hyperparameters, estimate model performance and choose the right model for
|
//! Splitting data into multiple subsets helps us to find the right combination of hyperparameters, estimate model performance and choose the right model for
|
||||||
//! the data.
|
//! the data.
|
||||||
//!
|
//!
|
||||||
//! In SmartCore a random split into training and test sets can be quickly computed with the [train_test_split](./fn.train_test_split.html) helper function.
|
//! In `smartcore` a random split into training and test sets can be quickly computed with the [train_test_split](./fn.train_test_split.html) helper function.
|
||||||
//!
|
//!
|
||||||
//! ```
|
//! ```
|
||||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||||
|
|||||||
@@ -364,6 +364,20 @@ pub struct BernoulliNB<
|
|||||||
binarize: Option<TX>,
|
binarize: Option<TX>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<TX: Number + PartialOrd, TY: Number + Ord + Unsigned, X: Array2<TX>, Y: Array1<TY>>
|
||||||
|
fmt::Display for BernoulliNB<TX, TY, X, Y>
|
||||||
|
{
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
writeln!(
|
||||||
|
f,
|
||||||
|
"BernoulliNB:\ninner: {:?}\nbinarize: {:?}",
|
||||||
|
self.inner.as_ref().unwrap(),
|
||||||
|
self.binarize.as_ref().unwrap()
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<TX: Number + PartialOrd, TY: Number + Ord + Unsigned, X: Array2<TX>, Y: Array1<TY>>
|
impl<TX: Number + PartialOrd, TY: Number + Ord + Unsigned, X: Array2<TX>, Y: Array1<TY>>
|
||||||
SupervisedEstimator<X, Y, BernoulliNBParameters<TX>> for BernoulliNB<TX, TY, X, Y>
|
SupervisedEstimator<X, Y, BernoulliNBParameters<TX>> for BernoulliNB<TX, TY, X, Y>
|
||||||
{
|
{
|
||||||
@@ -594,6 +608,9 @@ mod tests {
|
|||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// test Display
|
||||||
|
println!("{}", &bnb);
|
||||||
|
|
||||||
let distribution = bnb.inner.clone().unwrap().distribution;
|
let distribution = bnb.inner.clone().unwrap().distribution;
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
|||||||
@@ -139,6 +139,17 @@ impl<T: Number + Unsigned> NBDistribution<T, T> for CategoricalNBDistribution<T>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T: Number + Unsigned, X: Array2<T>, Y: Array1<T>> fmt::Display for CategoricalNB<T, X, Y> {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
writeln!(
|
||||||
|
f,
|
||||||
|
"CategoricalNB:\ninner: {:?}",
|
||||||
|
self.inner.as_ref().unwrap()
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T: Number + Unsigned> CategoricalNBDistribution<T> {
|
impl<T: Number + Unsigned> CategoricalNBDistribution<T> {
|
||||||
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
|
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
|
||||||
/// * `x` - training data.
|
/// * `x` - training data.
|
||||||
@@ -539,6 +550,8 @@ mod tests {
|
|||||||
let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap();
|
let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap();
|
||||||
let y_hat = cnb.predict(&x).unwrap();
|
let y_hat = cnb.predict(&x).unwrap();
|
||||||
assert_eq!(y_hat, vec![0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1]);
|
assert_eq!(y_hat, vec![0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1]);
|
||||||
|
|
||||||
|
println!("{}", &cnb);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg_attr(
|
#[cfg_attr(
|
||||||
|
|||||||
@@ -271,6 +271,19 @@ pub struct GaussianNB<
|
|||||||
inner: Option<BaseNaiveBayes<TX, TY, X, Y, GaussianNBDistribution<TY>>>,
|
inner: Option<BaseNaiveBayes<TX, TY, X, Y, GaussianNBDistribution<TY>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<
|
||||||
|
TX: Number + RealNumber + RealNumber,
|
||||||
|
TY: Number + Ord + Unsigned,
|
||||||
|
X: Array2<TX>,
|
||||||
|
Y: Array1<TY>,
|
||||||
|
> fmt::Display for GaussianNB<TX, TY, X, Y>
|
||||||
|
{
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
writeln!(f, "GaussianNB:\ninner: {:?}", self.inner.as_ref().unwrap())?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<
|
impl<
|
||||||
TX: Number + RealNumber + RealNumber,
|
TX: Number + RealNumber + RealNumber,
|
||||||
TY: Number + Ord + Unsigned,
|
TY: Number + Ord + Unsigned,
|
||||||
@@ -433,6 +446,9 @@ mod tests {
|
|||||||
let gnb = GaussianNB::fit(&x, &y, parameters).unwrap();
|
let gnb = GaussianNB::fit(&x, &y, parameters).unwrap();
|
||||||
|
|
||||||
assert_eq!(gnb.class_priors(), &priors);
|
assert_eq!(gnb.class_priors(), &priors);
|
||||||
|
|
||||||
|
// test display for GNB
|
||||||
|
println!("{}", &gnb);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg_attr(
|
#[cfg_attr(
|
||||||
|
|||||||
@@ -309,6 +309,19 @@ pub struct MultinomialNB<
|
|||||||
inner: Option<BaseNaiveBayes<TX, TY, X, Y, MultinomialNBDistribution<TY>>>,
|
inner: Option<BaseNaiveBayes<TX, TY, X, Y, MultinomialNBDistribution<TY>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<TX: Number + Unsigned, TY: Number + Ord + Unsigned, X: Array2<TX>, Y: Array1<TY>> fmt::Display
|
||||||
|
for MultinomialNB<TX, TY, X, Y>
|
||||||
|
{
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
writeln!(
|
||||||
|
f,
|
||||||
|
"MultinomialNB:\ninner: {:?}",
|
||||||
|
self.inner.as_ref().unwrap()
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<TX: Number + Unsigned, TY: Number + Ord + Unsigned, X: Array2<TX>, Y: Array1<TY>>
|
impl<TX: Number + Unsigned, TY: Number + Ord + Unsigned, X: Array2<TX>, Y: Array1<TY>>
|
||||||
SupervisedEstimator<X, Y, MultinomialNBParameters> for MultinomialNB<TX, TY, X, Y>
|
SupervisedEstimator<X, Y, MultinomialNBParameters> for MultinomialNB<TX, TY, X, Y>
|
||||||
{
|
{
|
||||||
@@ -500,6 +513,9 @@ mod tests {
|
|||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// test display
|
||||||
|
println!("{}", &nb);
|
||||||
|
|
||||||
let y_hat = nb.predict(&x).unwrap();
|
let y_hat = nb.predict(&x).unwrap();
|
||||||
|
|
||||||
let distribution = nb.inner.clone().unwrap().distribution;
|
let distribution = nb.inner.clone().unwrap().distribution;
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
//! # K Nearest Neighbors Classifier
|
//! # K Nearest Neighbors Classifier
|
||||||
//!
|
//!
|
||||||
//! SmartCore relies on 2 backend algorithms to speedup KNN queries:
|
//! `smartcore` relies on 2 backend algorithms to speedup KNN queries:
|
||||||
//! * [`LinearSearch`](../../algorithm/neighbour/linear_search/index.html)
|
//! * [`LinearSearch`](../../algorithm/neighbour/linear_search/index.html)
|
||||||
//! * [`CoverTree`](../../algorithm/neighbour/cover_tree/index.html)
|
//! * [`CoverTree`](../../algorithm/neighbour/cover_tree/index.html)
|
||||||
//!
|
//!
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
//! # Real Number
|
//! # Real Number
|
||||||
//! Most algorithms in SmartCore rely on basic linear algebra operations like dot product, matrix decomposition and other subroutines that are defined for a set of real numbers, ℝ.
|
//! Most algorithms in `smartcore` rely on basic linear algebra operations like dot product, matrix decomposition and other subroutines that are defined for a set of real numbers, ℝ.
|
||||||
//! This module defines real number and some useful functions that are used in [Linear Algebra](../../linalg/index.html) module.
|
//! This module defines real number and some useful functions that are used in [Linear Algebra](../../linalg/index.html) module.
|
||||||
|
|
||||||
use num_traits::Float;
|
use num_traits::Float;
|
||||||
|
|||||||
+19
-7
@@ -1,19 +1,31 @@
|
|||||||
#[cfg(not(feature = "std"))]
|
#[cfg(not(feature = "std_rand"))]
|
||||||
pub(crate) use rand::rngs::SmallRng as RngImpl;
|
pub use rand::rngs::SmallRng as RngImpl;
|
||||||
#[cfg(feature = "std")]
|
#[cfg(feature = "std_rand")]
|
||||||
pub(crate) use rand::rngs::StdRng as RngImpl;
|
pub use rand::rngs::StdRng as RngImpl;
|
||||||
use rand::SeedableRng;
|
use rand::SeedableRng;
|
||||||
|
|
||||||
pub(crate) fn get_rng_impl(seed: Option<u64>) -> RngImpl {
|
/// Custom switch for random fuctions
|
||||||
|
pub fn get_rng_impl(seed: Option<u64>) -> RngImpl {
|
||||||
match seed {
|
match seed {
|
||||||
Some(seed) => RngImpl::seed_from_u64(seed),
|
Some(seed) => RngImpl::seed_from_u64(seed),
|
||||||
None => {
|
None => {
|
||||||
cfg_if::cfg_if! {
|
cfg_if::cfg_if! {
|
||||||
if #[cfg(feature = "std")] {
|
if #[cfg(feature = "std_rand")] {
|
||||||
use rand::RngCore;
|
use rand::RngCore;
|
||||||
RngImpl::seed_from_u64(rand::thread_rng().next_u64())
|
RngImpl::seed_from_u64(rand::thread_rng().next_u64())
|
||||||
} else {
|
} else {
|
||||||
panic!("seed number needed for non-std build");
|
// no std_random feature build, use getrandom
|
||||||
|
#[cfg(feature = "js")]
|
||||||
|
{
|
||||||
|
let mut buf = [0u8; 64];
|
||||||
|
getrandom::getrandom(&mut buf).unwrap();
|
||||||
|
RngImpl::seed_from_u64(buf[0] as u64)
|
||||||
|
}
|
||||||
|
#[cfg(not(feature = "js"))]
|
||||||
|
{
|
||||||
|
// Using 0 as default seed
|
||||||
|
RngImpl::seed_from_u64(0)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+39
-32
@@ -1,23 +1,24 @@
|
|||||||
//! This module contains utitilities to read-in matrices from csv files.
|
//! This module contains utitilities to read-in matrices from csv files.
|
||||||
//! ```
|
//! ```rust
|
||||||
//! use smartcore::readers::csv;
|
//! use smartcore::readers::csv;
|
||||||
//! use smartcore::linalg::naive::dense_matrix::DenseMatrix;
|
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||||
//! use crate::smartcore::linalg::BaseMatrix;
|
|
||||||
//! use std::fs;
|
//! use std::fs;
|
||||||
//!
|
//!
|
||||||
//! fs::write("identity.csv", "header\n1.0,0.0\n0.0,1.0");
|
//! fs::write("identity.csv", "header\n1.0,0.0\n0.0,1.0");
|
||||||
//! assert_eq!(
|
//!
|
||||||
//! csv::matrix_from_csv_source::<f64, Vec<_>, DenseMatrix<_>>(
|
//! let mtx = csv::matrix_from_csv_source::<f64, Vec<_>, DenseMatrix<_>>(
|
||||||
//! fs::File::open("identity.csv").unwrap(),
|
//! fs::File::open("identity.csv").unwrap(),
|
||||||
//! csv::CSVDefinition::default()
|
//! csv::CSVDefinition::default()
|
||||||
//! )
|
//! )
|
||||||
//! .unwrap(),
|
//! .unwrap();
|
||||||
//! DenseMatrix::from_row_vectors(vec![vec![1.0, 0.0], vec![0.0, 1.0]]).unwrap()
|
//! println!("{}", &mtx);
|
||||||
//! );
|
//!
|
||||||
//! fs::remove_file("identity.csv");
|
//! fs::remove_file("identity.csv");
|
||||||
//! ```
|
//! ```
|
||||||
use crate::linalg::{BaseMatrix, BaseVector};
|
|
||||||
use crate::math::num::RealNumber;
|
use crate::linalg::basic::arrays::{Array1, Array2};
|
||||||
|
use crate::numbers::basenum::Number;
|
||||||
|
use crate::numbers::realnum::RealNumber;
|
||||||
use crate::readers::ReadingError;
|
use crate::readers::ReadingError;
|
||||||
use std::io::Read;
|
use std::io::Read;
|
||||||
|
|
||||||
@@ -77,35 +78,41 @@ pub fn matrix_from_csv_source<T, RowVector, Matrix>(
|
|||||||
definition: CSVDefinition<'_>,
|
definition: CSVDefinition<'_>,
|
||||||
) -> Result<Matrix, ReadingError>
|
) -> Result<Matrix, ReadingError>
|
||||||
where
|
where
|
||||||
T: RealNumber,
|
T: Number + RealNumber + std::str::FromStr,
|
||||||
RowVector: BaseVector<T>,
|
RowVector: Array1<T>,
|
||||||
Matrix: BaseMatrix<T, RowVector = RowVector>,
|
Matrix: Array2<T>,
|
||||||
{
|
{
|
||||||
let csv_text = read_string_from_source(source)?;
|
let csv_text = read_string_from_source(source)?;
|
||||||
let rows = extract_row_vectors_from_csv_text::<T, RowVector, Matrix>(
|
let rows: Vec<Vec<T>> = extract_row_vectors_from_csv_text::<T, RowVector, Matrix>(
|
||||||
&csv_text,
|
&csv_text,
|
||||||
&definition,
|
&definition,
|
||||||
detect_row_format(&csv_text, &definition)?,
|
detect_row_format(&csv_text, &definition)?,
|
||||||
)?;
|
)?;
|
||||||
|
let nrows = rows.len();
|
||||||
|
let ncols = rows[0].len();
|
||||||
|
|
||||||
match Matrix::from_row_vectors(rows) {
|
// TODO: try to return ReadingError
|
||||||
Some(matrix) => Ok(matrix),
|
let array2 = Matrix::from_iterator(rows.into_iter().flatten(), nrows, ncols, 0);
|
||||||
None => Err(ReadingError::NoRowsProvided),
|
|
||||||
|
if array2.shape() != (nrows, ncols) {
|
||||||
|
Err(ReadingError::ShapesDoNotMatch { msg: String::new() })
|
||||||
|
} else {
|
||||||
|
Ok(array2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Given a string containing the contents of a csv file, extract its value
|
/// Given a string containing the contents of a csv file, extract its value
|
||||||
/// into row-vectors.
|
/// into row-vectors.
|
||||||
fn extract_row_vectors_from_csv_text<'a, T, RowVector, Matrix>(
|
fn extract_row_vectors_from_csv_text<
|
||||||
|
'a,
|
||||||
|
T: Number + RealNumber + std::str::FromStr,
|
||||||
|
RowVector: Array1<T>,
|
||||||
|
Matrix: Array2<T>,
|
||||||
|
>(
|
||||||
csv_text: &'a str,
|
csv_text: &'a str,
|
||||||
definition: &'a CSVDefinition<'_>,
|
definition: &'a CSVDefinition<'_>,
|
||||||
row_format: CSVRowFormat<'_>,
|
row_format: CSVRowFormat<'_>,
|
||||||
) -> Result<Vec<RowVector>, ReadingError>
|
) -> Result<Vec<Vec<T>>, ReadingError> {
|
||||||
where
|
|
||||||
T: RealNumber,
|
|
||||||
RowVector: BaseVector<T>,
|
|
||||||
Matrix: BaseMatrix<T, RowVector = RowVector>,
|
|
||||||
{
|
|
||||||
csv_text
|
csv_text
|
||||||
.lines()
|
.lines()
|
||||||
.skip(definition.n_rows_header)
|
.skip(definition.n_rows_header)
|
||||||
@@ -132,12 +139,12 @@ fn extract_vector_from_csv_line<T, RowVector>(
|
|||||||
row_format: &CSVRowFormat<'_>,
|
row_format: &CSVRowFormat<'_>,
|
||||||
) -> Result<RowVector, ReadingError>
|
) -> Result<RowVector, ReadingError>
|
||||||
where
|
where
|
||||||
T: RealNumber,
|
T: Number + RealNumber + std::str::FromStr,
|
||||||
RowVector: BaseVector<T>,
|
RowVector: Array1<T>,
|
||||||
{
|
{
|
||||||
validate_csv_row(line, row_format)?;
|
validate_csv_row(line, row_format)?;
|
||||||
let extracted_fields = extract_fields_from_csv_row(line, row_format)?;
|
let extracted_fields: Vec<T> = extract_fields_from_csv_row(line, row_format)?;
|
||||||
Ok(BaseVector::from_array(&extracted_fields[..]))
|
Ok(Array1::from_vec_slice(&extracted_fields[..]))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extract the fields from a string containing the row of a csv file.
|
/// Extract the fields from a string containing the row of a csv file.
|
||||||
@@ -146,7 +153,7 @@ fn extract_fields_from_csv_row<T>(
|
|||||||
row_format: &CSVRowFormat<'_>,
|
row_format: &CSVRowFormat<'_>,
|
||||||
) -> Result<Vec<T>, ReadingError>
|
) -> Result<Vec<T>, ReadingError>
|
||||||
where
|
where
|
||||||
T: RealNumber,
|
T: Number + RealNumber + std::str::FromStr,
|
||||||
{
|
{
|
||||||
row.split(row_format.field_seperator)
|
row.split(row_format.field_seperator)
|
||||||
.enumerate()
|
.enumerate()
|
||||||
@@ -192,7 +199,7 @@ fn enrich_reading_error<T>(
|
|||||||
/// Extract the value from a single csv field.
|
/// Extract the value from a single csv field.
|
||||||
fn extract_value_from_csv_field<T>(value_string: &str) -> Result<T, ReadingError>
|
fn extract_value_from_csv_field<T>(value_string: &str) -> Result<T, ReadingError>
|
||||||
where
|
where
|
||||||
T: RealNumber,
|
T: Number + RealNumber + std::str::FromStr,
|
||||||
{
|
{
|
||||||
// By default, `FromStr::Err` does not implement `Debug`.
|
// By default, `FromStr::Err` does not implement `Debug`.
|
||||||
// Restricting it in the library leads to many breaking
|
// Restricting it in the library leads to many breaking
|
||||||
@@ -210,7 +217,7 @@ where
|
|||||||
mod tests {
|
mod tests {
|
||||||
mod matrix_from_csv_source {
|
mod matrix_from_csv_source {
|
||||||
use super::super::{read_string_from_source, CSVDefinition, ReadingError};
|
use super::super::{read_string_from_source, CSVDefinition, ReadingError};
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::basic::matrix::DenseMatrix;
|
||||||
use crate::readers::{csv::matrix_from_csv_source, io_testing};
|
use crate::readers::{csv::matrix_from_csv_source, io_testing};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -298,7 +305,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
mod extract_row_vectors_from_csv_text {
|
mod extract_row_vectors_from_csv_text {
|
||||||
use super::super::{extract_row_vectors_from_csv_text, CSVDefinition, CSVRowFormat};
|
use super::super::{extract_row_vectors_from_csv_text, CSVDefinition, CSVRowFormat};
|
||||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
use crate::linalg::basic::matrix::DenseMatrix;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn read_default_csv() {
|
fn read_default_csv() {
|
||||||
|
|||||||
@@ -24,6 +24,12 @@ pub enum ReadingError {
|
|||||||
/// and where it happened.
|
/// and where it happened.
|
||||||
msg: String,
|
msg: String,
|
||||||
},
|
},
|
||||||
|
/// Shape after deserialization is wrong
|
||||||
|
ShapesDoNotMatch {
|
||||||
|
/// More details about what row could not be read
|
||||||
|
/// and where it happened.
|
||||||
|
msg: String,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
impl From<std::io::Error> for ReadingError {
|
impl From<std::io::Error> for ReadingError {
|
||||||
fn from(io_error: std::io::Error) -> Self {
|
fn from(io_error: std::io::Error) -> Self {
|
||||||
@@ -39,6 +45,7 @@ impl ReadingError {
|
|||||||
ReadingError::InvalidField { msg } => Some(msg),
|
ReadingError::InvalidField { msg } => Some(msg),
|
||||||
ReadingError::InvalidRow { msg } => Some(msg),
|
ReadingError::InvalidRow { msg } => Some(msg),
|
||||||
ReadingError::CouldNotReadFileSystem { msg } => Some(msg),
|
ReadingError::CouldNotReadFileSystem { msg } => Some(msg),
|
||||||
|
ReadingError::ShapesDoNotMatch { msg } => Some(msg),
|
||||||
ReadingError::NoRowsProvided => None,
|
ReadingError::NoRowsProvided => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -107,6 +107,7 @@ mod test {
|
|||||||
use std::fs;
|
use std::fs;
|
||||||
use std::io::Read;
|
use std::io::Read;
|
||||||
use std::path;
|
use std::path;
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
#[test]
|
#[test]
|
||||||
fn test_temporary_text_file() {
|
fn test_temporary_text_file() {
|
||||||
let path_of_temporary_file;
|
let path_of_temporary_file;
|
||||||
@@ -126,7 +127,7 @@ mod test {
|
|||||||
// should have been cleaned up.
|
// should have been cleaned up.
|
||||||
assert!(!path::Path::new(&path_of_temporary_file).exists())
|
assert!(!path::Path::new(&path_of_temporary_file).exists())
|
||||||
}
|
}
|
||||||
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
#[test]
|
#[test]
|
||||||
fn test_string_to_file() {
|
fn test_string_to_file() {
|
||||||
let path_of_test_file = "test.file";
|
let path_of_test_file = "test.file";
|
||||||
|
|||||||
+69
-93
@@ -9,7 +9,7 @@
|
|||||||
//! SVM is memory efficient since it uses only a subset of training data to find a decision boundary. This subset is called support vectors.
|
//! SVM is memory efficient since it uses only a subset of training data to find a decision boundary. This subset is called support vectors.
|
||||||
//!
|
//!
|
||||||
//! In SVM distance between a data point and the support vectors is defined by the kernel function.
|
//! In SVM distance between a data point and the support vectors is defined by the kernel function.
|
||||||
//! SmartCore supports multiple kernel functions but you can always define a new kernel function by implementing the `Kernel` trait. Not all functions can be a kernel.
|
//! `smartcore` supports multiple kernel functions but you can always define a new kernel function by implementing the `Kernel` trait. Not all functions can be a kernel.
|
||||||
//! Building a new kernel requires a good mathematical understanding of the [Mercer theorem](https://en.wikipedia.org/wiki/Mercer%27s_theorem)
|
//! Building a new kernel requires a good mathematical understanding of the [Mercer theorem](https://en.wikipedia.org/wiki/Mercer%27s_theorem)
|
||||||
//! that gives necessary and sufficient condition for a function to be a kernel function.
|
//! that gives necessary and sufficient condition for a function to be a kernel function.
|
||||||
//!
|
//!
|
||||||
@@ -23,15 +23,13 @@
|
|||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
/// search parameters
|
/// search parameters
|
||||||
pub mod search;
|
|
||||||
pub mod svc;
|
pub mod svc;
|
||||||
pub mod svr;
|
pub mod svr;
|
||||||
|
// /// search parameters space
|
||||||
|
// pub mod search;
|
||||||
|
|
||||||
use core::fmt::Debug;
|
use core::fmt::Debug;
|
||||||
use std::marker::PhantomData;
|
|
||||||
|
|
||||||
#[cfg(feature = "serde")]
|
|
||||||
use serde::ser::{SerializeStruct, Serializer};
|
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
@@ -40,52 +38,36 @@ use crate::linalg::basic::arrays::{Array1, ArrayView1};
|
|||||||
|
|
||||||
/// Defines a kernel function.
|
/// Defines a kernel function.
|
||||||
/// This is a object-safe trait.
|
/// This is a object-safe trait.
|
||||||
pub trait Kernel<'a> {
|
#[cfg_attr(
|
||||||
|
all(feature = "serde", not(target_arch = "wasm32")),
|
||||||
|
typetag::serde(tag = "type")
|
||||||
|
)]
|
||||||
|
pub trait Kernel: Debug {
|
||||||
#[allow(clippy::ptr_arg)]
|
#[allow(clippy::ptr_arg)]
|
||||||
/// Apply kernel function to x_i and x_j
|
/// Apply kernel function to x_i and x_j
|
||||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed>;
|
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed>;
|
||||||
/// Return a serializable name
|
|
||||||
fn name(&self) -> &'a str;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> Debug for dyn Kernel<'_> + 'a {
|
|
||||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
|
||||||
write!(f, "Kernel<f64>")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "serde")]
|
|
||||||
impl<'a> Serialize for dyn Kernel<'_> + 'a {
|
|
||||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
|
||||||
where
|
|
||||||
S: Serializer,
|
|
||||||
{
|
|
||||||
let mut s = serializer.serialize_struct("Kernel", 1)?;
|
|
||||||
s.serialize_field("type", &self.name())?;
|
|
||||||
s.end()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Pre-defined kernel functions
|
/// Pre-defined kernel functions
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Kernels {}
|
pub struct Kernels;
|
||||||
|
|
||||||
impl<'a> Kernels {
|
impl Kernels {
|
||||||
/// Return a default linear
|
/// Return a default linear
|
||||||
pub fn linear() -> LinearKernel<'a> {
|
pub fn linear() -> LinearKernel {
|
||||||
LinearKernel::default()
|
LinearKernel::default()
|
||||||
}
|
}
|
||||||
/// Return a default RBF
|
/// Return a default RBF
|
||||||
pub fn rbf() -> RBFKernel<'a> {
|
pub fn rbf() -> RBFKernel {
|
||||||
RBFKernel::default()
|
RBFKernel::default()
|
||||||
}
|
}
|
||||||
/// Return a default polynomial
|
/// Return a default polynomial
|
||||||
pub fn polynomial() -> PolynomialKernel<'a> {
|
pub fn polynomial() -> PolynomialKernel {
|
||||||
PolynomialKernel::default()
|
PolynomialKernel::default()
|
||||||
}
|
}
|
||||||
/// Return a default sigmoid
|
/// Return a default sigmoid
|
||||||
pub fn sigmoid() -> SigmoidKernel<'a> {
|
pub fn sigmoid() -> SigmoidKernel {
|
||||||
SigmoidKernel::default()
|
SigmoidKernel::default()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -93,40 +75,25 @@ impl<'a> Kernels {
|
|||||||
/// Linear Kernel
|
/// Linear Kernel
|
||||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||||
pub struct LinearKernel<'a> {
|
pub struct LinearKernel;
|
||||||
phantom: PhantomData<&'a ()>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> Default for LinearKernel<'a> {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
phantom: PhantomData,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Radial basis function (Gaussian) kernel
|
/// Radial basis function (Gaussian) kernel
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Default, Clone, PartialEq)]
|
||||||
pub struct RBFKernel<'a> {
|
pub struct RBFKernel {
|
||||||
/// kernel coefficient
|
/// kernel coefficient
|
||||||
pub gamma: Option<f64>,
|
pub gamma: Option<f64>,
|
||||||
phantom: PhantomData<&'a ()>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> Default for RBFKernel<'a> {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
gamma: Option::None,
|
|
||||||
phantom: PhantomData,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
impl<'a> RBFKernel<'a> {
|
impl RBFKernel {
|
||||||
fn with_gamma(mut self, gamma: f64) -> Self {
|
/// assign gamma parameter to kernel (required)
|
||||||
|
/// ```rust
|
||||||
|
/// use smartcore::svm::RBFKernel;
|
||||||
|
/// let knl = RBFKernel::default().with_gamma(0.7);
|
||||||
|
/// ```
|
||||||
|
pub fn with_gamma(mut self, gamma: f64) -> Self {
|
||||||
self.gamma = Some(gamma);
|
self.gamma = Some(gamma);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
@@ -135,42 +102,52 @@ impl<'a> RBFKernel<'a> {
|
|||||||
/// Polynomial kernel
|
/// Polynomial kernel
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub struct PolynomialKernel<'a> {
|
pub struct PolynomialKernel {
|
||||||
/// degree of the polynomial
|
/// degree of the polynomial
|
||||||
pub degree: Option<f64>,
|
pub degree: Option<f64>,
|
||||||
/// kernel coefficient
|
/// kernel coefficient
|
||||||
pub gamma: Option<f64>,
|
pub gamma: Option<f64>,
|
||||||
/// independent term in kernel function
|
/// independent term in kernel function
|
||||||
pub coef0: Option<f64>,
|
pub coef0: Option<f64>,
|
||||||
phantom: PhantomData<&'a ()>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Default for PolynomialKernel<'a> {
|
impl Default for PolynomialKernel {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
gamma: Option::None,
|
gamma: Option::None,
|
||||||
degree: Option::None,
|
degree: Option::None,
|
||||||
coef0: Some(1f64),
|
coef0: Some(1f64),
|
||||||
phantom: PhantomData,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
impl PolynomialKernel {
|
||||||
impl<'a> PolynomialKernel<'a> {
|
/// set parameters for kernel
|
||||||
fn with_params(mut self, degree: f64, gamma: f64, coef0: f64) -> Self {
|
/// ```rust
|
||||||
|
/// use smartcore::svm::PolynomialKernel;
|
||||||
|
/// let knl = PolynomialKernel::default().with_params(3.0, 0.7, 1.0);
|
||||||
|
/// ```
|
||||||
|
pub fn with_params(mut self, degree: f64, gamma: f64, coef0: f64) -> Self {
|
||||||
self.degree = Some(degree);
|
self.degree = Some(degree);
|
||||||
self.gamma = Some(gamma);
|
self.gamma = Some(gamma);
|
||||||
self.coef0 = Some(coef0);
|
self.coef0 = Some(coef0);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
/// set gamma parameter for kernel
|
||||||
fn with_gamma(mut self, gamma: f64) -> Self {
|
/// ```rust
|
||||||
|
/// use smartcore::svm::PolynomialKernel;
|
||||||
|
/// let knl = PolynomialKernel::default().with_gamma(0.7);
|
||||||
|
/// ```
|
||||||
|
pub fn with_gamma(mut self, gamma: f64) -> Self {
|
||||||
self.gamma = Some(gamma);
|
self.gamma = Some(gamma);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
/// set degree parameter for kernel
|
||||||
fn with_degree(self, degree: f64, n_features: usize) -> Self {
|
/// ```rust
|
||||||
|
/// use smartcore::svm::PolynomialKernel;
|
||||||
|
/// let knl = PolynomialKernel::default().with_degree(3.0, 100);
|
||||||
|
/// ```
|
||||||
|
pub fn with_degree(self, degree: f64, n_features: usize) -> Self {
|
||||||
self.with_params(degree, 1f64, 1f64 / n_features as f64)
|
self.with_params(degree, 1f64, 1f64 / n_features as f64)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -178,47 +155,53 @@ impl<'a> PolynomialKernel<'a> {
|
|||||||
/// Sigmoid (hyperbolic tangent) kernel
|
/// Sigmoid (hyperbolic tangent) kernel
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub struct SigmoidKernel<'a> {
|
pub struct SigmoidKernel {
|
||||||
/// kernel coefficient
|
/// kernel coefficient
|
||||||
pub gamma: Option<f64>,
|
pub gamma: Option<f64>,
|
||||||
/// independent term in kernel function
|
/// independent term in kernel function
|
||||||
pub coef0: Option<f64>,
|
pub coef0: Option<f64>,
|
||||||
phantom: PhantomData<&'a ()>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Default for SigmoidKernel<'a> {
|
impl Default for SigmoidKernel {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
gamma: Option::None,
|
gamma: Option::None,
|
||||||
coef0: Some(1f64),
|
coef0: Some(1f64),
|
||||||
phantom: PhantomData,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
impl SigmoidKernel {
|
||||||
impl<'a> SigmoidKernel<'a> {
|
/// set parameters for kernel
|
||||||
fn with_params(mut self, gamma: f64, coef0: f64) -> Self {
|
/// ```rust
|
||||||
|
/// use smartcore::svm::SigmoidKernel;
|
||||||
|
/// let knl = SigmoidKernel::default().with_params(0.7, 1.0);
|
||||||
|
/// ```
|
||||||
|
pub fn with_params(mut self, gamma: f64, coef0: f64) -> Self {
|
||||||
self.gamma = Some(gamma);
|
self.gamma = Some(gamma);
|
||||||
self.coef0 = Some(coef0);
|
self.coef0 = Some(coef0);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
fn with_gamma(mut self, gamma: f64) -> Self {
|
/// set gamma parameter for kernel
|
||||||
|
/// ```rust
|
||||||
|
/// use smartcore::svm::SigmoidKernel;
|
||||||
|
/// let knl = SigmoidKernel::default().with_gamma(0.7);
|
||||||
|
/// ```
|
||||||
|
pub fn with_gamma(mut self, gamma: f64) -> Self {
|
||||||
self.gamma = Some(gamma);
|
self.gamma = Some(gamma);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Kernel<'a> for LinearKernel<'a> {
|
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
|
||||||
|
impl Kernel for LinearKernel {
|
||||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
||||||
Ok(x_i.dot(x_j))
|
Ok(x_i.dot(x_j))
|
||||||
}
|
}
|
||||||
fn name(&self) -> &'a str {
|
|
||||||
"Linear"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Kernel<'a> for RBFKernel<'a> {
|
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
|
||||||
|
impl Kernel for RBFKernel {
|
||||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
||||||
if self.gamma.is_none() {
|
if self.gamma.is_none() {
|
||||||
return Err(Failed::because(
|
return Err(Failed::because(
|
||||||
@@ -229,12 +212,10 @@ impl<'a> Kernel<'a> for RBFKernel<'a> {
|
|||||||
let v_diff = x_i.sub(x_j);
|
let v_diff = x_i.sub(x_j);
|
||||||
Ok((-self.gamma.unwrap() * v_diff.mul(&v_diff).sum()).exp())
|
Ok((-self.gamma.unwrap() * v_diff.mul(&v_diff).sum()).exp())
|
||||||
}
|
}
|
||||||
fn name(&self) -> &'a str {
|
|
||||||
"RBF"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Kernel<'a> for PolynomialKernel<'a> {
|
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
|
||||||
|
impl Kernel for PolynomialKernel {
|
||||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
||||||
if self.gamma.is_none() || self.coef0.is_none() || self.degree.is_none() {
|
if self.gamma.is_none() || self.coef0.is_none() || self.degree.is_none() {
|
||||||
return Err(Failed::because(
|
return Err(Failed::because(
|
||||||
@@ -245,12 +226,10 @@ impl<'a> Kernel<'a> for PolynomialKernel<'a> {
|
|||||||
let dot = x_i.dot(x_j);
|
let dot = x_i.dot(x_j);
|
||||||
Ok((self.gamma.unwrap() * dot + self.coef0.unwrap()).powf(self.degree.unwrap()))
|
Ok((self.gamma.unwrap() * dot + self.coef0.unwrap()).powf(self.degree.unwrap()))
|
||||||
}
|
}
|
||||||
fn name(&self) -> &'a str {
|
|
||||||
"Polynomial"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Kernel<'a> for SigmoidKernel<'a> {
|
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
|
||||||
|
impl Kernel for SigmoidKernel {
|
||||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
||||||
if self.gamma.is_none() || self.coef0.is_none() {
|
if self.gamma.is_none() || self.coef0.is_none() {
|
||||||
return Err(Failed::because(
|
return Err(Failed::because(
|
||||||
@@ -261,9 +240,6 @@ impl<'a> Kernel<'a> for SigmoidKernel<'a> {
|
|||||||
let dot = x_i.dot(x_j);
|
let dot = x_i.dot(x_j);
|
||||||
Ok(self.gamma.unwrap() * dot + self.coef0.unwrap().tanh())
|
Ok(self.gamma.unwrap() * dot + self.coef0.unwrap().tanh())
|
||||||
}
|
}
|
||||||
fn name(&self) -> &'a str {
|
|
||||||
"Sigmoid"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
+30
-33
@@ -20,7 +20,7 @@
|
|||||||
//!
|
//!
|
||||||
//! Where \\( m \\) is a number of training samples, \\( y_i \\) is a label value (either 1 or -1) and \\(\langle\vec{w}, \vec{x}_i \rangle + b\\) is a decision boundary.
|
//! Where \\( m \\) is a number of training samples, \\( y_i \\) is a label value (either 1 or -1) and \\(\langle\vec{w}, \vec{x}_i \rangle + b\\) is a decision boundary.
|
||||||
//!
|
//!
|
||||||
//! To solve this optimization problem, SmartCore uses an [approximate SVM solver](https://leon.bottou.org/projects/lasvm).
|
//! To solve this optimization problem, `smartcore` uses an [approximate SVM solver](https://leon.bottou.org/projects/lasvm).
|
||||||
//! The optimizer reaches accuracies similar to that of a real SVM after performing two passes through the training examples. You can choose the number of passes
|
//! The optimizer reaches accuracies similar to that of a real SVM after performing two passes through the training examples. You can choose the number of passes
|
||||||
//! through the data that the algorithm takes by changing the `epoch` parameter of the classifier.
|
//! through the data that the algorithm takes by changing the `epoch` parameter of the classifier.
|
||||||
//!
|
//!
|
||||||
@@ -58,7 +58,7 @@
|
|||||||
//! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
|
//! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
|
||||||
//!
|
//!
|
||||||
//! let knl = Kernels::linear();
|
//! let knl = Kernels::linear();
|
||||||
//! let params = &SVCParameters::default().with_c(200.0).with_kernel(&knl);
|
//! let params = &SVCParameters::default().with_c(200.0).with_kernel(knl);
|
||||||
//! let svc = SVC::fit(&x, &y, params).unwrap();
|
//! let svc = SVC::fit(&x, &y, params).unwrap();
|
||||||
//!
|
//!
|
||||||
//! let y_hat = svc.predict(&x).unwrap();
|
//! let y_hat = svc.predict(&x).unwrap();
|
||||||
@@ -91,24 +91,21 @@ use crate::rand_custom::get_rng_impl;
|
|||||||
use crate::svm::Kernel;
|
use crate::svm::Kernel;
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug)]
|
||||||
/// SVC Parameters
|
/// SVC Parameters
|
||||||
pub struct SVCParameters<
|
pub struct SVCParameters<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> {
|
||||||
'a,
|
|
||||||
TX: Number + RealNumber,
|
|
||||||
TY: Number + Ord,
|
|
||||||
X: Array2<TX>,
|
|
||||||
Y: Array1<TY>,
|
|
||||||
> {
|
|
||||||
/// Number of epochs.
|
/// Number of epochs.
|
||||||
pub epoch: usize,
|
pub epoch: usize,
|
||||||
/// Regularization parameter.
|
/// Regularization parameter.
|
||||||
pub c: TX,
|
pub c: TX,
|
||||||
/// Tolerance for stopping criterion.
|
/// Tolerance for stopping criterion.
|
||||||
pub tol: TX,
|
pub tol: TX,
|
||||||
#[cfg_attr(feature = "serde", serde(skip_deserializing))]
|
|
||||||
/// The kernel function.
|
/// The kernel function.
|
||||||
pub kernel: Option<&'a dyn Kernel<'a>>,
|
#[cfg_attr(
|
||||||
|
all(feature = "serde", target_arch = "wasm32"),
|
||||||
|
serde(skip_serializing, skip_deserializing)
|
||||||
|
)]
|
||||||
|
pub kernel: Option<Box<dyn Kernel>>,
|
||||||
/// Unused parameter.
|
/// Unused parameter.
|
||||||
m: PhantomData<(X, Y, TY)>,
|
m: PhantomData<(X, Y, TY)>,
|
||||||
/// Controls the pseudo random number generation for shuffling the data for probability estimates
|
/// Controls the pseudo random number generation for shuffling the data for probability estimates
|
||||||
@@ -129,7 +126,7 @@ pub struct SVC<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y:
|
|||||||
classes: Option<Vec<TY>>,
|
classes: Option<Vec<TY>>,
|
||||||
instances: Option<Vec<Vec<TX>>>,
|
instances: Option<Vec<Vec<TX>>>,
|
||||||
#[cfg_attr(feature = "serde", serde(skip))]
|
#[cfg_attr(feature = "serde", serde(skip))]
|
||||||
parameters: Option<&'a SVCParameters<'a, TX, TY, X, Y>>,
|
parameters: Option<&'a SVCParameters<TX, TY, X, Y>>,
|
||||||
w: Option<Vec<TX>>,
|
w: Option<Vec<TX>>,
|
||||||
b: Option<TX>,
|
b: Option<TX>,
|
||||||
phantomdata: PhantomData<(X, Y)>,
|
phantomdata: PhantomData<(X, Y)>,
|
||||||
@@ -155,7 +152,7 @@ struct Cache<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1
|
|||||||
struct Optimizer<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> {
|
struct Optimizer<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> {
|
||||||
x: &'a X,
|
x: &'a X,
|
||||||
y: &'a Y,
|
y: &'a Y,
|
||||||
parameters: &'a SVCParameters<'a, TX, TY, X, Y>,
|
parameters: &'a SVCParameters<TX, TY, X, Y>,
|
||||||
svmin: usize,
|
svmin: usize,
|
||||||
svmax: usize,
|
svmax: usize,
|
||||||
gmin: TX,
|
gmin: TX,
|
||||||
@@ -165,8 +162,8 @@ struct Optimizer<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y
|
|||||||
recalculate_minmax_grad: bool,
|
recalculate_minmax_grad: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
impl<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||||
SVCParameters<'a, TX, TY, X, Y>
|
SVCParameters<TX, TY, X, Y>
|
||||||
{
|
{
|
||||||
/// Number of epochs.
|
/// Number of epochs.
|
||||||
pub fn with_epoch(mut self, epoch: usize) -> Self {
|
pub fn with_epoch(mut self, epoch: usize) -> Self {
|
||||||
@@ -184,8 +181,8 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
/// The kernel function.
|
/// The kernel function.
|
||||||
pub fn with_kernel(mut self, kernel: &'a (dyn Kernel<'a>)) -> Self {
|
pub fn with_kernel<K: Kernel + 'static>(mut self, kernel: K) -> Self {
|
||||||
self.kernel = Some(kernel);
|
self.kernel = Some(Box::new(kernel));
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -196,8 +193,8 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> Default
|
impl<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> Default
|
||||||
for SVCParameters<'a, TX, TY, X, Y>
|
for SVCParameters<TX, TY, X, Y>
|
||||||
{
|
{
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
SVCParameters {
|
SVCParameters {
|
||||||
@@ -212,7 +209,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||||
SupervisedEstimatorBorrow<'a, X, Y, SVCParameters<'a, TX, TY, X, Y>> for SVC<'a, TX, TY, X, Y>
|
SupervisedEstimatorBorrow<'a, X, Y, SVCParameters<TX, TY, X, Y>> for SVC<'a, TX, TY, X, Y>
|
||||||
{
|
{
|
||||||
fn new() -> Self {
|
fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -227,7 +224,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
|||||||
fn fit(
|
fn fit(
|
||||||
x: &'a X,
|
x: &'a X,
|
||||||
y: &'a Y,
|
y: &'a Y,
|
||||||
parameters: &'a SVCParameters<'a, TX, TY, X, Y>,
|
parameters: &'a SVCParameters<TX, TY, X, Y>,
|
||||||
) -> Result<Self, Failed> {
|
) -> Result<Self, Failed> {
|
||||||
SVC::fit(x, y, parameters)
|
SVC::fit(x, y, parameters)
|
||||||
}
|
}
|
||||||
@@ -251,7 +248,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array
|
|||||||
pub fn fit(
|
pub fn fit(
|
||||||
x: &'a X,
|
x: &'a X,
|
||||||
y: &'a Y,
|
y: &'a Y,
|
||||||
parameters: &'a SVCParameters<'a, TX, TY, X, Y>,
|
parameters: &'a SVCParameters<TX, TY, X, Y>,
|
||||||
) -> Result<SVC<'a, TX, TY, X, Y>, Failed> {
|
) -> Result<SVC<'a, TX, TY, X, Y>, Failed> {
|
||||||
let (n, _) = x.shape();
|
let (n, _) = x.shape();
|
||||||
|
|
||||||
@@ -447,7 +444,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
|||||||
fn new(
|
fn new(
|
||||||
x: &'a X,
|
x: &'a X,
|
||||||
y: &'a Y,
|
y: &'a Y,
|
||||||
parameters: &'a SVCParameters<'a, TX, TY, X, Y>,
|
parameters: &'a SVCParameters<TX, TY, X, Y>,
|
||||||
) -> Optimizer<'a, TX, TY, X, Y> {
|
) -> Optimizer<'a, TX, TY, X, Y> {
|
||||||
let (n, _) = x.shape();
|
let (n, _) = x.shape();
|
||||||
|
|
||||||
@@ -940,8 +937,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::basic::matrix::DenseMatrix;
|
use crate::linalg::basic::matrix::DenseMatrix;
|
||||||
use crate::metrics::accuracy;
|
use crate::metrics::accuracy;
|
||||||
#[cfg(feature = "serde")]
|
use crate::svm::Kernels;
|
||||||
use crate::svm::*;
|
|
||||||
|
|
||||||
#[cfg_attr(
|
#[cfg_attr(
|
||||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
@@ -979,7 +975,7 @@ mod tests {
|
|||||||
let knl = Kernels::linear();
|
let knl = Kernels::linear();
|
||||||
let params = SVCParameters::default()
|
let params = SVCParameters::default()
|
||||||
.with_c(200.0)
|
.with_c(200.0)
|
||||||
.with_kernel(&knl)
|
.with_kernel(knl)
|
||||||
.with_seed(Some(100));
|
.with_seed(Some(100));
|
||||||
|
|
||||||
let y_hat = SVC::fit(&x, &y, ¶ms)
|
let y_hat = SVC::fit(&x, &y, ¶ms)
|
||||||
@@ -1018,7 +1014,7 @@ mod tests {
|
|||||||
&y,
|
&y,
|
||||||
&SVCParameters::default()
|
&SVCParameters::default()
|
||||||
.with_c(200.0)
|
.with_c(200.0)
|
||||||
.with_kernel(&Kernels::linear()),
|
.with_kernel(Kernels::linear()),
|
||||||
)
|
)
|
||||||
.and_then(|lr| lr.decision_function(&x2))
|
.and_then(|lr| lr.decision_function(&x2))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -1073,7 +1069,7 @@ mod tests {
|
|||||||
&y,
|
&y,
|
||||||
&SVCParameters::default()
|
&SVCParameters::default()
|
||||||
.with_c(1.0)
|
.with_c(1.0)
|
||||||
.with_kernel(&Kernels::rbf().with_gamma(0.7)),
|
.with_kernel(Kernels::rbf().with_gamma(0.7)),
|
||||||
)
|
)
|
||||||
.and_then(|lr| lr.predict(&x))
|
.and_then(|lr| lr.predict(&x))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -1092,7 +1088,7 @@ mod tests {
|
|||||||
wasm_bindgen_test::wasm_bindgen_test
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
)]
|
)]
|
||||||
#[test]
|
#[test]
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(all(feature = "serde", not(target_arch = "wasm32")))]
|
||||||
fn svc_serde() {
|
fn svc_serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[5.1, 3.5, 1.4, 0.2],
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
@@ -1122,12 +1118,13 @@ mod tests {
|
|||||||
];
|
];
|
||||||
|
|
||||||
let knl = Kernels::linear();
|
let knl = Kernels::linear();
|
||||||
let params = SVCParameters::default().with_kernel(&knl);
|
let params = SVCParameters::default().with_kernel(knl);
|
||||||
let svc = SVC::fit(&x, &y, ¶ms).unwrap();
|
let svc = SVC::fit(&x, &y, ¶ms).unwrap();
|
||||||
|
|
||||||
// serialization
|
// serialization
|
||||||
let serialized_svc = &serde_json::to_string(&svc).unwrap();
|
let deserialized_svc: SVC<f64, i32, _, _> =
|
||||||
|
serde_json::from_str(&serde_json::to_string(&svc).unwrap()).unwrap();
|
||||||
|
|
||||||
println!("{:?}", serialized_svc);
|
assert_eq!(svc, deserialized_svc);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+24
-27
@@ -50,7 +50,7 @@
|
|||||||
//! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
//! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
||||||
//!
|
//!
|
||||||
//! let knl = Kernels::linear();
|
//! let knl = Kernels::linear();
|
||||||
//! let params = &SVRParameters::default().with_eps(2.0).with_c(10.0).with_kernel(&knl);
|
//! let params = &SVRParameters::default().with_eps(2.0).with_c(10.0).with_kernel(knl);
|
||||||
//! // let svr = SVR::fit(&x, &y, params).unwrap();
|
//! // let svr = SVR::fit(&x, &y, params).unwrap();
|
||||||
//!
|
//!
|
||||||
//! // let y_hat = svr.predict(&x).unwrap();
|
//! // let y_hat = svr.predict(&x).unwrap();
|
||||||
@@ -83,18 +83,21 @@ use crate::numbers::floatnum::FloatNumber;
|
|||||||
use crate::svm::Kernel;
|
use crate::svm::Kernel;
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug)]
|
||||||
/// SVR Parameters
|
/// SVR Parameters
|
||||||
pub struct SVRParameters<'a, T: Number + FloatNumber + PartialOrd> {
|
pub struct SVRParameters<T: Number + FloatNumber + PartialOrd> {
|
||||||
/// Epsilon in the epsilon-SVR model.
|
/// Epsilon in the epsilon-SVR model.
|
||||||
pub eps: T,
|
pub eps: T,
|
||||||
/// Regularization parameter.
|
/// Regularization parameter.
|
||||||
pub c: T,
|
pub c: T,
|
||||||
/// Tolerance for stopping criterion.
|
/// Tolerance for stopping criterion.
|
||||||
pub tol: T,
|
pub tol: T,
|
||||||
#[cfg_attr(feature = "serde", serde(skip_deserializing))]
|
|
||||||
/// The kernel function.
|
/// The kernel function.
|
||||||
pub kernel: Option<&'a dyn Kernel<'a>>,
|
#[cfg_attr(
|
||||||
|
all(feature = "serde", target_arch = "wasm32"),
|
||||||
|
serde(skip_serializing, skip_deserializing)
|
||||||
|
)]
|
||||||
|
pub kernel: Option<Box<dyn Kernel>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
@@ -103,7 +106,7 @@ pub struct SVRParameters<'a, T: Number + FloatNumber + PartialOrd> {
|
|||||||
pub struct SVR<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> {
|
pub struct SVR<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> {
|
||||||
instances: Option<Vec<Vec<f64>>>,
|
instances: Option<Vec<Vec<f64>>>,
|
||||||
#[cfg_attr(feature = "serde", serde(skip_deserializing))]
|
#[cfg_attr(feature = "serde", serde(skip_deserializing))]
|
||||||
parameters: Option<&'a SVRParameters<'a, T>>,
|
parameters: Option<&'a SVRParameters<T>>,
|
||||||
w: Option<Vec<T>>,
|
w: Option<Vec<T>>,
|
||||||
b: T,
|
b: T,
|
||||||
phantom: PhantomData<(X, Y)>,
|
phantom: PhantomData<(X, Y)>,
|
||||||
@@ -123,7 +126,7 @@ struct SupportVector<T> {
|
|||||||
struct Optimizer<'a, T: Number + FloatNumber + PartialOrd> {
|
struct Optimizer<'a, T: Number + FloatNumber + PartialOrd> {
|
||||||
tol: T,
|
tol: T,
|
||||||
c: T,
|
c: T,
|
||||||
parameters: Option<&'a SVRParameters<'a, T>>,
|
parameters: Option<&'a SVRParameters<T>>,
|
||||||
svmin: usize,
|
svmin: usize,
|
||||||
svmax: usize,
|
svmax: usize,
|
||||||
gmin: T,
|
gmin: T,
|
||||||
@@ -140,7 +143,7 @@ struct Cache<T: Clone> {
|
|||||||
data: Vec<RefCell<Option<Vec<T>>>>,
|
data: Vec<RefCell<Option<Vec<T>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: Number + FloatNumber + PartialOrd> SVRParameters<'a, T> {
|
impl<T: Number + FloatNumber + PartialOrd> SVRParameters<T> {
|
||||||
/// Epsilon in the epsilon-SVR model.
|
/// Epsilon in the epsilon-SVR model.
|
||||||
pub fn with_eps(mut self, eps: T) -> Self {
|
pub fn with_eps(mut self, eps: T) -> Self {
|
||||||
self.eps = eps;
|
self.eps = eps;
|
||||||
@@ -157,13 +160,13 @@ impl<'a, T: Number + FloatNumber + PartialOrd> SVRParameters<'a, T> {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
/// The kernel function.
|
/// The kernel function.
|
||||||
pub fn with_kernel(mut self, kernel: &'a (dyn Kernel<'a>)) -> Self {
|
pub fn with_kernel<K: Kernel + 'static>(mut self, kernel: K) -> Self {
|
||||||
self.kernel = Some(kernel);
|
self.kernel = Some(Box::new(kernel));
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: Number + FloatNumber + PartialOrd> Default for SVRParameters<'a, T> {
|
impl<T: Number + FloatNumber + PartialOrd> Default for SVRParameters<T> {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
SVRParameters {
|
SVRParameters {
|
||||||
eps: T::from_f64(0.1).unwrap(),
|
eps: T::from_f64(0.1).unwrap(),
|
||||||
@@ -175,7 +178,7 @@ impl<'a, T: Number + FloatNumber + PartialOrd> Default for SVRParameters<'a, T>
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>>
|
impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>>
|
||||||
SupervisedEstimatorBorrow<'a, X, Y, SVRParameters<'a, T>> for SVR<'a, T, X, Y>
|
SupervisedEstimatorBorrow<'a, X, Y, SVRParameters<T>> for SVR<'a, T, X, Y>
|
||||||
{
|
{
|
||||||
fn new() -> Self {
|
fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -186,7 +189,7 @@ impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>>
|
|||||||
phantom: PhantomData,
|
phantom: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fn fit(x: &'a X, y: &'a Y, parameters: &'a SVRParameters<'a, T>) -> Result<Self, Failed> {
|
fn fit(x: &'a X, y: &'a Y, parameters: &'a SVRParameters<T>) -> Result<Self, Failed> {
|
||||||
SVR::fit(x, y, parameters)
|
SVR::fit(x, y, parameters)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -208,7 +211,7 @@ impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> SVR<'
|
|||||||
pub fn fit(
|
pub fn fit(
|
||||||
x: &'a X,
|
x: &'a X,
|
||||||
y: &'a Y,
|
y: &'a Y,
|
||||||
parameters: &'a SVRParameters<'a, T>,
|
parameters: &'a SVRParameters<T>,
|
||||||
) -> Result<SVR<'a, T, X, Y>, Failed> {
|
) -> Result<SVR<'a, T, X, Y>, Failed> {
|
||||||
let (n, _) = x.shape();
|
let (n, _) = x.shape();
|
||||||
|
|
||||||
@@ -324,7 +327,7 @@ impl<'a, T: Number + FloatNumber + PartialOrd> Optimizer<'a, T> {
|
|||||||
fn new<X: Array2<T>, Y: Array1<T>>(
|
fn new<X: Array2<T>, Y: Array1<T>>(
|
||||||
x: &'a X,
|
x: &'a X,
|
||||||
y: &'a Y,
|
y: &'a Y,
|
||||||
parameters: &'a SVRParameters<'a, T>,
|
parameters: &'a SVRParameters<T>,
|
||||||
) -> Optimizer<'a, T> {
|
) -> Optimizer<'a, T> {
|
||||||
let (n, _) = x.shape();
|
let (n, _) = x.shape();
|
||||||
|
|
||||||
@@ -596,7 +599,6 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::basic::matrix::DenseMatrix;
|
use crate::linalg::basic::matrix::DenseMatrix;
|
||||||
use crate::metrics::mean_squared_error;
|
use crate::metrics::mean_squared_error;
|
||||||
#[cfg(feature = "serde")]
|
|
||||||
use crate::svm::Kernels;
|
use crate::svm::Kernels;
|
||||||
|
|
||||||
// #[test]
|
// #[test]
|
||||||
@@ -617,7 +619,6 @@ mod tests {
|
|||||||
// assert!(iter.next().is_none());
|
// assert!(iter.next().is_none());
|
||||||
// }
|
// }
|
||||||
|
|
||||||
//TODO: had to disable this test as it runs for too long
|
|
||||||
#[cfg_attr(
|
#[cfg_attr(
|
||||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
wasm_bindgen_test::wasm_bindgen_test
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
@@ -655,7 +656,7 @@ mod tests {
|
|||||||
&SVRParameters::default()
|
&SVRParameters::default()
|
||||||
.with_eps(2.0)
|
.with_eps(2.0)
|
||||||
.with_c(10.0)
|
.with_c(10.0)
|
||||||
.with_kernel(&knl),
|
.with_kernel(knl),
|
||||||
)
|
)
|
||||||
.and_then(|lr| lr.predict(&x))
|
.and_then(|lr| lr.predict(&x))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -670,7 +671,7 @@ mod tests {
|
|||||||
wasm_bindgen_test::wasm_bindgen_test
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
)]
|
)]
|
||||||
#[test]
|
#[test]
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(all(feature = "serde", not(target_arch = "wasm32")))]
|
||||||
fn svr_serde() {
|
fn svr_serde() {
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||||
@@ -697,17 +698,13 @@ mod tests {
|
|||||||
];
|
];
|
||||||
|
|
||||||
let knl = Kernels::rbf().with_gamma(0.7);
|
let knl = Kernels::rbf().with_gamma(0.7);
|
||||||
let params = SVRParameters::default().with_kernel(&knl);
|
let params = SVRParameters::default().with_kernel(knl);
|
||||||
|
|
||||||
let svr = SVR::fit(&x, &y, ¶ms).unwrap();
|
let svr = SVR::fit(&x, &y, ¶ms).unwrap();
|
||||||
|
|
||||||
let serialized = &serde_json::to_string(&svr).unwrap();
|
let deserialized_svr: SVR<f64, DenseMatrix<f64>, _> =
|
||||||
|
serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap();
|
||||||
|
|
||||||
println!("{}", &serialized);
|
assert_eq!(svr, deserialized_svr);
|
||||||
|
|
||||||
// let deserialized_svr: SVR<f64, DenseMatrix<f64>, LinearKernel> =
|
|
||||||
// serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap();
|
|
||||||
|
|
||||||
// assert_eq!(svr, deserialized_svr);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -163,7 +163,6 @@ impl Default for SplitCriterion {
|
|||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct Node {
|
struct Node {
|
||||||
index: usize,
|
|
||||||
output: usize,
|
output: usize,
|
||||||
split_feature: usize,
|
split_feature: usize,
|
||||||
split_value: Option<f64>,
|
split_value: Option<f64>,
|
||||||
@@ -406,9 +405,8 @@ impl Default for DecisionTreeClassifierSearchParameters {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Node {
|
impl Node {
|
||||||
fn new(index: usize, output: usize) -> Self {
|
fn new(output: usize) -> Self {
|
||||||
Node {
|
Node {
|
||||||
index,
|
|
||||||
output,
|
output,
|
||||||
split_feature: 0,
|
split_feature: 0,
|
||||||
split_value: Option::None,
|
split_value: Option::None,
|
||||||
@@ -582,7 +580,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
count[yi[i]] += samples[i];
|
count[yi[i]] += samples[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
let root = Node::new(0, which_max(&count));
|
let root = Node::new(which_max(&count));
|
||||||
change_nodes.push(root);
|
change_nodes.push(root);
|
||||||
let mut order: Vec<Vec<usize>> = Vec::new();
|
let mut order: Vec<Vec<usize>> = Vec::new();
|
||||||
|
|
||||||
@@ -673,7 +671,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
let mut is_pure = true;
|
let mut is_pure = true;
|
||||||
for i in 0..n_rows {
|
for i in 0..n_rows {
|
||||||
if visitor.samples[i] > 0 {
|
if visitor.samples[i] > 0 {
|
||||||
if label == Option::None {
|
if label.is_none() {
|
||||||
label = Option::Some(visitor.y[i]);
|
label = Option::Some(visitor.y[i]);
|
||||||
} else if visitor.y[i] != label.unwrap() {
|
} else if visitor.y[i] != label.unwrap() {
|
||||||
is_pure = false;
|
is_pure = false;
|
||||||
@@ -831,11 +829,9 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
|
|
||||||
let true_child_idx = self.nodes().len();
|
let true_child_idx = self.nodes().len();
|
||||||
|
|
||||||
self.nodes
|
self.nodes.push(Node::new(visitor.true_child_output));
|
||||||
.push(Node::new(true_child_idx, visitor.true_child_output));
|
|
||||||
let false_child_idx = self.nodes().len();
|
let false_child_idx = self.nodes().len();
|
||||||
self.nodes
|
self.nodes.push(Node::new(visitor.false_child_output));
|
||||||
.push(Node::new(false_child_idx, visitor.false_child_output));
|
|
||||||
self.nodes[visitor.node].true_child = Some(true_child_idx);
|
self.nodes[visitor.node].true_child = Some(true_child_idx);
|
||||||
self.nodes[visitor.node].false_child = Some(false_child_idx);
|
self.nodes[visitor.node].false_child = Some(false_child_idx);
|
||||||
|
|
||||||
@@ -923,6 +919,7 @@ mod tests {
|
|||||||
wasm_bindgen_test::wasm_bindgen_test
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
)]
|
)]
|
||||||
#[test]
|
#[test]
|
||||||
|
#[cfg(feature = "datasets")]
|
||||||
fn fit_predict_iris() {
|
fn fit_predict_iris() {
|
||||||
let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
|
let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
|
||||||
&[5.1, 3.5, 1.4, 0.2],
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
|||||||
@@ -11,7 +11,7 @@
|
|||||||
//!
|
//!
|
||||||
//! where \\(\hat{y}_{Rk}\\) is the mean response for the training observations withing region _k_.
|
//! where \\(\hat{y}_{Rk}\\) is the mean response for the training observations withing region _k_.
|
||||||
//!
|
//!
|
||||||
//! SmartCore uses recursive binary splitting approach to build \\(R_1, R_2, ..., R_K\\) regions. The approach begins at the top of the tree and then successively splits the predictor space
|
//! `smartcore` uses recursive binary splitting approach to build \\(R_1, R_2, ..., R_K\\) regions. The approach begins at the top of the tree and then successively splits the predictor space
|
||||||
//! one predictor at a time. At each step of the tree-building process, the best split is made at that particular step, rather than looking ahead and picking a split that will lead to a better
|
//! one predictor at a time. At each step of the tree-building process, the best split is made at that particular step, rather than looking ahead and picking a split that will lead to a better
|
||||||
//! tree in some future step.
|
//! tree in some future step.
|
||||||
//!
|
//!
|
||||||
@@ -128,7 +128,6 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct Node {
|
struct Node {
|
||||||
index: usize,
|
|
||||||
output: f64,
|
output: f64,
|
||||||
split_feature: usize,
|
split_feature: usize,
|
||||||
split_value: Option<f64>,
|
split_value: Option<f64>,
|
||||||
@@ -299,9 +298,8 @@ impl Default for DecisionTreeRegressorSearchParameters {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Node {
|
impl Node {
|
||||||
fn new(index: usize, output: f64) -> Self {
|
fn new(output: f64) -> Self {
|
||||||
Node {
|
Node {
|
||||||
index,
|
|
||||||
output,
|
output,
|
||||||
split_feature: 0,
|
split_feature: 0,
|
||||||
split_value: Option::None,
|
split_value: Option::None,
|
||||||
@@ -450,7 +448,7 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
sum += *sample_i as f64 * y_m.get(i).to_f64().unwrap();
|
sum += *sample_i as f64 * y_m.get(i).to_f64().unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let root = Node::new(0, sum / (n as f64));
|
let root = Node::new(sum / (n as f64));
|
||||||
nodes.push(root);
|
nodes.push(root);
|
||||||
let mut order: Vec<Vec<usize>> = Vec::new();
|
let mut order: Vec<Vec<usize>> = Vec::new();
|
||||||
|
|
||||||
@@ -511,7 +509,7 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
match queue.pop_front() {
|
match queue.pop_front() {
|
||||||
Some(node_id) => {
|
Some(node_id) => {
|
||||||
let node = &self.nodes()[node_id];
|
let node = &self.nodes()[node_id];
|
||||||
if node.true_child == None && node.false_child == None {
|
if node.true_child.is_none() && node.false_child.is_none() {
|
||||||
result = node.output;
|
result = node.output;
|
||||||
} else if x.get((row, node.split_feature)).to_f64().unwrap()
|
} else if x.get((row, node.split_feature)).to_f64().unwrap()
|
||||||
<= node.split_value.unwrap_or(std::f64::NAN)
|
<= node.split_value.unwrap_or(std::f64::NAN)
|
||||||
@@ -557,7 +555,7 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
self.find_best_split(visitor, n, sum, parent_gain, *variable);
|
self.find_best_split(visitor, n, sum, parent_gain, *variable);
|
||||||
}
|
}
|
||||||
|
|
||||||
self.nodes()[visitor.node].split_score != Option::None
|
self.nodes()[visitor.node].split_score.is_some()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn find_best_split(
|
fn find_best_split(
|
||||||
@@ -662,11 +660,9 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
|
|
||||||
let true_child_idx = self.nodes().len();
|
let true_child_idx = self.nodes().len();
|
||||||
|
|
||||||
self.nodes
|
self.nodes.push(Node::new(visitor.true_child_output));
|
||||||
.push(Node::new(true_child_idx, visitor.true_child_output));
|
|
||||||
let false_child_idx = self.nodes().len();
|
let false_child_idx = self.nodes().len();
|
||||||
self.nodes
|
self.nodes.push(Node::new(visitor.false_child_output));
|
||||||
.push(Node::new(false_child_idx, visitor.false_child_output));
|
|
||||||
|
|
||||||
self.nodes[visitor.node].true_child = Some(true_child_idx);
|
self.nodes[visitor.node].true_child = Some(true_child_idx);
|
||||||
self.nodes[visitor.node].false_child = Some(false_child_idx);
|
self.nodes[visitor.node].false_child = Some(false_child_idx);
|
||||||
|
|||||||
+1
-1
@@ -9,7 +9,7 @@
|
|||||||
//! Decision trees suffer from high variance and often does not deliver best prediction accuracy when compared to other supervised learning approaches, such as linear and logistic regression.
|
//! Decision trees suffer from high variance and often does not deliver best prediction accuracy when compared to other supervised learning approaches, such as linear and logistic regression.
|
||||||
//! Hence some techniques such as [Random Forests](../ensemble/index.html) use more than one decision tree to improve performance of the algorithm.
|
//! Hence some techniques such as [Random Forests](../ensemble/index.html) use more than one decision tree to improve performance of the algorithm.
|
||||||
//!
|
//!
|
||||||
//! SmartCore uses [CART](https://en.wikipedia.org/wiki/Predictive_analytics#Classification_and_regression_trees_.28CART.29) learning technique to build both classification and regression trees.
|
//! `smartcore` uses [CART](https://en.wikipedia.org/wiki/Predictive_analytics#Classification_and_regression_trees_.28CART.29) learning technique to build both classification and regression trees.
|
||||||
//!
|
//!
|
||||||
//! ## References:
|
//! ## References:
|
||||||
//!
|
//!
|
||||||
|
|||||||
Reference in New Issue
Block a user