@@ -24,6 +24,7 @@ from libc.stdint cimport int32_t, int64_t, uint8_t, uint16_t, uint64_t
2424
2525cimport dpctl as c_dpctl
2626cimport dpctl.memory as c_dpmem
27+ from dpctl._sycl_queue_manager cimport get_device_cached_queue
2728
2829from .._backend cimport (
2930 DPCTLDevice_Delete,
@@ -344,12 +345,12 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
344345 if _IS_LINUX:
345346 default_context = root_device.sycl_platform.default_context
346347 else :
347- default_context = dpctl.SyclQueue (root_device).sycl_context
348+ default_context = get_device_cached_queue (root_device).sycl_context
348349 except RuntimeError :
349- default_context = dpctl.SyclQueue (root_device).sycl_context
350+ default_context = get_device_cached_queue (root_device).sycl_context
350351 if dlm_tensor.dl_tensor.data is NULL :
351352 usm_type = b" device"
352- q = dpctl.SyclQueue( default_context, root_device)
353+ q = get_device_cached_queue(( default_context, root_device,) )
353354 else :
354355 usm_type = c_dpmem._Memory.get_pointer_type(
355356 < DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
@@ -364,7 +365,7 @@ cpdef usm_ndarray from_dlpack_capsule(object py_caps) except +:
364365 < DPCTLSyclUSMRef> dlm_tensor.dl_tensor.data,
365366 < c_dpctl.SyclContext> default_context
366367 )
367- q = dpctl.SyclQueue( default_context, alloc_device)
368+ q = get_device_cached_queue(( default_context, alloc_device,) )
368369 if dlm_tensor.dl_tensor.dtype.bits % 8 :
369370 raise BufferError(
370371 " Can not import DLPack tensor whose element's "
0 commit comments