Skip to content

Commit b1a141d

Browse files
committed
Store LLVM script function arguments in LLVMBuildUtils
1 parent 30bfcc8 commit b1a141d

File tree

3 files changed

+74
-33
lines changed

3 files changed

+74
-33
lines changed

src/engine/internal/llvm/llvmbuildutils.cpp

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,26 @@ LLVMBuildUtils::LLVMBuildUtils(LLVMCompilerContext *ctx, llvm::IRBuilder<> &buil
3030
createListMap();
3131
}
3232

33-
void LLVMBuildUtils::init(llvm::Function *function, llvm::Value *targetVariables, llvm::Value *targetLists)
33+
void LLVMBuildUtils::init(llvm::Function *function, BlockPrototype *procedurePrototype, bool warp)
3434
{
3535
m_function = function;
36-
m_targetVariables = targetVariables;
37-
m_targetLists = targetLists;
36+
m_procedurePrototype = procedurePrototype;
37+
38+
m_executionContextPtr = m_function->getArg(0);
39+
m_targetPtr = m_function->getArg(1);
40+
m_targetVariables = m_function->getArg(2);
41+
m_targetLists = m_function->getArg(3);
42+
m_warpArg = m_procedurePrototype ? m_function->getArg(4) : nullptr;
43+
44+
if (m_procedurePrototype && m_warp)
45+
m_function->addFnAttr(llvm::Attribute::InlineHint);
46+
3847
m_stringHeap.clear();
3948
pushScopeLevel();
4049

4150
// Create variable pointers
4251
for (auto &[var, varPtr] : m_variablePtrs) {
43-
llvm::Value *ptr = getVariablePtr(targetVariables, var);
52+
llvm::Value *ptr = getVariablePtr(m_targetVariables, var);
4453

4554
// Direct access
4655
varPtr.heapPtr = ptr;
@@ -69,7 +78,7 @@ void LLVMBuildUtils::init(llvm::Function *function, llvm::Value *targetVariables
6978

7079
// Create list pointers
7180
for (auto &[list, listPtr] : m_listPtrs) {
72-
listPtr.ptr = getListPtr(targetLists, list);
81+
listPtr.ptr = getListPtr(m_targetLists, list);
7382

7483
listPtr.dataPtr = m_builder.CreateAlloca(m_valueDataType->getPointerTo()->getPointerTo());
7584
m_builder.CreateStore(m_builder.CreateCall(m_functions.resolve_list_data_ptr(), listPtr.ptr), listPtr.dataPtr);
@@ -85,6 +94,11 @@ void LLVMBuildUtils::end()
8594
freeScopeHeap();
8695
}
8796

97+
llvm::LLVMContext &LLVMBuildUtils::llvmCtx()
98+
{
99+
return m_llvmCtx;
100+
}
101+
88102
llvm::IRBuilder<> &LLVMBuildUtils::builder()
89103
{
90104
return m_builder;
@@ -95,6 +109,26 @@ LLVMFunctions &LLVMBuildUtils::functions()
95109
return m_functions;
96110
}
97111

112+
BlockPrototype *LLVMBuildUtils::procedurePrototype() const
113+
{
114+
return m_procedurePrototype;
115+
}
116+
117+
bool LLVMBuildUtils::warp() const
118+
{
119+
return m_warp;
120+
}
121+
122+
llvm::Value *LLVMBuildUtils::executionContextPtr()
123+
{
124+
return m_executionContextPtr;
125+
}
126+
127+
llvm::Value *LLVMBuildUtils::targetPtr()
128+
{
129+
return m_targetPtr;
130+
}
131+
98132
llvm::Value *LLVMBuildUtils::targetVariables()
99133
{
100134
return m_targetVariables;
@@ -105,6 +139,11 @@ llvm::Value *LLVMBuildUtils::targetLists()
105139
return m_targetLists;
106140
}
107141

142+
llvm::Value *LLVMBuildUtils::warpArg()
143+
{
144+
return m_warpArg;
145+
}
146+
108147
void LLVMBuildUtils::createVariablePtr(Variable *variable)
109148
{
110149
if (m_variablePtrs.find(variable) == m_variablePtrs.cend())

src/engine/internal/llvm/llvmbuildutils.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,21 @@ class LLVMBuildUtils
2525

2626
LLVMBuildUtils(LLVMCompilerContext *ctx, llvm::IRBuilder<> &builder);
2727

28-
void init(llvm::Function *function, llvm::Value *targetVariables, llvm::Value *targetLists);
28+
void init(llvm::Function *function, BlockPrototype *procedurePrototype, bool warp);
2929
void end();
3030

31+
llvm::LLVMContext &llvmCtx();
3132
llvm::IRBuilder<> &builder();
3233
LLVMFunctions &functions();
3334

35+
BlockPrototype *procedurePrototype() const;
36+
bool warp() const;
37+
38+
llvm::Value *executionContextPtr();
39+
llvm::Value *targetPtr();
3440
llvm::Value *targetVariables();
3541
llvm::Value *targetLists();
42+
llvm::Value *warpArg();
3643

3744
void createVariablePtr(Variable *variable);
3845
void createListPtr(List *list);
@@ -93,11 +100,18 @@ class LLVMBuildUtils
93100
llvm::StructType *m_valueDataType = nullptr;
94101
llvm::StructType *m_stringPtrType = nullptr;
95102

103+
BlockPrototype *m_procedurePrototype = nullptr;
104+
bool m_warp = false;
105+
106+
llvm::Value *m_executionContextPtr = nullptr;
107+
llvm::Value *m_targetPtr = nullptr;
96108
llvm::Value *m_targetVariables = nullptr;
109+
llvm::Value *m_targetLists = nullptr;
110+
llvm::Value *m_warpArg = nullptr;
111+
97112
std::unordered_map<Variable *, size_t> m_targetVariableMap;
98113
std::unordered_map<Variable *, LLVMVariablePtr> m_variablePtrs;
99114

100-
llvm::Value *m_targetLists = nullptr;
101115
std::unordered_map<List *, size_t> m_targetListMap;
102116
std::unordered_map<List *, LLVMListPtr> m_listPtrs;
103117

src/engine/internal/llvm/llvmcodebuilder.cpp

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,6 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
7171
else
7272
m_function = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, funcName, m_module);
7373

74-
llvm::Value *executionContextPtr = m_function->getArg(0);
75-
llvm::Value *targetPtr = m_function->getArg(1);
76-
llvm::Value *targetVariables = m_function->getArg(2);
77-
llvm::Value *targetLists = m_function->getArg(3);
78-
llvm::Value *warpArg = nullptr;
79-
80-
if (m_procedurePrototype)
81-
warpArg = m_function->getArg(4);
82-
83-
if (m_procedurePrototype && m_warp)
84-
m_function->addFnAttr(llvm::Attribute::InlineHint);
85-
8674
llvm::BasicBlock *entry = llvm::BasicBlock::Create(m_llvmCtx, "entry", m_function);
8775
llvm::BasicBlock *endBranch = llvm::BasicBlock::Create(m_llvmCtx, "end", m_function);
8876
m_builder.SetInsertPoint(entry);
@@ -96,7 +84,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
9684
std::vector<LLVMIfStatement> ifStatements;
9785
std::vector<LLVMLoop> loops;
9886

99-
m_utils.init(m_function, targetVariables, targetLists);
87+
m_utils.init(m_function, m_procedurePrototype, m_warp);
10088

10189
// Execute recorded instructions
10290
LLVMInstruction *ins = m_instructions.first();
@@ -108,18 +96,18 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
10896
std::vector<llvm::Value *> args;
10997

11098
// Variables must be synchronized because the function can read them
111-
m_utils.syncVariables(targetVariables);
99+
m_utils.syncVariables(m_utils.targetVariables());
112100

113101
// Add execution context arg
114102
if (ins->functionCtxArg) {
115103
types.push_back(llvm::PointerType::get(llvm::Type::getInt8Ty(m_llvmCtx), 0));
116-
args.push_back(executionContextPtr);
104+
args.push_back(m_utils.executionContextPtr());
117105
}
118106

119107
// Add target pointer arg
120108
if (ins->functionTargetArg) {
121109
types.push_back(llvm::PointerType::get(llvm::Type::getInt8Ty(m_llvmCtx), 0));
122-
args.push_back(targetPtr);
110+
args.push_back(m_utils.targetPtr());
123111
}
124112

125113
// Args
@@ -200,7 +188,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
200188
if (reg1->type() == Compiler::StaticType::Bool && reg2->type() == Compiler::StaticType::Bool) {
201189
llvm::Value *bool1 = m_utils.castValue(arg1.second, Compiler::StaticType::Bool);
202190
llvm::Value *bool2 = m_utils.castValue(arg2.second, Compiler::StaticType::Bool);
203-
ins->functionReturnReg->value = m_builder.CreateCall(m_utils.functions().resolve_llvm_random_bool(), { executionContextPtr, bool1, bool2 });
191+
ins->functionReturnReg->value = m_builder.CreateCall(m_utils.functions().resolve_llvm_random_bool(), { m_utils.executionContextPtr(), bool1, bool2 });
204192
} else {
205193
llvm::Constant *inf = llvm::ConstantFP::getInfinity(m_builder.getDoubleTy(), false);
206194
llvm::Value *num1 = m_utils.removeNaN(m_utils.castValue(arg1.second, Compiler::StaticType::Number));
@@ -212,12 +200,12 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
212200
// NOTE: The random function will be called even in edge cases where it isn't needed, but they're rare, so it shouldn't be an issue
213201
if (reg1->type() == Compiler::StaticType::Number && reg2->type() == Compiler::StaticType::Number)
214202
ins->functionReturnReg->value =
215-
m_builder.CreateSelect(isInfOrNaN, sum, m_builder.CreateCall(m_utils.functions().resolve_llvm_random_double(), { executionContextPtr, num1, num2 }));
203+
m_builder.CreateSelect(isInfOrNaN, sum, m_builder.CreateCall(m_utils.functions().resolve_llvm_random_double(), { m_utils.executionContextPtr(), num1, num2 }));
216204
else {
217205
llvm::Value *value1 = m_utils.createValue(reg1);
218206
llvm::Value *value2 = m_utils.createValue(reg2);
219207
ins->functionReturnReg->value =
220-
m_builder.CreateSelect(isInfOrNaN, sum, m_builder.CreateCall(m_utils.functions().resolve_llvm_random(), { executionContextPtr, value1, value2 }));
208+
m_builder.CreateSelect(isInfOrNaN, sum, m_builder.CreateCall(m_utils.functions().resolve_llvm_random(), { m_utils.executionContextPtr(), value1, value2 }));
221209
}
222210
}
223211

@@ -231,7 +219,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
231219
const auto &arg2 = ins->args[1];
232220
llvm::Value *from = m_builder.CreateFPToSI(m_utils.castValue(arg1.second, arg1.first), m_builder.getInt64Ty());
233221
llvm::Value *to = m_builder.CreateFPToSI(m_utils.castValue(arg2.second, arg2.first), m_builder.getInt64Ty());
234-
ins->functionReturnReg->value = m_builder.CreateCall(m_utils.functions().resolve_llvm_random_long(), { executionContextPtr, from, to });
222+
ins->functionReturnReg->value = m_builder.CreateCall(m_utils.functions().resolve_llvm_random_long(), { m_utils.executionContextPtr(), from, to });
235223

236224
ins = ins->next;
237225
break;
@@ -1039,7 +1027,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
10391027
}
10401028

10411029
case LLVMInstruction::Type::Yield:
1042-
createSuspend(coro.get(), warpArg, targetVariables);
1030+
createSuspend(coro.get(), m_utils.warpArg(), m_utils.targetVariables());
10431031

10441032
ins = ins->next;
10451033
break;
@@ -1282,7 +1270,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
12821270
assert(ins->procedurePrototype);
12831271
assert(ins->args.size() == ins->procedurePrototype->argumentTypes().size());
12841272
m_utils.freeScopeHeap();
1285-
m_utils.syncVariables(targetVariables);
1273+
m_utils.syncVariables(m_utils.targetVariables());
12861274

12871275
std::string name = getMainFunctionName(ins->procedurePrototype);
12881276
llvm::FunctionType *type = getMainFunctionType(ins->procedurePrototype);
@@ -1295,7 +1283,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
12951283
if (m_warp)
12961284
args.push_back(m_builder.getInt1(true));
12971285
else
1298-
args.push_back(m_procedurePrototype ? warpArg : m_builder.getInt1(false));
1286+
args.push_back(m_utils.procedurePrototype() ? m_utils.warpArg() : m_builder.getInt1(false));
12991287

13001288
// Add procedure args
13011289
for (const auto &arg : ins->args) {
@@ -1313,14 +1301,14 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
13131301
m_builder.CreateCondBr(m_builder.CreateIsNull(handle), nextBranch, suspendBranch);
13141302

13151303
m_builder.SetInsertPoint(suspendBranch);
1316-
createSuspend(coro.get(), warpArg, targetVariables);
1304+
createSuspend(coro.get(), m_utils.warpArg(), m_utils.targetVariables());
13171305
llvm::Value *done = m_builder.CreateCall(m_ctx->coroutineResumeFunction(), { handle });
13181306
m_builder.CreateCondBr(done, nextBranch, suspendBranch);
13191307

13201308
m_builder.SetInsertPoint(nextBranch);
13211309
}
13221310

1323-
m_utils.reloadVariables(targetVariables);
1311+
m_utils.reloadVariables(m_utils.targetVariables());
13241312

13251313
ins = ins->next;
13261314
break;
@@ -1341,7 +1329,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
13411329
m_builder.CreateBr(endBranch);
13421330

13431331
m_builder.SetInsertPoint(endBranch);
1344-
m_utils.syncVariables(targetVariables);
1332+
m_utils.syncVariables(m_utils.targetVariables());
13451333

13461334
// End and verify the function
13471335
llvm::PointerType *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_llvmCtx), 0);

0 commit comments

Comments
 (0)