From b482acdc8de07d10a292be2c31f523c782ffb4ee Mon Sep 17 00:00:00 2001 From: morenol <22335041+morenol@users.noreply.github.com> Date: Wed, 13 Jul 2022 21:06:05 -0400 Subject: [PATCH 01/19] Fix clippy warnings (#139) Co-authored-by: Luis Moreno --- src/algorithm/sort/heap_select.rs | 2 +- src/linear/lasso_optimizer.rs | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/algorithm/sort/heap_select.rs b/src/algorithm/sort/heap_select.rs index beb698f..bc880bc 100644 --- a/src/algorithm/sort/heap_select.rs +++ b/src/algorithm/sort/heap_select.rs @@ -12,7 +12,7 @@ pub struct HeapSelection { heap: Vec, } -impl<'a, T: PartialOrd + Debug> HeapSelection { +impl HeapSelection { pub fn with_capacity(k: usize) -> HeapSelection { HeapSelection { k, diff --git a/src/linear/lasso_optimizer.rs b/src/linear/lasso_optimizer.rs index c4340fc..aa09128 100644 --- a/src/linear/lasso_optimizer.rs +++ b/src/linear/lasso_optimizer.rs @@ -211,9 +211,7 @@ impl> InteriorPointOptimizer { } } -impl<'a, T: RealNumber, M: Matrix> BiconjugateGradientSolver - for InteriorPointOptimizer -{ +impl> BiconjugateGradientSolver for InteriorPointOptimizer { fn solve_preconditioner(&self, a: &M, b: &M, x: &mut M) { let (_, p) = a.shape(); From d905ebea1500155c76a26f82095c939a5c5e57c9 Mon Sep 17 00:00:00 2001 From: Chris McComb Date: Fri, 12 Aug 2022 17:38:13 -0400 Subject: [PATCH 02/19] Added additional doctest and fixed indices (#141) --- src/algorithm/neighbour/bbd_tree.rs | 2 +- src/linalg/evd.rs | 19 ++++++++++++++++--- src/optimization/mod.rs | 2 +- src/svm/svr.rs | 2 +- src/tree/decision_tree_classifier.rs | 4 ++-- src/tree/decision_tree_regressor.rs | 2 +- 6 files changed, 22 insertions(+), 9 deletions(-) diff --git a/src/algorithm/neighbour/bbd_tree.rs b/src/algorithm/neighbour/bbd_tree.rs index 293a822..93ea050 100644 --- a/src/algorithm/neighbour/bbd_tree.rs +++ b/src/algorithm/neighbour/bbd_tree.rs @@ -59,7 +59,7 @@ impl BBDTree { tree } - pub(in crate) fn clustering( + pub(crate) fn clustering( &self, centroids: &[Vec], sums: &mut Vec>, diff --git a/src/linalg/evd.rs b/src/linalg/evd.rs index bf195a0..fdca1fb 100644 --- a/src/linalg/evd.rs +++ b/src/linalg/evd.rs @@ -25,6 +25,19 @@ //! let eigenvectors: DenseMatrix = evd.V; //! let eigenvalues: Vec = evd.d; //! ``` +//! ``` +//! use smartcore::linalg::naive::dense_matrix::*; +//! use smartcore::linalg::evd::*; +//! +//! let A = DenseMatrix::from_2d_array(&[ +//! &[-5.0, 2.0], +//! &[-7.0, 4.0], +//! ]); +//! +//! let evd = A.evd(false).unwrap(); +//! let eigenvectors: DenseMatrix = evd.V; +//! let eigenvalues: Vec = evd.d; +//! ``` //! //! ## References: //! * ["Numerical Recipes: The Art of Scientific Computing", Press W.H., Teukolsky S.A., Vetterling W.T, Flannery B.P, 3rd ed., Section 11 Eigensystems](http://numerical.recipes/) @@ -799,10 +812,10 @@ fn sort>(d: &mut [T], e: &mut [T], V: &mut M) { } i -= 1; } - d[i as usize + 1] = real; - e[i as usize + 1] = img; + d[(i + 1) as usize] = real; + e[(i + 1) as usize] = img; for (k, temp_k) in temp.iter().enumerate().take(n) { - V.set(k, i as usize + 1, *temp_k); + V.set(k, (i + 1) as usize, *temp_k); } } } diff --git a/src/optimization/mod.rs b/src/optimization/mod.rs index b0be9d6..127b534 100644 --- a/src/optimization/mod.rs +++ b/src/optimization/mod.rs @@ -5,7 +5,7 @@ pub type F<'a, T, X> = dyn for<'b> Fn(&'b X) -> T + 'a; pub type DF<'a, X> = dyn for<'b> Fn(&'b mut X, &'b X) + 'a; #[allow(clippy::upper_case_acronyms)] -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Eq)] pub enum FunctionOrder { SECOND, THIRD, diff --git a/src/svm/svr.rs b/src/svm/svr.rs index 3257111..18c73d1 100644 --- a/src/svm/svr.rs +++ b/src/svm/svr.rs @@ -242,7 +242,7 @@ impl, K: Kernel> SVR { Ok(y_hat) } - pub(in crate) fn predict_for_row(&self, x: M::RowVector) -> T { + pub(crate) fn predict_for_row(&self, x: M::RowVector) -> T { let mut f = self.b; for i in 0..self.instances.len() { diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index d86f59a..35889e4 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -285,7 +285,7 @@ impl<'a, T: RealNumber, M: Matrix> NodeVisitor<'a, T, M> { } } -pub(in crate) fn which_max(x: &[usize]) -> usize { +pub(crate) fn which_max(x: &[usize]) -> usize { let mut m = x[0]; let mut which = 0; @@ -421,7 +421,7 @@ impl DecisionTreeClassifier { Ok(result.to_row_vector()) } - pub(in crate) fn predict_for_row>(&self, x: &M, row: usize) -> usize { + pub(crate) fn predict_for_row>(&self, x: &M, row: usize) -> usize { let mut result = 0; let mut queue: LinkedList = LinkedList::new(); diff --git a/src/tree/decision_tree_regressor.rs b/src/tree/decision_tree_regressor.rs index 94fa0f8..25f5e7e 100644 --- a/src/tree/decision_tree_regressor.rs +++ b/src/tree/decision_tree_regressor.rs @@ -321,7 +321,7 @@ impl DecisionTreeRegressor { Ok(result.to_row_vector()) } - pub(in crate) fn predict_for_row>(&self, x: &M, row: usize) -> T { + pub(crate) fn predict_for_row>(&self, x: &M, row: usize) -> T { let mut result = T::zero(); let mut queue: LinkedList = LinkedList::new(); From a1c56a859e73525fb55536ed73a04eaee8c8aafd Mon Sep 17 00:00:00 2001 From: Lorenzo Date: Tue, 23 Aug 2022 16:56:21 +0100 Subject: [PATCH 03/19] Implement fastpair (#142) * initial fastpair implementation * FastPair initial implementation * implement fastpair * Add random test * Add bench for fastpair * Refactor with constructor for FastPair * Add serialization for PairwiseDistance * Add fp_bench feature for fastpair bench --- Cargo.toml | 7 + benches/fastpair.rs | 56 +++ src/algorithm/neighbour/distances.rs | 48 +++ src/algorithm/neighbour/fastpair.rs | 554 +++++++++++++++++++++++++++ src/algorithm/neighbour/mod.rs | 4 + 5 files changed, 669 insertions(+) create mode 100644 benches/fastpair.rs create mode 100644 src/algorithm/neighbour/distances.rs create mode 100644 src/algorithm/neighbour/fastpair.rs diff --git a/Cargo.toml b/Cargo.toml index 2978238..e83a0cc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ default = ["datasets"] ndarray-bindings = ["ndarray"] nalgebra-bindings = ["nalgebra"] datasets = [] +fp_bench = [] [dependencies] ndarray = { version = "0.15", optional = true } @@ -26,6 +27,7 @@ num = "0.4" rand = "0.8" rand_distr = "0.4" serde = { version = "1", features = ["derive"], optional = true } +itertools = "0.10.3" [target.'cfg(target_arch = "wasm32")'.dependencies] getrandom = { version = "0.2", features = ["js"] } @@ -46,3 +48,8 @@ harness = false name = "naive_bayes" harness = false required-features = ["ndarray-bindings", "nalgebra-bindings"] + +[[bench]] +name = "fastpair" +harness = false +required-features = ["fp_bench"] \ No newline at end of file diff --git a/benches/fastpair.rs b/benches/fastpair.rs new file mode 100644 index 0000000..baa0e90 --- /dev/null +++ b/benches/fastpair.rs @@ -0,0 +1,56 @@ +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; + +// to run this bench you have to change the declaraion in mod.rs ---> pub mod fastpair; +use smartcore::algorithm::neighbour::fastpair::FastPair; +use smartcore::linalg::naive::dense_matrix::*; +use std::time::Duration; + +fn closest_pair_bench(n: usize, m: usize) -> () { + let x = DenseMatrix::::rand(n, m); + let fastpair = FastPair::new(&x); + let result = fastpair.unwrap(); + + result.closest_pair(); +} + +fn closest_pair_brute_bench(n: usize, m: usize) -> () { + let x = DenseMatrix::::rand(n, m); + let fastpair = FastPair::new(&x); + let result = fastpair.unwrap(); + + result.closest_pair_brute(); +} + +fn bench_fastpair(c: &mut Criterion) { + let mut group = c.benchmark_group("FastPair"); + + // with full samples size (100) the test will take too long + group.significance_level(0.1).sample_size(30); + // increase from default 5.0 secs + group.measurement_time(Duration::from_secs(60)); + + for n_samples in [100_usize, 1000_usize].iter() { + for n_features in [10_usize, 100_usize, 1000_usize].iter() { + group.bench_with_input( + BenchmarkId::from_parameter(format!( + "fastpair --- n_samples: {}, n_features: {}", + n_samples, n_features + )), + n_samples, + |b, _| b.iter(|| closest_pair_bench(*n_samples, *n_features)), + ); + group.bench_with_input( + BenchmarkId::from_parameter(format!( + "brute --- n_samples: {}, n_features: {}", + n_samples, n_features + )), + n_samples, + |b, _| b.iter(|| closest_pair_brute_bench(*n_samples, *n_features)), + ); + } + } + group.finish(); +} + +criterion_group!(benches, bench_fastpair); +criterion_main!(benches); diff --git a/src/algorithm/neighbour/distances.rs b/src/algorithm/neighbour/distances.rs new file mode 100644 index 0000000..56a7ed6 --- /dev/null +++ b/src/algorithm/neighbour/distances.rs @@ -0,0 +1,48 @@ +//! +//! Dissimilarities for vector-vector distance +//! +//! Representing distances as pairwise dissimilarities, so to build a +//! graph of closest neighbours. This representation can be reused for +//! different implementations (initially used in this library for FastPair). +use std::cmp::{Eq, Ordering, PartialOrd}; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use crate::math::num::RealNumber; + +/// +/// The edge of the subgraph is defined by `PairwiseDistance`. +/// The calling algorithm can store a list of distsances as +/// a list of these structures. +/// +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, Copy)] +pub struct PairwiseDistance { + /// index of the vector in the original `Matrix` or list + pub node: usize, + + /// index of the closest neighbor in the original `Matrix` or same list + pub neighbour: Option, + + /// measure of distance, according to the algorithm distance function + /// if the distance is None, the edge has value "infinite" or max distance + /// each algorithm has to match + pub distance: Option, +} + +impl Eq for PairwiseDistance {} + +impl PartialEq for PairwiseDistance { + fn eq(&self, other: &Self) -> bool { + self.node == other.node + && self.neighbour == other.neighbour + && self.distance == other.distance + } +} + +impl PartialOrd for PairwiseDistance { + fn partial_cmp(&self, other: &Self) -> Option { + self.distance.partial_cmp(&other.distance) + } +} diff --git a/src/algorithm/neighbour/fastpair.rs b/src/algorithm/neighbour/fastpair.rs new file mode 100644 index 0000000..dfc6f58 --- /dev/null +++ b/src/algorithm/neighbour/fastpair.rs @@ -0,0 +1,554 @@ +#![allow(non_snake_case)] +use itertools::Itertools; +/// +/// FastPair: Data-structure for the dynamic closest-pair problem. +/// +/// Reference: +/// Eppstein, David: Fast hierarchical clustering and other applications of +/// dynamic closest pairs. Journal of Experimental Algorithmics 5 (2000) 1. +/// +use std::collections::HashMap; + +use crate::algorithm::neighbour::distances::PairwiseDistance; +use crate::error::{Failed, FailedError}; +use crate::linalg::Matrix; +use crate::math::distance::euclidian::Euclidian; +use crate::math::num::RealNumber; + +/// +/// FastPair +/// +/// Ported from Python implementation: +/// +/// MIT License (MIT) Copyright (c) 2016 Carson Farmer +/// +/// affinity used is Euclidean so to allow linkage with single, ward, complete and average +/// +#[derive(Debug, Clone)] +pub struct FastPair<'a, T: RealNumber, M: Matrix> { + /// initial matrix + samples: &'a M, + /// closest pair hashmap (connectivity matrix for closest pairs) + pub distances: HashMap>, + /// conga line used to keep track of the closest pair + pub neighbours: Vec, +} + +impl<'a, T: RealNumber, M: Matrix> FastPair<'a, T, M> { + /// + /// Constructor + /// Instantiate and inizialise the algorithm + /// + pub fn new(m: &'a M) -> Result { + if m.shape().0 < 3 { + return Err(Failed::because( + FailedError::FindFailed, + "min number of rows should be 3", + )); + } + + let mut init = Self { + samples: m, + // to be computed in init(..) + distances: HashMap::with_capacity(m.shape().0), + neighbours: Vec::with_capacity(m.shape().0 + 1), + }; + init.init(); + Ok(init) + } + + /// + /// Initialise `FastPair` by passing a `Matrix`. + /// Build a FastPairs data-structure from a set of (new) points. + /// + fn init(&mut self) { + // basic measures + let len = self.samples.shape().0; + let max_index = self.samples.shape().0 - 1; + + // Store all closest neighbors + let _distances = Box::new(HashMap::with_capacity(len)); + let _neighbours = Box::new(Vec::with_capacity(len)); + + let mut distances = *_distances; + let mut neighbours = *_neighbours; + + // fill neighbours with -1 values + neighbours.extend(0..len); + + // init closest neighbour pairwise data + for index_row_i in 0..(max_index) { + distances.insert( + index_row_i, + PairwiseDistance { + node: index_row_i, + neighbour: None, + distance: Some(T::max_value()), + }, + ); + } + + // loop through indeces and neighbours + for index_row_i in 0..(len) { + // start looking for the neighbour in the second element + let mut index_closest = index_row_i + 1; // closest neighbour index + let mut nbd: Option = distances[&index_row_i].distance; // init neighbour distance + for index_row_j in (index_row_i + 1)..len { + distances.insert( + index_row_j, + PairwiseDistance { + node: index_row_j, + neighbour: Some(index_row_i), + distance: nbd, + }, + ); + + let d = Euclidian::squared_distance( + &(self.samples.get_row_as_vec(index_row_i)), + &(self.samples.get_row_as_vec(index_row_j)), + ); + if d < nbd.unwrap() { + // set this j-value to be the closest neighbour + index_closest = index_row_j; + nbd = Some(d); + } + } + + // Add that edge + distances.entry(index_row_i).and_modify(|e| { + e.distance = nbd; + e.neighbour = Some(index_closest); + }); + } + // No more neighbors, terminate conga line. + // Last person on the line has no neigbors + distances.get_mut(&max_index).unwrap().neighbour = Some(max_index); + distances.get_mut(&(len - 1)).unwrap().distance = Some(T::max_value()); + + // compute sparse matrix (connectivity matrix) + let mut sparse_matrix = M::zeros(len, len); + for (_, p) in distances.iter() { + sparse_matrix.set(p.node, p.neighbour.unwrap(), p.distance.unwrap()); + } + + self.distances = distances; + self.neighbours = neighbours; + } + + /// + /// Find closest pair by scanning list of nearest neighbors. + /// + #[allow(dead_code)] + pub fn closest_pair(&self) -> PairwiseDistance { + let mut a = self.neighbours[0]; // Start with first point + let mut d = self.distances[&a].distance; + for p in self.neighbours.iter() { + if self.distances[p].distance < d { + a = *p; // Update `a` and distance `d` + d = self.distances[p].distance; + } + } + let b = self.distances[&a].neighbour; + PairwiseDistance { + node: a, + neighbour: b, + distance: d, + } + } + + /// + /// Brute force algorithm, used only for comparison and testing + /// + #[cfg(feature = "fp_bench")] + pub fn closest_pair_brute(&self) -> PairwiseDistance { + let m = self.samples.shape().0; + + let mut closest_pair = PairwiseDistance { + node: 0, + neighbour: None, + distance: Some(T::max_value()), + }; + for pair in (0..m).combinations(2) { + let d = Euclidian::squared_distance( + &(self.samples.get_row_as_vec(pair[0])), + &(self.samples.get_row_as_vec(pair[1])), + ); + if d < closest_pair.distance.unwrap() { + closest_pair.node = pair[0]; + closest_pair.neighbour = Some(pair[1]); + closest_pair.distance = Some(d); + } + } + closest_pair + } + + // + // Compute distances from input to all other points in data-structure. + // input is the row index of the sample matrix + // + #[allow(dead_code)] + fn distances_from(&self, index_row: usize) -> Vec> { + let mut distances = Vec::>::with_capacity(self.samples.shape().0); + for other in self.neighbours.iter() { + if index_row != *other { + distances.push(PairwiseDistance { + node: index_row, + neighbour: Some(*other), + distance: Some(Euclidian::squared_distance( + &(self.samples.get_row_as_vec(index_row)), + &(self.samples.get_row_as_vec(*other)), + )), + }) + } + } + distances + } +} + +#[cfg(test)] +mod tests_fastpair { + + use super::*; + use crate::linalg::naive::dense_matrix::*; + + #[test] + fn fastpair_init() { + let x: DenseMatrix = DenseMatrix::rand(10, 4); + let _fastpair = FastPair::new(&x); + assert!(_fastpair.is_ok()); + + let fastpair = _fastpair.unwrap(); + + let distances = fastpair.distances; + let neighbours = fastpair.neighbours; + + assert!(distances.len() != 0); + assert!(neighbours.len() != 0); + + assert_eq!(10, neighbours.len()); + assert_eq!(10, distances.len()); + } + + #[test] + fn dataset_has_at_least_three_points() { + // Create a dataset which consists of only two points: + // A(0.0, 0.0) and B(1.0, 1.0). + let dataset = DenseMatrix::::from_2d_array(&[&[0.0, 0.0], &[1.0, 1.0]]); + + // We expect an error when we run `FastPair` on this dataset, + // becuase `FastPair` currently only works on a minimum of 3 + // points. + let _fastpair = FastPair::new(&dataset); + + match _fastpair { + Err(e) => { + let expected_error = + Failed::because(FailedError::FindFailed, "min number of rows should be 3"); + assert_eq!(e, expected_error) + } + _ => { + assert!(false); + } + } + } + + #[test] + fn one_dimensional_dataset_minimal() { + let dataset = DenseMatrix::::from_2d_array(&[&[0.0], &[2.0], &[9.0]]); + + let result = FastPair::new(&dataset); + assert!(result.is_ok()); + + let fastpair = result.unwrap(); + let closest_pair = fastpair.closest_pair(); + let expected_closest_pair = PairwiseDistance { + node: 0, + neighbour: Some(1), + distance: Some(4.0), + }; + assert_eq!(closest_pair, expected_closest_pair); + + let closest_pair_brute = fastpair.closest_pair_brute(); + assert_eq!(closest_pair_brute, expected_closest_pair); + } + + #[test] + fn one_dimensional_dataset_2() { + let dataset = DenseMatrix::::from_2d_array(&[&[27.0], &[0.0], &[9.0], &[2.0]]); + + let result = FastPair::new(&dataset); + assert!(result.is_ok()); + + let fastpair = result.unwrap(); + let closest_pair = fastpair.closest_pair(); + let expected_closest_pair = PairwiseDistance { + node: 1, + neighbour: Some(3), + distance: Some(4.0), + }; + assert_eq!(closest_pair, fastpair.closest_pair_brute()); + assert_eq!(closest_pair, expected_closest_pair); + } + + #[test] + fn fastpair_new() { + // compute + let x = DenseMatrix::::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + &[4.9, 3.1, 1.5, 0.1], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + ]); + let fastpair = FastPair::new(&x); + assert!(fastpair.is_ok()); + + // unwrap results + let result = fastpair.unwrap(); + + // list of minimal pairwise dissimilarities + let dissimilarities = vec![ + ( + 1, + PairwiseDistance { + node: 1, + neighbour: Some(9), + distance: Some(0.030000000000000037), + }, + ), + ( + 10, + PairwiseDistance { + node: 10, + neighbour: Some(12), + distance: Some(0.07000000000000003), + }, + ), + ( + 11, + PairwiseDistance { + node: 11, + neighbour: Some(14), + distance: Some(0.18000000000000013), + }, + ), + ( + 12, + PairwiseDistance { + node: 12, + neighbour: Some(14), + distance: Some(0.34000000000000086), + }, + ), + ( + 13, + PairwiseDistance { + node: 13, + neighbour: Some(14), + distance: Some(1.6499999999999997), + }, + ), + ( + 14, + PairwiseDistance { + node: 14, + neighbour: Some(14), + distance: Some(f64::MAX), + }, + ), + ( + 6, + PairwiseDistance { + node: 6, + neighbour: Some(7), + distance: Some(0.18000000000000027), + }, + ), + ( + 0, + PairwiseDistance { + node: 0, + neighbour: Some(4), + distance: Some(0.01999999999999995), + }, + ), + ( + 8, + PairwiseDistance { + node: 8, + neighbour: Some(9), + distance: Some(0.3100000000000001), + }, + ), + ( + 2, + PairwiseDistance { + node: 2, + neighbour: Some(3), + distance: Some(0.0600000000000001), + }, + ), + ( + 3, + PairwiseDistance { + node: 3, + neighbour: Some(8), + distance: Some(0.08999999999999982), + }, + ), + ( + 7, + PairwiseDistance { + node: 7, + neighbour: Some(9), + distance: Some(0.10999999999999982), + }, + ), + ( + 9, + PairwiseDistance { + node: 9, + neighbour: Some(13), + distance: Some(8.69), + }, + ), + ( + 4, + PairwiseDistance { + node: 4, + neighbour: Some(7), + distance: Some(0.050000000000000086), + }, + ), + ( + 5, + PairwiseDistance { + node: 5, + neighbour: Some(7), + distance: Some(0.4900000000000002), + }, + ), + ]; + + let expected: HashMap<_, _> = dissimilarities.into_iter().collect(); + + for i in 0..(x.shape().0 - 1) { + let input_node = result.samples.get_row_as_vec(i); + let input_neighbour: usize = expected.get(&i).unwrap().neighbour.unwrap(); + let distance = Euclidian::squared_distance( + &input_node, + &result.samples.get_row_as_vec(input_neighbour), + ); + + assert_eq!(i, expected.get(&i).unwrap().node); + assert_eq!( + input_neighbour, + expected.get(&i).unwrap().neighbour.unwrap() + ); + assert_eq!(distance, expected.get(&i).unwrap().distance.unwrap()); + } + } + + #[test] + fn fastpair_closest_pair() { + let x = DenseMatrix::::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + &[4.9, 3.1, 1.5, 0.1], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + ]); + // compute + let fastpair = FastPair::new(&x); + assert!(fastpair.is_ok()); + + let dissimilarity = fastpair.unwrap().closest_pair(); + let closest = PairwiseDistance { + node: 0, + neighbour: Some(4), + distance: Some(0.01999999999999995), + }; + + assert_eq!(closest, dissimilarity); + } + + #[test] + fn fastpair_closest_pair_random_matrix() { + let x = DenseMatrix::::rand(200, 25); + // compute + let fastpair = FastPair::new(&x); + assert!(fastpair.is_ok()); + + let result = fastpair.unwrap(); + + let dissimilarity1 = result.closest_pair(); + let dissimilarity2 = result.closest_pair_brute(); + + assert_eq!(dissimilarity1, dissimilarity2); + } + + #[test] + fn fastpair_distances() { + let x = DenseMatrix::::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + &[4.9, 3.1, 1.5, 0.1], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + ]); + // compute + let fastpair = FastPair::new(&x); + assert!(fastpair.is_ok()); + + let dissimilarities = fastpair.unwrap().distances_from(0); + + let mut min_dissimilarity = PairwiseDistance { + node: 0, + neighbour: None, + distance: Some(f64::MAX), + }; + for p in dissimilarities.iter() { + if p.distance.unwrap() < min_dissimilarity.distance.unwrap() { + min_dissimilarity = p.clone() + } + } + + let closest = PairwiseDistance { + node: 0, + neighbour: Some(4), + distance: Some(0.01999999999999995), + }; + + assert_eq!(closest, min_dissimilarity); + } +} diff --git a/src/algorithm/neighbour/mod.rs b/src/algorithm/neighbour/mod.rs index 321ec01..42ab7bc 100644 --- a/src/algorithm/neighbour/mod.rs +++ b/src/algorithm/neighbour/mod.rs @@ -41,6 +41,10 @@ use serde::{Deserialize, Serialize}; pub(crate) mod bbd_tree; /// tree data structure for fast nearest neighbor search pub mod cover_tree; +/// dissimilarities for vector-vector distance. Linkage algorithms used in fastpair +pub mod distances; +/// fastpair closest neighbour algorithm +pub mod fastpair; /// very simple algorithm that sequentially checks each element of the list until a match is found or the whole list has been searched. pub mod linear_search; From 3d2f4f71fa6a9c15df154d4ed90a8e75b995a92a Mon Sep 17 00:00:00 2001 From: Lorenzo Date: Wed, 24 Aug 2022 13:40:22 +0100 Subject: [PATCH 04/19] Add example for FastPair (#144) * Add example * Move to top * Add imports to example * Fix imports --- src/algorithm/neighbour/fastpair.rs | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/src/algorithm/neighbour/fastpair.rs b/src/algorithm/neighbour/fastpair.rs index dfc6f58..e14c2b3 100644 --- a/src/algorithm/neighbour/fastpair.rs +++ b/src/algorithm/neighbour/fastpair.rs @@ -1,12 +1,30 @@ #![allow(non_snake_case)] use itertools::Itertools; /// -/// FastPair: Data-structure for the dynamic closest-pair problem. +/// # FastPair: Data-structure for the dynamic closest-pair problem. /// /// Reference: /// Eppstein, David: Fast hierarchical clustering and other applications of /// dynamic closest pairs. Journal of Experimental Algorithmics 5 (2000) 1. /// +/// Example: +/// ``` +/// use smartcore::algorithm::neighbour::distances::PairwiseDistance; +/// use smartcore::linalg::naive::dense_matrix::DenseMatrix; +/// use smartcore::algorithm::neighbour::fastpair::FastPair; +/// let x = DenseMatrix::::from_2d_array(&[ +/// &[5.1, 3.5, 1.4, 0.2], +/// &[4.9, 3.0, 1.4, 0.2], +/// &[4.7, 3.2, 1.3, 0.2], +/// &[4.6, 3.1, 1.5, 0.2], +/// &[5.0, 3.6, 1.4, 0.2], +/// &[5.4, 3.9, 1.7, 0.4], +/// ]); +/// let fastpair = FastPair::new(&x); +/// let closest_pair: PairwiseDistance = fastpair.unwrap().closest_pair(); +/// ``` +/// +/// use std::collections::HashMap; use crate::algorithm::neighbour::distances::PairwiseDistance; @@ -16,9 +34,7 @@ use crate::math::distance::euclidian::Euclidian; use crate::math::num::RealNumber; /// -/// FastPair -/// -/// Ported from Python implementation: +/// Inspired by Python implementation: /// /// MIT License (MIT) Copyright (c) 2016 Carson Farmer /// From d305406dfd4f7d1a14f788f0b538278e4c2519b9 Mon Sep 17 00:00:00 2001 From: Tim Toebrock <35797763+titoeb@users.noreply.github.com> Date: Fri, 26 Aug 2022 16:20:20 +0200 Subject: [PATCH 05/19] Implementation of Standard scaler (#143) * docs: Fix typo in doc for categorical transformer. * feat: Add option to take a column from Matrix. I created the method `Matrix::take_column` that uses the `Matrix::take`-interface to extract a single column from a matrix. I need that feature in the implementation of `StandardScaler`. * feat: Add `StandardScaler`. Authored-by: titoeb --- src/linalg/mod.rs | 21 ++ src/preprocessing/mod.rs | 4 +- src/preprocessing/numerical.rs | 404 +++++++++++++++++++++++++++++++++ 3 files changed, 428 insertions(+), 1 deletion(-) create mode 100644 src/preprocessing/numerical.rs diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs index 59b6089..8e27c0b 100644 --- a/src/linalg/mod.rs +++ b/src/linalg/mod.rs @@ -651,6 +651,10 @@ pub trait BaseMatrix: Clone + Debug { result } + /// Take an individual column from the matrix. + fn take_column(&self, column_index: usize) -> Self { + self.take(&[column_index], 1) + } } /// Generic matrix with additional mixins like various factorization methods. @@ -761,4 +765,21 @@ mod tests { assert_eq!(m.take(&vec!(1, 1, 3), 0), expected_0); assert_eq!(m.take(&vec!(1, 0), 1), expected_1); } + + #[test] + fn take_second_column_from_matrix() { + let four_columns: DenseMatrix = DenseMatrix::from_2d_array(&[ + &[0.0, 1.0, 2.0, 3.0], + &[0.0, 1.0, 2.0, 3.0], + &[0.0, 1.0, 2.0, 3.0], + &[0.0, 1.0, 2.0, 3.0], + ]); + + let second_column = four_columns.take_column(1); + assert_eq!( + second_column, + DenseMatrix::from_2d_array(&[&[1.0], &[1.0], &[1.0], &[1.0]]), + "The second column was not extracted correctly" + ); + } } diff --git a/src/preprocessing/mod.rs b/src/preprocessing/mod.rs index 32a0cfa..915fdab 100644 --- a/src/preprocessing/mod.rs +++ b/src/preprocessing/mod.rs @@ -1,5 +1,7 @@ -/// Transform a data matrix by replaceing all categorical variables with their one-hot vector equivalents +/// Transform a data matrix by replacing all categorical variables with their one-hot vector equivalents pub mod categorical; mod data_traits; +/// Preprocess numerical matrices. +pub mod numerical; /// Encode a series (column, array) of categorical variables as one-hot vectors pub mod series_encoder; diff --git a/src/preprocessing/numerical.rs b/src/preprocessing/numerical.rs new file mode 100644 index 0000000..cc90b28 --- /dev/null +++ b/src/preprocessing/numerical.rs @@ -0,0 +1,404 @@ +//! # Standard-Scaling For [RealNumber](../../math/num/trait.RealNumber.html) Matricies +//! Transform a data [Matrix](../../linalg/trait.BaseMatrix.html) by removing the mean and scaling to unit variance. +//! +//! ### Usage Example +//! ``` +//! use smartcore::api::{Transformer, UnsupervisedEstimator}; +//! use smartcore::linalg::naive::dense_matrix::DenseMatrix; +//! use smartcore::preprocessing::numerical; +//! let data = DenseMatrix::from_2d_vec(&vec![ +//! vec![0.0, 0.0], +//! vec![0.0, 0.0], +//! vec![1.0, 1.0], +//! vec![1.0, 1.0], +//! ]); +//! +//! let standard_scaler = +//! numerical::StandardScaler::fit(&data, numerical::StandardScalerParameters::default()) +//! .unwrap(); +//! let transformed_data = standard_scaler.transform(&data).unwrap(); +//! assert_eq!( +//! transformed_data, +//! DenseMatrix::from_2d_vec(&vec![ +//! vec![-1.0, -1.0], +//! vec![-1.0, -1.0], +//! vec![1.0, 1.0], +//! vec![1.0, 1.0], +//! ]) +//! ); +//! ``` +use crate::api::{Transformer, UnsupervisedEstimator}; +use crate::error::{Failed, FailedError}; +use crate::linalg::Matrix; +use crate::math::num::RealNumber; + +/// Configure Behaviour of `StandardScaler`. +#[derive(Clone, Debug, Copy, Eq, PartialEq)] +pub struct StandardScalerParameters { + /// Optionaly adjust mean to be zero. + with_mean: bool, + /// Optionally adjust standard-deviation to be one. + with_std: bool, +} +impl Default for StandardScalerParameters { + fn default() -> Self { + StandardScalerParameters { + with_mean: true, + with_std: true, + } + } +} + +/// With the `StandardScaler` data can be adjusted so +/// that every column has a mean of zero and a standard +/// deviation of one. This can improve model training for +/// scaling sensitive models like neural network or nearest +/// neighbors based models. +#[derive(Clone, Debug, Default, Eq, PartialEq)] +pub struct StandardScaler { + means: Vec, + stds: Vec, + parameters: StandardScalerParameters, +} +impl StandardScaler { + /// When the mean should be adjusted, the column mean + /// should be kept. Otherwise, replace it by zero. + fn adjust_column_mean(&self, mean: T) -> T { + if self.parameters.with_mean { + mean + } else { + T::zero() + } + } + /// When the standard-deviation should be adjusted, the column + /// standard-deviation should be kept. Otherwise, replace it by one. + fn adjust_column_std(&self, std: T) -> T { + if self.parameters.with_std { + ensure_std_valid(std) + } else { + T::one() + } + } +} + +/// Make sure the standard deviation is valid. If it is +/// negative or zero, it should replaced by the smallest +/// positive value the type can have. That way we can savely +/// divide the columns with the resulting scalar. +fn ensure_std_valid(value: T) -> T { + value.max(T::min_positive_value()) +} + +/// During `fit` the `StandardScaler` computes the column means and standard deviation. +impl> UnsupervisedEstimator + for StandardScaler +{ + fn fit(x: &M, parameters: StandardScalerParameters) -> Result { + Ok(Self { + means: x.column_mean(), + stds: x.std(0), + parameters, + }) + } +} + +/// During `transform` the `StandardScaler` applies the summary statistics +/// computed during `fit` to set the mean of each column to zero and the +/// standard deviation to one. +impl> Transformer for StandardScaler { + fn transform(&self, x: &M) -> Result { + let (_, n_cols) = x.shape(); + if n_cols != self.means.len() { + return Err(Failed::because( + FailedError::TransformFailed, + &format!( + "Expected {} columns, but got {} columns instead.", + self.means.len(), + n_cols, + ), + )); + } + + Ok(build_matrix_from_columns( + self.means + .iter() + .zip(self.stds.iter()) + .enumerate() + .map(|(column_index, (column_mean, column_std))| { + x.take_column(column_index) + .sub_scalar(self.adjust_column_mean(*column_mean)) + .div_scalar(self.adjust_column_std(*column_std)) + }) + .collect(), + ) + .unwrap()) + } +} + +/// From a collection of matrices, that contain columns, construct +/// a matrix by stacking the columns horizontally. +fn build_matrix_from_columns(columns: Vec) -> Option +where + T: RealNumber, + M: Matrix, +{ + if let Some(output_matrix) = columns.first().cloned() { + return Some( + columns + .iter() + .skip(1) + .fold(output_matrix, |current_matrix, new_colum| { + current_matrix.h_stack(new_colum) + }), + ); + } else { + None + } +} + +#[cfg(test)] +mod tests { + + mod helper_functionality { + use super::super::{build_matrix_from_columns, ensure_std_valid}; + use crate::linalg::naive::dense_matrix::DenseMatrix; + + #[test] + fn combine_three_columns() { + assert_eq!( + build_matrix_from_columns(vec![ + DenseMatrix::from_2d_vec(&vec![vec![1.0], vec![1.0], vec![1.0],]), + DenseMatrix::from_2d_vec(&vec![vec![2.0], vec![2.0], vec![2.0],]), + DenseMatrix::from_2d_vec(&vec![vec![3.0], vec![3.0], vec![3.0],]) + ]), + Some(DenseMatrix::from_2d_vec(&vec![ + vec![1.0, 2.0, 3.0], + vec![1.0, 2.0, 3.0], + vec![1.0, 2.0, 3.0] + ])) + ) + } + + #[test] + fn negative_value_should_be_replace_with_minimal_positive_value() { + assert_eq!(ensure_std_valid(-1.0), f64::MIN_POSITIVE) + } + + #[test] + fn zero_should_be_replace_with_minimal_positive_value() { + assert_eq!(ensure_std_valid(0.0), f64::MIN_POSITIVE) + } + } + mod standard_scaler { + use super::super::{StandardScaler, StandardScalerParameters}; + use crate::api::{Transformer, UnsupervisedEstimator}; + use crate::linalg::naive::dense_matrix::DenseMatrix; + use crate::linalg::BaseMatrix; + + #[test] + fn dont_adjust_mean_if_used() { + assert_eq!( + (StandardScaler { + means: vec![], + stds: vec![], + parameters: StandardScalerParameters { + with_mean: true, + with_std: true + } + }) + .adjust_column_mean(1.0), + 1.0 + ) + } + #[test] + fn replace_mean_with_zero_if_not_used() { + assert_eq!( + (StandardScaler { + means: vec![], + stds: vec![], + parameters: StandardScalerParameters { + with_mean: false, + with_std: true + } + }) + .adjust_column_mean(1.0), + 0.0 + ) + } + #[test] + fn dont_adjust_std_if_used() { + assert_eq!( + (StandardScaler { + means: vec![], + stds: vec![], + parameters: StandardScalerParameters { + with_mean: true, + with_std: true + } + }) + .adjust_column_std(10.0), + 10.0 + ) + } + #[test] + fn replace_std_with_one_if_not_used() { + assert_eq!( + (StandardScaler { + means: vec![], + stds: vec![], + parameters: StandardScalerParameters { + with_mean: true, + with_std: false + } + }) + .adjust_column_std(10.0), + 1.0 + ) + } + + /// Helper function to apply fit as well as transform at the same time. + fn fit_transform_with_default_standard_scaler( + values_to_be_transformed: &DenseMatrix, + ) -> DenseMatrix { + StandardScaler::fit( + values_to_be_transformed, + StandardScalerParameters::default(), + ) + .unwrap() + .transform(values_to_be_transformed) + .unwrap() + } + + /// Fit transform with random generated values, expected values taken from + /// sklearn. + #[test] + fn fit_transform_random_values() { + let transformed_values = + fit_transform_with_default_standard_scaler(&DenseMatrix::from_2d_array(&[ + &[0.1004222429, 0.2194113576, 0.9310663354, 0.3313593793], + &[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264], + &[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046], + &[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442], + ])); + println!("{}", transformed_values); + assert!(transformed_values.approximate_eq( + &DenseMatrix::from_2d_array(&[ + &[-1.1154020653, -0.4031985330, 0.9284605204, -0.4271473866], + &[-0.7615464283, -0.7076698384, -1.1075452562, 1.2632979631], + &[0.4832504303, -0.6106747444, 1.0630075435, 0.5494084257], + &[1.3936980634, 1.7215431158, -0.8839228078, -1.3855590021], + ]), + 1.0 + )) + } + + /// Test `fit` and `transform` for a column with zero variance. + #[test] + fn fit_transform_with_zero_variance() { + assert_eq!( + fit_transform_with_default_standard_scaler(&DenseMatrix::from_2d_array(&[ + &[1.0], + &[1.0], + &[1.0], + &[1.0] + ])), + DenseMatrix::from_2d_array(&[&[0.0], &[0.0], &[0.0], &[0.0]]), + "When scaling values with zero variance, zero is expected as return value" + ) + } + + /// Test `fit` for columns with nice summary statistics. + #[test] + fn fit_for_simple_values() { + assert_eq!( + StandardScaler::fit( + &DenseMatrix::from_2d_array(&[ + &[1.0, 1.0, 1.0], + &[1.0, 2.0, 5.0], + &[1.0, 1.0, 1.0], + &[1.0, 2.0, 5.0] + ]), + StandardScalerParameters::default(), + ), + Ok(StandardScaler { + means: vec![1.0, 1.5, 3.0], + stds: vec![0.0, 0.5, 2.0], + parameters: StandardScalerParameters { + with_mean: true, + with_std: true + } + }) + ) + } + /// Test `fit` for random generated values. + #[test] + fn fit_for_random_values() { + let fitted_scaler = StandardScaler::fit( + &DenseMatrix::from_2d_array(&[ + &[0.1004222429, 0.2194113576, 0.9310663354, 0.3313593793], + &[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264], + &[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046], + &[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442], + ]), + StandardScalerParameters::default(), + ) + .unwrap(); + + assert_eq!( + fitted_scaler.means, + vec![0.42864544605, 0.2869813741, 0.737752073825, 0.431011663625], + ); + + assert!( + &DenseMatrix::from_2d_vec(&vec![fitted_scaler.stds]).approximate_eq( + &DenseMatrix::from_2d_array(&[&[ + 0.29426447500954, + 0.16758497615485, + 0.20820945786863, + 0.23329718831165 + ],]), + 0.00000000000001 + ) + ) + } + + /// If `with_std` is set to `false` the values should not be + /// adjusted to have a std of one. + #[test] + fn transform_without_std() { + let standard_scaler = StandardScaler { + means: vec![1.0, 3.0], + stds: vec![1.0, 2.0], + parameters: StandardScalerParameters { + with_mean: true, + with_std: false, + }, + }; + + assert_eq!( + standard_scaler.transform(&DenseMatrix::from_2d_array(&[&[0.0, 2.0], &[2.0, 4.0]])), + Ok(DenseMatrix::from_2d_array(&[&[-1.0, -1.0], &[1.0, 1.0]])) + ) + } + + /// If `with_mean` is set to `false` the values should not be adjusted + /// to have a mean of zero. + #[test] + fn transform_without_mean() { + let standard_scaler = StandardScaler { + means: vec![1.0, 2.0], + stds: vec![2.0, 3.0], + parameters: StandardScalerParameters { + with_mean: false, + with_std: true, + }, + }; + + assert_eq!( + standard_scaler + .transform(&DenseMatrix::from_2d_array(&[&[0.0, 9.0], &[4.0, 12.0]])), + Ok(DenseMatrix::from_2d_array(&[&[0.0, 3.0], &[2.0, 4.0]])) + ) + } + } +} From 4d5f64c7585cc5edaa62a2dadeaa2f878ec81993 Mon Sep 17 00:00:00 2001 From: Christos Katsakioris Date: Tue, 6 Sep 2022 20:37:54 +0300 Subject: [PATCH 06/19] Add serde for StandardScaler (#148) * Derive `serde::Serialize` and `serde::Deserialize` for `StandardScaler`. * Add relevant unit test. Signed-off-by: Christos Katsakioris Signed-off-by: Christos Katsakioris --- src/preprocessing/numerical.rs | 43 ++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/src/preprocessing/numerical.rs b/src/preprocessing/numerical.rs index cc90b28..e2205c3 100644 --- a/src/preprocessing/numerical.rs +++ b/src/preprocessing/numerical.rs @@ -32,7 +32,11 @@ use crate::error::{Failed, FailedError}; use crate::linalg::Matrix; use crate::math::num::RealNumber; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + /// Configure Behaviour of `StandardScaler`. +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Clone, Debug, Copy, Eq, PartialEq)] pub struct StandardScalerParameters { /// Optionaly adjust mean to be zero. @@ -54,6 +58,7 @@ impl Default for StandardScalerParameters { /// deviation of one. This can improve model training for /// scaling sensitive models like neural network or nearest /// neighbors based models. +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Clone, Debug, Default, Eq, PartialEq)] pub struct StandardScaler { means: Vec, @@ -400,5 +405,43 @@ mod tests { Ok(DenseMatrix::from_2d_array(&[&[0.0, 3.0], &[2.0, 4.0]])) ) } + + /// Same as `fit_for_random_values` test, but using a `StandardScaler` that has been + /// serialized and deserialized. + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + #[cfg(feature = "serde")] + fn serde_fit_for_random_values() { + let fitted_scaler = StandardScaler::fit( + &DenseMatrix::from_2d_array(&[ + &[0.1004222429, 0.2194113576, 0.9310663354, 0.3313593793], + &[0.2045493861, 0.1683865411, 0.5071506765, 0.7257355264], + &[0.5708488802, 0.1846414616, 0.9590802982, 0.5591871046], + &[0.8387612750, 0.5754861361, 0.5537109852, 0.1077646442], + ]), + StandardScalerParameters::default(), + ) + .unwrap(); + + let deserialized_scaler: StandardScaler = + serde_json::from_str(&serde_json::to_string(&fitted_scaler).unwrap()).unwrap(); + + assert_eq!( + deserialized_scaler.means, + vec![0.42864544605, 0.2869813741, 0.737752073825, 0.431011663625], + ); + + assert!( + &DenseMatrix::from_2d_vec(&vec![deserialized_scaler.stds]).approximate_eq( + &DenseMatrix::from_2d_array(&[&[ + 0.29426447500954, + 0.16758497615485, + 0.20820945786863, + 0.23329718831165 + ],]), + 0.00000000000001 + ) + ) + } } } From e445f0d558319b282e9846ff575c22f1c83e73ea Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 12 Sep 2022 12:03:43 -0400 Subject: [PATCH 07/19] Update criterion requirement from 0.3 to 0.4 (#150) * Update criterion requirement from 0.3 to 0.4 Updates the requirements on [criterion](https://github.com/bheisler/criterion.rs) to permit the latest version. - [Release notes](https://github.com/bheisler/criterion.rs/releases) - [Changelog](https://github.com/bheisler/criterion.rs/blob/master/CHANGELOG.md) - [Commits](https://github.com/bheisler/criterion.rs/compare/0.3.0...0.4.0) --- updated-dependencies: - dependency-name: criterion dependency-type: direct:production ... Signed-off-by: dependabot[bot] * fix criterion Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Luis Moreno --- Cargo.toml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e83a0cc..069e223 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,7 +33,8 @@ itertools = "0.10.3" getrandom = { version = "0.2", features = ["js"] } [dev-dependencies] -criterion = "0.3" +smartcore = { path = ".", features = ["fp_bench"] } +criterion = { version = "0.4", default-features = false } serde_json = "1.0" bincode = "1.3.1" @@ -52,4 +53,4 @@ required-features = ["ndarray-bindings", "nalgebra-bindings"] [[bench]] name = "fastpair" harness = false -required-features = ["fp_bench"] \ No newline at end of file +required-features = ["fp_bench"] From 2e5f88fad8a3c5f9c124cb99f06be33640c53b50 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Tue, 13 Sep 2022 08:23:45 -0700 Subject: [PATCH 08/19] Handle multiclass precision/recall (#152) * handle multiclass precision/recall --- src/math/num.rs | 13 +++++++- src/metrics/precision.rs | 64 ++++++++++++++++++++++++-------------- src/metrics/recall.rs | 66 ++++++++++++++++++++++++++-------------- 3 files changed, 97 insertions(+), 46 deletions(-) diff --git a/src/math/num.rs b/src/math/num.rs index 7199949..c454b9d 100644 --- a/src/math/num.rs +++ b/src/math/num.rs @@ -46,8 +46,11 @@ pub trait RealNumber: self * self } - /// Raw transmutation to u64 + /// Raw transmutation to u32 fn to_f32_bits(self) -> u32; + + /// Raw transmutation to u64 + fn to_f64_bits(self) -> u64; } impl RealNumber for f64 { @@ -89,6 +92,10 @@ impl RealNumber for f64 { fn to_f32_bits(self) -> u32 { self.to_bits() as u32 } + + fn to_f64_bits(self) -> u64 { + self.to_bits() + } } impl RealNumber for f32 { @@ -130,6 +137,10 @@ impl RealNumber for f32 { fn to_f32_bits(self) -> u32 { self.to_bits() } + + fn to_f64_bits(self) -> u64 { + self.to_bits() as u64 + } } #[cfg(test)] diff --git a/src/metrics/precision.rs b/src/metrics/precision.rs index a0171aa..a2bad30 100644 --- a/src/metrics/precision.rs +++ b/src/metrics/precision.rs @@ -18,6 +18,8 @@ //! //! //! +use std::collections::HashSet; + #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -42,34 +44,33 @@ impl Precision { ); } + let mut classes = HashSet::new(); + for i in 0..y_true.len() { + classes.insert(y_true.get(i).to_f64_bits()); + } + let classes = classes.len(); + let mut tp = 0; - let mut p = 0; - let n = y_true.len(); - for i in 0..n { - if y_true.get(i) != T::zero() && y_true.get(i) != T::one() { - panic!( - "Precision can only be applied to binary classification: {}", - y_true.get(i) - ); - } - - if y_pred.get(i) != T::zero() && y_pred.get(i) != T::one() { - panic!( - "Precision can only be applied to binary classification: {}", - y_pred.get(i) - ); - } - - if y_pred.get(i) == T::one() { - p += 1; - - if y_true.get(i) == T::one() { + let mut fp = 0; + for i in 0..y_true.len() { + if y_pred.get(i) == y_true.get(i) { + if classes == 2 { + if y_true.get(i) == T::one() { + tp += 1; + } + } else { tp += 1; } + } else if classes == 2 { + if y_true.get(i) == T::one() { + fp += 1; + } + } else { + fp += 1; } } - T::from_i64(tp).unwrap() / T::from_i64(p).unwrap() + T::from_i64(tp).unwrap() / (T::from_i64(tp).unwrap() + T::from_i64(fp).unwrap()) } } @@ -88,5 +89,24 @@ mod tests { assert!((score1 - 0.5).abs() < 1e-8); assert!((score2 - 1.0).abs() < 1e-8); + + let y_pred: Vec = vec![0., 0., 1., 1., 1., 1.]; + let y_true: Vec = vec![0., 1., 1., 0., 1., 0.]; + + let score3: f64 = Precision {}.get_score(&y_pred, &y_true); + assert!((score3 - 0.5).abs() < 1e-8); + } + + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + fn precision_multiclass() { + let y_true: Vec = vec![0., 0., 0., 1., 1., 1., 2., 2., 2.]; + let y_pred: Vec = vec![0., 1., 2., 0., 1., 2., 0., 1., 2.]; + + let score1: f64 = Precision {}.get_score(&y_pred, &y_true); + let score2: f64 = Precision {}.get_score(&y_pred, &y_pred); + + assert!((score1 - 0.333333333).abs() < 1e-8); + assert!((score2 - 1.0).abs() < 1e-8); } } diff --git a/src/metrics/recall.rs b/src/metrics/recall.rs index 18863ae..48ddeeb 100644 --- a/src/metrics/recall.rs +++ b/src/metrics/recall.rs @@ -18,6 +18,9 @@ //! //! //! +use std::collections::HashSet; +use std::convert::TryInto; + #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -42,34 +45,32 @@ impl Recall { ); } + let mut classes = HashSet::new(); + for i in 0..y_true.len() { + classes.insert(y_true.get(i).to_f64_bits()); + } + let classes: i64 = classes.len().try_into().unwrap(); + let mut tp = 0; - let mut p = 0; - let n = y_true.len(); - for i in 0..n { - if y_true.get(i) != T::zero() && y_true.get(i) != T::one() { - panic!( - "Recall can only be applied to binary classification: {}", - y_true.get(i) - ); - } - - if y_pred.get(i) != T::zero() && y_pred.get(i) != T::one() { - panic!( - "Recall can only be applied to binary classification: {}", - y_pred.get(i) - ); - } - - if y_true.get(i) == T::one() { - p += 1; - - if y_pred.get(i) == T::one() { + let mut fne = 0; + for i in 0..y_true.len() { + if y_pred.get(i) == y_true.get(i) { + if classes == 2 { + if y_true.get(i) == T::one() { + tp += 1; + } + } else { tp += 1; } + } else if classes == 2 { + if y_true.get(i) != T::one() { + fne += 1; + } + } else { + fne += 1; } } - - T::from_i64(tp).unwrap() / T::from_i64(p).unwrap() + T::from_i64(tp).unwrap() / (T::from_i64(tp).unwrap() + T::from_i64(fne).unwrap()) } } @@ -88,5 +89,24 @@ mod tests { assert!((score1 - 0.5).abs() < 1e-8); assert!((score2 - 1.0).abs() < 1e-8); + + let y_pred: Vec = vec![0., 0., 1., 1., 1., 1.]; + let y_true: Vec = vec![0., 1., 1., 0., 1., 0.]; + + let score3: f64 = Recall {}.get_score(&y_pred, &y_true); + assert!((score3 - 0.66666666).abs() < 1e-8); + } + + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + fn recall_multiclass() { + let y_true: Vec = vec![0., 0., 0., 1., 1., 1., 2., 2., 2.]; + let y_pred: Vec = vec![0., 1., 2., 0., 1., 2., 0., 1., 2.]; + + let score1: f64 = Recall {}.get_score(&y_pred, &y_true); + let score2: f64 = Recall {}.get_score(&y_pred, &y_pred); + + assert!((score1 - 0.333333333).abs() < 1e-8); + assert!((score2 - 1.0).abs() < 1e-8); } } From 4685fc73e08da4168f9d2eb91d914cc1bac37460 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Mon, 19 Sep 2022 02:31:56 -0700 Subject: [PATCH 09/19] grid search (#154) * grid search draft * hyperparam search for linear estimators --- src/linear/elastic_net.rs | 138 ++++++++++++++++++++++++++++ src/linear/lasso.rs | 122 ++++++++++++++++++++++++ src/linear/linear_regression.rs | 70 +++++++++++++- src/linear/logistic_regression.rs | 88 +++++++++++++++++- src/linear/ridge_regression.rs | 101 +++++++++++++++++++- src/model_selection/hyper_tuning.rs | 117 +++++++++++++++++++++++ src/model_selection/mod.rs | 24 +++-- 7 files changed, 649 insertions(+), 11 deletions(-) create mode 100644 src/model_selection/hyper_tuning.rs diff --git a/src/linear/elastic_net.rs b/src/linear/elastic_net.rs index ce13435..0e9cb57 100644 --- a/src/linear/elastic_net.rs +++ b/src/linear/elastic_net.rs @@ -135,6 +135,121 @@ impl Default for ElasticNetParameters { } } +/// ElasticNet grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct ElasticNetSearchParameters { + /// Regularization parameter. + pub alpha: Vec, + /// The elastic net mixing parameter, with 0 <= l1_ratio <= 1. + /// For l1_ratio = 0 the penalty is an L2 penalty. + /// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2. + pub l1_ratio: Vec, + /// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation. + pub normalize: Vec, + /// The tolerance for the optimization + pub tol: Vec, + /// The maximum number of iterations + pub max_iter: Vec, +} + +/// ElasticNet grid search iterator +pub struct ElasticNetSearchParametersIterator { + lasso_regression_search_parameters: ElasticNetSearchParameters, + current_alpha: usize, + current_l1_ratio: usize, + current_normalize: usize, + current_tol: usize, + current_max_iter: usize, +} + +impl IntoIterator for ElasticNetSearchParameters { + type Item = ElasticNetParameters; + type IntoIter = ElasticNetSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + ElasticNetSearchParametersIterator { + lasso_regression_search_parameters: self, + current_alpha: 0, + current_l1_ratio: 0, + current_normalize: 0, + current_tol: 0, + current_max_iter: 0, + } + } +} + +impl Iterator for ElasticNetSearchParametersIterator { + type Item = ElasticNetParameters; + + fn next(&mut self) -> Option { + if self.current_alpha == self.lasso_regression_search_parameters.alpha.len() + && self.current_l1_ratio == self.lasso_regression_search_parameters.l1_ratio.len() + && self.current_normalize == self.lasso_regression_search_parameters.normalize.len() + && self.current_tol == self.lasso_regression_search_parameters.tol.len() + && self.current_max_iter == self.lasso_regression_search_parameters.max_iter.len() + { + return None; + } + + let next = ElasticNetParameters { + alpha: self.lasso_regression_search_parameters.alpha[self.current_alpha], + l1_ratio: self.lasso_regression_search_parameters.alpha[self.current_l1_ratio], + normalize: self.lasso_regression_search_parameters.normalize[self.current_normalize], + tol: self.lasso_regression_search_parameters.tol[self.current_tol], + max_iter: self.lasso_regression_search_parameters.max_iter[self.current_max_iter], + }; + + if self.current_alpha + 1 < self.lasso_regression_search_parameters.alpha.len() { + self.current_alpha += 1; + } else if self.current_l1_ratio + 1 < self.lasso_regression_search_parameters.l1_ratio.len() + { + self.current_alpha = 0; + self.current_l1_ratio += 1; + } else if self.current_normalize + 1 + < self.lasso_regression_search_parameters.normalize.len() + { + self.current_alpha = 0; + self.current_l1_ratio = 0; + self.current_normalize += 1; + } else if self.current_tol + 1 < self.lasso_regression_search_parameters.tol.len() { + self.current_alpha = 0; + self.current_l1_ratio = 0; + self.current_normalize = 0; + self.current_tol += 1; + } else if self.current_max_iter + 1 < self.lasso_regression_search_parameters.max_iter.len() + { + self.current_alpha = 0; + self.current_l1_ratio = 0; + self.current_normalize = 0; + self.current_tol = 0; + self.current_max_iter += 1; + } else { + self.current_alpha += 1; + self.current_l1_ratio += 1; + self.current_normalize += 1; + self.current_tol += 1; + self.current_max_iter += 1; + } + + Some(next) + } +} + +impl Default for ElasticNetSearchParameters { + fn default() -> Self { + let default_params = ElasticNetParameters::default(); + + ElasticNetSearchParameters { + alpha: vec![default_params.alpha], + l1_ratio: vec![default_params.l1_ratio], + normalize: vec![default_params.normalize], + tol: vec![default_params.tol], + max_iter: vec![default_params.max_iter], + } + } +} + impl> PartialEq for ElasticNet { fn eq(&self, other: &Self) -> bool { self.coefficients == other.coefficients @@ -291,6 +406,29 @@ mod tests { use crate::linalg::naive::dense_matrix::*; use crate::metrics::mean_absolute_error; + #[test] + fn search_parameters() { + let parameters = ElasticNetSearchParameters { + alpha: vec![0., 1.], + max_iter: vec![10, 100], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 0.); + assert_eq!(next.max_iter, 10); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 1.); + assert_eq!(next.max_iter, 10); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 0.); + assert_eq!(next.max_iter, 100); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 1.); + assert_eq!(next.max_iter, 100); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn elasticnet_longley() { diff --git a/src/linear/lasso.rs b/src/linear/lasso.rs index 7edd325..7e80a8b 100644 --- a/src/linear/lasso.rs +++ b/src/linear/lasso.rs @@ -112,6 +112,105 @@ impl> Predictor for Lasso { } } +/// Lasso grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct LassoSearchParameters { + /// Controls the strength of the penalty to the loss function. + pub alpha: Vec, + /// If true the regressors X will be normalized before regression + /// by subtracting the mean and dividing by the standard deviation. + pub normalize: Vec, + /// The tolerance for the optimization + pub tol: Vec, + /// The maximum number of iterations + pub max_iter: Vec, +} + +/// Lasso grid search iterator +pub struct LassoSearchParametersIterator { + lasso_regression_search_parameters: LassoSearchParameters, + current_alpha: usize, + current_normalize: usize, + current_tol: usize, + current_max_iter: usize, +} + +impl IntoIterator for LassoSearchParameters { + type Item = LassoParameters; + type IntoIter = LassoSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + LassoSearchParametersIterator { + lasso_regression_search_parameters: self, + current_alpha: 0, + current_normalize: 0, + current_tol: 0, + current_max_iter: 0, + } + } +} + +impl Iterator for LassoSearchParametersIterator { + type Item = LassoParameters; + + fn next(&mut self) -> Option { + if self.current_alpha == self.lasso_regression_search_parameters.alpha.len() + && self.current_normalize == self.lasso_regression_search_parameters.normalize.len() + && self.current_tol == self.lasso_regression_search_parameters.tol.len() + && self.current_max_iter == self.lasso_regression_search_parameters.max_iter.len() + { + return None; + } + + let next = LassoParameters { + alpha: self.lasso_regression_search_parameters.alpha[self.current_alpha], + normalize: self.lasso_regression_search_parameters.normalize[self.current_normalize], + tol: self.lasso_regression_search_parameters.tol[self.current_tol], + max_iter: self.lasso_regression_search_parameters.max_iter[self.current_max_iter], + }; + + if self.current_alpha + 1 < self.lasso_regression_search_parameters.alpha.len() { + self.current_alpha += 1; + } else if self.current_normalize + 1 + < self.lasso_regression_search_parameters.normalize.len() + { + self.current_alpha = 0; + self.current_normalize += 1; + } else if self.current_tol + 1 < self.lasso_regression_search_parameters.tol.len() { + self.current_alpha = 0; + self.current_normalize = 0; + self.current_tol += 1; + } else if self.current_max_iter + 1 < self.lasso_regression_search_parameters.max_iter.len() + { + self.current_alpha = 0; + self.current_normalize = 0; + self.current_tol = 0; + self.current_max_iter += 1; + } else { + self.current_alpha += 1; + self.current_normalize += 1; + self.current_tol += 1; + self.current_max_iter += 1; + } + + Some(next) + } +} + +impl Default for LassoSearchParameters { + fn default() -> Self { + let default_params = LassoParameters::default(); + + LassoSearchParameters { + alpha: vec![default_params.alpha], + normalize: vec![default_params.normalize], + tol: vec![default_params.tol], + max_iter: vec![default_params.max_iter], + } + } +} + impl> Lasso { /// Fits Lasso regression to your data. /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation. @@ -226,6 +325,29 @@ mod tests { use crate::linalg::naive::dense_matrix::*; use crate::metrics::mean_absolute_error; + #[test] + fn search_parameters() { + let parameters = LassoSearchParameters { + alpha: vec![0., 1.], + max_iter: vec![10, 100], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 0.); + assert_eq!(next.max_iter, 10); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 1.); + assert_eq!(next.max_iter, 10); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 0.); + assert_eq!(next.max_iter, 100); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 1.); + assert_eq!(next.max_iter, 100); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn lasso_fit_predict() { diff --git a/src/linear/linear_regression.rs b/src/linear/linear_regression.rs index b1f7c51..c95e6e1 100644 --- a/src/linear/linear_regression.rs +++ b/src/linear/linear_regression.rs @@ -71,7 +71,7 @@ use crate::linalg::Matrix; use crate::math::num::RealNumber; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Eq, PartialEq)] /// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable. pub enum LinearRegressionSolverName { /// QR decomposition, see [QR](../../linalg/qr/index.html) @@ -113,6 +113,60 @@ impl Default for LinearRegressionParameters { } } +/// Linear Regression grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct LinearRegressionSearchParameters { + /// Solver to use for estimation of regression coefficients. + pub solver: Vec, +} + +/// Linear Regression grid search iterator +pub struct LinearRegressionSearchParametersIterator { + linear_regression_search_parameters: LinearRegressionSearchParameters, + current_solver: usize, +} + +impl IntoIterator for LinearRegressionSearchParameters { + type Item = LinearRegressionParameters; + type IntoIter = LinearRegressionSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + LinearRegressionSearchParametersIterator { + linear_regression_search_parameters: self, + current_solver: 0, + } + } +} + +impl Iterator for LinearRegressionSearchParametersIterator { + type Item = LinearRegressionParameters; + + fn next(&mut self) -> Option { + if self.current_solver == self.linear_regression_search_parameters.solver.len() { + return None; + } + + let next = LinearRegressionParameters { + solver: self.linear_regression_search_parameters.solver[self.current_solver].clone(), + }; + + self.current_solver += 1; + + Some(next) + } +} + +impl Default for LinearRegressionSearchParameters { + fn default() -> Self { + let default_params = LinearRegressionParameters::default(); + + LinearRegressionSearchParameters { + solver: vec![default_params.solver], + } + } +} + impl> PartialEq for LinearRegression { fn eq(&self, other: &Self) -> bool { self.coefficients == other.coefficients @@ -200,6 +254,20 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::*; + #[test] + fn search_parameters() { + let parameters = LinearRegressionSearchParameters { + solver: vec![ + LinearRegressionSolverName::QR, + LinearRegressionSolverName::SVD, + ], + }; + let mut iter = parameters.into_iter(); + assert_eq!(iter.next().unwrap().solver, LinearRegressionSolverName::QR); + assert_eq!(iter.next().unwrap().solver, LinearRegressionSolverName::SVD); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn ols_fit_predict() { diff --git a/src/linear/logistic_regression.rs b/src/linear/logistic_regression.rs index 1a20077..3a4c706 100644 --- a/src/linear/logistic_regression.rs +++ b/src/linear/logistic_regression.rs @@ -68,7 +68,7 @@ use crate::optimization::line_search::Backtracking; use crate::optimization::FunctionOrder; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Eq, PartialEq)] /// Solver options for Logistic regression. Right now only LBFGS solver is supported. pub enum LogisticRegressionSolverName { /// Limited-memory Broyden–Fletcher–Goldfarb–Shanno method, see [LBFGS paper](http://users.iems.northwestern.edu/~nocedal/lbfgsb.html) @@ -85,6 +85,77 @@ pub struct LogisticRegressionParameters { pub alpha: T, } +/// Logistic Regression grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct LogisticRegressionSearchParameters { + /// Solver to use for estimation of regression coefficients. + pub solver: Vec, + /// Regularization parameter. + pub alpha: Vec, +} + +/// Logistic Regression grid search iterator +pub struct LogisticRegressionSearchParametersIterator { + logistic_regression_search_parameters: LogisticRegressionSearchParameters, + current_solver: usize, + current_alpha: usize, +} + +impl IntoIterator for LogisticRegressionSearchParameters { + type Item = LogisticRegressionParameters; + type IntoIter = LogisticRegressionSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + LogisticRegressionSearchParametersIterator { + logistic_regression_search_parameters: self, + current_solver: 0, + current_alpha: 0, + } + } +} + +impl Iterator for LogisticRegressionSearchParametersIterator { + type Item = LogisticRegressionParameters; + + fn next(&mut self) -> Option { + if self.current_alpha == self.logistic_regression_search_parameters.alpha.len() + && self.current_solver == self.logistic_regression_search_parameters.solver.len() + { + return None; + } + + let next = LogisticRegressionParameters { + solver: self.logistic_regression_search_parameters.solver[self.current_solver].clone(), + alpha: self.logistic_regression_search_parameters.alpha[self.current_alpha], + }; + + if self.current_alpha + 1 < self.logistic_regression_search_parameters.alpha.len() { + self.current_alpha += 1; + } else if self.current_solver + 1 < self.logistic_regression_search_parameters.solver.len() + { + self.current_alpha = 0; + self.current_solver += 1; + } else { + self.current_alpha += 1; + self.current_solver += 1; + } + + Some(next) + } +} + +impl Default for LogisticRegressionSearchParameters { + fn default() -> Self { + let default_params = LogisticRegressionParameters::default(); + + LogisticRegressionSearchParameters { + solver: vec![default_params.solver], + alpha: vec![default_params.alpha], + } + } +} + /// Logistic Regression #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug)] @@ -452,6 +523,21 @@ mod tests { use crate::linalg::naive::dense_matrix::*; use crate::metrics::accuracy; + #[test] + fn search_parameters() { + let parameters = LogisticRegressionSearchParameters { + alpha: vec![0., 1.], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + assert_eq!(iter.next().unwrap().alpha, 0.); + assert_eq!( + iter.next().unwrap().solver, + LogisticRegressionSolverName::LBFGS + ); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn multiclass_objective_f() { diff --git a/src/linear/ridge_regression.rs b/src/linear/ridge_regression.rs index ecad250..4c3d4ff 100644 --- a/src/linear/ridge_regression.rs +++ b/src/linear/ridge_regression.rs @@ -68,7 +68,7 @@ use crate::linalg::Matrix; use crate::math::num::RealNumber; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Eq, PartialEq)] /// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable. pub enum RidgeRegressionSolverName { /// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html) @@ -90,6 +90,90 @@ pub struct RidgeRegressionParameters { pub normalize: bool, } +/// Ridge Regression grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct RidgeRegressionSearchParameters { + /// Solver to use for estimation of regression coefficients. + pub solver: Vec, + /// Regularization parameter. + pub alpha: Vec, + /// If true the regressors X will be normalized before regression + /// by subtracting the mean and dividing by the standard deviation. + pub normalize: Vec, +} + +/// Ridge Regression grid search iterator +pub struct RidgeRegressionSearchParametersIterator { + ridge_regression_search_parameters: RidgeRegressionSearchParameters, + current_solver: usize, + current_alpha: usize, + current_normalize: usize, +} + +impl IntoIterator for RidgeRegressionSearchParameters { + type Item = RidgeRegressionParameters; + type IntoIter = RidgeRegressionSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + RidgeRegressionSearchParametersIterator { + ridge_regression_search_parameters: self, + current_solver: 0, + current_alpha: 0, + current_normalize: 0, + } + } +} + +impl Iterator for RidgeRegressionSearchParametersIterator { + type Item = RidgeRegressionParameters; + + fn next(&mut self) -> Option { + if self.current_alpha == self.ridge_regression_search_parameters.alpha.len() + && self.current_solver == self.ridge_regression_search_parameters.solver.len() + { + return None; + } + + let next = RidgeRegressionParameters { + solver: self.ridge_regression_search_parameters.solver[self.current_solver].clone(), + alpha: self.ridge_regression_search_parameters.alpha[self.current_alpha], + normalize: self.ridge_regression_search_parameters.normalize[self.current_normalize], + }; + + if self.current_alpha + 1 < self.ridge_regression_search_parameters.alpha.len() { + self.current_alpha += 1; + } else if self.current_solver + 1 < self.ridge_regression_search_parameters.solver.len() { + self.current_alpha = 0; + self.current_solver += 1; + } else if self.current_normalize + 1 + < self.ridge_regression_search_parameters.normalize.len() + { + self.current_alpha = 0; + self.current_solver = 0; + self.current_normalize += 1; + } else { + self.current_alpha += 1; + self.current_solver += 1; + self.current_normalize += 1; + } + + Some(next) + } +} + +impl Default for RidgeRegressionSearchParameters { + fn default() -> Self { + let default_params = RidgeRegressionParameters::default(); + + RidgeRegressionSearchParameters { + solver: vec![default_params.solver], + alpha: vec![default_params.alpha], + normalize: vec![default_params.normalize], + } + } +} + /// Ridge regression #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug)] @@ -274,6 +358,21 @@ mod tests { use crate::linalg::naive::dense_matrix::*; use crate::metrics::mean_absolute_error; + #[test] + fn search_parameters() { + let parameters = RidgeRegressionSearchParameters { + alpha: vec![0., 1.], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + assert_eq!(iter.next().unwrap().alpha, 0.); + assert_eq!( + iter.next().unwrap().solver, + RidgeRegressionSolverName::Cholesky + ); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn ridge_fit_predict() { diff --git a/src/model_selection/hyper_tuning.rs b/src/model_selection/hyper_tuning.rs new file mode 100644 index 0000000..3093fbd --- /dev/null +++ b/src/model_selection/hyper_tuning.rs @@ -0,0 +1,117 @@ +/// grid search results. +#[derive(Clone, Debug)] +pub struct GridSearchResult { + /// Vector with test scores on each cv split + pub cross_validation_result: CrossValidationResult, + /// Vector with training scores on each cv split + pub parameters: I, +} + +/// Search for the best estimator by testing all possible combinations with cross-validation using given metric. +/// * `fit_estimator` - a `fit` function of an estimator +/// * `x` - features, matrix of size _NxM_ where _N_ is number of samples and _M_ is number of attributes. +/// * `y` - target values, should be of size _N_ +/// * `parameter_search` - an iterator for parameters that will be tested. +/// * `cv` - the cross-validation splitting strategy, should be an instance of [`BaseKFold`](./trait.BaseKFold.html) +/// * `score` - a metric to use for evaluation, see [metrics](../metrics/index.html) +pub fn grid_search( + fit_estimator: F, + x: &M, + y: &M::RowVector, + parameter_search: I, + cv: K, + score: S, +) -> Result, Failed> +where + T: RealNumber, + M: Matrix, + I: Iterator, + I::Item: Clone, + E: Predictor, + K: BaseKFold, + F: Fn(&M, &M::RowVector, I::Item) -> Result, + S: Fn(&M::RowVector, &M::RowVector) -> T, +{ + let mut best_result: Option> = None; + let mut best_parameters = None; + + for parameters in parameter_search { + let result = cross_validate(&fit_estimator, x, y, ¶meters, &cv, &score)?; + if best_result.is_none() + || result.mean_test_score() > best_result.as_ref().unwrap().mean_test_score() + { + best_parameters = Some(parameters); + best_result = Some(result); + } + } + + if let (Some(parameters), Some(cross_validation_result)) = (best_parameters, best_result) { + Ok(GridSearchResult { + cross_validation_result, + parameters, + }) + } else { + Err(Failed::because( + FailedError::FindFailed, + "there were no parameter sets found", + )) + } +} + +#[cfg(test)] +mod tests { + use crate::linear::logistic_regression::{ + LogisticRegression, LogisticRegressionSearchParameters, +}; + + #[test] + fn test_grid_search() { + let x = DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + &[4.9, 3.1, 1.5, 0.1], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + &[5.7, 2.8, 4.5, 1.3], + &[6.3, 3.3, 4.7, 1.6], + &[4.9, 2.4, 3.3, 1.0], + &[6.6, 2.9, 4.6, 1.3], + &[5.2, 2.7, 3.9, 1.4], + ]); + let y = vec![ + 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., + ]; + + let cv = KFold { + n_splits: 5, + ..KFold::default() + }; + + let parameters = LogisticRegressionSearchParameters { + alpha: vec![0., 1.], + ..Default::default() + }; + + let results = grid_search( + LogisticRegression::fit, + &x, + &y, + parameters.into_iter(), + cv, + &accuracy, + ) + .unwrap(); + + assert!([0., 1.].contains(&results.parameters.alpha)); + } +} \ No newline at end of file diff --git a/src/model_selection/mod.rs b/src/model_selection/mod.rs index d283176..68f0635 100644 --- a/src/model_selection/mod.rs +++ b/src/model_selection/mod.rs @@ -91,8 +91,8 @@ //! //! let results = cross_validate(LogisticRegression::fit, //estimator //! &x, &y, //data -//! Default::default(), //hyperparameters -//! cv, //cross validation split +//! &Default::default(), //hyperparameters +//! &cv, //cross validation split //! &accuracy).unwrap(); //metric //! //! println!("Training accuracy: {}, test accuracy: {}", @@ -201,8 +201,8 @@ pub fn cross_validate( fit_estimator: F, x: &M, y: &M::RowVector, - parameters: H, - cv: K, + parameters: &H, + cv: &K, score: S, ) -> Result, Failed> where @@ -281,6 +281,7 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::*; + use crate::metrics::{accuracy, mean_absolute_error}; use crate::model_selection::kfold::KFold; use crate::neighbors::knn_regressor::KNNRegressor; @@ -362,8 +363,15 @@ mod tests { ..KFold::default() }; - let results = - cross_validate(BiasedEstimator::fit, &x, &y, NoParameters {}, cv, &accuracy).unwrap(); + let results = cross_validate( + BiasedEstimator::fit, + &x, + &y, + &NoParameters {}, + &cv, + &accuracy, + ) + .unwrap(); assert_eq!(0.4, results.mean_test_score()); assert_eq!(0.4, results.mean_train_score()); @@ -404,8 +412,8 @@ mod tests { KNNRegressor::fit, &x, &y, - Default::default(), - cv, + &Default::default(), + &cv, &mean_absolute_error, ) .unwrap(); From b6f585e60fed106d163269a4db093c751141216a Mon Sep 17 00:00:00 2001 From: Tim Toebrock <35797763+titoeb@users.noreply.github.com> Date: Mon, 19 Sep 2022 11:38:01 +0200 Subject: [PATCH 10/19] Implement a generic read_csv method (#147) * feat: Add interface to build `Matrix` from rows. * feat: Add option to derive `RealNumber` from string. To construct a `Matrix` from csv, and therefore from string, I need to be able to deserialize a generic `RealNumber` from string. * feat: Implement `Matrix::read_csv`. --- src/lib.rs | 2 + src/linalg/mod.rs | 100 ++++++++ src/math/num.rs | 12 + src/readers/csv.rs | 487 ++++++++++++++++++++++++++++++++++++++ src/readers/error.rs | 71 ++++++ src/readers/io_testing.rs | 158 +++++++++++++ src/readers/mod.rs | 11 + 7 files changed, 841 insertions(+) create mode 100644 src/readers/csv.rs create mode 100644 src/readers/error.rs create mode 100644 src/readers/io_testing.rs create mode 100644 src/readers/mod.rs diff --git a/src/lib.rs b/src/lib.rs index 2edada4..e9e1c3d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -95,6 +95,8 @@ pub mod neighbors; pub(crate) mod optimization; /// Preprocessing utilities pub mod preprocessing; +/// Reading in Data. +pub mod readers; /// Support Vector Machines pub mod svm; /// Supervised tree-based learning methods diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs index 8e27c0b..9f1697c 100644 --- a/src/linalg/mod.rs +++ b/src/linalg/mod.rs @@ -65,8 +65,11 @@ use high_order::HighOrderOperations; use lu::LUDecomposableMatrix; use qr::QRDecomposableMatrix; use stats::{MatrixPreprocessing, MatrixStats}; +use std::fs; use svd::SVDDecomposableMatrix; +use crate::readers; + /// Column or row vector pub trait BaseVector: Clone + Debug { /// Get an element of a vector @@ -298,9 +301,60 @@ pub trait BaseMatrix: Clone + Debug { /// represents a row in this matrix. type RowVector: BaseVector + Clone + Debug; + /// Create a matrix from a csv file. + /// ``` + /// use smartcore::linalg::naive::dense_matrix::DenseMatrix; + /// use smartcore::linalg::BaseMatrix; + /// use smartcore::readers::csv; + /// use std::fs; + /// + /// fs::write("identity.csv", "header\n1.0,0.0\n0.0,1.0"); + /// assert_eq!( + /// DenseMatrix::::from_csv("identity.csv", csv::CSVDefinition::default()).unwrap(), + /// DenseMatrix::from_row_vectors(vec![vec![1.0, 0.0], vec![0.0, 1.0]]).unwrap() + /// ); + /// fs::remove_file("identity.csv"); + /// ``` + fn from_csv( + path: &str, + definition: readers::csv::CSVDefinition<'_>, + ) -> Result { + readers::csv::matrix_from_csv_source(fs::File::open(path)?, definition) + } + /// Transforms row vector `vec` into a 1xM matrix. fn from_row_vector(vec: Self::RowVector) -> Self; + /// Transforms Vector of n rows with dimension m into + /// a matrix nxm. + /// ``` + /// use smartcore::linalg::naive::dense_matrix::DenseMatrix; + /// use crate::smartcore::linalg::BaseMatrix; + /// + /// let eye = DenseMatrix::from_row_vectors(vec![vec![1., 0., 0.], vec![0., 1., 0.], vec![0., 0., 1.]]) + /// .unwrap(); + /// + /// assert_eq!( + /// eye, + /// DenseMatrix::from_2d_vec(&vec![ + /// vec![1.0, 0.0, 0.0], + /// vec![0.0, 1.0, 0.0], + /// vec![0.0, 0.0, 1.0], + /// ]) + /// ); + fn from_row_vectors(rows: Vec) -> Option { + if let Some(first_row) = rows.first().cloned() { + return Some(rows.iter().skip(1).cloned().fold( + Self::from_row_vector(first_row), + |current_matrix, new_row| { + current_matrix.v_stack(&BaseMatrix::from_row_vector(new_row)) + }, + )); + } else { + None + } + } + /// Transforms 1-d matrix of 1xM into a row vector. fn to_row_vector(self) -> Self::RowVector; @@ -782,4 +836,50 @@ mod tests { "The second column was not extracted correctly" ); } + mod matrix_from_csv { + + use crate::linalg::naive::dense_matrix::DenseMatrix; + use crate::linalg::BaseMatrix; + use crate::readers::csv; + use crate::readers::io_testing; + use crate::readers::ReadingError; + + #[test] + fn simple_read_default_csv() { + let test_csv_file = io_testing::TemporaryTextFile::new( + "'sepal.length','sepal.width','petal.length','petal.width'\n\ + 5.1,3.5,1.4,0.2\n\ + 4.9,3,1.4,0.2\n\ + 4.7,3.2,1.3,0.2", + ); + + assert_eq!( + DenseMatrix::::from_csv( + test_csv_file + .expect("Temporary file could not be written.") + .path(), + csv::CSVDefinition::default() + ), + Ok(DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + ])) + ) + } + + #[test] + fn non_existant_input_file() { + let potential_error = + DenseMatrix::::from_csv("/invalid/path", csv::CSVDefinition::default()); + // The exact message is operating system dependant, therefore, I only test that the correct type + // error was returned. + assert_eq!( + potential_error.clone(), + Err(ReadingError::CouldNotReadFileSystem { + msg: String::from(potential_error.err().unwrap().message().unwrap()) + }) + ) + } + } } diff --git a/src/math/num.rs b/src/math/num.rs index c454b9d..433ad28 100644 --- a/src/math/num.rs +++ b/src/math/num.rs @@ -7,6 +7,7 @@ use rand::prelude::*; use std::fmt::{Debug, Display}; use std::iter::{Product, Sum}; use std::ops::{AddAssign, DivAssign, MulAssign, SubAssign}; +use std::str::FromStr; /// Defines real number /// @@ -22,6 +23,7 @@ pub trait RealNumber: + SubAssign + MulAssign + DivAssign + + FromStr { /// Copy sign from `sign` - another real number fn copysign(self, sign: Self) -> Self; @@ -154,4 +156,14 @@ mod tests { assert_eq!(41.0.sigmoid(), 1.); assert_eq!((-41.0).sigmoid(), 0.); } + + #[test] + fn f32_from_string() { + assert_eq!(f32::from_str("1.111111").unwrap(), 1.111111) + } + + #[test] + fn f64_from_string() { + assert_eq!(f64::from_str("1.111111111").unwrap(), 1.111111111) + } } diff --git a/src/readers/csv.rs b/src/readers/csv.rs new file mode 100644 index 0000000..e80d99b --- /dev/null +++ b/src/readers/csv.rs @@ -0,0 +1,487 @@ +//! This module contains utitilities to read-in matrices from csv files. +//! ``` +//! use smartcore::readers::csv; +//! use smartcore::linalg::naive::dense_matrix::DenseMatrix; +//! use crate::smartcore::linalg::BaseMatrix; +//! use std::fs; +//! +//! fs::write("identity.csv", "header\n1.0,0.0\n0.0,1.0"); +//! assert_eq!( +//! csv::matrix_from_csv_source::, DenseMatrix<_>>( +//! fs::File::open("identity.csv").unwrap(), +//! csv::CSVDefinition::default() +//! ) +//! .unwrap(), +//! DenseMatrix::from_row_vectors(vec![vec![1.0, 0.0], vec![0.0, 1.0]]).unwrap() +//! ); +//! fs::remove_file("identity.csv"); +//! ``` +use crate::linalg::{BaseMatrix, BaseVector}; +use crate::math::num::RealNumber; +use crate::readers::ReadingError; +use std::io::Read; + +/// Define the structure of a CSV-file so that it can be read. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct CSVDefinition<'a> { + /// How many rows does the header have? + n_rows_header: usize, + /// What seperates the fields in your csv-file? + field_seperator: &'a str, +} +impl<'a> Default for CSVDefinition<'a> { + fn default() -> Self { + Self { + n_rows_header: 1, + field_seperator: ",", + } + } +} + +/// Format definition for a single row in a csv file. +/// This is used internally to validate rows of the csv file and +/// be able to fail as early as possible. +#[derive(Clone, Debug, PartialEq, Eq)] +struct CSVRowFormat<'a> { + field_seperator: &'a str, + n_fields: usize, +} +impl<'a> CSVRowFormat<'a> { + fn from_csv_definition(definition: &'a CSVDefinition<'_>, n_fields: usize) -> Self { + CSVRowFormat { + field_seperator: definition.field_seperator, + n_fields, + } + } +} + +/// Detect the row format for the csv file from the first row. +fn detect_row_format<'a>( + csv_text: &'a str, + definition: &'a CSVDefinition<'_>, +) -> Result, ReadingError> { + let first_line = csv_text + .lines() + .nth(definition.n_rows_header) + .ok_or(ReadingError::NoRowsProvided)?; + + Ok(CSVRowFormat::from_csv_definition( + definition, + first_line.split(definition.field_seperator).count(), + )) +} + +/// Read in a matrix from a source that contains a csv file. +pub fn matrix_from_csv_source( + source: impl Read, + definition: CSVDefinition<'_>, +) -> Result +where + T: RealNumber, + RowVector: BaseVector, + Matrix: BaseMatrix, +{ + let csv_text = read_string_from_source(source)?; + let rows = extract_row_vectors_from_csv_text::( + &csv_text, + &definition, + detect_row_format(&csv_text, &definition)?, + )?; + + match Matrix::from_row_vectors(rows) { + Some(matrix) => Ok(matrix), + None => Err(ReadingError::NoRowsProvided), + } +} + +/// Given a string containing the contents of a csv file, extract its value +/// into row-vectors. +fn extract_row_vectors_from_csv_text<'a, T, RowVector, Matrix>( + csv_text: &'a str, + definition: &'a CSVDefinition<'_>, + row_format: CSVRowFormat<'_>, +) -> Result, ReadingError> +where + T: RealNumber, + RowVector: BaseVector, + Matrix: BaseMatrix, +{ + csv_text + .lines() + .skip(definition.n_rows_header) + .enumerate() + .map(|(row_index, line)| { + enrich_reading_error( + extract_vector_from_csv_line(line, &row_format), + format!(", Row: {row_index}."), + ) + }) + .collect::, ReadingError>>() +} + +/// Read a string from source implementing `Read`. +fn read_string_from_source(mut source: impl Read) -> Result { + let mut string = String::new(); + source.read_to_string(&mut string)?; + Ok(string) +} + +/// Extract a vector from a single line of a csv file. +fn extract_vector_from_csv_line( + line: &str, + row_format: &CSVRowFormat<'_>, +) -> Result +where + T: RealNumber, + RowVector: BaseVector, +{ + validate_csv_row(line, row_format)?; + let extracted_fields = extract_fields_from_csv_row(line, row_format)?; + Ok(BaseVector::from_array(&extracted_fields[..])) +} + +/// Extract the fields from a string containing the row of a csv file. +fn extract_fields_from_csv_row( + row: &str, + row_format: &CSVRowFormat<'_>, +) -> Result, ReadingError> +where + T: RealNumber, +{ + row.split(row_format.field_seperator) + .enumerate() + .map(|(field_number, csv_field)| { + enrich_reading_error( + extract_value_from_csv_field(csv_field.trim()), + format!(" Column: {field_number}"), + ) + }) + .collect::, ReadingError>>() +} + +/// Ensure that a string containing a csv row conforms to a specified row format. +fn validate_csv_row<'a>(row: &'a str, row_format: &CSVRowFormat<'_>) -> Result<(), ReadingError> { + let actual_number_of_fields = row.split(row_format.field_seperator).count(); + if row_format.n_fields == actual_number_of_fields { + Ok(()) + } else { + Err(ReadingError::InvalidRow { + msg: format!( + "{} fields found but expected {}", + actual_number_of_fields, row_format.n_fields + ), + }) + } +} + +/// Add additional text to the message of an error. +/// In csv reading it is used to add the line-number / row-number +/// The error occured that is only known in the functions above. +fn enrich_reading_error( + result: Result, + additional_text: String, +) -> Result { + result.map_err(|error| ReadingError::InvalidField { + msg: format!( + "{}{additional_text}", + error.message().unwrap_or("Could not serialize value") + ), + }) +} + +/// Extract the value from a single csv field. +fn extract_value_from_csv_field(value_string: &str) -> Result +where + T: RealNumber, +{ + // By default, `FromStr::Err` does not implement `Debug`. + // Restricting it in the library leads to many breaking + // changes therefore I have to reconstruct my own, printable + // error as good as possible. + match value_string.parse::().ok() { + Some(value) => Ok(value), + None => Err(ReadingError::InvalidField { + msg: format!("Value '{}' could not be read.", value_string,), + }), + } +} + +#[cfg(test)] +mod tests { + mod matrix_from_csv_source { + use super::super::{read_string_from_source, CSVDefinition, ReadingError}; + use crate::linalg::naive::dense_matrix::DenseMatrix; + use crate::readers::{csv::matrix_from_csv_source, io_testing}; + + #[test] + fn read_simple_string() { + assert_eq!( + read_string_from_source(io_testing::TestingDataSource::new("test-string")), + Ok(String::from("test-string")) + ) + } + #[test] + fn read_simple_csv() { + assert_eq!( + matrix_from_csv_source::, DenseMatrix<_>>( + io_testing::TestingDataSource::new( + "'sepal.length','sepal.width','petal.length','petal.width'\n\ + 5.1,3.5,1.4,0.2\n\ + 4.9,3.0,1.4,0.2\n\ + 4.7,3.2,1.3,0.2", + ), + CSVDefinition::default(), + ), + Ok(DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + ])) + ) + } + #[test] + fn read_csv_semicolon_as_seperator() { + assert_eq!( + matrix_from_csv_source::, DenseMatrix<_>>( + io_testing::TestingDataSource::new( + "'sepal.length';'sepal.width';'petal.length';'petal.width'\n\ + 'Length of sepals.';'Width of Sepals';'Length of petals';'Width of petals'\n\ + 5.1;3.5;1.4;0.2\n\ + 4.9;3.0;1.4;0.2\n\ + 4.7;3.2;1.3;0.2", + ), + CSVDefinition { + n_rows_header: 2, + field_seperator: ";" + }, + ), + Ok(DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + ])) + ) + } + #[test] + fn error_in_colum_1_row_1() { + assert_eq!( + matrix_from_csv_source::, DenseMatrix<_>>( + io_testing::TestingDataSource::new( + "'sepal.length','sepal.width','petal.length','petal.width'\n\ + 5.1,3.5,1.4,0.2\n\ + 4.9,invalid,1.4,0.2\n\ + 4.7,3.2,1.3,0.2", + ), + CSVDefinition::default(), + ), + Err(ReadingError::InvalidField { + msg: String::from("Value 'invalid' could not be read. Column: 1, Row: 1.") + }) + ) + } + #[test] + fn different_number_of_columns() { + assert_eq!( + matrix_from_csv_source::, DenseMatrix<_>>( + io_testing::TestingDataSource::new( + "'field_1','field_2'\n\ + 5.1,3.5\n\ + 4.9,3.0,1.4", + ), + CSVDefinition::default(), + ), + Err(ReadingError::InvalidField { + msg: String::from("3 fields found but expected 2, Row: 1.") + }) + ) + } + } + mod extract_row_vectors_from_csv_text { + use super::super::{extract_row_vectors_from_csv_text, CSVDefinition, CSVRowFormat}; + use crate::linalg::naive::dense_matrix::DenseMatrix; + + #[test] + fn read_default_csv() { + assert_eq!( + extract_row_vectors_from_csv_text::, DenseMatrix<_>>( + "column 1, column 2, column3\n1.0,2.0,3.0\n4.0,5.0,6.0", + &CSVDefinition::default(), + CSVRowFormat { + field_seperator: ",", + n_fields: 3, + }, + ), + Ok(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]) + ); + } + } + mod test_validate_csv_row { + use super::super::{validate_csv_row, CSVRowFormat, ReadingError}; + + #[test] + fn valid_row_with_comma() { + assert_eq!( + validate_csv_row( + "1.0, 2.0, 3.0", + &CSVRowFormat { + field_seperator: ",", + n_fields: 3, + }, + ), + Ok(()) + ) + } + #[test] + fn valid_row_with_semicolon() { + assert_eq!( + validate_csv_row( + "1.0; 2.0; 3.0; 4.0", + &CSVRowFormat { + field_seperator: ";", + n_fields: 4, + }, + ), + Ok(()) + ) + } + #[test] + fn invalid_number_of_fields() { + assert_eq!( + validate_csv_row( + "1.0; 2.0; 3.0; 4.0", + &CSVRowFormat { + field_seperator: ";", + n_fields: 3, + }, + ), + Err(ReadingError::InvalidRow { + msg: String::from("4 fields found but expected 3") + }) + ) + } + } + mod extract_fields_from_csv_row { + use super::super::{extract_fields_from_csv_row, CSVRowFormat}; + + #[test] + fn read_four_values_from_csv_row() { + assert_eq!( + extract_fields_from_csv_row( + "1.0; 2.0; 3.0; 4.0", + &CSVRowFormat { + field_seperator: ";", + n_fields: 4 + } + ), + Ok(vec![1.0, 2.0, 3.0, 4.0]) + ) + } + } + mod detect_row_format { + use super::super::{detect_row_format, CSVDefinition, CSVRowFormat, ReadingError}; + + #[test] + fn detect_2_fields_with_header() { + assert_eq!( + detect_row_format( + "header-1\nheader-2\n1.0,2.0", + &CSVDefinition { + n_rows_header: 2, + field_seperator: "," + } + ) + .expect("The row format should be detectable with this input."), + CSVRowFormat { + field_seperator: ",", + n_fields: 2 + } + ) + } + #[test] + fn detect_3_fields_no_header() { + assert_eq!( + detect_row_format( + "1.0,2.0,3.0", + &CSVDefinition { + n_rows_header: 0, + field_seperator: "," + } + ) + .expect("The row format should be detectable with this input."), + CSVRowFormat { + field_seperator: ",", + n_fields: 3 + } + ) + } + #[test] + fn detect_no_rows_provided() { + assert_eq!( + detect_row_format("header\n", &CSVDefinition::default()), + Err(ReadingError::NoRowsProvided) + ) + } + } + mod extract_value_from_csv_field { + use super::super::extract_value_from_csv_field; + use crate::readers::ReadingError; + + #[test] + fn deserialize_f64_from_floating_point() { + assert_eq!(extract_value_from_csv_field::("1.0"), Ok(1.0)) + } + #[test] + fn deserialize_f64_from_negative_floating_point() { + assert_eq!(extract_value_from_csv_field::("-1.0"), Ok(-1.0)) + } + #[test] + fn deserialize_f64_from_non_floating_point() { + assert_eq!(extract_value_from_csv_field::("1"), Ok(1.0)) + } + #[test] + fn cant_deserialize_f64_from_string() { + assert_eq!( + extract_value_from_csv_field::("Test"), + Err(ReadingError::InvalidField { + msg: String::from("Value 'Test' could not be read.") + },) + ) + } + #[test] + fn deserialize_f32_from_non_floating_point() { + assert_eq!(extract_value_from_csv_field::("12.0"), Ok(12.0)) + } + } + mod extract_vector_from_csv_line { + use super::super::{extract_vector_from_csv_line, CSVRowFormat, ReadingError}; + + #[test] + fn extract_five_floating_point_values() { + assert_eq!( + extract_vector_from_csv_line::>( + "-1.0,2.0,100.0,12", + &CSVRowFormat { + field_seperator: ",", + n_fields: 4 + } + ), + Ok(vec![-1.0, 2.0, 100.0, 12.0]) + ) + } + #[test] + fn cannot_extract_second_value() { + assert_eq!( + extract_vector_from_csv_line::>( + "-1.0,test,100.0,12", + &CSVRowFormat { + field_seperator: ",", + n_fields: 4 + } + ), + Err(ReadingError::InvalidField { + msg: String::from("Value 'test' could not be read. Column: 1") + }) + ) + } + } +} diff --git a/src/readers/error.rs b/src/readers/error.rs new file mode 100644 index 0000000..16e910d --- /dev/null +++ b/src/readers/error.rs @@ -0,0 +1,71 @@ +//! The module contains the errors that can happen in the `readers` folder and +//! utility functions. + +/// Error wrapping all failures that can happen during loading from file. +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum ReadingError { + /// The file could not be read from the file-system. + CouldNotReadFileSystem { + /// More details about the specific file-system error + /// that occured. + msg: String, + }, + /// No rows exists in the CSV-file. + NoRowsProvided, + /// A field in the csv file could not be read. + InvalidField { + /// More details about what field could not be + /// read and where it happened. + msg: String, + }, + /// A row from the csv is invalid. + InvalidRow { + /// More details about what row could not be read + /// and where it happened. + msg: String, + }, +} +impl From for ReadingError { + fn from(io_error: std::io::Error) -> Self { + Self::CouldNotReadFileSystem { + msg: io_error.to_string(), + } + } +} +impl ReadingError { + /// Extract the error-message from a `ReadingError`. + pub fn message(&self) -> Option<&str> { + match self { + ReadingError::InvalidField { msg } => Some(msg), + ReadingError::InvalidRow { msg } => Some(msg), + ReadingError::CouldNotReadFileSystem { msg } => Some(msg), + ReadingError::NoRowsProvided => None, + } + } +} + +#[cfg(test)] +mod tests { + use super::ReadingError; + use std::io; + + #[test] + fn reading_error_from_io_error() { + let _parsed_reading_error: ReadingError = ReadingError::from(io::Error::new( + io::ErrorKind::AlreadyExists, + "File already exists .", + )); + } + #[test] + fn extract_message_from_reading_error() { + let error_content = "Path does not exist"; + assert_eq!( + ReadingError::CouldNotReadFileSystem { + msg: String::from(error_content) + } + .message() + .expect("This error should contain a mesage"), + String::from(error_content) + ) + } +} diff --git a/src/readers/io_testing.rs b/src/readers/io_testing.rs new file mode 100644 index 0000000..1376a5d --- /dev/null +++ b/src/readers/io_testing.rs @@ -0,0 +1,158 @@ +//! This module contains functionality to test IO. It has both functions that write +//! to the file-system for end-to-end tests, but also abstractions to avoid this by +//! reading from strings instead. +use rand::distributions::{Alphanumeric, DistString}; +use std::fs; +use std::io::Bytes; +use std::io::Read; +use std::io::{Chain, IoSliceMut, Take, Write}; + +/// Writing out a temporary csv file at a random location and cleaning +/// it up on `Drop`. +#[derive(Clone, Debug, Default, Eq, PartialEq)] +pub struct TemporaryTextFile { + random_path: String, +} +impl TemporaryTextFile { + pub fn new(contents: &str) -> std::io::Result { + let test_text_file = TemporaryTextFile { + random_path: Alphanumeric.sample_string(&mut rand::thread_rng(), 16), + }; + string_to_file(contents, &test_text_file.random_path)?; + Ok(test_text_file) + } + pub fn path(&self) -> &str { + &self.random_path + } +} +/// On `Drop` we cleanup the file-system by remove the file. +impl Drop for TemporaryTextFile { + fn drop(&mut self) { + fs::remove_file(self.path()) + .unwrap_or_else(|_| panic!("Could not clean up temporary file {}.", self.random_path)); + } +} +/// Write out a string to file. +pub(crate) fn string_to_file(string: &str, file_path: &str) -> std::io::Result<()> { + let mut csv_file = fs::File::create(file_path)?; + csv_file.write_all(string.as_bytes())?; + Ok(()) +} + +/// This is used an an alternative struct that implements `Read` so +/// that instead of reading from the file-system, we can test the same +/// functionality without any file-system interaction. +pub(crate) struct TestingDataSource { + text: String, +} +impl TestingDataSource { + pub(crate) fn new(text: &str) -> Self { + Self { + text: String::from(text), + } + } +} +/// This is the trait that also `file::File` implements, so by implementing +/// it for `TestingDataSource` we can test functionality that reads from +/// file in a more lightweight way. +impl Read for TestingDataSource { + fn read(&mut self, _buf: &mut [u8]) -> Result { + unimplemented!() + } + + fn read_vectored(&mut self, _bufs: &mut [IoSliceMut<'_>]) -> Result { + unimplemented!() + } + + fn read_to_end(&mut self, _buf: &mut Vec) -> Result { + unimplemented!() + } + fn read_to_string(&mut self, buf: &mut String) -> Result { + ::write_str(buf, &self.text).unwrap(); + Ok(0) + } + fn read_exact(&mut self, _buf: &mut [u8]) -> Result<(), std::io::Error> { + unimplemented!() + } + fn by_ref(&mut self) -> &mut Self + where + Self: Sized, + { + unimplemented!() + } + fn bytes(self) -> Bytes + where + Self: Sized, + { + unimplemented!() + } + fn chain(self, _next: R) -> Chain + where + Self: Sized, + { + unimplemented!() + } + fn take(self, _limit: u64) -> Take + where + Self: Sized, + { + unimplemented!() + } +} + +#[cfg(test)] +mod test { + use super::TestingDataSource; + use super::{string_to_file, TemporaryTextFile}; + use std::fs; + use std::io::Read; + use std::path; + #[test] + fn test_temporary_text_file() { + let path_of_temporary_file; + { + let hello_world_file = TemporaryTextFile::new("Hello World!") + .expect("`hello_world_file` should be able to write file."); + + path_of_temporary_file = String::from(hello_world_file.path()); + assert_eq!( + fs::read_to_string(&path_of_temporary_file).expect( + "This field should have been written by the `hello_world_file`-object." + ), + "Hello World!" + ) + } + // By now `hello_world_file` should have been dropped and the file + // should have been cleaned up. + assert!(!path::Path::new(&path_of_temporary_file).exists()) + } + + #[test] + fn test_string_to_file() { + let path_of_test_file = "test.file"; + let contents_of_test_file = "Hello IO-World"; + + string_to_file(contents_of_test_file, path_of_test_file) + .expect("The file should have been written out."); + assert_eq!( + fs::read_to_string(path_of_test_file) + .expect("The file we test for should have been written."), + String::from(contents_of_test_file) + ); + + // Cleanup the temporary file. + fs::remove_file(path_of_test_file) + .expect("The test file should exist before and be removed here."); + } + + #[test] + fn read_from_testing_data_source() { + let mut test_buffer = String::new(); + let test_data_content = "Hello non-IO world!"; + + TestingDataSource::new(test_data_content) + .read_to_string(&mut test_buffer) + .expect("Text should have been written to buffer `test_buffer`."); + assert_eq!(test_buffer, test_data_content) + } +} diff --git a/src/readers/mod.rs b/src/readers/mod.rs new file mode 100644 index 0000000..6fc6f92 --- /dev/null +++ b/src/readers/mod.rs @@ -0,0 +1,11 @@ +/// Read in from csv. +pub mod csv; + +/// Error definition for readers. +mod error; +/// Utilities to help with testing functionality using IO. +/// Only meant for internal usage. +#[cfg(test)] +pub(crate) mod io_testing; + +pub use error::ReadingError; From 2510ca4e9d1a4e46518f3e36e2583801ec0ad16f Mon Sep 17 00:00:00 2001 From: morenol <22335041+morenol@users.noreply.github.com> Date: Mon, 19 Sep 2022 10:44:01 -0400 Subject: [PATCH 11/19] fix: fix compilation warnings when running only with default features (#160) * fix: fix compilation warnings when running only with default features Co-authored-by: Luis Moreno --- Cargo.toml | 4 ++-- src/algorithm/neighbour/fastpair.rs | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 069e223..aa649fc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ default = ["datasets"] ndarray-bindings = ["ndarray"] nalgebra-bindings = ["nalgebra"] datasets = [] -fp_bench = [] +fp_bench = ["itertools"] [dependencies] ndarray = { version = "0.15", optional = true } @@ -27,7 +27,7 @@ num = "0.4" rand = "0.8" rand_distr = "0.4" serde = { version = "1", features = ["derive"], optional = true } -itertools = "0.10.3" +itertools = { version = "0.10.3", optional = true } [target.'cfg(target_arch = "wasm32")'.dependencies] getrandom = { version = "0.2", features = ["js"] } diff --git a/src/algorithm/neighbour/fastpair.rs b/src/algorithm/neighbour/fastpair.rs index e14c2b3..bf3bca3 100644 --- a/src/algorithm/neighbour/fastpair.rs +++ b/src/algorithm/neighbour/fastpair.rs @@ -1,5 +1,3 @@ -#![allow(non_snake_case)] -use itertools::Itertools; /// /// # FastPair: Data-structure for the dynamic closest-pair problem. /// @@ -177,6 +175,7 @@ impl<'a, T: RealNumber, M: Matrix> FastPair<'a, T, M> { /// #[cfg(feature = "fp_bench")] pub fn closest_pair_brute(&self) -> PairwiseDistance { + use itertools::Itertools; let m = self.samples.shape().0; let mut closest_pair = PairwiseDistance { From 436da104d7d6e96c43b58c106ade158ac9c5d446 Mon Sep 17 00:00:00 2001 From: Lorenzo Date: Mon, 19 Sep 2022 18:00:17 +0100 Subject: [PATCH 12/19] Update LICENSE --- LICENSE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LICENSE b/LICENSE index 261eeb9..3cd5786 100644 --- a/LICENSE +++ b/LICENSE @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright [yyyy] [name of copyright owner] + Copyright 2019-present at SmartCore developers (smartcorelib.org) Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. From 6a2e10452fe7eb43b8bad33481363f715fb222b7 Mon Sep 17 00:00:00 2001 From: morenol <22335041+morenol@users.noreply.github.com> Date: Tue, 20 Sep 2022 06:21:02 -0400 Subject: [PATCH 13/19] Make rand_distr optional (#161) --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index aa649fc..a0ad984 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ categories = ["science"] default = ["datasets"] ndarray-bindings = ["ndarray"] nalgebra-bindings = ["nalgebra"] -datasets = [] +datasets = ["rand_distr"] fp_bench = ["itertools"] [dependencies] @@ -25,7 +25,7 @@ nalgebra = { version = "0.31", optional = true } num-traits = "0.2" num = "0.4" rand = "0.8" -rand_distr = "0.4" +rand_distr = { version = "0.4", optional = true } serde = { version = "1", features = ["derive"], optional = true } itertools = { version = "0.10.3", optional = true } From c21e75276a5fd64652e07f541a0858da3c5c5690 Mon Sep 17 00:00:00 2001 From: morenol <22335041+morenol@users.noreply.github.com> Date: Tue, 20 Sep 2022 06:29:54 -0400 Subject: [PATCH 14/19] =?UTF-8?q?feat:=20allocate=20first=20and=20then=20p?= =?UTF-8?q?roceed=20to=20create=20matrix=20from=20Vec=20of=20Ro=E2=80=A6?= =?UTF-8?q?=20(#159)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: allocate first and then proceed to create matrix from Vec of RowVectors --- src/linalg/mod.rs | 54 +++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 45 insertions(+), 9 deletions(-) diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs index 9f1697c..4fb3ebf 100644 --- a/src/linalg/mod.rs +++ b/src/linalg/mod.rs @@ -343,16 +343,19 @@ pub trait BaseMatrix: Clone + Debug { /// ]) /// ); fn from_row_vectors(rows: Vec) -> Option { - if let Some(first_row) = rows.first().cloned() { - return Some(rows.iter().skip(1).cloned().fold( - Self::from_row_vector(first_row), - |current_matrix, new_row| { - current_matrix.v_stack(&BaseMatrix::from_row_vector(new_row)) - }, - )); - } else { - None + if rows.is_empty() { + return None; } + let n = rows.len(); + let m = rows[0].len(); + + let mut result = Self::zeros(n, m); + + for (row_idx, row) in rows.into_iter().enumerate() { + result.set_row(row_idx, row); + } + + Some(result) } /// Transforms 1-d matrix of 1xM into a row vector. @@ -376,6 +379,13 @@ pub trait BaseMatrix: Clone + Debug { /// * `result` - receiver for the row fn copy_row_as_vec(&self, row: usize, result: &mut Vec); + /// Set row vector at row `row_idx`. + fn set_row(&mut self, row_idx: usize, row: Self::RowVector) { + for (col_idx, val) in row.to_vec().into_iter().enumerate() { + self.set(row_idx, col_idx, val); + } + } + /// Get a vector with elements of the `col`'th column /// * `col` - column number fn get_col_as_vec(&self, col: usize) -> Vec; @@ -836,6 +846,32 @@ mod tests { "The second column was not extracted correctly" ); } + + #[test] + fn test_from_row_vectors_simple() { + let eye = DenseMatrix::from_row_vectors(vec![ + vec![1., 0., 0.], + vec![0., 1., 0.], + vec![0., 0., 1.], + ]) + .unwrap(); + assert_eq!( + eye, + DenseMatrix::from_2d_vec(&vec![ + vec![1.0, 0.0, 0.0], + vec![0.0, 1.0, 0.0], + vec![0.0, 0.0, 1.0], + ]) + ); + } + + #[test] + fn test_from_row_vectors_large() { + let eye = DenseMatrix::from_row_vectors(vec![vec![4.25; 5000]; 5000]).unwrap(); + + assert_eq!(eye.shape(), (5000, 5000)); + assert_eq!(eye.get_row(5), vec![4.25; 5000]); + } mod matrix_from_csv { use crate::linalg::naive::dense_matrix::DenseMatrix; From 69d8be35de5bd5e67ee75f03bb6071eeaeee7a11 Mon Sep 17 00:00:00 2001 From: morenol <22335041+morenol@users.noreply.github.com> Date: Tue, 20 Sep 2022 12:12:09 -0400 Subject: [PATCH 15/19] Provide better output in flaky tests (#163) --- src/svm/svc.rs | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/svm/svc.rs b/src/svm/svc.rs index 74f31c7..87fb743 100644 --- a/src/svm/svc.rs +++ b/src/svm/svc.rs @@ -776,8 +776,13 @@ mod tests { ) .and_then(|lr| lr.predict(&x)) .unwrap(); + let acc = accuracy(&y_hat, &y); - assert!(accuracy(&y_hat, &y) >= 0.9); + assert!( + acc >= 0.9, + "accuracy ({}) is not larger or equal to 0.9", + acc + ); } #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] @@ -860,7 +865,13 @@ mod tests { .and_then(|lr| lr.predict(&x)) .unwrap(); - assert!(accuracy(&y_hat, &y) >= 0.9); + let acc = accuracy(&y_hat, &y); + + assert!( + acc >= 0.9, + "accuracy ({}) is not larger or equal to 0.9", + acc + ); } #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] From 48514d1b1548d1b432c9cee002fada4e90382d9d Mon Sep 17 00:00:00 2001 From: Montana Low Date: Wed, 21 Sep 2022 12:34:21 -0700 Subject: [PATCH 16/19] Complete grid search params (#166) * grid search draft * hyperparam search for linear estimators * grid search for ensembles * support grid search for more algos * grid search for unsupervised algos * minor cleanup --- src/cluster/dbscan.rs | 120 +++++++++++ src/cluster/kmeans.rs | 93 +++++++++ src/decomposition/pca.rs | 98 +++++++++ src/decomposition/svd.rs | 68 +++++++ src/ensemble/random_forest_classifier.rs | 243 +++++++++++++++++++++++ src/ensemble/random_forest_regressor.rs | 208 +++++++++++++++++++ src/linear/lasso.rs | 31 ++- src/model_selection/hyper_tuning.rs | 2 +- src/model_selection/mod.rs | 1 - src/naive_bayes/bernoulli.rs | 96 +++++++++ src/naive_bayes/categorical.rs | 68 +++++++ src/naive_bayes/gaussian.rs | 76 ++++++- src/naive_bayes/multinomial.rs | 84 ++++++++ src/svm/mod.rs | 10 +- src/svm/svc.rs | 121 +++++++++++ src/svm/svr.rs | 121 +++++++++++ src/tree/decision_tree_classifier.rs | 161 +++++++++++++++ src/tree/decision_tree_regressor.rs | 137 +++++++++++++ 18 files changed, 1713 insertions(+), 25 deletions(-) diff --git a/src/cluster/dbscan.rs b/src/cluster/dbscan.rs index 7f2baef..621d017 100644 --- a/src/cluster/dbscan.rs +++ b/src/cluster/dbscan.rs @@ -109,6 +109,103 @@ impl, T>> DBSCANParameters { } } +/// DBSCAN grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct DBSCANSearchParameters, T>> { + /// a function that defines a distance between each pair of point in training data. + /// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait. + /// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions. + pub distance: Vec, + /// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point. + pub min_samples: Vec, + /// The maximum distance between two samples for one to be considered as in the neighborhood of the other. + pub eps: Vec, + /// KNN algorithm to use. + pub algorithm: Vec, +} + +/// DBSCAN grid search iterator +pub struct DBSCANSearchParametersIterator, T>> { + dbscan_search_parameters: DBSCANSearchParameters, + current_distance: usize, + current_min_samples: usize, + current_eps: usize, + current_algorithm: usize, +} + +impl, T>> IntoIterator for DBSCANSearchParameters { + type Item = DBSCANParameters; + type IntoIter = DBSCANSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + DBSCANSearchParametersIterator { + dbscan_search_parameters: self, + current_distance: 0, + current_min_samples: 0, + current_eps: 0, + current_algorithm: 0, + } + } +} + +impl, T>> Iterator for DBSCANSearchParametersIterator { + type Item = DBSCANParameters; + + fn next(&mut self) -> Option { + if self.current_distance == self.dbscan_search_parameters.distance.len() + && self.current_min_samples == self.dbscan_search_parameters.min_samples.len() + && self.current_eps == self.dbscan_search_parameters.eps.len() + && self.current_algorithm == self.dbscan_search_parameters.algorithm.len() + { + return None; + } + + let next = DBSCANParameters { + distance: self.dbscan_search_parameters.distance[self.current_distance].clone(), + min_samples: self.dbscan_search_parameters.min_samples[self.current_min_samples], + eps: self.dbscan_search_parameters.eps[self.current_eps], + algorithm: self.dbscan_search_parameters.algorithm[self.current_algorithm].clone(), + }; + + if self.current_distance + 1 < self.dbscan_search_parameters.distance.len() { + self.current_distance += 1; + } else if self.current_min_samples + 1 < self.dbscan_search_parameters.min_samples.len() { + self.current_distance = 0; + self.current_min_samples += 1; + } else if self.current_eps + 1 < self.dbscan_search_parameters.eps.len() { + self.current_distance = 0; + self.current_min_samples = 0; + self.current_eps += 1; + } else if self.current_algorithm + 1 < self.dbscan_search_parameters.algorithm.len() { + self.current_distance = 0; + self.current_min_samples = 0; + self.current_eps = 0; + self.current_algorithm += 1; + } else { + self.current_distance += 1; + self.current_min_samples += 1; + self.current_eps += 1; + self.current_algorithm += 1; + } + + Some(next) + } +} + +impl Default for DBSCANSearchParameters { + fn default() -> Self { + let default_params = DBSCANParameters::default(); + + DBSCANSearchParameters { + distance: vec![default_params.distance], + min_samples: vec![default_params.min_samples], + eps: vec![default_params.eps], + algorithm: vec![default_params.algorithm], + } + } +} + impl, T>> PartialEq for DBSCAN { fn eq(&self, other: &Self) -> bool { self.cluster_labels.len() == other.cluster_labels.len() @@ -268,6 +365,29 @@ mod tests { #[cfg(feature = "serde")] use crate::math::distance::euclidian::Euclidian; + #[test] + fn search_parameters() { + let parameters = DBSCANSearchParameters { + min_samples: vec![10, 100], + eps: vec![1., 2.], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.min_samples, 10); + assert_eq!(next.eps, 1.); + let next = iter.next().unwrap(); + assert_eq!(next.min_samples, 100); + assert_eq!(next.eps, 1.); + let next = iter.next().unwrap(); + assert_eq!(next.min_samples, 10); + assert_eq!(next.eps, 2.); + let next = iter.next().unwrap(); + assert_eq!(next.min_samples, 100); + assert_eq!(next.eps, 2.); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn fit_predict_dbscan() { diff --git a/src/cluster/kmeans.rs b/src/cluster/kmeans.rs index 05af680..8ecbb2e 100644 --- a/src/cluster/kmeans.rs +++ b/src/cluster/kmeans.rs @@ -132,6 +132,76 @@ impl Default for KMeansParameters { } } +/// KMeans grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct KMeansSearchParameters { + /// Number of clusters. + pub k: Vec, + /// Maximum number of iterations of the k-means algorithm for a single run. + pub max_iter: Vec, +} + +/// KMeans grid search iterator +pub struct KMeansSearchParametersIterator { + kmeans_search_parameters: KMeansSearchParameters, + current_k: usize, + current_max_iter: usize, +} + +impl IntoIterator for KMeansSearchParameters { + type Item = KMeansParameters; + type IntoIter = KMeansSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + KMeansSearchParametersIterator { + kmeans_search_parameters: self, + current_k: 0, + current_max_iter: 0, + } + } +} + +impl Iterator for KMeansSearchParametersIterator { + type Item = KMeansParameters; + + fn next(&mut self) -> Option { + if self.current_k == self.kmeans_search_parameters.k.len() + && self.current_max_iter == self.kmeans_search_parameters.max_iter.len() + { + return None; + } + + let next = KMeansParameters { + k: self.kmeans_search_parameters.k[self.current_k], + max_iter: self.kmeans_search_parameters.max_iter[self.current_max_iter], + }; + + if self.current_k + 1 < self.kmeans_search_parameters.k.len() { + self.current_k += 1; + } else if self.current_max_iter + 1 < self.kmeans_search_parameters.max_iter.len() { + self.current_k = 0; + self.current_max_iter += 1; + } else { + self.current_k += 1; + self.current_max_iter += 1; + } + + Some(next) + } +} + +impl Default for KMeansSearchParameters { + fn default() -> Self { + let default_params = KMeansParameters::default(); + + KMeansSearchParameters { + k: vec![default_params.k], + max_iter: vec![default_params.max_iter], + } + } +} + impl> UnsupervisedEstimator for KMeans { fn fit(x: &M, parameters: KMeansParameters) -> Result { KMeans::fit(x, parameters) @@ -313,6 +383,29 @@ mod tests { ); } + #[test] + fn search_parameters() { + let parameters = KMeansSearchParameters { + k: vec![2, 4], + max_iter: vec![10, 100], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.k, 2); + assert_eq!(next.max_iter, 10); + let next = iter.next().unwrap(); + assert_eq!(next.k, 4); + assert_eq!(next.max_iter, 10); + let next = iter.next().unwrap(); + assert_eq!(next.k, 2); + assert_eq!(next.max_iter, 100); + let next = iter.next().unwrap(); + assert_eq!(next.k, 4); + assert_eq!(next.max_iter, 100); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn fit_predict_iris() { diff --git a/src/decomposition/pca.rs b/src/decomposition/pca.rs index 9aebae2..296926a 100644 --- a/src/decomposition/pca.rs +++ b/src/decomposition/pca.rs @@ -116,6 +116,81 @@ impl Default for PCAParameters { } } +/// PCA grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct PCASearchParameters { + /// Number of components to keep. + pub n_components: Vec, + /// By default, covariance matrix is used to compute principal components. + /// Enable this flag if you want to use correlation matrix instead. + pub use_correlation_matrix: Vec, +} + +/// PCA grid search iterator +pub struct PCASearchParametersIterator { + pca_search_parameters: PCASearchParameters, + current_k: usize, + current_use_correlation_matrix: usize, +} + +impl IntoIterator for PCASearchParameters { + type Item = PCAParameters; + type IntoIter = PCASearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + PCASearchParametersIterator { + pca_search_parameters: self, + current_k: 0, + current_use_correlation_matrix: 0, + } + } +} + +impl Iterator for PCASearchParametersIterator { + type Item = PCAParameters; + + fn next(&mut self) -> Option { + if self.current_k == self.pca_search_parameters.n_components.len() + && self.current_use_correlation_matrix + == self.pca_search_parameters.use_correlation_matrix.len() + { + return None; + } + + let next = PCAParameters { + n_components: self.pca_search_parameters.n_components[self.current_k], + use_correlation_matrix: self.pca_search_parameters.use_correlation_matrix + [self.current_use_correlation_matrix], + }; + + if self.current_k + 1 < self.pca_search_parameters.n_components.len() { + self.current_k += 1; + } else if self.current_use_correlation_matrix + 1 + < self.pca_search_parameters.use_correlation_matrix.len() + { + self.current_k = 0; + self.current_use_correlation_matrix += 1; + } else { + self.current_k += 1; + self.current_use_correlation_matrix += 1; + } + + Some(next) + } +} + +impl Default for PCASearchParameters { + fn default() -> Self { + let default_params = PCAParameters::default(); + + PCASearchParameters { + n_components: vec![default_params.n_components], + use_correlation_matrix: vec![default_params.use_correlation_matrix], + } + } +} + impl> UnsupervisedEstimator for PCA { fn fit(x: &M, parameters: PCAParameters) -> Result { PCA::fit(x, parameters) @@ -271,6 +346,29 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::*; + #[test] + fn search_parameters() { + let parameters = PCASearchParameters { + n_components: vec![2, 4], + use_correlation_matrix: vec![true, false], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.n_components, 2); + assert_eq!(next.use_correlation_matrix, true); + let next = iter.next().unwrap(); + assert_eq!(next.n_components, 4); + assert_eq!(next.use_correlation_matrix, true); + let next = iter.next().unwrap(); + assert_eq!(next.n_components, 2); + assert_eq!(next.use_correlation_matrix, false); + let next = iter.next().unwrap(); + assert_eq!(next.n_components, 4); + assert_eq!(next.use_correlation_matrix, false); + assert!(iter.next().is_none()); + } + fn us_arrests_data() -> DenseMatrix { DenseMatrix::from_2d_array(&[ &[13.2, 236.0, 58.0, 21.2], diff --git a/src/decomposition/svd.rs b/src/decomposition/svd.rs index 3807760..3001fd9 100644 --- a/src/decomposition/svd.rs +++ b/src/decomposition/svd.rs @@ -90,6 +90,60 @@ impl SVDParameters { } } +/// SVD grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct SVDSearchParameters { + /// Maximum number of iterations of the k-means algorithm for a single run. + pub n_components: Vec, +} + +/// SVD grid search iterator +pub struct SVDSearchParametersIterator { + svd_search_parameters: SVDSearchParameters, + current_n_components: usize, +} + +impl IntoIterator for SVDSearchParameters { + type Item = SVDParameters; + type IntoIter = SVDSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + SVDSearchParametersIterator { + svd_search_parameters: self, + current_n_components: 0, + } + } +} + +impl Iterator for SVDSearchParametersIterator { + type Item = SVDParameters; + + fn next(&mut self) -> Option { + if self.current_n_components == self.svd_search_parameters.n_components.len() { + return None; + } + + let next = SVDParameters { + n_components: self.svd_search_parameters.n_components[self.current_n_components], + }; + + self.current_n_components += 1; + + Some(next) + } +} + +impl Default for SVDSearchParameters { + fn default() -> Self { + let default_params = SVDParameters::default(); + + SVDSearchParameters { + n_components: vec![default_params.n_components], + } + } +} + impl> UnsupervisedEstimator for SVD { fn fit(x: &M, parameters: SVDParameters) -> Result { SVD::fit(x, parameters) @@ -153,6 +207,20 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::*; + #[test] + fn search_parameters() { + let parameters = SVDSearchParameters { + n_components: vec![10, 100], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.n_components, 10); + let next = iter.next().unwrap(); + assert_eq!(next.n_components, 100); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn svd_decompose() { diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index 247b502..a4d6e75 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -193,6 +193,226 @@ impl> Predictor for RandomForestCla } } +/// RandomForestClassifier grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct RandomForestClassifierSearchParameters { + /// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub criterion: Vec, + /// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub max_depth: Vec>, + /// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub min_samples_leaf: Vec, + /// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub min_samples_split: Vec, + /// The number of trees in the forest. + pub n_trees: Vec, + /// Number of random sample of predictors to use as split candidates. + pub m: Vec>, + /// Whether to keep samples used for tree generation. This is required for OOB prediction. + pub keep_samples: Vec, + /// Seed used for bootstrap sampling and feature selection for each tree. + pub seed: Vec, +} + +/// RandomForestClassifier grid search iterator +pub struct RandomForestClassifierSearchParametersIterator { + random_forest_classifier_search_parameters: RandomForestClassifierSearchParameters, + current_criterion: usize, + current_max_depth: usize, + current_min_samples_leaf: usize, + current_min_samples_split: usize, + current_n_trees: usize, + current_m: usize, + current_keep_samples: usize, + current_seed: usize, +} + +impl IntoIterator for RandomForestClassifierSearchParameters { + type Item = RandomForestClassifierParameters; + type IntoIter = RandomForestClassifierSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + RandomForestClassifierSearchParametersIterator { + random_forest_classifier_search_parameters: self, + current_criterion: 0, + current_max_depth: 0, + current_min_samples_leaf: 0, + current_min_samples_split: 0, + current_n_trees: 0, + current_m: 0, + current_keep_samples: 0, + current_seed: 0, + } + } +} + +impl Iterator for RandomForestClassifierSearchParametersIterator { + type Item = RandomForestClassifierParameters; + + fn next(&mut self) -> Option { + if self.current_criterion + == self + .random_forest_classifier_search_parameters + .criterion + .len() + && self.current_max_depth + == self + .random_forest_classifier_search_parameters + .max_depth + .len() + && self.current_min_samples_leaf + == self + .random_forest_classifier_search_parameters + .min_samples_leaf + .len() + && self.current_min_samples_split + == self + .random_forest_classifier_search_parameters + .min_samples_split + .len() + && self.current_n_trees + == self + .random_forest_classifier_search_parameters + .n_trees + .len() + && self.current_m == self.random_forest_classifier_search_parameters.m.len() + && self.current_keep_samples + == self + .random_forest_classifier_search_parameters + .keep_samples + .len() + && self.current_seed == self.random_forest_classifier_search_parameters.seed.len() + { + return None; + } + + let next = RandomForestClassifierParameters { + criterion: self.random_forest_classifier_search_parameters.criterion + [self.current_criterion] + .clone(), + max_depth: self.random_forest_classifier_search_parameters.max_depth + [self.current_max_depth], + min_samples_leaf: self + .random_forest_classifier_search_parameters + .min_samples_leaf[self.current_min_samples_leaf], + min_samples_split: self + .random_forest_classifier_search_parameters + .min_samples_split[self.current_min_samples_split], + n_trees: self.random_forest_classifier_search_parameters.n_trees[self.current_n_trees], + m: self.random_forest_classifier_search_parameters.m[self.current_m], + keep_samples: self.random_forest_classifier_search_parameters.keep_samples + [self.current_keep_samples], + seed: self.random_forest_classifier_search_parameters.seed[self.current_seed], + }; + + if self.current_criterion + 1 + < self + .random_forest_classifier_search_parameters + .criterion + .len() + { + self.current_criterion += 1; + } else if self.current_max_depth + 1 + < self + .random_forest_classifier_search_parameters + .max_depth + .len() + { + self.current_criterion = 0; + self.current_max_depth += 1; + } else if self.current_min_samples_leaf + 1 + < self + .random_forest_classifier_search_parameters + .min_samples_leaf + .len() + { + self.current_criterion = 0; + self.current_max_depth = 0; + self.current_min_samples_leaf += 1; + } else if self.current_min_samples_split + 1 + < self + .random_forest_classifier_search_parameters + .min_samples_split + .len() + { + self.current_criterion = 0; + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split += 1; + } else if self.current_n_trees + 1 + < self + .random_forest_classifier_search_parameters + .n_trees + .len() + { + self.current_criterion = 0; + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split = 0; + self.current_n_trees += 1; + } else if self.current_m + 1 < self.random_forest_classifier_search_parameters.m.len() { + self.current_criterion = 0; + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split = 0; + self.current_n_trees = 0; + self.current_m += 1; + } else if self.current_keep_samples + 1 + < self + .random_forest_classifier_search_parameters + .keep_samples + .len() + { + self.current_criterion = 0; + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split = 0; + self.current_n_trees = 0; + self.current_m = 0; + self.current_keep_samples += 1; + } else if self.current_seed + 1 < self.random_forest_classifier_search_parameters.seed.len() + { + self.current_criterion = 0; + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split = 0; + self.current_n_trees = 0; + self.current_m = 0; + self.current_keep_samples = 0; + self.current_seed += 1; + } else { + self.current_criterion += 1; + self.current_max_depth += 1; + self.current_min_samples_leaf += 1; + self.current_min_samples_split += 1; + self.current_n_trees += 1; + self.current_m += 1; + self.current_keep_samples += 1; + self.current_seed += 1; + } + + Some(next) + } +} + +impl Default for RandomForestClassifierSearchParameters { + fn default() -> Self { + let default_params = RandomForestClassifierParameters::default(); + + RandomForestClassifierSearchParameters { + criterion: vec![default_params.criterion], + max_depth: vec![default_params.max_depth], + min_samples_leaf: vec![default_params.min_samples_leaf], + min_samples_split: vec![default_params.min_samples_split], + n_trees: vec![default_params.n_trees], + m: vec![default_params.m], + keep_samples: vec![default_params.keep_samples], + seed: vec![default_params.seed], + } + } +} + impl RandomForestClassifier { /// Build a forest of trees from the training set. /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation. @@ -346,6 +566,29 @@ mod tests { use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::metrics::*; + #[test] + fn search_parameters() { + let parameters = RandomForestClassifierSearchParameters { + n_trees: vec![10, 100], + m: vec![None, Some(1)], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.n_trees, 10); + assert_eq!(next.m, None); + let next = iter.next().unwrap(); + assert_eq!(next.n_trees, 100); + assert_eq!(next.m, None); + let next = iter.next().unwrap(); + assert_eq!(next.n_trees, 10); + assert_eq!(next.m, Some(1)); + let next = iter.next().unwrap(); + assert_eq!(next.n_trees, 100); + assert_eq!(next.m, Some(1)); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn fit_predict_iris() { diff --git a/src/ensemble/random_forest_regressor.rs b/src/ensemble/random_forest_regressor.rs index 08a7dcc..ec78137 100644 --- a/src/ensemble/random_forest_regressor.rs +++ b/src/ensemble/random_forest_regressor.rs @@ -176,6 +176,191 @@ impl> Predictor for RandomForestReg } } +/// RandomForestRegressor grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct RandomForestRegressorSearchParameters { + /// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub max_depth: Vec>, + /// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub min_samples_leaf: Vec, + /// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub min_samples_split: Vec, + /// The number of trees in the forest. + pub n_trees: Vec, + /// Number of random sample of predictors to use as split candidates. + pub m: Vec>, + /// Whether to keep samples used for tree generation. This is required for OOB prediction. + pub keep_samples: Vec, + /// Seed used for bootstrap sampling and feature selection for each tree. + pub seed: Vec, +} + +/// RandomForestRegressor grid search iterator +pub struct RandomForestRegressorSearchParametersIterator { + random_forest_regressor_search_parameters: RandomForestRegressorSearchParameters, + current_max_depth: usize, + current_min_samples_leaf: usize, + current_min_samples_split: usize, + current_n_trees: usize, + current_m: usize, + current_keep_samples: usize, + current_seed: usize, +} + +impl IntoIterator for RandomForestRegressorSearchParameters { + type Item = RandomForestRegressorParameters; + type IntoIter = RandomForestRegressorSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + RandomForestRegressorSearchParametersIterator { + random_forest_regressor_search_parameters: self, + current_max_depth: 0, + current_min_samples_leaf: 0, + current_min_samples_split: 0, + current_n_trees: 0, + current_m: 0, + current_keep_samples: 0, + current_seed: 0, + } + } +} + +impl Iterator for RandomForestRegressorSearchParametersIterator { + type Item = RandomForestRegressorParameters; + + fn next(&mut self) -> Option { + if self.current_max_depth + == self + .random_forest_regressor_search_parameters + .max_depth + .len() + && self.current_min_samples_leaf + == self + .random_forest_regressor_search_parameters + .min_samples_leaf + .len() + && self.current_min_samples_split + == self + .random_forest_regressor_search_parameters + .min_samples_split + .len() + && self.current_n_trees == self.random_forest_regressor_search_parameters.n_trees.len() + && self.current_m == self.random_forest_regressor_search_parameters.m.len() + && self.current_keep_samples + == self + .random_forest_regressor_search_parameters + .keep_samples + .len() + && self.current_seed == self.random_forest_regressor_search_parameters.seed.len() + { + return None; + } + + let next = RandomForestRegressorParameters { + max_depth: self.random_forest_regressor_search_parameters.max_depth + [self.current_max_depth], + min_samples_leaf: self + .random_forest_regressor_search_parameters + .min_samples_leaf[self.current_min_samples_leaf], + min_samples_split: self + .random_forest_regressor_search_parameters + .min_samples_split[self.current_min_samples_split], + n_trees: self.random_forest_regressor_search_parameters.n_trees[self.current_n_trees], + m: self.random_forest_regressor_search_parameters.m[self.current_m], + keep_samples: self.random_forest_regressor_search_parameters.keep_samples + [self.current_keep_samples], + seed: self.random_forest_regressor_search_parameters.seed[self.current_seed], + }; + + if self.current_max_depth + 1 + < self + .random_forest_regressor_search_parameters + .max_depth + .len() + { + self.current_max_depth += 1; + } else if self.current_min_samples_leaf + 1 + < self + .random_forest_regressor_search_parameters + .min_samples_leaf + .len() + { + self.current_max_depth = 0; + self.current_min_samples_leaf += 1; + } else if self.current_min_samples_split + 1 + < self + .random_forest_regressor_search_parameters + .min_samples_split + .len() + { + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split += 1; + } else if self.current_n_trees + 1 + < self.random_forest_regressor_search_parameters.n_trees.len() + { + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split = 0; + self.current_n_trees += 1; + } else if self.current_m + 1 < self.random_forest_regressor_search_parameters.m.len() { + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split = 0; + self.current_n_trees = 0; + self.current_m += 1; + } else if self.current_keep_samples + 1 + < self + .random_forest_regressor_search_parameters + .keep_samples + .len() + { + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split = 0; + self.current_n_trees = 0; + self.current_m = 0; + self.current_keep_samples += 1; + } else if self.current_seed + 1 < self.random_forest_regressor_search_parameters.seed.len() + { + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split = 0; + self.current_n_trees = 0; + self.current_m = 0; + self.current_keep_samples = 0; + self.current_seed += 1; + } else { + self.current_max_depth += 1; + self.current_min_samples_leaf += 1; + self.current_min_samples_split += 1; + self.current_n_trees += 1; + self.current_m += 1; + self.current_keep_samples += 1; + self.current_seed += 1; + } + + Some(next) + } +} + +impl Default for RandomForestRegressorSearchParameters { + fn default() -> Self { + let default_params = RandomForestRegressorParameters::default(); + + RandomForestRegressorSearchParameters { + max_depth: vec![default_params.max_depth], + min_samples_leaf: vec![default_params.min_samples_leaf], + min_samples_split: vec![default_params.min_samples_split], + n_trees: vec![default_params.n_trees], + m: vec![default_params.m], + keep_samples: vec![default_params.keep_samples], + seed: vec![default_params.seed], + } + } +} + impl RandomForestRegressor { /// Build a forest of trees from the training set. /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation. @@ -302,6 +487,29 @@ mod tests { use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::metrics::mean_absolute_error; + #[test] + fn search_parameters() { + let parameters = RandomForestRegressorSearchParameters { + n_trees: vec![10, 100], + m: vec![None, Some(1)], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.n_trees, 10); + assert_eq!(next.m, None); + let next = iter.next().unwrap(); + assert_eq!(next.n_trees, 100); + assert_eq!(next.m, None); + let next = iter.next().unwrap(); + assert_eq!(next.n_trees, 10); + assert_eq!(next.m, Some(1)); + let next = iter.next().unwrap(); + assert_eq!(next.n_trees, 100); + assert_eq!(next.m, Some(1)); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn fit_longley() { diff --git a/src/linear/lasso.rs b/src/linear/lasso.rs index 7e80a8b..aae7e50 100644 --- a/src/linear/lasso.rs +++ b/src/linear/lasso.rs @@ -129,7 +129,7 @@ pub struct LassoSearchParameters { /// Lasso grid search iterator pub struct LassoSearchParametersIterator { - lasso_regression_search_parameters: LassoSearchParameters, + lasso_search_parameters: LassoSearchParameters, current_alpha: usize, current_normalize: usize, current_tol: usize, @@ -142,7 +142,7 @@ impl IntoIterator for LassoSearchParameters { fn into_iter(self) -> Self::IntoIter { LassoSearchParametersIterator { - lasso_regression_search_parameters: self, + lasso_search_parameters: self, current_alpha: 0, current_normalize: 0, current_tol: 0, @@ -155,34 +155,31 @@ impl Iterator for LassoSearchParametersIterator { type Item = LassoParameters; fn next(&mut self) -> Option { - if self.current_alpha == self.lasso_regression_search_parameters.alpha.len() - && self.current_normalize == self.lasso_regression_search_parameters.normalize.len() - && self.current_tol == self.lasso_regression_search_parameters.tol.len() - && self.current_max_iter == self.lasso_regression_search_parameters.max_iter.len() + if self.current_alpha == self.lasso_search_parameters.alpha.len() + && self.current_normalize == self.lasso_search_parameters.normalize.len() + && self.current_tol == self.lasso_search_parameters.tol.len() + && self.current_max_iter == self.lasso_search_parameters.max_iter.len() { return None; } let next = LassoParameters { - alpha: self.lasso_regression_search_parameters.alpha[self.current_alpha], - normalize: self.lasso_regression_search_parameters.normalize[self.current_normalize], - tol: self.lasso_regression_search_parameters.tol[self.current_tol], - max_iter: self.lasso_regression_search_parameters.max_iter[self.current_max_iter], + alpha: self.lasso_search_parameters.alpha[self.current_alpha], + normalize: self.lasso_search_parameters.normalize[self.current_normalize], + tol: self.lasso_search_parameters.tol[self.current_tol], + max_iter: self.lasso_search_parameters.max_iter[self.current_max_iter], }; - if self.current_alpha + 1 < self.lasso_regression_search_parameters.alpha.len() { + if self.current_alpha + 1 < self.lasso_search_parameters.alpha.len() { self.current_alpha += 1; - } else if self.current_normalize + 1 - < self.lasso_regression_search_parameters.normalize.len() - { + } else if self.current_normalize + 1 < self.lasso_search_parameters.normalize.len() { self.current_alpha = 0; self.current_normalize += 1; - } else if self.current_tol + 1 < self.lasso_regression_search_parameters.tol.len() { + } else if self.current_tol + 1 < self.lasso_search_parameters.tol.len() { self.current_alpha = 0; self.current_normalize = 0; self.current_tol += 1; - } else if self.current_max_iter + 1 < self.lasso_regression_search_parameters.max_iter.len() - { + } else if self.current_max_iter + 1 < self.lasso_search_parameters.max_iter.len() { self.current_alpha = 0; self.current_normalize = 0; self.current_tol = 0; diff --git a/src/model_selection/hyper_tuning.rs b/src/model_selection/hyper_tuning.rs index 3093fbd..cb69da1 100644 --- a/src/model_selection/hyper_tuning.rs +++ b/src/model_selection/hyper_tuning.rs @@ -114,4 +114,4 @@ mod tests { assert!([0., 1.].contains(&results.parameters.alpha)); } -} \ No newline at end of file +} diff --git a/src/model_selection/mod.rs b/src/model_selection/mod.rs index 68f0635..6f737d6 100644 --- a/src/model_selection/mod.rs +++ b/src/model_selection/mod.rs @@ -281,7 +281,6 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::*; - use crate::metrics::{accuracy, mean_absolute_error}; use crate::model_selection::kfold::KFold; use crate::neighbors::knn_regressor::KNNRegressor; diff --git a/src/naive_bayes/bernoulli.rs b/src/naive_bayes/bernoulli.rs index 95c4d36..29c6c84 100644 --- a/src/naive_bayes/bernoulli.rs +++ b/src/naive_bayes/bernoulli.rs @@ -150,6 +150,88 @@ impl Default for BernoulliNBParameters { } } +/// BernoulliNB grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct BernoulliNBSearchParameters { + /// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing). + pub alpha: Vec, + /// Prior probabilities of the classes. If specified the priors are not adjusted according to the data + pub priors: Vec>>, + /// Threshold for binarizing (mapping to booleans) of sample features. If None, input is presumed to already consist of binary vectors. + pub binarize: Vec>, +} + +/// BernoulliNB grid search iterator +pub struct BernoulliNBSearchParametersIterator { + bernoulli_nb_search_parameters: BernoulliNBSearchParameters, + current_alpha: usize, + current_priors: usize, + current_binarize: usize, +} + +impl IntoIterator for BernoulliNBSearchParameters { + type Item = BernoulliNBParameters; + type IntoIter = BernoulliNBSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + BernoulliNBSearchParametersIterator { + bernoulli_nb_search_parameters: self, + current_alpha: 0, + current_priors: 0, + current_binarize: 0, + } + } +} + +impl Iterator for BernoulliNBSearchParametersIterator { + type Item = BernoulliNBParameters; + + fn next(&mut self) -> Option { + if self.current_alpha == self.bernoulli_nb_search_parameters.alpha.len() + && self.current_priors == self.bernoulli_nb_search_parameters.priors.len() + && self.current_binarize == self.bernoulli_nb_search_parameters.binarize.len() + { + return None; + } + + let next = BernoulliNBParameters { + alpha: self.bernoulli_nb_search_parameters.alpha[self.current_alpha], + priors: self.bernoulli_nb_search_parameters.priors[self.current_priors].clone(), + binarize: self.bernoulli_nb_search_parameters.binarize[self.current_binarize], + }; + + if self.current_alpha + 1 < self.bernoulli_nb_search_parameters.alpha.len() { + self.current_alpha += 1; + } else if self.current_priors + 1 < self.bernoulli_nb_search_parameters.priors.len() { + self.current_alpha = 0; + self.current_priors += 1; + } else if self.current_binarize + 1 < self.bernoulli_nb_search_parameters.binarize.len() { + self.current_alpha = 0; + self.current_priors = 0; + self.current_binarize += 1; + } else { + self.current_alpha += 1; + self.current_priors += 1; + self.current_binarize += 1; + } + + Some(next) + } +} + +impl Default for BernoulliNBSearchParameters { + fn default() -> Self { + let default_params = BernoulliNBParameters::default(); + + BernoulliNBSearchParameters { + alpha: vec![default_params.alpha], + priors: vec![default_params.priors], + binarize: vec![default_params.binarize], + } + } +} + impl BernoulliNBDistribution { /// Fits the distribution to a NxM matrix where N is number of samples and M is number of features. /// * `x` - training data. @@ -347,6 +429,20 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; + #[test] + fn search_parameters() { + let parameters = BernoulliNBSearchParameters { + alpha: vec![1., 2.], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 1.); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 2.); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn run_bernoulli_naive_bayes() { diff --git a/src/naive_bayes/categorical.rs b/src/naive_bayes/categorical.rs index 8706702..7855688 100644 --- a/src/naive_bayes/categorical.rs +++ b/src/naive_bayes/categorical.rs @@ -261,6 +261,60 @@ impl Default for CategoricalNBParameters { } } +/// CategoricalNB grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct CategoricalNBSearchParameters { + /// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing). + pub alpha: Vec, +} + +/// CategoricalNB grid search iterator +pub struct CategoricalNBSearchParametersIterator { + categorical_nb_search_parameters: CategoricalNBSearchParameters, + current_alpha: usize, +} + +impl IntoIterator for CategoricalNBSearchParameters { + type Item = CategoricalNBParameters; + type IntoIter = CategoricalNBSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + CategoricalNBSearchParametersIterator { + categorical_nb_search_parameters: self, + current_alpha: 0, + } + } +} + +impl Iterator for CategoricalNBSearchParametersIterator { + type Item = CategoricalNBParameters; + + fn next(&mut self) -> Option { + if self.current_alpha == self.categorical_nb_search_parameters.alpha.len() { + return None; + } + + let next = CategoricalNBParameters { + alpha: self.categorical_nb_search_parameters.alpha[self.current_alpha], + }; + + self.current_alpha += 1; + + Some(next) + } +} + +impl Default for CategoricalNBSearchParameters { + fn default() -> Self { + let default_params = CategoricalNBParameters::default(); + + CategoricalNBSearchParameters { + alpha: vec![default_params.alpha], + } + } +} + /// CategoricalNB implements the categorical naive Bayes algorithm for categorically distributed data. #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, PartialEq)] @@ -351,6 +405,20 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; + #[test] + fn search_parameters() { + let parameters = CategoricalNBSearchParameters { + alpha: vec![1., 2.], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 1.); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 2.); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn run_categorical_naive_bayes() { diff --git a/src/naive_bayes/gaussian.rs b/src/naive_bayes/gaussian.rs index bd23919..24bbdd3 100644 --- a/src/naive_bayes/gaussian.rs +++ b/src/naive_bayes/gaussian.rs @@ -76,7 +76,7 @@ impl> NBDistribution for GaussianNBDistributio /// `GaussianNB` parameters. Use `Default::default()` for default values. #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Default, Clone)] +#[derive(Debug, Clone)] pub struct GaussianNBParameters { /// Prior probabilities of the classes. If specified the priors are not adjusted according to the data pub priors: Option>, @@ -90,6 +90,66 @@ impl GaussianNBParameters { } } +impl Default for GaussianNBParameters { + fn default() -> Self { + Self { priors: None } + } +} + +/// GaussianNB grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct GaussianNBSearchParameters { + /// Prior probabilities of the classes. If specified the priors are not adjusted according to the data + pub priors: Vec>>, +} + +/// GaussianNB grid search iterator +pub struct GaussianNBSearchParametersIterator { + gaussian_nb_search_parameters: GaussianNBSearchParameters, + current_priors: usize, +} + +impl IntoIterator for GaussianNBSearchParameters { + type Item = GaussianNBParameters; + type IntoIter = GaussianNBSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + GaussianNBSearchParametersIterator { + gaussian_nb_search_parameters: self, + current_priors: 0, + } + } +} + +impl Iterator for GaussianNBSearchParametersIterator { + type Item = GaussianNBParameters; + + fn next(&mut self) -> Option { + if self.current_priors == self.gaussian_nb_search_parameters.priors.len() { + return None; + } + + let next = GaussianNBParameters { + priors: self.gaussian_nb_search_parameters.priors[self.current_priors].clone(), + }; + + self.current_priors += 1; + + Some(next) + } +} + +impl Default for GaussianNBSearchParameters { + fn default() -> Self { + let default_params = GaussianNBParameters::default(); + + GaussianNBSearchParameters { + priors: vec![default_params.priors], + } + } +} + impl GaussianNBDistribution { /// Fits the distribution to a NxM matrix where N is number of samples and M is number of features. /// * `x` - training data. @@ -260,6 +320,20 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; + #[test] + fn search_parameters() { + let parameters = GaussianNBSearchParameters { + priors: vec![Some(vec![1.]), Some(vec![2.])], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.priors, Some(vec![1.])); + let next = iter.next().unwrap(); + assert_eq!(next.priors, Some(vec![2.])); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn run_gaussian_naive_bayes() { diff --git a/src/naive_bayes/multinomial.rs b/src/naive_bayes/multinomial.rs index f42b99e..6e846c1 100644 --- a/src/naive_bayes/multinomial.rs +++ b/src/naive_bayes/multinomial.rs @@ -114,6 +114,76 @@ impl Default for MultinomialNBParameters { } } +/// MultinomialNB grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct MultinomialNBSearchParameters { + /// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing). + pub alpha: Vec, + /// Prior probabilities of the classes. If specified the priors are not adjusted according to the data + pub priors: Vec>>, +} + +/// MultinomialNB grid search iterator +pub struct MultinomialNBSearchParametersIterator { + multinomial_nb_search_parameters: MultinomialNBSearchParameters, + current_alpha: usize, + current_priors: usize, +} + +impl IntoIterator for MultinomialNBSearchParameters { + type Item = MultinomialNBParameters; + type IntoIter = MultinomialNBSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + MultinomialNBSearchParametersIterator { + multinomial_nb_search_parameters: self, + current_alpha: 0, + current_priors: 0, + } + } +} + +impl Iterator for MultinomialNBSearchParametersIterator { + type Item = MultinomialNBParameters; + + fn next(&mut self) -> Option { + if self.current_alpha == self.multinomial_nb_search_parameters.alpha.len() + && self.current_priors == self.multinomial_nb_search_parameters.priors.len() + { + return None; + } + + let next = MultinomialNBParameters { + alpha: self.multinomial_nb_search_parameters.alpha[self.current_alpha], + priors: self.multinomial_nb_search_parameters.priors[self.current_priors].clone(), + }; + + if self.current_alpha + 1 < self.multinomial_nb_search_parameters.alpha.len() { + self.current_alpha += 1; + } else if self.current_priors + 1 < self.multinomial_nb_search_parameters.priors.len() { + self.current_alpha = 0; + self.current_priors += 1; + } else { + self.current_alpha += 1; + self.current_priors += 1; + } + + Some(next) + } +} + +impl Default for MultinomialNBSearchParameters { + fn default() -> Self { + let default_params = MultinomialNBParameters::default(); + + MultinomialNBSearchParameters { + alpha: vec![default_params.alpha], + priors: vec![default_params.priors], + } + } +} + impl MultinomialNBDistribution { /// Fits the distribution to a NxM matrix where N is number of samples and M is number of features. /// * `x` - training data. @@ -297,6 +367,20 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; + #[test] + fn search_parameters() { + let parameters = MultinomialNBSearchParameters { + alpha: vec![1., 2.], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 1.); + let next = iter.next().unwrap(); + assert_eq!(next.alpha, 2.); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn run_multinomial_naive_bayes() { diff --git a/src/svm/mod.rs b/src/svm/mod.rs index 55df584..4c71b3f 100644 --- a/src/svm/mod.rs +++ b/src/svm/mod.rs @@ -33,7 +33,7 @@ use crate::linalg::BaseVector; use crate::math::num::RealNumber; /// Defines a kernel function -pub trait Kernel> { +pub trait Kernel>: Clone { /// Apply kernel function to x_i and x_j fn apply(&self, x_i: &V, x_j: &V) -> T; } @@ -95,12 +95,12 @@ impl Kernels { /// Linear Kernel #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct LinearKernel {} /// Radial basis function (Gaussian) kernel #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct RBFKernel { /// kernel coefficient pub gamma: T, @@ -108,7 +108,7 @@ pub struct RBFKernel { /// Polynomial kernel #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct PolynomialKernel { /// degree of the polynomial pub degree: T, @@ -120,7 +120,7 @@ pub struct PolynomialKernel { /// Sigmoid (hyperbolic tangent) kernel #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct SigmoidKernel { /// kernel coefficient pub gamma: T, diff --git a/src/svm/svc.rs b/src/svm/svc.rs index 87fb743..46b0b68 100644 --- a/src/svm/svc.rs +++ b/src/svm/svc.rs @@ -102,6 +102,109 @@ pub struct SVCParameters, K: Kernel m: PhantomData, } +/// SVC grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct SVCSearchParameters, K: Kernel> { + /// Number of epochs. + pub epoch: Vec, + /// Regularization parameter. + pub c: Vec, + /// Tolerance for stopping epoch. + pub tol: Vec, + /// The kernel function. + pub kernel: Vec, + /// Unused parameter. + m: PhantomData, +} + +/// SVC grid search iterator +pub struct SVCSearchParametersIterator, K: Kernel> { + svc_search_parameters: SVCSearchParameters, + current_epoch: usize, + current_c: usize, + current_tol: usize, + current_kernel: usize, +} + +impl, K: Kernel> IntoIterator + for SVCSearchParameters +{ + type Item = SVCParameters; + type IntoIter = SVCSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + SVCSearchParametersIterator { + svc_search_parameters: self, + current_epoch: 0, + current_c: 0, + current_tol: 0, + current_kernel: 0, + } + } +} + +impl, K: Kernel> Iterator + for SVCSearchParametersIterator +{ + type Item = SVCParameters; + + fn next(&mut self) -> Option { + if self.current_epoch == self.svc_search_parameters.epoch.len() + && self.current_c == self.svc_search_parameters.c.len() + && self.current_tol == self.svc_search_parameters.tol.len() + && self.current_kernel == self.svc_search_parameters.kernel.len() + { + return None; + } + + let next = SVCParameters:: { + epoch: self.svc_search_parameters.epoch[self.current_epoch], + c: self.svc_search_parameters.c[self.current_c], + tol: self.svc_search_parameters.tol[self.current_tol], + kernel: self.svc_search_parameters.kernel[self.current_kernel].clone(), + m: PhantomData, + }; + + if self.current_epoch + 1 < self.svc_search_parameters.epoch.len() { + self.current_epoch += 1; + } else if self.current_c + 1 < self.svc_search_parameters.c.len() { + self.current_epoch = 0; + self.current_c += 1; + } else if self.current_tol + 1 < self.svc_search_parameters.tol.len() { + self.current_epoch = 0; + self.current_c = 0; + self.current_tol += 1; + } else if self.current_kernel + 1 < self.svc_search_parameters.kernel.len() { + self.current_epoch = 0; + self.current_c = 0; + self.current_tol = 0; + self.current_kernel += 1; + } else { + self.current_epoch += 1; + self.current_c += 1; + self.current_tol += 1; + self.current_kernel += 1; + } + + Some(next) + } +} + +impl> Default for SVCSearchParameters { + fn default() -> Self { + let default_params: SVCParameters = SVCParameters::default(); + + SVCSearchParameters { + epoch: vec![default_params.epoch], + c: vec![default_params.c], + tol: vec![default_params.tol], + kernel: vec![default_params.kernel], + m: PhantomData, + } + } +} + #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug)] #[cfg_attr( @@ -737,6 +840,24 @@ mod tests { #[cfg(feature = "serde")] use crate::svm::*; + #[test] + fn search_parameters() { + let parameters: SVCSearchParameters, LinearKernel> = + SVCSearchParameters { + epoch: vec![10, 100], + kernel: vec![LinearKernel {}], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.epoch, 10); + assert_eq!(next.kernel, LinearKernel {}); + let next = iter.next().unwrap(); + assert_eq!(next.epoch, 100); + assert_eq!(next.kernel, LinearKernel {}); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn svc_fit_predict() { diff --git a/src/svm/svr.rs b/src/svm/svr.rs index 18c73d1..25326d4 100644 --- a/src/svm/svr.rs +++ b/src/svm/svr.rs @@ -94,6 +94,109 @@ pub struct SVRParameters, K: Kernel m: PhantomData, } +/// SVR grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct SVRSearchParameters, K: Kernel> { + /// Epsilon in the epsilon-SVR model. + pub eps: Vec, + /// Regularization parameter. + pub c: Vec, + /// Tolerance for stopping eps. + pub tol: Vec, + /// The kernel function. + pub kernel: Vec, + /// Unused parameter. + m: PhantomData, +} + +/// SVR grid search iterator +pub struct SVRSearchParametersIterator, K: Kernel> { + svr_search_parameters: SVRSearchParameters, + current_eps: usize, + current_c: usize, + current_tol: usize, + current_kernel: usize, +} + +impl, K: Kernel> IntoIterator + for SVRSearchParameters +{ + type Item = SVRParameters; + type IntoIter = SVRSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + SVRSearchParametersIterator { + svr_search_parameters: self, + current_eps: 0, + current_c: 0, + current_tol: 0, + current_kernel: 0, + } + } +} + +impl, K: Kernel> Iterator + for SVRSearchParametersIterator +{ + type Item = SVRParameters; + + fn next(&mut self) -> Option { + if self.current_eps == self.svr_search_parameters.eps.len() + && self.current_c == self.svr_search_parameters.c.len() + && self.current_tol == self.svr_search_parameters.tol.len() + && self.current_kernel == self.svr_search_parameters.kernel.len() + { + return None; + } + + let next = SVRParameters:: { + eps: self.svr_search_parameters.eps[self.current_eps], + c: self.svr_search_parameters.c[self.current_c], + tol: self.svr_search_parameters.tol[self.current_tol], + kernel: self.svr_search_parameters.kernel[self.current_kernel].clone(), + m: PhantomData, + }; + + if self.current_eps + 1 < self.svr_search_parameters.eps.len() { + self.current_eps += 1; + } else if self.current_c + 1 < self.svr_search_parameters.c.len() { + self.current_eps = 0; + self.current_c += 1; + } else if self.current_tol + 1 < self.svr_search_parameters.tol.len() { + self.current_eps = 0; + self.current_c = 0; + self.current_tol += 1; + } else if self.current_kernel + 1 < self.svr_search_parameters.kernel.len() { + self.current_eps = 0; + self.current_c = 0; + self.current_tol = 0; + self.current_kernel += 1; + } else { + self.current_eps += 1; + self.current_c += 1; + self.current_tol += 1; + self.current_kernel += 1; + } + + Some(next) + } +} + +impl> Default for SVRSearchParameters { + fn default() -> Self { + let default_params: SVRParameters = SVRParameters::default(); + + SVRSearchParameters { + eps: vec![default_params.eps], + c: vec![default_params.c], + tol: vec![default_params.tol], + kernel: vec![default_params.kernel], + m: PhantomData, + } + } +} + #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug)] #[cfg_attr( @@ -536,6 +639,24 @@ mod tests { #[cfg(feature = "serde")] use crate::svm::*; + #[test] + fn search_parameters() { + let parameters: SVRSearchParameters, LinearKernel> = + SVRSearchParameters { + eps: vec![0., 1.], + kernel: vec![LinearKernel {}], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.eps, 0.); + assert_eq!(next.kernel, LinearKernel {}); + let next = iter.next().unwrap(); + assert_eq!(next.eps, 1.); + assert_eq!(next.kernel, LinearKernel {}); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn svr_fit_predict() { diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index 35889e4..a1699af 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -201,6 +201,144 @@ impl Default for DecisionTreeClassifierParameters { } } +/// DecisionTreeClassifier grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct DecisionTreeClassifierSearchParameters { + /// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub criterion: Vec, + /// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub max_depth: Vec>, + /// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub min_samples_leaf: Vec, + /// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub min_samples_split: Vec, +} + +/// DecisionTreeClassifier grid search iterator +pub struct DecisionTreeClassifierSearchParametersIterator { + decision_tree_classifier_search_parameters: DecisionTreeClassifierSearchParameters, + current_criterion: usize, + current_max_depth: usize, + current_min_samples_leaf: usize, + current_min_samples_split: usize, +} + +impl IntoIterator for DecisionTreeClassifierSearchParameters { + type Item = DecisionTreeClassifierParameters; + type IntoIter = DecisionTreeClassifierSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + DecisionTreeClassifierSearchParametersIterator { + decision_tree_classifier_search_parameters: self, + current_criterion: 0, + current_max_depth: 0, + current_min_samples_leaf: 0, + current_min_samples_split: 0, + } + } +} + +impl Iterator for DecisionTreeClassifierSearchParametersIterator { + type Item = DecisionTreeClassifierParameters; + + fn next(&mut self) -> Option { + if self.current_criterion + == self + .decision_tree_classifier_search_parameters + .criterion + .len() + && self.current_max_depth + == self + .decision_tree_classifier_search_parameters + .max_depth + .len() + && self.current_min_samples_leaf + == self + .decision_tree_classifier_search_parameters + .min_samples_leaf + .len() + && self.current_min_samples_split + == self + .decision_tree_classifier_search_parameters + .min_samples_split + .len() + { + return None; + } + + let next = DecisionTreeClassifierParameters { + criterion: self.decision_tree_classifier_search_parameters.criterion + [self.current_criterion] + .clone(), + max_depth: self.decision_tree_classifier_search_parameters.max_depth + [self.current_max_depth], + min_samples_leaf: self + .decision_tree_classifier_search_parameters + .min_samples_leaf[self.current_min_samples_leaf], + min_samples_split: self + .decision_tree_classifier_search_parameters + .min_samples_split[self.current_min_samples_split], + }; + + if self.current_criterion + 1 + < self + .decision_tree_classifier_search_parameters + .criterion + .len() + { + self.current_criterion += 1; + } else if self.current_max_depth + 1 + < self + .decision_tree_classifier_search_parameters + .max_depth + .len() + { + self.current_criterion = 0; + self.current_max_depth += 1; + } else if self.current_min_samples_leaf + 1 + < self + .decision_tree_classifier_search_parameters + .min_samples_leaf + .len() + { + self.current_criterion = 0; + self.current_max_depth = 0; + self.current_min_samples_leaf += 1; + } else if self.current_min_samples_split + 1 + < self + .decision_tree_classifier_search_parameters + .min_samples_split + .len() + { + self.current_criterion = 0; + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split += 1; + } else { + self.current_criterion += 1; + self.current_max_depth += 1; + self.current_min_samples_leaf += 1; + self.current_min_samples_split += 1; + } + + Some(next) + } +} + +impl Default for DecisionTreeClassifierSearchParameters { + fn default() -> Self { + let default_params = DecisionTreeClassifierParameters::default(); + + DecisionTreeClassifierSearchParameters { + criterion: vec![default_params.criterion], + max_depth: vec![default_params.max_depth], + min_samples_leaf: vec![default_params.min_samples_leaf], + min_samples_split: vec![default_params.min_samples_split], + } + } +} + impl Node { fn new(index: usize, output: usize) -> Self { Node { @@ -651,6 +789,29 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; + #[test] + fn search_parameters() { + let parameters = DecisionTreeClassifierSearchParameters { + max_depth: vec![Some(10), Some(100)], + min_samples_split: vec![1, 2], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.max_depth, Some(10)); + assert_eq!(next.min_samples_split, 1); + let next = iter.next().unwrap(); + assert_eq!(next.max_depth, Some(100)); + assert_eq!(next.min_samples_split, 1); + let next = iter.next().unwrap(); + assert_eq!(next.max_depth, Some(10)); + assert_eq!(next.min_samples_split, 2); + let next = iter.next().unwrap(); + assert_eq!(next.max_depth, Some(100)); + assert_eq!(next.min_samples_split, 2); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn gini_impurity() { diff --git a/src/tree/decision_tree_regressor.rs b/src/tree/decision_tree_regressor.rs index 25f5e7e..f48de33 100644 --- a/src/tree/decision_tree_regressor.rs +++ b/src/tree/decision_tree_regressor.rs @@ -134,6 +134,120 @@ impl Default for DecisionTreeRegressorParameters { } } +/// DecisionTreeRegressor grid search parameters +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +pub struct DecisionTreeRegressorSearchParameters { + /// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) + pub max_depth: Vec>, + /// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) + pub min_samples_leaf: Vec, + /// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) + pub min_samples_split: Vec, +} + +/// DecisionTreeRegressor grid search iterator +pub struct DecisionTreeRegressorSearchParametersIterator { + decision_tree_regressor_search_parameters: DecisionTreeRegressorSearchParameters, + current_max_depth: usize, + current_min_samples_leaf: usize, + current_min_samples_split: usize, +} + +impl IntoIterator for DecisionTreeRegressorSearchParameters { + type Item = DecisionTreeRegressorParameters; + type IntoIter = DecisionTreeRegressorSearchParametersIterator; + + fn into_iter(self) -> Self::IntoIter { + DecisionTreeRegressorSearchParametersIterator { + decision_tree_regressor_search_parameters: self, + current_max_depth: 0, + current_min_samples_leaf: 0, + current_min_samples_split: 0, + } + } +} + +impl Iterator for DecisionTreeRegressorSearchParametersIterator { + type Item = DecisionTreeRegressorParameters; + + fn next(&mut self) -> Option { + if self.current_max_depth + == self + .decision_tree_regressor_search_parameters + .max_depth + .len() + && self.current_min_samples_leaf + == self + .decision_tree_regressor_search_parameters + .min_samples_leaf + .len() + && self.current_min_samples_split + == self + .decision_tree_regressor_search_parameters + .min_samples_split + .len() + { + return None; + } + + let next = DecisionTreeRegressorParameters { + max_depth: self.decision_tree_regressor_search_parameters.max_depth + [self.current_max_depth], + min_samples_leaf: self + .decision_tree_regressor_search_parameters + .min_samples_leaf[self.current_min_samples_leaf], + min_samples_split: self + .decision_tree_regressor_search_parameters + .min_samples_split[self.current_min_samples_split], + }; + + if self.current_max_depth + 1 + < self + .decision_tree_regressor_search_parameters + .max_depth + .len() + { + self.current_max_depth += 1; + } else if self.current_min_samples_leaf + 1 + < self + .decision_tree_regressor_search_parameters + .min_samples_leaf + .len() + { + self.current_max_depth = 0; + self.current_min_samples_leaf += 1; + } else if self.current_min_samples_split + 1 + < self + .decision_tree_regressor_search_parameters + .min_samples_split + .len() + { + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split += 1; + } else { + self.current_max_depth += 1; + self.current_min_samples_leaf += 1; + self.current_min_samples_split += 1; + } + + Some(next) + } +} + +impl Default for DecisionTreeRegressorSearchParameters { + fn default() -> Self { + let default_params = DecisionTreeRegressorParameters::default(); + + DecisionTreeRegressorSearchParameters { + max_depth: vec![default_params.max_depth], + min_samples_leaf: vec![default_params.min_samples_leaf], + min_samples_split: vec![default_params.min_samples_split], + } + } +} + impl Node { fn new(index: usize, output: T) -> Self { Node { @@ -517,6 +631,29 @@ mod tests { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; + #[test] + fn search_parameters() { + let parameters = DecisionTreeRegressorSearchParameters { + max_depth: vec![Some(10), Some(100)], + min_samples_split: vec![1, 2], + ..Default::default() + }; + let mut iter = parameters.into_iter(); + let next = iter.next().unwrap(); + assert_eq!(next.max_depth, Some(10)); + assert_eq!(next.min_samples_split, 1); + let next = iter.next().unwrap(); + assert_eq!(next.max_depth, Some(100)); + assert_eq!(next.min_samples_split, 1); + let next = iter.next().unwrap(); + assert_eq!(next.max_depth, Some(10)); + assert_eq!(next.min_samples_split, 2); + let next = iter.next().unwrap(); + assert_eq!(next.max_depth, Some(100)); + assert_eq!(next.min_samples_split, 2); + assert!(iter.next().is_none()); + } + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] #[test] fn fit_longley() { From 3a44161406b7f7b80226f4e576afd56c2c899ffe Mon Sep 17 00:00:00 2001 From: morenol <22335041+morenol@users.noreply.github.com> Date: Wed, 21 Sep 2022 15:35:22 -0400 Subject: [PATCH 17/19] Lmm/add seeds in more algorithms (#164) * Provide better output in flaky tests * feat: add seed parameter to multiple algorithms * Update changelog Co-authored-by: Luis Moreno --- .github/workflows/ci.yml | 23 ++++++++++++----------- CHANGELOG.md | 9 +++++++++ Cargo.toml | 10 +++++++--- src/cluster/kmeans.rs | 13 +++++++++---- src/ensemble/random_forest_classifier.rs | 11 ++++++----- src/ensemble/random_forest_regressor.rs | 11 ++++++----- src/lib.rs | 2 ++ src/math/num.rs | 6 ++++-- src/model_selection/kfold.rs | 21 +++++++++++++++++++-- src/model_selection/mod.rs | 11 +++++++---- src/rand.rs | 21 +++++++++++++++++++++ src/svm/svc.rs | 22 +++++++++++++++++----- src/tree/decision_tree_classifier.rs | 22 ++++++++++------------ src/tree/decision_tree_regressor.rs | 21 ++++++++++----------- 14 files changed, 139 insertions(+), 64 deletions(-) create mode 100644 src/rand.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5041117..82d0eab 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,23 +2,24 @@ name: CI on: push: - branches: [ main, development ] + branches: [main, development] pull_request: - branches: [ development ] + branches: [development] jobs: tests: runs-on: "${{ matrix.platform.os }}-latest" strategy: matrix: - platform: [ - { os: "windows", target: "x86_64-pc-windows-msvc" }, - { os: "windows", target: "i686-pc-windows-msvc" }, - { os: "ubuntu", target: "x86_64-unknown-linux-gnu" }, - { os: "ubuntu", target: "i686-unknown-linux-gnu" }, - { os: "ubuntu", target: "wasm32-unknown-unknown" }, - { os: "macos", target: "aarch64-apple-darwin" }, - ] + platform: + [ + { os: "windows", target: "x86_64-pc-windows-msvc" }, + { os: "windows", target: "i686-pc-windows-msvc" }, + { os: "ubuntu", target: "x86_64-unknown-linux-gnu" }, + { os: "ubuntu", target: "i686-unknown-linux-gnu" }, + { os: "ubuntu", target: "wasm32-unknown-unknown" }, + { os: "macos", target: "aarch64-apple-darwin" }, + ] env: TZ: "/usr/share/zoneinfo/your/location" steps: @@ -40,7 +41,7 @@ jobs: default: true - name: Install test runner for wasm if: matrix.platform.target == 'wasm32-unknown-unknown' - run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh + run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh - name: Stable Build uses: actions-rs/cargo@v1 with: diff --git a/CHANGELOG.md b/CHANGELOG.md index ade6825..79e77e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## Added +- Seeds to multiple algorithims that depend on random number generation. +- Added feature `js` to use WASM in browser + +## BREAKING CHANGE +- Added a new parameter to `train_test_split` to define the seed. + +## [0.2.1] - 2022-05-10 + ## Added - L2 regularization penalty to the Logistic Regression - Getters for the naive bayes structs diff --git a/Cargo.toml b/Cargo.toml index a0ad984..51b9887 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,21 +16,25 @@ categories = ["science"] default = ["datasets"] ndarray-bindings = ["ndarray"] nalgebra-bindings = ["nalgebra"] -datasets = ["rand_distr"] +datasets = ["rand_distr", "std"] fp_bench = ["itertools"] +std = ["rand/std", "rand/std_rng"] +# wasm32 only +js = ["getrandom/js"] [dependencies] ndarray = { version = "0.15", optional = true } nalgebra = { version = "0.31", optional = true } num-traits = "0.2" num = "0.4" -rand = "0.8" +rand = { version = "0.8", default-features = false, features = ["small_rng"] } rand_distr = { version = "0.4", optional = true } serde = { version = "1", features = ["derive"], optional = true } itertools = { version = "0.10.3", optional = true } +cfg-if = "1.0.0" [target.'cfg(target_arch = "wasm32")'.dependencies] -getrandom = { version = "0.2", features = ["js"] } +getrandom = { version = "0.2", optional = true } [dev-dependencies] smartcore = { path = ".", features = ["fp_bench"] } diff --git a/src/cluster/kmeans.rs b/src/cluster/kmeans.rs index 8ecbb2e..fee1425 100644 --- a/src/cluster/kmeans.rs +++ b/src/cluster/kmeans.rs @@ -52,10 +52,10 @@ //! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 10.3.1 K-Means Clustering](http://faculty.marshall.usc.edu/gareth-james/ISL/) //! * ["k-means++: The Advantages of Careful Seeding", Arthur D., Vassilvitskii S.](http://ilpubs.stanford.edu:8090/778/1/2006-13.pdf) -use rand::Rng; use std::fmt::Debug; use std::iter::Sum; +use ::rand::Rng; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -65,6 +65,7 @@ use crate::error::Failed; use crate::linalg::Matrix; use crate::math::distance::euclidian::*; use crate::math::num::RealNumber; +use crate::rand::get_rng_impl; /// K-Means clustering algorithm #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -108,6 +109,9 @@ pub struct KMeansParameters { pub k: usize, /// Maximum number of iterations of the k-means algorithm for a single run. pub max_iter: usize, + /// Determines random number generation for centroid initialization. + /// Use an int to make the randomness deterministic + pub seed: Option, } impl KMeansParameters { @@ -128,6 +132,7 @@ impl Default for KMeansParameters { KMeansParameters { k: 2, max_iter: 100, + seed: None, } } } @@ -238,7 +243,7 @@ impl KMeans { let (n, d) = data.shape(); let mut distortion = T::max_value(); - let mut y = KMeans::kmeans_plus_plus(data, parameters.k); + let mut y = KMeans::kmeans_plus_plus(data, parameters.k, parameters.seed); let mut size = vec![0; parameters.k]; let mut centroids = vec![vec![T::zero(); d]; parameters.k]; @@ -311,8 +316,8 @@ impl KMeans { Ok(result.to_row_vector()) } - fn kmeans_plus_plus>(data: &M, k: usize) -> Vec { - let mut rng = rand::thread_rng(); + fn kmeans_plus_plus>(data: &M, k: usize, seed: Option) -> Vec { + let mut rng = get_rng_impl(seed); let (n, m) = data.shape(); let mut y = vec![0; n]; let mut centroid = data.get_row_as_vec(rng.gen_range(0..n)); diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index a4d6e75..331dab7 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -45,8 +45,8 @@ //! //! //! -use rand::rngs::StdRng; -use rand::{Rng, SeedableRng}; +use rand::Rng; + use std::default::Default; use std::fmt::Debug; @@ -57,6 +57,7 @@ use crate::api::{Predictor, SupervisedEstimator}; use crate::error::{Failed, FailedError}; use crate::linalg::Matrix; use crate::math::num::RealNumber; +use crate::rand::get_rng_impl; use crate::tree::decision_tree_classifier::{ which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion, }; @@ -441,7 +442,7 @@ impl RandomForestClassifier { .unwrap() }); - let mut rng = StdRng::seed_from_u64(parameters.seed); + let mut rng = get_rng_impl(Some(parameters.seed)); let classes = y_m.unique(); let k = classes.len(); let mut trees: Vec> = Vec::new(); @@ -462,9 +463,9 @@ impl RandomForestClassifier { max_depth: parameters.max_depth, min_samples_leaf: parameters.min_samples_leaf, min_samples_split: parameters.min_samples_split, + seed: Some(parameters.seed), }; - let tree = - DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?; + let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params)?; trees.push(tree); } diff --git a/src/ensemble/random_forest_regressor.rs b/src/ensemble/random_forest_regressor.rs index ec78137..1270685 100644 --- a/src/ensemble/random_forest_regressor.rs +++ b/src/ensemble/random_forest_regressor.rs @@ -43,8 +43,8 @@ //! //! -use rand::rngs::StdRng; -use rand::{Rng, SeedableRng}; +use rand::Rng; + use std::default::Default; use std::fmt::Debug; @@ -55,6 +55,7 @@ use crate::api::{Predictor, SupervisedEstimator}; use crate::error::{Failed, FailedError}; use crate::linalg::Matrix; use crate::math::num::RealNumber; +use crate::rand::get_rng_impl; use crate::tree::decision_tree_regressor::{ DecisionTreeRegressor, DecisionTreeRegressorParameters, }; @@ -376,7 +377,7 @@ impl RandomForestRegressor { .m .unwrap_or((num_attributes as f64).sqrt().floor() as usize); - let mut rng = StdRng::seed_from_u64(parameters.seed); + let mut rng = get_rng_impl(Some(parameters.seed)); let mut trees: Vec> = Vec::new(); let mut maybe_all_samples: Option>> = Option::None; @@ -393,9 +394,9 @@ impl RandomForestRegressor { max_depth: parameters.max_depth, min_samples_leaf: parameters.min_samples_leaf, min_samples_split: parameters.min_samples_split, + seed: Some(parameters.seed), }; - let tree = - DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?; + let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params)?; trees.push(tree); } diff --git a/src/lib.rs b/src/lib.rs index e9e1c3d..b46ee10 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -101,3 +101,5 @@ pub mod readers; pub mod svm; /// Supervised tree-based learning methods pub mod tree; + +pub(crate) mod rand; diff --git a/src/math/num.rs b/src/math/num.rs index 433ad28..1ec20fb 100644 --- a/src/math/num.rs +++ b/src/math/num.rs @@ -9,6 +9,8 @@ use std::iter::{Product, Sum}; use std::ops::{AddAssign, DivAssign, MulAssign, SubAssign}; use std::str::FromStr; +use crate::rand::get_rng_impl; + /// Defines real number /// pub trait RealNumber: @@ -79,7 +81,7 @@ impl RealNumber for f64 { } fn rand() -> f64 { - let mut rng = rand::thread_rng(); + let mut rng = get_rng_impl(None); rng.gen() } @@ -124,7 +126,7 @@ impl RealNumber for f32 { } fn rand() -> f32 { - let mut rng = rand::thread_rng(); + let mut rng = get_rng_impl(None); rng.gen() } diff --git a/src/model_selection/kfold.rs b/src/model_selection/kfold.rs index 8706954..ef48b87 100644 --- a/src/model_selection/kfold.rs +++ b/src/model_selection/kfold.rs @@ -5,8 +5,8 @@ use crate::linalg::Matrix; use crate::math::num::RealNumber; use crate::model_selection::BaseKFold; +use crate::rand::get_rng_impl; use rand::seq::SliceRandom; -use rand::thread_rng; /// K-Folds cross-validator pub struct KFold { @@ -14,6 +14,9 @@ pub struct KFold { pub n_splits: usize, // cannot exceed std::usize::MAX /// Whether to shuffle the data before splitting into batches pub shuffle: bool, + /// When shuffle is True, seed affects the ordering of the indices. + /// Which controls the randomness of each fold + pub seed: Option, } impl KFold { @@ -23,8 +26,10 @@ impl KFold { // initialise indices let mut indices: Vec = (0..n_samples).collect(); + let mut rng = get_rng_impl(self.seed); + if self.shuffle { - indices.shuffle(&mut thread_rng()); + indices.shuffle(&mut rng); } // return a new array of given shape n_split, filled with each element of n_samples divided by n_splits. let mut fold_sizes = vec![n_samples / self.n_splits; self.n_splits]; @@ -66,6 +71,7 @@ impl Default for KFold { KFold { n_splits: 3, shuffle: true, + seed: None, } } } @@ -81,6 +87,12 @@ impl KFold { self.shuffle = shuffle; self } + + /// When shuffle is True, random_state affects the ordering of the indices. + pub fn with_seed(mut self, seed: Option) -> Self { + self.seed = seed; + self + } } /// An iterator over indices that split data into training and test set. @@ -150,6 +162,7 @@ mod tests { let k = KFold { n_splits: 3, shuffle: false, + seed: None, }; let x: DenseMatrix = DenseMatrix::rand(33, 100); let test_indices = k.test_indices(&x); @@ -165,6 +178,7 @@ mod tests { let k = KFold { n_splits: 3, shuffle: false, + seed: None, }; let x: DenseMatrix = DenseMatrix::rand(34, 100); let test_indices = k.test_indices(&x); @@ -180,6 +194,7 @@ mod tests { let k = KFold { n_splits: 2, shuffle: false, + seed: None, }; let x: DenseMatrix = DenseMatrix::rand(22, 100); let test_masks = k.test_masks(&x); @@ -206,6 +221,7 @@ mod tests { let k = KFold { n_splits: 2, shuffle: false, + seed: None, }; let x: DenseMatrix = DenseMatrix::rand(22, 100); let train_test_splits: Vec<(Vec, Vec)> = k.split(&x).collect(); @@ -238,6 +254,7 @@ mod tests { let k = KFold { n_splits: 3, shuffle: false, + seed: None, }; let x: DenseMatrix = DenseMatrix::rand(10, 4); let expected: Vec<(Vec, Vec)> = vec![ diff --git a/src/model_selection/mod.rs b/src/model_selection/mod.rs index 6f737d6..21cf7ed 100644 --- a/src/model_selection/mod.rs +++ b/src/model_selection/mod.rs @@ -41,7 +41,7 @@ //! 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., //! ]; //! -//! let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true); +//! let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true, None); //! //! println!("X train: {:?}, y train: {}, X test: {:?}, y test: {}", //! x_train.shape(), y_train.len(), x_test.shape(), y_test.len()); @@ -107,8 +107,8 @@ use crate::error::Failed; use crate::linalg::BaseVector; use crate::linalg::Matrix; use crate::math::num::RealNumber; +use crate::rand::get_rng_impl; use rand::seq::SliceRandom; -use rand::thread_rng; pub(crate) mod kfold; @@ -130,11 +130,13 @@ pub trait BaseKFold { /// * `y` - target values, should be of size _N_ /// * `test_size`, (0, 1] - the proportion of the dataset to include in the test split. /// * `shuffle`, - whether or not to shuffle the data before splitting +/// * `seed` - Controls the shuffling applied to the data before applying the split. Pass an int for reproducible output across multiple function calls pub fn train_test_split>( x: &M, y: &M::RowVector, test_size: f32, shuffle: bool, + seed: Option, ) -> (M, M, M::RowVector, M::RowVector) { if x.shape().0 != y.len() { panic!( @@ -143,6 +145,7 @@ pub fn train_test_split>( y.len() ); } + let mut rng = get_rng_impl(seed); if test_size <= 0. || test_size > 1.0 { panic!("test_size should be between 0 and 1"); @@ -159,7 +162,7 @@ pub fn train_test_split>( let mut indices: Vec = (0..n).collect(); if shuffle { - indices.shuffle(&mut thread_rng()); + indices.shuffle(&mut rng); } let x_train = x.take(&indices[n_test..n], 0); @@ -292,7 +295,7 @@ mod tests { let x: DenseMatrix = DenseMatrix::rand(n, 3); let y = vec![0f64; n]; - let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true); + let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y, 0.2, true, None); assert!( x_train.shape().0 > (n as f64 * 0.65) as usize diff --git a/src/rand.rs b/src/rand.rs new file mode 100644 index 0000000..d90e9c9 --- /dev/null +++ b/src/rand.rs @@ -0,0 +1,21 @@ +use ::rand::SeedableRng; +#[cfg(not(feature = "std"))] +use rand::rngs::SmallRng as RngImpl; +#[cfg(feature = "std")] +use rand::rngs::StdRng as RngImpl; + +pub(crate) fn get_rng_impl(seed: Option) -> RngImpl { + match seed { + Some(seed) => RngImpl::seed_from_u64(seed), + None => { + cfg_if::cfg_if! { + if #[cfg(feature = "std")] { + use rand::RngCore; + RngImpl::seed_from_u64(rand::thread_rng().next_u64()) + } else { + panic!("seed number needed for non-std build"); + } + } + } + } +} diff --git a/src/svm/svc.rs b/src/svm/svc.rs index 46b0b68..94c6d9e 100644 --- a/src/svm/svc.rs +++ b/src/svm/svc.rs @@ -84,6 +84,7 @@ use crate::error::Failed; use crate::linalg::BaseVector; use crate::linalg::Matrix; use crate::math::num::RealNumber; +use crate::rand::get_rng_impl; use crate::svm::{Kernel, Kernels, LinearKernel}; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -100,6 +101,8 @@ pub struct SVCParameters, K: Kernel pub kernel: K, /// Unused parameter. m: PhantomData, + /// Controls the pseudo random number generation for shuffling the data for probability estimates + seed: Option, } /// SVC grid search parameters @@ -279,8 +282,15 @@ impl, K: Kernel> SVCParameters) -> Self { + self.seed = seed; + self + } } impl> Default for SVCParameters { @@ -291,6 +301,7 @@ impl> Default for SVCParameters tol: T::from_f64(1e-3).unwrap(), kernel: Kernels::linear(), m: PhantomData, + seed: None, } } } @@ -511,7 +522,7 @@ impl<'a, T: RealNumber, M: Matrix, K: Kernel> Optimizer<'a, let good_enough = T::from_i32(1000).unwrap(); for _ in 0..self.parameters.epoch { - for i in Self::permutate(n) { + for i in self.permutate(n) { self.process(i, self.x.get_row(i), self.y.get(i), &mut cache); loop { self.reprocess(tol, &mut cache); @@ -544,7 +555,7 @@ impl<'a, T: RealNumber, M: Matrix, K: Kernel> Optimizer<'a, let mut cp = 0; let mut cn = 0; - for i in Self::permutate(n) { + for i in self.permutate(n) { if self.y.get(i) == T::one() && cp < few { if self.process(i, self.x.get_row(i), self.y.get(i), cache) { cp += 1; @@ -669,8 +680,8 @@ impl<'a, T: RealNumber, M: Matrix, K: Kernel> Optimizer<'a, self.recalculate_minmax_grad = true; } - fn permutate(n: usize) -> Vec { - let mut rng = rand::thread_rng(); + fn permutate(&self, n: usize) -> Vec { + let mut rng = get_rng_impl(self.parameters.seed); let mut range: Vec = (0..n).collect(); range.shuffle(&mut rng); range @@ -893,7 +904,8 @@ mod tests { &y, SVCParameters::default() .with_c(200.0) - .with_kernel(Kernels::linear()), + .with_kernel(Kernels::linear()) + .with_seed(Some(100)), ) .and_then(|lr| lr.predict(&x)) .unwrap(); diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index a1699af..a14c104 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -77,6 +77,7 @@ use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; use crate::linalg::Matrix; use crate::math::num::RealNumber; +use crate::rand::get_rng_impl; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] @@ -90,6 +91,8 @@ pub struct DecisionTreeClassifierParameters { pub min_samples_leaf: usize, /// The minimum number of samples required to split an internal node. pub min_samples_split: usize, + /// Controls the randomness of the estimator + pub seed: Option, } /// Decision Tree @@ -197,6 +200,7 @@ impl Default for DecisionTreeClassifierParameters { max_depth: None, min_samples_leaf: 1, min_samples_split: 2, + seed: None, } } } @@ -467,14 +471,7 @@ impl DecisionTreeClassifier { ) -> Result, Failed> { let (x_nrows, num_attributes) = x.shape(); let samples = vec![1; x_nrows]; - DecisionTreeClassifier::fit_weak_learner( - x, - y, - samples, - num_attributes, - parameters, - &mut rand::thread_rng(), - ) + DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters) } pub(crate) fn fit_weak_learner>( @@ -483,7 +480,6 @@ impl DecisionTreeClassifier { samples: Vec, mtry: usize, parameters: DecisionTreeClassifierParameters, - rng: &mut impl Rng, ) -> Result, Failed> { let y_m = M::from_row_vector(y.clone()); let (_, y_ncols) = y_m.shape(); @@ -497,6 +493,7 @@ impl DecisionTreeClassifier { ))); } + let mut rng = get_rng_impl(parameters.seed); let mut yi: Vec = vec![0; y_ncols]; for (i, yi_i) in yi.iter_mut().enumerate().take(y_ncols) { @@ -531,13 +528,13 @@ impl DecisionTreeClassifier { let mut visitor_queue: LinkedList> = LinkedList::new(); - if tree.find_best_cutoff(&mut visitor, mtry, rng) { + if tree.find_best_cutoff(&mut visitor, mtry, &mut rng) { visitor_queue.push_back(visitor); } while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) { match visitor_queue.pop_front() { - Some(node) => tree.split(node, mtry, &mut visitor_queue, rng), + Some(node) => tree.split(node, mtry, &mut visitor_queue, &mut rng), None => break, }; } @@ -874,7 +871,8 @@ mod tests { criterion: SplitCriterion::Entropy, max_depth: Some(3), min_samples_leaf: 1, - min_samples_split: 2 + min_samples_split: 2, + seed: None } ) .unwrap() diff --git a/src/tree/decision_tree_regressor.rs b/src/tree/decision_tree_regressor.rs index f48de33..7d88c40 100644 --- a/src/tree/decision_tree_regressor.rs +++ b/src/tree/decision_tree_regressor.rs @@ -72,6 +72,7 @@ use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; use crate::linalg::Matrix; use crate::math::num::RealNumber; +use crate::rand::get_rng_impl; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] @@ -83,6 +84,8 @@ pub struct DecisionTreeRegressorParameters { pub min_samples_leaf: usize, /// The minimum number of samples required to split an internal node. pub min_samples_split: usize, + /// Controls the randomness of the estimator + pub seed: Option, } /// Regression Tree @@ -130,6 +133,7 @@ impl Default for DecisionTreeRegressorParameters { max_depth: None, min_samples_leaf: 1, min_samples_split: 2, + seed: None, } } } @@ -357,14 +361,7 @@ impl DecisionTreeRegressor { ) -> Result, Failed> { let (x_nrows, num_attributes) = x.shape(); let samples = vec![1; x_nrows]; - DecisionTreeRegressor::fit_weak_learner( - x, - y, - samples, - num_attributes, - parameters, - &mut rand::thread_rng(), - ) + DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters) } pub(crate) fn fit_weak_learner>( @@ -373,7 +370,6 @@ impl DecisionTreeRegressor { samples: Vec, mtry: usize, parameters: DecisionTreeRegressorParameters, - rng: &mut impl Rng, ) -> Result, Failed> { let y_m = M::from_row_vector(y.clone()); @@ -381,6 +377,7 @@ impl DecisionTreeRegressor { let (_, num_attributes) = x.shape(); let mut nodes: Vec> = Vec::new(); + let mut rng = get_rng_impl(parameters.seed); let mut n = 0; let mut sum = T::zero(); @@ -407,13 +404,13 @@ impl DecisionTreeRegressor { let mut visitor_queue: LinkedList> = LinkedList::new(); - if tree.find_best_cutoff(&mut visitor, mtry, rng) { + if tree.find_best_cutoff(&mut visitor, mtry, &mut rng) { visitor_queue.push_back(visitor); } while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) { match visitor_queue.pop_front() { - Some(node) => tree.split(node, mtry, &mut visitor_queue, rng), + Some(node) => tree.split(node, mtry, &mut visitor_queue, &mut rng), None => break, }; } @@ -699,6 +696,7 @@ mod tests { max_depth: Option::None, min_samples_leaf: 2, min_samples_split: 6, + seed: None, }, ) .and_then(|t| t.predict(&x)) @@ -719,6 +717,7 @@ mod tests { max_depth: Option::None, min_samples_leaf: 1, min_samples_split: 3, + seed: None, }, ) .and_then(|t| t.predict(&x)) From 403d3f234833382e30b0f06b66d0ebf0b35de0a3 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Wed, 21 Sep 2022 16:15:26 -0700 Subject: [PATCH 18/19] add seed param to search params (#168) --- src/cluster/kmeans.rs | 13 +++++++++++++ src/svm/svc.rs | 14 ++++++++++++++ src/tree/decision_tree_classifier.rs | 20 ++++++++++++++++++++ src/tree/decision_tree_regressor.rs | 14 ++++++++++++++ 4 files changed, 61 insertions(+) diff --git a/src/cluster/kmeans.rs b/src/cluster/kmeans.rs index fee1425..404f7b0 100644 --- a/src/cluster/kmeans.rs +++ b/src/cluster/kmeans.rs @@ -145,6 +145,9 @@ pub struct KMeansSearchParameters { pub k: Vec, /// Maximum number of iterations of the k-means algorithm for a single run. pub max_iter: Vec, + /// Determines random number generation for centroid initialization. + /// Use an int to make the randomness deterministic + pub seed: Vec>, } /// KMeans grid search iterator @@ -152,6 +155,7 @@ pub struct KMeansSearchParametersIterator { kmeans_search_parameters: KMeansSearchParameters, current_k: usize, current_max_iter: usize, + current_seed: usize, } impl IntoIterator for KMeansSearchParameters { @@ -163,6 +167,7 @@ impl IntoIterator for KMeansSearchParameters { kmeans_search_parameters: self, current_k: 0, current_max_iter: 0, + current_seed: 0, } } } @@ -173,6 +178,7 @@ impl Iterator for KMeansSearchParametersIterator { fn next(&mut self) -> Option { if self.current_k == self.kmeans_search_parameters.k.len() && self.current_max_iter == self.kmeans_search_parameters.max_iter.len() + && self.current_seed == self.kmeans_search_parameters.seed.len() { return None; } @@ -180,6 +186,7 @@ impl Iterator for KMeansSearchParametersIterator { let next = KMeansParameters { k: self.kmeans_search_parameters.k[self.current_k], max_iter: self.kmeans_search_parameters.max_iter[self.current_max_iter], + seed: self.kmeans_search_parameters.seed[self.current_seed], }; if self.current_k + 1 < self.kmeans_search_parameters.k.len() { @@ -187,9 +194,14 @@ impl Iterator for KMeansSearchParametersIterator { } else if self.current_max_iter + 1 < self.kmeans_search_parameters.max_iter.len() { self.current_k = 0; self.current_max_iter += 1; + } else if self.current_seed + 1 < self.kmeans_search_parameters.seed.len() { + self.current_k = 0; + self.current_max_iter = 0; + self.current_seed += 1; } else { self.current_k += 1; self.current_max_iter += 1; + self.current_seed += 1; } Some(next) @@ -203,6 +215,7 @@ impl Default for KMeansSearchParameters { KMeansSearchParameters { k: vec![default_params.k], max_iter: vec![default_params.max_iter], + seed: vec![default_params.seed], } } } diff --git a/src/svm/svc.rs b/src/svm/svc.rs index 94c6d9e..d390866 100644 --- a/src/svm/svc.rs +++ b/src/svm/svc.rs @@ -119,6 +119,8 @@ pub struct SVCSearchParameters, K: Kernel, /// Unused parameter. m: PhantomData, + /// Controls the pseudo random number generation for shuffling the data for probability estimates + seed: Vec>, } /// SVC grid search iterator @@ -128,6 +130,7 @@ pub struct SVCSearchParametersIterator, K: Kernel, K: Kernel> IntoIterator @@ -143,6 +146,7 @@ impl, K: Kernel> IntoIterator current_c: 0, current_tol: 0, current_kernel: 0, + current_seed: 0, } } } @@ -157,6 +161,7 @@ impl, K: Kernel> Iterator && self.current_c == self.svc_search_parameters.c.len() && self.current_tol == self.svc_search_parameters.tol.len() && self.current_kernel == self.svc_search_parameters.kernel.len() + && self.current_seed == self.svc_search_parameters.kernel.len() { return None; } @@ -167,6 +172,7 @@ impl, K: Kernel> Iterator tol: self.svc_search_parameters.tol[self.current_tol], kernel: self.svc_search_parameters.kernel[self.current_kernel].clone(), m: PhantomData, + seed: self.svc_search_parameters.seed[self.current_seed], }; if self.current_epoch + 1 < self.svc_search_parameters.epoch.len() { @@ -183,11 +189,18 @@ impl, K: Kernel> Iterator self.current_c = 0; self.current_tol = 0; self.current_kernel += 1; + } else if self.current_kernel + 1 < self.svc_search_parameters.kernel.len() { + self.current_epoch = 0; + self.current_c = 0; + self.current_tol = 0; + self.current_kernel = 0; + self.current_seed += 1; } else { self.current_epoch += 1; self.current_c += 1; self.current_tol += 1; self.current_kernel += 1; + self.current_seed += 1; } Some(next) @@ -204,6 +217,7 @@ impl> Default for SVCSearchParameters, + #[cfg_attr(feature = "serde", serde(default))] /// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) pub max_depth: Vec>, + #[cfg_attr(feature = "serde", serde(default))] /// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) pub min_samples_leaf: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) pub min_samples_split: Vec, + #[cfg_attr(feature = "serde", serde(default))] + /// Controls the randomness of the estimator + pub seed: Vec>, } /// DecisionTreeClassifier grid search iterator @@ -226,6 +233,7 @@ pub struct DecisionTreeClassifierSearchParametersIterator { current_max_depth: usize, current_min_samples_leaf: usize, current_min_samples_split: usize, + current_seed: usize, } impl IntoIterator for DecisionTreeClassifierSearchParameters { @@ -239,6 +247,7 @@ impl IntoIterator for DecisionTreeClassifierSearchParameters { current_max_depth: 0, current_min_samples_leaf: 0, current_min_samples_split: 0, + current_seed: 0, } } } @@ -267,6 +276,7 @@ impl Iterator for DecisionTreeClassifierSearchParametersIterator { .decision_tree_classifier_search_parameters .min_samples_split .len() + && self.current_seed == self.decision_tree_classifier_search_parameters.seed.len() { return None; } @@ -283,6 +293,7 @@ impl Iterator for DecisionTreeClassifierSearchParametersIterator { min_samples_split: self .decision_tree_classifier_search_parameters .min_samples_split[self.current_min_samples_split], + seed: self.decision_tree_classifier_search_parameters.seed[self.current_seed], }; if self.current_criterion + 1 @@ -319,11 +330,19 @@ impl Iterator for DecisionTreeClassifierSearchParametersIterator { self.current_max_depth = 0; self.current_min_samples_leaf = 0; self.current_min_samples_split += 1; + } else if self.current_seed + 1 < self.decision_tree_classifier_search_parameters.seed.len() + { + self.current_criterion = 0; + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split = 0; + self.current_seed += 1; } else { self.current_criterion += 1; self.current_max_depth += 1; self.current_min_samples_leaf += 1; self.current_min_samples_split += 1; + self.current_seed += 1; } Some(next) @@ -339,6 +358,7 @@ impl Default for DecisionTreeClassifierSearchParameters { max_depth: vec![default_params.max_depth], min_samples_leaf: vec![default_params.min_samples_leaf], min_samples_split: vec![default_params.min_samples_split], + seed: vec![default_params.seed], } } } diff --git a/src/tree/decision_tree_regressor.rs b/src/tree/decision_tree_regressor.rs index 7d88c40..12bb9c9 100644 --- a/src/tree/decision_tree_regressor.rs +++ b/src/tree/decision_tree_regressor.rs @@ -148,6 +148,8 @@ pub struct DecisionTreeRegressorSearchParameters { pub min_samples_leaf: Vec, /// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) pub min_samples_split: Vec, + /// Controls the randomness of the estimator + pub seed: Vec>, } /// DecisionTreeRegressor grid search iterator @@ -156,6 +158,7 @@ pub struct DecisionTreeRegressorSearchParametersIterator { current_max_depth: usize, current_min_samples_leaf: usize, current_min_samples_split: usize, + current_seed: usize, } impl IntoIterator for DecisionTreeRegressorSearchParameters { @@ -168,6 +171,7 @@ impl IntoIterator for DecisionTreeRegressorSearchParameters { current_max_depth: 0, current_min_samples_leaf: 0, current_min_samples_split: 0, + current_seed: 0, } } } @@ -191,6 +195,7 @@ impl Iterator for DecisionTreeRegressorSearchParametersIterator { .decision_tree_regressor_search_parameters .min_samples_split .len() + && self.current_seed == self.decision_tree_regressor_search_parameters.seed.len() { return None; } @@ -204,6 +209,7 @@ impl Iterator for DecisionTreeRegressorSearchParametersIterator { min_samples_split: self .decision_tree_regressor_search_parameters .min_samples_split[self.current_min_samples_split], + seed: self.decision_tree_regressor_search_parameters.seed[self.current_seed], }; if self.current_max_depth + 1 @@ -230,10 +236,17 @@ impl Iterator for DecisionTreeRegressorSearchParametersIterator { self.current_max_depth = 0; self.current_min_samples_leaf = 0; self.current_min_samples_split += 1; + } else if self.current_seed + 1 < self.decision_tree_regressor_search_parameters.seed.len() + { + self.current_max_depth = 0; + self.current_min_samples_leaf = 0; + self.current_min_samples_split = 0; + self.current_seed += 1; } else { self.current_max_depth += 1; self.current_min_samples_leaf += 1; self.current_min_samples_split += 1; + self.current_seed += 1; } Some(next) @@ -248,6 +261,7 @@ impl Default for DecisionTreeRegressorSearchParameters { max_depth: vec![default_params.max_depth], min_samples_leaf: vec![default_params.min_samples_leaf], min_samples_split: vec![default_params.min_samples_split], + seed: vec![default_params.seed], } } } From 764309e313224ba0f6f9047e55c7507da0145224 Mon Sep 17 00:00:00 2001 From: Montana Low Date: Wed, 21 Sep 2022 19:48:31 -0700 Subject: [PATCH 19/19] make default params available to serde (#167) * add seed param to search params * make default params available to serde * lints * create defaults for enums * lint --- src/algorithm/neighbour/mod.rs | 6 ++++++ src/cluster/dbscan.rs | 11 ++++++++++- src/cluster/kmeans.rs | 7 +++++++ src/decomposition/pca.rs | 5 +++++ src/decomposition/svd.rs | 3 +++ src/ensemble/random_forest_classifier.rs | 16 ++++++++++++++++ src/ensemble/random_forest_regressor.rs | 14 ++++++++++++++ src/linear/elastic_net.rs | 10 ++++++++++ src/linear/lasso.rs | 8 ++++++++ src/linear/linear_regression.rs | 19 +++++++++---------- src/linear/logistic_regression.rs | 12 +++++++++++- src/linear/ridge_regression.rs | 11 ++++++++++- src/naive_bayes/bernoulli.rs | 6 ++++++ src/naive_bayes/categorical.rs | 2 ++ src/naive_bayes/gaussian.rs | 2 ++ src/naive_bayes/multinomial.rs | 4 ++++ src/neighbors/knn_classifier.rs | 9 +++++++-- src/neighbors/knn_regressor.rs | 9 +++++++-- src/neighbors/mod.rs | 6 ++++++ src/svm/svc.rs | 12 ++++++++++++ src/tree/decision_tree_classifier.rs | 13 ++++++++++++- src/tree/decision_tree_regressor.rs | 8 ++++++++ 22 files changed, 175 insertions(+), 18 deletions(-) diff --git a/src/algorithm/neighbour/mod.rs b/src/algorithm/neighbour/mod.rs index 42ab7bc..f59448a 100644 --- a/src/algorithm/neighbour/mod.rs +++ b/src/algorithm/neighbour/mod.rs @@ -59,6 +59,12 @@ pub enum KNNAlgorithmName { CoverTree, } +impl Default for KNNAlgorithmName { + fn default() -> Self { + KNNAlgorithmName::CoverTree + } +} + #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug)] pub(crate) enum KNNAlgorithm, T>> { diff --git a/src/cluster/dbscan.rs b/src/cluster/dbscan.rs index 621d017..ba8722e 100644 --- a/src/cluster/dbscan.rs +++ b/src/cluster/dbscan.rs @@ -65,17 +65,22 @@ pub struct DBSCAN, T>> { eps: T, } +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] /// DBSCAN clustering algorithm parameters pub struct DBSCANParameters, T>> { + #[cfg_attr(feature = "serde", serde(default))] /// a function that defines a distance between each pair of point in training data. /// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait. /// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions. pub distance: D, + #[cfg_attr(feature = "serde", serde(default))] /// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point. pub min_samples: usize, + #[cfg_attr(feature = "serde", serde(default))] /// The maximum distance between two samples for one to be considered as in the neighborhood of the other. pub eps: T, + #[cfg_attr(feature = "serde", serde(default))] /// KNN algorithm to use. pub algorithm: KNNAlgorithmName, } @@ -113,14 +118,18 @@ impl, T>> DBSCANParameters { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct DBSCANSearchParameters, T>> { + #[cfg_attr(feature = "serde", serde(default))] /// a function that defines a distance between each pair of point in training data. /// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait. /// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions. pub distance: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point. pub min_samples: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// The maximum distance between two samples for one to be considered as in the neighborhood of the other. pub eps: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// KNN algorithm to use. pub algorithm: Vec, } @@ -221,7 +230,7 @@ impl Default for DBSCANParameters { distance: Distances::euclidian(), min_samples: 5, eps: T::half(), - algorithm: KNNAlgorithmName::CoverTree, + algorithm: KNNAlgorithmName::default(), } } } diff --git a/src/cluster/kmeans.rs b/src/cluster/kmeans.rs index 404f7b0..6f45e6c 100644 --- a/src/cluster/kmeans.rs +++ b/src/cluster/kmeans.rs @@ -102,13 +102,17 @@ impl PartialEq for KMeans { } } +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] /// K-Means clustering algorithm parameters pub struct KMeansParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Number of clusters. pub k: usize, + #[cfg_attr(feature = "serde", serde(default))] /// Maximum number of iterations of the k-means algorithm for a single run. pub max_iter: usize, + #[cfg_attr(feature = "serde", serde(default))] /// Determines random number generation for centroid initialization. /// Use an int to make the randomness deterministic pub seed: Option, @@ -141,10 +145,13 @@ impl Default for KMeansParameters { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct KMeansSearchParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Number of clusters. pub k: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// Maximum number of iterations of the k-means algorithm for a single run. pub max_iter: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// Determines random number generation for centroid initialization. /// Use an int to make the randomness deterministic pub seed: Vec>, diff --git a/src/decomposition/pca.rs b/src/decomposition/pca.rs index 296926a..7961d41 100644 --- a/src/decomposition/pca.rs +++ b/src/decomposition/pca.rs @@ -83,11 +83,14 @@ impl> PartialEq for PCA { } } +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] /// PCA parameters pub struct PCAParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Number of components to keep. pub n_components: usize, + #[cfg_attr(feature = "serde", serde(default))] /// By default, covariance matrix is used to compute principal components. /// Enable this flag if you want to use correlation matrix instead. pub use_correlation_matrix: bool, @@ -120,8 +123,10 @@ impl Default for PCAParameters { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct PCASearchParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Number of components to keep. pub n_components: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// By default, covariance matrix is used to compute principal components. /// Enable this flag if you want to use correlation matrix instead. pub use_correlation_matrix: Vec, diff --git a/src/decomposition/svd.rs b/src/decomposition/svd.rs index 3001fd9..9a1e33d 100644 --- a/src/decomposition/svd.rs +++ b/src/decomposition/svd.rs @@ -69,9 +69,11 @@ impl> PartialEq for SVD { } } +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] /// SVD parameters pub struct SVDParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Number of components to keep. pub n_components: usize, } @@ -94,6 +96,7 @@ impl SVDParameters { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct SVDSearchParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Maximum number of iterations of the k-means algorithm for a single run. pub n_components: Vec, } diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index 331dab7..4264305 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -67,20 +67,28 @@ use crate::tree::decision_tree_classifier::{ #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct RandomForestClassifierParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) pub criterion: SplitCriterion, + #[cfg_attr(feature = "serde", serde(default))] /// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) pub max_depth: Option, + #[cfg_attr(feature = "serde", serde(default))] /// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) pub min_samples_leaf: usize, + #[cfg_attr(feature = "serde", serde(default))] /// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) pub min_samples_split: usize, + #[cfg_attr(feature = "serde", serde(default))] /// The number of trees in the forest. pub n_trees: u16, + #[cfg_attr(feature = "serde", serde(default))] /// Number of random sample of predictors to use as split candidates. pub m: Option, + #[cfg_attr(feature = "serde", serde(default))] /// Whether to keep samples used for tree generation. This is required for OOB prediction. pub keep_samples: bool, + #[cfg_attr(feature = "serde", serde(default))] /// Seed used for bootstrap sampling and feature selection for each tree. pub seed: u64, } @@ -198,20 +206,28 @@ impl> Predictor for RandomForestCla #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct RandomForestClassifierSearchParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Split criteria to use when building a tree. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) pub criterion: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) pub max_depth: Vec>, + #[cfg_attr(feature = "serde", serde(default))] /// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) pub min_samples_leaf: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) pub min_samples_split: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// The number of trees in the forest. pub n_trees: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// Number of random sample of predictors to use as split candidates. pub m: Vec>, + #[cfg_attr(feature = "serde", serde(default))] /// Whether to keep samples used for tree generation. This is required for OOB prediction. pub keep_samples: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// Seed used for bootstrap sampling and feature selection for each tree. pub seed: Vec, } diff --git a/src/ensemble/random_forest_regressor.rs b/src/ensemble/random_forest_regressor.rs index 1270685..d7e61c3 100644 --- a/src/ensemble/random_forest_regressor.rs +++ b/src/ensemble/random_forest_regressor.rs @@ -65,18 +65,25 @@ use crate::tree::decision_tree_regressor::{ /// Parameters of the Random Forest Regressor /// Some parameters here are passed directly into base estimator. pub struct RandomForestRegressorParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) pub max_depth: Option, + #[cfg_attr(feature = "serde", serde(default))] /// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) pub min_samples_leaf: usize, + #[cfg_attr(feature = "serde", serde(default))] /// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) pub min_samples_split: usize, + #[cfg_attr(feature = "serde", serde(default))] /// The number of trees in the forest. pub n_trees: usize, + #[cfg_attr(feature = "serde", serde(default))] /// Number of random sample of predictors to use as split candidates. pub m: Option, + #[cfg_attr(feature = "serde", serde(default))] /// Whether to keep samples used for tree generation. This is required for OOB prediction. pub keep_samples: bool, + #[cfg_attr(feature = "serde", serde(default))] /// Seed used for bootstrap sampling and feature selection for each tree. pub seed: u64, } @@ -181,18 +188,25 @@ impl> Predictor for RandomForestReg #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct RandomForestRegressorSearchParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) pub max_depth: Vec>, + #[cfg_attr(feature = "serde", serde(default))] /// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) pub min_samples_leaf: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) pub min_samples_split: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// The number of trees in the forest. pub n_trees: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// Number of random sample of predictors to use as split candidates. pub m: Vec>, + #[cfg_attr(feature = "serde", serde(default))] /// Whether to keep samples used for tree generation. This is required for OOB prediction. pub keep_samples: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// Seed used for bootstrap sampling and feature selection for each tree. pub seed: Vec, } diff --git a/src/linear/elastic_net.rs b/src/linear/elastic_net.rs index 0e9cb57..8ba3287 100644 --- a/src/linear/elastic_net.rs +++ b/src/linear/elastic_net.rs @@ -71,16 +71,21 @@ use crate::linear::lasso_optimizer::InteriorPointOptimizer; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct ElasticNetParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Regularization parameter. pub alpha: T, + #[cfg_attr(feature = "serde", serde(default))] /// The elastic net mixing parameter, with 0 <= l1_ratio <= 1. /// For l1_ratio = 0 the penalty is an L2 penalty. /// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2. pub l1_ratio: T, + #[cfg_attr(feature = "serde", serde(default))] /// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation. pub normalize: bool, + #[cfg_attr(feature = "serde", serde(default))] /// The tolerance for the optimization pub tol: T, + #[cfg_attr(feature = "serde", serde(default))] /// The maximum number of iterations pub max_iter: usize, } @@ -139,16 +144,21 @@ impl Default for ElasticNetParameters { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct ElasticNetSearchParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Regularization parameter. pub alpha: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// The elastic net mixing parameter, with 0 <= l1_ratio <= 1. /// For l1_ratio = 0 the penalty is an L2 penalty. /// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2. pub l1_ratio: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation. pub normalize: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// The tolerance for the optimization pub tol: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// The maximum number of iterations pub max_iter: Vec, } diff --git a/src/linear/lasso.rs b/src/linear/lasso.rs index aae7e50..d1445a0 100644 --- a/src/linear/lasso.rs +++ b/src/linear/lasso.rs @@ -38,13 +38,17 @@ use crate::math::num::RealNumber; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct LassoParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Controls the strength of the penalty to the loss function. pub alpha: T, + #[cfg_attr(feature = "serde", serde(default))] /// If true the regressors X will be normalized before regression /// by subtracting the mean and dividing by the standard deviation. pub normalize: bool, + #[cfg_attr(feature = "serde", serde(default))] /// The tolerance for the optimization pub tol: T, + #[cfg_attr(feature = "serde", serde(default))] /// The maximum number of iterations pub max_iter: usize, } @@ -116,13 +120,17 @@ impl> Predictor for Lasso { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct LassoSearchParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Controls the strength of the penalty to the loss function. pub alpha: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// If true the regressors X will be normalized before regression /// by subtracting the mean and dividing by the standard deviation. pub normalize: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// The tolerance for the optimization pub tol: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// The maximum number of iterations pub max_iter: Vec, } diff --git a/src/linear/linear_regression.rs b/src/linear/linear_regression.rs index c95e6e1..12769bb 100644 --- a/src/linear/linear_regression.rs +++ b/src/linear/linear_regression.rs @@ -71,19 +71,21 @@ use crate::linalg::Matrix; use crate::math::num::RealNumber; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Clone, Eq, PartialEq)] +#[derive(Debug, Default, Clone, Eq, PartialEq)] /// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable. pub enum LinearRegressionSolverName { /// QR decomposition, see [QR](../../linalg/qr/index.html) QR, + #[default] /// SVD decomposition, see [SVD](../../linalg/svd/index.html) SVD, } /// Linear Regression parameters #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Clone)] +#[derive(Debug, Default, Clone)] pub struct LinearRegressionParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Solver to use for estimation of regression coefficients. pub solver: LinearRegressionSolverName, } @@ -105,18 +107,11 @@ impl LinearRegressionParameters { } } -impl Default for LinearRegressionParameters { - fn default() -> Self { - LinearRegressionParameters { - solver: LinearRegressionSolverName::SVD, - } - } -} - /// Linear Regression grid search parameters #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct LinearRegressionSearchParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Solver to use for estimation of regression coefficients. pub solver: Vec, } @@ -353,5 +348,9 @@ mod tests { serde_json::from_str(&serde_json::to_string(&lr).unwrap()).unwrap(); assert_eq!(lr, deserialized_lr); + + let default = LinearRegressionParameters::default(); + let parameters: LinearRegressionParameters = serde_json::from_str("{}").unwrap(); + assert_eq!(parameters.solver, default.solver); } } diff --git a/src/linear/logistic_regression.rs b/src/linear/logistic_regression.rs index 3a4c706..e8fd01f 100644 --- a/src/linear/logistic_regression.rs +++ b/src/linear/logistic_regression.rs @@ -75,12 +75,20 @@ pub enum LogisticRegressionSolverName { LBFGS, } +impl Default for LogisticRegressionSolverName { + fn default() -> Self { + LogisticRegressionSolverName::LBFGS + } +} + /// Logistic Regression parameters #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct LogisticRegressionParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Solver to use for estimation of regression coefficients. pub solver: LogisticRegressionSolverName, + #[cfg_attr(feature = "serde", serde(default))] /// Regularization parameter. pub alpha: T, } @@ -89,8 +97,10 @@ pub struct LogisticRegressionParameters { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct LogisticRegressionSearchParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Solver to use for estimation of regression coefficients. pub solver: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// Regularization parameter. pub alpha: Vec, } @@ -204,7 +214,7 @@ impl LogisticRegressionParameters { impl Default for LogisticRegressionParameters { fn default() -> Self { LogisticRegressionParameters { - solver: LogisticRegressionSolverName::LBFGS, + solver: LogisticRegressionSolverName::default(), alpha: T::zero(), } } diff --git a/src/linear/ridge_regression.rs b/src/linear/ridge_regression.rs index 4c3d4ff..396953d 100644 --- a/src/linear/ridge_regression.rs +++ b/src/linear/ridge_regression.rs @@ -77,6 +77,12 @@ pub enum RidgeRegressionSolverName { SVD, } +impl Default for RidgeRegressionSolverName { + fn default() -> Self { + RidgeRegressionSolverName::Cholesky + } +} + /// Ridge Regression parameters #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] @@ -94,10 +100,13 @@ pub struct RidgeRegressionParameters { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct RidgeRegressionSearchParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Solver to use for estimation of regression coefficients. pub solver: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// Regularization parameter. pub alpha: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// If true the regressors X will be normalized before regression /// by subtracting the mean and dividing by the standard deviation. pub normalize: Vec, @@ -204,7 +213,7 @@ impl RidgeRegressionParameters { impl Default for RidgeRegressionParameters { fn default() -> Self { RidgeRegressionParameters { - solver: RidgeRegressionSolverName::Cholesky, + solver: RidgeRegressionSolverName::default(), alpha: T::one(), normalize: true, } diff --git a/src/naive_bayes/bernoulli.rs b/src/naive_bayes/bernoulli.rs index 29c6c84..d71197e 100644 --- a/src/naive_bayes/bernoulli.rs +++ b/src/naive_bayes/bernoulli.rs @@ -114,10 +114,13 @@ impl> NBDistribution for BernoulliNBDistributi #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct BernoulliNBParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing). pub alpha: T, + #[cfg_attr(feature = "serde", serde(default))] /// Prior probabilities of the classes. If specified the priors are not adjusted according to the data pub priors: Option>, + #[cfg_attr(feature = "serde", serde(default))] /// Threshold for binarizing (mapping to booleans) of sample features. If None, input is presumed to already consist of binary vectors. pub binarize: Option, } @@ -154,10 +157,13 @@ impl Default for BernoulliNBParameters { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct BernoulliNBSearchParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing). pub alpha: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// Prior probabilities of the classes. If specified the priors are not adjusted according to the data pub priors: Vec>>, + #[cfg_attr(feature = "serde", serde(default))] /// Threshold for binarizing (mapping to booleans) of sample features. If None, input is presumed to already consist of binary vectors. pub binarize: Vec>, } diff --git a/src/naive_bayes/categorical.rs b/src/naive_bayes/categorical.rs index 7855688..9cda7a8 100644 --- a/src/naive_bayes/categorical.rs +++ b/src/naive_bayes/categorical.rs @@ -243,6 +243,7 @@ impl CategoricalNBDistribution { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct CategoricalNBParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing). pub alpha: T, } @@ -265,6 +266,7 @@ impl Default for CategoricalNBParameters { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct CategoricalNBSearchParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing). pub alpha: Vec, } diff --git a/src/naive_bayes/gaussian.rs b/src/naive_bayes/gaussian.rs index 24bbdd3..37aeb0f 100644 --- a/src/naive_bayes/gaussian.rs +++ b/src/naive_bayes/gaussian.rs @@ -78,6 +78,7 @@ impl> NBDistribution for GaussianNBDistributio #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct GaussianNBParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Prior probabilities of the classes. If specified the priors are not adjusted according to the data pub priors: Option>, } @@ -100,6 +101,7 @@ impl Default for GaussianNBParameters { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct GaussianNBSearchParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Prior probabilities of the classes. If specified the priors are not adjusted according to the data pub priors: Vec>>, } diff --git a/src/naive_bayes/multinomial.rs b/src/naive_bayes/multinomial.rs index 6e846c1..8119fa9 100644 --- a/src/naive_bayes/multinomial.rs +++ b/src/naive_bayes/multinomial.rs @@ -86,8 +86,10 @@ impl> NBDistribution for MultinomialNBDistribu #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct MultinomialNBParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing). pub alpha: T, + #[cfg_attr(feature = "serde", serde(default))] /// Prior probabilities of the classes. If specified the priors are not adjusted according to the data pub priors: Option>, } @@ -118,8 +120,10 @@ impl Default for MultinomialNBParameters { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct MultinomialNBSearchParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing). pub alpha: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// Prior probabilities of the classes. If specified the priors are not adjusted according to the data pub priors: Vec>>, } diff --git a/src/neighbors/knn_classifier.rs b/src/neighbors/knn_classifier.rs index 8723900..5e34ce7 100644 --- a/src/neighbors/knn_classifier.rs +++ b/src/neighbors/knn_classifier.rs @@ -49,16 +49,21 @@ use crate::neighbors::KNNWeightFunction; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct KNNClassifierParameters, T>> { + #[cfg_attr(feature = "serde", serde(default))] /// a function that defines a distance between each pair of point in training data. /// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait. /// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions. pub distance: D, + #[cfg_attr(feature = "serde", serde(default))] /// backend search algorithm. See [`knn search algorithms`](../../algorithm/neighbour/index.html). `CoverTree` is default. pub algorithm: KNNAlgorithmName, + #[cfg_attr(feature = "serde", serde(default))] /// weighting function that is used to calculate estimated class value. Default function is `KNNWeightFunction::Uniform`. pub weight: KNNWeightFunction, + #[cfg_attr(feature = "serde", serde(default))] /// number of training samples to consider when estimating class for new point. Default value is 3. pub k: usize, + #[cfg_attr(feature = "serde", serde(default))] /// this parameter is not used t: PhantomData, } @@ -111,8 +116,8 @@ impl Default for KNNClassifierParameters { fn default() -> Self { KNNClassifierParameters { distance: Distances::euclidian(), - algorithm: KNNAlgorithmName::CoverTree, - weight: KNNWeightFunction::Uniform, + algorithm: KNNAlgorithmName::default(), + weight: KNNWeightFunction::default(), k: 3, t: PhantomData, } diff --git a/src/neighbors/knn_regressor.rs b/src/neighbors/knn_regressor.rs index 649cd1f..8fdda3d 100644 --- a/src/neighbors/knn_regressor.rs +++ b/src/neighbors/knn_regressor.rs @@ -52,16 +52,21 @@ use crate::neighbors::KNNWeightFunction; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct KNNRegressorParameters, T>> { + #[cfg_attr(feature = "serde", serde(default))] /// a function that defines a distance between each pair of point in training data. /// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait. /// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions. distance: D, + #[cfg_attr(feature = "serde", serde(default))] /// backend search algorithm. See [`knn search algorithms`](../../algorithm/neighbour/index.html). `CoverTree` is default. pub algorithm: KNNAlgorithmName, + #[cfg_attr(feature = "serde", serde(default))] /// weighting function that is used to calculate estimated class value. Default function is `KNNWeightFunction::Uniform`. pub weight: KNNWeightFunction, + #[cfg_attr(feature = "serde", serde(default))] /// number of training samples to consider when estimating class for new point. Default value is 3. pub k: usize, + #[cfg_attr(feature = "serde", serde(default))] /// this parameter is not used t: PhantomData, } @@ -113,8 +118,8 @@ impl Default for KNNRegressorParameters { fn default() -> Self { KNNRegressorParameters { distance: Distances::euclidian(), - algorithm: KNNAlgorithmName::CoverTree, - weight: KNNWeightFunction::Uniform, + algorithm: KNNAlgorithmName::default(), + weight: KNNWeightFunction::default(), k: 3, t: PhantomData, } diff --git a/src/neighbors/mod.rs b/src/neighbors/mod.rs index 86b1e46..5a713ab 100644 --- a/src/neighbors/mod.rs +++ b/src/neighbors/mod.rs @@ -58,6 +58,12 @@ pub enum KNNWeightFunction { Distance, } +impl Default for KNNWeightFunction { + fn default() -> Self { + KNNWeightFunction::Uniform + } +} + impl KNNWeightFunction { fn calc_weights(&self, distances: Vec) -> std::vec::Vec { match *self { diff --git a/src/svm/svc.rs b/src/svm/svc.rs index d390866..97b91de 100644 --- a/src/svm/svc.rs +++ b/src/svm/svc.rs @@ -91,16 +91,22 @@ use crate::svm::{Kernel, Kernels, LinearKernel}; #[derive(Debug, Clone)] /// SVC Parameters pub struct SVCParameters, K: Kernel> { + #[cfg_attr(feature = "serde", serde(default))] /// Number of epochs. pub epoch: usize, + #[cfg_attr(feature = "serde", serde(default))] /// Regularization parameter. pub c: T, + #[cfg_attr(feature = "serde", serde(default))] /// Tolerance for stopping criterion. pub tol: T, + #[cfg_attr(feature = "serde", serde(default))] /// The kernel function. pub kernel: K, + #[cfg_attr(feature = "serde", serde(default))] /// Unused parameter. m: PhantomData, + #[cfg_attr(feature = "serde", serde(default))] /// Controls the pseudo random number generation for shuffling the data for probability estimates seed: Option, } @@ -109,16 +115,22 @@ pub struct SVCParameters, K: Kernel #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct SVCSearchParameters, K: Kernel> { + #[cfg_attr(feature = "serde", serde(default))] /// Number of epochs. pub epoch: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// Regularization parameter. pub c: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// Tolerance for stopping epoch. pub tol: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// The kernel function. pub kernel: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// Unused parameter. m: PhantomData, + #[cfg_attr(feature = "serde", serde(default))] /// Controls the pseudo random number generation for shuffling the data for probability estimates seed: Vec>, } diff --git a/src/tree/decision_tree_classifier.rs b/src/tree/decision_tree_classifier.rs index acc3fb0..d330fdf 100644 --- a/src/tree/decision_tree_classifier.rs +++ b/src/tree/decision_tree_classifier.rs @@ -83,14 +83,19 @@ use crate::rand::get_rng_impl; #[derive(Debug, Clone)] /// Parameters of Decision Tree pub struct DecisionTreeClassifierParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Split criteria to use when building a tree. pub criterion: SplitCriterion, + #[cfg_attr(feature = "serde", serde(default))] /// The maximum depth of the tree. pub max_depth: Option, + #[cfg_attr(feature = "serde", serde(default))] /// The minimum number of samples required to be at a leaf node. pub min_samples_leaf: usize, + #[cfg_attr(feature = "serde", serde(default))] /// The minimum number of samples required to split an internal node. pub min_samples_split: usize, + #[cfg_attr(feature = "serde", serde(default))] /// Controls the randomness of the estimator pub seed: Option, } @@ -118,6 +123,12 @@ pub enum SplitCriterion { ClassificationError, } +impl Default for SplitCriterion { + fn default() -> Self { + SplitCriterion::Gini + } +} + #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug)] struct Node { @@ -196,7 +207,7 @@ impl DecisionTreeClassifierParameters { impl Default for DecisionTreeClassifierParameters { fn default() -> Self { DecisionTreeClassifierParameters { - criterion: SplitCriterion::Gini, + criterion: SplitCriterion::default(), max_depth: None, min_samples_leaf: 1, min_samples_split: 2, diff --git a/src/tree/decision_tree_regressor.rs b/src/tree/decision_tree_regressor.rs index 12bb9c9..c745a0d 100644 --- a/src/tree/decision_tree_regressor.rs +++ b/src/tree/decision_tree_regressor.rs @@ -78,12 +78,16 @@ use crate::rand::get_rng_impl; #[derive(Debug, Clone)] /// Parameters of Regression Tree pub struct DecisionTreeRegressorParameters { + #[cfg_attr(feature = "serde", serde(default))] /// The maximum depth of the tree. pub max_depth: Option, + #[cfg_attr(feature = "serde", serde(default))] /// The minimum number of samples required to be at a leaf node. pub min_samples_leaf: usize, + #[cfg_attr(feature = "serde", serde(default))] /// The minimum number of samples required to split an internal node. pub min_samples_split: usize, + #[cfg_attr(feature = "serde", serde(default))] /// Controls the randomness of the estimator pub seed: Option, } @@ -142,12 +146,16 @@ impl Default for DecisionTreeRegressorParameters { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] pub struct DecisionTreeRegressorSearchParameters { + #[cfg_attr(feature = "serde", serde(default))] /// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) pub max_depth: Vec>, + #[cfg_attr(feature = "serde", serde(default))] /// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) pub min_samples_leaf: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) pub min_samples_split: Vec, + #[cfg_attr(feature = "serde", serde(default))] /// Controls the randomness of the estimator pub seed: Vec>, }