diff --git a/src/tree/decision_tree_regressor.rs b/src/tree/decision_tree_regressor.rs new file mode 100644 index 0000000..4b5e8f2 --- /dev/null +++ b/src/tree/decision_tree_regressor.rs @@ -0,0 +1,350 @@ +use std::default::Default; +use std::collections::LinkedList; +use crate::linalg::Matrix; +use crate::algorithm::sort::quick_sort::QuickArgSort; + +#[derive(Debug)] +pub struct DecisionTreeRegressorParameters { + pub max_depth: Option, + pub min_samples_leaf: usize, + pub min_samples_split: usize +} + +#[derive(Debug)] +pub struct DecisionTreeRegressor { + nodes: Vec, + parameters: DecisionTreeRegressorParameters, + depth: u16 +} + +#[derive(Debug)] +pub struct Node { + index: usize, + output: f64, + split_feature: usize, + split_value: f64, + split_score: f64, + true_child: Option, + false_child: Option, +} + + +impl Default for DecisionTreeRegressorParameters { + fn default() -> Self { + DecisionTreeRegressorParameters { + max_depth: None, + min_samples_leaf: 2, + min_samples_split: 2 + } + } +} + +impl Node { + fn new(index: usize, output: f64) -> Self { + Node { + index: index, + output: output, + split_feature: 0, + split_value: std::f64::NAN, + split_score: std::f64::NAN, + true_child: Option::None, + false_child: Option::None + } + } +} + +struct NodeVisitor<'a, M: Matrix> { + x: &'a M, + y: &'a M, + node: usize, + samples: Vec, + order: &'a Vec>, + true_child_output: f64, + false_child_output: f64, + level: u16 +} + +impl<'a, M: Matrix> NodeVisitor<'a, M> { + + fn new(node_id: usize, samples: Vec, order: &'a Vec>, x: &'a M, y: &'a M, level: u16) -> Self { + NodeVisitor { + x: x, + y: y, + node: node_id, + samples: samples, + order: order, + true_child_output: 0., + false_child_output: 0., + level: level + } + } + +} + +impl DecisionTreeRegressor { + + pub fn fit(x: &M, y: &M::RowVector, parameters: DecisionTreeRegressorParameters) -> DecisionTreeRegressor { + let (x_nrows, num_attributes) = x.shape(); + let samples = vec![1; x_nrows]; + DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters) + } + + pub fn fit_weak_learner(x: &M, y: &M::RowVector, samples: Vec, mtry: usize, parameters: DecisionTreeRegressorParameters) -> DecisionTreeRegressor { + let y_m = M::from_row_vector(y.clone()); + let (_, y_ncols) = y_m.shape(); + let (_, num_attributes) = x.shape(); + let classes = y_m.unique(); + let k = classes.len(); + if k < 2 { + panic!("Incorrect number of classes: {}. Should be >= 2.", k); + } + + let mut nodes: Vec = Vec::new(); + + let mut n = 0; + let mut sum = 0f64; + for i in 0..y_ncols { + n += samples[i]; + sum += samples[i] as f64 * y_m.get(i, 0); + } + + let root = Node::new(0, sum / n as f64); + nodes.push(root); + let mut order: Vec> = Vec::new(); + + for i in 0..num_attributes { + order.push(x.get_col_as_vec(i).quick_argsort()); + } + + let mut tree = DecisionTreeRegressor{ + nodes: nodes, + parameters: parameters, + depth: 0 + }; + + let mut visitor = NodeVisitor::::new(0, samples, &order, &x, &y_m, 1); + + let mut visitor_queue: LinkedList> = LinkedList::new(); + + if tree.find_best_cutoff(&mut visitor, mtry) { + visitor_queue.push_back(visitor); + } + + while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) { + match visitor_queue.pop_front() { + Some(node) => tree.split(node, mtry, &mut visitor_queue,), + None => break + }; + } + + tree + } + + pub fn predict(&self, x: &M) -> M::RowVector { + let mut result = M::zeros(1, x.shape().0); + + let (n, _) = x.shape(); + + for i in 0..n { + result.set(0, i, self.predict_for_row(x, i)); + } + + result.to_row_vector() + } + + pub(in crate) fn predict_for_row(&self, x: &M, row: usize) -> f64 { + let mut result = 0f64; + let mut queue: LinkedList = LinkedList::new(); + + queue.push_back(0); + + while !queue.is_empty() { + match queue.pop_front() { + Some(node_id) => { + let node = &self.nodes[node_id]; + if node.true_child == None && node.false_child == None { + result = node.output; + } else { + if x.get(row, node.split_feature) <= node.split_value { + queue.push_back(node.true_child.unwrap()); + } else { + queue.push_back(node.false_child.unwrap()); + } + } + }, + None => break + }; + } + + return result + + } + + fn find_best_cutoff(&mut self, visitor: &mut NodeVisitor, mtry: usize) -> bool { + + let (_, n_attr) = visitor.x.shape(); + + let n: usize = visitor.samples.iter().sum(); + + if n <= self.parameters.min_samples_split { + return false; + } + + let sum = self.nodes[visitor.node].output * n as f64; + + let mut variables = vec![0; n_attr]; + for i in 0..n_attr { + variables[i] = i; + } + + let parent_gain = n as f64 * self.nodes[visitor.node].output * self.nodes[visitor.node].output; + + for j in 0..mtry { + self.find_best_split(visitor, n, sum, parent_gain, variables[j]); + } + + !self.nodes[visitor.node].split_score.is_nan() + + } + + fn find_best_split(&mut self, visitor: &mut NodeVisitor, n: usize, sum: f64, parent_gain: f64, j: usize){ + + let mut true_sum = 0f64; + let mut true_count = 0; + let mut prevx = std::f64::NAN; + + for i in visitor.order[j].iter() { + if visitor.samples[*i] > 0 { + if prevx.is_nan() || visitor.x.get(*i, j) == prevx { + prevx = visitor.x.get(*i, j); + true_count += visitor.samples[*i]; + true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i, 0); + continue; + } + + let false_count = n - true_count; + + if true_count < self.parameters.min_samples_leaf || false_count < self.parameters.min_samples_leaf { + prevx = visitor.x.get(*i, j); + true_count += visitor.samples[*i]; + true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i, 0); + continue; + } + + let true_mean = true_sum / true_count as f64; + let false_mean = (sum - true_sum) / false_count as f64; + + let gain = (true_count as f64 * true_mean * true_mean + false_count as f64 * false_mean * false_mean) - parent_gain; + + if self.nodes[visitor.node].split_score.is_nan() || gain > self.nodes[visitor.node].split_score { + self.nodes[visitor.node].split_feature = j; + self.nodes[visitor.node].split_value = (visitor.x.get(*i, j) + prevx) / 2.; + self.nodes[visitor.node].split_score = gain; + visitor.true_child_output = true_mean; + visitor.false_child_output = false_mean; + } + + prevx = visitor.x.get(*i, j); + true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i, 0); + true_count += visitor.samples[*i]; + } + } + + } + + fn split<'a, M: Matrix>(&mut self, mut visitor: NodeVisitor<'a, M>, mtry: usize, visitor_queue: &mut LinkedList>) -> bool { + let (n, _) = visitor.x.shape(); + let mut tc = 0; + let mut fc = 0; + let mut true_samples: Vec = vec![0; n]; + + for i in 0..n { + if visitor.samples[i] > 0 { + if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value { + true_samples[i] = visitor.samples[i]; + tc += true_samples[i]; + visitor.samples[i] = 0; + } else { + fc += visitor.samples[i]; + } + } + } + + if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf { + self.nodes[visitor.node].split_feature = 0; + self.nodes[visitor.node].split_value = std::f64::NAN; + self.nodes[visitor.node].split_score = std::f64::NAN; + return false; + } + + let true_child_idx = self.nodes.len(); + self.nodes.push(Node::new(true_child_idx, visitor.true_child_output)); + let false_child_idx = self.nodes.len(); + self.nodes.push(Node::new(false_child_idx, visitor.false_child_output)); + + self.nodes[visitor.node].true_child = Some(true_child_idx); + self.nodes[visitor.node].false_child = Some(false_child_idx); + + self.depth = u16::max(self.depth, visitor.level + 1); + + let mut true_visitor = NodeVisitor::::new(true_child_idx, true_samples, visitor.order, visitor.x, visitor.y, visitor.level + 1); + + if tc > self.parameters.min_samples_split && self.find_best_cutoff(&mut true_visitor, mtry) { + visitor_queue.push_back(true_visitor); + } + + let mut false_visitor = NodeVisitor::::new(false_child_idx, visitor.samples, visitor.order, visitor.x, visitor.y, visitor.level + 1); + + if fc > self.parameters.min_samples_split && self.find_best_cutoff(&mut false_visitor, mtry) { + visitor_queue.push_back(false_visitor); + } + + true + } + +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::linalg::naive::dense_matrix::DenseMatrix; + + #[test] + fn fit_longley() { + + let x = DenseMatrix::from_array(&[ + &[ 234.289, 235.6, 159., 107.608, 1947., 60.323], + &[ 259.426, 232.5, 145.6, 108.632, 1948., 61.122], + &[ 258.054, 368.2, 161.6, 109.773, 1949., 60.171], + &[ 284.599, 335.1, 165., 110.929, 1950., 61.187], + &[ 328.975, 209.9, 309.9, 112.075, 1951., 63.221], + &[ 346.999, 193.2, 359.4, 113.27 , 1952., 63.639], + &[ 365.385, 187., 354.7, 115.094, 1953., 64.989], + &[ 363.112, 357.8, 335., 116.219, 1954., 63.761], + &[ 397.469, 290.4, 304.8, 117.388, 1955., 66.019], + &[ 419.18 , 282.2, 285.7, 118.734, 1956., 67.857], + &[ 442.769, 293.6, 279.8, 120.445, 1957., 68.169], + &[ 444.546, 468.1, 263.7, 121.95 , 1958., 66.513], + &[ 482.704, 381.3, 255.2, 123.366, 1959., 68.655], + &[ 502.601, 393.1, 251.4, 125.368, 1960., 69.564], + &[ 518.173, 480.6, 257.2, 127.852, 1961., 69.331], + &[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]); + let y = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9]; + + let expected_y = vec![85.6, 89.0, 85.6, 89.0, 97.15, 97.15, 100.0, 100.0, 100.0, 107.9, 107.9, 107.9, 113.4, 113.4, 116.3, 116.3]; + let y_hat = DecisionTreeRegressor::fit(&x, &y, Default::default()).predict(&x); + + for i in 0..y_hat.len() { + assert!((y_hat[i] - expected_y[i]).abs() < 0.1); + } + + let expected_y = vec![87.3, 87.3, 87.3, 87.3, 98.9, 98.9, 98.9, 98.9, 98.9, 107.9, 107.9, 107.9, 114.85, 114.85, 114.85, 114.85]; + let y_hat = DecisionTreeRegressor::fit(&x, &y, DecisionTreeRegressorParameters{max_depth: Option::None, min_samples_leaf: 2, min_samples_split: 6}).predict(&x); + + for i in 0..y_hat.len() { + assert!((y_hat[i] - expected_y[i]).abs() < 0.1); + } + + } + +} \ No newline at end of file diff --git a/src/tree/mod.rs b/src/tree/mod.rs index 1d5fd68..bea1d51 100644 --- a/src/tree/mod.rs +++ b/src/tree/mod.rs @@ -1 +1,2 @@ +pub mod decision_tree_regressor; pub mod decision_tree; \ No newline at end of file