Skip to content

Commit 95867e7

Browse files
committed
LLVMCodeBuilder: Implement loop type analysis
1 parent b73c5e1 commit 95867e7

File tree

3 files changed

+847
-3
lines changed

3 files changed

+847
-3
lines changed

src/engine/internal/llvm/llvmcodebuilder.cpp

Lines changed: 171 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,8 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
632632
LLVMVariablePtr &varPtr = m_variablePtrs[step.workVariable];
633633
varPtr.changed = true;
634634

635+
const bool safe = isVariableTypeSafe(insPtr, varPtr.type);
636+
635637
// Initialize stack variable on first assignment
636638
if (!varPtr.onStack) {
637639
varPtr.onStack = true;
@@ -652,6 +654,9 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
652654
m_builder.CreateStore(m_builder.getInt32(static_cast<uint32_t>(mappedType)), typeField);
653655
}
654656

657+
if (!safe)
658+
varPtr.type = Compiler::StaticType::Unknown;
659+
655660
createValueStore(arg.second, varPtr.stackPtr, type, varPtr.type);
656661
varPtr.type = type;
657662
m_scopeVariables.back()[&varPtr] = varPtr.type;
@@ -660,7 +665,11 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
660665

661666
case LLVMInstruction::Type::ReadVariable: {
662667
assert(step.args.size() == 0);
663-
const LLVMVariablePtr &varPtr = m_variablePtrs[step.workVariable];
668+
LLVMVariablePtr &varPtr = m_variablePtrs[step.workVariable];
669+
670+
if (!isVariableTypeSafe(insPtr, varPtr.type))
671+
varPtr.type = Compiler::StaticType::Unknown;
672+
664673
step.functionReturnReg->value = varPtr.onStack ? varPtr.stackPtr : varPtr.heapPtr;
665674
step.functionReturnReg->setType(varPtr.type);
666675
break;
@@ -2091,7 +2100,7 @@ llvm::Constant *LLVMCodeBuilder::castConstValue(const Value &value, Compiler::St
20912100
}
20922101
}
20932102

2094-
Compiler::StaticType LLVMCodeBuilder::optimizeRegisterType(LLVMRegister *reg)
2103+
Compiler::StaticType LLVMCodeBuilder::optimizeRegisterType(LLVMRegister *reg) const
20952104
{
20962105
Compiler::StaticType ret = reg->type();
20972106

@@ -2213,6 +2222,166 @@ void LLVMCodeBuilder::updateListDataPtr(const LLVMListPtr &listPtr)
22132222
m_builder.CreateStore(m_builder.getInt1(false), listPtr.dataPtrDirty);
22142223
}
22152224

2225+
bool LLVMCodeBuilder::isVariableTypeSafe(std::shared_ptr<LLVMInstruction> ins, Compiler::StaticType expectedType) const
2226+
{
2227+
std::unordered_set<LLVMInstruction *> processed;
2228+
return isVariableTypeSafe(ins, expectedType, processed);
2229+
}
2230+
2231+
bool LLVMCodeBuilder::isVariableTypeSafe(std::shared_ptr<LLVMInstruction> ins, Compiler::StaticType expectedType, std::unordered_set<LLVMInstruction *> &processed) const
2232+
{
2233+
/*
2234+
* The main part of the loop type analyzer.
2235+
*
2236+
* This is a recursive function which is called when variable read
2237+
* instruction is created. It checks the last write to the
2238+
* variable in one of the loop scopes.
2239+
*
2240+
* If the last write operation writes a value with a different
2241+
* type, it will return false, otherwise true.
2242+
*
2243+
* If the last written value is from a variable, this function
2244+
* is called for it to check its type safety (that's why it is
2245+
* recursive).
2246+
*
2247+
* If the variable had a write operation before (in the same,
2248+
* parent or child loop scope), it is checked recursively.
2249+
*/
2250+
2251+
if (!ins)
2252+
return false;
2253+
2254+
/*
2255+
* If we are processing something that has been already
2256+
* processed, give up to avoid infinite recursion.
2257+
*
2258+
* This can happen in edge cases like this:
2259+
* var = var
2260+
*
2261+
* or this:
2262+
* x = y
2263+
* y = x
2264+
*
2265+
* This code isn't considered valid, so don't bother
2266+
* optimizing.
2267+
*/
2268+
if (processed.find(ins.get()) != processed.cend())
2269+
return false;
2270+
2271+
processed.insert(ins.get());
2272+
2273+
assert(std::find(m_instructions.begin(), m_instructions.end(), ins) != m_instructions.end());
2274+
const LLVMVariablePtr &varPtr = m_variablePtrs.at(ins->workVariable);
2275+
auto scope = ins->loopScope;
2276+
2277+
// If we aren't in a loop, we're safe
2278+
if (!scope)
2279+
return true;
2280+
2281+
// If the loop scope contains a suspend and this is a non-warp script, the type may change between suspend and resume
2282+
if (scope->containsYield && !m_warp)
2283+
return false;
2284+
2285+
std::shared_ptr<LLVMInstruction> write;
2286+
2287+
// Find this instruction
2288+
auto it = std::find(m_variableInstructions.begin(), m_variableInstructions.end(), ins);
2289+
assert(it != m_variableInstructions.end());
2290+
2291+
// Find previous write instruction in this, parent or child loop scope
2292+
size_t index = it - m_variableInstructions.begin();
2293+
2294+
if (index > 0) {
2295+
bool found = false;
2296+
2297+
do {
2298+
index--;
2299+
write = m_variableInstructions[index];
2300+
found = (write->loopScope && write->type == LLVMInstruction::Type::WriteVariable && write->workVariable == ins->workVariable);
2301+
} while (index > 0 && !found);
2302+
2303+
if (found) {
2304+
// Check if the write operation is in this or child scope
2305+
auto parentScope = write->loopScope;
2306+
2307+
while (parentScope && parentScope != scope)
2308+
parentScope = parentScope->parentScope;
2309+
2310+
if (!parentScope) {
2311+
// Check if the write operation is in any of the parent scopes
2312+
parentScope = scope;
2313+
2314+
do {
2315+
parentScope = parentScope->parentScope;
2316+
} while (parentScope && parentScope != write->loopScope);
2317+
}
2318+
2319+
// If there was a write operation before this instruction (in this, parent or child scope), check it
2320+
if (parentScope) {
2321+
if (parentScope == scope)
2322+
return isVariableWriteResultTypeSafe(write, expectedType, true, processed);
2323+
else
2324+
return isVariableTypeSafe(write, expectedType, processed);
2325+
}
2326+
}
2327+
}
2328+
2329+
// Get last write operation
2330+
write = nullptr;
2331+
2332+
// Find root loop scope
2333+
auto checkScope = scope;
2334+
2335+
while (checkScope->parentScope) {
2336+
checkScope = checkScope->parentScope;
2337+
}
2338+
2339+
// Find last loop scope (may be a parent or child scope)
2340+
while (checkScope) {
2341+
auto it = varPtr.loopVariableWrites.find(checkScope);
2342+
2343+
if (it != varPtr.loopVariableWrites.cend()) {
2344+
assert(!it->second.empty());
2345+
write = it->second.back();
2346+
}
2347+
2348+
if (checkScope->childScopes.empty())
2349+
checkScope = nullptr;
2350+
else
2351+
checkScope = checkScope->childScopes.back();
2352+
}
2353+
2354+
// If there aren't any write operations, we're safe
2355+
if (!write)
2356+
return true;
2357+
2358+
bool safe = true;
2359+
2360+
if (ins->type == LLVMInstruction::Type::WriteVariable)
2361+
safe = isVariableWriteResultTypeSafe(ins, expectedType, false, processed);
2362+
2363+
if (safe)
2364+
return isVariableWriteResultTypeSafe(write, expectedType, false, processed);
2365+
else
2366+
return false;
2367+
}
2368+
2369+
bool LLVMCodeBuilder::isVariableWriteResultTypeSafe(std::shared_ptr<LLVMInstruction> ins, Compiler::StaticType expectedType, bool ignoreSavedType, std::unordered_set<LLVMInstruction *> &processed)
2370+
const
2371+
{
2372+
const LLVMVariablePtr &varPtr = m_variablePtrs.at(ins->workVariable);
2373+
2374+
// If the write operation writes the value of another variable, recursively check its type safety
2375+
// TODO: Check get list item instruction
2376+
auto argIns = ins->args[0].second->instruction;
2377+
2378+
if (argIns && argIns->type == LLVMInstruction::Type::ReadVariable)
2379+
return isVariableTypeSafe(argIns, expectedType, processed);
2380+
2381+
// Check written type
2382+
return optimizeRegisterType(ins->args[0].second) == expectedType && (varPtr.type == expectedType || ignoreSavedType);
2383+
}
2384+
22162385
LLVMRegister *LLVMCodeBuilder::createOp(LLVMInstruction::Type type, Compiler::StaticType retType, Compiler::StaticType argType, const Compiler::Args &args)
22172386
{
22182387
return createOp({ type, currentLoopScope() }, retType, argType, args);

src/engine/internal/llvm/llvmcodebuilder.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ class LLVMCodeBuilder : public ICodeBuilder
137137
llvm::Value *castValue(LLVMRegister *reg, Compiler::StaticType targetType);
138138
llvm::Value *castRawValue(LLVMRegister *reg, Compiler::StaticType targetType);
139139
llvm::Constant *castConstValue(const Value &value, Compiler::StaticType targetType);
140-
Compiler::StaticType optimizeRegisterType(LLVMRegister *reg);
140+
Compiler::StaticType optimizeRegisterType(LLVMRegister *reg) const;
141141
llvm::Type *getType(Compiler::StaticType type);
142142
Compiler::StaticType getProcedureArgType(BlockPrototype::ArgType type);
143143
llvm::Value *isNaN(llvm::Value *num);
@@ -149,6 +149,9 @@ class LLVMCodeBuilder : public ICodeBuilder
149149
void reloadVariables(llvm::Value *targetVariables);
150150
void reloadLists();
151151
void updateListDataPtr(const LLVMListPtr &listPtr);
152+
bool isVariableTypeSafe(std::shared_ptr<LLVMInstruction> ins, Compiler::StaticType expectedType) const;
153+
bool isVariableTypeSafe(std::shared_ptr<LLVMInstruction> ins, Compiler::StaticType expectedType, std::unordered_set<LLVMInstruction *> &processed) const;
154+
bool isVariableWriteResultTypeSafe(std::shared_ptr<LLVMInstruction> ins, Compiler::StaticType expectedType, bool ignoreSavedType, std::unordered_set<LLVMInstruction *> &processed) const;
152155

153156
LLVMRegister *createOp(LLVMInstruction::Type type, Compiler::StaticType retType, Compiler::StaticType argType, const Compiler::Args &args);
154157
LLVMRegister *createOp(LLVMInstruction::Type type, Compiler::StaticType retType, const Compiler::ArgTypes &argTypes = {}, const Compiler::Args &args = {});

0 commit comments

Comments
 (0)