Merge remote-tracking branch 'sm/development' into predict-probability
This commit is contained in:
@@ -11,7 +11,8 @@ jobs:
|
||||
runs-on: "${{ matrix.platform.os }}-latest"
|
||||
strategy:
|
||||
matrix:
|
||||
platform: [
|
||||
platform:
|
||||
[
|
||||
{ os: "windows", target: "x86_64-pc-windows-msvc" },
|
||||
{ os: "windows", target: "i686-pc-windows-msvc" },
|
||||
{ os: "ubuntu", target: "x86_64-unknown-linux-gnu" },
|
||||
|
||||
@@ -6,6 +6,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
## Added
|
||||
- Seeds to multiple algorithims that depend on random number generation.
|
||||
- Added feature `js` to use WASM in browser
|
||||
|
||||
## BREAKING CHANGE
|
||||
- Added a new parameter to `train_test_split` to define the seed.
|
||||
|
||||
## [0.2.1] - 2022-05-10
|
||||
|
||||
## Added
|
||||
- L2 regularization penalty to the Logistic Regression
|
||||
- Getters for the naive bayes structs
|
||||
|
||||
+17
-5
@@ -16,22 +16,29 @@ categories = ["science"]
|
||||
default = ["datasets"]
|
||||
ndarray-bindings = ["ndarray"]
|
||||
nalgebra-bindings = ["nalgebra"]
|
||||
datasets = []
|
||||
datasets = ["rand_distr", "std"]
|
||||
fp_bench = ["itertools"]
|
||||
std = ["rand/std", "rand/std_rng"]
|
||||
# wasm32 only
|
||||
js = ["getrandom/js"]
|
||||
|
||||
[dependencies]
|
||||
ndarray = { version = "0.15", optional = true }
|
||||
nalgebra = { version = "0.31", optional = true }
|
||||
num-traits = "0.2"
|
||||
num = "0.4"
|
||||
rand = "0.8"
|
||||
rand_distr = "0.4"
|
||||
rand = { version = "0.8", default-features = false, features = ["small_rng"] }
|
||||
rand_distr = { version = "0.4", optional = true }
|
||||
serde = { version = "1", features = ["derive"], optional = true }
|
||||
itertools = { version = "0.10.3", optional = true }
|
||||
cfg-if = "1.0.0"
|
||||
|
||||
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||
getrandom = { version = "0.2", features = ["js"] }
|
||||
getrandom = { version = "0.2", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.3"
|
||||
smartcore = { path = ".", features = ["fp_bench"] }
|
||||
criterion = { version = "0.4", default-features = false }
|
||||
serde_json = "1.0"
|
||||
bincode = "1.3.1"
|
||||
|
||||
@@ -46,3 +53,8 @@ harness = false
|
||||
name = "naive_bayes"
|
||||
harness = false
|
||||
required-features = ["ndarray-bindings", "nalgebra-bindings"]
|
||||
|
||||
[[bench]]
|
||||
name = "fastpair"
|
||||
harness = false
|
||||
required-features = ["fp_bench"]
|
||||
|
||||
@@ -186,7 +186,7 @@
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
Copyright 2019-present at SmartCore developers (smartcorelib.org)
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
|
||||
|
||||
// to run this bench you have to change the declaraion in mod.rs ---> pub mod fastpair;
|
||||
use smartcore::algorithm::neighbour::fastpair::FastPair;
|
||||
use smartcore::linalg::naive::dense_matrix::*;
|
||||
use std::time::Duration;
|
||||
|
||||
fn closest_pair_bench(n: usize, m: usize) -> () {
|
||||
let x = DenseMatrix::<f64>::rand(n, m);
|
||||
let fastpair = FastPair::new(&x);
|
||||
let result = fastpair.unwrap();
|
||||
|
||||
result.closest_pair();
|
||||
}
|
||||
|
||||
fn closest_pair_brute_bench(n: usize, m: usize) -> () {
|
||||
let x = DenseMatrix::<f64>::rand(n, m);
|
||||
let fastpair = FastPair::new(&x);
|
||||
let result = fastpair.unwrap();
|
||||
|
||||
result.closest_pair_brute();
|
||||
}
|
||||
|
||||
fn bench_fastpair(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("FastPair");
|
||||
|
||||
// with full samples size (100) the test will take too long
|
||||
group.significance_level(0.1).sample_size(30);
|
||||
// increase from default 5.0 secs
|
||||
group.measurement_time(Duration::from_secs(60));
|
||||
|
||||
for n_samples in [100_usize, 1000_usize].iter() {
|
||||
for n_features in [10_usize, 100_usize, 1000_usize].iter() {
|
||||
group.bench_with_input(
|
||||
BenchmarkId::from_parameter(format!(
|
||||
"fastpair --- n_samples: {}, n_features: {}",
|
||||
n_samples, n_features
|
||||
)),
|
||||
n_samples,
|
||||
|b, _| b.iter(|| closest_pair_bench(*n_samples, *n_features)),
|
||||
);
|
||||
group.bench_with_input(
|
||||
BenchmarkId::from_parameter(format!(
|
||||
"brute --- n_samples: {}, n_features: {}",
|
||||
n_samples, n_features
|
||||
)),
|
||||
n_samples,
|
||||
|b, _| b.iter(|| closest_pair_brute_bench(*n_samples, *n_features)),
|
||||
);
|
||||
}
|
||||
}
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(benches, bench_fastpair);
|
||||
criterion_main!(benches);
|
||||
@@ -59,7 +59,7 @@ impl<T: RealNumber> BBDTree<T> {
|
||||
tree
|
||||
}
|
||||
|
||||
pub(in crate) fn clustering(
|
||||
pub(crate) fn clustering(
|
||||
&self,
|
||||
centroids: &[Vec<T>],
|
||||
sums: &mut Vec<Vec<T>>,
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
//!
|
||||
//! Dissimilarities for vector-vector distance
|
||||
//!
|
||||
//! Representing distances as pairwise dissimilarities, so to build a
|
||||
//! graph of closest neighbours. This representation can be reused for
|
||||
//! different implementations (initially used in this library for FastPair).
|
||||
use std::cmp::{Eq, Ordering, PartialOrd};
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
///
|
||||
/// The edge of the subgraph is defined by `PairwiseDistance`.
|
||||
/// The calling algorithm can store a list of distsances as
|
||||
/// a list of these structures.
|
||||
///
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct PairwiseDistance<T: RealNumber> {
|
||||
/// index of the vector in the original `Matrix` or list
|
||||
pub node: usize,
|
||||
|
||||
/// index of the closest neighbor in the original `Matrix` or same list
|
||||
pub neighbour: Option<usize>,
|
||||
|
||||
/// measure of distance, according to the algorithm distance function
|
||||
/// if the distance is None, the edge has value "infinite" or max distance
|
||||
/// each algorithm has to match
|
||||
pub distance: Option<T>,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Eq for PairwiseDistance<T> {}
|
||||
|
||||
impl<T: RealNumber> PartialEq for PairwiseDistance<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.node == other.node
|
||||
&& self.neighbour == other.neighbour
|
||||
&& self.distance == other.distance
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> PartialOrd for PairwiseDistance<T> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
self.distance.partial_cmp(&other.distance)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,569 @@
|
||||
///
|
||||
/// # FastPair: Data-structure for the dynamic closest-pair problem.
|
||||
///
|
||||
/// Reference:
|
||||
/// Eppstein, David: Fast hierarchical clustering and other applications of
|
||||
/// dynamic closest pairs. Journal of Experimental Algorithmics 5 (2000) 1.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// use smartcore::algorithm::neighbour::distances::PairwiseDistance;
|
||||
/// use smartcore::linalg::naive::dense_matrix::DenseMatrix;
|
||||
/// use smartcore::algorithm::neighbour::fastpair::FastPair;
|
||||
/// let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
/// &[5.1, 3.5, 1.4, 0.2],
|
||||
/// &[4.9, 3.0, 1.4, 0.2],
|
||||
/// &[4.7, 3.2, 1.3, 0.2],
|
||||
/// &[4.6, 3.1, 1.5, 0.2],
|
||||
/// &[5.0, 3.6, 1.4, 0.2],
|
||||
/// &[5.4, 3.9, 1.7, 0.4],
|
||||
/// ]);
|
||||
/// let fastpair = FastPair::new(&x);
|
||||
/// let closest_pair: PairwiseDistance<f64> = fastpair.unwrap().closest_pair();
|
||||
/// ```
|
||||
/// <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
/// <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::algorithm::neighbour::distances::PairwiseDistance;
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::distance::euclidian::Euclidian;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
///
|
||||
/// Inspired by Python implementation:
|
||||
/// <https://github.com/carsonfarmer/fastpair/blob/b8b4d3000ab6f795a878936667eee1b557bf353d/fastpair/base.py>
|
||||
/// MIT License (MIT) Copyright (c) 2016 Carson Farmer
|
||||
///
|
||||
/// affinity used is Euclidean so to allow linkage with single, ward, complete and average
|
||||
///
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FastPair<'a, T: RealNumber, M: Matrix<T>> {
|
||||
/// initial matrix
|
||||
samples: &'a M,
|
||||
/// closest pair hashmap (connectivity matrix for closest pairs)
|
||||
pub distances: HashMap<usize, PairwiseDistance<T>>,
|
||||
/// conga line used to keep track of the closest pair
|
||||
pub neighbours: Vec<usize>,
|
||||
}
|
||||
|
||||
impl<'a, T: RealNumber, M: Matrix<T>> FastPair<'a, T, M> {
|
||||
///
|
||||
/// Constructor
|
||||
/// Instantiate and inizialise the algorithm
|
||||
///
|
||||
pub fn new(m: &'a M) -> Result<Self, Failed> {
|
||||
if m.shape().0 < 3 {
|
||||
return Err(Failed::because(
|
||||
FailedError::FindFailed,
|
||||
"min number of rows should be 3",
|
||||
));
|
||||
}
|
||||
|
||||
let mut init = Self {
|
||||
samples: m,
|
||||
// to be computed in init(..)
|
||||
distances: HashMap::with_capacity(m.shape().0),
|
||||
neighbours: Vec::with_capacity(m.shape().0 + 1),
|
||||
};
|
||||
init.init();
|
||||
Ok(init)
|
||||
}
|
||||
|
||||
///
|
||||
/// Initialise `FastPair` by passing a `Matrix`.
|
||||
/// Build a FastPairs data-structure from a set of (new) points.
|
||||
///
|
||||
fn init(&mut self) {
|
||||
// basic measures
|
||||
let len = self.samples.shape().0;
|
||||
let max_index = self.samples.shape().0 - 1;
|
||||
|
||||
// Store all closest neighbors
|
||||
let _distances = Box::new(HashMap::with_capacity(len));
|
||||
let _neighbours = Box::new(Vec::with_capacity(len));
|
||||
|
||||
let mut distances = *_distances;
|
||||
let mut neighbours = *_neighbours;
|
||||
|
||||
// fill neighbours with -1 values
|
||||
neighbours.extend(0..len);
|
||||
|
||||
// init closest neighbour pairwise data
|
||||
for index_row_i in 0..(max_index) {
|
||||
distances.insert(
|
||||
index_row_i,
|
||||
PairwiseDistance {
|
||||
node: index_row_i,
|
||||
neighbour: None,
|
||||
distance: Some(T::max_value()),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// loop through indeces and neighbours
|
||||
for index_row_i in 0..(len) {
|
||||
// start looking for the neighbour in the second element
|
||||
let mut index_closest = index_row_i + 1; // closest neighbour index
|
||||
let mut nbd: Option<T> = distances[&index_row_i].distance; // init neighbour distance
|
||||
for index_row_j in (index_row_i + 1)..len {
|
||||
distances.insert(
|
||||
index_row_j,
|
||||
PairwiseDistance {
|
||||
node: index_row_j,
|
||||
neighbour: Some(index_row_i),
|
||||
distance: nbd,
|
||||
},
|
||||
);
|
||||
|
||||
let d = Euclidian::squared_distance(
|
||||
&(self.samples.get_row_as_vec(index_row_i)),
|
||||
&(self.samples.get_row_as_vec(index_row_j)),
|
||||
);
|
||||
if d < nbd.unwrap() {
|
||||
// set this j-value to be the closest neighbour
|
||||
index_closest = index_row_j;
|
||||
nbd = Some(d);
|
||||
}
|
||||
}
|
||||
|
||||
// Add that edge
|
||||
distances.entry(index_row_i).and_modify(|e| {
|
||||
e.distance = nbd;
|
||||
e.neighbour = Some(index_closest);
|
||||
});
|
||||
}
|
||||
// No more neighbors, terminate conga line.
|
||||
// Last person on the line has no neigbors
|
||||
distances.get_mut(&max_index).unwrap().neighbour = Some(max_index);
|
||||
distances.get_mut(&(len - 1)).unwrap().distance = Some(T::max_value());
|
||||
|
||||
// compute sparse matrix (connectivity matrix)
|
||||
let mut sparse_matrix = M::zeros(len, len);
|
||||
for (_, p) in distances.iter() {
|
||||
sparse_matrix.set(p.node, p.neighbour.unwrap(), p.distance.unwrap());
|
||||
}
|
||||
|
||||
self.distances = distances;
|
||||
self.neighbours = neighbours;
|
||||
}
|
||||
|
||||
///
|
||||
/// Find closest pair by scanning list of nearest neighbors.
|
||||
///
|
||||
#[allow(dead_code)]
|
||||
pub fn closest_pair(&self) -> PairwiseDistance<T> {
|
||||
let mut a = self.neighbours[0]; // Start with first point
|
||||
let mut d = self.distances[&a].distance;
|
||||
for p in self.neighbours.iter() {
|
||||
if self.distances[p].distance < d {
|
||||
a = *p; // Update `a` and distance `d`
|
||||
d = self.distances[p].distance;
|
||||
}
|
||||
}
|
||||
let b = self.distances[&a].neighbour;
|
||||
PairwiseDistance {
|
||||
node: a,
|
||||
neighbour: b,
|
||||
distance: d,
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
/// Brute force algorithm, used only for comparison and testing
|
||||
///
|
||||
#[cfg(feature = "fp_bench")]
|
||||
pub fn closest_pair_brute(&self) -> PairwiseDistance<T> {
|
||||
use itertools::Itertools;
|
||||
let m = self.samples.shape().0;
|
||||
|
||||
let mut closest_pair = PairwiseDistance {
|
||||
node: 0,
|
||||
neighbour: None,
|
||||
distance: Some(T::max_value()),
|
||||
};
|
||||
for pair in (0..m).combinations(2) {
|
||||
let d = Euclidian::squared_distance(
|
||||
&(self.samples.get_row_as_vec(pair[0])),
|
||||
&(self.samples.get_row_as_vec(pair[1])),
|
||||
);
|
||||
if d < closest_pair.distance.unwrap() {
|
||||
closest_pair.node = pair[0];
|
||||
closest_pair.neighbour = Some(pair[1]);
|
||||
closest_pair.distance = Some(d);
|
||||
}
|
||||
}
|
||||
closest_pair
|
||||
}
|
||||
|
||||
//
|
||||
// Compute distances from input to all other points in data-structure.
|
||||
// input is the row index of the sample matrix
|
||||
//
|
||||
#[allow(dead_code)]
|
||||
fn distances_from(&self, index_row: usize) -> Vec<PairwiseDistance<T>> {
|
||||
let mut distances = Vec::<PairwiseDistance<T>>::with_capacity(self.samples.shape().0);
|
||||
for other in self.neighbours.iter() {
|
||||
if index_row != *other {
|
||||
distances.push(PairwiseDistance {
|
||||
node: index_row,
|
||||
neighbour: Some(*other),
|
||||
distance: Some(Euclidian::squared_distance(
|
||||
&(self.samples.get_row_as_vec(index_row)),
|
||||
&(self.samples.get_row_as_vec(*other)),
|
||||
)),
|
||||
})
|
||||
}
|
||||
}
|
||||
distances
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests_fastpair {
|
||||
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
|
||||
#[test]
|
||||
fn fastpair_init() {
|
||||
let x: DenseMatrix<f64> = DenseMatrix::rand(10, 4);
|
||||
let _fastpair = FastPair::new(&x);
|
||||
assert!(_fastpair.is_ok());
|
||||
|
||||
let fastpair = _fastpair.unwrap();
|
||||
|
||||
let distances = fastpair.distances;
|
||||
let neighbours = fastpair.neighbours;
|
||||
|
||||
assert!(distances.len() != 0);
|
||||
assert!(neighbours.len() != 0);
|
||||
|
||||
assert_eq!(10, neighbours.len());
|
||||
assert_eq!(10, distances.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dataset_has_at_least_three_points() {
|
||||
// Create a dataset which consists of only two points:
|
||||
// A(0.0, 0.0) and B(1.0, 1.0).
|
||||
let dataset = DenseMatrix::<f64>::from_2d_array(&[&[0.0, 0.0], &[1.0, 1.0]]);
|
||||
|
||||
// We expect an error when we run `FastPair` on this dataset,
|
||||
// becuase `FastPair` currently only works on a minimum of 3
|
||||
// points.
|
||||
let _fastpair = FastPair::new(&dataset);
|
||||
|
||||
match _fastpair {
|
||||
Err(e) => {
|
||||
let expected_error =
|
||||
Failed::because(FailedError::FindFailed, "min number of rows should be 3");
|
||||
assert_eq!(e, expected_error)
|
||||
}
|
||||
_ => {
|
||||
assert!(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one_dimensional_dataset_minimal() {
|
||||
let dataset = DenseMatrix::<f64>::from_2d_array(&[&[0.0], &[2.0], &[9.0]]);
|
||||
|
||||
let result = FastPair::new(&dataset);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let fastpair = result.unwrap();
|
||||
let closest_pair = fastpair.closest_pair();
|
||||
let expected_closest_pair = PairwiseDistance {
|
||||
node: 0,
|
||||
neighbour: Some(1),
|
||||
distance: Some(4.0),
|
||||
};
|
||||
assert_eq!(closest_pair, expected_closest_pair);
|
||||
|
||||
let closest_pair_brute = fastpair.closest_pair_brute();
|
||||
assert_eq!(closest_pair_brute, expected_closest_pair);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one_dimensional_dataset_2() {
|
||||
let dataset = DenseMatrix::<f64>::from_2d_array(&[&[27.0], &[0.0], &[9.0], &[2.0]]);
|
||||
|
||||
let result = FastPair::new(&dataset);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let fastpair = result.unwrap();
|
||||
let closest_pair = fastpair.closest_pair();
|
||||
let expected_closest_pair = PairwiseDistance {
|
||||
node: 1,
|
||||
neighbour: Some(3),
|
||||
distance: Some(4.0),
|
||||
};
|
||||
assert_eq!(closest_pair, fastpair.closest_pair_brute());
|
||||
assert_eq!(closest_pair, expected_closest_pair);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fastpair_new() {
|
||||
// compute
|
||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
]);
|
||||
let fastpair = FastPair::new(&x);
|
||||
assert!(fastpair.is_ok());
|
||||
|
||||
// unwrap results
|
||||
let result = fastpair.unwrap();
|
||||
|
||||
// list of minimal pairwise dissimilarities
|
||||
let dissimilarities = vec![
|
||||
(
|
||||
1,
|
||||
PairwiseDistance {
|
||||
node: 1,
|
||||
neighbour: Some(9),
|
||||
distance: Some(0.030000000000000037),
|
||||
},
|
||||
),
|
||||
(
|
||||
10,
|
||||
PairwiseDistance {
|
||||
node: 10,
|
||||
neighbour: Some(12),
|
||||
distance: Some(0.07000000000000003),
|
||||
},
|
||||
),
|
||||
(
|
||||
11,
|
||||
PairwiseDistance {
|
||||
node: 11,
|
||||
neighbour: Some(14),
|
||||
distance: Some(0.18000000000000013),
|
||||
},
|
||||
),
|
||||
(
|
||||
12,
|
||||
PairwiseDistance {
|
||||
node: 12,
|
||||
neighbour: Some(14),
|
||||
distance: Some(0.34000000000000086),
|
||||
},
|
||||
),
|
||||
(
|
||||
13,
|
||||
PairwiseDistance {
|
||||
node: 13,
|
||||
neighbour: Some(14),
|
||||
distance: Some(1.6499999999999997),
|
||||
},
|
||||
),
|
||||
(
|
||||
14,
|
||||
PairwiseDistance {
|
||||
node: 14,
|
||||
neighbour: Some(14),
|
||||
distance: Some(f64::MAX),
|
||||
},
|
||||
),
|
||||
(
|
||||
6,
|
||||
PairwiseDistance {
|
||||
node: 6,
|
||||
neighbour: Some(7),
|
||||
distance: Some(0.18000000000000027),
|
||||
},
|
||||
),
|
||||
(
|
||||
0,
|
||||
PairwiseDistance {
|
||||
node: 0,
|
||||
neighbour: Some(4),
|
||||
distance: Some(0.01999999999999995),
|
||||
},
|
||||
),
|
||||
(
|
||||
8,
|
||||
PairwiseDistance {
|
||||
node: 8,
|
||||
neighbour: Some(9),
|
||||
distance: Some(0.3100000000000001),
|
||||
},
|
||||
),
|
||||
(
|
||||
2,
|
||||
PairwiseDistance {
|
||||
node: 2,
|
||||
neighbour: Some(3),
|
||||
distance: Some(0.0600000000000001),
|
||||
},
|
||||
),
|
||||
(
|
||||
3,
|
||||
PairwiseDistance {
|
||||
node: 3,
|
||||
neighbour: Some(8),
|
||||
distance: Some(0.08999999999999982),
|
||||
},
|
||||
),
|
||||
(
|
||||
7,
|
||||
PairwiseDistance {
|
||||
node: 7,
|
||||
neighbour: Some(9),
|
||||
distance: Some(0.10999999999999982),
|
||||
},
|
||||
),
|
||||
(
|
||||
9,
|
||||
PairwiseDistance {
|
||||
node: 9,
|
||||
neighbour: Some(13),
|
||||
distance: Some(8.69),
|
||||
},
|
||||
),
|
||||
(
|
||||
4,
|
||||
PairwiseDistance {
|
||||
node: 4,
|
||||
neighbour: Some(7),
|
||||
distance: Some(0.050000000000000086),
|
||||
},
|
||||
),
|
||||
(
|
||||
5,
|
||||
PairwiseDistance {
|
||||
node: 5,
|
||||
neighbour: Some(7),
|
||||
distance: Some(0.4900000000000002),
|
||||
},
|
||||
),
|
||||
];
|
||||
|
||||
let expected: HashMap<_, _> = dissimilarities.into_iter().collect();
|
||||
|
||||
for i in 0..(x.shape().0 - 1) {
|
||||
let input_node = result.samples.get_row_as_vec(i);
|
||||
let input_neighbour: usize = expected.get(&i).unwrap().neighbour.unwrap();
|
||||
let distance = Euclidian::squared_distance(
|
||||
&input_node,
|
||||
&result.samples.get_row_as_vec(input_neighbour),
|
||||
);
|
||||
|
||||
assert_eq!(i, expected.get(&i).unwrap().node);
|
||||
assert_eq!(
|
||||
input_neighbour,
|
||||
expected.get(&i).unwrap().neighbour.unwrap()
|
||||
);
|
||||
assert_eq!(distance, expected.get(&i).unwrap().distance.unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fastpair_closest_pair() {
|
||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
]);
|
||||
// compute
|
||||
let fastpair = FastPair::new(&x);
|
||||
assert!(fastpair.is_ok());
|
||||
|
||||
let dissimilarity = fastpair.unwrap().closest_pair();
|
||||
let closest = PairwiseDistance {
|
||||
node: 0,
|
||||
neighbour: Some(4),
|
||||
distance: Some(0.01999999999999995),
|
||||
};
|
||||
|
||||
assert_eq!(closest, dissimilarity);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fastpair_closest_pair_random_matrix() {
|
||||
let x = DenseMatrix::<f64>::rand(200, 25);
|
||||
// compute
|
||||
let fastpair = FastPair::new(&x);
|
||||
assert!(fastpair.is_ok());
|
||||
|
||||
let result = fastpair.unwrap();
|
||||
|
||||
let dissimilarity1 = result.closest_pair();
|
||||
let dissimilarity2 = result.closest_pair_brute();
|
||||
|
||||
assert_eq!(dissimilarity1, dissimilarity2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fastpair_distances() {
|
||||
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
]);
|
||||
// compute
|
||||
let fastpair = FastPair::new(&x);
|
||||
assert!(fastpair.is_ok());
|
||||
|
||||
let dissimilarities = fastpair.unwrap().distances_from(0);
|
||||
|
||||
let mut min_dissimilarity = PairwiseDistance {
|
||||
node: 0,
|
||||
neighbour: None,
|
||||
distance: Some(f64::MAX),
|
||||
};
|
||||
for p in dissimilarities.iter() {
|
||||
if p.distance.unwrap() < min_dissimilarity.distance.unwrap() {
|
||||
min_dissimilarity = p.clone()
|
||||
}
|
||||
}
|
||||
|
||||
let closest = PairwiseDistance {
|
||||
node: 0,
|
||||
neighbour: Some(4),
|
||||
distance: Some(0.01999999999999995),
|
||||
};
|
||||
|
||||
assert_eq!(closest, min_dissimilarity);
|
||||
}
|
||||
}
|
||||
@@ -41,6 +41,10 @@ use serde::{Deserialize, Serialize};
|
||||
pub(crate) mod bbd_tree;
|
||||
/// tree data structure for fast nearest neighbor search
|
||||
pub mod cover_tree;
|
||||
/// dissimilarities for vector-vector distance. Linkage algorithms used in fastpair
|
||||
pub mod distances;
|
||||
/// fastpair closest neighbour algorithm
|
||||
pub mod fastpair;
|
||||
/// very simple algorithm that sequentially checks each element of the list until a match is found or the whole list has been searched.
|
||||
pub mod linear_search;
|
||||
|
||||
@@ -55,6 +59,12 @@ pub enum KNNAlgorithmName {
|
||||
CoverTree,
|
||||
}
|
||||
|
||||
impl Default for KNNAlgorithmName {
|
||||
fn default() -> Self {
|
||||
KNNAlgorithmName::CoverTree
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum KNNAlgorithm<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
|
||||
@@ -12,7 +12,7 @@ pub struct HeapSelection<T: PartialOrd + Debug> {
|
||||
heap: Vec<T>,
|
||||
}
|
||||
|
||||
impl<'a, T: PartialOrd + Debug> HeapSelection<T> {
|
||||
impl<T: PartialOrd + Debug> HeapSelection<T> {
|
||||
pub fn with_capacity(k: usize) -> HeapSelection<T> {
|
||||
HeapSelection {
|
||||
k,
|
||||
|
||||
+130
-1
@@ -65,17 +65,22 @@ pub struct DBSCAN<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
eps: T,
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// DBSCAN clustering algorithm parameters
|
||||
pub struct DBSCANParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
pub distance: D,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
|
||||
pub min_samples: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
|
||||
pub eps: T,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// KNN algorithm to use.
|
||||
pub algorithm: KNNAlgorithmName,
|
||||
}
|
||||
@@ -109,6 +114,107 @@ impl<T: RealNumber, D: Distance<Vec<T>, T>> DBSCANParameters<T, D> {
|
||||
}
|
||||
}
|
||||
|
||||
/// DBSCAN grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DBSCANSearchParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
pub distance: Vec<D>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
|
||||
pub min_samples: Vec<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
|
||||
pub eps: Vec<T>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// KNN algorithm to use.
|
||||
pub algorithm: Vec<KNNAlgorithmName>,
|
||||
}
|
||||
|
||||
/// DBSCAN grid search iterator
|
||||
pub struct DBSCANSearchParametersIterator<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
dbscan_search_parameters: DBSCANSearchParameters<T, D>,
|
||||
current_distance: usize,
|
||||
current_min_samples: usize,
|
||||
current_eps: usize,
|
||||
current_algorithm: usize,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> IntoIterator for DBSCANSearchParameters<T, D> {
|
||||
type Item = DBSCANParameters<T, D>;
|
||||
type IntoIter = DBSCANSearchParametersIterator<T, D>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
DBSCANSearchParametersIterator {
|
||||
dbscan_search_parameters: self,
|
||||
current_distance: 0,
|
||||
current_min_samples: 0,
|
||||
current_eps: 0,
|
||||
current_algorithm: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> Iterator for DBSCANSearchParametersIterator<T, D> {
|
||||
type Item = DBSCANParameters<T, D>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_distance == self.dbscan_search_parameters.distance.len()
|
||||
&& self.current_min_samples == self.dbscan_search_parameters.min_samples.len()
|
||||
&& self.current_eps == self.dbscan_search_parameters.eps.len()
|
||||
&& self.current_algorithm == self.dbscan_search_parameters.algorithm.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = DBSCANParameters {
|
||||
distance: self.dbscan_search_parameters.distance[self.current_distance].clone(),
|
||||
min_samples: self.dbscan_search_parameters.min_samples[self.current_min_samples],
|
||||
eps: self.dbscan_search_parameters.eps[self.current_eps],
|
||||
algorithm: self.dbscan_search_parameters.algorithm[self.current_algorithm].clone(),
|
||||
};
|
||||
|
||||
if self.current_distance + 1 < self.dbscan_search_parameters.distance.len() {
|
||||
self.current_distance += 1;
|
||||
} else if self.current_min_samples + 1 < self.dbscan_search_parameters.min_samples.len() {
|
||||
self.current_distance = 0;
|
||||
self.current_min_samples += 1;
|
||||
} else if self.current_eps + 1 < self.dbscan_search_parameters.eps.len() {
|
||||
self.current_distance = 0;
|
||||
self.current_min_samples = 0;
|
||||
self.current_eps += 1;
|
||||
} else if self.current_algorithm + 1 < self.dbscan_search_parameters.algorithm.len() {
|
||||
self.current_distance = 0;
|
||||
self.current_min_samples = 0;
|
||||
self.current_eps = 0;
|
||||
self.current_algorithm += 1;
|
||||
} else {
|
||||
self.current_distance += 1;
|
||||
self.current_min_samples += 1;
|
||||
self.current_eps += 1;
|
||||
self.current_algorithm += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for DBSCANSearchParameters<T, Euclidian> {
|
||||
fn default() -> Self {
|
||||
let default_params = DBSCANParameters::default();
|
||||
|
||||
DBSCANSearchParameters {
|
||||
distance: vec![default_params.distance],
|
||||
min_samples: vec![default_params.min_samples],
|
||||
eps: vec![default_params.eps],
|
||||
algorithm: vec![default_params.algorithm],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, D: Distance<Vec<T>, T>> PartialEq for DBSCAN<T, D> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.cluster_labels.len() == other.cluster_labels.len()
|
||||
@@ -124,7 +230,7 @@ impl<T: RealNumber> Default for DBSCANParameters<T, Euclidian> {
|
||||
distance: Distances::euclidian(),
|
||||
min_samples: 5,
|
||||
eps: T::half(),
|
||||
algorithm: KNNAlgorithmName::CoverTree,
|
||||
algorithm: KNNAlgorithmName::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -268,6 +374,29 @@ mod tests {
|
||||
#[cfg(feature = "serde")]
|
||||
use crate::math::distance::euclidian::Euclidian;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = DBSCANSearchParameters {
|
||||
min_samples: vec![10, 100],
|
||||
eps: vec![1., 2.],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.min_samples, 10);
|
||||
assert_eq!(next.eps, 1.);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.min_samples, 100);
|
||||
assert_eq!(next.eps, 1.);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.min_samples, 10);
|
||||
assert_eq!(next.eps, 2.);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.min_samples, 100);
|
||||
assert_eq!(next.eps, 2.);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn fit_predict_dbscan() {
|
||||
|
||||
+122
-4
@@ -52,10 +52,10 @@
|
||||
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 10.3.1 K-Means Clustering](http://faculty.marshall.usc.edu/gareth-james/ISL/)
|
||||
//! * ["k-means++: The Advantages of Careful Seeding", Arthur D., Vassilvitskii S.](http://ilpubs.stanford.edu:8090/778/1/2006-13.pdf)
|
||||
|
||||
use rand::Rng;
|
||||
use std::fmt::Debug;
|
||||
use std::iter::Sum;
|
||||
|
||||
use ::rand::Rng;
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -65,6 +65,7 @@ use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::distance::euclidian::*;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::rand::get_rng_impl;
|
||||
|
||||
/// K-Means clustering algorithm
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
@@ -101,13 +102,20 @@ impl<T: RealNumber> PartialEq for KMeans<T> {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// K-Means clustering algorithm parameters
|
||||
pub struct KMeansParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Number of clusters.
|
||||
pub k: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Maximum number of iterations of the k-means algorithm for a single run.
|
||||
pub max_iter: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Determines random number generation for centroid initialization.
|
||||
/// Use an int to make the randomness deterministic
|
||||
pub seed: Option<u64>,
|
||||
}
|
||||
|
||||
impl KMeansParameters {
|
||||
@@ -128,6 +136,93 @@ impl Default for KMeansParameters {
|
||||
KMeansParameters {
|
||||
k: 2,
|
||||
max_iter: 100,
|
||||
seed: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// KMeans grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KMeansSearchParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Number of clusters.
|
||||
pub k: Vec<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Maximum number of iterations of the k-means algorithm for a single run.
|
||||
pub max_iter: Vec<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Determines random number generation for centroid initialization.
|
||||
/// Use an int to make the randomness deterministic
|
||||
pub seed: Vec<Option<u64>>,
|
||||
}
|
||||
|
||||
/// KMeans grid search iterator
|
||||
pub struct KMeansSearchParametersIterator {
|
||||
kmeans_search_parameters: KMeansSearchParameters,
|
||||
current_k: usize,
|
||||
current_max_iter: usize,
|
||||
current_seed: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for KMeansSearchParameters {
|
||||
type Item = KMeansParameters;
|
||||
type IntoIter = KMeansSearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
KMeansSearchParametersIterator {
|
||||
kmeans_search_parameters: self,
|
||||
current_k: 0,
|
||||
current_max_iter: 0,
|
||||
current_seed: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for KMeansSearchParametersIterator {
|
||||
type Item = KMeansParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_k == self.kmeans_search_parameters.k.len()
|
||||
&& self.current_max_iter == self.kmeans_search_parameters.max_iter.len()
|
||||
&& self.current_seed == self.kmeans_search_parameters.seed.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = KMeansParameters {
|
||||
k: self.kmeans_search_parameters.k[self.current_k],
|
||||
max_iter: self.kmeans_search_parameters.max_iter[self.current_max_iter],
|
||||
seed: self.kmeans_search_parameters.seed[self.current_seed],
|
||||
};
|
||||
|
||||
if self.current_k + 1 < self.kmeans_search_parameters.k.len() {
|
||||
self.current_k += 1;
|
||||
} else if self.current_max_iter + 1 < self.kmeans_search_parameters.max_iter.len() {
|
||||
self.current_k = 0;
|
||||
self.current_max_iter += 1;
|
||||
} else if self.current_seed + 1 < self.kmeans_search_parameters.seed.len() {
|
||||
self.current_k = 0;
|
||||
self.current_max_iter = 0;
|
||||
self.current_seed += 1;
|
||||
} else {
|
||||
self.current_k += 1;
|
||||
self.current_max_iter += 1;
|
||||
self.current_seed += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for KMeansSearchParameters {
|
||||
fn default() -> Self {
|
||||
let default_params = KMeansParameters::default();
|
||||
|
||||
KMeansSearchParameters {
|
||||
k: vec![default_params.k],
|
||||
max_iter: vec![default_params.max_iter],
|
||||
seed: vec![default_params.seed],
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -168,7 +263,7 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
||||
let (n, d) = data.shape();
|
||||
|
||||
let mut distortion = T::max_value();
|
||||
let mut y = KMeans::kmeans_plus_plus(data, parameters.k);
|
||||
let mut y = KMeans::kmeans_plus_plus(data, parameters.k, parameters.seed);
|
||||
let mut size = vec![0; parameters.k];
|
||||
let mut centroids = vec![vec![T::zero(); d]; parameters.k];
|
||||
|
||||
@@ -241,8 +336,8 @@ impl<T: RealNumber + Sum> KMeans<T> {
|
||||
Ok(result.to_row_vector())
|
||||
}
|
||||
|
||||
fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize) -> Vec<usize> {
|
||||
let mut rng = rand::thread_rng();
|
||||
fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize, seed: Option<u64>) -> Vec<usize> {
|
||||
let mut rng = get_rng_impl(seed);
|
||||
let (n, m) = data.shape();
|
||||
let mut y = vec![0; n];
|
||||
let mut centroid = data.get_row_as_vec(rng.gen_range(0..n));
|
||||
@@ -313,6 +408,29 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = KMeansSearchParameters {
|
||||
k: vec![2, 4],
|
||||
max_iter: vec![10, 100],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.k, 2);
|
||||
assert_eq!(next.max_iter, 10);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.k, 4);
|
||||
assert_eq!(next.max_iter, 10);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.k, 2);
|
||||
assert_eq!(next.max_iter, 100);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.k, 4);
|
||||
assert_eq!(next.max_iter, 100);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn fit_predict_iris() {
|
||||
|
||||
@@ -83,11 +83,14 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for PCA<T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// PCA parameters
|
||||
pub struct PCAParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Number of components to keep.
|
||||
pub n_components: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// By default, covariance matrix is used to compute principal components.
|
||||
/// Enable this flag if you want to use correlation matrix instead.
|
||||
pub use_correlation_matrix: bool,
|
||||
@@ -116,6 +119,83 @@ impl Default for PCAParameters {
|
||||
}
|
||||
}
|
||||
|
||||
/// PCA grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PCASearchParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Number of components to keep.
|
||||
pub n_components: Vec<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// By default, covariance matrix is used to compute principal components.
|
||||
/// Enable this flag if you want to use correlation matrix instead.
|
||||
pub use_correlation_matrix: Vec<bool>,
|
||||
}
|
||||
|
||||
/// PCA grid search iterator
|
||||
pub struct PCASearchParametersIterator {
|
||||
pca_search_parameters: PCASearchParameters,
|
||||
current_k: usize,
|
||||
current_use_correlation_matrix: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for PCASearchParameters {
|
||||
type Item = PCAParameters;
|
||||
type IntoIter = PCASearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
PCASearchParametersIterator {
|
||||
pca_search_parameters: self,
|
||||
current_k: 0,
|
||||
current_use_correlation_matrix: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for PCASearchParametersIterator {
|
||||
type Item = PCAParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_k == self.pca_search_parameters.n_components.len()
|
||||
&& self.current_use_correlation_matrix
|
||||
== self.pca_search_parameters.use_correlation_matrix.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = PCAParameters {
|
||||
n_components: self.pca_search_parameters.n_components[self.current_k],
|
||||
use_correlation_matrix: self.pca_search_parameters.use_correlation_matrix
|
||||
[self.current_use_correlation_matrix],
|
||||
};
|
||||
|
||||
if self.current_k + 1 < self.pca_search_parameters.n_components.len() {
|
||||
self.current_k += 1;
|
||||
} else if self.current_use_correlation_matrix + 1
|
||||
< self.pca_search_parameters.use_correlation_matrix.len()
|
||||
{
|
||||
self.current_k = 0;
|
||||
self.current_use_correlation_matrix += 1;
|
||||
} else {
|
||||
self.current_k += 1;
|
||||
self.current_use_correlation_matrix += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PCASearchParameters {
|
||||
fn default() -> Self {
|
||||
let default_params = PCAParameters::default();
|
||||
|
||||
PCASearchParameters {
|
||||
n_components: vec![default_params.n_components],
|
||||
use_correlation_matrix: vec![default_params.use_correlation_matrix],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> UnsupervisedEstimator<M, PCAParameters> for PCA<T, M> {
|
||||
fn fit(x: &M, parameters: PCAParameters) -> Result<Self, Failed> {
|
||||
PCA::fit(x, parameters)
|
||||
@@ -271,6 +351,29 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = PCASearchParameters {
|
||||
n_components: vec![2, 4],
|
||||
use_correlation_matrix: vec![true, false],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_components, 2);
|
||||
assert_eq!(next.use_correlation_matrix, true);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_components, 4);
|
||||
assert_eq!(next.use_correlation_matrix, true);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_components, 2);
|
||||
assert_eq!(next.use_correlation_matrix, false);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_components, 4);
|
||||
assert_eq!(next.use_correlation_matrix, false);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
fn us_arrests_data() -> DenseMatrix<f64> {
|
||||
DenseMatrix::from_2d_array(&[
|
||||
&[13.2, 236.0, 58.0, 21.2],
|
||||
|
||||
@@ -69,9 +69,11 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for SVD<T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// SVD parameters
|
||||
pub struct SVDParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Number of components to keep.
|
||||
pub n_components: usize,
|
||||
}
|
||||
@@ -90,6 +92,61 @@ impl SVDParameters {
|
||||
}
|
||||
}
|
||||
|
||||
/// SVD grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SVDSearchParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Maximum number of iterations of the k-means algorithm for a single run.
|
||||
pub n_components: Vec<usize>,
|
||||
}
|
||||
|
||||
/// SVD grid search iterator
|
||||
pub struct SVDSearchParametersIterator {
|
||||
svd_search_parameters: SVDSearchParameters,
|
||||
current_n_components: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for SVDSearchParameters {
|
||||
type Item = SVDParameters;
|
||||
type IntoIter = SVDSearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
SVDSearchParametersIterator {
|
||||
svd_search_parameters: self,
|
||||
current_n_components: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for SVDSearchParametersIterator {
|
||||
type Item = SVDParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_n_components == self.svd_search_parameters.n_components.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = SVDParameters {
|
||||
n_components: self.svd_search_parameters.n_components[self.current_n_components],
|
||||
};
|
||||
|
||||
self.current_n_components += 1;
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SVDSearchParameters {
|
||||
fn default() -> Self {
|
||||
let default_params = SVDParameters::default();
|
||||
|
||||
SVDSearchParameters {
|
||||
n_components: vec![default_params.n_components],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> UnsupervisedEstimator<M, SVDParameters> for SVD<T, M> {
|
||||
fn fit(x: &M, parameters: SVDParameters) -> Result<Self, Failed> {
|
||||
SVD::fit(x, parameters)
|
||||
@@ -153,6 +210,20 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = SVDSearchParameters {
|
||||
n_components: vec![10, 100],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_components, 10);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_components, 100);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn svd_decompose() {
|
||||
|
||||
@@ -45,8 +45,8 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rand::Rng;
|
||||
|
||||
use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
|
||||
@@ -58,6 +58,7 @@ use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::linalg::{BaseMatrix, Matrix};
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::rand::get_rng_impl;
|
||||
use crate::tree::decision_tree_classifier::{
|
||||
which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion,
|
||||
};
|
||||
@@ -67,20 +68,28 @@ use crate::tree::decision_tree_classifier::{
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RandomForestClassifierParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub criterion: SplitCriterion,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub max_depth: Option<u16>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub min_samples_leaf: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub min_samples_split: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The number of trees in the forest.
|
||||
pub n_trees: u16,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Number of random sample of predictors to use as split candidates.
|
||||
pub m: Option<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||
pub keep_samples: bool,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||
pub seed: u64,
|
||||
}
|
||||
@@ -194,6 +203,234 @@ impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RandomForestCla
|
||||
}
|
||||
}
|
||||
|
||||
/// RandomForestClassifier grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RandomForestClassifierSearchParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub criterion: Vec<SplitCriterion>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub max_depth: Vec<Option<u16>>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub min_samples_leaf: Vec<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub min_samples_split: Vec<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The number of trees in the forest.
|
||||
pub n_trees: Vec<u16>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Number of random sample of predictors to use as split candidates.
|
||||
pub m: Vec<Option<usize>>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||
pub keep_samples: Vec<bool>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||
pub seed: Vec<u64>,
|
||||
}
|
||||
|
||||
/// RandomForestClassifier grid search iterator
|
||||
pub struct RandomForestClassifierSearchParametersIterator {
|
||||
random_forest_classifier_search_parameters: RandomForestClassifierSearchParameters,
|
||||
current_criterion: usize,
|
||||
current_max_depth: usize,
|
||||
current_min_samples_leaf: usize,
|
||||
current_min_samples_split: usize,
|
||||
current_n_trees: usize,
|
||||
current_m: usize,
|
||||
current_keep_samples: usize,
|
||||
current_seed: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for RandomForestClassifierSearchParameters {
|
||||
type Item = RandomForestClassifierParameters;
|
||||
type IntoIter = RandomForestClassifierSearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
RandomForestClassifierSearchParametersIterator {
|
||||
random_forest_classifier_search_parameters: self,
|
||||
current_criterion: 0,
|
||||
current_max_depth: 0,
|
||||
current_min_samples_leaf: 0,
|
||||
current_min_samples_split: 0,
|
||||
current_n_trees: 0,
|
||||
current_m: 0,
|
||||
current_keep_samples: 0,
|
||||
current_seed: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for RandomForestClassifierSearchParametersIterator {
|
||||
type Item = RandomForestClassifierParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_criterion
|
||||
== self
|
||||
.random_forest_classifier_search_parameters
|
||||
.criterion
|
||||
.len()
|
||||
&& self.current_max_depth
|
||||
== self
|
||||
.random_forest_classifier_search_parameters
|
||||
.max_depth
|
||||
.len()
|
||||
&& self.current_min_samples_leaf
|
||||
== self
|
||||
.random_forest_classifier_search_parameters
|
||||
.min_samples_leaf
|
||||
.len()
|
||||
&& self.current_min_samples_split
|
||||
== self
|
||||
.random_forest_classifier_search_parameters
|
||||
.min_samples_split
|
||||
.len()
|
||||
&& self.current_n_trees
|
||||
== self
|
||||
.random_forest_classifier_search_parameters
|
||||
.n_trees
|
||||
.len()
|
||||
&& self.current_m == self.random_forest_classifier_search_parameters.m.len()
|
||||
&& self.current_keep_samples
|
||||
== self
|
||||
.random_forest_classifier_search_parameters
|
||||
.keep_samples
|
||||
.len()
|
||||
&& self.current_seed == self.random_forest_classifier_search_parameters.seed.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = RandomForestClassifierParameters {
|
||||
criterion: self.random_forest_classifier_search_parameters.criterion
|
||||
[self.current_criterion]
|
||||
.clone(),
|
||||
max_depth: self.random_forest_classifier_search_parameters.max_depth
|
||||
[self.current_max_depth],
|
||||
min_samples_leaf: self
|
||||
.random_forest_classifier_search_parameters
|
||||
.min_samples_leaf[self.current_min_samples_leaf],
|
||||
min_samples_split: self
|
||||
.random_forest_classifier_search_parameters
|
||||
.min_samples_split[self.current_min_samples_split],
|
||||
n_trees: self.random_forest_classifier_search_parameters.n_trees[self.current_n_trees],
|
||||
m: self.random_forest_classifier_search_parameters.m[self.current_m],
|
||||
keep_samples: self.random_forest_classifier_search_parameters.keep_samples
|
||||
[self.current_keep_samples],
|
||||
seed: self.random_forest_classifier_search_parameters.seed[self.current_seed],
|
||||
};
|
||||
|
||||
if self.current_criterion + 1
|
||||
< self
|
||||
.random_forest_classifier_search_parameters
|
||||
.criterion
|
||||
.len()
|
||||
{
|
||||
self.current_criterion += 1;
|
||||
} else if self.current_max_depth + 1
|
||||
< self
|
||||
.random_forest_classifier_search_parameters
|
||||
.max_depth
|
||||
.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth += 1;
|
||||
} else if self.current_min_samples_leaf + 1
|
||||
< self
|
||||
.random_forest_classifier_search_parameters
|
||||
.min_samples_leaf
|
||||
.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf += 1;
|
||||
} else if self.current_min_samples_split + 1
|
||||
< self
|
||||
.random_forest_classifier_search_parameters
|
||||
.min_samples_split
|
||||
.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split += 1;
|
||||
} else if self.current_n_trees + 1
|
||||
< self
|
||||
.random_forest_classifier_search_parameters
|
||||
.n_trees
|
||||
.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split = 0;
|
||||
self.current_n_trees += 1;
|
||||
} else if self.current_m + 1 < self.random_forest_classifier_search_parameters.m.len() {
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split = 0;
|
||||
self.current_n_trees = 0;
|
||||
self.current_m += 1;
|
||||
} else if self.current_keep_samples + 1
|
||||
< self
|
||||
.random_forest_classifier_search_parameters
|
||||
.keep_samples
|
||||
.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split = 0;
|
||||
self.current_n_trees = 0;
|
||||
self.current_m = 0;
|
||||
self.current_keep_samples += 1;
|
||||
} else if self.current_seed + 1 < self.random_forest_classifier_search_parameters.seed.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split = 0;
|
||||
self.current_n_trees = 0;
|
||||
self.current_m = 0;
|
||||
self.current_keep_samples = 0;
|
||||
self.current_seed += 1;
|
||||
} else {
|
||||
self.current_criterion += 1;
|
||||
self.current_max_depth += 1;
|
||||
self.current_min_samples_leaf += 1;
|
||||
self.current_min_samples_split += 1;
|
||||
self.current_n_trees += 1;
|
||||
self.current_m += 1;
|
||||
self.current_keep_samples += 1;
|
||||
self.current_seed += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RandomForestClassifierSearchParameters {
|
||||
fn default() -> Self {
|
||||
let default_params = RandomForestClassifierParameters::default();
|
||||
|
||||
RandomForestClassifierSearchParameters {
|
||||
criterion: vec![default_params.criterion],
|
||||
max_depth: vec![default_params.max_depth],
|
||||
min_samples_leaf: vec![default_params.min_samples_leaf],
|
||||
min_samples_split: vec![default_params.min_samples_split],
|
||||
n_trees: vec![default_params.n_trees],
|
||||
m: vec![default_params.m],
|
||||
keep_samples: vec![default_params.keep_samples],
|
||||
seed: vec![default_params.seed],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
/// Build a forest of trees from the training set.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
@@ -222,7 +459,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let mut rng = StdRng::seed_from_u64(parameters.seed);
|
||||
let mut rng = get_rng_impl(Some(parameters.seed));
|
||||
let classes = y_m.unique();
|
||||
let k = classes.len();
|
||||
let mut trees: Vec<DecisionTreeClassifier<T>> = Vec::new();
|
||||
@@ -243,9 +480,9 @@ impl<T: RealNumber> RandomForestClassifier<T> {
|
||||
max_depth: parameters.max_depth,
|
||||
min_samples_leaf: parameters.min_samples_leaf,
|
||||
min_samples_split: parameters.min_samples_split,
|
||||
seed: Some(parameters.seed),
|
||||
};
|
||||
let tree =
|
||||
DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
|
||||
let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params)?;
|
||||
trees.push(tree);
|
||||
}
|
||||
|
||||
@@ -378,6 +615,29 @@ mod tests_prob {
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::metrics::*;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = RandomForestClassifierSearchParameters {
|
||||
n_trees: vec![10, 100],
|
||||
m: vec![None, Some(1)],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_trees, 10);
|
||||
assert_eq!(next.m, None);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_trees, 100);
|
||||
assert_eq!(next.m, None);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_trees, 10);
|
||||
assert_eq!(next.m, Some(1));
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_trees, 100);
|
||||
assert_eq!(next.m, Some(1));
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn fit_predict_iris() {
|
||||
|
||||
@@ -43,8 +43,8 @@
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rand::Rng;
|
||||
|
||||
use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
|
||||
@@ -55,6 +55,7 @@ use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::rand::get_rng_impl;
|
||||
use crate::tree::decision_tree_regressor::{
|
||||
DecisionTreeRegressor, DecisionTreeRegressorParameters,
|
||||
};
|
||||
@@ -64,18 +65,25 @@ use crate::tree::decision_tree_regressor::{
|
||||
/// Parameters of the Random Forest Regressor
|
||||
/// Some parameters here are passed directly into base estimator.
|
||||
pub struct RandomForestRegressorParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||
pub max_depth: Option<u16>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||
pub min_samples_leaf: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||
pub min_samples_split: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The number of trees in the forest.
|
||||
pub n_trees: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Number of random sample of predictors to use as split candidates.
|
||||
pub m: Option<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||
pub keep_samples: bool,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||
pub seed: u64,
|
||||
}
|
||||
@@ -176,6 +184,198 @@ impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for RandomForestReg
|
||||
}
|
||||
}
|
||||
|
||||
/// RandomForestRegressor grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RandomForestRegressorSearchParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub max_depth: Vec<Option<u16>>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub min_samples_leaf: Vec<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub min_samples_split: Vec<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The number of trees in the forest.
|
||||
pub n_trees: Vec<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Number of random sample of predictors to use as split candidates.
|
||||
pub m: Vec<Option<usize>>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||
pub keep_samples: Vec<bool>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||
pub seed: Vec<u64>,
|
||||
}
|
||||
|
||||
/// RandomForestRegressor grid search iterator
|
||||
pub struct RandomForestRegressorSearchParametersIterator {
|
||||
random_forest_regressor_search_parameters: RandomForestRegressorSearchParameters,
|
||||
current_max_depth: usize,
|
||||
current_min_samples_leaf: usize,
|
||||
current_min_samples_split: usize,
|
||||
current_n_trees: usize,
|
||||
current_m: usize,
|
||||
current_keep_samples: usize,
|
||||
current_seed: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for RandomForestRegressorSearchParameters {
|
||||
type Item = RandomForestRegressorParameters;
|
||||
type IntoIter = RandomForestRegressorSearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
RandomForestRegressorSearchParametersIterator {
|
||||
random_forest_regressor_search_parameters: self,
|
||||
current_max_depth: 0,
|
||||
current_min_samples_leaf: 0,
|
||||
current_min_samples_split: 0,
|
||||
current_n_trees: 0,
|
||||
current_m: 0,
|
||||
current_keep_samples: 0,
|
||||
current_seed: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for RandomForestRegressorSearchParametersIterator {
|
||||
type Item = RandomForestRegressorParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_max_depth
|
||||
== self
|
||||
.random_forest_regressor_search_parameters
|
||||
.max_depth
|
||||
.len()
|
||||
&& self.current_min_samples_leaf
|
||||
== self
|
||||
.random_forest_regressor_search_parameters
|
||||
.min_samples_leaf
|
||||
.len()
|
||||
&& self.current_min_samples_split
|
||||
== self
|
||||
.random_forest_regressor_search_parameters
|
||||
.min_samples_split
|
||||
.len()
|
||||
&& self.current_n_trees == self.random_forest_regressor_search_parameters.n_trees.len()
|
||||
&& self.current_m == self.random_forest_regressor_search_parameters.m.len()
|
||||
&& self.current_keep_samples
|
||||
== self
|
||||
.random_forest_regressor_search_parameters
|
||||
.keep_samples
|
||||
.len()
|
||||
&& self.current_seed == self.random_forest_regressor_search_parameters.seed.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = RandomForestRegressorParameters {
|
||||
max_depth: self.random_forest_regressor_search_parameters.max_depth
|
||||
[self.current_max_depth],
|
||||
min_samples_leaf: self
|
||||
.random_forest_regressor_search_parameters
|
||||
.min_samples_leaf[self.current_min_samples_leaf],
|
||||
min_samples_split: self
|
||||
.random_forest_regressor_search_parameters
|
||||
.min_samples_split[self.current_min_samples_split],
|
||||
n_trees: self.random_forest_regressor_search_parameters.n_trees[self.current_n_trees],
|
||||
m: self.random_forest_regressor_search_parameters.m[self.current_m],
|
||||
keep_samples: self.random_forest_regressor_search_parameters.keep_samples
|
||||
[self.current_keep_samples],
|
||||
seed: self.random_forest_regressor_search_parameters.seed[self.current_seed],
|
||||
};
|
||||
|
||||
if self.current_max_depth + 1
|
||||
< self
|
||||
.random_forest_regressor_search_parameters
|
||||
.max_depth
|
||||
.len()
|
||||
{
|
||||
self.current_max_depth += 1;
|
||||
} else if self.current_min_samples_leaf + 1
|
||||
< self
|
||||
.random_forest_regressor_search_parameters
|
||||
.min_samples_leaf
|
||||
.len()
|
||||
{
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf += 1;
|
||||
} else if self.current_min_samples_split + 1
|
||||
< self
|
||||
.random_forest_regressor_search_parameters
|
||||
.min_samples_split
|
||||
.len()
|
||||
{
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split += 1;
|
||||
} else if self.current_n_trees + 1
|
||||
< self.random_forest_regressor_search_parameters.n_trees.len()
|
||||
{
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split = 0;
|
||||
self.current_n_trees += 1;
|
||||
} else if self.current_m + 1 < self.random_forest_regressor_search_parameters.m.len() {
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split = 0;
|
||||
self.current_n_trees = 0;
|
||||
self.current_m += 1;
|
||||
} else if self.current_keep_samples + 1
|
||||
< self
|
||||
.random_forest_regressor_search_parameters
|
||||
.keep_samples
|
||||
.len()
|
||||
{
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split = 0;
|
||||
self.current_n_trees = 0;
|
||||
self.current_m = 0;
|
||||
self.current_keep_samples += 1;
|
||||
} else if self.current_seed + 1 < self.random_forest_regressor_search_parameters.seed.len()
|
||||
{
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split = 0;
|
||||
self.current_n_trees = 0;
|
||||
self.current_m = 0;
|
||||
self.current_keep_samples = 0;
|
||||
self.current_seed += 1;
|
||||
} else {
|
||||
self.current_max_depth += 1;
|
||||
self.current_min_samples_leaf += 1;
|
||||
self.current_min_samples_split += 1;
|
||||
self.current_n_trees += 1;
|
||||
self.current_m += 1;
|
||||
self.current_keep_samples += 1;
|
||||
self.current_seed += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RandomForestRegressorSearchParameters {
|
||||
fn default() -> Self {
|
||||
let default_params = RandomForestRegressorParameters::default();
|
||||
|
||||
RandomForestRegressorSearchParameters {
|
||||
max_depth: vec![default_params.max_depth],
|
||||
min_samples_leaf: vec![default_params.min_samples_leaf],
|
||||
min_samples_split: vec![default_params.min_samples_split],
|
||||
n_trees: vec![default_params.n_trees],
|
||||
m: vec![default_params.m],
|
||||
keep_samples: vec![default_params.keep_samples],
|
||||
seed: vec![default_params.seed],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> RandomForestRegressor<T> {
|
||||
/// Build a forest of trees from the training set.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
@@ -191,7 +391,7 @@ impl<T: RealNumber> RandomForestRegressor<T> {
|
||||
.m
|
||||
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
|
||||
|
||||
let mut rng = StdRng::seed_from_u64(parameters.seed);
|
||||
let mut rng = get_rng_impl(Some(parameters.seed));
|
||||
let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new();
|
||||
|
||||
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
|
||||
@@ -208,9 +408,9 @@ impl<T: RealNumber> RandomForestRegressor<T> {
|
||||
max_depth: parameters.max_depth,
|
||||
min_samples_leaf: parameters.min_samples_leaf,
|
||||
min_samples_split: parameters.min_samples_split,
|
||||
seed: Some(parameters.seed),
|
||||
};
|
||||
let tree =
|
||||
DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
|
||||
let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params)?;
|
||||
trees.push(tree);
|
||||
}
|
||||
|
||||
@@ -302,6 +502,29 @@ mod tests {
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::metrics::mean_absolute_error;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = RandomForestRegressorSearchParameters {
|
||||
n_trees: vec![10, 100],
|
||||
m: vec![None, Some(1)],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_trees, 10);
|
||||
assert_eq!(next.m, None);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_trees, 100);
|
||||
assert_eq!(next.m, None);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_trees, 10);
|
||||
assert_eq!(next.m, Some(1));
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.n_trees, 100);
|
||||
assert_eq!(next.m, Some(1));
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn fit_longley() {
|
||||
|
||||
@@ -95,7 +95,11 @@ pub mod neighbors;
|
||||
pub(crate) mod optimization;
|
||||
/// Preprocessing utilities
|
||||
pub mod preprocessing;
|
||||
/// Reading in Data.
|
||||
pub mod readers;
|
||||
/// Support Vector Machines
|
||||
pub mod svm;
|
||||
/// Supervised tree-based learning methods
|
||||
pub mod tree;
|
||||
|
||||
pub(crate) mod rand;
|
||||
|
||||
+16
-3
@@ -25,6 +25,19 @@
|
||||
//! let eigenvectors: DenseMatrix<f64> = evd.V;
|
||||
//! let eigenvalues: Vec<f64> = evd.d;
|
||||
//! ```
|
||||
//! ```
|
||||
//! use smartcore::linalg::naive::dense_matrix::*;
|
||||
//! use smartcore::linalg::evd::*;
|
||||
//!
|
||||
//! let A = DenseMatrix::from_2d_array(&[
|
||||
//! &[-5.0, 2.0],
|
||||
//! &[-7.0, 4.0],
|
||||
//! ]);
|
||||
//!
|
||||
//! let evd = A.evd(false).unwrap();
|
||||
//! let eigenvectors: DenseMatrix<f64> = evd.V;
|
||||
//! let eigenvalues: Vec<f64> = evd.d;
|
||||
//! ```
|
||||
//!
|
||||
//! ## References:
|
||||
//! * ["Numerical Recipes: The Art of Scientific Computing", Press W.H., Teukolsky S.A., Vetterling W.T, Flannery B.P, 3rd ed., Section 11 Eigensystems](http://numerical.recipes/)
|
||||
@@ -799,10 +812,10 @@ fn sort<T: RealNumber, M: BaseMatrix<T>>(d: &mut [T], e: &mut [T], V: &mut M) {
|
||||
}
|
||||
i -= 1;
|
||||
}
|
||||
d[i as usize + 1] = real;
|
||||
e[i as usize + 1] = img;
|
||||
d[(i + 1) as usize] = real;
|
||||
e[(i + 1) as usize] = img;
|
||||
for (k, temp_k) in temp.iter().enumerate().take(n) {
|
||||
V.set(k, i as usize + 1, *temp_k);
|
||||
V.set(k, (i + 1) as usize, *temp_k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,8 +65,11 @@ use high_order::HighOrderOperations;
|
||||
use lu::LUDecomposableMatrix;
|
||||
use qr::QRDecomposableMatrix;
|
||||
use stats::{MatrixPreprocessing, MatrixStats};
|
||||
use std::fs;
|
||||
use svd::SVDDecomposableMatrix;
|
||||
|
||||
use crate::readers;
|
||||
|
||||
/// Column or row vector
|
||||
pub trait BaseVector<T: RealNumber>: Clone + Debug {
|
||||
/// Get an element of a vector
|
||||
@@ -298,9 +301,63 @@ pub trait BaseMatrix<T: RealNumber>: Clone + Debug {
|
||||
/// represents a row in this matrix.
|
||||
type RowVector: BaseVector<T> + Clone + Debug;
|
||||
|
||||
/// Create a matrix from a csv file.
|
||||
/// ```
|
||||
/// use smartcore::linalg::naive::dense_matrix::DenseMatrix;
|
||||
/// use smartcore::linalg::BaseMatrix;
|
||||
/// use smartcore::readers::csv;
|
||||
/// use std::fs;
|
||||
///
|
||||
/// fs::write("identity.csv", "header\n1.0,0.0\n0.0,1.0");
|
||||
/// assert_eq!(
|
||||
/// DenseMatrix::<f64>::from_csv("identity.csv", csv::CSVDefinition::default()).unwrap(),
|
||||
/// DenseMatrix::from_row_vectors(vec![vec![1.0, 0.0], vec![0.0, 1.0]]).unwrap()
|
||||
/// );
|
||||
/// fs::remove_file("identity.csv");
|
||||
/// ```
|
||||
fn from_csv(
|
||||
path: &str,
|
||||
definition: readers::csv::CSVDefinition<'_>,
|
||||
) -> Result<Self, readers::ReadingError> {
|
||||
readers::csv::matrix_from_csv_source(fs::File::open(path)?, definition)
|
||||
}
|
||||
|
||||
/// Transforms row vector `vec` into a 1xM matrix.
|
||||
fn from_row_vector(vec: Self::RowVector) -> Self;
|
||||
|
||||
/// Transforms Vector of n rows with dimension m into
|
||||
/// a matrix nxm.
|
||||
/// ```
|
||||
/// use smartcore::linalg::naive::dense_matrix::DenseMatrix;
|
||||
/// use crate::smartcore::linalg::BaseMatrix;
|
||||
///
|
||||
/// let eye = DenseMatrix::from_row_vectors(vec![vec![1., 0., 0.], vec![0., 1., 0.], vec![0., 0., 1.]])
|
||||
/// .unwrap();
|
||||
///
|
||||
/// assert_eq!(
|
||||
/// eye,
|
||||
/// DenseMatrix::from_2d_vec(&vec![
|
||||
/// vec![1.0, 0.0, 0.0],
|
||||
/// vec![0.0, 1.0, 0.0],
|
||||
/// vec![0.0, 0.0, 1.0],
|
||||
/// ])
|
||||
/// );
|
||||
fn from_row_vectors(rows: Vec<Self::RowVector>) -> Option<Self> {
|
||||
if rows.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let n = rows.len();
|
||||
let m = rows[0].len();
|
||||
|
||||
let mut result = Self::zeros(n, m);
|
||||
|
||||
for (row_idx, row) in rows.into_iter().enumerate() {
|
||||
result.set_row(row_idx, row);
|
||||
}
|
||||
|
||||
Some(result)
|
||||
}
|
||||
|
||||
/// Transforms 1-d matrix of 1xM into a row vector.
|
||||
fn to_row_vector(self) -> Self::RowVector;
|
||||
|
||||
@@ -322,6 +379,13 @@ pub trait BaseMatrix<T: RealNumber>: Clone + Debug {
|
||||
/// * `result` - receiver for the row
|
||||
fn copy_row_as_vec(&self, row: usize, result: &mut Vec<T>);
|
||||
|
||||
/// Set row vector at row `row_idx`.
|
||||
fn set_row(&mut self, row_idx: usize, row: Self::RowVector) {
|
||||
for (col_idx, val) in row.to_vec().into_iter().enumerate() {
|
||||
self.set(row_idx, col_idx, val);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a vector with elements of the `col`'th column
|
||||
/// * `col` - column number
|
||||
fn get_col_as_vec(&self, col: usize) -> Vec<T>;
|
||||
@@ -651,6 +715,10 @@ pub trait BaseMatrix<T: RealNumber>: Clone + Debug {
|
||||
|
||||
result
|
||||
}
|
||||
/// Take an individual column from the matrix.
|
||||
fn take_column(&self, column_index: usize) -> Self {
|
||||
self.take(&[column_index], 1)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generic matrix with additional mixins like various factorization methods.
|
||||
@@ -761,4 +829,93 @@ mod tests {
|
||||
assert_eq!(m.take(&vec!(1, 1, 3), 0), expected_0);
|
||||
assert_eq!(m.take(&vec!(1, 0), 1), expected_1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn take_second_column_from_matrix() {
|
||||
let four_columns: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
|
||||
&[0.0, 1.0, 2.0, 3.0],
|
||||
&[0.0, 1.0, 2.0, 3.0],
|
||||
&[0.0, 1.0, 2.0, 3.0],
|
||||
&[0.0, 1.0, 2.0, 3.0],
|
||||
]);
|
||||
|
||||
let second_column = four_columns.take_column(1);
|
||||
assert_eq!(
|
||||
second_column,
|
||||
DenseMatrix::from_2d_array(&[&[1.0], &[1.0], &[1.0], &[1.0]]),
|
||||
"The second column was not extracted correctly"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_row_vectors_simple() {
|
||||
let eye = DenseMatrix::from_row_vectors(vec![
|
||||
vec![1., 0., 0.],
|
||||
vec![0., 1., 0.],
|
||||
vec![0., 0., 1.],
|
||||
])
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
eye,
|
||||
DenseMatrix::from_2d_vec(&vec![
|
||||
vec![1.0, 0.0, 0.0],
|
||||
vec![0.0, 1.0, 0.0],
|
||||
vec![0.0, 0.0, 1.0],
|
||||
])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_row_vectors_large() {
|
||||
let eye = DenseMatrix::from_row_vectors(vec![vec![4.25; 5000]; 5000]).unwrap();
|
||||
|
||||
assert_eq!(eye.shape(), (5000, 5000));
|
||||
assert_eq!(eye.get_row(5), vec![4.25; 5000]);
|
||||
}
|
||||
mod matrix_from_csv {
|
||||
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::linalg::BaseMatrix;
|
||||
use crate::readers::csv;
|
||||
use crate::readers::io_testing;
|
||||
use crate::readers::ReadingError;
|
||||
|
||||
#[test]
|
||||
fn simple_read_default_csv() {
|
||||
let test_csv_file = io_testing::TemporaryTextFile::new(
|
||||
"'sepal.length','sepal.width','petal.length','petal.width'\n\
|
||||
5.1,3.5,1.4,0.2\n\
|
||||
4.9,3,1.4,0.2\n\
|
||||
4.7,3.2,1.3,0.2",
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
DenseMatrix::<f64>::from_csv(
|
||||
test_csv_file
|
||||
.expect("Temporary file could not be written.")
|
||||
.path(),
|
||||
csv::CSVDefinition::default()
|
||||
),
|
||||
Ok(DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
]))
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_existant_input_file() {
|
||||
let potential_error =
|
||||
DenseMatrix::<f64>::from_csv("/invalid/path", csv::CSVDefinition::default());
|
||||
// The exact message is operating system dependant, therefore, I only test that the correct type
|
||||
// error was returned.
|
||||
assert_eq!(
|
||||
potential_error.clone(),
|
||||
Err(ReadingError::CouldNotReadFileSystem {
|
||||
msg: String::from(potential_error.err().unwrap().message().unwrap())
|
||||
})
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,16 +71,21 @@ use crate::linear::lasso_optimizer::InteriorPointOptimizer;
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ElasticNetParameters<T: RealNumber> {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Regularization parameter.
|
||||
pub alpha: T,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The elastic net mixing parameter, with 0 <= l1_ratio <= 1.
|
||||
/// For l1_ratio = 0 the penalty is an L2 penalty.
|
||||
/// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.
|
||||
pub l1_ratio: T,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation.
|
||||
pub normalize: bool,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The tolerance for the optimization
|
||||
pub tol: T,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The maximum number of iterations
|
||||
pub max_iter: usize,
|
||||
}
|
||||
@@ -135,6 +140,126 @@ impl<T: RealNumber> Default for ElasticNetParameters<T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// ElasticNet grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ElasticNetSearchParameters<T: RealNumber> {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Regularization parameter.
|
||||
pub alpha: Vec<T>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The elastic net mixing parameter, with 0 <= l1_ratio <= 1.
|
||||
/// For l1_ratio = 0 the penalty is an L2 penalty.
|
||||
/// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.
|
||||
pub l1_ratio: Vec<T>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation.
|
||||
pub normalize: Vec<bool>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The tolerance for the optimization
|
||||
pub tol: Vec<T>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The maximum number of iterations
|
||||
pub max_iter: Vec<usize>,
|
||||
}
|
||||
|
||||
/// ElasticNet grid search iterator
|
||||
pub struct ElasticNetSearchParametersIterator<T: RealNumber> {
|
||||
lasso_regression_search_parameters: ElasticNetSearchParameters<T>,
|
||||
current_alpha: usize,
|
||||
current_l1_ratio: usize,
|
||||
current_normalize: usize,
|
||||
current_tol: usize,
|
||||
current_max_iter: usize,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> IntoIterator for ElasticNetSearchParameters<T> {
|
||||
type Item = ElasticNetParameters<T>;
|
||||
type IntoIter = ElasticNetSearchParametersIterator<T>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
ElasticNetSearchParametersIterator {
|
||||
lasso_regression_search_parameters: self,
|
||||
current_alpha: 0,
|
||||
current_l1_ratio: 0,
|
||||
current_normalize: 0,
|
||||
current_tol: 0,
|
||||
current_max_iter: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Iterator for ElasticNetSearchParametersIterator<T> {
|
||||
type Item = ElasticNetParameters<T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_alpha == self.lasso_regression_search_parameters.alpha.len()
|
||||
&& self.current_l1_ratio == self.lasso_regression_search_parameters.l1_ratio.len()
|
||||
&& self.current_normalize == self.lasso_regression_search_parameters.normalize.len()
|
||||
&& self.current_tol == self.lasso_regression_search_parameters.tol.len()
|
||||
&& self.current_max_iter == self.lasso_regression_search_parameters.max_iter.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = ElasticNetParameters {
|
||||
alpha: self.lasso_regression_search_parameters.alpha[self.current_alpha],
|
||||
l1_ratio: self.lasso_regression_search_parameters.alpha[self.current_l1_ratio],
|
||||
normalize: self.lasso_regression_search_parameters.normalize[self.current_normalize],
|
||||
tol: self.lasso_regression_search_parameters.tol[self.current_tol],
|
||||
max_iter: self.lasso_regression_search_parameters.max_iter[self.current_max_iter],
|
||||
};
|
||||
|
||||
if self.current_alpha + 1 < self.lasso_regression_search_parameters.alpha.len() {
|
||||
self.current_alpha += 1;
|
||||
} else if self.current_l1_ratio + 1 < self.lasso_regression_search_parameters.l1_ratio.len()
|
||||
{
|
||||
self.current_alpha = 0;
|
||||
self.current_l1_ratio += 1;
|
||||
} else if self.current_normalize + 1
|
||||
< self.lasso_regression_search_parameters.normalize.len()
|
||||
{
|
||||
self.current_alpha = 0;
|
||||
self.current_l1_ratio = 0;
|
||||
self.current_normalize += 1;
|
||||
} else if self.current_tol + 1 < self.lasso_regression_search_parameters.tol.len() {
|
||||
self.current_alpha = 0;
|
||||
self.current_l1_ratio = 0;
|
||||
self.current_normalize = 0;
|
||||
self.current_tol += 1;
|
||||
} else if self.current_max_iter + 1 < self.lasso_regression_search_parameters.max_iter.len()
|
||||
{
|
||||
self.current_alpha = 0;
|
||||
self.current_l1_ratio = 0;
|
||||
self.current_normalize = 0;
|
||||
self.current_tol = 0;
|
||||
self.current_max_iter += 1;
|
||||
} else {
|
||||
self.current_alpha += 1;
|
||||
self.current_l1_ratio += 1;
|
||||
self.current_normalize += 1;
|
||||
self.current_tol += 1;
|
||||
self.current_max_iter += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for ElasticNetSearchParameters<T> {
|
||||
fn default() -> Self {
|
||||
let default_params = ElasticNetParameters::default();
|
||||
|
||||
ElasticNetSearchParameters {
|
||||
alpha: vec![default_params.alpha],
|
||||
l1_ratio: vec![default_params.l1_ratio],
|
||||
normalize: vec![default_params.normalize],
|
||||
tol: vec![default_params.tol],
|
||||
max_iter: vec![default_params.max_iter],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> PartialEq for ElasticNet<T, M> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.coefficients == other.coefficients
|
||||
@@ -291,6 +416,29 @@ mod tests {
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::metrics::mean_absolute_error;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = ElasticNetSearchParameters {
|
||||
alpha: vec![0., 1.],
|
||||
max_iter: vec![10, 100],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 0.);
|
||||
assert_eq!(next.max_iter, 10);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 1.);
|
||||
assert_eq!(next.max_iter, 10);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 0.);
|
||||
assert_eq!(next.max_iter, 100);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 1.);
|
||||
assert_eq!(next.max_iter, 100);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn elasticnet_longley() {
|
||||
|
||||
@@ -38,13 +38,17 @@ use crate::math::num::RealNumber;
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LassoParameters<T: RealNumber> {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Controls the strength of the penalty to the loss function.
|
||||
pub alpha: T,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// If true the regressors X will be normalized before regression
|
||||
/// by subtracting the mean and dividing by the standard deviation.
|
||||
pub normalize: bool,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The tolerance for the optimization
|
||||
pub tol: T,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The maximum number of iterations
|
||||
pub max_iter: usize,
|
||||
}
|
||||
@@ -112,6 +116,106 @@ impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for Lasso<T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Lasso grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LassoSearchParameters<T: RealNumber> {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Controls the strength of the penalty to the loss function.
|
||||
pub alpha: Vec<T>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// If true the regressors X will be normalized before regression
|
||||
/// by subtracting the mean and dividing by the standard deviation.
|
||||
pub normalize: Vec<bool>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The tolerance for the optimization
|
||||
pub tol: Vec<T>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The maximum number of iterations
|
||||
pub max_iter: Vec<usize>,
|
||||
}
|
||||
|
||||
/// Lasso grid search iterator
|
||||
pub struct LassoSearchParametersIterator<T: RealNumber> {
|
||||
lasso_search_parameters: LassoSearchParameters<T>,
|
||||
current_alpha: usize,
|
||||
current_normalize: usize,
|
||||
current_tol: usize,
|
||||
current_max_iter: usize,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> IntoIterator for LassoSearchParameters<T> {
|
||||
type Item = LassoParameters<T>;
|
||||
type IntoIter = LassoSearchParametersIterator<T>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
LassoSearchParametersIterator {
|
||||
lasso_search_parameters: self,
|
||||
current_alpha: 0,
|
||||
current_normalize: 0,
|
||||
current_tol: 0,
|
||||
current_max_iter: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Iterator for LassoSearchParametersIterator<T> {
|
||||
type Item = LassoParameters<T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_alpha == self.lasso_search_parameters.alpha.len()
|
||||
&& self.current_normalize == self.lasso_search_parameters.normalize.len()
|
||||
&& self.current_tol == self.lasso_search_parameters.tol.len()
|
||||
&& self.current_max_iter == self.lasso_search_parameters.max_iter.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = LassoParameters {
|
||||
alpha: self.lasso_search_parameters.alpha[self.current_alpha],
|
||||
normalize: self.lasso_search_parameters.normalize[self.current_normalize],
|
||||
tol: self.lasso_search_parameters.tol[self.current_tol],
|
||||
max_iter: self.lasso_search_parameters.max_iter[self.current_max_iter],
|
||||
};
|
||||
|
||||
if self.current_alpha + 1 < self.lasso_search_parameters.alpha.len() {
|
||||
self.current_alpha += 1;
|
||||
} else if self.current_normalize + 1 < self.lasso_search_parameters.normalize.len() {
|
||||
self.current_alpha = 0;
|
||||
self.current_normalize += 1;
|
||||
} else if self.current_tol + 1 < self.lasso_search_parameters.tol.len() {
|
||||
self.current_alpha = 0;
|
||||
self.current_normalize = 0;
|
||||
self.current_tol += 1;
|
||||
} else if self.current_max_iter + 1 < self.lasso_search_parameters.max_iter.len() {
|
||||
self.current_alpha = 0;
|
||||
self.current_normalize = 0;
|
||||
self.current_tol = 0;
|
||||
self.current_max_iter += 1;
|
||||
} else {
|
||||
self.current_alpha += 1;
|
||||
self.current_normalize += 1;
|
||||
self.current_tol += 1;
|
||||
self.current_max_iter += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for LassoSearchParameters<T> {
|
||||
fn default() -> Self {
|
||||
let default_params = LassoParameters::default();
|
||||
|
||||
LassoSearchParameters {
|
||||
alpha: vec![default_params.alpha],
|
||||
normalize: vec![default_params.normalize],
|
||||
tol: vec![default_params.tol],
|
||||
max_iter: vec![default_params.max_iter],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Lasso<T, M> {
|
||||
/// Fits Lasso regression to your data.
|
||||
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||
@@ -226,6 +330,29 @@ mod tests {
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::metrics::mean_absolute_error;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = LassoSearchParameters {
|
||||
alpha: vec![0., 1.],
|
||||
max_iter: vec![10, 100],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 0.);
|
||||
assert_eq!(next.max_iter, 10);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 1.);
|
||||
assert_eq!(next.max_iter, 10);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 0.);
|
||||
assert_eq!(next.max_iter, 100);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 1.);
|
||||
assert_eq!(next.max_iter, 100);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn lasso_fit_predict() {
|
||||
|
||||
@@ -211,9 +211,7 @@ impl<T: RealNumber, M: Matrix<T>> InteriorPointOptimizer<T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: RealNumber, M: Matrix<T>> BiconjugateGradientSolver<T, M>
|
||||
for InteriorPointOptimizer<T, M>
|
||||
{
|
||||
impl<T: RealNumber, M: Matrix<T>> BiconjugateGradientSolver<T, M> for InteriorPointOptimizer<T, M> {
|
||||
fn solve_preconditioner(&self, a: &M, b: &M, x: &mut M) {
|
||||
let (_, p) = a.shape();
|
||||
|
||||
|
||||
@@ -71,19 +71,21 @@ use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Default, Clone, Eq, PartialEq)]
|
||||
/// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable.
|
||||
pub enum LinearRegressionSolverName {
|
||||
/// QR decomposition, see [QR](../../linalg/qr/index.html)
|
||||
QR,
|
||||
#[default]
|
||||
/// SVD decomposition, see [SVD](../../linalg/svd/index.html)
|
||||
SVD,
|
||||
}
|
||||
|
||||
/// Linear Regression parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct LinearRegressionParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub solver: LinearRegressionSolverName,
|
||||
}
|
||||
@@ -105,10 +107,57 @@ impl LinearRegressionParameters {
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LinearRegressionParameters {
|
||||
/// Linear Regression grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LinearRegressionSearchParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub solver: Vec<LinearRegressionSolverName>,
|
||||
}
|
||||
|
||||
/// Linear Regression grid search iterator
|
||||
pub struct LinearRegressionSearchParametersIterator {
|
||||
linear_regression_search_parameters: LinearRegressionSearchParameters,
|
||||
current_solver: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for LinearRegressionSearchParameters {
|
||||
type Item = LinearRegressionParameters;
|
||||
type IntoIter = LinearRegressionSearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
LinearRegressionSearchParametersIterator {
|
||||
linear_regression_search_parameters: self,
|
||||
current_solver: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for LinearRegressionSearchParametersIterator {
|
||||
type Item = LinearRegressionParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_solver == self.linear_regression_search_parameters.solver.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = LinearRegressionParameters {
|
||||
solver: self.linear_regression_search_parameters.solver[self.current_solver].clone(),
|
||||
};
|
||||
|
||||
self.current_solver += 1;
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LinearRegressionSearchParameters {
|
||||
fn default() -> Self {
|
||||
LinearRegressionParameters {
|
||||
solver: LinearRegressionSolverName::SVD,
|
||||
let default_params = LinearRegressionParameters::default();
|
||||
|
||||
LinearRegressionSearchParameters {
|
||||
solver: vec![default_params.solver],
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -200,6 +249,20 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = LinearRegressionSearchParameters {
|
||||
solver: vec![
|
||||
LinearRegressionSolverName::QR,
|
||||
LinearRegressionSolverName::SVD,
|
||||
],
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
assert_eq!(iter.next().unwrap().solver, LinearRegressionSolverName::QR);
|
||||
assert_eq!(iter.next().unwrap().solver, LinearRegressionSolverName::SVD);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn ols_fit_predict() {
|
||||
@@ -285,5 +348,9 @@ mod tests {
|
||||
serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(lr, deserialized_lr);
|
||||
|
||||
let default = LinearRegressionParameters::default();
|
||||
let parameters: LinearRegressionParameters = serde_json::from_str("{}").unwrap();
|
||||
assert_eq!(parameters.solver, default.solver);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,23 +68,104 @@ use crate::optimization::line_search::Backtracking;
|
||||
use crate::optimization::FunctionOrder;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
/// Solver options for Logistic regression. Right now only LBFGS solver is supported.
|
||||
pub enum LogisticRegressionSolverName {
|
||||
/// Limited-memory Broyden–Fletcher–Goldfarb–Shanno method, see [LBFGS paper](http://users.iems.northwestern.edu/~nocedal/lbfgsb.html)
|
||||
LBFGS,
|
||||
}
|
||||
|
||||
impl Default for LogisticRegressionSolverName {
|
||||
fn default() -> Self {
|
||||
LogisticRegressionSolverName::LBFGS
|
||||
}
|
||||
}
|
||||
|
||||
/// Logistic Regression parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LogisticRegressionParameters<T: RealNumber> {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub solver: LogisticRegressionSolverName,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Regularization parameter.
|
||||
pub alpha: T,
|
||||
}
|
||||
|
||||
/// Logistic Regression grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LogisticRegressionSearchParameters<T: RealNumber> {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub solver: Vec<LogisticRegressionSolverName>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Regularization parameter.
|
||||
pub alpha: Vec<T>,
|
||||
}
|
||||
|
||||
/// Logistic Regression grid search iterator
|
||||
pub struct LogisticRegressionSearchParametersIterator<T: RealNumber> {
|
||||
logistic_regression_search_parameters: LogisticRegressionSearchParameters<T>,
|
||||
current_solver: usize,
|
||||
current_alpha: usize,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> IntoIterator for LogisticRegressionSearchParameters<T> {
|
||||
type Item = LogisticRegressionParameters<T>;
|
||||
type IntoIter = LogisticRegressionSearchParametersIterator<T>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
LogisticRegressionSearchParametersIterator {
|
||||
logistic_regression_search_parameters: self,
|
||||
current_solver: 0,
|
||||
current_alpha: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Iterator for LogisticRegressionSearchParametersIterator<T> {
|
||||
type Item = LogisticRegressionParameters<T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_alpha == self.logistic_regression_search_parameters.alpha.len()
|
||||
&& self.current_solver == self.logistic_regression_search_parameters.solver.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = LogisticRegressionParameters {
|
||||
solver: self.logistic_regression_search_parameters.solver[self.current_solver].clone(),
|
||||
alpha: self.logistic_regression_search_parameters.alpha[self.current_alpha],
|
||||
};
|
||||
|
||||
if self.current_alpha + 1 < self.logistic_regression_search_parameters.alpha.len() {
|
||||
self.current_alpha += 1;
|
||||
} else if self.current_solver + 1 < self.logistic_regression_search_parameters.solver.len()
|
||||
{
|
||||
self.current_alpha = 0;
|
||||
self.current_solver += 1;
|
||||
} else {
|
||||
self.current_alpha += 1;
|
||||
self.current_solver += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for LogisticRegressionSearchParameters<T> {
|
||||
fn default() -> Self {
|
||||
let default_params = LogisticRegressionParameters::default();
|
||||
|
||||
LogisticRegressionSearchParameters {
|
||||
solver: vec![default_params.solver],
|
||||
alpha: vec![default_params.alpha],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Logistic Regression
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
@@ -133,7 +214,7 @@ impl<T: RealNumber> LogisticRegressionParameters<T> {
|
||||
impl<T: RealNumber> Default for LogisticRegressionParameters<T> {
|
||||
fn default() -> Self {
|
||||
LogisticRegressionParameters {
|
||||
solver: LogisticRegressionSolverName::LBFGS,
|
||||
solver: LogisticRegressionSolverName::default(),
|
||||
alpha: T::zero(),
|
||||
}
|
||||
}
|
||||
@@ -452,6 +533,21 @@ mod tests {
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::metrics::accuracy;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = LogisticRegressionSearchParameters {
|
||||
alpha: vec![0., 1.],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
assert_eq!(iter.next().unwrap().alpha, 0.);
|
||||
assert_eq!(
|
||||
iter.next().unwrap().solver,
|
||||
LogisticRegressionSolverName::LBFGS
|
||||
);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn multiclass_objective_f() {
|
||||
|
||||
@@ -68,7 +68,7 @@ use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
/// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable.
|
||||
pub enum RidgeRegressionSolverName {
|
||||
/// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html)
|
||||
@@ -77,6 +77,12 @@ pub enum RidgeRegressionSolverName {
|
||||
SVD,
|
||||
}
|
||||
|
||||
impl Default for RidgeRegressionSolverName {
|
||||
fn default() -> Self {
|
||||
RidgeRegressionSolverName::Cholesky
|
||||
}
|
||||
}
|
||||
|
||||
/// Ridge Regression parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -90,6 +96,93 @@ pub struct RidgeRegressionParameters<T: RealNumber> {
|
||||
pub normalize: bool,
|
||||
}
|
||||
|
||||
/// Ridge Regression grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RidgeRegressionSearchParameters<T: RealNumber> {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Solver to use for estimation of regression coefficients.
|
||||
pub solver: Vec<RidgeRegressionSolverName>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Regularization parameter.
|
||||
pub alpha: Vec<T>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// If true the regressors X will be normalized before regression
|
||||
/// by subtracting the mean and dividing by the standard deviation.
|
||||
pub normalize: Vec<bool>,
|
||||
}
|
||||
|
||||
/// Ridge Regression grid search iterator
|
||||
pub struct RidgeRegressionSearchParametersIterator<T: RealNumber> {
|
||||
ridge_regression_search_parameters: RidgeRegressionSearchParameters<T>,
|
||||
current_solver: usize,
|
||||
current_alpha: usize,
|
||||
current_normalize: usize,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> IntoIterator for RidgeRegressionSearchParameters<T> {
|
||||
type Item = RidgeRegressionParameters<T>;
|
||||
type IntoIter = RidgeRegressionSearchParametersIterator<T>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
RidgeRegressionSearchParametersIterator {
|
||||
ridge_regression_search_parameters: self,
|
||||
current_solver: 0,
|
||||
current_alpha: 0,
|
||||
current_normalize: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Iterator for RidgeRegressionSearchParametersIterator<T> {
|
||||
type Item = RidgeRegressionParameters<T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_alpha == self.ridge_regression_search_parameters.alpha.len()
|
||||
&& self.current_solver == self.ridge_regression_search_parameters.solver.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = RidgeRegressionParameters {
|
||||
solver: self.ridge_regression_search_parameters.solver[self.current_solver].clone(),
|
||||
alpha: self.ridge_regression_search_parameters.alpha[self.current_alpha],
|
||||
normalize: self.ridge_regression_search_parameters.normalize[self.current_normalize],
|
||||
};
|
||||
|
||||
if self.current_alpha + 1 < self.ridge_regression_search_parameters.alpha.len() {
|
||||
self.current_alpha += 1;
|
||||
} else if self.current_solver + 1 < self.ridge_regression_search_parameters.solver.len() {
|
||||
self.current_alpha = 0;
|
||||
self.current_solver += 1;
|
||||
} else if self.current_normalize + 1
|
||||
< self.ridge_regression_search_parameters.normalize.len()
|
||||
{
|
||||
self.current_alpha = 0;
|
||||
self.current_solver = 0;
|
||||
self.current_normalize += 1;
|
||||
} else {
|
||||
self.current_alpha += 1;
|
||||
self.current_solver += 1;
|
||||
self.current_normalize += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for RidgeRegressionSearchParameters<T> {
|
||||
fn default() -> Self {
|
||||
let default_params = RidgeRegressionParameters::default();
|
||||
|
||||
RidgeRegressionSearchParameters {
|
||||
solver: vec![default_params.solver],
|
||||
alpha: vec![default_params.alpha],
|
||||
normalize: vec![default_params.normalize],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Ridge regression
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
@@ -120,7 +213,7 @@ impl<T: RealNumber> RidgeRegressionParameters<T> {
|
||||
impl<T: RealNumber> Default for RidgeRegressionParameters<T> {
|
||||
fn default() -> Self {
|
||||
RidgeRegressionParameters {
|
||||
solver: RidgeRegressionSolverName::Cholesky,
|
||||
solver: RidgeRegressionSolverName::default(),
|
||||
alpha: T::one(),
|
||||
normalize: true,
|
||||
}
|
||||
@@ -274,6 +367,21 @@ mod tests {
|
||||
use crate::linalg::naive::dense_matrix::*;
|
||||
use crate::metrics::mean_absolute_error;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = RidgeRegressionSearchParameters {
|
||||
alpha: vec![0., 1.],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
assert_eq!(iter.next().unwrap().alpha, 0.);
|
||||
assert_eq!(
|
||||
iter.next().unwrap().solver,
|
||||
RidgeRegressionSolverName::Cholesky
|
||||
);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn ridge_fit_predict() {
|
||||
|
||||
+28
-3
@@ -7,6 +7,9 @@ use rand::prelude::*;
|
||||
use std::fmt::{Debug, Display};
|
||||
use std::iter::{Product, Sum};
|
||||
use std::ops::{AddAssign, DivAssign, MulAssign, SubAssign};
|
||||
use std::str::FromStr;
|
||||
|
||||
use crate::rand::get_rng_impl;
|
||||
|
||||
/// Defines real number
|
||||
/// <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||
@@ -22,6 +25,7 @@ pub trait RealNumber:
|
||||
+ SubAssign
|
||||
+ MulAssign
|
||||
+ DivAssign
|
||||
+ FromStr
|
||||
{
|
||||
/// Copy sign from `sign` - another real number
|
||||
fn copysign(self, sign: Self) -> Self;
|
||||
@@ -46,8 +50,11 @@ pub trait RealNumber:
|
||||
self * self
|
||||
}
|
||||
|
||||
/// Raw transmutation to u64
|
||||
/// Raw transmutation to u32
|
||||
fn to_f32_bits(self) -> u32;
|
||||
|
||||
/// Raw transmutation to u64
|
||||
fn to_f64_bits(self) -> u64;
|
||||
}
|
||||
|
||||
impl RealNumber for f64 {
|
||||
@@ -74,7 +81,7 @@ impl RealNumber for f64 {
|
||||
}
|
||||
|
||||
fn rand() -> f64 {
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut rng = get_rng_impl(None);
|
||||
rng.gen()
|
||||
}
|
||||
|
||||
@@ -89,6 +96,10 @@ impl RealNumber for f64 {
|
||||
fn to_f32_bits(self) -> u32 {
|
||||
self.to_bits() as u32
|
||||
}
|
||||
|
||||
fn to_f64_bits(self) -> u64 {
|
||||
self.to_bits()
|
||||
}
|
||||
}
|
||||
|
||||
impl RealNumber for f32 {
|
||||
@@ -115,7 +126,7 @@ impl RealNumber for f32 {
|
||||
}
|
||||
|
||||
fn rand() -> f32 {
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut rng = get_rng_impl(None);
|
||||
rng.gen()
|
||||
}
|
||||
|
||||
@@ -130,6 +141,10 @@ impl RealNumber for f32 {
|
||||
fn to_f32_bits(self) -> u32 {
|
||||
self.to_bits()
|
||||
}
|
||||
|
||||
fn to_f64_bits(self) -> u64 {
|
||||
self.to_bits() as u64
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -143,4 +158,14 @@ mod tests {
|
||||
assert_eq!(41.0.sigmoid(), 1.);
|
||||
assert_eq!((-41.0).sigmoid(), 0.);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn f32_from_string() {
|
||||
assert_eq!(f32::from_str("1.111111").unwrap(), 1.111111)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn f64_from_string() {
|
||||
assert_eq!(f64::from_str("1.111111111").unwrap(), 1.111111111)
|
||||
}
|
||||
}
|
||||
|
||||
+41
-21
@@ -18,6 +18,8 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::collections::HashSet;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -42,34 +44,33 @@ impl Precision {
|
||||
);
|
||||
}
|
||||
|
||||
let mut classes = HashSet::new();
|
||||
for i in 0..y_true.len() {
|
||||
classes.insert(y_true.get(i).to_f64_bits());
|
||||
}
|
||||
let classes = classes.len();
|
||||
|
||||
let mut tp = 0;
|
||||
let mut p = 0;
|
||||
let n = y_true.len();
|
||||
for i in 0..n {
|
||||
if y_true.get(i) != T::zero() && y_true.get(i) != T::one() {
|
||||
panic!(
|
||||
"Precision can only be applied to binary classification: {}",
|
||||
y_true.get(i)
|
||||
);
|
||||
}
|
||||
|
||||
if y_pred.get(i) != T::zero() && y_pred.get(i) != T::one() {
|
||||
panic!(
|
||||
"Precision can only be applied to binary classification: {}",
|
||||
y_pred.get(i)
|
||||
);
|
||||
}
|
||||
|
||||
if y_pred.get(i) == T::one() {
|
||||
p += 1;
|
||||
|
||||
let mut fp = 0;
|
||||
for i in 0..y_true.len() {
|
||||
if y_pred.get(i) == y_true.get(i) {
|
||||
if classes == 2 {
|
||||
if y_true.get(i) == T::one() {
|
||||
tp += 1;
|
||||
}
|
||||
} else {
|
||||
tp += 1;
|
||||
}
|
||||
} else if classes == 2 {
|
||||
if y_true.get(i) == T::one() {
|
||||
fp += 1;
|
||||
}
|
||||
} else {
|
||||
fp += 1;
|
||||
}
|
||||
}
|
||||
|
||||
T::from_i64(tp).unwrap() / T::from_i64(p).unwrap()
|
||||
T::from_i64(tp).unwrap() / (T::from_i64(tp).unwrap() + T::from_i64(fp).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,5 +89,24 @@ mod tests {
|
||||
|
||||
assert!((score1 - 0.5).abs() < 1e-8);
|
||||
assert!((score2 - 1.0).abs() < 1e-8);
|
||||
|
||||
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
||||
let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
|
||||
|
||||
let score3: f64 = Precision {}.get_score(&y_pred, &y_true);
|
||||
assert!((score3 - 0.5).abs() < 1e-8);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn precision_multiclass() {
|
||||
let y_true: Vec<f64> = vec![0., 0., 0., 1., 1., 1., 2., 2., 2.];
|
||||
let y_pred: Vec<f64> = vec![0., 1., 2., 0., 1., 2., 0., 1., 2.];
|
||||
|
||||
let score1: f64 = Precision {}.get_score(&y_pred, &y_true);
|
||||
let score2: f64 = Precision {}.get_score(&y_pred, &y_pred);
|
||||
|
||||
assert!((score1 - 0.333333333).abs() < 1e-8);
|
||||
assert!((score2 - 1.0).abs() < 1e-8);
|
||||
}
|
||||
}
|
||||
|
||||
+42
-22
@@ -18,6 +18,9 @@
|
||||
//!
|
||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
use std::collections::HashSet;
|
||||
use std::convert::TryInto;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -42,34 +45,32 @@ impl Recall {
|
||||
);
|
||||
}
|
||||
|
||||
let mut classes = HashSet::new();
|
||||
for i in 0..y_true.len() {
|
||||
classes.insert(y_true.get(i).to_f64_bits());
|
||||
}
|
||||
let classes: i64 = classes.len().try_into().unwrap();
|
||||
|
||||
let mut tp = 0;
|
||||
let mut p = 0;
|
||||
let n = y_true.len();
|
||||
for i in 0..n {
|
||||
if y_true.get(i) != T::zero() && y_true.get(i) != T::one() {
|
||||
panic!(
|
||||
"Recall can only be applied to binary classification: {}",
|
||||
y_true.get(i)
|
||||
);
|
||||
}
|
||||
|
||||
if y_pred.get(i) != T::zero() && y_pred.get(i) != T::one() {
|
||||
panic!(
|
||||
"Recall can only be applied to binary classification: {}",
|
||||
y_pred.get(i)
|
||||
);
|
||||
}
|
||||
|
||||
let mut fne = 0;
|
||||
for i in 0..y_true.len() {
|
||||
if y_pred.get(i) == y_true.get(i) {
|
||||
if classes == 2 {
|
||||
if y_true.get(i) == T::one() {
|
||||
p += 1;
|
||||
|
||||
if y_pred.get(i) == T::one() {
|
||||
tp += 1;
|
||||
}
|
||||
} else {
|
||||
tp += 1;
|
||||
}
|
||||
} else if classes == 2 {
|
||||
if y_true.get(i) != T::one() {
|
||||
fne += 1;
|
||||
}
|
||||
} else {
|
||||
fne += 1;
|
||||
}
|
||||
}
|
||||
|
||||
T::from_i64(tp).unwrap() / T::from_i64(p).unwrap()
|
||||
T::from_i64(tp).unwrap() / (T::from_i64(tp).unwrap() + T::from_i64(fne).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,5 +89,24 @@ mod tests {
|
||||
|
||||
assert!((score1 - 0.5).abs() < 1e-8);
|
||||
assert!((score2 - 1.0).abs() < 1e-8);
|
||||
|
||||
let y_pred: Vec<f64> = vec![0., 0., 1., 1., 1., 1.];
|
||||
let y_true: Vec<f64> = vec![0., 1., 1., 0., 1., 0.];
|
||||
|
||||
let score3: f64 = Recall {}.get_score(&y_pred, &y_true);
|
||||
assert!((score3 - 0.66666666).abs() < 1e-8);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn recall_multiclass() {
|
||||
let y_true: Vec<f64> = vec![0., 0., 0., 1., 1., 1., 2., 2., 2.];
|
||||
let y_pred: Vec<f64> = vec![0., 1., 2., 0., 1., 2., 0., 1., 2.];
|
||||
|
||||
let score1: f64 = Recall {}.get_score(&y_pred, &y_true);
|
||||
let score2: f64 = Recall {}.get_score(&y_pred, &y_pred);
|
||||
|
||||
assert!((score1 - 0.333333333).abs() < 1e-8);
|
||||
assert!((score2 - 1.0).abs() < 1e-8);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
/// grid search results.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct GridSearchResult<T: RealNumber, I: Clone> {
|
||||
/// Vector with test scores on each cv split
|
||||
pub cross_validation_result: CrossValidationResult<T>,
|
||||
/// Vector with training scores on each cv split
|
||||
pub parameters: I,
|
||||
}
|
||||
|
||||
/// Search for the best estimator by testing all possible combinations with cross-validation using given metric.
|
||||
/// * `fit_estimator` - a `fit` function of an estimator
|
||||
/// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes.
|
||||
/// * `y` - target values, should be of size _N_
|
||||
/// * `parameter_search` - an iterator for parameters that will be tested.
|
||||
/// * `cv` - the cross-validation splitting strategy, should be an instance of [`BaseKFold`](./trait.BaseKFold.html)
|
||||
/// * `score` - a metric to use for evaluation, see [metrics](../metrics/index.html)
|
||||
pub fn grid_search<T, M, I, E, K, F, S>(
|
||||
fit_estimator: F,
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameter_search: I,
|
||||
cv: K,
|
||||
score: S,
|
||||
) -> Result<GridSearchResult<T, I::Item>, Failed>
|
||||
where
|
||||
T: RealNumber,
|
||||
M: Matrix<T>,
|
||||
I: Iterator,
|
||||
I::Item: Clone,
|
||||
E: Predictor<M, M::RowVector>,
|
||||
K: BaseKFold,
|
||||
F: Fn(&M, &M::RowVector, I::Item) -> Result<E, Failed>,
|
||||
S: Fn(&M::RowVector, &M::RowVector) -> T,
|
||||
{
|
||||
let mut best_result: Option<CrossValidationResult<T>> = None;
|
||||
let mut best_parameters = None;
|
||||
|
||||
for parameters in parameter_search {
|
||||
let result = cross_validate(&fit_estimator, x, y, ¶meters, &cv, &score)?;
|
||||
if best_result.is_none()
|
||||
|| result.mean_test_score() > best_result.as_ref().unwrap().mean_test_score()
|
||||
{
|
||||
best_parameters = Some(parameters);
|
||||
best_result = Some(result);
|
||||
}
|
||||
}
|
||||
|
||||
if let (Some(parameters), Some(cross_validation_result)) = (best_parameters, best_result) {
|
||||
Ok(GridSearchResult {
|
||||
cross_validation_result,
|
||||
parameters,
|
||||
})
|
||||
} else {
|
||||
Err(Failed::because(
|
||||
FailedError::FindFailed,
|
||||
"there were no parameter sets found",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::linear::logistic_regression::{
|
||||
LogisticRegression, LogisticRegressionSearchParameters,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_grid_search() {
|
||||
let x = DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[5.7, 2.8, 4.5, 1.3],
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
let y = vec![
|
||||
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
|
||||
let cv = KFold {
|
||||
n_splits: 5,
|
||||
..KFold::default()
|
||||
};
|
||||
|
||||
let parameters = LogisticRegressionSearchParameters {
|
||||
alpha: vec![0., 1.],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let results = grid_search(
|
||||
LogisticRegression::fit,
|
||||
&x,
|
||||
&y,
|
||||
parameters.into_iter(),
|
||||
cv,
|
||||
&accuracy,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!([0., 1.].contains(&results.parameters.alpha));
|
||||
}
|
||||
}
|
||||
@@ -5,8 +5,8 @@
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::model_selection::BaseKFold;
|
||||
use crate::rand::get_rng_impl;
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
/// K-Folds cross-validator
|
||||
pub struct KFold {
|
||||
@@ -14,6 +14,9 @@ pub struct KFold {
|
||||
pub n_splits: usize, // cannot exceed std::usize::MAX
|
||||
/// Whether to shuffle the data before splitting into batches
|
||||
pub shuffle: bool,
|
||||
/// When shuffle is True, seed affects the ordering of the indices.
|
||||
/// Which controls the randomness of each fold
|
||||
pub seed: Option<u64>,
|
||||
}
|
||||
|
||||
impl KFold {
|
||||
@@ -23,8 +26,10 @@ impl KFold {
|
||||
|
||||
// initialise indices
|
||||
let mut indices: Vec<usize> = (0..n_samples).collect();
|
||||
let mut rng = get_rng_impl(self.seed);
|
||||
|
||||
if self.shuffle {
|
||||
indices.shuffle(&mut thread_rng());
|
||||
indices.shuffle(&mut rng);
|
||||
}
|
||||
// return a new array of given shape n_split, filled with each element of n_samples divided by n_splits.
|
||||
let mut fold_sizes = vec![n_samples / self.n_splits; self.n_splits];
|
||||
@@ -66,6 +71,7 @@ impl Default for KFold {
|
||||
KFold {
|
||||
n_splits: 3,
|
||||
shuffle: true,
|
||||
seed: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -81,6 +87,12 @@ impl KFold {
|
||||
self.shuffle = shuffle;
|
||||
self
|
||||
}
|
||||
|
||||
/// When shuffle is True, random_state affects the ordering of the indices.
|
||||
pub fn with_seed(mut self, seed: Option<u64>) -> Self {
|
||||
self.seed = seed;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// An iterator over indices that split data into training and test set.
|
||||
@@ -150,6 +162,7 @@ mod tests {
|
||||
let k = KFold {
|
||||
n_splits: 3,
|
||||
shuffle: false,
|
||||
seed: None,
|
||||
};
|
||||
let x: DenseMatrix<f64> = DenseMatrix::rand(33, 100);
|
||||
let test_indices = k.test_indices(&x);
|
||||
@@ -165,6 +178,7 @@ mod tests {
|
||||
let k = KFold {
|
||||
n_splits: 3,
|
||||
shuffle: false,
|
||||
seed: None,
|
||||
};
|
||||
let x: DenseMatrix<f64> = DenseMatrix::rand(34, 100);
|
||||
let test_indices = k.test_indices(&x);
|
||||
@@ -180,6 +194,7 @@ mod tests {
|
||||
let k = KFold {
|
||||
n_splits: 2,
|
||||
shuffle: false,
|
||||
seed: None,
|
||||
};
|
||||
let x: DenseMatrix<f64> = DenseMatrix::rand(22, 100);
|
||||
let test_masks = k.test_masks(&x);
|
||||
@@ -206,6 +221,7 @@ mod tests {
|
||||
let k = KFold {
|
||||
n_splits: 2,
|
||||
shuffle: false,
|
||||
seed: None,
|
||||
};
|
||||
let x: DenseMatrix<f64> = DenseMatrix::rand(22, 100);
|
||||
let train_test_splits: Vec<(Vec<usize>, Vec<usize>)> = k.split(&x).collect();
|
||||
@@ -238,6 +254,7 @@ mod tests {
|
||||
let k = KFold {
|
||||
n_splits: 3,
|
||||
shuffle: false,
|
||||
seed: None,
|
||||
};
|
||||
let x: DenseMatrix<f64> = DenseMatrix::rand(10, 4);
|
||||
let expected: Vec<(Vec<usize>, Vec<usize>)> = vec![
|
||||
|
||||
+22
-12
@@ -41,7 +41,7 @@
|
||||
//! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
//! ];
|
||||
//!
|
||||
//! let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true);
|
||||
//! let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true, None);
|
||||
//!
|
||||
//! println!("X train: {:?}, y train: {}, X test: {:?}, y test: {}",
|
||||
//! x_train.shape(), y_train.len(), x_test.shape(), y_test.len());
|
||||
@@ -91,8 +91,8 @@
|
||||
//!
|
||||
//! let results = cross_validate(LogisticRegression::fit, //estimator
|
||||
//! &x, &y, //data
|
||||
//! Default::default(), //hyperparameters
|
||||
//! cv, //cross validation split
|
||||
//! &Default::default(), //hyperparameters
|
||||
//! &cv, //cross validation split
|
||||
//! &accuracy).unwrap(); //metric
|
||||
//!
|
||||
//! println!("Training accuracy: {}, test accuracy: {}",
|
||||
@@ -107,8 +107,8 @@ use crate::error::Failed;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::rand::get_rng_impl;
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
pub(crate) mod kfold;
|
||||
|
||||
@@ -130,11 +130,13 @@ pub trait BaseKFold {
|
||||
/// * `y` - target values, should be of size _N_
|
||||
/// * `test_size`, (0, 1] - the proportion of the dataset to include in the test split.
|
||||
/// * `shuffle`, - whether or not to shuffle the data before splitting
|
||||
/// * `seed` - Controls the shuffling applied to the data before applying the split. Pass an int for reproducible output across multiple function calls
|
||||
pub fn train_test_split<T: RealNumber, M: Matrix<T>>(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
test_size: f32,
|
||||
shuffle: bool,
|
||||
seed: Option<u64>,
|
||||
) -> (M, M, M::RowVector, M::RowVector) {
|
||||
if x.shape().0 != y.len() {
|
||||
panic!(
|
||||
@@ -143,6 +145,7 @@ pub fn train_test_split<T: RealNumber, M: Matrix<T>>(
|
||||
y.len()
|
||||
);
|
||||
}
|
||||
let mut rng = get_rng_impl(seed);
|
||||
|
||||
if test_size <= 0. || test_size > 1.0 {
|
||||
panic!("test_size should be between 0 and 1");
|
||||
@@ -159,7 +162,7 @@ pub fn train_test_split<T: RealNumber, M: Matrix<T>>(
|
||||
let mut indices: Vec<usize> = (0..n).collect();
|
||||
|
||||
if shuffle {
|
||||
indices.shuffle(&mut thread_rng());
|
||||
indices.shuffle(&mut rng);
|
||||
}
|
||||
|
||||
let x_train = x.take(&indices[n_test..n], 0);
|
||||
@@ -201,8 +204,8 @@ pub fn cross_validate<T, M, H, E, K, F, S>(
|
||||
fit_estimator: F,
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: H,
|
||||
cv: K,
|
||||
parameters: &H,
|
||||
cv: &K,
|
||||
score: S,
|
||||
) -> Result<CrossValidationResult<T>, Failed>
|
||||
where
|
||||
@@ -292,7 +295,7 @@ mod tests {
|
||||
let x: DenseMatrix<f64> = DenseMatrix::rand(n, 3);
|
||||
let y = vec![0f64; n];
|
||||
|
||||
let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true);
|
||||
let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true, None);
|
||||
|
||||
assert!(
|
||||
x_train.shape().0 > (n as f64 * 0.65) as usize
|
||||
@@ -362,8 +365,15 @@ mod tests {
|
||||
..KFold::default()
|
||||
};
|
||||
|
||||
let results =
|
||||
cross_validate(BiasedEstimator::fit, &x, &y, NoParameters {}, cv, &accuracy).unwrap();
|
||||
let results = cross_validate(
|
||||
BiasedEstimator::fit,
|
||||
&x,
|
||||
&y,
|
||||
&NoParameters {},
|
||||
&cv,
|
||||
&accuracy,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(0.4, results.mean_test_score());
|
||||
assert_eq!(0.4, results.mean_train_score());
|
||||
@@ -404,8 +414,8 @@ mod tests {
|
||||
KNNRegressor::fit,
|
||||
&x,
|
||||
&y,
|
||||
Default::default(),
|
||||
cv,
|
||||
&Default::default(),
|
||||
&cv,
|
||||
&mean_absolute_error,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -114,10 +114,13 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for BernoulliNBDistributi
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BernoulliNBParameters<T: RealNumber> {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||
pub alpha: T,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
|
||||
pub priors: Option<Vec<T>>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Threshold for binarizing (mapping to booleans) of sample features. If None, input is presumed to already consist of binary vectors.
|
||||
pub binarize: Option<T>,
|
||||
}
|
||||
@@ -150,6 +153,91 @@ impl<T: RealNumber> Default for BernoulliNBParameters<T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// BernoulliNB grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BernoulliNBSearchParameters<T: RealNumber> {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||
pub alpha: Vec<T>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
|
||||
pub priors: Vec<Option<Vec<T>>>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Threshold for binarizing (mapping to booleans) of sample features. If None, input is presumed to already consist of binary vectors.
|
||||
pub binarize: Vec<Option<T>>,
|
||||
}
|
||||
|
||||
/// BernoulliNB grid search iterator
|
||||
pub struct BernoulliNBSearchParametersIterator<T: RealNumber> {
|
||||
bernoulli_nb_search_parameters: BernoulliNBSearchParameters<T>,
|
||||
current_alpha: usize,
|
||||
current_priors: usize,
|
||||
current_binarize: usize,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> IntoIterator for BernoulliNBSearchParameters<T> {
|
||||
type Item = BernoulliNBParameters<T>;
|
||||
type IntoIter = BernoulliNBSearchParametersIterator<T>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
BernoulliNBSearchParametersIterator {
|
||||
bernoulli_nb_search_parameters: self,
|
||||
current_alpha: 0,
|
||||
current_priors: 0,
|
||||
current_binarize: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Iterator for BernoulliNBSearchParametersIterator<T> {
|
||||
type Item = BernoulliNBParameters<T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_alpha == self.bernoulli_nb_search_parameters.alpha.len()
|
||||
&& self.current_priors == self.bernoulli_nb_search_parameters.priors.len()
|
||||
&& self.current_binarize == self.bernoulli_nb_search_parameters.binarize.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = BernoulliNBParameters {
|
||||
alpha: self.bernoulli_nb_search_parameters.alpha[self.current_alpha],
|
||||
priors: self.bernoulli_nb_search_parameters.priors[self.current_priors].clone(),
|
||||
binarize: self.bernoulli_nb_search_parameters.binarize[self.current_binarize],
|
||||
};
|
||||
|
||||
if self.current_alpha + 1 < self.bernoulli_nb_search_parameters.alpha.len() {
|
||||
self.current_alpha += 1;
|
||||
} else if self.current_priors + 1 < self.bernoulli_nb_search_parameters.priors.len() {
|
||||
self.current_alpha = 0;
|
||||
self.current_priors += 1;
|
||||
} else if self.current_binarize + 1 < self.bernoulli_nb_search_parameters.binarize.len() {
|
||||
self.current_alpha = 0;
|
||||
self.current_priors = 0;
|
||||
self.current_binarize += 1;
|
||||
} else {
|
||||
self.current_alpha += 1;
|
||||
self.current_priors += 1;
|
||||
self.current_binarize += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for BernoulliNBSearchParameters<T> {
|
||||
fn default() -> Self {
|
||||
let default_params = BernoulliNBParameters::default();
|
||||
|
||||
BernoulliNBSearchParameters {
|
||||
alpha: vec![default_params.alpha],
|
||||
priors: vec![default_params.priors],
|
||||
binarize: vec![default_params.binarize],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> BernoulliNBDistribution<T> {
|
||||
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
|
||||
/// * `x` - training data.
|
||||
@@ -347,6 +435,20 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = BernoulliNBSearchParameters {
|
||||
alpha: vec![1., 2.],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 1.);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 2.);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn run_bernoulli_naive_bayes() {
|
||||
|
||||
@@ -243,6 +243,7 @@ impl<T: RealNumber> CategoricalNBDistribution<T> {
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CategoricalNBParameters<T: RealNumber> {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||
pub alpha: T,
|
||||
}
|
||||
@@ -261,6 +262,61 @@ impl<T: RealNumber> Default for CategoricalNBParameters<T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// CategoricalNB grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CategoricalNBSearchParameters<T: RealNumber> {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||
pub alpha: Vec<T>,
|
||||
}
|
||||
|
||||
/// CategoricalNB grid search iterator
|
||||
pub struct CategoricalNBSearchParametersIterator<T: RealNumber> {
|
||||
categorical_nb_search_parameters: CategoricalNBSearchParameters<T>,
|
||||
current_alpha: usize,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> IntoIterator for CategoricalNBSearchParameters<T> {
|
||||
type Item = CategoricalNBParameters<T>;
|
||||
type IntoIter = CategoricalNBSearchParametersIterator<T>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
CategoricalNBSearchParametersIterator {
|
||||
categorical_nb_search_parameters: self,
|
||||
current_alpha: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Iterator for CategoricalNBSearchParametersIterator<T> {
|
||||
type Item = CategoricalNBParameters<T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_alpha == self.categorical_nb_search_parameters.alpha.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = CategoricalNBParameters {
|
||||
alpha: self.categorical_nb_search_parameters.alpha[self.current_alpha],
|
||||
};
|
||||
|
||||
self.current_alpha += 1;
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for CategoricalNBSearchParameters<T> {
|
||||
fn default() -> Self {
|
||||
let default_params = CategoricalNBParameters::default();
|
||||
|
||||
CategoricalNBSearchParameters {
|
||||
alpha: vec![default_params.alpha],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// CategoricalNB implements the categorical naive Bayes algorithm for categorically distributed data.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, PartialEq)]
|
||||
@@ -351,6 +407,20 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = CategoricalNBSearchParameters {
|
||||
alpha: vec![1., 2.],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 1.);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 2.);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn run_categorical_naive_bayes() {
|
||||
|
||||
@@ -76,8 +76,9 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for GaussianNBDistributio
|
||||
|
||||
/// `GaussianNB` parameters. Use `Default::default()` for default values.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Default, Clone)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GaussianNBParameters<T: RealNumber> {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
|
||||
pub priors: Option<Vec<T>>,
|
||||
}
|
||||
@@ -90,6 +91,67 @@ impl<T: RealNumber> GaussianNBParameters<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for GaussianNBParameters<T> {
|
||||
fn default() -> Self {
|
||||
Self { priors: None }
|
||||
}
|
||||
}
|
||||
|
||||
/// GaussianNB grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GaussianNBSearchParameters<T: RealNumber> {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
|
||||
pub priors: Vec<Option<Vec<T>>>,
|
||||
}
|
||||
|
||||
/// GaussianNB grid search iterator
|
||||
pub struct GaussianNBSearchParametersIterator<T: RealNumber> {
|
||||
gaussian_nb_search_parameters: GaussianNBSearchParameters<T>,
|
||||
current_priors: usize,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> IntoIterator for GaussianNBSearchParameters<T> {
|
||||
type Item = GaussianNBParameters<T>;
|
||||
type IntoIter = GaussianNBSearchParametersIterator<T>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
GaussianNBSearchParametersIterator {
|
||||
gaussian_nb_search_parameters: self,
|
||||
current_priors: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Iterator for GaussianNBSearchParametersIterator<T> {
|
||||
type Item = GaussianNBParameters<T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_priors == self.gaussian_nb_search_parameters.priors.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = GaussianNBParameters {
|
||||
priors: self.gaussian_nb_search_parameters.priors[self.current_priors].clone(),
|
||||
};
|
||||
|
||||
self.current_priors += 1;
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for GaussianNBSearchParameters<T> {
|
||||
fn default() -> Self {
|
||||
let default_params = GaussianNBParameters::default();
|
||||
|
||||
GaussianNBSearchParameters {
|
||||
priors: vec![default_params.priors],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> GaussianNBDistribution<T> {
|
||||
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
|
||||
/// * `x` - training data.
|
||||
@@ -260,6 +322,20 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = GaussianNBSearchParameters {
|
||||
priors: vec![Some(vec![1.]), Some(vec![2.])],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.priors, Some(vec![1.]));
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.priors, Some(vec![2.]));
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn run_gaussian_naive_bayes() {
|
||||
|
||||
@@ -86,8 +86,10 @@ impl<T: RealNumber, M: Matrix<T>> NBDistribution<T, M> for MultinomialNBDistribu
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MultinomialNBParameters<T: RealNumber> {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||
pub alpha: T,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
|
||||
pub priors: Option<Vec<T>>,
|
||||
}
|
||||
@@ -114,6 +116,78 @@ impl<T: RealNumber> Default for MultinomialNBParameters<T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// MultinomialNB grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MultinomialNBSearchParameters<T: RealNumber> {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
|
||||
pub alpha: Vec<T>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Prior probabilities of the classes. If specified the priors are not adjusted according to the data
|
||||
pub priors: Vec<Option<Vec<T>>>,
|
||||
}
|
||||
|
||||
/// MultinomialNB grid search iterator
|
||||
pub struct MultinomialNBSearchParametersIterator<T: RealNumber> {
|
||||
multinomial_nb_search_parameters: MultinomialNBSearchParameters<T>,
|
||||
current_alpha: usize,
|
||||
current_priors: usize,
|
||||
}
|
||||
|
||||
impl<T: RealNumber> IntoIterator for MultinomialNBSearchParameters<T> {
|
||||
type Item = MultinomialNBParameters<T>;
|
||||
type IntoIter = MultinomialNBSearchParametersIterator<T>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
MultinomialNBSearchParametersIterator {
|
||||
multinomial_nb_search_parameters: self,
|
||||
current_alpha: 0,
|
||||
current_priors: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Iterator for MultinomialNBSearchParametersIterator<T> {
|
||||
type Item = MultinomialNBParameters<T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_alpha == self.multinomial_nb_search_parameters.alpha.len()
|
||||
&& self.current_priors == self.multinomial_nb_search_parameters.priors.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = MultinomialNBParameters {
|
||||
alpha: self.multinomial_nb_search_parameters.alpha[self.current_alpha],
|
||||
priors: self.multinomial_nb_search_parameters.priors[self.current_priors].clone(),
|
||||
};
|
||||
|
||||
if self.current_alpha + 1 < self.multinomial_nb_search_parameters.alpha.len() {
|
||||
self.current_alpha += 1;
|
||||
} else if self.current_priors + 1 < self.multinomial_nb_search_parameters.priors.len() {
|
||||
self.current_alpha = 0;
|
||||
self.current_priors += 1;
|
||||
} else {
|
||||
self.current_alpha += 1;
|
||||
self.current_priors += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> Default for MultinomialNBSearchParameters<T> {
|
||||
fn default() -> Self {
|
||||
let default_params = MultinomialNBParameters::default();
|
||||
|
||||
MultinomialNBSearchParameters {
|
||||
alpha: vec![default_params.alpha],
|
||||
priors: vec![default_params.priors],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber> MultinomialNBDistribution<T> {
|
||||
/// Fits the distribution to a NxM matrix where N is number of samples and M is number of features.
|
||||
/// * `x` - training data.
|
||||
@@ -297,6 +371,20 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = MultinomialNBSearchParameters {
|
||||
alpha: vec![1., 2.],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 1.);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.alpha, 2.);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn run_multinomial_naive_bayes() {
|
||||
|
||||
@@ -49,16 +49,21 @@ use crate::neighbors::KNNWeightFunction;
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KNNClassifierParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
pub distance: D,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// backend search algorithm. See [`knn search algorithms`](../../algorithm/neighbour/index.html). `CoverTree` is default.
|
||||
pub algorithm: KNNAlgorithmName,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// weighting function that is used to calculate estimated class value. Default function is `KNNWeightFunction::Uniform`.
|
||||
pub weight: KNNWeightFunction,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// number of training samples to consider when estimating class for new point. Default value is 3.
|
||||
pub k: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// this parameter is not used
|
||||
t: PhantomData<T>,
|
||||
}
|
||||
@@ -111,8 +116,8 @@ impl<T: RealNumber> Default for KNNClassifierParameters<T, Euclidian> {
|
||||
fn default() -> Self {
|
||||
KNNClassifierParameters {
|
||||
distance: Distances::euclidian(),
|
||||
algorithm: KNNAlgorithmName::CoverTree,
|
||||
weight: KNNWeightFunction::Uniform,
|
||||
algorithm: KNNAlgorithmName::default(),
|
||||
weight: KNNWeightFunction::default(),
|
||||
k: 3,
|
||||
t: PhantomData,
|
||||
}
|
||||
|
||||
@@ -52,16 +52,21 @@ use crate::neighbors::KNNWeightFunction;
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KNNRegressorParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// a function that defines a distance between each pair of point in training data.
|
||||
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
|
||||
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
|
||||
distance: D,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// backend search algorithm. See [`knn search algorithms`](../../algorithm/neighbour/index.html). `CoverTree` is default.
|
||||
pub algorithm: KNNAlgorithmName,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// weighting function that is used to calculate estimated class value. Default function is `KNNWeightFunction::Uniform`.
|
||||
pub weight: KNNWeightFunction,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// number of training samples to consider when estimating class for new point. Default value is 3.
|
||||
pub k: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// this parameter is not used
|
||||
t: PhantomData<T>,
|
||||
}
|
||||
@@ -113,8 +118,8 @@ impl<T: RealNumber> Default for KNNRegressorParameters<T, Euclidian> {
|
||||
fn default() -> Self {
|
||||
KNNRegressorParameters {
|
||||
distance: Distances::euclidian(),
|
||||
algorithm: KNNAlgorithmName::CoverTree,
|
||||
weight: KNNWeightFunction::Uniform,
|
||||
algorithm: KNNAlgorithmName::default(),
|
||||
weight: KNNWeightFunction::default(),
|
||||
k: 3,
|
||||
t: PhantomData,
|
||||
}
|
||||
|
||||
@@ -58,6 +58,12 @@ pub enum KNNWeightFunction {
|
||||
Distance,
|
||||
}
|
||||
|
||||
impl Default for KNNWeightFunction {
|
||||
fn default() -> Self {
|
||||
KNNWeightFunction::Uniform
|
||||
}
|
||||
}
|
||||
|
||||
impl KNNWeightFunction {
|
||||
fn calc_weights<T: RealNumber>(&self, distances: Vec<T>) -> std::vec::Vec<T> {
|
||||
match *self {
|
||||
|
||||
@@ -5,7 +5,7 @@ pub type F<'a, T, X> = dyn for<'b> Fn(&'b X) -> T + 'a;
|
||||
pub type DF<'a, X> = dyn for<'b> Fn(&'b mut X, &'b X) + 'a;
|
||||
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
#[derive(Debug, PartialEq)]
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum FunctionOrder {
|
||||
SECOND,
|
||||
THIRD,
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
/// Transform a data matrix by replaceing all categorical variables with their one-hot vector equivalents
|
||||
/// Transform a data matrix by replacing all categorical variables with their one-hot vector equivalents
|
||||
pub mod categorical;
|
||||
mod data_traits;
|
||||
/// Preprocess numerical matrices.
|
||||
pub mod numerical;
|
||||
/// Encode a series (column, array) of categorical variables as one-hot vectors
|
||||
pub mod series_encoder;
|
||||
|
||||
@@ -0,0 +1,447 @@
|
||||
//! # Standard-Scaling For [RealNumber](../../math/num/trait.RealNumber.html) Matricies
|
||||
//! Transform a data [Matrix](../../linalg/trait.BaseMatrix.html) by removing the mean and scaling to unit variance.
|
||||
//!
|
||||
//! ### Usage Example
|
||||
//! ```
|
||||
//! use smartcore::api::{Transformer, UnsupervisedEstimator};
|
||||
//! use smartcore::linalg::naive::dense_matrix::DenseMatrix;
|
||||
//! use smartcore::preprocessing::numerical;
|
||||
//! let data = DenseMatrix::from_2d_vec(&vec![
|
||||
//! vec![0.0, 0.0],
|
||||
//! vec![0.0, 0.0],
|
||||
//! vec![1.0, 1.0],
|
||||
//! vec![1.0, 1.0],
|
||||
//! ]);
|
||||
//!
|
||||
//! let standard_scaler =
|
||||
//! numerical::StandardScaler::fit(&data, numerical::StandardScalerParameters::default())
|
||||
//! .unwrap();
|
||||
//! let transformed_data = standard_scaler.transform(&data).unwrap();
|
||||
//! assert_eq!(
|
||||
//! transformed_data,
|
||||
//! DenseMatrix::from_2d_vec(&vec![
|
||||
//! vec![-1.0, -1.0],
|
||||
//! vec![-1.0, -1.0],
|
||||
//! vec![1.0, 1.0],
|
||||
//! vec![1.0, 1.0],
|
||||
//! ])
|
||||
//! );
|
||||
//! ```
|
||||
use crate::api::{Transformer, UnsupervisedEstimator};
|
||||
use crate::error::{Failed, FailedError};
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
#[cfg(feature = "serde")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Configure Behaviour of `StandardScaler`.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Clone, Debug, Copy, Eq, PartialEq)]
|
||||
pub struct StandardScalerParameters {
|
||||
/// Optionaly adjust mean to be zero.
|
||||
with_mean: bool,
|
||||
/// Optionally adjust standard-deviation to be one.
|
||||
with_std: bool,
|
||||
}
|
||||
impl Default for StandardScalerParameters {
|
||||
fn default() -> Self {
|
||||
StandardScalerParameters {
|
||||
with_mean: true,
|
||||
with_std: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// With the `StandardScaler` data can be adjusted so
|
||||
/// that every column has a mean of zero and a standard
|
||||
/// deviation of one. This can improve model training for
|
||||
/// scaling sensitive models like neural network or nearest
|
||||
/// neighbors based models.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Clone, Debug, Default, Eq, PartialEq)]
|
||||
pub struct StandardScaler<T: RealNumber> {
|
||||
means: Vec<T>,
|
||||
stds: Vec<T>,
|
||||
parameters: StandardScalerParameters,
|
||||
}
|
||||
impl<T: RealNumber> StandardScaler<T> {
|
||||
/// When the mean should be adjusted, the column mean
|
||||
/// should be kept. Otherwise, replace it by zero.
|
||||
fn adjust_column_mean(&self, mean: T) -> T {
|
||||
if self.parameters.with_mean {
|
||||
mean
|
||||
} else {
|
||||
T::zero()
|
||||
}
|
||||
}
|
||||
/// When the standard-deviation should be adjusted, the column
|
||||
/// standard-deviation should be kept. Otherwise, replace it by one.
|
||||
fn adjust_column_std(&self, std: T) -> T {
|
||||
if self.parameters.with_std {
|
||||
ensure_std_valid(std)
|
||||
} else {
|
||||
T::one()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Make sure the standard deviation is valid. If it is
|
||||
/// negative or zero, it should replaced by the smallest
|
||||
/// positive value the type can have. That way we can savely
|
||||
/// divide the columns with the resulting scalar.
|
||||
fn ensure_std_valid<T: RealNumber>(value: T) -> T {
|
||||
value.max(T::min_positive_value())
|
||||
}
|
||||
|
||||
/// During `fit` the `StandardScaler` computes the column means and standard deviation.
|
||||
impl<T: RealNumber, M: Matrix<T>> UnsupervisedEstimator<M, StandardScalerParameters>
|
||||
for StandardScaler<T>
|
||||
{
|
||||
fn fit(x: &M, parameters: StandardScalerParameters) -> Result<Self, Failed> {
|
||||
Ok(Self {
|
||||
means: x.column_mean(),
|
||||
stds: x.std(0),
|
||||
parameters,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// During `transform` the `StandardScaler` applies the summary statistics
|
||||
/// computed during `fit` to set the mean of each column to zero and the
|
||||
/// standard deviation to one.
|
||||
impl<T: RealNumber, M: Matrix<T>> Transformer<M> for StandardScaler<T> {
|
||||
fn transform(&self, x: &M) -> Result<M, Failed> {
|
||||
let (_, n_cols) = x.shape();
|
||||
if n_cols != self.means.len() {
|
||||
return Err(Failed::because(
|
||||
FailedError::TransformFailed,
|
||||
&format!(
|
||||
"Expected {} columns, but got {} columns instead.",
|
||||
self.means.len(),
|
||||
n_cols,
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(build_matrix_from_columns(
|
||||
self.means
|
||||
.iter()
|
||||
.zip(self.stds.iter())
|
||||
.enumerate()
|
||||
.map(|(column_index, (column_mean, column_std))| {
|
||||
x.take_column(column_index)
|
||||
.sub_scalar(self.adjust_column_mean(*column_mean))
|
||||
.div_scalar(self.adjust_column_std(*column_std))
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
.unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
/// From a collection of matrices, that contain columns, construct
|
||||
/// a matrix by stacking the columns horizontally.
|
||||
fn build_matrix_from_columns<T, M>(columns: Vec<M>) -> Option<M>
|
||||
where
|
||||
T: RealNumber,
|
||||
M: Matrix<T>,
|
||||
{
|
||||
if let Some(output_matrix) = columns.first().cloned() {
|
||||
return Some(
|
||||
columns
|
||||
.iter()
|
||||
.skip(1)
|
||||
.fold(output_matrix, |current_matrix, new_colum| {
|
||||
current_matrix.h_stack(new_colum)
|
||||
}),
|
||||
);
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
mod helper_functionality {
|
||||
use super::super::{build_matrix_from_columns, ensure_std_valid};
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[test]
|
||||
fn combine_three_columns() {
|
||||
assert_eq!(
|
||||
build_matrix_from_columns(vec![
|
||||
DenseMatrix::from_2d_vec(&vec![vec![1.0], vec![1.0], vec![1.0],]),
|
||||
DenseMatrix::from_2d_vec(&vec![vec![2.0], vec![2.0], vec![2.0],]),
|
||||
DenseMatrix::from_2d_vec(&vec![vec![3.0], vec![3.0], vec![3.0],])
|
||||
]),
|
||||
Some(DenseMatrix::from_2d_vec(&vec![
|
||||
vec![1.0, 2.0, 3.0],
|
||||
vec![1.0, 2.0, 3.0],
|
||||
vec![1.0, 2.0, 3.0]
|
||||
]))
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn negative_value_should_be_replace_with_minimal_positive_value() {
|
||||
assert_eq!(ensure_std_valid(-1.0), f64::MIN_POSITIVE)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zero_should_be_replace_with_minimal_positive_value() {
|
||||
assert_eq!(ensure_std_valid(0.0), f64::MIN_POSITIVE)
|
||||
}
|
||||
}
|
||||
mod standard_scaler {
|
||||
use super::super::{StandardScaler, StandardScalerParameters};
|
||||
use crate::api::{Transformer, UnsupervisedEstimator};
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::linalg::BaseMatrix;
|
||||
|
||||
#[test]
|
||||
fn dont_adjust_mean_if_used() {
|
||||
assert_eq!(
|
||||
(StandardScaler {
|
||||
means: vec![],
|
||||
stds: vec![],
|
||||
parameters: StandardScalerParameters {
|
||||
with_mean: true,
|
||||
with_std: true
|
||||
}
|
||||
})
|
||||
.adjust_column_mean(1.0),
|
||||
1.0
|
||||
)
|
||||
}
|
||||
#[test]
|
||||
fn replace_mean_with_zero_if_not_used() {
|
||||
assert_eq!(
|
||||
(StandardScaler {
|
||||
means: vec![],
|
||||
stds: vec![],
|
||||
parameters: StandardScalerParameters {
|
||||
with_mean: false,
|
||||
with_std: true
|
||||
}
|
||||
})
|
||||
.adjust_column_mean(1.0),
|
||||
0.0
|
||||
)
|
||||
}
|
||||
#[test]
|
||||
fn dont_adjust_std_if_used() {
|
||||
assert_eq!(
|
||||
(StandardScaler {
|
||||
means: vec![],
|
||||
stds: vec![],
|
||||
parameters: StandardScalerParameters {
|
||||
with_mean: true,
|
||||
with_std: true
|
||||
}
|
||||
})
|
||||
.adjust_column_std(10.0),
|
||||
10.0
|
||||
)
|
||||
}
|
||||
#[test]
|
||||
fn replace_std_with_one_if_not_used() {
|
||||
assert_eq!(
|
||||
(StandardScaler {
|
||||
means: vec![],
|
||||
stds: vec![],
|
||||
parameters: StandardScalerParameters {
|
||||
with_mean: true,
|
||||
with_std: false
|
||||
}
|
||||
})
|
||||
.adjust_column_std(10.0),
|
||||
1.0
|
||||
)
|
||||
}
|
||||
|
||||
/// Helper function to apply fit as well as transform at the same time.
|
||||
fn fit_transform_with_default_standard_scaler(
|
||||
values_to_be_transformed: &DenseMatrix<f64>,
|
||||
) -> DenseMatrix<f64> {
|
||||
StandardScaler::fit(
|
||||
values_to_be_transformed,
|
||||
StandardScalerParameters::default(),
|
||||
)
|
||||
.unwrap()
|
||||
.transform(values_to_be_transformed)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Fit transform with random generated values, expected values taken from
|
||||
/// sklearn.
|
||||
#[test]
|
||||
fn fit_transform_random_values() {
|
||||
let transformed_values =
|
||||
fit_transform_with_default_standard_scaler(&DenseMatrix::from_2d_array(&[
|
||||
&[0.1004222429, 0.2194113576, 0.9310663354, 0.3313593793],
|
||||
&[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264],
|
||||
&[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046],
|
||||
&[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442],
|
||||
]));
|
||||
println!("{}", transformed_values);
|
||||
assert!(transformed_values.approximate_eq(
|
||||
&DenseMatrix::from_2d_array(&[
|
||||
&[-1.1154020653, -0.4031985330, 0.9284605204, -0.4271473866],
|
||||
&[-0.7615464283, -0.7076698384, -1.1075452562, 1.2632979631],
|
||||
&[0.4832504303, -0.6106747444, 1.0630075435, 0.5494084257],
|
||||
&[1.3936980634, 1.7215431158, -0.8839228078, -1.3855590021],
|
||||
]),
|
||||
1.0
|
||||
))
|
||||
}
|
||||
|
||||
/// Test `fit` and `transform` for a column with zero variance.
|
||||
#[test]
|
||||
fn fit_transform_with_zero_variance() {
|
||||
assert_eq!(
|
||||
fit_transform_with_default_standard_scaler(&DenseMatrix::from_2d_array(&[
|
||||
&[1.0],
|
||||
&[1.0],
|
||||
&[1.0],
|
||||
&[1.0]
|
||||
])),
|
||||
DenseMatrix::from_2d_array(&[&[0.0], &[0.0], &[0.0], &[0.0]]),
|
||||
"When scaling values with zero variance, zero is expected as return value"
|
||||
)
|
||||
}
|
||||
|
||||
/// Test `fit` for columns with nice summary statistics.
|
||||
#[test]
|
||||
fn fit_for_simple_values() {
|
||||
assert_eq!(
|
||||
StandardScaler::fit(
|
||||
&DenseMatrix::from_2d_array(&[
|
||||
&[1.0, 1.0, 1.0],
|
||||
&[1.0, 2.0, 5.0],
|
||||
&[1.0, 1.0, 1.0],
|
||||
&[1.0, 2.0, 5.0]
|
||||
]),
|
||||
StandardScalerParameters::default(),
|
||||
),
|
||||
Ok(StandardScaler {
|
||||
means: vec![1.0, 1.5, 3.0],
|
||||
stds: vec![0.0, 0.5, 2.0],
|
||||
parameters: StandardScalerParameters {
|
||||
with_mean: true,
|
||||
with_std: true
|
||||
}
|
||||
})
|
||||
)
|
||||
}
|
||||
/// Test `fit` for random generated values.
|
||||
#[test]
|
||||
fn fit_for_random_values() {
|
||||
let fitted_scaler = StandardScaler::fit(
|
||||
&DenseMatrix::from_2d_array(&[
|
||||
&[0.1004222429, 0.2194113576, 0.9310663354, 0.3313593793],
|
||||
&[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264],
|
||||
&[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046],
|
||||
&[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442],
|
||||
]),
|
||||
StandardScalerParameters::default(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
fitted_scaler.means,
|
||||
vec![0.42864544605, 0.2869813741, 0.737752073825, 0.431011663625],
|
||||
);
|
||||
|
||||
assert!(
|
||||
&DenseMatrix::from_2d_vec(&vec![fitted_scaler.stds]).approximate_eq(
|
||||
&DenseMatrix::from_2d_array(&[&[
|
||||
0.29426447500954,
|
||||
0.16758497615485,
|
||||
0.20820945786863,
|
||||
0.23329718831165
|
||||
],]),
|
||||
0.00000000000001
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
/// If `with_std` is set to `false` the values should not be
|
||||
/// adjusted to have a std of one.
|
||||
#[test]
|
||||
fn transform_without_std() {
|
||||
let standard_scaler = StandardScaler {
|
||||
means: vec![1.0, 3.0],
|
||||
stds: vec![1.0, 2.0],
|
||||
parameters: StandardScalerParameters {
|
||||
with_mean: true,
|
||||
with_std: false,
|
||||
},
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
standard_scaler.transform(&DenseMatrix::from_2d_array(&[&[0.0, 2.0], &[2.0, 4.0]])),
|
||||
Ok(DenseMatrix::from_2d_array(&[&[-1.0, -1.0], &[1.0, 1.0]]))
|
||||
)
|
||||
}
|
||||
|
||||
/// If `with_mean` is set to `false` the values should not be adjusted
|
||||
/// to have a mean of zero.
|
||||
#[test]
|
||||
fn transform_without_mean() {
|
||||
let standard_scaler = StandardScaler {
|
||||
means: vec![1.0, 2.0],
|
||||
stds: vec![2.0, 3.0],
|
||||
parameters: StandardScalerParameters {
|
||||
with_mean: false,
|
||||
with_std: true,
|
||||
},
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
standard_scaler
|
||||
.transform(&DenseMatrix::from_2d_array(&[&[0.0, 9.0], &[4.0, 12.0]])),
|
||||
Ok(DenseMatrix::from_2d_array(&[&[0.0, 3.0], &[2.0, 4.0]]))
|
||||
)
|
||||
}
|
||||
|
||||
/// Same as `fit_for_random_values` test, but using a `StandardScaler` that has been
|
||||
/// serialized and deserialized.
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
#[cfg(feature = "serde")]
|
||||
fn serde_fit_for_random_values() {
|
||||
let fitted_scaler = StandardScaler::fit(
|
||||
&DenseMatrix::from_2d_array(&[
|
||||
&[0.1004222429, 0.2194113576, 0.9310663354, 0.3313593793],
|
||||
&[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264],
|
||||
&[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046],
|
||||
&[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442],
|
||||
]),
|
||||
StandardScalerParameters::default(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let deserialized_scaler: StandardScaler<f64> =
|
||||
serde_json::from_str(&serde_json::to_string(&fitted_scaler).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
deserialized_scaler.means,
|
||||
vec![0.42864544605, 0.2869813741, 0.737752073825, 0.431011663625],
|
||||
);
|
||||
|
||||
assert!(
|
||||
&DenseMatrix::from_2d_vec(&vec![deserialized_scaler.stds]).approximate_eq(
|
||||
&DenseMatrix::from_2d_array(&[&[
|
||||
0.29426447500954,
|
||||
0.16758497615485,
|
||||
0.20820945786863,
|
||||
0.23329718831165
|
||||
],]),
|
||||
0.00000000000001
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
+21
@@ -0,0 +1,21 @@
|
||||
use ::rand::SeedableRng;
|
||||
#[cfg(not(feature = "std"))]
|
||||
use rand::rngs::SmallRng as RngImpl;
|
||||
#[cfg(feature = "std")]
|
||||
use rand::rngs::StdRng as RngImpl;
|
||||
|
||||
pub(crate) fn get_rng_impl(seed: Option<u64>) -> RngImpl {
|
||||
match seed {
|
||||
Some(seed) => RngImpl::seed_from_u64(seed),
|
||||
None => {
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(feature = "std")] {
|
||||
use rand::RngCore;
|
||||
RngImpl::seed_from_u64(rand::thread_rng().next_u64())
|
||||
} else {
|
||||
panic!("seed number needed for non-std build");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,487 @@
|
||||
//! This module contains utitilities to read-in matrices from csv files.
|
||||
//! ```
|
||||
//! use smartcore::readers::csv;
|
||||
//! use smartcore::linalg::naive::dense_matrix::DenseMatrix;
|
||||
//! use crate::smartcore::linalg::BaseMatrix;
|
||||
//! use std::fs;
|
||||
//!
|
||||
//! fs::write("identity.csv", "header\n1.0,0.0\n0.0,1.0");
|
||||
//! assert_eq!(
|
||||
//! csv::matrix_from_csv_source::<f64, Vec<_>, DenseMatrix<_>>(
|
||||
//! fs::File::open("identity.csv").unwrap(),
|
||||
//! csv::CSVDefinition::default()
|
||||
//! )
|
||||
//! .unwrap(),
|
||||
//! DenseMatrix::from_row_vectors(vec![vec![1.0, 0.0], vec![0.0, 1.0]]).unwrap()
|
||||
//! );
|
||||
//! fs::remove_file("identity.csv");
|
||||
//! ```
|
||||
use crate::linalg::{BaseMatrix, BaseVector};
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::readers::ReadingError;
|
||||
use std::io::Read;
|
||||
|
||||
/// Define the structure of a CSV-file so that it can be read.
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub struct CSVDefinition<'a> {
|
||||
/// How many rows does the header have?
|
||||
n_rows_header: usize,
|
||||
/// What seperates the fields in your csv-file?
|
||||
field_seperator: &'a str,
|
||||
}
|
||||
impl<'a> Default for CSVDefinition<'a> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
n_rows_header: 1,
|
||||
field_seperator: ",",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Format definition for a single row in a csv file.
|
||||
/// This is used internally to validate rows of the csv file and
|
||||
/// be able to fail as early as possible.
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
struct CSVRowFormat<'a> {
|
||||
field_seperator: &'a str,
|
||||
n_fields: usize,
|
||||
}
|
||||
impl<'a> CSVRowFormat<'a> {
|
||||
fn from_csv_definition(definition: &'a CSVDefinition<'_>, n_fields: usize) -> Self {
|
||||
CSVRowFormat {
|
||||
field_seperator: definition.field_seperator,
|
||||
n_fields,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Detect the row format for the csv file from the first row.
|
||||
fn detect_row_format<'a>(
|
||||
csv_text: &'a str,
|
||||
definition: &'a CSVDefinition<'_>,
|
||||
) -> Result<CSVRowFormat<'a>, ReadingError> {
|
||||
let first_line = csv_text
|
||||
.lines()
|
||||
.nth(definition.n_rows_header)
|
||||
.ok_or(ReadingError::NoRowsProvided)?;
|
||||
|
||||
Ok(CSVRowFormat::from_csv_definition(
|
||||
definition,
|
||||
first_line.split(definition.field_seperator).count(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Read in a matrix from a source that contains a csv file.
|
||||
pub fn matrix_from_csv_source<T, RowVector, Matrix>(
|
||||
source: impl Read,
|
||||
definition: CSVDefinition<'_>,
|
||||
) -> Result<Matrix, ReadingError>
|
||||
where
|
||||
T: RealNumber,
|
||||
RowVector: BaseVector<T>,
|
||||
Matrix: BaseMatrix<T, RowVector = RowVector>,
|
||||
{
|
||||
let csv_text = read_string_from_source(source)?;
|
||||
let rows = extract_row_vectors_from_csv_text::<T, RowVector, Matrix>(
|
||||
&csv_text,
|
||||
&definition,
|
||||
detect_row_format(&csv_text, &definition)?,
|
||||
)?;
|
||||
|
||||
match Matrix::from_row_vectors(rows) {
|
||||
Some(matrix) => Ok(matrix),
|
||||
None => Err(ReadingError::NoRowsProvided),
|
||||
}
|
||||
}
|
||||
|
||||
/// Given a string containing the contents of a csv file, extract its value
|
||||
/// into row-vectors.
|
||||
fn extract_row_vectors_from_csv_text<'a, T, RowVector, Matrix>(
|
||||
csv_text: &'a str,
|
||||
definition: &'a CSVDefinition<'_>,
|
||||
row_format: CSVRowFormat<'_>,
|
||||
) -> Result<Vec<RowVector>, ReadingError>
|
||||
where
|
||||
T: RealNumber,
|
||||
RowVector: BaseVector<T>,
|
||||
Matrix: BaseMatrix<T, RowVector = RowVector>,
|
||||
{
|
||||
csv_text
|
||||
.lines()
|
||||
.skip(definition.n_rows_header)
|
||||
.enumerate()
|
||||
.map(|(row_index, line)| {
|
||||
enrich_reading_error(
|
||||
extract_vector_from_csv_line(line, &row_format),
|
||||
format!(", Row: {row_index}."),
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, ReadingError>>()
|
||||
}
|
||||
|
||||
/// Read a string from source implementing `Read`.
|
||||
fn read_string_from_source(mut source: impl Read) -> Result<String, ReadingError> {
|
||||
let mut string = String::new();
|
||||
source.read_to_string(&mut string)?;
|
||||
Ok(string)
|
||||
}
|
||||
|
||||
/// Extract a vector from a single line of a csv file.
|
||||
fn extract_vector_from_csv_line<T, RowVector>(
|
||||
line: &str,
|
||||
row_format: &CSVRowFormat<'_>,
|
||||
) -> Result<RowVector, ReadingError>
|
||||
where
|
||||
T: RealNumber,
|
||||
RowVector: BaseVector<T>,
|
||||
{
|
||||
validate_csv_row(line, row_format)?;
|
||||
let extracted_fields = extract_fields_from_csv_row(line, row_format)?;
|
||||
Ok(BaseVector::from_array(&extracted_fields[..]))
|
||||
}
|
||||
|
||||
/// Extract the fields from a string containing the row of a csv file.
|
||||
fn extract_fields_from_csv_row<T>(
|
||||
row: &str,
|
||||
row_format: &CSVRowFormat<'_>,
|
||||
) -> Result<Vec<T>, ReadingError>
|
||||
where
|
||||
T: RealNumber,
|
||||
{
|
||||
row.split(row_format.field_seperator)
|
||||
.enumerate()
|
||||
.map(|(field_number, csv_field)| {
|
||||
enrich_reading_error(
|
||||
extract_value_from_csv_field(csv_field.trim()),
|
||||
format!(" Column: {field_number}"),
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<T>, ReadingError>>()
|
||||
}
|
||||
|
||||
/// Ensure that a string containing a csv row conforms to a specified row format.
|
||||
fn validate_csv_row<'a>(row: &'a str, row_format: &CSVRowFormat<'_>) -> Result<(), ReadingError> {
|
||||
let actual_number_of_fields = row.split(row_format.field_seperator).count();
|
||||
if row_format.n_fields == actual_number_of_fields {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ReadingError::InvalidRow {
|
||||
msg: format!(
|
||||
"{} fields found but expected {}",
|
||||
actual_number_of_fields, row_format.n_fields
|
||||
),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Add additional text to the message of an error.
|
||||
/// In csv reading it is used to add the line-number / row-number
|
||||
/// The error occured that is only known in the functions above.
|
||||
fn enrich_reading_error<T>(
|
||||
result: Result<T, ReadingError>,
|
||||
additional_text: String,
|
||||
) -> Result<T, ReadingError> {
|
||||
result.map_err(|error| ReadingError::InvalidField {
|
||||
msg: format!(
|
||||
"{}{additional_text}",
|
||||
error.message().unwrap_or("Could not serialize value")
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
/// Extract the value from a single csv field.
|
||||
fn extract_value_from_csv_field<T>(value_string: &str) -> Result<T, ReadingError>
|
||||
where
|
||||
T: RealNumber,
|
||||
{
|
||||
// By default, `FromStr::Err` does not implement `Debug`.
|
||||
// Restricting it in the library leads to many breaking
|
||||
// changes therefore I have to reconstruct my own, printable
|
||||
// error as good as possible.
|
||||
match value_string.parse::<T>().ok() {
|
||||
Some(value) => Ok(value),
|
||||
None => Err(ReadingError::InvalidField {
|
||||
msg: format!("Value '{}' could not be read.", value_string,),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
mod matrix_from_csv_source {
|
||||
use super::super::{read_string_from_source, CSVDefinition, ReadingError};
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use crate::readers::{csv::matrix_from_csv_source, io_testing};
|
||||
|
||||
#[test]
|
||||
fn read_simple_string() {
|
||||
assert_eq!(
|
||||
read_string_from_source(io_testing::TestingDataSource::new("test-string")),
|
||||
Ok(String::from("test-string"))
|
||||
)
|
||||
}
|
||||
#[test]
|
||||
fn read_simple_csv() {
|
||||
assert_eq!(
|
||||
matrix_from_csv_source::<f64, Vec<_>, DenseMatrix<_>>(
|
||||
io_testing::TestingDataSource::new(
|
||||
"'sepal.length','sepal.width','petal.length','petal.width'\n\
|
||||
5.1,3.5,1.4,0.2\n\
|
||||
4.9,3.0,1.4,0.2\n\
|
||||
4.7,3.2,1.3,0.2",
|
||||
),
|
||||
CSVDefinition::default(),
|
||||
),
|
||||
Ok(DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
]))
|
||||
)
|
||||
}
|
||||
#[test]
|
||||
fn read_csv_semicolon_as_seperator() {
|
||||
assert_eq!(
|
||||
matrix_from_csv_source::<f64, Vec<_>, DenseMatrix<_>>(
|
||||
io_testing::TestingDataSource::new(
|
||||
"'sepal.length';'sepal.width';'petal.length';'petal.width'\n\
|
||||
'Length of sepals.';'Width of Sepals';'Length of petals';'Width of petals'\n\
|
||||
5.1;3.5;1.4;0.2\n\
|
||||
4.9;3.0;1.4;0.2\n\
|
||||
4.7;3.2;1.3;0.2",
|
||||
),
|
||||
CSVDefinition {
|
||||
n_rows_header: 2,
|
||||
field_seperator: ";"
|
||||
},
|
||||
),
|
||||
Ok(DenseMatrix::from_2d_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
]))
|
||||
)
|
||||
}
|
||||
#[test]
|
||||
fn error_in_colum_1_row_1() {
|
||||
assert_eq!(
|
||||
matrix_from_csv_source::<f64, Vec<_>, DenseMatrix<_>>(
|
||||
io_testing::TestingDataSource::new(
|
||||
"'sepal.length','sepal.width','petal.length','petal.width'\n\
|
||||
5.1,3.5,1.4,0.2\n\
|
||||
4.9,invalid,1.4,0.2\n\
|
||||
4.7,3.2,1.3,0.2",
|
||||
),
|
||||
CSVDefinition::default(),
|
||||
),
|
||||
Err(ReadingError::InvalidField {
|
||||
msg: String::from("Value 'invalid' could not be read. Column: 1, Row: 1.")
|
||||
})
|
||||
)
|
||||
}
|
||||
#[test]
|
||||
fn different_number_of_columns() {
|
||||
assert_eq!(
|
||||
matrix_from_csv_source::<f64, Vec<_>, DenseMatrix<_>>(
|
||||
io_testing::TestingDataSource::new(
|
||||
"'field_1','field_2'\n\
|
||||
5.1,3.5\n\
|
||||
4.9,3.0,1.4",
|
||||
),
|
||||
CSVDefinition::default(),
|
||||
),
|
||||
Err(ReadingError::InvalidField {
|
||||
msg: String::from("3 fields found but expected 2, Row: 1.")
|
||||
})
|
||||
)
|
||||
}
|
||||
}
|
||||
mod extract_row_vectors_from_csv_text {
|
||||
use super::super::{extract_row_vectors_from_csv_text, CSVDefinition, CSVRowFormat};
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[test]
|
||||
fn read_default_csv() {
|
||||
assert_eq!(
|
||||
extract_row_vectors_from_csv_text::<f64, Vec<_>, DenseMatrix<_>>(
|
||||
"column 1, column 2, column3\n1.0,2.0,3.0\n4.0,5.0,6.0",
|
||||
&CSVDefinition::default(),
|
||||
CSVRowFormat {
|
||||
field_seperator: ",",
|
||||
n_fields: 3,
|
||||
},
|
||||
),
|
||||
Ok(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]])
|
||||
);
|
||||
}
|
||||
}
|
||||
mod test_validate_csv_row {
|
||||
use super::super::{validate_csv_row, CSVRowFormat, ReadingError};
|
||||
|
||||
#[test]
|
||||
fn valid_row_with_comma() {
|
||||
assert_eq!(
|
||||
validate_csv_row(
|
||||
"1.0, 2.0, 3.0",
|
||||
&CSVRowFormat {
|
||||
field_seperator: ",",
|
||||
n_fields: 3,
|
||||
},
|
||||
),
|
||||
Ok(())
|
||||
)
|
||||
}
|
||||
#[test]
|
||||
fn valid_row_with_semicolon() {
|
||||
assert_eq!(
|
||||
validate_csv_row(
|
||||
"1.0; 2.0; 3.0; 4.0",
|
||||
&CSVRowFormat {
|
||||
field_seperator: ";",
|
||||
n_fields: 4,
|
||||
},
|
||||
),
|
||||
Ok(())
|
||||
)
|
||||
}
|
||||
#[test]
|
||||
fn invalid_number_of_fields() {
|
||||
assert_eq!(
|
||||
validate_csv_row(
|
||||
"1.0; 2.0; 3.0; 4.0",
|
||||
&CSVRowFormat {
|
||||
field_seperator: ";",
|
||||
n_fields: 3,
|
||||
},
|
||||
),
|
||||
Err(ReadingError::InvalidRow {
|
||||
msg: String::from("4 fields found but expected 3")
|
||||
})
|
||||
)
|
||||
}
|
||||
}
|
||||
mod extract_fields_from_csv_row {
|
||||
use super::super::{extract_fields_from_csv_row, CSVRowFormat};
|
||||
|
||||
#[test]
|
||||
fn read_four_values_from_csv_row() {
|
||||
assert_eq!(
|
||||
extract_fields_from_csv_row(
|
||||
"1.0; 2.0; 3.0; 4.0",
|
||||
&CSVRowFormat {
|
||||
field_seperator: ";",
|
||||
n_fields: 4
|
||||
}
|
||||
),
|
||||
Ok(vec![1.0, 2.0, 3.0, 4.0])
|
||||
)
|
||||
}
|
||||
}
|
||||
mod detect_row_format {
|
||||
use super::super::{detect_row_format, CSVDefinition, CSVRowFormat, ReadingError};
|
||||
|
||||
#[test]
|
||||
fn detect_2_fields_with_header() {
|
||||
assert_eq!(
|
||||
detect_row_format(
|
||||
"header-1\nheader-2\n1.0,2.0",
|
||||
&CSVDefinition {
|
||||
n_rows_header: 2,
|
||||
field_seperator: ","
|
||||
}
|
||||
)
|
||||
.expect("The row format should be detectable with this input."),
|
||||
CSVRowFormat {
|
||||
field_seperator: ",",
|
||||
n_fields: 2
|
||||
}
|
||||
)
|
||||
}
|
||||
#[test]
|
||||
fn detect_3_fields_no_header() {
|
||||
assert_eq!(
|
||||
detect_row_format(
|
||||
"1.0,2.0,3.0",
|
||||
&CSVDefinition {
|
||||
n_rows_header: 0,
|
||||
field_seperator: ","
|
||||
}
|
||||
)
|
||||
.expect("The row format should be detectable with this input."),
|
||||
CSVRowFormat {
|
||||
field_seperator: ",",
|
||||
n_fields: 3
|
||||
}
|
||||
)
|
||||
}
|
||||
#[test]
|
||||
fn detect_no_rows_provided() {
|
||||
assert_eq!(
|
||||
detect_row_format("header\n", &CSVDefinition::default()),
|
||||
Err(ReadingError::NoRowsProvided)
|
||||
)
|
||||
}
|
||||
}
|
||||
mod extract_value_from_csv_field {
|
||||
use super::super::extract_value_from_csv_field;
|
||||
use crate::readers::ReadingError;
|
||||
|
||||
#[test]
|
||||
fn deserialize_f64_from_floating_point() {
|
||||
assert_eq!(extract_value_from_csv_field::<f64>("1.0"), Ok(1.0))
|
||||
}
|
||||
#[test]
|
||||
fn deserialize_f64_from_negative_floating_point() {
|
||||
assert_eq!(extract_value_from_csv_field::<f64>("-1.0"), Ok(-1.0))
|
||||
}
|
||||
#[test]
|
||||
fn deserialize_f64_from_non_floating_point() {
|
||||
assert_eq!(extract_value_from_csv_field::<f64>("1"), Ok(1.0))
|
||||
}
|
||||
#[test]
|
||||
fn cant_deserialize_f64_from_string() {
|
||||
assert_eq!(
|
||||
extract_value_from_csv_field::<f64>("Test"),
|
||||
Err(ReadingError::InvalidField {
|
||||
msg: String::from("Value 'Test' could not be read.")
|
||||
},)
|
||||
)
|
||||
}
|
||||
#[test]
|
||||
fn deserialize_f32_from_non_floating_point() {
|
||||
assert_eq!(extract_value_from_csv_field::<f32>("12.0"), Ok(12.0))
|
||||
}
|
||||
}
|
||||
mod extract_vector_from_csv_line {
|
||||
use super::super::{extract_vector_from_csv_line, CSVRowFormat, ReadingError};
|
||||
|
||||
#[test]
|
||||
fn extract_five_floating_point_values() {
|
||||
assert_eq!(
|
||||
extract_vector_from_csv_line::<f64, Vec<f64>>(
|
||||
"-1.0,2.0,100.0,12",
|
||||
&CSVRowFormat {
|
||||
field_seperator: ",",
|
||||
n_fields: 4
|
||||
}
|
||||
),
|
||||
Ok(vec![-1.0, 2.0, 100.0, 12.0])
|
||||
)
|
||||
}
|
||||
#[test]
|
||||
fn cannot_extract_second_value() {
|
||||
assert_eq!(
|
||||
extract_vector_from_csv_line::<f64, Vec<f64>>(
|
||||
"-1.0,test,100.0,12",
|
||||
&CSVRowFormat {
|
||||
field_seperator: ",",
|
||||
n_fields: 4
|
||||
}
|
||||
),
|
||||
Err(ReadingError::InvalidField {
|
||||
msg: String::from("Value 'test' could not be read. Column: 1")
|
||||
})
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
//! The module contains the errors that can happen in the `readers` folder and
|
||||
//! utility functions.
|
||||
|
||||
/// Error wrapping all failures that can happen during loading from file.
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub enum ReadingError {
|
||||
/// The file could not be read from the file-system.
|
||||
CouldNotReadFileSystem {
|
||||
/// More details about the specific file-system error
|
||||
/// that occured.
|
||||
msg: String,
|
||||
},
|
||||
/// No rows exists in the CSV-file.
|
||||
NoRowsProvided,
|
||||
/// A field in the csv file could not be read.
|
||||
InvalidField {
|
||||
/// More details about what field could not be
|
||||
/// read and where it happened.
|
||||
msg: String,
|
||||
},
|
||||
/// A row from the csv is invalid.
|
||||
InvalidRow {
|
||||
/// More details about what row could not be read
|
||||
/// and where it happened.
|
||||
msg: String,
|
||||
},
|
||||
}
|
||||
impl From<std::io::Error> for ReadingError {
|
||||
fn from(io_error: std::io::Error) -> Self {
|
||||
Self::CouldNotReadFileSystem {
|
||||
msg: io_error.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl ReadingError {
|
||||
/// Extract the error-message from a `ReadingError`.
|
||||
pub fn message(&self) -> Option<&str> {
|
||||
match self {
|
||||
ReadingError::InvalidField { msg } => Some(msg),
|
||||
ReadingError::InvalidRow { msg } => Some(msg),
|
||||
ReadingError::CouldNotReadFileSystem { msg } => Some(msg),
|
||||
ReadingError::NoRowsProvided => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::ReadingError;
|
||||
use std::io;
|
||||
|
||||
#[test]
|
||||
fn reading_error_from_io_error() {
|
||||
let _parsed_reading_error: ReadingError = ReadingError::from(io::Error::new(
|
||||
io::ErrorKind::AlreadyExists,
|
||||
"File already exists .",
|
||||
));
|
||||
}
|
||||
#[test]
|
||||
fn extract_message_from_reading_error() {
|
||||
let error_content = "Path does not exist";
|
||||
assert_eq!(
|
||||
ReadingError::CouldNotReadFileSystem {
|
||||
msg: String::from(error_content)
|
||||
}
|
||||
.message()
|
||||
.expect("This error should contain a mesage"),
|
||||
String::from(error_content)
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,158 @@
|
||||
//! This module contains functionality to test IO. It has both functions that write
|
||||
//! to the file-system for end-to-end tests, but also abstractions to avoid this by
|
||||
//! reading from strings instead.
|
||||
use rand::distributions::{Alphanumeric, DistString};
|
||||
use std::fs;
|
||||
use std::io::Bytes;
|
||||
use std::io::Read;
|
||||
use std::io::{Chain, IoSliceMut, Take, Write};
|
||||
|
||||
/// Writing out a temporary csv file at a random location and cleaning
|
||||
/// it up on `Drop`.
|
||||
#[derive(Clone, Debug, Default, Eq, PartialEq)]
|
||||
pub struct TemporaryTextFile {
|
||||
random_path: String,
|
||||
}
|
||||
impl TemporaryTextFile {
|
||||
pub fn new(contents: &str) -> std::io::Result<Self> {
|
||||
let test_text_file = TemporaryTextFile {
|
||||
random_path: Alphanumeric.sample_string(&mut rand::thread_rng(), 16),
|
||||
};
|
||||
string_to_file(contents, &test_text_file.random_path)?;
|
||||
Ok(test_text_file)
|
||||
}
|
||||
pub fn path(&self) -> &str {
|
||||
&self.random_path
|
||||
}
|
||||
}
|
||||
/// On `Drop` we cleanup the file-system by remove the file.
|
||||
impl Drop for TemporaryTextFile {
|
||||
fn drop(&mut self) {
|
||||
fs::remove_file(self.path())
|
||||
.unwrap_or_else(|_| panic!("Could not clean up temporary file {}.", self.random_path));
|
||||
}
|
||||
}
|
||||
/// Write out a string to file.
|
||||
pub(crate) fn string_to_file(string: &str, file_path: &str) -> std::io::Result<()> {
|
||||
let mut csv_file = fs::File::create(file_path)?;
|
||||
csv_file.write_all(string.as_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// This is used an an alternative struct that implements `Read` so
|
||||
/// that instead of reading from the file-system, we can test the same
|
||||
/// functionality without any file-system interaction.
|
||||
pub(crate) struct TestingDataSource {
|
||||
text: String,
|
||||
}
|
||||
impl TestingDataSource {
|
||||
pub(crate) fn new(text: &str) -> Self {
|
||||
Self {
|
||||
text: String::from(text),
|
||||
}
|
||||
}
|
||||
}
|
||||
/// This is the trait that also `file::File` implements, so by implementing
|
||||
/// it for `TestingDataSource` we can test functionality that reads from
|
||||
/// file in a more lightweight way.
|
||||
impl Read for TestingDataSource {
|
||||
fn read(&mut self, _buf: &mut [u8]) -> Result<usize, std::io::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn read_vectored(&mut self, _bufs: &mut [IoSliceMut<'_>]) -> Result<usize, std::io::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn read_to_end(&mut self, _buf: &mut Vec<u8>) -> Result<usize, std::io::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
fn read_to_string(&mut self, buf: &mut String) -> Result<usize, std::io::Error> {
|
||||
<String as std::fmt::Write>::write_str(buf, &self.text).unwrap();
|
||||
Ok(0)
|
||||
}
|
||||
fn read_exact(&mut self, _buf: &mut [u8]) -> Result<(), std::io::Error> {
|
||||
unimplemented!()
|
||||
}
|
||||
fn by_ref(&mut self) -> &mut Self
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
unimplemented!()
|
||||
}
|
||||
fn bytes(self) -> Bytes<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
unimplemented!()
|
||||
}
|
||||
fn chain<R: Read>(self, _next: R) -> Chain<Self, R>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
unimplemented!()
|
||||
}
|
||||
fn take(self, _limit: u64) -> Take<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::TestingDataSource;
|
||||
use super::{string_to_file, TemporaryTextFile};
|
||||
use std::fs;
|
||||
use std::io::Read;
|
||||
use std::path;
|
||||
#[test]
|
||||
fn test_temporary_text_file() {
|
||||
let path_of_temporary_file;
|
||||
{
|
||||
let hello_world_file = TemporaryTextFile::new("Hello World!")
|
||||
.expect("`hello_world_file` should be able to write file.");
|
||||
|
||||
path_of_temporary_file = String::from(hello_world_file.path());
|
||||
assert_eq!(
|
||||
fs::read_to_string(&path_of_temporary_file).expect(
|
||||
"This field should have been written by the `hello_world_file`-object."
|
||||
),
|
||||
"Hello World!"
|
||||
)
|
||||
}
|
||||
// By now `hello_world_file` should have been dropped and the file
|
||||
// should have been cleaned up.
|
||||
assert!(!path::Path::new(&path_of_temporary_file).exists())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_string_to_file() {
|
||||
let path_of_test_file = "test.file";
|
||||
let contents_of_test_file = "Hello IO-World";
|
||||
|
||||
string_to_file(contents_of_test_file, path_of_test_file)
|
||||
.expect("The file should have been written out.");
|
||||
assert_eq!(
|
||||
fs::read_to_string(path_of_test_file)
|
||||
.expect("The file we test for should have been written."),
|
||||
String::from(contents_of_test_file)
|
||||
);
|
||||
|
||||
// Cleanup the temporary file.
|
||||
fs::remove_file(path_of_test_file)
|
||||
.expect("The test file should exist before and be removed here.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_from_testing_data_source() {
|
||||
let mut test_buffer = String::new();
|
||||
let test_data_content = "Hello non-IO world!";
|
||||
|
||||
TestingDataSource::new(test_data_content)
|
||||
.read_to_string(&mut test_buffer)
|
||||
.expect("Text should have been written to buffer `test_buffer`.");
|
||||
assert_eq!(test_buffer, test_data_content)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
/// Read in from csv.
|
||||
pub mod csv;
|
||||
|
||||
/// Error definition for readers.
|
||||
mod error;
|
||||
/// Utilities to help with testing functionality using IO.
|
||||
/// Only meant for internal usage.
|
||||
#[cfg(test)]
|
||||
pub(crate) mod io_testing;
|
||||
|
||||
pub use error::ReadingError;
|
||||
+5
-5
@@ -33,7 +33,7 @@ use crate::linalg::BaseVector;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
/// Defines a kernel function
|
||||
pub trait Kernel<T: RealNumber, V: BaseVector<T>> {
|
||||
pub trait Kernel<T: RealNumber, V: BaseVector<T>>: Clone {
|
||||
/// Apply kernel function to x_i and x_j
|
||||
fn apply(&self, x_i: &V, x_j: &V) -> T;
|
||||
}
|
||||
@@ -95,12 +95,12 @@ impl Kernels {
|
||||
|
||||
/// Linear Kernel
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct LinearKernel {}
|
||||
|
||||
/// Radial basis function (Gaussian) kernel
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct RBFKernel<T: RealNumber> {
|
||||
/// kernel coefficient
|
||||
pub gamma: T,
|
||||
@@ -108,7 +108,7 @@ pub struct RBFKernel<T: RealNumber> {
|
||||
|
||||
/// Polynomial kernel
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct PolynomialKernel<T: RealNumber> {
|
||||
/// degree of the polynomial
|
||||
pub degree: T,
|
||||
@@ -120,7 +120,7 @@ pub struct PolynomialKernel<T: RealNumber> {
|
||||
|
||||
/// Sigmoid (hyperbolic tangent) kernel
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct SigmoidKernel<T: RealNumber> {
|
||||
/// kernel coefficient
|
||||
pub gamma: T,
|
||||
|
||||
+177
-7
@@ -84,22 +84,154 @@ use crate::error::Failed;
|
||||
use crate::linalg::BaseVector;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::rand::get_rng_impl;
|
||||
use crate::svm::{Kernel, Kernels, LinearKernel};
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// SVC Parameters
|
||||
pub struct SVCParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Number of epochs.
|
||||
pub epoch: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Regularization parameter.
|
||||
pub c: T,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Tolerance for stopping criterion.
|
||||
pub tol: T,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The kernel function.
|
||||
pub kernel: K,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Unused parameter.
|
||||
m: PhantomData<M>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Controls the pseudo random number generation for shuffling the data for probability estimates
|
||||
seed: Option<u64>,
|
||||
}
|
||||
|
||||
/// SVC grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SVCSearchParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Number of epochs.
|
||||
pub epoch: Vec<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Regularization parameter.
|
||||
pub c: Vec<T>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Tolerance for stopping epoch.
|
||||
pub tol: Vec<T>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The kernel function.
|
||||
pub kernel: Vec<K>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Unused parameter.
|
||||
m: PhantomData<M>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Controls the pseudo random number generation for shuffling the data for probability estimates
|
||||
seed: Vec<Option<u64>>,
|
||||
}
|
||||
|
||||
/// SVC grid search iterator
|
||||
pub struct SVCSearchParametersIterator<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||
svc_search_parameters: SVCSearchParameters<T, M, K>,
|
||||
current_epoch: usize,
|
||||
current_c: usize,
|
||||
current_tol: usize,
|
||||
current_kernel: usize,
|
||||
current_seed: usize,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> IntoIterator
|
||||
for SVCSearchParameters<T, M, K>
|
||||
{
|
||||
type Item = SVCParameters<T, M, K>;
|
||||
type IntoIter = SVCSearchParametersIterator<T, M, K>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
SVCSearchParametersIterator {
|
||||
svc_search_parameters: self,
|
||||
current_epoch: 0,
|
||||
current_c: 0,
|
||||
current_tol: 0,
|
||||
current_kernel: 0,
|
||||
current_seed: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Iterator
|
||||
for SVCSearchParametersIterator<T, M, K>
|
||||
{
|
||||
type Item = SVCParameters<T, M, K>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_epoch == self.svc_search_parameters.epoch.len()
|
||||
&& self.current_c == self.svc_search_parameters.c.len()
|
||||
&& self.current_tol == self.svc_search_parameters.tol.len()
|
||||
&& self.current_kernel == self.svc_search_parameters.kernel.len()
|
||||
&& self.current_seed == self.svc_search_parameters.kernel.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = SVCParameters::<T, M, K> {
|
||||
epoch: self.svc_search_parameters.epoch[self.current_epoch],
|
||||
c: self.svc_search_parameters.c[self.current_c],
|
||||
tol: self.svc_search_parameters.tol[self.current_tol],
|
||||
kernel: self.svc_search_parameters.kernel[self.current_kernel].clone(),
|
||||
m: PhantomData,
|
||||
seed: self.svc_search_parameters.seed[self.current_seed],
|
||||
};
|
||||
|
||||
if self.current_epoch + 1 < self.svc_search_parameters.epoch.len() {
|
||||
self.current_epoch += 1;
|
||||
} else if self.current_c + 1 < self.svc_search_parameters.c.len() {
|
||||
self.current_epoch = 0;
|
||||
self.current_c += 1;
|
||||
} else if self.current_tol + 1 < self.svc_search_parameters.tol.len() {
|
||||
self.current_epoch = 0;
|
||||
self.current_c = 0;
|
||||
self.current_tol += 1;
|
||||
} else if self.current_kernel + 1 < self.svc_search_parameters.kernel.len() {
|
||||
self.current_epoch = 0;
|
||||
self.current_c = 0;
|
||||
self.current_tol = 0;
|
||||
self.current_kernel += 1;
|
||||
} else if self.current_kernel + 1 < self.svc_search_parameters.kernel.len() {
|
||||
self.current_epoch = 0;
|
||||
self.current_c = 0;
|
||||
self.current_tol = 0;
|
||||
self.current_kernel = 0;
|
||||
self.current_seed += 1;
|
||||
} else {
|
||||
self.current_epoch += 1;
|
||||
self.current_c += 1;
|
||||
self.current_tol += 1;
|
||||
self.current_kernel += 1;
|
||||
self.current_seed += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Default for SVCSearchParameters<T, M, LinearKernel> {
|
||||
fn default() -> Self {
|
||||
let default_params: SVCParameters<T, M, LinearKernel> = SVCParameters::default();
|
||||
|
||||
SVCSearchParameters {
|
||||
epoch: vec![default_params.epoch],
|
||||
c: vec![default_params.c],
|
||||
tol: vec![default_params.tol],
|
||||
kernel: vec![default_params.kernel],
|
||||
m: PhantomData,
|
||||
seed: vec![default_params.seed],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
@@ -176,8 +308,15 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVCParameters<T, M
|
||||
tol: self.tol,
|
||||
kernel,
|
||||
m: PhantomData,
|
||||
seed: self.seed,
|
||||
}
|
||||
}
|
||||
|
||||
/// Seed for the pseudo random number generator.
|
||||
pub fn with_seed(mut self, seed: Option<u64>) -> Self {
|
||||
self.seed = seed;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Default for SVCParameters<T, M, LinearKernel> {
|
||||
@@ -188,6 +327,7 @@ impl<T: RealNumber, M: Matrix<T>> Default for SVCParameters<T, M, LinearKernel>
|
||||
tol: T::from_f64(1e-3).unwrap(),
|
||||
kernel: Kernels::linear(),
|
||||
m: PhantomData,
|
||||
seed: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -408,7 +548,7 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
|
||||
let good_enough = T::from_i32(1000).unwrap();
|
||||
|
||||
for _ in 0..self.parameters.epoch {
|
||||
for i in Self::permutate(n) {
|
||||
for i in self.permutate(n) {
|
||||
self.process(i, self.x.get_row(i), self.y.get(i), &mut cache);
|
||||
loop {
|
||||
self.reprocess(tol, &mut cache);
|
||||
@@ -441,7 +581,7 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
|
||||
let mut cp = 0;
|
||||
let mut cn = 0;
|
||||
|
||||
for i in Self::permutate(n) {
|
||||
for i in self.permutate(n) {
|
||||
if self.y.get(i) == T::one() && cp < few {
|
||||
if self.process(i, self.x.get_row(i), self.y.get(i), cache) {
|
||||
cp += 1;
|
||||
@@ -566,8 +706,8 @@ impl<'a, T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Optimizer<'a,
|
||||
self.recalculate_minmax_grad = true;
|
||||
}
|
||||
|
||||
fn permutate(n: usize) -> Vec<usize> {
|
||||
let mut rng = rand::thread_rng();
|
||||
fn permutate(&self, n: usize) -> Vec<usize> {
|
||||
let mut rng = get_rng_impl(self.parameters.seed);
|
||||
let mut range: Vec<usize> = (0..n).collect();
|
||||
range.shuffle(&mut rng);
|
||||
range
|
||||
@@ -737,6 +877,24 @@ mod tests {
|
||||
#[cfg(feature = "serde")]
|
||||
use crate::svm::*;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters: SVCSearchParameters<f64, DenseMatrix<f64>, LinearKernel> =
|
||||
SVCSearchParameters {
|
||||
epoch: vec![10, 100],
|
||||
kernel: vec![LinearKernel {}],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.epoch, 10);
|
||||
assert_eq!(next.kernel, LinearKernel {});
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.epoch, 100);
|
||||
assert_eq!(next.kernel, LinearKernel {});
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn svc_fit_predict() {
|
||||
@@ -772,12 +930,18 @@ mod tests {
|
||||
&y,
|
||||
SVCParameters::default()
|
||||
.with_c(200.0)
|
||||
.with_kernel(Kernels::linear()),
|
||||
.with_kernel(Kernels::linear())
|
||||
.with_seed(Some(100)),
|
||||
)
|
||||
.and_then(|lr| lr.predict(&x))
|
||||
.unwrap();
|
||||
let acc = accuracy(&y_hat, &y);
|
||||
|
||||
assert!(accuracy(&y_hat, &y) >= 0.9);
|
||||
assert!(
|
||||
acc >= 0.9,
|
||||
"accuracy ({}) is not larger or equal to 0.9",
|
||||
acc
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
@@ -860,7 +1024,13 @@ mod tests {
|
||||
.and_then(|lr| lr.predict(&x))
|
||||
.unwrap();
|
||||
|
||||
assert!(accuracy(&y_hat, &y) >= 0.9);
|
||||
let acc = accuracy(&y_hat, &y);
|
||||
|
||||
assert!(
|
||||
acc >= 0.9,
|
||||
"accuracy ({}) is not larger or equal to 0.9",
|
||||
acc
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
|
||||
+122
-1
@@ -94,6 +94,109 @@ pub struct SVRParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>
|
||||
m: PhantomData<M>,
|
||||
}
|
||||
|
||||
/// SVR grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SVRSearchParameters<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||
/// Epsilon in the epsilon-SVR model.
|
||||
pub eps: Vec<T>,
|
||||
/// Regularization parameter.
|
||||
pub c: Vec<T>,
|
||||
/// Tolerance for stopping eps.
|
||||
pub tol: Vec<T>,
|
||||
/// The kernel function.
|
||||
pub kernel: Vec<K>,
|
||||
/// Unused parameter.
|
||||
m: PhantomData<M>,
|
||||
}
|
||||
|
||||
/// SVR grid search iterator
|
||||
pub struct SVRSearchParametersIterator<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> {
|
||||
svr_search_parameters: SVRSearchParameters<T, M, K>,
|
||||
current_eps: usize,
|
||||
current_c: usize,
|
||||
current_tol: usize,
|
||||
current_kernel: usize,
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> IntoIterator
|
||||
for SVRSearchParameters<T, M, K>
|
||||
{
|
||||
type Item = SVRParameters<T, M, K>;
|
||||
type IntoIter = SVRSearchParametersIterator<T, M, K>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
SVRSearchParametersIterator {
|
||||
svr_search_parameters: self,
|
||||
current_eps: 0,
|
||||
current_c: 0,
|
||||
current_tol: 0,
|
||||
current_kernel: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> Iterator
|
||||
for SVRSearchParametersIterator<T, M, K>
|
||||
{
|
||||
type Item = SVRParameters<T, M, K>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_eps == self.svr_search_parameters.eps.len()
|
||||
&& self.current_c == self.svr_search_parameters.c.len()
|
||||
&& self.current_tol == self.svr_search_parameters.tol.len()
|
||||
&& self.current_kernel == self.svr_search_parameters.kernel.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = SVRParameters::<T, M, K> {
|
||||
eps: self.svr_search_parameters.eps[self.current_eps],
|
||||
c: self.svr_search_parameters.c[self.current_c],
|
||||
tol: self.svr_search_parameters.tol[self.current_tol],
|
||||
kernel: self.svr_search_parameters.kernel[self.current_kernel].clone(),
|
||||
m: PhantomData,
|
||||
};
|
||||
|
||||
if self.current_eps + 1 < self.svr_search_parameters.eps.len() {
|
||||
self.current_eps += 1;
|
||||
} else if self.current_c + 1 < self.svr_search_parameters.c.len() {
|
||||
self.current_eps = 0;
|
||||
self.current_c += 1;
|
||||
} else if self.current_tol + 1 < self.svr_search_parameters.tol.len() {
|
||||
self.current_eps = 0;
|
||||
self.current_c = 0;
|
||||
self.current_tol += 1;
|
||||
} else if self.current_kernel + 1 < self.svr_search_parameters.kernel.len() {
|
||||
self.current_eps = 0;
|
||||
self.current_c = 0;
|
||||
self.current_tol = 0;
|
||||
self.current_kernel += 1;
|
||||
} else {
|
||||
self.current_eps += 1;
|
||||
self.current_c += 1;
|
||||
self.current_tol += 1;
|
||||
self.current_kernel += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: RealNumber, M: Matrix<T>> Default for SVRSearchParameters<T, M, LinearKernel> {
|
||||
fn default() -> Self {
|
||||
let default_params: SVRParameters<T, M, LinearKernel> = SVRParameters::default();
|
||||
|
||||
SVRSearchParameters {
|
||||
eps: vec![default_params.eps],
|
||||
c: vec![default_params.c],
|
||||
tol: vec![default_params.tol],
|
||||
kernel: vec![default_params.kernel],
|
||||
m: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
#[cfg_attr(
|
||||
@@ -242,7 +345,7 @@ impl<T: RealNumber, M: Matrix<T>, K: Kernel<T, M::RowVector>> SVR<T, M, K> {
|
||||
Ok(y_hat)
|
||||
}
|
||||
|
||||
pub(in crate) fn predict_for_row(&self, x: M::RowVector) -> T {
|
||||
pub(crate) fn predict_for_row(&self, x: M::RowVector) -> T {
|
||||
let mut f = self.b;
|
||||
|
||||
for i in 0..self.instances.len() {
|
||||
@@ -536,6 +639,24 @@ mod tests {
|
||||
#[cfg(feature = "serde")]
|
||||
use crate::svm::*;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters: SVRSearchParameters<f64, DenseMatrix<f64>, LinearKernel> =
|
||||
SVRSearchParameters {
|
||||
eps: vec![0., 1.],
|
||||
kernel: vec![LinearKernel {}],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.eps, 0.);
|
||||
assert_eq!(next.kernel, LinearKernel {});
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.eps, 1.);
|
||||
assert_eq!(next.kernel, LinearKernel {});
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn svr_fit_predict() {
|
||||
|
||||
@@ -77,19 +77,27 @@ use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::rand::get_rng_impl;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// Parameters of Decision Tree
|
||||
pub struct DecisionTreeClassifierParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Split criteria to use when building a tree.
|
||||
pub criterion: SplitCriterion,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The maximum depth of the tree.
|
||||
pub max_depth: Option<u16>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to be at a leaf node.
|
||||
pub min_samples_leaf: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to split an internal node.
|
||||
pub min_samples_split: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Controls the randomness of the estimator
|
||||
pub seed: Option<u64>,
|
||||
}
|
||||
|
||||
/// Decision Tree
|
||||
@@ -115,6 +123,12 @@ pub enum SplitCriterion {
|
||||
ClassificationError,
|
||||
}
|
||||
|
||||
impl Default for SplitCriterion {
|
||||
fn default() -> Self {
|
||||
SplitCriterion::Gini
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug)]
|
||||
struct Node<T: RealNumber> {
|
||||
@@ -193,10 +207,169 @@ impl DecisionTreeClassifierParameters {
|
||||
impl Default for DecisionTreeClassifierParameters {
|
||||
fn default() -> Self {
|
||||
DecisionTreeClassifierParameters {
|
||||
criterion: SplitCriterion::Gini,
|
||||
criterion: SplitCriterion::default(),
|
||||
max_depth: None,
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2,
|
||||
seed: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// DecisionTreeClassifier grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DecisionTreeClassifierSearchParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub criterion: Vec<SplitCriterion>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub max_depth: Vec<Option<u16>>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub min_samples_leaf: Vec<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||
pub min_samples_split: Vec<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Controls the randomness of the estimator
|
||||
pub seed: Vec<Option<u64>>,
|
||||
}
|
||||
|
||||
/// DecisionTreeClassifier grid search iterator
|
||||
pub struct DecisionTreeClassifierSearchParametersIterator {
|
||||
decision_tree_classifier_search_parameters: DecisionTreeClassifierSearchParameters,
|
||||
current_criterion: usize,
|
||||
current_max_depth: usize,
|
||||
current_min_samples_leaf: usize,
|
||||
current_min_samples_split: usize,
|
||||
current_seed: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for DecisionTreeClassifierSearchParameters {
|
||||
type Item = DecisionTreeClassifierParameters;
|
||||
type IntoIter = DecisionTreeClassifierSearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
DecisionTreeClassifierSearchParametersIterator {
|
||||
decision_tree_classifier_search_parameters: self,
|
||||
current_criterion: 0,
|
||||
current_max_depth: 0,
|
||||
current_min_samples_leaf: 0,
|
||||
current_min_samples_split: 0,
|
||||
current_seed: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for DecisionTreeClassifierSearchParametersIterator {
|
||||
type Item = DecisionTreeClassifierParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_criterion
|
||||
== self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.criterion
|
||||
.len()
|
||||
&& self.current_max_depth
|
||||
== self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.max_depth
|
||||
.len()
|
||||
&& self.current_min_samples_leaf
|
||||
== self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.min_samples_leaf
|
||||
.len()
|
||||
&& self.current_min_samples_split
|
||||
== self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.min_samples_split
|
||||
.len()
|
||||
&& self.current_seed == self.decision_tree_classifier_search_parameters.seed.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = DecisionTreeClassifierParameters {
|
||||
criterion: self.decision_tree_classifier_search_parameters.criterion
|
||||
[self.current_criterion]
|
||||
.clone(),
|
||||
max_depth: self.decision_tree_classifier_search_parameters.max_depth
|
||||
[self.current_max_depth],
|
||||
min_samples_leaf: self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.min_samples_leaf[self.current_min_samples_leaf],
|
||||
min_samples_split: self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.min_samples_split[self.current_min_samples_split],
|
||||
seed: self.decision_tree_classifier_search_parameters.seed[self.current_seed],
|
||||
};
|
||||
|
||||
if self.current_criterion + 1
|
||||
< self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.criterion
|
||||
.len()
|
||||
{
|
||||
self.current_criterion += 1;
|
||||
} else if self.current_max_depth + 1
|
||||
< self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.max_depth
|
||||
.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth += 1;
|
||||
} else if self.current_min_samples_leaf + 1
|
||||
< self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.min_samples_leaf
|
||||
.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf += 1;
|
||||
} else if self.current_min_samples_split + 1
|
||||
< self
|
||||
.decision_tree_classifier_search_parameters
|
||||
.min_samples_split
|
||||
.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split += 1;
|
||||
} else if self.current_seed + 1 < self.decision_tree_classifier_search_parameters.seed.len()
|
||||
{
|
||||
self.current_criterion = 0;
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split = 0;
|
||||
self.current_seed += 1;
|
||||
} else {
|
||||
self.current_criterion += 1;
|
||||
self.current_max_depth += 1;
|
||||
self.current_min_samples_leaf += 1;
|
||||
self.current_min_samples_split += 1;
|
||||
self.current_seed += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DecisionTreeClassifierSearchParameters {
|
||||
fn default() -> Self {
|
||||
let default_params = DecisionTreeClassifierParameters::default();
|
||||
|
||||
DecisionTreeClassifierSearchParameters {
|
||||
criterion: vec![default_params.criterion],
|
||||
max_depth: vec![default_params.max_depth],
|
||||
min_samples_leaf: vec![default_params.min_samples_leaf],
|
||||
min_samples_split: vec![default_params.min_samples_split],
|
||||
seed: vec![default_params.seed],
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -285,7 +458,7 @@ impl<'a, T: RealNumber, M: Matrix<T>> NodeVisitor<'a, T, M> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(in crate) fn which_max(x: &[usize]) -> usize {
|
||||
pub(crate) fn which_max(x: &[usize]) -> usize {
|
||||
let mut m = x[0];
|
||||
let mut which = 0;
|
||||
|
||||
@@ -329,14 +502,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
) -> Result<DecisionTreeClassifier<T>, Failed> {
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let samples = vec![1; x_nrows];
|
||||
DecisionTreeClassifier::fit_weak_learner(
|
||||
x,
|
||||
y,
|
||||
samples,
|
||||
num_attributes,
|
||||
parameters,
|
||||
&mut rand::thread_rng(),
|
||||
)
|
||||
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||
}
|
||||
|
||||
pub(crate) fn fit_weak_learner<M: Matrix<T>>(
|
||||
@@ -345,7 +511,6 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
samples: Vec<usize>,
|
||||
mtry: usize,
|
||||
parameters: DecisionTreeClassifierParameters,
|
||||
rng: &mut impl Rng,
|
||||
) -> Result<DecisionTreeClassifier<T>, Failed> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
let (_, y_ncols) = y_m.shape();
|
||||
@@ -359,6 +524,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
)));
|
||||
}
|
||||
|
||||
let mut rng = get_rng_impl(parameters.seed);
|
||||
let mut yi: Vec<usize> = vec![0; y_ncols];
|
||||
|
||||
for (i, yi_i) in yi.iter_mut().enumerate().take(y_ncols) {
|
||||
@@ -393,13 +559,13 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
|
||||
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
|
||||
|
||||
if tree.find_best_cutoff(&mut visitor, mtry, rng) {
|
||||
if tree.find_best_cutoff(&mut visitor, mtry, &mut rng) {
|
||||
visitor_queue.push_back(visitor);
|
||||
}
|
||||
|
||||
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
||||
match visitor_queue.pop_front() {
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue, rng),
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue, &mut rng),
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
@@ -421,7 +587,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
Ok(result.to_row_vector())
|
||||
}
|
||||
|
||||
pub(in crate) fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
|
||||
pub(crate) fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
|
||||
let mut result = 0;
|
||||
let mut queue: LinkedList<usize> = LinkedList::new();
|
||||
|
||||
@@ -651,6 +817,29 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = DecisionTreeClassifierSearchParameters {
|
||||
max_depth: vec![Some(10), Some(100)],
|
||||
min_samples_split: vec![1, 2],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.max_depth, Some(10));
|
||||
assert_eq!(next.min_samples_split, 1);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.max_depth, Some(100));
|
||||
assert_eq!(next.min_samples_split, 1);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.max_depth, Some(10));
|
||||
assert_eq!(next.min_samples_split, 2);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.max_depth, Some(100));
|
||||
assert_eq!(next.min_samples_split, 2);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn gini_impurity() {
|
||||
@@ -713,7 +902,8 @@ mod tests {
|
||||
criterion: SplitCriterion::Entropy,
|
||||
max_depth: Some(3),
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2
|
||||
min_samples_split: 2,
|
||||
seed: None
|
||||
}
|
||||
)
|
||||
.unwrap()
|
||||
|
||||
@@ -72,17 +72,24 @@ use crate::api::{Predictor, SupervisedEstimator};
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
use crate::rand::get_rng_impl;
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
/// Parameters of Regression Tree
|
||||
pub struct DecisionTreeRegressorParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The maximum depth of the tree.
|
||||
pub max_depth: Option<u16>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to be at a leaf node.
|
||||
pub min_samples_leaf: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to split an internal node.
|
||||
pub min_samples_split: usize,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Controls the randomness of the estimator
|
||||
pub seed: Option<u64>,
|
||||
}
|
||||
|
||||
/// Regression Tree
|
||||
@@ -130,6 +137,139 @@ impl Default for DecisionTreeRegressorParameters {
|
||||
max_depth: None,
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2,
|
||||
seed: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// DecisionTreeRegressor grid search parameters
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DecisionTreeRegressorSearchParameters {
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||
pub max_depth: Vec<Option<u16>>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||
pub min_samples_leaf: Vec<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||
pub min_samples_split: Vec<usize>,
|
||||
#[cfg_attr(feature = "serde", serde(default))]
|
||||
/// Controls the randomness of the estimator
|
||||
pub seed: Vec<Option<u64>>,
|
||||
}
|
||||
|
||||
/// DecisionTreeRegressor grid search iterator
|
||||
pub struct DecisionTreeRegressorSearchParametersIterator {
|
||||
decision_tree_regressor_search_parameters: DecisionTreeRegressorSearchParameters,
|
||||
current_max_depth: usize,
|
||||
current_min_samples_leaf: usize,
|
||||
current_min_samples_split: usize,
|
||||
current_seed: usize,
|
||||
}
|
||||
|
||||
impl IntoIterator for DecisionTreeRegressorSearchParameters {
|
||||
type Item = DecisionTreeRegressorParameters;
|
||||
type IntoIter = DecisionTreeRegressorSearchParametersIterator;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
DecisionTreeRegressorSearchParametersIterator {
|
||||
decision_tree_regressor_search_parameters: self,
|
||||
current_max_depth: 0,
|
||||
current_min_samples_leaf: 0,
|
||||
current_min_samples_split: 0,
|
||||
current_seed: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for DecisionTreeRegressorSearchParametersIterator {
|
||||
type Item = DecisionTreeRegressorParameters;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.current_max_depth
|
||||
== self
|
||||
.decision_tree_regressor_search_parameters
|
||||
.max_depth
|
||||
.len()
|
||||
&& self.current_min_samples_leaf
|
||||
== self
|
||||
.decision_tree_regressor_search_parameters
|
||||
.min_samples_leaf
|
||||
.len()
|
||||
&& self.current_min_samples_split
|
||||
== self
|
||||
.decision_tree_regressor_search_parameters
|
||||
.min_samples_split
|
||||
.len()
|
||||
&& self.current_seed == self.decision_tree_regressor_search_parameters.seed.len()
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let next = DecisionTreeRegressorParameters {
|
||||
max_depth: self.decision_tree_regressor_search_parameters.max_depth
|
||||
[self.current_max_depth],
|
||||
min_samples_leaf: self
|
||||
.decision_tree_regressor_search_parameters
|
||||
.min_samples_leaf[self.current_min_samples_leaf],
|
||||
min_samples_split: self
|
||||
.decision_tree_regressor_search_parameters
|
||||
.min_samples_split[self.current_min_samples_split],
|
||||
seed: self.decision_tree_regressor_search_parameters.seed[self.current_seed],
|
||||
};
|
||||
|
||||
if self.current_max_depth + 1
|
||||
< self
|
||||
.decision_tree_regressor_search_parameters
|
||||
.max_depth
|
||||
.len()
|
||||
{
|
||||
self.current_max_depth += 1;
|
||||
} else if self.current_min_samples_leaf + 1
|
||||
< self
|
||||
.decision_tree_regressor_search_parameters
|
||||
.min_samples_leaf
|
||||
.len()
|
||||
{
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf += 1;
|
||||
} else if self.current_min_samples_split + 1
|
||||
< self
|
||||
.decision_tree_regressor_search_parameters
|
||||
.min_samples_split
|
||||
.len()
|
||||
{
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split += 1;
|
||||
} else if self.current_seed + 1 < self.decision_tree_regressor_search_parameters.seed.len()
|
||||
{
|
||||
self.current_max_depth = 0;
|
||||
self.current_min_samples_leaf = 0;
|
||||
self.current_min_samples_split = 0;
|
||||
self.current_seed += 1;
|
||||
} else {
|
||||
self.current_max_depth += 1;
|
||||
self.current_min_samples_leaf += 1;
|
||||
self.current_min_samples_split += 1;
|
||||
self.current_seed += 1;
|
||||
}
|
||||
|
||||
Some(next)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DecisionTreeRegressorSearchParameters {
|
||||
fn default() -> Self {
|
||||
let default_params = DecisionTreeRegressorParameters::default();
|
||||
|
||||
DecisionTreeRegressorSearchParameters {
|
||||
max_depth: vec![default_params.max_depth],
|
||||
min_samples_leaf: vec![default_params.min_samples_leaf],
|
||||
min_samples_split: vec![default_params.min_samples_split],
|
||||
seed: vec![default_params.seed],
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -243,14 +383,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
) -> Result<DecisionTreeRegressor<T>, Failed> {
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let samples = vec![1; x_nrows];
|
||||
DecisionTreeRegressor::fit_weak_learner(
|
||||
x,
|
||||
y,
|
||||
samples,
|
||||
num_attributes,
|
||||
parameters,
|
||||
&mut rand::thread_rng(),
|
||||
)
|
||||
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||
}
|
||||
|
||||
pub(crate) fn fit_weak_learner<M: Matrix<T>>(
|
||||
@@ -259,7 +392,6 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
samples: Vec<usize>,
|
||||
mtry: usize,
|
||||
parameters: DecisionTreeRegressorParameters,
|
||||
rng: &mut impl Rng,
|
||||
) -> Result<DecisionTreeRegressor<T>, Failed> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
|
||||
@@ -267,6 +399,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
let (_, num_attributes) = x.shape();
|
||||
|
||||
let mut nodes: Vec<Node<T>> = Vec::new();
|
||||
let mut rng = get_rng_impl(parameters.seed);
|
||||
|
||||
let mut n = 0;
|
||||
let mut sum = T::zero();
|
||||
@@ -293,13 +426,13 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
|
||||
let mut visitor_queue: LinkedList<NodeVisitor<'_, T, M>> = LinkedList::new();
|
||||
|
||||
if tree.find_best_cutoff(&mut visitor, mtry, rng) {
|
||||
if tree.find_best_cutoff(&mut visitor, mtry, &mut rng) {
|
||||
visitor_queue.push_back(visitor);
|
||||
}
|
||||
|
||||
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
||||
match visitor_queue.pop_front() {
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue, rng),
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue, &mut rng),
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
@@ -321,7 +454,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
Ok(result.to_row_vector())
|
||||
}
|
||||
|
||||
pub(in crate) fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
|
||||
pub(crate) fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
|
||||
let mut result = T::zero();
|
||||
let mut queue: LinkedList<usize> = LinkedList::new();
|
||||
|
||||
@@ -517,6 +650,29 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[test]
|
||||
fn search_parameters() {
|
||||
let parameters = DecisionTreeRegressorSearchParameters {
|
||||
max_depth: vec![Some(10), Some(100)],
|
||||
min_samples_split: vec![1, 2],
|
||||
..Default::default()
|
||||
};
|
||||
let mut iter = parameters.into_iter();
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.max_depth, Some(10));
|
||||
assert_eq!(next.min_samples_split, 1);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.max_depth, Some(100));
|
||||
assert_eq!(next.min_samples_split, 1);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.max_depth, Some(10));
|
||||
assert_eq!(next.min_samples_split, 2);
|
||||
let next = iter.next().unwrap();
|
||||
assert_eq!(next.max_depth, Some(100));
|
||||
assert_eq!(next.min_samples_split, 2);
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
|
||||
#[test]
|
||||
fn fit_longley() {
|
||||
@@ -562,6 +718,7 @@ mod tests {
|
||||
max_depth: Option::None,
|
||||
min_samples_leaf: 2,
|
||||
min_samples_split: 6,
|
||||
seed: None,
|
||||
},
|
||||
)
|
||||
.and_then(|t| t.predict(&x))
|
||||
@@ -582,6 +739,7 @@ mod tests {
|
||||
max_depth: Option::None,
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 3,
|
||||
seed: None,
|
||||
},
|
||||
)
|
||||
.and_then(|t| t.predict(&x))
|
||||
|
||||
Reference in New Issue
Block a user