@@ -15,18 +15,25 @@ import Adapt
1515export oneAPIBackend
1616
1717struct oneAPIBackend <: KA.GPU
18+ prefer_blocks:: Bool
19+ always_inline:: Bool
1820end
1921
20- KA. allocate (:: oneAPIBackend , :: Type{T} , dims:: Tuple ) where T = oneArray {T} (undef, dims)
21- KA. zeros (:: oneAPIBackend , :: Type{T} , dims:: Tuple ) where T = oneAPI. zeros (T, dims)
22- KA. ones (:: oneAPIBackend , :: Type{T} , dims:: Tuple ) where T = oneAPI. ones (T, dims)
22+ oneAPIBackend (; prefer_blocks = false , always_inline = false ) = oneAPIBackend (prefer_blocks, always_inline)
23+
24+ @inline KA. allocate (:: oneAPIBackend , :: Type{T} , dims:: Tuple ; unified:: Bool = false ) where {T} = oneArray {T, length(dims), unified ? oneAPI.oneL0.SharedBuffer : oneAPI.oneL0.DeviceBuffer} (undef, dims)
25+ @inline KA. zeros (:: oneAPIBackend , :: Type{T} , dims:: Tuple ; unified:: Bool = false ) where {T} = fill! (oneArray {T, length(dims), unified ? oneAPI.oneL0.SharedBuffer : oneAPI.oneL0.DeviceBuffer} (undef, dims), zero (T))
26+ @inline KA. ones (:: oneAPIBackend , :: Type{T} , dims:: Tuple ; unified:: Bool = false ) where {T} = fill! (oneArray {T, length(dims), unified ? oneAPI.oneL0.SharedBuffer : oneAPI.oneL0.DeviceBuffer} (undef, dims), one (T))
2327
2428KA. get_backend (:: oneArray ) = oneAPIBackend ()
2529# TODO should be non-blocking
26- KA. synchronize (:: oneAPIBackend ) = oneL0. synchronize ()
30+ KA. synchronize (:: oneAPIBackend ) = oneAPI . oneL0. synchronize ()
2731KA. supports_float64 (:: oneAPIBackend ) = false # TODO : Check if this is device dependent
32+ KA. supports_unified (:: oneAPIBackend ) = true
33+
34+ KA. functional (:: oneAPIBackend ) = oneAPI. functional ()
2835
29- Adapt. adapt_storage (:: oneAPIBackend , a:: Array ) = Adapt. adapt (oneArray, a)
36+ Adapt. adapt_storage (:: oneAPIBackend , a:: AbstractArray ) = Adapt. adapt (oneArray, a)
3037Adapt. adapt_storage (:: oneAPIBackend , a:: oneArray ) = a
3138Adapt. adapt_storage (:: KA.CPU , a:: oneArray ) = convert (Array, a)
3239
@@ -39,6 +46,24 @@ function KA.copyto!(::oneAPIBackend, A, B)
3946end
4047
4148
49+ # # Device Operations
50+
51+ function KA. ndevices (:: oneAPIBackend )
52+ return length (oneAPI. devices ())
53+ end
54+
55+ function KA. device (:: oneAPIBackend ):: Int
56+ dev = oneAPI. device ()
57+ devs = oneAPI. devices ()
58+ idx = findfirst (== (dev), devs)
59+ return idx === nothing ? 1 : idx
60+ end
61+
62+ function KA. device! (backend:: oneAPIBackend , id:: Int )
63+ return oneAPI. device! (id)
64+ end
65+
66+
4267# # Kernel Launch
4368
4469function KA. mkcontext (kernel:: KA.Kernel{oneAPIBackend} , _ndrange, iterspace)
@@ -83,14 +108,42 @@ function threads_to_workgroupsize(threads, ndrange)
83108end
84109
85110function (obj:: KA.Kernel{oneAPIBackend} )(args... ; ndrange= nothing , workgroupsize= nothing )
111+ backend = KA. backend (obj)
112+
86113 ndrange, workgroupsize, iterspace, dynamic = KA. launch_config (obj, ndrange, workgroupsize)
87114 # this might not be the final context, since we may tune the workgroupsize
88115 ctx = KA. mkcontext (obj, ndrange, iterspace)
89- kernel = @oneapi launch= false obj. f (ctx, args... )
116+
117+ # If the kernel is statically sized we can tell the compiler about that
118+ if KA. workgroupsize (obj) <: KA.StaticSize
119+ # TODO : maxthreads
120+ # maxthreads = prod(KA.get(KA.workgroupsize(obj)))
121+ else
122+ # maxthreads = nothing
123+ end
124+
125+ kernel = @oneapi launch = false always_inline = backend. always_inline obj. f (ctx, args... )
90126
91127 # figure out the optimal workgroupsize automatically
92128 if KA. workgroupsize (obj) <: KA.DynamicSize && workgroupsize === nothing
93129 items = oneAPI. launch_configuration (kernel)
130+
131+ if backend. prefer_blocks
132+ # Prefer blocks over threads:
133+ # Reducing the workgroup size (items) increases the number of workgroups (blocks).
134+ # We use a simple heuristic here since we lack full occupancy info (max_blocks) from launch_configuration.
135+
136+ # If the total range is large enough, full workgroups are fine.
137+ # If the range is small, we might want to reduce 'items' to create more blocks to fill the GPU.
138+ # (Simplified logic compared to CUDA.jl which uses explicit occupancy calculators)
139+ total_items = prod (ndrange)
140+ if total_items < items * 16 # Heuristic factor
141+ # Force at least a few blocks if possible by reducing items per block
142+ target_blocks = 16 # Target at least 16 blocks
143+ items = max (1 , min (items, cld (total_items, target_blocks)))
144+ end
145+ end
146+
94147 workgroupsize = threads_to_workgroupsize (items, ndrange)
95148 iterspace, dynamic = KA. partition (obj, ndrange, workgroupsize)
96149 ctx = KA. mkcontext (obj, ndrange, iterspace)
171224
172225# # Other
173226
227+ Adapt. adapt_storage (to:: KA.ConstAdaptor , a:: oneDeviceArray ) = Base. Experimental. Const (a)
228+
174229KA. argconvert (:: KA.Kernel{oneAPIBackend} , arg) = kernel_convert (arg)
175230
231+ function KA. priority! (:: oneAPIBackend , prio:: Symbol )
232+ if ! (prio in (:high , :normal , :low ))
233+ error (" priority must be one of :high, :normal, :low" )
234+ end
235+
236+ priority_enum = if prio == :high
237+ oneAPI. oneL0. ZE_COMMAND_QUEUE_PRIORITY_PRIORITY_HIGH
238+ elseif prio == :low
239+ oneAPI. oneL0. ZE_COMMAND_QUEUE_PRIORITY_PRIORITY_LOW
240+ else
241+ oneAPI. oneL0. ZE_COMMAND_QUEUE_PRIORITY_NORMAL
242+ end
243+
244+ ctx = oneAPI. context ()
245+ dev = oneAPI. device ()
246+
247+ # Update the cached queue
248+ # We synchronize the current queue first to ensure safety
249+ current_queue = oneAPI. global_queue (ctx, dev)
250+ oneAPI. oneL0. synchronize (current_queue)
251+
252+ # Replace the queue in task_local_storage
253+ # The key used by global_queue is (:ZeCommandQueue, ctx, dev)
254+
255+ new_queue = oneAPI. oneL0. ZeCommandQueue (
256+ ctx, dev;
257+ flags = oneAPI. oneL0. ZE_COMMAND_QUEUE_FLAG_IN_ORDER,
258+ priority = priority_enum
259+ )
260+
261+ task_local_storage ((:ZeCommandQueue , ctx, dev), new_queue)
262+
263+ return nothing
264+ end
265+
176266end
0 commit comments