diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp index 931606c6f8fe12..67763d657152d5 100644 --- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp +++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -806,6 +806,34 @@ bool canSkipDef(MemoryDef *D, bool DefVisibleToCaller) { return false; } +// A memory location wrapper that represents a MemoryLocation, `MemLoc`, +// defined by `MemDef`. +struct MemoryLocationWrapper { + MemoryLocationWrapper(MemoryLocation MemLoc, MemoryDef *MemDef) + : MemLoc(MemLoc), MemDef(MemDef) { + assert(MemLoc.Ptr && "MemLoc should be not null"); + UnderlyingObject = getUnderlyingObject(MemLoc.Ptr); + DefInst = MemDef->getMemoryInst(); + } + + MemoryLocation MemLoc; + const Value *UnderlyingObject; + MemoryDef *MemDef; + Instruction *DefInst; +}; + +// A memory def wrapper that represents a MemoryDef and the MemoryLocation(s) +// defined by this MemoryDef. +struct MemoryDefWrapper { + MemoryDefWrapper(MemoryDef *MemDef, std::optional MemLoc) { + DefInst = MemDef->getMemoryInst(); + if (MemLoc.has_value()) + DefinedLocation = MemoryLocationWrapper(*MemLoc, MemDef); + } + Instruction *DefInst; + std::optional DefinedLocation = std::nullopt; +}; + struct DSEState { Function &F; AliasAnalysis &AA; @@ -1119,6 +1147,15 @@ struct DSEState { return MemoryLocation::getOrNone(I); } + std::optional getLocForInst(Instruction *I) { + if (isMemTerminatorInst(I)) { + if (auto Loc = getLocForTerminator(I)) { + return Loc->first; + } + } + return getLocForWrite(I); + } + /// Assuming this instruction has a dead analyzable write, can we delete /// this instruction? bool isRemovable(Instruction *I) { @@ -2132,182 +2169,196 @@ struct DSEState { } return MadeChange; } -}; -static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, - DominatorTree &DT, PostDominatorTree &PDT, - const TargetLibraryInfo &TLI, - const LoopInfo &LI) { - bool MadeChange = false; + // Try to eliminate dead defs that access `KillingLocWrapper.MemLoc` and are + // killed by `KillingLocWrapper.MemDef`. Return whether + // any changes were made, and whether `KillingLocWrapper.DefInst` was deleted. + std::pair + eliminateDeadDefs(const MemoryLocationWrapper &KillingLocWrapper); - DSEState State(F, AA, MSSA, DT, PDT, TLI, LI); - // For each store: - for (unsigned I = 0; I < State.MemDefs.size(); I++) { - MemoryDef *KillingDef = State.MemDefs[I]; - if (State.SkipStores.count(KillingDef)) + // Try to eliminate dead defs killed by `KillingDefWrapper` and return the + // change state: whether make any change. + bool eliminateDeadDefs(const MemoryDefWrapper &KillingDefWrapper); +}; + +std::pair +DSEState::eliminateDeadDefs(const MemoryLocationWrapper &KillingLocWrapper) { + bool Changed = false; + bool DeletedKillingLoc = false; + unsigned ScanLimit = MemorySSAScanLimit; + unsigned WalkerStepLimit = MemorySSAUpwardsStepLimit; + unsigned PartialLimit = MemorySSAPartialStoreLimit; + // Worklist of MemoryAccesses that may be killed by + // "KillingLocWrapper.MemDef". + SmallSetVector ToCheck; + // Track MemoryAccesses that have been deleted in the loop below, so we can + // skip them. Don't use SkipStores for this, which may contain reused + // MemoryAccess addresses. + SmallPtrSet Deleted; + [[maybe_unused]] unsigned OrigNumSkipStores = SkipStores.size(); + ToCheck.insert(KillingLocWrapper.MemDef->getDefiningAccess()); + + // Check if MemoryAccesses in the worklist are killed by + // "KillingLocWrapper.MemDef". + for (unsigned I = 0; I < ToCheck.size(); I++) { + MemoryAccess *Current = ToCheck[I]; + if (Deleted.contains(Current)) continue; - Instruction *KillingI = KillingDef->getMemoryInst(); + std::optional MaybeDeadAccess = getDomMemoryDef( + KillingLocWrapper.MemDef, Current, KillingLocWrapper.MemLoc, + KillingLocWrapper.UnderlyingObject, ScanLimit, WalkerStepLimit, + isMemTerminatorInst(KillingLocWrapper.DefInst), PartialLimit); - std::optional MaybeKillingLoc; - if (State.isMemTerminatorInst(KillingI)) { - if (auto KillingLoc = State.getLocForTerminator(KillingI)) - MaybeKillingLoc = KillingLoc->first; - } else { - MaybeKillingLoc = State.getLocForWrite(KillingI); + if (!MaybeDeadAccess) { + LLVM_DEBUG(dbgs() << " finished walk\n"); + continue; } - - if (!MaybeKillingLoc) { - LLVM_DEBUG(dbgs() << "Failed to find analyzable write location for " - << *KillingI << "\n"); + MemoryAccess *DeadAccess = *MaybeDeadAccess; + LLVM_DEBUG(dbgs() << " Checking if we can kill " << *DeadAccess); + if (isa(DeadAccess)) { + LLVM_DEBUG(dbgs() << "\n ... adding incoming values to worklist\n"); + for (Value *V : cast(DeadAccess)->incoming_values()) { + MemoryAccess *IncomingAccess = cast(V); + BasicBlock *IncomingBlock = IncomingAccess->getBlock(); + BasicBlock *PhiBlock = DeadAccess->getBlock(); + + // We only consider incoming MemoryAccesses that come before the + // MemoryPhi. Otherwise we could discover candidates that do not + // strictly dominate our starting def. + if (PostOrderNumbers[IncomingBlock] > PostOrderNumbers[PhiBlock]) + ToCheck.insert(IncomingAccess); + } continue; } - MemoryLocation KillingLoc = *MaybeKillingLoc; - assert(KillingLoc.Ptr && "KillingLoc should not be null"); - const Value *KillingUndObj = getUnderlyingObject(KillingLoc.Ptr); - LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by " - << *KillingDef << " (" << *KillingI << ")\n"); - - unsigned ScanLimit = MemorySSAScanLimit; - unsigned WalkerStepLimit = MemorySSAUpwardsStepLimit; - unsigned PartialLimit = MemorySSAPartialStoreLimit; - // Worklist of MemoryAccesses that may be killed by KillingDef. - SmallSetVector ToCheck; - // Track MemoryAccesses that have been deleted in the loop below, so we can - // skip them. Don't use SkipStores for this, which may contain reused - // MemoryAccess addresses. - SmallPtrSet Deleted; - [[maybe_unused]] unsigned OrigNumSkipStores = State.SkipStores.size(); - ToCheck.insert(KillingDef->getDefiningAccess()); - - bool Shortend = false; - bool IsMemTerm = State.isMemTerminatorInst(KillingI); - // Check if MemoryAccesses in the worklist are killed by KillingDef. - for (unsigned I = 0; I < ToCheck.size(); I++) { - MemoryAccess *Current = ToCheck[I]; - if (Deleted.contains(Current)) - continue; - - std::optional MaybeDeadAccess = State.getDomMemoryDef( - KillingDef, Current, KillingLoc, KillingUndObj, ScanLimit, - WalkerStepLimit, IsMemTerm, PartialLimit); - - if (!MaybeDeadAccess) { - LLVM_DEBUG(dbgs() << " finished walk\n"); + MemoryDefWrapper DeadDefWrapper( + cast(DeadAccess), + getLocForInst(cast(DeadAccess)->getMemoryInst())); + MemoryLocationWrapper &DeadLocWrapper = *DeadDefWrapper.DefinedLocation; + LLVM_DEBUG(dbgs() << " (" << *DeadLocWrapper.DefInst << ")\n"); + ToCheck.insert(DeadLocWrapper.MemDef->getDefiningAccess()); + NumGetDomMemoryDefPassed++; + + if (!DebugCounter::shouldExecute(MemorySSACounter)) + continue; + if (isMemTerminatorInst(KillingLocWrapper.DefInst)) { + if (KillingLocWrapper.UnderlyingObject != DeadLocWrapper.UnderlyingObject) continue; + LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: " + << *DeadLocWrapper.DefInst << "\n KILLER: " + << *KillingLocWrapper.DefInst << '\n'); + deleteDeadInstruction(DeadLocWrapper.DefInst, &Deleted); + ++NumFastStores; + Changed = true; + } else { + // Check if DeadI overwrites KillingI. + int64_t KillingOffset = 0; + int64_t DeadOffset = 0; + OverwriteResult OR = + isOverwrite(KillingLocWrapper.DefInst, DeadLocWrapper.DefInst, + KillingLocWrapper.MemLoc, DeadLocWrapper.MemLoc, + KillingOffset, DeadOffset); + if (OR == OW_MaybePartial) { + auto Iter = + IOLs.insert(std::make_pair( + DeadLocWrapper.DefInst->getParent(), InstOverlapIntervalsTy())); + auto &IOL = Iter.first->second; + OR = isPartialOverwrite(KillingLocWrapper.MemLoc, DeadLocWrapper.MemLoc, + KillingOffset, DeadOffset, + DeadLocWrapper.DefInst, IOL); } - - MemoryAccess *DeadAccess = *MaybeDeadAccess; - LLVM_DEBUG(dbgs() << " Checking if we can kill " << *DeadAccess); - if (isa(DeadAccess)) { - LLVM_DEBUG(dbgs() << "\n ... adding incoming values to worklist\n"); - for (Value *V : cast(DeadAccess)->incoming_values()) { - MemoryAccess *IncomingAccess = cast(V); - BasicBlock *IncomingBlock = IncomingAccess->getBlock(); - BasicBlock *PhiBlock = DeadAccess->getBlock(); - - // We only consider incoming MemoryAccesses that come before the - // MemoryPhi. Otherwise we could discover candidates that do not - // strictly dominate our starting def. - if (State.PostOrderNumbers[IncomingBlock] > - State.PostOrderNumbers[PhiBlock]) - ToCheck.insert(IncomingAccess); + if (EnablePartialStoreMerging && OR == OW_PartialEarlierWithFullLater) { + auto *DeadSI = dyn_cast(DeadLocWrapper.DefInst); + auto *KillingSI = dyn_cast(KillingLocWrapper.DefInst); + // We are re-using tryToMergePartialOverlappingStores, which requires + // DeadSI to dominate KillingSI. + // TODO: implement tryToMergeParialOverlappingStores using MemorySSA. + if (DeadSI && KillingSI && DT.dominates(DeadSI, KillingSI)) { + if (Constant *Merged = tryToMergePartialOverlappingStores( + KillingSI, DeadSI, KillingOffset, DeadOffset, DL, BatchAA, + &DT)) { + + // Update stored value of earlier store to merged constant. + DeadSI->setOperand(0, Merged); + ++NumModifiedStores; + Changed = true; + DeletedKillingLoc = true; + + // Remove killing store and remove any outstanding overlap + // intervals for the updated store. + deleteDeadInstruction(KillingSI, &Deleted); + auto I = IOLs.find(DeadSI->getParent()); + if (I != IOLs.end()) + I->second.erase(DeadSI); + break; + } } - continue; } - auto *DeadDefAccess = cast(DeadAccess); - Instruction *DeadI = DeadDefAccess->getMemoryInst(); - LLVM_DEBUG(dbgs() << " (" << *DeadI << ")\n"); - ToCheck.insert(DeadDefAccess->getDefiningAccess()); - NumGetDomMemoryDefPassed++; - - if (!DebugCounter::shouldExecute(MemorySSACounter)) - continue; - - MemoryLocation DeadLoc = *State.getLocForWrite(DeadI); - - if (IsMemTerm) { - const Value *DeadUndObj = getUnderlyingObject(DeadLoc.Ptr); - if (KillingUndObj != DeadUndObj) - continue; - LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: " << *DeadI - << "\n KILLER: " << *KillingI << '\n'); - State.deleteDeadInstruction(DeadI, &Deleted); + if (OR == OW_Complete) { + LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: " + << *DeadLocWrapper.DefInst << "\n KILLER: " + << *KillingLocWrapper.DefInst << '\n'); + deleteDeadInstruction(DeadLocWrapper.DefInst, &Deleted); ++NumFastStores; - MadeChange = true; - } else { - // Check if DeadI overwrites KillingI. - int64_t KillingOffset = 0; - int64_t DeadOffset = 0; - OverwriteResult OR = State.isOverwrite( - KillingI, DeadI, KillingLoc, DeadLoc, KillingOffset, DeadOffset); - if (OR == OW_MaybePartial) { - auto Iter = State.IOLs.insert( - std::make_pair( - DeadI->getParent(), InstOverlapIntervalsTy())); - auto &IOL = Iter.first->second; - OR = isPartialOverwrite(KillingLoc, DeadLoc, KillingOffset, - DeadOffset, DeadI, IOL); - } - - if (EnablePartialStoreMerging && OR == OW_PartialEarlierWithFullLater) { - auto *DeadSI = dyn_cast(DeadI); - auto *KillingSI = dyn_cast(KillingI); - // We are re-using tryToMergePartialOverlappingStores, which requires - // DeadSI to dominate KillingSI. - // TODO: implement tryToMergeParialOverlappingStores using MemorySSA. - if (DeadSI && KillingSI && DT.dominates(DeadSI, KillingSI)) { - if (Constant *Merged = tryToMergePartialOverlappingStores( - KillingSI, DeadSI, KillingOffset, DeadOffset, State.DL, - State.BatchAA, &DT)) { - - // Update stored value of earlier store to merged constant. - DeadSI->setOperand(0, Merged); - ++NumModifiedStores; - MadeChange = true; - - Shortend = true; - // Remove killing store and remove any outstanding overlap - // intervals for the updated store. - State.deleteDeadInstruction(KillingSI, &Deleted); - auto I = State.IOLs.find(DeadSI->getParent()); - if (I != State.IOLs.end()) - I->second.erase(DeadSI); - break; - } - } - } - - if (OR == OW_Complete) { - LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: " << *DeadI - << "\n KILLER: " << *KillingI << '\n'); - State.deleteDeadInstruction(DeadI, &Deleted); - ++NumFastStores; - MadeChange = true; - } + Changed = true; } } + } - assert(State.SkipStores.size() - OrigNumSkipStores == Deleted.size() && - "SkipStores and Deleted out of sync?"); + assert(SkipStores.size() - OrigNumSkipStores == Deleted.size() && + "SkipStores and Deleted out of sync?"); - // Check if the store is a no-op. - if (!Shortend && State.storeIsNoop(KillingDef, KillingUndObj)) { - LLVM_DEBUG(dbgs() << "DSE: Remove No-Op Store:\n DEAD: " << *KillingI - << '\n'); - State.deleteDeadInstruction(KillingI); - NumRedundantStores++; - MadeChange = true; - continue; - } + return {Changed, DeletedKillingLoc}; +} - // Can we form a calloc from a memset/malloc pair? - if (!Shortend && State.tryFoldIntoCalloc(KillingDef, KillingUndObj)) { - LLVM_DEBUG(dbgs() << "DSE: Remove memset after forming calloc:\n" - << " DEAD: " << *KillingI << '\n'); - State.deleteDeadInstruction(KillingI); - MadeChange = true; +bool DSEState::eliminateDeadDefs(const MemoryDefWrapper &KillingDefWrapper) { + if (!KillingDefWrapper.DefinedLocation.has_value()) { + LLVM_DEBUG(dbgs() << "Failed to find analyzable write location for " + << *KillingDefWrapper.DefInst << "\n"); + return false; + } + + auto &KillingLocWrapper = *KillingDefWrapper.DefinedLocation; + LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by " + << *KillingLocWrapper.MemDef << " (" + << *KillingLocWrapper.DefInst << ")\n"); + auto [Changed, DeletedKillingLoc] = eliminateDeadDefs(KillingLocWrapper); + + // Check if the store is a no-op. + if (!DeletedKillingLoc && storeIsNoop(KillingLocWrapper.MemDef, + KillingLocWrapper.UnderlyingObject)) { + LLVM_DEBUG(dbgs() << "DSE: Remove No-Op Store:\n DEAD: " + << *KillingLocWrapper.DefInst << '\n'); + deleteDeadInstruction(KillingLocWrapper.DefInst); + NumRedundantStores++; + return true; + } + // Can we form a calloc from a memset/malloc pair? + if (!DeletedKillingLoc && + tryFoldIntoCalloc(KillingLocWrapper.MemDef, + KillingLocWrapper.UnderlyingObject)) { + LLVM_DEBUG(dbgs() << "DSE: Remove memset after forming calloc:\n" + << " DEAD: " << *KillingLocWrapper.DefInst << '\n'); + deleteDeadInstruction(KillingLocWrapper.DefInst); + return true; + } + return Changed; +} + +static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, + DominatorTree &DT, PostDominatorTree &PDT, + const TargetLibraryInfo &TLI, + const LoopInfo &LI) { + bool MadeChange = false; + DSEState State(F, AA, MSSA, DT, PDT, TLI, LI); + // For each store: + for (unsigned I = 0; I < State.MemDefs.size(); I++) { + MemoryDef *KillingDef = State.MemDefs[I]; + if (State.SkipStores.count(KillingDef)) continue; - } + + MemoryDefWrapper KillingDefWrapper( + KillingDef, State.getLocForInst(KillingDef->getMemoryInst())); + MadeChange |= State.eliminateDeadDefs(KillingDefWrapper); } if (EnablePartialOverwriteTracking)