From 412ff6d265adb060295edc892ccc8ac0e8e98fa7 Mon Sep 17 00:00:00 2001 From: adazem009 <68537469+adazem009@users.noreply.github.com> Date: Tue, 18 Feb 2025 11:28:29 +0100 Subject: [PATCH 1/9] Add support for hat predicate LLVM code --- include/scratchcpp/executablecode.h | 3 + .../internal/llvm/llvmexecutablecode.cpp | 24 ++++-- src/engine/internal/llvm/llvmexecutablecode.h | 9 +- test/llvm/llvmexecutablecode_test.cpp | 86 ++++++++++++++++--- test/llvm/testfunctions.cpp | 8 ++ test/llvm/testfunctions.h | 1 + test/llvm/testmock.h | 1 + test/mocks/executablecodemock.h | 1 + 8 files changed, 115 insertions(+), 18 deletions(-) diff --git a/include/scratchcpp/executablecode.h b/include/scratchcpp/executablecode.h index 094d0181..c55a4983 100644 --- a/include/scratchcpp/executablecode.h +++ b/include/scratchcpp/executablecode.h @@ -21,6 +21,9 @@ class LIBSCRATCHCPP_EXPORT ExecutableCode /*! Runs the script until it finishes or yields. */ virtual void run(ExecutionContext *context) = 0; + /*! Runs the hat predicate and returns its return value. */ + virtual bool runPredicate(ExecutionContext *context) = 0; + /*! Stops the code. isFinished() will return true. */ virtual void kill(ExecutionContext *context) = 0; diff --git a/src/engine/internal/llvm/llvmexecutablecode.cpp b/src/engine/internal/llvm/llvmexecutablecode.cpp index 60f69531..f41c660b 100644 --- a/src/engine/internal/llvm/llvmexecutablecode.cpp +++ b/src/engine/internal/llvm/llvmexecutablecode.cpp @@ -16,10 +16,11 @@ using namespace libscratchcpp; -LLVMExecutableCode::LLVMExecutableCode(LLVMCompilerContext *ctx, const std::string &mainFunctionName, const std::string &resumeFunctionName) : +LLVMExecutableCode::LLVMExecutableCode(LLVMCompilerContext *ctx, const std::string &mainFunctionName, const std::string &resumeFunctionName, bool isPredicate) : m_ctx(ctx), m_mainFunctionName(mainFunctionName), - m_resumeFunctionName(resumeFunctionName) + m_resumeFunctionName(resumeFunctionName), + m_isPredicate(isPredicate) { assert(m_ctx); @@ -31,7 +32,7 @@ LLVMExecutableCode::LLVMExecutableCode(LLVMCompilerContext *ctx, const std::stri void LLVMExecutableCode::run(ExecutionContext *context) { - assert(m_mainFunction); + assert(std::holds_alternative(m_mainFunction)); assert(m_resumeFunction); LLVMExecutionContext *ctx = getContext(context); @@ -56,7 +57,8 @@ void LLVMExecutableCode::run(ExecutionContext *context) ctx->setFinished(done); } else { Target *target = ctx->thread()->target(); - void *handle = m_mainFunction(context, target, target->variableData(), target->listData()); + MainFunctionType f = std::get(m_mainFunction); + void *handle = f(context, target, target->variableData(), target->listData()); if (!handle) ctx->setFinished(true); @@ -65,6 +67,14 @@ void LLVMExecutableCode::run(ExecutionContext *context) } } +bool LLVMExecutableCode::runPredicate(ExecutionContext *context) +{ + assert(std::holds_alternative(m_mainFunction)); + Target *target = context->thread()->target(); + PredicateFunctionType f = std::get(m_mainFunction); + return f(context, target, target->variableData(), target->listData()); +} + void LLVMExecutableCode::kill(ExecutionContext *context) { LLVMExecutionContext *ctx = getContext(context); @@ -91,7 +101,11 @@ std::shared_ptr LLVMExecutableCode::createExecutionContext(Thr if (!m_ctx->jitInitialized()) m_ctx->initJit(); - m_mainFunction = m_ctx->lookupFunction(m_mainFunctionName); + if (m_isPredicate) + m_mainFunction = m_ctx->lookupFunction(m_mainFunctionName); + else + m_mainFunction = m_ctx->lookupFunction(m_mainFunctionName); + m_resumeFunction = m_ctx->lookupFunction(m_resumeFunctionName); return std::make_shared(thread); } diff --git a/src/engine/internal/llvm/llvmexecutablecode.h b/src/engine/internal/llvm/llvmexecutablecode.h index 3242e315..5526bed5 100644 --- a/src/engine/internal/llvm/llvmexecutablecode.h +++ b/src/engine/internal/llvm/llvmexecutablecode.h @@ -16,9 +16,10 @@ class LLVMExecutionContext; class LLVMExecutableCode : public ExecutableCode { public: - LLVMExecutableCode(LLVMCompilerContext *ctx, const std::string &mainFunctionName, const std::string &resumeFunctionName); + LLVMExecutableCode(LLVMCompilerContext *ctx, const std::string &mainFunctionName, const std::string &resumeFunctionName, bool isPredicate); void run(ExecutionContext *context) override; + bool runPredicate(ExecutionContext *context) override; void kill(libscratchcpp::ExecutionContext *context) override; void reset(ExecutionContext *context) override; @@ -28,14 +29,18 @@ class LLVMExecutableCode : public ExecutableCode private: using MainFunctionType = void *(*)(ExecutionContext *, Target *, ValueData **, List **); + using PredicateFunctionType = bool (*)(ExecutionContext *, Target *, ValueData **, List **); using ResumeFunctionType = bool (*)(void *); static LLVMExecutionContext *getContext(ExecutionContext *context); LLVMCompilerContext *m_ctx = nullptr; std::string m_mainFunctionName; + std::string m_predicateFunctionName; std::string m_resumeFunctionName; - mutable MainFunctionType m_mainFunction = nullptr; + bool m_isPredicate = false; + + mutable std::variant m_mainFunction; mutable ResumeFunctionType m_resumeFunction = nullptr; }; diff --git a/test/llvm/llvmexecutablecode_test.cpp b/test/llvm/llvmexecutablecode_test.cpp index 17712d8e..f025831e 100644 --- a/test/llvm/llvmexecutablecode_test.cpp +++ b/test/llvm/llvmexecutablecode_test.cpp @@ -17,6 +17,8 @@ using namespace libscratchcpp; +using ::testing::Return; + class LLVMExecutableCodeTest : public testing::Test { public: @@ -34,11 +36,12 @@ class LLVMExecutableCodeTest : public testing::Test inline llvm::Constant *nullPointer() { return llvm::ConstantPointerNull::get(llvm::PointerType::get(llvm::Type::getInt8Ty(*m_llvmCtx), 0)); } - llvm::Function *beginMainFunction() + llvm::Function *beginMainFunction(bool predicate = false) { // void *f(ExecutionContext *, Target *, ValueData **, List **) + // bool f(...) (hat predicates) llvm::Type *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(*m_llvmCtx), 0); - llvm::FunctionType *funcType = llvm::FunctionType::get(pointerType, { pointerType, pointerType, pointerType, pointerType }, false); + llvm::FunctionType *funcType = llvm::FunctionType::get(predicate ? m_builder->getInt1Ty() : pointerType, { pointerType, pointerType, pointerType, pointerType }, false); llvm::Function *func = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, "f", m_module); llvm::BasicBlock *entry = llvm::BasicBlock::Create(*m_llvmCtx, "entry", func); @@ -70,6 +73,17 @@ class LLVMExecutableCodeTest : public testing::Test m_builder->CreateCall(func, { mockPtr, mainFunc->getArg(0), mainFunc->getArg(1), mainFunc->getArg(2), mainFunc->getArg(3) }); } + llvm::Value *addPredicateFunction(llvm::Function *mainFunc) + { + auto ptrType = llvm::PointerType::get(llvm::Type::getInt8Ty(*m_llvmCtx), 0); + auto func = m_module->getOrInsertFunction("test_predicate", llvm::FunctionType::get(m_builder->getInt1Ty(), { ptrType, ptrType, ptrType, ptrType, ptrType }, false)); + + llvm::Constant *mockInt = llvm::ConstantInt::get(llvm::Type::getInt64Ty(*m_llvmCtx), (uintptr_t)&m_mock, false); + llvm::Constant *mockPtr = llvm::ConstantExpr::getIntToPtr(mockInt, ptrType); + + return m_builder->CreateCall(func, { mockPtr, mainFunc->getArg(0), mainFunc->getArg(1), mainFunc->getArg(2), mainFunc->getArg(3) }); + } + void addTestPrintFunction(llvm::Value *arg1, llvm::Value *arg2) { auto ptrType = llvm::PointerType::get(llvm::Type::getInt8Ty(*m_llvmCtx), 0); @@ -95,13 +109,34 @@ TEST_F(LLVMExecutableCodeTest, CreateExecutionContext) llvm::Function *resumeFunc = beginResumeFunction(); endFunction(m_builder->getInt1(true)); - auto code = std::make_shared(m_ctx.get(), mainFunc->getName().str(), resumeFunc->getName().str()); - m_script->setCode(code); - Thread thread(&m_target, &m_engine, m_script.get()); - auto ctx = code->createExecutionContext(&thread); - ASSERT_TRUE(ctx); - ASSERT_EQ(ctx->thread(), &thread); - ASSERT_TRUE(dynamic_cast(ctx.get())); + { + auto code = std::make_shared(m_ctx.get(), mainFunc->getName().str(), resumeFunc->getName().str(), false); + m_script->setCode(code); + Thread thread(&m_target, &m_engine, m_script.get()); + auto ctx = code->createExecutionContext(&thread); + ASSERT_TRUE(ctx); + ASSERT_EQ(ctx->thread(), &thread); + ASSERT_TRUE(dynamic_cast(ctx.get())); + } +} + +TEST_F(LLVMExecutableCodeTest, CreatePredicateExecutionContext) +{ + llvm::Function *mainFunc = beginMainFunction(true); + endFunction(m_builder->getInt1(false)); + + llvm::Function *resumeFunc = beginResumeFunction(); + endFunction(m_builder->getInt1(true)); + + { + auto code = std::make_shared(m_ctx.get(), mainFunc->getName().str(), resumeFunc->getName().str(), true); + m_script->setCode(code); + Thread thread(&m_target, &m_engine, m_script.get()); + auto ctx = code->createExecutionContext(&thread); + ASSERT_TRUE(ctx); + ASSERT_EQ(ctx->thread(), &thread); + ASSERT_TRUE(dynamic_cast(ctx.get())); + } } TEST_F(LLVMExecutableCodeTest, MainFunction) @@ -116,7 +151,7 @@ TEST_F(LLVMExecutableCodeTest, MainFunction) llvm::Function *resumeFunc = beginResumeFunction(); endFunction(m_builder->getInt1(true)); - auto code = std::make_shared(m_ctx.get(), mainFunc->getName().str(), resumeFunc->getName().str()); + auto code = std::make_shared(m_ctx.get(), mainFunc->getName().str(), resumeFunc->getName().str(), false); m_script->setCode(code); Thread thread(&m_target, &m_engine, m_script.get()); auto ctx = code->createExecutionContext(&thread); @@ -160,6 +195,35 @@ TEST_F(LLVMExecutableCodeTest, MainFunction) ASSERT_FALSE(code->isFinished(ctx.get())); } +TEST_F(LLVMExecutableCodeTest, PredicateFunction) +{ + m_target.addVariable(std::make_shared("", "")); + m_target.addList(std::make_shared("", "")); + + llvm::Function *mainFunc = beginMainFunction(true); + endFunction(addPredicateFunction(mainFunc)); + + llvm::Function *resumeFunc = beginResumeFunction(); + endFunction(m_builder->getInt1(true)); + + auto code = std::make_shared(m_ctx.get(), mainFunc->getName().str(), resumeFunc->getName().str(), true); + m_script->setCode(code); + Thread thread(&m_target, &m_engine, m_script.get()); + auto ctx = code->createExecutionContext(&thread); + + EXPECT_CALL(m_mock, predicate(ctx.get(), &m_target, m_target.variableData(), m_target.listData())).WillOnce(Return(true)); + ASSERT_TRUE(code->runPredicate(ctx.get())); + + EXPECT_CALL(m_mock, predicate(ctx.get(), &m_target, m_target.variableData(), m_target.listData())).WillOnce(Return(true)); + ASSERT_TRUE(code->runPredicate(ctx.get())); + + EXPECT_CALL(m_mock, predicate(ctx.get(), &m_target, m_target.variableData(), m_target.listData())).WillOnce(Return(false)); + ASSERT_FALSE(code->runPredicate(ctx.get())); + + EXPECT_CALL(m_mock, predicate(ctx.get(), &m_target, m_target.variableData(), m_target.listData())).WillOnce(Return(false)); + ASSERT_FALSE(code->runPredicate(ctx.get())); +} + TEST_F(LLVMExecutableCodeTest, Promise) { llvm::Function *mainFunc = beginMainFunction(); @@ -169,7 +233,7 @@ TEST_F(LLVMExecutableCodeTest, Promise) llvm::Function *resumeFunc = beginResumeFunction(); endFunction(m_builder->getInt1(true)); - auto code = std::make_shared(m_ctx.get(), mainFunc->getName().str(), resumeFunc->getName().str()); + auto code = std::make_shared(m_ctx.get(), mainFunc->getName().str(), resumeFunc->getName().str(), false); m_script->setCode(code); Thread thread(&m_target, &m_engine, m_script.get()); auto ctx = code->createExecutionContext(&thread); diff --git a/test/llvm/testfunctions.cpp b/test/llvm/testfunctions.cpp index cf88194a..3aaccb10 100644 --- a/test/llvm/testfunctions.cpp +++ b/test/llvm/testfunctions.cpp @@ -21,6 +21,14 @@ extern "C" mock->f(ctx, target, varData, listData); } + bool test_predicate(TestMock *mock, ExecutionContext *ctx, Target *target, ValueData **varData, List **listData) + { + if (mock) + return mock->predicate(ctx, target, varData, listData); + + return false; + } + void test_print_function(ValueData *arg1, ValueData *arg2) { std::string s1, s2; diff --git a/test/llvm/testfunctions.h b/test/llvm/testfunctions.h index 490ffdf1..0a8ded8c 100644 --- a/test/llvm/testfunctions.h +++ b/test/llvm/testfunctions.h @@ -13,6 +13,7 @@ struct StringPtr; extern "C" { void test_function(TestMock *mock, ExecutionContext *ctx, Target *target, ValueData **varData, List **listData); + bool test_predicate(TestMock *mock, ExecutionContext *ctx, Target *target, ValueData **varData, List **listData); void test_print_function(ValueData *arg1, ValueData *arg2); void test_empty_function(); diff --git a/test/llvm/testmock.h b/test/llvm/testmock.h index 5573f247..14ef3453 100644 --- a/test/llvm/testmock.h +++ b/test/llvm/testmock.h @@ -13,6 +13,7 @@ class TestMock { public: MOCK_METHOD(void, f, (ExecutionContext * ctx, Target *, ValueData **, List **)); + MOCK_METHOD(bool, predicate, (ExecutionContext * ctx, Target *, ValueData **, List **)); }; } // namespace libscratchcpp diff --git a/test/mocks/executablecodemock.h b/test/mocks/executablecodemock.h index b2997460..4d113c26 100644 --- a/test/mocks/executablecodemock.h +++ b/test/mocks/executablecodemock.h @@ -9,6 +9,7 @@ class ExecutableCodeMock : public ExecutableCode { public: MOCK_METHOD(void, run, (ExecutionContext *), (override)); + MOCK_METHOD(bool, runPredicate, (ExecutionContext *), (override)); MOCK_METHOD(void, kill, (ExecutionContext *), (override)); MOCK_METHOD(void, reset, (ExecutionContext *), (override)); From 66978e03bae04a45e0b252add6e5bbb6a1dead36 Mon Sep 17 00:00:00 2001 From: adazem009 <68537469+adazem009@users.noreply.github.com> Date: Tue, 18 Feb 2025 11:46:57 +0100 Subject: [PATCH 2/9] LLVMCodeBuilder: Implement hat predicates --- src/engine/internal/llvm/llvmcodebuilder.cpp | 32 ++++++++++----- src/engine/internal/llvm/llvmcodebuilder.h | 3 +- test/llvm/llvmcodebuilder_test.cpp | 41 +++++++++++++++++++- 3 files changed, 63 insertions(+), 13 deletions(-) diff --git a/src/engine/internal/llvm/llvmcodebuilder.cpp b/src/engine/internal/llvm/llvmcodebuilder.cpp index dd452c7f..e212b3cf 100644 --- a/src/engine/internal/llvm/llvmcodebuilder.cpp +++ b/src/engine/internal/llvm/llvmcodebuilder.cpp @@ -26,7 +26,7 @@ static std::unordered_map static const std::unordered_set VAR_LIST_READ_INSTRUCTIONS = { LLVMInstruction::Type::ReadVariable, LLVMInstruction::Type::GetListItem, LLVMInstruction::Type::GetListItemIndex, LLVMInstruction::Type::ListContainsItem }; -LLVMCodeBuilder::LLVMCodeBuilder(LLVMCompilerContext *ctx, BlockPrototype *procedurePrototype) : +LLVMCodeBuilder::LLVMCodeBuilder(LLVMCompilerContext *ctx, BlockPrototype *procedurePrototype, bool isPredicate) : m_ctx(ctx), m_target(ctx->target()), m_llvmCtx(*ctx->llvmCtx()), @@ -34,7 +34,8 @@ LLVMCodeBuilder::LLVMCodeBuilder(LLVMCompilerContext *ctx, BlockPrototype *proce m_builder(m_llvmCtx), m_procedurePrototype(procedurePrototype), m_defaultWarp(procedurePrototype ? procedurePrototype->warp() : false), - m_warp(m_defaultWarp) + m_warp(m_defaultWarp), + m_isPredicate(isPredicate) { initTypes(); createVariableMap(); @@ -54,6 +55,10 @@ std::shared_ptr LLVMCodeBuilder::finalize() if (it == m_instructions.end()) m_warp = true; + + // Do not create coroutine in hat predicates + if (m_isPredicate) + m_warp = true; } // Set fast math flags @@ -1314,10 +1319,16 @@ std::shared_ptr LLVMCodeBuilder::finalize() // End and verify the function llvm::PointerType *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_llvmCtx), 0); - if (m_warp) - m_builder.CreateRet(llvm::ConstantPointerNull::get(pointerType)); - else - coro->end(); + if (m_isPredicate) { + // Use last instruction return value + assert(!m_instructions.empty()); + m_builder.CreateRet(m_instructions.back()->functionReturnReg->value); + } else { + if (m_warp) + m_builder.CreateRet(llvm::ConstantPointerNull::get(pointerType)); + else + coro->end(); + } verifyFunction(m_function); @@ -1338,7 +1349,7 @@ std::shared_ptr LLVMCodeBuilder::finalize() verifyFunction(resumeFunc); - return std::make_shared(m_ctx, m_function->getName().str(), resumeFunc->getName().str()); + return std::make_shared(m_ctx, m_function->getName().str(), resumeFunc->getName().str(), m_isPredicate); } CompilerValue *LLVMCodeBuilder::addFunctionCall(const std::string &functionName, Compiler::StaticType returnType, const Compiler::ArgTypes &argTypes, const Compiler::Args &args) @@ -2008,7 +2019,7 @@ void LLVMCodeBuilder::popLoopScope() std::string LLVMCodeBuilder::getMainFunctionName(BlockPrototype *procedurePrototype) { - return procedurePrototype ? "proc." + procedurePrototype->procCode() : "script"; + return procedurePrototype ? "proc." + procedurePrototype->procCode() : (m_isPredicate ? "predicate" : "script"); } std::string LLVMCodeBuilder::getResumeFunctionName(BlockPrototype *procedurePrototype) @@ -2019,7 +2030,8 @@ std::string LLVMCodeBuilder::getResumeFunctionName(BlockPrototype *procedureProt llvm::FunctionType *LLVMCodeBuilder::getMainFunctionType(BlockPrototype *procedurePrototype) { // void *f(ExecutionContext *, Target *, ValueData **, List **, (warp arg), (procedure args...)) - llvm::PointerType *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_llvmCtx), 0); + // bool f(...) (hat predicates) + llvm::Type *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_llvmCtx), 0); std::vector argTypes = { pointerType, pointerType, pointerType, pointerType }; if (procedurePrototype) { @@ -2034,7 +2046,7 @@ llvm::FunctionType *LLVMCodeBuilder::getMainFunctionType(BlockPrototype *procedu } } - return llvm::FunctionType::get(pointerType, argTypes, false); + return llvm::FunctionType::get(m_isPredicate ? m_builder.getInt1Ty() : pointerType, argTypes, false); } llvm::Function *LLVMCodeBuilder::getOrCreateFunction(const std::string &name, llvm::FunctionType *type) diff --git a/src/engine/internal/llvm/llvmcodebuilder.h b/src/engine/internal/llvm/llvmcodebuilder.h index f0bee6a7..5a765270 100644 --- a/src/engine/internal/llvm/llvmcodebuilder.h +++ b/src/engine/internal/llvm/llvmcodebuilder.h @@ -25,7 +25,7 @@ class LLVMLoopScope; class LLVMCodeBuilder : public ICodeBuilder { public: - LLVMCodeBuilder(LLVMCompilerContext *ctx, BlockPrototype *procedurePrototype = nullptr); + LLVMCodeBuilder(LLVMCompilerContext *ctx, BlockPrototype *procedurePrototype = nullptr, bool isPredicate = false); std::shared_ptr finalize() override; @@ -239,6 +239,7 @@ class LLVMCodeBuilder : public ICodeBuilder bool m_defaultWarp = false; bool m_warp = false; int m_defaultArgCount = 0; + bool m_isPredicate = false; // for hat predicates long m_loopScope = -1; // index std::vector> m_loopScopes; diff --git a/test/llvm/llvmcodebuilder_test.cpp b/test/llvm/llvmcodebuilder_test.cpp index 7f3b9f51..86cbfe0b 100644 --- a/test/llvm/llvmcodebuilder_test.cpp +++ b/test/llvm/llvmcodebuilder_test.cpp @@ -67,12 +67,12 @@ class LLVMCodeBuilderTest : public testing::Test test_function(nullptr, nullptr, nullptr, nullptr, nullptr); // force dependency } - void createBuilder(Target *target, BlockPrototype *procedurePrototype) + void createBuilder(Target *target, BlockPrototype *procedurePrototype, bool isPredicate = false) { if (m_contexts.find(target) == m_contexts.cend() || !target) m_contexts[target] = std::make_unique(&m_engine, target); - m_builder = std::make_unique(m_contexts[target].get(), procedurePrototype); + m_builder = std::make_unique(m_contexts[target].get(), procedurePrototype, isPredicate); } void createBuilder(Target *target, bool warp) @@ -82,6 +82,8 @@ class LLVMCodeBuilderTest : public testing::Test createBuilder(target, m_procedurePrototype.get()); } + void createPredicateBuilder(Target *target) { createBuilder(target, nullptr, true); } + void createBuilder(bool warp) { createBuilder(nullptr, warp); } CompilerValue *callConstFuncForType(ValueType type, CompilerValue *arg) @@ -5995,3 +5997,38 @@ TEST_F(LLVMCodeBuilderTest, Procedures) ASSERT_EQ(testing::internal::GetCapturedStdout(), expected2 + expected3); ASSERT_TRUE(code->isFinished(ctx.get())); } + +TEST_F(LLVMCodeBuilderTest, HatPredicates) +{ + Sprite sprite; + + // Predicate 1 + createPredicateBuilder(&sprite); + + CompilerValue *v = m_builder->addConstValue(true); + m_builder->addFunctionCall("test_const_bool", Compiler::StaticType::Bool, { Compiler::StaticType::Bool }, { v }); + + auto code1 = m_builder->finalize(); + + // Predicate 2 + createPredicateBuilder(&sprite); + + v = m_builder->addConstValue(false); + m_builder->addFunctionCall("test_const_bool", Compiler::StaticType::Bool, { Compiler::StaticType::Bool }, { v }); + + auto code2 = m_builder->finalize(); + + Script script1(&sprite, nullptr, nullptr); + script1.setCode(code1); + Thread thread1(&sprite, nullptr, &script1); + auto ctx = code1->createExecutionContext(&thread1); + + ASSERT_TRUE(code1->runPredicate(ctx.get())); + + Script script2(&sprite, nullptr, nullptr); + script2.setCode(code2); + Thread thread2(&sprite, nullptr, &script2); + ctx = code2->createExecutionContext(&thread2); + + ASSERT_FALSE(code2->runPredicate(ctx.get())); +} From 22661492f67f4ff7bd35643204e430f72345e26b Mon Sep 17 00:00:00 2001 From: adazem009 <68537469+adazem009@users.noreply.github.com> Date: Tue, 18 Feb 2025 19:52:42 +0100 Subject: [PATCH 3/9] Implement hat predicates in Compiler --- include/scratchcpp/compiler.h | 2 +- src/engine/compiler.cpp | 23 ++++++++++-- src/engine/internal/codebuilderfactory.cpp | 4 +-- src/engine/internal/codebuilderfactory.h | 2 +- src/engine/internal/icodebuilderfactory.h | 2 +- test/compiler/compiler_test.cpp | 42 +++++++++++++++++++--- test/mocks/codebuilderfactorymock.h | 2 +- 7 files changed, 65 insertions(+), 12 deletions(-) diff --git a/include/scratchcpp/compiler.h b/include/scratchcpp/compiler.h index 9a604702..cc7db35a 100644 --- a/include/scratchcpp/compiler.h +++ b/include/scratchcpp/compiler.h @@ -49,7 +49,7 @@ class LIBSCRATCHCPP_EXPORT Compiler Target *target() const; std::shared_ptr block() const; - std::shared_ptr compile(std::shared_ptr startBlock); + std::shared_ptr compile(std::shared_ptr startBlock, bool isHatPredicate = false); void preoptimize(); CompilerValue *addFunctionCall(const std::string &functionName, StaticType returnType = StaticType::Void, const ArgTypes &argTypes = {}, const Args &args = {}); diff --git a/src/engine/compiler.cpp b/src/engine/compiler.cpp index dda27e2f..17ea93ce 100644 --- a/src/engine/compiler.cpp +++ b/src/engine/compiler.cpp @@ -44,7 +44,7 @@ std::shared_ptr Compiler::block() const } /*! Compiles the script starting with the given block. */ -std::shared_ptr Compiler::compile(std::shared_ptr startBlock) +std::shared_ptr Compiler::compile(std::shared_ptr startBlock, bool isHatPredicate) { BlockPrototype *procedurePrototype = nullptr; @@ -60,13 +60,32 @@ std::shared_ptr Compiler::compile(std::shared_ptr startBl } } - impl->builder = impl->builderFactory->create(impl->ctx, procedurePrototype); + impl->builder = impl->builderFactory->create(impl->ctx, procedurePrototype, isHatPredicate); impl->substackTree.clear(); impl->substackHit = false; impl->emptySubstack = false; impl->warp = false; impl->block = startBlock; + if (impl->block && isHatPredicate) { + auto f = impl->block->hatPredicateCompileFunction(); + + if (f) { + CompilerValue *ret = f(this); + assert(ret); + + if (!ret) + std::cout << "warning: '" << impl->block->opcode() << "' hat predicate compile function doesn't return a valid value" << std::endl; + } else { + std::cout << "warning: unsupported hat predicate: " << impl->block->opcode() << std::endl; + impl->unsupportedBlocks.insert(impl->block->opcode()); + addConstValue(false); // return false if unsupported + } + + impl->block = nullptr; + return impl->builder->finalize(); + } + while (impl->block) { if (impl->block->compileFunction()) { assert(impl->customIfStatementCount == 0); diff --git a/src/engine/internal/codebuilderfactory.cpp b/src/engine/internal/codebuilderfactory.cpp index bb3cb1e9..42ea1122 100644 --- a/src/engine/internal/codebuilderfactory.cpp +++ b/src/engine/internal/codebuilderfactory.cpp @@ -13,10 +13,10 @@ std::shared_ptr CodeBuilderFactory::instance() return m_instance; } -std::shared_ptr CodeBuilderFactory::create(CompilerContext *ctx, BlockPrototype *procedurePrototype) const +std::shared_ptr CodeBuilderFactory::create(CompilerContext *ctx, BlockPrototype *procedurePrototype, bool isPredicate) const { assert(dynamic_cast(ctx)); - return std::make_shared(static_cast(ctx), procedurePrototype); + return std::make_shared(static_cast(ctx), procedurePrototype, isPredicate); } std::shared_ptr CodeBuilderFactory::createCtx(IEngine *engine, Target *target) const diff --git a/src/engine/internal/codebuilderfactory.h b/src/engine/internal/codebuilderfactory.h index 942a9a1e..c0c29568 100644 --- a/src/engine/internal/codebuilderfactory.h +++ b/src/engine/internal/codebuilderfactory.h @@ -11,7 +11,7 @@ class CodeBuilderFactory : public ICodeBuilderFactory { public: static std::shared_ptr instance(); - std::shared_ptr create(CompilerContext *ctx, BlockPrototype *procedurePrototype) const override; + std::shared_ptr create(CompilerContext *ctx, BlockPrototype *procedurePrototype, bool isPredicate) const override; std::shared_ptr createCtx(IEngine *engine, Target *target) const override; private: diff --git a/src/engine/internal/icodebuilderfactory.h b/src/engine/internal/icodebuilderfactory.h index 6e031b2d..eecc3c0a 100644 --- a/src/engine/internal/icodebuilderfactory.h +++ b/src/engine/internal/icodebuilderfactory.h @@ -18,7 +18,7 @@ class ICodeBuilderFactory public: virtual ~ICodeBuilderFactory() { } - virtual std::shared_ptr create(CompilerContext *ctx, BlockPrototype *procedurePrototype = nullptr) const = 0; + virtual std::shared_ptr create(CompilerContext *ctx, BlockPrototype *procedurePrototype = nullptr, bool isPredicate = false) const = 0; virtual std::shared_ptr createCtx(IEngine *engine, Target *target) const = 0; }; diff --git a/test/compiler/compiler_test.cpp b/test/compiler/compiler_test.cpp index 3b098284..3b823b7b 100644 --- a/test/compiler/compiler_test.cpp +++ b/test/compiler/compiler_test.cpp @@ -46,12 +46,12 @@ class CompilerTest : public testing::Test m_testVar.reset(); } - void compile(Compiler *compiler, std::shared_ptr block, BlockPrototype *procedurePrototype = nullptr) + void compile(Compiler *compiler, std::shared_ptr block, BlockPrototype *procedurePrototype = nullptr, bool isHatPredicate = false) { ASSERT_EQ(compiler->block(), nullptr); - EXPECT_CALL(m_builderFactory, create(m_ctx.get(), procedurePrototype)).WillOnce(Return(m_builder)); + EXPECT_CALL(m_builderFactory, create(m_ctx.get(), procedurePrototype, isHatPredicate)).WillOnce(Return(m_builder)); EXPECT_CALL(*m_builder, finalize()).WillOnce(Return(m_code)); - ASSERT_EQ(compiler->compile(block), m_code); + ASSERT_EQ(compiler->compile(block, isHatPredicate), m_code); ASSERT_EQ(compiler->block(), nullptr); } @@ -1778,7 +1778,23 @@ TEST_F(CompilerTest, UnsupportedBlocks) EXPECT_CALL(*m_builder, addConstValue).WillRepeatedly(Return(nullptr)); compile(m_compiler.get(), block1); - ASSERT_EQ(m_compiler->unsupportedBlocks(), std::unordered_set({ "block1", "block2", "value_block1", "value_block3", "value_block5", "block4" })); + // Hat predicates + auto block5 = std::make_shared("b5", "block5"); + compile(m_compiler.get(), block5, nullptr, true); + + auto block6 = std::make_shared("b6", "block6"); + block6->setCompileFunction([](Compiler *) -> CompilerValue * { return nullptr; }); + compile(m_compiler.get(), block6, nullptr, true); + + auto block7 = std::make_shared("b7", "block7"); + CompilerConstant ret(Compiler::StaticType::Bool, Value(true)); + block7->setCompileFunction([](Compiler *) -> CompilerValue * { return nullptr; }); + block7->setHatPredicateCompileFunction([](Compiler *compiler) -> CompilerValue * { return compiler->addConstValue(true); }); + EXPECT_CALL(*m_builder, addConstValue(Value(true))).WillOnce(Return(&ret)); + compile(m_compiler.get(), block7, nullptr, true); + + // Check + ASSERT_EQ(m_compiler->unsupportedBlocks(), std::unordered_set({ "block1", "block2", "value_block1", "value_block3", "value_block5", "block4", "block5", "block6" })); } TEST_F(CompilerTest, Procedure) @@ -1818,3 +1834,21 @@ TEST_F(CompilerTest, Preoptimize) EXPECT_CALL(*ctx, preoptimize()); compiler.preoptimize(); } + +TEST_F(CompilerTest, HatPredicate) +{ + auto block = std::make_shared("", ""); + block->setCompileFunction([](Compiler *compiler) -> CompilerValue * { return compiler->addConstValue(true); }); + block->setHatPredicateCompileFunction([](Compiler *compiler) -> CompilerValue * { return compiler->addConstValue(false); }); + + CompilerConstant ret1(Compiler::StaticType::Bool, Value(true)); + CompilerConstant ret2(Compiler::StaticType::Bool, Value(false)); + + // Script + EXPECT_CALL(*m_builder, addConstValue(Value(true))).WillOnce(Return(&ret1)); + compile(m_compiler.get(), block, nullptr, false); + + // Hat predicate + EXPECT_CALL(*m_builder, addConstValue(Value(false))).WillOnce(Return(&ret2)); + compile(m_compiler.get(), block, nullptr, true); +} diff --git a/test/mocks/codebuilderfactorymock.h b/test/mocks/codebuilderfactorymock.h index b585bf48..3489995d 100644 --- a/test/mocks/codebuilderfactorymock.h +++ b/test/mocks/codebuilderfactorymock.h @@ -8,6 +8,6 @@ using namespace libscratchcpp; class CodeBuilderFactoryMock : public ICodeBuilderFactory { public: - MOCK_METHOD(std::shared_ptr, create, (CompilerContext *, BlockPrototype *), (const, override)); + MOCK_METHOD(std::shared_ptr, create, (CompilerContext *, BlockPrototype *, bool), (const, override)); MOCK_METHOD(std::shared_ptr, createCtx, (IEngine *, Target *), (const, override)); }; From a30bce3215189cbc097f9285b714f026503ce90c Mon Sep 17 00:00:00 2001 From: adazem009 <68537469+adazem009@users.noreply.github.com> Date: Tue, 18 Feb 2025 19:56:38 +0100 Subject: [PATCH 4/9] Script: Add hat predicate code --- include/scratchcpp/script.h | 3 +++ src/engine/script.cpp | 12 ++++++++++++ src/engine/script_p.h | 1 + test/script/script_test.cpp | 10 ++++++++++ 4 files changed, 26 insertions(+) diff --git a/include/scratchcpp/script.h b/include/scratchcpp/script.h index 3032cd4c..13d3b825 100644 --- a/include/scratchcpp/script.h +++ b/include/scratchcpp/script.h @@ -31,6 +31,9 @@ class LIBSCRATCHCPP_EXPORT Script ExecutableCode *code() const; void setCode(std::shared_ptr code); + ExecutableCode *hatPredicateCode() const; + void setHatPredicateCode(std::shared_ptr code); + bool runHatPredicate(Target *target); std::shared_ptr start(); diff --git a/src/engine/script.cpp b/src/engine/script.cpp index 868ad18e..a56d33bb 100644 --- a/src/engine/script.cpp +++ b/src/engine/script.cpp @@ -42,6 +42,18 @@ void Script::setCode(std::shared_ptr code) impl->code = code; } +/*! Returns the executable code of the hat predicate. */ +ExecutableCode *Script::hatPredicateCode() const +{ + return impl->hatPredicateCode.get(); +} + +/*! Sets the executable code of the hat predicate. */ +void Script::setHatPredicateCode(std::shared_ptr code) +{ + impl->hatPredicateCode = code; +} + /*! * Runs the edge-activated hat predicate as the given target and returns the reported value. * \note If there isn't any predicate, nothing will happen and the returned value will be false. diff --git a/src/engine/script_p.h b/src/engine/script_p.h index b067b17a..207123d2 100644 --- a/src/engine/script_p.h +++ b/src/engine/script_p.h @@ -22,6 +22,7 @@ struct ScriptPrivate ScriptPrivate(const ScriptPrivate &) = delete; std::shared_ptr code; + std::shared_ptr hatPredicateCode; Target *target = nullptr; std::shared_ptr topBlock; diff --git a/test/script/script_test.cpp b/test/script/script_test.cpp index 48744545..102e52b3 100644 --- a/test/script/script_test.cpp +++ b/test/script/script_test.cpp @@ -41,6 +41,16 @@ TEST_F(ScriptTest, Code) ASSERT_EQ(script.code(), code.get()); } +TEST_F(ScriptTest, HatPredicateCode) +{ + Script script(nullptr, nullptr, nullptr); + ASSERT_EQ(script.hatPredicateCode(), nullptr); + + auto code = std::make_shared(); + script.setHatPredicateCode(code); + ASSERT_EQ(script.hatPredicateCode(), code.get()); +} + TEST_F(ScriptTest, Start) { Script script1(nullptr, nullptr, nullptr); From 8a1cc326e84228fce5e45c16abfb3e8e2d84d395 Mon Sep 17 00:00:00 2001 From: adazem009 <68537469+adazem009@users.noreply.github.com> Date: Tue, 18 Feb 2025 20:12:42 +0100 Subject: [PATCH 5/9] Thread: Add hat predicate support --- include/scratchcpp/thread.h | 1 + src/engine/thread.cpp | 16 ++++++++++++++++ src/engine/thread_p.h | 2 ++ test/thread/thread_test.cpp | 25 +++++++++++++++++++++++++ 4 files changed, 44 insertions(+) diff --git a/include/scratchcpp/thread.h b/include/scratchcpp/thread.h index bc0578d7..994de694 100644 --- a/include/scratchcpp/thread.h +++ b/include/scratchcpp/thread.h @@ -27,6 +27,7 @@ class LIBSCRATCHCPP_EXPORT Thread Script *script() const; void run(); + bool runPredicate(); void kill(); void reset(); diff --git a/src/engine/thread.cpp b/src/engine/thread.cpp index 0cb6b2ed..ddd96647 100644 --- a/src/engine/thread.cpp +++ b/src/engine/thread.cpp @@ -18,9 +18,13 @@ Thread::Thread(Target *target, IEngine *engine, Script *script) : if (impl->script) { impl->code = impl->script->code(); + impl->hatPredicateCode = impl->script->hatPredicateCode(); if (impl->code) impl->executionContext = impl->code->createExecutionContext(this); + + if (impl->hatPredicateCode) + impl->hatPredicateExecutionContext = impl->hatPredicateCode->createExecutionContext(this); } } @@ -56,6 +60,18 @@ void Thread::run() string_pool_set_thread(nullptr); } +/*! Runs the hat predicate and returns its return value. */ +bool Thread::runPredicate() +{ + if (!impl->hatPredicateCode) + return false; + + string_pool_set_thread(this); + const bool ret = impl->hatPredicateCode->runPredicate(impl->hatPredicateExecutionContext.get()); + string_pool_set_thread(nullptr); + return ret; +} + /*! Stops the script. */ void Thread::kill() { diff --git a/src/engine/thread_p.h b/src/engine/thread_p.h index cf340e2f..fd3e009e 100644 --- a/src/engine/thread_p.h +++ b/src/engine/thread_p.h @@ -21,7 +21,9 @@ struct ThreadPrivate IEngine *engine = nullptr; Script *script = nullptr; ExecutableCode *code = nullptr; + ExecutableCode *hatPredicateCode = nullptr; std::shared_ptr executionContext; + std::shared_ptr hatPredicateExecutionContext; }; } // namespace libscratchcpp diff --git a/test/thread/thread_test.cpp b/test/thread/thread_test.cpp index f44e5402..687bd76b 100644 --- a/test/thread/thread_test.cpp +++ b/test/thread/thread_test.cpp @@ -50,6 +50,31 @@ TEST_F(ThreadTest, Run) m_thread->run(); } +TEST_F(ThreadTest, RunPredicate) +{ + ASSERT_FALSE(m_thread->runPredicate()); + + auto predicateCode = std::make_shared(); + std::shared_ptr predicateCtx; + m_script->setHatPredicateCode(predicateCode); + EXPECT_CALL(*m_code, createExecutionContext(_)).WillOnce(Invoke([this](Thread *thread) { + m_ctx = std::make_shared(thread); + return m_ctx; + })); + EXPECT_CALL(*predicateCode, createExecutionContext(_)).WillOnce(Invoke([&predicateCtx](Thread *thread) { + predicateCtx = std::make_shared(thread); + return predicateCtx; + })); + + Thread thread(&m_target, &m_engine, m_script.get()); + + EXPECT_CALL(*predicateCode, runPredicate(predicateCtx.get())).WillOnce(Return(false)); + ASSERT_FALSE(thread.runPredicate()); + + EXPECT_CALL(*predicateCode, runPredicate(predicateCtx.get())).WillOnce(Return(true)); + ASSERT_TRUE(thread.runPredicate()); +} + TEST_F(ThreadTest, Kill) { EXPECT_CALL(*m_code, kill(m_ctx.get())); From c25145610e61c5cc3af2cf82a42323edc1a93151 Mon Sep 17 00:00:00 2001 From: adazem009 <68537469+adazem009@users.noreply.github.com> Date: Tue, 18 Feb 2025 20:46:33 +0100 Subject: [PATCH 6/9] Script: Implement runHatPredicate() --- src/engine/script.cpp | 6 ++---- test/script/script_test.cpp | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/src/engine/script.cpp b/src/engine/script.cpp index a56d33bb..44fa4177 100644 --- a/src/engine/script.cpp +++ b/src/engine/script.cpp @@ -63,10 +63,8 @@ bool Script::runHatPredicate(Target *target) if (!target || !impl->engine) return false; - // TODO: Implement this - // auto thread = std::make_shared(target, impl->engine, this); - - return false; + auto thread = std::make_shared(target, impl->engine, this); + return thread->runPredicate(); } /*! Starts the script (creates a thread). */ diff --git a/test/script/script_test.cpp b/test/script/script_test.cpp index 102e52b3..7fb6f380 100644 --- a/test/script/script_test.cpp +++ b/test/script/script_test.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -14,6 +15,7 @@ using namespace libscratchcpp; using ::testing::Return; using ::testing::ReturnRef; +using ::testing::Invoke; using ::testing::_; class ScriptTest : public testing::Test @@ -51,6 +53,28 @@ TEST_F(ScriptTest, HatPredicateCode) ASSERT_EQ(script.hatPredicateCode(), code.get()); } +TEST_F(ScriptTest, RunHatPredicate) +{ + Script script(nullptr, nullptr, &m_engine); + auto code = std::make_shared(); + std::shared_ptr ctx; + script.setHatPredicateCode(code); + + EXPECT_CALL(*code, createExecutionContext(_)).WillRepeatedly(Invoke([&ctx](Thread *thread) { + ctx = std::make_shared(thread); + return ctx; + })); + + EXPECT_CALL(*code, runPredicate(_)).WillOnce(Return(true)); + ASSERT_TRUE(script.runHatPredicate(&m_target)); + + EXPECT_CALL(*code, runPredicate(_)).WillOnce(Return(true)); + ASSERT_TRUE(script.runHatPredicate(&m_target)); + + EXPECT_CALL(*code, runPredicate(_)).WillOnce(Return(false)); + ASSERT_FALSE(script.runHatPredicate(&m_target)); +} + TEST_F(ScriptTest, Start) { Script script1(nullptr, nullptr, nullptr); From ce7cd5d747cc93107601594fbb8d74215635a3b1 Mon Sep 17 00:00:00 2001 From: adazem009 <68537469+adazem009@users.noreply.github.com> Date: Tue, 18 Feb 2025 20:47:12 +0100 Subject: [PATCH 7/9] Engine: Compile hat predicates --- src/engine/internal/engine.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/engine/internal/engine.cpp b/src/engine/internal/engine.cpp index 24a1a36c..cf0e831e 100644 --- a/src/engine/internal/engine.cpp +++ b/src/engine/internal/engine.cpp @@ -274,6 +274,9 @@ void Engine::compile() auto script = std::make_shared