Skip to content

Commit

Permalink
[mlir][linalg] Support lowering unpack with outer_dims_perm (#94477)
Browse files Browse the repository at this point in the history
This commit adds support for lowering `tensor.unpack` with a
non-identity `outer_dims_perm`. This was previously left as a
not-yet-implemented case.
  • Loading branch information
ryan-holt-1 authored Jun 7, 2024
1 parent 2c3723d commit 5b2f7a1
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 33 deletions.
48 changes: 19 additions & 29 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,13 +356,6 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,

FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
tensor::UnPackOp unPackOp) {
// 1. Filter out NYI cases.
if (!unPackOp.getOuterDimsPerm().empty() &&
!isIdentityPermutation(unPackOp.getOuterDimsPerm())) {
return rewriter.notifyMatchFailure(unPackOp,
"non-identity outer dims perm NYI");
}

Location loc = unPackOp->getLoc();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unPackOp);
Expand Down Expand Up @@ -391,45 +384,42 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr,
/*reshapeOp=*/nullptr, extractSliceOp};
}
// 2. Compute the permutation vector to move the last `numPackedDims` into
// the `innerPosDims` of a shape of rank `packedRank`.
int64_t numPackedDims = unPackOp.getInnerDimsPos().size();
auto lastDims = llvm::to_vector(
llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
PackingMetadata packingMetadata =
computePackingMetadata(packedRank, unPackOp.getInnerDimsPos());
SmallVector<int64_t> lastDimsToInsertPositionsPerm = computePermutationVector(
packedRank, lastDims, packingMetadata.insertPositions);

// 3. Compute the stripMinedShape: this is the packed shape without outer and

// 1. Compute the permutation vector to shuffle packed shape into the shape
// before any outer or inner permutations have been applied.
PackingMetadata packingMetadata;
SmallVector<int64_t> packedToStripMinedShapePerm =
tensor::getUnPackInverseSrcPerm(unPackOp, packingMetadata);

// 2. Compute the stripMinedShape: this is the packed shape without outer and
// inner permutations.
SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm);
applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);

// 4. Transpose packedShape to stripMinedShape.
// 3. Transpose packedShape to stripMinedShape.
RankedTensorType stripMinedTensorType =
RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
stripMinedTensorType, packingMetadata.reassociations);

// Get dynamic dims from input tensor based on lastDimsToInsertPositionsPerm
// Get dynamic dims from input tensor based on packedToStripMinedShapePerm
// permutation.
SmallVector<OpFoldResult, 4> dims =
tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
applyPermutationToVector(dims, lastDimsToInsertPositionsPerm);
applyPermutationToVector(dims, packedToStripMinedShapePerm);
auto emptyOp = rewriter.create<tensor::EmptyOp>(
loc, dims, stripMinedTensorType.getElementType());
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm);
loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);

LLVM_DEBUG(
DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
DBGS() << "insertPositions: ");
DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
DBGS() << "packedShape: ");
DBGSNL();
llvm::interleaveComma(lastDimsToInsertPositionsPerm,
DBGS() << "lastDimsToInsertPositionsPerm: ");
llvm::interleaveComma(packedToStripMinedShapePerm,
DBGS() << "packedToStripMinedShapePerm: ");
DBGSNL(); llvm::interleaveComma(
packingMetadata.reassociations, DBGS() << "reassociations: ",
[&](ReassociationIndices ri) {
Expand All @@ -439,24 +429,24 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL(););

// 5. Collapse from the stripMinedShape to the padded result.
// 4. Collapse from the stripMinedShape to the padded result.
auto reshapeOp = rewriter.create<tensor::CollapseShapeOp>(
loc, collapsedType, transposeOp->getResult(0),
packingMetadata.reassociations);

// 6. ExtractSlice.
// 5. ExtractSlice.
int64_t destRank = destTensorType.getRank();
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
loc, destTensorType, reshapeOp->getResult(0),
SmallVector<OpFoldResult>(destRank, zero),
tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()),
SmallVector<OpFoldResult>(destRank, one));

// 7. Inject a copy to preserve DPS.
// 6. Inject a copy to preserve DPS.
auto copyOp = rewriter.create<linalg::CopyOp>(
loc, extractSliceOp->getResult(0), unPackOp.getDest());

// 8. Replace unPackOp by extractSliceOp.
// 7. Replace unPackOp by copyOp.
rewriter.replaceOp(unPackOp, copyOp->getResults());

return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};
Expand Down
18 changes: 14 additions & 4 deletions mlir/test/Dialect/Linalg/transform-lower-pack.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -622,9 +622,20 @@ module attributes {transform.with_named_sequence} {

// -----

// At the moment, we cannot lower tensor.unpack with outer_dims_perm.
func.func @diagnostic_unpack(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> {
// expected-note @below {{target payload op}}
// CHECK-LABEL: @unpack_with_outer_dims_perm
// CHECK-SAME: %[[ARG0:.*]]: tensor<32x64xf32>, %[[ARG1:.*]]: tensor<2x4x32x8xf32>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<4x8x2x32xf32>
// CHECK: %[[TRAN:.*]] = linalg.transpose
// CHECK-SAME: ins(%[[ARG1]] : tensor<2x4x32x8xf32>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<4x8x2x32xf32>)
// CHECK-SAME: permutation = [1, 3, 0, 2]
// CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]]
// CHECK-SAME: : tensor<4x8x2x32xf32> into tensor<32x64xf32>
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0] [32, 64] [1, 1]
// CHECK-SAME: : tensor<32x64xf32> to tensor<32x64xf32>
// CHECK: linalg.copy ins(%[[SLICE]]
// CHECK-SAME: : tensor<32x64xf32>) outs(%[[ARG0]] : tensor<32x64xf32>) -> tensor<32x64xf32>
func.func @unpack_with_outer_dims_perm(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> {
%unpack = tensor.unpack %arg1 outer_dims_perm = [1, 0]
inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg0 : tensor<2x4x32x8xf32> -> tensor<32x64xf32>
return %unpack : tensor<32x64xf32>
Expand All @@ -634,7 +645,6 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
: (!transform.any_op) -> !transform.op<"tensor.unpack">
// expected-error @below {{cannot lower to transpose + collapse + extract}}
transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
-> (!transform.op<"tensor.empty">,
!transform.op<"linalg.transpose">,
Expand Down

0 comments on commit 5b2f7a1

Please sign in to comment.