@@ -23,6 +23,9 @@ using namespace libscratchcpp;
2323static std::unordered_map<ValueType, Compiler::StaticType>
2424 TYPE_MAP = { { ValueType::Number, Compiler::StaticType::Number }, { ValueType::Bool, Compiler::StaticType::Bool }, { ValueType::String, Compiler::StaticType::String } };
2525
26+ static const std::unordered_set<LLVMInstruction::Type>
27+ VAR_LIST_READ_INSTRUCTIONS = { LLVMInstruction::Type::ReadVariable, LLVMInstruction::Type::GetListItem, LLVMInstruction::Type::GetListItemIndex, LLVMInstruction::Type::ListContainsItem };
28+
2629LLVMCodeBuilder::LLVMCodeBuilder (LLVMCompilerContext *ctx, BlockPrototype *procedurePrototype) :
2730 m_ctx(ctx),
2831 m_target(ctx->target ()),
@@ -632,7 +635,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
632635 LLVMVariablePtr &varPtr = m_variablePtrs[step.workVariable ];
633636 varPtr.changed = true ;
634637
635- const bool safe = isVariableTypeSafe (insPtr, varPtr.type );
638+ const bool safe = isVarOrListTypeSafe (insPtr, varPtr.type );
636639
637640 // Initialize stack variable on first assignment
638641 if (!varPtr.onStack ) {
@@ -667,7 +670,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
667670 assert (step.args .size () == 0 );
668671 LLVMVariablePtr &varPtr = m_variablePtrs[step.workVariable ];
669672
670- if (!isVariableTypeSafe (insPtr, varPtr.type ))
673+ if (!isVarOrListTypeSafe (insPtr, varPtr.type ))
671674 varPtr.type = Compiler::StaticType::Unknown;
672675
673676 step.functionReturnReg ->value = varPtr.onStack && !(step.loopCondition && !m_warp) ? varPtr.stackPtr : varPtr.heapPtr ;
@@ -693,7 +696,10 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
693696 case LLVMInstruction::Type::RemoveListItem: {
694697 assert (step.args .size () == 1 );
695698 const auto &arg = step.args [0 ];
696- const LLVMListPtr &listPtr = m_listPtrs[step.workList ];
699+ LLVMListPtr &listPtr = m_listPtrs[step.workList ];
700+
701+ if (!isVarOrListTypeSafe (insPtr, listPtr.type ))
702+ listPtr.type = Compiler::StaticType::Unknown;
697703
698704 // Range check
699705 llvm::Value *min = llvm::ConstantFP::get (m_llvmCtx, llvm::APFloat (0.0 ));
@@ -722,6 +728,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
722728 Compiler::StaticType type = optimizeRegisterType (arg.second );
723729 LLVMListPtr &listPtr = m_listPtrs[step.workList ];
724730
731+ const bool safe = isVarOrListTypeSafe (insPtr, listPtr.type );
725732 auto &typeMap = m_scopeLists.back ();
726733
727734 if (typeMap.find (&listPtr) == typeMap.cend ()) {
@@ -732,6 +739,9 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
732739 typeMap[&listPtr] = listPtr.type ;
733740 }
734741
742+ if (!safe)
743+ listPtr.type = Compiler::StaticType::Unknown;
744+
735745 // Check if enough space is allocated
736746 llvm::Value *allocatedSize = m_builder.CreateLoad (m_builder.getInt64Ty (), listPtr.allocatedSizePtr );
737747 llvm::Value *size = m_builder.CreateLoad (m_builder.getInt64Ty (), listPtr.sizePtr );
@@ -767,6 +777,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
767777 Compiler::StaticType type = optimizeRegisterType (valueArg.second );
768778 LLVMListPtr &listPtr = m_listPtrs[step.workList ];
769779
780+ const bool safe = isVarOrListTypeSafe (insPtr, listPtr.type );
770781 auto &typeMap = m_scopeLists.back ();
771782
772783 if (typeMap.find (&listPtr) == typeMap.cend ()) {
@@ -777,6 +788,9 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
777788 typeMap[&listPtr] = listPtr.type ;
778789 }
779790
791+ if (!safe)
792+ listPtr.type = Compiler::StaticType::Unknown;
793+
780794 llvm::Value *oldAllocatedSize = m_builder.CreateLoad (m_builder.getInt64Ty (), listPtr.allocatedSizePtr );
781795
782796 // Range check
@@ -813,6 +827,9 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
813827 Compiler::StaticType type = optimizeRegisterType (valueArg.second );
814828 LLVMListPtr &listPtr = m_listPtrs[step.workList ];
815829
830+ if (!isVarOrListTypeSafe (insPtr, listPtr.type ))
831+ listPtr.type = Compiler::StaticType::Unknown;
832+
816833 // Range check
817834 llvm::Value *min = llvm::ConstantFP::get (m_llvmCtx, llvm::APFloat (0.0 ));
818835 llvm::Value *size = m_builder.CreateLoad (m_builder.getInt64Ty (), listPtr.sizePtr );
@@ -856,7 +873,10 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
856873 case LLVMInstruction::Type::GetListItem: {
857874 assert (step.args .size () == 1 );
858875 const auto &arg = step.args [0 ];
859- const LLVMListPtr &listPtr = m_listPtrs[step.workList ];
876+ LLVMListPtr &listPtr = m_listPtrs[step.workList ];
877+
878+ if (!isVarOrListTypeSafe (insPtr, listPtr.type ))
879+ listPtr.type = Compiler::StaticType::Unknown;
860880
861881 llvm::Value *min = llvm::ConstantFP::get (m_llvmCtx, llvm::APFloat (0.0 ));
862882 llvm::Value *size = m_builder.CreateLoad (m_builder.getInt64Ty (), listPtr.sizePtr );
@@ -884,15 +904,23 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
884904 case LLVMInstruction::Type::GetListItemIndex: {
885905 assert (step.args .size () == 1 );
886906 const auto &arg = step.args [0 ];
887- const LLVMListPtr &listPtr = m_listPtrs[step.workList ];
907+ LLVMListPtr &listPtr = m_listPtrs[step.workList ];
908+
909+ if (!isVarOrListTypeSafe (insPtr, listPtr.type ))
910+ listPtr.type = Compiler::StaticType::Unknown;
911+
888912 step.functionReturnReg ->value = m_builder.CreateSIToFP (getListItemIndex (listPtr, arg.second ), m_builder.getDoubleTy ());
889913 break ;
890914 }
891915
892916 case LLVMInstruction::Type::ListContainsItem: {
893917 assert (step.args .size () == 1 );
894918 const auto &arg = step.args [0 ];
895- const LLVMListPtr &listPtr = m_listPtrs[step.workList ];
919+ LLVMListPtr &listPtr = m_listPtrs[step.workList ];
920+
921+ if (!isVarOrListTypeSafe (insPtr, listPtr.type ))
922+ listPtr.type = Compiler::StaticType::Unknown;
923+
896924 llvm::Value *index = getListItemIndex (listPtr, arg.second );
897925 step.functionReturnReg ->value = m_builder.CreateICmpSGT (index, llvm::ConstantInt::get (m_builder.getInt64Ty (), -1 , true ));
898926 break ;
@@ -2258,30 +2286,31 @@ void LLVMCodeBuilder::updateListDataPtr(const LLVMListPtr &listPtr)
22582286 m_builder.CreateStore (m_builder.getInt1 (false ), listPtr.dataPtrDirty );
22592287}
22602288
2261- bool LLVMCodeBuilder::isVariableTypeSafe (std::shared_ptr<LLVMInstruction> ins, Compiler::StaticType expectedType) const
2289+ bool LLVMCodeBuilder::isVarOrListTypeSafe (std::shared_ptr<LLVMInstruction> ins, Compiler::StaticType expectedType) const
22622290{
22632291 std::unordered_set<LLVMInstruction *> processed;
2264- return isVariableTypeSafe (ins, expectedType, processed);
2292+ return isVarOrListTypeSafe (ins, expectedType, processed);
22652293}
22662294
2267- bool LLVMCodeBuilder::isVariableTypeSafe (std::shared_ptr<LLVMInstruction> ins, Compiler::StaticType expectedType, std::unordered_set<LLVMInstruction *> &processed) const
2295+ bool LLVMCodeBuilder::isVarOrListTypeSafe (std::shared_ptr<LLVMInstruction> ins, Compiler::StaticType expectedType, std::unordered_set<LLVMInstruction *> &processed) const
22682296{
22692297 /*
22702298 * The main part of the loop type analyzer.
22712299 *
2272- * This is a recursive function which is called when variable read
2273- * instruction is created. It checks the last write to the
2274- * variable in one of the loop scopes.
2300+ * This is a recursive function which is called when variable
2301+ * or list instruction is created. It checks the last write to
2302+ * the variable or list in one of the loop scopes.
22752303 *
22762304 * If the last write operation writes a value with a different
22772305 * type, it will return false, otherwise true.
22782306 *
2279- * If the last written value is from a variable, this function
2280- * is called for it to check its type safety (that's why it is
2281- * recursive).
2307+ * If the last written value is from a variable or list , this
2308+ * function is called for it to check its type safety (that's
2309+ * why it is recursive).
22822310 *
2283- * If the variable had a write operation before (in the same,
2284- * parent or child loop scope), it is checked recursively.
2311+ * If the variable or list had a write operation before (in
2312+ * the same, parent or child loop scope), it is checked
2313+ * recursively.
22852314 */
22862315
22872316 if (!ins)
@@ -2307,7 +2336,9 @@ bool LLVMCodeBuilder::isVariableTypeSafe(std::shared_ptr<LLVMInstruction> ins, C
23072336 processed.insert (ins.get ());
23082337
23092338 assert (std::find (m_instructions.begin (), m_instructions.end (), ins) != m_instructions.end ());
2310- const LLVMVariablePtr &varPtr = m_variablePtrs.at (ins->workVariable );
2339+ const LLVMVariablePtr *varPtr = ins->workVariable ? &m_variablePtrs.at (ins->workVariable ) : nullptr ;
2340+ const LLVMListPtr *listPtr = ins->workList ? &m_listPtrs.at (ins->workList ) : nullptr ;
2341+ assert ((varPtr || listPtr) && !(varPtr && listPtr));
23112342 auto scope = ins->loopScope ;
23122343
23132344 // If we aren't in a loop, we're safe
@@ -2319,21 +2350,23 @@ bool LLVMCodeBuilder::isVariableTypeSafe(std::shared_ptr<LLVMInstruction> ins, C
23192350 return false ;
23202351
23212352 std::shared_ptr<LLVMInstruction> write;
2353+ const auto &instructions = varPtr ? m_variableInstructions : m_listInstructions;
23222354
23232355 // Find this instruction
2324- auto it = std::find (m_variableInstructions .begin (), m_variableInstructions .end (), ins);
2325- assert (it != m_variableInstructions .end ());
2356+ auto it = std::find (instructions .begin (), instructions .end (), ins);
2357+ assert (it != instructions .end ());
23262358
23272359 // Find previous write instruction in this, parent or child loop scope
2328- size_t index = it - m_variableInstructions .begin ();
2360+ size_t index = it - instructions .begin ();
23292361
23302362 if (index > 0 ) {
23312363 bool found = false ;
23322364
23332365 do {
23342366 index--;
2335- write = m_variableInstructions[index];
2336- found = (write->loopScope && write->type == LLVMInstruction::Type::WriteVariable && write->workVariable == ins->workVariable );
2367+ write = instructions[index];
2368+ const bool isWrite = (VAR_LIST_READ_INSTRUCTIONS.find (write->type ) == VAR_LIST_READ_INSTRUCTIONS.cend ());
2369+ found = (write->loopScope && isWrite && write->workVariable == ins->workVariable );
23372370 } while (index > 0 && !found);
23382371
23392372 if (found) {
@@ -2355,13 +2388,15 @@ bool LLVMCodeBuilder::isVariableTypeSafe(std::shared_ptr<LLVMInstruction> ins, C
23552388 // If there was a write operation before this instruction (in this, parent or child scope), check it
23562389 if (parentScope) {
23572390 if (parentScope == scope)
2358- return isVariableWriteResultTypeSafe (write, expectedType, true , processed);
2391+ return isVarOrListWriteResultTypeSafe (write, expectedType, true , processed);
23592392 else
2360- return isVariableTypeSafe (write, expectedType, processed);
2393+ return isVarOrListTypeSafe (write, expectedType, processed);
23612394 }
23622395 }
23632396 }
23642397
2398+ const auto &loopWrites = varPtr ? varPtr->loopVariableWrites : listPtr->loopListWrites ;
2399+
23652400 // Get last write operation
23662401 write = nullptr ;
23672402
@@ -2374,9 +2409,9 @@ bool LLVMCodeBuilder::isVariableTypeSafe(std::shared_ptr<LLVMInstruction> ins, C
23742409
23752410 // Find last loop scope (may be a parent or child scope)
23762411 while (checkScope) {
2377- auto it = varPtr. loopVariableWrites .find (checkScope);
2412+ auto it = loopWrites .find (checkScope);
23782413
2379- if (it != varPtr. loopVariableWrites .cend ()) {
2414+ if (it != loopWrites .cend ()) {
23802415 assert (!it->second .empty ());
23812416 write = it->second .back ();
23822417 }
@@ -2393,29 +2428,36 @@ bool LLVMCodeBuilder::isVariableTypeSafe(std::shared_ptr<LLVMInstruction> ins, C
23932428
23942429 bool safe = true ;
23952430
2396- if (ins->type == LLVMInstruction::Type::WriteVariable)
2397- safe = isVariableWriteResultTypeSafe (ins, expectedType, false , processed);
2431+ if (VAR_LIST_READ_INSTRUCTIONS. find ( ins->type ) == VAR_LIST_READ_INSTRUCTIONS. cend ()) // write
2432+ safe = isVarOrListWriteResultTypeSafe (ins, expectedType, false , processed);
23982433
23992434 if (safe)
2400- return isVariableWriteResultTypeSafe (write, expectedType, false , processed);
2435+ return isVarOrListWriteResultTypeSafe (write, expectedType, false , processed);
24012436 else
24022437 return false ;
24032438}
24042439
2405- bool LLVMCodeBuilder::isVariableWriteResultTypeSafe (std::shared_ptr<LLVMInstruction> ins, Compiler::StaticType expectedType, bool ignoreSavedType, std::unordered_set<LLVMInstruction *> &processed)
2440+ bool LLVMCodeBuilder::isVarOrListWriteResultTypeSafe (std::shared_ptr<LLVMInstruction> ins, Compiler::StaticType expectedType, bool ignoreSavedType, std::unordered_set<LLVMInstruction *> &processed)
24062441 const
24072442{
2408- const LLVMVariablePtr &varPtr = m_variablePtrs.at (ins->workVariable );
2443+ const LLVMVariablePtr *varPtr = ins->workVariable ? &m_variablePtrs.at (ins->workVariable ) : nullptr ;
2444+ const LLVMListPtr *listPtr = ins->workList ? &m_listPtrs.at (ins->workList ) : nullptr ;
2445+ assert ((varPtr || listPtr) && !(varPtr && listPtr));
24092446
24102447 // If the write operation writes the value of another variable, recursively check its type safety
2411- // TODO: Check get list item instruction
2412- auto argIns = ins-> args [ 0 ]. second ->instruction ;
2448+ const auto arg = ins-> args . back (). second ; // value is always the last argument
2449+ auto argIns = arg ->instruction ;
24132450
2414- if (argIns && argIns->type == LLVMInstruction::Type::ReadVariable)
2415- return isVariableTypeSafe (argIns, expectedType, processed);
2451+ if (argIns && ( argIns->type == LLVMInstruction::Type::ReadVariable || argIns-> type == LLVMInstruction::Type::GetListItem) )
2452+ return isVarOrListTypeSafe (argIns, expectedType, processed);
24162453
24172454 // Check written type
2418- return optimizeRegisterType (ins->args [0 ].second ) == expectedType && (varPtr.type == expectedType || ignoreSavedType);
2455+ const bool typeMatches = (optimizeRegisterType (arg) == expectedType);
2456+
2457+ if (varPtr)
2458+ return typeMatches && (varPtr->type == expectedType || ignoreSavedType);
2459+ else
2460+ return typeMatches && (listPtr->type == expectedType || ignoreSavedType);
24192461}
24202462
24212463LLVMRegister *LLVMCodeBuilder::createOp (LLVMInstruction::Type type, Compiler::StaticType retType, Compiler::StaticType argType, const Compiler::Args &args)
0 commit comments