Skip to content

Commit 59cab46

Browse files
committed
Follow up gh-2728 and extend muting to any GPU with non LTS2 version
1 parent fd71f50 commit 59cab46

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

dpnp/tests/helper.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import importlib.util
2+
from enum import Enum
23
from sys import platform
34

45
import dpctl
@@ -11,6 +12,11 @@
1112
from . import config
1213

1314

15+
class LTS_VERSION(Enum):
16+
V1_3 = "1.3"
17+
V1_6 = "1.6"
18+
19+
1420
def _assert_dtype(a_dt, b_dt, check_only_type_kind=False):
1521
if check_only_type_kind:
1622
assert a_dt.kind == b_dt.kind, f"{a_dt.kind} != {b_dt.kind}"
@@ -475,13 +481,13 @@ def is_lnl(device=None):
475481
return _get_dev_mask(device) == 0x6400
476482

477483

478-
def is_lts_driver(device=None):
484+
def is_lts_driver(version=LTS_VERSION.V1_3, device=None):
479485
"""
480486
Return True if a test is running on a GPU device with LTS driver version,
481487
False otherwise.
482488
"""
483489
dev = dpctl.select_default_device() if device is None else device
484-
return dev.has_aspect_gpu and "1.3" in dev.driver_version
490+
return dev.has_aspect_gpu and version.value in dev.driver_version
485491

486492

487493
def is_ptl(device=None):

dpnp/tests/test_mathematical.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from dpnp.dpnp_utils import map_dtype_to_device
2121

2222
from .helper import (
23+
LTS_VERSION,
2324
assert_dtype_allclose,
2425
generate_random_numpy_array,
2526
get_abs_array,
@@ -33,7 +34,7 @@
3334
has_support_aspect16,
3435
has_support_aspect64,
3536
is_intel_numpy,
36-
is_ptl,
37+
is_lts_driver,
3738
numpy_version,
3839
)
3940
from .third_party.cupy import testing
@@ -218,7 +219,7 @@ def _get_exp_array(self, a, axis, dtype):
218219
@pytest.mark.parametrize("axis", [None, 2, -1])
219220
@pytest.mark.parametrize("include_initial", [True, False])
220221
def test_basic(self, dtype, axis, include_initial):
221-
if axis is None and is_ptl():
222+
if axis is None and not is_lts_driver(version=LTS_VERSION.V1_6):
222223
pytest.skip("due to SAT-8336")
223224

224225
a = dpnp.ones((3, 4, 5, 6, 7), dtype=dtype)
@@ -238,7 +239,7 @@ def test_basic(self, dtype, axis, include_initial):
238239
@pytest.mark.parametrize("axis", [None, 2, -1])
239240
@pytest.mark.parametrize("include_initial", [True, False])
240241
def test_include_initial(self, dtype, axis, include_initial):
241-
if axis is None and is_ptl():
242+
if axis is None and not is_lts_driver(version=LTS_VERSION.V1_6):
242243
pytest.skip("due to SAT-8336")
243244

244245
a = dpnp.ones((3, 4, 5, 6, 7), dtype=dtype)

0 commit comments

Comments
 (0)