Skip to content

Commit a8fa68c

Browse files
committed
Fix issue due to assumption of node.constant == false
1 parent 74d1753 commit a8fa68c

File tree

3 files changed

+14
-13
lines changed

3 files changed

+14
-13
lines changed

src/Equation.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,8 @@ end
221221
T = node_factory_type(N, T1, T2)
222222
n = allocator(N, T)
223223
n.degree = 0
224-
n.val = convert(T, val)
225224
n.constant = true
225+
n.val = convert(T, val)
226226
return n
227227
end
228228
"""Create a variable leaf, to store data."""
@@ -328,7 +328,7 @@ function Base.promote_rule(::Type{GraphNode{T1}}, ::Type{GraphNode{T2}}) where {
328328
end
329329

330330
# TODO: Verify using this helps with garbage collection
331-
create_dummy_node(::Type{N}) where {N<:AbstractExpressionNode} = N(; feature=zero(UInt16))
331+
create_dummy_node(::Type{N}) where {N<:AbstractExpressionNode} = N()
332332

333333
"""
334334
set_node!(tree::AbstractExpressionNode{T}, new_tree::AbstractExpressionNode{T}) where {T}

src/EquationUtils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ function set_constants!(
9191
Base.require_one_based_indexing(constants)
9292
i = Ref(0)
9393
foreach(tree) do node
94-
if node.degree == 0 && node.constant
94+
if is_node_constant(node)
9595
@inbounds node.val = constants[i[] += 1]
9696
end
9797
end

src/SimplifyEquation.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,13 @@ function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where
3333
tree.r = combine_operators(tree.r, operators)
3434
end
3535

36-
top_level_constant = tree.degree == 2 && (tree.l.constant || tree.r.constant)
36+
top_level_constant =
37+
tree.degree == 2 && (is_node_constant(tree.l) || is_node_constant(tree.r))
3738
if tree.degree == 2 && is_commutative(operators.binops[tree.op]) && top_level_constant
3839
# TODO: Does this break SymbolicRegression.jl due to the different names of operators?
3940
op = tree.op
4041
# Put the constant in r. Need to assume var in left for simplification assumption.
41-
if tree.l.constant
42+
if is_node_constant(tree.l)
4243
tmp = tree.r
4344
tree.r = tree.l
4445
tree.l = tmp
@@ -47,12 +48,12 @@ function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where
4748
# Simplify down first
4849
below = tree.l
4950
if below.degree == 2 && below.op == op
50-
if below.l.constant
51+
if is_node_constant(below.l)
5152
tree = below
5253
tree.l.val = _bin_op_kernel(
5354
operators.binops[op], tree.l.val::T, topconstant
5455
)
55-
elseif below.r.constant
56+
elseif is_node_constant(below.r)
5657
tree = below
5758
tree.r.val = _bin_op_kernel(
5859
operators.binops[op], tree.r.val::T, topconstant
@@ -65,17 +66,17 @@ function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where
6566

6667
# Currently just simplifies subtraction. (can't assume both plus and sub are operators)
6768
# Not commutative, so use different op.
68-
if tree.l.constant
69+
if is_node_constant(tree.l)
6970
if tree.r.degree == 2 && tree.op == tree.r.op
70-
if tree.r.l.constant
71+
if is_node_constant(tree.r.l)
7172
#(const - (const - var)) => (var - const)
7273
l = tree.l
7374
r = tree.r
7475
simplified_const = (r.l.val::T - l.val::T) #neg(sub(l.val, r.l.val))
7576
tree.l = tree.r.r
7677
tree.r = l
7778
tree.r.val = simplified_const
78-
elseif tree.r.r.constant
79+
elseif is_node_constant(tree.r.r)
7980
#(const - (var - const)) => (const - var)
8081
l = tree.l
8182
r = tree.r
@@ -84,17 +85,17 @@ function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where
8485
tree.l.val = simplified_const
8586
end
8687
end
87-
else #tree.r.constant is true
88+
else #tree.r is a constant
8889
if tree.l.degree == 2 && tree.op == tree.l.op
89-
if tree.l.l.constant
90+
if is_node_constant(tree.l.l)
9091
#((const - var) - const) => (const - var)
9192
l = tree.l
9293
r = tree.r
9394
simplified_const = l.l.val::T - r.val::T#sub(l.l.val, r.val)
9495
tree.r = tree.l.r
9596
tree.l = r
9697
tree.l.val = simplified_const
97-
elseif tree.l.r.constant
98+
elseif is_node_constant(tree.l.r)
9899
#((var - const) - const) => (var - const)
99100
l = tree.l
100101
r = tree.r

0 commit comments

Comments
 (0)