Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
76d1ef610d | ||
|
|
4092e24c2a | ||
|
|
17dc9f3bbf |
@@ -19,14 +19,13 @@ jobs:
|
|||||||
{ os: "ubuntu", target: "i686-unknown-linux-gnu" },
|
{ os: "ubuntu", target: "i686-unknown-linux-gnu" },
|
||||||
{ os: "ubuntu", target: "wasm32-unknown-unknown" },
|
{ os: "ubuntu", target: "wasm32-unknown-unknown" },
|
||||||
{ os: "macos", target: "aarch64-apple-darwin" },
|
{ os: "macos", target: "aarch64-apple-darwin" },
|
||||||
{ os: "ubuntu", target: "wasm32-wasi" },
|
|
||||||
]
|
]
|
||||||
env:
|
env:
|
||||||
TZ: "/usr/share/zoneinfo/your/location"
|
TZ: "/usr/share/zoneinfo/your/location"
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
- name: Cache .cargo and target
|
- name: Cache .cargo and target
|
||||||
uses: actions/cache@v2
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/.cargo
|
~/.cargo
|
||||||
@@ -36,16 +35,13 @@ jobs:
|
|||||||
- name: Install Rust toolchain
|
- name: Install Rust toolchain
|
||||||
uses: actions-rs/toolchain@v1
|
uses: actions-rs/toolchain@v1
|
||||||
with:
|
with:
|
||||||
toolchain: 1.81 # 1.82 seems to break wasm32 tests https://github.com/rustwasm/wasm-bindgen/issues/4274
|
toolchain: stable
|
||||||
target: ${{ matrix.platform.target }}
|
target: ${{ matrix.platform.target }}
|
||||||
profile: minimal
|
profile: minimal
|
||||||
default: true
|
default: true
|
||||||
- name: Install test runner for wasm
|
- name: Install test runner for wasm
|
||||||
if: matrix.platform.target == 'wasm32-unknown-unknown'
|
if: matrix.platform.target == 'wasm32-unknown-unknown'
|
||||||
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
|
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
|
||||||
- name: Install test runner for wasi
|
|
||||||
if: matrix.platform.target == 'wasm32-wasi'
|
|
||||||
run: curl https://wasmtime.dev/install.sh -sSf | bash
|
|
||||||
- name: Stable Build with all features
|
- name: Stable Build with all features
|
||||||
uses: actions-rs/cargo@v1
|
uses: actions-rs/cargo@v1
|
||||||
with:
|
with:
|
||||||
@@ -65,13 +61,7 @@ jobs:
|
|||||||
- name: Tests in WASM
|
- name: Tests in WASM
|
||||||
if: matrix.platform.target == 'wasm32-unknown-unknown'
|
if: matrix.platform.target == 'wasm32-unknown-unknown'
|
||||||
run: wasm-pack test --node -- --all-features
|
run: wasm-pack test --node -- --all-features
|
||||||
- name: Tests in WASI
|
|
||||||
if: matrix.platform.target == 'wasm32-wasi'
|
|
||||||
run: |
|
|
||||||
export WASMTIME_HOME="$HOME/.wasmtime"
|
|
||||||
export PATH="$WASMTIME_HOME/bin:$PATH"
|
|
||||||
cargo install cargo-wasi && cargo wasi test
|
|
||||||
|
|
||||||
check_features:
|
check_features:
|
||||||
runs-on: "${{ matrix.platform.os }}-latest"
|
runs-on: "${{ matrix.platform.os }}-latest"
|
||||||
strategy:
|
strategy:
|
||||||
@@ -81,9 +71,9 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
TZ: "/usr/share/zoneinfo/your/location"
|
TZ: "/usr/share/zoneinfo/your/location"
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
- name: Cache .cargo and target
|
- name: Cache .cargo and target
|
||||||
uses: actions/cache@v2
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/.cargo
|
~/.cargo
|
||||||
|
|||||||
@@ -12,9 +12,9 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
TZ: "/usr/share/zoneinfo/your/location"
|
TZ: "/usr/share/zoneinfo/your/location"
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v4
|
||||||
- name: Cache .cargo
|
- name: Cache .cargo
|
||||||
uses: actions/cache@v2
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/.cargo
|
~/.cargo
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
- name: Cache .cargo and target
|
- name: Cache .cargo and target
|
||||||
uses: actions/cache@v2
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/.cargo
|
~/.cargo
|
||||||
|
|||||||
+1
-1
@@ -2,7 +2,7 @@
|
|||||||
name = "smartcore"
|
name = "smartcore"
|
||||||
description = "Machine Learning in Rust."
|
description = "Machine Learning in Rust."
|
||||||
homepage = "https://smartcorelib.org"
|
homepage = "https://smartcorelib.org"
|
||||||
version = "0.4.0"
|
version = "0.4.1"
|
||||||
authors = ["smartcore Developers"]
|
authors = ["smartcore Developers"]
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
|
|||||||
@@ -18,4 +18,4 @@
|
|||||||
-----
|
-----
|
||||||
[](https://github.com/smartcorelib/smartcore/actions/workflows/ci.yml)
|
[](https://github.com/smartcorelib/smartcore/actions/workflows/ci.yml)
|
||||||
|
|
||||||
To start getting familiar with the new smartcore v0.3 API, there is now available a [**Jupyter Notebook environment repository**](https://github.com/smartcorelib/smartcore-jupyter). Please see instructions there, contributions welcome see [CONTRIBUTING](.github/CONTRIBUTING.md).
|
To start getting familiar with the new smartcore v0.4 API, there is now available a [**Jupyter Notebook environment repository**](https://github.com/smartcorelib/smartcore-jupyter). Please see instructions there, contributions welcome see [CONTRIBUTING](.github/CONTRIBUTING.md).
|
||||||
|
|||||||
@@ -1,219 +0,0 @@
|
|||||||
//! This module provides FastPair, a data-structure for efficiently tracking the dynamic
|
|
||||||
//! closest pairs in a set of points, with an example usage in hierarchical clustering.[2][3][5]
|
|
||||||
//!
|
|
||||||
//! ## Purpose
|
|
||||||
//!
|
|
||||||
//! FastPair allows quick retrieval of the nearest neighbor for each data point by maintaining
|
|
||||||
//! a "conga line" of closest pairs. Each point retains a link to its known nearest neighbor,
|
|
||||||
//! and updates in the data structure propagate accordingly. This can be leveraged in
|
|
||||||
//! agglomerative clustering steps, where merging or insertion of new points must be reflected
|
|
||||||
//! in nearest-neighbor relationships.
|
|
||||||
//!
|
|
||||||
//! ## Example
|
|
||||||
//!
|
|
||||||
//! ```
|
|
||||||
//! use smartcore::metrics::distance::PairwiseDistance;
|
|
||||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
|
||||||
//! use smartcore::algorithm::neighbour::fastpair::FastPair;
|
|
||||||
//!
|
|
||||||
//! let x = DenseMatrix::from_2d_array(&[
|
|
||||||
//! &[5.1, 3.5, 1.4, 0.2],
|
|
||||||
//! &[4.9, 3.0, 1.4, 0.2],
|
|
||||||
//! &[4.7, 3.2, 1.3, 0.2],
|
|
||||||
//! &[4.6, 3.1, 1.5, 0.2],
|
|
||||||
//! &[5.0, 3.6, 1.4, 0.2],
|
|
||||||
//! &[5.4, 3.9, 1.7, 0.4],
|
|
||||||
//! ]).unwrap();
|
|
||||||
//!
|
|
||||||
//! let fastpair = FastPair::new(&x).unwrap();
|
|
||||||
//! let closest = fastpair.closest_pair();
|
|
||||||
//! println!("Closest pair: {:?}", closest);
|
|
||||||
//! ```
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
use num::Bounded;
|
|
||||||
|
|
||||||
use crate::error::{Failed, FailedError};
|
|
||||||
use crate::linalg::basic::arrays::{Array, Array1, Array2};
|
|
||||||
use crate::metrics::distance::euclidian::Euclidian;
|
|
||||||
use crate::metrics::distance::PairwiseDistance;
|
|
||||||
use crate::numbers::floatnum::FloatNumber;
|
|
||||||
use crate::numbers::realnum::RealNumber;
|
|
||||||
|
|
||||||
/// Eppstein dynamic closet-pair structure
|
|
||||||
/// 'M' can be a matrix-like trait that provides row access
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct EppsteinDCP<'a, T: RealNumber + FloatNumber, M: Array2<T>> {
|
|
||||||
samples: &'a M,
|
|
||||||
// "buckets" store, for each row, a small structure recording potential neighbors
|
|
||||||
neighbors: HashMap<usize, PairwiseDistance<T>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> EppsteinDCP<'a, T, M> {
|
|
||||||
/// Creates a new EppsteinDCP instance with the given data
|
|
||||||
pub fn new(m: &'a M) -> Result<Self, Failed> {
|
|
||||||
if m.shape().0 < 3 {
|
|
||||||
return Err(Failed::because(
|
|
||||||
FailedError::FindFailed,
|
|
||||||
"min number of rows should be 3",
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut this = Self {
|
|
||||||
samples: m,
|
|
||||||
neighbors: HashMap::with_capacity(m.shape().0),
|
|
||||||
};
|
|
||||||
this.initialize();
|
|
||||||
Ok(this)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Build an initial "conga line" or chain of potential neighbors
|
|
||||||
/// akin to Eppstein’s technique[2].
|
|
||||||
fn initialize(&mut self) {
|
|
||||||
let n = self.samples.shape().0;
|
|
||||||
if n < 2 {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// Assign each row i some large distance by default
|
|
||||||
for i in 0..n {
|
|
||||||
self.neighbors.insert(
|
|
||||||
i,
|
|
||||||
PairwiseDistance {
|
|
||||||
node: i,
|
|
||||||
neighbour: None,
|
|
||||||
distance: Some(<T as Bounded>::max_value()),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
// Example: link each i to the next, forming a chain
|
|
||||||
// (depending on the actual Eppstein approach, can refine)
|
|
||||||
for i in 0..(n - 1) {
|
|
||||||
let dist = self.compute_dist(i, i + 1);
|
|
||||||
self.neighbors.entry(i).and_modify(|pd| {
|
|
||||||
pd.neighbour = Some(i + 1);
|
|
||||||
pd.distance = Some(dist);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
// Potential refinement steps omitted for brevity
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Insert a point into the structure.
|
|
||||||
pub fn insert(&mut self, row_idx: usize) {
|
|
||||||
// Expand data, find neighbor to link with
|
|
||||||
// For example, link row_idx to nearest among existing
|
|
||||||
let mut best_neighbor = None;
|
|
||||||
let mut best_d = <T as Bounded>::max_value();
|
|
||||||
for (i, _) in &self.neighbors {
|
|
||||||
let d = self.compute_dist(*i, row_idx);
|
|
||||||
if d < best_d {
|
|
||||||
best_d = d;
|
|
||||||
best_neighbor = Some(*i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
self.neighbors.insert(
|
|
||||||
row_idx,
|
|
||||||
PairwiseDistance {
|
|
||||||
node: row_idx,
|
|
||||||
neighbour: best_neighbor,
|
|
||||||
distance: Some(best_d),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
// For the best_neighbor, you might want to see if row_idx becomes closer
|
|
||||||
if let Some(kn) = best_neighbor {
|
|
||||||
let dist = self.compute_dist(row_idx, kn);
|
|
||||||
let entry = self.neighbors.get_mut(&kn).unwrap();
|
|
||||||
if dist < entry.distance.unwrap() {
|
|
||||||
entry.neighbour = Some(row_idx);
|
|
||||||
entry.distance = Some(dist);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// For hierarchical clustering, discover minimal pairs, then merge
|
|
||||||
pub fn closest_pair(&self) -> Option<PairwiseDistance<T>> {
|
|
||||||
let mut min_pair: Option<PairwiseDistance<T>> = None;
|
|
||||||
for (_, pd) in &self.neighbors {
|
|
||||||
if let Some(d) = pd.distance {
|
|
||||||
if min_pair.is_none() || d < min_pair.as_ref().unwrap().distance.unwrap() {
|
|
||||||
min_pair = Some(pd.clone());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
min_pair
|
|
||||||
}
|
|
||||||
|
|
||||||
fn compute_dist(&self, i: usize, j: usize) -> T {
|
|
||||||
// Example: Euclidean
|
|
||||||
let row_i = self.samples.get_row(i);
|
|
||||||
let row_j = self.samples.get_row(j);
|
|
||||||
row_i
|
|
||||||
.iterator(0)
|
|
||||||
.zip(row_j.iterator(0))
|
|
||||||
.map(|(a, b)| (*a - *b) * (*a - *b))
|
|
||||||
.sum()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Simple usage
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests_eppstein {
|
|
||||||
use super::*;
|
|
||||||
use crate::linalg::basic::matrix::DenseMatrix;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_eppstein() {
|
|
||||||
let matrix =
|
|
||||||
DenseMatrix::from_2d_array(&[&vec![1.0, 2.0], &vec![2.0, 2.0], &vec![5.0, 3.0]])
|
|
||||||
.unwrap();
|
|
||||||
let mut dcp = EppsteinDCP::new(&matrix).unwrap();
|
|
||||||
dcp.insert(2);
|
|
||||||
let cp = dcp.closest_pair();
|
|
||||||
assert!(cp.is_some());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn compare_fastpair_eppstein() {
|
|
||||||
use crate::algorithm::neighbour::fastpair::FastPair;
|
|
||||||
// Assuming EppsteinDCP is implemented in a similar module
|
|
||||||
use crate::algorithm::neighbour::eppstein::EppsteinDCP;
|
|
||||||
|
|
||||||
// Create a static example matrix
|
|
||||||
let x = DenseMatrix::from_2d_array(&[
|
|
||||||
&[5.1, 3.5, 1.4, 0.2],
|
|
||||||
&[4.9, 3.0, 1.4, 0.2],
|
|
||||||
&[4.7, 3.2, 1.3, 0.2],
|
|
||||||
&[4.6, 3.1, 1.5, 0.2],
|
|
||||||
&[5.0, 3.6, 1.4, 0.2],
|
|
||||||
&[5.4, 3.9, 1.7, 0.4],
|
|
||||||
&[4.6, 3.4, 1.4, 0.3],
|
|
||||||
&[5.0, 3.4, 1.5, 0.2],
|
|
||||||
&[4.4, 2.9, 1.4, 0.2],
|
|
||||||
&[4.9, 3.1, 1.5, 0.1],
|
|
||||||
])
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// Build FastPair
|
|
||||||
let fastpair = FastPair::new(&x).unwrap();
|
|
||||||
let pair_fastpair = fastpair.closest_pair();
|
|
||||||
|
|
||||||
// Build EppsteinDCP
|
|
||||||
let eppstein = EppsteinDCP::new(&x).unwrap();
|
|
||||||
let pair_eppstein = eppstein.closest_pair();
|
|
||||||
|
|
||||||
// Compare the results
|
|
||||||
assert_eq!(pair_fastpair.node, pair_eppstein.as_ref().unwrap().node);
|
|
||||||
assert_eq!(
|
|
||||||
pair_fastpair.neighbour.unwrap(),
|
|
||||||
pair_eppstein.as_ref().unwrap().neighbour.unwrap()
|
|
||||||
);
|
|
||||||
|
|
||||||
// Use a small epsilon for floating-point comparison
|
|
||||||
let epsilon = 1e-9;
|
|
||||||
let diff: f64 =
|
|
||||||
pair_fastpair.distance.unwrap() - pair_eppstein.as_ref().unwrap().distance.unwrap();
|
|
||||||
assert!(diff.abs() < epsilon);
|
|
||||||
|
|
||||||
println!("FastPair result: {:?}", pair_fastpair);
|
|
||||||
println!("EppsteinDCP result: {:?}", pair_eppstein);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -41,9 +41,7 @@ use serde::{Deserialize, Serialize};
|
|||||||
pub(crate) mod bbd_tree;
|
pub(crate) mod bbd_tree;
|
||||||
/// tree data structure for fast nearest neighbor search
|
/// tree data structure for fast nearest neighbor search
|
||||||
pub mod cover_tree;
|
pub mod cover_tree;
|
||||||
/// eppstein pairwise closest neighbour algorithm
|
/// fastpair closest neighbour algorithm
|
||||||
pub mod eppstein;
|
|
||||||
/// fastpair pairwise closest neighbour algorithm
|
|
||||||
pub mod fastpair;
|
pub mod fastpair;
|
||||||
/// very simple algorithm that sequentially checks each element of the list until a match is found or the whole list has been searched.
|
/// very simple algorithm that sequentially checks each element of the list until a match is found or the whole list has been searched.
|
||||||
pub mod linear_search;
|
pub mod linear_search;
|
||||||
|
|||||||
@@ -663,6 +663,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_instantiate_err_view3() {
|
fn test_instantiate_err_view3() {
|
||||||
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
|
let x = DenseMatrix::from_2d_array(&[&[1., 2., 3.], &[4., 5., 6.], &[7., 8., 9.]]).unwrap();
|
||||||
|
#[allow(clippy::reversed_empty_ranges)]
|
||||||
let v = DenseMatrixView::new(&x, 0..3, 4..3);
|
let v = DenseMatrixView::new(&x, 0..3, 4..3);
|
||||||
assert!(v.is_err());
|
assert!(v.is_err());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -257,8 +257,7 @@ impl<TY: Number + Ord + Unsigned> BernoulliNBDistribution<TY> {
|
|||||||
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
|
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
|
||||||
/// * `x` - training data.
|
/// * `x` - training data.
|
||||||
/// * `y` - vector with target values (classes) of length N.
|
/// * `y` - vector with target values (classes) of length N.
|
||||||
/// * `priors` - Optional vector with prior probabilities of the classes. If not defined,
|
/// * `priors` - Optional vector with prior probabilities of the classes. If not defined, priors are adjusted according to the data.
|
||||||
/// priors are adjusted according to the data.
|
|
||||||
/// * `alpha` - Additive (Laplace/Lidstone) smoothing parameter.
|
/// * `alpha` - Additive (Laplace/Lidstone) smoothing parameter.
|
||||||
/// * `binarize` - Threshold for binarizing.
|
/// * `binarize` - Threshold for binarizing.
|
||||||
fn fit<TX: Number + PartialOrd, X: Array2<TX>, Y: Array1<TY>>(
|
fn fit<TX: Number + PartialOrd, X: Array2<TX>, Y: Array1<TY>>(
|
||||||
|
|||||||
@@ -174,8 +174,7 @@ impl<TY: Number + Ord + Unsigned> GaussianNBDistribution<TY> {
|
|||||||
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
|
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
|
||||||
/// * `x` - training data.
|
/// * `x` - training data.
|
||||||
/// * `y` - vector with target values (classes) of length N.
|
/// * `y` - vector with target values (classes) of length N.
|
||||||
/// * `priors` - Optional vector with prior probabilities of the classes. If not defined,
|
/// * `priors` - Optional vector with prior probabilities of the classes. If not defined, priors are adjusted according to the data.
|
||||||
/// priors are adjusted according to the data.
|
|
||||||
pub fn fit<TX: Number + RealNumber, X: Array2<TX>, Y: Array1<TY>>(
|
pub fn fit<TX: Number + RealNumber, X: Array2<TX>, Y: Array1<TY>>(
|
||||||
x: &X,
|
x: &X,
|
||||||
y: &Y,
|
y: &Y,
|
||||||
|
|||||||
@@ -207,8 +207,7 @@ impl<TY: Number + Ord + Unsigned> MultinomialNBDistribution<TY> {
|
|||||||
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
|
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
|
||||||
/// * `x` - training data.
|
/// * `x` - training data.
|
||||||
/// * `y` - vector with target values (classes) of length N.
|
/// * `y` - vector with target values (classes) of length N.
|
||||||
/// * `priors` - Optional vector with prior probabilities of the classes. If not defined,
|
/// * `priors` - Optional vector with prior probabilities of the classes. If not defined, priors are adjusted according to the data.
|
||||||
/// priors are adjusted according to the data.
|
|
||||||
/// * `alpha` - Additive (Laplace/Lidstone) smoothing parameter.
|
/// * `alpha` - Additive (Laplace/Lidstone) smoothing parameter.
|
||||||
pub fn fit<TX: Number + Unsigned, X: Array2<TX>, Y: Array1<TY>>(
|
pub fn fit<TX: Number + Unsigned, X: Array2<TX>, Y: Array1<TY>>(
|
||||||
x: &X,
|
x: &X,
|
||||||
|
|||||||
@@ -24,7 +24,7 @@
|
|||||||
//! // &[1.5, 1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0]
|
//! // &[1.5, 1.0, 0.0, 1.5, 0.0, 0.0, 1.0, 0.0]
|
||||||
//! // &[1.5, 0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0]
|
//! // &[1.5, 0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 1.0]
|
||||||
//! ```
|
//! ```
|
||||||
use std::iter;
|
use std::iter::repeat_n;
|
||||||
|
|
||||||
use crate::error::Failed;
|
use crate::error::Failed;
|
||||||
use crate::linalg::basic::arrays::Array2;
|
use crate::linalg::basic::arrays::Array2;
|
||||||
@@ -75,11 +75,7 @@ fn find_new_idxs(num_params: usize, cat_sizes: &[usize], cat_idxs: &[usize]) ->
|
|||||||
let offset = (0..1).chain(offset_);
|
let offset = (0..1).chain(offset_);
|
||||||
|
|
||||||
let new_param_idxs: Vec<usize> = (0..num_params)
|
let new_param_idxs: Vec<usize> = (0..num_params)
|
||||||
.zip(
|
.zip(repeats.zip(offset).flat_map(|(r, o)| repeat_n(o, r)))
|
||||||
repeats
|
|
||||||
.zip(offset)
|
|
||||||
.flat_map(|(r, o)| iter::repeat(o).take(r)),
|
|
||||||
)
|
|
||||||
.map(|(idx, ofst)| idx + ofst)
|
.map(|(idx, ofst)| idx + ofst)
|
||||||
.collect();
|
.collect();
|
||||||
new_param_idxs
|
new_param_idxs
|
||||||
@@ -124,7 +120,7 @@ impl OneHotEncoder {
|
|||||||
let (nrows, _) = data.shape();
|
let (nrows, _) = data.shape();
|
||||||
|
|
||||||
// col buffer to avoid allocations
|
// col buffer to avoid allocations
|
||||||
let mut col_buf: Vec<T> = iter::repeat(T::zero()).take(nrows).collect();
|
let mut col_buf: Vec<T> = repeat_n(T::zero(), nrows).collect();
|
||||||
|
|
||||||
let mut res: Vec<CategoryMapper<CategoricalFloat>> = Vec::with_capacity(idxs.len());
|
let mut res: Vec<CategoryMapper<CategoricalFloat>> = Vec::with_capacity(idxs.len());
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user