Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Fix incorrect memref::DimOp canonicalization, add tensor::DimOp canonicalization #84225

Merged
merged 8 commits into from
Mar 12, 2024
15 changes: 14 additions & 1 deletion mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,20 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();

if (!reshape)
return failure();
return rewriter.notifyMatchFailure(
dim, "Dim op is not defined by a reshape op.");

if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
sahas3 marked this conversation as resolved.
Show resolved Hide resolved
if (auto *definingOp = dim.getIndex().getDefiningOp()) {
if (reshape->isBeforeInBlock(definingOp))
return rewriter.notifyMatchFailure(
dim,
"dim.getIndex is not defined before reshape in the same block.");
} // else dim.getIndex is a block argument to reshape->getBlock
} else if (!dim.getIndex().getParentRegion()->isProperAncestor(
sahas3 marked this conversation as resolved.
Show resolved Hide resolved
reshape->getParentRegion()))
return rewriter.notifyMatchFailure(
dim, "dim.getIndex does not dominate reshape.");
sahas3 marked this conversation as resolved.
Show resolved Hide resolved

// Place the load directly after the reshape to ensure that the shape memref
// was not mutated.
Expand Down
28 changes: 27 additions & 1 deletion mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -824,11 +824,37 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
return success();
}
};

/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
/// operand.
struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
using OpRewritePattern<DimOp>::OpRewritePattern;

LogicalResult matchAndRewrite(DimOp dim,
PatternRewriter &rewriter) const override {
auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();

if (!reshape)
return failure();

// Since tensors are immutable we don't need to worry about where to place
// the extract call
rewriter.setInsertionPointAfter(dim);
Location loc = dim.getLoc();
Value extract =
rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
if (extract.getType() != dim.getType())
extract =
rewriter.create<arith::IndexCastOp>(loc, dim.getType(), extract);
rewriter.replaceOp(dim, extract);
return success();
}
};
} // namespace

void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DimOfCastOp, DimOfDestStyleOp>(context);
results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
53 changes: 53 additions & 0 deletions mlir/test/Dialect/MemRef/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,59 @@ func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)

// -----

// Test case: memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
// CHECK-LABEL: func @dim_of_memref_reshape_block_arg_index(
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>,
// CHECK-SAME: %[[IDX:[0-9a-z]+]]: index
// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
// CHECK-NOT: memref.dim
// CHECK: return %[[DIM]] : index
func.func @dim_of_memref_reshape_block_arg_index(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
%reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
%dim = memref.dim %reshape, %arg2 : memref<*xf32>
return %dim : index
}

// -----

// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx]
// CHECK-LABEL: func @dim_of_memref_reshape_for(
// CHECK: memref.reshape
// CHECK: memref.dim
// CHECK-NOT: memref.load
func.func @dim_of_memref_reshape_for( %arg0: memref<*xf32>, %arg1: memref<?xindex>) -> index {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index

%0 = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>

%1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
%2 = memref.dim %0, %arg2 : memref<*xf32>
%3 = arith.muli %arg3, %2 : index
scf.yield %3 : index
}
return %1 : index
}

// -----

// Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx]
// CHECK-LABEL: func @dim_of_memref_reshape_undominated(
// CHECK: memref.reshape
// CHECK: memref.dim
// CHECK-NOT: memref.load
func.func @dim_of_memref_reshape_undominated(%arg0: memref<*xf32>, %arg1: memref<?xindex>, %arg2: index) -> index {
%c4 = arith.constant 4 : index
%reshape = memref.reshape %arg0(%arg1) : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
%0 = arith.muli %arg2, %c4 : index
%dim = memref.dim %reshape, %0 : memref<*xf32>
return %dim : index
}

// -----

// CHECK-LABEL: func @alloc_const_fold
func.func @alloc_const_fold() -> memref<?xf32> {
// CHECK-NEXT: memref.alloc() : memref<4xf32>
Expand Down
80 changes: 80 additions & 0 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2250,3 +2250,83 @@ func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> t
// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
// CHECK: return %[[SRC]]

// -----

// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
// CHECK-LABEL: func @dim_of_reshape(
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: tensor<*xf32>,
// CHECK-SAME: %[[SHP:[0-9a-z]+]]: tensor<?xindex>
// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
// CHECK-NEXT: %[[DIM:.*]] = tensor.extract %[[SHP]][%[[IDX]]]
// CHECK-NOT: tensor.store
// CHECK-NOT: tensor.dim
// CHECK-NOT: tensor.reshape
// CHECK: return %[[DIM]] : index
func.func @dim_of_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>)
-> index {
%c3 = arith.constant 3 : index
%0 = tensor.reshape %arg0(%arg1)
: (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
// Update the shape to test that the load ends up in the right place.
tensor.insert %c3 into %arg1[%c3] : tensor<?xindex>
%1 = tensor.dim %0, %c3 : tensor<*xf32>
return %1 : index
}

// -----

// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
// CHECK-LABEL: func @dim_of_reshape_i32(
// CHECK: tensor.extract
// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast
// CHECK-NOT: tensor.dim
// CHECK-NOT: tensor.reshape
// CHECK: return %[[CAST]] : index
func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor<?xi32>)
-> index {
%c3 = arith.constant 3 : index
%0 = tensor.reshape %arg0(%arg1)
: (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32>
%1 = tensor.dim %0, %c3 : tensor<*xf32>
return %1 : index
}

// -----

// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
// CHECK-LABEL: func @dim_of_reshape_for(
// CHECK: scf.for
// CHECK-NEXT: tensor.extract
// CHECK-NOT: tensor.dim
// CHECK-NOT: tensor.reshape
func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> index {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index

%0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>

%1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
%2 = tensor.dim %0, %arg2 : tensor<*xf32>
%3 = arith.muli %arg3, %2 : index
scf.yield %3 : index
}
return %1 : index
}

// -----

// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
// CHECK-LABEL: func @dim_of_reshape_undominated(
// CHECK: arith.muli
// CHECK-NEXT: tensor.extract
// CHECK-NOT: tensor.dim
// CHECK-NOT: tensor.reshape
func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: index) -> index {
%c4 = arith.constant 4 : index
%reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
%0 = arith.muli %arg2, %c4 : index
%dim = tensor.dim %reshape, %0 : tensor<*xf32>
return %dim : index
}
Loading