Skip to content

Commit 1a68af0

Browse files
committed
LLVMCodeBuilder: Add createRandom() method
1 parent 0479309 commit 1a68af0

File tree

9 files changed

+234
-5
lines changed

9 files changed

+234
-5
lines changed

src/dev/engine/internal/icodebuilder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class ICodeBuilder
3636
virtual CompilerValue *createMul(CompilerValue *operand1, CompilerValue *operand2) = 0;
3737
virtual CompilerValue *createDiv(CompilerValue *operand1, CompilerValue *operand2) = 0;
3838

39+
virtual CompilerValue *createRandom(CompilerValue *from, CompilerValue *to) = 0;
40+
3941
virtual CompilerValue *createCmpEQ(CompilerValue *operand1, CompilerValue *operand2) = 0;
4042
virtual CompilerValue *createCmpGT(CompilerValue *operand1, CompilerValue *operand2) = 0;
4143
virtual CompilerValue *createCmpLT(CompilerValue *operand1, CompilerValue *operand2) = 0;

src/dev/engine/internal/llvm/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ target_sources(scratchcpp
1515
llvmprocedure.h
1616
llvmtypes.cpp
1717
llvmtypes.h
18+
llvmfunctions.cpp
19+
llvmfunctions.h
1820
llvmexecutablecode.cpp
1921
llvmexecutablecode.h
2022
llvmexecutioncontext.cpp

src/dev/engine/internal/llvm/llvmcodebuilder.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,38 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
191191
break;
192192
}
193193

194+
case LLVMInstruction::Type::Random: {
195+
assert(step.args.size() == 2);
196+
const auto &arg1 = step.args[0];
197+
const auto &arg2 = step.args[1];
198+
LLVMRegister *reg1 = arg1.second;
199+
LLVMRegister *reg2 = arg2.second;
200+
201+
if (reg1->type() == Compiler::StaticType::Bool && reg2->type() == Compiler::StaticType::Bool) {
202+
llvm::Value *bool1 = castValue(arg1.second, Compiler::StaticType::Bool);
203+
llvm::Value *bool2 = castValue(arg2.second, Compiler::StaticType::Bool);
204+
step.functionReturnReg->value = m_builder.CreateCall(resolve_llvm_random_bool(), { bool1, bool2 });
205+
} else {
206+
llvm::Constant *inf = llvm::ConstantFP::getInfinity(m_builder.getDoubleTy(), false);
207+
llvm::Value *num1 = removeNaN(castValue(arg1.second, Compiler::StaticType::Number));
208+
llvm::Value *num2 = removeNaN(castValue(arg2.second, Compiler::StaticType::Number));
209+
llvm::Value *sum = m_builder.CreateFAdd(num1, num2);
210+
llvm::Value *sumDiv = m_builder.CreateFDiv(sum, inf);
211+
llvm::Value *isInfOrNaN = isNaN(sumDiv);
212+
213+
// NOTE: The random function will be called even in edge cases where it isn't needed, but they're rare, so it shouldn't be an issue
214+
if (reg1->type() == Compiler::StaticType::Number && reg2->type() == Compiler::StaticType::Number)
215+
step.functionReturnReg->value = m_builder.CreateSelect(isInfOrNaN, sum, m_builder.CreateCall(resolve_llvm_random_double(), { num1, num2 }));
216+
else {
217+
llvm::Value *value1 = createValue(reg1);
218+
llvm::Value *value2 = createValue(reg2);
219+
step.functionReturnReg->value = m_builder.CreateSelect(isInfOrNaN, sum, m_builder.CreateCall(resolve_llvm_random(), { value1, value2 }));
220+
}
221+
}
222+
223+
break;
224+
}
225+
194226
case LLVMInstruction::Type::CmpEQ: {
195227
assert(step.args.size() == 2);
196228
const auto &arg1 = step.args[0].second;
@@ -1089,6 +1121,11 @@ CompilerValue *LLVMCodeBuilder::createDiv(CompilerValue *operand1, CompilerValue
10891121
return createOp(LLVMInstruction::Type::Div, Compiler::StaticType::Number, Compiler::StaticType::Number, { operand1, operand2 });
10901122
}
10911123

1124+
CompilerValue *LLVMCodeBuilder::createRandom(CompilerValue *from, CompilerValue *to)
1125+
{
1126+
return createOp(LLVMInstruction::Type::Random, Compiler::StaticType::Number, Compiler::StaticType::Unknown, { from, to });
1127+
}
1128+
10921129
CompilerValue *LLVMCodeBuilder::createCmpEQ(CompilerValue *operand1, CompilerValue *operand2)
10931130
{
10941131
return createOp(LLVMInstruction::Type::CmpEQ, Compiler::StaticType::Bool, Compiler::StaticType::Number, { operand1, operand2 });
@@ -2334,6 +2371,22 @@ llvm::FunctionCallee LLVMCodeBuilder::resolve_list_to_string()
23342371
return resolveFunction("list_to_string", llvm::FunctionType::get(pointerType, { pointerType }, false));
23352372
}
23362373

2374+
llvm::FunctionCallee LLVMCodeBuilder::resolve_llvm_random()
2375+
{
2376+
llvm::Type *valuePtr = m_valueDataType->getPointerTo();
2377+
return resolveFunction("llvm_random", llvm::FunctionType::get(m_builder.getDoubleTy(), { valuePtr, valuePtr }, false));
2378+
}
2379+
2380+
llvm::FunctionCallee LLVMCodeBuilder::resolve_llvm_random_double()
2381+
{
2382+
return resolveFunction("llvm_random_double", llvm::FunctionType::get(m_builder.getDoubleTy(), { m_builder.getDoubleTy(), m_builder.getDoubleTy() }, false));
2383+
}
2384+
2385+
llvm::FunctionCallee LLVMCodeBuilder::resolve_llvm_random_bool()
2386+
{
2387+
return resolveFunction("llvm_random_bool", llvm::FunctionType::get(m_builder.getDoubleTy(), { m_builder.getInt1Ty(), m_builder.getInt1Ty() }, false));
2388+
}
2389+
23372390
llvm::FunctionCallee LLVMCodeBuilder::resolve_strcasecmp()
23382391
{
23392392
llvm::Type *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_ctx), 0);

src/dev/engine/internal/llvm/llvmcodebuilder.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class LLVMCodeBuilder : public ICodeBuilder
4343
CompilerValue *createMul(CompilerValue *operand1, CompilerValue *operand2) override;
4444
CompilerValue *createDiv(CompilerValue *operand1, CompilerValue *operand2) override;
4545

46+
CompilerValue *createRandom(CompilerValue *from, CompilerValue *to) override;
47+
4648
CompilerValue *createCmpEQ(CompilerValue *operand1, CompilerValue *operand2) override;
4749
CompilerValue *createCmpGT(CompilerValue *operand1, CompilerValue *operand2) override;
4850
CompilerValue *createCmpLT(CompilerValue *operand1, CompilerValue *operand2) override;
@@ -166,6 +168,9 @@ class LLVMCodeBuilder : public ICodeBuilder
166168
llvm::FunctionCallee resolve_list_size_ptr();
167169
llvm::FunctionCallee resolve_list_alloc_size_ptr();
168170
llvm::FunctionCallee resolve_list_to_string();
171+
llvm::FunctionCallee resolve_llvm_random();
172+
llvm::FunctionCallee resolve_llvm_random_double();
173+
llvm::FunctionCallee resolve_llvm_random_bool();
169174
llvm::FunctionCallee resolve_strcasecmp();
170175

171176
Target *m_target = nullptr;
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
#include <scratchcpp/value_functions.h>
4+
5+
#include "llvmfunctions.h"
6+
#include "../../../../engine/internal/randomgenerator.h"
7+
8+
namespace libscratchcpp
9+
{
10+
11+
extern "C"
12+
{
13+
double llvm_random(ValueData *from, ValueData *to)
14+
{
15+
if (!llvm_rng)
16+
llvm_rng = RandomGenerator::instance().get();
17+
18+
return value_isInt(from) && value_isInt(to) ? llvm_rng->randint(value_toLong(from), value_toLong(to)) : llvm_rng->randintDouble(value_toDouble(from), value_toDouble(to));
19+
}
20+
21+
double llvm_random_double(double from, double to)
22+
{
23+
if (!llvm_rng)
24+
llvm_rng = RandomGenerator::instance().get();
25+
26+
return value_doubleIsInt(from) && value_doubleIsInt(to) ? llvm_rng->randint(from, to) : llvm_rng->randintDouble(from, to);
27+
}
28+
29+
double llvm_random_bool(bool from, bool to)
30+
{
31+
if (!llvm_rng)
32+
llvm_rng = RandomGenerator::instance().get();
33+
34+
return llvm_rng->randint(from, to);
35+
}
36+
}
37+
38+
} // namespace libscratchcpp
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
#pragma once
3+
4+
namespace libscratchcpp
5+
{
6+
7+
class IRandomGenerator;
8+
9+
IRandomGenerator *llvm_rng = nullptr;
10+
11+
} // namespace libscratchcpp

src/dev/engine/internal/llvm/llvminstruction.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ struct LLVMInstruction
1818
Sub,
1919
Mul,
2020
Div,
21+
Random,
2122
CmpEQ,
2223
CmpGT,
2324
CmpLT,

test/dev/llvm/llvmcodebuilder_test.cpp

Lines changed: 120 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
#include <scratchcpp/variable.h>
88
#include <scratchcpp/list.h>
99
#include <dev/engine/internal/llvm/llvmcodebuilder.h>
10+
#include <dev/engine/internal/llvm/llvmfunctions.h>
1011
#include <gmock/gmock.h>
1112
#include <targetmock.h>
1213
#include <enginemock.h>
14+
#include <randomgeneratormock.h>
1315

1416
#include "testfunctions.h"
1517

@@ -27,6 +29,7 @@ class LLVMCodeBuilderTest : public testing::Test
2729
Sub,
2830
Mul,
2931
Div,
32+
Random,
3033
CmpEQ,
3134
CmpGT,
3235
CmpLT,
@@ -49,7 +52,6 @@ class LLVMCodeBuilderTest : public testing::Test
4952
Log10,
5053
Exp,
5154
Exp10
52-
5355
};
5456

5557
void SetUp() override
@@ -93,6 +95,9 @@ class LLVMCodeBuilderTest : public testing::Test
9395
case OpType::Div:
9496
return m_builder->createDiv(arg1, arg2);
9597

98+
case OpType::Random:
99+
return m_builder->createRandom(arg1, arg2);
100+
96101
case OpType::CmpEQ:
97102
return m_builder->createCmpEQ(arg1, arg2);
98103

@@ -189,6 +194,15 @@ class LLVMCodeBuilderTest : public testing::Test
189194
case OpType::Div:
190195
return v1 / v2;
191196

197+
case OpType::Random: {
198+
const double sum = v1.toDouble() + v2.toDouble();
199+
200+
if (std::isnan(sum) || std::isinf(sum))
201+
return sum;
202+
203+
return v1.isInt() && v2.isInt() ? m_rng.randint(v1.toLong(), v2.toLong()) : m_rng.randintDouble(v1.toDouble(), v2.toDouble());
204+
}
205+
192206
case OpType::CmpEQ:
193207
return v1 == v2;
194208

@@ -225,7 +239,7 @@ class LLVMCodeBuilderTest : public testing::Test
225239
}
226240
}
227241

228-
void runOpTest(OpType type, const Value &v1, const Value &v2)
242+
void runOpTestCommon(OpType type, const Value &v1, const Value &v2)
229243
{
230244
createBuilder(true);
231245

@@ -241,9 +255,6 @@ class LLVMCodeBuilderTest : public testing::Test
241255
ret = addOp(type, arg1, arg2);
242256
m_builder->addFunctionCall("test_print_string", Compiler::StaticType::Void, { Compiler::StaticType::String }, { ret });
243257

244-
std::string str = doOp(type, v1, v2).toString() + '\n';
245-
std::string expected = str + str;
246-
247258
auto code = m_builder->finalize();
248259
Script script(&m_target, nullptr, nullptr);
249260
script.setCode(code);
@@ -252,9 +263,28 @@ class LLVMCodeBuilderTest : public testing::Test
252263

253264
testing::internal::CaptureStdout();
254265
code->run(ctx.get());
266+
}
267+
268+
void checkOpTest(const Value &v1, const Value &v2, const std::string &expected)
269+
{
255270
const std::string quotes1 = v1.isString() ? "\"" : "";
256271
const std::string quotes2 = v2.isString() ? "\"" : "";
257272
ASSERT_THAT(testing::internal::GetCapturedStdout(), Eq(expected)) << quotes1 << v1.toString() << quotes1 << " " << quotes2 << v2.toString() << quotes2;
273+
}
274+
275+
void runOpTest(OpType type, const Value &v1, const Value &v2, const Value &expected)
276+
{
277+
std::string str = expected.toString();
278+
runOpTestCommon(type, v1, v2);
279+
checkOpTest(v1, v2, str + '\n' + str + '\n');
280+
};
281+
282+
void runOpTest(OpType type, const Value &v1, const Value &v2)
283+
{
284+
runOpTestCommon(type, v1, v2);
285+
std::string str = doOp(type, v1, v2).toString() + '\n';
286+
std::string expected = str + str;
287+
checkOpTest(v1, v2, expected);
258288
};
259289

260290
void runOpTest(OpType type, const Value &v)
@@ -317,6 +347,7 @@ class LLVMCodeBuilderTest : public testing::Test
317347

318348
std::unique_ptr<LLVMCodeBuilder> m_builder;
319349
TargetMock m_target; // NOTE: isStage() is used for call expectations
350+
RandomGeneratorMock m_rng;
320351
};
321352

322353
TEST_F(LLVMCodeBuilderTest, FunctionCalls)
@@ -605,6 +636,90 @@ TEST_F(LLVMCodeBuilderTest, Divide)
605636
runOpTest(OpType::Div, 0, 0);
606637
}
607638

639+
TEST_F(LLVMCodeBuilderTest, Random)
640+
{
641+
llvm_rng = &m_rng;
642+
643+
EXPECT_CALL(m_rng, randint(-45, 12)).Times(3).WillRepeatedly(Return(-18));
644+
runOpTest(OpType::Random, -45, 12);
645+
646+
EXPECT_CALL(m_rng, randint(-45, 12)).Times(3).WillRepeatedly(Return(5));
647+
runOpTest(OpType::Random, -45.0, 12.0);
648+
649+
EXPECT_CALL(m_rng, randintDouble(12, 6.05)).Times(3).WillRepeatedly(Return(3.486789));
650+
runOpTest(OpType::Random, 12, 6.05);
651+
652+
EXPECT_CALL(m_rng, randintDouble(-78.686, -45)).Times(3).WillRepeatedly(Return(-59.468873));
653+
runOpTest(OpType::Random, -78.686, -45);
654+
655+
EXPECT_CALL(m_rng, randintDouble(6.05, -78.686)).Times(3).WillRepeatedly(Return(-28.648764));
656+
runOpTest(OpType::Random, 6.05, -78.686);
657+
658+
EXPECT_CALL(m_rng, randint(-45, 12)).Times(3).WillRepeatedly(Return(0));
659+
runOpTest(OpType::Random, "-45", "12");
660+
661+
EXPECT_CALL(m_rng, randintDouble(-45, 12)).Times(3).WillRepeatedly(Return(5.2));
662+
runOpTest(OpType::Random, "-45.0", "12");
663+
664+
EXPECT_CALL(m_rng, randintDouble(-45, 12)).Times(3).WillRepeatedly(Return(-15.5787));
665+
runOpTest(OpType::Random, "-45", "12.0");
666+
667+
EXPECT_CALL(m_rng, randintDouble(-45, 12)).Times(3).WillRepeatedly(Return(2.587964));
668+
runOpTest(OpType::Random, "-45.0", "12.0");
669+
670+
EXPECT_CALL(m_rng, randintDouble(6.05, -78.686)).Times(3).WillRepeatedly(Return(5.648764));
671+
runOpTest(OpType::Random, "6.05", "-78.686");
672+
673+
EXPECT_CALL(m_rng, randint(-45, 12)).Times(3).WillRepeatedly(Return(0));
674+
runOpTest(OpType::Random, "-45", 12);
675+
676+
EXPECT_CALL(m_rng, randint(-45, 12)).Times(3).WillRepeatedly(Return(0));
677+
runOpTest(OpType::Random, -45, "12");
678+
679+
EXPECT_CALL(m_rng, randintDouble(-45, 12)).Times(3).WillRepeatedly(Return(5.2));
680+
runOpTest(OpType::Random, "-45.0", 12);
681+
682+
EXPECT_CALL(m_rng, randintDouble(-45, 12)).Times(3).WillRepeatedly(Return(-15.5787));
683+
runOpTest(OpType::Random, -45, "12.0");
684+
685+
EXPECT_CALL(m_rng, randintDouble(6.05, -78.686)).Times(3).WillRepeatedly(Return(5.648764));
686+
runOpTest(OpType::Random, 6.05, "-78.686");
687+
688+
EXPECT_CALL(m_rng, randintDouble(6.05, -78.686)).Times(3).WillRepeatedly(Return(5.648764));
689+
runOpTest(OpType::Random, "6.05", -78.686);
690+
691+
EXPECT_CALL(m_rng, randint(0, 1)).Times(3).WillRepeatedly(Return(1));
692+
runOpTest(OpType::Random, false, true);
693+
694+
EXPECT_CALL(m_rng, randint(1, 5)).Times(3).WillRepeatedly(Return(1));
695+
runOpTest(OpType::Random, true, 5);
696+
697+
EXPECT_CALL(m_rng, randint(8, 0)).Times(3).WillRepeatedly(Return(1));
698+
runOpTest(OpType::Random, 8, false);
699+
700+
const double inf = std::numeric_limits<double>::infinity();
701+
const double nan = std::numeric_limits<double>::quiet_NaN();
702+
EXPECT_CALL(m_rng, randint).WillRepeatedly(Return(0));
703+
EXPECT_CALL(m_rng, randintDouble).WillRepeatedly(Return(0));
704+
705+
runOpTest(OpType::Random, inf, 2, inf);
706+
runOpTest(OpType::Random, -8, inf, inf);
707+
runOpTest(OpType::Random, -inf, -2, -inf);
708+
runOpTest(OpType::Random, 8, -inf, -inf);
709+
710+
runOpTest(OpType::Random, inf, 2.5, inf);
711+
runOpTest(OpType::Random, -8.09, inf, inf);
712+
runOpTest(OpType::Random, -inf, -2.5, -inf);
713+
runOpTest(OpType::Random, 8.09, -inf, -inf);
714+
715+
runOpTest(OpType::Random, inf, inf, inf);
716+
runOpTest(OpType::Random, -inf, -inf, -inf);
717+
runOpTest(OpType::Random, inf, -inf, nan);
718+
runOpTest(OpType::Random, -inf, inf, nan);
719+
720+
llvm_rng = nullptr;
721+
}
722+
608723
TEST_F(LLVMCodeBuilderTest, EqualComparison)
609724
{
610725
runOpTest(OpType::CmpEQ, 10, 10);

test/mocks/codebuildermock.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ class CodeBuilderMock : public ICodeBuilder
2626
MOCK_METHOD(CompilerValue *, createMul, (CompilerValue *, CompilerValue *), (override));
2727
MOCK_METHOD(CompilerValue *, createDiv, (CompilerValue *, CompilerValue *), (override));
2828

29+
MOCK_METHOD(CompilerValue *, createRandom, (CompilerValue *, CompilerValue *), (override));
30+
2931
MOCK_METHOD(CompilerValue *, createCmpEQ, (CompilerValue *, CompilerValue *), (override));
3032
MOCK_METHOD(CompilerValue *, createCmpGT, (CompilerValue *, CompilerValue *), (override));
3133
MOCK_METHOD(CompilerValue *, createCmpLT, (CompilerValue *, CompilerValue *), (override));

0 commit comments

Comments
 (0)