feat: adds serialization/deserialization methods

This commit is contained in:
Volodymyr Orlov
2020-04-03 11:12:15 -07:00
parent 5766364311
commit eb0c36223f
16 changed files with 555 additions and 159 deletions
+89 -16
View File
@@ -3,11 +3,13 @@ use std::fmt::Debug;
use std::marker::PhantomData;
use std::collections::LinkedList;
use serde::{Serialize, Deserialize};
use crate::math::num::FloatExt;
use crate::linalg::Matrix;
use crate::algorithm::sort::quick_sort::QuickArgSort;
#[derive(Debug)]
#[derive(Serialize, Deserialize, Debug)]
pub struct DecisionTreeClassifierParameters {
pub criterion: SplitCriterion,
pub max_depth: Option<u16>,
@@ -15,7 +17,7 @@ pub struct DecisionTreeClassifierParameters {
pub min_samples_split: usize
}
#[derive(Debug)]
#[derive(Serialize, Deserialize, Debug)]
pub struct DecisionTreeClassifier<T: FloatExt> {
nodes: Vec<Node<T>>,
parameters: DecisionTreeClassifierParameters,
@@ -24,24 +26,62 @@ pub struct DecisionTreeClassifier<T: FloatExt> {
depth: u16
}
#[derive(Debug, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum SplitCriterion {
Gini,
Entropy,
ClassificationError
}
#[derive(Debug)]
#[derive(Serialize, Deserialize, Debug)]
pub struct Node<T: FloatExt> {
index: usize,
output: usize,
split_feature: usize,
split_value: T,
split_score: T,
split_value: Option<T>,
split_score: Option<T>,
true_child: Option<usize>,
false_child: Option<usize>,
}
impl<T: FloatExt> PartialEq for DecisionTreeClassifier<T> {
fn eq(&self, other: &Self) -> bool {
if self.depth != other.depth ||
self.num_classes != other.num_classes ||
self.nodes.len() != other.nodes.len(){
return false
} else {
for i in 0..self.classes.len() {
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
return false
}
}
for i in 0..self.nodes.len() {
if self.nodes[i] != other.nodes[i] {
return false
}
}
return true
}
}
}
impl<T: FloatExt> PartialEq for Node<T> {
fn eq(&self, other: &Self) -> bool {
self.output == other.output &&
self.split_feature == other.split_feature &&
match (self.split_value, other.split_value) {
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
(None, None) => true,
_ => false,
} &&
match (self.split_score, other.split_score) {
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
(None, None) => true,
_ => false,
}
}
}
impl Default for DecisionTreeClassifierParameters {
fn default() -> Self {
@@ -60,8 +100,8 @@ impl<T: FloatExt> Node<T> {
index: index,
output: output,
split_feature: 0,
split_value: T::nan(),
split_score: T::nan(),
split_value: Option::None,
split_score: Option::None,
true_child: Option::None,
false_child: Option::None
}
@@ -238,7 +278,7 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
if node.true_child == None && node.false_child == None {
result = node.output;
} else {
if x.get(row, node.split_feature) <= node.split_value {
if x.get(row, node.split_feature) <= node.split_value.unwrap_or(T::nan()) {
queue.push_back(node.true_child.unwrap());
} else {
queue.push_back(node.false_child.unwrap());
@@ -299,7 +339,7 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
self.find_best_split(visitor, n, &count, &mut false_count, parent_impurity, variables[j]);
}
!self.nodes[visitor.node].split_score.is_nan()
self.nodes[visitor.node].split_score != Option::None
}
@@ -336,10 +376,10 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
let false_label = which_max(false_count);
let gain = parent_impurity - T::from(tc).unwrap() / T::from(n).unwrap() * impurity(&self.parameters.criterion, &true_count, tc) - T::from(fc).unwrap() / T::from(n).unwrap() * impurity(&self.parameters.criterion, &false_count, fc);
if self.nodes[visitor.node].split_score.is_nan() || gain > self.nodes[visitor.node].split_score {
if self.nodes[visitor.node].split_score == Option::None || gain > self.nodes[visitor.node].split_score.unwrap() {
self.nodes[visitor.node].split_feature = j;
self.nodes[visitor.node].split_value = (visitor.x.get(*i, j) + prevx) / T::two();
self.nodes[visitor.node].split_score = gain;
self.nodes[visitor.node].split_value = Option::Some((visitor.x.get(*i, j) + prevx) / T::two());
self.nodes[visitor.node].split_score = Option::Some(gain);
visitor.true_child_output = true_label;
visitor.false_child_output = false_label;
}
@@ -360,7 +400,7 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
for i in 0..n {
if visitor.samples[i] > 0 {
if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value {
if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value.unwrap_or(T::nan()) {
true_samples[i] = visitor.samples[i];
tc += true_samples[i];
visitor.samples[i] = 0;
@@ -372,8 +412,8 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf {
self.nodes[visitor.node].split_feature = 0;
self.nodes[visitor.node].split_value = T::nan();
self.nodes[visitor.node].split_score = T::nan();
self.nodes[visitor.node].split_value = Option::None;
self.nodes[visitor.node].split_score = Option::None;
return false;
}
@@ -477,4 +517,37 @@ mod tests {
assert_eq!(y, DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x));
}
#[test]
fn serde() {
let x = DenseMatrix::from_array(&[
&[1.,1.,1.,0.],
&[1.,1.,1.,0.],
&[1.,1.,1.,1.],
&[1.,1.,0.,0.],
&[1.,1.,0.,1.],
&[1.,0.,1.,0.],
&[1.,0.,1.,0.],
&[1.,0.,1.,1.],
&[1.,0.,0.,0.],
&[1.,0.,0.,1.],
&[0.,1.,1.,0.],
&[0.,1.,1.,0.],
&[0.,1.,1.,1.],
&[0.,1.,0.,0.],
&[0.,1.,0.,1.],
&[0.,0.,1.,0.],
&[0.,0.,1.,0.],
&[0.,0.,1.,1.],
&[0.,0.,0.,0.],
&[0.,0.,0.,1.]]);
let y = vec![1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.];
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default());
let deserialized_tree: DecisionTreeClassifier<f64> = bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
assert_eq!(tree, deserialized_tree);
}
}