Skip to content

Commit 0a06c6c

Browse files
committed
Optimize list type info
1 parent 00b98b9 commit 0a06c6c

File tree

4 files changed

+182
-42
lines changed

4 files changed

+182
-42
lines changed

src/engine/internal/llvm/instructions/lists.cpp

Lines changed: 164 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,11 @@ LLVMInstruction *Lists::buildClearList(LLVMInstruction *ins)
7474
m_builder.CreateStore(m_builder.getInt64(0), listPtr.size);
7575
}
7676

77-
if (listPtr.type) {
78-
// Update type
79-
m_builder.CreateStore(m_builder.getInt32(static_cast<uint32_t>(ValueType::Void)), listPtr.type);
77+
if (listPtr.hasNumber && listPtr.hasBool && listPtr.hasString) {
78+
// Reset type info
79+
m_builder.CreateStore(m_builder.getInt1(false), listPtr.hasNumber);
80+
m_builder.CreateStore(m_builder.getInt1(false), listPtr.hasBool);
81+
m_builder.CreateStore(m_builder.getInt1(false), listPtr.hasString);
8082
}
8183
}
8284

@@ -167,7 +169,7 @@ LLVMInstruction *Lists::buildAppendToList(LLVMInstruction *ins)
167169
m_builder.CreateStore(size, listPtr.size);
168170
}
169171

170-
createListTypeUpdate(listPtr, arg.second);
172+
createListTypeUpdate(listPtr, arg.second, type);
171173
return ins->next;
172174
}
173175

@@ -205,7 +207,7 @@ LLVMInstruction *Lists::buildInsertToList(LLVMInstruction *ins)
205207
m_builder.CreateStore(size, listPtr.size);
206208
}
207209

208-
createListTypeUpdate(listPtr, valueArg.second);
210+
createListTypeUpdate(listPtr, valueArg.second, type);
209211
m_builder.CreateBr(nextBlock);
210212

211213
m_builder.SetInsertPoint(nextBlock);
@@ -244,12 +246,16 @@ LLVMInstruction *Lists::buildListReplace(LLVMInstruction *ins)
244246
index = m_builder.CreateFPToUI(index, m_builder.getInt64Ty());
245247

246248
llvm::Value *itemPtr = m_utils.getListItem(listPtr, index);
247-
llvm::Value *itemType = m_builder.CreateLoad(m_builder.getInt32Ty(), m_utils.getValueTypePtr(itemPtr));
248-
llvm::Value *typeVar = createListTypeVar(listPtr, itemType);
249-
createListTypeAssumption(listPtr, itemType, typeVar);
249+
// llvm::Value *typeVar = createListTypeVar(listPtr, itemPtr);
250+
// createListTypeAssumption(listPtr, typeVar, ins->targetType);
250251

251-
m_utils.createValueStore(itemPtr, m_utils.getValueTypePtr(itemPtr), valueArg.second, listType, type);
252-
createListTypeUpdate(listPtr, valueArg.second);
252+
/*llvm::Value *typeVar = m_utils.addAlloca(m_builder.getInt32Ty());
253+
llvm::Value *t = m_builder.CreateLoad(m_builder.getInt32Ty(), m_utils.getValueTypePtr(itemPtr));
254+
m_builder.CreateStore(t, typeVar);*/
255+
llvm::Value *typeVar = m_utils.getValueTypePtr(itemPtr);
256+
257+
m_utils.createValueStore(itemPtr, typeVar, valueArg.second, listType, type);
258+
createListTypeUpdate(listPtr, valueArg.second, type);
253259
m_builder.CreateBr(nextBlock);
254260

255261
m_builder.SetInsertPoint(nextBlock);
@@ -292,13 +298,11 @@ LLVMInstruction *Lists::buildGetListItem(LLVMInstruction *ins)
292298
index = m_builder.CreateFPToUI(index, m_builder.getInt64Ty());
293299

294300
llvm::Value *itemPtr = m_builder.CreateSelect(inRange, m_utils.getListItem(listPtr, index), null);
295-
llvm::Value *stringType = m_builder.getInt32(static_cast<uint32_t>(ValueType::String));
296-
llvm::Value *type = m_builder.CreateSelect(inRange, m_builder.CreateLoad(m_builder.getInt32Ty(), m_utils.getValueTypePtr(itemPtr)), stringType);
297-
llvm::Value *typeVar = createListTypeVar(listPtr, type);
301+
llvm::Value *typeVar = createListTypeVar(listPtr, itemPtr, inRange);
298302

299303
ins->functionReturnReg->value = itemPtr;
300304
ins->functionReturnReg->typeVar = typeVar;
301-
createListTypeAssumption(listPtr, type, typeVar, inRange);
305+
createListTypeAssumption(listPtr, typeVar, ins->targetType, inRange);
302306

303307
return ins->next;
304308
}
@@ -339,50 +343,176 @@ LLVMInstruction *Lists::buildListContainsItem(LLVMInstruction *ins)
339343
return ins->next;
340344
}
341345

342-
void Lists::createListTypeUpdate(const LLVMListPtr &listPtr, const LLVMRegister *newValue)
346+
void Lists::createListTypeUpdate(const LLVMListPtr &listPtr, const LLVMRegister *newValue, Compiler::StaticType newValueType)
343347
{
344-
if (listPtr.type) {
345-
// Update type
346-
llvm::Value *currentType = m_builder.CreateLoad(m_builder.getInt32Ty(), listPtr.type);
347-
llvm::Value *newTypeFlag;
348+
if (listPtr.hasNumber && listPtr.hasBool && listPtr.hasString) {
349+
// Get the new type
350+
llvm::Value *newType;
348351

349352
if (newValue->isRawValue)
350-
newTypeFlag = m_builder.getInt32(static_cast<uint32_t>(m_utils.mapType(newValue->type())));
353+
newType = m_builder.getInt32(static_cast<uint32_t>(m_utils.mapType(newValue->type())));
351354
else {
352355
llvm::Value *typeField = m_builder.CreateStructGEP(m_utils.compilerCtx()->valueDataType(), newValue->value, 1);
353-
newTypeFlag = m_builder.CreateLoad(m_builder.getInt32Ty(), typeField);
356+
newType = m_builder.CreateLoad(m_builder.getInt32Ty(), typeField);
357+
}
358+
359+
// Set the appropriate type flag
360+
llvm::BasicBlock *defaultBlock = llvm::BasicBlock::Create(m_utils.llvmCtx(), "updateListType.default", m_utils.function());
361+
llvm::BasicBlock *mergeBlock = llvm::BasicBlock::Create(m_utils.llvmCtx(), "updateListType.merge", m_utils.function());
362+
llvm::SwitchInst *sw = m_builder.CreateSwitch(newType, defaultBlock, 4);
363+
364+
// Number case
365+
if ((newValueType & Compiler::StaticType::Number) == Compiler::StaticType::Number) {
366+
llvm::BasicBlock *numberBlock = llvm::BasicBlock::Create(m_utils.llvmCtx(), "updateListType.number", m_utils.function());
367+
sw->addCase(m_builder.getInt32(static_cast<uint32_t>(ValueType::Number)), numberBlock);
368+
m_builder.SetInsertPoint(numberBlock);
369+
m_builder.CreateStore(m_builder.getInt1(true), listPtr.hasNumber);
370+
m_builder.CreateBr(mergeBlock);
371+
}
372+
373+
// Bool case
374+
if ((newValueType & Compiler::StaticType::Bool) == Compiler::StaticType::Bool) {
375+
llvm::BasicBlock *boolBlock = llvm::BasicBlock::Create(m_utils.llvmCtx(), "updateListType.bool", m_utils.function());
376+
sw->addCase(m_builder.getInt32(static_cast<uint32_t>(ValueType::Bool)), boolBlock);
377+
m_builder.SetInsertPoint(boolBlock);
378+
m_builder.CreateStore(m_builder.getInt1(true), listPtr.hasBool);
379+
m_builder.CreateBr(mergeBlock);
354380
}
355381

356-
m_builder.CreateStore(m_builder.CreateOr(currentType, newTypeFlag), listPtr.type);
382+
// String case
383+
if ((newValueType & Compiler::StaticType::String) == Compiler::StaticType::String) {
384+
llvm::BasicBlock *stringBlock = llvm::BasicBlock::Create(m_utils.llvmCtx(), "updateListType.string", m_utils.function());
385+
sw->addCase(m_builder.getInt32(static_cast<uint32_t>(ValueType::String)), stringBlock);
386+
m_builder.SetInsertPoint(stringBlock);
387+
m_builder.CreateStore(m_builder.getInt1(true), listPtr.hasString);
388+
m_builder.CreateBr(mergeBlock);
389+
}
390+
391+
// Default case
392+
m_builder.SetInsertPoint(defaultBlock);
393+
m_builder.CreateBr(mergeBlock);
394+
395+
m_builder.SetInsertPoint(mergeBlock);
357396
}
358397
}
359398

360-
llvm::Value *Lists::createListTypeVar(const LLVMListPtr &listPtr, llvm::Value *itemType)
399+
llvm::Value *Lists::createListTypeVar(const LLVMListPtr &listPtr, llvm::Value *itemPtr, llvm::Value *inRange)
361400
{
362401
llvm::Value *typeVar = m_utils.addAlloca(m_builder.getInt32Ty());
363-
m_builder.CreateStore(itemType, typeVar);
402+
llvm::BasicBlock *inRangeBlock = nullptr;
403+
llvm::BasicBlock *outOfRangeBlock = nullptr;
404+
llvm::BasicBlock *nextBlock = nullptr;
405+
406+
if (inRange) {
407+
inRangeBlock = llvm::BasicBlock::Create(m_utils.llvmCtx(), "", m_utils.function());
408+
outOfRangeBlock = llvm::BasicBlock::Create(m_utils.llvmCtx(), "", m_utils.function());
409+
nextBlock = llvm::BasicBlock::Create(m_utils.llvmCtx(), "", m_utils.function());
410+
411+
m_builder.CreateCondBr(inRange, inRangeBlock, outOfRangeBlock);
412+
m_builder.SetInsertPoint(inRangeBlock);
413+
}
414+
415+
llvm::Value *type = m_builder.CreateLoad(m_builder.getInt32Ty(), m_utils.getValueTypePtr(itemPtr));
416+
m_builder.CreateStore(type, typeVar);
417+
418+
if (inRange) {
419+
m_builder.CreateBr(nextBlock);
420+
m_builder.SetInsertPoint(outOfRangeBlock);
421+
422+
llvm::Value *type = m_builder.getInt32(static_cast<uint32_t>(ValueType::String));
423+
m_builder.CreateStore(type, typeVar);
424+
m_builder.CreateBr(nextBlock);
425+
426+
m_builder.SetInsertPoint(nextBlock);
427+
}
428+
364429
return typeVar;
365430
}
366431

367-
void Lists::createListTypeAssumption(const LLVMListPtr &listPtr, llvm::Value *itemType, llvm::Value *typeVar, llvm::Value *inRange)
432+
void Lists::createListTypeAssumption(const LLVMListPtr &listPtr, llvm::Value *typeVar, Compiler::StaticType staticType, llvm::Value *inRange)
368433
{
369-
if (listPtr.type) {
434+
if (listPtr.hasNumber && listPtr.hasBool && listPtr.hasString) {
370435
llvm::Function *assumeIntrinsic = llvm::Intrinsic::getDeclaration(m_utils.module(), llvm::Intrinsic::assume);
371436

437+
// Load the compile time list type information
438+
bool staticHasNumber = (staticType & Compiler::StaticType::Number) == Compiler::StaticType::Number;
439+
bool staticHasBool = (staticType & Compiler::StaticType::Bool) == Compiler::StaticType::Bool;
440+
bool staticHasString = (staticType & Compiler::StaticType::String) == Compiler::StaticType::String;
441+
372442
// Load the runtime list type information
373-
llvm::Value *listTypeFlags = m_builder.CreateLoad(m_builder.getInt32Ty(), listPtr.type);
443+
llvm::Value *hasNumber;
374444

375-
// Create assumption that the item type is contained in the list type flags
376-
llvm::Value *typeIsValid = m_builder.CreateICmpEQ(m_builder.CreateAnd(listTypeFlags, itemType), itemType);
445+
if (staticHasNumber)
446+
hasNumber = m_builder.CreateLoad(m_builder.getInt1Ty(), listPtr.hasNumber);
447+
else
448+
hasNumber = m_builder.getInt1(false);
449+
450+
llvm::Value *hasBool;
451+
452+
if (staticHasBool)
453+
hasBool = m_builder.CreateLoad(m_builder.getInt1Ty(), listPtr.hasBool);
454+
else
455+
hasBool = m_builder.getInt1(false);
456+
457+
llvm::Value *hasString;
458+
459+
if (staticHasString)
460+
hasString = m_builder.CreateLoad(m_builder.getInt1Ty(), listPtr.hasString);
461+
else
462+
hasString = m_builder.getInt1(false);
463+
464+
llvm::Value *type = m_builder.CreateLoad(m_builder.getInt32Ty(), typeVar);
465+
466+
llvm::Value *numberType = m_builder.getInt32(static_cast<uint32_t>(ValueType::Number));
467+
llvm::Value *boolType = m_builder.getInt32(static_cast<uint32_t>(ValueType::Bool));
468+
llvm::Value *stringType = m_builder.getInt32(static_cast<uint32_t>(ValueType::String));
469+
470+
// Number
471+
llvm::BasicBlock *noNumberBlock = llvm::BasicBlock::Create(m_utils.llvmCtx(), "listTypeAssumption.noNumber", m_utils.function());
472+
llvm::BasicBlock *afterNoNumberBlock = llvm::BasicBlock::Create(m_utils.llvmCtx(), "listTypeAssumption.afterNoNumber", m_utils.function());
473+
m_builder.CreateCondBr(hasNumber, afterNoNumberBlock, noNumberBlock);
474+
475+
m_builder.SetInsertPoint(noNumberBlock);
476+
llvm::Value *isNotNumber = m_builder.CreateICmpNE(type, numberType);
477+
m_builder.CreateCall(assumeIntrinsic, isNotNumber);
478+
m_builder.CreateBr(afterNoNumberBlock);
479+
480+
m_builder.SetInsertPoint(afterNoNumberBlock);
481+
482+
// Bool
483+
llvm::BasicBlock *noBoolBlock = llvm::BasicBlock::Create(m_utils.llvmCtx(), "listTypeAssumption.noBool", m_utils.function());
484+
llvm::BasicBlock *afterNoBoolBlock = llvm::BasicBlock::Create(m_utils.llvmCtx(), "listTypeAssumption.afterNoBool", m_utils.function());
485+
m_builder.CreateCondBr(hasBool, afterNoBoolBlock, noBoolBlock);
486+
487+
m_builder.SetInsertPoint(noBoolBlock);
488+
llvm::Value *isNotBool = m_builder.CreateICmpNE(type, boolType);
489+
m_builder.CreateCall(assumeIntrinsic, isNotBool);
490+
m_builder.CreateBr(afterNoBoolBlock);
491+
492+
m_builder.SetInsertPoint(afterNoBoolBlock);
493+
494+
// String
495+
llvm::BasicBlock *noStringBlock = llvm::BasicBlock::Create(m_utils.llvmCtx(), "listTypeAssumption.noString", m_utils.function());
496+
llvm::BasicBlock *afterNoStringBlock = llvm::BasicBlock::Create(m_utils.llvmCtx(), "listTypeAssumption.afterNoString", m_utils.function());
497+
498+
if (inRange)
499+
m_builder.CreateCondBr(m_builder.CreateAnd(m_builder.CreateNot(hasString), inRange), noStringBlock, afterNoStringBlock);
500+
else
501+
m_builder.CreateCondBr(hasString, afterNoStringBlock, noStringBlock);
502+
503+
m_builder.SetInsertPoint(noStringBlock);
504+
llvm::Value *isNotString = m_builder.CreateICmpNE(type, stringType);
505+
m_builder.CreateCall(assumeIntrinsic, isNotString);
506+
m_builder.CreateBr(afterNoStringBlock);
507+
508+
m_builder.SetInsertPoint(afterNoStringBlock);
377509

378510
if (inRange) {
379511
llvm::Value *stringType = m_builder.getInt32(static_cast<uint32_t>(ValueType::String));
380-
llvm::Value *canNotBeString = m_builder.CreateICmpNE(m_builder.CreateAnd(listTypeFlags, stringType), stringType);
381-
llvm::Value *isString = m_builder.CreateICmpEQ(itemType, stringType);
512+
llvm::Value *canNotBeString = m_builder.CreateNot(hasString);
513+
llvm::Value *isString = m_builder.CreateICmpEQ(type, stringType);
382514
llvm::Value *impossible = m_builder.CreateAnd(m_builder.CreateAnd(inRange, canNotBeString), isString);
383-
typeIsValid = m_builder.CreateAnd(typeIsValid, m_builder.CreateNot(impossible));
515+
m_builder.CreateCall(assumeIntrinsic, m_builder.CreateNot(impossible));
384516
}
385-
386-
m_builder.CreateCall(assumeIntrinsic, typeIsValid);
387517
}
388518
}

src/engine/internal/llvm/instructions/lists.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
#pragma once
44

5+
#include <scratchcpp/compiler.h>
6+
57
#include "instructiongroup.h"
68

79
namespace libscratchcpp
@@ -32,9 +34,9 @@ class Lists : public InstructionGroup
3234
LLVMInstruction *buildGetListItemIndex(LLVMInstruction *ins);
3335
LLVMInstruction *buildListContainsItem(LLVMInstruction *ins);
3436

35-
void createListTypeUpdate(const LLVMListPtr &listPtr, const LLVMRegister *newValue);
36-
llvm::Value *createListTypeVar(const LLVMListPtr &listPtr, llvm::Value *itemType);
37-
void createListTypeAssumption(const LLVMListPtr &listPtr, llvm::Value *itemType, llvm::Value *typeVar, llvm::Value *inRange = nullptr);
37+
void createListTypeUpdate(const LLVMListPtr &listPtr, const LLVMRegister *newValue, Compiler::StaticType newValueType);
38+
llvm::Value *createListTypeVar(const LLVMListPtr &listPtr, llvm::Value *itemPtr, llvm::Value *inRange = nullptr);
39+
void createListTypeAssumption(const LLVMListPtr &listPtr, llvm::Value *typeVar, Compiler::StaticType staticType, llvm::Value *inRange);
3840
};
3941

4042
} // namespace llvmins

src/engine/internal/llvm/llvmbuildutils.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,10 @@ void LLVMBuildUtils::init(llvm::Function *function, BlockPrototype *procedurePro
8989
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr);
9090
m_builder.CreateStore(size, listPtr.size);
9191

92-
// Store list type locally to leave static type analysis to LLVM
93-
listPtr.type = m_builder.CreateAlloca(m_builder.getInt32Ty(), nullptr, list->name() + ".type");
94-
m_builder.CreateStore(m_builder.getInt32(static_cast<uint32_t>(ValueType::Number | ValueType::Bool | ValueType::String)), listPtr.type);
92+
// Store list type info locally to leave static type analysis to LLVM
93+
listPtr.hasNumber = m_builder.CreateAlloca(m_builder.getInt1Ty(), nullptr, list->name() + ".hasNumber");
94+
listPtr.hasBool = m_builder.CreateAlloca(m_builder.getInt1Ty(), nullptr, list->name() + ".hasBool");
95+
listPtr.hasString = m_builder.CreateAlloca(m_builder.getInt1Ty(), nullptr, list->name() + ".hasString");
9596
}
9697
}
9798

@@ -321,11 +322,15 @@ void LLVMBuildUtils::reloadVariables()
321322

322323
void LLVMBuildUtils::reloadLists()
323324
{
324-
// Load list size info
325+
// Load list size and type info
325326
if (m_warp) {
326327
for (auto &[list, listPtr] : m_listPtrs) {
327328
llvm::Value *size = m_builder.CreateLoad(m_builder.getInt64Ty(), listPtr.sizePtr);
328329
m_builder.CreateStore(size, listPtr.size);
330+
331+
m_builder.CreateStore(m_builder.getInt1(true), listPtr.hasNumber);
332+
m_builder.CreateStore(m_builder.getInt1(true), listPtr.hasBool);
333+
m_builder.CreateStore(m_builder.getInt1(true), listPtr.hasString);
329334
}
330335
}
331336
}

src/engine/internal/llvm/llvmlistptr.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ struct LLVMListPtr
2323
llvm::Value *sizePtr = nullptr;
2424
llvm::Value *allocatedSizePtr = nullptr;
2525
llvm::Value *size = nullptr;
26-
llvm::Value *type = nullptr;
26+
27+
llvm::Value *hasNumber = nullptr;
28+
llvm::Value *hasBool = nullptr;
29+
llvm::Value *hasString = nullptr;
2730
};
2831

2932
} // namespace libscratchcpp

0 commit comments

Comments
 (0)