refactored decision tree into reusable components (#316)

* refactored decision tree into reusable components

* got rid of api code from base tree because its an implementation detail

* got rid of api code from base tree because its an implementation detail

* changed name
This commit is contained in:
Daniel Lacina
2025-07-12 05:25:53 -05:00
committed by GitHub
parent 5cc5528367
commit c5816b0e1b
3 changed files with 583 additions and 385 deletions
+551
View File
@@ -0,0 +1,551 @@
use std::collections::LinkedList;
use std::default::Default;
use std::fmt::Debug;
use std::marker::PhantomData;
use rand::seq::SliceRandom;
use rand::Rng;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::error::Failed;
use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1};
use crate::numbers::basenum::Number;
use crate::rand_custom::get_rng_impl;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, Default)]
pub enum Splitter {
Random,
#[default]
Best,
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
/// Parameters of Regression base_tree
pub struct BaseTreeRegressorParameters {
#[cfg_attr(feature = "serde", serde(default))]
/// The maximum depth of the base_tree.
pub max_depth: Option<u16>,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to be at a leaf node.
pub min_samples_leaf: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// The minimum number of samples required to split an internal node.
pub min_samples_split: usize,
#[cfg_attr(feature = "serde", serde(default))]
/// Controls the randomness of the estimator
pub seed: Option<u64>,
#[cfg_attr(feature = "serde", serde(default))]
/// Determines the strategy used to choose the split at each node.
pub splitter: Splitter,
}
/// Regression base_tree
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct BaseTreeRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
nodes: Vec<Node>,
parameters: Option<BaseTreeRegressorParameters>,
depth: u16,
_phantom_tx: PhantomData<TX>,
_phantom_ty: PhantomData<TY>,
_phantom_x: PhantomData<X>,
_phantom_y: PhantomData<Y>,
}
impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
BaseTreeRegressor<TX, TY, X, Y>
{
/// Get nodes, return a shared reference
fn nodes(&self) -> &Vec<Node> {
self.nodes.as_ref()
}
/// Get parameters, return a shared reference
fn parameters(&self) -> &BaseTreeRegressorParameters {
self.parameters.as_ref().unwrap()
}
/// Get estimate of intercept, return value
fn depth(&self) -> u16 {
self.depth
}
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
struct Node {
output: f64,
split_feature: usize,
split_value: Option<f64>,
split_score: Option<f64>,
true_child: Option<usize>,
false_child: Option<usize>,
}
impl Node {
fn new(output: f64) -> Self {
Node {
output,
split_feature: 0,
split_value: Option::None,
split_score: Option::None,
true_child: Option::None,
false_child: Option::None,
}
}
}
impl PartialEq for Node {
fn eq(&self, other: &Self) -> bool {
(self.output - other.output).abs() < f64::EPSILON
&& self.split_feature == other.split_feature
&& match (self.split_value, other.split_value) {
(Some(a), Some(b)) => (a - b).abs() < f64::EPSILON,
(None, None) => true,
_ => false,
}
&& match (self.split_score, other.split_score) {
(Some(a), Some(b)) => (a - b).abs() < f64::EPSILON,
(None, None) => true,
_ => false,
}
}
}
impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
for BaseTreeRegressor<TX, TY, X, Y>
{
fn eq(&self, other: &Self) -> bool {
if self.depth != other.depth || self.nodes().len() != other.nodes().len() {
false
} else {
self.nodes()
.iter()
.zip(other.nodes().iter())
.all(|(a, b)| a == b)
}
}
}
struct NodeVisitor<'a, TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
x: &'a X,
y: &'a Y,
node: usize,
samples: Vec<usize>,
order: &'a [Vec<usize>],
true_child_output: f64,
false_child_output: f64,
level: u16,
_phantom_tx: PhantomData<TX>,
_phantom_ty: PhantomData<TY>,
}
impl<'a, TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
NodeVisitor<'a, TX, TY, X, Y>
{
fn new(
node_id: usize,
samples: Vec<usize>,
order: &'a [Vec<usize>],
x: &'a X,
y: &'a Y,
level: u16,
) -> Self {
NodeVisitor {
x,
y,
node: node_id,
samples,
order,
true_child_output: 0f64,
false_child_output: 0f64,
level,
_phantom_tx: PhantomData,
_phantom_ty: PhantomData,
}
}
}
impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
BaseTreeRegressor<TX, TY, X, Y>
{
/// Build a decision base_tree regressor from the training data.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - the target values
pub fn fit(
x: &X,
y: &Y,
parameters: BaseTreeRegressorParameters,
) -> Result<BaseTreeRegressor<TX, TY, X, Y>, Failed> {
let (x_nrows, num_attributes) = x.shape();
if x_nrows != y.shape() {
return Err(Failed::fit("Size of x should equal size of y"));
}
let samples = vec![1; x_nrows];
BaseTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
}
pub(crate) fn fit_weak_learner(
x: &X,
y: &Y,
samples: Vec<usize>,
mtry: usize,
parameters: BaseTreeRegressorParameters,
) -> Result<BaseTreeRegressor<TX, TY, X, Y>, Failed> {
let y_m = y.clone();
let y_ncols = y_m.shape();
let (_, num_attributes) = x.shape();
let mut nodes: Vec<Node> = Vec::new();
let mut rng = get_rng_impl(parameters.seed);
let mut n = 0;
let mut sum = 0f64;
for (i, sample_i) in samples.iter().enumerate().take(y_ncols) {
n += *sample_i;
sum += *sample_i as f64 * y_m.get(i).to_f64().unwrap();
}
let root = Node::new(sum / (n as f64));
nodes.push(root);
let mut order: Vec<Vec<usize>> = Vec::new();
for i in 0..num_attributes {
let mut col_i: Vec<TX> = x.get_col(i).iterator(0).copied().collect();
order.push(col_i.argsort_mut());
}
let mut base_tree = BaseTreeRegressor {
nodes,
parameters: Some(parameters),
depth: 0u16,
_phantom_tx: PhantomData,
_phantom_ty: PhantomData,
_phantom_x: PhantomData,
_phantom_y: PhantomData,
};
let mut visitor = NodeVisitor::<TX, TY, X, Y>::new(0, samples, &order, x, &y_m, 1);
let mut visitor_queue: LinkedList<NodeVisitor<'_, TX, TY, X, Y>> = LinkedList::new();
if base_tree.find_best_cutoff(&mut visitor, mtry, &mut rng) {
visitor_queue.push_back(visitor);
}
while base_tree.depth() < base_tree.parameters().max_depth.unwrap_or(u16::MAX) {
match visitor_queue.pop_front() {
Some(node) => base_tree.split(node, mtry, &mut visitor_queue, &mut rng),
None => break,
};
}
Ok(base_tree)
}
/// Predict regression value for `x`.
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let mut result = Y::zeros(x.shape().0);
let (n, _) = x.shape();
for i in 0..n {
result.set(i, self.predict_for_row(x, i));
}
Ok(result)
}
pub(crate) fn predict_for_row(&self, x: &X, row: usize) -> TY {
let mut result = 0f64;
let mut queue: LinkedList<usize> = LinkedList::new();
queue.push_back(0);
while !queue.is_empty() {
match queue.pop_front() {
Some(node_id) => {
let node = &self.nodes()[node_id];
if node.true_child.is_none() && node.false_child.is_none() {
result = node.output;
} else if x.get((row, node.split_feature)).to_f64().unwrap()
<= node.split_value.unwrap_or(f64::NAN)
{
queue.push_back(node.true_child.unwrap());
} else {
queue.push_back(node.false_child.unwrap());
}
}
None => break,
};
}
TY::from_f64(result).unwrap()
}
fn find_best_cutoff(
&mut self,
visitor: &mut NodeVisitor<'_, TX, TY, X, Y>,
mtry: usize,
rng: &mut impl Rng,
) -> bool {
let (_, n_attr) = visitor.x.shape();
let n: usize = visitor.samples.iter().sum();
if n < self.parameters().min_samples_split {
return false;
}
let sum = self.nodes()[visitor.node].output * n as f64;
let mut variables = (0..n_attr).collect::<Vec<_>>();
if mtry < n_attr {
variables.shuffle(rng);
}
let parent_gain =
n as f64 * self.nodes()[visitor.node].output * self.nodes()[visitor.node].output;
let splitter = self.parameters().splitter.clone();
for variable in variables.iter().take(mtry) {
match splitter {
Splitter::Random => {
self.find_random_split(visitor, n, sum, parent_gain, *variable, rng);
}
Splitter::Best => {
self.find_best_split(visitor, n, sum, parent_gain, *variable);
}
}
}
self.nodes()[visitor.node].split_score.is_some()
}
fn find_random_split(
&mut self,
visitor: &mut NodeVisitor<'_, TX, TY, X, Y>,
n: usize,
sum: f64,
parent_gain: f64,
j: usize,
rng: &mut impl Rng,
) {
let (min_val, max_val) = {
let mut min_opt = None;
let mut max_opt = None;
for &i in &visitor.order[j] {
if visitor.samples[i] > 0 {
min_opt = Some(*visitor.x.get((i, j)));
break;
}
}
for &i in visitor.order[j].iter().rev() {
if visitor.samples[i] > 0 {
max_opt = Some(*visitor.x.get((i, j)));
break;
}
}
if min_opt.is_none() {
return;
}
(min_opt.unwrap(), max_opt.unwrap())
};
if min_val >= max_val {
return;
}
let split_value = rng.gen_range(min_val.to_f64().unwrap()..max_val.to_f64().unwrap());
let mut true_sum = 0f64;
let mut true_count = 0;
for &i in &visitor.order[j] {
if visitor.samples[i] > 0 {
if visitor.x.get((i, j)).to_f64().unwrap() <= split_value {
true_sum += visitor.samples[i] as f64 * visitor.y.get(i).to_f64().unwrap();
true_count += visitor.samples[i];
} else {
break;
}
}
}
let false_count = n - true_count;
if true_count < self.parameters().min_samples_leaf
|| false_count < self.parameters().min_samples_leaf
{
return;
}
let true_mean = if true_count > 0 {
true_sum / true_count as f64
} else {
0.0
};
let false_mean = if false_count > 0 {
(sum - true_sum) / false_count as f64
} else {
0.0
};
let gain = (true_count as f64 * true_mean * true_mean
+ false_count as f64 * false_mean * false_mean)
- parent_gain;
if self.nodes[visitor.node].split_score.is_none()
|| gain > self.nodes[visitor.node].split_score.unwrap()
{
self.nodes[visitor.node].split_feature = j;
self.nodes[visitor.node].split_value = Some(split_value);
self.nodes[visitor.node].split_score = Some(gain);
visitor.true_child_output = true_mean;
visitor.false_child_output = false_mean;
}
}
fn find_best_split(
&mut self,
visitor: &mut NodeVisitor<'_, TX, TY, X, Y>,
n: usize,
sum: f64,
parent_gain: f64,
j: usize,
) {
let mut true_sum = 0f64;
let mut true_count = 0;
let mut prevx = Option::None;
for i in visitor.order[j].iter() {
if visitor.samples[*i] > 0 {
let x_ij = *visitor.x.get((*i, j));
if prevx.is_none() || x_ij == prevx.unwrap() {
prevx = Some(x_ij);
true_count += visitor.samples[*i];
true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
continue;
}
let false_count = n - true_count;
if true_count < self.parameters().min_samples_leaf
|| false_count < self.parameters().min_samples_leaf
{
prevx = Some(x_ij);
true_count += visitor.samples[*i];
true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
continue;
}
let true_mean = true_sum / true_count as f64;
let false_mean = (sum - true_sum) / false_count as f64;
let gain = (true_count as f64 * true_mean * true_mean
+ false_count as f64 * false_mean * false_mean)
- parent_gain;
if self.nodes()[visitor.node].split_score.is_none()
|| gain > self.nodes()[visitor.node].split_score.unwrap()
{
self.nodes[visitor.node].split_feature = j;
self.nodes[visitor.node].split_value =
Option::Some((x_ij + prevx.unwrap()).to_f64().unwrap() / 2f64);
self.nodes[visitor.node].split_score = Option::Some(gain);
visitor.true_child_output = true_mean;
visitor.false_child_output = false_mean;
}
prevx = Some(x_ij);
true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
true_count += visitor.samples[*i];
}
}
}
fn split<'a>(
&mut self,
mut visitor: NodeVisitor<'a, TX, TY, X, Y>,
mtry: usize,
visitor_queue: &mut LinkedList<NodeVisitor<'a, TX, TY, X, Y>>,
rng: &mut impl Rng,
) -> bool {
let (n, _) = visitor.x.shape();
let mut tc = 0;
let mut fc = 0;
let mut true_samples: Vec<usize> = vec![0; n];
for (i, true_sample) in true_samples.iter_mut().enumerate().take(n) {
if visitor.samples[i] > 0 {
if visitor
.x
.get((i, self.nodes()[visitor.node].split_feature))
.to_f64()
.unwrap()
<= self.nodes()[visitor.node].split_value.unwrap_or(f64::NAN)
{
*true_sample = visitor.samples[i];
tc += *true_sample;
visitor.samples[i] = 0;
} else {
fc += visitor.samples[i];
}
}
}
if tc < self.parameters().min_samples_leaf || fc < self.parameters().min_samples_leaf {
self.nodes[visitor.node].split_feature = 0;
self.nodes[visitor.node].split_value = Option::None;
self.nodes[visitor.node].split_score = Option::None;
return false;
}
let true_child_idx = self.nodes().len();
self.nodes.push(Node::new(visitor.true_child_output));
let false_child_idx = self.nodes().len();
self.nodes.push(Node::new(visitor.false_child_output));
self.nodes[visitor.node].true_child = Some(true_child_idx);
self.nodes[visitor.node].false_child = Some(false_child_idx);
self.depth = u16::max(self.depth, visitor.level + 1);
let mut true_visitor = NodeVisitor::<TX, TY, X, Y>::new(
true_child_idx,
true_samples,
visitor.order,
visitor.x,
visitor.y,
visitor.level + 1,
);
if self.find_best_cutoff(&mut true_visitor, mtry, rng) {
visitor_queue.push_back(true_visitor);
}
let mut false_visitor = NodeVisitor::<TX, TY, X, Y>::new(
false_child_idx,
visitor.samples,
visitor.order,
visitor.x,
visitor.y,
visitor.level + 1,
);
if self.find_best_cutoff(&mut false_visitor, mtry, rng) {
visitor_queue.push_back(false_visitor);
}
true
}
}
+31 -385
View File
@@ -58,22 +58,17 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use std::collections::LinkedList;
use std::default::Default; use std::default::Default;
use std::fmt::Debug; use std::fmt::Debug;
use std::marker::PhantomData;
use rand::seq::SliceRandom;
use rand::Rng;
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::base_tree_regressor::{BaseTreeRegressor, BaseTreeRegressorParameters, Splitter};
use crate::api::{Predictor, SupervisedEstimator}; use crate::api::{Predictor, SupervisedEstimator};
use crate::error::Failed; use crate::error::Failed;
use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1}; use crate::linalg::basic::arrays::{Array1, Array2};
use crate::numbers::basenum::Number; use crate::numbers::basenum::Number;
use crate::rand_custom::get_rng_impl;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -98,41 +93,7 @@ pub struct DecisionTreeRegressorParameters {
#[derive(Debug)] #[derive(Debug)]
pub struct DecisionTreeRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> pub struct DecisionTreeRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
{ {
nodes: Vec<Node>, tree_regressor: Option<BaseTreeRegressor<TX, TY, X, Y>>,
parameters: Option<DecisionTreeRegressorParameters>,
depth: u16,
_phantom_tx: PhantomData<TX>,
_phantom_ty: PhantomData<TY>,
_phantom_x: PhantomData<X>,
_phantom_y: PhantomData<Y>,
}
impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
DecisionTreeRegressor<TX, TY, X, Y>
{
/// Get nodes, return a shared reference
fn nodes(&self) -> &Vec<Node> {
self.nodes.as_ref()
}
/// Get parameters, return a shared reference
fn parameters(&self) -> &DecisionTreeRegressorParameters {
self.parameters.as_ref().unwrap()
}
/// Get estimate of intercept, return value
fn depth(&self) -> u16 {
self.depth
}
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
struct Node {
output: f64,
split_feature: usize,
split_value: Option<f64>,
split_score: Option<f64>,
true_child: Option<usize>,
false_child: Option<usize>,
} }
impl DecisionTreeRegressorParameters { impl DecisionTreeRegressorParameters {
@@ -296,87 +257,11 @@ impl Default for DecisionTreeRegressorSearchParameters {
} }
} }
impl Node {
fn new(output: f64) -> Self {
Node {
output,
split_feature: 0,
split_value: Option::None,
split_score: Option::None,
true_child: Option::None,
false_child: Option::None,
}
}
}
impl PartialEq for Node {
fn eq(&self, other: &Self) -> bool {
(self.output - other.output).abs() < f64::EPSILON
&& self.split_feature == other.split_feature
&& match (self.split_value, other.split_value) {
(Some(a), Some(b)) => (a - b).abs() < f64::EPSILON,
(None, None) => true,
_ => false,
}
&& match (self.split_score, other.split_score) {
(Some(a), Some(b)) => (a - b).abs() < f64::EPSILON,
(None, None) => true,
_ => false,
}
}
}
impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> PartialEq
for DecisionTreeRegressor<TX, TY, X, Y> for DecisionTreeRegressor<TX, TY, X, Y>
{ {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
if self.depth != other.depth || self.nodes().len() != other.nodes().len() { self.tree_regressor == other.tree_regressor
false
} else {
self.nodes()
.iter()
.zip(other.nodes().iter())
.all(|(a, b)| a == b)
}
}
}
struct NodeVisitor<'a, TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
x: &'a X,
y: &'a Y,
node: usize,
samples: Vec<usize>,
order: &'a [Vec<usize>],
true_child_output: f64,
false_child_output: f64,
level: u16,
_phantom_tx: PhantomData<TX>,
_phantom_ty: PhantomData<TY>,
}
impl<'a, TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
NodeVisitor<'a, TX, TY, X, Y>
{
fn new(
node_id: usize,
samples: Vec<usize>,
order: &'a [Vec<usize>],
x: &'a X,
y: &'a Y,
level: u16,
) -> Self {
NodeVisitor {
x,
y,
node: node_id,
samples,
order,
true_child_output: 0f64,
false_child_output: 0f64,
level,
_phantom_tx: PhantomData,
_phantom_ty: PhantomData,
}
} }
} }
@@ -386,13 +271,7 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
{ {
fn new() -> Self { fn new() -> Self {
Self { Self {
nodes: vec![], tree_regressor: None,
parameters: Option::None,
depth: 0u16,
_phantom_tx: PhantomData,
_phantom_ty: PhantomData,
_phantom_x: PhantomData,
_phantom_y: PhantomData,
} }
} }
@@ -420,13 +299,17 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
y: &Y, y: &Y,
parameters: DecisionTreeRegressorParameters, parameters: DecisionTreeRegressorParameters,
) -> Result<DecisionTreeRegressor<TX, TY, X, Y>, Failed> { ) -> Result<DecisionTreeRegressor<TX, TY, X, Y>, Failed> {
let (x_nrows, num_attributes) = x.shape(); let tree_parameters = BaseTreeRegressorParameters {
if x_nrows != y.shape() { max_depth: parameters.max_depth,
return Err(Failed::fit("Size of x should equal size of y")); min_samples_leaf: parameters.min_samples_leaf,
} min_samples_split: parameters.min_samples_split,
seed: parameters.seed,
let samples = vec![1; x_nrows]; splitter: Splitter::Best,
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters) };
let tree = BaseTreeRegressor::fit(x, y, tree_parameters)?;
Ok(Self {
tree_regressor: Some(tree),
})
} }
pub(crate) fn fit_weak_learner( pub(crate) fn fit_weak_learner(
@@ -436,267 +319,30 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
mtry: usize, mtry: usize,
parameters: DecisionTreeRegressorParameters, parameters: DecisionTreeRegressorParameters,
) -> Result<DecisionTreeRegressor<TX, TY, X, Y>, Failed> { ) -> Result<DecisionTreeRegressor<TX, TY, X, Y>, Failed> {
let y_m = y.clone(); let tree_parameters = BaseTreeRegressorParameters {
max_depth: parameters.max_depth,
let y_ncols = y_m.shape(); min_samples_leaf: parameters.min_samples_leaf,
let (_, num_attributes) = x.shape(); min_samples_split: parameters.min_samples_split,
seed: parameters.seed,
let mut nodes: Vec<Node> = Vec::new(); splitter: Splitter::Best,
let mut rng = get_rng_impl(parameters.seed);
let mut n = 0;
let mut sum = 0f64;
for (i, sample_i) in samples.iter().enumerate().take(y_ncols) {
n += *sample_i;
sum += *sample_i as f64 * y_m.get(i).to_f64().unwrap();
}
let root = Node::new(sum / (n as f64));
nodes.push(root);
let mut order: Vec<Vec<usize>> = Vec::new();
for i in 0..num_attributes {
let mut col_i: Vec<TX> = x.get_col(i).iterator(0).copied().collect();
order.push(col_i.argsort_mut());
}
let mut tree = DecisionTreeRegressor {
nodes,
parameters: Some(parameters),
depth: 0u16,
_phantom_tx: PhantomData,
_phantom_ty: PhantomData,
_phantom_x: PhantomData,
_phantom_y: PhantomData,
}; };
let tree = BaseTreeRegressor::fit_weak_learner(x, y, samples, mtry, tree_parameters)?;
let mut visitor = NodeVisitor::<TX, TY, X, Y>::new(0, samples, &order, x, &y_m, 1); Ok(Self {
tree_regressor: Some(tree),
let mut visitor_queue: LinkedList<NodeVisitor<'_, TX, TY, X, Y>> = LinkedList::new(); })
if tree.find_best_cutoff(&mut visitor, mtry, &mut rng) {
visitor_queue.push_back(visitor);
}
while tree.depth() < tree.parameters().max_depth.unwrap_or(u16::MAX) {
match visitor_queue.pop_front() {
Some(node) => tree.split(node, mtry, &mut visitor_queue, &mut rng),
None => break,
};
}
Ok(tree)
} }
/// Predict regression value for `x`. /// Predict regression value for `x`.
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &X) -> Result<Y, Failed> { pub fn predict(&self, x: &X) -> Result<Y, Failed> {
let mut result = Y::zeros(x.shape().0); self.tree_regressor.as_ref().unwrap().predict(x)
let (n, _) = x.shape();
for i in 0..n {
result.set(i, self.predict_for_row(x, i));
}
Ok(result)
} }
pub(crate) fn predict_for_row(&self, x: &X, row: usize) -> TY { pub(crate) fn predict_for_row(&self, x: &X, row: usize) -> TY {
let mut result = 0f64; self.tree_regressor
let mut queue: LinkedList<usize> = LinkedList::new(); .as_ref()
.unwrap()
queue.push_back(0); .predict_for_row(x, row)
while !queue.is_empty() {
match queue.pop_front() {
Some(node_id) => {
let node = &self.nodes()[node_id];
if node.true_child.is_none() && node.false_child.is_none() {
result = node.output;
} else if x.get((row, node.split_feature)).to_f64().unwrap()
<= node.split_value.unwrap_or(f64::NAN)
{
queue.push_back(node.true_child.unwrap());
} else {
queue.push_back(node.false_child.unwrap());
}
}
None => break,
};
}
TY::from_f64(result).unwrap()
}
fn find_best_cutoff(
&mut self,
visitor: &mut NodeVisitor<'_, TX, TY, X, Y>,
mtry: usize,
rng: &mut impl Rng,
) -> bool {
let (_, n_attr) = visitor.x.shape();
let n: usize = visitor.samples.iter().sum();
if n < self.parameters().min_samples_split {
return false;
}
let sum = self.nodes()[visitor.node].output * n as f64;
let mut variables = (0..n_attr).collect::<Vec<_>>();
if mtry < n_attr {
variables.shuffle(rng);
}
let parent_gain =
n as f64 * self.nodes()[visitor.node].output * self.nodes()[visitor.node].output;
for variable in variables.iter().take(mtry) {
self.find_best_split(visitor, n, sum, parent_gain, *variable);
}
self.nodes()[visitor.node].split_score.is_some()
}
fn find_best_split(
&mut self,
visitor: &mut NodeVisitor<'_, TX, TY, X, Y>,
n: usize,
sum: f64,
parent_gain: f64,
j: usize,
) {
let mut true_sum = 0f64;
let mut true_count = 0;
let mut prevx = Option::None;
for i in visitor.order[j].iter() {
if visitor.samples[*i] > 0 {
let x_ij = *visitor.x.get((*i, j));
if prevx.is_none() || x_ij == prevx.unwrap() {
prevx = Some(x_ij);
true_count += visitor.samples[*i];
true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
continue;
}
let false_count = n - true_count;
if true_count < self.parameters().min_samples_leaf
|| false_count < self.parameters().min_samples_leaf
{
prevx = Some(x_ij);
true_count += visitor.samples[*i];
true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
continue;
}
let true_mean = true_sum / true_count as f64;
let false_mean = (sum - true_sum) / false_count as f64;
let gain = (true_count as f64 * true_mean * true_mean
+ false_count as f64 * false_mean * false_mean)
- parent_gain;
if self.nodes()[visitor.node].split_score.is_none()
|| gain > self.nodes()[visitor.node].split_score.unwrap()
{
self.nodes[visitor.node].split_feature = j;
self.nodes[visitor.node].split_value =
Option::Some((x_ij + prevx.unwrap()).to_f64().unwrap() / 2f64);
self.nodes[visitor.node].split_score = Option::Some(gain);
visitor.true_child_output = true_mean;
visitor.false_child_output = false_mean;
}
prevx = Some(x_ij);
true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
true_count += visitor.samples[*i];
}
}
}
fn split<'a>(
&mut self,
mut visitor: NodeVisitor<'a, TX, TY, X, Y>,
mtry: usize,
visitor_queue: &mut LinkedList<NodeVisitor<'a, TX, TY, X, Y>>,
rng: &mut impl Rng,
) -> bool {
let (n, _) = visitor.x.shape();
let mut tc = 0;
let mut fc = 0;
let mut true_samples: Vec<usize> = vec![0; n];
for (i, true_sample) in true_samples.iter_mut().enumerate().take(n) {
if visitor.samples[i] > 0 {
if visitor
.x
.get((i, self.nodes()[visitor.node].split_feature))
.to_f64()
.unwrap()
<= self.nodes()[visitor.node].split_value.unwrap_or(f64::NAN)
{
*true_sample = visitor.samples[i];
tc += *true_sample;
visitor.samples[i] = 0;
} else {
fc += visitor.samples[i];
}
}
}
if tc < self.parameters().min_samples_leaf || fc < self.parameters().min_samples_leaf {
self.nodes[visitor.node].split_feature = 0;
self.nodes[visitor.node].split_value = Option::None;
self.nodes[visitor.node].split_score = Option::None;
return false;
}
let true_child_idx = self.nodes().len();
self.nodes.push(Node::new(visitor.true_child_output));
let false_child_idx = self.nodes().len();
self.nodes.push(Node::new(visitor.false_child_output));
self.nodes[visitor.node].true_child = Some(true_child_idx);
self.nodes[visitor.node].false_child = Some(false_child_idx);
self.depth = u16::max(self.depth, visitor.level + 1);
let mut true_visitor = NodeVisitor::<TX, TY, X, Y>::new(
true_child_idx,
true_samples,
visitor.order,
visitor.x,
visitor.y,
visitor.level + 1,
);
if self.find_best_cutoff(&mut true_visitor, mtry, rng) {
visitor_queue.push_back(true_visitor);
}
let mut false_visitor = NodeVisitor::<TX, TY, X, Y>::new(
false_child_idx,
visitor.samples,
visitor.order,
visitor.x,
visitor.y,
visitor.level + 1,
);
if self.find_best_cutoff(&mut false_visitor, mtry, rng) {
visitor_queue.push_back(false_visitor);
}
true
} }
} }
+1
View File
@@ -19,6 +19,7 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
mod base_tree_regressor;
/// Classification tree for dependent variables that take a finite number of unordered values. /// Classification tree for dependent variables that take a finite number of unordered values.
pub mod decision_tree_classifier; pub mod decision_tree_classifier;
/// Regression tree for for dependent variables that take continuous or ordered discrete values. /// Regression tree for for dependent variables that take continuous or ordered discrete values.