Skip to content

Commit 01006dc

Browse files
committed
Optimize list index range check
1 parent c848360 commit 01006dc

File tree

2 files changed

+42
-31
lines changed

2 files changed

+42
-31
lines changed

src/engine/internal/llvm/instructions/lists.cpp

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -100,19 +100,17 @@ LLVMInstruction *Lists::buildRemoveListItem(LLVMInstruction *ins)
100100
LLVMListPtr &listPtr = m_utils.listPtr(ins->targetList);
101101

102102
// Range check
103-
llvm::Value *min = llvm::ConstantFP::get(llvmCtx, llvm::APFloat(0.0));
104-
llvm::Value *size = m_utils.getListSize(listPtr);
105-
size = m_builder.CreateUIToFP(size, m_builder.getDoubleTy());
106-
llvm::Value *index = m_utils.castValue(arg.second, arg.first);
107-
llvm::Value *inRange = m_builder.CreateAnd(m_builder.CreateFCmpOGE(index, min), m_builder.CreateFCmpOLT(index, size));
103+
llvm::Value *indexDouble = m_utils.castValue(arg.second, arg.first);
104+
llvm::Value *indexInt = getIndex(listPtr, indexDouble);
105+
llvm::Value *inRange = createSizeRangeCheck(listPtr, indexInt, "removeListItem.indexInRange");
106+
108107
llvm::BasicBlock *removeBlock = llvm::BasicBlock::Create(llvmCtx, "", function);
109108
llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(llvmCtx, "", function);
110109
m_builder.CreateCondBr(inRange, removeBlock, nextBlock);
111110

112111
// Remove
113112
m_builder.SetInsertPoint(removeBlock);
114-
index = m_builder.CreateFPToUI(m_utils.castValue(arg.second, arg.first), m_builder.getInt64Ty());
115-
m_builder.CreateCall(m_utils.functions().resolve_list_remove(), { listPtr.ptr, index });
113+
m_builder.CreateCall(m_utils.functions().resolve_list_remove(), { listPtr.ptr, indexInt });
116114

117115
if (listPtr.size) {
118116
// Update size
@@ -185,19 +183,17 @@ LLVMInstruction *Lists::buildInsertToList(LLVMInstruction *ins)
185183
LLVMListPtr &listPtr = m_utils.listPtr(ins->targetList);
186184

187185
// Range check
188-
llvm::Value *size = m_utils.getListSize(listPtr);
189-
llvm::Value *min = llvm::ConstantFP::get(llvmCtx, llvm::APFloat(0.0));
190-
size = m_builder.CreateUIToFP(size, m_builder.getDoubleTy());
191-
llvm::Value *index = m_utils.castValue(indexArg.second, indexArg.first);
192-
llvm::Value *inRange = m_builder.CreateAnd(m_builder.CreateFCmpOGE(index, min), m_builder.CreateFCmpOLE(index, size));
186+
llvm::Value *indexDouble = m_utils.castValue(indexArg.second, indexArg.first);
187+
llvm::Value *indexInt = getIndex(listPtr, indexDouble);
188+
llvm::Value *inRange = createSizeRangeCheck(listPtr, indexInt, "insertToList.indexInRange");
189+
193190
llvm::BasicBlock *insertBlock = llvm::BasicBlock::Create(llvmCtx, "", function);
194191
llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(llvmCtx, "", function);
195192
m_builder.CreateCondBr(inRange, insertBlock, nextBlock);
196193

197194
// Insert
198195
m_builder.SetInsertPoint(insertBlock);
199-
index = m_builder.CreateFPToUI(index, m_builder.getInt64Ty());
200-
llvm::Value *itemPtr = m_builder.CreateCall(m_utils.functions().resolve_list_insert_empty(), { listPtr.ptr, index });
196+
llvm::Value *itemPtr = m_builder.CreateCall(m_utils.functions().resolve_list_insert_empty(), { listPtr.ptr, indexInt });
201197
m_utils.createValueStore(itemPtr, m_utils.getValueTypePtr(itemPtr), valueArg.second, type);
202198

203199
if (listPtr.size) {
@@ -232,20 +228,18 @@ LLVMInstruction *Lists::buildListReplace(LLVMInstruction *ins)
232228
Compiler::StaticType listType = ins->targetType;
233229

234230
// Range check
235-
llvm::Value *min = llvm::ConstantFP::get(llvmCtx, llvm::APFloat(0.0));
236-
llvm::Value *size = m_utils.getListSize(listPtr);
237-
size = m_builder.CreateUIToFP(size, m_builder.getDoubleTy());
238-
llvm::Value *index = m_utils.castValue(indexArg.second, indexArg.first);
239-
llvm::Value *inRange = m_builder.CreateAnd(m_builder.CreateFCmpOGE(index, min), m_builder.CreateFCmpOLT(index, size));
231+
llvm::Value *indexDouble = m_utils.castValue(indexArg.second, indexArg.first);
232+
llvm::Value *indexInt = getIndex(listPtr, indexDouble);
233+
llvm::Value *inRange = createSizeRangeCheck(listPtr, indexInt, "listReplace.indexInRange");
234+
240235
llvm::BasicBlock *replaceBlock = llvm::BasicBlock::Create(llvmCtx, "", function);
241236
llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(llvmCtx, "", function);
242237
m_builder.CreateCondBr(inRange, replaceBlock, nextBlock);
243238

244239
// Replace
245240
m_builder.SetInsertPoint(replaceBlock);
246-
index = m_builder.CreateFPToUI(index, m_builder.getInt64Ty());
247241

248-
llvm::Value *itemPtr = m_utils.getListItem(listPtr, index);
242+
llvm::Value *itemPtr = m_utils.getListItem(listPtr, indexInt);
249243
// llvm::Value *typeVar = createListTypeVar(listPtr, itemPtr);
250244
// createListTypeAssumption(listPtr, typeVar, ins->targetType);
251245

@@ -275,6 +269,9 @@ LLVMInstruction *Lists::buildGetListContents(LLVMInstruction *ins)
275269

276270
LLVMInstruction *Lists::buildGetListItem(LLVMInstruction *ins)
277271
{
272+
llvm::LLVMContext &llvmCtx = m_utils.llvmCtx();
273+
llvm::Function *function = m_utils.function();
274+
278275
// Return empty string for empty lists
279276
if (ins->targetType == Compiler::StaticType::Void) {
280277
LLVMConstantRegister nullReg(Compiler::StaticType::String, "");
@@ -286,21 +283,19 @@ LLVMInstruction *Lists::buildGetListItem(LLVMInstruction *ins)
286283
const auto &arg = ins->args[0];
287284
LLVMListPtr &listPtr = m_utils.listPtr(ins->targetList);
288285

289-
llvm::Value *min = llvm::ConstantFP::get(m_utils.llvmCtx(), llvm::APFloat(0.0));
290-
llvm::Value *size = m_utils.getListSize(listPtr);
291-
size = m_builder.CreateUIToFP(size, m_builder.getDoubleTy());
292-
llvm::Value *index = m_utils.castValue(arg.second, arg.first);
293-
llvm::Value *inRange = m_builder.CreateAnd(m_builder.CreateFCmpOGE(index, min), m_builder.CreateFCmpOLT(index, size), "getListItem.indexInRange");
286+
// Range check
287+
llvm::Value *indexDouble = m_utils.castValue(arg.second, arg.first);
288+
llvm::Value *indexInt = getIndex(listPtr, indexDouble);
289+
llvm::Value *inRange = createSizeRangeCheck(listPtr, indexInt, "getListItem.indexInRange");
294290

295-
llvm::BasicBlock *inRangeBlock = llvm::BasicBlock::Create(m_utils.llvmCtx(), "getListItem.inRange", m_utils.function());
296-
llvm::BasicBlock *outOfRangeBlock = llvm::BasicBlock::Create(m_utils.llvmCtx(), "getListItem.outOfRange", m_utils.function());
297-
llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(m_utils.llvmCtx(), "getListItem.next", m_utils.function());
291+
llvm::BasicBlock *inRangeBlock = llvm::BasicBlock::Create(llvmCtx, "getListItem.inRange", function);
292+
llvm::BasicBlock *outOfRangeBlock = llvm::BasicBlock::Create(llvmCtx, "getListItem.outOfRange", function);
293+
llvm::BasicBlock *nextBlock = llvm::BasicBlock::Create(llvmCtx, "getListItem.next", function);
298294
m_builder.CreateCondBr(inRange, inRangeBlock, outOfRangeBlock);
299295

300296
// In range
301297
m_builder.SetInsertPoint(inRangeBlock);
302-
index = m_builder.CreateFPToUI(index, m_builder.getInt64Ty());
303-
llvm::Value *itemPtr = m_utils.getListItem(listPtr, index);
298+
llvm::Value *itemPtr = m_utils.getListItem(listPtr, indexInt);
304299
llvm::Value *itemType = m_builder.CreateLoad(m_builder.getInt32Ty(), m_utils.getValueTypePtr(itemPtr));
305300
m_builder.CreateBr(nextBlock);
306301

@@ -366,6 +361,19 @@ LLVMInstruction *Lists::buildListContainsItem(LLVMInstruction *ins)
366361
return ins->next;
367362
}
368363

364+
llvm::Value *Lists::getIndex(const LLVMListPtr &listPtr, llvm::Value *indexDouble)
365+
{
366+
llvm::Value *zero = llvm::ConstantFP::get(m_utils.llvmCtx(), llvm::APFloat(0.0));
367+
llvm::Value *isNegative = m_builder.CreateFCmpOLT(indexDouble, zero, "listIndex.isNegative");
368+
return m_builder.CreateSelect(isNegative, llvm::ConstantInt::get(m_builder.getInt64Ty(), INT64_MAX), m_builder.CreateFPToUI(indexDouble, m_builder.getInt64Ty(), "listIndex.int"));
369+
}
370+
371+
llvm::Value *Lists::createSizeRangeCheck(const LLVMListPtr &listPtr, llvm::Value *indexInt, const std::string &name)
372+
{
373+
llvm::Value *size = m_utils.getListSize(listPtr);
374+
return m_builder.CreateICmpULT(indexInt, size, name);
375+
}
376+
369377
void Lists::createListTypeUpdate(const LLVMListPtr &listPtr, const LLVMRegister *newValue, Compiler::StaticType newValueType)
370378
{
371379
if (listPtr.hasNumber && listPtr.hasBool && listPtr.hasString) {

src/engine/internal/llvm/instructions/lists.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ class Lists : public InstructionGroup
3434
LLVMInstruction *buildGetListItemIndex(LLVMInstruction *ins);
3535
LLVMInstruction *buildListContainsItem(LLVMInstruction *ins);
3636

37+
llvm::Value *getIndex(const LLVMListPtr &listPtr, llvm::Value *indexDouble);
38+
llvm::Value *createSizeRangeCheck(const LLVMListPtr &listPtr, llvm::Value *indexInt, const std::string &name);
39+
3740
void createListTypeUpdate(const LLVMListPtr &listPtr, const LLVMRegister *newValue, Compiler::StaticType newValueType);
3841
llvm::Value *createListTypeVar(const LLVMListPtr &listPtr, llvm::Value *type);
3942
void createListTypeAssumption(const LLVMListPtr &listPtr, llvm::Value *typeVar, Compiler::StaticType staticType, llvm::Value *inRange);

0 commit comments

Comments
 (0)