Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][tensor] Rewrite tensor.pack as a constant #93954

Closed
wants to merge 1 commit into from

Conversation

adam-smnk
Copy link
Contributor

Adds a pattern to rewrite tensor.pack into arith.constant to avoid runtime packing of a constant tensor.

Adds a pattern to rewrite tensor.pack into arith.constant to avoid
runtime packing of a constant tensor.
@llvmbot
Copy link
Collaborator

llvmbot commented May 31, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tensor

Author: Adam Siemieniuk (adam-smnk)

Changes

Adds a pattern to rewrite tensor.pack into arith.constant to avoid runtime packing of a constant tensor.


Patch is 22.72 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/93954.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp (+154-1)
  • (modified) mlir/test/Dialect/Tensor/rewrite-as-constant.mlir (+208)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
index 11e1de543ac91..b63551c268ddc 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
@@ -6,10 +6,13 @@
 //
 //===----------------------------------------------------------------------===//
 //
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Threading.h"
 
 using namespace mlir;
 using namespace mlir::tensor;
@@ -45,9 +48,159 @@ struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
   }
 };
 
+/// Rewrite tensor.pack with arith.constant if the pack is writing
+/// to an empty tensor and the destination shape is static.
+struct PackToConstant : OpRewritePattern<tensor::PackOp> {
+  using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::PackOp packOp,
+                                PatternRewriter &rewriter) const override {
+    auto constOp = packOp.getSource().getDefiningOp<arith::ConstantOp>();
+    if (!constOp)
+      return failure();
+    // Must be a dense constant.
+    auto denseAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
+    if (!denseAttr)
+      return failure();
+
+    // Bail out if the pack is used as a writing operation i.e.,
+    // the destination is not a tensor.empty.
+    if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
+      return rewriter.notifyMatchFailure(packOp,
+                                         "expects empty tensor destination");
+    // Pack destination must have static shape.
+    if (!packOp.getDestType().hasStaticShape())
+      return rewriter.notifyMatchFailure(
+          packOp, "expects destination with static shape");
+
+    // Pack with padding is not supported currently.
+    // TODO: Insert padding values as a part of rewrite.
+    if (packOp.getPaddingValue())
+      return rewriter.notifyMatchFailure(packOp, "expects no padding value");
+
+    OpBuilder::InsertionGuard guard(rewriter);
+
+    // If it is a splat constant, rewrite the pack directly.
+    if (denseAttr.isSplat()) {
+      DenseElementsAttr packedDenseShape =
+          denseAttr.reshape(packOp.getDestType());
+      rewriter.setInsertionPoint(constOp);
+      rewriter.replaceOpWithNewOp<arith::ConstantOp>(packOp, packedDenseShape);
+
+      return success();
+    }
+
+    // Constant contains non-splat dense values.
+    // Move the data into a new packed buffer. Each value is placed into its new
+    // position as defined by the pack operation.
+    ArrayRef<char> srcRawData = denseAttr.getRawData();
+    SmallVector<char> destRawData(srcRawData.size());
+
+    int64_t numberOfElements = denseAttr.getNumElements();
+    SmallVector<int64_t> strides =
+        computeStrides(packOp.getDestType().getShape());
+
+    // Parallelize raw data movement to speedup large constant packing.
+    parallelFor(
+        packOp.getContext(), 0, numberOfElements,
+        [&](size_t destLinearizedIdx) {
+          // Step 1: De-linearize destination index.
+          // f(lin) = tmp[A][B][C]
+          SmallVector<int64_t> destIndices =
+              delinearize(destLinearizedIdx, strides);
+
+          // Step 2: Arrange the indexes based on the packing information.
+          // Compute inverse of outerDimsPerm to bring the loops into the
+          // canonical form tmp[A][B][a][b].
+          if (!packOp.getOuterDimsPerm().empty()) {
+            SmallVector<int64_t> inversePermutation =
+                invertPermutationVector(packOp.getOuterDimsPerm());
+            SmallVector<int64_t> tileLoops;
+            for (int64_t i = 0; i < packOp.getSourceType().getRank(); i++)
+              tileLoops.push_back(destIndices[i]);
+            applyPermutationToVector(tileLoops, inversePermutation);
+
+            SmallVector<int64_t> pointLoops;
+            for (size_t i = packOp.getSourceType().getRank();
+                 i < destIndices.size(); i++) {
+              pointLoops.push_back(destIndices[i]);
+            }
+
+            destIndices = tileLoops;
+            destIndices.append(pointLoops.begin(), pointLoops.end());
+          }
+          assert(destIndices.size() ==
+                 static_cast<size_t>(packOp.getDestType().getRank()));
+
+          // After interchanging the outermost tiled loop we end up in the
+          // canonical form tmp[A][B][a][b]. Squash the point loops with the
+          // tiled ones.
+          llvm::DenseSet<int64_t> tiledLoops(packOp.getInnerDimsPos().begin(),
+                                             packOp.getInnerDimsPos().end());
+          llvm::DenseMap<int64_t, int64_t> mappingTileToPointLoops;
+          // Map the position of the tiled loops with the point one.
+          // For example:
+          // [A][B] -> [A][B][a][b]
+          // entry: [A : 0] [a : 2]
+          // entry: [B : 1] [b : 3]
+          // [A][B] -> [A][B][b]
+          // entry: [B : 1] [b : 2]
+          for (auto [idx, tileLoop] : llvm::enumerate(packOp.getInnerDimsPos()))
+            mappingTileToPointLoops[tileLoop] = idx;
+
+          SmallVector<int64_t> srcIndices;
+          SmallVector<int64_t> tilesSizes = packOp.getStaticTiles();
+          int64_t numberOfTileLoops = packOp.getSourceType().getRank();
+          size_t tilePosIdx = 0;
+          for (int64_t i = 0; i < numberOfTileLoops; i++) {
+            if (!tiledLoops.count(i)) {
+              // Loop is not tiled.
+              srcIndices.push_back(destIndices[i]);
+            } else {
+              // Loop is tiled, account for the point loop distance.
+              srcIndices.push_back(
+                  destIndices[i] * tilesSizes[tilePosIdx] +
+                  destIndices[numberOfTileLoops + mappingTileToPointLoops[i]]);
+              tilePosIdx++;
+            }
+          }
+          assert(srcIndices.size() == static_cast<size_t>(numberOfTileLoops));
+
+          int64_t srcLinearizedIdx = linearize(
+              srcIndices, computeStrides(packOp.getSourceType().getShape()));
+          assert(srcLinearizedIdx < numberOfElements);
+
+          // Step 3: Do the packing.
+          // Copy the source element byte-wise to its packed destination
+          // position.
+          size_t elementByteSize =
+              denseAttr.getRawData().size() / denseAttr.getNumElements();
+          for (size_t i = 0; i < elementByteSize; i++) {
+            destRawData[destLinearizedIdx * elementByteSize + i] =
+                srcRawData[srcLinearizedIdx * elementByteSize + i];
+          }
+        });
+
+    // Fail gracefully if something went wrong.
+    bool detectSpalt = false;
+    if (!DenseElementsAttr::isValidRawBuffer(packOp.getDestType(), destRawData,
+                                             detectSpalt))
+      return rewriter.notifyMatchFailure(
+          packOp, "failed to create packed raw data buffer");
+
+    // Replace the pack with a new constant.
+    auto packedDenseShape =
+        DenseElementsAttr::getFromRawBuffer(packOp.getDestType(), destRawData);
+    rewriter.setInsertionPoint(constOp);
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(packOp, packedDenseShape);
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::tensor::populateRewriteAsConstantPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<GenerateToConstant>(patterns.getContext());
+  patterns.add<GenerateToConstant, PackToConstant>(patterns.getContext());
 }
diff --git a/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir b/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
index 1a1cf9e407d80..045cb5a0da1d5 100644
--- a/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
+++ b/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
@@ -21,3 +21,211 @@ func.func @tensor_generate_constant() -> tensor<2x3x5xf32> {
   } : tensor<2x3x5xf32>
   return %0 : tensor<2x3x5xf32>
 }
+
+// CHECK-LABEL: func.func @fold_pack_with_splat
+// CHECK: %[[CST:.+]] = arith.constant dense<1> : tensor<8x2x1x1x32x32xi64>
+// CHECK-NEXT: return %[[CST]] : tensor<8x2x1x1x32x32xi64>
+func.func @fold_pack_with_splat() ->  tensor<8x2x1x1x32x32xi64> {
+  %cst = arith.constant dense<1> : tensor<1x1x64x256xi64>
+  %0 = tensor.empty() : tensor<8x2x1x1x32x32xi64>
+  %pack = tensor.pack %cst outer_dims_perm = [3, 2, 0, 1] inner_dims_pos = [2, 3] inner_tiles = [32, 32]
+    into %0 : tensor<1x1x64x256xi64> -> tensor<8x2x1x1x32x32xi64>
+  return  %pack : tensor<8x2x1x1x32x32xi64>
+}
+
+// CHECK-LABEL: func.func @fold_pack_with_non_splat
+// CHECK: %[[CST:.+]] = arith.constant
+// CHECK-SAME: [0.000000e+00, 1.000000e+00], [8.000000e+00, 9.000000e+00], [1.600000e+01, 1.700000e+01], [2.400000e+01, 2.500000e+01]
+// CHECK-SAME: [2.000000e+00, 3.000000e+00], [1.000000e+01, 1.100000e+01], [1.800000e+01, 1.900000e+01], [2.600000e+01, 2.700000e+01]
+// CHECK-SAME: [4.000000e+00, 5.000000e+00], [1.200000e+01, 1.300000e+01], [2.000000e+01, 2.100000e+01], [2.800000e+01, 2.900000e+01]
+// CHECK-SAME: [6.000000e+00, 7.000000e+00], [1.400000e+01, 1.500000e+01], [2.200000e+01, 2.300000e+01], [3.000000e+01, 3.100000e+01]
+// CHECK-SAME: [3.200000e+01, 3.300000e+01], [4.000000e+01, 4.100000e+01], [4.900000e+01, 5.000000e+01], [5.700000e+01, 5.800000e+01]
+// CHECK-SAME: [3.400000e+01, 3.500000e+01], [4.200000e+01, 4.300000e+01], [5.100000e+01, 5.200000e+01], [5.900000e+01, 6.000000e+01]
+// CHECK-SAME: [3.600000e+01, 3.700000e+01], [4.400000e+01, 4.500000e+01], [5.300000e+01, 5.400000e+01], [6.100000e+01, 6.200000e+01]
+// CHECK-SAME: [3.800000e+01, 3.900000e+01], [4.600000e+01, 4.700000e+01], [5.500000e+01, 5.600000e+01], [6.300000e+01, 6.400000e+01]
+// CHECK-NOT: tensor.pack
+// CHECK: return %[[CST]] : tensor<2x4x4x2xf32>
+func.func @fold_pack_with_non_splat() -> tensor<2x4x4x2xf32> {
+  %cst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
+                               [8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0],
+                               [16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0],
+                               [24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],
+                               [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0],
+                               [40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0],
+                               [49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0],
+                               [57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0]]> : tensor<8x8xf32>
+  %0 = tensor.empty() : tensor<2x4x4x2xf32>
+  %pack = tensor.pack %cst inner_dims_pos = [0, 1] inner_tiles = [4, 2]
+    into %0 : tensor<8x8xf32> -> tensor<2x4x4x2xf32>
+  return %pack : tensor<2x4x4x2xf32>
+}
+
+// CHECK-LABEL: func.func @fold_pack_with_non_splat_with_inner_dims_reordered
+// CHECK: %[[CST:.+]] = arith.constant
+// CHECK-SAME: [0.000000e+00, 8.000000e+00, 1.600000e+01, 2.400000e+01], [1.000000e+00, 9.000000e+00, 1.700000e+01, 2.500000e+01]
+// CHECK-SAME: [4.000000e+00, 1.200000e+01, 2.000000e+01, 2.800000e+01], [5.000000e+00, 1.300000e+01, 2.100000e+01, 2.900000e+01]
+// CHECK-SAME: [8.000000e+00, 1.600000e+01, 2.400000e+01, 3.200000e+01], [9.000000e+00, 1.700000e+01, 2.500000e+01, 3.300000e+01]
+// CHECK-SAME: [1.200000e+01, 2.000000e+01, 2.800000e+01, 3.600000e+01], [1.300000e+01, 2.100000e+01, 2.900000e+01, 3.700000e+01]
+// CHECK-SAME: [1.600000e+01, 2.400000e+01, 3.200000e+01, 4.000000e+01], [1.700000e+01, 2.500000e+01, 3.300000e+01, 4.100000e+01]
+// CHECK-SAME: [2.000000e+01, 2.800000e+01, 3.600000e+01, 4.400000e+01], [2.100000e+01, 2.900000e+01, 3.700000e+01, 4.500000e+01]
+// CHECK-SAME: [2.400000e+01, 3.200000e+01, 4.000000e+01, 4.900000e+01], [2.500000e+01, 3.300000e+01, 4.100000e+01, 5.000000e+01]
+// CHECK-SAME: [2.800000e+01, 3.600000e+01, 4.400000e+01, 5.300000e+01], [2.900000e+01, 3.700000e+01, 4.500000e+01, 5.400000e+01]
+// CHECK-NOT: tensor.pack
+// CHECK: return %[[CST]] : tensor<2x4x2x4xf32>
+func.func @fold_pack_with_non_splat_with_inner_dims_reordered() -> tensor<2x4x2x4xf32> {
+  %cst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
+                               [8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0],
+                               [16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0],
+                               [24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],
+                               [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0],
+                               [40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0],
+                               [49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0],
+                               [57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0]]> : tensor<8x8xf32>
+  %0 = tensor.empty() : tensor<2x4x2x4xf32>
+  %pack = tensor.pack %cst inner_dims_pos = [1, 0] inner_tiles = [2, 4]
+    into %0 : tensor<8x8xf32> -> tensor<2x4x2x4xf32>
+  return %pack : tensor<2x4x2x4xf32>
+}
+
+// CHECK-LABEL: func.func @fold_pack_with_non_splat_with_inner_tiles_reordered
+// CHECK: %[[CST:.+]] = arith.constant
+// CHECK-SAME: [0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00], [8.000000e+00, 9.000000e+00, 1.000000e+01, 1.100000e+01]
+// CHECK-SAME: [4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00], [1.200000e+01, 1.300000e+01, 1.400000e+01, 1.500000e+01]
+// CHECK-SAME: [1.600000e+01, 1.700000e+01, 1.800000e+01, 1.900000e+01], [2.400000e+01, 2.500000e+01, 2.600000e+01, 2.700000e+01]
+// CHECK-SAME: [2.000000e+01, 2.100000e+01, 2.200000e+01, 2.300000e+01], [2.800000e+01, 2.900000e+01, 3.000000e+01, 3.100000e+01]
+// CHECK-SAME: [3.200000e+01, 3.300000e+01, 3.400000e+01, 3.500000e+01], [4.000000e+01, 4.100000e+01, 4.200000e+01, 4.300000e+01]
+// CHECK-SAME: [3.600000e+01, 3.700000e+01, 3.800000e+01, 3.900000e+01], [4.400000e+01, 4.500000e+01, 4.600000e+01, 4.700000e+01]
+// CHECK-SAME: [4.900000e+01, 5.000000e+01, 5.100000e+01, 5.200000e+01], [5.700000e+01, 5.800000e+01, 5.900000e+01, 6.000000e+01]
+// CHECK-SAME: [5.300000e+01, 5.400000e+01, 5.500000e+01, 5.600000e+01], [6.100000e+01, 6.200000e+01, 6.300000e+01, 6.400000e+01]
+// CHECK-NOT: tensor.pack
+// CHECK: return %[[CST]] : tensor<4x2x2x4xf32>
+func.func @fold_pack_with_non_splat_with_inner_tiles_reordered() -> tensor<4x2x2x4xf32> {
+  %cst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
+                               [8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0],
+                               [16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0],
+                               [24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],
+                               [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0],
+                               [40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0],
+                               [49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0],
+                               [57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0]]> : tensor<8x8xf32>
+  %0 = tensor.empty() : tensor<4x2x2x4xf32>
+  %pack = tensor.pack %cst inner_dims_pos = [0, 1] inner_tiles = [2, 4]
+    into %0 : tensor<8x8xf32> -> tensor<4x2x2x4xf32>
+  return %pack : tensor<4x2x2x4xf32>
+}
+
+// CHECK-LABEL: func.func @fold_pack_with_non_splat_with_outer_permutation
+// CHECK: %[[CST:.+]] = arith.constant
+// CHECK-SAME: [0.000000e+00, 1.000000e+00], [8.000000e+00, 9.000000e+00], [1.600000e+01, 1.700000e+01], [2.400000e+01, 2.500000e+01]
+// CHECK-SAME: [3.200000e+01, 3.300000e+01], [4.000000e+01, 4.100000e+01], [4.900000e+01, 5.000000e+01], [5.700000e+01, 5.800000e+01]
+// CHECK-SAME: [2.000000e+00, 3.000000e+00], [1.000000e+01, 1.100000e+01], [1.800000e+01, 1.900000e+01], [2.600000e+01, 2.700000e+01]
+// CHECK-SAME: [3.400000e+01, 3.500000e+01], [4.200000e+01, 4.300000e+01], [5.100000e+01, 5.200000e+01], [5.900000e+01, 6.000000e+01]
+// CHECK-SAME: [4.000000e+00, 5.000000e+00], [1.200000e+01, 1.300000e+01], [2.000000e+01, 2.100000e+01], [2.800000e+01, 2.900000e+01]
+// CHECK-SAME: [3.600000e+01, 3.700000e+01], [4.400000e+01, 4.500000e+01], [5.300000e+01, 5.400000e+01], [6.100000e+01, 6.200000e+01]
+// CHECK-SAME: [6.000000e+00, 7.000000e+00], [1.400000e+01, 1.500000e+01], [2.200000e+01, 2.300000e+01], [3.000000e+01, 3.100000e+01]
+// CHECK-SAME: [3.800000e+01, 3.900000e+01], [4.600000e+01, 4.700000e+01], [5.500000e+01, 5.600000e+01], [6.300000e+01, 6.400000e+01]
+// CHECK-NOT: tensor.pack
+// CHECK: return %[[CST]] : tensor<4x2x4x2xf32>
+func.func @fold_pack_with_non_splat_with_outer_permutation() -> tensor<4x2x4x2xf32> {
+  %cst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
+                               [8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0],
+                               [16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0],
+                               [24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],
+                               [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0],
+                               [40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0],
+                               [49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0],
+                               [57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0]]> : tensor<8x8xf32>
+  %0 = tensor.empty() : tensor<4x2x4x2xf32>
+  %pack = tensor.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [4, 2]
+    into %0 : tensor<8x8xf32> -> tensor<4x2x4x2xf32>
+  return %pack : tensor<4x2x4x2xf32>
+}
+
+// CHECK-LABEL: func.func @fold_pack_with_non_splat_with_inner_and_outer
+// CHECK: %[[CST:.+]] = arith.constant
+// CHECK-SAME: [0.000000e+00, 1.000000e+00], [4.000000e+00, 5.000000e+00]
+// CHECK-SAME: [8.000000e+00, 9.000000e+00], [1.200000e+01, 1.300000e+01]
+// CHECK-SAME: [2.000000e+00, 3.000000e+00], [6.000000e+00, 7.000000e+00]
+// CHECK-SAME: [1.000000e+01, 1.100000e+01], [1.400000e+01, 1.500000e+01]
+// CHECK-NOT: tensor.pack
+// CHECK: return %[[CST]] : tensor<1x2x2x2x2xf32>
+func.func @fold_pack_with_non_splat_with_inner_and_outer_permutations() -> tensor<1x2x2x2x2xf32> {
+  %cst = arith.constant dense <[[[[0.0, 1.0, 2.0, 3.0],   [4.0, 5.0, 6.0, 7.0]],
+                                 [[8.0, 9.0, 10.0, 11.0], [12.0, 13.0, 14.0, 15.0]]]]> : tensor<1x2x2x4xf32>
+  %0 = tensor.empty() : tensor<1x2x2x2x2xf32>
+  %1 = tensor.pack %cst outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [2]
+    into %0 : tensor<1x2x2x4xf32> -> tensor<1x2x2x2x2xf32>
+  return %1 : tensor<1x2x2x2x2xf32>
+}
+
+// CHECK-LABEL: func.func @no_fold_pack_into_non_empty_with_non_splat
+// CHECK: %[[PACK:.+]] = tensor.pack
+// CHECK: return %[[PACK]] : tensor<2x4x2x4xf32>
+func.func @no_fold_pack_into_non_empty_with_non_splat(%arg0: tensor<2x4x2x4xf32>) -> tensor<2x4x2x4xf32> {
+  %cst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
+                               [8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0],
+                               [16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0],
+                               [24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],
+                               [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0],
+                               [40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0],
+                               [49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0],
+                               [57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0]]> : tensor<8x8xf32>
+  %pack = tensor.pack %cst inner_dims_pos = [1, 0] inner_tiles = [2, 4]
+    into %arg0 : tensor<8x8xf32> -> tensor<2x4x2x4xf32>
+  return %pack : tensor<2x4x2x4xf32>
+}
+
+// CHECK-LABEL: func.func @no_fold_dynamic_inner_tile_pack_with_non_splat
+// CHECK: %[[PACK:.+]] = tensor.pack
+// CHECK: return %[[PACK]] : tensor<?x4x2x?xf32>
+func.func @no_fold_dynamic_inner_tile_pack_with_non_splat(%outer: index, %tile: index) -> tensor<?x4x2x?xf32> {
+  %cst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
+                               [8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0],
+                               [16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0],
+                               [24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],
+                               [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0],
+                               [40.0, 41....
[truncated]

@adam-smnk adam-smnk requested a review from chelini May 31, 2024 12:24
@sabauma
Copy link
Contributor

sabauma commented May 31, 2024

@adam-smnk We've wanted something similar to this (fold away tensor.pack), but opted to improve the folding of linalg.transpose (#92589) and add a rewrite for tensor.pad here (#92691). This works out since tensor.pack lowers to a transpose + pad sequence. Maybe that would be sufficient to suit your use case?

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some terms need to be updated in the PR, can you help rephrase them?

  1. There are no loops in tensor ops. They should be something like tiled data dimensions and outer data dimensions, which are similar to what is documented in td file.
  2. What is point loop? Should they just be tiled data dimension?

Comment on lines +66 to +74
// Bail out if the pack is used as a writing operation i.e.,
// the destination is not a tensor.empty.
if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
return rewriter.notifyMatchFailure(packOp,
"expects empty tensor destination");
// Pack destination must have static shape.
if (!packOp.getDestType().hasStaticShape())
return rewriter.notifyMatchFailure(
packOp, "expects destination with static shape");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can drop the comments because the failure message already document it for us.

// Pack with padding is not supported currently.
// TODO: Insert padding values as a part of rewrite.
if (packOp.getPaddingValue())
return rewriter.notifyMatchFailure(packOp, "expects no padding value");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: perhaps say that it is NIY (not implemented yet) in the failure message.

Comment on lines +83 to +91
// If it is a splat constant, rewrite the pack directly.
if (denseAttr.isSplat()) {
DenseElementsAttr packedDenseShape =
denseAttr.reshape(packOp.getDestType());
rewriter.setInsertionPoint(constOp);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(packOp, packedDenseShape);

return success();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case is already covered in folders.

OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
std::optional<Attribute> paddingValue;
if (auto pad = adaptor.getPaddingValue())
paddingValue = pad;
if (OpFoldResult reshapedSource = reshapeConstantSource(
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
getDestType(), paddingValue))
return reshapedSource;
return {};
}

Comment on lines +116 to +117
SmallVector<int64_t> inversePermutation =
invertPermutationVector(packOp.getOuterDimsPerm());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move it before where it is used, i.e., line 120.

SmallVector<int64_t> inversePermutation =
invertPermutationVector(packOp.getOuterDimsPerm());
SmallVector<int64_t> tileLoops;
for (int64_t i = 0; i < packOp.getSourceType().getRank(); i++)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can use packOp.getSourceRank().

Comment on lines +123 to +127
SmallVector<int64_t> pointLoops;
for (size_t i = packOp.getSourceType().getRank();
i < destIndices.size(); i++) {
pointLoops.push_back(destIndices[i]);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can simplify it by using llvm::to_vector + llvm::seq<int64_t>. Also we can use getSourceRank().

Suggested change
SmallVector<int64_t> pointLoops;
for (size_t i = packOp.getSourceType().getRank();
i < destIndices.size(); i++) {
pointLoops.push_back(destIndices[i]);
}
SmallVector<int64_t> pointLoops = llvm::to_vector(llvm::seq<int64_t>(packOp.getSourceRank(), destIndices.size());

[optional] using packOp.getDestRank() instead of destIndices.size() is slightly better to me. Because it directly ties to the pack op, so people don't need to look at what destIndices is.

Comment on lines +129 to +130
destIndices = tileLoops;
destIndices.append(pointLoops.begin(), pointLoops.end());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reviewing the changes, can you try if getPackInverseDestPerm + applyPermutationToVector works?

SmallVector<int64_t> getPackInverseDestPerm(tensor::PackOp packOp);

/// Apply the permutation defined by `permutation` to `inVec`.
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation
/// vector `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a',
/// 'b']`.
template <typename T, unsigned N>
void applyPermutationToVector(SmallVector<T, N> &inVec,
ArrayRef<int64_t> permutation) {
inVec = applyPermutation(inVec, permutation);
}

destIndices.append(pointLoops.begin(), pointLoops.end());
}
assert(destIndices.size() ==
static_cast<size_t>(packOp.getDestType().getRank()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use getDestRank() method.

Comment on lines +141 to +149
// Map the position of the tiled loops with the point one.
// For example:
// [A][B] -> [A][B][a][b]
// entry: [A : 0] [a : 2]
// entry: [B : 1] [b : 3]
// [A][B] -> [A][B][b]
// entry: [B : 1] [b : 2]
for (auto [idx, tileLoop] : llvm::enumerate(packOp.getInnerDimsPos()))
mappingTileToPointLoops[tileLoop] = idx;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the comment is off? Do we have entries for inner dims? All the values in getInnderDimsPos are less than the rank of source, so we won't have [a : 2], [b : 3] and [b : 2] entries. Do I misunderstand something?


SmallVector<int64_t> srcIndices;
SmallVector<int64_t> tilesSizes = packOp.getStaticTiles();
int64_t numberOfTileLoops = packOp.getSourceType().getRank();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: getSourceRank()

@adam-smnk
Copy link
Contributor Author

@sabauma That's a great approach too. It could be enough, I'll have to try it out. Thanks!

@@ -45,9 +48,159 @@ struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
}
};

/// Rewrite tensor.pack with arith.constant if the pack is writing
/// to an empty tensor and the destination shape is static.
struct PackToConstant : OpRewritePattern<tensor::PackOp> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isnt the expectation that these patterns come with a way to control when these patterns get applied?

@adam-smnk
Copy link
Contributor Author

@sabauma Thanks for the pointers. The Linalg constant folders are indeed sufficient.
I recreated this tensor.pack rewrite functionality by temporarily lowering constant packs, and then running constant folding and canonicalization patterns. Now only the op selection is custom and the rest is powered by upstream logic.

For reference, the implementation: plaidml/tpp-mlir#921

Then I think we can close this PR to avoid duplicating functionality.

@adam-smnk adam-smnk closed this Jun 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants