Skip to content

Commit 412ff6d

Browse files
committed
Add support for hat predicate LLVM code
1 parent 0da8733 commit 412ff6d

File tree

8 files changed

+115
-18
lines changed

8 files changed

+115
-18
lines changed

include/scratchcpp/executablecode.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ class LIBSCRATCHCPP_EXPORT ExecutableCode
2121
/*! Runs the script until it finishes or yields. */
2222
virtual void run(ExecutionContext *context) = 0;
2323

24+
/*! Runs the hat predicate and returns its return value. */
25+
virtual bool runPredicate(ExecutionContext *context) = 0;
26+
2427
/*! Stops the code. isFinished() will return true. */
2528
virtual void kill(ExecutionContext *context) = 0;
2629

src/engine/internal/llvm/llvmexecutablecode.cpp

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616

1717
using namespace libscratchcpp;
1818

19-
LLVMExecutableCode::LLVMExecutableCode(LLVMCompilerContext *ctx, const std::string &mainFunctionName, const std::string &resumeFunctionName) :
19+
LLVMExecutableCode::LLVMExecutableCode(LLVMCompilerContext *ctx, const std::string &mainFunctionName, const std::string &resumeFunctionName, bool isPredicate) :
2020
m_ctx(ctx),
2121
m_mainFunctionName(mainFunctionName),
22-
m_resumeFunctionName(resumeFunctionName)
22+
m_resumeFunctionName(resumeFunctionName),
23+
m_isPredicate(isPredicate)
2324
{
2425
assert(m_ctx);
2526

@@ -31,7 +32,7 @@ LLVMExecutableCode::LLVMExecutableCode(LLVMCompilerContext *ctx, const std::stri
3132

3233
void LLVMExecutableCode::run(ExecutionContext *context)
3334
{
34-
assert(m_mainFunction);
35+
assert(std::holds_alternative<MainFunctionType>(m_mainFunction));
3536
assert(m_resumeFunction);
3637
LLVMExecutionContext *ctx = getContext(context);
3738

@@ -56,7 +57,8 @@ void LLVMExecutableCode::run(ExecutionContext *context)
5657
ctx->setFinished(done);
5758
} else {
5859
Target *target = ctx->thread()->target();
59-
void *handle = m_mainFunction(context, target, target->variableData(), target->listData());
60+
MainFunctionType f = std::get<MainFunctionType>(m_mainFunction);
61+
void *handle = f(context, target, target->variableData(), target->listData());
6062

6163
if (!handle)
6264
ctx->setFinished(true);
@@ -65,6 +67,14 @@ void LLVMExecutableCode::run(ExecutionContext *context)
6567
}
6668
}
6769

70+
bool LLVMExecutableCode::runPredicate(ExecutionContext *context)
71+
{
72+
assert(std::holds_alternative<PredicateFunctionType>(m_mainFunction));
73+
Target *target = context->thread()->target();
74+
PredicateFunctionType f = std::get<PredicateFunctionType>(m_mainFunction);
75+
return f(context, target, target->variableData(), target->listData());
76+
}
77+
6878
void LLVMExecutableCode::kill(ExecutionContext *context)
6979
{
7080
LLVMExecutionContext *ctx = getContext(context);
@@ -91,7 +101,11 @@ std::shared_ptr<ExecutionContext> LLVMExecutableCode::createExecutionContext(Thr
91101
if (!m_ctx->jitInitialized())
92102
m_ctx->initJit();
93103

94-
m_mainFunction = m_ctx->lookupFunction<MainFunctionType>(m_mainFunctionName);
104+
if (m_isPredicate)
105+
m_mainFunction = m_ctx->lookupFunction<PredicateFunctionType>(m_mainFunctionName);
106+
else
107+
m_mainFunction = m_ctx->lookupFunction<MainFunctionType>(m_mainFunctionName);
108+
95109
m_resumeFunction = m_ctx->lookupFunction<ResumeFunctionType>(m_resumeFunctionName);
96110
return std::make_shared<LLVMExecutionContext>(thread);
97111
}

src/engine/internal/llvm/llvmexecutablecode.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ class LLVMExecutionContext;
1616
class LLVMExecutableCode : public ExecutableCode
1717
{
1818
public:
19-
LLVMExecutableCode(LLVMCompilerContext *ctx, const std::string &mainFunctionName, const std::string &resumeFunctionName);
19+
LLVMExecutableCode(LLVMCompilerContext *ctx, const std::string &mainFunctionName, const std::string &resumeFunctionName, bool isPredicate);
2020

2121
void run(ExecutionContext *context) override;
22+
bool runPredicate(ExecutionContext *context) override;
2223
void kill(libscratchcpp::ExecutionContext *context) override;
2324
void reset(ExecutionContext *context) override;
2425

@@ -28,14 +29,18 @@ class LLVMExecutableCode : public ExecutableCode
2829

2930
private:
3031
using MainFunctionType = void *(*)(ExecutionContext *, Target *, ValueData **, List **);
32+
using PredicateFunctionType = bool (*)(ExecutionContext *, Target *, ValueData **, List **);
3133
using ResumeFunctionType = bool (*)(void *);
3234

3335
static LLVMExecutionContext *getContext(ExecutionContext *context);
3436

3537
LLVMCompilerContext *m_ctx = nullptr;
3638
std::string m_mainFunctionName;
39+
std::string m_predicateFunctionName;
3740
std::string m_resumeFunctionName;
38-
mutable MainFunctionType m_mainFunction = nullptr;
41+
bool m_isPredicate = false;
42+
43+
mutable std::variant<MainFunctionType, PredicateFunctionType> m_mainFunction;
3944
mutable ResumeFunctionType m_resumeFunction = nullptr;
4045
};
4146

test/llvm/llvmexecutablecode_test.cpp

Lines changed: 75 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
using namespace libscratchcpp;
1919

20+
using ::testing::Return;
21+
2022
class LLVMExecutableCodeTest : public testing::Test
2123
{
2224
public:
@@ -34,11 +36,12 @@ class LLVMExecutableCodeTest : public testing::Test
3436

3537
inline llvm::Constant *nullPointer() { return llvm::ConstantPointerNull::get(llvm::PointerType::get(llvm::Type::getInt8Ty(*m_llvmCtx), 0)); }
3638

37-
llvm::Function *beginMainFunction()
39+
llvm::Function *beginMainFunction(bool predicate = false)
3840
{
3941
// void *f(ExecutionContext *, Target *, ValueData **, List **)
42+
// bool f(...) (hat predicates)
4043
llvm::Type *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(*m_llvmCtx), 0);
41-
llvm::FunctionType *funcType = llvm::FunctionType::get(pointerType, { pointerType, pointerType, pointerType, pointerType }, false);
44+
llvm::FunctionType *funcType = llvm::FunctionType::get(predicate ? m_builder->getInt1Ty() : pointerType, { pointerType, pointerType, pointerType, pointerType }, false);
4245
llvm::Function *func = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, "f", m_module);
4346

4447
llvm::BasicBlock *entry = llvm::BasicBlock::Create(*m_llvmCtx, "entry", func);
@@ -70,6 +73,17 @@ class LLVMExecutableCodeTest : public testing::Test
7073
m_builder->CreateCall(func, { mockPtr, mainFunc->getArg(0), mainFunc->getArg(1), mainFunc->getArg(2), mainFunc->getArg(3) });
7174
}
7275

76+
llvm::Value *addPredicateFunction(llvm::Function *mainFunc)
77+
{
78+
auto ptrType = llvm::PointerType::get(llvm::Type::getInt8Ty(*m_llvmCtx), 0);
79+
auto func = m_module->getOrInsertFunction("test_predicate", llvm::FunctionType::get(m_builder->getInt1Ty(), { ptrType, ptrType, ptrType, ptrType, ptrType }, false));
80+
81+
llvm::Constant *mockInt = llvm::ConstantInt::get(llvm::Type::getInt64Ty(*m_llvmCtx), (uintptr_t)&m_mock, false);
82+
llvm::Constant *mockPtr = llvm::ConstantExpr::getIntToPtr(mockInt, ptrType);
83+
84+
return m_builder->CreateCall(func, { mockPtr, mainFunc->getArg(0), mainFunc->getArg(1), mainFunc->getArg(2), mainFunc->getArg(3) });
85+
}
86+
7387
void addTestPrintFunction(llvm::Value *arg1, llvm::Value *arg2)
7488
{
7589
auto ptrType = llvm::PointerType::get(llvm::Type::getInt8Ty(*m_llvmCtx), 0);
@@ -95,13 +109,34 @@ TEST_F(LLVMExecutableCodeTest, CreateExecutionContext)
95109
llvm::Function *resumeFunc = beginResumeFunction();
96110
endFunction(m_builder->getInt1(true));
97111

98-
auto code = std::make_shared<LLVMExecutableCode>(m_ctx.get(), mainFunc->getName().str(), resumeFunc->getName().str());
99-
m_script->setCode(code);
100-
Thread thread(&m_target, &m_engine, m_script.get());
101-
auto ctx = code->createExecutionContext(&thread);
102-
ASSERT_TRUE(ctx);
103-
ASSERT_EQ(ctx->thread(), &thread);
104-
ASSERT_TRUE(dynamic_cast<LLVMExecutionContext *>(ctx.get()));
112+
{
113+
auto code = std::make_shared<LLVMExecutableCode>(m_ctx.get(), mainFunc->getName().str(), resumeFunc->getName().str(), false);
114+
m_script->setCode(code);
115+
Thread thread(&m_target, &m_engine, m_script.get());
116+
auto ctx = code->createExecutionContext(&thread);
117+
ASSERT_TRUE(ctx);
118+
ASSERT_EQ(ctx->thread(), &thread);
119+
ASSERT_TRUE(dynamic_cast<LLVMExecutionContext *>(ctx.get()));
120+
}
121+
}
122+
123+
TEST_F(LLVMExecutableCodeTest, CreatePredicateExecutionContext)
124+
{
125+
llvm::Function *mainFunc = beginMainFunction(true);
126+
endFunction(m_builder->getInt1(false));
127+
128+
llvm::Function *resumeFunc = beginResumeFunction();
129+
endFunction(m_builder->getInt1(true));
130+
131+
{
132+
auto code = std::make_shared<LLVMExecutableCode>(m_ctx.get(), mainFunc->getName().str(), resumeFunc->getName().str(), true);
133+
m_script->setCode(code);
134+
Thread thread(&m_target, &m_engine, m_script.get());
135+
auto ctx = code->createExecutionContext(&thread);
136+
ASSERT_TRUE(ctx);
137+
ASSERT_EQ(ctx->thread(), &thread);
138+
ASSERT_TRUE(dynamic_cast<LLVMExecutionContext *>(ctx.get()));
139+
}
105140
}
106141

107142
TEST_F(LLVMExecutableCodeTest, MainFunction)
@@ -116,7 +151,7 @@ TEST_F(LLVMExecutableCodeTest, MainFunction)
116151
llvm::Function *resumeFunc = beginResumeFunction();
117152
endFunction(m_builder->getInt1(true));
118153

119-
auto code = std::make_shared<LLVMExecutableCode>(m_ctx.get(), mainFunc->getName().str(), resumeFunc->getName().str());
154+
auto code = std::make_shared<LLVMExecutableCode>(m_ctx.get(), mainFunc->getName().str(), resumeFunc->getName().str(), false);
120155
m_script->setCode(code);
121156
Thread thread(&m_target, &m_engine, m_script.get());
122157
auto ctx = code->createExecutionContext(&thread);
@@ -160,6 +195,35 @@ TEST_F(LLVMExecutableCodeTest, MainFunction)
160195
ASSERT_FALSE(code->isFinished(ctx.get()));
161196
}
162197

198+
TEST_F(LLVMExecutableCodeTest, PredicateFunction)
199+
{
200+
m_target.addVariable(std::make_shared<Variable>("", ""));
201+
m_target.addList(std::make_shared<List>("", ""));
202+
203+
llvm::Function *mainFunc = beginMainFunction(true);
204+
endFunction(addPredicateFunction(mainFunc));
205+
206+
llvm::Function *resumeFunc = beginResumeFunction();
207+
endFunction(m_builder->getInt1(true));
208+
209+
auto code = std::make_shared<LLVMExecutableCode>(m_ctx.get(), mainFunc->getName().str(), resumeFunc->getName().str(), true);
210+
m_script->setCode(code);
211+
Thread thread(&m_target, &m_engine, m_script.get());
212+
auto ctx = code->createExecutionContext(&thread);
213+
214+
EXPECT_CALL(m_mock, predicate(ctx.get(), &m_target, m_target.variableData(), m_target.listData())).WillOnce(Return(true));
215+
ASSERT_TRUE(code->runPredicate(ctx.get()));
216+
217+
EXPECT_CALL(m_mock, predicate(ctx.get(), &m_target, m_target.variableData(), m_target.listData())).WillOnce(Return(true));
218+
ASSERT_TRUE(code->runPredicate(ctx.get()));
219+
220+
EXPECT_CALL(m_mock, predicate(ctx.get(), &m_target, m_target.variableData(), m_target.listData())).WillOnce(Return(false));
221+
ASSERT_FALSE(code->runPredicate(ctx.get()));
222+
223+
EXPECT_CALL(m_mock, predicate(ctx.get(), &m_target, m_target.variableData(), m_target.listData())).WillOnce(Return(false));
224+
ASSERT_FALSE(code->runPredicate(ctx.get()));
225+
}
226+
163227
TEST_F(LLVMExecutableCodeTest, Promise)
164228
{
165229
llvm::Function *mainFunc = beginMainFunction();
@@ -169,7 +233,7 @@ TEST_F(LLVMExecutableCodeTest, Promise)
169233
llvm::Function *resumeFunc = beginResumeFunction();
170234
endFunction(m_builder->getInt1(true));
171235

172-
auto code = std::make_shared<LLVMExecutableCode>(m_ctx.get(), mainFunc->getName().str(), resumeFunc->getName().str());
236+
auto code = std::make_shared<LLVMExecutableCode>(m_ctx.get(), mainFunc->getName().str(), resumeFunc->getName().str(), false);
173237
m_script->setCode(code);
174238
Thread thread(&m_target, &m_engine, m_script.get());
175239
auto ctx = code->createExecutionContext(&thread);

test/llvm/testfunctions.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@ extern "C"
2121
mock->f(ctx, target, varData, listData);
2222
}
2323

24+
bool test_predicate(TestMock *mock, ExecutionContext *ctx, Target *target, ValueData **varData, List **listData)
25+
{
26+
if (mock)
27+
return mock->predicate(ctx, target, varData, listData);
28+
29+
return false;
30+
}
31+
2432
void test_print_function(ValueData *arg1, ValueData *arg2)
2533
{
2634
std::string s1, s2;

test/llvm/testfunctions.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ struct StringPtr;
1313
extern "C"
1414
{
1515
void test_function(TestMock *mock, ExecutionContext *ctx, Target *target, ValueData **varData, List **listData);
16+
bool test_predicate(TestMock *mock, ExecutionContext *ctx, Target *target, ValueData **varData, List **listData);
1617
void test_print_function(ValueData *arg1, ValueData *arg2);
1718

1819
void test_empty_function();

test/llvm/testmock.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class TestMock
1313
{
1414
public:
1515
MOCK_METHOD(void, f, (ExecutionContext * ctx, Target *, ValueData **, List **));
16+
MOCK_METHOD(bool, predicate, (ExecutionContext * ctx, Target *, ValueData **, List **));
1617
};
1718

1819
} // namespace libscratchcpp

test/mocks/executablecodemock.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ class ExecutableCodeMock : public ExecutableCode
99
{
1010
public:
1111
MOCK_METHOD(void, run, (ExecutionContext *), (override));
12+
MOCK_METHOD(bool, runPredicate, (ExecutionContext *), (override));
1213
MOCK_METHOD(void, kill, (ExecutionContext *), (override));
1314
MOCK_METHOD(void, reset, (ExecutionContext *), (override));
1415

0 commit comments

Comments
 (0)