|
| 1 | +module EvaluationHelpersModule |
| 2 | + |
| 3 | +import Base: adjoint |
| 4 | +import ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum, GenericOperatorEnum |
| 5 | +import ..EquationModule: Node |
| 6 | +import ..EvaluateEquationModule: eval_tree_array |
| 7 | +import ..EvaluateEquationDerivativeModule: eval_grad_tree_array |
| 8 | + |
| 9 | +# Evaluation: |
| 10 | +function (tree::Node)(X, operators::OperatorEnum; kws...) |
| 11 | + out, did_finish = eval_tree_array(tree, X, operators; kws...) |
| 12 | + !did_finish && (out .= convert(eltype(out), NaN)) |
| 13 | + return out |
| 14 | +end |
| 15 | +function (tree::Node)(X, operators::GenericOperatorEnum; kws...) |
| 16 | + out, did_finish = eval_tree_array(tree, X, operators; kws...) |
| 17 | + !did_finish && return nothing |
| 18 | + return out |
| 19 | +end |
| 20 | +function (tree::Node)(X; kws...) |
| 21 | + @error "The `tree(X; kws...)` syntax is deprecated. Use `tree(X, operators; kws...)` instead." |
| 22 | +end |
| 23 | + |
| 24 | +# Gradients: |
| 25 | +function _grad_evaluator(tree::Node, X, operators::OperatorEnum; kws...) |
| 26 | + _, grad, did_complete = eval_grad_tree_array(tree, X, operators; variable=true, kws...) |
| 27 | + !did_complete && (grad .= convert(eltype(grad), NaN)) |
| 28 | + return grad |
| 29 | +end |
| 30 | +function _grad_evaluator(tree::Node, X, operators::GenericOperatorEnum; kws...) |
| 31 | + @error "Gradients are not implemented for `GenericOperatorEnum`." |
| 32 | +end |
| 33 | +function _grad_evaluator(tree::Node, X; kws...) |
| 34 | + @error "The `tree'(X; kws...)` syntax is deprecated. Use `tree'(X, operators; kws...)` instead." |
| 35 | +end |
| 36 | +Base.adjoint(tree::Node) = ((args...; kws...) -> _grad_evaluator(tree, args...; kws...)) |
| 37 | + |
| 38 | +end |
0 commit comments