From 2144edce15aa9220f323938ed2698f170f46978f Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Mon, 16 Sep 2024 09:19:57 -0700 Subject: [PATCH 1/9] First version -pack-to-expand-shape --- include/TPP/Passes.td | 7 + lib/TPP/Transforms/CMakeLists.txt | 1 + .../Transforms/ConvertPackToExpandShape.cpp | 123 ++++++++++++++++++ 3 files changed, 131 insertions(+) create mode 100644 lib/TPP/Transforms/ConvertPackToExpandShape.cpp diff --git a/include/TPP/Passes.td b/include/TPP/Passes.td index a1decd3ee..fd561b65e 100644 --- a/include/TPP/Passes.td +++ b/include/TPP/Passes.td @@ -244,6 +244,13 @@ def FoldAddIntoDest : Pass<"fold-add-into-dest", "ModuleOp"> { "arith::ArithDialect"]; } +def ConvertPackToExpandShapePass : Pass<"convert-pack-to-expand-shape", "ModuleOp"> { + let summary = "TODO"; + let description = [{ TODO }]; + let dependentDialects = ["linalg::LinalgDialect", + "tensor::TensorDialect"]; +} + def ElementWiseFusion : Pass<"element-wise-fusion", "func::FuncOp"> { let summary = "Run linalg element-wise fusion"; } diff --git a/lib/TPP/Transforms/CMakeLists.txt b/lib/TPP/Transforms/CMakeLists.txt index 827aeb154..b4a3329c5 100644 --- a/lib/TPP/Transforms/CMakeLists.txt +++ b/lib/TPP/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ add_mlir_library(TPPTransforms Bufferize.cpp ConstantFoldPack.cpp ConvertForAllToParallelOp.cpp + ConvertPackToExpandShape.cpp ConvInitSimplify.cpp DecomposeAggregatedOps.cpp LinalgDeGeneralize.cpp diff --git a/lib/TPP/Transforms/ConvertPackToExpandShape.cpp b/lib/TPP/Transforms/ConvertPackToExpandShape.cpp new file mode 100644 index 000000000..88a69c274 --- /dev/null +++ b/lib/TPP/Transforms/ConvertPackToExpandShape.cpp @@ -0,0 +1,123 @@ +//===- FoldAddIntoDest.cpp ---------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TPP/Passes.h" +#include "TPP/Transforms/Utils/ValueUtils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" +#include +using namespace mlir; + +namespace mlir { +namespace tpp { +#define GEN_PASS_DEF_CONVERTPACKTOEXPANDSHAPEPASS +#include "TPP/Passes.h.inc" +} // namespace tpp +} // namespace mlir + +namespace { + +static FailureOr> +packToExpandShape(tensor::PackOp packOp, linalg::GenericOp genericOp, + PatternRewriter &rewriter) { + AffineMap affineMap; + // TODO: clean-up + for (auto &use : packOp->getUses()) { + affineMap = genericOp.getMatchingIndexingMap(&use); + break; + } + auto origShape = + dyn_cast(packOp->getOperand(0).getType()).getShape(); + auto packedType = dyn_cast(packOp->getResult(0).getType()); + auto packedShape = packedType.getShape(); + auto packInverseMap = AffineMap::getPermutationMap( + mlir::tensor::getPackInverseDestPerm(packOp), rewriter.getContext()); + auto normalizedShape = applyPermutationMap(packInverseMap, packedShape); + + auto normalizedType = packedType.clone(normalizedShape); + auto normalizedIndexingMap = packInverseMap.compose(affineMap); + + auto innerDimPos = SmallVector(packOp.getInnerDimsPos()); + + SmallVector> associationIndices; + int curDimIdx = 0; + for (auto idx : llvm::seq(origShape.size())) { + associationIndices.emplace_back(SmallVector()); + associationIndices.back().push_back(curDimIdx++); + if (llvm::is_contained(innerDimPos, idx)) + associationIndices.back().push_back(curDimIdx++); + } + + auto expandShape = rewriter.create( + genericOp.getLoc(), normalizedType, packOp.getOperand(0), + ArrayRef(associationIndices)); + rewriter.replaceAllOpUsesWith(packOp, expandShape); + + return std::pair(expandShape, normalizedIndexingMap); +} + +struct ConvertPackToExpandShape : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::GenericOp genericOp, + PatternRewriter &rewriter) const override { + if (!linalg::isaContractionOpInterface(genericOp)) + return failure(); + // know linalg has two inputs and one output and is a contraction + + // TODO: need way to control which operands to reverted packing on + // for demo purposes just do the first one + auto packOp = dyn_cast_if_present( + genericOp->getOperand(0).getDefiningOp()); + if (!packOp) + return failure(); + + auto res = packToExpandShape(packOp, genericOp, rewriter); + if (!succeeded(res)) + return res; + + auto indexingMaps = genericOp.getIndexingMaps(); + auto indexingMapsAttr = ArrayAttr::get( + rewriter.getContext(), + {{AffineMapAttr::get(res->second), indexingMaps[1], indexingMaps[2]}}); + genericOp.setIndexingMapsAttr(indexingMapsAttr); + + return llvm::success(); + } +}; + +/// Replace linalg.add when destination passing suffices for achieving the sum. +struct ConvertPackToExpandShapePass + : public tpp::impl::ConvertPackToExpandShapePassBase< + ConvertPackToExpandShapePass> { + + void runOnOperation() override { + auto *ctx = &getContext(); + + RewritePatternSet patterns(ctx); + patterns.add(ctx); + + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + +} // namespace From da3e15d708a76c8c17ef60aad16012b4d0398225 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Tue, 17 Sep 2024 02:15:14 -0700 Subject: [PATCH 2/9] Support for tensor.unpack to tensor.collapse_shape --- include/TPP/Passes.td | 2 +- .../Transforms/ConvertPackToExpandShape.cpp | 89 +++++++++++++------ 2 files changed, 65 insertions(+), 26 deletions(-) diff --git a/include/TPP/Passes.td b/include/TPP/Passes.td index fd561b65e..4902e9907 100644 --- a/include/TPP/Passes.td +++ b/include/TPP/Passes.td @@ -244,7 +244,7 @@ def FoldAddIntoDest : Pass<"fold-add-into-dest", "ModuleOp"> { "arith::ArithDialect"]; } -def ConvertPackToExpandShapePass : Pass<"convert-pack-to-expand-shape", "ModuleOp"> { +def ConvertPackToExpandShapePass : Pass<"pack-unpack-to-expand-collapse-shape", "ModuleOp"> { let summary = "TODO"; let description = [{ TODO }]; let dependentDialects = ["linalg::LinalgDialect", diff --git a/lib/TPP/Transforms/ConvertPackToExpandShape.cpp b/lib/TPP/Transforms/ConvertPackToExpandShape.cpp index 88a69c274..6ffe34b3f 100644 --- a/lib/TPP/Transforms/ConvertPackToExpandShape.cpp +++ b/lib/TPP/Transforms/ConvertPackToExpandShape.cpp @@ -37,14 +37,8 @@ namespace tpp { namespace { static FailureOr> -packToExpandShape(tensor::PackOp packOp, linalg::GenericOp genericOp, +packToExpandShape(tensor::PackOp packOp, AffineMap affineMap, PatternRewriter &rewriter) { - AffineMap affineMap; - // TODO: clean-up - for (auto &use : packOp->getUses()) { - affineMap = genericOp.getMatchingIndexingMap(&use); - break; - } auto origShape = dyn_cast(packOp->getOperand(0).getType()).getShape(); auto packedType = dyn_cast(packOp->getResult(0).getType()); @@ -63,18 +57,44 @@ packToExpandShape(tensor::PackOp packOp, linalg::GenericOp genericOp, for (auto idx : llvm::seq(origShape.size())) { associationIndices.emplace_back(SmallVector()); associationIndices.back().push_back(curDimIdx++); + // TODO: is it the case that each dim can only occur once in innerDimPos? if (llvm::is_contained(innerDimPos, idx)) associationIndices.back().push_back(curDimIdx++); } + rewriter.setInsertionPointAfter(packOp); auto expandShape = rewriter.create( - genericOp.getLoc(), normalizedType, packOp.getOperand(0), + packOp->getLoc(), normalizedType, packOp.getOperand(0), ArrayRef(associationIndices)); - rewriter.replaceAllOpUsesWith(packOp, expandShape); return std::pair(expandShape, normalizedIndexingMap); } +static FailureOr +unpackToCollapseShape(tensor::UnPackOp unpackOp, PatternRewriter &rewriter) { + auto origType = + dyn_cast(unpackOp->getResult(0).getType()); + auto origShape = origType.getShape(); + auto innerDimPos = SmallVector(unpackOp.getInnerDimsPos()); + + SmallVector> associationIndices; + int curDimIdx = 0; + for (auto idx : llvm::seq(origShape.size())) { + associationIndices.emplace_back(SmallVector()); + associationIndices.back().push_back(curDimIdx++); + // TODO: is it the case that each dim can only occur once in innerDimPos? + if (llvm::is_contained(innerDimPos, idx)) + associationIndices.back().push_back(curDimIdx++); + } + + rewriter.setInsertionPointAfter(unpackOp); + auto collapseShape = rewriter.create( + unpackOp.getLoc(), origType, unpackOp.getOperand(0), + ArrayRef(associationIndices)); + + return collapseShape; +} + struct ConvertPackToExpandShape : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -84,22 +104,41 @@ struct ConvertPackToExpandShape : public OpRewritePattern { return failure(); // know linalg has two inputs and one output and is a contraction - // TODO: need way to control which operands to reverted packing on - // for demo purposes just do the first one - auto packOp = dyn_cast_if_present( - genericOp->getOperand(0).getDefiningOp()); - if (!packOp) - return failure(); - - auto res = packToExpandShape(packOp, genericOp, rewriter); - if (!succeeded(res)) - return res; - - auto indexingMaps = genericOp.getIndexingMaps(); - auto indexingMapsAttr = ArrayAttr::get( - rewriter.getContext(), - {{AffineMapAttr::get(res->second), indexingMaps[1], indexingMaps[2]}}); - genericOp.setIndexingMapsAttr(indexingMapsAttr); + auto indexingMaps = genericOp.getIndexingMapsArray(); + // TODO: need way to control which operands to revert packing on + auto idxSet = {0, 1, 2}; + Type resultType = nullptr; + for (auto idx : idxSet) { + auto packOp = dyn_cast_if_present( + genericOp->getOperand(idx).getDefiningOp()); + if (!packOp) + return failure(); + + auto res = packToExpandShape(packOp, indexingMaps[idx], rewriter); + if (!succeeded(res)) + return res; + + rewriter.replaceAllOpUsesWith(packOp, res->first); + if (idx == 2) { + resultType = res->first.getResultType(); + genericOp->getOpResult(0).setType(resultType); + } + + SmallVector indexingMaps = + llvm::to_vector(genericOp.getIndexingMaps()); + indexingMaps[idx] = AffineMapAttr::get(res->second); + + genericOp.setIndexingMapsAttr( + ArrayAttr::get(rewriter.getContext(), indexingMaps)); + } + + if (auto unpackOp = llvm::dyn_cast( + *(genericOp->getResult(0).getUsers().begin()))) { + auto res = unpackToCollapseShape(unpackOp, rewriter); + if (!succeeded(res)) + return failure(); + rewriter.replaceAllOpUsesWith(unpackOp, *res); + } return llvm::success(); } From 63bbceb990b3ff6958495847a3371e28cc85508d Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Tue, 17 Sep 2024 11:04:08 -0700 Subject: [PATCH 3/9] Fixes and improvements --- include/TPP/Passes.td | 3 +- lib/TPP/DefaultTppPasses.cpp | 1 + lib/TPP/Transforms/CMakeLists.txt | 2 +- ...pp => PackUnpackToExpandCollapseShape.cpp} | 121 ++++++++++++------ 4 files changed, 83 insertions(+), 44 deletions(-) rename lib/TPP/Transforms/{ConvertPackToExpandShape.cpp => PackUnpackToExpandCollapseShape.cpp} (51%) diff --git a/include/TPP/Passes.td b/include/TPP/Passes.td index 4902e9907..432871401 100644 --- a/include/TPP/Passes.td +++ b/include/TPP/Passes.td @@ -244,7 +244,8 @@ def FoldAddIntoDest : Pass<"fold-add-into-dest", "ModuleOp"> { "arith::ArithDialect"]; } -def ConvertPackToExpandShapePass : Pass<"pack-unpack-to-expand-collapse-shape", "ModuleOp"> { +def PackUnpackToExpandCollapseShape : Pass<"pack-unpack-to-expand-collapse-shape", + "ModuleOp"> { let summary = "TODO"; let description = [{ TODO }]; let dependentDialects = ["linalg::LinalgDialect", diff --git a/lib/TPP/DefaultTppPasses.cpp b/lib/TPP/DefaultTppPasses.cpp index b1a0cd83a..06a63d92b 100644 --- a/lib/TPP/DefaultTppPasses.cpp +++ b/lib/TPP/DefaultTppPasses.cpp @@ -86,6 +86,7 @@ struct DefaultTppPasses // Applies a set of passes at the linalg level to fuse and pack. pm.addPass(createTppMapping()); + pm.addPass(createPackUnpackToExpandCollapseShape()); // Generalize tensor.pack and tensor.unpack. pm.addPass(createLowerPacksAndUnPacks()); diff --git a/lib/TPP/Transforms/CMakeLists.txt b/lib/TPP/Transforms/CMakeLists.txt index b4a3329c5..03e8621e8 100644 --- a/lib/TPP/Transforms/CMakeLists.txt +++ b/lib/TPP/Transforms/CMakeLists.txt @@ -4,7 +4,7 @@ add_mlir_library(TPPTransforms Bufferize.cpp ConstantFoldPack.cpp ConvertForAllToParallelOp.cpp - ConvertPackToExpandShape.cpp + PackUnpackToExpandCollapseShape.cpp ConvInitSimplify.cpp DecomposeAggregatedOps.cpp LinalgDeGeneralize.cpp diff --git a/lib/TPP/Transforms/ConvertPackToExpandShape.cpp b/lib/TPP/Transforms/PackUnpackToExpandCollapseShape.cpp similarity index 51% rename from lib/TPP/Transforms/ConvertPackToExpandShape.cpp rename to lib/TPP/Transforms/PackUnpackToExpandCollapseShape.cpp index 6ffe34b3f..5d19c9f47 100644 --- a/lib/TPP/Transforms/ConvertPackToExpandShape.cpp +++ b/lib/TPP/Transforms/PackUnpackToExpandCollapseShape.cpp @@ -1,4 +1,4 @@ -//===- FoldAddIntoDest.cpp ---------------------------------------*- C++-*-===// +//===- PackUnpackToExpandCollapseShape.cpp -----------------------*- C++-*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -29,14 +29,14 @@ using namespace mlir; namespace mlir { namespace tpp { -#define GEN_PASS_DEF_CONVERTPACKTOEXPANDSHAPEPASS +#define GEN_PASS_DEF_PACKUNPACKTOEXPANDCOLLAPSESHAPE #include "TPP/Passes.h.inc" } // namespace tpp } // namespace mlir namespace { -static FailureOr> +static std::pair packToExpandShape(tensor::PackOp packOp, AffineMap affineMap, PatternRewriter &rewriter) { auto origShape = @@ -52,10 +52,10 @@ packToExpandShape(tensor::PackOp packOp, AffineMap affineMap, auto innerDimPos = SmallVector(packOp.getInnerDimsPos()); - SmallVector> associationIndices; + SmallVector associationIndices; int curDimIdx = 0; for (auto idx : llvm::seq(origShape.size())) { - associationIndices.emplace_back(SmallVector()); + associationIndices.emplace_back(ReassociationIndices()); associationIndices.back().push_back(curDimIdx++); // TODO: is it the case that each dim can only occur once in innerDimPos? if (llvm::is_contained(innerDimPos, idx)) @@ -70,10 +70,9 @@ packToExpandShape(tensor::PackOp packOp, AffineMap affineMap, return std::pair(expandShape, normalizedIndexingMap); } -static FailureOr +static tensor::CollapseShapeOp unpackToCollapseShape(tensor::UnPackOp unpackOp, PatternRewriter &rewriter) { - auto origType = - dyn_cast(unpackOp->getResult(0).getType()); + auto origType = dyn_cast(unpackOp->getResult(0).getType()); auto origShape = origType.getShape(); auto innerDimPos = SmallVector(unpackOp.getInnerDimsPos()); @@ -95,65 +94,103 @@ unpackToCollapseShape(tensor::UnPackOp unpackOp, PatternRewriter &rewriter) { return collapseShape; } -struct ConvertPackToExpandShape : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct PackOnInputToExpandShape : public OpRewritePattern { + // Is only called with single-user packOp operands, so callback can always + // find the (use by the) linalg.generic that is the target of the pattern. + using ControlFn = std::function; + ControlFn controlFn; + + PackOnInputToExpandShape(MLIRContext *context, ControlFn controlFn = nullptr, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {} LogicalResult matchAndRewrite(linalg::GenericOp genericOp, PatternRewriter &rewriter) const override { if (!linalg::isaContractionOpInterface(genericOp)) return failure(); - // know linalg has two inputs and one output and is a contraction auto indexingMaps = genericOp.getIndexingMapsArray(); - // TODO: need way to control which operands to revert packing on - auto idxSet = {0, 1, 2}; - Type resultType = nullptr; - for (auto idx : idxSet) { + bool modifiedAnOperand = false; + for (auto operandIdx : {0, 1}) { auto packOp = dyn_cast_if_present( - genericOp->getOperand(idx).getDefiningOp()); - if (!packOp) - return failure(); + genericOp->getOperand(operandIdx).getDefiningOp()); - auto res = packToExpandShape(packOp, indexingMaps[idx], rewriter); - if (!succeeded(res)) - return res; + if (!packOp || !packOp->hasOneUse() || (controlFn && !controlFn(packOp))) + continue; - rewriter.replaceAllOpUsesWith(packOp, res->first); - if (idx == 2) { - resultType = res->first.getResultType(); - genericOp->getOpResult(0).setType(resultType); - } + auto res = packToExpandShape(packOp, indexingMaps[operandIdx], rewriter); + rewriter.replaceAllOpUsesWith(packOp, res.first); - SmallVector indexingMaps = + SmallVector maps = llvm::to_vector(genericOp.getIndexingMaps()); - indexingMaps[idx] = AffineMapAttr::get(res->second); - + maps[operandIdx] = AffineMapAttr::get(res.second); genericOp.setIndexingMapsAttr( - ArrayAttr::get(rewriter.getContext(), indexingMaps)); - } + ArrayAttr::get(rewriter.getContext(), maps)); - if (auto unpackOp = llvm::dyn_cast( - *(genericOp->getResult(0).getUsers().begin()))) { - auto res = unpackToCollapseShape(unpackOp, rewriter); - if (!succeeded(res)) - return failure(); - rewriter.replaceAllOpUsesWith(unpackOp, *res); + modifiedAnOperand = true; } + return modifiedAnOperand ? success() : failure(); + } +}; + +struct PackUnpackOnOutputToExpandCollapseShape + : public OpRewritePattern { + // Is only called with single-user packOp operands, so callback can always + // find the (use by the) linalg.generic that is the target of the pattern. + using ControlFn = std::function; + ControlFn controlFn; + + PackUnpackOnOutputToExpandCollapseShape(MLIRContext *context, + ControlFn controlFn = nullptr, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {} + + LogicalResult matchAndRewrite(linalg::GenericOp genericOp, + PatternRewriter &rewriter) const override { + if (!linalg::isaContractionOpInterface(genericOp) || + !genericOp->hasOneUse()) + return failure(); + + auto packOp = dyn_cast_if_present( + genericOp->getOperand(2).getDefiningOp()); + auto unpackOp = llvm::dyn_cast( + *(genericOp->getResult(0).getUsers().begin())); + + if (!packOp || !packOp->hasOneUse() || !unpackOp || + (controlFn && !controlFn(packOp, unpackOp))) + return failure(); + + auto res = packToExpandShape(packOp, genericOp.getIndexingMapsArray()[2], + rewriter); + rewriter.replaceAllOpUsesWith(packOp, res.first); + + SmallVector maps = llvm::to_vector(genericOp.getIndexingMaps()); + maps[2] = AffineMapAttr::get(res.second); + genericOp.setIndexingMapsAttr(ArrayAttr::get(rewriter.getContext(), maps)); + + genericOp->getOpResult(0).setType(res.first.getResultType()); + + auto collapseShapeOp = unpackToCollapseShape(unpackOp, rewriter); + rewriter.replaceAllOpUsesWith(unpackOp, collapseShapeOp); + return llvm::success(); } }; -/// Replace linalg.add when destination passing suffices for achieving the sum. -struct ConvertPackToExpandShapePass - : public tpp::impl::ConvertPackToExpandShapePassBase< - ConvertPackToExpandShapePass> { +struct PackUnpackToExpandCollapseShape + : public tpp::impl::PackUnpackToExpandCollapseShapeBase< + PackUnpackToExpandCollapseShape> { void runOnOperation() override { auto *ctx = &getContext(); RewritePatternSet patterns(ctx); - patterns.add(ctx); + patterns.add(ctx, [](tensor::PackOp packOp) { + return !llvm::dyn_cast_if_present( + packOp.getOperand(0).getDefiningOp()); + }); + patterns.add(ctx); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } From d60d39e34b488c4aaba624a4b27a354c49b26003 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Thu, 19 Sep 2024 08:22:29 -0700 Subject: [PATCH 4/9] Fully working version with docs and tests and pass pipeline flag --- include/TPP/PassBundles.td | 10 +- include/TPP/Passes.td | 14 +- lib/TPP/DefaultPipeline.cpp | 10 +- lib/TPP/DefaultTppPasses.cpp | 5 +- lib/TPP/PassBundles/TppMapping.cpp | 4 + lib/TPP/Transforms/CMakeLists.txt | 2 +- .../LowerPacksAndUnpacksWithoutTranspose.cpp | 173 +++++++++++++++ .../PackUnpackToExpandCollapseShape.cpp | 199 ------------------ ...r-packs-and-unpacks-without-transpose.mlir | 125 +++++++++++ 9 files changed, 329 insertions(+), 213 deletions(-) create mode 100644 lib/TPP/Transforms/LowerPacksAndUnpacksWithoutTranspose.cpp delete mode 100644 lib/TPP/Transforms/PackUnpackToExpandCollapseShape.cpp create mode 100644 test/Passes/lower-packs-and-unpacks-without-transpose.mlir diff --git a/include/TPP/PassBundles.td b/include/TPP/PassBundles.td index 93c8a73ca..ff9c8223c 100644 --- a/include/TPP/PassBundles.td +++ b/include/TPP/PassBundles.td @@ -37,7 +37,10 @@ def DefaultTppPasses : Pass<"default-tpp-passes", "ModuleOp"> { "unsigned", "Grid-sizes for parallel tasks.">, Option<"linalgToVector", "linalg-to-vector", "bool", /*default=*/"false", - "Lower linalg directly to vector."> + "Lower linalg directly to vector.">, + Option<"lowerPackUnpackWithoutTranspose", "lower-pack-unpack-without-transpose-to-vector", + "bool", /*default=*/"false", + "Lower packs and unpacks reverting any dim permutations."> ]; } @@ -51,6 +54,11 @@ def TppMapping : Pass<"tpp-mapping", "ModuleOp"> { "memref::MemRefDialect", "scf::SCFDialect", "tensor::TensorDialect"]; + let options= [ + Option<"lowerPackUnpackWithoutTranspose", "lower-pack-unpack-without-transpose", + "bool", /*default=*/"false", + "Lower packs and unpacks reverting any dim permutations."> + ]; } def LinalgLowering : Pass<"linalg-lowering", "func::FuncOp"> { diff --git a/include/TPP/Passes.td b/include/TPP/Passes.td index 432871401..cd9f5c0da 100644 --- a/include/TPP/Passes.td +++ b/include/TPP/Passes.td @@ -177,6 +177,12 @@ def LowerPacksAndUnPacks : Pass<"lower-packs-unpacks", "func::FuncOp"> { "tensor::TensorDialect"]; } +def LowerPacksAndUnpacksWithoutTranspose : Pass<"lower-packs-unpacks-without-transpose", + "ModuleOp"> { + let dependentDialects = ["linalg::LinalgDialect", + "tensor::TensorDialect"]; +} + def RewriteConvToMatmulOrBrgemm : Pass<"rewrite-conv-to-matmul-or-brgemm", "func::FuncOp"> { let summary = "Rewrite Conv2DNhwcHwcfOp/Conv2DNchwFchwOp to Matmul or Brgemm."; @@ -244,14 +250,6 @@ def FoldAddIntoDest : Pass<"fold-add-into-dest", "ModuleOp"> { "arith::ArithDialect"]; } -def PackUnpackToExpandCollapseShape : Pass<"pack-unpack-to-expand-collapse-shape", - "ModuleOp"> { - let summary = "TODO"; - let description = [{ TODO }]; - let dependentDialects = ["linalg::LinalgDialect", - "tensor::TensorDialect"]; -} - def ElementWiseFusion : Pass<"element-wise-fusion", "func::FuncOp"> { let summary = "Run linalg element-wise fusion"; } diff --git a/lib/TPP/DefaultPipeline.cpp b/lib/TPP/DefaultPipeline.cpp index 2e826869d..f4ec69859 100644 --- a/lib/TPP/DefaultPipeline.cpp +++ b/lib/TPP/DefaultPipeline.cpp @@ -57,6 +57,11 @@ llvm::cl::opt linalgToVector("linalg-to-vector", llvm::cl::desc("Lower linalg to vector"), llvm::cl::init(false)); +llvm::cl::opt lowerPackUnpackWithoutTranspose( + "lower-pack-unpack-without-transpose", + llvm::cl::desc("Lower packs and unpacks reverting any dim permutations"), + llvm::cl::init(false)); + namespace mlir { namespace tpp { #define GEN_PASS_DEF_DEFAULTPIPELINE @@ -128,8 +133,9 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase, pm.addPass(createGpuPipeline(GpuPipelineOptions{gpuBackend})); } else { // Apply the default preprocessing pass - DefaultTppPassesOptions tppDefaultOptions{linalgToLoops, parallelTaskGrid, - linalgToVector}; + DefaultTppPassesOptions tppDefaultOptions{ + linalgToLoops, parallelTaskGrid, linalgToVector, + lowerPackUnpackWithoutTranspose}; pm.addPass(createDefaultTppPasses(tppDefaultOptions)); } diff --git a/lib/TPP/DefaultTppPasses.cpp b/lib/TPP/DefaultTppPasses.cpp index 06a63d92b..11b394a93 100644 --- a/lib/TPP/DefaultTppPasses.cpp +++ b/lib/TPP/DefaultTppPasses.cpp @@ -85,8 +85,9 @@ struct DefaultTppPasses pm.addPass(createRewriteBatchMatmulToMatmul()); // Applies a set of passes at the linalg level to fuse and pack. - pm.addPass(createTppMapping()); - pm.addPass(createPackUnpackToExpandCollapseShape()); + TppMappingOptions tppMappingOptions{ + lowerPackUnpackWithoutTranspose}; + pm.addPass(createTppMapping(tppMappingOptions)); // Generalize tensor.pack and tensor.unpack. pm.addPass(createLowerPacksAndUnPacks()); diff --git a/lib/TPP/PassBundles/TppMapping.cpp b/lib/TPP/PassBundles/TppMapping.cpp index d39b62530..e74e368e2 100644 --- a/lib/TPP/PassBundles/TppMapping.cpp +++ b/lib/TPP/PassBundles/TppMapping.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" #include "TPP/PassUtils.h" @@ -63,6 +64,9 @@ struct TppMapping : public tpp::impl::TppMappingBase, pm.addPass(createPackMatmul()); pm.addPass(createPackVNNI()); + if (lowerPackUnpackWithoutTranspose) { + pm.addPass(createLowerPacksAndUnpacksWithoutTranspose()); + } // Postprocess packing. // Run only canonicalizer at this stage as full cleanup (mostly CSE) can // mess up tensor producer-consumer chains used for analysis in the diff --git a/lib/TPP/Transforms/CMakeLists.txt b/lib/TPP/Transforms/CMakeLists.txt index 03e8621e8..6a1e83fca 100644 --- a/lib/TPP/Transforms/CMakeLists.txt +++ b/lib/TPP/Transforms/CMakeLists.txt @@ -4,11 +4,11 @@ add_mlir_library(TPPTransforms Bufferize.cpp ConstantFoldPack.cpp ConvertForAllToParallelOp.cpp - PackUnpackToExpandCollapseShape.cpp ConvInitSimplify.cpp DecomposeAggregatedOps.cpp LinalgDeGeneralize.cpp LowerPacksAndUnpacks.cpp + LowerPacksAndUnpacksWithoutTranspose.cpp RewriteBatchMatmulToMatmul.cpp RewriteConvsToMatmulOrBrgemm.cpp RewriteConvToMatmulImpl.cpp diff --git a/lib/TPP/Transforms/LowerPacksAndUnpacksWithoutTranspose.cpp b/lib/TPP/Transforms/LowerPacksAndUnpacksWithoutTranspose.cpp new file mode 100644 index 000000000..552c0a12e --- /dev/null +++ b/lib/TPP/Transforms/LowerPacksAndUnpacksWithoutTranspose.cpp @@ -0,0 +1,173 @@ +//===- LowerPacksAndUnpacksWithoutTranspose.cpp ------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TPP/Passes.h" +#include "TPP/Transforms/Utils/ValueUtils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" +#include +using namespace mlir; + +namespace mlir { +namespace tpp { +#define GEN_PASS_DEF_LOWERPACKSANDUNPACKSWITHOUTTRANSPOSE +#include "TPP/Passes.h.inc" +} // namespace tpp +} // namespace mlir + +namespace { + +/// Wrapper around linalg::lowerPack which undoes the transpose that might have +/// happened. Single user genericOp's indexing_maps is corrected accordingly. +void lowerPackWithoutTranspose(tensor::PackOp packOp, + linalg::GenericOp genericOp, uint operandIdx, + PatternRewriter &rewriter) { + auto packInversionPerm = tensor::getPackInverseDestPerm(packOp); + + auto res = linalg::lowerPack(rewriter, packOp); + + if (res->transposeOp) { + // Forget about the permutation of the dims on expandShapeOp. + rewriter.replaceAllOpUsesWith(res->transposeOp, res->expandShapeOp); + + // Invert corresponding transposed accesses by the single-user, genericOp. + auto indexingMaps = genericOp.getIndexingMapsArray(); + auto packInverseMap = + AffineMap::getPermutationMap(packInversionPerm, rewriter.getContext()); + auto normalizedIndexingMap = + packInverseMap.compose(indexingMaps[operandIdx]); + + SmallVector maps = llvm::to_vector(genericOp.getIndexingMaps()); + maps[operandIdx] = AffineMapAttr::get(normalizedIndexingMap); + genericOp.setIndexingMapsAttr(ArrayAttr::get(rewriter.getContext(), maps)); + } +} + +struct LowerPackOnInputsWithoutTranspose + : public OpRewritePattern { + // Is only called with single-user packOp operands, so callback can always + // find the (use by the) linalg.generic that is the target of the pattern. + using ControlFn = std::function; + ControlFn controlFn; + + LowerPackOnInputsWithoutTranspose(MLIRContext *context, + ControlFn controlFn = nullptr, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {} + + LogicalResult matchAndRewrite(linalg::GenericOp genericOp, + PatternRewriter &rewriter) const override { + bool modifiedAnOperand = false; + for (auto operandIdx : {0, 1}) { + auto packOp = dyn_cast_if_present( + genericOp->getOperand(operandIdx).getDefiningOp()); + + if (!packOp || !packOp->hasOneUse() || (controlFn && !controlFn(packOp))) + continue; + + lowerPackWithoutTranspose(packOp, genericOp, operandIdx, rewriter); + + modifiedAnOperand = true; + } + + return modifiedAnOperand ? success() : failure(); + } +}; + +struct LowerPackUnpackOnOutputWithoutTranspose + : public OpRewritePattern { + // Is only called with single-user packOp operands, so callback can always + // find the (use by the) linalg.generic that is the target of the pattern. + using ControlFn = std::function; + ControlFn controlFn; + + LowerPackUnpackOnOutputWithoutTranspose(MLIRContext *context, + ControlFn controlFn = nullptr, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {} + + LogicalResult matchAndRewrite(linalg::GenericOp genericOp, + PatternRewriter &rewriter) const override { + if (!genericOp->hasOneUse()) + return failure(); + + auto packOp = dyn_cast_if_present( + genericOp->getOperand(2).getDefiningOp()); + auto unpackOp = llvm::dyn_cast( + *(genericOp->getResult(0).getUsers().begin())); + + if (!packOp || !packOp->hasOneUse() || !unpackOp || + (controlFn && !controlFn(packOp, unpackOp))) + return failure(); + + auto unpackDest = unpackOp.getDest(); + bool destHasStaticShape = unpackDest.getType().hasStaticShape(); + + lowerPackWithoutTranspose(packOp, genericOp, /*operandIdx=*/2, rewriter); + auto res = linalg::lowerUnPack(rewriter, unpackOp); + + // Set genericOp's result type to the adjusted type of the out parameter. + genericOp->getOpResult(0).setType(genericOp.getOperand(2).getType()); + + if (auto transposeOp = res->transposeOp) { + // Forget about the transpose introduced by lowerUnPack. + rewriter.replaceAllOpUsesWith(transposeOp, transposeOp.getInput()); + } + + // lowerUnPack introduces a copy to maintain DPS w.r.t. unpackOp's dest. + // As we ignore permutations - and in the static case - don't do padding, + // we know the underlying buffer will be used as is and hence we do not need + // to specify a dest to update into. + auto extractSliceOp = res->extractSliceOp; + if (destHasStaticShape && extractSliceOp && extractSliceOp->hasOneUse()) { + auto copyOp = + dyn_cast(*extractSliceOp->getUsers().begin()); + if (copyOp && copyOp.getOutputs()[0] == unpackDest) { + rewriter.replaceAllOpUsesWith(copyOp, copyOp.getInputs()[0]); + } + } + + return success(); + } +}; + +struct LowerPacksAndUnpacksWithoutTranspose + : public tpp::impl::LowerPacksAndUnpacksWithoutTransposeBase< + LowerPacksAndUnpacksWithoutTranspose> { + + void runOnOperation() override { + auto *ctx = &getContext(); + + RewritePatternSet patterns(ctx); + patterns.add( + ctx, [](tensor::PackOp packOp) { + // Only lower packOps whose argument is not a constant. + return !llvm::dyn_cast_if_present( + packOp.getOperand(0).getDefiningOp()); + }); + patterns.add(ctx); + + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + +} // namespace diff --git a/lib/TPP/Transforms/PackUnpackToExpandCollapseShape.cpp b/lib/TPP/Transforms/PackUnpackToExpandCollapseShape.cpp deleted file mode 100644 index 5d19c9f47..000000000 --- a/lib/TPP/Transforms/PackUnpackToExpandCollapseShape.cpp +++ /dev/null @@ -1,199 +0,0 @@ -//===- PackUnpackToExpandCollapseShape.cpp -----------------------*- C++-*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "TPP/Passes.h" -#include "TPP/Transforms/Utils/ValueUtils.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" -#include "mlir/Dialect/Utils/IndexingUtils.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Dominance.h" -#include "mlir/Interfaces/DestinationStyleOpInterface.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/Support/Casting.h" -#include -using namespace mlir; - -namespace mlir { -namespace tpp { -#define GEN_PASS_DEF_PACKUNPACKTOEXPANDCOLLAPSESHAPE -#include "TPP/Passes.h.inc" -} // namespace tpp -} // namespace mlir - -namespace { - -static std::pair -packToExpandShape(tensor::PackOp packOp, AffineMap affineMap, - PatternRewriter &rewriter) { - auto origShape = - dyn_cast(packOp->getOperand(0).getType()).getShape(); - auto packedType = dyn_cast(packOp->getResult(0).getType()); - auto packedShape = packedType.getShape(); - auto packInverseMap = AffineMap::getPermutationMap( - mlir::tensor::getPackInverseDestPerm(packOp), rewriter.getContext()); - auto normalizedShape = applyPermutationMap(packInverseMap, packedShape); - - auto normalizedType = packedType.clone(normalizedShape); - auto normalizedIndexingMap = packInverseMap.compose(affineMap); - - auto innerDimPos = SmallVector(packOp.getInnerDimsPos()); - - SmallVector associationIndices; - int curDimIdx = 0; - for (auto idx : llvm::seq(origShape.size())) { - associationIndices.emplace_back(ReassociationIndices()); - associationIndices.back().push_back(curDimIdx++); - // TODO: is it the case that each dim can only occur once in innerDimPos? - if (llvm::is_contained(innerDimPos, idx)) - associationIndices.back().push_back(curDimIdx++); - } - - rewriter.setInsertionPointAfter(packOp); - auto expandShape = rewriter.create( - packOp->getLoc(), normalizedType, packOp.getOperand(0), - ArrayRef(associationIndices)); - - return std::pair(expandShape, normalizedIndexingMap); -} - -static tensor::CollapseShapeOp -unpackToCollapseShape(tensor::UnPackOp unpackOp, PatternRewriter &rewriter) { - auto origType = dyn_cast(unpackOp->getResult(0).getType()); - auto origShape = origType.getShape(); - auto innerDimPos = SmallVector(unpackOp.getInnerDimsPos()); - - SmallVector> associationIndices; - int curDimIdx = 0; - for (auto idx : llvm::seq(origShape.size())) { - associationIndices.emplace_back(SmallVector()); - associationIndices.back().push_back(curDimIdx++); - // TODO: is it the case that each dim can only occur once in innerDimPos? - if (llvm::is_contained(innerDimPos, idx)) - associationIndices.back().push_back(curDimIdx++); - } - - rewriter.setInsertionPointAfter(unpackOp); - auto collapseShape = rewriter.create( - unpackOp.getLoc(), origType, unpackOp.getOperand(0), - ArrayRef(associationIndices)); - - return collapseShape; -} - -struct PackOnInputToExpandShape : public OpRewritePattern { - // Is only called with single-user packOp operands, so callback can always - // find the (use by the) linalg.generic that is the target of the pattern. - using ControlFn = std::function; - ControlFn controlFn; - - PackOnInputToExpandShape(MLIRContext *context, ControlFn controlFn = nullptr, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {} - - LogicalResult matchAndRewrite(linalg::GenericOp genericOp, - PatternRewriter &rewriter) const override { - if (!linalg::isaContractionOpInterface(genericOp)) - return failure(); - - auto indexingMaps = genericOp.getIndexingMapsArray(); - bool modifiedAnOperand = false; - for (auto operandIdx : {0, 1}) { - auto packOp = dyn_cast_if_present( - genericOp->getOperand(operandIdx).getDefiningOp()); - - if (!packOp || !packOp->hasOneUse() || (controlFn && !controlFn(packOp))) - continue; - - auto res = packToExpandShape(packOp, indexingMaps[operandIdx], rewriter); - rewriter.replaceAllOpUsesWith(packOp, res.first); - - SmallVector maps = - llvm::to_vector(genericOp.getIndexingMaps()); - maps[operandIdx] = AffineMapAttr::get(res.second); - genericOp.setIndexingMapsAttr( - ArrayAttr::get(rewriter.getContext(), maps)); - - modifiedAnOperand = true; - } - - return modifiedAnOperand ? success() : failure(); - } -}; - -struct PackUnpackOnOutputToExpandCollapseShape - : public OpRewritePattern { - // Is only called with single-user packOp operands, so callback can always - // find the (use by the) linalg.generic that is the target of the pattern. - using ControlFn = std::function; - ControlFn controlFn; - - PackUnpackOnOutputToExpandCollapseShape(MLIRContext *context, - ControlFn controlFn = nullptr, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {} - - LogicalResult matchAndRewrite(linalg::GenericOp genericOp, - PatternRewriter &rewriter) const override { - if (!linalg::isaContractionOpInterface(genericOp) || - !genericOp->hasOneUse()) - return failure(); - - auto packOp = dyn_cast_if_present( - genericOp->getOperand(2).getDefiningOp()); - auto unpackOp = llvm::dyn_cast( - *(genericOp->getResult(0).getUsers().begin())); - - if (!packOp || !packOp->hasOneUse() || !unpackOp || - (controlFn && !controlFn(packOp, unpackOp))) - return failure(); - - auto res = packToExpandShape(packOp, genericOp.getIndexingMapsArray()[2], - rewriter); - rewriter.replaceAllOpUsesWith(packOp, res.first); - - SmallVector maps = llvm::to_vector(genericOp.getIndexingMaps()); - maps[2] = AffineMapAttr::get(res.second); - genericOp.setIndexingMapsAttr(ArrayAttr::get(rewriter.getContext(), maps)); - - genericOp->getOpResult(0).setType(res.first.getResultType()); - - auto collapseShapeOp = unpackToCollapseShape(unpackOp, rewriter); - rewriter.replaceAllOpUsesWith(unpackOp, collapseShapeOp); - - return llvm::success(); - } -}; - -struct PackUnpackToExpandCollapseShape - : public tpp::impl::PackUnpackToExpandCollapseShapeBase< - PackUnpackToExpandCollapseShape> { - - void runOnOperation() override { - auto *ctx = &getContext(); - - RewritePatternSet patterns(ctx); - patterns.add(ctx, [](tensor::PackOp packOp) { - return !llvm::dyn_cast_if_present( - packOp.getOperand(0).getDefiningOp()); - }); - patterns.add(ctx); - - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - } -}; - -} // namespace diff --git a/test/Passes/lower-packs-and-unpacks-without-transpose.mlir b/test/Passes/lower-packs-and-unpacks-without-transpose.mlir new file mode 100644 index 000000000..f5b1cf663 --- /dev/null +++ b/test/Passes/lower-packs-and-unpacks-without-transpose.mlir @@ -0,0 +1,125 @@ +// RUN: tpp-opt %s -lower-packs-unpacks-without-transpose -canonicalize -split-input-file | FileCheck %s + +// NB: obtained from a M=128, N=256, K=512 linalg.matmul by -pack-matmul +#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> +func.func @revert_all_packing(%arg0: tensor<128x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<128x256xf32>) -> tensor<128x256xf32> { + %0 = tensor.empty() : tensor<4x16x32x32xf32> + %pack = tensor.pack %arg0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %0 : tensor<128x512xf32> -> tensor<4x16x32x32xf32> + %1 = tensor.empty() : tensor<8x16x32x32xf32> + %pack_0 = tensor.pack %arg1 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<512x256xf32> -> tensor<8x16x32x32xf32> + %2 = tensor.empty() : tensor<4x8x32x32xf32> + %pack_1 = tensor.pack %arg2 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %2 : tensor<128x256xf32> -> tensor<4x8x32x32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %pack_0 : tensor<4x16x32x32xf32>, tensor<8x16x32x32xf32>) outs(%pack_1 : tensor<4x8x32x32xf32>) { + ^bb0(%in: f32, %in_2: f32, %out: f32): + %4 = arith.mulf %in, %in_2 : f32 + %5 = arith.addf %out, %4 : f32 + linalg.yield %5 : f32 + } -> tensor<4x8x32x32xf32> + %unpack = tensor.unpack %3 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg2 : tensor<4x8x32x32xf32> -> tensor<128x256xf32> + return %unpack : tensor<128x256xf32> +} + +// CHECK-LABEL: func.func @revert_all_packing( +// CHECK-SAME: %[[ARG0:.+]]: tensor<128x512xf32>, +// CHECK-SAME: %[[ARG1:.+]]: tensor<512x256xf32>, +// CHECK-SAME: %[[ARG2:.+]]: tensor<128x256xf32>) +// CHECK-SAME: -> tensor<128x256xf32> + // CHECK: %[[EXP0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [4, 32, 16, 32] : tensor<128x512xf32> into tensor<4x32x16x32xf32> + // CHECK: %[[EXP1:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [16, 32, 8, 32] : tensor<512x256xf32> into tensor<16x32x8x32xf32> + // CHECK: %[[EXP2:.+]] = tensor.expand_shape %[[ARG2]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [4, 32, 8, 32] : tensor<128x256xf32> into tensor<4x32x8x32xf32> + // CHECK: %[[RES:.+]] = linalg.generic {{.*}} ins(%[[EXP0]], %[[EXP1]] : tensor<4x32x16x32xf32>, tensor<16x32x8x32xf32>) outs(%[[EXP2]] : tensor<4x32x8x32xf32>) + // CHECK: %[[COL:.+]] = tensor.collapse_shape %[[RES]] {{\[}}[0, 1], [2, 3]{{\]}} : tensor<4x32x8x32xf32> into tensor<128x256xf32> + // CHECK: return %[[COL]] + +// ----- + +#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> +func.func @only_keep_constant_packed(%arg0: tensor<128x512xf32>, %arg1: tensor<128x256xf32>) -> tensor<128x256xf32> { + %cst = arith.constant dense<1.000000e-03> : tensor<8x16x32x32xf32> + %0 = tensor.empty() : tensor<4x16x32x32xf32> + %pack = tensor.pack %arg0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %0 : tensor<128x512xf32> -> tensor<4x16x32x32xf32> + %1 = tensor.empty() : tensor<4x8x32x32xf32> + %pack_0 = tensor.pack %arg1 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<128x256xf32> -> tensor<4x8x32x32xf32> + %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %cst : tensor<4x16x32x32xf32>, tensor<8x16x32x32xf32>) outs(%pack_0 : tensor<4x8x32x32xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %3 = arith.mulf %in, %in_1 : f32 + %4 = arith.addf %out, %3 : f32 + linalg.yield %4 : f32 + } -> tensor<4x8x32x32xf32> + %unpack = tensor.unpack %2 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg1 : tensor<4x8x32x32xf32> -> tensor<128x256xf32> + return %unpack : tensor<128x256xf32> +} +// CHECK-LABEL: func.func @only_keep_constant_packed( +// CHECK-SAME: %[[ARG0:.+]]: tensor<128x512xf32>, +// CHECK-SAME: %[[ARG1:.+]]: tensor<128x256xf32>) +// CHECK-SAME: -> tensor<128x256xf32> + // NB: even if the following is the case, this does not mean the layout will be preserved in general + // CHECK: %[[CST:.*]] = arith.constant dense<1.000000e-03> : tensor<8x16x32x32xf32> + // CHECK: %[[EXP0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [4, 32, 16, 32] : tensor<128x512xf32> into tensor<4x32x16x32xf32> + // CHECK: %[[EXP1:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [4, 32, 8, 32] : tensor<128x256xf32> into tensor<4x32x8x32xf32> + // CHECK: %[[RES:.+]] = linalg.generic {{.*}} ins(%[[EXP0]], %[[CST]] : tensor<4x32x16x32xf32>, tensor<8x16x32x32xf32>) outs(%[[EXP1]] : tensor<4x32x8x32xf32>) + // CHECK: %[[COL:.+]] = tensor.collapse_shape %[[RES]] {{\[}}[0, 1], [2, 3]{{\]}} : tensor<4x32x8x32xf32> into tensor<128x256xf32> + // CHECK: return %[[COL]] + +// ----- + +// NB: obtained from a M=?, N=256, K=512 linalg.matmul by -pack-matmul +#map = affine_map<()[s0] -> (s0 ceildiv 32)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> +#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> +module { + func.func @revert_packing_with_leading_dim_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + %cst = arith.constant dense<1.000000e-03> : tensor<8x16x32x32xf32> + %cst_0 = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %dim = tensor.dim %arg0, %c0 : tensor + %0 = affine.apply #map()[%dim] + %1 = tensor.empty(%0) : tensor + %pack = tensor.pack %arg0 padding_value(%cst_0 : f32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor -> tensor + %dim_1 = tensor.dim %arg1, %c0 : tensor + %2 = affine.apply #map()[%dim_1] + %3 = tensor.empty(%2) : tensor + %pack_2 = tensor.pack %arg1 padding_value(%cst_0 : f32) inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %3 : tensor -> tensor + %4 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %cst : tensor, tensor<8x16x32x32xf32>) outs(%pack_2 : tensor) { + ^bb0(%in: f32, %in_3: f32, %out: f32): + %5 = arith.mulf %in, %in_3 : f32 + %6 = arith.addf %out, %5 : f32 + linalg.yield %6 : f32 + } -> tensor + %unpack = tensor.unpack %4 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg1 : tensor -> tensor + return %unpack : tensor + } +} +// CHECK-LABEL: func.func @revert_packing_with_leading_dim_dynamic( +// CHECK-SAME: %[[ARG0:.+]]: tensor, +// CHECK-SAME: %[[ARG1:.+]]: tensor) +// CHECK-SAME: -> tensor +// func.func @revert_packing_with_one_dim_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index + // CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e-03> : tensor<8x16x32x32xf32> + // CHECK: %[[M:.*]] = tensor.dim %[[ARG0]], %[[C0]] + // CHECK: %[[M_DUP:.*]] = tensor.dim %[[ARG0]], %[[C0]] + // CHECK: %[[M_ROUNDED_UP:.*]] = affine.apply {{.*}}()[%[[M_DUP]], %[[M]]] + // CHECK: %[[ARG0_PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0] high[%[[M_ROUNDED_UP]], 0] + // CHECK: %[[M_PADDED:.*]] = tensor.dim %[[ARG0_PADDED]], %[[C0]] + // CHECK: %[[NUM_CHUNKS_PADDED_M:.*]] = arith.divui %[[M_PADDED]], %[[C32]] + // CHECK: %[[EXP0:.+]] = tensor.expand_shape %[[ARG0_PADDED]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [%[[NUM_CHUNKS_PADDED_M]], 32, 16, 32] : tensor into tensor + // CHECK: %[[M_ARG1:.*]] = tensor.dim %[[ARG1]], %[[C0]] + // CHECK: %[[M_ARG1_DUP:.*]] = tensor.dim %[[ARG1]], %[[C0]] + // CHECK: %[[M_ARG1_ROUNDED_UP:.*]] = affine.apply {{.*}}()[%[[M_ARG1_DUP]], %[[M_ARG1]]] + // CHECK: %[[ARG1_PADDED:.*]] = tensor.pad %[[ARG1]] low[0, 0] high[%[[M_ARG1_ROUNDED_UP]], 0] + // CHECK: %[[M_ARG1_PADDED:.*]] = tensor.dim %[[ARG1_PADDED]], %[[C0]] + // CHECK: %[[NUM_CHUNKS_PADDED_M_ARG1:.*]] = arith.divui %[[M_ARG1_PADDED]], %[[C32]] + // CHECK: %[[EXP1:.+]] = tensor.expand_shape %[[ARG1_PADDED]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [%[[NUM_CHUNKS_PADDED_M_ARG1]], 32, 8, 32] : tensor into tensor + // CHECK: %[[RES:.+]] = linalg.generic {{.*}} ins(%[[EXP0]], %[[CST]] : tensor, tensor<8x16x32x32xf32>) outs(%[[EXP1]] : tensor) + // CHECK: %[[COL:.+]] = tensor.collapse_shape %[[RES]] {{\[}}[0, 1], [2, 3]{{\]}} : tensor into tensor + // CHECK: %[[M_DUP2:.*]] = tensor.dim %[[ARG1]], %[[C0]] + // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[COL]][0, 0] [%[[M_DUP2]], 256] [1, 1] : tensor to tensor + // CHECK: %[[COPY:.+]] = linalg.copy ins(%[[SLICE]] : tensor) outs(%[[ARG1]] + // CHECK: return %[[COPY]] From 026f04dc7868ffc455ac1f1ae13a1b3cf0a0a18e Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Mon, 23 Sep 2024 05:53:43 -0700 Subject: [PATCH 5/9] Small doc fixes and new test for actual constant as arg to pack --- include/TPP/PassBundles.td | 4 +-- ...r-packs-and-unpacks-without-transpose.mlir | 35 +++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/include/TPP/PassBundles.td b/include/TPP/PassBundles.td index ff9c8223c..cba7eb42f 100644 --- a/include/TPP/PassBundles.td +++ b/include/TPP/PassBundles.td @@ -40,7 +40,7 @@ def DefaultTppPasses : Pass<"default-tpp-passes", "ModuleOp"> { "Lower linalg directly to vector.">, Option<"lowerPackUnpackWithoutTranspose", "lower-pack-unpack-without-transpose-to-vector", "bool", /*default=*/"false", - "Lower packs and unpacks reverting any dim permutations."> + "Lower non-constant packs and unpacks reverting any dim permutations."> ]; } @@ -57,7 +57,7 @@ def TppMapping : Pass<"tpp-mapping", "ModuleOp"> { let options= [ Option<"lowerPackUnpackWithoutTranspose", "lower-pack-unpack-without-transpose", "bool", /*default=*/"false", - "Lower packs and unpacks reverting any dim permutations."> + "Lower non-constant packs and unpacks reverting any dim permutations."> ]; } diff --git a/test/Passes/lower-packs-and-unpacks-without-transpose.mlir b/test/Passes/lower-packs-and-unpacks-without-transpose.mlir index f5b1cf663..ddd3679c5 100644 --- a/test/Passes/lower-packs-and-unpacks-without-transpose.mlir +++ b/test/Passes/lower-packs-and-unpacks-without-transpose.mlir @@ -67,6 +67,41 @@ func.func @only_keep_constant_packed(%arg0: tensor<128x512xf32>, %arg1: tensor<1 // ----- +#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> +func.func @only_keep_constant_packed_non_prepacked(%arg0: tensor<128x512xf32>, %arg1: tensor<128x256xf32>) -> tensor<128x256xf32> { + %cst = arith.constant dense<1.000000e-03> : tensor<512x256xf32> + %cst_empty = tensor.empty() : tensor<8x16x32x32xf32> + %cst_packed = tensor.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %cst_empty : tensor<512x256xf32> -> tensor<8x16x32x32xf32> + %0 = tensor.empty() : tensor<4x16x32x32xf32> + %pack = tensor.pack %arg0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %0 : tensor<128x512xf32> -> tensor<4x16x32x32xf32> + %1 = tensor.empty() : tensor<4x8x32x32xf32> + %pack_0 = tensor.pack %arg1 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<128x256xf32> -> tensor<4x8x32x32xf32> + %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %cst_packed : tensor<4x16x32x32xf32>, tensor<8x16x32x32xf32>) outs(%pack_0 : tensor<4x8x32x32xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %3 = arith.mulf %in, %in_1 : f32 + %4 = arith.addf %out, %3 : f32 + linalg.yield %4 : f32 + } -> tensor<4x8x32x32xf32> + %unpack = tensor.unpack %2 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg1 : tensor<4x8x32x32xf32> -> tensor<128x256xf32> + return %unpack : tensor<128x256xf32> +} +// CHECK-LABEL: func.func @only_keep_constant_packed_non_prepacked( +// CHECK-SAME: %[[ARG0:.+]]: tensor<128x512xf32>, +// CHECK-SAME: %[[ARG1:.+]]: tensor<128x256xf32>) +// CHECK-SAME: -> tensor<128x256xf32> + // NB: even if the following is the case, this does not mean the layout will be preserved in general + // CHECK: %[[CST:.*]] = arith.constant dense<1.000000e-03> : tensor<8x16x32x32xf32> + // CHECK: %[[EXP0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [4, 32, 16, 32] : tensor<128x512xf32> into tensor<4x32x16x32xf32> + // CHECK: %[[EXP1:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [4, 32, 8, 32] : tensor<128x256xf32> into tensor<4x32x8x32xf32> + // CHECK: %[[RES:.+]] = linalg.generic {{.*}} ins(%[[EXP0]], %[[CST]] : tensor<4x32x16x32xf32>, tensor<8x16x32x32xf32>) outs(%[[EXP1]] : tensor<4x32x8x32xf32>) + // CHECK: %[[COL:.+]] = tensor.collapse_shape %[[RES]] {{\[}}[0, 1], [2, 3]{{\]}} : tensor<4x32x8x32xf32> into tensor<128x256xf32> + // CHECK: return %[[COL]] + + +// ----- + // NB: obtained from a M=?, N=256, K=512 linalg.matmul by -pack-matmul #map = affine_map<()[s0] -> (s0 ceildiv 32)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> From b1ad586da15b4d00d5eb419b31e6673b66d6a349 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Mon, 23 Sep 2024 09:16:41 -0700 Subject: [PATCH 6/9] Address Adam's comments --- include/TPP/PassBundles.td | 2 +- .../LowerPacksAndUnpacksWithoutTranspose.cpp | 96 +++++++++++-------- ...r-packs-and-unpacks-without-transpose.mlir | 54 +++++------ 3 files changed, 77 insertions(+), 75 deletions(-) diff --git a/include/TPP/PassBundles.td b/include/TPP/PassBundles.td index cba7eb42f..3a5c229fa 100644 --- a/include/TPP/PassBundles.td +++ b/include/TPP/PassBundles.td @@ -38,7 +38,7 @@ def DefaultTppPasses : Pass<"default-tpp-passes", "ModuleOp"> { Option<"linalgToVector", "linalg-to-vector", "bool", /*default=*/"false", "Lower linalg directly to vector.">, - Option<"lowerPackUnpackWithoutTranspose", "lower-pack-unpack-without-transpose-to-vector", + Option<"lowerPackUnpackWithoutTranspose", "lower-pack-unpack-without-transpose", "bool", /*default=*/"false", "Lower non-constant packs and unpacks reverting any dim permutations."> ]; diff --git a/lib/TPP/Transforms/LowerPacksAndUnpacksWithoutTranspose.cpp b/lib/TPP/Transforms/LowerPacksAndUnpacksWithoutTranspose.cpp index 552c0a12e..d431df79b 100644 --- a/lib/TPP/Transforms/LowerPacksAndUnpacksWithoutTranspose.cpp +++ b/lib/TPP/Transforms/LowerPacksAndUnpacksWithoutTranspose.cpp @@ -38,7 +38,7 @@ namespace { /// Wrapper around linalg::lowerPack which undoes the transpose that might have /// happened. Single user genericOp's indexing_maps is corrected accordingly. -void lowerPackWithoutTranspose(tensor::PackOp packOp, +void lowerPackAndFoldTranspose(tensor::PackOp packOp, linalg::GenericOp genericOp, uint operandIdx, PatternRewriter &rewriter) { auto packInversionPerm = tensor::getPackInverseDestPerm(packOp); @@ -49,7 +49,7 @@ void lowerPackWithoutTranspose(tensor::PackOp packOp, // Forget about the permutation of the dims on expandShapeOp. rewriter.replaceAllOpUsesWith(res->transposeOp, res->expandShapeOp); - // Invert corresponding transposed accesses by the single-user, genericOp. + // Invert corresponding transposed accesses by the single user, genericOp. auto indexingMaps = genericOp.getIndexingMapsArray(); auto packInverseMap = AffineMap::getPermutationMap(packInversionPerm, rewriter.getContext()); @@ -62,14 +62,14 @@ void lowerPackWithoutTranspose(tensor::PackOp packOp, } } -struct LowerPackOnInputsWithoutTranspose +struct LowerPackOnInputsFoldingTranspose : public OpRewritePattern { // Is only called with single-user packOp operands, so callback can always // find the (use by the) linalg.generic that is the target of the pattern. using ControlFn = std::function; ControlFn controlFn; - LowerPackOnInputsWithoutTranspose(MLIRContext *context, + LowerPackOnInputsFoldingTranspose(MLIRContext *context, ControlFn controlFn = nullptr, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {} @@ -77,14 +77,15 @@ struct LowerPackOnInputsWithoutTranspose LogicalResult matchAndRewrite(linalg::GenericOp genericOp, PatternRewriter &rewriter) const override { bool modifiedAnOperand = false; - for (auto operandIdx : {0, 1}) { - auto packOp = dyn_cast_if_present( - genericOp->getOperand(operandIdx).getDefiningOp()); + for (auto &&[operandIdx, inOperand] : + llvm::enumerate(genericOp.getInputs())) { + auto packOp = + dyn_cast_if_present(inOperand.getDefiningOp()); if (!packOp || !packOp->hasOneUse() || (controlFn && !controlFn(packOp))) continue; - lowerPackWithoutTranspose(packOp, genericOp, operandIdx, rewriter); + lowerPackAndFoldTranspose(packOp, genericOp, operandIdx, rewriter); modifiedAnOperand = true; } @@ -93,60 +94,71 @@ struct LowerPackOnInputsWithoutTranspose } }; -struct LowerPackUnpackOnOutputWithoutTranspose +struct LowerPackUnpackOnOutputFoldingTranspose : public OpRewritePattern { // Is only called with single-user packOp operands, so callback can always // find the (use by the) linalg.generic that is the target of the pattern. using ControlFn = std::function; ControlFn controlFn; - LowerPackUnpackOnOutputWithoutTranspose(MLIRContext *context, + LowerPackUnpackOnOutputFoldingTranspose(MLIRContext *context, ControlFn controlFn = nullptr, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {} LogicalResult matchAndRewrite(linalg::GenericOp genericOp, PatternRewriter &rewriter) const override { - if (!genericOp->hasOneUse()) - return failure(); + bool modifiedAnOperand = false; + size_t numInputs = genericOp.getInputs().size(); + for (auto &&[outOperandIdx, outOperand] : + llvm::enumerate(genericOp.getOutputs())) { + size_t operandIdx = numInputs + outOperandIdx; + auto result = genericOp->getResult(outOperandIdx); - auto packOp = dyn_cast_if_present( - genericOp->getOperand(2).getDefiningOp()); - auto unpackOp = llvm::dyn_cast( - *(genericOp->getResult(0).getUsers().begin())); + if (!result.hasOneUse()) + continue; - if (!packOp || !packOp->hasOneUse() || !unpackOp || - (controlFn && !controlFn(packOp, unpackOp))) - return failure(); + auto packOp = + dyn_cast_if_present(outOperand.getDefiningOp()); + auto unpackOp = + llvm::dyn_cast(*(result.getUsers().begin())); - auto unpackDest = unpackOp.getDest(); - bool destHasStaticShape = unpackDest.getType().hasStaticShape(); + if (!packOp || !packOp->hasOneUse() || !unpackOp || + unpackOp.getInnerDimsPos() != packOp.getInnerDimsPos() || + unpackOp.getOuterDimsPerm() != packOp.getOuterDimsPerm() || + (controlFn && !controlFn(packOp, unpackOp))) + continue; - lowerPackWithoutTranspose(packOp, genericOp, /*operandIdx=*/2, rewriter); - auto res = linalg::lowerUnPack(rewriter, unpackOp); + auto unpackDest = unpackOp.getDest(); + bool destHasStaticShape = unpackDest.getType().hasStaticShape(); - // Set genericOp's result type to the adjusted type of the out parameter. - genericOp->getOpResult(0).setType(genericOp.getOperand(2).getType()); + lowerPackAndFoldTranspose(packOp, genericOp, operandIdx, rewriter); + auto res = linalg::lowerUnPack(rewriter, unpackOp); - if (auto transposeOp = res->transposeOp) { - // Forget about the transpose introduced by lowerUnPack. - rewriter.replaceAllOpUsesWith(transposeOp, transposeOp.getInput()); - } + // Set genericOp's result type to the adjusted type of the out parameter. + result.setType(genericOp.getOperand(operandIdx).getType()); - // lowerUnPack introduces a copy to maintain DPS w.r.t. unpackOp's dest. - // As we ignore permutations - and in the static case - don't do padding, - // we know the underlying buffer will be used as is and hence we do not need - // to specify a dest to update into. - auto extractSliceOp = res->extractSliceOp; - if (destHasStaticShape && extractSliceOp && extractSliceOp->hasOneUse()) { - auto copyOp = - dyn_cast(*extractSliceOp->getUsers().begin()); - if (copyOp && copyOp.getOutputs()[0] == unpackDest) { - rewriter.replaceAllOpUsesWith(copyOp, copyOp.getInputs()[0]); + if (auto transposeOp = res->transposeOp) { + // Forget about the transpose introduced by lowerUnPack. + rewriter.replaceAllOpUsesWith(transposeOp, transposeOp.getInput()); } + + // lowerUnPack introduces a copy to maintain DPS w.r.t. unpackOp's dest. + // As we ignore permutations and, in the static case, don't do padding, + // we know the underlying buffer will be used as is and hence we do not + // need to specify a dest to update into. + auto extractSliceOp = res->extractSliceOp; + if (destHasStaticShape && extractSliceOp && extractSliceOp->hasOneUse()) { + auto copyOp = + dyn_cast(*extractSliceOp->getUsers().begin()); + if (copyOp && copyOp.getOutputs()[0] == unpackDest) { + rewriter.replaceAllOpUsesWith(copyOp, copyOp.getInputs()[0]); + } + } + modifiedAnOperand = true; } - return success(); + return modifiedAnOperand ? success() : failure(); } }; @@ -158,13 +170,13 @@ struct LowerPacksAndUnpacksWithoutTranspose auto *ctx = &getContext(); RewritePatternSet patterns(ctx); - patterns.add( + patterns.add( ctx, [](tensor::PackOp packOp) { // Only lower packOps whose argument is not a constant. return !llvm::dyn_cast_if_present( packOp.getOperand(0).getDefiningOp()); }); - patterns.add(ctx); + patterns.add(ctx); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } diff --git a/test/Passes/lower-packs-and-unpacks-without-transpose.mlir b/test/Passes/lower-packs-and-unpacks-without-transpose.mlir index ddd3679c5..2645606ca 100644 --- a/test/Passes/lower-packs-and-unpacks-without-transpose.mlir +++ b/test/Passes/lower-packs-and-unpacks-without-transpose.mlir @@ -1,5 +1,27 @@ +// RUN: tpp-opt %s -pack-matmul -lower-packs-unpacks-without-transpose -canonicalize -split-input-file | FileCheck %s --check-prefix PACK-CHECK // RUN: tpp-opt %s -lower-packs-unpacks-without-transpose -canonicalize -split-input-file | FileCheck %s +func.func @pack_including_constant_then_lower_not_touching_constant(%arg0: tensor<128x512xf32>, + %arg1: tensor<128x256xf32>) -> tensor<128x256xf32> { + %weights = arith.constant dense<1.000000e+00> : tensor<512x256xf32> + %0 = linalg.matmul ins(%arg0, %weights : tensor<128x512xf32>, tensor<512x256xf32>) + outs(%arg1 : tensor<128x256xf32>) -> tensor<128x256xf32> + return %0 : tensor<128x256xf32> +} +// PACK-CHECK-LABEL: func.func @pack_including_constant_then_lower_not_touching_constant( +// PACK-CHECK-SAME: %[[ARG0:.+]]: tensor<128x512xf32>, +// PACK-CHECK-SAME: %[[ARG1:.+]]: tensor<128x256xf32>) +// PACK-CHECK-SAME: -> tensor<128x256xf32> + // NB: even if the following is the case, this does not mean the layout will be preserved in general + // PACK-CHECK: %[[CST:.*]] = arith.constant dense<1.000000e-03> : tensor<8x16x32x32xf32> + // PACK-CHECK: %[[EXP0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [4, 32, 16, 32] : tensor<128x512xf32> into tensor<4x32x16x32xf32> + // PACK-CHECK: %[[EXP1:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [4, 32, 8, 32] : tensor<128x256xf32> into tensor<4x32x8x32xf32> + // PACK-CHECK: %[[RES:.+]] = linalg.generic {{.*}} ins(%[[EXP0]], %[[CST]] : tensor<4x32x16x32xf32>, tensor<8x16x32x32xf32>) outs(%[[EXP1]] : tensor<4x32x8x32xf32>) + // PACK-CHECK: %[[COL:.+]] = tensor.collapse_shape %[[RES]] {{\[}}[0, 1], [2, 3]{{\]}} : tensor<4x32x8x32xf32> into tensor<128x256xf32> + // PACK-CHECK: return %[[COL]] + +// ----- + // NB: obtained from a M=128, N=256, K=512 linalg.matmul by -pack-matmul #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> @@ -35,38 +57,6 @@ func.func @revert_all_packing(%arg0: tensor<128x512xf32>, %arg1: tensor<512x256x // ----- -#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> -#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> -func.func @only_keep_constant_packed(%arg0: tensor<128x512xf32>, %arg1: tensor<128x256xf32>) -> tensor<128x256xf32> { - %cst = arith.constant dense<1.000000e-03> : tensor<8x16x32x32xf32> - %0 = tensor.empty() : tensor<4x16x32x32xf32> - %pack = tensor.pack %arg0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %0 : tensor<128x512xf32> -> tensor<4x16x32x32xf32> - %1 = tensor.empty() : tensor<4x8x32x32xf32> - %pack_0 = tensor.pack %arg1 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<128x256xf32> -> tensor<4x8x32x32xf32> - %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %cst : tensor<4x16x32x32xf32>, tensor<8x16x32x32xf32>) outs(%pack_0 : tensor<4x8x32x32xf32>) { - ^bb0(%in: f32, %in_1: f32, %out: f32): - %3 = arith.mulf %in, %in_1 : f32 - %4 = arith.addf %out, %3 : f32 - linalg.yield %4 : f32 - } -> tensor<4x8x32x32xf32> - %unpack = tensor.unpack %2 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg1 : tensor<4x8x32x32xf32> -> tensor<128x256xf32> - return %unpack : tensor<128x256xf32> -} -// CHECK-LABEL: func.func @only_keep_constant_packed( -// CHECK-SAME: %[[ARG0:.+]]: tensor<128x512xf32>, -// CHECK-SAME: %[[ARG1:.+]]: tensor<128x256xf32>) -// CHECK-SAME: -> tensor<128x256xf32> - // NB: even if the following is the case, this does not mean the layout will be preserved in general - // CHECK: %[[CST:.*]] = arith.constant dense<1.000000e-03> : tensor<8x16x32x32xf32> - // CHECK: %[[EXP0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [4, 32, 16, 32] : tensor<128x512xf32> into tensor<4x32x16x32xf32> - // CHECK: %[[EXP1:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [4, 32, 8, 32] : tensor<128x256xf32> into tensor<4x32x8x32xf32> - // CHECK: %[[RES:.+]] = linalg.generic {{.*}} ins(%[[EXP0]], %[[CST]] : tensor<4x32x16x32xf32>, tensor<8x16x32x32xf32>) outs(%[[EXP1]] : tensor<4x32x8x32xf32>) - // CHECK: %[[COL:.+]] = tensor.collapse_shape %[[RES]] {{\[}}[0, 1], [2, 3]{{\]}} : tensor<4x32x8x32xf32> into tensor<128x256xf32> - // CHECK: return %[[COL]] - -// ----- - #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)> #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)> From 5d656a4b4343ecb3da06cec8d1e51ace18bdbe7a Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Tue, 24 Sep 2024 14:42:54 -0700 Subject: [PATCH 7/9] Address Adam's comments V2 --- .../LowerPacksAndUnpacksWithoutTranspose.cpp | 23 ++++++++-- .../lower-pack-unpack-folding-transpose.mlir | 44 +++++++++++++++++++ ...r-packs-and-unpacks-without-transpose.mlir | 37 ++++++++-------- 3 files changed, 83 insertions(+), 21 deletions(-) create mode 100644 test/Integration/lower-pack-unpack-folding-transpose.mlir diff --git a/lib/TPP/Transforms/LowerPacksAndUnpacksWithoutTranspose.cpp b/lib/TPP/Transforms/LowerPacksAndUnpacksWithoutTranspose.cpp index d431df79b..40ba11bc7 100644 --- a/lib/TPP/Transforms/LowerPacksAndUnpacksWithoutTranspose.cpp +++ b/lib/TPP/Transforms/LowerPacksAndUnpacksWithoutTranspose.cpp @@ -25,6 +25,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" #include +#include using namespace mlir; namespace mlir { @@ -123,9 +124,25 @@ struct LowerPackUnpackOnOutputFoldingTranspose auto unpackOp = llvm::dyn_cast(*(result.getUsers().begin())); - if (!packOp || !packOp->hasOneUse() || !unpackOp || - unpackOp.getInnerDimsPos() != packOp.getInnerDimsPos() || - unpackOp.getOuterDimsPerm() != packOp.getOuterDimsPerm() || + if (!packOp || !packOp->hasOneUse() || !unpackOp) + continue; + + // Normalize empty outer_dims_perm to its corresponding identity map. + auto packOuterDimsPerm = SmallVector(packOp.getOuterDimsPerm()); + if (packOuterDimsPerm.empty()) { + packOuterDimsPerm = + SmallVector(packOp.getSource().getType().getRank()); + std::iota(packOuterDimsPerm.begin(), packOuterDimsPerm.begin(), 0); + } + auto unpackOuterDimsPerm = SmallVector(unpackOp.getOuterDimsPerm()); + if (unpackOuterDimsPerm.empty()) { + unpackOuterDimsPerm = + SmallVector(unpackOp.getResult().getType().getRank()); + std::iota(unpackOuterDimsPerm.begin(), unpackOuterDimsPerm.begin(), 0); + } + + if (unpackOp.getInnerDimsPos() != packOp.getInnerDimsPos() || + packOuterDimsPerm != unpackOuterDimsPerm || (controlFn && !controlFn(packOp, unpackOp))) continue; diff --git a/test/Integration/lower-pack-unpack-folding-transpose.mlir b/test/Integration/lower-pack-unpack-folding-transpose.mlir new file mode 100644 index 000000000..4acf0f08c --- /dev/null +++ b/test/Integration/lower-pack-unpack-folding-transpose.mlir @@ -0,0 +1,44 @@ +// RUN: tpp-run %s -e entry --entry-point-result=void -print -n 1 --seed=123 2>&1 > %S/matmul_48x64x96-default-tpp-passes.out +// RUN: tpp-run --lower-pack-unpack-without-transpose %s -e entry --entry-point-result=void -print -n 1 --seed=123 2>&1 > %S/matmul_48x64x96-default-tpp-passes-lower-pack-unpack-without-transpose.out +// RUN: fpcmp -r 0.09 %S/matmul_48x64x96-default-tpp-passes.out %S/matmul_48x64x96-default-tpp-passes-lower-pack-unpack-without-transpose.out +// RUN: rm %S/matmul_48x64x96-default-tpp-passes.out %S/matmul_48x64x96-default-tpp-passes-lower-pack-unpack-without-transpose.out + +func.func @entry(%A: tensor<128x32xf32>, + %C: tensor<128x64xf32>) -> tensor<128x64xf32> { + %constB = arith.constant dense<[ +[ 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1, 9.1, 10.1, 11.1, 12.1, 13.1, 14.1, 15.1, 16.1, 17.1, 18.1, 19.1, 20.1, 21.1, 22.1, 23.1, 24.1, 25.1, 26.1, 27.1, 28.1, 29.1, 30.1, 31.1, 32.1, 33.1, 34.1, 35.1, 36.1, 37.1, 38.1, 39.1, 40.1, 41.1, 42.1, 43.1, 44.1, 45.1, 46.1, 47.1, 48.1, 49.1, 50.1, 51.1, 52.1, 53.1, 54.1, 55.1, 56.1, 57.1, 58.1, 59.1, 60.1, 61.1, 62.1, 63.1, 64.1 ], +[ 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2, 9.2, 10.2, 11.2, 12.2, 13.2, 14.2, 15.2, 16.2, 17.2, 18.2, 19.2, 20.2, 21.2, 22.2, 23.2, 24.2, 25.2, 26.2, 27.2, 28.2, 29.2, 30.2, 31.2, 32.2, 33.2, 34.2, 35.2, 36.2, 37.2, 38.2, 39.2, 40.2, 41.2, 42.2, 43.2, 44.2, 45.2, 46.2, 47.2, 48.2, 49.2, 50.2, 51.2, 52.2, 53.2, 54.2, 55.2, 56.2, 57.2, 58.2, 59.2, 60.2, 61.2, 62.2, 63.2, 64.2 ], +[ 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3, 8.3, 9.3, 10.3, 11.3, 12.3, 13.3, 14.3, 15.3, 16.3, 17.3, 18.3, 19.3, 20.3, 21.3, 22.3, 23.3, 24.3, 25.3, 26.3, 27.3, 28.3, 29.3, 30.3, 31.3, 32.3, 33.3, 34.3, 35.3, 36.3, 37.3, 38.3, 39.3, 40.3, 41.3, 42.3, 43.3, 44.3, 45.3, 46.3, 47.3, 48.3, 49.3, 50.3, 51.3, 52.3, 53.3, 54.3, 55.3, 56.3, 57.3, 58.3, 59.3, 60.3, 61.3, 62.3, 63.3, 64.3 ], +[ 1.4, 2.4, 3.4, 4.4, 5.4, 6.4, 7.4, 8.4, 9.4, 10.4, 11.4, 12.4, 13.4, 14.4, 15.4, 16.4, 17.4, 18.4, 19.4, 20.4, 21.4, 22.4, 23.4, 24.4, 25.4, 26.4, 27.4, 28.4, 29.4, 30.4, 31.4, 32.4, 33.4, 34.4, 35.4, 36.4, 37.4, 38.4, 39.4, 40.4, 41.4, 42.4, 43.4, 44.4, 45.4, 46.4, 47.4, 48.4, 49.4, 50.4, 51.4, 52.4, 53.4, 54.4, 55.4, 56.4, 57.4, 58.4, 59.4, 60.4, 61.4, 62.4, 63.4, 64.4 ], +[ 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5, 11.5, 12.5, 13.5, 14.5, 15.5, 16.5, 17.5, 18.5, 19.5, 20.5, 21.5, 22.5, 23.5, 24.5, 25.5, 26.5, 27.5, 28.5, 29.5, 30.5, 31.5, 32.5, 33.5, 34.5, 35.5, 36.5, 37.5, 38.5, 39.5, 40.5, 41.5, 42.5, 43.5, 44.5, 45.5, 46.5, 47.5, 48.5, 49.5, 50.5, 51.5, 52.5, 53.5, 54.5, 55.5, 56.5, 57.5, 58.5, 59.5, 60.5, 61.5, 62.5, 63.5, 64.5 ], +[ 1.6, 2.6, 3.6, 4.6, 5.6, 6.6, 7.6, 8.6, 9.6, 10.6, 11.6, 12.6, 13.6, 14.6, 15.6, 16.6, 17.6, 18.6, 19.6, 20.6, 21.6, 22.6, 23.6, 24.6, 25.6, 26.6, 27.6, 28.6, 29.6, 30.6, 31.6, 32.6, 33.6, 34.6, 35.6, 36.6, 37.6, 38.6, 39.6, 40.6, 41.6, 42.6, 43.6, 44.6, 45.6, 46.6, 47.6, 48.6, 49.6, 50.6, 51.6, 52.6, 53.6, 54.6, 55.6, 56.6, 57.6, 58.6, 59.6, 60.6, 61.6, 62.6, 63.6, 64.6 ], +[ 1.7, 2.7, 3.7, 4.7, 5.7, 6.7, 7.7, 8.7, 9.7, 10.7, 11.7, 12.7, 13.7, 14.7, 15.7, 16.7, 17.7, 18.7, 19.7, 20.7, 21.7, 22.7, 23.7, 24.7, 25.7, 26.7, 27.7, 28.7, 29.7, 30.7, 31.7, 32.7, 33.7, 34.7, 35.7, 36.7, 37.7, 38.7, 39.7, 40.7, 41.7, 42.7, 43.7, 44.7, 45.7, 46.7, 47.7, 48.7, 49.7, 50.7, 51.7, 52.7, 53.7, 54.7, 55.7, 56.7, 57.7, 58.7, 59.7, 60.7, 61.7, 62.7, 63.7, 64.7 ], +[ 1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8, 9.8, 10.8, 11.8, 12.8, 13.8, 14.8, 15.8, 16.8, 17.8, 18.8, 19.8, 20.8, 21.8, 22.8, 23.8, 24.8, 25.8, 26.8, 27.8, 28.8, 29.8, 30.8, 31.8, 32.8, 33.8, 34.8, 35.8, 36.8, 37.8, 38.8, 39.8, 40.8, 41.8, 42.8, 43.8, 44.8, 45.8, 46.8, 47.8, 48.8, 49.8, 50.8, 51.8, 52.8, 53.8, 54.8, 55.8, 56.8, 57.8, 58.8, 59.8, 60.8, 61.8, 62.8, 63.8, 64.8 ], +[ 1.9, 2.9, 3.9, 4.9, 5.9, 6.9, 7.9, 8.9, 9.9, 10.9, 11.9, 12.9, 13.9, 14.9, 15.9, 16.9, 17.9, 18.9, 19.9, 20.9, 21.9, 22.9, 23.9, 24.9, 25.9, 26.9, 27.9, 28.9, 29.9, 30.9, 31.9, 32.9, 33.9, 34.9, 35.9, 36.9, 37.9, 38.9, 39.9, 40.9, 41.9, 42.9, 43.9, 44.9, 45.9, 46.9, 47.9, 48.9, 49.9, 50.9, 51.9, 52.9, 53.9, 54.9, 55.9, 56.9, 57.9, 58.9, 59.9, 60.9, 61.9, 62.9, 63.9, 64.9 ], +[ 1.10, 2.10, 3.10, 4.10, 5.10, 6.10, 7.10, 8.10, 9.10, 10.10, 11.10, 12.10, 13.10, 14.10, 15.10, 16.10, 17.10, 18.10, 19.10, 20.10, 21.10, 22.10, 23.10, 24.10, 25.10, 26.10, 27.10, 28.10, 29.10, 30.10, 31.10, 32.10, 33.10, 34.10, 35.10, 36.10, 37.10, 38.10, 39.10, 40.10, 41.10, 42.10, 43.10, 44.10, 45.10, 46.10, 47.10, 48.10, 49.10, 50.10, 51.10, 52.10, 53.10, 54.10, 55.10, 56.10, 57.10, 58.10, 59.10, 60.10, 61.10, 62.10, 63.10, 64.10 ], +[ 1.11, 2.11, 3.11, 4.11, 5.11, 6.11, 7.11, 8.11, 9.11, 10.11, 11.11, 12.11, 13.11, 14.11, 15.11, 16.11, 17.11, 18.11, 19.11, 20.11, 21.11, 22.11, 23.11, 24.11, 25.11, 26.11, 27.11, 28.11, 29.11, 30.11, 31.11, 32.11, 33.11, 34.11, 35.11, 36.11, 37.11, 38.11, 39.11, 40.11, 41.11, 42.11, 43.11, 44.11, 45.11, 46.11, 47.11, 48.11, 49.11, 50.11, 51.11, 52.11, 53.11, 54.11, 55.11, 56.11, 57.11, 58.11, 59.11, 60.11, 61.11, 62.11, 63.11, 64.11 ], +[ 1.12, 2.12, 3.12, 4.12, 5.12, 6.12, 7.12, 8.12, 9.12, 10.12, 11.12, 12.12, 13.12, 14.12, 15.12, 16.12, 17.12, 18.12, 19.12, 20.12, 21.12, 22.12, 23.12, 24.12, 25.12, 26.12, 27.12, 28.12, 29.12, 30.12, 31.12, 32.12, 33.12, 34.12, 35.12, 36.12, 37.12, 38.12, 39.12, 40.12, 41.12, 42.12, 43.12, 44.12, 45.12, 46.12, 47.12, 48.12, 49.12, 50.12, 51.12, 52.12, 53.12, 54.12, 55.12, 56.12, 57.12, 58.12, 59.12, 60.12, 61.12, 62.12, 63.12, 64.12 ], +[ 1.13, 2.13, 3.13, 4.13, 5.13, 6.13, 7.13, 8.13, 9.13, 10.13, 11.13, 12.13, 13.13, 14.13, 15.13, 16.13, 17.13, 18.13, 19.13, 20.13, 21.13, 22.13, 23.13, 24.13, 25.13, 26.13, 27.13, 28.13, 29.13, 30.13, 31.13, 32.13, 33.13, 34.13, 35.13, 36.13, 37.13, 38.13, 39.13, 40.13, 41.13, 42.13, 43.13, 44.13, 45.13, 46.13, 47.13, 48.13, 49.13, 50.13, 51.13, 52.13, 53.13, 54.13, 55.13, 56.13, 57.13, 58.13, 59.13, 60.13, 61.13, 62.13, 63.13, 64.13 ], +[ 1.14, 2.14, 3.14, 4.14, 5.14, 6.14, 7.14, 8.14, 9.14, 10.14, 11.14, 12.14, 13.14, 14.14, 15.14, 16.14, 17.14, 18.14, 19.14, 20.14, 21.14, 22.14, 23.14, 24.14, 25.14, 26.14, 27.14, 28.14, 29.14, 30.14, 31.14, 32.14, 33.14, 34.14, 35.14, 36.14, 37.14, 38.14, 39.14, 40.14, 41.14, 42.14, 43.14, 44.14, 45.14, 46.14, 47.14, 48.14, 49.14, 50.14, 51.14, 52.14, 53.14, 54.14, 55.14, 56.14, 57.14, 58.14, 59.14, 60.14, 61.14, 62.14, 63.14, 64.14 ], +[ 1.15, 2.15, 3.15, 4.15, 5.15, 6.15, 7.15, 8.15, 9.15, 10.15, 11.15, 12.15, 13.15, 14.15, 15.15, 16.15, 17.15, 18.15, 19.15, 20.15, 21.15, 22.15, 23.15, 24.15, 25.15, 26.15, 27.15, 28.15, 29.15, 30.15, 31.15, 32.15, 33.15, 34.15, 35.15, 36.15, 37.15, 38.15, 39.15, 40.15, 41.15, 42.15, 43.15, 44.15, 45.15, 46.15, 47.15, 48.15, 49.15, 50.15, 51.15, 52.15, 53.15, 54.15, 55.15, 56.15, 57.15, 58.15, 59.15, 60.15, 61.15, 62.15, 63.15, 64.15 ], +[ 1.16, 2.16, 3.16, 4.16, 5.16, 6.16, 7.16, 8.16, 9.16, 10.16, 11.16, 12.16, 13.16, 14.16, 15.16, 16.16, 17.16, 18.16, 19.16, 20.16, 21.16, 22.16, 23.16, 24.16, 25.16, 26.16, 27.16, 28.16, 29.16, 30.16, 31.16, 32.16, 33.16, 34.16, 35.16, 36.16, 37.16, 38.16, 39.16, 40.16, 41.16, 42.16, 43.16, 44.16, 45.16, 46.16, 47.16, 48.16, 49.16, 50.16, 51.16, 52.16, 53.16, 54.16, 55.16, 56.16, 57.16, 58.16, 59.16, 60.16, 61.16, 62.16, 63.16, 64.16 ], +[ 1.17, 2.17, 3.17, 4.17, 5.17, 6.17, 7.17, 8.17, 9.17, 10.17, 11.17, 12.17, 13.17, 14.17, 15.17, 16.17, 17.17, 18.17, 19.17, 20.17, 21.17, 22.17, 23.17, 24.17, 25.17, 26.17, 27.17, 28.17, 29.17, 30.17, 31.17, 32.17, 33.17, 34.17, 35.17, 36.17, 37.17, 38.17, 39.17, 40.17, 41.17, 42.17, 43.17, 44.17, 45.17, 46.17, 47.17, 48.17, 49.17, 50.17, 51.17, 52.17, 53.17, 54.17, 55.17, 56.17, 57.17, 58.17, 59.17, 60.17, 61.17, 62.17, 63.17, 64.17 ], +[ 1.18, 2.18, 3.18, 4.18, 5.18, 6.18, 7.18, 8.18, 9.18, 10.18, 11.18, 12.18, 13.18, 14.18, 15.18, 16.18, 17.18, 18.18, 19.18, 20.18, 21.18, 22.18, 23.18, 24.18, 25.18, 26.18, 27.18, 28.18, 29.18, 30.18, 31.18, 32.18, 33.18, 34.18, 35.18, 36.18, 37.18, 38.18, 39.18, 40.18, 41.18, 42.18, 43.18, 44.18, 45.18, 46.18, 47.18, 48.18, 49.18, 50.18, 51.18, 52.18, 53.18, 54.18, 55.18, 56.18, 57.18, 58.18, 59.18, 60.18, 61.18, 62.18, 63.18, 64.18 ], +[ 1.19, 2.19, 3.19, 4.19, 5.19, 6.19, 7.19, 8.19, 9.19, 10.19, 11.19, 12.19, 13.19, 14.19, 15.19, 16.19, 17.19, 18.19, 19.19, 20.19, 21.19, 22.19, 23.19, 24.19, 25.19, 26.19, 27.19, 28.19, 29.19, 30.19, 31.19, 32.19, 33.19, 34.19, 35.19, 36.19, 37.19, 38.19, 39.19, 40.19, 41.19, 42.19, 43.19, 44.19, 45.19, 46.19, 47.19, 48.19, 49.19, 50.19, 51.19, 52.19, 53.19, 54.19, 55.19, 56.19, 57.19, 58.19, 59.19, 60.19, 61.19, 62.19, 63.19, 64.19 ], +[ 1.20, 2.20, 3.20, 4.20, 5.20, 6.20, 7.20, 8.20, 9.20, 10.20, 11.20, 12.20, 13.20, 14.20, 15.20, 16.20, 17.20, 18.20, 19.20, 20.20, 21.20, 22.20, 23.20, 24.20, 25.20, 26.20, 27.20, 28.20, 29.20, 30.20, 31.20, 32.20, 33.20, 34.20, 35.20, 36.20, 37.20, 38.20, 39.20, 40.20, 41.20, 42.20, 43.20, 44.20, 45.20, 46.20, 47.20, 48.20, 49.20, 50.20, 51.20, 52.20, 53.20, 54.20, 55.20, 56.20, 57.20, 58.20, 59.20, 60.20, 61.20, 62.20, 63.20, 64.20 ], +[ 1.21, 2.21, 3.21, 4.21, 5.21, 6.21, 7.21, 8.21, 9.21, 10.21, 11.21, 12.21, 13.21, 14.21, 15.21, 16.21, 17.21, 18.21, 19.21, 20.21, 21.21, 22.21, 23.21, 24.21, 25.21, 26.21, 27.21, 28.21, 29.21, 30.21, 31.21, 32.21, 33.21, 34.21, 35.21, 36.21, 37.21, 38.21, 39.21, 40.21, 41.21, 42.21, 43.21, 44.21, 45.21, 46.21, 47.21, 48.21, 49.21, 50.21, 51.21, 52.21, 53.21, 54.21, 55.21, 56.21, 57.21, 58.21, 59.21, 60.21, 61.21, 62.21, 63.21, 64.21 ], +[ 1.22, 2.22, 3.22, 4.22, 5.22, 6.22, 7.22, 8.22, 9.22, 10.22, 11.22, 12.22, 13.22, 14.22, 15.22, 16.22, 17.22, 18.22, 19.22, 20.22, 21.22, 22.22, 23.22, 24.22, 25.22, 26.22, 27.22, 28.22, 29.22, 30.22, 31.22, 32.22, 33.22, 34.22, 35.22, 36.22, 37.22, 38.22, 39.22, 40.22, 41.22, 42.22, 43.22, 44.22, 45.22, 46.22, 47.22, 48.22, 49.22, 50.22, 51.22, 52.22, 53.22, 54.22, 55.22, 56.22, 57.22, 58.22, 59.22, 60.22, 61.22, 62.22, 63.22, 64.22 ], +[ 1.23, 2.23, 3.23, 4.23, 5.23, 6.23, 7.23, 8.23, 9.23, 10.23, 11.23, 12.23, 13.23, 14.23, 15.23, 16.23, 17.23, 18.23, 19.23, 20.23, 21.23, 22.23, 23.23, 24.23, 25.23, 26.23, 27.23, 28.23, 29.23, 30.23, 31.23, 32.23, 33.23, 34.23, 35.23, 36.23, 37.23, 38.23, 39.23, 40.23, 41.23, 42.23, 43.23, 44.23, 45.23, 46.23, 47.23, 48.23, 49.23, 50.23, 51.23, 52.23, 53.23, 54.23, 55.23, 56.23, 57.23, 58.23, 59.23, 60.23, 61.23, 62.23, 63.23, 64.23 ], +[ 1.24, 2.24, 3.24, 4.24, 5.24, 6.24, 7.24, 8.24, 9.24, 10.24, 11.24, 12.24, 13.24, 14.24, 15.24, 16.24, 17.24, 18.24, 19.24, 20.24, 21.24, 22.24, 23.24, 24.24, 25.24, 26.24, 27.24, 28.24, 29.24, 30.24, 31.24, 32.24, 33.24, 34.24, 35.24, 36.24, 37.24, 38.24, 39.24, 40.24, 41.24, 42.24, 43.24, 44.24, 45.24, 46.24, 47.24, 48.24, 49.24, 50.24, 51.24, 52.24, 53.24, 54.24, 55.24, 56.24, 57.24, 58.24, 59.24, 60.24, 61.24, 62.24, 63.24, 64.24 ], +[ 1.25, 2.25, 3.25, 4.25, 5.25, 6.25, 7.25, 8.25, 9.25, 10.25, 11.25, 12.25, 13.25, 14.25, 15.25, 16.25, 17.25, 18.25, 19.25, 20.25, 21.25, 22.25, 23.25, 24.25, 25.25, 26.25, 27.25, 28.25, 29.25, 30.25, 31.25, 32.25, 33.25, 34.25, 35.25, 36.25, 37.25, 38.25, 39.25, 40.25, 41.25, 42.25, 43.25, 44.25, 45.25, 46.25, 47.25, 48.25, 49.25, 50.25, 51.25, 52.25, 53.25, 54.25, 55.25, 56.25, 57.25, 58.25, 59.25, 60.25, 61.25, 62.25, 63.25, 64.25 ], +[ 1.26, 2.26, 3.26, 4.26, 5.26, 6.26, 7.26, 8.26, 9.26, 10.26, 11.26, 12.26, 13.26, 14.26, 15.26, 16.26, 17.26, 18.26, 19.26, 20.26, 21.26, 22.26, 23.26, 24.26, 25.26, 26.26, 27.26, 28.26, 29.26, 30.26, 31.26, 32.26, 33.26, 34.26, 35.26, 36.26, 37.26, 38.26, 39.26, 40.26, 41.26, 42.26, 43.26, 44.26, 45.26, 46.26, 47.26, 48.26, 49.26, 50.26, 51.26, 52.26, 53.26, 54.26, 55.26, 56.26, 57.26, 58.26, 59.26, 60.26, 61.26, 62.26, 63.26, 64.26 ], +[ 1.27, 2.27, 3.27, 4.27, 5.27, 6.27, 7.27, 8.27, 9.27, 10.27, 11.27, 12.27, 13.27, 14.27, 15.27, 16.27, 17.27, 18.27, 19.27, 20.27, 21.27, 22.27, 23.27, 24.27, 25.27, 26.27, 27.27, 28.27, 29.27, 30.27, 31.27, 32.27, 33.27, 34.27, 35.27, 36.27, 37.27, 38.27, 39.27, 40.27, 41.27, 42.27, 43.27, 44.27, 45.27, 46.27, 47.27, 48.27, 49.27, 50.27, 51.27, 52.27, 53.27, 54.27, 55.27, 56.27, 57.27, 58.27, 59.27, 60.27, 61.27, 62.27, 63.27, 64.27 ], +[ 1.28, 2.28, 3.28, 4.28, 5.28, 6.28, 7.28, 8.28, 9.28, 10.28, 11.28, 12.28, 13.28, 14.28, 15.28, 16.28, 17.28, 18.28, 19.28, 20.28, 21.28, 22.28, 23.28, 24.28, 25.28, 26.28, 27.28, 28.28, 29.28, 30.28, 31.28, 32.28, 33.28, 34.28, 35.28, 36.28, 37.28, 38.28, 39.28, 40.28, 41.28, 42.28, 43.28, 44.28, 45.28, 46.28, 47.28, 48.28, 49.28, 50.28, 51.28, 52.28, 53.28, 54.28, 55.28, 56.28, 57.28, 58.28, 59.28, 60.28, 61.28, 62.28, 63.28, 64.28 ], +[ 1.29, 2.29, 3.29, 4.29, 5.29, 6.29, 7.29, 8.29, 9.29, 10.29, 11.29, 12.29, 13.29, 14.29, 15.29, 16.29, 17.29, 18.29, 19.29, 20.29, 21.29, 22.29, 23.29, 24.29, 25.29, 26.29, 27.29, 28.29, 29.29, 30.29, 31.29, 32.29, 33.29, 34.29, 35.29, 36.29, 37.29, 38.29, 39.29, 40.29, 41.29, 42.29, 43.29, 44.29, 45.29, 46.29, 47.29, 48.29, 49.29, 50.29, 51.29, 52.29, 53.29, 54.29, 55.29, 56.29, 57.29, 58.29, 59.29, 60.29, 61.29, 62.29, 63.29, 64.29 ], +[ 1.30, 2.30, 3.30, 4.30, 5.30, 6.30, 7.30, 8.30, 9.30, 10.30, 11.30, 12.30, 13.30, 14.30, 15.30, 16.30, 17.30, 18.30, 19.30, 20.30, 21.30, 22.30, 23.30, 24.30, 25.30, 26.30, 27.30, 28.30, 29.30, 30.30, 31.30, 32.30, 33.30, 34.30, 35.30, 36.30, 37.30, 38.30, 39.30, 40.30, 41.30, 42.30, 43.30, 44.30, 45.30, 46.30, 47.30, 48.30, 49.30, 50.30, 51.30, 52.30, 53.30, 54.30, 55.30, 56.30, 57.30, 58.30, 59.30, 60.30, 61.30, 62.30, 63.30, 64.30 ], +[ 1.31, 2.31, 3.31, 4.31, 5.31, 6.31, 7.31, 8.31, 9.31, 10.31, 11.31, 12.31, 13.31, 14.31, 15.31, 16.31, 17.31, 18.31, 19.31, 20.31, 21.31, 22.31, 23.31, 24.31, 25.31, 26.31, 27.31, 28.31, 29.31, 30.31, 31.31, 32.31, 33.31, 34.31, 35.31, 36.31, 37.31, 38.31, 39.31, 40.31, 41.31, 42.31, 43.31, 44.31, 45.31, 46.31, 47.31, 48.31, 49.31, 50.31, 51.31, 52.31, 53.31, 54.31, 55.31, 56.31, 57.31, 58.31, 59.31, 60.31, 61.31, 62.31, 63.31, 64.31 ], +[ 1.32, 2.32, 3.32, 4.32, 5.32, 6.32, 7.32, 8.32, 9.32, 10.32, 11.32, 12.32, 13.32, 14.32, 15.32, 16.32, 17.32, 18.32, 19.32, 20.32, 21.32, 22.32, 23.32, 24.32, 25.32, 26.32, 27.32, 28.32, 29.32, 30.32, 31.32, 32.32, 33.32, 34.32, 35.32, 36.32, 37.32, 38.32, 39.32, 40.32, 41.32, 42.32, 43.32, 44.32, 45.32, 46.32, 47.32, 48.32, 49.32, 50.32, 51.32, 52.32, 53.32, 54.32, 55.32, 56.32, 57.32, 58.32, 59.32, 60.32, 61.32, 62.32, 63.32, 64.32 ] + ]> : tensor<32x64xf32> + %D = linalg.matmul ins(%A, %constB: tensor<128x32xf32>, tensor<32x64xf32>) outs(%C: tensor<128x64xf32>) -> tensor<128x64xf32> + return %D : tensor<128x64xf32> +} diff --git a/test/Passes/lower-packs-and-unpacks-without-transpose.mlir b/test/Passes/lower-packs-and-unpacks-without-transpose.mlir index 2645606ca..83abbe516 100644 --- a/test/Passes/lower-packs-and-unpacks-without-transpose.mlir +++ b/test/Passes/lower-packs-and-unpacks-without-transpose.mlir @@ -1,24 +1,25 @@ -// RUN: tpp-opt %s -pack-matmul -lower-packs-unpacks-without-transpose -canonicalize -split-input-file | FileCheck %s --check-prefix PACK-CHECK +// : tpp-opt %s -pack-matmul -lower-packs-unpacks-without-transpose -canonicalize -split-input-file | FileCheck %s --check-prefix PACK-CHECK // RUN: tpp-opt %s -lower-packs-unpacks-without-transpose -canonicalize -split-input-file | FileCheck %s -func.func @pack_including_constant_then_lower_not_touching_constant(%arg0: tensor<128x512xf32>, - %arg1: tensor<128x256xf32>) -> tensor<128x256xf32> { - %weights = arith.constant dense<1.000000e+00> : tensor<512x256xf32> - %0 = linalg.matmul ins(%arg0, %weights : tensor<128x512xf32>, tensor<512x256xf32>) - outs(%arg1 : tensor<128x256xf32>) -> tensor<128x256xf32> - return %0 : tensor<128x256xf32> +#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0 * 32 + d2, d1 * 32 + d3)> +func.func @single_packed_arg(%arg0: tensor<128x512xf32>, %arg1: tensor<128x512xf32>) -> tensor<128x512xf32> { + %0 = tensor.empty() : tensor<4x16x32x32xf32> + %pack = tensor.pack %arg0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %0 : tensor<128x512xf32> -> tensor<4x16x32x32xf32> + %1 = linalg.generic {indexing_maps = [#map, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%pack : tensor<4x16x32x32xf32>) outs(%arg1 : tensor<128x512xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<128x512xf32> + return %1 : tensor<128x512xf32> } -// PACK-CHECK-LABEL: func.func @pack_including_constant_then_lower_not_touching_constant( -// PACK-CHECK-SAME: %[[ARG0:.+]]: tensor<128x512xf32>, -// PACK-CHECK-SAME: %[[ARG1:.+]]: tensor<128x256xf32>) -// PACK-CHECK-SAME: -> tensor<128x256xf32> - // NB: even if the following is the case, this does not mean the layout will be preserved in general - // PACK-CHECK: %[[CST:.*]] = arith.constant dense<1.000000e-03> : tensor<8x16x32x32xf32> - // PACK-CHECK: %[[EXP0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [4, 32, 16, 32] : tensor<128x512xf32> into tensor<4x32x16x32xf32> - // PACK-CHECK: %[[EXP1:.+]] = tensor.expand_shape %[[ARG1]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [4, 32, 8, 32] : tensor<128x256xf32> into tensor<4x32x8x32xf32> - // PACK-CHECK: %[[RES:.+]] = linalg.generic {{.*}} ins(%[[EXP0]], %[[CST]] : tensor<4x32x16x32xf32>, tensor<8x16x32x32xf32>) outs(%[[EXP1]] : tensor<4x32x8x32xf32>) - // PACK-CHECK: %[[COL:.+]] = tensor.collapse_shape %[[RES]] {{\[}}[0, 1], [2, 3]{{\]}} : tensor<4x32x8x32xf32> into tensor<128x256xf32> - // PACK-CHECK: return %[[COL]] +// CHECK-LABEL: func.func @single_packed_arg( +// CHECK-SAME: %[[ARG0:.+]]: tensor<128x512xf32>, +// CHECK-SAME: %[[ARG1:.+]]: tensor<128x512xf32>) +// CHECK-SAME: -> tensor<128x512xf32> + // CHECK: %[[EXP0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} output_shape [4, 32, 16, 32] : tensor<128x512xf32> into tensor<4x32x16x32xf32> + // CHECK: %[[RES:.+]] = linalg.generic {{.*}} ins(%[[EXP0]] : tensor<4x32x16x32xf32>) outs(%[[ARG1]] : tensor<128x512xf32>) + // CHECK-NOT: tensor.collapse_shape + // CHECK: return %[[RES]] // ----- From eb28a4f3a2d2d68ae4f9c1e1b6b1b2ff82a939d2 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Wed, 25 Sep 2024 04:38:40 -0700 Subject: [PATCH 8/9] Address further comments and add another functional correctness test --- .../LowerPacksAndUnpacksWithoutTranspose.cpp | 6 ++--- ... lower-pack-unpack-without-transpose.mlir} | 23 ++++++++++++++----- ...r-packs-and-unpacks-without-transpose.mlir | 7 ++++-- 3 files changed, 25 insertions(+), 11 deletions(-) rename test/Integration/{lower-pack-unpack-folding-transpose.mlir => lower-pack-unpack-without-transpose.mlir} (86%) diff --git a/lib/TPP/Transforms/LowerPacksAndUnpacksWithoutTranspose.cpp b/lib/TPP/Transforms/LowerPacksAndUnpacksWithoutTranspose.cpp index 40ba11bc7..57b2b6030 100644 --- a/lib/TPP/Transforms/LowerPacksAndUnpacksWithoutTranspose.cpp +++ b/lib/TPP/Transforms/LowerPacksAndUnpacksWithoutTranspose.cpp @@ -127,18 +127,18 @@ struct LowerPackUnpackOnOutputFoldingTranspose if (!packOp || !packOp->hasOneUse() || !unpackOp) continue; - // Normalize empty outer_dims_perm to its corresponding identity map. + // Normalize empty outer_dims_perm to its corresponding identity perm. auto packOuterDimsPerm = SmallVector(packOp.getOuterDimsPerm()); if (packOuterDimsPerm.empty()) { packOuterDimsPerm = SmallVector(packOp.getSource().getType().getRank()); - std::iota(packOuterDimsPerm.begin(), packOuterDimsPerm.begin(), 0); + std::iota(packOuterDimsPerm.begin(), packOuterDimsPerm.end(), 0); } auto unpackOuterDimsPerm = SmallVector(unpackOp.getOuterDimsPerm()); if (unpackOuterDimsPerm.empty()) { unpackOuterDimsPerm = SmallVector(unpackOp.getResult().getType().getRank()); - std::iota(unpackOuterDimsPerm.begin(), unpackOuterDimsPerm.begin(), 0); + std::iota(unpackOuterDimsPerm.begin(), unpackOuterDimsPerm.end(), 0); } if (unpackOp.getInnerDimsPos() != packOp.getInnerDimsPos() || diff --git a/test/Integration/lower-pack-unpack-folding-transpose.mlir b/test/Integration/lower-pack-unpack-without-transpose.mlir similarity index 86% rename from test/Integration/lower-pack-unpack-folding-transpose.mlir rename to test/Integration/lower-pack-unpack-without-transpose.mlir index 4acf0f08c..ef09e5d57 100644 --- a/test/Integration/lower-pack-unpack-folding-transpose.mlir +++ b/test/Integration/lower-pack-unpack-without-transpose.mlir @@ -1,10 +1,19 @@ -// RUN: tpp-run %s -e entry --entry-point-result=void -print -n 1 --seed=123 2>&1 > %S/matmul_48x64x96-default-tpp-passes.out -// RUN: tpp-run --lower-pack-unpack-without-transpose %s -e entry --entry-point-result=void -print -n 1 --seed=123 2>&1 > %S/matmul_48x64x96-default-tpp-passes-lower-pack-unpack-without-transpose.out -// RUN: fpcmp -r 0.09 %S/matmul_48x64x96-default-tpp-passes.out %S/matmul_48x64x96-default-tpp-passes-lower-pack-unpack-without-transpose.out -// RUN: rm %S/matmul_48x64x96-default-tpp-passes.out %S/matmul_48x64x96-default-tpp-passes-lower-pack-unpack-without-transpose.out +// RUN: tpp-run %s -e big_matmul --entry-point-result=void -print -n 1 --seed=123 2>&1 > %S/lower-pack-unpack-without-transpose-big_matmul-only-default-passes.out +// RUN: tpp-run --lower-pack-unpack-without-transpose %s -e big_matmul --entry-point-result=void -print -n 1 --seed=123 2>&1 > %S/lower-pack-unpack-without-transpose-big_matmul-default-passes-lower-pack-unpack-without-transpose.out +// RUN: fpcmp -r 0.001 %S/lower-pack-unpack-without-transpose-big_matmul-only-default-passes.out %S/lower-pack-unpack-without-transpose-big_matmul-default-passes-lower-pack-unpack-without-transpose.out +// RUN: rm %S/lower-pack-unpack-without-transpose-big_matmul-only-default-passes.out %S/lower-pack-unpack-without-transpose-big_matmul-default-passes-lower-pack-unpack-without-transpose.out -func.func @entry(%A: tensor<128x32xf32>, - %C: tensor<128x64xf32>) -> tensor<128x64xf32> { +func.func @big_matmul(%A: tensor<1024x2048xf32>, %B: tensor<2048x4096xf32>, %C: tensor<1024x4096xf32>) -> tensor<1024x4096xf32> { + %D = linalg.matmul ins(%A, %B: tensor<1024x2048xf32>, tensor<2048x4096xf32>) outs(%C: tensor<1024x4096xf32>) -> tensor<1024x4096xf32> + return %D : tensor<1024x4096xf32> +} + +// RUN: tpp-run %s -e small_matmul_with_a_cst_arg --entry-point-result=void -print -n 1 --seed=123 2>&1 > %S/lower-pack-unpack-without-transpose-small_matmul_with_a_cst_arg-only-default-passes.out +// RUN: tpp-run --lower-pack-unpack-without-transpose %s -e small_matmul_with_a_cst_arg --entry-point-result=void -print -n 1 --seed=123 2>&1 > %S/lower-pack-unpack-without-transpose-small_matmul_with_a_cst_arg-default-passes-lower-pack-unpack-without-transpose.out +// RUN: fpcmp -r 0.001 %S/lower-pack-unpack-without-transpose-small_matmul_with_a_cst_arg-only-default-passes.out %S/lower-pack-unpack-without-transpose-small_matmul_with_a_cst_arg-default-passes-lower-pack-unpack-without-transpose.out +// RUN: rm %S/lower-pack-unpack-without-transpose-small_matmul_with_a_cst_arg-only-default-passes.out %S/lower-pack-unpack-without-transpose-small_matmul_with_a_cst_arg-default-passes-lower-pack-unpack-without-transpose.out + +func.func @small_matmul_with_a_cst_arg(%A: tensor<128x32xf32>, %C: tensor<128x64xf32>) -> tensor<128x64xf32> { %constB = arith.constant dense<[ [ 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1, 9.1, 10.1, 11.1, 12.1, 13.1, 14.1, 15.1, 16.1, 17.1, 18.1, 19.1, 20.1, 21.1, 22.1, 23.1, 24.1, 25.1, 26.1, 27.1, 28.1, 29.1, 30.1, 31.1, 32.1, 33.1, 34.1, 35.1, 36.1, 37.1, 38.1, 39.1, 40.1, 41.1, 42.1, 43.1, 44.1, 45.1, 46.1, 47.1, 48.1, 49.1, 50.1, 51.1, 52.1, 53.1, 54.1, 55.1, 56.1, 57.1, 58.1, 59.1, 60.1, 61.1, 62.1, 63.1, 64.1 ], [ 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2, 8.2, 9.2, 10.2, 11.2, 12.2, 13.2, 14.2, 15.2, 16.2, 17.2, 18.2, 19.2, 20.2, 21.2, 22.2, 23.2, 24.2, 25.2, 26.2, 27.2, 28.2, 29.2, 30.2, 31.2, 32.2, 33.2, 34.2, 35.2, 36.2, 37.2, 38.2, 39.2, 40.2, 41.2, 42.2, 43.2, 44.2, 45.2, 46.2, 47.2, 48.2, 49.2, 50.2, 51.2, 52.2, 53.2, 54.2, 55.2, 56.2, 57.2, 58.2, 59.2, 60.2, 61.2, 62.2, 63.2, 64.2 ], @@ -40,5 +49,7 @@ func.func @entry(%A: tensor<128x32xf32>, [ 1.32, 2.32, 3.32, 4.32, 5.32, 6.32, 7.32, 8.32, 9.32, 10.32, 11.32, 12.32, 13.32, 14.32, 15.32, 16.32, 17.32, 18.32, 19.32, 20.32, 21.32, 22.32, 23.32, 24.32, 25.32, 26.32, 27.32, 28.32, 29.32, 30.32, 31.32, 32.32, 33.32, 34.32, 35.32, 36.32, 37.32, 38.32, 39.32, 40.32, 41.32, 42.32, 43.32, 44.32, 45.32, 46.32, 47.32, 48.32, 49.32, 50.32, 51.32, 52.32, 53.32, 54.32, 55.32, 56.32, 57.32, 58.32, 59.32, 60.32, 61.32, 62.32, 63.32, 64.32 ] ]> : tensor<32x64xf32> %D = linalg.matmul ins(%A, %constB: tensor<128x32xf32>, tensor<32x64xf32>) outs(%C: tensor<128x64xf32>) -> tensor<128x64xf32> + %empty = tensor.empty() : tensor<128x640xf32> return %D : tensor<128x64xf32> } + diff --git a/test/Passes/lower-packs-and-unpacks-without-transpose.mlir b/test/Passes/lower-packs-and-unpacks-without-transpose.mlir index 83abbe516..b9d741731 100644 --- a/test/Passes/lower-packs-and-unpacks-without-transpose.mlir +++ b/test/Passes/lower-packs-and-unpacks-without-transpose.mlir @@ -1,4 +1,3 @@ -// : tpp-opt %s -pack-matmul -lower-packs-unpacks-without-transpose -canonicalize -split-input-file | FileCheck %s --check-prefix PACK-CHECK // RUN: tpp-opt %s -lower-packs-unpacks-without-transpose -canonicalize -split-input-file | FileCheck %s #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> @@ -68,16 +67,20 @@ func.func @only_keep_constant_packed_non_prepacked(%arg0: tensor<128x512xf32>, % %0 = tensor.empty() : tensor<4x16x32x32xf32> %pack = tensor.pack %arg0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %0 : tensor<128x512xf32> -> tensor<4x16x32x32xf32> %1 = tensor.empty() : tensor<4x8x32x32xf32> - %pack_0 = tensor.pack %arg1 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<128x256xf32> -> tensor<4x8x32x32xf32> + %pack_0 = tensor.pack %arg1 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %1 : tensor<128x256xf32> -> tensor<4x8x32x32xf32> %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %cst_packed : tensor<4x16x32x32xf32>, tensor<8x16x32x32xf32>) outs(%pack_0 : tensor<4x8x32x32xf32>) { ^bb0(%in: f32, %in_1: f32, %out: f32): %3 = arith.mulf %in, %in_1 : f32 %4 = arith.addf %out, %3 : f32 linalg.yield %4 : f32 } -> tensor<4x8x32x32xf32> + // NB: unpack's outer_dims_perm should match those of corresponding pack - in case of elision it should match with an identity perm %unpack = tensor.unpack %2 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg1 : tensor<4x8x32x32xf32> -> tensor<128x256xf32> return %unpack : tensor<128x256xf32> } +// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d2, d5)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d4)> // CHECK-LABEL: func.func @only_keep_constant_packed_non_prepacked( // CHECK-SAME: %[[ARG0:.+]]: tensor<128x512xf32>, // CHECK-SAME: %[[ARG1:.+]]: tensor<128x256xf32>) From fa330594623253740c628d1c4b56c44ddf7be0be Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Wed, 25 Sep 2024 08:20:15 -0700 Subject: [PATCH 9/9] Use sensible sizes --- test/Integration/lower-pack-unpack-without-transpose.mlir | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/Integration/lower-pack-unpack-without-transpose.mlir b/test/Integration/lower-pack-unpack-without-transpose.mlir index ef09e5d57..cb2416cdd 100644 --- a/test/Integration/lower-pack-unpack-without-transpose.mlir +++ b/test/Integration/lower-pack-unpack-without-transpose.mlir @@ -3,9 +3,9 @@ // RUN: fpcmp -r 0.001 %S/lower-pack-unpack-without-transpose-big_matmul-only-default-passes.out %S/lower-pack-unpack-without-transpose-big_matmul-default-passes-lower-pack-unpack-without-transpose.out // RUN: rm %S/lower-pack-unpack-without-transpose-big_matmul-only-default-passes.out %S/lower-pack-unpack-without-transpose-big_matmul-default-passes-lower-pack-unpack-without-transpose.out -func.func @big_matmul(%A: tensor<1024x2048xf32>, %B: tensor<2048x4096xf32>, %C: tensor<1024x4096xf32>) -> tensor<1024x4096xf32> { - %D = linalg.matmul ins(%A, %B: tensor<1024x2048xf32>, tensor<2048x4096xf32>) outs(%C: tensor<1024x4096xf32>) -> tensor<1024x4096xf32> - return %D : tensor<1024x4096xf32> +func.func @big_matmul(%A: tensor<256x128xf32>, %B: tensor<128x64xf32>, %C: tensor<256x64xf32>) -> tensor<256x64xf32> { + %D = linalg.matmul ins(%A, %B: tensor<256x128xf32>, tensor<128x64xf32>) outs(%C: tensor<256x64xf32>) -> tensor<256x64xf32> + return %D : tensor<256x64xf32> } // RUN: tpp-run %s -e small_matmul_with_a_cst_arg --entry-point-result=void -print -n 1 --seed=123 2>&1 > %S/lower-pack-unpack-without-transpose-small_matmul_with_a_cst_arg-only-default-passes.out