feat: adds train/test split function; fixes bug in random forest
This commit is contained in:
@@ -67,6 +67,7 @@ use std::default::Default;
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use rand::seq::SliceRandom;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||
@@ -431,6 +432,10 @@ impl<T: RealNumber> DecisionTreeClassifier<T> {
|
||||
variables[i] = i;
|
||||
}
|
||||
|
||||
if mtry < n_attr {
|
||||
variables.shuffle(&mut rand::thread_rng());
|
||||
}
|
||||
|
||||
for j in 0..mtry {
|
||||
self.find_best_split(
|
||||
visitor,
|
||||
|
||||
Reference in New Issue
Block a user