Skip to content

Commit

Permalink
Transform to lower packs/unpacks without transpose, except when the p…
Browse files Browse the repository at this point in the history
…acking occurs on a constant (#972)

A transform that enables selectively "reverting" packing on
`linalg.generic` operands and unpacking on the generic's results by
lowering `tensor.pack` to `tensor.expand_shape` with identity
permutation and `tensor.unpack` to `tensor.collapse_shape` with identity
permutation.

The new pass `-lower-packs-unpacks-without-transpose` does this
reverting on all (compatible) packed/unpacked `linalg.generic`
operands/results, except if they are constants. Similarly, the CLI flag
`--lower-pack-unpack-without-transpose` enables this pass in the default
pipeline.
  • Loading branch information
rolfmorel authored Sep 25, 2024
1 parent 7b521f2 commit db9b935
Show file tree
Hide file tree
Showing 9 changed files with 442 additions and 4 deletions.
10 changes: 9 additions & 1 deletion include/TPP/PassBundles.td
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ def DefaultTppPasses : Pass<"default-tpp-passes", "ModuleOp"> {
"unsigned", "Grid-sizes for parallel tasks.">,
Option<"linalgToVector", "linalg-to-vector",
"bool", /*default=*/"false",
"Lower linalg directly to vector.">
"Lower linalg directly to vector.">,
Option<"lowerPackUnpackWithoutTranspose", "lower-pack-unpack-without-transpose",
"bool", /*default=*/"false",
"Lower non-constant packs and unpacks reverting any dim permutations.">
];
}

Expand All @@ -51,6 +54,11 @@ def TppMapping : Pass<"tpp-mapping", "ModuleOp"> {
"memref::MemRefDialect",
"scf::SCFDialect",
"tensor::TensorDialect"];
let options= [
Option<"lowerPackUnpackWithoutTranspose", "lower-pack-unpack-without-transpose",
"bool", /*default=*/"false",
"Lower non-constant packs and unpacks reverting any dim permutations.">
];
}

def LinalgLowering : Pass<"linalg-lowering", "func::FuncOp"> {
Expand Down
6 changes: 6 additions & 0 deletions include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,12 @@ def LowerPacksAndUnPacks : Pass<"lower-packs-unpacks", "func::FuncOp"> {
"tensor::TensorDialect"];
}

def LowerPacksAndUnpacksWithoutTranspose : Pass<"lower-packs-unpacks-without-transpose",
"ModuleOp"> {
let dependentDialects = ["linalg::LinalgDialect",
"tensor::TensorDialect"];
}

def RewriteConvToMatmulOrBrgemm : Pass<"rewrite-conv-to-matmul-or-brgemm",
"func::FuncOp"> {
let summary = "Rewrite Conv2DNhwcHwcfOp/Conv2DNchwFchwOp to Matmul or Brgemm.";
Expand Down
10 changes: 8 additions & 2 deletions lib/TPP/DefaultPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ llvm::cl::opt<bool> linalgToVector("linalg-to-vector",
llvm::cl::desc("Lower linalg to vector"),
llvm::cl::init(false));

llvm::cl::opt<bool> lowerPackUnpackWithoutTranspose(
"lower-pack-unpack-without-transpose",
llvm::cl::desc("Lower packs and unpacks reverting any dim permutations"),
llvm::cl::init(false));

namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_DEFAULTPIPELINE
Expand Down Expand Up @@ -128,8 +133,9 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase<DefaultPipeline>,
pm.addPass(createGpuPipeline(GpuPipelineOptions{gpuBackend}));
} else {
// Apply the default preprocessing pass
DefaultTppPassesOptions tppDefaultOptions{linalgToLoops, parallelTaskGrid,
linalgToVector};
DefaultTppPassesOptions tppDefaultOptions{
linalgToLoops, parallelTaskGrid, linalgToVector,
lowerPackUnpackWithoutTranspose};
pm.addPass(createDefaultTppPasses(tppDefaultOptions));
}

Expand Down
4 changes: 3 additions & 1 deletion lib/TPP/DefaultTppPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ struct DefaultTppPasses
pm.addPass(createRewriteBatchMatmulToMatmul());

// Applies a set of passes at the linalg level to fuse and pack.
pm.addPass(createTppMapping());
TppMappingOptions tppMappingOptions{
lowerPackUnpackWithoutTranspose};
pm.addPass(createTppMapping(tppMappingOptions));

// Generalize tensor.pack and tensor.unpack.
pm.addPass(createLowerPacksAndUnPacks());
Expand Down
4 changes: 4 additions & 0 deletions lib/TPP/PassBundles/TppMapping.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"

#include "TPP/PassUtils.h"

Expand Down Expand Up @@ -63,6 +64,9 @@ struct TppMapping : public tpp::impl::TppMappingBase<TppMapping>,
pm.addPass(createPackMatmul());
pm.addPass(createPackVNNI());

if (lowerPackUnpackWithoutTranspose) {
pm.addPass(createLowerPacksAndUnpacksWithoutTranspose());
}
// Postprocess packing.
// Run only canonicalizer at this stage as full cleanup (mostly CSE) can
// mess up tensor producer-consumer chains used for analysis in the
Expand Down
1 change: 1 addition & 0 deletions lib/TPP/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_mlir_library(TPPTransforms
DecomposeAggregatedOps.cpp
LinalgDeGeneralize.cpp
LowerPacksAndUnpacks.cpp
LowerPacksAndUnpacksWithoutTranspose.cpp
RewriteBatchMatmulToMatmul.cpp
RewriteConvsToMatmulOrBrgemm.cpp
RewriteConvToMatmulImpl.cpp
Expand Down
202 changes: 202 additions & 0 deletions lib/TPP/Transforms/LowerPacksAndUnpacksWithoutTranspose.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
//===- LowerPacksAndUnpacksWithoutTranspose.cpp ------------------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "TPP/Passes.h"
#include "TPP/Transforms/Utils/ValueUtils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include <cstdint>
#include <numeric>
using namespace mlir;

namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_LOWERPACKSANDUNPACKSWITHOUTTRANSPOSE
#include "TPP/Passes.h.inc"
} // namespace tpp
} // namespace mlir

namespace {

/// Wrapper around linalg::lowerPack which undoes the transpose that might have
/// happened. Single user genericOp's indexing_maps is corrected accordingly.
void lowerPackAndFoldTranspose(tensor::PackOp packOp,
linalg::GenericOp genericOp, uint operandIdx,
PatternRewriter &rewriter) {
auto packInversionPerm = tensor::getPackInverseDestPerm(packOp);

auto res = linalg::lowerPack(rewriter, packOp);

if (res->transposeOp) {
// Forget about the permutation of the dims on expandShapeOp.
rewriter.replaceAllOpUsesWith(res->transposeOp, res->expandShapeOp);

// Invert corresponding transposed accesses by the single user, genericOp.
auto indexingMaps = genericOp.getIndexingMapsArray();
auto packInverseMap =
AffineMap::getPermutationMap(packInversionPerm, rewriter.getContext());
auto normalizedIndexingMap =
packInverseMap.compose(indexingMaps[operandIdx]);

SmallVector<Attribute> maps = llvm::to_vector(genericOp.getIndexingMaps());
maps[operandIdx] = AffineMapAttr::get(normalizedIndexingMap);
genericOp.setIndexingMapsAttr(ArrayAttr::get(rewriter.getContext(), maps));
}
}

struct LowerPackOnInputsFoldingTranspose
: public OpRewritePattern<linalg::GenericOp> {
// Is only called with single-user packOp operands, so callback can always
// find the (use by the) linalg.generic that is the target of the pattern.
using ControlFn = std::function<bool(tensor::PackOp)>;
ControlFn controlFn;

LowerPackOnInputsFoldingTranspose(MLIRContext *context,
ControlFn controlFn = nullptr,
PatternBenefit benefit = 1)
: OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}

LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
PatternRewriter &rewriter) const override {
bool modifiedAnOperand = false;
for (auto &&[operandIdx, inOperand] :
llvm::enumerate(genericOp.getInputs())) {
auto packOp =
dyn_cast_if_present<tensor::PackOp>(inOperand.getDefiningOp());

if (!packOp || !packOp->hasOneUse() || (controlFn && !controlFn(packOp)))
continue;

lowerPackAndFoldTranspose(packOp, genericOp, operandIdx, rewriter);

modifiedAnOperand = true;
}

return modifiedAnOperand ? success() : failure();
}
};

struct LowerPackUnpackOnOutputFoldingTranspose
: public OpRewritePattern<linalg::GenericOp> {
// Is only called with single-user packOp operands, so callback can always
// find the (use by the) linalg.generic that is the target of the pattern.
using ControlFn = std::function<bool(tensor::PackOp, tensor::UnPackOp)>;
ControlFn controlFn;

LowerPackUnpackOnOutputFoldingTranspose(MLIRContext *context,
ControlFn controlFn = nullptr,
PatternBenefit benefit = 1)
: OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}

LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
PatternRewriter &rewriter) const override {
bool modifiedAnOperand = false;
size_t numInputs = genericOp.getInputs().size();
for (auto &&[outOperandIdx, outOperand] :
llvm::enumerate(genericOp.getOutputs())) {
size_t operandIdx = numInputs + outOperandIdx;
auto result = genericOp->getResult(outOperandIdx);

if (!result.hasOneUse())
continue;

auto packOp =
dyn_cast_if_present<tensor::PackOp>(outOperand.getDefiningOp());
auto unpackOp =
llvm::dyn_cast<tensor::UnPackOp>(*(result.getUsers().begin()));

if (!packOp || !packOp->hasOneUse() || !unpackOp)
continue;

// Normalize empty outer_dims_perm to its corresponding identity perm.
auto packOuterDimsPerm = SmallVector<long>(packOp.getOuterDimsPerm());
if (packOuterDimsPerm.empty()) {
packOuterDimsPerm =
SmallVector<long>(packOp.getSource().getType().getRank());
std::iota(packOuterDimsPerm.begin(), packOuterDimsPerm.end(), 0);
}
auto unpackOuterDimsPerm = SmallVector<long>(unpackOp.getOuterDimsPerm());
if (unpackOuterDimsPerm.empty()) {
unpackOuterDimsPerm =
SmallVector<long>(unpackOp.getResult().getType().getRank());
std::iota(unpackOuterDimsPerm.begin(), unpackOuterDimsPerm.end(), 0);
}

if (unpackOp.getInnerDimsPos() != packOp.getInnerDimsPos() ||
packOuterDimsPerm != unpackOuterDimsPerm ||
(controlFn && !controlFn(packOp, unpackOp)))
continue;

auto unpackDest = unpackOp.getDest();
bool destHasStaticShape = unpackDest.getType().hasStaticShape();

lowerPackAndFoldTranspose(packOp, genericOp, operandIdx, rewriter);
auto res = linalg::lowerUnPack(rewriter, unpackOp);

// Set genericOp's result type to the adjusted type of the out parameter.
result.setType(genericOp.getOperand(operandIdx).getType());

if (auto transposeOp = res->transposeOp) {
// Forget about the transpose introduced by lowerUnPack.
rewriter.replaceAllOpUsesWith(transposeOp, transposeOp.getInput());
}

// lowerUnPack introduces a copy to maintain DPS w.r.t. unpackOp's dest.
// As we ignore permutations and, in the static case, don't do padding,
// we know the underlying buffer will be used as is and hence we do not
// need to specify a dest to update into.
auto extractSliceOp = res->extractSliceOp;
if (destHasStaticShape && extractSliceOp && extractSliceOp->hasOneUse()) {
auto copyOp =
dyn_cast<linalg::CopyOp>(*extractSliceOp->getUsers().begin());
if (copyOp && copyOp.getOutputs()[0] == unpackDest) {
rewriter.replaceAllOpUsesWith(copyOp, copyOp.getInputs()[0]);
}
}
modifiedAnOperand = true;
}

return modifiedAnOperand ? success() : failure();
}
};

struct LowerPacksAndUnpacksWithoutTranspose
: public tpp::impl::LowerPacksAndUnpacksWithoutTransposeBase<
LowerPacksAndUnpacksWithoutTranspose> {

void runOnOperation() override {
auto *ctx = &getContext();

RewritePatternSet patterns(ctx);
patterns.add<LowerPackOnInputsFoldingTranspose>(
ctx, [](tensor::PackOp packOp) {
// Only lower packOps whose argument is not a constant.
return !llvm::dyn_cast_if_present<arith::ConstantOp>(
packOp.getOperand(0).getDefiningOp());
});
patterns.add<LowerPackUnpackOnOutputFoldingTranspose>(ctx);

(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

} // namespace
Loading

0 comments on commit db9b935

Please sign in to comment.