Release 0.3 (#235)

This commit is contained in:
Lorenzo
2022-11-08 15:22:34 +00:00
committed by GitHub
parent aab3817c58
commit 161d249917
30 changed files with 133 additions and 103 deletions
+5 -8
View File
@@ -163,7 +163,6 @@ impl Default for SplitCriterion {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
struct Node {
index: usize,
output: usize,
split_feature: usize,
split_value: Option<f64>,
@@ -406,9 +405,8 @@ impl Default for DecisionTreeClassifierSearchParameters {
}
impl Node {
fn new(index: usize, output: usize) -> Self {
fn new(output: usize) -> Self {
Node {
index,
output,
split_feature: 0,
split_value: Option::None,
@@ -582,7 +580,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
count[yi[i]] += samples[i];
}
let root = Node::new(0, which_max(&count));
let root = Node::new(which_max(&count));
change_nodes.push(root);
let mut order: Vec<Vec<usize>> = Vec::new();
@@ -831,11 +829,9 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
let true_child_idx = self.nodes().len();
self.nodes
.push(Node::new(true_child_idx, visitor.true_child_output));
self.nodes.push(Node::new(visitor.true_child_output));
let false_child_idx = self.nodes().len();
self.nodes
.push(Node::new(false_child_idx, visitor.false_child_output));
self.nodes.push(Node::new(visitor.false_child_output));
self.nodes[visitor.node].true_child = Some(true_child_idx);
self.nodes[visitor.node].false_child = Some(false_child_idx);
@@ -923,6 +919,7 @@ mod tests {
wasm_bindgen_test::wasm_bindgen_test
)]
#[test]
#[cfg(feature = "datasets")]
fn fit_predict_iris() {
let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],