@@ -303,13 +303,22 @@ cdef class _DeviceDefaultQueueCache:
303303 self .__device_queue_map__ = dict ()
304304
305305 def get_or_create (self , key ):
306- """ Return instance of SyclQueue and indicator if cache has been modified"""
307- if isinstance (key, tuple ) and len (key) == 2 and isinstance (key[0 ], SyclContext) and isinstance (key[1 ], SyclDevice):
306+ """ Return instance of SyclQueue and indicator if cache
307+ has been modified"""
308+ if (
309+ isinstance (key, tuple )
310+ and len (key) == 2
311+ and isinstance (key[0 ], SyclContext)
312+ and isinstance (key[1 ], SyclDevice)
313+ ):
308314 ctx_dev = key
309315 q = None
310316 elif isinstance (key, SyclDevice):
311317 q = SyclQueue(key)
312318 ctx_dev = q.sycl_context, key
319+ elif isinstance (key, str ):
320+ q = SyclQueue(key)
321+ ctx_dev = q.sycl_context, q.sycl_device
313322 else :
314323 raise TypeError
315324 if ctx_dev in self .__device_queue_map__:
@@ -322,12 +331,16 @@ cdef class _DeviceDefaultQueueCache:
322331 self .__device_queue_map__.update(dev_queue_map)
323332
324333 def __copy__ (self ):
325- cdef _DeviceDefaultQueueCache _copy = _DeviceDefaultQueueCache.__new__ (_DeviceDefaultQueueCache)
334+ cdef _DeviceDefaultQueueCache _copy = _DeviceDefaultQueueCache.__new__ (
335+ _DeviceDefaultQueueCache)
326336 _copy._update_map(self .__device_queue_map__)
327337 return _copy
328338
329339
330- _global_device_queue_cache = ContextVar(' global_device_queue_cache' , default = _DeviceDefaultQueueCache())
340+ _global_device_queue_cache = ContextVar(
341+ ' global_device_queue_cache' ,
342+ default = _DeviceDefaultQueueCache()
343+ )
331344
332345
333346cpdef object get_device_cached_queue(object key):
0 commit comments