Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transform to lower packs/unpacks without transpose, except when the packing occurs on a constant #972

Merged
merged 9 commits into from
Sep 25, 2024
10 changes: 9 additions & 1 deletion include/TPP/PassBundles.td
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ def DefaultTppPasses : Pass<"default-tpp-passes", "ModuleOp"> {
"unsigned", "Grid-sizes for parallel tasks.">,
Option<"linalgToVector", "linalg-to-vector",
"bool", /*default=*/"false",
"Lower linalg directly to vector.">
"Lower linalg directly to vector.">,
Option<"lowerPackUnpackWithoutTranspose", "lower-pack-unpack-without-transpose",
"bool", /*default=*/"false",
"Lower non-constant packs and unpacks reverting any dim permutations.">
];
}

Expand All @@ -51,6 +54,11 @@ def TppMapping : Pass<"tpp-mapping", "ModuleOp"> {
"memref::MemRefDialect",
"scf::SCFDialect",
"tensor::TensorDialect"];
let options= [
Option<"lowerPackUnpackWithoutTranspose", "lower-pack-unpack-without-transpose",
"bool", /*default=*/"false",
"Lower non-constant packs and unpacks reverting any dim permutations.">
];
}

def LinalgLowering : Pass<"linalg-lowering", "func::FuncOp"> {
Expand Down
6 changes: 6 additions & 0 deletions include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,12 @@ def LowerPacksAndUnPacks : Pass<"lower-packs-unpacks", "func::FuncOp"> {
"tensor::TensorDialect"];
}

def LowerPacksAndUnpacksWithoutTranspose : Pass<"lower-packs-unpacks-without-transpose",
"ModuleOp"> {
let dependentDialects = ["linalg::LinalgDialect",
"tensor::TensorDialect"];
}

def RewriteConvToMatmulOrBrgemm : Pass<"rewrite-conv-to-matmul-or-brgemm",
"func::FuncOp"> {
let summary = "Rewrite Conv2DNhwcHwcfOp/Conv2DNchwFchwOp to Matmul or Brgemm.";
Expand Down
10 changes: 8 additions & 2 deletions lib/TPP/DefaultPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ llvm::cl::opt<bool> linalgToVector("linalg-to-vector",
llvm::cl::desc("Lower linalg to vector"),
llvm::cl::init(false));

llvm::cl::opt<bool> lowerPackUnpackWithoutTranspose(
"lower-pack-unpack-without-transpose",
llvm::cl::desc("Lower packs and unpacks reverting any dim permutations"),
llvm::cl::init(false));

namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_DEFAULTPIPELINE
Expand Down Expand Up @@ -128,8 +133,9 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase<DefaultPipeline>,
pm.addPass(createGpuPipeline(GpuPipelineOptions{gpuBackend}));
} else {
// Apply the default preprocessing pass
DefaultTppPassesOptions tppDefaultOptions{linalgToLoops, parallelTaskGrid,
linalgToVector};
DefaultTppPassesOptions tppDefaultOptions{
linalgToLoops, parallelTaskGrid, linalgToVector,
lowerPackUnpackWithoutTranspose};
pm.addPass(createDefaultTppPasses(tppDefaultOptions));
}

Expand Down
4 changes: 3 additions & 1 deletion lib/TPP/DefaultTppPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ struct DefaultTppPasses
pm.addPass(createRewriteBatchMatmulToMatmul());

// Applies a set of passes at the linalg level to fuse and pack.
pm.addPass(createTppMapping());
TppMappingOptions tppMappingOptions{
lowerPackUnpackWithoutTranspose};
pm.addPass(createTppMapping(tppMappingOptions));

// Generalize tensor.pack and tensor.unpack.
pm.addPass(createLowerPacksAndUnPacks());
Expand Down
4 changes: 4 additions & 0 deletions lib/TPP/PassBundles/TppMapping.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"

#include "TPP/PassUtils.h"

Expand Down Expand Up @@ -63,6 +64,9 @@ struct TppMapping : public tpp::impl::TppMappingBase<TppMapping>,
pm.addPass(createPackMatmul());
pm.addPass(createPackVNNI());

if (lowerPackUnpackWithoutTranspose) {
pm.addPass(createLowerPacksAndUnpacksWithoutTranspose());
}
// Postprocess packing.
// Run only canonicalizer at this stage as full cleanup (mostly CSE) can
// mess up tensor producer-consumer chains used for analysis in the
Expand Down
1 change: 1 addition & 0 deletions lib/TPP/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_mlir_library(TPPTransforms
DecomposeAggregatedOps.cpp
LinalgDeGeneralize.cpp
LowerPacksAndUnpacks.cpp
LowerPacksAndUnpacksWithoutTranspose.cpp
RewriteBatchMatmulToMatmul.cpp
RewriteConvsToMatmulOrBrgemm.cpp
RewriteConvToMatmulImpl.cpp
Expand Down
202 changes: 202 additions & 0 deletions lib/TPP/Transforms/LowerPacksAndUnpacksWithoutTranspose.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
//===- LowerPacksAndUnpacksWithoutTranspose.cpp ------------------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "TPP/Passes.h"
#include "TPP/Transforms/Utils/ValueUtils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include <cstdint>
#include <numeric>
using namespace mlir;

namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_LOWERPACKSANDUNPACKSWITHOUTTRANSPOSE
#include "TPP/Passes.h.inc"
} // namespace tpp
} // namespace mlir

namespace {

/// Wrapper around linalg::lowerPack which undoes the transpose that might have
/// happened. Single user genericOp's indexing_maps is corrected accordingly.
void lowerPackAndFoldTranspose(tensor::PackOp packOp,
linalg::GenericOp genericOp, uint operandIdx,
PatternRewriter &rewriter) {
auto packInversionPerm = tensor::getPackInverseDestPerm(packOp);

auto res = linalg::lowerPack(rewriter, packOp);

if (res->transposeOp) {
// Forget about the permutation of the dims on expandShapeOp.
rewriter.replaceAllOpUsesWith(res->transposeOp, res->expandShapeOp);

// Invert corresponding transposed accesses by the single user, genericOp.
auto indexingMaps = genericOp.getIndexingMapsArray();
auto packInverseMap =
AffineMap::getPermutationMap(packInversionPerm, rewriter.getContext());
auto normalizedIndexingMap =
packInverseMap.compose(indexingMaps[operandIdx]);

SmallVector<Attribute> maps = llvm::to_vector(genericOp.getIndexingMaps());
maps[operandIdx] = AffineMapAttr::get(normalizedIndexingMap);
genericOp.setIndexingMapsAttr(ArrayAttr::get(rewriter.getContext(), maps));
}
}

struct LowerPackOnInputsFoldingTranspose
: public OpRewritePattern<linalg::GenericOp> {
// Is only called with single-user packOp operands, so callback can always
// find the (use by the) linalg.generic that is the target of the pattern.
using ControlFn = std::function<bool(tensor::PackOp)>;
ControlFn controlFn;

LowerPackOnInputsFoldingTranspose(MLIRContext *context,
ControlFn controlFn = nullptr,
PatternBenefit benefit = 1)
: OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}

LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
PatternRewriter &rewriter) const override {
bool modifiedAnOperand = false;
for (auto &&[operandIdx, inOperand] :
llvm::enumerate(genericOp.getInputs())) {
auto packOp =
dyn_cast_if_present<tensor::PackOp>(inOperand.getDefiningOp());

if (!packOp || !packOp->hasOneUse() || (controlFn && !controlFn(packOp)))
continue;

lowerPackAndFoldTranspose(packOp, genericOp, operandIdx, rewriter);

modifiedAnOperand = true;
}

return modifiedAnOperand ? success() : failure();
}
};

struct LowerPackUnpackOnOutputFoldingTranspose
: public OpRewritePattern<linalg::GenericOp> {
// Is only called with single-user packOp operands, so callback can always
// find the (use by the) linalg.generic that is the target of the pattern.
using ControlFn = std::function<bool(tensor::PackOp, tensor::UnPackOp)>;
ControlFn controlFn;

LowerPackUnpackOnOutputFoldingTranspose(MLIRContext *context,
ControlFn controlFn = nullptr,
PatternBenefit benefit = 1)
: OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}

LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
PatternRewriter &rewriter) const override {
bool modifiedAnOperand = false;
size_t numInputs = genericOp.getInputs().size();
for (auto &&[outOperandIdx, outOperand] :
llvm::enumerate(genericOp.getOutputs())) {
size_t operandIdx = numInputs + outOperandIdx;
auto result = genericOp->getResult(outOperandIdx);

if (!result.hasOneUse())
continue;

auto packOp =
dyn_cast_if_present<tensor::PackOp>(outOperand.getDefiningOp());
auto unpackOp =
llvm::dyn_cast<tensor::UnPackOp>(*(result.getUsers().begin()));

if (!packOp || !packOp->hasOneUse() || !unpackOp)
continue;

// Normalize empty outer_dims_perm to its corresponding identity map.
auto packOuterDimsPerm = SmallVector<long>(packOp.getOuterDimsPerm());
if (packOuterDimsPerm.empty()) {
packOuterDimsPerm =
SmallVector<long>(packOp.getSource().getType().getRank());
std::iota(packOuterDimsPerm.begin(), packOuterDimsPerm.begin(), 0);
rolfmorel marked this conversation as resolved.
Show resolved Hide resolved
}
auto unpackOuterDimsPerm = SmallVector<long>(unpackOp.getOuterDimsPerm());
if (unpackOuterDimsPerm.empty()) {
unpackOuterDimsPerm =
SmallVector<long>(unpackOp.getResult().getType().getRank());
std::iota(unpackOuterDimsPerm.begin(), unpackOuterDimsPerm.begin(), 0);
rolfmorel marked this conversation as resolved.
Show resolved Hide resolved
}

if (unpackOp.getInnerDimsPos() != packOp.getInnerDimsPos() ||
packOuterDimsPerm != unpackOuterDimsPerm ||
(controlFn && !controlFn(packOp, unpackOp)))
continue;

auto unpackDest = unpackOp.getDest();
bool destHasStaticShape = unpackDest.getType().hasStaticShape();

lowerPackAndFoldTranspose(packOp, genericOp, operandIdx, rewriter);
auto res = linalg::lowerUnPack(rewriter, unpackOp);

// Set genericOp's result type to the adjusted type of the out parameter.
result.setType(genericOp.getOperand(operandIdx).getType());

if (auto transposeOp = res->transposeOp) {
// Forget about the transpose introduced by lowerUnPack.
rewriter.replaceAllOpUsesWith(transposeOp, transposeOp.getInput());
}

// lowerUnPack introduces a copy to maintain DPS w.r.t. unpackOp's dest.
// As we ignore permutations and, in the static case, don't do padding,
// we know the underlying buffer will be used as is and hence we do not
// need to specify a dest to update into.
auto extractSliceOp = res->extractSliceOp;
if (destHasStaticShape && extractSliceOp && extractSliceOp->hasOneUse()) {
auto copyOp =
dyn_cast<linalg::CopyOp>(*extractSliceOp->getUsers().begin());
if (copyOp && copyOp.getOutputs()[0] == unpackDest) {
rewriter.replaceAllOpUsesWith(copyOp, copyOp.getInputs()[0]);
}
}
modifiedAnOperand = true;
}

return modifiedAnOperand ? success() : failure();
}
};

struct LowerPacksAndUnpacksWithoutTranspose
: public tpp::impl::LowerPacksAndUnpacksWithoutTransposeBase<
LowerPacksAndUnpacksWithoutTranspose> {

void runOnOperation() override {
auto *ctx = &getContext();

RewritePatternSet patterns(ctx);
patterns.add<LowerPackOnInputsFoldingTranspose>(
ctx, [](tensor::PackOp packOp) {
// Only lower packOps whose argument is not a constant.
return !llvm::dyn_cast_if_present<arith::ConstantOp>(
packOp.getOperand(0).getDefiningOp());
});
patterns.add<LowerPackUnpackOnOutputFoldingTranspose>(ctx);

(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

} // namespace
Loading