Skip to content

Commit 455e80f

Browse files
committed
test: fix more generic DynamicTensor
1 parent 5fa74fc commit 455e80f

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

test/test_non_number_eval_tree_array.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using DynamicExpressions:
77
@extend_operators,
88
OperatorEnum,
99
get_scalar_constants,
10+
set_scalar_constants!,
1011
pack_scalar_constants!,
1112
unpack_scalar_constants,
1213
count_scalar_constants,
@@ -25,6 +26,8 @@ A tensor with a maximum of `N` dimensions, where `N` is a positive integer.
2526
struct DynamicTensor{T,N,A<:Tuple{Base.RefValue{T},Vararg}}
2627
dims::UInt8
2728
data::A
29+
# ^For example, when N is 2, this is (Ref(0.0), Vector{Float64}[...], Matrix{Float64}[...])
30+
# See `Max2Tensor` below for an example
2831

2932
#! format: off
3033
function DynamicTensor{T,N}(x::A=nothing) where {T,N,A<:Union{Nothing,Number,Array{<:Number}}}
@@ -54,7 +57,7 @@ DE.get_number_type(::Type{<:DynamicTensor{T}}) where {T} = T
5457

5558
@generated function DE.is_valid(val::DynamicTensor{<:Any,N}) where {N}
5659
quote
57-
@nif($N, i -> i == val.dims + 1, i -> if i == 1
60+
@nif($(N + 1), i -> i == val.dims + 1, i -> if i == 1
5861
is_valid(val.data[i][])
5962
else
6063
is_valid_array(val.data[i])
@@ -66,7 +69,7 @@ end
6669
) where {N}
6770
quote
6871
x.dims != y.dims && return false
69-
@nif($N, i -> i == x.dims + 1, i -> if i == 1
72+
@nif($(N + 1), i -> i == x.dims + 1, i -> if i == 1
7073
x.data[i][] == y.data[i][]
7174
else
7275
x.data[i] == y.data[i]
@@ -76,15 +79,15 @@ end
7679

7780
@generated function DE.count_scalar_constants(val::DynamicTensor{<:Any,N}) where {N}
7881
quote
79-
@nif($N, i -> i == val.dims + 1, i -> i == 1 ? 1 : length(val.data[i]))
82+
@nif($(N + 1), i -> i == val.dims + 1, i -> i == 1 ? 1 : length(val.data[i]))
8083
end
8184
end
8285

8386
@generated function DE.pack_scalar_constants!(
8487
nvals::AbstractVector{BT}, idx::Int64, val::DynamicTensor{BT,N}
8588
) where {BT<:Number,N}
8689
quote
87-
@nif($N, i -> i == val.dims + 1, i -> if i == 1
90+
@nif($(N + 1), i -> i == val.dims + 1, i -> if i == 1
8891
nvals[idx] = val.data[i][]
8992
idx + 1
9093
else
@@ -101,7 +104,7 @@ end
101104
) where {BT<:Number,N}
102105
quote
103106
@nif(
104-
$N,
107+
$(N + 1),
105108
i -> i == val.dims + 1,
106109
i -> if i == 1
107110
val.data[i][] = nvals[idx]

0 commit comments

Comments
 (0)