-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][tensor] Add support for tensor.pack static shapes inference. #80848
Conversation
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: Han-Chung Wang (hanhanW) ChangesFull diff: https://github.com/llvm/llvm-project/pull/80848.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index b21e89ae3a571..737f897fd4fd4 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3983,6 +3983,41 @@ static bool paddingIsNotNeeded(PackOp op) {
op.getMixedTiles());
}
+// Returns true if the `srcShape` or `destShape` is different from the one in
+// `packOp`.
+static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
+ SmallVectorImpl<int64_t> &destShape) {
+ bool changeNeeded = false;
+ srcShape.assign(packOp.getSourceType().getShape().begin(),
+ packOp.getSourceType().getShape().end());
+ destShape.assign(packOp.getDestType().getShape().begin(),
+ packOp.getDestType().getShape().end());
+ llvm::SmallSetVector<int64_t, 4> innerDims;
+ innerDims.insert(packOp.getInnerDimsPos().begin(),
+ packOp.getInnerDimsPos().end());
+ auto outerDimsPerm = packOp.getOuterDimsPerm();
+ int srcRank = packOp.getSourceRank();
+ for (auto i : llvm::seq<int64_t>(0, srcRank)) {
+ if (innerDims.contains(i))
+ continue;
+ int64_t srcPos = i;
+ int64_t destPos = i;
+ if (!outerDimsPerm.empty())
+ destPos = outerDimsPerm[srcPos];
+ if (ShapedType::isDynamic(srcShape[srcPos]) ==
+ ShapedType::isDynamic(destShape[destPos])) {
+ continue;
+ }
+ int64_t size = srcShape[srcPos];
+ if (ShapedType::isDynamic(size))
+ size = destShape[destPos];
+ srcShape[srcPos] = size;
+ destShape[destPos] = size;
+ changeNeeded = true;
+ }
+ return changeNeeded;
+}
+
LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
// Fold an unpack(pack(x)) to x.
if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
@@ -4003,6 +4038,31 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.finalizeOpModification(packOp);
return success();
}
+
+ // Insert tensor.cast ops if static shape inference is available..
+ SmallVector<int64_t> srcShape, destShape;
+ if (inferStaticShape(packOp, srcShape, destShape)) {
+ Location loc = packOp.getLoc();
+ Value source = packOp.getSource();
+ if (srcShape != packOp.getSourceType().getShape()) {
+ auto newSrcType = packOp.getSourceType().clone(srcShape);
+ source =
+ rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
+ }
+ Value dest = packOp.getDest();
+ if (destShape != packOp.getDestType().getShape()) {
+ auto newDestType = packOp.getDestType().clone(destShape);
+ dest =
+ rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
+ }
+ Value newOp = rewriter.create<tensor::PackOp>(
+ loc, source, dest, packOp.getInnerDimsPos(), packOp.getMixedTiles(),
+ packOp.getPaddingValue(), packOp.getOuterDimsPerm());
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(
+ packOp, packOp.getResult().getType(), newOp);
+ return success();
+ }
+
return failure();
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 7192a719ceb13..a8e08241d28c0 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -791,6 +791,45 @@ func.func @fold_padding_value_pack(%arg0: tensor<1200x500000xf32>) -> tensor<312
// -----
+func.func @infer_src_shape_pack(%src: tensor<?x?x?x?xf32>, %dest: tensor<10x20x30x40x16xf32>) -> tensor<10x20x30x40x16xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %pack = tensor.pack %src
+ padding_value(%cst : f32)
+ outer_dims_perm = [2, 1, 3, 0]
+ inner_dims_pos = [2]
+ inner_tiles = [16]
+ into %dest : tensor<?x?x?x?xf32> -> tensor<10x20x30x40x16xf32>
+ return %pack : tensor<10x20x30x40x16xf32>
+}
+// CHECK-LABEL: func.func @infer_src_shape_pack
+// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
+// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
+// CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?xf32> to tensor<30x20x?x10xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[CAST_SRC]] {{.+}} into %[[DEST]]
+// CHECK: return %[[PACK]]
+
+// -----
+
+func.func @infer_dest_shape_pack(%src: tensor<30x20x?x10xf32>, %dest: tensor<?x?x?x?x16xf32>) -> tensor<?x?x?x?x16xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %pack = tensor.pack %src
+ padding_value(%cst : f32)
+ outer_dims_perm = [2, 1, 3, 0]
+ inner_dims_pos = [2]
+ inner_tiles = [16]
+ into %dest : tensor<30x20x?x10xf32> -> tensor<?x?x?x?x16xf32>
+ return %pack : tensor<?x?x?x?x16xf32>
+}
+// CHECK-LABEL: func.func @infer_dest_shape_pack
+// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
+// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
+// CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?x16xf32> to tensor<10x20x30x?x16xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[SRC]] {{.+}} into %[[CAST_DEST]]
+// CHECK: %[[CAST_PACK:.+]] = tensor.cast %[[PACK]] : tensor<10x20x30x?x16xf32> to tensor<?x?x?x?x16xf32>
+// CHECK: return %[[CAST_PACK]]
+
+// -----
+
func.func @fold_padding_value_pack_negative1(%arg0: tensor<1200x499999xf32>) -> tensor<31250x1200x16x1xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<31250x1200x16x1xf32>
|
loc, source, dest, packOp.getInnerDimsPos(), packOp.getMixedTiles(), | ||
packOp.getPaddingValue(), packOp.getOuterDimsPerm()); | ||
rewriter.replaceOpWithNewOp<tensor::CastOp>( | ||
packOp, packOp.getResult().getType(), newOp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this well suited for a canonicalization? I'm wondering about cases where a pack
and unpack
could have folded away but this pattern introduces a tensor.cast
in the middle. Maybe we need the same pattern for unpack too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is common in Linalg and I think we can add such functionality to tensor ops. Good point on pack/unpack folding. I think they can be folded if the order of applying patterns is correct. To make the result IR deterministic, we might need it for unpack. So yes, I will prepare a patch and send it out for review. Do you think it is better to land both patterns together? If so, I will put the update to the PR. If it does not matter, I will land it as a follow-up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there is a possibility of non-deterministic IR probably best to land at the same time. In this case the pack-unpack canonicalization seems to always apply before this casting pattern so maybe it is ok here? Hard to predict exactly what the pattern applicator could do for any input IR though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, contingent on addressing the discussed pack-unpack case in a follow up.
Fixes iree-org/iree#16317