Skip to content

Commit 09b7a3d

Browse files
committed
feat: also include early_exit in scalar checks
1 parent 3d7b529 commit 09b7a3d

File tree

1 file changed

+33
-38
lines changed

1 file changed

+33
-38
lines changed

src/Evaluate.jl

Lines changed: 33 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@ import ..ValueInterfaceModule: is_valid, is_valid_array
1212

1313
const OPERATOR_LIMIT_BEFORE_SLOWDOWN = 15
1414

15-
macro return_on_check(val, X)
15+
macro return_on_nonfinite_val(eval_options, val, X)
1616
:(
17-
if !is_valid($(esc(val)))
17+
if $(esc(eval_options)).early_exit isa Val{true} && !is_valid($(esc(val)))
1818
return $(ResultOk)(similar($(esc(X)), axes($(esc(X)), 2)), false)
1919
end
2020
)
2121
end
2222

23-
macro return_on_nonfinite_array(array)
23+
macro return_on_nonfinite_array(eval_options, array)
2424
:(
25-
if !is_valid_array($(esc(array)))
25+
if $(esc(eval_options)).early_exit isa Val{true} && !is_valid_array($(esc(array)))
2626
return $(ResultOk)($(esc(array)), false)
2727
end
2828
)
@@ -257,10 +257,10 @@ end
257257
return quote
258258
result_l = _eval_tree_array(tree.l, cX, operators, eval_options)
259259
!result_l.ok && return result_l
260-
eval_options.early_exit isa Val{true} && @return_on_nonfinite_array result_l.x
260+
@return_on_nonfinite_array(eval_options, result_l.x)
261261
result_r = _eval_tree_array(tree.r, cX, operators, eval_options)
262262
!result_r.ok && return result_r
263-
eval_options.early_exit isa Val{true} && @return_on_nonfinite_array result_r.x
263+
@return_on_nonfinite_array(eval_options, result_r.x)
264264
# op(x, y), for any x or y
265265
deg2_eval(result_l.x, result_r.x, operators.binops[op_idx], eval_options)
266266
end
@@ -275,26 +275,22 @@ end
275275
elseif tree.r.degree == 0
276276
result_l = _eval_tree_array(tree.l, cX, operators, eval_options)
277277
!result_l.ok && return result_l
278-
eval_options.early_exit isa Val{true} &&
279-
@return_on_nonfinite_array result_l.x
278+
@return_on_nonfinite_array(eval_options, result_l.x)
280279
# op(x, y), where y is a constant or variable but x is not.
281280
deg2_r0_eval(tree, result_l.x, cX, op, eval_options)
282281
elseif tree.l.degree == 0
283282
result_r = _eval_tree_array(tree.r, cX, operators, eval_options)
284283
!result_r.ok && return result_r
285-
eval_options.early_exit isa Val{true} &&
286-
@return_on_nonfinite_array result_r.x
284+
@return_on_nonfinite_array(eval_options, result_r.x)
287285
# op(x, y), where x is a constant or variable but y is not.
288286
deg2_l0_eval(tree, result_r.x, cX, op, eval_options)
289287
else
290288
result_l = _eval_tree_array(tree.l, cX, operators, eval_options)
291289
!result_l.ok && return result_l
292-
eval_options.early_exit isa Val{true} &&
293-
@return_on_nonfinite_array result_l.x
290+
@return_on_nonfinite_array(eval_options, result_l.x)
294291
result_r = _eval_tree_array(tree.r, cX, operators, eval_options)
295292
!result_r.ok && return result_r
296-
eval_options.early_exit isa Val{true} &&
297-
@return_on_nonfinite_array result_r.x
293+
@return_on_nonfinite_array(eval_options, result_r.x)
298294
# op(x, y), for any x or y
299295
deg2_eval(result_l.x, result_r.x, op, eval_options)
300296
end
@@ -315,7 +311,7 @@ end
315311
return quote
316312
result = _eval_tree_array(tree.l, cX, operators, eval_options)
317313
!result.ok && return result
318-
eval_options.early_exit isa Val{true} && @return_on_nonfinite_array result.x
314+
@return_on_nonfinite_array(eval_options, result.x)
319315
deg1_eval(result.x, operators.unaops[op_idx], eval_options)
320316
end
321317
end
@@ -342,8 +338,7 @@ end
342338
# op(x), for any x.
343339
result = _eval_tree_array(tree.l, cX, operators, eval_options)
344340
!result.ok && return result
345-
eval_options.early_exit isa Val{true} &&
346-
@return_on_nonfinite_array result.x
341+
@return_on_nonfinite_array(eval_options, result.x)
347342
deg1_eval(result.x, op, eval_options)
348343
end
349344
end
@@ -396,21 +391,21 @@ function deg1_l2_ll0_lr0_eval(
396391
cX::AbstractMatrix{T},
397392
op::F,
398393
op_l::F2,
399-
::EvalOptions{false,false},
394+
eval_options::EvalOptions{false,false},
400395
) where {T,F,F2}
401396
if tree.l.l.constant && tree.l.r.constant
402397
val_ll = tree.l.l.val
403398
val_lr = tree.l.r.val
404-
@return_on_check val_ll cX
405-
@return_on_check val_lr cX
399+
@return_on_nonfinite_val(eval_options, val_ll, cX)
400+
@return_on_nonfinite_val(eval_options, val_lr, cX)
406401
x_l = op_l(val_ll, val_lr)::T
407-
@return_on_check x_l cX
402+
@return_on_nonfinite_val(eval_options, x_l, cX)
408403
x = op(x_l)::T
409-
@return_on_check x cX
404+
@return_on_nonfinite_val(eval_options, x, cX)
410405
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
411406
elseif tree.l.l.constant
412407
val_ll = tree.l.l.val
413-
@return_on_check val_ll cX
408+
@return_on_nonfinite_val(eval_options, val_ll, cX)
414409
feature_lr = tree.l.r.feature
415410
cumulator = similar(cX, axes(cX, 2))
416411
@inbounds @simd for j in axes(cX, 2)
@@ -422,7 +417,7 @@ function deg1_l2_ll0_lr0_eval(
422417
elseif tree.l.r.constant
423418
feature_ll = tree.l.l.feature
424419
val_lr = tree.l.r.val
425-
@return_on_check val_lr cX
420+
@return_on_nonfinite_val(eval_options, val_lr, cX)
426421
cumulator = similar(cX, axes(cX, 2))
427422
@inbounds @simd for j in axes(cX, 2)
428423
x_l = op_l(cX[feature_ll, j], val_lr)::T
@@ -449,15 +444,15 @@ function deg1_l1_ll0_eval(
449444
cX::AbstractMatrix{T},
450445
op::F,
451446
op_l::F2,
452-
::EvalOptions{false,false},
447+
eval_options::EvalOptions{false,false},
453448
) where {T,F,F2}
454449
if tree.l.l.constant
455450
val_ll = tree.l.l.val
456-
@return_on_check val_ll cX
451+
@return_on_nonfinite_val(eval_options, val_ll, cX)
457452
x_l = op_l(val_ll)::T
458-
@return_on_check x_l cX
453+
@return_on_nonfinite_val(eval_options, x_l, cX)
459454
x = op(x_l)::T
460-
@return_on_check x cX
455+
@return_on_nonfinite_val(eval_options, x, cX)
461456
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
462457
else
463458
feature_ll = tree.l.l.feature
@@ -476,20 +471,20 @@ function deg2_l0_r0_eval(
476471
tree::AbstractExpressionNode{T},
477472
cX::AbstractMatrix{T},
478473
op::F,
479-
::EvalOptions{false,false},
474+
eval_options::EvalOptions{false,false},
480475
) where {T,F}
481476
if tree.l.constant && tree.r.constant
482477
val_l = tree.l.val
483-
@return_on_check val_l cX
478+
@return_on_nonfinite_val(eval_options, val_l, cX)
484479
val_r = tree.r.val
485-
@return_on_check val_r cX
480+
@return_on_nonfinite_val(eval_options, val_r, cX)
486481
x = op(val_l, val_r)::T
487-
@return_on_check x cX
482+
@return_on_nonfinite_val(eval_options, x, cX)
488483
return ResultOk(fill_similar(x, cX, axes(cX, 2)), true)
489484
elseif tree.l.constant
490485
cumulator = similar(cX, axes(cX, 2))
491486
val_l = tree.l.val
492-
@return_on_check val_l cX
487+
@return_on_nonfinite_val(eval_options, val_l, cX)
493488
feature_r = tree.r.feature
494489
@inbounds @simd for j in axes(cX, 2)
495490
x = op(val_l, cX[feature_r, j])::T
@@ -500,7 +495,7 @@ function deg2_l0_r0_eval(
500495
cumulator = similar(cX, axes(cX, 2))
501496
feature_l = tree.l.feature
502497
val_r = tree.r.val
503-
@return_on_check val_r cX
498+
@return_on_nonfinite_val(eval_options, val_r, cX)
504499
@inbounds @simd for j in axes(cX, 2)
505500
x = op(cX[feature_l, j], val_r)::T
506501
cumulator[j] = x
@@ -524,11 +519,11 @@ function deg2_l0_eval(
524519
cumulator::AbstractVector{T},
525520
cX::AbstractArray{T},
526521
op::F,
527-
::EvalOptions{false,false},
522+
eval_options::EvalOptions{false,false},
528523
) where {T,F}
529524
if tree.l.constant
530525
val = tree.l.val
531-
@return_on_check val cX
526+
@return_on_nonfinite_val(eval_options, val, cX)
532527
@inbounds @simd for j in eachindex(cumulator)
533528
x = op(val, cumulator[j])::T
534529
cumulator[j] = x
@@ -550,11 +545,11 @@ function deg2_r0_eval(
550545
cumulator::AbstractVector{T},
551546
cX::AbstractArray{T},
552547
op::F,
553-
::EvalOptions{false,false},
548+
eval_options::EvalOptions{false,false},
554549
) where {T,F}
555550
if tree.r.constant
556551
val = tree.r.val
557-
@return_on_check val cX
552+
@return_on_nonfinite_val(eval_options, val, cX)
558553
@inbounds @simd for j in eachindex(cumulator)
559554
x = op(cumulator[j], val)::T
560555
cumulator[j] = x

0 commit comments

Comments
 (0)