Skip to content

Commit

Permalink
[mlir][tensor] Add support for tensor.unpack static shapes inference. (
Browse files Browse the repository at this point in the history
…#81702)

The revision does not refactor the inferStaticShape for pack and unpack
ops because they can diverge quickly. Because there are more dimensions
can be inferred (i.e., with inner_tile_sizes) if the pack op does not
have padding value.

This is a follow-up of #80848
  • Loading branch information
hanhanW authored Feb 20, 2024
1 parent c3b87a8 commit eac8604
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 0 deletions.
59 changes: 59 additions & 0 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4229,6 +4229,40 @@ UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
metadata.outerDimsPerm);
}

/// Returns true if the `srcShape` or `destShape` is different from the one in
/// `op` and populates each with the inferred static shape.
static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
SmallVectorImpl<int64_t> &destShape) {
bool changeNeeded = false;
srcShape.assign(op.getSourceType().getShape().begin(),
op.getSourceType().getShape().end());
destShape.assign(op.getDestType().getShape().begin(),
op.getDestType().getShape().end());
llvm::SmallSetVector<int64_t, 4> innerDims;
innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
auto outerDimsPerm = op.getOuterDimsPerm();
int destRank = op.getDestRank();
for (auto i : llvm::seq<int64_t>(0, destRank)) {
if (innerDims.contains(i))
continue;
int64_t srcPos = i;
int64_t destPos = i;
if (!outerDimsPerm.empty())
srcPos = outerDimsPerm[destPos];
if (ShapedType::isDynamic(srcShape[srcPos]) ==
ShapedType::isDynamic(destShape[destPos])) {
continue;
}
int64_t size = srcShape[srcPos];
if (ShapedType::isDynamic(size))
size = destShape[destPos];
srcShape[srcPos] = size;
destShape[destPos] = size;
changeNeeded = true;
}
return changeNeeded;
}

LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
PatternRewriter &rewriter) {
/// pack(unpack(x)) -> x
Expand All @@ -4251,6 +4285,31 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
[&]() { unPackOp.setDpsInitOperand(0, newDest); });
return success();
}

// Insert tensor.cast ops if static shape inference is available..
SmallVector<int64_t> srcShape, destShape;
if (inferStaticShape(unPackOp, srcShape, destShape)) {
Location loc = unPackOp.getLoc();
Value source = unPackOp.getSource();
if (srcShape != unPackOp.getSourceType().getShape()) {
auto newSrcType = unPackOp.getSourceType().clone(srcShape);
source = rewriter.create<tensor::CastOp>(loc, newSrcType,
unPackOp.getSource());
}
Value dest = unPackOp.getDest();
if (destShape != unPackOp.getDestType().getShape()) {
auto newDestType = unPackOp.getDestType().clone(destShape);
dest =
rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
}
Value newOp = rewriter.create<tensor::UnPackOp>(
loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
unPackOp.getOuterDimsPerm());
rewriter.replaceOpWithNewOp<tensor::CastOp>(
unPackOp, unPackOp.getResult().getType(), newOp);
return success();
}

return failure();
}

Expand Down
51 changes: 51 additions & 0 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,41 @@ func.func @fold_unpack_constant_splat(%dest : tensor<128x256xf32>) -> tensor<128

// -----

func.func @infer_dest_shape_unpack(%src: tensor<10x20x30x40x16xf32>, %dest: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
%unpack = tensor.unpack %src
outer_dims_perm = [2, 1, 3, 0]
inner_dims_pos = [2]
inner_tiles = [16]
into %dest : tensor<10x20x30x40x16xf32> -> tensor<?x?x?x?xf32>
return %unpack : tensor<?x?x?x?xf32>
}
// CHECK-LABEL: func.func @infer_dest_shape_unpack
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
// CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?xf32> to tensor<30x20x?x10xf32>
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[SRC]] {{.+}} into %[[CAST_DEST]]
// CHECK: %[[CAST_UNPACK:.+]] = tensor.cast %[[UNPACK]] : tensor<30x20x?x10xf32> to tensor<?x?x?x?xf32>
// CHECK: return %[[CAST_UNPACK]]

// -----

func.func @infer_src_shape_unpack(%src: tensor<?x?x?x?x16xf32>, %dest: tensor<30x20x?x10xf32>) -> tensor<30x20x?x10xf32> {
%unpack = tensor.unpack %src
outer_dims_perm = [2, 1, 3, 0]
inner_dims_pos = [2]
inner_tiles = [16]
into %dest : tensor<?x?x?x?x16xf32> -> tensor<30x20x?x10xf32>
return %unpack : tensor<30x20x?x10xf32>
}
// CHECK-LABEL: func.func @infer_src_shape_unpack
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
// CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?x16xf32> to tensor<10x20x30x?x16xf32>
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[CAST_SRC]]
// CHECK: return %[[UNPACK]]

// -----

// CHECK-LABEL: func @fold_overlapping_insert
// CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
func.func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {
Expand Down Expand Up @@ -2176,3 +2211,19 @@ func.func @generate_negative_size_verifies() -> tensor<?x8xi32> {
} : tensor<?x8xi32>
return %tensor : tensor<?x8xi32>
}

// -----

func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> tensor<10x20x4x4xf32> {
%dim1 = arith.constant 40 : index
%dim2 = arith.constant 80 : index
%tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
%unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %tensor_empty : tensor<10x20x4x4xf32> -> tensor<?x?xf32>
%cast = tensor.cast %unpacked : tensor<?x?xf32> to tensor<40x80xf32>
%tensor_empty1 = tensor.empty() : tensor<10x20x4x4xf32>
%packed = tensor.pack %cast inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %tensor_empty1 : tensor<40x80xf32> -> tensor<10x20x4x4xf32>
return %packed : tensor<10x20x4x4xf32>
}
// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
// CHECK: return %[[SRC]]

0 comments on commit eac8604

Please sign in to comment.