Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[llvm][profdata][NFC] Support 64-bit weights in ProfDataUtils #86607

Conversation

ilovepi
Copy link
Contributor

@ilovepi ilovepi commented Mar 26, 2024

Since some places, like SimplifyCFG work with 64-bit weights, we supply an API
in ProfDataUtils to extract the weights accordingly.

We change the API slightly to disambiguate the 64 bit version from the 32 bit
version.

Created using spr 1.3.4
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 26, 2024

@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-llvm-transforms

Author: Paul Kirth (ilovepi)

Changes

Since some places, like SimplifyCFG work with 64-bit weights, we supply an API
in ProfDataUtils to extract the weights accordingly.

We change the API slightly to disambiguate the 64 bit version from the 32 bit
version.


Full diff: https://github.com/llvm/llvm-project/pull/86607.diff

4 Files Affected:

  • (modified) llvm/include/llvm/IR/ProfDataUtils.h (+7-2)
  • (modified) llvm/lib/IR/ProfDataUtils.cpp (+31-14)
  • (modified) llvm/lib/Transforms/Utils/LoopRotationUtils.cpp (+1-1)
  • (modified) llvm/lib/Transforms/Utils/SimplifyCFG.cpp (+2-5)
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index 255fa2ff1c7906..dc983eed13a8d3 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -65,10 +65,15 @@ bool extractBranchWeights(const MDNode *ProfileData,
                           SmallVectorImpl<uint32_t> &Weights);
 
 /// Faster version of extractBranchWeights() that skips checks and must only
-/// be called with "branch_weights" metadata nodes.
-void extractFromBranchWeightMD(const MDNode *ProfileData,
+/// be called with "branch_weights" metadata nodes. Supports uint32_t.
+void extractFromBranchWeightMD32(const MDNode *ProfileData,
                                SmallVectorImpl<uint32_t> &Weights);
 
+/// Faster version of extractBranchWeights() that skips checks and must only
+/// be called with "branch_weights" metadata nodes. Supports uint64_t.
+void extractFromBranchWeightMD64(const MDNode *ProfileData,
+                               SmallVectorImpl<uint64_t> &Weights);
+
 /// Extract branch weights attatched to an Instruction
 ///
 /// \param I The Instruction to extract weights from.
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index b1a10d0ce5a522..b4e09e76993f99 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -65,6 +65,26 @@ bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) {
   return ProfDataName->getString().equals(Name);
 }
 
+template <typename T,
+          typename = typename std::enable_if<std::is_arithmetic_v<T>>>
+static void extractFromBranchWeightMD(const MDNode *ProfileData,
+                                      SmallVectorImpl<T> &Weights) {
+  assert(isBranchWeightMD(ProfileData) && "wrong metadata");
+
+  unsigned NOps = ProfileData->getNumOperands();
+  assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
+  Weights.resize(NOps - WeightsIdx);
+
+  for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
+    ConstantInt *Weight =
+        mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
+    assert(Weight && "Malformed branch_weight in MD_prof node");
+    assert(Weight->getValue().getActiveBits() <= 32 &&
+           "Too many bits for uint32_t");
+    Weights[Idx - WeightsIdx] = Weight->getZExtValue();
+  }
+}
+
 } // namespace
 
 namespace llvm {
@@ -100,24 +120,21 @@ MDNode *getValidBranchWeightMDNode(const Instruction &I) {
   return nullptr;
 }
 
-void extractFromBranchWeightMD(const MDNode *ProfileData,
+void extractFromBranchWeightMD32(const MDNode *ProfileData,
                                SmallVectorImpl<uint32_t> &Weights) {
-  assert(isBranchWeightMD(ProfileData) && "wrong metadata");
-
-  unsigned NOps = ProfileData->getNumOperands();
-  assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
-  Weights.resize(NOps - WeightsIdx);
+  extractFromBranchWeightMD(ProfileData, Weights);
+}
 
-  for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
-    ConstantInt *Weight =
-        mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(Idx));
-    assert(Weight && "Malformed branch_weight in MD_prof node");
-    assert(Weight->getValue().getActiveBits() <= 32 &&
-           "Too many bits for uint32_t");
-    Weights[Idx - WeightsIdx] = Weight->getZExtValue();
-  }
+void extractFromBranchWeightMD64(const MDNode *ProfileData,
+                               SmallVectorImpl<uint64_t> &Weights) {
+  extractFromBranchWeightMD(ProfileData, Weights);
 }
 
+
+
+
+
+
 bool extractBranchWeights(const MDNode *ProfileData,
                           SmallVectorImpl<uint32_t> &Weights) {
   if (!isBranchWeightMD(ProfileData))
diff --git a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
index bc671171137199..f4b43ce370a5da 100644
--- a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
@@ -287,7 +287,7 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
     return;
 
   SmallVector<uint32_t, 2> Weights;
-  extractFromBranchWeightMD(WeightMD, Weights);
+  extractFromBranchWeightMD32(WeightMD, Weights);
   if (Weights.size() != 2)
     return;
   uint32_t OrigLoopExitWeight = Weights[0];
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 55bbffb18879fb..a425e26d490e4f 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -1065,11 +1065,8 @@ static int ConstantIntSortPredicate(ConstantInt *const *P1,
 static void GetBranchWeights(Instruction *TI,
                              SmallVectorImpl<uint64_t> &Weights) {
   MDNode *MD = TI->getMetadata(LLVMContext::MD_prof);
-  assert(MD);
-  for (unsigned i = 1, e = MD->getNumOperands(); i < e; ++i) {
-    ConstantInt *CI = mdconst::extract<ConstantInt>(MD->getOperand(i));
-    Weights.push_back(CI->getValue().getZExtValue());
-  }
+  assert(MD && "Invalid branch-weight metadata");
+  extractFromBranchWeightMD64(MD, Weights);
 
   // If TI is a conditional eq, the default case is the false case,
   // and the corresponding branch-weight data is at index 2. We swap the

@ilovepi ilovepi requested review from MatzeB and aeubanks March 26, 2024 00:50
Copy link

github-actions bot commented Mar 26, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Created using spr 1.3.4
ilovepi added a commit to ilovepi/llvm-project that referenced this pull request Mar 26, 2024
Since some places, like SimplifyCFG work with 64-bit weights, we supply an API
in ProfDataUtils to extract the weights accordingly.

We change the API slightly to disambiguate the 64 bit version from the 32 bit
version.

Pull Request: llvm#86607
@aeubanks aeubanks requested a review from david-xl March 26, 2024 17:03
@aeubanks
Copy link
Contributor

can you explain when 32-bit vs 64-bit weights are used?

@MatzeB
Copy link
Contributor

MatzeB commented Mar 26, 2024

Be careful here! I think there is a bunch of code that is summing up weights and stores intermediate results in a uint64_t. This is all fine when the weights are uint32_t but risks (silent!) overflow when the weights use the full uint64_t. I haven't looked around much but could find one random example in SimplifyCFG already:

Weights[0] += Weights[i + 1];

@MatzeB
Copy link
Contributor

MatzeB commented Mar 26, 2024

I am assuming that this also serves as preparation to increasing the bit size of the weight annotations?

@ilovepi
Copy link
Contributor Author

ilovepi commented Mar 26, 2024

can you explain when 32-bit vs 64-bit weights are used?

I think we generally work w/ 32-bit weights, but in some cases, we use 64-bit to make sure that summing or scaling don't overflow. @MatzeB mentioned some of this in his comment. My motivation here is to avoid having bespoke handling of branch weight extraction. To a large extent due to to #86609.

Be careful here! I think there is a bunch of code that is summing up weights and stores intermediate results in a uint64_t. This is all fine when the weights are uint32_t but risks (silent!) overflow when the weights use the full uint64_t. I haven't looked around much but could find one random example in SimplifyCFG already:

Weights[0] += Weights[i + 1];

Right, I don't want to change the defaults, I just want to provide better utilities, so people aren't manually walking the MD_prof metadata. When we add the proposed optional field in #86609, the offsets won't be fixed, and IMO, its better to provide the necessary utilities and point people towards those.

That's one of the reasons I left extractBranchWeights() APIs as they are now, only using uint32_t, which is the main way people should be extracting weights. I think the 32/64 APIs are a rather niche use case, but if passes need to operate on 64-bit weights, as in SimplifyCFG, then we should support that in ProfdataUtils.

I am assuming that this also serves as preparation to increasing the bit size of the weight annotations?

That isn't one of my goals.

@ilovepi
Copy link
Contributor Author

ilovepi commented Apr 2, 2024

@aeubanks @MatzeB, did my comments answer your questions adequately?

Created using spr 1.3.4
@ilovepi
Copy link
Contributor Author

ilovepi commented Apr 29, 2024

ping

Created using spr 1.3.4
Copy link
Contributor

@MatzeB MatzeB left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, if this is just for convenience / consistency then fine with me. LGTM, thanks

@ilovepi ilovepi merged commit 7538df9 into main Apr 30, 2024
4 checks passed
@ilovepi ilovepi deleted the users/ilovepi/spr/llvmprofdatanfc-support-64-bit-weights-in-profdatautils branch April 30, 2024 20:53
@ilovepi ilovepi restored the users/ilovepi/spr/llvmprofdatanfc-support-64-bit-weights-in-profdatautils branch April 30, 2024 21:03
@ilovepi ilovepi deleted the users/ilovepi/spr/llvmprofdatanfc-support-64-bit-weights-in-profdatautils branch April 30, 2024 21:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants