feat: adds serialization/deserialization methods
This commit is contained in:
@@ -3,11 +3,13 @@ use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
use std::collections::LinkedList;
|
||||
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct DecisionTreeClassifierParameters {
|
||||
pub criterion: SplitCriterion,
|
||||
pub max_depth: Option<u16>,
|
||||
@@ -15,7 +17,7 @@ pub struct DecisionTreeClassifierParameters {
|
||||
pub min_samples_split: usize
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct DecisionTreeClassifier<T: FloatExt> {
|
||||
nodes: Vec<Node<T>>,
|
||||
parameters: DecisionTreeClassifierParameters,
|
||||
@@ -24,24 +26,62 @@ pub struct DecisionTreeClassifier<T: FloatExt> {
|
||||
depth: u16
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub enum SplitCriterion {
|
||||
Gini,
|
||||
Entropy,
|
||||
ClassificationError
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Node<T: FloatExt> {
|
||||
index: usize,
|
||||
output: usize,
|
||||
split_feature: usize,
|
||||
split_value: T,
|
||||
split_score: T,
|
||||
split_value: Option<T>,
|
||||
split_score: Option<T>,
|
||||
true_child: Option<usize>,
|
||||
false_child: Option<usize>,
|
||||
}
|
||||
|
||||
impl<T: FloatExt> PartialEq for DecisionTreeClassifier<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.depth != other.depth ||
|
||||
self.num_classes != other.num_classes ||
|
||||
self.nodes.len() != other.nodes.len(){
|
||||
return false
|
||||
} else {
|
||||
for i in 0..self.classes.len() {
|
||||
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for i in 0..self.nodes.len() {
|
||||
if self.nodes[i] != other.nodes[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt> PartialEq for Node<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.output == other.output &&
|
||||
self.split_feature == other.split_feature &&
|
||||
match (self.split_value, other.split_value) {
|
||||
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
||||
(None, None) => true,
|
||||
_ => false,
|
||||
} &&
|
||||
match (self.split_score, other.split_score) {
|
||||
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
||||
(None, None) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DecisionTreeClassifierParameters {
|
||||
fn default() -> Self {
|
||||
@@ -60,8 +100,8 @@ impl<T: FloatExt> Node<T> {
|
||||
index: index,
|
||||
output: output,
|
||||
split_feature: 0,
|
||||
split_value: T::nan(),
|
||||
split_score: T::nan(),
|
||||
split_value: Option::None,
|
||||
split_score: Option::None,
|
||||
true_child: Option::None,
|
||||
false_child: Option::None
|
||||
}
|
||||
@@ -238,7 +278,7 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||
if node.true_child == None && node.false_child == None {
|
||||
result = node.output;
|
||||
} else {
|
||||
if x.get(row, node.split_feature) <= node.split_value {
|
||||
if x.get(row, node.split_feature) <= node.split_value.unwrap_or(T::nan()) {
|
||||
queue.push_back(node.true_child.unwrap());
|
||||
} else {
|
||||
queue.push_back(node.false_child.unwrap());
|
||||
@@ -299,7 +339,7 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||
self.find_best_split(visitor, n, &count, &mut false_count, parent_impurity, variables[j]);
|
||||
}
|
||||
|
||||
!self.nodes[visitor.node].split_score.is_nan()
|
||||
self.nodes[visitor.node].split_score != Option::None
|
||||
|
||||
}
|
||||
|
||||
@@ -336,10 +376,10 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||
let false_label = which_max(false_count);
|
||||
let gain = parent_impurity - T::from(tc).unwrap() / T::from(n).unwrap() * impurity(&self.parameters.criterion, &true_count, tc) - T::from(fc).unwrap() / T::from(n).unwrap() * impurity(&self.parameters.criterion, &false_count, fc);
|
||||
|
||||
if self.nodes[visitor.node].split_score.is_nan() || gain > self.nodes[visitor.node].split_score {
|
||||
if self.nodes[visitor.node].split_score == Option::None || gain > self.nodes[visitor.node].split_score.unwrap() {
|
||||
self.nodes[visitor.node].split_feature = j;
|
||||
self.nodes[visitor.node].split_value = (visitor.x.get(*i, j) + prevx) / T::two();
|
||||
self.nodes[visitor.node].split_score = gain;
|
||||
self.nodes[visitor.node].split_value = Option::Some((visitor.x.get(*i, j) + prevx) / T::two());
|
||||
self.nodes[visitor.node].split_score = Option::Some(gain);
|
||||
visitor.true_child_output = true_label;
|
||||
visitor.false_child_output = false_label;
|
||||
}
|
||||
@@ -360,7 +400,7 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||
|
||||
for i in 0..n {
|
||||
if visitor.samples[i] > 0 {
|
||||
if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value {
|
||||
if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value.unwrap_or(T::nan()) {
|
||||
true_samples[i] = visitor.samples[i];
|
||||
tc += true_samples[i];
|
||||
visitor.samples[i] = 0;
|
||||
@@ -372,8 +412,8 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||
|
||||
if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf {
|
||||
self.nodes[visitor.node].split_feature = 0;
|
||||
self.nodes[visitor.node].split_value = T::nan();
|
||||
self.nodes[visitor.node].split_score = T::nan();
|
||||
self.nodes[visitor.node].split_value = Option::None;
|
||||
self.nodes[visitor.node].split_score = Option::None;
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -477,4 +517,37 @@ mod tests {
|
||||
assert_eq!(y, DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x));
|
||||
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_array(&[
|
||||
&[1.,1.,1.,0.],
|
||||
&[1.,1.,1.,0.],
|
||||
&[1.,1.,1.,1.],
|
||||
&[1.,1.,0.,0.],
|
||||
&[1.,1.,0.,1.],
|
||||
&[1.,0.,1.,0.],
|
||||
&[1.,0.,1.,0.],
|
||||
&[1.,0.,1.,1.],
|
||||
&[1.,0.,0.,0.],
|
||||
&[1.,0.,0.,1.],
|
||||
&[0.,1.,1.,0.],
|
||||
&[0.,1.,1.,0.],
|
||||
&[0.,1.,1.,1.],
|
||||
&[0.,1.,0.,0.],
|
||||
&[0.,1.,0.,1.],
|
||||
&[0.,0.,1.,0.],
|
||||
&[0.,0.,1.,0.],
|
||||
&[0.,0.,1.,1.],
|
||||
&[0.,0.,0.,0.],
|
||||
&[0.,0.,0.,1.]]);
|
||||
let y = vec![1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.];
|
||||
|
||||
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default());
|
||||
|
||||
let deserialized_tree: DecisionTreeClassifier<f64> = bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(tree, deserialized_tree);
|
||||
|
||||
}
|
||||
}
|
||||
@@ -2,31 +2,33 @@ use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
use std::collections::LinkedList;
|
||||
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct DecisionTreeRegressorParameters {
|
||||
pub max_depth: Option<u16>,
|
||||
pub min_samples_leaf: usize,
|
||||
pub min_samples_split: usize
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct DecisionTreeRegressor<T: FloatExt> {
|
||||
nodes: Vec<Node<T>>,
|
||||
parameters: DecisionTreeRegressorParameters,
|
||||
depth: u16
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Node<T: FloatExt> {
|
||||
index: usize,
|
||||
output: T,
|
||||
split_feature: usize,
|
||||
split_value: T,
|
||||
split_score: T,
|
||||
split_value: Option<T>,
|
||||
split_score: Option<T>,
|
||||
true_child: Option<usize>,
|
||||
false_child: Option<usize>,
|
||||
}
|
||||
@@ -48,14 +50,46 @@ impl<T: FloatExt> Node<T> {
|
||||
index: index,
|
||||
output: output,
|
||||
split_feature: 0,
|
||||
split_value: T::nan(),
|
||||
split_score: T::nan(),
|
||||
split_value: Option::None,
|
||||
split_score: Option::None,
|
||||
true_child: Option::None,
|
||||
false_child: Option::None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt> PartialEq for Node<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
(self.output - other.output).abs() < T::epsilon() &&
|
||||
self.split_feature == other.split_feature &&
|
||||
match (self.split_value, other.split_value) {
|
||||
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
||||
(None, None) => true,
|
||||
_ => false,
|
||||
} &&
|
||||
match (self.split_score, other.split_score) {
|
||||
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
||||
(None, None) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt> PartialEq for DecisionTreeRegressor<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.depth != other.depth || self.nodes.len() != other.nodes.len(){
|
||||
return false
|
||||
} else {
|
||||
for i in 0..self.nodes.len() {
|
||||
if self.nodes[i] != other.nodes[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct NodeVisitor<'a, T: FloatExt, M: Matrix<T>> {
|
||||
x: &'a M,
|
||||
y: &'a M,
|
||||
@@ -169,7 +203,7 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
||||
if node.true_child == None && node.false_child == None {
|
||||
result = node.output;
|
||||
} else {
|
||||
if x.get(row, node.split_feature) <= node.split_value {
|
||||
if x.get(row, node.split_feature) <= node.split_value.unwrap_or(T::nan()) {
|
||||
queue.push_back(node.true_child.unwrap());
|
||||
} else {
|
||||
queue.push_back(node.false_child.unwrap());
|
||||
@@ -207,7 +241,7 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
||||
self.find_best_split(visitor, n, sum, parent_gain, variables[j]);
|
||||
}
|
||||
|
||||
!self.nodes[visitor.node].split_score.is_nan()
|
||||
self.nodes[visitor.node].split_score != Option::None
|
||||
|
||||
}
|
||||
|
||||
@@ -240,10 +274,10 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
||||
|
||||
let gain = (T::from(true_count).unwrap() * true_mean * true_mean + T::from(false_count).unwrap() * false_mean * false_mean) - parent_gain;
|
||||
|
||||
if self.nodes[visitor.node].split_score.is_nan() || gain > self.nodes[visitor.node].split_score {
|
||||
if self.nodes[visitor.node].split_score == Option::None || gain > self.nodes[visitor.node].split_score.unwrap() {
|
||||
self.nodes[visitor.node].split_feature = j;
|
||||
self.nodes[visitor.node].split_value = (visitor.x.get(*i, j) + prevx) / T::two();
|
||||
self.nodes[visitor.node].split_score = gain;
|
||||
self.nodes[visitor.node].split_value = Option::Some((visitor.x.get(*i, j) + prevx) / T::two());
|
||||
self.nodes[visitor.node].split_score = Option::Some(gain);
|
||||
visitor.true_child_output = true_mean;
|
||||
visitor.false_child_output = false_mean;
|
||||
}
|
||||
@@ -264,7 +298,7 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
||||
|
||||
for i in 0..n {
|
||||
if visitor.samples[i] > 0 {
|
||||
if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value {
|
||||
if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value.unwrap_or(T::nan()) {
|
||||
true_samples[i] = visitor.samples[i];
|
||||
tc += true_samples[i];
|
||||
visitor.samples[i] = 0;
|
||||
@@ -276,8 +310,8 @@ impl<T: FloatExt> DecisionTreeRegressor<T> {
|
||||
|
||||
if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf {
|
||||
self.nodes[visitor.node].split_feature = 0;
|
||||
self.nodes[visitor.node].split_value = T::nan();
|
||||
self.nodes[visitor.node].split_score = T::nan();
|
||||
self.nodes[visitor.node].split_value = Option::None;
|
||||
self.nodes[visitor.node].split_score = Option::None;
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -357,4 +391,33 @@ mod tests {
|
||||
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_array(&[
|
||||
&[ 234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||
&[ 259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
&[ 258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
&[ 284.599, 335.1, 165., 110.929, 1950., 61.187],
|
||||
&[ 328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
&[ 346.999, 193.2, 359.4, 113.27 , 1952., 63.639],
|
||||
&[ 365.385, 187., 354.7, 115.094, 1953., 64.989],
|
||||
&[ 363.112, 357.8, 335., 116.219, 1954., 63.761],
|
||||
&[ 397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
&[ 419.18 , 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
&[ 442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
&[ 444.546, 468.1, 263.7, 121.95 , 1958., 66.513],
|
||||
&[ 482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
&[ 502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[ 518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
|
||||
let y: Vec<f64> = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
||||
|
||||
let tree = DecisionTreeRegressor::fit(&x, &y, Default::default());
|
||||
|
||||
let deserialized_tree: DecisionTreeRegressor<f64> = bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(tree, deserialized_tree);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user