Skip to content

Commit abff3d7

Browse files
committed
Stored list size locally in warp scripts
1 parent f2de9fb commit abff3d7

File tree

4 files changed

+51
-10
lines changed

4 files changed

+51
-10
lines changed

src/engine/internal/llvm/instructions/lists.cpp

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ LLVMInstruction *Lists::buildClearList(LLVMInstruction *ins)
6767
assert(ins->args.size() == 0);
6868
LLVMListPtr &listPtr = m_utils.listPtr(ins->targetList);
6969
m_builder.CreateCall(m_utils.functions().resolve_list_clear(), listPtr.ptr);
70+
71+
if (listPtr.size) {
72+
// Update size
73+
m_builder.CreateStore(m_builder.getInt64(0), listPtr.size);
74+
}
7075
}
7176

7277
return ins->next;
@@ -88,7 +93,7 @@ LLVMInstruction *Lists::buildRemoveListItem(LLVMInstruction *ins)
8893

8994
// Range check
9095
llvm::Value *min = llvm::ConstantFP::get(llvmCtx, llvm::APFloat(0.0));
91-
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr);
96+
llvm::Value *size = m_utils.getListSize(listPtr);
9297
size = m_builder.CreateUIToFP(size, m_builder.getDoubleTy());
9398
llvm::Value *index = m_utils.castValue(arg.second, arg.first);
9499
llvm::Value *inRange = m_builder.CreateAnd(m_builder.CreateFCmpOGE(index, min), m_builder.CreateFCmpOLT(index, size));
@@ -100,6 +105,14 @@ LLVMInstruction *Lists::buildRemoveListItem(LLVMInstruction *ins)
100105
m_builder.SetInsertPoint(removeBlock);
101106
index = m_builder.CreateFPToUI(m_utils.castValue(arg.second, arg.first), m_builder.getInt64Ty());
102107
m_builder.CreateCall(m_utils.functions().resolve_list_remove(), { listPtr.ptr, index });
108+
109+
if (listPtr.size) {
110+
// Update size
111+
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.size);
112+
size = m_builder.CreateSub(size, m_builder.getInt64(1));
113+
m_builder.CreateStore(size, listPtr.size);
114+
}
115+
103116
m_builder.CreateBr(nextBlock);
104117

105118
m_builder.SetInsertPoint(nextBlock);
@@ -118,7 +131,7 @@ LLVMInstruction *Lists::buildAppendToList(LLVMInstruction *ins)
118131

119132
// Check if enough space is allocated
120133
llvm::Value *allocatedSize = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.allocatedSizePtr);
121-
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr);
134+
llvm::Value *size = m_utils.getListSize(listPtr);
122135
llvm::Value *isAllocated = m_builder.CreateICmpUGT(allocatedSize, size);
123136
llvm::BasicBlock *ifBlock = llvm::BasicBlock::Create(llvmCtx, "", function);
124137
llvm::BasicBlock *elseBlock = llvm::BasicBlock::Create(llvmCtx, "", function);
@@ -129,7 +142,7 @@ LLVMInstruction *Lists::buildAppendToList(LLVMInstruction *ins)
129142
m_builder.SetInsertPoint(ifBlock);
130143
llvm::Value *itemPtr = m_utils.getListItem(listPtr, size);
131144
m_utils.createValueStore(arg.second, itemPtr, type);
132-
m_builder.CreateStore(m_builder.CreateAdd(size, m_builder.getInt64(1)), listPtr.sizePtr);
145+
m_builder.CreateStore(m_builder.CreateAdd(size, m_builder.getInt64(1)), listPtr.sizePtr); // update size stored in *sizePtr
133146
m_builder.CreateBr(nextBlock);
134147

135148
// Otherwise call appendEmpty()
@@ -140,6 +153,14 @@ LLVMInstruction *Lists::buildAppendToList(LLVMInstruction *ins)
140153
m_builder.CreateBr(nextBlock);
141154

142155
m_builder.SetInsertPoint(nextBlock);
156+
157+
if (listPtr.size) {
158+
// Update local size
159+
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.size);
160+
size = m_builder.CreateAdd(size, m_builder.getInt64(1));
161+
m_builder.CreateStore(size, listPtr.size);
162+
}
163+
143164
return ins->next;
144165
}
145166

@@ -155,7 +176,7 @@ LLVMInstruction *Lists::buildInsertToList(LLVMInstruction *ins)
155176
LLVMListPtr &listPtr = m_utils.listPtr(ins->targetList);
156177

157178
// Range check
158-
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr);
179+
llvm::Value *size = m_utils.getListSize(listPtr);
159180
llvm::Value *min = llvm::ConstantFP::get(llvmCtx, llvm::APFloat(0.0));
160181
size = m_builder.CreateUIToFP(size, m_builder.getDoubleTy());
161182
llvm::Value *index = m_utils.castValue(indexArg.second, indexArg.first);
@@ -170,6 +191,13 @@ LLVMInstruction *Lists::buildInsertToList(LLVMInstruction *ins)
170191
llvm::Value *itemPtr = m_builder.CreateCall(m_utils.functions().resolve_list_insert_empty(), { listPtr.ptr, index });
171192
m_utils.createValueStore(valueArg.second, itemPtr, type);
172193

194+
if (listPtr.size) {
195+
// Update size
196+
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.size);
197+
size = m_builder.CreateAdd(size, m_builder.getInt64(1));
198+
m_builder.CreateStore(size, listPtr.size);
199+
}
200+
173201
m_builder.CreateBr(nextBlock);
174202

175203
m_builder.SetInsertPoint(nextBlock);
@@ -195,7 +223,7 @@ LLVMInstruction *Lists::buildListReplace(LLVMInstruction *ins)
195223

196224
// Range check
197225
llvm::Value *min = llvm::ConstantFP::get(llvmCtx, llvm::APFloat(0.0));
198-
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr);
226+
llvm::Value *size = m_utils.getListSize(listPtr);
199227
size = m_builder.CreateUIToFP(size, m_builder.getDoubleTy());
200228
llvm::Value *index = m_utils.castValue(indexArg.second, indexArg.first);
201229
llvm::Value *inRange = m_builder.CreateAnd(m_builder.CreateFCmpOGE(index, min), m_builder.CreateFCmpOLT(index, size));
@@ -238,10 +266,8 @@ LLVMInstruction *Lists::buildGetListItem(LLVMInstruction *ins)
238266
const auto &arg = ins->args[0];
239267
LLVMListPtr &listPtr = m_utils.listPtr(ins->targetList);
240268

241-
Compiler::StaticType listType = ins->functionReturnReg->type();
242-
243269
llvm::Value *min = llvm::ConstantFP::get(m_utils.llvmCtx(), llvm::APFloat(0.0));
244-
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr);
270+
llvm::Value *size = m_utils.getListSize(listPtr);
245271
size = m_builder.CreateUIToFP(size, m_builder.getDoubleTy());
246272
llvm::Value *index = m_utils.castValue(arg.second, arg.first);
247273
llvm::Value *inRange = m_builder.CreateAnd(m_builder.CreateFCmpOGE(index, min), m_builder.CreateFCmpOLT(index, size));
@@ -259,7 +285,7 @@ LLVMInstruction *Lists::buildGetListSize(LLVMInstruction *ins)
259285
{
260286
assert(ins->args.size() == 0);
261287
const LLVMListPtr &listPtr = m_utils.listPtr(ins->targetList);
262-
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr);
288+
llvm::Value *size = m_utils.getListSize(listPtr);
263289
ins->functionReturnReg->value = m_builder.CreateUIToFP(size, m_builder.getDoubleTy());
264290

265291
return ins->next;

src/engine/internal/llvm/llvmbuildutils.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,14 @@ void LLVMBuildUtils::init(llvm::Function *function, BlockPrototype *procedurePro
9494

9595
listPtr.sizePtr = m_builder.CreateCall(m_functions.resolve_list_size_ptr(), listPtr.ptr);
9696
listPtr.allocatedSizePtr = m_builder.CreateCall(m_functions.resolve_list_alloc_size_ptr(), listPtr.ptr);
97+
98+
if (m_warp) {
99+
// Store list size locally to allow some optimizations
100+
listPtr.size = m_builder.CreateAlloca(m_builder.getInt64Ty(), nullptr, list->name() + ".size");
101+
102+
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr);
103+
m_builder.CreateStore(size, listPtr.size);
104+
}
97105
}
98106

99107
// Create end branch
@@ -882,14 +890,19 @@ void LLVMBuildUtils::createValueStore(LLVMRegister *reg, llvm::Value *destPtr, C
882890
createValueStore(reg, destPtr, Compiler::StaticType::Unknown, targetType);
883891
}
884892

893+
llvm::Value *LLVMBuildUtils::getListSize(const LLVMListPtr &listPtr)
894+
{
895+
return m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.size ? listPtr.size : listPtr.sizePtr);
896+
}
897+
885898
llvm::Value *LLVMBuildUtils::getListItem(const LLVMListPtr &listPtr, llvm::Value *index)
886899
{
887900
return m_builder.CreateGEP(m_valueDataType, getListDataPtr(listPtr), index);
888901
}
889902

890903
llvm::Value *LLVMBuildUtils::getListItemIndex(const LLVMListPtr &listPtr, Compiler::StaticType listType, LLVMRegister *item)
891904
{
892-
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr);
905+
llvm::Value *size = getListSize(listPtr);
893906
llvm::BasicBlock *condBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
894907
llvm::BasicBlock *bodyBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);
895908
llvm::BasicBlock *cmpIfBlock = llvm::BasicBlock::Create(m_llvmCtx, "", m_function);

src/engine/internal/llvm/llvmbuildutils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ class LLVMBuildUtils
8686
void createValueStore(LLVMRegister *reg, llvm::Value *destPtr, Compiler::StaticType destType, Compiler::StaticType targetType);
8787
void createValueStore(LLVMRegister *reg, llvm::Value *destPtr, Compiler::StaticType targetType);
8888

89+
llvm::Value *getListSize(const LLVMListPtr &listPtr);
8990
llvm::Value *getListItem(const LLVMListPtr &listPtr, llvm::Value *index);
9091
llvm::Value *getListItemIndex(const LLVMListPtr &listPtr, Compiler::StaticType listType, LLVMRegister *item);
9192
llvm::Value *createValue(LLVMRegister *reg);

src/engine/internal/llvm/llvmlistptr.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ struct LLVMListPtr
2222
llvm::Value *dataPtr = nullptr;
2323
llvm::Value *sizePtr = nullptr;
2424
llvm::Value *allocatedSizePtr = nullptr;
25+
llvm::Value *size = nullptr;
2526
};
2627

2728
} // namespace libscratchcpp

0 commit comments

Comments
 (0)