Skip to content

Commit

Permalink
Cleaning up the pass LinalgTiling file
Browse files Browse the repository at this point in the history
  • Loading branch information
athangam committed Oct 4, 2024
1 parent 5bed750 commit be778c7
Showing 1 changed file with 50 additions and 80 deletions.
130 changes: 50 additions & 80 deletions lib/TPP/Transforms/LinalgTiling.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

//===- LinalgTiling.cpp -----------------------------------------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
Expand Down Expand Up @@ -49,11 +50,6 @@ using namespace std;

namespace mlir {
namespace tpp {

// namespace {

// template <typename LinalgOp>

struct LinalgOpTiling : OpRewritePattern<linalg::BatchReduceMatmulOp> {
using OpRewritePattern<linalg::BatchReduceMatmulOp>::OpRewritePattern;

Expand All @@ -62,35 +58,26 @@ struct LinalgOpTiling : OpRewritePattern<linalg::BatchReduceMatmulOp> {

LogicalResult matchAndRewrite(linalg::BatchReduceMatmulOp linalgOp,
PatternRewriter &rewriter) const override {
//std::cout << "The First operation" << "\n";
// Get the MXN tile shape
// Get the MXN tile shape from the user input
std::vector<int64_t> tileShapeM(options.mTileShape.begin(),
options.mTileShape.end());
std::vector<int64_t> tileShapeN(options.nTileShape.begin(),
options.nTileShape.end());
std::vector<int64_t> finaltile(3);
std::set_union(tileShapeM.begin(), tileShapeM.end(), tileShapeN.begin(),
tileShapeN.end(), finaltile.begin());

// Swap from MxKxN to MxNxK
int64_t temp_k = finaltile[1];
finaltile[1] = finaltile[2];
finaltile[2] = temp_k;
//std::cout << "Break 0.3" << "\n";
std::vector<int64_t> resulttileOne(1);
std::set_difference(tileShapeM.begin(), tileShapeM.end(),
tileShapeN.begin(), tileShapeN.end(),
resulttileOne.begin());

std::vector<int64_t> resulttileTwo(1);
std::set_difference(tileShapeN.begin(), tileShapeN.end(),
tileShapeM.begin(), tileShapeM.end(),
resulttileTwo.begin());

std::vector<int64_t> resulttile(2);
std::set_union(resulttileOne.begin(), resulttileOne.end(),
resulttileTwo.begin(), resulttileTwo.end(),
resulttile.begin());
if (tileShapeM.size() == 2 && tileShapeN.size() == 2) {
if (tileShapeM[1] == tileShapeN[0]) {
finaltile[0] = tileShapeM[0];
finaltile[1] = tileShapeN[1];
finaltile[2] = tileShapeM[1];
resulttile[0] = tileShapeM[0];
resulttile[1] = tileShapeN[1];
} else {
return failure();
}
} else {
return failure();
}

SmallVector<int64_t> boundariesOne{1,
static_cast<long>(tileShapeM.size() - 1),
Expand All @@ -103,25 +90,17 @@ struct LinalgOpTiling : OpRewritePattern<linalg::BatchReduceMatmulOp> {
std::vector<int64_t> tempTileN(tileShapeN.begin(), tileShapeN.end());
SmallVector<std::vector<int64_t>> tileshapes{tempTileM, tempTileN,
resulttile};
std::vector<int> i_temp = {0, 2, 1};
int i = 0;
std::vector<int> swap_i = {0, 2, 1};
size_t i = 0;
map<int, map<int, Value>> inductionVars;
scf::ForOp innermostForLoop;
scf::ForOp reductionForLoop;


// Creating the tiled loops
for (auto itrShapeM = finaltile.begin(); itrShapeM != finaltile.end();
itrShapeM++, i++) {
int index = i_temp[i] / boundariesOne[i_temp[i]];
int offset = i_temp[i] / (finaltile.size() - 1);
//std::cout << "The index: " << index << "\n";
/*auto testing = dyn_cast<MemRefType>(linalgOp.getOperand(index).getType());
if (testing == NULL) {
std::cout << "Null ptr" << "\n";
return success();
} else {
std::cout << "Not Null ptr" << "\n";
}*/
int index = swap_i[i] / boundariesOne[swap_i[i]];
int offset = swap_i[i] / (finaltile.size() - 1);

int operandSize =
dyn_cast<MemRefType>(linalgOp.getOperand(index).getType())
Expand All @@ -148,7 +127,6 @@ struct LinalgOpTiling : OpRewritePattern<linalg::BatchReduceMatmulOp> {
inductionVars[index][effectiveOffset] = loopOp.getInductionVar();

inductionVars[indexTwo][effectiveOffsetTwo] = loopOp.getInductionVar();
//std::cout << "Break 1" << "\n";
int indexThree = resulttile.size();
int effectiveOffsetThree =
index +
Expand All @@ -157,29 +135,26 @@ struct LinalgOpTiling : OpRewritePattern<linalg::BatchReduceMatmulOp> {
.size() -
tileSizesIndex[indexThree];
if (inductionVars[indexThree][effectiveOffsetThree] == NULL) {
inductionVars[indexThree][effectiveOffsetThree] =
loopOp.getInductionVar(); }
//std::cout << "The six: " << index << " " << effectiveOffset << indexTwo << " " << effectiveOffsetTwo << indexThree << " " << effectiveOffsetThree << "\n";
innermostForLoop = loopOp;
if ((finaltile.size()-1) == (i+1)) {
Value zeroCst1 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value ubCst1 = rewriter.create<arith::ConstantIndexOp>(loc, dyn_cast<MemRefType>(linalgOp.getOperand(0).getType()).getShape()[0]);
Value stepCst1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
scf::ForOp redloopOp = rewriter.create<scf::ForOp>(linalgOp.getLoc(),
zeroCst1, ubCst1, stepCst1);
rewriter.setInsertionPointToStart(redloopOp.getBody());
reductionForLoop = redloopOp;
inductionVars[indexThree][effectiveOffsetThree] =
loopOp.getInductionVar();
}
innermostForLoop = loopOp;
if ((finaltile.size() - 1) == (i + 1)) {
Value zeroCst1 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value ubCst1 = rewriter.create<arith::ConstantIndexOp>(
loc, dyn_cast<MemRefType>(linalgOp.getOperand(0).getType())
.getShape()[0]);
Value stepCst1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
scf::ForOp redloopOp = rewriter.create<scf::ForOp>(
linalgOp.getLoc(), zeroCst1, ubCst1, stepCst1);
rewriter.setInsertionPointToStart(redloopOp.getBody());
reductionForLoop = redloopOp;
}
}

//std::cout << "Code after introduction of loops: " << "\n";

//std:: cout << "Break 2" << "\n";
// Creating subviews
SmallVector<std::vector<int64_t>> tiles = {tileShapeM, tileShapeN};

auto contractionDim = inferContractionDims(linalgOp);

for (size_t i = 0; i < linalgOp.getNumOperands() ; i++) {
for (size_t i = 0; i < linalgOp.getNumOperands(); i++) {
SmallVector<int64_t> indices;

auto input = linalgOp.getOperand(i);
Expand All @@ -192,27 +167,25 @@ struct LinalgOpTiling : OpRewritePattern<linalg::BatchReduceMatmulOp> {
SmallVector<OpFoldResult> strides;
for (size_t j = 0; j < tensorShape.size(); j++) {
if (j < tensorShape.size() - tileSizesIndex[i]) {
//std::cout << "The inside: " << tensorShape.size() << ":" << tileSizesIndex[i] << "\n";
if (j == ((tensorShape.size() - tileSizesIndex[i]) - 1) && i < (linalgOp.getNumOperands()-1)) {
offsets.push_back(reductionForLoop.getInductionVar());
//std::cout << "The inside: " << tensorShape[j] << "\n";
indices.push_back(tensorShape[j]/32);
shape.push_back(rewriter.getIndexAttr(tensorShape[j]/32));
strides.push_back(rewriter.getIndexAttr(1));

} else {
offsets.push_back(rewriter.getIndexAttr(0));
indices.push_back(tensorShape[j]);
shape.push_back(rewriter.getIndexAttr(tensorShape[j]));
strides.push_back(rewriter.getIndexAttr(1));}
if (j == ((tensorShape.size() - tileSizesIndex[i]) - 1) &&
i < (linalgOp.getNumOperands() - 1)) {
offsets.push_back(reductionForLoop.getInductionVar());
indices.push_back(tensorShape[j] / 32);
shape.push_back(rewriter.getIndexAttr(tensorShape[j] / 32));
strides.push_back(rewriter.getIndexAttr(1));

} else {
offsets.push_back(rewriter.getIndexAttr(0));
indices.push_back(tensorShape[j]);
shape.push_back(rewriter.getIndexAttr(tensorShape[j]));
strides.push_back(rewriter.getIndexAttr(1));
}
} else {
//std::cout << "The tileItr: " << i << ":" << j << ":" << tensorShape[j] << ":" << (*tileItr) << "\n";
shape.push_back(rewriter.getIndexAttr(tensorShape[j] / (*tileItr)));
indices.push_back(tensorShape[j] / (*tileItr));
strides.push_back(rewriter.getIndexAttr(1));
offsets.push_back(
inductionVars[i][tensorShape.size() - tileSizesIndex[i] + k]);
//std::cout << "InduVar: " << i << " " << tensorShape.size() - tileSizesIndex[i] + k << "\n";
k++;
tileItr++;
}
Expand All @@ -226,15 +199,13 @@ struct LinalgOpTiling : OpRewritePattern<linalg::BatchReduceMatmulOp> {
input, offsets, shape, strides);
linalgOp.setOperand(i, subview);
}
//std::cout << "Break 4" << "\n";

rewriter.setInsertionPoint(innermostForLoop.getBody(),
std::prev(innermostForLoop.getBody()->end(), 1));
auto clone = rewriter.clone(*linalgOp);
linalgOp.replaceAllUsesWith(clone);
if (linalgOp->use_empty())
rewriter.eraseOp(linalgOp);

//std::cout << "Break 5" << "\n";
return success();
}

Expand All @@ -255,7 +226,6 @@ struct LinalgTiling : public tpp::impl::LinalgTilingBase<LinalgTiling> {
LinalgTilingOptions options{mTileShape, nTileShape};
RewritePatternSet patterns(&getContext());
populateLinalgTilingPatterns(patterns, options);
//std::cout << "Break 6" << "\n";
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;

Expand Down

0 comments on commit be778c7

Please sign in to comment.