Skip to content

Commit

Permalink
[mlir] Fix delayed object interfaces registration
Browse files Browse the repository at this point in the history
Store both interfaceID and objectID as key for interface registration callback.
Otherwise the implementation allows to register only one external model per one object in the single dialect.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D107274
  • Loading branch information
Vladislav Vinogradov committed Aug 3, 2021
1 parent 4f4f278 commit 9b50844
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 21 deletions.
12 changes: 8 additions & 4 deletions mlir/include/mlir/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Support/TypeID.h"

#include <map>
#include <tuple>

namespace mlir {
class DialectAsmParser;
Expand Down Expand Up @@ -285,7 +286,7 @@ class DialectRegistry {
SmallVector<std::pair<TypeID, DialectInterfaceAllocatorFunction>, 2>
dialectInterfaces;
/// Attribute/Operation/Type interfaces.
SmallVector<std::pair<TypeID, ObjectInterfaceAllocatorFunction>, 2>
SmallVector<std::tuple<TypeID, TypeID, ObjectInterfaceAllocatorFunction>, 2>
objectInterfaces;
};

Expand Down Expand Up @@ -367,7 +368,8 @@ class DialectRegistry {
void addOpInterface() {
StringRef opName = OpTy::getOperationName();
StringRef dialectName = opName.split('.').first;
addObjectInterface(dialectName, ModelTy::Interface::getInterfaceID(),
addObjectInterface(dialectName, TypeID::get<OpTy>(),
ModelTy::Interface::getInterfaceID(),
[](MLIRContext *context) {
OpTy::template attachInterface<ModelTy>(*context);
});
Expand Down Expand Up @@ -401,14 +403,16 @@ class DialectRegistry {

/// Add an attribute/operation/type interface constructible with the given
/// allocation function to the dialect identified by its namespace.
void addObjectInterface(StringRef dialectName, TypeID interfaceTypeID,
void addObjectInterface(StringRef dialectName, TypeID objectID,
TypeID interfaceTypeID,
ObjectInterfaceAllocatorFunction allocator);

/// Add an external model for an attribute/type interface to the dialect
/// identified by its namespace.
template <typename ObjectTy, typename ModelTy>
void addStorageUserInterface(StringRef dialectName) {
addObjectInterface(dialectName, ModelTy::Interface::getInterfaceID(),
addObjectInterface(dialectName, TypeID::get<ObjectTy>(),
ModelTy::Interface::getInterfaceID(),
[](MLIRContext *context) {
ObjectTy::template attachInterface<ModelTy>(*context);
});
Expand Down
17 changes: 10 additions & 7 deletions mlir/lib/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,24 +58,27 @@ void DialectRegistry::addDialectInterface(
}

void DialectRegistry::addObjectInterface(
StringRef dialectName, TypeID interfaceTypeID,
StringRef dialectName, TypeID objectID, TypeID interfaceTypeID,
ObjectInterfaceAllocatorFunction allocator) {
assert(allocator && "unexpected null interface allocation function");

auto it = registry.find(dialectName.str());
assert(it != registry.end() &&
"adding an interface for an op from an unregistered dialect");

auto &ifaces = interfaces[it->second.first];
for (const auto &kvp : ifaces.objectInterfaces) {
if (kvp.first == interfaceTypeID) {
auto dialectID = it->second.first;
auto &ifaces = interfaces[dialectID];

for (const auto &info : ifaces.objectInterfaces) {
if (std::get<0>(info) == objectID && std::get<1>(info) == interfaceTypeID) {
LLVM_DEBUG(llvm::dbgs()
<< "[" DEBUG_TYPE
"] repeated interface object interface registration");
return;
}
}

ifaces.objectInterfaces.emplace_back(interfaceTypeID, allocator);
ifaces.objectInterfaces.emplace_back(objectID, interfaceTypeID, allocator);
}

DialectAllocatorFunctionRef
Expand Down Expand Up @@ -110,8 +113,8 @@ void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const {
}

// Add attribute, operation and type interfaces.
for (const auto &kvp : it->getSecond().objectInterfaces)
kvp.second(dialect->getContext());
for (const auto &info : it->getSecond().objectInterfaces)
std::get<2>(info)(dialect->getContext());
}

//===----------------------------------------------------------------------===//
Expand Down
43 changes: 33 additions & 10 deletions mlir/unittests/IR/InterfaceAttachmentTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,55 +321,78 @@ TEST(InterfaceAttachment, Operation) {
ASSERT_FALSE(isa<TestExternalOpInterface>(otherModuleOp.getOperation()));
}

template <class ConcreteOp>
struct TestExternalTestOpModel
: public TestExternalOpInterface::ExternalModel<TestExternalTestOpModel,
test::OpJ> {
: public TestExternalOpInterface::ExternalModel<
TestExternalTestOpModel<ConcreteOp>, ConcreteOp> {
unsigned getNameLengthPlusArg(Operation *op, unsigned arg) const {
return op->getName().getStringRef().size() + arg;
}

static unsigned getNameLengthPlusArgTwice(unsigned arg) {
return test::OpJ::getOperationName().size() + 2 * arg;
return ConcreteOp::getOperationName().size() + 2 * arg;
}
};

TEST(InterfaceAttachment, OperationDelayedContextConstruct) {
DialectRegistry registry;
registry.insert<test::TestDialect>();
registry.addOpInterface<ModuleOp, TestExternalOpModel>();
registry.addOpInterface<test::OpJ, TestExternalTestOpModel>();
registry.addOpInterface<test::OpJ, TestExternalTestOpModel<test::OpJ>>();
registry.addOpInterface<test::OpH, TestExternalTestOpModel<test::OpH>>();

// Construct the context directly from a registry. The interfaces are expected
// to be readily available on operations.
MLIRContext context(registry);
context.loadDialect<test::TestDialect>();

ModuleOp module = ModuleOp::create(UnknownLoc::get(&context));
OpBuilder builder(module);
auto op =
auto opJ =
builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
auto opH =
builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult());
auto opI =
builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult());

EXPECT_TRUE(isa<TestExternalOpInterface>(module.getOperation()));
EXPECT_TRUE(isa<TestExternalOpInterface>(op.getOperation()));
EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation()));
EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation()));
EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
}

TEST(InterfaceAttachment, OperationDelayedContextAppend) {
DialectRegistry registry;
registry.insert<test::TestDialect>();
registry.addOpInterface<ModuleOp, TestExternalOpModel>();
registry.addOpInterface<test::OpJ, TestExternalTestOpModel>();
registry.addOpInterface<test::OpJ, TestExternalTestOpModel<test::OpJ>>();
registry.addOpInterface<test::OpH, TestExternalTestOpModel<test::OpH>>();

// Construct the context, create ops, and only then append the registry. The
// interfaces are expected to be available after appending the registry.
MLIRContext context;
context.loadDialect<test::TestDialect>();

ModuleOp module = ModuleOp::create(UnknownLoc::get(&context));
OpBuilder builder(module);
auto op =
auto opJ =
builder.create<test::OpJ>(builder.getUnknownLoc(), builder.getI32Type());
auto opH =
builder.create<test::OpH>(builder.getUnknownLoc(), opJ.getResult());
auto opI =
builder.create<test::OpI>(builder.getUnknownLoc(), opJ.getResult());

EXPECT_FALSE(isa<TestExternalOpInterface>(module.getOperation()));
EXPECT_FALSE(isa<TestExternalOpInterface>(op.getOperation()));
EXPECT_FALSE(isa<TestExternalOpInterface>(opJ.getOperation()));
EXPECT_FALSE(isa<TestExternalOpInterface>(opH.getOperation()));
EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));

context.appendDialectRegistry(registry);

EXPECT_TRUE(isa<TestExternalOpInterface>(module.getOperation()));
EXPECT_TRUE(isa<TestExternalOpInterface>(op.getOperation()));
EXPECT_TRUE(isa<TestExternalOpInterface>(opJ.getOperation()));
EXPECT_TRUE(isa<TestExternalOpInterface>(opH.getOperation()));
EXPECT_FALSE(isa<TestExternalOpInterface>(opI.getOperation()));
}

} // end namespace

0 comments on commit 9b50844

Please sign in to comment.