@@ -7,6 +7,7 @@ using DynamicExpressions:
77 @extend_operators ,
88 OperatorEnum,
99 get_scalar_constants,
10+ set_scalar_constants!,
1011 pack_scalar_constants!,
1112 unpack_scalar_constants,
1213 count_scalar_constants,
@@ -25,6 +26,8 @@ A tensor with a maximum of `N` dimensions, where `N` is a positive integer.
2526struct DynamicTensor{T,N,A<: Tuple{Base.RefValue{T},Vararg} }
2627 dims:: UInt8
2728 data:: A
29+ # ^For example, when N is 2, this is (Ref(0.0), Vector{Float64}[...], Matrix{Float64}[...])
30+ # See `Max2Tensor` below for an example
2831
2932 # ! format: off
3033 function DynamicTensor {T,N} (x:: A = nothing ) where {T,N,A<: Union{Nothing,Number,Array{<:Number}} }
@@ -54,7 +57,7 @@ DE.get_number_type(::Type{<:DynamicTensor{T}}) where {T} = T
5457
5558@generated function DE. is_valid (val:: DynamicTensor{<:Any,N} ) where {N}
5659 quote
57- @nif ($ N , i -> i == val. dims + 1 , i -> if i == 1
60+ @nif ($ (N + 1 ) , i -> i == val. dims + 1 , i -> if i == 1
5861 is_valid (val. data[i][])
5962 else
6063 is_valid_array (val. data[i])
6669) where {N}
6770 quote
6871 x. dims != y. dims && return false
69- @nif ($ N , i -> i == x. dims + 1 , i -> if i == 1
72+ @nif ($ (N + 1 ) , i -> i == x. dims + 1 , i -> if i == 1
7073 x. data[i][] == y. data[i][]
7174 else
7275 x. data[i] == y. data[i]
7679
7780@generated function DE. count_scalar_constants (val:: DynamicTensor{<:Any,N} ) where {N}
7881 quote
79- @nif ($ N , i -> i == val. dims + 1 , i -> i == 1 ? 1 : length (val. data[i]))
82+ @nif ($ (N + 1 ) , i -> i == val. dims + 1 , i -> i == 1 ? 1 : length (val. data[i]))
8083 end
8184end
8285
8386@generated function DE. pack_scalar_constants! (
8487 nvals:: AbstractVector{BT} , idx:: Int64 , val:: DynamicTensor{BT,N}
8588) where {BT<: Number ,N}
8689 quote
87- @nif ($ N , i -> i == val. dims + 1 , i -> if i == 1
90+ @nif ($ (N + 1 ) , i -> i == val. dims + 1 , i -> if i == 1
8891 nvals[idx] = val. data[i][]
8992 idx + 1
9093 else
101104) where {BT<: Number ,N}
102105 quote
103106 @nif (
104- $ N ,
107+ $ (N + 1 ) ,
105108 i -> i == val. dims + 1 ,
106109 i -> if i == 1
107110 val. data[i][] = nvals[idx]
0 commit comments