Skip to content

Commit b37535b

Browse files
niaowdeadprogram
authored andcommitted
transform (gc): create stack slots in callers of external functions
This updates the stack slot pass to include callers of external functions which may access non-argument memory. Wiithout this change, a use-after-free could occur on WASM when calling a reentrant function or switching to another goroutine.
1 parent ee0a10e commit b37535b

File tree

3 files changed

+154
-32
lines changed

3 files changed

+154
-32
lines changed

transform/gc.go

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ import (
44
"tinygo.org/x/go-llvm"
55
)
66

7+
// This is somewhat ugly to access through the API.
8+
// https://github.com/llvm/llvm-project/blob/94ebcfd16dac67486bae624f74e1c5c789448bae/llvm/include/llvm/Support/ModRef.h#L62
9+
// https://github.com/llvm/llvm-project/blob/94ebcfd16dac67486bae624f74e1c5c789448bae/llvm/include/llvm/Support/ModRef.h#L87
10+
const shiftExcludeArgMem = 2
11+
712
// MakeGCStackSlots converts all calls to runtime.trackPointer to explicit
813
// stores to stack slots that are scannable by the GC.
914
func MakeGCStackSlots(mod llvm.Module) bool {
@@ -36,54 +41,63 @@ func MakeGCStackSlots(mod llvm.Module) bool {
3641
defer targetData.Dispose()
3742
uintptrType := ctx.IntType(targetData.PointerSize() * 8)
3843

39-
// Look at *all* functions to see whether they are free of function pointer
44+
// All functions that call runtime.alloc needs stack objects.
45+
trackFuncs := map[llvm.Value]struct{}{}
46+
markParentFunctions(trackFuncs, alloc)
47+
48+
// External functions may indirectly suspend the goroutine or perform a heap allocation.
49+
// Their callers should get stack objects.
50+
memAttr := llvm.AttributeKindID("memory")
51+
for fn := mod.FirstFunction(); !fn.IsNil(); fn = llvm.NextFunction(fn) {
52+
if _, ok := trackFuncs[fn]; ok {
53+
continue // already found
54+
}
55+
if !fn.FirstBasicBlock().IsNil() {
56+
// This is not an external function.
57+
continue
58+
}
59+
if fn == trackPointer {
60+
// Manually exclude trackPointer.
61+
continue
62+
}
63+
64+
mem := fn.GetEnumFunctionAttribute(memAttr)
65+
if !mem.IsNil() && mem.GetEnumValue()>>shiftExcludeArgMem == 0 {
66+
// This does not access non-argument memory.
67+
// Exclude it.
68+
continue
69+
}
70+
71+
// The callers need stack objects.
72+
markParentFunctions(trackFuncs, fn)
73+
}
74+
75+
// Look at all other functions to see whether they contain function pointer
4076
// calls.
4177
// This takes less than 5ms for ~100kB of WebAssembly but would perhaps be
4278
// faster when written in C++ (to avoid the CGo overhead).
43-
funcsWithFPCall := map[llvm.Value]struct{}{}
44-
n := 0
4579
for fn := mod.FirstFunction(); !fn.IsNil(); fn = llvm.NextFunction(fn) {
46-
n++
47-
if _, ok := funcsWithFPCall[fn]; ok {
80+
if _, ok := trackFuncs[fn]; ok {
4881
continue // already found
4982
}
50-
done := false
51-
for bb := fn.FirstBasicBlock(); !bb.IsNil() && !done; bb = llvm.NextBasicBlock(bb) {
52-
for call := bb.FirstInstruction(); !call.IsNil() && !done; call = llvm.NextInstruction(call) {
83+
84+
scanBody:
85+
for bb := fn.FirstBasicBlock(); !bb.IsNil(); bb = llvm.NextBasicBlock(bb) {
86+
for call := bb.FirstInstruction(); !call.IsNil(); call = llvm.NextInstruction(call) {
5387
if call.IsACallInst().IsNil() {
5488
continue // only looking at calls
5589
}
5690
called := call.CalledValue()
5791
if !called.IsAFunction().IsNil() {
5892
continue // only looking for function pointers
5993
}
60-
funcsWithFPCall[fn] = struct{}{}
61-
markParentFunctions(funcsWithFPCall, fn)
62-
done = true
94+
trackFuncs[fn] = struct{}{}
95+
markParentFunctions(trackFuncs, fn)
96+
break scanBody
6397
}
6498
}
6599
}
66100

67-
// Determine which functions need stack objects. Many leaf functions don't
68-
// need it: it only causes overhead for them.
69-
// Actually, in one test it was only able to eliminate stack object from 12%
70-
// of functions that had a call to runtime.trackPointer (8 out of 68
71-
// functions), so this optimization is not as big as it may seem.
72-
allocatingFunctions := map[llvm.Value]struct{}{} // set of allocating functions
73-
74-
// Work from runtime.alloc and trace all parents to check which functions do
75-
// a heap allocation (and thus which functions do not).
76-
markParentFunctions(allocatingFunctions, alloc)
77-
78-
// Also trace all functions that call a function pointer.
79-
for fn := range funcsWithFPCall {
80-
// Assume that functions that call a function pointer do a heap
81-
// allocation as a conservative guess because the called function might
82-
// do a heap allocation.
83-
allocatingFunctions[fn] = struct{}{}
84-
markParentFunctions(allocatingFunctions, fn)
85-
}
86-
87101
// Collect some variables used below in the loop.
88102
stackChainStart := mod.NamedGlobal("runtime.stackChainStart")
89103
if stackChainStart.IsNil() {
@@ -110,7 +124,7 @@ func MakeGCStackSlots(mod llvm.Module) bool {
110124
// Pick the parent function.
111125
fn := call.InstructionParent().Parent()
112126

113-
if _, ok := allocatingFunctions[fn]; !ok {
127+
if _, ok := trackFuncs[fn]; !ok {
114128
// This function nor any of the functions it calls (recursively)
115129
// allocate anything from the heap, so it will not trigger a garbage
116130
// collection cycle. Thus, it does not need to track local pointer

transform/testdata/gc-stackslots.ll

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ target triple = "wasm32-unknown-unknown-wasm"
44
@runtime.stackChainStart = external global ptr
55
@someGlobal = global i8 3
66
@ptrGlobal = global ptr null
7+
@arrGlobal = global [8 x i8] zeroinitializer
78

89
declare void @runtime.trackPointer(ptr nocapture readonly)
910

@@ -116,3 +117,51 @@ define void @allocAndSave(ptr %x) {
116117
store ptr %x, ptr @ptrGlobal
117118
ret void
118119
}
120+
121+
declare void @"(internal/task).Pause"()
122+
123+
define ptr @getAndPause() {
124+
%ptr = call ptr @getPointer()
125+
call void @runtime.trackPointer(ptr %ptr)
126+
; Calling a function with unknown memory access forces stack slot creation.
127+
call void @"(internal/task).Pause"()
128+
ret ptr %ptr
129+
}
130+
131+
; Function Attrs: memory(readwrite)
132+
declare void @externCallWithMemAttr() #0
133+
134+
define ptr @getAndCallWithMemAttr() {
135+
%ptr = call ptr @getPointer()
136+
call void @runtime.trackPointer(ptr %ptr)
137+
; Calling an external function which may access non-arg memory forces stack slot creation.
138+
call void @externCallWithMemAttr()
139+
ret ptr %ptr
140+
}
141+
142+
; Generic function that returns a slice (that must be tracked).
143+
define {ptr, i32, i32} @getSlice() {
144+
ret {ptr, i32, i32} {ptr @someGlobal, i32 8, i32 8}
145+
}
146+
147+
define i32 @copyToSlice(ptr %src.ptr, i32 %src.len, i32 %src.cap) {
148+
%dst = call {ptr, i32, i32} @getSlice()
149+
%dst.ptr = extractvalue {ptr, i32, i32} %dst, 0
150+
call void @runtime.trackPointer(ptr %dst.ptr)
151+
%dst.len = extractvalue {ptr, i32, i32} %dst, 1
152+
; Math intrinsics do not need stack slots.
153+
%minLen = call i32 @llvm.umin.i32(i32 %dst.len, i32 %src.len)
154+
; Intrinsics which only access argument memory do not need stack slots.
155+
call void @llvm.memmove.p0.p0.i32(ptr %dst.ptr, ptr %src.ptr, i32 %minLen, i1 false)
156+
ret i32 %minLen
157+
}
158+
159+
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
160+
declare i32 @llvm.umin.i32(i32, i32) #1
161+
162+
; Function Attrs: nocallback nofree nounwind willreturn memory(argmem: readwrite)
163+
declare void @llvm.memmove.p0.p0.i32(ptr nocapture writeonly, ptr nocapture readonly, i32, i1 immarg) #2
164+
165+
attributes #0 = { memory(readwrite) }
166+
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
167+
attributes #2 = { nocallback nofree nounwind willreturn memory(argmem: readwrite) }

transform/testdata/gc-stackslots.out.ll

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ target triple = "wasm32-unknown-unknown-wasm"
44
@runtime.stackChainStart = internal global ptr null
55
@someGlobal = global i8 3
66
@ptrGlobal = global ptr null
7+
@arrGlobal = global [8 x i8] zeroinitializer
78

89
declare void @runtime.trackPointer(ptr nocapture readonly)
910

@@ -166,3 +167,61 @@ define void @allocAndSave(ptr %x) {
166167
store ptr %1, ptr @runtime.stackChainStart, align 4
167168
ret void
168169
}
170+
171+
declare void @"(internal/task).Pause"()
172+
173+
define ptr @getAndPause() {
174+
%gc.stackobject = alloca { ptr, i32, ptr }, align 8
175+
store { ptr, i32, ptr } { ptr null, i32 1, ptr null }, ptr %gc.stackobject, align 4
176+
%1 = load ptr, ptr @runtime.stackChainStart, align 4
177+
%2 = getelementptr { ptr, i32, ptr }, ptr %gc.stackobject, i32 0, i32 0
178+
store ptr %1, ptr %2, align 4
179+
store ptr %gc.stackobject, ptr @runtime.stackChainStart, align 4
180+
%ptr = call ptr @getPointer()
181+
%3 = getelementptr { ptr, i32, ptr }, ptr %gc.stackobject, i32 0, i32 2
182+
store ptr %ptr, ptr %3, align 4
183+
call void @"(internal/task).Pause"()
184+
store ptr %1, ptr @runtime.stackChainStart, align 4
185+
ret ptr %ptr
186+
}
187+
188+
; Function Attrs: memory(readwrite)
189+
declare void @externCallWithMemAttr() #0
190+
191+
define ptr @getAndCallWithMemAttr() {
192+
%gc.stackobject = alloca { ptr, i32, ptr }, align 8
193+
store { ptr, i32, ptr } { ptr null, i32 1, ptr null }, ptr %gc.stackobject, align 4
194+
%1 = load ptr, ptr @runtime.stackChainStart, align 4
195+
%2 = getelementptr { ptr, i32, ptr }, ptr %gc.stackobject, i32 0, i32 0
196+
store ptr %1, ptr %2, align 4
197+
store ptr %gc.stackobject, ptr @runtime.stackChainStart, align 4
198+
%ptr = call ptr @getPointer()
199+
%3 = getelementptr { ptr, i32, ptr }, ptr %gc.stackobject, i32 0, i32 2
200+
store ptr %ptr, ptr %3, align 4
201+
call void @externCallWithMemAttr()
202+
store ptr %1, ptr @runtime.stackChainStart, align 4
203+
ret ptr %ptr
204+
}
205+
206+
define { ptr, i32, i32 } @getSlice() {
207+
ret { ptr, i32, i32 } { ptr @someGlobal, i32 8, i32 8 }
208+
}
209+
210+
define i32 @copyToSlice(ptr %src.ptr, i32 %src.len, i32 %src.cap) {
211+
%dst = call { ptr, i32, i32 } @getSlice()
212+
%dst.ptr = extractvalue { ptr, i32, i32 } %dst, 0
213+
%dst.len = extractvalue { ptr, i32, i32 } %dst, 1
214+
%minLen = call i32 @llvm.umin.i32(i32 %dst.len, i32 %src.len)
215+
call void @llvm.memmove.p0.p0.i32(ptr %dst.ptr, ptr %src.ptr, i32 %minLen, i1 false)
216+
ret i32 %minLen
217+
}
218+
219+
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
220+
declare i32 @llvm.umin.i32(i32, i32) #1
221+
222+
; Function Attrs: nocallback nofree nounwind willreturn memory(argmem: readwrite)
223+
declare void @llvm.memmove.p0.p0.i32(ptr nocapture writeonly, ptr nocapture readonly, i32, i1 immarg) #2
224+
225+
attributes #0 = { memory(readwrite) }
226+
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
227+
attributes #2 = { nocallback nofree nounwind willreturn memory(argmem: readwrite) }

0 commit comments

Comments
 (0)