66from torch import Tensor
77
88from . import workbench
9- from .workbench import to_nhwc , to_nchw , convert_to_nhwc , fuse_conv_2d_batch_norm
9+ from .workbench import to_nhwc , to_nchw , convert_to_nhwc , fuse_conv_2d_batch_norm , tensors_match
1010
1111torch .set_printoptions (precision = 2 , linewidth = 100 , sci_mode = False )
1212
@@ -53,7 +53,7 @@ def test_conv_2d_batch_norm(bias: bool):
5353 result = workbench .invoke_test ("sam_conv_2d_batch_norm" , x , state , nhwc_layout )
5454 result = to_nchw (result )
5555
56- assert torch . allclose (result , expected )
56+ assert tensors_match (result , expected )
5757
5858
5959class PatchEmbed (torch .nn .Module ):
@@ -98,7 +98,7 @@ def test_patch_embed():
9898 result = workbench .invoke_test ("sam_patch_embed" , x , state , nhwc_layout )
9999 result = to_nchw (result )
100100
101- assert torch . allclose (result , expected , rtol = 0.001 , atol = 0.02 )
101+ assert tensors_match (result , expected , rtol = 0.001 , atol = 0.02 )
102102
103103
104104class LayerNorm2d (torch .nn .Module ):
@@ -130,7 +130,7 @@ def test_layer_norm_2d():
130130 result = workbench .invoke_test ("layer_norm" , x , state , nhwc_layout )
131131 result = to_nchw (result )
132132
133- assert torch . allclose (result , expected , rtol = 0.001 , atol = 0.02 )
133+ assert tensors_match (result , expected , rtol = 0.001 , atol = 0.02 )
134134
135135
136136class MBConv (torch .nn .Module ):
@@ -193,7 +193,7 @@ def test_mb_conv():
193193 result = to_nchw (result )
194194
195195 # precision: ggml_gelu uses fp16 look-up table & tanh approximation
196- assert torch . allclose (result , expected , rtol = 0.001 , atol = 0.02 )
196+ assert tensors_match (result , expected , rtol = 0.001 , atol = 0.02 )
197197
198198
199199class PatchMerging (torch .nn .Module ):
@@ -244,7 +244,7 @@ def test_patch_merging():
244244 result = result .transpose (1 , 2 ).reshape_as (expected )
245245
246246 # precision: ggml_gelu uses fp16 look-up table & tanh approximation
247- assert torch . allclose (result , expected , rtol = 0.001 , atol = 0.02 )
247+ assert tensors_match (result , expected , rtol = 0.001 , atol = 0.02 )
248248
249249
250250class Mlp (torch .nn .Module ):
@@ -288,7 +288,7 @@ def test_mlp():
288288 result = workbench .invoke_test ("sam_mlp" , x , state )
289289
290290 # precision: ggml_gelu uses fp16 look-up table & tanh approximation
291- assert torch . allclose (result , expected , rtol = 0.001 , atol = 0.02 )
291+ assert tensors_match (result , expected , rtol = 0.001 , atol = 0.02 )
292292
293293
294294class AttentionRelBias (torch .nn .Module ):
@@ -370,8 +370,8 @@ def forward(self, x): # x (B,N,C)
370370 x = self .proj (x )
371371 return x
372372
373-
374- def test_attention_rel_bias ():
373+ @ pytest . mark . parametrize ( "attn" , [ "default" , "flash_attn" ])
374+ def test_attention_rel_bias (attn : str ):
375375 attention = AttentionRelBias (4 , 2 , num_heads = 2 , attn_ratio = 1 , resolution = (3 , 3 ))
376376 state = workbench .randomize (attention .state_dict ())
377377 attention .load_state_dict (state )
@@ -381,9 +381,9 @@ def test_attention_rel_bias():
381381 expected = attention (x )
382382
383383 state ["attention_biases_indexed" ] = state ["attention_biases" ][:, attention .attention_bias_idxs ]
384- result = workbench .invoke_test ("sam_attention_rel_bias" , x , state )
384+ result = workbench .invoke_test ("sam_attention_rel_bias" , x , state , { "attn" : attn } )
385385
386- assert torch . allclose (result , expected , atol = 0.001 )
386+ assert tensors_match (result , expected , atol = 0.001 )
387387
388388
389389class TinyViTBlock (torch .nn .Module ):
@@ -495,7 +495,7 @@ def test_tiny_vit_block():
495495 state = convert_to_nhwc (state )
496496 result = workbench .invoke_test ("sam_tiny_vit_block" , x , state , nhwc_layout )
497497
498- assert torch . allclose (result , expected , rtol = 0.001 , atol = 0.02 )
498+ assert tensors_match (result , expected , rtol = 0.001 , atol = 0.02 )
499499
500500
501501class ConvLayer (torch .nn .Module ):
@@ -787,7 +787,7 @@ def test_tiny_vit():
787787 # result = torch.zeros_like(expected).contiguous()
788788 # result = workbench.invoke_test("sam_tiny_vit", x, state)
789789
790- # assert torch.allclose (result, expected, rtol=0.001, atol=0.02)
790+ # assert tensors_match (result, expected, rtol=0.001, atol=0.02)
791791
792792
793793#
@@ -835,7 +835,7 @@ def test_position_embedding_random():
835835
836836 result = workbench .invoke_test ("sam_position_embedding_random" , x , state )
837837
838- assert torch . allclose (result , expected )
838+ assert tensors_match (result , expected )
839839
840840
841841class PromptEncoder (torch .nn .Module ):
@@ -951,7 +951,7 @@ def test_prompt_encoder_points():
951951 points = torch .cat ([points , - torch .ones (1 , 1 , 2 )], dim = 1 )
952952 result = workbench .invoke_test ("sam_embed_points" , points , state )
953953
954- assert torch . allclose (result , expected )
954+ assert tensors_match (result , expected )
955955
956956
957957def test_prompt_encoder_box ():
@@ -970,7 +970,7 @@ def test_prompt_encoder_box():
970970
971971 result = workbench .invoke_test ("sam_embed_box" , boxes , state )
972972
973- assert torch . allclose (result , expected )
973+ assert tensors_match (result , expected )
974974
975975
976976#
@@ -1046,7 +1046,7 @@ def test_attention():
10461046 state ["input_v" ] = v
10471047 result = workbench .invoke_test ("sam_attention" , q , state )
10481048
1049- assert torch . allclose (result , expected )
1049+ assert tensors_match (result , expected )
10501050
10511051
10521052class MLPBlock (torch .nn .Module ):
@@ -1155,8 +1155,8 @@ def test_two_way_attention_block(mode):
11551155 "sam_two_way_attention_block" , queries , state , {"mode" : mode }
11561156 )
11571157
1158- assert torch . allclose (result_queries , expected_queries )
1159- assert torch . allclose (result_keys , expected_keys )
1158+ assert tensors_match (result_queries , expected_queries )
1159+ assert tensors_match (result_keys , expected_keys )
11601160
11611161
11621162class TwoWayTransformer (torch .nn .Module ):
@@ -1257,8 +1257,8 @@ def test_two_way_transformer():
12571257 "sam_two_way_transformer" , image_embedding , state , nhwc_layout
12581258 )
12591259
1260- assert torch . allclose (result_queries , expected_queries , atol = 1e-6 , rtol = 1e-4 )
1261- assert torch . allclose (result_keys , expected_keys , atol = 1e-6 , rtol = 1e-4 )
1260+ assert tensors_match (result_queries , expected_queries , atol = 1e-6 , rtol = 1e-4 )
1261+ assert tensors_match (result_keys , expected_keys , atol = 1e-6 , rtol = 1e-4 )
12621262
12631263
12641264class HypernetworkMLP (torch .nn .Module ):
@@ -1297,7 +1297,7 @@ def test_hypernetwork_mlp():
12971297
12981298 result = workbench .invoke_test ("sam_hypernetwork_mlp" , x , state )
12991299
1300- assert torch . allclose (result , expected )
1300+ assert tensors_match (result , expected )
13011301
13021302
13031303def output_upscaling (transformer_dim : int , activation = torch .nn .GELU ):
@@ -1325,8 +1325,7 @@ def test_output_upscaling():
13251325 result = workbench .invoke_test ("sam_output_upscaling" , x , state , nhwc_layout , backend = "vulkan" )
13261326 result = to_nchw (result )
13271327
1328- workbench .print_results (result , expected )
1329- assert torch .allclose (result , expected , rtol = 0.1 ) # fp16 weights
1328+ assert tensors_match (result , expected , rtol = 0.1 ) # fp16 weights
13301329
13311330
13321331class MaskDecoder (torch .nn .Module ):
@@ -1465,5 +1464,5 @@ def test_predict_masks():
14651464 "sam_predict_masks" , image_embeddings , state , nhwc_layout , backend = "vulkan"
14661465 )
14671466
1468- assert torch . allclose (result_masks , expected_masks , rtol = 1e-2 , atol = 1e-2 )
1469- assert torch . allclose (result_iou_pred , iou_pred , rtol = 1e-2 )
1467+ assert tensors_match (result_masks , expected_masks , rtol = 1e-2 , atol = 1e-2 )
1468+ assert tensors_match (result_iou_pred , iou_pred , rtol = 1e-2 )
0 commit comments