fix: minor refactoring
This commit is contained in:
@@ -3,7 +3,7 @@ extern crate rand;
|
||||
use rand::Rng;
|
||||
use std::default::Default;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::tree::decision_tree::{DecisionTree, DecisionTreeParameters, SplitCriterion, which_max};
|
||||
use crate::tree::decision_tree_classifier::{DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion, which_max};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RandomForestParameters {
|
||||
@@ -17,7 +17,7 @@ pub struct RandomForestParameters {
|
||||
#[derive(Debug)]
|
||||
pub struct RandomForest {
|
||||
parameters: RandomForestParameters,
|
||||
trees: Vec<DecisionTree>,
|
||||
trees: Vec<DecisionTreeClassifier>,
|
||||
classes: Vec<f64>
|
||||
}
|
||||
|
||||
@@ -51,16 +51,16 @@ impl RandomForest {
|
||||
|
||||
let classes = y_m.unique();
|
||||
let k = classes.len();
|
||||
let mut trees: Vec<DecisionTree> = Vec::new();
|
||||
let mut trees: Vec<DecisionTreeClassifier> = Vec::new();
|
||||
|
||||
for _ in 0..parameters.n_trees {
|
||||
let samples = RandomForest::sample_with_replacement(&yi, k);
|
||||
let params = DecisionTreeParameters{
|
||||
let params = DecisionTreeClassifierParameters{
|
||||
criterion: parameters.criterion.clone(),
|
||||
max_depth: parameters.max_depth,
|
||||
min_samples_leaf: parameters.min_samples_leaf
|
||||
};
|
||||
let tree = DecisionTree::fit_weak_learner(x, y, samples, mtry, params);
|
||||
let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params);
|
||||
trees.push(tree);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user