Skip to content

Commit

Permalink
code refactoring: renaming as brgemm pass
Browse files Browse the repository at this point in the history
  • Loading branch information
athangam committed Oct 10, 2024
1 parent d70efa2 commit e57f1d9
Show file tree
Hide file tree
Showing 6 changed files with 423 additions and 4 deletions.
4 changes: 2 additions & 2 deletions include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def VectorizationPass : Pass<"vectorization-pass",



def LinalgTiling : Pass<"tile-linalg"> {
let summary = "Tile matmul reduction dimension.";
def BrgemmLinalgTiling : Pass<"tile-brgemm-linalg"> {
let summary = "Tile matmul and reduction dimension.";
let description = [{
Tiles the innermost dimensions of the batch reduce matmul operation. Additionally, it swaps the reduction and k dimension loop. The final loop structure is as follows: M-loop->N-loop->reduction-loop->K-loop.
}];
Expand Down
2 changes: 1 addition & 1 deletion lib/TPP/DefaultTppPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ struct DefaultTppPasses
// TODO: This flag will be removed once the vector path becomes the
// default lowering path.
if (linalgToVector) {
pm.addNestedPass<func::FuncOp>(createLinalgTiling());
pm.addNestedPass<func::FuncOp>(createBrgemmLinalgTiling());
pm.addNestedPass<func::FuncOp>(createVectorizationPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
} else {
Expand Down
236 changes: 236 additions & 0 deletions lib/TPP/Transforms/BrgemmLinalgTiling.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
//===- BrgemmLinalgTiling.cpp -----------------------------------------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements parallel loop insertion for tiling.
//
//===----------------------------------------------------------------------===//
#include "TPP/IR/TilingUtils.h"
#include "TPP/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/Support/Debug.h"
#include <iostream>
#define DEBUG_TYPE "brgemm-linalg-tiling"

namespace mlir {
namespace tpp {
#define GEN_PASS_DECL_BRGEMMLINALGTILING
#define GEN_PASS_DEF_BRGEMMLINALGTILING
#include "TPP/Passes.h.inc"
} // namespace tpp
} // namespace mlir

using namespace mlir;
using namespace mlir::tpp;

namespace mlir {
namespace tpp {
struct LinalgOpTiling : OpRewritePattern<linalg::BatchReduceMatmulOp> {
using OpRewritePattern<linalg::BatchReduceMatmulOp>::OpRewritePattern;

LinalgOpTiling(MLIRContext *ctx, BrgemmLinalgTilingOptions tilingoptions)
: OpRewritePattern(ctx), options(tilingoptions) {}

LogicalResult matchAndRewrite(linalg::BatchReduceMatmulOp brgemmOp,
PatternRewriter &rewriter) const override {

if (!brgemmOp.hasPureBufferSemantics())
return failure();
// Get the M and N tile shape from the user input
SmallVector<int64_t> tileShapeM(options.mTileShape.begin(),
options.mTileShape.end());
SmallVector<int64_t> tileShapeN(options.nTileShape.begin(),
options.nTileShape.end());

if (tileShapeM.size() != 2 || tileShapeN.size() != 2)
return failure();

if (tileShapeM[1] != tileShapeN[0])
return failure();

// Stores the M, N, and K Tile Sizes
SmallVector<int64_t> mxnxkTile(3);
// Stores the M, and N Tile Sizes
SmallVector<int64_t> mxnTile(2);

mxnxkTile[0] = tileShapeM[0];
mxnxkTile[1] = tileShapeN[1];
mxnxkTile[2] = tileShapeM[1];
mxnTile[0] = tileShapeM[0];
mxnTile[1] = tileShapeN[1];

// To assist in calculating the argument and step values for the tiled loop.
SmallVector<int64_t> boundariesOne{1,
static_cast<long>(tileShapeM.size() - 1),
static_cast<long>(mxnxkTile.size() - 1)};

SmallVector<int64_t> tileSizesIndex{static_cast<long>(tileShapeM.size()),
static_cast<long>(tileShapeN.size()),
static_cast<long>(mxnTile.size())};
SmallVector<SmallVector<int64_t>> tileshapes{tileShapeM, tileShapeN, mxnTile};
SmallVector<int> swap_i = {0, 2, 1};
size_t i = 0;
std::map<int, std::map<int, Value>> inductionVars;

// For M, N, and K loops
scf::ForOp innermostForLoop;
// For brgemm reduction loop
scf::ForOp reductionForLoop;

// Creating the tiled loops
for (auto itrShapeM = mxnxkTile.begin(); itrShapeM != mxnxkTile.end();
itrShapeM++, i++) {
int index = swap_i[i] / boundariesOne[swap_i[i]];
int offset = swap_i[i] / (mxnxkTile.size() - 1);

int operandSize =
dyn_cast<MemRefType>(brgemmOp.getOperand(index).getType())
.getShape()
.size();
int effectiveOffset = operandSize - tileSizesIndex[index] + offset;
auto upperBound =
dyn_cast<MemRefType>(brgemmOp.getOperand(index).getType())
.getShape()[effectiveOffset];
Location loc = brgemmOp.getLoc();
Value zeroCst = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value ubCstTiledLoop = rewriter.create<arith::ConstantIndexOp>(loc, upperBound);
Value stepCstTiledLoop = rewriter.create<arith::ConstantIndexOp>(loc, upperBound/(*itrShapeM));
// Creates M, N, and K tile loops
scf::ForOp loopOp = rewriter.create<scf::ForOp>(brgemmOp.getLoc(),
zeroCst, ubCstTiledLoop, stepCstTiledLoop);
rewriter.setInsertionPointToStart(loopOp.getBody());
int indexTwo = offset;
int operandSizeTwo =
dyn_cast<MemRefType>(brgemmOp.getOperand(indexTwo).getType())
.getShape()
.size();
int effectiveOffsetTwo = operandSizeTwo - tileSizesIndex[index] + index;

inductionVars[index][effectiveOffset] = loopOp.getInductionVar();

inductionVars[indexTwo][effectiveOffsetTwo] = loopOp.getInductionVar();
int indexThree = mxnTile.size();
int effectiveOffsetThree =
index +
dyn_cast<MemRefType>(brgemmOp.getOperand(indexThree).getType())
.getShape()
.size() -
tileSizesIndex[indexThree];
if (inductionVars[indexThree][effectiveOffsetThree] == NULL) {
inductionVars[indexThree][effectiveOffsetThree] =
loopOp.getInductionVar();
}

innermostForLoop = loopOp;
if ((mxnxkTile.size() - 1) == (i + 1)) {
//Creates the brgemm reduction loop
Value ubCstReduction = rewriter.create<arith::ConstantIndexOp>(
loc, dyn_cast<MemRefType>(brgemmOp.getOperand(0).getType())
.getShape()[0]);
Value stepCstReduction = rewriter.create<arith::ConstantIndexOp>(loc, 1);
scf::ForOp redloopOp = rewriter.create<scf::ForOp>(
brgemmOp.getLoc(), zeroCst, ubCstReduction, stepCstReduction);
rewriter.setInsertionPointToStart(redloopOp.getBody());
reductionForLoop = redloopOp;
}
}

// Creating subviews
SmallVector<SmallVector<int64_t>> tiles = {tileShapeM, tileShapeN};
for (size_t i = 0; i < brgemmOp.getNumOperands(); i++) {
SmallVector<int64_t> indices;
auto input = brgemmOp.getOperand(i);
auto operandType = input.getType();
SmallVector<OpFoldResult> offsets;
size_t k = 0;
auto tileItr = tileshapes[i].begin();
auto tensorShape = dyn_cast<MemRefType>(operandType).getShape();
SmallVector<OpFoldResult> shape;
SmallVector<OpFoldResult> strides;
for (size_t j = 0; j < tensorShape.size(); j++) {
if (j < tensorShape.size() - tileSizesIndex[i]) {
if (j == ((tensorShape.size() - tileSizesIndex[i]) - 1) &&
i < (brgemmOp.getNumOperands() - 1)) {
offsets.push_back(reductionForLoop.getInductionVar());
indices.push_back(tensorShape[j] / tensorShape[j]);
shape.push_back(rewriter.getIndexAttr(tensorShape[j] / tensorShape[j]));
strides.push_back(rewriter.getIndexAttr(1));

} else {
offsets.push_back(rewriter.getIndexAttr(0));
indices.push_back(tensorShape[j]);
shape.push_back(rewriter.getIndexAttr(tensorShape[j]));
strides.push_back(rewriter.getIndexAttr(1));
}
} else {
shape.push_back(rewriter.getIndexAttr(tensorShape[j] / (*tileItr)));
indices.push_back(tensorShape[j] / (*tileItr));
strides.push_back(rewriter.getIndexAttr(1));
offsets.push_back(
inductionVars[i][tensorShape.size() - tileSizesIndex[i] + k]);
k++;
tileItr++;
}
}

auto subview = rewriter.create<memref::SubViewOp>(
brgemmOp.getLoc(), MemRefType(),
input, offsets, shape, strides);
brgemmOp.setOperand(i, subview);
}

rewriter.setInsertionPoint(innermostForLoop.getBody(),
std::prev(innermostForLoop.getBody()->end(), 1));
auto clone = rewriter.clone(*brgemmOp);
brgemmOp.replaceAllUsesWith(clone);
if (brgemmOp->use_empty())
rewriter.eraseOp(brgemmOp);
return success();
}

private:
BrgemmLinalgTilingOptions options;
};

void populateBrgemmLinalgTilingPatterns(RewritePatternSet &patterns,
BrgemmLinalgTilingOptions options) {
patterns.add<LinalgOpTiling>(patterns.getContext(), options);
}

struct BrgemmLinalgTiling : public tpp::impl::BrgemmLinalgTilingBase<BrgemmLinalgTiling> {

using BrgemmLinalgTilingBase::BrgemmLinalgTilingBase;

void runOnOperation() override {
BrgemmLinalgTilingOptions options{mTileShape, nTileShape};
RewritePatternSet patterns(&getContext());
populateBrgemmLinalgTilingPatterns(patterns, options);
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;

(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config);
}
};
} // namespace tpp
} // namespace mlir
2 changes: 1 addition & 1 deletion lib/TPP/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ add_mlir_library(TPPTransforms
FoldIntoEltwise.cpp
FoldAddIntoDest.cpp
Vectorization.cpp
LinalgTiling.cpp
BrgemmLinalgTiling.cpp
SplitReductionDim.cpp

ADDITIONAL_HEADER_DIRS
Expand Down
33 changes: 33 additions & 0 deletions test/Integration/tile-brgemm-linalg-matmul.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: tpp-opt %s | tpp-run -e entry --entry-point-result=void -print > %t.1
// RUN: tpp-opt %s --tile-brgemm-linalg="mTile=8,8 nTile=8,16" --vectorization-pass| tpp-run -e entry --entry-point-result=void -print > %t.2
// RUN: diff %t.1 %t.2 | FileCheck %s --check-prefix=DIFF --allow-empty

// DIFF-NOT: {{.}}
module {
memref.global "private" constant @__constant_48x32x32xf32 : memref<48x32x32xf32> = dense<1.000000e+00> {alignment = 64 : i64}
func.func @entry(%arg0: memref<8x48x32x32xf32>) -> memref<8x48x32x32xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = memref.get_global @__constant_48x32x32xf32 : memref<48x32x32xf32>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<8x48x32x32xf32>
scf.forall (%arg1, %arg2) in (8, 48) {
%subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>
linalg.fill ins(%cst : f32) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>)
%subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>>
linalg.batch_reduce_matmul ins(%subview_1, %0 : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<48x32x32xf32>) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>)
}
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x48x32x32xf32>
scf.forall (%arg1, %arg2) in (8, 48) {
%subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>
linalg.fill ins(%cst : f32) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>)
%subview_1 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>>
linalg.batch_reduce_matmul ins(%subview_1, %0 : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<48x32x32xf32>) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>)
}
scf.forall (%arg1, %arg2) in (8, 48) {
%subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>>
linalg.fill ins(%cst : f32) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>)
%subview_1 = memref.subview %alloc_0[%arg1, 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>>
linalg.batch_reduce_matmul ins(%subview_1, %0 : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<48x32x32xf32>) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>)
}
return %alloc : memref<8x48x32x32xf32>
}
}
Loading

0 comments on commit e57f1d9

Please sign in to comment.