@@ -373,7 +373,11 @@ def test_datapi_device():
373373
374374
375375def _pyx_capi_fnptr_to_callable (
376- X , pyx_capi_name , caps_name , fn_restype = ctypes .c_void_p
376+ X ,
377+ pyx_capi_name ,
378+ caps_name ,
379+ fn_restype = ctypes .c_void_p ,
380+ fn_argtypes = (ctypes .py_object ,),
377381):
378382 import sys
379383
@@ -388,7 +392,7 @@ def _pyx_capi_fnptr_to_callable(
388392 cap_ptr_fn .restype = ctypes .c_void_p
389393 cap_ptr_fn .argtypes = [ctypes .py_object , ctypes .c_char_p ]
390394 fn_ptr = cap_ptr_fn (cap , caps_name )
391- callable_maker_ptr = ctypes .PYFUNCTYPE (fn_restype , ctypes . py_object )
395+ callable_maker_ptr = ctypes .PYFUNCTYPE (fn_restype , * fn_argtypes )
392396 return callable_maker_ptr (fn_ptr )
393397
394398
@@ -399,6 +403,7 @@ def test_pyx_capi_get_data():
399403 "UsmNDArray_GetData" ,
400404 b"char *(struct PyUSMArrayObject *)" ,
401405 fn_restype = ctypes .c_void_p ,
406+ fn_argtypes = (ctypes .py_object ,),
402407 )
403408 r1 = get_data_fn (X )
404409 sua_iface = X .__sycl_usm_array_interface__
@@ -412,6 +417,7 @@ def test_pyx_capi_get_shape():
412417 "UsmNDArray_GetShape" ,
413418 b"Py_ssize_t *(struct PyUSMArrayObject *)" ,
414419 fn_restype = ctypes .c_void_p ,
420+ fn_argtypes = (ctypes .py_object ,),
415421 )
416422 c_longlong_p = ctypes .POINTER (ctypes .c_longlong )
417423 shape0 = ctypes .cast (get_shape_fn (X ), c_longlong_p ).contents .value
@@ -425,6 +431,7 @@ def test_pyx_capi_get_strides():
425431 "UsmNDArray_GetStrides" ,
426432 b"Py_ssize_t *(struct PyUSMArrayObject *)" ,
427433 fn_restype = ctypes .c_void_p ,
434+ fn_argtypes = (ctypes .py_object ,),
428435 )
429436 c_longlong_p = ctypes .POINTER (ctypes .c_longlong )
430437 strides0_p = get_strides_fn (X )
@@ -441,6 +448,7 @@ def test_pyx_capi_get_ndim():
441448 "UsmNDArray_GetNDim" ,
442449 b"int (struct PyUSMArrayObject *)" ,
443450 fn_restype = ctypes .c_int ,
451+ fn_argtypes = (ctypes .py_object ,),
444452 )
445453 assert get_ndim_fn (X ) == X .ndim
446454
@@ -452,6 +460,7 @@ def test_pyx_capi_get_typenum():
452460 "UsmNDArray_GetTypenum" ,
453461 b"int (struct PyUSMArrayObject *)" ,
454462 fn_restype = ctypes .c_int ,
463+ fn_argtypes = (ctypes .py_object ,),
455464 )
456465 typenum = get_typenum_fn (X )
457466 assert type (typenum ) is int
@@ -465,6 +474,7 @@ def test_pyx_capi_get_elemsize():
465474 "UsmNDArray_GetElementSize" ,
466475 b"int (struct PyUSMArrayObject *)" ,
467476 fn_restype = ctypes .c_int ,
477+ fn_argtypes = (ctypes .py_object ,),
468478 )
469479 itemsize = get_elemsize_fn (X )
470480 assert type (itemsize ) is int
@@ -478,6 +488,7 @@ def test_pyx_capi_get_flags():
478488 "UsmNDArray_GetFlags" ,
479489 b"int (struct PyUSMArrayObject *)" ,
480490 fn_restype = ctypes .c_int ,
491+ fn_argtypes = (ctypes .py_object ,),
481492 )
482493 flags = get_flags_fn (X )
483494 assert type (flags ) is int and X .flags == flags
@@ -490,6 +501,7 @@ def test_pyx_capi_get_offset():
490501 "UsmNDArray_GetOffset" ,
491502 b"Py_ssize_t (struct PyUSMArrayObject *)" ,
492503 fn_restype = ctypes .c_longlong ,
504+ fn_argtypes = (ctypes .py_object ,),
493505 )
494506 offset = get_offset_fn (X )
495507 assert type (offset ) is int
@@ -503,11 +515,123 @@ def test_pyx_capi_get_queue_ref():
503515 "UsmNDArray_GetQueueRef" ,
504516 b"DPCTLSyclQueueRef (struct PyUSMArrayObject *)" ,
505517 fn_restype = ctypes .c_void_p ,
518+ fn_argtypes = (ctypes .py_object ,),
506519 )
507520 queue_ref = get_queue_ref_fn (X ) # address of a copy, should be unequal
508521 assert queue_ref != X .sycl_queue .addressof_ref ()
509522
510523
524+ def test_pyx_capi_make_from_memory ():
525+ q = get_queue_or_skip ()
526+ n0 , n1 = 4 , 6
527+ c_tuple = (ctypes .c_ssize_t * 2 )(n0 , n1 )
528+ mem = dpm .MemoryUSMShared (n0 * n1 * 4 , queue = q )
529+ typenum = dpt .dtype ("single" ).num
530+ any_usm_ndarray = dpt .empty (tuple (), dtype = "i4" , sycl_queue = q )
531+ make_from_memory_fn = _pyx_capi_fnptr_to_callable (
532+ any_usm_ndarray ,
533+ "UsmNDArray_MakeFromMemory" ,
534+ b"PyObject *(int, Py_ssize_t const *, int, "
535+ b"struct Py_MemoryObject *, Py_ssize_t, char)" ,
536+ fn_restype = ctypes .py_object ,
537+ fn_argtypes = (
538+ ctypes .c_int ,
539+ ctypes .POINTER (ctypes .c_ssize_t ),
540+ ctypes .c_int ,
541+ ctypes .py_object ,
542+ ctypes .c_ssize_t ,
543+ ctypes .c_char ,
544+ ),
545+ )
546+ r = make_from_memory_fn (
547+ ctypes .c_int (2 ),
548+ c_tuple ,
549+ ctypes .c_int (typenum ),
550+ mem ,
551+ ctypes .c_ssize_t (0 ),
552+ ctypes .c_char (b"C" ),
553+ )
554+ assert isinstance (r , dpt .usm_ndarray )
555+ assert r .ndim == 2
556+ assert r .shape == (n0 , n1 )
557+ assert r ._pointer == mem ._pointer
558+ assert r .usm_type == "shared"
559+ assert r .sycl_queue == q
560+ assert r .flags ["C" ]
561+ r2 = make_from_memory_fn (
562+ ctypes .c_int (2 ),
563+ c_tuple ,
564+ ctypes .c_int (typenum ),
565+ mem ,
566+ ctypes .c_ssize_t (0 ),
567+ ctypes .c_char (b"F" ),
568+ )
569+ ptr = mem ._pointer
570+ del mem
571+ del r
572+ assert isinstance (r2 , dpt .usm_ndarray )
573+ assert r2 ._pointer == ptr
574+ assert r2 .usm_type == "shared"
575+ assert r2 .sycl_queue == q
576+ assert r2 .flags ["F" ]
577+
578+
579+ def test_pyx_capi_set_writable_flag ():
580+ q = get_queue_or_skip ()
581+ usm_ndarray = dpt .empty ((4 , 5 ), dtype = "i4" , sycl_queue = q )
582+ assert isinstance (usm_ndarray , dpt .usm_ndarray )
583+ assert usm_ndarray .flags ["WRITABLE" ] is True
584+ set_writable = _pyx_capi_fnptr_to_callable (
585+ usm_ndarray ,
586+ "UsmNDArray_SetWritableFlag" ,
587+ b"void (struct PyUSMArrayObject *, int)" ,
588+ fn_restype = None ,
589+ fn_argtypes = (ctypes .py_object , ctypes .c_int ),
590+ )
591+ set_writable (usm_ndarray , ctypes .c_int (0 ))
592+ assert isinstance (usm_ndarray , dpt .usm_ndarray )
593+ assert usm_ndarray .flags ["WRITABLE" ] is False
594+ set_writable (usm_ndarray , ctypes .c_int (1 ))
595+ assert isinstance (usm_ndarray , dpt .usm_ndarray )
596+ assert usm_ndarray .flags ["WRITABLE" ] is True
597+
598+
599+ def test_pyx_capi_make_from_ptr ():
600+ q = get_queue_or_skip ()
601+ usm_ndarray = dpt .empty (tuple (), dtype = "i4" , sycl_queue = q )
602+ make_from_ptr = _pyx_capi_fnptr_to_callable (
603+ usm_ndarray ,
604+ "UsmNDArray_MakeFromPtr" ,
605+ b"PyObject *(size_t, int, DPCTLSyclUSMRef, "
606+ b"DPCTLSyclQueueRef, PyObject *)" ,
607+ fn_restype = ctypes .py_object ,
608+ fn_argtypes = (
609+ ctypes .c_size_t ,
610+ ctypes .c_int ,
611+ ctypes .c_void_p ,
612+ ctypes .c_void_p ,
613+ ctypes .py_object ,
614+ ),
615+ )
616+ nelems = 10
617+ dt = dpt .int64
618+ mem = dpm .MemoryUSMDevice (nelems * dt .itemsize , queue = q )
619+ arr = make_from_ptr (
620+ ctypes .c_size_t (nelems ),
621+ dt .num ,
622+ mem ._pointer ,
623+ mem .sycl_queue .addressof_ref (),
624+ mem ,
625+ )
626+ assert isinstance (arr , dpt .usm_ndarray )
627+ assert arr .shape == (nelems ,)
628+ assert arr .dtype == dt
629+ assert arr .sycl_queue == q
630+ assert arr ._pointer == mem ._pointer
631+ del mem
632+ assert isinstance (arr .__repr__ (), str )
633+
634+
511635def _pyx_capi_int (X , pyx_capi_name , caps_name = b"int" , val_restype = ctypes .c_int ):
512636 import sys
513637
0 commit comments