-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][tensor] Add support for tensor.unpack static shapes inference. #81702
Conversation
The revision does not refactor the inferStaticShape for pack and unpack ops because they can diverge quickly. Because there are more dimensions can be infered (i.e., with inner_tile_sizes) if the pack op does not have padding value.
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: Han-Chung Wang (hanhanW) ChangesThe revision does not refactor the inferStaticShape for pack and unpack ops because they can diverge quickly. Because there are more dimensions can be inferred (i.e., with inner_tile_sizes) if the pack op does not have padding value. This is a follow-up of #80848 Full diff: https://github.com/llvm/llvm-project/pull/81702.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index bb72cba96ad935..1df15c0372e6e2 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4212,6 +4212,40 @@ UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
metadata.outerDimsPerm);
}
+/// Returns true if the `srcShape` or `destShape` is different from the one in
+/// `op` and populates each with the inferred static shape.
+static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
+ SmallVectorImpl<int64_t> &destShape) {
+ bool changeNeeded = false;
+ srcShape.assign(op.getSourceType().getShape().begin(),
+ op.getSourceType().getShape().end());
+ destShape.assign(op.getDestType().getShape().begin(),
+ op.getDestType().getShape().end());
+ llvm::SmallSetVector<int64_t, 4> innerDims;
+ innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
+ auto outerDimsPerm = op.getOuterDimsPerm();
+ int destRank = op.getDestRank();
+ for (auto i : llvm::seq<int64_t>(0, destRank)) {
+ if (innerDims.contains(i))
+ continue;
+ int64_t srcPos = i;
+ int64_t destPos = i;
+ if (!outerDimsPerm.empty())
+ srcPos = outerDimsPerm[destPos];
+ if (ShapedType::isDynamic(srcShape[srcPos]) ==
+ ShapedType::isDynamic(destShape[destPos])) {
+ continue;
+ }
+ int64_t size = srcShape[srcPos];
+ if (ShapedType::isDynamic(size))
+ size = destShape[destPos];
+ srcShape[srcPos] = size;
+ destShape[destPos] = size;
+ changeNeeded = true;
+ }
+ return changeNeeded;
+}
+
LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
PatternRewriter &rewriter) {
/// pack(unpack(x)) -> x
@@ -4234,6 +4268,31 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
[&]() { unPackOp.setDpsInitOperand(0, newDest); });
return success();
}
+
+ // Insert tensor.cast ops if static shape inference is available..
+ SmallVector<int64_t> srcShape, destShape;
+ if (inferStaticShape(unPackOp, srcShape, destShape)) {
+ Location loc = unPackOp.getLoc();
+ Value source = unPackOp.getSource();
+ if (srcShape != unPackOp.getSourceType().getShape()) {
+ auto newSrcType = unPackOp.getSourceType().clone(srcShape);
+ source = rewriter.create<tensor::CastOp>(loc, newSrcType,
+ unPackOp.getSource());
+ }
+ Value dest = unPackOp.getDest();
+ if (destShape != unPackOp.getDestType().getShape()) {
+ auto newDestType = unPackOp.getDestType().clone(destShape);
+ dest =
+ rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
+ }
+ Value newOp = rewriter.create<tensor::UnPackOp>(
+ loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
+ unPackOp.getOuterDimsPerm());
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(
+ unPackOp, unPackOp.getResult().getType(), newOp);
+ return success();
+ }
+
return failure();
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 3b6cd799a6f348..35619d098f008a 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -909,6 +909,41 @@ func.func @fold_unpack_constant_splat(%dest : tensor<128x256xf32>) -> tensor<128
// -----
+func.func @infer_dest_shape_unpack(%src: tensor<10x20x30x40x16xf32>, %dest: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+ %unpack = tensor.unpack %src
+ outer_dims_perm = [2, 1, 3, 0]
+ inner_dims_pos = [2]
+ inner_tiles = [16]
+ into %dest : tensor<10x20x30x40x16xf32> -> tensor<?x?x?x?xf32>
+ return %unpack : tensor<?x?x?x?xf32>
+}
+// CHECK-LABEL: func.func @infer_dest_shape_unpack
+// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
+// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
+// CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?xf32> to tensor<30x20x?x10xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[SRC]] {{.+}} into %[[CAST_DEST]]
+// CHECK: %[[CAST_UNPACK:.+]] = tensor.cast %[[UNPACK]] : tensor<30x20x?x10xf32> to tensor<?x?x?x?xf32>
+// CHECK: return %[[CAST_UNPACK]]
+
+// -----
+
+func.func @infer_src_shape_unpack(%src: tensor<?x?x?x?x16xf32>, %dest: tensor<30x20x?x10xf32>) -> tensor<30x20x?x10xf32> {
+ %unpack = tensor.unpack %src
+ outer_dims_perm = [2, 1, 3, 0]
+ inner_dims_pos = [2]
+ inner_tiles = [16]
+ into %dest : tensor<?x?x?x?x16xf32> -> tensor<30x20x?x10xf32>
+ return %unpack : tensor<30x20x?x10xf32>
+}
+// CHECK-LABEL: func.func @infer_src_shape_unpack
+// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
+// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
+// CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?x16xf32> to tensor<10x20x30x?x16xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[CAST_SRC]]
+// CHECK: return %[[UNPACK]]
+
+// -----
+
// CHECK-LABEL: func @fold_overlapping_insert
// CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
func.func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {
@@ -2176,3 +2211,18 @@ func.func @generate_negative_size_verifies() -> tensor<?x8xi32> {
} : tensor<?x8xi32>
return %tensor : tensor<?x8xi32>
}
+
+// -----
+
+func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x?x?xf32>,
+ %dim1: index, %dim2: index, %dim3: index, %dim4: index, %tile1: index,
+ %tile2: index) -> tensor<10x20x?x?xf32> {
+ %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
+ %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<10x20x?x?xf32> -> tensor<?x?xf32>
+ %tensor_empty1 = tensor.empty(%dim3, %dim4) : tensor<10x20x?x?xf32>
+ %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<?x?xf32> -> tensor<10x20x?x?xf32>
+ return %packed : tensor<10x20x?x?xf32>
+}
+// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
+// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
+// CHECK: return %[[SRC]]
|
%tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32> | ||
%unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<10x20x?x?xf32> -> tensor<?x?xf32> | ||
%tensor_empty1 = tensor.empty(%dim3, %dim4) : tensor<10x20x?x?xf32> | ||
%packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<?x?xf32> -> tensor<10x20x?x?xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the static shape inference pattern kick in for this case? The tile sizes are dynamic here, so from what I can tell it doesn't apply. Can you add an additional test for something like this
%tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
%unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %tensor_empty : tensor<10x20x4x4xf32> -> tensor<?x?xf32>
%cast = tensor.cast %unpacked : tensor<?x?xf32> to tensor<40x80xf32>
%tensor_empty1 = tensor.empty() : tensor<10x20x4x4xf32>
%packed = tensor.pack %cast inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %tensor_empty1 : tensor<40x80xf32> -> tensor<10x20x4x4xf32>
(might have written this wrong, but I'm wondering if we still get it to fold in this case).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, good point! I verified that my test case does not work with -debug
; yours is better. I'll update the test.
The revision does not refactor the inferStaticShape for pack and unpack ops because they can diverge quickly. Because there are more dimensions can be inferred (i.e., with inner_tile_sizes) if the pack op does not have padding value.
This is a follow-up of #80848