diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp index 4b029d076095..a6b6bf8a165b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp @@ -148,8 +148,11 @@ applyTileAndFuseToEachRoot(RewriterBase &rewriter, if (auto tilingOwner = dyn_cast(owner)) { shouldFuse = !payloadOps.contains(tilingOwner); } - // Do not fuse destination operands. - shouldFuse &= !isDestinationOperand; + // Do not fuse destination operands for reduction tiling. + if (isDestinationOperand && + tilingLevel == IREE::GPU::TilingLevel::Reduction) { + shouldFuse = false; + } if (shouldFuse) { return scf::SCFTileAndFuseOptions::ControlFnResult{ yieldProducerReplacement}; diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir index 22001c76dc6d..18b620f516dd 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_apply_tiling_level.mlir @@ -155,6 +155,30 @@ func.func @matmul_fuse(%3: tensor<64x64xf32>, %4: tensor<64x64xf32>, %5: tensor< // ----- +#config = #iree_gpu.lowering_config<{reduction = [0, 0, 8], thread = [8, 8, 0]}> +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @matmul_fuse_destination(%3: tensor<64x64xf32>, %4: tensor<64x64xf32>) -> tensor<64x64xf32> { + %empty = tensor.empty() : tensor<64x64xf32> + %cst = arith.constant 0.0 : f32 + %5 = linalg.fill ins(%cst : f32) outs(%empty : tensor<64x64xf32>) -> tensor<64x64xf32> + %7 = linalg.matmul {lowering_config = #config} ins(%3, %4 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%5 : tensor<64x64xf32>) -> tensor<64x64xf32> + return %7 : tensor<64x64xf32> +} + +// Verify that destinations are not fused for reduction tiling. +// CHECK-LABEL: func.func @matmul_fuse_destination +// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.*}} : tensor<64x64xf32>) +// CHECK: scf.for %{{.*}} = %c0 to %c64 step %c8 iter_args(%[[ITER:.+]] = %[[FILL]] +// CHECK: linalg.matmul + +// THREAD-LABEL: func.func @matmul_fuse_destination +// THREAD: %[[EMPTY:.+]] = tensor.empty() : tensor<64x64xf32> +// THREAD: scf.forall {{.*}} shared_outs(%[[INIT:.+]] = %[[EMPTY]] +// THREAD: linalg.fill +// THREAD: linalg.matmul + +// ----- + #config = #iree_gpu.lowering_config<{thread = [8, 8]}> func.func @matmul_cleanup(%3: tensor<64x64xf32>, %4: tensor<64x64xf32>, %5: tensor<64x64xf32>) -> tensor<64x64xf32> { %c8 = arith.constant 8 : index