Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeGen][NewPM] Split MachinePostDominators into a concrete analysis result #95113

Merged
merged 1 commit into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 39 additions & 54 deletions llvm/include/llvm/CodeGen/MachinePostDominators.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,78 +15,63 @@
#define LLVM_CODEGEN_MACHINEPOSTDOMINATORS_H

#include "llvm/CodeGen/MachineDominators.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include <memory>

namespace llvm {

extern template class DominatorTreeBase<MachineBasicBlock, true>; // PostDomTree

namespace DomTreeBuilder {
using MBBPostDomTree = PostDomTreeBase<MachineBasicBlock>;
using MBBPostDomTreeGraphDiff = GraphDiff<MachineBasicBlock *, true>;

extern template void Calculate<MBBPostDomTree>(MBBPostDomTree &DT);
extern template void InsertEdge<MBBPostDomTree>(MBBPostDomTree &DT,
MachineBasicBlock *From,
MachineBasicBlock *To);
extern template void DeleteEdge<MBBPostDomTree>(MBBPostDomTree &DT,
MachineBasicBlock *From,
MachineBasicBlock *To);
extern template void ApplyUpdates<MBBPostDomTree>(MBBPostDomTree &DT,
MBBPostDomTreeGraphDiff &,
MBBPostDomTreeGraphDiff *);
extern template bool
Verify<MBBPostDomTree>(const MBBPostDomTree &DT,
MBBPostDomTree::VerificationLevel VL);
} // namespace DomTreeBuilder

///
/// MachinePostDominatorTree - an analysis pass wrapper for DominatorTree
/// used to compute the post-dominator tree for MachineFunctions.
///
class MachinePostDominatorTree : public MachineFunctionPass {
using PostDomTreeT = PostDomTreeBase<MachineBasicBlock>;
std::unique_ptr<PostDomTreeT> PDT;
class MachinePostDominatorTree : public PostDomTreeBase<MachineBasicBlock> {
using Base = PostDomTreeBase<MachineBasicBlock>;

public:
static char ID;

MachinePostDominatorTree();

PostDomTreeT &getBase() {
if (!PDT)
PDT.reset(new PostDomTreeT());
return *PDT;
}

FunctionPass *createMachinePostDominatorTreePass();

MachineDomTreeNode *getRootNode() const { return PDT->getRootNode(); }

MachineDomTreeNode *operator[](MachineBasicBlock *BB) const {
return PDT->getNode(BB);
}

MachineDomTreeNode *getNode(MachineBasicBlock *BB) const {
return PDT->getNode(BB);
}

bool dominates(const MachineDomTreeNode *A,
const MachineDomTreeNode *B) const {
return PDT->dominates(A, B);
}

bool dominates(const MachineBasicBlock *A, const MachineBasicBlock *B) const {
return PDT->dominates(A, B);
}
MachinePostDominatorTree() = default;

bool properlyDominates(const MachineDomTreeNode *A,
const MachineDomTreeNode *B) const {
return PDT->properlyDominates(A, B);
}

bool properlyDominates(const MachineBasicBlock *A,
const MachineBasicBlock *B) const {
return PDT->properlyDominates(A, B);
}

bool isVirtualRoot(const MachineDomTreeNode *Node) const {
return PDT->isVirtualRoot(Node);
}

MachineBasicBlock *findNearestCommonDominator(MachineBasicBlock *A,
MachineBasicBlock *B) const {
return PDT->findNearestCommonDominator(A, B);
}
/// Make findNearestCommonDominator(const NodeT *A, const NodeT *B) available.
using Base::findNearestCommonDominator;

/// Returns the nearest common dominator of the given blocks.
/// If that tree node is a virtual root, a nullptr will be returned.
MachineBasicBlock *
findNearestCommonDominator(ArrayRef<MachineBasicBlock *> Blocks) const;
};

class MachinePostDominatorTreeWrapperPass : public MachineFunctionPass {
std::optional<MachinePostDominatorTree> PDT;

public:
static char ID;

MachinePostDominatorTreeWrapperPass();

MachinePostDominatorTree &getPostDomTree() { return *PDT; }
const MachinePostDominatorTree &getPostDomTree() const { return *PDT; }

bool runOnMachineFunction(MachineFunction &MF) override;
void getAnalysisUsage(AnalysisUsage &AU) const override;
void releaseMemory() override { PDT.reset(nullptr); }
void releaseMemory() override { PDT.reset(); }
void verifyAnalysis() const override;
void print(llvm::raw_ostream &OS, const Module *M = nullptr) const override;
};
Expand Down
2 changes: 1 addition & 1 deletion llvm/include/llvm/InitializePasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ void initializeMachineModuleInfoWrapperPassPass(PassRegistry &);
void initializeMachineOptimizationRemarkEmitterPassPass(PassRegistry&);
void initializeMachineOutlinerPass(PassRegistry&);
void initializeMachinePipelinerPass(PassRegistry&);
void initializeMachinePostDominatorTreePass(PassRegistry&);
void initializeMachinePostDominatorTreeWrapperPassPass(PassRegistry &);
void initializeMachineRegionInfoPassPass(PassRegistry&);
void initializeMachineSanitizerBinaryMetadataPass(PassRegistry &);
void initializeMachineSchedulerPass(PassRegistry&);
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ void llvm::initializeCodeGen(PassRegistry &Registry) {
initializeMachinePipelinerPass(Registry);
initializeMachineSanitizerBinaryMetadataPass(Registry);
initializeModuloScheduleTestPass(Registry);
initializeMachinePostDominatorTreePass(Registry);
initializeMachinePostDominatorTreeWrapperPassPass(Registry);
initializeMachineRegionInfoPassPass(Registry);
initializeMachineSchedulerPass(Registry);
initializeMachineSinkingPass(Registry);
Expand Down
9 changes: 5 additions & 4 deletions llvm/lib/CodeGen/MIRSampleProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ INITIALIZE_PASS_BEGIN(MIRProfileLoaderPass, DEBUG_TYPE,
/* cfg = */ false, /* is_analysis = */ false)
INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo)
INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTree)
INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass)
INITIALIZE_PASS_END(MIRProfileLoaderPass, DEBUG_TYPE, "Load MIR Sample Profile",
Expand Down Expand Up @@ -366,8 +366,9 @@ bool MIRProfileLoaderPass::runOnMachineFunction(MachineFunction &MF) {
MBFI = &getAnalysis<MachineBlockFrequencyInfo>();
MIRSampleLoader->setInitVals(
&getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree(),
&getAnalysis<MachinePostDominatorTree>(), &getAnalysis<MachineLoopInfo>(),
MBFI, &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE());
&getAnalysis<MachinePostDominatorTreeWrapperPass>().getPostDomTree(),
&getAnalysis<MachineLoopInfo>(), MBFI,
&getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE());

MF.RenumberBlocks();
if (ViewBFIBefore && ViewBlockLayoutWithBFI != GVDT_None &&
Expand Down Expand Up @@ -401,7 +402,7 @@ void MIRProfileLoaderPass::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesAll();
AU.addRequired<MachineBlockFrequencyInfo>();
AU.addRequired<MachineDominatorTreeWrapperPass>();
AU.addRequired<MachinePostDominatorTree>();
AU.addRequired<MachinePostDominatorTreeWrapperPass>();
AU.addRequiredTransitive<MachineLoopInfo>();
AU.addRequired<MachineOptimizationRemarkEmitterPass>();
MachineFunctionPass::getAnalysisUsage(AU);
Expand Down
8 changes: 4 additions & 4 deletions llvm/lib/CodeGen/MachineBlockPlacement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ class MachineBlockPlacement : public MachineFunctionPass {
AU.addRequired<MachineBranchProbabilityInfo>();
AU.addRequired<MachineBlockFrequencyInfo>();
if (TailDupPlacement)
AU.addRequired<MachinePostDominatorTree>();
AU.addRequired<MachinePostDominatorTreeWrapperPass>();
AU.addRequired<MachineLoopInfo>();
AU.addRequired<ProfileSummaryInfoWrapperPass>();
AU.addRequired<TargetPassConfig>();
Expand All @@ -624,7 +624,7 @@ INITIALIZE_PASS_BEGIN(MachineBlockPlacement, DEBUG_TYPE,
"Branch Probability Basic Block Placement", false, false)
INITIALIZE_PASS_DEPENDENCY(MachineBranchProbabilityInfo)
INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo)
INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTree)
INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass)
INITIALIZE_PASS_END(MachineBlockPlacement, DEBUG_TYPE,
Expand Down Expand Up @@ -3417,7 +3417,7 @@ bool MachineBlockPlacement::runOnMachineFunction(MachineFunction &MF) {
TailDupSize = TII->getTailDuplicateSize(PassConfig->getOptLevel());

if (allowTailDupPlacement()) {
MPDT = &getAnalysis<MachinePostDominatorTree>();
MPDT = &getAnalysis<MachinePostDominatorTreeWrapperPass>().getPostDomTree();
bool OptForSize = MF.getFunction().hasOptSize() ||
llvm::shouldOptimizeForSize(&MF, PSI, &MBFI->getMBFI());
if (OptForSize)
Expand Down Expand Up @@ -3449,7 +3449,7 @@ bool MachineBlockPlacement::runOnMachineFunction(MachineFunction &MF) {
ComputedEdges.clear();
// Must redo the post-dominator tree if blocks were changed.
if (MPDT)
MPDT->runOnMachineFunction(MF);
MPDT->recalculate(MF);
ChainAllocator.DestroyAll();
buildCFGChains();
}
Expand Down
58 changes: 35 additions & 23 deletions llvm/lib/CodeGen/MachinePostDominators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,46 @@ using namespace llvm;
namespace llvm {
template class DominatorTreeBase<MachineBasicBlock, true>; // PostDomTreeBase

namespace DomTreeBuilder {

template void Calculate<MBBPostDomTree>(MBBPostDomTree &DT);
template void InsertEdge<MBBPostDomTree>(MBBPostDomTree &DT,
MachineBasicBlock *From,
MachineBasicBlock *To);
template void DeleteEdge<MBBPostDomTree>(MBBPostDomTree &DT,
MachineBasicBlock *From,
MachineBasicBlock *To);
template void ApplyUpdates<MBBPostDomTree>(MBBPostDomTree &DT,
MBBPostDomTreeGraphDiff &,
MBBPostDomTreeGraphDiff *);
template bool Verify<MBBPostDomTree>(const MBBPostDomTree &DT,
MBBPostDomTree::VerificationLevel VL);

} // namespace DomTreeBuilder
extern bool VerifyMachineDomInfo;
} // namespace llvm

char MachinePostDominatorTree::ID = 0;
char MachinePostDominatorTreeWrapperPass::ID = 0;

//declare initializeMachinePostDominatorTreePass
INITIALIZE_PASS(MachinePostDominatorTree, "machinepostdomtree",
INITIALIZE_PASS(MachinePostDominatorTreeWrapperPass, "machinepostdomtree",
"MachinePostDominator Tree Construction", true, true)

MachinePostDominatorTree::MachinePostDominatorTree()
: MachineFunctionPass(ID), PDT(nullptr) {
initializeMachinePostDominatorTreePass(*PassRegistry::getPassRegistry());
MachinePostDominatorTreeWrapperPass::MachinePostDominatorTreeWrapperPass()
: MachineFunctionPass(ID), PDT() {
initializeMachinePostDominatorTreeWrapperPassPass(
*PassRegistry::getPassRegistry());
}

FunctionPass *MachinePostDominatorTree::createMachinePostDominatorTreePass() {
return new MachinePostDominatorTree();
}

bool MachinePostDominatorTree::runOnMachineFunction(MachineFunction &F) {
PDT = std::make_unique<PostDomTreeT>();
bool MachinePostDominatorTreeWrapperPass::runOnMachineFunction(
MachineFunction &F) {
PDT = MachinePostDominatorTree();
PDT->recalculate(F);
return false;
}

void MachinePostDominatorTree::getAnalysisUsage(AnalysisUsage &AU) const {
void MachinePostDominatorTreeWrapperPass::getAnalysisUsage(
AnalysisUsage &AU) const {
AU.setPreservesAll();
MachineFunctionPass::getAnalysisUsage(AU);
}
Expand All @@ -54,26 +69,23 @@ MachineBasicBlock *MachinePostDominatorTree::findNearestCommonDominator(

MachineBasicBlock *NCD = Blocks.front();
for (MachineBasicBlock *BB : Blocks.drop_front()) {
NCD = PDT->findNearestCommonDominator(NCD, BB);
NCD = Base::findNearestCommonDominator(NCD, BB);

// Stop when the root is reached.
if (PDT->isVirtualRoot(PDT->getNode(NCD)))
if (isVirtualRoot(getNode(NCD)))
return nullptr;
}

return NCD;
}

void MachinePostDominatorTree::verifyAnalysis() const {
if (PDT && VerifyMachineDomInfo)
if (!PDT->verify(PostDomTreeT::VerificationLevel::Basic)) {
errs() << "MachinePostDominatorTree verification failed\n";

abort();
}
void MachinePostDominatorTreeWrapperPass::verifyAnalysis() const {
if (VerifyMachineDomInfo && PDT &&
!PDT->verify(MachinePostDominatorTree::VerificationLevel::Basic))
report_fatal_error("MachinePostDominatorTree verification failed!");
}

void MachinePostDominatorTree::print(llvm::raw_ostream &OS,
const Module *M) const {
void MachinePostDominatorTreeWrapperPass::print(llvm::raw_ostream &OS,
const Module *M) const {
PDT->print(OS);
}
7 changes: 4 additions & 3 deletions llvm/lib/CodeGen/MachineRegionInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ bool MachineRegionInfoPass::runOnMachineFunction(MachineFunction &F) {
releaseMemory();

auto DT = &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
auto PDT = &getAnalysis<MachinePostDominatorTree>();
auto PDT =
&getAnalysis<MachinePostDominatorTreeWrapperPass>().getPostDomTree();
auto DF = &getAnalysis<MachineDominanceFrontier>();

RI.recalculate(F, DT, PDT, DF);
Expand All @@ -110,7 +111,7 @@ void MachineRegionInfoPass::verifyAnalysis() const {
void MachineRegionInfoPass::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesAll();
AU.addRequired<MachineDominatorTreeWrapperPass>();
AU.addRequired<MachinePostDominatorTree>();
AU.addRequired<MachinePostDominatorTreeWrapperPass>();
AU.addRequired<MachineDominanceFrontier>();
MachineFunctionPass::getAnalysisUsage(AU);
}
Expand All @@ -131,7 +132,7 @@ char &MachineRegionInfoPassID = MachineRegionInfoPass::ID;
INITIALIZE_PASS_BEGIN(MachineRegionInfoPass, DEBUG_TYPE,
"Detect single entry single exit regions", true, true)
INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTree)
INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineDominanceFrontier)
INITIALIZE_PASS_END(MachineRegionInfoPass, DEBUG_TYPE,
"Detect single entry single exit regions", true, true)
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/CodeGen/MachineSink.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ namespace {
MachineFunctionPass::getAnalysisUsage(AU);
AU.addRequired<AAResultsWrapperPass>();
AU.addRequired<MachineDominatorTreeWrapperPass>();
AU.addRequired<MachinePostDominatorTree>();
AU.addRequired<MachinePostDominatorTreeWrapperPass>();
AU.addRequired<MachineCycleInfoWrapperPass>();
AU.addRequired<MachineBranchProbabilityInfo>();
AU.addPreserved<MachineCycleInfoWrapperPass>();
Expand Down Expand Up @@ -709,7 +709,7 @@ bool MachineSinking::runOnMachineFunction(MachineFunction &MF) {
TRI = STI->getRegisterInfo();
MRI = &MF.getRegInfo();
DT = &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
PDT = &getAnalysis<MachinePostDominatorTree>();
PDT = &getAnalysis<MachinePostDominatorTreeWrapperPass>().getPostDomTree();
CI = &getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo();
MBFI = UseBlockFreqInfo ? &getAnalysis<MachineBlockFrequencyInfo>() : nullptr;
MBPI = &getAnalysis<MachineBranchProbabilityInfo>();
Expand Down
8 changes: 4 additions & 4 deletions llvm/lib/CodeGen/ShrinkWrap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ class ShrinkWrap : public MachineFunctionPass {
void init(MachineFunction &MF) {
RCI.runOnMachineFunction(MF);
MDT = &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
MPDT = &getAnalysis<MachinePostDominatorTree>();
MPDT = &getAnalysis<MachinePostDominatorTreeWrapperPass>().getPostDomTree();
Save = nullptr;
Restore = nullptr;
MBFI = &getAnalysis<MachineBlockFrequencyInfo>();
Expand Down Expand Up @@ -263,7 +263,7 @@ class ShrinkWrap : public MachineFunctionPass {
AU.setPreservesAll();
AU.addRequired<MachineBlockFrequencyInfo>();
AU.addRequired<MachineDominatorTreeWrapperPass>();
AU.addRequired<MachinePostDominatorTree>();
AU.addRequired<MachinePostDominatorTreeWrapperPass>();
AU.addRequired<MachineLoopInfo>();
AU.addRequired<MachineOptimizationRemarkEmitterPass>();
MachineFunctionPass::getAnalysisUsage(AU);
Expand All @@ -290,7 +290,7 @@ char &llvm::ShrinkWrapID = ShrinkWrap::ID;
INITIALIZE_PASS_BEGIN(ShrinkWrap, DEBUG_TYPE, "Shrink Wrap Pass", false, false)
INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo)
INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTree)
INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass)
INITIALIZE_PASS_END(ShrinkWrap, DEBUG_TYPE, "Shrink Wrap Pass", false, false)
Expand Down Expand Up @@ -671,7 +671,7 @@ bool ShrinkWrap::postShrinkWrapping(bool HasCandidate, MachineFunction &MF,
Restore = NewRestore;

MDT->recalculate(MF);
MPDT->runOnMachineFunction(MF);
MPDT->recalculate(MF);

assert((MDT->dominates(Save, Restore) && MPDT->dominates(Restore, Save)) &&
"Incorrect save or restore point due to dominance relations");
Expand Down
Loading
Loading