Skip to content

Commit 55ae519

Browse files
committed
test: fix Enzyme test
1 parent d7d5fd5 commit 55ae519

File tree

1 file changed

+24
-33
lines changed

1 file changed

+24
-33
lines changed

test/test_enzyme.jl

Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@ using DynamicExpressions
55
operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(cos, sin))
66
# TODO: More operators will trigger a segfault in Enzyme
77

8-
# These options are required for Enzyme to work:
9-
const eval_options = (turbo=Val(false),)
10-
118
x1, x2, x3 = (i -> Node(Float64; feature=i)).(1:3)
129

1310
tree = Node(1, x1, Node(1, x2)) # == x1 + cos(x2)
@@ -16,32 +13,31 @@ X = randn(3, 100);
1613
dX = zero(X)
1714

1815
function f(tree, X, operators, output)
19-
output[] = sum(eval_tree_array(tree, X, operators; eval_options...)[1])
16+
output[] = sum(eval_tree_array(tree, X, operators)[1])
2017
return nothing
2118
end
2219

2320
output = [0.0]
2421
doutput = [1.0]
2522

26-
autodiff(
27-
Reverse,
28-
f,
29-
Const(tree),
30-
Duplicated(X, dX),
31-
Const(operators),
32-
Duplicated(output, doutput),
33-
)
23+
fetch(schedule(Task(64 * 1024^2) do
24+
autodiff(
25+
Reverse,
26+
f,
27+
Const(tree),
28+
Duplicated(X, dX),
29+
Const(operators),
30+
Duplicated(output, doutput),
31+
)
32+
end))
3433

3534
true_dX = cat(ones(100), -sin.(X[2, :]), zeros(100); dims=2)'
3635

3736
@test true_dX dX
3837

39-
#! format: off
40-
@static if false
41-
# Broken test (see https://github.com/EnzymeAD/Enzyme.jl/issues/1241)
4238
function my_loss_function(tree, X, operators)
4339
# Get the outputs
44-
y = tree(X, operators)
40+
y, _ = eval_tree_array(tree, X, operators)
4541
# Sum them (so we can take a gradient, rather than a jacobian)
4642
return sum(y)
4743
end
@@ -53,23 +49,18 @@ X = [1.0; 1.0;;]
5349
d_tree = begin
5450
storage_tree = copy(tree)
5551
# Set all constants to zero:
56-
foreach(storage_tree) do node
57-
if node.degree == 0 && node.constant
58-
node.val = 0.0
59-
end
60-
end
61-
autodiff(
62-
Reverse,
63-
my_loss_function,
64-
Active,
65-
Duplicated(tree, storage_tree),
66-
Const(X),
67-
Const(operators),
68-
)
52+
Enzyme.make_zero!(storage_tree)
53+
fetch(schedule(Task(64 * 1024^2) do
54+
autodiff(
55+
Reverse,
56+
my_loss_function,
57+
Active,
58+
Duplicated(tree, storage_tree),
59+
Const(X),
60+
Const(operators),
61+
)
62+
end))
6963
storage_tree
7064
end
7165

72-
@test_broken get_scalar_constants(d_tree) [1.0, 0.717356]
73-
end
74-
75-
#! format: on
66+
@test isapprox(first(get_scalar_constants(d_tree)), [1.0, 0.717356]; atol=1e-3)

0 commit comments

Comments
 (0)