Skip to content

Commit a1f1c90

Browse files
committed
Clean up derivative tests
1 parent 6ce0d3a commit a1f1c90

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

test/test_derivatives.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ for type in [Float16, Float32, Float64]
4848
nfeatures = 3
4949
N = 100
5050

51+
local X, operators
5152
X = rand(rng, type, nfeatures, N) * 5
5253

5354
operators = OperatorEnum(;
@@ -63,6 +64,7 @@ for type in [Float16, Float32, Float64]
6364
continue
6465
end
6566

67+
local tree
6668
tree = convert(Node{type}, equation(nx1, nx2, nx3))
6769
predicted_output = eval_tree_array(tree, X, operators)[1]
6870
true_output = equation.([X[i, :] for i in 1:nfeatures]...)
@@ -96,6 +98,7 @@ for type in [Float16, Float32, Float64]
9698
# Test gradient with respect to constants:
9799
equation4(x1, x2, x3) = 3.2f0 * x1
98100
# The gradient should be: (C * x1) => x1 is gradient with respect to C.
101+
local tree
99102
tree = equation4(nx1, nx2, nx3)
100103
tree = convert(Node{type}, tree)
101104
predicted_grad = eval_grad_tree_array(tree, X, operators; variable=false)[2]

0 commit comments

Comments
 (0)