Skip to content

Commit

Permalink
Apply initializes attribute in DSE
Browse files Browse the repository at this point in the history
  • Loading branch information
haopliu committed Aug 26, 2024
1 parent c811ea4 commit 58dd8a4
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 42 deletions.
226 changes: 184 additions & 42 deletions llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
#include "llvm/IR/Argument.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/ConstantRangeList.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DebugInfo.h"
Expand Down Expand Up @@ -164,6 +165,10 @@ static cl::opt<bool>
OptimizeMemorySSA("dse-optimize-memoryssa", cl::init(true), cl::Hidden,
cl::desc("Allow DSE to optimize memory accesses."));

static cl::opt<bool> EnableInitializesImprovement(
"enable-dse-initializes-attr-improvement", cl::init(false), cl::Hidden,
cl::desc("Enable the initializes attr improvement in DSE"));

//===----------------------------------------------------------------------===//
// Helper functions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -809,8 +814,10 @@ bool canSkipDef(MemoryDef *D, bool DefVisibleToCaller) {
// A memory location wrapper that represents a MemoryLocation, `MemLoc`,
// defined by `MemDef`.
struct MemoryLocationWrapper {
MemoryLocationWrapper(MemoryLocation MemLoc, MemoryDef *MemDef)
: MemLoc(MemLoc), MemDef(MemDef) {
MemoryLocationWrapper(MemoryLocation MemLoc, MemoryDef *MemDef,
bool DefByInitializesAttr)
: MemLoc(MemLoc), MemDef(MemDef),
DefByInitializesAttr(DefByInitializesAttr) {
assert(MemLoc.Ptr && "MemLoc should be not null");
UnderlyingObject = getUnderlyingObject(MemLoc.Ptr);
DefInst = MemDef->getMemoryInst();
Expand All @@ -820,20 +827,121 @@ struct MemoryLocationWrapper {
const Value *UnderlyingObject;
MemoryDef *MemDef;
Instruction *DefInst;
bool DefByInitializesAttr = false;
};

// A memory def wrapper that represents a MemoryDef and the MemoryLocation(s)
// defined by this MemoryDef.
struct MemoryDefWrapper {
MemoryDefWrapper(MemoryDef *MemDef, std::optional<MemoryLocation> MemLoc) {
MemoryDefWrapper(
MemoryDef *MemDef,
const SmallVectorImpl<std::pair<MemoryLocation, bool>> &MemLocations) {
DefInst = MemDef->getMemoryInst();
if (MemLoc.has_value())
DefinedLocation = MemoryLocationWrapper(*MemLoc, MemDef);
for (auto &[MemLoc, DefByInitializesAttr] : MemLocations)
DefinedLocations.push_back(
MemoryLocationWrapper(MemLoc, MemDef, DefByInitializesAttr));
}
Instruction *DefInst;
std::optional<MemoryLocationWrapper> DefinedLocation = std::nullopt;
SmallVector<MemoryLocationWrapper, 1> DefinedLocations;
};

bool HasInitializesAttr(Instruction *I) {
CallBase *CB = dyn_cast<CallBase>(I);
if (!CB)
return false;

for (size_t Idx = 0; Idx < CB->arg_size(); Idx++)
if (CB->paramHasAttr(Idx, Attribute::Initializes))
return true;
return false;
}

struct ArgumentInitInfo {
size_t Idx = -1;
ConstantRangeList Inits;
bool HasDeadOnUnwindAttr = false;
bool FuncHasNoUnwindAttr = false;
};

ConstantRangeList
GetMergedInitAttr(const SmallVectorImpl<ArgumentInitInfo> &Args) {
if (Args.empty())
return {};

// To address unwind, the function should have nounwind attribute or the
// arguments have dead_on_unwind attribute. Otherwise, return empty.
for (const auto &Arg : Args) {
if (!Arg.FuncHasNoUnwindAttr && !Arg.HasDeadOnUnwindAttr)
return {};
if (Arg.Inits.empty())
return {};
}

if (Args.size() == 1)
return Args[0].Inits;

ConstantRangeList MergedIntervals = Args[0].Inits;
for (size_t i = 1; i < Args.size(); i++) {
MergedIntervals = MergedIntervals.intersectWith(Args[i].Inits);
}
return MergedIntervals;
}

// Return the locations wrote by the initializes attribute.
// Note that this function considers:
// 1. Unwind edge: apply "initializes" attribute only if the callee has
// "nounwind" attribute or the argument has "dead_on_unwind" attribute.
// 2. Argument alias: for aliasing arguments, the "initializes" attribute is
// the merged range list of their "initializes" attributes.
SmallVector<MemoryLocation, 1>
GetInitializesArgMemLoc(const Instruction *I, BatchAAResults &BatchAA) {
const CallBase *CB = dyn_cast<CallBase>(I);
if (!CB)
return {};

bool HasNoUnwindAttr = CB->hasFnAttr(Attribute::NoUnwind);
SmallMapVector<Value *, SmallVector<ArgumentInitInfo, 2>, 2> Arguments;
for (size_t Idx = 0; Idx < CB->arg_size(); Idx++) {
bool HasDeadOnUnwindAttr = CB->paramHasAttr(Idx, Attribute::DeadOnUnwind);

ConstantRangeList Inits;
if (CB->paramHasAttr(Idx, Attribute::Initializes))
Inits = CB->getParamAttr(Idx, Attribute::Initializes)
.getValueAsConstantRangeList();

ArgumentInitInfo InitInfo{Idx, Inits, HasDeadOnUnwindAttr, HasNoUnwindAttr};
Value *CurArg = CB->getArgOperand(Idx);
bool FoundAliasing = false;
for (auto &[Arg, AliasList] : Arguments) {
if (BatchAA.isMustAlias(Arg, CurArg)) {
FoundAliasing = true;
AliasList.push_back(InitInfo);
}
}
if (!FoundAliasing)
Arguments[CurArg] = {InitInfo};
}

SmallVector<MemoryLocation, 1> Locations;
for (const auto &[_, Args] : Arguments) {
auto MergedInitAttr = GetMergedInitAttr(Args);
if (MergedInitAttr.empty())
continue;

for (const auto &Arg : Args) {
for (const auto &Range : MergedInitAttr) {
int64_t Start = Range.getLower().getSExtValue();
int64_t End = Range.getUpper().getSExtValue();
if (Start == 0)
Locations.push_back(MemoryLocation(CB->getArgOperand(Arg.Idx),
LocationSize::precise(End - Start),
CB->getAAMetadata()));
}
}
}
return Locations;
}

struct DSEState {
Function &F;
AliasAnalysis &AA;
Expand Down Expand Up @@ -911,7 +1019,8 @@ struct DSEState {

auto *MD = dyn_cast_or_null<MemoryDef>(MA);
if (MD && MemDefs.size() < MemorySSADefsPerBlockLimit &&
(getLocForWrite(&I) || isMemTerminatorInst(&I)))
(getLocForWrite(&I) || isMemTerminatorInst(&I) ||
HasInitializesAttr(&I)))
MemDefs.push_back(MD);
}
}
Expand Down Expand Up @@ -1147,13 +1256,24 @@ struct DSEState {
return MemoryLocation::getOrNone(I);
}

std::optional<MemoryLocation> getLocForInst(Instruction *I) {
SmallVector<std::pair<MemoryLocation, bool>, 1>
getLocForInst(Instruction *I, bool consider_initializes_attr) {
SmallVector<std::pair<MemoryLocation, bool>, 1> Locations;
if (isMemTerminatorInst(I)) {
if (auto Loc = getLocForTerminator(I)) {
return Loc->first;
if (auto Loc = getLocForTerminator(I))
Locations.push_back(std::make_pair(Loc->first, false));
return Locations;
}

if (auto Loc = getLocForWrite(I))
Locations.push_back(std::make_pair(*Loc, false));

if (consider_initializes_attr) {
for (auto &MemLoc : GetInitializesArgMemLoc(I, BatchAA)) {
Locations.push_back(std::make_pair(MemLoc, true));
}
}
return getLocForWrite(I);
return Locations;
}

/// Assuming this instruction has a dead analyzable write, can we delete
Expand Down Expand Up @@ -1365,7 +1485,8 @@ struct DSEState {
getDomMemoryDef(MemoryDef *KillingDef, MemoryAccess *StartAccess,
const MemoryLocation &KillingLoc, const Value *KillingUndObj,
unsigned &ScanLimit, unsigned &WalkerStepLimit,
bool IsMemTerm, unsigned &PartialLimit) {
bool IsMemTerm, unsigned &PartialLimit,
bool IsInitializesAttrMemLoc) {
if (ScanLimit == 0 || WalkerStepLimit == 0) {
LLVM_DEBUG(dbgs() << "\n ... hit scan limit\n");
return std::nullopt;
Expand Down Expand Up @@ -1602,7 +1723,19 @@ struct DSEState {

// Uses which may read the original MemoryDef mean we cannot eliminate the
// original MD. Stop walk.
if (isReadClobber(MaybeDeadLoc, UseInst)) {
// If KillingDef is a CallInst with "initializes" attribute, the reads in
// Callee would be dominated by initializations, so this should be safe.
bool IsKillingDefFromInitAttr = false;
if (IsInitializesAttrMemLoc) {
if (KillingI == UseInst &&
KillingUndObj == getUnderlyingObject(MaybeDeadLoc.Ptr)) {
IsKillingDefFromInitAttr = true;
// Note that, we don't need to check aliasing arguments here since
// aliasing has been considered at the begining.
}
}

if (isReadClobber(MaybeDeadLoc, UseInst) && !IsKillingDefFromInitAttr) {
LLVM_DEBUG(dbgs() << " ... found read clobber\n");
return std::nullopt;
}
Expand Down Expand Up @@ -2207,7 +2340,8 @@ DSEState::eliminateDeadDefs(const MemoryLocationWrapper &KillingLocWrapper) {
std::optional<MemoryAccess *> MaybeDeadAccess = getDomMemoryDef(
KillingLocWrapper.MemDef, Current, KillingLocWrapper.MemLoc,
KillingLocWrapper.UnderlyingObject, ScanLimit, WalkerStepLimit,
isMemTerminatorInst(KillingLocWrapper.DefInst), PartialLimit);
isMemTerminatorInst(KillingLocWrapper.DefInst), PartialLimit,
KillingLocWrapper.DefByInitializesAttr);

if (!MaybeDeadAccess) {
LLVM_DEBUG(dbgs() << " finished walk\n");
Expand All @@ -2232,8 +2366,11 @@ DSEState::eliminateDeadDefs(const MemoryLocationWrapper &KillingLocWrapper) {
}
MemoryDefWrapper DeadDefWrapper(
cast<MemoryDef>(DeadAccess),
getLocForInst(cast<MemoryDef>(DeadAccess)->getMemoryInst()));
MemoryLocationWrapper &DeadLocWrapper = *DeadDefWrapper.DefinedLocation;
getLocForInst(cast<MemoryDef>(DeadAccess)->getMemoryInst(),
/*consider_initializes_attr=*/false));
assert(DeadDefWrapper.DefinedLocations.size() == 1);
MemoryLocationWrapper &DeadLocWrapper =
DeadDefWrapper.DefinedLocations.front();
LLVM_DEBUG(dbgs() << " (" << *DeadLocWrapper.DefInst << ")\n");
ToCheck.insert(DeadLocWrapper.MemDef->getDefiningAccess());
NumGetDomMemoryDefPassed++;
Expand Down Expand Up @@ -2311,37 +2448,41 @@ DSEState::eliminateDeadDefs(const MemoryLocationWrapper &KillingLocWrapper) {
}

bool DSEState::eliminateDeadDefs(const MemoryDefWrapper &KillingDefWrapper) {
if (!KillingDefWrapper.DefinedLocation.has_value()) {
if (KillingDefWrapper.DefinedLocations.empty()) {
LLVM_DEBUG(dbgs() << "Failed to find analyzable write location for "
<< *KillingDefWrapper.DefInst << "\n");
return false;
}

auto &KillingLocWrapper = *KillingDefWrapper.DefinedLocation;
LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by "
<< *KillingLocWrapper.MemDef << " ("
<< *KillingLocWrapper.DefInst << ")\n");
auto [Changed, DeletedKillingLoc] = eliminateDeadDefs(KillingLocWrapper);

// Check if the store is a no-op.
if (!DeletedKillingLoc && storeIsNoop(KillingLocWrapper.MemDef,
KillingLocWrapper.UnderlyingObject)) {
LLVM_DEBUG(dbgs() << "DSE: Remove No-Op Store:\n DEAD: "
<< *KillingLocWrapper.DefInst << '\n');
deleteDeadInstruction(KillingLocWrapper.DefInst);
NumRedundantStores++;
return true;
}
// Can we form a calloc from a memset/malloc pair?
if (!DeletedKillingLoc &&
tryFoldIntoCalloc(KillingLocWrapper.MemDef,
KillingLocWrapper.UnderlyingObject)) {
LLVM_DEBUG(dbgs() << "DSE: Remove memset after forming calloc:\n"
<< " DEAD: " << *KillingLocWrapper.DefInst << '\n');
deleteDeadInstruction(KillingLocWrapper.DefInst);
return true;
bool MadeChange = false;
for (auto &KillingLocWrapper : KillingDefWrapper.DefinedLocations) {
LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by "
<< *KillingLocWrapper.MemDef << " ("
<< *KillingLocWrapper.DefInst << ")\n");
auto [Changed, DeletedKillingLoc] = eliminateDeadDefs(KillingLocWrapper);

// Check if the store is a no-op.
if (!DeletedKillingLoc && storeIsNoop(KillingLocWrapper.MemDef,
KillingLocWrapper.UnderlyingObject)) {
LLVM_DEBUG(dbgs() << "DSE: Remove No-Op Store:\n DEAD: "
<< *KillingLocWrapper.DefInst << '\n');
deleteDeadInstruction(KillingLocWrapper.DefInst);
NumRedundantStores++;
MadeChange = true;
continue;
}
// Can we form a calloc from a memset/malloc pair?
if (!DeletedKillingLoc &&
tryFoldIntoCalloc(KillingLocWrapper.MemDef,
KillingLocWrapper.UnderlyingObject)) {
LLVM_DEBUG(dbgs() << "DSE: Remove memset after forming calloc:\n"
<< " DEAD: " << *KillingLocWrapper.DefInst << '\n');
deleteDeadInstruction(KillingLocWrapper.DefInst);
MadeChange = true;
continue;
}
}
return Changed;
return MadeChange;
}

static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
Expand All @@ -2357,7 +2498,8 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis &AA, MemorySSA &MSSA,
continue;

MemoryDefWrapper KillingDefWrapper(
KillingDef, State.getLocForInst(KillingDef->getMemoryInst()));
KillingDef, State.getLocForInst(KillingDef->getMemoryInst(),
EnableInitializesImprovement));
MadeChange |= State.eliminateDeadDefs(KillingDefWrapper);
}

Expand Down
45 changes: 45 additions & 0 deletions llvm/test/Transforms/DeadStoreElimination/inter-procedural.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -passes=function-attrs,dse -enable-dse-initializes-attr-improvement -S | FileCheck %s

; Function Attrs: mustprogress nounwind uwtable
define void @write_only_arg(ptr nocapture noundef writeonly initializes((0, 2)) %ptr) {
store i16 100, ptr %ptr
ret void
}

; Function Attrs: mustprogress nounwind uwtable memory(none argmem: readwrite)
define i16 @write_then_read_arg(ptr nocapture noundef initializes((0, 2)) %ptr) {
store i16 10, ptr %ptr
%l = load i16, ptr %ptr
ret i16 %l
}

; Function Attrs: mustprogress nounwind uwtable
define i16 @write_only_caller() {
; CHECK-LABEL: @write_only_caller(
; CHECK-NEXT: %ptr = alloca i16, align 2
; CHECK-NEXT: call void @write_only_arg(ptr %ptr)
; CHECK-NEXT: %l = load i16, ptr %ptr
; CHECK-NEXT: ret i16 %l
;
%ptr = alloca i16
store i16 0, ptr %ptr
call void @write_only_arg(ptr %ptr)
%l = load i16, ptr %ptr
ret i16 %l
}

; Function Attrs: mustprogress nounwind uwtable
define i16 @write_then_read_caller() {
; CHECK-LABEL: @write_then_read_caller(
; CHECK-NEXT: %ptr = alloca i16, align 2
; CHECK-NEXT: %call = call i16 @write_then_read_arg(ptr %ptr)
; CHECK-NEXT: %l = load i16, ptr %ptr
; CHECK-NEXT: ret i16 %l
;
%ptr = alloca i16
store i16 0, ptr %ptr
%call = call i16 @write_then_read_arg(ptr %ptr)
%l = load i16, ptr %ptr
ret i16 %l
}

0 comments on commit 58dd8a4

Please sign in to comment.