Improve options conditionals

This commit is contained in:
Lorenzo (Mec-iS)
2022-11-03 14:58:05 +00:00
parent ba70bb941f
commit b66afa9222
3 changed files with 6 additions and 9 deletions
+3 -6
View File
@@ -518,12 +518,9 @@ impl<TX: Number + FloatNumber + RealNumber, TY: Number + Ord, X: Array2<TX>, Y:
for (i, y_hat_i) in y_hat.iterator(0).enumerate().take(n) { for (i, y_hat_i) in y_hat.iterator(0).enumerate().take(n) {
result.set( result.set(
i, i,
self.classes()[if RealNumber::sigmoid(*y_hat_i + intercept) > RealNumber::half() self.classes()[usize::from(
{ RealNumber::sigmoid(*y_hat_i + intercept) > RealNumber::half(),
1 )],
} else {
0
}],
); );
} }
} else { } else {
+1 -1
View File
@@ -673,7 +673,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
let mut is_pure = true; let mut is_pure = true;
for i in 0..n_rows { for i in 0..n_rows {
if visitor.samples[i] > 0 { if visitor.samples[i] > 0 {
if label == Option::None { if label.is_none() {
label = Option::Some(visitor.y[i]); label = Option::Some(visitor.y[i]);
} else if visitor.y[i] != label.unwrap() { } else if visitor.y[i] != label.unwrap() {
is_pure = false; is_pure = false;
+2 -2
View File
@@ -511,7 +511,7 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
match queue.pop_front() { match queue.pop_front() {
Some(node_id) => { Some(node_id) => {
let node = &self.nodes()[node_id]; let node = &self.nodes()[node_id];
if node.true_child == None && node.false_child == None { if node.true_child.is_none() && node.false_child.is_none() {
result = node.output; result = node.output;
} else if x.get((row, node.split_feature)).to_f64().unwrap() } else if x.get((row, node.split_feature)).to_f64().unwrap()
<= node.split_value.unwrap_or(std::f64::NAN) <= node.split_value.unwrap_or(std::f64::NAN)
@@ -557,7 +557,7 @@ impl<TX: Number + PartialOrd, TY: Number, X: Array2<TX>, Y: Array1<TY>>
self.find_best_split(visitor, n, sum, parent_gain, *variable); self.find_best_split(visitor, n, sum, parent_gain, *variable);
} }
self.nodes()[visitor.node].split_score != Option::None self.nodes()[visitor.node].split_score.is_some()
} }
fn find_best_split( fn find_best_split(