Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][tensor] Add support for tensor.unpack static shapes inference. #81702

Merged
merged 3 commits into from
Feb 20, 2024

Conversation

hanhanW
Copy link
Contributor

@hanhanW hanhanW commented Feb 14, 2024

The revision does not refactor the inferStaticShape for pack and unpack ops because they can diverge quickly. Because there are more dimensions can be inferred (i.e., with inner_tile_sizes) if the pack op does not have padding value.

This is a follow-up of #80848

The revision does not refactor the inferStaticShape for pack and unpack
ops because they can diverge quickly. Because there are more dimensions
can be infered (i.e., with inner_tile_sizes) if the pack op does not
have padding value.
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 14, 2024

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: Han-Chung Wang (hanhanW)

Changes

The revision does not refactor the inferStaticShape for pack and unpack ops because they can diverge quickly. Because there are more dimensions can be inferred (i.e., with inner_tile_sizes) if the pack op does not have padding value.

This is a follow-up of #80848


Full diff: https://github.com/llvm/llvm-project/pull/81702.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+59)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+50)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index bb72cba96ad935..1df15c0372e6e2 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4212,6 +4212,40 @@ UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
                             metadata.outerDimsPerm);
 }
 
+/// Returns true if the `srcShape` or `destShape` is different from the one in
+/// `op` and populates each with the inferred static shape.
+static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
+                             SmallVectorImpl<int64_t> &destShape) {
+  bool changeNeeded = false;
+  srcShape.assign(op.getSourceType().getShape().begin(),
+                  op.getSourceType().getShape().end());
+  destShape.assign(op.getDestType().getShape().begin(),
+                   op.getDestType().getShape().end());
+  llvm::SmallSetVector<int64_t, 4> innerDims;
+  innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
+  auto outerDimsPerm = op.getOuterDimsPerm();
+  int destRank = op.getDestRank();
+  for (auto i : llvm::seq<int64_t>(0, destRank)) {
+    if (innerDims.contains(i))
+      continue;
+    int64_t srcPos = i;
+    int64_t destPos = i;
+    if (!outerDimsPerm.empty())
+      srcPos = outerDimsPerm[destPos];
+    if (ShapedType::isDynamic(srcShape[srcPos]) ==
+        ShapedType::isDynamic(destShape[destPos])) {
+      continue;
+    }
+    int64_t size = srcShape[srcPos];
+    if (ShapedType::isDynamic(size))
+      size = destShape[destPos];
+    srcShape[srcPos] = size;
+    destShape[destPos] = size;
+    changeNeeded = true;
+  }
+  return changeNeeded;
+}
+
 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
                                      PatternRewriter &rewriter) {
   /// pack(unpack(x)) -> x
@@ -4234,6 +4268,31 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
                              [&]() { unPackOp.setDpsInitOperand(0, newDest); });
     return success();
   }
+
+  // Insert tensor.cast ops if static shape inference is available..
+  SmallVector<int64_t> srcShape, destShape;
+  if (inferStaticShape(unPackOp, srcShape, destShape)) {
+    Location loc = unPackOp.getLoc();
+    Value source = unPackOp.getSource();
+    if (srcShape != unPackOp.getSourceType().getShape()) {
+      auto newSrcType = unPackOp.getSourceType().clone(srcShape);
+      source = rewriter.create<tensor::CastOp>(loc, newSrcType,
+                                               unPackOp.getSource());
+    }
+    Value dest = unPackOp.getDest();
+    if (destShape != unPackOp.getDestType().getShape()) {
+      auto newDestType = unPackOp.getDestType().clone(destShape);
+      dest =
+          rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
+    }
+    Value newOp = rewriter.create<tensor::UnPackOp>(
+        loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
+        unPackOp.getOuterDimsPerm());
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(
+        unPackOp, unPackOp.getResult().getType(), newOp);
+    return success();
+  }
+
   return failure();
 }
 
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 3b6cd799a6f348..35619d098f008a 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -909,6 +909,41 @@ func.func @fold_unpack_constant_splat(%dest : tensor<128x256xf32>) -> tensor<128
 
 // -----
 
+func.func @infer_dest_shape_unpack(%src: tensor<10x20x30x40x16xf32>, %dest: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %unpack = tensor.unpack %src
+    outer_dims_perm = [2, 1, 3, 0]
+    inner_dims_pos = [2]
+    inner_tiles = [16]
+    into %dest : tensor<10x20x30x40x16xf32> -> tensor<?x?x?x?xf32>
+  return %unpack : tensor<?x?x?x?xf32>
+}
+// CHECK-LABEL: func.func @infer_dest_shape_unpack
+// CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
+// CHECK-SAME:    %[[DEST:[0-9a-zA-Z]+]]
+// CHECK:         %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?xf32> to tensor<30x20x?x10xf32>
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[SRC]] {{.+}} into %[[CAST_DEST]]
+// CHECK:         %[[CAST_UNPACK:.+]] = tensor.cast %[[UNPACK]] : tensor<30x20x?x10xf32> to tensor<?x?x?x?xf32>
+// CHECK:         return %[[CAST_UNPACK]]
+
+// -----
+
+func.func @infer_src_shape_unpack(%src: tensor<?x?x?x?x16xf32>, %dest: tensor<30x20x?x10xf32>) -> tensor<30x20x?x10xf32> {
+  %unpack = tensor.unpack %src
+    outer_dims_perm = [2, 1, 3, 0]
+    inner_dims_pos = [2]
+    inner_tiles = [16]
+    into %dest : tensor<?x?x?x?x16xf32> -> tensor<30x20x?x10xf32>
+  return %unpack : tensor<30x20x?x10xf32>
+}
+// CHECK-LABEL: func.func @infer_src_shape_unpack
+// CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
+// CHECK-SAME:    %[[DEST:[0-9a-zA-Z]+]]
+// CHECK:         %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?x16xf32> to tensor<10x20x30x?x16xf32>
+// CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[CAST_SRC]]
+// CHECK:         return %[[UNPACK]]
+
+// -----
+
 // CHECK-LABEL: func @fold_overlapping_insert
 //  CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
 func.func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {
@@ -2176,3 +2211,18 @@ func.func @generate_negative_size_verifies() -> tensor<?x8xi32> {
   } : tensor<?x8xi32>
   return %tensor : tensor<?x8xi32>
 }
+
+// -----
+
+func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x?x?xf32>,
+    %dim1: index, %dim2: index, %dim3: index, %dim4: index, %tile1: index,
+    %tile2: index) -> tensor<10x20x?x?xf32> {
+  %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
+  %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<10x20x?x?xf32> -> tensor<?x?xf32>
+  %tensor_empty1 = tensor.empty(%dim3, %dim4) : tensor<10x20x?x?xf32>
+  %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<?x?xf32> -> tensor<10x20x?x?xf32>
+  return %packed : tensor<10x20x?x?xf32>
+}
+// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
+// CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
+// CHECK:         return %[[SRC]]

%tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
%unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<10x20x?x?xf32> -> tensor<?x?xf32>
%tensor_empty1 = tensor.empty(%dim3, %dim4) : tensor<10x20x?x?xf32>
%packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<?x?xf32> -> tensor<10x20x?x?xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the static shape inference pattern kick in for this case? The tile sizes are dynamic here, so from what I can tell it doesn't apply. Can you add an additional test for something like this

  %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
  %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %tensor_empty : tensor<10x20x4x4xf32> -> tensor<?x?xf32>
  %cast = tensor.cast %unpacked : tensor<?x?xf32> to tensor<40x80xf32>
  %tensor_empty1 = tensor.empty() : tensor<10x20x4x4xf32>
  %packed = tensor.pack %cast inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %tensor_empty1 : tensor<40x80xf32> -> tensor<10x20x4x4xf32>

(might have written this wrong, but I'm wondering if we still get it to fold in this case).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, good point! I verified that my test case does not work with -debug; yours is better. I'll update the test.

@hanhanW hanhanW merged commit eac8604 into llvm:main Feb 20, 2024
4 checks passed
@hanhanW hanhanW deleted the infer-unpack-static-shapes branch February 20, 2024 00:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants