@@ -26,15 +26,16 @@ static std::unordered_map<ValueType, Compiler::StaticType>
2626static const std::unordered_set<LLVMInstruction::Type>
2727 VAR_LIST_READ_INSTRUCTIONS = { LLVMInstruction::Type::ReadVariable, LLVMInstruction::Type::GetListItem, LLVMInstruction::Type::GetListItemIndex, LLVMInstruction::Type::ListContainsItem };
2828
29- LLVMCodeBuilder::LLVMCodeBuilder (LLVMCompilerContext *ctx, BlockPrototype *procedurePrototype) :
29+ LLVMCodeBuilder::LLVMCodeBuilder (LLVMCompilerContext *ctx, BlockPrototype *procedurePrototype, bool isPredicate ) :
3030 m_ctx(ctx),
3131 m_target(ctx->target ()),
3232 m_llvmCtx(*ctx->llvmCtx ()),
3333 m_module(ctx->module ()),
3434 m_builder(m_llvmCtx),
3535 m_procedurePrototype(procedurePrototype),
3636 m_defaultWarp(procedurePrototype ? procedurePrototype->warp () : false),
37- m_warp(m_defaultWarp)
37+ m_warp(m_defaultWarp),
38+ m_isPredicate(isPredicate)
3839{
3940 initTypes ();
4041 createVariableMap ();
@@ -54,6 +55,10 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
5455
5556 if (it == m_instructions.end ())
5657 m_warp = true ;
58+
59+ // Do not create coroutine in hat predicates
60+ if (m_isPredicate)
61+ m_warp = true ;
5762 }
5863
5964 // Set fast math flags
@@ -1314,10 +1319,16 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
13141319 // End and verify the function
13151320 llvm::PointerType *pointerType = llvm::PointerType::get (llvm::Type::getInt8Ty (m_llvmCtx), 0 );
13161321
1317- if (m_warp)
1318- m_builder.CreateRet (llvm::ConstantPointerNull::get (pointerType));
1319- else
1320- coro->end ();
1322+ if (m_isPredicate) {
1323+ // Use last instruction return value
1324+ assert (!m_instructions.empty ());
1325+ m_builder.CreateRet (m_instructions.back ()->functionReturnReg ->value );
1326+ } else {
1327+ if (m_warp)
1328+ m_builder.CreateRet (llvm::ConstantPointerNull::get (pointerType));
1329+ else
1330+ coro->end ();
1331+ }
13211332
13221333 verifyFunction (m_function);
13231334
@@ -1338,7 +1349,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
13381349
13391350 verifyFunction (resumeFunc);
13401351
1341- return std::make_shared<LLVMExecutableCode>(m_ctx, m_function->getName ().str (), resumeFunc->getName ().str ());
1352+ return std::make_shared<LLVMExecutableCode>(m_ctx, m_function->getName ().str (), resumeFunc->getName ().str (), m_isPredicate );
13421353}
13431354
13441355CompilerValue *LLVMCodeBuilder::addFunctionCall (const std::string &functionName, Compiler::StaticType returnType, const Compiler::ArgTypes &argTypes, const Compiler::Args &args)
@@ -2008,7 +2019,7 @@ void LLVMCodeBuilder::popLoopScope()
20082019
20092020std::string LLVMCodeBuilder::getMainFunctionName (BlockPrototype *procedurePrototype)
20102021{
2011- return procedurePrototype ? " proc." + procedurePrototype->procCode () : " script" ;
2022+ return procedurePrototype ? " proc." + procedurePrototype->procCode () : (m_isPredicate ? " predicate " : " script" ) ;
20122023}
20132024
20142025std::string LLVMCodeBuilder::getResumeFunctionName (BlockPrototype *procedurePrototype)
@@ -2019,7 +2030,8 @@ std::string LLVMCodeBuilder::getResumeFunctionName(BlockPrototype *procedureProt
20192030llvm::FunctionType *LLVMCodeBuilder::getMainFunctionType (BlockPrototype *procedurePrototype)
20202031{
20212032 // void *f(ExecutionContext *, Target *, ValueData **, List **, (warp arg), (procedure args...))
2022- llvm::PointerType *pointerType = llvm::PointerType::get (llvm::Type::getInt8Ty (m_llvmCtx), 0 );
2033+ // bool f(...) (hat predicates)
2034+ llvm::Type *pointerType = llvm::PointerType::get (llvm::Type::getInt8Ty (m_llvmCtx), 0 );
20232035 std::vector<llvm::Type *> argTypes = { pointerType, pointerType, pointerType, pointerType };
20242036
20252037 if (procedurePrototype) {
@@ -2034,7 +2046,7 @@ llvm::FunctionType *LLVMCodeBuilder::getMainFunctionType(BlockPrototype *procedu
20342046 }
20352047 }
20362048
2037- return llvm::FunctionType::get (pointerType, argTypes, false );
2049+ return llvm::FunctionType::get (m_isPredicate ? m_builder. getInt1Ty () : pointerType, argTypes, false );
20382050}
20392051
20402052llvm::Function *LLVMCodeBuilder::getOrCreateFunction (const std::string &name, llvm::FunctionType *type)
0 commit comments