From a55fb95174224d8e73e4c18c3a1edb92867d5ced Mon Sep 17 00:00:00 2001 From: ckvermaAI Date: Wed, 18 Jun 2025 10:56:48 +0300 Subject: [PATCH] Update unit tests for HPU --- tests/test_modules.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/test_modules.py b/tests/test_modules.py index e35afb214..8946522d3 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -284,7 +284,8 @@ def test_linear_kbit_fp32_bias(device, module): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("module", module_dict.values(), ids=module_dict.keys()) -def test_kbit_backprop(device, module): +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_kbit_backprop(device, module, dtype): b = 16 dim1 = 36 dim2 = 84 @@ -298,24 +299,28 @@ def test_kbit_backprop(device, module): kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 128)]) - if device == "hpu" and isinstance(kbit[1], bnb.nn.Linear4bit) and kbit[1].weight.quant_type == "fp4": - pytest.skip("FP4 is not supported on HPU") + if ( + device == "hpu" + and isinstance(kbit[1], bnb.nn.Linear4bit) + and not is_supported_on_hpu(kbit[1].weight.quant_type, dtype) + ): + pytest.skip("This configuration not supported on HPU") kbit[0].weight.detach().copy_(ref[0].weight) kbit[1].weight.detach().copy_(ref[1].weight) kbit[0].bias.detach().copy_(ref[0].bias) kbit[1].bias.detach().copy_(ref[1].bias) kbit[1].weight.requires_grad_(False) - ref = ref.half().to(device) - kbit = kbit.half().to(device) - kbit = kbit.half().to(device) + ref = ref.to(device=device, dtype=dtype) + kbit = kbit.to(device=device, dtype=dtype) + kbit = kbit.to(device=device, dtype=dtype) errs1 = [] errs2 = [] relerrs1 = [] relerrs2 = [] for i in range(100): - batch = torch.randn(b, dim1, device=device, dtype=torch.float16) + batch = torch.randn(b, dim1, device=device, dtype=dtype) out1 = ref(batch) out2 = kbit(batch) out1.mean().backward()