feat: adds serialization/deserialization methods

This commit is contained in:
Volodymyr Orlov
2020-04-03 11:12:15 -07:00
parent 5766364311
commit eb0c36223f
16 changed files with 555 additions and 159 deletions
+89 -16
View File
@@ -3,11 +3,13 @@ use std::fmt::Debug;
use std::marker::PhantomData;
use std::collections::LinkedList;
use serde::{Serialize, Deserialize};
use crate::math::num::FloatExt;
use crate::linalg::Matrix;
use crate::algorithm::sort::quick_sort::QuickArgSort;
#[derive(Debug)]
#[derive(Serialize, Deserialize, Debug)]
pub struct DecisionTreeClassifierParameters {
pub criterion: SplitCriterion,
pub max_depth: Option<u16>,
@@ -15,7 +17,7 @@ pub struct DecisionTreeClassifierParameters {
pub min_samples_split: usize
}
#[derive(Debug)]
#[derive(Serialize, Deserialize, Debug)]
pub struct DecisionTreeClassifier<T: FloatExt> {
nodes: Vec<Node<T>>,
parameters: DecisionTreeClassifierParameters,
@@ -24,24 +26,62 @@ pub struct DecisionTreeClassifier<T: FloatExt> {
depth: u16
}
#[derive(Debug, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum SplitCriterion {
Gini,
Entropy,
ClassificationError
}
#[derive(Debug)]
#[derive(Serialize, Deserialize, Debug)]
pub struct Node<T: FloatExt> {
index: usize,
output: usize,
split_feature: usize,
split_value: T,
split_score: T,
split_value: Option<T>,
split_score: Option<T>,
true_child: Option<usize>,
false_child: Option<usize>,
}
impl<T: FloatExt> PartialEq for DecisionTreeClassifier<T> {
fn eq(&self, other: &Self) -> bool {
if self.depth != other.depth ||
self.num_classes != other.num_classes ||
self.nodes.len() != other.nodes.len(){
return false
} else {
for i in 0..self.classes.len() {
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
return false
}
}
for i in 0..self.nodes.len() {
if self.nodes[i] != other.nodes[i] {
return false
}
}
return true
}
}
}
impl<T: FloatExt> PartialEq for Node<T> {
fn eq(&self, other: &Self) -> bool {
self.output == other.output &&
self.split_feature == other.split_feature &&
match (self.split_value, other.split_value) {
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
(None, None) => true,
_ => false,
} &&
match (self.split_score, other.split_score) {
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
(None, None) => true,
_ => false,
}
}
}
impl Default for DecisionTreeClassifierParameters {
fn default() -> Self {
@@ -60,8 +100,8 @@ impl<T: FloatExt> Node<T> {
index: index,
output: output,
split_feature: 0,
split_value: T::nan(),
split_score: T::nan(),
split_value: Option::None,
split_score: Option::None,
true_child: Option::None,
false_child: Option::None
}
@@ -238,7 +278,7 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
if node.true_child == None && node.false_child == None {
result = node.output;
} else {
if x.get(row, node.split_feature) <= node.split_value {
if x.get(row, node.split_feature) <= node.split_value.unwrap_or(T::nan()) {
queue.push_back(node.true_child.unwrap());
} else {
queue.push_back(node.false_child.unwrap());
@@ -299,7 +339,7 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
self.find_best_split(visitor, n, &count, &mut false_count, parent_impurity, variables[j]);
}
!self.nodes[visitor.node].split_score.is_nan()
self.nodes[visitor.node].split_score != Option::None
}
@@ -336,10 +376,10 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
let false_label = which_max(false_count);
let gain = parent_impurity - T::from(tc).unwrap() / T::from(n).unwrap() * impurity(&self.parameters.criterion, &true_count, tc) - T::from(fc).unwrap() / T::from(n).unwrap() * impurity(&self.parameters.criterion, &false_count, fc);
if self.nodes[visitor.node].split_score.is_nan() || gain > self.nodes[visitor.node].split_score {
if self.nodes[visitor.node].split_score == Option::None || gain > self.nodes[visitor.node].split_score.unwrap() {
self.nodes[visitor.node].split_feature = j;
self.nodes[visitor.node].split_value = (visitor.x.get(*i, j) + prevx) / T::two();
self.nodes[visitor.node].split_score = gain;
self.nodes[visitor.node].split_value = Option::Some((visitor.x.get(*i, j) + prevx) / T::two());
self.nodes[visitor.node].split_score = Option::Some(gain);
visitor.true_child_output = true_label;
visitor.false_child_output = false_label;
}
@@ -360,7 +400,7 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
for i in 0..n {
if visitor.samples[i] > 0 {
if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value {
if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value.unwrap_or(T::nan()) {
true_samples[i] = visitor.samples[i];
tc += true_samples[i];
visitor.samples[i] = 0;
@@ -372,8 +412,8 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf {
self.nodes[visitor.node].split_feature = 0;
self.nodes[visitor.node].split_value = T::nan();
self.nodes[visitor.node].split_score = T::nan();
self.nodes[visitor.node].split_value = Option::None;
self.nodes[visitor.node].split_score = Option::None;
return false;
}
@@ -477,4 +517,37 @@ mod tests {
assert_eq!(y, DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x));
}
#[test]
fn serde() {
let x = DenseMatrix::from_array(&[
&[1.,1.,1.,0.],
&[1.,1.,1.,0.],
&[1.,1.,1.,1.],
&[1.,1.,0.,0.],
&[1.,1.,0.,1.],
&[1.,0.,1.,0.],
&[1.,0.,1.,0.],
&[1.,0.,1.,1.],
&[1.,0.,0.,0.],
&[1.,0.,0.,1.],
&[0.,1.,1.,0.],
&[0.,1.,1.,0.],
&[0.,1.,1.,1.],
&[0.,1.,0.,0.],
&[0.,1.,0.,1.],
&[0.,0.,1.,0.],
&[0.,0.,1.,0.],
&[0.,0.,1.,1.],
&[0.,0.,0.,0.],
&[0.,0.,0.,1.]]);
let y = vec![1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.];
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default());
let deserialized_tree: DecisionTreeClassifier<f64> = bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
assert_eq!(tree, deserialized_tree);
}
}
+78 -15
View File
@@ -2,31 +2,33 @@ use std::default::Default;
use std::fmt::Debug;
use std::collections::LinkedList;
use serde::{Serialize, Deserialize};
use crate::math::num::FloatExt;
use crate::linalg::Matrix;
use crate::algorithm::sort::quick_sort::QuickArgSort;
#[derive(Debug)]
#[derive(Serialize, Deserialize, Debug)]
pub struct DecisionTreeRegressorParameters {
pub max_depth: Option<u16>,
pub min_samples_leaf: usize,
pub min_samples_split: usize
}
#[derive(Debug)]
#[derive(Serialize, Deserialize, Debug)]
pub struct DecisionTreeRegressor<T: FloatExt> {
nodes: Vec<Node<T>>,
parameters: DecisionTreeRegressorParameters,
depth: u16
}
#[derive(Debug)]
#[derive(Serialize, Deserialize, Debug)]
pub struct Node<T: FloatExt> {
index: usize,
output: T,
split_feature: usize,
split_value: T,
split_score: T,
split_value: Option<T>,
split_score: Option<T>,
true_child: Option<usize>,
false_child: Option<usize>,
}
@@ -48,14 +50,46 @@ impl<T: FloatExt> Node<T> {
index: index,
output: output,
split_feature: 0,
split_value: T::nan(),
split_score: T::nan(),
split_value: Option::None,
split_score: Option::None,
true_child: Option::None,
false_child: Option::None
}
}
}
impl<T: FloatExt> PartialEq for Node<T> {
fn eq(&self, other: &Self) -> bool {
(self.output - other.output).abs() < T::epsilon() &&
self.split_feature == other.split_feature &&
match (self.split_value, other.split_value) {
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
(None, None) => true,
_ => false,
} &&
match (self.split_score, other.split_score) {
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
(None, None) => true,
_ => false,
}
}
}
impl<T: FloatExt> PartialEq for DecisionTreeRegressor<T> {
fn eq(&self, other: &Self) -> bool {
if self.depth != other.depth || self.nodes.len() != other.nodes.len(){
return false
} else {
for i in 0..self.nodes.len() {
if self.nodes[i] != other.nodes[i] {
return false
}
}
return true
}
}
}
struct NodeVisitor<'a, T: FloatExt, M: Matrix<T>> {
x: &'a M,
y: &'a M,
@@ -169,7 +203,7 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
if node.true_child == None && node.false_child == None {
result = node.output;
} else {
if x.get(row, node.split_feature) <= node.split_value {
if x.get(row, node.split_feature) <= node.split_value.unwrap_or(T::nan()) {
queue.push_back(node.true_child.unwrap());
} else {
queue.push_back(node.false_child.unwrap());
@@ -207,7 +241,7 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
self.find_best_split(visitor, n, sum, parent_gain, variables[j]);
}
!self.nodes[visitor.node].split_score.is_nan()
self.nodes[visitor.node].split_score != Option::None
}
@@ -240,10 +274,10 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
let gain = (T::from(true_count).unwrap() * true_mean * true_mean + T::from(false_count).unwrap() * false_mean * false_mean) - parent_gain;
if self.nodes[visitor.node].split_score.is_nan() || gain > self.nodes[visitor.node].split_score {
if self.nodes[visitor.node].split_score == Option::None || gain > self.nodes[visitor.node].split_score.unwrap() {
self.nodes[visitor.node].split_feature = j;
self.nodes[visitor.node].split_value = (visitor.x.get(*i, j) + prevx) / T::two();
self.nodes[visitor.node].split_score = gain;
self.nodes[visitor.node].split_value = Option::Some((visitor.x.get(*i, j) + prevx) / T::two());
self.nodes[visitor.node].split_score = Option::Some(gain);
visitor.true_child_output = true_mean;
visitor.false_child_output = false_mean;
}
@@ -264,7 +298,7 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
for i in 0..n {
if visitor.samples[i] > 0 {
if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value {
if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value.unwrap_or(T::nan()) {
true_samples[i] = visitor.samples[i];
tc += true_samples[i];
visitor.samples[i] = 0;
@@ -276,8 +310,8 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf {
self.nodes[visitor.node].split_feature = 0;
self.nodes[visitor.node].split_value = T::nan();
self.nodes[visitor.node].split_score = T::nan();
self.nodes[visitor.node].split_value = Option::None;
self.nodes[visitor.node].split_score = Option::None;
return false;
}
@@ -357,4 +391,33 @@ mod tests {
}
#[test]
fn serde() {
let x = DenseMatrix::from_array(&[
&[ 234.289, 235.6, 159., 107.608, 1947., 60.323],
&[ 259.426, 232.5, 145.6, 108.632, 1948., 61.122],
&[ 258.054, 368.2, 161.6, 109.773, 1949., 60.171],
&[ 284.599, 335.1, 165., 110.929, 1950., 61.187],
&[ 328.975, 209.9, 309.9, 112.075, 1951., 63.221],
&[ 346.999, 193.2, 359.4, 113.27 , 1952., 63.639],
&[ 365.385, 187., 354.7, 115.094, 1953., 64.989],
&[ 363.112, 357.8, 335., 116.219, 1954., 63.761],
&[ 397.469, 290.4, 304.8, 117.388, 1955., 66.019],
&[ 419.18 , 282.2, 285.7, 118.734, 1956., 67.857],
&[ 442.769, 293.6, 279.8, 120.445, 1957., 68.169],
&[ 444.546, 468.1, 263.7, 121.95 , 1958., 66.513],
&[ 482.704, 381.3, 255.2, 123.366, 1959., 68.655],
&[ 502.601, 393.1, 251.4, 125.368, 1960., 69.564],
&[ 518.173, 480.6, 257.2, 127.852, 1961., 69.331],
&[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
let tree = DecisionTreeRegressor::fit(&x, &y, Default::default());
let deserialized_tree: DecisionTreeRegressor<f64> = bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
assert_eq!(tree, deserialized_tree);
}
}