fix: cargo fmt
This commit is contained in:
@@ -1,111 +1,112 @@
|
||||
use std::collections::LinkedList;
|
||||
use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
use std::collections::LinkedList;
|
||||
|
||||
use serde::{Serialize, Deserialize};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::FloatExt;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct DecisionTreeClassifierParameters {
|
||||
pub struct DecisionTreeClassifierParameters {
|
||||
pub criterion: SplitCriterion,
|
||||
pub max_depth: Option<u16>,
|
||||
pub min_samples_leaf: usize,
|
||||
pub min_samples_split: usize
|
||||
pub min_samples_split: usize,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct DecisionTreeClassifier<T: FloatExt> {
|
||||
nodes: Vec<Node<T>>,
|
||||
parameters: DecisionTreeClassifierParameters,
|
||||
pub struct DecisionTreeClassifier<T: FloatExt> {
|
||||
nodes: Vec<Node<T>>,
|
||||
parameters: DecisionTreeClassifierParameters,
|
||||
num_classes: usize,
|
||||
classes: Vec<T>,
|
||||
depth: u16
|
||||
depth: u16,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub enum SplitCriterion {
|
||||
Gini,
|
||||
Entropy,
|
||||
ClassificationError
|
||||
ClassificationError,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Node<T: FloatExt> {
|
||||
index: usize,
|
||||
output: usize,
|
||||
index: usize,
|
||||
output: usize,
|
||||
split_feature: usize,
|
||||
split_value: Option<T>,
|
||||
split_score: Option<T>,
|
||||
true_child: Option<usize>,
|
||||
false_child: Option<usize>,
|
||||
false_child: Option<usize>,
|
||||
}
|
||||
|
||||
impl<T: FloatExt> PartialEq for DecisionTreeClassifier<T> {
|
||||
impl<T: FloatExt> PartialEq for DecisionTreeClassifier<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.depth != other.depth ||
|
||||
self.num_classes != other.num_classes ||
|
||||
self.nodes.len() != other.nodes.len(){
|
||||
return false
|
||||
if self.depth != other.depth
|
||||
|| self.num_classes != other.num_classes
|
||||
|| self.nodes.len() != other.nodes.len()
|
||||
{
|
||||
return false;
|
||||
} else {
|
||||
for i in 0..self.classes.len() {
|
||||
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
|
||||
return false
|
||||
return false;
|
||||
}
|
||||
}
|
||||
for i in 0..self.nodes.len() {
|
||||
if self.nodes[i] != other.nodes[i] {
|
||||
return false
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt> PartialEq for Node<T> {
|
||||
impl<T: FloatExt> PartialEq for Node<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.output == other.output &&
|
||||
self.split_feature == other.split_feature &&
|
||||
match (self.split_value, other.split_value) {
|
||||
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
||||
(None, None) => true,
|
||||
_ => false,
|
||||
} &&
|
||||
match (self.split_score, other.split_score) {
|
||||
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
||||
(None, None) => true,
|
||||
_ => false,
|
||||
}
|
||||
self.output == other.output
|
||||
&& self.split_feature == other.split_feature
|
||||
&& match (self.split_value, other.split_value) {
|
||||
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
||||
(None, None) => true,
|
||||
_ => false,
|
||||
}
|
||||
&& match (self.split_score, other.split_score) {
|
||||
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
||||
(None, None) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DecisionTreeClassifierParameters {
|
||||
fn default() -> Self {
|
||||
fn default() -> Self {
|
||||
DecisionTreeClassifierParameters {
|
||||
criterion: SplitCriterion::Gini,
|
||||
max_depth: None,
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2
|
||||
min_samples_split: 2,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt> Node<T> {
|
||||
fn new(index: usize, output: usize) -> Self {
|
||||
fn new(index: usize, output: usize) -> Self {
|
||||
Node {
|
||||
index: index,
|
||||
index: index,
|
||||
output: output,
|
||||
split_feature: 0,
|
||||
split_value: Option::None,
|
||||
split_score: Option::None,
|
||||
true_child: Option::None,
|
||||
false_child: Option::None
|
||||
false_child: Option::None,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct NodeVisitor<'a, T: FloatExt, M: Matrix<T>> {
|
||||
@@ -113,11 +114,11 @@ struct NodeVisitor<'a, T: FloatExt, M: Matrix<T>> {
|
||||
y: &'a Vec<usize>,
|
||||
node: usize,
|
||||
samples: Vec<usize>,
|
||||
order: &'a Vec<Vec<usize>>,
|
||||
order: &'a Vec<Vec<usize>>,
|
||||
true_child_output: usize,
|
||||
false_child_output: usize,
|
||||
level: u16,
|
||||
phantom: PhantomData<&'a T>
|
||||
phantom: PhantomData<&'a T>,
|
||||
}
|
||||
|
||||
fn impurity<T: FloatExt>(criterion: &SplitCriterion, count: &Vec<usize>, n: usize) -> T {
|
||||
@@ -131,7 +132,7 @@ fn impurity<T: FloatExt>(criterion: &SplitCriterion, count: &Vec<usize>, n: usiz
|
||||
let p = T::from(count[i]).unwrap() / T::from(n).unwrap();
|
||||
impurity = impurity - p * p;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SplitCriterion::Entropy => {
|
||||
@@ -149,15 +150,21 @@ fn impurity<T: FloatExt>(criterion: &SplitCriterion, count: &Vec<usize>, n: usiz
|
||||
}
|
||||
}
|
||||
impurity = (T::one() - impurity).abs();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return impurity;
|
||||
}
|
||||
|
||||
impl<'a, T: FloatExt, M: Matrix<T>> NodeVisitor<'a, T, M> {
|
||||
|
||||
fn new(node_id: usize, samples: Vec<usize>, order: &'a Vec<Vec<usize>>, x: &'a M, y: &'a Vec<usize>, level: u16) -> Self {
|
||||
impl<'a, T: FloatExt, M: Matrix<T>> NodeVisitor<'a, T, M> {
|
||||
fn new(
|
||||
node_id: usize,
|
||||
samples: Vec<usize>,
|
||||
order: &'a Vec<Vec<usize>>,
|
||||
x: &'a M,
|
||||
y: &'a Vec<usize>,
|
||||
level: u16,
|
||||
) -> Self {
|
||||
NodeVisitor {
|
||||
x: x,
|
||||
y: y,
|
||||
@@ -167,10 +174,9 @@ impl<'a, T: FloatExt, M: Matrix<T>> NodeVisitor<'a, T, M> {
|
||||
true_child_output: 0,
|
||||
false_child_output: 0,
|
||||
level: level,
|
||||
phantom: PhantomData
|
||||
phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
pub(in crate) fn which_max(x: &Vec<usize>) -> usize {
|
||||
@@ -188,19 +194,28 @@ pub(in crate) fn which_max(x: &Vec<usize>) -> usize {
|
||||
}
|
||||
|
||||
impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||
|
||||
pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier<T> {
|
||||
pub fn fit<M: Matrix<T>>(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: DecisionTreeClassifierParameters,
|
||||
) -> DecisionTreeClassifier<T> {
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let samples = vec![1; x_nrows];
|
||||
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||
}
|
||||
|
||||
pub fn fit_weak_learner<M: Matrix<T>>(x: &M, y: &M::RowVector, samples: Vec<usize>, mtry: usize, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier<T> {
|
||||
pub fn fit_weak_learner<M: Matrix<T>>(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
samples: Vec<usize>,
|
||||
mtry: usize,
|
||||
parameters: DecisionTreeClassifierParameters,
|
||||
) -> DecisionTreeClassifier<T> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
let (_, y_ncols) = y_m.shape();
|
||||
let (_, num_attributes) = x.shape();
|
||||
let classes = y_m.unique();
|
||||
let k = classes.len();
|
||||
let classes = y_m.unique();
|
||||
let k = classes.len();
|
||||
if k < 2 {
|
||||
panic!("Incorrect number of classes: {}. Should be >= 2.", k);
|
||||
}
|
||||
@@ -208,31 +223,31 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||
let mut yi: Vec<usize> = vec![0; y_ncols];
|
||||
|
||||
for i in 0..y_ncols {
|
||||
let yc = y_m.get(0, i);
|
||||
let yc = y_m.get(0, i);
|
||||
yi[i] = classes.iter().position(|c| yc == *c).unwrap();
|
||||
}
|
||||
|
||||
let mut nodes: Vec<Node<T>> = Vec::new();
|
||||
let mut nodes: Vec<Node<T>> = Vec::new();
|
||||
|
||||
let mut count = vec![0; k];
|
||||
for i in 0..y_ncols {
|
||||
count[yi[i]] += samples[i];
|
||||
}
|
||||
}
|
||||
|
||||
let root = Node::new(0, which_max(&count));
|
||||
let root = Node::new(0, which_max(&count));
|
||||
nodes.push(root);
|
||||
let mut order: Vec<Vec<usize>> = Vec::new();
|
||||
|
||||
for i in 0..num_attributes {
|
||||
order.push(x.get_col_as_vec(i).quick_argsort());
|
||||
}
|
||||
}
|
||||
|
||||
let mut tree = DecisionTreeClassifier{
|
||||
nodes: nodes,
|
||||
parameters: parameters,
|
||||
let mut tree = DecisionTreeClassifier {
|
||||
nodes: nodes,
|
||||
parameters: parameters,
|
||||
num_classes: k,
|
||||
classes: classes,
|
||||
depth: 0
|
||||
depth: 0,
|
||||
};
|
||||
|
||||
let mut visitor = NodeVisitor::<T, M>::new(0, samples, &order, &x, &yi, 1);
|
||||
@@ -243,12 +258,12 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||
visitor_queue.push_back(visitor);
|
||||
}
|
||||
|
||||
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
||||
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
||||
match visitor_queue.pop_front() {
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue,),
|
||||
None => break
|
||||
};
|
||||
}
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue),
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
|
||||
tree
|
||||
}
|
||||
@@ -270,7 +285,7 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||
let mut queue: LinkedList<usize> = LinkedList::new();
|
||||
|
||||
queue.push_back(0);
|
||||
|
||||
|
||||
while !queue.is_empty() {
|
||||
match queue.pop_front() {
|
||||
Some(node_id) => {
|
||||
@@ -284,18 +299,20 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||
queue.push_back(node.false_child.unwrap());
|
||||
}
|
||||
}
|
||||
},
|
||||
None => break
|
||||
}
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
}
|
||||
|
||||
fn find_best_cutoff<M: Matrix<T>>(&mut self, visitor: &mut NodeVisitor<T, M>, mtry: usize) -> bool {
|
||||
return result;
|
||||
}
|
||||
|
||||
let (n_rows, n_attr) = visitor.x.shape();
|
||||
fn find_best_cutoff<M: Matrix<T>>(
|
||||
&mut self,
|
||||
visitor: &mut NodeVisitor<T, M>,
|
||||
mtry: usize,
|
||||
) -> bool {
|
||||
let (n_rows, n_attr) = visitor.x.shape();
|
||||
|
||||
let mut label = Option::None;
|
||||
let mut is_pure = true;
|
||||
@@ -309,17 +326,17 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if is_pure {
|
||||
return false;
|
||||
}
|
||||
|
||||
let n = visitor.samples.iter().sum();
|
||||
let n = visitor.samples.iter().sum();
|
||||
|
||||
if n <= self.parameters.min_samples_split {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
let mut count = vec![0; self.num_classes];
|
||||
let mut false_count = vec![0; self.num_classes];
|
||||
for i in 0..n_rows {
|
||||
@@ -329,25 +346,38 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||
}
|
||||
|
||||
let parent_impurity = impurity(&self.parameters.criterion, &count, n);
|
||||
|
||||
|
||||
let mut variables = vec![0; n_attr];
|
||||
for i in 0..n_attr {
|
||||
variables[i] = i;
|
||||
}
|
||||
|
||||
for j in 0..mtry {
|
||||
self.find_best_split(visitor, n, &count, &mut false_count, parent_impurity, variables[j]);
|
||||
}
|
||||
self.find_best_split(
|
||||
visitor,
|
||||
n,
|
||||
&count,
|
||||
&mut false_count,
|
||||
parent_impurity,
|
||||
variables[j],
|
||||
);
|
||||
}
|
||||
|
||||
self.nodes[visitor.node].split_score != Option::None
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
fn find_best_split<M: Matrix<T>>(&mut self, visitor: &mut NodeVisitor<T, M>, n: usize, count: &Vec<usize>, false_count: &mut Vec<usize>, parent_impurity: T, j: usize){
|
||||
|
||||
fn find_best_split<M: Matrix<T>>(
|
||||
&mut self,
|
||||
visitor: &mut NodeVisitor<T, M>,
|
||||
n: usize,
|
||||
count: &Vec<usize>,
|
||||
false_count: &mut Vec<usize>,
|
||||
parent_impurity: T,
|
||||
j: usize,
|
||||
) {
|
||||
let mut true_count = vec![0; self.num_classes];
|
||||
let mut prevx = T::nan();
|
||||
let mut prevy = 0;
|
||||
let mut prevy = 0;
|
||||
|
||||
for i in visitor.order[j].iter() {
|
||||
if visitor.samples[*i] > 0 {
|
||||
@@ -360,7 +390,7 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||
|
||||
let tc = true_count.iter().sum();
|
||||
let fc = n - tc;
|
||||
|
||||
|
||||
if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf {
|
||||
prevx = visitor.x.get(*i, j);
|
||||
prevy = visitor.y[*i];
|
||||
@@ -373,12 +403,19 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||
}
|
||||
|
||||
let true_label = which_max(&true_count);
|
||||
let false_label = which_max(false_count);
|
||||
let gain = parent_impurity - T::from(tc).unwrap() / T::from(n).unwrap() * impurity(&self.parameters.criterion, &true_count, tc) - T::from(fc).unwrap() / T::from(n).unwrap() * impurity(&self.parameters.criterion, &false_count, fc);
|
||||
let false_label = which_max(false_count);
|
||||
let gain = parent_impurity
|
||||
- T::from(tc).unwrap() / T::from(n).unwrap()
|
||||
* impurity(&self.parameters.criterion, &true_count, tc)
|
||||
- T::from(fc).unwrap() / T::from(n).unwrap()
|
||||
* impurity(&self.parameters.criterion, &false_count, fc);
|
||||
|
||||
if self.nodes[visitor.node].split_score == Option::None || gain > self.nodes[visitor.node].split_score.unwrap() {
|
||||
if self.nodes[visitor.node].split_score == Option::None
|
||||
|| gain > self.nodes[visitor.node].split_score.unwrap()
|
||||
{
|
||||
self.nodes[visitor.node].split_feature = j;
|
||||
self.nodes[visitor.node].split_value = Option::Some((visitor.x.get(*i, j) + prevx) / T::two());
|
||||
self.nodes[visitor.node].split_value =
|
||||
Option::Some((visitor.x.get(*i, j) + prevx) / T::two());
|
||||
self.nodes[visitor.node].split_score = Option::Some(gain);
|
||||
visitor.true_child_output = true_label;
|
||||
visitor.false_child_output = false_label;
|
||||
@@ -389,22 +426,28 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||
true_count[visitor.y[*i]] += visitor.samples[*i];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
fn split<'a, M: Matrix<T>>(&mut self, mut visitor: NodeVisitor<'a, T, M>, mtry: usize, visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>) -> bool {
|
||||
fn split<'a, M: Matrix<T>>(
|
||||
&mut self,
|
||||
mut visitor: NodeVisitor<'a, T, M>,
|
||||
mtry: usize,
|
||||
visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>,
|
||||
) -> bool {
|
||||
let (n, _) = visitor.x.shape();
|
||||
let mut tc = 0;
|
||||
let mut fc = 0;
|
||||
let mut fc = 0;
|
||||
let mut true_samples: Vec<usize> = vec![0; n];
|
||||
|
||||
for i in 0..n {
|
||||
if visitor.samples[i] > 0 {
|
||||
if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value.unwrap_or(T::nan()) {
|
||||
if visitor.x.get(i, self.nodes[visitor.node].split_feature)
|
||||
<= self.nodes[visitor.node].split_value.unwrap_or(T::nan())
|
||||
{
|
||||
true_samples[i] = visitor.samples[i];
|
||||
tc += true_samples[i];
|
||||
visitor.samples[i] = 0;
|
||||
} else {
|
||||
} else {
|
||||
fc += visitor.samples[i];
|
||||
}
|
||||
}
|
||||
@@ -415,50 +458,73 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||
self.nodes[visitor.node].split_value = Option::None;
|
||||
self.nodes[visitor.node].split_score = Option::None;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
let true_child_idx = self.nodes.len();
|
||||
self.nodes.push(Node::new(true_child_idx, visitor.true_child_output));
|
||||
self.nodes
|
||||
.push(Node::new(true_child_idx, visitor.true_child_output));
|
||||
let false_child_idx = self.nodes.len();
|
||||
self.nodes.push(Node::new(false_child_idx, visitor.false_child_output));
|
||||
self.nodes
|
||||
.push(Node::new(false_child_idx, visitor.false_child_output));
|
||||
|
||||
self.nodes[visitor.node].true_child = Some(true_child_idx);
|
||||
self.nodes[visitor.node].false_child = Some(false_child_idx);
|
||||
|
||||
|
||||
self.depth = u16::max(self.depth, visitor.level + 1);
|
||||
|
||||
let mut true_visitor = NodeVisitor::<T, M>::new(true_child_idx, true_samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
|
||||
|
||||
let mut true_visitor = NodeVisitor::<T, M>::new(
|
||||
true_child_idx,
|
||||
true_samples,
|
||||
visitor.order,
|
||||
visitor.x,
|
||||
visitor.y,
|
||||
visitor.level + 1,
|
||||
);
|
||||
|
||||
if self.find_best_cutoff(&mut true_visitor, mtry) {
|
||||
visitor_queue.push_back(true_visitor);
|
||||
}
|
||||
|
||||
let mut false_visitor = NodeVisitor::<T, M>::new(false_child_idx, visitor.samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
|
||||
|
||||
let mut false_visitor = NodeVisitor::<T, M>::new(
|
||||
false_child_idx,
|
||||
visitor.samples,
|
||||
visitor.order,
|
||||
visitor.x,
|
||||
visitor.y,
|
||||
visitor.level + 1,
|
||||
);
|
||||
|
||||
if self.find_best_cutoff(&mut false_visitor, mtry) {
|
||||
visitor_queue.push_back(false_visitor);
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[test]
|
||||
fn gini_impurity() {
|
||||
assert!((impurity::<f64>(&SplitCriterion::Gini, &vec![7, 3], 10) - 0.42).abs() < std::f64::EPSILON);
|
||||
assert!((impurity::<f64>(&SplitCriterion::Entropy, &vec![7, 3], 10) - 0.8812908992306927).abs() < std::f64::EPSILON);
|
||||
assert!((impurity::<f64>(&SplitCriterion::ClassificationError, &vec![7, 3], 10) - 0.3).abs() < std::f64::EPSILON);
|
||||
assert!(
|
||||
(impurity::<f64>(&SplitCriterion::Gini, &vec![7, 3], 10) - 0.42).abs()
|
||||
< std::f64::EPSILON
|
||||
);
|
||||
assert!(
|
||||
(impurity::<f64>(&SplitCriterion::Entropy, &vec![7, 3], 10) - 0.8812908992306927).abs()
|
||||
< std::f64::EPSILON
|
||||
);
|
||||
assert!(
|
||||
(impurity::<f64>(&SplitCriterion::ClassificationError, &vec![7, 3], 10) - 0.3).abs()
|
||||
< std::f64::EPSILON
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fit_predict_iris() {
|
||||
|
||||
fn fit_predict_iris() {
|
||||
let x = DenseMatrix::from_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
@@ -479,75 +545,100 @@ mod tests {
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4]]);
|
||||
let y = vec![0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
|
||||
&[5.2, 2.7, 3.9, 1.4],
|
||||
]);
|
||||
let y = vec![
|
||||
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
|
||||
assert_eq!(y, DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x));
|
||||
|
||||
assert_eq!(3, DecisionTreeClassifier::fit(&x, &y, DecisionTreeClassifierParameters{criterion: SplitCriterion::Entropy, max_depth: Some(3), min_samples_leaf: 1, min_samples_split: 2}).depth);
|
||||
|
||||
assert_eq!(
|
||||
y,
|
||||
DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
3,
|
||||
DecisionTreeClassifier::fit(
|
||||
&x,
|
||||
&y,
|
||||
DecisionTreeClassifierParameters {
|
||||
criterion: SplitCriterion::Entropy,
|
||||
max_depth: Some(3),
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2
|
||||
}
|
||||
)
|
||||
.depth
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fit_predict_baloons() {
|
||||
|
||||
fn fit_predict_baloons() {
|
||||
let x = DenseMatrix::from_array(&[
|
||||
&[1.,1.,1.,0.],
|
||||
&[1.,1.,1.,0.],
|
||||
&[1.,1.,1.,1.],
|
||||
&[1.,1.,0.,0.],
|
||||
&[1.,1.,0.,1.],
|
||||
&[1.,0.,1.,0.],
|
||||
&[1.,0.,1.,0.],
|
||||
&[1.,0.,1.,1.],
|
||||
&[1.,0.,0.,0.],
|
||||
&[1.,0.,0.,1.],
|
||||
&[0.,1.,1.,0.],
|
||||
&[0.,1.,1.,0.],
|
||||
&[0.,1.,1.,1.],
|
||||
&[0.,1.,0.,0.],
|
||||
&[0.,1.,0.,1.],
|
||||
&[0.,0.,1.,0.],
|
||||
&[0.,0.,1.,0.],
|
||||
&[0.,0.,1.,1.],
|
||||
&[0.,0.,0.,0.],
|
||||
&[0.,0.,0.,1.]]);
|
||||
let y = vec![1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.];
|
||||
&[1., 1., 1., 0.],
|
||||
&[1., 1., 1., 0.],
|
||||
&[1., 1., 1., 1.],
|
||||
&[1., 1., 0., 0.],
|
||||
&[1., 1., 0., 1.],
|
||||
&[1., 0., 1., 0.],
|
||||
&[1., 0., 1., 0.],
|
||||
&[1., 0., 1., 1.],
|
||||
&[1., 0., 0., 0.],
|
||||
&[1., 0., 0., 1.],
|
||||
&[0., 1., 1., 0.],
|
||||
&[0., 1., 1., 0.],
|
||||
&[0., 1., 1., 1.],
|
||||
&[0., 1., 0., 0.],
|
||||
&[0., 1., 0., 1.],
|
||||
&[0., 0., 1., 0.],
|
||||
&[0., 0., 1., 0.],
|
||||
&[0., 0., 1., 1.],
|
||||
&[0., 0., 0., 0.],
|
||||
&[0., 0., 0., 1.],
|
||||
]);
|
||||
let y = vec![
|
||||
1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.,
|
||||
];
|
||||
|
||||
assert_eq!(y, DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x));
|
||||
|
||||
assert_eq!(
|
||||
y,
|
||||
DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde() {
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_array(&[
|
||||
&[1.,1.,1.,0.],
|
||||
&[1.,1.,1.,0.],
|
||||
&[1.,1.,1.,1.],
|
||||
&[1.,1.,0.,0.],
|
||||
&[1.,1.,0.,1.],
|
||||
&[1.,0.,1.,0.],
|
||||
&[1.,0.,1.,0.],
|
||||
&[1.,0.,1.,1.],
|
||||
&[1.,0.,0.,0.],
|
||||
&[1.,0.,0.,1.],
|
||||
&[0.,1.,1.,0.],
|
||||
&[0.,1.,1.,0.],
|
||||
&[0.,1.,1.,1.],
|
||||
&[0.,1.,0.,0.],
|
||||
&[0.,1.,0.,1.],
|
||||
&[0.,0.,1.,0.],
|
||||
&[0.,0.,1.,0.],
|
||||
&[0.,0.,1.,1.],
|
||||
&[0.,0.,0.,0.],
|
||||
&[0.,0.,0.,1.]]);
|
||||
let y = vec![1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.];
|
||||
&[1., 1., 1., 0.],
|
||||
&[1., 1., 1., 0.],
|
||||
&[1., 1., 1., 1.],
|
||||
&[1., 1., 0., 0.],
|
||||
&[1., 1., 0., 1.],
|
||||
&[1., 0., 1., 0.],
|
||||
&[1., 0., 1., 0.],
|
||||
&[1., 0., 1., 1.],
|
||||
&[1., 0., 0., 0.],
|
||||
&[1., 0., 0., 1.],
|
||||
&[0., 1., 1., 0.],
|
||||
&[0., 1., 1., 0.],
|
||||
&[0., 1., 1., 1.],
|
||||
&[0., 1., 0., 0.],
|
||||
&[0., 1., 0., 1.],
|
||||
&[0., 0., 1., 0.],
|
||||
&[0., 0., 1., 0.],
|
||||
&[0., 0., 1., 1.],
|
||||
&[0., 0., 0., 0.],
|
||||
&[0., 0., 0., 1.],
|
||||
]);
|
||||
let y = vec![
|
||||
1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.,
|
||||
];
|
||||
|
||||
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default());
|
||||
|
||||
let deserialized_tree: DecisionTreeClassifier<f64> = bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
|
||||
let deserialized_tree: DecisionTreeClassifier<f64> =
|
||||
bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(tree, deserialized_tree);
|
||||
|
||||
assert_eq!(tree, deserialized_tree);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+247
-167
@@ -1,91 +1,90 @@
|
||||
use std::collections::LinkedList;
|
||||
use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
use std::collections::LinkedList;
|
||||
|
||||
use serde::{Serialize, Deserialize};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::FloatExt;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct DecisionTreeRegressorParameters {
|
||||
pub struct DecisionTreeRegressorParameters {
|
||||
pub max_depth: Option<u16>,
|
||||
pub min_samples_leaf: usize,
|
||||
pub min_samples_split: usize
|
||||
pub min_samples_split: usize,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct DecisionTreeRegressor<T: FloatExt> {
|
||||
nodes: Vec<Node<T>>,
|
||||
parameters: DecisionTreeRegressorParameters,
|
||||
depth: u16
|
||||
pub struct DecisionTreeRegressor<T: FloatExt> {
|
||||
nodes: Vec<Node<T>>,
|
||||
parameters: DecisionTreeRegressorParameters,
|
||||
depth: u16,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Node<T: FloatExt> {
|
||||
index: usize,
|
||||
output: T,
|
||||
index: usize,
|
||||
output: T,
|
||||
split_feature: usize,
|
||||
split_value: Option<T>,
|
||||
split_score: Option<T>,
|
||||
true_child: Option<usize>,
|
||||
false_child: Option<usize>,
|
||||
false_child: Option<usize>,
|
||||
}
|
||||
|
||||
|
||||
impl Default for DecisionTreeRegressorParameters {
|
||||
fn default() -> Self {
|
||||
DecisionTreeRegressorParameters {
|
||||
fn default() -> Self {
|
||||
DecisionTreeRegressorParameters {
|
||||
max_depth: None,
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 2
|
||||
min_samples_split: 2,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt> Node<T> {
|
||||
fn new(index: usize, output: T) -> Self {
|
||||
fn new(index: usize, output: T) -> Self {
|
||||
Node {
|
||||
index: index,
|
||||
index: index,
|
||||
output: output,
|
||||
split_feature: 0,
|
||||
split_value: Option::None,
|
||||
split_score: Option::None,
|
||||
true_child: Option::None,
|
||||
false_child: Option::None
|
||||
false_child: Option::None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt> PartialEq for Node<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
(self.output - other.output).abs() < T::epsilon() &&
|
||||
self.split_feature == other.split_feature &&
|
||||
match (self.split_value, other.split_value) {
|
||||
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
||||
(None, None) => true,
|
||||
_ => false,
|
||||
} &&
|
||||
match (self.split_score, other.split_score) {
|
||||
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
||||
(None, None) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt> PartialEq for DecisionTreeRegressor<T> {
|
||||
impl<T: FloatExt> PartialEq for Node<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.depth != other.depth || self.nodes.len() != other.nodes.len(){
|
||||
return false
|
||||
} else {
|
||||
(self.output - other.output).abs() < T::epsilon()
|
||||
&& self.split_feature == other.split_feature
|
||||
&& match (self.split_value, other.split_value) {
|
||||
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
||||
(None, None) => true,
|
||||
_ => false,
|
||||
}
|
||||
&& match (self.split_score, other.split_score) {
|
||||
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
||||
(None, None) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt> PartialEq for DecisionTreeRegressor<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.depth != other.depth || self.nodes.len() != other.nodes.len() {
|
||||
return false;
|
||||
} else {
|
||||
for i in 0..self.nodes.len() {
|
||||
if self.nodes[i] != other.nodes[i] {
|
||||
return false
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -95,15 +94,21 @@ struct NodeVisitor<'a, T: FloatExt, M: Matrix<T>> {
|
||||
y: &'a M,
|
||||
node: usize,
|
||||
samples: Vec<usize>,
|
||||
order: &'a Vec<Vec<usize>>,
|
||||
order: &'a Vec<Vec<usize>>,
|
||||
true_child_output: T,
|
||||
false_child_output: T,
|
||||
level: u16
|
||||
level: u16,
|
||||
}
|
||||
|
||||
impl<'a, T: FloatExt, M: Matrix<T>> NodeVisitor<'a, T, M> {
|
||||
|
||||
fn new(node_id: usize, samples: Vec<usize>, order: &'a Vec<Vec<usize>>, x: &'a M, y: &'a M, level: u16) -> Self {
|
||||
impl<'a, T: FloatExt, M: Matrix<T>> NodeVisitor<'a, T, M> {
|
||||
fn new(
|
||||
node_id: usize,
|
||||
samples: Vec<usize>,
|
||||
order: &'a Vec<Vec<usize>>,
|
||||
x: &'a M,
|
||||
y: &'a M,
|
||||
level: u16,
|
||||
) -> Self {
|
||||
NodeVisitor {
|
||||
x: x,
|
||||
y: y,
|
||||
@@ -112,33 +117,41 @@ impl<'a, T: FloatExt, M: Matrix<T>> NodeVisitor<'a, T, M> {
|
||||
order: order,
|
||||
true_child_output: T::zero(),
|
||||
false_child_output: T::zero(),
|
||||
level: level
|
||||
level: level,
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
impl<T: FloatExt> DecisionTreeRegressor<T> {
|
||||
|
||||
pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, parameters: DecisionTreeRegressorParameters) -> DecisionTreeRegressor<T> {
|
||||
pub fn fit<M: Matrix<T>>(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: DecisionTreeRegressorParameters,
|
||||
) -> DecisionTreeRegressor<T> {
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let samples = vec![1; x_nrows];
|
||||
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||
}
|
||||
|
||||
pub fn fit_weak_learner<M: Matrix<T>>(x: &M, y: &M::RowVector, samples: Vec<usize>, mtry: usize, parameters: DecisionTreeRegressorParameters) -> DecisionTreeRegressor<T> {
|
||||
pub fn fit_weak_learner<M: Matrix<T>>(
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
samples: Vec<usize>,
|
||||
mtry: usize,
|
||||
parameters: DecisionTreeRegressorParameters,
|
||||
) -> DecisionTreeRegressor<T> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
|
||||
|
||||
let (_, y_ncols) = y_m.shape();
|
||||
let (_, num_attributes) = x.shape();
|
||||
let classes = y_m.unique();
|
||||
let k = classes.len();
|
||||
let classes = y_m.unique();
|
||||
let k = classes.len();
|
||||
if k < 2 {
|
||||
panic!("Incorrect number of classes: {}. Should be >= 2.", k);
|
||||
}
|
||||
}
|
||||
|
||||
let mut nodes: Vec<Node<T>> = Vec::new();
|
||||
|
||||
let mut nodes: Vec<Node<T>> = Vec::new();
|
||||
|
||||
let mut n = 0;
|
||||
let mut sum = T::zero();
|
||||
for i in 0..y_ncols {
|
||||
@@ -146,18 +159,18 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
||||
sum = sum + T::from(samples[i]).unwrap() * y_m.get(0, i);
|
||||
}
|
||||
|
||||
let root = Node::new(0, sum / T::from(n).unwrap());
|
||||
let root = Node::new(0, sum / T::from(n).unwrap());
|
||||
nodes.push(root);
|
||||
let mut order: Vec<Vec<usize>> = Vec::new();
|
||||
|
||||
for i in 0..num_attributes {
|
||||
order.push(x.get_col_as_vec(i).quick_argsort());
|
||||
}
|
||||
}
|
||||
|
||||
let mut tree = DecisionTreeRegressor{
|
||||
nodes: nodes,
|
||||
parameters: parameters,
|
||||
depth: 0
|
||||
let mut tree = DecisionTreeRegressor {
|
||||
nodes: nodes,
|
||||
parameters: parameters,
|
||||
depth: 0,
|
||||
};
|
||||
|
||||
let mut visitor = NodeVisitor::<T, M>::new(0, samples, &order, &x, &y_m, 1);
|
||||
@@ -168,12 +181,12 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
||||
visitor_queue.push_back(visitor);
|
||||
}
|
||||
|
||||
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
||||
while tree.depth < tree.parameters.max_depth.unwrap_or(std::u16::MAX) {
|
||||
match visitor_queue.pop_front() {
|
||||
Some(node) => tree.split(node, mtry, &mut visitor_queue),
|
||||
None => break
|
||||
};
|
||||
}
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
|
||||
tree
|
||||
}
|
||||
@@ -195,7 +208,7 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
||||
let mut queue: LinkedList<usize> = LinkedList::new();
|
||||
|
||||
queue.push_back(0);
|
||||
|
||||
|
||||
while !queue.is_empty() {
|
||||
match queue.pop_front() {
|
||||
Some(node_id) => {
|
||||
@@ -209,100 +222,123 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
||||
queue.push_back(node.false_child.unwrap());
|
||||
}
|
||||
}
|
||||
},
|
||||
None => break
|
||||
}
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
}
|
||||
|
||||
fn find_best_cutoff<M: Matrix<T>>(&mut self, visitor: &mut NodeVisitor<T, M>, mtry: usize) -> bool {
|
||||
return result;
|
||||
}
|
||||
|
||||
let (_, n_attr) = visitor.x.shape();
|
||||
fn find_best_cutoff<M: Matrix<T>>(
|
||||
&mut self,
|
||||
visitor: &mut NodeVisitor<T, M>,
|
||||
mtry: usize,
|
||||
) -> bool {
|
||||
let (_, n_attr) = visitor.x.shape();
|
||||
|
||||
let n: usize = visitor.samples.iter().sum();
|
||||
let n: usize = visitor.samples.iter().sum();
|
||||
|
||||
if n < self.parameters.min_samples_split {
|
||||
return false;
|
||||
}
|
||||
|
||||
let sum = self.nodes[visitor.node].output * T::from(n).unwrap();
|
||||
|
||||
let sum = self.nodes[visitor.node].output * T::from(n).unwrap();
|
||||
|
||||
let mut variables = vec![0; n_attr];
|
||||
for i in 0..n_attr {
|
||||
variables[i] = i;
|
||||
}
|
||||
|
||||
let parent_gain = T::from(n).unwrap() * self.nodes[visitor.node].output * self.nodes[visitor.node].output;
|
||||
let parent_gain =
|
||||
T::from(n).unwrap() * self.nodes[visitor.node].output * self.nodes[visitor.node].output;
|
||||
|
||||
for j in 0..mtry {
|
||||
self.find_best_split(visitor, n, sum, parent_gain, variables[j]);
|
||||
}
|
||||
}
|
||||
|
||||
self.nodes[visitor.node].split_score != Option::None
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
fn find_best_split<M: Matrix<T>>(&mut self, visitor: &mut NodeVisitor<T, M>, n: usize, sum: T, parent_gain: T, j: usize){
|
||||
|
||||
fn find_best_split<M: Matrix<T>>(
|
||||
&mut self,
|
||||
visitor: &mut NodeVisitor<T, M>,
|
||||
n: usize,
|
||||
sum: T,
|
||||
parent_gain: T,
|
||||
j: usize,
|
||||
) {
|
||||
let mut true_sum = T::zero();
|
||||
let mut true_count = 0;
|
||||
let mut prevx = T::nan();
|
||||
|
||||
let mut prevx = T::nan();
|
||||
|
||||
for i in visitor.order[j].iter() {
|
||||
if visitor.samples[*i] > 0 {
|
||||
if prevx.is_nan() || visitor.x.get(*i, j) == prevx {
|
||||
prevx = visitor.x.get(*i, j);
|
||||
true_count += visitor.samples[*i];
|
||||
true_sum = true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i);
|
||||
true_sum =
|
||||
true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i);
|
||||
continue;
|
||||
}
|
||||
|
||||
let false_count = n - true_count;
|
||||
|
||||
if true_count < self.parameters.min_samples_leaf || false_count < self.parameters.min_samples_leaf {
|
||||
|
||||
if true_count < self.parameters.min_samples_leaf
|
||||
|| false_count < self.parameters.min_samples_leaf
|
||||
{
|
||||
prevx = visitor.x.get(*i, j);
|
||||
true_count += visitor.samples[*i];
|
||||
true_sum = true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i);
|
||||
true_sum =
|
||||
true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i);
|
||||
continue;
|
||||
}
|
||||
|
||||
let true_mean = true_sum / T::from(true_count).unwrap();
|
||||
let false_mean = (sum - true_sum) / T::from(false_count).unwrap();
|
||||
let false_mean = (sum - true_sum) / T::from(false_count).unwrap();
|
||||
|
||||
let gain = (T::from(true_count).unwrap() * true_mean * true_mean + T::from(false_count).unwrap() * false_mean * false_mean) - parent_gain;
|
||||
|
||||
if self.nodes[visitor.node].split_score == Option::None || gain > self.nodes[visitor.node].split_score.unwrap() {
|
||||
let gain = (T::from(true_count).unwrap() * true_mean * true_mean
|
||||
+ T::from(false_count).unwrap() * false_mean * false_mean)
|
||||
- parent_gain;
|
||||
|
||||
if self.nodes[visitor.node].split_score == Option::None
|
||||
|| gain > self.nodes[visitor.node].split_score.unwrap()
|
||||
{
|
||||
self.nodes[visitor.node].split_feature = j;
|
||||
self.nodes[visitor.node].split_value = Option::Some((visitor.x.get(*i, j) + prevx) / T::two());
|
||||
self.nodes[visitor.node].split_value =
|
||||
Option::Some((visitor.x.get(*i, j) + prevx) / T::two());
|
||||
self.nodes[visitor.node].split_score = Option::Some(gain);
|
||||
visitor.true_child_output = true_mean;
|
||||
visitor.false_child_output = false_mean;
|
||||
}
|
||||
|
||||
prevx = visitor.x.get(*i, j);
|
||||
true_sum = true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i);
|
||||
true_sum = true_sum + T::from(visitor.samples[*i]).unwrap() * visitor.y.get(0, *i);
|
||||
true_count += visitor.samples[*i];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
fn split<'a, M: Matrix<T>>(&mut self, mut visitor: NodeVisitor<'a, T, M>, mtry: usize, visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>) -> bool {
|
||||
fn split<'a, M: Matrix<T>>(
|
||||
&mut self,
|
||||
mut visitor: NodeVisitor<'a, T, M>,
|
||||
mtry: usize,
|
||||
visitor_queue: &mut LinkedList<NodeVisitor<'a, T, M>>,
|
||||
) -> bool {
|
||||
let (n, _) = visitor.x.shape();
|
||||
let mut tc = 0;
|
||||
let mut fc = 0;
|
||||
let mut fc = 0;
|
||||
let mut true_samples: Vec<usize> = vec![0; n];
|
||||
|
||||
for i in 0..n {
|
||||
if visitor.samples[i] > 0 {
|
||||
if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value.unwrap_or(T::nan()) {
|
||||
if visitor.x.get(i, self.nodes[visitor.node].split_feature)
|
||||
<= self.nodes[visitor.node].split_value.unwrap_or(T::nan())
|
||||
{
|
||||
true_samples[i] = visitor.samples[i];
|
||||
tc += true_samples[i];
|
||||
visitor.samples[i] = 0;
|
||||
} else {
|
||||
} else {
|
||||
fc += visitor.samples[i];
|
||||
}
|
||||
}
|
||||
@@ -313,111 +349,155 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
||||
self.nodes[visitor.node].split_value = Option::None;
|
||||
self.nodes[visitor.node].split_score = Option::None;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
let true_child_idx = self.nodes.len();
|
||||
self.nodes.push(Node::new(true_child_idx, visitor.true_child_output));
|
||||
self.nodes
|
||||
.push(Node::new(true_child_idx, visitor.true_child_output));
|
||||
let false_child_idx = self.nodes.len();
|
||||
self.nodes.push(Node::new(false_child_idx, visitor.false_child_output));
|
||||
self.nodes
|
||||
.push(Node::new(false_child_idx, visitor.false_child_output));
|
||||
|
||||
self.nodes[visitor.node].true_child = Some(true_child_idx);
|
||||
self.nodes[visitor.node].false_child = Some(false_child_idx);
|
||||
|
||||
|
||||
self.depth = u16::max(self.depth, visitor.level + 1);
|
||||
|
||||
let mut true_visitor = NodeVisitor::<T, M>::new(true_child_idx, true_samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
|
||||
|
||||
let mut true_visitor = NodeVisitor::<T, M>::new(
|
||||
true_child_idx,
|
||||
true_samples,
|
||||
visitor.order,
|
||||
visitor.x,
|
||||
visitor.y,
|
||||
visitor.level + 1,
|
||||
);
|
||||
|
||||
if self.find_best_cutoff(&mut true_visitor, mtry) {
|
||||
visitor_queue.push_back(true_visitor);
|
||||
}
|
||||
|
||||
let mut false_visitor = NodeVisitor::<T, M>::new(false_child_idx, visitor.samples, visitor.order, visitor.x, visitor.y, visitor.level + 1);
|
||||
|
||||
let mut false_visitor = NodeVisitor::<T, M>::new(
|
||||
false_child_idx,
|
||||
visitor.samples,
|
||||
visitor.order,
|
||||
visitor.x,
|
||||
visitor.y,
|
||||
visitor.level + 1,
|
||||
);
|
||||
|
||||
if self.find_best_cutoff(&mut false_visitor, mtry) {
|
||||
visitor_queue.push_back(false_visitor);
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
use super::*;
|
||||
use crate::linalg::naive::dense_matrix::DenseMatrix;
|
||||
|
||||
#[test]
|
||||
fn fit_longley() {
|
||||
|
||||
fn fit_longley() {
|
||||
let x = DenseMatrix::from_array(&[
|
||||
&[ 234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||
&[ 259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
&[ 258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
&[ 284.599, 335.1, 165., 110.929, 1950., 61.187],
|
||||
&[ 328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
&[ 346.999, 193.2, 359.4, 113.27 , 1952., 63.639],
|
||||
&[ 365.385, 187., 354.7, 115.094, 1953., 64.989],
|
||||
&[ 363.112, 357.8, 335., 116.219, 1954., 63.761],
|
||||
&[ 397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
&[ 419.18 , 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
&[ 442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
&[ 444.546, 468.1, 263.7, 121.95 , 1958., 66.513],
|
||||
&[ 482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
&[ 502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[ 518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
|
||||
let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
||||
&[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
&[284.599, 335.1, 165., 110.929, 1950., 61.187],
|
||||
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
&[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
|
||||
&[365.385, 187., 354.7, 115.094, 1953., 64.989],
|
||||
&[363.112, 357.8, 335., 116.219, 1954., 63.761],
|
||||
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
&[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
&[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
|
||||
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
let y: Vec<f64> = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
|
||||
let y_hat = DecisionTreeRegressor::fit(&x, &y, Default::default()).predict(&x);
|
||||
let y_hat = DecisionTreeRegressor::fit(&x, &y, Default::default()).predict(&x);
|
||||
|
||||
for i in 0..y_hat.len() {
|
||||
assert!((y_hat[i] - y[i]).abs() < 0.1);
|
||||
}
|
||||
}
|
||||
|
||||
let expected_y = vec![87.3, 87.3, 87.3, 87.3, 98.9, 98.9, 98.9, 98.9, 98.9, 107.9, 107.9, 107.9, 114.85, 114.85, 114.85, 114.85];
|
||||
let y_hat = DecisionTreeRegressor::fit(&x, &y, DecisionTreeRegressorParameters{max_depth: Option::None, min_samples_leaf: 2, min_samples_split: 6}).predict(&x);
|
||||
let expected_y = vec![
|
||||
87.3, 87.3, 87.3, 87.3, 98.9, 98.9, 98.9, 98.9, 98.9, 107.9, 107.9, 107.9, 114.85,
|
||||
114.85, 114.85, 114.85,
|
||||
];
|
||||
let y_hat = DecisionTreeRegressor::fit(
|
||||
&x,
|
||||
&y,
|
||||
DecisionTreeRegressorParameters {
|
||||
max_depth: Option::None,
|
||||
min_samples_leaf: 2,
|
||||
min_samples_split: 6,
|
||||
},
|
||||
)
|
||||
.predict(&x);
|
||||
|
||||
for i in 0..y_hat.len() {
|
||||
assert!((y_hat[i] - expected_y[i]).abs() < 0.1);
|
||||
}
|
||||
|
||||
let expected_y = vec![83.0, 88.35, 88.35, 89.5, 97.15, 97.15, 99.5, 99.5, 101.2, 104.6, 109.6, 109.6, 113.4, 113.4, 116.30, 116.30];
|
||||
let y_hat = DecisionTreeRegressor::fit(&x, &y, DecisionTreeRegressorParameters{max_depth: Option::None, min_samples_leaf: 1, min_samples_split: 3}).predict(&x);
|
||||
|
||||
let expected_y = vec![
|
||||
83.0, 88.35, 88.35, 89.5, 97.15, 97.15, 99.5, 99.5, 101.2, 104.6, 109.6, 109.6, 113.4,
|
||||
113.4, 116.30, 116.30,
|
||||
];
|
||||
let y_hat = DecisionTreeRegressor::fit(
|
||||
&x,
|
||||
&y,
|
||||
DecisionTreeRegressorParameters {
|
||||
max_depth: Option::None,
|
||||
min_samples_leaf: 1,
|
||||
min_samples_split: 3,
|
||||
},
|
||||
)
|
||||
.predict(&x);
|
||||
|
||||
for i in 0..y_hat.len() {
|
||||
assert!((y_hat[i] - expected_y[i]).abs() < 0.1);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde() {
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_array(&[
|
||||
&[ 234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||
&[ 259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
&[ 258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
&[ 284.599, 335.1, 165., 110.929, 1950., 61.187],
|
||||
&[ 328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
&[ 346.999, 193.2, 359.4, 113.27 , 1952., 63.639],
|
||||
&[ 365.385, 187., 354.7, 115.094, 1953., 64.989],
|
||||
&[ 363.112, 357.8, 335., 116.219, 1954., 63.761],
|
||||
&[ 397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
&[ 419.18 , 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
&[ 442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
&[ 444.546, 468.1, 263.7, 121.95 , 1958., 66.513],
|
||||
&[ 482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
&[ 502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[ 518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
|
||||
let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
||||
&[234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||
&[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
&[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
&[284.599, 335.1, 165., 110.929, 1950., 61.187],
|
||||
&[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
&[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
|
||||
&[365.385, 187., 354.7, 115.094, 1953., 64.989],
|
||||
&[363.112, 357.8, 335., 116.219, 1954., 63.761],
|
||||
&[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
&[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
&[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
&[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
|
||||
&[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
&[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
|
||||
]);
|
||||
let y: Vec<f64> = vec![
|
||||
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
|
||||
let tree = DecisionTreeRegressor::fit(&x, &y, Default::default());
|
||||
|
||||
let deserialized_tree: DecisionTreeRegressor<f64> = bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
|
||||
let deserialized_tree: DecisionTreeRegressor<f64> =
|
||||
bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(tree, deserialized_tree);
|
||||
|
||||
assert_eq!(tree, deserialized_tree);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
+1
-1
@@ -1,2 +1,2 @@
|
||||
pub mod decision_tree_classifier;
|
||||
pub mod decision_tree_regressor;
|
||||
pub mod decision_tree_classifier;
|
||||
Reference in New Issue
Block a user