22import torch
33
44from . import workbench
5- from .workbench import input_tensor , to_nchw , to_nhwc
5+ from .workbench import input_tensor , to_nchw , to_nhwc , tensors_match
66
77
88def test_linear ():
@@ -13,7 +13,7 @@ def test_linear():
1313 result = workbench .invoke_test ("linear" , x , dict (weight = weight , bias = bias ))
1414
1515 expected = torch .nn .functional .linear (x , weight , bias )
16- assert torch . allclose (result , expected )
16+ assert tensors_match (result , expected )
1717
1818
1919@pytest .mark .parametrize ("scenario" , ["stride_1_pad_0" , "stride_2_pad_1" , "dilation_2_pad_2" ])
@@ -48,7 +48,7 @@ def test_conv_2d_depthwise(scenario: str, memory_layout: str, batch: str, backen
4848 if memory_layout == "nhwc" :
4949 result = to_nchw (result )
5050
51- assert torch . allclose (result , expected )
51+ assert tensors_match (result , expected )
5252
5353
5454@pytest .mark .parametrize ("scenario" , ["3x3" , "5x5" , "stride2" , "nhwc" ])
@@ -76,7 +76,7 @@ def test_conv_transpose_2d(scenario: str):
7676 if scenario == "nhwc" :
7777 result = to_nchw (result )
7878
79- assert torch . allclose (result , expected , rtol = 1e-2 )
79+ assert tensors_match (result , expected , rtol = 1e-2 )
8080
8181
8282# def test_batch_norm_2d():
@@ -106,7 +106,7 @@ def test_layer_norm():
106106 result = workbench .invoke_test ("layer_norm" , x , dict (weight = weight , bias = bias ))
107107
108108 expected = torch .nn .functional .layer_norm (x , [dim ], weight , bias , eps = 1e-5 )
109- assert torch . allclose (result , expected , atol = 1e-6 )
109+ assert tensors_match (result , expected , atol = 1e-6 )
110110
111111
112112@pytest .mark .parametrize ("backend" , ["cpu" , "vulkan" ])
@@ -133,7 +133,7 @@ def test_window_partition(backend: str):
133133
134134 result = workbench .invoke_test ("sam_window_partition" , x , {}, backend = backend )
135135
136- assert torch . allclose (result , expected )
136+ assert tensors_match (result , expected )
137137
138138
139139@pytest .mark .parametrize ("shift" , [(0 , 2 , - 1 , 0 ), (0 , - 2 , 0 , 3 )])
@@ -147,7 +147,7 @@ def test_roll(shift: tuple[int, int, int, int], backend: str):
147147 params = dict (s0 = shift [3 ], s1 = shift [2 ], s2 = shift [1 ], s3 = shift [0 ])
148148 result = workbench .invoke_test ("roll" , x , {}, params , backend )
149149
150- assert torch . allclose (result , expected )
150+ assert tensors_match (result , expected )
151151
152152
153153@pytest .mark .parametrize ("mode" , ["bilinear" , "bicubic" ])
@@ -169,4 +169,4 @@ def test_interpolate(mode: str, align_corners: bool, size: str, scale: float, ba
169169
170170 params = dict (mode = mode , h = target [0 ], w = target [1 ], align_corners = 1 if align_corners else 0 )
171171 result = workbench .invoke_test ("interpolate" , x , {}, params , backend )
172- assert torch . allclose (result , expected )
172+ assert tensors_match (result , expected )
0 commit comments