diff --git a/include/TPP/Passes.td b/include/TPP/Passes.td index 7a04adbb4..8d698bc1b 100644 --- a/include/TPP/Passes.td +++ b/include/TPP/Passes.td @@ -55,8 +55,8 @@ def VectorizationPass : Pass<"vectorization-pass", -def LinalgTiling : Pass<"tile-linalg"> { - let summary = "Tile matmul reduction dimension."; +def BrgemmLinalgTiling : Pass<"tile-brgemm-linalg"> { + let summary = "Tile matmul and reduction dimension."; let description = [{ Tiles the innermost dimensions of the batch reduce matmul operation. Additionally, it swaps the reduction and k dimension loop. The final loop structure is as follows: M-loop->N-loop->reduction-loop->K-loop. }]; diff --git a/lib/TPP/DefaultTppPasses.cpp b/lib/TPP/DefaultTppPasses.cpp index 8a19e1689..dd6662ed4 100644 --- a/lib/TPP/DefaultTppPasses.cpp +++ b/lib/TPP/DefaultTppPasses.cpp @@ -104,7 +104,7 @@ struct DefaultTppPasses // TODO: This flag will be removed once the vector path becomes the // default lowering path. if (linalgToVector) { - pm.addNestedPass(createLinalgTiling()); + pm.addNestedPass(createBrgemmLinalgTiling()); pm.addNestedPass(createVectorizationPass()); pm.addNestedPass(createCanonicalizerPass()); } else { diff --git a/lib/TPP/Transforms/BrgemmLinalgTiling.cpp b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp new file mode 100644 index 000000000..9090cf885 --- /dev/null +++ b/lib/TPP/Transforms/BrgemmLinalgTiling.cpp @@ -0,0 +1,236 @@ +//===- BrgemmLinalgTiling.cpp -----------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements parallel loop insertion for tiling. +// +//===----------------------------------------------------------------------===// +#include "TPP/IR/TilingUtils.h" +#include "TPP/Transforms/Transforms.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/IR/AffineValueMap.h" +#include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/Support/Debug.h" +#include +#define DEBUG_TYPE "brgemm-linalg-tiling" + +namespace mlir { +namespace tpp { +#define GEN_PASS_DECL_BRGEMMLINALGTILING +#define GEN_PASS_DEF_BRGEMMLINALGTILING +#include "TPP/Passes.h.inc" +} // namespace tpp +} // namespace mlir + +using namespace mlir; +using namespace mlir::tpp; + +namespace mlir { +namespace tpp { +struct LinalgOpTiling : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LinalgOpTiling(MLIRContext *ctx, BrgemmLinalgTilingOptions tilingoptions) + : OpRewritePattern(ctx), options(tilingoptions) {} + + LogicalResult matchAndRewrite(linalg::BatchReduceMatmulOp brgemmOp, + PatternRewriter &rewriter) const override { + + if (!brgemmOp.hasPureBufferSemantics()) + return failure(); + // Get the M and N tile shape from the user input + SmallVector tileShapeM(options.mTileShape.begin(), + options.mTileShape.end()); + SmallVector tileShapeN(options.nTileShape.begin(), + options.nTileShape.end()); + + if (tileShapeM.size() != 2 || tileShapeN.size() != 2) + return failure(); + + if (tileShapeM[1] != tileShapeN[0]) + return failure(); + + // Stores the M, N, and K Tile Sizes + SmallVector mxnxkTile(3); + // Stores the M, and N Tile Sizes + SmallVector mxnTile(2); + + mxnxkTile[0] = tileShapeM[0]; + mxnxkTile[1] = tileShapeN[1]; + mxnxkTile[2] = tileShapeM[1]; + mxnTile[0] = tileShapeM[0]; + mxnTile[1] = tileShapeN[1]; + + // To assist in calculating the argument and step values for the tiled loop. + SmallVector boundariesOne{1, + static_cast(tileShapeM.size() - 1), + static_cast(mxnxkTile.size() - 1)}; + + SmallVector tileSizesIndex{static_cast(tileShapeM.size()), + static_cast(tileShapeN.size()), + static_cast(mxnTile.size())}; + SmallVector> tileshapes{tileShapeM, tileShapeN, mxnTile}; + SmallVector swap_i = {0, 2, 1}; + size_t i = 0; + std::map> inductionVars; + + // For M, N, and K loops + scf::ForOp innermostForLoop; + // For brgemm reduction loop + scf::ForOp reductionForLoop; + + // Creating the tiled loops + for (auto itrShapeM = mxnxkTile.begin(); itrShapeM != mxnxkTile.end(); + itrShapeM++, i++) { + int index = swap_i[i] / boundariesOne[swap_i[i]]; + int offset = swap_i[i] / (mxnxkTile.size() - 1); + + int operandSize = + dyn_cast(brgemmOp.getOperand(index).getType()) + .getShape() + .size(); + int effectiveOffset = operandSize - tileSizesIndex[index] + offset; + auto upperBound = + dyn_cast(brgemmOp.getOperand(index).getType()) + .getShape()[effectiveOffset]; + Location loc = brgemmOp.getLoc(); + Value zeroCst = rewriter.create(loc, 0); + Value ubCstTiledLoop = rewriter.create(loc, upperBound); + Value stepCstTiledLoop = rewriter.create(loc, upperBound/(*itrShapeM)); + // Creates M, N, and K tile loops + scf::ForOp loopOp = rewriter.create(brgemmOp.getLoc(), + zeroCst, ubCstTiledLoop, stepCstTiledLoop); + rewriter.setInsertionPointToStart(loopOp.getBody()); + int indexTwo = offset; + int operandSizeTwo = + dyn_cast(brgemmOp.getOperand(indexTwo).getType()) + .getShape() + .size(); + int effectiveOffsetTwo = operandSizeTwo - tileSizesIndex[index] + index; + + inductionVars[index][effectiveOffset] = loopOp.getInductionVar(); + + inductionVars[indexTwo][effectiveOffsetTwo] = loopOp.getInductionVar(); + int indexThree = mxnTile.size(); + int effectiveOffsetThree = + index + + dyn_cast(brgemmOp.getOperand(indexThree).getType()) + .getShape() + .size() - + tileSizesIndex[indexThree]; + if (inductionVars[indexThree][effectiveOffsetThree] == NULL) { + inductionVars[indexThree][effectiveOffsetThree] = + loopOp.getInductionVar(); + } + + innermostForLoop = loopOp; + if ((mxnxkTile.size() - 1) == (i + 1)) { + //Creates the brgemm reduction loop + Value ubCstReduction = rewriter.create( + loc, dyn_cast(brgemmOp.getOperand(0).getType()) + .getShape()[0]); + Value stepCstReduction = rewriter.create(loc, 1); + scf::ForOp redloopOp = rewriter.create( + brgemmOp.getLoc(), zeroCst, ubCstReduction, stepCstReduction); + rewriter.setInsertionPointToStart(redloopOp.getBody()); + reductionForLoop = redloopOp; + } + } + + // Creating subviews + SmallVector> tiles = {tileShapeM, tileShapeN}; + for (size_t i = 0; i < brgemmOp.getNumOperands(); i++) { + SmallVector indices; + auto input = brgemmOp.getOperand(i); + auto operandType = input.getType(); + SmallVector offsets; + size_t k = 0; + auto tileItr = tileshapes[i].begin(); + auto tensorShape = dyn_cast(operandType).getShape(); + SmallVector shape; + SmallVector strides; + for (size_t j = 0; j < tensorShape.size(); j++) { + if (j < tensorShape.size() - tileSizesIndex[i]) { + if (j == ((tensorShape.size() - tileSizesIndex[i]) - 1) && + i < (brgemmOp.getNumOperands() - 1)) { + offsets.push_back(reductionForLoop.getInductionVar()); + indices.push_back(tensorShape[j] / tensorShape[j]); + shape.push_back(rewriter.getIndexAttr(tensorShape[j] / tensorShape[j])); + strides.push_back(rewriter.getIndexAttr(1)); + + } else { + offsets.push_back(rewriter.getIndexAttr(0)); + indices.push_back(tensorShape[j]); + shape.push_back(rewriter.getIndexAttr(tensorShape[j])); + strides.push_back(rewriter.getIndexAttr(1)); + } + } else { + shape.push_back(rewriter.getIndexAttr(tensorShape[j] / (*tileItr))); + indices.push_back(tensorShape[j] / (*tileItr)); + strides.push_back(rewriter.getIndexAttr(1)); + offsets.push_back( + inductionVars[i][tensorShape.size() - tileSizesIndex[i] + k]); + k++; + tileItr++; + } + } + + auto subview = rewriter.create( + brgemmOp.getLoc(), MemRefType(), + input, offsets, shape, strides); + brgemmOp.setOperand(i, subview); + } + + rewriter.setInsertionPoint(innermostForLoop.getBody(), + std::prev(innermostForLoop.getBody()->end(), 1)); + auto clone = rewriter.clone(*brgemmOp); + brgemmOp.replaceAllUsesWith(clone); + if (brgemmOp->use_empty()) + rewriter.eraseOp(brgemmOp); + return success(); + } + +private: + BrgemmLinalgTilingOptions options; +}; + +void populateBrgemmLinalgTilingPatterns(RewritePatternSet &patterns, + BrgemmLinalgTilingOptions options) { + patterns.add(patterns.getContext(), options); +} + +struct BrgemmLinalgTiling : public tpp::impl::BrgemmLinalgTilingBase { + + using BrgemmLinalgTilingBase::BrgemmLinalgTilingBase; + + void runOnOperation() override { + BrgemmLinalgTilingOptions options{mTileShape, nTileShape}; + RewritePatternSet patterns(&getContext()); + populateBrgemmLinalgTilingPatterns(patterns, options); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteStrictness::ExistingOps; + + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config); + } +}; +} // namespace tpp +} // namespace mlir diff --git a/lib/TPP/Transforms/CMakeLists.txt b/lib/TPP/Transforms/CMakeLists.txt index 0850dbe23..7bb9f0859 100644 --- a/lib/TPP/Transforms/CMakeLists.txt +++ b/lib/TPP/Transforms/CMakeLists.txt @@ -25,7 +25,7 @@ add_mlir_library(TPPTransforms FoldIntoEltwise.cpp FoldAddIntoDest.cpp Vectorization.cpp - LinalgTiling.cpp + BrgemmLinalgTiling.cpp SplitReductionDim.cpp ADDITIONAL_HEADER_DIRS diff --git a/test/Integration/tile-brgemm-linalg-matmul.mlir b/test/Integration/tile-brgemm-linalg-matmul.mlir new file mode 100644 index 000000000..c6dc20f20 --- /dev/null +++ b/test/Integration/tile-brgemm-linalg-matmul.mlir @@ -0,0 +1,33 @@ +// RUN: tpp-opt %s | tpp-run -e entry --entry-point-result=void -print > %t.1 +// RUN: tpp-opt %s --tile-brgemm-linalg="mTile=8,8 nTile=8,16" --vectorization-pass| tpp-run -e entry --entry-point-result=void -print > %t.2 +// RUN: diff %t.1 %t.2 | FileCheck %s --check-prefix=DIFF --allow-empty + +// DIFF-NOT: {{.}} +module { + memref.global "private" constant @__constant_48x32x32xf32 : memref<48x32x32xf32> = dense<1.000000e+00> {alignment = 64 : i64} + func.func @entry(%arg0: memref<8x48x32x32xf32>) -> memref<8x48x32x32xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = memref.get_global @__constant_48x32x32xf32 : memref<48x32x32xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x48x32x32xf32> + scf.forall (%arg1, %arg2) in (8, 48) { + %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.fill ins(%cst : f32) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>) + %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> + linalg.batch_reduce_matmul ins(%subview_1, %0 : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<48x32x32xf32>) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>) + } + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x48x32x32xf32> + scf.forall (%arg1, %arg2) in (8, 48) { + %subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.fill ins(%cst : f32) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>) + %subview_1 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> + linalg.batch_reduce_matmul ins(%subview_1, %0 : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<48x32x32xf32>) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>) + } + scf.forall (%arg1, %arg2) in (8, 48) { + %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.fill ins(%cst : f32) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>) + %subview_1 = memref.subview %alloc_0[%arg1, 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> + linalg.batch_reduce_matmul ins(%subview_1, %0 : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<48x32x32xf32>) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>) + } + return %alloc : memref<8x48x32x32xf32> + } +} diff --git a/test/Passes/pass-tile-brgemm-linalg-matmul.mlir b/test/Passes/pass-tile-brgemm-linalg-matmul.mlir new file mode 100644 index 000000000..73950092a --- /dev/null +++ b/test/Passes/pass-tile-brgemm-linalg-matmul.mlir @@ -0,0 +1,150 @@ + +// RUN: tpp-opt %s --tile-brgemm-linalg="mTile=8,8 nTile=8,16" --split-input-file | FileCheck %s + +module { + func.func @entry(%arg0: memref<16x32x16x32xf32>, %arg1: memref<32x32x32x32xf32>, %arg2: memref<16x32x16x32xf32>) { + scf.forall (%arg3, %arg4) in (16, 32) { + %subview = memref.subview %arg0[%arg3, 0, 0, 0] [1, 32, 16, 32] [1, 1, 1, 1] : memref<16x32x16x32xf32> to memref<32x16x32xf32, strided<[512, 32, 1], offset: ?>> + %subview_0 = memref.subview %arg1[%arg4, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<32x32x32x32xf32> to memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>> + %subview_1 = memref.subview %arg2[%arg3, %arg4, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<16x32x16x32xf32> to memref<16x32xf32, strided<[32, 1], offset: ?>> + linalg.batch_reduce_matmul ins(%subview, %subview_0 : memref<32x16x32xf32, strided<[512, 32, 1], offset: ?>>, memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>>) outs(%subview_1 : memref<16x32xf32, strided<[32, 1], offset: ?>>) + } + return + } +} + + + +// CHECK-LABEL: func.func @entry( +// CHECK-SAME: %[[VAL_0:.*]]: memref<16x32x16x32xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: memref<32x32x32x32xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: memref<16x32x16x32xf32>) { +// CHECK: %[[VAL_3:.*]] = arith.constant 4 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 32 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_7:.*]] = arith.constant 16 : index +// CHECK: %[[VAL_8:.*]] = arith.constant 0 : index +// CHECK: scf.forall (%[[VAL_9:.*]], %[[VAL_10:.*]]) in (16, 32) { +// CHECK: %[[VAL_11:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_9]], 0, 0, 0] [1, 32, 16, 32] [1, 1, 1, 1] : memref<16x32x16x32xf32> to memref<32x16x32xf32, strided<[512, 32, 1], offset: ?>> +// CHECK: %[[VAL_12:.*]] = memref.subview %[[VAL_1]]{{\[}}%[[VAL_10]], 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<32x32x32x32xf32> to memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[VAL_13:.*]] = memref.subview %[[VAL_2]]{{\[}}%[[VAL_9]], %[[VAL_10]], 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<16x32x16x32xf32> to memref<16x32xf32, strided<[32, 1], offset: ?>> +// CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_8]] to %[[VAL_7]] step %[[VAL_6]] { +// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_6]] { +// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_4]] { +// CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_3]] { +// CHECK: %[[VAL_18:.*]] = memref.subview %[[VAL_11]]{{\[}}%[[VAL_16]], %[[VAL_14]], %[[VAL_17]]] [1, 2, 4] [1, 1, 1] : memref<32x16x32xf32, strided<[512, 32, 1], offset: ?>> to memref<1x2x4xf32, strided<[512, 32, 1], offset: ?>> +// CHECK: %[[VAL_19:.*]] = memref.subview %[[VAL_12]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_15]]] [1, 4, 2] [1, 1, 1] : memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>> to memref<1x4x2xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[VAL_20:.*]] = memref.subview %[[VAL_13]]{{\[}}%[[VAL_14]], %[[VAL_15]]] [2, 2] [1, 1] : memref<16x32xf32, strided<[32, 1], offset: ?>> to memref<2x2xf32, strided<[32, 1], offset: ?>> +// CHECK: linalg.batch_reduce_matmul ins(%[[VAL_18]], %[[VAL_19]] : memref<1x2x4xf32, strided<[512, 32, 1], offset: ?>>, memref<1x4x2xf32, strided<[1024, 32, 1], offset: ?>>) outs(%[[VAL_20]] : memref<2x2xf32, strided<[32, 1], offset: ?>>) +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: return +// CHECK: } + +// ----- + + +// RUN: tpp-opt %s --tile-brgemm-linalg="mTile=8,8 nTile=8,16" --split-input-file | FileCheck %s + +module { + memref.global "private" constant @__constant_48x32x32xf32 : memref<48x32x32xf32> = dense<1.000000e+00> {alignment = 64 : i64} + func.func @entry(%arg0: memref<8x48x32x32xf32>) -> memref<8x48x32x32xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = memref.get_global @__constant_48x32x32xf32 : memref<48x32x32xf32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x48x32x32xf32> + scf.forall (%arg1, %arg2) in (8, 48) { + %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.fill ins(%cst : f32) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>) + %subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> + linalg.batch_reduce_matmul ins(%subview_1, %0 : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<48x32x32xf32>) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>) + } + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x48x32x32xf32> + scf.forall (%arg1, %arg2) in (8, 48) { + %subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.fill ins(%cst : f32) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>) + %subview_1 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> + linalg.batch_reduce_matmul ins(%subview_1, %0 : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<48x32x32xf32>) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>) + } + scf.forall (%arg1, %arg2) in (8, 48) { + %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> + linalg.fill ins(%cst : f32) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>) + %subview_1 = memref.subview %alloc_0[%arg1, 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> + linalg.batch_reduce_matmul ins(%subview_1, %0 : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<48x32x32xf32>) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>) + } + return %alloc : memref<8x48x32x32xf32> + } +} + + + +// CHECK-LABEL: memref.global "private" constant @__constant_48x32x32xf32 : memref<48x32x32xf32> = dense<1.000000e+00> {alignment = 64 : i64} + +// CHECK-LABEL: func.func @entry( +// CHECK-SAME: %[[VAL_0:.*]]: memref<8x48x32x32xf32>) -> memref<8x48x32x32xf32> { +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 48 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 4 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 32 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_8:.*]] = memref.get_global @__constant_48x32x32xf32 : memref<48x32x32xf32> +// CHECK: %[[VAL_9:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x48x32x32xf32> +// CHECK: scf.forall (%[[VAL_10:.*]], %[[VAL_11:.*]]) in (8, 48) { +// CHECK: %[[VAL_12:.*]] = memref.subview %[[VAL_9]]{{\[}}%[[VAL_10]], %[[VAL_11]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> +// CHECK: linalg.fill ins(%[[VAL_7]] : f32) outs(%[[VAL_12]] : memref<32x32xf32, strided<[32, 1], offset: ?>>) +// CHECK: %[[VAL_13:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_10]], 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_4]] { +// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_3]] { +// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_6]] to %[[VAL_2]] step %[[VAL_1]] { +// CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_4]] { +// CHECK: %[[VAL_18:.*]] = memref.subview %[[VAL_13]]{{\[}}%[[VAL_16]], %[[VAL_14]], %[[VAL_17]]] [1, 4, 4] [1, 1, 1] : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> to memref<1x4x4xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[VAL_19:.*]] = memref.subview %[[VAL_8]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_15]]] [1, 4, 2] [1, 1, 1] : memref<48x32x32xf32> to memref<1x4x2xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[VAL_20:.*]] = memref.subview %[[VAL_12]]{{\[}}%[[VAL_14]], %[[VAL_15]]] [4, 2] [1, 1] : memref<32x32xf32, strided<[32, 1], offset: ?>> to memref<4x2xf32, strided<[32, 1], offset: ?>> +// CHECK: linalg.batch_reduce_matmul ins(%[[VAL_18]], %[[VAL_19]] : memref<1x4x4xf32, strided<[1024, 32, 1], offset: ?>>, memref<1x4x2xf32, strided<[1024, 32, 1], offset: ?>>) outs(%[[VAL_20]] : memref<4x2xf32, strided<[32, 1], offset: ?>>) +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_21:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x48x32x32xf32> +// CHECK: scf.forall (%[[VAL_22:.*]], %[[VAL_23:.*]]) in (8, 48) { +// CHECK: %[[VAL_24:.*]] = memref.subview %[[VAL_21]]{{\[}}%[[VAL_22]], %[[VAL_23]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> +// CHECK: linalg.fill ins(%[[VAL_7]] : f32) outs(%[[VAL_24]] : memref<32x32xf32, strided<[32, 1], offset: ?>>) +// CHECK: %[[VAL_25:.*]] = memref.subview %[[VAL_9]]{{\[}}%[[VAL_22]], 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: scf.for %[[VAL_26:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_4]] { +// CHECK: scf.for %[[VAL_27:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_3]] { +// CHECK: scf.for %[[VAL_28:.*]] = %[[VAL_6]] to %[[VAL_2]] step %[[VAL_1]] { +// CHECK: scf.for %[[VAL_29:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_4]] { +// CHECK: %[[VAL_30:.*]] = memref.subview %[[VAL_25]]{{\[}}%[[VAL_28]], %[[VAL_26]], %[[VAL_29]]] [1, 4, 4] [1, 1, 1] : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> to memref<1x4x4xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[VAL_31:.*]] = memref.subview %[[VAL_8]]{{\[}}%[[VAL_28]], %[[VAL_29]], %[[VAL_27]]] [1, 4, 2] [1, 1, 1] : memref<48x32x32xf32> to memref<1x4x2xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[VAL_32:.*]] = memref.subview %[[VAL_24]]{{\[}}%[[VAL_26]], %[[VAL_27]]] [4, 2] [1, 1] : memref<32x32xf32, strided<[32, 1], offset: ?>> to memref<4x2xf32, strided<[32, 1], offset: ?>> +// CHECK: linalg.batch_reduce_matmul ins(%[[VAL_30]], %[[VAL_31]] : memref<1x4x4xf32, strided<[1024, 32, 1], offset: ?>>, memref<1x4x2xf32, strided<[1024, 32, 1], offset: ?>>) outs(%[[VAL_32]] : memref<4x2xf32, strided<[32, 1], offset: ?>>) +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: scf.forall (%[[VAL_33:.*]], %[[VAL_34:.*]]) in (8, 48) { +// CHECK: %[[VAL_35:.*]] = memref.subview %[[VAL_9]]{{\[}}%[[VAL_33]], %[[VAL_34]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> +// CHECK: linalg.fill ins(%[[VAL_7]] : f32) outs(%[[VAL_35]] : memref<32x32xf32, strided<[32, 1], offset: ?>>) +// CHECK: %[[VAL_36:.*]] = memref.subview %[[VAL_21]]{{\[}}%[[VAL_33]], 0, 0, 0] [1, 48, 32, 32] [1, 1, 1, 1] : memref<8x48x32x32xf32> to memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: scf.for %[[VAL_37:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_4]] { +// CHECK: scf.for %[[VAL_38:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_3]] { +// CHECK: scf.for %[[VAL_39:.*]] = %[[VAL_6]] to %[[VAL_2]] step %[[VAL_1]] { +// CHECK: scf.for %[[VAL_40:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_4]] { +// CHECK: %[[VAL_41:.*]] = memref.subview %[[VAL_36]]{{\[}}%[[VAL_39]], %[[VAL_37]], %[[VAL_40]]] [1, 4, 4] [1, 1, 1] : memref<48x32x32xf32, strided<[1024, 32, 1], offset: ?>> to memref<1x4x4xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[VAL_42:.*]] = memref.subview %[[VAL_8]]{{\[}}%[[VAL_39]], %[[VAL_40]], %[[VAL_38]]] [1, 4, 2] [1, 1, 1] : memref<48x32x32xf32> to memref<1x4x2xf32, strided<[1024, 32, 1], offset: ?>> +// CHECK: %[[VAL_43:.*]] = memref.subview %[[VAL_35]]{{\[}}%[[VAL_37]], %[[VAL_38]]] [4, 2] [1, 1] : memref<32x32xf32, strided<[32, 1], offset: ?>> to memref<4x2xf32, strided<[32, 1], offset: ?>> +// CHECK: linalg.batch_reduce_matmul ins(%[[VAL_41]], %[[VAL_42]] : memref<1x4x4xf32, strided<[1024, 32, 1], offset: ?>>, memref<1x4x2xf32, strided<[1024, 32, 1], offset: ?>>) outs(%[[VAL_43]] : memref<4x2xf32, strided<[32, 1], offset: ?>>) +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: return %[[VAL_9]] : memref<8x48x32x32xf32> +// CHECK: }