11
2- using DynamicExpressions
32using DynamicExpressions:
43 DynamicExpressions as DE,
4+ ValueInterface,
55 Node,
66 @extend_operators ,
77 OperatorEnum,
@@ -12,6 +12,8 @@ using DynamicExpressions:
1212 is_valid,
1313 is_valid_array
1414
15+ using Interfaces: Interfaces, @implements , Arguments
16+
1517# Max2Tensor (Tensor with a maximum of 3 dimensions) - struct that contains all three datatypes
1618mutable struct Max2Tensor{T}
1719 dims:: UInt8 # number of dimmentions
@@ -31,6 +33,8 @@ mutable struct Max2Tensor{T}
3133 # ! format: on
3234end
3335
36+ DE. get_number_type (:: Type{<:Max2Tensor{T}} ) where {T} = T
37+
3438function DE. is_valid (val:: T ) where {Q<: Number ,T<: Max2Tensor{Q} }
3539 if val. dims == 0
3640 return is_valid (val. scalar)
@@ -45,7 +49,7 @@ function Base.:(==)(x::Max2Tensor{T}, y::Max2Tensor{T}) where {T}
4549 return false
4650 elseif x. dims == 0
4751 return x. scalar == y. scalar
48- elseif val . dims == 1
52+ elseif x . dims == 1
4953 return x. vector == y. vector
5054 end
5155 return x. matrix == y. matrix
@@ -61,8 +65,8 @@ function DE.count_scalar_constants(val::T) where {BT,T<:Max2Tensor{BT}}
6165end
6266
6367function DE. pack_scalar_constants! (
64- nvals:: AbstractVector{BT} , idx:: Int64 , val:: T
65- ) where {BT<: Number ,T <: Max2Tensor{BT} }
68+ nvals:: AbstractVector{BT} , idx:: Int64 , val:: Max2Tensor{BT}
69+ ) where {BT<: Number }
6670 if val. dims == 0
6771 nvals[idx] = val. scalar
6872 return idx + 1
@@ -77,8 +81,8 @@ function DE.pack_scalar_constants!(
7781end
7882
7983function DE. unpack_scalar_constants (
80- nvals:: AbstractVector{BT} , idx:: Int64 , val:: T
81- ) where {BT<: Number ,T <: Max2Tensor{BT} }
84+ nvals:: AbstractVector{BT} , idx:: Int64 , val:: Max2Tensor{BT}
85+ ) where {BT<: Number }
8286 if val. dims == 0
8387 val. scalar = nvals[idx]
8488 return idx + 1 , val
@@ -92,6 +96,19 @@ function DE.unpack_scalar_constants(
9296 return idx + length (val. matrix), val
9397end
9498
99+ # Declare that `Max2Tensor` implements `ValueInterface`
100+ @implements (ValueInterface, Max2Tensor, [Arguments ()])
101+ # Run the interface tests
102+ @test Interfaces. test (
103+ ValueInterface,
104+ Max2Tensor,
105+ [
106+ Max2Tensor {Float64} (1.0 ),
107+ Max2Tensor {Float64} ([1 , 2 , 3 ]),
108+ Max2Tensor {Float64} ([1 2 3 ; 4 5 6 ]),
109+ ],
110+ )
111+
95112# testing is_valid functions
96113@test is_valid (Max2Tensor {Float64} ())
97114@test ! is_valid (Max2Tensor {Float64} (NaN ))
0 commit comments