From 9b50844fd798b5a81afd4aeb44b053d622747a42 Mon Sep 17 00:00:00 2001 From: Vladislav Vinogradov Date: Mon, 2 Aug 2021 16:42:23 +0300 Subject: [PATCH] [mlir] Fix delayed object interfaces registration Store both interfaceID and objectID as key for interface registration callback. Otherwise the implementation allows to register only one external model per one object in the single dialect. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D107274 --- mlir/include/mlir/IR/Dialect.h | 12 ++++-- mlir/lib/IR/Dialect.cpp | 17 +++++--- mlir/unittests/IR/InterfaceAttachmentTest.cpp | 43 ++++++++++++++----- 3 files changed, 51 insertions(+), 21 deletions(-) diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h index e16b3e47a0140..f615819fd16bb 100644 --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -17,6 +17,7 @@ #include "mlir/Support/TypeID.h" #include +#include namespace mlir { class DialectAsmParser; @@ -285,7 +286,7 @@ class DialectRegistry { SmallVector, 2> dialectInterfaces; /// Attribute/Operation/Type interfaces. - SmallVector, 2> + SmallVector, 2> objectInterfaces; }; @@ -367,7 +368,8 @@ class DialectRegistry { void addOpInterface() { StringRef opName = OpTy::getOperationName(); StringRef dialectName = opName.split('.').first; - addObjectInterface(dialectName, ModelTy::Interface::getInterfaceID(), + addObjectInterface(dialectName, TypeID::get(), + ModelTy::Interface::getInterfaceID(), [](MLIRContext *context) { OpTy::template attachInterface(*context); }); @@ -401,14 +403,16 @@ class DialectRegistry { /// Add an attribute/operation/type interface constructible with the given /// allocation function to the dialect identified by its namespace. - void addObjectInterface(StringRef dialectName, TypeID interfaceTypeID, + void addObjectInterface(StringRef dialectName, TypeID objectID, + TypeID interfaceTypeID, ObjectInterfaceAllocatorFunction allocator); /// Add an external model for an attribute/type interface to the dialect /// identified by its namespace. template void addStorageUserInterface(StringRef dialectName) { - addObjectInterface(dialectName, ModelTy::Interface::getInterfaceID(), + addObjectInterface(dialectName, TypeID::get(), + ModelTy::Interface::getInterfaceID(), [](MLIRContext *context) { ObjectTy::template attachInterface(*context); }); diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp index 4713463124d92..80c8dabe1f3b9 100644 --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -58,16 +58,19 @@ void DialectRegistry::addDialectInterface( } void DialectRegistry::addObjectInterface( - StringRef dialectName, TypeID interfaceTypeID, + StringRef dialectName, TypeID objectID, TypeID interfaceTypeID, ObjectInterfaceAllocatorFunction allocator) { assert(allocator && "unexpected null interface allocation function"); + auto it = registry.find(dialectName.str()); assert(it != registry.end() && "adding an interface for an op from an unregistered dialect"); - auto &ifaces = interfaces[it->second.first]; - for (const auto &kvp : ifaces.objectInterfaces) { - if (kvp.first == interfaceTypeID) { + auto dialectID = it->second.first; + auto &ifaces = interfaces[dialectID]; + + for (const auto &info : ifaces.objectInterfaces) { + if (std::get<0>(info) == objectID && std::get<1>(info) == interfaceTypeID) { LLVM_DEBUG(llvm::dbgs() << "[" DEBUG_TYPE "] repeated interface object interface registration"); @@ -75,7 +78,7 @@ void DialectRegistry::addObjectInterface( } } - ifaces.objectInterfaces.emplace_back(interfaceTypeID, allocator); + ifaces.objectInterfaces.emplace_back(objectID, interfaceTypeID, allocator); } DialectAllocatorFunctionRef @@ -110,8 +113,8 @@ void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const { } // Add attribute, operation and type interfaces. - for (const auto &kvp : it->getSecond().objectInterfaces) - kvp.second(dialect->getContext()); + for (const auto &info : it->getSecond().objectInterfaces) + std::get<2>(info)(dialect->getContext()); } //===----------------------------------------------------------------------===// diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp index b83e5a0bf2f77..76124707cbfc7 100644 --- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp +++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp @@ -321,15 +321,16 @@ TEST(InterfaceAttachment, Operation) { ASSERT_FALSE(isa(otherModuleOp.getOperation())); } +template struct TestExternalTestOpModel - : public TestExternalOpInterface::ExternalModel { + : public TestExternalOpInterface::ExternalModel< + TestExternalTestOpModel, ConcreteOp> { unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const { return op->getName().getStringRef().size() + arg; } static unsigned getNameLengthPlusArgTwice(unsigned arg) { - return test::OpJ::getOperationName().size() + 2 * arg; + return ConcreteOp::getOperationName().size() + 2 * arg; } }; @@ -337,39 +338,61 @@ TEST(InterfaceAttachment, OperationDelayedContextConstruct) { DialectRegistry registry; registry.insert(); registry.addOpInterface(); - registry.addOpInterface(); + registry.addOpInterface>(); + registry.addOpInterface>(); // Construct the context directly from a registry. The interfaces are expected // to be readily available on operations. MLIRContext context(registry); context.loadDialect(); + ModuleOp module = ModuleOp::create(UnknownLoc::get(&context)); OpBuilder builder(module); - auto op = + auto opJ = builder.create(builder.getUnknownLoc(), builder.getI32Type()); + auto opH = + builder.create(builder.getUnknownLoc(), opJ.getResult()); + auto opI = + builder.create(builder.getUnknownLoc(), opJ.getResult()); + EXPECT_TRUE(isa(module.getOperation())); - EXPECT_TRUE(isa(op.getOperation())); + EXPECT_TRUE(isa(opJ.getOperation())); + EXPECT_TRUE(isa(opH.getOperation())); + EXPECT_FALSE(isa(opI.getOperation())); } TEST(InterfaceAttachment, OperationDelayedContextAppend) { DialectRegistry registry; registry.insert(); registry.addOpInterface(); - registry.addOpInterface(); + registry.addOpInterface>(); + registry.addOpInterface>(); // Construct the context, create ops, and only then append the registry. The // interfaces are expected to be available after appending the registry. MLIRContext context; context.loadDialect(); + ModuleOp module = ModuleOp::create(UnknownLoc::get(&context)); OpBuilder builder(module); - auto op = + auto opJ = builder.create(builder.getUnknownLoc(), builder.getI32Type()); + auto opH = + builder.create(builder.getUnknownLoc(), opJ.getResult()); + auto opI = + builder.create(builder.getUnknownLoc(), opJ.getResult()); + EXPECT_FALSE(isa(module.getOperation())); - EXPECT_FALSE(isa(op.getOperation())); + EXPECT_FALSE(isa(opJ.getOperation())); + EXPECT_FALSE(isa(opH.getOperation())); + EXPECT_FALSE(isa(opI.getOperation())); + context.appendDialectRegistry(registry); + EXPECT_TRUE(isa(module.getOperation())); - EXPECT_TRUE(isa(op.getOperation())); + EXPECT_TRUE(isa(opJ.getOperation())); + EXPECT_TRUE(isa(opH.getOperation())); + EXPECT_FALSE(isa(opI.getOperation())); } } // end namespace