diff --git a/src/tree/base_tree_regressor.rs b/src/tree/base_tree_regressor.rs new file mode 100644 index 0000000..8728894 --- /dev/null +++ b/src/tree/base_tree_regressor.rs @@ -0,0 +1,551 @@ +use std::collections::LinkedList; +use std::default::Default; +use std::fmt::Debug; +use std::marker::PhantomData; + +use rand::seq::SliceRandom; +use rand::Rng; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use crate::error::Failed; +use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1}; +use crate::numbers::basenum::Number; +use crate::rand_custom::get_rng_impl; + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, Default)] +pub enum Splitter { + Random, + #[default] + Best, +} + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +/// Parameters of Regression base_tree +pub struct BaseTreeRegressorParameters { + #[cfg_attr(feature = "serde", serde(default))] + /// The maximum depth of the base_tree. + pub max_depth: Option, + #[cfg_attr(feature = "serde", serde(default))] + /// The minimum number of samples required to be at a leaf node. + pub min_samples_leaf: usize, + #[cfg_attr(feature = "serde", serde(default))] + /// The minimum number of samples required to split an internal node. + pub min_samples_split: usize, + #[cfg_attr(feature = "serde", serde(default))] + /// Controls the randomness of the estimator + pub seed: Option, + #[cfg_attr(feature = "serde", serde(default))] + /// Determines the strategy used to choose the split at each node. + pub splitter: Splitter, +} + +/// Regression base_tree +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] +pub struct BaseTreeRegressor, Y: Array1> { + nodes: Vec, + parameters: Option, + depth: u16, + _phantom_tx: PhantomData, + _phantom_ty: PhantomData, + _phantom_x: PhantomData, + _phantom_y: PhantomData, +} + +impl, Y: Array1> + BaseTreeRegressor +{ + /// Get nodes, return a shared reference + fn nodes(&self) -> &Vec { + self.nodes.as_ref() + } + /// Get parameters, return a shared reference + fn parameters(&self) -> &BaseTreeRegressorParameters { + self.parameters.as_ref().unwrap() + } + /// Get estimate of intercept, return value + fn depth(&self) -> u16 { + self.depth + } +} + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +struct Node { + output: f64, + split_feature: usize, + split_value: Option, + split_score: Option, + true_child: Option, + false_child: Option, +} + +impl Node { + fn new(output: f64) -> Self { + Node { + output, + split_feature: 0, + split_value: Option::None, + split_score: Option::None, + true_child: Option::None, + false_child: Option::None, + } + } +} + +impl PartialEq for Node { + fn eq(&self, other: &Self) -> bool { + (self.output - other.output).abs() < f64::EPSILON + && self.split_feature == other.split_feature + && match (self.split_value, other.split_value) { + (Some(a), Some(b)) => (a - b).abs() < f64::EPSILON, + (None, None) => true, + _ => false, + } + && match (self.split_score, other.split_score) { + (Some(a), Some(b)) => (a - b).abs() < f64::EPSILON, + (None, None) => true, + _ => false, + } + } +} + +impl, Y: Array1> PartialEq + for BaseTreeRegressor +{ + fn eq(&self, other: &Self) -> bool { + if self.depth != other.depth || self.nodes().len() != other.nodes().len() { + false + } else { + self.nodes() + .iter() + .zip(other.nodes().iter()) + .all(|(a, b)| a == b) + } + } +} + +struct NodeVisitor<'a, TX: Number + PartialOrd, TY: Number, X: Array2, Y: Array1> { + x: &'a X, + y: &'a Y, + node: usize, + samples: Vec, + order: &'a [Vec], + true_child_output: f64, + false_child_output: f64, + level: u16, + _phantom_tx: PhantomData, + _phantom_ty: PhantomData, +} + +impl<'a, TX: Number + PartialOrd, TY: Number, X: Array2, Y: Array1> + NodeVisitor<'a, TX, TY, X, Y> +{ + fn new( + node_id: usize, + samples: Vec, + order: &'a [Vec], + x: &'a X, + y: &'a Y, + level: u16, + ) -> Self { + NodeVisitor { + x, + y, + node: node_id, + samples, + order, + true_child_output: 0f64, + false_child_output: 0f64, + level, + _phantom_tx: PhantomData, + _phantom_ty: PhantomData, + } + } +} + +impl, Y: Array1> + BaseTreeRegressor +{ + /// Build a decision base_tree regressor from the training data. + /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation. + /// * `y` - the target values + pub fn fit( + x: &X, + y: &Y, + parameters: BaseTreeRegressorParameters, + ) -> Result, Failed> { + let (x_nrows, num_attributes) = x.shape(); + if x_nrows != y.shape() { + return Err(Failed::fit("Size of x should equal size of y")); + } + + let samples = vec![1; x_nrows]; + BaseTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters) + } + + pub(crate) fn fit_weak_learner( + x: &X, + y: &Y, + samples: Vec, + mtry: usize, + parameters: BaseTreeRegressorParameters, + ) -> Result, Failed> { + let y_m = y.clone(); + + let y_ncols = y_m.shape(); + let (_, num_attributes) = x.shape(); + + let mut nodes: Vec = Vec::new(); + let mut rng = get_rng_impl(parameters.seed); + + let mut n = 0; + let mut sum = 0f64; + for (i, sample_i) in samples.iter().enumerate().take(y_ncols) { + n += *sample_i; + sum += *sample_i as f64 * y_m.get(i).to_f64().unwrap(); + } + + let root = Node::new(sum / (n as f64)); + nodes.push(root); + let mut order: Vec> = Vec::new(); + + for i in 0..num_attributes { + let mut col_i: Vec = x.get_col(i).iterator(0).copied().collect(); + order.push(col_i.argsort_mut()); + } + + let mut base_tree = BaseTreeRegressor { + nodes, + parameters: Some(parameters), + depth: 0u16, + _phantom_tx: PhantomData, + _phantom_ty: PhantomData, + _phantom_x: PhantomData, + _phantom_y: PhantomData, + }; + + let mut visitor = NodeVisitor::::new(0, samples, &order, x, &y_m, 1); + + let mut visitor_queue: LinkedList> = LinkedList::new(); + + if base_tree.find_best_cutoff(&mut visitor, mtry, &mut rng) { + visitor_queue.push_back(visitor); + } + + while base_tree.depth() < base_tree.parameters().max_depth.unwrap_or(u16::MAX) { + match visitor_queue.pop_front() { + Some(node) => base_tree.split(node, mtry, &mut visitor_queue, &mut rng), + None => break, + }; + } + + Ok(base_tree) + } + + /// Predict regression value for `x`. + /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. + pub fn predict(&self, x: &X) -> Result { + let mut result = Y::zeros(x.shape().0); + + let (n, _) = x.shape(); + + for i in 0..n { + result.set(i, self.predict_for_row(x, i)); + } + + Ok(result) + } + + pub(crate) fn predict_for_row(&self, x: &X, row: usize) -> TY { + let mut result = 0f64; + let mut queue: LinkedList = LinkedList::new(); + + queue.push_back(0); + + while !queue.is_empty() { + match queue.pop_front() { + Some(node_id) => { + let node = &self.nodes()[node_id]; + if node.true_child.is_none() && node.false_child.is_none() { + result = node.output; + } else if x.get((row, node.split_feature)).to_f64().unwrap() + <= node.split_value.unwrap_or(f64::NAN) + { + queue.push_back(node.true_child.unwrap()); + } else { + queue.push_back(node.false_child.unwrap()); + } + } + None => break, + }; + } + + TY::from_f64(result).unwrap() + } + + fn find_best_cutoff( + &mut self, + visitor: &mut NodeVisitor<'_, TX, TY, X, Y>, + mtry: usize, + rng: &mut impl Rng, + ) -> bool { + let (_, n_attr) = visitor.x.shape(); + + let n: usize = visitor.samples.iter().sum(); + + if n < self.parameters().min_samples_split { + return false; + } + + let sum = self.nodes()[visitor.node].output * n as f64; + + let mut variables = (0..n_attr).collect::>(); + + if mtry < n_attr { + variables.shuffle(rng); + } + + let parent_gain = + n as f64 * self.nodes()[visitor.node].output * self.nodes()[visitor.node].output; + + let splitter = self.parameters().splitter.clone(); + + for variable in variables.iter().take(mtry) { + match splitter { + Splitter::Random => { + self.find_random_split(visitor, n, sum, parent_gain, *variable, rng); + } + Splitter::Best => { + self.find_best_split(visitor, n, sum, parent_gain, *variable); + } + } + } + + self.nodes()[visitor.node].split_score.is_some() + } + + fn find_random_split( + &mut self, + visitor: &mut NodeVisitor<'_, TX, TY, X, Y>, + n: usize, + sum: f64, + parent_gain: f64, + j: usize, + rng: &mut impl Rng, + ) { + let (min_val, max_val) = { + let mut min_opt = None; + let mut max_opt = None; + for &i in &visitor.order[j] { + if visitor.samples[i] > 0 { + min_opt = Some(*visitor.x.get((i, j))); + break; + } + } + for &i in visitor.order[j].iter().rev() { + if visitor.samples[i] > 0 { + max_opt = Some(*visitor.x.get((i, j))); + break; + } + } + if min_opt.is_none() { + return; + } + (min_opt.unwrap(), max_opt.unwrap()) + }; + + if min_val >= max_val { + return; + } + + let split_value = rng.gen_range(min_val.to_f64().unwrap()..max_val.to_f64().unwrap()); + + let mut true_sum = 0f64; + let mut true_count = 0; + for &i in &visitor.order[j] { + if visitor.samples[i] > 0 { + if visitor.x.get((i, j)).to_f64().unwrap() <= split_value { + true_sum += visitor.samples[i] as f64 * visitor.y.get(i).to_f64().unwrap(); + true_count += visitor.samples[i]; + } else { + break; + } + } + } + + let false_count = n - true_count; + + if true_count < self.parameters().min_samples_leaf + || false_count < self.parameters().min_samples_leaf + { + return; + } + + let true_mean = if true_count > 0 { + true_sum / true_count as f64 + } else { + 0.0 + }; + let false_mean = if false_count > 0 { + (sum - true_sum) / false_count as f64 + } else { + 0.0 + }; + let gain = (true_count as f64 * true_mean * true_mean + + false_count as f64 * false_mean * false_mean) + - parent_gain; + + if self.nodes[visitor.node].split_score.is_none() + || gain > self.nodes[visitor.node].split_score.unwrap() + { + self.nodes[visitor.node].split_feature = j; + self.nodes[visitor.node].split_value = Some(split_value); + self.nodes[visitor.node].split_score = Some(gain); + visitor.true_child_output = true_mean; + visitor.false_child_output = false_mean; + } + } + + fn find_best_split( + &mut self, + visitor: &mut NodeVisitor<'_, TX, TY, X, Y>, + n: usize, + sum: f64, + parent_gain: f64, + j: usize, + ) { + let mut true_sum = 0f64; + let mut true_count = 0; + let mut prevx = Option::None; + + for i in visitor.order[j].iter() { + if visitor.samples[*i] > 0 { + let x_ij = *visitor.x.get((*i, j)); + + if prevx.is_none() || x_ij == prevx.unwrap() { + prevx = Some(x_ij); + true_count += visitor.samples[*i]; + true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap(); + continue; + } + + let false_count = n - true_count; + + if true_count < self.parameters().min_samples_leaf + || false_count < self.parameters().min_samples_leaf + { + prevx = Some(x_ij); + true_count += visitor.samples[*i]; + true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap(); + continue; + } + + let true_mean = true_sum / true_count as f64; + let false_mean = (sum - true_sum) / false_count as f64; + + let gain = (true_count as f64 * true_mean * true_mean + + false_count as f64 * false_mean * false_mean) + - parent_gain; + + if self.nodes()[visitor.node].split_score.is_none() + || gain > self.nodes()[visitor.node].split_score.unwrap() + { + self.nodes[visitor.node].split_feature = j; + self.nodes[visitor.node].split_value = + Option::Some((x_ij + prevx.unwrap()).to_f64().unwrap() / 2f64); + self.nodes[visitor.node].split_score = Option::Some(gain); + + visitor.true_child_output = true_mean; + visitor.false_child_output = false_mean; + } + + prevx = Some(x_ij); + true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap(); + true_count += visitor.samples[*i]; + } + } + } + + fn split<'a>( + &mut self, + mut visitor: NodeVisitor<'a, TX, TY, X, Y>, + mtry: usize, + visitor_queue: &mut LinkedList>, + rng: &mut impl Rng, + ) -> bool { + let (n, _) = visitor.x.shape(); + let mut tc = 0; + let mut fc = 0; + let mut true_samples: Vec = vec![0; n]; + + for (i, true_sample) in true_samples.iter_mut().enumerate().take(n) { + if visitor.samples[i] > 0 { + if visitor + .x + .get((i, self.nodes()[visitor.node].split_feature)) + .to_f64() + .unwrap() + <= self.nodes()[visitor.node].split_value.unwrap_or(f64::NAN) + { + *true_sample = visitor.samples[i]; + tc += *true_sample; + visitor.samples[i] = 0; + } else { + fc += visitor.samples[i]; + } + } + } + + if tc < self.parameters().min_samples_leaf || fc < self.parameters().min_samples_leaf { + self.nodes[visitor.node].split_feature = 0; + self.nodes[visitor.node].split_value = Option::None; + self.nodes[visitor.node].split_score = Option::None; + + return false; + } + + let true_child_idx = self.nodes().len(); + + self.nodes.push(Node::new(visitor.true_child_output)); + let false_child_idx = self.nodes().len(); + self.nodes.push(Node::new(visitor.false_child_output)); + + self.nodes[visitor.node].true_child = Some(true_child_idx); + self.nodes[visitor.node].false_child = Some(false_child_idx); + + self.depth = u16::max(self.depth, visitor.level + 1); + + let mut true_visitor = NodeVisitor::::new( + true_child_idx, + true_samples, + visitor.order, + visitor.x, + visitor.y, + visitor.level + 1, + ); + + if self.find_best_cutoff(&mut true_visitor, mtry, rng) { + visitor_queue.push_back(true_visitor); + } + + let mut false_visitor = NodeVisitor::::new( + false_child_idx, + visitor.samples, + visitor.order, + visitor.x, + visitor.y, + visitor.level + 1, + ); + + if self.find_best_cutoff(&mut false_visitor, mtry, rng) { + visitor_queue.push_back(false_visitor); + } + + true + } +} diff --git a/src/tree/decision_tree_regressor.rs b/src/tree/decision_tree_regressor.rs index d735697..154ba2e 100644 --- a/src/tree/decision_tree_regressor.rs +++ b/src/tree/decision_tree_regressor.rs @@ -58,22 +58,17 @@ //! //! -use std::collections::LinkedList; use std::default::Default; use std::fmt::Debug; -use std::marker::PhantomData; - -use rand::seq::SliceRandom; -use rand::Rng; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use super::base_tree_regressor::{BaseTreeRegressor, BaseTreeRegressorParameters, Splitter}; use crate::api::{Predictor, SupervisedEstimator}; use crate::error::Failed; -use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1}; +use crate::linalg::basic::arrays::{Array1, Array2}; use crate::numbers::basenum::Number; -use crate::rand_custom::get_rng_impl; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] @@ -98,41 +93,7 @@ pub struct DecisionTreeRegressorParameters { #[derive(Debug)] pub struct DecisionTreeRegressor, Y: Array1> { - nodes: Vec, - parameters: Option, - depth: u16, - _phantom_tx: PhantomData, - _phantom_ty: PhantomData, - _phantom_x: PhantomData, - _phantom_y: PhantomData, -} - -impl, Y: Array1> - DecisionTreeRegressor -{ - /// Get nodes, return a shared reference - fn nodes(&self) -> &Vec { - self.nodes.as_ref() - } - /// Get parameters, return a shared reference - fn parameters(&self) -> &DecisionTreeRegressorParameters { - self.parameters.as_ref().unwrap() - } - /// Get estimate of intercept, return value - fn depth(&self) -> u16 { - self.depth - } -} - -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Debug, Clone)] -struct Node { - output: f64, - split_feature: usize, - split_value: Option, - split_score: Option, - true_child: Option, - false_child: Option, + tree_regressor: Option>, } impl DecisionTreeRegressorParameters { @@ -296,87 +257,11 @@ impl Default for DecisionTreeRegressorSearchParameters { } } -impl Node { - fn new(output: f64) -> Self { - Node { - output, - split_feature: 0, - split_value: Option::None, - split_score: Option::None, - true_child: Option::None, - false_child: Option::None, - } - } -} - -impl PartialEq for Node { - fn eq(&self, other: &Self) -> bool { - (self.output - other.output).abs() < f64::EPSILON - && self.split_feature == other.split_feature - && match (self.split_value, other.split_value) { - (Some(a), Some(b)) => (a - b).abs() < f64::EPSILON, - (None, None) => true, - _ => false, - } - && match (self.split_score, other.split_score) { - (Some(a), Some(b)) => (a - b).abs() < f64::EPSILON, - (None, None) => true, - _ => false, - } - } -} - impl, Y: Array1> PartialEq for DecisionTreeRegressor { fn eq(&self, other: &Self) -> bool { - if self.depth != other.depth || self.nodes().len() != other.nodes().len() { - false - } else { - self.nodes() - .iter() - .zip(other.nodes().iter()) - .all(|(a, b)| a == b) - } - } -} - -struct NodeVisitor<'a, TX: Number + PartialOrd, TY: Number, X: Array2, Y: Array1> { - x: &'a X, - y: &'a Y, - node: usize, - samples: Vec, - order: &'a [Vec], - true_child_output: f64, - false_child_output: f64, - level: u16, - _phantom_tx: PhantomData, - _phantom_ty: PhantomData, -} - -impl<'a, TX: Number + PartialOrd, TY: Number, X: Array2, Y: Array1> - NodeVisitor<'a, TX, TY, X, Y> -{ - fn new( - node_id: usize, - samples: Vec, - order: &'a [Vec], - x: &'a X, - y: &'a Y, - level: u16, - ) -> Self { - NodeVisitor { - x, - y, - node: node_id, - samples, - order, - true_child_output: 0f64, - false_child_output: 0f64, - level, - _phantom_tx: PhantomData, - _phantom_ty: PhantomData, - } + self.tree_regressor == other.tree_regressor } } @@ -386,13 +271,7 @@ impl, Y: Array1> { fn new() -> Self { Self { - nodes: vec![], - parameters: Option::None, - depth: 0u16, - _phantom_tx: PhantomData, - _phantom_ty: PhantomData, - _phantom_x: PhantomData, - _phantom_y: PhantomData, + tree_regressor: None, } } @@ -420,13 +299,17 @@ impl, Y: Array1> y: &Y, parameters: DecisionTreeRegressorParameters, ) -> Result, Failed> { - let (x_nrows, num_attributes) = x.shape(); - if x_nrows != y.shape() { - return Err(Failed::fit("Size of x should equal size of y")); - } - - let samples = vec![1; x_nrows]; - DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters) + let tree_parameters = BaseTreeRegressorParameters { + max_depth: parameters.max_depth, + min_samples_leaf: parameters.min_samples_leaf, + min_samples_split: parameters.min_samples_split, + seed: parameters.seed, + splitter: Splitter::Best, + }; + let tree = BaseTreeRegressor::fit(x, y, tree_parameters)?; + Ok(Self { + tree_regressor: Some(tree), + }) } pub(crate) fn fit_weak_learner( @@ -436,267 +319,30 @@ impl, Y: Array1> mtry: usize, parameters: DecisionTreeRegressorParameters, ) -> Result, Failed> { - let y_m = y.clone(); - - let y_ncols = y_m.shape(); - let (_, num_attributes) = x.shape(); - - let mut nodes: Vec = Vec::new(); - let mut rng = get_rng_impl(parameters.seed); - - let mut n = 0; - let mut sum = 0f64; - for (i, sample_i) in samples.iter().enumerate().take(y_ncols) { - n += *sample_i; - sum += *sample_i as f64 * y_m.get(i).to_f64().unwrap(); - } - - let root = Node::new(sum / (n as f64)); - nodes.push(root); - let mut order: Vec> = Vec::new(); - - for i in 0..num_attributes { - let mut col_i: Vec = x.get_col(i).iterator(0).copied().collect(); - order.push(col_i.argsort_mut()); - } - - let mut tree = DecisionTreeRegressor { - nodes, - parameters: Some(parameters), - depth: 0u16, - _phantom_tx: PhantomData, - _phantom_ty: PhantomData, - _phantom_x: PhantomData, - _phantom_y: PhantomData, + let tree_parameters = BaseTreeRegressorParameters { + max_depth: parameters.max_depth, + min_samples_leaf: parameters.min_samples_leaf, + min_samples_split: parameters.min_samples_split, + seed: parameters.seed, + splitter: Splitter::Best, }; - - let mut visitor = NodeVisitor::::new(0, samples, &order, x, &y_m, 1); - - let mut visitor_queue: LinkedList> = LinkedList::new(); - - if tree.find_best_cutoff(&mut visitor, mtry, &mut rng) { - visitor_queue.push_back(visitor); - } - - while tree.depth() < tree.parameters().max_depth.unwrap_or(u16::MAX) { - match visitor_queue.pop_front() { - Some(node) => tree.split(node, mtry, &mut visitor_queue, &mut rng), - None => break, - }; - } - - Ok(tree) + let tree = BaseTreeRegressor::fit_weak_learner(x, y, samples, mtry, tree_parameters)?; + Ok(Self { + tree_regressor: Some(tree), + }) } /// Predict regression value for `x`. /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. pub fn predict(&self, x: &X) -> Result { - let mut result = Y::zeros(x.shape().0); - - let (n, _) = x.shape(); - - for i in 0..n { - result.set(i, self.predict_for_row(x, i)); - } - - Ok(result) + self.tree_regressor.as_ref().unwrap().predict(x) } pub(crate) fn predict_for_row(&self, x: &X, row: usize) -> TY { - let mut result = 0f64; - let mut queue: LinkedList = LinkedList::new(); - - queue.push_back(0); - - while !queue.is_empty() { - match queue.pop_front() { - Some(node_id) => { - let node = &self.nodes()[node_id]; - if node.true_child.is_none() && node.false_child.is_none() { - result = node.output; - } else if x.get((row, node.split_feature)).to_f64().unwrap() - <= node.split_value.unwrap_or(f64::NAN) - { - queue.push_back(node.true_child.unwrap()); - } else { - queue.push_back(node.false_child.unwrap()); - } - } - None => break, - }; - } - - TY::from_f64(result).unwrap() - } - - fn find_best_cutoff( - &mut self, - visitor: &mut NodeVisitor<'_, TX, TY, X, Y>, - mtry: usize, - rng: &mut impl Rng, - ) -> bool { - let (_, n_attr) = visitor.x.shape(); - - let n: usize = visitor.samples.iter().sum(); - - if n < self.parameters().min_samples_split { - return false; - } - - let sum = self.nodes()[visitor.node].output * n as f64; - - let mut variables = (0..n_attr).collect::>(); - - if mtry < n_attr { - variables.shuffle(rng); - } - - let parent_gain = - n as f64 * self.nodes()[visitor.node].output * self.nodes()[visitor.node].output; - - for variable in variables.iter().take(mtry) { - self.find_best_split(visitor, n, sum, parent_gain, *variable); - } - - self.nodes()[visitor.node].split_score.is_some() - } - - fn find_best_split( - &mut self, - visitor: &mut NodeVisitor<'_, TX, TY, X, Y>, - n: usize, - sum: f64, - parent_gain: f64, - j: usize, - ) { - let mut true_sum = 0f64; - let mut true_count = 0; - let mut prevx = Option::None; - - for i in visitor.order[j].iter() { - if visitor.samples[*i] > 0 { - let x_ij = *visitor.x.get((*i, j)); - - if prevx.is_none() || x_ij == prevx.unwrap() { - prevx = Some(x_ij); - true_count += visitor.samples[*i]; - true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap(); - continue; - } - - let false_count = n - true_count; - - if true_count < self.parameters().min_samples_leaf - || false_count < self.parameters().min_samples_leaf - { - prevx = Some(x_ij); - true_count += visitor.samples[*i]; - true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap(); - continue; - } - - let true_mean = true_sum / true_count as f64; - let false_mean = (sum - true_sum) / false_count as f64; - - let gain = (true_count as f64 * true_mean * true_mean - + false_count as f64 * false_mean * false_mean) - - parent_gain; - - if self.nodes()[visitor.node].split_score.is_none() - || gain > self.nodes()[visitor.node].split_score.unwrap() - { - self.nodes[visitor.node].split_feature = j; - self.nodes[visitor.node].split_value = - Option::Some((x_ij + prevx.unwrap()).to_f64().unwrap() / 2f64); - self.nodes[visitor.node].split_score = Option::Some(gain); - - visitor.true_child_output = true_mean; - visitor.false_child_output = false_mean; - } - - prevx = Some(x_ij); - true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap(); - true_count += visitor.samples[*i]; - } - } - } - - fn split<'a>( - &mut self, - mut visitor: NodeVisitor<'a, TX, TY, X, Y>, - mtry: usize, - visitor_queue: &mut LinkedList>, - rng: &mut impl Rng, - ) -> bool { - let (n, _) = visitor.x.shape(); - let mut tc = 0; - let mut fc = 0; - let mut true_samples: Vec = vec![0; n]; - - for (i, true_sample) in true_samples.iter_mut().enumerate().take(n) { - if visitor.samples[i] > 0 { - if visitor - .x - .get((i, self.nodes()[visitor.node].split_feature)) - .to_f64() - .unwrap() - <= self.nodes()[visitor.node].split_value.unwrap_or(f64::NAN) - { - *true_sample = visitor.samples[i]; - tc += *true_sample; - visitor.samples[i] = 0; - } else { - fc += visitor.samples[i]; - } - } - } - - if tc < self.parameters().min_samples_leaf || fc < self.parameters().min_samples_leaf { - self.nodes[visitor.node].split_feature = 0; - self.nodes[visitor.node].split_value = Option::None; - self.nodes[visitor.node].split_score = Option::None; - - return false; - } - - let true_child_idx = self.nodes().len(); - - self.nodes.push(Node::new(visitor.true_child_output)); - let false_child_idx = self.nodes().len(); - self.nodes.push(Node::new(visitor.false_child_output)); - - self.nodes[visitor.node].true_child = Some(true_child_idx); - self.nodes[visitor.node].false_child = Some(false_child_idx); - - self.depth = u16::max(self.depth, visitor.level + 1); - - let mut true_visitor = NodeVisitor::::new( - true_child_idx, - true_samples, - visitor.order, - visitor.x, - visitor.y, - visitor.level + 1, - ); - - if self.find_best_cutoff(&mut true_visitor, mtry, rng) { - visitor_queue.push_back(true_visitor); - } - - let mut false_visitor = NodeVisitor::::new( - false_child_idx, - visitor.samples, - visitor.order, - visitor.x, - visitor.y, - visitor.level + 1, - ); - - if self.find_best_cutoff(&mut false_visitor, mtry, rng) { - visitor_queue.push_back(false_visitor); - } - - true + self.tree_regressor + .as_ref() + .unwrap() + .predict_for_row(x, row) } } diff --git a/src/tree/mod.rs b/src/tree/mod.rs index 340b0a8..b325e96 100644 --- a/src/tree/mod.rs +++ b/src/tree/mod.rs @@ -19,6 +19,7 @@ //! //! +mod base_tree_regressor; /// Classification tree for dependent variables that take a finite number of unordered values. pub mod decision_tree_classifier; /// Regression tree for for dependent variables that take continuous or ordered discrete values.