@@ -60,10 +60,10 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
6060 m_builder.SetInsertPoint (entry);
6161
6262 // Init coroutine
63- LLVMCoroutine coro;
63+ std::unique_ptr< LLVMCoroutine> coro;
6464
6565 if (!m_warp)
66- coro = initCoroutine ( func);
66+ coro = std::make_unique<LLVMCoroutine>(m_module. get (), &m_builder, func);
6767
6868 std::vector<LLVMIfStatement> ifStatements;
6969 std::vector<LLVMLoop> loops;
@@ -439,14 +439,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
439439 if (!m_warp) {
440440 freeHeap ();
441441 syncVariables (targetVariables);
442- m_builder.CreateStore (m_builder.getInt1 (true ), coro.didSuspend );
443- llvm::BasicBlock *resumeBranch = llvm::BasicBlock::Create (m_ctx, " " , func);
444- llvm::Value *noneToken = llvm::ConstantTokenNone::get (m_ctx);
445- llvm::Value *suspendResult = m_builder.CreateCall (llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_suspend), { noneToken, m_builder.getInt1 (false ) });
446- llvm::SwitchInst *sw = m_builder.CreateSwitch (suspendResult, coro.suspend , 2 );
447- sw->addCase (m_builder.getInt8 (0 ), resumeBranch);
448- sw->addCase (m_builder.getInt8 (1 ), coro.cleanup );
449- m_builder.SetInsertPoint (resumeBranch);
442+ coro->createSuspend ();
450443 }
451444
452445 break ;
@@ -651,34 +644,18 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
651644 freeHeap ();
652645 syncVariables (targetVariables);
653646
654- // Add final suspend point
655- if (!m_warp) {
656- llvm::BasicBlock *endBranch = llvm::BasicBlock::Create (m_ctx, " end" , func);
657- llvm::BasicBlock *finalSuspendBranch = llvm::BasicBlock::Create (m_ctx, " finalSuspend" , func);
658- m_builder.CreateCondBr (m_builder.CreateLoad (m_builder.getInt1Ty (), coro.didSuspend ), finalSuspendBranch, endBranch);
659-
660- m_builder.SetInsertPoint (finalSuspendBranch);
661- llvm::Value *suspendResult =
662- m_builder.CreateCall (llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_suspend), { llvm::ConstantTokenNone::get (m_ctx), m_builder.getInt1 (true ) });
663- llvm::SwitchInst *sw = m_builder.CreateSwitch (suspendResult, coro.suspend , 2 );
664- sw->addCase (m_builder.getInt8 (0 ), endBranch); // unreachable
665- sw->addCase (m_builder.getInt8 (1 ), coro.cleanup );
666-
667- m_builder.SetInsertPoint (endBranch);
668- }
669-
670647 // End and verify the function
648+ if (m_warp)
649+ m_builder.CreateRet (llvm::ConstantPointerNull::get (pointerType));
650+ else
651+ coro->end ();
652+
671653 if (!m_tmpRegs.empty ()) {
672654 std::cout
673655 << " warning: " << m_tmpRegs.size () << " registers were leaked by script '" << m_module->getName ().str () << " ', function '" << func->getName ().str ()
674656 << " ' (if you see this as a regular user, this is a bug and should be reported)" << std::endl;
675657 }
676658
677- if (m_warp)
678- m_builder.CreateRet (llvm::ConstantPointerNull::get (pointerType));
679- else
680- m_builder.CreateBr (coro.freeMemRet );
681-
682659 verifyFunction (func);
683660
684661 // Create resume function
@@ -691,12 +668,8 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
691668
692669 if (m_warp)
693670 m_builder.CreateRet (m_builder.getInt1 (true ));
694- else {
695- llvm::Value *coroHandle = func->getArg (0 );
696- m_builder.CreateCall (llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_resume), { coroHandle });
697- llvm::Value *done = m_builder.CreateCall (llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_done), { coroHandle });
698- m_builder.CreateRet (done);
699- }
671+ else
672+ m_builder.CreateRet (coro->createResume (func->getArg (0 )));
700673
701674 verifyFunction (func);
702675
@@ -1010,63 +983,6 @@ void LLVMCodeBuilder::createVariableMap()
1010983 }
1011984}
1012985
1013- LLVMCoroutine LLVMCodeBuilder::initCoroutine (llvm::Function *func)
1014- {
1015- // Set presplitcoroutine attribute
1016- func->setPresplitCoroutine ();
1017-
1018- // Coroutine intrinsics
1019- llvm::Function *coroId = llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_id);
1020- llvm::Function *coroSize = llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_size, m_builder.getInt64Ty ());
1021- llvm::Function *coroBegin = llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_begin);
1022- llvm::Function *coroEnd = llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_end);
1023- llvm::Function *coroFree = llvm::Intrinsic::getDeclaration (m_module.get (), llvm::Intrinsic::coro_free);
1024-
1025- // Init coroutine
1026- LLVMCoroutine coro;
1027- llvm::PointerType *pointerType = llvm::PointerType::get (llvm::Type::getInt8Ty (m_ctx), 0 );
1028- llvm::Constant *nullPointer = llvm::ConstantPointerNull::get (pointerType);
1029- llvm::Value *coroIdRet = m_builder.CreateCall (coroId, { m_builder.getInt32 (8 ), nullPointer, nullPointer, nullPointer });
1030-
1031- // Allocate memory
1032- llvm::Value *coroSizeRet = m_builder.CreateCall (coroSize, std::nullopt , " size" );
1033- llvm::Function *mallocFunc = llvm::Function::Create (llvm::FunctionType::get (pointerType, { m_builder.getInt64Ty () }, false ), llvm::Function::ExternalLinkage, " malloc" , m_module.get ());
1034- llvm::Value *alloc = m_builder.CreateCall (mallocFunc, coroSizeRet, " mem" );
1035-
1036- // Begin
1037- coro.handle = m_builder.CreateCall (coroBegin, { coroIdRet, alloc });
1038- coro.didSuspend = m_builder.CreateAlloca (m_builder.getInt1Ty (), nullptr , " didSuspend" );
1039- m_builder.CreateStore (m_builder.getInt1 (false ), coro.didSuspend );
1040- llvm::BasicBlock *entry = m_builder.GetInsertBlock ();
1041-
1042- // Create suspend branch
1043- coro.suspend = llvm::BasicBlock::Create (m_ctx, " suspend" , func);
1044- m_builder.SetInsertPoint (coro.suspend );
1045- m_builder.CreateCall (coroEnd, { coro.handle , m_builder.getInt1 (false ), llvm::ConstantTokenNone::get (m_ctx) });
1046- m_builder.CreateRet (coro.handle );
1047-
1048- // Create free branches
1049- coro.freeMemRet = llvm::BasicBlock::Create (m_ctx, " freeMemRet" , func);
1050- m_builder.SetInsertPoint (coro.freeMemRet );
1051- m_builder.CreateFree (alloc);
1052- m_builder.CreateRet (llvm::ConstantPointerNull::get (pointerType));
1053-
1054- llvm::BasicBlock *freeBranch = llvm::BasicBlock::Create (m_ctx, " free" , func);
1055- m_builder.SetInsertPoint (freeBranch);
1056- m_builder.CreateFree (alloc);
1057- m_builder.CreateBr (coro.suspend );
1058-
1059- // Create cleanup branch
1060- coro.cleanup = llvm::BasicBlock::Create (m_ctx, " cleanup" , func);
1061- m_builder.SetInsertPoint (coro.cleanup );
1062- llvm::Value *mem = m_builder.CreateCall (coroFree, { coroIdRet, coro.handle });
1063- llvm::Value *needFree = m_builder.CreateIsNotNull (mem);
1064- m_builder.CreateCondBr (needFree, freeBranch, coro.suspend );
1065-
1066- m_builder.SetInsertPoint (entry);
1067- return coro;
1068- }
1069-
1070986void LLVMCodeBuilder::verifyFunction (llvm::Function *func)
1071987{
1072988 if (llvm::verifyFunction (*func, &llvm::errs ())) {
0 commit comments