88#include < scratchcpp/iengine.h>
99#include < scratchcpp/variable.h>
1010#include < scratchcpp/list.h>
11+ #include < scratchcpp/blockprototype.h>
1112#include < scratchcpp/dev/compilerconstant.h>
1213#include < scratchcpp/dev/compilerlocalvariable.h>
1314
@@ -24,14 +25,15 @@ using namespace libscratchcpp;
2425static std::unordered_map<ValueType, Compiler::StaticType>
2526 TYPE_MAP = { { ValueType::Number, Compiler::StaticType::Number }, { ValueType::Bool, Compiler::StaticType::Bool }, { ValueType::String, Compiler::StaticType::String } };
2627
27- LLVMCodeBuilder::LLVMCodeBuilder (LLVMCompilerContext *ctx, bool warp ) :
28+ LLVMCodeBuilder::LLVMCodeBuilder (LLVMCompilerContext *ctx, BlockPrototype *procedurePrototype ) :
2829 m_ctx(ctx),
2930 m_target(ctx->target ()),
3031 m_llvmCtx(*ctx->llvmCtx ()),
3132 m_module(ctx->module ()),
3233 m_builder(m_llvmCtx),
33- m_defaultWarp(warp),
34- m_warp(warp)
34+ m_procedurePrototype(procedurePrototype),
35+ m_defaultWarp(procedurePrototype ? procedurePrototype->warp () : false),
36+ m_warp(m_defaultWarp)
3537{
3638 initTypes ();
3739 createVariableMap ();
@@ -40,9 +42,11 @@ LLVMCodeBuilder::LLVMCodeBuilder(LLVMCompilerContext *ctx, bool warp) :
4042
4143std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize ()
4244{
43- // Do not create coroutine if there are no yield instructions
4445 if (!m_warp) {
45- auto it = std::find_if (m_instructions.begin (), m_instructions.end (), [](const LLVMInstruction &step) { return step.type == LLVMInstruction::Type::Yield; });
46+ // Do not create coroutine if there are no yield instructions nor non-warp procedure calls
47+ auto it = std::find_if (m_instructions.begin (), m_instructions.end (), [](const LLVMInstruction &step) {
48+ return step.type == LLVMInstruction::Type::Yield || (step.type == LLVMInstruction::Type::CallProcedure && step.procedurePrototype && !step.procedurePrototype ->warp ());
49+ });
4650
4751 if (it == m_instructions.end ())
4852 m_warp = true ;
@@ -57,14 +61,25 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
5761 m_builder.setFastMathFlags (fmf);
5862
5963 // Create function
60- // void *f(ExecutionContext *, Target *, ValueData **, List **)
61- llvm::PointerType *pointerType = llvm::PointerType::get (llvm::Type::getInt8Ty (m_llvmCtx), 0 );
62- llvm::FunctionType *funcType = llvm::FunctionType::get (pointerType, { pointerType, pointerType, pointerType, pointerType }, false );
63- llvm::Function *func = llvm::Function::Create (funcType, llvm::Function::ExternalLinkage, " f" , m_module);
64+ std::string funcName = getMainFunctionName (m_procedurePrototype);
65+ llvm::FunctionType *funcType = getMainFunctionType (m_procedurePrototype);
66+ llvm::Function *func;
67+
68+ if (m_procedurePrototype)
69+ func = getOrCreateFunction (funcName, funcType);
70+ else
71+ func = llvm::Function::Create (funcType, llvm::Function::ExternalLinkage, funcName, m_module);
72+
6473 llvm::Value *executionContextPtr = func->getArg (0 );
6574 llvm::Value *targetPtr = func->getArg (1 );
6675 llvm::Value *targetVariables = func->getArg (2 );
6776 llvm::Value *targetLists = func->getArg (3 );
77+ llvm::Value *warpArg = nullptr ;
78+
79+ if (m_procedurePrototype) {
80+ func->addFnAttr (llvm::Attribute::AlwaysInline);
81+ warpArg = func->getArg (4 );
82+ }
6883
6984 llvm::BasicBlock *entry = llvm::BasicBlock::Create (m_llvmCtx, " entry" , func);
7085 llvm::BasicBlock *endBranch = llvm::BasicBlock::Create (m_llvmCtx, " end" , func);
@@ -787,15 +802,8 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
787802 }
788803
789804 case LLVMInstruction::Type::Yield:
790- if (!m_warp) {
791- // TODO: Do not allow use after suspend (use after free)
792- freeScopeHeap ();
793- syncVariables (targetVariables);
794- coro->createSuspend ();
795- reloadVariables (targetVariables);
796- reloadLists ();
797- }
798-
805+ // TODO: Do not allow use after suspend (use after free)
806+ createSuspend (coro.get (), func, warpArg, targetVariables);
799807 break ;
800808
801809 case LLVMInstruction::Type::BeginIf: {
@@ -1019,6 +1027,47 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
10191027 m_builder.SetInsertPoint (nextBranch);
10201028 break ;
10211029 }
1030+
1031+ case LLVMInstruction::Type::CallProcedure: {
1032+ assert (step.procedurePrototype );
1033+ freeScopeHeap ();
1034+ syncVariables (targetVariables);
1035+
1036+ std::string name = getMainFunctionName (step.procedurePrototype );
1037+ llvm::FunctionType *type = getMainFunctionType (step.procedurePrototype );
1038+ std::vector<llvm::Value *> args;
1039+ const size_t argCount = type->getNumParams () - 1 - step.procedurePrototype ->argumentTypes ().size (); // omit warp arg and procedure args
1040+
1041+ for (size_t i = 0 ; i < argCount; i++)
1042+ args.push_back (func->getArg (i));
1043+
1044+ // Add warp arg
1045+ if (m_warp)
1046+ args.push_back (m_builder.getInt1 (true ));
1047+ else
1048+ args.push_back (m_procedurePrototype ? warpArg : m_builder.getInt1 (false ));
1049+
1050+ // TODO: Add procedure args
1051+ llvm::Value *handle = m_builder.CreateCall (resolveFunction (name, type), args);
1052+
1053+ if (!m_warp && !step.procedurePrototype ->warp ()) {
1054+ llvm::BasicBlock *suspendBranch = llvm::BasicBlock::Create (m_llvmCtx, " " , func);
1055+ llvm::BasicBlock *nextBranch = llvm::BasicBlock::Create (m_llvmCtx, " " , func);
1056+ m_builder.CreateCondBr (m_builder.CreateIsNull (handle), nextBranch, suspendBranch);
1057+
1058+ m_builder.SetInsertPoint (suspendBranch);
1059+ createSuspend (coro.get (), func, warpArg, targetVariables);
1060+ name = getResumeFunctionName (step.procedurePrototype );
1061+ llvm::Value *done = m_builder.CreateCall (resolveFunction (name, m_resumeFuncType), { handle });
1062+ m_builder.CreateCondBr (done, nextBranch, suspendBranch);
1063+
1064+ m_builder.SetInsertPoint (nextBranch);
1065+ }
1066+
1067+ reloadVariables (targetVariables);
1068+ reloadLists ();
1069+ break ;
1070+ }
10221071 }
10231072 }
10241073
@@ -1030,6 +1079,8 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
10301079 syncVariables (targetVariables);
10311080
10321081 // End and verify the function
1082+ llvm::PointerType *pointerType = llvm::PointerType::get (llvm::Type::getInt8Ty (m_llvmCtx), 0 );
1083+
10331084 if (m_warp)
10341085 m_builder.CreateRet (llvm::ConstantPointerNull::get (pointerType));
10351086 else
@@ -1039,8 +1090,8 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
10391090
10401091 // Create resume function
10411092 // bool resume(void *)
1042- funcType = llvm::FunctionType::get (m_builder. getInt1Ty (), pointerType, false );
1043- llvm::Function *resumeFunc = llvm::Function::Create (funcType, llvm::Function::ExternalLinkage, " resume " , m_module );
1093+ funcName = getResumeFunctionName (m_procedurePrototype );
1094+ llvm::Function *resumeFunc = getOrCreateFunction (funcName, m_resumeFuncType );
10441095
10451096 entry = llvm::BasicBlock::Create (m_llvmCtx, " entry" , resumeFunc);
10461097 m_builder.SetInsertPoint (entry);
@@ -1451,9 +1502,18 @@ void LLVMCodeBuilder::createStop()
14511502 m_instructions.push_back ({ LLVMInstruction::Type::Stop });
14521503}
14531504
1505+ void LLVMCodeBuilder::createProcedureCall (BlockPrototype *prototype)
1506+ {
1507+ LLVMInstruction ins (LLVMInstruction::Type::CallProcedure);
1508+ ins.procedurePrototype = prototype;
1509+ m_instructions.push_back (ins);
1510+ }
1511+
14541512void LLVMCodeBuilder::initTypes ()
14551513{
1514+ llvm::PointerType *pointerType = llvm::PointerType::get (llvm::Type::getInt8Ty (m_llvmCtx), 0 );
14561515 m_valueDataType = LLVMTypes::createValueDataType (&m_builder);
1516+ m_resumeFuncType = llvm::FunctionType::get (m_builder.getInt1Ty (), pointerType, false );
14571517}
14581518
14591519void LLVMCodeBuilder::createVariableMap ()
@@ -1557,6 +1617,47 @@ void LLVMCodeBuilder::popScopeLevel()
15571617 m_heap.pop_back ();
15581618}
15591619
1620+ std::string LLVMCodeBuilder::getMainFunctionName (BlockPrototype *procedurePrototype)
1621+ {
1622+ return procedurePrototype ? " f." + procedurePrototype->procCode () : " f" ;
1623+ }
1624+
1625+ std::string LLVMCodeBuilder::getResumeFunctionName (BlockPrototype *procedurePrototype)
1626+ {
1627+ return procedurePrototype ? " resume." + procedurePrototype->procCode () : " resume" ;
1628+ }
1629+
1630+ llvm::FunctionType *LLVMCodeBuilder::getMainFunctionType (BlockPrototype *procedurePrototype)
1631+ {
1632+ // void *f(ExecutionContext *, Target *, ValueData **, List **, (warp arg), (procedure args...))
1633+ llvm::PointerType *pointerType = llvm::PointerType::get (llvm::Type::getInt8Ty (m_llvmCtx), 0 );
1634+ std::vector<llvm::Type *> argTypes = { pointerType, pointerType, pointerType, pointerType };
1635+
1636+ if (procedurePrototype) {
1637+ argTypes.push_back (m_builder.getInt1Ty ()); // warp arg (only in procedures)
1638+ const auto &types = procedurePrototype->argumentTypes ();
1639+
1640+ for (BlockPrototype::ArgType type : types) {
1641+ if (type == BlockPrototype::ArgType::Bool)
1642+ argTypes.push_back (m_builder.getInt1Ty ());
1643+ else
1644+ argTypes.push_back (m_valueDataType->getPointerTo ());
1645+ }
1646+ }
1647+
1648+ return llvm::FunctionType::get (pointerType, argTypes, false );
1649+ }
1650+
1651+ llvm::Function *LLVMCodeBuilder::getOrCreateFunction (const std::string &name, llvm::FunctionType *type)
1652+ {
1653+ llvm::Function *func = m_module->getFunction (name);
1654+
1655+ if (func)
1656+ return func;
1657+ else
1658+ return llvm::Function::Create (type, llvm::Function::ExternalLinkage, name, m_module);
1659+ }
1660+
15601661void LLVMCodeBuilder::verifyFunction (llvm::Function *func)
15611662{
15621663 if (llvm::verifyFunction (*func, &llvm::errs ())) {
@@ -2346,6 +2447,31 @@ llvm::Value *LLVMCodeBuilder::createComparison(LLVMRegister *arg1, LLVMRegister
23462447 }
23472448}
23482449
2450+ void LLVMCodeBuilder::createSuspend (LLVMCoroutine *coro, llvm::Function *func, llvm::Value *warpArg, llvm::Value *targetVariables)
2451+ {
2452+ if (!m_warp) {
2453+ llvm::BasicBlock *suspendBranch, *nextBranch;
2454+
2455+ if (warpArg) {
2456+ suspendBranch = llvm::BasicBlock::Create (m_llvmCtx, " " , func);
2457+ nextBranch = llvm::BasicBlock::Create (m_llvmCtx, " " , func);
2458+ m_builder.CreateCondBr (warpArg, nextBranch, suspendBranch);
2459+ m_builder.SetInsertPoint (suspendBranch);
2460+ }
2461+
2462+ freeScopeHeap ();
2463+ syncVariables (targetVariables);
2464+ coro->createSuspend ();
2465+ reloadVariables (targetVariables);
2466+ reloadLists ();
2467+
2468+ if (warpArg) {
2469+ m_builder.CreateBr (nextBranch);
2470+ m_builder.SetInsertPoint (nextBranch);
2471+ }
2472+ }
2473+ }
2474+
23492475llvm::FunctionCallee LLVMCodeBuilder::resolveFunction (const std::string name, llvm::FunctionType *type)
23502476{
23512477 return m_module->getOrInsertFunction (name, type);
0 commit comments