Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Fix bugs in expand_shape patterns after semantics changes #94631

Merged
merged 4 commits into from
Jun 7, 2024

Conversation

Max191
Copy link
Contributor

@Max191 Max191 commented Jun 6, 2024

After the output_shape field was added to expand_shape ops, dynamically sized expand shapes are now possible, but this was not accounted for in the folder. This PR tightens the constraints of the folder to fix this.

@llvmbot
Copy link
Collaborator

llvmbot commented Jun 6, 2024

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: None (Max191)

Changes

After the output_shape field was added to expand_shape ops, dynamically sized expand shapes are now possible, but this was not accounted for in the folder. This PR tightens the constraints of the folder to fix this.


Full diff: https://github.com/llvm/llvm-project/pull/94631.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (+46-10)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+55-2)
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index e8f6edc3f133e..3b986f4a60064 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -85,21 +85,55 @@ bool isReassociationValid(ArrayRef<AffineMap> reassociation,
 template <typename ReshapeOpTy, typename InverseReshapeOpTy>
 static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
                                   ArrayRef<Attribute> operands) {
-
+  // Fold identity reshape.
   if (reshapeOp.getSrcType() == reshapeOp.getType())
     return reshapeOp.getSrc();
 
-  // Fold producer-consumer reshape ops where the operand type of the
-  // producer is same as the return type of the consumer.
-  auto reshapeSrcOp =
-      reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
-  if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
-    return reshapeSrcOp.getSrc();
-
   // Reshape of a constant can be replaced with a new constant.
   if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
     return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
 
+  // Fold if the producer reshape source has the same shape with at most 1
+  // dynamic dimension.
+  auto reshapeSrcOp =
+      reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
+  if (!reshapeSrcOp)
+    return nullptr;
+  auto srcType = reshapeSrcOp.getSrcType();
+  auto resultType = reshapeOp.getResultType();
+  if (srcType != resultType)
+    return nullptr;
+
+  // If the reshapes are expanding and then collapsing, the ops can be folded
+  // despite multiple dynamic dimensions.
+  if (srcType.getRank() < reshapeSrcOp.getResultType().getRank())
+    return reshapeSrcOp.getSrc();
+  // Otherwise, only 1 dynamic dimension is allowed.
+  if (srcType == resultType &&
+      llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) {
+    return reshapeSrcOp.getSrc();
+  }
+
+  // Fold producer-consumer reshape ops when they are perfect inverses of each
+  // other:
+  //   1) Reassociation indices are equivalent.
+  //   2) Boundary types are equivalent.
+  //   3) No reassociations have more than 1 dynamic dimension, and reassociated
+  //      shapes are equal for each reassociation.
+  auto reassociations = reshapeOp.getReassociationIndices();
+  auto inverseReassociations = reshapeSrcOp.getReassociationIndices();
+  if (reassociations != inverseReassociations)
+    return nullptr;
+  ArrayRef<int64_t> expandedSrcShape = srcType.getShape();
+  ArrayRef<int64_t> expandedResultShape = resultType.getShape();
+  if (llvm::none_of(reassociations, [&](auto reInd) {
+        auto srcSlice = expandedSrcShape.slice(reInd.front(), reInd.size());
+        auto resSlice = expandedResultShape.slice(reInd.front(), reInd.size());
+        return srcSlice == resSlice &&
+               llvm::count_if(srcSlice, ShapedType::isDynamic) > 1;
+      })) {
+    return reshapeSrcOp.getSrc();
+  }
   return nullptr;
 }
 
@@ -360,10 +394,12 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
           resultShape.slice(resultIndices.front(), resultIndices.size());
 
       if (srcSubShape.size() == resultSubShape.size()) {
-        if (srcSubShape == resultSubShape)
+        if (srcSubShape == resultSubShape &&
+            llvm::count_if(srcSubShape, ShapedType::isDynamic) < 2) {
           composedReassociation.push_back(srcIndices);
-        else
+        } else {
           return std::nullopt;
+        }
       }
 
       // Find reassociation to collapse `srcSubShape` into `resultSubShape`.
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index f7fbd3834288b..4a04d37d4be29 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1139,7 +1139,7 @@ func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32>
   return %1 : tensor<12x4xf32>
 }
 // CHECK-LABEL: @fold_collapse_of_expand
-//   CHECK-NOT:   linalg.{{.*}}shape
+//   CHECK-NOT:   tensor.{{.*}}_shape
 
 // -----
 
@@ -1152,7 +1152,60 @@ func.func @fold_collapse_of_expand_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index
   return %1 : tensor<?x?xf32>
 }
 // CHECK-LABEL: @fold_collapse_of_expand_dynamic
-//   CHECK-NOT:   linalg.{{.*}}_shape
+//   CHECK-NOT:   tensor.{{.*}}_shape
+
+// -----
+
+func.func @fold_collapse_of_expand_fully_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
+    -> tensor<?x?xf32> {
+  %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
+      : tensor<?x?xf32> into tensor<?x?x?xf32>
+  %1 = tensor.collapse_shape %0 [[0, 1], [2]]
+      : tensor<?x?x?xf32> into tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: @fold_collapse_of_expand_fully_dynamic
+//   CHECK-NOT:   tensor.{{.*}}_shape
+
+// -----
+
+func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf32> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
+      : tensor<3x4x4xf32> into tensor<12x4xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [3, 4, 4]
+      : tensor<12x4xf32> into tensor<3x4x4xf32>
+  return %1 : tensor<3x4x4xf32>
+}
+// CHECK-LABEL: @fold_expand_of_collapse
+//   CHECK-NOT:   tensor.{{.*}}_shape
+
+// -----
+
+func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
+    -> tensor<?x4x?xf32> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
+      : tensor<?x4x?xf32> into tensor<?x?xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
+      : tensor<?x?xf32> into tensor<?x4x?xf32>
+  return %1 : tensor<?x4x?xf32>
+}
+// CHECK-LABEL: @fold_expand_of_collapse_dynamic
+//   CHECK-NOT:   tensor.{{.*}}_shape
+
+// -----
+
+func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
+    -> tensor<?x?x?xf32> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
+      : tensor<?x?x?xf32> into tensor<?x?xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
+      : tensor<?x?xf32> into tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
+//       CHECK:   tensor.collapse_shape
+//       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape
+//       CHECK:   return %[[EXPAND]]
 
 // -----
 

@Max191 Max191 requested review from MaheshRavishankar, qedawkins and jaingaurav and removed request for jaingaurav June 6, 2024 15:40
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h Outdated Show resolved Hide resolved
mlir/test/Dialect/Tensor/canonicalize.mlir Show resolved Hide resolved
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h Outdated Show resolved Hide resolved
if (srcType.getRank() < reshapeSrcOp.getResultType().getRank())
return reshapeSrcOp.getSrc();
ArrayRef<int64_t> expandedSrcShape = srcType.getShape();
ArrayRef<int64_t> expandedResultShape = resultType.getShape();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: unused variable

@Max191 Max191 merged commit 2117677 into llvm:main Jun 7, 2024
7 checks passed
@HerrCai0907 HerrCai0907 mentioned this pull request Jun 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants