Skip to content

Commit fe6b3e0

Browse files
committed
Create default (::Node, ...)() and adjoint methods
1 parent 7835abf commit fe6b3e0

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

src/DynamicExpressions.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ include("Equation.jl")
66
include("EquationUtils.jl")
77
include("EvaluateEquation.jl")
88
include("EvaluateEquationDerivative.jl")
9+
include("EvaluationHelpers.jl")
910
include("InterfaceSymbolicUtils.jl")
1011
include("SimplifyEquation.jl")
1112
include("OperatorEnumConstruction.jl")
@@ -31,6 +32,7 @@ using Reexport
3132
eval_diff_tree_array, eval_grad_tree_array
3233
@reexport import .InterfaceSymbolicUtilsModule: node_to_symbolic, symbolic_to_node
3334
@reexport import .SimplifyEquationModule: combine_operators, simplify_tree
35+
@reexport import .EvaluationHelpersModule
3436

3537
import TOML: parsefile
3638

src/EvaluationHelpers.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
module EvaluationHelpersModule
2+
3+
import Base: adjoint
4+
import ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum, GenericOperatorEnum
5+
import ..EquationModule: Node
6+
import ..EvaluateEquationModule: eval_tree_array
7+
import ..EvaluateEquationDerivativeModule: eval_grad_tree_array
8+
9+
# Evaluation:
10+
function (tree::Node)(X, operators::OperatorEnum; kws...)
11+
out, did_finish = eval_tree_array(tree, X, operators; kws...)
12+
!did_finish && (out .= convert(eltype(out), NaN))
13+
return out
14+
end
15+
function (tree::Node)(X, operators::GenericOperatorEnum; kws...)
16+
out, did_finish = eval_tree_array(tree, X, operators; kws...)
17+
!did_finish && return nothing
18+
return out
19+
end
20+
function (tree::Node)(X; kws...)
21+
@error "The `tree(X; kws...)` syntax is deprecated. Use `tree(X, operators; kws...)` instead."
22+
end
23+
24+
# Gradients:
25+
function _grad_evaluator(tree::Node, X, operators::OperatorEnum; kws...)
26+
_, grad, did_complete = eval_grad_tree_array(tree, X, operators; variable=true, kws...)
27+
!did_complete && (grad .= convert(eltype(grad), NaN))
28+
return grad
29+
end
30+
function _grad_evaluator(tree::Node, X, operators::GenericOperatorEnum; kws...)
31+
@error "Gradients are not implemented for `GenericOperatorEnum`."
32+
end
33+
function _grad_evaluator(tree::Node, X; kws...)
34+
@error "The `tree'(X; kws...)` syntax is deprecated. Use `tree'(X, operators; kws...)` instead."
35+
end
36+
Base.adjoint(tree::Node) = ((args...; kws...) -> _grad_evaluator(tree, args...; kws...))
37+
38+
end

0 commit comments

Comments
 (0)