Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions src/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,14 @@ end

const __llvm_initialized = Ref(false)

@locked function emit_llvm(@nospecialize(job::CompilerJob); kwargs...)
@locked function emit_llvm(@nospecialize(compiler_job::CompilerJob); kwargs...)
# XXX: remove on next major version
if !isempty(kwargs)
job = if !isempty(kwargs)
Base.depwarn("The GPUCompiler `emit_llvm` function is an internal API. Use `GPUCompiler.compile` (with any kwargs passed to `CompilerConfig`) instead.", :emit_llvm)
config = CompilerConfig(job.config; kwargs...)
job = CompilerJob(job.source, config)
config = CompilerConfig(compiler_job.config; kwargs...)
CompilerJob(compiler_job.source, config)
else
compiler_job
end

if !__llvm_initialized[]
Expand Down Expand Up @@ -299,10 +301,11 @@ const __llvm_initialized = Ref(false)

if job.config.toplevel && job.config.libraries
# load the runtime outside of a timing block (because it recurses into the compiler)
if !uses_julia_runtime(job)
runtime_fns, runtime_intrinsics = if !uses_julia_runtime(job)
runtime = load_runtime(job)
runtime_fns = LLVM.name.(defs(runtime))
runtime_intrinsics = ["julia.gc_alloc_obj"]
LLVM.name.(defs(runtime)), ["julia.gc_alloc_obj"]
else
String[], String[]
end

@tracepoint "Library linking" begin
Expand Down
289 changes: 151 additions & 138 deletions src/irgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,92 @@ end
# once LLVM supports this pattern, consider going back to passing the state by reference,
# so that the julia.gpu.state_getter` can be simplified to return an opaque pointer.

function kernel_state_check_user!(additions, val, worklist)
return if val isa Instruction
bb = LLVM.parent(val)
new_f = LLVM.parent(bb)
in(new_f, worklist) || push!(additions, new_f)
elseif val isa ConstantExpr
# constant expressions don't have a parent; we need to look up their uses
for use in uses(val)
kernel_state_check_user!(additions, user(use), worklist)
end
else
error("Don't know how to check uses of $val. Please file an issue.")
end
end

# update uses of the new function, modifying call sites to include the kernel state
function kernel_state_rewrite_uses!(f, ft, state_intr_ft, state_intr)
# update uses
return @dispose builder = IRBuilder() begin
for use in uses(f)
val = user(use)
if val isa LLVM.CallBase && called_operand(val) == f
# NOTE: we don't rewrite calls using Julia's jlcall calling convention,
# as those have a fixed argument list, passing actual arguments
# in an array of objects. that doesn't matter, for now, since
# GPU back-ends don't support such calls anyhow. but if we ever
# want to support kernel state passing on more capable back-ends,
# we'll need to update the argument array instead.
if callconv(val) in (37, 38)
# TODO: update for LLVM 15 when JuliaLang/julia#45088 is merged.
continue
end

# forward the state argument
position!(builder, val)
state = call!(builder, state_intr_ft, state_intr, Value[], "state")
new_val = if val isa LLVM.CallInst
call!(builder, ft, f, [state, arguments(val)...], operand_bundles(val))
else
# TODO: invoke and callbr
error("Rewrite of $(typeof(val))-based calls is not implemented: $val")
end
callconv!(new_val, callconv(val))

replace_uses!(val, new_val)
@assert isempty(uses(val))
erase!(val)
elseif val isa LLVM.CallBase
# the function is being passed as an argument. to avoid having to
# rewrite the target function, instead case the rewritten function to
# the old stateless type.
# XXX: we won't have to do this with opaque pointers.
position!(builder, val)
target_ft = called_type(val)
new_args = map(
zip(
parameters(target_ft),
arguments(val)
)
) do (param_typ, arg)
if value_type(arg) != param_typ
const_bitcast(arg, param_typ)
else
arg
end
end
new_val = call!(
builder, called_type(val), called_operand(val), new_args,
operand_bundles(val)
)
callconv!(new_val, callconv(val))

replace_uses!(val, new_val)
@assert isempty(uses(val))
erase!(val)
elseif val isa LLVM.StoreInst
# the function is being stored, which again we'll permit like before.
elseif val isa ConstantExpr
kernel_state_rewrite_uses!(val, ft, state_intr_ft, state_intr)
else
error("Cannot rewrite $(typeof(val)) use of function: $val")
end
end
end
end

# add a state argument to every function in the module, starting from the kernel entry point
function add_kernel_state!(mod::LLVM.Module)
job = current_job::CompilerJob
Expand Down Expand Up @@ -537,22 +623,8 @@ function add_kernel_state!(mod::LLVM.Module)
# iteratively discover functions that use the intrinsic or any function calling it
worklist_length = length(worklist)
additions = LLVM.Function[]
function check_user(val)
if val isa Instruction
bb = LLVM.parent(val)
new_f = LLVM.parent(bb)
in(new_f, worklist) || push!(additions, new_f)
elseif val isa ConstantExpr
# constant expressions don't have a parent; we need to look up their uses
for use in uses(val)
check_user(user(use))
end
else
error("Don't know how to check uses of $val. Please file an issue.")
end
end
for f in worklist, use in uses(f)
check_user(user(use))
kernel_state_check_user!(additions, user(use), worklist)
end
for f in additions
push!(worklist, f)
Expand Down Expand Up @@ -639,73 +711,9 @@ function add_kernel_state!(mod::LLVM.Module)
erase!(f)
end

# update uses of the new function, modifying call sites to include the kernel state
function rewrite_uses!(f, ft)
# update uses
@dispose builder=IRBuilder() begin
for use in uses(f)
val = user(use)
if val isa LLVM.CallBase && called_operand(val) == f
# NOTE: we don't rewrite calls using Julia's jlcall calling convention,
# as those have a fixed argument list, passing actual arguments
# in an array of objects. that doesn't matter, for now, since
# GPU back-ends don't support such calls anyhow. but if we ever
# want to support kernel state passing on more capable back-ends,
# we'll need to update the argument array instead.
if callconv(val) == 37 || callconv(val) == 38
# TODO: update for LLVM 15 when JuliaLang/julia#45088 is merged.
continue
end

# forward the state argument
position!(builder, val)
state = call!(builder, state_intr_ft, state_intr, Value[], "state")
new_val = if val isa LLVM.CallInst
call!(builder, ft, f, [state, arguments(val)...], operand_bundles(val))
else
# TODO: invoke and callbr
error("Rewrite of $(typeof(val))-based calls is not implemented: $val")
end
callconv!(new_val, callconv(val))

replace_uses!(val, new_val)
@assert isempty(uses(val))
erase!(val)
elseif val isa LLVM.CallBase
# the function is being passed as an argument. to avoid having to
# rewrite the target function, instead case the rewritten function to
# the old stateless type.
# XXX: we won't have to do this with opaque pointers.
position!(builder, val)
target_ft = called_type(val)
new_args = map(zip(parameters(target_ft),
arguments(val))) do (param_typ, arg)
if value_type(arg) != param_typ
const_bitcast(arg, param_typ)
else
arg
end
end
new_val = call!(builder, called_type(val), called_operand(val), new_args,
operand_bundles(val))
callconv!(new_val, callconv(val))

replace_uses!(val, new_val)
@assert isempty(uses(val))
erase!(val)
elseif val isa LLVM.StoreInst
# the function is being stored, which again we'll permit like before.
elseif val isa ConstantExpr
rewrite_uses!(val, ft)
else
error("Cannot rewrite $(typeof(val)) use of function: $val")
end
end
end
end
for f in values(workmap)
ft = function_type(f)
rewrite_uses!(f, ft)
kernel_state_rewrite_uses!(f, ft, state_intr_ft, state_intr)
end

return true
Expand Down Expand Up @@ -916,6 +924,67 @@ function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.M
end
end

function scan_uses!(additions, val, worklist)
for use in uses(val)
candidate = user(use)
if isa(candidate, Instruction)
bb = LLVM.parent(candidate)
new_f = LLVM.parent(bb)
in(new_f, worklist) || push!(additions, new_f)
elseif isa(candidate, ConstantExpr)
scan_uses!(additions, candidate, worklist)
else
error("Don't know how to check uses of $candidate. Please file an issue.")
end
end
return
end

function input_arguments_rewrite_uses!(f, new_f, nargs)
# update uses
return @dispose builder = IRBuilder() begin
for use in uses(f)
val = user(use)
if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
callee_f = LLVM.parent(LLVM.parent(val))
# forward the arguments
position!(builder, val)
new_val = if val isa LLVM.CallInst
call!(
builder, function_type(new_f), new_f,
[arguments(val)..., parameters(callee_f)[(end - nargs + 1):end]...],
operand_bundles(val)
)
else
# TODO: invoke and callbr
error("Rewrite of $(typeof(val))-based calls is not implemented: $val")
end
callconv!(new_val, callconv(val))

replace_uses!(val, new_val)
@assert isempty(uses(val))
erase!(val)
elseif val isa LLVM.ConstantExpr && opcode(val) == LLVM.API.LLVMBitCast
# XXX: why isn't this caught by the value materializer above?
target = operands(val)[1]
@assert target == f
new_val = LLVM.const_bitcast(new_f, value_type(val))
input_arguments_rewrite_uses!(val, new_val, nargs)
# we can't simply replace this constant expression, as it may be used
# as a call, taking arguments (so we need to rewrite it to pass the input arguments)

# drop the old constant if it is unused
# XXX: can we do this differently?
if isempty(uses(val))
LLVM.unsafe_destroy!(val)
end
else
error("Cannot rewrite unknown use of function: $val")
end
end
end
end

function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
entry::LLVM.Function, kernel_intrinsics::Dict)
entry_fn = LLVM.name(entry)
Expand All @@ -936,22 +1005,8 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
# iteratively discover functions that use an intrinsic or any function calling it
worklist_length = length(worklist)
additions = Set{LLVM.Function}()
function scan_uses(val)
for use in uses(val)
candidate = user(use)
if isa(candidate, Instruction)
bb = LLVM.parent(candidate)
new_f = LLVM.parent(bb)
in(new_f, worklist) || push!(additions, new_f)
elseif isa(candidate, ConstantExpr)
scan_uses(candidate)
else
error("Don't know how to check uses of $candidate. Please file an issue.")
end
end
end
for f in worklist
scan_uses(f)
scan_uses!(additions, f, worklist)
end
for f in additions
push!(worklist, f)
Expand Down Expand Up @@ -1014,50 +1069,8 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
end

# update other uses of the old function, modifying call sites to pass the arguments
function rewrite_uses!(f, new_f)
# update uses
@dispose builder=IRBuilder() begin
for use in uses(f)
val = user(use)
if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
callee_f = LLVM.parent(LLVM.parent(val))
# forward the arguments
position!(builder, val)
new_val = if val isa LLVM.CallInst
call!(builder, function_type(new_f), new_f,
[arguments(val)..., parameters(callee_f)[end-nargs+1:end]...],
operand_bundles(val))
else
# TODO: invoke and callbr
error("Rewrite of $(typeof(val))-based calls is not implemented: $val")
end
callconv!(new_val, callconv(val))

replace_uses!(val, new_val)
@assert isempty(uses(val))
erase!(val)
elseif val isa LLVM.ConstantExpr && opcode(val) == LLVM.API.LLVMBitCast
# XXX: why isn't this caught by the value materializer above?
target = operands(val)[1]
@assert target == f
new_val = LLVM.const_bitcast(new_f, value_type(val))
rewrite_uses!(val, new_val)
# we can't simply replace this constant expression, as it may be used
# as a call, taking arguments (so we need to rewrite it to pass the input arguments)

# drop the old constant if it is unused
# XXX: can we do this differently?
if isempty(uses(val))
LLVM.unsafe_destroy!(val)
end
else
error("Cannot rewrite unknown use of function: $val")
end
end
end
end
for (f, new_f) in workmap
rewrite_uses!(f, new_f)
input_arguments_rewrite_uses!(f, new_f, nargs)
@assert isempty(uses(f))
replace_metadata_uses!(f, new_f)
erase!(f)
Expand Down
5 changes: 2 additions & 3 deletions src/jlgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -757,9 +757,8 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
# instead always go through the callback in order to unlock it properly.
# rework this once we depend on Julia 1.9 or later.
llvm_ts_mod = LLVM.ThreadSafeModule(llvm_mod_ref)
llvm_mod = nothing
llvm_ts_mod() do mod
llvm_mod = mod
llvm_mod = llvm_ts_mod() do mod
mod
end
end
if !(Sys.ARCH == :x86 || Sys.ARCH == :x86_64)
Expand Down
Loading
Loading