diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index e73df61c964341..7aa8a0b37c219c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1086,6 +1086,76 @@ struct FoldReshapeWithGenericOpByExpansion private: ControlFusionFn controlFoldingReshapes; }; + +/// Pattern to bubble up a tensor.expand_shape op through a producer +/// tensor.collapse_shape op that has non intersecting reassociations. +struct BubbleUpExpandThroughParallelCollapse + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp, + PatternRewriter &rewriter) const override { + auto collapseOp = + expandOp.getSrc().getDefiningOp(); + if (!collapseOp || !collapseOp->hasOneUse()) + return failure(); + auto expandReInds = expandOp.getReassociationIndices(); + auto collapseReInds = collapseOp.getReassociationIndices(); + + // Reshapes are parallel to each other if none of the reassociation indices + // have greater than 1 index for both reshapes. + for (auto [expandReassociation, collapseReassociation] : + llvm::zip_equal(expandReInds, collapseReInds)) { + if (collapseReassociation.size() != 1 && expandReassociation.size() != 1) + return failure(); + } + + // Compute new reassociation indices and expanded/collaped shapes. + SmallVector newExpandReInds, newCollapseReInds; + Location loc = expandOp->getLoc(); + SmallVector collapseSizes = + tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc()); + SmallVector expandSizes(getMixedValues( + expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter)); + SmallVector newExpandSizes; + int64_t index = 0, expandIndex = 0, collapseIndex = 0; + for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) { + if (collapseReassociation.size() != 1) { + ReassociationIndices newCollapseReassociation; + for (size_t i = 0; i < collapseReassociation.size(); ++i) { + newCollapseReassociation.push_back(index); + newExpandReInds.push_back({index++}); + newExpandSizes.push_back(collapseSizes[collapseIndex++]); + } + newCollapseReInds.push_back(newCollapseReassociation); + expandIndex++; + continue; + } + ReassociationIndices newExpandReassociation; + auto expandReassociation = expandReInds[idx]; + for (size_t i = 0; i < expandReassociation.size(); ++i) { + newExpandReassociation.push_back(index); + newCollapseReInds.push_back({index++}); + newExpandSizes.push_back(expandSizes[expandIndex++]); + } + newExpandReInds.push_back(newExpandReassociation); + collapseIndex++; + } + + // Swap reshape order. + SmallVector dynamicSizes; + SmallVector staticSizes; + dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes); + auto expandResultType = expandOp.getResultType().clone(staticSizes); + auto newExpand = rewriter.create( + loc, expandResultType, collapseOp.getSrc(), newExpandReInds, + newExpandSizes); + rewriter.replaceOpWithNewOp( + expandOp, newExpand.getResult(), newCollapseReInds); + return success(); + } +}; + } // namespace //===---------------------------------------------------------------------===// @@ -2083,6 +2153,7 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( controlFoldingReshapes); patterns.add(patterns.getContext(), controlFoldingReshapes); + patterns.add(patterns.getContext()); } void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns( diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir index b8df5fc88e1999..86c2904218385c 100644 --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -887,3 +887,37 @@ func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor, %l0: i // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]] // CHECK-SAME: : tensor into tensor // CHECK: return %[[COLLAPSE]] + +// ----- + +func.func @bubble_parallel_reshapes(%arg0: tensor, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor { + %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor into tensor + %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]] + output_shape [%s0, %s1, %s2, %s3] : tensor into tensor + return %expand : tensor +} +// CHECK: func @bubble_parallel_reshapes +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor +// CHECK-DAG: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]] +// CHECK-SAME: output_shape [%[[S0]], %[[DIM1]], %[[DIM2]], %[[S2]], %[[S3]]] : tensor into tensor +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor into tensor +// CHECK: return %[[COLLAPSE]] + +// ----- + +func.func @no_bubble_intersecting_reshapes(%arg0: tensor, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor { + %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor into tensor + %expand = tensor.expand_shape %collapse [[0], [1, 2], [3]] + output_shape [%s0, %s1, %s2, %s3] : tensor into tensor + return %expand : tensor +} +// CHECK: func @no_bubble_intersecting_reshapes +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3]] +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0], [1, 2], [3]] +// CHECK: return %[[EXPAND]]