From be778c7a74e5a7c1796bb20c0a4a86560acbdfe2 Mon Sep 17 00:00:00 2001 From: athangam Date: Fri, 4 Oct 2024 01:50:51 -0700 Subject: [PATCH] Cleaning up the pass LinalgTiling file --- lib/TPP/Transforms/LinalgTiling.cpp | 130 +++++++++++----------------- 1 file changed, 50 insertions(+), 80 deletions(-) diff --git a/lib/TPP/Transforms/LinalgTiling.cpp b/lib/TPP/Transforms/LinalgTiling.cpp index 9e8702d6e..c05b617c4 100644 --- a/lib/TPP/Transforms/LinalgTiling.cpp +++ b/lib/TPP/Transforms/LinalgTiling.cpp @@ -1,3 +1,4 @@ + //===- LinalgTiling.cpp -----------------------------------------*- C++-*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. @@ -49,11 +50,6 @@ using namespace std; namespace mlir { namespace tpp { - -// namespace { - -// template - struct LinalgOpTiling : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -62,35 +58,26 @@ struct LinalgOpTiling : OpRewritePattern { LogicalResult matchAndRewrite(linalg::BatchReduceMatmulOp linalgOp, PatternRewriter &rewriter) const override { - //std::cout << "The First operation" << "\n"; - // Get the MXN tile shape + // Get the MXN tile shape from the user input std::vector tileShapeM(options.mTileShape.begin(), options.mTileShape.end()); std::vector tileShapeN(options.nTileShape.begin(), options.nTileShape.end()); std::vector finaltile(3); - std::set_union(tileShapeM.begin(), tileShapeM.end(), tileShapeN.begin(), - tileShapeN.end(), finaltile.begin()); - - // Swap from MxKxN to MxNxK - int64_t temp_k = finaltile[1]; - finaltile[1] = finaltile[2]; - finaltile[2] = temp_k; - //std::cout << "Break 0.3" << "\n"; - std::vector resulttileOne(1); - std::set_difference(tileShapeM.begin(), tileShapeM.end(), - tileShapeN.begin(), tileShapeN.end(), - resulttileOne.begin()); - - std::vector resulttileTwo(1); - std::set_difference(tileShapeN.begin(), tileShapeN.end(), - tileShapeM.begin(), tileShapeM.end(), - resulttileTwo.begin()); - std::vector resulttile(2); - std::set_union(resulttileOne.begin(), resulttileOne.end(), - resulttileTwo.begin(), resulttileTwo.end(), - resulttile.begin()); + if (tileShapeM.size() == 2 && tileShapeN.size() == 2) { + if (tileShapeM[1] == tileShapeN[0]) { + finaltile[0] = tileShapeM[0]; + finaltile[1] = tileShapeN[1]; + finaltile[2] = tileShapeM[1]; + resulttile[0] = tileShapeM[0]; + resulttile[1] = tileShapeN[1]; + } else { + return failure(); + } + } else { + return failure(); + } SmallVector boundariesOne{1, static_cast(tileShapeM.size() - 1), @@ -103,25 +90,17 @@ struct LinalgOpTiling : OpRewritePattern { std::vector tempTileN(tileShapeN.begin(), tileShapeN.end()); SmallVector> tileshapes{tempTileM, tempTileN, resulttile}; - std::vector i_temp = {0, 2, 1}; - int i = 0; + std::vector swap_i = {0, 2, 1}; + size_t i = 0; map> inductionVars; scf::ForOp innermostForLoop; scf::ForOp reductionForLoop; - + // Creating the tiled loops for (auto itrShapeM = finaltile.begin(); itrShapeM != finaltile.end(); itrShapeM++, i++) { - int index = i_temp[i] / boundariesOne[i_temp[i]]; - int offset = i_temp[i] / (finaltile.size() - 1); - //std::cout << "The index: " << index << "\n"; - /*auto testing = dyn_cast(linalgOp.getOperand(index).getType()); - if (testing == NULL) { - std::cout << "Null ptr" << "\n"; - return success(); - } else { - std::cout << "Not Null ptr" << "\n"; - }*/ + int index = swap_i[i] / boundariesOne[swap_i[i]]; + int offset = swap_i[i] / (finaltile.size() - 1); int operandSize = dyn_cast(linalgOp.getOperand(index).getType()) @@ -148,7 +127,6 @@ struct LinalgOpTiling : OpRewritePattern { inductionVars[index][effectiveOffset] = loopOp.getInductionVar(); inductionVars[indexTwo][effectiveOffsetTwo] = loopOp.getInductionVar(); - //std::cout << "Break 1" << "\n"; int indexThree = resulttile.size(); int effectiveOffsetThree = index + @@ -157,29 +135,26 @@ struct LinalgOpTiling : OpRewritePattern { .size() - tileSizesIndex[indexThree]; if (inductionVars[indexThree][effectiveOffsetThree] == NULL) { - inductionVars[indexThree][effectiveOffsetThree] = - loopOp.getInductionVar(); } - //std::cout << "The six: " << index << " " << effectiveOffset << indexTwo << " " << effectiveOffsetTwo << indexThree << " " << effectiveOffsetThree << "\n"; - innermostForLoop = loopOp; - if ((finaltile.size()-1) == (i+1)) { - Value zeroCst1 = rewriter.create(loc, 0); - Value ubCst1 = rewriter.create(loc, dyn_cast(linalgOp.getOperand(0).getType()).getShape()[0]); - Value stepCst1 = rewriter.create(loc, 1); - scf::ForOp redloopOp = rewriter.create(linalgOp.getLoc(), - zeroCst1, ubCst1, stepCst1); - rewriter.setInsertionPointToStart(redloopOp.getBody()); - reductionForLoop = redloopOp; + inductionVars[indexThree][effectiveOffsetThree] = + loopOp.getInductionVar(); + } + innermostForLoop = loopOp; + if ((finaltile.size() - 1) == (i + 1)) { + Value zeroCst1 = rewriter.create(loc, 0); + Value ubCst1 = rewriter.create( + loc, dyn_cast(linalgOp.getOperand(0).getType()) + .getShape()[0]); + Value stepCst1 = rewriter.create(loc, 1); + scf::ForOp redloopOp = rewriter.create( + linalgOp.getLoc(), zeroCst1, ubCst1, stepCst1); + rewriter.setInsertionPointToStart(redloopOp.getBody()); + reductionForLoop = redloopOp; } } - //std::cout << "Code after introduction of loops: " << "\n"; - - //std:: cout << "Break 2" << "\n"; + // Creating subviews SmallVector> tiles = {tileShapeM, tileShapeN}; - - auto contractionDim = inferContractionDims(linalgOp); - - for (size_t i = 0; i < linalgOp.getNumOperands() ; i++) { + for (size_t i = 0; i < linalgOp.getNumOperands(); i++) { SmallVector indices; auto input = linalgOp.getOperand(i); @@ -192,27 +167,25 @@ struct LinalgOpTiling : OpRewritePattern { SmallVector strides; for (size_t j = 0; j < tensorShape.size(); j++) { if (j < tensorShape.size() - tileSizesIndex[i]) { - //std::cout << "The inside: " << tensorShape.size() << ":" << tileSizesIndex[i] << "\n"; - if (j == ((tensorShape.size() - tileSizesIndex[i]) - 1) && i < (linalgOp.getNumOperands()-1)) { - offsets.push_back(reductionForLoop.getInductionVar()); - //std::cout << "The inside: " << tensorShape[j] << "\n"; - indices.push_back(tensorShape[j]/32); - shape.push_back(rewriter.getIndexAttr(tensorShape[j]/32)); - strides.push_back(rewriter.getIndexAttr(1)); - - } else { - offsets.push_back(rewriter.getIndexAttr(0)); - indices.push_back(tensorShape[j]); - shape.push_back(rewriter.getIndexAttr(tensorShape[j])); - strides.push_back(rewriter.getIndexAttr(1));} + if (j == ((tensorShape.size() - tileSizesIndex[i]) - 1) && + i < (linalgOp.getNumOperands() - 1)) { + offsets.push_back(reductionForLoop.getInductionVar()); + indices.push_back(tensorShape[j] / 32); + shape.push_back(rewriter.getIndexAttr(tensorShape[j] / 32)); + strides.push_back(rewriter.getIndexAttr(1)); + + } else { + offsets.push_back(rewriter.getIndexAttr(0)); + indices.push_back(tensorShape[j]); + shape.push_back(rewriter.getIndexAttr(tensorShape[j])); + strides.push_back(rewriter.getIndexAttr(1)); + } } else { - //std::cout << "The tileItr: " << i << ":" << j << ":" << tensorShape[j] << ":" << (*tileItr) << "\n"; shape.push_back(rewriter.getIndexAttr(tensorShape[j] / (*tileItr))); indices.push_back(tensorShape[j] / (*tileItr)); strides.push_back(rewriter.getIndexAttr(1)); offsets.push_back( inductionVars[i][tensorShape.size() - tileSizesIndex[i] + k]); - //std::cout << "InduVar: " << i << " " << tensorShape.size() - tileSizesIndex[i] + k << "\n"; k++; tileItr++; } @@ -226,15 +199,13 @@ struct LinalgOpTiling : OpRewritePattern { input, offsets, shape, strides); linalgOp.setOperand(i, subview); } - //std::cout << "Break 4" << "\n"; + rewriter.setInsertionPoint(innermostForLoop.getBody(), std::prev(innermostForLoop.getBody()->end(), 1)); auto clone = rewriter.clone(*linalgOp); linalgOp.replaceAllUsesWith(clone); if (linalgOp->use_empty()) rewriter.eraseOp(linalgOp); - - //std::cout << "Break 5" << "\n"; return success(); } @@ -255,7 +226,6 @@ struct LinalgTiling : public tpp::impl::LinalgTilingBase { LinalgTilingOptions options{mTileShape, nTileShape}; RewritePatternSet patterns(&getContext()); populateLinalgTilingPatterns(patterns, options); - //std::cout << "Break 6" << "\n"; GreedyRewriteConfig config; config.strictMode = GreedyRewriteStrictness::ExistingOps;