fix: minor refactoring
This commit is contained in:
@@ -3,7 +3,7 @@ extern crate rand;
|
||||
use rand::Rng;
|
||||
use std::default::Default;
|
||||
use crate::linalg::Matrix;
|
||||
use crate::tree::decision_tree::{DecisionTree, DecisionTreeParameters, SplitCriterion, which_max};
|
||||
use crate::tree::decision_tree_classifier::{DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion, which_max};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RandomForestParameters {
|
||||
@@ -17,7 +17,7 @@ pub struct RandomForestParameters {
|
||||
#[derive(Debug)]
|
||||
pub struct RandomForest {
|
||||
parameters: RandomForestParameters,
|
||||
trees: Vec<DecisionTree>,
|
||||
trees: Vec<DecisionTreeClassifier>,
|
||||
classes: Vec<f64>
|
||||
}
|
||||
|
||||
@@ -51,16 +51,16 @@ impl RandomForest {
|
||||
|
||||
let classes = y_m.unique();
|
||||
let k = classes.len();
|
||||
let mut trees: Vec<DecisionTree> = Vec::new();
|
||||
let mut trees: Vec<DecisionTreeClassifier> = Vec::new();
|
||||
|
||||
for _ in 0..parameters.n_trees {
|
||||
let samples = RandomForest::sample_with_replacement(&yi, k);
|
||||
let params = DecisionTreeParameters{
|
||||
let params = DecisionTreeClassifierParameters{
|
||||
criterion: parameters.criterion.clone(),
|
||||
max_depth: parameters.max_depth,
|
||||
min_samples_leaf: parameters.min_samples_leaf
|
||||
};
|
||||
let tree = DecisionTree::fit_weak_learner(x, y, samples, mtry, params);
|
||||
let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params);
|
||||
trees.push(tree);
|
||||
}
|
||||
|
||||
|
||||
@@ -4,16 +4,16 @@ use crate::linalg::Matrix;
|
||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DecisionTreeParameters {
|
||||
pub struct DecisionTreeClassifierParameters {
|
||||
pub criterion: SplitCriterion,
|
||||
pub max_depth: Option<u16>,
|
||||
pub min_samples_leaf: u16
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DecisionTree {
|
||||
pub struct DecisionTreeClassifier {
|
||||
nodes: Vec<Node>,
|
||||
parameters: DecisionTreeParameters,
|
||||
parameters: DecisionTreeClassifierParameters,
|
||||
num_classes: usize,
|
||||
classes: Vec<f64>,
|
||||
depth: u16
|
||||
@@ -38,9 +38,9 @@ pub struct Node {
|
||||
}
|
||||
|
||||
|
||||
impl Default for DecisionTreeParameters {
|
||||
impl Default for DecisionTreeClassifierParameters {
|
||||
fn default() -> Self {
|
||||
DecisionTreeParameters {
|
||||
DecisionTreeClassifierParameters {
|
||||
criterion: SplitCriterion::Gini,
|
||||
max_depth: None,
|
||||
min_samples_leaf: 1
|
||||
@@ -139,15 +139,15 @@ pub(in crate) fn which_max(x: &Vec<u32>) -> usize {
|
||||
return which;
|
||||
}
|
||||
|
||||
impl DecisionTree {
|
||||
impl DecisionTreeClassifier {
|
||||
|
||||
pub fn fit<M: Matrix>(x: &M, y: &M::RowVector, parameters: DecisionTreeParameters) -> DecisionTree {
|
||||
pub fn fit<M: Matrix>(x: &M, y: &M::RowVector, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier {
|
||||
let (x_nrows, num_attributes) = x.shape();
|
||||
let samples = vec![1; x_nrows];
|
||||
DecisionTree::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||
}
|
||||
|
||||
pub fn fit_weak_learner<M: Matrix>(x: &M, y: &M::RowVector, samples: Vec<u32>, mtry: usize, parameters: DecisionTreeParameters) -> DecisionTree {
|
||||
pub fn fit_weak_learner<M: Matrix>(x: &M, y: &M::RowVector, samples: Vec<u32>, mtry: usize, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier {
|
||||
let y_m = M::from_row_vector(y.clone());
|
||||
let (_, y_ncols) = y_m.shape();
|
||||
let (_, num_attributes) = x.shape();
|
||||
@@ -179,7 +179,7 @@ impl DecisionTree {
|
||||
order.push(x.get_col_as_vec(i).quick_argsort());
|
||||
}
|
||||
|
||||
let mut tree = DecisionTree{
|
||||
let mut tree = DecisionTreeClassifier{
|
||||
nodes: nodes,
|
||||
parameters: parameters,
|
||||
num_classes: k,
|
||||
@@ -435,9 +435,9 @@ mod tests {
|
||||
&[5.2, 2.7, 3.9, 1.4]]);
|
||||
let y = vec![0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
|
||||
|
||||
assert_eq!(y, DecisionTree::fit(&x, &y, Default::default()).predict(&x));
|
||||
assert_eq!(y, DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x));
|
||||
|
||||
assert_eq!(3, DecisionTree::fit(&x, &y, DecisionTreeParameters{criterion: SplitCriterion::Entropy, max_depth: Some(3), min_samples_leaf: 1}).depth);
|
||||
assert_eq!(3, DecisionTreeClassifier::fit(&x, &y, DecisionTreeClassifierParameters{criterion: SplitCriterion::Entropy, max_depth: Some(3), min_samples_leaf: 1}).depth);
|
||||
|
||||
}
|
||||
|
||||
@@ -467,7 +467,7 @@ mod tests {
|
||||
&[0.,0.,0.,1.]]);
|
||||
let y = vec![1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.];
|
||||
|
||||
assert_eq!(y, DecisionTree::fit(&x, &y, Default::default()).predict(&x));
|
||||
assert_eq!(y, DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x));
|
||||
|
||||
}
|
||||
}
|
||||
+1
-1
@@ -1,2 +1,2 @@
|
||||
pub mod decision_tree_regressor;
|
||||
pub mod decision_tree;
|
||||
pub mod decision_tree_classifier;
|
||||
Reference in New Issue
Block a user