@@ -116,7 +116,7 @@ to `Node`s `l` and `r`.
116116mutable struct Node{T} <: AbstractExpressionNode{T}
117117 degree:: UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
118118 constant:: Bool # false if variable
119- val:: Union{T,Nothing} # If is a constant, this stores the actual value
119+ val:: T # If is a constant, this stores the actual value
120120 # ------------------- (possibly undefined below)
121121 feature:: UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index.
122122 op:: UInt8 # If operator, this is the index of the operator in operators.binops, or operators.unaops
@@ -126,12 +126,7 @@ mutable struct Node{T} <: AbstractExpressionNode{T}
126126 # ################
127127 # # Constructors:
128128 # ################
129- Node (d:: Integer , c:: Bool , v:: _T ) where {_T} = new {_T} (UInt8 (d), c, v)
130- Node (:: Type{_T} , d:: Integer , c:: Bool , v:: _T ) where {_T} = new {_T} (UInt8 (d), c, v)
131- Node (:: Type{_T} , d:: Integer , c:: Bool , v:: Nothing , f:: Integer ) where {_T} = new {_T} (UInt8 (d), c, v, UInt16 (f))
132- Node (d:: Integer , c:: Bool , v:: Nothing , f:: Integer , o:: Integer , l:: Node{_T} ) where {_T} = new {_T} (UInt8 (d), c, v, UInt16 (f), UInt8 (o), l)
133- Node (d:: Integer , c:: Bool , v:: Nothing , f:: Integer , o:: Integer , l:: Node{_T} , r:: Node{_T} ) where {_T} = new {_T} (UInt8 (d), c, v, UInt16 (f), UInt8 (o), l, r)
134-
129+ Node {_T} () where {_T} = new {_T} ()
135130end
136131
137132"""
@@ -168,26 +163,22 @@ when constructing or setting properties.
168163mutable struct GraphNode{T} <: AbstractExpressionNode{T}
169164 degree:: UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
170165 constant:: Bool # false if variable
171- val:: Union{T,Nothing} # If is a constant, this stores the actual value
166+ val:: T # If is a constant, this stores the actual value
172167 # ------------------- (possibly undefined below)
173168 feature:: UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index.
174169 op:: UInt8 # If operator, this is the index of the operator in operators.binops, or operators.unaops
175170 l:: GraphNode{T} # Left child node. Only defined for degree=1 or degree=2.
176171 r:: GraphNode{T} # Right child node. Only defined for degree=2.
177172
178- # ################
179- # # Constructors:
180- # ################
181- GraphNode (d:: Integer , c:: Bool , v:: _T ) where {_T} = new {_T} (UInt8 (d), c, v)
182- GraphNode (:: Type{_T} , d:: Integer , c:: Bool , v:: _T ) where {_T} = new {_T} (UInt8 (d), c, v)
183- GraphNode (:: Type{_T} , d:: Integer , c:: Bool , v:: Nothing , f:: Integer ) where {_T} = new {_T} (UInt8 (d), c, v, UInt16 (f))
184- GraphNode (d:: Integer , c:: Bool , v:: Nothing , f:: Integer , o:: Integer , l:: GraphNode{_T} ) where {_T} = new {_T} (UInt8 (d), c, v, UInt16 (f), UInt8 (o), l)
185- GraphNode (d:: Integer , c:: Bool , v:: Nothing , f:: Integer , o:: Integer , l:: GraphNode{_T} , r:: GraphNode{_T} ) where {_T} = new {_T} (UInt8 (d), c, v, UInt16 (f), UInt8 (o), l, r)
173+ GraphNode {_T} () where {_T} = new {_T} ()
186174end
187175
188176# ###############################################################################
189177# ! format: on
190178
179+ Base. eltype (:: Type{<:AbstractExpressionNode{T}} ) where {T} = T
180+ Base. eltype (:: AbstractExpressionNode{T} ) where {T} = T
181+
191182constructorof (:: Type{N} ) where {N<: AbstractNode } = Base. typename (N). wrapper
192183constructorof (:: Type{<:Node} ) = Node
193184constructorof (:: Type{<:GraphNode} ) = GraphNode
@@ -198,48 +189,119 @@ end
198189with_type_parameters (:: Type{<:Node} , :: Type{T} ) where {T} = Node{T}
199190with_type_parameters (:: Type{<:GraphNode} , :: Type{T} ) where {T} = GraphNode{T}
200191
192+ function default_allocator (:: Type{N} , :: Type{T} ) where {N<: AbstractExpressionNode ,T}
193+ return with_type_parameters (N, T)()
194+ end
195+ default_allocator (:: Type{<:Node} , :: Type{T} ) where {T} = Node {T} ()
196+ default_allocator (:: Type{<:GraphNode} , :: Type{T} ) where {T} = GraphNode {T} ()
197+
201198""" Trait declaring whether nodes share children or not."""
202199preserve_sharing (:: Type{<:AbstractNode} ) = false
203200preserve_sharing (:: Type{<:Node} ) = false
204201preserve_sharing (:: Type{<:GraphNode} ) = true
205202
206203include (" base.jl" )
207204
208- function (:: Type{N} )(
209- :: Type{T} = Undefined; val:: T1 = nothing , feature:: T2 = nothing
210- ) where {T,T1,T2<: Union{Integer,Nothing} ,N<: AbstractExpressionNode }
211- ((T1 <: Nothing ) ⊻ (T2 <: Nothing )) || error (
212- " You must specify exactly one of `val` or `feature` when creating a leaf node."
213- )
214- Tout = compute_value_output_type (N, T, T1)
215- if T2 <: Nothing
216- if ! (T1 <: T )
217- # Only convert if not already in the type union.
218- val = convert (Tout, val)
219- end
220- return constructorof (N)(Tout, 0 , true , val)
205+ @inline function (:: Type{N} )(
206+ :: Type{T1} = Undefined;
207+ val= nothing ,
208+ feature= nothing ,
209+ op= nothing ,
210+ l= nothing ,
211+ r= nothing ,
212+ allocator= default_allocator,
213+ ) where {T1,N<: AbstractExpressionNode }
214+ return node_factory (N, T1, val, feature, op, l, r, allocator)
215+ end
216+
217+ """ Create a constant leaf."""
218+ @inline function node_factory (
219+ :: Type{N} , :: Type{T1} , val:: T2 , :: Nothing , :: Nothing , :: Nothing , :: Nothing , allocator
220+ ) where {N,T1,T2}
221+ T = node_factory_type (N, T1, T2)
222+ n = allocator (N, T)
223+ n. degree = 0
224+ n. val = convert (T, val)
225+ n. constant = true
226+ return n
227+ end
228+ """ Create a variable leaf, to store data."""
229+ @inline function node_factory (
230+ :: Type{N} ,
231+ :: Type{T1} ,
232+ :: Nothing ,
233+ feature:: Integer ,
234+ :: Nothing ,
235+ :: Nothing ,
236+ :: Nothing ,
237+ allocator,
238+ ) where {N,T1}
239+ T = node_factory_type (N, T1, DEFAULT_NODE_TYPE)
240+ n = allocator (N, T)
241+ n. degree = 0
242+ n. constant = false
243+ n. feature = feature
244+ return n
245+ end
246+ """ Create a unary operator node."""
247+ @inline function node_factory (
248+ :: Type{N} ,
249+ :: Type{T1} ,
250+ :: Nothing ,
251+ :: Nothing ,
252+ op:: Integer ,
253+ l:: AbstractExpressionNode{T2} ,
254+ :: Nothing ,
255+ allocator,
256+ ) where {N,T1,T2}
257+ @assert l isa N
258+ T = T2 # Always prefer existing nodes, so we don't mess up references from conversion
259+ n = allocator (N, T)
260+ n. degree = 1
261+ n. op = op
262+ n. l = l
263+ return n
264+ end
265+ """ Create a binary operator node."""
266+ @inline function node_factory (
267+ :: Type{N} ,
268+ :: Type{T1} ,
269+ :: Nothing ,
270+ :: Nothing ,
271+ op:: Integer ,
272+ l:: AbstractExpressionNode{T2} ,
273+ r:: AbstractExpressionNode{T3} ,
274+ allocator,
275+ ) where {N,T1,T2,T3}
276+ T = promote_type (T2, T3)
277+ n = allocator (N, T)
278+ n. degree = 2
279+ n. op = op
280+ n. l = T2 === T ? l : convert (with_type_parameters (N, T), l)
281+ n. r = T3 === T ? r : convert (with_type_parameters (N, T), r)
282+ return n
283+ end
284+ @inline function node_factory_type (:: Type{N} , :: Type{T1} , :: Type{T2} ) where {N,T1,T2}
285+ if T1 === Undefined && N isa UnionAll
286+ T2
287+ elseif T1 === Undefined
288+ eltype (N)
289+ elseif N isa UnionAll
290+ T1
221291 else
222- return constructorof (N)(Tout, 0 , false , nothing , feature )
292+ eltype (N )
223293 end
224294end
295+
225296function (:: Type{N} )(
226- op:: Integer , l:: AbstractExpressionNode{T}
227- ) where {T,N<: AbstractExpressionNode }
228- @assert l isa N
229- return constructorof (N)(1 , false , nothing , 0 , op, l)
297+ op:: Integer , l:: AbstractExpressionNode
298+ ) where {N<: AbstractExpressionNode }
299+ return N (; op= op, l= l)
230300end
231301function (:: Type{N} )(
232- op:: Integer , l:: AbstractExpressionNode{T1} , r:: AbstractExpressionNode{T2}
233- ) where {T1,T2,N<: AbstractExpressionNode }
234- @assert l isa N && r isa N
235- # Get highest type:
236- if T1 != T2
237- T = promote_type (T1, T2)
238- # TODO : This might slow things down
239- l = convert (with_type_parameters (N, T), l)
240- r = convert (with_type_parameters (N, T), r)
241- end
242- return constructorof (N)(2 , false , nothing , 0 , op, l, r)
302+ op:: Integer , l:: AbstractExpressionNode , r:: AbstractExpressionNode
303+ ) where {N<: AbstractExpressionNode }
304+ return N (; op= op, l= l, r= r)
243305end
244306function (:: Type{N} )(var_string:: String ) where {N<: AbstractExpressionNode }
245307 Base. depwarn (
@@ -255,28 +317,6 @@ function (::Type{N})(
255317 return N (; feature= i)
256318end
257319
258- @inline function compute_value_output_type (
259- :: Type{N} , :: Type{T} , :: Type{T1}
260- ) where {N<: AbstractExpressionNode ,T,T1}
261- ! (N isa UnionAll) &&
262- T != = Undefined &&
263- error (
264- " Ambiguous type for node. Please either use `Node{T}(; val, feature)` or `Node(T; val, feature)`." ,
265- )
266-
267- if T === Undefined && N isa UnionAll
268- if T1 <: Nothing
269- return DEFAULT_NODE_TYPE
270- else
271- return T1
272- end
273- elseif T === Undefined
274- return eltype (N)
275- else
276- return T
277- end
278- end
279-
280320function Base. promote_rule (:: Type{Node{T1}} , :: Type{Node{T2}} ) where {T1,T2}
281321 return Node{promote_type (T1, T2)}
282322end
286326function Base. promote_rule (:: Type{GraphNode{T1}} , :: Type{GraphNode{T2}} ) where {T1,T2}
287327 return GraphNode{promote_type (T1, T2)}
288328end
289- Base. eltype (:: Type{<:AbstractExpressionNode{T}} ) where {T} = T
290- Base. eltype (:: AbstractExpressionNode{T} ) where {T} = T
291329
292330# TODO : Verify using this helps with garbage collection
293331create_dummy_node (:: Type{N} ) where {N<: AbstractExpressionNode } = N (; feature= zero (UInt16))
0 commit comments