4545def test_allocate_usm_ndarray (shape , usm_type ):
4646 q = get_queue_or_skip ()
4747 X = dpt .usm_ndarray (
48- shape , dtype = "d " , buffer = usm_type , buffer_ctor_kwargs = {"queue" : q }
48+ shape , dtype = "i8 " , buffer = usm_type , buffer_ctor_kwargs = {"queue" : q }
4949 )
50- Xnp = np .ndarray (shape , dtype = "d " )
50+ Xnp = np .ndarray (shape , dtype = "i8 " )
5151 assert X .usm_type == usm_type
5252 assert X .sycl_context == q .sycl_context
5353 assert X .sycl_device == q .sycl_device
@@ -57,13 +57,17 @@ def test_allocate_usm_ndarray(shape, usm_type):
5757
5858
5959def test_usm_ndarray_flags ():
60- assert dpt .usm_ndarray ((5 ,)).flags .fc
61- assert dpt .usm_ndarray ((5 , 2 )).flags .c_contiguous
62- assert dpt .usm_ndarray ((5 , 2 ), order = "F" ).flags .f_contiguous
63- assert dpt .usm_ndarray ((5 , 1 , 2 ), order = "F" ).flags .f_contiguous
64- assert dpt .usm_ndarray ((5 , 1 , 2 ), strides = (2 , 0 , 1 )).flags .c_contiguous
65- assert dpt .usm_ndarray ((5 , 1 , 2 ), strides = (1 , 0 , 5 )).flags .f_contiguous
66- assert dpt .usm_ndarray ((5 , 1 , 1 ), strides = (1 , 0 , 1 )).flags .fc
60+ assert dpt .usm_ndarray ((5 ,), dtype = "i4" ).flags .fc
61+ assert dpt .usm_ndarray ((5 , 2 ), dtype = "i4" ).flags .c_contiguous
62+ assert dpt .usm_ndarray ((5 , 2 ), dtype = "i4" , order = "F" ).flags .f_contiguous
63+ assert dpt .usm_ndarray ((5 , 1 , 2 ), dtype = "i4" , order = "F" ).flags .f_contiguous
64+ assert dpt .usm_ndarray (
65+ (5 , 1 , 2 ), dtype = "i4" , strides = (2 , 0 , 1 )
66+ ).flags .c_contiguous
67+ assert dpt .usm_ndarray (
68+ (5 , 1 , 2 ), dtype = "i4" , strides = (1 , 0 , 5 )
69+ ).flags .f_contiguous
70+ assert dpt .usm_ndarray ((5 , 1 , 1 ), dtype = "i4" , strides = (1 , 0 , 1 )).flags .fc
6771
6872
6973@pytest .mark .parametrize (
@@ -88,6 +92,8 @@ def test_usm_ndarray_flags():
8892 ],
8993)
9094def test_dtypes (dtype ):
95+ q = get_queue_or_skip ()
96+ skip_if_dtype_not_supported (dtype , q )
9197 Xusm = dpt .usm_ndarray ((1 ,), dtype = dtype )
9298 assert Xusm .itemsize == dpt .dtype (dtype ).itemsize
9399 expected_fmt = (dpt .dtype (dtype ).str )[1 :]
@@ -169,15 +175,15 @@ def test_copy_scalar_with_method(method, shape, dtype):
169175@pytest .mark .parametrize ("func" , [bool , float , int , complex ])
170176@pytest .mark .parametrize ("shape" , [(2 ,), (1 , 2 ), (3 , 4 , 5 ), (0 ,)])
171177def test_copy_scalar_invalid_shape (func , shape ):
172- X = dpt .usm_ndarray (shape )
178+ X = dpt .usm_ndarray (shape , dtype = "i8" )
173179 with pytest .raises (ValueError ):
174180 func (X )
175181
176182
177183def test_index_noninteger ():
178184 import operator
179185
180- X = dpt .usm_ndarray (1 , "d " )
186+ X = dpt .usm_ndarray (1 , "f4 " )
181187 with pytest .raises (IndexError ):
182188 operator .index (X )
183189
@@ -283,7 +289,7 @@ def test_slice_suai(usm_type):
283289
284290
285291def test_slicing_basic ():
286- Xusm = dpt .usm_ndarray ((10 , 5 ), dtype = "c16 " )
292+ Xusm = dpt .usm_ndarray ((10 , 5 ), dtype = "c8 " )
287293 Xusm [None ]
288294 Xusm [...]
289295 Xusm [8 ]
@@ -318,20 +324,20 @@ def test_ctor_invalid_order():
318324
319325
320326def test_ctor_buffer_kwarg ():
321- dpt .usm_ndarray (10 , buffer = b"device" )
327+ dpt .usm_ndarray (10 , dtype = "i8" , buffer = b"device" )
322328 with pytest .raises (ValueError ):
323329 dpt .usm_ndarray (10 , buffer = "invalid_param" )
324- Xusm = dpt .usm_ndarray ((10 , 5 ), dtype = "c16 " )
330+ Xusm = dpt .usm_ndarray ((10 , 5 ), dtype = "c8 " )
325331 X2 = dpt .usm_ndarray (Xusm .shape , buffer = Xusm , dtype = Xusm .dtype )
326332 assert np .array_equal (
327333 Xusm .usm_data .copy_to_host (), X2 .usm_data .copy_to_host ()
328334 )
329335 with pytest .raises (ValueError ):
330- dpt .usm_ndarray (10 , buffer = dict ())
336+ dpt .usm_ndarray (10 , dtype = "i4" , buffer = dict ())
331337
332338
333339def test_usm_ndarray_props ():
334- Xusm = dpt .usm_ndarray ((10 , 5 ), dtype = "c16 " , order = "F" )
340+ Xusm = dpt .usm_ndarray ((10 , 5 ), dtype = "c8 " , order = "F" )
335341 Xusm .ndim
336342 repr (Xusm )
337343 Xusm .flags
@@ -348,7 +354,7 @@ def test_usm_ndarray_props():
348354
349355
350356def test_datapi_device ():
351- X = dpt .usm_ndarray (1 )
357+ X = dpt .usm_ndarray (1 , dtype = "i4" )
352358 dev_t = type (X .device )
353359 with pytest .raises (TypeError ):
354360 dev_t ()
@@ -387,7 +393,7 @@ def _pyx_capi_fnptr_to_callable(
387393
388394
389395def test_pyx_capi_get_data ():
390- X = dpt .usm_ndarray (17 )[1 ::2 ]
396+ X = dpt .usm_ndarray (17 , dtype = "i8" )[1 ::2 ]
391397 get_data_fn = _pyx_capi_fnptr_to_callable (
392398 X ,
393399 "UsmNDArray_GetData" ,
@@ -400,7 +406,7 @@ def test_pyx_capi_get_data():
400406
401407
402408def test_pyx_capi_get_shape ():
403- X = dpt .usm_ndarray (17 )[1 ::2 ]
409+ X = dpt .usm_ndarray (17 , dtype = "u4" )[1 ::2 ]
404410 get_shape_fn = _pyx_capi_fnptr_to_callable (
405411 X ,
406412 "UsmNDArray_GetShape" ,
@@ -413,7 +419,7 @@ def test_pyx_capi_get_shape():
413419
414420
415421def test_pyx_capi_get_strides ():
416- X = dpt .usm_ndarray (17 )[1 ::2 ]
422+ X = dpt .usm_ndarray (17 , dtype = "f4" )[1 ::2 ]
417423 get_strides_fn = _pyx_capi_fnptr_to_callable (
418424 X ,
419425 "UsmNDArray_GetStrides" ,
@@ -429,7 +435,7 @@ def test_pyx_capi_get_strides():
429435
430436
431437def test_pyx_capi_get_ndim ():
432- X = dpt .usm_ndarray (17 )[1 ::2 ]
438+ X = dpt .usm_ndarray (17 , dtype = "?" )[1 ::2 ]
433439 get_ndim_fn = _pyx_capi_fnptr_to_callable (
434440 X ,
435441 "UsmNDArray_GetNDim" ,
@@ -440,7 +446,7 @@ def test_pyx_capi_get_ndim():
440446
441447
442448def test_pyx_capi_get_typenum ():
443- X = dpt .usm_ndarray (17 )[1 ::2 ]
449+ X = dpt .usm_ndarray (17 , dtype = "c8" )[1 ::2 ]
444450 get_typenum_fn = _pyx_capi_fnptr_to_callable (
445451 X ,
446452 "UsmNDArray_GetTypenum" ,
@@ -453,7 +459,7 @@ def test_pyx_capi_get_typenum():
453459
454460
455461def test_pyx_capi_get_elemsize ():
456- X = dpt .usm_ndarray (17 )[1 ::2 ]
462+ X = dpt .usm_ndarray (17 , dtype = "u8" )[1 ::2 ]
457463 get_elemsize_fn = _pyx_capi_fnptr_to_callable (
458464 X ,
459465 "UsmNDArray_GetElementSize" ,
@@ -466,7 +472,7 @@ def test_pyx_capi_get_elemsize():
466472
467473
468474def test_pyx_capi_get_flags ():
469- X = dpt .usm_ndarray (17 )[1 ::2 ]
475+ X = dpt .usm_ndarray (17 , dtype = "i8" )[1 ::2 ]
470476 get_flags_fn = _pyx_capi_fnptr_to_callable (
471477 X ,
472478 "UsmNDArray_GetFlags" ,
@@ -478,7 +484,7 @@ def test_pyx_capi_get_flags():
478484
479485
480486def test_pyx_capi_get_offset ():
481- X = dpt .usm_ndarray (17 )[1 ::2 ]
487+ X = dpt .usm_ndarray (17 , dtype = "u2" )[1 ::2 ]
482488 get_offset_fn = _pyx_capi_fnptr_to_callable (
483489 X ,
484490 "UsmNDArray_GetOffset" ,
@@ -491,7 +497,7 @@ def test_pyx_capi_get_offset():
491497
492498
493499def test_pyx_capi_get_queue_ref ():
494- X = dpt .usm_ndarray (17 )[1 ::2 ]
500+ X = dpt .usm_ndarray (17 , dtype = "i2" )[1 ::2 ]
495501 get_queue_ref_fn = _pyx_capi_fnptr_to_callable (
496502 X ,
497503 "UsmNDArray_GetQueueRef" ,
@@ -521,7 +527,7 @@ def _pyx_capi_int(X, pyx_capi_name, caps_name=b"int", val_restype=ctypes.c_int):
521527
522528
523529def test_pyx_capi_check_constants ():
524- X = dpt .usm_ndarray (17 )[1 ::2 ]
530+ X = dpt .usm_ndarray (17 , dtype = "i1" )[1 ::2 ]
525531 cc_flag = _pyx_capi_int (X , "USM_ARRAY_C_CONTIGUOUS" )
526532 assert cc_flag > 0 and 0 == (cc_flag & (cc_flag - 1 ))
527533 fc_flag = _pyx_capi_int (X , "USM_ARRAY_F_CONTIGUOUS" )
@@ -598,6 +604,7 @@ def test_pyx_capi_check_constants():
598604@pytest .mark .parametrize ("usm_type" , ["device" , "shared" , "host" ])
599605def test_tofrom_numpy (shape , dtype , usm_type ):
600606 q = get_queue_or_skip ()
607+ skip_if_dtype_not_supported (dtype , q )
601608 Xnp = np .zeros (shape , dtype = dtype )
602609 Xusm = dpt .from_numpy (Xnp , usm_type = usm_type , sycl_queue = q )
603610 Ynp = np .ones (shape , dtype = dtype )
@@ -733,7 +740,7 @@ def relaxed_strides_equal(st1, st2, sh):
733740 4 ,
734741 5 ,
735742 )
736- X = dpt .usm_ndarray (sh_s , dtype = "d " )
743+ X = dpt .usm_ndarray (sh_s , dtype = "i8 " )
737744 X .shape = sh_f
738745 assert X .shape == sh_f
739746 assert relaxed_strides_equal (X .strides , cc_strides (sh_f ), sh_f )
@@ -750,27 +757,27 @@ def relaxed_strides_equal(st1, st2, sh):
750757 4 ,
751758 5 ,
752759 )
753- X = dpt .usm_ndarray (sh_s , dtype = "d " , order = "C" )
760+ X = dpt .usm_ndarray (sh_s , dtype = "u4 " , order = "C" )
754761 X .shape = sh_f
755762 assert X .shape == sh_f
756763 assert relaxed_strides_equal (X .strides , cc_strides (sh_f ), sh_f )
757764
758765 sh_s = (2 , 3 , 4 , 5 )
759766 sh_f = (4 , 3 , 2 , 5 )
760- X = dpt .usm_ndarray (sh_s , dtype = "d " )
767+ X = dpt .usm_ndarray (sh_s , dtype = "f4 " )
761768 X .shape = sh_f
762769 assert relaxed_strides_equal (X .strides , cc_strides (sh_f ), sh_f )
763770
764771 sh_s = (2 , 3 , 4 , 5 )
765772 sh_f = (4 , 3 , 1 , 2 , 5 )
766- X = dpt .usm_ndarray (sh_s , dtype = "d " )
773+ X = dpt .usm_ndarray (sh_s , dtype = "? " )
767774 X .shape = sh_f
768775 assert relaxed_strides_equal (X .strides , cc_strides (sh_f ), sh_f )
769776
770- X = dpt .usm_ndarray (sh_s , dtype = "d " )
777+ X = dpt .usm_ndarray (sh_s , dtype = "u4 " )
771778 with pytest .raises (TypeError ):
772779 X .shape = "abcbe"
773- X = dpt .usm_ndarray ((4 , 4 ), dtype = "d " )[::2 , ::2 ]
780+ X = dpt .usm_ndarray ((4 , 4 ), dtype = "u1 " )[::2 , ::2 ]
774781 with pytest .raises (AttributeError ):
775782 X .shape = (4 ,)
776783 X = dpt .usm_ndarray ((0 ,), dtype = "i4" )
@@ -814,7 +821,7 @@ def test_dlpack():
814821
815822
816823def test_to_device ():
817- X = dpt .usm_ndarray (1 , "d " )
824+ X = dpt .usm_ndarray (1 , "f4 " )
818825 for dev in dpctl .get_devices ():
819826 if dev .default_selector_score > 0 :
820827 Y = X .to_device (dev )
@@ -900,7 +907,7 @@ def test_reshape():
900907 W = dpt .reshape (Z , (- 1 ,), order = "C" )
901908 assert W .shape == (Z .size ,)
902909
903- X = dpt .usm_ndarray ((1 ,))
910+ X = dpt .usm_ndarray ((1 ,), dtype = "i8" )
904911 Y = dpt .reshape (X , X .shape )
905912 assert Y .flags == X .flags
906913
@@ -970,7 +977,9 @@ def test_real_imag_views():
970977 _all_dtypes ,
971978)
972979def test_zeros (dtype ):
973- X = dpt .zeros (10 , dtype = dtype )
980+ q = get_queue_or_skip ()
981+ skip_if_dtype_not_supported (dtype , q )
982+ X = dpt .zeros (10 , dtype = dtype , sycl_queue = q )
974983 assert np .array_equal (dpt .asnumpy (X ), np .zeros (10 , dtype = dtype ))
975984
976985
@@ -1197,6 +1206,7 @@ def test_linspace_fp_max(dtype):
11971206)
11981207def test_empty_like (dt , usm_kind ):
11991208 q = get_queue_or_skip ()
1209+ skip_if_dtype_not_supported (dt , q )
12001210
12011211 X = dpt .empty ((4 , 5 ), dtype = dt , usm_type = usm_kind , sycl_queue = q )
12021212 Y = dpt .empty_like (X )
@@ -1232,6 +1242,7 @@ def test_empty_unexpected_data_type():
12321242)
12331243def test_zeros_like (dt , usm_kind ):
12341244 q = get_queue_or_skip ()
1245+ skip_if_dtype_not_supported (dt , q )
12351246
12361247 X = dpt .empty ((4 , 5 ), dtype = dt , usm_type = usm_kind , sycl_queue = q )
12371248 Y = dpt .zeros_like (X )
0 commit comments