Skip to content

Commit 66978e0

Browse files
committed
LLVMCodeBuilder: Implement hat predicates
1 parent 412ff6d commit 66978e0

File tree

3 files changed

+63
-13
lines changed

3 files changed

+63
-13
lines changed

src/engine/internal/llvm/llvmcodebuilder.cpp

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,16 @@ static std::unordered_map<ValueType, Compiler::StaticType>
2626
static const std::unordered_set<LLVMInstruction::Type>
2727
VAR_LIST_READ_INSTRUCTIONS = { LLVMInstruction::Type::ReadVariable, LLVMInstruction::Type::GetListItem, LLVMInstruction::Type::GetListItemIndex, LLVMInstruction::Type::ListContainsItem };
2828

29-
LLVMCodeBuilder::LLVMCodeBuilder(LLVMCompilerContext *ctx, BlockPrototype *procedurePrototype) :
29+
LLVMCodeBuilder::LLVMCodeBuilder(LLVMCompilerContext *ctx, BlockPrototype *procedurePrototype, bool isPredicate) :
3030
m_ctx(ctx),
3131
m_target(ctx->target()),
3232
m_llvmCtx(*ctx->llvmCtx()),
3333
m_module(ctx->module()),
3434
m_builder(m_llvmCtx),
3535
m_procedurePrototype(procedurePrototype),
3636
m_defaultWarp(procedurePrototype ? procedurePrototype->warp() : false),
37-
m_warp(m_defaultWarp)
37+
m_warp(m_defaultWarp),
38+
m_isPredicate(isPredicate)
3839
{
3940
initTypes();
4041
createVariableMap();
@@ -54,6 +55,10 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
5455

5556
if (it == m_instructions.end())
5657
m_warp = true;
58+
59+
// Do not create coroutine in hat predicates
60+
if (m_isPredicate)
61+
m_warp = true;
5762
}
5863

5964
// Set fast math flags
@@ -1314,10 +1319,16 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
13141319
// End and verify the function
13151320
llvm::PointerType *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_llvmCtx), 0);
13161321

1317-
if (m_warp)
1318-
m_builder.CreateRet(llvm::ConstantPointerNull::get(pointerType));
1319-
else
1320-
coro->end();
1322+
if (m_isPredicate) {
1323+
// Use last instruction return value
1324+
assert(!m_instructions.empty());
1325+
m_builder.CreateRet(m_instructions.back()->functionReturnReg->value);
1326+
} else {
1327+
if (m_warp)
1328+
m_builder.CreateRet(llvm::ConstantPointerNull::get(pointerType));
1329+
else
1330+
coro->end();
1331+
}
13211332

13221333
verifyFunction(m_function);
13231334

@@ -1338,7 +1349,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
13381349

13391350
verifyFunction(resumeFunc);
13401351

1341-
return std::make_shared<LLVMExecutableCode>(m_ctx, m_function->getName().str(), resumeFunc->getName().str());
1352+
return std::make_shared<LLVMExecutableCode>(m_ctx, m_function->getName().str(), resumeFunc->getName().str(), m_isPredicate);
13421353
}
13431354

13441355
CompilerValue *LLVMCodeBuilder::addFunctionCall(const std::string &functionName, Compiler::StaticType returnType, const Compiler::ArgTypes &argTypes, const Compiler::Args &args)
@@ -2008,7 +2019,7 @@ void LLVMCodeBuilder::popLoopScope()
20082019

20092020
std::string LLVMCodeBuilder::getMainFunctionName(BlockPrototype *procedurePrototype)
20102021
{
2011-
return procedurePrototype ? "proc." + procedurePrototype->procCode() : "script";
2022+
return procedurePrototype ? "proc." + procedurePrototype->procCode() : (m_isPredicate ? "predicate" : "script");
20122023
}
20132024

20142025
std::string LLVMCodeBuilder::getResumeFunctionName(BlockPrototype *procedurePrototype)
@@ -2019,7 +2030,8 @@ std::string LLVMCodeBuilder::getResumeFunctionName(BlockPrototype *procedureProt
20192030
llvm::FunctionType *LLVMCodeBuilder::getMainFunctionType(BlockPrototype *procedurePrototype)
20202031
{
20212032
// void *f(ExecutionContext *, Target *, ValueData **, List **, (warp arg), (procedure args...))
2022-
llvm::PointerType *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_llvmCtx), 0);
2033+
// bool f(...) (hat predicates)
2034+
llvm::Type *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_llvmCtx), 0);
20232035
std::vector<llvm::Type *> argTypes = { pointerType, pointerType, pointerType, pointerType };
20242036

20252037
if (procedurePrototype) {
@@ -2034,7 +2046,7 @@ llvm::FunctionType *LLVMCodeBuilder::getMainFunctionType(BlockPrototype *procedu
20342046
}
20352047
}
20362048

2037-
return llvm::FunctionType::get(pointerType, argTypes, false);
2049+
return llvm::FunctionType::get(m_isPredicate ? m_builder.getInt1Ty() : pointerType, argTypes, false);
20382050
}
20392051

20402052
llvm::Function *LLVMCodeBuilder::getOrCreateFunction(const std::string &name, llvm::FunctionType *type)

src/engine/internal/llvm/llvmcodebuilder.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class LLVMLoopScope;
2525
class LLVMCodeBuilder : public ICodeBuilder
2626
{
2727
public:
28-
LLVMCodeBuilder(LLVMCompilerContext *ctx, BlockPrototype *procedurePrototype = nullptr);
28+
LLVMCodeBuilder(LLVMCompilerContext *ctx, BlockPrototype *procedurePrototype = nullptr, bool isPredicate = false);
2929

3030
std::shared_ptr<ExecutableCode> finalize() override;
3131

@@ -239,6 +239,7 @@ class LLVMCodeBuilder : public ICodeBuilder
239239
bool m_defaultWarp = false;
240240
bool m_warp = false;
241241
int m_defaultArgCount = 0;
242+
bool m_isPredicate = false; // for hat predicates
242243

243244
long m_loopScope = -1; // index
244245
std::vector<std::shared_ptr<LLVMLoopScope>> m_loopScopes;

test/llvm/llvmcodebuilder_test.cpp

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,12 @@ class LLVMCodeBuilderTest : public testing::Test
6767
test_function(nullptr, nullptr, nullptr, nullptr, nullptr); // force dependency
6868
}
6969

70-
void createBuilder(Target *target, BlockPrototype *procedurePrototype)
70+
void createBuilder(Target *target, BlockPrototype *procedurePrototype, bool isPredicate = false)
7171
{
7272
if (m_contexts.find(target) == m_contexts.cend() || !target)
7373
m_contexts[target] = std::make_unique<LLVMCompilerContext>(&m_engine, target);
7474

75-
m_builder = std::make_unique<LLVMCodeBuilder>(m_contexts[target].get(), procedurePrototype);
75+
m_builder = std::make_unique<LLVMCodeBuilder>(m_contexts[target].get(), procedurePrototype, isPredicate);
7676
}
7777

7878
void createBuilder(Target *target, bool warp)
@@ -82,6 +82,8 @@ class LLVMCodeBuilderTest : public testing::Test
8282
createBuilder(target, m_procedurePrototype.get());
8383
}
8484

85+
void createPredicateBuilder(Target *target) { createBuilder(target, nullptr, true); }
86+
8587
void createBuilder(bool warp) { createBuilder(nullptr, warp); }
8688

8789
CompilerValue *callConstFuncForType(ValueType type, CompilerValue *arg)
@@ -5995,3 +5997,38 @@ TEST_F(LLVMCodeBuilderTest, Procedures)
59955997
ASSERT_EQ(testing::internal::GetCapturedStdout(), expected2 + expected3);
59965998
ASSERT_TRUE(code->isFinished(ctx.get()));
59975999
}
6000+
6001+
TEST_F(LLVMCodeBuilderTest, HatPredicates)
6002+
{
6003+
Sprite sprite;
6004+
6005+
// Predicate 1
6006+
createPredicateBuilder(&sprite);
6007+
6008+
CompilerValue *v = m_builder->addConstValue(true);
6009+
m_builder->addFunctionCall("test_const_bool", Compiler::StaticType::Bool, { Compiler::StaticType::Bool }, { v });
6010+
6011+
auto code1 = m_builder->finalize();
6012+
6013+
// Predicate 2
6014+
createPredicateBuilder(&sprite);
6015+
6016+
v = m_builder->addConstValue(false);
6017+
m_builder->addFunctionCall("test_const_bool", Compiler::StaticType::Bool, { Compiler::StaticType::Bool }, { v });
6018+
6019+
auto code2 = m_builder->finalize();
6020+
6021+
Script script1(&sprite, nullptr, nullptr);
6022+
script1.setCode(code1);
6023+
Thread thread1(&sprite, nullptr, &script1);
6024+
auto ctx = code1->createExecutionContext(&thread1);
6025+
6026+
ASSERT_TRUE(code1->runPredicate(ctx.get()));
6027+
6028+
Script script2(&sprite, nullptr, nullptr);
6029+
script2.setCode(code2);
6030+
Thread thread2(&sprite, nullptr, &script2);
6031+
ctx = code2->createExecutionContext(&thread2);
6032+
6033+
ASSERT_FALSE(code2->runPredicate(ctx.get()));
6034+
}

0 commit comments

Comments
 (0)