From 25d2c05c26490209473b00a0b0c71968afa66f70 Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Wed, 25 Sep 2024 21:10:21 -0400 Subject: [PATCH] [GPU] Use shared memory for data tiled multi_mma ops Signed-off-by: Max Dawkins --- .../Dialect/GPU/TargetUtils/ConfigUtils.cpp | 2 + .../test/ROCDL/pipeline_tile_and_fuse.mlir | 42 +++++++++++++++---- 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index cad64ac9fc18..b7a1ffb5789b 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -78,6 +78,8 @@ setDataTiledMultiMmaLoweringConfig(IREE::GPU::TargetAttr target, b.getI64ArrayAttr(workgroupTileSizes)); attrs.emplace_back(b.getStringAttr("reduction"), b.getI64ArrayAttr(reductionTileSizes)); + // Promote operands to use shared memory for LHS and RHS. + GPU::LoweringConfigAttr::setPromotedOperandList(context, attrs, {0, 1}); auto configDict = b.getDictionaryAttr(attrs); auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir index 1b1c87e944a6..4b443142dae7 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir @@ -756,11 +756,26 @@ hal.executable public @main { #map3 = affine_map<(d0, d1, d2) -> (d1, d2)> #map4 = affine_map<(d0, d1, d2) -> (d0, d1)> #pipeline_layout = #hal.pipeline.layout, #hal.pipeline.binding, #hal.pipeline.binding], flags = Indirect > -#config = #iree_gpu.lowering_config<{workgroup = [1, 1, 0], reduction = [0, 0, 1]}> +#translation_info = #iree_codegen.translation_info< + LLVMGPUTileAndFuse + workgroup_size = [256, 1, 1] + subgroup_size = 64, + { + gpu_pipeline_options = #iree_gpu.pipeline_options< + prefetch_shared_memory = false, + no_reduce_shared_memory_bank_conflicts = true> + } +> +#config = #iree_gpu.lowering_config<{ + workgroup = [1, 1, 0], + reduction = [0, 0, 1], + promote_operands = [0, 1] +}> hal.executable public @main { hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) { hal.executable.export public @matmul_transpose_b_mfma ordinal(0) layout(#pipeline_layout) { @@ -770,13 +785,11 @@ hal.executable public @main { } builtin.module { func.func @multi_mma_data_tiled_unrolled_MFMA_F32_16x16x4_F32() - attributes {translation_info = #iree_codegen.translation_info} { + attributes {translation_info = #translation_info} { %c0 = arith.constant 0 : index - %c65536 = arith.constant 65536 : index - %c131072 = arith.constant 131072 : index %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor> - %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c65536) flags("ReadOnly|Indirect") : !flow.dispatch.tensor> - %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c131072) flags(Indirect) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor> %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 8, 4, 16, 4], strides = [1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<1x1x8x4x16x4xf32> %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0, 0, 0, 0], sizes = [1, 1, 4, 2, 4, 16, 4], strides = [1, 1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<1x1x4x2x4x16x4xf32> %5 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0, 0, 0, 0, 0, 0], sizes = [1, 1, 8, 4, 2, 4, 16, 4], strides = [1, 1, 1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<1x1x8x4x2x4x16x4xf32> @@ -805,9 +818,20 @@ hal.executable public @main { } // CHECK-LABEL: func.func @multi_mma_data_tiled_unrolled_MFMA_F32_16x16x4_F32() -// CHECK-DAG: %[[A_READ:.+]] = vector.transfer_read {{.+}}, vector<8x1x1x4xf32> -// CHECK-DAG: %[[B_READ:.+]] = vector.transfer_read {{.+}}, vector<2x1x1x4xf32> -// CHECK-DAG: %[[C_READ:.+]] = vector.transfer_read {{.+}}, vector<8x1x2x1x1x4xf32> +// CHECK-DAG: %[[BINDING_A:.+]] = hal.interface.binding.subspan {{.*}} binding(0) +// CHECK-DAG: %[[BINDING_B:.+]] = hal.interface.binding.subspan {{.*}} binding(1) +// CHECK-DAG: %[[BINDING_C:.+]] = hal.interface.binding.subspan {{.*}} binding(2) +// CHECK-DAG: %[[A_ALLOC:.+]] = memref.alloc() : memref<1x1x8x4x16x4xf32, #gpu.address_space> +// CHECK-DAG: %[[B_ALLOC:.+]] = memref.alloc() : memref<1x1x4x2x4x16x4xf32, #gpu.address_space> +// CHECK: gpu.barrier +// CHECK-DAG: %[[A_GLOBAL_LOAD:.+]] = vector.transfer_read %[[BINDING_A]]{{.*}} vector<4xf32> +// CHECK-DAG: %[[B_GLOBAL_LOAD:.+]] = vector.transfer_read %[[BINDING_B]]{{.*}} vector<4xf32> +// CHECK-DAG: vector.transfer_write %[[A_GLOBAL_LOAD]], %[[A_ALLOC]] +// CHECK-DAG: vector.transfer_write %[[B_GLOBAL_LOAD]], %[[B_ALLOC]] +// CHECK: gpu.barrier +// CHECK-DAG: %[[A_READ:.+]] = vector.transfer_read %[[A_ALLOC]]{{.*}} vector<8x1x1x4xf32> +// CHECK-DAG: %[[B_READ:.+]] = vector.transfer_read %[[B_ALLOC]]{{.*}} vector<2x1x1x4xf32> +// CHECK-DAG: %[[C_READ:.+]] = vector.transfer_read %[[BINDING_C]]{{.*}} vector<8x1x2x1x1x4xf32> // CHECK-DAG: %[[C_SLICE00:.+]] = vector.extract_strided_slice %[[C_READ]] {{.*}}offsets = [0, 0, 0, 0, 0, 0]{{.*}} : vector<8x1x2x1x1x4xf32> to vector<1x1x1x1x1x4xf32> // CHECK-DAG: %[[C_SLICE01:.+]] = vector.extract_strided_slice %[[C_READ]] {{.*}}offsets = [0, 0, 1, 0, 0, 0]{{.*}} : vector<8x1x2x1x1x4xf32> to vector<1x1x1x1x1x4xf32> // CHECK-DAG: %[[C_SLICE70:.+]] = vector.extract_strided_slice %[[C_READ]] {{.*}}offsets = [7, 0, 0, 0, 0, 0]{{.*}} : vector<8x1x2x1x1x4xf32> to vector<1x1x1x1x1x4xf32>