diff --git a/include/TPP/PassBundles.td b/include/TPP/PassBundles.td index 93c8a73ca..3a5c229fa 100644 --- a/include/TPP/PassBundles.td +++ b/include/TPP/PassBundles.td @@ -37,7 +37,10 @@ def DefaultTppPasses : Pass<"default-tpp-passes", "ModuleOp"> { "unsigned", "Grid-sizes for parallel tasks.">, Option<"linalgToVector", "linalg-to-vector", "bool", /*default=*/"false", - "Lower linalg directly to vector."> + "Lower linalg directly to vector.">, + Option<"lowerPackUnpackWithoutTranspose", "lower-pack-unpack-without-transpose", + "bool", /*default=*/"false", + "Lower non-constant packs and unpacks reverting any dim permutations."> ]; } @@ -51,6 +54,11 @@ def TppMapping : Pass<"tpp-mapping", "ModuleOp"> { "memref::MemRefDialect", "scf::SCFDialect", "tensor::TensorDialect"]; + let options= [ + Option<"lowerPackUnpackWithoutTranspose", "lower-pack-unpack-without-transpose", + "bool", /*default=*/"false", + "Lower non-constant packs and unpacks reverting any dim permutations."> + ]; } def LinalgLowering : Pass<"linalg-lowering", "func::FuncOp"> { diff --git a/include/TPP/Passes.td b/include/TPP/Passes.td index a1decd3ee..cd9f5c0da 100644 --- a/include/TPP/Passes.td +++ b/include/TPP/Passes.td @@ -177,6 +177,12 @@ def LowerPacksAndUnPacks : Pass<"lower-packs-unpacks", "func::FuncOp"> { "tensor::TensorDialect"]; } +def LowerPacksAndUnpacksWithoutTranspose : Pass<"lower-packs-unpacks-without-transpose", + "ModuleOp"> { + let dependentDialects = ["linalg::LinalgDialect", + "tensor::TensorDialect"]; +} + def RewriteConvToMatmulOrBrgemm : Pass<"rewrite-conv-to-matmul-or-brgemm", "func::FuncOp"> { let summary = "Rewrite Conv2DNhwcHwcfOp/Conv2DNchwFchwOp to Matmul or Brgemm."; diff --git a/lib/TPP/DefaultPipeline.cpp b/lib/TPP/DefaultPipeline.cpp index 2e826869d..f4ec69859 100644 --- a/lib/TPP/DefaultPipeline.cpp +++ b/lib/TPP/DefaultPipeline.cpp @@ -57,6 +57,11 @@ llvm::cl::opt linalgToVector("linalg-to-vector", llvm::cl::desc("Lower linalg to vector"), llvm::cl::init(false)); +llvm::cl::opt lowerPackUnpackWithoutTranspose( + "lower-pack-unpack-without-transpose", + llvm::cl::desc("Lower packs and unpacks reverting any dim permutations"), + llvm::cl::init(false)); + namespace mlir { namespace tpp { #define GEN_PASS_DEF_DEFAULTPIPELINE @@ -128,8 +133,9 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase, pm.addPass(createGpuPipeline(GpuPipelineOptions{gpuBackend})); } else { // Apply the default preprocessing pass - DefaultTppPassesOptions tppDefaultOptions{linalgToLoops, parallelTaskGrid, - linalgToVector}; + DefaultTppPassesOptions tppDefaultOptions{ + linalgToLoops, parallelTaskGrid, linalgToVector, + lowerPackUnpackWithoutTranspose}; pm.addPass(createDefaultTppPasses(tppDefaultOptions)); } diff --git a/lib/TPP/DefaultTppPasses.cpp b/lib/TPP/DefaultTppPasses.cpp index b1a0cd83a..11b394a93 100644 --- a/lib/TPP/DefaultTppPasses.cpp +++ b/lib/TPP/DefaultTppPasses.cpp @@ -85,7 +85,9 @@ struct DefaultTppPasses pm.addPass(createRewriteBatchMatmulToMatmul()); // Applies a set of passes at the linalg level to fuse and pack. - pm.addPass(createTppMapping()); + TppMappingOptions tppMappingOptions{ + lowerPackUnpackWithoutTranspose}; + pm.addPass(createTppMapping(tppMappingOptions)); // Generalize tensor.pack and tensor.unpack. pm.addPass(createLowerPacksAndUnPacks()); diff --git a/lib/TPP/PassBundles/TppMapping.cpp b/lib/TPP/PassBundles/TppMapping.cpp index d39b62530..e74e368e2 100644 --- a/lib/TPP/PassBundles/TppMapping.cpp +++ b/lib/TPP/PassBundles/TppMapping.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" #include "TPP/PassUtils.h" @@ -63,6 +64,9 @@ struct TppMapping : public tpp::impl::TppMappingBase, pm.addPass(createPackMatmul()); pm.addPass(createPackVNNI()); + if (lowerPackUnpackWithoutTranspose) { + pm.addPass(createLowerPacksAndUnpacksWithoutTranspose()); + } // Postprocess packing. // Run only canonicalizer at this stage as full cleanup (mostly CSE) can // mess up tensor producer-consumer chains used for analysis in the diff --git a/lib/TPP/Transforms/CMakeLists.txt b/lib/TPP/Transforms/CMakeLists.txt index 827aeb154..6a1e83fca 100644 --- a/lib/TPP/Transforms/CMakeLists.txt +++ b/lib/TPP/Transforms/CMakeLists.txt @@ -8,6 +8,7 @@ add_mlir_library(TPPTransforms DecomposeAggregatedOps.cpp LinalgDeGeneralize.cpp LowerPacksAndUnpacks.cpp + LowerPacksAndUnpacksWithoutTranspose.cpp RewriteBatchMatmulToMatmul.cpp RewriteConvsToMatmulOrBrgemm.cpp RewriteConvToMatmulImpl.cpp diff --git a/lib/TPP/Transforms/LowerPacksAndUnpacksWithoutTranspose.cpp b/lib/TPP/Transforms/LowerPacksAndUnpacksWithoutTranspose.cpp new file mode 100644 index 000000000..57b2b6030 --- /dev/null +++ b/lib/TPP/Transforms/LowerPacksAndUnpacksWithoutTranspose.cpp @@ -0,0 +1,202 @@ +//===- LowerPacksAndUnpacksWithoutTranspose.cpp ------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TPP/Passes.h" +#include "TPP/Transforms/Utils/ValueUtils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" +#include +#include +using namespace mlir; + +namespace mlir { +namespace tpp { +#define GEN_PASS_DEF_LOWERPACKSANDUNPACKSWITHOUTTRANSPOSE +#include "TPP/Passes.h.inc" +} // namespace tpp +} // namespace mlir + +namespace { + +/// Wrapper around linalg::lowerPack which undoes the transpose that might have +/// happened. Single user genericOp's indexing_maps is corrected accordingly. +void lowerPackAndFoldTranspose(tensor::PackOp packOp, + linalg::GenericOp genericOp, uint operandIdx, + PatternRewriter &rewriter) { + auto packInversionPerm = tensor::getPackInverseDestPerm(packOp); + + auto res = linalg::lowerPack(rewriter, packOp); + + if (res->transposeOp) { + // Forget about the permutation of the dims on expandShapeOp. + rewriter.replaceAllOpUsesWith(res->transposeOp, res->expandShapeOp); + + // Invert corresponding transposed accesses by the single user, genericOp. + auto indexingMaps = genericOp.getIndexingMapsArray(); + auto packInverseMap = + AffineMap::getPermutationMap(packInversionPerm, rewriter.getContext()); + auto normalizedIndexingMap = + packInverseMap.compose(indexingMaps[operandIdx]); + + SmallVector maps = llvm::to_vector(genericOp.getIndexingMaps()); + maps[operandIdx] = AffineMapAttr::get(normalizedIndexingMap); + genericOp.setIndexingMapsAttr(ArrayAttr::get(rewriter.getContext(), maps)); + } +} + +struct LowerPackOnInputsFoldingTranspose + : public OpRewritePattern { + // Is only called with single-user packOp operands, so callback can always + // find the (use by the) linalg.generic that is the target of the pattern. + using ControlFn = std::function; + ControlFn controlFn; + + LowerPackOnInputsFoldingTranspose(MLIRContext *context, + ControlFn controlFn = nullptr, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {} + + LogicalResult matchAndRewrite(linalg::GenericOp genericOp, + PatternRewriter &rewriter) const override { + bool modifiedAnOperand = false; + for (auto &&[operandIdx, inOperand] : + llvm::enumerate(genericOp.getInputs())) { + auto packOp = + dyn_cast_if_present(inOperand.getDefiningOp()); + + if (!packOp || !packOp->hasOneUse() || (controlFn && !controlFn(packOp))) + continue; + + lowerPackAndFoldTranspose(packOp, genericOp, operandIdx, rewriter); + + modifiedAnOperand = true; + } + + return modifiedAnOperand ? success() : failure(); + } +}; + +struct LowerPackUnpackOnOutputFoldingTranspose + : public OpRewritePattern { + // Is only called with single-user packOp operands, so callback can always + // find the (use by the) linalg.generic that is the target of the pattern. + using ControlFn = std::function; + ControlFn controlFn; + + LowerPackUnpackOnOutputFoldingTranspose(MLIRContext *context, + ControlFn controlFn = nullptr, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {} + + LogicalResult matchAndRewrite(linalg::GenericOp genericOp, + PatternRewriter &rewriter) const override { + bool modifiedAnOperand = false; + size_t numInputs = genericOp.getInputs().size(); + for (auto &&[outOperandIdx, outOperand] : + llvm::enumerate(genericOp.getOutputs())) { + size_t operandIdx = numInputs + outOperandIdx; + auto result = genericOp->getResult(outOperandIdx); + + if (!result.hasOneUse()) + continue; + + auto packOp = + dyn_cast_if_present(outOperand.getDefiningOp()); + auto unpackOp = + llvm::dyn_cast(*(result.getUsers().begin())); + + if (!packOp || !packOp->hasOneUse() || !unpackOp) + continue; + + // Normalize empty outer_dims_perm to its corresponding identity perm. + auto packOuterDimsPerm = SmallVector(packOp.getOuterDimsPerm()); + if (packOuterDimsPerm.empty()) { + packOuterDimsPerm = + SmallVector(packOp.getSource().getType().getRank()); + std::iota(packOuterDimsPerm.begin(), packOuterDimsPerm.end(), 0); + } + auto unpackOuterDimsPerm = SmallVector(unpackOp.getOuterDimsPerm()); + if (unpackOuterDimsPerm.empty()) { + unpackOuterDimsPerm = + SmallVector(unpackOp.getResult().getType().getRank()); + std::iota(unpackOuterDimsPerm.begin(), unpackOuterDimsPerm.end(), 0); + } + + if (unpackOp.getInnerDimsPos() != packOp.getInnerDimsPos() || + packOuterDimsPerm != unpackOuterDimsPerm || + (controlFn && !controlFn(packOp, unpackOp))) + continue; + + auto unpackDest = unpackOp.getDest(); + bool destHasStaticShape = unpackDest.getType().hasStaticShape(); + + lowerPackAndFoldTranspose(packOp, genericOp, operandIdx, rewriter); + auto res = linalg::lowerUnPack(rewriter, unpackOp); + + // Set genericOp's result type to the adjusted type of the out parameter. + result.setType(genericOp.getOperand(operandIdx).getType()); + + if (auto transposeOp = res->transposeOp) { + // Forget about the transpose introduced by lowerUnPack. + rewriter.replaceAllOpUsesWith(transposeOp, transposeOp.getInput()); + } + + // lowerUnPack introduces a copy to maintain DPS w.r.t. unpackOp's dest. + // As we ignore permutations and, in the static case, don't do padding, + // we know the underlying buffer will be used as is and hence we do not + // need to specify a dest to update into. + auto extractSliceOp = res->extractSliceOp; + if (destHasStaticShape && extractSliceOp && extractSliceOp->hasOneUse()) { + auto copyOp = + dyn_cast(*extractSliceOp->getUsers().begin()); + if (copyOp && copyOp.getOutputs()[0] == unpackDest) { + rewriter.replaceAllOpUsesWith(copyOp, copyOp.getInputs()[0]); + } + } + modifiedAnOperand = true; + } + + return modifiedAnOperand ? success() : failure(); + } +}; + +struct LowerPacksAndUnpacksWithoutTranspose + : public tpp::impl::LowerPacksAndUnpacksWithoutTransposeBase< + LowerPacksAndUnpacksWithoutTranspose> { + + void runOnOperation() override { + auto *ctx = &getContext(); + + RewritePatternSet patterns(ctx); + patterns.add( + ctx, [](tensor::PackOp packOp) { + // Only lower packOps whose argument is not a constant. + return !llvm::dyn_cast_if_present( + packOp.getOperand(0).getDefiningOp()); + }); + patterns.add(ctx); + + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + +} // namespace diff --git a/test/Integration/lower-pack-unpack-without-transpose.mlir b/test/Integration/lower-pack-unpack-without-transpose.mlir new file mode 100644 index 000000000..cb2416cdd --- /dev/null +++ b/test/Integration/lower-pack-unpack-without-transpose.mlir @@ -0,0 +1,55 @@ +// RUN: tpp-run %s -e big_matmul --entry-point-result=void -print -n 1 --seed=123 2>&1 > %S/lower-pack-unpack-without-transpose-big_matmul-only-default-passes.out +// RUN: tpp-run --lower-pack-unpack-without-transpose %s -e big_matmul --entry-point-result=void -print -n 1 --seed=123 2>&1 > %S/lower-pack-unpack-without-transpose-big_matmul-default-passes-lower-pack-unpack-without-transpose.out +// RUN: fpcmp -r 0.001 %S/lower-pack-unpack-without-transpose-big_matmul-only-default-passes.out %S/lower-pack-unpack-without-transpose-big_matmul-default-passes-lower-pack-unpack-without-transpose.out +// RUN: rm %S/lower-pack-unpack-without-transpose-big_matmul-only-default-passes.out %S/lower-pack-unpack-without-transpose-big_matmul-default-passes-lower-pack-unpack-without-transpose.out + +func.func @big_matmul(%A: tensor<256x128xf32>, %B: tensor<128x64xf32>, %C: tensor<256x64xf32>) -> tensor<256x64xf32> { + %D = linalg.matmul ins(%A, %B: tensor<256x128xf32>, tensor<128x64xf32>) outs(%C: tensor<256x64xf32>) -> tensor<256x64xf32> + return %D : tensor<256x64xf32> +} + +// RUN: tpp-run %s -e small_matmul_with_a_cst_arg --entry-point-result=void -print -n 1 --seed=123 2>&1 > %S/lower-pack-unpack-without-transpose-small_matmul_with_a_cst_arg-only-default-passes.out +// RUN: tpp-run --lower-pack-unpack-without-transpose %s -e small_matmul_with_a_cst_arg --entry-point-result=void -print -n 1 --seed=123 2>&1 > %S/lower-pack-unpack-without-transpose-small_matmul_with_a_cst_arg-default-passes-lower-pack-unpack-without-transpose.out +// RUN: fpcmp -r 0.001 %S/lower-pack-unpack-without-transpose-small_matmul_with_a_cst_arg-only-default-passes.out %S/lower-pack-unpack-without-transpose-small_matmul_with_a_cst_arg-default-passes-lower-pack-unpack-without-transpose.out +// RUN: rm %S/lower-pack-unpack-without-transpose-small_matmul_with_a_cst_arg-only-default-passes.out %S/lower-pack-unpack-without-transpose-small_matmul_with_a_cst_arg-default-passes-lower-pack-unpack-without-transpose.out + +func.func @small_matmul_with_a_cst_arg(%A: tensor<128x32xf32>, %C: tensor<128x64xf32>) -> tensor<128x64xf32> { + %constB = arith.constant dense<[ +[ 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1, 9.1, 10.1, 11.1, 12.1, 13.1, 14.1, 15.1, 16.1, 17.1, 18.1, 19.1, 20.1, 21.1, 22.1, 23.1, 24.1, 25.1, 26.1, 27.1, 28.1, 29.1, 30.1, 31.1, 32.1, 33.1, 34.1, 35.1, 36.1, 37.1, 38.1, 39.1, 40.1, 41.1, 42.1, 43.1, 44.1, 45.1, 46.1, 47.1, 48.1, 49.1, 50.1, 51.1, 52.1, 53.1, 54.1, 55.1, 56.1, 57.1, 58.1, 59.1, 60.1, 61.1, 62.1, 63.1, 64.1 ], +[ 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2, 9.2, 10.2, 11.2, 12.2, 13.2, 14.2, 15.2, 16.2, 17.2, 18.2, 19.2, 20.2, 21.2, 22.2, 23.2, 24.2, 25.2, 26.2, 27.2, 28.2, 29.2, 30.2, 31.2, 32.2, 33.2, 34.2, 35.2, 36.2, 37.2, 38.2, 39.2, 40.2, 41.2, 42.2, 43.2, 44.2, 45.2, 46.2, 47.2, 48.2, 49.2, 50.2, 51.2, 52.2, 53.2, 54.2, 55.2, 56.2, 57.2, 58.2, 59.2, 60.2, 61.2, 62.2, 63.2, 64.2 ], +[ 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3, 9.3, 10.3, 11.3, 12.3, 13.3, 14.3, 15.3, 16.3, 17.3, 18.3, 19.3, 20.3, 21.3, 22.3, 23.3, 24.3, 25.3, 26.3, 27.3, 28.3, 29.3, 30.3, 31.3, 32.3, 33.3, 34.3, 35.3, 36.3, 37.3, 38.3, 39.3, 40.3, 41.3, 42.3, 43.3, 44.3, 45.3, 46.3, 47.3, 48.3, 49.3, 50.3, 51.3, 52.3, 53.3, 54.3, 55.3, 56.3, 57.3, 58.3, 59.3, 60.3, 61.3, 62.3, 63.3, 64.3 ], +[ 1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4, 9.4, 10.4, 11.4, 12.4, 13.4, 14.4, 15.4, 16.4, 17.4, 18.4, 19.4, 20.4, 21.4, 22.4, 23.4, 24.4, 25.4, 26.4, 27.4, 28.4, 29.4, 30.4, 31.4, 32.4, 33.4, 34.4, 35.4, 36.4, 37.4, 38.4, 39.4, 40.4, 41.4, 42.4, 43.4, 44.4, 45.4, 46.4, 47.4, 48.4, 49.4, 50.4, 51.4, 52.4, 53.4, 54.4, 55.4, 56.4, 57.4, 58.4, 59.4, 60.4, 61.4, 62.4, 63.4, 64.4 ], +[ 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5, 11.5, 12.5, 13.5, 14.5, 15.5, 16.5, 17.5, 18.5, 19.5, 20.5, 21.5, 22.5, 23.5, 24.5, 25.5, 26.5, 27.5, 28.5, 29.5, 30.5, 31.5, 32.5, 33.5, 34.5, 35.5, 36.5, 37.5, 38.5, 39.5, 40.5, 41.5, 42.5, 43.5, 44.5, 45.5, 46.5, 47.5, 48.5, 49.5, 50.5, 51.5, 52.5, 53.5, 54.5, 55.5, 56.5, 57.5, 58.5, 59.5, 60.5, 61.5, 62.5, 63.5, 64.5 ], +[ 1.6, 2.6, 3.6, 4.6, 5.6, 6.6, 7.6, 8.6, 9.6, 10.6, 11.6, 12.6, 13.6, 14.6, 15.6, 16.6, 17.6, 18.6, 19.6, 20.6, 21.6, 22.6, 23.6, 24.6, 25.6, 26.6, 27.6, 28.6, 29.6, 30.6, 31.6, 32.6, 33.6, 34.6, 35.6, 36.6, 37.6, 38.6, 39.6, 40.6, 41.6, 42.6, 43.6, 44.6, 45.6, 46.6, 47.6, 48.6, 49.6, 50.6, 51.6, 52.6, 53.6, 54.6, 55.6, 56.6, 57.6, 58.6, 59.6, 60.6, 61.6, 62.6, 63.6, 64.6 ], +[ 1.7, 2.7, 3.7, 4.7, 5.7, 6.7, 7.7, 8.7, 9.7, 10.7, 11.7, 12.7, 13.7, 14.7, 15.7, 16.7, 17.7, 18.7, 19.7, 20.7, 21.7, 22.7, 23.7, 24.7, 25.7, 26.7, 27.7, 28.7, 29.7, 30.7, 31.7, 32.7, 33.7, 34.7, 35.7, 36.7, 37.7, 38.7, 39.7, 40.7, 41.7, 42.7, 43.7, 44.7, 45.7, 46.7, 47.7, 48.7, 49.7, 50.7, 51.7, 52.7, 53.7, 54.7, 55.7, 56.7, 57.7, 58.7, 59.7, 60.7, 61.7, 62.7, 63.7, 64.7 ], +[ 1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8, 9.8, 10.8, 11.8, 12.8, 13.8, 14.8, 15.8, 16.8, 17.8, 18.8, 19.8, 20.8, 21.8, 22.8, 23.8, 24.8, 25.8, 26.8, 27.8, 28.8, 29.8, 30.8, 31.8, 32.8, 33.8, 34.8, 35.8, 36.8, 37.8, 38.8, 39.8, 40.8, 41.8, 42.8, 43.8, 44.8, 45.8, 46.8, 47.8, 48.8, 49.8, 50.8, 51.8, 52.8, 53.8, 54.8, 55.8, 56.8, 57.8, 58.8, 59.8, 60.8, 61.8, 62.8, 63.8, 64.8 ], +[ 1.9, 2.9, 3.9, 4.9, 5.9, 6.9, 7.9, 8.9, 9.9, 10.9, 11.9, 12.9, 13.9, 14.9, 15.9, 16.9, 17.9, 18.9, 19.9, 20.9, 21.9, 22.9, 23.9, 24.9, 25.9, 26.9, 27.9, 28.9, 29.9, 30.9, 31.9, 32.9, 33.9, 34.9, 35.9, 36.9, 37.9, 38.9, 39.9, 40.9, 41.9, 42.9, 43.9, 44.9, 45.9, 46.9, 47.9, 48.9, 49.9, 50.9, 51.9, 52.9, 53.9, 54.9, 55.9, 56.9, 57.9, 58.9, 59.9, 60.9, 61.9, 62.9, 63.9, 64.9 ], +[ 1.10, 2.10, 3.10, 4.10, 5.10, 6.10, 7.10, 8.10, 9.10, 10.10, 11.10, 12.10, 13.10, 14.10, 15.10, 16.10, 17.10, 18.10, 19.10, 20.10, 21.10, 22.10, 23.10, 24.10, 25.10, 26.10, 27.10, 28.10, 29.10, 30.10, 31.10, 32.10, 33.10, 34.10, 35.10, 36.10, 37.10, 38.10, 39.10, 40.10, 41.10, 42.10, 43.10, 44.10, 45.10, 46.10, 47.10, 48.10, 49.10, 50.10, 51.10, 52.10, 53.10, 54.10, 55.10, 56.10, 57.10, 58.10, 59.10, 60.10, 61.10, 62.10, 63.10, 64.10 ], +[ 1.11, 2.11, 3.11, 4.11, 5.11, 6.11, 7.11, 8.11, 9.11, 10.11, 11.11, 12.11, 13.11, 14.11, 15.11, 16.11, 17.11, 18.11, 19.11, 20.11, 21.11, 22.11, 23.11, 24.11, 25.11, 26.11, 27.11, 28.11, 29.11, 30.11, 31.11, 32.11, 33.11, 34.11, 35.11, 36.11, 37.11, 38.11, 39.11, 40.11, 41.11, 42.11, 43.11, 44.11, 45.11, 46.11, 47.11, 48.11, 49.11, 50.11, 51.11, 52.11, 53.11, 54.11, 55.11, 56.11, 57.11, 58.11, 59.11, 60.11, 61.11, 62.11, 63.11, 64.11 ], +[ 1.12, 2.12, 3.12, 4.12, 5.12, 6.12, 7.12, 8.12, 9.12, 10.12, 11.12, 12.12, 13.12, 14.12, 15.12, 16.12, 17.12, 18.12, 19.12, 20.12, 21.12, 22.12, 23.12, 24.12, 25.12, 26.12, 27.12, 28.12, 29.12, 30.12, 31.12, 32.12, 33.12, 34.12, 35.12, 36.12, 37.12, 38.12, 39.12, 40.12, 41.12, 42.12, 43.12, 44.12, 45.12, 46.12, 47.12, 48.12, 49.12, 50.12, 51.12, 52.12, 53.12, 54.12, 55.12, 56.12, 57.12, 58.12, 59.12, 60.12, 61.12, 62.12, 63.12, 64.12 ], +[ 1.13, 2.13, 3.13, 4.13, 5.13, 6.13, 7.13, 8.13, 9.13, 10.13, 11.13, 12.13, 13.13, 14.13, 15.13, 16.13, 17.13, 18.13, 19.13, 20.13, 21.13, 22.13, 23.13, 24.13, 25.13, 26.13, 27.13, 28.13, 29.13, 30.13, 31.13, 32.13, 33.13, 34.13, 35.13, 36.13, 37.13, 38.13, 39.13, 40.13, 41.13, 42.13, 43.13, 44.13, 45.13, 46.13, 47.13, 48.13, 49.13, 50.13, 51.13, 52.13, 53.13, 54.13, 55.13, 56.13, 57.13, 58.13, 59.13, 60.13, 61.13, 62.13, 63.13, 64.13 ], +[ 1.14, 2.14, 3.14, 4.14, 5.14, 6.14, 7.14, 8.14, 9.14, 10.14, 11.14, 12.14, 13.14, 14.14, 15.14, 16.14, 17.14, 18.14, 19.14, 20.14, 21.14, 22.14, 23.14, 24.14, 25.14, 26.14, 27.14, 28.14, 29.14, 30.14, 31.14, 32.14, 33.14, 34.14, 35.14, 36.14, 37.14, 38.14, 39.14, 40.14, 41.14, 42.14, 43.14, 44.14, 45.14, 46.14, 47.14, 48.14, 49.14, 50.14, 51.14, 52.14, 53.14, 54.14, 55.14, 56.14, 57.14, 58.14, 59.14, 60.14, 61.14, 62.14, 63.14, 64.14 ], +[ 1.15, 2.15, 3.15, 4.15, 5.15, 6.15, 7.15, 8.15, 9.15, 10.15, 11.15, 12.15, 13.15, 14.15, 15.15, 16.15, 17.15, 18.15, 19.15, 20.15, 21.15, 22.15, 23.15, 24.15, 25.15, 26.15, 27.15, 28.15, 29.15, 30.15, 31.15, 32.15, 33.15, 34.15, 35.15, 36.15, 37.15, 38.15, 39.15, 40.15, 41.15, 42.15, 43.15, 44.15, 45.15, 46.15, 47.15, 48.15, 49.15, 50.15, 51.15, 52.15, 53.15, 54.15, 55.15, 56.15, 57.15, 58.15, 59.15, 60.15, 61.15, 62.15, 63.15, 64.15 ], +[ 1.16, 2.16, 3.16, 4.16, 5.16, 6.16, 7.16, 8.16, 9.16, 10.16, 11.16, 12.16, 13.16, 14.16, 15.16, 16.16, 17.16, 18.16, 19.16, 20.16, 21.16, 22.16, 23.16, 24.16, 25.16, 26.16, 27.16, 28.16, 29.16, 30.16, 31.16, 32.16, 33.16, 34.16, 35.16, 36.16, 37.16, 38.16, 39.16, 40.16, 41.16, 42.16, 43.16, 44.16, 45.16, 46.16, 47.16, 48.16, 49.16, 50.16, 51.16, 52.16, 53.16, 54.16, 55.16, 56.16, 57.16, 58.16, 59.16, 60.16, 61.16, 62.16, 63.16, 64.16 ], +[ 1.17, 2.17, 3.17, 4.17, 5.17, 6.17, 7.17, 8.17, 9.17, 10.17, 11.17, 12.17, 13.17, 14.17, 15.17, 16.17, 17.17, 18.17, 19.17, 20.17, 21.17, 22.17, 23.17, 24.17, 25.17, 26.17, 27.17, 28.17, 29.17, 30.17, 31.17, 32.17, 33.17, 34.17, 35.17, 36.17, 37.17, 38.17, 39.17, 40.17, 41.17, 42.17, 43.17, 44.17, 45.17, 46.17, 47.17, 48.17, 49.17, 50.17, 51.17, 52.17, 53.17, 54.17, 55.17, 56.17, 57.17, 58.17, 59.17, 60.17, 61.17, 62.17, 63.17, 64.17 ], +[ 1.18, 2.18, 3.18, 4.18, 5.18, 6.18, 7.18, 8.18, 9.18, 10.18, 11.18, 12.18, 13.18, 14.18, 15.18, 16.18, 17.18, 18.18, 19.18, 20.18, 21.18, 22.18, 23.18, 24.18, 25.18, 26.18, 27.18, 28.18, 29.18, 30.18, 31.18, 32.18, 33.18, 34.18, 35.18, 36.18, 37.18, 38.18, 39.18, 40.18, 41.18, 42.18, 43.18, 44.18, 45.18, 46.18, 47.18, 48.18, 49.18, 50.18, 51.18, 52.18, 53.18, 54.18, 55.18, 56.18, 57.18, 58.18, 59.18, 60.18, 61.18, 62.18, 63.18, 64.18 ], +[ 1.19, 2.19, 3.19, 4.19, 5.19, 6.19, 7.19, 8.19, 9.19, 10.19, 11.19, 12.19, 13.19, 14.19, 15.19, 16.19, 17.19, 18.19, 19.19, 20.19, 21.19, 22.19, 23.19, 24.19, 25.19, 26.19, 27.19, 28.19, 29.19, 30.19, 31.19, 32.19, 33.19, 34.19, 35.19, 36.19, 37.19, 38.19, 39.19, 40.19, 41.19, 42.19, 43.19, 44.19, 45.19, 46.19, 47.19, 48.19, 49.19, 50.19, 51.19, 52.19, 53.19, 54.19, 55.19, 56.19, 57.19, 58.19, 59.19, 60.19, 61.19, 62.19, 63.19, 64.19 ], +[ 1.20, 2.20, 3.20, 4.20, 5.20, 6.20, 7.20, 8.20, 9.20, 10.20, 11.20, 12.20, 13.20, 14.20, 15.20, 16.20, 17.20, 18.20, 19.20, 20.20, 21.20, 22.20, 23.20, 24.20, 25.20, 26.20, 27.20, 28.20, 29.20, 30.20, 31.20, 32.20, 33.20, 34.20, 35.20, 36.20, 37.20, 38.20, 39.20, 40.20, 41.20, 42.20, 43.20, 44.20, 45.20, 46.20, 47.20, 48.20, 49.20, 50.20, 51.20, 52.20, 53.20, 54.20, 55.20, 56.20, 57.20, 58.20, 59.20, 60.20, 61.20, 62.20, 63.20, 64.20 ], +[ 1.21, 2.21, 3.21, 4.21, 5.21, 6.21, 7.21, 8.21, 9.21, 10.21, 11.21, 12.21, 13.21, 14.21, 15.21, 16.21, 17.21, 18.21, 19.21, 20.21, 21.21, 22.21, 23.21, 24.21, 25.21, 26.21, 27.21, 28.21, 29.21, 30.21, 31.21, 32.21, 33.21, 34.21, 35.21, 36.21, 37.21, 38.21, 39.21, 40.21, 41.21, 42.21, 43.21, 44.21, 45.21, 46.21, 47.21, 48.21, 49.21, 50.21, 51.21, 52.21, 53.21, 54.21, 55.21, 56.21, 57.21, 58.21, 59.21, 60.21, 61.21, 62.21, 63.21, 64.21 ], +[ 1.22, 2.22, 3.22, 4.22, 5.22, 6.22, 7.22, 8.22, 9.22, 10.22, 11.22, 12.22, 13.22, 14.22, 15.22, 16.22, 17.22, 18.22, 19.22, 20.22, 21.22, 22.22, 23.22, 24.22, 25.22, 26.22, 27.22, 28.22, 29.22, 30.22, 31.22, 32.22, 33.22, 34.22, 35.22, 36.22, 37.22, 38.22, 39.22, 40.22, 41.22, 42.22, 43.22, 44.22, 45.22, 46.22, 47.22, 48.22, 49.22, 50.22, 51.22, 52.22, 53.22, 54.22, 55.22, 56.22, 57.22, 58.22, 59.22, 60.22, 61.22, 62.22, 63.22, 64.22 ], +[ 1.23, 2.23, 3.23, 4.23, 5.23, 6.23, 7.23, 8.23, 9.23, 10.23, 11.23, 12.23, 13.23, 14.23, 15.23, 16.23, 17.23, 18.23, 19.23, 20.23, 21.23, 22.23, 23.23, 24.23, 25.23, 26.23, 27.23, 28.23, 29.23, 30.23, 31.23, 32.23, 33.23, 34.23, 35.23, 36.23, 37.23, 38.23, 39.23, 40.23, 41.23, 42.23, 43.23, 44.23, 45.23, 46.23, 47.23, 48.23, 49.23, 50.23, 51.23, 52.23, 53.23, 54.23, 55.23, 56.23, 57.23, 58.23, 59.23, 60.23, 61.23, 62.23, 63.23, 64.23 ], +[ 1.24, 2.24, 3.24, 4.24, 5.24, 6.24, 7.24, 8.24, 9.24, 10.24, 11.24, 12.24, 13.24, 14.24, 15.24, 16.24, 17.24, 18.24, 19.24, 20.24, 21.24, 22.24, 23.24, 24.24, 25.24, 26.24, 27.24, 28.24, 29.24, 30.24, 31.24, 32.24, 33.24, 34.24, 35.24, 36.24, 37.24, 38.24, 39.24, 40.24, 41.24, 42.24, 43.24, 44.24, 45.24, 46.24, 47.24, 48.24, 49.24, 50.24, 51.24, 52.24, 53.24, 54.24, 55.24, 56.24, 57.24, 58.24, 59.24, 60.24, 61.24, 62.24, 63.24, 64.24 ], +[ 1.25, 2.25, 3.25, 4.25, 5.25, 6.25, 7.25, 8.25, 9.25, 10.25, 11.25, 12.25, 13.25, 14.25, 15.25, 16.25, 17.25, 18.25, 19.25, 20.25, 21.25, 22.25, 23.25, 24.25, 25.25, 26.25, 27.25, 28.25, 29.25, 30.25, 31.25, 32.25, 33.25, 34.25, 35.25, 36.25, 37.25, 38.25, 39.25, 40.25, 41.25, 42.25, 43.25, 44.25, 45.25, 46.25, 47.25, 48.25, 49.25, 50.25, 51.25, 52.25, 53.25, 54.25, 55.25, 56.25, 57.25, 58.25, 59.25, 60.25, 61.25, 62.25, 63.25, 64.25 ], +[ 1.26, 2.26, 3.26, 4.26, 5.26, 6.26, 7.26, 8.26, 9.26, 10.26, 11.26, 12.26, 13.26, 14.26, 15.26, 16.26, 17.26, 18.26, 19.26, 20.26, 21.26, 22.26, 23.26, 24.26, 25.26, 26.26, 27.26, 28.26, 29.26, 30.26, 31.26, 32.26, 33.26, 34.26, 35.26, 36.26, 37.26, 38.26, 39.26, 40.26, 41.26, 42.26, 43.26, 44.26, 45.26, 46.26, 47.26, 48.26, 49.26, 50.26, 51.26, 52.26, 53.26, 54.26, 55.26, 56.26, 57.26, 58.26, 59.26, 60.26, 61.26, 62.26, 63.26, 64.26 ], +[ 1.27, 2.27, 3.27, 4.27, 5.27, 6.27, 7.27, 8.27, 9.27, 10.27, 11.27, 12.27, 13.27, 14.27, 15.27, 16.27, 17.27, 18.27, 19.27, 20.27, 21.27, 22.27, 23.27, 24.27, 25.27, 26.27, 27.27, 28.27, 29.27, 30.27, 31.27, 32.27, 33.27, 34.27, 35.27, 36.27, 37.27, 38.27, 39.27, 40.27, 41.27, 42.27, 43.27, 44.27, 45.27, 46.27, 47.27, 48.27, 49.27, 50.27, 51.27, 52.27, 53.27, 54.27, 55.27, 56.27, 57.27, 58.27, 59.27, 60.27, 61.27, 62.27, 63.27, 64.27 ], +[ 1.28, 2.28, 3.28, 4.28, 5.28, 6.28, 7.28, 8.28, 9.28, 10.28, 11.28, 12.28, 13.28, 14.28, 15.28, 16.28, 17.28, 18.28, 19.28, 20.28, 21.28, 22.28, 23.28, 24.28, 25.28, 26.28, 27.28, 28.28, 29.28, 30.28, 31.28, 32.28, 33.28, 34.28, 35.28, 36.28, 37.28, 38.28, 39.28, 40.28, 41.28, 42.28, 43.28, 44.28, 45.28, 46.28, 47.28, 48.28, 49.28, 50.28, 51.28, 52.28, 53.28, 54.28, 55.28, 56.28, 57.28, 58.28, 59.28, 60.28, 61.28, 62.28, 63.28, 64.28 ], +[ 1.29, 2.29, 3.29, 4.29, 5.29, 6.29, 7.29, 8.29, 9.29, 10.29, 11.29, 12.29, 13.29, 14.29, 15.29, 16.29, 17.29, 18.29, 19.29, 20.29, 21.29, 22.29, 23.29, 24.29, 25.29, 26.29, 27.29, 28.29, 29.29, 30.29, 31.29, 32.29, 33.29, 34.29, 35.29, 36.29, 37.29, 38.29, 39.29, 40.29, 41.29, 42.29, 43.29, 44.29, 45.29, 46.29, 47.29, 48.29, 49.29, 50.29, 51.29, 52.29, 53.29, 54.29, 55.29, 56.29, 57.29, 58.29, 59.29, 60.29, 61.29, 62.29, 63.29, 64.29 ], +[ 1.30, 2.30, 3.30, 4.30, 5.30, 6.30, 7.30, 8.30, 9.30, 10.30, 11.30, 12.30, 13.30, 14.30, 15.30, 16.30, 17.30, 18.30, 19.30, 20.30, 21.30, 22.30, 23.30, 24.30, 25.30, 26.30, 27.30, 28.30, 29.30, 30.30, 31.30, 32.30, 33.30, 34.30, 35.30, 36.30, 37.30, 38.30, 39.30, 40.30, 41.30, 42.30, 43.30, 44.30, 45.30, 46.30, 47.30, 48.30, 49.30, 50.30, 51.30, 52.30, 53.30, 54.30, 55.30, 56.30, 57.30, 58.30, 59.30, 60.30, 61.30, 62.30, 63.30, 64.30 ], +[ 1.31, 2.31, 3.31, 4.31, 5.31, 6.31, 7.31, 8.31, 9.31, 10.31, 11.31, 12.31, 13.31, 14.31, 15.31, 16.31, 17.31, 18.31, 19.31, 20.31, 21.31, 22.31, 23.31, 24.31, 25.31, 26.31, 27.31, 28.31, 29.31, 30.31, 31.31, 32.31, 33.31, 34.31, 35.31, 36.31, 37.31, 38.31, 39.31, 40.31, 41.31, 42.31, 43.31, 44.31, 45.31, 46.31, 47.31, 48.31, 49.31, 50.31, 51.31, 52.31, 53.31, 54.31, 55.31, 56.31, 57.31, 58.31, 59.31, 60.31, 61.31, 62.31, 63.31, 64.31 ], +[ 1.32, 2.32, 3.32, 4.32, 5.32, 6.32, 7.32, 8.32, 9.32, 10.32, 11.32, 12.32, 13.32, 14.32, 15.32, 16.32, 17.32, 18.32, 19.32, 20.32, 21.32, 22.32, 23.32, 24.32, 25.32, 26.32, 27.32, 28.32, 29.32, 30.32, 31.32, 32.32, 33.32, 34.32, 35.32, 36.32, 37.32, 38.32, 39.32, 40.32, 41.32, 42.32, 43.32, 44.32, 45.32, 46.32, 47.32, 48.32, 49.32, 50.32, 51.32, 52.32, 53.32, 54.32, 55.32, 56.32, 57.32, 58.32, 59.32, 60.32, 61.32, 62.32, 63.32, 64.32 ] + ]> : tensor<32x64xf32> + %D = linalg.matmul ins(%A, %constB: tensor<128x32xf32>, tensor<32x64xf32>) outs(%C: tensor<128x64xf32>) -> tensor<128x64xf32> + %empty = tensor.empty() : tensor<128x640xf32> + return %D : tensor<128x64xf32> +} + diff --git a/test/Passes/lower-packs-and-unpacks-without-transpose.mlir b/test/Passes/lower-packs-and-unpacks-without-transpose.mlir new file mode 100644 index 000000000..b9d741731 --- /dev/null +++ b/test/Passes/lower-packs-and-unpacks-without-transpose.mlir @@ -0,0 +1,154 @@ +// RUN: tpp-opt %s -lower-packs-unpacks-without-transpose -canonicalize -split-input-file | FileCheck %s + +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0 * 32 + d2, d1 * 32 + d3)> +func.func @single_packed_arg(%arg0: tensor<128x512xf32>, %arg1: tensor<128x512xf32>) -> tensor<128x512xf32> { + %0 = tensor.empty() : tensor<4x16x32x32xf32> + %pack = tensor.pack %arg0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %0 : tensor<128x512xf32> -> tensor<4x16x32x32xf32> + %1 = linalg.generic {indexing_maps = [#map, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%pack : tensor<4x16x32x32xf32>) outs(%arg1 : tensor<128x512xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<128x512xf32> + return %1 : tensor<128x512xf32> +} +// CHECK-LABEL: func.func @single_packed_arg( +// CHECK-SAME: %[[ARG0:.+]]: tensor<128x512xf32>, +// CHECK-SAME: %[[ARG1:.+]]: tensor<128x512xf32>) +// CHECK-SAME: -> tensor<128x512xf32> + // CHECK: %[[EXP0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [4, 32, 16, 32] : tensor<128x512xf32> into tensor<4x32x16x32xf32> + // CHECK: %[[RES:.+]] = linalg.generic {{.*}} ins(%[[EXP0]] : tensor<4x32x16x32xf32>) outs(%[[ARG1]] : tensor<128x512xf32>) + // CHECK-NOT: tensor.collapse_shape + // CHECK: return %[[RES]] + +// ----- + +// NB: obtained from a M=128, N=256, K=512 linalg.matmul by -pack-matmul +#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> +func.func @revert_all_packing(%arg0: tensor<128x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<128x256xf32>) -> tensor<128x256xf32> { + %0 = tensor.empty() : tensor<4x16x32x32xf32> + %pack = tensor.pack %arg0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %0 : tensor<128x512xf32> -> tensor<4x16x32x32xf32> + %1 = tensor.empty() : tensor<8x16x32x32xf32> + %pack_0 = tensor.pack %arg1 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<512x256xf32> -> tensor<8x16x32x32xf32> + %2 = tensor.empty() : tensor<4x8x32x32xf32> + %pack_1 = tensor.pack %arg2 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %2 : tensor<128x256xf32> -> tensor<4x8x32x32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %pack_0 : tensor<4x16x32x32xf32>, tensor<8x16x32x32xf32>) outs(%pack_1 : tensor<4x8x32x32xf32>) { + ^bb0(%in: f32, %in_2: f32, %out: f32): + %4 = arith.mulf %in, %in_2 : f32 + %5 = arith.addf %out, %4 : f32 + linalg.yield %5 : f32 + } -> tensor<4x8x32x32xf32> + %unpack = tensor.unpack %3 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg2 : tensor<4x8x32x32xf32> -> tensor<128x256xf32> + return %unpack : tensor<128x256xf32> +} + +// CHECK-LABEL: func.func @revert_all_packing( +// CHECK-SAME: %[[ARG0:.+]]: tensor<128x512xf32>, +// CHECK-SAME: %[[ARG1:.+]]: tensor<512x256xf32>, +// CHECK-SAME: %[[ARG2:.+]]: tensor<128x256xf32>) +// CHECK-SAME: -> tensor<128x256xf32> + // CHECK: %[[EXP0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [4, 32, 16, 32] : tensor<128x512xf32> into tensor<4x32x16x32xf32> + // CHECK: %[[EXP1:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [16, 32, 8, 32] : tensor<512x256xf32> into tensor<16x32x8x32xf32> + // CHECK: %[[EXP2:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [4, 32, 8, 32] : tensor<128x256xf32> into tensor<4x32x8x32xf32> + // CHECK: %[[RES:.+]] = linalg.generic {{.*}} ins(%[[EXP0]], %[[EXP1]] : tensor<4x32x16x32xf32>, tensor<16x32x8x32xf32>) outs(%[[EXP2]] : tensor<4x32x8x32xf32>) + // CHECK: %[[COL:.+]] = tensor.collapse_shape %[[RES]] {{\[}}[0, 1], [2, 3]{{\]}} : tensor<4x32x8x32xf32> into tensor<128x256xf32> + // CHECK: return %[[COL]] + +// ----- + +#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> +func.func @only_keep_constant_packed_non_prepacked(%arg0: tensor<128x512xf32>, %arg1: tensor<128x256xf32>) -> tensor<128x256xf32> { + %cst = arith.constant dense<1.000000e-03> : tensor<512x256xf32> + %cst_empty = tensor.empty() : tensor<8x16x32x32xf32> + %cst_packed = tensor.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %cst_empty : tensor<512x256xf32> -> tensor<8x16x32x32xf32> + %0 = tensor.empty() : tensor<4x16x32x32xf32> + %pack = tensor.pack %arg0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %0 : tensor<128x512xf32> -> tensor<4x16x32x32xf32> + %1 = tensor.empty() : tensor<4x8x32x32xf32> + %pack_0 = tensor.pack %arg1 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<128x256xf32> -> tensor<4x8x32x32xf32> + %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %cst_packed : tensor<4x16x32x32xf32>, tensor<8x16x32x32xf32>) outs(%pack_0 : tensor<4x8x32x32xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %3 = arith.mulf %in, %in_1 : f32 + %4 = arith.addf %out, %3 : f32 + linalg.yield %4 : f32 + } -> tensor<4x8x32x32xf32> + // NB: unpack's outer_dims_perm should match those of corresponding pack - in case of elision it should match with an identity perm + %unpack = tensor.unpack %2 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg1 : tensor<4x8x32x32xf32> -> tensor<128x256xf32> + return %unpack : tensor<128x256xf32> +} +// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d2, d5)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d4)> +// CHECK-LABEL: func.func @only_keep_constant_packed_non_prepacked( +// CHECK-SAME: %[[ARG0:.+]]: tensor<128x512xf32>, +// CHECK-SAME: %[[ARG1:.+]]: tensor<128x256xf32>) +// CHECK-SAME: -> tensor<128x256xf32> + // NB: even if the following is the case, this does not mean the layout will be preserved in general + // CHECK: %[[CST:.*]] = arith.constant dense<1.000000e-03> : tensor<8x16x32x32xf32> + // CHECK: %[[EXP0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [4, 32, 16, 32] : tensor<128x512xf32> into tensor<4x32x16x32xf32> + // CHECK: %[[EXP1:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [4, 32, 8, 32] : tensor<128x256xf32> into tensor<4x32x8x32xf32> + // CHECK: %[[RES:.+]] = linalg.generic {{.*}} ins(%[[EXP0]], %[[CST]] : tensor<4x32x16x32xf32>, tensor<8x16x32x32xf32>) outs(%[[EXP1]] : tensor<4x32x8x32xf32>) + // CHECK: %[[COL:.+]] = tensor.collapse_shape %[[RES]] {{\[}}[0, 1], [2, 3]{{\]}} : tensor<4x32x8x32xf32> into tensor<128x256xf32> + // CHECK: return %[[COL]] + + +// ----- + +// NB: obtained from a M=?, N=256, K=512 linalg.matmul by -pack-matmul +#map = affine_map<()[s0] -> (s0 ceildiv 32)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> +#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> +module { + func.func @revert_packing_with_leading_dim_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + %cst = arith.constant dense<1.000000e-03> : tensor<8x16x32x32xf32> + %cst_0 = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %dim = tensor.dim %arg0, %c0 : tensor + %0 = affine.apply #map()[%dim] + %1 = tensor.empty(%0) : tensor + %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor -> tensor + %dim_1 = tensor.dim %arg1, %c0 : tensor + %2 = affine.apply #map()[%dim_1] + %3 = tensor.empty(%2) : tensor + %pack_2 = tensor.pack %arg1 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %3 : tensor -> tensor + %4 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %cst : tensor, tensor<8x16x32x32xf32>) outs(%pack_2 : tensor) { + ^bb0(%in: f32, %in_3: f32, %out: f32): + %5 = arith.mulf %in, %in_3 : f32 + %6 = arith.addf %out, %5 : f32 + linalg.yield %6 : f32 + } -> tensor + %unpack = tensor.unpack %4 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg1 : tensor -> tensor + return %unpack : tensor + } +} +// CHECK-LABEL: func.func @revert_packing_with_leading_dim_dynamic( +// CHECK-SAME: %[[ARG0:.+]]: tensor, +// CHECK-SAME: %[[ARG1:.+]]: tensor) +// CHECK-SAME: -> tensor +// func.func @revert_packing_with_one_dim_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index + // CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e-03> : tensor<8x16x32x32xf32> + // CHECK: %[[M:.*]] = tensor.dim %[[ARG0]], %[[C0]] + // CHECK: %[[M_DUP:.*]] = tensor.dim %[[ARG0]], %[[C0]] + // CHECK: %[[M_ROUNDED_UP:.*]] = affine.apply {{.*}}()[%[[M_DUP]], %[[M]]] + // CHECK: %[[ARG0_PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0] high[%[[M_ROUNDED_UP]], 0] + // CHECK: %[[M_PADDED:.*]] = tensor.dim %[[ARG0_PADDED]], %[[C0]] + // CHECK: %[[NUM_CHUNKS_PADDED_M:.*]] = arith.divui %[[M_PADDED]], %[[C32]] + // CHECK: %[[EXP0:.+]] = tensor.expand_shape %[[ARG0_PADDED]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [%[[NUM_CHUNKS_PADDED_M]], 32, 16, 32] : tensor into tensor + // CHECK: %[[M_ARG1:.*]] = tensor.dim %[[ARG1]], %[[C0]] + // CHECK: %[[M_ARG1_DUP:.*]] = tensor.dim %[[ARG1]], %[[C0]] + // CHECK: %[[M_ARG1_ROUNDED_UP:.*]] = affine.apply {{.*}}()[%[[M_ARG1_DUP]], %[[M_ARG1]]] + // CHECK: %[[ARG1_PADDED:.*]] = tensor.pad %[[ARG1]] low[0, 0] high[%[[M_ARG1_ROUNDED_UP]], 0] + // CHECK: %[[M_ARG1_PADDED:.*]] = tensor.dim %[[ARG1_PADDED]], %[[C0]] + // CHECK: %[[NUM_CHUNKS_PADDED_M_ARG1:.*]] = arith.divui %[[M_ARG1_PADDED]], %[[C32]] + // CHECK: %[[EXP1:.+]] = tensor.expand_shape %[[ARG1_PADDED]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [%[[NUM_CHUNKS_PADDED_M_ARG1]], 32, 8, 32] : tensor into tensor + // CHECK: %[[RES:.+]] = linalg.generic {{.*}} ins(%[[EXP0]], %[[CST]] : tensor, tensor<8x16x32x32xf32>) outs(%[[EXP1]] : tensor) + // CHECK: %[[COL:.+]] = tensor.collapse_shape %[[RES]] {{\[}}[0, 1], [2, 3]{{\]}} : tensor into tensor + // CHECK: %[[M_DUP2:.*]] = tensor.dim %[[ARG1]], %[[C0]] + // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[COL]][0, 0] [%[[M_DUP2]], 256] [1, 1] : tensor to tensor + // CHECK: %[[COPY:.+]] = linalg.copy ins(%[[SLICE]] : tensor) outs(%[[ARG1]] + // CHECK: return %[[COPY]]