Compare commits
6 Commits
single_linkage
...
v0.4.3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
09be4681cf | ||
|
|
4841791b7e | ||
|
|
9fef05ecc6 | ||
|
|
c5816b0e1b | ||
|
|
5cc5528367 | ||
|
|
d459c48372 |
+1
-1
@@ -2,7 +2,7 @@
|
|||||||
name = "smartcore"
|
name = "smartcore"
|
||||||
description = "Machine Learning in Rust."
|
description = "Machine Learning in Rust."
|
||||||
homepage = "https://smartcorelib.org"
|
homepage = "https://smartcorelib.org"
|
||||||
version = "0.4.2"
|
version = "0.4.3"
|
||||||
authors = ["smartcore Developers"]
|
authors = ["smartcore Developers"]
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
|
|||||||
@@ -0,0 +1,777 @@
|
|||||||
|
///
|
||||||
|
/// ### CosinePair: Data-structure for the dynamic closest-pair problem.
|
||||||
|
///
|
||||||
|
/// Reference:
|
||||||
|
/// Eppstein, David: Fast hierarchical clustering and other applications of
|
||||||
|
/// dynamic closest pairs. Journal of Experimental Algorithmics 5 (2000) 1.
|
||||||
|
///
|
||||||
|
/// Example:
|
||||||
|
/// ```
|
||||||
|
/// use smartcore::metrics::distance::PairwiseDistance;
|
||||||
|
/// use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||||
|
/// use smartcore::algorithm::neighbour::cosinepair::CosinePair;
|
||||||
|
/// let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
/// &[5.1, 3.5, 1.4, 0.2],
|
||||||
|
/// &[4.9, 3.0, 1.4, 0.2],
|
||||||
|
/// &[4.7, 3.2, 1.3, 0.2],
|
||||||
|
/// &[4.6, 3.1, 1.5, 0.2],
|
||||||
|
/// &[5.0, 3.6, 1.4, 0.2],
|
||||||
|
/// &[5.4, 3.9, 1.7, 0.4],
|
||||||
|
/// ]).unwrap();
|
||||||
|
/// let cosinepair = CosinePair::new(&x);
|
||||||
|
/// let closest_pair: PairwiseDistance<f64> = cosinepair.unwrap().closest_pair();
|
||||||
|
/// ```
|
||||||
|
/// <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
|
/// <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use num::Bounded;
|
||||||
|
|
||||||
|
use crate::error::{Failed, FailedError};
|
||||||
|
use crate::linalg::basic::arrays::{Array1, Array2};
|
||||||
|
use crate::metrics::distance::cosine::Cosine;
|
||||||
|
use crate::metrics::distance::{Distance, PairwiseDistance};
|
||||||
|
use crate::numbers::floatnum::FloatNumber;
|
||||||
|
use crate::numbers::realnum::RealNumber;
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Inspired by Python implementation:
|
||||||
|
/// <https://github.com/carsonfarmer/fastpair/blob/b8b4d3000ab6f795a878936667eee1b557bf353d/fastpair/base.py>
|
||||||
|
/// MIT License (MIT) Copyright (c) 2016 Carson Farmer
|
||||||
|
///
|
||||||
|
/// affinity used is Cosine as it is the most used
|
||||||
|
///
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct CosinePair<'a, T: RealNumber + FloatNumber, M: Array2<T>> {
|
||||||
|
/// initial matrix
|
||||||
|
pub samples: &'a M,
|
||||||
|
/// closest pair hashmap (connectivity matrix for closest pairs)
|
||||||
|
pub distances: HashMap<usize, PairwiseDistance<T>>,
|
||||||
|
/// conga line used to keep track of the closest pair
|
||||||
|
pub neighbours: Vec<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T: RealNumber + FloatNumber, M: Array2<T>> CosinePair<'a, T, M> {
|
||||||
|
/// Constructor
|
||||||
|
/// Instantiate and initialize the algorithm
|
||||||
|
pub fn new(m: &'a M) -> Result<Self, Failed> {
|
||||||
|
if m.shape().0 < 2 {
|
||||||
|
return Err(Failed::because(
|
||||||
|
FailedError::FindFailed,
|
||||||
|
"min number of rows should be 2",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut init = Self {
|
||||||
|
samples: m,
|
||||||
|
// to be computed in init(..)
|
||||||
|
distances: HashMap::with_capacity(m.shape().0),
|
||||||
|
neighbours: Vec::with_capacity(m.shape().0 + 1),
|
||||||
|
};
|
||||||
|
init.init();
|
||||||
|
Ok(init)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Initialise `CosinePair` by passing a `Array2`.
|
||||||
|
/// Build a CosinePairs data-structure from a set of (new) points.
|
||||||
|
fn init(&mut self) {
|
||||||
|
// basic measures
|
||||||
|
let len = self.samples.shape().0;
|
||||||
|
let max_index = self.samples.shape().0 - 1;
|
||||||
|
|
||||||
|
// Store all closest neighbors
|
||||||
|
let _distances = Box::new(HashMap::with_capacity(len));
|
||||||
|
let _neighbours = Box::new(Vec::with_capacity(len));
|
||||||
|
|
||||||
|
let mut distances = *_distances;
|
||||||
|
let mut neighbours = *_neighbours;
|
||||||
|
|
||||||
|
// fill neighbours with -1 values
|
||||||
|
neighbours.extend(0..len);
|
||||||
|
|
||||||
|
// init closest neighbour pairwise data
|
||||||
|
for index_row_i in 0..(max_index) {
|
||||||
|
distances.insert(
|
||||||
|
index_row_i,
|
||||||
|
PairwiseDistance {
|
||||||
|
node: index_row_i,
|
||||||
|
neighbour: Option::None,
|
||||||
|
distance: Some(<T as Bounded>::max_value()),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// loop through indeces and neighbours
|
||||||
|
for index_row_i in 0..(len) {
|
||||||
|
// start looking for the neighbour in the second element
|
||||||
|
let mut index_closest = index_row_i + 1; // closest neighbour index
|
||||||
|
let mut nbd: Option<T> = distances[&index_row_i].distance; // init neighbour distance
|
||||||
|
for index_row_j in (index_row_i + 1)..len {
|
||||||
|
distances.insert(
|
||||||
|
index_row_j,
|
||||||
|
PairwiseDistance {
|
||||||
|
node: index_row_j,
|
||||||
|
neighbour: Some(index_row_i),
|
||||||
|
distance: nbd,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
let d = Cosine::new().distance(
|
||||||
|
&Vec::from_iterator(
|
||||||
|
self.samples.get_row(index_row_i).iterator(0).copied(),
|
||||||
|
self.samples.shape().1,
|
||||||
|
),
|
||||||
|
&Vec::from_iterator(
|
||||||
|
self.samples.get_row(index_row_j).iterator(0).copied(),
|
||||||
|
self.samples.shape().1,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
if d < nbd.unwrap().to_f64().unwrap() {
|
||||||
|
// set this j-value to be the closest neighbour
|
||||||
|
index_closest = index_row_j;
|
||||||
|
nbd = Some(T::from(d).unwrap());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add that edge
|
||||||
|
distances.entry(index_row_i).and_modify(|e| {
|
||||||
|
e.distance = nbd;
|
||||||
|
e.neighbour = Some(index_closest);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
// No more neighbors, terminate conga line.
|
||||||
|
// Last person on the line has no neigbors
|
||||||
|
distances.get_mut(&max_index).unwrap().neighbour = Some(max_index);
|
||||||
|
distances.get_mut(&(len - 1)).unwrap().distance = Some(<T as Bounded>::max_value());
|
||||||
|
|
||||||
|
// compute sparse matrix (connectivity matrix)
|
||||||
|
let mut sparse_matrix = M::zeros(len, len);
|
||||||
|
for (_, p) in distances.iter() {
|
||||||
|
sparse_matrix.set((p.node, p.neighbour.unwrap()), p.distance.unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
self.distances = distances;
|
||||||
|
self.neighbours = neighbours;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Query k nearest neighbors for a row that's already in the dataset
|
||||||
|
pub fn query_row(&self, query_row_index: usize, k: usize) -> Result<Vec<(T, usize)>, Failed> {
|
||||||
|
if query_row_index >= self.samples.shape().0 {
|
||||||
|
return Err(Failed::because(
|
||||||
|
FailedError::FindFailed,
|
||||||
|
"Query row index out of bounds",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if k == 0 {
|
||||||
|
return Ok(Vec::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get distances to all other points
|
||||||
|
let mut distances = self.distances_from(query_row_index);
|
||||||
|
|
||||||
|
// Sort by distance (ascending)
|
||||||
|
distances.sort_by(|a, b| {
|
||||||
|
a.distance
|
||||||
|
.unwrap()
|
||||||
|
.partial_cmp(&b.distance.unwrap())
|
||||||
|
.unwrap_or(std::cmp::Ordering::Equal)
|
||||||
|
});
|
||||||
|
|
||||||
|
// Take top k neighbors and convert to (distance, index) format
|
||||||
|
let neighbors: Vec<(T, usize)> = distances
|
||||||
|
.into_iter()
|
||||||
|
.take(k)
|
||||||
|
.map(|pd| (pd.distance.unwrap(), pd.neighbour.unwrap()))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(neighbors)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Query k nearest neighbors for an external query vector
|
||||||
|
pub fn query(&self, query_vector: &Vec<T>, k: usize) -> Result<Vec<(T, usize)>, Failed> {
|
||||||
|
if query_vector.len() != self.samples.shape().1 {
|
||||||
|
return Err(Failed::because(
|
||||||
|
FailedError::FindFailed,
|
||||||
|
"Query vector dimension mismatch",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if k == 0 {
|
||||||
|
return Ok(Vec::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute distances from query vector to all points in the dataset
|
||||||
|
let mut distances = Vec::<PairwiseDistance<T>>::with_capacity(self.samples.shape().0);
|
||||||
|
|
||||||
|
for i in 0..self.samples.shape().0 {
|
||||||
|
let dataset_point = Vec::from_iterator(
|
||||||
|
self.samples.get_row(i).iterator(0).copied(),
|
||||||
|
self.samples.shape().1,
|
||||||
|
);
|
||||||
|
|
||||||
|
let distance = T::from(Cosine::new().distance(query_vector, &dataset_point)).unwrap();
|
||||||
|
|
||||||
|
distances.push(PairwiseDistance {
|
||||||
|
node: i, // This represents the dataset point index
|
||||||
|
neighbour: Some(i),
|
||||||
|
distance: Some(distance),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by distance (ascending)
|
||||||
|
distances.sort_by(|a, b| {
|
||||||
|
a.distance
|
||||||
|
.unwrap()
|
||||||
|
.partial_cmp(&b.distance.unwrap())
|
||||||
|
.unwrap_or(std::cmp::Ordering::Equal)
|
||||||
|
});
|
||||||
|
|
||||||
|
// Take top k neighbors and convert to (distance, index) format
|
||||||
|
let neighbors: Vec<(T, usize)> = distances
|
||||||
|
.into_iter()
|
||||||
|
.take(k)
|
||||||
|
.map(|pd| (pd.distance.unwrap(), pd.node))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(neighbors)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Optimized version that reuses the existing distances_from method
|
||||||
|
/// This is more efficient for queries that are points already in the dataset
|
||||||
|
pub fn query_optimized(
|
||||||
|
&self,
|
||||||
|
query_row_index: usize,
|
||||||
|
k: usize,
|
||||||
|
) -> Result<Vec<(T, usize)>, Failed> {
|
||||||
|
// Reuse existing method and sort the results
|
||||||
|
self.query_row(query_row_index, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find closest pair by scanning list of nearest neighbors.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn closest_pair(&self) -> PairwiseDistance<T> {
|
||||||
|
let mut a = self.neighbours[0]; // Start with first point
|
||||||
|
let mut d = self.distances[&a].distance;
|
||||||
|
for p in self.neighbours.iter() {
|
||||||
|
if self.distances[p].distance < d {
|
||||||
|
a = *p; // Update `a` and distance `d`
|
||||||
|
d = self.distances[p].distance;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let b = self.distances[&a].neighbour;
|
||||||
|
PairwiseDistance {
|
||||||
|
node: a,
|
||||||
|
neighbour: b,
|
||||||
|
distance: d,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Return order dissimilarities from closest to furthest
|
||||||
|
///
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn ordered_pairs(&self) -> std::vec::IntoIter<&PairwiseDistance<T>> {
|
||||||
|
// improvement: implement this to return `impl Iterator<Item = &PairwiseDistance<T>>`
|
||||||
|
// need to implement trait `Iterator` for `Vec<&PairwiseDistance<T>>`
|
||||||
|
let mut distances = self
|
||||||
|
.distances
|
||||||
|
.values()
|
||||||
|
.collect::<Vec<&PairwiseDistance<T>>>();
|
||||||
|
distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||||
|
distances.into_iter()
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Compute distances from input to all other points in data-structure.
|
||||||
|
// input is the row index of the sample matrix
|
||||||
|
//
|
||||||
|
#[allow(dead_code)]
|
||||||
|
fn distances_from(&self, index_row: usize) -> Vec<PairwiseDistance<T>> {
|
||||||
|
let mut distances = Vec::<PairwiseDistance<T>>::with_capacity(self.samples.shape().0);
|
||||||
|
for other in self.neighbours.iter() {
|
||||||
|
if index_row != *other {
|
||||||
|
distances.push(PairwiseDistance {
|
||||||
|
node: index_row,
|
||||||
|
neighbour: Some(*other),
|
||||||
|
distance: Some(
|
||||||
|
T::from(Cosine::new().distance(
|
||||||
|
&Vec::from_iterator(
|
||||||
|
self.samples.get_row(index_row).iterator(0).copied(),
|
||||||
|
self.samples.shape().1,
|
||||||
|
),
|
||||||
|
&Vec::from_iterator(
|
||||||
|
self.samples.get_row(*other).iterator(0).copied(),
|
||||||
|
self.samples.shape().1,
|
||||||
|
),
|
||||||
|
))
|
||||||
|
.unwrap(),
|
||||||
|
),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
distances
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix};
|
||||||
|
use approx::assert_relative_eq;
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_initialization() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
|
&[4.7, 3.2, 1.3, 0.2],
|
||||||
|
&[4.6, 3.1, 1.5, 0.2],
|
||||||
|
&[5.0, 3.6, 1.4, 0.2],
|
||||||
|
&[5.4, 3.9, 1.7, 0.4],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x);
|
||||||
|
|
||||||
|
assert!(cosine_pair.is_ok());
|
||||||
|
let cp = cosine_pair.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(cp.samples.shape().0, 6);
|
||||||
|
assert_eq!(cp.distances.len(), 6);
|
||||||
|
assert_eq!(cp.neighbours.len(), 6);
|
||||||
|
assert!(!cp.distances.is_empty());
|
||||||
|
assert!(!cp.neighbours.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_minimum_rows_error() {
|
||||||
|
// Test with only one row - should fail
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[&[5.1, 3.5, 1.4, 0.2]]).unwrap();
|
||||||
|
|
||||||
|
let result = CosinePair::new(&x);
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
if let Err(e) = result {
|
||||||
|
let expected_error =
|
||||||
|
Failed::because(FailedError::FindFailed, "min number of rows should be 2");
|
||||||
|
assert_eq!(e, expected_error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_closest_pair() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[1.0, 0.0],
|
||||||
|
&[0.0, 1.0],
|
||||||
|
&[1.0, 1.0],
|
||||||
|
&[2.0, 2.0], // This should be closest to [1.0, 1.0] with cosine distance
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
let closest_pair = cosine_pair.closest_pair();
|
||||||
|
|
||||||
|
// Verify structure
|
||||||
|
assert!(closest_pair.distance.is_some());
|
||||||
|
assert!(closest_pair.neighbour.is_some());
|
||||||
|
|
||||||
|
// The closest pair should have the smallest cosine distance
|
||||||
|
let distance = closest_pair.distance.unwrap();
|
||||||
|
assert!(distance >= 0.0 && distance <= 2.0); // Cosine distance range
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_identical_vectors() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[1.0, 2.0, 3.0],
|
||||||
|
&[1.0, 2.0, 3.0], // Identical vector
|
||||||
|
&[4.0, 5.0, 6.0],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
let closest_pair = cosine_pair.closest_pair();
|
||||||
|
|
||||||
|
// Distance between identical vectors should be 0
|
||||||
|
let distance = closest_pair.distance.unwrap();
|
||||||
|
assert!((distance - 0.0).abs() < 1e-8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_orthogonal_vectors() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[1.0, 0.0],
|
||||||
|
&[0.0, 1.0], // Orthogonal to first
|
||||||
|
&[2.0, 3.0],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
|
||||||
|
// Check that orthogonal vectors have cosine distance of 1.0
|
||||||
|
let distances_from_first = cosine_pair.distances_from(0);
|
||||||
|
let orthogonal_distance = distances_from_first
|
||||||
|
.iter()
|
||||||
|
.find(|pd| pd.neighbour == Some(1))
|
||||||
|
.unwrap()
|
||||||
|
.distance
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!((orthogonal_distance - 1.0).abs() < 1e-8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_ordered_pairs() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[1.0, 2.0],
|
||||||
|
&[2.0, 1.0],
|
||||||
|
&[3.0, 4.0],
|
||||||
|
&[4.0, 3.0],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
let ordered_pairs: Vec<_> = cosine_pair.ordered_pairs().collect();
|
||||||
|
|
||||||
|
assert_eq!(ordered_pairs.len(), 4);
|
||||||
|
|
||||||
|
// Check that pairs are ordered by distance (ascending)
|
||||||
|
for i in 1..ordered_pairs.len() {
|
||||||
|
let prev_distance = ordered_pairs[i - 1].distance.unwrap();
|
||||||
|
let curr_distance = ordered_pairs[i].distance.unwrap();
|
||||||
|
assert!(prev_distance <= curr_distance);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_query_row() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[1.0, 0.0, 0.0],
|
||||||
|
&[0.0, 1.0, 0.0],
|
||||||
|
&[0.0, 0.0, 1.0],
|
||||||
|
&[1.0, 1.0, 0.0],
|
||||||
|
&[0.0, 1.0, 1.0],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
|
||||||
|
// Query k=2 nearest neighbors for row 0
|
||||||
|
let neighbors = cosine_pair.query_row(0, 2).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(neighbors.len(), 2);
|
||||||
|
|
||||||
|
// Check that distances are in ascending order
|
||||||
|
assert!(neighbors[0].0 <= neighbors[1].0);
|
||||||
|
|
||||||
|
// All distances should be valid cosine distances (0 to 2)
|
||||||
|
for (distance, _) in &neighbors {
|
||||||
|
assert!(*distance >= 0.0 && *distance <= 2.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_query_row_bounds_error() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
|
||||||
|
// Query with out-of-bounds row index
|
||||||
|
let result = cosine_pair.query_row(5, 1);
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
if let Err(e) = result {
|
||||||
|
let expected_error =
|
||||||
|
Failed::because(FailedError::FindFailed, "Query row index out of bounds");
|
||||||
|
assert_eq!(e, expected_error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_query_row_k_zero() {
|
||||||
|
let x =
|
||||||
|
DenseMatrix::<f64>::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0], &[5.0, 6.0]]).unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
let neighbors = cosine_pair.query_row(0, 0).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(neighbors.len(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_query_external_vector() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[1.0, 0.0, 0.0],
|
||||||
|
&[0.0, 1.0, 0.0],
|
||||||
|
&[0.0, 0.0, 1.0],
|
||||||
|
&[1.0, 1.0, 0.0],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
|
||||||
|
// Query with external vector
|
||||||
|
let query_vector = vec![1.0, 0.5, 0.0];
|
||||||
|
let neighbors = cosine_pair.query(&query_vector, 2).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(neighbors.len(), 2);
|
||||||
|
|
||||||
|
// Verify distances are valid and ordered
|
||||||
|
assert!(neighbors[0].0 <= neighbors[1].0);
|
||||||
|
for (distance, index) in &neighbors {
|
||||||
|
assert!(*distance >= 0.0 && *distance <= 2.0);
|
||||||
|
assert!(*index < x.shape().0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_query_dimension_mismatch() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]).unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
|
||||||
|
// Query with mismatched dimensions
|
||||||
|
let query_vector = vec![1.0, 2.0]; // Only 2 dimensions, but data has 3
|
||||||
|
let result = cosine_pair.query(&query_vector, 1);
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
if let Err(e) = result {
|
||||||
|
let expected_error =
|
||||||
|
Failed::because(FailedError::FindFailed, "Query vector dimension mismatch");
|
||||||
|
assert_eq!(e, expected_error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_query_k_zero_external() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[&[1.0, 2.0], &[3.0, 4.0]]).unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
let query_vector = vec![1.0, 1.0];
|
||||||
|
let neighbors = cosine_pair.query(&query_vector, 0).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(neighbors.len(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_large_dataset() {
|
||||||
|
// Test with larger dataset (similar to Iris)
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[5.1, 3.5, 1.4, 0.2],
|
||||||
|
&[4.9, 3.0, 1.4, 0.2],
|
||||||
|
&[4.7, 3.2, 1.3, 0.2],
|
||||||
|
&[4.6, 3.1, 1.5, 0.2],
|
||||||
|
&[5.0, 3.6, 1.4, 0.2],
|
||||||
|
&[5.4, 3.9, 1.7, 0.4],
|
||||||
|
&[4.6, 3.4, 1.4, 0.3],
|
||||||
|
&[5.0, 3.4, 1.5, 0.2],
|
||||||
|
&[4.4, 2.9, 1.4, 0.2],
|
||||||
|
&[4.9, 3.1, 1.5, 0.1],
|
||||||
|
&[7.0, 3.2, 4.7, 1.4],
|
||||||
|
&[6.4, 3.2, 4.5, 1.5],
|
||||||
|
&[6.9, 3.1, 4.9, 1.5],
|
||||||
|
&[5.5, 2.3, 4.0, 1.3],
|
||||||
|
&[6.5, 2.8, 4.6, 1.5],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(cosine_pair.samples.shape().0, 15);
|
||||||
|
assert_eq!(cosine_pair.distances.len(), 15);
|
||||||
|
assert_eq!(cosine_pair.neighbours.len(), 15);
|
||||||
|
|
||||||
|
// Test closest pair computation
|
||||||
|
let closest_pair = cosine_pair.closest_pair();
|
||||||
|
assert!(closest_pair.distance.is_some());
|
||||||
|
assert!(closest_pair.neighbour.is_some());
|
||||||
|
|
||||||
|
let distance = closest_pair.distance.unwrap();
|
||||||
|
assert!(distance >= 0.0 && distance <= 2.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_float_precision() {
|
||||||
|
// Test with f32 precision
|
||||||
|
let x = DenseMatrix::<f32>::from_2d_array(&[
|
||||||
|
&[1.0f32, 2.0, 3.0],
|
||||||
|
&[4.0f32, 5.0, 6.0],
|
||||||
|
&[7.0f32, 8.0, 9.0],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
let closest_pair = cosine_pair.closest_pair();
|
||||||
|
|
||||||
|
assert!(closest_pair.distance.is_some());
|
||||||
|
let distance = closest_pair.distance.unwrap();
|
||||||
|
assert!(distance >= 0.0 && distance <= 2.0);
|
||||||
|
|
||||||
|
// Test querying
|
||||||
|
let neighbors = cosine_pair.query_row(0, 2).unwrap();
|
||||||
|
assert_eq!(neighbors.len(), 2);
|
||||||
|
assert_eq!(neighbors[0].1, 1);
|
||||||
|
assert_relative_eq!(neighbors[0].0, 0.025368154);
|
||||||
|
assert_eq!(neighbors[1].1, 2);
|
||||||
|
assert_relative_eq!(neighbors[1].0, 0.040588055);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_distances_from() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[1.0, 0.0],
|
||||||
|
&[0.0, 1.0],
|
||||||
|
&[1.0, 1.0],
|
||||||
|
&[2.0, 0.0],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
let distances = cosine_pair.distances_from(0);
|
||||||
|
|
||||||
|
// Should have 3 distances (excluding self)
|
||||||
|
assert_eq!(distances.len(), 3);
|
||||||
|
|
||||||
|
// All should be from node 0
|
||||||
|
for pd in &distances {
|
||||||
|
assert_eq!(pd.node, 0);
|
||||||
|
assert!(pd.neighbour.is_some());
|
||||||
|
assert!(pd.distance.is_some());
|
||||||
|
assert!(pd.neighbour.unwrap() != 0); // Should not include self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_consistency_check() {
|
||||||
|
// Verify that different query methods return consistent results
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[1.0, 2.0, 3.0],
|
||||||
|
&[4.0, 5.0, 6.0],
|
||||||
|
&[7.0, 8.0, 9.0],
|
||||||
|
&[2.0, 3.0, 4.0],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
|
||||||
|
// Query row 0 using internal method
|
||||||
|
let neighbors_internal = cosine_pair.query_row(0, 2).unwrap();
|
||||||
|
|
||||||
|
// Query row 0 using optimized method (should be same)
|
||||||
|
let neighbors_optimized = cosine_pair.query_optimized(0, 2).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(neighbors_internal.len(), neighbors_optimized.len());
|
||||||
|
for i in 0..neighbors_internal.len() {
|
||||||
|
let (dist1, idx1) = neighbors_internal[i];
|
||||||
|
let (dist2, idx2) = neighbors_optimized[i];
|
||||||
|
assert!((dist1 - dist2).abs() < 1e-10);
|
||||||
|
assert_eq!(idx1, idx2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Brute force algorithm for testing/comparison
|
||||||
|
fn closest_pair_brute_force(
|
||||||
|
cosine_pair: &CosinePair<'_, f64, DenseMatrix<f64>>,
|
||||||
|
) -> PairwiseDistance<f64> {
|
||||||
|
use itertools::Itertools;
|
||||||
|
|
||||||
|
let m = cosine_pair.samples.shape().0;
|
||||||
|
let mut closest_pair = PairwiseDistance {
|
||||||
|
node: 0,
|
||||||
|
neighbour: None,
|
||||||
|
distance: Some(f64::MAX),
|
||||||
|
};
|
||||||
|
|
||||||
|
for pair in (0..m).combinations(2) {
|
||||||
|
let d = Cosine::new().distance(
|
||||||
|
&Vec::from_iterator(
|
||||||
|
cosine_pair.samples.get_row(pair[0]).iterator(0).copied(),
|
||||||
|
cosine_pair.samples.shape().1,
|
||||||
|
),
|
||||||
|
&Vec::from_iterator(
|
||||||
|
cosine_pair.samples.get_row(pair[1]).iterator(0).copied(),
|
||||||
|
cosine_pair.samples.shape().1,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
if d < closest_pair.distance.unwrap() {
|
||||||
|
closest_pair.node = pair[0];
|
||||||
|
closest_pair.neighbour = Some(pair[1]);
|
||||||
|
closest_pair.distance = Some(d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
closest_pair
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_pair_vs_brute_force() {
|
||||||
|
let x = DenseMatrix::<f64>::from_2d_array(&[
|
||||||
|
&[1.0, 2.0, 3.0],
|
||||||
|
&[4.0, 5.0, 6.0],
|
||||||
|
&[7.0, 8.0, 9.0],
|
||||||
|
&[1.1, 2.1, 3.1], // Close to first point
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let cosine_pair = CosinePair::new(&x).unwrap();
|
||||||
|
let cp_result = cosine_pair.closest_pair();
|
||||||
|
let brute_result = closest_pair_brute_force(&cosine_pair);
|
||||||
|
|
||||||
|
// Results should be identical or very close
|
||||||
|
assert!((cp_result.distance.unwrap() - brute_result.distance.unwrap()).abs() < 1e-10);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -39,6 +39,8 @@ use crate::numbers::basenum::Number;
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
pub(crate) mod bbd_tree;
|
pub(crate) mod bbd_tree;
|
||||||
|
/// a variant of fastpair using cosine distance
|
||||||
|
pub mod cosinepair;
|
||||||
/// tree data structure for fast nearest neighbor search
|
/// tree data structure for fast nearest neighbor search
|
||||||
pub mod cover_tree;
|
pub mod cover_tree;
|
||||||
/// fastpair closest neighbour algorithm
|
/// fastpair closest neighbour algorithm
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
use num_traits::Num;
|
use num_traits::Num;
|
||||||
|
|
||||||
pub trait QuickArgSort {
|
pub trait QuickArgSort {
|
||||||
|
#[allow(dead_code)]
|
||||||
fn quick_argsort_mut(&mut self) -> Vec<usize>;
|
fn quick_argsort_mut(&mut self) -> Vec<usize>;
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
|
|||||||
@@ -0,0 +1,317 @@
|
|||||||
|
//! # Agglomerative Hierarchical Clustering
|
||||||
|
//!
|
||||||
|
//! Agglomerative clustering is a "bottom-up" hierarchical clustering method. It works by placing each data point in its own cluster and then successively merging the two most similar clusters until a stopping criterion is met. This process creates a tree-based hierarchy of clusters known as a dendrogram.
|
||||||
|
//!
|
||||||
|
//! The similarity of two clusters is determined by a **linkage criterion**. This implementation uses **single-linkage**, where the distance between two clusters is defined as the minimum distance between any single point in the first cluster and any single point in the second cluster. The distance between points is the standard Euclidean distance.
|
||||||
|
//!
|
||||||
|
//! The algorithm first builds the full hierarchy of `N-1` merges. To obtain a specific number of clusters, `n_clusters`, the algorithm then effectively "cuts" the dendrogram at the point where `n_clusters` remain.
|
||||||
|
//!
|
||||||
|
//! ## Example:
|
||||||
|
//!
|
||||||
|
//! ```
|
||||||
|
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||||
|
//! use smartcore::cluster::agglomerative::{AgglomerativeClustering, AgglomerativeClusteringParameters};
|
||||||
|
//!
|
||||||
|
//! // A dataset with 2 distinct groups of points.
|
||||||
|
//! let x = DenseMatrix::from_2d_array(&[
|
||||||
|
//! &[0.0, 0.0], &[1.0, 1.0], &[0.5, 0.5], // Cluster A
|
||||||
|
//! &[10.0, 10.0], &[11.0, 11.0], &[10.5, 10.5], // Cluster B
|
||||||
|
//! ]).unwrap();
|
||||||
|
//!
|
||||||
|
//! // Set parameters to find 2 clusters.
|
||||||
|
//! let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(2);
|
||||||
|
//!
|
||||||
|
//! // Fit the model to the data.
|
||||||
|
//! let clustering = AgglomerativeClustering::<f64, usize, DenseMatrix<f64>, Vec<usize>>::fit(&x, parameters).unwrap();
|
||||||
|
//!
|
||||||
|
//! // Get the cluster assignments.
|
||||||
|
//! let labels = clustering.labels; // e.g., [0, 0, 0, 1, 1, 1]
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! ## References:
|
||||||
|
//!
|
||||||
|
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 10.3.2 Hierarchical Clustering](http://faculty.marshall.usc.edu/gareth-james/ISL/)
|
||||||
|
//! * ["The Elements of Statistical Learning", Hastie T., Tibshirani R., Friedman J., 14.3.12 Hierarchical Clustering](https://hastie.su.domains/ElemStatLearn/)
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
use crate::api::UnsupervisedEstimator;
|
||||||
|
use crate::error::{Failed, FailedError};
|
||||||
|
use crate::linalg::basic::arrays::{Array1, Array2};
|
||||||
|
use crate::numbers::basenum::Number;
|
||||||
|
|
||||||
|
/// Parameters for the Agglomerative Clustering algorithm.
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct AgglomerativeClusteringParameters {
|
||||||
|
/// The number of clusters to find.
|
||||||
|
pub n_clusters: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AgglomerativeClusteringParameters {
|
||||||
|
/// Sets the number of clusters.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `n_clusters` - The desired number of clusters.
|
||||||
|
pub fn with_n_clusters(mut self, n_clusters: usize) -> Self {
|
||||||
|
self.n_clusters = n_clusters;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for AgglomerativeClusteringParameters {
|
||||||
|
fn default() -> Self {
|
||||||
|
AgglomerativeClusteringParameters { n_clusters: 2 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Agglomerative Clustering model.
|
||||||
|
///
|
||||||
|
/// This implementation uses single-linkage clustering, which is mathematically
|
||||||
|
/// equivalent to finding the Minimum Spanning Tree (MST) of the data points.
|
||||||
|
/// The core logic is an efficient implementation of Kruskal's algorithm, which
|
||||||
|
/// processes all pairwise distances in increasing order and uses a Disjoint
|
||||||
|
/// Set Union (DSU) data structure to track cluster membership.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct AgglomerativeClustering<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
|
||||||
|
/// The cluster label assigned to each sample.
|
||||||
|
pub labels: Vec<usize>,
|
||||||
|
_phantom_tx: PhantomData<TX>,
|
||||||
|
_phantom_ty: PhantomData<TY>,
|
||||||
|
_phantom_x: PhantomData<X>,
|
||||||
|
_phantom_y: PhantomData<Y>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> AgglomerativeClustering<TX, TY, X, Y> {
|
||||||
|
/// Fits the agglomerative clustering model to the data.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `data` - A reference to the input data matrix.
|
||||||
|
/// * `parameters` - The parameters for the clustering algorithm, including `n_clusters`.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// A `Result` containing the fitted model with cluster labels, or an error if
|
||||||
|
pub fn fit(data: &X, parameters: AgglomerativeClusteringParameters) -> Result<Self, Failed> {
|
||||||
|
let (num_samples, _) = data.shape();
|
||||||
|
let n_clusters = parameters.n_clusters;
|
||||||
|
if n_clusters > num_samples {
|
||||||
|
return Err(Failed::because(
|
||||||
|
FailedError::ParametersError,
|
||||||
|
&format!(
|
||||||
|
"n_clusters: {n_clusters} cannot be greater than n_samples: {num_samples}"
|
||||||
|
),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut distance_pairs = Vec::new();
|
||||||
|
for i in 0..num_samples {
|
||||||
|
for j in (i + 1)..num_samples {
|
||||||
|
let distance: f64 = data
|
||||||
|
.get_row(i)
|
||||||
|
.iterator(0)
|
||||||
|
.zip(data.get_row(j).iterator(0))
|
||||||
|
.map(|(&a, &b)| (a.to_f64().unwrap() - b.to_f64().unwrap()).powi(2))
|
||||||
|
.sum::<f64>();
|
||||||
|
|
||||||
|
distance_pairs.push((distance, i, j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
distance_pairs.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
|
||||||
|
let mut parent = HashMap::new();
|
||||||
|
let mut children = HashMap::new();
|
||||||
|
for i in 0..num_samples {
|
||||||
|
parent.insert(i, i);
|
||||||
|
children.insert(i, vec![i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut merge_history = Vec::new();
|
||||||
|
let num_merges_needed = num_samples - 1;
|
||||||
|
|
||||||
|
while merge_history.len() < num_merges_needed {
|
||||||
|
let (_, p1, p2) = distance_pairs.pop().unwrap();
|
||||||
|
|
||||||
|
let root1 = parent[&p1];
|
||||||
|
let root2 = parent[&p2];
|
||||||
|
|
||||||
|
if root1 != root2 {
|
||||||
|
let root2_children = children.remove(&root2).unwrap();
|
||||||
|
for child in root2_children.iter() {
|
||||||
|
parent.insert(*child, root1);
|
||||||
|
}
|
||||||
|
let root1_children = children.get_mut(&root1).unwrap();
|
||||||
|
root1_children.extend(root2_children);
|
||||||
|
merge_history.push((root1, root2));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut clusters = HashMap::new();
|
||||||
|
let mut assignments = HashMap::new();
|
||||||
|
|
||||||
|
for i in 0..num_samples {
|
||||||
|
clusters.insert(i, vec![i]);
|
||||||
|
assignments.insert(i, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
let merges_to_apply = num_samples - n_clusters;
|
||||||
|
|
||||||
|
for (root1, root2) in merge_history[0..merges_to_apply].iter() {
|
||||||
|
let root1_cluster = assignments[root1];
|
||||||
|
let root2_cluster = assignments[root2];
|
||||||
|
|
||||||
|
let root2_assignments = clusters.remove(&root2_cluster).unwrap();
|
||||||
|
for assignment in root2_assignments.iter() {
|
||||||
|
assignments.insert(*assignment, root1_cluster);
|
||||||
|
}
|
||||||
|
let root1_assignments = clusters.get_mut(&root1_cluster).unwrap();
|
||||||
|
root1_assignments.extend(root2_assignments);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut labels: Vec<usize> = (0..num_samples).map(|_| 0).collect();
|
||||||
|
let mut cluster_keys: Vec<&usize> = clusters.keys().collect();
|
||||||
|
cluster_keys.sort();
|
||||||
|
for (i, key) in cluster_keys.into_iter().enumerate() {
|
||||||
|
for index in clusters[key].iter() {
|
||||||
|
labels[*index] = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(AgglomerativeClustering {
|
||||||
|
labels,
|
||||||
|
_phantom_tx: PhantomData,
|
||||||
|
_phantom_ty: PhantomData,
|
||||||
|
_phantom_x: PhantomData,
|
||||||
|
_phantom_y: PhantomData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||||
|
UnsupervisedEstimator<X, AgglomerativeClusteringParameters>
|
||||||
|
for AgglomerativeClustering<TX, TY, X, Y>
|
||||||
|
{
|
||||||
|
fn fit(x: &X, parameters: AgglomerativeClusteringParameters) -> Result<Self, Failed> {
|
||||||
|
AgglomerativeClustering::fit(x, parameters)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use crate::linalg::basic::matrix::DenseMatrix;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_simple_clustering() {
|
||||||
|
// Two distinct clusters, far apart.
|
||||||
|
let data = vec![
|
||||||
|
0.0, 0.0, 1.0, 1.0, 0.5, 0.5, // Cluster A
|
||||||
|
10.0, 10.0, 11.0, 11.0, 10.5, 10.5, // Cluster B
|
||||||
|
];
|
||||||
|
let matrix = DenseMatrix::new(6, 2, data, false).unwrap();
|
||||||
|
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(2);
|
||||||
|
// Using f64 for TY as usize doesn't satisfy the Number trait bound.
|
||||||
|
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
|
||||||
|
&matrix, parameters,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let labels = clustering.labels;
|
||||||
|
|
||||||
|
// Check that all points in the first group have the same label.
|
||||||
|
let first_group_label = labels[0];
|
||||||
|
assert!(labels[0..3].iter().all(|&l| l == first_group_label));
|
||||||
|
|
||||||
|
// Check that all points in the second group have the same label.
|
||||||
|
let second_group_label = labels[3];
|
||||||
|
assert!(labels[3..6].iter().all(|&l| l == second_group_label));
|
||||||
|
|
||||||
|
// Check that the two groups have different labels.
|
||||||
|
assert_ne!(first_group_label, second_group_label);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_four_clusters() {
|
||||||
|
// Four distinct clusters in the corners of a square.
|
||||||
|
let data = vec![
|
||||||
|
0.0, 0.0, 1.0, 1.0, // Cluster A
|
||||||
|
100.0, 100.0, 101.0, 101.0, // Cluster B
|
||||||
|
0.0, 100.0, 1.0, 101.0, // Cluster C
|
||||||
|
100.0, 0.0, 101.0, 1.0, // Cluster D
|
||||||
|
];
|
||||||
|
let matrix = DenseMatrix::new(8, 2, data, false).unwrap();
|
||||||
|
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(4);
|
||||||
|
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
|
||||||
|
&matrix, parameters,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let labels = clustering.labels;
|
||||||
|
|
||||||
|
// Verify that there are exactly 4 unique labels produced.
|
||||||
|
let unique_labels: HashSet<usize> = labels.iter().cloned().collect();
|
||||||
|
assert_eq!(unique_labels.len(), 4);
|
||||||
|
|
||||||
|
// Verify that points within each original group were assigned the same cluster label.
|
||||||
|
let label_a = labels[0];
|
||||||
|
assert_eq!(label_a, labels[1]);
|
||||||
|
|
||||||
|
let label_b = labels[2];
|
||||||
|
assert_eq!(label_b, labels[3]);
|
||||||
|
|
||||||
|
let label_c = labels[4];
|
||||||
|
assert_eq!(label_c, labels[5]);
|
||||||
|
|
||||||
|
let label_d = labels[6];
|
||||||
|
assert_eq!(label_d, labels[7]);
|
||||||
|
|
||||||
|
// Verify that all four groups received different labels.
|
||||||
|
assert_ne!(label_a, label_b);
|
||||||
|
assert_ne!(label_a, label_c);
|
||||||
|
assert_ne!(label_a, label_d);
|
||||||
|
assert_ne!(label_b, label_c);
|
||||||
|
assert_ne!(label_b, label_d);
|
||||||
|
assert_ne!(label_c, label_d);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_n_clusters_equal_to_samples() {
|
||||||
|
let data = vec![0.0, 0.0, 5.0, 5.0, 10.0, 10.0];
|
||||||
|
let matrix = DenseMatrix::new(3, 2, data, false).unwrap();
|
||||||
|
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(3);
|
||||||
|
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
|
||||||
|
&matrix, parameters,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Each point should be its own cluster. Sorting makes the test deterministic.
|
||||||
|
let mut labels = clustering.labels;
|
||||||
|
labels.sort();
|
||||||
|
assert_eq!(labels, vec![0, 1, 2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_one_cluster() {
|
||||||
|
let data = vec![0.0, 0.0, 5.0, 5.0, 10.0, 10.0];
|
||||||
|
let matrix = DenseMatrix::new(3, 2, data, false).unwrap();
|
||||||
|
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(1);
|
||||||
|
let clustering = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
|
||||||
|
&matrix, parameters,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// All points should be in the same cluster.
|
||||||
|
assert_eq!(clustering.labels, vec![0, 0, 0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_error_on_too_many_clusters() {
|
||||||
|
let data = vec![0.0, 0.0, 5.0, 5.0];
|
||||||
|
let matrix = DenseMatrix::new(2, 2, data, false).unwrap();
|
||||||
|
let parameters = AgglomerativeClusteringParameters::default().with_n_clusters(3);
|
||||||
|
let result = AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
|
||||||
|
&matrix, parameters,
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@
|
|||||||
//! Clustering is the type of unsupervised learning where you divide the population or data points into a number of groups such that data points in the same groups
|
//! Clustering is the type of unsupervised learning where you divide the population or data points into a number of groups such that data points in the same groups
|
||||||
//! are more similar to other data points in the same group than those in other groups. In simple words, the aim is to segregate groups with similar traits and assign them into clusters.
|
//! are more similar to other data points in the same group than those in other groups. In simple words, the aim is to segregate groups with similar traits and assign them into clusters.
|
||||||
|
|
||||||
|
pub mod agglomerative;
|
||||||
pub mod dbscan;
|
pub mod dbscan;
|
||||||
/// An iterative clustering algorithm that aims to find local maxima in each iteration.
|
/// An iterative clustering algorithm that aims to find local maxima in each iteration.
|
||||||
pub mod kmeans;
|
pub mod kmeans;
|
||||||
|
|||||||
@@ -0,0 +1,214 @@
|
|||||||
|
use rand::Rng;
|
||||||
|
use std::fmt::Debug;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::error::{Failed, FailedError};
|
||||||
|
use crate::linalg::basic::arrays::{Array1, Array2};
|
||||||
|
use crate::numbers::basenum::Number;
|
||||||
|
use crate::numbers::floatnum::FloatNumber;
|
||||||
|
|
||||||
|
use crate::rand_custom::get_rng_impl;
|
||||||
|
use crate::tree::base_tree_regressor::{BaseTreeRegressor, BaseTreeRegressorParameters, Splitter};
|
||||||
|
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
/// Parameters of the Forest Regressor
|
||||||
|
/// Some parameters here are passed directly into base estimator.
|
||||||
|
pub struct BaseForestRegressorParameters {
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
/// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||||
|
pub max_depth: Option<u16>,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||||
|
pub min_samples_leaf: usize,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||||
|
pub min_samples_split: usize,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
/// The number of trees in the forest.
|
||||||
|
pub n_trees: usize,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
/// Number of random sample of predictors to use as split candidates.
|
||||||
|
pub m: Option<usize>,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||||
|
pub keep_samples: bool,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||||
|
pub seed: u64,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
pub bootstrap: bool,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
pub splitter: Splitter,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
|
||||||
|
for BaseForestRegressor<TX, TY, X, Y>
|
||||||
|
{
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
if self.trees.as_ref().unwrap().len() != other.trees.as_ref().unwrap().len() {
|
||||||
|
false
|
||||||
|
} else {
|
||||||
|
self.trees
|
||||||
|
.iter()
|
||||||
|
.zip(other.trees.iter())
|
||||||
|
.all(|(a, b)| a == b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Forest Regressor
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct BaseForestRegressor<
|
||||||
|
TX: Number + FloatNumber + PartialOrd,
|
||||||
|
TY: Number,
|
||||||
|
X: Array2<TX>,
|
||||||
|
Y: Array1<TY>,
|
||||||
|
> {
|
||||||
|
trees: Option<Vec<BaseTreeRegressor<TX, TY, X, Y>>>,
|
||||||
|
samples: Option<Vec<Vec<bool>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||||
|
BaseForestRegressor<TX, TY, X, Y>
|
||||||
|
{
|
||||||
|
/// Build a forest of trees from the training set.
|
||||||
|
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||||
|
/// * `y` - the target class values
|
||||||
|
pub fn fit(
|
||||||
|
x: &X,
|
||||||
|
y: &Y,
|
||||||
|
parameters: BaseForestRegressorParameters,
|
||||||
|
) -> Result<BaseForestRegressor<TX, TY, X, Y>, Failed> {
|
||||||
|
let (n_rows, num_attributes) = x.shape();
|
||||||
|
|
||||||
|
if n_rows != y.shape() {
|
||||||
|
return Err(Failed::fit("Number of rows in X should = len(y)"));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mtry = parameters
|
||||||
|
.m
|
||||||
|
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
|
||||||
|
|
||||||
|
let mut rng = get_rng_impl(Some(parameters.seed));
|
||||||
|
let mut trees: Vec<BaseTreeRegressor<TX, TY, X, Y>> = Vec::new();
|
||||||
|
|
||||||
|
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
|
||||||
|
if parameters.keep_samples {
|
||||||
|
// TODO: use with_capacity here
|
||||||
|
maybe_all_samples = Some(Vec::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut samples: Vec<usize> = (0..n_rows).map(|_| 1).collect();
|
||||||
|
|
||||||
|
for _ in 0..parameters.n_trees {
|
||||||
|
if parameters.bootstrap {
|
||||||
|
samples =
|
||||||
|
BaseForestRegressor::<TX, TY, X, Y>::sample_with_replacement(n_rows, &mut rng);
|
||||||
|
}
|
||||||
|
|
||||||
|
// keep samples is flag is on
|
||||||
|
if let Some(ref mut all_samples) = maybe_all_samples {
|
||||||
|
all_samples.push(samples.iter().map(|x| *x != 0).collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
let params = BaseTreeRegressorParameters {
|
||||||
|
max_depth: parameters.max_depth,
|
||||||
|
min_samples_leaf: parameters.min_samples_leaf,
|
||||||
|
min_samples_split: parameters.min_samples_split,
|
||||||
|
seed: Some(parameters.seed),
|
||||||
|
splitter: parameters.splitter.clone(),
|
||||||
|
};
|
||||||
|
let tree = BaseTreeRegressor::fit_weak_learner(x, y, samples.clone(), mtry, params)?;
|
||||||
|
trees.push(tree);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(BaseForestRegressor {
|
||||||
|
trees: Some(trees),
|
||||||
|
samples: maybe_all_samples,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Predict class for `x`
|
||||||
|
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||||
|
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||||
|
let mut result = Y::zeros(x.shape().0);
|
||||||
|
|
||||||
|
let (n, _) = x.shape();
|
||||||
|
|
||||||
|
for i in 0..n {
|
||||||
|
result.set(i, self.predict_for_row(x, i));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn predict_for_row(&self, x: &X, row: usize) -> TY {
|
||||||
|
let n_trees = self.trees.as_ref().unwrap().len();
|
||||||
|
|
||||||
|
let mut result = TY::zero();
|
||||||
|
|
||||||
|
for tree in self.trees.as_ref().unwrap().iter() {
|
||||||
|
result += tree.predict_for_row(x, row);
|
||||||
|
}
|
||||||
|
|
||||||
|
result / TY::from_usize(n_trees).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
|
||||||
|
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
|
||||||
|
let (n, _) = x.shape();
|
||||||
|
if self.samples.is_none() {
|
||||||
|
Err(Failed::because(
|
||||||
|
FailedError::PredictFailed,
|
||||||
|
"Need samples=true for OOB predictions.",
|
||||||
|
))
|
||||||
|
} else if self.samples.as_ref().unwrap()[0].len() != n {
|
||||||
|
Err(Failed::because(
|
||||||
|
FailedError::PredictFailed,
|
||||||
|
"Prediction matrix must match matrix used in training for OOB predictions.",
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
let mut result = Y::zeros(n);
|
||||||
|
|
||||||
|
for i in 0..n {
|
||||||
|
result.set(i, self.predict_for_row_oob(x, i));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn predict_for_row_oob(&self, x: &X, row: usize) -> TY {
|
||||||
|
let mut n_trees = 0;
|
||||||
|
let mut result = TY::zero();
|
||||||
|
|
||||||
|
for (tree, samples) in self
|
||||||
|
.trees
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.iter()
|
||||||
|
.zip(self.samples.as_ref().unwrap())
|
||||||
|
{
|
||||||
|
if !samples[row] {
|
||||||
|
result += tree.predict_for_row(x, row);
|
||||||
|
n_trees += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: What to do if there are no oob trees?
|
||||||
|
result / TY::from(n_trees).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec<usize> {
|
||||||
|
let mut samples = vec![0; nrows];
|
||||||
|
for _ in 0..nrows {
|
||||||
|
let xi = rng.gen_range(0..nrows);
|
||||||
|
samples[xi] += 1;
|
||||||
|
}
|
||||||
|
samples
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,318 @@
|
|||||||
|
//! # Extra Trees Regressor
|
||||||
|
//! An Extra-Trees (Extremely Randomized Trees) regressor is an ensemble learning method that fits multiple randomized
|
||||||
|
//! decision trees on the dataset and averages their predictions to improve accuracy and control over-fitting.
|
||||||
|
//!
|
||||||
|
//! It is similar to a standard Random Forest, but introduces more randomness in the way splits are chosen, which can
|
||||||
|
//! reduce the variance of the model and often make the training process faster.
|
||||||
|
//!
|
||||||
|
//! The two key differences from a standard Random Forest are:
|
||||||
|
//! 1. It uses the whole original dataset to build each tree instead of bootstrap samples.
|
||||||
|
//! 2. When splitting a node, it chooses a random split point for each feature, rather than the most optimal one.
|
||||||
|
//!
|
||||||
|
//! See [ensemble models](../index.html) for more details.
|
||||||
|
//!
|
||||||
|
//! Bigger number of estimators in general improves performance of the algorithm with an increased cost of training time.
|
||||||
|
//! The random sample of _m_ predictors is typically set to be \\(\sqrt{p}\\) from the full set of _p_ predictors.
|
||||||
|
//!
|
||||||
|
//! Example:
|
||||||
|
//!
|
||||||
|
//! ```
|
||||||
|
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||||
|
//! use smartcore::ensemble::extra_trees_regressor::*;
|
||||||
|
//!
|
||||||
|
//! // Longley dataset ([https://www.statsmodels.org/stable/datasets/generated/longley.html](https://www.statsmodels.org/stable/datasets/generated/longley.html))
|
||||||
|
//! let x = DenseMatrix::from_2d_array(&[
|
||||||
|
//! &[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||||
|
//! &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||||
|
//! &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||||
|
//! &[284.599, 335.1, 165., 110.929, 1950., 61.187],
|
||||||
|
//! &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||||
|
//! &[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
|
||||||
|
//! &[365.385, 187., 354.7, 115.094, 1953., 64.989],
|
||||||
|
//! &[363.112, 357.8, 335., 116.219, 1954., 63.761],
|
||||||
|
//! &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||||
|
//! &[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||||
|
//! &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||||
|
//! &[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
|
||||||
|
//! &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||||
|
//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||||
|
//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||||
|
//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||||
|
//! ]).unwrap();
|
||||||
|
//! let y = vec![
|
||||||
|
//! 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2,
|
||||||
|
//! 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9
|
||||||
|
//! ];
|
||||||
|
//!
|
||||||
|
//! let regressor = ExtraTreesRegressor::fit(&x, &y, Default::default()).unwrap();
|
||||||
|
//!
|
||||||
|
//! let y_hat = regressor.predict(&x).unwrap(); // use the same data for prediction
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
|
||||||
|
use std::default::Default;
|
||||||
|
use std::fmt::Debug;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::api::{Predictor, SupervisedEstimator};
|
||||||
|
use crate::ensemble::base_forest_regressor::{BaseForestRegressor, BaseForestRegressorParameters};
|
||||||
|
use crate::error::Failed;
|
||||||
|
use crate::linalg::basic::arrays::{Array1, Array2};
|
||||||
|
use crate::numbers::basenum::Number;
|
||||||
|
use crate::numbers::floatnum::FloatNumber;
|
||||||
|
use crate::tree::base_tree_regressor::Splitter;
|
||||||
|
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
/// Parameters of the Extra Trees Regressor
|
||||||
|
/// Some parameters here are passed directly into base estimator.
|
||||||
|
pub struct ExtraTreesRegressorParameters {
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
/// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||||
|
pub max_depth: Option<u16>,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||||
|
pub min_samples_leaf: usize,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
/// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
|
||||||
|
pub min_samples_split: usize,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
/// The number of trees in the forest.
|
||||||
|
pub n_trees: usize,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
/// Number of random sample of predictors to use as split candidates.
|
||||||
|
pub m: Option<usize>,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||||
|
pub keep_samples: bool,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||||
|
pub seed: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extra Trees Regressor
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ExtraTreesRegressor<
|
||||||
|
TX: Number + FloatNumber + PartialOrd,
|
||||||
|
TY: Number,
|
||||||
|
X: Array2<TX>,
|
||||||
|
Y: Array1<TY>,
|
||||||
|
> {
|
||||||
|
forest_regressor: Option<BaseForestRegressor<TX, TY, X, Y>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ExtraTreesRegressorParameters {
|
||||||
|
/// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||||
|
pub fn with_max_depth(mut self, max_depth: u16) -> Self {
|
||||||
|
self.max_depth = Some(max_depth);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
/// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||||
|
pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
|
||||||
|
self.min_samples_leaf = min_samples_leaf;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
/// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
|
||||||
|
pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
|
||||||
|
self.min_samples_split = min_samples_split;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
/// The number of trees in the forest.
|
||||||
|
pub fn with_n_trees(mut self, n_trees: usize) -> Self {
|
||||||
|
self.n_trees = n_trees;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
/// Number of random sample of predictors to use as split candidates.
|
||||||
|
pub fn with_m(mut self, m: usize) -> Self {
|
||||||
|
self.m = Some(m);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Whether to keep samples used for tree generation. This is required for OOB prediction.
|
||||||
|
pub fn with_keep_samples(mut self, keep_samples: bool) -> Self {
|
||||||
|
self.keep_samples = keep_samples;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Seed used for bootstrap sampling and feature selection for each tree.
|
||||||
|
pub fn with_seed(mut self, seed: u64) -> Self {
|
||||||
|
self.seed = seed;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl Default for ExtraTreesRegressorParameters {
|
||||||
|
fn default() -> Self {
|
||||||
|
ExtraTreesRegressorParameters {
|
||||||
|
max_depth: Option::None,
|
||||||
|
min_samples_leaf: 1,
|
||||||
|
min_samples_split: 2,
|
||||||
|
n_trees: 10,
|
||||||
|
m: Option::None,
|
||||||
|
keep_samples: false,
|
||||||
|
seed: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||||
|
SupervisedEstimator<X, Y, ExtraTreesRegressorParameters> for ExtraTreesRegressor<TX, TY, X, Y>
|
||||||
|
{
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
forest_regressor: Option::None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fit(x: &X, y: &Y, parameters: ExtraTreesRegressorParameters) -> Result<Self, Failed> {
|
||||||
|
ExtraTreesRegressor::fit(x, y, parameters)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||||
|
Predictor<X, Y> for ExtraTreesRegressor<TX, TY, X, Y>
|
||||||
|
{
|
||||||
|
fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||||
|
self.predict(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||||
|
ExtraTreesRegressor<TX, TY, X, Y>
|
||||||
|
{
|
||||||
|
/// Build a forest of trees from the training set.
|
||||||
|
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||||
|
/// * `y` - the target class values
|
||||||
|
pub fn fit(
|
||||||
|
x: &X,
|
||||||
|
y: &Y,
|
||||||
|
parameters: ExtraTreesRegressorParameters,
|
||||||
|
) -> Result<ExtraTreesRegressor<TX, TY, X, Y>, Failed> {
|
||||||
|
let regressor_params = BaseForestRegressorParameters {
|
||||||
|
max_depth: parameters.max_depth,
|
||||||
|
min_samples_leaf: parameters.min_samples_leaf,
|
||||||
|
min_samples_split: parameters.min_samples_split,
|
||||||
|
n_trees: parameters.n_trees,
|
||||||
|
m: parameters.m,
|
||||||
|
keep_samples: parameters.keep_samples,
|
||||||
|
seed: parameters.seed,
|
||||||
|
bootstrap: false,
|
||||||
|
splitter: Splitter::Random,
|
||||||
|
};
|
||||||
|
let forest_regressor = BaseForestRegressor::fit(x, y, regressor_params)?;
|
||||||
|
|
||||||
|
Ok(ExtraTreesRegressor {
|
||||||
|
forest_regressor: Some(forest_regressor),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Predict class for `x`
|
||||||
|
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||||
|
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||||
|
let forest_regressor = self.forest_regressor.as_ref().unwrap();
|
||||||
|
forest_regressor.predict(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
|
||||||
|
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
|
||||||
|
let forest_regressor = self.forest_regressor.as_ref().unwrap();
|
||||||
|
forest_regressor.predict_oob(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::linalg::basic::matrix::DenseMatrix;
|
||||||
|
use crate::metrics::mean_squared_error;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_extra_trees_regressor_fit_predict() {
|
||||||
|
// Use a simpler, more predictable dataset for unit testing.
|
||||||
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
|
&[1., 2.],
|
||||||
|
&[3., 4.],
|
||||||
|
&[5., 6.],
|
||||||
|
&[7., 8.],
|
||||||
|
&[9., 10.],
|
||||||
|
&[11., 12.],
|
||||||
|
&[13., 14.],
|
||||||
|
&[15., 16.],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
let y = vec![1., 2., 3., 4., 5., 6., 7., 8.];
|
||||||
|
|
||||||
|
let parameters = ExtraTreesRegressorParameters::default()
|
||||||
|
.with_n_trees(100)
|
||||||
|
.with_seed(42);
|
||||||
|
|
||||||
|
let regressor = ExtraTreesRegressor::fit(&x, &y, parameters).unwrap();
|
||||||
|
let y_hat = regressor.predict(&x).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(y_hat.len(), y.len());
|
||||||
|
// A basic check to ensure the model is learning something.
|
||||||
|
// The error should be significantly less than the variance of y.
|
||||||
|
let mse = mean_squared_error(&y, &y_hat);
|
||||||
|
// With this simple dataset, the error should be very low.
|
||||||
|
assert!(mse < 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_fit_predict_higher_dims() {
|
||||||
|
// Dataset with 10 features, but y is only dependent on the 3rd feature (index 2).
|
||||||
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
|
// The 3rd column is the important one. The rest are noise.
|
||||||
|
&[0., 0., 10., 5., 8., 1., 4., 9., 2., 7.],
|
||||||
|
&[0., 0., 20., 1., 2., 3., 4., 5., 6., 7.],
|
||||||
|
&[0., 0., 30., 7., 6., 5., 4., 3., 2., 1.],
|
||||||
|
&[0., 0., 40., 9., 2., 4., 6., 8., 1., 3.],
|
||||||
|
&[0., 0., 55., 3., 1., 8., 6., 4., 2., 9.],
|
||||||
|
&[0., 0., 65., 2., 4., 7., 5., 3., 1., 8.],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
let y = vec![10., 20., 30., 40., 55., 65.];
|
||||||
|
|
||||||
|
let parameters = ExtraTreesRegressorParameters::default()
|
||||||
|
.with_n_trees(100)
|
||||||
|
.with_seed(42);
|
||||||
|
|
||||||
|
let regressor = ExtraTreesRegressor::fit(&x, &y, parameters).unwrap();
|
||||||
|
let y_hat = regressor.predict(&x).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(y_hat.len(), y.len());
|
||||||
|
|
||||||
|
let mse = mean_squared_error(&y, &y_hat);
|
||||||
|
|
||||||
|
// The model should be able to learn this simple relationship perfectly,
|
||||||
|
// ignoring the noise features. The MSE should be very low.
|
||||||
|
assert!(mse < 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_reproducibility() {
|
||||||
|
let x = DenseMatrix::from_2d_array(&[
|
||||||
|
&[1., 2.],
|
||||||
|
&[3., 4.],
|
||||||
|
&[5., 6.],
|
||||||
|
&[7., 8.],
|
||||||
|
&[9., 10.],
|
||||||
|
&[11., 12.],
|
||||||
|
])
|
||||||
|
.unwrap();
|
||||||
|
let y = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
|
|
||||||
|
let params = ExtraTreesRegressorParameters::default().with_seed(42);
|
||||||
|
|
||||||
|
let regressor1 = ExtraTreesRegressor::fit(&x, &y, params.clone()).unwrap();
|
||||||
|
let y_hat1 = regressor1.predict(&x).unwrap();
|
||||||
|
|
||||||
|
let regressor2 = ExtraTreesRegressor::fit(&x, &y, params.clone()).unwrap();
|
||||||
|
let y_hat2 = regressor2.predict(&x).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(y_hat1, y_hat2);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -16,6 +16,8 @@
|
|||||||
//!
|
//!
|
||||||
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 8.2 Bagging, Random Forests, Boosting](http://faculty.marshall.usc.edu/gareth-james/ISL/)
|
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 8.2 Bagging, Random Forests, Boosting](http://faculty.marshall.usc.edu/gareth-james/ISL/)
|
||||||
|
|
||||||
|
mod base_forest_regressor;
|
||||||
|
pub mod extra_trees_regressor;
|
||||||
/// Random forest classifier
|
/// Random forest classifier
|
||||||
pub mod random_forest_classifier;
|
pub mod random_forest_classifier;
|
||||||
/// Random forest regressor
|
/// Random forest regressor
|
||||||
|
|||||||
@@ -43,7 +43,6 @@
|
|||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
|
||||||
use rand::Rng;
|
|
||||||
use std::default::Default;
|
use std::default::Default;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
@@ -51,15 +50,12 @@ use std::fmt::Debug;
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::api::{Predictor, SupervisedEstimator};
|
use crate::api::{Predictor, SupervisedEstimator};
|
||||||
use crate::error::{Failed, FailedError};
|
use crate::ensemble::base_forest_regressor::{BaseForestRegressor, BaseForestRegressorParameters};
|
||||||
|
use crate::error::Failed;
|
||||||
use crate::linalg::basic::arrays::{Array1, Array2};
|
use crate::linalg::basic::arrays::{Array1, Array2};
|
||||||
use crate::numbers::basenum::Number;
|
use crate::numbers::basenum::Number;
|
||||||
use crate::numbers::floatnum::FloatNumber;
|
use crate::numbers::floatnum::FloatNumber;
|
||||||
|
use crate::tree::base_tree_regressor::Splitter;
|
||||||
use crate::rand_custom::get_rng_impl;
|
|
||||||
use crate::tree::decision_tree_regressor::{
|
|
||||||
DecisionTreeRegressor, DecisionTreeRegressorParameters,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@@ -98,8 +94,7 @@ pub struct RandomForestRegressor<
|
|||||||
X: Array2<TX>,
|
X: Array2<TX>,
|
||||||
Y: Array1<TY>,
|
Y: Array1<TY>,
|
||||||
> {
|
> {
|
||||||
trees: Option<Vec<DecisionTreeRegressor<TX, TY, X, Y>>>,
|
forest_regressor: Option<BaseForestRegressor<TX, TY, X, Y>>,
|
||||||
samples: Option<Vec<Vec<bool>>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RandomForestRegressorParameters {
|
impl RandomForestRegressorParameters {
|
||||||
@@ -159,14 +154,7 @@ impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1
|
|||||||
for RandomForestRegressor<TX, TY, X, Y>
|
for RandomForestRegressor<TX, TY, X, Y>
|
||||||
{
|
{
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
if self.trees.as_ref().unwrap().len() != other.trees.as_ref().unwrap().len() {
|
self.forest_regressor == other.forest_regressor
|
||||||
false
|
|
||||||
} else {
|
|
||||||
self.trees
|
|
||||||
.iter()
|
|
||||||
.zip(other.trees.iter())
|
|
||||||
.all(|(a, b)| a == b)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,8 +164,7 @@ impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1
|
|||||||
{
|
{
|
||||||
fn new() -> Self {
|
fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
trees: Option::None,
|
forest_regressor: Option::None,
|
||||||
samples: Option::None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -397,128 +384,35 @@ impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1
|
|||||||
y: &Y,
|
y: &Y,
|
||||||
parameters: RandomForestRegressorParameters,
|
parameters: RandomForestRegressorParameters,
|
||||||
) -> Result<RandomForestRegressor<TX, TY, X, Y>, Failed> {
|
) -> Result<RandomForestRegressor<TX, TY, X, Y>, Failed> {
|
||||||
let (n_rows, num_attributes) = x.shape();
|
let regressor_params = BaseForestRegressorParameters {
|
||||||
|
max_depth: parameters.max_depth,
|
||||||
if n_rows != y.shape() {
|
min_samples_leaf: parameters.min_samples_leaf,
|
||||||
return Err(Failed::fit("Number of rows in X should = len(y)"));
|
min_samples_split: parameters.min_samples_split,
|
||||||
}
|
n_trees: parameters.n_trees,
|
||||||
|
m: parameters.m,
|
||||||
let mtry = parameters
|
keep_samples: parameters.keep_samples,
|
||||||
.m
|
seed: parameters.seed,
|
||||||
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
|
bootstrap: true,
|
||||||
|
splitter: Splitter::Best,
|
||||||
let mut rng = get_rng_impl(Some(parameters.seed));
|
};
|
||||||
let mut trees: Vec<DecisionTreeRegressor<TX, TY, X, Y>> = Vec::new();
|
let forest_regressor = BaseForestRegressor::fit(x, y, regressor_params)?;
|
||||||
|
|
||||||
let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
|
|
||||||
if parameters.keep_samples {
|
|
||||||
// TODO: use with_capacity here
|
|
||||||
maybe_all_samples = Some(Vec::new());
|
|
||||||
}
|
|
||||||
|
|
||||||
for _ in 0..parameters.n_trees {
|
|
||||||
let samples: Vec<usize> =
|
|
||||||
RandomForestRegressor::<TX, TY, X, Y>::sample_with_replacement(n_rows, &mut rng);
|
|
||||||
|
|
||||||
// keep samples is flag is on
|
|
||||||
if let Some(ref mut all_samples) = maybe_all_samples {
|
|
||||||
all_samples.push(samples.iter().map(|x| *x != 0).collect())
|
|
||||||
}
|
|
||||||
|
|
||||||
let params = DecisionTreeRegressorParameters {
|
|
||||||
max_depth: parameters.max_depth,
|
|
||||||
min_samples_leaf: parameters.min_samples_leaf,
|
|
||||||
min_samples_split: parameters.min_samples_split,
|
|
||||||
seed: Some(parameters.seed),
|
|
||||||
};
|
|
||||||
let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params)?;
|
|
||||||
trees.push(tree);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(RandomForestRegressor {
|
Ok(RandomForestRegressor {
|
||||||
trees: Some(trees),
|
forest_regressor: Some(forest_regressor),
|
||||||
samples: maybe_all_samples,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Predict class for `x`
|
/// Predict class for `x`
|
||||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||||
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||||
let mut result = Y::zeros(x.shape().0);
|
let forest_regressor = self.forest_regressor.as_ref().unwrap();
|
||||||
|
forest_regressor.predict(x)
|
||||||
let (n, _) = x.shape();
|
|
||||||
|
|
||||||
for i in 0..n {
|
|
||||||
result.set(i, self.predict_for_row(x, i));
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn predict_for_row(&self, x: &X, row: usize) -> TY {
|
|
||||||
let n_trees = self.trees.as_ref().unwrap().len();
|
|
||||||
|
|
||||||
let mut result = TY::zero();
|
|
||||||
|
|
||||||
for tree in self.trees.as_ref().unwrap().iter() {
|
|
||||||
result += tree.predict_for_row(x, row);
|
|
||||||
}
|
|
||||||
|
|
||||||
result / TY::from_usize(n_trees).unwrap()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
|
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
|
||||||
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
|
pub fn predict_oob(&self, x: &X) -> Result<Y, Failed> {
|
||||||
let (n, _) = x.shape();
|
let forest_regressor = self.forest_regressor.as_ref().unwrap();
|
||||||
if self.samples.is_none() {
|
forest_regressor.predict_oob(x)
|
||||||
Err(Failed::because(
|
|
||||||
FailedError::PredictFailed,
|
|
||||||
"Need samples=true for OOB predictions.",
|
|
||||||
))
|
|
||||||
} else if self.samples.as_ref().unwrap()[0].len() != n {
|
|
||||||
Err(Failed::because(
|
|
||||||
FailedError::PredictFailed,
|
|
||||||
"Prediction matrix must match matrix used in training for OOB predictions.",
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
let mut result = Y::zeros(n);
|
|
||||||
|
|
||||||
for i in 0..n {
|
|
||||||
result.set(i, self.predict_for_row_oob(x, i));
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn predict_for_row_oob(&self, x: &X, row: usize) -> TY {
|
|
||||||
let mut n_trees = 0;
|
|
||||||
let mut result = TY::zero();
|
|
||||||
|
|
||||||
for (tree, samples) in self
|
|
||||||
.trees
|
|
||||||
.as_ref()
|
|
||||||
.unwrap()
|
|
||||||
.iter()
|
|
||||||
.zip(self.samples.as_ref().unwrap())
|
|
||||||
{
|
|
||||||
if !samples[row] {
|
|
||||||
result += tree.predict_for_row(x, row);
|
|
||||||
n_trees += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: What to do if there are no oob trees?
|
|
||||||
result / TY::from(n_trees).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec<usize> {
|
|
||||||
let mut samples = vec![0; nrows];
|
|
||||||
for _ in 0..nrows {
|
|
||||||
let xi = rng.gen_range(0..nrows);
|
|
||||||
samples[xi] += 1;
|
|
||||||
}
|
|
||||||
samples
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -130,5 +130,6 @@ pub mod readers;
|
|||||||
pub mod svm;
|
pub mod svm;
|
||||||
/// Supervised tree-based learning methods
|
/// Supervised tree-based learning methods
|
||||||
pub mod tree;
|
pub mod tree;
|
||||||
|
pub mod xgboost;
|
||||||
|
|
||||||
pub(crate) mod rand_custom;
|
pub(crate) mod rand_custom;
|
||||||
|
|||||||
@@ -0,0 +1,219 @@
|
|||||||
|
//! # Cosine Distance Metric
|
||||||
|
//!
|
||||||
|
//! The cosine distance between two points \\( x \\) and \\( y \\) in n-space is defined as:
|
||||||
|
//!
|
||||||
|
//! \\[ d(x, y) = 1 - \frac{x \cdot y}{||x|| ||y||} \\]
|
||||||
|
//!
|
||||||
|
//! where \\( x \cdot y \\) is the dot product of the vectors, and \\( ||x|| \\) and \\( ||y|| \\)
|
||||||
|
//! are their respective magnitudes (Euclidean norms).
|
||||||
|
//!
|
||||||
|
//! Cosine distance measures the angular dissimilarity between vectors, ranging from 0 to 2.
|
||||||
|
//! A value of 0 indicates identical direction (parallel vectors), while larger values indicate
|
||||||
|
//! greater angular separation.
|
||||||
|
//!
|
||||||
|
//! Example:
|
||||||
|
//!
|
||||||
|
//! ```
|
||||||
|
//! use smartcore::metrics::distance::Distance;
|
||||||
|
//! use smartcore::metrics::distance::cosine::Cosine;
|
||||||
|
//!
|
||||||
|
//! let x = vec![1., 1.];
|
||||||
|
//! let y = vec![2., 2.];
|
||||||
|
//!
|
||||||
|
//! let cosine_dist: f64 = Cosine::new().distance(&x, &y);
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
use crate::linalg::basic::arrays::ArrayView1;
|
||||||
|
use crate::numbers::basenum::Number;
|
||||||
|
|
||||||
|
use super::Distance;
|
||||||
|
|
||||||
|
/// Cosine distance is a measure of the angular dissimilarity between two non-zero vectors in n-space.
|
||||||
|
/// It is defined as 1 minus the cosine similarity of the vectors.
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Cosine<T> {
|
||||||
|
_t: PhantomData<T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Number> Default for Cosine<T> {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Number> Cosine<T> {
|
||||||
|
/// Instantiate the initial structure
|
||||||
|
pub fn new() -> Cosine<T> {
|
||||||
|
Cosine { _t: PhantomData }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculate the dot product of two vectors using smartcore's ArrayView1 trait
|
||||||
|
#[inline]
|
||||||
|
pub(crate) fn dot_product<A: ArrayView1<T>>(x: &A, y: &A) -> f64 {
|
||||||
|
if x.shape() != y.shape() {
|
||||||
|
panic!("Input vector sizes are different.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the built-in dot product method from ArrayView1 trait
|
||||||
|
x.dot(y).to_f64().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculate the squared magnitude (norm squared) of a vector
|
||||||
|
#[inline]
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub(crate) fn squared_magnitude<A: ArrayView1<T>>(x: &A) -> f64 {
|
||||||
|
x.iterator(0)
|
||||||
|
.map(|&a| {
|
||||||
|
let val = a.to_f64().unwrap();
|
||||||
|
val * val
|
||||||
|
})
|
||||||
|
.sum()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculate the magnitude (Euclidean norm) of a vector using smartcore's norm2 method
|
||||||
|
#[inline]
|
||||||
|
pub(crate) fn magnitude<A: ArrayView1<T>>(x: &A) -> f64 {
|
||||||
|
// Use the built-in norm2 method from ArrayView1 trait
|
||||||
|
x.norm2()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculate cosine similarity between two vectors
|
||||||
|
#[inline]
|
||||||
|
pub(crate) fn cosine_similarity<A: ArrayView1<T>>(x: &A, y: &A) -> f64 {
|
||||||
|
let dot_product = Self::dot_product(x, y);
|
||||||
|
let magnitude_x = Self::magnitude(x);
|
||||||
|
let magnitude_y = Self::magnitude(y);
|
||||||
|
|
||||||
|
if magnitude_x == 0.0 || magnitude_y == 0.0 {
|
||||||
|
panic!("Cannot compute cosine distance for zero-magnitude vectors.");
|
||||||
|
}
|
||||||
|
|
||||||
|
dot_product / (magnitude_x * magnitude_y)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Number, A: ArrayView1<T>> Distance<A> for Cosine<T> {
|
||||||
|
fn distance(&self, x: &A, y: &A) -> f64 {
|
||||||
|
let similarity = Cosine::cosine_similarity(x, y);
|
||||||
|
1.0 - similarity
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_distance_identical_vectors() {
|
||||||
|
let a = vec![1, 2, 3];
|
||||||
|
let b = vec![1, 2, 3];
|
||||||
|
|
||||||
|
let dist: f64 = Cosine::new().distance(&a, &b);
|
||||||
|
|
||||||
|
assert!((dist - 0.0).abs() < 1e-8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_distance_orthogonal_vectors() {
|
||||||
|
let a = vec![1, 0];
|
||||||
|
let b = vec![0, 1];
|
||||||
|
|
||||||
|
let dist: f64 = Cosine::new().distance(&a, &b);
|
||||||
|
|
||||||
|
assert!((dist - 1.0).abs() < 1e-8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_distance_opposite_vectors() {
|
||||||
|
let a = vec![1, 2, 3];
|
||||||
|
let b = vec![-1, -2, -3];
|
||||||
|
|
||||||
|
let dist: f64 = Cosine::new().distance(&a, &b);
|
||||||
|
|
||||||
|
assert!((dist - 2.0).abs() < 1e-8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_distance_general_case() {
|
||||||
|
let a = vec![1.0, 2.0, 3.0];
|
||||||
|
let b = vec![2.0, 1.0, 3.0];
|
||||||
|
|
||||||
|
let dist: f64 = Cosine::new().distance(&a, &b);
|
||||||
|
|
||||||
|
// Expected cosine similarity: (1*2 + 2*1 + 3*3) / (sqrt(1+4+9) * sqrt(4+1+9))
|
||||||
|
// = (2 + 2 + 9) / (sqrt(14) * sqrt(14)) = 13/14 ≈ 0.9286
|
||||||
|
// So cosine distance = 1 - 13/14 = 1/14 ≈ 0.0714
|
||||||
|
let expected_dist = 1.0 - (13.0 / 14.0);
|
||||||
|
assert!((dist - expected_dist).abs() < 1e-8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "Input vector sizes are different.")]
|
||||||
|
fn cosine_distance_different_sizes() {
|
||||||
|
let a = vec![1, 2];
|
||||||
|
let b = vec![1, 2, 3];
|
||||||
|
|
||||||
|
let _dist: f64 = Cosine::new().distance(&a, &b);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "Cannot compute cosine distance for zero-magnitude vectors.")]
|
||||||
|
fn cosine_distance_zero_vector() {
|
||||||
|
let a = vec![0, 0, 0];
|
||||||
|
let b = vec![1, 2, 3];
|
||||||
|
|
||||||
|
let _dist: f64 = Cosine::new().distance(&a, &b);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(
|
||||||
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
)]
|
||||||
|
#[test]
|
||||||
|
fn cosine_distance_float_precision() {
|
||||||
|
let a = vec![1.0f32, 2.0, 3.0];
|
||||||
|
let b = vec![4.0f32, 5.0, 6.0];
|
||||||
|
|
||||||
|
let dist: f64 = Cosine::new().distance(&a, &b);
|
||||||
|
|
||||||
|
// Calculate expected value manually
|
||||||
|
let dot_product = 1.0 * 4.0 + 2.0 * 5.0 + 3.0 * 6.0; // = 32
|
||||||
|
let mag_a = (1.0 * 1.0 + 2.0 * 2.0 + 3.0 * 3.0_f64).sqrt(); // = sqrt(14)
|
||||||
|
let mag_b = (4.0 * 4.0 + 5.0 * 5.0 + 6.0 * 6.0_f64).sqrt(); // = sqrt(77)
|
||||||
|
let expected_similarity = dot_product / (mag_a * mag_b);
|
||||||
|
let expected_distance = 1.0 - expected_similarity;
|
||||||
|
|
||||||
|
assert!((dist - expected_distance).abs() < 1e-6);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -13,6 +13,8 @@
|
|||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
|
||||||
|
/// Cosine distance
|
||||||
|
pub mod cosine;
|
||||||
/// Euclidean Distance is the straight-line distance between two points in Euclidean spacere that presents the shortest distance between these points.
|
/// Euclidean Distance is the straight-line distance between two points in Euclidean spacere that presents the shortest distance between these points.
|
||||||
pub mod euclidian;
|
pub mod euclidian;
|
||||||
/// Hamming Distance between two strings is the number of positions at which the corresponding symbols are different.
|
/// Hamming Distance between two strings is the number of positions at which the corresponding symbols are different.
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ pub trait LineSearchMethod<T: Float> {
|
|||||||
/// Find alpha that satisfies strong Wolfe conditions.
|
/// Find alpha that satisfies strong Wolfe conditions.
|
||||||
fn search(
|
fn search(
|
||||||
&self,
|
&self,
|
||||||
f: &(dyn Fn(T) -> T),
|
f: &dyn Fn(T) -> T,
|
||||||
df: &(dyn Fn(T) -> T),
|
df: &dyn Fn(T) -> T,
|
||||||
alpha: T,
|
alpha: T,
|
||||||
f0: T,
|
f0: T,
|
||||||
df0: T,
|
df0: T,
|
||||||
@@ -55,8 +55,8 @@ impl<T: Float> Default for Backtracking<T> {
|
|||||||
impl<T: Float> LineSearchMethod<T> for Backtracking<T> {
|
impl<T: Float> LineSearchMethod<T> for Backtracking<T> {
|
||||||
fn search(
|
fn search(
|
||||||
&self,
|
&self,
|
||||||
f: &(dyn Fn(T) -> T),
|
f: &dyn Fn(T) -> T,
|
||||||
_: &(dyn Fn(T) -> T),
|
_: &dyn Fn(T) -> T,
|
||||||
alpha: T,
|
alpha: T,
|
||||||
f0: T,
|
f0: T,
|
||||||
df0: T,
|
df0: T,
|
||||||
|
|||||||
@@ -0,0 +1,551 @@
|
|||||||
|
use std::collections::LinkedList;
|
||||||
|
use std::default::Default;
|
||||||
|
use std::fmt::Debug;
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
use rand::seq::SliceRandom;
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::error::Failed;
|
||||||
|
use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1};
|
||||||
|
use crate::numbers::basenum::Number;
|
||||||
|
use crate::rand_custom::get_rng_impl;
|
||||||
|
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone, Default)]
|
||||||
|
pub enum Splitter {
|
||||||
|
Random,
|
||||||
|
#[default]
|
||||||
|
Best,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
/// Parameters of Regression base_tree
|
||||||
|
pub struct BaseTreeRegressorParameters {
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
/// The maximum depth of the base_tree.
|
||||||
|
pub max_depth: Option<u16>,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
/// The minimum number of samples required to be at a leaf node.
|
||||||
|
pub min_samples_leaf: usize,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
/// The minimum number of samples required to split an internal node.
|
||||||
|
pub min_samples_split: usize,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
/// Controls the randomness of the estimator
|
||||||
|
pub seed: Option<u64>,
|
||||||
|
#[cfg_attr(feature = "serde", serde(default))]
|
||||||
|
/// Determines the strategy used to choose the split at each node.
|
||||||
|
pub splitter: Splitter,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Regression base_tree
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct BaseTreeRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
|
||||||
|
nodes: Vec<Node>,
|
||||||
|
parameters: Option<BaseTreeRegressorParameters>,
|
||||||
|
depth: u16,
|
||||||
|
_phantom_tx: PhantomData<TX>,
|
||||||
|
_phantom_ty: PhantomData<TY>,
|
||||||
|
_phantom_x: PhantomData<X>,
|
||||||
|
_phantom_y: PhantomData<Y>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||||
|
BaseTreeRegressor<TX, TY, X, Y>
|
||||||
|
{
|
||||||
|
/// Get nodes, return a shared reference
|
||||||
|
fn nodes(&self) -> &Vec<Node> {
|
||||||
|
self.nodes.as_ref()
|
||||||
|
}
|
||||||
|
/// Get parameters, return a shared reference
|
||||||
|
fn parameters(&self) -> &BaseTreeRegressorParameters {
|
||||||
|
self.parameters.as_ref().unwrap()
|
||||||
|
}
|
||||||
|
/// Get estimate of intercept, return value
|
||||||
|
fn depth(&self) -> u16 {
|
||||||
|
self.depth
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Node {
|
||||||
|
output: f64,
|
||||||
|
split_feature: usize,
|
||||||
|
split_value: Option<f64>,
|
||||||
|
split_score: Option<f64>,
|
||||||
|
true_child: Option<usize>,
|
||||||
|
false_child: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Node {
|
||||||
|
fn new(output: f64) -> Self {
|
||||||
|
Node {
|
||||||
|
output,
|
||||||
|
split_feature: 0,
|
||||||
|
split_value: Option::None,
|
||||||
|
split_score: Option::None,
|
||||||
|
true_child: Option::None,
|
||||||
|
false_child: Option::None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialEq for Node {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
(self.output - other.output).abs() < f64::EPSILON
|
||||||
|
&& self.split_feature == other.split_feature
|
||||||
|
&& match (self.split_value, other.split_value) {
|
||||||
|
(Some(a), Some(b)) => (a - b).abs() < f64::EPSILON,
|
||||||
|
(None, None) => true,
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
&& match (self.split_score, other.split_score) {
|
||||||
|
(Some(a), Some(b)) => (a - b).abs() < f64::EPSILON,
|
||||||
|
(None, None) => true,
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
|
||||||
|
for BaseTreeRegressor<TX, TY, X, Y>
|
||||||
|
{
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
if self.depth != other.depth || self.nodes().len() != other.nodes().len() {
|
||||||
|
false
|
||||||
|
} else {
|
||||||
|
self.nodes()
|
||||||
|
.iter()
|
||||||
|
.zip(other.nodes().iter())
|
||||||
|
.all(|(a, b)| a == b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct NodeVisitor<'a, TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
|
||||||
|
x: &'a X,
|
||||||
|
y: &'a Y,
|
||||||
|
node: usize,
|
||||||
|
samples: Vec<usize>,
|
||||||
|
order: &'a [Vec<usize>],
|
||||||
|
true_child_output: f64,
|
||||||
|
false_child_output: f64,
|
||||||
|
level: u16,
|
||||||
|
_phantom_tx: PhantomData<TX>,
|
||||||
|
_phantom_ty: PhantomData<TY>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||||
|
NodeVisitor<'a, TX, TY, X, Y>
|
||||||
|
{
|
||||||
|
fn new(
|
||||||
|
node_id: usize,
|
||||||
|
samples: Vec<usize>,
|
||||||
|
order: &'a [Vec<usize>],
|
||||||
|
x: &'a X,
|
||||||
|
y: &'a Y,
|
||||||
|
level: u16,
|
||||||
|
) -> Self {
|
||||||
|
NodeVisitor {
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
node: node_id,
|
||||||
|
samples,
|
||||||
|
order,
|
||||||
|
true_child_output: 0f64,
|
||||||
|
false_child_output: 0f64,
|
||||||
|
level,
|
||||||
|
_phantom_tx: PhantomData,
|
||||||
|
_phantom_ty: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||||
|
BaseTreeRegressor<TX, TY, X, Y>
|
||||||
|
{
|
||||||
|
/// Build a decision base_tree regressor from the training data.
|
||||||
|
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
|
||||||
|
/// * `y` - the target values
|
||||||
|
pub fn fit(
|
||||||
|
x: &X,
|
||||||
|
y: &Y,
|
||||||
|
parameters: BaseTreeRegressorParameters,
|
||||||
|
) -> Result<BaseTreeRegressor<TX, TY, X, Y>, Failed> {
|
||||||
|
let (x_nrows, num_attributes) = x.shape();
|
||||||
|
if x_nrows != y.shape() {
|
||||||
|
return Err(Failed::fit("Size of x should equal size of y"));
|
||||||
|
}
|
||||||
|
|
||||||
|
let samples = vec![1; x_nrows];
|
||||||
|
BaseTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn fit_weak_learner(
|
||||||
|
x: &X,
|
||||||
|
y: &Y,
|
||||||
|
samples: Vec<usize>,
|
||||||
|
mtry: usize,
|
||||||
|
parameters: BaseTreeRegressorParameters,
|
||||||
|
) -> Result<BaseTreeRegressor<TX, TY, X, Y>, Failed> {
|
||||||
|
let y_m = y.clone();
|
||||||
|
|
||||||
|
let y_ncols = y_m.shape();
|
||||||
|
let (_, num_attributes) = x.shape();
|
||||||
|
|
||||||
|
let mut nodes: Vec<Node> = Vec::new();
|
||||||
|
let mut rng = get_rng_impl(parameters.seed);
|
||||||
|
|
||||||
|
let mut n = 0;
|
||||||
|
let mut sum = 0f64;
|
||||||
|
for (i, sample_i) in samples.iter().enumerate().take(y_ncols) {
|
||||||
|
n += *sample_i;
|
||||||
|
sum += *sample_i as f64 * y_m.get(i).to_f64().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let root = Node::new(sum / (n as f64));
|
||||||
|
nodes.push(root);
|
||||||
|
let mut order: Vec<Vec<usize>> = Vec::new();
|
||||||
|
|
||||||
|
for i in 0..num_attributes {
|
||||||
|
let mut col_i: Vec<TX> = x.get_col(i).iterator(0).copied().collect();
|
||||||
|
order.push(col_i.argsort_mut());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut base_tree = BaseTreeRegressor {
|
||||||
|
nodes,
|
||||||
|
parameters: Some(parameters),
|
||||||
|
depth: 0u16,
|
||||||
|
_phantom_tx: PhantomData,
|
||||||
|
_phantom_ty: PhantomData,
|
||||||
|
_phantom_x: PhantomData,
|
||||||
|
_phantom_y: PhantomData,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut visitor = NodeVisitor::<TX, TY, X, Y>::new(0, samples, &order, x, &y_m, 1);
|
||||||
|
|
||||||
|
let mut visitor_queue: LinkedList<NodeVisitor<'_, TX, TY, X, Y>> = LinkedList::new();
|
||||||
|
|
||||||
|
if base_tree.find_best_cutoff(&mut visitor, mtry, &mut rng) {
|
||||||
|
visitor_queue.push_back(visitor);
|
||||||
|
}
|
||||||
|
|
||||||
|
while base_tree.depth() < base_tree.parameters().max_depth.unwrap_or(u16::MAX) {
|
||||||
|
match visitor_queue.pop_front() {
|
||||||
|
Some(node) => base_tree.split(node, mtry, &mut visitor_queue, &mut rng),
|
||||||
|
None => break,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(base_tree)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Predict regression value for `x`.
|
||||||
|
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||||
|
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||||
|
let mut result = Y::zeros(x.shape().0);
|
||||||
|
|
||||||
|
let (n, _) = x.shape();
|
||||||
|
|
||||||
|
for i in 0..n {
|
||||||
|
result.set(i, self.predict_for_row(x, i));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn predict_for_row(&self, x: &X, row: usize) -> TY {
|
||||||
|
let mut result = 0f64;
|
||||||
|
let mut queue: LinkedList<usize> = LinkedList::new();
|
||||||
|
|
||||||
|
queue.push_back(0);
|
||||||
|
|
||||||
|
while !queue.is_empty() {
|
||||||
|
match queue.pop_front() {
|
||||||
|
Some(node_id) => {
|
||||||
|
let node = &self.nodes()[node_id];
|
||||||
|
if node.true_child.is_none() && node.false_child.is_none() {
|
||||||
|
result = node.output;
|
||||||
|
} else if x.get((row, node.split_feature)).to_f64().unwrap()
|
||||||
|
<= node.split_value.unwrap_or(f64::NAN)
|
||||||
|
{
|
||||||
|
queue.push_back(node.true_child.unwrap());
|
||||||
|
} else {
|
||||||
|
queue.push_back(node.false_child.unwrap());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => break,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
TY::from_f64(result).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn find_best_cutoff(
|
||||||
|
&mut self,
|
||||||
|
visitor: &mut NodeVisitor<'_, TX, TY, X, Y>,
|
||||||
|
mtry: usize,
|
||||||
|
rng: &mut impl Rng,
|
||||||
|
) -> bool {
|
||||||
|
let (_, n_attr) = visitor.x.shape();
|
||||||
|
|
||||||
|
let n: usize = visitor.samples.iter().sum();
|
||||||
|
|
||||||
|
if n < self.parameters().min_samples_split {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let sum = self.nodes()[visitor.node].output * n as f64;
|
||||||
|
|
||||||
|
let mut variables = (0..n_attr).collect::<Vec<_>>();
|
||||||
|
|
||||||
|
if mtry < n_attr {
|
||||||
|
variables.shuffle(rng);
|
||||||
|
}
|
||||||
|
|
||||||
|
let parent_gain =
|
||||||
|
n as f64 * self.nodes()[visitor.node].output * self.nodes()[visitor.node].output;
|
||||||
|
|
||||||
|
let splitter = self.parameters().splitter.clone();
|
||||||
|
|
||||||
|
for variable in variables.iter().take(mtry) {
|
||||||
|
match splitter {
|
||||||
|
Splitter::Random => {
|
||||||
|
self.find_random_split(visitor, n, sum, parent_gain, *variable, rng);
|
||||||
|
}
|
||||||
|
Splitter::Best => {
|
||||||
|
self.find_best_split(visitor, n, sum, parent_gain, *variable);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.nodes()[visitor.node].split_score.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn find_random_split(
|
||||||
|
&mut self,
|
||||||
|
visitor: &mut NodeVisitor<'_, TX, TY, X, Y>,
|
||||||
|
n: usize,
|
||||||
|
sum: f64,
|
||||||
|
parent_gain: f64,
|
||||||
|
j: usize,
|
||||||
|
rng: &mut impl Rng,
|
||||||
|
) {
|
||||||
|
let (min_val, max_val) = {
|
||||||
|
let mut min_opt = None;
|
||||||
|
let mut max_opt = None;
|
||||||
|
for &i in &visitor.order[j] {
|
||||||
|
if visitor.samples[i] > 0 {
|
||||||
|
min_opt = Some(*visitor.x.get((i, j)));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for &i in visitor.order[j].iter().rev() {
|
||||||
|
if visitor.samples[i] > 0 {
|
||||||
|
max_opt = Some(*visitor.x.get((i, j)));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if min_opt.is_none() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
(min_opt.unwrap(), max_opt.unwrap())
|
||||||
|
};
|
||||||
|
|
||||||
|
if min_val >= max_val {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let split_value = rng.gen_range(min_val.to_f64().unwrap()..max_val.to_f64().unwrap());
|
||||||
|
|
||||||
|
let mut true_sum = 0f64;
|
||||||
|
let mut true_count = 0;
|
||||||
|
for &i in &visitor.order[j] {
|
||||||
|
if visitor.samples[i] > 0 {
|
||||||
|
if visitor.x.get((i, j)).to_f64().unwrap() <= split_value {
|
||||||
|
true_sum += visitor.samples[i] as f64 * visitor.y.get(i).to_f64().unwrap();
|
||||||
|
true_count += visitor.samples[i];
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let false_count = n - true_count;
|
||||||
|
|
||||||
|
if true_count < self.parameters().min_samples_leaf
|
||||||
|
|| false_count < self.parameters().min_samples_leaf
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let true_mean = if true_count > 0 {
|
||||||
|
true_sum / true_count as f64
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
let false_mean = if false_count > 0 {
|
||||||
|
(sum - true_sum) / false_count as f64
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
let gain = (true_count as f64 * true_mean * true_mean
|
||||||
|
+ false_count as f64 * false_mean * false_mean)
|
||||||
|
- parent_gain;
|
||||||
|
|
||||||
|
if self.nodes[visitor.node].split_score.is_none()
|
||||||
|
|| gain > self.nodes[visitor.node].split_score.unwrap()
|
||||||
|
{
|
||||||
|
self.nodes[visitor.node].split_feature = j;
|
||||||
|
self.nodes[visitor.node].split_value = Some(split_value);
|
||||||
|
self.nodes[visitor.node].split_score = Some(gain);
|
||||||
|
visitor.true_child_output = true_mean;
|
||||||
|
visitor.false_child_output = false_mean;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn find_best_split(
|
||||||
|
&mut self,
|
||||||
|
visitor: &mut NodeVisitor<'_, TX, TY, X, Y>,
|
||||||
|
n: usize,
|
||||||
|
sum: f64,
|
||||||
|
parent_gain: f64,
|
||||||
|
j: usize,
|
||||||
|
) {
|
||||||
|
let mut true_sum = 0f64;
|
||||||
|
let mut true_count = 0;
|
||||||
|
let mut prevx = Option::None;
|
||||||
|
|
||||||
|
for i in visitor.order[j].iter() {
|
||||||
|
if visitor.samples[*i] > 0 {
|
||||||
|
let x_ij = *visitor.x.get((*i, j));
|
||||||
|
|
||||||
|
if prevx.is_none() || x_ij == prevx.unwrap() {
|
||||||
|
prevx = Some(x_ij);
|
||||||
|
true_count += visitor.samples[*i];
|
||||||
|
true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let false_count = n - true_count;
|
||||||
|
|
||||||
|
if true_count < self.parameters().min_samples_leaf
|
||||||
|
|| false_count < self.parameters().min_samples_leaf
|
||||||
|
{
|
||||||
|
prevx = Some(x_ij);
|
||||||
|
true_count += visitor.samples[*i];
|
||||||
|
true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let true_mean = true_sum / true_count as f64;
|
||||||
|
let false_mean = (sum - true_sum) / false_count as f64;
|
||||||
|
|
||||||
|
let gain = (true_count as f64 * true_mean * true_mean
|
||||||
|
+ false_count as f64 * false_mean * false_mean)
|
||||||
|
- parent_gain;
|
||||||
|
|
||||||
|
if self.nodes()[visitor.node].split_score.is_none()
|
||||||
|
|| gain > self.nodes()[visitor.node].split_score.unwrap()
|
||||||
|
{
|
||||||
|
self.nodes[visitor.node].split_feature = j;
|
||||||
|
self.nodes[visitor.node].split_value =
|
||||||
|
Option::Some((x_ij + prevx.unwrap()).to_f64().unwrap() / 2f64);
|
||||||
|
self.nodes[visitor.node].split_score = Option::Some(gain);
|
||||||
|
|
||||||
|
visitor.true_child_output = true_mean;
|
||||||
|
visitor.false_child_output = false_mean;
|
||||||
|
}
|
||||||
|
|
||||||
|
prevx = Some(x_ij);
|
||||||
|
true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
|
||||||
|
true_count += visitor.samples[*i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn split<'a>(
|
||||||
|
&mut self,
|
||||||
|
mut visitor: NodeVisitor<'a, TX, TY, X, Y>,
|
||||||
|
mtry: usize,
|
||||||
|
visitor_queue: &mut LinkedList<NodeVisitor<'a, TX, TY, X, Y>>,
|
||||||
|
rng: &mut impl Rng,
|
||||||
|
) -> bool {
|
||||||
|
let (n, _) = visitor.x.shape();
|
||||||
|
let mut tc = 0;
|
||||||
|
let mut fc = 0;
|
||||||
|
let mut true_samples: Vec<usize> = vec![0; n];
|
||||||
|
|
||||||
|
for (i, true_sample) in true_samples.iter_mut().enumerate().take(n) {
|
||||||
|
if visitor.samples[i] > 0 {
|
||||||
|
if visitor
|
||||||
|
.x
|
||||||
|
.get((i, self.nodes()[visitor.node].split_feature))
|
||||||
|
.to_f64()
|
||||||
|
.unwrap()
|
||||||
|
<= self.nodes()[visitor.node].split_value.unwrap_or(f64::NAN)
|
||||||
|
{
|
||||||
|
*true_sample = visitor.samples[i];
|
||||||
|
tc += *true_sample;
|
||||||
|
visitor.samples[i] = 0;
|
||||||
|
} else {
|
||||||
|
fc += visitor.samples[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc < self.parameters().min_samples_leaf || fc < self.parameters().min_samples_leaf {
|
||||||
|
self.nodes[visitor.node].split_feature = 0;
|
||||||
|
self.nodes[visitor.node].split_value = Option::None;
|
||||||
|
self.nodes[visitor.node].split_score = Option::None;
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let true_child_idx = self.nodes().len();
|
||||||
|
|
||||||
|
self.nodes.push(Node::new(visitor.true_child_output));
|
||||||
|
let false_child_idx = self.nodes().len();
|
||||||
|
self.nodes.push(Node::new(visitor.false_child_output));
|
||||||
|
|
||||||
|
self.nodes[visitor.node].true_child = Some(true_child_idx);
|
||||||
|
self.nodes[visitor.node].false_child = Some(false_child_idx);
|
||||||
|
|
||||||
|
self.depth = u16::max(self.depth, visitor.level + 1);
|
||||||
|
|
||||||
|
let mut true_visitor = NodeVisitor::<TX, TY, X, Y>::new(
|
||||||
|
true_child_idx,
|
||||||
|
true_samples,
|
||||||
|
visitor.order,
|
||||||
|
visitor.x,
|
||||||
|
visitor.y,
|
||||||
|
visitor.level + 1,
|
||||||
|
);
|
||||||
|
|
||||||
|
if self.find_best_cutoff(&mut true_visitor, mtry, rng) {
|
||||||
|
visitor_queue.push_back(true_visitor);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut false_visitor = NodeVisitor::<TX, TY, X, Y>::new(
|
||||||
|
false_child_idx,
|
||||||
|
visitor.samples,
|
||||||
|
visitor.order,
|
||||||
|
visitor.x,
|
||||||
|
visitor.y,
|
||||||
|
visitor.level + 1,
|
||||||
|
);
|
||||||
|
|
||||||
|
if self.find_best_cutoff(&mut false_visitor, mtry, rng) {
|
||||||
|
visitor_queue.push_back(false_visitor);
|
||||||
|
}
|
||||||
|
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -674,15 +674,20 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
) -> bool {
|
) -> bool {
|
||||||
let (n_rows, n_attr) = visitor.x.shape();
|
let (n_rows, n_attr) = visitor.x.shape();
|
||||||
|
|
||||||
let mut label = Option::None;
|
let mut label = None;
|
||||||
let mut is_pure = true;
|
let mut is_pure = true;
|
||||||
for i in 0..n_rows {
|
for i in 0..n_rows {
|
||||||
if visitor.samples[i] > 0 {
|
if visitor.samples[i] > 0 {
|
||||||
if label.is_none() {
|
match label {
|
||||||
label = Option::Some(visitor.y[i]);
|
None => {
|
||||||
} else if visitor.y[i] != label.unwrap() {
|
label = Some(visitor.y[i]);
|
||||||
is_pure = false;
|
}
|
||||||
break;
|
Some(current_label) => {
|
||||||
|
if visitor.y[i] != current_label {
|
||||||
|
is_pure = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -58,22 +58,17 @@
|
|||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
|
||||||
use std::collections::LinkedList;
|
|
||||||
use std::default::Default;
|
use std::default::Default;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::marker::PhantomData;
|
|
||||||
|
|
||||||
use rand::seq::SliceRandom;
|
|
||||||
use rand::Rng;
|
|
||||||
|
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use super::base_tree_regressor::{BaseTreeRegressor, BaseTreeRegressorParameters, Splitter};
|
||||||
use crate::api::{Predictor, SupervisedEstimator};
|
use crate::api::{Predictor, SupervisedEstimator};
|
||||||
use crate::error::Failed;
|
use crate::error::Failed;
|
||||||
use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1};
|
use crate::linalg::basic::arrays::{Array1, Array2};
|
||||||
use crate::numbers::basenum::Number;
|
use crate::numbers::basenum::Number;
|
||||||
use crate::rand_custom::get_rng_impl;
|
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@@ -98,41 +93,7 @@ pub struct DecisionTreeRegressorParameters {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct DecisionTreeRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
pub struct DecisionTreeRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||||
{
|
{
|
||||||
nodes: Vec<Node>,
|
tree_regressor: Option<BaseTreeRegressor<TX, TY, X, Y>>,
|
||||||
parameters: Option<DecisionTreeRegressorParameters>,
|
|
||||||
depth: u16,
|
|
||||||
_phantom_tx: PhantomData<TX>,
|
|
||||||
_phantom_ty: PhantomData<TY>,
|
|
||||||
_phantom_x: PhantomData<X>,
|
|
||||||
_phantom_y: PhantomData<Y>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
|
||||||
DecisionTreeRegressor<TX, TY, X, Y>
|
|
||||||
{
|
|
||||||
/// Get nodes, return a shared reference
|
|
||||||
fn nodes(&self) -> &Vec<Node> {
|
|
||||||
self.nodes.as_ref()
|
|
||||||
}
|
|
||||||
/// Get parameters, return a shared reference
|
|
||||||
fn parameters(&self) -> &DecisionTreeRegressorParameters {
|
|
||||||
self.parameters.as_ref().unwrap()
|
|
||||||
}
|
|
||||||
/// Get estimate of intercept, return value
|
|
||||||
fn depth(&self) -> u16 {
|
|
||||||
self.depth
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct Node {
|
|
||||||
output: f64,
|
|
||||||
split_feature: usize,
|
|
||||||
split_value: Option<f64>,
|
|
||||||
split_score: Option<f64>,
|
|
||||||
true_child: Option<usize>,
|
|
||||||
false_child: Option<usize>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DecisionTreeRegressorParameters {
|
impl DecisionTreeRegressorParameters {
|
||||||
@@ -296,87 +257,11 @@ impl Default for DecisionTreeRegressorSearchParameters {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Node {
|
|
||||||
fn new(output: f64) -> Self {
|
|
||||||
Node {
|
|
||||||
output,
|
|
||||||
split_feature: 0,
|
|
||||||
split_value: Option::None,
|
|
||||||
split_score: Option::None,
|
|
||||||
true_child: Option::None,
|
|
||||||
false_child: Option::None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PartialEq for Node {
|
|
||||||
fn eq(&self, other: &Self) -> bool {
|
|
||||||
(self.output - other.output).abs() < f64::EPSILON
|
|
||||||
&& self.split_feature == other.split_feature
|
|
||||||
&& match (self.split_value, other.split_value) {
|
|
||||||
(Some(a), Some(b)) => (a - b).abs() < f64::EPSILON,
|
|
||||||
(None, None) => true,
|
|
||||||
_ => false,
|
|
||||||
}
|
|
||||||
&& match (self.split_score, other.split_score) {
|
|
||||||
(Some(a), Some(b)) => (a - b).abs() < f64::EPSILON,
|
|
||||||
(None, None) => true,
|
|
||||||
_ => false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
|
impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
|
||||||
for DecisionTreeRegressor<TX, TY, X, Y>
|
for DecisionTreeRegressor<TX, TY, X, Y>
|
||||||
{
|
{
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
if self.depth != other.depth || self.nodes().len() != other.nodes().len() {
|
self.tree_regressor == other.tree_regressor
|
||||||
false
|
|
||||||
} else {
|
|
||||||
self.nodes()
|
|
||||||
.iter()
|
|
||||||
.zip(other.nodes().iter())
|
|
||||||
.all(|(a, b)| a == b)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct NodeVisitor<'a, TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
|
|
||||||
x: &'a X,
|
|
||||||
y: &'a Y,
|
|
||||||
node: usize,
|
|
||||||
samples: Vec<usize>,
|
|
||||||
order: &'a [Vec<usize>],
|
|
||||||
true_child_output: f64,
|
|
||||||
false_child_output: f64,
|
|
||||||
level: u16,
|
|
||||||
_phantom_tx: PhantomData<TX>,
|
|
||||||
_phantom_ty: PhantomData<TY>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a, TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
|
||||||
NodeVisitor<'a, TX, TY, X, Y>
|
|
||||||
{
|
|
||||||
fn new(
|
|
||||||
node_id: usize,
|
|
||||||
samples: Vec<usize>,
|
|
||||||
order: &'a [Vec<usize>],
|
|
||||||
x: &'a X,
|
|
||||||
y: &'a Y,
|
|
||||||
level: u16,
|
|
||||||
) -> Self {
|
|
||||||
NodeVisitor {
|
|
||||||
x,
|
|
||||||
y,
|
|
||||||
node: node_id,
|
|
||||||
samples,
|
|
||||||
order,
|
|
||||||
true_child_output: 0f64,
|
|
||||||
false_child_output: 0f64,
|
|
||||||
level,
|
|
||||||
_phantom_tx: PhantomData,
|
|
||||||
_phantom_ty: PhantomData,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -386,13 +271,7 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
{
|
{
|
||||||
fn new() -> Self {
|
fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
nodes: vec![],
|
tree_regressor: None,
|
||||||
parameters: Option::None,
|
|
||||||
depth: 0u16,
|
|
||||||
_phantom_tx: PhantomData,
|
|
||||||
_phantom_ty: PhantomData,
|
|
||||||
_phantom_x: PhantomData,
|
|
||||||
_phantom_y: PhantomData,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -420,283 +299,23 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
y: &Y,
|
y: &Y,
|
||||||
parameters: DecisionTreeRegressorParameters,
|
parameters: DecisionTreeRegressorParameters,
|
||||||
) -> Result<DecisionTreeRegressor<TX, TY, X, Y>, Failed> {
|
) -> Result<DecisionTreeRegressor<TX, TY, X, Y>, Failed> {
|
||||||
let (x_nrows, num_attributes) = x.shape();
|
let tree_parameters = BaseTreeRegressorParameters {
|
||||||
if x_nrows != y.shape() {
|
max_depth: parameters.max_depth,
|
||||||
return Err(Failed::fit("Size of x should equal size of y"));
|
min_samples_leaf: parameters.min_samples_leaf,
|
||||||
}
|
min_samples_split: parameters.min_samples_split,
|
||||||
|
seed: parameters.seed,
|
||||||
let samples = vec![1; x_nrows];
|
splitter: Splitter::Best,
|
||||||
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn fit_weak_learner(
|
|
||||||
x: &X,
|
|
||||||
y: &Y,
|
|
||||||
samples: Vec<usize>,
|
|
||||||
mtry: usize,
|
|
||||||
parameters: DecisionTreeRegressorParameters,
|
|
||||||
) -> Result<DecisionTreeRegressor<TX, TY, X, Y>, Failed> {
|
|
||||||
let y_m = y.clone();
|
|
||||||
|
|
||||||
let y_ncols = y_m.shape();
|
|
||||||
let (_, num_attributes) = x.shape();
|
|
||||||
|
|
||||||
let mut nodes: Vec<Node> = Vec::new();
|
|
||||||
let mut rng = get_rng_impl(parameters.seed);
|
|
||||||
|
|
||||||
let mut n = 0;
|
|
||||||
let mut sum = 0f64;
|
|
||||||
for (i, sample_i) in samples.iter().enumerate().take(y_ncols) {
|
|
||||||
n += *sample_i;
|
|
||||||
sum += *sample_i as f64 * y_m.get(i).to_f64().unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
let root = Node::new(sum / (n as f64));
|
|
||||||
nodes.push(root);
|
|
||||||
let mut order: Vec<Vec<usize>> = Vec::new();
|
|
||||||
|
|
||||||
for i in 0..num_attributes {
|
|
||||||
let mut col_i: Vec<TX> = x.get_col(i).iterator(0).copied().collect();
|
|
||||||
order.push(col_i.argsort_mut());
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut tree = DecisionTreeRegressor {
|
|
||||||
nodes,
|
|
||||||
parameters: Some(parameters),
|
|
||||||
depth: 0u16,
|
|
||||||
_phantom_tx: PhantomData,
|
|
||||||
_phantom_ty: PhantomData,
|
|
||||||
_phantom_x: PhantomData,
|
|
||||||
_phantom_y: PhantomData,
|
|
||||||
};
|
};
|
||||||
|
let tree = BaseTreeRegressor::fit(x, y, tree_parameters)?;
|
||||||
let mut visitor = NodeVisitor::<TX, TY, X, Y>::new(0, samples, &order, x, &y_m, 1);
|
Ok(Self {
|
||||||
|
tree_regressor: Some(tree),
|
||||||
let mut visitor_queue: LinkedList<NodeVisitor<'_, TX, TY, X, Y>> = LinkedList::new();
|
})
|
||||||
|
|
||||||
if tree.find_best_cutoff(&mut visitor, mtry, &mut rng) {
|
|
||||||
visitor_queue.push_back(visitor);
|
|
||||||
}
|
|
||||||
|
|
||||||
while tree.depth() < tree.parameters().max_depth.unwrap_or(u16::MAX) {
|
|
||||||
match visitor_queue.pop_front() {
|
|
||||||
Some(node) => tree.split(node, mtry, &mut visitor_queue, &mut rng),
|
|
||||||
None => break,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(tree)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Predict regression value for `x`.
|
/// Predict regression value for `x`.
|
||||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||||
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
|
||||||
let mut result = Y::zeros(x.shape().0);
|
self.tree_regressor.as_ref().unwrap().predict(x)
|
||||||
|
|
||||||
let (n, _) = x.shape();
|
|
||||||
|
|
||||||
for i in 0..n {
|
|
||||||
result.set(i, self.predict_for_row(x, i));
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn predict_for_row(&self, x: &X, row: usize) -> TY {
|
|
||||||
let mut result = 0f64;
|
|
||||||
let mut queue: LinkedList<usize> = LinkedList::new();
|
|
||||||
|
|
||||||
queue.push_back(0);
|
|
||||||
|
|
||||||
while !queue.is_empty() {
|
|
||||||
match queue.pop_front() {
|
|
||||||
Some(node_id) => {
|
|
||||||
let node = &self.nodes()[node_id];
|
|
||||||
if node.true_child.is_none() && node.false_child.is_none() {
|
|
||||||
result = node.output;
|
|
||||||
} else if x.get((row, node.split_feature)).to_f64().unwrap()
|
|
||||||
<= node.split_value.unwrap_or(f64::NAN)
|
|
||||||
{
|
|
||||||
queue.push_back(node.true_child.unwrap());
|
|
||||||
} else {
|
|
||||||
queue.push_back(node.false_child.unwrap());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => break,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
TY::from_f64(result).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn find_best_cutoff(
|
|
||||||
&mut self,
|
|
||||||
visitor: &mut NodeVisitor<'_, TX, TY, X, Y>,
|
|
||||||
mtry: usize,
|
|
||||||
rng: &mut impl Rng,
|
|
||||||
) -> bool {
|
|
||||||
let (_, n_attr) = visitor.x.shape();
|
|
||||||
|
|
||||||
let n: usize = visitor.samples.iter().sum();
|
|
||||||
|
|
||||||
if n < self.parameters().min_samples_split {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
let sum = self.nodes()[visitor.node].output * n as f64;
|
|
||||||
|
|
||||||
let mut variables = (0..n_attr).collect::<Vec<_>>();
|
|
||||||
|
|
||||||
if mtry < n_attr {
|
|
||||||
variables.shuffle(rng);
|
|
||||||
}
|
|
||||||
|
|
||||||
let parent_gain =
|
|
||||||
n as f64 * self.nodes()[visitor.node].output * self.nodes()[visitor.node].output;
|
|
||||||
|
|
||||||
for variable in variables.iter().take(mtry) {
|
|
||||||
self.find_best_split(visitor, n, sum, parent_gain, *variable);
|
|
||||||
}
|
|
||||||
|
|
||||||
self.nodes()[visitor.node].split_score.is_some()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn find_best_split(
|
|
||||||
&mut self,
|
|
||||||
visitor: &mut NodeVisitor<'_, TX, TY, X, Y>,
|
|
||||||
n: usize,
|
|
||||||
sum: f64,
|
|
||||||
parent_gain: f64,
|
|
||||||
j: usize,
|
|
||||||
) {
|
|
||||||
let mut true_sum = 0f64;
|
|
||||||
let mut true_count = 0;
|
|
||||||
let mut prevx = Option::None;
|
|
||||||
|
|
||||||
for i in visitor.order[j].iter() {
|
|
||||||
if visitor.samples[*i] > 0 {
|
|
||||||
let x_ij = *visitor.x.get((*i, j));
|
|
||||||
|
|
||||||
if prevx.is_none() || x_ij == prevx.unwrap() {
|
|
||||||
prevx = Some(x_ij);
|
|
||||||
true_count += visitor.samples[*i];
|
|
||||||
true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let false_count = n - true_count;
|
|
||||||
|
|
||||||
if true_count < self.parameters().min_samples_leaf
|
|
||||||
|| false_count < self.parameters().min_samples_leaf
|
|
||||||
{
|
|
||||||
prevx = Some(x_ij);
|
|
||||||
true_count += visitor.samples[*i];
|
|
||||||
true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let true_mean = true_sum / true_count as f64;
|
|
||||||
let false_mean = (sum - true_sum) / false_count as f64;
|
|
||||||
|
|
||||||
let gain = (true_count as f64 * true_mean * true_mean
|
|
||||||
+ false_count as f64 * false_mean * false_mean)
|
|
||||||
- parent_gain;
|
|
||||||
|
|
||||||
if self.nodes()[visitor.node].split_score.is_none()
|
|
||||||
|| gain > self.nodes()[visitor.node].split_score.unwrap()
|
|
||||||
{
|
|
||||||
self.nodes[visitor.node].split_feature = j;
|
|
||||||
self.nodes[visitor.node].split_value =
|
|
||||||
Option::Some((x_ij + prevx.unwrap()).to_f64().unwrap() / 2f64);
|
|
||||||
self.nodes[visitor.node].split_score = Option::Some(gain);
|
|
||||||
|
|
||||||
visitor.true_child_output = true_mean;
|
|
||||||
visitor.false_child_output = false_mean;
|
|
||||||
}
|
|
||||||
|
|
||||||
prevx = Some(x_ij);
|
|
||||||
true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
|
|
||||||
true_count += visitor.samples[*i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn split<'a>(
|
|
||||||
&mut self,
|
|
||||||
mut visitor: NodeVisitor<'a, TX, TY, X, Y>,
|
|
||||||
mtry: usize,
|
|
||||||
visitor_queue: &mut LinkedList<NodeVisitor<'a, TX, TY, X, Y>>,
|
|
||||||
rng: &mut impl Rng,
|
|
||||||
) -> bool {
|
|
||||||
let (n, _) = visitor.x.shape();
|
|
||||||
let mut tc = 0;
|
|
||||||
let mut fc = 0;
|
|
||||||
let mut true_samples: Vec<usize> = vec![0; n];
|
|
||||||
|
|
||||||
for (i, true_sample) in true_samples.iter_mut().enumerate().take(n) {
|
|
||||||
if visitor.samples[i] > 0 {
|
|
||||||
if visitor
|
|
||||||
.x
|
|
||||||
.get((i, self.nodes()[visitor.node].split_feature))
|
|
||||||
.to_f64()
|
|
||||||
.unwrap()
|
|
||||||
<= self.nodes()[visitor.node].split_value.unwrap_or(f64::NAN)
|
|
||||||
{
|
|
||||||
*true_sample = visitor.samples[i];
|
|
||||||
tc += *true_sample;
|
|
||||||
visitor.samples[i] = 0;
|
|
||||||
} else {
|
|
||||||
fc += visitor.samples[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if tc < self.parameters().min_samples_leaf || fc < self.parameters().min_samples_leaf {
|
|
||||||
self.nodes[visitor.node].split_feature = 0;
|
|
||||||
self.nodes[visitor.node].split_value = Option::None;
|
|
||||||
self.nodes[visitor.node].split_score = Option::None;
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
let true_child_idx = self.nodes().len();
|
|
||||||
|
|
||||||
self.nodes.push(Node::new(visitor.true_child_output));
|
|
||||||
let false_child_idx = self.nodes().len();
|
|
||||||
self.nodes.push(Node::new(visitor.false_child_output));
|
|
||||||
|
|
||||||
self.nodes[visitor.node].true_child = Some(true_child_idx);
|
|
||||||
self.nodes[visitor.node].false_child = Some(false_child_idx);
|
|
||||||
|
|
||||||
self.depth = u16::max(self.depth, visitor.level + 1);
|
|
||||||
|
|
||||||
let mut true_visitor = NodeVisitor::<TX, TY, X, Y>::new(
|
|
||||||
true_child_idx,
|
|
||||||
true_samples,
|
|
||||||
visitor.order,
|
|
||||||
visitor.x,
|
|
||||||
visitor.y,
|
|
||||||
visitor.level + 1,
|
|
||||||
);
|
|
||||||
|
|
||||||
if self.find_best_cutoff(&mut true_visitor, mtry, rng) {
|
|
||||||
visitor_queue.push_back(true_visitor);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut false_visitor = NodeVisitor::<TX, TY, X, Y>::new(
|
|
||||||
false_child_idx,
|
|
||||||
visitor.samples,
|
|
||||||
visitor.order,
|
|
||||||
visitor.x,
|
|
||||||
visitor.y,
|
|
||||||
visitor.level + 1,
|
|
||||||
);
|
|
||||||
|
|
||||||
if self.find_best_cutoff(&mut false_visitor, mtry, rng) {
|
|
||||||
visitor_queue.push_back(false_visitor);
|
|
||||||
}
|
|
||||||
|
|
||||||
true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@
|
|||||||
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||||
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||||
|
|
||||||
|
pub(crate) mod base_tree_regressor;
|
||||||
/// Classification tree for dependent variables that take a finite number of unordered values.
|
/// Classification tree for dependent variables that take a finite number of unordered values.
|
||||||
pub mod decision_tree_classifier;
|
pub mod decision_tree_classifier;
|
||||||
/// Regression tree for for dependent variables that take continuous or ordered discrete values.
|
/// Regression tree for for dependent variables that take continuous or ordered discrete values.
|
||||||
|
|||||||
@@ -0,0 +1,16 @@
|
|||||||
|
//! # XGBoost
|
||||||
|
//!
|
||||||
|
//! XGBoost, which stands for Extreme Gradient Boosting, is a powerful and efficient implementation of the gradient boosting framework. Gradient boosting is a machine learning technique for regression and classification problems, which produces a prediction model in the form of an ensemble of weak prediction models, typically decision trees.
|
||||||
|
//!
|
||||||
|
//! The core idea of boosting is to build the model in a stage-wise fashion. It learns from its mistakes by sequentially adding new models that correct the errors of the previous ones. Unlike bagging, which trains models in parallel, boosting is a sequential process. Each new tree is fit on a modified version of the original data set, specifically focusing on the instances where the previous models performed poorly.
|
||||||
|
//!
|
||||||
|
//! XGBoost enhances this process through several key innovations. It employs a more regularized model formalization to control over-fitting, which gives it better performance. It also has a highly optimized and parallelized tree construction process, making it significantly faster and more scalable than traditional gradient boosting implementations.
|
||||||
|
//!
|
||||||
|
//! ## References:
|
||||||
|
//!
|
||||||
|
//! * "Elements of Statistical Learning", Hastie T., Tibshirani R., Friedman J., 10. Boosting and Additive Trees
|
||||||
|
//! * XGBoost: A Scalable Tree Boosting System, Chen T., Guestrin C.
|
||||||
|
|
||||||
|
// xgboost implementation
|
||||||
|
pub mod xgb_regressor;
|
||||||
|
pub use xgb_regressor::{XGRegressor, XGRegressorParameters};
|
||||||
@@ -0,0 +1,762 @@
|
|||||||
|
//! # Extreme Gradient Boosting (XGBoost)
|
||||||
|
//!
|
||||||
|
//! XGBoost is a highly efficient and effective implementation of the gradient boosting framework.
|
||||||
|
//! Like other boosting models, it builds an ensemble of sequential decision trees, where each new tree
|
||||||
|
//! is trained to correct the errors of the previous ones.
|
||||||
|
//!
|
||||||
|
//! What makes XGBoost powerful is its use of both the first and second derivatives (gradient and hessian)
|
||||||
|
//! of the loss function, which allows for more accurate approximations and faster convergence. It also
|
||||||
|
//! includes built-in regularization techniques (L1/`alpha` and L2/`lambda`) to prevent overfitting.
|
||||||
|
//!
|
||||||
|
//! This implementation was ported to Rust from the concepts and algorithm explained in the blog post
|
||||||
|
//! ["XGBoost from Scratch"](https://randomrealizations.com/posts/xgboost-from-scratch/). It is designed
|
||||||
|
//! to be a general-purpose regressor that can be used with any objective function that provides a gradient
|
||||||
|
//! and a hessian.
|
||||||
|
//!
|
||||||
|
//! Example:
|
||||||
|
//!
|
||||||
|
//! ```
|
||||||
|
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||||
|
//! use smartcore::xgboost::{XGRegressor, XGRegressorParameters};
|
||||||
|
//!
|
||||||
|
//! // Simple dataset: predict y = 2*x
|
||||||
|
//! let x = DenseMatrix::from_2d_array(&[
|
||||||
|
//! &[1.0], &[2.0], &[3.0], &[4.0], &[5.0]
|
||||||
|
//! ]).unwrap();
|
||||||
|
//! let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
|
||||||
|
//!
|
||||||
|
//! // Use default parameters, but set a few for demonstration
|
||||||
|
//! let parameters = XGRegressorParameters::default()
|
||||||
|
//! .with_n_estimators(50)
|
||||||
|
//! .with_max_depth(3)
|
||||||
|
//! .with_learning_rate(0.1);
|
||||||
|
//!
|
||||||
|
//! // Train the model
|
||||||
|
//! let model = XGRegressor::fit(&x, &y, parameters).unwrap();
|
||||||
|
//!
|
||||||
|
//! // Make predictions
|
||||||
|
//! let x_test = DenseMatrix::from_2d_array(&[&[6.0], &[7.0]]).unwrap();
|
||||||
|
//! let y_hat = model.predict(&x_test).unwrap();
|
||||||
|
//!
|
||||||
|
//! // y_hat should be close to [12.0, 14.0]
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
|
||||||
|
use rand::{seq::SliceRandom, Rng};
|
||||||
|
use std::{iter::zip, marker::PhantomData};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
api::{PredictorBorrow, SupervisedEstimatorBorrow},
|
||||||
|
error::{Failed, FailedError},
|
||||||
|
linalg::basic::arrays::{Array1, Array2},
|
||||||
|
numbers::basenum::Number,
|
||||||
|
rand_custom::get_rng_impl,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Defines the objective function to be optimized.
|
||||||
|
/// The objective function provides the loss, gradient (first derivative), and
|
||||||
|
/// hessian (second derivative) required for the XGBoost algorithm.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub enum Objective {
|
||||||
|
/// The objective for regression tasks using Mean Squared Error.
|
||||||
|
/// Loss: 0.5 * (y_true - y_pred)^2
|
||||||
|
MeanSquaredError,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Objective {
|
||||||
|
/// Calculates the loss for each sample given the true and predicted values.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `y_true` - A vector of the true target values.
|
||||||
|
/// * `y_pred` - A vector of the predicted values.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// The mean of the calculated loss values.
|
||||||
|
pub fn loss_function<TY: Number, Y: Array1<TY>>(&self, y_true: &Y, y_pred: &Vec<f64>) -> f64 {
|
||||||
|
match self {
|
||||||
|
Objective::MeanSquaredError => {
|
||||||
|
zip(y_true.iterator(0), y_pred)
|
||||||
|
.map(|(true_val, pred_val)| {
|
||||||
|
0.5 * (true_val.to_f64().unwrap() - pred_val).powi(2)
|
||||||
|
})
|
||||||
|
.sum::<f64>()
|
||||||
|
/ y_true.shape() as f64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculates the gradient (first derivative) of the loss function.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `y_true` - A vector of the true target values.
|
||||||
|
/// * `y_pred` - A vector of the predicted values.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// A vector of gradients for each sample.
|
||||||
|
pub fn gradient<TY: Number, Y: Array1<TY>>(&self, y_true: &Y, y_pred: &Vec<f64>) -> Vec<f64> {
|
||||||
|
match self {
|
||||||
|
Objective::MeanSquaredError => zip(y_true.iterator(0), y_pred)
|
||||||
|
.map(|(true_val, pred_val)| *pred_val - true_val.to_f64().unwrap())
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calculates the hessian (second derivative) of the loss function.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `y_true` - A vector of the true target values.
|
||||||
|
/// * `y_pred` - A vector of the predicted values.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// A vector of hessians for each sample.
|
||||||
|
#[allow(unused_variables)]
|
||||||
|
pub fn hessian<TY: Number, Y: Array1<TY>>(&self, y_true: &Y, y_pred: &[f64]) -> Vec<f64> {
|
||||||
|
match self {
|
||||||
|
Objective::MeanSquaredError => vec![1.0; y_true.shape()],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Represents a single decision tree in the XGBoost ensemble.
|
||||||
|
///
|
||||||
|
/// This is a recursive data structure where each `TreeRegressor` is a node
|
||||||
|
/// that can have a left and a right child, also of type `TreeRegressor`.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
struct TreeRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
|
||||||
|
left: Option<Box<TreeRegressor<TX, TY, X, Y>>>,
|
||||||
|
right: Option<Box<TreeRegressor<TX, TY, X, Y>>>,
|
||||||
|
/// The output value of this node. If it's a leaf, this is the final prediction.
|
||||||
|
value: f64,
|
||||||
|
/// The feature value threshold used to split this node.
|
||||||
|
threshold: f64,
|
||||||
|
/// The index of the feature used for splitting.
|
||||||
|
split_feature_idx: usize,
|
||||||
|
/// The gain in score achieved by this split.
|
||||||
|
split_score: f64,
|
||||||
|
_phantom_tx: PhantomData<TX>,
|
||||||
|
_phantom_ty: PhantomData<TY>,
|
||||||
|
_phantom_x: PhantomData<X>,
|
||||||
|
_phantom_y: PhantomData<Y>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||||
|
TreeRegressor<TX, TY, X, Y>
|
||||||
|
{
|
||||||
|
/// Recursively builds a decision tree (a `TreeRegressor` node).
|
||||||
|
///
|
||||||
|
/// This function determines the optimal split for the given set of samples (`idxs`)
|
||||||
|
/// and then recursively calls itself to build the left and right child nodes.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `data` - The full training dataset.
|
||||||
|
/// * `g` - Gradients for all samples.
|
||||||
|
/// * `h` - Hessians for all samples.
|
||||||
|
/// * `idxs` - The indices of the samples belonging to the current node.
|
||||||
|
/// * `max_depth` - The maximum remaining depth for this branch.
|
||||||
|
/// * `min_child_weight` - The minimum sum of hessians required in a child node.
|
||||||
|
/// * `lambda` - L2 regularization term on weights.
|
||||||
|
/// * `gamma` - Minimum loss reduction required to make a further partition.
|
||||||
|
pub fn fit(
|
||||||
|
data: &X,
|
||||||
|
g: &Vec<f64>,
|
||||||
|
h: &Vec<f64>,
|
||||||
|
idxs: &[usize],
|
||||||
|
max_depth: u16,
|
||||||
|
min_child_weight: f64,
|
||||||
|
lambda: f64,
|
||||||
|
gamma: f64,
|
||||||
|
) -> Self {
|
||||||
|
let g_sum = idxs.iter().map(|&i| g[i]).sum::<f64>();
|
||||||
|
let h_sum = idxs.iter().map(|&i| h[i]).sum::<f64>();
|
||||||
|
let value = -g_sum / (h_sum + lambda);
|
||||||
|
|
||||||
|
let mut best_feature_idx = usize::MAX;
|
||||||
|
let mut best_split_score = 0.0;
|
||||||
|
let mut best_threshold = 0.0;
|
||||||
|
let mut left = Option::None;
|
||||||
|
let mut right = Option::None;
|
||||||
|
|
||||||
|
if max_depth > 0 {
|
||||||
|
Self::insert_child_nodes(
|
||||||
|
data,
|
||||||
|
g,
|
||||||
|
h,
|
||||||
|
idxs,
|
||||||
|
&mut best_feature_idx,
|
||||||
|
&mut best_split_score,
|
||||||
|
&mut best_threshold,
|
||||||
|
&mut left,
|
||||||
|
&mut right,
|
||||||
|
max_depth,
|
||||||
|
min_child_weight,
|
||||||
|
lambda,
|
||||||
|
gamma,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Self {
|
||||||
|
left,
|
||||||
|
right,
|
||||||
|
value,
|
||||||
|
threshold: best_threshold,
|
||||||
|
split_feature_idx: best_feature_idx,
|
||||||
|
split_score: best_split_score,
|
||||||
|
_phantom_tx: PhantomData,
|
||||||
|
_phantom_ty: PhantomData,
|
||||||
|
_phantom_x: PhantomData,
|
||||||
|
_phantom_y: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Finds the best split and creates child nodes if a valid split is found.
|
||||||
|
fn insert_child_nodes(
|
||||||
|
data: &X,
|
||||||
|
g: &Vec<f64>,
|
||||||
|
h: &Vec<f64>,
|
||||||
|
idxs: &[usize],
|
||||||
|
best_feature_idx: &mut usize,
|
||||||
|
best_split_score: &mut f64,
|
||||||
|
best_threshold: &mut f64,
|
||||||
|
left: &mut Option<Box<Self>>,
|
||||||
|
right: &mut Option<Box<Self>>,
|
||||||
|
max_depth: u16,
|
||||||
|
min_child_weight: f64,
|
||||||
|
lambda: f64,
|
||||||
|
gamma: f64,
|
||||||
|
) {
|
||||||
|
let (_, n_features) = data.shape();
|
||||||
|
for i in 0..n_features {
|
||||||
|
Self::find_best_split(
|
||||||
|
data,
|
||||||
|
g,
|
||||||
|
h,
|
||||||
|
idxs,
|
||||||
|
i,
|
||||||
|
best_feature_idx,
|
||||||
|
best_split_score,
|
||||||
|
best_threshold,
|
||||||
|
min_child_weight,
|
||||||
|
lambda,
|
||||||
|
gamma,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// A split is only valid if it results in a positive gain.
|
||||||
|
if *best_split_score > 0.0 {
|
||||||
|
let mut left_idxs = Vec::new();
|
||||||
|
let mut right_idxs = Vec::new();
|
||||||
|
for idx in idxs.iter() {
|
||||||
|
if data.get((*idx, *best_feature_idx)).to_f64().unwrap() <= *best_threshold {
|
||||||
|
left_idxs.push(*idx);
|
||||||
|
} else {
|
||||||
|
right_idxs.push(*idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
*left = Some(Box::new(TreeRegressor::fit(
|
||||||
|
data,
|
||||||
|
g,
|
||||||
|
h,
|
||||||
|
&left_idxs,
|
||||||
|
max_depth - 1,
|
||||||
|
min_child_weight,
|
||||||
|
lambda,
|
||||||
|
gamma,
|
||||||
|
)));
|
||||||
|
*right = Some(Box::new(TreeRegressor::fit(
|
||||||
|
data,
|
||||||
|
g,
|
||||||
|
h,
|
||||||
|
&right_idxs,
|
||||||
|
max_depth - 1,
|
||||||
|
min_child_weight,
|
||||||
|
lambda,
|
||||||
|
gamma,
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Iterates through a single feature to find the best possible split point.
|
||||||
|
fn find_best_split(
|
||||||
|
data: &X,
|
||||||
|
g: &[f64],
|
||||||
|
h: &[f64],
|
||||||
|
idxs: &[usize],
|
||||||
|
feature_idx: usize,
|
||||||
|
best_feature_idx: &mut usize,
|
||||||
|
best_split_score: &mut f64,
|
||||||
|
best_threshold: &mut f64,
|
||||||
|
min_child_weight: f64,
|
||||||
|
lambda: f64,
|
||||||
|
gamma: f64,
|
||||||
|
) {
|
||||||
|
let mut sorted_idxs = idxs.to_owned();
|
||||||
|
sorted_idxs.sort_by(|a, b| {
|
||||||
|
data.get((*a, feature_idx))
|
||||||
|
.partial_cmp(data.get((*b, feature_idx)))
|
||||||
|
.unwrap()
|
||||||
|
});
|
||||||
|
|
||||||
|
let sum_g = sorted_idxs.iter().map(|&i| g[i]).sum::<f64>();
|
||||||
|
let sum_h = sorted_idxs.iter().map(|&i| h[i]).sum::<f64>();
|
||||||
|
|
||||||
|
let mut sum_g_right = sum_g;
|
||||||
|
let mut sum_h_right = sum_h;
|
||||||
|
let mut sum_g_left = 0.0;
|
||||||
|
let mut sum_h_left = 0.0;
|
||||||
|
|
||||||
|
for i in 0..sorted_idxs.len() - 1 {
|
||||||
|
let idx = sorted_idxs[i];
|
||||||
|
let next_idx = sorted_idxs[i + 1];
|
||||||
|
|
||||||
|
let g_i = g[idx];
|
||||||
|
let h_i = h[idx];
|
||||||
|
let x_i = data.get((idx, feature_idx)).to_f64().unwrap();
|
||||||
|
let x_i_next = data.get((next_idx, feature_idx)).to_f64().unwrap();
|
||||||
|
|
||||||
|
sum_g_left += g_i;
|
||||||
|
sum_h_left += h_i;
|
||||||
|
sum_g_right -= g_i;
|
||||||
|
sum_h_right -= h_i;
|
||||||
|
|
||||||
|
if sum_h_left < min_child_weight || x_i == x_i_next {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if sum_h_right < min_child_weight {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let gain = 0.5
|
||||||
|
* ((sum_g_left * sum_g_left / (sum_h_left + lambda))
|
||||||
|
+ (sum_g_right * sum_g_right / (sum_h_right + lambda))
|
||||||
|
- (sum_g * sum_g / (sum_h + lambda)))
|
||||||
|
- gamma;
|
||||||
|
|
||||||
|
if gain > *best_split_score {
|
||||||
|
*best_split_score = gain;
|
||||||
|
*best_threshold = (x_i + x_i_next) / 2.0;
|
||||||
|
*best_feature_idx = feature_idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Predicts the output values for a dataset.
|
||||||
|
pub fn predict(&self, data: &X) -> Vec<f64> {
|
||||||
|
let (n_samples, n_features) = data.shape();
|
||||||
|
(0..n_samples)
|
||||||
|
.map(|i| {
|
||||||
|
self.predict_for_row(&Vec::from_iterator(
|
||||||
|
data.get_row(i).iterator(0).copied(),
|
||||||
|
n_features,
|
||||||
|
))
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Predicts the output value for a single row of data by traversing the tree.
|
||||||
|
pub fn predict_for_row(&self, row: &Vec<TX>) -> f64 {
|
||||||
|
// A leaf node is identified by having no children.
|
||||||
|
if self.left.is_none() {
|
||||||
|
return self.value;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recurse down the appropriate branch.
|
||||||
|
let child = if row[self.split_feature_idx].to_f64().unwrap() <= self.threshold {
|
||||||
|
self.left.as_ref().unwrap()
|
||||||
|
} else {
|
||||||
|
self.right.as_ref().unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
|
child.predict_for_row(row)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parameters for the `jRegressor` model.
|
||||||
|
///
|
||||||
|
/// This struct holds all the hyperparameters that control the training process.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct XGRegressorParameters {
|
||||||
|
/// The number of boosting rounds or trees to build.
|
||||||
|
pub n_estimators: usize,
|
||||||
|
/// The maximum depth of each tree.
|
||||||
|
pub max_depth: u16,
|
||||||
|
/// Step size shrinkage used to prevent overfitting.
|
||||||
|
pub learning_rate: f64,
|
||||||
|
/// Minimum sum of instance weight (hessian) needed in a child.
|
||||||
|
pub min_child_weight: usize,
|
||||||
|
/// L2 regularization term on weights.
|
||||||
|
pub lambda: f64,
|
||||||
|
/// Minimum loss reduction required to make a further partition on a leaf node.
|
||||||
|
pub gamma: f64,
|
||||||
|
/// The initial prediction score for all instances.
|
||||||
|
pub base_score: f64,
|
||||||
|
/// The fraction of samples to be used for fitting the individual base learners.
|
||||||
|
pub subsample: f64,
|
||||||
|
/// The seed for the random number generator for reproducibility.
|
||||||
|
pub seed: u64,
|
||||||
|
/// The objective function to be optimized.
|
||||||
|
pub objective: Objective,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for XGRegressorParameters {
|
||||||
|
/// Creates a new set of `XGRegressorParameters` with default values.
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
n_estimators: 100,
|
||||||
|
learning_rate: 0.3,
|
||||||
|
max_depth: 6,
|
||||||
|
min_child_weight: 1,
|
||||||
|
lambda: 1.0,
|
||||||
|
gamma: 0.0,
|
||||||
|
base_score: 0.5,
|
||||||
|
subsample: 1.0,
|
||||||
|
seed: 0,
|
||||||
|
objective: Objective::MeanSquaredError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Builder pattern for XGRegressorParameters
|
||||||
|
impl XGRegressorParameters {
|
||||||
|
/// Sets the number of boosting rounds or trees to build.
|
||||||
|
pub fn with_n_estimators(mut self, n_estimators: usize) -> Self {
|
||||||
|
self.n_estimators = n_estimators;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sets the step size shrinkage used to prevent overfitting.
|
||||||
|
///
|
||||||
|
/// Also known as `eta`. A smaller value makes the model more robust by preventing
|
||||||
|
/// too much weight being given to any single tree.
|
||||||
|
pub fn with_learning_rate(mut self, learning_rate: f64) -> Self {
|
||||||
|
self.learning_rate = learning_rate;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sets the maximum depth of each individual tree.
|
||||||
|
// A lower value helps prevent overfitting.*
|
||||||
|
pub fn with_max_depth(mut self, max_depth: u16) -> Self {
|
||||||
|
self.max_depth = max_depth;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sets the minimum sum of instance weight (hessian) needed in a child node.
|
||||||
|
///
|
||||||
|
/// If the tree partition step results in a leaf node with the sum of
|
||||||
|
// instance weight less than `min_child_weight`, then the building process*
|
||||||
|
/// will give up further partitioning.
|
||||||
|
pub fn with_min_child_weight(mut self, min_child_weight: usize) -> Self {
|
||||||
|
self.min_child_weight = min_child_weight;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sets the L2 regularization term on weights (`lambda`).
|
||||||
|
///
|
||||||
|
/// Increasing this value will make the model more conservative.
|
||||||
|
pub fn with_lambda(mut self, lambda: f64) -> Self {
|
||||||
|
self.lambda = lambda;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sets the minimum loss reduction required to make a further partition on a leaf node.
|
||||||
|
///
|
||||||
|
/// The larger `gamma` is, the more conservative the algorithm will be.
|
||||||
|
pub fn with_gamma(mut self, gamma: f64) -> Self {
|
||||||
|
self.gamma = gamma;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sets the initial prediction score for all instances.
|
||||||
|
pub fn with_base_score(mut self, base_score: f64) -> Self {
|
||||||
|
self.base_score = base_score;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sets the fraction of samples to be used for fitting individual base learners.
|
||||||
|
///
|
||||||
|
/// A value of less than 1.0 introduces randomness and helps prevent overfitting.
|
||||||
|
pub fn with_subsample(mut self, subsample: f64) -> Self {
|
||||||
|
self.subsample = subsample;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sets the seed for the random number generator for reproducibility.
|
||||||
|
pub fn with_seed(mut self, seed: u64) -> Self {
|
||||||
|
self.seed = seed;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sets the objective function to be optimized during training.
|
||||||
|
pub fn with_objective(mut self, objective: Objective) -> Self {
|
||||||
|
self.objective = objective;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An Extreme Gradient Boosting (XGBoost) model for regression and classification tasks.
|
||||||
|
pub struct XGRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
|
||||||
|
regressors: Option<Vec<TreeRegressor<TX, TY, X, Y>>>,
|
||||||
|
parameters: Option<XGRegressorParameters>,
|
||||||
|
_phantom_ty: PhantomData<TY>,
|
||||||
|
_phantom_tx: PhantomData<TX>,
|
||||||
|
_phantom_y: PhantomData<Y>,
|
||||||
|
_phantom_x: PhantomData<X>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> XGRegressor<TX, TY, X, Y> {
|
||||||
|
/// Fits the XGBoost model to the training data.
|
||||||
|
pub fn fit(data: &X, y: &Y, parameters: XGRegressorParameters) -> Result<Self, Failed> {
|
||||||
|
if parameters.subsample > 1.0 || parameters.subsample <= 0.0 {
|
||||||
|
return Err(Failed::because(
|
||||||
|
FailedError::ParametersError,
|
||||||
|
"Subsample ratio must be in (0, 1].",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let (n_samples, _) = data.shape();
|
||||||
|
let learning_rate = parameters.learning_rate;
|
||||||
|
let mut predictions = vec![parameters.base_score; n_samples];
|
||||||
|
|
||||||
|
let mut regressors = Vec::new();
|
||||||
|
let mut rng = get_rng_impl(Some(parameters.seed));
|
||||||
|
|
||||||
|
for _ in 0..parameters.n_estimators {
|
||||||
|
let gradients = parameters.objective.gradient(y, &predictions);
|
||||||
|
let hessians = parameters.objective.hessian(y, &predictions);
|
||||||
|
|
||||||
|
let sample_idxs = if parameters.subsample < 1.0 {
|
||||||
|
Self::sample_without_replacement(n_samples, parameters.subsample, &mut rng)
|
||||||
|
} else {
|
||||||
|
(0..n_samples).collect::<Vec<usize>>()
|
||||||
|
};
|
||||||
|
|
||||||
|
let regressor = TreeRegressor::fit(
|
||||||
|
data,
|
||||||
|
&gradients,
|
||||||
|
&hessians,
|
||||||
|
&sample_idxs,
|
||||||
|
parameters.max_depth,
|
||||||
|
parameters.min_child_weight as f64,
|
||||||
|
parameters.lambda,
|
||||||
|
parameters.gamma,
|
||||||
|
);
|
||||||
|
|
||||||
|
let corrections = regressor.predict(data);
|
||||||
|
predictions = zip(predictions, corrections)
|
||||||
|
.map(|(pred, correction)| pred + (learning_rate * correction))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
regressors.push(regressor);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
regressors: Some(regressors),
|
||||||
|
parameters: Some(parameters),
|
||||||
|
_phantom_ty: PhantomData,
|
||||||
|
_phantom_y: PhantomData,
|
||||||
|
_phantom_tx: PhantomData,
|
||||||
|
_phantom_x: PhantomData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Predicts target values for the given input data.
|
||||||
|
pub fn predict(&self, data: &X) -> Result<Vec<TX>, Failed> {
|
||||||
|
let (n_samples, _) = data.shape();
|
||||||
|
|
||||||
|
let parameters = self.parameters.as_ref().unwrap();
|
||||||
|
let mut predictions = vec![parameters.base_score; n_samples];
|
||||||
|
let regressors = self.regressors.as_ref().unwrap();
|
||||||
|
|
||||||
|
for regressor in regressors.iter() {
|
||||||
|
let corrections = regressor.predict(data);
|
||||||
|
predictions = zip(predictions, corrections)
|
||||||
|
.map(|(pred, correction)| pred + (parameters.learning_rate * correction))
|
||||||
|
.collect();
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(predictions
|
||||||
|
.into_iter()
|
||||||
|
.map(|p| TX::from_f64(p).unwrap())
|
||||||
|
.collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a random sample of indices without replacement.
|
||||||
|
fn sample_without_replacement(
|
||||||
|
population_size: usize,
|
||||||
|
subsample_ratio: f64,
|
||||||
|
rng: &mut impl Rng,
|
||||||
|
) -> Vec<usize> {
|
||||||
|
let mut indices: Vec<usize> = (0..population_size).collect();
|
||||||
|
indices.shuffle(rng);
|
||||||
|
indices.truncate((population_size as f64 * subsample_ratio) as usize);
|
||||||
|
indices
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Boilerplate implementation for the smartcore traits
|
||||||
|
impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||||
|
SupervisedEstimatorBorrow<'_, X, Y, XGRegressorParameters> for XGRegressor<TX, TY, X, Y>
|
||||||
|
{
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
regressors: None,
|
||||||
|
parameters: None,
|
||||||
|
_phantom_ty: PhantomData,
|
||||||
|
_phantom_y: PhantomData,
|
||||||
|
_phantom_tx: PhantomData,
|
||||||
|
_phantom_x: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fit(x: &X, y: &Y, parameters: &XGRegressorParameters) -> Result<Self, Failed> {
|
||||||
|
XGRegressor::fit(x, y, parameters.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> PredictorBorrow<'_, X, TX>
|
||||||
|
for XGRegressor<TX, TY, X, Y>
|
||||||
|
{
|
||||||
|
fn predict(&self, x: &X) -> Result<Vec<TX>, Failed> {
|
||||||
|
self.predict(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ------------------- TESTS -------------------
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::linalg::basic::{arrays::Array, matrix::DenseMatrix};
|
||||||
|
|
||||||
|
/// Tests the gradient and hessian calculations for MeanSquaredError.
|
||||||
|
#[test]
|
||||||
|
fn test_mse_objective() {
|
||||||
|
let objective = Objective::MeanSquaredError;
|
||||||
|
let y_true = vec![1.0, 2.0, 3.0];
|
||||||
|
let y_pred = vec![1.5, 2.5, 2.5];
|
||||||
|
|
||||||
|
let gradients = objective.gradient(&y_true, &y_pred);
|
||||||
|
let hessians = objective.hessian(&y_true, &y_pred);
|
||||||
|
|
||||||
|
// Gradients should be (pred - true)
|
||||||
|
assert_eq!(gradients, vec![0.5, 0.5, -0.5]);
|
||||||
|
// Hessians should be all 1.0 for MSE
|
||||||
|
assert_eq!(hessians, vec![1.0, 1.0, 1.0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_find_best_split_multidimensional() {
|
||||||
|
// Data has two features. The second feature is a better predictor.
|
||||||
|
let data = vec![
|
||||||
|
vec![1.0, 10.0], // g = -0.5
|
||||||
|
vec![1.0, 20.0], // g = -1.0
|
||||||
|
vec![1.0, 30.0], // g = 1.0
|
||||||
|
vec![1.0, 40.0], // g = 1.5
|
||||||
|
];
|
||||||
|
let data = DenseMatrix::from_2d_vec(&data).unwrap();
|
||||||
|
let g = vec![-0.5, -1.0, 1.0, 1.5];
|
||||||
|
let h = vec![1.0, 1.0, 1.0, 1.0];
|
||||||
|
let idxs = (0..4).collect::<Vec<usize>>();
|
||||||
|
|
||||||
|
let mut best_feature_idx = usize::MAX;
|
||||||
|
let mut best_split_score = 0.0;
|
||||||
|
let mut best_threshold = 0.0;
|
||||||
|
|
||||||
|
// Manually calculated expected gain for the best split (on feature 1, with lambda=1.0).
|
||||||
|
// G_left = -1.5, H_left = 2.0
|
||||||
|
// G_right = 2.5, H_right = 2.0
|
||||||
|
// G_total = 1.0, H_total = 4.0
|
||||||
|
// Gain = 0.5 * (G_l^2/(H_l+λ) + G_r^2/(H_r+λ) - G_t^2/(H_t+λ))
|
||||||
|
// Gain = 0.5 * ((-1.5)^2/(2+1) + (2.5)^2/(2+1) - (1.0)^2/(4+1))
|
||||||
|
// Gain = 0.5 * (2.25/3 + 6.25/3 - 1.0/5) = 0.5 * (0.75 + 2.0833 - 0.2) = 1.3166...
|
||||||
|
let expected_gain = 1.3166666666666667;
|
||||||
|
|
||||||
|
// Search both features. The algorithm must find the best split on feature 1.
|
||||||
|
let (_, n_features) = data.shape();
|
||||||
|
for i in 0..n_features {
|
||||||
|
TreeRegressor::<f64, f64, DenseMatrix<f64>, Vec<f64>>::find_best_split(
|
||||||
|
&data,
|
||||||
|
&g,
|
||||||
|
&h,
|
||||||
|
&idxs,
|
||||||
|
i,
|
||||||
|
&mut best_feature_idx,
|
||||||
|
&mut best_split_score,
|
||||||
|
&mut best_threshold,
|
||||||
|
1.0,
|
||||||
|
1.0,
|
||||||
|
0.0,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(best_feature_idx, 1); // Should choose the second feature
|
||||||
|
assert!((best_split_score - expected_gain).abs() < 1e-9);
|
||||||
|
assert_eq!(best_threshold, 25.0); // (20 + 30) / 2
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tests that the TreeRegressor can build a simple one-level tree on multidimensional data.
|
||||||
|
#[test]
|
||||||
|
fn test_tree_regressor_fit_multidimensional() {
|
||||||
|
let data = vec![
|
||||||
|
vec![1.0, 10.0],
|
||||||
|
vec![1.0, 20.0],
|
||||||
|
vec![1.0, 30.0],
|
||||||
|
vec![1.0, 40.0],
|
||||||
|
];
|
||||||
|
let data = DenseMatrix::from_2d_vec(&data).unwrap();
|
||||||
|
let g = vec![-0.5, -1.0, 1.0, 1.5];
|
||||||
|
let h = vec![1.0, 1.0, 1.0, 1.0];
|
||||||
|
let idxs = (0..4).collect::<Vec<usize>>();
|
||||||
|
|
||||||
|
let tree = TreeRegressor::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(
|
||||||
|
&data, &g, &h, &idxs, 2, 1.0, 1.0, 0.0,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Check that the root node was split on the correct feature
|
||||||
|
assert!(tree.left.is_some());
|
||||||
|
assert!(tree.right.is_some());
|
||||||
|
assert_eq!(tree.split_feature_idx, 1); // Should split on the second feature
|
||||||
|
assert_eq!(tree.threshold, 25.0);
|
||||||
|
|
||||||
|
// Check leaf values (G/H+lambda)
|
||||||
|
// Left leaf: G = -1.5, H = 2.0 => value = -(-1.5)/(2+1) = 0.5
|
||||||
|
// Right leaf: G = 2.5, H = 2.0 => value = -(2.5)/(2+1) = -0.8333
|
||||||
|
assert!((tree.left.unwrap().value - 0.5).abs() < 1e-9);
|
||||||
|
assert!((tree.right.unwrap().value - (-0.833333333)).abs() < 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A "smoke test" to ensure the main XGRegressor can fit and predict on multidimensional data.
|
||||||
|
#[test]
|
||||||
|
fn test_xgregressor_fit_predict_multidimensional() {
|
||||||
|
// Simple 2D data where y is roughly 2*x1 + 3*x2
|
||||||
|
let x_vec = vec![
|
||||||
|
vec![1.0, 1.0],
|
||||||
|
vec![2.0, 1.0],
|
||||||
|
vec![1.0, 2.0],
|
||||||
|
vec![2.0, 2.0],
|
||||||
|
];
|
||||||
|
let x = DenseMatrix::from_2d_vec(&x_vec).unwrap();
|
||||||
|
let y = vec![5.0, 7.0, 8.0, 10.0];
|
||||||
|
|
||||||
|
let params = XGRegressorParameters::default()
|
||||||
|
.with_n_estimators(10)
|
||||||
|
.with_max_depth(2);
|
||||||
|
|
||||||
|
let fit_result = XGRegressor::fit(&x, &y, params);
|
||||||
|
assert!(
|
||||||
|
fit_result.is_ok(),
|
||||||
|
"Fit failed with error: {:?}",
|
||||||
|
fit_result.err()
|
||||||
|
);
|
||||||
|
|
||||||
|
let model = fit_result.unwrap();
|
||||||
|
let predict_result = model.predict(&x);
|
||||||
|
assert!(
|
||||||
|
predict_result.is_ok(),
|
||||||
|
"Predict failed with error: {:?}",
|
||||||
|
predict_result.err()
|
||||||
|
);
|
||||||
|
|
||||||
|
let predictions = predict_result.unwrap();
|
||||||
|
assert_eq!(predictions.len(), 4);
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user