Skip to content

Commit a6e6dee

Browse files
committed
Avoid overwriting methods
1 parent c84014f commit a6e6dee

File tree

2 files changed

+133
-75
lines changed

2 files changed

+133
-75
lines changed

src/EvaluationHelpers.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,6 @@ function (tree::Node)(X, operators::GenericOperatorEnum; kws...)
5454
!did_finish && return nothing
5555
return out
5656
end
57-
function (tree::Node)(X; kws...)
58-
## This will be overwritten by OperatorEnumConstructionModule, and turned
59-
## into a depwarn.
60-
return error(
61-
"The `tree(X; kws...)` syntax is deprecated. Use `tree(X, operators; kws...)` instead.",
62-
)
63-
end
6457

6558
# Gradients:
6659
function _grad_evaluator(tree::Node, X, operators::OperatorEnum; variable=true, kws...)
@@ -73,13 +66,6 @@ end
7366
function _grad_evaluator(tree::Node, X, operators::GenericOperatorEnum; kws...)
7467
return error("Gradients are not implemented for `GenericOperatorEnum`.")
7568
end
76-
function _grad_evaluator(tree::Node, X; kws...)
77-
## This will be overwritten by OperatorEnumConstructionModule, and turned
78-
## into a depwarn
79-
return error(
80-
"The `tree'(X; kws...)` syntax is deprecated. Use `tree'(X, operators; kws...)` instead.",
81-
)
82-
end
8369

8470
"""
8571
(tree::Node{T})'(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false, variable::Bool=true)

src/OperatorEnumConstruction.jl

Lines changed: 133 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -6,82 +6,131 @@ import ..EvaluateEquationModule: eval_tree_array
66
import ..EvaluateEquationDerivativeModule: eval_grad_tree_array, _zygote_gradient
77
import ..EvaluationHelpersModule: _grad_evaluator
88

9-
function create_evaluation_helpers!(operators::OperatorEnum)
10-
@eval begin
11-
Base.print(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
12-
Base.show(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
13-
function (tree::Node)(X; kws...)
14-
Base.depwarn(
15-
"The `tree(X; kws...)` syntax is deprecated. Use `tree(X, operators; kws...)` instead.",
16-
:Node,
17-
)
18-
return tree(X, $operators; kws...)
19-
end
20-
# Gradients:
21-
function _grad_evaluator(tree::Node, X; kws...)
22-
Base.depwarn(
23-
"The `tree'(X; kws...)` syntax is deprecated. Use `tree'(X, operators; kws...)` instead.",
24-
:Node,
25-
)
26-
return _grad_evaluator(tree, X, $operators; kws...)
27-
end
9+
"""Used to set a default value for `operators` for ease of use."""
10+
@enum AvailableOperatorTypes begin
11+
IsNothing
12+
IsOperatorEnum
13+
IsGenericOperatorEnum
14+
end
15+
16+
# These constants are purely for convenience. Internal code
17+
# should make use of `Node`, `string_tree`, `eval_tree_array`,
18+
# and `eval_grad_tree_array` directly.
19+
20+
const LATEST_OPERATORS = Ref{Union{Nothing,AbstractOperatorEnum}}(nothing)
21+
const LATEST_OPERATORS_TYPE = Ref{AvailableOperatorTypes}(IsNothing)
22+
const LATEST_UNARY_OPERATOR_MAPPING = Dict{Function,Int}()
23+
const LATEST_BINARY_OPERATOR_MAPPING = Dict{Function,Int}()
24+
const ALREADY_DEFINED_UNARY_OPERATORS = Dict{Function,Bool}()
25+
const ALREADY_DEFINED_BINARY_OPERATORS = Dict{Function,Bool}()
26+
27+
function Base.show(io::IO, tree::Node)
28+
latest_operators_type = LATEST_OPERATORS_TYPE.x
29+
if latest_operators_type == IsNothing
30+
return print(io, string(tree))
31+
elseif latest_operators_type == IsOperatorEnum
32+
latest_operators = LATEST_OPERATORS.x::OperatorEnum
33+
return print(io, string_tree(tree, latest_operators))
34+
else
35+
latest_operators = LATEST_OPERATORS.x::GenericOperatorEnum
36+
return print(io, string_tree(tree, latest_operators))
2837
end
2938
end
39+
function (tree::Node)(X; kws...)
40+
Base.depwarn(
41+
"The `tree(X; kws...)` syntax is deprecated. Use `tree(X, operators; kws...)` instead.",
42+
:Node,
43+
)
44+
latest_operators_type = LATEST_OPERATORS_TYPE.x
45+
if latest_operators_type == IsNothing
46+
error("Please use the `tree(X, operators; kws...)` syntax instead.")
47+
elseif latest_operators_type == IsOperatorEnum
48+
latest_operators = LATEST_OPERATORS.x::OperatorEnum
49+
return tree(X, latest_operators; kws...)
50+
else
51+
latest_operators = LATEST_OPERATORS.x::GenericOperatorEnum
52+
return tree(X, latest_operators; kws...)
53+
end
54+
end
55+
56+
function _grad_evaluator(tree::Node, X; kws...)
57+
Base.depwarn(
58+
"The `tree'(X; kws...)` syntax is deprecated. Use `tree'(X, operators; kws...)` instead.",
59+
:Node,
60+
)
61+
latest_operators_type = LATEST_OPERATORS_TYPE.x
62+
# return _grad_evaluator(tree, X, $operators; kws...)
63+
if latest_operators_type == IsNothing
64+
error("Please use the `tree'(X, operators; kws...)` syntax instead.")
65+
elseif latest_operators_type == IsOperatorEnum
66+
latest_operators = LATEST_OPERATORS.x::OperatorEnum
67+
return _grad_evaluator(tree, X, latest_operators; kws...)
68+
else
69+
error("Gradients are not implemented for `GenericOperatorEnum`.")
70+
end
71+
end
72+
73+
function create_evaluation_helpers!(operators::OperatorEnum)
74+
LATEST_OPERATORS.x = operators
75+
return LATEST_OPERATORS_TYPE.x = IsOperatorEnum
76+
end
3077

3178
function create_evaluation_helpers!(operators::GenericOperatorEnum)
32-
@eval begin
33-
Base.print(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
34-
Base.show(io::IO, tree::Node) = print(io, string_tree(tree, $operators))
35-
36-
function (tree::Node)(X; kws...)
37-
Base.depwarn(
38-
"The `tree(X; kws...)` syntax is deprecated. Use `tree(X, operators; kws...)` instead.",
39-
:Node,
40-
)
41-
return tree(X, $operators; kws...)
42-
end
43-
function _grad_evaluator(tree::Node, X; kws...)
44-
return error("Gradients are not implemented for `GenericOperatorEnum`.")
45-
end
79+
LATEST_OPERATORS.x = operators
80+
return LATEST_OPERATORS_TYPE.x = IsGenericOperatorEnum
81+
end
82+
function lookup_op(@nospecialize(f), ::Val{degree}) where {degree}
83+
mapping = degree == 1 ? LATEST_UNARY_OPERATOR_MAPPING : LATEST_BINARY_OPERATOR_MAPPING
84+
if !haskey(mapping, f)
85+
error(
86+
"Convenience constructor for `Node` using operator `$(f)` is out-of-date. " *
87+
"Please create an `OperatorEnum` (or `GenericOperatorEnum`) with " *
88+
"`define_helper_functions=true` and pass `$(f)`.",
89+
)
4690
end
91+
return mapping[f]
4792
end
4893

49-
function _extend_unary_operator(f::Symbol, op, type_requirements)
94+
function _extend_unary_operator(f::Symbol, type_requirements)
5095
quote
5196
quote
5297
function $($f)(l::Node{T})::Node{T} where {T<:$($type_requirements)}
5398
return if (l.degree == 0 && l.constant)
5499
Node(T; val=$($f)(l.val::T))
55100
else
56-
Node($($op), l)
101+
latest_op_idx = $($lookup_op)($($f), Val(1))
102+
Node(latest_op_idx, l)
57103
end
58104
end
59105
end
60106
end
61107
end
62108

63-
function _extend_binary_operator(f::Symbol, op, type_requirements, build_converters)
109+
function _extend_binary_operator(f::Symbol, type_requirements, build_converters)
64110
quote
65111
quote
66112
function $($f)(l::Node{T}, r::Node{T}) where {T<:$($type_requirements)}
67113
if (l.degree == 0 && l.constant && r.degree == 0 && r.constant)
68114
Node(T; val=$($f)(l.val::T, r.val::T))
69115
else
70-
Node($($op), l, r)
116+
latest_op_idx = $($lookup_op)($($f), Val(2))
117+
Node(latest_op_idx, l, r)
71118
end
72119
end
73120
function $($f)(l::Node{T}, r::T) where {T<:$($type_requirements)}
74121
if l.degree == 0 && l.constant
75122
Node(T; val=$($f)(l.val::T, r))
76123
else
77-
Node($($op), l, Node(T; val=r))
124+
latest_op_idx = $($lookup_op)($($f), Val(2))
125+
Node(latest_op_idx, l, Node(T; val=r))
78126
end
79127
end
80128
function $($f)(l::T, r::Node{T}) where {T<:$($type_requirements)}
81129
if r.degree == 0 && r.constant
82130
Node(T; val=$($f)(l, r.val::T))
83131
else
84-
Node($($op), Node(T; val=l), r)
132+
latest_op_idx = $($lookup_op)($($f), Val(2))
133+
Node(latest_op_idx, Node(T; val=l), r)
85134
end
86135
end
87136
if $($build_converters)
@@ -116,8 +165,8 @@ function _extend_binary_operator(f::Symbol, op, type_requirements, build_convert
116165
end
117166

118167
function _extend_operators(operators, skip_user_operators, __module__::Module)
119-
binary_ex = _extend_binary_operator(:f, :op, :type_requirements, :build_converters)
120-
unary_ex = _extend_unary_operator(:f, :op, :type_requirements)
168+
binary_ex = _extend_binary_operator(:f, :type_requirements, :build_converters)
169+
unary_ex = _extend_unary_operator(:f, :type_requirements)
121170
return quote
122171
local type_requirements
123172
local build_converters
@@ -128,25 +177,44 @@ function _extend_operators(operators, skip_user_operators, __module__::Module)
128177
type_requirements = Any
129178
build_converters = false
130179
end
131-
for (op, f) in enumerate(map(Symbol, $(operators).binops))
180+
# Trigger errors if operators are not yet defined:
181+
empty!($(LATEST_BINARY_OPERATOR_MAPPING))
182+
empty!($(LATEST_UNARY_OPERATOR_MAPPING))
183+
for (op, func) in enumerate($(operators).binops)
184+
local f = Symbol(func)
185+
local skip = false
132186
if isdefined(Base, f)
133187
f = :(Base.$(f))
134188
elseif $(skip_user_operators)
135-
continue
189+
skip = true
136190
else
137191
f = :($($__module__).$(f))
138192
end
139-
eval($binary_ex)
193+
$(LATEST_BINARY_OPERATOR_MAPPING)[func] = op
194+
skip && continue
195+
# Avoid redefining methods:
196+
if !haskey($(ALREADY_DEFINED_UNARY_OPERATORS), func)
197+
eval($binary_ex)
198+
$(ALREADY_DEFINED_UNARY_OPERATORS)[func] = true
199+
end
140200
end
141-
for (op, f) in enumerate(map(Symbol, $(operators).unaops))
201+
for (op, func) in enumerate($(operators).unaops)
202+
local f = Symbol(func)
203+
local skip = false
142204
if isdefined(Base, f)
143205
f = :(Base.$(f))
144206
elseif $(skip_user_operators)
145-
continue
207+
skip = true
146208
else
147209
f = :($($__module__).$(f))
148210
end
149-
eval($unary_ex)
211+
$(LATEST_UNARY_OPERATOR_MAPPING)[func] = op
212+
skip && continue
213+
# Avoid redefining methods:
214+
if !haskey($(ALREADY_DEFINED_BINARY_OPERATORS), func)
215+
eval($unary_ex)
216+
$(ALREADY_DEFINED_BINARY_OPERATORS)[func] = true
217+
end
150218
end
151219
end
152220
end
@@ -162,14 +230,16 @@ apply this macro to the operator enum in the same module you have the operators
162230
defined.
163231
"""
164232
macro extend_operators(operators)
165-
ex = _extend_operators(esc(operators), false, __module__)
233+
ex = _extend_operators(operators, false, __module__)
166234
expected_type = AbstractOperatorEnum
167-
quote
168-
if !isa($(esc(operators)), $expected_type)
169-
error("You must pass an operator enum to `@extend_operators`.")
170-
end
171-
$ex
172-
end
235+
return esc(
236+
quote
237+
if !isa($(operators), $expected_type)
238+
error("You must pass an operator enum to `@extend_operators`.")
239+
end
240+
$ex
241+
end,
242+
)
173243
end
174244

175245
"""
@@ -179,14 +249,16 @@ Similar to `@extend_operators`, but only extends operators already
179249
defined in `Base`.
180250
"""
181251
macro extend_operators_base(operators)
182-
ex = _extend_operators(esc(operators), true, __module__)
252+
ex = _extend_operators(operators, true, __module__)
183253
expected_type = AbstractOperatorEnum
184-
quote
185-
if !isa($(esc(operators)), $expected_type)
186-
error("You must pass an operator enum to `@extend_operators_base`.")
187-
end
188-
$ex
189-
end
254+
return esc(
255+
quote
256+
if !isa($(operators), $expected_type)
257+
error("You must pass an operator enum to `@extend_operators_base`.")
258+
end
259+
$ex
260+
end,
261+
)
190262
end
191263

192264
"""

0 commit comments

Comments
 (0)