diff --git a/llvm/lib/CodeGen/MachineSink.cpp b/llvm/lib/CodeGen/MachineSink.cpp index 8da97dc7e74240fef7c0cd42a02e79a55f664ecf..a79ecb73bb580eebdc77aa7f0e67afd1e3aa6881 100644 --- a/llvm/lib/CodeGen/MachineSink.cpp +++ b/llvm/lib/CodeGen/MachineSink.cpp @@ -106,6 +106,8 @@ static cl::opt SinkIntoCycleLimit( cl::desc("The maximum number of instructions considered for cycle sinking."), cl::init(50), cl::Hidden); +extern cl::opt DoNotSinkPtrAddPostLoad; + STATISTIC(NumSunk, "Number of machine instructions sunk"); STATISTIC(NumCycleSunk, "Number of machine instructions sunk into a cycle"); STATISTIC(NumSplit, "Number of critical edges split"); @@ -238,6 +240,9 @@ namespace { SmallVectorImpl &Candidates); bool SinkIntoCycle(MachineCycle *Cycle, MachineInstr &I); + bool isProfitablePtrAddPostLoad(Register Reg, MachineInstr &MI, + MachineBasicBlock *MBB); + bool isProfitableToSinkTo(Register Reg, MachineInstr &MI, MachineBasicBlock *MBB, MachineBasicBlock *SuccToSinkTo, @@ -772,6 +777,45 @@ MachineSinking::getBBRegisterPressure(MachineBasicBlock &MBB) { return It.first->second; } +/// isProfitablePtrAddPostLoad - Return true if MI is not a post load PtrAdd. +/// When a pointer post-increment after loads to it in a loop. It may not be +/// profitable to sink the PtrAdd, which makes the distance between the load to +/// it closer and causes stall. +bool MachineSinking::isProfitablePtrAddPostLoad(Register Reg, MachineInstr &MI, + MachineBasicBlock *MBB) { + // Check if MI is inside a loop. + MachineCycle *MCycle = CI->getCycle(MBB); + if (!MCycle) + return true; + // Check if MI is a PtrAdd Instruction + const MCInstrDesc &InstrDesc = MI.getDesc(); + if (!InstrDesc.isAsCheapAsAMove() && !InstrDesc.isAdd()) + return true; + + // Collect Phi nodes take this PtrAdd as their incoming values. + SmallDenseSet Phis; + for (MachineInstr &UseInst : MRI->use_nodbg_instructions(Reg)) + if (UseInst.getOpcode() == TargetOpcode::COPY) + for (MachineInstr &UseInst2 : + MRI->use_nodbg_instructions(UseInst.getOperand(0).getReg())) + if (UseInst2.isPHI()) + Phis.insert(UseInst2.getOperand(0).getReg()); + if (Phis.empty()) + return true; + + // Check if any operand of MI takes value from Phi nodes and used by loads. + for (MachineOperand &MO : MI.all_uses()) + for (MachineInstr &UseInst : MRI->use_nodbg_instructions(MO.getReg())) + if (UseInst.mayLoad() && CI->getCycle(UseInst.getParent())) + for (MachineOperand &MO2 : UseInst.all_uses()) + if (Phis.count(MO2.getReg())) { + LLVM_DEBUG(dbgs() << "PtrAdd post load, not profitable.\n"); + return false; + } + + return true; +} + /// isProfitableToSinkTo - Return true if it is profitable to sink MI. bool MachineSinking::isProfitableToSinkTo(Register Reg, MachineInstr &MI, MachineBasicBlock *MBB, @@ -782,6 +826,10 @@ bool MachineSinking::isProfitableToSinkTo(Register Reg, MachineInstr &MI, if (MBB == SuccToSinkTo) return false; + if (DoNotSinkPtrAddPostLoad) + if (!isProfitablePtrAddPostLoad(Reg, MI, MBB)) + return false; + // It is profitable if SuccToSinkTo does not post dominate current block. if (!PDT->dominates(SuccToSinkTo, MBB)) return true; diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index a4369b83e732fc011c40e7f1cd8c9ee943f2301e..2b5c10ac63ca1da140f90077811162dff38c62e1 100644 --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -196,6 +196,10 @@ static cl::opt AllowDropSolutionIfLessProfitable( "lsr-drop-solution", cl::Hidden, cl::init(false), cl::desc("Attempt to drop solution if it is less profitable")); +cl::opt DoNotSinkPtrAddPostLoad( + "no-sink-ptradd-post-load", cl::Hidden, cl::init(false), + cl::desc("Avoid sinking post load PtrAdds to the loop latches")); + STATISTIC(NumTermFold, "Number of terminating condition fold recognized and performed"); @@ -2873,6 +2877,35 @@ bool IVChain::isProfitableIncrement(const SCEV *OperExpr, return !isHighCostExpansion(IncExpr, Processed, SE); } +/// Return true if the IVChain is incomplete or there is no relevant load +/// instruction exist before the IVOperand of the tail Phi. +/// When a pointer post-increment after loads to it in a loop. It may not be +/// profitable to sink it to the latch of the loop even with register saving, +/// which makes the distance between the PtrAdd closer to the load instructions +/// and causes stall. +static bool isProfitablePtrAddPostLoad(IVChain &Chain, DominatorTree &DT) { + // Only care about complete chains which GenerateIVChain may place the PtrAdd + // of its Phi to the latch of the loop. + IVInc &Tail = Chain.Incs.back(); + if (!isa(Tail.UserInst)) + return true; + + if (Tail.IncExpr->isZero()) + return true; + + GetElementPtrInst *PtrAdd = dyn_cast(Tail.IVOperand); + if (!PtrAdd) + return true; + + unsigned NumLoadPrePtrAdd = 0; + for (const IVInc &Inc : Chain.Incs) + if (isa(Inc.UserInst) && DT.dominates(Inc.UserInst, PtrAdd)) + ++NumLoadPrePtrAdd; + LLVM_DEBUG(dbgs() << "Chain: " << *Chain.Incs[0].UserInst + << " NumLoadPrePtrAdd: " << NumLoadPrePtrAdd << "\n"); + return NumLoadPrePtrAdd == 0; +} + /// Return true if the number of registers needed for the chain is estimated to /// be less than the number required for the individual IV users. First prohibit /// any IV users that keep the IV live across increments (the Users set should @@ -2886,7 +2919,8 @@ bool IVChain::isProfitableIncrement(const SCEV *OperExpr, static bool isProfitableChain(IVChain &Chain, SmallPtrSetImpl &Users, ScalarEvolution &SE, - const TargetTransformInfo &TTI) { + const TargetTransformInfo &TTI, + DominatorTree &DT) { if (StressIVChain) return true; @@ -2919,6 +2953,10 @@ static bool isProfitableChain(IVChain &Chain, if (TTI.isProfitableLSRChainElement(Chain.Incs[0].UserInst)) return true; + if (DoNotSinkPtrAddPostLoad) + if (!isProfitablePtrAddPostLoad(Chain, DT)) + return false; + for (const IVInc &Inc : Chain) { if (TTI.isProfitableLSRChainElement(Inc.UserInst)) return true; @@ -3152,7 +3190,7 @@ void LSRInstance::CollectChains() { for (unsigned UsersIdx = 0, NChains = IVChainVec.size(); UsersIdx < NChains; ++UsersIdx) { if (!isProfitableChain(IVChainVec[UsersIdx], - ChainUsersVec[UsersIdx].FarUsers, SE, TTI)) + ChainUsersVec[UsersIdx].FarUsers, SE, TTI, DT)) continue; // Preserve the chain at UsesIdx. if (ChainIdx != UsersIdx) diff --git a/llvm/test/Transforms/LoopStrengthReduce/dont-sink-ptradd-post-load.ll b/llvm/test/Transforms/LoopStrengthReduce/dont-sink-ptradd-post-load.ll new file mode 100644 index 0000000000000000000000000000000000000000..d46804f302241b3a9f5a61c9b623e92f33fb37d9 --- /dev/null +++ b/llvm/test/Transforms/LoopStrengthReduce/dont-sink-ptradd-post-load.ll @@ -0,0 +1,64 @@ +; Check if the ptradd is not sunk into the loop latch when the feature is enabled. +; RUN: opt < %s -loop-reduce -no-sink-ptradd-post-load -S | FileCheck -check-prefix=ENABLE %s +; RUN: opt < %s -loop-reduce -S | FileCheck -check-prefix=DISABLE %s + +; ModuleID = 'foo.ll' +source_filename = "foo.cpp" +target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128-Fn32" +target triple = "aarch64-unknown-linux-gnu" + +define dso_local noundef i16 @foo(ptr nocapture noundef readonly %ptr, i16 noundef %sum) local_unnamed_addr { +; ENABLE-LABEL: while.body: +; ENABLE-NEXT: [[PTR1:%.*]] = phi ptr [ [[GEP1:%.*]], %cleanup ], [ %{{.*}}, %while.body.preheader ] +; ENABLE: [[GEP1]] = getelementptr{{.*}}, ptr [[PTR1]], i64 8 +; ENABLE-LABEL: cleanup: +; ENABLE-NOT: [[GEP1]] = {{.*}} +; DISABLE-LABEL: while.body: +; DISABLE-NEXT: [[PTR2:%.*]] = phi ptr [ [[GEP2:%.*]], %cleanup ], [ %{{.*}}, %while.body.preheader ] +; DISABLE-NOT: [[GEP2]] = getelementptr{{.*}}, ptr [[PTR2]], i64 8 +; DISABLE-LABEL: cleanup: +; DISABLE: [[GEP2]] = getelementptr{{.*}}, ptr [[PTR2]], i64 8 +; +entry: + br label %while.body.preheader + +while.body.preheader: ; preds = %entry + br label %while.body + +while.body: ; preds = %while.body.preheader, %cleanup + %ptr.addr.046 = phi ptr [ %incdec.ptr, %cleanup ], [ %ptr, %while.body.preheader ] + %ret.045 = phi i16 [ %ret.1, %cleanup ], [ 0, %while.body.preheader ] + %temp.044 = phi i16 [ %temp.1, %cleanup ], [ 0, %while.body.preheader ] + %count.043 = phi i16 [ %inc, %cleanup ], [ 0, %while.body.preheader ] + %0 = load i32, ptr %ptr.addr.046, align 4 + %y3 = getelementptr inbounds i8, ptr %ptr.addr.046, i64 4 + %1 = load i16, ptr %y3, align 4 + %z4 = getelementptr inbounds i8, ptr %ptr.addr.046, i64 6 + %2 = load i16, ptr %z4, align 2 + %incdec.ptr = getelementptr inbounds i8, ptr %ptr.addr.046, i64 8 + %inc = add nuw i16 %count.043, 1 + %cmp5 = icmp eq i32 %0, 0 + br i1 %cmp5, label %if.then, label %if.end19 + +if.then: ; preds = %while.body + %add = add i16 %temp.044, 3 + %add8 = add i16 %add, %1 + %add10 = add i16 %add8, %2 + br label %cleanup + +if.end19: ; preds = %while.body + %add22 = add i16 %temp.044, -5 + %sub = add i16 %add22, %ret.045 + %add14 = add i16 %sub, %1 + %add24 = add i16 %add14, %2 + br label %cleanup + +cleanup: ; preds = %if.end19, %if.then + %temp.1 = phi i16 [ %add10, %if.then ], [ 0, %if.end19 ] + %ret.1 = phi i16 [ %ret.045, %if.then ], [ %add24, %if.end19 ] + %cmp = icmp ult i16 %inc, %sum + br i1 %cmp, label %while.body, label %while.end + +while.end: ; preds = %cleanup + ret i16 %ret.1 +}