Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][linalg] Support lowering unpack with outer_dims_perm #94477

Merged
merged 1 commit into from
Jun 7, 2024

Conversation

ryan-holt-1
Copy link
Contributor

This commit adds support for lowering tensor.unpack with a non-identity outer_dims_perm. This was previously left as a not-yet-implemented case.

@llvmbot
Copy link
Collaborator

llvmbot commented Jun 5, 2024

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Ryan Holt (ryan-holt-1)

Changes

This commit adds support for lowering tensor.unpack with a non-identity outer_dims_perm. This was previously left as a not-yet-implemented case.


Full diff: https://github.com/llvm/llvm-project/pull/94477.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+12-22)
  • (modified) mlir/test/Dialect/Linalg/transform-lower-pack.mlir (+14-4)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 91dfac802ad67..f18cfdea2faac 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -356,13 +356,6 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
 
 FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
                                                    tensor::UnPackOp unPackOp) {
-  // 1. Filter out NYI cases.
-  if (!unPackOp.getOuterDimsPerm().empty() &&
-      !isIdentityPermutation(unPackOp.getOuterDimsPerm())) {
-    return rewriter.notifyMatchFailure(unPackOp,
-                                       "non-identity outer dims perm NYI");
-  }
-
   Location loc = unPackOp->getLoc();
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unPackOp);
@@ -391,20 +384,17 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
     return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr,
                                /*reshapeOp=*/nullptr, extractSliceOp};
   }
-  // 2. Compute the permutation vector to move the last `numPackedDims` into
-  // the `innerPosDims` of a shape of rank `packedRank`.
-  int64_t numPackedDims = unPackOp.getInnerDimsPos().size();
-  auto lastDims = llvm::to_vector(
-      llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
-  PackingMetadata packingMetadata =
-      computePackingMetadata(packedRank, unPackOp.getInnerDimsPos());
-  SmallVector<int64_t> lastDimsToInsertPositionsPerm = computePermutationVector(
-      packedRank, lastDims, packingMetadata.insertPositions);
+
+  // 2. Compute the permutation vector to shuffle packed shape into the shape
+  // before any outer or inner permutations have been applied.
+  PackingMetadata packingMetadata;
+  SmallVector<int64_t> packedToStripMinedShapePerm =
+      tensor::getUnPackInverseSrcPerm(unPackOp, packingMetadata);
 
   // 3. Compute the stripMinedShape: this is the packed shape without outer and
   // inner permutations.
   SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
-  applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm);
+  applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
 
   // 4. Transpose packedShape to stripMinedShape.
   RankedTensorType stripMinedTensorType =
@@ -412,15 +402,15 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
   RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
       stripMinedTensorType, packingMetadata.reassociations);
 
-  // Get dynamic dims from input tensor based on lastDimsToInsertPositionsPerm
+  // Get dynamic dims from input tensor based on packedToStripMinedShapePerm
   // permutation.
   SmallVector<OpFoldResult, 4> dims =
       tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
-  applyPermutationToVector(dims, lastDimsToInsertPositionsPerm);
+  applyPermutationToVector(dims, packedToStripMinedShapePerm);
   auto emptyOp = rewriter.create<tensor::EmptyOp>(
       loc, dims, stripMinedTensorType.getElementType());
   auto transposeOp = rewriter.create<linalg::TransposeOp>(
-      loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm);
+      loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
 
   LLVM_DEBUG(
       DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
@@ -428,8 +418,8 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
       DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
                                       DBGS() << "packedShape: ");
       DBGSNL();
-      llvm::interleaveComma(lastDimsToInsertPositionsPerm,
-                            DBGS() << "lastDimsToInsertPositionsPerm: ");
+      llvm::interleaveComma(packedToStripMinedShapePerm,
+                            DBGS() << "packedToStripMinedShapePerm: ");
       DBGSNL(); llvm::interleaveComma(
           packingMetadata.reassociations, DBGS() << "reassociations: ",
           [&](ReassociationIndices ri) {
diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
index 926969bfc7388..f34ef4f961483 100644
--- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
+++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir
@@ -622,9 +622,20 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// At the moment, we cannot lower tensor.unpack with outer_dims_perm.
-func.func @diagnostic_unpack(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> {
-  // expected-note @below {{target payload op}}
+// CHECK-LABEL: @unpack_with_outer_dims_perm
+//  CHECK-SAME: %[[ARG0:.*]]: tensor<32x64xf32>, %[[ARG1:.*]]: tensor<2x4x32x8xf32>
+//       CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<4x8x2x32xf32>
+//       CHECK: %[[TRAN:.*]] = linalg.transpose 
+//  CHECK-SAME:   ins(%[[ARG1]] : tensor<2x4x32x8xf32>) 
+//  CHECK-SAME:   outs(%[[EMPTY]] : tensor<4x8x2x32xf32>) 
+//  CHECK-SAME:   permutation = [1, 3, 0, 2]
+//       CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]]
+//  CHECK-SAME:   : tensor<4x8x2x32xf32> into tensor<32x64xf32>
+//       CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0] [32, 64] [1, 1]
+//  CHECK-SAME:   : tensor<32x64xf32> to tensor<32x64xf32>
+//       CHECK: linalg.copy ins(%[[SLICE]]
+//  CHECK-SAME:   : tensor<32x64xf32>) outs(%[[ARG0]] : tensor<32x64xf32>) -> tensor<32x64xf32>
+func.func @unpack_with_outer_dims_perm(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> {
   %unpack = tensor.unpack %arg1 outer_dims_perm = [1, 0] 
     inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg0 : tensor<2x4x32x8xf32> -> tensor<32x64xf32>
   return %unpack : tensor<32x64xf32>
@@ -634,7 +645,6 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
       : (!transform.any_op) -> !transform.op<"tensor.unpack">
-    // expected-error @below {{cannot lower to transpose + collapse + extract}} 
     transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
       -> (!transform.op<"tensor.empty">,
           !transform.op<"linalg.transpose">,

Copy link
Member

@pashu123 pashu123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

This commit adds support for lowering `tensor.unpack` with a
non-identity `outer_dims_perm`. This was previously left as a
not-yet-implemented case.
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

Copy link
Contributor

@chelini chelini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ryan-holt-1 ryan-holt-1 merged commit 5b2f7a1 into llvm:main Jun 7, 2024
7 checks passed
@HerrCai0907 HerrCai0907 mentioned this pull request Jun 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants