11
2+ using Base. Cartesian: @nif
23using DynamicExpressions:
34 DynamicExpressions as DE,
45 ValueInterface,
@@ -14,110 +15,140 @@ using DynamicExpressions:
1415
1516using Interfaces: Interfaces, @implements , Arguments
1617
17- # Max2Tensor (Tensor with a maximum of 3 dimensions) - struct that contains all three datatypes
18- mutable struct Max2Tensor{T}
19- dims:: UInt8 # number of dimmentions
20- scalar:: T
21- vector:: Vector{T}
22- matrix:: Matrix{T}
18+ """
19+ DynamicTensor{T,N,A<:Tuple{Base.RefValue{T},Vararg}}
20+
21+ A tensor with a maximum of `N` dimensions, where `N` is a positive integer.
22+ `.data[1]` is a `Ref` to the scalar value (so that it is mutable), while
23+ `.data[n]` is the array of dimension `n-1`.
24+ """
25+ struct DynamicTensor{T,N,A<: Tuple{Base.RefValue{T},Vararg} }
26+ dims:: UInt8
27+ data:: A
2328
2429 # ! format: off
25- function Max2Tensor {T} (x:: A = nothing ) where {T,A<: Union{Nothing,Number,Vector{<:Number},Matrix{<:Number}} }
26- return new (
30+ function DynamicTensor {T,N} (x:: A = nothing ) where {T,N,A<: Union{Nothing,Number,Array{<:Number}} }
31+ nd = x === nothing ? 0 : ndims (x)
32+ data = (
33+ Ref (x isa Number ? Base. convert (T, x) : zero (T)),
34+ ntuple (
35+ i -> if nd == i
36+ Array {T,i} (x)
37+ else
38+ Array {T,i} (undef, ntuple (_ -> 0 , Val (i))... )
39+ end ,
40+ Val (N)
41+ )... ,
42+ )
43+ return new {T,N,typeof(data)} (
2744 x === nothing ? 0 : ndims (x),
28- x isa Number ? Base. convert (T, x) : zero (T),
29- x isa Vector ? Vector {T} (x) : Vector {T} (undef, 0 ),
30- x isa Matrix ? Matrix {T} (x) : Matrix {T} (undef, 0 , 0 ),
45+ data
3146 )
3247 end
3348 # ! format: on
3449end
35-
36- DE. get_number_type (:: Type{<:Max2Tensor{T}} ) where {T} = T
37-
38- function DE. is_valid (val:: T ) where {Q<: Number ,T<: Max2Tensor{Q} }
39- if val. dims == 0
40- return is_valid (val. scalar)
41- elseif val. dims == 1
42- return is_valid_array (val. vector)
50+ const Max2Tensor{T} = DynamicTensor{T,2 ,Tuple{Base. RefValue{T},Vector{T},Matrix{T}}}
51+ Max2Tensor {T} (x) where {T} = DynamicTensor {T,2} (x)
52+
53+ DE. get_number_type (:: Type{<:DynamicTensor{T}} ) where {T} = T
54+
55+ @generated function DE. is_valid (val:: DynamicTensor{<:Any,N} ) where {N}
56+ quote
57+ @nif ($ N, i -> i == val. dims + 1 , i -> if i == 1
58+ is_valid (val. data[i][])
59+ else
60+ is_valid_array (val. data[i])
61+ end )
4362 end
44- return is_valid_array (val. matrix)
4563end
46-
47- function Base.:(== )(x:: Max2Tensor{T} , y:: Max2Tensor{T} ) where {T}
48- if x. dims != y. dims
49- return false
50- elseif x. dims == 0
51- return x. scalar == y. scalar
52- elseif x. dims == 1
53- return x. vector == y. vector
64+ @generated function Base.:(== )(
65+ x:: DynamicTensor{<:Any,N} , y:: DynamicTensor{<:Any,N}
66+ ) where {N}
67+ quote
68+ x. dims != y. dims && return false
69+ @nif ($ N, i -> i == x. dims + 1 , i -> if i == 1
70+ x. data[i][] == y. data[i][]
71+ else
72+ x. data[i] == y. data[i]
73+ end )
5474 end
55- return x. matrix == y. matrix
5675end
5776
58- function DE. count_scalar_constants (val:: T ) where {BT,T<: Max2Tensor{BT} }
59- if val. dims == 0
60- return 1
61- elseif val. dims == 1
62- return length (val. vector)
77+ @generated function DE. count_scalar_constants (val:: DynamicTensor{<:Any,N} ) where {N}
78+ quote
79+ @nif ($ N, i -> i == val. dims + 1 , i -> i == 1 ? 1 : length (val. data[i]))
6380 end
64- return length (val. matrix)
6581end
6682
67- function DE. pack_scalar_constants! (
68- nvals:: AbstractVector{BT} , idx:: Int64 , val:: Max2Tensor{BT}
69- ) where {BT<: Number }
70- if val. dims == 0
71- nvals[idx] = val. scalar
72- return idx + 1
73- elseif val. dims == 1
74- @view (nvals[idx: (idx + length (val. vector) - 1 )]) .= val. vector
75- return idx + length (val. vector)
83+ @generated function DE. pack_scalar_constants! (
84+ nvals:: AbstractVector{BT} , idx:: Int64 , val:: DynamicTensor{BT,N}
85+ ) where {BT<: Number ,N}
86+ quote
87+ @nif ($ N, i -> i == val. dims + 1 , i -> if i == 1
88+ nvals[idx] = val. data[i][]
89+ idx + 1
90+ else
91+ data = val. data[i]
92+ num = length (data)
93+ copyto! (nvals, idx, @view (data[:]))
94+ idx + num
95+ end )
7696 end
77- @view (nvals[idx: (idx + length (val. matrix) - 1 )]) .= reshape (
78- val. matrix, length (val. matrix)
79- )
80- return idx + length (val. matrix)
8197end
8298
83- function DE. unpack_scalar_constants (
84- nvals:: AbstractVector{BT} , idx:: Int64 , val:: Max2Tensor{BT}
85- ) where {BT<: Number }
86- if val. dims == 0
87- val. scalar = nvals[idx]
88- return idx + 1 , val
89- elseif val. dims == 1
90- val. vector .= @view (nvals[idx: (idx + length (val. vector) - 1 )])
91- return idx + length (val. vector), val
99+ @generated function DE. unpack_scalar_constants (
100+ nvals:: AbstractVector{BT} , idx:: Int64 , val:: DynamicTensor{BT,N}
101+ ) where {BT<: Number ,N}
102+ quote
103+ @nif (
104+ $ N,
105+ i -> i == val. dims + 1 ,
106+ i -> if i == 1
107+ val. data[i][] = nvals[idx]
108+ (idx + 1 , val)
109+ else
110+ data = val. data[i]
111+ num = length (data)
112+ copyto! (data, @view (nvals[idx: (idx + num - 1 )]))
113+ (idx + num, val)
114+ end
115+ )
92116 end
93- reshape (val. matrix, length (val. matrix)) .= @view (
94- nvals[idx: (idx + length (val. matrix) - 1 )]
95- )
96- return idx + length (val. matrix), val
97117end
98118
99- # Declare that `Max2Tensor ` implements `ValueInterface`
100- @implements (ValueInterface, Max2Tensor , [Arguments ()])
119+ # Declare that `DynamicTensor ` implements `ValueInterface`
120+ @implements (ValueInterface, DynamicTensor , [Arguments ()])
101121# Run the interface tests
102122@test Interfaces. test (
103123 ValueInterface,
104- Max2Tensor ,
124+ DynamicTensor ,
105125 [
106- Max2Tensor {Float64} (1.0 ),
107- Max2Tensor {Float64} ([1 , 2 , 3 ]),
108- Max2Tensor {Float64} ([1 2 3 ; 4 5 6 ]),
126+ # up to 1D
127+ DynamicTensor {Float64,1} (1.0 ),
128+ DynamicTensor {Float64,1} ([1 , 2 , 3 ]),
129+ # up to 2D
130+ DynamicTensor {Float64,2} (1.0 ),
131+ DynamicTensor {Float64,2} ([1 , 2 , 3 ]),
132+ DynamicTensor {Float64,2} ([1 2 3 ; 4 5 6 ]),
133+ # up to 3D
134+ DynamicTensor {Float64,3} (1.0 ),
135+ DynamicTensor {Float64,3} ([1 , 2 , 3 ]),
136+ DynamicTensor {Float64,3} ([1 2 3 ; 4 5 6 ]),
137+ DynamicTensor {Float64,3} (rand (1 , 2 , 3 )),
109138 ],
110139)
111140
112141# testing is_valid functions
113- @test is_valid (Max2Tensor {Float64} ())
114- @test ! is_valid (Max2Tensor {Float64} (NaN ))
115- @test is_valid_array ([Max2Tensor {Float64} (1 ), Max2Tensor {Float64} ([1 , 2 , 3 ])])
116- @test ! is_valid_array ([Max2Tensor {Float64} (1 ), Max2Tensor {Float64} ([1 , 2 , NaN ])])
142+ @test is_valid (DynamicTensor {Float64,2 } ())
143+ @test ! is_valid (DynamicTensor {Float64,2 } (NaN ))
144+ @test is_valid_array ([DynamicTensor {Float64,2 } (1 ), DynamicTensor {Float64,2 } ([1 , 2 , 3 ])])
145+ @test ! is_valid_array ([DynamicTensor {Float64,2 } (1 ), DynamicTensor {Float64,2 } ([1 , 2 , NaN ])])
117146
118147# dummy operators
119- q (x:: Max2Tensor{T} ) where {T} = Max2Tensor {T} (x. scalar)
120- a (x:: Max2Tensor{T} , y:: Max2Tensor{T} ) where {T} = Max2Tensor {T} (x. scalar + y. scalar)
148+ q (x:: DynamicTensor{T,N} ) where {T,N} = DynamicTensor {T,N} (x. data[1 ])
149+ function a (x:: DynamicTensor{T,N} , y:: DynamicTensor{T,N} ) where {T,N}
150+ return DynamicTensor {T,N} (x. data[1 ][] + y. data[1 ][])
151+ end
121152
122153operators = OperatorEnum (; binary_operators= [a], unary_operators= [q])
123154@extend_operators (operators, on_type = Max2Tensor{Float64})
0 commit comments