diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h index c0897408986fb3..88fbad4d6b9d82 100644 --- a/llvm/include/llvm/IR/ProfDataUtils.h +++ b/llvm/include/llvm/IR/ProfDataUtils.h @@ -65,9 +65,14 @@ bool extractBranchWeights(const MDNode *ProfileData, SmallVectorImpl &Weights); /// Faster version of extractBranchWeights() that skips checks and must only -/// be called with "branch_weights" metadata nodes. -void extractFromBranchWeightMD(const MDNode *ProfileData, - SmallVectorImpl &Weights); +/// be called with "branch_weights" metadata nodes. Supports uint32_t. +void extractFromBranchWeightMD32(const MDNode *ProfileData, + SmallVectorImpl &Weights); + +/// Faster version of extractBranchWeights() that skips checks and must only +/// be called with "branch_weights" metadata nodes. Supports uint64_t. +void extractFromBranchWeightMD64(const MDNode *ProfileData, + SmallVectorImpl &Weights); /// Extract branch weights attatched to an Instruction /// diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp index dc86f4204b1a1d..874ead5bc63f15 100644 --- a/llvm/lib/IR/ProfDataUtils.cpp +++ b/llvm/lib/IR/ProfDataUtils.cpp @@ -65,6 +65,26 @@ bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) { return ProfDataName->getString().equals(Name); } +template >> +static void extractFromBranchWeightMD(const MDNode *ProfileData, + SmallVectorImpl &Weights) { + assert(isBranchWeightMD(ProfileData) && "wrong metadata"); + + unsigned NOps = ProfileData->getNumOperands(); + assert(WeightsIdx < NOps && "Weights Index must be less than NOps."); + Weights.resize(NOps - WeightsIdx); + + for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) { + ConstantInt *Weight = + mdconst::dyn_extract(ProfileData->getOperand(Idx)); + assert(Weight && "Malformed branch_weight in MD_prof node"); + assert(Weight->getValue().getActiveBits() <= 32 && + "Too many bits for uint32_t"); + Weights[Idx - WeightsIdx] = Weight->getZExtValue(); + } +} + } // namespace namespace llvm { @@ -100,22 +120,14 @@ MDNode *getValidBranchWeightMDNode(const Instruction &I) { return nullptr; } -void extractFromBranchWeightMD(const MDNode *ProfileData, - SmallVectorImpl &Weights) { - assert(isBranchWeightMD(ProfileData) && "wrong metadata"); - - unsigned NOps = ProfileData->getNumOperands(); - assert(WeightsIdx < NOps && "Weights Index must be less than NOps."); - Weights.resize(NOps - WeightsIdx); +void extractFromBranchWeightMD32(const MDNode *ProfileData, + SmallVectorImpl &Weights) { + extractFromBranchWeightMD(ProfileData, Weights); +} - for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) { - ConstantInt *Weight = - mdconst::dyn_extract(ProfileData->getOperand(Idx)); - assert(Weight && "Malformed branch_weight in MD_prof node"); - assert(Weight->getValue().getActiveBits() <= 32 && - "Too many bits for uint32_t"); - Weights[Idx - WeightsIdx] = Weight->getZExtValue(); - } +void extractFromBranchWeightMD64(const MDNode *ProfileData, + SmallVectorImpl &Weights) { + extractFromBranchWeightMD(ProfileData, Weights); } bool extractBranchWeights(const MDNode *ProfileData, diff --git a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp index 0f55af3b6eddf8..5cd96412a322d1 100644 --- a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp @@ -287,7 +287,7 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI, return; SmallVector Weights; - extractFromBranchWeightMD(WeightMD, Weights); + extractFromBranchWeightMD32(WeightMD, Weights); if (Weights.size() != 2) return; uint32_t OrigLoopExitWeight = Weights[0]; diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index 4db72461c95e47..5a44a11ecfd2c5 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -1066,11 +1066,8 @@ static int ConstantIntSortPredicate(ConstantInt *const *P1, static void GetBranchWeights(Instruction *TI, SmallVectorImpl &Weights) { MDNode *MD = TI->getMetadata(LLVMContext::MD_prof); - assert(MD); - for (unsigned i = 1, e = MD->getNumOperands(); i < e; ++i) { - ConstantInt *CI = mdconst::extract(MD->getOperand(i)); - Weights.push_back(CI->getValue().getZExtValue()); - } + assert(MD && "Invalid branch-weight metadata"); + extractFromBranchWeightMD64(MD, Weights); // If TI is a conditional eq, the default case is the false case, // and the corresponding branch-weight data is at index 2. We swap the