Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][spirv][gpu] Convert remaining wmma ops to KHR coop matrix #66455

Merged
merged 1 commit into from
Sep 19, 2023

Conversation

kuhar
Copy link
Member

@kuhar kuhar commented Sep 15, 2023

These do not produce extension-specific ops and are handled via common
patterns for both the KHR and the NV coop matrix extension.

Also improve match failure reporting and error handling in type conversion.

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 15, 2023

@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir-core

Changes These do not produce extension-specific ops and are handled via common patterns for both the KHR and the NV coop matrix extension.

Also improve match failure reporting and error handling in type conversion.

Patch is 36.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66455.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h (+10)
  • (modified) mlir/include/mlir/Conversion/Passes.td (+5-1)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td (+25)
  • (modified) mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp (+16-4)
  • (modified) mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp (+268-106)
  • (added) mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir (+174)
  • (modified) mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-nv-coop-matrix.mlir (+2-1)
diff --git a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
index 6c4643da1884900..c258513ed4878ea 100644
--- a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
+++ b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
@@ -30,11 +30,21 @@ class MMAMatrixType;
 void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                 RewritePatternSet &patterns);
 
+/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV,
+/// using the KHR Cooperative Matrix extension.
+void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
+    SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
+
 /// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV,
 /// using the NV Cooperative Matrix extension.
 void populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
     SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
 
+/// Returns a KHR cooperative matrix type corresponding to the MMAMatrixType
+/// `type`.
+spirv::CooperativeMatrixType
+convertMMAToSPIRVCoopMatrixType(gpu::MMAMatrixType type);
+
 /// Returns an NV cooperative matrix type corresponding to the MMAMatrixType
 /// `type`.
 spirv::CooperativeMatrixNVType
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 3218760931b8cb0..5e0f976b18f7da5 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -567,7 +567,11 @@ def ConvertGPUToSPIRV : Pass<"convert-gpu-to-spirv", "ModuleOp"> {
   let options = [
     Option<"use64bitIndex", "use-64bit-index",
            "bool", /*default=*/"false",
-           "Use 64-bit integers to convert index types">
+           "Use 64-bit integers to convert index types">,
+    Option<"useCoopMatrixNV", "use-coop-matrix-nv",
+           "bool", /*default=*/"false",
+           "Use the NV cooperative matrix extension insted of the KHR extension"
+           " to lower GPU WMMA ops">,
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index b5ea0774f589d16..34c76c5e9382302 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -146,6 +146,15 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
   let results = (outs
     SPIRV_AnyCooperativeMatrix:$result
   );
+
+  let builders = [
+    OpBuilder<(ins "Type":$result, "Value":$pointer,
+                   "spirv::ConstantOp":$stride,
+                   "spirv::CooperativeMatrixLayoutKHR":$layout), [{
+      build($_builder, $_state, result, pointer, layout, stride,
+            spirv::MemoryAccessAttr{});
+    }]>
+  ];
 }
 
 // -----
@@ -226,6 +235,15 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
   );
 
   let results = (outs);
+
+  let builders = [
+    OpBuilder<(ins "Value":$pointer, "Value":$object,
+                   "spirv::ConstantOp":$stride,
+                   "spirv::CooperativeMatrixLayoutKHR":$layout), [{
+      build($_builder, $_state, pointer, object, layout, stride,
+            spirv::MemoryAccessAttr{});
+    }]>
+  ];
 }
 
 // -----
@@ -332,6 +350,13 @@ def SPIRV_KHRCooperativeMatrixMulAddOp : SPIRV_KhrVendorOp<"CooperativeMatrixMul
   let results = (outs
     SPIRV_AnyCooperativeMatrix:$result
   );
+
+  let builders = [
+    OpBuilder<(ins "Value":$a, "Value":$b, "Value":$c), [{
+      build($_builder, $_state, a, b, c,
+            spirv::CooperativeMatrixOperandsKHRAttr{});
+    }]>
+  ];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index d0ce58597f980d4..5b05c45bf602509 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -86,13 +86,25 @@ void GPUToSPIRVPass::runOnOperation() {
     SPIRVConversionOptions options;
     options.use64bitIndex = this->use64bitIndex;
     SPIRVTypeConverter typeConverter(targetAttr, options);
-    typeConverter.addConversion([&](gpu::MMAMatrixType type) -> Type {
-      return convertMMAToSPIRVCoopMatrixNVType(type);
+
+    typeConverter.addConversion([useNV = this->useCoopMatrixNV.getValue()](
+                                    gpu::MMAMatrixType type) -> Type {
+      if (useNV)
+        return convertMMAToSPIRVCoopMatrixNVType(type);
+
+      return convertMMAToSPIRVCoopMatrixType(type);
     });
+
     RewritePatternSet patterns(context);
     populateGPUToSPIRVPatterns(typeConverter, patterns);
-    populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typeConverter,
-                                                         patterns);
+    if (this->useCoopMatrixNV) {
+      populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typeConverter,
+                                                           patterns);
+    } else {
+      populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(typeConverter,
+                                                            patterns);
+    }
+
     // TODO: Change SPIR-V conversion to be progressive and remove the following
     // patterns.
     mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns);
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
index bf3fff027fe384a..eb7fcb63d920d8f 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
@@ -18,12 +18,22 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/ValueRange.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringSwitch.h"
 
-namespace mlir::nv {
-namespace {
+#include <cassert>
+
+namespace mlir {
+//===----------------------------------------------------------------------===//
+// Patterns and helpers used by both the KHR and the NV lowering paths.
+//===----------------------------------------------------------------------===//
 
 /// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
 /// when the elementwise op directly supports with cooperative matrix type.
@@ -31,9 +41,11 @@ namespace {
 ///
 /// See SPV_NV_cooperative_matrix for supported elementwise ops.
 static bool createElementwiseOp(ConversionPatternRewriter &builder,
-                                gpu::SubgroupMmaElementwiseOp op,
-                                spirv::CooperativeMatrixNVType coopType,
+                                gpu::SubgroupMmaElementwiseOp op, Type coopType,
                                 ValueRange operands) {
+  assert((isa<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType>(
+      coopType)));
+
   switch (op.getOpType()) {
   case gpu::MMAElementwiseOp::ADDF:
     builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
@@ -71,6 +83,223 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
   return false;
 }
 
+bool allOperandsHaveSameCoopMatrixType(ValueRange operands) {
+  assert(!operands.empty());
+  if (!llvm::all_equal(
+          llvm::map_range(operands, [](Value v) { return v.getType(); })))
+    return false;
+
+  return isa<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType>(
+      operands.front().getType());
+}
+
+namespace {
+/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative
+/// matrix ops.
+struct WmmaConstantOpToSPIRVLowering final
+    : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    assert(adaptor.getOperands().size() == 1);
+    Value cst = adaptor.getOperands().front();
+    auto coopType = getTypeConverter()->convertType(op.getType());
+    if (!coopType)
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst);
+    return success();
+  }
+};
+
+/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
+/// the default case.
+struct WmmaElementwiseOpToSPIRVDefaultLowering final
+    : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // All operands should be of cooperative matrix types.
+    if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
+      return rewriter.notifyMatchFailure(op,
+                                         "not all operands are coop matrices");
+    }
+
+    auto coopType = getTypeConverter()->convertType(op.getType());
+    if (!coopType)
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+    return success(
+        createElementwiseOp(rewriter, op, coopType, adaptor.getOperands()));
+  }
+};
+
+/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
+/// matrix times scalar case.
+struct WmmaElementwiseOpToSPIRVScalarMulLowering final
+    : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (adaptor.getOperands().size() != 2)
+      return failure();
+
+    // All operands should be of cooperative matrix types.
+    if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
+      return rewriter.notifyMatchFailure(op,
+                                         "not all operands are coop matrices");
+    }
+
+    if (op.getOpType() != gpu::MMAElementwiseOp::MULF)
+      return failure();
+
+    // Use the original operands to check whether one of the operands is a splat
+    // scalar value.
+    Value lhs = op.getOperands().front();
+    Value rhs = op.getOperands().back();
+    Value splat = nullptr;
+    Value matrix = nullptr;
+    if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
+      splat = adaptor.getOperands().front();
+      matrix = adaptor.getOperands().back();
+    } else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
+      matrix = adaptor.getOperands().front();
+      splat = adaptor.getOperands().back();
+    }
+    if (!splat || !matrix)
+      return rewriter.notifyMatchFailure(op, "no splat operand");
+
+    // Constant MMA matrix ops are converted to `spirv.CompositeConstruct` ops.
+    Value scalar;
+    auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
+    if (!cc) {
+      return rewriter.notifyMatchFailure(op,
+                                         "splat is not a composite construct");
+    }
+
+    assert(cc.getConstituents().size() == 1);
+    scalar = cc.getConstituents().front();
+
+    auto coopType = getTypeConverter()->convertType(op.getType());
+    if (!coopType)
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
+    rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
+        op, coopType, ValueRange{matrix, scalar});
+    return success();
+  }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// SPV_KHR_cooperative_matrix
+//===----------------------------------------------------------------------===//
+
+namespace khr {
+namespace {
+
+/// Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV
+/// dialect.
+struct WmmaLoadOpToSPIRVLowering final
+    : OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    Location loc = op->getLoc();
+
+    auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType());
+    MemRefType memrefType = op.getSrcMemref().getType();
+    Value bufferPtr =
+        spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(),
+                             adaptor.getIndices(), loc, rewriter);
+
+    auto coopType =
+        typeConverter.convertType<spirv::CooperativeMatrixType>(retType);
+    if (!coopType)
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+    int64_t stride = op.getLeadDimension().getSExtValue();
+    IntegerType i32Type = rewriter.getI32Type();
+    auto strideValue = rewriter.create<spirv::ConstantOp>(
+        loc, i32Type, IntegerAttr::get(i32Type, stride));
+
+    bool isColMajor = op.getTranspose().value_or(false);
+    auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
+                             : spirv::CooperativeMatrixLayoutKHR::RowMajor;
+
+    rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixLoadOp>(
+        op, coopType, bufferPtr, strideValue, layout);
+    return success();
+  }
+};
+
+/// Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV
+/// dialect.
+struct WmmaStoreOpToSPIRVLowering final
+    : OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+    Location loc = op->getLoc();
+
+    auto memrefType = cast<MemRefType>(op.getDstMemref().getType());
+    Value bufferPtr =
+        spirv::getElementPtr(typeConverter, memrefType, adaptor.getDstMemref(),
+                             adaptor.getIndices(), loc, rewriter);
+
+    int64_t stride = op.getLeadDimension().getSExtValue();
+    IntegerType i32Type = rewriter.getI32Type();
+    auto strideValue = rewriter.create<spirv::ConstantOp>(
+        loc, i32Type, IntegerAttr::get(i32Type, stride));
+
+    bool isColMajor = op.getTranspose().value_or(false);
+    auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
+                             : spirv::CooperativeMatrixLayoutKHR::RowMajor;
+
+    rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixStoreOp>(
+        op, bufferPtr, adaptor.getSrc(), strideValue, layout);
+    return success();
+  }
+};
+
+/// Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV
+/// dialect.
+struct WmmaMmaOpToSPIRVLowering final
+    : OpConversionPattern<gpu::SubgroupMmaComputeOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
+                  OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>(
+        subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
+        adaptor.getOpC());
+    return success();
+  }
+};
+
+} // namespace
+} // namespace khr
+
+//===----------------------------------------------------------------------===//
+// SPV_NV_cooperative_matrix
+//===----------------------------------------------------------------------===//
+
+namespace nv {
+namespace {
+
 /// Converts the GPU MMA loadOp to NVCooperativeMatrixLoad op in the SPIRV
 /// dialect.
 struct WmmaLoadOpToSPIRVLowering final
@@ -152,102 +381,9 @@ struct WmmaMmaOpToSPIRVLowering final
   }
 };
 
-/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V NV cooperative matrix
-/// ops.
-struct WmmaConstantOpToSPIRVLowering final
-    : OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantMatrixOp,
-                  OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    Value cst = adaptor.getOperands()[0];
-    auto coopType = convertMMAToSPIRVCoopMatrixNVType(
-        cast<gpu::MMAMatrixType>(subgroupMmaConstantMatrixOp.getType()));
-    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
-        subgroupMmaConstantMatrixOp, coopType, cst);
-    return success();
-  }
-};
-
-/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
-/// the default case.
-struct WmmaElementwiseOpToSPIRVDefaultLowering final
-    : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp,
-                  OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    // All operands should be of cooperative matrix types.
-    for (Value operand : adaptor.getOperands()) {
-      if (!isa<spirv::CooperativeMatrixNVType>(operand.getType()))
-        return failure();
-    }
-    auto coopType = convertMMAToSPIRVCoopMatrixNVType(
-        cast<gpu::MMAMatrixType>(elementwiseOp.getType()));
-    return success(createElementwiseOp(rewriter, elementwiseOp, coopType,
-                                       adaptor.getOperands()));
-  }
-};
-
-/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
-/// matrix times scalar case.
-struct WmmaElementwiseOpToSPIRVScalarMulLowering final
-    : OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(gp...

@kuhar
Copy link
Member Author

kuhar commented Sep 15, 2023

Note: This PR is stacked on top of #66311, so please review the second commit only. Once the parent PR lands, I'll rebase this one.

These do not produce extension-specific ops and are handled via common
patterns for both the KHR and the NV coop matrix extension.

Also improve match failure reporting and error handling in type
conversion.
@kuhar
Copy link
Member Author

kuhar commented Sep 15, 2023

#66311 landed, rebase.

@kuhar
Copy link
Member Author

kuhar commented Sep 19, 2023

Ping @qedawkins @raikonenfnu

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants