Skip to content

Commit

Permalink
[LV] Fold tail by masking - handle reductions
Browse files Browse the repository at this point in the history
Allow vectorizing loops that have reductions when tail is folded by masking.
A select is introduced in VPlan, choosing between the last value carried by the
loop-exit/live-out instruction of the reduction, and the penultimate value
carried by the reduction phi, according to the "i < n" mask of fold-tail.
This select replaces the last value as the live-out value of the loop.

Differential Revision: https://reviews.llvm.org/D66720

llvm-svn: 370173
  • Loading branch information
azaks committed Aug 28, 2019
1 parent 8fbe81f commit d15df0e
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 11 deletions.
19 changes: 9 additions & 10 deletions llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1176,18 +1176,17 @@ bool LoopVectorizationLegality::prepareToFoldTailByMasking() {
return false;
}

// TODO: handle reductions when tail is folded by masking.
if (!Reductions.empty()) {
reportVectorizationFailure(
"Loop has reductions, cannot fold tail by masking",
"Cannot fold tail by masking in the presence of reductions.",
"ReductionFoldingTailByMasking", ORE, TheLoop);
return false;
}
SmallPtrSet<const Value *, 8> ReductionLiveOuts;

// TODO: handle outside users when tail is folded by masking.
for (auto &Reduction : *getReductionVars())
ReductionLiveOuts.insert(Reduction.second.getLoopExitInstr());

// TODO: handle non-reduction outside users when tail is folded by masking.
for (auto *AE : AllowedExit) {
// Check that all users of allowed exit values are inside the loop.
// Check that all users of allowed exit values are inside the loop or
// are the live-out of a reduction.
if (ReductionLiveOuts.count(AE))
continue;
for (User *U : AE->users()) {
Instruction *UI = cast<Instruction>(U);
if (TheLoop->contains(UI))
Expand Down
41 changes: 40 additions & 1 deletion llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3678,6 +3678,26 @@ void InnerLoopVectorizer::fixReduction(PHINode *Phi) {

setDebugLocFromInst(Builder, LoopExitInst);

// If tail is folded by masking, the vector value to leave the loop should be
// a Select choosing between the vectorized LoopExitInst and vectorized Phi,
// instead of the former.
if (Cost->foldTailByMasking()) {
for (unsigned Part = 0; Part < UF; ++Part) {
Value *VecLoopExitInst =
VectorLoopValueMap.getVectorValue(LoopExitInst, Part);
Value *Sel = nullptr;
for (User *U : VecLoopExitInst->users()) {
if (isa<SelectInst>(U)) {
assert(!Sel && "Reduction exit feeding two selects");
Sel = U;
} else
assert(isa<PHINode>(U) && "Reduction exit must feed Phi's or select");
}
assert(Sel && "Reduction exit feeds no select");
VectorLoopValueMap.resetVectorValue(LoopExitInst, Part, Sel);
}
}

// If the vector reduction can be performed in a smaller type, we truncate
// then extend the loop exit value to enable InstCombine to evaluate the
// entire expression in the smaller type.
Expand Down Expand Up @@ -6939,8 +6959,15 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(unsigned MinVF,

// If the tail is to be folded by masking, the primary induction variable
// needs to be represented in VPlan for it to model early-exit masking.
if (CM.foldTailByMasking())
// Also, both the Phi and the live-out instruction of each reduction are
// required in order to introduce a select between them in VPlan.
if (CM.foldTailByMasking()) {
NeedDef.insert(Legal->getPrimaryInduction());
for (auto &Reduction : *Legal->getReductionVars()) {
NeedDef.insert(Reduction.first);
NeedDef.insert(Reduction.second.getLoopExitInstr());
}
}

// Collect instructions from the original loop that will become trivially dead
// in the vectorized loop. We don't need to vectorize these instructions. For
Expand Down Expand Up @@ -7067,6 +7094,18 @@ VPlanPtr LoopVectorizationPlanner::buildVPlanWithVPRecipes(
VPBlockUtils::disconnectBlocks(PreEntry, Entry);
delete PreEntry;

// Finally, if tail is folded by masking, introduce selects between the phi
// and the live-out instruction of each reduction, at the end of the latch.
if (CM.foldTailByMasking()) {
Builder.setInsertPoint(VPBB);
auto *Cond = RecipeBuilder.createBlockInMask(OrigLoop->getHeader(), Plan);
for (auto &Reduction : *Legal->getReductionVars()) {
VPValue *Phi = Plan->getVPValue(Reduction.first);
VPValue *Red = Plan->getVPValue(Reduction.second.getLoopExitInstr());
Builder.createNaryOp(Instruction::Select, {Cond, Red, Phi});
}
}

std::string PlanName;
raw_string_ostream RSO(PlanName);
unsigned VF = Range.Start;
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,14 @@ void VPInstruction::generateInstruction(VPTransformState &State,
State.set(this, V, Part);
break;
}
case Instruction::Select: {
Value *Cond = State.get(getOperand(0), Part);
Value *Op1 = State.get(getOperand(1), Part);
Value *Op2 = State.get(getOperand(2), Part);
Value *V = Builder.CreateSelect(Cond, Op1, Op2);
State.set(this, V, Part);
break;
}
default:
llvm_unreachable("Unsupported opcode for instruction");
}
Expand Down
56 changes: 56 additions & 0 deletions llvm/test/Transforms/LoopVectorize/X86/tail_loop_folding.ll
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,62 @@ for.body:
br i1 %exitcond, label %for.cond.cleanup, label %for.body, !llvm.loop !10
}

; Check that fold tail under optsize passes the reduction live-out value
; through a select.
; int reduction_i32(int *A, int *B, int N) {
; int sum = 0;
; for (int i = 0; i < N; ++i)
; sum += (A[i] + B[i]);
; return sum;
; }
;
define i32 @reduction_i32(i32* nocapture readonly %A, i32* nocapture readonly %B, i32 %N) #0 {
; CHECK-LABEL: @reduction_i32(
; CHECK: vector.body:
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %vector.ph ], [ [[INDEX_NEXT:%.*]], %vector.body ]
; CHECK-NEXT: [[ACCUM_PHI:%.*]] = phi <8 x i32> [ zeroinitializer, %vector.ph ], [ [[ACCUM:%.*]], %vector.body ]
; CHECK: [[ICMPULE:%.*]] = icmp ule <8 x i64>
; CHECK: [[LOAD1:%.*]] = call <8 x i32> @llvm.masked.load.v8i32.p0v8i32(<8 x i32>* {{.*}}, i32 4, <8 x i1> [[ICMPULE]], <8 x i32> undef)
; CHECK: [[LOAD2:%.*]] = call <8 x i32> @llvm.masked.load.v8i32.p0v8i32(<8 x i32>* {{.*}}, i32 4, <8 x i1> [[ICMPULE]], <8 x i32> undef)
; CHECK-NEXT: [[ADD:%.*]] = add nsw <8 x i32> [[LOAD2]], [[LOAD1]]
; CHECK-NEXT: [[ACCUM]] = add nuw nsw <8 x i32> [[ADD]], [[ACCUM_PHI]]
; CHECK: [[LIVEOUT:%.*]] = select <8 x i1> [[ICMPULE]], <8 x i32> [[ACCUM]], <8 x i32> [[ACCUM_PHI]]
; CHECK-NEXT: [[INDEX_NEXT]] = add i64 [[INDEX]], 8
; CHECK: middle.block:
; CHECK-NEXT: [[RDX_SHUF:%.*]] = shufflevector <8 x i32> [[LIVEOUT]], <8 x i32> undef, <8 x i32> <i32 4, i32 5, i32 6, i32 7, i32 undef, i32 undef, i32 undef, i32 undef>
; CHECK-NEXT: [[BIN_RDX:%.*]] = add <8 x i32> [[LIVEOUT]], [[RDX_SHUF]]
; CHECK-NEXT: [[RDX_SHUF4:%.*]] = shufflevector <8 x i32> [[BIN_RDX]], <8 x i32> undef, <8 x i32> <i32 2, i32 3, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef>
; CHECK-NEXT: [[BIN_RDX5:%.*]] = add <8 x i32> [[BIN_RDX]], [[RDX_SHUF4]]
; CHECK-NEXT: [[RDX_SHUF6:%.*]] = shufflevector <8 x i32> [[BIN_RDX5]], <8 x i32> undef, <8 x i32> <i32 1, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef>
; CHECK-NEXT: [[BIN_RDX7:%.*]] = add <8 x i32> [[BIN_RDX5]], [[RDX_SHUF6]]
; CHECK-NEXT: [[TMP17:%.*]] = extractelement <8 x i32> [[BIN_RDX7]], i32 0
; CHECK-NEXT: br i1 true, label %for.cond.cleanup, label %scalar.ph
; CHECK: scalar.ph:
; CHECK: for.cond.cleanup:
; CHECK-NEXT: [[SUM_1_LCSSA:%.*]] = phi i32 [ {{.*}}, %for.body ], [ [[TMP17]], %middle.block ]
; CHECK-NEXT: ret i32 [[SUM_1_LCSSA]]
;
entry:
br label %for.body

for.body:
%indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %entry ]
%sum.0 = phi i32 [ %sum.1, %for.body ], [ 0, %entry ]
%indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
%arrayidxA = getelementptr inbounds i32, i32* %A, i64 %indvars.iv
%0 = load i32, i32* %arrayidxA, align 4
%arrayidxB = getelementptr inbounds i32, i32* %B, i64 %indvars.iv
%1 = load i32, i32* %arrayidxB, align 4
%add = add nsw i32 %1, %0
%sum.1 = add nuw nsw i32 %add, %sum.0
%lftr.wideiv = trunc i64 %indvars.iv.next to i32
%exitcond = icmp eq i32 %lftr.wideiv, %N
br i1 %exitcond, label %for.cond.cleanup, label %for.body

for.cond.cleanup:
ret i32 %sum.1
}

; CHECK: !0 = distinct !{!0, !1}
; CHECK-NEXT: !1 = !{!"llvm.loop.isvectorized", i32 1}
; CHECK-NEXT: !2 = distinct !{!2, !3, !1}
Expand Down

0 comments on commit d15df0e

Please sign in to comment.