Compare commits
11 Commits
v0.4.6
...
development
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f53cb36b9d | ||
|
|
c57a4370ba | ||
|
|
78f18505b1 | ||
|
|
58a8624fa9 | ||
|
|
18de2aa244 | ||
|
|
2bf5f7a1a5 | ||
|
|
0caa8306ff | ||
|
|
2f63148de4 | ||
|
|
f9e473c919 | ||
|
|
70d8a0f34b | ||
|
|
0e42a97514 |
+10
-30
@@ -31,33 +31,21 @@ jobs:
|
||||
~/.cargo
|
||||
./target
|
||||
key: ${{ runner.os }}-cargo-${{ matrix.platform.target }}-${{ hashFiles('**/Cargo.toml') }}
|
||||
restore-keys: ${{ runner.os }}-cargo-${{ matrix.platform.target }}-${{ hashFiles('**/Cargo.toml') }}
|
||||
restore-keys: ${{ runner.os }}-cargo-${{ matrix.platform.target }}
|
||||
- name: Install Rust toolchain
|
||||
uses: actions-rs/toolchain@v1
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
toolchain: stable
|
||||
target: ${{ matrix.platform.target }}
|
||||
profile: minimal
|
||||
default: true
|
||||
targets: ${{ matrix.platform.target }}
|
||||
- name: Install test runner for wasm
|
||||
if: matrix.platform.target == 'wasm32-unknown-unknown'
|
||||
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
|
||||
- name: Stable Build with all features
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: build
|
||||
args: --all-features --target ${{ matrix.platform.target }}
|
||||
run: cargo build --all-features --target ${{ matrix.platform.target }}
|
||||
- name: Stable Build without features
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: build
|
||||
args: --target ${{ matrix.platform.target }}
|
||||
run: cargo build --target ${{ matrix.platform.target }}
|
||||
- name: Tests
|
||||
if: matrix.platform.target == 'x86_64-unknown-linux-gnu' || matrix.platform.target == 'x86_64-pc-windows-msvc' || matrix.platform.target == 'aarch64-apple-darwin'
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --all-features
|
||||
run: cargo test --all-features
|
||||
- name: Tests in WASM
|
||||
if: matrix.platform.target == 'wasm32-unknown-unknown'
|
||||
run: wasm-pack test --node -- --all-features
|
||||
@@ -78,17 +66,9 @@ jobs:
|
||||
path: |
|
||||
~/.cargo
|
||||
./target
|
||||
key: ${{ runner.os }}-cargo-features-${{ hashFiles('**/Cargo.toml') }}
|
||||
restore-keys: ${{ runner.os }}-cargo-features-${{ hashFiles('**/Cargo.toml') }}
|
||||
key: ${{ runner.os }}-cargo-features-${{ hashFiles('Cargo.toml') }}
|
||||
restore-keys: ${{ runner.os }}-cargo-features
|
||||
- name: Install Rust toolchain
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: stable
|
||||
target: ${{ matrix.platform.target }}
|
||||
profile: minimal
|
||||
default: true
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
- name: Stable Build
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: build
|
||||
args: --no-default-features ${{ matrix.features }}
|
||||
run: cargo build --no-default-features ${{ matrix.features }}
|
||||
|
||||
@@ -19,26 +19,15 @@ jobs:
|
||||
path: |
|
||||
~/.cargo
|
||||
./target
|
||||
key: ${{ runner.os }}-coverage-cargo-${{ hashFiles('**/Cargo.toml') }}
|
||||
restore-keys: ${{ runner.os }}-coverage-cargo-${{ hashFiles('**/Cargo.toml') }}
|
||||
key: ${{ runner.os }}-coverage-cargo-${{ hashFiles('Cargo.toml') }}
|
||||
restore-keys: ${{ runner.os }}-coverage-cargo
|
||||
- name: Install Rust toolchain
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly
|
||||
profile: minimal
|
||||
default: true
|
||||
uses: dtolnay/rust-toolchain@nightly
|
||||
- name: Install cargo-tarpaulin
|
||||
uses: actions-rs/install@v0.1
|
||||
with:
|
||||
crate: cargo-tarpaulin
|
||||
version: latest
|
||||
use-tool-cache: true
|
||||
run: cargo install cargo-tarpaulin
|
||||
- name: Run cargo-tarpaulin
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: tarpaulin
|
||||
args: --out Lcov --all-features -- --test-threads 1
|
||||
run: cargo tarpaulin --out Lcov --all-features -- --test-threads 1
|
||||
- name: Upload to codecov.io
|
||||
uses: codecov/codecov-action@v2
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
fail_ci_if_error: false
|
||||
|
||||
@@ -6,36 +6,27 @@ on:
|
||||
pull_request:
|
||||
branches: [ development ]
|
||||
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
TZ: "/usr/share/zoneinfo/your/location"
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v4
|
||||
- name: Cache .cargo and target
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo
|
||||
./target
|
||||
key: ${{ runner.os }}-lint-cargo-${{ hashFiles('**/Cargo.toml') }}
|
||||
restore-keys: ${{ runner.os }}-lint-cargo-${{ hashFiles('**/Cargo.toml') }}
|
||||
key: ${{ runner.os }}-lint-cargo-${{ hashFiles('Cargo.toml') }}
|
||||
restore-keys: ${{ runner.os }}-lint-cargo
|
||||
- name: Install Rust toolchain
|
||||
uses: actions-rs/toolchain@v1
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
toolchain: stable
|
||||
profile: minimal
|
||||
default: true
|
||||
- run: rustup component add rustfmt
|
||||
- name: Check formt
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: fmt
|
||||
args: --all -- --check
|
||||
- run: rustup component add clippy
|
||||
components: rustfmt, clippy
|
||||
- name: Check format
|
||||
run: cargo fmt --all -- --check
|
||||
- name: Run clippy
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: clippy
|
||||
args: --all-features -- -Drust-2018-idioms -Dwarnings
|
||||
run: cargo clippy --all-features -- -Drust-2018-idioms -Dwarnings
|
||||
|
||||
@@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [0.4.8] - 2025-11-29
|
||||
- WARNING: Breaking changes!
|
||||
- `LassoParameters` and `LassoSearchParameters` have a new field `fit_intercept`. When it is set to false, the `beta_0` term in the formula will be forced to zero, and `intercept` field in `Lasso` will be set to `None`.
|
||||
|
||||
|
||||
## [0.4.0] - 2023-04-05
|
||||
|
||||
## Added
|
||||
|
||||
+1
-1
@@ -2,7 +2,7 @@
|
||||
name = "smartcore"
|
||||
description = "Machine Learning in Rust."
|
||||
homepage = "https://smartcorelib.org"
|
||||
version = "0.4.6"
|
||||
version = "0.4.9"
|
||||
authors = ["smartcore Developers"]
|
||||
edition = "2021"
|
||||
license = "Apache-2.0"
|
||||
|
||||
@@ -345,6 +345,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||
l1_reg * gamma,
|
||||
parameters.max_iter,
|
||||
TX::from_f64(parameters.tol).unwrap(),
|
||||
true,
|
||||
)?;
|
||||
|
||||
for i in 0..p {
|
||||
@@ -371,6 +372,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||
l1_reg * gamma,
|
||||
parameters.max_iter,
|
||||
TX::from_f64(parameters.tol).unwrap(),
|
||||
true,
|
||||
)?;
|
||||
|
||||
for i in 0..p {
|
||||
|
||||
+145
-55
@@ -9,7 +9,7 @@
|
||||
//!
|
||||
//! Lasso coefficient estimates solve the problem:
|
||||
//!
|
||||
//! \\[\underset{\beta}{minimize} \space \space \sum_{i=1}^n \left( y_i - \beta_0 - \sum_{j=1}^p \beta_jx_{ij} \right)^2 + \alpha \sum_{j=1}^p \lVert \beta_j \rVert_1\\]
|
||||
//! \\[\underset{\beta}{minimize} \space \space \frac{1}{n} \sum_{i=1}^n \left( y_i - \beta_0 - \sum_{j=1}^p \beta_jx_{ij} \right)^2 + \alpha \sum_{j=1}^p \lVert \beta_j \rVert_1\\]
|
||||
//!
|
||||
//! This problem is solved with an interior-point method that is comparable to coordinate descent in solving large problems with modest accuracy,
|
||||
//! but is able to solve them with high accuracy with relatively small additional computational cost.
|
||||
@@ -53,6 +53,9 @@ pub struct LassoParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The maximum number of iterations
|
||||
pub max_iter: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// If false, force the intercept parameter (beta_0) to be zero.
|
||||
pub fit_intercept: bool,
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
@@ -86,6 +89,12 @@ impl LassoParameters {
|
||||
self.max_iter = max_iter;
|
||||
self
|
||||
}
|
||||
|
||||
/// If false, force the intercept parameter (beta_0) to be zero.
|
||||
pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
|
||||
self.fit_intercept = fit_intercept;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LassoParameters {
|
||||
@@ -95,6 +104,7 @@ impl Default for LassoParameters {
|
||||
normalize: true,
|
||||
tol: 1e-4,
|
||||
max_iter: 1000,
|
||||
fit_intercept: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -118,8 +128,8 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||
{
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
coefficients: Option::None,
|
||||
intercept: Option::None,
|
||||
coefficients: None,
|
||||
intercept: None,
|
||||
_phantom_ty: PhantomData,
|
||||
_phantom_y: PhantomData,
|
||||
}
|
||||
@@ -155,6 +165,9 @@ pub struct LassoSearchParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The maximum number of iterations
|
||||
pub max_iter: Vec<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// If false, force the intercept parameter (beta_0) to be zero.
|
||||
pub fit_intercept: Vec<bool>,
|
||||
}
|
||||
|
||||
/// Lasso grid search iterator
|
||||
@@ -164,6 +177,7 @@ pub struct LassoSearchParametersIterator {
|
||||
current_normalize: usize,
|
||||
current_tol: usize,
|
||||
current_max_iter: usize,
|
||||
current_fit_intercept: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for LassoSearchParameters {
|
||||
@@ -177,6 +191,7 @@ impl IntoIterator for LassoSearchParameters {
|
||||
current_normalize: 0,
|
||||
current_tol: 0,
|
||||
current_max_iter: 0,
|
||||
current_fit_intercept: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -189,6 +204,7 @@ impl Iterator for LassoSearchParametersIterator {
|
||||
&& self.current_normalize == self.lasso_search_parameters.normalize.len()
|
||||
&& self.current_tol == self.lasso_search_parameters.tol.len()
|
||||
&& self.current_max_iter == self.lasso_search_parameters.max_iter.len()
|
||||
&& self.current_fit_intercept == self.lasso_search_parameters.fit_intercept.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
@@ -198,6 +214,7 @@ impl Iterator for LassoSearchParametersIterator {
|
||||
normalize: self.lasso_search_parameters.normalize[self.current_normalize],
|
||||
tol: self.lasso_search_parameters.tol[self.current_tol],
|
||||
max_iter: self.lasso_search_parameters.max_iter[self.current_max_iter],
|
||||
fit_intercept: self.lasso_search_parameters.fit_intercept[self.current_fit_intercept],
|
||||
};
|
||||
|
||||
if self.current_alpha + 1 < self.lasso_search_parameters.alpha.len() {
|
||||
@@ -214,11 +231,19 @@ impl Iterator for LassoSearchParametersIterator {
|
||||
self.current_normalize = 0;
|
||||
self.current_tol = 0;
|
||||
self.current_max_iter += 1;
|
||||
} else if self.current_fit_intercept + 1 < self.lasso_search_parameters.fit_intercept.len()
|
||||
{
|
||||
self.current_alpha = 0;
|
||||
self.current_normalize = 0;
|
||||
self.current_tol = 0;
|
||||
self.current_max_iter = 0;
|
||||
self.current_fit_intercept += 1;
|
||||
} else {
|
||||
self.current_alpha += 1;
|
||||
self.current_normalize += 1;
|
||||
self.current_tol += 1;
|
||||
self.current_max_iter += 1;
|
||||
self.current_fit_intercept += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
@@ -234,6 +259,7 @@ impl Default for LassoSearchParameters {
|
||||
normalize: vec![default_params.normalize],
|
||||
tol: vec![default_params.tol],
|
||||
max_iter: vec![default_params.max_iter],
|
||||
fit_intercept: vec![default_params.fit_intercept],
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -246,7 +272,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
|
||||
pub fn fit(x: &X, y: &Y, parameters: LassoParameters) -> Result<Lasso<TX, TY, X, Y>, Failed> {
|
||||
let (n, p) = x.shape();
|
||||
|
||||
if n <= p {
|
||||
if n < p {
|
||||
return Err(Failed::fit(
|
||||
"Number of rows in X should be >= number of columns in X",
|
||||
));
|
||||
@@ -283,19 +309,23 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
|
||||
l1_reg,
|
||||
parameters.max_iter,
|
||||
TX::from_f64(parameters.tol).unwrap(),
|
||||
parameters.fit_intercept,
|
||||
)?;
|
||||
|
||||
for (j, col_std_j) in col_std.iter().enumerate().take(p) {
|
||||
w[j] /= *col_std_j;
|
||||
}
|
||||
|
||||
let mut b = TX::zero();
|
||||
let b = if parameters.fit_intercept {
|
||||
let mut xw_mean = TX::zero();
|
||||
for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
|
||||
xw_mean += w[i] * *col_mean_i;
|
||||
}
|
||||
|
||||
for (i, col_mean_i) in col_mean.iter().enumerate().take(p) {
|
||||
b += w[i] * *col_mean_i;
|
||||
}
|
||||
|
||||
b = TX::from_f64(y.mean_by()).unwrap() - b;
|
||||
Some(TX::from_f64(y.mean_by()).unwrap() - xw_mean)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
(X::from_column(&w), b)
|
||||
} else {
|
||||
let mut optimizer = InteriorPointOptimizer::new(x, p);
|
||||
@@ -306,13 +336,21 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
|
||||
l1_reg,
|
||||
parameters.max_iter,
|
||||
TX::from_f64(parameters.tol).unwrap(),
|
||||
parameters.fit_intercept,
|
||||
)?;
|
||||
|
||||
(X::from_column(&w), TX::from_f64(y.mean_by()).unwrap())
|
||||
(
|
||||
X::from_column(&w),
|
||||
if parameters.fit_intercept {
|
||||
Some(TX::from_f64(y.mean_by()).unwrap())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
)
|
||||
};
|
||||
|
||||
Ok(Lasso {
|
||||
intercept: Some(b),
|
||||
intercept: b,
|
||||
coefficients: Some(w),
|
||||
_phantom_ty: PhantomData,
|
||||
_phantom_y: PhantomData,
|
||||
@@ -369,6 +407,7 @@ impl<TX: FloatNumber + RealNumber, TY: Number, X: Array2<TX>, Y: Array1<TY>> Las
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::basic::arrays::Array;
|
||||
use crate::linalg::basic::matrix::DenseMatrix;
|
||||
use crate::metrics::mean_absolute_error;
|
||||
|
||||
@@ -377,30 +416,28 @@ mod tests {
|
||||
let parameters = LassoSearchParameters {
|
||||
alpha: vec![0., 1.],
|
||||
max_iter: vec![10, 100],
|
||||
fit_intercept: vec![false, true],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 0.);
|
||||
assert_eq!(next.max_iter, 10);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 1.);
|
||||
assert_eq!(next.max_iter, 10);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 0.);
|
||||
assert_eq!(next.max_iter, 100);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 1.);
|
||||
assert_eq!(next.max_iter, 100);
|
||||
|
||||
let mut iter = parameters.clone().into_iter();
|
||||
for current_fit_intercept in 0..parameters.fit_intercept.len() {
|
||||
for current_max_iter in 0..parameters.max_iter.len() {
|
||||
for current_alpha in 0..parameters.alpha.len() {
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, parameters.alpha[current_alpha]);
|
||||
assert_eq!(next.max_iter, parameters.max_iter[current_max_iter]);
|
||||
assert_eq!(
|
||||
next.fit_intercept,
|
||||
parameters.fit_intercept[current_fit_intercept]
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn lasso_fit_predict() {
|
||||
fn get_example_x_y() -> (DenseMatrix<f64>, Vec<f64>) {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
@@ -426,6 +463,17 @@ mod tests {
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
|
||||
(x, y)
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn lasso_fit_predict() {
|
||||
let (x, y) = get_example_x_y();
|
||||
|
||||
let y_hat = Lasso::fit(&x, &y, Default::default())
|
||||
.and_then(|lr| lr.predict(&x))
|
||||
.unwrap();
|
||||
@@ -440,6 +488,7 @@ mod tests {
|
||||
normalize: false,
|
||||
tol: 1e-4,
|
||||
max_iter: 1000,
|
||||
fit_intercept: true,
|
||||
},
|
||||
)
|
||||
.and_then(|lr| lr.predict(&x))
|
||||
@@ -448,35 +497,76 @@ mod tests {
|
||||
assert!(mean_absolute_error(&y_hat, &y) < 2.0);
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn test_full_rank_x() {
|
||||
// x: randn(3,3) * 10, demean, then round to 2 decimal points
|
||||
// y = x @ [10.0, 0.2, -3.0], round to 2 decimal points
|
||||
let param = LassoParameters::default()
|
||||
.with_normalize(false)
|
||||
.with_alpha(200.0);
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[-8.9, -2.24, 8.89],
|
||||
&[-4.02, 8.89, 12.33],
|
||||
&[12.92, -6.65, -21.22],
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
let y = vec![-116.12, -75.41, 191.53];
|
||||
let w = Lasso::fit(&x, &y, param)
|
||||
.unwrap()
|
||||
.coefficients()
|
||||
.iterator(0)
|
||||
.copied()
|
||||
.collect();
|
||||
|
||||
let expected_w = vec![5.20289531, 0., -5.32823882]; // by coordinate descent
|
||||
assert!(mean_absolute_error(&w, &expected_w) < 1e-3); // actual mean_absolute_error is about 2e-4
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn test_fit_intercept() {
|
||||
let (x, y) = get_example_x_y();
|
||||
let fit_result = Lasso::fit(
|
||||
&x,
|
||||
&y,
|
||||
LassoParameters {
|
||||
alpha: 0.1,
|
||||
normalize: false,
|
||||
tol: 1e-8,
|
||||
max_iter: 1000,
|
||||
fit_intercept: false,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let w = fit_result.coefficients().iterator(0).copied().collect();
|
||||
// by sklearn LassoLars. coordinate descent doesn't converge well
|
||||
let expected_w = vec![
|
||||
0.18335684,
|
||||
0.02106526,
|
||||
0.00703214,
|
||||
-1.35952542,
|
||||
0.09295222,
|
||||
0.,
|
||||
];
|
||||
assert!(mean_absolute_error(&w, &expected_w) < 1e-6);
|
||||
assert_eq!(fit_result.intercept, None);
|
||||
}
|
||||
|
||||
// TODO: serialization for the new DenseMatrix needs to be implemented
|
||||
// #[cfg_attr(all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test)]
|
||||
// #[test]
|
||||
// #[cfg(feature = "serde")]
|
||||
// fn serde() {
|
||||
// let x = DenseMatrix::from_2d_array(&[
|
||||
// &[234.289, 235.6, 159.0, 107.608, 1947., 60.323],
|
||||
// &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
// &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
// &[284.599, 335.1, 165.0, 110.929, 1950., 61.187],
|
||||
// &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
// &[346.999, 193.2, 359.4, 113.270, 1952., 63.639],
|
||||
// &[365.385, 187.0, 354.7, 115.094, 1953., 64.989],
|
||||
// &[363.112, 357.8, 335.0, 116.219, 1954., 63.761],
|
||||
// &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
// &[419.180, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
// &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
// &[444.546, 468.1, 263.7, 121.950, 1958., 66.513],
|
||||
// &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
// &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
// &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
// &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
// ]);
|
||||
|
||||
// let y = vec![
|
||||
// 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
// 114.2, 115.7, 116.9,
|
||||
// ];
|
||||
|
||||
// let (x, y) = get_lasso_sample_x_y();
|
||||
// let lr = Lasso::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
// let deserialized_lr: Lasso<f64, f64, DenseMatrix<f64>, Vec<f64>> =
|
||||
|
||||
@@ -45,6 +45,7 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
|
||||
lambda: T,
|
||||
max_iter: usize,
|
||||
tol: T,
|
||||
fit_intercept: bool,
|
||||
) -> Result<Vec<T>, Failed> {
|
||||
let (n, p) = x.shape();
|
||||
let p_f64 = T::from_usize(p).unwrap();
|
||||
@@ -52,6 +53,7 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
|
||||
let lambda = lambda.max(T::epsilon());
|
||||
|
||||
//parameters
|
||||
let max_ls_iter = 100;
|
||||
let pcgmaxi = 5000;
|
||||
let min_pcgtol = T::from_f64(0.1).unwrap();
|
||||
let eta = T::from_f64(1E-3).unwrap();
|
||||
@@ -61,9 +63,12 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
|
||||
let mu = T::two();
|
||||
|
||||
// let y = M::from_row_vector(y.sub_scalar(y.mean_by())).transpose();
|
||||
let y = y.sub_scalar(T::from_f64(y.mean_by()).unwrap());
|
||||
let y = if fit_intercept {
|
||||
y.sub_scalar(T::from_f64(y.mean_by()).unwrap())
|
||||
} else {
|
||||
y.to_owned()
|
||||
};
|
||||
|
||||
let mut max_ls_iter = 100;
|
||||
let mut pitr = 0;
|
||||
let mut w = Vec::zeros(p);
|
||||
let mut neww = w.clone();
|
||||
@@ -165,7 +170,7 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
|
||||
s = T::one();
|
||||
let gdx = grad.dot(&dxu);
|
||||
|
||||
let lsiter = 0;
|
||||
let mut lsiter = 0;
|
||||
while lsiter < max_ls_iter {
|
||||
for i in 0..p {
|
||||
neww[i] = w[i] + s * dx[i];
|
||||
@@ -190,7 +195,7 @@ impl<T: FloatNumber, X: Array2<T>> InteriorPointOptimizer<T, X> {
|
||||
}
|
||||
}
|
||||
s = beta * s;
|
||||
max_ls_iter += 1;
|
||||
lsiter += 1;
|
||||
}
|
||||
|
||||
if lsiter == max_ls_iter {
|
||||
|
||||
+88
-23
@@ -4,7 +4,9 @@
|
||||
//!
|
||||
//! \\[precision = \frac{tp}{tp + fp}\\]
|
||||
//!
|
||||
//! where tp (true positive) - correct result, fp (false positive) - unexpected result
|
||||
//! where tp (true positive) - correct result, fp (false positive) - unexpected result.
|
||||
//! For binary classification, this is precision for the positive class (assumed to be 1.0).
|
||||
//! For multiclass, this is macro-averaged precision (average of per-class precisions).
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
@@ -19,7 +21,8 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::collections::HashSet;
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
@@ -61,33 +64,63 @@ impl<T: RealNumber> Metrics<T> for Precision<T> {
|
||||
);
|
||||
}
|
||||
|
||||
let mut classes = HashSet::new();
|
||||
for i in 0..y_true.shape() {
|
||||
classes.insert(y_true.get(i).to_f64_bits());
|
||||
}
|
||||
let classes = classes.len();
|
||||
let n = y_true.shape();
|
||||
|
||||
let mut tp = 0;
|
||||
let mut fp = 0;
|
||||
for i in 0..y_true.shape() {
|
||||
if y_pred.get(i) == y_true.get(i) {
|
||||
if classes == 2 {
|
||||
if *y_true.get(i) == T::one() {
|
||||
let mut classes_set: HashSet<u64> = HashSet::new();
|
||||
for i in 0..n {
|
||||
classes_set.insert(y_true.get(i).to_f64_bits());
|
||||
}
|
||||
let classes: usize = classes_set.len();
|
||||
|
||||
if classes == 2 {
|
||||
// Binary case: precision for positive class (assumed T::one())
|
||||
let positive = T::one();
|
||||
let mut tp: usize = 0;
|
||||
let mut fp_count: usize = 0;
|
||||
for i in 0..n {
|
||||
let t = *y_true.get(i);
|
||||
let p = *y_pred.get(i);
|
||||
if p == t {
|
||||
if t == positive {
|
||||
tp += 1;
|
||||
}
|
||||
} else {
|
||||
tp += 1;
|
||||
}
|
||||
} else if classes == 2 {
|
||||
if *y_true.get(i) == T::one() {
|
||||
fp += 1;
|
||||
} else if t != positive {
|
||||
fp_count += 1;
|
||||
}
|
||||
}
|
||||
if tp + fp_count == 0 {
|
||||
0.0
|
||||
} else {
|
||||
fp += 1;
|
||||
tp as f64 / (tp + fp_count) as f64
|
||||
}
|
||||
} else {
|
||||
// Multiclass case: macro-averaged precision
|
||||
let mut predicted: HashMap<u64, usize> = HashMap::new();
|
||||
let mut tp_map: HashMap<u64, usize> = HashMap::new();
|
||||
for i in 0..n {
|
||||
let p_bits = y_pred.get(i).to_f64_bits();
|
||||
*predicted.entry(p_bits).or_insert(0) += 1;
|
||||
if *y_true.get(i) == *y_pred.get(i) {
|
||||
*tp_map.entry(p_bits).or_insert(0) += 1;
|
||||
}
|
||||
}
|
||||
let mut precision_sum = 0.0;
|
||||
for &bits in &classes_set {
|
||||
let pred_count = *predicted.get(&bits).unwrap_or(&0);
|
||||
let tp = *tp_map.get(&bits).unwrap_or(&0);
|
||||
let prec = if pred_count > 0 {
|
||||
tp as f64 / pred_count as f64
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
precision_sum += prec;
|
||||
}
|
||||
if classes == 0 {
|
||||
0.0
|
||||
} else {
|
||||
precision_sum / classes as f64
|
||||
}
|
||||
}
|
||||
|
||||
tp as f64 / (tp as f64 + fp as f64)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,7 +147,7 @@ mod tests {
|
||||
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
||||
|
||||
let score3: f64 = Precision::new().get_score(&y_true, &y_pred);
|
||||
assert!((score3 - 0.6666666666).abs() < 1e-8);
|
||||
assert!((score3 - 0.5).abs() < 1e-8);
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
@@ -132,4 +165,36 @@ mod tests {
|
||||
assert!((score1 - 0.333333333).abs() < 1e-8);
|
||||
assert!((score2 - 1.0).abs() < 1e-8);
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn precision_multiclass_imbalanced() {
|
||||
let y_true: Vec<f64> = vec![0., 0., 1., 2., 2., 2.];
|
||||
let y_pred: Vec<f64> = vec![0., 1., 1., 2., 0., 2.];
|
||||
|
||||
let score: f64 = Precision::new().get_score(&y_true, &y_pred);
|
||||
let expected = (0.5 + 0.5 + 1.0) / 3.0;
|
||||
assert!((score - expected).abs() < 1e-8);
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn precision_multiclass_unpredicted_class() {
|
||||
let y_true: Vec<f64> = vec![0., 0., 1., 2., 2., 2., 3.];
|
||||
let y_pred: Vec<f64> = vec![0., 1., 1., 2., 0., 2., 0.];
|
||||
|
||||
let score: f64 = Precision::new().get_score(&y_true, &y_pred);
|
||||
// Class 0: pred=3, tp=1 -> 1/3 ≈0.333
|
||||
// Class 1: pred=2, tp=1 -> 0.5
|
||||
// Class 2: pred=2, tp=2 -> 1.0
|
||||
// Class 3: pred=0, tp=0 -> 0.0
|
||||
let expected = (1.0 / 3.0 + 0.5 + 1.0 + 0.0) / 4.0;
|
||||
assert!((score - expected).abs() < 1e-8);
|
||||
}
|
||||
}
|
||||
|
||||
+64
-24
@@ -4,7 +4,9 @@
|
||||
//!
|
||||
//! \\[recall = \frac{tp}{tp + fn}\\]
|
||||
//!
|
||||
//! where tp (true positive) - correct result, fn (false negative) - missing result
|
||||
//! where tp (true positive) - correct result, fn (false negative) - missing result.
|
||||
//! For binary classification, this is recall for the positive class (assumed to be 1.0).
|
||||
//! For multiclass, this is macro-averaged recall (average of per-class recalls).
|
||||
//!
|
||||
//! Example:
|
||||
//!
|
||||
@@ -20,8 +22,7 @@
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::convert::TryInto;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
@@ -52,7 +53,7 @@ impl<T: RealNumber> Metrics<T> for Recall<T> {
|
||||
}
|
||||
}
|
||||
/// Calculated recall score
|
||||
/// * `y_true` - cround truth (correct) labels.
|
||||
/// * `y_true` - ground truth (correct) labels.
|
||||
/// * `y_pred` - predicted labels, as returned by a classifier.
|
||||
fn get_score(&self, y_true: &dyn ArrayView1<T>, y_pred: &dyn ArrayView1<T>) -> f64 {
|
||||
if y_true.shape() != y_pred.shape() {
|
||||
@@ -63,32 +64,57 @@ impl<T: RealNumber> Metrics<T> for Recall<T> {
|
||||
);
|
||||
}
|
||||
|
||||
let mut classes = HashSet::new();
|
||||
for i in 0..y_true.shape() {
|
||||
classes.insert(y_true.get(i).to_f64_bits());
|
||||
}
|
||||
let classes: i64 = classes.len().try_into().unwrap();
|
||||
let n = y_true.shape();
|
||||
|
||||
let mut tp = 0;
|
||||
let mut fne = 0;
|
||||
for i in 0..y_true.shape() {
|
||||
if y_pred.get(i) == y_true.get(i) {
|
||||
if classes == 2 {
|
||||
if *y_true.get(i) == T::one() {
|
||||
let mut classes_set = HashSet::new();
|
||||
for i in 0..n {
|
||||
classes_set.insert(y_true.get(i).to_f64_bits());
|
||||
}
|
||||
let classes: usize = classes_set.len();
|
||||
|
||||
if classes == 2 {
|
||||
// Binary case: recall for positive class (assumed T::one())
|
||||
let positive = T::one();
|
||||
let mut tp: usize = 0;
|
||||
let mut fn_count: usize = 0;
|
||||
for i in 0..n {
|
||||
let t = *y_true.get(i);
|
||||
let p = *y_pred.get(i);
|
||||
if p == t {
|
||||
if t == positive {
|
||||
tp += 1;
|
||||
}
|
||||
} else {
|
||||
tp += 1;
|
||||
}
|
||||
} else if classes == 2 {
|
||||
if *y_true.get(i) != T::one() {
|
||||
fne += 1;
|
||||
} else if t == positive {
|
||||
fn_count += 1;
|
||||
}
|
||||
}
|
||||
if tp + fn_count == 0 {
|
||||
0.0
|
||||
} else {
|
||||
fne += 1;
|
||||
tp as f64 / (tp + fn_count) as f64
|
||||
}
|
||||
} else {
|
||||
// Multiclass case: macro-averaged recall
|
||||
let mut support: HashMap<u64, usize> = HashMap::new();
|
||||
let mut tp_map: HashMap<u64, usize> = HashMap::new();
|
||||
for i in 0..n {
|
||||
let t_bits = y_true.get(i).to_f64_bits();
|
||||
*support.entry(t_bits).or_insert(0) += 1;
|
||||
if *y_true.get(i) == *y_pred.get(i) {
|
||||
*tp_map.entry(t_bits).or_insert(0) += 1;
|
||||
}
|
||||
}
|
||||
let mut recall_sum = 0.0;
|
||||
for (&bits, &sup) in &support {
|
||||
let tp = *tp_map.get(&bits).unwrap_or(&0);
|
||||
recall_sum += tp as f64 / sup as f64;
|
||||
}
|
||||
if support.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
recall_sum / support.len() as f64
|
||||
}
|
||||
}
|
||||
tp as f64 / (tp as f64 + fne as f64)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,7 +141,7 @@ mod tests {
|
||||
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
||||
|
||||
let score3: f64 = Recall::new().get_score(&y_true, &y_pred);
|
||||
assert!((score3 - 0.5).abs() < 1e-8);
|
||||
assert!((score3 - (2.0 / 3.0)).abs() < 1e-8);
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
@@ -133,4 +159,18 @@ mod tests {
|
||||
assert!((score1 - 0.333333333).abs() < 1e-8);
|
||||
assert!((score2 - 1.0).abs() < 1e-8);
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn recall_multiclass_imbalanced() {
|
||||
let y_true: Vec<f64> = vec![0., 0., 1., 2., 2., 2.];
|
||||
let y_pred: Vec<f64> = vec![0., 1., 1., 2., 0., 2.];
|
||||
|
||||
let score: f64 = Recall::new().get_score(&y_true, &y_pred);
|
||||
let expected = (0.5 + 1.0 + (2.0 / 3.0)) / 3.0;
|
||||
assert!((score - expected).abs() < 1e-8);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
//! # K Nearest Neighbors Regressor
|
||||
//! # K Nearest Neighbors Regressor with Feature Sparsing
|
||||
//!
|
||||
//! Regressor that predicts estimated values as a function of k nearest neightbours.
|
||||
//! Now supports feature sparsing - the ability to consider only a subset of features during prediction.
|
||||
//!
|
||||
//! `KNNRegressor` relies on 2 backend algorithms to speedup KNN queries:
|
||||
//! * [`LinearSearch`](../../algorithm/neighbour/linear_search/index.html)
|
||||
@@ -29,6 +30,10 @@
|
||||
//!
|
||||
//! let knn = KNNRegressor::fit(&x, &y, Default::default()).unwrap();
|
||||
//! let y_hat = knn.predict(&x).unwrap();
|
||||
//!
|
||||
//! // Predict using only features at indices 0
|
||||
//! let feature_indices = vec![0];
|
||||
//! let y_hat_sparse = knn.predict_sparse(&x, &feature_indices).unwrap();
|
||||
//! ```
|
||||
//!
|
||||
//! variable `y_hat` will hold predicted value
|
||||
@@ -77,12 +82,13 @@ pub struct KNNRegressorParameters<T: Number, D: Distance<Vec<T>>> {
|
||||
pub struct KNNRegressor<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
|
||||
{
|
||||
y: Option<Y>,
|
||||
x: Option<X>, // Store training data for sparse feature prediction
|
||||
knn_algorithm: Option<KNNAlgorithm<TX, D>>,
|
||||
distance: Option<D>, // Store distance function for sparse prediction
|
||||
weight: Option<KNNWeightFunction>,
|
||||
k: Option<usize>,
|
||||
_phantom_tx: PhantomData<TX>,
|
||||
_phantom_ty: PhantomData<TY>,
|
||||
_phantom_x: PhantomData<X>,
|
||||
}
|
||||
|
||||
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
|
||||
@@ -92,12 +98,20 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
|
||||
self.y.as_ref().unwrap()
|
||||
}
|
||||
|
||||
fn x(&self) -> &X {
|
||||
self.x.as_ref().unwrap()
|
||||
}
|
||||
|
||||
fn knn_algorithm(&self) -> &KNNAlgorithm<TX, D> {
|
||||
self.knn_algorithm
|
||||
.as_ref()
|
||||
.expect("Missing parameter: KNNAlgorithm")
|
||||
}
|
||||
|
||||
fn distance(&self) -> &D {
|
||||
self.distance.as_ref().expect("Missing parameter: distance")
|
||||
}
|
||||
|
||||
fn weight(&self) -> &KNNWeightFunction {
|
||||
self.weight.as_ref().expect("Missing parameter: weight")
|
||||
}
|
||||
@@ -176,12 +190,13 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
y: Option::None,
|
||||
x: Option::None,
|
||||
knn_algorithm: Option::None,
|
||||
distance: Option::None,
|
||||
weight: Option::None,
|
||||
k: Option::None,
|
||||
_phantom_tx: PhantomData,
|
||||
_phantom_ty: PhantomData,
|
||||
_phantom_x: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -231,16 +246,17 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
|
||||
)));
|
||||
}
|
||||
|
||||
let knn_algo = parameters.algorithm.fit(data, parameters.distance)?;
|
||||
let knn_algo = parameters.algorithm.fit(data, parameters.distance.clone())?;
|
||||
|
||||
Ok(KNNRegressor {
|
||||
y: Some(y.clone()),
|
||||
x: Some(x.clone()),
|
||||
k: Some(parameters.k),
|
||||
knn_algorithm: Some(knn_algo),
|
||||
distance: Some(parameters.distance),
|
||||
weight: Some(parameters.weight),
|
||||
_phantom_tx: PhantomData,
|
||||
_phantom_ty: PhantomData,
|
||||
_phantom_x: PhantomData,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -262,6 +278,45 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Predict the target for the provided data using only specified features.
|
||||
/// * `x` - data of shape NxM where N is number of data points to estimate and M is number of features.
|
||||
/// * `feature_indices` - indices of features to consider (e.g., [0, 2, 4] to use only features at positions 0, 2, and 4)
|
||||
///
|
||||
/// Returns a vector of size N with estimates.
|
||||
pub fn predict_sparse(&self, x: &X, feature_indices: &[usize]) -> Result<Y, Failed> {
|
||||
let (n_samples, n_features) = x.shape();
|
||||
|
||||
// Validate feature indices
|
||||
for &idx in feature_indices {
|
||||
if idx >= n_features {
|
||||
return Err(Failed::predict(&format!(
|
||||
"Feature index {} out of bounds (max: {})",
|
||||
idx,
|
||||
n_features - 1
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
if feature_indices.is_empty() {
|
||||
return Err(Failed::predict(
|
||||
"feature_indices cannot be empty"
|
||||
));
|
||||
}
|
||||
|
||||
let mut result = Y::zeros(n_samples);
|
||||
|
||||
let mut row_vec = vec![TX::zero(); feature_indices.len()];
|
||||
for (i, row) in x.row_iter().enumerate() {
|
||||
// Extract only the specified features
|
||||
for (j, &feat_idx) in feature_indices.iter().enumerate() {
|
||||
row_vec[j] = *row.get(feat_idx);
|
||||
}
|
||||
result.set(i, self.predict_for_row_sparse(&row_vec, feature_indices)?);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn predict_for_row(&self, row: &Vec<TX>) -> Result<TY, Failed> {
|
||||
let search_result = self.knn_algorithm().find(row, self.k.unwrap())?;
|
||||
let mut result = TY::zero();
|
||||
@@ -277,6 +332,50 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>>
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn predict_for_row_sparse(
|
||||
&self,
|
||||
row: &Vec<TX>,
|
||||
feature_indices: &[usize],
|
||||
) -> Result<TY, Failed> {
|
||||
let training_data = self.x();
|
||||
let (n_training_samples, _) = training_data.shape();
|
||||
let k = self.k.unwrap();
|
||||
|
||||
// Manually compute distances using only specified features
|
||||
let mut distances: Vec<(usize, f64)> = Vec::with_capacity(n_training_samples);
|
||||
|
||||
for i in 0..n_training_samples {
|
||||
let train_row = training_data.get_row(i);
|
||||
|
||||
// Extract sparse features from training data
|
||||
let mut train_sparse = Vec::with_capacity(feature_indices.len());
|
||||
for &feat_idx in feature_indices {
|
||||
train_sparse.push(*train_row.get(feat_idx));
|
||||
}
|
||||
|
||||
// Compute distance using only selected features
|
||||
let dist = self.distance().distance(row, &train_sparse);
|
||||
distances.push((i, dist));
|
||||
}
|
||||
|
||||
// Sort by distance and take k nearest
|
||||
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
let k_nearest: Vec<(usize, f64)> = distances.into_iter().take(k).collect();
|
||||
|
||||
// Compute weighted prediction
|
||||
let mut result = TY::zero();
|
||||
let weights = self
|
||||
.weight()
|
||||
.calc_weights(k_nearest.iter().map(|v| v.1).collect());
|
||||
let w_sum: f64 = weights.iter().copied().sum();
|
||||
|
||||
for (neighbor, w) in k_nearest.iter().zip(weights.iter()) {
|
||||
result += *self.y().get(neighbor.0) * TY::from_f64(*w / w_sum).unwrap();
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -332,6 +431,91 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn knn_predict_sparse() {
|
||||
// Training data with 3 features
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[1., 2., 10.],
|
||||
&[3., 4., 20.],
|
||||
&[5., 6., 30.],
|
||||
&[7., 8., 40.],
|
||||
&[9., 10., 50.],
|
||||
])
|
||||
.unwrap();
|
||||
let y: Vec<f64> = vec![1., 2., 3., 4., 5.];
|
||||
|
||||
let knn = KNNRegressor::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
// Test data
|
||||
let x_test = DenseMatrix::from_2d_array(&[
|
||||
&[1., 2., 999.], // Third feature is very different
|
||||
&[5., 6., 999.],
|
||||
])
|
||||
.unwrap();
|
||||
|
||||
// Predict using only first two features (ignore the third)
|
||||
let feature_indices = vec![0, 1];
|
||||
let y_hat_sparse = knn.predict_sparse(&x_test, &feature_indices).unwrap();
|
||||
|
||||
// Should get good predictions since we're ignoring the mismatched third feature
|
||||
assert_eq!(2, Vec::len(&y_hat_sparse));
|
||||
assert!((y_hat_sparse[0] - 2.0).abs() < 1.0); // Should be close to 1-2
|
||||
assert!((y_hat_sparse[1] - 3.0).abs() < 1.0); // Should be close to 3
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn knn_predict_sparse_single_feature() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[1., 100., 1000.],
|
||||
&[2., 200., 2000.],
|
||||
&[3., 300., 3000.],
|
||||
&[4., 400., 4000.],
|
||||
&[5., 500., 5000.],
|
||||
])
|
||||
.unwrap();
|
||||
let y: Vec<f64> = vec![1., 2., 3., 4., 5.];
|
||||
|
||||
let knn = KNNRegressor::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let x_test = DenseMatrix::from_2d_array(&[&[1.5, 999., 9999.]]).unwrap();
|
||||
|
||||
// Use only first feature
|
||||
let y_hat = knn.predict_sparse(&x_test, &[0]).unwrap();
|
||||
|
||||
// Should predict based on first feature only
|
||||
assert_eq!(1, Vec::len(&y_hat));
|
||||
assert!((y_hat[0] - 1.5).abs() < 1.0);
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
)]
|
||||
#[test]
|
||||
fn knn_predict_sparse_invalid_indices() {
|
||||
let x = DenseMatrix::from_2d_array(&[&[1., 2.], &[3., 4.]]).unwrap();
|
||||
let y: Vec<f64> = vec![1., 2.];
|
||||
|
||||
let knn = KNNRegressor::fit(&x, &y, Default::default()).unwrap();
|
||||
let x_test = DenseMatrix::from_2d_array(&[&[1., 2.]]).unwrap();
|
||||
|
||||
// Index out of bounds
|
||||
let result = knn.predict_sparse(&x_test, &[5]);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Empty indices
|
||||
let result = knn.predict_sparse(&x_test, &[]);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
@@ -350,4 +534,4 @@ mod tests {
|
||||
|
||||
assert_eq!(knn, deserialized_knn);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -53,10 +53,14 @@ use crate::{
|
||||
rand_custom::get_rng_impl,
|
||||
};
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Defines the objective function to be optimized.
|
||||
/// The objective function provides the loss, gradient (first derivative), and
|
||||
/// hessian (second derivative) required for the XGBoost algorithm.
|
||||
#[derive(Clone, Debug)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
pub enum Objective {
|
||||
/// The objective for regression tasks using Mean Squared Error.
|
||||
/// Loss: 0.5 * (y_true - y_pred)^2
|
||||
@@ -122,6 +126,8 @@ impl Objective {
|
||||
/// This is a recursive data structure where each `TreeRegressor` is a node
|
||||
/// that can have a left and a right child, also of type `TreeRegressor`.
|
||||
#[allow(dead_code)]
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
struct TreeRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
|
||||
left: Option<Box<TreeRegressor<TX, TY, X, Y>>>,
|
||||
right: Option<Box<TreeRegressor<TX, TY, X, Y>>>,
|
||||
@@ -374,6 +380,7 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||
/// Parameters for the `jRegressor` model.
|
||||
///
|
||||
/// This struct holds all the hyperparameters that control the training process.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct XGRegressorParameters {
|
||||
/// The number of boosting rounds or trees to build.
|
||||
@@ -494,6 +501,8 @@ impl XGRegressorParameters {
|
||||
}
|
||||
|
||||
/// An Extreme Gradient Boosting (XGBoost) model for regression and classification tasks.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub struct XGRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
|
||||
regressors: Option<Vec<TreeRegressor<TX, TY, X, Y>>>,
|
||||
parameters: Option<XGRegressorParameters>,
|
||||
|
||||
Reference in New Issue
Block a user