feat: adds serialization/deserialization methods
This commit is contained in:
@@ -4,12 +4,13 @@ use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rand::Rng;
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::tree::decision_tree_classifier::{DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion, which_max};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct RandomForestClassifierParameters {
|
||||
pub criterion: SplitCriterion,
|
||||
pub max_depth: Option<u16>,
|
||||
@@ -19,13 +20,34 @@ pub struct RandomForestClassifierParameters {
|
||||
pub mtry: Option<usize>
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct RandomForestClassifier<T: FloatExt> {
|
||||
parameters: RandomForestClassifierParameters,
|
||||
trees: Vec<DecisionTreeClassifier<T>>,
|
||||
classes: Vec<T>
|
||||
}
|
||||
|
||||
impl<T: FloatExt> PartialEq for RandomForestClassifier<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.classes.len() != other.classes.len() ||
|
||||
self.trees.len() != other.trees.len() {
|
||||
return false
|
||||
} else {
|
||||
for i in 0..self.classes.len() {
|
||||
if (self.classes[i] - other.classes[i]).abs() > T::epsilon() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for i in 0..self.trees.len() {
|
||||
if self.trees[i] != other.trees[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RandomForestClassifierParameters {
|
||||
fn default() -> Self {
|
||||
RandomForestClassifierParameters {
|
||||
@@ -171,4 +193,37 @@ mod tests {
|
||||
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_array(&[
|
||||
&[5.1, 3.5, 1.4, 0.2],
|
||||
&[4.9, 3.0, 1.4, 0.2],
|
||||
&[4.7, 3.2, 1.3, 0.2],
|
||||
&[4.6, 3.1, 1.5, 0.2],
|
||||
&[5.0, 3.6, 1.4, 0.2],
|
||||
&[5.4, 3.9, 1.7, 0.4],
|
||||
&[4.6, 3.4, 1.4, 0.3],
|
||||
&[5.0, 3.4, 1.5, 0.2],
|
||||
&[4.4, 2.9, 1.4, 0.2],
|
||||
&[4.9, 3.1, 1.5, 0.1],
|
||||
&[7.0, 3.2, 4.7, 1.4],
|
||||
&[6.4, 3.2, 4.5, 1.5],
|
||||
&[6.9, 3.1, 4.9, 1.5],
|
||||
&[5.5, 2.3, 4.0, 1.3],
|
||||
&[6.5, 2.8, 4.6, 1.5],
|
||||
&[5.7, 2.8, 4.5, 1.3],
|
||||
&[6.3, 3.3, 4.7, 1.6],
|
||||
&[4.9, 2.4, 3.3, 1.0],
|
||||
&[6.6, 2.9, 4.6, 1.3],
|
||||
&[5.2, 2.7, 3.9, 1.4]]);
|
||||
let y = vec![0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
|
||||
|
||||
let forest = RandomForestClassifier::fit(&x, &y, Default::default());
|
||||
|
||||
let deserialized_forest: RandomForestClassifier<f64> = bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(forest, deserialized_forest);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -4,12 +4,13 @@ use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rand::Rng;
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
use crate::math::num::FloatExt;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::tree::decision_tree_regressor::{DecisionTreeRegressor, DecisionTreeRegressorParameters};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct RandomForestRegressorParameters {
|
||||
pub max_depth: Option<u16>,
|
||||
pub min_samples_leaf: usize,
|
||||
@@ -18,7 +19,7 @@ pub struct RandomForestRegressorParameters {
|
||||
pub mtry: Option<usize>
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct RandomForestRegressor<T: FloatExt> {
|
||||
parameters: RandomForestRegressorParameters,
|
||||
trees: Vec<DecisionTreeRegressor<T>>
|
||||
@@ -36,6 +37,21 @@ impl Default for RandomForestRegressorParameters {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt> PartialEq for RandomForestRegressor<T> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.trees.len() != other.trees.len() {
|
||||
return false
|
||||
} else {
|
||||
for i in 0..self.trees.len() {
|
||||
if self.trees[i] != other.trees[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FloatExt> RandomForestRegressor<T> {
|
||||
|
||||
pub fn fit<M: Matrix<T>>(x: &M, y: &M::RowVector, parameters: RandomForestRegressorParameters) -> RandomForestRegressor<T> {
|
||||
@@ -180,4 +196,33 @@ mod tests {
|
||||
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde() {
|
||||
let x = DenseMatrix::from_array(&[
|
||||
&[ 234.289, 235.6, 159., 107.608, 1947., 60.323],
|
||||
&[ 259.426, 232.5, 145.6, 108.632, 1948., 61.122],
|
||||
&[ 258.054, 368.2, 161.6, 109.773, 1949., 60.171],
|
||||
&[ 284.599, 335.1, 165., 110.929, 1950., 61.187],
|
||||
&[ 328.975, 209.9, 309.9, 112.075, 1951., 63.221],
|
||||
&[ 346.999, 193.2, 359.4, 113.27 , 1952., 63.639],
|
||||
&[ 365.385, 187., 354.7, 115.094, 1953., 64.989],
|
||||
&[ 363.112, 357.8, 335., 116.219, 1954., 63.761],
|
||||
&[ 397.469, 290.4, 304.8, 117.388, 1955., 66.019],
|
||||
&[ 419.18 , 282.2, 285.7, 118.734, 1956., 67.857],
|
||||
&[ 442.769, 293.6, 279.8, 120.445, 1957., 68.169],
|
||||
&[ 444.546, 468.1, 263.7, 121.95 , 1958., 66.513],
|
||||
&[ 482.704, 381.3, 255.2, 123.366, 1959., 68.655],
|
||||
&[ 502.601, 393.1, 251.4, 125.368, 1960., 69.564],
|
||||
&[ 518.173, 480.6, 257.2, 127.852, 1961., 69.331],
|
||||
&[ 554.894, 400.7, 282.7, 130.081, 1962., 70.551]]);
|
||||
let y = vec![83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9];
|
||||
|
||||
let forest = RandomForestRegressor::fit(&x, &y, Default::default());
|
||||
|
||||
let deserialized_forest: RandomForestRegressor<f64> = bincode::deserialize(&bincode::serialize(&forest).unwrap()).unwrap();
|
||||
|
||||
assert_eq!(forest, deserialized_forest);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user