Skip to content

Commit 5fa74fc

Browse files
committed
refactor: Max2Tensor to more general DynamicTensor
1 parent 1b1a674 commit 5fa74fc

File tree

1 file changed

+104
-73
lines changed

1 file changed

+104
-73
lines changed

test/test_non_number_eval_tree_array.jl

Lines changed: 104 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

2+
using Base.Cartesian: @nif
23
using DynamicExpressions:
34
DynamicExpressions as DE,
45
ValueInterface,
@@ -14,110 +15,140 @@ using DynamicExpressions:
1415

1516
using Interfaces: Interfaces, @implements, Arguments
1617

17-
# Max2Tensor (Tensor with a maximum of 3 dimensions) - struct that contains all three datatypes
18-
mutable struct Max2Tensor{T}
19-
dims::UInt8 # number of dimmentions
20-
scalar::T
21-
vector::Vector{T}
22-
matrix::Matrix{T}
18+
"""
19+
DynamicTensor{T,N,A<:Tuple{Base.RefValue{T},Vararg}}
20+
21+
A tensor with a maximum of `N` dimensions, where `N` is a positive integer.
22+
`.data[1]` is a `Ref` to the scalar value (so that it is mutable), while
23+
`.data[n]` is the array of dimension `n-1`.
24+
"""
25+
struct DynamicTensor{T,N,A<:Tuple{Base.RefValue{T},Vararg}}
26+
dims::UInt8
27+
data::A
2328

2429
#! format: off
25-
function Max2Tensor{T}(x::A=nothing) where {T,A<:Union{Nothing,Number,Vector{<:Number},Matrix{<:Number}}}
26-
return new(
30+
function DynamicTensor{T,N}(x::A=nothing) where {T,N,A<:Union{Nothing,Number,Array{<:Number}}}
31+
nd = x === nothing ? 0 : ndims(x)
32+
data = (
33+
Ref(x isa Number ? Base.convert(T, x) : zero(T)),
34+
ntuple(
35+
i -> if nd == i
36+
Array{T,i}(x)
37+
else
38+
Array{T,i}(undef, ntuple(_ -> 0, Val(i))...)
39+
end,
40+
Val(N)
41+
)...,
42+
)
43+
return new{T,N,typeof(data)}(
2744
x === nothing ? 0 : ndims(x),
28-
x isa Number ? Base.convert(T, x) : zero(T),
29-
x isa Vector ? Vector{T}(x) : Vector{T}(undef, 0),
30-
x isa Matrix ? Matrix{T}(x) : Matrix{T}(undef, 0, 0),
45+
data
3146
)
3247
end
3348
#! format: on
3449
end
35-
36-
DE.get_number_type(::Type{<:Max2Tensor{T}}) where {T} = T
37-
38-
function DE.is_valid(val::T) where {Q<:Number,T<:Max2Tensor{Q}}
39-
if val.dims == 0
40-
return is_valid(val.scalar)
41-
elseif val.dims == 1
42-
return is_valid_array(val.vector)
50+
const Max2Tensor{T} = DynamicTensor{T,2,Tuple{Base.RefValue{T},Vector{T},Matrix{T}}}
51+
Max2Tensor{T}(x) where {T} = DynamicTensor{T,2}(x)
52+
53+
DE.get_number_type(::Type{<:DynamicTensor{T}}) where {T} = T
54+
55+
@generated function DE.is_valid(val::DynamicTensor{<:Any,N}) where {N}
56+
quote
57+
@nif($N, i -> i == val.dims + 1, i -> if i == 1
58+
is_valid(val.data[i][])
59+
else
60+
is_valid_array(val.data[i])
61+
end)
4362
end
44-
return is_valid_array(val.matrix)
4563
end
46-
47-
function Base.:(==)(x::Max2Tensor{T}, y::Max2Tensor{T}) where {T}
48-
if x.dims != y.dims
49-
return false
50-
elseif x.dims == 0
51-
return x.scalar == y.scalar
52-
elseif x.dims == 1
53-
return x.vector == y.vector
64+
@generated function Base.:(==)(
65+
x::DynamicTensor{<:Any,N}, y::DynamicTensor{<:Any,N}
66+
) where {N}
67+
quote
68+
x.dims != y.dims && return false
69+
@nif($N, i -> i == x.dims + 1, i -> if i == 1
70+
x.data[i][] == y.data[i][]
71+
else
72+
x.data[i] == y.data[i]
73+
end)
5474
end
55-
return x.matrix == y.matrix
5675
end
5776

58-
function DE.count_scalar_constants(val::T) where {BT,T<:Max2Tensor{BT}}
59-
if val.dims == 0
60-
return 1
61-
elseif val.dims == 1
62-
return length(val.vector)
77+
@generated function DE.count_scalar_constants(val::DynamicTensor{<:Any,N}) where {N}
78+
quote
79+
@nif($N, i -> i == val.dims + 1, i -> i == 1 ? 1 : length(val.data[i]))
6380
end
64-
return length(val.matrix)
6581
end
6682

67-
function DE.pack_scalar_constants!(
68-
nvals::AbstractVector{BT}, idx::Int64, val::Max2Tensor{BT}
69-
) where {BT<:Number}
70-
if val.dims == 0
71-
nvals[idx] = val.scalar
72-
return idx + 1
73-
elseif val.dims == 1
74-
@view(nvals[idx:(idx + length(val.vector) - 1)]) .= val.vector
75-
return idx + length(val.vector)
83+
@generated function DE.pack_scalar_constants!(
84+
nvals::AbstractVector{BT}, idx::Int64, val::DynamicTensor{BT,N}
85+
) where {BT<:Number,N}
86+
quote
87+
@nif($N, i -> i == val.dims + 1, i -> if i == 1
88+
nvals[idx] = val.data[i][]
89+
idx + 1
90+
else
91+
data = val.data[i]
92+
num = length(data)
93+
copyto!(nvals, idx, @view(data[:]))
94+
idx + num
95+
end)
7696
end
77-
@view(nvals[idx:(idx + length(val.matrix) - 1)]) .= reshape(
78-
val.matrix, length(val.matrix)
79-
)
80-
return idx + length(val.matrix)
8197
end
8298

83-
function DE.unpack_scalar_constants(
84-
nvals::AbstractVector{BT}, idx::Int64, val::Max2Tensor{BT}
85-
) where {BT<:Number}
86-
if val.dims == 0
87-
val.scalar = nvals[idx]
88-
return idx + 1, val
89-
elseif val.dims == 1
90-
val.vector .= @view(nvals[idx:(idx + length(val.vector) - 1)])
91-
return idx + length(val.vector), val
99+
@generated function DE.unpack_scalar_constants(
100+
nvals::AbstractVector{BT}, idx::Int64, val::DynamicTensor{BT,N}
101+
) where {BT<:Number,N}
102+
quote
103+
@nif(
104+
$N,
105+
i -> i == val.dims + 1,
106+
i -> if i == 1
107+
val.data[i][] = nvals[idx]
108+
(idx + 1, val)
109+
else
110+
data = val.data[i]
111+
num = length(data)
112+
copyto!(data, @view(nvals[idx:(idx + num - 1)]))
113+
(idx + num, val)
114+
end
115+
)
92116
end
93-
reshape(val.matrix, length(val.matrix)) .= @view(
94-
nvals[idx:(idx + length(val.matrix) - 1)]
95-
)
96-
return idx + length(val.matrix), val
97117
end
98118

99-
# Declare that `Max2Tensor` implements `ValueInterface`
100-
@implements(ValueInterface, Max2Tensor, [Arguments()])
119+
# Declare that `DynamicTensor` implements `ValueInterface`
120+
@implements(ValueInterface, DynamicTensor, [Arguments()])
101121
# Run the interface tests
102122
@test Interfaces.test(
103123
ValueInterface,
104-
Max2Tensor,
124+
DynamicTensor,
105125
[
106-
Max2Tensor{Float64}(1.0),
107-
Max2Tensor{Float64}([1, 2, 3]),
108-
Max2Tensor{Float64}([1 2 3; 4 5 6]),
126+
# up to 1D
127+
DynamicTensor{Float64,1}(1.0),
128+
DynamicTensor{Float64,1}([1, 2, 3]),
129+
# up to 2D
130+
DynamicTensor{Float64,2}(1.0),
131+
DynamicTensor{Float64,2}([1, 2, 3]),
132+
DynamicTensor{Float64,2}([1 2 3; 4 5 6]),
133+
# up to 3D
134+
DynamicTensor{Float64,3}(1.0),
135+
DynamicTensor{Float64,3}([1, 2, 3]),
136+
DynamicTensor{Float64,3}([1 2 3; 4 5 6]),
137+
DynamicTensor{Float64,3}(rand(1, 2, 3)),
109138
],
110139
)
111140

112141
# testing is_valid functions
113-
@test is_valid(Max2Tensor{Float64}())
114-
@test !is_valid(Max2Tensor{Float64}(NaN))
115-
@test is_valid_array([Max2Tensor{Float64}(1), Max2Tensor{Float64}([1, 2, 3])])
116-
@test !is_valid_array([Max2Tensor{Float64}(1), Max2Tensor{Float64}([1, 2, NaN])])
142+
@test is_valid(DynamicTensor{Float64,2}())
143+
@test !is_valid(DynamicTensor{Float64,2}(NaN))
144+
@test is_valid_array([DynamicTensor{Float64,2}(1), DynamicTensor{Float64,2}([1, 2, 3])])
145+
@test !is_valid_array([DynamicTensor{Float64,2}(1), DynamicTensor{Float64,2}([1, 2, NaN])])
117146

118147
# dummy operators
119-
q(x::Max2Tensor{T}) where {T} = Max2Tensor{T}(x.scalar)
120-
a(x::Max2Tensor{T}, y::Max2Tensor{T}) where {T} = Max2Tensor{T}(x.scalar + y.scalar)
148+
q(x::DynamicTensor{T,N}) where {T,N} = DynamicTensor{T,N}(x.data[1])
149+
function a(x::DynamicTensor{T,N}, y::DynamicTensor{T,N}) where {T,N}
150+
return DynamicTensor{T,N}(x.data[1][] + y.data[1][])
151+
end
121152

122153
operators = OperatorEnum(; binary_operators=[a], unary_operators=[q])
123154
@extend_operators(operators, on_type = Max2Tensor{Float64})

0 commit comments

Comments
 (0)