Skip to content

Commit 86ded21

Browse files
authored
Update KA API and fix unified API (#556)
1 parent 93bb643 commit 86ded21

File tree

10 files changed

+131
-22
lines changed

10 files changed

+131
-22
lines changed

Project.toml

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "oneAPI"
22
uuid = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
3-
version = "2.6.0"
43
authors = ["Tim Besard <tim.besard@gmail.com>", "Alexis Montoison", "Michel Schanen <michel.schanen@gmail.com>"]
4+
version = "2.6.0"
55

66
[deps]
77
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
8+
AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"
89
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
910
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
1011
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
@@ -29,22 +30,16 @@ oneAPI_Level_Zero_Headers_jll = "f4bc562b-d309-54f8-9efb-476e56f0410d"
2930
oneAPI_Level_Zero_Loader_jll = "13eca655-d68d-5b81-8367-6d99d727ab01"
3031
oneAPI_Support_jll = "b049733a-a71d-5ed3-8eba-7d323ac00b36"
3132

32-
[weakdeps]
33-
AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"
34-
35-
[extensions]
36-
oneAPIAcceleratedKernelsExt = "AcceleratedKernels"
37-
3833
[compat]
3934
AbstractFFTs = "1.5.0"
40-
AcceleratedKernels = "0.4.3"
35+
AcceleratedKernels = "0.3.1, 0.4"
4136
Adapt = "4"
4237
CEnum = "0.4, 0.5"
4338
ExprTools = "0.1"
4439
GPUArrays = "11.2.1"
4540
GPUCompiler = "1.6"
4641
GPUToolbox = "0.1, 0.2, 0.3, 1"
47-
KernelAbstractions = "0.9.1"
42+
KernelAbstractions = "0.9.39"
4843
LLVM = "6, 7, 8, 9"
4944
NEO_jll = "=25.44.36015"
5045
Preferences = "1"

lib/level-zero/device.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ Base.length(iter::ZeDevices) = length(iter.handles)
204204

205205
Base.IteratorSize(::ZeDevices) = Base.HasLength()
206206

207+
Base.keys(iter::ZeDevices) = 1:length(iter)
208+
207209
function Base.show(io::IO, ::MIME"text/plain", iter::ZeDevices)
208210
print(io, "ZeDevice iterator for $(length(iter)) devices")
209211
if !isempty(iter)
Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
module oneAPIAcceleratedKernelsExt
2-
31
import oneAPI
42
import oneAPI: oneArray, oneAPIBackend
53
import AcceleratedKernels as AK
@@ -13,5 +11,3 @@ Base.accumulate(op, A::oneArray; init = zero(eltype(A)), kwargs...) =
1311

1412
Base.cumsum(src::oneArray; kwargs...) = AK.cumsum(src, oneAPIBackend(); kwargs...)
1513
Base.cumprod(src::oneArray; kwargs...) = AK.cumprod(src, oneAPIBackend(); kwargs...)
16-
17-
end # module

src/array.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,19 @@ function Base.unsafe_convert(::Type{ZePtr{T}}, x::oneArray{T}) where {T}
345345
end
346346

347347

348+
## indexing
349+
350+
# Host-accessible arrays can be indexed from CPU, bypassing GPUArrays restrictions
351+
function Base.getindex(x::oneArray{<:Any, <:Any, <:Union{oneL0.HostBuffer, oneL0.SharedBuffer}}, I::Int)
352+
@boundscheck checkbounds(x, I)
353+
return unsafe_load(pointer(x, I; type = oneL0.HostBuffer))
354+
end
355+
356+
function Base.setindex!(x::oneArray{<:Any, <:Any, <:Union{oneL0.HostBuffer, oneL0.SharedBuffer}}, v, I::Int)
357+
@boundscheck checkbounds(x, I)
358+
return unsafe_store!(pointer(x, I; type = oneL0.HostBuffer), v)
359+
end
360+
348361

349362
## interop with GPU arrays
350363

src/context.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,13 @@ See also: [`device`](@ref), [`devices`](@ref)
103103
function device!(drv::ZeDevice)
104104
task_local_storage(:ZeDevice, drv)
105105
end
106-
device!(i::Int) = device!(devices(driver())[i])
106+
function device!(i::Int)
107+
devs = devices(driver())
108+
if i < 1 || i > length(devs)
109+
throw(ArgumentError("Invalid device index $i (must be between 1 and $(length(devs)))"))
110+
end
111+
return device!(devs[i])
112+
end
107113

108114
const global_contexts = Dict{ZeDriver,ZeContext}()
109115

src/device/array.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,13 @@ end
195195
end
196196
end
197197

198+
@device_function @inline function unsafe_cached_load(ptr::LLVMPtr{T, A}, i::Integer, align::Val) where {T, A}
199+
# For SPIR-V/Level Zero, we don't have explicit cache control intrinsics like CUDA's __ldg
200+
# So we fall back to a regular unsafe_load. The SPIR-V compiler may still apply
201+
# appropriate optimizations based on context.
202+
unsafe_load(ptr, i, align)
203+
end
204+
198205
@device_function @inline function const_arrayref(A::oneDeviceArray{T}, index::Integer) where {T}
199206
# simplified bounds check (see `arrayset`)
200207
#@boundscheck checkbounds(A, index)

src/oneAPI.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ include("utils.jl")
6969

7070
include("oneAPIKernels.jl")
7171
import .oneAPIKernels: oneAPIBackend
72+
include("accumulate.jl")
7273
include("indexing.jl")
7374
export oneAPIBackend
7475

src/oneAPIKernels.jl

Lines changed: 96 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,25 @@ import Adapt
1515
export oneAPIBackend
1616

1717
struct oneAPIBackend <: KA.GPU
18+
prefer_blocks::Bool
19+
always_inline::Bool
1820
end
1921

20-
KA.allocate(::oneAPIBackend, ::Type{T}, dims::Tuple) where T = oneArray{T}(undef, dims)
21-
KA.zeros(::oneAPIBackend, ::Type{T}, dims::Tuple) where T = oneAPI.zeros(T, dims)
22-
KA.ones(::oneAPIBackend, ::Type{T}, dims::Tuple) where T = oneAPI.ones(T, dims)
22+
oneAPIBackend(; prefer_blocks = false, always_inline = false) = oneAPIBackend(prefer_blocks, always_inline)
23+
24+
@inline KA.allocate(::oneAPIBackend, ::Type{T}, dims::Tuple; unified::Bool = false) where {T} = oneArray{T, length(dims), unified ? oneAPI.oneL0.SharedBuffer : oneAPI.oneL0.DeviceBuffer}(undef, dims)
25+
@inline KA.zeros(::oneAPIBackend, ::Type{T}, dims::Tuple; unified::Bool = false) where {T} = fill!(oneArray{T, length(dims), unified ? oneAPI.oneL0.SharedBuffer : oneAPI.oneL0.DeviceBuffer}(undef, dims), zero(T))
26+
@inline KA.ones(::oneAPIBackend, ::Type{T}, dims::Tuple; unified::Bool = false) where {T} = fill!(oneArray{T, length(dims), unified ? oneAPI.oneL0.SharedBuffer : oneAPI.oneL0.DeviceBuffer}(undef, dims), one(T))
2327

2428
KA.get_backend(::oneArray) = oneAPIBackend()
2529
# TODO should be non-blocking
26-
KA.synchronize(::oneAPIBackend) = oneL0.synchronize()
30+
KA.synchronize(::oneAPIBackend) = oneAPI.oneL0.synchronize()
2731
KA.supports_float64(::oneAPIBackend) = false # TODO: Check if this is device dependent
32+
KA.supports_unified(::oneAPIBackend) = true
33+
34+
KA.functional(::oneAPIBackend) = oneAPI.functional()
2835

29-
Adapt.adapt_storage(::oneAPIBackend, a::Array) = Adapt.adapt(oneArray, a)
36+
Adapt.adapt_storage(::oneAPIBackend, a::AbstractArray) = Adapt.adapt(oneArray, a)
3037
Adapt.adapt_storage(::oneAPIBackend, a::oneArray) = a
3138
Adapt.adapt_storage(::KA.CPU, a::oneArray) = convert(Array, a)
3239

@@ -39,6 +46,24 @@ function KA.copyto!(::oneAPIBackend, A, B)
3946
end
4047

4148

49+
## Device Operations
50+
51+
function KA.ndevices(::oneAPIBackend)
52+
return length(oneAPI.devices())
53+
end
54+
55+
function KA.device(::oneAPIBackend)::Int
56+
dev = oneAPI.device()
57+
devs = oneAPI.devices()
58+
idx = findfirst(==(dev), devs)
59+
return idx === nothing ? 1 : idx
60+
end
61+
62+
function KA.device!(backend::oneAPIBackend, id::Int)
63+
return oneAPI.device!(id)
64+
end
65+
66+
4267
## Kernel Launch
4368

4469
function KA.mkcontext(kernel::KA.Kernel{oneAPIBackend}, _ndrange, iterspace)
@@ -83,14 +108,42 @@ function threads_to_workgroupsize(threads, ndrange)
83108
end
84109

85110
function (obj::KA.Kernel{oneAPIBackend})(args...; ndrange=nothing, workgroupsize=nothing)
111+
backend = KA.backend(obj)
112+
86113
ndrange, workgroupsize, iterspace, dynamic = KA.launch_config(obj, ndrange, workgroupsize)
87114
# this might not be the final context, since we may tune the workgroupsize
88115
ctx = KA.mkcontext(obj, ndrange, iterspace)
89-
kernel = @oneapi launch=false obj.f(ctx, args...)
116+
117+
# If the kernel is statically sized we can tell the compiler about that
118+
if KA.workgroupsize(obj) <: KA.StaticSize
119+
# TODO: maxthreads
120+
# maxthreads = prod(KA.get(KA.workgroupsize(obj)))
121+
else
122+
# maxthreads = nothing
123+
end
124+
125+
kernel = @oneapi launch = false always_inline = backend.always_inline obj.f(ctx, args...)
90126

91127
# figure out the optimal workgroupsize automatically
92128
if KA.workgroupsize(obj) <: KA.DynamicSize && workgroupsize === nothing
93129
items = oneAPI.launch_configuration(kernel)
130+
131+
if backend.prefer_blocks
132+
# Prefer blocks over threads:
133+
# Reducing the workgroup size (items) increases the number of workgroups (blocks).
134+
# We use a simple heuristic here since we lack full occupancy info (max_blocks) from launch_configuration.
135+
136+
# If the total range is large enough, full workgroups are fine.
137+
# If the range is small, we might want to reduce 'items' to create more blocks to fill the GPU.
138+
# (Simplified logic compared to CUDA.jl which uses explicit occupancy calculators)
139+
total_items = prod(ndrange)
140+
if total_items < items * 16 # Heuristic factor
141+
# Force at least a few blocks if possible by reducing items per block
142+
target_blocks = 16 # Target at least 16 blocks
143+
items = max(1, min(items, cld(total_items, target_blocks)))
144+
end
145+
end
146+
94147
workgroupsize = threads_to_workgroupsize(items, ndrange)
95148
iterspace, dynamic = KA.partition(obj, ndrange, workgroupsize)
96149
ctx = KA.mkcontext(obj, ndrange, iterspace)
@@ -171,6 +224,43 @@ end
171224

172225
## Other
173226

227+
Adapt.adapt_storage(to::KA.ConstAdaptor, a::oneDeviceArray) = Base.Experimental.Const(a)
228+
174229
KA.argconvert(::KA.Kernel{oneAPIBackend}, arg) = kernel_convert(arg)
175230

231+
function KA.priority!(::oneAPIBackend, prio::Symbol)
232+
if !(prio in (:high, :normal, :low))
233+
error("priority must be one of :high, :normal, :low")
234+
end
235+
236+
priority_enum = if prio == :high
237+
oneAPI.oneL0.ZE_COMMAND_QUEUE_PRIORITY_PRIORITY_HIGH
238+
elseif prio == :low
239+
oneAPI.oneL0.ZE_COMMAND_QUEUE_PRIORITY_PRIORITY_LOW
240+
else
241+
oneAPI.oneL0.ZE_COMMAND_QUEUE_PRIORITY_NORMAL
242+
end
243+
244+
ctx = oneAPI.context()
245+
dev = oneAPI.device()
246+
247+
# Update the cached queue
248+
# We synchronize the current queue first to ensure safety
249+
current_queue = oneAPI.global_queue(ctx, dev)
250+
oneAPI.oneL0.synchronize(current_queue)
251+
252+
# Replace the queue in task_local_storage
253+
# The key used by global_queue is (:ZeCommandQueue, ctx, dev)
254+
255+
new_queue = oneAPI.oneL0.ZeCommandQueue(
256+
ctx, dev;
257+
flags = oneAPI.oneL0.ZE_COMMAND_QUEUE_FLAG_IN_ORDER,
258+
priority = priority_enum
259+
)
260+
261+
task_local_storage((:ZeCommandQueue, ctx, dev), new_queue)
262+
263+
return nothing
264+
end
265+
176266
end

test/Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
[deps]
22
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
3-
AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"
43
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
54
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
65
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"

test/setup.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Distributed, Test, oneAPI, AcceleratedKernels
1+
using Distributed, Test, oneAPI
22

33
oneAPI.functional() || error("oneAPI.jl is not functional on this system")
44

0 commit comments

Comments
 (0)