Skip to content

Commit 74d1753

Browse files
committed
Refactor constructors
1 parent 88084d2 commit 74d1753

File tree

2 files changed

+111
-73
lines changed

2 files changed

+111
-73
lines changed

src/Equation.jl

Lines changed: 107 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ to `Node`s `l` and `r`.
116116
mutable struct Node{T} <: AbstractExpressionNode{T}
117117
degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
118118
constant::Bool # false if variable
119-
val::Union{T,Nothing} # If is a constant, this stores the actual value
119+
val::T # If is a constant, this stores the actual value
120120
# ------------------- (possibly undefined below)
121121
feature::UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index.
122122
op::UInt8 # If operator, this is the index of the operator in operators.binops, or operators.unaops
@@ -126,12 +126,7 @@ mutable struct Node{T} <: AbstractExpressionNode{T}
126126
#################
127127
## Constructors:
128128
#################
129-
Node(d::Integer, c::Bool, v::_T) where {_T} = new{_T}(UInt8(d), c, v)
130-
Node(::Type{_T}, d::Integer, c::Bool, v::_T) where {_T} = new{_T}(UInt8(d), c, v)
131-
Node(::Type{_T}, d::Integer, c::Bool, v::Nothing, f::Integer) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f))
132-
Node(d::Integer, c::Bool, v::Nothing, f::Integer, o::Integer, l::Node{_T}) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f), UInt8(o), l)
133-
Node(d::Integer, c::Bool, v::Nothing, f::Integer, o::Integer, l::Node{_T}, r::Node{_T}) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f), UInt8(o), l, r)
134-
129+
Node{_T}() where {_T} = new{_T}()
135130
end
136131

137132
"""
@@ -168,26 +163,22 @@ when constructing or setting properties.
168163
mutable struct GraphNode{T} <: AbstractExpressionNode{T}
169164
degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
170165
constant::Bool # false if variable
171-
val::Union{T,Nothing} # If is a constant, this stores the actual value
166+
val::T # If is a constant, this stores the actual value
172167
# ------------------- (possibly undefined below)
173168
feature::UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index.
174169
op::UInt8 # If operator, this is the index of the operator in operators.binops, or operators.unaops
175170
l::GraphNode{T} # Left child node. Only defined for degree=1 or degree=2.
176171
r::GraphNode{T} # Right child node. Only defined for degree=2.
177172

178-
#################
179-
## Constructors:
180-
#################
181-
GraphNode(d::Integer, c::Bool, v::_T) where {_T} = new{_T}(UInt8(d), c, v)
182-
GraphNode(::Type{_T}, d::Integer, c::Bool, v::_T) where {_T} = new{_T}(UInt8(d), c, v)
183-
GraphNode(::Type{_T}, d::Integer, c::Bool, v::Nothing, f::Integer) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f))
184-
GraphNode(d::Integer, c::Bool, v::Nothing, f::Integer, o::Integer, l::GraphNode{_T}) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f), UInt8(o), l)
185-
GraphNode(d::Integer, c::Bool, v::Nothing, f::Integer, o::Integer, l::GraphNode{_T}, r::GraphNode{_T}) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f), UInt8(o), l, r)
173+
GraphNode{_T}() where {_T} = new{_T}()
186174
end
187175

188176
################################################################################
189177
#! format: on
190178

179+
Base.eltype(::Type{<:AbstractExpressionNode{T}}) where {T} = T
180+
Base.eltype(::AbstractExpressionNode{T}) where {T} = T
181+
191182
constructorof(::Type{N}) where {N<:AbstractNode} = Base.typename(N).wrapper
192183
constructorof(::Type{<:Node}) = Node
193184
constructorof(::Type{<:GraphNode}) = GraphNode
@@ -198,48 +189,119 @@ end
198189
with_type_parameters(::Type{<:Node}, ::Type{T}) where {T} = Node{T}
199190
with_type_parameters(::Type{<:GraphNode}, ::Type{T}) where {T} = GraphNode{T}
200191

192+
function default_allocator(::Type{N}, ::Type{T}) where {N<:AbstractExpressionNode,T}
193+
return with_type_parameters(N, T)()
194+
end
195+
default_allocator(::Type{<:Node}, ::Type{T}) where {T} = Node{T}()
196+
default_allocator(::Type{<:GraphNode}, ::Type{T}) where {T} = GraphNode{T}()
197+
201198
"""Trait declaring whether nodes share children or not."""
202199
preserve_sharing(::Type{<:AbstractNode}) = false
203200
preserve_sharing(::Type{<:Node}) = false
204201
preserve_sharing(::Type{<:GraphNode}) = true
205202

206203
include("base.jl")
207204

208-
function (::Type{N})(
209-
::Type{T}=Undefined; val::T1=nothing, feature::T2=nothing
210-
) where {T,T1,T2<:Union{Integer,Nothing},N<:AbstractExpressionNode}
211-
((T1 <: Nothing) (T2 <: Nothing)) || error(
212-
"You must specify exactly one of `val` or `feature` when creating a leaf node."
213-
)
214-
Tout = compute_value_output_type(N, T, T1)
215-
if T2 <: Nothing
216-
if !(T1 <: T)
217-
# Only convert if not already in the type union.
218-
val = convert(Tout, val)
219-
end
220-
return constructorof(N)(Tout, 0, true, val)
205+
@inline function (::Type{N})(
206+
::Type{T1}=Undefined;
207+
val=nothing,
208+
feature=nothing,
209+
op=nothing,
210+
l=nothing,
211+
r=nothing,
212+
allocator=default_allocator,
213+
) where {T1,N<:AbstractExpressionNode}
214+
return node_factory(N, T1, val, feature, op, l, r, allocator)
215+
end
216+
217+
"""Create a constant leaf."""
218+
@inline function node_factory(
219+
::Type{N}, ::Type{T1}, val::T2, ::Nothing, ::Nothing, ::Nothing, ::Nothing, allocator
220+
) where {N,T1,T2}
221+
T = node_factory_type(N, T1, T2)
222+
n = allocator(N, T)
223+
n.degree = 0
224+
n.val = convert(T, val)
225+
n.constant = true
226+
return n
227+
end
228+
"""Create a variable leaf, to store data."""
229+
@inline function node_factory(
230+
::Type{N},
231+
::Type{T1},
232+
::Nothing,
233+
feature::Integer,
234+
::Nothing,
235+
::Nothing,
236+
::Nothing,
237+
allocator,
238+
) where {N,T1}
239+
T = node_factory_type(N, T1, DEFAULT_NODE_TYPE)
240+
n = allocator(N, T)
241+
n.degree = 0
242+
n.constant = false
243+
n.feature = feature
244+
return n
245+
end
246+
"""Create a unary operator node."""
247+
@inline function node_factory(
248+
::Type{N},
249+
::Type{T1},
250+
::Nothing,
251+
::Nothing,
252+
op::Integer,
253+
l::AbstractExpressionNode{T2},
254+
::Nothing,
255+
allocator,
256+
) where {N,T1,T2}
257+
@assert l isa N
258+
T = T2 # Always prefer existing nodes, so we don't mess up references from conversion
259+
n = allocator(N, T)
260+
n.degree = 1
261+
n.op = op
262+
n.l = l
263+
return n
264+
end
265+
"""Create a binary operator node."""
266+
@inline function node_factory(
267+
::Type{N},
268+
::Type{T1},
269+
::Nothing,
270+
::Nothing,
271+
op::Integer,
272+
l::AbstractExpressionNode{T2},
273+
r::AbstractExpressionNode{T3},
274+
allocator,
275+
) where {N,T1,T2,T3}
276+
T = promote_type(T2, T3)
277+
n = allocator(N, T)
278+
n.degree = 2
279+
n.op = op
280+
n.l = T2 === T ? l : convert(with_type_parameters(N, T), l)
281+
n.r = T3 === T ? r : convert(with_type_parameters(N, T), r)
282+
return n
283+
end
284+
@inline function node_factory_type(::Type{N}, ::Type{T1}, ::Type{T2}) where {N,T1,T2}
285+
if T1 === Undefined && N isa UnionAll
286+
T2
287+
elseif T1 === Undefined
288+
eltype(N)
289+
elseif N isa UnionAll
290+
T1
221291
else
222-
return constructorof(N)(Tout, 0, false, nothing, feature)
292+
eltype(N)
223293
end
224294
end
295+
225296
function (::Type{N})(
226-
op::Integer, l::AbstractExpressionNode{T}
227-
) where {T,N<:AbstractExpressionNode}
228-
@assert l isa N
229-
return constructorof(N)(1, false, nothing, 0, op, l)
297+
op::Integer, l::AbstractExpressionNode
298+
) where {N<:AbstractExpressionNode}
299+
return N(; op=op, l=l)
230300
end
231301
function (::Type{N})(
232-
op::Integer, l::AbstractExpressionNode{T1}, r::AbstractExpressionNode{T2}
233-
) where {T1,T2,N<:AbstractExpressionNode}
234-
@assert l isa N && r isa N
235-
# Get highest type:
236-
if T1 != T2
237-
T = promote_type(T1, T2)
238-
# TODO: This might slow things down
239-
l = convert(with_type_parameters(N, T), l)
240-
r = convert(with_type_parameters(N, T), r)
241-
end
242-
return constructorof(N)(2, false, nothing, 0, op, l, r)
302+
op::Integer, l::AbstractExpressionNode, r::AbstractExpressionNode
303+
) where {N<:AbstractExpressionNode}
304+
return N(; op=op, l=l, r=r)
243305
end
244306
function (::Type{N})(var_string::String) where {N<:AbstractExpressionNode}
245307
Base.depwarn(
@@ -255,28 +317,6 @@ function (::Type{N})(
255317
return N(; feature=i)
256318
end
257319

258-
@inline function compute_value_output_type(
259-
::Type{N}, ::Type{T}, ::Type{T1}
260-
) where {N<:AbstractExpressionNode,T,T1}
261-
!(N isa UnionAll) &&
262-
T !== Undefined &&
263-
error(
264-
"Ambiguous type for node. Please either use `Node{T}(; val, feature)` or `Node(T; val, feature)`.",
265-
)
266-
267-
if T === Undefined && N isa UnionAll
268-
if T1 <: Nothing
269-
return DEFAULT_NODE_TYPE
270-
else
271-
return T1
272-
end
273-
elseif T === Undefined
274-
return eltype(N)
275-
else
276-
return T
277-
end
278-
end
279-
280320
function Base.promote_rule(::Type{Node{T1}}, ::Type{Node{T2}}) where {T1,T2}
281321
return Node{promote_type(T1, T2)}
282322
end
@@ -286,8 +326,6 @@ end
286326
function Base.promote_rule(::Type{GraphNode{T1}}, ::Type{GraphNode{T2}}) where {T1,T2}
287327
return GraphNode{promote_type(T1, T2)}
288328
end
289-
Base.eltype(::Type{<:AbstractExpressionNode{T}}) where {T} = T
290-
Base.eltype(::AbstractExpressionNode{T}) where {T} = T
291329

292330
# TODO: Verify using this helps with garbage collection
293331
create_dummy_node(::Type{N}) where {N<:AbstractExpressionNode} = N(; feature=zero(UInt16))

src/base.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -478,20 +478,20 @@ function convert(
478478
end
479479
return tree_mapreduce(
480480
t -> if t.constant
481-
constructorof(N1)(T1, 0, true, convert(T1, t.val::T2))
481+
constructorof(N1)(; val=convert(T1, t.val::T2))
482482
else
483-
constructorof(N1)(T1, 0, false, nothing, t.feature)
483+
constructorof(N1)(T1; feature=t.feature)
484484
end,
485485
identity,
486-
(p, c...) -> constructorof(N1)(p.degree, false, nothing, 0, p.op, c...),
486+
((p, c::Vararg{Any,M}) where {M}) -> constructorof(N1)(p.op, c...),
487487
tree,
488488
N1,
489489
)
490490
end
491491
function convert(
492492
::Type{N1}, tree::N2
493493
) where {T2,N1<:AbstractExpressionNode,N2<:AbstractExpressionNode{T2}}
494-
return convert(constructorof(N1){T2}, tree)
494+
return convert(with_type_parameters(N1, T2), tree)
495495
end
496496
function (::Type{N})(tree::AbstractExpressionNode) where {N<:AbstractExpressionNode}
497497
return convert(N, tree)

0 commit comments

Comments
 (0)