Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Target] Improve ROCDL gpu serialization API #95456

Merged
merged 2 commits into from
Jun 17, 2024

Conversation

fabianmcg
Copy link
Contributor

@fabianmcg fabianmcg commented Jun 13, 2024

This patch improves the ROCDL gpu serialization API by:

  • Introducing the enum AMDGCNLibraries for specifying the AMD GCN device code libraries to use during linking.
  • Removing getCommonBitcodeLibs in favor of AMDGCNLibraries. Previously getCommonBitcodeLibs would try to load all AMD GCN bitcode librariesm now it will only load the requested libraries.
  • Exposing the compileToBinary method and making it virtual, allowing downstream users to re-use this method.
  • Exposing moduleToObjectImpl, this method provides a prototype flow for compiling to binary, allowing downstream users to re-use this method.
  • It also avoids constructing the control variables if no device libraries are being used.

This patch also changes the behavior of the CMake flag DEFAULT_ROCM_PATH. Before it would fall back to a default value of /opt/rocm if not specified. However, that default value causes fragile builds in environments with ROCm. Now, the flag falls back to the empty string, making it clear that the user must provide a value at LLVM build time.

@fabianmcg fabianmcg force-pushed the pr-improve-rcodl-serialization branch 3 times, most recently from 3b74b15 to 9a04741 Compare June 13, 2024 19:53
@fabianmcg fabianmcg marked this pull request as ready for review June 13, 2024 20:27
@fabianmcg fabianmcg requested a review from krzysz00 June 13, 2024 20:27
@llvmbot
Copy link
Collaborator

llvmbot commented Jun 13, 2024

@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Fabian Mora (fabianmcg)

Changes

This patch improves the ROCDL gpu serialization API by:

  • Introducing the class AMDGCNLibraryList.
    This class provides a structured API for specifying the AMD GCN device code libraries to use during linking.
  • Removing getCommonBitcodeLibs in favor of AMDGCNLibraryList. Previously getCommonBitcodeLibs would try to load all AMD GCN bitcode libraries.
  • Exposing the compileToBinary method and making it virtual, allowing downstream users to re-use this method.
  • Exposing moduleToObjectImpl, this method provides a prototype flow for compiling to binary, allowing downstream users to re-use this method.

This patch also changes the behavior of the CMake flag DEFAULT_ROCM_PATH. Before it would fall back to a default value of /opt/rocm if not specified. However, that default value causes fragile builds in environments with ROCm. Now, the flag falls back to the empty string, making it clear that the user must provide a value at LLVM build time.


Patch is 21.65 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95456.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Target/LLVM/ROCDL/Utils.h (+76-10)
  • (modified) mlir/lib/Dialect/GPU/CMakeLists.txt (+1-1)
  • (modified) mlir/lib/Target/LLVM/CMakeLists.txt (+1-6)
  • (modified) mlir/lib/Target/LLVM/ROCDL/Target.cpp (+163-119)
diff --git a/mlir/include/mlir/Target/LLVM/ROCDL/Utils.h b/mlir/include/mlir/Target/LLVM/ROCDL/Utils.h
index 374fa65bd02e3..acbdb06be3f67 100644
--- a/mlir/include/mlir/Target/LLVM/ROCDL/Utils.h
+++ b/mlir/include/mlir/Target/LLVM/ROCDL/Utils.h
@@ -27,6 +27,64 @@ namespace ROCDL {
 /// 5. Returns an empty string.
 StringRef getROCMPath();
 
+/// Helper class for specifying the AMD GCN device libraries required for
+/// compilation.
+class AMDGCNLibraryList {
+public:
+  typedef enum : uint32_t {
+    None = 0,
+    Ockl = 1,
+    Ocml = 2,
+    OpenCL = 4,
+    Hip = 8,
+    LastLib = Hip,
+    All = (LastLib << 1) - 1
+  } Library;
+
+  explicit AMDGCNLibraryList(uint32_t libs = All) : libList(All & libs) {}
+
+  /// Return a list with no libraries.
+  static AMDGCNLibraryList getEmpty() { return AMDGCNLibraryList(None); }
+
+  /// Return the libraries needed for compiling code with OpenCL calls.
+  static AMDGCNLibraryList getOpenCL() {
+    return AMDGCNLibraryList(Ockl | Ocml | OpenCL);
+  }
+
+  /// Returns true if the list is empty.
+  bool isEmpty() const { return libList == None; }
+
+  /// Adds a library to the list.
+  AMDGCNLibraryList addLibrary(Library lib) {
+    libList = libList | lib;
+    return *this;
+  }
+
+  /// Adds all the libraries in `list` to the library list.
+  AMDGCNLibraryList addList(AMDGCNLibraryList list) {
+    libList = libList | list.libList;
+    return *this;
+  }
+
+  /// Removes a library from the list.
+  AMDGCNLibraryList removeLibrary(Library lib) {
+    libList = libList & ~lib;
+    return *this;
+  }
+
+  /// Returns true if `lib` is in the list of libraries.
+  bool requiresLibrary(Library lib) const { return (libList & lib) != None; }
+
+  /// Returns true if `libList` contains all the libraries in `libs`.
+  bool containLibraries(uint32_t libs) const {
+    return (libList & libs) != None;
+  }
+
+private:
+  /// Library list.
+  uint32_t libList;
+};
+
 /// Base class for all ROCDL serializations from GPU modules into binary
 /// strings. By default this class serializes into LLVM bitcode.
 class SerializeGPUModuleBase : public LLVM::ModuleToObject {
@@ -49,8 +107,8 @@ class SerializeGPUModuleBase : public LLVM::ModuleToObject {
   /// Returns the bitcode files to be loaded.
   ArrayRef<std::string> getFileList() const;
 
-  /// Appends standard ROCm device libraries like `ocml.bc`, `ockl.bc`, etc.
-  LogicalResult appendStandardLibs();
+  /// Appends standard ROCm device Library to `fileList`.
+  LogicalResult appendStandardLibs(AMDGCNLibraryList libs);
 
   /// Loads the bitcode files in `fileList`.
   virtual std::optional<SmallVector<std::unique_ptr<llvm::Module>>>
@@ -63,15 +121,20 @@ class SerializeGPUModuleBase : public LLVM::ModuleToObject {
   LogicalResult handleBitcodeFile(llvm::Module &module) override;
 
 protected:
-  /// Appends the paths of common ROCm device libraries to `libs`.
-  LogicalResult getCommonBitcodeLibs(llvm::SmallVector<std::string> &libs,
-                                     SmallVector<char, 256> &libPath,
-                                     StringRef isaVersion);
-
   /// Adds `oclc` control variables to the LLVM module.
-  void addControlVariables(llvm::Module &module, bool wave64, bool daz,
-                           bool finiteOnly, bool unsafeMath, bool fastMath,
-                           bool correctSqrt, StringRef abiVer);
+  void addControlVariables(llvm::Module &module, AMDGCNLibraryList libs,
+                           bool wave64, bool daz, bool finiteOnly,
+                           bool unsafeMath, bool fastMath, bool correctSqrt,
+                           StringRef abiVer);
+
+  /// Compiles assembly to a binary.
+  virtual std::optional<SmallVector<char, 0>>
+  compileToBinary(const std::string &serializedISA);
+
+  /// Default implementation of `ModuleToObject::moduleToObject`.
+  std::optional<SmallVector<char, 0>>
+  moduleToObjectImpl(const gpu::TargetOptions &targetOptions,
+                     llvm::Module &llvmModule);
 
   /// Returns the assembled ISA.
   std::optional<SmallVector<char, 0>> assembleIsa(StringRef isa);
@@ -84,6 +147,9 @@ class SerializeGPUModuleBase : public LLVM::ModuleToObject {
 
   /// List of LLVM bitcode files to link to.
   SmallVector<std::string> fileList;
+
+  /// AMD GCN libraries to use when linking, the default is using all.
+  AMDGCNLibraryList deviceLibs = AMDGCNLibraryList::getEmpty();
 };
 } // namespace ROCDL
 } // namespace mlir
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index 61ab298ebfb98..08c8aea36fac9 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -106,7 +106,7 @@ if(MLIR_ENABLE_ROCM_CONVERSIONS)
       "Building mlir with ROCm support requires the AMDGPU backend")
   endif()
 
-  set(DEFAULT_ROCM_PATH "/opt/rocm" CACHE PATH "Fallback path to search for ROCm installs")
+  set(DEFAULT_ROCM_PATH "" CACHE PATH "Fallback path to search for ROCm installs")
   target_compile_definitions(obj.MLIRGPUTransforms
     PRIVATE
     __DEFAULT_ROCM_PATH__="${DEFAULT_ROCM_PATH}"
diff --git a/mlir/lib/Target/LLVM/CMakeLists.txt b/mlir/lib/Target/LLVM/CMakeLists.txt
index 5a3fa160850b4..4393ff1775ef9 100644
--- a/mlir/lib/Target/LLVM/CMakeLists.txt
+++ b/mlir/lib/Target/LLVM/CMakeLists.txt
@@ -123,17 +123,12 @@ add_mlir_dialect_library(MLIRROCDLTarget
   )
 
 if(MLIR_ENABLE_ROCM_CONVERSIONS)
-  if (NOT ("AMDGPU" IN_LIST LLVM_TARGETS_TO_BUILD))
-    message(SEND_ERROR
-      "Building mlir with ROCm support requires the AMDGPU backend")
-  endif()
-
   if (DEFINED ROCM_PATH)
     set(DEFAULT_ROCM_PATH "${ROCM_PATH}" CACHE PATH "Fallback path to search for ROCm installs")
   elseif(DEFINED ENV{ROCM_PATH})
     set(DEFAULT_ROCM_PATH "$ENV{ROCM_PATH}" CACHE PATH "Fallback path to search for ROCm installs")
   else()
-    set(DEFAULT_ROCM_PATH "/opt/rocm" CACHE PATH "Fallback path to search for ROCm installs")
+    set(DEFAULT_ROCM_PATH "" CACHE PATH "Fallback path to search for ROCm installs")
   endif()
   message(VERBOSE "MLIR Default ROCM toolkit path: ${DEFAULT_ROCM_PATH}")
 
diff --git a/mlir/lib/Target/LLVM/ROCDL/Target.cpp b/mlir/lib/Target/LLVM/ROCDL/Target.cpp
index cc13e5b7436ea..dd4311f2b4b39 100644
--- a/mlir/lib/Target/LLVM/ROCDL/Target.cpp
+++ b/mlir/lib/Target/LLVM/ROCDL/Target.cpp
@@ -17,9 +17,6 @@
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/Support/FileUtilities.h"
 #include "mlir/Target/LLVM/ROCDL/Utils.h"
-#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
-#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
-#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Export.h"
 
 #include "llvm/IR/Constants.h"
@@ -112,8 +109,9 @@ SerializeGPUModuleBase::SerializeGPUModuleBase(
       if (auto file = dyn_cast<StringAttr>(attr))
         fileList.push_back(file.str());
 
-  // Append standard ROCm device bitcode libraries to the files to be loaded.
-  (void)appendStandardLibs();
+  // By default add all libraries if the toolkit path is not empty.
+  if (!getToolkitPath().empty())
+    deviceLibs = AMDGCNLibraryList(AMDGCNLibraryList::All);
 }
 
 void SerializeGPUModuleBase::init() {
@@ -138,29 +136,70 @@ ArrayRef<std::string> SerializeGPUModuleBase::getFileList() const {
   return fileList;
 }
 
-LogicalResult SerializeGPUModuleBase::appendStandardLibs() {
+LogicalResult
+SerializeGPUModuleBase::appendStandardLibs(AMDGCNLibraryList libs) {
+  if (libs.isEmpty())
+    return success();
   StringRef pathRef = getToolkitPath();
-  if (!pathRef.empty()) {
-    SmallVector<char, 256> path;
-    path.insert(path.begin(), pathRef.begin(), pathRef.end());
-    llvm::sys::path::append(path, "amdgcn", "bitcode");
-    pathRef = StringRef(path.data(), path.size());
-    if (!llvm::sys::fs::is_directory(pathRef)) {
-      getOperation().emitRemark() << "ROCm amdgcn bitcode path: " << pathRef
-                                  << " does not exist or is not a directory.";
-      return failure();
-    }
-    StringRef isaVersion =
-        llvm::AMDGPU::getArchNameAMDGCN(llvm::AMDGPU::parseArchAMDGCN(chip));
-    isaVersion.consume_front("gfx");
-    return getCommonBitcodeLibs(fileList, path, isaVersion);
+  // Fail if the toolkit is empty.
+  if (pathRef.empty())
+    return failure();
+
+  // Get the path for the device libraries
+  SmallString<256> path;
+  path.insert(path.begin(), pathRef.begin(), pathRef.end());
+  llvm::sys::path::append(path, "amdgcn", "bitcode");
+  pathRef = StringRef(path.data(), path.size());
+
+  // Fail if the path is invalid.
+  if (!llvm::sys::fs::is_directory(pathRef)) {
+    getOperation().emitRemark() << "ROCm amdgcn bitcode path: " << pathRef
+                                << " does not exist or is not a directory.";
+    return failure();
   }
+
+  // Get the ISA version.
+  StringRef isaVersion =
+      llvm::AMDGPU::getArchNameAMDGCN(llvm::AMDGPU::parseArchAMDGCN(chip));
+  isaVersion.consume_front("gfx");
+
+  // Helper function for adding a library.
+  auto addLib = [&](const Twine &lib) -> bool {
+    auto baseSize = path.size();
+    llvm::sys::path::append(path, lib);
+    StringRef pathRef(path.data(), path.size());
+    if (!llvm::sys::fs::is_regular_file(pathRef)) {
+      getOperation().emitRemark() << "Bitcode library path: " << pathRef
+                                  << " does not exist or is not a file.\n";
+      return true;
+    }
+    fileList.push_back(pathRef.str());
+    path.truncate(baseSize);
+    return false;
+  };
+
+  // Add ROCm device libraries. Fail if any of the libraries is not found, ie.
+  // if any of the `addLib` failed.
+  if ((libs.requiresLibrary(AMDGCNLibraryList::Ocml) && addLib("ocml.bc")) ||
+      (libs.requiresLibrary(AMDGCNLibraryList::Ockl) && addLib("ockl.bc")) ||
+      (libs.requiresLibrary(AMDGCNLibraryList::Hip) && addLib("hip.bc")) ||
+      (libs.requiresLibrary(AMDGCNLibraryList::OpenCL) &&
+       addLib("opencl.bc")) ||
+      (libs.containLibraries(AMDGCNLibraryList::Ocml |
+                             AMDGCNLibraryList::Ockl) &&
+       addLib("oclc_isa_version_" + isaVersion + ".bc")))
+    return failure();
   return success();
 }
 
 std::optional<SmallVector<std::unique_ptr<llvm::Module>>>
 SerializeGPUModuleBase::loadBitcodeFiles(llvm::Module &module) {
   SmallVector<std::unique_ptr<llvm::Module>> bcFiles;
+  // Return if there are no libs to load.
+  if (deviceLibs.isEmpty() && fileList.empty())
+    return bcFiles;
+  if (failed(appendStandardLibs(deviceLibs)))
+    return std::nullopt;
   if (failed(loadBitcodeFilesFromList(module.getContext(), fileList, bcFiles,
                                       true)))
     return std::nullopt;
@@ -174,80 +213,79 @@ LogicalResult SerializeGPUModuleBase::handleBitcodeFile(llvm::Module &module) {
   // Stop spamming us with clang version numbers
   if (auto *ident = module.getNamedMetadata("llvm.ident"))
     module.eraseNamedMetadata(ident);
+  // Override the libModules datalayout and target triple with the compiler's
+  // data layout should there be a discrepency.
+  setDataLayoutAndTriple(module);
   return success();
 }
 
 void SerializeGPUModuleBase::handleModulePreLink(llvm::Module &module) {
-  [[maybe_unused]] std::optional<llvm::TargetMachine *> targetMachine =
+  std::optional<llvm::TargetMachine *> targetMachine =
       getOrCreateTargetMachine();
   assert(targetMachine && "expect a TargetMachine");
-  addControlVariables(module, target.hasWave64(), target.hasDaz(),
+  // If all libraries are not set, traverse the module to determine which
+  // libraries are required.
+  if (!deviceLibs.requiresLibrary(AMDGCNLibraryList::All)) {
+    for (llvm::Function &f : module.functions()) {
+      if (f.hasExternalLinkage() && f.hasName() && !f.hasExactDefinition()) {
+        StringRef funcName = f.getName();
+        if ("printf" == funcName)
+          deviceLibs.addList(AMDGCNLibraryList::getOpenCL());
+        if (funcName.starts_with("__ockl_"))
+          deviceLibs.addLibrary(AMDGCNLibraryList::Ockl);
+        if (funcName.starts_with("__ocml_"))
+          deviceLibs.addLibrary(AMDGCNLibraryList::Ocml);
+      }
+    }
+  }
+  addControlVariables(module, deviceLibs, target.hasWave64(), target.hasDaz(),
                       target.hasFiniteOnly(), target.hasUnsafeMath(),
                       target.hasFastMath(), target.hasCorrectSqrt(),
                       target.getAbi());
 }
 
-// Get the paths of ROCm device libraries.
-LogicalResult SerializeGPUModuleBase::getCommonBitcodeLibs(
-    llvm::SmallVector<std::string> &libs, SmallVector<char, 256> &libPath,
-    StringRef isaVersion) {
-  auto addLib = [&](StringRef path) -> bool {
-    if (!llvm::sys::fs::is_regular_file(path)) {
-      getOperation().emitRemark() << "Bitcode library path: " << path
-                                  << " does not exist or is not a file.\n";
-      return true;
-    }
-    libs.push_back(path.str());
-    return false;
-  };
-  auto getLibPath = [&libPath](Twine lib) {
-    auto baseSize = libPath.size();
-    llvm::sys::path::append(libPath, lib + ".bc");
-    std::string path(StringRef(libPath.data(), libPath.size()).str());
-    libPath.truncate(baseSize);
-    return path;
-  };
-
-  // Add ROCm device libraries. Fail if any of the libraries is not found.
-  if (addLib(getLibPath("ocml")) || addLib(getLibPath("ockl")) ||
-      addLib(getLibPath("hip")) || addLib(getLibPath("opencl")) ||
-      addLib(getLibPath("oclc_isa_version_" + isaVersion)))
-    return failure();
-  return success();
-}
-
 void SerializeGPUModuleBase::addControlVariables(
-    llvm::Module &module, bool wave64, bool daz, bool finiteOnly,
-    bool unsafeMath, bool fastMath, bool correctSqrt, StringRef abiVer) {
-  llvm::Type *i8Ty = llvm::Type::getInt8Ty(module.getContext());
-  auto addControlVariable = [i8Ty, &module](StringRef name, bool enable) {
+    llvm::Module &module, AMDGCNLibraryList libs, bool wave64, bool daz,
+    bool finiteOnly, bool unsafeMath, bool fastMath, bool correctSqrt,
+    StringRef abiVer) {
+  // Return if no device libraries are required.
+  if (libs.isEmpty())
+    return;
+  // Helper function for adding control variables.
+  auto addControlVariable = [&module](StringRef name, uint32_t value,
+                                      uint32_t bitwidth) {
+    if (module.getNamedGlobal(name)) {
+      return;
+    }
+    llvm::IntegerType *type =
+        llvm::IntegerType::getIntNTy(module.getContext(), bitwidth);
     llvm::GlobalVariable *controlVariable = new llvm::GlobalVariable(
-        module, i8Ty, true, llvm::GlobalValue::LinkageTypes::LinkOnceODRLinkage,
-        llvm::ConstantInt::get(i8Ty, enable), name, nullptr,
-        llvm::GlobalValue::ThreadLocalMode::NotThreadLocal, 4);
+        module, /*isConstant=*/type, true,
+        llvm::GlobalValue::LinkageTypes::LinkOnceODRLinkage,
+        llvm::ConstantInt::get(type, value), name, /*before=*/nullptr,
+        /*threadLocalMode=*/llvm::GlobalValue::ThreadLocalMode::NotThreadLocal,
+        /*addressSpace=*/4);
     controlVariable->setVisibility(
         llvm::GlobalValue::VisibilityTypes::ProtectedVisibility);
-    controlVariable->setAlignment(llvm::MaybeAlign(1));
+    controlVariable->setAlignment(llvm::MaybeAlign(bitwidth / 8));
     controlVariable->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Local);
   };
-  addControlVariable("__oclc_finite_only_opt", finiteOnly || fastMath);
-  addControlVariable("__oclc_unsafe_math_opt", unsafeMath || fastMath);
-  addControlVariable("__oclc_daz_opt", daz || fastMath);
-  addControlVariable("__oclc_correctly_rounded_sqrt32",
-                     correctSqrt && !fastMath);
-  addControlVariable("__oclc_wavefrontsize64", wave64);
-
-  llvm::Type *i32Ty = llvm::Type::getInt32Ty(module.getContext());
-  int abi = 500;
-  abiVer.getAsInteger(0, abi);
-  llvm::GlobalVariable *abiVersion = new llvm::GlobalVariable(
-      module, i32Ty, true, llvm::GlobalValue::LinkageTypes::LinkOnceODRLinkage,
-      llvm::ConstantInt::get(i32Ty, abi), "__oclc_ABI_version", nullptr,
-      llvm::GlobalValue::ThreadLocalMode::NotThreadLocal, 4);
-  abiVersion->setVisibility(
-      llvm::GlobalValue::VisibilityTypes::ProtectedVisibility);
-  abiVersion->setAlignment(llvm::MaybeAlign(4));
-  abiVersion->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Local);
+  // Add ocml related control variables.
+  if (libs.requiresLibrary(AMDGCNLibraryList::Ocml)) {
+    addControlVariable("__oclc_finite_only_opt", finiteOnly || fastMath, 8);
+    addControlVariable("__oclc_daz_opt", daz || fastMath, 8);
+    addControlVariable("__oclc_correctly_rounded_sqrt32",
+                       correctSqrt && !fastMath, 8);
+    addControlVariable("__oclc_unsafe_math_opt", unsafeMath || fastMath, 8);
+  }
+  // Add ocml or ockl related control variables.
+  if (libs.containLibraries(AMDGCNLibraryList::Ocml |
+                            AMDGCNLibraryList::Ockl)) {
+    addControlVariable("__oclc_wavefrontsize64", wave64, 8);
+    int abi = 500;
+    abiVer.getAsInteger(0, abi);
+    addControlVariable("__oclc_ABI_version", abi, 32);
+  }
 }
 
 std::optional<SmallVector<char, 0>>
@@ -312,43 +350,11 @@ SerializeGPUModuleBase::assembleIsa(StringRef isa) {
 
   parser->setTargetParser(*tap);
   parser->Run(false);
-
   return result;
 }
 
-#if MLIR_ENABLE_ROCM_CONVERSIONS
-namespace {
-class AMDGPUSerializer : public SerializeGPUModuleBase {
-public:
-  AMDGPUSerializer(Operation &module, ROCDLTargetAttr target,
-                   const gpu::TargetOptions &targetOptions);
-
-  gpu::GPUModuleOp getOperation();
-
-  // Compile to HSA.
-  std::optional<SmallVector<char, 0>>
-  compileToBinary(const std::string &serializedISA);
-
-  std::optional<SmallVector<char, 0>>
-  moduleToObject(llvm::Module &llvmModule) override;
-
-private:
-  // Target options.
-  gpu::TargetOptions targetOptions;
-};
-} // namespace
-
-AMDGPUSerializer::AMDGPUSerializer(Operation &module, ROCDLTargetAttr target,
-                                   const gpu::TargetOptions &targetOptions)
-    : SerializeGPUModuleBase(module, target, targetOptions),
-      targetOptions(targetOptions) {}
-
-gpu::GPUModuleOp AMDGPUSerializer::getOperation() {
-  return dyn_cast<gpu::GPUModuleOp>(&SerializeGPUModuleBase::getOperation());
-}
-
 std::optional<SmallVector<char, 0>>
-AMDGPUSerializer::compileToBinary(const std::string &serializedISA) {
+SerializeGPUModuleBase::compileToBinary(const std::string &serializedISA) {
   // Assemble the ISA.
   std::optional<SmallVector<char, 0>> isaBinary = assembleIsa(serializedISA);
 
@@ -407,13 +413,13 @@ AMDGPUSerializer::compileToBinary(const std::string &serializedISA) {
   return SmallVector<char, 0>(buffer.begin(), buffer.end());
 }
 
-std::optional<SmallVector<char, 0>>
-AMDGPUSerializer::moduleToObject(llvm::Module &llvmModule) {
+std::optional<SmallVector<char, 0>> SerializeGPUModuleBase::moduleToObjectImpl(
+    const gpu::TargetOptions &targetOptions, llvm::Module &llvmModule) {
   // Return LLVM IR if the compilation target is offload.
 #define DEBUG_TYPE "serialize-to-llvm"
   LLVM_DEBUG({
-    llvm::dbgs() << "LLVM IR for module: " << getOperation().getNameAttr()
-                 << "\n"
+    llvm::dbgs() << "LLVM IR for module: "
+                 << cast<gpu::GPUModuleOp>(getOperation()).getNameAttr() << "\n"
                  << llvmModule << "\n";
   });
 #undef DEBUG_TYPE
@@ -437,7 +443,8 @@ AMDGPUSerializer::moduleToObject(llvm::Module &llvmModule) {
   }
 #define DEBUG_TYPE "serialize-to-isa"
   LLVM_DEBUG({
-    llvm::dbgs() << "ISA for module: " << getOperation().getNameAttr() << "\n"
+    llvm::dbgs() << "ISA for module: "
+                 << cast<gpu::GPUModuleOp>(getOperation()).getNameAttr() << "\n"
                  << *serializedISA << "\n";
   });
 #undef DEBUG_TYPE
@@ -448,6 +455,38 @@ AMDGPUSerializer::moduleToObject(llvm::Module &llvmModule) {
   // Compile to binary.
   return compileToBinary(*serializedISA);
 }
+
+#if MLIR_ENABLE_ROCM_CONVERSIONS
+namespace {
+class AMDGPUSerializer : public SerializeGPUModuleBase {
+public:
+  AMDGPUSerializer(Operation &module, ROCDLTargetAttr target,
+                   const gpu::TargetOptions &targetOptions);
+
+  gpu::GPUModuleOp getOperation();
+
+  std::optional<SmallVector<char, 0>>
+  moduleToObject(llvm::Module &llvmModule) override;
+
+private:
+  // Target options.
...
[truncated]

@fabianmcg fabianmcg force-pushed the pr-improve-rcodl-serialization branch from 427561c to 289a55e Compare June 13, 2024 20:40
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've got a few notes

@@ -27,6 +27,62 @@ namespace ROCDL {
/// 5. Returns an empty string.
StringRef getROCMPath();

/// Helper class for specifying the AMD GCN device libraries required for
/// compilation.
class AMDGCNLibraryList {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, isn't this just enum class AMDGCNLibraryList : uint32_t?

... Actually, on top of that - there's already general support for bit enums. This could and should be tablegen'd

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to avoid going to tablegen for such a small class, but I'll switch it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's more that there's a lot of common infrastructure around these sorts of flag enums (including, say, printing them) that can be autogenerated. And it'll be useful to have if someone ever wants to put this as an attribute somewhere

}

/// Adds all the libraries in `list` to the library list.
AMDGCNLibraryList addList(AMDGCNLibraryList list) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For instance, this is bitEnumSet() or some similarly-phrased function

if (DEFINED ROCM_PATH)
set(DEFAULT_ROCM_PATH "${ROCM_PATH}" CACHE PATH "Fallback path to search for ROCm installs")
elseif(DEFINED ENV{ROCM_PATH})
set(DEFAULT_ROCM_PATH "$ENV{ROCM_PATH}" CACHE PATH "Fallback path to search for ROCm installs")
else()
set(DEFAULT_ROCM_PATH "/opt/rocm" CACHE PATH "Fallback path to search for ROCm installs")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is meant to be a fallback so that the build goes to look in /opt/rocm for the device libraries if there are no clues about where they are in the user's environment, though?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and the flag is still available, however, having a hard coded value causes issues in windows, as the value is for linux. a better solution would be detecting it with CMake, I'll look into if there's anything like findHip in CMake.

@@ -123,17 +123,12 @@ add_mlir_dialect_library(MLIRROCDLTarget
)

if(MLIR_ENABLE_ROCM_CONVERSIONS)
if (NOT ("AMDGPU" IN_LIST LLVM_TARGETS_TO_BUILD))
message(SEND_ERROR
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this because there's a duplicate check? What happens if the AMDGPU backend isn't available?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MLIR_ENABLE_ROCM_CONVERSIONS is already an alias for checking whether AMDGPU is being built or not.

@fabianmcg
Copy link
Contributor Author

@krzysz00 I'm still not convinced on using tablegen for the AMD GCN lib enum as there are no consumers for it as an attribute, and we generally keep things minimal and internal unless there's a use case, and right now there are none. Also, I would argue that the enum in question is completely internal to the compilation process and it shouldn't be exposed.
However, I switched the class to an enum class, and by using https://www.llvm.org/doxygen/BitmaskEnum_8h.html I removed many of the methods, reducing code duplication.

With respect to CMake and DEFAULT_ROCM_PATH, right now using FindHip is not robust enough, as the module still lives inside /opt/rocm and not CMake. Hence, I'm arguing that it's the responsibility of the builder/maintainer to specify DEFAULT_ROCM_PATH when building LLVM, otherwise we will encounter fragile builds and issues with windows.

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These all seem like reasonable changes to me, approved

@fabianmcg fabianmcg merged commit 954cb5f into llvm:main Jun 17, 2024
7 checks passed
fabianmcg added a commit that referenced this pull request Jun 17, 2024
@fabianmcg fabianmcg deleted the pr-improve-rcodl-serialization branch June 17, 2024 16:51
fabianmcg added a commit that referenced this pull request Jun 17, 2024
Reland: #95456

This patch improves the ROCDL gpu serialization API by:
- Introducing the enum `AMDGCNLibraries` for specifying the AMD GCN
device code libraries to use during linking.
- Removing `getCommonBitcodeLibs` in favor of `AMDGCNLibraries`.
Previously `getCommonBitcodeLibs` would try to load all AMD GCN bitcode
librariesm now it will only load the requested libraries.
- Exposing the `compileToBinary` method and making it virtual, allowing
downstream users to re-use this method.
- Exposing `moduleToObjectImpl`, this method provides a prototype flow
for compiling to binary, allowing downstream users to re-use this
method.
- It also avoids constructing the control variables if no device
libraries are being used.
- Changes the style of the error messages to be composable, ie no full
stops.
- Adds an error message for when the ROCm toolkit can't be found but it
was required.
fabianmcg added a commit that referenced this pull request Jun 25, 2024
Reland: #95456

This patch improves the ROCDL gpu serialization API by:
- Introducing the enum `AMDGCNLibraries` for specifying the AMD GCN
device code libraries to use during linking.
- Removing `getCommonBitcodeLibs` in favor of `AMDGCNLibraries`.
Previously `getCommonBitcodeLibs` would try to load all AMD GCN bitcode
librariesm now it will only load the requested libraries.
- Exposing the `compileToBinary` method and making it virtual, allowing
downstream users to re-use this method.
- Exposing `moduleToObjectImpl`, this method provides a prototype flow
for compiling to binary, allowing downstream users to re-use this
method.
- It also avoids constructing the control variables if no device
libraries are being used.
- Changes the style of the error messages to be composable, ie no full
stops.
- Adds an error message for when the ROCm toolkit can't be found but it
was required.
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
Reland: llvm#95456

This patch improves the ROCDL gpu serialization API by:
- Introducing the enum `AMDGCNLibraries` for specifying the AMD GCN
device code libraries to use during linking.
- Removing `getCommonBitcodeLibs` in favor of `AMDGCNLibraries`.
Previously `getCommonBitcodeLibs` would try to load all AMD GCN bitcode
librariesm now it will only load the requested libraries.
- Exposing the `compileToBinary` method and making it virtual, allowing
downstream users to re-use this method.
- Exposing `moduleToObjectImpl`, this method provides a prototype flow
for compiling to binary, allowing downstream users to re-use this
method.
- It also avoids constructing the control variables if no device
libraries are being used.
- Changes the style of the error messages to be composable, ie no full
stops.
- Adds an error message for when the ROCm toolkit can't be found but it
was required.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants