Skip to content

Commit

Permalink
[mlir][Transforms] Dialect Conversion: Simplify block conversion API (#…
Browse files Browse the repository at this point in the history
…94866)

This commit simplifies and improves documentation for the part of the
`ConversionPatternRewriter` API that deals with signature conversions.

There are now two public functions for signature conversion:
* `applySignatureConversion` converts a single block signature. This
function used to take a `Region *` (but converted only the entry block).
It now takes a `Block *`.
* `convertRegionTypes` converts all block signatures of a region.

`convertNonEntryRegionTypes` is removed because it is not widely used
and can easily be expressed with a call to `applySignatureConversion`
inside a loop. (See `Detensorize.cpp` for an example.)

Note: For consistency, `convertRegionTypes` could be renamed to
`applySignatureConversion` (overload) in the future. (Or
`applySignatureConversion` renamed to `convertBlockTypes`.)

Also clarify when a type converter and/or signature conversion object is
needed and for what purpose.

Internal code refactoring (NFC) of `ConversionPatternRewriterImpl` (the
part that deals with signature conversions). This part of the codebase
was quite convoluted and unintuitive.

From a functional perspective, this change is NFC. However, the public
API changes, thus not marking as NFC.

Note for LLVM integration: When you see
`applySignatureConversion(region, ...)`, replace with
`applySignatureConversion(region->front(), ...)`. In the unlikely case
that you see `convertNonEntryRegionTypes`, apply the same changes as
this commit did to `Detensorize.cpp`.

---------

Co-authored-by: Markus Böck <markus.boeck02@gmail.com>
  • Loading branch information
matthias-springer and zero9178 authored Jun 10, 2024
1 parent 65310f3 commit 52050f3
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 148 deletions.
30 changes: 17 additions & 13 deletions mlir/docs/DialectConversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -372,19 +372,23 @@ class TypeConverter {
From the perspective of type conversion, the types of block arguments are a bit
special. Throughout the conversion process, blocks may move between regions of
different operations. Given this, the conversion of the types for blocks must be
done explicitly via a conversion pattern. To convert the types of block
arguments within a Region, a custom hook on the `ConversionPatternRewriter` must
be invoked; `convertRegionTypes`. This hook uses a provided type converter to
apply type conversions to all blocks within a given region, and all blocks that
move into that region. As noted above, the conversions performed by this method
use the argument materialization hook on the `TypeConverter`. This hook also
takes an optional `TypeConverter::SignatureConversion` parameter that applies a
custom conversion to the entry block of the region. The types of the entry block
arguments are often tied semantically to details on the operation, e.g. func::FuncOp,
AffineForOp, etc. To convert the signature of just the region entry block, and
not any other blocks within the region, the `applySignatureConversion` hook may
be used instead. A signature conversion, `TypeConverter::SignatureConversion`,
can be built programmatically:
done explicitly via a conversion pattern.

To convert the types of block arguments within a Region, a custom hook on the
`ConversionPatternRewriter` must be invoked; `convertRegionTypes`. This hook
uses a provided type converter to apply type conversions to all blocks of a
given region. As noted above, the conversions performed by this method use the
argument materialization hook on the `TypeConverter`. This hook also takes an
optional `TypeConverter::SignatureConversion` parameter that applies a custom
conversion to the entry block of the region. The types of the entry block
arguments are often tied semantically to the operation, e.g.,
`func::FuncOp`, `AffineForOp`, etc.

To convert the signature of just one given block, the
`applySignatureConversion` hook can be used.

A signature conversion, `TypeConverter::SignatureConversion`, can be built
programmatically:

```c++
class SignatureConversion {
Expand Down
49 changes: 25 additions & 24 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ class TypeConverter {
/// Attempts a 1-1 type conversion, expecting the result type to be
/// `TargetType`. Returns the converted type cast to `TargetType` on success,
/// and a null type on conversion or cast failure.
template <typename TargetType> TargetType convertType(Type t) const {
template <typename TargetType>
TargetType convertType(Type t) const {
return dyn_cast_or_null<TargetType>(convertType(t));
}

Expand Down Expand Up @@ -661,42 +662,42 @@ class ConversionPatternRewriter final : public PatternRewriter {
public:
~ConversionPatternRewriter() override;

/// Apply a signature conversion to the entry block of the given region. This
/// replaces the entry block with a new block containing the updated
/// signature. The new entry block to the region is returned for convenience.
/// If no block argument types are changing, the entry original block will be
/// Apply a signature conversion to given block. This replaces the block with
/// a new block containing the updated signature. The operations of the given
/// block are inlined into the newly-created block, which is returned.
///
/// If no block argument types are changing, the original block will be
/// left in place and returned.
///
/// If provided, `converter` will be used for any materializations.
/// A signature converison must be provided. (Type converters can construct
/// a signature conversion with `convertBlockSignature`.)
///
/// Optionally, a type converter can be provided to build materializations.
/// Note: If no type converter was provided or the type converter does not
/// specify any suitable argument/target materialization rules, the dialect
/// conversion may fail to legalize unresolved materializations.
Block *
applySignatureConversion(Region *region,
applySignatureConversion(Block *block,
TypeConverter::SignatureConversion &conversion,
const TypeConverter *converter = nullptr);

/// Convert the types of block arguments within the given region. This
/// Apply a signature conversion to each block in the given region. This
/// replaces each block with a new block containing the updated signature. If
/// an updated signature would match the current signature, the respective
/// block is left in place as is.
/// block is left in place as is. (See `applySignatureConversion` for
/// details.) The new entry block of the region is returned.
///
/// SignatureConversions are computed with the specified type converter.
/// This function returns "failure" if the type converter failed to compute
/// a SignatureConversion for at least one block.
///
/// The entry block may have a special conversion if `entryConversion` is
/// provided. On success, the new entry block to the region is returned for
/// convenience. Otherwise, failure is returned.
/// Optionally, a special SignatureConversion can be specified for the entry
/// block. This is because the types of the entry block arguments are often
/// tied semantically to the operation.
FailureOr<Block *> convertRegionTypes(
Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion = nullptr);

/// Convert the types of block arguments within the given region except for
/// the entry region. This replaces each non-entry block with a new block
/// containing the updated signature. If an updated signature would match the
/// current signature, the respective block is left in place as is.
///
/// If special conversion behavior is needed for the non-entry blocks (for
/// example, we need to convert only a subset of a BB arguments), such
/// behavior can be specified in blockConversions.
LogicalResult convertNonEntryRegionTypes(
Region *region, const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions);

/// Replace all the uses of the block argument `from` with value `to`.
void replaceUsesOfBlockArgument(BlockArgument from, Value to);

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
signatureConverter.remapInput(0, newIndVar);
for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
signatureConverter.remapInput(i, header->getArgument(i));
body = rewriter.applySignatureConversion(&forOp.getRegion(),
body = rewriter.applySignatureConversion(&forOp.getRegion().front(),
signatureConverter);

// Move the blocks from the forOp into the loopOp. This is the body of the
Expand Down
20 changes: 8 additions & 12 deletions mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,27 +106,23 @@ struct FunctionNonEntryBlockConversion
ConversionPatternRewriter &rewriter) const override {
rewriter.startOpModification(op);
Region &region = op.getFunctionBody();
SmallVector<TypeConverter::SignatureConversion, 2> conversions;

for (Block &block : llvm::drop_begin(region, 1)) {
conversions.emplace_back(block.getNumArguments());
TypeConverter::SignatureConversion &back = conversions.back();
for (Block &block :
llvm::make_early_inc_range(llvm::drop_begin(region, 1))) {
TypeConverter::SignatureConversion conversion(
/*numOrigInputs=*/block.getNumArguments());

for (BlockArgument blockArgument : block.getArguments()) {
int idx = blockArgument.getArgNumber();

if (blockArgsToDetensor.count(blockArgument))
back.addInputs(idx, {getTypeConverter()->convertType(
block.getArgumentTypes()[idx])});
conversion.addInputs(idx, {getTypeConverter()->convertType(
block.getArgumentTypes()[idx])});
else
back.addInputs(idx, {block.getArgumentTypes()[idx]});
conversion.addInputs(idx, {block.getArgumentTypes()[idx]});
}
}

if (failed(rewriter.convertNonEntryRegionTypes(&region, *typeConverter,
conversions))) {
rewriter.cancelOpModification(op);
return failure();
rewriter.applySignatureConversion(&block, conversion, getTypeConverter());
}

rewriter.finalizeOpModification(op);
Expand Down
123 changes: 27 additions & 96 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -839,27 +839,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
// Type Conversion
//===--------------------------------------------------------------------===//

/// Attempt to convert the signature of the given block, if successful a new
/// block is returned containing the new arguments. Returns `block` if it did
/// not require conversion.
FailureOr<Block *> convertBlockSignature(
ConversionPatternRewriter &rewriter, Block *block,
const TypeConverter *converter,
TypeConverter::SignatureConversion *conversion = nullptr);

/// Convert the types of non-entry block arguments within the given region.
LogicalResult convertNonEntryRegionTypes(
ConversionPatternRewriter &rewriter, Region *region,
const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});

/// Apply a signature conversion on the given region, using `converter` for
/// materializations if not null.
Block *
applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region,
TypeConverter::SignatureConversion &conversion,
const TypeConverter *converter);

/// Convert the types of block arguments within the given region.
FailureOr<Block *>
convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
Expand Down Expand Up @@ -1294,34 +1273,6 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
//===----------------------------------------------------------------------===//
// Type Conversion

FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
ConversionPatternRewriter &rewriter, Block *block,
const TypeConverter *converter,
TypeConverter::SignatureConversion *conversion) {
if (conversion)
return applySignatureConversion(rewriter, block, converter, *conversion);

// If a converter wasn't provided, and the block wasn't already converted,
// there is nothing we can do.
if (!converter)
return failure();

// Try to convert the signature for the block with the provided converter.
if (auto conversion = converter->convertBlockSignature(block))
return applySignatureConversion(rewriter, block, converter, *conversion);
return failure();
}

Block *ConversionPatternRewriterImpl::applySignatureConversion(
ConversionPatternRewriter &rewriter, Region *region,
TypeConverter::SignatureConversion &conversion,
const TypeConverter *converter) {
if (!region->empty())
return *convertBlockSignature(rewriter, &region->front(), converter,
&conversion);
return nullptr;
}

FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
ConversionPatternRewriter &rewriter, Region *region,
const TypeConverter &converter,
Expand All @@ -1330,42 +1281,29 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
if (region->empty())
return nullptr;

if (failed(convertNonEntryRegionTypes(rewriter, region, converter)))
return failure();

FailureOr<Block *> newEntry = convertBlockSignature(
rewriter, &region->front(), &converter, entryConversion);
return newEntry;
}

LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
ConversionPatternRewriter &rewriter, Region *region,
const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
regionToConverter[region] = &converter;
if (region->empty())
return success();

// Convert the arguments of each block within the region.
int blockIdx = 0;
assert((blockConversions.empty() ||
blockConversions.size() == region->getBlocks().size() - 1) &&
"expected either to provide no SignatureConversions at all or to "
"provide a SignatureConversion for each non-entry block");

// Convert the arguments of each non-entry block within the region.
for (Block &block :
llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
TypeConverter::SignatureConversion *blockConversion =
blockConversions.empty()
? nullptr
: const_cast<TypeConverter::SignatureConversion *>(
&blockConversions[blockIdx++]);

if (failed(convertBlockSignature(rewriter, &block, &converter,
blockConversion)))
// Compute the signature for the block with the provided converter.
std::optional<TypeConverter::SignatureConversion> conversion =
converter.convertBlockSignature(&block);
if (!conversion)
return failure();
}
return success();
// Convert the block with the computed signature.
applySignatureConversion(rewriter, &block, &converter, *conversion);
}

// Convert the entry block. If an entry signature conversion was provided,
// use that one. Otherwise, compute the signature with the type converter.
if (entryConversion)
return applySignatureConversion(rewriter, &region->front(), &converter,
*entryConversion);
std::optional<TypeConverter::SignatureConversion> conversion =
converter.convertBlockSignature(&region->front());
if (!conversion)
return failure();
return applySignatureConversion(rewriter, &region->front(), &converter,
*conversion);
}

Block *ConversionPatternRewriterImpl::applySignatureConversion(
Expand Down Expand Up @@ -1676,12 +1614,12 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
}

Block *ConversionPatternRewriter::applySignatureConversion(
Region *region, TypeConverter::SignatureConversion &conversion,
Block *block, TypeConverter::SignatureConversion &conversion,
const TypeConverter *converter) {
assert(!impl->wasOpReplaced(region->getParentOp()) &&
assert(!impl->wasOpReplaced(block->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
return impl->applySignatureConversion(*this, region, conversion, converter);
return impl->applySignatureConversion(*this, block, converter, conversion);
}

FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
Expand All @@ -1693,16 +1631,6 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
return impl->convertRegionTypes(*this, region, converter, entryConversion);
}

LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
Region *region, const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
assert(!impl->wasOpReplaced(region->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
return impl->convertNonEntryRegionTypes(*this, region, converter,
blockConversions);
}

void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
Value to) {
LLVM_DEBUG({
Expand Down Expand Up @@ -2231,11 +2159,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
// If the region of the block has a type converter, try to convert the block
// directly.
if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
if (failed(impl.convertBlockSignature(rewriter, block, converter))) {
std::optional<TypeConverter::SignatureConversion> conversion =
converter->convertBlockSignature(block);
if (!conversion) {
LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
"block"));
return failure();
}
impl.applySignatureConversion(rewriter, block, converter, *conversion);
continue;
}

Expand Down
5 changes: 3 additions & 2 deletions mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1516,8 +1516,9 @@ struct TestTestSignatureConversionNoConverter
if (failed(
converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
return failure();
rewriter.modifyOpInPlace(
op, [&] { rewriter.applySignatureConversion(&region, result); });
rewriter.modifyOpInPlace(op, [&] {
rewriter.applySignatureConversion(&region.front(), result);
});
return success();
}

Expand Down

0 comments on commit 52050f3

Please sign in to comment.