Implement realnum::rand (#251)
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com> Co-authored-by: Lorenzo <tunedconsulting@gmail.com> * Implement rand. Use the new derive [#default] * Use custom range * Use range seed * Bump version * Add array length checks for
This commit is contained in:
@@ -137,16 +137,17 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
self.classes.as_ref()
|
||||
}
|
||||
/// Get depth of tree
|
||||
fn depth(&self) -> u16 {
|
||||
pub fn depth(&self) -> u16 {
|
||||
self.depth
|
||||
}
|
||||
}
|
||||
|
||||
/// The function to measure the quality of a split.
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub enum SplitCriterion {
|
||||
/// [Gini index](../decision_tree_classifier/index.html)
|
||||
#[default]
|
||||
Gini,
|
||||
/// [Entropy](../decision_tree_classifier/index.html)
|
||||
Entropy,
|
||||
@@ -154,12 +155,6 @@ pub enum SplitCriterion {
|
||||
ClassificationError,
|
||||
}
|
||||
|
||||
impl Default for SplitCriterion {
|
||||
fn default() -> Self {
|
||||
SplitCriterion::Gini
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
struct Node {
|
||||
@@ -543,6 +538,10 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
||||
parameters: DecisionTreeClassifierParameters,
|
||||
) -> Result<DecisionTreeClassifier<TX, TY, X, Y>, Failed> {
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
if x_nrows != y.shape() {
|
||||
return Err(Failed::fit("Size of x should equal size of y"));
|
||||
}
|
||||
|
||||
let samples = vec![1; x_nrows];
|
||||
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||
}
|
||||
@@ -968,6 +967,17 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_random_matrix_with_wrong_rownum() {
|
||||
let x_rand: DenseMatrix<f64> = DenseMatrix::<f64>::rand(21, 200);
|
||||
|
||||
let y: Vec<u32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
|
||||
|
||||
let fail = DecisionTreeClassifier::fit(&x_rand, &y, Default::default());
|
||||
|
||||
assert!(fail.is_err());
|
||||
}
|
||||
|
||||
#[cfg_attr(
|
||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||
wasm_bindgen_test::wasm_bindgen_test
|
||||
|
||||
@@ -18,7 +18,6 @@
|
||||
//! Example:
|
||||
//!
|
||||
//! ```
|
||||
//! use rand::thread_rng;
|
||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||
//! use smartcore::tree::decision_tree_regressor::*;
|
||||
//!
|
||||
@@ -422,6 +421,10 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
||||
parameters: DecisionTreeRegressorParameters,
|
||||
) -> Result<DecisionTreeRegressor<TX, TY, X, Y>, Failed> {
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
if x_nrows != y.shape() {
|
||||
return Err(Failed::fit("Size of x should equal size of y"));
|
||||
}
|
||||
|
||||
let samples = vec![1; x_nrows];
|
||||
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user