@@ -6,82 +6,131 @@ import ..EvaluateEquationModule: eval_tree_array
66import .. EvaluateEquationDerivativeModule: eval_grad_tree_array, _zygote_gradient
77import .. EvaluationHelpersModule: _grad_evaluator
88
9- function create_evaluation_helpers! (operators:: OperatorEnum )
10- @eval begin
11- Base. print (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
12- Base. show (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
13- function (tree:: Node )(X; kws... )
14- Base. depwarn (
15- " The `tree(X; kws...)` syntax is deprecated. Use `tree(X, operators; kws...)` instead." ,
16- :Node ,
17- )
18- return tree (X, $ operators; kws... )
19- end
20- # Gradients:
21- function _grad_evaluator (tree:: Node , X; kws... )
22- Base. depwarn (
23- " The `tree'(X; kws...)` syntax is deprecated. Use `tree'(X, operators; kws...)` instead." ,
24- :Node ,
25- )
26- return _grad_evaluator (tree, X, $ operators; kws... )
27- end
9+ """ Used to set a default value for `operators` for ease of use."""
10+ @enum AvailableOperatorTypes begin
11+ IsNothing
12+ IsOperatorEnum
13+ IsGenericOperatorEnum
14+ end
15+
16+ # These constants are purely for convenience. Internal code
17+ # should make use of `Node`, `string_tree`, `eval_tree_array`,
18+ # and `eval_grad_tree_array` directly.
19+
20+ const LATEST_OPERATORS = Ref {Union{Nothing,AbstractOperatorEnum}} (nothing )
21+ const LATEST_OPERATORS_TYPE = Ref {AvailableOperatorTypes} (IsNothing)
22+ const LATEST_UNARY_OPERATOR_MAPPING = Dict {Function,Int} ()
23+ const LATEST_BINARY_OPERATOR_MAPPING = Dict {Function,Int} ()
24+ const ALREADY_DEFINED_UNARY_OPERATORS = Dict {Function,Bool} ()
25+ const ALREADY_DEFINED_BINARY_OPERATORS = Dict {Function,Bool} ()
26+
27+ function Base. show (io:: IO , tree:: Node )
28+ latest_operators_type = LATEST_OPERATORS_TYPE. x
29+ if latest_operators_type == IsNothing
30+ return print (io, string (tree))
31+ elseif latest_operators_type == IsOperatorEnum
32+ latest_operators = LATEST_OPERATORS. x:: OperatorEnum
33+ return print (io, string_tree (tree, latest_operators))
34+ else
35+ latest_operators = LATEST_OPERATORS. x:: GenericOperatorEnum
36+ return print (io, string_tree (tree, latest_operators))
2837 end
2938end
39+ function (tree:: Node )(X; kws... )
40+ Base. depwarn (
41+ " The `tree(X; kws...)` syntax is deprecated. Use `tree(X, operators; kws...)` instead." ,
42+ :Node ,
43+ )
44+ latest_operators_type = LATEST_OPERATORS_TYPE. x
45+ if latest_operators_type == IsNothing
46+ error (" Please use the `tree(X, operators; kws...)` syntax instead." )
47+ elseif latest_operators_type == IsOperatorEnum
48+ latest_operators = LATEST_OPERATORS. x:: OperatorEnum
49+ return tree (X, latest_operators; kws... )
50+ else
51+ latest_operators = LATEST_OPERATORS. x:: GenericOperatorEnum
52+ return tree (X, latest_operators; kws... )
53+ end
54+ end
55+
56+ function _grad_evaluator (tree:: Node , X; kws... )
57+ Base. depwarn (
58+ " The `tree'(X; kws...)` syntax is deprecated. Use `tree'(X, operators; kws...)` instead." ,
59+ :Node ,
60+ )
61+ latest_operators_type = LATEST_OPERATORS_TYPE. x
62+ # return _grad_evaluator(tree, X, $operators; kws...)
63+ if latest_operators_type == IsNothing
64+ error (" Please use the `tree'(X, operators; kws...)` syntax instead." )
65+ elseif latest_operators_type == IsOperatorEnum
66+ latest_operators = LATEST_OPERATORS. x:: OperatorEnum
67+ return _grad_evaluator (tree, X, latest_operators; kws... )
68+ else
69+ error (" Gradients are not implemented for `GenericOperatorEnum`." )
70+ end
71+ end
72+
73+ function create_evaluation_helpers! (operators:: OperatorEnum )
74+ LATEST_OPERATORS. x = operators
75+ return LATEST_OPERATORS_TYPE. x = IsOperatorEnum
76+ end
3077
3178function create_evaluation_helpers! (operators:: GenericOperatorEnum )
32- @eval begin
33- Base. print (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
34- Base. show (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
35-
36- function (tree:: Node )(X; kws... )
37- Base. depwarn (
38- " The `tree(X; kws...)` syntax is deprecated. Use `tree(X, operators; kws...)` instead." ,
39- :Node ,
40- )
41- return tree (X, $ operators; kws... )
42- end
43- function _grad_evaluator (tree:: Node , X; kws... )
44- return error (" Gradients are not implemented for `GenericOperatorEnum`." )
45- end
79+ LATEST_OPERATORS. x = operators
80+ return LATEST_OPERATORS_TYPE. x = IsGenericOperatorEnum
81+ end
82+ function lookup_op (@nospecialize (f), :: Val{degree} ) where {degree}
83+ mapping = degree == 1 ? LATEST_UNARY_OPERATOR_MAPPING : LATEST_BINARY_OPERATOR_MAPPING
84+ if ! haskey (mapping, f)
85+ error (
86+ " Convenience constructor for `Node` using operator `$(f) ` is out-of-date. " *
87+ " Please create an `OperatorEnum` (or `GenericOperatorEnum`) with " *
88+ " `define_helper_functions=true` and pass `$(f) `." ,
89+ )
4690 end
91+ return mapping[f]
4792end
4893
49- function _extend_unary_operator (f:: Symbol , op, type_requirements)
94+ function _extend_unary_operator (f:: Symbol , type_requirements)
5095 quote
5196 quote
5297 function $ ($ f)(l:: Node{T} ):: Node{T} where {T<: $ ($ type_requirements)}
5398 return if (l. degree == 0 && l. constant)
5499 Node (T; val= $ ($ f)(l. val:: T ))
55100 else
56- Node ($ ($ op), l)
101+ latest_op_idx = $ ($ lookup_op)($ ($ f), Val (1 ))
102+ Node (latest_op_idx, l)
57103 end
58104 end
59105 end
60106 end
61107end
62108
63- function _extend_binary_operator (f:: Symbol , op, type_requirements, build_converters)
109+ function _extend_binary_operator (f:: Symbol , type_requirements, build_converters)
64110 quote
65111 quote
66112 function $ ($ f)(l:: Node{T} , r:: Node{T} ) where {T<: $ ($ type_requirements)}
67113 if (l. degree == 0 && l. constant && r. degree == 0 && r. constant)
68114 Node (T; val= $ ($ f)(l. val:: T , r. val:: T ))
69115 else
70- Node ($ ($ op), l, r)
116+ latest_op_idx = $ ($ lookup_op)($ ($ f), Val (2 ))
117+ Node (latest_op_idx, l, r)
71118 end
72119 end
73120 function $ ($ f)(l:: Node{T} , r:: T ) where {T<: $ ($ type_requirements)}
74121 if l. degree == 0 && l. constant
75122 Node (T; val= $ ($ f)(l. val:: T , r))
76123 else
77- Node ($ ($ op), l, Node (T; val= r))
124+ latest_op_idx = $ ($ lookup_op)($ ($ f), Val (2 ))
125+ Node (latest_op_idx, l, Node (T; val= r))
78126 end
79127 end
80128 function $ ($ f)(l:: T , r:: Node{T} ) where {T<: $ ($ type_requirements)}
81129 if r. degree == 0 && r. constant
82130 Node (T; val= $ ($ f)(l, r. val:: T ))
83131 else
84- Node ($ ($ op), Node (T; val= l), r)
132+ latest_op_idx = $ ($ lookup_op)($ ($ f), Val (2 ))
133+ Node (latest_op_idx, Node (T; val= l), r)
85134 end
86135 end
87136 if $ ($ build_converters)
@@ -116,8 +165,8 @@ function _extend_binary_operator(f::Symbol, op, type_requirements, build_convert
116165end
117166
118167function _extend_operators (operators, skip_user_operators, __module__:: Module )
119- binary_ex = _extend_binary_operator (:f , :op , : type_requirements , :build_converters )
120- unary_ex = _extend_unary_operator (:f , :op , : type_requirements )
168+ binary_ex = _extend_binary_operator (:f , :type_requirements , :build_converters )
169+ unary_ex = _extend_unary_operator (:f , :type_requirements )
121170 return quote
122171 local type_requirements
123172 local build_converters
@@ -128,25 +177,44 @@ function _extend_operators(operators, skip_user_operators, __module__::Module)
128177 type_requirements = Any
129178 build_converters = false
130179 end
131- for (op, f) in enumerate (map (Symbol, $ (operators). binops))
180+ # Trigger errors if operators are not yet defined:
181+ empty! ($ (LATEST_BINARY_OPERATOR_MAPPING))
182+ empty! ($ (LATEST_UNARY_OPERATOR_MAPPING))
183+ for (op, func) in enumerate ($ (operators). binops)
184+ local f = Symbol (func)
185+ local skip = false
132186 if isdefined (Base, f)
133187 f = :(Base.$ (f))
134188 elseif $ (skip_user_operators)
135- continue
189+ skip = true
136190 else
137191 f = :($ ($ __module__). $ (f))
138192 end
139- eval ($ binary_ex)
193+ $ (LATEST_BINARY_OPERATOR_MAPPING)[func] = op
194+ skip && continue
195+ # Avoid redefining methods:
196+ if ! haskey ($ (ALREADY_DEFINED_UNARY_OPERATORS), func)
197+ eval ($ binary_ex)
198+ $ (ALREADY_DEFINED_UNARY_OPERATORS)[func] = true
199+ end
140200 end
141- for (op, f) in enumerate (map (Symbol, $ (operators). unaops))
201+ for (op, func) in enumerate ($ (operators). unaops)
202+ local f = Symbol (func)
203+ local skip = false
142204 if isdefined (Base, f)
143205 f = :(Base.$ (f))
144206 elseif $ (skip_user_operators)
145- continue
207+ skip = true
146208 else
147209 f = :($ ($ __module__). $ (f))
148210 end
149- eval ($ unary_ex)
211+ $ (LATEST_UNARY_OPERATOR_MAPPING)[func] = op
212+ skip && continue
213+ # Avoid redefining methods:
214+ if ! haskey ($ (ALREADY_DEFINED_BINARY_OPERATORS), func)
215+ eval ($ unary_ex)
216+ $ (ALREADY_DEFINED_BINARY_OPERATORS)[func] = true
217+ end
150218 end
151219 end
152220end
@@ -162,14 +230,16 @@ apply this macro to the operator enum in the same module you have the operators
162230defined.
163231"""
164232macro extend_operators (operators)
165- ex = _extend_operators (esc ( operators) , false , __module__)
233+ ex = _extend_operators (operators, false , __module__)
166234 expected_type = AbstractOperatorEnum
167- quote
168- if ! isa ($ (esc (operators)), $ expected_type)
169- error (" You must pass an operator enum to `@extend_operators`." )
170- end
171- $ ex
172- end
235+ return esc (
236+ quote
237+ if ! isa ($ (operators), $ expected_type)
238+ error (" You must pass an operator enum to `@extend_operators`." )
239+ end
240+ $ ex
241+ end ,
242+ )
173243end
174244
175245"""
@@ -179,14 +249,16 @@ Similar to `@extend_operators`, but only extends operators already
179249defined in `Base`.
180250"""
181251macro extend_operators_base (operators)
182- ex = _extend_operators (esc ( operators) , true , __module__)
252+ ex = _extend_operators (operators, true , __module__)
183253 expected_type = AbstractOperatorEnum
184- quote
185- if ! isa ($ (esc (operators)), $ expected_type)
186- error (" You must pass an operator enum to `@extend_operators_base`." )
187- end
188- $ ex
189- end
254+ return esc (
255+ quote
256+ if ! isa ($ (operators), $ expected_type)
257+ error (" You must pass an operator enum to `@extend_operators_base`." )
258+ end
259+ $ ex
260+ end ,
261+ )
190262end
191263
192264"""
0 commit comments