Skip to content

Commit 1c36433

Browse files
committed
LLVMCodeBuilder: Implement loop type analysis for lists
1 parent ea97a92 commit 1c36433

File tree

3 files changed

+907
-40
lines changed

3 files changed

+907
-40
lines changed

src/engine/internal/llvm/llvmcodebuilder.cpp

Lines changed: 79 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ using namespace libscratchcpp;
2323
static std::unordered_map<ValueType, Compiler::StaticType>
2424
TYPE_MAP = { { ValueType::Number, Compiler::StaticType::Number }, { ValueType::Bool, Compiler::StaticType::Bool }, { ValueType::String, Compiler::StaticType::String } };
2525

26+
static const std::unordered_set<LLVMInstruction::Type>
27+
VAR_LIST_READ_INSTRUCTIONS = { LLVMInstruction::Type::ReadVariable, LLVMInstruction::Type::GetListItem, LLVMInstruction::Type::GetListItemIndex, LLVMInstruction::Type::ListContainsItem };
28+
2629
LLVMCodeBuilder::LLVMCodeBuilder(LLVMCompilerContext *ctx, BlockPrototype *procedurePrototype) :
2730
m_ctx(ctx),
2831
m_target(ctx->target()),
@@ -632,7 +635,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
632635
LLVMVariablePtr &varPtr = m_variablePtrs[step.workVariable];
633636
varPtr.changed = true;
634637

635-
const bool safe = isVariableTypeSafe(insPtr, varPtr.type);
638+
const bool safe = isVarOrListTypeSafe(insPtr, varPtr.type);
636639

637640
// Initialize stack variable on first assignment
638641
if (!varPtr.onStack) {
@@ -667,7 +670,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
667670
assert(step.args.size() == 0);
668671
LLVMVariablePtr &varPtr = m_variablePtrs[step.workVariable];
669672

670-
if (!isVariableTypeSafe(insPtr, varPtr.type))
673+
if (!isVarOrListTypeSafe(insPtr, varPtr.type))
671674
varPtr.type = Compiler::StaticType::Unknown;
672675

673676
step.functionReturnReg->value = varPtr.onStack && !(step.loopCondition && !m_warp) ? varPtr.stackPtr : varPtr.heapPtr;
@@ -693,7 +696,10 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
693696
case LLVMInstruction::Type::RemoveListItem: {
694697
assert(step.args.size() == 1);
695698
const auto &arg = step.args[0];
696-
const LLVMListPtr &listPtr = m_listPtrs[step.workList];
699+
LLVMListPtr &listPtr = m_listPtrs[step.workList];
700+
701+
if (!isVarOrListTypeSafe(insPtr, listPtr.type))
702+
listPtr.type = Compiler::StaticType::Unknown;
697703

698704
// Range check
699705
llvm::Value *min = llvm::ConstantFP::get(m_llvmCtx, llvm::APFloat(0.0));
@@ -722,6 +728,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
722728
Compiler::StaticType type = optimizeRegisterType(arg.second);
723729
LLVMListPtr &listPtr = m_listPtrs[step.workList];
724730

731+
const bool safe = isVarOrListTypeSafe(insPtr, listPtr.type);
725732
auto &typeMap = m_scopeLists.back();
726733

727734
if (typeMap.find(&listPtr) == typeMap.cend()) {
@@ -732,6 +739,9 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
732739
typeMap[&listPtr] = listPtr.type;
733740
}
734741

742+
if (!safe)
743+
listPtr.type = Compiler::StaticType::Unknown;
744+
735745
// Check if enough space is allocated
736746
llvm::Value *allocatedSize = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.allocatedSizePtr);
737747
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr);
@@ -767,6 +777,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
767777
Compiler::StaticType type = optimizeRegisterType(valueArg.second);
768778
LLVMListPtr &listPtr = m_listPtrs[step.workList];
769779

780+
const bool safe = isVarOrListTypeSafe(insPtr, listPtr.type);
770781
auto &typeMap = m_scopeLists.back();
771782

772783
if (typeMap.find(&listPtr) == typeMap.cend()) {
@@ -777,6 +788,9 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
777788
typeMap[&listPtr] = listPtr.type;
778789
}
779790

791+
if (!safe)
792+
listPtr.type = Compiler::StaticType::Unknown;
793+
780794
llvm::Value *oldAllocatedSize = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.allocatedSizePtr);
781795

782796
// Range check
@@ -813,6 +827,9 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
813827
Compiler::StaticType type = optimizeRegisterType(valueArg.second);
814828
LLVMListPtr &listPtr = m_listPtrs[step.workList];
815829

830+
if (!isVarOrListTypeSafe(insPtr, listPtr.type))
831+
listPtr.type = Compiler::StaticType::Unknown;
832+
816833
// Range check
817834
llvm::Value *min = llvm::ConstantFP::get(m_llvmCtx, llvm::APFloat(0.0));
818835
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr);
@@ -856,7 +873,10 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
856873
case LLVMInstruction::Type::GetListItem: {
857874
assert(step.args.size() == 1);
858875
const auto &arg = step.args[0];
859-
const LLVMListPtr &listPtr = m_listPtrs[step.workList];
876+
LLVMListPtr &listPtr = m_listPtrs[step.workList];
877+
878+
if (!isVarOrListTypeSafe(insPtr, listPtr.type))
879+
listPtr.type = Compiler::StaticType::Unknown;
860880

861881
llvm::Value *min = llvm::ConstantFP::get(m_llvmCtx, llvm::APFloat(0.0));
862882
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr);
@@ -884,15 +904,23 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
884904
case LLVMInstruction::Type::GetListItemIndex: {
885905
assert(step.args.size() == 1);
886906
const auto &arg = step.args[0];
887-
const LLVMListPtr &listPtr = m_listPtrs[step.workList];
907+
LLVMListPtr &listPtr = m_listPtrs[step.workList];
908+
909+
if (!isVarOrListTypeSafe(insPtr, listPtr.type))
910+
listPtr.type = Compiler::StaticType::Unknown;
911+
888912
step.functionReturnReg->value = m_builder.CreateSIToFP(getListItemIndex(listPtr, arg.second), m_builder.getDoubleTy());
889913
break;
890914
}
891915

892916
case LLVMInstruction::Type::ListContainsItem: {
893917
assert(step.args.size() == 1);
894918
const auto &arg = step.args[0];
895-
const LLVMListPtr &listPtr = m_listPtrs[step.workList];
919+
LLVMListPtr &listPtr = m_listPtrs[step.workList];
920+
921+
if (!isVarOrListTypeSafe(insPtr, listPtr.type))
922+
listPtr.type = Compiler::StaticType::Unknown;
923+
896924
llvm::Value *index = getListItemIndex(listPtr, arg.second);
897925
step.functionReturnReg->value = m_builder.CreateICmpSGT(index, llvm::ConstantInt::get(m_builder.getInt64Ty(), -1, true));
898926
break;
@@ -2258,30 +2286,31 @@ void LLVMCodeBuilder::updateListDataPtr(const LLVMListPtr &listPtr)
22582286
m_builder.CreateStore(m_builder.getInt1(false), listPtr.dataPtrDirty);
22592287
}
22602288

2261-
bool LLVMCodeBuilder::isVariableTypeSafe(std::shared_ptr<LLVMInstruction> ins, Compiler::StaticType expectedType) const
2289+
bool LLVMCodeBuilder::isVarOrListTypeSafe(std::shared_ptr<LLVMInstruction> ins, Compiler::StaticType expectedType) const
22622290
{
22632291
std::unordered_set<LLVMInstruction *> processed;
2264-
return isVariableTypeSafe(ins, expectedType, processed);
2292+
return isVarOrListTypeSafe(ins, expectedType, processed);
22652293
}
22662294

2267-
bool LLVMCodeBuilder::isVariableTypeSafe(std::shared_ptr<LLVMInstruction> ins, Compiler::StaticType expectedType, std::unordered_set<LLVMInstruction *> &processed) const
2295+
bool LLVMCodeBuilder::isVarOrListTypeSafe(std::shared_ptr<LLVMInstruction> ins, Compiler::StaticType expectedType, std::unordered_set<LLVMInstruction *> &processed) const
22682296
{
22692297
/*
22702298
* The main part of the loop type analyzer.
22712299
*
2272-
* This is a recursive function which is called when variable read
2273-
* instruction is created. It checks the last write to the
2274-
* variable in one of the loop scopes.
2300+
* This is a recursive function which is called when variable
2301+
* or list instruction is created. It checks the last write to
2302+
* the variable or list in one of the loop scopes.
22752303
*
22762304
* If the last write operation writes a value with a different
22772305
* type, it will return false, otherwise true.
22782306
*
2279-
* If the last written value is from a variable, this function
2280-
* is called for it to check its type safety (that's why it is
2281-
* recursive).
2307+
* If the last written value is from a variable or list, this
2308+
* function is called for it to check its type safety (that's
2309+
* why it is recursive).
22822310
*
2283-
* If the variable had a write operation before (in the same,
2284-
* parent or child loop scope), it is checked recursively.
2311+
* If the variable or list had a write operation before (in
2312+
* the same, parent or child loop scope), it is checked
2313+
* recursively.
22852314
*/
22862315

22872316
if (!ins)
@@ -2307,7 +2336,9 @@ bool LLVMCodeBuilder::isVariableTypeSafe(std::shared_ptr<LLVMInstruction> ins, C
23072336
processed.insert(ins.get());
23082337

23092338
assert(std::find(m_instructions.begin(), m_instructions.end(), ins) != m_instructions.end());
2310-
const LLVMVariablePtr &varPtr = m_variablePtrs.at(ins->workVariable);
2339+
const LLVMVariablePtr *varPtr = ins->workVariable ? &m_variablePtrs.at(ins->workVariable) : nullptr;
2340+
const LLVMListPtr *listPtr = ins->workList ? &m_listPtrs.at(ins->workList) : nullptr;
2341+
assert((varPtr || listPtr) && !(varPtr && listPtr));
23112342
auto scope = ins->loopScope;
23122343

23132344
// If we aren't in a loop, we're safe
@@ -2319,21 +2350,23 @@ bool LLVMCodeBuilder::isVariableTypeSafe(std::shared_ptr<LLVMInstruction> ins, C
23192350
return false;
23202351

23212352
std::shared_ptr<LLVMInstruction> write;
2353+
const auto &instructions = varPtr ? m_variableInstructions : m_listInstructions;
23222354

23232355
// Find this instruction
2324-
auto it = std::find(m_variableInstructions.begin(), m_variableInstructions.end(), ins);
2325-
assert(it != m_variableInstructions.end());
2356+
auto it = std::find(instructions.begin(), instructions.end(), ins);
2357+
assert(it != instructions.end());
23262358

23272359
// Find previous write instruction in this, parent or child loop scope
2328-
size_t index = it - m_variableInstructions.begin();
2360+
size_t index = it - instructions.begin();
23292361

23302362
if (index > 0) {
23312363
bool found = false;
23322364

23332365
do {
23342366
index--;
2335-
write = m_variableInstructions[index];
2336-
found = (write->loopScope && write->type == LLVMInstruction::Type::WriteVariable && write->workVariable == ins->workVariable);
2367+
write = instructions[index];
2368+
const bool isWrite = (VAR_LIST_READ_INSTRUCTIONS.find(write->type) == VAR_LIST_READ_INSTRUCTIONS.cend());
2369+
found = (write->loopScope && isWrite && write->workVariable == ins->workVariable);
23372370
} while (index > 0 && !found);
23382371

23392372
if (found) {
@@ -2355,13 +2388,15 @@ bool LLVMCodeBuilder::isVariableTypeSafe(std::shared_ptr<LLVMInstruction> ins, C
23552388
// If there was a write operation before this instruction (in this, parent or child scope), check it
23562389
if (parentScope) {
23572390
if (parentScope == scope)
2358-
return isVariableWriteResultTypeSafe(write, expectedType, true, processed);
2391+
return isVarOrListWriteResultTypeSafe(write, expectedType, true, processed);
23592392
else
2360-
return isVariableTypeSafe(write, expectedType, processed);
2393+
return isVarOrListTypeSafe(write, expectedType, processed);
23612394
}
23622395
}
23632396
}
23642397

2398+
const auto &loopWrites = varPtr ? varPtr->loopVariableWrites : listPtr->loopListWrites;
2399+
23652400
// Get last write operation
23662401
write = nullptr;
23672402

@@ -2374,9 +2409,9 @@ bool LLVMCodeBuilder::isVariableTypeSafe(std::shared_ptr<LLVMInstruction> ins, C
23742409

23752410
// Find last loop scope (may be a parent or child scope)
23762411
while (checkScope) {
2377-
auto it = varPtr.loopVariableWrites.find(checkScope);
2412+
auto it = loopWrites.find(checkScope);
23782413

2379-
if (it != varPtr.loopVariableWrites.cend()) {
2414+
if (it != loopWrites.cend()) {
23802415
assert(!it->second.empty());
23812416
write = it->second.back();
23822417
}
@@ -2393,29 +2428,36 @@ bool LLVMCodeBuilder::isVariableTypeSafe(std::shared_ptr<LLVMInstruction> ins, C
23932428

23942429
bool safe = true;
23952430

2396-
if (ins->type == LLVMInstruction::Type::WriteVariable)
2397-
safe = isVariableWriteResultTypeSafe(ins, expectedType, false, processed);
2431+
if (VAR_LIST_READ_INSTRUCTIONS.find(ins->type) == VAR_LIST_READ_INSTRUCTIONS.cend()) // write
2432+
safe = isVarOrListWriteResultTypeSafe(ins, expectedType, false, processed);
23982433

23992434
if (safe)
2400-
return isVariableWriteResultTypeSafe(write, expectedType, false, processed);
2435+
return isVarOrListWriteResultTypeSafe(write, expectedType, false, processed);
24012436
else
24022437
return false;
24032438
}
24042439

2405-
bool LLVMCodeBuilder::isVariableWriteResultTypeSafe(std::shared_ptr<LLVMInstruction> ins, Compiler::StaticType expectedType, bool ignoreSavedType, std::unordered_set<LLVMInstruction *> &processed)
2440+
bool LLVMCodeBuilder::isVarOrListWriteResultTypeSafe(std::shared_ptr<LLVMInstruction> ins, Compiler::StaticType expectedType, bool ignoreSavedType, std::unordered_set<LLVMInstruction *> &processed)
24062441
const
24072442
{
2408-
const LLVMVariablePtr &varPtr = m_variablePtrs.at(ins->workVariable);
2443+
const LLVMVariablePtr *varPtr = ins->workVariable ? &m_variablePtrs.at(ins->workVariable) : nullptr;
2444+
const LLVMListPtr *listPtr = ins->workList ? &m_listPtrs.at(ins->workList) : nullptr;
2445+
assert((varPtr || listPtr) && !(varPtr && listPtr));
24092446

24102447
// If the write operation writes the value of another variable, recursively check its type safety
2411-
// TODO: Check get list item instruction
2412-
auto argIns = ins->args[0].second->instruction;
2448+
const auto arg = ins->args.back().second; // value is always the last argument
2449+
auto argIns = arg->instruction;
24132450

2414-
if (argIns && argIns->type == LLVMInstruction::Type::ReadVariable)
2415-
return isVariableTypeSafe(argIns, expectedType, processed);
2451+
if (argIns && (argIns->type == LLVMInstruction::Type::ReadVariable || argIns->type == LLVMInstruction::Type::GetListItem))
2452+
return isVarOrListTypeSafe(argIns, expectedType, processed);
24162453

24172454
// Check written type
2418-
return optimizeRegisterType(ins->args[0].second) == expectedType && (varPtr.type == expectedType || ignoreSavedType);
2455+
const bool typeMatches = (optimizeRegisterType(arg) == expectedType);
2456+
2457+
if (varPtr)
2458+
return typeMatches && (varPtr->type == expectedType || ignoreSavedType);
2459+
else
2460+
return typeMatches && (listPtr->type == expectedType || ignoreSavedType);
24192461
}
24202462

24212463
LLVMRegister *LLVMCodeBuilder::createOp(LLVMInstruction::Type type, Compiler::StaticType retType, Compiler::StaticType argType, const Compiler::Args &args)

src/engine/internal/llvm/llvmcodebuilder.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,9 @@ class LLVMCodeBuilder : public ICodeBuilder
149149
void reloadVariables(llvm::Value *targetVariables);
150150
void reloadLists();
151151
void updateListDataPtr(const LLVMListPtr &listPtr);
152-
bool isVariableTypeSafe(std::shared_ptr<LLVMInstruction> ins, Compiler::StaticType expectedType) const;
153-
bool isVariableTypeSafe(std::shared_ptr<LLVMInstruction> ins, Compiler::StaticType expectedType, std::unordered_set<LLVMInstruction *> &processed) const;
154-
bool isVariableWriteResultTypeSafe(std::shared_ptr<LLVMInstruction> ins, Compiler::StaticType expectedType, bool ignoreSavedType, std::unordered_set<LLVMInstruction *> &processed) const;
152+
bool isVarOrListTypeSafe(std::shared_ptr<LLVMInstruction> ins, Compiler::StaticType expectedType) const;
153+
bool isVarOrListTypeSafe(std::shared_ptr<LLVMInstruction> ins, Compiler::StaticType expectedType, std::unordered_set<LLVMInstruction *> &processed) const;
154+
bool isVarOrListWriteResultTypeSafe(std::shared_ptr<LLVMInstruction> ins, Compiler::StaticType expectedType, bool ignoreSavedType, std::unordered_set<LLVMInstruction *> &processed) const;
155155

156156
LLVMRegister *createOp(LLVMInstruction::Type type, Compiler::StaticType retType, Compiler::StaticType argType, const Compiler::Args &args);
157157
LLVMRegister *createOp(LLVMInstruction::Type type, Compiler::StaticType retType, const Compiler::ArgTypes &argTypes = {}, const Compiler::Args &args = {});

0 commit comments

Comments
 (0)