Skip to content

Commit

Permalink
[mlir] Add bubbling patterns for non intersecting reshapes (#103401)
Browse files Browse the repository at this point in the history
Refactored @Max191's PR #94637
to move it to `Tensor`

From the original PR
>This PR adds fusion by expansion patterns to push a tensor.expand_shape
up through a tensor.collapse_shape with non-intersecting reassociations.
Sometimes parallel collapse_shape ops like this can block propagation of
expand_shape ops, so this allows them to pass through each other.

I'm not sure if I put the code/tests in the right places, so let me know
where those go if they aren't.

cc @MaheshRavishankar @hanhanW

---------

Co-authored-by: Max Dawkins <max.dawkins@gmail.com>
  • Loading branch information
IanWood1 and Max191 authored Aug 14, 2024
1 parent f6e3dbc commit a95ad2d
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 0 deletions.
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ void populateDropRedundantInsertSliceRankExpansionPatterns(
/// `tensor.collapse_shape` into other ops.
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns);

/// Populates `patterns` with patterns that bubble up `tensor.expand_shape`
/// through `tensor.collapse_shape` ops.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns);

/// Populates `patterns` with patterns that fold tensor.empty with its
/// consumers.
///
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Matchers.h"
Expand Down Expand Up @@ -2144,6 +2145,7 @@ struct LinalgElementwiseOpFusionPass
// Add elementwise op fusion patterns.
populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
tensor::populateBubbleUpExpandShapePatterns(patterns);

// General canonicalization patterns.
affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
Expand Down
75 changes: 75 additions & 0 deletions mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,76 @@ struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> {
return success();
}
};

/// Pattern to bubble up a tensor.expand_shape op through a producer
/// tensor.collapse_shape op that has non intersecting reassociations.
struct BubbleUpExpandThroughParallelCollapse
: public OpRewritePattern<tensor::ExpandShapeOp> {
using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
PatternRewriter &rewriter) const override {
auto collapseOp =
expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
if (!collapseOp)
return failure();
auto expandReInds = expandOp.getReassociationIndices();
auto collapseReInds = collapseOp.getReassociationIndices();

// Reshapes are parallel to each other if none of the reassociation indices
// have greater than 1 index for both reshapes.
for (auto [expandReassociation, collapseReassociation] :
llvm::zip_equal(expandReInds, collapseReInds)) {
if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
return failure();
}

// Compute new reassociation indices and expanded/collaped shapes.
SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
Location loc = expandOp->getLoc();
SmallVector<OpFoldResult> collapseSizes =
tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
SmallVector<OpFoldResult> expandSizes(getMixedValues(
expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
SmallVector<OpFoldResult> newExpandSizes;
int64_t index = 0, expandIndex = 0, collapseIndex = 0;
for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) {
if (collapseReassociation.size() != 1) {
ReassociationIndices newCollapseReassociation;
for (size_t i = 0; i < collapseReassociation.size(); ++i) {
newCollapseReassociation.push_back(index);
newExpandReInds.push_back({index++});
newExpandSizes.push_back(collapseSizes[collapseIndex++]);
}
newCollapseReInds.push_back(newCollapseReassociation);
expandIndex++;
continue;
}
ReassociationIndices newExpandReassociation;
auto expandReassociation = expandReInds[idx];
for (size_t i = 0; i < expandReassociation.size(); ++i) {
newExpandReassociation.push_back(index);
newCollapseReInds.push_back({index++});
newExpandSizes.push_back(expandSizes[expandIndex++]);
}
newExpandReInds.push_back(newExpandReassociation);
collapseIndex++;
}

// Swap reshape order.
SmallVector<Value> dynamicSizes;
SmallVector<int64_t> staticSizes;
dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
auto expandResultType = expandOp.getResultType().clone(staticSizes);
auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
newExpandSizes);
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
expandOp, newExpand.getResult(), newCollapseReInds);
return success();
}
};

} // namespace

void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
Expand All @@ -152,3 +222,8 @@ void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
patterns.getContext());
}

void mlir::tensor::populateBubbleUpExpandShapePatterns(
RewritePatternSet &patterns) {
patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
}
47 changes: 47 additions & 0 deletions mlir/test/Dialect/Tensor/bubble-reshapes.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-expand-shape-bubbling %s | FileCheck %s

func.func @bubble_parallel_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
return %expand : tensor<?x?x?x?xf32>
}
// CHECK: func @bubble_parallel_reshapes
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]]
// CHECK-SAME: output_shape [%[[S0]], %[[DIM1]], %[[DIM2]], %[[S2]], %[[S3]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?xf32>
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor<?x?x?x?x?xf32> into tensor<?x?x?x?xf32>
// CHECK: return %[[COLLAPSE]]

// -----

func.func @no_bubble_full_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%expand = tensor.expand_shape %collapse [[0], [1, 2], [3]]
output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
return %expand : tensor<?x?x?x?xf32>
}
// CHECK: func @no_bubble_full_intersecting_reshapes
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3]]
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0], [1, 2], [3]]
// CHECK: return %[[EXPAND]]

// -----

func.func @no_bubble_partial_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
%collapse = tensor.collapse_shape %arg0 [[0, 1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?xf32>
%expand = tensor.expand_shape %collapse [[0, 1], [2, 3]]
output_shape [%s0, %s1, %s2, %s3] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
return %expand : tensor<?x?x?x?xf32>
}
// CHECK: func @no_bubble_partial_intersecting_reshapes
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0, 1], [2, 3]]
// CHECK: return %[[EXPAND]]
13 changes: 13 additions & 0 deletions mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ struct TestTensorTransforms
llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
llvm::cl::init(false)};

Option<bool> testBubbleUpExpandShapePatterns{
*this, "test-expand-shape-bubbling",
llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
llvm::cl::init(false)};

Option<bool> testFoldIntoPackAndUnpack{
*this, "test-fold-into-pack-and-unpack",
llvm::cl::desc("Test folding ops into tensor.pack and tensor.unpack"),
Expand Down Expand Up @@ -102,6 +107,12 @@ static void applyReassociativeReshapeFoldingPatterns(Operation *rootOp) {
(void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
}

static void applyBubbleUpExpandShapePatterns(Operation *rootOp) {
RewritePatternSet patterns(rootOp->getContext());
tensor::populateBubbleUpExpandShapePatterns(patterns);
(void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
}

static void applyFoldIntoPackAndUnpackPatterns(Operation *rootOp) {
RewritePatternSet patterns(rootOp->getContext());
tensor::populateFoldIntoPackAndUnpackPatterns(patterns);
Expand Down Expand Up @@ -386,6 +397,8 @@ void TestTensorTransforms::runOnOperation() {
applyDropRedundantInsertSliceRankExpansionPatterns(rootOp);
if (testReassociativeReshapeFolding)
applyReassociativeReshapeFoldingPatterns(rootOp);
if (testBubbleUpExpandShapePatterns)
applyBubbleUpExpandShapePatterns(rootOp);
if (testFoldIntoPackAndUnpack)
applyFoldIntoPackAndUnpackPatterns(rootOp);
if (testRewriteExtractSliceWithTiledCollapseShape) {
Expand Down

0 comments on commit a95ad2d

Please sign in to comment.