From c515c780244e3ecbb1fcfd06b3ad588d8d22c28e Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Thu, 18 Apr 2024 15:47:08 +0200 Subject: [PATCH] [mlir][Bufferization] castOrReallocMemRefValue: Use BufferizationOptions (#89175) This allows to configure both the op used for allocation and copy of memrefs. It also changes the default behavior because the default allocation in `BufferizationOptions` creates `memref.alloc` with `alignment = 64` where we used to create `memref.alloca` without any alignment before. Fixes ``` // TODO: Use alloc/memcpy callback from BufferizationOptions if called via // BufferizableOpInterface impl of ToMemrefOp. ``` --- .../Dialect/Bufferization/IR/Bufferization.h | 6 ++-- .../Bufferization/IR/BufferizationOps.cpp | 31 +++++++++++-------- .../Bufferization/Transforms/Bufferize.cpp | 8 +++-- .../one-shot-module-bufferize-out-params.mlir | 2 +- .../Transforms/one-shot-module-bufferize.mlir | 2 +- 5 files changed, 29 insertions(+), 20 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h index e98b5728b38ef8..6f19dca2e82224 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h @@ -53,12 +53,14 @@ void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue, /// This function returns `failure()` in case of unsupported casts. E.g., casts /// with differing element types or memory spaces. FailureOr castOrReallocMemRefValue(OpBuilder &b, Value value, - MemRefType type); + MemRefType type, + const BufferizationOptions &options); /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the /// to_memref op are different, a memref.cast is needed. LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter, - ToMemrefOp toMemref); + ToMemrefOp toMemref, + const BufferizationOptions &options); /// Add the canonicalization patterns for bufferization.dealloc to the given /// pattern set to make them available to other passes (such as diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index a656c812a59feb..0acb0c24ab313b 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -23,9 +23,9 @@ using namespace mlir::bufferization; // Helper functions //===----------------------------------------------------------------------===// -FailureOr -mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value, - MemRefType destType) { +FailureOr mlir::bufferization::castOrReallocMemRefValue( + OpBuilder &b, Value value, MemRefType destType, + const BufferizationOptions &options) { auto srcType = llvm::cast(value.getType()); // Element type, rank and memory space must match. @@ -73,18 +73,21 @@ mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value, Value size = b.create(loc, value, i); dynamicOperands.push_back(size); } - // TODO: Use alloc/memcpy callback from BufferizationOptions if called via - // BufferizableOpInterface impl of ToMemrefOp. - Value copy = b.create(loc, destType, dynamicOperands); - b.create(loc, value, copy); + + FailureOr copy = + options.createAlloc(b, loc, destType, dynamicOperands); + if (failed(copy)) + return failure(); + if (failed(options.createMemCpy(b, loc, value, *copy))) + return failure(); return copy; } /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the /// to_memref op are different, a memref.cast is needed. -LogicalResult -mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter, - ToMemrefOp toMemref) { +LogicalResult mlir::bufferization::foldToMemrefToTensorPair( + RewriterBase &rewriter, ToMemrefOp toMemref, + const BufferizationOptions &options) { auto memrefToTensor = toMemref.getTensor().getDefiningOp(); if (!memrefToTensor) return failure(); @@ -105,7 +108,7 @@ mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter, // Ranked memref -> Ranked memref cast. if (rankedSrcType && rankedDestType) { FailureOr replacement = castOrReallocMemRefValue( - rewriter, memrefToTensor.getMemref(), rankedDestType); + rewriter, memrefToTensor.getMemref(), rankedDestType, options); if (failed(replacement)) return failure(); @@ -795,7 +798,9 @@ struct ToMemrefToTensorFolding : public OpRewritePattern { LogicalResult matchAndRewrite(ToMemrefOp toMemref, PatternRewriter &rewriter) const final { - return foldToMemrefToTensorPair(rewriter, toMemref); + BufferizationOptions options; + options.bufferAlignment = 0; + return foldToMemrefToTensorPair(rewriter, toMemref, options); } }; @@ -843,7 +848,7 @@ void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results, LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter, const BufferizationOptions &options) { // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary. - (void)foldToMemrefToTensorPair(rewriter, *this); + (void)foldToMemrefToTensorPair(rewriter, *this, options); // Note: The return value of `bufferize` indicates whether there was an error // or not. (And not whether the pattern matched or not.) return success(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index 32f4e6a0fe8901..7ba347a1f15e47 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -74,8 +74,10 @@ BufferizeTypeConverter::BufferizeTypeConverter() { auto rankedDestType = dyn_cast(type); if (!rankedDestType) return nullptr; + BufferizationOptions options; + options.bufferAlignment = 0; FailureOr replacement = - castOrReallocMemRefValue(builder, inputs[0], rankedDestType); + castOrReallocMemRefValue(builder, inputs[0], rankedDestType, options); if (failed(replacement)) return nullptr; return *replacement; @@ -512,8 +514,8 @@ LogicalResult bufferization::bufferizeOp(Operation *op, // Fold all to_memref(to_tensor(x)) pairs. for (Operation *op : toMemrefOps) { rewriter.setInsertionPoint(op); - (void)bufferization::foldToMemrefToTensorPair(rewriter, - cast(op)); + (void)bufferization::foldToMemrefToTensorPair( + rewriter, cast(op), options); } // Remove all dead to_tensor ops. diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir index de75b288855f94..9cf44c335d551e 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir @@ -84,7 +84,7 @@ func.func @main(%t: tensor<5xf32>) -> (f32, f32) { // Note: This alloc is not needed, but it is inserted before the returned buffer // is promoted to an out param to reconcile mismatching layout maps on return // value and function signature. -// CHECK-NO-LAYOUT: %[[alloc2:.*]] = memref.alloc() : memref<2x5xf32> +// CHECK-NO-LAYOUT: %[[alloc2:.*]] = memref.alloc() {{.*}} : memref<2x5xf32> // CHECK-NO-LAYOUT: memref.copy %[[subview]], %[[alloc2]] // CHECK-NO-LAYOUT: memref.copy %[[alloc2]], %[[r]] diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir index 429c9e4dea9e93..0248afb11f1672 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir @@ -52,7 +52,7 @@ func.func private @external_func_with_return_val(tensor<4xi32>) -> f32 // CHECK-NO-LAYOUT-MAP-LABEL: func @return_extract_slice(%{{.*}}) -> memref<2x?xf32> // CHECK-NO-LAYOUT-MAP: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<20x10xf32> // CHECK-NO-LAYOUT-MAP: %[[subview:.*]] = memref.subview {{.*}} : memref<20x10xf32> to memref<2x?xf32, strided<[10, 1], offset: ?>> -// CHECK-NO-LAYOUT-MAP: %[[alloc_no_layout:.*]] = memref.alloc(%{{.*}}) : memref<2x?xf32> +// CHECK-NO-LAYOUT-MAP: %[[alloc_no_layout:.*]] = memref.alloc(%{{.*}}) {{.*}} : memref<2x?xf32> // CHECK-NO-LAYOUT-MAP: memref.copy %[[subview]], %[[alloc_no_layout]] // TODO: %alloc should be deallocated here, but we currently do not dealloc // buffers that are inserted due to to_tensor/to_memref canonicalization (when