Improve options conditionals
This commit is contained in:
@@ -518,12 +518,9 @@ impl<TX: Number + FloatNumber + RealNumber, TY: Number + Ord, X: Array2<TX>, Y:
|
|||||||
for (i, y_hat_i) in y_hat.iterator(0).enumerate().take(n) {
|
for (i, y_hat_i) in y_hat.iterator(0).enumerate().take(n) {
|
||||||
result.set(
|
result.set(
|
||||||
i,
|
i,
|
||||||
self.classes()[if RealNumber::sigmoid(*y_hat_i + intercept) > RealNumber::half()
|
self.classes()[usize::from(
|
||||||
{
|
RealNumber::sigmoid(*y_hat_i + intercept) > RealNumber::half(),
|
||||||
1
|
)],
|
||||||
} else {
|
|
||||||
0
|
|
||||||
}],
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -673,7 +673,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
let mut is_pure = true;
|
let mut is_pure = true;
|
||||||
for i in 0..n_rows {
|
for i in 0..n_rows {
|
||||||
if visitor.samples[i] > 0 {
|
if visitor.samples[i] > 0 {
|
||||||
if label == Option::None {
|
if label.is_none() {
|
||||||
label = Option::Some(visitor.y[i]);
|
label = Option::Some(visitor.y[i]);
|
||||||
} else if visitor.y[i] != label.unwrap() {
|
} else if visitor.y[i] != label.unwrap() {
|
||||||
is_pure = false;
|
is_pure = false;
|
||||||
|
|||||||
@@ -511,7 +511,7 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
match queue.pop_front() {
|
match queue.pop_front() {
|
||||||
Some(node_id) => {
|
Some(node_id) => {
|
||||||
let node = &self.nodes()[node_id];
|
let node = &self.nodes()[node_id];
|
||||||
if node.true_child == None && node.false_child == None {
|
if node.true_child.is_none() && node.false_child.is_none() {
|
||||||
result = node.output;
|
result = node.output;
|
||||||
} else if x.get((row, node.split_feature)).to_f64().unwrap()
|
} else if x.get((row, node.split_feature)).to_f64().unwrap()
|
||||||
<= node.split_value.unwrap_or(std::f64::NAN)
|
<= node.split_value.unwrap_or(std::f64::NAN)
|
||||||
@@ -557,7 +557,7 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
|
|||||||
self.find_best_split(visitor, n, sum, parent_gain, *variable);
|
self.find_best_split(visitor, n, sum, parent_gain, *variable);
|
||||||
}
|
}
|
||||||
|
|
||||||
self.nodes()[visitor.node].split_score != Option::None
|
self.nodes()[visitor.node].split_score.is_some()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn find_best_split(
|
fn find_best_split(
|
||||||
|
|||||||
Reference in New Issue
Block a user