Skip to content

Commit a3aa0f8

Browse files
committed
Implement LLVM procedures
1 parent 9567a04 commit a3aa0f8

File tree

12 files changed

+277
-33
lines changed

12 files changed

+277
-33
lines changed

src/dev/engine/compiler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ std::shared_ptr<libscratchcpp::Block> Compiler::block() const
4646
/*! Compiles the script starting with the given block. */
4747
std::shared_ptr<ExecutableCode> Compiler::compile(std::shared_ptr<Block> startBlock)
4848
{
49-
impl->builder = impl->builderFactory->create(impl->ctx, false);
49+
impl->builder = impl->builderFactory->create(impl->ctx, nullptr);
5050
impl->substackTree.clear();
5151
impl->substackHit = false;
5252
impl->emptySubstack = false;

src/dev/engine/internal/codebuilderfactory.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ std::shared_ptr<CodeBuilderFactory> CodeBuilderFactory::instance()
1313
return m_instance;
1414
}
1515

16-
std::shared_ptr<ICodeBuilder> CodeBuilderFactory::create(CompilerContext *ctx, bool warp) const
16+
std::shared_ptr<ICodeBuilder> CodeBuilderFactory::create(CompilerContext *ctx, BlockPrototype *procedurePrototype) const
1717
{
1818
assert(dynamic_cast<LLVMCompilerContext *>(ctx));
19-
return std::make_shared<LLVMCodeBuilder>(static_cast<LLVMCompilerContext *>(ctx), warp);
19+
return std::make_shared<LLVMCodeBuilder>(static_cast<LLVMCompilerContext *>(ctx), procedurePrototype);
2020
}
2121

2222
std::shared_ptr<CompilerContext> CodeBuilderFactory::createCtx(IEngine *engine, Target *target) const

src/dev/engine/internal/codebuilderfactory.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class CodeBuilderFactory : public ICodeBuilderFactory
1111
{
1212
public:
1313
static std::shared_ptr<CodeBuilderFactory> instance();
14-
std::shared_ptr<ICodeBuilder> create(CompilerContext *ctx, bool warp) const override;
14+
std::shared_ptr<ICodeBuilder> create(CompilerContext *ctx, BlockPrototype *procedurePrototype) const override;
1515
std::shared_ptr<CompilerContext> createCtx(IEngine *engine, Target *target) const override;
1616

1717
private:

src/dev/engine/internal/icodebuilder.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ class Value;
1111
class Variable;
1212
class List;
1313
class ExecutableCode;
14+
class BlockPrototype;
1415

1516
class ICodeBuilder
1617
{
@@ -91,6 +92,8 @@ class ICodeBuilder
9192
virtual void yield() = 0;
9293

9394
virtual void createStop() = 0;
95+
96+
virtual void createProcedureCall(BlockPrototype *prototype) = 0;
9497
};
9598

9699
} // namespace libscratchcpp

src/dev/engine/internal/icodebuilderfactory.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ namespace libscratchcpp
99

1010
class ICodeBuilder;
1111
class CompilerContext;
12+
class BlockPrototype;
1213
class Target;
1314
class IEngine;
1415

@@ -17,7 +18,7 @@ class ICodeBuilderFactory
1718
public:
1819
virtual ~ICodeBuilderFactory() { }
1920

20-
virtual std::shared_ptr<ICodeBuilder> create(CompilerContext *ctx, bool warp) const = 0;
21+
virtual std::shared_ptr<ICodeBuilder> create(CompilerContext *ctx, BlockPrototype *procedurePrototype = nullptr) const = 0;
2122
virtual std::shared_ptr<CompilerContext> createCtx(IEngine *engine, Target *target) const = 0;
2223
};
2324

src/dev/engine/internal/llvm/llvmcodebuilder.cpp

Lines changed: 146 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <scratchcpp/iengine.h>
99
#include <scratchcpp/variable.h>
1010
#include <scratchcpp/list.h>
11+
#include <scratchcpp/blockprototype.h>
1112
#include <scratchcpp/dev/compilerconstant.h>
1213
#include <scratchcpp/dev/compilerlocalvariable.h>
1314

@@ -24,14 +25,15 @@ using namespace libscratchcpp;
2425
static std::unordered_map<ValueType, Compiler::StaticType>
2526
TYPE_MAP = { { ValueType::Number, Compiler::StaticType::Number }, { ValueType::Bool, Compiler::StaticType::Bool }, { ValueType::String, Compiler::StaticType::String } };
2627

27-
LLVMCodeBuilder::LLVMCodeBuilder(LLVMCompilerContext *ctx, bool warp) :
28+
LLVMCodeBuilder::LLVMCodeBuilder(LLVMCompilerContext *ctx, BlockPrototype *procedurePrototype) :
2829
m_ctx(ctx),
2930
m_target(ctx->target()),
3031
m_llvmCtx(*ctx->llvmCtx()),
3132
m_module(ctx->module()),
3233
m_builder(m_llvmCtx),
33-
m_defaultWarp(warp),
34-
m_warp(warp)
34+
m_procedurePrototype(procedurePrototype),
35+
m_defaultWarp(procedurePrototype ? procedurePrototype->warp() : false),
36+
m_warp(m_defaultWarp)
3537
{
3638
initTypes();
3739
createVariableMap();
@@ -40,9 +42,11 @@ LLVMCodeBuilder::LLVMCodeBuilder(LLVMCompilerContext *ctx, bool warp) :
4042

4143
std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
4244
{
43-
// Do not create coroutine if there are no yield instructions
4445
if (!m_warp) {
45-
auto it = std::find_if(m_instructions.begin(), m_instructions.end(), [](const LLVMInstruction &step) { return step.type == LLVMInstruction::Type::Yield; });
46+
// Do not create coroutine if there are no yield instructions nor non-warp procedure calls
47+
auto it = std::find_if(m_instructions.begin(), m_instructions.end(), [](const LLVMInstruction &step) {
48+
return step.type == LLVMInstruction::Type::Yield || (step.type == LLVMInstruction::Type::CallProcedure && step.procedurePrototype && !step.procedurePrototype->warp());
49+
});
4650

4751
if (it == m_instructions.end())
4852
m_warp = true;
@@ -57,14 +61,25 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
5761
m_builder.setFastMathFlags(fmf);
5862

5963
// Create function
60-
// void *f(ExecutionContext *, Target *, ValueData **, List **)
61-
llvm::PointerType *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_llvmCtx), 0);
62-
llvm::FunctionType *funcType = llvm::FunctionType::get(pointerType, { pointerType, pointerType, pointerType, pointerType }, false);
63-
llvm::Function *func = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, "f", m_module);
64+
std::string funcName = getMainFunctionName(m_procedurePrototype);
65+
llvm::FunctionType *funcType = getMainFunctionType(m_procedurePrototype);
66+
llvm::Function *func;
67+
68+
if (m_procedurePrototype)
69+
func = getOrCreateFunction(funcName, funcType);
70+
else
71+
func = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, funcName, m_module);
72+
6473
llvm::Value *executionContextPtr = func->getArg(0);
6574
llvm::Value *targetPtr = func->getArg(1);
6675
llvm::Value *targetVariables = func->getArg(2);
6776
llvm::Value *targetLists = func->getArg(3);
77+
llvm::Value *warpArg = nullptr;
78+
79+
if (m_procedurePrototype) {
80+
func->addFnAttr(llvm::Attribute::AlwaysInline);
81+
warpArg = func->getArg(4);
82+
}
6883

6984
llvm::BasicBlock *entry = llvm::BasicBlock::Create(m_llvmCtx, "entry", func);
7085
llvm::BasicBlock *endBranch = llvm::BasicBlock::Create(m_llvmCtx, "end", func);
@@ -787,15 +802,8 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
787802
}
788803

789804
case LLVMInstruction::Type::Yield:
790-
if (!m_warp) {
791-
// TODO: Do not allow use after suspend (use after free)
792-
freeScopeHeap();
793-
syncVariables(targetVariables);
794-
coro->createSuspend();
795-
reloadVariables(targetVariables);
796-
reloadLists();
797-
}
798-
805+
// TODO: Do not allow use after suspend (use after free)
806+
createSuspend(coro.get(), func, warpArg, targetVariables);
799807
break;
800808

801809
case LLVMInstruction::Type::BeginIf: {
@@ -1019,6 +1027,47 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
10191027
m_builder.SetInsertPoint(nextBranch);
10201028
break;
10211029
}
1030+
1031+
case LLVMInstruction::Type::CallProcedure: {
1032+
assert(step.procedurePrototype);
1033+
freeScopeHeap();
1034+
syncVariables(targetVariables);
1035+
1036+
std::string name = getMainFunctionName(step.procedurePrototype);
1037+
llvm::FunctionType *type = getMainFunctionType(step.procedurePrototype);
1038+
std::vector<llvm::Value *> args;
1039+
const size_t argCount = type->getNumParams() - 1 - step.procedurePrototype->argumentTypes().size(); // omit warp arg and procedure args
1040+
1041+
for (size_t i = 0; i < argCount; i++)
1042+
args.push_back(func->getArg(i));
1043+
1044+
// Add warp arg
1045+
if (m_warp)
1046+
args.push_back(m_builder.getInt1(true));
1047+
else
1048+
args.push_back(m_procedurePrototype ? warpArg : m_builder.getInt1(false));
1049+
1050+
// TODO: Add procedure args
1051+
llvm::Value *handle = m_builder.CreateCall(resolveFunction(name, type), args);
1052+
1053+
if (!m_warp && !step.procedurePrototype->warp()) {
1054+
llvm::BasicBlock *suspendBranch = llvm::BasicBlock::Create(m_llvmCtx, "", func);
1055+
llvm::BasicBlock *nextBranch = llvm::BasicBlock::Create(m_llvmCtx, "", func);
1056+
m_builder.CreateCondBr(m_builder.CreateIsNull(handle), nextBranch, suspendBranch);
1057+
1058+
m_builder.SetInsertPoint(suspendBranch);
1059+
createSuspend(coro.get(), func, warpArg, targetVariables);
1060+
name = getResumeFunctionName(step.procedurePrototype);
1061+
llvm::Value *done = m_builder.CreateCall(resolveFunction(name, m_resumeFuncType), { handle });
1062+
m_builder.CreateCondBr(done, nextBranch, suspendBranch);
1063+
1064+
m_builder.SetInsertPoint(nextBranch);
1065+
}
1066+
1067+
reloadVariables(targetVariables);
1068+
reloadLists();
1069+
break;
1070+
}
10221071
}
10231072
}
10241073

@@ -1030,6 +1079,8 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
10301079
syncVariables(targetVariables);
10311080

10321081
// End and verify the function
1082+
llvm::PointerType *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_llvmCtx), 0);
1083+
10331084
if (m_warp)
10341085
m_builder.CreateRet(llvm::ConstantPointerNull::get(pointerType));
10351086
else
@@ -1039,8 +1090,8 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
10391090

10401091
// Create resume function
10411092
// bool resume(void *)
1042-
funcType = llvm::FunctionType::get(m_builder.getInt1Ty(), pointerType, false);
1043-
llvm::Function *resumeFunc = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, "resume", m_module);
1093+
funcName = getResumeFunctionName(m_procedurePrototype);
1094+
llvm::Function *resumeFunc = getOrCreateFunction(funcName, m_resumeFuncType);
10441095

10451096
entry = llvm::BasicBlock::Create(m_llvmCtx, "entry", resumeFunc);
10461097
m_builder.SetInsertPoint(entry);
@@ -1451,9 +1502,18 @@ void LLVMCodeBuilder::createStop()
14511502
m_instructions.push_back({ LLVMInstruction::Type::Stop });
14521503
}
14531504

1505+
void LLVMCodeBuilder::createProcedureCall(BlockPrototype *prototype)
1506+
{
1507+
LLVMInstruction ins(LLVMInstruction::Type::CallProcedure);
1508+
ins.procedurePrototype = prototype;
1509+
m_instructions.push_back(ins);
1510+
}
1511+
14541512
void LLVMCodeBuilder::initTypes()
14551513
{
1514+
llvm::PointerType *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_llvmCtx), 0);
14561515
m_valueDataType = LLVMTypes::createValueDataType(&m_builder);
1516+
m_resumeFuncType = llvm::FunctionType::get(m_builder.getInt1Ty(), pointerType, false);
14571517
}
14581518

14591519
void LLVMCodeBuilder::createVariableMap()
@@ -1557,6 +1617,47 @@ void LLVMCodeBuilder::popScopeLevel()
15571617
m_heap.pop_back();
15581618
}
15591619

1620+
std::string LLVMCodeBuilder::getMainFunctionName(BlockPrototype *procedurePrototype)
1621+
{
1622+
return procedurePrototype ? "f." + procedurePrototype->procCode() : "f";
1623+
}
1624+
1625+
std::string LLVMCodeBuilder::getResumeFunctionName(BlockPrototype *procedurePrototype)
1626+
{
1627+
return procedurePrototype ? "resume." + procedurePrototype->procCode() : "resume";
1628+
}
1629+
1630+
llvm::FunctionType *LLVMCodeBuilder::getMainFunctionType(BlockPrototype *procedurePrototype)
1631+
{
1632+
// void *f(ExecutionContext *, Target *, ValueData **, List **, (warp arg), (procedure args...))
1633+
llvm::PointerType *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_llvmCtx), 0);
1634+
std::vector<llvm::Type *> argTypes = { pointerType, pointerType, pointerType, pointerType };
1635+
1636+
if (procedurePrototype) {
1637+
argTypes.push_back(m_builder.getInt1Ty()); // warp arg (only in procedures)
1638+
const auto &types = procedurePrototype->argumentTypes();
1639+
1640+
for (BlockPrototype::ArgType type : types) {
1641+
if (type == BlockPrototype::ArgType::Bool)
1642+
argTypes.push_back(m_builder.getInt1Ty());
1643+
else
1644+
argTypes.push_back(m_valueDataType->getPointerTo());
1645+
}
1646+
}
1647+
1648+
return llvm::FunctionType::get(pointerType, argTypes, false);
1649+
}
1650+
1651+
llvm::Function *LLVMCodeBuilder::getOrCreateFunction(const std::string &name, llvm::FunctionType *type)
1652+
{
1653+
llvm::Function *func = m_module->getFunction(name);
1654+
1655+
if (func)
1656+
return func;
1657+
else
1658+
return llvm::Function::Create(type, llvm::Function::ExternalLinkage, name, m_module);
1659+
}
1660+
15601661
void LLVMCodeBuilder::verifyFunction(llvm::Function *func)
15611662
{
15621663
if (llvm::verifyFunction(*func, &llvm::errs())) {
@@ -2346,6 +2447,31 @@ llvm::Value *LLVMCodeBuilder::createComparison(LLVMRegister *arg1, LLVMRegister
23462447
}
23472448
}
23482449

2450+
void LLVMCodeBuilder::createSuspend(LLVMCoroutine *coro, llvm::Function *func, llvm::Value *warpArg, llvm::Value *targetVariables)
2451+
{
2452+
if (!m_warp) {
2453+
llvm::BasicBlock *suspendBranch, *nextBranch;
2454+
2455+
if (warpArg) {
2456+
suspendBranch = llvm::BasicBlock::Create(m_llvmCtx, "", func);
2457+
nextBranch = llvm::BasicBlock::Create(m_llvmCtx, "", func);
2458+
m_builder.CreateCondBr(warpArg, nextBranch, suspendBranch);
2459+
m_builder.SetInsertPoint(suspendBranch);
2460+
}
2461+
2462+
freeScopeHeap();
2463+
syncVariables(targetVariables);
2464+
coro->createSuspend();
2465+
reloadVariables(targetVariables);
2466+
reloadLists();
2467+
2468+
if (warpArg) {
2469+
m_builder.CreateBr(nextBranch);
2470+
m_builder.SetInsertPoint(nextBranch);
2471+
}
2472+
}
2473+
}
2474+
23492475
llvm::FunctionCallee LLVMCodeBuilder::resolveFunction(const std::string name, llvm::FunctionType *type)
23502476
{
23512477
return m_module->getOrInsertFunction(name, type);

src/dev/engine/internal/llvm/llvmcodebuilder.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class LLVMConstantRegister;
2222
class LLVMCodeBuilder : public ICodeBuilder
2323
{
2424
public:
25-
LLVMCodeBuilder(LLVMCompilerContext *ctx, bool warp);
25+
LLVMCodeBuilder(LLVMCompilerContext *ctx, BlockPrototype *procedurePrototype = nullptr);
2626

2727
std::shared_ptr<ExecutableCode> finalize() override;
2828

@@ -99,6 +99,8 @@ class LLVMCodeBuilder : public ICodeBuilder
9999

100100
void createStop() override;
101101

102+
void createProcedureCall(BlockPrototype *prototype) override;
103+
102104
private:
103105
enum class Comparison
104106
{
@@ -113,6 +115,10 @@ class LLVMCodeBuilder : public ICodeBuilder
113115
void pushScopeLevel();
114116
void popScopeLevel();
115117

118+
std::string getMainFunctionName(BlockPrototype *procedurePrototype);
119+
std::string getResumeFunctionName(BlockPrototype *procedurePrototype);
120+
llvm::FunctionType *getMainFunctionType(BlockPrototype *procedurePrototype);
121+
llvm::Function *getOrCreateFunction(const std::string &name, llvm::FunctionType *type);
116122
void verifyFunction(llvm::Function *func);
117123
void optimize();
118124

@@ -147,6 +153,8 @@ class LLVMCodeBuilder : public ICodeBuilder
147153
llvm::Value *createValue(LLVMRegister *reg);
148154
llvm::Value *createComparison(LLVMRegister *arg1, LLVMRegister *arg2, Comparison type);
149155

156+
void createSuspend(LLVMCoroutine *coro, llvm::Function *func, llvm::Value *warpArg, llvm::Value *targetVariables);
157+
150158
llvm::FunctionCallee resolveFunction(const std::string name, llvm::FunctionType *type);
151159
llvm::FunctionCallee resolve_value_init();
152160
llvm::FunctionCallee resolve_value_free();
@@ -196,10 +204,12 @@ class LLVMCodeBuilder : public ICodeBuilder
196204
llvm::IRBuilder<> m_builder;
197205

198206
llvm::StructType *m_valueDataType = nullptr;
207+
llvm::FunctionType *m_resumeFuncType = nullptr;
199208

200209
std::vector<LLVMInstruction> m_instructions;
201210
std::vector<std::shared_ptr<LLVMRegister>> m_regs;
202211
std::vector<std::shared_ptr<CompilerLocalVariable>> m_localVars;
212+
BlockPrototype *m_procedurePrototype = nullptr;
203213
bool m_defaultWarp = false;
204214
bool m_warp = false;
205215

src/dev/engine/internal/llvm/llvminstruction.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
namespace libscratchcpp
1010
{
1111

12+
class BlockPrototype;
13+
1214
struct LLVMInstruction
1315
{
1416
enum class Type
@@ -68,7 +70,8 @@ struct LLVMInstruction
6870
BeginRepeatUntilLoop,
6971
BeginLoopCondition,
7072
EndLoop,
71-
Stop
73+
Stop,
74+
CallProcedure
7275
};
7376

7477
LLVMInstruction(Type type) :
@@ -84,6 +87,7 @@ struct LLVMInstruction
8487
bool functionCtxArg = false; // whether to add execution context ptr to function parameters
8588
Variable *workVariable = nullptr; // for variables
8689
List *workList = nullptr; // for lists
90+
BlockPrototype *procedurePrototype = nullptr;
8791
};
8892

8993
} // namespace libscratchcpp

0 commit comments

Comments
 (0)