Skip to content

Commit 2268896

Browse files
committed
test: add unittests for copy_node!
1 parent bc65a20 commit 2268896

File tree

4 files changed

+56
-0
lines changed

4 files changed

+56
-0
lines changed

src/DynamicExpressions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ import .ValueInterfaceModule:
4141
GraphNode,
4242
Node,
4343
copy_node,
44+
copy_node!,
4445
set_node!,
4546
tree_mapreduce,
4647
filter_map,

src/base.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,7 @@ function copy_node!(
499499
break_sharing::Val{BS}=Val(false),
500500
ref::Base.RefValue{<:Integer}=Ref(0),
501501
) where {BS,N<:AbstractExpressionNode}
502+
ref.x = 0
502503
return tree_mapreduce(
503504
leaf -> leaf_copy!(@inbounds(dest[ref.x += 1]), leaf),
504505
identity,

test/test_copy_inplace.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
@testitem "copy_node! - random trees" begin
2+
using DynamicExpressions
3+
using DynamicExpressions: copy_node!
4+
include("tree_gen_utils.jl")
5+
6+
operators = OperatorEnum(; binary_operators=[+, *, /], unary_operators=[sin, cos])
7+
8+
for size in [1, 2, 5, 10, 20], _ in 1:10, N in (Node, ParametricNode)
9+
tree = gen_random_tree_fixed_size(size, operators, 5, Float64, N)
10+
n_nodes = count_nodes(tree)
11+
@test n_nodes == size # Verify gen_random_tree_fixed_size worked
12+
13+
# Make array larger than needed to test bounds:
14+
dest_array = [N{Float64}() for _ in 1:(n_nodes + 10)]
15+
orig_nodes = dest_array[(n_nodes + 1):end] # Save reference to unused nodes
16+
17+
ref = Ref(0)
18+
result = copy_node!(dest_array, tree; ref)
19+
20+
@test ref[] == n_nodes # Increment once per node
21+
22+
# Should be the same tree:
23+
@test result == tree
24+
@test hash(result) == hash(tree)
25+
26+
# The root should be the last node in the destination array:
27+
@test result === dest_array[n_nodes]
28+
29+
# Every node in the resultant tree should be from an allocated
30+
# node in the destination array:
31+
@test all(n -> any(n === x for x in dest_array[1:n_nodes]), result)
32+
33+
# There should be no aliasing:
34+
@test Set(map(objectid, result)) == Set(map(objectid, dest_array[1:n_nodes]))
35+
end
36+
end
37+
38+
@testitem "copy_node! - leaf nodes" begin
39+
using DynamicExpressions
40+
using DynamicExpressions: copy_node!
41+
42+
leaf_constant = Node{Float64}(; val=1.0)
43+
leaf_feature = Node{Float64}(; feature=1)
44+
45+
for leaf in [leaf_constant, leaf_feature]
46+
dest_array = [Node{Float64}() for _ in 1:1]
47+
ref = Ref(0)
48+
result = copy_node!(dest_array, leaf; ref=ref)
49+
@test ref[] == 1
50+
@test result == leaf
51+
@test result === dest_array[1]
52+
end
53+
end

test/unittest.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ end
101101
include("test_base.jl")
102102
end
103103
include("test_base_2.jl")
104+
include("test_copy_inplace.jl")
104105

105106
@testitem "Test extra node fields" begin
106107
include("test_extra_node_fields.jl")

0 commit comments

Comments
 (0)