diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp index d73cd5686d66e9..eb7fcb63d920d8 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp @@ -24,11 +24,17 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/ValueRange.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" #include namespace mlir { +//===----------------------------------------------------------------------===// +// Patterns and helpers used by both the KHR and the NV lowering paths. +//===----------------------------------------------------------------------===// + /// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op /// when the elementwise op directly supports with cooperative matrix type. /// Returns false if cannot. @@ -77,6 +83,119 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder, return false; } +bool allOperandsHaveSameCoopMatrixType(ValueRange operands) { + assert(!operands.empty()); + if (!llvm::all_equal( + llvm::map_range(operands, [](Value v) { return v.getType(); }))) + return false; + + return isa( + operands.front().getType()); +} + +namespace { +/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V KHR/NV cooperative +/// matrix ops. +struct WmmaConstantOpToSPIRVLowering final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(adaptor.getOperands().size() == 1); + Value cst = adaptor.getOperands().front(); + auto coopType = getTypeConverter()->convertType(op.getType()); + if (!coopType) + return rewriter.notifyMatchFailure(op, "type conversion failed"); + + rewriter.replaceOpWithNewOp(op, coopType, cst); + return success(); + } +}; + +/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for +/// the default case. +struct WmmaElementwiseOpToSPIRVDefaultLowering final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // All operands should be of cooperative matrix types. + if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) { + return rewriter.notifyMatchFailure(op, + "not all operands are coop matrices"); + } + + auto coopType = getTypeConverter()->convertType(op.getType()); + if (!coopType) + return rewriter.notifyMatchFailure(op, "type conversion failed"); + + return success( + createElementwiseOp(rewriter, op, coopType, adaptor.getOperands())); + } +}; + +/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for +/// matrix times scalar case. +struct WmmaElementwiseOpToSPIRVScalarMulLowering final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaElementwiseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (adaptor.getOperands().size() != 2) + return failure(); + + // All operands should be of cooperative matrix types. + if (!allOperandsHaveSameCoopMatrixType(adaptor.getOperands())) { + return rewriter.notifyMatchFailure(op, + "not all operands are coop matrices"); + } + + if (op.getOpType() != gpu::MMAElementwiseOp::MULF) + return failure(); + + // Use the original operands to check whether one of the operands is a splat + // scalar value. + Value lhs = op.getOperands().front(); + Value rhs = op.getOperands().back(); + Value splat = nullptr; + Value matrix = nullptr; + if (lhs.getDefiningOp()) { + splat = adaptor.getOperands().front(); + matrix = adaptor.getOperands().back(); + } else if (rhs.getDefiningOp()) { + matrix = adaptor.getOperands().front(); + splat = adaptor.getOperands().back(); + } + if (!splat || !matrix) + return rewriter.notifyMatchFailure(op, "no splat operand"); + + // Constant MMA matrix ops are converted to `spirv.CompositeConstruct` ops. + Value scalar; + auto cc = splat.getDefiningOp(); + if (!cc) { + return rewriter.notifyMatchFailure(op, + "splat is not a composite construct"); + } + + assert(cc.getConstituents().size() == 1); + scalar = cc.getConstituents().front(); + + auto coopType = getTypeConverter()->convertType(op.getType()); + if (!coopType) + return rewriter.notifyMatchFailure(op, "type conversion failed"); + rewriter.replaceOpWithNewOp( + op, coopType, ValueRange{matrix, scalar}); + return success(); + } +}; +} // namespace + //===----------------------------------------------------------------------===// // SPV_KHR_cooperative_matrix //===----------------------------------------------------------------------===// @@ -262,100 +381,6 @@ struct WmmaMmaOpToSPIRVLowering final } }; -/// Converts GPU MMA ConstantMatrixOp to constant SPIR-V NV cooperative matrix -/// ops. -struct WmmaConstantOpToSPIRVLowering final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantMatrixOp, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value cst = adaptor.getOperands()[0]; - auto coopType = convertMMAToSPIRVCoopMatrixNVType( - cast(subgroupMmaConstantMatrixOp.getType())); - rewriter.replaceOpWithNewOp( - subgroupMmaConstantMatrixOp, coopType, cst); - return success(); - } -}; - -/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for -/// the default case. -struct WmmaElementwiseOpToSPIRVDefaultLowering final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // All operands should be of cooperative matrix types. - for (Value operand : adaptor.getOperands()) { - if (!isa(operand.getType())) - return failure(); - } - auto coopType = convertMMAToSPIRVCoopMatrixNVType( - cast(elementwiseOp.getType())); - return success(createElementwiseOp(rewriter, elementwiseOp, coopType, - adaptor.getOperands())); - } -}; - -/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for -/// matrix times scalar case. -struct WmmaElementwiseOpToSPIRVScalarMulLowering final - : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(gpu::SubgroupMmaElementwiseOp elementwiseOp, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (adaptor.getOperands().size() != 2) - return failure(); - // All operands should be of cooperative matrix types. - for (Value operand : adaptor.getOperands()) { - if (!isa(operand.getType())) - return failure(); - } - - if (elementwiseOp.getOpType() != gpu::MMAElementwiseOp::MULF) - return failure(); - - // Use the original operands to check whether one of the operands is a splat - // scalar value. - Value lhs = elementwiseOp.getOperands().front(); - Value rhs = elementwiseOp.getOperands().back(); - Value splat = nullptr; - Value matrix = nullptr; - if (lhs.getDefiningOp()) { - splat = adaptor.getOperands().front(); - matrix = adaptor.getOperands().back(); - } else if (rhs.getDefiningOp()) { - matrix = adaptor.getOperands().front(); - splat = adaptor.getOperands().back(); - } - if (!splat || !matrix) - return failure(); - - // Constant MMA matrix ops are converted to spirv.CompositeConstruct ops. - Value scalar = nullptr; - auto cc = splat.getDefiningOp(); - if (!cc) - return failure(); - assert(cc.getConstituents().size() == 1); - scalar = cc.getConstituents().front(); - - auto coopType = convertMMAToSPIRVCoopMatrixNVType( - cast(elementwiseOp.getType())); - rewriter.replaceOpWithNewOp( - elementwiseOp, coopType, ValueRange{matrix, scalar}); - return success(); - } -}; - } // namespace } // namespace nv } // namespace mlir @@ -389,19 +414,21 @@ void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns( using namespace mlir; MLIRContext *context = patterns.getContext(); patterns.add(converter, context); + khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering, + WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context); + // Give the following patterns higher benefit to prevail over the default one. + patterns.add(converter, context, + /*benefit=*/2); } void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns( SPIRVTypeConverter &converter, RewritePatternSet &patterns) { using namespace mlir; MLIRContext *context = patterns.getContext(); - patterns - .add(converter, context); + patterns.add(converter, context); // Give the following patterns higher benefit to prevail over the default one. - patterns.add(converter, - context, - /*benefit=*/2); + patterns.add(converter, context, + /*benefit=*/2); } diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir index 0818791b98471d..f129cc8ce84ec3 100644 --- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir @@ -69,12 +69,106 @@ module attributes { -> !gpu.mma_matrix<16x16xf16, "COp"> %i = arith.constant 0 : index - // CHECK: spirv.KHR.CooperativeMatrixStore {{%.+}}, %[[MAD]], %{{.+}}, + // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[MAD]], %{{.+}}, gpu.subgroup_mma_store_matrix %D, %ptr[%i,%i] {leadDimension = 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class> // CHECK: spirv.Return gpu.return } + // CHECK-LABEL: spirv.func @gpu_wmma_constant_op + gpu.func @gpu_wmma_constant_op(%ptr: memref<16x16xf16, #spirv.storage_class>) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + // CHECK: %[[CST1F:.+]] = spirv.Constant 1.000000e+00 : f16 + %cst = arith.constant 1.0 : f16 + // CHECK: %[[MAT:.+]] = spirv.CompositeConstruct %[[CST1F]] : + // CHECK-SAME: (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> + %C = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf16, "COp"> + + %i = arith.constant 0 : index + // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[MAT]], %{{.+}}, + gpu.subgroup_mma_store_matrix %C, %ptr[%i,%i] {leadDimension = 32 : index} : + !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class> + // CHECK: spirv.Return + gpu.return + } + + // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default + // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> + // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> + gpu.func @gpu_wmma_elementwise_op_default(%A: !gpu.mma_matrix<16x16xf16, "COp">, + %B: !gpu.mma_matrix<16x16xf16, "COp">, + %ptr: memref<16x16xf32, #spirv.storage_class>) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> + %C = gpu.subgroup_mma_elementwise addf %A, %B : + (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> + // CHECK: {{%.*}} = spirv.FNegate {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> + %D = gpu.subgroup_mma_elementwise negatef %C : + (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> + // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> + %E = gpu.subgroup_mma_elementwise divf %D, %A : + (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> + // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : + // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> to !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc> + %F = gpu.subgroup_mma_elementwise extf %E : + (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp"> + + %i = arith.constant 0 : index + // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %{{.+}}, %{{.+}}, + gpu.subgroup_mma_store_matrix %F, %ptr[%i,%i] {leadDimension = 32 : index} : + !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32, #spirv.storage_class> + // CHECK: spirv.Return + gpu.return + } + + // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_times_scalar + // CHECK-SAME: %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> + // CHECK-SAME: %[[S:.+]]: f16 + gpu.func @gpu_wmma_elementwise_op_matrix_times_scalar( + %A: !gpu.mma_matrix<16x16xf16, "COp">, %scalar: f16, + %ptr: memref<16x16xf16, #spirv.storage_class>) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + %i = arith.constant 0 : index + + %B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp"> + // CHECK: %[[C:.+]] = spirv.MatrixTimesScalar %[[A]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, f16 + // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[C]], %{{.+}}, + %C = gpu.subgroup_mma_elementwise mulf %A, %B : + (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> + gpu.subgroup_mma_store_matrix %C, %ptr[%i,%i] {leadDimension = 32 : index} : + !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class> + + // CHECK: %[[D:.+]] = spirv.MatrixTimesScalar %[[C]], %[[S]] : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>, f16 + // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[D]], %{{.+}}, + %D = gpu.subgroup_mma_elementwise mulf %B, %C : + (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> + gpu.subgroup_mma_store_matrix %D, %ptr[%i,%i] {leadDimension = 32 : index} : + !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class> + // CHECK: spirv.Return + gpu.return + } + + // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_matrix_plus_scalar + // CHECK-SAME: %[[A:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> + // CHECK-SAME: %[[S:.+]]: f16 + gpu.func @gpu_wmma_elementwise_op_matrix_plus_scalar( + %A : !gpu.mma_matrix<16x16xf16, "COp">, %scalar : f16, + %ptr: memref<16x16xf16, #spirv.storage_class>) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + %i = arith.constant 0 : index + + // CHECK: %[[SM:.+]] = spirv.CompositeConstruct %[[S]] : (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> + %B = gpu.subgroup_mma_constant_matrix %scalar : !gpu.mma_matrix<16x16xf16, "COp"> + // CHECK: %[[C:.+]] = spirv.FAdd %[[A]], %[[SM]] : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> + %C = gpu.subgroup_mma_elementwise addf %A, %B : + (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> + + // CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %[[C]], %{{.+}}, + gpu.subgroup_mma_store_matrix %C, %ptr[%i,%i] {leadDimension = 32 : index} : + !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class> + // CHECK: spirv.Return + gpu.return + } } }