From 53da31711cbbcba9a5af230dfb4b57abd49bc92b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= Date: Sat, 10 Jan 2026 15:26:54 +0000 Subject: [PATCH 1/8] Remove some `Core.Box`es --- src/irgen.jl | 280 +++++++++++++++++++++++----------------------- src/jlgen.jl | 5 +- src/mcgen.jl | 44 ++++---- src/metal.jl | 102 +++++++++-------- test/Project.toml | 5 +- 5 files changed, 223 insertions(+), 213 deletions(-) diff --git a/src/irgen.jl b/src/irgen.jl index a7c36a60..2a310227 100644 --- a/src/irgen.jl +++ b/src/irgen.jl @@ -509,7 +509,87 @@ end # once LLVM supports this pattern, consider going back to passing the state by reference, # so that the julia.gpu.state_getter` can be simplified to return an opaque pointer. +function kernel_state_check_user!(additions, val, worklist) + if val isa Instruction + bb = LLVM.parent(val) + new_f = LLVM.parent(bb) + in(new_f, worklist) || push!(additions, new_f) + elseif val isa ConstantExpr + # constant expressions don't have a parent; we need to look up their uses + for use in uses(val) + kernel_state_check_user!(additions, user(use), worklist) + end + else + error("Don't know how to check uses of $val. Please file an issue.") + end +end + +function kernel_state_rewrite_uses!(f, ft) + # update uses + @dispose builder=IRBuilder() begin + for use in uses(f) + val = user(use) + if val isa LLVM.CallBase && called_operand(val) == f + # NOTE: we don't rewrite calls using Julia's jlcall calling convention, + # as those have a fixed argument list, passing actual arguments + # in an array of objects. that doesn't matter, for now, since + # GPU back-ends don't support such calls anyhow. but if we ever + # want to support kernel state passing on more capable back-ends, + # we'll need to update the argument array instead. + if callconv(val) == 37 || callconv(val) == 38 + # TODO: update for LLVM 15 when JuliaLang/julia#45088 is merged. + continue + end + + # forward the state argument + position!(builder, val) + state = call!(builder, state_intr_ft, state_intr, Value[], "state") + new_val = if val isa LLVM.CallInst + call!(builder, ft, f, [state, arguments(val)...], operand_bundles(val)) + else + # TODO: invoke and callbr + error("Rewrite of $(typeof(val))-based calls is not implemented: $val") + end + callconv!(new_val, callconv(val)) + + replace_uses!(val, new_val) + @assert isempty(uses(val)) + erase!(val) + elseif val isa LLVM.CallBase + # the function is being passed as an argument. to avoid having to + # rewrite the target function, instead case the rewritten function to + # the old stateless type. + # XXX: we won't have to do this with opaque pointers. + position!(builder, val) + target_ft = called_type(val) + new_args = map(zip(parameters(target_ft), + arguments(val))) do (param_typ, arg) + if value_type(arg) != param_typ + const_bitcast(arg, param_typ) + else + arg + end + end + new_val = call!(builder, called_type(val), called_operand(val), new_args, + operand_bundles(val)) + callconv!(new_val, callconv(val)) + + replace_uses!(val, new_val) + @assert isempty(uses(val)) + erase!(val) + elseif val isa LLVM.StoreInst + # the function is being stored, which again we'll permit like before. + elseif val isa ConstantExpr + kernel_state_rewrite_uses!(val, ft) + else + error("Cannot rewrite $(typeof(val)) use of function: $val") + end + end + end +end + # add a state argument to every function in the module, starting from the kernel entry point +# update uses of the new function, modifying call sites to include the kernel state function add_kernel_state!(mod::LLVM.Module) job = current_job::CompilerJob @@ -537,22 +617,8 @@ function add_kernel_state!(mod::LLVM.Module) # iteratively discover functions that use the intrinsic or any function calling it worklist_length = length(worklist) additions = LLVM.Function[] - function check_user(val) - if val isa Instruction - bb = LLVM.parent(val) - new_f = LLVM.parent(bb) - in(new_f, worklist) || push!(additions, new_f) - elseif val isa ConstantExpr - # constant expressions don't have a parent; we need to look up their uses - for use in uses(val) - check_user(user(use)) - end - else - error("Don't know how to check uses of $val. Please file an issue.") - end - end for f in worklist, use in uses(f) - check_user(user(use)) + kernel_state_check_user!(additions, user(use), worklist) end for f in additions push!(worklist, f) @@ -639,73 +705,9 @@ function add_kernel_state!(mod::LLVM.Module) erase!(f) end - # update uses of the new function, modifying call sites to include the kernel state - function rewrite_uses!(f, ft) - # update uses - @dispose builder=IRBuilder() begin - for use in uses(f) - val = user(use) - if val isa LLVM.CallBase && called_operand(val) == f - # NOTE: we don't rewrite calls using Julia's jlcall calling convention, - # as those have a fixed argument list, passing actual arguments - # in an array of objects. that doesn't matter, for now, since - # GPU back-ends don't support such calls anyhow. but if we ever - # want to support kernel state passing on more capable back-ends, - # we'll need to update the argument array instead. - if callconv(val) == 37 || callconv(val) == 38 - # TODO: update for LLVM 15 when JuliaLang/julia#45088 is merged. - continue - end - - # forward the state argument - position!(builder, val) - state = call!(builder, state_intr_ft, state_intr, Value[], "state") - new_val = if val isa LLVM.CallInst - call!(builder, ft, f, [state, arguments(val)...], operand_bundles(val)) - else - # TODO: invoke and callbr - error("Rewrite of $(typeof(val))-based calls is not implemented: $val") - end - callconv!(new_val, callconv(val)) - - replace_uses!(val, new_val) - @assert isempty(uses(val)) - erase!(val) - elseif val isa LLVM.CallBase - # the function is being passed as an argument. to avoid having to - # rewrite the target function, instead case the rewritten function to - # the old stateless type. - # XXX: we won't have to do this with opaque pointers. - position!(builder, val) - target_ft = called_type(val) - new_args = map(zip(parameters(target_ft), - arguments(val))) do (param_typ, arg) - if value_type(arg) != param_typ - const_bitcast(arg, param_typ) - else - arg - end - end - new_val = call!(builder, called_type(val), called_operand(val), new_args, - operand_bundles(val)) - callconv!(new_val, callconv(val)) - - replace_uses!(val, new_val) - @assert isempty(uses(val)) - erase!(val) - elseif val isa LLVM.StoreInst - # the function is being stored, which again we'll permit like before. - elseif val isa ConstantExpr - rewrite_uses!(val, ft) - else - error("Cannot rewrite $(typeof(val)) use of function: $val") - end - end - end - end for f in values(workmap) ft = function_type(f) - rewrite_uses!(f, ft) + kernel_state_rewrite_uses!(f, ft) end return true @@ -916,6 +918,64 @@ function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.M end end +function scan_uses!(additions, val, worklist) + for use in uses(val) + candidate = user(use) + if isa(candidate, Instruction) + bb = LLVM.parent(candidate) + new_f = LLVM.parent(bb) + in(new_f, worklist) || push!(additions, new_f) + elseif isa(candidate, ConstantExpr) + scan_uses!(additions, candidate, worklist) + else + error("Don't know how to check uses of $candidate. Please file an issue.") + end + end +end + +function input_arguments_rewrite_uses!(f, new_f) + # update uses + @dispose builder=IRBuilder() begin + for use in uses(f) + val = user(use) + if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst + callee_f = LLVM.parent(LLVM.parent(val)) + # forward the arguments + position!(builder, val) + new_val = if val isa LLVM.CallInst + call!(builder, function_type(new_f), new_f, + [arguments(val)..., parameters(callee_f)[end-nargs+1:end]...], + operand_bundles(val)) + else + # TODO: invoke and callbr + error("Rewrite of $(typeof(val))-based calls is not implemented: $val") + end + callconv!(new_val, callconv(val)) + + replace_uses!(val, new_val) + @assert isempty(uses(val)) + erase!(val) + elseif val isa LLVM.ConstantExpr && opcode(val) == LLVM.API.LLVMBitCast + # XXX: why isn't this caught by the value materializer above? + target = operands(val)[1] + @assert target == f + new_val = LLVM.const_bitcast(new_f, value_type(val)) + input_arguments_rewrite_uses!(val, new_val) + # we can't simply replace this constant expression, as it may be used + # as a call, taking arguments (so we need to rewrite it to pass the input arguments) + + # drop the old constant if it is unused + # XXX: can we do this differently? + if isempty(uses(val)) + LLVM.unsafe_destroy!(val) + end + else + error("Cannot rewrite unknown use of function: $val") + end + end + end +end + function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, entry::LLVM.Function, kernel_intrinsics::Dict) entry_fn = LLVM.name(entry) @@ -936,22 +996,8 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, # iteratively discover functions that use an intrinsic or any function calling it worklist_length = length(worklist) additions = Set{LLVM.Function}() - function scan_uses(val) - for use in uses(val) - candidate = user(use) - if isa(candidate, Instruction) - bb = LLVM.parent(candidate) - new_f = LLVM.parent(bb) - in(new_f, worklist) || push!(additions, new_f) - elseif isa(candidate, ConstantExpr) - scan_uses(candidate) - else - error("Don't know how to check uses of $candidate. Please file an issue.") - end - end - end for f in worklist - scan_uses(f) + scan_uses!(additions, f, worklist) end for f in additions push!(worklist, f) @@ -1014,50 +1060,8 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, end # update other uses of the old function, modifying call sites to pass the arguments - function rewrite_uses!(f, new_f) - # update uses - @dispose builder=IRBuilder() begin - for use in uses(f) - val = user(use) - if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst - callee_f = LLVM.parent(LLVM.parent(val)) - # forward the arguments - position!(builder, val) - new_val = if val isa LLVM.CallInst - call!(builder, function_type(new_f), new_f, - [arguments(val)..., parameters(callee_f)[end-nargs+1:end]...], - operand_bundles(val)) - else - # TODO: invoke and callbr - error("Rewrite of $(typeof(val))-based calls is not implemented: $val") - end - callconv!(new_val, callconv(val)) - - replace_uses!(val, new_val) - @assert isempty(uses(val)) - erase!(val) - elseif val isa LLVM.ConstantExpr && opcode(val) == LLVM.API.LLVMBitCast - # XXX: why isn't this caught by the value materializer above? - target = operands(val)[1] - @assert target == f - new_val = LLVM.const_bitcast(new_f, value_type(val)) - rewrite_uses!(val, new_val) - # we can't simply replace this constant expression, as it may be used - # as a call, taking arguments (so we need to rewrite it to pass the input arguments) - - # drop the old constant if it is unused - # XXX: can we do this differently? - if isempty(uses(val)) - LLVM.unsafe_destroy!(val) - end - else - error("Cannot rewrite unknown use of function: $val") - end - end - end - end for (f, new_f) in workmap - rewrite_uses!(f, new_f) + input_arguments_rewrite_uses!(f, new_f) @assert isempty(uses(f)) replace_metadata_uses!(f, new_f) erase!(f) diff --git a/src/jlgen.jl b/src/jlgen.jl index 0d380cbf..fe11d1de 100644 --- a/src/jlgen.jl +++ b/src/jlgen.jl @@ -757,9 +757,8 @@ function compile_method_instance(@nospecialize(job::CompilerJob)) # instead always go through the callback in order to unlock it properly. # rework this once we depend on Julia 1.9 or later. llvm_ts_mod = LLVM.ThreadSafeModule(llvm_mod_ref) - llvm_mod = nothing - llvm_ts_mod() do mod - llvm_mod = mod + llvm_mod = llvm_ts_mod() do mod + mod end end if !(Sys.ARCH == :x86 || Sys.ARCH == :x86_64) diff --git a/src/mcgen.jl b/src/mcgen.jl index 77a40d85..b16848e0 100644 --- a/src/mcgen.jl +++ b/src/mcgen.jl @@ -29,6 +29,27 @@ end # but at the same time the GPU can't resolve them at run-time. # # this pass performs that resolution at link time. +function replace_bindings!(value, dereferenced) + changed = false + for use in uses(value) + val = user(use) + changed |= if isa(val, LLVM.ConstantExpr) + # recurse + replace_bindings!(val, dereferenced) + elseif isa(val, LLVM.LoadInst) + # resolve + replace_uses!(val, dereferenced) + erase!(val) + # FIXME: iterator invalidation? + true + else + # `changed` didn't change + changed + end + end + changed +end + function resolve_cpu_references!(mod::LLVM.Module) job = current_job::CompilerJob changed = false @@ -38,28 +59,9 @@ function resolve_cpu_references!(mod::LLVM.Module) if isdeclaration(f) && !LLVM.isintrinsic(f) && startswith(fn, "jl_") # eagerly resolve the address of the binding address = ccall(:jl_cglobal, Any, (Any, Any), fn, UInt) - dereferenced = unsafe_load(address) - dereferenced = LLVM.ConstantInt(dereferenced) - - function replace_bindings!(value) - changed = false - for use in uses(value) - val = user(use) - if isa(val, LLVM.ConstantExpr) - # recurse - changed |= replace_bindings!(val) - elseif isa(val, LLVM.LoadInst) - # resolve - replace_uses!(val, dereferenced) - erase!(val) - # FIXME: iterator invalidation? - changed = true - end - end - changed - end + dereferenced = LLVM.ConstantInt(unsafe_load(address)) - changed |= replace_bindings!(f) + changed |= replace_bindings!(f, dereferenced) end end diff --git a/src/metal.jl b/src/metal.jl index d3a83d61..ef5da56f 100644 --- a/src/metal.jl +++ b/src/metal.jl @@ -230,6 +230,23 @@ end # NOTE: this pass also only rewrites pointers _without_ address spaces, which requires it to # be executed after optimization (where Julia's address spaces are stripped). If we ever # want to execute it earlier, adapt remapType to rewrite all pointer types. +function remapType(src) + # TODO: shouldn't we recurse into structs here, making sure the parent object's + # address space matches the contained one? doesn't matter right now as we + # only use LLVMPtr (i.e. no rewriting of contained pointers needed) in the + # device addrss space (i.e. no mismatch between parent and field possible) + dst = if src isa LLVM.PointerType && addrspace(src) == 0 + if supports_typed_pointers(context()) + LLVM.PointerType(remapType(eltype(src)), #=device=# 1) + else + LLVM.PointerType(#=device=# 1) + end + else + src + end + return dst +end + function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function) ft = function_type(f) @@ -244,23 +261,6 @@ function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLV byref[arg.idx] = (arg.cc == BITS_REF || arg.cc == KERNEL_STATE) end - function remapType(src) - # TODO: shouldn't we recurse into structs here, making sure the parent object's - # address space matches the contained one? doesn't matter right now as we - # only use LLVMPtr (i.e. no rewriting of contained pointers needed) in the - # device addrss space (i.e. no mismatch between parent and field possible) - dst = if src isa LLVM.PointerType && addrspace(src) == 0 - if supports_typed_pointers(context()) - LLVM.PointerType(remapType(eltype(src)), #=device=# 1) - else - LLVM.PointerType(#=device=# 1) - end - else - src - end - return dst - end - # generate the new function type & definition new_types = LLVMType[] for (i, param) in enumerate(parameters(ft)) @@ -342,6 +342,19 @@ end # # global constant objects need to reside in address space 2, so we clone each function # that uses global objects and rewrite the globals used by it +function metal_check_user!(function_worklist, val) + if val isa LLVM.Instruction + bb = LLVM.parent(val) + f = LLVM.parent(bb) + + push!(function_worklist, f) + elseif val isa LLVM.ConstantExpr + for use in uses(val) + metal_check_user!(function_worklist, user(use)) + end + end +end + function add_global_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module, entry::LLVM.Function) # determine global variables we need to update @@ -374,20 +387,8 @@ function add_global_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.M # determine which functions we need to update function_worklist = Set{LLVM.Function}() - function check_user(val) - if val isa LLVM.Instruction - bb = LLVM.parent(val) - f = LLVM.parent(bb) - - push!(function_worklist, f) - elseif val isa LLVM.ConstantExpr - for use in uses(val) - check_user(user(use)) - end - end - end for gv in keys(global_map), use in uses(gv) - check_user(user(use)) + metal_check_user!(function_worklist, user(use)) end # update functions that use the global @@ -737,6 +738,23 @@ end # # we don't have a proper back-end, so we're missing out on intrinsics-related functionality. +function type_suffix(typ) + # XXX: can't we use LLVM to do this kind of mangling? + if typ isa LLVM.IntegerType + "i$(width(typ))" + elseif typ == LLVM.HalfType() + "f16" + elseif typ == LLVM.FloatType() + "f32" + elseif typ == LLVM.DoubleType() + "f64" + elseif typ isa LLVM.VectorType + "v$(length(typ))$(type_suffix(eltype(typ)))" + else + error("Unsupported intrinsic type: $typ") + end +end + # replace LLVM intrinsics with AIR equivalents function lower_llvm_intrinsics!(@nospecialize(job::CompilerJob), fun::LLVM.Function) isdeclaration(fun) && return false @@ -796,23 +814,6 @@ function lower_llvm_intrinsics!(@nospecialize(job::CompilerJob), fun::LLVM.Funct # determine type of the intrinsic typ = value_type(call) - function type_suffix(typ) - # XXX: can't we use LLVM to do this kind of mangling? - if typ isa LLVM.IntegerType - "i$(width(typ))" - elseif typ == LLVM.HalfType() - "f16" - elseif typ == LLVM.FloatType() - "f32" - elseif typ == LLVM.DoubleType() - "f64" - elseif typ isa LLVM.VectorType - "v$(length(typ))$(type_suffix(eltype(typ)))" - else - error("Unsupported intrinsic type: $typ") - end - end - if typ isa LLVM.IntegerType || (typ isa LLVM.VectorType && eltype(typ) isa LLVM.IntegerType) fn *= "." * (signed::Bool ? "s" : "u") * "." * type_suffix(typ) else @@ -1013,11 +1014,11 @@ function annotate_air_intrinsics!(@nospecialize(job::CompilerJob), mod::LLVM.Mod end push!(attrs, EnumAttribute(name, 0)) end - changed = true + return true end # synchronization - if fn == "air.wg.barrier" || fn == "air.simdgroup.barrier" + changed |= if fn == "air.wg.barrier" || fn == "air.simdgroup.barrier" add_attributes("nounwind", "convergent") # atomics @@ -1033,6 +1034,9 @@ function annotate_air_intrinsics!(@nospecialize(job::CompilerJob), mod::LLVM.Mod elseif match(r"^air.atomic.(local|global).(add|sub|min|max|and|or|xor)", fn) !== nothing # TODO: "memory(argmem: readwrite)" on LLVM 16+ add_attributes("argmemonly", "nounwind") + else + # `changed` didn't change + changed end end diff --git a/test/Project.toml b/test/Project.toml index 0c5499ed..f6119573 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" IOCapture = "b5f81e59-6552-4d32-b1f0-c071b021bf89" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" @@ -19,5 +20,5 @@ demumble_jll = "1e29f10c-031c-5a83-9565-69cddfc27673" Aqua = "0.8" ParallelTestRunner = "1" -[extras] -GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" +[sources] +GPUCompiler = {path = ".."} From aa36b56f22962d1f32373ae695b993d3585534ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= Date: Sat, 10 Jan 2026 15:43:35 +0000 Subject: [PATCH 2/8] Make Runic happier --- src/irgen.jl | 41 +++++++++++++++++++++++++---------------- src/mcgen.jl | 2 +- src/metal.jl | 4 ++-- 3 files changed, 28 insertions(+), 19 deletions(-) diff --git a/src/irgen.jl b/src/irgen.jl index 2a310227..359681ce 100644 --- a/src/irgen.jl +++ b/src/irgen.jl @@ -510,7 +510,7 @@ end # so that the julia.gpu.state_getter` can be simplified to return an opaque pointer. function kernel_state_check_user!(additions, val, worklist) - if val isa Instruction + return if val isa Instruction bb = LLVM.parent(val) new_f = LLVM.parent(bb) in(new_f, worklist) || push!(additions, new_f) @@ -526,7 +526,7 @@ end function kernel_state_rewrite_uses!(f, ft) # update uses - @dispose builder=IRBuilder() begin + return @dispose builder = IRBuilder() begin for use in uses(f) val = user(use) if val isa LLVM.CallBase && called_operand(val) == f @@ -562,16 +562,22 @@ function kernel_state_rewrite_uses!(f, ft) # XXX: we won't have to do this with opaque pointers. position!(builder, val) target_ft = called_type(val) - new_args = map(zip(parameters(target_ft), - arguments(val))) do (param_typ, arg) - if value_type(arg) != param_typ - const_bitcast(arg, param_typ) - else - arg - end - end - new_val = call!(builder, called_type(val), called_operand(val), new_args, - operand_bundles(val)) + new_args = map( + zip( + parameters(target_ft), + arguments(val) + ) + ) do (param_typ, arg) + if value_type(arg) != param_typ + const_bitcast(arg, param_typ) + else + arg + end + end + new_val = call!( + builder, called_type(val), called_operand(val), new_args, + operand_bundles(val) + ) callconv!(new_val, callconv(val)) replace_uses!(val, new_val) @@ -931,11 +937,12 @@ function scan_uses!(additions, val, worklist) error("Don't know how to check uses of $candidate. Please file an issue.") end end + return end function input_arguments_rewrite_uses!(f, new_f) # update uses - @dispose builder=IRBuilder() begin + return @dispose builder = IRBuilder() begin for use in uses(f) val = user(use) if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst @@ -943,9 +950,11 @@ function input_arguments_rewrite_uses!(f, new_f) # forward the arguments position!(builder, val) new_val = if val isa LLVM.CallInst - call!(builder, function_type(new_f), new_f, - [arguments(val)..., parameters(callee_f)[end-nargs+1:end]...], - operand_bundles(val)) + call!( + builder, function_type(new_f), new_f, + [arguments(val)..., parameters(callee_f)[(end - nargs + 1):end]...], + operand_bundles(val) + ) else # TODO: invoke and callbr error("Rewrite of $(typeof(val))-based calls is not implemented: $val") diff --git a/src/mcgen.jl b/src/mcgen.jl index b16848e0..66bcb818 100644 --- a/src/mcgen.jl +++ b/src/mcgen.jl @@ -47,7 +47,7 @@ function replace_bindings!(value, dereferenced) changed end end - changed + return changed end function resolve_cpu_references!(mod::LLVM.Module) diff --git a/src/metal.jl b/src/metal.jl index ef5da56f..9c1163e1 100644 --- a/src/metal.jl +++ b/src/metal.jl @@ -343,7 +343,7 @@ end # global constant objects need to reside in address space 2, so we clone each function # that uses global objects and rewrite the globals used by it function metal_check_user!(function_worklist, val) - if val isa LLVM.Instruction + return if val isa LLVM.Instruction bb = LLVM.parent(val) f = LLVM.parent(bb) @@ -740,7 +740,7 @@ end function type_suffix(typ) # XXX: can't we use LLVM to do this kind of mangling? - if typ isa LLVM.IntegerType + return if typ isa LLVM.IntegerType "i$(width(typ))" elseif typ == LLVM.HalfType() "f16" From 8b9ec83028ec885ee4195b588036d7884f977c78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= Date: Sat, 10 Jan 2026 16:12:58 +0000 Subject: [PATCH 3/8] Revert change to `rewrite_uses!` --- src/irgen.jl | 137 ++++++++++++++++++++++++--------------------------- 1 file changed, 65 insertions(+), 72 deletions(-) diff --git a/src/irgen.jl b/src/irgen.jl index 359681ce..e9997b22 100644 --- a/src/irgen.jl +++ b/src/irgen.jl @@ -524,78 +524,7 @@ function kernel_state_check_user!(additions, val, worklist) end end -function kernel_state_rewrite_uses!(f, ft) - # update uses - return @dispose builder = IRBuilder() begin - for use in uses(f) - val = user(use) - if val isa LLVM.CallBase && called_operand(val) == f - # NOTE: we don't rewrite calls using Julia's jlcall calling convention, - # as those have a fixed argument list, passing actual arguments - # in an array of objects. that doesn't matter, for now, since - # GPU back-ends don't support such calls anyhow. but if we ever - # want to support kernel state passing on more capable back-ends, - # we'll need to update the argument array instead. - if callconv(val) == 37 || callconv(val) == 38 - # TODO: update for LLVM 15 when JuliaLang/julia#45088 is merged. - continue - end - - # forward the state argument - position!(builder, val) - state = call!(builder, state_intr_ft, state_intr, Value[], "state") - new_val = if val isa LLVM.CallInst - call!(builder, ft, f, [state, arguments(val)...], operand_bundles(val)) - else - # TODO: invoke and callbr - error("Rewrite of $(typeof(val))-based calls is not implemented: $val") - end - callconv!(new_val, callconv(val)) - - replace_uses!(val, new_val) - @assert isempty(uses(val)) - erase!(val) - elseif val isa LLVM.CallBase - # the function is being passed as an argument. to avoid having to - # rewrite the target function, instead case the rewritten function to - # the old stateless type. - # XXX: we won't have to do this with opaque pointers. - position!(builder, val) - target_ft = called_type(val) - new_args = map( - zip( - parameters(target_ft), - arguments(val) - ) - ) do (param_typ, arg) - if value_type(arg) != param_typ - const_bitcast(arg, param_typ) - else - arg - end - end - new_val = call!( - builder, called_type(val), called_operand(val), new_args, - operand_bundles(val) - ) - callconv!(new_val, callconv(val)) - - replace_uses!(val, new_val) - @assert isempty(uses(val)) - erase!(val) - elseif val isa LLVM.StoreInst - # the function is being stored, which again we'll permit like before. - elseif val isa ConstantExpr - kernel_state_rewrite_uses!(val, ft) - else - error("Cannot rewrite $(typeof(val)) use of function: $val") - end - end - end -end - # add a state argument to every function in the module, starting from the kernel entry point -# update uses of the new function, modifying call sites to include the kernel state function add_kernel_state!(mod::LLVM.Module) job = current_job::CompilerJob @@ -711,9 +640,73 @@ function add_kernel_state!(mod::LLVM.Module) erase!(f) end + # update uses of the new function, modifying call sites to include the kernel state + function rewrite_uses!(f, ft) + # update uses + @dispose builder=IRBuilder() begin + for use in uses(f) + val = user(use) + if val isa LLVM.CallBase && called_operand(val) == f + # NOTE: we don't rewrite calls using Julia's jlcall calling convention, + # as those have a fixed argument list, passing actual arguments + # in an array of objects. that doesn't matter, for now, since + # GPU back-ends don't support such calls anyhow. but if we ever + # want to support kernel state passing on more capable back-ends, + # we'll need to update the argument array instead. + if callconv(val) == 37 || callconv(val) == 38 + # TODO: update for LLVM 15 when JuliaLang/julia#45088 is merged. + continue + end + + # forward the state argument + position!(builder, val) + state = call!(builder, state_intr_ft, state_intr, Value[], "state") + new_val = if val isa LLVM.CallInst + call!(builder, ft, f, [state, arguments(val)...], operand_bundles(val)) + else + # TODO: invoke and callbr + error("Rewrite of $(typeof(val))-based calls is not implemented: $val") + end + callconv!(new_val, callconv(val)) + + replace_uses!(val, new_val) + @assert isempty(uses(val)) + erase!(val) + elseif val isa LLVM.CallBase + # the function is being passed as an argument. to avoid having to + # rewrite the target function, instead case the rewritten function to + # the old stateless type. + # XXX: we won't have to do this with opaque pointers. + position!(builder, val) + target_ft = called_type(val) + new_args = map(zip(parameters(target_ft), + arguments(val))) do (param_typ, arg) + if value_type(arg) != param_typ + const_bitcast(arg, param_typ) + else + arg + end + end + new_val = call!(builder, called_type(val), called_operand(val), new_args, + operand_bundles(val)) + callconv!(new_val, callconv(val)) + + replace_uses!(val, new_val) + @assert isempty(uses(val)) + erase!(val) + elseif val isa LLVM.StoreInst + # the function is being stored, which again we'll permit like before. + elseif val isa ConstantExpr + rewrite_uses!(val, ft) + else + error("Cannot rewrite $(typeof(val)) use of function: $val") + end + end + end + end for f in values(workmap) ft = function_type(f) - kernel_state_rewrite_uses!(f, ft) + rewrite_uses!(f, ft) end return true From a8ae8460747c3f2699a24e2be7168784f375f9f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= Date: Sat, 10 Jan 2026 16:13:18 +0000 Subject: [PATCH 4/8] Mark `state` inside `rewrite_uses!` as `local` to avoid boxing --- src/irgen.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/irgen.jl b/src/irgen.jl index e9997b22..558c4a02 100644 --- a/src/irgen.jl +++ b/src/irgen.jl @@ -660,7 +660,7 @@ function add_kernel_state!(mod::LLVM.Module) # forward the state argument position!(builder, val) - state = call!(builder, state_intr_ft, state_intr, Value[], "state") + local state = call!(builder, state_intr_ft, state_intr, Value[], "state") new_val = if val isa LLVM.CallInst call!(builder, ft, f, [state, arguments(val)...], operand_bundles(val)) else From 610bfc3f2ad1e7202585b5b5606fbeab7e71d49b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= Date: Sat, 10 Jan 2026 18:19:00 +0000 Subject: [PATCH 5/8] Pass `nargs` as argument to `input_arguments_rewrite_uses!` --- src/irgen.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/irgen.jl b/src/irgen.jl index 558c4a02..1a91541f 100644 --- a/src/irgen.jl +++ b/src/irgen.jl @@ -933,7 +933,7 @@ function scan_uses!(additions, val, worklist) return end -function input_arguments_rewrite_uses!(f, new_f) +function input_arguments_rewrite_uses!(f, new_f, nargs) # update uses return @dispose builder = IRBuilder() begin for use in uses(f) @@ -962,7 +962,7 @@ function input_arguments_rewrite_uses!(f, new_f) target = operands(val)[1] @assert target == f new_val = LLVM.const_bitcast(new_f, value_type(val)) - input_arguments_rewrite_uses!(val, new_val) + input_arguments_rewrite_uses!(val, new_val, nargs) # we can't simply replace this constant expression, as it may be used # as a call, taking arguments (so we need to rewrite it to pass the input arguments) @@ -1063,7 +1063,7 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, # update other uses of the old function, modifying call sites to pass the arguments for (f, new_f) in workmap - input_arguments_rewrite_uses!(f, new_f) + input_arguments_rewrite_uses!(f, new_f, nargs) @assert isempty(uses(f)) replace_metadata_uses!(f, new_f) erase!(f) From 98af08dc8864e72b5c77e11aa5df2f54f73860f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= Date: Sun, 11 Jan 2026 11:29:46 +0000 Subject: [PATCH 6/8] Revert "Revert change to `rewrite_uses!`" This reverts commit 8b9ec83028ec885ee4195b588036d7884f977c78. --- src/irgen.jl | 137 +++++++++++++++++++++++++++------------------------ 1 file changed, 72 insertions(+), 65 deletions(-) diff --git a/src/irgen.jl b/src/irgen.jl index 1a91541f..6307119b 100644 --- a/src/irgen.jl +++ b/src/irgen.jl @@ -524,6 +524,77 @@ function kernel_state_check_user!(additions, val, worklist) end end +# update uses of the new function, modifying call sites to include the kernel state +function kernel_state_rewrite_uses!(f, ft) + # update uses + return @dispose builder = IRBuilder() begin + for use in uses(f) + val = user(use) + if val isa LLVM.CallBase && called_operand(val) == f + # NOTE: we don't rewrite calls using Julia's jlcall calling convention, + # as those have a fixed argument list, passing actual arguments + # in an array of objects. that doesn't matter, for now, since + # GPU back-ends don't support such calls anyhow. but if we ever + # want to support kernel state passing on more capable back-ends, + # we'll need to update the argument array instead. + if callconv(val) == 37 || callconv(val) == 38 + # TODO: update for LLVM 15 when JuliaLang/julia#45088 is merged. + continue + end + + # forward the state argument + position!(builder, val) + state = call!(builder, state_intr_ft, state_intr, Value[], "state") + new_val = if val isa LLVM.CallInst + call!(builder, ft, f, [state, arguments(val)...], operand_bundles(val)) + else + # TODO: invoke and callbr + error("Rewrite of $(typeof(val))-based calls is not implemented: $val") + end + callconv!(new_val, callconv(val)) + + replace_uses!(val, new_val) + @assert isempty(uses(val)) + erase!(val) + elseif val isa LLVM.CallBase + # the function is being passed as an argument. to avoid having to + # rewrite the target function, instead case the rewritten function to + # the old stateless type. + # XXX: we won't have to do this with opaque pointers. + position!(builder, val) + target_ft = called_type(val) + new_args = map( + zip( + parameters(target_ft), + arguments(val) + ) + ) do (param_typ, arg) + if value_type(arg) != param_typ + const_bitcast(arg, param_typ) + else + arg + end + end + new_val = call!( + builder, called_type(val), called_operand(val), new_args, + operand_bundles(val) + ) + callconv!(new_val, callconv(val)) + + replace_uses!(val, new_val) + @assert isempty(uses(val)) + erase!(val) + elseif val isa LLVM.StoreInst + # the function is being stored, which again we'll permit like before. + elseif val isa ConstantExpr + kernel_state_rewrite_uses!(val, ft) + else + error("Cannot rewrite $(typeof(val)) use of function: $val") + end + end + end +end + # add a state argument to every function in the module, starting from the kernel entry point function add_kernel_state!(mod::LLVM.Module) job = current_job::CompilerJob @@ -640,73 +711,9 @@ function add_kernel_state!(mod::LLVM.Module) erase!(f) end - # update uses of the new function, modifying call sites to include the kernel state - function rewrite_uses!(f, ft) - # update uses - @dispose builder=IRBuilder() begin - for use in uses(f) - val = user(use) - if val isa LLVM.CallBase && called_operand(val) == f - # NOTE: we don't rewrite calls using Julia's jlcall calling convention, - # as those have a fixed argument list, passing actual arguments - # in an array of objects. that doesn't matter, for now, since - # GPU back-ends don't support such calls anyhow. but if we ever - # want to support kernel state passing on more capable back-ends, - # we'll need to update the argument array instead. - if callconv(val) == 37 || callconv(val) == 38 - # TODO: update for LLVM 15 when JuliaLang/julia#45088 is merged. - continue - end - - # forward the state argument - position!(builder, val) - local state = call!(builder, state_intr_ft, state_intr, Value[], "state") - new_val = if val isa LLVM.CallInst - call!(builder, ft, f, [state, arguments(val)...], operand_bundles(val)) - else - # TODO: invoke and callbr - error("Rewrite of $(typeof(val))-based calls is not implemented: $val") - end - callconv!(new_val, callconv(val)) - - replace_uses!(val, new_val) - @assert isempty(uses(val)) - erase!(val) - elseif val isa LLVM.CallBase - # the function is being passed as an argument. to avoid having to - # rewrite the target function, instead case the rewritten function to - # the old stateless type. - # XXX: we won't have to do this with opaque pointers. - position!(builder, val) - target_ft = called_type(val) - new_args = map(zip(parameters(target_ft), - arguments(val))) do (param_typ, arg) - if value_type(arg) != param_typ - const_bitcast(arg, param_typ) - else - arg - end - end - new_val = call!(builder, called_type(val), called_operand(val), new_args, - operand_bundles(val)) - callconv!(new_val, callconv(val)) - - replace_uses!(val, new_val) - @assert isempty(uses(val)) - erase!(val) - elseif val isa LLVM.StoreInst - # the function is being stored, which again we'll permit like before. - elseif val isa ConstantExpr - rewrite_uses!(val, ft) - else - error("Cannot rewrite $(typeof(val)) use of function: $val") - end - end - end - end for f in values(workmap) ft = function_type(f) - rewrite_uses!(f, ft) + kernel_state_rewrite_uses!(f, ft) end return true From 21704740a882fce96b30ef9816e62d45b1553a83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= Date: Sun, 11 Jan 2026 11:32:25 +0000 Subject: [PATCH 7/8] Pass more arguments to `kernel_state_rewrite_uses!` These variables were being captured from the outer scope before. --- src/irgen.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/irgen.jl b/src/irgen.jl index 6307119b..cca67882 100644 --- a/src/irgen.jl +++ b/src/irgen.jl @@ -525,7 +525,7 @@ function kernel_state_check_user!(additions, val, worklist) end # update uses of the new function, modifying call sites to include the kernel state -function kernel_state_rewrite_uses!(f, ft) +function kernel_state_rewrite_uses!(f, ft, state_intr_ft, state_intr) # update uses return @dispose builder = IRBuilder() begin for use in uses(f) @@ -537,7 +537,7 @@ function kernel_state_rewrite_uses!(f, ft) # GPU back-ends don't support such calls anyhow. but if we ever # want to support kernel state passing on more capable back-ends, # we'll need to update the argument array instead. - if callconv(val) == 37 || callconv(val) == 38 + if callconv(val) in (37, 38) # TODO: update for LLVM 15 when JuliaLang/julia#45088 is merged. continue end @@ -587,7 +587,7 @@ function kernel_state_rewrite_uses!(f, ft) elseif val isa LLVM.StoreInst # the function is being stored, which again we'll permit like before. elseif val isa ConstantExpr - kernel_state_rewrite_uses!(val, ft) + kernel_state_rewrite_uses!(val, ft, state_intr_ft, state_intr) else error("Cannot rewrite $(typeof(val)) use of function: $val") end @@ -713,7 +713,7 @@ function add_kernel_state!(mod::LLVM.Module) for f in values(workmap) ft = function_type(f) - kernel_state_rewrite_uses!(f, ft) + kernel_state_rewrite_uses!(f, ft, state_intr_ft, state_intr) end return true From 32c335be182a24bfc5d7f660c949ed75d63275ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= Date: Sun, 11 Jan 2026 11:45:45 +0000 Subject: [PATCH 8/8] Resolve some boxes in `emit_llvm` --- src/driver.jl | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/driver.jl b/src/driver.jl index 950ea272..8923a660 100644 --- a/src/driver.jl +++ b/src/driver.jl @@ -179,12 +179,14 @@ end const __llvm_initialized = Ref(false) -@locked function emit_llvm(@nospecialize(job::CompilerJob); kwargs...) +@locked function emit_llvm(@nospecialize(compiler_job::CompilerJob); kwargs...) # XXX: remove on next major version - if !isempty(kwargs) + job = if !isempty(kwargs) Base.depwarn("The GPUCompiler `emit_llvm` function is an internal API. Use `GPUCompiler.compile` (with any kwargs passed to `CompilerConfig`) instead.", :emit_llvm) - config = CompilerConfig(job.config; kwargs...) - job = CompilerJob(job.source, config) + config = CompilerConfig(compiler_job.config; kwargs...) + CompilerJob(compiler_job.source, config) + else + compiler_job end if !__llvm_initialized[] @@ -299,10 +301,11 @@ const __llvm_initialized = Ref(false) if job.config.toplevel && job.config.libraries # load the runtime outside of a timing block (because it recurses into the compiler) - if !uses_julia_runtime(job) + runtime_fns, runtime_intrinsics = if !uses_julia_runtime(job) runtime = load_runtime(job) - runtime_fns = LLVM.name.(defs(runtime)) - runtime_intrinsics = ["julia.gc_alloc_obj"] + LLVM.name.(defs(runtime)), ["julia.gc_alloc_obj"] + else + String[], String[] end @tracepoint "Library linking" begin