diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp index 7f9141c7d3877..4a8a1fa242f14 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp @@ -34,6 +34,15 @@ void transform_dialect::ApplyDropMultiMmaOpUnitDims::populatePatterns( IREE::GPU::populateIREEGPUDropUnitDimsPatterns(patterns); } +//===---------------------------------------------------------------------===// +// ApplyLowerMultiMmaOp +//===---------------------------------------------------------------------===// + +void transform_dialect::ApplyLowerMultiMmaOp::populatePatterns( + RewritePatternSet &patterns) { + IREE::GPU::populateIREEGPULowerMultiMmaPatterns(patterns); +} + //===---------------------------------------------------------------------===// // ApplyLowerValueBarrierOp //===---------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td index 1a33675cc71f3..51c10d4abfe0d 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td @@ -27,6 +27,19 @@ def ApplyDropMultiMmaOpUnitDims : Op, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + Populate patterns to lowering multi_mma ops to the intrinsic specified by + the |kind| attribute. + }]; + + let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect"; + let assemblyFormat = "attr-dict"; +} + def ApplyLowerValueBarrierOp : Op, diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel index 4b644fa4e6152..bb10414c9bcde 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel @@ -19,6 +19,7 @@ iree_lit_test_suite( srcs = enforce_glob( [ "drop_multi_mma_unit_dims.mlir", + "lower_multi_mma.mlir", "lower_vector_barrier.mlir", "transform_fuse_forall.mlir", "vectorize_multi_mma.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt index 8e3ec6dc50163..9929856b42b95 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt @@ -15,6 +15,7 @@ iree_lit_test_suite( lit SRCS "drop_multi_mma_unit_dims.mlir" + "lower_multi_mma.mlir" "lower_vector_barrier.mlir" "transform_fuse_forall.mlir" "unroll_multi_mma.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/lower_multi_mma.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/lower_multi_mma.mlir new file mode 100644 index 0000000000000..60255d47bd706 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/lower_multi_mma.mlir @@ -0,0 +1,138 @@ +// RUN: iree-opt %s -iree-transform-dialect-interpreter -transform-dialect-drop-schedule --split-input-file | FileCheck %s + +#contraction_accesses = [ + affine_map<() -> ()>, + affine_map<() -> ()>, + affine_map<() -> ()> +] +func.func @lower_multi_mma_mfma_16x16x16(%lhs: vector<4xf16>, %rhs: vector<4xf16>, %acc: vector<4xf32>) -> vector<4xf32> { + %0 = iree_gpu.multi_mma %lhs, %rhs, %acc { + indexing_maps = #contraction_accesses, + iterator_types = [], + kind = #iree_gpu.mma_layout + } : vector<4xf16>, vector<4xf16> into vector<4xf32> + return %0 : vector<4xf32> +} + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.iree.lower_multi_mma + } : !transform.any_op + transform.yield + } +} + +// CHECK-LABEL: func @lower_multi_mma_mfma_16x16x16 +// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: vector<4xf16> +// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: vector<4xf16> +// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]: vector<4xf32> +// CHECK: amdgpu.mfma %[[LHS]] * %[[RHS]] + %[[ACC]] +// CHECK-SAME: blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32 +// CHECK-SAME: blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + +// ----- + +#contraction_accesses = [ + affine_map<() -> ()>, + affine_map<() -> ()>, + affine_map<() -> ()> +] +func.func @lower_multi_mma_mfma_32x32x8(%lhs: vector<4xf16>, %rhs: vector<4xf16>, %acc: vector<16xf32>) -> vector<16xf32> { + %0 = iree_gpu.multi_mma %lhs, %rhs, %acc { + indexing_maps = #contraction_accesses, + iterator_types = [], + kind = #iree_gpu.mma_layout + } : vector<4xf16>, vector<4xf16> into vector<16xf32> + return %0 : vector<16xf32> +} + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.iree.lower_multi_mma + } : !transform.any_op + transform.yield + } +} + +// CHECK-LABEL: func @lower_multi_mma_mfma_32x32x8 +// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: vector<4xf16> +// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: vector<4xf16> +// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]: vector<16xf32> +// CHECK: amdgpu.mfma %[[LHS]] * %[[RHS]] + %[[ACC]] +// CHECK-SAME: blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32 +// CHECK-SAME: blgp = none : vector<4xf16>, vector<4xf16>, vector<16xf32> + +// ----- + +#contraction_accesses = [ + affine_map<() -> ()>, + affine_map<() -> ()>, + affine_map<() -> ()> +] +func.func @lower_multi_mma_wmma_16x16x16(%lhs: vector<16xf16>, %rhs: vector<16xf16>, %acc: vector<8xf32>) -> vector<8xf32> { + %0 = iree_gpu.multi_mma %lhs, %rhs, %acc { + indexing_maps = #contraction_accesses, + iterator_types = [], + kind = #iree_gpu.mma_layout + } : vector<16xf16>, vector<16xf16> into vector<8xf32> + return %0 : vector<8xf32> +} + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.iree.lower_multi_mma + } : !transform.any_op + transform.yield + } +} + +// CHECK-LABEL: func @lower_multi_mma_wmma_16x16x16 +// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: vector<16xf16> +// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: vector<16xf16> +// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]: vector<8xf32> +// CHECK: amdgpu.wmma %[[LHS]] * %[[RHS]] + %[[ACC]] +// CHECK-SAME: : vector<16xf16>, vector<16xf16>, vector<8xf32> + +// ----- + +#contraction_accesses = [ + affine_map<() -> ()>, + affine_map<() -> ()>, + affine_map<() -> ()> +] +func.func @lower_multi_mma_mfma_shape_cast_16x16x16(%lhs: vector<1x4xf16>, %rhs: vector<4x1xf16>, %acc: vector<4x1xf32>) -> vector<4x1xf32> { + %0 = iree_gpu.multi_mma %lhs, %rhs, %acc { + indexing_maps = #contraction_accesses, + iterator_types = [], + kind = #iree_gpu.mma_layout + } : vector<1x4xf16>, vector<4x1xf16> into vector<4x1xf32> + return %0 : vector<4x1xf32> +} + +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func { + transform.apply_patterns.iree.lower_multi_mma + } : !transform.any_op + transform.yield + } +} + +// CHECK-LABEL: func @lower_multi_mma_mfma_shape_cast_16x16x16 +// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: vector<1x4xf16> +// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: vector<4x1xf16> +// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]: vector<4x1xf32> +// CHECK-DAG: %[[LHSCAST:.+]] = vector.shape_cast %[[LHS]] : vector<1x4xf16> to vector<4xf16> +// CHECK-DAG: %[[RHSCAST:.+]] = vector.shape_cast %[[RHS]] : vector<4x1xf16> to vector<4xf16> +// CHECK-DAG: %[[ACCCAST:.+]] = vector.shape_cast %[[ACC]] : vector<4x1xf32> to vector<4xf32> +// CHECK: %[[MMA:.+]] = amdgpu.mfma %[[LHSCAST]] * %[[RHSCAST]] + %[[ACCCAST]] +// CHECK-SAME: blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32 +// CHECK-SAME: blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> +// CHECK: vector.shape_cast %[[MMA]] : vector<4xf32> to vector<4x1xf32> diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp index 60596a06cdaf0..105b19d5458ea 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp @@ -164,6 +164,59 @@ LogicalResult fuseForallIntoSlice(RewriterBase &rewriter, return success(); } +//===----------------------------------------------------------------------===// +// MultiMmaOp Lowering +//===----------------------------------------------------------------------===// + +namespace { +struct LowerMultiMmaPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(IREE::GPU::MultiMmaOp mmaOp, + PatternRewriter &rewriter) const override { + if (mmaOp.hasTensorSemantics()) { + return rewriter.notifyMatchFailure( + mmaOp, "lowering to concrete op requires vector semantics"); + } + SmallVector bounds; + mmaOp.getIterationBounds(bounds); + if (!bounds.empty()) { + return rewriter.notifyMatchFailure(mmaOp, + "must be a single mma operation"); + } + + auto [lhsVectorType, rhsVectorType, accVectorType] = + mmaOp.getKind().getABCVectorTypes(); + + Value aCast = mmaOp.getLhs(); + Value bCast = mmaOp.getRhs(); + Value cCast = mmaOp.getAcc(); + if (aCast.getType() != lhsVectorType) { + aCast = rewriter.create(mmaOp.getLoc(), + lhsVectorType, aCast); + } + if (bCast.getType() != rhsVectorType) { + bCast = rewriter.create(mmaOp.getLoc(), + rhsVectorType, bCast); + } + if (cCast.getType() != accVectorType) { + cCast = rewriter.create(mmaOp.getLoc(), + accVectorType, cCast); + } + + FailureOr concreteMmaOp = mmaOp.getKind().buildMmaOperation( + rewriter, mmaOp.getLoc(), cCast.getType(), aCast, bCast, cCast); + assert(succeeded(concreteMmaOp) && "Failed to create mma op"); + rewriter.replaceOpWithNewOp( + mmaOp, mmaOp.getAcc().getType(), *concreteMmaOp); + return success(); + } +}; +} // namespace + +void populateIREEGPULowerMultiMmaPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + //===----------------------------------------------------------------------===// // MultiMmaOp Unit Dim Folding //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h index c706ee1a2ebeb..3b6b0b1cc8f98 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h @@ -35,6 +35,7 @@ LogicalResult fuseForallIntoSlice(RewriterBase &rewriter, tensor::ExtractSliceOp slice); void populateIREEGPUDropUnitDimsPatterns(RewritePatternSet &patterns); +void populateIREEGPULowerMultiMmaPatterns(RewritePatternSet &patterns); void populateIREEGPULowerValueBarrierPatterns(RewritePatternSet &patterns); void populateIREEGPUVectorUnrollPatterns( RewritePatternSet &patterns, const vector::UnrollVectorOptions &options);