2020
2121import logging
2222from contextlib import ExitStack, contextmanager
23+ from contextvars import ContextVar
2324
2425from .enum_types import backend_type, device_type
2526
@@ -35,6 +36,7 @@ from ._backend cimport ( # noqa: E211
3536 _device_type,
3637)
3738from ._sycl_context cimport SyclContext
39+ from ._sycl_device cimport SyclDevice
3840
3941__all__ = [
4042 " device_context" ,
@@ -44,6 +46,7 @@ __all__ = [
4446 " get_num_activated_queues" ,
4547 " is_in_device_context" ,
4648 " set_global_queue" ,
49+ " _global_device_queue_cache" ,
4750]
4851
4952_logger = logging.getLogger(__name__ )
@@ -291,3 +294,45 @@ def device_context(arg):
291294 _mgr._remove_current_queue()
292295 else :
293296 _logger.debug(" No queue was created so nothing to do" )
297+
298+
299+ cdef class _DeviceDefaultQueueCache:
300+ cdef dict __device_queue_map__
301+
302+ def __cinit__ (self ):
303+ self .__device_queue_map__ = dict ()
304+
305+ def get_or_create (self , key ):
306+ """ Return instance of SyclQueue and indicator if cache has been modified"""
307+ if isinstance (key, tuple ) and len (key) == 2 and isinstance (key[0 ], SyclContext) and isinstance (key[1 ], SyclDevice):
308+ ctx_dev = key
309+ q = None
310+ elif isinstance (key, SyclDevice):
311+ q = SyclQueue(key)
312+ ctx_dev = q.sycl_context, key
313+ else :
314+ raise TypeError
315+ if ctx_dev in self .__device_queue_map__:
316+ return self .__device_queue_map__[ctx_dev], False
317+ if q is None : q = SyclQueue(* ctx_dev)
318+ self .__device_queue_map__[ctx_dev] = q
319+ return q, True
320+
321+ cdef _update_map(self , dev_queue_map):
322+ self .__device_queue_map__.update(dev_queue_map)
323+
324+ def __copy__ (self ):
325+ cdef _DeviceDefaultQueueCache _copy = _DeviceDefaultQueueCache.__new__ (_DeviceDefaultQueueCache)
326+ _copy._update_map(self .__device_queue_map__)
327+ return _copy
328+
329+
330+ _global_device_queue_cache = ContextVar(' global_device_queue_cache' , default = _DeviceDefaultQueueCache())
331+
332+
333+ cpdef object get_device_cached_queue(object key):
334+ """ Get cached queue associated with given device"""
335+ _cache = _global_device_queue_cache.get()
336+ q_, changed_ = _cache.get_or_create(key)
337+ if changed_: _global_device_queue_cache.set(_cache)
338+ return q_
0 commit comments