@@ -146,7 +146,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
146146 step.functionReturnReg ->value = ret;
147147
148148 if (step.functionReturnReg ->type () == Compiler::StaticType::String)
149- m_heap. push_back (step.functionReturnReg ->value );
149+ freeLater (step.functionReturnReg ->value );
150150 }
151151
152152 break ;
@@ -749,7 +749,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
749749 assert (step.args .size () == 0 );
750750 const LLVMListPtr &listPtr = m_listPtrs[step.workList ];
751751 llvm::Value *ptr = m_builder.CreateCall (resolve_list_to_string (), listPtr.ptr );
752- m_heap. push_back (ptr); // deallocate later
752+ freeLater (ptr);
753753 step.functionReturnReg ->value = ptr;
754754 break ;
755755 }
@@ -791,7 +791,8 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
791791
792792 case LLVMInstruction::Type::Yield:
793793 if (!m_warp) {
794- freeHeap ();
794+ // TODO: Do not allow use after suspend (use after free)
795+ freeScopeHeap ();
795796 syncVariables (targetVariables);
796797 coro->createSuspend ();
797798 reloadVariables (targetVariables);
@@ -809,7 +810,6 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
809810 assert (step.args .size () == 1 );
810811 const auto ® = step.args [0 ];
811812 assert (reg.first == Compiler::StaticType::Bool);
812- freeHeap ();
813813 statement.condition = castValue (reg.second , reg.first );
814814
815815 // Switch to body branch
@@ -836,7 +836,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
836836 // Jump to the branch after the if statement
837837 assert (!statement.afterIf );
838838 statement.afterIf = llvm::BasicBlock::Create (m_ctx, " " , func);
839- freeHeap ();
839+ freeScopeHeap ();
840840 m_builder.CreateBr (statement.afterIf );
841841
842842 // Create else branch
@@ -855,12 +855,12 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
855855 case LLVMInstruction::Type::EndIf: {
856856 assert (!ifStatements.empty ());
857857 LLVMIfStatement &statement = ifStatements.back ();
858+ freeScopeHeap ();
858859
859860 // Jump to the branch after the if statement
860861 if (!statement.afterIf )
861862 statement.afterIf = llvm::BasicBlock::Create (m_ctx, " " , func);
862863
863- freeHeap ();
864864 m_builder.CreateBr (statement.afterIf );
865865
866866 if (statement.elseBranch ) {
@@ -901,7 +901,6 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
901901
902902 // Clamp count if <= 0 (we can skip the loop if count is not positive)
903903 llvm::Value *comparison = m_builder.CreateFCmpULE (count, llvm::ConstantFP::get (m_ctx, llvm::APFloat (0.0 )));
904- freeHeap ();
905904 m_builder.CreateCondBr (comparison, loop.afterLoop , roundBranch);
906905
907906 // Round (Scratch-specific behavior)
@@ -955,7 +954,6 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
955954 const auto ® = step.args [0 ];
956955 assert (reg.first == Compiler::StaticType::Bool);
957956 llvm::Value *condition = castValue (reg.second , reg.first );
958- freeHeap ();
959957 m_builder.CreateCondBr (condition, body, loop.afterLoop );
960958
961959 // Switch to body branch
@@ -977,7 +975,6 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
977975 const auto ® = step.args [0 ];
978976 assert (reg.first == Compiler::StaticType::Bool);
979977 llvm::Value *condition = castValue (reg.second , reg.first );
980- freeHeap ();
981978 m_builder.CreateCondBr (condition, loop.afterLoop , body);
982979
983980 // Switch to body branch
@@ -990,7 +987,6 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
990987 LLVMLoop loop;
991988 loop.isRepeatLoop = false ;
992989 loop.conditionBranch = llvm::BasicBlock::Create (m_ctx, " " , func);
993- freeHeap ();
994990 m_builder.CreateBr (loop.conditionBranch );
995991 m_builder.SetInsertPoint (loop.conditionBranch );
996992 loops.push_back (loop);
@@ -1009,7 +1005,7 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
10091005 }
10101006
10111007 // Jump to the condition branch
1012- freeHeap ();
1008+ freeScopeHeap ();
10131009 m_builder.CreateBr (loop.conditionBranch );
10141010
10151011 // Switch to the branch after the loop
@@ -1032,7 +1028,8 @@ std::shared_ptr<ExecutableCode> LLVMCodeBuilder::finalize()
10321028 m_builder.CreateBr (endBranch);
10331029
10341030 m_builder.SetInsertPoint (endBranch);
1035- freeHeap ();
1031+ assert (m_heap.size () == 1 );
1032+ freeScopeHeap ();
10361033 syncVariables (targetVariables);
10371034
10381035 // End and verify the function
@@ -1535,6 +1532,8 @@ void LLVMCodeBuilder::pushScopeLevel()
15351532 m_scopeLists.push_back (listTypes);
15361533 } else
15371534 m_scopeLists.push_back (m_scopeLists.back ());
1535+
1536+ m_heap.push_back ({});
15381537}
15391538
15401539void LLVMCodeBuilder::popScopeLevel ()
@@ -1556,6 +1555,9 @@ void LLVMCodeBuilder::popScopeLevel()
15561555 }
15571556
15581557 m_scopeLists.pop_back ();
1558+
1559+ freeScopeHeap ();
1560+ m_heap.pop_back ();
15591561}
15601562
15611563void LLVMCodeBuilder::verifyFunction (llvm::Function *func)
@@ -1590,13 +1592,28 @@ LLVMRegister *LLVMCodeBuilder::addReg(std::shared_ptr<LLVMRegister> reg)
15901592 return reg.get ();
15911593}
15921594
1593- void LLVMCodeBuilder::freeHeap ()
1595+ void LLVMCodeBuilder::freeLater (llvm::Value *value)
1596+ {
1597+ assert (!m_heap.empty ());
1598+
1599+ if (m_heap.empty ())
1600+ return ;
1601+
1602+ m_heap.back ().push_back (value);
1603+ }
1604+
1605+ void LLVMCodeBuilder::freeScopeHeap ()
15941606{
1595- // Free dynamically allocated memory
1596- for (llvm::Value *ptr : m_heap)
1607+ if (m_heap.empty ())
1608+ return ;
1609+
1610+ // Free dynamically allocated memory in current scope
1611+ auto &heap = m_heap.back ();
1612+
1613+ for (llvm::Value *ptr : heap)
15971614 m_builder.CreateFree (ptr);
15981615
1599- m_heap .clear ();
1616+ heap .clear ();
16001617}
16011618
16021619llvm::Value *LLVMCodeBuilder::castValue (LLVMRegister *reg, Compiler::StaticType targetType)
@@ -1668,7 +1685,7 @@ llvm::Value *LLVMCodeBuilder::castValue(LLVMRegister *reg, Compiler::StaticType
16681685 case Compiler::StaticType::Unknown: {
16691686 // Cast to string
16701687 llvm::Value *ptr = m_builder.CreateCall (resolve_value_toCString (), reg->value );
1671- m_heap. push_back (ptr); // deallocate later
1688+ freeLater (ptr);
16721689 return ptr;
16731690 }
16741691
@@ -1731,7 +1748,7 @@ llvm::Value *LLVMCodeBuilder::castRawValue(LLVMRegister *reg, Compiler::StaticTy
17311748 case Compiler::StaticType::Number: {
17321749 // Convert double to string
17331750 llvm::Value *ptr = m_builder.CreateCall (resolve_value_doubleToCString (), reg->value );
1734- m_heap. push_back (ptr); // deallocate later
1751+ freeLater (ptr);
17351752 return ptr;
17361753 }
17371754
0 commit comments