feat: refactoring, adds Result to most public API
This commit is contained in:
@@ -50,9 +50,9 @@
|
||||
//! let y = vec![ 0., 0., 0., 0., 0., 0., 0., 0.,
|
||||
//! 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
|
||||
//!
|
||||
//! let tree = DecisionTreeClassifier::fit(&x, &y, Default::default());
|
||||
//! let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||
//!
|
||||
//! let y_hat = tree.predict(&x); // use the same data for prediction
|
||||
//! let y_hat = tree.predict(&x).unwrap(); // use the same data for prediction
|
||||
//! ```
|
||||
//!
|
||||
//!
|
||||
@@ -71,6 +71,7 @@ use rand::seq::SliceRandom;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
@@ -276,7 +277,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: DecisionTreeClassifierParameters,
|
||||
) -> DecisionTreeClassifier<T> {
|
||||
) -> Result<DecisionTreeClassifier<T>, Failed> {
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let samples = vec![1; x_nrows];
|
||||
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||
@@ -288,14 +289,17 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
samples: Vec<usize>,
|
||||
mtry: usize,
|
||||
parameters: DecisionTreeClassifierParameters,
|
||||
) -> DecisionTreeClassifier<T> {
|
||||
) -> Result<DecisionTreeClassifier<T>, Failed> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
let (_, y_ncols) = y_m.shape();
|
||||
let (_, num_attributes) = x.shape();
|
||||
let classes = y_m.unique();
|
||||
let k = classes.len();
|
||||
if k < 2 {
|
||||
panic!("Incorrect number of classes: {}. Should be >= 2.", k);
|
||||
return Err(Failed::fit(&format!(
|
||||
"Incorrect number of classes: {}. Should be >= 2.",
|
||||
k
|
||||
)));
|
||||
}
|
||||
|
||||
let mut yi: Vec<usize> = vec![0; y_ncols];
|
||||
@@ -343,12 +347,12 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
};
|
||||
}
|
||||
|
||||
tree
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
/// Predict class value for `x`.
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
|
||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let mut result = M::zeros(1, x.shape().0);
|
||||
|
||||
let (n, _) = x.shape();
|
||||
@@ -357,7 +361,7 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
result.set(0, i, self.classes[self.predict_for_row(x, i)]);
|
||||
}
|
||||
|
||||
result.to_row_vector()
|
||||
Ok(result.to_row_vector())
|
||||
}
|
||||
|
||||
pub(in crate) fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> usize {
|
||||
@@ -637,7 +641,9 @@ mod tests {
|
||||
|
||||
assert_eq!(
|
||||
y,
|
||||
DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x)
|
||||
DecisionTreeClassifier::fit(&x, &y, Default::default())
|
||||
.and_then(|t| t.predict(&x))
|
||||
.unwrap()
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
@@ -652,6 +658,7 @@ mod tests {
|
||||
min_samples_split: 2
|
||||
}
|
||||
)
|
||||
.unwrap()
|
||||
.depth
|
||||
);
|
||||
}
|
||||
@@ -686,7 +693,9 @@ mod tests {
|
||||
|
||||
assert_eq!(
|
||||
y,
|
||||
DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x)
|
||||
DecisionTreeClassifier::fit(&x, &y, Default::default())
|
||||
.and_then(|t| t.predict(&x))
|
||||
.unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
@@ -718,7 +727,7 @@ mod tests {
|
||||
1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.,
|
||||
];
|
||||
|
||||
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default());
|
||||
let tree = DecisionTreeClassifier::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let deserialized_tree: DecisionTreeClassifier<f64> =
|
||||
bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
|
||||
|
||||
@@ -45,9 +45,9 @@
|
||||
//! 101.2, 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9,
|
||||
//! ];
|
||||
//!
|
||||
//! let tree = DecisionTreeRegressor::fit(&x, &y, Default::default());
|
||||
//! let tree = DecisionTreeRegressor::fit(&x, &y, Default::default()).unwrap();
|
||||
//!
|
||||
//! let y_hat = tree.predict(&x); // use the same data for prediction
|
||||
//! let y_hat = tree.predict(&x).unwrap(); // use the same data for prediction
|
||||
//! ```
|
||||
//!
|
||||
//! ## References:
|
||||
@@ -66,6 +66,7 @@ use rand::seq::SliceRandom;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||
use crate::error::Failed;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::math::num::RealNumber;
|
||||
|
||||
@@ -196,7 +197,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
x: &M,
|
||||
y: &M::RowVector,
|
||||
parameters: DecisionTreeRegressorParameters,
|
||||
) -> DecisionTreeRegressor<T> {
|
||||
) -> Result<DecisionTreeRegressor<T>, Failed> {
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let samples = vec![1; x_nrows];
|
||||
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||
@@ -208,16 +209,11 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
samples: Vec<usize>,
|
||||
mtry: usize,
|
||||
parameters: DecisionTreeRegressorParameters,
|
||||
) -> DecisionTreeRegressor<T> {
|
||||
) -> Result<DecisionTreeRegressor<T>, Failed> {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
|
||||
let (_, y_ncols) = y_m.shape();
|
||||
let (_, num_attributes) = x.shape();
|
||||
let classes = y_m.unique();
|
||||
let k = classes.len();
|
||||
if k < 2 {
|
||||
panic!("Incorrect number of classes: {}. Should be >= 2.", k);
|
||||
}
|
||||
|
||||
let mut nodes: Vec<Node<T>> = Vec::new();
|
||||
|
||||
@@ -257,12 +253,12 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
};
|
||||
}
|
||||
|
||||
tree
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
/// Predict regression value for `x`.
|
||||
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
|
||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> M::RowVector {
|
||||
pub fn predict<M: Matrix<T>>(&self, x: &M) -> Result<M::RowVector, Failed> {
|
||||
let mut result = M::zeros(1, x.shape().0);
|
||||
|
||||
let (n, _) = x.shape();
|
||||
@@ -271,7 +267,7 @@ impl<T: RealNumber> DecisionTreeRegressor<T> {
|
||||
result.set(0, i, self.predict_for_row(x, i));
|
||||
}
|
||||
|
||||
result.to_row_vector()
|
||||
Ok(result.to_row_vector())
|
||||
}
|
||||
|
||||
pub(in crate) fn predict_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> T {
|
||||
@@ -498,7 +494,9 @@ mod tests {
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
|
||||
let y_hat = DecisionTreeRegressor::fit(&x, &y, Default::default()).predict(&x);
|
||||
let y_hat = DecisionTreeRegressor::fit(&x, &y, Default::default())
|
||||
.and_then(|t| t.predict(&x))
|
||||
.unwrap();
|
||||
|
||||
for i in 0..y_hat.len() {
|
||||
assert!((y_hat[i] - y[i]).abs() < 0.1);
|
||||
@@ -517,7 +515,8 @@ mod tests {
|
||||
min_samples_split: 6,
|
||||
},
|
||||
)
|
||||
.predict(&x);
|
||||
.and_then(|t| t.predict(&x))
|
||||
.unwrap();
|
||||
|
||||
for i in 0..y_hat.len() {
|
||||
assert!((y_hat[i] - expected_y[i]).abs() < 0.1);
|
||||
@@ -536,7 +535,8 @@ mod tests {
|
||||
min_samples_split: 3,
|
||||
},
|
||||
)
|
||||
.predict(&x);
|
||||
.and_then(|t| t.predict(&x))
|
||||
.unwrap();
|
||||
|
||||
for i in 0..y_hat.len() {
|
||||
assert!((y_hat[i] - expected_y[i]).abs() < 0.1);
|
||||
@@ -568,7 +568,7 @@ mod tests {
|
||||
114.2, 115.7, 116.9,
|
||||
];
|
||||
|
||||
let tree = DecisionTreeRegressor::fit(&x, &y, Default::default());
|
||||
let tree = DecisionTreeRegressor::fit(&x, &y, Default::default()).unwrap();
|
||||
|
||||
let deserialized_tree: DecisionTreeRegressor<f64> =
|
||||
bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
|
||||
|
||||
Reference in New Issue
Block a user