Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
73f425cf3b | ||
|
|
18de2aa244 | ||
|
|
2bf5f7a1a5 | ||
|
|
0caa8306ff | ||
|
|
2f63148de4 | ||
|
|
f9e473c919 | ||
|
|
70d8a0f34b | ||
|
|
0e42a97514 | ||
|
|
36efd582a5 | ||
|
|
70212c71e0 | ||
|
|
63f86f7bc9 |
+10
-30
@@ -31,33 +31,21 @@ jobs:
|
|||||||
~/.cargo
|
~/.cargo
|
||||||
./target
|
./target
|
||||||
key: ${{ runner.os }}-cargo-${{ matrix.platform.target }}-${{ hashFiles('**/Cargo.toml') }}
|
key: ${{ runner.os }}-cargo-${{ matrix.platform.target }}-${{ hashFiles('**/Cargo.toml') }}
|
||||||
restore-keys: ${{ runner.os }}-cargo-${{ matrix.platform.target }}-${{ hashFiles('**/Cargo.toml') }}
|
restore-keys: ${{ runner.os }}-cargo-${{ matrix.platform.target }}
|
||||||
- name: Install Rust toolchain
|
- name: Install Rust toolchain
|
||||||
uses: actions-rs/toolchain@v1
|
uses: dtolnay/rust-toolchain@stable
|
||||||
with:
|
with:
|
||||||
toolchain: stable
|
targets: ${{ matrix.platform.target }}
|
||||||
target: ${{ matrix.platform.target }}
|
|
||||||
profile: minimal
|
|
||||||
default: true
|
|
||||||
- name: Install test runner for wasm
|
- name: Install test runner for wasm
|
||||||
if: matrix.platform.target == 'wasm32-unknown-unknown'
|
if: matrix.platform.target == 'wasm32-unknown-unknown'
|
||||||
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
|
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
|
||||||
- name: Stable Build with all features
|
- name: Stable Build with all features
|
||||||
uses: actions-rs/cargo@v1
|
run: cargo build --all-features --target ${{ matrix.platform.target }}
|
||||||
with:
|
|
||||||
command: build
|
|
||||||
args: --all-features --target ${{ matrix.platform.target }}
|
|
||||||
- name: Stable Build without features
|
- name: Stable Build without features
|
||||||
uses: actions-rs/cargo@v1
|
run: cargo build --target ${{ matrix.platform.target }}
|
||||||
with:
|
|
||||||
command: build
|
|
||||||
args: --target ${{ matrix.platform.target }}
|
|
||||||
- name: Tests
|
- name: Tests
|
||||||
if: matrix.platform.target == 'x86_64-unknown-linux-gnu' || matrix.platform.target == 'x86_64-pc-windows-msvc' || matrix.platform.target == 'aarch64-apple-darwin'
|
if: matrix.platform.target == 'x86_64-unknown-linux-gnu' || matrix.platform.target == 'x86_64-pc-windows-msvc' || matrix.platform.target == 'aarch64-apple-darwin'
|
||||||
uses: actions-rs/cargo@v1
|
run: cargo test --all-features
|
||||||
with:
|
|
||||||
command: test
|
|
||||||
args: --all-features
|
|
||||||
- name: Tests in WASM
|
- name: Tests in WASM
|
||||||
if: matrix.platform.target == 'wasm32-unknown-unknown'
|
if: matrix.platform.target == 'wasm32-unknown-unknown'
|
||||||
run: wasm-pack test --node -- --all-features
|
run: wasm-pack test --node -- --all-features
|
||||||
@@ -78,17 +66,9 @@ jobs:
|
|||||||
path: |
|
path: |
|
||||||
~/.cargo
|
~/.cargo
|
||||||
./target
|
./target
|
||||||
key: ${{ runner.os }}-cargo-features-${{ hashFiles('**/Cargo.toml') }}
|
key: ${{ runner.os }}-cargo-features-${{ hashFiles('Cargo.toml') }}
|
||||||
restore-keys: ${{ runner.os }}-cargo-features-${{ hashFiles('**/Cargo.toml') }}
|
restore-keys: ${{ runner.os }}-cargo-features
|
||||||
- name: Install Rust toolchain
|
- name: Install Rust toolchain
|
||||||
uses: actions-rs/toolchain@v1
|
uses: dtolnay/rust-toolchain@stable
|
||||||
with:
|
|
||||||
toolchain: stable
|
|
||||||
target: ${{ matrix.platform.target }}
|
|
||||||
profile: minimal
|
|
||||||
default: true
|
|
||||||
- name: Stable Build
|
- name: Stable Build
|
||||||
uses: actions-rs/cargo@v1
|
run: cargo build --no-default-features ${{ matrix.features }}
|
||||||
with:
|
|
||||||
command: build
|
|
||||||
args: --no-default-features ${{ matrix.features }}
|
|
||||||
|
|||||||
@@ -19,26 +19,15 @@ jobs:
|
|||||||
path: |
|
path: |
|
||||||
~/.cargo
|
~/.cargo
|
||||||
./target
|
./target
|
||||||
key: ${{ runner.os }}-coverage-cargo-${{ hashFiles('**/Cargo.toml') }}
|
key: ${{ runner.os }}-coverage-cargo-${{ hashFiles('Cargo.toml') }}
|
||||||
restore-keys: ${{ runner.os }}-coverage-cargo-${{ hashFiles('**/Cargo.toml') }}
|
restore-keys: ${{ runner.os }}-coverage-cargo
|
||||||
- name: Install Rust toolchain
|
- name: Install Rust toolchain
|
||||||
uses: actions-rs/toolchain@v1
|
uses: dtolnay/rust-toolchain@nightly
|
||||||
with:
|
|
||||||
toolchain: nightly
|
|
||||||
profile: minimal
|
|
||||||
default: true
|
|
||||||
- name: Install cargo-tarpaulin
|
- name: Install cargo-tarpaulin
|
||||||
uses: actions-rs/install@v0.1
|
run: cargo install cargo-tarpaulin
|
||||||
with:
|
|
||||||
crate: cargo-tarpaulin
|
|
||||||
version: latest
|
|
||||||
use-tool-cache: true
|
|
||||||
- name: Run cargo-tarpaulin
|
- name: Run cargo-tarpaulin
|
||||||
uses: actions-rs/cargo@v1
|
run: cargo tarpaulin --out Lcov --all-features -- --test-threads 1
|
||||||
with:
|
|
||||||
command: tarpaulin
|
|
||||||
args: --out Lcov --all-features -- --test-threads 1
|
|
||||||
- name: Upload to codecov.io
|
- name: Upload to codecov.io
|
||||||
uses: codecov/codecov-action@v2
|
uses: codecov/codecov-action@v4
|
||||||
with:
|
with:
|
||||||
fail_ci_if_error: false
|
fail_ci_if_error: false
|
||||||
|
|||||||
@@ -6,36 +6,27 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
branches: [ development ]
|
branches: [ development ]
|
||||||
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
lint:
|
lint:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
TZ: "/usr/share/zoneinfo/your/location"
|
TZ: "/usr/share/zoneinfo/your/location"
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v4
|
||||||
- name: Cache .cargo and target
|
- name: Cache .cargo and target
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/.cargo
|
~/.cargo
|
||||||
./target
|
./target
|
||||||
key: ${{ runner.os }}-lint-cargo-${{ hashFiles('**/Cargo.toml') }}
|
key: ${{ runner.os }}-lint-cargo-${{ hashFiles('Cargo.toml') }}
|
||||||
restore-keys: ${{ runner.os }}-lint-cargo-${{ hashFiles('**/Cargo.toml') }}
|
restore-keys: ${{ runner.os }}-lint-cargo
|
||||||
- name: Install Rust toolchain
|
- name: Install Rust toolchain
|
||||||
uses: actions-rs/toolchain@v1
|
uses: dtolnay/rust-toolchain@stable
|
||||||
with:
|
with:
|
||||||
toolchain: stable
|
components: rustfmt, clippy
|
||||||
profile: minimal
|
- name: Check format
|
||||||
default: true
|
run: cargo fmt --all -- --check
|
||||||
- run: rustup component add rustfmt
|
|
||||||
- name: Check formt
|
|
||||||
uses: actions-rs/cargo@v1
|
|
||||||
with:
|
|
||||||
command: fmt
|
|
||||||
args: --all -- --check
|
|
||||||
- run: rustup component add clippy
|
|
||||||
- name: Run clippy
|
- name: Run clippy
|
||||||
uses: actions-rs/cargo@v1
|
run: cargo clippy --all-features -- -Drust-2018-idioms -Dwarnings
|
||||||
with:
|
|
||||||
command: clippy
|
|
||||||
args: --all-features -- -Drust-2018-idioms -Dwarnings
|
|
||||||
|
|||||||
@@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
|
|||||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
|
## [0.4.8] - 2025-11-29
|
||||||
|
- WARNING: Breaking changes!
|
||||||
|
- `LassoParameters` and `LassoSearchParameters` have a new field `fit_intercept`. When it is set to false, the `beta_0` term in the formula will be forced to zero, and `intercept` field in `Lasso` will be set to `None`.
|
||||||
|
|
||||||
|
|
||||||
## [0.4.0] - 2023-04-05
|
## [0.4.0] - 2023-04-05
|
||||||
|
|
||||||
## Added
|
## Added
|
||||||
|
|||||||
+2
-1
@@ -2,7 +2,7 @@
|
|||||||
name = "smartcore"
|
name = "smartcore"
|
||||||
description = "Machine Learning in Rust."
|
description = "Machine Learning in Rust."
|
||||||
homepage = "https://smartcorelib.org"
|
homepage = "https://smartcorelib.org"
|
||||||
version = "0.4.4"
|
version = "0.4.8"
|
||||||
authors = ["smartcore Developers"]
|
authors = ["smartcore Developers"]
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
@@ -28,6 +28,7 @@ num = "0.4"
|
|||||||
rand = { version = "0.8.5", default-features = false, features = ["small_rng"] }
|
rand = { version = "0.8.5", default-features = false, features = ["small_rng"] }
|
||||||
rand_distr = { version = "0.4", optional = true }
|
rand_distr = { version = "0.4", optional = true }
|
||||||
serde = { version = "1", features = ["derive"], optional = true }
|
serde = { version = "1", features = ["derive"], optional = true }
|
||||||
|
ordered-float = "5.1.0"
|
||||||
|
|
||||||
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
|
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
|
||||||
typetag = { version = "0.2", optional = true }
|
typetag = { version = "0.2", optional = true }
|
||||||
|
|||||||
@@ -23,7 +23,10 @@
|
|||||||
/// ```
|
/// ```
|
||||||
/// <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
/// <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
/// <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
/// <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
use std::collections::HashMap;
|
use ordered_float::{FloatCore, OrderedFloat};
|
||||||
|
|
||||||
|
use std::cmp::Reverse;
|
||||||
|
use std::collections::{BinaryHeap, HashMap};
|
||||||
|
|
||||||
use num::Bounded;
|
use num::Bounded;
|
||||||
|
|
||||||
@@ -34,6 +37,25 @@ use crate::metrics::distance::{Distance, PairwiseDistance};
|
|||||||
use crate::numbers::floatnum::FloatNumber;
|
use crate::numbers::floatnum::FloatNumber;
|
||||||
use crate::numbers::realnum::RealNumber;
|
use crate::numbers::realnum::RealNumber;
|
||||||
|
|
||||||
|
/// Parameters for CosinePair construction
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct CosinePairParameters {
|
||||||
|
/// Maximum number of neighbors to consider per point (default: all points)
|
||||||
|
pub top_k: Option<usize>,
|
||||||
|
/// Whether to use approximate nearest neighbor search
|
||||||
|
pub approximate: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::derivable_impls)]
|
||||||
|
impl Default for CosinePairParameters {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
top_k: None,
|
||||||
|
approximate: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
///
|
///
|
||||||
/// Inspired by Python implementation:
|
/// Inspired by Python implementation:
|
||||||
/// <https://github.com/carsonfarmer/fastpair/blob/b8b4d3000ab6f795a878936667eee1b557bf353d/fastpair/base.py>
|
/// <https://github.com/carsonfarmer/fastpair/blob/b8b4d3000ab6f795a878936667eee1b557bf353d/fastpair/base.py>
|
||||||
@@ -49,12 +71,29 @@ pub struct CosinePair<'a, T: RealNumber + FloatNumber, M: Array2<T>> {
|
|||||||
pub distances: HashMap<usize, PairwiseDistance<T>>,
|
pub distances: HashMap<usize, PairwiseDistance<T>>,
|
||||||
/// conga line used to keep track of the closest pair
|
/// conga line used to keep track of the closest pair
|
||||||
pub neighbours: Vec<usize>,
|
pub neighbours: Vec<usize>,
|
||||||
|
/// parameters used during construction
|
||||||
|
pub parameters: CosinePairParameters,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> CosinePair<'a, T, M> {
|
impl<'a, T: RealNumber + FloatNumber + FloatCore, M: Array2<T>> CosinePair<'a, T, M> {
|
||||||
/// Constructor
|
/// Constructor with default parameters (backward compatibility)
|
||||||
/// Instantiate and initialize the algorithm
|
|
||||||
pub fn new(m: &'a M) -> Result<Self, Failed> {
|
pub fn new(m: &'a M) -> Result<Self, Failed> {
|
||||||
|
Self::with_parameters(m, CosinePairParameters::default())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Constructor with top-k limiting for faster performance
|
||||||
|
pub fn with_top_k(m: &'a M, top_k: usize) -> Result<Self, Failed> {
|
||||||
|
Self::with_parameters(
|
||||||
|
m,
|
||||||
|
CosinePairParameters {
|
||||||
|
top_k: Some(top_k),
|
||||||
|
approximate: false,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Constructor with full parameter control
|
||||||
|
pub fn with_parameters(m: &'a M, parameters: CosinePairParameters) -> Result<Self, Failed> {
|
||||||
if m.shape().0 < 2 {
|
if m.shape().0 < 2 {
|
||||||
return Err(Failed::because(
|
return Err(Failed::because(
|
||||||
FailedError::FindFailed,
|
FailedError::FindFailed,
|
||||||
@@ -64,96 +103,156 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> CosinePair<'a, T, M> {
|
|||||||
|
|
||||||
let mut init = Self {
|
let mut init = Self {
|
||||||
samples: m,
|
samples: m,
|
||||||
// to be computed in init(..)
|
|
||||||
distances: HashMap::with_capacity(m.shape().0),
|
distances: HashMap::with_capacity(m.shape().0),
|
||||||
neighbours: Vec::with_capacity(m.shape().0 + 1),
|
neighbours: Vec::with_capacity(m.shape().0),
|
||||||
|
parameters,
|
||||||
};
|
};
|
||||||
init.init();
|
init.init();
|
||||||
Ok(init)
|
Ok(init)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Initialise `CosinePair` by passing a `Array2`.
|
/// Helper function to create ordered float wrapper
|
||||||
/// Build a CosinePairs data-structure from a set of (new) points.
|
fn ordered_float(value: T) -> OrderedFloat<T> {
|
||||||
|
OrderedFloat(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper function to extract value from ordered float wrapper
|
||||||
|
fn extract_float(ordered: OrderedFloat<T>) -> T {
|
||||||
|
ordered.into_inner()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Optimized initialization with top-k neighbor limiting
|
||||||
fn init(&mut self) {
|
fn init(&mut self) {
|
||||||
// basic measures
|
|
||||||
let len = self.samples.shape().0;
|
let len = self.samples.shape().0;
|
||||||
let max_index = self.samples.shape().0 - 1;
|
let max_neighbors: usize = self.parameters.top_k.unwrap_or(len - 1).min(len - 1);
|
||||||
|
|
||||||
// Store all closest neighbors
|
let mut distances = HashMap::with_capacity(len);
|
||||||
let _distances = Box::new(HashMap::with_capacity(len));
|
let mut neighbours = Vec::with_capacity(len);
|
||||||
let _neighbours = Box::new(Vec::with_capacity(len));
|
|
||||||
|
|
||||||
let mut distances = *_distances;
|
|
||||||
let mut neighbours = *_neighbours;
|
|
||||||
|
|
||||||
// fill neighbours with -1 values
|
|
||||||
neighbours.extend(0..len);
|
neighbours.extend(0..len);
|
||||||
|
|
||||||
// init closest neighbour pairwise data
|
// Initialize with max distances
|
||||||
for index_row_i in 0..(max_index) {
|
for i in 0..len {
|
||||||
distances.insert(
|
distances.insert(
|
||||||
index_row_i,
|
i,
|
||||||
PairwiseDistance {
|
PairwiseDistance {
|
||||||
node: index_row_i,
|
node: i,
|
||||||
neighbour: Option::None,
|
neighbour: None,
|
||||||
distance: Some(<T as Bounded>::max_value()),
|
distance: Some(<T as Bounded>::max_value()),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// loop through indeces and neighbours
|
// Compute distances for each point using top-k optimization
|
||||||
for index_row_i in 0..(len) {
|
for i in 0..len {
|
||||||
// start looking for the neighbour in the second element
|
let mut candidate_distances = BinaryHeap::new();
|
||||||
let mut index_closest = index_row_i + 1; // closest neighbour index
|
|
||||||
let mut nbd: Option<T> = distances[&index_row_i].distance; // init neighbour distance
|
|
||||||
for index_row_j in (index_row_i + 1)..len {
|
|
||||||
distances.insert(
|
|
||||||
index_row_j,
|
|
||||||
PairwiseDistance {
|
|
||||||
node: index_row_j,
|
|
||||||
neighbour: Some(index_row_i),
|
|
||||||
distance: nbd,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
let d = Cosine::new().distance(
|
for j in 0..len {
|
||||||
|
if i != j {
|
||||||
|
let distance = T::from(Cosine::new().distance(
|
||||||
&Vec::from_iterator(
|
&Vec::from_iterator(
|
||||||
self.samples.get_row(index_row_i).iterator(0).copied(),
|
self.samples.get_row(i).iterator(0).copied(),
|
||||||
self.samples.shape().1,
|
self.samples.shape().1,
|
||||||
),
|
),
|
||||||
&Vec::from_iterator(
|
&Vec::from_iterator(
|
||||||
self.samples.get_row(index_row_j).iterator(0).copied(),
|
self.samples.get_row(j).iterator(0).copied(),
|
||||||
self.samples.shape().1,
|
self.samples.shape().1,
|
||||||
),
|
),
|
||||||
);
|
))
|
||||||
if d < nbd.unwrap().to_f64().unwrap() {
|
.unwrap();
|
||||||
// set this j-value to be the closest neighbour
|
|
||||||
index_closest = index_row_j;
|
// Use OrderedFloat for stable ordering
|
||||||
nbd = Some(T::from(d).unwrap());
|
candidate_distances.push(Reverse((Self::ordered_float(distance), j)));
|
||||||
|
|
||||||
|
if candidate_distances.len() > max_neighbors {
|
||||||
|
candidate_distances.pop();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add that edge
|
// Find the closest neighbor from candidates
|
||||||
distances.entry(index_row_i).and_modify(|e| {
|
if let Some(Reverse((closest_distance, closest_neighbor))) =
|
||||||
e.distance = nbd;
|
candidate_distances.iter().min_by_key(|Reverse((d, _))| *d)
|
||||||
e.neighbour = Some(index_closest);
|
{
|
||||||
|
distances.entry(i).and_modify(|e| {
|
||||||
|
e.distance = Some(Self::extract_float(*closest_distance));
|
||||||
|
e.neighbour = Some(*closest_neighbor);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
// No more neighbors, terminate conga line.
|
|
||||||
// Last person on the line has no neigbors
|
|
||||||
distances.get_mut(&max_index).unwrap().neighbour = Some(max_index);
|
|
||||||
distances.get_mut(&(len - 1)).unwrap().distance = Some(<T as Bounded>::max_value());
|
|
||||||
|
|
||||||
// compute sparse matrix (connectivity matrix)
|
|
||||||
let mut sparse_matrix = M::zeros(len, len);
|
|
||||||
for (_, p) in distances.iter() {
|
|
||||||
sparse_matrix.set((p.node, p.neighbour.unwrap()), p.distance.unwrap());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
self.distances = distances;
|
self.distances = distances;
|
||||||
self.neighbours = neighbours;
|
self.neighbours = neighbours;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Fast query using top-k pre-computed neighbors with ordered-float
|
||||||
|
pub fn query_row_top_k(
|
||||||
|
&self,
|
||||||
|
query_row_index: usize,
|
||||||
|
k: usize,
|
||||||
|
) -> Result<Vec<(T, usize)>, Failed> {
|
||||||
|
if query_row_index >= self.samples.shape().0 {
|
||||||
|
return Err(Failed::because(
|
||||||
|
FailedError::FindFailed,
|
||||||
|
"Query row index out of bounds",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if k == 0 {
|
||||||
|
return Ok(Vec::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
let max_candidates = self.parameters.top_k.unwrap_or(self.samples.shape().0);
|
||||||
|
let actual_k: usize = k.min(max_candidates);
|
||||||
|
|
||||||
|
// Use binary heap with ordered-float for reliable ordering
|
||||||
|
let mut heap = BinaryHeap::with_capacity(actual_k + 1);
|
||||||
|
|
||||||
|
let candidates = if let Some(top_k) = self.parameters.top_k {
|
||||||
|
let step = (self.samples.shape().0 / top_k).max(1);
|
||||||
|
(0..self.samples.shape().0)
|
||||||
|
.step_by(step)
|
||||||
|
.filter(|&i| i != query_row_index)
|
||||||
|
.take(top_k)
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
} else {
|
||||||
|
(0..self.samples.shape().0)
|
||||||
|
.filter(|&i| i != query_row_index)
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
};
|
||||||
|
|
||||||
|
for &candidate_idx in &candidates {
|
||||||
|
let distance = T::from(Cosine::new().distance(
|
||||||
|
&Vec::from_iterator(
|
||||||
|
self.samples.get_row(query_row_index).iterator(0).copied(),
|
||||||
|
self.samples.shape().1,
|
||||||
|
),
|
||||||
|
&Vec::from_iterator(
|
||||||
|
self.samples.get_row(candidate_idx).iterator(0).copied(),
|
||||||
|
self.samples.shape().1,
|
||||||
|
),
|
||||||
|
))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
heap.push(Reverse((Self::ordered_float(distance), candidate_idx)));
|
||||||
|
|
||||||
|
if heap.len() > actual_k {
|
||||||
|
heap.pop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert heap to sorted vector
|
||||||
|
let mut neighbors: Vec<_> = heap
|
||||||
|
.into_vec()
|
||||||
|
.into_iter()
|
||||||
|
.map(|Reverse((dist, idx))| (Self::extract_float(dist), idx))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
neighbors.sort_by(|a, b| Self::ordered_float(a.0).cmp(&Self::ordered_float(b.0)));
|
||||||
|
|
||||||
|
Ok(neighbors)
|
||||||
|
}
|
||||||
|
|
||||||
/// Query k nearest neighbors for a row that's already in the dataset
|
/// Query k nearest neighbors for a row that's already in the dataset
|
||||||
pub fn query_row(&self, query_row_index: usize, k: usize) -> Result<Vec<(T, usize)>, Failed> {
|
pub fn query_row(&self, query_row_index: usize, k: usize) -> Result<Vec<(T, usize)>, Failed> {
|
||||||
if query_row_index >= self.samples.shape().0 {
|
if query_row_index >= self.samples.shape().0 {
|
||||||
@@ -318,7 +417,7 @@ impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> CosinePair<'a, T, M> {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix};
|
use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix};
|
||||||
use approx::assert_relative_eq;
|
use approx::{assert_relative_eq, relative_eq};
|
||||||
|
|
||||||
#[cfg_attr(
|
#[cfg_attr(
|
||||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
@@ -499,10 +598,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg_attr(
|
|
||||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
|
||||||
wasm_bindgen_test::wasm_bindgen_test
|
|
||||||
)]
|
|
||||||
#[test]
|
#[test]
|
||||||
fn cosine_pair_query_row_bounds_error() {
|
fn cosine_pair_query_row_bounds_error() {
|
||||||
let x = DenseMatrix::<f64>::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
|
let x = DenseMatrix::<f64>::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
|
||||||
@@ -520,10 +615,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg_attr(
|
|
||||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
|
||||||
wasm_bindgen_test::wasm_bindgen_test
|
|
||||||
)]
|
|
||||||
#[test]
|
#[test]
|
||||||
fn cosine_pair_query_row_k_zero() {
|
fn cosine_pair_query_row_k_zero() {
|
||||||
let x =
|
let x =
|
||||||
@@ -635,6 +726,206 @@ mod tests {
|
|||||||
assert!(distance >= 0.0 && distance <= 2.0);
|
assert!(distance >= 0.0 && distance <= 2.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn query_row_top_k_top_k_limiting() {
|
||||||
|
// Test that query_row_top_k respects top_k parameter and returns correct results
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[1.0, 0.0, 0.0], // Point 0
|
||||||
|
&[0.0, 1.0, 0.0], // Point 1 - orthogonal to point 0
|
||||||
|
&[0.0, 0.0, 1.0], // Point 2 - orthogonal to point 0
|
||||||
|
&[1.0, 1.0, 0.0], // Point 3 - closer to point 0 than points 1,2
|
||||||
|
&[0.5, 0.0, 0.0], // Point 4 - very close to point 0 (parallel)
|
||||||
|
&[2.0, 0.0, 0.0], // Point 5 - very close to point 0 (parallel)
|
||||||
|
&[0.0, 1.0, 1.0], // Point 6 - far from point 0
|
||||||
|
&[3.0, 3.0, 3.0], // Point 7 - moderately close to point 0
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Create CosinePair with top_k=4 to limit candidates
|
||||||
|
let cosine_pair = CosinePair::with_top_k(&x, 4).unwrap();
|
||||||
|
|
||||||
|
// Query for 3 nearest neighbors to point 0
|
||||||
|
let neighbors = cosine_pair.query_row_top_k(0, 3).unwrap();
|
||||||
|
|
||||||
|
// Should return exactly 3 neighbors
|
||||||
|
assert_eq!(neighbors.len(), 3);
|
||||||
|
|
||||||
|
// Verify that distances are in ascending order
|
||||||
|
for i in 1..neighbors.len() {
|
||||||
|
assert!(
|
||||||
|
neighbors[i - 1].0 <= neighbors[i].0,
|
||||||
|
"Distances should be in ascending order: {} <= {}",
|
||||||
|
neighbors[i - 1].0,
|
||||||
|
neighbors[i].0
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// All distances should be valid cosine distances (0 to 2)
|
||||||
|
for (distance, index) in &neighbors {
|
||||||
|
assert!(
|
||||||
|
*distance >= 0.0 && *distance <= 2.0,
|
||||||
|
"Cosine distance {} should be between 0 and 2",
|
||||||
|
distance
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
*index < x.shape().0,
|
||||||
|
"Neighbor index {} should be less than dataset size {}",
|
||||||
|
index,
|
||||||
|
x.shape().0
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
*index != 0,
|
||||||
|
"Neighbor index should not include query point itself"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The closest neighbor should be either point 4 or 5 (parallel vectors)
|
||||||
|
// These should have cosine distance ≈ 0
|
||||||
|
let closest_distance = neighbors[0].0;
|
||||||
|
assert!(
|
||||||
|
closest_distance < 0.01,
|
||||||
|
"Closest parallel vector should have distance close to 0, got {}",
|
||||||
|
closest_distance
|
||||||
|
);
|
||||||
|
|
||||||
|
// Verify that we get different results with different top_k values
|
||||||
|
let cosine_pair_full = CosinePair::new(&x).unwrap();
|
||||||
|
let neighbors_full = cosine_pair_full.query_row(0, 3).unwrap();
|
||||||
|
|
||||||
|
// Results should be the same or very close since we're asking for top 3
|
||||||
|
// but the algorithm might find different candidates due to top_k limiting
|
||||||
|
assert_eq!(neighbors.len(), neighbors_full.len());
|
||||||
|
|
||||||
|
// The closest neighbor should be the same in both cases
|
||||||
|
let closest_idx_fast = neighbors[0].1;
|
||||||
|
let closest_idx_full = neighbors_full[0].1;
|
||||||
|
let closest_dist_fast = neighbors[0].0;
|
||||||
|
let closest_dist_full = neighbors_full[0].0;
|
||||||
|
|
||||||
|
// Either we get the same closest neighbor, or distances are very close
|
||||||
|
if closest_idx_fast == closest_idx_full {
|
||||||
|
assert!(relative_eq!(
|
||||||
|
closest_dist_fast,
|
||||||
|
closest_dist_full,
|
||||||
|
epsilon = 1e-10
|
||||||
|
));
|
||||||
|
} else {
|
||||||
|
// Different neighbors, but distances should be very close (parallel vectors)
|
||||||
|
assert!(relative_eq!(
|
||||||
|
closest_dist_fast,
|
||||||
|
closest_dist_full,
|
||||||
|
epsilon = 1e-6
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn query_row_top_k_performance_vs_accuracy() {
|
||||||
|
// Test that query_row_top_k provides reasonable performance/accuracy tradeoff
|
||||||
|
// and handles edge cases properly
|
||||||
|
let large_dataset = DenseMatrix::<f32>::from_2d_array(&[
|
||||||
|
&[1.0f32, 2.0, 3.0, 4.0], // Point 0 - query point
|
||||||
|
&[1.1f32, 2.1, 3.1, 4.1], // Point 1 - very close to 0
|
||||||
|
&[1.05f32, 2.05, 3.05, 4.05], // Point 2 - very close to 0
|
||||||
|
&[2.0f32, 4.0, 6.0, 8.0], // Point 3 - parallel to 0 (2x scaling)
|
||||||
|
&[0.5f32, 1.0, 1.5, 2.0], // Point 4 - parallel to 0 (0.5x scaling)
|
||||||
|
&[-1.0f32, -2.0, -3.0, -4.0], // Point 5 - opposite to 0
|
||||||
|
&[4.0f32, 3.0, 2.0, 1.0], // Point 6 - different direction
|
||||||
|
&[0.0f32, 0.0, 0.0, 0.1], // Point 7 - mostly orthogonal
|
||||||
|
&[10.0f32, 20.0, 30.0, 40.0], // Point 8 - parallel but far
|
||||||
|
&[1.0f32, 0.0, 0.0, 0.0], // Point 9 - partially similar
|
||||||
|
&[0.0f32, 2.0, 0.0, 0.0], // Point 10 - partially similar
|
||||||
|
&[0.0f32, 0.0, 3.0, 0.0], // Point 11 - partially similar
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Test with aggressive top_k limiting (only consider 5 out of 11 other points)
|
||||||
|
let cosine_pair_limited = CosinePair::with_top_k(&large_dataset, 5).unwrap();
|
||||||
|
|
||||||
|
// Query for 4 nearest neighbors
|
||||||
|
let neighbors_limited = cosine_pair_limited.query_row_top_k(0, 4).unwrap();
|
||||||
|
|
||||||
|
// Should return exactly 4 neighbors
|
||||||
|
assert_eq!(neighbors_limited.len(), 4);
|
||||||
|
|
||||||
|
// Test error handling - out of bounds query
|
||||||
|
let result_oob = cosine_pair_limited.query_row_top_k(15, 2);
|
||||||
|
assert!(result_oob.is_err());
|
||||||
|
if let Err(e) = result_oob {
|
||||||
|
assert_eq!(
|
||||||
|
e,
|
||||||
|
Failed::because(FailedError::FindFailed, "Query row index out of bounds")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test k=0 case
|
||||||
|
let neighbors_zero = cosine_pair_limited.query_row_top_k(0, 0).unwrap();
|
||||||
|
assert_eq!(neighbors_zero.len(), 0);
|
||||||
|
|
||||||
|
// Test k > available candidates
|
||||||
|
let neighbors_large_k = cosine_pair_limited.query_row_top_k(0, 20).unwrap();
|
||||||
|
assert!(neighbors_large_k.len() <= 11); // At most 11 other points
|
||||||
|
|
||||||
|
// Verify ordering is correct
|
||||||
|
for i in 1..neighbors_limited.len() {
|
||||||
|
assert!(
|
||||||
|
neighbors_limited[i - 1].0 <= neighbors_limited[i].0,
|
||||||
|
"Distance ordering violation at position {}: {} > {}",
|
||||||
|
i,
|
||||||
|
neighbors_limited[i - 1].0,
|
||||||
|
neighbors_limited[i].0
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The closest neighbors should be the parallel vectors (points 1, 2, 3, 4)
|
||||||
|
// since they have the smallest cosine distances
|
||||||
|
let closest_distance = neighbors_limited[0].0;
|
||||||
|
assert!(
|
||||||
|
closest_distance < 0.1,
|
||||||
|
"Closest neighbor should be nearly parallel, distance: {}",
|
||||||
|
closest_distance
|
||||||
|
);
|
||||||
|
|
||||||
|
// Compare with full algorithm for accuracy assessment
|
||||||
|
let cosine_pair_full = CosinePair::new(&large_dataset).unwrap();
|
||||||
|
let neighbors_full = cosine_pair_full.query_row(0, 4).unwrap();
|
||||||
|
|
||||||
|
// The fast version might not find the exact same neighbors due to sampling,
|
||||||
|
// but the closest neighbor's distance should be very similar
|
||||||
|
let dist_diff = (neighbors_limited[0].0 - neighbors_full[0].0).abs();
|
||||||
|
assert!(
|
||||||
|
dist_diff < 0.01,
|
||||||
|
"Fast and full algorithms should give similar closest distances. Diff: {}",
|
||||||
|
dist_diff
|
||||||
|
);
|
||||||
|
|
||||||
|
// Verify that all returned indices are valid and unique
|
||||||
|
let mut indices: Vec<usize> = neighbors_limited.iter().map(|(_, idx)| *idx).collect();
|
||||||
|
indices.sort();
|
||||||
|
indices.dedup();
|
||||||
|
assert_eq!(
|
||||||
|
indices.len(),
|
||||||
|
neighbors_limited.len(),
|
||||||
|
"All neighbor indices should be unique"
|
||||||
|
);
|
||||||
|
|
||||||
|
for &idx in &indices {
|
||||||
|
assert!(
|
||||||
|
idx < large_dataset.shape().0,
|
||||||
|
"Neighbor index {} should be valid",
|
||||||
|
idx
|
||||||
|
);
|
||||||
|
assert!(idx != 0, "Neighbor should not include query point itself");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with f32 precision to ensure type compatibility
|
||||||
|
for (distance, _) in &neighbors_limited {
|
||||||
|
assert!(!distance.is_nan(), "Distance should not be NaN");
|
||||||
|
assert!(distance.is_finite(), "Distance should be finite");
|
||||||
|
assert!(*distance >= 0.0, "Distance should be non-negative");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn cosine_pair_float_precision() {
|
fn cosine_pair_float_precision() {
|
||||||
// Test with f32 precision
|
// Test with f32 precision
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
#![allow(clippy::ptr_arg)]
|
#![allow(clippy::ptr_arg, clippy::needless_range_loop)]
|
||||||
//! # Nearest Neighbors Search Algorithms and Data Structures
|
//! # Nearest Neighbors Search Algorithms and Data Structures
|
||||||
//!
|
//!
|
||||||
//! Nearest neighbor search is a basic computational tool that is particularly relevant to machine learning,
|
//! Nearest neighbor search is a basic computational tool that is particularly relevant to machine learning,
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
#![allow(clippy::ptr_arg, clippy::needless_range_loop)]
|
||||||
//! # Clustering
|
//! # Clustering
|
||||||
//!
|
//!
|
||||||
//! Clustering is the type of unsupervised learning where you divide the population or data points into a number of groups such that data points in the same groups
|
//! Clustering is the type of unsupervised learning where you divide the population or data points into a number of groups such that data points in the same groups
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
#![allow(clippy::ptr_arg, clippy::needless_range_loop)]
|
||||||
//! Datasets
|
//! Datasets
|
||||||
//!
|
//!
|
||||||
//! In this module you will find small datasets that are used in `smartcore` mostly for demonstration purposes.
|
//! In this module you will find small datasets that are used in `smartcore` mostly for demonstration purposes.
|
||||||
|
|||||||
@@ -385,7 +385,7 @@ impl<T: Debug + Display + Copy + Sized> Array<T, (usize, usize)> for DenseMatrix
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn is_empty(&self) -> bool {
|
fn is_empty(&self) -> bool {
|
||||||
self.ncols > 0 && self.nrows > 0
|
self.ncols < 1 || self.nrows < 1
|
||||||
}
|
}
|
||||||
|
|
||||||
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
|
fn iterator<'b>(&'b self, axis: u8) -> Box<dyn Iterator<Item = &'b T> + 'b> {
|
||||||
|
|||||||
@@ -345,6 +345,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
l1_reg * gamma,
|
l1_reg * gamma,
|
||||||
parameters.max_iter,
|
parameters.max_iter,
|
||||||
TX::from_f64(parameters.tol).unwrap(),
|
TX::from_f64(parameters.tol).unwrap(),
|
||||||
|
true,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
for i in 0..p {
|
for i in 0..p {
|
||||||
@@ -371,6 +372,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
l1_reg * gamma,
|
l1_reg * gamma,
|
||||||
parameters.max_iter,
|
parameters.max_iter,
|
||||||
TX::from_f64(parameters.tol).unwrap(),
|
TX::from_f64(parameters.tol).unwrap(),
|
||||||
|
true,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
for i in 0..p {
|
for i in 0..p {
|
||||||
|
|||||||
+142
-52
@@ -9,7 +9,7 @@
|
|||||||
//!
|
//!
|
||||||
//! Lasso coefficient estimates solve the problem:
|
//! Lasso coefficient estimates solve the problem:
|
||||||
//!
|
//!
|
||||||
//! \\[\underset{\beta}{minimize} \space \space \sum_{i=1}^n \left( y_i - \beta_0 - \sum_{j=1}^p \beta_jx_{ij} \right)^2 + \alpha \sum_{j=1}^p \lVert \beta_j \rVert_1\\]
|
//! \\[\underset{\beta}{minimize} \space \space \frac{1}{n} \sum_{i=1}^n \left( y_i - \beta_0 - \sum_{j=1}^p \beta_jx_{ij} \right)^2 + \alpha \sum_{j=1}^p \lVert \beta_j \rVert_1\\]
|
||||||
//!
|
//!
|
||||||
//! This problem is solved with an interior-point method that is comparable to coordinate descent in solving large problems with modest accuracy,
|
//! This problem is solved with an interior-point method that is comparable to coordinate descent in solving large problems with modest accuracy,
|
||||||
//! but is able to solve them with high accuracy with relatively small additional computational cost.
|
//! but is able to solve them with high accuracy with relatively small additional computational cost.
|
||||||
@@ -53,6 +53,9 @@ pub struct LassoParameters {
|
|||||||
#[cfg_attr(feature = "serde", serde(default))]
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
/// The maximum number of iterations
|
/// The maximum number of iterations
|
||||||
pub max_iter: usize,
|
pub max_iter: usize,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
/// If false, force the intercept parameter (beta_0) to be zero.
|
||||||
|
pub fit_intercept: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
@@ -86,6 +89,12 @@ impl LassoParameters {
|
|||||||
self.max_iter = max_iter;
|
self.max_iter = max_iter;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// If false, force the intercept parameter (beta_0) to be zero.
|
||||||
|
pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
|
||||||
|
self.fit_intercept = fit_intercept;
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for LassoParameters {
|
impl Default for LassoParameters {
|
||||||
@@ -95,6 +104,7 @@ impl Default for LassoParameters {
|
|||||||
normalize: true,
|
normalize: true,
|
||||||
tol: 1e-4,
|
tol: 1e-4,
|
||||||
max_iter: 1000,
|
max_iter: 1000,
|
||||||
|
fit_intercept: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -118,8 +128,8 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
{
|
{
|
||||||
fn new() -> Self {
|
fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
coefficients: Option::None,
|
coefficients: None,
|
||||||
intercept: Option::None,
|
intercept: None,
|
||||||
_phantom_ty: PhantomData,
|
_phantom_ty: PhantomData,
|
||||||
_phantom_y: PhantomData,
|
_phantom_y: PhantomData,
|
||||||
}
|
}
|
||||||
@@ -155,6 +165,9 @@ pub struct LassoSearchParameters {
|
|||||||
#[cfg_attr(feature = "serde", serde(default))]
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
/// The maximum number of iterations
|
/// The maximum number of iterations
|
||||||
pub max_iter: Vec<usize>,
|
pub max_iter: Vec<usize>,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
/// The maximum number of iterations
|
||||||
|
pub fit_intercept: Vec<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Lasso grid search iterator
|
/// Lasso grid search iterator
|
||||||
@@ -164,6 +177,7 @@ pub struct LassoSearchParametersIterator {
|
|||||||
current_normalize: usize,
|
current_normalize: usize,
|
||||||
current_tol: usize,
|
current_tol: usize,
|
||||||
current_max_iter: usize,
|
current_max_iter: usize,
|
||||||
|
current_fit_intercept: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IntoIterator for LassoSearchParameters {
|
impl IntoIterator for LassoSearchParameters {
|
||||||
@@ -177,6 +191,7 @@ impl IntoIterator for LassoSearchParameters {
|
|||||||
current_normalize: 0,
|
current_normalize: 0,
|
||||||
current_tol: 0,
|
current_tol: 0,
|
||||||
current_max_iter: 0,
|
current_max_iter: 0,
|
||||||
|
current_fit_intercept: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -189,6 +204,7 @@ impl Iterator for LassoSearchParametersIterator {
|
|||||||
&& self.current_normalize == self.lasso_search_parameters.normalize.len()
|
&& self.current_normalize == self.lasso_search_parameters.normalize.len()
|
||||||
&& self.current_tol == self.lasso_search_parameters.tol.len()
|
&& self.current_tol == self.lasso_search_parameters.tol.len()
|
||||||
&& self.current_max_iter == self.lasso_search_parameters.max_iter.len()
|
&& self.current_max_iter == self.lasso_search_parameters.max_iter.len()
|
||||||
|
&& self.current_fit_intercept == self.lasso_search_parameters.fit_intercept.len()
|
||||||
{
|
{
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
@@ -198,6 +214,7 @@ impl Iterator for LassoSearchParametersIterator {
|
|||||||
normalize: self.lasso_search_parameters.normalize[self.current_normalize],
|
normalize: self.lasso_search_parameters.normalize[self.current_normalize],
|
||||||
tol: self.lasso_search_parameters.tol[self.current_tol],
|
tol: self.lasso_search_parameters.tol[self.current_tol],
|
||||||
max_iter: self.lasso_search_parameters.max_iter[self.current_max_iter],
|
max_iter: self.lasso_search_parameters.max_iter[self.current_max_iter],
|
||||||
|
fit_intercept: self.lasso_search_parameters.fit_intercept[self.current_fit_intercept],
|
||||||
};
|
};
|
||||||
|
|
||||||
if self.current_alpha + 1 < self.lasso_search_parameters.alpha.len() {
|
if self.current_alpha + 1 < self.lasso_search_parameters.alpha.len() {
|
||||||
@@ -214,11 +231,19 @@ impl Iterator for LassoSearchParametersIterator {
|
|||||||
self.current_normalize = 0;
|
self.current_normalize = 0;
|
||||||
self.current_tol = 0;
|
self.current_tol = 0;
|
||||||
self.current_max_iter += 1;
|
self.current_max_iter += 1;
|
||||||
|
} else if self.current_fit_intercept + 1 < self.lasso_search_parameters.fit_intercept.len()
|
||||||
|
{
|
||||||
|
self.current_alpha = 0;
|
||||||
|
self.current_normalize = 0;
|
||||||
|
self.current_tol = 0;
|
||||||
|
self.current_max_iter = 0;
|
||||||
|
self.current_fit_intercept += 1;
|
||||||
} else {
|
} else {
|
||||||
self.current_alpha += 1;
|
self.current_alpha += 1;
|
||||||
self.current_normalize += 1;
|
self.current_normalize += 1;
|
||||||
self.current_tol += 1;
|
self.current_tol += 1;
|
||||||
self.current_max_iter += 1;
|
self.current_max_iter += 1;
|
||||||
|
self.current_fit_intercept += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
Some(next)
|
Some(next)
|
||||||
@@ -234,6 +259,7 @@ impl Default for LassoSearchParameters {
|
|||||||
normalize: vec![default_params.normalize],
|
normalize: vec![default_params.normalize],
|
||||||
tol: vec![default_params.tol],
|
tol: vec![default_params.tol],
|
||||||
max_iter: vec![default_params.max_iter],
|
max_iter: vec![default_params.max_iter],
|
||||||
|
fit_intercept: vec![default_params.fit_intercept],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -246,7 +272,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
|
|||||||
pub fn fit(x: &X, y: &Y, parameters: LassoParameters) -> Result<Lasso<TX, TY, X, Y>, Failed> {
|
pub fn fit(x: &X, y: &Y, parameters: LassoParameters) -> Result<Lasso<TX, TY, X, Y>, Failed> {
|
||||||
let (n, p) = x.shape();
|
let (n, p) = x.shape();
|
||||||
|
|
||||||
if n <= p {
|
if n < p {
|
||||||
return Err(Failed::fit(
|
return Err(Failed::fit(
|
||||||
"Number of rows in X should be >= number of columns in X",
|
"Number of rows in X should be >= number of columns in X",
|
||||||
));
|
));
|
||||||
@@ -283,19 +309,23 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
|
|||||||
l1_reg,
|
l1_reg,
|
||||||
parameters.max_iter,
|
parameters.max_iter,
|
||||||
TX::from_f64(parameters.tol).unwrap(),
|
TX::from_f64(parameters.tol).unwrap(),
|
||||||
|
parameters.fit_intercept,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
for (j, col_std_j) in col_std.iter().enumerate().take(p) {
|
for (j, col_std_j) in col_std.iter().enumerate().take(p) {
|
||||||
w[j] /= *col_std_j;
|
w[j] /= *col_std_j;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut b = TX::zero();
|
let b = if parameters.fit_intercept {
|
||||||
|
let mut xw_mean = TX::zero();
|
||||||
for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
|
for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
|
||||||
b += w[i] * *col_mean_i;
|
xw_mean += w[i] * *col_mean_i;
|
||||||
}
|
}
|
||||||
|
|
||||||
b = TX::from_f64(y.mean_by()).unwrap() - b;
|
Some(TX::from_f64(y.mean_by()).unwrap() - xw_mean)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
(X::from_column(&w), b)
|
(X::from_column(&w), b)
|
||||||
} else {
|
} else {
|
||||||
let mut optimizer = InteriorPointOptimizer::new(x, p);
|
let mut optimizer = InteriorPointOptimizer::new(x, p);
|
||||||
@@ -306,13 +336,21 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
|
|||||||
l1_reg,
|
l1_reg,
|
||||||
parameters.max_iter,
|
parameters.max_iter,
|
||||||
TX::from_f64(parameters.tol).unwrap(),
|
TX::from_f64(parameters.tol).unwrap(),
|
||||||
|
parameters.fit_intercept,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
(X::from_column(&w), TX::from_f64(y.mean_by()).unwrap())
|
(
|
||||||
|
X::from_column(&w),
|
||||||
|
if parameters.fit_intercept {
|
||||||
|
Some(TX::from_f64(y.mean_by()).unwrap())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
},
|
||||||
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Lasso {
|
Ok(Lasso {
|
||||||
intercept: Some(b),
|
intercept: b,
|
||||||
coefficients: Some(w),
|
coefficients: Some(w),
|
||||||
_phantom_ty: PhantomData,
|
_phantom_ty: PhantomData,
|
||||||
_phantom_y: PhantomData,
|
_phantom_y: PhantomData,
|
||||||
@@ -369,6 +407,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::linalg::basic::arrays::Array;
|
||||||
use crate::linalg::basic::matrix::DenseMatrix;
|
use crate::linalg::basic::matrix::DenseMatrix;
|
||||||
use crate::metrics::mean_absolute_error;
|
use crate::metrics::mean_absolute_error;
|
||||||
|
|
||||||
@@ -377,30 +416,28 @@ mod tests {
|
|||||||
let parameters = LassoSearchParameters {
|
let parameters = LassoSearchParameters {
|
||||||
alpha: vec![0., 1.],
|
alpha: vec![0., 1.],
|
||||||
max_iter: vec![10, 100],
|
max_iter: vec![10, 100],
|
||||||
|
fit_intercept: vec![false, true],
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let mut iter = parameters.into_iter();
|
|
||||||
|
let mut iter = parameters.clone().into_iter();
|
||||||
|
for current_fit_intercept in 0..parameters.fit_intercept.len() {
|
||||||
|
for current_max_iter in 0..parameters.max_iter.len() {
|
||||||
|
for current_alpha in 0..parameters.alpha.len() {
|
||||||
let next = iter.next().unwrap();
|
let next = iter.next().unwrap();
|
||||||
assert_eq!(next.alpha, 0.);
|
assert_eq!(next.alpha, parameters.alpha[current_alpha]);
|
||||||
assert_eq!(next.max_iter, 10);
|
assert_eq!(next.max_iter, parameters.max_iter[current_max_iter]);
|
||||||
let next = iter.next().unwrap();
|
assert_eq!(
|
||||||
assert_eq!(next.alpha, 1.);
|
next.fit_intercept,
|
||||||
assert_eq!(next.max_iter, 10);
|
parameters.fit_intercept[current_fit_intercept]
|
||||||
let next = iter.next().unwrap();
|
);
|
||||||
assert_eq!(next.alpha, 0.);
|
}
|
||||||
assert_eq!(next.max_iter, 100);
|
}
|
||||||
let next = iter.next().unwrap();
|
}
|
||||||
assert_eq!(next.alpha, 1.);
|
|
||||||
assert_eq!(next.max_iter, 100);
|
|
||||||
assert!(iter.next().is_none());
|
assert!(iter.next().is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg_attr(
|
fn get_example_x_y() -> (DenseMatrix<f64>, Vec<f64>) {
|
||||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
|
||||||
wasm_bindgen_test::wasm_bindgen_test
|
|
||||||
)]
|
|
||||||
#[test]
|
|
||||||
fn lasso_fit_predict() {
|
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||||
@@ -426,6 +463,17 @@ mod tests {
|
|||||||
114.2, 115.7, 116.9,
|
114.2, 115.7, 116.9,
|
||||||
];
|
];
|
||||||
|
|
||||||
|
(x, y)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn lasso_fit_predict() {
|
||||||
|
let (x, y) = get_example_x_y();
|
||||||
|
|
||||||
let y_hat = Lasso::fit(&x, &y, Default::default())
|
let y_hat = Lasso::fit(&x, &y, Default::default())
|
||||||
.and_then(|lr| lr.predict(&x))
|
.and_then(|lr| lr.predict(&x))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -440,6 +488,7 @@ mod tests {
|
|||||||
normalize: false,
|
normalize: false,
|
||||||
tol: 1e-4,
|
tol: 1e-4,
|
||||||
max_iter: 1000,
|
max_iter: 1000,
|
||||||
|
fit_intercept: true,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.and_then(|lr| lr.predict(&x))
|
.and_then(|lr| lr.predict(&x))
|
||||||
@@ -448,35 +497,76 @@ mod tests {
|
|||||||
assert!(mean_absolute_error(&y_hat, &y) < 2.0);
|
assert!(mean_absolute_error(&y_hat, &y) < 2.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn test_full_rank_x() {
|
||||||
|
// x: randn(3,3) * 10, demean, then round to 2 decimal points
|
||||||
|
// y = x @ [10.0, 0.2, -3.0], round to 2 decimal points
|
||||||
|
let param = LassoParameters::default()
|
||||||
|
.with_normalize(false)
|
||||||
|
.with_alpha(200.0);
|
||||||
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
|
&[-8.9, -2.24, 8.89],
|
||||||
|
&[-4.02, 8.89, 12.33],
|
||||||
|
&[12.92, -6.65, -21.22],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let y = vec![-116.12, -75.41, 191.53];
|
||||||
|
let w = Lasso::fit(&x, &y, param)
|
||||||
|
.unwrap()
|
||||||
|
.coefficients()
|
||||||
|
.iterator(0)
|
||||||
|
.copied()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let expected_w = vec![5.20289531, 0., -5.32823882]; // by coordinate descent
|
||||||
|
assert!(mean_absolute_error(&w, &expected_w) < 1e-3); // actual mean_absolute_error is about 2e-4
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn test_fit_intercept() {
|
||||||
|
let (x, y) = get_example_x_y();
|
||||||
|
let fit_result = Lasso::fit(
|
||||||
|
&x,
|
||||||
|
&y,
|
||||||
|
LassoParameters {
|
||||||
|
alpha: 0.1,
|
||||||
|
normalize: false,
|
||||||
|
tol: 1e-8,
|
||||||
|
max_iter: 1000,
|
||||||
|
fit_intercept: false,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let w = fit_result.coefficients().iterator(0).copied().collect();
|
||||||
|
// by sklearn LassoLars. coordinate descent doesn't converge well
|
||||||
|
let expected_w = vec![
|
||||||
|
0.18335684,
|
||||||
|
0.02106526,
|
||||||
|
0.00703214,
|
||||||
|
-1.35952542,
|
||||||
|
0.09295222,
|
||||||
|
0.,
|
||||||
|
];
|
||||||
|
assert!(mean_absolute_error(&w, &expected_w) < 1e-6);
|
||||||
|
assert_eq!(fit_result.intercept, None);
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: serialization for the new DenseMatrix needs to be implemented
|
// TODO: serialization for the new DenseMatrix needs to be implemented
|
||||||
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
|
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
|
||||||
// #[test]
|
// #[test]
|
||||||
// #[cfg(feature = "serde")]
|
// #[cfg(feature = "serde")]
|
||||||
// fn serde() {
|
// fn serde() {
|
||||||
// let x = DenseMatrix::from_2d_array(&[
|
// let (x, y) = get_lasso_sample_x_y();
|
||||||
// &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
|
||||||
// &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
|
||||||
// &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
|
||||||
// &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
|
|
||||||
// &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
|
||||||
// &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
|
|
||||||
// &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
|
|
||||||
// &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
|
|
||||||
// &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
|
||||||
// &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
|
|
||||||
// &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
|
||||||
// &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
|
|
||||||
// &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
|
||||||
// &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
|
||||||
// &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
|
||||||
// &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
|
||||||
// ]);
|
|
||||||
|
|
||||||
// let y = vec![
|
|
||||||
// 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
|
||||||
// 114.2, 115.7, 116.9,
|
|
||||||
// ];
|
|
||||||
|
|
||||||
// let lr = Lasso::fit(&x, &y, Default::default()).unwrap();
|
// let lr = Lasso::fit(&x, &y, Default::default()).unwrap();
|
||||||
|
|
||||||
// let deserialized_lr: Lasso<f64, f64, DenseMatrix<f64>, Vec<f64>> =
|
// let deserialized_lr: Lasso<f64, f64, DenseMatrix<f64>, Vec<f64>> =
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
|
|||||||
lambda: T,
|
lambda: T,
|
||||||
max_iter: usize,
|
max_iter: usize,
|
||||||
tol: T,
|
tol: T,
|
||||||
|
fit_intercept: bool,
|
||||||
) -> Result<Vec<T>, Failed> {
|
) -> Result<Vec<T>, Failed> {
|
||||||
let (n, p) = x.shape();
|
let (n, p) = x.shape();
|
||||||
let p_f64 = T::from_usize(p).unwrap();
|
let p_f64 = T::from_usize(p).unwrap();
|
||||||
@@ -61,7 +62,11 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
|
|||||||
let mu = T::two();
|
let mu = T::two();
|
||||||
|
|
||||||
// let y = M::from_row_vector(y.sub_scalar(y.mean_by())).transpose();
|
// let y = M::from_row_vector(y.sub_scalar(y.mean_by())).transpose();
|
||||||
let y = y.sub_scalar(T::from_f64(y.mean_by()).unwrap());
|
let y = if fit_intercept {
|
||||||
|
y.sub_scalar(T::from_f64(y.mean_by()).unwrap())
|
||||||
|
} else {
|
||||||
|
y.to_owned()
|
||||||
|
};
|
||||||
|
|
||||||
let mut max_ls_iter = 100;
|
let mut max_ls_iter = 100;
|
||||||
let mut pitr = 0;
|
let mut pitr = 0;
|
||||||
|
|||||||
+86
-21
@@ -4,7 +4,9 @@
|
|||||||
//!
|
//!
|
||||||
//! \\[precision = \frac{tp}{tp + fp}\\]
|
//! \\[precision = \frac{tp}{tp + fp}\\]
|
||||||
//!
|
//!
|
||||||
//! where tp (true positive) - correct result, fp (false positive) - unexpected result
|
//! where tp (true positive) - correct result, fp (false positive) - unexpected result.
|
||||||
|
//! For binary classification, this is precision for the positive class (assumed to be 1.0).
|
||||||
|
//! For multiclass, this is macro-averaged precision (average of per-class precisions).
|
||||||
//!
|
//!
|
||||||
//! Example:
|
//! Example:
|
||||||
//!
|
//!
|
||||||
@@ -19,7 +21,8 @@
|
|||||||
//!
|
//!
|
||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
use std::collections::HashSet;
|
|
||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
@@ -61,33 +64,63 @@ impl<T: RealNumber> Metrics<T> for Precision<T> {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut classes = HashSet::new();
|
let n = y_true.shape();
|
||||||
for i in 0..y_true.shape() {
|
|
||||||
classes.insert(y_true.get(i).to_f64_bits());
|
let mut classes_set: HashSet<u64> = HashSet::new();
|
||||||
}
|
for i in 0..n {
|
||||||
let classes = classes.len();
|
classes_set.insert(y_true.get(i).to_f64_bits());
|
||||||
|
}
|
||||||
|
let classes: usize = classes_set.len();
|
||||||
|
|
||||||
let mut tp = 0;
|
|
||||||
let mut fp = 0;
|
|
||||||
for i in 0..y_true.shape() {
|
|
||||||
if y_pred.get(i) == y_true.get(i) {
|
|
||||||
if classes == 2 {
|
if classes == 2 {
|
||||||
if *y_true.get(i) == T::one() {
|
// Binary case: precision for positive class (assumed T::one())
|
||||||
|
let positive = T::one();
|
||||||
|
let mut tp: usize = 0;
|
||||||
|
let mut fp_count: usize = 0;
|
||||||
|
for i in 0..n {
|
||||||
|
let t = *y_true.get(i);
|
||||||
|
let p = *y_pred.get(i);
|
||||||
|
if p == t {
|
||||||
|
if t == positive {
|
||||||
tp += 1;
|
tp += 1;
|
||||||
}
|
}
|
||||||
} else {
|
} else if t != positive {
|
||||||
tp += 1;
|
fp_count += 1;
|
||||||
}
|
}
|
||||||
} else if classes == 2 {
|
}
|
||||||
if *y_true.get(i) == T::one() {
|
if tp + fp_count == 0 {
|
||||||
fp += 1;
|
0.0
|
||||||
|
} else {
|
||||||
|
tp as f64 / (tp + fp_count) as f64
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
fp += 1;
|
// Multiclass case: macro-averaged precision
|
||||||
|
let mut predicted: HashMap<u64, usize> = HashMap::new();
|
||||||
|
let mut tp_map: HashMap<u64, usize> = HashMap::new();
|
||||||
|
for i in 0..n {
|
||||||
|
let p_bits = y_pred.get(i).to_f64_bits();
|
||||||
|
*predicted.entry(p_bits).or_insert(0) += 1;
|
||||||
|
if *y_true.get(i) == *y_pred.get(i) {
|
||||||
|
*tp_map.entry(p_bits).or_insert(0) += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mut precision_sum = 0.0;
|
||||||
|
for &bits in &classes_set {
|
||||||
|
let pred_count = *predicted.get(&bits).unwrap_or(&0);
|
||||||
|
let tp = *tp_map.get(&bits).unwrap_or(&0);
|
||||||
|
let prec = if pred_count > 0 {
|
||||||
|
tp as f64 / pred_count as f64
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
precision_sum += prec;
|
||||||
|
}
|
||||||
|
if classes == 0 {
|
||||||
|
0.0
|
||||||
|
} else {
|
||||||
|
precision_sum / classes as f64
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tp as f64 / (tp as f64 + fp as f64)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -114,7 +147,7 @@ mod tests {
|
|||||||
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
||||||
|
|
||||||
let score3: f64 = Precision::new().get_score(&y_true, &y_pred);
|
let score3: f64 = Precision::new().get_score(&y_true, &y_pred);
|
||||||
assert!((score3 - 0.6666666666).abs() < 1e-8);
|
assert!((score3 - 0.5).abs() < 1e-8);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg_attr(
|
#[cfg_attr(
|
||||||
@@ -132,4 +165,36 @@ mod tests {
|
|||||||
assert!((score1 - 0.333333333).abs() < 1e-8);
|
assert!((score1 - 0.333333333).abs() < 1e-8);
|
||||||
assert!((score2 - 1.0).abs() < 1e-8);
|
assert!((score2 - 1.0).abs() < 1e-8);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn precision_multiclass_imbalanced() {
|
||||||
|
let y_true: Vec<f64> = vec![0., 0., 1., 2., 2., 2.];
|
||||||
|
let y_pred: Vec<f64> = vec![0., 1., 1., 2., 0., 2.];
|
||||||
|
|
||||||
|
let score: f64 = Precision::new().get_score(&y_true, &y_pred);
|
||||||
|
let expected = (0.5 + 0.5 + 1.0) / 3.0;
|
||||||
|
assert!((score - expected).abs() < 1e-8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn precision_multiclass_unpredicted_class() {
|
||||||
|
let y_true: Vec<f64> = vec![0., 0., 1., 2., 2., 2., 3.];
|
||||||
|
let y_pred: Vec<f64> = vec![0., 1., 1., 2., 0., 2., 0.];
|
||||||
|
|
||||||
|
let score: f64 = Precision::new().get_score(&y_true, &y_pred);
|
||||||
|
// Class 0: pred=3, tp=1 -> 1/3 ≈0.333
|
||||||
|
// Class 1: pred=2, tp=1 -> 0.5
|
||||||
|
// Class 2: pred=2, tp=2 -> 1.0
|
||||||
|
// Class 3: pred=0, tp=0 -> 0.0
|
||||||
|
let expected = (1.0 / 3.0 + 0.5 + 1.0 + 0.0) / 4.0;
|
||||||
|
assert!((score - expected).abs() < 1e-8);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+62
-22
@@ -4,7 +4,9 @@
|
|||||||
//!
|
//!
|
||||||
//! \\[recall = \frac{tp}{tp + fn}\\]
|
//! \\[recall = \frac{tp}{tp + fn}\\]
|
||||||
//!
|
//!
|
||||||
//! where tp (true positive) - correct result, fn (false negative) - missing result
|
//! where tp (true positive) - correct result, fn (false negative) - missing result.
|
||||||
|
//! For binary classification, this is recall for the positive class (assumed to be 1.0).
|
||||||
|
//! For multiclass, this is macro-averaged recall (average of per-class recalls).
|
||||||
//!
|
//!
|
||||||
//! Example:
|
//! Example:
|
||||||
//!
|
//!
|
||||||
@@ -20,8 +22,7 @@
|
|||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
|
||||||
use std::collections::HashSet;
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::convert::TryInto;
|
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
@@ -52,7 +53,7 @@ impl<T: RealNumber> Metrics<T> for Recall<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
/// Calculated recall score
|
/// Calculated recall score
|
||||||
/// * `y_true` - cround truth (correct) labels.
|
/// * `y_true` - ground truth (correct) labels.
|
||||||
/// * `y_pred` - predicted labels, as returned by a classifier.
|
/// * `y_pred` - predicted labels, as returned by a classifier.
|
||||||
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64 {
|
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64 {
|
||||||
if y_true.shape() != y_pred.shape() {
|
if y_true.shape() != y_pred.shape() {
|
||||||
@@ -63,32 +64,57 @@ impl<T: RealNumber> Metrics<T> for Recall<T> {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut classes = HashSet::new();
|
let n = y_true.shape();
|
||||||
for i in 0..y_true.shape() {
|
|
||||||
classes.insert(y_true.get(i).to_f64_bits());
|
let mut classes_set = HashSet::new();
|
||||||
}
|
for i in 0..n {
|
||||||
let classes: i64 = classes.len().try_into().unwrap();
|
classes_set.insert(y_true.get(i).to_f64_bits());
|
||||||
|
}
|
||||||
|
let classes: usize = classes_set.len();
|
||||||
|
|
||||||
let mut tp = 0;
|
|
||||||
let mut fne = 0;
|
|
||||||
for i in 0..y_true.shape() {
|
|
||||||
if y_pred.get(i) == y_true.get(i) {
|
|
||||||
if classes == 2 {
|
if classes == 2 {
|
||||||
if *y_true.get(i) == T::one() {
|
// Binary case: recall for positive class (assumed T::one())
|
||||||
|
let positive = T::one();
|
||||||
|
let mut tp: usize = 0;
|
||||||
|
let mut fn_count: usize = 0;
|
||||||
|
for i in 0..n {
|
||||||
|
let t = *y_true.get(i);
|
||||||
|
let p = *y_pred.get(i);
|
||||||
|
if p == t {
|
||||||
|
if t == positive {
|
||||||
tp += 1;
|
tp += 1;
|
||||||
}
|
}
|
||||||
} else {
|
} else if t == positive {
|
||||||
tp += 1;
|
fn_count += 1;
|
||||||
}
|
}
|
||||||
} else if classes == 2 {
|
}
|
||||||
if *y_true.get(i) != T::one() {
|
if tp + fn_count == 0 {
|
||||||
fne += 1;
|
0.0
|
||||||
|
} else {
|
||||||
|
tp as f64 / (tp + fn_count) as f64
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
fne += 1;
|
// Multiclass case: macro-averaged recall
|
||||||
|
let mut support: HashMap<u64, usize> = HashMap::new();
|
||||||
|
let mut tp_map: HashMap<u64, usize> = HashMap::new();
|
||||||
|
for i in 0..n {
|
||||||
|
let t_bits = y_true.get(i).to_f64_bits();
|
||||||
|
*support.entry(t_bits).or_insert(0) += 1;
|
||||||
|
if *y_true.get(i) == *y_pred.get(i) {
|
||||||
|
*tp_map.entry(t_bits).or_insert(0) += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mut recall_sum = 0.0;
|
||||||
|
for (&bits, &sup) in &support {
|
||||||
|
let tp = *tp_map.get(&bits).unwrap_or(&0);
|
||||||
|
recall_sum += tp as f64 / sup as f64;
|
||||||
|
}
|
||||||
|
if support.is_empty() {
|
||||||
|
0.0
|
||||||
|
} else {
|
||||||
|
recall_sum / support.len() as f64
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
tp as f64 / (tp as f64 + fne as f64)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -115,7 +141,7 @@ mod tests {
|
|||||||
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
||||||
|
|
||||||
let score3: f64 = Recall::new().get_score(&y_true, &y_pred);
|
let score3: f64 = Recall::new().get_score(&y_true, &y_pred);
|
||||||
assert!((score3 - 0.5).abs() < 1e-8);
|
assert!((score3 - (2.0 / 3.0)).abs() < 1e-8);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg_attr(
|
#[cfg_attr(
|
||||||
@@ -133,4 +159,18 @@ mod tests {
|
|||||||
assert!((score1 - 0.333333333).abs() < 1e-8);
|
assert!((score1 - 0.333333333).abs() < 1e-8);
|
||||||
assert!((score2 - 1.0).abs() < 1e-8);
|
assert!((score2 - 1.0).abs() < 1e-8);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn recall_multiclass_imbalanced() {
|
||||||
|
let y_true: Vec<f64> = vec![0., 0., 1., 2., 2., 2.];
|
||||||
|
let y_pred: Vec<f64> = vec![0., 1., 1., 2., 0., 2.];
|
||||||
|
|
||||||
|
let score: f64 = Recall::new().get_score(&y_true, &y_pred);
|
||||||
|
let expected = (0.5 + 1.0 + (2.0 / 3.0)) / 3.0;
|
||||||
|
assert!((score - expected).abs() < 1e-8);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,10 +53,14 @@ use crate::{
|
|||||||
rand_custom::get_rng_impl,
|
rand_custom::get_rng_impl,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// Defines the objective function to be optimized.
|
/// Defines the objective function to be optimized.
|
||||||
/// The objective function provides the loss, gradient (first derivative), and
|
/// The objective function provides the loss, gradient (first derivative), and
|
||||||
/// hessian (second derivative) required for the XGBoost algorithm.
|
/// hessian (second derivative) required for the XGBoost algorithm.
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
pub enum Objective {
|
pub enum Objective {
|
||||||
/// The objective for regression tasks using Mean Squared Error.
|
/// The objective for regression tasks using Mean Squared Error.
|
||||||
/// Loss: 0.5 * (y_true - y_pred)^2
|
/// Loss: 0.5 * (y_true - y_pred)^2
|
||||||
@@ -122,6 +126,8 @@ impl Objective {
|
|||||||
/// This is a recursive data structure where each `TreeRegressor` is a node
|
/// This is a recursive data structure where each `TreeRegressor` is a node
|
||||||
/// that can have a left and a right child, also of type `TreeRegressor`.
|
/// that can have a left and a right child, also of type `TreeRegressor`.
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
struct TreeRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
|
struct TreeRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
|
||||||
left: Option<Box<TreeRegressor<TX, TY, X, Y>>>,
|
left: Option<Box<TreeRegressor<TX, TY, X, Y>>>,
|
||||||
right: Option<Box<TreeRegressor<TX, TY, X, Y>>>,
|
right: Option<Box<TreeRegressor<TX, TY, X, Y>>>,
|
||||||
@@ -374,6 +380,7 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
/// Parameters for the `jRegressor` model.
|
/// Parameters for the `jRegressor` model.
|
||||||
///
|
///
|
||||||
/// This struct holds all the hyperparameters that control the training process.
|
/// This struct holds all the hyperparameters that control the training process.
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct XGRegressorParameters {
|
pub struct XGRegressorParameters {
|
||||||
/// The number of boosting rounds or trees to build.
|
/// The number of boosting rounds or trees to build.
|
||||||
@@ -494,6 +501,8 @@ impl XGRegressorParameters {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// An Extreme Gradient Boosting (XGBoost) model for regression and classification tasks.
|
/// An Extreme Gradient Boosting (XGBoost) model for regression and classification tasks.
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct XGRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
|
pub struct XGRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
|
||||||
regressors: Option<Vec<TreeRegressor<TX, TY, X, Y>>>,
|
regressors: Option<Vec<TreeRegressor<TX, TY, X, Y>>>,
|
||||||
parameters: Option<XGRegressorParameters>,
|
parameters: Option<XGRegressorParameters>,
|
||||||
|
|||||||
Reference in New Issue
Block a user