@@ -5,9 +5,6 @@ using DynamicExpressions
55operators = OperatorEnum (; binary_operators= (+ , - , * , / ), unary_operators= (cos, sin))
66# TODO : More operators will trigger a segfault in Enzyme
77
8- # These options are required for Enzyme to work:
9- const eval_options = (turbo= Val (false ),)
10-
118x1, x2, x3 = (i -> Node (Float64; feature= i)). (1 : 3 )
129
1310tree = Node (1 , x1, Node (1 , x2)) # == x1 + cos(x2)
@@ -16,32 +13,31 @@ X = randn(3, 100);
1613dX = zero (X)
1714
1815function f (tree, X, operators, output)
19- output[] = sum (eval_tree_array (tree, X, operators; eval_options ... )[1 ])
16+ output[] = sum (eval_tree_array (tree, X, operators)[1 ])
2017 return nothing
2118end
2219
2320output = [0.0 ]
2421doutput = [1.0 ]
2522
26- autodiff (
27- Reverse,
28- f,
29- Const (tree),
30- Duplicated (X, dX),
31- Const (operators),
32- Duplicated (output, doutput),
33- )
23+ fetch (schedule (Task (64 * 1024 ^ 2 ) do
24+ autodiff (
25+ Reverse,
26+ f,
27+ Const (tree),
28+ Duplicated (X, dX),
29+ Const (operators),
30+ Duplicated (output, doutput),
31+ )
32+ end ))
3433
3534true_dX = cat (ones (100 ), - sin .(X[2 , :]), zeros (100 ); dims= 2 )'
3635
3736@test true_dX ≈ dX
3837
39- # ! format: off
40- @static if false
41- # Broken test (see https://github.com/EnzymeAD/Enzyme.jl/issues/1241)
4238function my_loss_function (tree, X, operators)
4339 # Get the outputs
44- y = tree ( X, operators)
40+ y, _ = eval_tree_array (tree, X, operators)
4541 # Sum them (so we can take a gradient, rather than a jacobian)
4642 return sum (y)
4743end
@@ -53,23 +49,18 @@ X = [1.0; 1.0;;]
5349d_tree = begin
5450 storage_tree = copy (tree)
5551 # Set all constants to zero:
56- foreach (storage_tree) do node
57- if node. degree == 0 && node. constant
58- node. val = 0.0
59- end
60- end
61- autodiff (
62- Reverse,
63- my_loss_function,
64- Active,
65- Duplicated (tree, storage_tree),
66- Const (X),
67- Const (operators),
68- )
52+ Enzyme. make_zero! (storage_tree)
53+ fetch (schedule (Task (64 * 1024 ^ 2 ) do
54+ autodiff (
55+ Reverse,
56+ my_loss_function,
57+ Active,
58+ Duplicated (tree, storage_tree),
59+ Const (X),
60+ Const (operators),
61+ )
62+ end ))
6963 storage_tree
7064end
7165
72- @test_broken get_scalar_constants (d_tree) ≈ [1.0 , 0.717356 ]
73- end
74-
75- # ! format: on
66+ @test isapprox (first (get_scalar_constants (d_tree)), [1.0 , 0.717356 ]; atol= 1e-3 )
0 commit comments