add serde support for XGRegressor (#337)
* add serde support for XGBoostRegressor * add traits to dependent structs
This commit is contained in:
@@ -53,10 +53,14 @@ use crate::{
|
|||||||
rand_custom::get_rng_impl,
|
rand_custom::get_rng_impl,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[cfg(feature = "serde")]
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// Defines the objective function to be optimized.
|
/// Defines the objective function to be optimized.
|
||||||
/// The objective function provides the loss, gradient (first derivative), and
|
/// The objective function provides the loss, gradient (first derivative), and
|
||||||
/// hessian (second derivative) required for the XGBoost algorithm.
|
/// hessian (second derivative) required for the XGBoost algorithm.
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
pub enum Objective {
|
pub enum Objective {
|
||||||
/// The objective for regression tasks using Mean Squared Error.
|
/// The objective for regression tasks using Mean Squared Error.
|
||||||
/// Loss: 0.5 * (y_true - y_pred)^2
|
/// Loss: 0.5 * (y_true - y_pred)^2
|
||||||
@@ -122,6 +126,8 @@ impl Objective {
|
|||||||
/// This is a recursive data structure where each `TreeRegressor` is a node
|
/// This is a recursive data structure where each `TreeRegressor` is a node
|
||||||
/// that can have a left and a right child, also of type `TreeRegressor`.
|
/// that can have a left and a right child, also of type `TreeRegressor`.
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
struct TreeRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
|
struct TreeRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
|
||||||
left: Option<Box<TreeRegressor<TX, TY, X, Y>>>,
|
left: Option<Box<TreeRegressor<TX, TY, X, Y>>>,
|
||||||
right: Option<Box<TreeRegressor<TX, TY, X, Y>>>,
|
right: Option<Box<TreeRegressor<TX, TY, X, Y>>>,
|
||||||
@@ -374,6 +380,7 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
/// Parameters for the `jRegressor` model.
|
/// Parameters for the `jRegressor` model.
|
||||||
///
|
///
|
||||||
/// This struct holds all the hyperparameters that control the training process.
|
/// This struct holds all the hyperparameters that control the training process.
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct XGRegressorParameters {
|
pub struct XGRegressorParameters {
|
||||||
/// The number of boosting rounds or trees to build.
|
/// The number of boosting rounds or trees to build.
|
||||||
@@ -494,6 +501,8 @@ impl XGRegressorParameters {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// An Extreme Gradient Boosting (XGBoost) model for regression and classification tasks.
|
/// An Extreme Gradient Boosting (XGBoost) model for regression and classification tasks.
|
||||||
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct XGRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
|
pub struct XGRegressor<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>> {
|
||||||
regressors: Option<Vec<TreeRegressor<TX, TY, X, Y>>>,
|
regressors: Option<Vec<TreeRegressor<TX, TY, X, Y>>>,
|
||||||
parameters: Option<XGRegressorParameters>,
|
parameters: Option<XGRegressorParameters>,
|
||||||
|
|||||||
Reference in New Issue
Block a user