Skip to content

Commit 3380e7c

Browse files
committed
LLVMCodeBuilder: Fix use after free when accessing from another scope
1 parent 297ae3a commit 3380e7c

File tree

3 files changed

+71
-22
lines changed

3 files changed

+71
-22
lines changed

src/dev/engine/internal/llvm/llvmcodebuilder.cpp

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
146146
step.functionReturnReg->value = ret;
147147

148148
if (step.functionReturnReg->type() == Compiler::StaticType::String)
149-
m_heap.push_back(step.functionReturnReg->value);
149+
freeLater(step.functionReturnReg->value);
150150
}
151151

152152
break;
@@ -749,7 +749,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
749749
assert(step.args.size() == 0);
750750
const LLVMListPtr &listPtr = m_listPtrs[step.workList];
751751
llvm::Value *ptr = m_builder.CreateCall(resolve_list_to_string(), listPtr.ptr);
752-
m_heap.push_back(ptr); // deallocate later
752+
freeLater(ptr);
753753
step.functionReturnReg->value = ptr;
754754
break;
755755
}
@@ -791,7 +791,8 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
791791

792792
case LLVMInstruction::Type::Yield:
793793
if (!m_warp) {
794-
freeHeap();
794+
// TODO: Do not allow use after suspend (use after free)
795+
freeScopeHeap();
795796
syncVariables(targetVariables);
796797
coro->createSuspend();
797798
reloadVariables(targetVariables);
@@ -809,7 +810,6 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
809810
assert(step.args.size() == 1);
810811
const auto &reg = step.args[0];
811812
assert(reg.first == Compiler::StaticType::Bool);
812-
freeHeap();
813813
statement.condition = castValue(reg.second, reg.first);
814814

815815
// Switch to body branch
@@ -836,7 +836,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
836836
// Jump to the branch after the if statement
837837
assert(!statement.afterIf);
838838
statement.afterIf = llvm::BasicBlock::Create(m_ctx, "", func);
839-
freeHeap();
839+
freeScopeHeap();
840840
m_builder.CreateBr(statement.afterIf);
841841

842842
// Create else branch
@@ -855,12 +855,12 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
855855
case LLVMInstruction::Type::EndIf: {
856856
assert(!ifStatements.empty());
857857
LLVMIfStatement &statement = ifStatements.back();
858+
freeScopeHeap();
858859

859860
// Jump to the branch after the if statement
860861
if (!statement.afterIf)
861862
statement.afterIf = llvm::BasicBlock::Create(m_ctx, "", func);
862863

863-
freeHeap();
864864
m_builder.CreateBr(statement.afterIf);
865865

866866
if (statement.elseBranch) {
@@ -901,7 +901,6 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
901901

902902
// Clamp count if <= 0 (we can skip the loop if count is not positive)
903903
llvm::Value *comparison = m_builder.CreateFCmpULE(count, llvm::ConstantFP::get(m_ctx, llvm::APFloat(0.0)));
904-
freeHeap();
905904
m_builder.CreateCondBr(comparison, loop.afterLoop, roundBranch);
906905

907906
// Round (Scratch-specific behavior)
@@ -955,7 +954,6 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
955954
const auto &reg = step.args[0];
956955
assert(reg.first == Compiler::StaticType::Bool);
957956
llvm::Value *condition = castValue(reg.second, reg.first);
958-
freeHeap();
959957
m_builder.CreateCondBr(condition, body, loop.afterLoop);
960958

961959
// Switch to body branch
@@ -977,7 +975,6 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
977975
const auto &reg = step.args[0];
978976
assert(reg.first == Compiler::StaticType::Bool);
979977
llvm::Value *condition = castValue(reg.second, reg.first);
980-
freeHeap();
981978
m_builder.CreateCondBr(condition, loop.afterLoop, body);
982979

983980
// Switch to body branch
@@ -990,7 +987,6 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
990987
LLVMLoop loop;
991988
loop.isRepeatLoop = false;
992989
loop.conditionBranch = llvm::BasicBlock::Create(m_ctx, "", func);
993-
freeHeap();
994990
m_builder.CreateBr(loop.conditionBranch);
995991
m_builder.SetInsertPoint(loop.conditionBranch);
996992
loops.push_back(loop);
@@ -1009,7 +1005,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
10091005
}
10101006

10111007
// Jump to the condition branch
1012-
freeHeap();
1008+
freeScopeHeap();
10131009
m_builder.CreateBr(loop.conditionBranch);
10141010

10151011
// Switch to the branch after the loop
@@ -1032,7 +1028,8 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
10321028
m_builder.CreateBr(endBranch);
10331029

10341030
m_builder.SetInsertPoint(endBranch);
1035-
freeHeap();
1031+
assert(m_heap.size() == 1);
1032+
freeScopeHeap();
10361033
syncVariables(targetVariables);
10371034

10381035
// End and verify the function
@@ -1535,6 +1532,8 @@ void LLVMCodeBuilder::pushScopeLevel()
15351532
m_scopeLists.push_back(listTypes);
15361533
} else
15371534
m_scopeLists.push_back(m_scopeLists.back());
1535+
1536+
m_heap.push_back({});
15381537
}
15391538

15401539
void LLVMCodeBuilder::popScopeLevel()
@@ -1556,6 +1555,9 @@ void LLVMCodeBuilder::popScopeLevel()
15561555
}
15571556

15581557
m_scopeLists.pop_back();
1558+
1559+
freeScopeHeap();
1560+
m_heap.pop_back();
15591561
}
15601562

15611563
void LLVMCodeBuilder::verifyFunction(llvm::Function *func)
@@ -1590,13 +1592,28 @@ LLVMRegister *LLVMCodeBuilder::addReg(std::shared_ptr<LLVMRegister> reg)
15901592
return reg.get();
15911593
}
15921594

1593-
void LLVMCodeBuilder::freeHeap()
1595+
void LLVMCodeBuilder::freeLater(llvm::Value *value)
1596+
{
1597+
assert(!m_heap.empty());
1598+
1599+
if (m_heap.empty())
1600+
return;
1601+
1602+
m_heap.back().push_back(value);
1603+
}
1604+
1605+
void LLVMCodeBuilder::freeScopeHeap()
15941606
{
1595-
// Free dynamically allocated memory
1596-
for (llvm::Value *ptr : m_heap)
1607+
if (m_heap.empty())
1608+
return;
1609+
1610+
// Free dynamically allocated memory in current scope
1611+
auto &heap = m_heap.back();
1612+
1613+
for (llvm::Value *ptr : heap)
15971614
m_builder.CreateFree(ptr);
15981615

1599-
m_heap.clear();
1616+
heap.clear();
16001617
}
16011618

16021619
llvm::Value *LLVMCodeBuilder::castValue(LLVMRegister *reg, Compiler::StaticType targetType)
@@ -1668,7 +1685,7 @@ llvm::Value *LLVMCodeBuilder::castValue(LLVMRegister *reg, Compiler::StaticType
16681685
case Compiler::StaticType::Unknown: {
16691686
// Cast to string
16701687
llvm::Value *ptr = m_builder.CreateCall(resolve_value_toCString(), reg->value);
1671-
m_heap.push_back(ptr); // deallocate later
1688+
freeLater(ptr);
16721689
return ptr;
16731690
}
16741691

@@ -1731,7 +1748,7 @@ llvm::Value *LLVMCodeBuilder::castRawValue(LLVMRegister *reg, Compiler::StaticTy
17311748
case Compiler::StaticType::Number: {
17321749
// Convert double to string
17331750
llvm::Value *ptr = m_builder.CreateCall(resolve_value_doubleToCString(), reg->value);
1734-
m_heap.push_back(ptr); // deallocate later
1751+
freeLater(ptr);
17351752
return ptr;
17361753
}
17371754

src/dev/engine/internal/llvm/llvmcodebuilder.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ class LLVMCodeBuilder : public ICodeBuilder
118118

119119
LLVMRegister *addReg(std::shared_ptr<LLVMRegister> reg);
120120

121-
void freeHeap();
121+
void freeLater(llvm::Value *value);
122+
void freeScopeHeap();
122123
llvm::Value *castValue(LLVMRegister *reg, Compiler::StaticType targetType);
123124
llvm::Value *castRawValue(LLVMRegister *reg, Compiler::StaticType targetType);
124125
llvm::Constant *castConstValue(const Value &value, Compiler::StaticType targetType);
@@ -202,7 +203,7 @@ class LLVMCodeBuilder : public ICodeBuilder
202203
bool m_defaultWarp = false;
203204
bool m_warp = false;
204205

205-
std::vector<llvm::Value *> m_heap;
206+
std::vector<std::vector<llvm::Value *>> m_heap; // scopes
206207

207208
std::shared_ptr<ExecutableCode> m_output;
208209
};

test/dev/llvm/llvmcodebuilder_test.cpp

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2765,6 +2765,7 @@ TEST_F(LLVMCodeBuilderTest, IfStatement)
27652765
m_builder->endIf();
27662766

27672767
// Nested 1
2768+
CompilerValue *str = m_builder->addFunctionCall("test_const_string", Compiler::StaticType::String, { Compiler::StaticType::String }, { m_builder->addConstValue("test") });
27682769
v = m_builder->addConstValue(true);
27692770
m_builder->beginIfStatement(v);
27702771
{
@@ -2779,6 +2780,9 @@ TEST_F(LLVMCodeBuilderTest, IfStatement)
27792780
v = m_builder->addConstValue(1);
27802781
m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });
27812782

2783+
// str should still be allocated
2784+
m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { str });
2785+
27822786
v = m_builder->addConstValue(false);
27832787
m_builder->beginIfStatement(v);
27842788
m_builder->beginElseBranch();
@@ -2787,6 +2791,9 @@ TEST_F(LLVMCodeBuilderTest, IfStatement)
27872791
m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });
27882792
}
27892793
m_builder->endIf();
2794+
2795+
// str should still be allocated
2796+
m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { str });
27902797
}
27912798
m_builder->endIf();
27922799
}
@@ -2807,6 +2814,9 @@ TEST_F(LLVMCodeBuilderTest, IfStatement)
28072814
}
28082815
m_builder->endIf();
28092816

2817+
// str should still be allocated
2818+
m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { str });
2819+
28102820
// Nested 2
28112821
v = m_builder->addConstValue(false);
28122822
m_builder->beginIfStatement(v);
@@ -2826,6 +2836,8 @@ TEST_F(LLVMCodeBuilderTest, IfStatement)
28262836
}
28272837
m_builder->beginElseBranch();
28282838
{
2839+
str = m_builder->addFunctionCall("test_const_string", Compiler::StaticType::String, { Compiler::StaticType::String }, { m_builder->addConstValue("test") });
2840+
28292841
v = m_builder->addConstValue(true);
28302842
m_builder->beginIfStatement(v);
28312843
{
@@ -2834,6 +2846,9 @@ TEST_F(LLVMCodeBuilderTest, IfStatement)
28342846
}
28352847
m_builder->beginElseBranch();
28362848
m_builder->endIf();
2849+
2850+
// str should still be allocated
2851+
m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { str });
28372852
}
28382853
m_builder->endIf();
28392854

@@ -2855,8 +2870,12 @@ TEST_F(LLVMCodeBuilderTest, IfStatement)
28552870
"no_args_ret\n"
28562871
"1_arg 9\n"
28572872
"1_arg 1\n"
2873+
"test\n"
28582874
"1_arg 2\n"
2859-
"1_arg 7\n";
2875+
"test\n"
2876+
"test\n"
2877+
"1_arg 7\n"
2878+
"test\n";
28602879

28612880
EXPECT_CALL(m_target, isStage).WillRepeatedly(Return(false));
28622881
testing::internal::CaptureStdout();
@@ -3144,6 +3163,7 @@ TEST_F(LLVMCodeBuilderTest, RepeatLoop)
31443163
m_builder->endLoop();
31453164

31463165
// Nested
3166+
CompilerValue *str = m_builder->addFunctionCall("test_const_string", Compiler::StaticType::String, { Compiler::StaticType::String }, { m_builder->addConstValue("test") });
31473167
v = m_builder->addConstValue(2);
31483168
m_builder->beginRepeatLoop(v);
31493169
{
@@ -3152,6 +3172,9 @@ TEST_F(LLVMCodeBuilderTest, RepeatLoop)
31523172
{
31533173
v = m_builder->addConstValue(1);
31543174
m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v });
3175+
3176+
// str should still be allocated
3177+
m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { str });
31553178
}
31563179
m_builder->endLoop();
31573180

@@ -3168,6 +3191,9 @@ TEST_F(LLVMCodeBuilderTest, RepeatLoop)
31683191
}
31693192
m_builder->endLoop();
31703193

3194+
// str should still be allocated
3195+
m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { str });
3196+
31713197
auto code = m_builder->finalize();
31723198
Script script(&m_target, nullptr, nullptr);
31733199
script.setCode(code);
@@ -3186,17 +3212,22 @@ TEST_F(LLVMCodeBuilderTest, RepeatLoop)
31863212
"0\n"
31873213
"1\n"
31883214
"1_arg 1\n"
3215+
"test\n"
31893216
"1_arg 1\n"
3217+
"test\n"
31903218
"1_arg 2\n"
31913219
"0\n"
31923220
"1\n"
31933221
"2\n"
31943222
"1_arg 1\n"
3223+
"test\n"
31953224
"1_arg 1\n"
3225+
"test\n"
31963226
"1_arg 2\n"
31973227
"0\n"
31983228
"1\n"
3199-
"2\n";
3229+
"2\n"
3230+
"test\n";
32003231

32013232
EXPECT_CALL(m_target, isStage).WillRepeatedly(Return(false));
32023233
testing::internal::CaptureStdout();

0 commit comments

Comments
 (0)