diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 62bf6b48f18b4a..e6efec14e31a60 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -4229,6 +4229,40 @@ UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc, metadata.outerDimsPerm); } +/// Returns true if the `srcShape` or `destShape` is different from the one in +/// `op` and populates each with the inferred static shape. +static bool inferStaticShape(UnPackOp op, SmallVectorImpl &srcShape, + SmallVectorImpl &destShape) { + bool changeNeeded = false; + srcShape.assign(op.getSourceType().getShape().begin(), + op.getSourceType().getShape().end()); + destShape.assign(op.getDestType().getShape().begin(), + op.getDestType().getShape().end()); + llvm::SmallSetVector innerDims; + innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end()); + auto outerDimsPerm = op.getOuterDimsPerm(); + int destRank = op.getDestRank(); + for (auto i : llvm::seq(0, destRank)) { + if (innerDims.contains(i)) + continue; + int64_t srcPos = i; + int64_t destPos = i; + if (!outerDimsPerm.empty()) + srcPos = outerDimsPerm[destPos]; + if (ShapedType::isDynamic(srcShape[srcPos]) == + ShapedType::isDynamic(destShape[destPos])) { + continue; + } + int64_t size = srcShape[srcPos]; + if (ShapedType::isDynamic(size)) + size = destShape[destPos]; + srcShape[srcPos] = size; + destShape[destPos] = size; + changeNeeded = true; + } + return changeNeeded; +} + LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp, PatternRewriter &rewriter) { /// pack(unpack(x)) -> x @@ -4251,6 +4285,31 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp, [&]() { unPackOp.setDpsInitOperand(0, newDest); }); return success(); } + + // Insert tensor.cast ops if static shape inference is available.. + SmallVector srcShape, destShape; + if (inferStaticShape(unPackOp, srcShape, destShape)) { + Location loc = unPackOp.getLoc(); + Value source = unPackOp.getSource(); + if (srcShape != unPackOp.getSourceType().getShape()) { + auto newSrcType = unPackOp.getSourceType().clone(srcShape); + source = rewriter.create(loc, newSrcType, + unPackOp.getSource()); + } + Value dest = unPackOp.getDest(); + if (destShape != unPackOp.getDestType().getShape()) { + auto newDestType = unPackOp.getDestType().clone(destShape); + dest = + rewriter.create(loc, newDestType, unPackOp.getDest()); + } + Value newOp = rewriter.create( + loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(), + unPackOp.getOuterDimsPerm()); + rewriter.replaceOpWithNewOp( + unPackOp, unPackOp.getResult().getType(), newOp); + return success(); + } + return failure(); } diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 3b6cd799a6f348..e123c77aabd57c 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -909,6 +909,41 @@ func.func @fold_unpack_constant_splat(%dest : tensor<128x256xf32>) -> tensor<128 // ----- +func.func @infer_dest_shape_unpack(%src: tensor<10x20x30x40x16xf32>, %dest: tensor) -> tensor { + %unpack = tensor.unpack %src + outer_dims_perm = [2, 1, 3, 0] + inner_dims_pos = [2] + inner_tiles = [16] + into %dest : tensor<10x20x30x40x16xf32> -> tensor + return %unpack : tensor +} +// CHECK-LABEL: func.func @infer_dest_shape_unpack +// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]] +// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]] +// CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor to tensor<30x20x?x10xf32> +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[SRC]] {{.+}} into %[[CAST_DEST]] +// CHECK: %[[CAST_UNPACK:.+]] = tensor.cast %[[UNPACK]] : tensor<30x20x?x10xf32> to tensor +// CHECK: return %[[CAST_UNPACK]] + +// ----- + +func.func @infer_src_shape_unpack(%src: tensor, %dest: tensor<30x20x?x10xf32>) -> tensor<30x20x?x10xf32> { + %unpack = tensor.unpack %src + outer_dims_perm = [2, 1, 3, 0] + inner_dims_pos = [2] + inner_tiles = [16] + into %dest : tensor -> tensor<30x20x?x10xf32> + return %unpack : tensor<30x20x?x10xf32> +} +// CHECK-LABEL: func.func @infer_src_shape_unpack +// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]] +// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]] +// CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor to tensor<10x20x30x?x16xf32> +// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[CAST_SRC]] +// CHECK: return %[[UNPACK]] + +// ----- + // CHECK-LABEL: func @fold_overlapping_insert // CHECK-SAME: %[[INPUT:.+]]: tensor, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32> func.func @fold_overlapping_insert(%input : tensor, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor) { @@ -2176,3 +2211,19 @@ func.func @generate_negative_size_verifies() -> tensor { } : tensor return %tensor : tensor } + +// ----- + +func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> tensor<10x20x4x4xf32> { + %dim1 = arith.constant 40 : index + %dim2 = arith.constant 80 : index + %tensor_empty = tensor.empty(%dim1, %dim2) : tensor + %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %tensor_empty : tensor<10x20x4x4xf32> -> tensor + %cast = tensor.cast %unpacked : tensor to tensor<40x80xf32> + %tensor_empty1 = tensor.empty() : tensor<10x20x4x4xf32> + %packed = tensor.pack %cast inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %tensor_empty1 : tensor<40x80xf32> -> tensor<10x20x4x4xf32> + return %packed : tensor<10x20x4x4xf32> +} +// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles +// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]] +// CHECK: return %[[SRC]]