-
Notifications
You must be signed in to change notification settings - Fork 12.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Linalg] Allow expand shape propagation across linalg ops with dynamic shapes. #127943
base: main
Are you sure you want to change the base?
[mlir][Linalg] Allow expand shape propagation across linalg ops with dynamic shapes. #127943
Conversation
…dynamic shapes. With `tensor.expand_shape` allowing expanding dynamic dimension into multiple dynamic dimension, adapt the reshape propagation through expansion to handle cases where one dynamic dimension is expanded into multiple dynamic dimension. Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: None (MaheshRavishankar) ChangesWith Patch is 46.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/127943.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 60cae77644291..b4da6d3d37354 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -595,18 +595,17 @@ class ExpansionInfo {
// the expanded op.
LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
ArrayRef<AffineMap> reassociationMaps,
- ArrayRef<int64_t> expandedShape,
- ArrayRef<int64_t> collapsedShape,
+ ArrayRef<OpFoldResult> expandedShape,
PatternRewriter &rewriter);
unsigned getOrigOpNumDims() const { return reassociation.size(); }
unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
ReassociationIndicesRef getExpandedDims(unsigned i) const {
return reassociation[i];
}
- ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
+ ArrayRef<OpFoldResult> getExpandedShapeOfDim(unsigned i) const {
return expandedShapeMap[i];
}
- ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
+ ArrayRef<OpFoldResult> getOriginalShape() const { return originalLoopExtent; }
private:
/// Reassociation from the dimensions in the original operation to the
@@ -614,9 +613,9 @@ class ExpansionInfo {
SmallVector<ReassociationIndices> reassociation;
/// Mapping from extent of loops in the original operation, to the extent of
/// loops in the expanded operation.
- SmallVector<SmallVector<int64_t>> expandedShapeMap;
+ SmallVector<SmallVector<OpFoldResult>> expandedShapeMap;
/// Extent of the loop in the original operation.
- SmallVector<int64_t> originalLoopExtent;
+ SmallVector<OpFoldResult> originalLoopExtent;
unsigned expandedOpNumDims;
};
} // namespace
@@ -624,15 +623,17 @@ class ExpansionInfo {
LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
OpOperand *fusableOpOperand,
ArrayRef<AffineMap> reassociationMaps,
- ArrayRef<int64_t> expandedShape,
- ArrayRef<int64_t> collapsedShape,
+ ArrayRef<OpFoldResult> expandedShape,
PatternRewriter &rewriter) {
if (reassociationMaps.empty())
return failure();
AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
- SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges();
- originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(linalgOp);
+ originalLoopExtent = llvm::map_to_vector(
+ linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
+ [](Range r) { return r.size; });
reassociation.clear();
expandedShapeMap.clear();
@@ -644,7 +645,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
AffineMap foldedDims = reassociationMaps[resultExpr.index()];
numExpandedDims[pos] = foldedDims.getNumResults();
- ArrayRef<int64_t> shape =
+ ArrayRef<OpFoldResult> shape =
expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
expandedShapeMap[pos].assign(shape.begin(), shape.end());
}
@@ -665,33 +666,6 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
return success();
}
-/// Expanding the body of a linalg operation requires adaptations of the
-/// accessed loop indices. Specifically, access of indices in the original
-/// operation need to be replaced with linearizations of indices in the expanded
-/// op. That requires the shape of the expanded dimensions to be static (at
-/// least all but the most significant). For now check that these are all
-/// statically sized. Note that this could be extended to handle dynamic case,
-/// but the implementation below uses `affine.apply` which seems to have issues
-/// when the shapes are not static.
-static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp,
- const ExpansionInfo &expansionInfo,
- PatternRewriter &rewriter) {
- if (!linalgOp.hasIndexSemantics())
- return success();
- for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
- ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
- if (expandedShape.size() == 1)
- continue;
- for (int64_t shape : expandedShape.drop_front()) {
- if (ShapedType::isDynamic(shape)) {
- return rewriter.notifyMatchFailure(
- linalgOp, "cannot expand due to index semantics and dynamic dims");
- }
- }
- }
- return success();
-}
-
/// Return the indexing map to use in the expanded op for a given the
/// `indexingMap` of the original operation.
static AffineMap
@@ -713,16 +687,28 @@ getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
/// Return the type of the operand/result to use in the expanded op given the
/// type in the original op.
-static RankedTensorType getExpandedType(RankedTensorType originalType,
- AffineMap indexingMap,
- const ExpansionInfo &expansionInfo) {
- SmallVector<int64_t> expandedShape;
+static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
+getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap,
+ const ExpansionInfo &expansionInfo) {
+ SmallVector<int64_t> expandedStaticShape;
+ SmallVector<OpFoldResult> expandedShape;
for (AffineExpr expr : indexingMap.getResults()) {
unsigned dim = cast<AffineDimExpr>(expr).getPosition();
- auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
+ ArrayRef<OpFoldResult> dimExpansion =
+ expansionInfo.getExpandedShapeOfDim(dim);
+ llvm::append_range(expandedStaticShape,
+ llvm::map_range(dimExpansion, [](OpFoldResult ofr) {
+ std::optional<int64_t> staticShape =
+ getConstantIntValue(ofr);
+ if (staticShape) {
+ return staticShape.value();
+ }
+ return ShapedType::kDynamic;
+ }));
expandedShape.append(dimExpansion.begin(), dimExpansion.end());
}
- return RankedTensorType::get(expandedShape, originalType.getElementType());
+ return {expandedShape, RankedTensorType::get(expandedStaticShape,
+ originalType.getElementType())};
}
/// Returns the reassociation maps to use in the `tensor.expand_shape`
@@ -770,49 +756,27 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
// Linearize the expanded indices of the original index dimension.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(indexOp);
- ArrayRef<int64_t> expandedDimsShape =
+ ArrayRef<OpFoldResult> expandedDimsShape =
expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
SmallVector<Value> expandedIndices;
expandedIndices.reserve(expandedDims.size() - 1);
llvm::transform(
expandedDims.drop_front(), std::back_inserter(expandedIndices),
[&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
- Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
+ OpFoldResult newIndex =
+ rewriter.create<IndexOp>(loc, expandedDims.front()).getResult();
for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
- assert(!ShapedType::isDynamic(std::get<0>(it)));
- AffineExpr idx, acc;
+ AffineExpr idx, acc, shape;
bindDims(rewriter.getContext(), idx, acc);
- newIndex = rewriter.create<affine::AffineApplyOp>(
- indexOp.getLoc(), idx + acc * std::get<0>(it),
- ValueRange{std::get<1>(it), newIndex});
- }
- rewriter.replaceOp(indexOp, newIndex);
- }
-}
-
-/// Checks if a single dynamic dimension expanded into multiple dynamic
-/// dimensions.
-static LogicalResult
-validateDynamicDimExpansion(LinalgOp linalgOp,
- const ExpansionInfo &expansionInfo,
- PatternRewriter &rewriter) {
- for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
- ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
- if (expandedShape.size() == 1)
- continue;
- bool foundDynamic = false;
- for (int64_t shape : expandedShape) {
- if (!ShapedType::isDynamic(shape))
- continue;
- if (foundDynamic) {
- return rewriter.notifyMatchFailure(
- linalgOp, "cannot infer expanded shape with multiple dynamic "
- "dims in the same reassociation group");
- }
- foundDynamic = true;
+ bindSymbols(rewriter.getContext(), shape);
+ newIndex = affine::makeComposedFoldedAffineApply(
+ rewriter, indexOp.getLoc(), idx + acc * shape,
+ ArrayRef<OpFoldResult>{std::get<1>(it), newIndex, std::get<0>(it)});
}
+ Value newIndexVal =
+ getValueOrCreateConstantIndexOp(rewriter, indexOp.getLoc(), newIndex);
+ rewriter.replaceOp(indexOp, newIndexVal);
}
- return success();
}
/// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
@@ -826,31 +790,25 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
"preconditions for fuse operation failed");
Location loc = linalgOp.getLoc();
- // Check if reshape is expanding or collapsing.
- auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
- auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
- bool isExpanding = (expandingReshapeOp != nullptr);
- RankedTensorType expandedType = isExpanding
- ? expandingReshapeOp.getResultType()
- : collapsingReshapeOp.getSrcType();
- RankedTensorType collapsedType = isExpanding
- ? expandingReshapeOp.getSrcType()
- : collapsingReshapeOp.getResultType();
+ SmallVector<OpFoldResult> expandedShape, collapsedShape;
+ SmallVector<AffineMap, 4> reassociationIndices;
+ Value src;
+ if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
+ expandedShape = expandingReshapeOp.getMixedOutputShape();
+ reassociationIndices = expandingReshapeOp.getReassociationMaps();
+ src = expandingReshapeOp.getSrc();
+ } else {
+ auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
+ expandedShape = tensor::getMixedSizes(
+ rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
+ reassociationIndices = collapsingReshapeOp.getReassociationMaps();
+ src = collapsingReshapeOp.getSrc();
+ }
ExpansionInfo expansionInfo;
if (failed(expansionInfo.compute(
- linalgOp, fusableOpOperand,
- isExpanding ? expandingReshapeOp.getReassociationMaps()
- : collapsingReshapeOp.getReassociationMaps(),
- expandedType.getShape(), collapsedType.getShape(), rewriter)))
- return std::nullopt;
-
- // TODO: With the support of multiple dynamic dims expansion in
- // tensor.expand_shape op, this case can be handled.
- if (failed(validateDynamicDimExpansion(linalgOp, expansionInfo, rewriter)))
- return std::nullopt;
-
- if (failed(isLinalgOpExpandable(linalgOp, expansionInfo, rewriter)))
+ linalgOp, fusableOpOperand, reassociationIndices,
+ expandedShape, rewriter)))
return std::nullopt;
SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
@@ -866,15 +824,16 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
if (opOperand == fusableOpOperand) {
- expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
- : collapsingReshapeOp.getSrc());
+ expandedOpOperands.push_back(src);
continue;
}
if (auto opOperandType =
dyn_cast<RankedTensorType>(opOperand->get().getType())) {
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
- RankedTensorType expandedOperandType =
- getExpandedType(opOperandType, indexingMap, expansionInfo);
+ SmallVector<OpFoldResult> expandedOperandShape;
+ RankedTensorType expandedOperandType;
+ std::tie(expandedOperandShape, expandedOperandType) =
+ getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
if (expandedOperandType != opOperand->get().getType()) {
// Reshape the operand to get the right type.
SmallVector<ReassociationIndices> reassociation =
@@ -888,7 +847,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
/*isExpandingReshape=*/true)))
return std::nullopt;
expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
- loc, expandedOperandType, opOperand->get(), reassociation));
+ loc, expandedOperandType, opOperand->get(), reassociation,
+ expandedOperandShape));
continue;
}
}
@@ -899,8 +859,10 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
- RankedTensorType expandedOutputType =
- getExpandedType(opOperandType, indexingMap, expansionInfo);
+ SmallVector<OpFoldResult> expandedOutputShape;
+ RankedTensorType expandedOutputType;
+ std::tie(expandedOutputShape, expandedOutputType) =
+ getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
if (expandedOutputType != opOperand.get().getType()) {
SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(indexingMap, expansionInfo);
@@ -913,7 +875,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
/*isExpandingReshape=*/true)))
return std::nullopt;
outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
- loc, expandedOutputType, opOperand.get(), reassociation));
+ loc, expandedOutputType, opOperand.get(), reassociation,
+ expandedOutputShape));
} else {
outputs.push_back(opOperand.get());
}
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index ef853e4d662a7..57904f912a35b 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -30,20 +30,14 @@ func.func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: f32
-// CHECK: %[[C4:.+]] = arith.constant 4 : index
-// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[C3:.+]] = arith.constant 3 : index
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32>
-// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM_1]], %[[C4]] : index
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM]], %[[DIM_0]], %[[VAL_0]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
-// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_3:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_4:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32>
-// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_4]], %[[C4]] : index
-// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM_2]], %[[DIM_3]], %[[VAL_1]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
+// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x4x?xf32>
+// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x4x?xf32>
+// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG0]], %[[C3]] : tensor<?x?x4x?xf32>
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM_1]], %[[DIM]], %[[DIM_0]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
+// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM_1]], %[[DIM]], %[[DIM_0]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]], #[[MAP6]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
@@ -88,21 +82,9 @@ func.func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: f32
// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index
-// CHECK: %[[C20:.+]] = arith.constant 20 : index
-// CHECK: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C20]] : index
-// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM]], 4, %[[VAL_0]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
-// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C20]] : index
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_1]], 4, %[[VAL_1]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
-// CHECK: %[[DIM_4:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index
-// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], 4, %[[VAL_2]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
+// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
+// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]], #[[MAP2]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
@@ -137,26 +119,9 @@ func.func @reshape_as_consumer_permutation
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index, %[[SZ2:.+]]: index
-// CHECK: %[[C12:.+]] = arith.constant 12 : index
-// CHECK: %[[C2:.+]] = arith.constant 2 : index
-// CHECK: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: ...
[truncated]
|
You can test this locally with the following command:git-clang-format --diff 7c24041895bc46dc19634e285a8907c787f8a3f9 09fa4040b54618703c900899c63b9585ab83552c --extensions cpp -- mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp View the diff from clang-format here.diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index b4da6d3d37..935155f356 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -806,9 +806,9 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
}
ExpansionInfo expansionInfo;
- if (failed(expansionInfo.compute(
- linalgOp, fusableOpOperand, reassociationIndices,
- expandedShape, rewriter)))
+ if (failed(expansionInfo.compute(linalgOp, fusableOpOperand,
+ reassociationIndices, expandedShape,
+ rewriter)))
return std::nullopt;
SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
|
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
With
tensor.expand_shape
allowing expanding dynamic dimension into multiple dynamic dimension, adapt the reshape propagation through expansion to handle cases where one dynamic dimension is expanded into multiple dynamic dimension.