Implement realnum::rand (#251)

Co-authored-by: Luis Moreno <morenol@users.noreply.github.com>
Co-authored-by: Lorenzo <tunedconsulting@gmail.com>

* Implement rand. Use the new derive [#default]
* Use custom range
* Use range seed
* Bump version
* Add array length checks for
This commit is contained in:
Lorenzo
2023-03-20 23:45:44 +09:00
committed by GitHub
parent 7d059c4fb1
commit f498f9629e
12 changed files with 118 additions and 44 deletions
+18 -8
View File
@@ -137,16 +137,17 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
self.classes.as_ref()
}
/// Get depth of tree
fn depth(&self) -> u16 {
pub fn depth(&self) -> u16 {
self.depth
}
}
/// The function to measure the quality of a split.
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Default)]
pub enum SplitCriterion {
/// [Gini index](../decision_tree_classifier/index.html)
#[default]
Gini,
/// [Entropy](../decision_tree_classifier/index.html)
Entropy,
@@ -154,12 +155,6 @@ pub enum SplitCriterion {
ClassificationError,
}
impl Default for SplitCriterion {
fn default() -> Self {
SplitCriterion::Gini
}
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
struct Node {
@@ -543,6 +538,10 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
parameters: DecisionTreeClassifierParameters,
) -> Result<DecisionTreeClassifier<TX, TY, X, Y>, Failed> {
let (x_nrows, num_attributes) = x.shape();
if x_nrows != y.shape() {
return Err(Failed::fit("Size of x should equal size of y"));
}
let samples = vec![1; x_nrows];
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
}
@@ -968,6 +967,17 @@ mod tests {
);
}
#[test]
fn test_random_matrix_with_wrong_rownum() {
let x_rand: DenseMatrix<f64> = DenseMatrix::<f64>::rand(21, 200);
let y: Vec<u32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
let fail = DecisionTreeClassifier::fit(&x_rand, &y, Default::default());
assert!(fail.is_err());
}
#[cfg_attr(
all(target_arch = "wasm32", not(target_os = "wasi")),
wasm_bindgen_test::wasm_bindgen_test
+4 -1
View File
@@ -18,7 +18,6 @@
//! Example:
//!
//! ```
//! use rand::thread_rng;
//! use smartcore::linalg::basic::matrix::DenseMatrix;
//! use smartcore::tree::decision_tree_regressor::*;
//!
@@ -422,6 +421,10 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
parameters: DecisionTreeRegressorParameters,
) -> Result<DecisionTreeRegressor<TX, TY, X, Y>, Failed> {
let (x_nrows, num_attributes) = x.shape();
if x_nrows != y.shape() {
return Err(Failed::fit("Size of x should equal size of y"));
}
let samples = vec![1; x_nrows];
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
}