From 82383c7700363c74c91486a91d846dd14b08576f Mon Sep 17 00:00:00 2001 From: ryan-holt-1 Date: Wed, 5 Jun 2024 10:33:45 -0400 Subject: [PATCH] [mlir][linalg] Support lowering unpack with outer_dims_perm This commit adds support for lowering `tensor.unpack` with a non-identity `outer_dims_perm`. This was previously left as a not-yet-implemented case. --- .../Dialect/Linalg/Transforms/Transforms.cpp | 48 ++++++++----------- .../Dialect/Linalg/transform-lower-pack.mlir | 18 +++++-- 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 91dfac802ad672..f30ef235e9cd34 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -356,13 +356,6 @@ FailureOr linalg::lowerPack(RewriterBase &rewriter, FailureOr linalg::lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp) { - // 1. Filter out NYI cases. - if (!unPackOp.getOuterDimsPerm().empty() && - !isIdentityPermutation(unPackOp.getOuterDimsPerm())) { - return rewriter.notifyMatchFailure(unPackOp, - "non-identity outer dims perm NYI"); - } - Location loc = unPackOp->getLoc(); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(unPackOp); @@ -391,36 +384,33 @@ FailureOr linalg::lowerUnPack(RewriterBase &rewriter, return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr, /*reshapeOp=*/nullptr, extractSliceOp}; } - // 2. Compute the permutation vector to move the last `numPackedDims` into - // the `innerPosDims` of a shape of rank `packedRank`. - int64_t numPackedDims = unPackOp.getInnerDimsPos().size(); - auto lastDims = llvm::to_vector( - llvm::seq(packedRank - numPackedDims, packedRank)); - PackingMetadata packingMetadata = - computePackingMetadata(packedRank, unPackOp.getInnerDimsPos()); - SmallVector lastDimsToInsertPositionsPerm = computePermutationVector( - packedRank, lastDims, packingMetadata.insertPositions); - - // 3. Compute the stripMinedShape: this is the packed shape without outer and + + // 1. Compute the permutation vector to shuffle packed shape into the shape + // before any outer or inner permutations have been applied. + PackingMetadata packingMetadata; + SmallVector packedToStripMinedShapePerm = + tensor::getUnPackInverseSrcPerm(unPackOp, packingMetadata); + + // 2. Compute the stripMinedShape: this is the packed shape without outer and // inner permutations. SmallVector stripMinedShape(packedTensorType.getShape()); - applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm); + applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm); - // 4. Transpose packedShape to stripMinedShape. + // 3. Transpose packedShape to stripMinedShape. RankedTensorType stripMinedTensorType = RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape); RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType( stripMinedTensorType, packingMetadata.reassociations); - // Get dynamic dims from input tensor based on lastDimsToInsertPositionsPerm + // Get dynamic dims from input tensor based on packedToStripMinedShapePerm // permutation. SmallVector dims = tensor::getMixedSizes(rewriter, loc, unPackOp.getSource()); - applyPermutationToVector(dims, lastDimsToInsertPositionsPerm); + applyPermutationToVector(dims, packedToStripMinedShapePerm); auto emptyOp = rewriter.create( loc, dims, stripMinedTensorType.getElementType()); auto transposeOp = rewriter.create( - loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm); + loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm); LLVM_DEBUG( DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions, @@ -428,8 +418,8 @@ FailureOr linalg::lowerUnPack(RewriterBase &rewriter, DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(), DBGS() << "packedShape: "); DBGSNL(); - llvm::interleaveComma(lastDimsToInsertPositionsPerm, - DBGS() << "lastDimsToInsertPositionsPerm: "); + llvm::interleaveComma(packedToStripMinedShapePerm, + DBGS() << "packedToStripMinedShapePerm: "); DBGSNL(); llvm::interleaveComma( packingMetadata.reassociations, DBGS() << "reassociations: ", [&](ReassociationIndices ri) { @@ -439,12 +429,12 @@ FailureOr linalg::lowerUnPack(RewriterBase &rewriter, llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: "); DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL();); - // 5. Collapse from the stripMinedShape to the padded result. + // 4. Collapse from the stripMinedShape to the padded result. auto reshapeOp = rewriter.create( loc, collapsedType, transposeOp->getResult(0), packingMetadata.reassociations); - // 6. ExtractSlice. + // 5. ExtractSlice. int64_t destRank = destTensorType.getRank(); auto extractSliceOp = rewriter.create( loc, destTensorType, reshapeOp->getResult(0), @@ -452,11 +442,11 @@ FailureOr linalg::lowerUnPack(RewriterBase &rewriter, tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()), SmallVector(destRank, one)); - // 7. Inject a copy to preserve DPS. + // 6. Inject a copy to preserve DPS. auto copyOp = rewriter.create( loc, extractSliceOp->getResult(0), unPackOp.getDest()); - // 8. Replace unPackOp by extractSliceOp. + // 7. Replace unPackOp by copyOp. rewriter.replaceOp(unPackOp, copyOp->getResults()); return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp}; diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir index 926969bfc73880..f34ef4f961483d 100644 --- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir @@ -622,9 +622,20 @@ module attributes {transform.with_named_sequence} { // ----- -// At the moment, we cannot lower tensor.unpack with outer_dims_perm. -func.func @diagnostic_unpack(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> { - // expected-note @below {{target payload op}} +// CHECK-LABEL: @unpack_with_outer_dims_perm +// CHECK-SAME: %[[ARG0:.*]]: tensor<32x64xf32>, %[[ARG1:.*]]: tensor<2x4x32x8xf32> +// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<4x8x2x32xf32> +// CHECK: %[[TRAN:.*]] = linalg.transpose +// CHECK-SAME: ins(%[[ARG1]] : tensor<2x4x32x8xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<4x8x2x32xf32>) +// CHECK-SAME: permutation = [1, 3, 0, 2] +// CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]] +// CHECK-SAME: : tensor<4x8x2x32xf32> into tensor<32x64xf32> +// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0] [32, 64] [1, 1] +// CHECK-SAME: : tensor<32x64xf32> to tensor<32x64xf32> +// CHECK: linalg.copy ins(%[[SLICE]] +// CHECK-SAME: : tensor<32x64xf32>) outs(%[[ARG0]] : tensor<32x64xf32>) -> tensor<32x64xf32> +func.func @unpack_with_outer_dims_perm(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> { %unpack = tensor.unpack %arg1 outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg0 : tensor<2x4x32x8xf32> -> tensor<32x64xf32> return %unpack : tensor<32x64xf32> @@ -634,7 +645,6 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op : (!transform.any_op) -> !transform.op<"tensor.unpack"> - // expected-error @below {{cannot lower to transpose + collapse + extract}} transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">) -> (!transform.op<"tensor.empty">, !transform.op<"linalg.transpose">,