@@ -32,6 +32,7 @@ from .._sycl_queue_manager cimport get_current_queue
3232
3333from cpython cimport Py_buffer
3434from cpython.bytes cimport PyBytes_AS_STRING, PyBytes_FromStringAndSize
35+ from cpython cimport pycapsule
3536
3637import numpy as np
3738
@@ -41,10 +42,63 @@ __all__ = [
4142 " MemoryUSMDevice"
4243]
4344
44- cdef _throw_sycl_usm_ary_iface ():
45- raise ValueError (" __sycl_usm_array_interface__ is malformed" )
45+ cdef object _sycl_usm_ary_iface_error ():
46+ return ValueError (" __sycl_usm_array_interface__ is malformed" )
4647
4748
49+ cdef DPCTLSyclQueueRef _queue_ref_copy_from_SyclQueue(SyclQueue q):
50+ return DPCTLQueue_Copy(q.get_queue_ref())
51+
52+
53+ cdef DPCTLSyclQueueRef _queue_ref_copy_from_USMRef_and_SyclContext(
54+ DPCTLSyclUSMRef ptr, SyclContext ctx):
55+ """ Obtain device from pointer and sycl context, use
56+ context and device to create a queue from which this memory
57+ can be accessible.
58+ """
59+ cdef SyclDevice dev = _Memory.get_pointer_device(ptr, ctx)
60+ cdef DPCTLSyclContextRef CRef = NULL
61+ cdef DPCTLSyclDeviceRef DRef = NULL
62+ CRef = ctx.get_context_ref()
63+ DRef = dev.get_device_ref()
64+ return DPCTLQueue_Create(CRef, DRef, NULL , 0 )
65+
66+
67+ cdef DPCTLSyclQueueRef get_queue_ref_from_ptr_and_syclobj(
68+ DPCTLSyclUSMRef ptr, object syclobj):
69+ """ Constructs queue from pointer and syclobject from
70+ __sycl_usm_array_interface__
71+ """
72+ cdef DPCTLSyclQueueRef QRef = NULL
73+ cdef SyclContext ctx
74+ if type (syclobj) is SyclQueue:
75+ return _queue_ref_copy_from_SyclQueue(< SyclQueue> syclobj)
76+ elif type (syclobj) is SyclContext:
77+ ctx = < SyclContext> syclobj
78+ return _queue_ref_copy_from_USMRef_and_SyclContext(ptr, ctx)
79+ elif type (syclobj) is str :
80+ q = SyclQueue(syclobj)
81+ return _queue_ref_copy_from_SyclQueue(< SyclQueue> q)
82+ elif pycapsule.PyCapsule_IsValid(syclobj, " SyclQueueRef" ):
83+ q = SyclQueue(syclobj)
84+ return _queue_ref_copy_from_SyclQueue(< SyclQueue> q)
85+ elif pycapsule.PyCapsule_IsValid(syclobj, " SyclContextRef" ):
86+ ctx = < SyclContext> SyclContext(syclobj)
87+ return _queue_ref_copy_from_USMRef_and_SyclContext(ptr, ctx)
88+ elif hasattr (syclobj, ' _get_capsule' ):
89+ cap = syclobj._get_capsule()
90+ if pycapsule.PyCapsule_IsValid(cap, " SyclQueueRef" ):
91+ q = SyclQueue(cap)
92+ return _queue_ref_copy_from_SyclQueue(< SyclQueue> q)
93+ elif pycapsule.PyCapsule_IsValid(cap, " SyclContexRef" ):
94+ ctx = < SyclContext> SyclContext(cap)
95+ return _queue_ref_copy_from_USMRef_and_SyclContext(ptr, ctx)
96+ else :
97+ return QRef
98+ else :
99+ return QRef
100+
101+
48102cdef void copy_via_host(void * dest_ptr, SyclQueue dest_queue,
49103 void * src_ptr, SyclQueue src_queue, size_t nbytes):
50104 """
@@ -98,25 +152,26 @@ cdef class _BufferData:
98152 cdef Py_ssize_t arr_data_ptr
99153 cdef SyclDevice dev
100154 cdef SyclContext ctx
155+ cdef DPCTLSyclQueueRef QRef = NULL
101156
102157 if ary_version != 1 :
103- _throw_sycl_usm_ary_iface ()
158+ raise _sycl_usm_ary_iface_error ()
104159 if not ary_data_tuple or len (ary_data_tuple) != 2 :
105- _throw_sycl_usm_ary_iface ()
160+ raise _sycl_usm_ary_iface_error ()
106161 if not ary_shape or len (ary_shape) != 1 or ary_shape[0 ] < 1 :
107162 raise ValueError
108163 try :
109164 dt = np.dtype(ary_typestr)
110165 except TypeError :
111- _throw_sycl_usm_ary_iface ()
166+ raise _sycl_usm_ary_iface_error ()
112167 if (ary_strides and len (ary_strides) != 1
113168 and ary_strides[0 ] != dt.itemsize):
114169 raise ValueError (" Must be contiguous" )
115170
116171 if (not ary_syclobj or
117172 not isinstance (ary_syclobj,
118173 (dpctl.SyclQueue, dpctl.SyclContext))):
119- _throw_sycl_usm_ary_iface ()
174+ raise _sycl_usm_ary_iface_error ()
120175
121176 buf = _BufferData.__new__ (_BufferData)
122177 arr_data_ptr = < Py_ssize_t> ary_data_tuple[0 ]
@@ -125,15 +180,8 @@ cdef class _BufferData:
125180 buf.itemsize = < Py_ssize_t> (dt.itemsize)
126181 buf.nbytes = (< Py_ssize_t> ary_shape[0 ]) * buf.itemsize
127182
128- if isinstance (ary_syclobj, dpctl.SyclQueue):
129- buf.queue = < SyclQueue> ary_syclobj
130- else :
131- # Obtain device from pointer and context
132- ctx = < SyclContext> ary_syclobj
133- dev = _Memory.get_pointer_device(buf.p, ctx)
134- # Use context and device to create a queue to
135- # be able to copy memory
136- buf.queue = SyclQueue._create_from_context_and_device(ctx, dev)
183+ QRef = get_queue_ref_from_ptr_and_syclobj(buf.p, ary_syclobj)
184+ buf.queue = SyclQueue._create(QRef)
137185
138186 return buf
139187
0 commit comments