Skip to content

Commit

Permalink
Split reduction dim pass (#975)
Browse files Browse the repository at this point in the history
Adds a pass to allow tiling contraction's innermost reduction dimension
using serial loop with in-place accumulation.

Compared to other available transformations, this rewrite computes
reduction sequentially with in-place accumulation which avoids temporary
allocation and separate reduction operation. This tiling strategy is
more friendly in terms of register and memory pressure more suitable for
low-level GPU kernel generation. Similarly, restriction to the innermost
dimension is there to simplify both usage and pass logic as the rewrite
is geared toward progressive GEMM lowering.

Additionally, a GPU vectorization control flag is added to allow
grouping of passes which target lowering through vector operations and
might not be compatible with other existing lowering strategies.
Effectively, this pass perform separate K-dim split which is currently
baked into linalg-to-xegpu lowering, and the two are not fully
compatible.
  • Loading branch information
adam-smnk authored Oct 3, 2024
1 parent db9b935 commit 37781c1
Show file tree
Hide file tree
Showing 5 changed files with 445 additions and 0 deletions.
19 changes: 19 additions & 0 deletions include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -506,4 +506,23 @@ def FoldIntoEltwise : Pass<"fold-into-eltwise", "ModuleOp"> {
"affine::AffineDialect"];
}

def SplitReductionDim : Pass<"split-reduction-dim", "func::FuncOp"> {
let summary = "Split innermost reduction dimension.";
let description = [{
Split innermost reduction dimension and compute it sequentially
using a serial loop and in-place accumulation.
}];
let dependentDialects = ["linalg::LinalgDialect",
"scf::SCFDialect",
"tensor::TensorDialect",
"memref::MemRefDialect",
"affine::AffineDialect",
"arith::ArithDialect"];
let options = [
Option<"tileSize", "tile", "int64_t",
/*default=*/"0",
"Reduction dimension tile size">,
];
}

#endif // TPP_DIALECT_TPP_PASSES
13 changes: 13 additions & 0 deletions lib/TPP/GPU/GpuPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ llvm::cl::list<int64_t>
llvm::cl::list_init<int64_t>(SmallVector<int64_t>{8, 16, 16}),
llvm::cl::CommaSeparated);

// Control GPU vectorization.
llvm::cl::opt<bool> gpuVectorize("gpu-vectorize",
llvm::cl::desc("Vectorize GPU kernel"),
llvm::cl::init(false));

namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_GPUPIPELINE
Expand Down Expand Up @@ -182,6 +187,14 @@ struct GpuPipeline : public tpp::impl::GpuPipelineBase<GpuPipeline>,
pm.addPass(createTileConsumerAndFuseProducers(threadTileOptions));
pm.addPass(createCleanup());

if (gpuVectorize) {
// Early reduction dimension splitting is incompatible with
// Linalg to XeGPU lowering that expects full GEMM.
// For now, enable only with other vectorization passes.
pm.addPass(createSplitReductionDim(SplitReductionDimOptions{kTile}));
pm.addPass(createCleanup());
}

// Preprocess and bufferize as further conversion requires memref
// abstraction.
pm.addPass(createLowerPacksAndUnPacks());
Expand Down
1 change: 1 addition & 0 deletions lib/TPP/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ add_mlir_library(TPPTransforms
FoldIntoEltwise.cpp
FoldAddIntoDest.cpp
Vectorization.cpp
SplitReductionDim.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/TPP
Expand Down
112 changes: 112 additions & 0 deletions lib/TPP/Transforms/SplitReductionDim.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
//===- SplitReductionDim.cpp -------------------------------------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements serial reduction dimension splitting.
//
//===----------------------------------------------------------------------===//

#include "TPP/Passes.h"
#include "TPP/Transforms/Utils/TransformUtils.h"

#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"

#include <algorithm>

using namespace mlir;
using namespace mlir::tpp;

namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_SPLITREDUCTIONDIM
#include "TPP/Passes.h.inc"
} // namespace tpp
} // namespace mlir

namespace {

// Split contraction's innermost reduction dimension.
struct SplitContractionReduction
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;

SplitContractionReduction(MLIRContext *ctx, SplitReductionDimOptions options)
: OpInterfaceRewritePattern<linalg::LinalgOp>(ctx), options(options) {}

LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
if (options.tileSize <= 0)
return rewriter.notifyMatchFailure(linalgOp,
"invalid reduction tile size");

FailureOr<linalg::ContractionDimensions> dims =
linalg::inferContractionDims(linalgOp);
if (failed(dims))
return rewriter.notifyMatchFailure(linalgOp, "not a contraction");

scf::SCFTilingOptions tilingOpts;
// Tile using a serial loop.
tilingOpts.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
// Tile only the innermost reduction dimension - disable tiling for all
// other dims.
SmallVector<OpFoldResult> tiles(
linalgOp.getNumLoops(),
getAsIndexOpFoldResult(rewriter.getContext(), 0));
tiles[dims->k.back()] =
getAsIndexOpFoldResult(rewriter.getContext(), options.tileSize);
tilingOpts.setTileSizes(tiles);

FailureOr<scf::SCFTilingResult> tilingResult = scf::tileUsingSCF(
rewriter, cast<TilingInterface>(linalgOp.getOperation()), tilingOpts);
if (failed(tilingResult))
return rewriter.notifyMatchFailure(linalgOp,
"failed to tile contraction");

rewriter.replaceOp(linalgOp, tilingResult->replacements);

return success();
}

private:
SplitReductionDimOptions options;
};

// Split reduction dimension.
struct SplitReductionDim
: public tpp::impl::SplitReductionDimBase<SplitReductionDim> {
using SplitReductionDimBase::SplitReductionDimBase;

void runOnOperation() override {
MLIRContext *ctx = &getContext();

SplitReductionDimOptions options{tileSize};

RewritePatternSet patterns(ctx);
patterns.add<SplitContractionReduction>(ctx, options);
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config);
}
};

} // namespace
Loading

0 comments on commit 37781c1

Please sign in to comment.