Skip to content

Commit 4e1a8d8

Browse files
committed
tests: keep python tests alive, they're still useful sometimes
1 parent b88b6f8 commit 4e1a8d8

File tree

5 files changed

+45
-38
lines changed

5 files changed

+45
-38
lines changed

tests/test_birefnet.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -228,10 +228,11 @@ def __init__(
228228
drop=drop,
229229
)
230230

231-
self.H = None
232-
self.W = None
231+
self.H: int | None = None
232+
self.W: int | None = None
233233

234234
def forward(self, x, mask_matrix):
235+
assert self.W is not None and self.H is not None, "W and H must be set before forward"
235236
B, L, C = x.shape
236237
H, W = self.H, self.W
237238
assert L == H * W, "input feature has wrong size"
@@ -297,7 +298,7 @@ def test_swin_block():
297298

298299
x = input_tensor(1, 36, 8)
299300
mask = torch.zeros(2, 9, 9).masked_fill(torch.rand(2, 9, 9) > 0.5, -100.0)
300-
state["mask"] = mask
301+
state["mask"] = mask.half()
301302
swin_block.W, swin_block.H = 6, 6
302303
expected = swin_block(x, None)
303304

@@ -421,7 +422,7 @@ def attention_mask(self, H, W):
421422
mask_windows = window_partition(img_mask, self.window_size)
422423
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
423424
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
424-
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0))
425+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float("-inf"))
425426
attn_mask = attn_mask.masked_fill(attn_mask == 0, float(0.0))
426427
return attn_mask
427428

@@ -453,7 +454,7 @@ def forward(self, x, H, W):
453454
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
454455
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
455456
attn_mask = (
456-
attn_mask.masked_fill(attn_mask != 0, float(-100.0))
457+
attn_mask.masked_fill(attn_mask != 0, float("-inf"))
457458
.masked_fill(attn_mask == 0, float(0.0))
458459
.to(x.dtype)
459460
)
@@ -475,8 +476,8 @@ def test_attention_mask():
475476
swin_layer = BasicLayer(8, 2, 2, window_size=window_size)
476477
expected = swin_layer.attention_mask(h, w)
477478

478-
result = torch.zeros_like(expected)
479-
result = workbench.invoke_test("biref_attention_mask", result, {})
479+
x = torch.zeros_like(expected)
480+
result = workbench.invoke_test("biref_attention_mask", x, {})
480481

481482
assert torch.allclose(result, expected)
482483

tests/test_mobile_sam.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1325,7 +1325,8 @@ def test_output_upscaling():
13251325
result = workbench.invoke_test("sam_output_upscaling", x, state, nhwc_layout, backend="vulkan")
13261326
result = to_nchw(result)
13271327

1328-
assert torch.allclose(result, expected, atol=1e-4, rtol=1e-2) # fp16 weights
1328+
workbench.print_results(result, expected)
1329+
assert torch.allclose(result, expected, rtol=0.1) # fp16 weights
13291330

13301331

13311332
class MaskDecoder(torch.nn.Module):

tests/test_primitives.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33

44
from . import workbench
5-
from .workbench import to_nchw, to_nhwc
5+
from .workbench import input_tensor, to_nchw, to_nhwc
66

77

88
def test_linear():
@@ -43,56 +43,59 @@ def test_conv_2d_depthwise(scenario: str, memory_layout: str, batch: str, backen
4343
x = to_nhwc(x)
4444
k = k.permute(2, 3, 1, 0)
4545
test_case = f"conv_2d_depthwise_{memory_layout}"
46-
params = dict(stride=stride, pad=pad, dilation=dilate)
46+
params = dict(stride=stride, pad=pad, dilation=dilate, memory_layout=memory_layout)
4747
result = workbench.invoke_test(test_case, x, dict(weight=k), params, backend)
4848
if memory_layout == "nhwc":
4949
result = to_nchw(result)
5050

5151
assert torch.allclose(result, expected)
5252

5353

54-
@pytest.mark.parametrize("scenario", ["3x3", "5x5", "stride2"])
54+
@pytest.mark.parametrize("scenario", ["3x3", "5x5", "stride2", "nhwc"])
5555
def test_conv_transpose_2d(scenario: str):
5656
ksize, stride = {
5757
"3x3": (3, 1),
5858
"5x5": (5, 1),
5959
"stride2": (3, 2),
60-
"nchw": (3, 1),
60+
"nhwc": (3, 1),
6161
}[scenario]
62-
x = torch.arange(2 * 11 * 4 * 5).reshape(2, 11, 4, 5).float()
63-
weight = torch.arange(11 * 2 * ksize * ksize).reshape(11, 2, ksize, ksize).float()
62+
x = input_tensor(2, 11, 4, 5)
63+
weight = input_tensor(11, 2, ksize, ksize)
6464
bias = None
6565
expected = torch.nn.functional.conv_transpose2d(x, weight, bias, stride=stride)
6666

67-
x = to_nhwc(x) # -> [N, H, W, C_in]
67+
if scenario == "nhwc":
68+
x = to_nhwc(x) # -> [N, H, W, C_in]
6869
result = workbench.invoke_test(
6970
"conv_transpose_2d",
7071
x,
7172
dict(weight=weight),
72-
dict(stride=stride),
73+
dict(stride=stride, memory_layout="nhwc" if scenario == "nhwc" else "nchw"),
7374
backend="vulkan",
7475
)
75-
result = to_nchw(result)
76+
if scenario == "nhwc":
77+
result = to_nchw(result)
7678

77-
assert torch.allclose(result, expected)
79+
workbench.print_results(result, expected)
80+
assert torch.allclose(result, expected, rtol=1e-2)
7881

7982

80-
def test_batch_norm_2d():
81-
x = torch.rand(1, 3, 4, 5)
82-
weight = torch.rand(3)
83-
bias = torch.rand(3)
84-
mean = torch.rand(3)
85-
var = torch.arange(1, 4).float()
86-
expected = torch.nn.functional.batch_norm(x, mean, var, weight, bias, eps=1e-5)
83+
# def test_batch_norm_2d():
84+
# x = torch.rand(1, 3, 4, 5)
85+
# weight = torch.rand(3)
86+
# bias = torch.rand(3)
87+
# mean = torch.rand(3)
88+
# var = torch.arange(1, 4).float()
89+
# expected = torch.nn.functional.batch_norm(x, mean, var, weight, bias, eps=1e-5)
8790

88-
x = to_nhwc(x)
91+
# x = to_nhwc(x)
8992

90-
var = (var + 1e-5).sqrt()
91-
state = dict(weight=weight, bias=bias, running_mean=mean, running_var=var)
92-
result = workbench.invoke_test("batch_norm_2d", x, state)
93-
result = to_nchw(result)
93+
# var = (var + 1e-5).sqrt()
94+
# state = dict(weight=weight, bias=bias, running_mean=mean, running_var=var)
95+
# result = workbench.invoke_test("batch_norm_2d", x, state, dict(memory_layout="nhwc"))
96+
# result = to_nchw(result)
9497

95-
assert torch.allclose(result, expected)
98+
# assert torch.allclose(result, expected)
9699

97100

98101
def test_layer_norm():

tests/workbench.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,10 +284,11 @@ DEF(biref_patch_merging)(model_ref m, span<tensor> input, param_dict const& p) {
284284
return {swin::patch_merging(m, input[0], 6, 4)};
285285
}
286286

287-
DEF(biref_attention_mask)(model_ref m, span<tensor> input, param_dict const& p) {
288-
auto dst = span((byte*)input[0]->data, ggml_nbytes(input[0]));
289-
swin::compute_attention_mask(dst, 18, 18, 6);
290-
return {input[0]};
287+
DEF(biref_attention_mask)(model_ref m, span<tensor> /*input*/, param_dict const& p) {
288+
auto mask = swin::create_attention_mask(m, 18, 18, 6);
289+
ggml_backend_alloc_ctx_tensors(m, workbench_backend());
290+
transfer_to_backend(mask);
291+
return {ggml_cast(m, mask.x, GGML_TYPE_F32)};
291292
}
292293

293294
DEF(biref_swin_layer)(model_ref m, span<tensor> input, param_dict const& p) {

tests/workbench.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import ctypes
22
from functools import reduce
3+
from typing import Mapping
34
import torch
45
import os
56

@@ -66,7 +67,7 @@ def raw_to_torch_tensor(raw_tensor: RawTensor):
6667
).reshape(shape)
6768

6869

69-
def encode_params(params: dict[str, str | int | float]):
70+
def encode_params(params: Mapping[str, str | int | float]):
7071
raw_params = []
7172
for name, value in params.items():
7273
ptype = 0
@@ -109,7 +110,7 @@ def invoke_test(
109110
test_case: str,
110111
input: torch.Tensor | list[torch.Tensor],
111112
state: dict[str, torch.Tensor],
112-
params: dict[str, str | int | float] = {},
113+
params: Mapping[str, str | int | float] = {},
113114
backend: str = "cpu",
114115
):
115116
input = input if isinstance(input, list) else [input]
@@ -142,7 +143,7 @@ def invoke_test(
142143
return output
143144

144145

145-
def input_tensor(*shape: tuple[int]):
146+
def input_tensor(*shape: int):
146147
end = reduce(lambda x, y: x * y, shape, 1)
147148
return torch.arange(0, end).reshape(*shape) / end
148149

0 commit comments

Comments
 (0)