From 52050f3ff388773b9345d421d968a7d1ee880531 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 10 Jun 2024 21:49:52 +0200 Subject: [PATCH] [mlir][Transforms] Dialect Conversion: Simplify block conversion API (#94866) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit simplifies and improves documentation for the part of the `ConversionPatternRewriter` API that deals with signature conversions. There are now two public functions for signature conversion: * `applySignatureConversion` converts a single block signature. This function used to take a `Region *` (but converted only the entry block). It now takes a `Block *`. * `convertRegionTypes` converts all block signatures of a region. `convertNonEntryRegionTypes` is removed because it is not widely used and can easily be expressed with a call to `applySignatureConversion` inside a loop. (See `Detensorize.cpp` for an example.) Note: For consistency, `convertRegionTypes` could be renamed to `applySignatureConversion` (overload) in the future. (Or `applySignatureConversion` renamed to `convertBlockTypes`.) Also clarify when a type converter and/or signature conversion object is needed and for what purpose. Internal code refactoring (NFC) of `ConversionPatternRewriterImpl` (the part that deals with signature conversions). This part of the codebase was quite convoluted and unintuitive. From a functional perspective, this change is NFC. However, the public API changes, thus not marking as NFC. Note for LLVM integration: When you see `applySignatureConversion(region, ...)`, replace with `applySignatureConversion(region->front(), ...)`. In the unlikely case that you see `convertNonEntryRegionTypes`, apply the same changes as this commit did to `Detensorize.cpp`. --------- Co-authored-by: Markus Böck --- mlir/docs/DialectConversion.md | 30 +++-- .../mlir/Transforms/DialectConversion.h | 49 +++---- mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp | 2 +- .../Dialect/Linalg/Transforms/Detensorize.cpp | 20 ++- .../Transforms/Utils/DialectConversion.cpp | 123 ++++-------------- mlir/test/lib/Dialect/Test/TestPatterns.cpp | 5 +- 6 files changed, 81 insertions(+), 148 deletions(-) diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md index a355d5a90e4d1b..69781bb868bbf8 100644 --- a/mlir/docs/DialectConversion.md +++ b/mlir/docs/DialectConversion.md @@ -372,19 +372,23 @@ class TypeConverter { From the perspective of type conversion, the types of block arguments are a bit special. Throughout the conversion process, blocks may move between regions of different operations. Given this, the conversion of the types for blocks must be -done explicitly via a conversion pattern. To convert the types of block -arguments within a Region, a custom hook on the `ConversionPatternRewriter` must -be invoked; `convertRegionTypes`. This hook uses a provided type converter to -apply type conversions to all blocks within a given region, and all blocks that -move into that region. As noted above, the conversions performed by this method -use the argument materialization hook on the `TypeConverter`. This hook also -takes an optional `TypeConverter::SignatureConversion` parameter that applies a -custom conversion to the entry block of the region. The types of the entry block -arguments are often tied semantically to details on the operation, e.g. func::FuncOp, -AffineForOp, etc. To convert the signature of just the region entry block, and -not any other blocks within the region, the `applySignatureConversion` hook may -be used instead. A signature conversion, `TypeConverter::SignatureConversion`, -can be built programmatically: +done explicitly via a conversion pattern. + +To convert the types of block arguments within a Region, a custom hook on the +`ConversionPatternRewriter` must be invoked; `convertRegionTypes`. This hook +uses a provided type converter to apply type conversions to all blocks of a +given region. As noted above, the conversions performed by this method use the +argument materialization hook on the `TypeConverter`. This hook also takes an +optional `TypeConverter::SignatureConversion` parameter that applies a custom +conversion to the entry block of the region. The types of the entry block +arguments are often tied semantically to the operation, e.g., +`func::FuncOp`, `AffineForOp`, etc. + +To convert the signature of just one given block, the +`applySignatureConversion` hook can be used. + +A signature conversion, `TypeConverter::SignatureConversion`, can be built +programmatically: ```c++ class SignatureConversion { diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index f6c51499f271c5..f83f3a3fdf9929 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -247,7 +247,8 @@ class TypeConverter { /// Attempts a 1-1 type conversion, expecting the result type to be /// `TargetType`. Returns the converted type cast to `TargetType` on success, /// and a null type on conversion or cast failure. - template TargetType convertType(Type t) const { + template + TargetType convertType(Type t) const { return dyn_cast_or_null(convertType(t)); } @@ -661,42 +662,42 @@ class ConversionPatternRewriter final : public PatternRewriter { public: ~ConversionPatternRewriter() override; - /// Apply a signature conversion to the entry block of the given region. This - /// replaces the entry block with a new block containing the updated - /// signature. The new entry block to the region is returned for convenience. - /// If no block argument types are changing, the entry original block will be + /// Apply a signature conversion to given block. This replaces the block with + /// a new block containing the updated signature. The operations of the given + /// block are inlined into the newly-created block, which is returned. + /// + /// If no block argument types are changing, the original block will be /// left in place and returned. /// - /// If provided, `converter` will be used for any materializations. + /// A signature converison must be provided. (Type converters can construct + /// a signature conversion with `convertBlockSignature`.) + /// + /// Optionally, a type converter can be provided to build materializations. + /// Note: If no type converter was provided or the type converter does not + /// specify any suitable argument/target materialization rules, the dialect + /// conversion may fail to legalize unresolved materializations. Block * - applySignatureConversion(Region *region, + applySignatureConversion(Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter = nullptr); - /// Convert the types of block arguments within the given region. This + /// Apply a signature conversion to each block in the given region. This /// replaces each block with a new block containing the updated signature. If /// an updated signature would match the current signature, the respective - /// block is left in place as is. + /// block is left in place as is. (See `applySignatureConversion` for + /// details.) The new entry block of the region is returned. + /// + /// SignatureConversions are computed with the specified type converter. + /// This function returns "failure" if the type converter failed to compute + /// a SignatureConversion for at least one block. /// - /// The entry block may have a special conversion if `entryConversion` is - /// provided. On success, the new entry block to the region is returned for - /// convenience. Otherwise, failure is returned. + /// Optionally, a special SignatureConversion can be specified for the entry + /// block. This is because the types of the entry block arguments are often + /// tied semantically to the operation. FailureOr convertRegionTypes( Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion = nullptr); - /// Convert the types of block arguments within the given region except for - /// the entry region. This replaces each non-entry block with a new block - /// containing the updated signature. If an updated signature would match the - /// current signature, the respective block is left in place as is. - /// - /// If special conversion behavior is needed for the non-entry blocks (for - /// example, we need to convert only a subset of a BB arguments), such - /// behavior can be specified in blockConversions. - LogicalResult convertNonEntryRegionTypes( - Region *region, const TypeConverter &converter, - ArrayRef blockConversions); - /// Replace all the uses of the block argument `from` with value `to`. void replaceUsesOfBlockArgument(BlockArgument from, Value to); diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp index d90cf931385fcc..f62de1f17a6668 100644 --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -162,7 +162,7 @@ struct ForOpConversion final : SCFToSPIRVPattern { signatureConverter.remapInput(0, newIndVar); for (unsigned i = 1, e = body->getNumArguments(); i < e; i++) signatureConverter.remapInput(i, header->getArgument(i)); - body = rewriter.applySignatureConversion(&forOp.getRegion(), + body = rewriter.applySignatureConversion(&forOp.getRegion().front(), signatureConverter); // Move the blocks from the forOp into the loopOp. This is the body of the diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp index 22968096a68913..af38485291182f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -106,27 +106,23 @@ struct FunctionNonEntryBlockConversion ConversionPatternRewriter &rewriter) const override { rewriter.startOpModification(op); Region ®ion = op.getFunctionBody(); - SmallVector conversions; - for (Block &block : llvm::drop_begin(region, 1)) { - conversions.emplace_back(block.getNumArguments()); - TypeConverter::SignatureConversion &back = conversions.back(); + for (Block &block : + llvm::make_early_inc_range(llvm::drop_begin(region, 1))) { + TypeConverter::SignatureConversion conversion( + /*numOrigInputs=*/block.getNumArguments()); for (BlockArgument blockArgument : block.getArguments()) { int idx = blockArgument.getArgNumber(); if (blockArgsToDetensor.count(blockArgument)) - back.addInputs(idx, {getTypeConverter()->convertType( - block.getArgumentTypes()[idx])}); + conversion.addInputs(idx, {getTypeConverter()->convertType( + block.getArgumentTypes()[idx])}); else - back.addInputs(idx, {block.getArgumentTypes()[idx]}); + conversion.addInputs(idx, {block.getArgumentTypes()[idx]}); } - } - if (failed(rewriter.convertNonEntryRegionTypes(®ion, *typeConverter, - conversions))) { - rewriter.cancelOpModification(op); - return failure(); + rewriter.applySignatureConversion(&block, conversion, getTypeConverter()); } rewriter.finalizeOpModification(op); diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index d407d60334c70d..2f0efe1b1e454e 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -839,27 +839,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { // Type Conversion //===--------------------------------------------------------------------===// - /// Attempt to convert the signature of the given block, if successful a new - /// block is returned containing the new arguments. Returns `block` if it did - /// not require conversion. - FailureOr convertBlockSignature( - ConversionPatternRewriter &rewriter, Block *block, - const TypeConverter *converter, - TypeConverter::SignatureConversion *conversion = nullptr); - - /// Convert the types of non-entry block arguments within the given region. - LogicalResult convertNonEntryRegionTypes( - ConversionPatternRewriter &rewriter, Region *region, - const TypeConverter &converter, - ArrayRef blockConversions = {}); - - /// Apply a signature conversion on the given region, using `converter` for - /// materializations if not null. - Block * - applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region, - TypeConverter::SignatureConversion &conversion, - const TypeConverter *converter); - /// Convert the types of block arguments within the given region. FailureOr convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, @@ -1294,34 +1273,6 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const { //===----------------------------------------------------------------------===// // Type Conversion -FailureOr ConversionPatternRewriterImpl::convertBlockSignature( - ConversionPatternRewriter &rewriter, Block *block, - const TypeConverter *converter, - TypeConverter::SignatureConversion *conversion) { - if (conversion) - return applySignatureConversion(rewriter, block, converter, *conversion); - - // If a converter wasn't provided, and the block wasn't already converted, - // there is nothing we can do. - if (!converter) - return failure(); - - // Try to convert the signature for the block with the provided converter. - if (auto conversion = converter->convertBlockSignature(block)) - return applySignatureConversion(rewriter, block, converter, *conversion); - return failure(); -} - -Block *ConversionPatternRewriterImpl::applySignatureConversion( - ConversionPatternRewriter &rewriter, Region *region, - TypeConverter::SignatureConversion &conversion, - const TypeConverter *converter) { - if (!region->empty()) - return *convertBlockSignature(rewriter, ®ion->front(), converter, - &conversion); - return nullptr; -} - FailureOr ConversionPatternRewriterImpl::convertRegionTypes( ConversionPatternRewriter &rewriter, Region *region, const TypeConverter &converter, @@ -1330,42 +1281,29 @@ FailureOr ConversionPatternRewriterImpl::convertRegionTypes( if (region->empty()) return nullptr; - if (failed(convertNonEntryRegionTypes(rewriter, region, converter))) - return failure(); - - FailureOr newEntry = convertBlockSignature( - rewriter, ®ion->front(), &converter, entryConversion); - return newEntry; -} - -LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes( - ConversionPatternRewriter &rewriter, Region *region, - const TypeConverter &converter, - ArrayRef blockConversions) { - regionToConverter[region] = &converter; - if (region->empty()) - return success(); - - // Convert the arguments of each block within the region. - int blockIdx = 0; - assert((blockConversions.empty() || - blockConversions.size() == region->getBlocks().size() - 1) && - "expected either to provide no SignatureConversions at all or to " - "provide a SignatureConversion for each non-entry block"); - + // Convert the arguments of each non-entry block within the region. for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) { - TypeConverter::SignatureConversion *blockConversion = - blockConversions.empty() - ? nullptr - : const_cast( - &blockConversions[blockIdx++]); - - if (failed(convertBlockSignature(rewriter, &block, &converter, - blockConversion))) + // Compute the signature for the block with the provided converter. + std::optional conversion = + converter.convertBlockSignature(&block); + if (!conversion) return failure(); - } - return success(); + // Convert the block with the computed signature. + applySignatureConversion(rewriter, &block, &converter, *conversion); + } + + // Convert the entry block. If an entry signature conversion was provided, + // use that one. Otherwise, compute the signature with the type converter. + if (entryConversion) + return applySignatureConversion(rewriter, ®ion->front(), &converter, + *entryConversion); + std::optional conversion = + converter.convertBlockSignature(®ion->front()); + if (!conversion) + return failure(); + return applySignatureConversion(rewriter, ®ion->front(), &converter, + *conversion); } Block *ConversionPatternRewriterImpl::applySignatureConversion( @@ -1676,12 +1614,12 @@ void ConversionPatternRewriter::eraseBlock(Block *block) { } Block *ConversionPatternRewriter::applySignatureConversion( - Region *region, TypeConverter::SignatureConversion &conversion, + Block *block, TypeConverter::SignatureConversion &conversion, const TypeConverter *converter) { - assert(!impl->wasOpReplaced(region->getParentOp()) && + assert(!impl->wasOpReplaced(block->getParentOp()) && "attempting to apply a signature conversion to a block within a " "replaced/erased op"); - return impl->applySignatureConversion(*this, region, conversion, converter); + return impl->applySignatureConversion(*this, block, converter, conversion); } FailureOr ConversionPatternRewriter::convertRegionTypes( @@ -1693,16 +1631,6 @@ FailureOr ConversionPatternRewriter::convertRegionTypes( return impl->convertRegionTypes(*this, region, converter, entryConversion); } -LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes( - Region *region, const TypeConverter &converter, - ArrayRef blockConversions) { - assert(!impl->wasOpReplaced(region->getParentOp()) && - "attempting to apply a signature conversion to a block within a " - "replaced/erased op"); - return impl->convertNonEntryRegionTypes(*this, region, converter, - blockConversions); -} - void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, Value to) { LLVM_DEBUG({ @@ -2231,11 +2159,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites( // If the region of the block has a type converter, try to convert the block // directly. if (auto *converter = impl.regionToConverter.lookup(block->getParent())) { - if (failed(impl.convertBlockSignature(rewriter, block, converter))) { + std::optional conversion = + converter->convertBlockSignature(block); + if (!conversion) { LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved " "block")); return failure(); } + impl.applySignatureConversion(rewriter, block, converter, *conversion); continue; } diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index f9f7d4eacf948a..a14a5da3410980 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1516,8 +1516,9 @@ struct TestTestSignatureConversionNoConverter if (failed( converter.convertSignatureArgs(entry->getArgumentTypes(), result))) return failure(); - rewriter.modifyOpInPlace( - op, [&] { rewriter.applySignatureConversion(®ion, result); }); + rewriter.modifyOpInPlace(op, [&] { + rewriter.applySignatureConversion(®ion.front(), result); + }); return success(); }