Skip to content

Commit 1b1a674

Browse files
committed
test: full ValueInterface for Max2Tensor
1 parent e2e9b21 commit 1b1a674

File tree

1 file changed

+23
-6
lines changed

1 file changed

+23
-6
lines changed

test/test_non_number_eval_tree_array.jl

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

2-
using DynamicExpressions
32
using DynamicExpressions:
43
DynamicExpressions as DE,
4+
ValueInterface,
55
Node,
66
@extend_operators,
77
OperatorEnum,
@@ -12,6 +12,8 @@ using DynamicExpressions:
1212
is_valid,
1313
is_valid_array
1414

15+
using Interfaces: Interfaces, @implements, Arguments
16+
1517
# Max2Tensor (Tensor with a maximum of 3 dimensions) - struct that contains all three datatypes
1618
mutable struct Max2Tensor{T}
1719
dims::UInt8 # number of dimmentions
@@ -31,6 +33,8 @@ mutable struct Max2Tensor{T}
3133
#! format: on
3234
end
3335

36+
DE.get_number_type(::Type{<:Max2Tensor{T}}) where {T} = T
37+
3438
function DE.is_valid(val::T) where {Q<:Number,T<:Max2Tensor{Q}}
3539
if val.dims == 0
3640
return is_valid(val.scalar)
@@ -45,7 +49,7 @@ function Base.:(==)(x::Max2Tensor{T}, y::Max2Tensor{T}) where {T}
4549
return false
4650
elseif x.dims == 0
4751
return x.scalar == y.scalar
48-
elseif val.dims == 1
52+
elseif x.dims == 1
4953
return x.vector == y.vector
5054
end
5155
return x.matrix == y.matrix
@@ -61,8 +65,8 @@ function DE.count_scalar_constants(val::T) where {BT,T<:Max2Tensor{BT}}
6165
end
6266

6367
function DE.pack_scalar_constants!(
64-
nvals::AbstractVector{BT}, idx::Int64, val::T
65-
) where {BT<:Number,T<:Max2Tensor{BT}}
68+
nvals::AbstractVector{BT}, idx::Int64, val::Max2Tensor{BT}
69+
) where {BT<:Number}
6670
if val.dims == 0
6771
nvals[idx] = val.scalar
6872
return idx + 1
@@ -77,8 +81,8 @@ function DE.pack_scalar_constants!(
7781
end
7882

7983
function DE.unpack_scalar_constants(
80-
nvals::AbstractVector{BT}, idx::Int64, val::T
81-
) where {BT<:Number,T<:Max2Tensor{BT}}
84+
nvals::AbstractVector{BT}, idx::Int64, val::Max2Tensor{BT}
85+
) where {BT<:Number}
8286
if val.dims == 0
8387
val.scalar = nvals[idx]
8488
return idx + 1, val
@@ -92,6 +96,19 @@ function DE.unpack_scalar_constants(
9296
return idx + length(val.matrix), val
9397
end
9498

99+
# Declare that `Max2Tensor` implements `ValueInterface`
100+
@implements(ValueInterface, Max2Tensor, [Arguments()])
101+
# Run the interface tests
102+
@test Interfaces.test(
103+
ValueInterface,
104+
Max2Tensor,
105+
[
106+
Max2Tensor{Float64}(1.0),
107+
Max2Tensor{Float64}([1, 2, 3]),
108+
Max2Tensor{Float64}([1 2 3; 4 5 6]),
109+
],
110+
)
111+
95112
# testing is_valid functions
96113
@test is_valid(Max2Tensor{Float64}())
97114
@test !is_valid(Max2Tensor{Float64}(NaN))

0 commit comments

Comments
 (0)