Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2095,24 +2095,6 @@ bool LoopVectorizationLegality::canFoldTailByMasking() const {
for (const auto &Reduction : getReductionVars())
ReductionLiveOuts.insert(Reduction.second.getLoopExitInstr());

// TODO: handle non-reduction outside users when tail is folded by masking.
for (auto *AE : AllowedExit) {
// Check that all users of allowed exit values are inside the loop or
// are the live-out of a reduction.
if (ReductionLiveOuts.count(AE))
continue;
for (User *U : AE->users()) {
Instruction *UI = cast<Instruction>(U);
if (TheLoop->contains(UI))
continue;
LLVM_DEBUG(
dbgs()
<< "LV: Cannot fold tail by masking, loop has an outside user for "
<< *UI << "\n");
return false;
}
}

for (const auto &Entry : getInductionVars()) {
PHINode *OrigPhi = Entry.first;
for (User *U : OrigPhi->users()) {
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8914,7 +8914,8 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
if (FinalReductionResult == U || Parent->getParent())
continue;
U->replaceUsesOfWith(OrigExitingVPV, FinalReductionResult);
if (match(U, m_ExtractLastElement(m_VPValue())))
if (match(U, m_CombineOr(m_ExtractLastElement(m_VPValue()),
m_ExtractLane(m_VPValue(), m_VPValue()))))
cast<VPInstruction>(U)->replaceAllUsesWith(FinalReductionResult);
}

Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,13 @@ class LLVM_ABI_FOR_TEST VPInstruction : public VPRecipeWithIRFlags,
// It produces the lane index across all unrolled iterations. Unrolling will
// add all copies of its original operand as additional operands.
FirstActiveLane,
// Calculates the last active lane index of the vector predicate operands.
// The predicates must be prefix-masks (all 1s before all 0s). Used when
// tail-folding to extract the correct live-out value from the last active
// iteration. It produces the lane index across all unrolled iterations.
// Unrolling will add all copies of its original operand as additional
// operands.
LastActiveLane,

// The opcodes below are used for VPInstructionWithType.
//
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
case VPInstruction::ExtractLane:
return inferScalarType(R->getOperand(1));
case VPInstruction::FirstActiveLane:
case VPInstruction::LastActiveLane:
return Type::getIntNTy(Ctx, 64);
case VPInstruction::ExtractLastElement:
case VPInstruction::ExtractLastLanePerPart:
Expand Down
22 changes: 22 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -399,12 +399,24 @@ m_ExtractElement(const Op0_t &Op0, const Op1_t &Op1) {
return m_VPInstruction<Instruction::ExtractElement>(Op0, Op1);
}

template <typename Op0_t, typename Op1_t>
inline VPInstruction_match<VPInstruction::ExtractLane, Op0_t, Op1_t>
m_ExtractLane(const Op0_t &Op0, const Op1_t &Op1) {
return m_VPInstruction<VPInstruction::ExtractLane>(Op0, Op1);
}

template <typename Op0_t>
inline VPInstruction_match<VPInstruction::ExtractLastLanePerPart, Op0_t>
m_ExtractLastLanePerPart(const Op0_t &Op0) {
return m_VPInstruction<VPInstruction::ExtractLastLanePerPart>(Op0);
}

template <typename Op0_t>
inline VPInstruction_match<VPInstruction::ExtractPenultimateElement, Op0_t>
m_ExtractPenultimateElement(const Op0_t &Op0) {
return m_VPInstruction<VPInstruction::ExtractPenultimateElement>(Op0);
}

template <typename Op0_t, typename Op1_t, typename Op2_t>
inline VPInstruction_match<VPInstruction::ActiveLaneMask, Op0_t, Op1_t, Op2_t>
m_ActiveLaneMask(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) {
Expand Down Expand Up @@ -433,6 +445,16 @@ m_FirstActiveLane(const Op0_t &Op0) {
return m_VPInstruction<VPInstruction::FirstActiveLane>(Op0);
}

template <typename Op0_t>
inline VPInstruction_match<VPInstruction::LastActiveLane, Op0_t>
m_LastActiveLane(const Op0_t &Op0) {
return m_VPInstruction<VPInstruction::LastActiveLane>(Op0);
}

inline VPInstruction_match<VPInstruction::StepVector> m_StepVector() {
return m_VPInstruction<VPInstruction::StepVector>();
}

template <unsigned Opcode, typename Op0_t>
inline AllRecipe_match<Opcode, Op0_t> m_Unary(const Op0_t &Op0) {
return AllRecipe_match<Opcode, Op0_t>(Op0);
Expand Down
39 changes: 34 additions & 5 deletions llvm/lib/Transforms/Vectorize/VPlanPredicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@ class VPPredicator {
/// possibly inserting new recipes at \p Dst (using Builder's insertion point)
VPValue *createEdgeMask(VPBasicBlock *Src, VPBasicBlock *Dst);

/// Returns the *entry* mask for \p VPBB.
VPValue *getBlockInMask(VPBasicBlock *VPBB) const {
return BlockMaskCache.lookup(VPBB);
}

/// Record \p Mask as the *entry* mask of \p VPBB, which is expected to not
/// already have a mask.
void setBlockInMask(VPBasicBlock *VPBB, VPValue *Mask) {
Expand All @@ -68,6 +63,11 @@ class VPPredicator {
}

public:
/// Returns the *entry* mask for \p VPBB.
VPValue *getBlockInMask(VPBasicBlock *VPBB) const {
return BlockMaskCache.lookup(VPBB);
}

/// Returns the precomputed predicate of the edge from \p Src to \p Dst.
VPValue *getEdgeMask(const VPBasicBlock *Src, const VPBasicBlock *Dst) const {
return EdgeMaskCache.lookup({Src, Dst});
Expand Down Expand Up @@ -301,5 +301,34 @@ VPlanTransforms::introduceMasksAndLinearize(VPlan &Plan, bool FoldTail) {

PrevVPBB = VPBB;
}

// If we folded the tail and introduced a header mask, any extract of the
// last element must be updated to extract from the last active lane of the
// header mask instead (i.e., the lane corresponding to the last active
// iteration).
if (FoldTail) {
assert(Plan.getExitBlocks().size() == 1 &&
"only a single-exit block is supported currently");
VPBasicBlock *EB = Plan.getExitBlocks().front();
assert(EB->getSinglePredecessor() == Plan.getMiddleBlock() &&
"the exit block must have middle block as single predecessor");

VPBuilder B(Plan.getMiddleBlock()->getTerminator());
for (auto &P : EB->phis()) {
auto *ExitIRI = cast<VPIRPhi>(&P);
VPValue *Inc = ExitIRI->getIncomingValue(0);
VPValue *Op;
if (!match(Inc, m_ExtractLastElement(m_VPValue(Op))))
continue;

// Compute the index of the last active lane.
VPValue *HeaderMask = Predicator.getBlockInMask(Header);
VPValue *LastActiveLane =
B.createNaryOp(VPInstruction::LastActiveLane, HeaderMask);
auto *Ext =
B.createNaryOp(VPInstruction::ExtractLane, {LastActiveLane, Op});
Inc->replaceAllUsesWith(Ext);
}
}
return Predicator.getBlockMaskCache();
}
41 changes: 36 additions & 5 deletions llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,6 @@ unsigned VPInstruction::getNumOperandsForOpcode(unsigned Opcode) {
case VPInstruction::ExtractLastElement:
case VPInstruction::ExtractLastLanePerPart:
case VPInstruction::ExtractPenultimateElement:
case VPInstruction::FirstActiveLane:
case VPInstruction::Not:
case VPInstruction::ResumeForEpilogue:
case VPInstruction::Unpack:
Expand Down Expand Up @@ -591,6 +590,8 @@ unsigned VPInstruction::getNumOperandsForOpcode(unsigned Opcode) {
case Instruction::Switch:
case VPInstruction::SLPLoad:
case VPInstruction::SLPStore:
case VPInstruction::FirstActiveLane:
case VPInstruction::LastActiveLane:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think with the test case added in #167897, this assertion:

assert((getNumOperandsForOpcode(Opcode) == -1u ||
          getNumOperandsForOpcode(Opcode) == getNumOperands()) &&
         "number of operands does not match opcode");

started firing on FirstActiveLane/LastActiveLane because we're now hitting the unrolling path with first order recurrences I think. These can have multiple operands so I think we need to return an unknown number of operands here.

// Cannot determine the number of operands from the opcode.
return -1u;
}
Expand Down Expand Up @@ -1004,8 +1005,10 @@ Value *VPInstruction::generate(VPTransformState &State) {
case VPInstruction::FirstActiveLane: {
if (getNumOperands() == 1) {
Value *Mask = State.get(getOperand(0));
// LastActiveLane might get expanded to a FirstActiveLane with an all-ones
// mask, so make sure zero returns VF and not poison.
return Builder.CreateCountTrailingZeroElems(Builder.getInt64Ty(), Mask,
true, Name);
/*ZeroIsPoison=*/false, Name);
}
// If there are multiple operands, create a chain of selects to pick the
// first operand with an active lane and add the number of lanes of the
Expand All @@ -1021,9 +1024,9 @@ Value *VPInstruction::generate(VPTransformState &State) {
Builder.CreateICmpEQ(State.get(getOperand(Idx)),
Builder.getFalse()),
Builder.getInt64Ty())
: Builder.CreateCountTrailingZeroElems(Builder.getInt64Ty(),
State.get(getOperand(Idx)),
true, Name);
: Builder.CreateCountTrailingZeroElems(
Builder.getInt64Ty(), State.get(getOperand(Idx)),
/*ZeroIsPoison=*/false, Name);
Value *Current = Builder.CreateAdd(
Builder.CreateMul(RuntimeVF, Builder.getInt64(Idx)), TrailingZeros);
if (Res) {
Expand Down Expand Up @@ -1174,6 +1177,29 @@ InstructionCost VPInstruction::computeCost(ElementCount VF,
{PredTy, Type::getInt1Ty(Ctx.LLVMCtx)});
return Ctx.TTI.getIntrinsicInstrCost(Attrs, Ctx.CostKind);
}
case VPInstruction::LastActiveLane: {
Type *ScalarTy = Ctx.Types.inferScalarType(getOperand(0));
if (VF.isScalar())
return Ctx.TTI.getCmpSelInstrCost(Instruction::ICmp, ScalarTy,
CmpInst::makeCmpResultType(ScalarTy),
CmpInst::ICMP_EQ, Ctx.CostKind);
// Calculate the cost of determining the lane index: NOT + cttz_elts + SUB.
auto *PredTy = toVectorTy(ScalarTy, VF);
IntrinsicCostAttributes Attrs(Intrinsic::experimental_cttz_elts,
Type::getInt64Ty(Ctx.LLVMCtx),
{PredTy, Type::getInt1Ty(Ctx.LLVMCtx)});
InstructionCost Cost = Ctx.TTI.getIntrinsicInstrCost(Attrs, Ctx.CostKind);
// Add cost of NOT operation on the predicate.
Cost += Ctx.TTI.getArithmeticInstrCost(
Instruction::Xor, PredTy, Ctx.CostKind,
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
{TargetTransformInfo::OK_UniformConstantValue,
TargetTransformInfo::OP_None});
// Add cost of SUB operation on the index.
Cost += Ctx.TTI.getArithmeticInstrCost(
Instruction::Sub, Type::getInt64Ty(Ctx.LLVMCtx), Ctx.CostKind);
return Cost;
}
case VPInstruction::FirstOrderRecurrenceSplice: {
assert(VF.isVector() && "Scalar FirstOrderRecurrenceSplice?");
SmallVector<int> Mask(VF.getKnownMinValue());
Expand Down Expand Up @@ -1228,6 +1254,7 @@ bool VPInstruction::isVectorToScalar() const {
getOpcode() == Instruction::ExtractElement ||
getOpcode() == VPInstruction::ExtractLane ||
getOpcode() == VPInstruction::FirstActiveLane ||
getOpcode() == VPInstruction::LastActiveLane ||
getOpcode() == VPInstruction::ComputeAnyOfResult ||
getOpcode() == VPInstruction::ComputeFindIVResult ||
getOpcode() == VPInstruction::ComputeReductionResult ||
Expand Down Expand Up @@ -1294,6 +1321,7 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
case VPInstruction::ActiveLaneMask:
case VPInstruction::ExplicitVectorLength:
case VPInstruction::FirstActiveLane:
case VPInstruction::LastActiveLane:
case VPInstruction::FirstOrderRecurrenceSplice:
case VPInstruction::LogicalAnd:
case VPInstruction::Not:
Expand Down Expand Up @@ -1470,6 +1498,9 @@ void VPInstruction::printRecipe(raw_ostream &O, const Twine &Indent,
case VPInstruction::FirstActiveLane:
O << "first-active-lane";
break;
case VPInstruction::LastActiveLane:
O << "last-active-lane";
break;
case VPInstruction::ReductionStartVector:
O << "reduction-start-vector";
break;
Expand Down
61 changes: 57 additions & 4 deletions llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -826,8 +826,8 @@ static VPValue *optimizeEarlyExitInductionUser(VPlan &Plan,
VPValue *Op,
ScalarEvolution &SE) {
VPValue *Incoming, *Mask;
if (!match(Op, m_VPInstruction<VPInstruction::ExtractLane>(
m_FirstActiveLane(m_VPValue(Mask)), m_VPValue(Incoming))))
if (!match(Op, m_ExtractLane(m_FirstActiveLane(m_VPValue(Mask)),
m_VPValue(Incoming))))
return nullptr;

auto *WideIV = getOptimizableIVOf(Incoming, SE);
Expand Down Expand Up @@ -1295,8 +1295,7 @@ static void simplifyRecipe(VPSingleDefRecipe *Def, VPTypeAnalysis &TypeInfo) {
}

// Look through ExtractPenultimateElement (BuildVector ....).
if (match(Def, m_VPInstruction<VPInstruction::ExtractPenultimateElement>(
m_BuildVector()))) {
if (match(Def, m_ExtractPenultimateElement(m_BuildVector()))) {
auto *BuildVector = cast<VPInstruction>(Def->getOperand(0));
Def->replaceAllUsesWith(
BuildVector->getOperand(BuildVector->getNumOperands() - 2));
Expand Down Expand Up @@ -2106,6 +2105,32 @@ bool VPlanTransforms::adjustFixedOrderRecurrences(VPlan &Plan,
// Set the first operand of RecurSplice to FOR again, after replacing
// all users.
RecurSplice->setOperand(0, FOR);

// Check for users extracting at the penultimate active lane of the FOR.
// If only a single lane is active in the current iteration, we need to
// select the last element from the previous iteration (from the FOR phi
// directly).
for (VPUser *U : RecurSplice->users()) {
if (!match(U, m_ExtractLane(m_LastActiveLane(m_VPValue()),
m_Specific(RecurSplice))))
continue;

VPBuilder B(cast<VPInstruction>(U));
VPValue *LastActiveLane = cast<VPInstruction>(U)->getOperand(0);
Type *I64Ty = Type::getInt64Ty(Plan.getContext());
VPValue *Zero = Plan.getOrAddLiveIn(ConstantInt::get(I64Ty, 0));
VPValue *One = Plan.getOrAddLiveIn(ConstantInt::get(I64Ty, 1));
VPValue *PenultimateIndex =
B.createNaryOp(Instruction::Sub, {LastActiveLane, One});
VPValue *PenultimateLastIter =
B.createNaryOp(VPInstruction::ExtractLane,
{PenultimateIndex, FOR->getBackedgeValue()});
VPValue *LastPrevIter =
B.createNaryOp(VPInstruction::ExtractLastElement, FOR);
VPValue *Cmp = B.createICmp(CmpInst::ICMP_EQ, LastActiveLane, Zero);
VPValue *Sel = B.createSelect(Cmp, LastPrevIter, PenultimateLastIter);
cast<VPInstruction>(U)->replaceAllUsesWith(Sel);
}
}
return true;
}
Expand Down Expand Up @@ -3492,6 +3517,34 @@ void VPlanTransforms::convertToConcreteRecipes(VPlan &Plan) {
ToRemove.push_back(Expr);
}

// Expand LastActiveLane into Not + FirstActiveLane + Sub.
auto *LastActiveL = dyn_cast<VPInstruction>(&R);
if (LastActiveL &&
LastActiveL->getOpcode() == VPInstruction::LastActiveLane) {
// Create Not(Mask) for all operands.
SmallVector<VPValue *, 2> NotMasks;
for (VPValue *Op : LastActiveL->operands()) {
VPValue *NotMask = Builder.createNot(Op, LastActiveL->getDebugLoc());
NotMasks.push_back(NotMask);
}

// Create FirstActiveLane on the inverted masks.
VPValue *FirstInactiveLane = Builder.createNaryOp(
VPInstruction::FirstActiveLane, NotMasks,
LastActiveL->getDebugLoc(), "first.inactive.lane");

// Subtract 1 to get the last active lane.
VPValue *One = Plan.getOrAddLiveIn(
ConstantInt::get(Type::getInt64Ty(Plan.getContext()), 1));
VPValue *LastLane = Builder.createNaryOp(
Instruction::Sub, {FirstInactiveLane, One},
LastActiveL->getDebugLoc(), "last.active.lane");

LastActiveL->replaceAllUsesWith(LastLane);
ToRemove.push_back(LastActiveL);
continue;
}

VPValue *VectorStep;
VPValue *ScalarStep;
if (!match(&R, m_VPInstruction<VPInstruction::WideIVStep>(
Expand Down
13 changes: 9 additions & 4 deletions llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ void UnrollState::unrollBlock(VPBlockBase *VPB) {
VPValue *Op1;
if (match(&R, m_VPInstruction<VPInstruction::AnyOf>(m_VPValue(Op1))) ||
match(&R, m_FirstActiveLane(m_VPValue(Op1))) ||
match(&R, m_LastActiveLane(m_VPValue(Op1))) ||
match(&R, m_VPInstruction<VPInstruction::ComputeAnyOfResult>(
m_VPValue(), m_VPValue(), m_VPValue(Op1))) ||
match(&R, m_VPInstruction<VPInstruction::ComputeReductionResult>(
Expand All @@ -364,17 +365,21 @@ void UnrollState::unrollBlock(VPBlockBase *VPB) {
continue;
}
VPValue *Op0;
if (match(&R, m_VPInstruction<VPInstruction::ExtractLane>(
m_VPValue(Op0), m_VPValue(Op1)))) {
if (match(&R, m_ExtractLane(m_VPValue(Op0), m_VPValue(Op1)))) {
addUniformForAllParts(cast<VPInstruction>(&R));
for (unsigned Part = 1; Part != UF; ++Part)
R.addOperand(getValueForPart(Op1, Part));
continue;
}
if (match(&R, m_ExtractLastElement(m_VPValue(Op0))) ||
match(&R, m_VPInstruction<VPInstruction::ExtractPenultimateElement>(
m_VPValue(Op0)))) {
match(&R, m_ExtractPenultimateElement(m_VPValue(Op0)))) {
addUniformForAllParts(cast<VPSingleDefRecipe>(&R));
if (isa<VPFirstOrderRecurrencePHIRecipe>(Op0)) {
assert(match(&R, m_ExtractLastElement(m_VPValue())) &&
"can only extract last element of FOR");
continue;
}

if (Plan.hasScalarVFOnly()) {
auto *I = cast<VPInstruction>(&R);
// Extracting from end with VF = 1 implies retrieving the last or
Expand Down
Loading
Loading