@@ -632,6 +632,8 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
632632 LLVMVariablePtr &varPtr = m_variablePtrs[step.workVariable ];
633633 varPtr.changed = true ;
634634
635+ const bool safe = isVariableTypeSafe (insPtr, varPtr.type );
636+
635637 // Initialize stack variable on first assignment
636638 if (!varPtr.onStack ) {
637639 varPtr.onStack = true ;
@@ -652,6 +654,9 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
652654 m_builder.CreateStore (m_builder.getInt32 (static_cast <uint32_t >(mappedType)), typeField);
653655 }
654656
657+ if (!safe)
658+ varPtr.type = Compiler::StaticType::Unknown;
659+
655660 createValueStore (arg.second , varPtr.stackPtr , type, varPtr.type );
656661 varPtr.type = type;
657662 m_scopeVariables.back ()[&varPtr] = varPtr.type ;
@@ -660,7 +665,11 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
660665
661666 case LLVMInstruction::Type::ReadVariable: {
662667 assert (step.args .size () == 0 );
663- const LLVMVariablePtr &varPtr = m_variablePtrs[step.workVariable ];
668+ LLVMVariablePtr &varPtr = m_variablePtrs[step.workVariable ];
669+
670+ if (!isVariableTypeSafe (insPtr, varPtr.type ))
671+ varPtr.type = Compiler::StaticType::Unknown;
672+
664673 step.functionReturnReg ->value = varPtr.onStack ? varPtr.stackPtr : varPtr.heapPtr ;
665674 step.functionReturnReg ->setType (varPtr.type );
666675 break ;
@@ -2091,7 +2100,7 @@ llvm::Constant *LLVMCodeBuilder::castConstValue(const Value &value, Compiler::St
20912100 }
20922101}
20932102
2094- Compiler::StaticType LLVMCodeBuilder::optimizeRegisterType (LLVMRegister *reg)
2103+ Compiler::StaticType LLVMCodeBuilder::optimizeRegisterType (LLVMRegister *reg) const
20952104{
20962105 Compiler::StaticType ret = reg->type ();
20972106
@@ -2213,6 +2222,166 @@ void LLVMCodeBuilder::updateListDataPtr(const LLVMListPtr &listPtr)
22132222 m_builder.CreateStore (m_builder.getInt1 (false ), listPtr.dataPtrDirty );
22142223}
22152224
2225+ bool LLVMCodeBuilder::isVariableTypeSafe (std::shared_ptr<LLVMInstruction> ins, Compiler::StaticType expectedType) const
2226+ {
2227+ std::unordered_set<LLVMInstruction *> processed;
2228+ return isVariableTypeSafe (ins, expectedType, processed);
2229+ }
2230+
2231+ bool LLVMCodeBuilder::isVariableTypeSafe (std::shared_ptr<LLVMInstruction> ins, Compiler::StaticType expectedType, std::unordered_set<LLVMInstruction *> &processed) const
2232+ {
2233+ /*
2234+ * The main part of the loop type analyzer.
2235+ *
2236+ * This is a recursive function which is called when variable read
2237+ * instruction is created. It checks the last write to the
2238+ * variable in one of the loop scopes.
2239+ *
2240+ * If the last write operation writes a value with a different
2241+ * type, it will return false, otherwise true.
2242+ *
2243+ * If the last written value is from a variable, this function
2244+ * is called for it to check its type safety (that's why it is
2245+ * recursive).
2246+ *
2247+ * If the variable had a write operation before (in the same,
2248+ * parent or child loop scope), it is checked recursively.
2249+ */
2250+
2251+ if (!ins)
2252+ return false ;
2253+
2254+ /*
2255+ * If we are processing something that has been already
2256+ * processed, give up to avoid infinite recursion.
2257+ *
2258+ * This can happen in edge cases like this:
2259+ * var = var
2260+ *
2261+ * or this:
2262+ * x = y
2263+ * y = x
2264+ *
2265+ * This code isn't considered valid, so don't bother
2266+ * optimizing.
2267+ */
2268+ if (processed.find (ins.get ()) != processed.cend ())
2269+ return false ;
2270+
2271+ processed.insert (ins.get ());
2272+
2273+ assert (std::find (m_instructions.begin (), m_instructions.end (), ins) != m_instructions.end ());
2274+ const LLVMVariablePtr &varPtr = m_variablePtrs.at (ins->workVariable );
2275+ auto scope = ins->loopScope ;
2276+
2277+ // If we aren't in a loop, we're safe
2278+ if (!scope)
2279+ return true ;
2280+
2281+ // If the loop scope contains a suspend and this is a non-warp script, the type may change between suspend and resume
2282+ if (scope->containsYield && !m_warp)
2283+ return false ;
2284+
2285+ std::shared_ptr<LLVMInstruction> write;
2286+
2287+ // Find this instruction
2288+ auto it = std::find (m_variableInstructions.begin (), m_variableInstructions.end (), ins);
2289+ assert (it != m_variableInstructions.end ());
2290+
2291+ // Find previous write instruction in this, parent or child loop scope
2292+ size_t index = it - m_variableInstructions.begin ();
2293+
2294+ if (index > 0 ) {
2295+ bool found = false ;
2296+
2297+ do {
2298+ index--;
2299+ write = m_variableInstructions[index];
2300+ found = (write->loopScope && write->type == LLVMInstruction::Type::WriteVariable && write->workVariable == ins->workVariable );
2301+ } while (index > 0 && !found);
2302+
2303+ if (found) {
2304+ // Check if the write operation is in this or child scope
2305+ auto parentScope = write->loopScope ;
2306+
2307+ while (parentScope && parentScope != scope)
2308+ parentScope = parentScope->parentScope ;
2309+
2310+ if (!parentScope) {
2311+ // Check if the write operation is in any of the parent scopes
2312+ parentScope = scope;
2313+
2314+ do {
2315+ parentScope = parentScope->parentScope ;
2316+ } while (parentScope && parentScope != write->loopScope );
2317+ }
2318+
2319+ // If there was a write operation before this instruction (in this, parent or child scope), check it
2320+ if (parentScope) {
2321+ if (parentScope == scope)
2322+ return isVariableWriteResultTypeSafe (write, expectedType, true , processed);
2323+ else
2324+ return isVariableTypeSafe (write, expectedType, processed);
2325+ }
2326+ }
2327+ }
2328+
2329+ // Get last write operation
2330+ write = nullptr ;
2331+
2332+ // Find root loop scope
2333+ auto checkScope = scope;
2334+
2335+ while (checkScope->parentScope ) {
2336+ checkScope = checkScope->parentScope ;
2337+ }
2338+
2339+ // Find last loop scope (may be a parent or child scope)
2340+ while (checkScope) {
2341+ auto it = varPtr.loopVariableWrites .find (checkScope);
2342+
2343+ if (it != varPtr.loopVariableWrites .cend ()) {
2344+ assert (!it->second .empty ());
2345+ write = it->second .back ();
2346+ }
2347+
2348+ if (checkScope->childScopes .empty ())
2349+ checkScope = nullptr ;
2350+ else
2351+ checkScope = checkScope->childScopes .back ();
2352+ }
2353+
2354+ // If there aren't any write operations, we're safe
2355+ if (!write)
2356+ return true ;
2357+
2358+ bool safe = true ;
2359+
2360+ if (ins->type == LLVMInstruction::Type::WriteVariable)
2361+ safe = isVariableWriteResultTypeSafe (ins, expectedType, false , processed);
2362+
2363+ if (safe)
2364+ return isVariableWriteResultTypeSafe (write, expectedType, false , processed);
2365+ else
2366+ return false ;
2367+ }
2368+
2369+ bool LLVMCodeBuilder::isVariableWriteResultTypeSafe (std::shared_ptr<LLVMInstruction> ins, Compiler::StaticType expectedType, bool ignoreSavedType, std::unordered_set<LLVMInstruction *> &processed)
2370+ const
2371+ {
2372+ const LLVMVariablePtr &varPtr = m_variablePtrs.at (ins->workVariable );
2373+
2374+ // If the write operation writes the value of another variable, recursively check its type safety
2375+ // TODO: Check get list item instruction
2376+ auto argIns = ins->args [0 ].second ->instruction ;
2377+
2378+ if (argIns && argIns->type == LLVMInstruction::Type::ReadVariable)
2379+ return isVariableTypeSafe (argIns, expectedType, processed);
2380+
2381+ // Check written type
2382+ return optimizeRegisterType (ins->args [0 ].second ) == expectedType && (varPtr.type == expectedType || ignoreSavedType);
2383+ }
2384+
22162385LLVMRegister *LLVMCodeBuilder::createOp (LLVMInstruction::Type type, Compiler::StaticType retType, Compiler::StaticType argType, const Compiler::Args &args)
22172386{
22182387 return createOp ({ type, currentLoopScope () }, retType, argType, args);
0 commit comments