Skip to content

Commit 2a6eb3a

Browse files
authored
allow cassette transform for cglobal (#64)
We need to special case it since it still has non-linearized form.
1 parent c3b4311 commit 2a6eb3a

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

CassetteBase/src/CassetteBase.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,12 @@ function transform_stmt(@nospecialize(x), map_slot_number, map_ssa_value, @nospe
7575
if isa(x, Expr)
7676
head = x.head
7777
if head === :call
78-
return Expr(:call, SlotNumber(1), map(transform, x.args[1:end])...)
78+
arg1 = x.args[1]
79+
if ((arg1 === Base.cglobal || (arg1 isa GlobalRef && arg1.name === :cglobal)) ||
80+
(arg1 === Core.tuple || (arg1 isa GlobalRef && arg1.name === :tuple)))
81+
return Expr(:call, map(transform, x.args)...) # don't cassette this -- we still have non-linearized cglobal
82+
end
83+
return Expr(:call, SlotNumber(1), map(transform, x.args)...)
7984
elseif head === :foreigncall
8085
arg1 = x.args[1]
8186
if Meta.isexpr(arg1, :call)

CassetteBase/test/test_basic.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ let pass = BasicPass()
5050
@test pass(sin, 1) == sin(1)
5151
@test_throws MethodError pass("1") do x; sin(x); end
5252
end
53+
const libccalltest = "libccalltest"
54+
fcglobal() = unsafe_load(cglobal((:global_var, libccalltest), Cint))
55+
let pass = BasicPass()
56+
@test pass(fcglobal) isa Cint
57+
end
5358

5459
struct RaisePass end
5560
@eval function (pass::RaisePass)(fargs...)

0 commit comments

Comments
 (0)