Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split reduction dim pass #975

Merged
merged 10 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -506,4 +506,23 @@ def FoldIntoEltwise : Pass<"fold-into-eltwise", "ModuleOp"> {
"affine::AffineDialect"];
}

def SplitReductionDim : Pass<"split-reduction-dim", "func::FuncOp"> {
let summary = "Split innermost reduction dimension.";
let description = [{
Split innermost reduction dimension and compute it sequentially
using a serial loop and in-place accumulation.
}];
let dependentDialects = ["linalg::LinalgDialect",
"scf::SCFDialect",
"tensor::TensorDialect",
"memref::MemRefDialect",
"affine::AffineDialect",
"arith::ArithDialect"];
let options = [
Option<"tileSize", "tile", "int64_t",
/*default=*/"0",
"Reduction dimension tile size">,
];
}

#endif // TPP_DIALECT_TPP_PASSES
13 changes: 13 additions & 0 deletions lib/TPP/GPU/GpuPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ llvm::cl::list<int64_t>
llvm::cl::list_init<int64_t>(SmallVector<int64_t>{8, 16, 16}),
llvm::cl::CommaSeparated);

// Control GPU vectorization.
llvm::cl::opt<bool> gpuVectorize("gpu-vectorize",
llvm::cl::desc("Vectorize GPU kernel"),
llvm::cl::init(false));

namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_GPUPIPELINE
Expand Down Expand Up @@ -182,6 +187,14 @@ struct GpuPipeline : public tpp::impl::GpuPipelineBase<GpuPipeline>,
pm.addPass(createTileConsumerAndFuseProducers(threadTileOptions));
pm.addPass(createCleanup());

if (gpuVectorize) {
// Early reduction dimension splitting is incompatible with
// Linalg to XeGPU lowering that expects full GEMM.
// For now, enable only with other vectorization passes.
pm.addPass(createSplitReductionDim(SplitReductionDimOptions{kTile}));
pm.addPass(createCleanup());
}

// Preprocess and bufferize as further conversion requires memref
// abstraction.
pm.addPass(createLowerPacksAndUnPacks());
Expand Down
1 change: 1 addition & 0 deletions lib/TPP/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ add_mlir_library(TPPTransforms
FoldIntoEltwise.cpp
FoldAddIntoDest.cpp
Vectorization.cpp
SplitReductionDim.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/TPP
Expand Down
112 changes: 112 additions & 0 deletions lib/TPP/Transforms/SplitReductionDim.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
//===- SplitReductionDim.cpp -------------------------------------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
adam-smnk marked this conversation as resolved.
Show resolved Hide resolved
//
// This file implements serial reduction dimension splitting.
//
//===----------------------------------------------------------------------===//

#include "TPP/Passes.h"
#include "TPP/Transforms/Utils/TransformUtils.h"

#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"

#include <algorithm>

using namespace mlir;
using namespace mlir::tpp;

namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_SPLITREDUCTIONDIM
#include "TPP/Passes.h.inc"
} // namespace tpp
} // namespace mlir

namespace {

// Split contraction's innermost reduction dimension.
struct SplitContractionReduction
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;

SplitContractionReduction(MLIRContext *ctx, SplitReductionDimOptions options)
: OpInterfaceRewritePattern<linalg::LinalgOp>(ctx), options(options) {}

LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
if (options.tileSize <= 0)
return rewriter.notifyMatchFailure(linalgOp,
"invalid reduction tile size");

FailureOr<linalg::ContractionDimensions> dims =
linalg::inferContractionDims(linalgOp);
if (failed(dims))
return rewriter.notifyMatchFailure(linalgOp, "not a contraction");

scf::SCFTilingOptions tilingOpts;
// Tile using a serial loop.
tilingOpts.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
// Tile only the innermost reduction dimension - disable tiling for all
// other dims.
SmallVector<OpFoldResult> tiles(
linalgOp.getNumLoops(),
getAsIndexOpFoldResult(rewriter.getContext(), 0));
tiles[dims->k.back()] =
getAsIndexOpFoldResult(rewriter.getContext(), options.tileSize);
tilingOpts.setTileSizes(tiles);

FailureOr<scf::SCFTilingResult> tilingResult = scf::tileUsingSCF(
rewriter, cast<TilingInterface>(linalgOp.getOperation()), tilingOpts);
if (failed(tilingResult))
return rewriter.notifyMatchFailure(linalgOp,
"failed to tile contraction");

rewriter.replaceOp(linalgOp, tilingResult->replacements);

return success();
}

private:
SplitReductionDimOptions options;
};

// Split reduction dimension.
struct SplitReductionDim
: public tpp::impl::SplitReductionDimBase<SplitReductionDim> {
using SplitReductionDimBase::SplitReductionDimBase;

void runOnOperation() override {
MLIRContext *ctx = &getContext();

SplitReductionDimOptions options{tileSize};

RewritePatternSet patterns(ctx);
patterns.add<SplitContractionReduction>(ctx, options);
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config);
}
};

} // namespace
Loading