Skip to content

Commit

Permalink
[mlir] Add bubbling patterns for non intersecting reshapes
Browse files Browse the repository at this point in the history
  • Loading branch information
Max191 committed Jun 10, 2024
1 parent 503fb1a commit c96c4ad
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 0 deletions.
71 changes: 71 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,76 @@ struct FoldReshapeWithGenericOpByExpansion
private:
ControlFusionFn controlFoldingReshapes;
};

/// Pattern to bubble up a tensor.expand_shape op through a producer
/// tensor.collapse_shape op that has non intersecting reassociations.
struct BubbleUpExpandThroughParallelCollapse
: public OpRewritePattern<tensor::ExpandShapeOp> {
using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
PatternRewriter &rewriter) const override {
auto collapseOp =
expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
if (!collapseOp || !collapseOp->hasOneUse())
return failure();
auto expandReInds = expandOp.getReassociationIndices();
auto collapseReInds = collapseOp.getReassociationIndices();

// Reshapes are parallel to each other if none of the reassociation indices
// have greater than 1 index for both reshapes.
for (auto [expandReassociation, collapseReassociation] :
llvm::zip_equal(expandReInds, collapseReInds)) {
if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
return failure();
}

// Compute new reassociation indices and expanded/collaped shapes.
SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
Location loc = expandOp->getLoc();
SmallVector<OpFoldResult> collapseSizes =
tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
SmallVector<OpFoldResult> expandSizes(getMixedValues(
expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
SmallVector<OpFoldResult> newExpandSizes;
int64_t index = 0, expandIndex = 0, collapseIndex = 0;
for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) {
if (collapseReassociation.size() != 1) {
ReassociationIndices newCollapseReassociation;
for (size_t i = 0; i < collapseReassociation.size(); ++i) {
newCollapseReassociation.push_back(index);
newExpandReInds.push_back({index++});
newExpandSizes.push_back(collapseSizes[collapseIndex++]);
}
newCollapseReInds.push_back(newCollapseReassociation);
expandIndex++;
continue;
}
ReassociationIndices newExpandReassociation;
auto expandReassociation = expandReInds[idx];
for (size_t i = 0; i < expandReassociation.size(); ++i) {
newExpandReassociation.push_back(index);
newCollapseReInds.push_back({index++});
newExpandSizes.push_back(expandSizes[expandIndex++]);
}
newExpandReInds.push_back(newExpandReassociation);
collapseIndex++;
}

// Swap reshape order.
SmallVector<Value> dynamicSizes;
SmallVector<int64_t> staticSizes;
dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
auto expandResultType = expandOp.getResultType().clone(staticSizes);
auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
newExpandSizes);
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
expandOp, newExpand.getResult(), newCollapseReInds);
return success();
}
};

} // namespace

//===---------------------------------------------------------------------===//
Expand Down Expand Up @@ -2083,6 +2153,7 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
controlFoldingReshapes);
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
}

void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
Expand Down
34 changes: 34 additions & 0 deletions mlir/test/Dialect/Linalg/reshape_fusion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -887,3 +887,37 @@ func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor<?x?x?x?x?x?xi32>, %l0: i
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
// CHECK-SAME: : tensor<?x?x?x?x?x?xi32> into tensor<?x?x?x?xi32>
// CHECK: return %[[COLLAPSE]]

// -----

func.func @bubble_parallel_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
return %expand : tensor<?x?x?x?xf32>
}
// CHECK: func @bubble_parallel_reshapes
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]]
// CHECK-SAME: output_shape [%[[S0]], %[[DIM1]], %[[DIM2]], %[[S2]], %[[S3]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?xf32>
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor<?x?x?x?x?xf32> into tensor<?x?x?x?xf32>
// CHECK: return %[[COLLAPSE]]

// -----

func.func @no_bubble_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%expand = tensor.expand_shape %collapse [[0], [1, 2], [3]]
output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
return %expand : tensor<?x?x?x?xf32>
}
// CHECK: func @no_bubble_intersecting_reshapes
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3]]
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0], [1, 2], [3]]
// CHECK: return %[[EXPAND]]

0 comments on commit c96c4ad

Please sign in to comment.