Skip to content

Commit 6e3db19

Browse files
committed
refactor: clean up use of is_valid throughout library
1 parent 455e80f commit 6e3db19

File tree

6 files changed

+28
-32
lines changed

6 files changed

+28
-32
lines changed

ext/DynamicExpressionsBumperExt.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
module DynamicExpressionsBumperExt
22

33
using Bumper: @no_escape, @alloc
4-
using DynamicExpressions: OperatorEnum, AbstractExpressionNode, tree_mapreduce
5-
using DynamicExpressions.UtilsModule: ResultOk, counttuple, is_bad_array
4+
using DynamicExpressions:
5+
OperatorEnum, AbstractExpressionNode, tree_mapreduce, is_valid_array
6+
using DynamicExpressions.UtilsModule: ResultOk, counttuple
67

78
import DynamicExpressions.ExtensionInterfaceModule:
89
bumper_eval_tree_array, bumper_kern1!, bumper_kern2!
@@ -52,7 +53,7 @@ function dispatch_kerns!(operators, branch_node, cumulator, ::Val{turbo}) where
5253
cumulator.ok || return cumulator
5354

5455
out = dispatch_kern1!(operators.unaops, branch_node.op, cumulator.x, Val(turbo))
55-
return ResultOk(out, !is_bad_array(out))
56+
return ResultOk(out, is_valid_array(out))
5657
end
5758
function dispatch_kerns!(
5859
operators, branch_node, cumulator1, cumulator2, ::Val{turbo}
@@ -63,7 +64,7 @@ function dispatch_kerns!(
6364
out = dispatch_kern2!(
6465
operators.binops, branch_node.op, cumulator1.x, cumulator2.x, Val(turbo)
6566
)
66-
return ResultOk(out, !is_bad_array(out))
67+
return ResultOk(out, is_valid_array(out))
6768
end
6869

6970
@generated function dispatch_kern1!(unaops, op_idx, cumulator, ::Val{turbo}) where {turbo}

ext/DynamicExpressionsSymbolicUtilsExt.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ using SymbolicUtils
44
import DynamicExpressions.NodeModule:
55
AbstractExpressionNode, Node, constructorof, DEFAULT_NODE_TYPE
66
import DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
7-
import DynamicExpressions.UtilsModule: isgood, isbad, deprecate_varmap
7+
import DynamicExpressions.ValueInterfaceModule: is_valid
8+
import DynamicExpressions.UtilsModule: deprecate_varmap
89
import DynamicExpressions.ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
910
import DynamicExpressions: AbstractExpression, get_tree, get_operators
1011

@@ -19,14 +20,14 @@ macro return_on_false(flag, retval)
1920
)
2021
end
2122

22-
function isgood(x::SymbolicUtils.Symbolic)
23+
function is_valid(x::SymbolicUtils.Symbolic)
2324
return if SymbolicUtils.istree(x)
24-
all(isgood.([SymbolicUtils.operation(x); SymbolicUtils.arguments(x)]))
25+
all(is_valid.([SymbolicUtils.operation(x); SymbolicUtils.arguments(x)]))
2526
else
2627
true
2728
end
2829
end
29-
subs_bad(x) = isgood(x) ? x : Inf
30+
subs_bad(x) = is_valid(x) ? x : Inf
3031

3132
function parse_tree_to_eqs(
3233
tree::AbstractExpressionNode{T},
@@ -197,7 +198,7 @@ function node_to_symbolic(
197198
variable_names = deprecate_varmap(variable_names, varMap, :node_to_symbolic)
198199
expr = subs_bad(parse_tree_to_eqs(tree, operators, index_functions))
199200
# Check for NaN and Inf
200-
@assert isgood(expr) "The recovered equation contains NaN or Inf."
201+
@assert is_valid(expr) "The recovered equation contains NaN or Inf."
201202
# Return if no variable_names is given
202203
variable_names === nothing && return expr
203204
# Create a substitution tuple
@@ -248,12 +249,12 @@ function multiply_powers(
248249
if nargs == 1
249250
l, complete = multiply_powers(args[1])
250251
@return_on_false complete eqn
251-
@return_on_false isgood(l) eqn
252+
@return_on_false is_valid(l) eqn
252253
return op(l), true
253254
elseif op == ^
254255
l, complete = multiply_powers(args[1])
255256
@return_on_false complete eqn
256-
@return_on_false isgood(l) eqn
257+
@return_on_false is_valid(l) eqn
257258
n = args[2]
258259
if typeof(n) <: Integer
259260
if n == 1
@@ -275,23 +276,23 @@ function multiply_powers(
275276
elseif nargs == 2
276277
l, complete = multiply_powers(args[1])
277278
@return_on_false complete eqn
278-
@return_on_false isgood(l) eqn
279+
@return_on_false is_valid(l) eqn
279280
r, complete2 = multiply_powers(args[2])
280281
@return_on_false complete2 eqn
281-
@return_on_false isgood(r) eqn
282+
@return_on_false is_valid(r) eqn
282283
return op(l, r), true
283284
else
284285
# return tree_mapreduce(multiply_powers, op, args)
285286
# ## reduce(op, map(multiply_powers, args))
286287
out = map(multiply_powers, args) #vector of tuples
287288
for i in 1:size(out, 1)
288289
@return_on_false out[i][2] eqn
289-
@return_on_false isgood(out[i][1]) eqn
290+
@return_on_false is_valid(out[i][1]) eqn
290291
end
291292
cumulator = out[1][1]
292293
for i in 2:size(out, 1)
293294
cumulator = op(cumulator, out[i][1])
294-
@return_on_false isgood(cumulator) eqn
295+
@return_on_false is_valid(cumulator) eqn
295296
end
296297
return cumulator, true
297298
end

src/Evaluate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using DispatchDoctor: @unstable
55
import ..NodeModule: AbstractExpressionNode, constructorof
66
import ..StringsModule: string_tree
77
import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum
8-
import ..UtilsModule: is_bad_array, fill_similar, counttuple, ResultOk
8+
import ..UtilsModule: fill_similar, counttuple, ResultOk
99
import ..NodeUtilsModule: is_constant
1010
import ..ExtensionInterfaceModule: bumper_eval_tree_array, _is_loopvectorization_loaded
1111
import ..ValueInterfaceModule: is_valid, is_valid_array

src/EvaluateDerivative.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ module EvaluateDerivativeModule
22

33
import ..NodeModule: AbstractExpressionNode, constructorof
44
import ..OperatorEnumModule: OperatorEnum
5-
import ..UtilsModule: is_bad_array, fill_similar, ResultOk2
6-
import ..NodeUtilsModule: count_scalar_constants, index_constant_nodes, NodeIndex
5+
import ..UtilsModule: fill_similar, ResultOk2
6+
import ..ValueInterfaceModule: is_valid_array
7+
import ..NodeUtilsModule: count_constant_nodes, index_constant_nodes, NodeIndex
78
import ..EvaluateModule: deg0_eval, get_nuna, get_nbin, OPERATOR_LIMIT_BEFORE_SLOWDOWN
89
import ..ExtensionInterfaceModule: _zygote_gradient
910

@@ -105,7 +106,7 @@ end
105106
end
106107
!result.ok && return result
107108
return ResultOk2(
108-
result.x, result.dx, !(is_bad_array(result.x) || is_bad_array(result.dx))
109+
result.x, result.dx, is_valid_array(result.x) && is_valid_array(result.dx)
109110
)
110111
end
111112
end
@@ -213,9 +214,9 @@ function eval_grad_tree_array(
213214
n_gradients = if variable_mode
214215
size(cX, 1)::Int
215216
elseif constant_mode
216-
count_scalar_constants(tree)::Int
217+
count_constant_nodes(tree)::Int
217218
elseif both_mode
218-
size(cX, 1) + count_scalar_constants(tree)
219+
size(cX, 1) + count_constant_nodes(tree)
219220
end
220221

221222
result = if variable_mode
@@ -247,7 +248,7 @@ function eval_grad_tree_array(
247248
result = _eval_grad_tree_array(tree, n_gradients, index_tree, cX, operators, Val(mode))
248249
!result.ok && return result
249250
return ResultOk2(
250-
result.x, result.dx, !(is_bad_array(result.x) || is_bad_array(result.dx))
251+
result.x, result.dx, is_valid_array(result.x) && is_valid_array(result.dx)
251252
)
252253
end
253254

src/Simplify.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module SimplifyModule
33
import ..NodeModule: AbstractExpressionNode, constructorof, Node, copy_node, set_node!
44
import ..NodeUtilsModule: tree_mapreduce, is_node_constant
55
import ..OperatorEnumModule: AbstractOperatorEnum
6-
import ..UtilsModule: isbad, isgood
6+
import ..ValueInterfaceModule: is_valid
77

88
_una_op_kernel(f::F, l::T) where {F,T} = f(l)
99
_bin_op_kernel(f::F, l::T, r::T) where {F,T} = f(l, r)
@@ -109,13 +109,13 @@ end
109109
function combine_children!(operators, p::N, c::N...) where {T,N<:AbstractExpressionNode{T}}
110110
all(is_node_constant, c) || return p
111111
vals = map(n -> n.val, c)
112-
all(isgood, vals) || return p
112+
all(is_valid, vals) || return p
113113
out = if length(c) == 1
114114
_una_op_kernel(operators.unaops[p.op], vals...)
115115
else
116116
_bin_op_kernel(operators.binops[p.op], vals...)
117117
end
118-
isgood(out) || return p
118+
is_valid(out) || return p
119119
new_node = constructorof(N)(T; val=convert(T, out))
120120
set_node!(p, new_node)
121121
return p

src/Utils.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,6 @@ macro return_on_false2(flag, retval, retval2)
1212
)
1313
end
1414

15-
# Fastest way to check for NaN in an array.
16-
# (due to optimizations in sum())
17-
is_bad_array(array) = !(isempty(array) || isfinite(sum(array)))
18-
isgood(x::T) where {T<:Number} = !(isnan(x) || !isfinite(x))
19-
isgood(x) = true
20-
isbad(x) = !isgood(x)
21-
2215
"""
2316
@memoize_on tree [postprocess] function my_function_on_tree(tree::AbstractExpressionNode)
2417
...

0 commit comments

Comments
 (0)