Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMDGPU] Make AMDGPULowerKernelArguments a module pass #112790

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

kerbowa
Copy link
Member

@kerbowa kerbowa commented Oct 17, 2024

After c4d8920 AMDGPULowerKernelArguments may clone functions and modify the kernel signature of those functions to enable preloading hidden kernel arguments. These leftover functions end up as dead declarations which may cause issues with the toolchain downstream.

This patch makes AMDGPULowerKernelArguments a module pass so that we can safely erase these leftover declarations.

There is also some small refactoring to avoid duplicated logic with the different pass managers. The update changes the pass interfaces to look similar to other AMDGPU passes that have been migrated over to the new pass manager.

After c4d8920 AMDGPULowerKernelArguments may clone functions and
modify the kernel signature of those functions to enable preloading
hidden kernel arguments. These leftover functions end up as dead
declarations which may cause issues with the toolchain downstream.

This patch makes AMDGPULowerKernelArguments a module pass so that we can
safely erase these leftover declarations.

There is also some small refactoring to avoid duplicated logic with the
different pass managers. The update changes the pass interfaces to look
similar to other AMDGPU passes that have been migrated over to the new
pass manager.
@llvmbot
Copy link
Collaborator

llvmbot commented Oct 17, 2024

@llvm/pr-subscribers-backend-amdgpu

Author: Austin Kerbow (kerbowa)

Changes

After c4d8920 AMDGPULowerKernelArguments may clone functions and modify the kernel signature of those functions to enable preloading hidden kernel arguments. These leftover functions end up as dead declarations which may cause issues with the toolchain downstream.

This patch makes AMDGPULowerKernelArguments a module pass so that we can safely erase these leftover declarations.

There is also some small refactoring to avoid duplicated logic with the different pass managers. The update changes the pass interfaces to look similar to other AMDGPU passes that have been migrated over to the new pass manager.


Patch is 29.66 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/112790.diff

6 Files Affected:

  • (modified) llvm/lib/Target/AMDGPU/AMDGPU.h (+4-4)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPULowerKernelArguments.cpp (+243-222)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUPassRegistry.def (+2-2)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp (+2-2)
  • (modified) llvm/test/CodeGen/AMDGPU/llc-pipeline.ll (+5-10)
  • (modified) llvm/test/CodeGen/AMDGPU/preload-implicit-kernargs-IR-lowering.ll (+11-2)
diff --git a/llvm/lib/Target/AMDGPU/AMDGPU.h b/llvm/lib/Target/AMDGPU/AMDGPU.h
index 342d55e828bca5..9ffd1f3977213e 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPU.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPU.h
@@ -111,9 +111,9 @@ ModulePass *createAMDGPUCtorDtorLoweringLegacyPass();
 void initializeAMDGPUCtorDtorLoweringLegacyPass(PassRegistry &);
 extern char &AMDGPUCtorDtorLoweringLegacyPassID;
 
-FunctionPass *createAMDGPULowerKernelArgumentsPass();
-void initializeAMDGPULowerKernelArgumentsPass(PassRegistry &);
-extern char &AMDGPULowerKernelArgumentsID;
+ModulePass *createAMDGPULowerKernelArgumentsLegacyPass(const TargetMachine *TM);
+void initializeAMDGPULowerKernelArgumentsLegacyPass(PassRegistry &);
+extern char &AMDGPULowerKernelArgumentsLegacyPassID;
 
 FunctionPass *createAMDGPUPromoteKernelArgumentsPass();
 void initializeAMDGPUPromoteKernelArgumentsPass(PassRegistry &);
@@ -310,7 +310,7 @@ class AMDGPULowerKernelArgumentsPass
 
 public:
   AMDGPULowerKernelArgumentsPass(TargetMachine &TM) : TM(TM){};
-  PreservedAnalyses run(Function &, FunctionAnalysisManager &);
+  PreservedAnalyses run(Module &, ModuleAnalysisManager &);
 };
 
 struct AMDGPUAttributorOptions {
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULowerKernelArguments.cpp b/llvm/lib/Target/AMDGPU/AMDGPULowerKernelArguments.cpp
index 6573176492b7f3..7b986b4385023e 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULowerKernelArguments.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULowerKernelArguments.cpp
@@ -131,7 +131,6 @@ class PreloadKernelArgInfo {
 
     NF->setAttributes(AL);
     F.replaceAllUsesWith(NF);
-    F.setCallingConv(CallingConv::C);
 
     return NF;
   }
@@ -169,8 +168,9 @@ class PreloadKernelArgInfo {
   }
 
   // Try to allocate SGPRs to preload implicit kernel arguments.
-  void tryAllocImplicitArgPreloadSGPRs(uint64_t ImplicitArgsBaseOffset,
-                                       IRBuilder<> &Builder) {
+  void tryAllocImplicitArgPreloadSGPRs(
+      uint64_t ImplicitArgsBaseOffset, IRBuilder<> &Builder,
+      SmallVectorImpl<Function *> &FunctionsToErase) {
     Function *ImplicitArgPtr = Intrinsic::getDeclarationIfExists(
         F.getParent(), Intrinsic::amdgcn_implicitarg_ptr);
     if (!ImplicitArgPtr)
@@ -239,6 +239,7 @@ class PreloadKernelArgInfo {
     unsigned LastHiddenArgIndex = getHiddenArgFromOffset(PreloadEnd[-1].second);
     Function *NF = cloneFunctionWithPreloadImplicitArgs(LastHiddenArgIndex);
     assert(NF);
+    FunctionsToErase.push_back(&F);
     for (const auto *I = ImplicitArgLoads.begin(); I != PreloadEnd; ++I) {
       LoadInst *LoadInst = I->first;
       unsigned LoadOffset = I->second;
@@ -250,264 +251,284 @@ class PreloadKernelArgInfo {
   }
 };
 
-class AMDGPULowerKernelArguments : public FunctionPass {
-public:
-  static char ID;
+class AMDGPULowerKernelArguments {
+  const TargetMachine &TM;
+  SmallVector<Function *> FunctionsToErase;
 
-  AMDGPULowerKernelArguments() : FunctionPass(ID) {}
+public:
+  AMDGPULowerKernelArguments(const TargetMachine &TM) : TM(TM) {}
+
+  // skip allocas
+  static BasicBlock::iterator getInsertPt(BasicBlock &BB) {
+    BasicBlock::iterator InsPt = BB.getFirstInsertionPt();
+    for (BasicBlock::iterator E = BB.end(); InsPt != E; ++InsPt) {
+      AllocaInst *AI = dyn_cast<AllocaInst>(&*InsPt);
+
+      // If this is a dynamic alloca, the value may depend on the loaded kernargs,
+      // so loads will need to be inserted before it.
+      if (!AI || !AI->isStaticAlloca())
+        break;
+    }
 
-  bool runOnFunction(Function &F) override;
+    return InsPt;
+  }
 
-  void getAnalysisUsage(AnalysisUsage &AU) const override {
-    AU.addRequired<TargetPassConfig>();
-    AU.setPreservesAll();
- }
-};
+  bool lowerKernelArguments(Function &F) {
+    CallingConv::ID CC = F.getCallingConv();
+    if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty())
+      return false;
 
-} // end anonymous namespace
+    const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
+    LLVMContext &Ctx = F.getParent()->getContext();
+    const DataLayout &DL = F.getDataLayout();
+    BasicBlock &EntryBlock = *F.begin();
+    IRBuilder<> Builder(&EntryBlock, getInsertPt(EntryBlock));
 
-// skip allocas
-static BasicBlock::iterator getInsertPt(BasicBlock &BB) {
-  BasicBlock::iterator InsPt = BB.getFirstInsertionPt();
-  for (BasicBlock::iterator E = BB.end(); InsPt != E; ++InsPt) {
-    AllocaInst *AI = dyn_cast<AllocaInst>(&*InsPt);
+    const Align KernArgBaseAlign(16); // FIXME: Increase if necessary
+    const uint64_t BaseOffset = ST.getExplicitKernelArgOffset();
 
-    // If this is a dynamic alloca, the value may depend on the loaded kernargs,
-    // so loads will need to be inserted before it.
-    if (!AI || !AI->isStaticAlloca())
-      break;
-  }
+    Align MaxAlign;
+    // FIXME: Alignment is broken with explicit arg offset.;
+    const uint64_t TotalKernArgSize = ST.getKernArgSegmentSize(F, MaxAlign);
+    if (TotalKernArgSize == 0)
+      return false;
 
-  return InsPt;
-}
+    CallInst *KernArgSegment =
+        Builder.CreateIntrinsic(Intrinsic::amdgcn_kernarg_segment_ptr, {}, {},
+                                nullptr, F.getName() + ".kernarg.segment");
+    KernArgSegment->addRetAttr(Attribute::NonNull);
+    KernArgSegment->addRetAttr(
+        Attribute::getWithDereferenceableBytes(Ctx, TotalKernArgSize));
 
-static bool lowerKernelArguments(Function &F, const TargetMachine &TM) {
-  CallingConv::ID CC = F.getCallingConv();
-  if (CC != CallingConv::AMDGPU_KERNEL || F.arg_empty())
-    return false;
-
-  const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
-  LLVMContext &Ctx = F.getParent()->getContext();
-  const DataLayout &DL = F.getDataLayout();
-  BasicBlock &EntryBlock = *F.begin();
-  IRBuilder<> Builder(&EntryBlock, getInsertPt(EntryBlock));
-
-  const Align KernArgBaseAlign(16); // FIXME: Increase if necessary
-  const uint64_t BaseOffset = ST.getExplicitKernelArgOffset();
-
-  Align MaxAlign;
-  // FIXME: Alignment is broken with explicit arg offset.;
-  const uint64_t TotalKernArgSize = ST.getKernArgSegmentSize(F, MaxAlign);
-  if (TotalKernArgSize == 0)
-    return false;
-
-  CallInst *KernArgSegment =
-      Builder.CreateIntrinsic(Intrinsic::amdgcn_kernarg_segment_ptr, {}, {},
-                              nullptr, F.getName() + ".kernarg.segment");
-  KernArgSegment->addRetAttr(Attribute::NonNull);
-  KernArgSegment->addRetAttr(
-      Attribute::getWithDereferenceableBytes(Ctx, TotalKernArgSize));
-
-  uint64_t ExplicitArgOffset = 0;
-  // Preloaded kernel arguments must be sequential.
-  bool InPreloadSequence = true;
-  PreloadKernelArgInfo PreloadInfo(F, ST);
-
-  for (Argument &Arg : F.args()) {
-    const bool IsByRef = Arg.hasByRefAttr();
-    Type *ArgTy = IsByRef ? Arg.getParamByRefType() : Arg.getType();
-    MaybeAlign ParamAlign = IsByRef ? Arg.getParamAlign() : std::nullopt;
-    Align ABITypeAlign = DL.getValueOrABITypeAlignment(ParamAlign, ArgTy);
-
-    uint64_t Size = DL.getTypeSizeInBits(ArgTy);
-    uint64_t AllocSize = DL.getTypeAllocSize(ArgTy);
-
-    uint64_t EltOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + BaseOffset;
-    uint64_t LastExplicitArgOffset = ExplicitArgOffset;
-    ExplicitArgOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + AllocSize;
-
-    // Guard against the situation where hidden arguments have already been
-    // lowered and added to the kernel function signiture, i.e. in a situation
-    // where this pass has run twice.
-    if (Arg.hasAttribute("amdgpu-hidden-argument"))
-      break;
-
-    // Try to preload this argument into user SGPRs.
-    if (Arg.hasInRegAttr() && InPreloadSequence && ST.hasKernargPreload() &&
-        !Arg.getType()->isAggregateType())
-      if (PreloadInfo.tryAllocPreloadSGPRs(AllocSize, EltOffset,
-                                           LastExplicitArgOffset))
-        continue;
+    uint64_t ExplicitArgOffset = 0;
+    // Preloaded kernel arguments must be sequential.
+    bool InPreloadSequence = true;
+    PreloadKernelArgInfo PreloadInfo(F, ST);
 
-    InPreloadSequence = false;
+    for (Argument &Arg : F.args()) {
+      const bool IsByRef = Arg.hasByRefAttr();
+      Type *ArgTy = IsByRef ? Arg.getParamByRefType() : Arg.getType();
+      MaybeAlign ParamAlign = IsByRef ? Arg.getParamAlign() : std::nullopt;
+      Align ABITypeAlign = DL.getValueOrABITypeAlignment(ParamAlign, ArgTy);
+
+      uint64_t Size = DL.getTypeSizeInBits(ArgTy);
+      uint64_t AllocSize = DL.getTypeAllocSize(ArgTy);
+
+      uint64_t EltOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + BaseOffset;
+      uint64_t LastExplicitArgOffset = ExplicitArgOffset;
+      ExplicitArgOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + AllocSize;
+
+      // Guard against the situation where hidden arguments have already been
+      // lowered and added to the kernel function signiture, i.e. in a situation
+      // where this pass has run twice.
+      if (Arg.hasAttribute("amdgpu-hidden-argument"))
+        break;
+
+      // Try to preload this argument into user SGPRs.
+      if (Arg.hasInRegAttr() && InPreloadSequence && ST.hasKernargPreload() &&
+          !Arg.getType()->isAggregateType())
+        if (PreloadInfo.tryAllocPreloadSGPRs(AllocSize, EltOffset,
+                                            LastExplicitArgOffset))
+          continue;
 
-    if (Arg.use_empty())
-      continue;
+      InPreloadSequence = false;
 
-    // If this is byval, the loads are already explicit in the function. We just
-    // need to rewrite the pointer values.
-    if (IsByRef) {
-      Value *ArgOffsetPtr = Builder.CreateConstInBoundsGEP1_64(
-          Builder.getInt8Ty(), KernArgSegment, EltOffset,
-          Arg.getName() + ".byval.kernarg.offset");
+      if (Arg.use_empty())
+        continue;
 
-      Value *CastOffsetPtr =
-          Builder.CreateAddrSpaceCast(ArgOffsetPtr, Arg.getType());
-      Arg.replaceAllUsesWith(CastOffsetPtr);
-      continue;
-    }
+      // If this is byval, the loads are already explicit in the function. We just
+      // need to rewrite the pointer values.
+      if (IsByRef) {
+        Value *ArgOffsetPtr = Builder.CreateConstInBoundsGEP1_64(
+            Builder.getInt8Ty(), KernArgSegment, EltOffset,
+            Arg.getName() + ".byval.kernarg.offset");
 
-    if (PointerType *PT = dyn_cast<PointerType>(ArgTy)) {
-      // FIXME: Hack. We rely on AssertZext to be able to fold DS addressing
-      // modes on SI to know the high bits are 0 so pointer adds don't wrap. We
-      // can't represent this with range metadata because it's only allowed for
-      // integer types.
-      if ((PT->getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS ||
-           PT->getAddressSpace() == AMDGPUAS::REGION_ADDRESS) &&
-          !ST.hasUsableDSOffset())
+        Value *CastOffsetPtr =
+            Builder.CreateAddrSpaceCast(ArgOffsetPtr, Arg.getType());
+        Arg.replaceAllUsesWith(CastOffsetPtr);
         continue;
+      }
 
-      // FIXME: We can replace this with equivalent alias.scope/noalias
-      // metadata, but this appears to be a lot of work.
-      if (Arg.hasNoAliasAttr())
-        continue;
-    }
+      if (PointerType *PT = dyn_cast<PointerType>(ArgTy)) {
+        // FIXME: Hack. We rely on AssertZext to be able to fold DS addressing
+        // modes on SI to know the high bits are 0 so pointer adds don't wrap. We
+        // can't represent this with range metadata because it's only allowed for
+        // integer types.
+        if ((PT->getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS ||
+            PT->getAddressSpace() == AMDGPUAS::REGION_ADDRESS) &&
+            !ST.hasUsableDSOffset())
+          continue;
 
-    auto *VT = dyn_cast<FixedVectorType>(ArgTy);
-    bool IsV3 = VT && VT->getNumElements() == 3;
-    bool DoShiftOpt = Size < 32 && !ArgTy->isAggregateType();
-
-    VectorType *V4Ty = nullptr;
-
-    int64_t AlignDownOffset = alignDown(EltOffset, 4);
-    int64_t OffsetDiff = EltOffset - AlignDownOffset;
-    Align AdjustedAlign = commonAlignment(
-        KernArgBaseAlign, DoShiftOpt ? AlignDownOffset : EltOffset);
-
-    Value *ArgPtr;
-    Type *AdjustedArgTy;
-    if (DoShiftOpt) { // FIXME: Handle aggregate types
-      // Since we don't have sub-dword scalar loads, avoid doing an extload by
-      // loading earlier than the argument address, and extracting the relevant
-      // bits.
-      // TODO: Update this for GFX12 which does have scalar sub-dword loads.
-      //
-      // Additionally widen any sub-dword load to i32 even if suitably aligned,
-      // so that CSE between different argument loads works easily.
-      ArgPtr = Builder.CreateConstInBoundsGEP1_64(
-          Builder.getInt8Ty(), KernArgSegment, AlignDownOffset,
-          Arg.getName() + ".kernarg.offset.align.down");
-      AdjustedArgTy = Builder.getInt32Ty();
-    } else {
-      ArgPtr = Builder.CreateConstInBoundsGEP1_64(
-          Builder.getInt8Ty(), KernArgSegment, EltOffset,
-          Arg.getName() + ".kernarg.offset");
-      AdjustedArgTy = ArgTy;
-    }
+        // FIXME: We can replace this with equivalent alias.scope/noalias
+        // metadata, but this appears to be a lot of work.
+        if (Arg.hasNoAliasAttr())
+          continue;
+      }
 
-    if (IsV3 && Size >= 32) {
-      V4Ty = FixedVectorType::get(VT->getElementType(), 4);
-      // Use the hack that clang uses to avoid SelectionDAG ruining v3 loads
-      AdjustedArgTy = V4Ty;
-    }
+      auto *VT = dyn_cast<FixedVectorType>(ArgTy);
+      bool IsV3 = VT && VT->getNumElements() == 3;
+      bool DoShiftOpt = Size < 32 && !ArgTy->isAggregateType();
+
+      VectorType *V4Ty = nullptr;
+
+      int64_t AlignDownOffset = alignDown(EltOffset, 4);
+      int64_t OffsetDiff = EltOffset - AlignDownOffset;
+      Align AdjustedAlign = commonAlignment(
+          KernArgBaseAlign, DoShiftOpt ? AlignDownOffset : EltOffset);
+
+      Value *ArgPtr;
+      Type *AdjustedArgTy;
+      if (DoShiftOpt) { // FIXME: Handle aggregate types
+        // Since we don't have sub-dword scalar loads, avoid doing an extload by
+        // loading earlier than the argument address, and extracting the relevant
+        // bits.
+        // TODO: Update this for GFX12 which does have scalar sub-dword loads.
+        //
+        // Additionally widen any sub-dword load to i32 even if suitably aligned,
+        // so that CSE between different argument loads works easily.
+        ArgPtr = Builder.CreateConstInBoundsGEP1_64(
+            Builder.getInt8Ty(), KernArgSegment, AlignDownOffset,
+            Arg.getName() + ".kernarg.offset.align.down");
+        AdjustedArgTy = Builder.getInt32Ty();
+      } else {
+        ArgPtr = Builder.CreateConstInBoundsGEP1_64(
+            Builder.getInt8Ty(), KernArgSegment, EltOffset,
+            Arg.getName() + ".kernarg.offset");
+        AdjustedArgTy = ArgTy;
+      }
 
-    LoadInst *Load =
-        Builder.CreateAlignedLoad(AdjustedArgTy, ArgPtr, AdjustedAlign);
-    Load->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(Ctx, {}));
+      if (IsV3 && Size >= 32) {
+        V4Ty = FixedVectorType::get(VT->getElementType(), 4);
+        // Use the hack that clang uses to avoid SelectionDAG ruining v3 loads
+        AdjustedArgTy = V4Ty;
+      }
 
-    MDBuilder MDB(Ctx);
+      LoadInst *Load =
+          Builder.CreateAlignedLoad(AdjustedArgTy, ArgPtr, AdjustedAlign);
+      Load->setMetadata(LLVMContext::MD_invariant_load, MDNode::get(Ctx, {}));
 
-    if (isa<PointerType>(ArgTy)) {
-      if (Arg.hasNonNullAttr())
-        Load->setMetadata(LLVMContext::MD_nonnull, MDNode::get(Ctx, {}));
+      MDBuilder MDB(Ctx);
 
-      uint64_t DerefBytes = Arg.getDereferenceableBytes();
-      if (DerefBytes != 0) {
-        Load->setMetadata(
-          LLVMContext::MD_dereferenceable,
-          MDNode::get(Ctx,
-                      MDB.createConstant(
-                        ConstantInt::get(Builder.getInt64Ty(), DerefBytes))));
-      }
+      if (isa<PointerType>(ArgTy)) {
+        if (Arg.hasNonNullAttr())
+          Load->setMetadata(LLVMContext::MD_nonnull, MDNode::get(Ctx, {}));
 
-      uint64_t DerefOrNullBytes = Arg.getDereferenceableOrNullBytes();
-      if (DerefOrNullBytes != 0) {
-        Load->setMetadata(
-          LLVMContext::MD_dereferenceable_or_null,
-          MDNode::get(Ctx,
-                      MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(),
-                                                          DerefOrNullBytes))));
+        uint64_t DerefBytes = Arg.getDereferenceableBytes();
+        if (DerefBytes != 0) {
+          Load->setMetadata(
+            LLVMContext::MD_dereferenceable,
+            MDNode::get(Ctx,
+                        MDB.createConstant(
+                          ConstantInt::get(Builder.getInt64Ty(), DerefBytes))));
+        }
+
+        uint64_t DerefOrNullBytes = Arg.getDereferenceableOrNullBytes();
+        if (DerefOrNullBytes != 0) {
+          Load->setMetadata(
+            LLVMContext::MD_dereferenceable_or_null,
+            MDNode::get(Ctx,
+                        MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(),
+                                                            DerefOrNullBytes))));
+        }
+
+        if (MaybeAlign ParamAlign = Arg.getParamAlign()) {
+          Load->setMetadata(
+              LLVMContext::MD_align,
+              MDNode::get(Ctx, MDB.createConstant(ConstantInt::get(
+                                  Builder.getInt64Ty(), ParamAlign->value()))));
+        }
       }
 
-      if (MaybeAlign ParamAlign = Arg.getParamAlign()) {
-        Load->setMetadata(
-            LLVMContext::MD_align,
-            MDNode::get(Ctx, MDB.createConstant(ConstantInt::get(
-                                 Builder.getInt64Ty(), ParamAlign->value()))));
+      // TODO: Convert noalias arg to !noalias
+
+      if (DoShiftOpt) {
+        Value *ExtractBits = OffsetDiff == 0 ?
+          Load : Builder.CreateLShr(Load, OffsetDiff * 8);
+
+        IntegerType *ArgIntTy = Builder.getIntNTy(Size);
+        Value *Trunc = Builder.CreateTrunc(ExtractBits, ArgIntTy);
+        Value *NewVal = Builder.CreateBitCast(Trunc, ArgTy,
+                                              Arg.getName() + ".load");
+        Arg.replaceAllUsesWith(NewVal);
+      } else if (IsV3) {
+        Value *Shuf = Builder.CreateShuffleVector(Load, ArrayRef<int>{0, 1, 2},
+                                                  Arg.getName() + ".load");
+        Arg.replaceAllUsesWith(Shuf);
+      } else {
+        Load->setName(Arg.getName() + ".load");
+        Arg.replaceAllUsesWith(Load);
       }
     }
 
-    // TODO: Convert noalias arg to !noalias
-
-    if (DoShiftOpt) {
-      Value *ExtractBits = OffsetDiff == 0 ?
-        Load : Builder.CreateLShr(Load, OffsetDiff * 8);
-
-      IntegerType *ArgIntTy = Builder.getIntNTy(Size);
-      Value *Trunc = Builder.CreateTrunc(ExtractBits, ArgIntTy);
-      Value *NewVal = Builder.CreateBitCast(Trunc, ArgTy,
-                                            Arg.getName() + ".load");
-      Arg.replaceAllUsesWith(NewVal);
-    } else if (IsV3) {
-      Value *Shuf = Builder.CreateShuffleVector(Load, ArrayRef<int>{0, 1, 2},
-                                                Arg.getName() + ".load");
-      Arg.replaceAllUsesWith(Shuf);
-    } else {
-      Load->setName(Arg.getName() + ".load");
-      Arg.replaceAllUsesWith(Load);
+    KernArgSegment->addRetAttr(
+        Attribute::getWithAlignment(Ctx, std::max(KernArgBaseAlign, MaxAlign)));
+
+    if (InPreloadSequence) {
+      uint64_t ImplicitArgsBaseOffset =
+          alignTo(ExplicitArgOffset, ST.getAlignmentForImplicitArgPtr()) +
+          BaseOffset;
+      PreloadInfo.tryAllocImplicitArgPreloadSGPRs(ImplicitArgsBaseOffset,
+                                                  Builder, FunctionsToErase);
     }
+
+    return true;
   }
 
-  KernArgSegment->addRetAttr(
-      Attribute::getWithAlignment(Ctx, std::max(KernArgBaseAlign, MaxAlign)));
+  bool runOnModule(Module &M) {
+    bool Changed = false;
 
-  if (InPreloadSequence) {
-    uint64_t ImplicitArgsBaseOffset =
-        alignTo(ExplicitArgOffset, ST.getAlignmentForImplicitArgPtr()) +
-        BaseOffset;
-    PreloadInfo.tryAllocImplicitArgPreloadSGPRs(ImplicitArgsBaseOffset,
-                                                Builder);
+    for (Function &F : M)
+      Changed |= lowerKernelArguments(F);
+
+  ...
[truncated]

Copy link

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff 4512bbe7467c1c0f884304e5654d1070df58d6f8 a8cb03ff24446a85ea82963d3585204b0874a55a --extensions cpp,h -- llvm/lib/Target/AMDGPU/AMDGPU.h llvm/lib/Target/AMDGPU/AMDGPULowerKernelArguments.cpp llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
View the diff from clang-format here.
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULowerKernelArguments.cpp b/llvm/lib/Target/AMDGPU/AMDGPULowerKernelArguments.cpp
index 7b986b4385..02ca044c4b 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULowerKernelArguments.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULowerKernelArguments.cpp
@@ -264,8 +264,8 @@ public:
     for (BasicBlock::iterator E = BB.end(); InsPt != E; ++InsPt) {
       AllocaInst *AI = dyn_cast<AllocaInst>(&*InsPt);
 
-      // If this is a dynamic alloca, the value may depend on the loaded kernargs,
-      // so loads will need to be inserted before it.
+      // If this is a dynamic alloca, the value may depend on the loaded
+      // kernargs, so loads will need to be inserted before it.
       if (!AI || !AI->isStaticAlloca())
         break;
     }
@@ -314,7 +314,8 @@ public:
       uint64_t Size = DL.getTypeSizeInBits(ArgTy);
       uint64_t AllocSize = DL.getTypeAllocSize(ArgTy);
 
-      uint64_t EltOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + BaseOffset;
+      uint64_t EltOffset =
+          alignTo(ExplicitArgOffset, ABITypeAlign) + BaseOffset;
       uint64_t LastExplicitArgOffset = ExplicitArgOffset;
       ExplicitArgOffset = alignTo(ExplicitArgOffset, ABITypeAlign) + AllocSize;
 
@@ -328,7 +329,7 @@ public:
       if (Arg.hasInRegAttr() && InPreloadSequence && ST.hasKernargPreload() &&
           !Arg.getType()->isAggregateType())
         if (PreloadInfo.tryAllocPreloadSGPRs(AllocSize, EltOffset,
-                                            LastExplicitArgOffset))
+                                             LastExplicitArgOffset))
           continue;
 
       InPreloadSequence = false;
@@ -336,8 +337,8 @@ public:
       if (Arg.use_empty())
         continue;
 
-      // If this is byval, the loads are already explicit in the function. We just
-      // need to rewrite the pointer values.
+      // If this is byval, the loads are already explicit in the function. We
+      // just need to rewrite the pointer values.
       if (IsByRef) {
         Value *ArgOffsetPtr = Builder.CreateConstInBoundsGEP1_64(
             Builder.getInt8Ty(), KernArgSegment, EltOffset,
@@ -351,11 +352,11 @@ public:
 
       if (PointerType *PT = dyn_cast<PointerType>(ArgTy)) {
         // FIXME: Hack. We rely on AssertZext to be able to fold DS addressing
-        // modes on SI to know the high bits are 0 so pointer adds don't wrap. We
-        // can't represent this with range metadata because it's only allowed for
-        // integer types.
+        // modes on SI to know the high bits are 0 so pointer adds don't wrap.
+        // We can't represent this with range metadata because it's only allowed
+        // for integer types.
         if ((PT->getAddressSpace() == AMDGPUAS::LOCAL_ADDRESS ||
-            PT->getAddressSpace() == AMDGPUAS::REGION_ADDRESS) &&
+             PT->getAddressSpace() == AMDGPUAS::REGION_ADDRESS) &&
             !ST.hasUsableDSOffset())
           continue;
 
@@ -380,12 +381,12 @@ public:
       Type *AdjustedArgTy;
       if (DoShiftOpt) { // FIXME: Handle aggregate types
         // Since we don't have sub-dword scalar loads, avoid doing an extload by
-        // loading earlier than the argument address, and extracting the relevant
-        // bits.
+        // loading earlier than the argument address, and extracting the
+        // relevant bits.
         // TODO: Update this for GFX12 which does have scalar sub-dword loads.
         //
-        // Additionally widen any sub-dword load to i32 even if suitably aligned,
-        // so that CSE between different argument loads works easily.
+        // Additionally widen any sub-dword load to i32 even if suitably
+        // aligned, so that CSE between different argument loads works easily.
         ArgPtr = Builder.CreateConstInBoundsGEP1_64(
             Builder.getInt8Ty(), KernArgSegment, AlignDownOffset,
             Arg.getName() + ".kernarg.offset.align.down");
@@ -416,39 +417,38 @@ public:
         uint64_t DerefBytes = Arg.getDereferenceableBytes();
         if (DerefBytes != 0) {
           Load->setMetadata(
-            LLVMContext::MD_dereferenceable,
-            MDNode::get(Ctx,
-                        MDB.createConstant(
-                          ConstantInt::get(Builder.getInt64Ty(), DerefBytes))));
+              LLVMContext::MD_dereferenceable,
+              MDNode::get(Ctx, MDB.createConstant(ConstantInt::get(
+                                   Builder.getInt64Ty(), DerefBytes))));
         }
 
         uint64_t DerefOrNullBytes = Arg.getDereferenceableOrNullBytes();
         if (DerefOrNullBytes != 0) {
           Load->setMetadata(
-            LLVMContext::MD_dereferenceable_or_null,
-            MDNode::get(Ctx,
-                        MDB.createConstant(ConstantInt::get(Builder.getInt64Ty(),
-                                                            DerefOrNullBytes))));
+              LLVMContext::MD_dereferenceable_or_null,
+              MDNode::get(Ctx, MDB.createConstant(ConstantInt::get(
+                                   Builder.getInt64Ty(), DerefOrNullBytes))));
         }
 
         if (MaybeAlign ParamAlign = Arg.getParamAlign()) {
           Load->setMetadata(
               LLVMContext::MD_align,
-              MDNode::get(Ctx, MDB.createConstant(ConstantInt::get(
-                                  Builder.getInt64Ty(), ParamAlign->value()))));
+              MDNode::get(Ctx,
+                          MDB.createConstant(ConstantInt::get(
+                              Builder.getInt64Ty(), ParamAlign->value()))));
         }
       }
 
       // TODO: Convert noalias arg to !noalias
 
       if (DoShiftOpt) {
-        Value *ExtractBits = OffsetDiff == 0 ?
-          Load : Builder.CreateLShr(Load, OffsetDiff * 8);
+        Value *ExtractBits =
+            OffsetDiff == 0 ? Load : Builder.CreateLShr(Load, OffsetDiff * 8);
 
         IntegerType *ArgIntTy = Builder.getIntNTy(Size);
         Value *Trunc = Builder.CreateTrunc(ExtractBits, ArgIntTy);
-        Value *NewVal = Builder.CreateBitCast(Trunc, ArgTy,
-                                              Arg.getName() + ".load");
+        Value *NewVal =
+            Builder.CreateBitCast(Trunc, ArgTy, Arg.getName() + ".load");
         Arg.replaceAllUsesWith(NewVal);
       } else if (IsV3) {
         Value *Shuf = Builder.CreateShuffleVector(Load, ArrayRef<int>{0, 1, 2},

@shiltian
Copy link
Contributor

I'd make the refactoring and change to module pass separate PRs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants