Skip to content

Commit

Permalink
[GPU] Use shared memory for data tiled multi_mma ops
Browse files Browse the repository at this point in the history
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
  • Loading branch information
Max191 committed Oct 1, 2024
1 parent 20a7638 commit 25d2c05
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ setDataTiledMultiMmaLoweringConfig(IREE::GPU::TargetAttr target,
b.getI64ArrayAttr(workgroupTileSizes));
attrs.emplace_back(b.getStringAttr("reduction"),
b.getI64ArrayAttr(reductionTileSizes));
// Promote operands to use shared memory for LHS and RHS.
GPU::LoweringConfigAttr::setPromotedOperandList(context, attrs, {0, 1});
auto configDict = b.getDictionaryAttr(attrs);
auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -756,11 +756,26 @@ hal.executable public @main {
#map3 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map4 = affine_map<(d0, d1, d2) -> (d0, d1)>
#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">,
#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">,
#hal.pipeline.binding<storage_buffer, Indirect>],
flags = Indirect
>
#config = #iree_gpu.lowering_config<{workgroup = [1, 1, 0], reduction = [0, 0, 1]}>
#translation_info = #iree_codegen.translation_info<
LLVMGPUTileAndFuse
workgroup_size = [256, 1, 1]
subgroup_size = 64,
{
gpu_pipeline_options = #iree_gpu.pipeline_options<
prefetch_shared_memory = false,
no_reduce_shared_memory_bank_conflicts = true>
}
>
#config = #iree_gpu.lowering_config<{
workgroup = [1, 1, 0],
reduction = [0, 0, 1],
promote_operands = [0, 1]
}>
hal.executable public @main {
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) {
hal.executable.export public @matmul_transpose_b_mfma ordinal(0) layout(#pipeline_layout) {
Expand All @@ -770,13 +785,11 @@ hal.executable public @main {
}
builtin.module {
func.func @multi_mma_data_tiled_unrolled_MFMA_F32_16x16x4_F32()
attributes {translation_info = #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [256, 1, 1] subgroup_size = 64>} {
attributes {translation_info = #translation_info} {
%c0 = arith.constant 0 : index
%c65536 = arith.constant 65536 : index
%c131072 = arith.constant 131072 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<1x1x8x4x16x4xf32>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c65536) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<1x1x4x2x4x16x4xf32>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c131072) flags(Indirect) : !flow.dispatch.tensor<readwrite:tensor<1x1x8x4x2x4x16x4xf32>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<1x1x4x2x4x16x4xf32>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<readwrite:tensor<1x1x8x4x2x4x16x4xf32>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 8, 4, 16, 4], strides = [1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x1x8x4x16x4xf32>> -> tensor<1x1x8x4x16x4xf32>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0, 0, 0, 0], sizes = [1, 1, 4, 2, 4, 16, 4], strides = [1, 1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x1x4x2x4x16x4xf32>> -> tensor<1x1x4x2x4x16x4xf32>
%5 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0, 0, 0, 0, 0, 0], sizes = [1, 1, 8, 4, 2, 4, 16, 4], strides = [1, 1, 1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readwrite:tensor<1x1x8x4x2x4x16x4xf32>> -> tensor<1x1x8x4x2x4x16x4xf32>
Expand Down Expand Up @@ -805,9 +818,20 @@ hal.executable public @main {
}

// CHECK-LABEL: func.func @multi_mma_data_tiled_unrolled_MFMA_F32_16x16x4_F32()
// CHECK-DAG: %[[A_READ:.+]] = vector.transfer_read {{.+}}, vector<8x1x1x4xf32>
// CHECK-DAG: %[[B_READ:.+]] = vector.transfer_read {{.+}}, vector<2x1x1x4xf32>
// CHECK-DAG: %[[C_READ:.+]] = vector.transfer_read {{.+}}, vector<8x1x2x1x1x4xf32>
// CHECK-DAG: %[[BINDING_A:.+]] = hal.interface.binding.subspan {{.*}} binding(0)
// CHECK-DAG: %[[BINDING_B:.+]] = hal.interface.binding.subspan {{.*}} binding(1)
// CHECK-DAG: %[[BINDING_C:.+]] = hal.interface.binding.subspan {{.*}} binding(2)
// CHECK-DAG: %[[A_ALLOC:.+]] = memref.alloc() : memref<1x1x8x4x16x4xf32, #gpu.address_space<workgroup>>
// CHECK-DAG: %[[B_ALLOC:.+]] = memref.alloc() : memref<1x1x4x2x4x16x4xf32, #gpu.address_space<workgroup>>
// CHECK: gpu.barrier
// CHECK-DAG: %[[A_GLOBAL_LOAD:.+]] = vector.transfer_read %[[BINDING_A]]{{.*}} vector<4xf32>
// CHECK-DAG: %[[B_GLOBAL_LOAD:.+]] = vector.transfer_read %[[BINDING_B]]{{.*}} vector<4xf32>
// CHECK-DAG: vector.transfer_write %[[A_GLOBAL_LOAD]], %[[A_ALLOC]]
// CHECK-DAG: vector.transfer_write %[[B_GLOBAL_LOAD]], %[[B_ALLOC]]
// CHECK: gpu.barrier
// CHECK-DAG: %[[A_READ:.+]] = vector.transfer_read %[[A_ALLOC]]{{.*}} vector<8x1x1x4xf32>
// CHECK-DAG: %[[B_READ:.+]] = vector.transfer_read %[[B_ALLOC]]{{.*}} vector<2x1x1x4xf32>
// CHECK-DAG: %[[C_READ:.+]] = vector.transfer_read %[[BINDING_C]]{{.*}} vector<8x1x2x1x1x4xf32>
// CHECK-DAG: %[[C_SLICE00:.+]] = vector.extract_strided_slice %[[C_READ]] {{.*}}offsets = [0, 0, 0, 0, 0, 0]{{.*}} : vector<8x1x2x1x1x4xf32> to vector<1x1x1x1x1x4xf32>
// CHECK-DAG: %[[C_SLICE01:.+]] = vector.extract_strided_slice %[[C_READ]] {{.*}}offsets = [0, 0, 1, 0, 0, 0]{{.*}} : vector<8x1x2x1x1x4xf32> to vector<1x1x1x1x1x4xf32>
// CHECK-DAG: %[[C_SLICE70:.+]] = vector.extract_strided_slice %[[C_READ]] {{.*}}offsets = [7, 0, 0, 0, 0, 0]{{.*}} : vector<8x1x2x1x1x4xf32> to vector<1x1x1x1x1x4xf32>
Expand Down

0 comments on commit 25d2c05

Please sign in to comment.