diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 60cae77644291..b4da6d3d37354 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -595,18 +595,17 @@ class ExpansionInfo { // the expanded op. LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand, ArrayRef reassociationMaps, - ArrayRef expandedShape, - ArrayRef collapsedShape, + ArrayRef expandedShape, PatternRewriter &rewriter); unsigned getOrigOpNumDims() const { return reassociation.size(); } unsigned getExpandedOpNumDims() const { return expandedOpNumDims; } ReassociationIndicesRef getExpandedDims(unsigned i) const { return reassociation[i]; } - ArrayRef getExpandedShapeOfDim(unsigned i) const { + ArrayRef getExpandedShapeOfDim(unsigned i) const { return expandedShapeMap[i]; } - ArrayRef getOriginalShape() const { return originalLoopExtent; } + ArrayRef getOriginalShape() const { return originalLoopExtent; } private: /// Reassociation from the dimensions in the original operation to the @@ -614,9 +613,9 @@ class ExpansionInfo { SmallVector reassociation; /// Mapping from extent of loops in the original operation, to the extent of /// loops in the expanded operation. - SmallVector> expandedShapeMap; + SmallVector> expandedShapeMap; /// Extent of the loop in the original operation. - SmallVector originalLoopExtent; + SmallVector originalLoopExtent; unsigned expandedOpNumDims; }; } // namespace @@ -624,15 +623,17 @@ class ExpansionInfo { LogicalResult ExpansionInfo::compute(LinalgOp linalgOp, OpOperand *fusableOpOperand, ArrayRef reassociationMaps, - ArrayRef expandedShape, - ArrayRef collapsedShape, + ArrayRef expandedShape, PatternRewriter &rewriter) { if (reassociationMaps.empty()) return failure(); AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand); - SmallVector originalLoopRange = linalgOp.getStaticLoopRanges(); - originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end()); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(linalgOp); + originalLoopExtent = llvm::map_to_vector( + linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()), + [](Range r) { return r.size; }); reassociation.clear(); expandedShapeMap.clear(); @@ -644,7 +645,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp, unsigned pos = cast(resultExpr.value()).getPosition(); AffineMap foldedDims = reassociationMaps[resultExpr.index()]; numExpandedDims[pos] = foldedDims.getNumResults(); - ArrayRef shape = + ArrayRef shape = expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]); expandedShapeMap[pos].assign(shape.begin(), shape.end()); } @@ -665,33 +666,6 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp, return success(); } -/// Expanding the body of a linalg operation requires adaptations of the -/// accessed loop indices. Specifically, access of indices in the original -/// operation need to be replaced with linearizations of indices in the expanded -/// op. That requires the shape of the expanded dimensions to be static (at -/// least all but the most significant). For now check that these are all -/// statically sized. Note that this could be extended to handle dynamic case, -/// but the implementation below uses `affine.apply` which seems to have issues -/// when the shapes are not static. -static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp, - const ExpansionInfo &expansionInfo, - PatternRewriter &rewriter) { - if (!linalgOp.hasIndexSemantics()) - return success(); - for (unsigned i : llvm::seq(0, expansionInfo.getOrigOpNumDims())) { - ArrayRef expandedShape = expansionInfo.getExpandedShapeOfDim(i); - if (expandedShape.size() == 1) - continue; - for (int64_t shape : expandedShape.drop_front()) { - if (ShapedType::isDynamic(shape)) { - return rewriter.notifyMatchFailure( - linalgOp, "cannot expand due to index semantics and dynamic dims"); - } - } - } - return success(); -} - /// Return the indexing map to use in the expanded op for a given the /// `indexingMap` of the original operation. static AffineMap @@ -713,16 +687,28 @@ getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, /// Return the type of the operand/result to use in the expanded op given the /// type in the original op. -static RankedTensorType getExpandedType(RankedTensorType originalType, - AffineMap indexingMap, - const ExpansionInfo &expansionInfo) { - SmallVector expandedShape; +static std::tuple, RankedTensorType> +getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap, + const ExpansionInfo &expansionInfo) { + SmallVector expandedStaticShape; + SmallVector expandedShape; for (AffineExpr expr : indexingMap.getResults()) { unsigned dim = cast(expr).getPosition(); - auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim); + ArrayRef dimExpansion = + expansionInfo.getExpandedShapeOfDim(dim); + llvm::append_range(expandedStaticShape, + llvm::map_range(dimExpansion, [](OpFoldResult ofr) { + std::optional staticShape = + getConstantIntValue(ofr); + if (staticShape) { + return staticShape.value(); + } + return ShapedType::kDynamic; + })); expandedShape.append(dimExpansion.begin(), dimExpansion.end()); } - return RankedTensorType::get(expandedShape, originalType.getElementType()); + return {expandedShape, RankedTensorType::get(expandedStaticShape, + originalType.getElementType())}; } /// Returns the reassociation maps to use in the `tensor.expand_shape` @@ -770,49 +756,27 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, // Linearize the expanded indices of the original index dimension. OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointAfter(indexOp); - ArrayRef expandedDimsShape = + ArrayRef expandedDimsShape = expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front(); SmallVector expandedIndices; expandedIndices.reserve(expandedDims.size() - 1); llvm::transform( expandedDims.drop_front(), std::back_inserter(expandedIndices), [&](int64_t dim) { return rewriter.create(loc, dim); }); - Value newIndex = rewriter.create(loc, expandedDims.front()); + OpFoldResult newIndex = + rewriter.create(loc, expandedDims.front()).getResult(); for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) { - assert(!ShapedType::isDynamic(std::get<0>(it))); - AffineExpr idx, acc; + AffineExpr idx, acc, shape; bindDims(rewriter.getContext(), idx, acc); - newIndex = rewriter.create( - indexOp.getLoc(), idx + acc * std::get<0>(it), - ValueRange{std::get<1>(it), newIndex}); - } - rewriter.replaceOp(indexOp, newIndex); - } -} - -/// Checks if a single dynamic dimension expanded into multiple dynamic -/// dimensions. -static LogicalResult -validateDynamicDimExpansion(LinalgOp linalgOp, - const ExpansionInfo &expansionInfo, - PatternRewriter &rewriter) { - for (unsigned i : llvm::seq(0, expansionInfo.getOrigOpNumDims())) { - ArrayRef expandedShape = expansionInfo.getExpandedShapeOfDim(i); - if (expandedShape.size() == 1) - continue; - bool foundDynamic = false; - for (int64_t shape : expandedShape) { - if (!ShapedType::isDynamic(shape)) - continue; - if (foundDynamic) { - return rewriter.notifyMatchFailure( - linalgOp, "cannot infer expanded shape with multiple dynamic " - "dims in the same reassociation group"); - } - foundDynamic = true; + bindSymbols(rewriter.getContext(), shape); + newIndex = affine::makeComposedFoldedAffineApply( + rewriter, indexOp.getLoc(), idx + acc * shape, + ArrayRef{std::get<1>(it), newIndex, std::get<0>(it)}); } + Value newIndexVal = + getValueOrCreateConstantIndexOp(rewriter, indexOp.getLoc(), newIndex); + rewriter.replaceOp(indexOp, newIndexVal); } - return success(); } /// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op @@ -826,31 +790,25 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, "preconditions for fuse operation failed"); Location loc = linalgOp.getLoc(); - // Check if reshape is expanding or collapsing. - auto expandingReshapeOp = dyn_cast(*reshapeOp); - auto collapsingReshapeOp = dyn_cast(*reshapeOp); - bool isExpanding = (expandingReshapeOp != nullptr); - RankedTensorType expandedType = isExpanding - ? expandingReshapeOp.getResultType() - : collapsingReshapeOp.getSrcType(); - RankedTensorType collapsedType = isExpanding - ? expandingReshapeOp.getSrcType() - : collapsingReshapeOp.getResultType(); + SmallVector expandedShape, collapsedShape; + SmallVector reassociationIndices; + Value src; + if (auto expandingReshapeOp = dyn_cast(reshapeOp)) { + expandedShape = expandingReshapeOp.getMixedOutputShape(); + reassociationIndices = expandingReshapeOp.getReassociationMaps(); + src = expandingReshapeOp.getSrc(); + } else { + auto collapsingReshapeOp = dyn_cast(reshapeOp); + expandedShape = tensor::getMixedSizes( + rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc()); + reassociationIndices = collapsingReshapeOp.getReassociationMaps(); + src = collapsingReshapeOp.getSrc(); + } ExpansionInfo expansionInfo; if (failed(expansionInfo.compute( - linalgOp, fusableOpOperand, - isExpanding ? expandingReshapeOp.getReassociationMaps() - : collapsingReshapeOp.getReassociationMaps(), - expandedType.getShape(), collapsedType.getShape(), rewriter))) - return std::nullopt; - - // TODO: With the support of multiple dynamic dims expansion in - // tensor.expand_shape op, this case can be handled. - if (failed(validateDynamicDimExpansion(linalgOp, expansionInfo, rewriter))) - return std::nullopt; - - if (failed(isLinalgOpExpandable(linalgOp, expansionInfo, rewriter))) + linalgOp, fusableOpOperand, reassociationIndices, + expandedShape, rewriter))) return std::nullopt; SmallVector expandedOpIndexingMaps = llvm::to_vector<4>( @@ -866,15 +824,16 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, expandedOpOperands.reserve(linalgOp.getNumDpsInputs()); for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) { if (opOperand == fusableOpOperand) { - expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc() - : collapsingReshapeOp.getSrc()); + expandedOpOperands.push_back(src); continue; } if (auto opOperandType = dyn_cast(opOperand->get().getType())) { AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); - RankedTensorType expandedOperandType = - getExpandedType(opOperandType, indexingMap, expansionInfo); + SmallVector expandedOperandShape; + RankedTensorType expandedOperandType; + std::tie(expandedOperandShape, expandedOperandType) = + getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo); if (expandedOperandType != opOperand->get().getType()) { // Reshape the operand to get the right type. SmallVector reassociation = @@ -888,7 +847,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, /*isExpandingReshape=*/true))) return std::nullopt; expandedOpOperands.push_back(rewriter.create( - loc, expandedOperandType, opOperand->get(), reassociation)); + loc, expandedOperandType, opOperand->get(), reassociation, + expandedOperandShape)); continue; } } @@ -899,8 +859,10 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) { AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand); auto opOperandType = cast(opOperand.get().getType()); - RankedTensorType expandedOutputType = - getExpandedType(opOperandType, indexingMap, expansionInfo); + SmallVector expandedOutputShape; + RankedTensorType expandedOutputType; + std::tie(expandedOutputShape, expandedOutputType) = + getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo); if (expandedOutputType != opOperand.get().getType()) { SmallVector reassociation = getReassociationForExpansion(indexingMap, expansionInfo); @@ -913,7 +875,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, /*isExpandingReshape=*/true))) return std::nullopt; outputs.push_back(rewriter.create( - loc, expandedOutputType, opOperand.get(), reassociation)); + loc, expandedOutputType, opOperand.get(), reassociation, + expandedOutputShape)); } else { outputs.push_back(opOperand.get()); } diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir index ef853e4d662a7..57904f912a35b 100644 --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -30,20 +30,14 @@ func.func @generic_op_reshape_producer_fusion(%arg0 : tensor, // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: f32 -// CHECK: %[[C4:.+]] = arith.constant 4 : index -// CHECK: %[[C2:.+]] = arith.constant 2 : index +// CHECK: %[[C3:.+]] = arith.constant 3 : index // CHECK: %[[C1:.+]] = arith.constant 1 : index // CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor -// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor -// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor -// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM_1]], %[[C4]] : index -// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM]], %[[DIM_0]], %[[VAL_0]], 4] : tensor into tensor -// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor -// CHECK: %[[DIM_3:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor -// CHECK: %[[DIM_4:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor -// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_4]], %[[C4]] : index -// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM_2]], %[[DIM_3]], %[[VAL_1]], 4] : tensor into tensor +// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor +// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG0]], %[[C3]] : tensor +// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM_1]], %[[DIM]], %[[DIM_0]], 4] : tensor into tensor +// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM_1]], %[[DIM]], %[[DIM_0]], 4] : tensor into tensor // CHECK: %[[T3:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]], #[[MAP6]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] @@ -88,21 +82,9 @@ func.func @generic_op_reshape_consumer_fusion(%arg0 : tensor, // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: f32 // CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index -// CHECK: %[[C20:.+]] = arith.constant 20 : index -// CHECK: %[[C1:.+]] = arith.constant 1 : index -// CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor -// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C20]] : index -// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM]], 4, %[[VAL_0]], 5] : tensor into tensor -// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor -// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor -// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C20]] : index -// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_1]], 4, %[[VAL_1]], 5] : tensor into tensor -// CHECK: %[[DIM_4:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor -// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index -// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], 4, %[[VAL_2]], 5] : tensor into tensor +// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 5] : tensor into tensor +// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 5] : tensor into tensor +// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 5] : tensor into tensor // CHECK: %[[T3:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]], #[[MAP2]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] @@ -137,26 +119,9 @@ func.func @reshape_as_consumer_permutation // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index, %[[SZ2:.+]]: index -// CHECK: %[[C12:.+]] = arith.constant 12 : index -// CHECK: %[[C2:.+]] = arith.constant 2 : index -// CHECK: %[[C1:.+]] = arith.constant 1 : index -// CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor -// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor -// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C12]] : index -// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C2]] : index -// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [3, 4, %[[VAL_0]], %[[VAL_1]], 2, %[[DIM_1]]] : tensor into tensor<3x4x?x?x2x?xf32> -// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor -// CHECK: %[[DIM_3:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor -// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_2]], %[[C12]] : index -// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [3, 4, %[[VAL_2]], %[[DIM_3]]] : tensor into tensor<3x4x?x?xf32> -// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[DIM_6:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor -// CHECK: %[[DIM_7:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor -// CHECK: %[[VAL_3:.+]] = arith.divsi %[[DIM_5]], %[[C2]] : index -// CHECK: %[[VAL_4:.+]] = arith.divsi %[[DIM_7]], %[[C12]] : index -// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [%[[VAL_3]], 2, %[[DIM_6]], 3, 4, %[[VAL_4]]] : tensor into tensor +// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [3, 4, %[[SZ2]], %[[SZ0]], 2, %[[SZ1]]] : tensor into tensor<3x4x?x?x2x?xf32> +// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [3, 4, %[[SZ2]], %[[SZ1]]] : tensor into tensor<3x4x?x?xf32> +// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [%[[SZ0]], 2, %[[SZ1]], 3, 4, %[[SZ2]]] : tensor into tensor // CHECK: %[[T3:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP8]], #[[MAP9]], #[[MAP10]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"] @@ -195,7 +160,7 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>) // CHECK-SAME: : tensor<8x33x4xf32> // CHECK-DAG: %[[INIT:.+]] = tensor.empty() // CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32> -// CHECK: %[[T1:.+]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32> +// CHECK: %[[T1:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32> // CHECK: %[[T2:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]] // CHECK-SAME: ["parallel", "parallel", "parallel"] @@ -235,7 +200,7 @@ func.func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor, } // Only check the body in the indexed version of the test. -// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 4)> +// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 4)> // CHECK: func @indexed_consumer_reshape_producer_fusion // CHECK: linalg.generic // CHECK: ^{{.*}}( @@ -245,7 +210,7 @@ func.func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor, // CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index // CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 : index // CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index -// CHECK-DAG: %[[T3:.+]] = affine.apply #[[MAP]](%[[IDX1]], %[[IDX0]]) +// CHECK-DAG: %[[T3:.+]] = affine.apply #[[MAP]]()[%[[IDX1]], %[[IDX0]]] // CHECK: %[[T4:.+]] = arith.muli %[[ARG3]], %[[ARG4]] // CHECK: %[[T5:.+]] = arith.index_cast %[[T3]] // CHECK: %[[T6:.+]] = arith.addi %[[T4]], %[[T5]] @@ -284,8 +249,7 @@ func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor, } // Only check the body in the indexed version of the test. -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 4)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 5)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> (s0 * 5 + s1 * 20 + s2)> // CHECK: func @indexed_producer_reshape_consumer_fusion // CHECK: linalg.generic // CHECK: ^{{.*}}( @@ -295,12 +259,11 @@ func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor, // CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index // CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 : index // CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index -// CHECK: %[[T1:.+]] = affine.apply #[[MAP1]](%[[IDX2]], %[[IDX1]]) -// CHECK: %[[T2:.+]] = affine.apply #[[MAP2]](%[[IDX3]], %[[T1]]) +// CHECK: %[[T1:.+]] = affine.apply #[[MAP1]]()[%[[IDX2]], %[[IDX1]], %[[IDX3]]] // CHECK: %[[T4:.+]] = arith.muli %[[ARG3]], %[[ARG4]] // CHECK: %[[T5:.+]] = arith.index_cast %[[IDX0]] // CHECK: %[[T6:.+]] = arith.addi %[[T4]], %[[T5]] -// CHECK: %[[T7:.+]] = arith.index_cast %[[T2]] +// CHECK: %[[T7:.+]] = arith.index_cast %[[T1]] // CHECK: %[[T8:.+]] = arith.addi %[[T6]], %[[T7]] // CHECK: linalg.yield %[[T8]] @@ -339,16 +302,15 @@ func.func @reshape_as_consumer_permutation // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)> // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)> -// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 6)> -// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 7)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 3)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<()[s0, s1, s2] -> (s0 * 7 + s1 * 42 + s2)> // CHECK: func @reshape_as_consumer_permutation // CHECK-SAME: %[[ARG0:.+]]: tensor<210x6x4xi32> // CHECK-SAME: %[[ARG1:.+]]: tensor<210x4xi32> // CHECK-DAG: %[[INIT:.+]] = tensor.empty() // CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [5, 6, 7, 2, 3, 4] : tensor<210x6x4xi32> into tensor<5x6x7x2x3x4xi32> // CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [5, 6, 7, 4] : tensor<210x4xi32> into tensor<5x6x7x4xi32> -// CHECK: %[[T3:.+]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32> +// CHECK: %[[T3:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32> // CHECK: %[[T4:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] // CHECK-SAME: ins(%[[T1]], %[[T2]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>) @@ -362,13 +324,12 @@ func.func @reshape_as_consumer_permutation // CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index // CHECK-DAG: %[[IDX4:.+]] = linalg.index 4 : index // CHECK-DAG: %[[IDX5:.+]] = linalg.index 5 : index -// CHECK-DAG: %[[T5:.+]] = affine.apply #[[MAP3]](%[[IDX1]], %[[IDX0]]) -// CHECK-DAG: %[[T6:.+]] = affine.apply #[[MAP4]](%[[IDX3]], %[[IDX2]]) -// CHECK-DAG: %[[T7:.+]] = affine.apply #[[MAP5]](%[[IDX4]], %[[T6]]) +// CHECK-DAG: %[[T5:.+]] = affine.apply #[[MAP3]]()[%[[IDX1]], %[[IDX0]]] +// CHECK-DAG: %[[T6:.+]] = affine.apply #[[MAP4]]()[%[[IDX3]], %[[IDX2]], %[[IDX4]]] // CHECK-DAG: %[[T8:.+]] = arith.addi %[[ARG8]], %[[ARG9]] // CHECK: %[[T9:.+]] = arith.index_cast %[[T5]] // CHECK: %[[T10:.+]] = arith.addi %[[T8]], %[[T9]] -// CHECK: %[[T11:.+]] = arith.index_cast %[[T7]] +// CHECK: %[[T11:.+]] = arith.index_cast %[[T6]] // CHECK: %[[T12:.+]] = arith.addi %[[T10]], %[[T11]] // CHECK: %[[T13:.+]] = arith.index_cast %[[IDX5]] // CHECK: %[[T14:.+]] = arith.addi %[[T12]], %[[T13]] @@ -403,7 +364,7 @@ func.func @reshape_as_producer_projected_permutation( // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 8)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 8)> // CHECK: @reshape_as_producer_projected_permutation // CHECK-SAME: %[[ARG0:.+]]: tensor<33x8x?xi32> // CHECK: %[[RES:.+]] = linalg.generic @@ -416,7 +377,7 @@ func.func @reshape_as_producer_projected_permutation( // CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index // CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 : index // CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index -// CHECK-DAG: %[[T0:.+]] = affine.apply #[[MAP2]](%[[IDX1]], %[[IDX0]]) +// CHECK-DAG: %[[T0:.+]] = affine.apply #[[MAP2]]()[%[[IDX1]], %[[IDX0]]] // CHECK: %[[T1:.+]] = arith.index_cast %[[T0]] : index to i32 // CHECK: %[[T2:.+]] = arith.addi %[[ARG1]], %[[T1]] : i32 // CHECK: %[[T3:.+]] = arith.index_cast %[[IDX2]] : index to i32 @@ -458,21 +419,9 @@ func.func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor, // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index -// CHECK: %[[C20:.+]] = arith.constant 20 : index -// CHECK: %[[C1:.+]] = arith.constant 1 : index -// CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor -// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C20]] : index -// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[VAL_0]], 4, 5, %[[DIM_0]]] : tensor into tensor -// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor -// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor -// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_1]], %[[C20]] : index -// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[VAL_1]], 4, 5, %[[DIM_2]]] : tensor into tensor -// CHECK: %[[DIM_4:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor -// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index -// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], %[[VAL_2]], 4, 5] : tensor into tensor +// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[SZ1]], 4, 5, %[[SZ0]]] : tensor into tensor +// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[SZ1]], 4, 5, %[[SZ0]]] : tensor into tensor +// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], %[[SZ1]], 4, 5] : tensor into tensor // CHECK: %[[T3:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP4]], #[[MAP4]], #[[MAP5]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] @@ -482,7 +431,7 @@ func.func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor, // ----- -func.func @no_fuse_dynamic_dims(%arg0: tensor) -> tensor { +func.func @fuse_dynamic_dims(%arg0: tensor) -> tensor { %c0 = arith.constant 0 : index %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor into tensor %1 = tensor.dim %0, %c0 : tensor @@ -498,39 +447,21 @@ func.func @no_fuse_dynamic_dims(%arg0: tensor) -> tensor { return %3 : tensor } -// CHECK: func @no_fuse_dynamic_dims +// CHECK: func @fuse_dynamic_dims // CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] +// CHECK: %[[EMPTY:.+]] = tensor.empty +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK: %[[EXPAND_SHAPE:.+]] = tensor.expand_shape %[[EMPTY]] {{\[}}[0, 1]{{\]}} +// CHECK-SAME: output_shape [%[[D0]], %[[D1]]] // CHECK: %[[GENERIC:.+]] = linalg.generic -// CHECK-SAME: ins(%[[RESHAPE]] : tensor) -// CHECK: return %[[GENERIC]] - -// ----- - -func.func @no_fuse_mismatched_dynamism(%arg0: tensor<2x1xi64>, %arg1: tensor) -> tensor<2xi64> { - %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<2x1xi64> into tensor<2xi64> - %1 = tensor.empty() : tensor<2xi64> - %2 = linalg.generic - {indexing_maps = [affine_map<(d0) -> (d0)>, - affine_map<(d0) -> (d0)>, - affine_map<(d0) -> (d0)>], - iterator_types = ["parallel"]} - ins(%0, %arg1 : tensor<2xi64>, tensor) - outs(%1 : tensor<2xi64>) { - ^bb0(%arg4: i64, %arg5: i64, %arg6: i64): - %3 = arith.addi %arg4, %arg5 : i64 - linalg.yield %3 : i64 - } -> tensor<2xi64> - return %2 : tensor<2xi64> -} - -// CHECK: func @no_fuse_mismatched_dynamism -// CHECK-SAME: %[[ARG0:.+]]: tensor<2x1xi64> -// CHECK-SAME: %[[ARG1:.+]]: tensor -// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] -// CHECK: %[[GENERIC:.+]] = linalg.generic -// CHECK-SAME: ins(%[[RESHAPE]], %[[ARG1]] : tensor<2xi64>, tensor) -// CHECK: return %[[GENERIC]] +// CHECK-SAME: ins(%[[ARG0]] : +// CHECK-SAME: outs(%[[EXPAND_SHAPE]] : +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1]{{\]}} +// CHECK: return %[[COLLAPSE]] // ----- @@ -562,32 +493,10 @@ func.func @reshape_as_consumer_permutation_with_multiple_results // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor // CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index, %[[SZ2:.+]]: index, %[[SZ3:.+]]: index, %[[SZ4:.+]]: index -// CHECK: %[[C12:.+]] = arith.constant 12 : index -// CHECK: %[[C2:.+]] = arith.constant 2 : index -// CHECK: %[[C1:.+]] = arith.constant 1 : index -// CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor -// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor -// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C12]] : index -// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C2]] : index -// CHECK: %[[RESHAPE0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [3, 4, %[[VAL_0]], %[[VAL_1]], 2, %[[DIM_1]]] : tensor into tensor<3x4x?x?x2x?xf32> -// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor -// CHECK: %[[DIM_3:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor -// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_2]], %[[C12]] : index -// CHECK: %[[RESHAPE1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [3, 4, %[[VAL_2]], %[[DIM_3]]] : tensor into tensor<3x4x?x?xf32> -// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[DIM_6:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor -// CHECK: %[[DIM_7:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor -// CHECK: %[[VAL_3:.+]] = arith.divsi %[[DIM_5]], %[[C2]] : index -// CHECK: %[[VAL_4:.+]] = arith.divsi %[[DIM_7]], %[[C12]] : index -// CHECK: %[[RESHAPE2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [%[[VAL_3]], 2, %[[DIM_6]], 3, 4, %[[VAL_4]]] : tensor into tensor -// CHECK: %[[DIM_9:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[DIM_10:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor -// CHECK: %[[DIM_11:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor -// CHECK: %[[VAL_5:.+]] = arith.divsi %[[DIM_10]], %[[C2]] : index -// CHECK: %[[VAL_6:.+]] = arith.divsi %[[DIM_11]], %[[C12]] : index -// CHECK: %[[RESHAPE3:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2], [3, 4, 5]] output_shape [%[[DIM_9]], %[[VAL_5]], 2, 3, 4, %[[VAL_6]]] : tensor into tensor +// CHECK: %[[RESHAPE0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3, 4], [5]] output_shape [3, 4, %[[SZ2]], %[[SZ4]], 2, %[[SZ3]]] : tensor into tensor<3x4x?x?x2x?xf32> +// CHECK: %[[RESHAPE1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [3, 4, %[[SZ2]], %[[SZ3]]] : tensor into tensor<3x4x?x?xf32> +// CHECK: %[[RESHAPE2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [%[[SZ4]], 2, %[[SZ3]], 3, 4, %[[SZ2]]] : tensor into tensor +// CHECK: %[[RESHAPE3:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2], [3, 4, 5]] output_shape [%[[SZ3]], %[[SZ4]], 2, 3, 4, %[[SZ2]]] : tensor into tensor // CHECK: %[[GENERIC:.+]]:2 = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]] // CHECK-SAME: ins(%[[RESHAPE0]], %[[RESHAPE1]] : @@ -662,17 +571,10 @@ func.func @generic_op_reshape_consumer_fusion_reduction(%arg0 : tensor, // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index -// CHECK: %[[C20:.+]] = arith.constant 20 : index // CHECK: %[[C1:.+]] = arith.constant 1 : index -// CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor -// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor -// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C20]] : index -// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[VAL_0]], 4, 5, %[[DIM_0]]] : tensor into tensor -// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor -// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor -// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C20]] : index -// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_1]], %[[VAL_1]], 4, 5] : tensor into tensor +// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor +// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[SZ1]], 4, 5, %[[DIM]]] : tensor into tensor +// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], %[[SZ1]], 4, 5] : tensor into tensor // CHECK: %[[T3:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel", "reduction"] @@ -712,21 +614,12 @@ func.func @generic_op_reshape_producer_fusion_with_reduction(%arg0 : tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[C1:.+]] = arith.constant 1 : index -// CHECK: %[[C7:.+]] = arith.constant 7 : index -// CHECK: %[[C8:.+]] = arith.constant 8 : index // CHECK: %[[C2:.+]] = arith.constant 2 : index // CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor -// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor -// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C8]] : index -// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C7]] : index -// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2], [3, 4]] output_shape [%[[VAL_0]], 8, 4, %[[VAL_1]], 7] : tensor into tensor -// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor -// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor -// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_1]], %[[C8]] : index -// CHECK: %[[VAL_3:.+]] = arith.divsi %[[DIM_2]], %[[C7]] : index -// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 8, %[[VAL_3]], 7] : tensor into tensor +// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor +// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2], [3, 4]] output_shape [%[[DIM_0]], 8, 4, %[[DIM]], 7] : tensor into tensor +// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[DIM_0]], 8, %[[DIM]], 7] : tensor into tensor // CHECK: %[[T3:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: ["parallel", "parallel", "reduction", "parallel", "parallel"] @@ -759,21 +652,9 @@ func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor, // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index -// CHECK: %[[C20:.+]] = arith.constant 20 : index -// CHECK: %[[C1:.+]] = arith.constant 1 : index -// CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor -// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor -// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C20]] : index -// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM]], %[[VAL_0]], 4, 5] : tensor into tensor -// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor -// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor -// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C20]] : index -// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_1]], %[[VAL_1]], 4, 5] : tensor into tensor -// CHECK: %[[DIM_4:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor -// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor -// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index -// CHECK: %[[T3:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], %[[VAL_2]], 4, 5] : tensor into tensor +// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], %[[SZ1]], 4, 5] : tensor into tensor +// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], %[[SZ1]], 4, 5] : tensor into tensor +// CHECK: %[[T3:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], %[[SZ1]], 4, 5] : tensor into tensor // CHECK: %[[T4:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] @@ -803,20 +684,12 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor, // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[C8:.+]] = arith.constant 8 : index -// CHECK: %[[C7:.+]] = arith.constant 7 : index -// CHECK: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[C2:.+]] = arith.constant 2 : index // CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor -// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor -// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C7]] : index -// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C8]] : index -// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_0]], 7, %[[VAL_1]], 8] : tensor into tensor -// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG2]], %[[C0]] : tensor -// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor -// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_1]], %[[C7]] : index -// CHECK: %[[VAL_3:.+]] = arith.divsi %[[DIM_2]], %[[C8]] : index -// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 7, %[[VAL_3]], 8] : tensor into tensor +// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor +// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[DIM]], 7, %[[DIM_0]], 8] : tensor into tensor +// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[DIM]], 7, %[[DIM_0]], 8] : tensor into tensor // CHECK: %[[T3:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]