Skip to content

Commit

Permalink
did refactoring to the pass
Browse files Browse the repository at this point in the history
  • Loading branch information
athangam committed Oct 8, 2024
1 parent dcb9f57 commit 09240f2
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions lib/TPP/Transforms/LinalgTiling.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

//===- LinalgTiling.cpp -----------------------------------------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
Expand Down Expand Up @@ -58,26 +57,28 @@ struct LinalgOpTiling : OpRewritePattern<linalg::BatchReduceMatmulOp> {

LogicalResult matchAndRewrite(linalg::BatchReduceMatmulOp linalgOp,
PatternRewriter &rewriter) const override {

if (!linalgOp.hasPureBufferSemantics())
return failure();
// Get the MXN tile shape from the user input
std::vector<int64_t> tileShapeM(options.mTileShape.begin(),
options.mTileShape.end());
std::vector<int64_t> tileShapeN(options.nTileShape.begin(),
options.nTileShape.end());
std::vector<int64_t> finaltile(3);
std::vector<int64_t> resulttile(2);
if (tileShapeM.size() == 2 && tileShapeN.size() == 2) {
if (tileShapeM[1] == tileShapeN[0]) {
finaltile[0] = tileShapeM[0];
finaltile[1] = tileShapeN[1];
finaltile[2] = tileShapeM[1];
resulttile[0] = tileShapeM[0];
resulttile[1] = tileShapeN[1];
} else {
return failure();
}
} else {
return failure();
}

if (tileShapeM.size() != 2 || tileShapeN.size() != 2)
return failure();

if (tileShapeM[1] != tileShapeN[0])
return failure();

finaltile[0] = tileShapeM[0];
finaltile[1] = tileShapeN[1];
finaltile[2] = tileShapeM[1];
resulttile[0] = tileShapeM[0];
resulttile[1] = tileShapeN[1];

SmallVector<int64_t> boundariesOne{1,
static_cast<long>(tileShapeM.size() - 1),
Expand Down Expand Up @@ -195,7 +196,7 @@ struct LinalgOpTiling : OpRewritePattern<linalg::BatchReduceMatmulOp> {
{indices}, dyn_cast<MemRefType>(operandType).getElementType());
auto [staticStrides, staticOffset] = getStridesAndOffset(subviewType);
auto subview = rewriter.create<memref::SubViewOp>(
linalgOp.getLoc(), nullptr /*dyn_cast<MemRefType>(newSubviewType)*/,
linalgOp.getLoc(), nullptr /* dyn_cast<MemRefType>(subviewType)*/,
input, offsets, shape, strides);
linalgOp.setOperand(i, subview);
}
Expand Down

0 comments on commit 09240f2

Please sign in to comment.