diff --git a/src/algorithm/sort/mod.rs b/src/algorithm/sort/mod.rs index 7984c1d..994c47e 100644 --- a/src/algorithm/sort/mod.rs +++ b/src/algorithm/sort/mod.rs @@ -1 +1,2 @@ -pub mod heap_select; \ No newline at end of file +pub mod heap_select; +pub mod quick_sort; \ No newline at end of file diff --git a/src/algorithm/sort/quick_sort.rs b/src/algorithm/sort/quick_sort.rs new file mode 100644 index 0000000..969f41a --- /dev/null +++ b/src/algorithm/sort/quick_sort.rs @@ -0,0 +1,116 @@ +pub trait QuickArgSort { + fn quick_argsort(&mut self) -> Vec; +} + +impl QuickArgSort for Vec { + + fn quick_argsort(&mut self) -> Vec { + let stack_size = 64; + let mut jstack = -1; + let mut l = 0; + let mut istack = vec![0; stack_size]; + let mut ir = self.len() - 1; + let mut index: Vec = (0..self.len()).collect(); + + loop { + if ir - l < 7 { + for j in l + 1..=ir { + let a = self[j]; + let b = index[j]; + let mut i: i32 = (j - 1) as i32; + while i >= l as i32 { + if self[i as usize] <= a { + break; + } + self[(i + 1) as usize] = self[i as usize]; + index[(i + 1) as usize] = index[i as usize]; + i -= 1; + } + self[(i + 1) as usize] = a; + index[(i + 1) as usize] = b; + } + if jstack < 0 { + break; + } + ir = istack[jstack as usize]; + jstack -= 1; + l = istack[jstack as usize]; + jstack -= 1; + } else { + let k = (l + ir) >> 1; + self.swap(k, l + 1); + index.swap(k, l + 1); + if self[l] > self[ir] { + self.swap(l, ir); + index.swap(l, ir); + } + if self[l + 1] > self[ir] { + self.swap(l + 1, ir); + index.swap(l + 1, ir); + } + if self[l] > self[l + 1] { + self.swap(l, l + 1); + index.swap(l, l + 1); + } + let mut i = l + 1; + let mut j = ir; + let a = self[l + 1]; + let b = index[l + 1]; + loop { + loop { + i += 1; + if self[i] >= a { + break; + } + } + loop { + j -=1; + if self[j] <= a { + break; + } + } + if j < i { + break; + } + self.swap(i, j); + index.swap(i, j); + } + self[l + 1] = self[j]; + self[j] = a; + index[l + 1] = index[j]; + index[j] = b; + jstack += 2; + + if jstack >= 64 { + panic!("stack size is too small."); + } + + if ir - i + 1 >= j - l { + istack[jstack as usize] = ir; + istack[jstack as usize - 1] = i; + ir = j - 1; + } else { + istack[jstack as usize] = j - 1; + istack[jstack as usize - 1] = l; + l = i; + } + } + } + + index + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn with_capacity() { + let mut arr1 = vec![0.3, 0.1, 0.2, 0.4, 0.9, 0.5, 0.7, 0.6, 0.8]; + assert_eq!(vec![1, 2, 0, 3, 5, 7, 6, 8, 4], arr1.quick_argsort()); + + let mut arr2 = vec![0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1, 1.4, 1.5, 1.5, 1.3, 1.5, 1.3, 1.6, 1.0, 1.3, 1.4]; + assert_eq!(vec![9, 7, 1, 8, 0, 2, 4, 3, 6, 5, 17, 18, 15, 13, 19, 10, 14, 11, 12, 16], arr2.quick_argsort()); + } +} \ No newline at end of file diff --git a/src/classification/decision_tree.rs b/src/classification/decision_tree.rs new file mode 100644 index 0000000..a9cba83 --- /dev/null +++ b/src/classification/decision_tree.rs @@ -0,0 +1,468 @@ +use std::default::Default; +use std::collections::LinkedList; +use crate::linalg::Matrix; +use crate::algorithm::sort::quick_sort::QuickArgSort; + +#[derive(Debug)] +pub struct DecisionTreeParameters { + criterion: SplitCriterion, + max_depth: Option, + min_samples_leaf: u16 +} + +#[derive(Debug)] +pub struct DecisionTree { + nodes: Vec, + parameters: DecisionTreeParameters, + num_classes: usize, + classes: Vec, + depth: u16 +} + +#[derive(Debug)] +pub enum SplitCriterion { + Gini, + Entropy, + ClassificationError +} + +#[derive(Debug)] +pub struct Node { + index: usize, + output: usize, + split_feature: usize, + split_value: f64, + split_score: f64, + true_child: Option, + false_child: Option, +} + + +impl Default for DecisionTreeParameters { + fn default() -> Self { + DecisionTreeParameters { + criterion: SplitCriterion::Gini, + max_depth: None, + min_samples_leaf: 1 + } + } +} + +impl Node { + fn new(index: usize, output: usize) -> Self { + Node { + index: index, + output: output, + split_feature: 0, + split_value: std::f64::NAN, + split_score: std::f64::NAN, + true_child: Option::None, + false_child: Option::None + } + } +} + +struct NodeVisitor<'a, M: Matrix> { + x: &'a M, + y: &'a Vec, + node: usize, + samples: Vec, + order: &'a Vec>, + true_child_output: usize, + false_child_output: usize, + level: u16 +} + +fn impurity(criterion: &SplitCriterion, count: &Vec, n: u32) -> f64 { + let mut impurity = 0.; + + match criterion { + SplitCriterion::Gini => { + impurity = 1.0; + for i in 0..count.len() { + if count[i] > 0 { + let p = count[i] as f64 / n as f64; + impurity -= p * p; + } + } + } + + SplitCriterion::Entropy => { + for i in 0..count.len() { + if count[i] > 0 { + let p = count[i] as f64 / n as f64; + impurity -= p * p.log2(); + } + } + } + SplitCriterion::ClassificationError => { + for i in 0..count.len() { + if count[i] > 0 { + impurity = impurity.max(count[i] as f64 / n as f64); + } + } + impurity = (1. - impurity).abs(); + } + } + + return impurity; +} + +impl<'a, M: Matrix> NodeVisitor<'a, M> { + + fn new(node_id: usize, samples: Vec, order: &'a Vec>, x: &'a M, y: &'a Vec, level: u16) -> Self { + NodeVisitor { + x: x, + y: y, + node: node_id, + samples: samples, + order: order, + true_child_output: 0, + false_child_output: 0, + level: level + } + } + +} + +fn which_max(x: &Vec) -> usize { + let mut m = x[0]; + let mut which = 0; + + for i in 1..x.len() { + if x[i] > m { + m = x[i]; + which = i; + } + } + + return which; +} + +impl DecisionTree { + + pub fn fit(x: &M, y: &M::RowVector, parameters: DecisionTreeParameters) -> DecisionTree { + let y_m = M::from_row_vector(y.clone()); + let (_, y_ncols) = y_m.shape(); + let (x_nrows, num_attributes) = x.shape(); + let classes = y_m.unique(); + let k = classes.len(); + if k < 2 { + panic!("Incorrect number of classes: {}. Should be >= 2.", k); + } + + let mut yi: Vec = vec![0; y_ncols]; + + for i in 0..y_ncols { + let yc = y_m.get(0, i); + yi[i] = classes.iter().position(|c| yc == *c).unwrap(); + } + + let mut nodes: Vec = Vec::new(); + let samples = vec![1; x_nrows]; + + let mut count = vec![0; k]; + for i in 0..y_ncols { + count[yi[i]] += samples[i]; + } + + let root = Node::new(0, which_max(&count)); + nodes.push(root); + let mut order: Vec> = Vec::new(); + + for i in 0..num_attributes { + order.push(x.get_col_as_vec(i).quick_argsort()); + } + + let mut tree = DecisionTree{ + nodes: nodes, + parameters: parameters, + num_classes: k, + classes: classes, + depth: 0 + }; + + let mut visitor = NodeVisitor::::new(0, samples, &order, &x, &yi, 1); + + let mut visitor_queue: LinkedList> = LinkedList::new(); + + if tree.find_best_cutoff(&mut visitor) { + visitor_queue.push_back(visitor); + } + + while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) { + match visitor_queue.pop_front() { + Some(node) => tree.split(node, &mut visitor_queue), + None => break + }; + } + + tree + } + + pub fn predict(&self, x: &M) -> M::RowVector { + let mut result = M::zeros(1, x.shape().0); + + let (n, _) = x.shape(); + + for i in 0..n { + result.set(0, i, self.classes[self.predict_for_row(x, i)]); + } + + result.to_row_vector() + } + + fn predict_for_row(&self, x: &M, row: usize) -> usize { + let mut result = 0; + let mut queue: LinkedList = LinkedList::new(); + + queue.push_back(0); + + while !queue.is_empty() { + match queue.pop_front() { + Some(node_id) => { + let node = &self.nodes[node_id]; + if node.true_child == None && node.false_child == None { + result = node.output; + } else { + if x.get(row, node.split_feature) <= node.split_value { + queue.push_back(node.true_child.unwrap()); + } else { + queue.push_back(node.false_child.unwrap()); + } + } + }, + None => break + }; + } + + return result + + } + + fn find_best_cutoff(&mut self, visitor: &mut NodeVisitor) -> bool { + + let (n_rows, n_attr) = visitor.x.shape(); + + let mut label = Option::None; + let mut is_pure = true; + for i in 0..n_rows { + if visitor.samples[i] > 0 { + if label == Option::None { + label = Option::Some(visitor.y[i]); + } else if visitor.y[i] != label.unwrap() { + is_pure = false; + break; + } + } + } + + if is_pure { + return false; + } + + let n = visitor.samples.iter().sum(); + + if n <= self.parameters.min_samples_leaf as u32 { + return false; + } + + let mut count = vec![0; self.num_classes]; + let mut false_count = vec![0; self.num_classes]; + for i in 0..n_rows { + if visitor.samples[i] > 0 { + count[visitor.y[i]] += visitor.samples[i]; + } + } + + let parent_impurity = impurity(&self.parameters.criterion, &count, n); + + let mut variables = vec![0; n_attr]; + for i in 0..n_attr { + variables[i] = i; + } + + for j in 0..n_attr { + self.find_best_split(visitor, n, &count, &mut false_count, parent_impurity, variables[j]); + } + + !self.nodes[visitor.node].split_score.is_nan() + + } + + fn find_best_split(&mut self, visitor: &mut NodeVisitor, n: u32, count: &Vec, false_count: &mut Vec, parent_impurity: f64, j: usize){ + + let mut true_count = vec![0; self.num_classes]; + let mut prevx = std::f64::NAN; + let mut prevy = 0; + let node_size = 1; + + for i in visitor.order[j].iter() { + if visitor.samples[*i] > 0 { + if prevx.is_nan() || visitor.x.get(*i, j) == prevx || visitor.y[*i] == prevy { + prevx = visitor.x.get(*i, j); + prevy = visitor.y[*i]; + true_count[visitor.y[*i]] += visitor.samples[*i]; + continue; + } + + let tc = true_count.iter().sum(); + let fc = n - tc; + + if tc < node_size || fc < node_size { + prevx = visitor.x.get(*i, j); + prevy = visitor.y[*i]; + true_count[visitor.y[*i]] += visitor.samples[*i]; + continue; + } + + for l in 0..self.num_classes { + false_count[l] = count[l] - true_count[l]; + } + + let true_label = which_max(&true_count); + let false_label = which_max(false_count); + let gain = parent_impurity - tc as f64 / n as f64 * impurity(&self.parameters.criterion, &true_count, tc) - fc as f64 / n as f64 * impurity(&self.parameters.criterion, &false_count, fc); + + if self.nodes[visitor.node].split_score.is_nan() || gain > self.nodes[visitor.node].split_score { + self.nodes[visitor.node].split_feature = j; + self.nodes[visitor.node].split_value = (visitor.x.get(*i, j) + prevx) / 2.; + self.nodes[visitor.node].split_score = gain; + visitor.true_child_output = true_label; + visitor.false_child_output = false_label; + } + + prevx = visitor.x.get(*i, j); + prevy = visitor.y[*i]; + true_count[visitor.y[*i]] += visitor.samples[*i]; + } + } + + } + + fn split<'a, M: Matrix>(&mut self, mut visitor: NodeVisitor<'a, M>, visitor_queue: &mut LinkedList>) -> bool { + let (n, _) = visitor.x.shape(); + let mut tc = 0; + let mut fc = 0; + let mut true_samples: Vec = vec![0; n]; + + for i in 0..n { + if visitor.samples[i] > 0 { + if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value { + true_samples[i] = visitor.samples[i]; + tc += true_samples[i]; + visitor.samples[i] = 0; + } else { + fc += visitor.samples[i]; + } + } + } + + if tc < self.parameters.min_samples_leaf as u32 || fc < self.parameters.min_samples_leaf as u32 { + self.nodes[visitor.node].split_feature = 0; + self.nodes[visitor.node].split_value = std::f64::NAN; + self.nodes[visitor.node].split_score = std::f64::NAN; + return false; + } + + let true_child_idx = self.nodes.len(); + self.nodes.push(Node::new(true_child_idx, visitor.true_child_output)); + let false_child_idx = self.nodes.len(); + self.nodes.push(Node::new(false_child_idx, visitor.false_child_output)); + + self.nodes[visitor.node].true_child = Some(true_child_idx); + self.nodes[visitor.node].false_child = Some(false_child_idx); + + self.depth = u16::max(self.depth, visitor.level + 1); + + let mut true_visitor = NodeVisitor::::new(true_child_idx, true_samples, visitor.order, visitor.x, visitor.y, visitor.level + 1); + + if tc > self.parameters.min_samples_leaf as u32 && self.find_best_cutoff(&mut true_visitor) { + visitor_queue.push_back(true_visitor); + } + + let mut false_visitor = NodeVisitor::::new(false_child_idx, visitor.samples, visitor.order, visitor.x, visitor.y, visitor.level + 1); + + if fc > self.parameters.min_samples_leaf as u32 && self.find_best_cutoff(&mut false_visitor) { + visitor_queue.push_back(false_visitor); + } + + true + } + +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::linalg::naive::dense_matrix::DenseMatrix; + + #[test] + fn gini_impurity() { + assert!((impurity(&SplitCriterion::Gini, &vec![7, 3], 10) - 0.42).abs() < std::f64::EPSILON); + assert!((impurity(&SplitCriterion::Entropy, &vec![7, 3], 10) - 0.8812908992306927).abs() < std::f64::EPSILON); + assert!((impurity(&SplitCriterion::ClassificationError, &vec![7, 3], 10) - 0.3).abs() < std::f64::EPSILON); + } + + #[test] + fn fit_predict_iris() { + + let x = DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + &[4.9, 3.1, 1.5, 0.1], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + &[5.7, 2.8, 4.5, 1.3], + &[6.3, 3.3, 4.7, 1.6], + &[4.9, 2.4, 3.3, 1.0], + &[6.6, 2.9, 4.6, 1.3], + &[5.2, 2.7, 3.9, 1.4]]); + let y = vec![0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]; + + assert_eq!(y, DecisionTree::fit(&x, &y, Default::default()).predict(&x)); + + assert_eq!(3, DecisionTree::fit(&x, &y, DecisionTreeParameters{criterion: SplitCriterion::Entropy, max_depth: Some(3), min_samples_leaf: 1}).depth); + + } + + #[test] + fn fit_predict_baloons() { + + let x = DenseMatrix::from_2d_array(&[ + &[1.,1.,1.,0.], + &[1.,1.,1.,0.], + &[1.,1.,1.,1.], + &[1.,1.,0.,0.], + &[1.,1.,0.,1.], + &[1.,0.,1.,0.], + &[1.,0.,1.,0.], + &[1.,0.,1.,1.], + &[1.,0.,0.,0.], + &[1.,0.,0.,1.], + &[0.,1.,1.,0.], + &[0.,1.,1.,0.], + &[0.,1.,1.,1.], + &[0.,1.,0.,0.], + &[0.,1.,0.,1.], + &[0.,0.,1.,0.], + &[0.,0.,1.,0.], + &[0.,0.,1.,1.], + &[0.,0.,0.,0.], + &[0.,0.,0.,1.]]); + let y = vec![1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.]; + + assert_eq!(y, DecisionTree::fit(&x, &y, Default::default()).predict(&x)); + + } +} \ No newline at end of file diff --git a/src/classification/logistic_regression.rs b/src/classification/logistic_regression.rs index a297f36..b979499 100644 --- a/src/classification/logistic_regression.rs +++ b/src/classification/logistic_regression.rs @@ -354,45 +354,10 @@ mod tests { assert_eq!(y_hat, vec![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]); - } + } #[test] - fn lr_fit_predict_iris() { - - let x = DenseMatrix::from_2d_array(&[ - &[5.1, 3.5, 1.4, 0.2], - &[4.9, 3.0, 1.4, 0.2], - &[4.7, 3.2, 1.3, 0.2], - &[4.6, 3.1, 1.5, 0.2], - &[5.0, 3.6, 1.4, 0.2], - &[5.4, 3.9, 1.7, 0.4], - &[4.6, 3.4, 1.4, 0.3], - &[5.0, 3.4, 1.5, 0.2], - &[4.4, 2.9, 1.4, 0.2], - &[4.9, 3.1, 1.5, 0.1], - &[7.0, 3.2, 4.7, 1.4], - &[6.4, 3.2, 4.5, 1.5], - &[6.9, 3.1, 4.9, 1.5], - &[5.5, 2.3, 4.0, 1.3], - &[6.5, 2.8, 4.6, 1.5], - &[5.7, 2.8, 4.5, 1.3], - &[6.3, 3.3, 4.7, 1.6], - &[4.9, 2.4, 3.3, 1.0], - &[6.6, 2.9, 4.6, 1.3], - &[5.2, 2.7, 3.9, 1.4]]); - let y =vec![0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]; - - let lr = LogisticRegression::fit(&x, &y); - - let y_hat = lr.predict(&x); - - assert_eq!(y_hat, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]); - - - } - - #[test] - fn tt() { + fn lr_fit_predict_iris() { let x = arr2(&[ [5.1, 3.5, 1.4, 0.2], [4.9, 3.0, 1.4, 0.2], diff --git a/src/classification/mod.rs b/src/classification/mod.rs index 0d14356..c181db8 100644 --- a/src/classification/mod.rs +++ b/src/classification/mod.rs @@ -2,6 +2,7 @@ use crate::common::Nominal; pub mod knn; pub mod logistic_regression; +pub mod decision_tree; pub trait Classifier where diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs index 896f8d6..636e1ac 100644 --- a/src/linalg/mod.rs +++ b/src/linalg/mod.rs @@ -18,6 +18,10 @@ pub trait Matrix: Clone + Debug { fn get(&self, row: usize, col: usize) -> f64; + fn get_row_as_vec(&self, row: usize) -> Vec; + + fn get_col_as_vec(&self, col: usize) -> Vec; + fn set(&mut self, row: usize, col: usize, x: f64); fn qr_solve_mut(&mut self, b: Self) -> Self; diff --git a/src/linalg/naive/dense_matrix.rs b/src/linalg/naive/dense_matrix.rs index 5165339..431a1d2 100644 --- a/src/linalg/naive/dense_matrix.rs +++ b/src/linalg/naive/dense_matrix.rs @@ -135,6 +135,22 @@ impl Matrix for DenseMatrix { self.values[col*self.nrows + row] } + fn get_row_as_vec(&self, row: usize) -> Vec{ + let mut result = vec![0f64; self.ncols]; + for c in 0..self.ncols { + result[c] = self.get(row, c); + } + result + } + + fn get_col_as_vec(&self, col: usize) -> Vec{ + let mut result = vec![0f64; self.nrows]; + for r in 0..self.nrows { + result[r] = self.get(r, col); + } + result + } + fn set(&mut self, row: usize, col: usize, x: f64) { self.values[col*self.nrows + row] = x; } diff --git a/src/linalg/ndarray_bindings.rs b/src/linalg/ndarray_bindings.rs index a3d2723..6f36552 100644 --- a/src/linalg/ndarray_bindings.rs +++ b/src/linalg/ndarray_bindings.rs @@ -28,6 +28,14 @@ impl Matrix for ArrayBase, Ix2> self[[row, col]] } + fn get_row_as_vec(&self, row: usize) -> Vec { + self.row(row).to_vec() + } + + fn get_col_as_vec(&self, col: usize) -> Vec { + self.column(col).to_vec() + } + fn set(&mut self, row: usize, col: usize, x: f64) { self[[row, col]] = x; } @@ -509,4 +517,18 @@ mod tests { assert_eq!(res.len(), 7); assert_eq!(res, vec![-7., -6., -2., 1., 2., 3., 4.]); } + + #[test] + fn get_row_as_vector(){ + let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]); + let res = a.get_row_as_vec(1); + assert_eq!(res, vec![4., 5., 6.]); + } + + #[test] + fn get_col_as_vector(){ + let a = arr2(&[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]); + let res = a.get_col_as_vec(1); + assert_eq!(res, vec![2., 5., 8.]); + } } \ No newline at end of file