Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
681fea6cbe | ||
|
|
038108b1c3 | ||
|
|
730c0d64df | ||
|
|
44424807a0 | ||
|
|
76d1ef610d | ||
|
|
4092e24c2a | ||
|
|
17dc9f3bbf | ||
|
|
c8ec8fec00 | ||
|
|
3da433f757 | ||
|
|
4523ac73ff | ||
|
|
ba75f9ffad | ||
|
|
239c00428f | ||
|
|
80a93c1a0e | ||
|
|
4eadd16ce4 | ||
|
|
886b5631b7 | ||
|
|
9c07925d8a | ||
|
|
6f22bbd150 | ||
|
|
dbdc2b2a77 |
@@ -2,6 +2,5 @@
|
||||
# the repo. Unless a later match takes precedence,
|
||||
# Developers in this list will be requested for
|
||||
# review when someone opens a pull request.
|
||||
* @VolodymyrOrlov
|
||||
* @morenol
|
||||
* @Mec-iS
|
||||
|
||||
@@ -50,9 +50,9 @@ $ rust-code-analysis-cli -p src/algorithm/neighbour/fastpair.rs --ls 22 --le 213
|
||||
|
||||
1. After a PR is opened maintainers are notified
|
||||
2. Probably changes will be required to comply with the workflow, these commands are run automatically and all tests shall pass:
|
||||
* **Coverage** (optional): `tarpaulin` is used with command `cargo tarpaulin --out Lcov --all-features -- --test-threads 1`
|
||||
* **Formatting**: run `rustfmt src/*.rs` to apply automatic formatting
|
||||
* **Linting**: `clippy` is used with command `cargo clippy --all-features -- -Drust-2018-idioms -Dwarnings`
|
||||
* **Coverage** (optional): `tarpaulin` is used with command `cargo tarpaulin --out Lcov --all-features -- --test-threads 1`
|
||||
* **Testing**: multiple test pipelines are run for different targets
|
||||
3. When everything is OK, code is merged.
|
||||
|
||||
|
||||
@@ -19,14 +19,13 @@ jobs:
|
||||
{ os: "ubuntu", target: "i686-unknown-linux-gnu" },
|
||||
{ os: "ubuntu", target: "wasm32-unknown-unknown" },
|
||||
{ os: "macos", target: "aarch64-apple-darwin" },
|
||||
{ os: "ubuntu", target: "wasm32-wasi" },
|
||||
]
|
||||
env:
|
||||
TZ: "/usr/share/zoneinfo/your/location"
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
- name: Cache .cargo and target
|
||||
uses: actions/cache@v2
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo
|
||||
@@ -43,9 +42,6 @@ jobs:
|
||||
- name: Install test runner for wasm
|
||||
if: matrix.platform.target == 'wasm32-unknown-unknown'
|
||||
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
|
||||
- name: Install test runner for wasi
|
||||
if: matrix.platform.target == 'wasm32-wasi'
|
||||
run: curl https://wasmtime.dev/install.sh -sSf | bash
|
||||
- name: Stable Build with all features
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
@@ -65,13 +61,7 @@ jobs:
|
||||
- name: Tests in WASM
|
||||
if: matrix.platform.target == 'wasm32-unknown-unknown'
|
||||
run: wasm-pack test --node -- --all-features
|
||||
- name: Tests in WASI
|
||||
if: matrix.platform.target == 'wasm32-wasi'
|
||||
run: |
|
||||
export WASMTIME_HOME="$HOME/.wasmtime"
|
||||
export PATH="$WASMTIME_HOME/bin:$PATH"
|
||||
cargo install cargo-wasi && cargo wasi test
|
||||
|
||||
|
||||
check_features:
|
||||
runs-on: "${{ matrix.platform.os }}-latest"
|
||||
strategy:
|
||||
@@ -81,9 +71,9 @@ jobs:
|
||||
env:
|
||||
TZ: "/usr/share/zoneinfo/your/location"
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
- name: Cache .cargo and target
|
||||
uses: actions/cache@v2
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo
|
||||
|
||||
@@ -12,9 +12,9 @@ jobs:
|
||||
env:
|
||||
TZ: "/usr/share/zoneinfo/your/location"
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v4
|
||||
- name: Cache .cargo
|
||||
uses: actions/cache@v2
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo
|
||||
|
||||
@@ -14,7 +14,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Cache .cargo and target
|
||||
uses: actions/cache@v2
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo
|
||||
|
||||
@@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [0.4.0] - 2023-04-05
|
||||
|
||||
## Added
|
||||
- WARNING: Breaking changes!
|
||||
- `DenseMatrix` constructor now returns `Result` to avoid user instantiating inconsistent rows/cols count. Their return values need to be unwrapped with `unwrap()`, see tests
|
||||
|
||||
## [0.3.0] - 2022-11-09
|
||||
|
||||
## Added
|
||||
|
||||
+2
-2
@@ -2,7 +2,7 @@
|
||||
name = "smartcore"
|
||||
description = "Machine Learning in Rust."
|
||||
homepage = "https://smartcorelib.org"
|
||||
version = "0.3.2"
|
||||
version = "0.4.2"
|
||||
authors = ["smartcore Developers"]
|
||||
edition = "2021"
|
||||
license = "Apache-2.0"
|
||||
@@ -48,7 +48,7 @@ getrandom = { version = "0.2.8", optional = true }
|
||||
wasm-bindgen-test = "0.3"
|
||||
|
||||
[dev-dependencies]
|
||||
itertools = "0.10.5"
|
||||
itertools = "0.13.0"
|
||||
serde_json = "1.0"
|
||||
bincode = "1.3.1"
|
||||
|
||||
|
||||
@@ -18,4 +18,4 @@
|
||||
-----
|
||||
[](https://github.com/smartcorelib/smartcore/actions/workflows/ci.yml)
|
||||
|
||||
To start getting familiar with the new smartcore v0.3 API, there is now available a [**Jupyter Notebook environment repository**](https://github.com/smartcorelib/smartcore-jupyter). Please see instructions there, contributions welcome see [CONTRIBUTING](.github/CONTRIBUTING.md).
|
||||
To start getting familiar with the new smartcore v0.4 API, there is now available a [**Jupyter Notebook environment repository**](https://github.com/smartcorelib/smartcore-jupyter). Please see instructions there, contributions welcome see [CONTRIBUTING](.github/CONTRIBUTING.md).
|
||||
|
||||
@@ -40,11 +40,11 @@ impl BBDTreeNode {
|
||||
|
||||
impl BBDTree {
|
||||
pub fn new<T: Number, M: Array2<T>>(data: &M) -> BBDTree {
|
||||
let nodes = Vec::new();
|
||||
let nodes: Vec<BBDTreeNode> = Vec::new();
|
||||
|
||||
let (n, _) = data.shape();
|
||||
|
||||
let index = (0..n).collect::<Vec<_>>();
|
||||
let index = (0..n).collect::<Vec<usize>>();
|
||||
|
||||
let mut tree = BBDTree {
|
||||
nodes,
|
||||
@@ -343,7 +343,8 @@ mod tests {
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let tree = BBDTree::new(&data);
|
||||
|
||||
|
||||
@@ -124,7 +124,7 @@ impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
|
||||
current_cover_set.push((d, &self.root));
|
||||
|
||||
let mut heap = HeapSelection::with_capacity(k);
|
||||
heap.add(std::f64::MAX);
|
||||
heap.add(f64::MAX);
|
||||
|
||||
let mut empty_heap = true;
|
||||
if !self.identical_excluded || self.get_data_value(self.root.idx) != p {
|
||||
@@ -145,7 +145,7 @@ impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
|
||||
}
|
||||
|
||||
let upper_bound = if empty_heap {
|
||||
std::f64::INFINITY
|
||||
f64::INFINITY
|
||||
} else {
|
||||
*heap.peek()
|
||||
};
|
||||
@@ -291,7 +291,7 @@ impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
|
||||
} else {
|
||||
let max_dist = self.max(point_set);
|
||||
let next_scale = (max_scale - 1).min(self.get_scale(max_dist));
|
||||
if next_scale == std::i64::MIN {
|
||||
if next_scale == i64::MIN {
|
||||
let mut children: Vec<Node> = Vec::new();
|
||||
let mut leaf = self.new_leaf(p);
|
||||
children.push(leaf);
|
||||
@@ -435,7 +435,7 @@ impl<T: Debug + PartialEq, D: Distance<T>> CoverTree<T, D> {
|
||||
|
||||
fn get_scale(&self, d: f64) -> i64 {
|
||||
if d == 0f64 {
|
||||
std::i64::MIN
|
||||
i64::MIN
|
||||
} else {
|
||||
(self.inv_log_base * d.ln()).ceil() as i64
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
/// &[4.6, 3.1, 1.5, 0.2],
|
||||
/// &[5.0, 3.6, 1.4, 0.2],
|
||||
/// &[5.4, 3.9, 1.7, 0.4],
|
||||
/// ]);
|
||||
/// ]).unwrap();
|
||||
/// let fastpair = FastPair::new(&x);
|
||||
/// let closest_pair: PairwiseDistance<f64> = fastpair.unwrap().closest_pair();
|
||||
/// ```
|
||||
@@ -52,10 +52,8 @@ pub struct FastPair<'a, T: RealNumber + FloatNumber, M: Array2<T>> {
|
||||
}
|
||||
|
||||
impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
|
||||
///
|
||||
/// Constructor
|
||||
/// Instantiate and inizialise the algorithm
|
||||
///
|
||||
/// Instantiate and initialize the algorithm
|
||||
pub fn new(m: &'a M) -> Result<Self, Failed> {
|
||||
if m.shape().0 < 3 {
|
||||
return Err(Failed::because(
|
||||
@@ -74,10 +72,8 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
|
||||
Ok(init)
|
||||
}
|
||||
|
||||
///
|
||||
/// Initialise `FastPair` by passing a `Array2`.
|
||||
/// Build a FastPairs data-structure from a set of (new) points.
|
||||
///
|
||||
fn init(&mut self) {
|
||||
// basic measures
|
||||
let len = self.samples.shape().0;
|
||||
@@ -158,9 +154,7 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> FastPair<'a, T, M> {
|
||||
self.neighbours = neighbours;
|
||||
}
|
||||
|
||||
///
|
||||
/// Find closest pair by scanning list of nearest neighbors.
|
||||
///
|
||||
#[allow(dead_code)]
|
||||
pub fn closest_pair(&self) -> PairwiseDistance<T> {
|
||||
let mut a = self.neighbours[0]; // Start with first point
|
||||
@@ -232,10 +226,10 @@ mod tests_fastpair {
|
||||
use super::*;
|
||||
use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix};
|
||||
|
||||
///
|
||||
/// Brute force algorithm, used only for comparison and testing
|
||||
///
|
||||
pub fn closest_pair_brute(fastpair: &FastPair<f64, DenseMatrix<f64>>) -> PairwiseDistance<f64> {
|
||||
pub fn closest_pair_brute(
|
||||
fastpair: &FastPair<'_, f64, DenseMatrix<f64>>,
|
||||
) -> PairwiseDistance<f64> {
|
||||
use itertools::Itertools;
|
||||
let m = fastpair.samples.shape().0;
|
||||
|
||||
@@ -286,7 +280,7 @@ mod tests_fastpair {
|
||||
fn dataset_has_at_least_three_points() {
|
||||
// Create a dataset which consists of only two points:
|
||||
// A(0.0, 0.0) and B(1.0, 1.0).
|
||||
let dataset = DenseMatrix::<f64>::from_2d_array(&[&[0.0, 0.0], &[1.0, 1.0]]);
|
||||
let dataset = DenseMatrix::<f64>::from_2d_array(&[&[0.0, 0.0], &[1.0, 1.0]]).unwrap();
|
||||
|
||||
// We expect an error when we run `FastPair` on this dataset,
|
||||
// becuase `FastPair` currently only works on a minimum of 3
|
||||
@@ -303,7 +297,7 @@ mod tests_fastpair {
|
||||
|
||||
#[test]
|
||||
fn one_dimensional_dataset_minimal() {
|
||||
let dataset = DenseMatrix::<f64>::from_2d_array(&[&[0.0], &[2.0], &[9.0]]);
|
||||
let dataset = DenseMatrix::<f64>::from_2d_array(&[&[0.0], &[2.0], &[9.0]]).unwrap();
|
||||
|
||||
let result = FastPair::new(&dataset);
|
||||
assert!(result.is_ok());
|
||||
@@ -323,7 +317,8 @@ mod tests_fastpair {
|
||||
|
||||
#[test]
|
||||
fn one_dimensional_dataset_2() {
|
||||
let dataset = DenseMatrix::<f64>::from_2d_array(&[&[27.0], &[0.0], &[9.0], &[2.0]]);
|
||||
let dataset =
|
||||
DenseMatrix::<f64>::from_2d_array(&[&[27.0], &[0.0], &[9.0], &[2.0]]).unwrap();
|
||||
|
||||
let result = FastPair::new(&dataset);
|
||||
assert!(result.is_ok());
|
||||
@@ -358,7 +353,8 @@ mod tests_fastpair {
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let fastpair = FastPair::new(&x);
|
||||
assert!(fastpair.is_ok());
|
||||
|
||||
@@ -531,7 +527,8 @@ mod tests_fastpair {
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
// compute
|
||||
let fastpair = FastPair::new(&x);
|
||||
assert!(fastpair.is_ok());
|
||||
@@ -579,7 +576,8 @@ mod tests_fastpair {
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
// compute
|
||||
let fastpair = FastPair::new(&x);
|
||||
assert!(fastpair.is_ok());
|
||||
@@ -624,7 +622,8 @@ mod tests_fastpair {
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let fastpair = FastPair::new(&x).unwrap();
|
||||
|
||||
let ordered = fastpair.ordered_pairs();
|
||||
@@ -640,4 +639,67 @@ mod tests_fastpair {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_set() {
|
||||
let empty_matrix = DenseMatrix::<f64>::zeros(0, 0);
|
||||
let result = FastPair::new(&empty_matrix);
|
||||
assert!(result.is_err());
|
||||
if let Err(e) = result {
|
||||
assert_eq!(
|
||||
e,
|
||||
Failed::because(FailedError::FindFailed, "min number of rows should be 3")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_point() {
|
||||
let single_point = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]).unwrap();
|
||||
let result = FastPair::new(&single_point);
|
||||
assert!(result.is_err());
|
||||
if let Err(e) = result {
|
||||
assert_eq!(
|
||||
e,
|
||||
Failed::because(FailedError::FindFailed, "min number of rows should be 3")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_two_points() {
|
||||
let two_points = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
|
||||
let result = FastPair::new(&two_points);
|
||||
assert!(result.is_err());
|
||||
if let Err(e) = result {
|
||||
assert_eq!(
|
||||
e,
|
||||
Failed::because(FailedError::FindFailed, "min number of rows should be 3")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_three_identical_points() {
|
||||
let identical_points =
|
||||
DenseMatrix::from_2d_array(&[&[1.0, 1.0], &[1.0, 1.0], &[1.0, 1.0]]).unwrap();
|
||||
let result = FastPair::new(&identical_points);
|
||||
assert!(result.is_ok());
|
||||
let fastpair = result.unwrap();
|
||||
let closest_pair = fastpair.closest_pair();
|
||||
assert_eq!(closest_pair.distance, Some(0.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_result_unwrapping() {
|
||||
let valid_matrix =
|
||||
DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0], &[5.0, 6.0], &[7.0, 8.0]])
|
||||
.unwrap();
|
||||
|
||||
let result = FastPair::new(&valid_matrix);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// This should not panic
|
||||
let _fastpair = result.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ impl<T, D: Distance<T>> LinearKNNSearch<T, D> {
|
||||
|
||||
for _ in 0..k {
|
||||
heap.add(KNNPoint {
|
||||
distance: std::f64::INFINITY,
|
||||
distance: f64::INFINITY,
|
||||
index: None,
|
||||
});
|
||||
}
|
||||
@@ -215,7 +215,7 @@ mod tests {
|
||||
};
|
||||
|
||||
let point_inf = KNNPoint {
|
||||
distance: std::f64::INFINITY,
|
||||
distance: f64::INFINITY,
|
||||
index: Some(3),
|
||||
};
|
||||
|
||||
|
||||
@@ -133,7 +133,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_add1() {
|
||||
let mut heap = HeapSelection::with_capacity(3);
|
||||
heap.add(std::f64::INFINITY);
|
||||
heap.add(f64::INFINITY);
|
||||
heap.add(-5f64);
|
||||
heap.add(4f64);
|
||||
heap.add(-1f64);
|
||||
@@ -151,7 +151,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_add2() {
|
||||
let mut heap = HeapSelection::with_capacity(3);
|
||||
heap.add(std::f64::INFINITY);
|
||||
heap.add(f64::INFINITY);
|
||||
heap.add(0.0);
|
||||
heap.add(8.4852);
|
||||
heap.add(5.6568);
|
||||
|
||||
@@ -3,6 +3,7 @@ use num_traits::Num;
|
||||
pub trait QuickArgSort {
|
||||
fn quick_argsort_mut(&mut self) -> Vec<usize>;
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn quick_argsort(&self) -> Vec<usize>;
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,315 @@
|
||||
//! # Agglomerative Hierarchical Clustering
|
||||
//!
|
||||
//! Agglomerative clustering is a "bottom-up" hierarchical clustering method. It works by placing each data point in its own cluster and then successively merging the two most similar clusters until a stopping criterion is met. This process creates a tree-based hierarchy of clusters known as a dendrogram.
|
||||
//!
|
||||
//! The similarity of two clusters is determined by a **linkage criterion**. This implementation uses **single-linkage**, where the distance between two clusters is defined as the minimum distance between any single point in the first cluster and any single point in the second cluster. The distance between points is the standard Euclidean distance.
|
||||
//!
|
||||
//! The algorithm first builds the full hierarchy of `N-1` merges. To obtain a specific number of clusters, `n_clusters`, the algorithm then effectively "cuts" the dendrogram at the point where `n_clusters` remain.
|
||||
//!
|
||||
//! ## Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
//! use smartcore::cluster::agglomerative::{AgglomerativeClustering, AgglomerativeClusteringParameters};
|
||||
//!
|
||||
//! // A dataset with 2 distinct groups of points.
|
||||
//! let x = DenseMatrix::from_2d_array(&[
|
||||
//! &[0.0, 0.0], &[1.0, 1.0], &[0.5, 0.5], // Cluster A
|
||||
//! &[10.0, 10.0], &[11.0, 11.0], &[10.5, 10.5], // Cluster B
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! // Set parameters to find 2 clusters.
|
||||
//! let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(2);
|
||||
//!
|
||||
//! // Fit the model to the data.
|
||||
//! let clustering = AgglomerativeClustering::<f64, usize, DenseMatrix<f64>, Vec<usize>>::fit(&x, parameters).unwrap();
|
||||
//!
|
||||
//! // Get the cluster assignments.
|
||||
//! let labels = clustering.labels; // e.g., [0, 0, 0, 1, 1, 1]
|
||||
//! ```
|
||||
//!
|
||||
//! ## References:
|
||||
//!
|
||||
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 10.3.2 Hierarchical Clustering](http://faculty.marshall.usc.edu/gareth-james/ISL/)
|
||||
//! * ["The Elements of Statistical Learning", Hastie T., Tibshirani R., Friedman J., 14.3.12 Hierarchical Clustering](https://hastie.su.domains/ElemStatLearn/)
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::api::UnsupervisedEstimator;
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::basic::arrays::{Array1, Array2};
|
||||
use crate::numbers::basenum::Number;
|
||||
|
||||
/// Parameters for the Agglomerative Clustering algorithm.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct AgglomerativeClusteringParameters {
|
||||
/// The number of clusters to find.
|
||||
pub n_clusters: usize,
|
||||
}
|
||||
|
||||
impl AgglomerativeClusteringParameters {
|
||||
/// Sets the number of clusters.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `n_clusters` - The desired number of clusters.
|
||||
pub fn with_n_clusters(mut self, n_clusters: usize) -> Self {
|
||||
self.n_clusters = n_clusters;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AgglomerativeClusteringParameters {
|
||||
fn default() -> Self {
|
||||
AgglomerativeClusteringParameters { n_clusters: 2 }
|
||||
}
|
||||
}
|
||||
|
||||
/// Agglomerative Clustering model.
|
||||
///
|
||||
/// This implementation uses single-linkage clustering, which is mathematically
|
||||
/// equivalent to finding the Minimum Spanning Tree (MST) of the data points.
|
||||
/// The core logic is an efficient implementation of Kruskal's algorithm, which
|
||||
/// processes all pairwise distances in increasing order and uses a Disjoint
|
||||
/// Set Union (DSU) data structure to track cluster membership.
|
||||
#[derive(Debug)]
|
||||
pub struct AgglomerativeClustering<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
|
||||
/// The cluster label assigned to each sample.
|
||||
pub labels: Vec<usize>,
|
||||
_phantom_tx: PhantomData<TX>,
|
||||
_phantom_ty: PhantomData<TY>,
|
||||
_phantom_x: PhantomData<X>,
|
||||
_phantom_y: PhantomData<Y>,
|
||||
}
|
||||
|
||||
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> AgglomerativeClustering<TX, TY, X, Y> {
|
||||
/// Fits the agglomerative clustering model to the data.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `data` - A reference to the input data matrix.
|
||||
/// * `parameters` - The parameters for the clustering algorithm, including `n_clusters`.
|
||||
///
|
||||
/// # Returns
|
||||
/// A `Result` containing the fitted model with cluster labels, or an error if
|
||||
pub fn fit(data: &X, parameters: AgglomerativeClusteringParameters) -> Result<Self, Failed> {
|
||||
let (num_samples, _) = data.shape();
|
||||
let n_clusters = parameters.n_clusters;
|
||||
if n_clusters > num_samples {
|
||||
return Err(Failed::because(
|
||||
FailedError::ParametersError,
|
||||
&format!("n_clusters: {n_clusters} cannot be greater than n_samples: {num_samples}"),
|
||||
));
|
||||
}
|
||||
|
||||
let mut distance_pairs = Vec::new();
|
||||
for i in 0..num_samples {
|
||||
for j in (i + 1)..num_samples {
|
||||
let distance: f64 = data
|
||||
.get_row(i)
|
||||
.iterator(0)
|
||||
.zip(data.get_row(j).iterator(0))
|
||||
.map(|(&a, &b)| (a.to_f64().unwrap() - b.to_f64().unwrap()).powi(2))
|
||||
.sum::<f64>();
|
||||
|
||||
distance_pairs.push((distance, i, j));
|
||||
}
|
||||
}
|
||||
distance_pairs.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
|
||||
let mut parent = HashMap::new();
|
||||
let mut children = HashMap::new();
|
||||
for i in 0..num_samples {
|
||||
parent.insert(i, i);
|
||||
children.insert(i, vec![i]);
|
||||
}
|
||||
|
||||
let mut merge_history = Vec::new();
|
||||
let num_merges_needed = num_samples - 1;
|
||||
|
||||
while merge_history.len() < num_merges_needed {
|
||||
let (_, p1, p2) = distance_pairs.pop().unwrap();
|
||||
|
||||
let root1 = parent[&p1];
|
||||
let root2 = parent[&p2];
|
||||
|
||||
if root1 != root2 {
|
||||
let root2_children = children.remove(&root2).unwrap();
|
||||
for child in root2_children.iter() {
|
||||
parent.insert(*child, root1);
|
||||
}
|
||||
let root1_children = children.get_mut(&root1).unwrap();
|
||||
root1_children.extend(root2_children);
|
||||
merge_history.push((root1, root2));
|
||||
}
|
||||
}
|
||||
|
||||
let mut clusters = HashMap::new();
|
||||
let mut assignments = HashMap::new();
|
||||
|
||||
for i in 0..num_samples {
|
||||
clusters.insert(i, vec![i]);
|
||||
assignments.insert(i, i);
|
||||
}
|
||||
|
||||
let merges_to_apply = num_samples - n_clusters;
|
||||
|
||||
for (root1, root2) in merge_history[0..merges_to_apply].iter() {
|
||||
let root1_cluster = assignments[root1];
|
||||
let root2_cluster = assignments[root2];
|
||||
|
||||
let root2_assignments = clusters.remove(&root2_cluster).unwrap();
|
||||
for assignment in root2_assignments.iter() {
|
||||
assignments.insert(*assignment, root1_cluster);
|
||||
}
|
||||
let root1_assignments = clusters.get_mut(&root1_cluster).unwrap();
|
||||
root1_assignments.extend(root2_assignments);
|
||||
}
|
||||
|
||||
let mut labels: Vec<usize> = (0..num_samples).map(|_| 0).collect();
|
||||
let mut cluster_keys: Vec<&usize> = clusters.keys().collect();
|
||||
cluster_keys.sort();
|
||||
for (i, key) in cluster_keys.into_iter().enumerate() {
|
||||
for index in clusters[key].iter() {
|
||||
labels[*index] = i;
|
||||
}
|
||||
}
|
||||
Ok(AgglomerativeClustering {
|
||||
labels,
|
||||
_phantom_tx: PhantomData,
|
||||
_phantom_ty: PhantomData,
|
||||
_phantom_x: PhantomData,
|
||||
_phantom_y: PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||
UnsupervisedEstimator<X, AgglomerativeClusteringParameters>
|
||||
for AgglomerativeClustering<TX, TY, X, Y>
|
||||
{
|
||||
fn fit(x: &X, parameters: AgglomerativeClusteringParameters) -> Result<Self, Failed> {
|
||||
AgglomerativeClustering::fit(x, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use std::collections::HashSet;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_simple_clustering() {
|
||||
// Two distinct clusters, far apart.
|
||||
let data = vec![
|
||||
0.0, 0.0, 1.0, 1.0, 0.5, 0.5, // Cluster A
|
||||
10.0, 10.0, 11.0, 11.0, 10.5, 10.5, // Cluster B
|
||||
];
|
||||
let matrix = DenseMatrix::new(6, 2, data, false).unwrap();
|
||||
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(2);
|
||||
// Using f64 for TY as usize doesn't satisfy the Number trait bound.
|
||||
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
|
||||
&matrix, parameters,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let labels = clustering.labels;
|
||||
|
||||
// Check that all points in the first group have the same label.
|
||||
let first_group_label = labels[0];
|
||||
assert!(labels[0..3].iter().all(|&l| l == first_group_label));
|
||||
|
||||
// Check that all points in the second group have the same label.
|
||||
let second_group_label = labels[3];
|
||||
assert!(labels[3..6].iter().all(|&l| l == second_group_label));
|
||||
|
||||
// Check that the two groups have different labels.
|
||||
assert_ne!(first_group_label, second_group_label);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_four_clusters() {
|
||||
// Four distinct clusters in the corners of a square.
|
||||
let data = vec![
|
||||
0.0, 0.0, 1.0, 1.0, // Cluster A
|
||||
100.0, 100.0, 101.0, 101.0, // Cluster B
|
||||
0.0, 100.0, 1.0, 101.0, // Cluster C
|
||||
100.0, 0.0, 101.0, 1.0, // Cluster D
|
||||
];
|
||||
let matrix = DenseMatrix::new(8, 2, data, false).unwrap();
|
||||
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(4);
|
||||
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
|
||||
&matrix, parameters,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let labels = clustering.labels;
|
||||
|
||||
// Verify that there are exactly 4 unique labels produced.
|
||||
let unique_labels: HashSet<usize> = labels.iter().cloned().collect();
|
||||
assert_eq!(unique_labels.len(), 4);
|
||||
|
||||
// Verify that points within each original group were assigned the same cluster label.
|
||||
let label_a = labels[0];
|
||||
assert_eq!(label_a, labels[1]);
|
||||
|
||||
let label_b = labels[2];
|
||||
assert_eq!(label_b, labels[3]);
|
||||
|
||||
let label_c = labels[4];
|
||||
assert_eq!(label_c, labels[5]);
|
||||
|
||||
let label_d = labels[6];
|
||||
assert_eq!(label_d, labels[7]);
|
||||
|
||||
// Verify that all four groups received different labels.
|
||||
assert_ne!(label_a, label_b);
|
||||
assert_ne!(label_a, label_c);
|
||||
assert_ne!(label_a, label_d);
|
||||
assert_ne!(label_b, label_c);
|
||||
assert_ne!(label_b, label_d);
|
||||
assert_ne!(label_c, label_d);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_n_clusters_equal_to_samples() {
|
||||
let data = vec![0.0, 0.0, 5.0, 5.0, 10.0, 10.0];
|
||||
let matrix = DenseMatrix::new(3, 2, data, false).unwrap();
|
||||
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(3);
|
||||
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
|
||||
&matrix, parameters,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Each point should be its own cluster. Sorting makes the test deterministic.
|
||||
let mut labels = clustering.labels;
|
||||
labels.sort();
|
||||
assert_eq!(labels, vec![0, 1, 2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_one_cluster() {
|
||||
let data = vec![0.0, 0.0, 5.0, 5.0, 10.0, 10.0];
|
||||
let matrix = DenseMatrix::new(3, 2, data, false).unwrap();
|
||||
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(1);
|
||||
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
|
||||
&matrix, parameters,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// All points should be in the same cluster.
|
||||
assert_eq!(clustering.labels, vec![0, 0, 0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_on_too_many_clusters() {
|
||||
let data = vec![0.0, 0.0, 5.0, 5.0];
|
||||
let matrix = DenseMatrix::new(2, 2, data, false).unwrap();
|
||||
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(3);
|
||||
let result = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
|
||||
&matrix, parameters,
|
||||
);
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
@@ -315,8 +315,7 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
|
||||
}
|
||||
}
|
||||
|
||||
while !neighbors.is_empty() {
|
||||
let neighbor = neighbors.pop().unwrap();
|
||||
while let Some(neighbor) = neighbors.pop() {
|
||||
let index = neighbor.0;
|
||||
|
||||
if y[index] == outlier {
|
||||
@@ -443,7 +442,8 @@ mod tests {
|
||||
&[2.2, 1.2],
|
||||
&[1.8, 0.8],
|
||||
&[3.0, 5.0],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let expected_labels = vec![1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 0];
|
||||
|
||||
@@ -488,7 +488,8 @@ mod tests {
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let dbscan = DBSCAN::fit(&x, Default::default()).unwrap();
|
||||
|
||||
|
||||
+12
-186
@@ -41,7 +41,7 @@
|
||||
//! &[4.9, 2.4, 3.3, 1.0],
|
||||
//! &[6.6, 2.9, 4.6, 1.3],
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! let kmeans = KMeans::fit(&x, KMeansParameters::default().with_k(2)).unwrap(); // Fit to data, 2 clusters
|
||||
//! let y_hat: Vec<u8> = kmeans.predict(&x).unwrap(); // use the same points for prediction
|
||||
@@ -62,7 +62,7 @@ use serde::{Deserialize, Serialize};
|
||||
use crate::algorithm::neighbour::bbd_tree::BBDTree;
|
||||
use crate::api::{Predictor, UnsupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::basic::arrays::{Array1, Array2, Array};
|
||||
use crate::linalg::basic::arrays::{Array1, Array2};
|
||||
use crate::metrics::distance::euclidian::*;
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::rand_custom::get_rng_impl;
|
||||
@@ -96,7 +96,7 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq for KMeans<
|
||||
return false;
|
||||
}
|
||||
for j in 0..self.centroids[i].len() {
|
||||
if (self.centroids[i][j] - other.centroids[i][j]).abs() > std::f64::EPSILON {
|
||||
if (self.centroids[i][j] - other.centroids[i][j]).abs() > f64::EPSILON {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -249,7 +249,7 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> Predictor<X, Y>
|
||||
|
||||
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y> {
|
||||
/// Fit algorithm to _NxM_ matrix where _N_ is number of samples and _M_ is number of features.
|
||||
/// * `data` - training instances to cluster
|
||||
/// * `data` - training instances to cluster
|
||||
/// * `parameters` - cluster parameters
|
||||
pub fn fit(data: &X, parameters: KMeansParameters) -> Result<KMeans<TX, TY, X, Y>, Failed> {
|
||||
let bbd = BBDTree::new(data);
|
||||
@@ -270,7 +270,7 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y>
|
||||
|
||||
let (n, d) = data.shape();
|
||||
|
||||
let mut distortion = std::f64::MAX;
|
||||
let mut distortion = f64::MAX;
|
||||
let mut y = KMeans::<TX, TY, X, Y>::kmeans_plus_plus(data, parameters.k, parameters.seed);
|
||||
let mut size = vec![0; parameters.k];
|
||||
let mut centroids = vec![vec![0f64; d]; parameters.k];
|
||||
@@ -322,109 +322,6 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y>
|
||||
})
|
||||
}
|
||||
|
||||
/// Fit algorithm to _NxM_ matrix where _N_ is number of samples and _M_ is number of features.
|
||||
/// * `data` - training instances to cluster
|
||||
/// * `parameters` - cluster parameters
|
||||
/// * `centroids` - starting centroids
|
||||
pub fn fit_with_centroids(
|
||||
data: &X,
|
||||
parameters: KMeansParameters,
|
||||
centroids: Vec<Vec<f64>>,
|
||||
) -> Result<KMeans<TX, TY, X, Y>, Failed> {
|
||||
|
||||
// TODO: reuse existing methods in `crate::metrics`
|
||||
fn euclidean_distance(point1: &Vec<f64>, point2: &Vec<f64>) -> f64 {
|
||||
let mut dist = 0.0;
|
||||
for i in 0..point1.len() {
|
||||
dist += (point1[i] - point2[i]).powi(2);
|
||||
}
|
||||
dist.sqrt()
|
||||
}
|
||||
|
||||
fn closest_centroid(point: &Vec<f64>, centroids: &Vec<Vec<f64>>) -> usize {
|
||||
let mut closest_idx = 0;
|
||||
let mut closest_dist = std::f64::MAX;
|
||||
for (i, centroid) in centroids.iter().enumerate() {
|
||||
let dist = euclidean_distance(point, centroid);
|
||||
if dist < closest_dist {
|
||||
closest_dist = dist;
|
||||
closest_idx = i;
|
||||
}
|
||||
}
|
||||
closest_idx
|
||||
}
|
||||
|
||||
let bbd = BBDTree::new(data);
|
||||
|
||||
if centroids.len() != parameters.k {
|
||||
return Err(Failed::fit(&format!(
|
||||
"number of centroids ({}) must be equal to k ({})",
|
||||
centroids.len(),
|
||||
parameters.k
|
||||
)));
|
||||
}
|
||||
|
||||
let mut y = vec![0; data.shape().0];
|
||||
for i in 0..data.shape().0 {
|
||||
y[i] = closest_centroid(
|
||||
&Vec::from_iterator(data.get_row(i).iterator(0).map(|e| e.to_f64().unwrap()),
|
||||
data.shape().1), ¢roids
|
||||
);
|
||||
}
|
||||
|
||||
let mut size = vec![0; parameters.k];
|
||||
let mut new_centroids = vec![vec![0f64; data.shape().1]; parameters.k];
|
||||
|
||||
for i in 0..data.shape().0 {
|
||||
size[y[i]] += 1;
|
||||
}
|
||||
|
||||
for i in 0..data.shape().0 {
|
||||
for j in 0..data.shape().1 {
|
||||
new_centroids[y[i]][j] += data.get((i, j)).to_f64().unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
for i in 0..parameters.k {
|
||||
for j in 0..data.shape().1 {
|
||||
new_centroids[i][j] /= size[i] as f64;
|
||||
}
|
||||
}
|
||||
|
||||
let mut sums = vec![vec![0f64; data.shape().1]; parameters.k];
|
||||
let mut distortion = std::f64::MAX;
|
||||
|
||||
for _ in 1..=parameters.max_iter {
|
||||
let dist = bbd.clustering(&new_centroids, &mut sums, &mut size, &mut y);
|
||||
for i in 0..parameters.k {
|
||||
if size[i] > 0 {
|
||||
for j in 0..data.shape().1 {
|
||||
new_centroids[i][j] = sums[i][j] / size[i] as f64;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if distortion <= dist {
|
||||
break;
|
||||
} else {
|
||||
distortion = dist;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(KMeans {
|
||||
k: parameters.k,
|
||||
_y: y,
|
||||
size,
|
||||
_distortion: distortion,
|
||||
centroids: new_centroids,
|
||||
_phantom_tx: PhantomData,
|
||||
_phantom_ty: PhantomData,
|
||||
_phantom_x: PhantomData,
|
||||
_phantom_y: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
/// Predict clusters for `x`
|
||||
/// * `x` - matrix with new data to transform of size _KxM_ , where _K_ is number of new samples and _M_ is number of features.
|
||||
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
@@ -434,7 +331,7 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y>
|
||||
let mut row = vec![0f64; x.shape().1];
|
||||
|
||||
for i in 0..n {
|
||||
let mut min_dist = std::f64::MAX;
|
||||
let mut min_dist = f64::MAX;
|
||||
let mut best_cluster = 0;
|
||||
|
||||
for j in 0..self.k {
|
||||
@@ -464,7 +361,7 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y>
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let mut d = vec![std::f64::MAX; n];
|
||||
let mut d = vec![f64::MAX; n];
|
||||
let mut row = vec![TX::zero(); data.shape().1];
|
||||
|
||||
for j in 1..k {
|
||||
@@ -520,7 +417,6 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y>
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use crate::algorithm::neighbour::fastpair;
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
@@ -528,7 +424,7 @@ mod tests {
|
||||
)]
|
||||
#[test]
|
||||
fn invalid_k() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]);
|
||||
let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
|
||||
|
||||
assert!(KMeans::<i32, i32, DenseMatrix<i32>, Vec<i32>>::fit(
|
||||
&x,
|
||||
@@ -596,7 +492,8 @@ mod tests {
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let kmeans = KMeans::fit(&x, Default::default()).unwrap();
|
||||
|
||||
@@ -607,78 +504,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn fit_with_centroids_predict() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[5.7, 2.8, 4.5, 1.3],
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
|
||||
let parameters = KMeansParameters {
|
||||
k: 3,
|
||||
max_iter: 50,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// compute pairs
|
||||
let fastpair = fastpair::FastPair::new(&x).unwrap();
|
||||
|
||||
// compute centroids for N closest pairs
|
||||
let mut n: isize = 2;
|
||||
let mut centroids = vec![vec![0f64; x.shape().1]; n as usize + 1];
|
||||
for p in fastpair.ordered_pairs() {
|
||||
if n == -1 {
|
||||
break
|
||||
}
|
||||
|
||||
centroids[n as usize] = {
|
||||
let mut result: Vec<f64> = Vec::with_capacity(x.shape().1);
|
||||
for val1 in x.get_row(p.node).iterator(0) {
|
||||
for val2 in x.get_row(p.neighbour.unwrap()).iterator(0) {
|
||||
let sum = val1 + val2;
|
||||
let avg = sum * 0.5f64;
|
||||
result.push(avg);
|
||||
}
|
||||
}
|
||||
result
|
||||
};
|
||||
|
||||
n -= 1;
|
||||
}
|
||||
|
||||
|
||||
let kmeans = KMeans::fit_with_centroids(
|
||||
&x, parameters, centroids).unwrap();
|
||||
|
||||
let y: Vec<usize> = kmeans.predict(&x).unwrap();
|
||||
|
||||
for (i, _y_i) in y.iter().enumerate() {
|
||||
assert_eq!({ y[i] }, kmeans._y[i]);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
@@ -707,7 +532,8 @@ mod tests {
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let kmeans: KMeans<f32, f32, DenseMatrix<f32>, Vec<f32>> =
|
||||
KMeans::fit(&x, Default::default()).unwrap();
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
//! Clustering is the type of unsupervised learning where you divide the population or data points into a number of groups such that data points in the same groups
|
||||
//! are more similar to other data points in the same group than those in other groups. In simple words, the aim is to segregate groups with similar traits and assign them into clusters.
|
||||
|
||||
pub mod agglomerative;
|
||||
pub mod dbscan;
|
||||
/// An iterative clustering algorithm that aims to find local maxima in each iteration.
|
||||
pub mod kmeans;
|
||||
|
||||
@@ -40,7 +40,7 @@ pub fn load_dataset() -> Dataset<f32, u32> {
|
||||
target: y,
|
||||
num_samples,
|
||||
num_features,
|
||||
feature_names: vec![
|
||||
feature_names: [
|
||||
"Age", "Sex", "BMI", "BP", "S1", "S2", "S3", "S4", "S5", "S6",
|
||||
]
|
||||
.iter()
|
||||
|
||||
@@ -25,16 +25,14 @@ pub fn load_dataset() -> Dataset<f32, f32> {
|
||||
target: y,
|
||||
num_samples,
|
||||
num_features,
|
||||
feature_names: vec![
|
||||
"sepal length (cm)",
|
||||
feature_names: ["sepal length (cm)",
|
||||
"sepal width (cm)",
|
||||
"petal length (cm)",
|
||||
"petal width (cm)",
|
||||
]
|
||||
"petal width (cm)"]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect(),
|
||||
target_names: vec!["setosa", "versicolor", "virginica"]
|
||||
target_names: ["setosa", "versicolor", "virginica"]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect(),
|
||||
|
||||
+2
-2
@@ -36,7 +36,7 @@ pub fn load_dataset() -> Dataset<f32, u32> {
|
||||
target: y,
|
||||
num_samples,
|
||||
num_features,
|
||||
feature_names: vec![
|
||||
feature_names: [
|
||||
"sepal length (cm)",
|
||||
"sepal width (cm)",
|
||||
"petal length (cm)",
|
||||
@@ -45,7 +45,7 @@ pub fn load_dataset() -> Dataset<f32, u32> {
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect(),
|
||||
target_names: vec!["setosa", "versicolor", "virginica"]
|
||||
target_names: ["setosa", "versicolor", "virginica"]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect(),
|
||||
|
||||
@@ -35,7 +35,7 @@
|
||||
//! &[4.9, 2.4, 3.3, 1.0],
|
||||
//! &[6.6, 2.9, 4.6, 1.3],
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! let pca = PCA::fit(&iris, PCAParameters::default().with_n_components(2)).unwrap(); // Reduce number of features to 2
|
||||
//!
|
||||
@@ -443,6 +443,7 @@ mod tests {
|
||||
&[2.6, 53.0, 66.0, 10.8],
|
||||
&[6.8, 161.0, 60.0, 15.6],
|
||||
])
|
||||
.unwrap()
|
||||
}
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
@@ -457,7 +458,8 @@ mod tests {
|
||||
&[0.9952, 0.0588],
|
||||
&[0.0463, 0.9769],
|
||||
&[0.0752, 0.2007],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let pca = PCA::fit(&us_arrests, Default::default()).unwrap();
|
||||
|
||||
@@ -500,7 +502,8 @@ mod tests {
|
||||
-0.974080592182491,
|
||||
0.0723250196376097,
|
||||
],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let expected_projection = DenseMatrix::from_2d_array(&[
|
||||
&[-64.8022, -11.448, 2.4949, -2.4079],
|
||||
@@ -553,7 +556,8 @@ mod tests {
|
||||
&[91.5446, -22.9529, 0.402, -0.7369],
|
||||
&[118.1763, 5.5076, 2.7113, -0.205],
|
||||
&[10.4345, -5.9245, 3.7944, 0.5179],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let expected_eigenvalues: Vec<f64> = vec![
|
||||
343544.6277001563,
|
||||
@@ -616,7 +620,8 @@ mod tests {
|
||||
-0.0881962972508558,
|
||||
-0.0096011588898465,
|
||||
],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let expected_projection = DenseMatrix::from_2d_array(&[
|
||||
&[0.9856, -1.1334, 0.4443, -0.1563],
|
||||
@@ -669,7 +674,8 @@ mod tests {
|
||||
&[-2.1086, -1.4248, -0.1048, -0.1319],
|
||||
&[-2.0797, 0.6113, 0.1389, -0.1841],
|
||||
&[-0.6294, -0.321, 0.2407, 0.1667],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let expected_eigenvalues: Vec<f64> = vec![
|
||||
2.480241579149493,
|
||||
@@ -732,7 +738,7 @@ mod tests {
|
||||
// &[4.9, 2.4, 3.3, 1.0],
|
||||
// &[6.6, 2.9, 4.6, 1.3],
|
||||
// &[5.2, 2.7, 3.9, 1.4],
|
||||
// ]);
|
||||
// ]).unwrap();
|
||||
|
||||
// let pca = PCA::fit(&iris, Default::default()).unwrap();
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@
|
||||
//! &[4.9, 2.4, 3.3, 1.0],
|
||||
//! &[6.6, 2.9, 4.6, 1.3],
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! let svd = SVD::fit(&iris, SVDParameters::default().
|
||||
//! with_n_components(2)).unwrap(); // Reduce number of features to 2
|
||||
@@ -292,7 +292,8 @@ mod tests {
|
||||
&[5.7, 81.0, 39.0, 9.3],
|
||||
&[2.6, 53.0, 66.0, 10.8],
|
||||
&[6.8, 161.0, 60.0, 15.6],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let expected = DenseMatrix::from_2d_array(&[
|
||||
&[243.54655757, -18.76673788],
|
||||
@@ -300,7 +301,8 @@ mod tests {
|
||||
&[305.93972467, -15.39087376],
|
||||
&[197.28420365, -11.66808306],
|
||||
&[293.43187394, 1.91163633],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let svd = SVD::fit(&x, Default::default()).unwrap();
|
||||
|
||||
let x_transformed = svd.transform(&x).unwrap();
|
||||
@@ -341,7 +343,7 @@ mod tests {
|
||||
// &[4.9, 2.4, 3.3, 1.0],
|
||||
// &[6.6, 2.9, 4.6, 1.3],
|
||||
// &[5.2, 2.7, 3.9, 1.4],
|
||||
// ]);
|
||||
// ]).unwrap();
|
||||
|
||||
// let svd = SVD::fit(&iris, Default::default()).unwrap();
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@
|
||||
//! &[4.9, 2.4, 3.3, 1.0],
|
||||
//! &[6.6, 2.9, 4.6, 1.3],
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//! let y = vec![
|
||||
//! 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
//! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
@@ -660,7 +660,8 @@ mod tests {
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
|
||||
|
||||
let classifier = RandomForestClassifier::fit(
|
||||
@@ -733,7 +734,8 @@ mod tests {
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
|
||||
|
||||
let classifier = RandomForestClassifier::fit(
|
||||
@@ -786,7 +788,8 @@ mod tests {
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
|
||||
|
||||
let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
@@ -29,7 +29,7 @@
|
||||
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//! let y = vec![
|
||||
//! 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2,
|
||||
//! 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9
|
||||
@@ -574,7 +574,8 @@ mod tests {
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
@@ -648,7 +649,8 @@ mod tests {
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
@@ -702,7 +704,8 @@ mod tests {
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
|
||||
@@ -32,6 +32,8 @@ pub enum FailedError {
|
||||
SolutionFailed,
|
||||
/// Error in input parameters
|
||||
ParametersError,
|
||||
/// Invalid state error (should never happen)
|
||||
InvalidStateError,
|
||||
}
|
||||
|
||||
impl Failed {
|
||||
@@ -64,6 +66,22 @@ impl Failed {
|
||||
}
|
||||
}
|
||||
|
||||
/// new instance of `FailedError::ParametersError`
|
||||
pub fn input(msg: &str) -> Self {
|
||||
Failed {
|
||||
err: FailedError::ParametersError,
|
||||
msg: msg.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// new instance of `FailedError::InvalidStateError`
|
||||
pub fn invalid_state(msg: &str) -> Self {
|
||||
Failed {
|
||||
err: FailedError::InvalidStateError,
|
||||
msg: msg.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// new instance of `err`
|
||||
pub fn because(err: FailedError, msg: &str) -> Self {
|
||||
Failed {
|
||||
@@ -97,6 +115,7 @@ impl fmt::Display for FailedError {
|
||||
FailedError::DecompositionFailed => "Decomposition failed",
|
||||
FailedError::SolutionFailed => "Can't find solution",
|
||||
FailedError::ParametersError => "Error in input, check parameters",
|
||||
FailedError::InvalidStateError => "Invalid state, this should never happen", // useful in development phase of lib
|
||||
};
|
||||
write!(f, "{failed_err_str}")
|
||||
}
|
||||
|
||||
+1
-2
@@ -7,7 +7,6 @@
|
||||
clippy::approx_constant
|
||||
)]
|
||||
#![warn(missing_docs)]
|
||||
#![warn(rustdoc::missing_doc_code_examples)]
|
||||
|
||||
//! # smartcore
|
||||
//!
|
||||
@@ -64,7 +63,7 @@
|
||||
//! &[3., 4.],
|
||||
//! &[5., 6.],
|
||||
//! &[7., 8.],
|
||||
//! &[9., 10.]]);
|
||||
//! &[9., 10.]]).unwrap();
|
||||
//! // Our classes are defined as a vector
|
||||
//! let y = vec![2, 2, 2, 3, 3];
|
||||
//!
|
||||
|
||||
+226
-165
File diff suppressed because it is too large
Load Diff
+221
-92
@@ -19,6 +19,8 @@ use crate::linalg::traits::svd::SVDDecomposable;
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
|
||||
use crate::error::Failed;
|
||||
|
||||
/// Dense matrix
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -50,26 +52,26 @@ pub struct DenseMatrixMutView<'a, T: Debug + Display + Copy + Sized> {
|
||||
}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixView<'a, T> {
|
||||
fn new(m: &'a DenseMatrix<T>, rows: Range<usize>, cols: Range<usize>) -> Self {
|
||||
let (start, end, stride) = if m.column_major {
|
||||
(
|
||||
rows.start + cols.start * m.nrows,
|
||||
rows.end + (cols.end - 1) * m.nrows,
|
||||
m.nrows,
|
||||
)
|
||||
fn new(
|
||||
m: &'a DenseMatrix<T>,
|
||||
vrows: Range<usize>,
|
||||
vcols: Range<usize>,
|
||||
) -> Result<Self, Failed> {
|
||||
if m.is_valid_view(m.shape().0, m.shape().1, &vrows, &vcols) {
|
||||
Err(Failed::input(
|
||||
"The specified view is outside of the matrix range",
|
||||
))
|
||||
} else {
|
||||
(
|
||||
rows.start * m.ncols + cols.start,
|
||||
(rows.end - 1) * m.ncols + cols.end,
|
||||
m.ncols,
|
||||
)
|
||||
};
|
||||
DenseMatrixView {
|
||||
values: &m.values[start..end],
|
||||
stride,
|
||||
nrows: rows.end - rows.start,
|
||||
ncols: cols.end - cols.start,
|
||||
column_major: m.column_major,
|
||||
let (start, end, stride) =
|
||||
m.stride_range(m.shape().0, m.shape().1, &vrows, &vcols, m.column_major);
|
||||
|
||||
Ok(DenseMatrixView {
|
||||
values: &m.values[start..end],
|
||||
stride,
|
||||
nrows: vrows.end - vrows.start,
|
||||
ncols: vcols.end - vcols.start,
|
||||
column_major: m.column_major,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,7 +91,7 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixView<'a, T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixView<'a, T> {
|
||||
impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixView<'_, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
@@ -102,26 +104,26 @@ impl<'a, T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixView<'a,
|
||||
}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> {
|
||||
fn new(m: &'a mut DenseMatrix<T>, rows: Range<usize>, cols: Range<usize>) -> Self {
|
||||
let (start, end, stride) = if m.column_major {
|
||||
(
|
||||
rows.start + cols.start * m.nrows,
|
||||
rows.end + (cols.end - 1) * m.nrows,
|
||||
m.nrows,
|
||||
)
|
||||
fn new(
|
||||
m: &'a mut DenseMatrix<T>,
|
||||
vrows: Range<usize>,
|
||||
vcols: Range<usize>,
|
||||
) -> Result<Self, Failed> {
|
||||
if m.is_valid_view(m.shape().0, m.shape().1, &vrows, &vcols) {
|
||||
Err(Failed::input(
|
||||
"The specified view is outside of the matrix range",
|
||||
))
|
||||
} else {
|
||||
(
|
||||
rows.start * m.ncols + cols.start,
|
||||
(rows.end - 1) * m.ncols + cols.end,
|
||||
m.ncols,
|
||||
)
|
||||
};
|
||||
DenseMatrixMutView {
|
||||
values: &mut m.values[start..end],
|
||||
stride,
|
||||
nrows: rows.end - rows.start,
|
||||
ncols: cols.end - cols.start,
|
||||
column_major: m.column_major,
|
||||
let (start, end, stride) =
|
||||
m.stride_range(m.shape().0, m.shape().1, &vrows, &vcols, m.column_major);
|
||||
|
||||
Ok(DenseMatrixMutView {
|
||||
values: &mut m.values[start..end],
|
||||
stride,
|
||||
nrows: vrows.end - vrows.start,
|
||||
ncols: vcols.end - vcols.start,
|
||||
column_major: m.column_major,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,7 +142,7 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> {
|
||||
}
|
||||
}
|
||||
|
||||
fn iter_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &mut T> + 'b> {
|
||||
fn iter_mut<'b>(&'b mut self, axis: u8) -> Box<dyn Iterator<Item = &'b mut T> + 'b> {
|
||||
let column_major = self.column_major;
|
||||
let stride = self.stride;
|
||||
let ptr = self.values.as_mut_ptr();
|
||||
@@ -167,7 +169,7 @@ impl<'a, T: Debug + Display + Copy + Sized> DenseMatrixMutView<'a, T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixMutView<'a, T> {
|
||||
impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixMutView<'_, T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
@@ -182,42 +184,102 @@ impl<'a, T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrixMutView<
|
||||
impl<T: Debug + Display + Copy + Sized> DenseMatrix<T> {
|
||||
/// Create new instance of `DenseMatrix` without copying data.
|
||||
/// `values` should be in column-major order.
|
||||
pub fn new(nrows: usize, ncols: usize, values: Vec<T>, column_major: bool) -> Self {
|
||||
DenseMatrix {
|
||||
ncols,
|
||||
nrows,
|
||||
values,
|
||||
column_major,
|
||||
pub fn new(
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
values: Vec<T>,
|
||||
column_major: bool,
|
||||
) -> Result<Self, Failed> {
|
||||
let data_len = values.len();
|
||||
if nrows * ncols != values.len() {
|
||||
Err(Failed::input(&format!(
|
||||
"The specified shape: (cols: {ncols}, rows: {nrows}) does not align with data len: {data_len}"
|
||||
)))
|
||||
} else {
|
||||
Ok(DenseMatrix {
|
||||
ncols,
|
||||
nrows,
|
||||
values,
|
||||
column_major,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// New instance of `DenseMatrix` from 2d array.
|
||||
pub fn from_2d_array(values: &[&[T]]) -> Self {
|
||||
pub fn from_2d_array(values: &[&[T]]) -> Result<Self, Failed> {
|
||||
DenseMatrix::from_2d_vec(&values.iter().map(|row| Vec::from(*row)).collect())
|
||||
}
|
||||
|
||||
/// New instance of `DenseMatrix` from 2d vector.
|
||||
pub fn from_2d_vec(values: &Vec<Vec<T>>) -> Self {
|
||||
let nrows = values.len();
|
||||
let ncols = values
|
||||
.first()
|
||||
.unwrap_or_else(|| panic!("Cannot create 2d matrix from an empty vector"))
|
||||
.len();
|
||||
let mut m_values = Vec::with_capacity(nrows * ncols);
|
||||
#[allow(clippy::ptr_arg)]
|
||||
pub fn from_2d_vec(values: &Vec<Vec<T>>) -> Result<Self, Failed> {
|
||||
if values.is_empty() || values[0].is_empty() {
|
||||
Err(Failed::input(
|
||||
"The 2d vec provided is empty; cannot instantiate the matrix",
|
||||
))
|
||||
} else {
|
||||
let nrows = values.len();
|
||||
let ncols = values
|
||||
.first()
|
||||
.unwrap_or_else(|| {
|
||||
panic!("Invalid state: Cannot create 2d matrix from an empty vector")
|
||||
})
|
||||
.len();
|
||||
let mut m_values = Vec::with_capacity(nrows * ncols);
|
||||
|
||||
for c in 0..ncols {
|
||||
for r in values.iter().take(nrows) {
|
||||
m_values.push(r[c])
|
||||
for c in 0..ncols {
|
||||
for r in values.iter().take(nrows) {
|
||||
m_values.push(r[c])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
DenseMatrix::new(nrows, ncols, m_values, true)
|
||||
DenseMatrix::new(nrows, ncols, m_values, true)
|
||||
}
|
||||
}
|
||||
|
||||
/// Iterate over values of matrix
|
||||
pub fn iter(&self) -> Iter<'_, T> {
|
||||
self.values.iter()
|
||||
}
|
||||
|
||||
/// Check if the size of the requested view is bounded to matrix rows/cols count
|
||||
fn is_valid_view(
|
||||
&self,
|
||||
n_rows: usize,
|
||||
n_cols: usize,
|
||||
vrows: &Range<usize>,
|
||||
vcols: &Range<usize>,
|
||||
) -> bool {
|
||||
!(vrows.end <= n_rows
|
||||
&& vcols.end <= n_cols
|
||||
&& vrows.start <= n_rows
|
||||
&& vcols.start <= n_cols)
|
||||
}
|
||||
|
||||
/// Compute the range of the requested view: start, end, size of the slice
|
||||
fn stride_range(
|
||||
&self,
|
||||
n_rows: usize,
|
||||
n_cols: usize,
|
||||
vrows: &Range<usize>,
|
||||
vcols: &Range<usize>,
|
||||
column_major: bool,
|
||||
) -> (usize, usize, usize) {
|
||||
let (start, end, stride) = if column_major {
|
||||
(
|
||||
vrows.start + vcols.start * n_rows,
|
||||
vrows.end + (vcols.end - 1) * n_rows,
|
||||
n_rows,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
vrows.start * n_cols + vcols.start,
|
||||
(vrows.end - 1) * n_cols + vcols.end,
|
||||
n_cols,
|
||||
)
|
||||
};
|
||||
(start, end, stride)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> fmt::Display for DenseMatrix<T> {
|
||||
@@ -304,6 +366,7 @@ where
|
||||
impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrix<T> {
|
||||
fn get(&self, pos: (usize, usize)) -> &T {
|
||||
let (row, col) = pos;
|
||||
|
||||
if row >= self.nrows || col >= self.ncols {
|
||||
panic!(
|
||||
"Invalid index ({},{}) for {}x{} matrix",
|
||||
@@ -383,15 +446,15 @@ impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for DenseMatrix<T> {}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> Array2<T> for DenseMatrix<T> {
|
||||
fn get_row<'a>(&'a self, row: usize) -> Box<dyn ArrayView1<T> + 'a> {
|
||||
Box::new(DenseMatrixView::new(self, row..row + 1, 0..self.ncols))
|
||||
Box::new(DenseMatrixView::new(self, row..row + 1, 0..self.ncols).unwrap())
|
||||
}
|
||||
|
||||
fn get_col<'a>(&'a self, col: usize) -> Box<dyn ArrayView1<T> + 'a> {
|
||||
Box::new(DenseMatrixView::new(self, 0..self.nrows, col..col + 1))
|
||||
Box::new(DenseMatrixView::new(self, 0..self.nrows, col..col + 1).unwrap())
|
||||
}
|
||||
|
||||
fn slice<'a>(&'a self, rows: Range<usize>, cols: Range<usize>) -> Box<dyn ArrayView2<T> + 'a> {
|
||||
Box::new(DenseMatrixView::new(self, rows, cols))
|
||||
Box::new(DenseMatrixView::new(self, rows, cols).unwrap())
|
||||
}
|
||||
|
||||
fn slice_mut<'a>(
|
||||
@@ -402,15 +465,17 @@ impl<T: Debug + Display + Copy + Sized> Array2<T> for DenseMatrix<T> {
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
Box::new(DenseMatrixMutView::new(self, rows, cols))
|
||||
Box::new(DenseMatrixMutView::new(self, rows, cols).unwrap())
|
||||
}
|
||||
|
||||
// private function so for now assume infalible
|
||||
fn fill(nrows: usize, ncols: usize, value: T) -> Self {
|
||||
DenseMatrix::new(nrows, ncols, vec![value; nrows * ncols], true)
|
||||
DenseMatrix::new(nrows, ncols, vec![value; nrows * ncols], true).unwrap()
|
||||
}
|
||||
|
||||
// private function so for now assume infalible
|
||||
fn from_iterator<I: Iterator<Item = T>>(iter: I, nrows: usize, ncols: usize, axis: u8) -> Self {
|
||||
DenseMatrix::new(nrows, ncols, iter.collect(), axis != 0)
|
||||
DenseMatrix::new(nrows, ncols, iter.collect(), axis != 0).unwrap()
|
||||
}
|
||||
|
||||
fn transpose(&self) -> Self {
|
||||
@@ -428,7 +493,7 @@ impl<T: Number + RealNumber> EVDDecomposable<T> for DenseMatrix<T> {}
|
||||
impl<T: Number + RealNumber> LUDecomposable<T> for DenseMatrix<T> {}
|
||||
impl<T: Number + RealNumber> SVDDecomposable<T> for DenseMatrix<T> {}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixView<'a, T> {
|
||||
impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixView<'_, T> {
|
||||
fn get(&self, pos: (usize, usize)) -> &T {
|
||||
if self.column_major {
|
||||
&self.values[pos.0 + pos.1 * self.stride]
|
||||
@@ -450,7 +515,7 @@ impl<'a, T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMa
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for DenseMatrixView<'a, T> {
|
||||
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for DenseMatrixView<'_, T> {
|
||||
fn get(&self, i: usize) -> &T {
|
||||
if self.nrows == 1 {
|
||||
if self.column_major {
|
||||
@@ -488,11 +553,11 @@ impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for DenseMatrixView<
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixView<'a, T> {}
|
||||
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixView<'_, T> {}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for DenseMatrixView<'a, T> {}
|
||||
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for DenseMatrixView<'_, T> {}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixMutView<'a, T> {
|
||||
impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrixMutView<'_, T> {
|
||||
fn get(&self, pos: (usize, usize)) -> &T {
|
||||
if self.column_major {
|
||||
&self.values[pos.0 + pos.1 * self.stride]
|
||||
@@ -514,9 +579,7 @@ impl<'a, T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMa
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)>
|
||||
for DenseMatrixMutView<'a, T>
|
||||
{
|
||||
impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for DenseMatrixMutView<'_, T> {
|
||||
fn set(&mut self, pos: (usize, usize), x: T) {
|
||||
if self.column_major {
|
||||
self.values[pos.0 + pos.1 * self.stride] = x;
|
||||
@@ -530,29 +593,90 @@ impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)>
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> MutArrayView2<T> for DenseMatrixMutView<'a, T> {}
|
||||
impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for DenseMatrixMutView<'_, T> {}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixMutView<'a, T> {}
|
||||
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for DenseMatrixMutView<'_, T> {}
|
||||
|
||||
impl<T: RealNumber> MatrixStats<T> for DenseMatrix<T> {}
|
||||
|
||||
impl<T: RealNumber> MatrixPreprocessing<T> for DenseMatrix<T> {}
|
||||
|
||||
#[cfg(test)]
|
||||
#[warn(clippy::reversed_empty_ranges)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use approx::relative_eq;
|
||||
|
||||
#[test]
|
||||
fn test_display() {
|
||||
fn test_instantiate_from_2d() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
|
||||
assert!(x.is_ok());
|
||||
}
|
||||
#[test]
|
||||
fn test_instantiate_from_2d_empty() {
|
||||
let input: &[&[f64]] = &[&[]];
|
||||
let x = DenseMatrix::from_2d_array(input);
|
||||
assert!(x.is_err());
|
||||
}
|
||||
#[test]
|
||||
fn test_instantiate_from_2d_empty2() {
|
||||
let input: &[&[f64]] = &[&[], &[]];
|
||||
let x = DenseMatrix::from_2d_array(input);
|
||||
assert!(x.is_err());
|
||||
}
|
||||
#[test]
|
||||
fn test_instantiate_ok_view1() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
|
||||
let v = DenseMatrixView::new(&x, 0..2, 0..2);
|
||||
assert!(v.is_ok());
|
||||
}
|
||||
#[test]
|
||||
fn test_instantiate_ok_view2() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
|
||||
let v = DenseMatrixView::new(&x, 0..3, 0..3);
|
||||
assert!(v.is_ok());
|
||||
}
|
||||
#[test]
|
||||
fn test_instantiate_ok_view3() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
|
||||
let v = DenseMatrixView::new(&x, 2..3, 0..3);
|
||||
assert!(v.is_ok());
|
||||
}
|
||||
#[test]
|
||||
fn test_instantiate_ok_view4() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
|
||||
let v = DenseMatrixView::new(&x, 3..3, 0..3);
|
||||
assert!(v.is_ok());
|
||||
}
|
||||
#[test]
|
||||
fn test_instantiate_err_view1() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
|
||||
let v = DenseMatrixView::new(&x, 3..4, 0..3);
|
||||
assert!(v.is_err());
|
||||
}
|
||||
#[test]
|
||||
fn test_instantiate_err_view2() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
|
||||
let v = DenseMatrixView::new(&x, 0..3, 3..4);
|
||||
assert!(v.is_err());
|
||||
}
|
||||
#[test]
|
||||
fn test_instantiate_err_view3() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
|
||||
#[allow(clippy::reversed_empty_ranges)]
|
||||
let v = DenseMatrixView::new(&x, 0..3, 4..3);
|
||||
assert!(v.is_err());
|
||||
}
|
||||
#[test]
|
||||
fn test_display() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
|
||||
|
||||
println!("{}", &x);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_row_col() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
|
||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
|
||||
|
||||
assert_eq!(15.0, x.get_col(1).sum());
|
||||
assert_eq!(15.0, x.get_row(1).sum());
|
||||
@@ -561,7 +685,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_row_major() {
|
||||
let mut x = DenseMatrix::new(2, 3, vec![1, 2, 3, 4, 5, 6], false);
|
||||
let mut x = DenseMatrix::new(2, 3, vec![1, 2, 3, 4, 5, 6], false).unwrap();
|
||||
|
||||
assert_eq!(5, *x.get_col(1).get(1));
|
||||
assert_eq!(7, x.get_col(1).sum());
|
||||
@@ -575,7 +699,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_get_slice() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9], &[10, 11, 12]]);
|
||||
let x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9], &[10, 11, 12]])
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
vec![4, 5, 6],
|
||||
@@ -589,7 +714,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_iter_mut() {
|
||||
let mut x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]]);
|
||||
let mut x = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]]).unwrap();
|
||||
|
||||
assert_eq!(vec![1, 4, 7, 2, 5, 8, 3, 6, 9], x.values);
|
||||
// add +2 to some elements
|
||||
@@ -625,7 +750,8 @@ mod tests {
|
||||
#[test]
|
||||
fn test_str_array() {
|
||||
let mut x =
|
||||
DenseMatrix::from_2d_array(&[&["1", "2", "3"], &["4", "5", "6"], &["7", "8", "9"]]);
|
||||
DenseMatrix::from_2d_array(&[&["1", "2", "3"], &["4", "5", "6"], &["7", "8", "9"]])
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(vec!["1", "4", "7", "2", "5", "8", "3", "6", "9"], x.values);
|
||||
x.iterator_mut(0).for_each(|v| *v = "str");
|
||||
@@ -637,7 +763,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_transpose() {
|
||||
let x = DenseMatrix::<&str>::from_2d_array(&[&["1", "2", "3"], &["4", "5", "6"]]);
|
||||
let x = DenseMatrix::<&str>::from_2d_array(&[&["1", "2", "3"], &["4", "5", "6"]]).unwrap();
|
||||
|
||||
assert_eq!(vec!["1", "4", "2", "5", "3", "6"], x.values);
|
||||
assert!(x.column_major);
|
||||
@@ -650,7 +776,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_from_iterator() {
|
||||
let data = vec![1, 2, 3, 4, 5, 6];
|
||||
let data = [1, 2, 3, 4, 5, 6];
|
||||
|
||||
let m = DenseMatrix::from_iterator(data.iter(), 2, 3, 0);
|
||||
|
||||
@@ -664,8 +790,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_take() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]);
|
||||
let b = DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4], &[5, 6]]);
|
||||
let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6]]).unwrap();
|
||||
let b = DenseMatrix::from_2d_array(&[&[1, 2], &[3, 4], &[5, 6]]).unwrap();
|
||||
|
||||
println!("{a}");
|
||||
// take column 0 and 2
|
||||
@@ -677,7 +803,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_mut() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1.3, -2.1, 3.4], &[-4., -5.3, 6.1]]);
|
||||
let a = DenseMatrix::from_2d_array(&[&[1.3, -2.1, 3.4], &[-4., -5.3, 6.1]]).unwrap();
|
||||
|
||||
let a = a.abs();
|
||||
assert_eq!(vec![1.3, 4.0, 2.1, 5.3, 3.4, 6.1], a.values);
|
||||
@@ -688,7 +814,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_reshape() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9], &[10, 11, 12]]);
|
||||
let a = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9], &[10, 11, 12]])
|
||||
.unwrap();
|
||||
|
||||
let a = a.reshape(2, 6, 0);
|
||||
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], a.values);
|
||||
@@ -701,13 +828,15 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_eq() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]);
|
||||
let b = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]);
|
||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.]]).unwrap();
|
||||
let b = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
|
||||
let c = DenseMatrix::from_2d_array(&[
|
||||
&[1. + f32::EPSILON, 2., 3.],
|
||||
&[4., 5., 6. + f32::EPSILON],
|
||||
]);
|
||||
let d = DenseMatrix::from_2d_array(&[&[1. + 0.5, 2., 3.], &[4., 5., 6. + f32::EPSILON]]);
|
||||
])
|
||||
.unwrap();
|
||||
let d = DenseMatrix::from_2d_array(&[&[1. + 0.5, 2., 3.], &[4., 5., 6. + f32::EPSILON]])
|
||||
.unwrap();
|
||||
|
||||
assert!(!relative_eq!(a, b));
|
||||
assert!(!relative_eq!(a, d));
|
||||
|
||||
@@ -55,6 +55,7 @@ impl<T: Debug + Display + Copy + Sized> Array<T, usize> for Vec<T> {
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for Vec<T> {
|
||||
fn set(&mut self, i: usize, x: T) {
|
||||
// NOTE: this panics in case of out of bounds index
|
||||
self[i] = x
|
||||
}
|
||||
|
||||
@@ -118,7 +119,7 @@ impl<T: Debug + Display + Copy + Sized> Array1<T> for Vec<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for VecMutView<'a, T> {
|
||||
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for VecMutView<'_, T> {
|
||||
fn get(&self, i: usize) -> &T {
|
||||
&self.ptr[i]
|
||||
}
|
||||
@@ -137,7 +138,7 @@ impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for VecMutView<'a, T
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, usize> for VecMutView<'a, T> {
|
||||
impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for VecMutView<'_, T> {
|
||||
fn set(&mut self, i: usize, x: T) {
|
||||
self.ptr[i] = x;
|
||||
}
|
||||
@@ -148,10 +149,10 @@ impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, usize> for VecMutView<'a
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for VecMutView<'a, T> {}
|
||||
impl<'a, T: Debug + Display + Copy + Sized> MutArrayView1<T> for VecMutView<'a, T> {}
|
||||
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for VecMutView<'_, T> {}
|
||||
impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for VecMutView<'_, T> {}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for VecView<'a, T> {
|
||||
impl<T: Debug + Display + Copy + Sized> Array<T, usize> for VecView<'_, T> {
|
||||
fn get(&self, i: usize) -> &T {
|
||||
&self.ptr[i]
|
||||
}
|
||||
@@ -170,7 +171,7 @@ impl<'a, T: Debug + Display + Copy + Sized> Array<T, usize> for VecView<'a, T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for VecView<'a, T> {}
|
||||
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for VecView<'_, T> {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
@@ -211,7 +212,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_len() {
|
||||
let x = vec![1, 2, 3];
|
||||
let x = [1, 2, 3];
|
||||
assert_eq!(3, x.len());
|
||||
}
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayBase<OwnedRepr<T>
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> for ArrayView<'a, T, Ix2> {
|
||||
impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> for ArrayView<'_, T, Ix2> {
|
||||
fn get(&self, pos: (usize, usize)) -> &T {
|
||||
&self[[pos.0, pos.1]]
|
||||
}
|
||||
@@ -144,11 +144,9 @@ impl<T: Number + RealNumber> EVDDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2>
|
||||
impl<T: Number + RealNumber> LUDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
||||
impl<T: Number + RealNumber> SVDDecomposable<T> for ArrayBase<OwnedRepr<T>, Ix2> {}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayView<'a, T, Ix2> {}
|
||||
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayView<'_, T, Ix2> {}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)>
|
||||
for ArrayViewMut<'a, T, Ix2>
|
||||
{
|
||||
impl<T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)> for ArrayViewMut<'_, T, Ix2> {
|
||||
fn get(&self, pos: (usize, usize)) -> &T {
|
||||
&self[[pos.0, pos.1]]
|
||||
}
|
||||
@@ -175,9 +173,7 @@ impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, (usize, usize)>
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)>
|
||||
for ArrayViewMut<'a, T, Ix2>
|
||||
{
|
||||
impl<T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)> for ArrayViewMut<'_, T, Ix2> {
|
||||
fn set(&mut self, pos: (usize, usize), x: T) {
|
||||
self[[pos.0, pos.1]] = x
|
||||
}
|
||||
@@ -195,9 +191,9 @@ impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, (usize, usize)>
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> MutArrayView2<T> for ArrayViewMut<'a, T, Ix2> {}
|
||||
impl<T: Debug + Display + Copy + Sized> MutArrayView2<T> for ArrayViewMut<'_, T, Ix2> {}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayViewMut<'a, T, Ix2> {}
|
||||
impl<T: Debug + Display + Copy + Sized> ArrayView2<T> for ArrayViewMut<'_, T, Ix2> {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
@@ -41,7 +41,7 @@ impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayBase<OwnedRepr<T>
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for ArrayBase<OwnedRepr<T>, Ix1> {}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayView<'a, T, Ix1> {
|
||||
impl<T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayView<'_, T, Ix1> {
|
||||
fn get(&self, i: usize) -> &T {
|
||||
&self[i]
|
||||
}
|
||||
@@ -60,9 +60,9 @@ impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayView<'a
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayView<'a, T, Ix1> {}
|
||||
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayView<'_, T, Ix1> {}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayViewMut<'a, T, Ix1> {
|
||||
impl<T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayViewMut<'_, T, Ix1> {
|
||||
fn get(&self, i: usize) -> &T {
|
||||
&self[i]
|
||||
}
|
||||
@@ -81,7 +81,7 @@ impl<'a, T: Debug + Display + Copy + Sized> BaseArray<T, usize> for ArrayViewMut
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, usize> for ArrayViewMut<'a, T, Ix1> {
|
||||
impl<T: Debug + Display + Copy + Sized> MutArray<T, usize> for ArrayViewMut<'_, T, Ix1> {
|
||||
fn set(&mut self, i: usize, x: T) {
|
||||
self[i] = x;
|
||||
}
|
||||
@@ -92,8 +92,8 @@ impl<'a, T: Debug + Display + Copy + Sized> MutArray<T, usize> for ArrayViewMut<
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayViewMut<'a, T, Ix1> {}
|
||||
impl<'a, T: Debug + Display + Copy + Sized> MutArrayView1<T> for ArrayViewMut<'a, T, Ix1> {}
|
||||
impl<T: Debug + Display + Copy + Sized> ArrayView1<T> for ArrayViewMut<'_, T, Ix1> {}
|
||||
impl<T: Debug + Display + Copy + Sized> MutArrayView1<T> for ArrayViewMut<'_, T, Ix1> {}
|
||||
|
||||
impl<T: Debug + Display + Copy + Sized> Array1<T> for ArrayBase<OwnedRepr<T>, Ix1> {
|
||||
fn slice<'a>(&'a self, range: Range<usize>) -> Box<dyn ArrayView1<T> + 'a> {
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
//! &[25., 15., -5.],
|
||||
//! &[15., 18., 0.],
|
||||
//! &[-5., 0., 11.]
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! let cholesky = A.cholesky().unwrap();
|
||||
//! let lower_triangular: DenseMatrix<f64> = cholesky.L();
|
||||
@@ -175,11 +175,14 @@ mod tests {
|
||||
)]
|
||||
#[test]
|
||||
fn cholesky_decompose() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
|
||||
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]])
|
||||
.unwrap();
|
||||
let l =
|
||||
DenseMatrix::from_2d_array(&[&[5.0, 0.0, 0.0], &[3.0, 3.0, 0.0], &[-1.0, 1.0, 3.0]]);
|
||||
DenseMatrix::from_2d_array(&[&[5.0, 0.0, 0.0], &[3.0, 3.0, 0.0], &[-1.0, 1.0, 3.0]])
|
||||
.unwrap();
|
||||
let u =
|
||||
DenseMatrix::from_2d_array(&[&[5.0, 3.0, -1.0], &[0.0, 3.0, 1.0], &[0.0, 0.0, 3.0]]);
|
||||
DenseMatrix::from_2d_array(&[&[5.0, 3.0, -1.0], &[0.0, 3.0, 1.0], &[0.0, 0.0, 3.0]])
|
||||
.unwrap();
|
||||
let cholesky = a.cholesky().unwrap();
|
||||
|
||||
assert!(relative_eq!(cholesky.L().abs(), l.abs(), epsilon = 1e-4));
|
||||
@@ -197,9 +200,10 @@ mod tests {
|
||||
)]
|
||||
#[test]
|
||||
fn cholesky_solve_mut() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
|
||||
let b = DenseMatrix::from_2d_array(&[&[40., 51., 28.]]);
|
||||
let expected = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]);
|
||||
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]])
|
||||
.unwrap();
|
||||
let b = DenseMatrix::from_2d_array(&[&[40., 51., 28.]]).unwrap();
|
||||
let expected = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0]]).unwrap();
|
||||
|
||||
let cholesky = a.cholesky().unwrap();
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
//! &[0.9000, 0.4000, 0.7000],
|
||||
//! &[0.4000, 0.5000, 0.3000],
|
||||
//! &[0.7000, 0.3000, 0.8000],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! let evd = A.evd(true).unwrap();
|
||||
//! let eigenvectors: DenseMatrix<f64> = evd.V;
|
||||
@@ -820,7 +820,8 @@ mod tests {
|
||||
&[0.9000, 0.4000, 0.7000],
|
||||
&[0.4000, 0.5000, 0.3000],
|
||||
&[0.7000, 0.3000, 0.8000],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let eigen_values: Vec<f64> = vec![1.7498382, 0.3165784, 0.1335834];
|
||||
|
||||
@@ -828,7 +829,8 @@ mod tests {
|
||||
&[0.6881997, -0.07121225, 0.7220180],
|
||||
&[0.3700456, 0.89044952, -0.2648886],
|
||||
&[0.6240573, -0.44947578, -0.6391588],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let evd = A.evd(true).unwrap();
|
||||
|
||||
@@ -839,7 +841,7 @@ mod tests {
|
||||
));
|
||||
for (i, eigen_values_i) in eigen_values.iter().enumerate() {
|
||||
assert!((eigen_values_i - evd.d[i]).abs() < 1e-4);
|
||||
assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON);
|
||||
assert!((0f64 - evd.e[i]).abs() < f64::EPSILON);
|
||||
}
|
||||
}
|
||||
#[cfg_attr(
|
||||
@@ -852,7 +854,8 @@ mod tests {
|
||||
&[0.9000, 0.4000, 0.7000],
|
||||
&[0.4000, 0.5000, 0.3000],
|
||||
&[0.8000, 0.3000, 0.8000],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let eigen_values: Vec<f64> = vec![1.79171122, 0.31908143, 0.08920735];
|
||||
|
||||
@@ -860,7 +863,8 @@ mod tests {
|
||||
&[0.7178958, 0.05322098, 0.6812010],
|
||||
&[0.3837711, -0.84702111, -0.1494582],
|
||||
&[0.6952105, 0.43984484, -0.7036135],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let evd = A.evd(false).unwrap();
|
||||
|
||||
@@ -871,7 +875,7 @@ mod tests {
|
||||
));
|
||||
for (i, eigen_values_i) in eigen_values.iter().enumerate() {
|
||||
assert!((eigen_values_i - evd.d[i]).abs() < 1e-4);
|
||||
assert!((0f64 - evd.e[i]).abs() < std::f64::EPSILON);
|
||||
assert!((0f64 - evd.e[i]).abs() < f64::EPSILON);
|
||||
}
|
||||
}
|
||||
#[cfg_attr(
|
||||
@@ -885,7 +889,8 @@ mod tests {
|
||||
&[4.0, -1.0, 1.0, 1.0],
|
||||
&[1.0, 1.0, 3.0, -2.0],
|
||||
&[1.0, 1.0, 4.0, -1.0],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let eigen_values_d: Vec<f64> = vec![0.0, 2.0, 2.0, 0.0];
|
||||
let eigen_values_e: Vec<f64> = vec![2.2361, 0.9999, -0.9999, -2.2361];
|
||||
@@ -895,7 +900,8 @@ mod tests {
|
||||
&[-0.6707, 0.1059, 0.901, 0.6289],
|
||||
&[0.9159, -0.1378, 0.3816, 0.0806],
|
||||
&[0.6707, 0.1059, 0.901, -0.6289],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let evd = A.evd(false).unwrap();
|
||||
|
||||
|
||||
@@ -12,9 +12,9 @@ pub trait HighOrderOperations<T: Number>: Array2<T> {
|
||||
/// use smartcore::linalg::traits::high_order::HighOrderOperations;
|
||||
/// use smartcore::linalg::basic::arrays::Array2;
|
||||
///
|
||||
/// let a = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.]]);
|
||||
/// let b = DenseMatrix::from_2d_array(&[&[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
/// let expected = DenseMatrix::from_2d_array(&[&[71., 80.], &[92., 104.]]);
|
||||
/// let a = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.]]).unwrap();
|
||||
/// let b = DenseMatrix::from_2d_array(&[&[5., 6.], &[7., 8.], &[9., 10.]]).unwrap();
|
||||
/// let expected = DenseMatrix::from_2d_array(&[&[71., 80.], &[92., 104.]]).unwrap();
|
||||
///
|
||||
/// assert_eq!(a.ab(true, &b, false), expected);
|
||||
/// ```
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
//! &[1., 2., 3.],
|
||||
//! &[0., 1., 5.],
|
||||
//! &[5., 6., 0.]
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! let lu = A.lu().unwrap();
|
||||
//! let lower: DenseMatrix<f64> = lu.L();
|
||||
@@ -263,13 +263,13 @@ mod tests {
|
||||
)]
|
||||
#[test]
|
||||
fn decompose() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
|
||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]).unwrap();
|
||||
let expected_L =
|
||||
DenseMatrix::from_2d_array(&[&[1., 0., 0.], &[0., 1., 0.], &[0.2, 0.8, 1.]]);
|
||||
DenseMatrix::from_2d_array(&[&[1., 0., 0.], &[0., 1., 0.], &[0.2, 0.8, 1.]]).unwrap();
|
||||
let expected_U =
|
||||
DenseMatrix::from_2d_array(&[&[5., 6., 0.], &[0., 1., 5.], &[0., 0., -1.]]);
|
||||
DenseMatrix::from_2d_array(&[&[5., 6., 0.], &[0., 1., 5.], &[0., 0., -1.]]).unwrap();
|
||||
let expected_pivot =
|
||||
DenseMatrix::from_2d_array(&[&[0., 0., 1.], &[0., 1., 0.], &[1., 0., 0.]]);
|
||||
DenseMatrix::from_2d_array(&[&[0., 0., 1.], &[0., 1., 0.], &[1., 0., 0.]]).unwrap();
|
||||
let lu = a.lu().unwrap();
|
||||
assert!(relative_eq!(lu.L(), expected_L, epsilon = 1e-4));
|
||||
assert!(relative_eq!(lu.U(), expected_U, epsilon = 1e-4));
|
||||
@@ -281,9 +281,10 @@ mod tests {
|
||||
)]
|
||||
#[test]
|
||||
fn inverse() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]);
|
||||
let a = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[0., 1., 5.], &[5., 6., 0.]]).unwrap();
|
||||
let expected =
|
||||
DenseMatrix::from_2d_array(&[&[-6.0, 3.6, 1.4], &[5.0, -3.0, -1.0], &[-1.0, 0.8, 0.2]]);
|
||||
DenseMatrix::from_2d_array(&[&[-6.0, 3.6, 1.4], &[5.0, -3.0, -1.0], &[-1.0, 0.8, 0.2]])
|
||||
.unwrap();
|
||||
let a_inv = a.lu().and_then(|lu| lu.inverse()).unwrap();
|
||||
assert!(relative_eq!(a_inv, expected, epsilon = 1e-4));
|
||||
}
|
||||
|
||||
+12
-7
@@ -13,7 +13,7 @@
|
||||
//! &[0.9, 0.4, 0.7],
|
||||
//! &[0.4, 0.5, 0.3],
|
||||
//! &[0.7, 0.3, 0.8]
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! let qr = A.qr().unwrap();
|
||||
//! let orthogonal: DenseMatrix<f64> = qr.Q();
|
||||
@@ -201,17 +201,20 @@ mod tests {
|
||||
)]
|
||||
#[test]
|
||||
fn decompose() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]])
|
||||
.unwrap();
|
||||
let q = DenseMatrix::from_2d_array(&[
|
||||
&[-0.7448, 0.2436, 0.6212],
|
||||
&[-0.331, -0.9432, -0.027],
|
||||
&[-0.5793, 0.2257, -0.7832],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let r = DenseMatrix::from_2d_array(&[
|
||||
&[-1.2083, -0.6373, -1.0842],
|
||||
&[0.0, -0.3064, 0.0682],
|
||||
&[0.0, 0.0, -0.1999],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let qr = a.qr().unwrap();
|
||||
assert!(relative_eq!(qr.Q().abs(), q.abs(), epsilon = 1e-4));
|
||||
assert!(relative_eq!(qr.R().abs(), r.abs(), epsilon = 1e-4));
|
||||
@@ -223,13 +226,15 @@ mod tests {
|
||||
)]
|
||||
#[test]
|
||||
fn qr_solve_mut() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]);
|
||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]])
|
||||
.unwrap();
|
||||
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]).unwrap();
|
||||
let expected_w = DenseMatrix::from_2d_array(&[
|
||||
&[-0.2027027, -1.2837838],
|
||||
&[0.8783784, 2.2297297],
|
||||
&[0.4729730, 0.6621622],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let w = a.qr_solve_mut(b).unwrap();
|
||||
assert!(relative_eq!(w, expected_w, epsilon = 1e-2));
|
||||
}
|
||||
|
||||
+17
-14
@@ -136,13 +136,12 @@ pub trait MatrixPreprocessing<T: RealNumber>: MutArrayView2<T> + Clone {
|
||||
/// ```rust
|
||||
/// use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
/// use smartcore::linalg::traits::stats::MatrixPreprocessing;
|
||||
/// let mut a = DenseMatrix::from_2d_array(&[&[0., 2., 3.], &[-5., -6., -7.]]);
|
||||
/// let expected = DenseMatrix::from_2d_array(&[&[0., 1., 1.],&[0., 0., 0.]]);
|
||||
/// let mut a = DenseMatrix::from_2d_array(&[&[0., 2., 3.], &[-5., -6., -7.]]).unwrap();
|
||||
/// let expected = DenseMatrix::from_2d_array(&[&[0., 1., 1.],&[0., 0., 0.]]).unwrap();
|
||||
/// a.binarize_mut(0.);
|
||||
///
|
||||
/// assert_eq!(a, expected);
|
||||
/// ```
|
||||
|
||||
fn binarize_mut(&mut self, threshold: T) {
|
||||
let (nrows, ncols) = self.shape();
|
||||
for row in 0..nrows {
|
||||
@@ -159,8 +158,8 @@ pub trait MatrixPreprocessing<T: RealNumber>: MutArrayView2<T> + Clone {
|
||||
/// ```rust
|
||||
/// use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
/// use smartcore::linalg::traits::stats::MatrixPreprocessing;
|
||||
/// let a = DenseMatrix::from_2d_array(&[&[0., 2., 3.], &[-5., -6., -7.]]);
|
||||
/// let expected = DenseMatrix::from_2d_array(&[&[0., 1., 1.],&[0., 0., 0.]]);
|
||||
/// let a = DenseMatrix::from_2d_array(&[&[0., 2., 3.], &[-5., -6., -7.]]).unwrap();
|
||||
/// let expected = DenseMatrix::from_2d_array(&[&[0., 1., 1.],&[0., 0., 0.]]).unwrap();
|
||||
///
|
||||
/// assert_eq!(a.binarize(0.), expected);
|
||||
/// ```
|
||||
@@ -186,7 +185,8 @@ mod tests {
|
||||
&[1., 2., 3., 1., 2.],
|
||||
&[4., 5., 6., 3., 4.],
|
||||
&[7., 8., 9., 5., 6.],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let expected_0 = vec![4., 5., 6., 3., 4.];
|
||||
let expected_1 = vec![1.8, 4.4, 7.];
|
||||
|
||||
@@ -196,7 +196,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_var() {
|
||||
let m = DenseMatrix::from_2d_array(&[&[1., 2., 3., 4.], &[5., 6., 7., 8.]]);
|
||||
let m = DenseMatrix::from_2d_array(&[&[1., 2., 3., 4.], &[5., 6., 7., 8.]]).unwrap();
|
||||
let expected_0 = vec![4., 4., 4., 4.];
|
||||
let expected_1 = vec![1.25, 1.25];
|
||||
|
||||
@@ -211,12 +211,13 @@ mod tests {
|
||||
let m = DenseMatrix::from_2d_array(&[
|
||||
&[0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25],
|
||||
&[0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let expected_0 = vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
|
||||
let expected_1 = vec![1.25, 1.25];
|
||||
|
||||
assert!(m.var(0).approximate_eq(&expected_0, std::f64::EPSILON));
|
||||
assert!(m.var(1).approximate_eq(&expected_1, std::f64::EPSILON));
|
||||
assert!(m.var(0).approximate_eq(&expected_0, f64::EPSILON));
|
||||
assert!(m.var(1).approximate_eq(&expected_1, f64::EPSILON));
|
||||
assert_eq!(
|
||||
m.mean(0),
|
||||
vec![0.0, 0.25, 0.25, 1.25, 1.5, 1.75, 2.75, 3.25]
|
||||
@@ -230,7 +231,8 @@ mod tests {
|
||||
&[1., 2., 3., 1., 2.],
|
||||
&[4., 5., 6., 3., 4.],
|
||||
&[7., 8., 9., 5., 6.],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let expected_0 = vec![
|
||||
2.449489742783178,
|
||||
2.449489742783178,
|
||||
@@ -251,10 +253,10 @@ mod tests {
|
||||
#[test]
|
||||
fn test_scale() {
|
||||
let m: DenseMatrix<f64> =
|
||||
DenseMatrix::from_2d_array(&[&[1., 2., 3., 4.], &[5., 6., 7., 8.]]);
|
||||
DenseMatrix::from_2d_array(&[&[1., 2., 3., 4.], &[5., 6., 7., 8.]]).unwrap();
|
||||
|
||||
let expected_0: DenseMatrix<f64> =
|
||||
DenseMatrix::from_2d_array(&[&[-1., -1., -1., -1.], &[1., 1., 1., 1.]]);
|
||||
DenseMatrix::from_2d_array(&[&[-1., -1., -1., -1.], &[1., 1., 1., 1.]]).unwrap();
|
||||
let expected_1: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
|
||||
&[
|
||||
-1.3416407864998738,
|
||||
@@ -268,7 +270,8 @@ mod tests {
|
||||
0.4472135954999579,
|
||||
1.3416407864998738,
|
||||
],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(m.mean(0), vec![3.0, 4.0, 5.0, 6.0]);
|
||||
assert_eq!(m.mean(1), vec![2.5, 6.5]);
|
||||
|
||||
+20
-14
@@ -17,7 +17,7 @@
|
||||
//! &[0.9, 0.4, 0.7],
|
||||
//! &[0.4, 0.5, 0.3],
|
||||
//! &[0.7, 0.3, 0.8]
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! let svd = A.svd().unwrap();
|
||||
//! let u: DenseMatrix<f64> = svd.U;
|
||||
@@ -48,11 +48,9 @@ pub struct SVD<T: Number + RealNumber, M: SVDDecomposable<T>> {
|
||||
pub V: M,
|
||||
/// Singular values of the original matrix
|
||||
pub s: Vec<T>,
|
||||
///
|
||||
m: usize,
|
||||
///
|
||||
n: usize,
|
||||
///
|
||||
/// Tolerance
|
||||
tol: T,
|
||||
}
|
||||
|
||||
@@ -489,7 +487,8 @@ mod tests {
|
||||
&[0.9000, 0.4000, 0.7000],
|
||||
&[0.4000, 0.5000, 0.3000],
|
||||
&[0.7000, 0.3000, 0.8000],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let s: Vec<f64> = vec![1.7498382, 0.3165784, 0.1335834];
|
||||
|
||||
@@ -497,13 +496,15 @@ mod tests {
|
||||
&[0.6881997, -0.07121225, 0.7220180],
|
||||
&[0.3700456, 0.89044952, -0.2648886],
|
||||
&[0.6240573, -0.44947578, -0.639158],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let V = DenseMatrix::from_2d_array(&[
|
||||
&[0.6881997, -0.07121225, 0.7220180],
|
||||
&[0.3700456, 0.89044952, -0.2648886],
|
||||
&[0.6240573, -0.44947578, -0.6391588],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let svd = A.svd().unwrap();
|
||||
|
||||
@@ -577,7 +578,8 @@ mod tests {
|
||||
-0.2158704,
|
||||
-0.27529472,
|
||||
],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let s: Vec<f64> = vec![
|
||||
3.8589375, 3.4396766, 2.6487176, 2.2317399, 1.5165054, 0.8109055, 0.2706515,
|
||||
@@ -647,7 +649,8 @@ mod tests {
|
||||
0.73034065,
|
||||
-0.43965505,
|
||||
],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let V = DenseMatrix::from_2d_array(&[
|
||||
&[
|
||||
@@ -707,7 +710,8 @@ mod tests {
|
||||
0.1654796,
|
||||
-0.32346758,
|
||||
],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let svd = A.svd().unwrap();
|
||||
|
||||
@@ -723,10 +727,11 @@ mod tests {
|
||||
)]
|
||||
#[test]
|
||||
fn solve() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]]);
|
||||
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]);
|
||||
let a = DenseMatrix::from_2d_array(&[&[0.9, 0.4, 0.7], &[0.4, 0.5, 0.3], &[0.7, 0.3, 0.8]])
|
||||
.unwrap();
|
||||
let b = DenseMatrix::from_2d_array(&[&[0.5, 0.2], &[0.5, 0.8], &[0.5, 0.3]]).unwrap();
|
||||
let expected_w =
|
||||
DenseMatrix::from_2d_array(&[&[-0.20, -1.28], &[0.87, 2.22], &[0.47, 0.66]]);
|
||||
DenseMatrix::from_2d_array(&[&[-0.20, -1.28], &[0.87, 2.22], &[0.47, 0.66]]).unwrap();
|
||||
let w = a.svd_solve_mut(b).unwrap();
|
||||
assert!(relative_eq!(w, expected_w, epsilon = 1e-2));
|
||||
}
|
||||
@@ -737,7 +742,8 @@ mod tests {
|
||||
)]
|
||||
#[test]
|
||||
fn decompose_restore() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0, 4.0], &[5.0, 6.0, 7.0, 8.0]]);
|
||||
let a =
|
||||
DenseMatrix::from_2d_array(&[&[1.0, 2.0, 3.0, 4.0], &[5.0, 6.0, 7.0, 8.0]]).unwrap();
|
||||
let svd = a.svd().unwrap();
|
||||
let u: &DenseMatrix<f32> = &svd.U; //U
|
||||
let v: &DenseMatrix<f32> = &svd.V; // V
|
||||
|
||||
@@ -12,7 +12,8 @@
|
||||
//! pub struct BGSolver {}
|
||||
//! impl<'a, T: FloatNumber, X: Array2<T>> BiconjugateGradientSolver<'a, T, X> for BGSolver {}
|
||||
//!
|
||||
//! let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
|
||||
//! let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0.,
|
||||
//! 11.]]).unwrap();
|
||||
//! let b = vec![40., 51., 28.];
|
||||
//! let expected = vec![1.0, 2.0, 3.0];
|
||||
//! let mut x = Vec::zeros(3);
|
||||
@@ -26,9 +27,9 @@ use crate::error::Failed;
|
||||
use crate::linalg::basic::arrays::{Array, Array1, Array2, ArrayView1, MutArrayView1};
|
||||
use crate::numbers::floatnum::FloatNumber;
|
||||
|
||||
///
|
||||
/// Trait for Biconjugate Gradient Solver
|
||||
pub trait BiconjugateGradientSolver<'a, T: FloatNumber, X: Array2<T>> {
|
||||
///
|
||||
/// Solve Ax = b
|
||||
fn solve_mut(
|
||||
&self,
|
||||
a: &'a X,
|
||||
@@ -108,7 +109,7 @@ pub trait BiconjugateGradientSolver<'a, T: FloatNumber, X: Array2<T>> {
|
||||
Ok(err)
|
||||
}
|
||||
|
||||
///
|
||||
/// solve preconditioner
|
||||
fn solve_preconditioner(&self, a: &'a X, b: &[T], x: &mut [T]) {
|
||||
let diag = Self::diag(a);
|
||||
let n = diag.len();
|
||||
@@ -132,7 +133,7 @@ pub trait BiconjugateGradientSolver<'a, T: FloatNumber, X: Array2<T>> {
|
||||
y.copy_from(&x.xa(true, a));
|
||||
}
|
||||
|
||||
///
|
||||
/// Extract the diagonal from a matrix
|
||||
fn diag(a: &X) -> Vec<T> {
|
||||
let (nrows, ncols) = a.shape();
|
||||
let n = nrows.min(ncols);
|
||||
@@ -158,9 +159,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn bg_solver() {
|
||||
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]]);
|
||||
let a = DenseMatrix::from_2d_array(&[&[25., 15., -5.], &[15., 18., 0.], &[-5., 0., 11.]])
|
||||
.unwrap();
|
||||
let b = vec![40., 51., 28.];
|
||||
let expected = vec![1.0, 2.0, 3.0];
|
||||
let expected = [1.0, 2.0, 3.0];
|
||||
|
||||
let mut x = Vec::zeros(3);
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@
|
||||
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0,
|
||||
//! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
||||
@@ -511,7 +511,8 @@ mod tests {
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let y: Vec<f64> = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
@@ -562,7 +563,8 @@ mod tests {
|
||||
&[17.0, 1918.0, 1.4054969025700674],
|
||||
&[18.0, 1929.0, 1.3271699396384906],
|
||||
&[19.0, 1915.0, 1.1373332337674806],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let y: Vec<f64> = vec![
|
||||
1.48, 2.72, 4.52, 5.72, 5.25, 4.07, 3.75, 4.75, 6.77, 4.72, 6.78, 6.79, 8.3, 7.42,
|
||||
@@ -627,7 +629,7 @@ mod tests {
|
||||
// &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
// &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
// &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
// ]);
|
||||
// ]).unwrap();
|
||||
|
||||
// let y = vec![
|
||||
// 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
|
||||
+2
-1
@@ -418,7 +418,8 @@ mod tests {
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let y: Vec<f64> = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
|
||||
@@ -16,7 +16,7 @@ use crate::linalg::basic::arrays::{Array1, Array2, ArrayView1, MutArray, MutArra
|
||||
use crate::linear::bg_solver::BiconjugateGradientSolver;
|
||||
use crate::numbers::floatnum::FloatNumber;
|
||||
|
||||
///
|
||||
/// Interior Point Optimizer
|
||||
pub struct InteriorPointOptimizer<T: FloatNumber, X: Array2<T>> {
|
||||
ata: X,
|
||||
d1: Vec<T>,
|
||||
@@ -25,9 +25,8 @@ pub struct InteriorPointOptimizer<T: FloatNumber, X: Array2<T>> {
|
||||
prs: Vec<T>,
|
||||
}
|
||||
|
||||
///
|
||||
impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
|
||||
///
|
||||
/// Initialize a new Interior Point Optimizer
|
||||
pub fn new(a: &X, n: usize) -> InteriorPointOptimizer<T, X> {
|
||||
InteriorPointOptimizer {
|
||||
ata: a.ab(true, a, false),
|
||||
@@ -38,7 +37,7 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
/// Run the optimization
|
||||
pub fn optimize(
|
||||
&mut self,
|
||||
x: &X,
|
||||
@@ -101,7 +100,7 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
|
||||
|
||||
// CALCULATE DUALITY GAP
|
||||
let xnu = nu.xa(false, x);
|
||||
let max_xnu = xnu.norm(std::f64::INFINITY);
|
||||
let max_xnu = xnu.norm(f64::INFINITY);
|
||||
if max_xnu > lambda_f64 {
|
||||
let lnu = T::from_f64(lambda_f64 / max_xnu).unwrap();
|
||||
nu.mul_scalar_mut(lnu);
|
||||
@@ -208,7 +207,6 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
|
||||
Ok(w)
|
||||
}
|
||||
|
||||
///
|
||||
fn sumlogneg(f: &X) -> T {
|
||||
let (n, _) = f.shape();
|
||||
let mut sum = T::zero();
|
||||
@@ -220,11 +218,9 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
impl<'a, T: FloatNumber, X: Array2<T>> BiconjugateGradientSolver<'a, T, X>
|
||||
for InteriorPointOptimizer<T, X>
|
||||
{
|
||||
///
|
||||
fn solve_preconditioner(&self, a: &'a X, b: &[T], x: &mut [T]) {
|
||||
let (_, p) = a.shape();
|
||||
|
||||
@@ -234,7 +230,6 @@ impl<'a, T: FloatNumber, X: Array2<T>> BiconjugateGradientSolver<'a, T, X>
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
fn mat_vec_mul(&self, _: &X, x: &Vec<T>, y: &mut Vec<T>) {
|
||||
let (_, p) = self.ata.shape();
|
||||
let x_slice = Vec::from_slice(x.slice(0..p).as_ref());
|
||||
@@ -246,7 +241,6 @@ impl<'a, T: FloatNumber, X: Array2<T>> BiconjugateGradientSolver<'a, T, X>
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
fn mat_t_vec_mul(&self, a: &X, x: &Vec<T>, y: &mut Vec<T>) {
|
||||
self.mat_vec_mul(a, x, y);
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0,
|
||||
//! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
||||
@@ -341,7 +341,8 @@ mod tests {
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let y: Vec<f64> = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8,
|
||||
@@ -393,7 +394,7 @@ mod tests {
|
||||
// &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
// &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
// &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
// ]);
|
||||
// ]).unwrap();
|
||||
|
||||
// let y = vec![
|
||||
// 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
|
||||
@@ -35,7 +35,7 @@
|
||||
//! &[4.9, 2.4, 3.3, 1.0],
|
||||
//! &[6.6, 2.9, 4.6, 1.3],
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//! let y: Vec<i32> = vec![
|
||||
//! 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
//! ];
|
||||
@@ -183,14 +183,11 @@ pub struct LogisticRegression<
|
||||
}
|
||||
|
||||
trait ObjectiveFunction<T: Number + FloatNumber, X: Array2<T>> {
|
||||
///
|
||||
fn f(&self, w_bias: &[T]) -> T;
|
||||
|
||||
///
|
||||
#[allow(clippy::ptr_arg)]
|
||||
fn df(&self, g: &mut Vec<T>, w_bias: &Vec<T>);
|
||||
|
||||
///
|
||||
#[allow(clippy::ptr_arg)]
|
||||
fn partial_dot(w: &[T], x: &X, v_col: usize, m_row: usize) -> T {
|
||||
let mut sum = T::zero();
|
||||
@@ -261,8 +258,8 @@ impl<TX: Number + FloatNumber + RealNumber, TY: Number + Ord, X: Array2<TX>, Y:
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Number + FloatNumber, X: Array2<T>> ObjectiveFunction<T, X>
|
||||
for BinaryObjectiveFunction<'a, T, X>
|
||||
impl<T: Number + FloatNumber, X: Array2<T>> ObjectiveFunction<T, X>
|
||||
for BinaryObjectiveFunction<'_, T, X>
|
||||
{
|
||||
fn f(&self, w_bias: &[T]) -> T {
|
||||
let mut f = T::zero();
|
||||
@@ -316,8 +313,8 @@ struct MultiClassObjectiveFunction<'a, T: Number + FloatNumber, X: Array2<T>> {
|
||||
_phantom_t: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<'a, T: Number + FloatNumber + RealNumber, X: Array2<T>> ObjectiveFunction<T, X>
|
||||
for MultiClassObjectiveFunction<'a, T, X>
|
||||
impl<T: Number + FloatNumber + RealNumber, X: Array2<T>> ObjectiveFunction<T, X>
|
||||
for MultiClassObjectiveFunction<'_, T, X>
|
||||
{
|
||||
fn f(&self, w_bias: &[T]) -> T {
|
||||
let mut f = T::zero();
|
||||
@@ -416,7 +413,7 @@ impl<TX: Number + FloatNumber + RealNumber, TY: Number + Ord, X: Array2<TX>, Y:
|
||||
/// Fits Logistic Regression to your data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - target class values
|
||||
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
|
||||
/// * `parameters` - other parameters, use `Default::default()` to set parameters to default values.
|
||||
pub fn fit(
|
||||
x: &X,
|
||||
y: &Y,
|
||||
@@ -611,7 +608,8 @@ mod tests {
|
||||
&[10., -2.],
|
||||
&[8., 2.],
|
||||
&[9., 0.],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let y = vec![0, 0, 1, 1, 2, 1, 1, 0, 0, 2, 1, 1, 0, 0, 1];
|
||||
|
||||
@@ -628,11 +626,11 @@ mod tests {
|
||||
objective.df(&mut g, &vec![1., 2., 3., 4., 5., 6., 7., 8., 9.]);
|
||||
objective.df(&mut g, &vec![1., 2., 3., 4., 5., 6., 7., 8., 9.]);
|
||||
|
||||
assert!((g[0] + 33.000068218163484).abs() < std::f64::EPSILON);
|
||||
assert!((g[0] + 33.000068218163484).abs() < f64::EPSILON);
|
||||
|
||||
let f = objective.f(&[1., 2., 3., 4., 5., 6., 7., 8., 9.]);
|
||||
|
||||
assert!((f - 408.0052230582765).abs() < std::f64::EPSILON);
|
||||
assert!((f - 408.0052230582765).abs() < f64::EPSILON);
|
||||
|
||||
let objective_reg = MultiClassObjectiveFunction {
|
||||
x: &x,
|
||||
@@ -671,7 +669,8 @@ mod tests {
|
||||
&[10., -2.],
|
||||
&[8., 2.],
|
||||
&[9., 0.],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let y = vec![0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1];
|
||||
|
||||
@@ -687,13 +686,13 @@ mod tests {
|
||||
objective.df(&mut g, &vec![1., 2., 3.]);
|
||||
objective.df(&mut g, &vec![1., 2., 3.]);
|
||||
|
||||
assert!((g[0] - 26.051064349381285).abs() < std::f64::EPSILON);
|
||||
assert!((g[1] - 10.239000702928523).abs() < std::f64::EPSILON);
|
||||
assert!((g[2] - 3.869294270156324).abs() < std::f64::EPSILON);
|
||||
assert!((g[0] - 26.051064349381285).abs() < f64::EPSILON);
|
||||
assert!((g[1] - 10.239000702928523).abs() < f64::EPSILON);
|
||||
assert!((g[2] - 3.869294270156324).abs() < f64::EPSILON);
|
||||
|
||||
let f = objective.f(&[1., 2., 3.]);
|
||||
|
||||
assert!((f - 59.76994756647412).abs() < std::f64::EPSILON);
|
||||
assert!((f - 59.76994756647412).abs() < f64::EPSILON);
|
||||
|
||||
let objective_reg = BinaryObjectiveFunction {
|
||||
x: &x,
|
||||
@@ -733,7 +732,8 @@ mod tests {
|
||||
&[10., -2.],
|
||||
&[8., 2.],
|
||||
&[9., 0.],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y: Vec<i32> = vec![0, 0, 1, 1, 2, 1, 1, 0, 0, 2, 1, 1, 0, 0, 1];
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
@@ -818,37 +818,41 @@ mod tests {
|
||||
assert!(reg_coeff_sum < coeff);
|
||||
}
|
||||
|
||||
// TODO: serialization for the new DenseMatrix needs to be implemented
|
||||
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
|
||||
// #[test]
|
||||
// #[cfg(feature = "serde")]
|
||||
// fn serde() {
|
||||
// let x = DenseMatrix::from_2d_array(&[
|
||||
// &[1., -5.],
|
||||
// &[2., 5.],
|
||||
// &[3., -2.],
|
||||
// &[1., 2.],
|
||||
// &[2., 0.],
|
||||
// &[6., -5.],
|
||||
// &[7., 5.],
|
||||
// &[6., -2.],
|
||||
// &[7., 2.],
|
||||
// &[6., 0.],
|
||||
// &[8., -5.],
|
||||
// &[9., 5.],
|
||||
// &[10., -2.],
|
||||
// &[8., 2.],
|
||||
// &[9., 0.],
|
||||
// ]);
|
||||
// let y: Vec<i32> = vec![0, 0, 1, 1, 2, 1, 1, 0, 0, 2, 1, 1, 0, 0, 1];
|
||||
//TODO: serialization for the new DenseMatrix needs to be implemented
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
|
||||
&[1., -5.],
|
||||
&[2., 5.],
|
||||
&[3., -2.],
|
||||
&[1., 2.],
|
||||
&[2., 0.],
|
||||
&[6., -5.],
|
||||
&[7., 5.],
|
||||
&[6., -2.],
|
||||
&[7., 2.],
|
||||
&[6., 0.],
|
||||
&[8., -5.],
|
||||
&[9., 5.],
|
||||
&[10., -2.],
|
||||
&[8., 2.],
|
||||
&[9., 0.],
|
||||
])
|
||||
.unwrap();
|
||||
let y: Vec<i32> = vec![0, 0, 1, 1, 2, 1, 1, 0, 0, 2, 1, 1, 0, 0, 1];
|
||||
|
||||
// let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
// let deserialized_lr: LogisticRegression<f64, i32, DenseMatrix<f64>, Vec<i32>> =
|
||||
// serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
||||
let deserialized_lr: LogisticRegression<f64, i32, DenseMatrix<f64>, Vec<i32>> =
|
||||
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
||||
|
||||
// assert_eq!(lr, deserialized_lr);
|
||||
// }
|
||||
assert_eq!(lr, deserialized_lr);
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
@@ -877,7 +881,8 @@ mod tests {
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y: Vec<i32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
@@ -890,11 +895,7 @@ mod tests {
|
||||
|
||||
let y_hat = lr.predict(&x).unwrap();
|
||||
|
||||
let error: i32 = y
|
||||
.into_iter()
|
||||
.zip(y_hat.into_iter())
|
||||
.map(|(a, b)| (a - b).abs())
|
||||
.sum();
|
||||
let error: i32 = y.into_iter().zip(y_hat).map(|(a, b)| (a - b).abs()).sum();
|
||||
|
||||
assert!(error <= 1);
|
||||
|
||||
@@ -903,4 +904,46 @@ mod tests {
|
||||
|
||||
assert!(reg_coeff_sum < coeff);
|
||||
}
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn lr_fit_predict_random() {
|
||||
let x: DenseMatrix<f32> = DenseMatrix::rand(52181, 94);
|
||||
let y1: Vec<i32> = vec![1; 2181];
|
||||
let y2: Vec<i32> = vec![0; 50000];
|
||||
let y: Vec<i32> = y1.into_iter().chain(y2).collect();
|
||||
|
||||
let lr = LogisticRegression::fit(&x, &y, Default::default()).unwrap();
|
||||
let lr_reg = LogisticRegression::fit(
|
||||
&x,
|
||||
&y,
|
||||
LogisticRegressionParameters::default().with_alpha(1.0),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let y_hat = lr.predict(&x).unwrap();
|
||||
let y_hat_reg = lr_reg.predict(&x).unwrap();
|
||||
|
||||
assert_eq!(y.len(), y_hat.len());
|
||||
assert_eq!(y.len(), y_hat_reg.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_logit() {
|
||||
let x: &DenseMatrix<f64> = &DenseMatrix::rand(52181, 94);
|
||||
let y1: Vec<u32> = vec![1; 2181];
|
||||
let y2: Vec<u32> = vec![0; 50000];
|
||||
let y: &Vec<u32> = &(y1.into_iter().chain(y2).collect());
|
||||
println!("y vec height: {:?}", y.len());
|
||||
println!("x matrix shape: {:?}", x.shape());
|
||||
|
||||
let lr = LogisticRegression::fit(x, y, Default::default()).unwrap();
|
||||
let y_hat = lr.predict(x).unwrap();
|
||||
|
||||
println!("y_hat shape: {:?}", y_hat.shape());
|
||||
|
||||
assert_eq!(y_hat.shape(), 52181);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0,
|
||||
//! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
||||
@@ -455,7 +455,8 @@ mod tests {
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let y: Vec<f64> = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
@@ -513,7 +514,7 @@ mod tests {
|
||||
// &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
// &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
// &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
// ]);
|
||||
// ]).unwrap();
|
||||
|
||||
// let y = vec![
|
||||
// 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
|
||||
@@ -25,7 +25,7 @@
|
||||
//! &[68., 590., 37.],
|
||||
//! &[69., 660., 46.],
|
||||
//! &[73., 600., 55.],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! let a = data.mean_by(0);
|
||||
//! let b = vec![66., 640., 44.];
|
||||
@@ -151,7 +151,8 @@ mod tests {
|
||||
&[68., 590., 37.],
|
||||
&[69., 660., 46.],
|
||||
&[73., 600., 55.],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let a = data.mean_by(0);
|
||||
let b = vec![66., 640., 44.];
|
||||
|
||||
+1
-1
@@ -37,7 +37,7 @@
|
||||
//! &[4.9, 2.4, 3.3, 1.0],
|
||||
//! &[6.6, 2.9, 4.6, 1.3],
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//! let y: Vec<i8> = vec![
|
||||
//! 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
//! ];
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
use crate::{
|
||||
api::{Predictor, SupervisedEstimator},
|
||||
error::{Failed, FailedError},
|
||||
linalg::basic::arrays::{Array2, Array1},
|
||||
numbers::realnum::RealNumber,
|
||||
linalg::basic::arrays::{Array1, Array2},
|
||||
numbers::basenum::Number,
|
||||
numbers::realnum::RealNumber,
|
||||
};
|
||||
|
||||
use crate::model_selection::{cross_validate, BaseKFold, CrossValidationResult};
|
||||
|
||||
@@ -36,7 +36,7 @@
|
||||
//! &[4.9, 2.4, 3.3, 1.0],
|
||||
//! &[6.6, 2.9, 4.6, 1.3],
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//! let y: Vec<f64> = vec![
|
||||
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
//! ];
|
||||
@@ -84,7 +84,7 @@
|
||||
//! &[4.9, 2.4, 3.3, 1.0],
|
||||
//! &[6.6, 2.9, 4.6, 1.3],
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//! let y: Vec<i32> = vec![
|
||||
//! 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
//! ];
|
||||
@@ -396,7 +396,8 @@ mod tests {
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y: Vec<u32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
|
||||
|
||||
let cv = KFold {
|
||||
@@ -441,7 +442,8 @@ mod tests {
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
@@ -489,7 +491,8 @@ mod tests {
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y: Vec<f64> = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
@@ -539,7 +542,8 @@ mod tests {
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y: Vec<i32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
|
||||
|
||||
let cv = KFold::default().with_n_splits(3);
|
||||
|
||||
@@ -19,14 +19,14 @@
|
||||
//! &[0, 1, 0, 0, 1, 0],
|
||||
//! &[0, 1, 0, 1, 0, 0],
|
||||
//! &[0, 1, 1, 0, 0, 1],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//! let y: Vec<u32> = vec![0, 0, 0, 1];
|
||||
//!
|
||||
//! let nb = BernoulliNB::fit(&x, &y, Default::default()).unwrap();
|
||||
//!
|
||||
//! // Testing data point is:
|
||||
//! // Chinese Chinese Chinese Tokyo Japan
|
||||
//! let x_test = DenseMatrix::from_2d_array(&[&[0, 1, 1, 0, 0, 1]]);
|
||||
//! let x_test = DenseMatrix::from_2d_array(&[&[0, 1, 1, 0, 0, 1]]).unwrap();
|
||||
//! let y_hat = nb.predict(&x_test).unwrap();
|
||||
//! ```
|
||||
//!
|
||||
@@ -257,8 +257,7 @@ impl<TY: Number + Ord + Unsigned> BernoulliNBDistribution<TY> {
|
||||
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
|
||||
/// * `x` - training data.
|
||||
/// * `y` - vector with target values (classes) of length N.
|
||||
/// * `priors` - Optional vector with prior probabilities of the classes. If not defined,
|
||||
/// priors are adjusted according to the data.
|
||||
/// * `priors` - Optional vector with prior probabilities of the classes. If not defined, priors are adjusted according to the data.
|
||||
/// * `alpha` - Additive (Laplace/Lidstone) smoothing parameter.
|
||||
/// * `binarize` - Threshold for binarizing.
|
||||
fn fit<TX: Number + PartialOrd, X: Array2<TX>, Y: Array1<TY>>(
|
||||
@@ -402,10 +401,10 @@ impl<TX: Number + PartialOrd, TY: Number + Ord + Unsigned, X: Array2<TX>, Y: Arr
|
||||
{
|
||||
/// Fits BernoulliNB with given data
|
||||
/// * `x` - training data of size NxM where N is the number of samples and M is the number of
|
||||
/// features.
|
||||
/// features.
|
||||
/// * `y` - vector with target values (classes) of length N.
|
||||
/// * `parameters` - additional parameters like class priors, alpha for smoothing and
|
||||
/// binarizing threshold.
|
||||
/// binarizing threshold.
|
||||
pub fn fit(x: &X, y: &Y, parameters: BernoulliNBParameters<TX>) -> Result<Self, Failed> {
|
||||
let distribution = if let Some(threshold) = parameters.binarize {
|
||||
BernoulliNBDistribution::fit(
|
||||
@@ -427,6 +426,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord + Unsigned, X: Array2<TX>, Y: Arr
|
||||
|
||||
/// Estimates the class labels for the provided data.
|
||||
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
|
||||
///
|
||||
/// Returns a vector of size N with class estimates.
|
||||
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
if let Some(threshold) = self.binarize {
|
||||
@@ -527,7 +527,8 @@ mod tests {
|
||||
&[0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
|
||||
&[0.0, 1.0, 0.0, 1.0, 0.0, 0.0],
|
||||
&[0.0, 1.0, 1.0, 0.0, 0.0, 1.0],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y: Vec<u32> = vec![0, 0, 0, 1];
|
||||
let bnb = BernoulliNB::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
@@ -558,7 +559,7 @@ mod tests {
|
||||
|
||||
// Testing data point is:
|
||||
// Chinese Chinese Chinese Tokyo Japan
|
||||
let x_test = DenseMatrix::from_2d_array(&[&[0.0, 1.0, 1.0, 0.0, 0.0, 1.0]]);
|
||||
let x_test = DenseMatrix::from_2d_array(&[&[0.0, 1.0, 1.0, 0.0, 0.0, 1.0]]).unwrap();
|
||||
let y_hat = bnb.predict(&x_test).unwrap();
|
||||
|
||||
assert_eq!(y_hat, &[1]);
|
||||
@@ -586,7 +587,8 @@ mod tests {
|
||||
&[2, 0, 3, 3, 1, 2, 0, 2, 4, 1],
|
||||
&[2, 4, 0, 4, 2, 4, 1, 3, 1, 4],
|
||||
&[0, 2, 2, 3, 4, 0, 4, 4, 4, 4],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y: Vec<u32> = vec![2, 2, 0, 0, 0, 2, 1, 1, 0, 1, 0, 0, 2, 0, 2];
|
||||
let bnb = BernoulliNB::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
@@ -643,7 +645,8 @@ mod tests {
|
||||
&[0, 1, 0, 0, 1, 0],
|
||||
&[0, 1, 0, 1, 0, 0],
|
||||
&[0, 1, 1, 0, 0, 1],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y: Vec<u32> = vec![0, 0, 0, 1];
|
||||
|
||||
let bnb = BernoulliNB::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
//! &[3, 4, 2, 4],
|
||||
//! &[0, 3, 1, 2],
|
||||
//! &[0, 4, 1, 2],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//! let y: Vec<u32> = vec![0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0];
|
||||
//!
|
||||
//! let nb = CategoricalNB::fit(&x, &y, Default::default()).unwrap();
|
||||
@@ -95,7 +95,7 @@ impl<T: Number + Unsigned> PartialEq for CategoricalNBDistribution<T> {
|
||||
return false;
|
||||
}
|
||||
for (a_i_j, b_i_j) in a_i.iter().zip(b_i.iter()) {
|
||||
if (*a_i_j - *b_i_j).abs() > std::f64::EPSILON {
|
||||
if (*a_i_j - *b_i_j).abs() > f64::EPSILON {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -363,7 +363,7 @@ impl<T: Number + Unsigned, X: Array2<T>, Y: Array1<T>> Predictor<X, Y> for Categ
|
||||
impl<T: Number + Unsigned, X: Array2<T>, Y: Array1<T>> CategoricalNB<T, X, Y> {
|
||||
/// Fits CategoricalNB with given data
|
||||
/// * `x` - training data of size NxM where N is the number of samples and M is the number of
|
||||
/// features.
|
||||
/// features.
|
||||
/// * `y` - vector with target values (classes) of length N.
|
||||
/// * `parameters` - additional parameters like alpha for smoothing
|
||||
pub fn fit(x: &X, y: &Y, parameters: CategoricalNBParameters) -> Result<Self, Failed> {
|
||||
@@ -375,6 +375,7 @@ impl<T: Number + Unsigned, X: Array2<T>, Y: Array1<T>> CategoricalNB<T, X, Y> {
|
||||
|
||||
/// Estimates the class labels for the provided data.
|
||||
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
|
||||
///
|
||||
/// Returns a vector of size N with class estimates.
|
||||
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
self.inner.as_ref().unwrap().predict(x)
|
||||
@@ -455,7 +456,8 @@ mod tests {
|
||||
&[1, 1, 1, 1],
|
||||
&[1, 2, 0, 0],
|
||||
&[2, 1, 1, 1],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y: Vec<u32> = vec![0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0];
|
||||
|
||||
let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap();
|
||||
@@ -513,7 +515,7 @@ mod tests {
|
||||
]
|
||||
);
|
||||
|
||||
let x_test = DenseMatrix::from_2d_array(&[&[0, 2, 1, 0], &[2, 2, 0, 0]]);
|
||||
let x_test = DenseMatrix::from_2d_array(&[&[0, 2, 1, 0], &[2, 2, 0, 0]]).unwrap();
|
||||
let y_hat = cnb.predict(&x_test).unwrap();
|
||||
assert_eq!(y_hat, vec![0, 1]);
|
||||
}
|
||||
@@ -539,7 +541,8 @@ mod tests {
|
||||
&[3, 4, 2, 4],
|
||||
&[0, 3, 1, 2],
|
||||
&[0, 4, 1, 2],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y: Vec<u32> = vec![0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0];
|
||||
|
||||
let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap();
|
||||
@@ -571,7 +574,8 @@ mod tests {
|
||||
&[3, 4, 2, 4],
|
||||
&[0, 3, 1, 2],
|
||||
&[0, 4, 1, 2],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let y: Vec<u32> = vec![0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0];
|
||||
let cnb = CategoricalNB::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
//! &[ 1., 1.],
|
||||
//! &[ 2., 1.],
|
||||
//! &[ 3., 2.],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//! let y: Vec<u32> = vec![1, 1, 1, 2, 2, 2];
|
||||
//!
|
||||
//! let nb = GaussianNB::fit(&x, &y, Default::default()).unwrap();
|
||||
@@ -174,8 +174,7 @@ impl<TY: Number + Ord + Unsigned> GaussianNBDistribution<TY> {
|
||||
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
|
||||
/// * `x` - training data.
|
||||
/// * `y` - vector with target values (classes) of length N.
|
||||
/// * `priors` - Optional vector with prior probabilities of the classes. If not defined,
|
||||
/// priors are adjusted according to the data.
|
||||
/// * `priors` - Optional vector with prior probabilities of the classes. If not defined, priors are adjusted according to the data.
|
||||
pub fn fit<TX: Number + RealNumber, X: Array2<TX>, Y: Array1<TY>>(
|
||||
x: &X,
|
||||
y: &Y,
|
||||
@@ -317,7 +316,7 @@ impl<TX: Number + RealNumber, TY: Number + Ord + Unsigned, X: Array2<TX>, Y: Arr
|
||||
{
|
||||
/// Fits GaussianNB with given data
|
||||
/// * `x` - training data of size NxM where N is the number of samples and M is the number of
|
||||
/// features.
|
||||
/// features.
|
||||
/// * `y` - vector with target values (classes) of length N.
|
||||
/// * `parameters` - additional parameters like class priors.
|
||||
pub fn fit(x: &X, y: &Y, parameters: GaussianNBParameters) -> Result<Self, Failed> {
|
||||
@@ -328,6 +327,7 @@ impl<TX: Number + RealNumber, TY: Number + Ord + Unsigned, X: Array2<TX>, Y: Arr
|
||||
|
||||
/// Estimates the class labels for the provided data.
|
||||
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
|
||||
///
|
||||
/// Returns a vector of size N with class estimates.
|
||||
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
self.inner.as_ref().unwrap().predict(x)
|
||||
@@ -395,7 +395,8 @@ mod tests {
|
||||
&[1., 1.],
|
||||
&[2., 1.],
|
||||
&[3., 2.],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y: Vec<u32> = vec![1, 1, 1, 2, 2, 2];
|
||||
|
||||
let gnb = GaussianNB::fit(&x, &y, Default::default()).unwrap();
|
||||
@@ -435,7 +436,8 @@ mod tests {
|
||||
&[1., 1.],
|
||||
&[2., 1.],
|
||||
&[3., 2.],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y: Vec<u32> = vec![1, 1, 1, 2, 2, 2];
|
||||
|
||||
let priors = vec![0.3, 0.7];
|
||||
@@ -462,7 +464,8 @@ mod tests {
|
||||
&[1., 1.],
|
||||
&[2., 1.],
|
||||
&[3., 2.],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y: Vec<u32> = vec![1, 1, 1, 2, 2, 2];
|
||||
|
||||
let gnb = GaussianNB::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
+532
-20
@@ -89,33 +89,545 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: NBDistribution<TX,
|
||||
|
||||
/// Estimates the class labels for the provided data.
|
||||
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
|
||||
///
|
||||
/// Returns a vector of size N with class estimates.
|
||||
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
let y_classes = self.distribution.classes();
|
||||
|
||||
if y_classes.is_empty() {
|
||||
return Err(Failed::predict("Failed to predict, no classes available"));
|
||||
}
|
||||
|
||||
let (rows, _) = x.shape();
|
||||
let predictions = (0..rows)
|
||||
.map(|row_index| {
|
||||
let row = x.get_row(row_index);
|
||||
let (prediction, _probability) = y_classes
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(class_index, class)| {
|
||||
(
|
||||
class,
|
||||
self.distribution.log_likelihood(class_index, &row)
|
||||
+ self.distribution.prior(class_index).ln(),
|
||||
)
|
||||
})
|
||||
.max_by(|(_, p1), (_, p2)| p1.partial_cmp(p2).unwrap())
|
||||
.unwrap();
|
||||
*prediction
|
||||
})
|
||||
.collect::<Vec<TY>>();
|
||||
let y_hat = Y::from_vec_slice(&predictions);
|
||||
Ok(y_hat)
|
||||
let mut predictions = Vec::with_capacity(rows);
|
||||
let mut all_probs_nan = true;
|
||||
|
||||
for row_index in 0..rows {
|
||||
let row = x.get_row(row_index);
|
||||
let mut max_log_prob = f64::NEG_INFINITY;
|
||||
let mut max_class = None;
|
||||
|
||||
for (class_index, class) in y_classes.iter().enumerate() {
|
||||
let log_likelihood = self.distribution.log_likelihood(class_index, &row);
|
||||
let log_prob = log_likelihood + self.distribution.prior(class_index).ln();
|
||||
|
||||
if !log_prob.is_nan() && log_prob > max_log_prob {
|
||||
max_log_prob = log_prob;
|
||||
max_class = Some(*class);
|
||||
all_probs_nan = false;
|
||||
}
|
||||
}
|
||||
|
||||
predictions.push(max_class.unwrap_or(y_classes[0]));
|
||||
}
|
||||
|
||||
if all_probs_nan {
|
||||
Err(Failed::predict(
|
||||
"Failed to predict, all probabilities were NaN",
|
||||
))
|
||||
} else {
|
||||
Ok(Y::from_vec_slice(&predictions))
|
||||
}
|
||||
}
|
||||
}
|
||||
pub mod bernoulli;
|
||||
pub mod categorical;
|
||||
pub mod gaussian;
|
||||
pub mod multinomial;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::basic::arrays::Array;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use num_traits::float::Float;
|
||||
|
||||
type Model<'d> = BaseNaiveBayes<i32, i32, DenseMatrix<i32>, Vec<i32>, TestDistribution<'d>>;
|
||||
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
struct TestDistribution<'d>(&'d Vec<i32>);
|
||||
|
||||
impl NBDistribution<i32, i32> for TestDistribution<'_> {
|
||||
fn prior(&self, _class_index: usize) -> f64 {
|
||||
1.
|
||||
}
|
||||
|
||||
fn log_likelihood<'a>(
|
||||
&'a self,
|
||||
class_index: usize,
|
||||
_j: &'a Box<dyn ArrayView1<i32> + 'a>,
|
||||
) -> f64 {
|
||||
match self.0.get(class_index) {
|
||||
&v @ 2 | &v @ 10 | &v @ 20 => v as f64,
|
||||
_ => f64::nan(),
|
||||
}
|
||||
}
|
||||
|
||||
fn classes(&self) -> &Vec<i32> {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_predict() {
|
||||
let matrix = DenseMatrix::from_2d_array(&[&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]]).unwrap();
|
||||
|
||||
let val = vec![];
|
||||
match Model::fit(TestDistribution(&val)).unwrap().predict(&matrix) {
|
||||
Ok(_) => panic!("Should return error in case of empty classes"),
|
||||
Err(err) => assert_eq!(
|
||||
err.to_string(),
|
||||
"Predict failed: Failed to predict, no classes available"
|
||||
),
|
||||
}
|
||||
|
||||
let val = vec![1, 2, 3];
|
||||
match Model::fit(TestDistribution(&val)).unwrap().predict(&matrix) {
|
||||
Ok(r) => assert_eq!(r, vec![2, 2, 2]),
|
||||
Err(_) => panic!("Should success in normal case with NaNs"),
|
||||
}
|
||||
|
||||
let val = vec![20, 2, 10];
|
||||
match Model::fit(TestDistribution(&val)).unwrap().predict(&matrix) {
|
||||
Ok(r) => assert_eq!(r, vec![20, 20, 20]),
|
||||
Err(_) => panic!("Should success in normal case without NaNs"),
|
||||
}
|
||||
}
|
||||
|
||||
// A simple test distribution using float
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
struct TestDistributionAgain {
|
||||
classes: Vec<u32>,
|
||||
probs: Vec<f64>,
|
||||
}
|
||||
|
||||
impl NBDistribution<f64, u32> for TestDistributionAgain {
|
||||
fn classes(&self) -> &Vec<u32> {
|
||||
&self.classes
|
||||
}
|
||||
fn prior(&self, class_index: usize) -> f64 {
|
||||
self.probs[class_index]
|
||||
}
|
||||
fn log_likelihood<'a>(
|
||||
&'a self,
|
||||
class_index: usize,
|
||||
_j: &'a Box<dyn ArrayView1<f64> + 'a>,
|
||||
) -> f64 {
|
||||
self.probs[class_index].ln()
|
||||
}
|
||||
}
|
||||
|
||||
type TestNB = BaseNaiveBayes<f64, u32, DenseMatrix<f64>, Vec<u32>, TestDistributionAgain>;
|
||||
|
||||
#[test]
|
||||
fn test_predict_empty_classes() {
|
||||
let dist = TestDistributionAgain {
|
||||
classes: vec![],
|
||||
probs: vec![],
|
||||
};
|
||||
let nb = TestNB::fit(dist).unwrap();
|
||||
let x = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
|
||||
assert!(nb.predict(&x).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_predict_single_class() {
|
||||
let dist = TestDistributionAgain {
|
||||
classes: vec![1],
|
||||
probs: vec![1.0],
|
||||
};
|
||||
let nb = TestNB::fit(dist).unwrap();
|
||||
let x = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
|
||||
let result = nb.predict(&x).unwrap();
|
||||
assert_eq!(result, vec![1, 1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_predict_multiple_classes() {
|
||||
let dist = TestDistributionAgain {
|
||||
classes: vec![1, 2, 3],
|
||||
probs: vec![0.2, 0.5, 0.3],
|
||||
};
|
||||
let nb = TestNB::fit(dist).unwrap();
|
||||
let x = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0], &[5.0, 6.0]]).unwrap();
|
||||
let result = nb.predict(&x).unwrap();
|
||||
assert_eq!(result, vec![2, 2, 2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_predict_with_nans() {
|
||||
let dist = TestDistributionAgain {
|
||||
classes: vec![1, 2],
|
||||
probs: vec![f64::NAN, 0.5],
|
||||
};
|
||||
let nb = TestNB::fit(dist).unwrap();
|
||||
let x = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
|
||||
let result = nb.predict(&x).unwrap();
|
||||
assert_eq!(result, vec![2, 2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_predict_all_nans() {
|
||||
let dist = TestDistributionAgain {
|
||||
classes: vec![1, 2],
|
||||
probs: vec![f64::NAN, f64::NAN],
|
||||
};
|
||||
let nb = TestNB::fit(dist).unwrap();
|
||||
let x = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
|
||||
assert!(nb.predict(&x).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_predict_extreme_probabilities() {
|
||||
let dist = TestDistributionAgain {
|
||||
classes: vec![1, 2],
|
||||
probs: vec![1e-300, 1e-301],
|
||||
};
|
||||
let nb = TestNB::fit(dist).unwrap();
|
||||
let x = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
|
||||
let result = nb.predict(&x).unwrap();
|
||||
assert_eq!(result, vec![1, 1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_predict_with_infinity() {
|
||||
let dist = TestDistributionAgain {
|
||||
classes: vec![1, 2, 3],
|
||||
probs: vec![f64::INFINITY, 1.0, 2.0],
|
||||
};
|
||||
let nb = TestNB::fit(dist).unwrap();
|
||||
let x = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
|
||||
let result = nb.predict(&x).unwrap();
|
||||
assert_eq!(result, vec![1, 1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_predict_with_negative_infinity() {
|
||||
let dist = TestDistributionAgain {
|
||||
classes: vec![1, 2, 3],
|
||||
probs: vec![f64::NEG_INFINITY, 1.0, 2.0],
|
||||
};
|
||||
let nb = TestNB::fit(dist).unwrap();
|
||||
let x = DenseMatrix::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
|
||||
let result = nb.predict(&x).unwrap();
|
||||
assert_eq!(result, vec![3, 3]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gaussian_naive_bayes_numerical_stability() {
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
struct GaussianTestDistribution {
|
||||
classes: Vec<u32>,
|
||||
means: Vec<Vec<f64>>,
|
||||
variances: Vec<Vec<f64>>,
|
||||
priors: Vec<f64>,
|
||||
}
|
||||
|
||||
impl NBDistribution<f64, u32> for GaussianTestDistribution {
|
||||
fn classes(&self) -> &Vec<u32> {
|
||||
&self.classes
|
||||
}
|
||||
|
||||
fn prior(&self, class_index: usize) -> f64 {
|
||||
self.priors[class_index]
|
||||
}
|
||||
|
||||
fn log_likelihood<'a>(
|
||||
&'a self,
|
||||
class_index: usize,
|
||||
j: &'a Box<dyn ArrayView1<f64> + 'a>,
|
||||
) -> f64 {
|
||||
let means = &self.means[class_index];
|
||||
let variances = &self.variances[class_index];
|
||||
j.iterator(0)
|
||||
.enumerate()
|
||||
.map(|(i, &xi)| {
|
||||
let mean = means[i];
|
||||
let var = variances[i] + 1e-9; // Small smoothing for numerical stability
|
||||
let coeff = -0.5 * (2.0 * std::f64::consts::PI * var).ln();
|
||||
let exponent = -(xi - mean).powi(2) / (2.0 * var);
|
||||
coeff + exponent
|
||||
})
|
||||
.sum()
|
||||
}
|
||||
}
|
||||
|
||||
fn train_distribution(x: &DenseMatrix<f64>, y: &[u32]) -> GaussianTestDistribution {
|
||||
let mut classes: Vec<u32> = y
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect::<std::collections::HashSet<u32>>()
|
||||
.into_iter()
|
||||
.collect();
|
||||
classes.sort();
|
||||
let n_classes = classes.len();
|
||||
let n_features = x.shape().1;
|
||||
|
||||
let mut means = vec![vec![0.0; n_features]; n_classes];
|
||||
let mut variances = vec![vec![0.0; n_features]; n_classes];
|
||||
let mut class_counts = vec![0; n_classes];
|
||||
|
||||
// Calculate means and count samples per class
|
||||
for (sample, &class) in x.row_iter().zip(y.iter()) {
|
||||
let class_idx = classes.iter().position(|&c| c == class).unwrap();
|
||||
class_counts[class_idx] += 1;
|
||||
for (i, &value) in sample.iterator(0).enumerate() {
|
||||
means[class_idx][i] += value;
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize means
|
||||
for (class_idx, mean) in means.iter_mut().enumerate() {
|
||||
for value in mean.iter_mut() {
|
||||
*value /= class_counts[class_idx] as f64;
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate variances
|
||||
for (sample, &class) in x.row_iter().zip(y.iter()) {
|
||||
let class_idx = classes.iter().position(|&c| c == class).unwrap();
|
||||
for (i, &value) in sample.iterator(0).enumerate() {
|
||||
let diff = value - means[class_idx][i];
|
||||
variances[class_idx][i] += diff * diff;
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize variances and add small epsilon to avoid zero variance
|
||||
let epsilon = 1e-9;
|
||||
for (class_idx, variance) in variances.iter_mut().enumerate() {
|
||||
for value in variance.iter_mut() {
|
||||
*value = *value / class_counts[class_idx] as f64 + epsilon;
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate priors
|
||||
let total_samples = y.len() as f64;
|
||||
let priors: Vec<f64> = class_counts
|
||||
.iter()
|
||||
.map(|&count| count as f64 / total_samples)
|
||||
.collect();
|
||||
|
||||
GaussianTestDistribution {
|
||||
classes,
|
||||
means,
|
||||
variances,
|
||||
priors,
|
||||
}
|
||||
}
|
||||
|
||||
type TestNBGaussian =
|
||||
BaseNaiveBayes<f64, u32, DenseMatrix<f64>, Vec<u32>, GaussianTestDistribution>;
|
||||
|
||||
// Create a constant training dataset
|
||||
let n_samples = 1000;
|
||||
let n_features = 5;
|
||||
let n_classes = 4;
|
||||
|
||||
let mut x_data = Vec::with_capacity(n_samples * n_features);
|
||||
let mut y_data = Vec::with_capacity(n_samples);
|
||||
|
||||
for i in 0..n_samples {
|
||||
for j in 0..n_features {
|
||||
x_data.push((i * j) as f64 % 10.0);
|
||||
}
|
||||
y_data.push((i % n_classes) as u32);
|
||||
}
|
||||
|
||||
let x = DenseMatrix::new(n_samples, n_features, x_data, true).unwrap();
|
||||
let y = y_data;
|
||||
|
||||
// Train the model
|
||||
let dist = train_distribution(&x, &y);
|
||||
let nb = TestNBGaussian::fit(dist).unwrap();
|
||||
|
||||
// Create constant test data
|
||||
let n_test_samples = 100;
|
||||
let mut test_x_data = Vec::with_capacity(n_test_samples * n_features);
|
||||
for i in 0..n_test_samples {
|
||||
for j in 0..n_features {
|
||||
test_x_data.push((i * j * 2) as f64 % 15.0);
|
||||
}
|
||||
}
|
||||
let test_x = DenseMatrix::new(n_test_samples, n_features, test_x_data, true).unwrap();
|
||||
|
||||
// Make predictions
|
||||
let predictions = nb
|
||||
.predict(&test_x)
|
||||
.map_err(|e| format!("Prediction failed: {}", e))
|
||||
.unwrap();
|
||||
|
||||
// Check numerical stability
|
||||
assert_eq!(
|
||||
predictions.len(),
|
||||
n_test_samples,
|
||||
"Number of predictions should match number of test samples"
|
||||
);
|
||||
|
||||
// Check that all predictions are valid class labels
|
||||
for &pred in predictions.iter() {
|
||||
assert!(pred < n_classes as u32, "Predicted class should be valid");
|
||||
}
|
||||
|
||||
// Check consistency of predictions
|
||||
let repeated_predictions = nb
|
||||
.predict(&test_x)
|
||||
.map_err(|e| format!("Repeated prediction failed: {}", e))
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
predictions, repeated_predictions,
|
||||
"Predictions should be consistent when repeated"
|
||||
);
|
||||
|
||||
// Check extreme values
|
||||
let extreme_x =
|
||||
DenseMatrix::new(2, n_features, vec![f64::MAX; n_features * 2], true).unwrap();
|
||||
let extreme_predictions = nb.predict(&extreme_x);
|
||||
assert!(
|
||||
extreme_predictions.is_err(),
|
||||
"Extreme value input should result in an error"
|
||||
);
|
||||
assert_eq!(
|
||||
extreme_predictions.unwrap_err().to_string(),
|
||||
"Predict failed: Failed to predict, all probabilities were NaN",
|
||||
"Incorrect error message for extreme values"
|
||||
);
|
||||
|
||||
// Check for NaN handling
|
||||
let nan_x = DenseMatrix::new(2, n_features, vec![f64::NAN; n_features * 2], true).unwrap();
|
||||
let nan_predictions = nb.predict(&nan_x);
|
||||
assert!(
|
||||
nan_predictions.is_err(),
|
||||
"NaN input should result in an error"
|
||||
);
|
||||
|
||||
// Check for very small values
|
||||
let small_x =
|
||||
DenseMatrix::new(2, n_features, vec![f64::MIN_POSITIVE; n_features * 2], true).unwrap();
|
||||
let small_predictions = nb
|
||||
.predict(&small_x)
|
||||
.map_err(|e| format!("Small value prediction failed: {}", e))
|
||||
.unwrap();
|
||||
for &pred in small_predictions.iter() {
|
||||
assert!(
|
||||
pred < n_classes as u32,
|
||||
"Predictions for very small values should be valid"
|
||||
);
|
||||
}
|
||||
|
||||
// Check for values close to zero
|
||||
let near_zero_x =
|
||||
DenseMatrix::new(2, n_features, vec![1e-300; n_features * 2], true).unwrap();
|
||||
let near_zero_predictions = nb
|
||||
.predict(&near_zero_x)
|
||||
.map_err(|e| format!("Near-zero value prediction failed: {}", e))
|
||||
.unwrap();
|
||||
for &pred in near_zero_predictions.iter() {
|
||||
assert!(
|
||||
pred < n_classes as u32,
|
||||
"Predictions for near-zero values should be valid"
|
||||
);
|
||||
}
|
||||
|
||||
println!("All numerical stability checks passed!");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gaussian_naive_bayes_numerical_stability_random_data() {
|
||||
#[derive(Debug)]
|
||||
struct MySimpleRng {
|
||||
state: u64,
|
||||
}
|
||||
|
||||
impl MySimpleRng {
|
||||
fn new(seed: u64) -> Self {
|
||||
MySimpleRng { state: seed }
|
||||
}
|
||||
|
||||
/// Get the next u64 in the sequence.
|
||||
fn next_u64(&mut self) -> u64 {
|
||||
// LCG parameters; these are somewhat arbitrary but commonly used.
|
||||
// Feel free to tweak the multiplier/adder etc.
|
||||
self.state = self.state.wrapping_mul(6364136223846793005).wrapping_add(1);
|
||||
self.state
|
||||
}
|
||||
|
||||
/// Get an f64 in the range [min, max).
|
||||
fn next_f64(&mut self, min: f64, max: f64) -> f64 {
|
||||
let fraction = (self.next_u64() as f64) / (u64::MAX as f64);
|
||||
min + fraction * (max - min)
|
||||
}
|
||||
|
||||
/// Get a usize in the range [min, max). This floors the floating result.
|
||||
fn gen_range_usize(&mut self, min: usize, max: usize) -> usize {
|
||||
let v = self.next_f64(min as f64, max as f64);
|
||||
// Truncate into the integer range. Because of floating inexactness,
|
||||
// ensure we also clamp.
|
||||
let int_v = v.floor() as isize;
|
||||
// simple clamp to avoid any float rounding out of range
|
||||
let clamped = int_v.max(min as isize).min((max - 1) as isize);
|
||||
clamped as usize
|
||||
}
|
||||
}
|
||||
use crate::naive_bayes::gaussian::GaussianNB;
|
||||
// We will generate random data in a reproducible way (using a fixed seed).
|
||||
// We will generate random data in a reproducible way:
|
||||
let mut rng = MySimpleRng::new(42);
|
||||
|
||||
let n_samples = 1000;
|
||||
let n_features = 5;
|
||||
let n_classes = 4;
|
||||
|
||||
// Our feature matrix and label vector
|
||||
let mut x_data = Vec::with_capacity(n_samples * n_features);
|
||||
let mut y_data = Vec::with_capacity(n_samples);
|
||||
|
||||
// Fill x_data with random values and y_data with random class labels.
|
||||
for _i in 0..n_samples {
|
||||
for _j in 0..n_features {
|
||||
// We’ll pick random values in [-10, 10).
|
||||
x_data.push(rng.next_f64(-10.0, 10.0));
|
||||
}
|
||||
let class = rng.gen_range_usize(0, n_classes) as u32;
|
||||
y_data.push(class);
|
||||
}
|
||||
|
||||
// Create DenseMatrix from x_data
|
||||
let x = DenseMatrix::new(n_samples, n_features, x_data, true).unwrap();
|
||||
|
||||
// Train GaussianNB
|
||||
let gnb = GaussianNB::fit(&x, &y_data, Default::default())
|
||||
.expect("Fitting GaussianNB with random data failed.");
|
||||
|
||||
// Predict on the same training data to verify no numerical instability
|
||||
let predictions = gnb.predict(&x).expect("Prediction on random data failed.");
|
||||
|
||||
// Basic sanity checks
|
||||
assert_eq!(
|
||||
predictions.len(),
|
||||
n_samples,
|
||||
"Prediction size must match n_samples"
|
||||
);
|
||||
for &pred_class in &predictions {
|
||||
assert!(
|
||||
(pred_class as usize) < n_classes,
|
||||
"Predicted class {} is out of range [0..n_classes).",
|
||||
pred_class
|
||||
);
|
||||
}
|
||||
|
||||
// If you want to compare with scikit-learn, you can do something like:
|
||||
// println!("X = {:?}", &x);
|
||||
// println!("Y = {:?}", &y_data);
|
||||
// println!("predictions = {:?}", &predictions);
|
||||
// and then in Python:
|
||||
// import numpy as np
|
||||
// from sklearn.naive_bayes import GaussianNB
|
||||
// X = np.reshape(np.array(x), (1000, 5), order='F')
|
||||
// Y = np.array(y)
|
||||
// gnb = GaussianNB().fit(X, Y)
|
||||
// preds = gnb.predict(X)
|
||||
// expected = np.array(predictions)
|
||||
// assert expected == preds
|
||||
// They should match closely (or exactly) depending on floating rounding.
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,13 +20,13 @@
|
||||
//! &[0, 2, 0, 0, 1, 0],
|
||||
//! &[0, 1, 0, 1, 0, 0],
|
||||
//! &[0, 1, 1, 0, 0, 1],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//! let y: Vec<u32> = vec![0, 0, 0, 1];
|
||||
//! let nb = MultinomialNB::fit(&x, &y, Default::default()).unwrap();
|
||||
//!
|
||||
//! // Testing data point is:
|
||||
//! // Chinese Chinese Chinese Tokyo Japan
|
||||
//! let x_test = DenseMatrix::from_2d_array(&[&[0, 3, 1, 0, 0, 1]]);
|
||||
//! let x_test = DenseMatrix::from_2d_array(&[&[0, 3, 1, 0, 0, 1]]).unwrap();
|
||||
//! let y_hat = nb.predict(&x_test).unwrap();
|
||||
//! ```
|
||||
//!
|
||||
@@ -207,8 +207,7 @@ impl<TY: Number + Ord + Unsigned> MultinomialNBDistribution<TY> {
|
||||
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
|
||||
/// * `x` - training data.
|
||||
/// * `y` - vector with target values (classes) of length N.
|
||||
/// * `priors` - Optional vector with prior probabilities of the classes. If not defined,
|
||||
/// priors are adjusted according to the data.
|
||||
/// * `priors` - Optional vector with prior probabilities of the classes. If not defined, priors are adjusted according to the data.
|
||||
/// * `alpha` - Additive (Laplace/Lidstone) smoothing parameter.
|
||||
pub fn fit<TX: Number + Unsigned, X: Array2<TX>, Y: Array1<TY>>(
|
||||
x: &X,
|
||||
@@ -345,10 +344,10 @@ impl<TX: Number + Unsigned, TY: Number + Ord + Unsigned, X: Array2<TX>, Y: Array
|
||||
{
|
||||
/// Fits MultinomialNB with given data
|
||||
/// * `x` - training data of size NxM where N is the number of samples and M is the number of
|
||||
/// features.
|
||||
/// features.
|
||||
/// * `y` - vector with target values (classes) of length N.
|
||||
/// * `parameters` - additional parameters like class priors, alpha for smoothing and
|
||||
/// binarizing threshold.
|
||||
/// binarizing threshold.
|
||||
pub fn fit(x: &X, y: &Y, parameters: MultinomialNBParameters) -> Result<Self, Failed> {
|
||||
let distribution =
|
||||
MultinomialNBDistribution::fit(x, y, parameters.alpha, parameters.priors)?;
|
||||
@@ -358,6 +357,7 @@ impl<TX: Number + Unsigned, TY: Number + Ord + Unsigned, X: Array2<TX>, Y: Array
|
||||
|
||||
/// Estimates the class labels for the provided data.
|
||||
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
|
||||
///
|
||||
/// Returns a vector of size N with class estimates.
|
||||
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
self.inner.as_ref().unwrap().predict(x)
|
||||
@@ -433,7 +433,8 @@ mod tests {
|
||||
&[0, 2, 0, 0, 1, 0],
|
||||
&[0, 1, 0, 1, 0, 0],
|
||||
&[0, 1, 1, 0, 0, 1],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y: Vec<u32> = vec![0, 0, 0, 1];
|
||||
let mnb = MultinomialNB::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
@@ -467,7 +468,7 @@ mod tests {
|
||||
|
||||
// Testing data point is:
|
||||
// Chinese Chinese Chinese Tokyo Japan
|
||||
let x_test = DenseMatrix::<u32>::from_2d_array(&[&[0, 3, 1, 0, 0, 1]]);
|
||||
let x_test = DenseMatrix::<u32>::from_2d_array(&[&[0, 3, 1, 0, 0, 1]]).unwrap();
|
||||
let y_hat = mnb.predict(&x_test).unwrap();
|
||||
|
||||
assert_eq!(y_hat, &[0]);
|
||||
@@ -495,7 +496,8 @@ mod tests {
|
||||
&[2, 0, 3, 3, 1, 2, 0, 2, 4, 1],
|
||||
&[2, 4, 0, 4, 2, 4, 1, 3, 1, 4],
|
||||
&[0, 2, 2, 3, 4, 0, 4, 4, 4, 4],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y: Vec<u32> = vec![2, 2, 0, 0, 0, 2, 1, 1, 0, 1, 0, 0, 2, 0, 2];
|
||||
let nb = MultinomialNB::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
@@ -554,7 +556,8 @@ mod tests {
|
||||
&[0, 1, 0, 0, 1, 0],
|
||||
&[0, 1, 0, 1, 0, 0],
|
||||
&[0, 1, 1, 0, 0, 1],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y = vec![0, 0, 0, 1];
|
||||
|
||||
let mnb = MultinomialNB::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
@@ -22,7 +22,7 @@
|
||||
//! &[3., 4.],
|
||||
//! &[5., 6.],
|
||||
//! &[7., 8.],
|
||||
//! &[9., 10.]]);
|
||||
//! &[9., 10.]]).unwrap();
|
||||
//! let y = vec![2, 2, 2, 3, 3]; //your class labels
|
||||
//!
|
||||
//! let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||
@@ -211,7 +211,7 @@ impl<TX: Number, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec
|
||||
{
|
||||
/// Fits KNN classifier to a NxM matrix where N is number of samples and M is number of features.
|
||||
/// * `x` - training data
|
||||
/// * `y` - vector with target values (classes) of length N
|
||||
/// * `y` - vector with target values (classes) of length N
|
||||
/// * `parameters` - additional parameters like search algorithm and k
|
||||
pub fn fit(
|
||||
x: &X,
|
||||
@@ -261,6 +261,7 @@ impl<TX: Number, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec
|
||||
|
||||
/// Estimates the class labels for the provided data.
|
||||
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
|
||||
///
|
||||
/// Returns a vector of size N with class estimates.
|
||||
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
let mut result = Y::zeros(x.shape().0);
|
||||
@@ -311,7 +312,8 @@ mod tests {
|
||||
#[test]
|
||||
fn knn_fit_predict() {
|
||||
let x =
|
||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]])
|
||||
.unwrap();
|
||||
let y = vec![2, 2, 2, 3, 3];
|
||||
let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||
let y_hat = knn.predict(&x).unwrap();
|
||||
@@ -325,7 +327,7 @@ mod tests {
|
||||
)]
|
||||
#[test]
|
||||
fn knn_fit_predict_weighted() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1.], &[2.], &[3.], &[4.], &[5.]]);
|
||||
let x = DenseMatrix::from_2d_array(&[&[1.], &[2.], &[3.], &[4.], &[5.]]).unwrap();
|
||||
let y = vec![2, 2, 2, 3, 3];
|
||||
let knn = KNNClassifier::fit(
|
||||
&x,
|
||||
@@ -336,7 +338,9 @@ mod tests {
|
||||
.with_weight(KNNWeightFunction::Distance),
|
||||
)
|
||||
.unwrap();
|
||||
let y_hat = knn.predict(&DenseMatrix::from_2d_array(&[&[4.1]])).unwrap();
|
||||
let y_hat = knn
|
||||
.predict(&DenseMatrix::from_2d_array(&[&[4.1]]).unwrap())
|
||||
.unwrap();
|
||||
assert_eq!(vec![3], y_hat);
|
||||
}
|
||||
|
||||
@@ -348,7 +352,8 @@ mod tests {
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x =
|
||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]])
|
||||
.unwrap();
|
||||
let y = vec![2, 2, 2, 3, 3];
|
||||
|
||||
let knn = KNNClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
//! &[2., 2.],
|
||||
//! &[3., 3.],
|
||||
//! &[4., 4.],
|
||||
//! &[5., 5.]]);
|
||||
//! &[5., 5.]]).unwrap();
|
||||
//! let y = vec![1., 2., 3., 4., 5.]; //your target values
|
||||
//!
|
||||
//! let knn = KNNRegressor::fit(&x, &y, Default::default()).unwrap();
|
||||
@@ -88,25 +88,21 @@ pub struct KNNRegressor<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D:
|
||||
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
|
||||
KNNRegressor<TX, TY, X, Y, D>
|
||||
{
|
||||
///
|
||||
fn y(&self) -> &Y {
|
||||
self.y.as_ref().unwrap()
|
||||
}
|
||||
|
||||
///
|
||||
fn knn_algorithm(&self) -> &KNNAlgorithm<TX, D> {
|
||||
self.knn_algorithm
|
||||
.as_ref()
|
||||
.expect("Missing parameter: KNNAlgorithm")
|
||||
}
|
||||
|
||||
///
|
||||
fn weight(&self) -> &KNNWeightFunction {
|
||||
self.weight.as_ref().expect("Missing parameter: weight")
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
///
|
||||
fn k(&self) -> usize {
|
||||
self.k.unwrap()
|
||||
}
|
||||
@@ -207,7 +203,7 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
|
||||
{
|
||||
/// Fits KNN regressor to a NxM matrix where N is number of samples and M is number of features.
|
||||
/// * `x` - training data
|
||||
/// * `y` - vector with real values
|
||||
/// * `y` - vector with real values
|
||||
/// * `parameters` - additional parameters like search algorithm and k
|
||||
pub fn fit(
|
||||
x: &X,
|
||||
@@ -250,6 +246,7 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
|
||||
|
||||
/// Predict the target for the provided data.
|
||||
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
|
||||
///
|
||||
/// Returns a vector of size N with estimates.
|
||||
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||
let mut result = Y::zeros(x.shape().0);
|
||||
@@ -295,9 +292,10 @@ mod tests {
|
||||
#[test]
|
||||
fn knn_fit_predict_weighted() {
|
||||
let x =
|
||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]])
|
||||
.unwrap();
|
||||
let y: Vec<f64> = vec![1., 2., 3., 4., 5.];
|
||||
let y_exp = vec![1., 2., 3., 4., 5.];
|
||||
let y_exp = [1., 2., 3., 4., 5.];
|
||||
let knn = KNNRegressor::fit(
|
||||
&x,
|
||||
&y,
|
||||
@@ -311,7 +309,7 @@ mod tests {
|
||||
let y_hat = knn.predict(&x).unwrap();
|
||||
assert_eq!(5, Vec::len(&y_hat));
|
||||
for i in 0..y_hat.len() {
|
||||
assert!((y_hat[i] - y_exp[i]).abs() < std::f64::EPSILON);
|
||||
assert!((y_hat[i] - y_exp[i]).abs() < f64::EPSILON);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -322,9 +320,10 @@ mod tests {
|
||||
#[test]
|
||||
fn knn_fit_predict_uniform() {
|
||||
let x =
|
||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]])
|
||||
.unwrap();
|
||||
let y: Vec<f64> = vec![1., 2., 3., 4., 5.];
|
||||
let y_exp = vec![2., 2., 3., 4., 4.];
|
||||
let y_exp = [2., 2., 3., 4., 4.];
|
||||
let knn = KNNRegressor::fit(&x, &y, Default::default()).unwrap();
|
||||
let y_hat = knn.predict(&x).unwrap();
|
||||
assert_eq!(5, Vec::len(&y_hat));
|
||||
@@ -341,7 +340,8 @@ mod tests {
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde() {
|
||||
let x =
|
||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]]);
|
||||
DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.], &[5., 6.], &[7., 8.], &[9., 10.]])
|
||||
.unwrap();
|
||||
let y = vec![1., 2., 3., 4., 5.];
|
||||
|
||||
let knn = KNNRegressor::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
@@ -64,7 +64,7 @@ impl KNNWeightFunction {
|
||||
KNNWeightFunction::Distance => {
|
||||
// if there are any points that has zero distance from one or more training points,
|
||||
// those training points are weighted as 1.0 and the other points as 0.0
|
||||
if distances.iter().any(|&e| e == 0f64) {
|
||||
if distances.contains(&0f64) {
|
||||
distances
|
||||
.iter()
|
||||
.map(|e| if *e == 0f64 { 1f64 } else { 0f64 })
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
// TODO: missing documentation
|
||||
|
||||
use std::default::Default;
|
||||
|
||||
use crate::linalg::basic::arrays::Array1;
|
||||
@@ -8,30 +6,27 @@ use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
|
||||
use crate::optimization::line_search::LineSearchMethod;
|
||||
use crate::optimization::{DF, F};
|
||||
|
||||
///
|
||||
/// Gradient Descent optimization algorithm
|
||||
pub struct GradientDescent {
|
||||
///
|
||||
/// Maximum number of iterations
|
||||
pub max_iter: usize,
|
||||
///
|
||||
/// Relative tolerance for the gradient norm
|
||||
pub g_rtol: f64,
|
||||
///
|
||||
/// Absolute tolerance for the gradient norm
|
||||
pub g_atol: f64,
|
||||
}
|
||||
|
||||
///
|
||||
impl Default for GradientDescent {
|
||||
fn default() -> Self {
|
||||
GradientDescent {
|
||||
max_iter: 10000,
|
||||
g_rtol: std::f64::EPSILON.sqrt(),
|
||||
g_atol: std::f64::EPSILON,
|
||||
g_rtol: f64::EPSILON.sqrt(),
|
||||
g_atol: f64::EPSILON,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
impl<T: FloatNumber> FirstOrderOptimizer<T> for GradientDescent {
|
||||
///
|
||||
fn optimize<'a, X: Array1<T>, LS: LineSearchMethod<T>>(
|
||||
&self,
|
||||
f: &'a F<'_, T, X>,
|
||||
|
||||
@@ -11,31 +11,29 @@ use crate::optimization::first_order::{FirstOrderOptimizer, OptimizerResult};
|
||||
use crate::optimization::line_search::LineSearchMethod;
|
||||
use crate::optimization::{DF, F};
|
||||
|
||||
///
|
||||
/// Limited-memory BFGS optimization algorithm
|
||||
pub struct LBFGS {
|
||||
///
|
||||
/// Maximum number of iterations
|
||||
pub max_iter: usize,
|
||||
///
|
||||
/// TODO: Add documentation
|
||||
pub g_rtol: f64,
|
||||
///
|
||||
/// TODO: Add documentation
|
||||
pub g_atol: f64,
|
||||
///
|
||||
/// TODO: Add documentation
|
||||
pub x_atol: f64,
|
||||
///
|
||||
/// TODO: Add documentation
|
||||
pub x_rtol: f64,
|
||||
///
|
||||
/// TODO: Add documentation
|
||||
pub f_abstol: f64,
|
||||
///
|
||||
/// TODO: Add documentation
|
||||
pub f_reltol: f64,
|
||||
///
|
||||
/// TODO: Add documentation
|
||||
pub successive_f_tol: usize,
|
||||
///
|
||||
/// TODO: Add documentation
|
||||
pub m: usize,
|
||||
}
|
||||
|
||||
///
|
||||
impl Default for LBFGS {
|
||||
///
|
||||
fn default() -> Self {
|
||||
LBFGS {
|
||||
max_iter: 1000,
|
||||
@@ -51,9 +49,7 @@ impl Default for LBFGS {
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
impl LBFGS {
|
||||
///
|
||||
fn two_loops<T: FloatNumber + RealNumber, X: Array1<T>>(&self, state: &mut LBFGSState<T, X>) {
|
||||
let lower = state.iteration.max(self.m) - self.m;
|
||||
let upper = state.iteration;
|
||||
@@ -95,7 +91,6 @@ impl LBFGS {
|
||||
state.s.mul_scalar_mut(-T::one());
|
||||
}
|
||||
|
||||
///
|
||||
fn init_state<T: FloatNumber + RealNumber, X: Array1<T>>(&self, x: &X) -> LBFGSState<T, X> {
|
||||
LBFGSState {
|
||||
x: x.clone(),
|
||||
@@ -119,7 +114,6 @@ impl LBFGS {
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
fn update_state<'a, T: FloatNumber + RealNumber, X: Array1<T>, LS: LineSearchMethod<T>>(
|
||||
&self,
|
||||
f: &'a F<'_, T, X>,
|
||||
@@ -161,7 +155,6 @@ impl LBFGS {
|
||||
df(&mut state.x_df, &state.x);
|
||||
}
|
||||
|
||||
///
|
||||
fn assess_convergence<T: FloatNumber, X: Array1<T>>(
|
||||
&self,
|
||||
state: &mut LBFGSState<T, X>,
|
||||
@@ -173,7 +166,7 @@ impl LBFGS {
|
||||
}
|
||||
|
||||
if state.x.max_diff(&state.x_prev)
|
||||
<= T::from_f64(self.x_rtol * state.x.norm(std::f64::INFINITY)).unwrap()
|
||||
<= T::from_f64(self.x_rtol * state.x.norm(f64::INFINITY)).unwrap()
|
||||
{
|
||||
x_converged = true;
|
||||
}
|
||||
@@ -188,14 +181,13 @@ impl LBFGS {
|
||||
state.counter_f_tol += 1;
|
||||
}
|
||||
|
||||
if state.x_df.norm(std::f64::INFINITY) <= self.g_atol {
|
||||
if state.x_df.norm(f64::INFINITY) <= self.g_atol {
|
||||
g_converged = true;
|
||||
}
|
||||
|
||||
g_converged || x_converged || state.counter_f_tol > self.successive_f_tol
|
||||
}
|
||||
|
||||
///
|
||||
fn update_hessian<T: FloatNumber, X: Array1<T>>(
|
||||
&self,
|
||||
_: &DF<'_, X>,
|
||||
@@ -212,7 +204,6 @@ impl LBFGS {
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
#[derive(Debug)]
|
||||
struct LBFGSState<T: FloatNumber, X: Array1<T>> {
|
||||
x: X,
|
||||
@@ -234,9 +225,7 @@ struct LBFGSState<T: FloatNumber, X: Array1<T>> {
|
||||
alpha: T,
|
||||
}
|
||||
|
||||
///
|
||||
impl<T: FloatNumber + RealNumber> FirstOrderOptimizer<T> for LBFGS {
|
||||
///
|
||||
fn optimize<'a, X: Array1<T>, LS: LineSearchMethod<T>>(
|
||||
&self,
|
||||
f: &F<'_, T, X>,
|
||||
@@ -248,7 +237,7 @@ impl<T: FloatNumber + RealNumber> FirstOrderOptimizer<T> for LBFGS {
|
||||
|
||||
df(&mut state.x_df, x0);
|
||||
|
||||
let g_converged = state.x_df.norm(std::f64::INFINITY) < self.g_atol;
|
||||
let g_converged = state.x_df.norm(f64::INFINITY) < self.g_atol;
|
||||
let mut converged = g_converged;
|
||||
let stopped = false;
|
||||
|
||||
@@ -299,7 +288,7 @@ mod tests {
|
||||
|
||||
let result = optimizer.optimize(&f, &df, &x0, &ls);
|
||||
|
||||
assert!((result.f_x - 0.0).abs() < std::f64::EPSILON);
|
||||
assert!((result.f_x - 0.0).abs() < f64::EPSILON);
|
||||
assert!((result.x[0] - 1.0).abs() < 1e-8);
|
||||
assert!((result.x[1] - 1.0).abs() < 1e-8);
|
||||
assert!(result.iterations <= 24);
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
///
|
||||
/// Gradient descent optimization algorithm
|
||||
pub mod gradient_descent;
|
||||
///
|
||||
/// Limited-memory BFGS optimization algorithm
|
||||
pub mod lbfgs;
|
||||
|
||||
use std::clone::Clone;
|
||||
@@ -11,9 +11,9 @@ use crate::numbers::floatnum::FloatNumber;
|
||||
use crate::optimization::line_search::LineSearchMethod;
|
||||
use crate::optimization::{DF, F};
|
||||
|
||||
///
|
||||
/// First-order optimization is a class of algorithms that use the first derivative of a function to find optimal solutions.
|
||||
pub trait FirstOrderOptimizer<T: FloatNumber> {
|
||||
///
|
||||
/// run first order optimization
|
||||
fn optimize<'a, X: Array1<T>, LS: LineSearchMethod<T>>(
|
||||
&self,
|
||||
f: &F<'_, T, X>,
|
||||
@@ -23,13 +23,13 @@ pub trait FirstOrderOptimizer<T: FloatNumber> {
|
||||
) -> OptimizerResult<T, X>;
|
||||
}
|
||||
|
||||
///
|
||||
/// Result of optimization
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OptimizerResult<T: FloatNumber, X: Array1<T>> {
|
||||
///
|
||||
/// Solution
|
||||
pub x: X,
|
||||
///
|
||||
/// f(x) value
|
||||
pub f_x: T,
|
||||
///
|
||||
/// number of iterations
|
||||
pub iterations: usize,
|
||||
}
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
// TODO: missing documentation
|
||||
|
||||
use crate::optimization::FunctionOrder;
|
||||
use num_traits::Float;
|
||||
|
||||
///
|
||||
/// Line search optimization.
|
||||
pub trait LineSearchMethod<T: Float> {
|
||||
///
|
||||
/// Find alpha that satisfies strong Wolfe conditions.
|
||||
fn search(
|
||||
&self,
|
||||
f: &(dyn Fn(T) -> T),
|
||||
@@ -16,32 +14,31 @@ pub trait LineSearchMethod<T: Float> {
|
||||
) -> LineSearchResult<T>;
|
||||
}
|
||||
|
||||
///
|
||||
/// Line search result
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LineSearchResult<T: Float> {
|
||||
///
|
||||
/// Alpha value
|
||||
pub alpha: T,
|
||||
///
|
||||
/// f(alpha) value
|
||||
pub f_x: T,
|
||||
}
|
||||
|
||||
///
|
||||
/// Backtracking line search method.
|
||||
pub struct Backtracking<T: Float> {
|
||||
///
|
||||
/// TODO: Add documentation
|
||||
pub c1: T,
|
||||
///
|
||||
/// Maximum number of iterations for Backtracking single run
|
||||
pub max_iterations: usize,
|
||||
///
|
||||
/// TODO: Add documentation
|
||||
pub max_infinity_iterations: usize,
|
||||
///
|
||||
/// TODO: Add documentation
|
||||
pub phi: T,
|
||||
///
|
||||
/// TODO: Add documentation
|
||||
pub plo: T,
|
||||
///
|
||||
/// function order
|
||||
pub order: FunctionOrder,
|
||||
}
|
||||
|
||||
///
|
||||
impl<T: Float> Default for Backtracking<T> {
|
||||
fn default() -> Self {
|
||||
Backtracking {
|
||||
@@ -55,9 +52,7 @@ impl<T: Float> Default for Backtracking<T> {
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
impl<T: Float> LineSearchMethod<T> for Backtracking<T> {
|
||||
///
|
||||
fn search(
|
||||
&self,
|
||||
f: &(dyn Fn(T) -> T),
|
||||
|
||||
@@ -1,21 +1,19 @@
|
||||
// TODO: missing documentation
|
||||
|
||||
///
|
||||
/// first order optimization algorithms
|
||||
pub mod first_order;
|
||||
///
|
||||
/// line search algorithms
|
||||
pub mod line_search;
|
||||
|
||||
///
|
||||
/// Function f(x) = y
|
||||
pub type F<'a, T, X> = dyn for<'b> Fn(&'b X) -> T + 'a;
|
||||
///
|
||||
/// Function df(x)
|
||||
pub type DF<'a, X> = dyn for<'b> Fn(&'b mut X, &'b X) + 'a;
|
||||
|
||||
///
|
||||
/// Function order
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum FunctionOrder {
|
||||
///
|
||||
/// Second order
|
||||
SECOND,
|
||||
///
|
||||
/// Third order
|
||||
THIRD,
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
//! &[1.5, 2.0, 1.5, 4.0],
|
||||
//! &[1.5, 1.0, 1.5, 5.0],
|
||||
//! &[1.5, 2.0, 1.5, 6.0],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//! let encoder_params = OneHotEncoderParams::from_cat_idx(&[1, 3]);
|
||||
//! // Infer number of categories from data and return a reusable encoder
|
||||
//! let encoder = OneHotEncoder::fit(&data, encoder_params).unwrap();
|
||||
@@ -24,7 +24,7 @@
|
||||
//! // &[1.5, 1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0]
|
||||
//! // &[1.5, 0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0]
|
||||
//! ```
|
||||
use std::iter;
|
||||
use std::iter::repeat_n;
|
||||
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::basic::arrays::Array2;
|
||||
@@ -75,11 +75,7 @@ fn find_new_idxs(num_params: usize, cat_sizes: &[usize], cat_idxs: &[usize]) ->
|
||||
let offset = (0..1).chain(offset_);
|
||||
|
||||
let new_param_idxs: Vec<usize> = (0..num_params)
|
||||
.zip(
|
||||
repeats
|
||||
.zip(offset)
|
||||
.flat_map(|(r, o)| iter::repeat(o).take(r)),
|
||||
)
|
||||
.zip(repeats.zip(offset).flat_map(|(r, o)| repeat_n(o, r)))
|
||||
.map(|(idx, ofst)| idx + ofst)
|
||||
.collect();
|
||||
new_param_idxs
|
||||
@@ -124,7 +120,7 @@ impl OneHotEncoder {
|
||||
let (nrows, _) = data.shape();
|
||||
|
||||
// col buffer to avoid allocations
|
||||
let mut col_buf: Vec<T> = iter::repeat(T::zero()).take(nrows).collect();
|
||||
let mut col_buf: Vec<T> = repeat_n(T::zero(), nrows).collect();
|
||||
|
||||
let mut res: Vec<CategoryMapper<CategoricalFloat>> = Vec::with_capacity(idxs.len());
|
||||
|
||||
@@ -240,14 +236,16 @@ mod tests {
|
||||
&[2.0, 1.5, 4.0],
|
||||
&[1.0, 1.5, 5.0],
|
||||
&[2.0, 1.5, 6.0],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let oh_enc = DenseMatrix::from_2d_array(&[
|
||||
&[1.0, 0.0, 1.5, 1.0, 0.0, 0.0, 0.0],
|
||||
&[0.0, 1.0, 1.5, 0.0, 1.0, 0.0, 0.0],
|
||||
&[1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0],
|
||||
&[0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
(orig, oh_enc)
|
||||
}
|
||||
@@ -259,14 +257,16 @@ mod tests {
|
||||
&[1.5, 2.0, 1.5, 4.0],
|
||||
&[1.5, 1.0, 1.5, 5.0],
|
||||
&[1.5, 2.0, 1.5, 6.0],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let oh_enc = DenseMatrix::from_2d_array(&[
|
||||
&[1.5, 1.0, 0.0, 1.5, 1.0, 0.0, 0.0, 0.0],
|
||||
&[1.5, 0.0, 1.0, 1.5, 0.0, 1.0, 0.0, 0.0],
|
||||
&[1.5, 1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0],
|
||||
&[1.5, 0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
(orig, oh_enc)
|
||||
}
|
||||
@@ -277,7 +277,7 @@ mod tests {
|
||||
)]
|
||||
#[test]
|
||||
fn hash_encode_f64_series() {
|
||||
let series = vec![3.0, 1.0, 2.0, 1.0];
|
||||
let series = [3.0, 1.0, 2.0, 1.0];
|
||||
let hashable_series: Vec<CategoricalFloat> =
|
||||
series.iter().map(|v| v.to_category()).collect();
|
||||
let enc = CategoryMapper::from_positional_category_vec(hashable_series);
|
||||
@@ -334,7 +334,8 @@ mod tests {
|
||||
&[2.0, 1.5, 4.0],
|
||||
&[1.0, 1.5, 5.0],
|
||||
&[2.0, 1.5, 6.0],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let params = OneHotEncoderParams::from_cat_idx(&[1]);
|
||||
let result = OneHotEncoder::fit(&m, params);
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
//! vec![0.0, 0.0],
|
||||
//! vec![1.0, 1.0],
|
||||
//! vec![1.0, 1.0],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! let standard_scaler =
|
||||
//! numerical::StandardScaler::fit(&data, numerical::StandardScalerParameters::default())
|
||||
@@ -24,7 +24,7 @@
|
||||
//! vec![-1.0, -1.0],
|
||||
//! vec![1.0, 1.0],
|
||||
//! vec![1.0, 1.0],
|
||||
//! ])
|
||||
//! ]).unwrap()
|
||||
//! );
|
||||
//! ```
|
||||
use std::marker::PhantomData;
|
||||
@@ -172,18 +172,14 @@ where
|
||||
T: Number + RealNumber,
|
||||
M: Array2<T>,
|
||||
{
|
||||
if let Some(output_matrix) = columns.first().cloned() {
|
||||
return Some(
|
||||
columns
|
||||
.iter()
|
||||
.skip(1)
|
||||
.fold(output_matrix, |current_matrix, new_colum| {
|
||||
current_matrix.h_stack(new_colum)
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
None
|
||||
}
|
||||
columns.first().cloned().map(|output_matrix| {
|
||||
columns
|
||||
.iter()
|
||||
.skip(1)
|
||||
.fold(output_matrix, |current_matrix, new_colum| {
|
||||
current_matrix.h_stack(new_colum)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -197,15 +193,18 @@ mod tests {
|
||||
fn combine_three_columns() {
|
||||
assert_eq!(
|
||||
build_matrix_from_columns(vec![
|
||||
DenseMatrix::from_2d_vec(&vec![vec![1.0], vec![1.0], vec![1.0],]),
|
||||
DenseMatrix::from_2d_vec(&vec![vec![2.0], vec![2.0], vec![2.0],]),
|
||||
DenseMatrix::from_2d_vec(&vec![vec![3.0], vec![3.0], vec![3.0],])
|
||||
DenseMatrix::from_2d_vec(&vec![vec![1.0], vec![1.0], vec![1.0],]).unwrap(),
|
||||
DenseMatrix::from_2d_vec(&vec![vec![2.0], vec![2.0], vec![2.0],]).unwrap(),
|
||||
DenseMatrix::from_2d_vec(&vec![vec![3.0], vec![3.0], vec![3.0],]).unwrap()
|
||||
]),
|
||||
Some(DenseMatrix::from_2d_vec(&vec![
|
||||
vec![1.0, 2.0, 3.0],
|
||||
vec![1.0, 2.0, 3.0],
|
||||
vec![1.0, 2.0, 3.0]
|
||||
]))
|
||||
Some(
|
||||
DenseMatrix::from_2d_vec(&vec![
|
||||
vec![1.0, 2.0, 3.0],
|
||||
vec![1.0, 2.0, 3.0],
|
||||
vec![1.0, 2.0, 3.0]
|
||||
])
|
||||
.unwrap()
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
@@ -287,13 +286,15 @@ mod tests {
|
||||
/// sklearn.
|
||||
#[test]
|
||||
fn fit_transform_random_values() {
|
||||
let transformed_values =
|
||||
fit_transform_with_default_standard_scaler(&DenseMatrix::from_2d_array(&[
|
||||
let transformed_values = fit_transform_with_default_standard_scaler(
|
||||
&DenseMatrix::from_2d_array(&[
|
||||
&[0.1004222429, 0.2194113576, 0.9310663354, 0.3313593793],
|
||||
&[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264],
|
||||
&[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046],
|
||||
&[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442],
|
||||
]));
|
||||
])
|
||||
.unwrap(),
|
||||
);
|
||||
println!("{transformed_values}");
|
||||
assert!(transformed_values.approximate_eq(
|
||||
&DenseMatrix::from_2d_array(&[
|
||||
@@ -301,7 +302,8 @@ mod tests {
|
||||
&[-0.7615464283, -0.7076698384, -1.1075452562, 1.2632979631],
|
||||
&[0.4832504303, -0.6106747444, 1.0630075435, 0.5494084257],
|
||||
&[1.3936980634, 1.7215431158, -0.8839228078, -1.3855590021],
|
||||
]),
|
||||
])
|
||||
.unwrap(),
|
||||
1.0
|
||||
))
|
||||
}
|
||||
@@ -310,13 +312,10 @@ mod tests {
|
||||
#[test]
|
||||
fn fit_transform_with_zero_variance() {
|
||||
assert_eq!(
|
||||
fit_transform_with_default_standard_scaler(&DenseMatrix::from_2d_array(&[
|
||||
&[1.0],
|
||||
&[1.0],
|
||||
&[1.0],
|
||||
&[1.0]
|
||||
])),
|
||||
DenseMatrix::from_2d_array(&[&[0.0], &[0.0], &[0.0], &[0.0]]),
|
||||
fit_transform_with_default_standard_scaler(
|
||||
&DenseMatrix::from_2d_array(&[&[1.0], &[1.0], &[1.0], &[1.0]]).unwrap()
|
||||
),
|
||||
DenseMatrix::from_2d_array(&[&[0.0], &[0.0], &[0.0], &[0.0]]).unwrap(),
|
||||
"When scaling values with zero variance, zero is expected as return value"
|
||||
)
|
||||
}
|
||||
@@ -331,7 +330,8 @@ mod tests {
|
||||
&[1.0, 2.0, 5.0],
|
||||
&[1.0, 1.0, 1.0],
|
||||
&[1.0, 2.0, 5.0]
|
||||
]),
|
||||
])
|
||||
.unwrap(),
|
||||
StandardScalerParameters::default(),
|
||||
),
|
||||
Ok(StandardScaler {
|
||||
@@ -354,7 +354,8 @@ mod tests {
|
||||
&[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264],
|
||||
&[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046],
|
||||
&[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442],
|
||||
]),
|
||||
])
|
||||
.unwrap(),
|
||||
StandardScalerParameters::default(),
|
||||
)
|
||||
.unwrap();
|
||||
@@ -364,17 +365,18 @@ mod tests {
|
||||
vec![0.42864544605, 0.2869813741, 0.737752073825, 0.431011663625],
|
||||
);
|
||||
|
||||
assert!(
|
||||
&DenseMatrix::<f64>::from_2d_vec(&vec![fitted_scaler.stds]).approximate_eq(
|
||||
assert!(&DenseMatrix::<f64>::from_2d_vec(&vec![fitted_scaler.stds])
|
||||
.unwrap()
|
||||
.approximate_eq(
|
||||
&DenseMatrix::from_2d_array(&[&[
|
||||
0.29426447500954,
|
||||
0.16758497615485,
|
||||
0.20820945786863,
|
||||
0.23329718831165
|
||||
],]),
|
||||
],])
|
||||
.unwrap(),
|
||||
0.00000000000001
|
||||
)
|
||||
)
|
||||
))
|
||||
}
|
||||
|
||||
/// If `with_std` is set to `false` the values should not be
|
||||
@@ -392,8 +394,9 @@ mod tests {
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
standard_scaler.transform(&DenseMatrix::from_2d_array(&[&[0.0, 2.0], &[2.0, 4.0]])),
|
||||
Ok(DenseMatrix::from_2d_array(&[&[-1.0, -1.0], &[1.0, 1.0]]))
|
||||
standard_scaler
|
||||
.transform(&DenseMatrix::from_2d_array(&[&[0.0, 2.0], &[2.0, 4.0]]).unwrap()),
|
||||
Ok(DenseMatrix::from_2d_array(&[&[-1.0, -1.0], &[1.0, 1.0]]).unwrap())
|
||||
)
|
||||
}
|
||||
|
||||
@@ -413,8 +416,8 @@ mod tests {
|
||||
|
||||
assert_eq!(
|
||||
standard_scaler
|
||||
.transform(&DenseMatrix::from_2d_array(&[&[0.0, 9.0], &[4.0, 12.0]])),
|
||||
Ok(DenseMatrix::from_2d_array(&[&[0.0, 3.0], &[2.0, 4.0]]))
|
||||
.transform(&DenseMatrix::from_2d_array(&[&[0.0, 9.0], &[4.0, 12.0]]).unwrap()),
|
||||
Ok(DenseMatrix::from_2d_array(&[&[0.0, 3.0], &[2.0, 4.0]]).unwrap())
|
||||
)
|
||||
}
|
||||
|
||||
@@ -433,7 +436,8 @@ mod tests {
|
||||
&[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264],
|
||||
&[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046],
|
||||
&[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442],
|
||||
]),
|
||||
])
|
||||
.unwrap(),
|
||||
StandardScalerParameters::default(),
|
||||
)
|
||||
.unwrap();
|
||||
@@ -446,17 +450,18 @@ mod tests {
|
||||
vec![0.42864544605, 0.2869813741, 0.737752073825, 0.431011663625],
|
||||
);
|
||||
|
||||
assert!(
|
||||
&DenseMatrix::from_2d_vec(&vec![deserialized_scaler.stds]).approximate_eq(
|
||||
assert!(&DenseMatrix::from_2d_vec(&vec![deserialized_scaler.stds])
|
||||
.unwrap()
|
||||
.approximate_eq(
|
||||
&DenseMatrix::from_2d_array(&[&[
|
||||
0.29426447500954,
|
||||
0.16758497615485,
|
||||
0.20820945786863,
|
||||
0.23329718831165
|
||||
],]),
|
||||
],])
|
||||
.unwrap(),
|
||||
0.00000000000001
|
||||
)
|
||||
)
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+4
-3
@@ -30,7 +30,7 @@ pub struct CSVDefinition<'a> {
|
||||
/// What seperates the fields in your csv-file?
|
||||
field_seperator: &'a str,
|
||||
}
|
||||
impl<'a> Default for CSVDefinition<'a> {
|
||||
impl Default for CSVDefinition<'_> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
n_rows_header: 1,
|
||||
@@ -238,7 +238,8 @@ mod tests {
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
]))
|
||||
])
|
||||
.unwrap())
|
||||
)
|
||||
}
|
||||
#[test]
|
||||
@@ -261,7 +262,7 @@ mod tests {
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
]))
|
||||
]).unwrap())
|
||||
)
|
||||
}
|
||||
#[test]
|
||||
|
||||
+283
-179
@@ -25,14 +25,18 @@
|
||||
/// search parameters
|
||||
pub mod svc;
|
||||
pub mod svr;
|
||||
// /// search parameters space
|
||||
// pub mod search;
|
||||
// search parameters space
|
||||
pub mod search;
|
||||
|
||||
use core::fmt::Debug;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// Only import typetag if not compiling for wasm32 and serde is enabled
|
||||
#[cfg(all(feature = "serde", not(target_arch = "wasm32")))]
|
||||
use typetag;
|
||||
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::basic::arrays::{Array1, ArrayView1};
|
||||
|
||||
@@ -48,197 +52,281 @@ pub trait Kernel: Debug {
|
||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed>;
|
||||
}
|
||||
|
||||
/// Pre-defined kernel functions
|
||||
/// A enumerator for all the kernels type to support.
|
||||
/// This allows kernel selection and parameterization ergonomic, type-safe, and ready for use in parameter structs like SVRParameters.
|
||||
/// You can construct kernels using the provided variants and builder-style methods.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use smartcore::svm::Kernels;
|
||||
///
|
||||
/// let linear = Kernels::linear();
|
||||
/// let rbf = Kernels::rbf().with_gamma(0.5);
|
||||
/// let poly = Kernels::polynomial().with_degree(3.0).with_gamma(0.5).with_coef0(1.0);
|
||||
/// let sigmoid = Kernels::sigmoid().with_gamma(0.2).with_coef0(0.0);
|
||||
/// ```
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Kernels;
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum Kernels {
|
||||
/// Linear kernel (default).
|
||||
///
|
||||
/// Computes the standard dot product between vectors.
|
||||
Linear,
|
||||
|
||||
/// Radial Basis Function (RBF) kernel.
|
||||
///
|
||||
/// Formula: K(x, y) = exp(-gamma * ||x-y||²)
|
||||
RBF {
|
||||
/// Controls the width of the Gaussian RBF kernel.
|
||||
///
|
||||
/// Larger values of gamma lead to higher bias and lower variance.
|
||||
/// This parameter is inversely proportional to the radius of influence
|
||||
/// of samples selected by the model as support vectors.
|
||||
gamma: Option<f64>,
|
||||
},
|
||||
|
||||
/// Polynomial kernel.
|
||||
///
|
||||
/// Formula: K(x, y) = (gamma * <x, y> + coef0)^degree
|
||||
Polynomial {
|
||||
/// The degree of the polynomial kernel.
|
||||
///
|
||||
/// Integer values are typical (2 = quadratic, 3 = cubic), but any positive real value is valid.
|
||||
/// Higher degree values create decision boundaries with higher complexity.
|
||||
degree: Option<f64>,
|
||||
|
||||
/// Kernel coefficient for the dot product.
|
||||
///
|
||||
/// Controls the influence of higher-degree versus lower-degree terms in the polynomial.
|
||||
/// If None, a default value will be used.
|
||||
gamma: Option<f64>,
|
||||
|
||||
/// Independent term in the polynomial kernel.
|
||||
///
|
||||
/// Controls the influence of higher-degree versus lower-degree terms.
|
||||
/// If None, a default value of 1.0 will be used.
|
||||
coef0: Option<f64>,
|
||||
},
|
||||
|
||||
/// Sigmoid kernel.
|
||||
///
|
||||
/// Formula: K(x, y) = tanh(gamma * <x, y> + coef0)
|
||||
Sigmoid {
|
||||
/// Kernel coefficient for the dot product.
|
||||
///
|
||||
/// Controls the scaling of the dot product in the sigmoid function.
|
||||
/// If None, a default value will be used.
|
||||
gamma: Option<f64>,
|
||||
|
||||
/// Independent term in the sigmoid kernel.
|
||||
///
|
||||
/// Acts as a threshold/bias term in the sigmoid function.
|
||||
/// If None, a default value of 1.0 will be used.
|
||||
coef0: Option<f64>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Kernels {
|
||||
/// Return a default linear
|
||||
pub fn linear() -> LinearKernel {
|
||||
LinearKernel::default()
|
||||
/// Create a linear kernel.
|
||||
///
|
||||
/// The linear kernel computes the dot product between two vectors:
|
||||
/// K(x, y) = <x, y>
|
||||
pub fn linear() -> Self {
|
||||
Kernels::Linear
|
||||
}
|
||||
/// Return a default RBF
|
||||
pub fn rbf() -> RBFKernel {
|
||||
RBFKernel::default()
|
||||
|
||||
/// Create an RBF kernel with unspecified gamma.
|
||||
///
|
||||
/// The RBF kernel is defined as:
|
||||
/// K(x, y) = exp(-gamma * ||x-y||²)
|
||||
///
|
||||
/// You should specify gamma using `with_gamma()` before using this kernel.
|
||||
pub fn rbf() -> Self {
|
||||
Kernels::RBF { gamma: None }
|
||||
}
|
||||
/// Return a default polynomial
|
||||
pub fn polynomial() -> PolynomialKernel {
|
||||
PolynomialKernel::default()
|
||||
|
||||
/// Create a polynomial kernel with default parameters.
|
||||
///
|
||||
/// The polynomial kernel is defined as:
|
||||
/// K(x, y) = (gamma * <x, y> + coef0)^degree
|
||||
///
|
||||
/// Default values:
|
||||
/// - gamma: None (must be specified)
|
||||
/// - degree: None (must be specified)
|
||||
/// - coef0: 1.0
|
||||
pub fn polynomial() -> Self {
|
||||
Kernels::Polynomial {
|
||||
gamma: None,
|
||||
degree: None,
|
||||
coef0: Some(1.0),
|
||||
}
|
||||
}
|
||||
/// Return a default sigmoid
|
||||
pub fn sigmoid() -> SigmoidKernel {
|
||||
SigmoidKernel::default()
|
||||
|
||||
/// Create a sigmoid kernel with default parameters.
|
||||
///
|
||||
/// The sigmoid kernel is defined as:
|
||||
/// K(x, y) = tanh(gamma * <x, y> + coef0)
|
||||
///
|
||||
/// Default values:
|
||||
/// - gamma: None (must be specified)
|
||||
/// - coef0: 1.0
|
||||
///
|
||||
pub fn sigmoid() -> Self {
|
||||
Kernels::Sigmoid {
|
||||
gamma: None,
|
||||
coef0: Some(1.0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Linear Kernel
|
||||
#[allow(clippy::derive_partial_eq_without_eq)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||
pub struct LinearKernel;
|
||||
|
||||
/// Radial basis function (Gaussian) kernel
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Default, Clone, PartialEq)]
|
||||
pub struct RBFKernel {
|
||||
/// kernel coefficient
|
||||
pub gamma: Option<f64>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl RBFKernel {
|
||||
/// assign gamma parameter to kernel (required)
|
||||
/// ```rust
|
||||
/// use smartcore::svm::RBFKernel;
|
||||
/// let knl = RBFKernel::default().with_gamma(0.7);
|
||||
/// ```
|
||||
pub fn with_gamma(mut self, gamma: f64) -> Self {
|
||||
self.gamma = Some(gamma);
|
||||
self
|
||||
/// Set the `gamma` parameter for RBF, polynomial, or sigmoid kernels.
|
||||
///
|
||||
/// The gamma parameter has different interpretations depending on the kernel:
|
||||
/// - For RBF: Controls the width of the Gaussian. Larger values mean tighter fit.
|
||||
/// - For Polynomial: Scaling factor for the dot product.
|
||||
/// - For Sigmoid: Scaling factor for the dot product.
|
||||
///
|
||||
pub fn with_gamma(self, gamma: f64) -> Self {
|
||||
match self {
|
||||
Kernels::RBF { .. } => Kernels::RBF { gamma: Some(gamma) },
|
||||
Kernels::Polynomial { degree, coef0, .. } => Kernels::Polynomial {
|
||||
gamma: Some(gamma),
|
||||
degree,
|
||||
coef0,
|
||||
},
|
||||
Kernels::Sigmoid { coef0, .. } => Kernels::Sigmoid {
|
||||
gamma: Some(gamma),
|
||||
coef0,
|
||||
},
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Polynomial kernel
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct PolynomialKernel {
|
||||
/// degree of the polynomial
|
||||
pub degree: Option<f64>,
|
||||
/// kernel coefficient
|
||||
pub gamma: Option<f64>,
|
||||
/// independent term in kernel function
|
||||
pub coef0: Option<f64>,
|
||||
}
|
||||
/// Set the `degree` parameter for the polynomial kernel.
|
||||
///
|
||||
/// The degree parameter controls the flexibility of the decision boundary.
|
||||
/// Higher degrees create more complex boundaries but may lead to overfitting.
|
||||
///
|
||||
pub fn with_degree(self, degree: f64) -> Self {
|
||||
match self {
|
||||
Kernels::Polynomial { gamma, coef0, .. } => Kernels::Polynomial {
|
||||
degree: Some(degree),
|
||||
gamma,
|
||||
coef0,
|
||||
},
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PolynomialKernel {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
gamma: Option::None,
|
||||
degree: Option::None,
|
||||
coef0: Some(1f64),
|
||||
/// Set the `coef0` parameter for polynomial or sigmoid kernels.
|
||||
///
|
||||
/// The coef0 parameter is the independent term in the kernel function:
|
||||
/// - For Polynomial: Controls the influence of higher-degree vs. lower-degree terms.
|
||||
/// - For Sigmoid: Acts as a threshold/bias term.
|
||||
///
|
||||
pub fn with_coef0(self, coef0: f64) -> Self {
|
||||
match self {
|
||||
Kernels::Polynomial { degree, gamma, .. } => Kernels::Polynomial {
|
||||
degree,
|
||||
gamma,
|
||||
coef0: Some(coef0),
|
||||
},
|
||||
Kernels::Sigmoid { gamma, .. } => Kernels::Sigmoid {
|
||||
gamma,
|
||||
coef0: Some(coef0),
|
||||
},
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PolynomialKernel {
|
||||
/// set parameters for kernel
|
||||
/// ```rust
|
||||
/// use smartcore::svm::PolynomialKernel;
|
||||
/// let knl = PolynomialKernel::default().with_params(3.0, 0.7, 1.0);
|
||||
/// ```
|
||||
pub fn with_params(mut self, degree: f64, gamma: f64, coef0: f64) -> Self {
|
||||
self.degree = Some(degree);
|
||||
self.gamma = Some(gamma);
|
||||
self.coef0 = Some(coef0);
|
||||
self
|
||||
}
|
||||
/// set gamma parameter for kernel
|
||||
/// ```rust
|
||||
/// use smartcore::svm::PolynomialKernel;
|
||||
/// let knl = PolynomialKernel::default().with_gamma(0.7);
|
||||
/// ```
|
||||
pub fn with_gamma(mut self, gamma: f64) -> Self {
|
||||
self.gamma = Some(gamma);
|
||||
self
|
||||
}
|
||||
/// set degree parameter for kernel
|
||||
/// ```rust
|
||||
/// use smartcore::svm::PolynomialKernel;
|
||||
/// let knl = PolynomialKernel::default().with_degree(3.0, 100);
|
||||
/// ```
|
||||
pub fn with_degree(self, degree: f64, n_features: usize) -> Self {
|
||||
self.with_params(degree, 1f64, 1f64 / n_features as f64)
|
||||
}
|
||||
}
|
||||
|
||||
/// Sigmoid (hyperbolic tangent) kernel
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct SigmoidKernel {
|
||||
/// kernel coefficient
|
||||
pub gamma: Option<f64>,
|
||||
/// independent term in kernel function
|
||||
pub coef0: Option<f64>,
|
||||
}
|
||||
|
||||
impl Default for SigmoidKernel {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
gamma: Option::None,
|
||||
coef0: Some(1f64),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SigmoidKernel {
|
||||
/// set parameters for kernel
|
||||
/// ```rust
|
||||
/// use smartcore::svm::SigmoidKernel;
|
||||
/// let knl = SigmoidKernel::default().with_params(0.7, 1.0);
|
||||
/// ```
|
||||
pub fn with_params(mut self, gamma: f64, coef0: f64) -> Self {
|
||||
self.gamma = Some(gamma);
|
||||
self.coef0 = Some(coef0);
|
||||
self
|
||||
}
|
||||
/// set gamma parameter for kernel
|
||||
/// ```rust
|
||||
/// use smartcore::svm::SigmoidKernel;
|
||||
/// let knl = SigmoidKernel::default().with_gamma(0.7);
|
||||
/// ```
|
||||
pub fn with_gamma(mut self, gamma: f64) -> Self {
|
||||
self.gamma = Some(gamma);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of the [`Kernel`] trait for the [`Kernels`] enum in smartcore.
|
||||
///
|
||||
/// This method computes the value of the kernel function between two feature vectors `x_i` and `x_j`,
|
||||
/// according to the variant and parameters of the [`Kernels`] enum. This enables flexible and type-safe
|
||||
/// selection of kernel functions for SVM and SVR models in smartcore.
|
||||
///
|
||||
/// # Supported Kernels
|
||||
///
|
||||
/// - [`Kernels::Linear`]: Computes the standard dot product between `x_i` and `x_j`.
|
||||
/// - [`Kernels::RBF`]: Computes the Radial Basis Function (Gaussian) kernel. Requires `gamma`.
|
||||
/// - [`Kernels::Polynomial`]: Computes the polynomial kernel. Requires `degree`, `gamma`, and `coef0`.
|
||||
/// - [`Kernels::Sigmoid`]: Computes the sigmoid kernel. Requires `gamma` and `coef0`.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `x_i`: First input vector (feature vector).
|
||||
/// - `x_j`: Second input vector (feature vector).
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// - `Ok(f64)`: The computed kernel value.
|
||||
/// - `Err(Failed)`: If any required kernel parameter is missing.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err(Failed)` if a required parameter (such as `gamma`, `degree`, or `coef0`)
|
||||
/// is `None` for the selected kernel variant.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use smartcore::svm::Kernels;
|
||||
/// use smartcore::svm::Kernel;
|
||||
///
|
||||
/// let x = vec![1.0, 2.0, 3.0];
|
||||
/// let y = vec![4.0, 5.0, 6.0];
|
||||
/// let kernel = Kernels::rbf().with_gamma(0.5);
|
||||
/// let value = kernel.apply(&x, &y).unwrap();
|
||||
/// ```
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// - This implementation follows smartcore's philosophy: pure Rust, no macros, no unsafe code,
|
||||
/// and an accessible, pythonic API surface for both ML practitioners and Rust beginners.
|
||||
/// - All kernel parameters must be set before calling `apply`; missing parameters will result in an error.
|
||||
///
|
||||
/// See the [`Kernels`] enum documentation for more details on each kernel type and its parameters.
|
||||
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
|
||||
impl Kernel for LinearKernel {
|
||||
impl Kernel for Kernels {
|
||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
||||
Ok(x_i.dot(x_j))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
|
||||
impl Kernel for RBFKernel {
|
||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
||||
if self.gamma.is_none() {
|
||||
return Err(Failed::because(
|
||||
FailedError::ParametersError,
|
||||
"gamma should be set, use {Kernel}::default().with_gamma(..)",
|
||||
));
|
||||
match self {
|
||||
Kernels::Linear => Ok(x_i.dot(x_j)),
|
||||
Kernels::RBF { gamma } => {
|
||||
let gamma = gamma.ok_or_else(|| {
|
||||
Failed::because(FailedError::ParametersError, "gamma not set")
|
||||
})?;
|
||||
let v_diff = x_i.sub(x_j);
|
||||
Ok((-gamma * v_diff.mul(&v_diff).sum()).exp())
|
||||
}
|
||||
Kernels::Polynomial {
|
||||
degree,
|
||||
gamma,
|
||||
coef0,
|
||||
} => {
|
||||
let degree = degree.ok_or_else(|| {
|
||||
Failed::because(FailedError::ParametersError, "degree not set")
|
||||
})?;
|
||||
let gamma = gamma.ok_or_else(|| {
|
||||
Failed::because(FailedError::ParametersError, "gamma not set")
|
||||
})?;
|
||||
let coef0 = coef0.ok_or_else(|| {
|
||||
Failed::because(FailedError::ParametersError, "coef0 not set")
|
||||
})?;
|
||||
let dot = x_i.dot(x_j);
|
||||
Ok((gamma * dot + coef0).powf(degree))
|
||||
}
|
||||
Kernels::Sigmoid { gamma, coef0 } => {
|
||||
let gamma = gamma.ok_or_else(|| {
|
||||
Failed::because(FailedError::ParametersError, "gamma not set")
|
||||
})?;
|
||||
let coef0 = coef0.ok_or_else(|| {
|
||||
Failed::because(FailedError::ParametersError, "coef0 not set")
|
||||
})?;
|
||||
let dot = x_i.dot(x_j);
|
||||
Ok((gamma * dot + coef0).tanh())
|
||||
}
|
||||
}
|
||||
let v_diff = x_i.sub(x_j);
|
||||
Ok((-self.gamma.unwrap() * v_diff.mul(&v_diff).sum()).exp())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
|
||||
impl Kernel for PolynomialKernel {
|
||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
||||
if self.gamma.is_none() || self.coef0.is_none() || self.degree.is_none() {
|
||||
return Err(Failed::because(
|
||||
FailedError::ParametersError, "gamma, coef0, degree should be set,
|
||||
use {Kernel}::default().with_{parameter}(..)")
|
||||
);
|
||||
}
|
||||
let dot = x_i.dot(x_j);
|
||||
Ok((self.gamma.unwrap() * dot + self.coef0.unwrap()).powf(self.degree.unwrap()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(all(feature = "serde", not(target_arch = "wasm32")), typetag::serde)]
|
||||
impl Kernel for SigmoidKernel {
|
||||
fn apply(&self, x_i: &Vec<f64>, x_j: &Vec<f64>) -> Result<f64, Failed> {
|
||||
if self.gamma.is_none() || self.coef0.is_none() {
|
||||
return Err(Failed::because(
|
||||
FailedError::ParametersError, "gamma, coef0, degree should be set,
|
||||
use {Kernel}::default().with_{parameter}(..)")
|
||||
);
|
||||
}
|
||||
let dot = x_i.dot(x_j);
|
||||
Ok(self.gamma.unwrap() * dot + self.coef0.unwrap().tanh())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -247,6 +335,18 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::svm::Kernels;
|
||||
|
||||
#[test]
|
||||
fn rbf_kernel() {
|
||||
let v1 = vec![1., 2., 3.];
|
||||
let v2 = vec![4., 5., 6.];
|
||||
let result = Kernels::rbf()
|
||||
.with_gamma(0.055)
|
||||
.apply(&v1, &v2)
|
||||
.unwrap()
|
||||
.abs();
|
||||
assert!((0.2265f64 - result) < 1e-4);
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
@@ -264,7 +364,7 @@ mod tests {
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn rbf_kernel() {
|
||||
fn test_rbf_kernel() {
|
||||
let v1 = vec![1., 2., 3.];
|
||||
let v2 = vec![4., 5., 6.];
|
||||
|
||||
@@ -287,12 +387,15 @@ mod tests {
|
||||
let v2 = vec![4., 5., 6.];
|
||||
|
||||
let result = Kernels::polynomial()
|
||||
.with_params(3.0, 0.5, 1.0)
|
||||
.with_gamma(0.5)
|
||||
.with_degree(3.0)
|
||||
.with_coef0(1.0)
|
||||
//.with_params(3.0, 0.5, 1.0)
|
||||
.apply(&v1, &v2)
|
||||
.unwrap()
|
||||
.abs();
|
||||
|
||||
assert!((4913f64 - result) < std::f64::EPSILON);
|
||||
assert!((4913f64 - result).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
@@ -305,7 +408,8 @@ mod tests {
|
||||
let v2 = vec![4., 5., 6.];
|
||||
|
||||
let result = Kernels::sigmoid()
|
||||
.with_params(0.01, 0.1)
|
||||
.with_gamma(0.01)
|
||||
.with_coef0(0.1)
|
||||
.apply(&v1, &v2)
|
||||
.unwrap()
|
||||
.abs();
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//! SVC and Grid Search
|
||||
|
||||
/// SVC search parameters
|
||||
pub mod svc_params;
|
||||
/// SVC search parameters
|
||||
|
||||
+282
-101
@@ -1,112 +1,293 @@
|
||||
// /// SVR grid search parameters
|
||||
// #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
// #[derive(Debug, Clone)]
|
||||
// pub struct SVRSearchParameters<T: Number + RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||
// /// Epsilon in the epsilon-SVR model.
|
||||
// pub eps: Vec<T>,
|
||||
// /// Regularization parameter.
|
||||
// pub c: Vec<T>,
|
||||
// /// Tolerance for stopping eps.
|
||||
// pub tol: Vec<T>,
|
||||
// /// The kernel function.
|
||||
// pub kernel: Vec<K>,
|
||||
// /// Unused parameter.
|
||||
// m: PhantomData<M>,
|
||||
// }
|
||||
//! # SVR Grid Search Parameters
|
||||
//!
|
||||
//! This module provides utilities for defining and iterating over grid search parameter spaces
|
||||
//! for Support Vector Regression (SVR) models in [smartcore](https://github.com/smartcorelib/smartcore).
|
||||
//!
|
||||
//! The main struct, [`SVRSearchParameters`], allows users to specify multiple values for each
|
||||
//! SVR hyperparameter (epsilon, regularization parameter C, tolerance, and kernel function).
|
||||
//! The provided iterator yields all possible combinations (the Cartesian product) of these parameters,
|
||||
//! enabling exhaustive grid search for hyperparameter tuning.
|
||||
//!
|
||||
//!
|
||||
//! ## Example
|
||||
//! ```
|
||||
//! use smartcore::svm::Kernels;
|
||||
//! use smartcore::svm::search::svr_params::SVRSearchParameters;
|
||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
//!
|
||||
//! let params = SVRSearchParameters::<f64, DenseMatrix<f64>> {
|
||||
//! eps: vec![0.1, 0.2],
|
||||
//! c: vec![1.0, 10.0],
|
||||
//! tol: vec![1e-3],
|
||||
//! kernel: vec![Kernels::linear(), Kernels::rbf().with_gamma(0.5)],
|
||||
//! m: std::marker::PhantomData,
|
||||
//! };
|
||||
//!
|
||||
//! // for param_set in params.into_iter() {
|
||||
//! // Use param_set (of type svr::SVRParameters) to fit and evaluate your SVR model.
|
||||
//! // }
|
||||
//! ```
|
||||
//!
|
||||
//!
|
||||
//! ## Note
|
||||
//! This module is intended for use with smartcore version 0.4 or later. The API is not compatible with older versions[1].
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// /// SVR grid search iterator
|
||||
// pub struct SVRSearchParametersIterator<T: Number + RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||
// svr_search_parameters: SVRSearchParameters<T, M, K>,
|
||||
// current_eps: usize,
|
||||
// current_c: usize,
|
||||
// current_tol: usize,
|
||||
// current_kernel: usize,
|
||||
// }
|
||||
use crate::linalg::basic::arrays::Array2;
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::floatnum::FloatNumber;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
use crate::svm::{svr, Kernels};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
// impl<T: Number + RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> IntoIterator
|
||||
// for SVRSearchParameters<T, M, K>
|
||||
// {
|
||||
// type Item = SVRParameters<T, M, K>;
|
||||
// type IntoIter = SVRSearchParametersIterator<T, M, K>;
|
||||
/// ## SVR grid search parameters
|
||||
/// A struct representing a grid of hyperparameters for SVR grid search in smartcore.
|
||||
///
|
||||
/// Each field is a vector of possible values for the corresponding SVR hyperparameter.
|
||||
/// The [`IntoIterator`] implementation yields every possible combination of these parameters
|
||||
/// as an `svr::SVRParameters` struct, suitable for use in model selection routines.
|
||||
///
|
||||
/// # Type Parameters
|
||||
/// - `T`: Numeric type for parameters (e.g., `f64`)
|
||||
/// - `M`: Matrix type implementing [`Array2<T>`]
|
||||
///
|
||||
/// # Fields
|
||||
/// - `eps`: Vector of epsilon values for the epsilon-insensitive loss in SVR.
|
||||
/// - `c`: Vector of regularization parameters (C) for SVR.
|
||||
/// - `tol`: Vector of tolerance values for the stopping criterion.
|
||||
/// - `kernel`: Vector of kernel function variants (see [`Kernels`]).
|
||||
/// - `m`: Phantom data for the matrix type parameter.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use smartcore::svm::Kernels;
|
||||
/// use smartcore::svm::search::svr_params::SVRSearchParameters;
|
||||
/// use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
///
|
||||
/// let params = SVRSearchParameters::<f64, DenseMatrix<f64>> {
|
||||
/// eps: vec![0.1, 0.2],
|
||||
/// c: vec![1.0, 10.0],
|
||||
/// tol: vec![1e-3],
|
||||
/// kernel: vec![Kernels::linear(), Kernels::rbf().with_gamma(0.5)],
|
||||
/// m: std::marker::PhantomData,
|
||||
/// };
|
||||
/// ```
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SVRSearchParameters<T: Number + RealNumber, M: Array2<T>> {
|
||||
/// Epsilon in the epsilon-SVR model.
|
||||
pub eps: Vec<T>,
|
||||
/// Regularization parameter.
|
||||
pub c: Vec<T>,
|
||||
/// Tolerance for stopping eps.
|
||||
pub tol: Vec<T>,
|
||||
/// The kernel function.
|
||||
pub kernel: Vec<Kernels>,
|
||||
/// Unused parameter.
|
||||
pub m: PhantomData<M>,
|
||||
}
|
||||
|
||||
// fn into_iter(self) -> Self::IntoIter {
|
||||
// SVRSearchParametersIterator {
|
||||
// svr_search_parameters: self,
|
||||
// current_eps: 0,
|
||||
// current_c: 0,
|
||||
// current_tol: 0,
|
||||
// current_kernel: 0,
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
/// SVR grid search iterator
|
||||
pub struct SVRSearchParametersIterator<T: Number + RealNumber, M: Array2<T>> {
|
||||
svr_search_parameters: SVRSearchParameters<T, M>,
|
||||
current_eps: usize,
|
||||
current_c: usize,
|
||||
current_tol: usize,
|
||||
current_kernel: usize,
|
||||
}
|
||||
|
||||
// impl<T: Number + RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Iterator
|
||||
// for SVRSearchParametersIterator<T, M, K>
|
||||
// {
|
||||
// type Item = SVRParameters<T, M, K>;
|
||||
impl<T: Number + FloatNumber + RealNumber, M: Array2<T>> IntoIterator
|
||||
for SVRSearchParameters<T, M>
|
||||
{
|
||||
type Item = svr::SVRParameters<T>;
|
||||
type IntoIter = SVRSearchParametersIterator<T, M>;
|
||||
|
||||
// fn next(&mut self) -> Option<Self::Item> {
|
||||
// if self.current_eps == self.svr_search_parameters.eps.len()
|
||||
// && self.current_c == self.svr_search_parameters.c.len()
|
||||
// && self.current_tol == self.svr_search_parameters.tol.len()
|
||||
// && self.current_kernel == self.svr_search_parameters.kernel.len()
|
||||
// {
|
||||
// return None;
|
||||
// }
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
SVRSearchParametersIterator {
|
||||
svr_search_parameters: self,
|
||||
current_eps: 0,
|
||||
current_c: 0,
|
||||
current_tol: 0,
|
||||
current_kernel: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// let next = SVRParameters::<T, M, K> {
|
||||
// eps: self.svr_search_parameters.eps[self.current_eps],
|
||||
// c: self.svr_search_parameters.c[self.current_c],
|
||||
// tol: self.svr_search_parameters.tol[self.current_tol],
|
||||
// kernel: self.svr_search_parameters.kernel[self.current_kernel].clone(),
|
||||
// m: PhantomData,
|
||||
// };
|
||||
impl<T: Number + FloatNumber + RealNumber, M: Array2<T>> Iterator
|
||||
for SVRSearchParametersIterator<T, M>
|
||||
{
|
||||
type Item = svr::SVRParameters<T>;
|
||||
|
||||
// if self.current_eps + 1 < self.svr_search_parameters.eps.len() {
|
||||
// self.current_eps += 1;
|
||||
// } else if self.current_c + 1 < self.svr_search_parameters.c.len() {
|
||||
// self.current_eps = 0;
|
||||
// self.current_c += 1;
|
||||
// } else if self.current_tol + 1 < self.svr_search_parameters.tol.len() {
|
||||
// self.current_eps = 0;
|
||||
// self.current_c = 0;
|
||||
// self.current_tol += 1;
|
||||
// } else if self.current_kernel + 1 < self.svr_search_parameters.kernel.len() {
|
||||
// self.current_eps = 0;
|
||||
// self.current_c = 0;
|
||||
// self.current_tol = 0;
|
||||
// self.current_kernel += 1;
|
||||
// } else {
|
||||
// self.current_eps += 1;
|
||||
// self.current_c += 1;
|
||||
// self.current_tol += 1;
|
||||
// self.current_kernel += 1;
|
||||
// }
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_eps == self.svr_search_parameters.eps.len()
|
||||
&& self.current_c == self.svr_search_parameters.c.len()
|
||||
&& self.current_tol == self.svr_search_parameters.tol.len()
|
||||
&& self.current_kernel == self.svr_search_parameters.kernel.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
// Some(next)
|
||||
// }
|
||||
// }
|
||||
let next = svr::SVRParameters::<T> {
|
||||
eps: self.svr_search_parameters.eps[self.current_eps],
|
||||
c: self.svr_search_parameters.c[self.current_c],
|
||||
tol: self.svr_search_parameters.tol[self.current_tol],
|
||||
kernel: Some(self.svr_search_parameters.kernel[self.current_kernel].clone()),
|
||||
};
|
||||
|
||||
// impl<T: Number + RealNumber, M: Matrix<T>> Default for SVRSearchParameters<T, M, LinearKernel> {
|
||||
// fn default() -> Self {
|
||||
// let default_params: SVRParameters<T, M, LinearKernel> = SVRParameters::default();
|
||||
if self.current_eps + 1 < self.svr_search_parameters.eps.len() {
|
||||
self.current_eps += 1;
|
||||
} else if self.current_c + 1 < self.svr_search_parameters.c.len() {
|
||||
self.current_eps = 0;
|
||||
self.current_c += 1;
|
||||
} else if self.current_tol + 1 < self.svr_search_parameters.tol.len() {
|
||||
self.current_eps = 0;
|
||||
self.current_c = 0;
|
||||
self.current_tol += 1;
|
||||
} else if self.current_kernel + 1 < self.svr_search_parameters.kernel.len() {
|
||||
self.current_eps = 0;
|
||||
self.current_c = 0;
|
||||
self.current_tol = 0;
|
||||
self.current_kernel += 1;
|
||||
} else {
|
||||
self.current_eps += 1;
|
||||
self.current_c += 1;
|
||||
self.current_tol += 1;
|
||||
self.current_kernel += 1;
|
||||
}
|
||||
|
||||
// SVRSearchParameters {
|
||||
// eps: vec![default_params.eps],
|
||||
// c: vec![default_params.c],
|
||||
// tol: vec![default_params.tol],
|
||||
// kernel: vec![default_params.kernel],
|
||||
// m: PhantomData,
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
// #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
// #[derive(Debug)]
|
||||
// #[cfg_attr(
|
||||
// feature = "serde",
|
||||
// serde(bound(
|
||||
// serialize = "M::RowVector: Serialize, K: Serialize, T: Serialize",
|
||||
// deserialize = "M::RowVector: Deserialize<'de>, K: Deserialize<'de>, T: Deserialize<'de>",
|
||||
// ))
|
||||
// )]
|
||||
impl<T: Number + FloatNumber + RealNumber, M: Array2<T>> Default for SVRSearchParameters<T, M> {
|
||||
fn default() -> Self {
|
||||
let default_params: svr::SVRParameters<T> = svr::SVRParameters::default();
|
||||
|
||||
SVRSearchParameters {
|
||||
eps: vec![default_params.eps],
|
||||
c: vec![default_params.c],
|
||||
tol: vec![default_params.tol],
|
||||
kernel: vec![default_params.kernel.unwrap_or_else(Kernels::linear)],
|
||||
m: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use crate::svm::Kernels;
|
||||
|
||||
type T = f64;
|
||||
type M = DenseMatrix<T>;
|
||||
|
||||
#[test]
|
||||
fn test_default_parameters() {
|
||||
let params = SVRSearchParameters::<T, M>::default();
|
||||
assert_eq!(params.eps.len(), 1);
|
||||
assert_eq!(params.c.len(), 1);
|
||||
assert_eq!(params.tol.len(), 1);
|
||||
assert_eq!(params.kernel.len(), 1);
|
||||
// Check that the default kernel is linear
|
||||
assert_eq!(params.kernel[0], Kernels::linear());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_grid_iteration() {
|
||||
let params = SVRSearchParameters::<T, M> {
|
||||
eps: vec![0.1],
|
||||
c: vec![1.0],
|
||||
tol: vec![1e-3],
|
||||
kernel: vec![Kernels::rbf().with_gamma(0.5)],
|
||||
m: PhantomData,
|
||||
};
|
||||
let mut iter = params.into_iter();
|
||||
let param = iter.next().unwrap();
|
||||
assert_eq!(param.eps, 0.1);
|
||||
assert_eq!(param.c, 1.0);
|
||||
assert_eq!(param.tol, 1e-3);
|
||||
assert_eq!(param.kernel, Some(Kernels::rbf().with_gamma(0.5)));
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cartesian_grid_iteration() {
|
||||
let params = SVRSearchParameters::<T, M> {
|
||||
eps: vec![0.1, 0.2],
|
||||
c: vec![1.0, 2.0],
|
||||
tol: vec![1e-3],
|
||||
kernel: vec![Kernels::linear(), Kernels::rbf().with_gamma(0.5)],
|
||||
m: PhantomData,
|
||||
};
|
||||
let expected_count =
|
||||
params.eps.len() * params.c.len() * params.tol.len() * params.kernel.len();
|
||||
let results: Vec<_> = params.into_iter().collect();
|
||||
assert_eq!(results.len(), expected_count);
|
||||
|
||||
// Check that all parameter combinations are present
|
||||
let mut seen = vec![];
|
||||
for p in &results {
|
||||
seen.push((p.eps, p.c, p.tol, p.kernel.clone().unwrap()));
|
||||
}
|
||||
for &eps in &[0.1, 0.2] {
|
||||
for &c in &[1.0, 2.0] {
|
||||
for &tol in &[1e-3] {
|
||||
for kernel in &[Kernels::linear(), Kernels::rbf().with_gamma(0.5)] {
|
||||
assert!(seen.contains(&(eps, c, tol, kernel.clone())));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_grid() {
|
||||
let params = SVRSearchParameters::<T, M> {
|
||||
eps: vec![],
|
||||
c: vec![],
|
||||
tol: vec![],
|
||||
kernel: vec![],
|
||||
m: PhantomData,
|
||||
};
|
||||
let mut iter = params.into_iter();
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kernel_enum_variants() {
|
||||
let lin = Kernels::linear();
|
||||
let rbf = Kernels::rbf().with_gamma(0.2);
|
||||
let poly = Kernels::polynomial()
|
||||
.with_degree(2.0)
|
||||
.with_gamma(1.0)
|
||||
.with_coef0(0.5);
|
||||
let sig = Kernels::sigmoid().with_gamma(0.3).with_coef0(0.1);
|
||||
|
||||
assert_eq!(lin, Kernels::Linear);
|
||||
match rbf {
|
||||
Kernels::RBF { gamma } => assert_eq!(gamma, Some(0.2)),
|
||||
_ => panic!("Not RBF"),
|
||||
}
|
||||
match poly {
|
||||
Kernels::Polynomial {
|
||||
degree,
|
||||
gamma,
|
||||
coef0,
|
||||
} => {
|
||||
assert_eq!(degree, Some(2.0));
|
||||
assert_eq!(gamma, Some(1.0));
|
||||
assert_eq!(coef0, Some(0.5));
|
||||
}
|
||||
_ => panic!("Not Polynomial"),
|
||||
}
|
||||
match sig {
|
||||
Kernels::Sigmoid { gamma, coef0 } => {
|
||||
assert_eq!(gamma, Some(0.3));
|
||||
assert_eq!(coef0, Some(0.1));
|
||||
}
|
||||
_ => panic!("Not Sigmoid"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+389
-70
@@ -53,15 +53,16 @@
|
||||
//! &[4.9, 2.4, 3.3, 1.0],
|
||||
//! &[6.6, 2.9, 4.6, 1.3],
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//! let y = vec![ -1, -1, -1, -1, -1, -1, -1, -1,
|
||||
//! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
|
||||
//!
|
||||
//! let knl = Kernels::linear();
|
||||
//! let params = &SVCParameters::default().with_c(200.0).with_kernel(knl);
|
||||
//! let svc = SVC::fit(&x, &y, params).unwrap();
|
||||
//! let parameters = &SVCParameters::default().with_c(200.0).with_kernel(knl);
|
||||
//! let svc = SVC::fit(&x, &y, parameters).unwrap();
|
||||
//!
|
||||
//! let y_hat = svc.predict(&x).unwrap();
|
||||
//!
|
||||
//! ```
|
||||
//!
|
||||
//! ## References:
|
||||
@@ -84,12 +85,194 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{PredictorBorrow, SupervisedEstimatorBorrow};
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::basic::arrays::{Array1, Array2, MutArray};
|
||||
use crate::linalg::basic::arrays::{Array, Array1, Array2, MutArray};
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::realnum::RealNumber;
|
||||
use crate::rand_custom::get_rng_impl;
|
||||
use crate::svm::Kernel;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
/// Configuration for a multi-class Support Vector Machine (SVM) classifier.
|
||||
/// This struct holds the indices of the data points relevant to a specific binary
|
||||
/// classification problem within a multi-class context, and the two classes
|
||||
/// being discriminated.
|
||||
struct MultiClassConfig<TY: Number + Ord> {
|
||||
/// The indices of the data points from the original dataset that belong to the two `classes`.
|
||||
indices: Vec<usize>,
|
||||
/// A tuple representing the two classes that this configuration is designed to distinguish.
|
||||
classes: (TY, TY),
|
||||
}
|
||||
|
||||
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
SupervisedEstimatorBorrow<'a, X, Y, SVCParameters<TX, TY, X, Y>>
|
||||
for MultiClassSVC<'a, TX, TY, X, Y>
|
||||
{
|
||||
/// Creates a new, empty `MultiClassSVC` instance.
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
classifiers: Option::None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Fits the `MultiClassSVC` model to the provided data and parameters.
|
||||
///
|
||||
/// This method delegates the fitting process to the inherent `MultiClassSVC::fit` method.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `x` - A reference to the input features (2D array).
|
||||
/// * `y` - A reference to the target labels (1D array).
|
||||
/// * `parameters` - A reference to the `SVCParameters` controlling the SVM training.
|
||||
///
|
||||
/// # Returns
|
||||
/// A `Result` indicating success (`Self`) or failure (`Failed`).
|
||||
fn fit(
|
||||
x: &'a X,
|
||||
y: &'a Y,
|
||||
parameters: &'a SVCParameters<TX, TY, X, Y>,
|
||||
) -> Result<Self, Failed> {
|
||||
MultiClassSVC::fit(x, y, parameters)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
PredictorBorrow<'a, X, TX> for MultiClassSVC<'a, TX, TY, X, Y>
|
||||
{
|
||||
/// Predicts the class labels for new data points.
|
||||
///
|
||||
/// This method delegates the prediction process to the inherent `MultiClassSVC::predict` method.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `x` - A reference to the input features (2D array) for which to make predictions.
|
||||
///
|
||||
/// # Returns
|
||||
/// A `Result` containing a `Vec` of predicted class labels (`TX`) or a `Failed` error.
|
||||
fn predict(&self, x: &'a X) -> Result<Vec<TX>, Failed> {
|
||||
Ok(self.predict(x).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
/// A multi-class Support Vector Machine (SVM) classifier.
|
||||
///
|
||||
/// This struct implements a multi-class SVM using the "one-vs-one" strategy,
|
||||
/// where a separate binary SVC classifier is trained for every pair of classes.
|
||||
///
|
||||
/// # Type Parameters
|
||||
/// * `'a` - Lifetime parameter for borrowed data.
|
||||
/// * `TX` - The numeric type of the input features (must implement `Number` and `RealNumber`).
|
||||
/// * `TY` - The numeric type of the target labels (must implement `Number` and `Ord`).
|
||||
/// * `X` - The type representing the 2D array of input features (e.g., a matrix).
|
||||
/// * `Y` - The type representing the 1D array of target labels (e.g., a vector).
|
||||
pub struct MultiClassSVC<
|
||||
'a,
|
||||
TX: Number + RealNumber,
|
||||
TY: Number + Ord,
|
||||
X: Array2<TX>,
|
||||
Y: Array1<TY>,
|
||||
> {
|
||||
/// An optional vector of binary `SVC` classifiers.
|
||||
classifiers: Option<Vec<SVC<'a, TX, TY, X, Y>>>,
|
||||
}
|
||||
|
||||
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
MultiClassSVC<'a, TX, TY, X, Y>
|
||||
{
|
||||
/// Fits the `MultiClassSVC` model to the provided data using a one-vs-one strategy.
|
||||
///
|
||||
/// This method identifies all unique classes in the target labels `y` and then
|
||||
/// trains a binary `SVC` for every unique pair of classes. For each pair, it
|
||||
/// extracts the relevant data points and their labels, and then trains a
|
||||
/// specialized `SVC` for that binary classification task.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `x` - A reference to the input features (2D array).
|
||||
/// * `y` - A reference to the target labels (1D array).
|
||||
/// * `parameters` - A reference to the `SVCParameters` controlling the SVM training for each individual binary classifier.
|
||||
///
|
||||
///
|
||||
/// # Returns
|
||||
/// A `Result` indicating success (`MultiClassSVC`) or failure (`Failed`).
|
||||
pub fn fit(
|
||||
x: &'a X,
|
||||
y: &'a Y,
|
||||
parameters: &'a SVCParameters<TX, TY, X, Y>,
|
||||
) -> Result<MultiClassSVC<'a, TX, TY, X, Y>, Failed> {
|
||||
let unique_classes = y.unique();
|
||||
let mut classifiers = Vec::new();
|
||||
// Iterate through all unique pairs of classes (one-vs-one strategy)
|
||||
for i in 0..unique_classes.len() {
|
||||
for j in i..unique_classes.len() {
|
||||
if i == j {
|
||||
continue;
|
||||
}
|
||||
let class0 = unique_classes[j];
|
||||
let class1 = unique_classes[i];
|
||||
|
||||
let mut indices = Vec::new();
|
||||
// Collect indices of data points belonging to the current pair of classes
|
||||
for (index, v) in y.iterator(0).enumerate() {
|
||||
if *v == class0 || *v == class1 {
|
||||
indices.push(index)
|
||||
}
|
||||
}
|
||||
let classes = (class0, class1);
|
||||
let multiclass_config = MultiClassConfig { classes, indices };
|
||||
// Fit a binary SVC for the current pair of classes
|
||||
let svc = SVC::multiclass_fit(x, y, parameters, multiclass_config).unwrap();
|
||||
classifiers.push(svc);
|
||||
}
|
||||
}
|
||||
Ok(Self {
|
||||
classifiers: Some(classifiers),
|
||||
})
|
||||
}
|
||||
|
||||
/// Predicts the class labels for new data points using the trained multi-class SVM.
|
||||
///
|
||||
/// This method uses a "voting" scheme (majority vote) among all the binary
|
||||
/// classifiers to determine the final prediction for each data point.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `x` - A reference to the input features (2D array) for which to make predictions.
|
||||
///
|
||||
/// # Returns
|
||||
/// A `Result` containing a `Vec` of predicted class labels (`TX`) or a `Failed` error.
|
||||
///
|
||||
pub fn predict(&self, x: &X) -> Result<Vec<TX>, Failed> {
|
||||
// Initialize a HashMap for each data point to store votes for each class
|
||||
let mut polls = vec![HashMap::new(); x.shape().0];
|
||||
// Retrieve the trained binary classifiers
|
||||
let classifiers = self.classifiers.as_ref().unwrap();
|
||||
|
||||
// Iterate through each binary classifier
|
||||
for i in 0..classifiers.len() {
|
||||
let svc = classifiers.get(i).unwrap();
|
||||
let predictions = svc.predict(x).unwrap(); // call SVC::predict for each binary classifier
|
||||
|
||||
// For each prediction from the current binary classifier
|
||||
for (j, prediction) in predictions.iter().enumerate() {
|
||||
let prediction = prediction.to_i32().unwrap();
|
||||
let poll = polls.get_mut(j).unwrap(); // Get the poll for the current data point
|
||||
// Increment the vote for the predicted class
|
||||
if let Some(count) = poll.get_mut(&prediction) {
|
||||
*count += 1
|
||||
} else {
|
||||
poll.insert(prediction, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Determine the final prediction for each data point based on majority vote
|
||||
Ok(polls
|
||||
.iter()
|
||||
.map(|v| {
|
||||
// Find the class with the maximum votes for each data point
|
||||
TX::from(*v.iter().max_by_key(|(_, class)| *class).unwrap().0).unwrap()
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
/// SVC Parameters
|
||||
@@ -123,7 +306,7 @@ pub struct SVCParameters<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX
|
||||
)]
|
||||
/// Support Vector Classifier
|
||||
pub struct SVC<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> {
|
||||
classes: Option<Vec<TY>>,
|
||||
classes: Option<(TY, TY)>,
|
||||
instances: Option<Vec<Vec<TX>>>,
|
||||
#[cfg_attr(feature = "serde", serde(skip))]
|
||||
parameters: Option<&'a SVCParameters<TX, TY, X, Y>>,
|
||||
@@ -152,7 +335,9 @@ struct Cache<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1
|
||||
struct Optimizer<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> {
|
||||
x: &'a X,
|
||||
y: &'a Y,
|
||||
indices: Option<Vec<usize>>,
|
||||
parameters: &'a SVCParameters<TX, TY, X, Y>,
|
||||
classes: &'a (TY, TY),
|
||||
svmin: usize,
|
||||
svmax: usize,
|
||||
gmin: TX,
|
||||
@@ -180,12 +365,12 @@ impl<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
self.tol = tol;
|
||||
self
|
||||
}
|
||||
|
||||
/// The kernel function.
|
||||
pub fn with_kernel<K: Kernel + 'static>(mut self, kernel: K) -> Self {
|
||||
self.kernel = Some(Box::new(kernel));
|
||||
self
|
||||
}
|
||||
|
||||
/// Seed for the pseudo random number generator.
|
||||
pub fn with_seed(mut self, seed: Option<u64>) -> Self {
|
||||
self.seed = seed;
|
||||
@@ -241,17 +426,98 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
||||
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array1<TY> + 'a>
|
||||
SVC<'a, TX, TY, X, Y>
|
||||
{
|
||||
/// Fits SVC to your data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
/// * `y` - class labels
|
||||
/// * `parameters` - optional parameters, use `Default::default()` to set parameters to default values.
|
||||
/// Fits a binary Support Vector Classifier (SVC) to the provided data.
|
||||
///
|
||||
/// This is the primary `fit` method for a standalone binary SVC. It expects
|
||||
/// the target labels `y` to contain exactly two unique classes. If more or
|
||||
/// fewer than two classes are found, it returns an error. It then extracts
|
||||
/// these two classes and proceeds to optimize and fit the SVC model.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `x` - A reference to the input features (2D array) of the training data.
|
||||
/// * `y` - A reference to the target labels (1D array) of the training data. `y` must contain exactly two unique class labels.
|
||||
/// * `parameters` - A reference to the `SVCParameters` controlling the training process.
|
||||
///
|
||||
/// # Returns
|
||||
/// A `Result` which is:
|
||||
/// - `Ok(SVC<'a, TX, TY, X, Y>)`: A new, fitted binary SVC instance.
|
||||
/// - `Err(Failed)`: If the number of unique classes in `y` is not exactly two, or if the underlying optimization fails.
|
||||
pub fn fit(
|
||||
x: &'a X,
|
||||
y: &'a Y,
|
||||
parameters: &'a SVCParameters<TX, TY, X, Y>,
|
||||
) -> Result<SVC<'a, TX, TY, X, Y>, Failed> {
|
||||
let (n, _) = x.shape();
|
||||
let classes = y.unique();
|
||||
// Validate that there are exactly two unique classes in the target labels.
|
||||
if classes.len() != 2 {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Incorrect number of classes: {}. A binary SVC requires exactly two classes.",
|
||||
classes.len()
|
||||
)));
|
||||
}
|
||||
let classes = (classes[0], classes[1]);
|
||||
let svc = Self::optimize_and_fit(x, y, parameters, classes, None);
|
||||
svc
|
||||
}
|
||||
|
||||
/// Fits a binary Support Vector Classifier (SVC) specifically for multi-class scenarios.
|
||||
///
|
||||
/// This function is intended to be called by a multi-class strategy (e.g., one-vs-one)
|
||||
/// to train individual binary SVCs. It takes a `MultiClassConfig` which specifies
|
||||
/// the two classes this SVC should discriminate and the subset of data indices
|
||||
/// relevant to these classes. It then delegates the actual optimization and fitting
|
||||
/// to `optimize_and_fit`.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `x` - A reference to the input features (2D array) of the training data.
|
||||
/// * `y` - A reference to the target labels (1D array) of the training data.
|
||||
/// * `parameters` - A reference to the `SVCParameters` controlling the training process (e.g., kernel, C-value, tolerance).
|
||||
/// * `multiclass_config` - A `MultiClassConfig` struct containing:
|
||||
/// - `classes`: A tuple `(class0, class1)` specifying the two classes this SVC should distinguish.
|
||||
/// - `indices`: A `Vec<usize>` containing the indices of the data points in `x` and `y that belong to either `class0` or `class1`.`
|
||||
///
|
||||
/// # Returns
|
||||
/// A `Result` which is:
|
||||
/// - `Ok(SVC<'a, TX, TY, X, Y>)`: A new, fitted binary SVC instance.
|
||||
/// - `Err(Failed)`: If the fitting process encounters an error (e.g., invalid parameters).
|
||||
fn multiclass_fit(
|
||||
x: &'a X,
|
||||
y: &'a Y,
|
||||
parameters: &'a SVCParameters<TX, TY, X, Y>,
|
||||
multiclass_config: MultiClassConfig<TY>,
|
||||
) -> Result<SVC<'a, TX, TY, X, Y>, Failed> {
|
||||
let classes = multiclass_config.classes;
|
||||
let indices = multiclass_config.indices;
|
||||
let svc = Self::optimize_and_fit(x, y, parameters, classes, Some(indices));
|
||||
svc
|
||||
}
|
||||
|
||||
/// Internal function to optimize and fit the Support Vector Classifier.
|
||||
///
|
||||
/// This is the core logic for training a binary SVC. It performs several checks
|
||||
/// (e.g., kernel presence, data shape consistency) and then initializes an
|
||||
/// `Optimizer` to find the support vectors, weights (`w`), and bias (`b`).
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `x` - A reference to the input features (2D array) of the training data.
|
||||
/// * `y` - A reference to the target labels (1D array) of the training data.
|
||||
/// * `parameters` - A reference to the `SVCParameters` defining the SVM model's configuration.
|
||||
/// * `classes` - A tuple `(class0, class1)` representing the two distinct class labels that the SVC will learn to separate.
|
||||
/// * `indices` - An `Option<Vec<usize>>`. If `Some`, it contains the specific indices of data points from `x` and `y` that should be used for training this binary classifier. If `None`, all data points in `x` and `y` are considered.
|
||||
/// # Returns
|
||||
/// A `Result` which is:
|
||||
/// - `Ok(SVC<'a, TX, TY, X, Y>)`: A new `SVC` instance populated with the learned model components (support vectors, weights, bias).
|
||||
/// - `Err(Failed)`: If any of the validation checks fail (e.g., missing kernel, mismatched data shapes), or if the optimization process fails.
|
||||
fn optimize_and_fit(
|
||||
x: &'a X,
|
||||
y: &'a Y,
|
||||
parameters: &'a SVCParameters<TX, TY, X, Y>,
|
||||
classes: (TY, TY),
|
||||
indices: Option<Vec<usize>>,
|
||||
) -> Result<SVC<'a, TX, TY, X, Y>, Failed> {
|
||||
let (n_samples, _) = x.shape();
|
||||
|
||||
// Validate that a kernel has been defined in the parameters.
|
||||
if parameters.kernel.is_none() {
|
||||
return Err(Failed::because(
|
||||
FailedError::ParametersError,
|
||||
@@ -259,55 +525,39 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array
|
||||
));
|
||||
}
|
||||
|
||||
if n != y.shape() {
|
||||
// Validate that the number of samples in X matches the number of labels in Y.
|
||||
if n_samples != y.shape() {
|
||||
return Err(Failed::fit(
|
||||
"Number of rows of X doesn\'t match number of rows of Y",
|
||||
"Number of rows of X doesn't match number of rows of Y",
|
||||
));
|
||||
}
|
||||
|
||||
let classes = y.unique();
|
||||
|
||||
if classes.len() != 2 {
|
||||
return Err(Failed::fit(&format!(
|
||||
"Incorrect number of classes: {}",
|
||||
classes.len()
|
||||
)));
|
||||
}
|
||||
|
||||
// Make sure class labels are either 1 or -1
|
||||
for e in y.iterator(0) {
|
||||
let y_v = e.to_i32().unwrap();
|
||||
if y_v != -1 && y_v != 1 {
|
||||
return Err(Failed::because(
|
||||
FailedError::ParametersError,
|
||||
"Class labels must be 1 or -1",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let optimizer: Optimizer<'_, TX, TY, X, Y> = Optimizer::new(x, y, parameters);
|
||||
let optimizer: Optimizer<'_, TX, TY, X, Y> =
|
||||
Optimizer::new(x, y, indices, parameters, &classes);
|
||||
|
||||
// Perform the optimization to find the support vectors, weight vector, and bias.
|
||||
// This is where the core SVM algorithm (e.g., SMO) would run.
|
||||
let (support_vectors, weight, b) = optimizer.optimize();
|
||||
|
||||
// Construct and return the fitted SVC model.
|
||||
Ok(SVC::<'a> {
|
||||
classes: Some(classes),
|
||||
instances: Some(support_vectors),
|
||||
parameters: Some(parameters),
|
||||
w: Some(weight),
|
||||
b: Some(b),
|
||||
phantomdata: PhantomData,
|
||||
classes: Some(classes), // Store the two classes the SVC was trained on.
|
||||
instances: Some(support_vectors), // Store the data points that are support vectors.
|
||||
parameters: Some(parameters), // Reference to the parameters used for fitting.
|
||||
w: Some(weight), // The learned weight vector (for linear kernels).
|
||||
b: Some(b), // The learned bias term.
|
||||
phantomdata: PhantomData, // Placeholder for type parameters not directly stored.
|
||||
})
|
||||
}
|
||||
|
||||
/// Predicts estimated class labels from `x`
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn predict(&self, x: &'a X) -> Result<Vec<TX>, Failed> {
|
||||
let mut y_hat: Vec<TX> = self.decision_function(x)?;
|
||||
|
||||
for i in 0..y_hat.len() {
|
||||
let cls_idx = match *y_hat.get(i).unwrap() > TX::zero() {
|
||||
false => TX::from(self.classes.as_ref().unwrap()[0]).unwrap(),
|
||||
true => TX::from(self.classes.as_ref().unwrap()[1]).unwrap(),
|
||||
let cls_idx = match *y_hat.get(i) > TX::zero() {
|
||||
false => TX::from(self.classes.as_ref().unwrap().0).unwrap(),
|
||||
true => TX::from(self.classes.as_ref().unwrap().1).unwrap(),
|
||||
};
|
||||
|
||||
y_hat.set(i, cls_idx);
|
||||
@@ -360,8 +610,8 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> PartialEq
|
||||
for SVC<'a, TX, TY, X, Y>
|
||||
impl<TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> PartialEq
|
||||
for SVC<'_, TX, TY, X, Y>
|
||||
{
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if (self.b.unwrap().sub(other.b.unwrap())).abs() > TX::epsilon() * TX::two()
|
||||
@@ -445,14 +695,18 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
||||
fn new(
|
||||
x: &'a X,
|
||||
y: &'a Y,
|
||||
indices: Option<Vec<usize>>,
|
||||
parameters: &'a SVCParameters<TX, TY, X, Y>,
|
||||
classes: &'a (TY, TY),
|
||||
) -> Optimizer<'a, TX, TY, X, Y> {
|
||||
let (n, _) = x.shape();
|
||||
|
||||
Optimizer {
|
||||
x,
|
||||
y,
|
||||
indices,
|
||||
parameters,
|
||||
classes,
|
||||
svmin: 0,
|
||||
svmax: 0,
|
||||
gmin: <TX as Bounded>::max_value(),
|
||||
@@ -478,7 +732,12 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
||||
for i in self.permutate(n) {
|
||||
x.clear();
|
||||
x.extend(self.x.get_row(i).iterator(0).take(n).copied());
|
||||
self.process(i, &x, *self.y.get(i), &mut cache);
|
||||
let y = if *self.y.get(i) == self.classes.1 {
|
||||
1
|
||||
} else {
|
||||
-1
|
||||
} as f64;
|
||||
self.process(i, &x, y, &mut cache);
|
||||
loop {
|
||||
self.reprocess(tol, &mut cache);
|
||||
self.find_min_max_gradient();
|
||||
@@ -514,14 +773,16 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
||||
for i in self.permutate(n) {
|
||||
x.clear();
|
||||
x.extend(self.x.get_row(i).iterator(0).take(n).copied());
|
||||
if *self.y.get(i) == TY::one() && cp < few {
|
||||
if self.process(i, &x, *self.y.get(i), cache) {
|
||||
let y = if *self.y.get(i) == self.classes.1 {
|
||||
1
|
||||
} else {
|
||||
-1
|
||||
} as f64;
|
||||
if y == 1.0 && cp < few {
|
||||
if self.process(i, &x, y, cache) {
|
||||
cp += 1;
|
||||
}
|
||||
} else if *self.y.get(i) == TY::from(-1).unwrap()
|
||||
&& cn < few
|
||||
&& self.process(i, &x, *self.y.get(i), cache)
|
||||
{
|
||||
} else if y == -1.0 && cn < few && self.process(i, &x, y, cache) {
|
||||
cn += 1;
|
||||
}
|
||||
|
||||
@@ -531,14 +792,14 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
||||
}
|
||||
}
|
||||
|
||||
fn process(&mut self, i: usize, x: &[TX], y: TY, cache: &mut Cache<TX, TY, X, Y>) -> bool {
|
||||
fn process(&mut self, i: usize, x: &[TX], y: f64, cache: &mut Cache<TX, TY, X, Y>) -> bool {
|
||||
for j in 0..self.sv.len() {
|
||||
if self.sv[j].index == i {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
let mut g: f64 = y.to_f64().unwrap();
|
||||
let mut g = y;
|
||||
|
||||
let mut cache_values: Vec<((usize, usize), TX)> = Vec::new();
|
||||
|
||||
@@ -559,8 +820,8 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
||||
self.find_min_max_gradient();
|
||||
|
||||
if self.gmin < self.gmax
|
||||
&& ((y > TY::zero() && g < self.gmin.to_f64().unwrap())
|
||||
|| (y < TY::zero() && g > self.gmax.to_f64().unwrap()))
|
||||
&& ((y > 0.0 && g < self.gmin.to_f64().unwrap())
|
||||
|| (y < 0.0 && g > self.gmax.to_f64().unwrap()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -590,7 +851,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
||||
),
|
||||
);
|
||||
|
||||
if y > TY::zero() {
|
||||
if y > 0.0 {
|
||||
self.smo(None, Some(0), TX::zero(), cache);
|
||||
} else {
|
||||
self.smo(Some(0), None, TX::zero(), cache);
|
||||
@@ -647,7 +908,6 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
||||
let gmin = self.gmin;
|
||||
|
||||
let mut idxs_to_drop: HashSet<usize> = HashSet::new();
|
||||
|
||||
self.sv.retain(|v| {
|
||||
if v.alpha == 0f64
|
||||
&& ((TX::from(v.grad).unwrap() >= gmax && TX::zero() >= TX::from(v.cmax).unwrap())
|
||||
@@ -666,7 +926,11 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
|
||||
|
||||
fn permutate(&self, n: usize) -> Vec<usize> {
|
||||
let mut rng = get_rng_impl(self.parameters.seed);
|
||||
let mut range: Vec<usize> = (0..n).collect();
|
||||
let mut range = if let Some(indices) = self.indices.clone() {
|
||||
indices
|
||||
} else {
|
||||
(0..n).collect::<Vec<usize>>()
|
||||
};
|
||||
range.shuffle(&mut rng);
|
||||
range
|
||||
}
|
||||
@@ -957,19 +1221,20 @@ mod tests {
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let y: Vec<i32> = vec![
|
||||
-1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
];
|
||||
|
||||
let knl = Kernels::linear();
|
||||
let params = SVCParameters::default()
|
||||
let parameters = SVCParameters::default()
|
||||
.with_c(200.0)
|
||||
.with_kernel(knl)
|
||||
.with_seed(Some(100));
|
||||
|
||||
let y_hat = SVC::fit(&x, &y, ¶ms)
|
||||
let y_hat = SVC::fit(&x, &y, ¶meters)
|
||||
.and_then(|lr| lr.predict(&x))
|
||||
.unwrap();
|
||||
let acc = accuracy(&y, &(y_hat.iter().map(|e| e.to_i32().unwrap()).collect()));
|
||||
@@ -983,7 +1248,8 @@ mod tests {
|
||||
)]
|
||||
#[test]
|
||||
fn svc_fit_decision_function() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[4.0, 0.0], &[0.0, 4.0], &[8.0, 0.0], &[0.0, 8.0]]);
|
||||
let x = DenseMatrix::from_2d_array(&[&[4.0, 0.0], &[0.0, 4.0], &[8.0, 0.0], &[0.0, 8.0]])
|
||||
.unwrap();
|
||||
|
||||
let x2 = DenseMatrix::from_2d_array(&[
|
||||
&[3.0, 3.0],
|
||||
@@ -992,7 +1258,8 @@ mod tests {
|
||||
&[10.0, 10.0],
|
||||
&[1.0, 1.0],
|
||||
&[0.0, 0.0],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let y: Vec<i32> = vec![-1, -1, 1, 1];
|
||||
|
||||
@@ -1045,7 +1312,8 @@ mod tests {
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let y: Vec<i32> = vec![
|
||||
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
@@ -1066,6 +1334,56 @@ mod tests {
|
||||
assert!(acc >= 0.9, "accuracy ({acc}) is not larger or equal to 0.9");
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn svc_multiclass_fit_predict() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[5.7, 2.8, 4.5, 1.3],
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let y: Vec<i32> = vec![0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2];
|
||||
|
||||
let knl = Kernels::linear();
|
||||
let parameters = SVCParameters::default()
|
||||
.with_c(200.0)
|
||||
.with_kernel(knl)
|
||||
.with_seed(Some(100));
|
||||
|
||||
let y_hat = MultiClassSVC::fit(&x, &y, ¶meters)
|
||||
.and_then(|lr| lr.predict(&x))
|
||||
.unwrap();
|
||||
|
||||
let acc = accuracy(&y, &(y_hat.iter().map(|e| e.to_i32().unwrap()).collect()));
|
||||
|
||||
assert!(
|
||||
acc >= 0.9,
|
||||
"Multiclass accuracy ({acc}) is not larger or equal to 0.9"
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
@@ -1094,18 +1412,19 @@ mod tests {
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let y: Vec<i32> = vec![
|
||||
-1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
];
|
||||
|
||||
let knl = Kernels::linear();
|
||||
let params = SVCParameters::default().with_kernel(knl);
|
||||
let svc = SVC::fit(&x, &y, ¶ms).unwrap();
|
||||
let parameters = SVCParameters::default().with_kernel(knl);
|
||||
let svc = SVC::fit(&x, &y, ¶meters).unwrap();
|
||||
|
||||
// serialization
|
||||
let deserialized_svc: SVC<f64, i32, _, _> =
|
||||
let deserialized_svc: SVC<'_, f64, i32, _, _> =
|
||||
serde_json::from_str(&serde_json::to_string(&svc).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(svc, deserialized_svc);
|
||||
|
||||
+34
-31
@@ -44,16 +44,16 @@
|
||||
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//!
|
||||
//! let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0,
|
||||
//! 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
||||
//!
|
||||
//! let knl = Kernels::linear();
|
||||
//! let params = &SVRParameters::default().with_eps(2.0).with_c(10.0).with_kernel(knl);
|
||||
//! // let svr = SVR::fit(&x, &y, params).unwrap();
|
||||
//! let svr = SVR::fit(&x, &y, params).unwrap();
|
||||
//!
|
||||
//! // let y_hat = svr.predict(&x).unwrap();
|
||||
//! let y_hat = svr.predict(&x).unwrap();
|
||||
//! ```
|
||||
//!
|
||||
//! ## References:
|
||||
@@ -80,11 +80,12 @@ use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::basic::arrays::{Array1, Array2, MutArray};
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::numbers::floatnum::FloatNumber;
|
||||
use crate::svm::Kernel;
|
||||
|
||||
use crate::svm::{Kernel, Kernels};
|
||||
|
||||
/// SVR Parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
/// SVR Parameters
|
||||
pub struct SVRParameters<T: Number + FloatNumber + PartialOrd> {
|
||||
/// Epsilon in the epsilon-SVR model.
|
||||
pub eps: T,
|
||||
@@ -97,7 +98,7 @@ pub struct SVRParameters<T: Number + FloatNumber + PartialOrd> {
|
||||
all(feature = "serde", target_arch = "wasm32"),
|
||||
serde(skip_serializing, skip_deserializing)
|
||||
)]
|
||||
pub kernel: Option<Box<dyn Kernel>>,
|
||||
pub kernel: Option<Kernels>,
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
@@ -160,8 +161,8 @@ impl<T: Number + FloatNumber + PartialOrd> SVRParameters<T> {
|
||||
self
|
||||
}
|
||||
/// The kernel function.
|
||||
pub fn with_kernel<K: Kernel + 'static>(mut self, kernel: K) -> Self {
|
||||
self.kernel = Some(Box::new(kernel));
|
||||
pub fn with_kernel(mut self, kernel: Kernels) -> Self {
|
||||
self.kernel = Some(kernel);
|
||||
self
|
||||
}
|
||||
}
|
||||
@@ -281,8 +282,8 @@ impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> SVR<'
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> PartialEq
|
||||
for SVR<'a, T, X, Y>
|
||||
impl<T: Number + FloatNumber + PartialOrd, X: Array2<T>, Y: Array1<T>> PartialEq
|
||||
for SVR<'_, T, X, Y>
|
||||
{
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if (self.b - other.b).abs() > T::epsilon() * T::two()
|
||||
@@ -597,25 +598,25 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use crate::metrics::mean_squared_error;
|
||||
use crate::svm::search::svr_params::SVRSearchParameters;
|
||||
use crate::svm::Kernels;
|
||||
|
||||
// #[test]
|
||||
// fn search_parameters() {
|
||||
// let parameters: SVRSearchParameters<f64, DenseMatrix<f64>, LinearKernel> =
|
||||
// SVRSearchParameters {
|
||||
// eps: vec![0., 1.],
|
||||
// kernel: vec![LinearKernel {}],
|
||||
// ..Default::default()
|
||||
// };
|
||||
// let mut iter = parameters.into_iter();
|
||||
// let next = iter.next().unwrap();
|
||||
// assert_eq!(next.eps, 0.);
|
||||
// assert_eq!(next.kernel, LinearKernel {});
|
||||
// let next = iter.next().unwrap();
|
||||
// assert_eq!(next.eps, 1.);
|
||||
// assert_eq!(next.kernel, LinearKernel {});
|
||||
// assert!(iter.next().is_none());
|
||||
// }
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters: SVRSearchParameters<f64, DenseMatrix<f64>> = SVRSearchParameters {
|
||||
eps: vec![0., 1.],
|
||||
kernel: vec![Kernels::linear()],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.eps, 0.);
|
||||
// assert_eq!(next.kernel, LinearKernel {});
|
||||
// let next = iter.next().unwrap();
|
||||
// assert_eq!(next.eps, 1.);
|
||||
// assert_eq!(next.kernel, LinearKernel {});
|
||||
// assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
@@ -640,14 +641,15 @@ mod tests {
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let y: Vec<f64> = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
|
||||
let knl = Kernels::linear();
|
||||
let knl: Kernels = Kernels::linear();
|
||||
let y_hat = SVR::fit(
|
||||
&x,
|
||||
&y,
|
||||
@@ -688,7 +690,8 @@ mod tests {
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let y: Vec<f64> = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
@@ -700,7 +703,7 @@ mod tests {
|
||||
|
||||
let svr = SVR::fit(&x, &y, ¶ms).unwrap();
|
||||
|
||||
let deserialized_svr: SVR<f64, DenseMatrix<f64>, _> =
|
||||
let deserialized_svr: SVR<'_, f64, DenseMatrix<f64>, _> =
|
||||
serde_json::from_str(&serde_json::to_string(&svr).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(svr, deserialized_svr);
|
||||
|
||||
@@ -48,7 +48,7 @@
|
||||
//! &[4.9, 2.4, 3.3, 1.0],
|
||||
//! &[6.6, 2.9, 4.6, 1.3],
|
||||
//! &[5.2, 2.7, 3.9, 1.4],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//! let y = vec![ 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
//! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
|
||||
//!
|
||||
@@ -77,7 +77,9 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::basic::arrays::MutArray;
|
||||
use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1};
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use crate::numbers::basenum::Number;
|
||||
use crate::rand_custom::get_rng_impl;
|
||||
|
||||
@@ -116,6 +118,7 @@ pub struct DecisionTreeClassifier<
|
||||
num_classes: usize,
|
||||
classes: Vec<TY>,
|
||||
depth: u16,
|
||||
num_features: usize,
|
||||
_phantom_tx: PhantomData<TX>,
|
||||
_phantom_x: PhantomData<X>,
|
||||
_phantom_y: PhantomData<Y>,
|
||||
@@ -159,11 +162,13 @@ pub enum SplitCriterion {
|
||||
#[derive(Debug, Clone)]
|
||||
struct Node {
|
||||
output: usize,
|
||||
n_node_samples: usize,
|
||||
split_feature: usize,
|
||||
split_value: Option<f64>,
|
||||
split_score: Option<f64>,
|
||||
true_child: Option<usize>,
|
||||
false_child: Option<usize>,
|
||||
impurity: Option<f64>,
|
||||
}
|
||||
|
||||
impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>> PartialEq
|
||||
@@ -194,12 +199,12 @@ impl PartialEq for Node {
|
||||
self.output == other.output
|
||||
&& self.split_feature == other.split_feature
|
||||
&& match (self.split_value, other.split_value) {
|
||||
(Some(a), Some(b)) => (a - b).abs() < std::f64::EPSILON,
|
||||
(Some(a), Some(b)) => (a - b).abs() < f64::EPSILON,
|
||||
(None, None) => true,
|
||||
_ => false,
|
||||
}
|
||||
&& match (self.split_score, other.split_score) {
|
||||
(Some(a), Some(b)) => (a - b).abs() < std::f64::EPSILON,
|
||||
(Some(a), Some(b)) => (a - b).abs() < f64::EPSILON,
|
||||
(None, None) => true,
|
||||
_ => false,
|
||||
}
|
||||
@@ -400,14 +405,16 @@ impl Default for DecisionTreeClassifierSearchParameters {
|
||||
}
|
||||
|
||||
impl Node {
|
||||
fn new(output: usize) -> Self {
|
||||
fn new(output: usize, n_node_samples: usize) -> Self {
|
||||
Node {
|
||||
output,
|
||||
n_node_samples,
|
||||
split_feature: 0,
|
||||
split_value: Option::None,
|
||||
split_score: Option::None,
|
||||
true_child: Option::None,
|
||||
false_child: Option::None,
|
||||
impurity: Option::None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -507,6 +514,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
num_classes: 0usize,
|
||||
classes: vec![],
|
||||
depth: 0u16,
|
||||
num_features: 0usize,
|
||||
_phantom_tx: PhantomData,
|
||||
_phantom_x: PhantomData,
|
||||
_phantom_y: PhantomData,
|
||||
@@ -578,7 +586,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
count[yi[i]] += samples[i];
|
||||
}
|
||||
|
||||
let root = Node::new(which_max(&count));
|
||||
let root = Node::new(which_max(&count), y_ncols);
|
||||
change_nodes.push(root);
|
||||
let mut order: Vec<Vec<usize>> = Vec::new();
|
||||
|
||||
@@ -593,6 +601,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
num_classes: k,
|
||||
classes,
|
||||
depth: 0u16,
|
||||
num_features: num_attributes,
|
||||
_phantom_tx: PhantomData,
|
||||
_phantom_x: PhantomData,
|
||||
_phantom_y: PhantomData,
|
||||
@@ -606,7 +615,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
visitor_queue.push_back(visitor);
|
||||
}
|
||||
|
||||
while tree.depth() < tree.parameters().max_depth.unwrap_or(std::u16::MAX) {
|
||||
while tree.depth() < tree.parameters().max_depth.unwrap_or(u16::MAX) {
|
||||
match visitor_queue.pop_front() {
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue, &mut rng),
|
||||
None => break,
|
||||
@@ -643,7 +652,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
if node.true_child.is_none() && node.false_child.is_none() {
|
||||
result = node.output;
|
||||
} else if x.get((row, node.split_feature)).to_f64().unwrap()
|
||||
<= node.split_value.unwrap_or(std::f64::NAN)
|
||||
<= node.split_value.unwrap_or(f64::NAN)
|
||||
{
|
||||
queue.push_back(node.true_child.unwrap());
|
||||
} else {
|
||||
@@ -678,16 +687,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
}
|
||||
}
|
||||
|
||||
if is_pure {
|
||||
return false;
|
||||
}
|
||||
|
||||
let n = visitor.samples.iter().sum();
|
||||
|
||||
if n <= self.parameters().min_samples_split {
|
||||
return false;
|
||||
}
|
||||
|
||||
let mut count = vec![0; self.num_classes];
|
||||
let mut false_count = vec![0; self.num_classes];
|
||||
for i in 0..n_rows {
|
||||
@@ -696,7 +696,15 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
}
|
||||
}
|
||||
|
||||
let parent_impurity = impurity(&self.parameters().criterion, &count, n);
|
||||
self.nodes[visitor.node].impurity = Some(impurity(&self.parameters().criterion, &count, n));
|
||||
|
||||
if is_pure {
|
||||
return false;
|
||||
}
|
||||
|
||||
if n <= self.parameters().min_samples_split {
|
||||
return false;
|
||||
}
|
||||
|
||||
let mut variables = (0..n_attr).collect::<Vec<_>>();
|
||||
|
||||
@@ -705,14 +713,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
}
|
||||
|
||||
for variable in variables.iter().take(mtry) {
|
||||
self.find_best_split(
|
||||
visitor,
|
||||
n,
|
||||
&count,
|
||||
&mut false_count,
|
||||
parent_impurity,
|
||||
*variable,
|
||||
);
|
||||
self.find_best_split(visitor, n, &count, &mut false_count, *variable);
|
||||
}
|
||||
|
||||
self.nodes()[visitor.node].split_score.is_some()
|
||||
@@ -724,7 +725,6 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
n: usize,
|
||||
count: &[usize],
|
||||
false_count: &mut [usize],
|
||||
parent_impurity: f64,
|
||||
j: usize,
|
||||
) {
|
||||
let mut true_count = vec![0; self.num_classes];
|
||||
@@ -760,6 +760,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
|
||||
let true_label = which_max(&true_count);
|
||||
let false_label = which_max(false_count);
|
||||
let parent_impurity = self.nodes()[visitor.node].impurity.unwrap();
|
||||
let gain = parent_impurity
|
||||
- tc as f64 / n as f64
|
||||
* impurity(&self.parameters().criterion, &true_count, tc)
|
||||
@@ -804,9 +805,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
.get((i, self.nodes()[visitor.node].split_feature))
|
||||
.to_f64()
|
||||
.unwrap()
|
||||
<= self.nodes()[visitor.node]
|
||||
.split_value
|
||||
.unwrap_or(std::f64::NAN)
|
||||
<= self.nodes()[visitor.node].split_value.unwrap_or(f64::NAN)
|
||||
{
|
||||
*true_sample = visitor.samples[i];
|
||||
tc += *true_sample;
|
||||
@@ -827,9 +826,9 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
|
||||
let true_child_idx = self.nodes().len();
|
||||
|
||||
self.nodes.push(Node::new(visitor.true_child_output));
|
||||
self.nodes.push(Node::new(visitor.true_child_output, tc));
|
||||
let false_child_idx = self.nodes().len();
|
||||
self.nodes.push(Node::new(visitor.false_child_output));
|
||||
self.nodes.push(Node::new(visitor.false_child_output, fc));
|
||||
self.nodes[visitor.node].true_child = Some(true_child_idx);
|
||||
self.nodes[visitor.node].false_child = Some(false_child_idx);
|
||||
|
||||
@@ -863,11 +862,104 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Compute feature importances for the fitted tree.
|
||||
pub fn compute_feature_importances(&self, normalize: bool) -> Vec<f64> {
|
||||
let mut importances = vec![0f64; self.num_features];
|
||||
|
||||
for node in self.nodes().iter() {
|
||||
if node.true_child.is_none() && node.false_child.is_none() {
|
||||
continue;
|
||||
}
|
||||
let left = &self.nodes()[node.true_child.unwrap()];
|
||||
let right = &self.nodes()[node.false_child.unwrap()];
|
||||
|
||||
importances[node.split_feature] += node.n_node_samples as f64 * node.impurity.unwrap()
|
||||
- left.n_node_samples as f64 * left.impurity.unwrap()
|
||||
- right.n_node_samples as f64 * right.impurity.unwrap();
|
||||
}
|
||||
for item in importances.iter_mut() {
|
||||
*item /= self.nodes()[0].n_node_samples as f64;
|
||||
}
|
||||
if normalize {
|
||||
let sum = importances.iter().sum::<f64>();
|
||||
for importance in importances.iter_mut() {
|
||||
*importance /= sum;
|
||||
}
|
||||
}
|
||||
importances
|
||||
}
|
||||
|
||||
/// Predict class probabilities for the input samples.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `x` - The input samples as a matrix where each row is a sample and each column is a feature.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A `Result` containing a `DenseMatrix<f64>` where each row corresponds to a sample and each column
|
||||
/// corresponds to a class. The values represent the probability of the sample belonging to each class.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if at least one row prediction process fails.
|
||||
pub fn predict_proba(&self, x: &X) -> Result<DenseMatrix<f64>, Failed> {
|
||||
let (n_samples, _) = x.shape();
|
||||
let n_classes = self.classes().len();
|
||||
let mut result = DenseMatrix::<f64>::zeros(n_samples, n_classes);
|
||||
|
||||
for i in 0..n_samples {
|
||||
let probs = self.predict_proba_for_row(x, i)?;
|
||||
for (j, &prob) in probs.iter().enumerate() {
|
||||
result.set((i, j), prob);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Predict class probabilities for a single input sample.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `x` - The input matrix containing all samples.
|
||||
/// * `row` - The index of the row in `x` for which to predict probabilities.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vector of probabilities, one for each class, representing the probability
|
||||
/// of the input sample belonging to each class.
|
||||
fn predict_proba_for_row(&self, x: &X, row: usize) -> Result<Vec<f64>, Failed> {
|
||||
let mut node = 0;
|
||||
|
||||
while let Some(current_node) = self.nodes().get(node) {
|
||||
if current_node.true_child.is_none() && current_node.false_child.is_none() {
|
||||
// Leaf node reached
|
||||
let mut probs = vec![0.0; self.classes().len()];
|
||||
probs[current_node.output] = 1.0;
|
||||
return Ok(probs);
|
||||
}
|
||||
|
||||
let split_feature = current_node.split_feature;
|
||||
let split_value = current_node.split_value.unwrap_or(f64::NAN);
|
||||
|
||||
if x.get((row, split_feature)).to_f64().unwrap() <= split_value {
|
||||
node = current_node.true_child.unwrap();
|
||||
} else {
|
||||
node = current_node.false_child.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
// This should never happen if the tree is properly constructed
|
||||
Err(Failed::predict("Nodes iteration did not reach leaf"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::basic::arrays::Array;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
|
||||
#[test]
|
||||
@@ -899,17 +991,62 @@ mod tests {
|
||||
)]
|
||||
#[test]
|
||||
fn gini_impurity() {
|
||||
assert!((impurity(&SplitCriterion::Gini, &[7, 3], 10) - 0.42).abs() < std::f64::EPSILON);
|
||||
assert!((impurity(&SplitCriterion::Gini, &[7, 3], 10) - 0.42).abs() < f64::EPSILON);
|
||||
assert!(
|
||||
(impurity(&SplitCriterion::Entropy, &[7, 3], 10) - 0.8812908992306927).abs()
|
||||
< std::f64::EPSILON
|
||||
< f64::EPSILON
|
||||
);
|
||||
assert!(
|
||||
(impurity(&SplitCriterion::ClassificationError, &[7, 3], 10) - 0.3).abs()
|
||||
< std::f64::EPSILON
|
||||
< f64::EPSILON
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn test_predict_proba() {
|
||||
let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
])
|
||||
.unwrap();
|
||||
let y: Vec<usize> = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
|
||||
|
||||
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||
let probabilities = tree.predict_proba(&x).unwrap();
|
||||
|
||||
assert_eq!(probabilities.shape(), (10, 2));
|
||||
|
||||
for row in 0..10 {
|
||||
let row_sum: f64 = probabilities.get_row(row).sum();
|
||||
assert!(
|
||||
(row_sum - 1.0).abs() < 1e-6,
|
||||
"Row probabilities should sum to 1"
|
||||
);
|
||||
}
|
||||
|
||||
// Check if the first 5 samples have higher probability for class 0
|
||||
for i in 0..5 {
|
||||
assert!(probabilities.get((i, 0)) > probabilities.get((i, 1)));
|
||||
}
|
||||
|
||||
// Check if the last 5 samples have higher probability for class 1
|
||||
for i in 5..10 {
|
||||
assert!(probabilities.get((i, 1)) > probabilities.get((i, 0)));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
@@ -938,7 +1075,8 @@ mod tests {
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y: Vec<u32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
|
||||
|
||||
assert_eq!(
|
||||
@@ -1005,7 +1143,8 @@ mod tests {
|
||||
&[0., 0., 1., 1.],
|
||||
&[0., 0., 0., 0.],
|
||||
&[0., 0., 0., 1.],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y: Vec<u32> = vec![1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0];
|
||||
|
||||
assert_eq!(
|
||||
@@ -1016,6 +1155,43 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compute_feature_importances() {
|
||||
let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
|
||||
&[1., 1., 1., 0.],
|
||||
&[1., 1., 1., 0.],
|
||||
&[1., 1., 1., 1.],
|
||||
&[1., 1., 0., 0.],
|
||||
&[1., 1., 0., 1.],
|
||||
&[1., 0., 1., 0.],
|
||||
&[1., 0., 1., 0.],
|
||||
&[1., 0., 1., 1.],
|
||||
&[1., 0., 0., 0.],
|
||||
&[1., 0., 0., 1.],
|
||||
&[0., 1., 1., 0.],
|
||||
&[0., 1., 1., 0.],
|
||||
&[0., 1., 1., 1.],
|
||||
&[0., 1., 0., 0.],
|
||||
&[0., 1., 0., 1.],
|
||||
&[0., 0., 1., 0.],
|
||||
&[0., 0., 1., 0.],
|
||||
&[0., 0., 1., 1.],
|
||||
&[0., 0., 0., 0.],
|
||||
&[0., 0., 0., 1.],
|
||||
])
|
||||
.unwrap();
|
||||
let y: Vec<u32> = vec![1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0];
|
||||
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||
assert_eq!(
|
||||
tree.compute_feature_importances(false),
|
||||
vec![0., 0., 0.21333333333333332, 0.26666666666666666]
|
||||
);
|
||||
assert_eq!(
|
||||
tree.compute_feature_importances(true),
|
||||
vec![0., 0., 0.4444444444444444, 0.5555555555555556]
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
@@ -1044,7 +1220,8 @@ mod tests {
|
||||
&[0., 0., 1., 1.],
|
||||
&[0., 0., 0., 0.],
|
||||
&[0., 0., 0., 1.],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y = vec![1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0];
|
||||
|
||||
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
@@ -39,7 +39,7 @@
|
||||
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
//! ]);
|
||||
//! ]).unwrap();
|
||||
//! let y: Vec<f64> = vec![
|
||||
//! 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0,
|
||||
//! 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9,
|
||||
@@ -311,15 +311,15 @@ impl Node {
|
||||
|
||||
impl PartialEq for Node {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
(self.output - other.output).abs() < std::f64::EPSILON
|
||||
(self.output - other.output).abs() < f64::EPSILON
|
||||
&& self.split_feature == other.split_feature
|
||||
&& match (self.split_value, other.split_value) {
|
||||
(Some(a), Some(b)) => (a - b).abs() < std::f64::EPSILON,
|
||||
(Some(a), Some(b)) => (a - b).abs() < f64::EPSILON,
|
||||
(None, None) => true,
|
||||
_ => false,
|
||||
}
|
||||
&& match (self.split_score, other.split_score) {
|
||||
(Some(a), Some(b)) => (a - b).abs() < std::f64::EPSILON,
|
||||
(Some(a), Some(b)) => (a - b).abs() < f64::EPSILON,
|
||||
(None, None) => true,
|
||||
_ => false,
|
||||
}
|
||||
@@ -478,7 +478,7 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||
visitor_queue.push_back(visitor);
|
||||
}
|
||||
|
||||
while tree.depth() < tree.parameters().max_depth.unwrap_or(std::u16::MAX) {
|
||||
while tree.depth() < tree.parameters().max_depth.unwrap_or(u16::MAX) {
|
||||
match visitor_queue.pop_front() {
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue, &mut rng),
|
||||
None => break,
|
||||
@@ -515,7 +515,7 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||
if node.true_child.is_none() && node.false_child.is_none() {
|
||||
result = node.output;
|
||||
} else if x.get((row, node.split_feature)).to_f64().unwrap()
|
||||
<= node.split_value.unwrap_or(std::f64::NAN)
|
||||
<= node.split_value.unwrap_or(f64::NAN)
|
||||
{
|
||||
queue.push_back(node.true_child.unwrap());
|
||||
} else {
|
||||
@@ -640,9 +640,7 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||
.get((i, self.nodes()[visitor.node].split_feature))
|
||||
.to_f64()
|
||||
.unwrap()
|
||||
<= self.nodes()[visitor.node]
|
||||
.split_value
|
||||
.unwrap_or(std::f64::NAN)
|
||||
<= self.nodes()[visitor.node].split_value.unwrap_or(f64::NAN)
|
||||
{
|
||||
*true_sample = visitor.samples[i];
|
||||
tc += *true_sample;
|
||||
@@ -753,7 +751,8 @@ mod tests {
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y: Vec<f64> = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
@@ -767,7 +766,7 @@ mod tests {
|
||||
assert!((y_hat[i] - y[i]).abs() < 0.1);
|
||||
}
|
||||
|
||||
let expected_y = vec![
|
||||
let expected_y = [
|
||||
87.3, 87.3, 87.3, 87.3, 98.9, 98.9, 98.9, 98.9, 98.9, 107.9, 107.9, 107.9, 114.85,
|
||||
114.85, 114.85, 114.85,
|
||||
];
|
||||
@@ -788,7 +787,7 @@ mod tests {
|
||||
assert!((y_hat[i] - expected_y[i]).abs() < 0.1);
|
||||
}
|
||||
|
||||
let expected_y = vec![
|
||||
let expected_y = [
|
||||
83.0, 88.35, 88.35, 89.5, 97.15, 97.15, 99.5, 99.5, 101.2, 104.6, 109.6, 109.6, 113.4,
|
||||
113.4, 116.30, 116.30,
|
||||
];
|
||||
@@ -834,7 +833,8 @@ mod tests {
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
])
|
||||
.unwrap();
|
||||
let y: Vec<f64> = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
|
||||
Reference in New Issue
Block a user