fix: minor refactoring
This commit is contained in:
@@ -3,7 +3,7 @@ extern crate rand;
|
|||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use std::default::Default;
|
use std::default::Default;
|
||||||
use crate::linalg::Matrix;
|
use crate::linalg::Matrix;
|
||||||
use crate::tree::decision_tree::{DecisionTree, DecisionTreeParameters, SplitCriterion, which_max};
|
use crate::tree::decision_tree_classifier::{DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion, which_max};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct RandomForestParameters {
|
pub struct RandomForestParameters {
|
||||||
@@ -17,7 +17,7 @@ pub struct RandomForestParameters {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct RandomForest {
|
pub struct RandomForest {
|
||||||
parameters: RandomForestParameters,
|
parameters: RandomForestParameters,
|
||||||
trees: Vec<DecisionTree>,
|
trees: Vec<DecisionTreeClassifier>,
|
||||||
classes: Vec<f64>
|
classes: Vec<f64>
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -51,16 +51,16 @@ impl RandomForest {
|
|||||||
|
|
||||||
let classes = y_m.unique();
|
let classes = y_m.unique();
|
||||||
let k = classes.len();
|
let k = classes.len();
|
||||||
let mut trees: Vec<DecisionTree> = Vec::new();
|
let mut trees: Vec<DecisionTreeClassifier> = Vec::new();
|
||||||
|
|
||||||
for _ in 0..parameters.n_trees {
|
for _ in 0..parameters.n_trees {
|
||||||
let samples = RandomForest::sample_with_replacement(&yi, k);
|
let samples = RandomForest::sample_with_replacement(&yi, k);
|
||||||
let params = DecisionTreeParameters{
|
let params = DecisionTreeClassifierParameters{
|
||||||
criterion: parameters.criterion.clone(),
|
criterion: parameters.criterion.clone(),
|
||||||
max_depth: parameters.max_depth,
|
max_depth: parameters.max_depth,
|
||||||
min_samples_leaf: parameters.min_samples_leaf
|
min_samples_leaf: parameters.min_samples_leaf
|
||||||
};
|
};
|
||||||
let tree = DecisionTree::fit_weak_learner(x, y, samples, mtry, params);
|
let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params);
|
||||||
trees.push(tree);
|
trees.push(tree);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,16 +4,16 @@ use crate::linalg::Matrix;
|
|||||||
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
use crate::algorithm::sort::quick_sort::QuickArgSort;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct DecisionTreeParameters {
|
pub struct DecisionTreeClassifierParameters {
|
||||||
pub criterion: SplitCriterion,
|
pub criterion: SplitCriterion,
|
||||||
pub max_depth: Option<u16>,
|
pub max_depth: Option<u16>,
|
||||||
pub min_samples_leaf: u16
|
pub min_samples_leaf: u16
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct DecisionTree {
|
pub struct DecisionTreeClassifier {
|
||||||
nodes: Vec<Node>,
|
nodes: Vec<Node>,
|
||||||
parameters: DecisionTreeParameters,
|
parameters: DecisionTreeClassifierParameters,
|
||||||
num_classes: usize,
|
num_classes: usize,
|
||||||
classes: Vec<f64>,
|
classes: Vec<f64>,
|
||||||
depth: u16
|
depth: u16
|
||||||
@@ -38,9 +38,9 @@ pub struct Node {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
impl Default for DecisionTreeParameters {
|
impl Default for DecisionTreeClassifierParameters {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
DecisionTreeParameters {
|
DecisionTreeClassifierParameters {
|
||||||
criterion: SplitCriterion::Gini,
|
criterion: SplitCriterion::Gini,
|
||||||
max_depth: None,
|
max_depth: None,
|
||||||
min_samples_leaf: 1
|
min_samples_leaf: 1
|
||||||
@@ -139,15 +139,15 @@ pub(in crate) fn which_max(x: &Vec<u32>) -> usize {
|
|||||||
return which;
|
return which;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DecisionTree {
|
impl DecisionTreeClassifier {
|
||||||
|
|
||||||
pub fn fit<M: Matrix>(x: &M, y: &M::RowVector, parameters: DecisionTreeParameters) -> DecisionTree {
|
pub fn fit<M: Matrix>(x: &M, y: &M::RowVector, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier {
|
||||||
let (x_nrows, num_attributes) = x.shape();
|
let (x_nrows, num_attributes) = x.shape();
|
||||||
let samples = vec![1; x_nrows];
|
let samples = vec![1; x_nrows];
|
||||||
DecisionTree::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fit_weak_learner<M: Matrix>(x: &M, y: &M::RowVector, samples: Vec<u32>, mtry: usize, parameters: DecisionTreeParameters) -> DecisionTree {
|
pub fn fit_weak_learner<M: Matrix>(x: &M, y: &M::RowVector, samples: Vec<u32>, mtry: usize, parameters: DecisionTreeClassifierParameters) -> DecisionTreeClassifier {
|
||||||
let y_m = M::from_row_vector(y.clone());
|
let y_m = M::from_row_vector(y.clone());
|
||||||
let (_, y_ncols) = y_m.shape();
|
let (_, y_ncols) = y_m.shape();
|
||||||
let (_, num_attributes) = x.shape();
|
let (_, num_attributes) = x.shape();
|
||||||
@@ -179,7 +179,7 @@ impl DecisionTree {
|
|||||||
order.push(x.get_col_as_vec(i).quick_argsort());
|
order.push(x.get_col_as_vec(i).quick_argsort());
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut tree = DecisionTree{
|
let mut tree = DecisionTreeClassifier{
|
||||||
nodes: nodes,
|
nodes: nodes,
|
||||||
parameters: parameters,
|
parameters: parameters,
|
||||||
num_classes: k,
|
num_classes: k,
|
||||||
@@ -435,9 +435,9 @@ mod tests {
|
|||||||
&[5.2, 2.7, 3.9, 1.4]]);
|
&[5.2, 2.7, 3.9, 1.4]]);
|
||||||
let y = vec![0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
|
let y = vec![0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.];
|
||||||
|
|
||||||
assert_eq!(y, DecisionTree::fit(&x, &y, Default::default()).predict(&x));
|
assert_eq!(y, DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x));
|
||||||
|
|
||||||
assert_eq!(3, DecisionTree::fit(&x, &y, DecisionTreeParameters{criterion: SplitCriterion::Entropy, max_depth: Some(3), min_samples_leaf: 1}).depth);
|
assert_eq!(3, DecisionTreeClassifier::fit(&x, &y, DecisionTreeClassifierParameters{criterion: SplitCriterion::Entropy, max_depth: Some(3), min_samples_leaf: 1}).depth);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -467,7 +467,7 @@ mod tests {
|
|||||||
&[0.,0.,0.,1.]]);
|
&[0.,0.,0.,1.]]);
|
||||||
let y = vec![1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.];
|
let y = vec![1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 0., 0., 0.];
|
||||||
|
|
||||||
assert_eq!(y, DecisionTree::fit(&x, &y, Default::default()).predict(&x));
|
assert_eq!(y, DecisionTreeClassifier::fit(&x, &y, Default::default()).predict(&x));
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
+1
-1
@@ -1,2 +1,2 @@
|
|||||||
pub mod decision_tree_regressor;
|
pub mod decision_tree_regressor;
|
||||||
pub mod decision_tree;
|
pub mod decision_tree_classifier;
|
||||||
Reference in New Issue
Block a user