Skip to content

Commit ae5a2b9

Browse files
committed
Move Zygote.jl to an extension
1 parent 6be88b7 commit ae5a2b9

File tree

6 files changed

+37
-41
lines changed

6 files changed

+37
-41
lines changed

Project.toml

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,47 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
4-
version = "0.10.1"
4+
version = "0.11.0"
55

66
[deps]
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
99
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
1010
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
11+
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
1112
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1213
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1314
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1415
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
15-
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1616
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
17+
18+
[weakdeps]
19+
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
1720
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1821

22+
[extensions]
23+
DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils"
24+
DynamicExpressionsZygoteExt = "Zygote"
25+
1926
[compat]
2027
Compat = "3.37, 4"
2128
LoopVectorization = "0.12"
2229
MacroTools = "0.4, 0.5"
30+
PackageExtensionCompat = "1"
2331
PrecompileTools = "1"
2432
Reexport = "1"
25-
Requires = "1.0, 1.1, 1.2, 1.3"
2633
SymbolicUtils = "0.19, ^1.0.5"
2734
Zygote = "0.6"
2835
julia = "1.6"
2936

30-
[extensions]
31-
DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils"
32-
3337
[extras]
3438
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3539
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
3640
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
3741
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
3842
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
3943
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
44+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4045

4146
[targets]
42-
test = ["Test", "SafeTestsets", "SpecialFunctions", "ForwardDiff", "StaticArrays", "SymbolicUtils"]
43-
44-
[weakdeps]
45-
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
47+
test = ["Test", "SafeTestsets", "SpecialFunctions", "ForwardDiff", "StaticArrays", "SymbolicUtils", "Zygote"]

ext/DynamicExpressionsSymbolicUtilsExt.jl

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,12 @@
11
module DynamicExpressionsSymbolicUtilsExt
22

3-
import Base: convert
4-
#! format: off
5-
if isdefined(Base, :get_extension)
6-
using SymbolicUtils
7-
import DynamicExpressions.EquationModule: Node, DEFAULT_NODE_TYPE
8-
import DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
9-
import DynamicExpressions.UtilsModule: isgood, isbad, @return_on_false, deprecate_varmap
10-
import DynamicExpressions.ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
11-
else
12-
using ..SymbolicUtils
13-
import ..DynamicExpressions.EquationModule: Node, DEFAULT_NODE_TYPE
14-
import ..DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
15-
import ..DynamicExpressions.UtilsModule: isgood, isbad, @return_on_false, deprecate_varmap
16-
import ..DynamicExpressions.ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
17-
end
18-
#! format: on
3+
using SymbolicUtils
4+
import DynamicExpressions.EquationModule: Node, DEFAULT_NODE_TYPE
5+
import DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
6+
import DynamicExpressions.UtilsModule: isgood, isbad, @return_on_false, deprecate_varmap
7+
import DynamicExpressions.ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
198

209
const SYMBOLIC_UTILS_TYPES = Union{<:Number,SymbolicUtils.Symbolic{<:Number}}
21-
2210
const SUPPORTED_OPS = (cos, sin, exp, cot, tan, csc, sec, +, -, *, /)
2311

2412
function isgood(x::SymbolicUtils.Symbolic)
@@ -106,7 +94,7 @@ function findoperation(op, ops)
10694
throw(error("Operation $(op) in expression not found in operations $(ops)!"))
10795
end
10896

109-
function convert(
97+
function Base.convert(
11098
::typeof(SymbolicUtils.Symbolic),
11199
tree::Node,
112100
operators::AbstractOperatorEnum;
@@ -121,11 +109,11 @@ function convert(
121109
)
122110
end
123111

124-
function convert(::typeof(Node), x::Number, operators::AbstractOperatorEnum; kws...)
112+
function Base.convert(::typeof(Node), x::Number, operators::AbstractOperatorEnum; kws...)
125113
return Node(; val=DEFAULT_NODE_TYPE(x))
126114
end
127115

128-
function convert(
116+
function Base.convert(
129117
::typeof(Node),
130118
expr::SymbolicUtils.Symbolic,
131119
operators::AbstractOperatorEnum;

ext/DynamicExpressionsZygoteExt.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
module DynamicExpressionsZygoteExt
2+
3+
import Zygote: gradient
4+
import DynamicExpressions.EvaluateEquationDerivativeModule: _zygote_gradient
5+
6+
_zygote_gradient(op::F, ::Val{1}) where {F} = x -> gradient(op, x)[1]
7+
_zygote_gradient(op::F, ::Val{2}) where {F} = (x, y) -> gradient(op, x, y)
8+
9+
end

src/DynamicExpressions.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ include("SimplifyEquation.jl")
1111
include("OperatorEnumConstruction.jl")
1212
include("ExtensionInterface.jl")
1313

14-
import Requires: @init, @require
14+
import PackageExtensionCompat: @require_extensions
1515
import Reexport: @reexport
1616
@reexport import .EquationModule:
1717
Node, string_tree, print_tree, copy_node, set_node!, tree_mapreduce, filter_map
@@ -35,11 +35,9 @@ import Reexport: @reexport
3535
@reexport import .EvaluationHelpersModule
3636
@reexport import .ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
3737

38-
#! format: off
39-
if !isdefined(Base, :get_extension)
40-
@init @require SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" include("../ext/DynamicExpressionsSymbolicUtilsExt.jl")
38+
function __init__()
39+
@require_extensions
4140
end
42-
#! format: on
4341

4442
include("deprecated.jl")
4543

src/EvaluateEquationDerivative.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import ..UtilsModule: @return_on_false2, @maybe_turbo, is_bad_array, fill_simila
77
import ..EquationUtilsModule: count_constants, index_constants, NodeIndex
88
import ..EvaluateEquationModule: deg0_eval
99

10+
_zygote_gradient(args...) = error("Please load the Zygote.jl package.")
11+
1012
function assert_autodiff_enabled(operators::OperatorEnum)
1113
if length(operators.diff_binops) == 0 && length(operators.diff_unaops) == 0
1214
error(

src/OperatorEnumConstruction.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
module OperatorEnumConstructionModule
22

3-
import Zygote: gradient
43
import ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum, GenericOperatorEnum
54
import ..EquationModule: string_tree, Node
65
import ..EvaluateEquationModule: eval_tree_array
7-
import ..EvaluateEquationDerivativeModule: eval_grad_tree_array
6+
import ..EvaluateEquationDerivativeModule: eval_grad_tree_array, _zygote_gradient
87
import ..EvaluationHelpersModule: _grad_evaluator
98

109
function create_evaluation_helpers!(operators::OperatorEnum)
@@ -223,12 +222,10 @@ function OperatorEnum(;
223222

224223
if enable_autodiff
225224
for op in binary_operators
226-
diff_op(x, y) = gradient(op, x, y)
227-
push!(diff_binary_operators, diff_op)
225+
push!(diff_binary_operators, _zygote_gradient(op, Val(2)))
228226
end
229227
for op in unary_operators
230-
diff_op(x) = gradient(op, x)[1]
231-
push!(diff_unary_operators, diff_op)
228+
push!(diff_unary_operators, _zygote_gradient(op, Val(1)))
232229
end
233230
end
234231

0 commit comments

Comments
 (0)