Implement realnum::rand (#251)
Co-authored-by: Luis Moreno <morenol@users.noreply.github.com> Co-authored-by: Lorenzo <tunedconsulting@gmail.com> * Implement rand. Use the new derive [#default] * Use custom range * Use range seed * Bump version * Add array length checks for
This commit is contained in:
+1
-1
@@ -2,7 +2,7 @@
|
|||||||
name = "smartcore"
|
name = "smartcore"
|
||||||
description = "Machine Learning in Rust."
|
description = "Machine Learning in Rust."
|
||||||
homepage = "https://smartcorelib.org"
|
homepage = "https://smartcorelib.org"
|
||||||
version = "0.3.0"
|
version = "0.3.1"
|
||||||
authors = ["smartcore Developers"]
|
authors = ["smartcore Developers"]
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
|
|||||||
@@ -49,20 +49,15 @@ pub mod linear_search;
|
|||||||
/// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries.
|
/// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries.
|
||||||
/// `KNNAlgorithmName` maintains a list of supported search algorithms, see [KNN algorithms](../algorithm/neighbour/index.html)
|
/// `KNNAlgorithmName` maintains a list of supported search algorithms, see [KNN algorithms](../algorithm/neighbour/index.html)
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone, Default)]
|
||||||
pub enum KNNAlgorithmName {
|
pub enum KNNAlgorithmName {
|
||||||
/// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html)
|
/// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html)
|
||||||
LinearSearch,
|
LinearSearch,
|
||||||
/// Cover Tree Search algorithm, see [`CoverTree`](../algorithm/neighbour/cover_tree/index.html)
|
/// Cover Tree Search algorithm, see [`CoverTree`](../algorithm/neighbour/cover_tree/index.html)
|
||||||
|
#[default]
|
||||||
CoverTree,
|
CoverTree,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for KNNAlgorithmName {
|
|
||||||
fn default() -> Self {
|
|
||||||
KNNAlgorithmName::CoverTree
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) enum KNNAlgorithm<T: Number, D: Distance<Vec<T>>> {
|
pub(crate) enum KNNAlgorithm<T: Number, D: Distance<Vec<T>>> {
|
||||||
|
|||||||
@@ -18,7 +18,7 @@
|
|||||||
//!
|
//!
|
||||||
//! Example:
|
//! Example:
|
||||||
//!
|
//!
|
||||||
//! ```
|
//! ```ignore
|
||||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||||
//! use smartcore::linalg::basic::arrays::Array2;
|
//! use smartcore::linalg::basic::arrays::Array2;
|
||||||
//! use smartcore::cluster::dbscan::*;
|
//! use smartcore::cluster::dbscan::*;
|
||||||
|
|||||||
@@ -454,8 +454,12 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
|
|||||||
y: &Y,
|
y: &Y,
|
||||||
parameters: RandomForestClassifierParameters,
|
parameters: RandomForestClassifierParameters,
|
||||||
) -> Result<RandomForestClassifier<TX, TY, X, Y>, Failed> {
|
) -> Result<RandomForestClassifier<TX, TY, X, Y>, Failed> {
|
||||||
let (_, num_attributes) = x.shape();
|
let (x_nrows, num_attributes) = x.shape();
|
||||||
let y_ncols = y.shape();
|
let y_ncols = y.shape();
|
||||||
|
if x_nrows != y_ncols {
|
||||||
|
return Err(Failed::fit("Number of rows in X should = len(y)"));
|
||||||
|
}
|
||||||
|
|
||||||
let mut yi: Vec<usize> = vec![0; y_ncols];
|
let mut yi: Vec<usize> = vec![0; y_ncols];
|
||||||
let classes = y.unique();
|
let classes = y.unique();
|
||||||
|
|
||||||
@@ -678,6 +682,30 @@ mod tests {
|
|||||||
assert!(accuracy(&y, &classifier.predict(&x).unwrap()) >= 0.95);
|
assert!(accuracy(&y, &classifier.predict(&x).unwrap()) >= 0.95);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_random_matrix_with_wrong_rownum() {
|
||||||
|
let x_rand: DenseMatrix<f64> = DenseMatrix::<f64>::rand(21, 200);
|
||||||
|
|
||||||
|
let y: Vec<u32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
|
||||||
|
|
||||||
|
let fail = RandomForestClassifier::fit(
|
||||||
|
&x_rand,
|
||||||
|
&y,
|
||||||
|
RandomForestClassifierParameters {
|
||||||
|
criterion: SplitCriterion::Gini,
|
||||||
|
max_depth: Option::None,
|
||||||
|
min_samples_leaf: 1,
|
||||||
|
min_samples_split: 2,
|
||||||
|
n_trees: 100,
|
||||||
|
m: Option::None,
|
||||||
|
keep_samples: false,
|
||||||
|
seed: 87,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(fail.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(
|
#[cfg_attr(
|
||||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
wasm_bindgen_test::wasm_bindgen_test
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
|||||||
@@ -399,6 +399,10 @@ impl<TX: Number + FloatNumber + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1
|
|||||||
) -> Result<RandomForestRegressor<TX, TY, X, Y>, Failed> {
|
) -> Result<RandomForestRegressor<TX, TY, X, Y>, Failed> {
|
||||||
let (n_rows, num_attributes) = x.shape();
|
let (n_rows, num_attributes) = x.shape();
|
||||||
|
|
||||||
|
if n_rows != y.shape() {
|
||||||
|
return Err(Failed::fit("Number of rows in X should = len(y)"));
|
||||||
|
}
|
||||||
|
|
||||||
let mtry = parameters
|
let mtry = parameters
|
||||||
.m
|
.m
|
||||||
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
|
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);
|
||||||
@@ -595,6 +599,32 @@ mod tests {
|
|||||||
assert!(mean_absolute_error(&y, &y_hat) < 1.0);
|
assert!(mean_absolute_error(&y, &y_hat) < 1.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_random_matrix_with_wrong_rownum() {
|
||||||
|
let x_rand: DenseMatrix<f64> = DenseMatrix::<f64>::rand(17, 200);
|
||||||
|
|
||||||
|
let y = vec![
|
||||||
|
83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, 112.6,
|
||||||
|
114.2, 115.7, 116.9,
|
||||||
|
];
|
||||||
|
|
||||||
|
let fail = RandomForestRegressor::fit(
|
||||||
|
&x_rand,
|
||||||
|
&y,
|
||||||
|
RandomForestRegressorParameters {
|
||||||
|
max_depth: Option::None,
|
||||||
|
min_samples_leaf: 1,
|
||||||
|
min_samples_split: 2,
|
||||||
|
n_trees: 1000,
|
||||||
|
m: Option::None,
|
||||||
|
keep_samples: false,
|
||||||
|
seed: 87,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(fail.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(
|
#[cfg_attr(
|
||||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
wasm_bindgen_test::wasm_bindgen_test
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
|||||||
+1
-1
@@ -30,7 +30,7 @@ pub enum FailedError {
|
|||||||
DecompositionFailed,
|
DecompositionFailed,
|
||||||
/// Can't solve for x
|
/// Can't solve for x
|
||||||
SolutionFailed,
|
SolutionFailed,
|
||||||
/// Erro in input
|
/// Error in input parameters
|
||||||
ParametersError,
|
ParametersError,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -71,19 +71,14 @@ use crate::optimization::line_search::Backtracking;
|
|||||||
use crate::optimization::FunctionOrder;
|
use crate::optimization::FunctionOrder;
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
#[derive(Debug, Clone, Eq, PartialEq, Default)]
|
||||||
/// Solver options for Logistic regression. Right now only LBFGS solver is supported.
|
/// Solver options for Logistic regression. Right now only LBFGS solver is supported.
|
||||||
pub enum LogisticRegressionSolverName {
|
pub enum LogisticRegressionSolverName {
|
||||||
/// Limited-memory Broyden–Fletcher–Goldfarb–Shanno method, see [LBFGS paper](http://users.iems.northwestern.edu/~nocedal/lbfgsb.html)
|
/// Limited-memory Broyden–Fletcher–Goldfarb–Shanno method, see [LBFGS paper](http://users.iems.northwestern.edu/~nocedal/lbfgsb.html)
|
||||||
|
#[default]
|
||||||
LBFGS,
|
LBFGS,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for LogisticRegressionSolverName {
|
|
||||||
fn default() -> Self {
|
|
||||||
LogisticRegressionSolverName::LBFGS
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Logistic Regression parameters
|
/// Logistic Regression parameters
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
|||||||
@@ -71,21 +71,16 @@ use crate::numbers::basenum::Number;
|
|||||||
use crate::numbers::realnum::RealNumber;
|
use crate::numbers::realnum::RealNumber;
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
#[derive(Debug, Clone, Eq, PartialEq, Default)]
|
||||||
/// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable.
|
/// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable.
|
||||||
pub enum RidgeRegressionSolverName {
|
pub enum RidgeRegressionSolverName {
|
||||||
/// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html)
|
/// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html)
|
||||||
|
#[default]
|
||||||
Cholesky,
|
Cholesky,
|
||||||
/// SVD decomposition, see [SVD](../../linalg/svd/index.html)
|
/// SVD decomposition, see [SVD](../../linalg/svd/index.html)
|
||||||
SVD,
|
SVD,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for RidgeRegressionSolverName {
|
|
||||||
fn default() -> Self {
|
|
||||||
RidgeRegressionSolverName::Cholesky
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Ridge Regression parameters
|
/// Ridge Regression parameters
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
|||||||
@@ -49,20 +49,15 @@ pub type KNNAlgorithmName = crate::algorithm::neighbour::KNNAlgorithmName;
|
|||||||
|
|
||||||
/// Weight function that is used to determine estimated value.
|
/// Weight function that is used to determine estimated value.
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone, Default)]
|
||||||
pub enum KNNWeightFunction {
|
pub enum KNNWeightFunction {
|
||||||
/// All k nearest points are weighted equally
|
/// All k nearest points are weighted equally
|
||||||
|
#[default]
|
||||||
Uniform,
|
Uniform,
|
||||||
/// k nearest points are weighted by the inverse of their distance. Closer neighbors will have a greater influence than neighbors which are further away.
|
/// k nearest points are weighted by the inverse of their distance. Closer neighbors will have a greater influence than neighbors which are further away.
|
||||||
Distance,
|
Distance,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for KNNWeightFunction {
|
|
||||||
fn default() -> Self {
|
|
||||||
KNNWeightFunction::Uniform
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl KNNWeightFunction {
|
impl KNNWeightFunction {
|
||||||
fn calc_weights(&self, distances: Vec<f64>) -> std::vec::Vec<f64> {
|
fn calc_weights(&self, distances: Vec<f64>) -> std::vec::Vec<f64> {
|
||||||
match *self {
|
match *self {
|
||||||
|
|||||||
+26
-3
@@ -2,9 +2,13 @@
|
|||||||
//! Most algorithms in `smartcore` rely on basic linear algebra operations like dot product, matrix decomposition and other subroutines that are defined for a set of real numbers, ℝ.
|
//! Most algorithms in `smartcore` rely on basic linear algebra operations like dot product, matrix decomposition and other subroutines that are defined for a set of real numbers, ℝ.
|
||||||
//! This module defines real number and some useful functions that are used in [Linear Algebra](../../linalg/index.html) module.
|
//! This module defines real number and some useful functions that are used in [Linear Algebra](../../linalg/index.html) module.
|
||||||
|
|
||||||
|
use rand::rngs::SmallRng;
|
||||||
|
use rand::{Rng, SeedableRng};
|
||||||
|
|
||||||
use num_traits::Float;
|
use num_traits::Float;
|
||||||
|
|
||||||
use crate::numbers::basenum::Number;
|
use crate::numbers::basenum::Number;
|
||||||
|
use crate::rand_custom::get_rng_impl;
|
||||||
|
|
||||||
/// Defines real number
|
/// Defines real number
|
||||||
/// <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
/// <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
|
||||||
@@ -63,8 +67,12 @@ impl RealNumber for f64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn rand() -> f64 {
|
fn rand() -> f64 {
|
||||||
// TODO: to be implemented, see issue smartcore#214
|
let mut small_rng = get_rng_impl(None);
|
||||||
1.0
|
|
||||||
|
let mut rngs: Vec<SmallRng> = (0..3)
|
||||||
|
.map(|_| SmallRng::from_rng(&mut small_rng).unwrap())
|
||||||
|
.collect();
|
||||||
|
rngs[0].gen::<f64>()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn two() -> Self {
|
fn two() -> Self {
|
||||||
@@ -108,7 +116,12 @@ impl RealNumber for f32 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn rand() -> f32 {
|
fn rand() -> f32 {
|
||||||
1.0
|
let mut small_rng = get_rng_impl(None);
|
||||||
|
|
||||||
|
let mut rngs: Vec<SmallRng> = (0..3)
|
||||||
|
.map(|_| SmallRng::from_rng(&mut small_rng).unwrap())
|
||||||
|
.collect();
|
||||||
|
rngs[0].gen::<f32>()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn two() -> Self {
|
fn two() -> Self {
|
||||||
@@ -149,4 +162,14 @@ mod tests {
|
|||||||
fn f64_from_string() {
|
fn f64_from_string() {
|
||||||
assert_eq!(f64::from_str("1.111111111").unwrap(), 1.111111111)
|
assert_eq!(f64::from_str("1.111111111").unwrap(), 1.111111111)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn f64_rand() {
|
||||||
|
f64::rand();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn f32_rand() {
|
||||||
|
f32::rand();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -137,16 +137,17 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
self.classes.as_ref()
|
self.classes.as_ref()
|
||||||
}
|
}
|
||||||
/// Get depth of tree
|
/// Get depth of tree
|
||||||
fn depth(&self) -> u16 {
|
pub fn depth(&self) -> u16 {
|
||||||
self.depth
|
self.depth
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The function to measure the quality of a split.
|
/// The function to measure the quality of a split.
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone, Default)]
|
||||||
pub enum SplitCriterion {
|
pub enum SplitCriterion {
|
||||||
/// [Gini index](../decision_tree_classifier/index.html)
|
/// [Gini index](../decision_tree_classifier/index.html)
|
||||||
|
#[default]
|
||||||
Gini,
|
Gini,
|
||||||
/// [Entropy](../decision_tree_classifier/index.html)
|
/// [Entropy](../decision_tree_classifier/index.html)
|
||||||
Entropy,
|
Entropy,
|
||||||
@@ -154,12 +155,6 @@ pub enum SplitCriterion {
|
|||||||
ClassificationError,
|
ClassificationError,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for SplitCriterion {
|
|
||||||
fn default() -> Self {
|
|
||||||
SplitCriterion::Gini
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct Node {
|
struct Node {
|
||||||
@@ -543,6 +538,10 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
parameters: DecisionTreeClassifierParameters,
|
parameters: DecisionTreeClassifierParameters,
|
||||||
) -> Result<DecisionTreeClassifier<TX, TY, X, Y>, Failed> {
|
) -> Result<DecisionTreeClassifier<TX, TY, X, Y>, Failed> {
|
||||||
let (x_nrows, num_attributes) = x.shape();
|
let (x_nrows, num_attributes) = x.shape();
|
||||||
|
if x_nrows != y.shape() {
|
||||||
|
return Err(Failed::fit("Size of x should equal size of y"));
|
||||||
|
}
|
||||||
|
|
||||||
let samples = vec![1; x_nrows];
|
let samples = vec![1; x_nrows];
|
||||||
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||||
}
|
}
|
||||||
@@ -968,6 +967,17 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_random_matrix_with_wrong_rownum() {
|
||||||
|
let x_rand: DenseMatrix<f64> = DenseMatrix::<f64>::rand(21, 200);
|
||||||
|
|
||||||
|
let y: Vec<u32> = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
|
||||||
|
|
||||||
|
let fail = DecisionTreeClassifier::fit(&x_rand, &y, Default::default());
|
||||||
|
|
||||||
|
assert!(fail.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg_attr(
|
#[cfg_attr(
|
||||||
all(target_arch = "wasm32", not(target_os = "wasi")),
|
all(target_arch = "wasm32", not(target_os = "wasi")),
|
||||||
wasm_bindgen_test::wasm_bindgen_test
|
wasm_bindgen_test::wasm_bindgen_test
|
||||||
|
|||||||
@@ -18,7 +18,6 @@
|
|||||||
//! Example:
|
//! Example:
|
||||||
//!
|
//!
|
||||||
//! ```
|
//! ```
|
||||||
//! use rand::thread_rng;
|
|
||||||
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
//! use smartcore::linalg::basic::matrix::DenseMatrix;
|
||||||
//! use smartcore::tree::decision_tree_regressor::*;
|
//! use smartcore::tree::decision_tree_regressor::*;
|
||||||
//!
|
//!
|
||||||
@@ -422,6 +421,10 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
parameters: DecisionTreeRegressorParameters,
|
parameters: DecisionTreeRegressorParameters,
|
||||||
) -> Result<DecisionTreeRegressor<TX, TY, X, Y>, Failed> {
|
) -> Result<DecisionTreeRegressor<TX, TY, X, Y>, Failed> {
|
||||||
let (x_nrows, num_attributes) = x.shape();
|
let (x_nrows, num_attributes) = x.shape();
|
||||||
|
if x_nrows != y.shape() {
|
||||||
|
return Err(Failed::fit("Size of x should equal size of y"));
|
||||||
|
}
|
||||||
|
|
||||||
let samples = vec![1; x_nrows];
|
let samples = vec![1; x_nrows];
|
||||||
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user