feat: adds serialization/deserialization methods
This commit is contained in:
@@ -3,11 +3,13 @@ use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
use std::collections::LinkedList;
|
||||
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct DecisionTreeClassifierParameters {
|
||||
pub criterion: SplitCriterion,
|
||||
pub max_depth: Option<u16>,
|
||||
@@ -15,7 +17,7 @@ pub struct DecisionTreeClassifierParameters {
|
||||
pub min_samples_split: usize
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct DecisionTreeClassifier<T: FloatExt> {
|
||||
nodes: Vec<Node<T>>,
|
||||
parameters: DecisionTreeClassifierParameters,
|
||||
@@ -24,24 +26,62 @@ pub struct DecisionTreeClassifier<T: FloatExt> {
|
||||
depth: u16
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub enum SplitCriterion {
|
||||
Gini,
|
||||
Entropy,
|
||||
ClassificationError
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct Node<T: FloatExt> {
|
||||
index: usize,
|
||||
output: usize,
|
||||
split_feature: usize,
|
||||
split_value: T,
|
||||
split_score: T,
|
||||
split_value: Option<T>,
|
||||
split_score: Option<T>,
|
||||
true_child: Option<usize>,
|
||||
false_child: Option<usize>,
|
||||
}
|
||||
|
||||
impl<T: FloatExt> PartialEq for DecisionTreeClassifier<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.depth != other.depth ||
|
||||
self.num_classes != other.num_classes ||
|
||||
self.nodes.len() != other.nodes.len(){
|
||||
return false
|
||||
} else {
|
||||
for i in 0..self.classes.len() {
|
||||
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for i in 0..self.nodes.len() {
|
||||
if self.nodes[i] != other.nodes[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt> PartialEq for Node<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.output == other.output &&
|
||||
self.split_feature == other.split_feature &&
|
||||
match (self.split_value, other.split_value) {
|
||||
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
||||
(None, None) => true,
|
||||
_ => false,
|
||||
} &&
|
||||
match (self.split_score, other.split_score) {
|
||||
(Some(a), Some(b)) => (a - b).abs() < T::epsilon(),
|
||||
(None, None) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DecisionTreeClassifierParameters {
|
||||
fn default() -> Self {
|
||||
@@ -60,8 +100,8 @@ impl<T: FloatExt> Node<T> {
|
||||
index: index,
|
||||
output: output,
|
||||
split_feature: 0,
|
||||
split_value: T::nan(),
|
||||
split_score: T::nan(),
|
||||
split_value: Option::None,
|
||||
split_score: Option::None,
|
||||
true_child: Option::None,
|
||||
false_child: Option::None
|
||||
}
|
||||
@@ -238,7 +278,7 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||
if node.true_child == None && node.false_child == None {
|
||||
result = node.output;
|
||||
} else {
|
||||
if x.get(row, node.split_feature) <= node.split_value {
|
||||
if x.get(row, node.split_feature) <= node.split_value.unwrap_or(T::nan()) {
|
||||
queue.push_back(node.true_child.unwrap());
|
||||
} else {
|
||||
queue.push_back(node.false_child.unwrap());
|
||||
@@ -299,7 +339,7 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||
self.find_best_split(visitor, n, &count, &mut false_count, parent_impurity, variables[j]);
|
||||
}
|
||||
|
||||
!self.nodes[visitor.node].split_score.is_nan()
|
||||
self.nodes[visitor.node].split_score != Option::None
|
||||
|
||||
}
|
||||
|
||||
@@ -336,10 +376,10 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||
let false_label = which_max(false_count);
|
||||
let gain = parent_impurity - T::from(tc).unwrap() / T::from(n).unwrap() * impurity(&self.parameters.criterion, &true_count, tc) - T::from(fc).unwrap() / T::from(n).unwrap() * impurity(&self.parameters.criterion, &false_count, fc);
|
||||
|
||||
if self.nodes[visitor.node].split_score.is_nan() || gain > self.nodes[visitor.node].split_score {
|
||||
if self.nodes[visitor.node].split_score == Option::None || gain > self.nodes[visitor.node].split_score.unwrap() {
|
||||
self.nodes[visitor.node].split_feature = j;
|
||||
self.nodes[visitor.node].split_value = (visitor.x.get(*i, j) + prevx) / T::two();
|
||||
self.nodes[visitor.node].split_score = gain;
|
||||
self.nodes[visitor.node].split_value = Option::Some((visitor.x.get(*i, j) + prevx) / T::two());
|
||||
self.nodes[visitor.node].split_score = Option::Some(gain);
|
||||
visitor.true_child_output = true_label;
|
||||
visitor.false_child_output = false_label;
|
||||
}
|
||||
@@ -360,7 +400,7 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||
|
||||
for i in 0..n {
|
||||
if visitor.samples[i] > 0 {
|
||||
if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value {
|
||||
if visitor.x.get(i, self.nodes[visitor.node].split_feature) <= self.nodes[visitor.node].split_value.unwrap_or(T::nan()) {
|
||||
true_samples[i] = visitor.samples[i];
|
||||
tc += true_samples[i];
|
||||
visitor.samples[i] = 0;
|
||||
@@ -372,8 +412,8 @@ impl<T: FloatExt> DecisionTreeClassifier<T> {
|
||||
|
||||
if tc < self.parameters.min_samples_leaf || fc < self.parameters.min_samples_leaf {
|
||||
self.nodes[visitor.node].split_feature = 0;
|
||||
self.nodes[visitor.node].split_value = T::nan();
|
||||
self.nodes[visitor.node].split_score = T::nan();
|
||||
self.nodes[visitor.node].split_value = Option::None;
|
||||
self.nodes[visitor.node].split_score = Option::None;
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -477,4 +517,37 @@ mod tests {
|
||||
assert_eq!(y, DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x));
|
||||
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_array(&[
|
||||
&[1.,1.,1.,0.],
|
||||
&[1.,1.,1.,0.],
|
||||
&[1.,1.,1.,1.],
|
||||
&[1.,1.,0.,0.],
|
||||
&[1.,1.,0.,1.],
|
||||
&[1.,0.,1.,0.],
|
||||
&[1.,0.,1.,0.],
|
||||
&[1.,0.,1.,1.],
|
||||
&[1.,0.,0.,0.],
|
||||
&[1.,0.,0.,1.],
|
||||
&[0.,1.,1.,0.],
|
||||
&[0.,1.,1.,0.],
|
||||
&[0.,1.,1.,1.],
|
||||
&[0.,1.,0.,0.],
|
||||
&[0.,1.,0.,1.],
|
||||
&[0.,0.,1.,0.],
|
||||
&[0.,0.,1.,0.],
|
||||
&[0.,0.,1.,1.],
|
||||
&[0.,0.,0.,0.],
|
||||
&[0.,0.,0.,1.]]);
|
||||
let y = vec![1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.];
|
||||
|
||||
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default());
|
||||
|
||||
let deserialized_tree: DecisionTreeClassifier<f64> = bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(tree, deserialized_tree);
|
||||
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user