From 4309e900cdbd4e001d6bb1de1443ef6f4f840ba2 Mon Sep 17 00:00:00 2001 From: adazem009 <68537469+adazem009@users.noreply.github.com> Date: Sun, 24 Nov 2024 11:53:25 +0100 Subject: [PATCH 1/9] Add execution context parameter to LLVM function --- src/dev/engine/internal/llvm/llvmcodebuilder.cpp | 11 ++++++----- .../engine/internal/llvm/llvmexecutablecode.cpp | 2 +- src/dev/engine/internal/llvm/llvmexecutablecode.h | 2 +- test/dev/llvm/llvmcodebuilder_test.cpp | 2 +- test/dev/llvm/llvmexecutablecode_test.cpp | 14 +++++++------- test/dev/llvm/testfunctions.cpp | 4 ++-- test/dev/llvm/testfunctions.h | 3 ++- test/dev/llvm/testmock.h | 2 +- 8 files changed, 21 insertions(+), 19 deletions(-) diff --git a/src/dev/engine/internal/llvm/llvmcodebuilder.cpp b/src/dev/engine/internal/llvm/llvmcodebuilder.cpp index cfb4b29e..45c0b3d1 100644 --- a/src/dev/engine/internal/llvm/llvmcodebuilder.cpp +++ b/src/dev/engine/internal/llvm/llvmcodebuilder.cpp @@ -58,13 +58,14 @@ std::shared_ptr LLVMCodeBuilder::finalize() m_builder.setFastMathFlags(fmf); // Create function - // void *f(Target *, ValueData **, List **) + // void *f(ExecutionContext *, Target *, ValueData **, List **) llvm::PointerType *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0); - llvm::FunctionType *funcType = llvm::FunctionType::get(pointerType, { pointerType, pointerType, pointerType }, false); + llvm::FunctionType *funcType = llvm::FunctionType::get(pointerType, { pointerType, pointerType, pointerType, pointerType }, false); llvm::Function *func = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, "f", m_module.get()); - llvm::Value *targetPtr = func->getArg(0); - llvm::Value *targetVariables = func->getArg(1); - llvm::Value *targetLists = func->getArg(2); + llvm::Value *executionContextPtr = func->getArg(0); + llvm::Value *targetPtr = func->getArg(1); + llvm::Value *targetVariables = func->getArg(2); + llvm::Value *targetLists = func->getArg(3); llvm::BasicBlock *entry = llvm::BasicBlock::Create(m_ctx, "entry", func); m_builder.SetInsertPoint(entry); diff --git a/src/dev/engine/internal/llvm/llvmexecutablecode.cpp b/src/dev/engine/internal/llvm/llvmexecutablecode.cpp index eb8b8b03..e4f9d443 100644 --- a/src/dev/engine/internal/llvm/llvmexecutablecode.cpp +++ b/src/dev/engine/internal/llvm/llvmexecutablecode.cpp @@ -55,7 +55,7 @@ void LLVMExecutableCode::run(ExecutionContext *context) ctx->setFinished(done); } else { Target *target = ctx->target(); - void *handle = m_mainFunction(target, target->variableData(), target->listData()); + void *handle = m_mainFunction(context, target, target->variableData(), target->listData()); if (!handle) ctx->setFinished(true); diff --git a/src/dev/engine/internal/llvm/llvmexecutablecode.h b/src/dev/engine/internal/llvm/llvmexecutablecode.h index 26e412dd..63524156 100644 --- a/src/dev/engine/internal/llvm/llvmexecutablecode.h +++ b/src/dev/engine/internal/llvm/llvmexecutablecode.h @@ -33,7 +33,7 @@ class LLVMExecutableCode : public ExecutableCode private: uint64_t lookupFunction(const std::string &name); - using MainFunctionType = void *(*)(Target *, ValueData **, List **); + using MainFunctionType = void *(*)(ExecutionContext *, Target *, ValueData **, List **); using ResumeFunctionType = bool (*)(void *); static LLVMExecutionContext *getContext(ExecutionContext *context); diff --git a/test/dev/llvm/llvmcodebuilder_test.cpp b/test/dev/llvm/llvmcodebuilder_test.cpp index e9739cd1..b12f4dc0 100644 --- a/test/dev/llvm/llvmcodebuilder_test.cpp +++ b/test/dev/llvm/llvmcodebuilder_test.cpp @@ -52,7 +52,7 @@ class LLVMCodeBuilderTest : public testing::Test void SetUp() override { - test_function(nullptr, nullptr, nullptr, nullptr); // force dependency + test_function(nullptr, nullptr, nullptr, nullptr, nullptr); // force dependency } void createBuilder(Target *target, bool warp) { m_builder = std::make_unique(target, "test", warp); } diff --git a/test/dev/llvm/llvmexecutablecode_test.cpp b/test/dev/llvm/llvmexecutablecode_test.cpp index 84e86d15..4564d94e 100644 --- a/test/dev/llvm/llvmexecutablecode_test.cpp +++ b/test/dev/llvm/llvmexecutablecode_test.cpp @@ -20,7 +20,7 @@ class LLVMExecutableCodeTest : public testing::Test { m_module = std::make_unique("test", m_ctx); m_builder = std::make_unique>(m_ctx); - test_function(nullptr, nullptr, nullptr, nullptr); // force dependency + test_function(nullptr, nullptr, nullptr, nullptr, nullptr); // force dependency llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); @@ -31,9 +31,9 @@ class LLVMExecutableCodeTest : public testing::Test llvm::Function *beginMainFunction() { - // void *f(Target *, ValueData **, List **) + // void *f(ExecutionContext *, Target *, ValueData **, List **) llvm::Type *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0); - llvm::FunctionType *funcType = llvm::FunctionType::get(pointerType, { pointerType, pointerType, pointerType }, false); + llvm::FunctionType *funcType = llvm::FunctionType::get(pointerType, { pointerType, pointerType, pointerType, pointerType }, false); llvm::Function *func = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, "f", m_module.get()); llvm::BasicBlock *entry = llvm::BasicBlock::Create(m_ctx, "entry", func); @@ -57,12 +57,12 @@ class LLVMExecutableCodeTest : public testing::Test void addTestFunction(llvm::Function *mainFunc) { auto ptrType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0); - auto func = m_module->getOrInsertFunction("test_function", llvm::FunctionType::get(m_builder->getVoidTy(), { ptrType, ptrType, ptrType, ptrType }, false)); + auto func = m_module->getOrInsertFunction("test_function", llvm::FunctionType::get(m_builder->getVoidTy(), { ptrType, ptrType, ptrType, ptrType, ptrType }, false)); llvm::Constant *mockInt = llvm::ConstantInt::get(llvm::Type::getInt64Ty(m_ctx), (uintptr_t)&m_mock, false); llvm::Constant *mockPtr = llvm::ConstantExpr::getIntToPtr(mockInt, ptrType); - m_builder->CreateCall(func, { mockPtr, mainFunc->getArg(0), mainFunc->getArg(1), mainFunc->getArg(2) }); + m_builder->CreateCall(func, { mockPtr, mainFunc->getArg(0), mainFunc->getArg(1), mainFunc->getArg(2), mainFunc->getArg(3) }); } void addTestPrintFunction(llvm::Value *arg1, llvm::Value *arg2) @@ -110,7 +110,7 @@ TEST_F(LLVMExecutableCodeTest, MainFunction) auto ctx = code.createExecutionContext(&m_target); ASSERT_FALSE(code.isFinished(ctx.get())); - EXPECT_CALL(m_mock, f(&m_target, m_target.variableData(), m_target.listData())); + EXPECT_CALL(m_mock, f(ctx.get(), &m_target, m_target.variableData(), m_target.listData())); code.run(ctx.get()); ASSERT_TRUE(code.isFinished(ctx.get())); @@ -135,7 +135,7 @@ TEST_F(LLVMExecutableCodeTest, MainFunction) ASSERT_FALSE(code.isFinished(anotherCtx.get())); ASSERT_TRUE(code.isFinished(ctx.get())); - EXPECT_CALL(m_mock, f(&anotherTarget, anotherTarget.variableData(), anotherTarget.listData())); + EXPECT_CALL(m_mock, f(anotherCtx.get(), &anotherTarget, anotherTarget.variableData(), anotherTarget.listData())); code.run(anotherCtx.get()); ASSERT_TRUE(code.isFinished(anotherCtx.get())); ASSERT_TRUE(code.isFinished(ctx.get())); diff --git a/test/dev/llvm/testfunctions.cpp b/test/dev/llvm/testfunctions.cpp index 8555ee24..df7e3879 100644 --- a/test/dev/llvm/testfunctions.cpp +++ b/test/dev/llvm/testfunctions.cpp @@ -11,10 +11,10 @@ static int counter = 0; extern "C" { - void test_function(TestMock *mock, Target *target, ValueData **varData, List **listData) + void test_function(TestMock *mock, ExecutionContext *ctx, Target *target, ValueData **varData, List **listData) { if (mock) - mock->f(target, varData, listData); + mock->f(ctx, target, varData, listData); } void test_print_function(ValueData *arg1, ValueData *arg2) diff --git a/test/dev/llvm/testfunctions.h b/test/dev/llvm/testfunctions.h index c20c8efa..dd864328 100644 --- a/test/dev/llvm/testfunctions.h +++ b/test/dev/llvm/testfunctions.h @@ -7,10 +7,11 @@ class TestMock; class Target; class ValueData; class List; +class ExecutionContext; extern "C" { - void test_function(TestMock *mock, Target *target, ValueData **varData, List **listData); + void test_function(TestMock *mock, ExecutionContext *ctx, Target *target, ValueData **varData, List **listData); void test_print_function(ValueData *arg1, ValueData *arg2); void test_function_no_args(Target *target); diff --git a/test/dev/llvm/testmock.h b/test/dev/llvm/testmock.h index 176eeff2..5573f247 100644 --- a/test/dev/llvm/testmock.h +++ b/test/dev/llvm/testmock.h @@ -12,7 +12,7 @@ class List; class TestMock { public: - MOCK_METHOD(void, f, (Target *, ValueData **, List **)); + MOCK_METHOD(void, f, (ExecutionContext * ctx, Target *, ValueData **, List **)); }; } // namespace libscratchcpp From beb0f98f34379e3d9683b3645aa53f479f92c790 Mon Sep 17 00:00:00 2001 From: adazem009 <68537469+adazem009@users.noreply.github.com> Date: Sun, 24 Nov 2024 12:22:04 +0100 Subject: [PATCH 2/9] LLVMCodeBuilder: Add more function call parameter options --- src/dev/engine/internal/icodebuilder.h | 2 + .../engine/internal/llvm/llvmcodebuilder.cpp | 26 +++- .../engine/internal/llvm/llvmcodebuilder.h | 2 + .../engine/internal/llvm/llvminstruction.h | 2 + test/dev/llvm/llvmcodebuilder_test.cpp | 129 ++++++++++-------- test/dev/llvm/testfunctions.cpp | 37 +++-- test/dev/llvm/testfunctions.h | 29 ++-- test/mocks/codebuildermock.h | 2 + 8 files changed, 144 insertions(+), 85 deletions(-) diff --git a/src/dev/engine/internal/icodebuilder.h b/src/dev/engine/internal/icodebuilder.h index 8e47ea40..5132f1a9 100644 --- a/src/dev/engine/internal/icodebuilder.h +++ b/src/dev/engine/internal/icodebuilder.h @@ -20,6 +20,8 @@ class ICodeBuilder virtual std::shared_ptr finalize() = 0; virtual CompilerValue *addFunctionCall(const std::string &functionName, Compiler::StaticType returnType, const Compiler::ArgTypes &argTypes, const Compiler::Args &args) = 0; + virtual CompilerValue *addTargetFunctionCall(const std::string &functionName, Compiler::StaticType returnType, const Compiler::ArgTypes &argTypes, const Compiler::Args &args) = 0; + virtual CompilerValue *addFunctionCallWithCtx(const std::string &functionName, Compiler::StaticType returnType, const Compiler::ArgTypes &argTypes, const Compiler::Args &args) = 0; virtual CompilerConstant *addConstValue(const Value &value) = 0; virtual CompilerValue *addVariableValue(Variable *variable) = 0; virtual CompilerValue *addListContents(List *list) = 0; diff --git a/src/dev/engine/internal/llvm/llvmcodebuilder.cpp b/src/dev/engine/internal/llvm/llvmcodebuilder.cpp index 45c0b3d1..da31cea5 100644 --- a/src/dev/engine/internal/llvm/llvmcodebuilder.cpp +++ b/src/dev/engine/internal/llvm/llvmcodebuilder.cpp @@ -118,9 +118,17 @@ std::shared_ptr LLVMCodeBuilder::finalize() std::vector types; std::vector args; + // Add execution context arg + if (step.functionCtxArg) { + types.push_back(llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0)); + args.push_back(executionContextPtr); + } + // Add target pointer arg - types.push_back(llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0)); - args.push_back(targetPtr); + if (step.functionTargetArg) { + types.push_back(llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0)); + args.push_back(targetPtr); + } // Args for (auto &arg : step.args) { @@ -951,6 +959,20 @@ CompilerValue *LLVMCodeBuilder::addFunctionCall(const std::string &functionName, return nullptr; } +CompilerValue *LLVMCodeBuilder::addTargetFunctionCall(const std::string &functionName, Compiler::StaticType returnType, const Compiler::ArgTypes &argTypes, const Compiler::Args &args) +{ + CompilerValue *ret = addFunctionCall(functionName, returnType, argTypes, args); + m_instructions.back().functionTargetArg = true; + return ret; +} + +CompilerValue *LLVMCodeBuilder::addFunctionCallWithCtx(const std::string &functionName, Compiler::StaticType returnType, const Compiler::ArgTypes &argTypes, const Compiler::Args &args) +{ + CompilerValue *ret = addFunctionCall(functionName, returnType, argTypes, args); + m_instructions.back().functionCtxArg = true; + return ret; +} + CompilerConstant *LLVMCodeBuilder::addConstValue(const Value &value) { auto constReg = std::make_shared(TYPE_MAP[value.type()], value); diff --git a/src/dev/engine/internal/llvm/llvmcodebuilder.h b/src/dev/engine/internal/llvm/llvmcodebuilder.h index 7933dc52..1978ef30 100644 --- a/src/dev/engine/internal/llvm/llvmcodebuilder.h +++ b/src/dev/engine/internal/llvm/llvmcodebuilder.h @@ -27,6 +27,8 @@ class LLVMCodeBuilder : public ICodeBuilder std::shared_ptr finalize() override; CompilerValue *addFunctionCall(const std::string &functionName, Compiler::StaticType returnType, const Compiler::ArgTypes &argTypes, const Compiler::Args &args) override; + CompilerValue *addTargetFunctionCall(const std::string &functionName, Compiler::StaticType returnType, const Compiler::ArgTypes &argTypes, const Compiler::Args &args) override; + CompilerValue *addFunctionCallWithCtx(const std::string &functionName, Compiler::StaticType returnType, const Compiler::ArgTypes &argTypes, const Compiler::Args &args) override; CompilerConstant *addConstValue(const Value &value) override; CompilerValue *addVariableValue(Variable *variable) override; CompilerValue *addListContents(List *list) override; diff --git a/src/dev/engine/internal/llvm/llvminstruction.h b/src/dev/engine/internal/llvm/llvminstruction.h index e661cff5..b770008e 100644 --- a/src/dev/engine/internal/llvm/llvminstruction.h +++ b/src/dev/engine/internal/llvm/llvminstruction.h @@ -73,6 +73,8 @@ struct LLVMInstruction std::string functionName; std::vector> args; // target type, register LLVMRegister *functionReturnReg = nullptr; + bool functionTargetArg = false; // whether to add target ptr to function parameters + bool functionCtxArg = false; // whether to add execution context ptr to function parameters Variable *workVariable = nullptr; // for variables List *workList = nullptr; // for lists }; diff --git a/test/dev/llvm/llvmcodebuilder_test.cpp b/test/dev/llvm/llvmcodebuilder_test.cpp index b12f4dc0..e624fb9c 100644 --- a/test/dev/llvm/llvmcodebuilder_test.cpp +++ b/test/dev/llvm/llvmcodebuilder_test.cpp @@ -314,30 +314,44 @@ TEST_F(LLVMCodeBuilderTest, FunctionCalls) for (bool warp : warpList) { createBuilder(warp); - m_builder->addFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}, {}); - CompilerValue *v = m_builder->addFunctionCall("test_function_no_args_ret", Compiler::StaticType::String, {}, {}); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addFunctionCall("test_empty_function", Compiler::StaticType::Void, {}, {}); + + CompilerValue *v = m_builder->addConstValue("test"); + m_builder->addFunctionCallWithCtx("test_ctx_function", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + + m_builder->addTargetFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}, {}); + + v = m_builder->addTargetFunctionCall("test_function_no_args_ret", Compiler::StaticType::String, {}, {}); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); v = m_builder->addConstValue("1"); - v = m_builder->addFunctionCall("test_function_1_arg_ret", Compiler::StaticType::String, { Compiler::StaticType::String }, { v }); + v = m_builder->addTargetFunctionCall("test_function_1_arg_ret", Compiler::StaticType::String, { Compiler::StaticType::String }, { v }); CompilerValue *v1 = m_builder->addConstValue("2"); CompilerValue *v2 = m_builder->addConstValue("3"); - m_builder->addFunctionCall("test_function_3_args", Compiler::StaticType::Void, { Compiler::StaticType::String, Compiler::StaticType::String, Compiler::StaticType::String }, { v, v1, v2 }); + m_builder + ->addTargetFunctionCall("test_function_3_args", Compiler::StaticType::Void, { Compiler::StaticType::String, Compiler::StaticType::String, Compiler::StaticType::String }, { v, v1, v2 }); v = m_builder->addConstValue("test"); v1 = m_builder->addConstValue("4"); v2 = m_builder->addConstValue("5"); - v = m_builder->addFunctionCall( + v = m_builder->addTargetFunctionCall( "test_function_3_args_ret", Compiler::StaticType::String, { Compiler::StaticType::String, Compiler::StaticType::String, Compiler::StaticType::String }, { v, v1, v2 }); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); auto code = m_builder->finalize(); auto ctx = code->createExecutionContext(&m_target); - static const std::string expected = + std::stringstream s; + s << ctx.get(); + std::string ctxPtr = s.str(); + + const std::string expected = + "empty\n" + ctxPtr + + "\n" + "test\n" "no_args\n" "no_args_ret\n" "1_arg no_args_output\n" @@ -2148,28 +2162,29 @@ TEST_F(LLVMCodeBuilderTest, ListContainsItem) TEST_F(LLVMCodeBuilderTest, Yield) { auto build = [this]() { - m_builder->addFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}, {}); + m_builder->addTargetFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}, {}); - CompilerValue *v = m_builder->addFunctionCall("test_function_no_args_ret", Compiler::StaticType::String, {}, {}); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + CompilerValue *v = m_builder->addTargetFunctionCall("test_function_no_args_ret", Compiler::StaticType::String, {}, {}); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); m_builder->yield(); v = m_builder->addConstValue("1"); - v = m_builder->addFunctionCall("test_function_1_arg_ret", Compiler::StaticType::String, { Compiler::StaticType::String }, { v }); + v = m_builder->addTargetFunctionCall("test_function_1_arg_ret", Compiler::StaticType::String, { Compiler::StaticType::String }, { v }); CompilerValue *v1 = m_builder->addConstValue("2"); CompilerValue *v2 = m_builder->addConstValue(3); - m_builder->addFunctionCall("test_function_3_args", Compiler::StaticType::Void, { Compiler::StaticType::String, Compiler::StaticType::String, Compiler::StaticType::String }, { v, v1, v2 }); + m_builder + ->addTargetFunctionCall("test_function_3_args", Compiler::StaticType::Void, { Compiler::StaticType::String, Compiler::StaticType::String, Compiler::StaticType::String }, { v, v1, v2 }); v = m_builder->addConstValue("test"); v1 = m_builder->addConstValue("4"); v2 = m_builder->addConstValue("5"); - v = m_builder->addFunctionCall( + v = m_builder->addTargetFunctionCall( "test_function_3_args_ret", Compiler::StaticType::String, { Compiler::StaticType::String, Compiler::StaticType::String, Compiler::StaticType::String }, { v, v1, v2 }); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); }; // Without warp @@ -2367,71 +2382,71 @@ TEST_F(LLVMCodeBuilderTest, IfStatement) // Without else branch (const condition) CompilerValue *v = m_builder->addConstValue("true"); m_builder->beginIfStatement(v); - m_builder->addFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}, {}); + m_builder->addTargetFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}, {}); m_builder->endIf(); v = m_builder->addConstValue("false"); m_builder->beginIfStatement(v); - m_builder->addFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}, {}); + m_builder->addTargetFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}, {}); m_builder->endIf(); // Without else branch (condition returned by function) - CompilerValue *v1 = m_builder->addFunctionCall("test_function_no_args_ret", Compiler::StaticType::String, {}, {}); + CompilerValue *v1 = m_builder->addTargetFunctionCall("test_function_no_args_ret", Compiler::StaticType::String, {}, {}); CompilerValue *v2 = m_builder->addConstValue("no_args_output"); v = m_builder->addFunctionCall("test_equals", Compiler::StaticType::Bool, { Compiler::StaticType::String, Compiler::StaticType::String }, { v1, v2 }); m_builder->beginIfStatement(v); v = m_builder->addConstValue(0); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); m_builder->endIf(); - v1 = m_builder->addFunctionCall("test_function_no_args_ret", Compiler::StaticType::String, {}, {}); + v1 = m_builder->addTargetFunctionCall("test_function_no_args_ret", Compiler::StaticType::String, {}, {}); v2 = m_builder->addConstValue(""); v = m_builder->addFunctionCall("test_equals", Compiler::StaticType::Bool, { Compiler::StaticType::String, Compiler::StaticType::String }, { v1, v2 }); m_builder->beginIfStatement(v); v = m_builder->addConstValue(1); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); m_builder->endIf(); // With else branch (const condition) v = m_builder->addConstValue("true"); m_builder->beginIfStatement(v); v = m_builder->addConstValue(2); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); m_builder->beginElseBranch(); v = m_builder->addConstValue(3); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); m_builder->endIf(); v = m_builder->addConstValue("false"); m_builder->beginIfStatement(v); v = m_builder->addConstValue(4); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); m_builder->beginElseBranch(); v = m_builder->addConstValue(5); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); m_builder->endIf(); // With else branch (condition returned by function) - v1 = m_builder->addFunctionCall("test_function_no_args_ret", Compiler::StaticType::String, {}, {}); + v1 = m_builder->addTargetFunctionCall("test_function_no_args_ret", Compiler::StaticType::String, {}, {}); v2 = m_builder->addConstValue("no_args_output"); v = m_builder->addFunctionCall("test_equals", Compiler::StaticType::Bool, { Compiler::StaticType::String, Compiler::StaticType::String }, { v1, v2 }); m_builder->beginIfStatement(v); v = m_builder->addConstValue(6); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); m_builder->beginElseBranch(); v = m_builder->addConstValue(7); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); m_builder->endIf(); - v1 = m_builder->addFunctionCall("test_function_no_args_ret", Compiler::StaticType::String, {}, {}); + v1 = m_builder->addTargetFunctionCall("test_function_no_args_ret", Compiler::StaticType::String, {}, {}); v2 = m_builder->addConstValue(""); v = m_builder->addFunctionCall("test_equals", Compiler::StaticType::Bool, { Compiler::StaticType::String, Compiler::StaticType::String }, { v1, v2 }); m_builder->beginIfStatement(v); v = m_builder->addConstValue(8); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); m_builder->beginElseBranch(); v = m_builder->addConstValue(9); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); m_builder->endIf(); // Nested 1 @@ -2442,19 +2457,19 @@ TEST_F(LLVMCodeBuilderTest, IfStatement) m_builder->beginIfStatement(v); { v = m_builder->addConstValue(0); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); } m_builder->beginElseBranch(); { v = m_builder->addConstValue(1); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); v = m_builder->addConstValue(false); m_builder->beginIfStatement(v); m_builder->beginElseBranch(); { v = m_builder->addConstValue(2); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); } m_builder->endIf(); } @@ -2466,12 +2481,12 @@ TEST_F(LLVMCodeBuilderTest, IfStatement) m_builder->beginIfStatement(v); { v = m_builder->addConstValue(3); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); } m_builder->beginElseBranch(); { v = m_builder->addConstValue(4); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); } m_builder->endIf(); } @@ -2485,12 +2500,12 @@ TEST_F(LLVMCodeBuilderTest, IfStatement) m_builder->beginIfStatement(v); { v = m_builder->addConstValue(5); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); } m_builder->beginElseBranch(); { v = m_builder->addConstValue(6); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); } m_builder->endIf(); } @@ -2500,7 +2515,7 @@ TEST_F(LLVMCodeBuilderTest, IfStatement) m_builder->beginIfStatement(v); { v = m_builder->addConstValue(7); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); } m_builder->beginElseBranch(); m_builder->endIf(); @@ -2769,36 +2784,36 @@ TEST_F(LLVMCodeBuilderTest, RepeatLoop) // Const count CompilerValue *v = m_builder->addConstValue("-5"); m_builder->beginRepeatLoop(v); - m_builder->addFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}, {}); + m_builder->addTargetFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}, {}); m_builder->endLoop(); v = m_builder->addConstValue(0); m_builder->beginRepeatLoop(v); - m_builder->addFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}, {}); + m_builder->addTargetFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}, {}); m_builder->endLoop(); v = m_builder->addConstValue(3); m_builder->beginRepeatLoop(v); - m_builder->addFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}, {}); + m_builder->addTargetFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}, {}); m_builder->endLoop(); v = m_builder->addConstValue("2.4"); m_builder->beginRepeatLoop(v); v = m_builder->addConstValue(0); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); m_builder->endLoop(); v = m_builder->addConstValue("2.5"); m_builder->beginRepeatLoop(v); v = m_builder->addConstValue(1); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); m_builder->endLoop(); // Count returned by function v = m_builder->addConstValue(2); v = callConstFuncForType(ValueType::Number, v); m_builder->beginRepeatLoop(v); - m_builder->addFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}, {}); + m_builder->addTargetFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}, {}); m_builder->endLoop(); // Nested @@ -2809,18 +2824,18 @@ TEST_F(LLVMCodeBuilderTest, RepeatLoop) m_builder->beginRepeatLoop(v); { v = m_builder->addConstValue(1); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); } m_builder->endLoop(); v = m_builder->addConstValue(2); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); v = m_builder->addConstValue(3); m_builder->beginRepeatLoop(v); { v = m_builder->addConstValue(3); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); } m_builder->endLoop(); } @@ -2863,7 +2878,7 @@ TEST_F(LLVMCodeBuilderTest, RepeatLoop) v = m_builder->addConstValue(3); m_builder->beginRepeatLoop(v); - m_builder->addFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}, {}); + m_builder->addTargetFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}, {}); m_builder->endLoop(); code = m_builder->finalize(); @@ -2923,7 +2938,7 @@ TEST_F(LLVMCodeBuilderTest, WhileLoop) v = m_builder->addFunctionCall("test_lower_than", Compiler::StaticType::Bool, { Compiler::StaticType::Number, Compiler::StaticType::Number }, { v1, v2 }); m_builder->beginWhileLoop(v); v = m_builder->addConstValue(0); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); m_builder->addFunctionCall("test_increment_counter", Compiler::StaticType::Void, {}, {}); m_builder->endLoop(); @@ -2942,13 +2957,13 @@ TEST_F(LLVMCodeBuilderTest, WhileLoop) m_builder->beginWhileLoop(v); { v = m_builder->addConstValue(1); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); m_builder->addFunctionCall("test_increment_counter", Compiler::StaticType::Void, {}, {}); } m_builder->endLoop(); v = m_builder->addConstValue(2); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); m_builder->beginLoopCondition(); v = m_builder->addConstValue(false); @@ -2982,7 +2997,7 @@ TEST_F(LLVMCodeBuilderTest, WhileLoop) m_builder->beginLoopCondition(); v = m_builder->addConstValue(true); m_builder->beginWhileLoop(v); - m_builder->addFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}, {}); + m_builder->addTargetFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}, {}); m_builder->endLoop(); code = m_builder->finalize(); @@ -3026,7 +3041,7 @@ TEST_F(LLVMCodeBuilderTest, RepeatUntilLoop) v = m_builder->addFunctionCall("test_not", Compiler::StaticType::Bool, { Compiler::StaticType::Bool }, { v }); m_builder->beginRepeatUntilLoop(v); v = m_builder->addConstValue(0); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); m_builder->addFunctionCall("test_increment_counter", Compiler::StaticType::Void, {}, {}); m_builder->endLoop(); @@ -3047,13 +3062,13 @@ TEST_F(LLVMCodeBuilderTest, RepeatUntilLoop) m_builder->beginRepeatUntilLoop(v); { v = m_builder->addConstValue(1); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); m_builder->addFunctionCall("test_increment_counter", Compiler::StaticType::Void, {}, {}); } m_builder->endLoop(); v = m_builder->addConstValue(2); - m_builder->addFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); + m_builder->addTargetFunctionCall("test_function_1_arg", Compiler::StaticType::Void, { Compiler::StaticType::String }, { v }); m_builder->beginLoopCondition(); v = m_builder->addConstValue(true); @@ -3087,7 +3102,7 @@ TEST_F(LLVMCodeBuilderTest, RepeatUntilLoop) m_builder->beginLoopCondition(); v = m_builder->addConstValue(false); m_builder->beginRepeatUntilLoop(v); - m_builder->addFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}, {}); + m_builder->addTargetFunctionCall("test_function_no_args", Compiler::StaticType::Void, {}, {}); m_builder->endLoop(); code = m_builder->finalize(); diff --git a/test/dev/llvm/testfunctions.cpp b/test/dev/llvm/testfunctions.cpp index df7e3879..0a78634e 100644 --- a/test/dev/llvm/testfunctions.cpp +++ b/test/dev/llvm/testfunctions.cpp @@ -25,6 +25,17 @@ extern "C" std::cout << s1 << " " << s2 << std::endl; } + void test_empty_function() + { + std::cout << "empty" << std::endl; + } + + void test_ctx_function(ExecutionContext *ctx, const char *arg) + { + std::cout << ctx << std::endl; + std::cout << arg << std::endl; + } + void test_function_no_args(Target *target) { target->isStage(); @@ -67,69 +78,69 @@ extern "C" return value_toCString(&v.data()); } - bool test_equals(Target *target, const char *a, const char *b) + bool test_equals(const char *a, const char *b) { return strcmp(a, b) == 0; } - void test_unreachable(Target *target) + void test_unreachable() { std::cout << "error: unreachable reached" << std::endl; exit(1); } - bool test_lower_than(Target *target, double a, double b) + bool test_lower_than(double a, double b) { return a < b; } - double test_const_number(Target *target, double v) + double test_const_number(double v) { return v; } - bool test_const_bool(Target *target, bool v) + bool test_const_bool(bool v) { return v; } - char *test_const_string(Target *target, const char *v) + char *test_const_string(const char *v) { Value value(v); return value_toCString(&value.data()); } - bool test_not(Target *target, bool arg) + bool test_not(bool arg) { return !arg; } - void test_reset_counter(Target *target) + void test_reset_counter() { counter = 0; } - void test_increment_counter(Target *target) + void test_increment_counter() { counter++; } - double test_get_counter(Target *target) + double test_get_counter() { return counter; } - void test_print_number(Target *target, double v) + void test_print_number(double v) { std::cout << v << std::endl; } - void test_print_bool(Target *target, bool v) + void test_print_bool(bool v) { std::cout << v << std::endl; } - void test_print_string(Target *target, const char *v) + void test_print_string(const char *v) { std::cout << v << std::endl; } diff --git a/test/dev/llvm/testfunctions.h b/test/dev/llvm/testfunctions.h index dd864328..49693f6b 100644 --- a/test/dev/llvm/testfunctions.h +++ b/test/dev/llvm/testfunctions.h @@ -14,6 +14,9 @@ extern "C" void test_function(TestMock *mock, ExecutionContext *ctx, Target *target, ValueData **varData, List **listData); void test_print_function(ValueData *arg1, ValueData *arg2); + void test_empty_function(); + void test_ctx_function(ExecutionContext *ctx, const char *arg); + void test_function_no_args(Target *target); char *test_function_no_args_ret(Target *target); void test_function_1_arg(Target *target, const char *arg1); @@ -21,22 +24,22 @@ extern "C" void test_function_3_args(Target *target, const char *arg1, const char *arg2, const char *arg3); char *test_function_3_args_ret(Target *target, const char *arg1, const char *arg2, const char *arg3); - bool test_equals(Target *target, const char *a, const char *b); - bool test_lower_than(Target *target, double a, double b); - bool test_not(Target *target, bool arg); - double test_const_number(Target *target, double v); - bool test_const_bool(Target *target, bool v); - char *test_const_string(Target *target, const char *v); + bool test_equals(const char *a, const char *b); + bool test_lower_than(double a, double b); + bool test_not(bool arg); + double test_const_number(double v); + bool test_const_bool(bool v); + char *test_const_string(const char *v); - void test_unreachable(Target *target); + void test_unreachable(); - void test_reset_counter(Target *target); - void test_increment_counter(Target *target); - double test_get_counter(Target *target); + void test_reset_counter(); + void test_increment_counter(); + double test_get_counter(); - void test_print_number(Target *target, double v); - void test_print_bool(Target *target, bool v); - void test_print_string(Target *target, const char *v); + void test_print_number(double v); + void test_print_bool(bool v); + void test_print_string(const char *v); } } // namespace libscratchcpp diff --git a/test/mocks/codebuildermock.h b/test/mocks/codebuildermock.h index 3097812c..cf1738ea 100644 --- a/test/mocks/codebuildermock.h +++ b/test/mocks/codebuildermock.h @@ -10,6 +10,8 @@ class CodeBuilderMock : public ICodeBuilder public: MOCK_METHOD(std::shared_ptr, finalize, (), (override)); MOCK_METHOD(CompilerValue *, addFunctionCall, (const std::string &, Compiler::StaticType, const Compiler::ArgTypes &, const Compiler::Args &), (override)); + MOCK_METHOD(CompilerValue *, addTargetFunctionCall, (const std::string &, Compiler::StaticType, const Compiler::ArgTypes &, const Compiler::Args &), (override)); + MOCK_METHOD(CompilerValue *, addFunctionCallWithCtx, (const std::string &, Compiler::StaticType, const Compiler::ArgTypes &, const Compiler::Args &), (override)); MOCK_METHOD(CompilerConstant *, addConstValue, (const Value &), (override)); MOCK_METHOD(CompilerValue *, addVariableValue, (Variable *), (override)); MOCK_METHOD(CompilerValue *, addListContents, (List *), (override)); From 19711234d04183817cf3e0f1d587f81d8d9b9d02 Mon Sep 17 00:00:00 2001 From: adazem009 <68537469+adazem009@users.noreply.github.com> Date: Sun, 24 Nov 2024 12:26:11 +0100 Subject: [PATCH 3/9] Compiler: Add target and ctx function call methods --- include/scratchcpp/dev/compiler.h | 2 ++ src/dev/engine/compiler.cpp | 22 ++++++++++++- test/dev/compiler/compiler_test.cpp | 50 +++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 1 deletion(-) diff --git a/include/scratchcpp/dev/compiler.h b/include/scratchcpp/dev/compiler.h index 178d3b9a..06cd716c 100644 --- a/include/scratchcpp/dev/compiler.h +++ b/include/scratchcpp/dev/compiler.h @@ -48,6 +48,8 @@ class LIBSCRATCHCPP_EXPORT Compiler std::shared_ptr compile(std::shared_ptr startBlock); CompilerValue *addFunctionCall(const std::string &functionName, StaticType returnType = StaticType::Void, const ArgTypes &argTypes = {}, const Args &args = {}); + CompilerValue *addTargetFunctionCall(const std::string &functionName, StaticType returnType = StaticType::Void, const ArgTypes &argTypes = {}, const Args &args = {}); + CompilerValue *addFunctionCallWithCtx(const std::string &functionName, StaticType returnType = StaticType::Void, const ArgTypes &argTypes = {}, const Args &args = {}); CompilerConstant *addConstValue(const Value &value); CompilerValue *addVariableValue(Variable *variable); CompilerValue *addListContents(List *list); diff --git a/src/dev/engine/compiler.cpp b/src/dev/engine/compiler.cpp index 2b41a71c..10e68e7d 100644 --- a/src/dev/engine/compiler.cpp +++ b/src/dev/engine/compiler.cpp @@ -76,7 +76,7 @@ std::shared_ptr Compiler::compile(std::shared_ptr startBl /*! * Adds a call to the given function.\n - * For example: extern "C" bool some_block(Target *target, double arg1, const char *arg2) + * For example: extern "C" bool some_block(double arg1, const char *arg2) */ CompilerValue *Compiler::addFunctionCall(const std::string &functionName, StaticType returnType, const ArgTypes &argTypes, const Args &args) { @@ -84,6 +84,26 @@ CompilerValue *Compiler::addFunctionCall(const std::string &functionName, Static return impl->builder->addFunctionCall(functionName, returnType, argTypes, args); } +/*! + * Adds a call to the given function with a target parameter.\n + * For example: extern "C" bool some_block(Target *target, double arg1, const char *arg2) + */ +CompilerValue *Compiler::addTargetFunctionCall(const std::string &functionName, StaticType returnType, const ArgTypes &argTypes, const Args &args) +{ + assert(argTypes.size() == args.size()); + return impl->builder->addTargetFunctionCall(functionName, returnType, argTypes, args); +} + +/*! + * Adds a call to the given function with an execution context parameter.\n + * For example: extern "C" bool some_block(ExecutionContext *ctx, double arg1, const char *arg2) + */ +CompilerValue *Compiler::addFunctionCallWithCtx(const std::string &functionName, StaticType returnType, const ArgTypes &argTypes, const Args &args) +{ + assert(argTypes.size() == args.size()); + return impl->builder->addFunctionCallWithCtx(functionName, returnType, argTypes, args); +} + /*! Adds the given constant to the compiled code. */ CompilerConstant *Compiler::addConstValue(const Value &value) { diff --git a/test/dev/compiler/compiler_test.cpp b/test/dev/compiler/compiler_test.cpp index 62b52baa..8e87191b 100644 --- a/test/dev/compiler/compiler_test.cpp +++ b/test/dev/compiler/compiler_test.cpp @@ -89,6 +89,56 @@ TEST_F(CompilerTest, AddFunctionCall) compile(compiler, block); } +TEST_F(CompilerTest, AddTargetFunctionCall) +{ + Compiler compiler(&m_engine, &m_target); + auto block = std::make_shared("a", ""); + m_compareBlock = block; + block->setCompileFunction([](Compiler *compiler) -> CompilerValue * { + EXPECT_EQ(compiler->block(), m_compareBlock); + CompilerValue arg1(Compiler::StaticType::Unknown); + CompilerValue arg2(Compiler::StaticType::Unknown); + Compiler::ArgTypes argTypes = { Compiler::StaticType::Number, Compiler::StaticType::Bool }; + Compiler::Args args = { &arg1, &arg2 }; + EXPECT_CALL(*m_builder, addTargetFunctionCall("test1", Compiler::StaticType::Void, argTypes, args)); + compiler->addTargetFunctionCall("test1", Compiler::StaticType::Void, argTypes, args); + + args = { &arg1 }; + argTypes = { Compiler::StaticType::String }; + EXPECT_CALL(*m_builder, addTargetFunctionCall("test2", Compiler::StaticType::Bool, argTypes, args)); + compiler->addTargetFunctionCall("test2", Compiler::StaticType::Bool, argTypes, args); + + return nullptr; + }); + + compile(compiler, block); +} + +TEST_F(CompilerTest, AddFunctionCallWithCtx) +{ + Compiler compiler(&m_engine, &m_target); + auto block = std::make_shared("a", ""); + m_compareBlock = block; + block->setCompileFunction([](Compiler *compiler) -> CompilerValue * { + EXPECT_EQ(compiler->block(), m_compareBlock); + CompilerValue arg1(Compiler::StaticType::Unknown); + CompilerValue arg2(Compiler::StaticType::Unknown); + Compiler::ArgTypes argTypes = { Compiler::StaticType::Number, Compiler::StaticType::Bool }; + Compiler::Args args = { &arg1, &arg2 }; + EXPECT_CALL(*m_builder, addFunctionCallWithCtx("test1", Compiler::StaticType::Void, argTypes, args)); + compiler->addFunctionCallWithCtx("test1", Compiler::StaticType::Void, argTypes, args); + + args = { &arg1 }; + argTypes = { Compiler::StaticType::String }; + EXPECT_CALL(*m_builder, addFunctionCallWithCtx("test2", Compiler::StaticType::Bool, argTypes, args)); + compiler->addFunctionCallWithCtx("test2", Compiler::StaticType::Bool, argTypes, args); + + return nullptr; + }); + + compile(compiler, block); +} + TEST_F(CompilerTest, AddConstValue) { Compiler compiler(&m_engine, &m_target); From 6ea564431c529af64d5223dbd0acd29a420f3d8e Mon Sep 17 00:00:00 2001 From: adazem009 <68537469+adazem009@users.noreply.github.com> Date: Mon, 25 Nov 2024 20:24:06 +0100 Subject: [PATCH 4/9] Compiler: Add API for custom while/until loops --- include/scratchcpp/dev/compiler.h | 6 ++- src/dev/engine/compiler.cpp | 50 ++++++++++++++++++++--- src/dev/engine/compiler_p.h | 1 + test/dev/compiler/compiler_test.cpp | 62 +++++++++++++++++++++++------ 4 files changed, 99 insertions(+), 20 deletions(-) diff --git a/include/scratchcpp/dev/compiler.h b/include/scratchcpp/dev/compiler.h index 06cd716c..1c1d441e 100644 --- a/include/scratchcpp/dev/compiler.h +++ b/include/scratchcpp/dev/compiler.h @@ -103,12 +103,16 @@ class LIBSCRATCHCPP_EXPORT Compiler void beginElseBranch(); void endIf(); + void beginWhileLoop(CompilerValue *cond); + void beginRepeatUntilLoop(CompilerValue *cond); + void beginLoopCondition(); + void endLoop(); + void moveToIf(CompilerValue *cond, std::shared_ptr substack); void moveToIfElse(CompilerValue *cond, std::shared_ptr substack1, std::shared_ptr substack2); void moveToRepeatLoop(CompilerValue *count, std::shared_ptr substack); void moveToWhileLoop(CompilerValue *cond, std::shared_ptr substack); void moveToRepeatUntilLoop(CompilerValue *cond, std::shared_ptr substack); - void beginLoopCondition(); void warp(); Input *input(const std::string &name) const; diff --git a/src/dev/engine/compiler.cpp b/src/dev/engine/compiler.cpp index 10e68e7d..db4c0b54 100644 --- a/src/dev/engine/compiler.cpp +++ b/src/dev/engine/compiler.cpp @@ -54,6 +54,11 @@ std::shared_ptr Compiler::compile(std::shared_ptr startBl std::cerr << "error: if statement created by block '" << impl->block->opcode() << "' not terminated" << std::endl; assert(false); } + + if (impl->customLoopCount > 0) { + std::cerr << "error: loop created by block '" << impl->block->opcode() << "' not terminated" << std::endl; + assert(false); + } } else { std::cout << "warning: unsupported block: " << impl->block->opcode() << std::endl; impl->unsupportedBlocks.insert(impl->block->opcode()); @@ -382,6 +387,45 @@ void Compiler::endIf() impl->customIfStatementCount--; } +/*! + * Begins a custom while loop. + * \note The loop must be terminated with endLoop() after compiling your block. + */ +void Compiler::beginWhileLoop(CompilerValue *cond) +{ + impl->builder->beginWhileLoop(cond); + impl->customLoopCount++; +} + +/*! + * Begins a custom repeat until loop. + * \note The loop must be terminated with endLoop() after compiling your block. + */ +void Compiler::beginRepeatUntilLoop(CompilerValue *cond) +{ + impl->builder->beginRepeatUntilLoop(cond); + impl->customLoopCount++; +} + +/*! Begins a while/until loop condition. */ +void Compiler::beginLoopCondition() +{ + impl->builder->beginLoopCondition(); +} + +/*! Ends custom loop. */ +void Compiler::endLoop() +{ + if (impl->customLoopCount == 0) { + std::cerr << "error: called Compiler::endLoop() without a loop"; + assert(false); + return; + } + + impl->builder->endLoop(); + impl->customLoopCount--; +} + /*! Jumps to the given if substack. */ void Compiler::moveToIf(CompilerValue *cond, std::shared_ptr substack) { @@ -445,12 +489,6 @@ void Compiler::moveToRepeatUntilLoop(CompilerValue *cond, std::shared_ptr impl->substackEnd(); } -/*! Begins a while/until loop condition. */ -void Compiler::beginLoopCondition() -{ - impl->builder->beginLoopCondition(); -} - /*! Makes current script run without screen refresh. */ void Compiler::warp() { diff --git a/src/dev/engine/compiler_p.h b/src/dev/engine/compiler_p.h index 080afc96..740d3be2 100644 --- a/src/dev/engine/compiler_p.h +++ b/src/dev/engine/compiler_p.h @@ -33,6 +33,7 @@ struct CompilerPrivate std::shared_ptr block; int customIfStatementCount = 0; + int customLoopCount = 0; std::vector, std::shared_ptr>, SubstackType>> substackTree; bool substackHit = false; bool warp = false; diff --git a/test/dev/compiler/compiler_test.cpp b/test/dev/compiler/compiler_test.cpp index 8e87191b..7e7ed85f 100644 --- a/test/dev/compiler/compiler_test.cpp +++ b/test/dev/compiler/compiler_test.cpp @@ -892,6 +892,55 @@ TEST_F(CompilerTest, CustomIfStatement) compile(compiler, block); } +TEST_F(CompilerTest, CustomWhileLoop) +{ + Compiler compiler(&m_engine, &m_target); + auto block = std::make_shared("", ""); + + block->setCompileFunction([](Compiler *compiler) -> CompilerValue * { + CompilerValue arg(Compiler::StaticType::Unknown); + EXPECT_CALL(*m_builder, beginWhileLoop(&arg)); + compiler->beginWhileLoop(&arg); + EXPECT_CALL(*m_builder, endLoop()); + compiler->endLoop(); + + return nullptr; + }); + + compile(compiler, block); +} + +TEST_F(CompilerTest, CustomRepeatUntilLoop) +{ + Compiler compiler(&m_engine, &m_target); + auto block = std::make_shared("", ""); + + block->setCompileFunction([](Compiler *compiler) -> CompilerValue * { + CompilerValue arg(Compiler::StaticType::Unknown); + EXPECT_CALL(*m_builder, beginRepeatUntilLoop(&arg)); + compiler->beginRepeatUntilLoop(&arg); + EXPECT_CALL(*m_builder, endLoop()); + compiler->endLoop(); + + return nullptr; + }); + + compile(compiler, block); +} + +TEST_F(CompilerTest, BeginLoopCondition) +{ + Compiler compiler(&m_engine, &m_target); + auto block = std::make_shared("a", ""); + block->setCompileFunction([](Compiler *compiler) -> CompilerValue * { + EXPECT_CALL(*m_builder, beginLoopCondition()); + compiler->beginLoopCondition(); + return nullptr; + }); + + compile(compiler, block); +} + TEST_F(CompilerTest, MoveToIf) { Compiler compiler(&m_engine, &m_target); @@ -1369,19 +1418,6 @@ TEST_F(CompilerTest, MoveToRepeatUntilLoop) compile(compiler, l1); } -TEST_F(CompilerTest, BeginLoopCondition) -{ - Compiler compiler(&m_engine, &m_target); - auto block = std::make_shared("a", ""); - block->setCompileFunction([](Compiler *compiler) -> CompilerValue * { - EXPECT_CALL(*m_builder, beginLoopCondition()); - compiler->beginLoopCondition(); - return nullptr; - }); - - compile(compiler, block); -} - TEST_F(CompilerTest, Input) { Compiler compiler(&m_engine, &m_target); From 826935f0cff1bb066685a7c45fd32fe560359898 Mon Sep 17 00:00:00 2001 From: adazem009 <68537469+adazem009@users.noreply.github.com> Date: Sat, 7 Dec 2024 18:25:44 +0100 Subject: [PATCH 5/9] Add Promise class --- CMakeLists.txt | 1 + include/scratchcpp/dev/promise.h | 27 +++++++++++++++++++++++++++ src/dev/engine/CMakeLists.txt | 3 +++ src/dev/engine/promise.cpp | 25 +++++++++++++++++++++++++ src/dev/engine/promise_p.cpp | 5 +++++ src/dev/engine/promise_p.h | 13 +++++++++++++ test/dev/CMakeLists.txt | 1 + test/dev/promise/CMakeLists.txt | 12 ++++++++++++ test/dev/promise/promise_test.cpp | 14 ++++++++++++++ 9 files changed, 101 insertions(+) create mode 100644 include/scratchcpp/dev/promise.h create mode 100644 src/dev/engine/promise.cpp create mode 100644 src/dev/engine/promise_p.cpp create mode 100644 src/dev/engine/promise_p.h create mode 100644 test/dev/promise/CMakeLists.txt create mode 100644 test/dev/promise/promise_test.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 138f5540..b7723d88 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -76,6 +76,7 @@ if (LIBSCRATCHCPP_USE_LLVM) include/scratchcpp/dev/compilerconstant.h include/scratchcpp/dev/executablecode.h include/scratchcpp/dev/executioncontext.h + include/scratchcpp/dev/promise.h ) if(LIBSCRATCHCPP_PRINT_LLVM_IR) diff --git a/include/scratchcpp/dev/promise.h b/include/scratchcpp/dev/promise.h new file mode 100644 index 00000000..2c0b6010 --- /dev/null +++ b/include/scratchcpp/dev/promise.h @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../global.h" +#include "../spimpl.h" + +namespace libscratchcpp +{ + +class PromisePrivate; + +/*! \brief The Promise class represents the eventual completion of an asynchronous operation. */ +class LIBSCRATCHCPP_EXPORT Promise +{ + public: + Promise(); + Promise(const Promise &) = delete; + + bool isResolved() const; + void resolve(); + + private: + spimpl::unique_impl_ptr impl; +}; + +} // namespace libscratchcpp diff --git a/src/dev/engine/CMakeLists.txt b/src/dev/engine/CMakeLists.txt index a4933a98..a3e2790e 100644 --- a/src/dev/engine/CMakeLists.txt +++ b/src/dev/engine/CMakeLists.txt @@ -12,6 +12,9 @@ target_sources(scratchcpp executioncontext.cpp executioncontext_p.cpp executioncontext_p.h + promise.cpp + promise_p.cpp + promise_p.h internal/icodebuilder.h internal/icodebuilderfactory.h internal/codebuilderfactory.cpp diff --git a/src/dev/engine/promise.cpp b/src/dev/engine/promise.cpp new file mode 100644 index 00000000..4f6ed6ae --- /dev/null +++ b/src/dev/engine/promise.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "promise_p.h" + +using namespace libscratchcpp; + +/*! Constructs Promise. */ +Promise::Promise() : + impl(spimpl::make_unique_impl()) +{ +} + +/*! Returns true if the promise is resolved. */ +bool Promise::isResolved() const +{ + return impl->isResolved; +} + +/*! Marks the promise as resolved. */ +void Promise::resolve() +{ + impl->isResolved = true; +} diff --git a/src/dev/engine/promise_p.cpp b/src/dev/engine/promise_p.cpp new file mode 100644 index 00000000..bf4f7f0a --- /dev/null +++ b/src/dev/engine/promise_p.cpp @@ -0,0 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "promise_p.h" + +using namespace libscratchcpp; diff --git a/src/dev/engine/promise_p.h b/src/dev/engine/promise_p.h new file mode 100644 index 00000000..1c956537 --- /dev/null +++ b/src/dev/engine/promise_p.h @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace libscratchcpp +{ + +struct PromisePrivate +{ + bool isResolved = false; +}; + +} // namespace libscratchcpp diff --git a/test/dev/CMakeLists.txt b/test/dev/CMakeLists.txt index 79ee0e67..60bec96f 100644 --- a/test/dev/CMakeLists.txt +++ b/test/dev/CMakeLists.txt @@ -2,3 +2,4 @@ add_subdirectory(blocks) add_subdirectory(executioncontext) add_subdirectory(llvm) add_subdirectory(compiler) +add_subdirectory(promise) diff --git a/test/dev/promise/CMakeLists.txt b/test/dev/promise/CMakeLists.txt new file mode 100644 index 00000000..71f94d71 --- /dev/null +++ b/test/dev/promise/CMakeLists.txt @@ -0,0 +1,12 @@ +add_executable( + promise_test + promise_test.cpp +) + +target_link_libraries( + promise_test + GTest::gtest_main + scratchcpp +) + +gtest_discover_tests(promise_test) diff --git a/test/dev/promise/promise_test.cpp b/test/dev/promise/promise_test.cpp new file mode 100644 index 00000000..d2c28c6b --- /dev/null +++ b/test/dev/promise/promise_test.cpp @@ -0,0 +1,14 @@ +#include + +#include "../../common.h" + +using namespace libscratchcpp; + +TEST(PromiseTest, Resolve) +{ + Promise promise; + ASSERT_FALSE(promise.isResolved()); + + promise.resolve(); + ASSERT_TRUE(promise.isResolved()); +} From a415da4504759c00f7db7621d824da4cf2358652 Mon Sep 17 00:00:00 2001 From: adazem009 <68537469+adazem009@users.noreply.github.com> Date: Sat, 7 Dec 2024 18:50:46 +0100 Subject: [PATCH 6/9] ExecutionContext: Add promise API --- include/scratchcpp/dev/executioncontext.h | 4 ++++ src/dev/engine/executioncontext.cpp | 12 ++++++++++++ src/dev/engine/executioncontext_p.h | 4 ++++ .../executioncontext/executioncontext_test.cpp | 15 +++++++++++++++ 4 files changed, 35 insertions(+) diff --git a/include/scratchcpp/dev/executioncontext.h b/include/scratchcpp/dev/executioncontext.h index 7d7ac877..4651d079 100644 --- a/include/scratchcpp/dev/executioncontext.h +++ b/include/scratchcpp/dev/executioncontext.h @@ -9,6 +9,7 @@ namespace libscratchcpp { class Target; +class Promise; class ExecutionContextPrivate; /*! \brief The ExecutionContext represents the execution context of a target (can be a clone) with variables, lists, etc. */ @@ -21,6 +22,9 @@ class LIBSCRATCHCPP_EXPORT ExecutionContext Target *target() const; + std::shared_ptr promise() const; + void setPromise(std::shared_ptr promise); + private: spimpl::unique_impl_ptr impl; }; diff --git a/src/dev/engine/executioncontext.cpp b/src/dev/engine/executioncontext.cpp index c45156c6..81f8cb06 100644 --- a/src/dev/engine/executioncontext.cpp +++ b/src/dev/engine/executioncontext.cpp @@ -17,3 +17,15 @@ Target *ExecutionContext::target() const { return impl->target; } + +/*! Returns the script promise. */ +std::shared_ptr ExecutionContext::promise() const +{ + return impl->promise; +} + +/*! Sets the script promise (yields until the promise is resolved). */ +void ExecutionContext::setPromise(std::shared_ptr promise) +{ + impl->promise = promise; +} diff --git a/src/dev/engine/executioncontext_p.h b/src/dev/engine/executioncontext_p.h index fd7fad51..19a22e96 100644 --- a/src/dev/engine/executioncontext_p.h +++ b/src/dev/engine/executioncontext_p.h @@ -2,16 +2,20 @@ #pragma once +#include + namespace libscratchcpp { class Target; +class Promise; struct ExecutionContextPrivate { ExecutionContextPrivate(Target *target); Target *target = nullptr; + std::shared_ptr promise; }; } // namespace libscratchcpp diff --git a/test/dev/executioncontext/executioncontext_test.cpp b/test/dev/executioncontext/executioncontext_test.cpp index 3429ffef..d1145373 100644 --- a/test/dev/executioncontext/executioncontext_test.cpp +++ b/test/dev/executioncontext/executioncontext_test.cpp @@ -1,5 +1,6 @@ #include #include +#include #include "../../common.h" @@ -11,3 +12,17 @@ TEST(ExecutionContextTest, Constructor) ExecutionContext ctx(&target); ASSERT_EQ(ctx.target(), &target); } + +TEST(ExecutionContextTest, Promise) +{ + Target target; + ExecutionContext ctx(&target); + ASSERT_EQ(ctx.promise(), nullptr); + + auto promise = std::make_shared(); + ctx.setPromise(promise); + ASSERT_EQ(ctx.promise(), promise); + + ctx.setPromise(nullptr); + ASSERT_EQ(ctx.promise(), nullptr); +} From 4dfbc525bf6627644dffeeaf966d2361bdd00c49 Mon Sep 17 00:00:00 2001 From: adazem009 <68537469+adazem009@users.noreply.github.com> Date: Sat, 7 Dec 2024 18:51:52 +0100 Subject: [PATCH 7/9] Update promise API in Thread class --- include/scratchcpp/thread.h | 8 ++++++++ src/engine/internal/engine.cpp | 21 +++++++++++++++++++-- src/engine/thread.cpp | 23 +++++++++++++++-------- test/thread/thread_test.cpp | 17 ++++++++++++++++- 4 files changed, 58 insertions(+), 11 deletions(-) diff --git a/include/scratchcpp/thread.h b/include/scratchcpp/thread.h index f3f7693e..b5936db7 100644 --- a/include/scratchcpp/thread.h +++ b/include/scratchcpp/thread.h @@ -10,6 +10,9 @@ namespace libscratchcpp class VirtualMachine; class Target; +#ifdef USE_LLVM +class Promise; +#endif class IEngine; class Script; class ThreadPrivate; @@ -32,8 +35,13 @@ class LIBSCRATCHCPP_EXPORT Thread bool isFinished() const; +#ifdef USE_LLVM + std::shared_ptr promise() const; + void setPromise(std::shared_ptr promise); +#else void promise(); void resolvePromise(); +#endif private: spimpl::unique_impl_ptr impl; diff --git a/src/engine/internal/engine.cpp b/src/engine/internal/engine.cpp index 590bbe17..748797ce 100644 --- a/src/engine/internal/engine.cpp +++ b/src/engine/internal/engine.cpp @@ -9,6 +9,7 @@ #include #ifdef USE_LLVM #include +#include #else #include #endif @@ -593,8 +594,16 @@ void Engine::step() } else { Thread *th = senderThread; - if (std::find_if(m_threads.begin(), m_threads.end(), [th](std::shared_ptr thread) { return thread.get() == th; }) != m_threads.end()) + if (std::find_if(m_threads.begin(), m_threads.end(), [th](std::shared_ptr thread) { return thread.get() == th; }) != m_threads.end()) { +#ifdef USE_LLVM + auto promise = th->promise(); + + if (promise) + promise->resolve(); +#else th->resolvePromise(); +#endif + } resolved.push_back(broadcast); resolvedThreads.push_back(th); @@ -2018,8 +2027,16 @@ void Engine::addBroadcastPromise(Broadcast *broadcast, Thread *sender, bool wait // Resolve broadcast promise if it's already running auto it = m_broadcastSenders.find(broadcast); - if (it != m_broadcastSenders.cend() && std::find_if(m_threads.begin(), m_threads.end(), [&it](std::shared_ptr thread) { return thread.get() == it->second; }) != m_threads.end()) + if (it != m_broadcastSenders.cend() && std::find_if(m_threads.begin(), m_threads.end(), [&it](std::shared_ptr thread) { return thread.get() == it->second; }) != m_threads.end()) { +#ifdef USE_LLVM + auto promise = it->second->promise(); + + if (promise) + promise->resolve(); +#else it->second->resolvePromise(); +#endif + } if (wait) m_broadcastSenders[broadcast] = sender; diff --git a/src/engine/thread.cpp b/src/engine/thread.cpp index 2ffe2e56..3c1cbfc5 100644 --- a/src/engine/thread.cpp +++ b/src/engine/thread.cpp @@ -5,6 +5,7 @@ #ifdef USE_LLVM #include #include +#include #endif #include "thread_p.h" @@ -87,22 +88,28 @@ bool Thread::isFinished() const #endif } +#ifdef USE_LLVM +/*! Returns the script promise. */ +std::shared_ptr Thread::promise() const +{ + return impl->executionContext->promise(); +} + +/*! Sets the script promise (yields until the promise is resolved). */ +void Thread::setPromise(std::shared_ptr promise) +{ + impl->executionContext->setPromise(promise); +} +#else /*! Pauses the script (when it's executed using run() again) until resolvePromise() is called. */ void Thread::promise() { -#ifdef USE_LLVM - impl->code->promise(); -#else impl->vm->promise(); -#endif } /*! Resolves the promise and resumes the script. */ void Thread::resolvePromise() { -#ifdef USE_LLVM - impl->code->resolvePromise(); -#else impl->vm->resolvePromise(); -#endif } +#endif // USE_LLVM diff --git a/test/thread/thread_test.cpp b/test/thread/thread_test.cpp index 08f1c943..65114c0d 100644 --- a/test/thread/thread_test.cpp +++ b/test/thread/thread_test.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -76,5 +77,19 @@ TEST_F(ThreadTest, IsFinished) ASSERT_TRUE(m_thread->isFinished()); } -// TODO: Test promise() and resolvePromise() +TEST_F(ThreadTest, Promise) +{ + ASSERT_EQ(m_thread->promise(), m_ctx->promise()); + + auto promise = std::make_shared(); + m_ctx->setPromise(promise); + ASSERT_EQ(m_thread->promise(), m_ctx->promise()); + + m_ctx->setPromise(nullptr); + ASSERT_EQ(m_thread->promise(), m_ctx->promise()); + + m_thread->setPromise(promise); + ASSERT_EQ(m_thread->promise(), promise); + ASSERT_EQ(m_ctx->promise(), promise); +} #endif // USE_LLVM From e0143a6206f3aabb4203a189fc64c35bf7180a8d Mon Sep 17 00:00:00 2001 From: adazem009 <68537469+adazem009@users.noreply.github.com> Date: Sat, 7 Dec 2024 18:52:51 +0100 Subject: [PATCH 8/9] Remove old promise API from ExecutableCode --- include/scratchcpp/dev/executablecode.h | 6 ------ src/dev/engine/internal/llvm/llvmexecutablecode.cpp | 8 -------- src/dev/engine/internal/llvm/llvmexecutablecode.h | 3 --- test/mocks/executablecodemock.h | 3 --- 4 files changed, 20 deletions(-) diff --git a/include/scratchcpp/dev/executablecode.h b/include/scratchcpp/dev/executablecode.h index 54efea83..ba5d03e4 100644 --- a/include/scratchcpp/dev/executablecode.h +++ b/include/scratchcpp/dev/executablecode.h @@ -30,12 +30,6 @@ class LIBSCRATCHCPP_EXPORT ExecutableCode /*! Returns true if the code is stopped or finished. */ virtual bool isFinished(ExecutionContext *context) const = 0; - /*! Pauses the script (when it's executed using run() again) until resolvePromise() is called. */ - virtual void promise() = 0; - - /*! Resolves the promise and resumes the script. */ - virtual void resolvePromise() = 0; - /*! Creates an execution context for the given Target. */ virtual std::shared_ptr createExecutionContext(Target *target) const = 0; }; diff --git a/src/dev/engine/internal/llvm/llvmexecutablecode.cpp b/src/dev/engine/internal/llvm/llvmexecutablecode.cpp index e4f9d443..6d1b0a8b 100644 --- a/src/dev/engine/internal/llvm/llvmexecutablecode.cpp +++ b/src/dev/engine/internal/llvm/llvmexecutablecode.cpp @@ -83,14 +83,6 @@ bool LLVMExecutableCode::isFinished(ExecutionContext *context) const return getContext(context)->finished(); } -void LLVMExecutableCode::promise() -{ -} - -void LLVMExecutableCode::resolvePromise() -{ -} - std::shared_ptr LLVMExecutableCode::createExecutionContext(Target *target) const { return std::make_shared(target); diff --git a/src/dev/engine/internal/llvm/llvmexecutablecode.h b/src/dev/engine/internal/llvm/llvmexecutablecode.h index 63524156..ecdc09c3 100644 --- a/src/dev/engine/internal/llvm/llvmexecutablecode.h +++ b/src/dev/engine/internal/llvm/llvmexecutablecode.h @@ -25,9 +25,6 @@ class LLVMExecutableCode : public ExecutableCode bool isFinished(ExecutionContext *context) const override; - void promise() override; - void resolvePromise() override; - std::shared_ptr createExecutionContext(Target *target) const override; private: diff --git a/test/mocks/executablecodemock.h b/test/mocks/executablecodemock.h index f01b9f97..0958c391 100644 --- a/test/mocks/executablecodemock.h +++ b/test/mocks/executablecodemock.h @@ -14,8 +14,5 @@ class ExecutableCodeMock : public ExecutableCode MOCK_METHOD(bool, isFinished, (ExecutionContext *), (const, override)); - MOCK_METHOD(void, promise, (), (override)); - MOCK_METHOD(void, resolvePromise, (), (override)); - MOCK_METHOD(std::shared_ptr, createExecutionContext, (Target *), (const, override)); }; From 4477e02bcd595f7c16269e8e332b7e4ca8f93b03 Mon Sep 17 00:00:00 2001 From: adazem009 <68537469+adazem009@users.noreply.github.com> Date: Sat, 7 Dec 2024 19:47:35 +0100 Subject: [PATCH 9/9] LLVMExecutableCode: Implement promises --- .../internal/llvm/llvmexecutablecode.cpp | 12 +++ test/dev/llvm/llvmexecutablecode_test.cpp | 92 +++++++++++++++++++ 2 files changed, 104 insertions(+) diff --git a/src/dev/engine/internal/llvm/llvmexecutablecode.cpp b/src/dev/engine/internal/llvm/llvmexecutablecode.cpp index 6d1b0a8b..944e5d44 100644 --- a/src/dev/engine/internal/llvm/llvmexecutablecode.cpp +++ b/src/dev/engine/internal/llvm/llvmexecutablecode.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -46,6 +47,15 @@ void LLVMExecutableCode::run(ExecutionContext *context) if (ctx->finished()) return; + auto promise = ctx->promise(); + + if (promise) { + if (promise->isResolved()) + ctx->setPromise(nullptr); + else + return; + } + if (ctx->coroutineHandle()) { bool done = m_resumeFunction(ctx->coroutineHandle()); @@ -69,6 +79,7 @@ void LLVMExecutableCode::kill(ExecutionContext *context) LLVMExecutionContext *ctx = getContext(context); ctx->setCoroutineHandle(nullptr); ctx->setFinished(true); + ctx->setPromise(nullptr); } void LLVMExecutableCode::reset(ExecutionContext *context) @@ -76,6 +87,7 @@ void LLVMExecutableCode::reset(ExecutionContext *context) LLVMExecutionContext *ctx = getContext(context); ctx->setCoroutineHandle(nullptr); ctx->setFinished(false); + ctx->setPromise(nullptr); } bool LLVMExecutableCode::isFinished(ExecutionContext *context) const diff --git a/test/dev/llvm/llvmexecutablecode_test.cpp b/test/dev/llvm/llvmexecutablecode_test.cpp index 4564d94e..def0f985 100644 --- a/test/dev/llvm/llvmexecutablecode_test.cpp +++ b/test/dev/llvm/llvmexecutablecode_test.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -144,3 +145,94 @@ TEST_F(LLVMExecutableCodeTest, MainFunction) ASSERT_TRUE(code.isFinished(anotherCtx.get())); ASSERT_FALSE(code.isFinished(ctx.get())); } + +TEST_F(LLVMExecutableCodeTest, Promise) +{ + auto f = beginMainFunction(); + addTestFunction(f); + endFunction(nullPointer()); + + beginResumeFunction(); + endFunction(m_builder->getInt1(true)); + + LLVMExecutableCode code(std::move(m_module)); + auto ctx = code.createExecutionContext(&m_target); + ASSERT_FALSE(code.isFinished(ctx.get())); + + // run() + auto promise = std::make_shared(); + ctx->setPromise(promise); + EXPECT_CALL(m_mock, f).Times(0); + + for (int i = 0; i < 10; i++) { + code.run(ctx.get()); + ASSERT_FALSE(code.isFinished(ctx.get())); + } + + promise->resolve(); + + EXPECT_CALL(m_mock, f); + code.run(ctx.get()); + ASSERT_TRUE(code.isFinished(ctx.get())); + ASSERT_EQ(ctx->promise(), nullptr); + code.reset(ctx.get()); + + // kill() + promise = std::make_shared(); + ctx->setPromise(promise); + EXPECT_CALL(m_mock, f).Times(0); + + for (int i = 0; i < 10; i++) { + code.run(ctx.get()); + ASSERT_FALSE(code.isFinished(ctx.get())); + } + + code.kill(ctx.get()); + ASSERT_TRUE(code.isFinished(ctx.get())); + ASSERT_EQ(ctx->promise(), nullptr); + code.reset(ctx.get()); + + // reset() + promise = std::make_shared(); + ctx->setPromise(promise); + EXPECT_CALL(m_mock, f).Times(0); + + for (int i = 0; i < 10; i++) { + code.run(ctx.get()); + ASSERT_FALSE(code.isFinished(ctx.get())); + } + + code.reset(ctx.get()); + ASSERT_FALSE(code.isFinished(ctx.get())); + ASSERT_EQ(ctx->promise(), nullptr); + + EXPECT_CALL(m_mock, f); + code.run(ctx.get()); + ASSERT_TRUE(code.isFinished(ctx.get())); + + // Test with another context + Target anotherTarget; + auto anotherCtx = code.createExecutionContext(&anotherTarget); + ASSERT_FALSE(code.isFinished(anotherCtx.get())); + ASSERT_TRUE(code.isFinished(ctx.get())); + + promise = std::make_shared(); + anotherCtx->setPromise(promise); + EXPECT_CALL(m_mock, f).Times(0); + + for (int i = 0; i < 10; i++) { + code.run(anotherCtx.get()); + ASSERT_FALSE(code.isFinished(anotherCtx.get())); + } + + promise->resolve(); + + EXPECT_CALL(m_mock, f); + code.run(anotherCtx.get()); + ASSERT_TRUE(code.isFinished(anotherCtx.get())); + ASSERT_TRUE(code.isFinished(ctx.get())); + + code.reset(ctx.get()); + ASSERT_TRUE(code.isFinished(anotherCtx.get())); + ASSERT_FALSE(code.isFinished(ctx.get())); +}