Skip to content

Commit f2cdb0c

Browse files
committed
LLVMCodeBuilder: Add createStrCmpEQ() method
1 parent 26a4540 commit f2cdb0c

File tree

6 files changed

+239
-4
lines changed

6 files changed

+239
-4
lines changed

src/dev/engine/internal/icodebuilder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ class ICodeBuilder
4646
virtual CompilerValue *createCmpGT(CompilerValue *operand1, CompilerValue *operand2) = 0;
4747
virtual CompilerValue *createCmpLT(CompilerValue *operand1, CompilerValue *operand2) = 0;
4848

49+
virtual CompilerValue *createStrCmpEQ(CompilerValue *string1, CompilerValue *string2, bool caseSensitive = false) = 0;
50+
4951
virtual CompilerValue *createAnd(CompilerValue *operand1, CompilerValue *operand2) = 0;
5052
virtual CompilerValue *createOr(CompilerValue *operand1, CompilerValue *operand2) = 0;
5153
virtual CompilerValue *createNot(CompilerValue *operand) = 0;

src/dev/engine/internal/llvm/llvmcodebuilder.cpp

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,22 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
270270
break;
271271
}
272272

273+
case LLVMInstruction::Type::StrCmpEQCS: {
274+
assert(step.args.size() == 2);
275+
const auto &arg1 = step.args[0].second;
276+
const auto &arg2 = step.args[1].second;
277+
step.functionReturnReg->value = createStringComparison(arg1, arg2, true);
278+
break;
279+
}
280+
281+
case LLVMInstruction::Type::StrCmpEQCI: {
282+
assert(step.args.size() == 2);
283+
const auto &arg1 = step.args[0].second;
284+
const auto &arg2 = step.args[1].second;
285+
step.functionReturnReg->value = createStringComparison(arg1, arg2, false);
286+
break;
287+
}
288+
273289
case LLVMInstruction::Type::And: {
274290
assert(step.args.size() == 2);
275291
const auto &arg1 = step.args[0];
@@ -1363,6 +1379,11 @@ CompilerValue *LLVMCodeBuilder::createCmpLT(CompilerValue *operand1, CompilerVal
13631379
return createOp(LLVMInstruction::Type::CmpLT, Compiler::StaticType::Bool, Compiler::StaticType::Number, { operand1, operand2 });
13641380
}
13651381

1382+
CompilerValue *LLVMCodeBuilder::createStrCmpEQ(CompilerValue *string1, CompilerValue *string2, bool caseSensitive)
1383+
{
1384+
return createOp(caseSensitive ? LLVMInstruction::Type::StrCmpEQCS : LLVMInstruction::Type::StrCmpEQCI, Compiler::StaticType::Bool, Compiler::StaticType::String, { string1, string2 });
1385+
}
1386+
13661387
CompilerValue *LLVMCodeBuilder::createAnd(CompilerValue *operand1, CompilerValue *operand2)
13671388
{
13681389
return createOp(LLVMInstruction::Type::And, Compiler::StaticType::Bool, Compiler::StaticType::Bool, { operand1, operand2 });
@@ -2399,10 +2420,8 @@ llvm::Value *LLVMCodeBuilder::createComparison(LLVMRegister *arg1, LLVMRegister
23992420
// Optimize number and string constant comparison
24002421
// TODO: GT and LT comparison can be optimized here (e. g. by checking the string constant characters and comparing with numbers and .+-e)
24012422
if (type == Comparison::EQ) {
2402-
if (type1 == Compiler::StaticType::Number && type2 == Compiler::StaticType::String && arg2->isConst() && !arg2->constValue().isValidNumber())
2403-
return m_builder.getInt1(false);
2404-
2405-
if (type1 == Compiler::StaticType::String && type2 == Compiler::StaticType::Number && arg1->isConst() && !arg1->constValue().isValidNumber())
2423+
if ((type1 == Compiler::StaticType::Number && type2 == Compiler::StaticType::String && arg2->isConst() && !arg2->constValue().isValidNumber()) ||
2424+
(type1 == Compiler::StaticType::String && type2 == Compiler::StaticType::Number && arg1->isConst() && !arg1->constValue().isValidNumber()))
24062425
return m_builder.getInt1(false);
24072426
}
24082427

@@ -2535,6 +2554,39 @@ llvm::Value *LLVMCodeBuilder::createComparison(LLVMRegister *arg1, LLVMRegister
25352554
}
25362555
}
25372556

2557+
llvm::Value *LLVMCodeBuilder::createStringComparison(LLVMRegister *arg1, LLVMRegister *arg2, bool caseSensitive)
2558+
{
2559+
auto type1 = arg1->type();
2560+
auto type2 = arg2->type();
2561+
2562+
if (arg1->isConst() && arg2->isConst()) {
2563+
// If both operands are constant, perform the comparison at compile time
2564+
bool result;
2565+
2566+
if (caseSensitive)
2567+
result = arg1->constValue().toString() == arg2->constValue().toString();
2568+
else {
2569+
std::string str1 = arg1->constValue().toString();
2570+
std::string str2 = arg2->constValue().toString();
2571+
result = strcasecmp(str1.c_str(), str2.c_str()) == 0;
2572+
}
2573+
2574+
return m_builder.getInt1(result);
2575+
} else {
2576+
// Optimize number and string constant comparison
2577+
// TODO: Optimize bool and string constant comparison (in compare() as well)
2578+
if ((type1 == Compiler::StaticType::Number && type2 == Compiler::StaticType::String && arg2->isConst() && !arg2->constValue().isValidNumber()) ||
2579+
(type1 == Compiler::StaticType::String && type2 == Compiler::StaticType::Number && arg1->isConst() && !arg1->constValue().isValidNumber()))
2580+
return m_builder.getInt1(false);
2581+
2582+
// Explicitly cast to string
2583+
llvm::Value *string1 = castValue(arg1, Compiler::StaticType::String);
2584+
llvm::Value *string2 = castValue(arg2, Compiler::StaticType::String);
2585+
llvm::Value *cmp = m_builder.CreateCall(caseSensitive ? resolve_strcmp() : resolve_strcasecmp(), { string1, string2 });
2586+
return m_builder.CreateICmpEQ(cmp, m_builder.getInt32(0));
2587+
}
2588+
}
2589+
25382590
void LLVMCodeBuilder::createSuspend(LLVMCoroutine *coro, llvm::Value *warpArg, llvm::Value *targetVariables)
25392591
{
25402592
if (!m_warp) {
@@ -2769,6 +2821,15 @@ llvm::FunctionCallee LLVMCodeBuilder::resolve_llvm_random_bool()
27692821
return resolveFunction("llvm_random_bool", llvm::FunctionType::get(m_builder.getDoubleTy(), { pointerType, m_builder.getInt1Ty(), m_builder.getInt1Ty() }, false));
27702822
}
27712823

2824+
llvm::FunctionCallee LLVMCodeBuilder::resolve_strcmp()
2825+
{
2826+
llvm::Type *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_llvmCtx), 0);
2827+
llvm::FunctionCallee callee = resolveFunction("strcmp", llvm::FunctionType::get(m_builder.getInt32Ty(), { pointerType, pointerType }, false));
2828+
llvm::Function *func = llvm::cast<llvm::Function>(callee.getCallee());
2829+
func->addFnAttr(llvm::Attribute::ReadOnly);
2830+
return callee;
2831+
}
2832+
27722833
llvm::FunctionCallee LLVMCodeBuilder::resolve_strcasecmp()
27732834
{
27742835
llvm::Type *pointerType = llvm::PointerType::get(llvm::Type::getInt8Ty(m_llvmCtx), 0);

src/dev/engine/internal/llvm/llvmcodebuilder.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ class LLVMCodeBuilder : public ICodeBuilder
5454
CompilerValue *createCmpGT(CompilerValue *operand1, CompilerValue *operand2) override;
5555
CompilerValue *createCmpLT(CompilerValue *operand1, CompilerValue *operand2) override;
5656

57+
CompilerValue *createStrCmpEQ(CompilerValue *string1, CompilerValue *string2, bool caseSensitive = false) override;
58+
5759
CompilerValue *createAnd(CompilerValue *operand1, CompilerValue *operand2) override;
5860
CompilerValue *createOr(CompilerValue *operand1, CompilerValue *operand2) override;
5961
CompilerValue *createNot(CompilerValue *operand) override;
@@ -156,6 +158,7 @@ class LLVMCodeBuilder : public ICodeBuilder
156158
llvm::Value *getListItemIndex(const LLVMListPtr &listPtr, LLVMRegister *item);
157159
llvm::Value *createValue(LLVMRegister *reg);
158160
llvm::Value *createComparison(LLVMRegister *arg1, LLVMRegister *arg2, Comparison type);
161+
llvm::Value *createStringComparison(LLVMRegister *arg1, LLVMRegister *arg2, bool caseSensitive);
159162

160163
void createSuspend(LLVMCoroutine *coro, llvm::Value *warpArg, llvm::Value *targetVariables);
161164

@@ -190,6 +193,7 @@ class LLVMCodeBuilder : public ICodeBuilder
190193
llvm::FunctionCallee resolve_llvm_random_double();
191194
llvm::FunctionCallee resolve_llvm_random_long();
192195
llvm::FunctionCallee resolve_llvm_random_bool();
196+
llvm::FunctionCallee resolve_strcmp();
193197
llvm::FunctionCallee resolve_strcasecmp();
194198

195199
Target *m_target = nullptr;

src/dev/engine/internal/llvm/llvminstruction.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ struct LLVMInstruction
2525
CmpEQ,
2626
CmpGT,
2727
CmpLT,
28+
StrCmpEQCS, // case sensitive
29+
StrCmpEQCI, // case insensitive
2830
And,
2931
Or,
3032
Not,

test/dev/llvm/llvmcodebuilder_test.cpp

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class LLVMCodeBuilderTest : public testing::Test
3636
CmpEQ,
3737
CmpGT,
3838
CmpLT,
39+
StrCmpEQCS,
40+
StrCmpEQCI,
3941
And,
4042
Or,
4143
Not,
@@ -127,6 +129,12 @@ class LLVMCodeBuilderTest : public testing::Test
127129
case OpType::CmpLT:
128130
return m_builder->createCmpLT(arg1, arg2);
129131

132+
case OpType::StrCmpEQCS:
133+
return m_builder->createStrCmpEQ(arg1, arg2, true);
134+
135+
case OpType::StrCmpEQCI:
136+
return m_builder->createStrCmpEQ(arg1, arg2, false);
137+
130138
case OpType::And:
131139
return m_builder->createAnd(arg1, arg2);
132140

@@ -235,6 +243,16 @@ class LLVMCodeBuilderTest : public testing::Test
235243
case OpType::CmpLT:
236244
return v1 < v2;
237245

246+
case OpType::StrCmpEQCS:
247+
return v1.toString() == v2.toString();
248+
249+
case OpType::StrCmpEQCI: {
250+
// TODO: Use a custom function for string comparison
251+
std::string str1 = v1.toString();
252+
std::string str2 = v2.toString();
253+
return strcasecmp(str1.c_str(), str2.c_str()) == 0;
254+
}
255+
238256
case OpType::And:
239257
return v1.toBool() && v2.toBool();
240258

@@ -1089,6 +1107,152 @@ TEST_F(LLVMCodeBuilderTest, GreaterAndLowerThanComparison)
10891107
}
10901108
}
10911109

1110+
TEST_F(LLVMCodeBuilderTest, StringEqualComparison)
1111+
{
1112+
std::vector<OpType> types = { OpType::StrCmpEQCS, OpType::StrCmpEQCI };
1113+
1114+
for (OpType type : types) {
1115+
runOpTest(type, 10, 10);
1116+
runOpTest(type, 10, 8);
1117+
runOpTest(type, 8, 10);
1118+
1119+
runOpTest(type, -4.25, -4.25);
1120+
runOpTest(type, -4.25, 5.312);
1121+
runOpTest(type, 5.312, -4.25);
1122+
1123+
runOpTest(type, true, true);
1124+
runOpTest(type, true, false);
1125+
runOpTest(type, false, true);
1126+
1127+
runOpTest(type, 1, true);
1128+
runOpTest(type, 1, false);
1129+
1130+
runOpTest(type, "abC def", "abC def");
1131+
runOpTest(type, "abC def", "abc dEf");
1132+
runOpTest(type, "abC def", "ghi Jkl");
1133+
runOpTest(type, "abC def", "hello world");
1134+
1135+
runOpTest(type, " ", "");
1136+
runOpTest(type, " ", "0");
1137+
runOpTest(type, " ", 0);
1138+
runOpTest(type, 0, " ");
1139+
runOpTest(type, "", "0");
1140+
runOpTest(type, "", 0);
1141+
runOpTest(type, 0, "");
1142+
runOpTest(type, "0", 0);
1143+
runOpTest(type, 0, "0");
1144+
1145+
runOpTest(type, 5.25, "5.25");
1146+
runOpTest(type, "5.25", 5.25);
1147+
runOpTest(type, 5.25, " 5.25");
1148+
runOpTest(type, " 5.25", 5.25);
1149+
runOpTest(type, 5.25, "5.25 ");
1150+
runOpTest(type, "5.25 ", 5.25);
1151+
runOpTest(type, 5.25, " 5.25 ");
1152+
runOpTest(type, " 5.25 ", 5.25);
1153+
runOpTest(type, 5.25, "5.26");
1154+
runOpTest(type, "5.26", 5.25);
1155+
runOpTest(type, "5.25", "5.26");
1156+
runOpTest(type, 5, "5 ");
1157+
runOpTest(type, "5 ", 5);
1158+
runOpTest(type, 0, "1");
1159+
runOpTest(type, "1", 0);
1160+
runOpTest(type, 0, "test");
1161+
runOpTest(type, "test", 0);
1162+
1163+
static const double inf = std::numeric_limits<double>::infinity();
1164+
static const double nan = std::numeric_limits<double>::quiet_NaN();
1165+
1166+
runOpTest(type, inf, inf);
1167+
runOpTest(type, -inf, -inf);
1168+
runOpTest(type, nan, nan);
1169+
runOpTest(type, inf, -inf);
1170+
runOpTest(type, -inf, inf);
1171+
runOpTest(type, inf, nan);
1172+
runOpTest(type, nan, inf);
1173+
runOpTest(type, -inf, nan);
1174+
runOpTest(type, nan, -inf);
1175+
1176+
runOpTest(type, 5, inf);
1177+
runOpTest(type, 5, -inf);
1178+
runOpTest(type, 5, nan);
1179+
runOpTest(type, 0, nan);
1180+
1181+
runOpTest(type, true, "true");
1182+
runOpTest(type, "true", true);
1183+
runOpTest(type, false, "false");
1184+
runOpTest(type, "false", false);
1185+
runOpTest(type, false, "true");
1186+
runOpTest(type, "true", false);
1187+
runOpTest(type, true, "false");
1188+
runOpTest(type, "false", true);
1189+
runOpTest(type, true, "TRUE");
1190+
runOpTest(type, "TRUE", true);
1191+
runOpTest(type, false, "FALSE");
1192+
runOpTest(type, "FALSE", false);
1193+
1194+
runOpTest(type, true, "00001");
1195+
runOpTest(type, "00001", true);
1196+
runOpTest(type, true, "00000");
1197+
runOpTest(type, "00000", true);
1198+
runOpTest(type, false, "00000");
1199+
runOpTest(type, "00000", false);
1200+
1201+
runOpTest(type, "true", 1);
1202+
runOpTest(type, 1, "true");
1203+
runOpTest(type, "true", 0);
1204+
runOpTest(type, 0, "true");
1205+
runOpTest(type, "false", 0);
1206+
runOpTest(type, 0, "false");
1207+
runOpTest(type, "false", 1);
1208+
runOpTest(type, 1, "false");
1209+
1210+
runOpTest(type, "true", "TRUE");
1211+
runOpTest(type, "true", "FALSE");
1212+
runOpTest(type, "false", "FALSE");
1213+
runOpTest(type, "false", "TRUE");
1214+
1215+
runOpTest(type, true, inf);
1216+
runOpTest(type, true, -inf);
1217+
runOpTest(type, true, nan);
1218+
runOpTest(type, false, inf);
1219+
runOpTest(type, false, -inf);
1220+
runOpTest(type, false, nan);
1221+
1222+
runOpTest(type, "Infinity", inf);
1223+
runOpTest(type, "Infinity", -inf);
1224+
runOpTest(type, "Infinity", nan);
1225+
runOpTest(type, "infinity", inf);
1226+
runOpTest(type, "infinity", -inf);
1227+
runOpTest(type, "infinity", nan);
1228+
runOpTest(type, "-Infinity", inf);
1229+
runOpTest(type, "-Infinity", -inf);
1230+
runOpTest(type, "-Infinity", nan);
1231+
runOpTest(type, "-infinity", inf);
1232+
runOpTest(type, "-infinity", -inf);
1233+
runOpTest(type, "-infinity", nan);
1234+
runOpTest(type, "NaN", inf);
1235+
runOpTest(type, "NaN", -inf);
1236+
runOpTest(type, "NaN", nan);
1237+
runOpTest(type, "nan", inf);
1238+
runOpTest(type, "nan", -inf);
1239+
runOpTest(type, "nan", nan);
1240+
1241+
runOpTest(type, inf, "abc");
1242+
runOpTest(type, inf, " ");
1243+
runOpTest(type, inf, "");
1244+
runOpTest(type, inf, "0");
1245+
runOpTest(type, -inf, "abc");
1246+
runOpTest(type, -inf, " ");
1247+
runOpTest(type, -inf, "");
1248+
runOpTest(type, -inf, "0");
1249+
runOpTest(type, nan, "abc");
1250+
runOpTest(type, nan, " ");
1251+
runOpTest(type, nan, "");
1252+
runOpTest(type, nan, "0");
1253+
}
1254+
}
1255+
10921256
TEST_F(LLVMCodeBuilderTest, AndOr)
10931257
{
10941258
std::vector<OpType> opTypes = { OpType::And, OpType::Or };

test/mocks/codebuildermock.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ class CodeBuilderMock : public ICodeBuilder
3535
MOCK_METHOD(CompilerValue *, createCmpGT, (CompilerValue *, CompilerValue *), (override));
3636
MOCK_METHOD(CompilerValue *, createCmpLT, (CompilerValue *, CompilerValue *), (override));
3737

38+
MOCK_METHOD(CompilerValue *, createStrCmpEQ, (CompilerValue *, CompilerValue *, bool), (override));
39+
3840
MOCK_METHOD(CompilerValue *, createAnd, (CompilerValue *, CompilerValue *), (override));
3941
MOCK_METHOD(CompilerValue *, createOr, (CompilerValue *, CompilerValue *), (override));
4042
MOCK_METHOD(CompilerValue *, createNot, (CompilerValue *), (override));

0 commit comments

Comments
 (0)