Skip to content

Commit

Permalink
[Codegen][GPU] Fuse into destinations for parallel tiling (#18666)
Browse files Browse the repository at this point in the history
Currently we disable fusion of destinations for all tiling levels, but
that's only required for reduction tiling. Turn on fusion along
destinations for parallel tiling levels.
  • Loading branch information
qedawkins authored Oct 3, 2024
1 parent 24ee841 commit 88153eb
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,11 @@ applyTileAndFuseToEachRoot(RewriterBase &rewriter,
if (auto tilingOwner = dyn_cast<TilingInterface>(owner)) {
shouldFuse = !payloadOps.contains(tilingOwner);
}
// Do not fuse destination operands.
shouldFuse &= !isDestinationOperand;
// Do not fuse destination operands for reduction tiling.
if (isDestinationOperand &&
tilingLevel == IREE::GPU::TilingLevel::Reduction) {
shouldFuse = false;
}
if (shouldFuse) {
return scf::SCFTileAndFuseOptions::ControlFnResult{
yieldProducerReplacement};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,30 @@ func.func @matmul_fuse(%3: tensor<64x64xf32>, %4: tensor<64x64xf32>, %5: tensor<

// -----

#config = #iree_gpu.lowering_config<{reduction = [0, 0, 8], thread = [8, 8, 0]}>
#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @matmul_fuse_destination(%3: tensor<64x64xf32>, %4: tensor<64x64xf32>) -> tensor<64x64xf32> {
%empty = tensor.empty() : tensor<64x64xf32>
%cst = arith.constant 0.0 : f32
%5 = linalg.fill ins(%cst : f32) outs(%empty : tensor<64x64xf32>) -> tensor<64x64xf32>
%7 = linalg.matmul {lowering_config = #config} ins(%3, %4 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%5 : tensor<64x64xf32>) -> tensor<64x64xf32>
return %7 : tensor<64x64xf32>
}

// Verify that destinations are not fused for reduction tiling.
// CHECK-LABEL: func.func @matmul_fuse_destination
// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.*}} : tensor<64x64xf32>)
// CHECK: scf.for %{{.*}} = %c0 to %c64 step %c8 iter_args(%[[ITER:.+]] = %[[FILL]]
// CHECK: linalg.matmul

// THREAD-LABEL: func.func @matmul_fuse_destination
// THREAD: %[[EMPTY:.+]] = tensor.empty() : tensor<64x64xf32>
// THREAD: scf.forall {{.*}} shared_outs(%[[INIT:.+]] = %[[EMPTY]]
// THREAD: linalg.fill
// THREAD: linalg.matmul

// -----

#config = #iree_gpu.lowering_config<{thread = [8, 8]}>
func.func @matmul_cleanup(%3: tensor<64x64xf32>, %4: tensor<64x64xf32>, %5: tensor<64x64xf32>) -> tensor<64x64xf32> {
%c8 = arith.constant 8 : index
Expand Down

0 comments on commit 88153eb

Please sign in to comment.