Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][tensor] Add support for tensor.pack static shapes inference. #80848

Merged
merged 2 commits into from
Feb 14, 2024

Conversation

hanhanW
Copy link
Contributor

@hanhanW hanhanW commented Feb 6, 2024

@llvmbot
Copy link
Collaborator

llvmbot commented Feb 6, 2024

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: Han-Chung Wang (hanhanW)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/80848.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+60)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+39)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index b21e89ae3a571..737f897fd4fd4 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3983,6 +3983,41 @@ static bool paddingIsNotNeeded(PackOp op) {
                                       op.getMixedTiles());
 }
 
+// Returns true if the `srcShape` or `destShape` is different from the one in
+// `packOp`.
+static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
+                             SmallVectorImpl<int64_t> &destShape) {
+  bool changeNeeded = false;
+  srcShape.assign(packOp.getSourceType().getShape().begin(),
+                  packOp.getSourceType().getShape().end());
+  destShape.assign(packOp.getDestType().getShape().begin(),
+                   packOp.getDestType().getShape().end());
+  llvm::SmallSetVector<int64_t, 4> innerDims;
+  innerDims.insert(packOp.getInnerDimsPos().begin(),
+                   packOp.getInnerDimsPos().end());
+  auto outerDimsPerm = packOp.getOuterDimsPerm();
+  int srcRank = packOp.getSourceRank();
+  for (auto i : llvm::seq<int64_t>(0, srcRank)) {
+    if (innerDims.contains(i))
+      continue;
+    int64_t srcPos = i;
+    int64_t destPos = i;
+    if (!outerDimsPerm.empty())
+      destPos = outerDimsPerm[srcPos];
+    if (ShapedType::isDynamic(srcShape[srcPos]) ==
+        ShapedType::isDynamic(destShape[destPos])) {
+      continue;
+    }
+    int64_t size = srcShape[srcPos];
+    if (ShapedType::isDynamic(size))
+      size = destShape[destPos];
+    srcShape[srcPos] = size;
+    destShape[destPos] = size;
+    changeNeeded = true;
+  }
+  return changeNeeded;
+}
+
 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
   // Fold an unpack(pack(x)) to x.
   if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
@@ -4003,6 +4038,31 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
     rewriter.finalizeOpModification(packOp);
     return success();
   }
+
+  // Insert tensor.cast ops if static shape inference is available..
+  SmallVector<int64_t> srcShape, destShape;
+  if (inferStaticShape(packOp, srcShape, destShape)) {
+    Location loc = packOp.getLoc();
+    Value source = packOp.getSource();
+    if (srcShape != packOp.getSourceType().getShape()) {
+      auto newSrcType = packOp.getSourceType().clone(srcShape);
+      source =
+          rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
+    }
+    Value dest = packOp.getDest();
+    if (destShape != packOp.getDestType().getShape()) {
+      auto newDestType = packOp.getDestType().clone(destShape);
+      dest =
+          rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
+    }
+    Value newOp = rewriter.create<tensor::PackOp>(
+        loc, source, dest, packOp.getInnerDimsPos(), packOp.getMixedTiles(),
+        packOp.getPaddingValue(), packOp.getOuterDimsPerm());
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(
+        packOp, packOp.getResult().getType(), newOp);
+    return success();
+  }
+
   return failure();
 }
 
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 7192a719ceb13..a8e08241d28c0 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -791,6 +791,45 @@ func.func @fold_padding_value_pack(%arg0: tensor<1200x500000xf32>) -> tensor<312
 
 // -----
 
+func.func @infer_src_shape_pack(%src: tensor<?x?x?x?xf32>, %dest: tensor<10x20x30x40x16xf32>) -> tensor<10x20x30x40x16xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+   %pack = tensor.pack %src
+    padding_value(%cst : f32)
+    outer_dims_perm = [2, 1, 3, 0]
+    inner_dims_pos = [2]
+    inner_tiles = [16]
+    into %dest : tensor<?x?x?x?xf32> -> tensor<10x20x30x40x16xf32>
+  return %pack : tensor<10x20x30x40x16xf32>
+}
+// CHECK-LABEL: func.func @infer_src_shape_pack
+// CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
+// CHECK-SAME:    %[[DEST:[0-9a-zA-Z]+]]
+// CHECK:         %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?xf32> to tensor<30x20x?x10xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[CAST_SRC]] {{.+}} into %[[DEST]]
+// CHECK:         return %[[PACK]]
+
+// -----
+
+func.func @infer_dest_shape_pack(%src: tensor<30x20x?x10xf32>, %dest: tensor<?x?x?x?x16xf32>) -> tensor<?x?x?x?x16xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+   %pack = tensor.pack %src
+    padding_value(%cst : f32)
+    outer_dims_perm = [2, 1, 3, 0]
+    inner_dims_pos = [2]
+    inner_tiles = [16]
+    into %dest : tensor<30x20x?x10xf32> -> tensor<?x?x?x?x16xf32>
+  return %pack : tensor<?x?x?x?x16xf32>
+}
+// CHECK-LABEL: func.func @infer_dest_shape_pack
+// CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
+// CHECK-SAME:    %[[DEST:[0-9a-zA-Z]+]]
+// CHECK:         %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?x16xf32> to tensor<10x20x30x?x16xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[SRC]] {{.+}} into %[[CAST_DEST]]
+// CHECK:         %[[CAST_PACK:.+]] = tensor.cast %[[PACK]] : tensor<10x20x30x?x16xf32> to tensor<?x?x?x?x16xf32>
+// CHECK:         return %[[CAST_PACK]]
+
+// -----
+
 func.func @fold_padding_value_pack_negative1(%arg0: tensor<1200x499999xf32>) -> tensor<31250x1200x16x1xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %0 = tensor.empty() : tensor<31250x1200x16x1xf32>

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp Outdated Show resolved Hide resolved
loc, source, dest, packOp.getInnerDimsPos(), packOp.getMixedTiles(),
packOp.getPaddingValue(), packOp.getOuterDimsPerm());
rewriter.replaceOpWithNewOp<tensor::CastOp>(
packOp, packOp.getResult().getType(), newOp);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this well suited for a canonicalization? I'm wondering about cases where a pack and unpack could have folded away but this pattern introduces a tensor.cast in the middle. Maybe we need the same pattern for unpack too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is common in Linalg and I think we can add such functionality to tensor ops. Good point on pack/unpack folding. I think they can be folded if the order of applying patterns is correct. To make the result IR deterministic, we might need it for unpack. So yes, I will prepare a patch and send it out for review. Do you think it is better to land both patterns together? If so, I will put the update to the PR. If it does not matter, I will land it as a follow-up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is a possibility of non-deterministic IR probably best to land at the same time. In this case the pack-unpack canonicalization seems to always apply before this casting pattern so maybe it is ok here? Hard to predict exactly what the pattern applicator could do for any input IR though.

Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, contingent on addressing the discussed pack-unpack case in a follow up.

@hanhanW hanhanW merged commit bc08cc2 into llvm:main Feb 14, 2024
6 checks passed
@hanhanW hanhanW deleted the infer-pack-static-shapes branch February 14, 2024 04:20
hanhanW added a commit that referenced this pull request Feb 20, 2024
…#81702)

The revision does not refactor the inferStaticShape for pack and unpack
ops because they can diverge quickly. Because there are more dimensions
can be inferred (i.e., with inner_tile_sizes) if the pack op does not
have padding value.

This is a follow-up of #80848
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Llama-7b Crashes during IREE Compilation
3 participants