Skip to content

Commit

Permalink
[mlir][spirv][gpu] Convert remaining wmma ops to KHR coop matrix
Browse files Browse the repository at this point in the history
These do not produce extension-specific ops and are handled via common
patterns for both the KHR and the NV coop matrix extension.

Also improve match failure reporting and error handling in type
conversion.
  • Loading branch information
kuhar committed Sep 15, 2023
1 parent ed4daea commit 992b575
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 103 deletions.
231 changes: 129 additions & 102 deletions mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,17 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"

#include <cassert>

namespace mlir {
//===----------------------------------------------------------------------===//
// Patterns and helpers used by both the KHR and the NV lowering paths.
//===----------------------------------------------------------------------===//

/// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
/// when the elementwise op directly supports with cooperative matrix type.
/// Returns false if cannot.
Expand Down Expand Up @@ -77,6 +83,119 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
return false;
}

bool allOperandsHaveSameCoopMatrixType(ValueRange operands) {
assert(!operands.empty());
if (!llvm::all_equal(
llvm::map_range(operands, [](Value v) { return v.getType(); })))
return false;

return isa<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType>(
operands.front().getType());
}

namespace {
/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative
/// matrix ops.
struct WmmaConstantOpToSPIRVLowering final
: OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
assert(adaptor.getOperands().size() == 1);
Value cst = adaptor.getOperands().front();
auto coopType = getTypeConverter()->convertType(op.getType());
if (!coopType)
return rewriter.notifyMatchFailure(op, "type conversion failed");

rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, coopType, cst);
return success();
}
};

/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
/// the default case.
struct WmmaElementwiseOpToSPIRVDefaultLowering final
: OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// All operands should be of cooperative matrix types.
if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
return rewriter.notifyMatchFailure(op,
"not all operands are coop matrices");
}

auto coopType = getTypeConverter()->convertType(op.getType());
if (!coopType)
return rewriter.notifyMatchFailure(op, "type conversion failed");

return success(
createElementwiseOp(rewriter, op, coopType, adaptor.getOperands()));
}
};

/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
/// matrix times scalar case.
struct WmmaElementwiseOpToSPIRVScalarMulLowering final
: OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (adaptor.getOperands().size() != 2)
return failure();

// All operands should be of cooperative matrix types.
if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) {
return rewriter.notifyMatchFailure(op,
"not all operands are coop matrices");
}

if (op.getOpType() != gpu::MMAElementwiseOp::MULF)
return failure();

// Use the original operands to check whether one of the operands is a splat
// scalar value.
Value lhs = op.getOperands().front();
Value rhs = op.getOperands().back();
Value splat = nullptr;
Value matrix = nullptr;
if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
splat = adaptor.getOperands().front();
matrix = adaptor.getOperands().back();
} else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
matrix = adaptor.getOperands().front();
splat = adaptor.getOperands().back();
}
if (!splat || !matrix)
return rewriter.notifyMatchFailure(op, "no splat operand");

// Constant MMA matrix ops are converted to `spirv.CompositeConstruct` ops.
Value scalar;
auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
if (!cc) {
return rewriter.notifyMatchFailure(op,
"splat is not a composite construct");
}

assert(cc.getConstituents().size() == 1);
scalar = cc.getConstituents().front();

auto coopType = getTypeConverter()->convertType(op.getType());
if (!coopType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
op, coopType, ValueRange{matrix, scalar});
return success();
}
};
} // namespace

//===----------------------------------------------------------------------===//
// SPV_KHR_cooperative_matrix
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -262,100 +381,6 @@ struct WmmaMmaOpToSPIRVLowering final
}
};

/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V NV cooperative matrix
/// ops.
struct WmmaConstantOpToSPIRVLowering final
: OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantMatrixOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value cst = adaptor.getOperands()[0];
auto coopType = convertMMAToSPIRVCoopMatrixNVType(
cast<gpu::MMAMatrixType>(subgroupMmaConstantMatrixOp.getType()));
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
subgroupMmaConstantMatrixOp, coopType, cst);
return success();
}
};

/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
/// the default case.
struct WmmaElementwiseOpToSPIRVDefaultLowering final
: OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// All operands should be of cooperative matrix types.
for (Value operand : adaptor.getOperands()) {
if (!isa<spirv::CooperativeMatrixNVType>(operand.getType()))
return failure();
}
auto coopType = convertMMAToSPIRVCoopMatrixNVType(
cast<gpu::MMAMatrixType>(elementwiseOp.getType()));
return success(createElementwiseOp(rewriter, elementwiseOp, coopType,
adaptor.getOperands()));
}
};

/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
/// matrix times scalar case.
struct WmmaElementwiseOpToSPIRVScalarMulLowering final
: OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (adaptor.getOperands().size() != 2)
return failure();
// All operands should be of cooperative matrix types.
for (Value operand : adaptor.getOperands()) {
if (!isa<spirv::CooperativeMatrixNVType>(operand.getType()))
return failure();
}

if (elementwiseOp.getOpType() != gpu::MMAElementwiseOp::MULF)
return failure();

// Use the original operands to check whether one of the operands is a splat
// scalar value.
Value lhs = elementwiseOp.getOperands().front();
Value rhs = elementwiseOp.getOperands().back();
Value splat = nullptr;
Value matrix = nullptr;
if (lhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
splat = adaptor.getOperands().front();
matrix = adaptor.getOperands().back();
} else if (rhs.getDefiningOp<gpu::SubgroupMmaConstantMatrixOp>()) {
matrix = adaptor.getOperands().front();
splat = adaptor.getOperands().back();
}
if (!splat || !matrix)
return failure();

// Constant MMA matrix ops are converted to spirv.CompositeConstruct ops.
Value scalar = nullptr;
auto cc = splat.getDefiningOp<spirv::CompositeConstructOp>();
if (!cc)
return failure();
assert(cc.getConstituents().size() == 1);
scalar = cc.getConstituents().front();

auto coopType = convertMMAToSPIRVCoopMatrixNVType(
cast<gpu::MMAMatrixType>(elementwiseOp.getType()));
rewriter.replaceOpWithNewOp<spirv::MatrixTimesScalarOp>(
elementwiseOp, coopType, ValueRange{matrix, scalar});
return success();
}
};

} // namespace
} // namespace nv
} // namespace mlir
Expand Down Expand Up @@ -389,19 +414,21 @@ void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
using namespace mlir;
MLIRContext *context = patterns.getContext();
patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
khr::WmmaStoreOpToSPIRVLowering>(converter, context);
khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
// Give the following patterns higher benefit to prevail over the default one.
patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
/*benefit=*/2);
}

void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
using namespace mlir;
MLIRContext *context = patterns.getContext();
patterns
.add<nv::WmmaLoadOpToSPIRVLowering, nv::WmmaMmaOpToSPIRVLowering,
nv::WmmaStoreOpToSPIRVLowering, nv::WmmaConstantOpToSPIRVLowering,
nv::WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
patterns.add<nv::WmmaLoadOpToSPIRVLowering, nv::WmmaMmaOpToSPIRVLowering,
nv::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
// Give the following patterns higher benefit to prevail over the default one.
patterns.add<nv::WmmaElementwiseOpToSPIRVScalarMulLowering>(converter,
context,
/*benefit=*/2);
patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
/*benefit=*/2);
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,106 @@ module attributes {
-> !gpu.mma_matrix<16x16xf16, "COp">

%i = arith.constant 0 : index
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.+}}, %[[MAD]], %{{.+}}, <RowMajor>
// CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[MAD]], %{{.+}}, <RowMajor>
gpu.subgroup_mma_store_matrix %D, %ptr[%i,%i] {leadDimension = 32 : index} :
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
// CHECK: spirv.Return
gpu.return
}

// CHECK-LABEL: spirv.func @gpu_wmma_constant_op
gpu.func @gpu_wmma_constant_op(%ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
// CHECK: %[[CST1F:.+]] = spirv.Constant 1.000000e+00 : f16
%cst = arith.constant 1.0 : f16
// CHECK: %[[MAT:.+]] = spirv.CompositeConstruct %[[CST1F]] :
// CHECK-SAME: (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
%C = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf16, "COp">

%i = arith.constant 0 : index
// CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[MAT]], %{{.+}}, <RowMajor>
gpu.subgroup_mma_store_matrix %C, %ptr[%i,%i] {leadDimension = 32 : index} :
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
// CHECK: spirv.Return
gpu.return
}

// CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
gpu.func @gpu_wmma_elementwise_op_default(%A: !gpu.mma_matrix<16x16xf16, "COp">,
%B: !gpu.mma_matrix<16x16xf16, "COp">,
%ptr: memref<16x16xf32, #spirv.storage_class<StorageBuffer>>) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
// CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
%C = gpu.subgroup_mma_elementwise addf %A, %B :
(!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: {{%.*}} = spirv.FNegate {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
%D = gpu.subgroup_mma_elementwise negatef %C :
(!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
%E = gpu.subgroup_mma_elementwise divf %D, %A :
(!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: {{%.*}} = spirv.FConvert {{%.*}} :
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> to !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
%F = gpu.subgroup_mma_elementwise extf %E :
(!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">

%i = arith.constant 0 : index
// CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %{{.+}}, %{{.+}}, <RowMajor>
gpu.subgroup_mma_store_matrix %F, %ptr[%i,%i] {leadDimension = 32 : index} :
!gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32, #spirv.storage_class<StorageBuffer>>
// CHECK: spirv.Return
gpu.return
}

// CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_times_scalar
// CHECK-SAME: %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
// CHECK-SAME: %[[S:.+]]: f16
gpu.func @gpu_wmma_elementwise_op_matrix_times_scalar(
%A: !gpu.mma_matrix<16x16xf16, "COp">, %scalar: f16,
%ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
%i = arith.constant 0 : index

%B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: %[[C:.+]] = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, f16
// CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[C]], %{{.+}}, <RowMajor>
%C = gpu.subgroup_mma_elementwise mulf %A, %B :
(!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
gpu.subgroup_mma_store_matrix %C, %ptr[%i,%i] {leadDimension = 32 : index} :
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>

// CHECK: %[[D:.+]] = spirv.MatrixTimesScalar %[[C]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, f16
// CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[D]], %{{.+}}, <RowMajor>
%D = gpu.subgroup_mma_elementwise mulf %B, %C :
(!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
gpu.subgroup_mma_store_matrix %D, %ptr[%i,%i] {leadDimension = 32 : index} :
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
// CHECK: spirv.Return
gpu.return
}

// CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_plus_scalar
// CHECK-SAME: %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
// CHECK-SAME: %[[S:.+]]: f16
gpu.func @gpu_wmma_elementwise_op_matrix_plus_scalar(
%A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16,
%ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
%i = arith.constant 0 : index

// CHECK: %[[SM:.+]] = spirv.CompositeConstruct %[[S]] : (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
%B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: %[[C:.+]] = spirv.FAdd %[[A]], %[[SM]] : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
%C = gpu.subgroup_mma_elementwise addf %A, %B :
(!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">

// CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[C]], %{{.+}}, <RowMajor>
gpu.subgroup_mma_store_matrix %C, %ptr[%i,%i] {leadDimension = 32 : index} :
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
// CHECK: spirv.Return
gpu.return
}
}
}

0 comments on commit 992b575

Please sign in to comment.