-
Notifications
You must be signed in to change notification settings - Fork 11.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][tensor] Rewrite tensor.pack as a constant #93954
Conversation
Adds a pattern to rewrite tensor.pack into arith.constant to avoid runtime packing of a constant tensor.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tensor Author: Adam Siemieniuk (adam-smnk) ChangesAdds a pattern to rewrite tensor.pack into arith.constant to avoid runtime packing of a constant tensor. Patch is 22.72 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/93954.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
index 11e1de543ac91..b63551c268ddc 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
@@ -6,10 +6,13 @@
//
//===----------------------------------------------------------------------===//
//
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Threading.h"
using namespace mlir;
using namespace mlir::tensor;
@@ -45,9 +48,159 @@ struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
}
};
+/// Rewrite tensor.pack with arith.constant if the pack is writing
+/// to an empty tensor and the destination shape is static.
+struct PackToConstant : OpRewritePattern<tensor::PackOp> {
+ using OpRewritePattern<tensor::PackOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::PackOp packOp,
+ PatternRewriter &rewriter) const override {
+ auto constOp = packOp.getSource().getDefiningOp<arith::ConstantOp>();
+ if (!constOp)
+ return failure();
+ // Must be a dense constant.
+ auto denseAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
+ if (!denseAttr)
+ return failure();
+
+ // Bail out if the pack is used as a writing operation i.e.,
+ // the destination is not a tensor.empty.
+ if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
+ return rewriter.notifyMatchFailure(packOp,
+ "expects empty tensor destination");
+ // Pack destination must have static shape.
+ if (!packOp.getDestType().hasStaticShape())
+ return rewriter.notifyMatchFailure(
+ packOp, "expects destination with static shape");
+
+ // Pack with padding is not supported currently.
+ // TODO: Insert padding values as a part of rewrite.
+ if (packOp.getPaddingValue())
+ return rewriter.notifyMatchFailure(packOp, "expects no padding value");
+
+ OpBuilder::InsertionGuard guard(rewriter);
+
+ // If it is a splat constant, rewrite the pack directly.
+ if (denseAttr.isSplat()) {
+ DenseElementsAttr packedDenseShape =
+ denseAttr.reshape(packOp.getDestType());
+ rewriter.setInsertionPoint(constOp);
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(packOp, packedDenseShape);
+
+ return success();
+ }
+
+ // Constant contains non-splat dense values.
+ // Move the data into a new packed buffer. Each value is placed into its new
+ // position as defined by the pack operation.
+ ArrayRef<char> srcRawData = denseAttr.getRawData();
+ SmallVector<char> destRawData(srcRawData.size());
+
+ int64_t numberOfElements = denseAttr.getNumElements();
+ SmallVector<int64_t> strides =
+ computeStrides(packOp.getDestType().getShape());
+
+ // Parallelize raw data movement to speedup large constant packing.
+ parallelFor(
+ packOp.getContext(), 0, numberOfElements,
+ [&](size_t destLinearizedIdx) {
+ // Step 1: De-linearize destination index.
+ // f(lin) = tmp[A][B][C]
+ SmallVector<int64_t> destIndices =
+ delinearize(destLinearizedIdx, strides);
+
+ // Step 2: Arrange the indexes based on the packing information.
+ // Compute inverse of outerDimsPerm to bring the loops into the
+ // canonical form tmp[A][B][a][b].
+ if (!packOp.getOuterDimsPerm().empty()) {
+ SmallVector<int64_t> inversePermutation =
+ invertPermutationVector(packOp.getOuterDimsPerm());
+ SmallVector<int64_t> tileLoops;
+ for (int64_t i = 0; i < packOp.getSourceType().getRank(); i++)
+ tileLoops.push_back(destIndices[i]);
+ applyPermutationToVector(tileLoops, inversePermutation);
+
+ SmallVector<int64_t> pointLoops;
+ for (size_t i = packOp.getSourceType().getRank();
+ i < destIndices.size(); i++) {
+ pointLoops.push_back(destIndices[i]);
+ }
+
+ destIndices = tileLoops;
+ destIndices.append(pointLoops.begin(), pointLoops.end());
+ }
+ assert(destIndices.size() ==
+ static_cast<size_t>(packOp.getDestType().getRank()));
+
+ // After interchanging the outermost tiled loop we end up in the
+ // canonical form tmp[A][B][a][b]. Squash the point loops with the
+ // tiled ones.
+ llvm::DenseSet<int64_t> tiledLoops(packOp.getInnerDimsPos().begin(),
+ packOp.getInnerDimsPos().end());
+ llvm::DenseMap<int64_t, int64_t> mappingTileToPointLoops;
+ // Map the position of the tiled loops with the point one.
+ // For example:
+ // [A][B] -> [A][B][a][b]
+ // entry: [A : 0] [a : 2]
+ // entry: [B : 1] [b : 3]
+ // [A][B] -> [A][B][b]
+ // entry: [B : 1] [b : 2]
+ for (auto [idx, tileLoop] : llvm::enumerate(packOp.getInnerDimsPos()))
+ mappingTileToPointLoops[tileLoop] = idx;
+
+ SmallVector<int64_t> srcIndices;
+ SmallVector<int64_t> tilesSizes = packOp.getStaticTiles();
+ int64_t numberOfTileLoops = packOp.getSourceType().getRank();
+ size_t tilePosIdx = 0;
+ for (int64_t i = 0; i < numberOfTileLoops; i++) {
+ if (!tiledLoops.count(i)) {
+ // Loop is not tiled.
+ srcIndices.push_back(destIndices[i]);
+ } else {
+ // Loop is tiled, account for the point loop distance.
+ srcIndices.push_back(
+ destIndices[i] * tilesSizes[tilePosIdx] +
+ destIndices[numberOfTileLoops + mappingTileToPointLoops[i]]);
+ tilePosIdx++;
+ }
+ }
+ assert(srcIndices.size() == static_cast<size_t>(numberOfTileLoops));
+
+ int64_t srcLinearizedIdx = linearize(
+ srcIndices, computeStrides(packOp.getSourceType().getShape()));
+ assert(srcLinearizedIdx < numberOfElements);
+
+ // Step 3: Do the packing.
+ // Copy the source element byte-wise to its packed destination
+ // position.
+ size_t elementByteSize =
+ denseAttr.getRawData().size() / denseAttr.getNumElements();
+ for (size_t i = 0; i < elementByteSize; i++) {
+ destRawData[destLinearizedIdx * elementByteSize + i] =
+ srcRawData[srcLinearizedIdx * elementByteSize + i];
+ }
+ });
+
+ // Fail gracefully if something went wrong.
+ bool detectSpalt = false;
+ if (!DenseElementsAttr::isValidRawBuffer(packOp.getDestType(), destRawData,
+ detectSpalt))
+ return rewriter.notifyMatchFailure(
+ packOp, "failed to create packed raw data buffer");
+
+ // Replace the pack with a new constant.
+ auto packedDenseShape =
+ DenseElementsAttr::getFromRawBuffer(packOp.getDestType(), destRawData);
+ rewriter.setInsertionPoint(constOp);
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(packOp, packedDenseShape);
+
+ return success();
+ }
+};
+
} // namespace
void mlir::tensor::populateRewriteAsConstantPatterns(
RewritePatternSet &patterns) {
- patterns.add<GenerateToConstant>(patterns.getContext());
+ patterns.add<GenerateToConstant, PackToConstant>(patterns.getContext());
}
diff --git a/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir b/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
index 1a1cf9e407d80..045cb5a0da1d5 100644
--- a/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
+++ b/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
@@ -21,3 +21,211 @@ func.func @tensor_generate_constant() -> tensor<2x3x5xf32> {
} : tensor<2x3x5xf32>
return %0 : tensor<2x3x5xf32>
}
+
+// CHECK-LABEL: func.func @fold_pack_with_splat
+// CHECK: %[[CST:.+]] = arith.constant dense<1> : tensor<8x2x1x1x32x32xi64>
+// CHECK-NEXT: return %[[CST]] : tensor<8x2x1x1x32x32xi64>
+func.func @fold_pack_with_splat() -> tensor<8x2x1x1x32x32xi64> {
+ %cst = arith.constant dense<1> : tensor<1x1x64x256xi64>
+ %0 = tensor.empty() : tensor<8x2x1x1x32x32xi64>
+ %pack = tensor.pack %cst outer_dims_perm = [3, 2, 0, 1] inner_dims_pos = [2, 3] inner_tiles = [32, 32]
+ into %0 : tensor<1x1x64x256xi64> -> tensor<8x2x1x1x32x32xi64>
+ return %pack : tensor<8x2x1x1x32x32xi64>
+}
+
+// CHECK-LABEL: func.func @fold_pack_with_non_splat
+// CHECK: %[[CST:.+]] = arith.constant
+// CHECK-SAME: [0.000000e+00, 1.000000e+00], [8.000000e+00, 9.000000e+00], [1.600000e+01, 1.700000e+01], [2.400000e+01, 2.500000e+01]
+// CHECK-SAME: [2.000000e+00, 3.000000e+00], [1.000000e+01, 1.100000e+01], [1.800000e+01, 1.900000e+01], [2.600000e+01, 2.700000e+01]
+// CHECK-SAME: [4.000000e+00, 5.000000e+00], [1.200000e+01, 1.300000e+01], [2.000000e+01, 2.100000e+01], [2.800000e+01, 2.900000e+01]
+// CHECK-SAME: [6.000000e+00, 7.000000e+00], [1.400000e+01, 1.500000e+01], [2.200000e+01, 2.300000e+01], [3.000000e+01, 3.100000e+01]
+// CHECK-SAME: [3.200000e+01, 3.300000e+01], [4.000000e+01, 4.100000e+01], [4.900000e+01, 5.000000e+01], [5.700000e+01, 5.800000e+01]
+// CHECK-SAME: [3.400000e+01, 3.500000e+01], [4.200000e+01, 4.300000e+01], [5.100000e+01, 5.200000e+01], [5.900000e+01, 6.000000e+01]
+// CHECK-SAME: [3.600000e+01, 3.700000e+01], [4.400000e+01, 4.500000e+01], [5.300000e+01, 5.400000e+01], [6.100000e+01, 6.200000e+01]
+// CHECK-SAME: [3.800000e+01, 3.900000e+01], [4.600000e+01, 4.700000e+01], [5.500000e+01, 5.600000e+01], [6.300000e+01, 6.400000e+01]
+// CHECK-NOT: tensor.pack
+// CHECK: return %[[CST]] : tensor<2x4x4x2xf32>
+func.func @fold_pack_with_non_splat() -> tensor<2x4x4x2xf32> {
+ %cst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
+ [8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0],
+ [16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0],
+ [24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],
+ [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0],
+ [40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0],
+ [49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0],
+ [57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0]]> : tensor<8x8xf32>
+ %0 = tensor.empty() : tensor<2x4x4x2xf32>
+ %pack = tensor.pack %cst inner_dims_pos = [0, 1] inner_tiles = [4, 2]
+ into %0 : tensor<8x8xf32> -> tensor<2x4x4x2xf32>
+ return %pack : tensor<2x4x4x2xf32>
+}
+
+// CHECK-LABEL: func.func @fold_pack_with_non_splat_with_inner_dims_reordered
+// CHECK: %[[CST:.+]] = arith.constant
+// CHECK-SAME: [0.000000e+00, 8.000000e+00, 1.600000e+01, 2.400000e+01], [1.000000e+00, 9.000000e+00, 1.700000e+01, 2.500000e+01]
+// CHECK-SAME: [4.000000e+00, 1.200000e+01, 2.000000e+01, 2.800000e+01], [5.000000e+00, 1.300000e+01, 2.100000e+01, 2.900000e+01]
+// CHECK-SAME: [8.000000e+00, 1.600000e+01, 2.400000e+01, 3.200000e+01], [9.000000e+00, 1.700000e+01, 2.500000e+01, 3.300000e+01]
+// CHECK-SAME: [1.200000e+01, 2.000000e+01, 2.800000e+01, 3.600000e+01], [1.300000e+01, 2.100000e+01, 2.900000e+01, 3.700000e+01]
+// CHECK-SAME: [1.600000e+01, 2.400000e+01, 3.200000e+01, 4.000000e+01], [1.700000e+01, 2.500000e+01, 3.300000e+01, 4.100000e+01]
+// CHECK-SAME: [2.000000e+01, 2.800000e+01, 3.600000e+01, 4.400000e+01], [2.100000e+01, 2.900000e+01, 3.700000e+01, 4.500000e+01]
+// CHECK-SAME: [2.400000e+01, 3.200000e+01, 4.000000e+01, 4.900000e+01], [2.500000e+01, 3.300000e+01, 4.100000e+01, 5.000000e+01]
+// CHECK-SAME: [2.800000e+01, 3.600000e+01, 4.400000e+01, 5.300000e+01], [2.900000e+01, 3.700000e+01, 4.500000e+01, 5.400000e+01]
+// CHECK-NOT: tensor.pack
+// CHECK: return %[[CST]] : tensor<2x4x2x4xf32>
+func.func @fold_pack_with_non_splat_with_inner_dims_reordered() -> tensor<2x4x2x4xf32> {
+ %cst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
+ [8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0],
+ [16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0],
+ [24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],
+ [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0],
+ [40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0],
+ [49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0],
+ [57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0]]> : tensor<8x8xf32>
+ %0 = tensor.empty() : tensor<2x4x2x4xf32>
+ %pack = tensor.pack %cst inner_dims_pos = [1, 0] inner_tiles = [2, 4]
+ into %0 : tensor<8x8xf32> -> tensor<2x4x2x4xf32>
+ return %pack : tensor<2x4x2x4xf32>
+}
+
+// CHECK-LABEL: func.func @fold_pack_with_non_splat_with_inner_tiles_reordered
+// CHECK: %[[CST:.+]] = arith.constant
+// CHECK-SAME: [0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00], [8.000000e+00, 9.000000e+00, 1.000000e+01, 1.100000e+01]
+// CHECK-SAME: [4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00], [1.200000e+01, 1.300000e+01, 1.400000e+01, 1.500000e+01]
+// CHECK-SAME: [1.600000e+01, 1.700000e+01, 1.800000e+01, 1.900000e+01], [2.400000e+01, 2.500000e+01, 2.600000e+01, 2.700000e+01]
+// CHECK-SAME: [2.000000e+01, 2.100000e+01, 2.200000e+01, 2.300000e+01], [2.800000e+01, 2.900000e+01, 3.000000e+01, 3.100000e+01]
+// CHECK-SAME: [3.200000e+01, 3.300000e+01, 3.400000e+01, 3.500000e+01], [4.000000e+01, 4.100000e+01, 4.200000e+01, 4.300000e+01]
+// CHECK-SAME: [3.600000e+01, 3.700000e+01, 3.800000e+01, 3.900000e+01], [4.400000e+01, 4.500000e+01, 4.600000e+01, 4.700000e+01]
+// CHECK-SAME: [4.900000e+01, 5.000000e+01, 5.100000e+01, 5.200000e+01], [5.700000e+01, 5.800000e+01, 5.900000e+01, 6.000000e+01]
+// CHECK-SAME: [5.300000e+01, 5.400000e+01, 5.500000e+01, 5.600000e+01], [6.100000e+01, 6.200000e+01, 6.300000e+01, 6.400000e+01]
+// CHECK-NOT: tensor.pack
+// CHECK: return %[[CST]] : tensor<4x2x2x4xf32>
+func.func @fold_pack_with_non_splat_with_inner_tiles_reordered() -> tensor<4x2x2x4xf32> {
+ %cst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
+ [8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0],
+ [16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0],
+ [24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],
+ [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0],
+ [40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0],
+ [49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0],
+ [57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0]]> : tensor<8x8xf32>
+ %0 = tensor.empty() : tensor<4x2x2x4xf32>
+ %pack = tensor.pack %cst inner_dims_pos = [0, 1] inner_tiles = [2, 4]
+ into %0 : tensor<8x8xf32> -> tensor<4x2x2x4xf32>
+ return %pack : tensor<4x2x2x4xf32>
+}
+
+// CHECK-LABEL: func.func @fold_pack_with_non_splat_with_outer_permutation
+// CHECK: %[[CST:.+]] = arith.constant
+// CHECK-SAME: [0.000000e+00, 1.000000e+00], [8.000000e+00, 9.000000e+00], [1.600000e+01, 1.700000e+01], [2.400000e+01, 2.500000e+01]
+// CHECK-SAME: [3.200000e+01, 3.300000e+01], [4.000000e+01, 4.100000e+01], [4.900000e+01, 5.000000e+01], [5.700000e+01, 5.800000e+01]
+// CHECK-SAME: [2.000000e+00, 3.000000e+00], [1.000000e+01, 1.100000e+01], [1.800000e+01, 1.900000e+01], [2.600000e+01, 2.700000e+01]
+// CHECK-SAME: [3.400000e+01, 3.500000e+01], [4.200000e+01, 4.300000e+01], [5.100000e+01, 5.200000e+01], [5.900000e+01, 6.000000e+01]
+// CHECK-SAME: [4.000000e+00, 5.000000e+00], [1.200000e+01, 1.300000e+01], [2.000000e+01, 2.100000e+01], [2.800000e+01, 2.900000e+01]
+// CHECK-SAME: [3.600000e+01, 3.700000e+01], [4.400000e+01, 4.500000e+01], [5.300000e+01, 5.400000e+01], [6.100000e+01, 6.200000e+01]
+// CHECK-SAME: [6.000000e+00, 7.000000e+00], [1.400000e+01, 1.500000e+01], [2.200000e+01, 2.300000e+01], [3.000000e+01, 3.100000e+01]
+// CHECK-SAME: [3.800000e+01, 3.900000e+01], [4.600000e+01, 4.700000e+01], [5.500000e+01, 5.600000e+01], [6.300000e+01, 6.400000e+01]
+// CHECK-NOT: tensor.pack
+// CHECK: return %[[CST]] : tensor<4x2x4x2xf32>
+func.func @fold_pack_with_non_splat_with_outer_permutation() -> tensor<4x2x4x2xf32> {
+ %cst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
+ [8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0],
+ [16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0],
+ [24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],
+ [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0],
+ [40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0],
+ [49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0],
+ [57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0]]> : tensor<8x8xf32>
+ %0 = tensor.empty() : tensor<4x2x4x2xf32>
+ %pack = tensor.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [4, 2]
+ into %0 : tensor<8x8xf32> -> tensor<4x2x4x2xf32>
+ return %pack : tensor<4x2x4x2xf32>
+}
+
+// CHECK-LABEL: func.func @fold_pack_with_non_splat_with_inner_and_outer
+// CHECK: %[[CST:.+]] = arith.constant
+// CHECK-SAME: [0.000000e+00, 1.000000e+00], [4.000000e+00, 5.000000e+00]
+// CHECK-SAME: [8.000000e+00, 9.000000e+00], [1.200000e+01, 1.300000e+01]
+// CHECK-SAME: [2.000000e+00, 3.000000e+00], [6.000000e+00, 7.000000e+00]
+// CHECK-SAME: [1.000000e+01, 1.100000e+01], [1.400000e+01, 1.500000e+01]
+// CHECK-NOT: tensor.pack
+// CHECK: return %[[CST]] : tensor<1x2x2x2x2xf32>
+func.func @fold_pack_with_non_splat_with_inner_and_outer_permutations() -> tensor<1x2x2x2x2xf32> {
+ %cst = arith.constant dense <[[[[0.0, 1.0, 2.0, 3.0], [4.0, 5.0, 6.0, 7.0]],
+ [[8.0, 9.0, 10.0, 11.0], [12.0, 13.0, 14.0, 15.0]]]]> : tensor<1x2x2x4xf32>
+ %0 = tensor.empty() : tensor<1x2x2x2x2xf32>
+ %1 = tensor.pack %cst outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [2]
+ into %0 : tensor<1x2x2x4xf32> -> tensor<1x2x2x2x2xf32>
+ return %1 : tensor<1x2x2x2x2xf32>
+}
+
+// CHECK-LABEL: func.func @no_fold_pack_into_non_empty_with_non_splat
+// CHECK: %[[PACK:.+]] = tensor.pack
+// CHECK: return %[[PACK]] : tensor<2x4x2x4xf32>
+func.func @no_fold_pack_into_non_empty_with_non_splat(%arg0: tensor<2x4x2x4xf32>) -> tensor<2x4x2x4xf32> {
+ %cst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
+ [8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0],
+ [16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0],
+ [24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],
+ [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0],
+ [40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0],
+ [49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0],
+ [57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0]]> : tensor<8x8xf32>
+ %pack = tensor.pack %cst inner_dims_pos = [1, 0] inner_tiles = [2, 4]
+ into %arg0 : tensor<8x8xf32> -> tensor<2x4x2x4xf32>
+ return %pack : tensor<2x4x2x4xf32>
+}
+
+// CHECK-LABEL: func.func @no_fold_dynamic_inner_tile_pack_with_non_splat
+// CHECK: %[[PACK:.+]] = tensor.pack
+// CHECK: return %[[PACK]] : tensor<?x4x2x?xf32>
+func.func @no_fold_dynamic_inner_tile_pack_with_non_splat(%outer: index, %tile: index) -> tensor<?x4x2x?xf32> {
+ %cst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
+ [8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0],
+ [16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0],
+ [24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0],
+ [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0],
+ [40.0, 41....
[truncated]
|
@adam-smnk We've wanted something similar to this (fold away tensor.pack), but opted to improve the folding of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some terms need to be updated in the PR, can you help rephrase them?
- There are no loops in tensor ops. They should be something like
tiled data dimensions
andouter data dimensions
, which are similar to what is documented intd
file. - What is
point loop
? Should they just betiled data dimension
?
// Bail out if the pack is used as a writing operation i.e., | ||
// the destination is not a tensor.empty. | ||
if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>()) | ||
return rewriter.notifyMatchFailure(packOp, | ||
"expects empty tensor destination"); | ||
// Pack destination must have static shape. | ||
if (!packOp.getDestType().hasStaticShape()) | ||
return rewriter.notifyMatchFailure( | ||
packOp, "expects destination with static shape"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can drop the comments because the failure message already document it for us.
// Pack with padding is not supported currently. | ||
// TODO: Insert padding values as a part of rewrite. | ||
if (packOp.getPaddingValue()) | ||
return rewriter.notifyMatchFailure(packOp, "expects no padding value"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: perhaps say that it is NIY (not implemented yet) in the failure message.
// If it is a splat constant, rewrite the pack directly. | ||
if (denseAttr.isSplat()) { | ||
DenseElementsAttr packedDenseShape = | ||
denseAttr.reshape(packOp.getDestType()); | ||
rewriter.setInsertionPoint(constOp); | ||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(packOp, packedDenseShape); | ||
|
||
return success(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This case is already covered in folders.
llvm-project/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Lines 4278 to 4287 in 07bd439
OpFoldResult PackOp::fold(FoldAdaptor adaptor) { | |
std::optional<Attribute> paddingValue; | |
if (auto pad = adaptor.getPaddingValue()) | |
paddingValue = pad; | |
if (OpFoldResult reshapedSource = reshapeConstantSource( | |
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()), | |
getDestType(), paddingValue)) | |
return reshapedSource; | |
return {}; | |
} |
SmallVector<int64_t> inversePermutation = | ||
invertPermutationVector(packOp.getOuterDimsPerm()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: move it before where it is used, i.e., line 120.
SmallVector<int64_t> inversePermutation = | ||
invertPermutationVector(packOp.getOuterDimsPerm()); | ||
SmallVector<int64_t> tileLoops; | ||
for (int64_t i = 0; i < packOp.getSourceType().getRank(); i++) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we can use packOp.getSourceRank()
.
SmallVector<int64_t> pointLoops; | ||
for (size_t i = packOp.getSourceType().getRank(); | ||
i < destIndices.size(); i++) { | ||
pointLoops.push_back(destIndices[i]); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can simplify it by using llvm::to_vector
+ llvm::seq<int64_t>
. Also we can use getSourceRank()
.
SmallVector<int64_t> pointLoops; | |
for (size_t i = packOp.getSourceType().getRank(); | |
i < destIndices.size(); i++) { | |
pointLoops.push_back(destIndices[i]); | |
} | |
SmallVector<int64_t> pointLoops = llvm::to_vector(llvm::seq<int64_t>(packOp.getSourceRank(), destIndices.size()); |
[optional] using packOp.getDestRank()
instead of destIndices.size()
is slightly better to me. Because it directly ties to the pack op, so people don't need to look at what destIndices
is.
destIndices = tileLoops; | ||
destIndices.append(pointLoops.begin(), pointLoops.end()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After reviewing the changes, can you try if getPackInverseDestPerm
+ applyPermutationToVector
works?
SmallVector<int64_t> getPackInverseDestPerm(tensor::PackOp packOp); |
llvm-project/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
Lines 217 to 226 in 13b6284
/// Apply the permutation defined by `permutation` to `inVec`. | |
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`. | |
/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation | |
/// vector `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', | |
/// 'b']`. | |
template <typename T, unsigned N> | |
void applyPermutationToVector(SmallVector<T, N> &inVec, | |
ArrayRef<int64_t> permutation) { | |
inVec = applyPermutation(inVec, permutation); | |
} |
destIndices.append(pointLoops.begin(), pointLoops.end()); | ||
} | ||
assert(destIndices.size() == | ||
static_cast<size_t>(packOp.getDestType().getRank())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use getDestRank()
method.
// Map the position of the tiled loops with the point one. | ||
// For example: | ||
// [A][B] -> [A][B][a][b] | ||
// entry: [A : 0] [a : 2] | ||
// entry: [B : 1] [b : 3] | ||
// [A][B] -> [A][B][b] | ||
// entry: [B : 1] [b : 2] | ||
for (auto [idx, tileLoop] : llvm::enumerate(packOp.getInnerDimsPos())) | ||
mappingTileToPointLoops[tileLoop] = idx; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the comment is off? Do we have entries for inner dims? All the values in getInnderDimsPos
are less than the rank of source, so we won't have [a : 2]
, [b : 3]
and [b : 2] entries. Do I misunderstand something?
|
||
SmallVector<int64_t> srcIndices; | ||
SmallVector<int64_t> tilesSizes = packOp.getStaticTiles(); | ||
int64_t numberOfTileLoops = packOp.getSourceType().getRank(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: getSourceRank()
@sabauma That's a great approach too. It could be enough, I'll have to try it out. Thanks! |
@@ -45,9 +48,159 @@ struct GenerateToConstant : public OpRewritePattern<GenerateOp> { | |||
} | |||
}; | |||
|
|||
/// Rewrite tensor.pack with arith.constant if the pack is writing | |||
/// to an empty tensor and the destination shape is static. | |||
struct PackToConstant : OpRewritePattern<tensor::PackOp> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isnt the expectation that these patterns come with a way to control when these patterns get applied?
@sabauma Thanks for the pointers. The Linalg constant folders are indeed sufficient. For reference, the implementation: plaidml/tpp-mlir#921 Then I think we can close this PR to avoid duplicating functionality. |
Adds a pattern to rewrite tensor.pack into arith.constant to avoid runtime packing of a constant tensor.