fix: minor refactoring

This commit is contained in:
Volodymyr Orlov
2020-03-20 15:58:10 -07:00
parent a96f303dea
commit 6577e22111
3 changed files with 19 additions and 19 deletions
+5 -5
View File
@@ -3,7 +3,7 @@ extern crate rand;
use rand::Rng;
use std::default::Default;
use crate::linalg::Matrix;
use crate::tree::decision_tree::{DecisionTree, DecisionTreeParameters, SplitCriterion, which_max};
use crate::tree::decision_tree_classifier::{DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion, which_max};
#[derive(Debug, Clone)]
pub struct RandomForestParameters {
@@ -17,7 +17,7 @@ pub struct RandomForestParameters {
#[derive(Debug)]
pub struct RandomForest {
parameters: RandomForestParameters,
trees: Vec<DecisionTree>,
trees: Vec<DecisionTreeClassifier>,
classes: Vec<f64>
}
@@ -51,16 +51,16 @@ impl RandomForest {
let classes = y_m.unique();
let k = classes.len();
let mut trees: Vec<DecisionTree> = Vec::new();
let mut trees: Vec<DecisionTreeClassifier> = Vec::new();
for _ in 0..parameters.n_trees {
let samples = RandomForest::sample_with_replacement(&yi, k);
let params = DecisionTreeParameters{
let params = DecisionTreeClassifierParameters{
criterion: parameters.criterion.clone(),
max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf
};
let tree = DecisionTree::fit_weak_learner(x, y, samples, mtry, params);
let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params);
trees.push(tree);
}