Skip to content

Commit

Permalink
add controlFn for lowerPack and lowerUnpack
Browse files Browse the repository at this point in the history
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
  • Loading branch information
Max191 committed Sep 19, 2024
1 parent c9b33b2 commit c326384
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-codegen-decompose-pack-unpack-ops"
Expand All @@ -35,15 +36,25 @@ namespace {
struct LowerPackPattern : public OpRewritePattern<tensor::PackOp> {
using OpRewritePattern<tensor::PackOp>::OpRewritePattern;

explicit LowerPackPattern(MLIRContext *context,
std::optional<PackUnPackControlFn> controlFn)
: OpRewritePattern(context), controlFn(controlFn) {}

LogicalResult matchAndRewrite(tensor::PackOp op,
PatternRewriter &rewriter) const override {
if (controlFn && failed(controlFn.value()(op))) {
return failure();
}
FailureOr<linalg::LowerPackResult> res = linalg::lowerPack(rewriter, op);
if (failed(res)) {
return rewriter.notifyMatchFailure(
op, "cannot lower to pad + expand + transpose");
}
return success();
}

private:
std::optional<PackUnPackControlFn> controlFn;
};

/// A warpper pattern that calls linalg::lowerUnPack on tensor::UnPackOp. It
Expand All @@ -52,8 +63,15 @@ struct LowerPackPattern : public OpRewritePattern<tensor::PackOp> {
struct LowerUnPackPattern : public OpRewritePattern<tensor::UnPackOp> {
using OpRewritePattern<tensor::UnPackOp>::OpRewritePattern;

explicit LowerUnPackPattern(MLIRContext *context,
std::optional<PackUnPackControlFn> controlFn)
: OpRewritePattern(context), controlFn(controlFn) {}

LogicalResult matchAndRewrite(tensor::UnPackOp op,
PatternRewriter &rewriter) const override {
if (controlFn && failed(controlFn.value()(op))) {
return failure();
}
FailureOr<linalg::LowerUnPackOpResult> res =
linalg::lowerUnPack(rewriter, op);
if (failed(res)) {
Expand All @@ -62,23 +80,31 @@ struct LowerUnPackPattern : public OpRewritePattern<tensor::UnPackOp> {
}
return success();
}

private:
std::optional<PackUnPackControlFn> controlFn;
};

struct DecomposePackUnPackOpsPass final
: impl::DecomposePackUnPackOpsPassBase<DecomposePackUnPackOpsPass> {
using impl::DecomposePackUnPackOpsPassBase<
DecomposePackUnPackOpsPass>::DecomposePackUnPackOpsPassBase;
explicit DecomposePackUnPackOpsPass(bool tileOuterToOne,
bool useOnlyReshapes) {
explicit DecomposePackUnPackOpsPass(
bool tileOuterToOne, bool useOnlyReshapes,
std::optional<PackUnPackControlFn> controlFn) {
this->tileOuterToOne = tileOuterToOne;
this->useOnlyReshapes = useOnlyReshapes;
this->controlFn = controlFn;
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect, arith::ArithDialect, scf::SCFDialect,
tensor::TensorDialect>();
}

void runOnOperation() override;

private:
std::optional<PackUnPackControlFn> controlFn = std::nullopt;
};

} // namespace
Expand All @@ -104,7 +130,7 @@ void DecomposePackUnPackOpsPass::runOnOperation() {
// tiled to one.
if (!tileOuterToOne) {
RewritePatternSet patterns(ctx);
patterns.add<LowerPackPattern, LowerUnPackPattern>(ctx);
patterns.add<LowerPackPattern, LowerUnPackPattern>(ctx, controlFn);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
funcOp.emitError(
"failed to apply generalization patterns on pack/unpack ops for "
Expand Down Expand Up @@ -138,6 +164,9 @@ void DecomposePackUnPackOpsPass::runOnOperation() {
return tileSizes;
}));
funcOp->walk([&](tensor::PackOp op) {
if (controlFn && failed(controlFn.value()(op))) {
return;
}
FailureOr<scf::SCFTileAndFuseResult> tileAndFuseResult =
scf::tileConsumerAndFuseProducersUsingSCF(
rewriter, cast<TilingInterface>(op.getOperation()), packOptions);
Expand All @@ -163,6 +192,9 @@ void DecomposePackUnPackOpsPass::runOnOperation() {
return tileSizes;
});
funcOp->walk([&](tensor::UnPackOp op) {
if (controlFn && failed(controlFn.value()(op))) {
return;
}
FailureOr<scf::SCFTilingResult> tilingResult =
scf::tileUsingSCF(rewriter, cast<TilingInterface>(op.getOperation()),
unpackTilingOptions);
Expand Down Expand Up @@ -200,7 +232,7 @@ void DecomposePackUnPackOpsPass::runOnOperation() {
{
RewritePatternSet patterns(ctx);
if (useOnlyReshapes) {
patterns.add<LowerPackPattern, LowerUnPackPattern>(ctx);
patterns.add<LowerPackPattern, LowerUnPackPattern>(ctx, controlFn);
} else {
patterns.add<linalg::GeneralizeOuterUnitDimsPackOpPattern,
linalg::GeneralizeOuterUnitDimsUnPackOpPattern>(ctx);
Expand All @@ -212,9 +244,10 @@ void DecomposePackUnPackOpsPass::runOnOperation() {
}

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createDecomposePackUnPackOpsPass(bool tileOuterToOne, bool useOnlyReshapes) {
return std::make_unique<DecomposePackUnPackOpsPass>(tileOuterToOne,
useOnlyReshapes);
createDecomposePackUnPackOpsPass(bool tileOuterToOne, bool useOnlyReshapes,
std::optional<PackUnPackControlFn> controlFn) {
return std::make_unique<DecomposePackUnPackOpsPass>(
tileOuterToOne, useOnlyReshapes, controlFn);
}

} // namespace mlir::iree_compiler
10 changes: 8 additions & 2 deletions compiler/src/iree/compiler/Codegen/Common/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,15 @@ using ConfigFn =
std::unique_ptr<InterfacePass<FunctionOpInterface>>
createConvolutionToIGEMMPass(ConfigFn configFn);

using PackUnPackControlFn = std::function<LogicalResult(Operation *)>;
// Pass to decompose pack and unpack ops into pad/extract_slice and reshape ops.
// If specified, `controlFn` controls which ops get decomposed. The `controlFn`
// should be used with `useOnlyReshapes` set to true.
// TODO(Max191): Add a controlFn upstream for `GeneralizeOuterUnitDim*` patterns
// and remove the need to have `useOnlyReshapes = true` when using `controlFn`.
std::unique_ptr<InterfacePass<FunctionOpInterface>>
createDecomposePackUnPackOpsPass(bool tileOuterToOne,
bool useOnlyReshapes = false);
createDecomposePackUnPackOpsPass(bool tileOuterToOne, bool useOnlyReshapes,
std::optional<PackUnPackControlFn> controlFn);

std::unique_ptr<Pass> createDecomposeSoftmaxPass(bool useFusion);

Expand Down
7 changes: 5 additions & 2 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,8 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
// Step 3. Decompose pack and unpack ops and propagate the resulting reshapes.
funcPassManager.addPass(
createDecomposePackUnPackOpsPass(/*tileOuterToOne=*/false,
/*useOnlyReshapes=*/true));
/*useOnlyReshapes=*/true,
/*controlFn=*/std::nullopt));

// Step 3.5. Expand the inner dimensions of MultiMma ops in preparation for
// distribution to lanes.
Expand Down Expand Up @@ -946,7 +947,9 @@ void addGPUPackUnPackPasses(OpPassManager &funcPassManager) {
funcPassManager.addPass(createCSEPass());

funcPassManager.addPass(
createDecomposePackUnPackOpsPass(/*tileOuterToOne=*/true));
createDecomposePackUnPackOpsPass(/*tileOuterToOne=*/true,
/*useOnlyReshapes=*/false,
/*controlFn=*/std::nullopt));
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
addGPUVectorizationPasses(funcPassManager);
Expand Down

0 comments on commit c326384

Please sign in to comment.